diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java index d39123183..0e8e977ce 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.ai.common.util.SeDeUtil; import org.apache.geaflow.ai.graph.*; import org.apache.geaflow.ai.graph.io.*; @@ -40,193 +41,199 @@ @Controller public class GeaFlowMemoryServer { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMemoryServer.class); - - private static final String SERVER_NAME = "geaflow-memory-server"; - private static final int DEFAULT_PORT = 8080; - - private static final ServerMemoryCache CACHE = new ServerMemoryCache(); - - public static void main(String[] args) { - System.setProperty("solon.app.name", SERVER_NAME); - Solon.start(GeaFlowMemoryServer.class, args, app -> { - app.cfg().loadAdd("application.yml"); - int port = app.cfg().getInt("server.port", DEFAULT_PORT); - LOGGER.info("Starting {} on port {}", SERVER_NAME, port); - app.get("/", ctx -> { + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMemoryServer.class); + + private static final String SERVER_NAME = "geaflow-memory-server"; + private static final int DEFAULT_PORT = 8080; + + private static final ServerMemoryCache CACHE = new ServerMemoryCache(); + + public static void main(String[] args) { + System.setProperty("solon.app.name", SERVER_NAME); + Solon.start( + GeaFlowMemoryServer.class, + args, + app -> { + app.cfg().loadAdd("application.yml"); + int port = app.cfg().getInt("server.port", DEFAULT_PORT); + LOGGER.info("Starting {} on port {}", SERVER_NAME, port); + app.get( + "/", + ctx -> { ctx.output("GeaFlow AI Server is running..."); - }); - app.get("/health", ctx -> { + }); + app.get( + "/health", + ctx -> { ctx.output("{\"status\":\"UP\",\"service\":\"" + SERVER_NAME + "\"}"); - }); + }); }); + } + + @Get + @Mapping("/api/test") + public String test() { + return "GeaFlow Memory Server is working!"; + } + + @Post + @Mapping("/graph/create") + public String createGraph(@Body String input) { + GraphSchema graphSchema = SeDeUtil.deserializeGraphSchema(input); + String graphName = graphSchema.getName(); + if (graphName == null || CACHE.getGraphByName(graphName) != null) { + throw new RuntimeException("Cannot create graph name: " + graphName); } - - @Get - @Mapping("/api/test") - public String test() { - return "GeaFlow Memory Server is working!"; - } - - @Post - @Mapping("/graph/create") - public String createGraph(@Body String input) { - GraphSchema graphSchema = SeDeUtil.deserializeGraphSchema(input); - String graphName = graphSchema.getName(); - if (graphName == null || CACHE.getGraphByName(graphName) != null) { - throw new RuntimeException("Cannot create graph name: " + graphName); - } - Map entities = new HashMap<>(); - for (VertexSchema vertexSchema : graphSchema.getVertexSchemaList()) { - entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); - } - for (EdgeSchema edgeSchema : graphSchema.getEdgeSchemaList()) { - entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); - } - MemoryGraph graph = new MemoryGraph(graphSchema, entities); - CACHE.putGraph(graph); - LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); - LOGGER.info("Success to init empty graph."); - - EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); - indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); - LOGGER.info("Success to init EntityAttributeIndexStore."); - - GraphMemoryServer server = new GraphMemoryServer(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - LOGGER.info("Success to init GraphMemoryServer."); - CACHE.putServer(server); - - LOGGER.info("Success to init graph. SCHEMA: {}", graphSchema); - return "createGraph has been called, graphName: " + graphName; - } - - @Post - @Mapping("/graph/addEntitySchema") - public String addSchema(@Param("graphName") String graphName, - @Body String input) { - Graph graph = CACHE.getGraphByName(graphName); - if (graph == null) { - throw new RuntimeException("Graph not exist."); - } - if (!(graph instanceof MemoryGraph)) { - throw new RuntimeException("Graph cannot modify."); - } - MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); - Schema schema = SeDeUtil.deserializeEntitySchema(input); - String schemaName = schema.getName(); - if (schema instanceof VertexSchema) { - memoryMutableGraph.addVertexSchema((VertexSchema) schema); - } else if (schema instanceof EdgeSchema) { - memoryMutableGraph.addEdgeSchema((EdgeSchema) schema); - } else { - throw new RuntimeException("Cannt add schema: " + input); - } - return "addSchema has been called, schemaName: " + schemaName; - } - - @Post - @Mapping("/graph/getGraphSchema") - public String getSchema(@Param("graphName") String graphName) { - Graph graph = CACHE.getGraphByName(graphName); - if (graph == null) { - throw new RuntimeException("Graph not exist."); - } - if (!(graph instanceof MemoryGraph)) { - throw new RuntimeException("Graph cannot modify."); - } - return SeDeUtil.serializeGraphSchema(graph.getGraphSchema()); - } - - @Post - @Mapping("/graph/insertEntity") - public String addEntity(@Param("graphName") String graphName, - @Body String input) { - Graph graph = CACHE.getGraphByName(graphName); - if (graph == null) { - throw new RuntimeException("Graph not exist."); - } - if (!(graph instanceof MemoryGraph)) { - throw new RuntimeException("Graph cannot modify."); - } - MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); - List graphEntities = SeDeUtil.deserializeEntities(input); - - for (GraphEntity entity : graphEntities) { - if (entity instanceof GraphVertex) { - memoryMutableGraph.addVertex(((GraphVertex) entity).getVertex()); - } else { - memoryMutableGraph.addEdge(((GraphEdge) entity).getEdge()); - } - } - CACHE.getConsolidateServer().executeConsolidateTask( + Map entities = new HashMap<>(); + for (VertexSchema vertexSchema : graphSchema.getVertexSchemaList()) { + entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); + } + for (EdgeSchema edgeSchema : graphSchema.getEdgeSchemaList()) { + entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); + } + MemoryGraph graph = new MemoryGraph(graphSchema, entities); + CACHE.putGraph(graph); + LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); + LOGGER.info("Success to init empty graph."); + + EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); + indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); + LOGGER.info("Success to init EntityAttributeIndexStore."); + + GraphMemoryServer server = new GraphMemoryServer(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + LOGGER.info("Success to init GraphMemoryServer."); + CACHE.putServer(server); + + LOGGER.info("Success to init graph. SCHEMA: {}", graphSchema); + return "createGraph has been called, graphName: " + graphName; + } + + @Post + @Mapping("/graph/addEntitySchema") + public String addSchema(@Param("graphName") String graphName, @Body String input) { + Graph graph = CACHE.getGraphByName(graphName); + if (graph == null) { + throw new RuntimeException("Graph not exist."); + } + if (!(graph instanceof MemoryGraph)) { + throw new RuntimeException("Graph cannot modify."); + } + MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); + Schema schema = SeDeUtil.deserializeEntitySchema(input); + String schemaName = schema.getName(); + if (schema instanceof VertexSchema) { + memoryMutableGraph.addVertexSchema((VertexSchema) schema); + } else if (schema instanceof EdgeSchema) { + memoryMutableGraph.addEdgeSchema((EdgeSchema) schema); + } else { + throw new RuntimeException("Cannt add schema: " + input); + } + return "addSchema has been called, schemaName: " + schemaName; + } + + @Post + @Mapping("/graph/getGraphSchema") + public String getSchema(@Param("graphName") String graphName) { + Graph graph = CACHE.getGraphByName(graphName); + if (graph == null) { + throw new RuntimeException("Graph not exist."); + } + if (!(graph instanceof MemoryGraph)) { + throw new RuntimeException("Graph cannot modify."); + } + return SeDeUtil.serializeGraphSchema(graph.getGraphSchema()); + } + + @Post + @Mapping("/graph/insertEntity") + public String addEntity(@Param("graphName") String graphName, @Body String input) { + Graph graph = CACHE.getGraphByName(graphName); + if (graph == null) { + throw new RuntimeException("Graph not exist."); + } + if (!(graph instanceof MemoryGraph)) { + throw new RuntimeException("Graph cannot modify."); + } + MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); + List graphEntities = SeDeUtil.deserializeEntities(input); + + for (GraphEntity entity : graphEntities) { + if (entity instanceof GraphVertex) { + memoryMutableGraph.addVertex(((GraphVertex) entity).getVertex()); + } else { + memoryMutableGraph.addEdge(((GraphEdge) entity).getEdge()); + } + } + CACHE + .getConsolidateServer() + .executeConsolidateTask( CACHE.getServerByName(graphName).getGraphAccessors().get(0), memoryMutableGraph); - return "Success to add entities, num: " + graphEntities.size(); - } - - @Post - @Mapping("/graph/delEntity") - public String deleteEntity(@Param("graphName") String graphName, - @Body String input) { - Graph graph = CACHE.getGraphByName(graphName); - if (graph == null) { - throw new RuntimeException("Graph not exist."); - } - if (!(graph instanceof MemoryGraph)) { - throw new RuntimeException("Graph cannot modify."); - } - MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); - List graphEntities = SeDeUtil.deserializeEntities(input); - for (GraphEntity entity : graphEntities) { - if (entity instanceof GraphVertex) { - memoryMutableGraph.removeVertex(entity.getLabel(), - ((GraphVertex) entity).getVertex().getId()); - } else { - memoryMutableGraph.removeEdge(((GraphEdge) entity).getEdge()); - } - } - return "Success to remove entities, num: " + graphEntities.size(); - } - - @Post - @Mapping("/query/context") - public String createContext(@Param("graphName") String graphName) { - GraphMemoryServer server = CACHE.getServerByName(graphName); - if (server == null) { - throw new RuntimeException("Server not exist."); - } - String sessionId = server.createSession(); - CACHE.putSession(server, sessionId); - return sessionId; - } - - @Post - @Mapping("/query/exec") - public String execQuery(@Param("sessionId") String sessionId, - @Body String query) { - String graphName = CACHE.getGraphNameBySession(sessionId); - if (graphName == null) { - throw new RuntimeException("Graph not exist."); - } - GraphMemoryServer server = CACHE.getServerByName(graphName); - VectorSearch search = new VectorSearch(null, sessionId); - search.addVector(new KeywordVector(query)); - server.search(search); - Context context = server.verbalize(sessionId, - new SubgraphSemanticPromptFunction(server.getGraphAccessors().get(0))); - return context.toString(); - } - - @Post - @Mapping("/query/result") - public String getResult(@Param("sessionId") String sessionId) { - String graphName = CACHE.getGraphNameBySession(sessionId); - if (graphName == null) { - throw new RuntimeException("Graph not exist."); - } - GraphMemoryServer server = CACHE.getServerByName(graphName); - List result = server.getSessionEntities(sessionId); - return result.toString(); + return "Success to add entities, num: " + graphEntities.size(); + } + + @Post + @Mapping("/graph/delEntity") + public String deleteEntity(@Param("graphName") String graphName, @Body String input) { + Graph graph = CACHE.getGraphByName(graphName); + if (graph == null) { + throw new RuntimeException("Graph not exist."); + } + if (!(graph instanceof MemoryGraph)) { + throw new RuntimeException("Graph cannot modify."); + } + MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph); + List graphEntities = SeDeUtil.deserializeEntities(input); + for (GraphEntity entity : graphEntities) { + if (entity instanceof GraphVertex) { + memoryMutableGraph.removeVertex( + entity.getLabel(), ((GraphVertex) entity).getVertex().getId()); + } else { + memoryMutableGraph.removeEdge(((GraphEdge) entity).getEdge()); + } + } + return "Success to remove entities, num: " + graphEntities.size(); + } + + @Post + @Mapping("/query/context") + public String createContext(@Param("graphName") String graphName) { + GraphMemoryServer server = CACHE.getServerByName(graphName); + if (server == null) { + throw new RuntimeException("Server not exist."); + } + String sessionId = server.createSession(); + CACHE.putSession(server, sessionId); + return sessionId; + } + + @Post + @Mapping("/query/exec") + public String execQuery(@Param("sessionId") String sessionId, @Body String query) { + String graphName = CACHE.getGraphNameBySession(sessionId); + if (graphName == null) { + throw new RuntimeException("Graph not exist."); + } + GraphMemoryServer server = CACHE.getServerByName(graphName); + VectorSearch search = new VectorSearch(null, sessionId); + search.addVector(new KeywordVector(query)); + server.search(search); + Context context = + server.verbalize( + sessionId, new SubgraphSemanticPromptFunction(server.getGraphAccessors().get(0))); + return context.toString(); + } + + @Post + @Mapping("/query/result") + public String getResult(@Param("sessionId") String sessionId) { + String graphName = CACHE.getGraphNameBySession(sessionId); + if (graphName == null) { + throw new RuntimeException("Graph not exist."); } + GraphMemoryServer server = CACHE.getServerByName(graphName); + List result = server.getSessionEntities(sessionId); + return result.toString(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java index b571fe540..aa00648ce 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.index.EmbeddingIndexStore; @@ -40,91 +41,91 @@ public class GraphMemoryServer { - private final SessionManagement sessionManagement = new SessionManagement(); - private final List graphAccessors = new ArrayList<>(); - private final List indexStores = new ArrayList<>(); + private final SessionManagement sessionManagement = new SessionManagement(); + private final List graphAccessors = new ArrayList<>(); + private final List indexStores = new ArrayList<>(); - public void addGraphAccessor(GraphAccessor graph) { - if (graph != null) { - graphAccessors.add(graph); - } + public void addGraphAccessor(GraphAccessor graph) { + if (graph != null) { + graphAccessors.add(graph); } + } - public List getGraphAccessors() { - return graphAccessors; - } + public List getGraphAccessors() { + return graphAccessors; + } - public void addIndexStore(IndexStore indexStore) { - if (indexStore != null) { - indexStores.add(indexStore); - } + public void addIndexStore(IndexStore indexStore) { + if (indexStore != null) { + indexStores.add(indexStore); } + } - public List getIndexStores() { - return indexStores; - } + public List getIndexStores() { + return indexStores; + } - public String createSession() { - String sessionId = sessionManagement.createSession(); - if (sessionId == null) { - throw new RuntimeException("Cannot create new session"); - } - return sessionId; + public String createSession() { + String sessionId = sessionManagement.createSession(); + if (sessionId == null) { + throw new RuntimeException("Cannot create new session"); } + return sessionId; + } - public String search(VectorSearch search) { - String sessionId = search.getSessionId(); - if (sessionId == null || sessionId.isEmpty()) { - throw new RuntimeException("Session id is empty"); - } - if (!sessionManagement.sessionExists(sessionId)) { - sessionManagement.createSession(sessionId); - } - - for (IndexStore indexStore : indexStores) { - if (indexStore instanceof EntityAttributeIndexStore) { - SessionOperator searchOperator = new SessionOperator(graphAccessors.get(0), indexStore); - applySearch(sessionId, searchOperator, search); - } - if (indexStore instanceof EmbeddingIndexStore) { - EmbeddingOperator embeddingOperator = new EmbeddingOperator(graphAccessors.get(0), indexStore); - applySearch(sessionId, embeddingOperator, search); - } - } - return sessionId; + public String search(VectorSearch search) { + String sessionId = search.getSessionId(); + if (sessionId == null || sessionId.isEmpty()) { + throw new RuntimeException("Session id is empty"); + } + if (!sessionManagement.sessionExists(sessionId)) { + sessionManagement.createSession(sessionId); } - private void applySearch(String sessionId, SearchOperator operator, VectorSearch search) { - SessionManagement manager = sessionManagement; - if (!manager.sessionExists(sessionId)) { - return; - } - List result = operator.apply(manager.getSubGraph(sessionId), search); - manager.setSubGraph(sessionId, result); + for (IndexStore indexStore : indexStores) { + if (indexStore instanceof EntityAttributeIndexStore) { + SessionOperator searchOperator = new SessionOperator(graphAccessors.get(0), indexStore); + applySearch(sessionId, searchOperator, search); + } + if (indexStore instanceof EmbeddingIndexStore) { + EmbeddingOperator embeddingOperator = + new EmbeddingOperator(graphAccessors.get(0), indexStore); + applySearch(sessionId, embeddingOperator, search); + } } + return sessionId; + } - public Context verbalize(String sessionId, VerbalizationFunction verbalizationFunction) { - List subGraphList = sessionManagement.getSubGraph(sessionId); - List subGraphStringList = new ArrayList<>(subGraphList.size()); - for (SubGraph subGraph : subGraphList) { - subGraphStringList.add(verbalizationFunction.verbalize(subGraph)); - } - subGraphStringList = subGraphStringList.stream().sorted().collect(Collectors.toList()); - StringBuilder stringBuilder = new StringBuilder(); - for (String subGraph : subGraphStringList) { - stringBuilder.append(subGraph).append("\n"); - } - stringBuilder.append(verbalizationFunction.verbalizeGraphSchema()); - return new Context(stringBuilder.toString()); + private void applySearch(String sessionId, SearchOperator operator, VectorSearch search) { + SessionManagement manager = sessionManagement; + if (!manager.sessionExists(sessionId)) { + return; } + List result = operator.apply(manager.getSubGraph(sessionId), search); + manager.setSubGraph(sessionId, result); + } - public List getSessionEntities(String sessionId) { - List subGraphList = sessionManagement.getSubGraph(sessionId); - Set entitySet = new HashSet<>(); - for (SubGraph subGraph : subGraphList) { - entitySet.addAll(subGraph.getGraphEntityList()); - } - return new ArrayList<>(entitySet); + public Context verbalize(String sessionId, VerbalizationFunction verbalizationFunction) { + List subGraphList = sessionManagement.getSubGraph(sessionId); + List subGraphStringList = new ArrayList<>(subGraphList.size()); + for (SubGraph subGraph : subGraphList) { + subGraphStringList.add(verbalizationFunction.verbalize(subGraph)); + } + subGraphStringList = subGraphStringList.stream().sorted().collect(Collectors.toList()); + StringBuilder stringBuilder = new StringBuilder(); + for (String subGraph : subGraphStringList) { + stringBuilder.append(subGraph).append("\n"); } + stringBuilder.append(verbalizationFunction.verbalizeGraphSchema()); + return new Context(stringBuilder.toString()); + } + public List getSessionEntities(String sessionId) { + List subGraphList = sessionManagement.getSubGraph(sessionId); + Set entitySet = new HashSet<>(); + for (SubGraph subGraph : subGraphList) { + entitySet.addAll(subGraph.getGraphEntityList()); + } + return new ArrayList<>(entitySet); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/client/GeaFlowMemoryClientCLI.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/client/GeaFlowMemoryClientCLI.java index 870b0eb6b..1a4753501 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/client/GeaFlowMemoryClientCLI.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/client/GeaFlowMemoryClientCLI.java @@ -19,7 +19,6 @@ package org.apache.geaflow.ai.client; -import com.google.gson.Gson; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; @@ -31,339 +30,347 @@ import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.io.*; -public class GeaFlowMemoryClientCLI { - - private static final String BASE_URL = "http://localhost:8080"; - private static final String SERVER_URL = BASE_URL + "/api/test"; - private static final String CREATE_URL = BASE_URL + "/graph/create"; - private static final String SCHEMA_URL = BASE_URL + "/graph/addEntitySchema"; - private static final String INSERT_URL = BASE_URL + "/graph/insertEntity"; - private static final String CONTEXT_URL = BASE_URL + "/query/context"; - private static final String EXEC_URL = BASE_URL + "/query/exec"; - - private static final String DEFAULT_GRAPH_NAME = "memory_graph"; - private static final String VERTEX_LABEL = "chunk"; - private static final String EDGE_LABEL = "relation"; - private final Scanner scanner = new Scanner(System.in); - private final Gson gson = new Gson(); - private String currentGraphName = DEFAULT_GRAPH_NAME; - private String currentSessionId = null; - - public static void main(String[] args) { - GeaFlowMemoryClientCLI client = new GeaFlowMemoryClientCLI(); - client.start(); - } - - public void start() { - printWelcome(); - - while (true) { - try { - System.out.print("\ngeaflow> "); - String input = scanner.nextLine().trim(); - - if (input.isEmpty()) { - continue; - } - - if (input.equalsIgnoreCase("exit") || input.equalsIgnoreCase("quit")) { - System.out.println("Goodbye!"); - break; - } - - if (input.equalsIgnoreCase("help")) { - printHelp(); - continue; - } +import com.google.gson.Gson; - processCommand(input); +public class GeaFlowMemoryClientCLI { - } catch (Exception e) { - System.err.println("Error: " + e.getMessage()); - if (e.getCause() != null) { - System.err.println("Cause: " + e.getCause().getMessage()); - } - } + private static final String BASE_URL = "http://localhost:8080"; + private static final String SERVER_URL = BASE_URL + "/api/test"; + private static final String CREATE_URL = BASE_URL + "/graph/create"; + private static final String SCHEMA_URL = BASE_URL + "/graph/addEntitySchema"; + private static final String INSERT_URL = BASE_URL + "/graph/insertEntity"; + private static final String CONTEXT_URL = BASE_URL + "/query/context"; + private static final String EXEC_URL = BASE_URL + "/query/exec"; + + private static final String DEFAULT_GRAPH_NAME = "memory_graph"; + private static final String VERTEX_LABEL = "chunk"; + private static final String EDGE_LABEL = "relation"; + private final Scanner scanner = new Scanner(System.in); + private final Gson gson = new Gson(); + private String currentGraphName = DEFAULT_GRAPH_NAME; + private String currentSessionId = null; + + public static void main(String[] args) { + GeaFlowMemoryClientCLI client = new GeaFlowMemoryClientCLI(); + client.start(); + } + + public void start() { + printWelcome(); + + while (true) { + try { + System.out.print("\ngeaflow> "); + String input = scanner.nextLine().trim(); + + if (input.isEmpty()) { + continue; } - scanner.close(); - } - - private void processCommand(String command) throws IOException { - String[] parts = command.split("\\s+", 2); - String cmd = parts[0].toLowerCase(); - String param = parts.length > 1 ? parts[1] : ""; - - switch (cmd) { - case "test": - testServer(); - break; - - case "use": - currentGraphName = param.isEmpty() ? DEFAULT_GRAPH_NAME : param; - break; - - case "create": - String graphName = param.isEmpty() ? DEFAULT_GRAPH_NAME : param; - createGraph(graphName); - currentGraphName = graphName; - break; - - case "remember": - if (param.isEmpty()) { - System.out.println("Please enter content to remember:"); - param = scanner.nextLine(); - } - rememberContent(param); - break; - - case "query": - if (param.isEmpty()) { - System.out.println("Please enter your query:"); - param = scanner.nextLine(); - } - executeQuery(param); - break; - - default: - System.out.println("Unknown command: " + cmd); - System.out.println("Available commands: test, create, use, remember, query, help, exit"); + if (input.equalsIgnoreCase("exit") || input.equalsIgnoreCase("quit")) { + System.out.println("Goodbye!"); + break; } - } - - private void testServer() throws IOException { - System.out.println("Testing server connection..."); - String response = sendGetRequest(SERVER_URL); - System.out.println("✓ Server response: " + response); - } - - private void createGraph(String graphName) throws IOException { - System.out.println("Creating graph: " + graphName); - - GraphSchema testGraph = new GraphSchema(); - testGraph.setName(graphName); - String graphJson = gson.toJson(testGraph); - String response = sendPostRequest(CREATE_URL, graphJson); - System.out.println("✓ Graph created: " + response); - Map params = new HashMap<>(); - params.put("graphName", graphName); - VertexSchema vertexSchema = new VertexSchema(VERTEX_LABEL, Constants.PREFIX_ID, - Collections.singletonList("text")); - response = sendPostRequest(SCHEMA_URL, gson.toJson(vertexSchema), params); - System.out.println("✓ Chunk schema added: " + response); + if (input.equalsIgnoreCase("help")) { + printHelp(); + continue; + } - EdgeSchema edgeSchema = new EdgeSchema(EDGE_LABEL, Constants.PREFIX_SRC_ID, Constants.PREFIX_DST_ID, - Collections.singletonList("rel")); - response = sendPostRequest(SCHEMA_URL, gson.toJson(edgeSchema), params); - System.out.println("✓ Relation schema added: " + response); + processCommand(input); - System.out.println("✓ Graph '" + graphName + "' is ready for use!"); + } catch (Exception e) { + System.err.println("Error: " + e.getMessage()); + if (e.getCause() != null) { + System.err.println("Cause: " + e.getCause().getMessage()); + } + } } - private void rememberContent(String content) throws IOException { - if (currentGraphName == null) { - System.out.println("No graph selected. Please create a graph first."); - return; + scanner.close(); + } + + private void processCommand(String command) throws IOException { + String[] parts = command.split("\\s+", 2); + String cmd = parts[0].toLowerCase(); + String param = parts.length > 1 ? parts[1] : ""; + + switch (cmd) { + case "test": + testServer(); + break; + + case "use": + currentGraphName = param.isEmpty() ? DEFAULT_GRAPH_NAME : param; + break; + + case "create": + String graphName = param.isEmpty() ? DEFAULT_GRAPH_NAME : param; + createGraph(graphName); + currentGraphName = graphName; + break; + + case "remember": + if (param.isEmpty()) { + System.out.println("Please enter content to remember:"); + param = scanner.nextLine(); } + rememberContent(param); + break; - if (content.trim().toLowerCase(Locale.ROOT).startsWith("doc")) { - String path = content.trim().substring(3).trim(); - - TextFileReader textFileReader = new TextFileReader(10000); - textFileReader.readFile(path); - List chunks = IntStream.range(0, textFileReader.getRowCount()) - .mapToObj(textFileReader::getRow) - .map(String::trim).collect(Collectors.toList()); - for (String chunk : chunks) { - String response = rememberChunk(chunk); - System.out.println("✓ Content remembered: " + response); - } + case "query": + if (param.isEmpty()) { + System.out.println("Please enter your query:"); + param = scanner.nextLine(); } + executeQuery(param); + break; - System.out.println("Remembering content..."); - String response = rememberChunk(content); - System.out.println("✓ Content remembered: " + response); - + default: + System.out.println("Unknown command: " + cmd); + System.out.println("Available commands: test, create, use, remember, query, help, exit"); } + } + + private void testServer() throws IOException { + System.out.println("Testing server connection..."); + String response = sendGetRequest(SERVER_URL); + System.out.println("✓ Server response: " + response); + } + + private void createGraph(String graphName) throws IOException { + System.out.println("Creating graph: " + graphName); + + GraphSchema testGraph = new GraphSchema(); + testGraph.setName(graphName); + String graphJson = gson.toJson(testGraph); + String response = sendPostRequest(CREATE_URL, graphJson); + System.out.println("✓ Graph created: " + response); + + Map params = new HashMap<>(); + params.put("graphName", graphName); + VertexSchema vertexSchema = + new VertexSchema(VERTEX_LABEL, Constants.PREFIX_ID, Collections.singletonList("text")); + response = sendPostRequest(SCHEMA_URL, gson.toJson(vertexSchema), params); + System.out.println("✓ Chunk schema added: " + response); + + EdgeSchema edgeSchema = + new EdgeSchema( + EDGE_LABEL, + Constants.PREFIX_SRC_ID, + Constants.PREFIX_DST_ID, + Collections.singletonList("rel")); + response = sendPostRequest(SCHEMA_URL, gson.toJson(edgeSchema), params); + System.out.println("✓ Relation schema added: " + response); + System.out.println("✓ Graph '" + graphName + "' is ready for use!"); + } - private String rememberChunk(String content) throws IOException { - String vertexId = "chunk_" + System.currentTimeMillis() + "_" + Math.abs(content.hashCode()); - Vertex chunkVertex = new Vertex("chunk", vertexId, Collections.singletonList(content)); - String vertexJson = gson.toJson(chunkVertex); - - Map params = new HashMap<>(); - params.put("graphName", currentGraphName); + private void rememberContent(String content) throws IOException { + if (currentGraphName == null) { + System.out.println("No graph selected. Please create a graph first."); + return; + } - return sendPostRequest(INSERT_URL, vertexJson, params); + if (content.trim().toLowerCase(Locale.ROOT).startsWith("doc")) { + String path = content.trim().substring(3).trim(); + + TextFileReader textFileReader = new TextFileReader(10000); + textFileReader.readFile(path); + List chunks = + IntStream.range(0, textFileReader.getRowCount()) + .mapToObj(textFileReader::getRow) + .map(String::trim) + .collect(Collectors.toList()); + for (String chunk : chunks) { + String response = rememberChunk(chunk); + System.out.println("✓ Content remembered: " + response); + } } - private void executeQuery(String query) throws IOException { - if (currentGraphName == null) { - System.out.println("No graph selected. Please create a graph first."); - return; - } + System.out.println("Remembering content..."); + String response = rememberChunk(content); + System.out.println("✓ Content remembered: " + response); + } - System.out.println("Creating new session..."); - Map params = new HashMap<>(); - params.put("graphName", currentGraphName); - String response = sendPostRequest(CONTEXT_URL, "", params); - currentSessionId = response.trim(); - System.out.println("✓ Session created: " + currentSessionId); + private String rememberChunk(String content) throws IOException { + String vertexId = "chunk_" + System.currentTimeMillis() + "_" + Math.abs(content.hashCode()); + Vertex chunkVertex = new Vertex("chunk", vertexId, Collections.singletonList(content)); + String vertexJson = gson.toJson(chunkVertex); - System.out.println("Executing query: " + query); + Map params = new HashMap<>(); + params.put("graphName", currentGraphName); - params = new HashMap<>(); - params.put("sessionId", currentSessionId); + return sendPostRequest(INSERT_URL, vertexJson, params); + } - response = sendPostRequest(EXEC_URL, query, params); - System.out.println("✓ Query result:"); - System.out.println("========================"); - System.out.println(response); - System.out.println("========================"); + private void executeQuery(String query) throws IOException { + if (currentGraphName == null) { + System.out.println("No graph selected. Please create a graph first."); + return; } - private String sendGetRequest(String urlStr) throws IOException { - URL url = new URL(urlStr); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("GET"); - conn.setRequestProperty("Accept", "application/json"); - - int responseCode = conn.getResponseCode(); - if (responseCode != HttpURLConnection.HTTP_OK) { - throw new IOException("HTTP error code: " + responseCode); - } - - try (BufferedReader br = new BufferedReader( - new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { - StringBuilder response = new StringBuilder(); - String responseLine; - while ((responseLine = br.readLine()) != null) { - response.append(responseLine.trim()); - } - return response.toString(); - } + System.out.println("Creating new session..."); + Map params = new HashMap<>(); + params.put("graphName", currentGraphName); + String response = sendPostRequest(CONTEXT_URL, "", params); + currentSessionId = response.trim(); + System.out.println("✓ Session created: " + currentSessionId); + + System.out.println("Executing query: " + query); + + params = new HashMap<>(); + params.put("sessionId", currentSessionId); + + response = sendPostRequest(EXEC_URL, query, params); + System.out.println("✓ Query result:"); + System.out.println("========================"); + System.out.println(response); + System.out.println("========================"); + } + + private String sendGetRequest(String urlStr) throws IOException { + URL url = new URL(urlStr); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("GET"); + conn.setRequestProperty("Accept", "application/json"); + + int responseCode = conn.getResponseCode(); + if (responseCode != HttpURLConnection.HTTP_OK) { + throw new IOException("HTTP error code: " + responseCode); } - private String sendPostRequest(String urlStr, String body) throws IOException { - return sendPostRequest(urlStr, body, Collections.emptyMap()); + try (BufferedReader br = + new BufferedReader(new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); } - - private String sendPostRequest(String urlStr, String body, Map queryParams) throws IOException { - if (!queryParams.isEmpty()) { - StringBuilder urlBuilder = new StringBuilder(urlStr); - urlBuilder.append("?"); - boolean first = true; - for (Map.Entry entry : queryParams.entrySet()) { - if (!first) { - urlBuilder.append("&"); - } - urlBuilder.append(URLEncoder.encode(entry.getKey(), "UTF-8")); - urlBuilder.append("="); - urlBuilder.append(URLEncoder.encode(entry.getValue(), "UTF-8")); - first = false; - } - urlStr = urlBuilder.toString(); - } - - URL url = new URL(urlStr); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json"); - conn.setRequestProperty("Accept", "application/json"); - conn.setDoOutput(true); - - try (OutputStream os = conn.getOutputStream()) { - byte[] input = body.getBytes(StandardCharsets.UTF_8); - os.write(input, 0, input.length); + } + + private String sendPostRequest(String urlStr, String body) throws IOException { + return sendPostRequest(urlStr, body, Collections.emptyMap()); + } + + private String sendPostRequest(String urlStr, String body, Map queryParams) + throws IOException { + if (!queryParams.isEmpty()) { + StringBuilder urlBuilder = new StringBuilder(urlStr); + urlBuilder.append("?"); + boolean first = true; + for (Map.Entry entry : queryParams.entrySet()) { + if (!first) { + urlBuilder.append("&"); } + urlBuilder.append(URLEncoder.encode(entry.getKey(), "UTF-8")); + urlBuilder.append("="); + urlBuilder.append(URLEncoder.encode(entry.getValue(), "UTF-8")); + first = false; + } + urlStr = urlBuilder.toString(); + } - int responseCode = conn.getResponseCode(); - if (responseCode != HttpURLConnection.HTTP_OK) { - String errorMessage = readErrorResponse(conn); - throw new IOException("HTTP error code: " + responseCode + ", Message: " + errorMessage); - } + URL url = new URL(urlStr); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setRequestProperty("Accept", "application/json"); + conn.setDoOutput(true); - try (BufferedReader br = new BufferedReader( - new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { - StringBuilder response = new StringBuilder(); - String responseLine; - while ((responseLine = br.readLine()) != null) { - response.append(responseLine); - } - return response.toString(); - } + try (OutputStream os = conn.getOutputStream()) { + byte[] input = body.getBytes(StandardCharsets.UTF_8); + os.write(input, 0, input.length); } - private String readErrorResponse(HttpURLConnection conn) throws IOException { - try (BufferedReader br = new BufferedReader( - new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) { - StringBuilder response = new StringBuilder(); - String responseLine; - while ((responseLine = br.readLine()) != null) { - response.append(responseLine); - } - return response.toString(); - } catch (Exception e) { - return "No error message available"; - } + int responseCode = conn.getResponseCode(); + if (responseCode != HttpURLConnection.HTTP_OK) { + String errorMessage = readErrorResponse(conn); + throw new IOException("HTTP error code: " + responseCode + ", Message: " + errorMessage); } - private void printWelcome() { - System.out.println("========================================="); - System.out.println(" GeaFlow Memory Server - Simple Client"); - System.out.println("========================================="); - System.out.println("Simple Commands:"); - System.out.println(" test - Test server connection"); - System.out.println(" create [name] - Create a new memory graph"); - System.out.println(" use [name] - Use a new memory graph"); - System.out.println(" remember - Store content to memory"); - System.out.println(" query - Ask questions about memory"); - System.out.println(" help - Show this help"); - System.out.println(" exit - Quit the client"); - System.out.println("========================================="); - System.out.println("Default graph name: " + DEFAULT_GRAPH_NAME); - System.out.println("Server URL: " + BASE_URL); - System.out.println("========================================="); + try (BufferedReader br = + new BufferedReader(new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine); + } + return response.toString(); } - - private void printHelp() { - System.out.println("\nAvailable Commands:"); - System.out.println("-------------------"); - System.out.println("test"); - System.out.println(" Test if the GeaFlow server is running"); - System.out.println(" Example: test"); - System.out.println(); - System.out.println("create [graph_name]"); - System.out.println(" Create a new memory graph with default schema"); - System.out.println(" Creates: chunk vertices and relation edges"); - System.out.println(" Default name: " + DEFAULT_GRAPH_NAME); - System.out.println(" Example: create"); - System.out.println(" Example: create my_memory"); - System.out.println(); - System.out.println("use [graph_name]"); - System.out.println(" Use a new memory graph"); - System.out.println(" Default name: " + DEFAULT_GRAPH_NAME); - System.out.println(" Example: use my_memory"); - System.out.println(); - System.out.println("remember "); - System.out.println(" Store text content into memory"); - System.out.println(" Creates a 'chunk' vertex with the content"); - System.out.println(" Example: remember \"孔子是中国古代的思想家\""); - System.out.println(" Example: remember"); - System.out.println(" (will prompt for content)"); - System.out.println(); - System.out.println("query "); - System.out.println(" Query the memory with natural language"); - System.out.println(" Example: query \"Who is Confucius?\""); - System.out.println(" Example: query"); - System.out.println(" (will prompt for question)"); - System.out.println(); - System.out.println("exit / quit"); - System.out.println(" Exit the client"); + } + + private String readErrorResponse(HttpURLConnection conn) throws IOException { + try (BufferedReader br = + new BufferedReader(new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine); + } + return response.toString(); + } catch (Exception e) { + return "No error message available"; } + } + + private void printWelcome() { + System.out.println("========================================="); + System.out.println(" GeaFlow Memory Server - Simple Client"); + System.out.println("========================================="); + System.out.println("Simple Commands:"); + System.out.println(" test - Test server connection"); + System.out.println(" create [name] - Create a new memory graph"); + System.out.println(" use [name] - Use a new memory graph"); + System.out.println(" remember - Store content to memory"); + System.out.println(" query - Ask questions about memory"); + System.out.println(" help - Show this help"); + System.out.println(" exit - Quit the client"); + System.out.println("========================================="); + System.out.println("Default graph name: " + DEFAULT_GRAPH_NAME); + System.out.println("Server URL: " + BASE_URL); + System.out.println("========================================="); + } + + private void printHelp() { + System.out.println("\nAvailable Commands:"); + System.out.println("-------------------"); + System.out.println("test"); + System.out.println(" Test if the GeaFlow server is running"); + System.out.println(" Example: test"); + System.out.println(); + System.out.println("create [graph_name]"); + System.out.println(" Create a new memory graph with default schema"); + System.out.println(" Creates: chunk vertices and relation edges"); + System.out.println(" Default name: " + DEFAULT_GRAPH_NAME); + System.out.println(" Example: create"); + System.out.println(" Example: create my_memory"); + System.out.println(); + System.out.println("use [graph_name]"); + System.out.println(" Use a new memory graph"); + System.out.println(" Default name: " + DEFAULT_GRAPH_NAME); + System.out.println(" Example: use my_memory"); + System.out.println(); + System.out.println("remember "); + System.out.println(" Store text content into memory"); + System.out.println(" Creates a 'chunk' vertex with the content"); + System.out.println(" Example: remember \"孔子是中国古代的思想家\""); + System.out.println(" Example: remember"); + System.out.println(" (will prompt for content)"); + System.out.println(); + System.out.println("query "); + System.out.println(" Query the memory with natural language"); + System.out.println(" Example: query \"Who is Confucius?\""); + System.out.println(" Example: query"); + System.out.println(" (will prompt for question)"); + System.out.println(); + System.out.println("exit / quit"); + System.out.println(" Exit the client"); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/ErrorCode.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/ErrorCode.java index 41eed3f08..0727e1da3 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/ErrorCode.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/ErrorCode.java @@ -21,12 +21,12 @@ public class ErrorCode { - public static final int SUCCESS = 0; - public static final int GRAPH_ENTITY_GROUP_NOT_EXISTS = 100001; - public static final int GRAPH_ENTITY_GROUP_NOT_MATCH = 100002; - public static final int GRAPH_ENTITY_GROUP_INSERT_FAILED = 100003; - public static final int GRAPH_ENTITY_GROUP_UPDATE_FAILED = 100004; - public static final int GRAPH_ENTITY_GROUP_REMOVE_FAILED = 100005; - public static final int GRAPH_ADD_VERTEX_SCHEMA_FAILED = 100006; - public static final int GRAPH_ADD_EDGE_SCHEMA_FAILED = 100007; + public static final int SUCCESS = 0; + public static final int GRAPH_ENTITY_GROUP_NOT_EXISTS = 100001; + public static final int GRAPH_ENTITY_GROUP_NOT_MATCH = 100002; + public static final int GRAPH_ENTITY_GROUP_INSERT_FAILED = 100003; + public static final int GRAPH_ENTITY_GROUP_UPDATE_FAILED = 100004; + public static final int GRAPH_ENTITY_GROUP_REMOVE_FAILED = 100005; + public static final int GRAPH_ADD_VERTEX_SCHEMA_FAILED = 100006; + public static final int GRAPH_ADD_EDGE_SCHEMA_FAILED = 100007; } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/Constants.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/Constants.java index ef2646035..3c5167d8c 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/Constants.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/Constants.java @@ -21,32 +21,32 @@ public class Constants { - public static String MODEL_CONTEXT_ROLE_USER = "user"; - public static String PREFIX_V = "V"; - public static String PREFIX_E = "E"; - public static String PREFIX_GRAPH = "GRAPH"; - public static String PREFIX_TMP_SESSION = "TmpSession-"; - public static String PREFIX_SRC_ID = "srcId"; - public static String PREFIX_DST_ID = "dstId"; - public static String PREFIX_ID = "id"; - - public static int HTTP_CALL_TIMEOUT_SECONDS = 300; - public static int HTTP_CONNECT_TIMEOUT_SECONDS = 300; - public static int HTTP_READ_TIMEOUT_SECONDS = 300; - public static int HTTP_WRITE_TIMEOUT_SECONDS = 300; - - public static int MODEL_CLIENT_RETRY_TIMES = 10; - public static int MODEL_CLIENT_RETRY_INTERVAL_MS = 3000; - - public static int EMBEDDING_INDEX_STORE_BATCH_SIZE = 32; - public static int EMBEDDING_INDEX_STORE_REPORT_SIZE = 100; - public static int EMBEDDING_INDEX_STORE_FLUSH_WRITE_SIZE = 1024; - public static int EMBEDDING_INDEX_STORE_SPLIT_TEXT_CHUNK_SIZE = 128; - - public static double EMBEDDING_OPERATE_DEFAULT_THRESHOLD = 0.5; - public static int EMBEDDING_OPERATE_DEFAULT_TOPN = 50; - public static int GRAPH_SEARCH_STORE_DEFAULT_TOPN = 30; - - public static String CONSOLIDATE_KEYWORD_RELATION_LABEL = "consolidate_keyword_edge"; - public static String PREFIX_COMMON_KEYWORDS = "common_keywords"; + public static String MODEL_CONTEXT_ROLE_USER = "user"; + public static String PREFIX_V = "V"; + public static String PREFIX_E = "E"; + public static String PREFIX_GRAPH = "GRAPH"; + public static String PREFIX_TMP_SESSION = "TmpSession-"; + public static String PREFIX_SRC_ID = "srcId"; + public static String PREFIX_DST_ID = "dstId"; + public static String PREFIX_ID = "id"; + + public static int HTTP_CALL_TIMEOUT_SECONDS = 300; + public static int HTTP_CONNECT_TIMEOUT_SECONDS = 300; + public static int HTTP_READ_TIMEOUT_SECONDS = 300; + public static int HTTP_WRITE_TIMEOUT_SECONDS = 300; + + public static int MODEL_CLIENT_RETRY_TIMES = 10; + public static int MODEL_CLIENT_RETRY_INTERVAL_MS = 3000; + + public static int EMBEDDING_INDEX_STORE_BATCH_SIZE = 32; + public static int EMBEDDING_INDEX_STORE_REPORT_SIZE = 100; + public static int EMBEDDING_INDEX_STORE_FLUSH_WRITE_SIZE = 1024; + public static int EMBEDDING_INDEX_STORE_SPLIT_TEXT_CHUNK_SIZE = 128; + + public static double EMBEDDING_OPERATE_DEFAULT_THRESHOLD = 0.5; + public static int EMBEDDING_OPERATE_DEFAULT_TOPN = 50; + public static int GRAPH_SEARCH_STORE_DEFAULT_TOPN = 30; + + public static String CONSOLIDATE_KEYWORD_RELATION_LABEL = "consolidate_keyword_edge"; + public static String PREFIX_COMMON_KEYWORDS = "common_keywords"; } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/GraphMemoryConfigKeys.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/GraphMemoryConfigKeys.java index 5f288787d..c611b070d 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/GraphMemoryConfigKeys.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/config/GraphMemoryConfigKeys.java @@ -19,5 +19,4 @@ package org.apache.geaflow.ai.common.config; -public class GraphMemoryConfigKeys { -} +public class GraphMemoryConfigKeys {} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/AbstractModelService.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/AbstractModelService.java index a2ffd2b49..7224eb2c9 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/AbstractModelService.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/AbstractModelService.java @@ -21,17 +21,17 @@ public abstract class AbstractModelService { - public ModelConfig getModelConfig() { - return modelConfig; - } + public ModelConfig getModelConfig() { + return modelConfig; + } - public void setModelConfig(ModelConfig modelConfig) { - this.modelConfig = modelConfig; - } + public void setModelConfig(ModelConfig modelConfig) { + this.modelConfig = modelConfig; + } - private ModelConfig modelConfig; + private ModelConfig modelConfig; - public AbstractModelService(ModelConfig modelConfig) { - this.modelConfig = modelConfig; - } + public AbstractModelService(ModelConfig modelConfig) { + this.modelConfig = modelConfig; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ChatService.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ChatService.java index 4101a6231..4e1984e18 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ChatService.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ChatService.java @@ -21,20 +21,20 @@ public class ChatService extends AbstractModelService { - public ChatService() { - super(new ModelConfig()); - } + public ChatService() { + super(new ModelConfig()); + } - public ChatService(String model) { - super(new ModelConfig()); - getModelConfig().setModel(model); - } + public ChatService(String model) { + super(new ModelConfig()); + getModelConfig().setModel(model); + } - public String chat(String sentence) { - RemoteModelClient model = new RemoteModelClient(); - ModelContext context = ModelContext.emptyContext(); - context.setModelInfo(getModelConfig()); - context.userSay(sentence); - return model.chat(context); - } + public String chat(String sentence) { + RemoteModelClient model = new RemoteModelClient(); + ModelContext context = ModelContext.emptyContext(); + context.setModelInfo(getModelConfig()); + context.userSay(sentence); + return model.chat(context); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingResponse.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingResponse.java index d8bcd9a34..6be1fed5b 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingResponse.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingResponse.java @@ -19,29 +19,30 @@ package org.apache.geaflow.ai.common.model; -import com.google.gson.annotations.SerializedName; import java.util.List; +import com.google.gson.annotations.SerializedName; + public class EmbeddingResponse { - public List data; - public String model; - public String object; - public Response.Usage usage; + public List data; + public String model; + public String object; + public Response.Usage usage; - public static class EmbeddingVector { - public int index; - public double[] embedding; - } + public static class EmbeddingVector { + public int index; + public double[] embedding; + } - public static class Usage { - @SerializedName("completion_tokens") - public int completionTokens; - @SerializedName("prompt_tokens") - public int promptTokens; - @SerializedName("total_tokens") - public int totalTokens; + public static class Usage { + @SerializedName("completion_tokens") + public int completionTokens; - } + @SerializedName("prompt_tokens") + public int promptTokens; -} \ No newline at end of file + @SerializedName("total_tokens") + public int totalTokens; + } +} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingService.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingService.java index 2ad134762..904ab0c9e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingService.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/EmbeddingService.java @@ -19,41 +19,42 @@ package org.apache.geaflow.ai.common.model; -import com.google.gson.Gson; import java.util.List; -public class EmbeddingService extends AbstractModelService { - - public EmbeddingService() { - super(new ModelConfig()); - } +import com.google.gson.Gson; - public EmbeddingService(String model) { - super(new ModelConfig()); - getModelConfig().setModel(model); - } +public class EmbeddingService extends AbstractModelService { - public String embedding(String... inputs) { - RemoteModelClient model = new RemoteModelClient(); - ModelEmbedding context = ModelEmbedding.embedding(getModelConfig(), inputs); - List embeddingResults = model.embedding(context); - Gson gson = new Gson(); - StringBuilder builder = new StringBuilder(); - for (EmbeddingResult result : embeddingResults) { - builder.append("\n"); - String json = gson.toJson(result); - builder.append(json); - } - return builder.toString(); + public EmbeddingService() { + super(new ModelConfig()); + } + + public EmbeddingService(String model) { + super(new ModelConfig()); + getModelConfig().setModel(model); + } + + public String embedding(String... inputs) { + RemoteModelClient model = new RemoteModelClient(); + ModelEmbedding context = ModelEmbedding.embedding(getModelConfig(), inputs); + List embeddingResults = model.embedding(context); + Gson gson = new Gson(); + StringBuilder builder = new StringBuilder(); + for (EmbeddingResult result : embeddingResults) { + builder.append("\n"); + String json = gson.toJson(result); + builder.append(json); } + return builder.toString(); + } - public static class EmbeddingResult { - public String input; - public double[] embedding; + public static class EmbeddingResult { + public String input; + public double[] embedding; - public EmbeddingResult(String input, double[] embedding) { - this.input = input; - this.embedding = embedding; - } + public EmbeddingResult(String input, double[] embedding) { + this.input = input; + this.embedding = embedding; } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelConfig.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelConfig.java index 55ab1fcc1..ed57600e7 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelConfig.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelConfig.java @@ -21,50 +21,49 @@ public class ModelConfig { - private String model; - private String url; - private String api; - private String userToken; + private String model; + private String url; + private String api; + private String userToken; - public ModelConfig() { - } + public ModelConfig() {} - public ModelConfig(String model, String url, String api, String userToken) { - this.model = model; - this.url = url; - this.api = api; - this.userToken = userToken; - } + public ModelConfig(String model, String url, String api, String userToken) { + this.model = model; + this.url = url; + this.api = api; + this.userToken = userToken; + } - public String getModel() { - return model; - } + public String getModel() { + return model; + } - public void setModel(String model) { - this.model = model; - } + public void setModel(String model) { + this.model = model; + } - public String getUserToken() { - return userToken; - } + public String getUserToken() { + return userToken; + } - public void setUserToken(String userToken) { - this.userToken = userToken; - } + public void setUserToken(String userToken) { + this.userToken = userToken; + } - public String getUrl() { - return url; - } + public String getUrl() { + return url; + } - public void setUrl(String url) { - this.url = url; - } + public void setUrl(String url) { + this.url = url; + } - public String getApi() { - return api; - } + public String getApi() { + return api; + } - public void setApi(String api) { - this.api = api; - } + public void setApi(String api) { + this.api = api; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelContext.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelContext.java index 94dfc7952..ec69d8556 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelContext.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelContext.java @@ -23,51 +23,49 @@ public class ModelContext { - public String model; - public ModelMessage[] messages; - private ModelConfig modelInfo; + public String model; + public ModelMessage[] messages; + private ModelConfig modelInfo; - public ModelContext(ModelConfig modelInfo, ModelMessage[] messages) { - this.modelInfo = modelInfo; - this.messages = messages; - if (this.modelInfo != null) { - this.model = this.modelInfo.getModel(); - } + public ModelContext(ModelConfig modelInfo, ModelMessage[] messages) { + this.modelInfo = modelInfo; + this.messages = messages; + if (this.modelInfo != null) { + this.model = this.modelInfo.getModel(); } + } - public static ModelContext emptyContext() { - return new ModelContext(null, new ModelMessage[0]); - } - - public ModelConfig getModelInfo() { - return modelInfo; - } + public static ModelContext emptyContext() { + return new ModelContext(null, new ModelMessage[0]); + } - public void setModelInfo(ModelConfig modelInfo) { - this.modelInfo = modelInfo; - if (this.modelInfo != null) { - this.model = this.modelInfo.getModel(); - } - } + public ModelConfig getModelInfo() { + return modelInfo; + } - public ModelContext userSay(String sentence) { - ModelMessage msg = new ModelMessage(Constants.MODEL_CONTEXT_ROLE_USER, sentence); - ModelMessage[] newMessages = new ModelMessage[messages.length + 1]; - for (int i = 0; i < messages.length; i++) { - newMessages[i] = messages[i]; - } - newMessages[messages.length] = msg; - messages = newMessages; - return this; + public void setModelInfo(ModelConfig modelInfo) { + this.modelInfo = modelInfo; + if (this.modelInfo != null) { + this.model = this.modelInfo.getModel(); } + } - - public ModelMessage[] getMessages() { - return messages; + public ModelContext userSay(String sentence) { + ModelMessage msg = new ModelMessage(Constants.MODEL_CONTEXT_ROLE_USER, sentence); + ModelMessage[] newMessages = new ModelMessage[messages.length + 1]; + for (int i = 0; i < messages.length; i++) { + newMessages[i] = messages[i]; } + newMessages[messages.length] = msg; + messages = newMessages; + return this; + } - public void setMessages(ModelMessage[] messages) { - this.messages = messages; - } + public ModelMessage[] getMessages() { + return messages; + } + public void setMessages(ModelMessage[] messages) { + this.messages = messages; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelEmbedding.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelEmbedding.java index 6cfe4a428..5f135ec0f 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelEmbedding.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelEmbedding.java @@ -21,38 +21,38 @@ public class ModelEmbedding { - public String model; - public String[] input; - private ModelConfig modelInfo; - - public ModelEmbedding(ModelConfig modelInfo, String[] inputs) { - this.modelInfo = modelInfo; - this.input = inputs; - if (this.modelInfo != null) { - this.model = this.modelInfo.getModel(); - } + public String model; + public String[] input; + private ModelConfig modelInfo; + + public ModelEmbedding(ModelConfig modelInfo, String[] inputs) { + this.modelInfo = modelInfo; + this.input = inputs; + if (this.modelInfo != null) { + this.model = this.modelInfo.getModel(); } + } - public static ModelEmbedding embedding(ModelConfig modelInfo, String... strings) { - return new ModelEmbedding(modelInfo, strings); - } + public static ModelEmbedding embedding(ModelConfig modelInfo, String... strings) { + return new ModelEmbedding(modelInfo, strings); + } - public ModelConfig getModelInfo() { - return modelInfo; - } + public ModelConfig getModelInfo() { + return modelInfo; + } - public void setModelInfo(ModelConfig modelInfo) { - this.modelInfo = modelInfo; - if (this.modelInfo != null) { - this.model = this.modelInfo.getModel(); - } + public void setModelInfo(ModelConfig modelInfo) { + this.modelInfo = modelInfo; + if (this.modelInfo != null) { + this.model = this.modelInfo.getModel(); } + } - public String getModel() { - return model; - } + public String getModel() { + return model; + } - public void setModel(String model) { - this.model = model; - } + public void setModel(String model) { + this.model = model; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelMessage.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelMessage.java index 417fce017..38faabaec 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelMessage.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelMessage.java @@ -21,11 +21,11 @@ public class ModelMessage { - public String role; - public String content; + public String role; + public String content; - public ModelMessage(String role, String content) { - this.role = role; - this.content = content; - } + public ModelMessage(String role, String content) { + this.role = role; + this.content = content; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelUtils.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelUtils.java index 91a27d5e7..05a899450 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelUtils.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/ModelUtils.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; @@ -28,24 +29,26 @@ public class ModelUtils { - public static List splitLongText(int maxChunkSize, String... textList) { - List chunks = new ArrayList<>(); - for (String text : textList) { - for (int i = 0; i < text.length(); i += maxChunkSize) { - int end = Math.min(i + maxChunkSize, text.length()); - chunks.add(text.substring(i, end)); - } - } - return chunks; + public static List splitLongText(int maxChunkSize, String... textList) { + List chunks = new ArrayList<>(); + for (String text : textList) { + for (int i = 0; i < text.length(); i += maxChunkSize) { + int end = Math.min(i + maxChunkSize, text.length()); + chunks.add(text.substring(i, end)); + } } + return chunks; + } - public static String getGraphEntityKey(GraphEntity entity) { - if (entity instanceof GraphVertex) { - return Constants.PREFIX_V + ((GraphVertex) entity).getVertex().getId() + entity.getLabel(); - } else if (entity instanceof GraphEdge) { - return Constants.PREFIX_E + ((GraphEdge) entity).getEdge().getSrcId() - + entity.getLabel() + ((GraphEdge) entity).getEdge().getDstId(); - } - return ""; + public static String getGraphEntityKey(GraphEntity entity) { + if (entity instanceof GraphVertex) { + return Constants.PREFIX_V + ((GraphVertex) entity).getVertex().getId() + entity.getLabel(); + } else if (entity instanceof GraphEdge) { + return Constants.PREFIX_E + + ((GraphEdge) entity).getEdge().getSrcId() + + entity.getLabel() + + ((GraphEdge) entity).getEdge().getDstId(); } + return ""; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/OkHttpDirectConnector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/OkHttpDirectConnector.java index b7d1d3d23..94bd1a8c4 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/OkHttpDirectConnector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/OkHttpDirectConnector.java @@ -19,102 +19,101 @@ package org.apache.geaflow.ai.common.model; -import com.google.gson.Gson; import java.io.IOException; import java.util.Objects; import java.util.concurrent.TimeUnit; + +import org.apache.geaflow.ai.common.config.Constants; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.gson.Gson; + import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; -import org.apache.geaflow.ai.common.config.Constants; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class OkHttpDirectConnector { - private static final Logger LOGGER = LoggerFactory.getLogger(OkHttpDirectConnector.class); - - private static final Gson GSON = new Gson(); - private static OkHttpClient client; - - private final String endpoint; - private final String useApi; - private final String userToken; - - public OkHttpDirectConnector(String endpoint, String useApi, String userToken) { - this.endpoint = endpoint; - this.useApi = useApi; - this.userToken = userToken; - if (client == null) { - OkHttpClient.Builder builder = new OkHttpClient.Builder(); - builder.callTimeout(Constants.HTTP_CALL_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.connectTimeout(Constants.HTTP_CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.readTimeout(Constants.HTTP_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.writeTimeout(Constants.HTTP_WRITE_TIMEOUT_SECONDS, TimeUnit.SECONDS); - client = builder.build(); - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(OkHttpDirectConnector.class); + + private static final Gson GSON = new Gson(); + private static OkHttpClient client; + + private final String endpoint; + private final String useApi; + private final String userToken; - public org.apache.geaflow.ai.common.model.Response post(String bodyJson) { - RequestBody requestBody = RequestBody.create( - MediaType.parse("application/json; charset=utf-8"), - bodyJson - ); - - String url = endpoint + useApi; - LOGGER.info(url); - Request request = new Request.Builder() - .url(url) - .addHeader("Authorization", "Bearer " + userToken) - .addHeader("Content-Type", "application/json; charset=utf-8") - .post(requestBody) - .build(); - - try (okhttp3.Response response = client.newCall(request).execute()) { - if (response.isSuccessful() && response.body() != null) { - String responseBody = response.body().string(); - return GSON.fromJson(responseBody, org.apache.geaflow.ai.common.model.Response.class); - } else { - LOGGER.info("Request failed with code: " + response.code()); - } - } catch (IOException e) { - LOGGER.error("Http connect exception", e); - throw new RuntimeException(e); - } - return null; + public OkHttpDirectConnector(String endpoint, String useApi, String userToken) { + this.endpoint = endpoint; + this.useApi = useApi; + this.userToken = userToken; + if (client == null) { + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + builder.callTimeout(Constants.HTTP_CALL_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.connectTimeout(Constants.HTTP_CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.readTimeout(Constants.HTTP_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.writeTimeout(Constants.HTTP_WRITE_TIMEOUT_SECONDS, TimeUnit.SECONDS); + client = builder.build(); } + } + public org.apache.geaflow.ai.common.model.Response post(String bodyJson) { + RequestBody requestBody = + RequestBody.create(MediaType.parse("application/json; charset=utf-8"), bodyJson); - public EmbeddingResponse embeddingPost(String bodyJson) { - RequestBody requestBody = RequestBody.create( - MediaType.parse("application/json; charset=utf-8"), - bodyJson - ); - - String url = endpoint + useApi; - Request request = new Request.Builder() - .url(url) - .addHeader("Authorization", "Bearer " + userToken) - .addHeader("Content-Type", "application/json; charset=utf-8") - .post(requestBody) - .build(); - - try (okhttp3.Response response = client.newCall(request).execute()) { - if (response.isSuccessful() && response.body() != null) { - String responseBody = response.body().string(); - return GSON.fromJson(responseBody, EmbeddingResponse.class); - } else { - LOGGER.info("Request failed with code: " + response.code()); - LOGGER.info("Request failed with request bodyJson: " - + bodyJson); - LOGGER.info("Request failed with response body: " - + Objects.requireNonNull(response.body()).string()); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - return null; + String url = endpoint + useApi; + LOGGER.info(url); + Request request = + new Request.Builder() + .url(url) + .addHeader("Authorization", "Bearer " + userToken) + .addHeader("Content-Type", "application/json; charset=utf-8") + .post(requestBody) + .build(); + + try (okhttp3.Response response = client.newCall(request).execute()) { + if (response.isSuccessful() && response.body() != null) { + String responseBody = response.body().string(); + return GSON.fromJson(responseBody, org.apache.geaflow.ai.common.model.Response.class); + } else { + LOGGER.info("Request failed with code: " + response.code()); + } + } catch (IOException e) { + LOGGER.error("Http connect exception", e); + throw new RuntimeException(e); } + return null; + } + + public EmbeddingResponse embeddingPost(String bodyJson) { + RequestBody requestBody = + RequestBody.create(MediaType.parse("application/json; charset=utf-8"), bodyJson); + String url = endpoint + useApi; + Request request = + new Request.Builder() + .url(url) + .addHeader("Authorization", "Bearer " + userToken) + .addHeader("Content-Type", "application/json; charset=utf-8") + .post(requestBody) + .build(); + + try (okhttp3.Response response = client.newCall(request).execute()) { + if (response.isSuccessful() && response.body() != null) { + String responseBody = response.body().string(); + return GSON.fromJson(responseBody, EmbeddingResponse.class); + } else { + LOGGER.info("Request failed with code: " + response.code()); + LOGGER.info("Request failed with request bodyJson: " + bodyJson); + LOGGER.info( + "Request failed with response body: " + + Objects.requireNonNull(response.body()).string()); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + return null; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/RemoteModelClient.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/RemoteModelClient.java index c37e5e7ce..27a589d5e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/RemoteModelClient.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/RemoteModelClient.java @@ -19,60 +19,69 @@ package org.apache.geaflow.ai.common.model; -import com.google.gson.Gson; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.common.utils.RetryCommand; -public class RemoteModelClient { +import com.google.gson.Gson; - public String chat(ModelContext context) { - ModelConfig info = context.getModelInfo(); - OkHttpDirectConnector connector = new OkHttpDirectConnector( - info.getUrl(), info.getApi(), info.getUserToken()); - String request = new Gson().toJson(context); - org.apache.geaflow.ai.common.model.Response response = RetryCommand.run(() -> { - return connector.post(request); - }, Constants.MODEL_CLIENT_RETRY_TIMES, Constants.MODEL_CLIENT_RETRY_INTERVAL_MS); - if (response.choices != null && !response.choices.isEmpty()) { - for (Response.Choice choice : response.choices) { - if (choice.message != null) { - return choice.message.content; - } - } +public class RemoteModelClient { + public String chat(ModelContext context) { + ModelConfig info = context.getModelInfo(); + OkHttpDirectConnector connector = + new OkHttpDirectConnector(info.getUrl(), info.getApi(), info.getUserToken()); + String request = new Gson().toJson(context); + org.apache.geaflow.ai.common.model.Response response = + RetryCommand.run( + () -> { + return connector.post(request); + }, + Constants.MODEL_CLIENT_RETRY_TIMES, + Constants.MODEL_CLIENT_RETRY_INTERVAL_MS); + if (response.choices != null && !response.choices.isEmpty()) { + for (Response.Choice choice : response.choices) { + if (choice.message != null) { + return choice.message.content; } - return null; + } } + return null; + } - public List embedding(ModelEmbedding context) { - ModelConfig info = context.getModelInfo(); - OkHttpDirectConnector connector = new OkHttpDirectConnector( - info.getUrl(), info.getApi(), info.getUserToken()); - ModelEmbedding requestContext = new ModelEmbedding(null, context.input); - requestContext.setModel(context.getModel()); - String request = new Gson().toJson(requestContext); - final EmbeddingResponse response = RetryCommand.run(() -> { - return connector.embeddingPost(request); - }, Constants.MODEL_CLIENT_RETRY_TIMES, Constants.MODEL_CLIENT_RETRY_INTERVAL_MS); - if (response == null) { - return new ArrayList<>(); - } - List embeddingResults = new ArrayList<>(); - if (response.data == null) { - throw new RuntimeException("Embedding model response is null"); - } - for (EmbeddingResponse.EmbeddingVector v : response.data) { - int index = v.index; - if (index >= context.input.length) { - throw new RuntimeException("Embedding model response contains invalid index"); - } - String input = context.input[index]; - double[] vector = v.embedding; - EmbeddingService.EmbeddingResult result = new EmbeddingService.EmbeddingResult(input, vector); - embeddingResults.add(result); - } - return embeddingResults; + public List embedding(ModelEmbedding context) { + ModelConfig info = context.getModelInfo(); + OkHttpDirectConnector connector = + new OkHttpDirectConnector(info.getUrl(), info.getApi(), info.getUserToken()); + ModelEmbedding requestContext = new ModelEmbedding(null, context.input); + requestContext.setModel(context.getModel()); + String request = new Gson().toJson(requestContext); + final EmbeddingResponse response = + RetryCommand.run( + () -> { + return connector.embeddingPost(request); + }, + Constants.MODEL_CLIENT_RETRY_TIMES, + Constants.MODEL_CLIENT_RETRY_INTERVAL_MS); + if (response == null) { + return new ArrayList<>(); + } + List embeddingResults = new ArrayList<>(); + if (response.data == null) { + throw new RuntimeException("Embedding model response is null"); + } + for (EmbeddingResponse.EmbeddingVector v : response.data) { + int index = v.index; + if (index >= context.input.length) { + throw new RuntimeException("Embedding model response contains invalid index"); + } + String input = context.input[index]; + double[] vector = v.embedding; + EmbeddingService.EmbeddingResult result = new EmbeddingService.EmbeddingResult(input, vector); + embeddingResults.add(result); } + return embeddingResults; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/Response.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/Response.java index b53c2759e..26b30af98 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/Response.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/model/Response.java @@ -19,39 +19,44 @@ package org.apache.geaflow.ai.common.model; -import com.google.gson.annotations.SerializedName; import java.util.List; +import com.google.gson.annotations.SerializedName; public class Response { - public String id; - @SerializedName("created") - public long createdTimestamp; - public String model; - public String object; - public List choices; - public Usage usage; - @SerializedName("system_fingerprint") - public String systemFingerprint; - - public static class Choice { - public String finishReason; - public int index; - public Message message; - public Object logprobs; - } - - public static class Message { - public String role; - public String content; - } - - public static class Usage { - @SerializedName("completion_tokens") - public int completionTokens; - @SerializedName("prompt_tokens") - public int promptTokens; - @SerializedName("total_tokens") - public int totalTokens; - } + public String id; + + @SerializedName("created") + public long createdTimestamp; + + public String model; + public String object; + public List choices; + public Usage usage; + + @SerializedName("system_fingerprint") + public String systemFingerprint; + + public static class Choice { + public String finishReason; + public int index; + public Message message; + public Object logprobs; + } + + public static class Message { + public String role; + public String content; + } + + public static class Usage { + @SerializedName("completion_tokens") + public int completionTokens; + + @SerializedName("prompt_tokens") + public int promptTokens; + + @SerializedName("total_tokens") + public int totalTokens; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/util/SeDeUtil.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/util/SeDeUtil.java index af05a03c1..638b5eae7 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/util/SeDeUtil.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/common/util/SeDeUtil.java @@ -19,67 +19,68 @@ package org.apache.geaflow.ai.common.util; -import com.google.gson.*; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.graph.GraphVertex; import org.apache.geaflow.ai.graph.io.*; +import com.google.gson.*; + public class SeDeUtil { - private static final Gson GSON = new Gson(); + private static final Gson GSON = new Gson(); - public static String serializeGraphSchema(GraphSchema schema) { - return GSON.toJson(schema); - } + public static String serializeGraphSchema(GraphSchema schema) { + return GSON.toJson(schema); + } - public static GraphSchema deserializeGraphSchema(String json) { - return GSON.fromJson(json, GraphSchema.class); - } + public static GraphSchema deserializeGraphSchema(String json) { + return GSON.fromJson(json, GraphSchema.class); + } - public static Schema deserializeEntitySchema(String json) { - boolean isVertex = new JsonParser().parse(json).getAsJsonObject().has("idField"); - if (isVertex) { - return GSON.fromJson(json, VertexSchema.class); - } else { - return GSON.fromJson(json, EdgeSchema.class); - } + public static Schema deserializeEntitySchema(String json) { + boolean isVertex = new JsonParser().parse(json).getAsJsonObject().has("idField"); + if (isVertex) { + return GSON.fromJson(json, VertexSchema.class); + } else { + return GSON.fromJson(json, EdgeSchema.class); } + } - public static String serializeEntitySchema(Schema schema) { - if (schema instanceof VertexSchema) { - return GSON.toJson((VertexSchema) schema); - } else { - return GSON.toJson((EdgeSchema) schema); - } + public static String serializeEntitySchema(Schema schema) { + if (schema instanceof VertexSchema) { + return GSON.toJson((VertexSchema) schema); + } else { + return GSON.toJson((EdgeSchema) schema); } + } - public static List deserializeEntities(String json) { - JsonArray jsonArray; - List entities = new ArrayList<>(); - try { - jsonArray = new JsonParser().parse(json).getAsJsonArray(); - } catch (Throwable e) { - JsonObject jsonObject = new JsonParser().parse(json).getAsJsonObject(); - if (jsonObject.has(Constants.PREFIX_ID)) { - entities.add(new GraphVertex(GSON.fromJson(json, Vertex.class))); - } else { - entities.add(new GraphEdge(GSON.fromJson(json, Edge.class))); - } - return entities; - } - for (JsonElement element : jsonArray) { - JsonObject jsonObject = element.getAsJsonObject(); - if (jsonObject.has(Constants.PREFIX_ID)) { - entities.add(new GraphVertex(GSON.fromJson(json, Vertex.class))); - } else { - entities.add(new GraphEdge(GSON.fromJson(json, Edge.class))); - } - } - return entities; + public static List deserializeEntities(String json) { + JsonArray jsonArray; + List entities = new ArrayList<>(); + try { + jsonArray = new JsonParser().parse(json).getAsJsonArray(); + } catch (Throwable e) { + JsonObject jsonObject = new JsonParser().parse(json).getAsJsonObject(); + if (jsonObject.has(Constants.PREFIX_ID)) { + entities.add(new GraphVertex(GSON.fromJson(json, Vertex.class))); + } else { + entities.add(new GraphEdge(GSON.fromJson(json, Edge.class))); + } + return entities; } - + for (JsonElement element : jsonArray) { + JsonObject jsonObject = element.getAsJsonObject(); + if (jsonObject.has(Constants.PREFIX_ID)) { + entities.add(new GraphVertex(GSON.fromJson(json, Vertex.class))); + } else { + entities.add(new GraphEdge(GSON.fromJson(json, Edge.class))); + } + } + return entities; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/ConsolidateServer.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/ConsolidateServer.java index d446b6515..14809fb64 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/ConsolidateServer.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/ConsolidateServer.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.consolidate.function.ConsolidateFunction; import org.apache.geaflow.ai.consolidate.function.EmbeddingRelationFunction; import org.apache.geaflow.ai.consolidate.function.KeywordRelationFunction; @@ -29,17 +30,17 @@ public class ConsolidateServer { - private static final List functions = new ArrayList<>(); + private static final List functions = new ArrayList<>(); - static { - functions.add(new KeywordRelationFunction()); - functions.add(new EmbeddingRelationFunction()); - } + static { + functions.add(new KeywordRelationFunction()); + functions.add(new EmbeddingRelationFunction()); + } - public int executeConsolidateTask(GraphAccessor graphAccessor, MutableGraph graph) { - for (ConsolidateFunction function : functions) { - function.eval(graphAccessor, graph); - } - return 0; + public int executeConsolidateTask(GraphAccessor graphAccessor, MutableGraph graph) { + for (ConsolidateFunction function : functions) { + function.eval(graphAccessor, graph); } + return 0; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/ConsolidateFunction.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/ConsolidateFunction.java index 9dabdb64f..3aeada223 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/ConsolidateFunction.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/ConsolidateFunction.java @@ -24,5 +24,5 @@ public interface ConsolidateFunction { - void eval(GraphAccessor graphAccessor, MutableGraph graph); + void eval(GraphAccessor graphAccessor, MutableGraph graph); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/EmbeddingRelationFunction.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/EmbeddingRelationFunction.java index 34976ef9c..d9e6b4e32 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/EmbeddingRelationFunction.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/EmbeddingRelationFunction.java @@ -24,8 +24,6 @@ public class EmbeddingRelationFunction implements ConsolidateFunction { - @Override - public void eval(GraphAccessor graphAccessor, MutableGraph graph) { - - } + @Override + public void eval(GraphAccessor graphAccessor, MutableGraph graph) {} } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/KeywordRelationFunction.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/KeywordRelationFunction.java index 29291eea2..d5924897c 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/KeywordRelationFunction.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/consolidate/function/KeywordRelationFunction.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.ai.GraphMemoryServer; import org.apache.geaflow.ai.common.ErrorCode; import org.apache.geaflow.ai.common.config.Constants; @@ -37,66 +38,67 @@ public class KeywordRelationFunction implements ConsolidateFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(KeywordRelationFunction.class); - - @Override - public void eval(GraphAccessor graphAccessor, MutableGraph mutableGraph) { - EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); - indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); - LOGGER.info("Success to init EntityAttributeIndexStore."); - GraphMemoryServer server = new GraphMemoryServer(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - LOGGER.info("Success to init GraphMemoryServer."); + private static final Logger LOGGER = LoggerFactory.getLogger(KeywordRelationFunction.class); - if (null == mutableGraph.getSchema().getSchema( - Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL)) { - int code = mutableGraph.addEdgeSchema(new EdgeSchema( - Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, - Constants.PREFIX_SRC_ID, Constants.PREFIX_DST_ID, - Collections.singletonList(Constants.PREFIX_COMMON_KEYWORDS) - )); - if (code != ErrorCode.SUCCESS) { - return; - } - } + @Override + public void eval(GraphAccessor graphAccessor, MutableGraph mutableGraph) { + EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); + indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); + LOGGER.info("Success to init EntityAttributeIndexStore."); + GraphMemoryServer server = new GraphMemoryServer(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + LOGGER.info("Success to init GraphMemoryServer."); - Long cnt = 0L; - Iterator vertexIterator = graphAccessor.scanVertex(); - while (vertexIterator.hasNext()) { - GraphVertex vertex = vertexIterator.next(); - boolean existRelation = false; - Iterator neighborIterator = graphAccessor.scanEdge(vertex); - while (neighborIterator.hasNext()) { - GraphEdge existEdge = neighborIterator.next(); - if (existEdge.getLabel().equals(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL)) { - existRelation = true; - break; - } - } - if (existRelation) { - continue; - } - String sessionId = server.createSession(); - VectorSearch search = new VectorSearch(null, sessionId); - search.addVector(new KeywordVector(vertex.toString())); - sessionId = server.search(search); - List results = server.getSessionEntities(sessionId); - for (GraphEntity relateEntity : results) { - if (relateEntity instanceof GraphVertex) { - String srcId = vertex.getVertex().getId(); - String dstId = ((GraphVertex) relateEntity).getVertex().getId(); - mutableGraph.addEdge(new Edge( - Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, - srcId, dstId, - Collections.singletonList(Constants.PREFIX_COMMON_KEYWORDS) - )); - } - } + if (null == mutableGraph.getSchema().getSchema(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL)) { + int code = + mutableGraph.addEdgeSchema( + new EdgeSchema( + Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, + Constants.PREFIX_SRC_ID, + Constants.PREFIX_DST_ID, + Collections.singletonList(Constants.PREFIX_COMMON_KEYWORDS))); + if (code != ErrorCode.SUCCESS) { + return; + } + } - cnt++; - LOGGER.info("Process vertex num: {}", cnt); + Long cnt = 0L; + Iterator vertexIterator = graphAccessor.scanVertex(); + while (vertexIterator.hasNext()) { + GraphVertex vertex = vertexIterator.next(); + boolean existRelation = false; + Iterator neighborIterator = graphAccessor.scanEdge(vertex); + while (neighborIterator.hasNext()) { + GraphEdge existEdge = neighborIterator.next(); + if (existEdge.getLabel().equals(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL)) { + existRelation = true; + break; + } + } + if (existRelation) { + continue; + } + String sessionId = server.createSession(); + VectorSearch search = new VectorSearch(null, sessionId); + search.addVector(new KeywordVector(vertex.toString())); + sessionId = server.search(search); + List results = server.getSessionEntities(sessionId); + for (GraphEntity relateEntity : results) { + if (relateEntity instanceof GraphVertex) { + String srcId = vertex.getVertex().getId(); + String dstId = ((GraphVertex) relateEntity).getVertex().getId(); + mutableGraph.addEdge( + new Edge( + Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, + srcId, + dstId, + Collections.singletonList(Constants.PREFIX_COMMON_KEYWORDS))); } + } + cnt++; + LOGGER.info("Process vertex num: {}", cnt); } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/engine/GraphComputeEngine.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/engine/GraphComputeEngine.java index 6942e9e0a..e732abea7 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/engine/GraphComputeEngine.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/engine/GraphComputeEngine.java @@ -19,5 +19,4 @@ package org.apache.geaflow.ai.engine; -public interface GraphComputeEngine { -} +public interface GraphComputeEngine {} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/EmptyGraphAccessor.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/EmptyGraphAccessor.java index e0910d385..e085c21b3 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/EmptyGraphAccessor.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/EmptyGraphAccessor.java @@ -23,47 +23,48 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.ai.graph.io.GraphSchema; public class EmptyGraphAccessor implements GraphAccessor { - @Override - public GraphSchema getGraphSchema() { - return null; - } + @Override + public GraphSchema getGraphSchema() { + return null; + } - @Override - public GraphVertex getVertex(String label, String id) { - return null; - } + @Override + public GraphVertex getVertex(String label, String id) { + return null; + } - @Override - public List getEdge(String label, String src, String dst) { - return null; - } + @Override + public List getEdge(String label, String src, String dst) { + return null; + } - @Override - public Iterator scanVertex() { - return Collections.emptyIterator(); - } + @Override + public Iterator scanVertex() { + return Collections.emptyIterator(); + } - @Override - public Iterator scanEdge(GraphVertex vertex) { - return Collections.emptyIterator(); - } + @Override + public Iterator scanEdge(GraphVertex vertex) { + return Collections.emptyIterator(); + } - @Override - public List expand(GraphEntity entity) { - return new ArrayList<>(); - } + @Override + public List expand(GraphEntity entity) { + return new ArrayList<>(); + } - @Override - public GraphAccessor copy() { - return new EmptyGraphAccessor(); - } + @Override + public GraphAccessor copy() { + return new EmptyGraphAccessor(); + } - @Override - public String getType() { - return EmptyGraphAccessor.class.getSimpleName(); - } + @Override + public String getType() { + return EmptyGraphAccessor.class.getSimpleName(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/Graph.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/Graph.java index 8322bae16..5f9e7812e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/Graph.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/Graph.java @@ -21,29 +21,30 @@ import java.util.Collection; import java.util.Iterator; + import org.apache.geaflow.ai.graph.io.Edge; import org.apache.geaflow.ai.graph.io.GraphSchema; import org.apache.geaflow.ai.graph.io.Vertex; public interface Graph { - GraphSchema getGraphSchema(); + GraphSchema getGraphSchema(); - Vertex getVertex(String label, String id); + Vertex getVertex(String label, String id); - int removeVertex(String label, String id); + int removeVertex(String label, String id); - int updateVertex(Vertex newVertex); + int updateVertex(Vertex newVertex); - int addVertex(Vertex newVertex); + int addVertex(Vertex newVertex); - Collection getEdge(String label, String src, String dst); + Collection getEdge(String label, String src, String dst); - int removeEdge(Edge edge); + int removeEdge(Edge edge); - int addEdge(Edge newEdge); + int addEdge(Edge newEdge); - Iterator scanEdge(Vertex vertex); + Iterator scanEdge(Vertex vertex); - Iterator scanVertex(); + Iterator scanVertex(); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphAccessor.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphAccessor.java index 82dd1cd57..fbb671aa7 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphAccessor.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphAccessor.java @@ -21,23 +21,24 @@ import java.util.Iterator; import java.util.List; + import org.apache.geaflow.ai.graph.io.GraphSchema; public interface GraphAccessor { - GraphSchema getGraphSchema(); + GraphSchema getGraphSchema(); - GraphVertex getVertex(String label, String id); + GraphVertex getVertex(String label, String id); - List getEdge(String label, String src, String dst); + List getEdge(String label, String src, String dst); - Iterator scanVertex(); + Iterator scanVertex(); - Iterator scanEdge(GraphVertex vertex); + Iterator scanEdge(GraphVertex vertex); - List expand(GraphEntity entity); + List expand(GraphEntity entity); - GraphAccessor copy(); + GraphAccessor copy(); - String getType(); + String getType(); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEdge.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEdge.java index 0dd3d2216..baa498f77 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEdge.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEdge.java @@ -20,44 +20,45 @@ package org.apache.geaflow.ai.graph; import java.util.Objects; + import org.apache.geaflow.ai.graph.io.Edge; public class GraphEdge implements GraphEntity { - private final Edge edge; + private final Edge edge; - public GraphEdge(Edge edge) { - this.edge = Objects.requireNonNull(edge); - } + public GraphEdge(Edge edge) { + this.edge = Objects.requireNonNull(edge); + } - public Edge getEdge() { - return edge; - } + public Edge getEdge() { + return edge; + } - @Override - public String getLabel() { - return edge.getLabel(); - } + @Override + public String getLabel() { + return edge.getLabel(); + } - @Override - public String toString() { - return edge.toString(); - } + @Override + public String toString() { + return edge.toString(); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - GraphEdge graphEdge = (GraphEdge) o; - return Objects.equals(edge, graphEdge.edge); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(edge); + if (o == null || getClass() != o.getClass()) { + return false; } + GraphEdge graphEdge = (GraphEdge) o; + return Objects.equals(edge, graphEdge.edge); + } + + @Override + public int hashCode() { + return Objects.hash(edge); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEntity.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEntity.java index 5d63649a4..20bbc0e76 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEntity.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphEntity.java @@ -21,5 +21,5 @@ public interface GraphEntity { - String getLabel(); + String getLabel(); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphVertex.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphVertex.java index 5e3d58889..f48161812 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphVertex.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/GraphVertex.java @@ -20,44 +20,45 @@ package org.apache.geaflow.ai.graph; import java.util.Objects; + import org.apache.geaflow.ai.graph.io.Vertex; public class GraphVertex implements GraphEntity { - private final Vertex vertex; + private final Vertex vertex; - public GraphVertex(Vertex vertex) { - this.vertex = Objects.requireNonNull(vertex); - } + public GraphVertex(Vertex vertex) { + this.vertex = Objects.requireNonNull(vertex); + } - public Vertex getVertex() { - return vertex; - } + public Vertex getVertex() { + return vertex; + } - @Override - public String getLabel() { - return vertex.getLabel(); - } + @Override + public String getLabel() { + return vertex.getLabel(); + } - @Override - public String toString() { - return vertex.toString(); - } + @Override + public String toString() { + return vertex.toString(); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - GraphVertex that = (GraphVertex) o; - return Objects.equals(vertex, that.vertex); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(vertex); + if (o == null || getClass() != o.getClass()) { + return false; } + GraphVertex that = (GraphVertex) o; + return Objects.equals(vertex, that.vertex); + } + + @Override + public int hashCode() { + return Objects.hash(vertex); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/LocalMemoryGraphAccessor.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/LocalMemoryGraphAccessor.java index 98614bc4a..69f68d09e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/LocalMemoryGraphAccessor.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/LocalMemoryGraphAccessor.java @@ -25,134 +25,136 @@ import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.io.*; public class LocalMemoryGraphAccessor implements GraphAccessor { - private final MemoryGraph graph; - - public LocalMemoryGraphAccessor(ClassLoader classLoader, String resourcePath, Long limit, - Function vertexMapper, - Function edgeMapper) { - try { - this.graph = GraphFileReader.getGraph(classLoader, resourcePath, limit, - vertexMapper, edgeMapper); - } catch (Throwable e) { - throw new RuntimeException("Init local graph error", e); - } + private final MemoryGraph graph; + + public LocalMemoryGraphAccessor( + ClassLoader classLoader, + String resourcePath, + Long limit, + Function vertexMapper, + Function edgeMapper) { + try { + this.graph = + GraphFileReader.getGraph(classLoader, resourcePath, limit, vertexMapper, edgeMapper); + } catch (Throwable e) { + throw new RuntimeException("Init local graph error", e); } - - public LocalMemoryGraphAccessor(MemoryGraph memoryGraph) { - this.graph = memoryGraph; + } + + public LocalMemoryGraphAccessor(MemoryGraph memoryGraph) { + this.graph = memoryGraph; + } + + @Override + public GraphSchema getGraphSchema() { + return graph.getGraphSchema(); + } + + @Override + public GraphVertex getVertex(String label, String id) { + Vertex innerVertex = graph.getVertex(label, id); + if (innerVertex == null) { + return null; } - - @Override - public GraphSchema getGraphSchema() { - return graph.getGraphSchema(); + return new GraphVertex(innerVertex); + } + + @Override + public List getEdge(String label, String src, String dst) { + List innerEdges = graph.getEdge(label, src, dst); + if (innerEdges == null) { + return Collections.emptyList(); } - - @Override - public GraphVertex getVertex(String label, String id) { - Vertex innerVertex = graph.getVertex(label, id); - if (innerVertex == null) { - return null; - } - return new GraphVertex(innerVertex); + return innerEdges.stream().map(GraphEdge::new).collect(Collectors.toList()); + } + + @Override + public Iterator scanVertex() { + return new GraphVertexIterator(graph.scanVertex()); + } + + @Override + public Iterator scanEdge(GraphVertex vertex) { + return new GraphEdgeIterator(graph.scanEdge(vertex.getVertex())); + } + + @Override + public List expand(GraphEntity entity) { + List results = new ArrayList<>(); + if (entity instanceof GraphVertex) { + Iterator iterator = graph.scanEdge(((GraphVertex) entity).getVertex()); + while (iterator.hasNext()) { + results.add(new GraphEdge(iterator.next())); + } + } else if (entity instanceof GraphEdge) { + GraphEdge graphEdge = (GraphEdge) entity; + Vertex srcVertex = graph.getVertex(null, graphEdge.getEdge().getSrcId()); + Vertex dstVertex = graph.getVertex(null, graphEdge.getEdge().getDstId()); + if (srcVertex != null) { + results.add(new GraphVertex(srcVertex)); + } + if (dstVertex != null) { + results.add(new GraphVertex(dstVertex)); + } } + return results; + } - @Override - public List getEdge(String label, String src, String dst) { - List innerEdges = graph.getEdge(label, src, dst); - if (innerEdges == null) { - return Collections.emptyList(); - } - return innerEdges.stream().map(GraphEdge::new).collect(Collectors.toList()); - } + public MemoryMutableGraph getMutableGraph() { + return new MemoryMutableGraph(graph); + } - @Override - public Iterator scanVertex() { - return new GraphVertexIterator(graph.scanVertex()); - } + @Override + public GraphAccessor copy() { + return this; + } - @Override - public Iterator scanEdge(GraphVertex vertex) { - return new GraphEdgeIterator(graph.scanEdge(vertex.getVertex())); - } + @Override + public String getType() { + return this.getClass().getSimpleName(); + } - @Override - public List expand(GraphEntity entity) { - List results = new ArrayList<>(); - if (entity instanceof GraphVertex) { - Iterator iterator = graph.scanEdge(((GraphVertex) entity).getVertex()); - while (iterator.hasNext()) { - results.add(new GraphEdge(iterator.next())); - } - } else if (entity instanceof GraphEdge) { - GraphEdge graphEdge = (GraphEdge) entity; - Vertex srcVertex = graph.getVertex(null, graphEdge.getEdge().getSrcId()); - Vertex dstVertex = graph.getVertex(null, graphEdge.getEdge().getDstId()); - if (srcVertex != null) { - results.add(new GraphVertex(srcVertex)); - } - if (dstVertex != null) { - results.add(new GraphVertex(dstVertex)); - } - } - return results; - } + private static class GraphVertexIterator implements Iterator { + + private final Iterator vertexIterator; - public MemoryMutableGraph getMutableGraph() { - return new MemoryMutableGraph(graph); + public GraphVertexIterator(Iterator vertexIterator) { + this.vertexIterator = vertexIterator; } @Override - public GraphAccessor copy() { - return this; + public boolean hasNext() { + return vertexIterator.hasNext(); } @Override - public String getType() { - return this.getClass().getSimpleName(); + public GraphVertex next() { + Vertex nextVertex = vertexIterator.next(); + return new GraphVertex(nextVertex); } + } + private static class GraphEdgeIterator implements Iterator { + private final Iterator delegate; - private static class GraphVertexIterator implements Iterator { - - private final Iterator vertexIterator; - - public GraphVertexIterator(Iterator vertexIterator) { - this.vertexIterator = vertexIterator; - } - - @Override - public boolean hasNext() { - return vertexIterator.hasNext(); - } - - @Override - public GraphVertex next() { - Vertex nextVertex = vertexIterator.next(); - return new GraphVertex(nextVertex); - } + public GraphEdgeIterator(Iterator delegate) { + this.delegate = delegate; } - private static class GraphEdgeIterator implements Iterator { - private final Iterator delegate; - - public GraphEdgeIterator(Iterator delegate) { - this.delegate = delegate; - } - - @Override - public boolean hasNext() { - return delegate.hasNext(); - } - - @Override - public GraphEdge next() { - Edge edge = delegate.next(); - return new GraphEdge(edge); - } + @Override + public boolean hasNext() { + return delegate.hasNext(); } + @Override + public GraphEdge next() { + Edge edge = delegate.next(); + return new GraphEdge(edge); + } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MemoryMutableGraph.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MemoryMutableGraph.java index 37b9a6d93..47e8115fe 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MemoryMutableGraph.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MemoryMutableGraph.java @@ -20,91 +20,93 @@ package org.apache.geaflow.ai.graph; import java.util.ArrayList; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.ai.common.ErrorCode; import org.apache.geaflow.ai.graph.io.*; public class MemoryMutableGraph implements MutableGraph { - private final MemoryGraph graph; + private final MemoryGraph graph; - public MemoryMutableGraph(MemoryGraph graph) { - this.graph = graph; - } + public MemoryMutableGraph(MemoryGraph graph) { + this.graph = graph; + } - @Override - public int removeVertex(String label, String id) { - return this.graph.removeVertex(label, id); - } + @Override + public int removeVertex(String label, String id) { + return this.graph.removeVertex(label, id); + } - @Override - public int updateVertex(Vertex newVertex) { - return this.graph.updateVertex(newVertex); - } + @Override + public int updateVertex(Vertex newVertex) { + return this.graph.updateVertex(newVertex); + } - @Override - public int addVertex(Vertex newVertex) { - return this.graph.addVertex(newVertex); - } + @Override + public int addVertex(Vertex newVertex) { + return this.graph.addVertex(newVertex); + } - @Override - public int removeEdge(Edge edge) { - return this.graph.removeEdge(edge); - } + @Override + public int removeEdge(Edge edge) { + return this.graph.removeEdge(edge); + } - @Override - public int addEdge(Edge newEdge) { - return this.graph.addEdge(newEdge); - } + @Override + public int addEdge(Edge newEdge) { + return this.graph.addEdge(newEdge); + } - @Override - public GraphSchema getSchema() { - return this.graph.getGraphSchema(); - } + @Override + public GraphSchema getSchema() { + return this.graph.getGraphSchema(); + } - @Override - public int addVertexSchema(VertexSchema vertexSchema) { - if (vertexSchema == null || StringUtils.isBlank(vertexSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; - } - for (VertexSchema existSchema : this.getSchema().getVertexSchemaList()) { - if (existSchema.getLabel().equals(vertexSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; - } - } - for (EdgeSchema existSchema : this.getSchema().getEdgeSchemaList()) { - if (existSchema.getLabel().equals(vertexSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; - } - } - if (this.graph.entities.get(vertexSchema.getLabel()) != null) { - return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; - } - this.graph.getGraphSchema().addVertex(vertexSchema); - this.graph.entities.put(vertexSchema.getLabel(), new VertexGroup(vertexSchema, new ArrayList<>())); - return ErrorCode.SUCCESS; + @Override + public int addVertexSchema(VertexSchema vertexSchema) { + if (vertexSchema == null || StringUtils.isBlank(vertexSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; + } + for (VertexSchema existSchema : this.getSchema().getVertexSchemaList()) { + if (existSchema.getLabel().equals(vertexSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; + } + } + for (EdgeSchema existSchema : this.getSchema().getEdgeSchemaList()) { + if (existSchema.getLabel().equals(vertexSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; + } + } + if (this.graph.entities.get(vertexSchema.getLabel()) != null) { + return ErrorCode.GRAPH_ADD_VERTEX_SCHEMA_FAILED; } + this.graph.getGraphSchema().addVertex(vertexSchema); + this.graph.entities.put( + vertexSchema.getLabel(), new VertexGroup(vertexSchema, new ArrayList<>())); + return ErrorCode.SUCCESS; + } - @Override - public int addEdgeSchema(EdgeSchema edgeSchema) { - if (edgeSchema == null || StringUtils.isBlank(edgeSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; - } - for (VertexSchema existSchema : this.getSchema().getVertexSchemaList()) { - if (existSchema.getLabel().equals(edgeSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; - } - } - for (EdgeSchema existSchema : this.getSchema().getEdgeSchemaList()) { - if (existSchema.getLabel().equals(edgeSchema.getLabel())) { - return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; - } - } - if (this.graph.entities.get(edgeSchema.getLabel()) != null) { - return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; - } - this.graph.getGraphSchema().addEdge(edgeSchema); - this.graph.entities.put(edgeSchema.getLabel(), new EdgeGroup(edgeSchema, new ArrayList<>())); - return ErrorCode.SUCCESS; + @Override + public int addEdgeSchema(EdgeSchema edgeSchema) { + if (edgeSchema == null || StringUtils.isBlank(edgeSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; + } + for (VertexSchema existSchema : this.getSchema().getVertexSchemaList()) { + if (existSchema.getLabel().equals(edgeSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; + } + } + for (EdgeSchema existSchema : this.getSchema().getEdgeSchemaList()) { + if (existSchema.getLabel().equals(edgeSchema.getLabel())) { + return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; + } + } + if (this.graph.entities.get(edgeSchema.getLabel()) != null) { + return ErrorCode.GRAPH_ADD_EDGE_SCHEMA_FAILED; } + this.graph.getGraphSchema().addEdge(edgeSchema); + this.graph.entities.put(edgeSchema.getLabel(), new EdgeGroup(edgeSchema, new ArrayList<>())); + return ErrorCode.SUCCESS; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MutableGraph.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MutableGraph.java index 0f999a694..13d602894 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MutableGraph.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/MutableGraph.java @@ -23,19 +23,19 @@ public interface MutableGraph { - int removeVertex(String label, String id); + int removeVertex(String label, String id); - int updateVertex(Vertex newVertex); + int updateVertex(Vertex newVertex); - int addVertex(Vertex newVertex); + int addVertex(Vertex newVertex); - int removeEdge(Edge edge); + int removeEdge(Edge edge); - int addEdge(Edge newEdge); + int addEdge(Edge newEdge); - GraphSchema getSchema(); + GraphSchema getSchema(); - int addVertexSchema(VertexSchema vertexSchema); + int addVertexSchema(VertexSchema vertexSchema); - int addEdgeSchema(EdgeSchema edgeSchema); + int addEdgeSchema(EdgeSchema edgeSchema); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/CsvFileReader.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/CsvFileReader.java index 28f4e8c0a..d88fec9fa 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/CsvFileReader.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/CsvFileReader.java @@ -27,127 +27,128 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class CsvFileReader { - private static final Logger LOGGER = LoggerFactory.getLogger(CsvFileReader.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CsvFileReader.class); + + List colSchema; + Map> fileContent; + long limit; - List colSchema; - Map> fileContent; - long limit; + public CsvFileReader(long limit) { + this.colSchema = new ArrayList<>(); + this.fileContent = new HashMap<>(); + this.limit = limit; + } - public CsvFileReader(long limit) { - this.colSchema = new ArrayList<>(); - this.fileContent = new HashMap<>(); - this.limit = limit; + public void readCsvFile(String fileName) throws IOException { + InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName); + if (inputStream == null) { + throw new IOException("Cannot find the file: " + fileName); } - public void readCsvFile(String fileName) throws IOException { - InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName); - if (inputStream == null) { - throw new IOException("Cannot find the file: " + fileName); + long count = 0; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + boolean isFirstLine = true; + + while ((line = reader.readLine()) != null) { + count++; + if (isFirstLine) { + parseHeader(line); + isFirstLine = false; + } else { + parseDataRow(line); } - - long count = 0; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { - String line; - boolean isFirstLine = true; - - while ((line = reader.readLine()) != null) { - count++; - if (isFirstLine) { - parseHeader(line); - isFirstLine = false; - } else { - parseDataRow(line); - } - if (count > limit) { - break; - } - } + if (count > limit) { + break; } + } } + } - private void parseHeader(String headerLine) { - String[] headers = headerLine.split("\\|"); + private void parseHeader(String headerLine) { + String[] headers = headerLine.split("\\|"); - colSchema.clear(); - fileContent.clear(); + colSchema.clear(); + fileContent.clear(); - for (String header : headers) { - colSchema.add(header.trim()); - fileContent.put(header.trim(), new ArrayList<>()); - } + for (String header : headers) { + colSchema.add(header.trim()); + fileContent.put(header.trim(), new ArrayList<>()); } + } - private void parseDataRow(String dataLine) { - if (dataLine == null || dataLine.trim().isEmpty()) { - return; - } - - String[] values = dataLine.split("\\|"); - if (values.length != colSchema.size()) { - System.err.println("WARNING: line number does not match - " + dataLine); - return; - } - - for (int i = 0; i < colSchema.size(); i++) { - String columnName = colSchema.get(i); - String value = values[i].trim(); - fileContent.get(columnName).add(value); - } + private void parseDataRow(String dataLine) { + if (dataLine == null || dataLine.trim().isEmpty()) { + return; } - public List getColSchema() { - return new ArrayList<>(colSchema); + String[] values = dataLine.split("\\|"); + if (values.length != colSchema.size()) { + System.err.println("WARNING: line number does not match - " + dataLine); + return; } - public Map> getFileContent() { - Map> copy = new HashMap<>(); - for (Map.Entry> entry : fileContent.entrySet()) { - copy.put(entry.getKey(), new ArrayList<>(entry.getValue())); - } - return copy; + for (int i = 0; i < colSchema.size(); i++) { + String columnName = colSchema.get(i); + String value = values[i].trim(); + fileContent.get(columnName).add(value); } + } - public List getColumnData(String columnName) { - List data = fileContent.get(columnName); - return data != null ? new ArrayList<>(data) : new ArrayList<>(); - } + public List getColSchema() { + return new ArrayList<>(colSchema); + } - public int getRowCount() { - if (fileContent.isEmpty()) { - return 0; - } - return fileContent.values().iterator().next().size(); + public Map> getFileContent() { + Map> copy = new HashMap<>(); + for (Map.Entry> entry : fileContent.entrySet()) { + copy.put(entry.getKey(), new ArrayList<>(entry.getValue())); } + return copy; + } - public List getRow(int rowIndex) { - List row = new ArrayList<>(); - int rowCount = getRowCount(); + public List getColumnData(String columnName) { + List data = fileContent.get(columnName); + return data != null ? new ArrayList<>(data) : new ArrayList<>(); + } - if (rowIndex < 0 || rowIndex >= rowCount) { - throw new IndexOutOfBoundsException("Row index out of range: " + rowIndex); - } + public int getRowCount() { + if (fileContent.isEmpty()) { + return 0; + } + return fileContent.values().iterator().next().size(); + } - for (String columnName : colSchema) { - List columnData = fileContent.get(columnName); - if (columnData != null && rowIndex < columnData.size()) { - row.add(columnData.get(rowIndex)); - } - } + public List getRow(int rowIndex) { + List row = new ArrayList<>(); + int rowCount = getRowCount(); - return row; + if (rowIndex < 0 || rowIndex >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of range: " + rowIndex); } - public void printContent() { - LOGGER.info("ColName: " + colSchema); - LOGGER.info("Data content:"); - for (Map.Entry> entry : fileContent.entrySet()) { - LOGGER.info(entry.getKey() + ": " + entry.getValue()); - } - LOGGER.info("Total row count: " + getRowCount()); + for (String columnName : colSchema) { + List columnData = fileContent.get(columnName); + if (columnData != null && rowIndex < columnData.size()) { + row.add(columnData.get(rowIndex)); + } + } + + return row; + } + + public void printContent() { + LOGGER.info("ColName: " + colSchema); + LOGGER.info("Data content:"); + for (Map.Entry> entry : fileContent.entrySet()) { + LOGGER.info(entry.getKey() + ": " + entry.getValue()); } + LOGGER.info("Total row count: " + getRowCount()); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Edge.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Edge.java index fa4d98576..f8d70bfec 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Edge.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Edge.java @@ -24,53 +24,55 @@ public class Edge { - private final String srcId; - private final String dstId; - private final String label; - private final List values; + private final String srcId; + private final String dstId; + private final String label; + private final List values; - public Edge(String label, String srcId, String dstId, List values) { - this.label = label; - this.srcId = srcId; - this.dstId = dstId; - this.values = values; - } + public Edge(String label, String srcId, String dstId, List values) { + this.label = label; + this.srcId = srcId; + this.dstId = dstId; + this.values = values; + } - public List getValues() { - return values; - } + public List getValues() { + return values; + } - public String getSrcId() { - return srcId; - } + public String getSrcId() { + return srcId; + } - public String getDstId() { - return dstId; - } + public String getDstId() { + return dstId; + } - public String getLabel() { - return label; - } + public String getLabel() { + return label; + } - @Override - public String toString() { - return label + " | " + srcId + " | " + dstId + " | " + String.join("|", values); - } + @Override + public String toString() { + return label + " | " + srcId + " | " + dstId + " | " + String.join("|", values); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Edge edge = (Edge) o; - return Objects.equals(srcId, edge.srcId) && Objects.equals(dstId, edge.dstId) && Objects.equals(label, edge.label); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(srcId, dstId, label); + if (o == null || getClass() != o.getClass()) { + return false; } + Edge edge = (Edge) o; + return Objects.equals(srcId, edge.srcId) + && Objects.equals(dstId, edge.dstId) + && Objects.equals(label, edge.label); + } + + @Override + public int hashCode() { + return Objects.hash(srcId, dstId, label); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeGroup.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeGroup.java index d134d823c..00b37937a 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeGroup.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeGroup.java @@ -21,112 +21,115 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.geaflow.ai.common.ErrorCode; public class EdgeGroup implements EntityGroup { - public final EdgeSchema edgeSchema; - private final List edges; - private final Map> index; - private final Map> pointIndices; + public final EdgeSchema edgeSchema; + private final List edges; + private final Map> index; + private final Map> pointIndices; - public EdgeGroup(EdgeSchema edgeSchema, List edges) { - this.edgeSchema = edgeSchema; - this.edges = edges; - this.index = new HashMap<>(edges.size()); - this.pointIndices = new HashMap<>(10000); - buildIndex(); - } + public EdgeGroup(EdgeSchema edgeSchema, List edges) { + this.edgeSchema = edgeSchema; + this.edges = edges; + this.index = new HashMap<>(edges.size()); + this.pointIndices = new HashMap<>(10000); + buildIndex(); + } - private void buildIndex() { - int index = 0; - for (Edge e : edges) { - String key = makeKey(e.getSrcId(), e.getDstId()); - this.index.computeIfAbsent(key, k -> new ArrayList<>()).add(index); - this.pointIndices.computeIfAbsent(e.getSrcId(), k -> new ArrayList<>()).add(index); - this.pointIndices.computeIfAbsent(e.getDstId(), k -> new ArrayList<>()).add(index); - index++; - } + private void buildIndex() { + int index = 0; + for (Edge e : edges) { + String key = makeKey(e.getSrcId(), e.getDstId()); + this.index.computeIfAbsent(key, k -> new ArrayList<>()).add(index); + this.pointIndices.computeIfAbsent(e.getSrcId(), k -> new ArrayList<>()).add(index); + this.pointIndices.computeIfAbsent(e.getDstId(), k -> new ArrayList<>()).add(index); + index++; } + } - public EdgeSchema getEdgeSchema() { - return edgeSchema; - } + public EdgeSchema getEdgeSchema() { + return edgeSchema; + } - public List getOutEdges() { - return edges; - } + public List getOutEdges() { + return edges; + } - public List getOutEdges(String src) { - if (pointIndices.get(src) == null) { - return new ArrayList<>(); - } - return pointIndices.get(src).stream().map(edges::get) - .filter(e -> e.getSrcId().equals(src)) - .collect(Collectors.toList()); + public List getOutEdges(String src) { + if (pointIndices.get(src) == null) { + return new ArrayList<>(); } + return pointIndices.get(src).stream() + .map(edges::get) + .filter(e -> e.getSrcId().equals(src)) + .collect(Collectors.toList()); + } - public List getInEdges(String src) { - if (pointIndices.get(src) == null) { - return new ArrayList<>(); - } - return pointIndices.get(src).stream().map(edges::get) - .filter(e -> e.getDstId().equals(src)) - .collect(Collectors.toList()); + public List getInEdges(String src) { + if (pointIndices.get(src) == null) { + return new ArrayList<>(); } + return pointIndices.get(src).stream() + .map(edges::get) + .filter(e -> e.getDstId().equals(src)) + .collect(Collectors.toList()); + } - public List getEdge(String src, String dst) { - List edgeIndices = index.get(makeKey(src, dst)); - if (edgeIndices == null) { - return Collections.emptyList(); - } - List edges = new ArrayList<>(edgeIndices.size()); - for (int i : edgeIndices) { - edges.add(this.edges.get(i)); - } - return edges; + public List getEdge(String src, String dst) { + List edgeIndices = index.get(makeKey(src, dst)); + if (edgeIndices == null) { + return Collections.emptyList(); } - - public int addEdge(Edge newEdge) { - if (newEdge == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; - } - this.edges.add(newEdge); - int index = this.edges.size() - 1; - String key = makeKey(newEdge.getSrcId(), newEdge.getDstId()); - this.index.computeIfAbsent(key, k -> new ArrayList<>()).add(index); - this.pointIndices.computeIfAbsent(newEdge.getSrcId(), k -> new ArrayList<>()).add(index); - this.pointIndices.computeIfAbsent(newEdge.getDstId(), k -> new ArrayList<>()).add(index); - return ErrorCode.SUCCESS; + List edges = new ArrayList<>(edgeIndices.size()); + for (int i : edgeIndices) { + edges.add(this.edges.get(i)); } + return edges; + } - public int removeEdge(Edge edge) { - if (edge == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; - } - String key = makeKey(edge.getSrcId(), edge.getDstId()); - if (!index.containsKey(key)) { - return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; - } - List indices = index.get(key); - List deleteOffset = new ArrayList<>(); - for (int offset : indices) { - Edge existsEdge = this.edges.get(offset); - boolean needDelete = existsEdge.equals(edge); - if (needDelete) { - edges.set(offset, null); - deleteOffset.add(offset); - } - } - for (int del : deleteOffset) { - this.index.get(key).remove(del); - this.pointIndices.get(edge.getSrcId()).remove(del); - this.pointIndices.get(edge.getDstId()).remove(del); - } - return ErrorCode.SUCCESS; + public int addEdge(Edge newEdge) { + if (newEdge == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; } + this.edges.add(newEdge); + int index = this.edges.size() - 1; + String key = makeKey(newEdge.getSrcId(), newEdge.getDstId()); + this.index.computeIfAbsent(key, k -> new ArrayList<>()).add(index); + this.pointIndices.computeIfAbsent(newEdge.getSrcId(), k -> new ArrayList<>()).add(index); + this.pointIndices.computeIfAbsent(newEdge.getDstId(), k -> new ArrayList<>()).add(index); + return ErrorCode.SUCCESS; + } - private String makeKey(String src, String dst) { - return src + "-" + dst; + public int removeEdge(Edge edge) { + if (edge == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; } + String key = makeKey(edge.getSrcId(), edge.getDstId()); + if (!index.containsKey(key)) { + return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; + } + List indices = index.get(key); + List deleteOffset = new ArrayList<>(); + for (int offset : indices) { + Edge existsEdge = this.edges.get(offset); + boolean needDelete = existsEdge.equals(edge); + if (needDelete) { + edges.set(offset, null); + deleteOffset.add(offset); + } + } + for (int del : deleteOffset) { + this.index.get(key).remove(del); + this.pointIndices.get(edge.getSrcId()).remove(del); + this.pointIndices.get(edge.getDstId()).remove(del); + } + return ErrorCode.SUCCESS; + } + + private String makeKey(String src, String dst) { + return src + "-" + dst; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeSchema.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeSchema.java index 8e6ea0b0c..c2be84143 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeSchema.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EdgeSchema.java @@ -23,38 +23,37 @@ public class EdgeSchema implements Schema { - private final String label; - private final List fields; - private final String srcIdField; - private final String dstIdField; - - public EdgeSchema(String label, String srcIdField, - String dstIdField, List fields) { - this.label = label; - this.fields = fields; - this.srcIdField = srcIdField; - this.dstIdField = dstIdField; - } - - @Override - public List getFields() { - return fields; - } - - public String getLabel() { - return label; - } - - public String getSrcIdField() { - return srcIdField; - } - - public String getDstIdField() { - return dstIdField; - } - - @Override - public String getName() { - return label; - } + private final String label; + private final List fields; + private final String srcIdField; + private final String dstIdField; + + public EdgeSchema(String label, String srcIdField, String dstIdField, List fields) { + this.label = label; + this.fields = fields; + this.srcIdField = srcIdField; + this.dstIdField = dstIdField; + } + + @Override + public List getFields() { + return fields; + } + + public String getLabel() { + return label; + } + + public String getSrcIdField() { + return srcIdField; + } + + public String getDstIdField() { + return dstIdField; + } + + @Override + public String getName() { + return label; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EntityGroup.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EntityGroup.java index f754de103..88785a26a 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EntityGroup.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/EntityGroup.java @@ -19,5 +19,4 @@ package org.apache.geaflow.ai.graph.io; -public interface EntityGroup { -} +public interface EntityGroup {} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphFileReader.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphFileReader.java index 87e72327c..5f32b1c3c 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphFileReader.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphFileReader.java @@ -28,53 +28,61 @@ public class GraphFileReader { - public static MemoryGraph getGraph(ClassLoader classLoader, String path, long limit, - Function vertexMapper, - Function edgeMapper) throws IOException { - Map> result = ResourceFileScanner.scanGraphLdbcSfFolder(classLoader, path); - GraphSchema graphSchema = new GraphSchema(); - Map entities = new HashMap<>(); - for (Map.Entry> entry : result.entrySet()) { - String entityName = entry.getKey(); - List fileNames = entry.getValue(); - for (String fileName : fileNames) { - CsvFileReader reader = new CsvFileReader(limit); - reader.readCsvFile(path + "/" + entityName + "/" + fileName); - List colSchema = reader.getColSchema(); - boolean isVertex = colSchema.contains("id"); - if (isVertex) { - int idIndex = colSchema.indexOf("id"); - if (idIndex < 0) { - throw new RuntimeException("Cannot find index of the id column."); - } - VertexSchema vertexSchema = new VertexSchema(entityName, colSchema.get(idIndex), colSchema); - List vertices = new ArrayList<>(reader.getRowCount()); - for (int i = 0; i < reader.getRowCount(); i++) { - List row = reader.getRow(i); - Vertex newVertex = vertexMapper.apply(new Vertex(entityName, row.get(idIndex), row)); - vertices.add(newVertex); - } - VertexGroup vertexGroup = new VertexGroup(vertexSchema, vertices); - entities.put(entityName, vertexGroup); - graphSchema.addVertex(vertexSchema); - } else { - boolean containTime = colSchema.contains("creationDate"); - int srcIdIndex = containTime ? 1 : 0; - int dstIdIndex = containTime ? 2 : 1; - EdgeSchema edgeSchema = new EdgeSchema(entityName, colSchema.get(srcIdIndex), colSchema.get(dstIdIndex), colSchema); - List edges = new ArrayList<>(reader.getRowCount()); - for (int i = 0; i < reader.getRowCount(); i++) { - List row = reader.getRow(i); - Edge newEdge = edgeMapper.apply(new Edge(entityName, row.get(srcIdIndex), row.get(dstIdIndex), row)); - edges.add(newEdge); - } - EdgeGroup edgeGroup = new EdgeGroup(edgeSchema, edges); - entities.put(entityName, edgeGroup); - graphSchema.addEdge(edgeSchema); - } - } + public static MemoryGraph getGraph( + ClassLoader classLoader, + String path, + long limit, + Function vertexMapper, + Function edgeMapper) + throws IOException { + Map> result = ResourceFileScanner.scanGraphLdbcSfFolder(classLoader, path); + GraphSchema graphSchema = new GraphSchema(); + Map entities = new HashMap<>(); + for (Map.Entry> entry : result.entrySet()) { + String entityName = entry.getKey(); + List fileNames = entry.getValue(); + for (String fileName : fileNames) { + CsvFileReader reader = new CsvFileReader(limit); + reader.readCsvFile(path + "/" + entityName + "/" + fileName); + List colSchema = reader.getColSchema(); + boolean isVertex = colSchema.contains("id"); + if (isVertex) { + int idIndex = colSchema.indexOf("id"); + if (idIndex < 0) { + throw new RuntimeException("Cannot find index of the id column."); + } + VertexSchema vertexSchema = + new VertexSchema(entityName, colSchema.get(idIndex), colSchema); + List vertices = new ArrayList<>(reader.getRowCount()); + for (int i = 0; i < reader.getRowCount(); i++) { + List row = reader.getRow(i); + Vertex newVertex = vertexMapper.apply(new Vertex(entityName, row.get(idIndex), row)); + vertices.add(newVertex); + } + VertexGroup vertexGroup = new VertexGroup(vertexSchema, vertices); + entities.put(entityName, vertexGroup); + graphSchema.addVertex(vertexSchema); + } else { + boolean containTime = colSchema.contains("creationDate"); + int srcIdIndex = containTime ? 1 : 0; + int dstIdIndex = containTime ? 2 : 1; + EdgeSchema edgeSchema = + new EdgeSchema( + entityName, colSchema.get(srcIdIndex), colSchema.get(dstIdIndex), colSchema); + List edges = new ArrayList<>(reader.getRowCount()); + for (int i = 0; i < reader.getRowCount(); i++) { + List row = reader.getRow(i); + Edge newEdge = + edgeMapper.apply( + new Edge(entityName, row.get(srcIdIndex), row.get(dstIdIndex), row)); + edges.add(newEdge); + } + EdgeGroup edgeGroup = new EdgeGroup(edgeSchema, edges); + entities.put(entityName, edgeGroup); + graphSchema.addEdge(edgeSchema); } - return new MemoryGraph(graphSchema, entities); + } } - + return new MemoryGraph(graphSchema, entities); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphSchema.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphSchema.java index d18117841..728f4f046 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphSchema.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/GraphSchema.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphVertex; @@ -28,86 +29,86 @@ public class GraphSchema implements Schema { - private String graphName = Constants.PREFIX_GRAPH; - private final List vertexSchemaList; - private final List edgeSchemaList; - private PromptFormatter promptFormatter; - - public GraphSchema() { - this.vertexSchemaList = new ArrayList<>(); - this.edgeSchemaList = new ArrayList<>(); - } - - public GraphSchema(List vertexSchemaList, List edgeSchemaList) { - this.vertexSchemaList = vertexSchemaList; - this.edgeSchemaList = edgeSchemaList; - } - - public void addVertex(VertexSchema vertexSchema) { - vertexSchemaList.add(vertexSchema); - } - - public void addEdge(EdgeSchema edgeSchema) { - edgeSchemaList.add(edgeSchema); - } - - public Schema getSchema(String label) { - for (VertexSchema vs : vertexSchemaList) { - if (label.equals(vs.getLabel())) { - return vs; - } - } - for (EdgeSchema vs : edgeSchemaList) { - if (label.equals(vs.getLabel())) { - return vs; - } - } - return null; - } - - @Override - public List getFields() { - return null; + private String graphName = Constants.PREFIX_GRAPH; + private final List vertexSchemaList; + private final List edgeSchemaList; + private PromptFormatter promptFormatter; + + public GraphSchema() { + this.vertexSchemaList = new ArrayList<>(); + this.edgeSchemaList = new ArrayList<>(); + } + + public GraphSchema(List vertexSchemaList, List edgeSchemaList) { + this.vertexSchemaList = vertexSchemaList; + this.edgeSchemaList = edgeSchemaList; + } + + public void addVertex(VertexSchema vertexSchema) { + vertexSchemaList.add(vertexSchema); + } + + public void addEdge(EdgeSchema edgeSchema) { + edgeSchemaList.add(edgeSchema); + } + + public Schema getSchema(String label) { + for (VertexSchema vs : vertexSchemaList) { + if (label.equals(vs.getLabel())) { + return vs; + } } - - public List getVertexSchemaList() { - return vertexSchemaList; - } - - public List getEdgeSchemaList() { - return edgeSchemaList; - } - - public void setName(String graphName) { - this.graphName = graphName; - } - - @Override - public String getName() { - return graphName; - } - - public void setPromptFormatter(PromptFormatter promptFormatter) { - this.promptFormatter = promptFormatter; + for (EdgeSchema vs : edgeSchemaList) { + if (label.equals(vs.getLabel())) { + return vs; + } } - - public String getPrompt() { - return promptFormatter.prompt(this); - } - - public String getPrompt(GraphVertex entity) { - if (promptFormatter == null) { - return entity.toString(); - } else { - return promptFormatter.prompt(entity); - } + return null; + } + + @Override + public List getFields() { + return null; + } + + public List getVertexSchemaList() { + return vertexSchemaList; + } + + public List getEdgeSchemaList() { + return edgeSchemaList; + } + + public void setName(String graphName) { + this.graphName = graphName; + } + + @Override + public String getName() { + return graphName; + } + + public void setPromptFormatter(PromptFormatter promptFormatter) { + this.promptFormatter = promptFormatter; + } + + public String getPrompt() { + return promptFormatter.prompt(this); + } + + public String getPrompt(GraphVertex entity) { + if (promptFormatter == null) { + return entity.toString(); + } else { + return promptFormatter.prompt(entity); } + } - public String getPrompt(GraphEdge entity, GraphVertex start, GraphVertex end) { - if (promptFormatter == null) { - return entity.toString(); - } else { - return promptFormatter.prompt(entity, start, end); - } + public String getPrompt(GraphEdge entity, GraphVertex start, GraphVertex end) { + if (promptFormatter == null) { + return entity.toString(); + } else { + return promptFormatter.prompt(entity, start, end); } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/MemoryGraph.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/MemoryGraph.java index c1ae84b1e..e75ef3e4a 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/MemoryGraph.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/MemoryGraph.java @@ -20,178 +20,179 @@ package org.apache.geaflow.ai.graph.io; import java.util.*; + import org.apache.geaflow.ai.common.ErrorCode; import org.apache.geaflow.ai.graph.Graph; public class MemoryGraph implements Graph { - public GraphSchema graphSchema; - public Map entities; - - public MemoryGraph(GraphSchema graphSchema, Map entities) { - this.graphSchema = graphSchema; - this.entities = entities; + public GraphSchema graphSchema; + public Map entities; + + public MemoryGraph(GraphSchema graphSchema, Map entities) { + this.graphSchema = graphSchema; + this.entities = entities; + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + public void setGraphSchema(GraphSchema graphSchema) { + this.graphSchema = graphSchema; + } + + private EntityGroup getEntity(String entityName) { + return entities.get(entityName); + } + + @Override + public Vertex getVertex(String label, String id) { + if (label == null) { + for (VertexSchema schema : getGraphSchema().getVertexSchemaList()) { + Vertex res = getVertex(schema.getLabel(), id); + if (res != null) { + return res; + } + } + } else { + VertexGroup vg = (VertexGroup) getEntity(label); + if (vg == null) { + return null; + } + return vg.getVertex(id); } - - @Override - public GraphSchema getGraphSchema() { - return graphSchema; + return null; + } + + @Override + public int removeVertex(String label, String id) { + EntityGroup vg = entities.get(label); + if (vg == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; } - - public void setGraphSchema(GraphSchema graphSchema) { - this.graphSchema = graphSchema; + if (!(vg instanceof VertexGroup)) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; } - - private EntityGroup getEntity(String entityName) { - return entities.get(entityName); + VertexGroup vertexGroup = (VertexGroup) vg; + return vertexGroup.removeVertex(id); + } + + @Override + public int updateVertex(Vertex newVertex) { + String label = newVertex.getLabel(); + EntityGroup vg = entities.get(label); + if (vg == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; } - - @Override - public Vertex getVertex(String label, String id) { - if (label == null) { - for (VertexSchema schema : getGraphSchema().getVertexSchemaList()) { - Vertex res = getVertex(schema.getLabel(), id); - if (res != null) { - return res; - } - } - } else { - VertexGroup vg = (VertexGroup) getEntity(label); - if (vg == null) { - return null; - } - return vg.getVertex(id); - } - return null; + if (!(vg instanceof VertexGroup)) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; } - - @Override - public int removeVertex(String label, String id) { - EntityGroup vg = entities.get(label); - if (vg == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; - } - if (!(vg instanceof VertexGroup)) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; - } - VertexGroup vertexGroup = (VertexGroup) vg; - return vertexGroup.removeVertex(id); + VertexGroup vertexGroup = (VertexGroup) vg; + return vertexGroup.updateVertex(newVertex); + } + + @Override + public int addVertex(Vertex newVertex) { + String label = newVertex.getLabel(); + EntityGroup vg = entities.get(label); + if (vg == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; } - - @Override - public int updateVertex(Vertex newVertex) { - String label = newVertex.getLabel(); - EntityGroup vg = entities.get(label); - if (vg == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; - } - if (!(vg instanceof VertexGroup)) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; - } - VertexGroup vertexGroup = (VertexGroup) vg; - return vertexGroup.updateVertex(newVertex); + if (!(vg instanceof VertexGroup)) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; } - - @Override - public int addVertex(Vertex newVertex) { - String label = newVertex.getLabel(); - EntityGroup vg = entities.get(label); - if (vg == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; - } - if (!(vg instanceof VertexGroup)) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; - } - VertexGroup vertexGroup = (VertexGroup) vg; - return vertexGroup.addVertex(newVertex); + VertexGroup vertexGroup = (VertexGroup) vg; + return vertexGroup.addVertex(newVertex); + } + + @Override + public List getEdge(String label, String src, String dst) { + EdgeGroup eg = (EdgeGroup) getEntity(label); + if (eg == null) { + return Collections.emptyList(); } - - @Override - public List getEdge(String label, String src, String dst) { - EdgeGroup eg = (EdgeGroup) getEntity(label); - if (eg == null) { - return Collections.emptyList(); - } - return eg.getEdge(src, dst); + return eg.getEdge(src, dst); + } + + @Override + public int removeEdge(Edge edge) { + String label = edge.getLabel(); + EntityGroup vg = entities.get(label); + if (vg == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; } - - @Override - public int removeEdge(Edge edge) { - String label = edge.getLabel(); - EntityGroup vg = entities.get(label); - if (vg == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; - } - if (!(vg instanceof EdgeGroup)) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; - } - EdgeGroup edgeGroup = (EdgeGroup) vg; - return edgeGroup.removeEdge(edge); + if (!(vg instanceof EdgeGroup)) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; } - - @Override - public int addEdge(Edge newEdge) { - String label = newEdge.getLabel(); - EntityGroup vg = entities.get(label); - if (vg == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; - } - if (!(vg instanceof EdgeGroup)) { - return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; - } - EdgeGroup edgeGroup = (EdgeGroup) vg; - return edgeGroup.addEdge(newEdge); + EdgeGroup edgeGroup = (EdgeGroup) vg; + return edgeGroup.removeEdge(edge); + } + + @Override + public int addEdge(Edge newEdge) { + String label = newEdge.getLabel(); + EntityGroup vg = entities.get(label); + if (vg == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_EXISTS; } - - @Override - public Iterator scanEdge(Vertex vertex) { - List> iterators = new ArrayList<>(); - for (EntityGroup entityGroup : this.entities.values()) { - if (entityGroup instanceof EdgeGroup) { - iterators.add(((EdgeGroup) entityGroup).getOutEdges(vertex.getId()).iterator()); - iterators.add(((EdgeGroup) entityGroup).getInEdges(vertex.getId()).iterator()); - } - } - return new CompositeIterator<>(iterators); + if (!(vg instanceof EdgeGroup)) { + return ErrorCode.GRAPH_ENTITY_GROUP_NOT_MATCH; } - - @Override - public Iterator scanVertex() { - List> iterators = new ArrayList<>(); - for (EntityGroup entityGroup : this.entities.values()) { - if (entityGroup instanceof VertexGroup) { - iterators.add(((VertexGroup) entityGroup).getVertices().iterator()); - } - } - return new CompositeIterator<>(iterators); + EdgeGroup edgeGroup = (EdgeGroup) vg; + return edgeGroup.addEdge(newEdge); + } + + @Override + public Iterator scanEdge(Vertex vertex) { + List> iterators = new ArrayList<>(); + for (EntityGroup entityGroup : this.entities.values()) { + if (entityGroup instanceof EdgeGroup) { + iterators.add(((EdgeGroup) entityGroup).getOutEdges(vertex.getId()).iterator()); + iterators.add(((EdgeGroup) entityGroup).getInEdges(vertex.getId()).iterator()); + } + } + return new CompositeIterator<>(iterators); + } + + @Override + public Iterator scanVertex() { + List> iterators = new ArrayList<>(); + for (EntityGroup entityGroup : this.entities.values()) { + if (entityGroup instanceof VertexGroup) { + iterators.add(((VertexGroup) entityGroup).getVertices().iterator()); + } } + return new CompositeIterator<>(iterators); + } - static class CompositeIterator implements Iterator { + static class CompositeIterator implements Iterator { - private final List> iterators; - private int currentIndex = 0; + private final List> iterators; + private int currentIndex = 0; - public CompositeIterator(List> iterators) { - this.iterators = iterators; - } + public CompositeIterator(List> iterators) { + this.iterators = iterators; + } - @Override - public boolean hasNext() { - while (currentIndex < iterators.size()) { - if (iterators.get(currentIndex).hasNext()) { - return true; - } - currentIndex++; - } - return false; - } + @Override + public boolean hasNext() { + while (currentIndex < iterators.size()) { + if (iterators.get(currentIndex).hasNext()) { + return true; + } + currentIndex++; + } + return false; + } - @Override - public T next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - return iterators.get(currentIndex).next(); - } + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return iterators.get(currentIndex).next(); } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/ResourceFileScanner.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/ResourceFileScanner.java index 050483e53..c5fd59661 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/ResourceFileScanner.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/ResourceFileScanner.java @@ -31,35 +31,37 @@ public class ResourceFileScanner { - public static Map> scanGraphLdbcSfFolder(ClassLoader classLoader, - String path) { - Map> resultMap = new HashMap<>(); + public static Map> scanGraphLdbcSfFolder( + ClassLoader classLoader, String path) { + Map> resultMap = new HashMap<>(); - try { - Path resourcePath = Paths.get(classLoader.getResource(path).toURI()); + try { + Path resourcePath = Paths.get(classLoader.getResource(path).toURI()); - Files.list(resourcePath) - .filter(Files::isDirectory) - .forEach(dirPath -> { - String folderName = dirPath.getFileName().toString(); - try { - List fileNames = Files.list(dirPath) - .filter(Files::isRegularFile) - .map(filePath -> filePath.getFileName().toString()) - .collect(Collectors.toList()); + Files.list(resourcePath) + .filter(Files::isDirectory) + .forEach( + dirPath -> { + String folderName = dirPath.getFileName().toString(); + try { + List fileNames = + Files.list(dirPath) + .filter(Files::isRegularFile) + .map(filePath -> filePath.getFileName().toString()) + .collect(Collectors.toList()); - resultMap.put(folderName, fileNames); - } catch (IOException e) { - System.err.println("Fail to read: " + dirPath.toString()); - e.printStackTrace(); - } - }); + resultMap.put(folderName, fileNames); + } catch (IOException e) { + System.err.println("Fail to read: " + dirPath.toString()); + e.printStackTrace(); + } + }); - } catch (IOException | URISyntaxException e) { - System.err.println("Fail to scan resource files."); - e.printStackTrace(); - } - - return resultMap; + } catch (IOException | URISyntaxException e) { + System.err.println("Fail to scan resource files."); + e.printStackTrace(); } + + return resultMap; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Schema.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Schema.java index 711cf04c3..273f5b625 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Schema.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Schema.java @@ -23,8 +23,7 @@ public interface Schema { - String getName(); - - List getFields(); + String getName(); + List getFields(); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/TextFileReader.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/TextFileReader.java index 2ce06823b..766956b2a 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/TextFileReader.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/TextFileReader.java @@ -22,85 +22,92 @@ import java.io.*; import java.util.ArrayList; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TextFileReader { - private static final Logger LOGGER = LoggerFactory.getLogger(TextFileReader.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TextFileReader.class); - List fileContent; - long limit; + List fileContent; + long limit; - public TextFileReader(long limit) { - this.fileContent = new ArrayList<>(); - this.limit = limit; - } + public TextFileReader(long limit) { + this.fileContent = new ArrayList<>(); + this.limit = limit; + } - public void readFile(String fileName) throws IOException { - InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName); - if (inputStream == null) { - LOGGER.error("Cannot find the file: {} (tried as resource)", fileName); - File file = new File(fileName); - if (file.exists() && file.isFile()) { - try { - inputStream = new FileInputStream(file); - } catch (FileNotFoundException e) { - throw new IOException("Cannot find the file: " + fileName - + " (tried both as resource and as absolute path)", e); - } catch (Throwable e2) { - throw new IOException("Cannot open the file: " + fileName - + " (tried both as resource and as absolute path)", e2); - } - } else { - throw new IOException("Cannot find the file: " + fileName - + " (tried both as resource and as absolute path)"); - } + public void readFile(String fileName) throws IOException { + InputStream inputStream = getClass().getClassLoader().getResourceAsStream(fileName); + if (inputStream == null) { + LOGGER.error("Cannot find the file: {} (tried as resource)", fileName); + File file = new File(fileName); + if (file.exists() && file.isFile()) { + try { + inputStream = new FileInputStream(file); + } catch (FileNotFoundException e) { + throw new IOException( + "Cannot find the file: " + + fileName + + " (tried both as resource and as absolute path)", + e); + } catch (Throwable e2) { + throw new IOException( + "Cannot open the file: " + + fileName + + " (tried both as resource and as absolute path)", + e2); } + } else { + throw new IOException( + "Cannot find the file: " + fileName + " (tried both as resource and as absolute path)"); + } + } - long count = 0; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { - String line; + long count = 0; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + String line; - while ((line = reader.readLine()) != null) { - count++; - if (StringUtils.isNotBlank(line)) { - fileContent.add(line); - } - if (count > limit) { - break; - } - } + while ((line = reader.readLine()) != null) { + count++; + if (StringUtils.isNotBlank(line)) { + fileContent.add(line); + } + if (count > limit) { + break; } + } } + } - public List getFileContent() { - return fileContent; - } + public List getFileContent() { + return fileContent; + } - public int getRowCount() { - if (fileContent.isEmpty()) { - return 0; - } - return fileContent.size(); + public int getRowCount() { + if (fileContent.isEmpty()) { + return 0; } + return fileContent.size(); + } - public String getRow(int rowIndex) { - String row = null; - int rowCount = getRowCount(); - if (rowIndex < 0 || rowIndex >= rowCount) { - throw new IndexOutOfBoundsException("Row index out of range: " + rowIndex); - } - row = fileContent.get(rowIndex); - return row; + public String getRow(int rowIndex) { + String row = null; + int rowCount = getRowCount(); + if (rowIndex < 0 || rowIndex >= rowCount) { + throw new IndexOutOfBoundsException("Row index out of range: " + rowIndex); } + row = fileContent.get(rowIndex); + return row; + } - public void printContent() { - LOGGER.info("Data content:"); - for (String content : fileContent) { - LOGGER.info(content); - } - LOGGER.info("Total row count: " + getRowCount()); + public void printContent() { + LOGGER.info("Data content:"); + for (String content : fileContent) { + LOGGER.info(content); } + LOGGER.info("Total row count: " + getRowCount()); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Vertex.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Vertex.java index 72b320c0d..4d0619b72 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Vertex.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/Vertex.java @@ -24,47 +24,47 @@ public class Vertex { - private final String id; - private final String label; - private final List values; + private final String id; + private final String label; + private final List values; - public Vertex(String label, String id, List values) { - this.label = label; - this.id = id; - this.values = values; - } + public Vertex(String label, String id, List values) { + this.label = label; + this.id = id; + this.values = values; + } - public String getId() { - return id; - } + public String getId() { + return id; + } - public String getLabel() { - return label; - } + public String getLabel() { + return label; + } - public List getValues() { - return values; - } + public List getValues() { + return values; + } - @Override - public String toString() { - return label + " | " + id + " | " + String.join("|", values); - } + @Override + public String toString() { + return label + " | " + id + " | " + String.join("|", values); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Vertex vertex = (Vertex) o; - return Objects.equals(id, vertex.id) && Objects.equals(label, vertex.label); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(id, label); + if (o == null || getClass() != o.getClass()) { + return false; } + Vertex vertex = (Vertex) o; + return Objects.equals(id, vertex.id) && Objects.equals(label, vertex.label); + } + + @Override + public int hashCode() { + return Objects.hash(id, label); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexGroup.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexGroup.java index fa274417d..0674b98b1 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexGroup.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexGroup.java @@ -24,81 +24,82 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.geaflow.ai.common.ErrorCode; public class VertexGroup implements EntityGroup { - public final VertexSchema vertexSchema; - private final List vertices; - private final Map index; + public final VertexSchema vertexSchema; + private final List vertices; + private final Map index; - public VertexGroup(VertexSchema vertexSchema, List vertices) { - this.vertexSchema = vertexSchema; - this.vertices = vertices; - this.index = new HashMap<>(vertices.size()); - buildIndex(); - } + public VertexGroup(VertexSchema vertexSchema, List vertices) { + this.vertexSchema = vertexSchema; + this.vertices = vertices; + this.index = new HashMap<>(vertices.size()); + buildIndex(); + } - private void buildIndex() { - int index = 0; - for (Vertex v : vertices) { - this.index.put(v.getId(), index); - index++; - } + private void buildIndex() { + int index = 0; + for (Vertex v : vertices) { + this.index.put(v.getId(), index); + index++; } + } - public int addVertex(Vertex newVertex) { - if (newVertex == null || newVertex.getId() == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; - } - if (index.containsKey(newVertex.getId())) { - return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; - } - this.vertices.add(newVertex); - this.index.put(newVertex.getId(), vertices.size() - 1); - return ErrorCode.SUCCESS; + public int addVertex(Vertex newVertex) { + if (newVertex == null || newVertex.getId() == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; } - - public int updateVertex(Vertex newVertex) { - if (newVertex == null || newVertex.getId() == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_UPDATE_FAILED; - } - if (!index.containsKey(newVertex.getId())) { - return ErrorCode.GRAPH_ENTITY_GROUP_UPDATE_FAILED; - } - int offset = index.get(newVertex.getId()); - this.vertices.set(offset, newVertex); - return ErrorCode.SUCCESS; + if (index.containsKey(newVertex.getId())) { + return ErrorCode.GRAPH_ENTITY_GROUP_INSERT_FAILED; } + this.vertices.add(newVertex); + this.index.put(newVertex.getId(), vertices.size() - 1); + return ErrorCode.SUCCESS; + } - public int removeVertex(String id) { - if (id == null) { - return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; - } - if (!index.containsKey(id)) { - return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; - } - int offset = index.get(id); - vertices.set(offset, null); - index.remove(id); - return ErrorCode.SUCCESS; + public int updateVertex(Vertex newVertex) { + if (newVertex == null || newVertex.getId() == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_UPDATE_FAILED; } - - public VertexSchema getVertexSchema() { - return vertexSchema; + if (!index.containsKey(newVertex.getId())) { + return ErrorCode.GRAPH_ENTITY_GROUP_UPDATE_FAILED; } + int offset = index.get(newVertex.getId()); + this.vertices.set(offset, newVertex); + return ErrorCode.SUCCESS; + } - public List getVertices() { - List vertices = this.vertices.stream() - .filter(Objects::nonNull).collect(Collectors.toList()); - return vertices; + public int removeVertex(String id) { + if (id == null) { + return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; } + if (!index.containsKey(id)) { + return ErrorCode.GRAPH_ENTITY_GROUP_REMOVE_FAILED; + } + int offset = index.get(id); + vertices.set(offset, null); + index.remove(id); + return ErrorCode.SUCCESS; + } + + public VertexSchema getVertexSchema() { + return vertexSchema; + } + + public List getVertices() { + List vertices = + this.vertices.stream().filter(Objects::nonNull).collect(Collectors.toList()); + return vertices; + } - public Vertex getVertex(String id) { - if (index.get(id) != null) { - return vertices.get(index.get(id)); - } else { - return null; - } + public Vertex getVertex(String id) { + if (index.get(id) != null) { + return vertices.get(index.get(id)); + } else { + return null; } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexSchema.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexSchema.java index 633e37490..3f58bf3f3 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexSchema.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/graph/io/VertexSchema.java @@ -23,31 +23,31 @@ public class VertexSchema implements Schema { - private final String label; - private final List fields; - private final String idField; - - public VertexSchema(String label, String idField, List fields) { - this.idField = idField; - this.fields = fields; - this.label = label; - } - - @Override - public List getFields() { - return fields; - } - - public String getLabel() { - return label; - } - - public String getIdField() { - return idField; - } - - @Override - public String getName() { - return label; - } + private final String label; + private final List fields; + private final String idField; + + public VertexSchema(String label, String idField, List fields) { + this.idField = idField; + this.fields = fields; + this.label = label; + } + + @Override + public List getFields() { + return fields; + } + + public String getLabel() { + return label; + } + + public String getIdField() { + return idField; + } + + @Override + public String getName() { + return label; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EmbeddingIndexStore.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EmbeddingIndexStore.java index 8fae2e1f0..cb49b4dd4 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EmbeddingIndexStore.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EmbeddingIndexStore.java @@ -19,10 +19,10 @@ package org.apache.geaflow.ai.index; -import com.google.gson.Gson; import java.io.*; import java.nio.charset.Charset; import java.util.*; + import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.geaflow.ai.common.config.Constants; @@ -39,219 +39,224 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.gson.Gson; + public class EmbeddingIndexStore implements IndexStore { - private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingIndexStore.class); + private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingIndexStore.class); - private GraphAccessor graphAccessor; - private VerbalizationFunction verbFunc; - private String indexFilePath; - private ModelConfig modelConfig; - private Map> indexStoreMap; + private GraphAccessor graphAccessor; + private VerbalizationFunction verbFunc; + private String indexFilePath; + private ModelConfig modelConfig; + private Map> indexStoreMap; - public void initStore(GraphAccessor graphAccessor, VerbalizationFunction func, - String indexFilePath, ModelConfig modelInfo) { - this.graphAccessor = graphAccessor; - this.verbFunc = func; - this.indexFilePath = indexFilePath; - this.modelConfig = modelInfo; - this.indexStoreMap = new HashMap<>(); + public void initStore( + GraphAccessor graphAccessor, + VerbalizationFunction func, + String indexFilePath, + ModelConfig modelInfo) { + this.graphAccessor = graphAccessor; + this.verbFunc = func; + this.indexFilePath = indexFilePath; + this.modelConfig = modelInfo; + this.indexStoreMap = new HashMap<>(); - //Read index items from indexFilePath - Map key2EntityMap = new HashMap<>(); - for (Iterator itV = this.graphAccessor.scanVertex(); itV.hasNext(); ) { - GraphVertex vertex = itV.next(); - key2EntityMap.put(ModelUtils.getGraphEntityKey(vertex), vertex); - for (Iterator itE = this.graphAccessor.scanEdge(vertex); itE.hasNext(); ) { - GraphEdge edge = itE.next(); - key2EntityMap.put(ModelUtils.getGraphEntityKey(edge), edge); - } - } - LOGGER.info("Success to scan entities. total entities num: " + key2EntityMap.size()); + // Read index items from indexFilePath + Map key2EntityMap = new HashMap<>(); + for (Iterator itV = this.graphAccessor.scanVertex(); itV.hasNext(); ) { + GraphVertex vertex = itV.next(); + key2EntityMap.put(ModelUtils.getGraphEntityKey(vertex), vertex); + for (Iterator itE = this.graphAccessor.scanEdge(vertex); itE.hasNext(); ) { + GraphEdge edge = itE.next(); + key2EntityMap.put(ModelUtils.getGraphEntityKey(edge), edge); + } + } + LOGGER.info("Success to scan entities. total entities num: " + key2EntityMap.size()); - try { - File indexFile = new File(this.indexFilePath); + try { + File indexFile = new File(this.indexFilePath); - if (!indexFile.exists()) { - File parentDir = indexFile.getParentFile(); - if (parentDir != null && !parentDir.exists()) { - parentDir.mkdirs(); - } - indexFile.createNewFile(); - LOGGER.info("Success to create new index store file. Path: " + this.indexFilePath); - } - } catch (Throwable e) { - throw new RuntimeException(e); + if (!indexFile.exists()) { + File parentDir = indexFile.getParentFile(); + if (parentDir != null && !parentDir.exists()) { + parentDir.mkdirs(); } + indexFile.createNewFile(); + LOGGER.info("Success to create new index store file. Path: " + this.indexFilePath); + } + } catch (Throwable e) { + throw new RuntimeException(e); + } - - long count = 0; - try (BufferedReader reader = new BufferedReader( - new InputStreamReader( - new FileInputStream(this.indexFilePath), - Charset.defaultCharset()))) { - String line; - while ((line = reader.readLine()) != null) { - line = line.trim(); - if (line.isEmpty()) { - continue; - } - try { - EmbeddingService.EmbeddingResult embedding = - new Gson().fromJson(line, EmbeddingService.EmbeddingResult.class); - String key = embedding.input; - GraphEntity entity = key2EntityMap.get(key); - if (entity != null) { - this.indexStoreMap.computeIfAbsent(entity, k -> new ArrayList<>()).add(embedding); - } - count++; - } catch (Throwable e) { - LOGGER.info("Cannot parse embedding item: " + line); - } - } + long count = 0; + try (BufferedReader reader = + new BufferedReader( + new InputStreamReader( + new FileInputStream(this.indexFilePath), Charset.defaultCharset()))) { + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (line.isEmpty()) { + continue; + } + try { + EmbeddingService.EmbeddingResult embedding = + new Gson().fromJson(line, EmbeddingService.EmbeddingResult.class); + String key = embedding.input; + GraphEntity entity = key2EntityMap.get(key); + if (entity != null) { + this.indexStoreMap.computeIfAbsent(entity, k -> new ArrayList<>()).add(embedding); + } + count++; } catch (Throwable e) { - throw new RuntimeException(e); + LOGGER.info("Cannot parse embedding item: " + line); } + } + } catch (Throwable e) { + throw new RuntimeException(e); + } - LOGGER.info("Success to read index store file. items num: " + count); - LOGGER.info("Success to rebuild index with file. index num: " + this.indexStoreMap.size()); - - - //Scan entities in the graph, make new index items - EmbeddingService embeddingService = new EmbeddingService(); - embeddingService.setModelConfig(modelInfo); + LOGGER.info("Success to read index store file. items num: " + count); + LOGGER.info("Success to rebuild index with file. index num: " + this.indexStoreMap.size()); - final int BATCH_SIZE = Constants.EMBEDDING_INDEX_STORE_BATCH_SIZE; - List pendingEntities = new ArrayList<>(BATCH_SIZE); - Set batchEntitiesBuffer = new HashSet<>(BATCH_SIZE); - List result = new ArrayList<>(); - final int REPORT_SIZE = Constants.EMBEDDING_INDEX_STORE_REPORT_SIZE; - long reportedCount = this.indexStoreMap.size(); - long addedCount = this.indexStoreMap.size(); - for (Iterator itV = graphAccessor.scanVertex(); itV.hasNext(); ) { - GraphVertex vertex = itV.next(); + // Scan entities in the graph, make new index items + EmbeddingService embeddingService = new EmbeddingService(); + embeddingService.setModelConfig(modelInfo); - // Scan vertices or edges, skip already indexed data, - // add un-indexed data to batch processing collection - if (!indexStoreMap.containsKey(vertex) && !batchEntitiesBuffer.contains(vertex)) { - batchEntitiesBuffer.add(vertex); - pendingEntities.add(vertex); - if (pendingEntities.size() >= BATCH_SIZE) { - result.addAll(indexBatch(embeddingService, pendingEntities)); - flushBatchIndex(result, false); - pendingEntities.clear(); - batchEntitiesBuffer.clear(); - addedCount += BATCH_SIZE; - } - } + final int BATCH_SIZE = Constants.EMBEDDING_INDEX_STORE_BATCH_SIZE; + List pendingEntities = new ArrayList<>(BATCH_SIZE); + Set batchEntitiesBuffer = new HashSet<>(BATCH_SIZE); + List result = new ArrayList<>(); + final int REPORT_SIZE = Constants.EMBEDDING_INDEX_STORE_REPORT_SIZE; + long reportedCount = this.indexStoreMap.size(); + long addedCount = this.indexStoreMap.size(); + for (Iterator itV = graphAccessor.scanVertex(); itV.hasNext(); ) { + GraphVertex vertex = itV.next(); - for (Iterator itE = graphAccessor.scanEdge(vertex); itE.hasNext(); ) { - GraphEdge edge = itE.next(); - if (!indexStoreMap.containsKey(edge) && !batchEntitiesBuffer.contains(edge)) { - batchEntitiesBuffer.add(edge); - pendingEntities.add(edge); - if (pendingEntities.size() >= BATCH_SIZE) { - result.addAll(indexBatch(embeddingService, pendingEntities)); - flushBatchIndex(result, false); - pendingEntities.clear(); - batchEntitiesBuffer.clear(); - addedCount += BATCH_SIZE; - } - } - } - if (addedCount - reportedCount > REPORT_SIZE) { - LOGGER.info("added batch index. added num: " + addedCount); - reportedCount = addedCount; - } + // Scan vertices or edges, skip already indexed data, + // add un-indexed data to batch processing collection + if (!indexStoreMap.containsKey(vertex) && !batchEntitiesBuffer.contains(vertex)) { + batchEntitiesBuffer.add(vertex); + pendingEntities.add(vertex); + if (pendingEntities.size() >= BATCH_SIZE) { + result.addAll(indexBatch(embeddingService, pendingEntities)); + flushBatchIndex(result, false); + pendingEntities.clear(); + batchEntitiesBuffer.clear(); + addedCount += BATCH_SIZE; } - if (pendingEntities.size() > 0) { + } + + for (Iterator itE = graphAccessor.scanEdge(vertex); itE.hasNext(); ) { + GraphEdge edge = itE.next(); + if (!indexStoreMap.containsKey(edge) && !batchEntitiesBuffer.contains(edge)) { + batchEntitiesBuffer.add(edge); + pendingEntities.add(edge); + if (pendingEntities.size() >= BATCH_SIZE) { result.addAll(indexBatch(embeddingService, pendingEntities)); - flushBatchIndex(result, true); - addedCount += pendingEntities.size(); + flushBatchIndex(result, false); pendingEntities.clear(); batchEntitiesBuffer.clear(); + addedCount += BATCH_SIZE; + } } - - LOGGER.info("Successfully added {} new index items. Total indexed: {}", - addedCount, indexStoreMap.size()); + } + if (addedCount - reportedCount > REPORT_SIZE) { + LOGGER.info("added batch index. added num: " + addedCount); + reportedCount = addedCount; + } + } + if (pendingEntities.size() > 0) { + result.addAll(indexBatch(embeddingService, pendingEntities)); + flushBatchIndex(result, true); + addedCount += pendingEntities.size(); + pendingEntities.clear(); + batchEntitiesBuffer.clear(); } - private List indexBatch(EmbeddingService service, List pendingEntities) { - if (pendingEntities == null || service == null || pendingEntities.isEmpty()) { - return new ArrayList<>(); - } - List pendingTexts = new ArrayList<>(pendingEntities.size()); - Map> entity2StartEndPair = new HashMap<>(); - for (GraphEntity e : pendingEntities) { - Integer start = pendingTexts.size(); - pendingTexts.addAll(ModelUtils.splitLongText( - Constants.EMBEDDING_INDEX_STORE_SPLIT_TEXT_CHUNK_SIZE, - verbFunc.verbalize(e).toArray(new String[0]))); - Integer end = pendingTexts.size(); - entity2StartEndPair.put(e, Pair.of(start, end)); - } + LOGGER.info( + "Successfully added {} new index items. Total indexed: {}", + addedCount, + indexStoreMap.size()); + } - Gson gson = new Gson(); - int batchSize = pendingEntities.size(); - List result = new ArrayList<>(); - List pendingTextsList = new ArrayList<>(pendingTexts); + private List indexBatch(EmbeddingService service, List pendingEntities) { + if (pendingEntities == null || service == null || pendingEntities.isEmpty()) { + return new ArrayList<>(); + } + List pendingTexts = new ArrayList<>(pendingEntities.size()); + Map> entity2StartEndPair = new HashMap<>(); + for (GraphEntity e : pendingEntities) { + Integer start = pendingTexts.size(); + pendingTexts.addAll( + ModelUtils.splitLongText( + Constants.EMBEDDING_INDEX_STORE_SPLIT_TEXT_CHUNK_SIZE, + verbFunc.verbalize(e).toArray(new String[0]))); + Integer end = pendingTexts.size(); + entity2StartEndPair.put(e, Pair.of(start, end)); + } - for (int i = 0; i < pendingTextsList.size(); i += batchSize) { - int end = Math.min(i + batchSize, pendingTextsList.size()); - List batch = pendingTextsList.subList(i, end); - String[] textsArray = batch.toArray(new String[0]); - String embeddingResultStr = service.embedding(textsArray); - List splitResults = Arrays.asList(embeddingResultStr.trim().split("\n")); - result.addAll(splitResults); + Gson gson = new Gson(); + int batchSize = pendingEntities.size(); + List result = new ArrayList<>(); + List pendingTextsList = new ArrayList<>(pendingTexts); - } + for (int i = 0; i < pendingTextsList.size(); i += batchSize) { + int end = Math.min(i + batchSize, pendingTextsList.size()); + List batch = pendingTextsList.subList(i, end); + String[] textsArray = batch.toArray(new String[0]); + String embeddingResultStr = service.embedding(textsArray); + List splitResults = Arrays.asList(embeddingResultStr.trim().split("\n")); + result.addAll(splitResults); + } - List formatResult = new ArrayList<>(); - for (Map.Entry> entry : entity2StartEndPair.entrySet()) { - GraphEntity e = entry.getKey(); - List embeddings = new ArrayList<>(); - for (int i = entry.getValue().getLeft(); i < entry.getValue().getRight(); i++) { - if (StringUtils.isNotBlank(result.get(i))) { - EmbeddingService.EmbeddingResult res = gson.fromJson(result.get(i), - EmbeddingService.EmbeddingResult.class); - res.input = ModelUtils.getGraphEntityKey(e); - formatResult.add(gson.toJson(res)); - embeddings.add(res); - } - } - indexStoreMap.put(e, embeddings); + List formatResult = new ArrayList<>(); + for (Map.Entry> entry : entity2StartEndPair.entrySet()) { + GraphEntity e = entry.getKey(); + List embeddings = new ArrayList<>(); + for (int i = entry.getValue().getLeft(); i < entry.getValue().getRight(); i++) { + if (StringUtils.isNotBlank(result.get(i))) { + EmbeddingService.EmbeddingResult res = + gson.fromJson(result.get(i), EmbeddingService.EmbeddingResult.class); + res.input = ModelUtils.getGraphEntityKey(e); + formatResult.add(gson.toJson(res)); + embeddings.add(res); } - return formatResult; + } + indexStoreMap.put(e, embeddings); } + return formatResult; + } - private void flushBatchIndex(List newItemStrings, boolean force) { - final int WRITE_SIZE = Constants.EMBEDDING_INDEX_STORE_FLUSH_WRITE_SIZE; - if (force || newItemStrings.size() >= WRITE_SIZE) { - try (FileWriter fw = new FileWriter(this.indexFilePath, true); - BufferedWriter writer = new BufferedWriter(fw); - PrintWriter out = new PrintWriter(writer)) { - for (String item : newItemStrings) { - out.println(item); - } - LOGGER.info("Success to append " + newItemStrings.size() + " new index items to file."); - } catch (IOException e) { - throw new RuntimeException("Failed to append to index file: " + this.indexFilePath, e); - } - newItemStrings.clear(); + private void flushBatchIndex(List newItemStrings, boolean force) { + final int WRITE_SIZE = Constants.EMBEDDING_INDEX_STORE_FLUSH_WRITE_SIZE; + if (force || newItemStrings.size() >= WRITE_SIZE) { + try (FileWriter fw = new FileWriter(this.indexFilePath, true); + BufferedWriter writer = new BufferedWriter(fw); + PrintWriter out = new PrintWriter(writer)) { + for (String item : newItemStrings) { + out.println(item); } + LOGGER.info("Success to append " + newItemStrings.size() + " new index items to file."); + } catch (IOException e) { + throw new RuntimeException("Failed to append to index file: " + this.indexFilePath, e); + } + newItemStrings.clear(); } + } - @Override - public List getEntityIndex(GraphEntity entity) { - if (entity != null && indexStoreMap.get(entity) != null) { - List resultList = indexStoreMap.get(entity); - List result = new ArrayList<>(); - for (EmbeddingService.EmbeddingResult res : resultList) { - double[] embedding = res.embedding; - result.add(new EmbeddingVector(embedding)); - } - return result; - } - return Collections.emptyList(); + @Override + public List getEntityIndex(GraphEntity entity) { + if (entity != null && indexStoreMap.get(entity) != null) { + List resultList = indexStoreMap.get(entity); + List result = new ArrayList<>(); + for (EmbeddingService.EmbeddingResult res : resultList) { + double[] embedding = res.embedding; + result.add(new EmbeddingVector(embedding)); + } + return result; } + return Collections.emptyList(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EntityAttributeIndexStore.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EntityAttributeIndexStore.java index aa9823876..7961e4bbc 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EntityAttributeIndexStore.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/EntityAttributeIndexStore.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.graph.GraphVertex; @@ -31,32 +32,32 @@ public class EntityAttributeIndexStore implements IndexStore { - private VerbalizationFunction verbFunc; + private VerbalizationFunction verbFunc; - public void initStore(VerbalizationFunction func) { - if (func != null) { - this.verbFunc = func; - } + public void initStore(VerbalizationFunction func) { + if (func != null) { + this.verbFunc = func; } + } - @Override - public List getEntityIndex(GraphEntity entity) { - if (entity instanceof GraphVertex) { - String verbalization = verbFunc.verbalize(new SubGraph().addVertex((GraphVertex) entity)); - List sentences = new ArrayList<>(); - sentences.add(verbalization); - KeywordVector keywordVector = new KeywordVector(sentences.toArray(new String[0])); - List results = new ArrayList<>(); - results.add(keywordVector); - return results; - } else { - String verbalization = verbFunc.verbalize(new SubGraph().addEdge((GraphEdge) entity)); - List sentences = new ArrayList<>(); - sentences.add(verbalization); - KeywordVector keywordVector = new KeywordVector(sentences.toArray(new String[0])); - List results = new ArrayList<>(); - results.add(keywordVector); - return results; - } + @Override + public List getEntityIndex(GraphEntity entity) { + if (entity instanceof GraphVertex) { + String verbalization = verbFunc.verbalize(new SubGraph().addVertex((GraphVertex) entity)); + List sentences = new ArrayList<>(); + sentences.add(verbalization); + KeywordVector keywordVector = new KeywordVector(sentences.toArray(new String[0])); + List results = new ArrayList<>(); + results.add(keywordVector); + return results; + } else { + String verbalization = verbFunc.verbalize(new SubGraph().addEdge((GraphEdge) entity)); + List sentences = new ArrayList<>(); + sentences.add(verbalization); + KeywordVector keywordVector = new KeywordVector(sentences.toArray(new String[0])); + List results = new ArrayList<>(); + results.add(keywordVector); + return results; } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStore.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStore.java index 53d1d8e51..6abb4d59c 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStore.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStore.java @@ -20,10 +20,11 @@ package org.apache.geaflow.ai.index; import java.util.List; + import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.index.vector.IVector; public interface IndexStore { - List getEntityIndex(GraphEntity entity); + List getEntityIndex(GraphEntity entity); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStoreCache.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStoreCache.java index 6401b366b..945e4b109 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStoreCache.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/IndexStoreCache.java @@ -21,32 +21,33 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphVertex; import org.apache.geaflow.ai.index.vector.IVector; public class IndexStoreCache { - public static final IndexStoreCache CACHE = new IndexStoreCache(); - public static final List STORE = new ArrayList<>(); + public static final IndexStoreCache CACHE = new IndexStoreCache(); + public static final List STORE = new ArrayList<>(); - static { - STORE.add(new EntityAttributeIndexStore()); - } + static { + STORE.add(new EntityAttributeIndexStore()); + } - public List getVertexIndex(GraphVertex graphVertex) { - List results = new ArrayList<>(); - for (IndexStore indexStore : STORE) { - results.addAll(indexStore.getEntityIndex(graphVertex)); - } - return results; + public List getVertexIndex(GraphVertex graphVertex) { + List results = new ArrayList<>(); + for (IndexStore indexStore : STORE) { + results.addAll(indexStore.getEntityIndex(graphVertex)); } + return results; + } - public List getEdgeIndex(GraphEdge graphEdge) { - List results = new ArrayList<>(); - for (IndexStore indexStore : STORE) { - results.addAll(indexStore.getEntityIndex(graphEdge)); - } - return results; + public List getEdgeIndex(GraphEdge graphEdge) { + List results = new ArrayList<>(); + for (IndexStore indexStore : STORE) { + results.addAll(indexStore.getEntityIndex(graphEdge)); } + return results; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/EmbeddingVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/EmbeddingVector.java index 61cd80ee6..0856d1b68 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/EmbeddingVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/EmbeddingVector.java @@ -24,59 +24,55 @@ public class EmbeddingVector implements IVector { + private final double[] vec; - private final double[] vec; + public EmbeddingVector(double[] vec) { + this.vec = Objects.requireNonNull(vec); + } - public EmbeddingVector(double[] vec) { - this.vec = Objects.requireNonNull(vec); + @Override + public double match(IVector other) { + // Type check: must be same implementation class + if (!(other instanceof EmbeddingVector)) { + return 0.0; } + EmbeddingVector otherVec = (EmbeddingVector) other; - @Override - public double match(IVector other) { - // Type check: must be same implementation class - if (!(other instanceof EmbeddingVector)) { - return 0.0; - } - EmbeddingVector otherVec = (EmbeddingVector) other; - - // Dimension check: vectors must have same length - if (this.vec.length != otherVec.vec.length) { - return 0.0; - } - - double dotProduct = 0.0; // Accumulator for dot product - double normSquared1 = 0.0; // Accumulator for squared L2 norm of this vector - double normSquared2 = 0.0; // Accumulator for squared L2 norm of other vector + // Dimension check: vectors must have same length + if (this.vec.length != otherVec.vec.length) { + return 0.0; + } - // Single-pass computation for dot product and squared norms - for (int i = 0; i < this.vec.length; i++) { - dotProduct += this.vec[i] * otherVec.vec[i]; - normSquared1 += this.vec[i] * this.vec[i]; - normSquared2 += otherVec.vec[i] * otherVec.vec[i]; - } + double dotProduct = 0.0; // Accumulator for dot product + double normSquared1 = 0.0; // Accumulator for squared L2 norm of this vector + double normSquared2 = 0.0; // Accumulator for squared L2 norm of other vector - // Calculate denominator (product of L2 norms) - double denominator = Math.sqrt(normSquared1) * Math.sqrt(normSquared2); + // Single-pass computation for dot product and squared norms + for (int i = 0; i < this.vec.length; i++) { + dotProduct += this.vec[i] * otherVec.vec[i]; + normSquared1 += this.vec[i] * this.vec[i]; + normSquared2 += otherVec.vec[i] * otherVec.vec[i]; + } - // Handle zero-vector case (avoid division by zero) - if (denominator == 0.0) { - return 0.0; - } + // Calculate denominator (product of L2 norms) + double denominator = Math.sqrt(normSquared1) * Math.sqrt(normSquared2); - // Return cosine similarity: dot product divided by norms product - return dotProduct / denominator; + // Handle zero-vector case (avoid division by zero) + if (denominator == 0.0) { + return 0.0; } - @Override - public VectorType getType() { - return VectorType.EmbeddingVector; - } + // Return cosine similarity: dot product divided by norms product + return dotProduct / denominator; + } + @Override + public VectorType getType() { + return VectorType.EmbeddingVector; + } - @Override - public String toString() { - return "EmbeddingVector{" - + "vec=" + Arrays.toString(vec) - + '}'; - } + @Override + public String toString() { + return "EmbeddingVector{" + "vec=" + Arrays.toString(vec) + '}'; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/IVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/IVector.java index 14bfdb1c1..11cdb78f8 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/IVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/IVector.java @@ -21,9 +21,9 @@ public interface IVector { - double match(IVector other); + double match(IVector other); - VectorType getType(); + VectorType getType(); - String toString(); + String toString(); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/KeywordVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/KeywordVector.java index da92341b4..4eebfb9bb 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/KeywordVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/KeywordVector.java @@ -25,46 +25,44 @@ public class KeywordVector implements IVector { - private final String[] vec; + private final String[] vec; - public KeywordVector(String... vec) { - this.vec = vec; - } + public KeywordVector(String... vec) { + this.vec = vec; + } - public String[] getVec() { - return vec; - } + public String[] getVec() { + return vec; + } - @Override - public double match(IVector other) { - if (!(other instanceof KeywordVector)) { - return 0.0; - } - KeywordVector otherKeyword = (KeywordVector) other; - String[] small = this.vec.length <= otherKeyword.vec.length ? this.vec : otherKeyword.vec; - String[] large = this.vec.length <= otherKeyword.vec.length ? otherKeyword.vec : this.vec; - Map keyword2TimesMap = new HashMap<>(); - for (String word : small) { - keyword2TimesMap.put(word, keyword2TimesMap.getOrDefault(word, 0L) + 1); - } - int count = 0; - for (String keyword : large) { - if (keyword2TimesMap.containsKey(keyword)) { - count += keyword2TimesMap.get(keyword); - } - } - return count; + @Override + public double match(IVector other) { + if (!(other instanceof KeywordVector)) { + return 0.0; } - - @Override - public VectorType getType() { - return VectorType.KeywordVector; + KeywordVector otherKeyword = (KeywordVector) other; + String[] small = this.vec.length <= otherKeyword.vec.length ? this.vec : otherKeyword.vec; + String[] large = this.vec.length <= otherKeyword.vec.length ? otherKeyword.vec : this.vec; + Map keyword2TimesMap = new HashMap<>(); + for (String word : small) { + keyword2TimesMap.put(word, keyword2TimesMap.getOrDefault(word, 0L) + 1); } - - @Override - public String toString() { - return "KeywordVector{" - + "vec=" + Arrays.toString(vec) - + '}'; + int count = 0; + for (String keyword : large) { + if (keyword2TimesMap.containsKey(keyword)) { + count += keyword2TimesMap.get(keyword); + } } + return count; + } + + @Override + public VectorType getType() { + return VectorType.KeywordVector; + } + + @Override + public String toString() { + return "KeywordVector{" + "vec=" + Arrays.toString(vec) + '}'; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java index dfe6696dd..ffcfbdd26 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/MagnitudeVector.java @@ -21,18 +21,18 @@ public class MagnitudeVector implements IVector { - @Override - public double match(IVector other) { - return 0; - } + @Override + public double match(IVector other) { + return 0; + } - @Override - public VectorType getType() { - return VectorType.MagnitudeVector; - } + @Override + public VectorType getType() { + return VectorType.MagnitudeVector; + } - @Override - public String toString() { - return "MagnitudeVector{}"; - } + @Override + public String toString() { + return "MagnitudeVector{}"; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/TraversalVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/TraversalVector.java index 86ef1ce4c..8874a0c5d 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/TraversalVector.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/TraversalVector.java @@ -21,37 +21,37 @@ public class TraversalVector implements IVector { - private final String[] vec; + private final String[] vec; - public TraversalVector(String... vec) { - if (vec.length % 3 != 0) { - throw new RuntimeException("Traversal vector should be src-edge-dst triple"); - } - this.vec = vec; + public TraversalVector(String... vec) { + if (vec.length % 3 != 0) { + throw new RuntimeException("Traversal vector should be src-edge-dst triple"); } + this.vec = vec; + } - @Override - public double match(IVector other) { - return 0; - } + @Override + public double match(IVector other) { + return 0; + } - @Override - public VectorType getType() { - return VectorType.TraversalVector; - } + @Override + public VectorType getType() { + return VectorType.TraversalVector; + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder("TraversalVector{vec="); - for (int i = 0; i < vec.length; i++) { - if (i > 0) { - sb.append(i % 3 == 0 ? "; " : "-"); - } - sb.append(vec[i]); - if (i % 3 == 2) { - sb.append(">"); - } - } - return sb.append('}').toString(); + @Override + public String toString() { + StringBuilder sb = new StringBuilder("TraversalVector{vec="); + for (int i = 0; i < vec.length; i++) { + if (i > 0) { + sb.append(i % 3 == 0 ? "; " : "-"); + } + sb.append(vec[i]); + if (i % 3 == 2) { + sb.append(">"); + } } + return sb.append('}').toString(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java index 1f076c8e7..fcf2c923c 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java @@ -20,8 +20,8 @@ package org.apache.geaflow.ai.index.vector; public enum VectorType { - TraversalVector, - EmbeddingVector, - MagnitudeVector, - KeywordVector + TraversalVector, + EmbeddingVector, + MagnitudeVector, + KeywordVector } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/EmbeddingOperator.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/EmbeddingOperator.java index 849f53ee4..b610fa617 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/EmbeddingOperator.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/EmbeddingOperator.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEntity; @@ -34,161 +35,168 @@ public class EmbeddingOperator implements SearchOperator { - private final GraphAccessor graphAccessor; - private final IndexStore indexStore; - private double threshold; - private int topN; - - public EmbeddingOperator(GraphAccessor accessor, IndexStore store) { - this.graphAccessor = Objects.requireNonNull(accessor); - this.indexStore = Objects.requireNonNull(store); - this.threshold = Constants.EMBEDDING_OPERATE_DEFAULT_THRESHOLD; - this.topN = Constants.EMBEDDING_OPERATE_DEFAULT_TOPN; + private final GraphAccessor graphAccessor; + private final IndexStore indexStore; + private double threshold; + private int topN; + + public EmbeddingOperator(GraphAccessor accessor, IndexStore store) { + this.graphAccessor = Objects.requireNonNull(accessor); + this.indexStore = Objects.requireNonNull(store); + this.threshold = Constants.EMBEDDING_OPERATE_DEFAULT_THRESHOLD; + this.topN = Constants.EMBEDDING_OPERATE_DEFAULT_TOPN; + } + + @Override + public List apply(List subGraphList, VectorSearch search) { + List queryEmbeddingVectors = search.getVectorMap().get(VectorType.EmbeddingVector); + if (queryEmbeddingVectors == null || queryEmbeddingVectors.isEmpty()) { + if (subGraphList == null) { + return new ArrayList<>(); + } + return new ArrayList<>(subGraphList); } - - @Override - public List apply(List subGraphList, VectorSearch search) { - List queryEmbeddingVectors = search.getVectorMap().get(VectorType.EmbeddingVector); - if (queryEmbeddingVectors == null || queryEmbeddingVectors.isEmpty()) { - if (subGraphList == null) { - return new ArrayList<>(); - } - return new ArrayList<>(subGraphList); + List globalResults = searchWithGlobalGraph(queryEmbeddingVectors); + if (subGraphList == null || subGraphList.isEmpty()) { + List startVertices = new ArrayList<>(); + for (GraphEntity resEntity : globalResults) { + if (resEntity instanceof GraphVertex) { + startVertices.add((GraphVertex) resEntity); } - List globalResults = searchWithGlobalGraph(queryEmbeddingVectors); - if (subGraphList == null || subGraphList.isEmpty()) { - List startVertices = new ArrayList<>(); - for (GraphEntity resEntity : globalResults) { - if (resEntity instanceof GraphVertex) { - startVertices.add((GraphVertex) resEntity); - } - } - //Apply to subgraph - return startVertices.stream().map(v -> { + } + // Apply to subgraph + return startVertices.stream() + .map( + v -> { SubGraph subGraph = new SubGraph(); subGraph.addVertex(v); return subGraph; - }).collect(Collectors.toList()); - } else { - Map> extendEntityIndexMap = new HashMap<>(); - //Traverse all extension points of the subgraph and search within the extension area - for (SubGraph subGraph : subGraphList) { - List extendEntities = getSubgraphExpand(subGraph); - for (GraphEntity extendEntity : extendEntities) { - List entityIndex = indexStore.getEntityIndex(extendEntity); - extendEntityIndexMap.put(extendEntity, entityIndex); - } - } - //recall compute - List matchEntities = searchEmbeddings(queryEmbeddingVectors, extendEntityIndexMap); - Set matchEntitiesSet = new HashSet<>(matchEntities); - - //Apply to subgraph - List subGraphs = new ArrayList<>(subGraphList); - for (SubGraph subGraph : subGraphs) { - Set subgraphEntitySet = new HashSet<>(subGraph.getGraphEntityList()); - List extendEntities = getSubgraphExpand(subGraph); - for (GraphEntity extendEntity : extendEntities) { - if (matchEntitiesSet.contains(extendEntity) - && !subgraphEntitySet.contains(extendEntity)) { - subgraphEntitySet.add(extendEntity); - subGraph.addEntity(extendEntity); - } - } - } - return subGraphs; + }) + .collect(Collectors.toList()); + } else { + Map> extendEntityIndexMap = new HashMap<>(); + // Traverse all extension points of the subgraph and search within the extension area + for (SubGraph subGraph : subGraphList) { + List extendEntities = getSubgraphExpand(subGraph); + for (GraphEntity extendEntity : extendEntities) { + List entityIndex = indexStore.getEntityIndex(extendEntity); + extendEntityIndexMap.put(extendEntity, entityIndex); } - } - - private List getSubgraphExpand(SubGraph subGraph) { - List entityList = subGraph.getGraphEntityList(); - List expandEntities = new ArrayList<>(); - for (GraphEntity entity : entityList) { - List entityExpand = graphAccessor.expand(entity); - expandEntities.addAll(entityExpand); + } + // recall compute + List matchEntities = + searchEmbeddings(queryEmbeddingVectors, extendEntityIndexMap); + Set matchEntitiesSet = new HashSet<>(matchEntities); + + // Apply to subgraph + List subGraphs = new ArrayList<>(subGraphList); + for (SubGraph subGraph : subGraphs) { + Set subgraphEntitySet = new HashSet<>(subGraph.getGraphEntityList()); + List extendEntities = getSubgraphExpand(subGraph); + for (GraphEntity extendEntity : extendEntities) { + if (matchEntitiesSet.contains(extendEntity) + && !subgraphEntitySet.contains(extendEntity)) { + subgraphEntitySet.add(extendEntity); + subGraph.addEntity(extendEntity); + } } - return expandEntities; + } + return subGraphs; } - - private List searchWithGlobalGraph(List queryEmbeddingVectors) { - Map> entityIndexMap = new HashMap<>(); - Iterator vertexIterator = graphAccessor.scanVertex(); - while (vertexIterator.hasNext()) { - GraphVertex vertex = vertexIterator.next(); - //Read all vertices indices from the index and add them to the candidate set. - List vertexIndex = indexStore.getEntityIndex(vertex); - entityIndexMap.put(vertex, vertexIndex); - } - //recall compute - return searchEmbeddings(queryEmbeddingVectors, entityIndexMap); + } + + private List getSubgraphExpand(SubGraph subGraph) { + List entityList = subGraph.getGraphEntityList(); + List expandEntities = new ArrayList<>(); + for (GraphEntity entity : entityList) { + List entityExpand = graphAccessor.expand(entity); + expandEntities.addAll(entityExpand); } - - private List searchEmbeddings(List queryEmbeddingVectors, - Map> entityIndexMap) { - // Extract valid query EmbeddingVectors from input - List queryVectors = queryEmbeddingVectors.stream() - .filter(EmbeddingVector.class::isInstance) - .map(EmbeddingVector.class::cast) - .collect(Collectors.toList()); - - // Create min-heap to maintain top N entities by maximum relevance score - PriorityQueue minHeap = new PriorityQueue<>(Comparator.comparingDouble(a -> a.score)); - - // Process each entity in the index - for (Map.Entry> entry : entityIndexMap.entrySet()) { - GraphEntity entity = entry.getKey(); - // Extract valid entity embeddings - List entityVectors = entry.getValue().stream() - .filter(EmbeddingVector.class::isInstance) - .map(EmbeddingVector.class::cast) - .collect(Collectors.toList()); - - // Skip entities without valid embeddings - if (entityVectors.isEmpty()) { - continue; - } - - // Compute maximum relevance score between query and entity embeddings - double maxRelevance = 0.0; - for (EmbeddingVector queryVector : queryVectors) { - for (EmbeddingVector entityVector : entityVectors) { - double matchScore = queryVector.match(entityVector); - if (matchScore > maxRelevance) { - maxRelevance = matchScore; - } - } - } - - // Add to candidates if above threshold - if (maxRelevance > threshold) { - // Maintain heap size <= topN - if (minHeap.size() < topN) { - minHeap.offer(new GraphEntityScorePair(entity, maxRelevance)); - } else { - assert minHeap.peek() != null; - if (minHeap.peek().score < maxRelevance) { - minHeap.poll(); - minHeap.offer(new GraphEntityScorePair(entity, maxRelevance)); - } - } - } + return expandEntities; + } + + private List searchWithGlobalGraph(List queryEmbeddingVectors) { + Map> entityIndexMap = new HashMap<>(); + Iterator vertexIterator = graphAccessor.scanVertex(); + while (vertexIterator.hasNext()) { + GraphVertex vertex = vertexIterator.next(); + // Read all vertices indices from the index and add them to the candidate set. + List vertexIndex = indexStore.getEntityIndex(vertex); + entityIndexMap.put(vertex, vertexIndex); + } + // recall compute + return searchEmbeddings(queryEmbeddingVectors, entityIndexMap); + } + + private List searchEmbeddings( + List queryEmbeddingVectors, Map> entityIndexMap) { + // Extract valid query EmbeddingVectors from input + List queryVectors = + queryEmbeddingVectors.stream() + .filter(EmbeddingVector.class::isInstance) + .map(EmbeddingVector.class::cast) + .collect(Collectors.toList()); + + // Create min-heap to maintain top N entities by maximum relevance score + PriorityQueue minHeap = + new PriorityQueue<>(Comparator.comparingDouble(a -> a.score)); + + // Process each entity in the index + for (Map.Entry> entry : entityIndexMap.entrySet()) { + GraphEntity entity = entry.getKey(); + // Extract valid entity embeddings + List entityVectors = + entry.getValue().stream() + .filter(EmbeddingVector.class::isInstance) + .map(EmbeddingVector.class::cast) + .collect(Collectors.toList()); + + // Skip entities without valid embeddings + if (entityVectors.isEmpty()) { + continue; + } + + // Compute maximum relevance score between query and entity embeddings + double maxRelevance = 0.0; + for (EmbeddingVector queryVector : queryVectors) { + for (EmbeddingVector entityVector : entityVectors) { + double matchScore = queryVector.match(entityVector); + if (matchScore > maxRelevance) { + maxRelevance = matchScore; + } } + } - // Convert heap to sorted list (descending by score) - return minHeap.stream() - .sorted(Comparator.comparingDouble((GraphEntityScorePair p) -> p.score).reversed()) - .map(pair -> pair.entity) - .collect(Collectors.toList()); + // Add to candidates if above threshold + if (maxRelevance > threshold) { + // Maintain heap size <= topN + if (minHeap.size() < topN) { + minHeap.offer(new GraphEntityScorePair(entity, maxRelevance)); + } else { + assert minHeap.peek() != null; + if (minHeap.peek().score < maxRelevance) { + minHeap.poll(); + minHeap.offer(new GraphEntityScorePair(entity, maxRelevance)); + } + } + } } - // Helper class to store entity-score pairs - private static class GraphEntityScorePair { - final GraphEntity entity; - final double score; - - GraphEntityScorePair(GraphEntity entity, double score) { - this.entity = entity; - this.score = score; - } + // Convert heap to sorted list (descending by score) + return minHeap.stream() + .sorted(Comparator.comparingDouble((GraphEntityScorePair p) -> p.score).reversed()) + .map(pair -> pair.entity) + .collect(Collectors.toList()); + } + + // Helper class to store entity-score pairs + private static class GraphEntityScorePair { + final GraphEntity entity; + final double score; + + GraphEntityScorePair(GraphEntity entity, double score) { + this.entity = entity; + this.score = score; } + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/GraphSearchStore.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/GraphSearchStore.java index 20beb48ad..a84eac2a2 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/GraphSearchStore.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/GraphSearchStore.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; @@ -40,115 +41,117 @@ public class GraphSearchStore { - private SearchStore store; - private long entityNum = 0L; - - public GraphSearchStore() { - this.store = new SearchStore(); + private SearchStore store; + private long entityNum = 0L; + + public GraphSearchStore() { + this.store = new SearchStore(); + } + + public boolean indexVertex(GraphVertex graphVertex, List indexVectors) { + Map kv = new HashMap<>(); + Vertex vertex = graphVertex.getVertex(); + kv.put(SearchConstants.ID, vertex.getId()); + kv.put(SearchConstants.LABEL, vertex.getLabel()); + List contents = new ArrayList<>(indexVectors.size()); + for (IVector v : indexVectors) { + contents.add(v.toString()); } + String content = String.join(SearchConstants.DELIMITER, contents); + kv.put(SearchConstants.CONTENT, content); - public boolean indexVertex(GraphVertex graphVertex, List indexVectors) { - Map kv = new HashMap<>(); - Vertex vertex = graphVertex.getVertex(); - kv.put(SearchConstants.ID, vertex.getId()); - kv.put(SearchConstants.LABEL, vertex.getLabel()); - List contents = new ArrayList<>(indexVectors.size()); - for (IVector v : indexVectors) { - contents.add(v.toString()); - } - String content = String.join(SearchConstants.DELIMITER, contents); - kv.put(SearchConstants.CONTENT, content); - - try { - store.addDoc(kv); - } catch (Throwable e) { - throw new RuntimeException("Cannot index vertex to search store", e); - } - addItem(); - return true; + try { + store.addDoc(kv); + } catch (Throwable e) { + throw new RuntimeException("Cannot index vertex to search store", e); } - - public boolean indexEdge(GraphEdge graphEdge, List indexVectors) { - Map kv = new HashMap<>(); - Edge edge = graphEdge.getEdge(); - kv.put(SearchConstants.SRC, edge.getSrcId()); - kv.put(SearchConstants.DST, edge.getDstId()); - kv.put(SearchConstants.LABEL, edge.getLabel()); - List contents = new ArrayList<>(indexVectors.size()); - for (IVector v : indexVectors) { - contents.add(v.toString()); - } - String content = String.join(SearchConstants.DELIMITER, contents); - kv.put(SearchConstants.CONTENT, content); - try { - store.addDoc(kv); - } catch (Throwable e) { - throw new RuntimeException("Cannot index vertex to search store", e); - } - addItem(); - return true; + addItem(); + return true; + } + + public boolean indexEdge(GraphEdge graphEdge, List indexVectors) { + Map kv = new HashMap<>(); + Edge edge = graphEdge.getEdge(); + kv.put(SearchConstants.SRC, edge.getSrcId()); + kv.put(SearchConstants.DST, edge.getDstId()); + kv.put(SearchConstants.LABEL, edge.getLabel()); + List contents = new ArrayList<>(indexVectors.size()); + for (IVector v : indexVectors) { + contents.add(v.toString()); } - - public List search(String key1, GraphAccessor graphAccessor) { - try { - String query = SearchUtils.formatQuery(key1); - TopDocs docs = store.searchDoc(SearchConstants.CONTENT, query); - ScoreDoc[] scoreDocArray = docs.scoreDocs; - Set vertexLabels = graphAccessor.getGraphSchema().getVertexSchemaList() - .stream().map(VertexSchema::getLabel).collect(Collectors.toSet()); - Set edgeLabels = graphAccessor.getGraphSchema().getEdgeSchemaList() - .stream().map(EdgeSchema::getLabel).collect(Collectors.toSet()); - List result = new ArrayList<>(); - for (ScoreDoc scoreDoc : scoreDocArray) { - int docId = scoreDoc.doc; - Document document = store.getDoc(docId); - String label = document.get(SearchConstants.LABEL); - if (vertexLabels.contains(label)) { - String id = document.get(SearchConstants.ID); - GraphVertex graphVertex = graphAccessor.getVertex(label, id); - if (graphVertex != null) { - result.add(graphVertex); - } - } else if (edgeLabels.contains(label)) { - String src = document.get(SearchConstants.SRC); - String dst = document.get(SearchConstants.DST); - List graphEdge = graphAccessor.getEdge(label, src, dst); - if (graphEdge != null) { - result.addAll(graphEdge); - } - } - } - return result; - } catch (IndexNotFoundException notFoundException) { - return new ArrayList<>(); - } catch (Throwable e) { - throw new RuntimeException("Cannot read search store", e); - } - } - - private void addItem() { - entityNum++; + String content = String.join(SearchConstants.DELIMITER, contents); + kv.put(SearchConstants.CONTENT, content); + try { + store.addDoc(kv); + } catch (Throwable e) { + throw new RuntimeException("Cannot index vertex to search store", e); } - - public void close() { - try { - store.close(); - } catch (Throwable e) { - throw new RuntimeException("Cannot close search store", e); + addItem(); + return true; + } + + public List search(String key1, GraphAccessor graphAccessor) { + try { + String query = SearchUtils.formatQuery(key1); + TopDocs docs = store.searchDoc(SearchConstants.CONTENT, query); + ScoreDoc[] scoreDocArray = docs.scoreDocs; + Set vertexLabels = + graphAccessor.getGraphSchema().getVertexSchemaList().stream() + .map(VertexSchema::getLabel) + .collect(Collectors.toSet()); + Set edgeLabels = + graphAccessor.getGraphSchema().getEdgeSchemaList().stream() + .map(EdgeSchema::getLabel) + .collect(Collectors.toSet()); + List result = new ArrayList<>(); + for (ScoreDoc scoreDoc : scoreDocArray) { + int docId = scoreDoc.doc; + Document document = store.getDoc(docId); + String label = document.get(SearchConstants.LABEL); + if (vertexLabels.contains(label)) { + String id = document.get(SearchConstants.ID); + GraphVertex graphVertex = graphAccessor.getVertex(label, id); + if (graphVertex != null) { + result.add(graphVertex); + } + } else if (edgeLabels.contains(label)) { + String src = document.get(SearchConstants.SRC); + String dst = document.get(SearchConstants.DST); + List graphEdge = graphAccessor.getEdge(label, src, dst); + if (graphEdge != null) { + result.addAll(graphEdge); + } } + } + return result; + } catch (IndexNotFoundException notFoundException) { + return new ArrayList<>(); + } catch (Throwable e) { + throw new RuntimeException("Cannot read search store", e); } + } - public Directory getDirectory() { - return store.getDirectory(); - } + private void addItem() { + entityNum++; + } - public Analyzer getAnalyzer() { - return store.getAnalyzer(); + public void close() { + try { + store.close(); + } catch (Throwable e) { + throw new RuntimeException("Cannot close search store", e); } + } - public IndexWriterConfig getConfig() { - return store.getConfig(); - } + public Directory getDirectory() { + return store.getDirectory(); + } + public Analyzer getAnalyzer() { + return store.getAnalyzer(); + } + public IndexWriterConfig getConfig() { + return store.getConfig(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchConstants.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchConstants.java index 6a5ae0770..fafd7afee 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchConstants.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchConstants.java @@ -21,12 +21,11 @@ public class SearchConstants { - public static String LABEL = "label"; - public static String ID = "id"; - public static String SRC = "src"; - public static String DST = "dst"; - public static String CONTENT = "content"; - public static String OPERATOR = "operator"; - public static String DELIMITER = " "; - + public static String LABEL = "label"; + public static String ID = "id"; + public static String SRC = "src"; + public static String DST = "dst"; + public static String CONTENT = "content"; + public static String OPERATOR = "operator"; + public static String DELIMITER = " "; } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchOperator.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchOperator.java index 068bfd744..3b3b3f13d 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchOperator.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchOperator.java @@ -20,10 +20,11 @@ package org.apache.geaflow.ai.operator; import java.util.List; + import org.apache.geaflow.ai.search.VectorSearch; import org.apache.geaflow.ai.subgraph.SubGraph; public interface SearchOperator { - List apply(List subGraphList, VectorSearch search); + List apply(List subGraphList, VectorSearch search); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchStore.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchStore.java index 81183ae47..2a4d7e89e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchStore.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchStore.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer; @@ -40,79 +41,76 @@ public class SearchStore { - private final Directory directory = new ByteBuffersDirectory(); - private final Analyzer analyzer = new StandardAnalyzer(); - private final IndexWriterConfig config = new IndexWriterConfig(analyzer); - private IndexWriter writer; - private boolean writeStats = false; - private IndexReader reader; - private IndexSearcher searcher; - private boolean readStats = false; + private final Directory directory = new ByteBuffersDirectory(); + private final Analyzer analyzer = new StandardAnalyzer(); + private final IndexWriterConfig config = new IndexWriterConfig(analyzer); + private IndexWriter writer; + private boolean writeStats = false; + private IndexReader reader; + private IndexSearcher searcher; + private boolean readStats = false; - public SearchStore() { - } + public SearchStore() {} - public void addDoc(Map kv) throws IOException { - initWriter(); - Document doc = new Document(); - for (Map.Entry entry : kv.entrySet()) { - doc.add(new TextField(entry.getKey(), entry.getValue(), Field.Store.YES)); - } - writer.addDocument(doc); + public void addDoc(Map kv) throws IOException { + initWriter(); + Document doc = new Document(); + for (Map.Entry entry : kv.entrySet()) { + doc.add(new TextField(entry.getKey(), entry.getValue(), Field.Store.YES)); } + writer.addDocument(doc); + } - public TopDocs searchDoc(String field, String content) throws ParseException, IOException { - if (!readStats) { - reader = DirectoryReader.open(directory); - searcher = new IndexSearcher(reader); - readStats = true; - } - QueryParser parser = new QueryParser(field, analyzer); - return searcher.search(parser.parse(content), Constants.GRAPH_SEARCH_STORE_DEFAULT_TOPN); + public TopDocs searchDoc(String field, String content) throws ParseException, IOException { + if (!readStats) { + reader = DirectoryReader.open(directory); + searcher = new IndexSearcher(reader); + readStats = true; } + QueryParser parser = new QueryParser(field, analyzer); + return searcher.search(parser.parse(content), Constants.GRAPH_SEARCH_STORE_DEFAULT_TOPN); + } - public Document getDoc(int docId) { - try { - if (!readStats) { - reader = DirectoryReader.open(directory); - searcher = new IndexSearcher(reader); - readStats = true; - } - return searcher.doc(docId); - } catch (Throwable e) { - return null; - } - + public Document getDoc(int docId) { + try { + if (!readStats) { + reader = DirectoryReader.open(directory); + searcher = new IndexSearcher(reader); + readStats = true; + } + return searcher.doc(docId); + } catch (Throwable e) { + return null; } + } - public void initWriter() throws IOException { - if (!writeStats) { - writer = new IndexWriter(directory, config); - writeStats = true; - } + public void initWriter() throws IOException { + if (!writeStats) { + writer = new IndexWriter(directory, config); + writeStats = true; } + } - public void close() throws IOException { - if (writeStats) { - writer.close(); - writeStats = false; - } - if (readStats) { - reader.close(); - readStats = false; - } + public void close() throws IOException { + if (writeStats) { + writer.close(); + writeStats = false; } - - - public Directory getDirectory() { - return directory; + if (readStats) { + reader.close(); + readStats = false; } + } - public Analyzer getAnalyzer() { - return analyzer; - } + public Directory getDirectory() { + return directory; + } - public IndexWriterConfig getConfig() { - return config; - } + public Analyzer getAnalyzer() { + return analyzer; + } + + public IndexWriterConfig getConfig() { + return config; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchUtils.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchUtils.java index c60d21af3..636117f4b 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchUtils.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SearchUtils.java @@ -26,83 +26,83 @@ public class SearchUtils { - // Set of excluded characters: these will be replaced with spaces in formatQuery - private static final Set EXCLUDED_CHARS = new HashSet<>(Arrays.asList( - '*', '#', '-', '?', '`', '{', '}', '[', ']', '(', ')', '>', '<', ':', '/', '.' - )); + // Set of excluded characters: these will be replaced with spaces in formatQuery + private static final Set EXCLUDED_CHARS = + new HashSet<>( + Arrays.asList( + '*', '#', '-', '?', '`', '{', '}', '[', ']', '(', ')', '>', '<', ':', '/', '.')); - // Set of allowed characters for validation in isAllAllowedChars - // Includes: digits (0-9), and some common safe symbols - private static final Set IGNORE_CHARS = buildIgnoredChars(); + // Set of allowed characters for validation in isAllAllowedChars + // Includes: digits (0-9), and some common safe symbols + private static final Set IGNORE_CHARS = buildIgnoredChars(); - /** - * Builds the set of allowed characters for input validation. - * Includes alphanumeric characters and selected common symbols. - * - * @return an unmodifiable set of ignored characters - */ - private static Set buildIgnoredChars() { - Set ignored = new HashSet<>(32); - // Add digits - for (char c = '0'; c <= '9'; c++) { - ignored.add(c); - } - // Add commonly allowed symbols - ignored.add('.'); - ignored.add('_'); - ignored.add('-'); - ignored.add('@'); - ignored.add('+'); - ignored.add('!'); - ignored.add('$'); - ignored.add('%'); - ignored.add('&'); - ignored.add('='); - ignored.add('~'); - return Collections.unmodifiableSet(ignored); + /** + * Builds the set of allowed characters for input validation. Includes alphanumeric characters and + * selected common symbols. + * + * @return an unmodifiable set of ignored characters + */ + private static Set buildIgnoredChars() { + Set ignored = new HashSet<>(32); + // Add digits + for (char c = '0'; c <= '9'; c++) { + ignored.add(c); } + // Add commonly allowed symbols + ignored.add('.'); + ignored.add('_'); + ignored.add('-'); + ignored.add('@'); + ignored.add('+'); + ignored.add('!'); + ignored.add('$'); + ignored.add('%'); + ignored.add('&'); + ignored.add('='); + ignored.add('~'); + return Collections.unmodifiableSet(ignored); + } - /** - * Formats the input query string by replacing each excluded character with a space. - * This helps sanitize search queries for parsing or indexing. - * - * @param query the input string to format - * @return the formatted string with excluded characters replaced by spaces - */ - public static String formatQuery(String query) { - if (query == null || query.isEmpty()) { - return query; - } - StringBuilder result = new StringBuilder(); - for (char c : query.toCharArray()) { - if (EXCLUDED_CHARS.contains(c)) { - result.append(' '); - } else { - result.append(c); - } - } - String replacedQuery = result.toString(); - replacedQuery = replacedQuery.replace("http", ""); - return replacedQuery; + /** + * Formats the input query string by replacing each excluded character with a space. This helps + * sanitize search queries for parsing or indexing. + * + * @param query the input string to format + * @return the formatted string with excluded characters replaced by spaces + */ + public static String formatQuery(String query) { + if (query == null || query.isEmpty()) { + return query; } - - /** - * Checks whether all characters in the given string are within the allowed character set. - * Useful for validating usernames, identifiers, or safe input formats. - * - * @param str the string to validate - * @return true if all characters are allowed; false otherwise - */ - public static boolean isAllAllowedChars(String str) { - if (str == null || str.isEmpty()) { - return false; // Consider empty/null invalid; adjust based on use case - } - for (char c : str.toCharArray()) { - if (IGNORE_CHARS.contains(c)) { - return false; - } - } - return true; + StringBuilder result = new StringBuilder(); + for (char c : query.toCharArray()) { + if (EXCLUDED_CHARS.contains(c)) { + result.append(' '); + } else { + result.append(c); + } } + String replacedQuery = result.toString(); + replacedQuery = replacedQuery.replace("http", ""); + return replacedQuery; + } + /** + * Checks whether all characters in the given string are within the allowed character set. Useful + * for validating usernames, identifiers, or safe input formats. + * + * @param str the string to validate + * @return true if all characters are allowed; false otherwise + */ + public static boolean isAllAllowedChars(String str) { + if (str == null || str.isEmpty()) { + return false; // Consider empty/null invalid; adjust based on use case + } + for (char c : str.toCharArray()) { + if (IGNORE_CHARS.contains(c)) { + return false; + } + } + return true; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SessionOperator.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SessionOperator.java index 4800c9341..e8a006c0a 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SessionOperator.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/operator/SessionOperator.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; @@ -33,111 +34,114 @@ public class SessionOperator implements SearchOperator { - private final GraphAccessor graphAccessor; - private final IndexStore indexStore; + private final GraphAccessor graphAccessor; + private final IndexStore indexStore; - public SessionOperator(GraphAccessor accessor, IndexStore store) { - this.graphAccessor = Objects.requireNonNull(accessor); - this.indexStore = Objects.requireNonNull(store); - } + public SessionOperator(GraphAccessor accessor, IndexStore store) { + this.graphAccessor = Objects.requireNonNull(accessor); + this.indexStore = Objects.requireNonNull(store); + } - @Override - public List apply(List subGraphList, VectorSearch search) { - List keyWordVectors = search.getVectorMap().get(VectorType.KeywordVector); - if (keyWordVectors == null || keyWordVectors.isEmpty()) { - if (subGraphList == null) { - return new ArrayList<>(); - } - return new ArrayList<>(subGraphList); - } - List contents = new ArrayList<>(keyWordVectors.size()); - for (IVector v : keyWordVectors) { - contents.add(v.toString()); + @Override + public List apply(List subGraphList, VectorSearch search) { + List keyWordVectors = search.getVectorMap().get(VectorType.KeywordVector); + if (keyWordVectors == null || keyWordVectors.isEmpty()) { + if (subGraphList == null) { + return new ArrayList<>(); + } + return new ArrayList<>(subGraphList); + } + List contents = new ArrayList<>(keyWordVectors.size()); + for (IVector v : keyWordVectors) { + contents.add(v.toString()); + } + String query = String.join(SearchConstants.DELIMITER, contents); + List globalResults = searchWithGlobalGraph(query); + if (subGraphList == null || subGraphList.isEmpty()) { + List startVertices = new ArrayList<>(); + for (GraphEntity resEntity : globalResults) { + if (resEntity instanceof GraphVertex) { + startVertices.add((GraphVertex) resEntity); } - String query = String.join(SearchConstants.DELIMITER, contents); - List globalResults = searchWithGlobalGraph(query); - if (subGraphList == null || subGraphList.isEmpty()) { - List startVertices = new ArrayList<>(); - for (GraphEntity resEntity : globalResults) { - if (resEntity instanceof GraphVertex) { - startVertices.add((GraphVertex) resEntity); - } - } - //Apply to subgraph - return startVertices.stream().map(v -> { + } + // Apply to subgraph + return startVertices.stream() + .map( + v -> { SubGraph subGraph = new SubGraph(); subGraph.addVertex(v); return subGraph; - }).collect(Collectors.toList()); - } else { - Map> extendEntityIndexMap = new HashMap<>(); - //Traverse all extension points of the subgraph and search within the extension area - for (SubGraph subGraph : subGraphList) { - List extendEntities = getSubgraphExpand(subGraph); - for (GraphEntity extendEntity : extendEntities) { - List entityIndex = indexStore.getEntityIndex(extendEntity); - extendEntityIndexMap.put(extendEntity, entityIndex); - } - } - //recall compute - GraphSearchStore searchStore = initSearchStore(extendEntityIndexMap); - searchStore.close(); - List matchEntities = searchStore.search(query, graphAccessor); - Set matchEntitiesSet = new HashSet<>(matchEntities); + }) + .collect(Collectors.toList()); + } else { + Map> extendEntityIndexMap = new HashMap<>(); + // Traverse all extension points of the subgraph and search within the extension area + for (SubGraph subGraph : subGraphList) { + List extendEntities = getSubgraphExpand(subGraph); + for (GraphEntity extendEntity : extendEntities) { + List entityIndex = indexStore.getEntityIndex(extendEntity); + extendEntityIndexMap.put(extendEntity, entityIndex); + } + } + // recall compute + GraphSearchStore searchStore = initSearchStore(extendEntityIndexMap); + searchStore.close(); + List matchEntities = searchStore.search(query, graphAccessor); + Set matchEntitiesSet = new HashSet<>(matchEntities); - //Apply to subgraph - List subGraphs = new ArrayList<>(subGraphList); - for (SubGraph subGraph : subGraphs) { - Set subgraphEntitySet = new HashSet<>(subGraph.getGraphEntityList()); - List extendEntities = getSubgraphExpand(subGraph); - for (GraphEntity extendEntity : extendEntities) { - if (matchEntitiesSet.contains(extendEntity) - && !subgraphEntitySet.contains(extendEntity)) { - subgraphEntitySet.add(extendEntity); - subGraph.addEntity(extendEntity); - } - } - } - return subGraphs; + // Apply to subgraph + List subGraphs = new ArrayList<>(subGraphList); + for (SubGraph subGraph : subGraphs) { + Set subgraphEntitySet = new HashSet<>(subGraph.getGraphEntityList()); + List extendEntities = getSubgraphExpand(subGraph); + for (GraphEntity extendEntity : extendEntities) { + if (matchEntitiesSet.contains(extendEntity) + && !subgraphEntitySet.contains(extendEntity)) { + subgraphEntitySet.add(extendEntity); + subGraph.addEntity(extendEntity); + } } + } + return subGraphs; } + } - private List getSubgraphExpand(SubGraph subGraph) { - List entityList = subGraph.getGraphEntityList(); - List expandEntities = new ArrayList<>(); - for (GraphEntity entity : entityList) { - List entityExpand = graphAccessor.expand(entity); - expandEntities.addAll(entityExpand); - } - return expandEntities; + private List getSubgraphExpand(SubGraph subGraph) { + List entityList = subGraph.getGraphEntityList(); + List expandEntities = new ArrayList<>(); + for (GraphEntity entity : entityList) { + List entityExpand = graphAccessor.expand(entity); + expandEntities.addAll(entityExpand); } + return expandEntities; + } - private List searchWithGlobalGraph(String query) { - Map> entityIndexMap = new HashMap<>(); - Iterator vertexIterator = graphAccessor.scanVertex(); - while (vertexIterator.hasNext()) { - GraphVertex vertex = vertexIterator.next(); - //Read all vertices indices from the index and add them to the candidate set. - List vertexIndex = indexStore.getEntityIndex(vertex); - entityIndexMap.put(vertex, vertexIndex); - } - //recall compute - GraphSearchStore searchStore = initSearchStore(entityIndexMap); - searchStore.close(); - return searchStore.search(query, graphAccessor); + private List searchWithGlobalGraph(String query) { + Map> entityIndexMap = new HashMap<>(); + Iterator vertexIterator = graphAccessor.scanVertex(); + while (vertexIterator.hasNext()) { + GraphVertex vertex = vertexIterator.next(); + // Read all vertices indices from the index and add them to the candidate set. + List vertexIndex = indexStore.getEntityIndex(vertex); + entityIndexMap.put(vertex, vertexIndex); } + // recall compute + GraphSearchStore searchStore = initSearchStore(entityIndexMap); + searchStore.close(); + return searchStore.search(query, graphAccessor); + } - private GraphSearchStore initSearchStore(Map> entityIndexMap) { - GraphSearchStore searchStore = new GraphSearchStore(); - for (Map.Entry> entry : entityIndexMap.entrySet()) { - if (entry.getValue() != null && !entry.getValue().isEmpty()) { - if (entry.getKey() instanceof GraphVertex) { - searchStore.indexVertex((GraphVertex) entry.getKey(), entry.getValue()); - } else if (entry.getKey() instanceof GraphEdge) { - searchStore.indexEdge((GraphEdge) entry.getKey(), entry.getValue()); - } - } + private GraphSearchStore initSearchStore(Map> entityIndexMap) { + GraphSearchStore searchStore = new GraphSearchStore(); + for (Map.Entry> entry : entityIndexMap.entrySet()) { + if (entry.getValue() != null && !entry.getValue().isEmpty()) { + if (entry.getKey() instanceof GraphVertex) { + searchStore.indexVertex((GraphVertex) entry.getKey(), entry.getValue()); + } else if (entry.getKey() instanceof GraphEdge) { + searchStore.indexEdge((GraphEdge) entry.getKey(), entry.getValue()); } - return searchStore; + } } + return searchStore; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/search/VectorSearch.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/search/VectorSearch.java index 98db3a34d..28f48b3b8 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/search/VectorSearch.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/search/VectorSearch.java @@ -20,56 +20,61 @@ package org.apache.geaflow.ai.search; import java.util.*; + import org.apache.geaflow.ai.index.vector.IVector; import org.apache.geaflow.ai.index.vector.VectorType; public class VectorSearch { - public final String memoryId; + public final String memoryId; - public final String sessionId; + public final String sessionId; - public final Map> vectorList = new LinkedHashMap<>(); + public final Map> vectorList = new LinkedHashMap<>(); - public VectorSearch(String memoryId, String sessionId) { - this.memoryId = memoryId; - this.sessionId = sessionId; - } + public VectorSearch(String memoryId, String sessionId) { + this.memoryId = memoryId; + this.sessionId = sessionId; + } - public void addVector(IVector vector) { - addVector(Collections.singletonList(vector)); - } + public void addVector(IVector vector) { + addVector(Collections.singletonList(vector)); + } - public void addVector(List vectors) { - if (vectors == null) { - return; - } - for (IVector v : vectors) { - if (v != null) { - vectorList.computeIfAbsent(v.getType(), - k -> new ArrayList<>()).add(v); - } - } + public void addVector(List vectors) { + if (vectors == null) { + return; } - - @Override - public String toString() { - return "VectorSearch{" - + "memoryId='" + memoryId + '\'' - + ", sessionId='" + sessionId + '\'' - + ", vectorList=" + vectorList - + '}'; + for (IVector v : vectors) { + if (v != null) { + vectorList.computeIfAbsent(v.getType(), k -> new ArrayList<>()).add(v); + } } + } - public String getMemoryId() { - return memoryId; - } + @Override + public String toString() { + return "VectorSearch{" + + "memoryId='" + + memoryId + + '\'' + + ", sessionId='" + + sessionId + + '\'' + + ", vectorList=" + + vectorList + + '}'; + } - public String getSessionId() { - return sessionId; - } + public String getMemoryId() { + return memoryId; + } - public Map> getVectorMap() { - return vectorList; - } + public String getSessionId() { + return sessionId; + } + + public Map> getVectorMap() { + return vectorList; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/service/ServerMemoryCache.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/service/ServerMemoryCache.java index 8e894d180..7e65d9ea8 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/service/ServerMemoryCache.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/service/ServerMemoryCache.java @@ -22,44 +22,43 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; + import org.apache.geaflow.ai.GraphMemoryServer; import org.apache.geaflow.ai.consolidate.ConsolidateServer; import org.apache.geaflow.ai.graph.Graph; public class ServerMemoryCache { - private final Map name2Graph = new LinkedHashMap<>(); - private final Map name2Server = new LinkedHashMap<>(); - private final Map session2GraphName = new HashMap<>(); - private final ConsolidateServer consolidateServer = new ConsolidateServer(); + private final Map name2Graph = new LinkedHashMap<>(); + private final Map name2Server = new LinkedHashMap<>(); + private final Map session2GraphName = new HashMap<>(); + private final ConsolidateServer consolidateServer = new ConsolidateServer(); - public void putGraph(Graph g) { - name2Graph.put(g.getGraphSchema().getName(), g); - } + public void putGraph(Graph g) { + name2Graph.put(g.getGraphSchema().getName(), g); + } - public void putServer(GraphMemoryServer server) { - name2Server.put(server.getGraphAccessors().get(0) - .getGraphSchema().getName(), server); - } + public void putServer(GraphMemoryServer server) { + name2Server.put(server.getGraphAccessors().get(0).getGraphSchema().getName(), server); + } - public void putSession(GraphMemoryServer server, String sessionId) { - session2GraphName.put(sessionId, - server.getGraphAccessors().get(0).getGraphSchema().getName()); - } + public void putSession(GraphMemoryServer server, String sessionId) { + session2GraphName.put(sessionId, server.getGraphAccessors().get(0).getGraphSchema().getName()); + } - public Graph getGraphByName(String name) { - return name2Graph.get(name); - } + public Graph getGraphByName(String name) { + return name2Graph.get(name); + } - public GraphMemoryServer getServerByName(String name) { - return name2Server.get(name); - } + public GraphMemoryServer getServerByName(String name) { + return name2Server.get(name); + } - public String getGraphNameBySession(String sessionId) { - return session2GraphName.get(sessionId); - } + public String getGraphNameBySession(String sessionId) { + return session2GraphName.get(sessionId); + } - public ConsolidateServer getConsolidateServer() { - return consolidateServer; - } + public ConsolidateServer getConsolidateServer() { + return consolidateServer; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/session/SessionManagement.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/session/SessionManagement.java index 3e991e756..cc379a71e 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/session/SessionManagement.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/session/SessionManagement.java @@ -20,46 +20,48 @@ package org.apache.geaflow.ai.session; import java.util.*; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.subgraph.SubGraph; public class SessionManagement { - private final Map session2ActiveTime = new HashMap<>(); - private final Map> session2Graphs = new HashMap<>(); + private final Map session2ActiveTime = new HashMap<>(); + private final Map> session2Graphs = new HashMap<>(); - public SessionManagement() { - } + public SessionManagement() {} - public boolean createSession(String sessionId) { - if (session2ActiveTime.containsKey(sessionId)) { - return false; - } - session2ActiveTime.put(sessionId, System.nanoTime()); - session2Graphs.putIfAbsent(sessionId, new ArrayList<>()); - return true; + public boolean createSession(String sessionId) { + if (session2ActiveTime.containsKey(sessionId)) { + return false; } + session2ActiveTime.put(sessionId, System.nanoTime()); + session2Graphs.putIfAbsent(sessionId, new ArrayList<>()); + return true; + } - public String createSession() { - String sessionId = Constants.PREFIX_TMP_SESSION + System.nanoTime() - + UUID.randomUUID().toString().replace("-", "").substring(0, 8); - if (createSession(sessionId)) { - return sessionId; - } else { - return null; - } + public String createSession() { + String sessionId = + Constants.PREFIX_TMP_SESSION + + System.nanoTime() + + UUID.randomUUID().toString().replace("-", "").substring(0, 8); + if (createSession(sessionId)) { + return sessionId; + } else { + return null; } + } - public boolean sessionExists(String session) { - return this.session2ActiveTime.containsKey(session); - } + public boolean sessionExists(String session) { + return this.session2ActiveTime.containsKey(session); + } - public List getSubGraph(String sessionId) { - List l = this.session2Graphs.get(sessionId); - return l == null ? new ArrayList<>() : l; - } + public List getSubGraph(String sessionId) { + List l = this.session2Graphs.get(sessionId); + return l == null ? new ArrayList<>() : l; + } - public void setSubGraph(String sessionId, List subGraphs) { - this.session2Graphs.put(sessionId, subGraphs); - } + public void setSubGraph(String sessionId, List subGraphs) { + this.session2Graphs.put(sessionId, subGraphs); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/subgraph/SubGraph.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/subgraph/SubGraph.java index 6648a6aca..ec1c5376b 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/subgraph/SubGraph.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/subgraph/SubGraph.java @@ -21,37 +21,36 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.graph.GraphVertex; public class SubGraph { - private final List graphEntityList = new ArrayList<>(); - - public SubGraph addVertex(GraphVertex vertex) { - graphEntityList.add(vertex); - return this; - } - - public SubGraph addEdge(GraphEdge edge) { - graphEntityList.add(edge); - return this; - } - - public SubGraph addEntity(GraphEntity e) { - graphEntityList.add(e); - return this; - } - - public List getGraphEntityList() { - return graphEntityList; - } - - @Override - public String toString() { - return "SubGraph{" - + "graphEntityList=" + graphEntityList - + '}'; - } + private final List graphEntityList = new ArrayList<>(); + + public SubGraph addVertex(GraphVertex vertex) { + graphEntityList.add(vertex); + return this; + } + + public SubGraph addEdge(GraphEdge edge) { + graphEntityList.add(edge); + return this; + } + + public SubGraph addEntity(GraphEntity e) { + graphEntityList.add(e); + return this; + } + + public List getGraphEntityList() { + return graphEntityList; + } + + @Override + public String toString() { + return "SubGraph{" + "graphEntityList=" + graphEntityList + '}'; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/Context.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/Context.java index 62a449831..5fb3f7ec4 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/Context.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/Context.java @@ -21,16 +21,14 @@ public class Context { - private final String prompt; + private final String prompt; - public Context(String prompt) { - this.prompt = prompt; - } + public Context(String prompt) { + this.prompt = prompt; + } - @Override - public String toString() { - return "Context{" - + "prompt='" + prompt + '\'' - + '}'; - } + @Override + public String toString() { + return "Context{" + "prompt='" + prompt + '\'' + '}'; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/PromptFormatter.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/PromptFormatter.java index 85f824dea..944beded1 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/PromptFormatter.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/PromptFormatter.java @@ -25,9 +25,9 @@ public interface PromptFormatter { - String prompt(GraphSchema graphSchema); + String prompt(GraphSchema graphSchema); - String prompt(GraphVertex entity); + String prompt(GraphVertex entity); - String prompt(GraphEdge entity, GraphVertex start, GraphVertex end); + String prompt(GraphEdge entity, GraphVertex start, GraphVertex end); } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/SubgraphSemanticPromptFunction.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/SubgraphSemanticPromptFunction.java index 0b3407c06..62333e217 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/SubgraphSemanticPromptFunction.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/SubgraphSemanticPromptFunction.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphEntity; @@ -32,59 +33,63 @@ public class SubgraphSemanticPromptFunction implements VerbalizationFunction { - private final GraphAccessor graphAccessor; + private final GraphAccessor graphAccessor; - public SubgraphSemanticPromptFunction(GraphAccessor accessor) { - this.graphAccessor = Objects.requireNonNull(accessor); - } + public SubgraphSemanticPromptFunction(GraphAccessor accessor) { + this.graphAccessor = Objects.requireNonNull(accessor); + } - @Override - public String verbalize(SubGraph subGraph) { - if (subGraph == null || subGraph.getGraphEntityList().isEmpty()) { - return "Empty."; - } - List sentences = new ArrayList<>(); - GraphSchema schema = graphAccessor.getGraphSchema(); - Set existsEntities = new HashSet<>(); - for (GraphEntity entity : subGraph.getGraphEntityList()) { - if (entity instanceof GraphVertex) { - GraphVertex graphVertex = (GraphVertex) entity; - if (!existsEntities.contains(graphVertex)) { - sentences.add(schema.getPrompt(graphVertex)); - existsEntities.add(graphVertex); - } - } else if (entity instanceof GraphEdge) { - GraphEdge graphEdge = (GraphEdge) entity; - GraphVertex start = graphAccessor.getVertex(null, graphEdge.getEdge().getSrcId()); - GraphVertex end = graphAccessor.getVertex(null, graphEdge.getEdge().getDstId()); - sentences.add(schema.getPrompt(graphEdge, - existsEntities.contains(start) ? null : start, - existsEntities.contains(end) ? null : end)); - existsEntities.add(start); - existsEntities.add(end); - } - } - return String.join(SearchConstants.DELIMITER, sentences); + @Override + public String verbalize(SubGraph subGraph) { + if (subGraph == null || subGraph.getGraphEntityList().isEmpty()) { + return "Empty."; } - - @Override - public List verbalize(GraphEntity entity) { - if (entity instanceof GraphVertex) { - GraphVertex graphVertex = (GraphVertex) entity; - return graphVertex.getVertex().getValues().stream() - .filter(str -> !SearchUtils.isAllAllowedChars(str)) - .map(SearchUtils::formatQuery).collect(Collectors.toList()); - } else if (entity instanceof GraphEdge) { - GraphEdge graphEdge = (GraphEdge) entity; - return graphEdge.getEdge().getValues().stream() - .filter(str -> !SearchUtils.isAllAllowedChars(str)) - .map(SearchUtils::formatQuery).collect(Collectors.toList()); + List sentences = new ArrayList<>(); + GraphSchema schema = graphAccessor.getGraphSchema(); + Set existsEntities = new HashSet<>(); + for (GraphEntity entity : subGraph.getGraphEntityList()) { + if (entity instanceof GraphVertex) { + GraphVertex graphVertex = (GraphVertex) entity; + if (!existsEntities.contains(graphVertex)) { + sentences.add(schema.getPrompt(graphVertex)); + existsEntities.add(graphVertex); } - return new ArrayList<>(); + } else if (entity instanceof GraphEdge) { + GraphEdge graphEdge = (GraphEdge) entity; + GraphVertex start = graphAccessor.getVertex(null, graphEdge.getEdge().getSrcId()); + GraphVertex end = graphAccessor.getVertex(null, graphEdge.getEdge().getDstId()); + sentences.add( + schema.getPrompt( + graphEdge, + existsEntities.contains(start) ? null : start, + existsEntities.contains(end) ? null : end)); + existsEntities.add(start); + existsEntities.add(end); + } } + return String.join(SearchConstants.DELIMITER, sentences); + } - @Override - public String verbalizeGraphSchema() { - return ""; + @Override + public List verbalize(GraphEntity entity) { + if (entity instanceof GraphVertex) { + GraphVertex graphVertex = (GraphVertex) entity; + return graphVertex.getVertex().getValues().stream() + .filter(str -> !SearchUtils.isAllAllowedChars(str)) + .map(SearchUtils::formatQuery) + .collect(Collectors.toList()); + } else if (entity instanceof GraphEdge) { + GraphEdge graphEdge = (GraphEdge) entity; + return graphEdge.getEdge().getValues().stream() + .filter(str -> !SearchUtils.isAllAllowedChars(str)) + .map(SearchUtils::formatQuery) + .collect(Collectors.toList()); } + return new ArrayList<>(); + } + + @Override + public String verbalizeGraphSchema() { + return ""; + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/VerbalizationFunction.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/VerbalizationFunction.java index 6e90bc17a..71178b752 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/VerbalizationFunction.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/verbalization/VerbalizationFunction.java @@ -20,14 +20,15 @@ package org.apache.geaflow.ai.verbalization; import java.util.List; + import org.apache.geaflow.ai.graph.GraphEntity; import org.apache.geaflow.ai.subgraph.SubGraph; public interface VerbalizationFunction { - String verbalize(SubGraph subGraph); + String verbalize(SubGraph subGraph); - List verbalize(GraphEntity entity); + List verbalize(GraphEntity entity); - String verbalizeGraphSchema(); + String verbalizeGraphSchema(); } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java index 6216cb344..d54ab972d 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/GraphMemoryTest.java @@ -40,157 +40,189 @@ public class GraphMemoryTest { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class); - - @Test - public void testVectorSearch() { - VectorSearch search = new VectorSearch(null, "test01"); - search.addVector(new EmbeddingVector(new double[0])); - search.addVector(new KeywordVector("test1", "test2")); - search.addVector(new MagnitudeVector()); - search.addVector(new TraversalVector("src", "edge", "dst")); - LOGGER.info(String.valueOf(search)); + private static final Logger LOGGER = LoggerFactory.getLogger(GraphMemoryTest.class); + + @Test + public void testVectorSearch() { + VectorSearch search = new VectorSearch(null, "test01"); + search.addVector(new EmbeddingVector(new double[0])); + search.addVector(new KeywordVector("test1", "test2")); + search.addVector(new MagnitudeVector()); + search.addVector(new TraversalVector("src", "edge", "dst")); + LOGGER.info(String.valueOf(search)); + } + + @Test + public void testEmptyMainPipeline() { + GraphMemoryServer server = new GraphMemoryServer(); + IndexStore indexStore = new EntityAttributeIndexStore(); + GraphAccessor graphAccessor = new EmptyGraphAccessor(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + + String sessionId = server.createSession(); + VectorSearch search = new VectorSearch(null, sessionId); + String context = produceCycle(server, search, graphAccessor); + Assertions.assertEquals(context, "Context{prompt=''}"); + + context = produceCycle(server, search, graphAccessor); + Assertions.assertEquals(context, "Context{prompt=''}"); + + context = produceCycle(server, search, graphAccessor); + Assertions.assertEquals(context, "Context{prompt=''}"); + } + + private String produceCycle( + GraphMemoryServer server, VectorSearch search, GraphAccessor graphAccessor) { + String sessionId = search.getSessionId(); + String searchResult = server.search(search); + Assertions.assertNotNull(searchResult); + Context context = + server.verbalize(sessionId, new SubgraphSemanticPromptFunction(graphAccessor)); + Assertions.assertNotNull(context); + return context.toString(); + } + + @Test + public void testLdbcMainPipeline() { + LdbcPromptFormatter ldbcPromptFormatter = new LdbcPromptFormatter(); + LocalMemoryGraphAccessor graphAccessor = + new LocalMemoryGraphAccessor( + this.getClass().getClassLoader(), + "graph_ldbc_sf", + 7500L, + ldbcPromptFormatter::vertexMapper, + ldbcPromptFormatter::edgeMapper); + graphAccessor.getGraphSchema().setPromptFormatter(ldbcPromptFormatter); + LOGGER.info("Success to init graph data."); + + EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); + indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); + LOGGER.info("Success to init EntityAttributeIndexStore."); + + ModelConfig modelInfo = new ModelConfig(null, null, null, null); + EmbeddingIndexStore embeddingStore = new EmbeddingIndexStore(); + embeddingStore.initStore( + graphAccessor, + new SubgraphSemanticPromptFunction(graphAccessor), + "src/test/resources/index/LDBCEmbeddingIndexStore", + modelInfo); + LOGGER.info("Success to init EmbeddingIndexStore."); + + GraphMemoryServer server = new GraphMemoryServer(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + server.addIndexStore(embeddingStore); + MockChatRobot robot = new MockChatRobot(); + robot.setModelInfo(modelInfo); + + { + String sessionId = server.createSession(); + VectorSearch search = new VectorSearch(null, sessionId); + String query = "What comments has Chaim Azriel posted?"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + LOGGER.info("Round 1: \n" + search); + + String context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 1: \n" + context); + Assertions.assertTrue( + context.contains( + "A Person, male, register at 2010-02-19T12:42:14.255+00:00, id is Person166, name is" + + " Chaim Azriel Hleb, birthday is 1985-10-01, ip is 178.238.2.172, use browser" + + " is Firefox, use language are ru;pl;en, email address is" + + " Chaim.Azriel166@yahoo.com;Chaim.Azriel166@gmail.com;Chaim.Azriel166@gmx.com;Chaim.Azriel166@hotmail.com;Chaim.Azriel166@theblackmarket.com")); + Assertions.assertTrue( + context.contains( + "A Person, male, register at 2011-09-18T02:40:31.062+00:00, id is" + + " Person21990232556059, name is Chaim Azriel Epstein, birthday is 1981-07-17," + + " ip is 80.94.167.126, use browser is Firefox, use language are ru;pl;en, email" + + " address is" + + " Chaim.Azriel21990232556059@gmx.com;Chaim.Azriel21990232556059@gmail.com;Chaim.Azriel21990232556059@hotmail.com;Chaim.Azriel21990232556059@yahoo.com;Chaim.Azriel21990232556059@zoho.com")); + + search = new VectorSearch(null, sessionId); + query = + "Chaim Azriel, Comment_hasCreator_Person, personId, comment author, Person166," + + " Person21990232556059"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + LOGGER.info("Round 2: \n" + search); + + context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 2: \n" + context); + Assertions.assertTrue(context.contains("Comment824633737550")); + Assertions.assertTrue(context.contains("Comment962072691263")); + Assertions.assertTrue(context.contains("Comment687194784946")); } - @Test - public void testEmptyMainPipeline() { - GraphMemoryServer server = new GraphMemoryServer(); - IndexStore indexStore = new EntityAttributeIndexStore(); - GraphAccessor graphAccessor = new EmptyGraphAccessor(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - - String sessionId = server.createSession(); - VectorSearch search = new VectorSearch(null, sessionId); - String context = produceCycle(server, search, graphAccessor); - Assertions.assertEquals(context, "Context{prompt=''}"); - - context = produceCycle(server, search, graphAccessor); - Assertions.assertEquals(context, "Context{prompt=''}"); - - context = produceCycle(server, search, graphAccessor); - Assertions.assertEquals(context, "Context{prompt=''}"); - } - - private String produceCycle(GraphMemoryServer server, VectorSearch search, - GraphAccessor graphAccessor) { - String sessionId = search.getSessionId(); - String searchResult = server.search(search); - Assertions.assertNotNull(searchResult); - Context context = server.verbalize(sessionId, - new SubgraphSemanticPromptFunction(graphAccessor)); - Assertions.assertNotNull(context); - return context.toString(); + { + String sessionId = server.createSession(); + VectorSearch search = new VectorSearch(null, sessionId); + String query = "How many posts has Chaim Azriel posted?"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + + LOGGER.info("Round 1: \n" + search); + + String context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 1: \n" + context); + Assertions.assertTrue( + context.contains( + "A Forum, name is Wall of Chaim Azriel Hleb, id is Forum220, created at" + + " 2010-02-19T12:42:24.255+00:00")); + Assertions.assertTrue( + context.contains( + "A Person, male, register at 2010-02-19T12:42:14.255+00:00, id is Person166, name is" + + " Chaim Azriel Hleb, birthday is 1985-10-01, ip is 178.238.2.172, use browser" + + " is Firefox, use language are ru;pl;en, email address is" + + " Chaim.Azriel166@yahoo.com;Chaim.Azriel166@gmail.com;Chaim.Azriel166@gmx.com;Chaim.Azriel166@hotmail.com;Chaim.Azriel166@theblackmarket.com")); + Assertions.assertTrue( + context.contains( + "A Person, male, register at 2011-09-18T02:40:31.062+00:00, id is" + + " Person21990232556059, name is Chaim Azriel Epstein, birthday is 1981-07-17," + + " ip is 80.94.167.126, use browser is Firefox, use language are ru;pl;en, email" + + " address is" + + " Chaim.Azriel21990232556059@gmx.com;Chaim.Azriel21990232556059@gmail.com;Chaim.Azriel21990232556059@hotmail.com;Chaim.Azriel21990232556059@yahoo.com;Chaim.Azriel21990232556059@zoho.com")); + + search = new VectorSearch(null, sessionId); + query = "Chaim Azriel, Post_hasCreator_Person, Post"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + LOGGER.info("Round 2: \n" + search); + + context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 2: \n" + context); + Assertions.assertTrue(context.contains("Post755914247530")); + Assertions.assertTrue(context.contains("Post1099511644342")); + Assertions.assertTrue(context.contains("Person166")); } - @Test - public void testLdbcMainPipeline() { - LdbcPromptFormatter ldbcPromptFormatter = new LdbcPromptFormatter(); - LocalMemoryGraphAccessor graphAccessor = - new LocalMemoryGraphAccessor(this.getClass().getClassLoader(), - "graph_ldbc_sf", 7500L, - ldbcPromptFormatter::vertexMapper, ldbcPromptFormatter::edgeMapper); - graphAccessor.getGraphSchema().setPromptFormatter(ldbcPromptFormatter); - LOGGER.info("Success to init graph data."); - - EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); - indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); - LOGGER.info("Success to init EntityAttributeIndexStore."); - - ModelConfig modelInfo = new ModelConfig(null, null, null, null); - EmbeddingIndexStore embeddingStore = new EmbeddingIndexStore(); - embeddingStore.initStore(graphAccessor, - new SubgraphSemanticPromptFunction(graphAccessor), - "src/test/resources/index/LDBCEmbeddingIndexStore", - modelInfo); - LOGGER.info("Success to init EmbeddingIndexStore."); - - GraphMemoryServer server = new GraphMemoryServer(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - server.addIndexStore(embeddingStore); - MockChatRobot robot = new MockChatRobot(); - robot.setModelInfo(modelInfo); - - { - String sessionId = server.createSession(); - VectorSearch search = new VectorSearch(null, sessionId); - String query = "What comments has Chaim Azriel posted?"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - LOGGER.info("Round 1: \n" + search); - - String context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 1: \n" + context); - Assertions.assertTrue(context.contains("A Person, male, register at 2010-02-19T12:42:14.255+00:00, id is Person166, name is Chaim Azriel Hleb, birthday is 1985-10-01, ip is 178.238.2.172, use browser is Firefox, use language are ru;pl;en, email address is Chaim.Azriel166@yahoo.com;Chaim.Azriel166@gmail.com;Chaim.Azriel166@gmx.com;Chaim.Azriel166@hotmail.com;Chaim.Azriel166@theblackmarket.com")); - Assertions.assertTrue(context.contains("A Person, male, register at 2011-09-18T02:40:31.062+00:00, id is Person21990232556059, name is Chaim Azriel Epstein, birthday is 1981-07-17, ip is 80.94.167.126, use browser is Firefox, use language are ru;pl;en, email address is Chaim.Azriel21990232556059@gmx.com;Chaim.Azriel21990232556059@gmail.com;Chaim.Azriel21990232556059@hotmail.com;Chaim.Azriel21990232556059@yahoo.com;Chaim.Azriel21990232556059@zoho.com")); - - search = new VectorSearch(null, sessionId); - query = "Chaim Azriel, Comment_hasCreator_Person, personId, comment author, Person166, Person21990232556059"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - LOGGER.info("Round 2: \n" + search); - - context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 2: \n" + context); - Assertions.assertTrue(context.contains("Comment824633737550")); - Assertions.assertTrue(context.contains("Comment962072691263")); - Assertions.assertTrue(context.contains("Comment687194784946")); - } - - { - String sessionId = server.createSession(); - VectorSearch search = new VectorSearch(null, sessionId); - String query = "How many posts has Chaim Azriel posted?"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - - LOGGER.info("Round 1: \n" + search); - - String context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 1: \n" + context); - Assertions.assertTrue(context.contains("A Forum, name is Wall of Chaim Azriel Hleb, id is Forum220, created at 2010-02-19T12:42:24.255+00:00")); - Assertions.assertTrue(context.contains("A Person, male, register at 2010-02-19T12:42:14.255+00:00, id is Person166, name is Chaim Azriel Hleb, birthday is 1985-10-01, ip is 178.238.2.172, use browser is Firefox, use language are ru;pl;en, email address is Chaim.Azriel166@yahoo.com;Chaim.Azriel166@gmail.com;Chaim.Azriel166@gmx.com;Chaim.Azriel166@hotmail.com;Chaim.Azriel166@theblackmarket.com")); - Assertions.assertTrue(context.contains("A Person, male, register at 2011-09-18T02:40:31.062+00:00, id is Person21990232556059, name is Chaim Azriel Epstein, birthday is 1981-07-17, ip is 80.94.167.126, use browser is Firefox, use language are ru;pl;en, email address is Chaim.Azriel21990232556059@gmx.com;Chaim.Azriel21990232556059@gmail.com;Chaim.Azriel21990232556059@hotmail.com;Chaim.Azriel21990232556059@yahoo.com;Chaim.Azriel21990232556059@zoho.com")); - - search = new VectorSearch(null, sessionId); - query = "Chaim Azriel, Post_hasCreator_Person, Post"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - LOGGER.info("Round 2: \n" + search); - - context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 2: \n" + context); - Assertions.assertTrue(context.contains("Post755914247530")); - Assertions.assertTrue(context.contains("Post1099511644342")); - Assertions.assertTrue(context.contains("Person166")); - } - - { - String sessionId = server.createSession(); - VectorSearch search = new VectorSearch(null, sessionId); - String query = "Which historical comments exist, who posted them, and how many have been published?"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - - LOGGER.info("Round 1: \n" + search); - - String context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 1: \n" + context); - Assertions.assertTrue(context.contains("Comment1099511634167")); - Assertions.assertTrue(context.contains("Comment1030792164752")); - Assertions.assertTrue(context.contains("Comment1099511645848")); - - search = new VectorSearch(null, sessionId); - query = "Comment_hasCreator_Person, Person, Comment IDs, Comment1030792157359"; - search.addVector(new KeywordVector(query)); - search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); - LOGGER.info("Round 2: \n" + search); - - context = produceCycle(server, search, graphAccessor); - LOGGER.info("Round 2: \n" + context); - Assertions.assertTrue(context.contains("Person24189255812253")); - Assertions.assertTrue(context.contains("Person26388279067480")); - } + { + String sessionId = server.createSession(); + VectorSearch search = new VectorSearch(null, sessionId); + String query = + "Which historical comments exist, who posted them, and how many have been published?"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + + LOGGER.info("Round 1: \n" + search); + + String context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 1: \n" + context); + Assertions.assertTrue(context.contains("Comment1099511634167")); + Assertions.assertTrue(context.contains("Comment1030792164752")); + Assertions.assertTrue(context.contains("Comment1099511645848")); + + search = new VectorSearch(null, sessionId); + query = "Comment_hasCreator_Person, Person, Comment IDs, Comment1030792157359"; + search.addVector(new KeywordVector(query)); + search.addVector(new EmbeddingVector(robot.embeddingSingle(query).embedding)); + LOGGER.info("Round 2: \n" + search); + + context = produceCycle(server, search, graphAccessor); + LOGGER.info("Round 2: \n" + context); + Assertions.assertTrue(context.contains("Person24189255812253")); + Assertions.assertTrue(context.contains("Person26388279067480")); } + } } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/LdbcPromptFormatter.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/LdbcPromptFormatter.java index f0560a547..cebf36950 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/LdbcPromptFormatter.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/LdbcPromptFormatter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.ai; import java.util.stream.Collectors; + import org.apache.geaflow.ai.graph.GraphEdge; import org.apache.geaflow.ai.graph.GraphVertex; import org.apache.geaflow.ai.graph.io.*; @@ -27,289 +28,348 @@ public class LdbcPromptFormatter implements PromptFormatter { - public Vertex vertexMapper(Vertex v) { - String idPrefix = v.getLabel(); - return new Vertex(v.getLabel(), idPrefix + v.getId(), v.getValues()); - } + public Vertex vertexMapper(Vertex v) { + String idPrefix = v.getLabel(); + return new Vertex(v.getLabel(), idPrefix + v.getId(), v.getValues()); + } - public Edge edgeMapper(Edge e) { - String srcPrefix = ""; - String dstPrefix = ""; - switch (e.getLabel()) { - case "Forum_hasMember_Person": - case "Forum_hasModerator_Person": - srcPrefix = "Forum"; - dstPrefix = "Person"; - break; - case "Person_knows_Person": - srcPrefix = "Person"; - dstPrefix = "Person"; - break; - case "Comment_replyOf_Post": - srcPrefix = "Comment"; - dstPrefix = "Post"; - break; - case "Person_isLocatedIn_City": - srcPrefix = "Person"; - dstPrefix = "Place"; - break; - case "Comment_hasTag_Tag": - srcPrefix = "Comment"; - dstPrefix = "Tag"; - break; - case "Person_workAt_Company": - srcPrefix = "Person"; - dstPrefix = "Company"; - break; - case "Comment_replyOf_Comment": - srcPrefix = "Comment"; - dstPrefix = "Comment"; - break; - case "Organisation_isLocatedIn_Place": - srcPrefix = "Organisation"; - dstPrefix = "Place"; - break; - case "Comment_isLocatedIn_Country": - srcPrefix = "Comment"; - dstPrefix = "Country"; - break; - case "Person_hasInterest_Tag": - srcPrefix = "Person"; - dstPrefix = "Tag"; - break; - case "Forum_hasTag_Tag": - srcPrefix = "Forum"; - dstPrefix = "Tag"; - break; - case "Person_studyAt_University": - srcPrefix = "Person"; - dstPrefix = "Organisation"; - break; - case "Comment_hasCreator_Person": - srcPrefix = "Comment"; - dstPrefix = "Person"; - break; - case "TagClass_isSubclassOf_TagClass": - srcPrefix = "TagClass"; - dstPrefix = "TagClass"; - break; - case "Post_isLocatedIn_Country": - srcPrefix = "Post"; - dstPrefix = "Country"; - break; - case "Post_hasCreator_Person": - srcPrefix = "Post"; - dstPrefix = "Person"; - break; - case "Post_hasTag_Tag": - srcPrefix = "Post"; - dstPrefix = "Tag"; - break; - case "Person_likes_Post": - srcPrefix = "Person"; - dstPrefix = "Post"; - break; - case "Person_likes_Comment": - srcPrefix = "Person"; - dstPrefix = "Comment"; - break; - case "Forum_containerOf_Post": - srcPrefix = "Forum"; - dstPrefix = "Post"; - break; - case "Tag_hasType_TagClass": - srcPrefix = "Tag"; - dstPrefix = "TagClass"; - break; - case "Place_isPartOf_Place": - srcPrefix = "Place"; - dstPrefix = "Place"; - break; - default: - srcPrefix = ""; - dstPrefix = ""; - break; - } - return new Edge(e.getLabel(), srcPrefix + e.getSrcId(), - dstPrefix + e.getDstId(), e.getValues()); + public Edge edgeMapper(Edge e) { + String srcPrefix = ""; + String dstPrefix = ""; + switch (e.getLabel()) { + case "Forum_hasMember_Person": + case "Forum_hasModerator_Person": + srcPrefix = "Forum"; + dstPrefix = "Person"; + break; + case "Person_knows_Person": + srcPrefix = "Person"; + dstPrefix = "Person"; + break; + case "Comment_replyOf_Post": + srcPrefix = "Comment"; + dstPrefix = "Post"; + break; + case "Person_isLocatedIn_City": + srcPrefix = "Person"; + dstPrefix = "Place"; + break; + case "Comment_hasTag_Tag": + srcPrefix = "Comment"; + dstPrefix = "Tag"; + break; + case "Person_workAt_Company": + srcPrefix = "Person"; + dstPrefix = "Company"; + break; + case "Comment_replyOf_Comment": + srcPrefix = "Comment"; + dstPrefix = "Comment"; + break; + case "Organisation_isLocatedIn_Place": + srcPrefix = "Organisation"; + dstPrefix = "Place"; + break; + case "Comment_isLocatedIn_Country": + srcPrefix = "Comment"; + dstPrefix = "Country"; + break; + case "Person_hasInterest_Tag": + srcPrefix = "Person"; + dstPrefix = "Tag"; + break; + case "Forum_hasTag_Tag": + srcPrefix = "Forum"; + dstPrefix = "Tag"; + break; + case "Person_studyAt_University": + srcPrefix = "Person"; + dstPrefix = "Organisation"; + break; + case "Comment_hasCreator_Person": + srcPrefix = "Comment"; + dstPrefix = "Person"; + break; + case "TagClass_isSubclassOf_TagClass": + srcPrefix = "TagClass"; + dstPrefix = "TagClass"; + break; + case "Post_isLocatedIn_Country": + srcPrefix = "Post"; + dstPrefix = "Country"; + break; + case "Post_hasCreator_Person": + srcPrefix = "Post"; + dstPrefix = "Person"; + break; + case "Post_hasTag_Tag": + srcPrefix = "Post"; + dstPrefix = "Tag"; + break; + case "Person_likes_Post": + srcPrefix = "Person"; + dstPrefix = "Post"; + break; + case "Person_likes_Comment": + srcPrefix = "Person"; + dstPrefix = "Comment"; + break; + case "Forum_containerOf_Post": + srcPrefix = "Forum"; + dstPrefix = "Post"; + break; + case "Tag_hasType_TagClass": + srcPrefix = "Tag"; + dstPrefix = "TagClass"; + break; + case "Place_isPartOf_Place": + srcPrefix = "Place"; + dstPrefix = "Place"; + break; + default: + srcPrefix = ""; + dstPrefix = ""; + break; } + return new Edge( + e.getLabel(), srcPrefix + e.getSrcId(), dstPrefix + e.getDstId(), e.getValues()); + } - @Override - public String prompt(GraphSchema graphSchema) { - return "allVerticesType:[" + graphSchema.getVertexSchemaList() - .stream().map(VertexSchema::getLabel).collect(Collectors.joining(",")) + "]\n" - + "allRelations:[" + graphSchema.getEdgeSchemaList() - .stream().map(EdgeSchema::getLabel).collect(Collectors.joining(",")) + "]\n"; - } + @Override + public String prompt(GraphSchema graphSchema) { + return "allVerticesType:[" + + graphSchema.getVertexSchemaList().stream() + .map(VertexSchema::getLabel) + .collect(Collectors.joining(",")) + + "]\n" + + "allRelations:[" + + graphSchema.getEdgeSchemaList().stream() + .map(EdgeSchema::getLabel) + .collect(Collectors.joining(",")) + + "]\n"; + } - @Override - public String prompt(GraphVertex entity) { - if (entity == null) { - return "Empty Vertex."; - } else if (entity.getVertex() == null) { - return "The Vertex."; - } - Vertex obj; - Edge edge; - switch (entity.getLabel()) { - case "Person": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A Person, %s, register at %s, id is %s, name is %s %s, birthday is %s, " - + "ip is %s, use browser is %s, use language are %s, email address is %s", - obj.getValues().get(4), obj.getValues().get(0), obj.getId(), - obj.getValues().get(2), obj.getValues().get(3), obj.getValues().get(5), - obj.getValues().get(6), obj.getValues().get(7), obj.getValues().get(8), obj.getValues().get(9)); - case "Forum": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A Forum, name is %s, id is %s, created at %s", - obj.getValues().get(2), obj.getId(), obj.getValues().get(0)); - case "TagClass": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A TagClass, name is %s, id is %s, url %s", - obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); - case "Tag": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A tag, name is %s, id is %s, url %s", - obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); - case "Comment": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A comment, id is %s, created at %s, user ip is %s, " - + "user use browser is %s, content is %s", - obj.getId(), obj.getValues().get(0), obj.getValues().get(2), obj.getValues().get(3), obj.getValues().get(4)); - case "Post": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A post, id is %s, created at %s, user ip is %s, " - + "user use browser is %s, use language is %s, content is %s", - obj.getId(), obj.getValues().get(0), obj.getValues().get(3), obj.getValues().get(4), obj.getValues().get(5), obj.getValues().get(6)); - case "Organisation": - obj = ((GraphVertex) entity).getVertex(); - return String.format("An organisation of %s, name is %s, id is %s, url %s", - obj.getValues().get(1), obj.getValues().get(2), obj.getId(), obj.getValues().get(3)); - case "Place": - obj = ((GraphVertex) entity).getVertex(); - return String.format("A place of %s, name is %s, id is %s, url %s", - obj.getValues().get(3), obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); - default: - return entity.toString(); - } + @Override + public String prompt(GraphVertex entity) { + if (entity == null) { + return "Empty Vertex."; + } else if (entity.getVertex() == null) { + return "The Vertex."; } + Vertex obj; + Edge edge; + switch (entity.getLabel()) { + case "Person": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A Person, %s, register at %s, id is %s, name is %s %s, birthday is %s, " + + "ip is %s, use browser is %s, use language are %s, email address is %s", + obj.getValues().get(4), + obj.getValues().get(0), + obj.getId(), + obj.getValues().get(2), + obj.getValues().get(3), + obj.getValues().get(5), + obj.getValues().get(6), + obj.getValues().get(7), + obj.getValues().get(8), + obj.getValues().get(9)); + case "Forum": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A Forum, name is %s, id is %s, created at %s", + obj.getValues().get(2), obj.getId(), obj.getValues().get(0)); + case "TagClass": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A TagClass, name is %s, id is %s, url %s", + obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); + case "Tag": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A tag, name is %s, id is %s, url %s", + obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); + case "Comment": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A comment, id is %s, created at %s, user ip is %s, " + + "user use browser is %s, content is %s", + obj.getId(), + obj.getValues().get(0), + obj.getValues().get(2), + obj.getValues().get(3), + obj.getValues().get(4)); + case "Post": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A post, id is %s, created at %s, user ip is %s, " + + "user use browser is %s, use language is %s, content is %s", + obj.getId(), + obj.getValues().get(0), + obj.getValues().get(3), + obj.getValues().get(4), + obj.getValues().get(5), + obj.getValues().get(6)); + case "Organisation": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "An organisation of %s, name is %s, id is %s, url %s", + obj.getValues().get(1), obj.getValues().get(2), obj.getId(), obj.getValues().get(3)); + case "Place": + obj = ((GraphVertex) entity).getVertex(); + return String.format( + "A place of %s, name is %s, id is %s, url %s", + obj.getValues().get(3), obj.getValues().get(1), obj.getId(), obj.getValues().get(2)); + default: + return entity.toString(); + } + } - @Override - public String prompt(GraphEdge entity, GraphVertex start, GraphVertex end) { - if (entity == null) { - return "Empty Edge."; - } else if (entity.getEdge() == null) { - return "The Edge."; - } - StringBuilder builder = new StringBuilder(); - builder.append(promptEdge(entity)); - builder.append("\n startVertex:"); - builder.append(prompt(start)); - builder.append("\n endVertex:"); - builder.append(prompt(end)); - return builder.toString(); + @Override + public String prompt(GraphEdge entity, GraphVertex start, GraphVertex end) { + if (entity == null) { + return "Empty Edge."; + } else if (entity.getEdge() == null) { + return "The Edge."; } + StringBuilder builder = new StringBuilder(); + builder.append(promptEdge(entity)); + builder.append("\n startVertex:"); + builder.append(prompt(start)); + builder.append("\n endVertex:"); + builder.append(prompt(end)); + return builder.toString(); + } - public String promptEdge(GraphEdge entity) { - if (entity == null || entity.getEdge() == null) { - return "Empty Edge."; - } - Edge edge; - switch (entity.getLabel()) { - case "Forum_hasMember_Person": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Forum_hasMember_Person, Forum id %s has member id %s, register at %s", - edge.getSrcId(), edge.getDstId(), edge.getValues().get(0)); - case "Person_workAt_Company": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_workAt_Company, Person id %s work at company id %s, start from %s year", - edge.getSrcId(), edge.getDstId(), edge.getValues().get(3)); - case "Organisation_isLocatedIn_Place": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Organisation_isLocatedIn_Place, Organisation id %s is located in place id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_hasInterest_Tag": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_hasInterest_Tag, Person id %s interest at tag id %s", - edge.getSrcId(), edge.getDstId()); - case "Forum_hasTag_Tag": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Forum_hasTag_Tag, Forum id %s has tag id %s", - edge.getSrcId(), edge.getDstId()); - case "Forum_hasModerator_Person": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Forum_hasModerator_Person, Forum id %s has moderator id %s", - edge.getSrcId(), edge.getDstId()); - case "Forum_containerOf_Post": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Forum_containerOf_Post, Forum id %s contain of the post id %s", - edge.getSrcId(), edge.getDstId()); - case "Tag_hasType_TagClass": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Tag_hasType_TagClass, Tag id %s has type of tag class id %s", - edge.getSrcId(), edge.getDstId()); - case "Place_isPartOf_Place": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Place_isPartOf_Place, The place id %s is part of place id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_isLocatedIn_City": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_isLocatedIn_City, The person id %s is located in city id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_knows_Person": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_knows_Person, The person id %s knows person id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_studyAt_University": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_studyAt_University, The person id %s study at university id %s", - edge.getSrcId(), edge.getDstId()); - case "Comment_hasCreator_Person": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Comment_hasCreator_Person, The comment id %s has creator person id %s", - edge.getSrcId(), edge.getDstId()); - case "Post_hasCreator_Person": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Post_hasCreator_Person, The post id %s has creator person id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_likes_Post": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_likes_Post, The person id %s likes person id %s", - edge.getSrcId(), edge.getDstId()); - case "Person_likes_Comment": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Person_likes_Comment, The person id %s likes comment id %s", - edge.getSrcId(), edge.getDstId()); - case "Comment_replyOf_Post": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Comment_replyOf_Post, The comment id %s reply of the post id %s", - edge.getSrcId(), edge.getDstId()); - case "Post_isLocatedIn_Country": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Post_isLocatedIn_Country, The post id %s is located in country id %s", - edge.getSrcId(), edge.getDstId()); - case "Post_hasTag_Tag": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Post_hasTag_Tag, The post id %s has tag id %s", - edge.getSrcId(), edge.getDstId()); - case "Comment_hasTag_Tag": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Comment_hasTag_Tag, The comment id %s has tag id %s", - edge.getSrcId(), edge.getDstId()); - case "Comment_isLocatedIn_Country": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Comment_isLocatedIn_Country, The comment id %s is located in country id %s", - edge.getSrcId(), edge.getDstId()); - case "Comment_replyOf_Comment": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type Comment_replyOf_Comment, The comment id %s is reply of comment id %s", - edge.getSrcId(), edge.getDstId()); - case "TagClass_isSubclassOf_TagClass": - edge = ((GraphEdge) entity).getEdge(); - return String.format("One edge of Type TagClass_isSubclassOf_TagClass, The tag class id %s is subclass of the tag class id %s", - edge.getSrcId(), edge.getDstId()); - default: - return entity.toString(); - } + public String promptEdge(GraphEdge entity) { + if (entity == null || entity.getEdge() == null) { + return "Empty Edge."; + } + Edge edge; + switch (entity.getLabel()) { + case "Forum_hasMember_Person": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Forum_hasMember_Person, Forum id %s has member id %s, register at %s", + edge.getSrcId(), edge.getDstId(), edge.getValues().get(0)); + case "Person_workAt_Company": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_workAt_Company, Person id %s work at company id %s, start from" + + " %s year", + edge.getSrcId(), edge.getDstId(), edge.getValues().get(3)); + case "Organisation_isLocatedIn_Place": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Organisation_isLocatedIn_Place, Organisation id %s is located in" + + " place id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_hasInterest_Tag": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_hasInterest_Tag, Person id %s interest at tag id %s", + edge.getSrcId(), edge.getDstId()); + case "Forum_hasTag_Tag": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Forum_hasTag_Tag, Forum id %s has tag id %s", + edge.getSrcId(), edge.getDstId()); + case "Forum_hasModerator_Person": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Forum_hasModerator_Person, Forum id %s has moderator id %s", + edge.getSrcId(), edge.getDstId()); + case "Forum_containerOf_Post": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Forum_containerOf_Post, Forum id %s contain of the post id %s", + edge.getSrcId(), edge.getDstId()); + case "Tag_hasType_TagClass": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Tag_hasType_TagClass, Tag id %s has type of tag class id %s", + edge.getSrcId(), edge.getDstId()); + case "Place_isPartOf_Place": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Place_isPartOf_Place, The place id %s is part of place id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_isLocatedIn_City": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_isLocatedIn_City, The person id %s is located in city id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_knows_Person": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_knows_Person, The person id %s knows person id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_studyAt_University": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_studyAt_University, The person id %s study at university id" + + " %s", + edge.getSrcId(), edge.getDstId()); + case "Comment_hasCreator_Person": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Comment_hasCreator_Person, The comment id %s has creator person id" + + " %s", + edge.getSrcId(), edge.getDstId()); + case "Post_hasCreator_Person": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Post_hasCreator_Person, The post id %s has creator person id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_likes_Post": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_likes_Post, The person id %s likes person id %s", + edge.getSrcId(), edge.getDstId()); + case "Person_likes_Comment": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Person_likes_Comment, The person id %s likes comment id %s", + edge.getSrcId(), edge.getDstId()); + case "Comment_replyOf_Post": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Comment_replyOf_Post, The comment id %s reply of the post id %s", + edge.getSrcId(), edge.getDstId()); + case "Post_isLocatedIn_Country": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Post_isLocatedIn_Country, The post id %s is located in country id %s", + edge.getSrcId(), edge.getDstId()); + case "Post_hasTag_Tag": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Post_hasTag_Tag, The post id %s has tag id %s", + edge.getSrcId(), edge.getDstId()); + case "Comment_hasTag_Tag": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Comment_hasTag_Tag, The comment id %s has tag id %s", + edge.getSrcId(), edge.getDstId()); + case "Comment_isLocatedIn_Country": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Comment_isLocatedIn_Country, The comment id %s is located in country" + + " id %s", + edge.getSrcId(), edge.getDstId()); + case "Comment_replyOf_Comment": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type Comment_replyOf_Comment, The comment id %s is reply of comment id %s", + edge.getSrcId(), edge.getDstId()); + case "TagClass_isSubclassOf_TagClass": + edge = ((GraphEdge) entity).getEdge(); + return String.format( + "One edge of Type TagClass_isSubclassOf_TagClass, The tag class id %s is subclass of" + + " the tag class id %s", + edge.getSrcId(), edge.getDstId()); + default: + return entity.toString(); } + } } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MemoryServerTest.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MemoryServerTest.java index 73db0a1af..e1b21b615 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MemoryServerTest.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MemoryServerTest.java @@ -19,13 +19,12 @@ package org.apache.geaflow.ai; -import com.google.gson.Gson; import java.io.IOException; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; -import okhttp3.*; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.graph.io.*; import org.junit.jupiter.api.*; @@ -33,198 +32,202 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -@SolonTest(GeaFlowMemoryServer.class) -public class MemoryServerTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(MemoryServerTest.class); - private static final String BASE_URL = "http://localhost:8080"; - private static final String GRAPH_NAME = "Confucius"; - private static OkHttpClient client; - - @BeforeEach - void setUp() { - LOGGER.info("Setting up test environment..."); - if (client == null) { - OkHttpClient.Builder builder = new OkHttpClient.Builder(); - builder.callTimeout(Constants.HTTP_CALL_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.connectTimeout(Constants.HTTP_CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.readTimeout(Constants.HTTP_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS); - builder.writeTimeout(Constants.HTTP_WRITE_TIMEOUT_SECONDS, TimeUnit.SECONDS); - client = builder.build(); - } - LOGGER.info("Test HTTP client initialized, base URL: {}", BASE_URL); - } - - @AfterEach - void tearDown() { - LOGGER.info("Cleaning up test environment..."); - } +import com.google.gson.Gson; - private String get(String useApi) { - String url = BASE_URL + useApi; - Request request = new Request.Builder().url(url).get().build(); - try (okhttp3.Response response = client.newCall(request).execute()) { - return response.body().string(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } +import okhttp3.*; - private String post(String useApi, String bodyJson) { - return post(useApi, bodyJson, Collections.emptyMap()); - } +@SolonTest(GeaFlowMemoryServer.class) +public class MemoryServerTest { - private String post(String useApi, String bodyJson, Map queryParams) { - RequestBody requestBody = RequestBody.create( - MediaType.parse("application/json; charset=utf-8"), bodyJson); - String url = BASE_URL + useApi; - HttpUrl.Builder urlBuilder = HttpUrl.parse(url).newBuilder(); - if (queryParams != null) { - for (Map.Entry entry : queryParams.entrySet()) { - urlBuilder.addQueryParameter(entry.getKey(), entry.getValue()); - } - } - Request request = new Request.Builder().url(urlBuilder.build()).post(requestBody).build(); - try (okhttp3.Response response = client.newCall(request).execute()) { - return response.body().string(); - } catch (IOException e) { - throw new RuntimeException(e); - } + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryServerTest.class); + private static final String BASE_URL = "http://localhost:8080"; + private static final String GRAPH_NAME = "Confucius"; + private static OkHttpClient client; + + @BeforeEach + void setUp() { + LOGGER.info("Setting up test environment..."); + if (client == null) { + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + builder.callTimeout(Constants.HTTP_CALL_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.connectTimeout(Constants.HTTP_CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.readTimeout(Constants.HTTP_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.writeTimeout(Constants.HTTP_WRITE_TIMEOUT_SECONDS, TimeUnit.SECONDS); + client = builder.build(); } - - @Test - void testMain() throws Exception { - testServerHealth(); - testCreateGraph(); - testAddEntities(); - testQueries(); + LOGGER.info("Test HTTP client initialized, base URL: {}", BASE_URL); + } + + @AfterEach + void tearDown() { + LOGGER.info("Cleaning up test environment..."); + } + + private String get(String useApi) { + String url = BASE_URL + useApi; + Request request = new Request.Builder().url(url).get().build(); + try (okhttp3.Response response = client.newCall(request).execute()) { + return response.body().string(); + } catch (IOException e) { + throw new RuntimeException(e); } - - void testServerHealth() throws Exception { - LOGGER.info("Testing server health endpoint..."); - String api = "/"; - String response = get(api); - LOGGER.info("API: {} Response: {}", api, response); - api = "/health"; - response = get(api); - LOGGER.info("API: {} Response: {}", api, response); - api = "/api/test"; - response = get(api); - LOGGER.info("API: {} Response: {}", api, response); + } + + private String post(String useApi, String bodyJson) { + return post(useApi, bodyJson, Collections.emptyMap()); + } + + private String post(String useApi, String bodyJson, Map queryParams) { + RequestBody requestBody = + RequestBody.create(MediaType.parse("application/json; charset=utf-8"), bodyJson); + String url = BASE_URL + useApi; + HttpUrl.Builder urlBuilder = HttpUrl.parse(url).newBuilder(); + if (queryParams != null) { + for (Map.Entry entry : queryParams.entrySet()) { + urlBuilder.addQueryParameter(entry.getKey(), entry.getValue()); + } } - - void testCreateGraph() throws Exception { - LOGGER.info("Testing server create graph..."); - Gson gson = new Gson(); - String api = "/graph/create"; - GraphSchema testGraph = new GraphSchema(); - String graphName = GRAPH_NAME; - testGraph.setName(graphName); - String response = post(api, gson.toJson(testGraph)); - LOGGER.info("API: {} Response: {}", api, response); - - api = "/graph/getGraphSchema"; - Map queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertEquals(gson.toJson(testGraph), response); - - VertexSchema vertexSchema = new VertexSchema("chunk", "id", - Collections.singletonList("text")); - EdgeSchema edgeSchema = new EdgeSchema("relation", "srcId", "dstId", - Collections.singletonList("rel")); - testGraph.addVertex(vertexSchema); - testGraph.addEdge(edgeSchema); - - api = "/graph/addEntitySchema"; - queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - response = post(api, gson.toJson(vertexSchema), queryParams); - LOGGER.info("API: {} Response: {}", api, response); - - api = "/graph/addEntitySchema"; - queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - response = post(api, gson.toJson(edgeSchema), queryParams); - LOGGER.info("API: {} Response: {}", api, response); - - api = "/graph/getGraphSchema"; - queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertEquals(gson.toJson(testGraph), response); + Request request = new Request.Builder().url(urlBuilder.build()).post(requestBody).build(); + try (okhttp3.Response response = client.newCall(request).execute()) { + return response.body().string(); + } catch (IOException e) { + throw new RuntimeException(e); } - - void testAddEntities() throws Exception { - LOGGER.info("Testing server add entities..."); - Gson gson = new Gson(); - String graphName = GRAPH_NAME; - - String api = "/graph/getGraphSchema"; - Map queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - String response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - - TextFileReader textFileReader = new TextFileReader(10000); - textFileReader.readFile("text/Confucius"); - List chunks = IntStream.range(0, textFileReader.getRowCount()) + } + + @Test + void testMain() throws Exception { + testServerHealth(); + testCreateGraph(); + testAddEntities(); + testQueries(); + } + + void testServerHealth() throws Exception { + LOGGER.info("Testing server health endpoint..."); + String api = "/"; + String response = get(api); + LOGGER.info("API: {} Response: {}", api, response); + api = "/health"; + response = get(api); + LOGGER.info("API: {} Response: {}", api, response); + api = "/api/test"; + response = get(api); + LOGGER.info("API: {} Response: {}", api, response); + } + + void testCreateGraph() throws Exception { + LOGGER.info("Testing server create graph..."); + Gson gson = new Gson(); + String api = "/graph/create"; + GraphSchema testGraph = new GraphSchema(); + String graphName = GRAPH_NAME; + testGraph.setName(graphName); + String response = post(api, gson.toJson(testGraph)); + LOGGER.info("API: {} Response: {}", api, response); + + api = "/graph/getGraphSchema"; + Map queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertEquals(gson.toJson(testGraph), response); + + VertexSchema vertexSchema = new VertexSchema("chunk", "id", Collections.singletonList("text")); + EdgeSchema edgeSchema = + new EdgeSchema("relation", "srcId", "dstId", Collections.singletonList("rel")); + testGraph.addVertex(vertexSchema); + testGraph.addEdge(edgeSchema); + + api = "/graph/addEntitySchema"; + queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + response = post(api, gson.toJson(vertexSchema), queryParams); + LOGGER.info("API: {} Response: {}", api, response); + + api = "/graph/addEntitySchema"; + queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + response = post(api, gson.toJson(edgeSchema), queryParams); + LOGGER.info("API: {} Response: {}", api, response); + + api = "/graph/getGraphSchema"; + queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertEquals(gson.toJson(testGraph), response); + } + + void testAddEntities() throws Exception { + LOGGER.info("Testing server add entities..."); + Gson gson = new Gson(); + String graphName = GRAPH_NAME; + + String api = "/graph/getGraphSchema"; + Map queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + String response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + + TextFileReader textFileReader = new TextFileReader(10000); + textFileReader.readFile("text/Confucius"); + List chunks = + IntStream.range(0, textFileReader.getRowCount()) .mapToObj(textFileReader::getRow) - .map(String::trim).collect(Collectors.toList()); - for(String chunk : chunks) { - String vid = UUID.randomUUID().toString().replace("-", ""); - Vertex chunkVertex = new Vertex("chunk", vid, Collections.singletonList(chunk)); - api = "/graph/insertEntity"; - queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - response = post(api, gson.toJson(chunkVertex), queryParams); - LOGGER.info("API: {} Response: {}", api, response); - } + .map(String::trim) + .collect(Collectors.toList()); + for (String chunk : chunks) { + String vid = UUID.randomUUID().toString().replace("-", ""); + Vertex chunkVertex = new Vertex("chunk", vid, Collections.singletonList(chunk)); + api = "/graph/insertEntity"; + queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + response = post(api, gson.toJson(chunkVertex), queryParams); + LOGGER.info("API: {} Response: {}", api, response); } - - void testQueries() throws Exception { - LOGGER.info("Testing server queries..."); - String graphName = GRAPH_NAME; - String sessionId = null; - String api = "/query/context"; - Map queryParams = new HashMap<>(); - queryParams.put("graphName", graphName); - String response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertNotNull(response); - sessionId = response; - - api = "/query/exec"; - queryParams = new HashMap<>(); - queryParams.put("sessionId", sessionId); - queryParams.put("query", "Who is Confucius?"); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertNotNull(response); - - api = "/query/result"; - queryParams = new HashMap<>(); - queryParams.put("sessionId", sessionId); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertNotNull(response); - - api = "/query/exec"; - queryParams = new HashMap<>(); - queryParams.put("sessionId", sessionId); - queryParams.put("query", "What did he say?"); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertNotNull(response); - - api = "/query/result"; - queryParams = new HashMap<>(); - queryParams.put("sessionId", sessionId); - response = post(api, "", queryParams); - LOGGER.info("API: {} Response: {}", api, response); - Assertions.assertNotNull(response); - } - + } + + void testQueries() throws Exception { + LOGGER.info("Testing server queries..."); + String graphName = GRAPH_NAME; + String sessionId = null; + String api = "/query/context"; + Map queryParams = new HashMap<>(); + queryParams.put("graphName", graphName); + String response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertNotNull(response); + sessionId = response; + + api = "/query/exec"; + queryParams = new HashMap<>(); + queryParams.put("sessionId", sessionId); + queryParams.put("query", "Who is Confucius?"); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertNotNull(response); + + api = "/query/result"; + queryParams = new HashMap<>(); + queryParams.put("sessionId", sessionId); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertNotNull(response); + + api = "/query/exec"; + queryParams = new HashMap<>(); + queryParams.put("sessionId", sessionId); + queryParams.put("query", "What did he say?"); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertNotNull(response); + + api = "/query/result"; + queryParams = new HashMap<>(); + queryParams.put("sessionId", sessionId); + response = post(api, "", queryParams); + LOGGER.info("API: {} Response: {}", api, response); + Assertions.assertNotNull(response); + } } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MockChatRobot.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MockChatRobot.java index 2b1248247..fbf40ee1c 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MockChatRobot.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MockChatRobot.java @@ -24,26 +24,26 @@ public class MockChatRobot { - private ModelConfig modelInfo; + private ModelConfig modelInfo; - public MockChatRobot() { - this.modelInfo = new ModelConfig(); - } + public MockChatRobot() { + this.modelInfo = new ModelConfig(); + } - public MockChatRobot(String model) { - this.modelInfo = new ModelConfig(); - this.modelInfo.setModel(model); - } + public MockChatRobot(String model) { + this.modelInfo = new ModelConfig(); + this.modelInfo.setModel(model); + } - public EmbeddingService.EmbeddingResult embeddingSingle(String input) { - return new EmbeddingService.EmbeddingResult(input, new double[0]); - } + public EmbeddingService.EmbeddingResult embeddingSingle(String input) { + return new EmbeddingService.EmbeddingResult(input, new double[0]); + } - public ModelConfig getModelInfo() { - return modelInfo; - } + public ModelConfig getModelInfo() { + return modelInfo; + } - public void setModelInfo(ModelConfig modelInfo) { - this.modelInfo = modelInfo; - } + public void setModelInfo(ModelConfig modelInfo) { + this.modelInfo = modelInfo; + } } diff --git a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MutableGraphTest.java b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MutableGraphTest.java index 1a92d7849..1206a2049 100644 --- a/geaflow-ai/src/test/java/org/apache/geaflow/ai/MutableGraphTest.java +++ b/geaflow-ai/src/test/java/org/apache/geaflow/ai/MutableGraphTest.java @@ -23,6 +23,7 @@ import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; + import org.apache.geaflow.ai.common.config.Constants; import org.apache.geaflow.ai.consolidate.ConsolidateServer; import org.apache.geaflow.ai.graph.GraphEntity; @@ -42,195 +43,218 @@ public class MutableGraphTest { - private static final Logger LOGGER = LoggerFactory.getLogger(MutableGraphTest.class); - - @Test - public void testMutableGraph() { - GraphSchema graphSchema = new GraphSchema(); - VertexSchema vertexSchema = new VertexSchema("chunk", "id", - Collections.singletonList("text")); - EdgeSchema edgeSchema = new EdgeSchema("relation", "srcId", "dstId", - Collections.singletonList("rel")); - graphSchema.addVertex(vertexSchema); - graphSchema.addEdge(edgeSchema); - Map entities = new HashMap<>(); - entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); - entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); - MemoryGraph graph = new MemoryGraph(graphSchema, entities); - - LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); - MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph(graph); - LOGGER.info("Success to init empty graph."); - - EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); - indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); - LOGGER.info("Success to init EntityAttributeIndexStore."); - - GraphMemoryServer server = new GraphMemoryServer(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - LOGGER.info("Success to init GraphMemoryServer."); - - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "apple", Collections.singletonList("apple is a kind of fruit."))); - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "banana", Collections.singletonList("banana is a kind of fruit."))); - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "grape", Collections.singletonList("grape is a kind of fruit."))); - - String query = "How about apple?"; - String result = searchInGraph(server, graphAccessor, query, 1); - LOGGER.info("query: {} result: {}", query, result); - Assertions.assertTrue(result.contains("apple")); - - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "red", Collections.singletonList("red is a kind of color."))); - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "yellow", Collections.singletonList("yellow is a kind of color."))); - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - "purple", Collections.singletonList("purple is a kind of color."))); - memoryMutableGraph.addEdge(new Edge(edgeSchema.getName(), - "apple", "red", Collections.singletonList("apple is red."))); - memoryMutableGraph.addEdge(new Edge(edgeSchema.getName(), - "banana", "yellow", Collections.singletonList("apple is yellow."))); - memoryMutableGraph.addEdge(new Edge(edgeSchema.getName(), - "grape", "purple", Collections.singletonList("apple is purple."))); - - query = "What color is apple?"; - result = searchInGraph(server, graphAccessor, query, 3); - LOGGER.info("query: {} result: {}", query, result); - Assertions.assertTrue(result.contains("apple is red")); - - memoryMutableGraph.updateVertex(new Vertex(vertexSchema.getName(), - "red", Collections.singletonList("red is not a kind of fruit."))); - - query = "How about red?"; - result = searchInGraph(server, graphAccessor, query, 1); - LOGGER.info("query: {} result: {}", query, result); - Assertions.assertTrue(result.contains("red is not a kind of fruit.")); - - memoryMutableGraph.removeVertex(vertexSchema.getName(), "yellow"); - query = "How about yellow?"; - result = searchInGraph(server, graphAccessor, query, 1); - LOGGER.info("query: {} result: {}", query, result); - Assertions.assertFalse(result.contains("yellow is a kind of color.")); - - memoryMutableGraph.removeEdge(new Edge(edgeSchema.getName(), - "apple", "red", Collections.singletonList("apple is red."))); - - query = "What color is apple?"; - result = searchInGraph(server, graphAccessor, query, 3); - LOGGER.info("query: {} result: {}", query, result); - Assertions.assertFalse(result.contains("apple is red.")); - Assertions.assertTrue(result.contains("apple is a kind of fruit.")); - Assertions.assertTrue(result.contains("red is not a kind of fruit.")); + private static final Logger LOGGER = LoggerFactory.getLogger(MutableGraphTest.class); - } + @Test + public void testMutableGraph() { + GraphSchema graphSchema = new GraphSchema(); + VertexSchema vertexSchema = new VertexSchema("chunk", "id", Collections.singletonList("text")); + EdgeSchema edgeSchema = + new EdgeSchema("relation", "srcId", "dstId", Collections.singletonList("rel")); + graphSchema.addVertex(vertexSchema); + graphSchema.addEdge(edgeSchema); + Map entities = new HashMap<>(); + entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); + entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); + MemoryGraph graph = new MemoryGraph(graphSchema, entities); - private String searchInGraph(GraphMemoryServer server, - LocalMemoryGraphAccessor graphAccessor, - String query, int times) { - String sessionId = server.createSession(); - Context context = null; - for (int i = 0; i < times; i++) { - VectorSearch search = new VectorSearch(null, sessionId); - search.addVector(new KeywordVector(query)); - String searchResult = server.search(search); - Assertions.assertNotNull(searchResult); - context = server.verbalize(sessionId, - new SubgraphSemanticPromptFunction(graphAccessor)); - } - assert context != null; - return context.toString(); + LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); + MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph(graph); + LOGGER.info("Success to init empty graph."); + + EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); + indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); + LOGGER.info("Success to init EntityAttributeIndexStore."); + + GraphMemoryServer server = new GraphMemoryServer(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + LOGGER.info("Success to init GraphMemoryServer."); + + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), + "apple", + Collections.singletonList("apple is a kind of fruit."))); + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), + "banana", + Collections.singletonList("banana is a kind of fruit."))); + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), + "grape", + Collections.singletonList("grape is a kind of fruit."))); + + String query = "How about apple?"; + String result = searchInGraph(server, graphAccessor, query, 1); + LOGGER.info("query: {} result: {}", query, result); + Assertions.assertTrue(result.contains("apple")); + + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), "red", Collections.singletonList("red is a kind of color."))); + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), + "yellow", + Collections.singletonList("yellow is a kind of color."))); + memoryMutableGraph.addVertex( + new Vertex( + vertexSchema.getName(), + "purple", + Collections.singletonList("purple is a kind of color."))); + memoryMutableGraph.addEdge( + new Edge(edgeSchema.getName(), "apple", "red", Collections.singletonList("apple is red."))); + memoryMutableGraph.addEdge( + new Edge( + edgeSchema.getName(), + "banana", + "yellow", + Collections.singletonList("apple is yellow."))); + memoryMutableGraph.addEdge( + new Edge( + edgeSchema.getName(), + "grape", + "purple", + Collections.singletonList("apple is purple."))); + + query = "What color is apple?"; + result = searchInGraph(server, graphAccessor, query, 3); + LOGGER.info("query: {} result: {}", query, result); + Assertions.assertTrue(result.contains("apple is red")); + + memoryMutableGraph.updateVertex( + new Vertex( + vertexSchema.getName(), + "red", + Collections.singletonList("red is not a kind of fruit."))); + + query = "How about red?"; + result = searchInGraph(server, graphAccessor, query, 1); + LOGGER.info("query: {} result: {}", query, result); + Assertions.assertTrue(result.contains("red is not a kind of fruit.")); + + memoryMutableGraph.removeVertex(vertexSchema.getName(), "yellow"); + query = "How about yellow?"; + result = searchInGraph(server, graphAccessor, query, 1); + LOGGER.info("query: {} result: {}", query, result); + Assertions.assertFalse(result.contains("yellow is a kind of color.")); + + memoryMutableGraph.removeEdge( + new Edge(edgeSchema.getName(), "apple", "red", Collections.singletonList("apple is red."))); + + query = "What color is apple?"; + result = searchInGraph(server, graphAccessor, query, 3); + LOGGER.info("query: {} result: {}", query, result); + Assertions.assertFalse(result.contains("apple is red.")); + Assertions.assertTrue(result.contains("apple is a kind of fruit.")); + Assertions.assertTrue(result.contains("red is not a kind of fruit.")); + } + + private String searchInGraph( + GraphMemoryServer server, LocalMemoryGraphAccessor graphAccessor, String query, int times) { + String sessionId = server.createSession(); + Context context = null; + for (int i = 0; i < times; i++) { + VectorSearch search = new VectorSearch(null, sessionId); + search.addVector(new KeywordVector(query)); + String searchResult = server.search(search); + Assertions.assertNotNull(searchResult); + context = server.verbalize(sessionId, new SubgraphSemanticPromptFunction(graphAccessor)); } + assert context != null; + return context.toString(); + } + + @Test + public void testConsolidation() throws IOException { + GraphSchema graphSchema = new GraphSchema(); + VertexSchema vertexSchema = new VertexSchema("chunk", "id", Collections.singletonList("text")); + EdgeSchema edgeSchema = + new EdgeSchema("relation", "srcId", "dstId", Collections.singletonList("rel")); + graphSchema.addVertex(vertexSchema); + graphSchema.addEdge(edgeSchema); + Map entities = new HashMap<>(); + entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); + entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); + MemoryGraph graph = new MemoryGraph(graphSchema, entities); + + LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); + MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph(graph); + LOGGER.info("Success to init empty graph."); - @Test - public void testConsolidation() throws IOException { - GraphSchema graphSchema = new GraphSchema(); - VertexSchema vertexSchema = new VertexSchema("chunk", "id", - Collections.singletonList("text")); - EdgeSchema edgeSchema = new EdgeSchema("relation", "srcId", "dstId", - Collections.singletonList("rel")); - graphSchema.addVertex(vertexSchema); - graphSchema.addEdge(edgeSchema); - Map entities = new HashMap<>(); - entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>())); - entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>())); - MemoryGraph graph = new MemoryGraph(graphSchema, entities); - - LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph); - MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph(graph); - LOGGER.info("Success to init empty graph."); - - EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); - indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); - LOGGER.info("Success to init EntityAttributeIndexStore."); - - GraphMemoryServer server = new GraphMemoryServer(); - server.addGraphAccessor(graphAccessor); - server.addIndexStore(indexStore); - LOGGER.info("Success to init GraphMemoryServer."); - - TextFileReader textFileReader = new TextFileReader(10000); - textFileReader.readFile("text/Confucius"); - List chunks = IntStream.range(0, textFileReader.getRowCount()) + EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore(); + indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor)); + LOGGER.info("Success to init EntityAttributeIndexStore."); + + GraphMemoryServer server = new GraphMemoryServer(); + server.addGraphAccessor(graphAccessor); + server.addIndexStore(indexStore); + LOGGER.info("Success to init GraphMemoryServer."); + + TextFileReader textFileReader = new TextFileReader(10000); + textFileReader.readFile("text/Confucius"); + List chunks = + IntStream.range(0, textFileReader.getRowCount()) .mapToObj(textFileReader::getRow) - .map(String::trim).collect(Collectors.toList()); - for(String chunk : chunks) { - String vid = UUID.randomUUID().toString().replace("-", ""); - memoryMutableGraph.addVertex(new Vertex(vertexSchema.getName(), - vid, Collections.singletonList(chunk))); - } + .map(String::trim) + .collect(Collectors.toList()); + for (String chunk : chunks) { + String vid = UUID.randomUUID().toString().replace("-", ""); + memoryMutableGraph.addVertex( + new Vertex(vertexSchema.getName(), vid, Collections.singletonList(chunk))); + } - GraphVertex vertexForTest = graphAccessor.scanVertex().next(); - String testId = vertexForTest.getVertex().getId(); - String query = vertexForTest.getVertex().getValues().toString(); - List relatedResult = searchGraphEntities(server, query, 3); - Assertions.assertFalse(relatedResult.isEmpty()); - LOGGER.info("query: {} result: {}", query, relatedResult); - - for (GraphEntity relatedEntity : relatedResult) { - if (relatedEntity instanceof GraphVertex) { - String relatedId = ((GraphVertex) relatedEntity).getVertex().getId(); - Assertions.assertTrue(graphAccessor.getEdge( - Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, testId, relatedId - ).isEmpty()); - } - } + GraphVertex vertexForTest = graphAccessor.scanVertex().next(); + String testId = vertexForTest.getVertex().getId(); + String query = vertexForTest.getVertex().getValues().toString(); + List relatedResult = searchGraphEntities(server, query, 3); + Assertions.assertFalse(relatedResult.isEmpty()); + LOGGER.info("query: {} result: {}", query, relatedResult); + + for (GraphEntity relatedEntity : relatedResult) { + if (relatedEntity instanceof GraphVertex) { + String relatedId = ((GraphVertex) relatedEntity).getVertex().getId(); + Assertions.assertTrue( + graphAccessor + .getEdge(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, testId, relatedId) + .isEmpty()); + } + } + + ConsolidateServer consolidate = new ConsolidateServer(); + int taskId = consolidate.executeConsolidateTask(graphAccessor, memoryMutableGraph); + LOGGER.info("Success to run consolidation task, taskId: {}.", taskId); + relatedResult = searchGraphEntities(server, query, 3); + Assertions.assertFalse(relatedResult.isEmpty()); + LOGGER.info("query: {} result: {}", query, relatedResult); - ConsolidateServer consolidate = new ConsolidateServer(); - int taskId = consolidate.executeConsolidateTask( - graphAccessor, memoryMutableGraph); - LOGGER.info("Success to run consolidation task, taskId: {}.", taskId); - - relatedResult = searchGraphEntities(server, query, 3); - Assertions.assertFalse(relatedResult.isEmpty()); - LOGGER.info("query: {} result: {}", query, relatedResult); - - //Test for at least one related entity in result - int existNum = 0; - for (GraphEntity relatedEntity : relatedResult) { - if (relatedEntity instanceof GraphVertex) { - String relatedId = ((GraphVertex) relatedEntity).getVertex().getId(); - if(!graphAccessor.getEdge(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, - testId, relatedId).isEmpty()) { - existNum++; - } - } + // Test for at least one related entity in result + int existNum = 0; + for (GraphEntity relatedEntity : relatedResult) { + if (relatedEntity instanceof GraphVertex) { + String relatedId = ((GraphVertex) relatedEntity).getVertex().getId(); + if (!graphAccessor + .getEdge(Constants.CONSOLIDATE_KEYWORD_RELATION_LABEL, testId, relatedId) + .isEmpty()) { + existNum++; } - LOGGER.info("relatedResult size: {} found size: {}", relatedResult.size(), existNum); - Assertions.assertTrue(existNum > 0); + } } + LOGGER.info("relatedResult size: {} found size: {}", relatedResult.size(), existNum); + Assertions.assertTrue(existNum > 0); + } - private List searchGraphEntities(GraphMemoryServer server, - String query, int times) { - String sessionId = server.createSession(); - for (int i = 0; i < times; i++) { - VectorSearch search = new VectorSearch(null, sessionId); - search.addVector(new KeywordVector(query)); - sessionId = server.search(search); - } - return server.getSessionEntities(sessionId); + private List searchGraphEntities(GraphMemoryServer server, String query, int times) { + String sessionId = server.createSession(); + for (int i = 0; i < times; i++) { + VectorSearch search = new VectorSearch(null, sessionId); + search.addVector(new KeywordVector(query)); + sessionId = server.search(search); } + return server.getSessionEntities(sessionId); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuditManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuditManager.java index 937568663..12be2286e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuditManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuditManager.java @@ -22,6 +22,4 @@ import org.apache.geaflow.console.biz.shared.view.AuditView; import org.apache.geaflow.console.common.dal.model.AuditSearch; -public interface AuditManager extends IdManager { - -} +public interface AuditManager extends IdManager {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthenticationManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthenticationManager.java index b065d047b..693a088aa 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthenticationManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthenticationManager.java @@ -24,13 +24,13 @@ public interface AuthenticationManager { - AuthenticationView login(String loginName, String password, boolean systemAdmin); + AuthenticationView login(String loginName, String password, boolean systemAdmin); - AuthenticationView authenticate(String token); + AuthenticationView authenticate(String token); - SessionView currentSession(); + SessionView currentSession(); - boolean switchSession(); + boolean switchSession(); - boolean logout(); + boolean logout(); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthorizationManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthorizationManager.java index e2ea5fd34..19a065823 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthorizationManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/AuthorizationManager.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.AuthorizationView; import org.apache.geaflow.console.common.dal.model.AuthorizationSearch; import org.apache.geaflow.console.common.util.exception.GeaflowSecurityException; @@ -30,9 +31,10 @@ public interface AuthorizationManager extends IdManager { - List getUserRoleTypes(String userId); + List getUserRoleTypes(String userId); - void hasRole(GeaflowRole... roles) throws GeaflowSecurityException; + void hasRole(GeaflowRole... roles) throws GeaflowSecurityException; - void hasAuthority(GeaflowAuthority authority, GeaflowResource resource) throws GeaflowSecurityException; + void hasAuthority(GeaflowAuthority authority, GeaflowResource resource) + throws GeaflowSecurityException; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ChatManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ChatManager.java index 2c07e189f..b74fb108d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ChatManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ChatManager.java @@ -19,15 +19,14 @@ package org.apache.geaflow.console.biz.shared; - import org.apache.geaflow.console.biz.shared.view.ChatView; import org.apache.geaflow.console.common.dal.model.ChatSearch; public interface ChatManager extends IdManager { - String callSync(ChatView chatView, boolean record, boolean withSchema); + String callSync(ChatView chatView, boolean record, boolean withSchema); - String callASync(ChatView chatView, boolean withSchema); + String callASync(ChatView chatView, boolean withSchema); - boolean dropByJobId(String jobId); + boolean dropByJobId(String jobId); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ClusterManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ClusterManager.java index 2a4b922ed..91879037e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ClusterManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ClusterManager.java @@ -19,11 +19,7 @@ package org.apache.geaflow.console.biz.shared; - import org.apache.geaflow.console.biz.shared.view.ClusterView; import org.apache.geaflow.console.common.dal.model.ClusterSearch; -public interface ClusterManager extends NameManager { - - -} +public interface ClusterManager extends NameManager {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ConfigManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ConfigManager.java index e6483e52a..0e3b20fbe 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ConfigManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ConfigManager.java @@ -20,18 +20,19 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.core.model.config.ConfigDescItem; public interface ConfigManager { - List getClusterConfig(); + List getClusterConfig(); - List getJobConfig(); + List getJobConfig(); - List getPluginCategories(); + List getPluginCategories(); - List getPluginCategoryTypes(GeaflowPluginCategory category); + List getPluginCategoryTypes(GeaflowPluginCategory category); - List getPluginConfig(GeaflowPluginCategory category, String type); + List getPluginConfig(GeaflowPluginCategory category, String type); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/DataManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/DataManager.java index 8598a04d3..22af84b19 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/DataManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/DataManager.java @@ -20,29 +20,30 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.IdView; import org.apache.geaflow.console.common.dal.model.IdSearch; import org.apache.geaflow.console.common.dal.model.PageList; public interface DataManager extends NameManager { - V getByName(String instanceName, String name); + V getByName(String instanceName, String name); - List getByNames(String instanceName, List names); + List getByNames(String instanceName, List names); - boolean dropByName(String instanceName, String name); + boolean dropByName(String instanceName, String name); - boolean dropByNames(String instanceName, List names); + boolean dropByNames(String instanceName, List names); - List create(String instanceName, List views); + List create(String instanceName, List views); - String create(String instanceName, V view); + String create(String instanceName, V view); - boolean update(String instanceName, List views); + boolean update(String instanceName, List views); - boolean updateByName(String instanceName, String name, V view); + boolean updateByName(String instanceName, String name, V view); - PageList searchByInstanceName(String instanceName, S search); + PageList searchByInstanceName(String instanceName, S search); - void createIfIdAbsent(String instanceName, List views); + void createIfIdAbsent(String instanceName, List views); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/EdgeManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/EdgeManager.java index 4ebcfd562..0983d7f67 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/EdgeManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/EdgeManager.java @@ -22,6 +22,4 @@ import org.apache.geaflow.console.biz.shared.view.EdgeView; import org.apache.geaflow.console.common.dal.model.EdgeSearch; -public interface EdgeManager extends DataManager { - -} +public interface EdgeManager extends DataManager {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/FunctionManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/FunctionManager.java index 1d67c789d..426267add 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/FunctionManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/FunctionManager.java @@ -25,9 +25,11 @@ public interface FunctionManager extends DataManager { - String createFunction(String instanceName, FunctionView view, MultipartFile functionFile, String fileId); + String createFunction( + String instanceName, FunctionView view, MultipartFile functionFile, String fileId); - boolean updateFunction(String instanceName, String functionName, FunctionView view, MultipartFile functionFile); + boolean updateFunction( + String instanceName, String functionName, FunctionView view, MultipartFile functionFile); - boolean deleteFunction(String instanceName, String functionName); + boolean deleteFunction(String instanceName, String functionName); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/GraphManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/GraphManager.java index a2f7db6b6..6b4b0db36 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/GraphManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/GraphManager.java @@ -19,17 +19,17 @@ package org.apache.geaflow.console.biz.shared; - import java.util.List; + import org.apache.geaflow.console.biz.shared.view.EndpointView; import org.apache.geaflow.console.biz.shared.view.GraphView; import org.apache.geaflow.console.common.dal.model.GraphSearch; public interface GraphManager extends DataManager { - boolean createEndpoints(String instanceName, String graphName, List endpoints); + boolean createEndpoints(String instanceName, String graphName, List endpoints); - boolean deleteEndpoints(String instanceName, String graphName, List endpoints); + boolean deleteEndpoints(String instanceName, String graphName, List endpoints); - boolean clean(String instanceName, String graphName); + boolean clean(String instanceName, String graphName); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/IdManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/IdManager.java index 05eec8b7a..2c13f607f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/IdManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/IdManager.java @@ -20,27 +20,28 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.IdView; import org.apache.geaflow.console.common.dal.model.IdSearch; import org.apache.geaflow.console.common.dal.model.PageList; public interface IdManager { - PageList search(S search); + PageList search(S search); - V get(String id); + V get(String id); - String create(V view); + String create(V view); - boolean updateById(String id, V view); + boolean updateById(String id, V view); - boolean drop(String id); + boolean drop(String id); - List get(List ids); + List get(List ids); - List create(List views); + List create(List views); - boolean update(List views); + boolean update(List views); - boolean drop(List ids); + boolean drop(List ids); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstallManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstallManager.java index 62b04f962..696f23a81 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstallManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstallManager.java @@ -23,7 +23,7 @@ public interface InstallManager { - InstallView get(); + InstallView get(); - boolean install(InstallView installView); + boolean install(InstallView installView); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstanceManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstanceManager.java index f0252f361..2fff69d6f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstanceManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/InstanceManager.java @@ -20,11 +20,11 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.InstanceView; import org.apache.geaflow.console.common.dal.model.InstanceSearch; public interface InstanceManager extends NameManager { - List search(); - + List search(); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/JobManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/JobManager.java index b25689827..8ed919feb 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/JobManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/JobManager.java @@ -19,15 +19,15 @@ package org.apache.geaflow.console.biz.shared; - import java.util.List; + import org.apache.geaflow.console.biz.shared.view.JobView; import org.apache.geaflow.console.common.dal.model.JobSearch; import org.springframework.web.multipart.MultipartFile; public interface JobManager extends IdManager { - String create(JobView jobView, MultipartFile functionFile, String fileId, List graphIds); + String create(JobView jobView, MultipartFile functionFile, String fileId, List graphIds); - boolean update(String jobId, JobView jobView, MultipartFile jarFile, String fileId); + boolean update(String jobId, JobView jobView, MultipartFile jarFile, String fileId); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/LLMManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/LLMManager.java index fbd960392..8a047c737 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/LLMManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/LLMManager.java @@ -19,13 +19,13 @@ package org.apache.geaflow.console.biz.shared; - import java.util.List; + import org.apache.geaflow.console.biz.shared.view.LLMView; import org.apache.geaflow.console.common.dal.model.LLMSearch; import org.apache.geaflow.console.common.util.type.GeaflowLLMType; public interface LLMManager extends NameManager { - List getLLMTypes(); + List getLLMTypes(); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/NameManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/NameManager.java index 9d33ddca5..3b5dd004a 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/NameManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/NameManager.java @@ -20,18 +20,19 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.IdView; import org.apache.geaflow.console.common.dal.model.IdSearch; public interface NameManager extends IdManager { - V getByName(String name); + V getByName(String name); - List getByNames(List names); + List getByNames(List names); - boolean updateByName(String name, V view); + boolean updateByName(String name, V view); - boolean dropByName(String name); + boolean dropByName(String name); - boolean dropByNames(List names); + boolean dropByNames(List names); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginConfigManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginConfigManager.java index 4a263eb8e..99a245021 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginConfigManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginConfigManager.java @@ -20,11 +20,12 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.PluginConfigView; import org.apache.geaflow.console.common.dal.model.PluginConfigSearch; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; public interface PluginConfigManager extends NameManager { - List getPluginConfigs(GeaflowPluginCategory category, String type); + List getPluginConfigs(GeaflowPluginCategory category, String type); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginManager.java index dd518e589..63f485a79 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/PluginManager.java @@ -25,8 +25,7 @@ public interface PluginManager extends NameManager { - String createPlugin(PluginView pluginView, MultipartFile jarPackage, String jarId); - - boolean updatePlugin(String pluginId, PluginView updateView, MultipartFile jarPackage); + String createPlugin(PluginView pluginView, MultipartFile jarPackage, String jarId); + boolean updatePlugin(String pluginId, PluginView updateView, MultipartFile jarPackage); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ReleaseManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ReleaseManager.java index 6c91ed5a9..e239ac8b5 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ReleaseManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/ReleaseManager.java @@ -25,7 +25,7 @@ public interface ReleaseManager extends IdManager { - String publish(String jobId); + String publish(String jobId); - boolean updateRelease(String jobId, ReleaseUpdateView packageUpdate); + boolean updateRelease(String jobId, ReleaseUpdateView packageUpdate); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/RemoteFileManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/RemoteFileManager.java index 33c08cea5..d0655b223 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/RemoteFileManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/RemoteFileManager.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared; import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.console.biz.shared.view.RemoteFileView; import org.apache.geaflow.console.common.dal.model.RemoteFileSearch; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @@ -27,14 +28,13 @@ public interface RemoteFileManager extends NameManager { - String create(RemoteFileView view, MultipartFile multipartFile); - - boolean upload(String remoteFileId, MultipartFile multipartFile); + String create(RemoteFileView view, MultipartFile multipartFile); - boolean download(String remoteFileId, HttpServletResponse response); + boolean upload(String remoteFileId, MultipartFile multipartFile); - boolean delete(String remoteFileId); + boolean download(String remoteFileId, HttpServletResponse response); - void deleteRefJar(String jarID, String refId, GeaflowResourceType resourceType); + boolean delete(String remoteFileId); + void deleteRefJar(String jarID, String refId, GeaflowResourceType resourceType); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/StatementManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/StatementManager.java index 9c50d6a9e..9bcf5630d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/StatementManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/StatementManager.java @@ -19,11 +19,10 @@ package org.apache.geaflow.console.biz.shared; - import org.apache.geaflow.console.biz.shared.view.StatementView; import org.apache.geaflow.console.common.dal.model.StatementSearch; public interface StatementManager extends IdManager { - boolean dropByJobId(String jobId); + boolean dropByJobId(String jobId); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/SystemConfigManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/SystemConfigManager.java index 8fef89fb9..77120e5ed 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/SystemConfigManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/SystemConfigManager.java @@ -19,20 +19,18 @@ package org.apache.geaflow.console.biz.shared; - import org.apache.geaflow.console.biz.shared.view.SystemConfigView; import org.apache.geaflow.console.common.dal.model.SystemConfigSearch; public interface SystemConfigManager extends NameManager { - SystemConfigView getConfig(String tenantId, String key); - - String getValue(String key); + SystemConfigView getConfig(String tenantId, String key); - boolean createConfig(SystemConfigView view); + String getValue(String key); - boolean updateConfig(String key, SystemConfigView view); + boolean createConfig(SystemConfigView view); - boolean deleteConfig(String tenantId, String key); + boolean updateConfig(String key, SystemConfigView view); + boolean deleteConfig(String tenantId, String key); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TableManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TableManager.java index 5a9decc51..c772498cc 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TableManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TableManager.java @@ -22,6 +22,4 @@ import org.apache.geaflow.console.biz.shared.view.TableView; import org.apache.geaflow.console.common.dal.model.TableSearch; -public interface TableManager extends DataManager { - -} +public interface TableManager extends DataManager {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TaskManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TaskManager.java index fc8327f1a..cf9842894 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TaskManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TaskManager.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.biz.shared; - import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.console.biz.shared.view.TaskStartupNotifyView; import org.apache.geaflow.console.biz.shared.view.TaskView; import org.apache.geaflow.console.common.dal.model.PageList; @@ -38,29 +38,29 @@ public interface TaskManager extends IdManager { - TaskView getByJobId(String jobId); + TaskView getByJobId(String jobId); - void operate(String taskId, GeaflowOperationType action); + void operate(String taskId, GeaflowOperationType action); - GeaflowTaskStatus queryStatus(String taskId, Boolean refresh); + GeaflowTaskStatus queryStatus(String taskId, Boolean refresh); - PageList queryPipelines(String taskId); + PageList queryPipelines(String taskId); - PageList queryCycles(String taskId, String pipelineId); + PageList queryCycles(String taskId, String pipelineId); - PageList queryErrors(String taskId); + PageList queryErrors(String taskId); - PageList queryMetricMeta(String taskId); + PageList queryMetricMeta(String taskId); - PageList queryMetrics(String taskId, GeaflowMetricQueryRequest queryRequest); + PageList queryMetrics(String taskId, GeaflowMetricQueryRequest queryRequest); - PageList queryOffsets(String taskId); + PageList queryOffsets(String taskId); - GeaflowHeartbeatInfo queryHeartbeat(String taskId); + GeaflowHeartbeatInfo queryHeartbeat(String taskId); - void startupNotify(String taskId, TaskStartupNotifyView requestData); + void startupNotify(String taskId, TaskStartupNotifyView requestData); - void download(String taskId, String path, HttpServletResponse response); + void download(String taskId, String path, HttpServletResponse response); - String getLogs(String taskId); + String getLogs(String taskId); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TenantManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TenantManager.java index 2bbcc6523..f156afe1c 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TenantManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/TenantManager.java @@ -21,12 +21,13 @@ import java.util.Collection; import java.util.Map; + import org.apache.geaflow.console.biz.shared.view.TenantView; import org.apache.geaflow.console.common.dal.model.TenantSearch; public interface TenantManager extends NameManager { - TenantView getActiveTenant(String userId); + TenantView getActiveTenant(String userId); - Map getTenantNames(Collection tenantIds); + Map getTenantNames(Collection tenantIds); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/UserManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/UserManager.java index 4b3efe6af..c2e55a68c 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/UserManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/UserManager.java @@ -21,20 +21,21 @@ import java.util.Collection; import java.util.Map; + import org.apache.geaflow.console.biz.shared.view.UserView; import org.apache.geaflow.console.common.dal.model.UserSearch; public interface UserManager extends NameManager { - String register(UserView view); + String register(UserView view); - UserView getUser(String userId); + UserView getUser(String userId); - String addUser(UserView view); + String addUser(UserView view); - boolean updateUser(String userId, UserView view); + boolean updateUser(String userId, UserView view); - boolean deleteUser(String userId); + boolean deleteUser(String userId); - Map getUserNames(Collection userIds); + Map getUserNames(Collection userIds); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VersionManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VersionManager.java index 1e531b95f..94659b2aa 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VersionManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VersionManager.java @@ -26,15 +26,17 @@ public interface VersionManager extends NameManager { - PageList searchVersions(VersionSearch search); + PageList searchVersions(VersionSearch search); - VersionView getVersion(String name); + VersionView getVersion(String name); - String createDefaultVersion(); + String createDefaultVersion(); - String createVersion(VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile); + String createVersion( + VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile); - boolean updateVersion(String name, VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile); + boolean updateVersion( + String name, VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile); - boolean deleteVersion(String versionName); + boolean deleteVersion(String versionName); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VertexManager.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VertexManager.java index 88991d288..6ecacf716 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VertexManager.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/VertexManager.java @@ -20,13 +20,14 @@ package org.apache.geaflow.console.biz.shared; import java.util.List; + import org.apache.geaflow.console.biz.shared.view.VertexView; import org.apache.geaflow.console.common.dal.model.VertexSearch; import org.apache.geaflow.console.core.model.data.GeaflowVertex; public interface VertexManager extends DataManager { - List getVerticesByGraphId(String graphId); + List getVerticesByGraphId(String graphId); - List getVerticesByEdgeId(String edgeId); + List getVerticesByEdgeId(String edgeId); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/aspect/ViewConverterAspect.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/aspect/ViewConverterAspect.java index 6d71c3fb4..091b197d0 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/aspect/ViewConverterAspect.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/aspect/ViewConverterAspect.java @@ -29,20 +29,20 @@ @Component public class ViewConverterAspect { - @Around("execution(* org.apache.geaflow.console.biz.shared.convert.*Converter.convert(..))") - public Object handle(ProceedingJoinPoint joinPoint) throws Throwable { - Object[] args = joinPoint.getArgs(); - if (args[0] == null) { - return null; - } - - Object result = joinPoint.proceed(args); + @Around("execution(* org.apache.geaflow.console.biz.shared.convert.*Converter.convert(..))") + public Object handle(ProceedingJoinPoint joinPoint) throws Throwable { + Object[] args = joinPoint.getArgs(); + if (args[0] == null) { + return null; + } - if (result instanceof GeaflowId) { - GeaflowId geaflowId = (GeaflowId) result; - geaflowId.validate(); - } + Object result = joinPoint.proceed(args); - return result; + if (result instanceof GeaflowId) { + GeaflowId geaflowId = (GeaflowId) result; + geaflowId.validate(); } + + return result; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuditViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuditViewConverter.java index 6bd310262..1a485c73e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuditViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuditViewConverter.java @@ -26,27 +26,27 @@ @Component public class AuditViewConverter extends IdViewConverter { - @Override - protected AuditView modelToView(GeaflowAudit model) { - AuditView view = super.modelToView(model); - view.setResourceType(model.getResourceType()); - view.setResourceId(model.getResourceId()); - view.setOperationType(model.getOperationType()); - view.setDetail(model.getDetail()); - return view; - } + @Override + protected AuditView modelToView(GeaflowAudit model) { + AuditView view = super.modelToView(model); + view.setResourceType(model.getResourceType()); + view.setResourceId(model.getResourceId()); + view.setOperationType(model.getOperationType()); + view.setDetail(model.getDetail()); + return view; + } - @Override - protected GeaflowAudit viewToModel(AuditView view) { - GeaflowAudit model = super.viewToModel(view); - model.setOperationType(view.getOperationType()); - model.setResourceId(view.getResourceId()); - model.setResourceType(view.getResourceType()); - model.setDetail(view.getDetail()); - return model; - } + @Override + protected GeaflowAudit viewToModel(AuditView view) { + GeaflowAudit model = super.viewToModel(view); + model.setOperationType(view.getOperationType()); + model.setResourceId(view.getResourceId()); + model.setResourceType(view.getResourceType()); + model.setDetail(view.getDetail()); + return model; + } - public GeaflowAudit convert(AuditView view) { - return viewToModel(view); - } + public GeaflowAudit convert(AuditView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthenticationViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthenticationViewConverter.java index fe5a4c3a0..d4fe716f5 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthenticationViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthenticationViewConverter.java @@ -27,12 +27,12 @@ @Component public class AuthenticationViewConverter { - public AuthenticationView convert(GeaflowAuthentication model) { - AuthenticationView view = new AuthenticationView(); - view.setUserId(model.getUserId()); - view.setSessionToken(model.getSessionToken()); - view.setSystemSession(model.isSystemSession()); - view.setAccessTime(DateTimeUtil.format(model.getAccessTime())); - return view; - } + public AuthenticationView convert(GeaflowAuthentication model) { + AuthenticationView view = new AuthenticationView(); + view.setUserId(model.getUserId()); + view.setSessionToken(model.getSessionToken()); + view.setSystemSession(model.isSystemSession()); + view.setAccessTime(DateTimeUtil.format(model.getAccessTime())); + return view; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthorizationViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthorizationViewConverter.java index e486a6779..2ddf149bb 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthorizationViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/AuthorizationViewConverter.java @@ -24,29 +24,30 @@ import org.springframework.stereotype.Component; @Component -public class AuthorizationViewConverter extends IdViewConverter { +public class AuthorizationViewConverter + extends IdViewConverter { - @Override - protected AuthorizationView modelToView(GeaflowAuthorization model) { - AuthorizationView view = super.modelToView(model); - view.setResourceType(model.getResourceType()); - view.setAuthorityType(model.getAuthorityType()); - view.setUserId(model.getUserId()); - view.setResourceId(model.getResourceId()); - return view; - } + @Override + protected AuthorizationView modelToView(GeaflowAuthorization model) { + AuthorizationView view = super.modelToView(model); + view.setResourceType(model.getResourceType()); + view.setAuthorityType(model.getAuthorityType()); + view.setUserId(model.getUserId()); + view.setResourceId(model.getResourceId()); + return view; + } - @Override - protected GeaflowAuthorization viewToModel(AuthorizationView view) { - GeaflowAuthorization model = super.viewToModel(view); - model.setResourceType(view.getResourceType()); - model.setAuthorityType(view.getAuthorityType()); - model.setUserId(view.getUserId()); - model.setResourceId(view.getResourceId()); - return model; - } + @Override + protected GeaflowAuthorization viewToModel(AuthorizationView view) { + GeaflowAuthorization model = super.viewToModel(view); + model.setResourceType(view.getResourceType()); + model.setAuthorityType(view.getAuthorityType()); + model.setUserId(view.getUserId()); + model.setResourceId(view.getResourceId()); + return model; + } - public GeaflowAuthorization convert(AuthorizationView view) { - return viewToModel(view); - } + public GeaflowAuthorization convert(AuthorizationView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ChatViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ChatViewConverter.java index 5a3e9d0b5..ba20e9275 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ChatViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ChatViewConverter.java @@ -26,35 +26,34 @@ @Component public class ChatViewConverter extends IdViewConverter { - @Override - public void merge(ChatView view, ChatView updateView) { - super.merge(view, updateView); - } - - @Override - protected ChatView modelToView(GeaflowChat model) { - ChatView view = super.modelToView(model); - view.setModelId(model.getModelId()); - view.setAnswer(model.getAnswer()); - view.setPrompt(model.getPrompt()); - view.setStatus(model.getStatus()); - view.setJobId(model.getJobId()); - return view; - } - - @Override - protected GeaflowChat viewToModel(ChatView view) { - GeaflowChat model = super.viewToModel(view); - model.setModelId(view.getModelId()); - model.setAnswer(view.getAnswer()); - model.setPrompt(view.getPrompt()); - model.setStatus(view.getStatus()); - model.setJobId(view.getJobId()); - return model; - } - - public GeaflowChat convert(ChatView view) { - return viewToModel(view); - } - + @Override + public void merge(ChatView view, ChatView updateView) { + super.merge(view, updateView); + } + + @Override + protected ChatView modelToView(GeaflowChat model) { + ChatView view = super.modelToView(model); + view.setModelId(model.getModelId()); + view.setAnswer(model.getAnswer()); + view.setPrompt(model.getPrompt()); + view.setStatus(model.getStatus()); + view.setJobId(model.getJobId()); + return view; + } + + @Override + protected GeaflowChat viewToModel(ChatView view) { + GeaflowChat model = super.viewToModel(view); + model.setModelId(view.getModelId()); + model.setAnswer(view.getAnswer()); + model.setPrompt(view.getPrompt()); + model.setStatus(view.getStatus()); + model.setJobId(view.getJobId()); + return model; + } + + public GeaflowChat convert(ChatView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ClusterViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ClusterViewConverter.java index e4d49da68..827fdf2b7 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ClusterViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ClusterViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.ClusterView; import org.apache.geaflow.console.core.model.cluster.GeaflowCluster; import org.springframework.stereotype.Component; @@ -27,31 +28,30 @@ @Component public class ClusterViewConverter extends NameViewConverter { - @Override - public void merge(ClusterView view, ClusterView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getType()).ifPresent(view::setType); - Optional.ofNullable(updateView.getConfig()).ifPresent(view::setConfig); - } - - @Override - protected ClusterView modelToView(GeaflowCluster model) { - ClusterView view = super.modelToView(model); - view.setType(model.getType()); - view.setConfig(model.getConfig()); - return view; - } - - @Override - protected GeaflowCluster viewToModel(ClusterView view) { - GeaflowCluster model = super.viewToModel(view); - model.setType(view.getType()); - model.setConfig(view.getConfig()); - return model; - } - - public GeaflowCluster convert(ClusterView view) { - return viewToModel(view); - } - + @Override + public void merge(ClusterView view, ClusterView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getType()).ifPresent(view::setType); + Optional.ofNullable(updateView.getConfig()).ifPresent(view::setConfig); + } + + @Override + protected ClusterView modelToView(GeaflowCluster model) { + ClusterView view = super.modelToView(model); + view.setType(model.getType()); + view.setConfig(model.getConfig()); + return view; + } + + @Override + protected GeaflowCluster viewToModel(ClusterView view) { + GeaflowCluster model = super.viewToModel(view); + model.setType(view.getType()); + model.setConfig(view.getConfig()); + return model; + } + + public GeaflowCluster convert(ClusterView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/DataViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/DataViewConverter.java index f25f34322..ad1bcfa3f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/DataViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/DataViewConverter.java @@ -22,6 +22,5 @@ import org.apache.geaflow.console.biz.shared.view.DataView; import org.apache.geaflow.console.core.model.data.GeaflowData; -public abstract class DataViewConverter extends NameViewConverter { - -} +public abstract class DataViewConverter + extends NameViewConverter {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/EdgeViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/EdgeViewConverter.java index 1568bdd35..83a6d8688 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/EdgeViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/EdgeViewConverter.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.EdgeView; import org.apache.geaflow.console.common.util.type.GeaflowStructType; import org.apache.geaflow.console.core.model.data.GeaflowEdge; @@ -30,29 +31,29 @@ @Component public class EdgeViewConverter extends StructViewConverter { - @Override - protected EdgeView modelToView(GeaflowEdge model) { - EdgeView edgeView = super.modelToView(model); - edgeView.setType(GeaflowStructType.EDGE); - return edgeView; - } - - @Override - public void merge(EdgeView view, EdgeView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); - } - - @Override - protected GeaflowEdge viewToModel(EdgeView view) { - GeaflowEdge edge = super.viewToModel(view); - edge.setType(GeaflowStructType.EDGE); - return edge; - } - - public GeaflowEdge converter(EdgeView view, List fields) { - GeaflowEdge edge = viewToModel(view); - edge.addFields(fields); - return edge; - } + @Override + protected EdgeView modelToView(GeaflowEdge model) { + EdgeView edgeView = super.modelToView(model); + edgeView.setType(GeaflowStructType.EDGE); + return edgeView; + } + + @Override + public void merge(EdgeView view, EdgeView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); + } + + @Override + protected GeaflowEdge viewToModel(EdgeView view) { + GeaflowEdge edge = super.viewToModel(view); + edge.setType(GeaflowStructType.EDGE); + return edge; + } + + public GeaflowEdge converter(EdgeView view, List fields) { + GeaflowEdge edge = viewToModel(view); + edge.addFields(fields); + return edge; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FieldViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FieldViewConverter.java index 8a420f918..adb795c1c 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FieldViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FieldViewConverter.java @@ -26,23 +26,23 @@ @Component public class FieldViewConverter extends NameViewConverter { - @Override - protected FieldView modelToView(GeaflowField model) { - FieldView fieldView = super.modelToView(model); - fieldView.setType(model.getType()); - fieldView.setCategory(model.getCategory()); - return fieldView; - } + @Override + protected FieldView modelToView(GeaflowField model) { + FieldView fieldView = super.modelToView(model); + fieldView.setType(model.getType()); + fieldView.setCategory(model.getCategory()); + return fieldView; + } - @Override - protected GeaflowField viewToModel(FieldView view) { - GeaflowField field = super.viewToModel(view); - field.setType(view.getType()); - field.setCategory(view.getCategory()); - return field; - } + @Override + protected GeaflowField viewToModel(FieldView view) { + GeaflowField field = super.viewToModel(view); + field.setType(view.getType()); + field.setCategory(view.getCategory()); + return field; + } - public GeaflowField convert(FieldView view) { - return viewToModel(view); - } + public GeaflowField convert(FieldView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FunctionViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FunctionViewConverter.java index 1f610f9ae..19bbae82a 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FunctionViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/FunctionViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.FunctionView; import org.apache.geaflow.console.core.model.data.GeaflowFunction; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; @@ -29,33 +30,32 @@ @Component public class FunctionViewConverter extends DataViewConverter { - @Autowired - private RemoteFileViewConverter remoteFileViewConverter; - - @Override - public void merge(FunctionView view, FunctionView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getEntryClass()).ifPresent(view::setEntryClass); - } - - @Override - protected FunctionView modelToView(GeaflowFunction model) { - FunctionView view = super.modelToView(model); - view.setJarPackage(remoteFileViewConverter.convert(model.getJarPackage())); - view.setEntryClass(model.getEntryClass()); - return view; - } - - @Override - protected GeaflowFunction viewToModel(FunctionView view) { - GeaflowFunction geaflowFunction = super.viewToModel(view); - geaflowFunction.setEntryClass(view.getEntryClass()); - return geaflowFunction; - } - - public GeaflowFunction convert(FunctionView entity, GeaflowRemoteFile jarPackage) { - GeaflowFunction model = viewToModel(entity); - model.setJarPackage(jarPackage); - return model; - } + @Autowired private RemoteFileViewConverter remoteFileViewConverter; + + @Override + public void merge(FunctionView view, FunctionView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getEntryClass()).ifPresent(view::setEntryClass); + } + + @Override + protected FunctionView modelToView(GeaflowFunction model) { + FunctionView view = super.modelToView(model); + view.setJarPackage(remoteFileViewConverter.convert(model.getJarPackage())); + view.setEntryClass(model.getEntryClass()); + return view; + } + + @Override + protected GeaflowFunction viewToModel(FunctionView view) { + GeaflowFunction geaflowFunction = super.viewToModel(view); + geaflowFunction.setEntryClass(view.getEntryClass()); + return geaflowFunction; + } + + public GeaflowFunction convert(FunctionView entity, GeaflowRemoteFile jarPackage) { + GeaflowFunction model = viewToModel(entity); + model.setJarPackage(jarPackage); + return model; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/GraphViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/GraphViewConverter.java index 9261a98e0..78b262c42 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/GraphViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/GraphViewConverter.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.EdgeView; import org.apache.geaflow.console.biz.shared.view.EndpointView; import org.apache.geaflow.console.biz.shared.view.GraphView; @@ -38,63 +39,74 @@ @Component public class GraphViewConverter extends DataViewConverter { - @Autowired - private VertexViewConverter vertexViewConverter; + @Autowired private VertexViewConverter vertexViewConverter; - @Autowired - private EdgeViewConverter edgeViewConverter; + @Autowired private EdgeViewConverter edgeViewConverter; - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; - @Override - public void merge(GraphView view, GraphView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getVertices()).ifPresent(view::setVertices); - Optional.ofNullable(updateView.getEdges()).ifPresent(view::setEdges); - Optional.ofNullable(updateView.getPluginConfig()).ifPresent(e -> { - // update pluginConfig info - e.setId(view.getPluginConfig().getId()); - e.setCategory(GeaflowPluginCategory.GRAPH); - view.setPluginConfig(e); - }); - } + @Override + public void merge(GraphView view, GraphView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getVertices()).ifPresent(view::setVertices); + Optional.ofNullable(updateView.getEdges()).ifPresent(view::setEdges); + Optional.ofNullable(updateView.getPluginConfig()) + .ifPresent( + e -> { + // update pluginConfig info + e.setId(view.getPluginConfig().getId()); + e.setCategory(GeaflowPluginCategory.GRAPH); + view.setPluginConfig(e); + }); + } - @Override - protected GraphView modelToView(GeaflowGraph model) { - GraphView graphView = super.modelToView(model); - graphView.setPluginConfig(pluginConfigViewConverter.modelToView(model.getPluginConfig())); - // cache model for vertex/edge - HashMap vertexMap = new HashMap<>(); - HashMap edgeMap = new HashMap<>(); - List vertexViews = ListUtil.convert(model.getVertices().values(), e -> { - vertexMap.putIfAbsent(e.getId(), e); - return vertexViewConverter.modelToView(e); - }); - List edgeViews = ListUtil.convert(model.getEdges().values(), e -> { - edgeMap.putIfAbsent(e.getId(), e); - return edgeViewConverter.modelToView(e); - }); + @Override + protected GraphView modelToView(GeaflowGraph model) { + GraphView graphView = super.modelToView(model); + graphView.setPluginConfig(pluginConfigViewConverter.modelToView(model.getPluginConfig())); + // cache model for vertex/edge + HashMap vertexMap = new HashMap<>(); + HashMap edgeMap = new HashMap<>(); + List vertexViews = + ListUtil.convert( + model.getVertices().values(), + e -> { + vertexMap.putIfAbsent(e.getId(), e); + return vertexViewConverter.modelToView(e); + }); + List edgeViews = + ListUtil.convert( + model.getEdges().values(), + e -> { + edgeMap.putIfAbsent(e.getId(), e); + return edgeViewConverter.modelToView(e); + }); - graphView.setVertices(vertexViews); - graphView.setEdges(edgeViews); + graphView.setVertices(vertexViews); + graphView.setEdges(edgeViews); - // set endpoints - List endpointViews = ListUtil.convert(model.getEndpoints(), e -> - new EndpointView(edgeMap.get(e.getEdgeId()).getName(), - vertexMap.get(e.getSourceId()).getName(), - vertexMap.get(e.getTargetId()).getName()) - ); - graphView.setEndpoints(endpointViews); - return graphView; - } + // set endpoints + List endpointViews = + ListUtil.convert( + model.getEndpoints(), + e -> + new EndpointView( + edgeMap.get(e.getEdgeId()).getName(), + vertexMap.get(e.getSourceId()).getName(), + vertexMap.get(e.getTargetId()).getName())); + graphView.setEndpoints(endpointViews); + return graphView; + } - public GeaflowGraph convert(GraphView view, List vertices, List edges, - GeaflowPluginConfig pluginConfig) { - GeaflowGraph graph = super.viewToModel(view); - graph.addVertices(vertices); - graph.addEdges(edges); - graph.setPluginConfig(pluginConfig); - return graph; - } + public GeaflowGraph convert( + GraphView view, + List vertices, + List edges, + GeaflowPluginConfig pluginConfig) { + GeaflowGraph graph = super.viewToModel(view); + graph.addVertices(vertices); + graph.addEdges(edges); + graph.setPluginConfig(pluginConfig); + return graph; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/IdViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/IdViewConverter.java index 567a1d9f2..1848ddccd 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/IdViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/IdViewConverter.java @@ -21,6 +21,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; + import org.apache.geaflow.console.biz.shared.view.IdView; import org.apache.geaflow.console.common.util.DateTimeUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; @@ -29,56 +30,56 @@ @SuppressWarnings("unchecked") public abstract class IdViewConverter { - public void merge(V view, V updateView) { + public void merge(V view, V updateView) {} - } + public V convert(M model) { + return modelToView(model); + } - public V convert(M model) { - return modelToView(model); - } + protected V modelToView(M model) { + Type[] args = + ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); - protected V modelToView(M model) { - Type[] args = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); + try { + V view = (V) ((Class) args[1]).newInstance(); - try { - V view = (V) ((Class) args[1]).newInstance(); + view.setTenantId(model.getTenantId()); + view.setId(model.getId()); + view.setCreateTime(DateTimeUtil.format(model.getGmtCreate())); + view.setCreatorId(model.getCreatorId()); + view.setModifyTime(DateTimeUtil.format(model.getGmtModified())); + view.setModifierId(model.getModifierId()); - view.setTenantId(model.getTenantId()); - view.setId(model.getId()); - view.setCreateTime(DateTimeUtil.format(model.getGmtCreate())); - view.setCreatorId(model.getCreatorId()); - view.setModifyTime(DateTimeUtil.format(model.getGmtModified())); - view.setModifierId(model.getModifierId()); + return view; - return view; - - } catch (Exception e) { - throw new GeaflowException("Convert id model to view failed", e); - } + } catch (Exception e) { + throw new GeaflowException("Convert id model to view failed", e); } + } - protected M viewToModel(V view) { - Type[] args = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); + protected M viewToModel(V view) { + Type[] args = + ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); - try { - M model = (M) ((Class) args[0]).newInstance(); + try { + M model = (M) ((Class) args[0]).newInstance(); - model.setId(view.getId()); + model.setId(view.getId()); - return model; + return model; - } catch (Exception e) { - throw new GeaflowException("Convert id view to model failed", e); - } + } catch (Exception e) { + throw new GeaflowException("Convert id view to model failed", e); } - - protected M viewToModel(V view, Class clazz) { - try { - M model = clazz.newInstance(); - model.setId(view.getId()); - return model; - } catch (Exception e) { - throw new GeaflowException("Convert id view to model failed", e); - } + } + + protected M viewToModel(V view, Class clazz) { + try { + M model = clazz.newInstance(); + model.setId(view.getId()); + return model; + } catch (Exception e) { + throw new GeaflowException("Convert id view to model failed", e); } + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstallViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstallViewConverter.java index eede80c8b..e07ebefc4 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstallViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstallViewConverter.java @@ -27,35 +27,35 @@ @Component public class InstallViewConverter extends IdViewConverter { - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; - - @Override - protected InstallView modelToView(GeaflowInstall model) { - InstallView view = super.modelToView(model); - view.setRuntimeClusterConfig(pluginConfigViewConverter.convert(model.getRuntimeClusterConfig())); - view.setRuntimeMetaConfig(pluginConfigViewConverter.convert(model.getRuntimeMetaConfig())); - view.setHaMetaConfig(pluginConfigViewConverter.convert(model.getHaMetaConfig())); - view.setMetricConfig(pluginConfigViewConverter.convert(model.getMetricConfig())); - view.setRemoteFileConfig(pluginConfigViewConverter.convert(model.getRemoteFileConfig())); - view.setDataConfig(pluginConfigViewConverter.convert(model.getDataConfig())); - return view; - } - - @Override - protected GeaflowInstall viewToModel(InstallView view) { - GeaflowInstall model = super.viewToModel(view); - model.setRuntimeClusterConfig(pluginConfigViewConverter.convert(view.getRuntimeClusterConfig())); - model.setRuntimeMetaConfig(pluginConfigViewConverter.convert(view.getRuntimeMetaConfig())); - model.setHaMetaConfig(pluginConfigViewConverter.convert(view.getHaMetaConfig())); - model.setMetricConfig(pluginConfigViewConverter.convert(view.getMetricConfig())); - model.setRemoteFileConfig(pluginConfigViewConverter.convert(view.getRemoteFileConfig())); - model.setDataConfig(pluginConfigViewConverter.convert(view.getDataConfig())); - return model; - } - - public GeaflowInstall convert(InstallView view) { - return viewToModel(view); - } - + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; + + @Override + protected InstallView modelToView(GeaflowInstall model) { + InstallView view = super.modelToView(model); + view.setRuntimeClusterConfig( + pluginConfigViewConverter.convert(model.getRuntimeClusterConfig())); + view.setRuntimeMetaConfig(pluginConfigViewConverter.convert(model.getRuntimeMetaConfig())); + view.setHaMetaConfig(pluginConfigViewConverter.convert(model.getHaMetaConfig())); + view.setMetricConfig(pluginConfigViewConverter.convert(model.getMetricConfig())); + view.setRemoteFileConfig(pluginConfigViewConverter.convert(model.getRemoteFileConfig())); + view.setDataConfig(pluginConfigViewConverter.convert(model.getDataConfig())); + return view; + } + + @Override + protected GeaflowInstall viewToModel(InstallView view) { + GeaflowInstall model = super.viewToModel(view); + model.setRuntimeClusterConfig( + pluginConfigViewConverter.convert(view.getRuntimeClusterConfig())); + model.setRuntimeMetaConfig(pluginConfigViewConverter.convert(view.getRuntimeMetaConfig())); + model.setHaMetaConfig(pluginConfigViewConverter.convert(view.getHaMetaConfig())); + model.setMetricConfig(pluginConfigViewConverter.convert(view.getMetricConfig())); + model.setRemoteFileConfig(pluginConfigViewConverter.convert(view.getRemoteFileConfig())); + model.setDataConfig(pluginConfigViewConverter.convert(view.getDataConfig())); + return model; + } + + public GeaflowInstall convert(InstallView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstanceViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstanceViewConverter.java index f5c202410..d452f8b2f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstanceViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/InstanceViewConverter.java @@ -26,7 +26,7 @@ @Component public class InstanceViewConverter extends NameViewConverter { - public GeaflowInstance convert(InstanceView view) { - return viewToModel(view); - } + public GeaflowInstance convert(InstanceView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/JobViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/JobViewConverter.java index fb4191c64..133def283 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/JobViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/JobViewConverter.java @@ -19,14 +19,14 @@ package org.apache.geaflow.console.biz.shared.convert; -import com.alibaba.fastjson.JSON; -import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; + import javax.annotation.PostConstruct; + import org.apache.geaflow.console.biz.shared.view.FunctionView; import org.apache.geaflow.console.biz.shared.view.GraphView; import org.apache.geaflow.console.biz.shared.view.JobView; @@ -51,131 +51,142 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; +import com.google.common.base.Preconditions; + @Component public class JobViewConverter extends NameViewConverter { - @Autowired - private GraphViewConverter graphViewConverter; - - @Autowired - private VertexViewConverter vertexViewConverter; - - @Autowired - private EdgeViewConverter edgeViewConverter; + @Autowired private GraphViewConverter graphViewConverter; - @Autowired - private TableViewConverter tableViewConverter; + @Autowired private VertexViewConverter vertexViewConverter; - @Autowired - private RemoteFileViewConverter remoteFileViewConverter; + @Autowired private EdgeViewConverter edgeViewConverter; - @Autowired - private FunctionViewConverter functionViewConverter; + @Autowired private TableViewConverter tableViewConverter; - @Autowired - private InstanceService instanceService; + @Autowired private RemoteFileViewConverter remoteFileViewConverter; - private final Map converterMap = new HashMap<>(); + @Autowired private FunctionViewConverter functionViewConverter; - @PostConstruct - public void init() { - converterMap.put(GeaflowStructType.VERTEX, vertexViewConverter); - converterMap.put(GeaflowStructType.EDGE, edgeViewConverter); - converterMap.put(GeaflowStructType.TABLE, tableViewConverter); - } - - @Override - public void merge(JobView view, JobView updateView) { - super.merge(view, updateView); - switch (view.getType()) { - case INTEGRATE: - Optional.ofNullable(updateView.getStructMappings()).ifPresent(view::setStructMappings); - Optional.ofNullable(updateView.getGraphs()).ifPresent(view::setGraphs); - break; - case PROCESS: - Optional.ofNullable(updateView.getUserCode()).ifPresent(view::setUserCode); - break; - case CUSTOM: - Optional.ofNullable(updateView.getEntryClass()).ifPresent(view::setEntryClass); - Optional.ofNullable(updateView.getJarPackage()).ifPresent(view::setJarPackage); - break; - case SERVE: - break; - default: - throw new GeaflowException("Unsupported job type: {}", view.getType()); - } - } - - @Override - protected JobView modelToView(GeaflowJob model) { - JobView jobView = super.modelToView(model); - jobView.setUserCode(Optional.ofNullable(model.getUserCode()).map(GeaflowCode::getText).orElse(null)); + @Autowired private InstanceService instanceService; - List graphs = ListUtil.convert(model.getGraphs(), e -> graphViewConverter.convert(e)); + private final Map converterMap = new HashMap<>(); - List structs = ListUtil.convert(model.getStructs(), e -> (StructView) converterMap.get(e.getType()).convert(e)); + @PostConstruct + public void init() { + converterMap.put(GeaflowStructType.VERTEX, vertexViewConverter); + converterMap.put(GeaflowStructType.EDGE, edgeViewConverter); + converterMap.put(GeaflowStructType.TABLE, tableViewConverter); + } - jobView.setStructs(structs); - jobView.setGraphs(graphs); - - jobView.setStructMappings(JSON.toJSONString(model.getStructMappings())); - - List functions = ListUtil.convert(model.getFunctions(), e -> functionViewConverter.convert(e)); - jobView.setFunctions(functions); - jobView.setType(model.getType()); - jobView.setInstanceId(model.getInstanceId()); - jobView.setInstanceName(instanceService.get(model.getInstanceId()).getName()); - jobView.setEntryClass(model.getEntryClass()); - jobView.setJarPackage(Optional.ofNullable(model.getJarPackage()).map(e -> remoteFileViewConverter.convert(e)).orElse(null)); - return jobView; + @Override + public void merge(JobView view, JobView updateView) { + super.merge(view, updateView); + switch (view.getType()) { + case INTEGRATE: + Optional.ofNullable(updateView.getStructMappings()).ifPresent(view::setStructMappings); + Optional.ofNullable(updateView.getGraphs()).ifPresent(view::setGraphs); + break; + case PROCESS: + Optional.ofNullable(updateView.getUserCode()).ifPresent(view::setUserCode); + break; + case CUSTOM: + Optional.ofNullable(updateView.getEntryClass()).ifPresent(view::setEntryClass); + Optional.ofNullable(updateView.getJarPackage()).ifPresent(view::setJarPackage); + break; + case SERVE: + break; + default: + throw new GeaflowException("Unsupported job type: {}", view.getType()); } - - public GeaflowJob convert(JobView view, List structs, List graphs, - List functions, GeaflowRemoteFile jarFile) { - GeaflowJobType jobType = view.getType(); - GeaflowJob job; - switch (jobType) { - case INTEGRATE: - GeaflowIntegrateJob integrateJob = (GeaflowIntegrateJob) viewToModel(view, GeaflowIntegrateJob.class); - - Preconditions.checkNotNull(view.getStructMappings()); - - List structMappings = JSON.parseArray(view.getStructMappings(), StructMapping.class); - // dedup duplicated field mappings - for (StructMapping structMapping : structMappings) { - List distinctMapping = structMapping.getFieldMappings().stream().distinct().collect(Collectors.toList()); - structMapping.setFieldMappings(distinctMapping); - } - integrateJob.setStructMappings(structMappings); - integrateJob.setGraph(graphs); - integrateJob.setStructs(structs); - job = integrateJob; - break; - case PROCESS: - GeaflowProcessJob processJob = (GeaflowProcessJob) viewToModel(view, GeaflowProcessJob.class); - processJob.setUserCode(view.getUserCode()); - processJob.setFunctions(functions); - job = processJob; - break; - case CUSTOM: - GeaflowCustomJob customJob = (GeaflowCustomJob) viewToModel(view, GeaflowCustomJob.class); - customJob.setEntryClass(view.getEntryClass()); - customJob.setJarPackage(jarFile); - job = customJob; - break; - case SERVE: - GeaflowServeJob serveJob = (GeaflowServeJob) viewToModel(view, GeaflowServeJob.class); - serveJob.setGraph(graphs); - job = serveJob; - break; - default: - throw new GeaflowException("Unsupported job type: {}", jobType); + } + + @Override + protected JobView modelToView(GeaflowJob model) { + JobView jobView = super.modelToView(model); + jobView.setUserCode( + Optional.ofNullable(model.getUserCode()).map(GeaflowCode::getText).orElse(null)); + + List graphs = + ListUtil.convert(model.getGraphs(), e -> graphViewConverter.convert(e)); + + List structs = + ListUtil.convert( + model.getStructs(), e -> (StructView) converterMap.get(e.getType()).convert(e)); + + jobView.setStructs(structs); + jobView.setGraphs(graphs); + + jobView.setStructMappings(JSON.toJSONString(model.getStructMappings())); + + List functions = + ListUtil.convert(model.getFunctions(), e -> functionViewConverter.convert(e)); + jobView.setFunctions(functions); + jobView.setType(model.getType()); + jobView.setInstanceId(model.getInstanceId()); + jobView.setInstanceName(instanceService.get(model.getInstanceId()).getName()); + jobView.setEntryClass(model.getEntryClass()); + jobView.setJarPackage( + Optional.ofNullable(model.getJarPackage()) + .map(e -> remoteFileViewConverter.convert(e)) + .orElse(null)); + return jobView; + } + + public GeaflowJob convert( + JobView view, + List structs, + List graphs, + List functions, + GeaflowRemoteFile jarFile) { + GeaflowJobType jobType = view.getType(); + GeaflowJob job; + switch (jobType) { + case INTEGRATE: + GeaflowIntegrateJob integrateJob = + (GeaflowIntegrateJob) viewToModel(view, GeaflowIntegrateJob.class); + + Preconditions.checkNotNull(view.getStructMappings()); + + List structMappings = + JSON.parseArray(view.getStructMappings(), StructMapping.class); + // dedup duplicated field mappings + for (StructMapping structMapping : structMappings) { + List distinctMapping = + structMapping.getFieldMappings().stream().distinct().collect(Collectors.toList()); + structMapping.setFieldMappings(distinctMapping); } - - job.setInstanceId(view.getInstanceId()); - job.setType(jobType); - //TODO job.setSla(entity.getSlaId()); - return job; + integrateJob.setStructMappings(structMappings); + integrateJob.setGraph(graphs); + integrateJob.setStructs(structs); + job = integrateJob; + break; + case PROCESS: + GeaflowProcessJob processJob = + (GeaflowProcessJob) viewToModel(view, GeaflowProcessJob.class); + processJob.setUserCode(view.getUserCode()); + processJob.setFunctions(functions); + job = processJob; + break; + case CUSTOM: + GeaflowCustomJob customJob = (GeaflowCustomJob) viewToModel(view, GeaflowCustomJob.class); + customJob.setEntryClass(view.getEntryClass()); + customJob.setJarPackage(jarFile); + job = customJob; + break; + case SERVE: + GeaflowServeJob serveJob = (GeaflowServeJob) viewToModel(view, GeaflowServeJob.class); + serveJob.setGraph(graphs); + job = serveJob; + break; + default: + throw new GeaflowException("Unsupported job type: {}", jobType); } + job.setInstanceId(view.getInstanceId()); + job.setType(jobType); + // TODO job.setSla(entity.getSlaId()); + return job; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/LLMViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/LLMViewConverter.java index a5db513e8..897fdda6e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/LLMViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/LLMViewConverter.java @@ -19,46 +19,49 @@ package org.apache.geaflow.console.biz.shared.convert; -import com.alibaba.fastjson.JSON; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.LLMView; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.llm.GeaflowLLM; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class LLMViewConverter extends NameViewConverter { - @Override - public void merge(LLMView view, LLMView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getType()).ifPresent(view::setType); - Optional.ofNullable(updateView.getUrl()).ifPresent(view::setUrl); - Optional.ofNullable(updateView.getArgs()).ifPresent(view::setArgs); - } - - @Override - protected LLMView modelToView(GeaflowLLM model) { - LLMView view = super.modelToView(model); - view.setType(model.getType()); - view.setUrl(model.getUrl()); - view.setArgs(JSON.toJSONString(model.getArgs())); - return view; - } + @Override + public void merge(LLMView view, LLMView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getType()).ifPresent(view::setType); + Optional.ofNullable(updateView.getUrl()).ifPresent(view::setUrl); + Optional.ofNullable(updateView.getArgs()).ifPresent(view::setArgs); + } - @Override - protected GeaflowLLM viewToModel(LLMView view) { - GeaflowLLM model = super.viewToModel(view); - model.setType(view.getType()); - model.setUrl(view.getUrl()); + @Override + protected LLMView modelToView(GeaflowLLM model) { + LLMView view = super.modelToView(model); + view.setType(model.getType()); + view.setUrl(model.getUrl()); + view.setArgs(JSON.toJSONString(model.getArgs())); + return view; + } - GeaflowConfig config = Optional.ofNullable(JSON.parseObject(view.getArgs(), GeaflowConfig.class)).orElse(new GeaflowConfig()); - model.setArgs(config); - return model; - } + @Override + protected GeaflowLLM viewToModel(LLMView view) { + GeaflowLLM model = super.viewToModel(view); + model.setType(view.getType()); + model.setUrl(view.getUrl()); - public GeaflowLLM convert(LLMView view) { - return viewToModel(view); - } + GeaflowConfig config = + Optional.ofNullable(JSON.parseObject(view.getArgs(), GeaflowConfig.class)) + .orElse(new GeaflowConfig()); + model.setArgs(config); + return model; + } + public GeaflowLLM convert(LLMView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/NameViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/NameViewConverter.java index c5bdaf225..b579b3551 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/NameViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/NameViewConverter.java @@ -20,39 +20,41 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.NameView; import org.apache.geaflow.console.core.model.GeaflowName; -public abstract class NameViewConverter extends IdViewConverter { - - @Override - public void merge(V view, V updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getName()).ifPresent(view::setName); - Optional.ofNullable(updateView.getComment()).ifPresent(view::setComment); - } - - @Override - protected V modelToView(M model) { - V view = super.modelToView(model); - view.setName(model.getName()); - view.setComment(model.getComment()); - return view; - } - - @Override - protected M viewToModel(V view) { - M model = super.viewToModel(view); - model.setName(view.getName()); - model.setComment(view.getComment()); - return model; - } - - @Override - protected M viewToModel(V view, Class clazz) { - M model = super.viewToModel(view, clazz); - model.setName(view.getName()); - model.setComment(view.getComment()); - return model; - } +public abstract class NameViewConverter + extends IdViewConverter { + + @Override + public void merge(V view, V updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getName()).ifPresent(view::setName); + Optional.ofNullable(updateView.getComment()).ifPresent(view::setComment); + } + + @Override + protected V modelToView(M model) { + V view = super.modelToView(model); + view.setName(model.getName()); + view.setComment(model.getComment()); + return view; + } + + @Override + protected M viewToModel(V view) { + M model = super.viewToModel(view); + model.setName(view.getName()); + model.setComment(view.getComment()); + return model; + } + + @Override + protected M viewToModel(V view, Class clazz) { + M model = super.viewToModel(view, clazz); + model.setName(view.getName()); + model.setComment(view.getComment()); + return model; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginConfigViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginConfigViewConverter.java index 1779e9c96..06832edaf 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginConfigViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginConfigViewConverter.java @@ -20,40 +20,42 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.PluginConfigView; import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; import org.springframework.stereotype.Component; @Component -public class PluginConfigViewConverter extends NameViewConverter { - - @Override - public void merge(PluginConfigView view, PluginConfigView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getConfig()).ifPresent(view::setConfig); - Optional.ofNullable(updateView.getCategory()).ifPresent(view::setCategory); - Optional.ofNullable(updateView.getType()).ifPresent(view::setType); - } - - @Override - protected PluginConfigView modelToView(GeaflowPluginConfig model) { - PluginConfigView view = super.modelToView(model); - view.setType(model.getType()); - view.setConfig(model.getConfig()); - view.setCategory(model.getCategory()); - return view; - } - - @Override - protected GeaflowPluginConfig viewToModel(PluginConfigView view) { - GeaflowPluginConfig model = super.viewToModel(view); - model.setType(view.getType()); - model.setConfig(view.getConfig()); - model.setCategory(view.getCategory()); - return model; - } - - public GeaflowPluginConfig convert(PluginConfigView view) { - return viewToModel(view); - } +public class PluginConfigViewConverter + extends NameViewConverter { + + @Override + public void merge(PluginConfigView view, PluginConfigView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getConfig()).ifPresent(view::setConfig); + Optional.ofNullable(updateView.getCategory()).ifPresent(view::setCategory); + Optional.ofNullable(updateView.getType()).ifPresent(view::setType); + } + + @Override + protected PluginConfigView modelToView(GeaflowPluginConfig model) { + PluginConfigView view = super.modelToView(model); + view.setType(model.getType()); + view.setConfig(model.getConfig()); + view.setCategory(model.getCategory()); + return view; + } + + @Override + protected GeaflowPluginConfig viewToModel(PluginConfigView view) { + GeaflowPluginConfig model = super.viewToModel(view); + model.setType(view.getType()); + model.setConfig(view.getConfig()); + model.setCategory(view.getCategory()); + return model; + } + + public GeaflowPluginConfig convert(PluginConfigView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginViewConverter.java index 96bc9734f..9c9a7af72 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/PluginViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.PluginView; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; import org.apache.geaflow.console.core.model.plugin.GeaflowPlugin; @@ -29,38 +30,40 @@ @Component public class PluginViewConverter extends NameViewConverter { - @Autowired - private RemoteFileViewConverter remoteFileViewConverter; + @Autowired private RemoteFileViewConverter remoteFileViewConverter; - @Override - public void merge(PluginView view, PluginView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getJarPackage()).ifPresent(view::setJarPackage); - Optional.ofNullable(updateView.getType()).ifPresent(view::setType); - Optional.ofNullable(updateView.getCategory()).ifPresent(view::setCategory); - } + @Override + public void merge(PluginView view, PluginView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getJarPackage()).ifPresent(view::setJarPackage); + Optional.ofNullable(updateView.getType()).ifPresent(view::setType); + Optional.ofNullable(updateView.getCategory()).ifPresent(view::setCategory); + } - @Override - protected PluginView modelToView(GeaflowPlugin model) { - PluginView view = super.modelToView(model); - view.setType(model.getType()); - view.setCategory(model.getCategory()); - view.setJarPackage(Optional.ofNullable(model.getJarPackage()).map(e -> remoteFileViewConverter.convert(e)).orElse(null)); - view.setSystem(model.isSystem()); - return view; - } + @Override + protected PluginView modelToView(GeaflowPlugin model) { + PluginView view = super.modelToView(model); + view.setType(model.getType()); + view.setCategory(model.getCategory()); + view.setJarPackage( + Optional.ofNullable(model.getJarPackage()) + .map(e -> remoteFileViewConverter.convert(e)) + .orElse(null)); + view.setSystem(model.isSystem()); + return view; + } - @Override - protected GeaflowPlugin viewToModel(PluginView view) { - GeaflowPlugin model = super.viewToModel(view); - model.setType(view.getType()); - model.setCategory(view.getCategory()); - return model; - } + @Override + protected GeaflowPlugin viewToModel(PluginView view) { + GeaflowPlugin model = super.viewToModel(view); + model.setType(view.getType()); + model.setCategory(view.getCategory()); + return model; + } - public GeaflowPlugin convert(PluginView view, GeaflowRemoteFile jarPackage) { - GeaflowPlugin model = viewToModel(view); - model.setJarPackage(jarPackage); - return model; - } + public GeaflowPlugin convert(PluginView view, GeaflowRemoteFile jarPackage) { + GeaflowPlugin model = viewToModel(view); + model.setJarPackage(jarPackage); + return model; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseUpdateViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseUpdateViewConverter.java index 0466ba048..382d862a8 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseUpdateViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseUpdateViewConverter.java @@ -28,13 +28,14 @@ @Component public class ReleaseUpdateViewConverter { - public ReleaseUpdate converter(ReleaseUpdateView view, GeaflowVersion version, GeaflowCluster cluster) { - ReleaseUpdate model = new ReleaseUpdate(); - model.setNewJobConfig(view.getNewJobConfig()); - model.setNewClusterConfig(view.getNewClusterConfig()); - model.setNewParallelisms(view.getNewParallelisms()); - model.setNewVersion(version); - model.setNewCluster(cluster); - return model; - } + public ReleaseUpdate converter( + ReleaseUpdateView view, GeaflowVersion version, GeaflowCluster cluster) { + ReleaseUpdate model = new ReleaseUpdate(); + model.setNewJobConfig(view.getNewJobConfig()); + model.setNewClusterConfig(view.getNewClusterConfig()); + model.setNewParallelisms(view.getNewParallelisms()); + model.setNewVersion(version); + model.setNewCluster(cluster); + return model; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseViewConverter.java index 903aa0be0..f1cca0ab2 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/ReleaseViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.ReleaseView; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.model.release.GeaflowRelease; @@ -29,21 +30,20 @@ @Component public class ReleaseViewConverter extends IdViewConverter { - @Autowired - private JobViewConverter jobViewConverter; - - @Override - protected ReleaseView modelToView(GeaflowRelease model) { - ReleaseView releaseView = super.modelToView(model); - releaseView.setClusterName(Optional.ofNullable(model.getCluster()).map(GeaflowName::getName).orElse(null)); - releaseView.setVersionName(Optional.ofNullable(model.getVersion()).map(GeaflowName::getName).orElse(null)); - releaseView.setJob(jobViewConverter.convert(model.getJob())); - releaseView.setJobConfig(model.getJobConfig()); - releaseView.setClusterConfig(model.getClusterConfig()); - releaseView.setJobPlan(model.getJobPlan()); - releaseView.setReleaseVersion(model.getReleaseVersion()); - return releaseView; - } - + @Autowired private JobViewConverter jobViewConverter; + @Override + protected ReleaseView modelToView(GeaflowRelease model) { + ReleaseView releaseView = super.modelToView(model); + releaseView.setClusterName( + Optional.ofNullable(model.getCluster()).map(GeaflowName::getName).orElse(null)); + releaseView.setVersionName( + Optional.ofNullable(model.getVersion()).map(GeaflowName::getName).orElse(null)); + releaseView.setJob(jobViewConverter.convert(model.getJob())); + releaseView.setJobConfig(model.getJobConfig()); + releaseView.setClusterConfig(model.getClusterConfig()); + releaseView.setJobPlan(model.getJobPlan()); + releaseView.setReleaseVersion(model.getReleaseVersion()); + return releaseView; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/RemoteFileViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/RemoteFileViewConverter.java index 549fb3141..c5846bc5b 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/RemoteFileViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/RemoteFileViewConverter.java @@ -26,28 +26,27 @@ @Component public class RemoteFileViewConverter extends NameViewConverter { - @Override - protected RemoteFileView modelToView(GeaflowRemoteFile model) { - RemoteFileView view = super.modelToView(model); - view.setMd5(model.getMd5()); - view.setType(model.getType()); - view.setPath(model.getPath()); - view.setUrl(model.getUrl()); - return view; - } + @Override + protected RemoteFileView modelToView(GeaflowRemoteFile model) { + RemoteFileView view = super.modelToView(model); + view.setMd5(model.getMd5()); + view.setType(model.getType()); + view.setPath(model.getPath()); + view.setUrl(model.getUrl()); + return view; + } + @Override + protected GeaflowRemoteFile viewToModel(RemoteFileView view) { + GeaflowRemoteFile model = super.viewToModel(view); + model.setPath(view.getPath()); + model.setMd5(view.getMd5()); + model.setUrl(view.getUrl()); + model.setType(view.getType()); + return model; + } - @Override - protected GeaflowRemoteFile viewToModel(RemoteFileView view) { - GeaflowRemoteFile model = super.viewToModel(view); - model.setPath(view.getPath()); - model.setMd5(view.getMd5()); - model.setUrl(view.getUrl()); - model.setType(view.getType()); - return model; - } - - public GeaflowRemoteFile convert(RemoteFileView view) { - return viewToModel(view); - } + public GeaflowRemoteFile convert(RemoteFileView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StatementViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StatementViewConverter.java index 1976f8466..d1783a016 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StatementViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StatementViewConverter.java @@ -19,37 +19,38 @@ package org.apache.geaflow.console.biz.shared.convert; -import com.alibaba.fastjson.JSON; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.StatementView; import org.apache.geaflow.console.core.model.statement.GeaflowStatement; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class StatementViewConverter extends IdViewConverter { - - @Override - protected StatementView modelToView(GeaflowStatement model) { - StatementView view = super.modelToView(model); - view.setScript(model.getScript()); - view.setStatus(model.getStatus()); - try { - view.setResult(JSON.parseObject(model.getResult())); - } catch (Exception e) { - view.setResult(model.getResult()); - } - - view.setJobId(model.getJobId()); - return view; + @Override + protected StatementView modelToView(GeaflowStatement model) { + StatementView view = super.modelToView(model); + view.setScript(model.getScript()); + view.setStatus(model.getStatus()); + try { + view.setResult(JSON.parseObject(model.getResult())); + } catch (Exception e) { + view.setResult(model.getResult()); } - public GeaflowStatement convert(StatementView view) { - GeaflowStatement model = super.viewToModel(view); - model.setScript(view.getScript()); - model.setStatus(view.getStatus()); - model.setResult(Optional.ofNullable(view.getResult()).map(Object::toString).orElse(null)); - model.setJobId(view.getJobId()); - return model; - } + view.setJobId(model.getJobId()); + return view; + } + + public GeaflowStatement convert(StatementView view) { + GeaflowStatement model = super.viewToModel(view); + model.setScript(view.getScript()); + model.setStatus(view.getStatus()); + model.setResult(Optional.ofNullable(view.getResult()).map(Object::toString).orElse(null)); + model.setJobId(view.getJobId()); + return model; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StructViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StructViewConverter.java index ed1152856..899678089 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StructViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/StructViewConverter.java @@ -21,38 +21,41 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.view.FieldView; import org.apache.geaflow.console.biz.shared.view.StructView; import org.apache.geaflow.console.core.model.data.GeaflowField; import org.apache.geaflow.console.core.model.data.GeaflowStruct; import org.springframework.beans.factory.annotation.Autowired; -public abstract class StructViewConverter extends DataViewConverter { +public abstract class StructViewConverter + extends DataViewConverter { - @Autowired - private FieldViewConverter fieldViewConverter; + @Autowired private FieldViewConverter fieldViewConverter; - @Override - protected V modelToView(M model) { - V view = super.modelToView(model); - view.setType(model.getType()); - //set field - List fieldViews = model.getFields().values().stream().map(e -> fieldViewConverter.convert(e)) + @Override + protected V modelToView(M model) { + V view = super.modelToView(model); + view.setType(model.getType()); + // set field + List fieldViews = + model.getFields().values().stream() + .map(e -> fieldViewConverter.convert(e)) .collect(Collectors.toList()); - view.setFields(fieldViews); - return view; - } - - @Override - protected M viewToModel(V view) { - M model = super.viewToModel(view); - model.setType(view.getType()); - return model; - } - - public M convert(V view, List fields) { - M struct = viewToModel(view); - struct.addFields(fields); - return struct; - } + view.setFields(fieldViews); + return view; + } + + @Override + protected M viewToModel(V view) { + M model = super.viewToModel(view); + model.setType(view.getType()); + return model; + } + + public M convert(V view, List fields) { + M struct = viewToModel(view); + struct.addFields(fields); + return struct; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/SystemConfigViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/SystemConfigViewConverter.java index 5435e51fb..a653c3359 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/SystemConfigViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/SystemConfigViewConverter.java @@ -20,35 +20,37 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.SystemConfigView; import org.apache.geaflow.console.core.model.config.GeaflowSystemConfig; import org.springframework.stereotype.Component; @Component -public class SystemConfigViewConverter extends NameViewConverter { - - @Override - public void merge(SystemConfigView view, SystemConfigView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getValue()).ifPresent(view::setValue); - } - - @Override - protected SystemConfigView modelToView(GeaflowSystemConfig model) { - SystemConfigView view = super.modelToView(model); - view.setValue(model.getValue()); - return view; - } - - @Override - protected GeaflowSystemConfig viewToModel(SystemConfigView view) { - GeaflowSystemConfig model = super.viewToModel(view); - model.setTenantId(view.getTenantId()); - model.setValue(view.getValue()); - return model; - } - - public GeaflowSystemConfig convert(SystemConfigView view) { - return viewToModel(view); - } +public class SystemConfigViewConverter + extends NameViewConverter { + + @Override + public void merge(SystemConfigView view, SystemConfigView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getValue()).ifPresent(view::setValue); + } + + @Override + protected SystemConfigView modelToView(GeaflowSystemConfig model) { + SystemConfigView view = super.modelToView(model); + view.setValue(model.getValue()); + return view; + } + + @Override + protected GeaflowSystemConfig viewToModel(SystemConfigView view) { + GeaflowSystemConfig model = super.viewToModel(view); + model.setTenantId(view.getTenantId()); + model.setValue(view.getValue()); + return model; + } + + public GeaflowSystemConfig convert(SystemConfigView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TableViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TableViewConverter.java index 5e5047547..16c4464c3 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TableViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TableViewConverter.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.TableView; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.common.util.type.GeaflowStructType; @@ -33,42 +34,41 @@ @Component public class TableViewConverter extends StructViewConverter { - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; - - - @Override - public void merge(TableView view, TableView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); - Optional.ofNullable(updateView.getPluginConfig()).ifPresent(e -> { - // update pluginConfig info - e.setId(view.getPluginConfig().getId()); - e.setCategory(GeaflowPluginCategory.TABLE); - view.setPluginConfig(e); - }); - } - - @Override - protected TableView modelToView(GeaflowTable model) { - TableView tableView = super.modelToView(model); - tableView.setType(GeaflowStructType.TABLE); - tableView.setPluginConfig(pluginConfigViewConverter.convert(model.getPluginConfig())); - return tableView; - } + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; - @Override - protected GeaflowTable viewToModel(TableView view) { - GeaflowTable table = super.viewToModel(view); - table.setType(GeaflowStructType.TABLE); - return table; - } + @Override + public void merge(TableView view, TableView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); + Optional.ofNullable(updateView.getPluginConfig()) + .ifPresent( + e -> { + // update pluginConfig info + e.setId(view.getPluginConfig().getId()); + e.setCategory(GeaflowPluginCategory.TABLE); + view.setPluginConfig(e); + }); + } + @Override + protected TableView modelToView(GeaflowTable model) { + TableView tableView = super.modelToView(model); + tableView.setType(GeaflowStructType.TABLE); + tableView.setPluginConfig(pluginConfigViewConverter.convert(model.getPluginConfig())); + return tableView; + } - public GeaflowTable convert(TableView view, List fields, GeaflowPluginConfig pluginConfig) { - GeaflowTable table = super.convert(view, fields); - table.setPluginConfig(pluginConfig); - return table; + @Override + protected GeaflowTable viewToModel(TableView view) { + GeaflowTable table = super.viewToModel(view); + table.setType(GeaflowStructType.TABLE); + return table; + } - } + public GeaflowTable convert( + TableView view, List fields, GeaflowPluginConfig pluginConfig) { + GeaflowTable table = super.convert(view, fields); + table.setPluginConfig(pluginConfig); + return table; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TaskViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TaskViewConverter.java index 147ca9fb9..db9ee2bb2 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TaskViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TaskViewConverter.java @@ -27,29 +27,29 @@ @Component public class TaskViewConverter extends IdViewConverter { - @Autowired - protected PluginConfigViewConverter pluginConfigViewConverter; - - @Autowired - protected ReleaseViewConverter releaseViewConverter; - - @Override - protected TaskView modelToView(GeaflowTask model) { - TaskView view = super.modelToView(model); - view.setRelease(releaseViewConverter.convert(model.getRelease())); - view.setType(model.getType()); - view.setStatus(model.getStatus()); - view.setStartTime(model.getStartTime()); - view.setEndTime(model.getEndTime()); - view.setRuntimeMetaPluginConfig(pluginConfigViewConverter.modelToView(model.getRuntimeMetaPluginConfig())); - view.setHaMetaPluginConfig(pluginConfigViewConverter.modelToView(model.getHaMetaPluginConfig())); - view.setMetricPluginConfig(pluginConfigViewConverter.modelToView(model.getMetricPluginConfig())); - view.setDataPluginConfig(pluginConfigViewConverter.modelToView(model.getDataPluginConfig())); - return view; - } - - public GeaflowTask convert(TaskView view) { - return viewToModel(view); - } - + @Autowired protected PluginConfigViewConverter pluginConfigViewConverter; + + @Autowired protected ReleaseViewConverter releaseViewConverter; + + @Override + protected TaskView modelToView(GeaflowTask model) { + TaskView view = super.modelToView(model); + view.setRelease(releaseViewConverter.convert(model.getRelease())); + view.setType(model.getType()); + view.setStatus(model.getStatus()); + view.setStartTime(model.getStartTime()); + view.setEndTime(model.getEndTime()); + view.setRuntimeMetaPluginConfig( + pluginConfigViewConverter.modelToView(model.getRuntimeMetaPluginConfig())); + view.setHaMetaPluginConfig( + pluginConfigViewConverter.modelToView(model.getHaMetaPluginConfig())); + view.setMetricPluginConfig( + pluginConfigViewConverter.modelToView(model.getMetricPluginConfig())); + view.setDataPluginConfig(pluginConfigViewConverter.modelToView(model.getDataPluginConfig())); + return view; + } + + public GeaflowTask convert(TaskView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TenantViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TenantViewConverter.java index 7efc01cbe..3c506e847 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TenantViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/TenantViewConverter.java @@ -26,7 +26,7 @@ @Component public class TenantViewConverter extends NameViewConverter { - public GeaflowTenant convert(TenantView view) { - return viewToModel(view); - } + public GeaflowTenant convert(TenantView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/UserViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/UserViewConverter.java index e658dfebf..b8a27f319 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/UserViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/UserViewConverter.java @@ -26,24 +26,24 @@ @Component public class UserViewConverter extends NameViewConverter { - @Override - protected UserView modelToView(GeaflowUser model) { - UserView view = super.modelToView(model); - view.setPhone(model.getPhone()); - view.setEmail(model.getEmail()); - return view; - } + @Override + protected UserView modelToView(GeaflowUser model) { + UserView view = super.modelToView(model); + view.setPhone(model.getPhone()); + view.setEmail(model.getEmail()); + return view; + } - @Override - protected GeaflowUser viewToModel(UserView view) { - GeaflowUser model = super.viewToModel(view); - model.setPassword(view.getPassword()); - model.setPhone(view.getPhone()); - model.setEmail(view.getEmail()); - return model; - } + @Override + protected GeaflowUser viewToModel(UserView view) { + GeaflowUser model = super.viewToModel(view); + model.setPassword(view.getPassword()); + model.setPhone(view.getPhone()); + model.setEmail(view.getEmail()); + return model; + } - public GeaflowUser convert(UserView view) { - return viewToModel(view); - } + public GeaflowUser convert(UserView view) { + return viewToModel(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VersionViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VersionViewConverter.java index f653e1d5b..85ccdd3e1 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VersionViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VersionViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.convert; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.VersionView; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; import org.apache.geaflow.console.core.model.version.GeaflowVersion; @@ -29,38 +30,37 @@ @Component public class VersionViewConverter extends NameViewConverter { - @Autowired - protected RemoteFileViewConverter remoteFileViewConverter; + @Autowired protected RemoteFileViewConverter remoteFileViewConverter; - @Override - public void merge(VersionView view, VersionView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getEngineJarPackage()).ifPresent(view::setEngineJarPackage); - Optional.ofNullable(updateView.getLangJarPackage()).ifPresent(view::setLangJarPackage); - Optional.ofNullable(updateView.getPublish()).ifPresent(view::setPublish); - } + @Override + public void merge(VersionView view, VersionView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getEngineJarPackage()).ifPresent(view::setEngineJarPackage); + Optional.ofNullable(updateView.getLangJarPackage()).ifPresent(view::setLangJarPackage); + Optional.ofNullable(updateView.getPublish()).ifPresent(view::setPublish); + } - @Override - protected VersionView modelToView(GeaflowVersion model) { - VersionView view = super.modelToView(model); - view.setEngineJarPackage(remoteFileViewConverter.convert(model.getEngineJarPackage())); - view.setLangJarPackage(remoteFileViewConverter.convert(model.getLangJarPackage())); - view.setPublish(model.isPublish()); - return view; - } + @Override + protected VersionView modelToView(GeaflowVersion model) { + VersionView view = super.modelToView(model); + view.setEngineJarPackage(remoteFileViewConverter.convert(model.getEngineJarPackage())); + view.setLangJarPackage(remoteFileViewConverter.convert(model.getLangJarPackage())); + view.setPublish(model.isPublish()); + return view; + } - @Override - protected GeaflowVersion viewToModel(VersionView view) { - GeaflowVersion model = super.viewToModel(view); - Optional.ofNullable(view.getPublish()).ifPresent(model::setPublish); - return model; - } + @Override + protected GeaflowVersion viewToModel(VersionView view) { + GeaflowVersion model = super.viewToModel(view); + Optional.ofNullable(view.getPublish()).ifPresent(model::setPublish); + return model; + } - public GeaflowVersion convert(VersionView view, GeaflowRemoteFile engineJarPackage, - GeaflowRemoteFile langJarPackage) { - GeaflowVersion version = viewToModel(view); - version.setEngineJarPackage(engineJarPackage); - version.setLangJarPackage(langJarPackage); - return version; - } + public GeaflowVersion convert( + VersionView view, GeaflowRemoteFile engineJarPackage, GeaflowRemoteFile langJarPackage) { + GeaflowVersion version = viewToModel(view); + version.setEngineJarPackage(engineJarPackage); + version.setLangJarPackage(langJarPackage); + return version; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VertexViewConverter.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VertexViewConverter.java index 936a60341..f1a64a443 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VertexViewConverter.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/convert/VertexViewConverter.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.biz.shared.view.VertexView; import org.apache.geaflow.console.common.util.type.GeaflowStructType; import org.apache.geaflow.console.core.model.data.GeaflowField; @@ -30,29 +31,29 @@ @Component public class VertexViewConverter extends StructViewConverter { - @Override - protected VertexView modelToView(GeaflowVertex model) { - VertexView vertexView = super.modelToView(model); - vertexView.setType(GeaflowStructType.VERTEX); - return vertexView; - } - - @Override - public void merge(VertexView view, VertexView updateView) { - super.merge(view, updateView); - Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); - } - - @Override - protected GeaflowVertex viewToModel(VertexView view) { - GeaflowVertex vertex = super.viewToModel(view); - vertex.setType(GeaflowStructType.VERTEX); - return vertex; - } - - public GeaflowVertex converter(VertexView view, List fields) { - GeaflowVertex vertex = viewToModel(view); - vertex.addFields(fields); - return vertex; - } + @Override + protected VertexView modelToView(GeaflowVertex model) { + VertexView vertexView = super.modelToView(model); + vertexView.setType(GeaflowStructType.VERTEX); + return vertexView; + } + + @Override + public void merge(VertexView view, VertexView updateView) { + super.merge(view, updateView); + Optional.ofNullable(updateView.getFields()).ifPresent(view::setFields); + } + + @Override + protected GeaflowVertex viewToModel(VertexView view) { + GeaflowVertex vertex = super.viewToModel(view); + vertex.setType(GeaflowStructType.VERTEX); + return vertex; + } + + public GeaflowVertex converter(VertexView view, List fields) { + GeaflowVertex vertex = viewToModel(view); + vertex.addFields(fields); + return vertex; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/DemoJob.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/DemoJob.java index 962bcdafe..7d2177709 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/DemoJob.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/DemoJob.java @@ -23,5 +23,5 @@ public abstract class DemoJob { - public abstract GeaflowJob build(); + public abstract GeaflowJob build(); } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/SocketDemo.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/SocketDemo.java index 7a3fd9fc1..503918b8f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/SocketDemo.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/demo/SocketDemo.java @@ -20,23 +20,25 @@ package org.apache.geaflow.console.biz.shared.demo; import java.util.HashMap; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.VelocityUtil; import org.apache.geaflow.console.core.model.job.GeaflowJob; import org.apache.geaflow.console.core.model.job.GeaflowProcessJob; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Component @Slf4j public class SocketDemo extends DemoJob { - private static final String DEMO_JOB_TEMPLATE = "template/demoJob.vm"; + private static final String DEMO_JOB_TEMPLATE = "template/demoJob.vm"; - public GeaflowJob build() { - GeaflowProcessJob job = new GeaflowProcessJob(); - String code = VelocityUtil.applyResource(DEMO_JOB_TEMPLATE, new HashMap<>()); - job.setUserCode(code); - job.setName("demoJob"); - return job; - } + public GeaflowJob build() { + GeaflowProcessJob job = new GeaflowProcessJob(); + String code = VelocityUtil.applyResource(DEMO_JOB_TEMPLATE, new HashMap<>()); + job.setUserCode(code); + job.setName("demoJob"); + return job; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuditManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuditManagerImpl.java index e419982b9..fdcf37137 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuditManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuditManagerImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.impl; import java.util.List; + import org.apache.geaflow.console.biz.shared.AuditManager; import org.apache.geaflow.console.biz.shared.convert.AuditViewConverter; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; @@ -34,27 +35,25 @@ import org.springframework.stereotype.Service; @Service -public class AuditManagerImpl extends IdManagerImpl implements AuditManager { - - @Autowired - private AuditService auditService; +public class AuditManagerImpl extends IdManagerImpl + implements AuditManager { - @Autowired - private AuditViewConverter auditViewConverter; + @Autowired private AuditService auditService; - @Override - public IdViewConverter getConverter() { - return auditViewConverter; - } + @Autowired private AuditViewConverter auditViewConverter; - @Override - public IdService getService() { - return auditService; - } + @Override + public IdViewConverter getConverter() { + return auditViewConverter; + } - @Override - public List parse(List views) { - return ListUtil.convert(views, auditViewConverter::convert); - } + @Override + public IdService getService() { + return auditService; + } + @Override + public List parse(List views) { + return ListUtil.convert(views, auditViewConverter::convert); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthenticationManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthenticationManagerImpl.java index 3eeae65d2..2588f82a8 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthenticationManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthenticationManagerImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.impl; import java.util.Set; + import org.apache.geaflow.console.biz.shared.AuthenticationManager; import org.apache.geaflow.console.biz.shared.TenantManager; import org.apache.geaflow.console.biz.shared.UserManager; @@ -38,56 +39,53 @@ @Service public class AuthenticationManagerImpl implements AuthenticationManager { - @Autowired - private AuthenticationService authenticationService; + @Autowired private AuthenticationService authenticationService; - @Autowired - private AuthenticationViewConverter authenticationViewConverter; + @Autowired private AuthenticationViewConverter authenticationViewConverter; - @Autowired - private UserManager userManager; + @Autowired private UserManager userManager; - @Autowired - private TenantManager tenantManager; + @Autowired private TenantManager tenantManager; - @Override - public AuthenticationView login(String loginName, String password, boolean systemLogin) { - GeaflowAuthentication authentication = authenticationService.login(loginName, password, systemLogin); - return authenticationViewConverter.convert(authentication); - } + @Override + public AuthenticationView login(String loginName, String password, boolean systemLogin) { + GeaflowAuthentication authentication = + authenticationService.login(loginName, password, systemLogin); + return authenticationViewConverter.convert(authentication); + } - @Override - public AuthenticationView authenticate(String token) { - GeaflowAuthentication authentication = authenticationService.authenticate(token); - return authenticationViewConverter.convert(authentication); - } + @Override + public AuthenticationView authenticate(String token) { + GeaflowAuthentication authentication = authenticationService.authenticate(token); + return authenticationViewConverter.convert(authentication); + } - @Override - public SessionView currentSession() { - String userId = ContextHolder.get().getUserId(); - String tenantId = ContextHolder.get().getTenantId(); - Set roleTypes = ContextHolder.get().getRoleTypes(); + @Override + public SessionView currentSession() { + String userId = ContextHolder.get().getUserId(); + String tenantId = ContextHolder.get().getTenantId(); + Set roleTypes = ContextHolder.get().getRoleTypes(); - UserView user = userManager.getUser(userId); - TenantView tenant = tenantId == null ? null : tenantManager.get(tenantId); - GeaflowAuthentication authentication = authenticationService.getAuthenticationByUserId(userId); + UserView user = userManager.getUser(userId); + TenantView tenant = tenantId == null ? null : tenantManager.get(tenantId); + GeaflowAuthentication authentication = authenticationService.getAuthenticationByUserId(userId); - SessionView session = new SessionView(); - session.setUser(user); - session.setTenant(tenant); - session.setAuthentication(authenticationViewConverter.convert(authentication)); - session.setRoleTypes(roleTypes); - return session; - } + SessionView session = new SessionView(); + session.setUser(user); + session.setTenant(tenant); + session.setAuthentication(authenticationViewConverter.convert(authentication)); + session.setRoleTypes(roleTypes); + return session; + } - @Override - public boolean switchSession() { - return authenticationService.switchSession(); - } + @Override + public boolean switchSession() { + return authenticationService.switchSession(); + } - @Override - public boolean logout() { - String token = ContextHolder.get().getSessionToken(); - return authenticationService.logout(token); - } + @Override + public boolean logout() { + String token = ContextHolder.get().getSessionToken(); + return authenticationService.logout(token); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthorizationManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthorizationManagerImpl.java index e0a07b1ea..5dbc77833 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthorizationManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/AuthorizationManagerImpl.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.AuthorizationManager; import org.apache.geaflow.console.biz.shared.convert.AuthorizationViewConverter; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; @@ -48,126 +49,135 @@ import org.springframework.stereotype.Service; @Service -public class AuthorizationManagerImpl extends - IdManagerImpl implements AuthorizationManager { +public class AuthorizationManagerImpl + extends IdManagerImpl + implements AuthorizationManager { - @Autowired - private AuthorizationService authorizationService; + @Autowired private AuthorizationService authorizationService; - @Autowired - private AuthorizationViewConverter authorizationViewConverter; + @Autowired private AuthorizationViewConverter authorizationViewConverter; - @Autowired - private UserService userService; + @Autowired private UserService userService; - @Autowired - private ResourceFactory resourceFactory; + @Autowired private ResourceFactory resourceFactory; - @Override - public List getUserRoleTypes(String userId) { - String tenantId = ContextHolder.get().getTenantId(); - return authorizationService.getUserRoleTypes(tenantId, userId); - } + @Override + public List getUserRoleTypes(String userId) { + String tenantId = ContextHolder.get().getTenantId(); + return authorizationService.getUserRoleTypes(tenantId, userId); + } - @Override - public void hasRole(GeaflowRole... roles) throws GeaflowSecurityException { - if (!validateRoles(ContextHolder.get().getRoleTypes(), roles)) { - throw new GeaflowSecurityException("User has no role of {}", - Arrays.stream(roles).map(GeaflowRole::getType).collect(Collectors.toList())); - } + @Override + public void hasRole(GeaflowRole... roles) throws GeaflowSecurityException { + if (!validateRoles(ContextHolder.get().getRoleTypes(), roles)) { + throw new GeaflowSecurityException( + "User has no role of {}", + Arrays.stream(roles).map(GeaflowRole::getType).collect(Collectors.toList())); } - - @Override - public void hasAuthority(GeaflowAuthority authority, GeaflowResource resource) throws GeaflowSecurityException { - String userId = ContextHolder.get().getUserId(); - Set roleTypes = ContextHolder.get().getRoleTypes(); - if (!validateAuthority(userId, roleTypes, authority, resource)) { - throw new GeaflowSecurityException("User has no authority {} of {} {}", authority.getType(), - resource.getType(), resource.getId()); - } - } - - protected boolean validateRoles(Set roleTypes, GeaflowRole[] expectRoles) { - for (GeaflowRole expectRole : expectRoles) { - while (expectRole != null) { - if (roleTypes.contains(expectRole.getType())) { - return true; - } - expectRole = expectRole.getParent(); - } - } - - return false; + } + + @Override + public void hasAuthority(GeaflowAuthority authority, GeaflowResource resource) + throws GeaflowSecurityException { + String userId = ContextHolder.get().getUserId(); + Set roleTypes = ContextHolder.get().getRoleTypes(); + if (!validateAuthority(userId, roleTypes, authority, resource)) { + throw new GeaflowSecurityException( + "User has no authority {} of {} {}", + authority.getType(), + resource.getType(), + resource.getId()); } + } - protected boolean validateAuthority(String userId, Set roleTypes, GeaflowAuthority authority, - GeaflowResource resource) { - // check role - GeaflowGrant expectGrant = new GeaflowGrant(authority, resource); - for (GeaflowGrant roleGrant : GeaflowRole.getGrants(roleTypes)) { - if (roleGrant.include(expectGrant)) { - return true; - } + protected boolean validateRoles(Set roleTypes, GeaflowRole[] expectRoles) { + for (GeaflowRole expectRole : expectRoles) { + while (expectRole != null) { + if (roleTypes.contains(expectRole.getType())) { + return true; } - - // check authorization - while (resource != null) { - if (authorizationService.exist(userId, authority, resource)) { - return true; - } - - resource = resource.getParent(); - } - - return false; + expectRole = expectRole.getParent(); + } } - @Override - protected IdService getService() { - return authorizationService; + return false; + } + + protected boolean validateAuthority( + String userId, + Set roleTypes, + GeaflowAuthority authority, + GeaflowResource resource) { + // check role + GeaflowGrant expectGrant = new GeaflowGrant(authority, resource); + for (GeaflowGrant roleGrant : GeaflowRole.getGrants(roleTypes)) { + if (roleGrant.include(expectGrant)) { + return true; + } } - @Override - protected IdViewConverter getConverter() { - return authorizationViewConverter; - } + // check authorization + while (resource != null) { + if (authorizationService.exist(userId, authority, resource)) { + return true; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, e -> authorizationViewConverter.convert(e)); + resource = resource.getParent(); } - @Override - public List create(List views) { - for (AuthorizationView view : views) { - String userId = view.getUserId(); - GeaflowAuthority authority = GeaflowAuthority.of(view.getAuthorityType()); - GeaflowResource resource = resourceFactory.build(view.getResourceType(), view.getResourceId()); - - // check current user has the authority - hasAuthority(authority, resource); - // check the authored user has the authority - Set roleTypes = new HashSet<>(getUserRoleTypes(userId)); - if (validateAuthority(userId, roleTypes, authority, resource)) { - String userName = userService.getUserNames(Collections.singleton(userId)).get(userId); - throw new GeaflowException("User {} already has the authority of resource {} {}", userName, - resource.getType(), resource.getId()); - } - } - return super.create(views); + return false; + } + + @Override + protected IdService getService() { + return authorizationService; + } + + @Override + protected IdViewConverter getConverter() { + return authorizationViewConverter; + } + + @Override + protected List parse(List views) { + return ListUtil.convert(views, e -> authorizationViewConverter.convert(e)); + } + + @Override + public List create(List views) { + for (AuthorizationView view : views) { + String userId = view.getUserId(); + GeaflowAuthority authority = GeaflowAuthority.of(view.getAuthorityType()); + GeaflowResource resource = + resourceFactory.build(view.getResourceType(), view.getResourceId()); + + // check current user has the authority + hasAuthority(authority, resource); + // check the authored user has the authority + Set roleTypes = new HashSet<>(getUserRoleTypes(userId)); + if (validateAuthority(userId, roleTypes, authority, resource)) { + String userName = userService.getUserNames(Collections.singleton(userId)).get(userId); + throw new GeaflowException( + "User {} already has the authority of resource {} {}", + userName, + resource.getType(), + resource.getId()); + } } - - @Override - public boolean drop(List ids) { - List authorizations = authorizationService.get(ids); - for (GeaflowAuthorization authorization : authorizations) { - GeaflowAuthority authority = GeaflowAuthority.of(authorization.getAuthorityType()); - GeaflowResource resource = resourceFactory.build(authorization.getResourceType(), - authorization.getResourceId()); - - // check current user has the authority - hasAuthority(authority, resource); - } - return super.drop(ids); + return super.create(views); + } + + @Override + public boolean drop(List ids) { + List authorizations = authorizationService.get(ids); + for (GeaflowAuthorization authorization : authorizations) { + GeaflowAuthority authority = GeaflowAuthority.of(authorization.getAuthorityType()); + GeaflowResource resource = + resourceFactory.build(authorization.getResourceType(), authorization.getResourceId()); + + // check current user has the authority + hasAuthority(authority, resource); } + return super.drop(ids); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ChatManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ChatManagerImpl.java index 2d99430bf..8f6784a8b 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ChatManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ChatManagerImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.impl; import java.util.List; + import org.apache.geaflow.console.biz.shared.ChatManager; import org.apache.geaflow.console.biz.shared.convert.ChatViewConverter; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; @@ -34,50 +35,46 @@ import org.springframework.stereotype.Service; @Service -public class ChatManagerImpl extends IdManagerImpl implements ChatManager { - - @Autowired - private ChatService chatService; - - @Autowired - private ChatViewConverter chatViewConverter; - - @Override - protected IdService getService() { - return chatService; - } +public class ChatManagerImpl extends IdManagerImpl + implements ChatManager { - @Override - protected IdViewConverter getConverter() { - return chatViewConverter; - } + @Autowired private ChatService chatService; + @Autowired private ChatViewConverter chatViewConverter; - @Override - protected List parse(List views) { - return ListUtil.convert(views, chatViewConverter::convert); - } + @Override + protected IdService getService() { + return chatService; + } - @Override - public String callASync(ChatView view, boolean withSchema) { - view.setStatus(GeaflowStatementStatus.RUNNING); - String id = super.create(view); - GeaflowChat chat = chatViewConverter.convert(view); - chatService.callASync(chat, withSchema); + @Override + protected IdViewConverter getConverter() { + return chatViewConverter; + } - return id; - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, chatViewConverter::convert); + } - @Override - public String callSync(ChatView chatView, boolean saveRecord, boolean withSchema) { - GeaflowChat geaflowChat = chatViewConverter.convert(chatView); - return chatService.callSync(geaflowChat, saveRecord, withSchema); - } + @Override + public String callASync(ChatView view, boolean withSchema) { + view.setStatus(GeaflowStatementStatus.RUNNING); + String id = super.create(view); + GeaflowChat chat = chatViewConverter.convert(view); + chatService.callASync(chat, withSchema); - @Override - public boolean dropByJobId(String jobId) { - return chatService.dropByJobId(jobId); - } + return id; + } + @Override + public String callSync(ChatView chatView, boolean saveRecord, boolean withSchema) { + GeaflowChat geaflowChat = chatViewConverter.convert(chatView); + return chatService.callSync(geaflowChat, saveRecord, withSchema); + } + @Override + public boolean dropByJobId(String jobId) { + return chatService.dropByJobId(jobId); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ClusterManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ClusterManagerImpl.java index a55d66e86..da6938426 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ClusterManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ClusterManagerImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.impl; import java.util.List; + import org.apache.geaflow.console.biz.shared.ClusterManager; import org.apache.geaflow.console.biz.shared.convert.ClusterViewConverter; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; @@ -34,36 +35,34 @@ import org.springframework.stereotype.Service; @Service -public class ClusterManagerImpl extends NameManagerImpl implements - ClusterManager { - - @Autowired - private ClusterService clusterService; +public class ClusterManagerImpl extends NameManagerImpl + implements ClusterManager { - @Autowired - private ClusterViewConverter clusterViewConverter; + @Autowired private ClusterService clusterService; - @Override - protected NameService getService() { - return clusterService; - } + @Autowired private ClusterViewConverter clusterViewConverter; - @Override - protected NameViewConverter getConverter() { - return clusterViewConverter; - } + @Override + protected NameService getService() { + return clusterService; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, clusterViewConverter::convert); - } + @Override + protected NameViewConverter getConverter() { + return clusterViewConverter; + } - @Override - public String create(ClusterView view) { - if (clusterService.existName(view.getName())) { - throw new GeaflowIllegalException("Cluster name {} exists", view.getName()); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, clusterViewConverter::convert); + } - return super.create(view); + @Override + public String create(ClusterView view) { + if (clusterService.existName(view.getName())) { + throw new GeaflowIllegalException("Cluster name {} exists", view.getName()); } + + return super.create(view); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ConfigManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ConfigManagerImpl.java index da797b4da..8b3f39fb0 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ConfigManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ConfigManagerImpl.java @@ -19,10 +19,10 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.ConfigManager; import org.apache.geaflow.console.common.util.ListUtil; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; @@ -38,84 +38,86 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.google.common.collect.Lists; + @Service public class ConfigManagerImpl implements ConfigManager { - @Autowired - private DeployConfig deployConfig; + @Autowired private DeployConfig deployConfig; - @Autowired - private PluginService pluginService; + @Autowired private PluginService pluginService; - @Override - public List getClusterConfig() { - return ConfigDescFactory.getOrRegister(ClusterConfigClass.class).getItems(); - } + @Override + public List getClusterConfig() { + return ConfigDescFactory.getOrRegister(ClusterConfigClass.class).getItems(); + } - @Override - public List getJobConfig() { - return ConfigDescFactory.getOrRegister(JobConfigClass.class).getItems(); - } + @Override + public List getJobConfig() { + return ConfigDescFactory.getOrRegister(JobConfigClass.class).getItems(); + } - @Override - public List getPluginCategories() { - return Lists.newArrayList(GeaflowPluginCategory.values()); - } + @Override + public List getPluginCategories() { + return Lists.newArrayList(GeaflowPluginCategory.values()); + } - @Override - public List getPluginCategoryTypes(GeaflowPluginCategory category) { - List types = new ArrayList<>(); - switch (category) { - case TABLE: - types.add(GeaflowPluginType.FILE.name()); - types.add(GeaflowPluginType.KAFKA.name()); - types.add(GeaflowPluginType.HIVE.name()); - types.add(GeaflowPluginType.SOCKET.name()); - List plugins = pluginService.getPlugins(category); - types.addAll(ListUtil.convert(plugins, GeaflowPlugin::getType)); - break; - case GRAPH: - types.add(GeaflowPluginType.MEMORY.name()); - types.add(GeaflowPluginType.ROCKSDB.name()); - break; - case RUNTIME_CLUSTER: - if (deployConfig.isLocalMode()) { - types.add(GeaflowPluginType.CONTAINER.name()); - } - types.add(GeaflowPluginType.K8S.name()); - types.add(GeaflowPluginType.RAY.name()); - break; - case RUNTIME_META: - types.add(GeaflowPluginType.JDBC.name()); - break; - case HA_META: - types.add(GeaflowPluginType.REDIS.name()); - break; - case METRIC: - types.add(GeaflowPluginType.INFLUXDB.name()); - break; - case REMOTE_FILE: - case DATA: - if (deployConfig.isLocalMode()) { - types.add(GeaflowPluginType.LOCAL.name()); - } - types.add(GeaflowPluginType.DFS.name()); - types.add(GeaflowPluginType.OSS.name()); - break; - default: - throw new GeaflowIllegalException("Unknown category {}", category); + @Override + public List getPluginCategoryTypes(GeaflowPluginCategory category) { + List types = new ArrayList<>(); + switch (category) { + case TABLE: + types.add(GeaflowPluginType.FILE.name()); + types.add(GeaflowPluginType.KAFKA.name()); + types.add(GeaflowPluginType.HIVE.name()); + types.add(GeaflowPluginType.SOCKET.name()); + List plugins = pluginService.getPlugins(category); + types.addAll(ListUtil.convert(plugins, GeaflowPlugin::getType)); + break; + case GRAPH: + types.add(GeaflowPluginType.MEMORY.name()); + types.add(GeaflowPluginType.ROCKSDB.name()); + break; + case RUNTIME_CLUSTER: + if (deployConfig.isLocalMode()) { + types.add(GeaflowPluginType.CONTAINER.name()); } - return types.stream().distinct().collect(Collectors.toList()); - } - - @Override - public List getPluginConfig(GeaflowPluginCategory category, String type) { - if (!getPluginCategoryTypes(category).contains(type)) { - throw new GeaflowIllegalException("Plugin type {} not supported by category {}", type, category); + types.add(GeaflowPluginType.K8S.name()); + types.add(GeaflowPluginType.RAY.name()); + break; + case RUNTIME_META: + types.add(GeaflowPluginType.JDBC.name()); + break; + case HA_META: + types.add(GeaflowPluginType.REDIS.name()); + break; + case METRIC: + types.add(GeaflowPluginType.INFLUXDB.name()); + break; + case REMOTE_FILE: + case DATA: + if (deployConfig.isLocalMode()) { + types.add(GeaflowPluginType.LOCAL.name()); } + types.add(GeaflowPluginType.DFS.name()); + types.add(GeaflowPluginType.OSS.name()); + break; + default: + throw new GeaflowIllegalException("Unknown category {}", category); + } + return types.stream().distinct().collect(Collectors.toList()); + } - GeaflowPluginType geaflowPluginType = GeaflowPluginType.of(type); - return geaflowPluginType == GeaflowPluginType.None ? new ArrayList<>() : - ConfigDescFactory.get(geaflowPluginType).getItems(); + @Override + public List getPluginConfig(GeaflowPluginCategory category, String type) { + if (!getPluginCategoryTypes(category).contains(type)) { + throw new GeaflowIllegalException( + "Plugin type {} not supported by category {}", type, category); } + + GeaflowPluginType geaflowPluginType = GeaflowPluginType.of(type); + return geaflowPluginType == GeaflowPluginType.None + ? new ArrayList<>() + : ConfigDescFactory.get(geaflowPluginType).getItems(); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/DataManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/DataManagerImpl.java index 21b4bcdfe..fb29c576f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/DataManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/DataManagerImpl.java @@ -19,10 +19,10 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.biz.shared.DataManager; import org.apache.geaflow.console.biz.shared.convert.DataViewConverter; @@ -37,126 +37,128 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.transaction.annotation.Transactional; -public abstract class DataManagerImpl extends - NameManagerImpl implements DataManager { - - @Autowired - private InstanceService instanceService; - - @Autowired - private DataManager dataManager; - - @Override - protected abstract DataService getService(); - - @Override - protected abstract DataViewConverter getConverter(); - - public V getByName(String instanceName, String name) { - // by instanceName - List models = getByNames(instanceName, Collections.singletonList(name)); - return models.isEmpty() ? null : models.get(0); - } - - public List getByNames(String instanceName, List names) { - String instanceId = getInstanceIdByName(instanceName); - List models = getService().getByNames(instanceId, names); - return build(models); - } - - public boolean dropByName(String instanceName, String name) { - return dataManager.dropByNames(instanceName, Collections.singletonList(name)); - } - - @Transactional - public boolean dropByNames(String instanceName, List names) { - String instanceId = getInstanceIdByName(instanceName); - return getService().dropByNames(instanceId, names); - } - - public String create(String instanceName, V view) { - return dataManager.create(instanceName, Collections.singletonList(view)).get(0); - } - - @Transactional - public List create(String instanceName, List views) { - List models = parse(views); - // set instanceId - String instanceId = getInstanceIdByName(instanceName); - for (M model : models) { - model.setInstanceId(instanceId); - } - - List ids = getService().create(models); - - for (int i = 0; i < ids.size(); i++) { - views.get(i).setId(ids.get(i)); - } - return ids; - } - - @Override - public boolean updateByName(String name, V view) { - throw new GeaflowException("Use updateByName(instanceName, name, view) instead"); - } - - @Override - public boolean updateByName(String instanceName, String name, V view) { - String instanceId = getInstanceIdByName(instanceName); - String id = getService().getIdByName(instanceId, name); - Preconditions.checkNotNull(id, "Invalid name %s in instance %s", name, instanceName); - return updateById(id, view); - } - - @Override - public boolean update(String instanceName, List views) { - String instanceId = getInstanceIdByName(instanceName); - for (V view : views) { - String id = getService().getIdByName(instanceId, view.getName()); - Preconditions.checkNotNull(id, "Invalid name %s in instance %s", view.getName(), instanceName); - view.setId(id); - } +import com.google.common.base.Preconditions; - return dataManager.update(views); +public abstract class DataManagerImpl< + M extends GeaflowData, V extends DataView, S extends DataSearch> + extends NameManagerImpl implements DataManager { + + @Autowired private InstanceService instanceService; + + @Autowired private DataManager dataManager; + + @Override + protected abstract DataService getService(); + + @Override + protected abstract DataViewConverter getConverter(); + + public V getByName(String instanceName, String name) { + // by instanceName + List models = getByNames(instanceName, Collections.singletonList(name)); + return models.isEmpty() ? null : models.get(0); + } + + public List getByNames(String instanceName, List names) { + String instanceId = getInstanceIdByName(instanceName); + List models = getService().getByNames(instanceId, names); + return build(models); + } + + public boolean dropByName(String instanceName, String name) { + return dataManager.dropByNames(instanceName, Collections.singletonList(name)); + } + + @Transactional + public boolean dropByNames(String instanceName, List names) { + String instanceId = getInstanceIdByName(instanceName); + return getService().dropByNames(instanceId, names); + } + + public String create(String instanceName, V view) { + return dataManager.create(instanceName, Collections.singletonList(view)).get(0); + } + + @Transactional + public List create(String instanceName, List views) { + List models = parse(views); + // set instanceId + String instanceId = getInstanceIdByName(instanceName); + for (M model : models) { + model.setInstanceId(instanceId); } - @Override - public List getByNames(List names) { - throw new GeaflowIllegalException("Instance id is needed"); - } + List ids = getService().create(models); - @Override - public List create(List views) { - throw new GeaflowIllegalException("Instance id is needed"); + for (int i = 0; i < ids.size(); i++) { + views.get(i).setId(ids.get(i)); } - - @Override - public boolean dropByName(String name) { - throw new GeaflowIllegalException("Instance id is needed"); + return ids; + } + + @Override + public boolean updateByName(String name, V view) { + throw new GeaflowException("Use updateByName(instanceName, name, view) instead"); + } + + @Override + public boolean updateByName(String instanceName, String name, V view) { + String instanceId = getInstanceIdByName(instanceName); + String id = getService().getIdByName(instanceId, name); + Preconditions.checkNotNull(id, "Invalid name %s in instance %s", name, instanceName); + return updateById(id, view); + } + + @Override + public boolean update(String instanceName, List views) { + String instanceId = getInstanceIdByName(instanceName); + for (V view : views) { + String id = getService().getIdByName(instanceId, view.getName()); + Preconditions.checkNotNull( + id, "Invalid name %s in instance %s", view.getName(), instanceName); + view.setId(id); } - @Override - public PageList searchByInstanceName(String instanceName, S search) { - String instanceId = getInstanceIdByName(instanceName); - search.setInstanceId(instanceId); - return search(search); + return dataManager.update(views); + } + + @Override + public List getByNames(List names) { + throw new GeaflowIllegalException("Instance id is needed"); + } + + @Override + public List create(List views) { + throw new GeaflowIllegalException("Instance id is needed"); + } + + @Override + public boolean dropByName(String name) { + throw new GeaflowIllegalException("Instance id is needed"); + } + + @Override + public PageList searchByInstanceName(String instanceName, S search) { + String instanceId = getInstanceIdByName(instanceName); + search.setInstanceId(instanceId); + return search(search); + } + + @Override + public void createIfIdAbsent(String instanceName, List views) { + if (CollectionUtils.isEmpty(views)) { + return; } - @Override - public void createIfIdAbsent(String instanceName, List views) { - if (CollectionUtils.isEmpty(views)) { - return; - } - - List filtered = views.stream().filter(e -> e.getId() == null).collect(Collectors.toList()); - this.create(instanceName, filtered); - } + List filtered = views.stream().filter(e -> e.getId() == null).collect(Collectors.toList()); + this.create(instanceName, filtered); + } - protected String getInstanceIdByName(String instanceName) { - String id = instanceService.getIdByName(instanceName); - if (id == null) { - throw new GeaflowException("Instance name {} not found", instanceName); - } - return id; + protected String getInstanceIdByName(String instanceName) { + String id = instanceService.getIdByName(instanceName); + if (id == null) { + throw new GeaflowException("Instance name {} not found", instanceName); } + return id; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/EdgeManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/EdgeManagerImpl.java index 87b49e6aa..611f03fa4 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/EdgeManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/EdgeManagerImpl.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.EdgeManager; import org.apache.geaflow.console.biz.shared.convert.DataViewConverter; import org.apache.geaflow.console.biz.shared.convert.EdgeViewConverter; @@ -37,34 +38,34 @@ import org.springframework.stereotype.Service; @Service -public class EdgeManagerImpl extends DataManagerImpl implements EdgeManager { - - @Autowired - private EdgeService edgeService; - - @Autowired - private EdgeViewConverter edgeViewConverter; +public class EdgeManagerImpl extends DataManagerImpl + implements EdgeManager { - @Autowired - private FieldViewConverter fieldViewConverter; + @Autowired private EdgeService edgeService; - @Override - public DataViewConverter getConverter() { - return edgeViewConverter; - } + @Autowired private EdgeViewConverter edgeViewConverter; - @Override - public DataService getService() { - return edgeService; - } + @Autowired private FieldViewConverter fieldViewConverter; + @Override + public DataViewConverter getConverter() { + return edgeViewConverter; + } - @Override - protected List parse(List views) { - return views.stream().map(e -> { - List fields = ListUtil.convert(e.getFields(), fieldViewConverter::convert); - return edgeViewConverter.converter(e, fields); - }).collect(Collectors.toList()); - } + @Override + public DataService getService() { + return edgeService; + } + @Override + protected List parse(List views) { + return views.stream() + .map( + e -> { + List fields = + ListUtil.convert(e.getFields(), fieldViewConverter::convert); + return edgeViewConverter.converter(e, fields); + }) + .collect(Collectors.toList()); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/FunctionManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/FunctionManagerImpl.java index eaa064a57..dccb9c699 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/FunctionManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/FunctionManagerImpl.java @@ -21,11 +21,10 @@ import static org.apache.geaflow.console.core.service.RemoteFileService.JAR_FILE_SUFFIX; -import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.FunctionManager; @@ -54,156 +53,167 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; -@Service -@Slf4j -public class FunctionManagerImpl extends DataManagerImpl implements - FunctionManager { - - @Autowired - private FunctionViewConverter functionViewConverter; - - @Autowired - private FunctionService functionService; - - @Autowired - private JobService jobService; +import com.google.common.base.Preconditions; - @Autowired - private RemoteFileService remoteFileService; +import lombok.extern.slf4j.Slf4j; - @Autowired - private RemoteFileManager remoteFileManager; +@Service +@Slf4j +public class FunctionManagerImpl + extends DataManagerImpl + implements FunctionManager { + + @Autowired private FunctionViewConverter functionViewConverter; + + @Autowired private FunctionService functionService; + + @Autowired private JobService jobService; + + @Autowired private RemoteFileService remoteFileService; + + @Autowired private RemoteFileManager remoteFileManager; + + @Override + protected DataViewConverter getConverter() { + return functionViewConverter; + } + + @Override + protected DataService getService() { + return functionService; + } + + @Override + protected List parse(List views) { + List packageIds = + views.stream().map(e -> e.getJarPackage().getId()).collect(Collectors.toList()); + List jarPackages = remoteFileService.get(packageIds); + Map map = + jarPackages.stream().collect(Collectors.toMap(GeaflowId::getId, e -> e)); + + return views.stream() + .map( + e -> { + GeaflowRemoteFile jarPackage = map.get(e.getJarPackage().getId()); + return functionViewConverter.convert(e, jarPackage); + }) + .collect(Collectors.toList()); + } + + @Override + @Transactional + public String createFunction( + String instanceName, FunctionView functionView, MultipartFile functionFile, String fileId) { + String functionName = functionView.getName(); + if (StringUtils.isBlank(functionName)) { + throw new GeaflowIllegalException("Invalid function name"); + } - @Override - protected DataViewConverter getConverter() { - return functionViewConverter; + if (functionService.existName(functionName)) { + throw new GeaflowIllegalException("Function name {} exists", functionName); } - @Override - protected DataService getService() { - return functionService; + Preconditions.checkNotNull(functionView.getEntryClass(), "Function needs entryClass"); + if (fileId == null) { + Preconditions.checkNotNull(functionFile, "Invalid function file"); + functionView.setJarPackage(createRemoteFile(functionFile)); + } else { + // bind a jar file if jarId is not null + if (!remoteFileService.exist(fileId)) { + throw new GeaflowIllegalException("File {} does not exist", fileId); + } + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setId(fileId); + functionView.setJarPackage(remoteFileView); } - @Override - protected List parse(List views) { - List packageIds = views.stream().map(e -> e.getJarPackage().getId()).collect(Collectors.toList()); - List jarPackages = remoteFileService.get(packageIds); - Map map = jarPackages.stream() - .collect(Collectors.toMap(GeaflowId::getId, e -> e)); - - return views.stream().map(e -> { - GeaflowRemoteFile jarPackage = map.get(e.getJarPackage().getId()); - return functionViewConverter.convert(e, jarPackage); - }).collect(Collectors.toList()); + return super.create(instanceName, functionView); + } + + @Override + @Transactional + public boolean updateFunction( + String instanceName, + String functionName, + FunctionView updateView, + MultipartFile functionFile) { + FunctionView oldView = getByName(instanceName, functionName); + if (oldView == null) { + throw new GeaflowIllegalException("Function name {} not exists", functionName); } - @Override - @Transactional - public String createFunction(String instanceName, FunctionView functionView, MultipartFile functionFile, String fileId) { - String functionName = functionView.getName(); - if (StringUtils.isBlank(functionName)) { - throw new GeaflowIllegalException("Invalid function name"); - } - - if (functionService.existName(functionName)) { - throw new GeaflowIllegalException("Function name {} exists", functionName); - } - - Preconditions.checkNotNull(functionView.getEntryClass(), "Function needs entryClass"); - if (fileId == null) { - Preconditions.checkNotNull(functionFile, "Invalid function file"); - functionView.setJarPackage(createRemoteFile(functionFile)); - } else { - // bind a jar file if jarId is not null - if (!remoteFileService.exist(fileId)) { - throw new GeaflowIllegalException("File {} does not exist", fileId); - } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setId(fileId); - functionView.setJarPackage(remoteFileView); - } - - return super.create(instanceName, functionView); + if (functionFile != null) { + updateView.setJarPackage(updateJarFile(updateView, functionFile)); + } + return updateById(oldView.getId(), updateView); + } + + @Transactional + @Override + public boolean deleteFunction(String instanceName, String functionName) { + String instanceId = getInstanceIdByName(instanceName); + GeaflowFunction function = functionService.getByName(instanceId, functionName); + if (function == null) { + return false; } - @Override - @Transactional - public boolean updateFunction(String instanceName, String functionName, FunctionView updateView, MultipartFile functionFile) { - FunctionView oldView = getByName(instanceName, functionName); - if (oldView == null) { - throw new GeaflowIllegalException("Function name {} not exists", functionName); - } - - if (functionFile != null) { - updateView.setJarPackage(updateJarFile(updateView, functionFile)); - } - return updateById(oldView.getId(), updateView); + // check plugin is used by jobs + List jobIds = + jobService.getJobByResources( + function.getName(), function.getInstanceId(), GeaflowResourceType.FUNCTION); + if (CollectionUtils.isNotEmpty(jobIds)) { + List jobNames = ListUtil.convert(jobIds, e -> jobService.getNameById(e)); + throw new GeaflowException( + "Function {} is used by job: {}", function.getName(), String.join(",", jobNames)); } - @Transactional - @Override - public boolean deleteFunction(String instanceName, String functionName) { - String instanceId = getInstanceIdByName(instanceName); - GeaflowFunction function = functionService.getByName(instanceId, functionName); - if (function == null) { - return false; - } - - // check plugin is used by jobs - List jobIds = jobService.getJobByResources(function.getName(), function.getInstanceId(), - GeaflowResourceType.FUNCTION); - if (CollectionUtils.isNotEmpty(jobIds)) { - List jobNames = ListUtil.convert(jobIds, e -> jobService.getNameById(e)); - throw new GeaflowException("Function {} is used by job: {}", function.getName(), String.join(",", jobNames)); - } - - GeaflowRemoteFile file = function.getJarPackage(); - if (file != null) { - // do not delete if file is used by others - try { - remoteFileManager.deleteRefJar(file.getId(), function.getId(), GeaflowResourceType.FUNCTION); - - } catch (Exception e) { - log.info(" Delete function -> delete file {} failed ", file.getName(), e); - } - } - return super.dropByName(instanceName, functionName); + GeaflowRemoteFile file = function.getJarPackage(); + if (file != null) { + // do not delete if file is used by others + try { + remoteFileManager.deleteRefJar( + file.getId(), function.getId(), GeaflowResourceType.FUNCTION); + + } catch (Exception e) { + log.info(" Delete function -> delete file {} failed ", file.getName(), e); + } } + return super.dropByName(instanceName, functionName); + } + private RemoteFileView createRemoteFile(MultipartFile functionFile) { + if (!StringUtils.endsWith(functionFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); + } - private RemoteFileView createRemoteFile(MultipartFile functionFile) { - if (!StringUtils.endsWith(functionFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + String fileName = functionFile.getOriginalFilename(); + if (remoteFileService.existName(fileName)) { + throw new GeaflowException("FileName {} exists", fileName); + } - String fileName = functionFile.getOriginalFilename(); - if (remoteFileService.existName(fileName)) { - throw new GeaflowException("FileName {} exists", fileName); - } + String path = RemoteFileStorage.getUserFilePath(ContextHolder.get().getUserId(), fileName); - String path = RemoteFileStorage.getUserFilePath(ContextHolder.get().getUserId(), fileName); + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setName(fileName); + remoteFileView.setPath(path); + remoteFileManager.create(remoteFileView, functionFile); - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setName(fileName); - remoteFileView.setPath(path); - remoteFileManager.create(remoteFileView, functionFile); + return remoteFileView; + } - return remoteFileView; + private RemoteFileView updateJarFile(FunctionView functionView, MultipartFile multipartFile) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); } - private RemoteFileView updateJarFile(FunctionView functionView, MultipartFile multipartFile) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } - - RemoteFileView remoteFileView = functionView.getJarPackage(); - if (remoteFileView == null) { - return createRemoteFile(multipartFile); + RemoteFileView remoteFileView = functionView.getJarPackage(); + if (remoteFileView == null) { + return createRemoteFile(multipartFile); - } else { - String remoteFileId = remoteFileView.getId(); - remoteFileManager.upload(remoteFileId, multipartFile); - return remoteFileView; - } + } else { + String remoteFileId = remoteFileView.getId(); + remoteFileManager.upload(remoteFileId, multipartFile); + return remoteFileView; } + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/GraphManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/GraphManagerImpl.java index c0a3328f7..73a304ea9 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/GraphManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/GraphManagerImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -28,6 +27,7 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.biz.shared.EdgeManager; import org.apache.geaflow.console.biz.shared.GraphManager; @@ -60,166 +60,172 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -@Service -public class GraphManagerImpl extends DataManagerImpl implements GraphManager { - - @Autowired - private GraphService graphService; - - @Autowired - private VertexService vertexService; +import com.google.common.base.Preconditions; - @Autowired - private VertexManager vertexManager; +@Service +public class GraphManagerImpl extends DataManagerImpl + implements GraphManager { - @Autowired - private EdgeService edgeService; + @Autowired private GraphService graphService; - @Autowired - private EdgeManager edgeManager; + @Autowired private VertexService vertexService; - @Autowired - private PluginConfigManager pluginConfigManager; + @Autowired private VertexManager vertexManager; - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private EdgeService edgeService; - @Autowired - private GraphViewConverter graphViewConverter; + @Autowired private EdgeManager edgeManager; - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; + @Autowired private PluginConfigManager pluginConfigManager; - @Override - public DataViewConverter getConverter() { - return graphViewConverter; - } + @Autowired private PluginConfigService pluginConfigService; - @Override - public DataService getService() { - return graphService; - } + @Autowired private GraphViewConverter graphViewConverter; - @Override - public List create(String instanceName, List views) { - for (GraphView g : views) { - g.setVertices(Optional.ofNullable(g.getVertices()).orElse(new ArrayList<>())); - g.setEdges(Optional.ofNullable(g.getEdges()).orElse(new ArrayList<>())); + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; - // create if id is null - vertexManager.createIfIdAbsent(instanceName, g.getVertices()); - edgeManager.createIfIdAbsent(instanceName, g.getEdges()); + @Override + public DataViewConverter getConverter() { + return graphViewConverter; + } - PluginConfigView pluginConfigView = Preconditions.checkNotNull(g.getPluginConfig(), - "Graph pluginConfig is required"); - pluginConfigView.setCategory(GeaflowPluginCategory.GRAPH); - pluginConfigView.setName(Fmt.as("{}-{}-graph-config", instanceName, g.getName())); - pluginConfigManager.create(pluginConfigView); - } + @Override + public DataService getService() { + return graphService; + } - return super.create(instanceName, views); - } + @Override + public List create(String instanceName, List views) { + for (GraphView g : views) { + g.setVertices(Optional.ofNullable(g.getVertices()).orElse(new ArrayList<>())); + g.setEdges(Optional.ofNullable(g.getEdges()).orElse(new ArrayList<>())); - @Override - @Transactional - public boolean updateByName(String instanceName, String name, GraphView view) { - // only support to add vertices and edges - vertexManager.createIfIdAbsent(instanceName, view.getVertices()); - edgeManager.createIfIdAbsent(instanceName, view.getEdges()); + // create if id is null + vertexManager.createIfIdAbsent(instanceName, g.getVertices()); + edgeManager.createIfIdAbsent(instanceName, g.getEdges()); - return super.updateByName(instanceName, name, view); + PluginConfigView pluginConfigView = + Preconditions.checkNotNull(g.getPluginConfig(), "Graph pluginConfig is required"); + pluginConfigView.setCategory(GeaflowPluginCategory.GRAPH); + pluginConfigView.setName(Fmt.as("{}-{}-graph-config", instanceName, g.getName())); + pluginConfigManager.create(pluginConfigView); } - - @Override - protected List parse(List views) { - return ListUtil.convert(views, g -> { - List vertexIds = ListUtil.convert(g.getVertices(), IdView::getId); - List edgeIds = ListUtil.convert(g.getEdges(), IdView::getId); - Map orderMap = getOrderMap(vertexIds, edgeIds); - - // ensure the order of vertices and edges - List vertices = vertexService.get(vertexIds); - vertices = vertices.stream() - .sorted(Comparator.comparing(e -> { - String key = getResourceKey(e.getType(), e.getId()); - return orderMap.get(key); - })).collect(Collectors.toList()); - - List edges = edgeService.get(edgeIds); - edges = edges.stream() - .sorted(Comparator.comparing(e -> { - String key = getResourceKey(e.getType(), e.getId()); - return orderMap.get(key); - })).collect(Collectors.toList()); - - GeaflowPluginConfig pluginConfig = pluginConfigViewConverter.convert(g.getPluginConfig()); - - return graphViewConverter.convert(g, vertices, edges, pluginConfig); + return super.create(instanceName, views); + } + + @Override + @Transactional + public boolean updateByName(String instanceName, String name, GraphView view) { + // only support to add vertices and edges + vertexManager.createIfIdAbsent(instanceName, view.getVertices()); + edgeManager.createIfIdAbsent(instanceName, view.getEdges()); + + return super.updateByName(instanceName, name, view); + } + + @Override + protected List parse(List views) { + return ListUtil.convert( + views, + g -> { + List vertexIds = ListUtil.convert(g.getVertices(), IdView::getId); + List edgeIds = ListUtil.convert(g.getEdges(), IdView::getId); + Map orderMap = getOrderMap(vertexIds, edgeIds); + + // ensure the order of vertices and edges + List vertices = vertexService.get(vertexIds); + vertices = + vertices.stream() + .sorted( + Comparator.comparing( + e -> { + String key = getResourceKey(e.getType(), e.getId()); + return orderMap.get(key); + })) + .collect(Collectors.toList()); + + List edges = edgeService.get(edgeIds); + edges = + edges.stream() + .sorted( + Comparator.comparing( + e -> { + String key = getResourceKey(e.getType(), e.getId()); + return orderMap.get(key); + })) + .collect(Collectors.toList()); + + GeaflowPluginConfig pluginConfig = pluginConfigViewConverter.convert(g.getPluginConfig()); + + return graphViewConverter.convert(g, vertices, edges, pluginConfig); }); - } - - private Map getOrderMap(List vertexIds, List edgeIds) { - Map orderMap = new HashMap<>(); - for (int i = 0; i < vertexIds.size(); i++) { - orderMap.put(getResourceKey(GeaflowStructType.VERTEX, vertexIds.get(i)), i); - } - - for (int i = 0; i < edgeIds.size(); i++) { - orderMap.put(getResourceKey(GeaflowStructType.EDGE, edgeIds.get(i)), i); - } + } - return orderMap; + private Map getOrderMap(List vertexIds, List edgeIds) { + Map orderMap = new HashMap<>(); + for (int i = 0; i < vertexIds.size(); i++) { + orderMap.put(getResourceKey(GeaflowStructType.VERTEX, vertexIds.get(i)), i); } - private String getResourceKey(GeaflowStructType structType, String resourceId) { - return structType.name() + "-" + resourceId; + for (int i = 0; i < edgeIds.size(); i++) { + orderMap.put(getResourceKey(GeaflowStructType.EDGE, edgeIds.get(i)), i); } - @Override - @Transactional - public boolean createEndpoints(String instanceName, String graphName, List views) { - if (CollectionUtils.isEmpty(views)) { - return true; - } + return orderMap; + } - String instanceId = getInstanceIdByName(instanceName); - GeaflowGraph graph = graphService.getByName(instanceId, graphName); - Preconditions.checkNotNull(graph, "Graph %s not exist", graphName); + private String getResourceKey(GeaflowStructType structType, String resourceId) { + return structType.name() + "-" + resourceId; + } - List endpoints = buildEndpoints(instanceId, views); - return graphService.createEndpoints(graph, endpoints); + @Override + @Transactional + public boolean createEndpoints(String instanceName, String graphName, List views) { + if (CollectionUtils.isEmpty(views)) { + return true; } - @Override - public boolean deleteEndpoints(String instanceName, String edgeName, List views) { - String instanceId = getInstanceIdByName(instanceName); - GeaflowGraph graph = graphService.getByName(instanceId, edgeName); - Preconditions.checkNotNull(graph, "Graph %s not exist", edgeName); - List endpoints = buildEndpoints(instanceId, views); - return graphService.deleteEndpoints(graph, endpoints); - } - - @Override - public boolean clean(String instanceName, String graphName) { - String instanceId = getInstanceIdByName(instanceName); - GeaflowGraph graph = graphService.getByName(instanceId, graphName); - return graphService.clean(Collections.singletonList(graph)); - } - - private List buildEndpoints(String instanceId, List endpointViews) { - return ListUtil.convert(endpointViews, e -> { - // check vertex/edge existing and build endpoint models - String edgeName = e.getEdgeName(); - String srcName = e.getSourceName(); - String targetName = e.getTargetName(); - GeaflowEdge edge = edgeService.getByName(instanceId, edgeName); - Preconditions.checkNotNull(edge, "Edge %s not exist", edgeName); - GeaflowVertex srcVertex = vertexService.getByName(instanceId, srcName); - Preconditions.checkNotNull(srcVertex, "Vertex %s not exist", srcName); - GeaflowVertex targetVertex = vertexService.getByName(instanceId, targetName); - Preconditions.checkNotNull(targetVertex, "Vertex %s not exist", targetName); - return new GeaflowEndpoint(edge.getId(), srcVertex.getId(), targetVertex.getId()); + String instanceId = getInstanceIdByName(instanceName); + GeaflowGraph graph = graphService.getByName(instanceId, graphName); + Preconditions.checkNotNull(graph, "Graph %s not exist", graphName); + + List endpoints = buildEndpoints(instanceId, views); + return graphService.createEndpoints(graph, endpoints); + } + + @Override + public boolean deleteEndpoints(String instanceName, String edgeName, List views) { + String instanceId = getInstanceIdByName(instanceName); + GeaflowGraph graph = graphService.getByName(instanceId, edgeName); + Preconditions.checkNotNull(graph, "Graph %s not exist", edgeName); + List endpoints = buildEndpoints(instanceId, views); + return graphService.deleteEndpoints(graph, endpoints); + } + + @Override + public boolean clean(String instanceName, String graphName) { + String instanceId = getInstanceIdByName(instanceName); + GeaflowGraph graph = graphService.getByName(instanceId, graphName); + return graphService.clean(Collections.singletonList(graph)); + } + + private List buildEndpoints( + String instanceId, List endpointViews) { + return ListUtil.convert( + endpointViews, + e -> { + // check vertex/edge existing and build endpoint models + String edgeName = e.getEdgeName(); + String srcName = e.getSourceName(); + String targetName = e.getTargetName(); + GeaflowEdge edge = edgeService.getByName(instanceId, edgeName); + Preconditions.checkNotNull(edge, "Edge %s not exist", edgeName); + GeaflowVertex srcVertex = vertexService.getByName(instanceId, srcName); + Preconditions.checkNotNull(srcVertex, "Vertex %s not exist", srcName); + GeaflowVertex targetVertex = vertexService.getByName(instanceId, targetName); + Preconditions.checkNotNull(targetVertex, "Vertex %s not exist", targetName); + return new GeaflowEndpoint(edge.getId(), srcVertex.getId(), targetVertex.getId()); }); - } + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/IdManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/IdManagerImpl.java index 867c995e2..372645e25 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/IdManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/IdManagerImpl.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.Collections; import java.util.List; + import org.apache.geaflow.console.biz.shared.IdManager; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; import org.apache.geaflow.console.biz.shared.view.IdView; @@ -33,86 +33,90 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.transaction.annotation.Transactional; -public abstract class IdManagerImpl implements - IdManager { +import com.google.common.base.Preconditions; - protected abstract IdService getService(); +public abstract class IdManagerImpl + implements IdManager { - protected abstract IdViewConverter getConverter(); + protected abstract IdService getService(); - @Autowired - private IdManager idManager; + protected abstract IdViewConverter getConverter(); - protected V build(M model) { - return getConverter().convert(model); - } + @Autowired private IdManager idManager; - protected List build(List models) { - return ListUtil.convert(models, this::build); - } + protected V build(M model) { + return getConverter().convert(model); + } - protected M parse(V view) { - List views = parse(Collections.singletonList(view)); - return views.isEmpty() ? null : views.get(0); - } + protected List build(List models) { + return ListUtil.convert(models, this::build); + } - protected abstract List parse(List views); + protected M parse(V view) { + List views = parse(Collections.singletonList(view)); + return views.isEmpty() ? null : views.get(0); + } - @Override - public PageList search(S search) { - PageList models = getService().search(search); - return models.transform(this::build); - } + protected abstract List parse(List views); - public V get(String id) { - List list = get(Collections.singletonList(id)); - return list.isEmpty() ? null : list.get(0); - } + @Override + public PageList search(S search) { + PageList models = getService().search(search); + return models.transform(this::build); + } - public String create(V view) { - return idManager.create(Collections.singletonList(view)).get(0); - } + public V get(String id) { + List list = get(Collections.singletonList(id)); + return list.isEmpty() ? null : list.get(0); + } - public boolean updateById(String id, V updateView) { - updateView.setId(id); - return idManager.update(Collections.singletonList(updateView)); - } + public String create(V view) { + return idManager.create(Collections.singletonList(view)).get(0); + } - public boolean drop(String id) { - return idManager.drop(Collections.singletonList(id)); - } + public boolean updateById(String id, V updateView) { + updateView.setId(id); + return idManager.update(Collections.singletonList(updateView)); + } - public List get(List ids) { - List models = getService().get(ids); - return build(models); - } + public boolean drop(String id) { + return idManager.drop(Collections.singletonList(id)); + } - @Transactional - public List create(List views) { - List models = parse(views); - List ids = getService().create(models); + public List get(List ids) { + List models = getService().get(ids); + return build(models); + } - for (int i = 0; i < ids.size(); i++) { - views.get(i).setId(ids.get(i)); - } - return ids; - } - - @Transactional - public boolean update(List updateViews) { - List views = ListUtil.convert(updateViews, e -> { - V oldView = get(e.getId()); - Preconditions.checkNotNull(oldView, "Invalid id {}", e.getId()); - getConverter().merge(oldView, e); - return oldView; - }); - - List models = parse(views); - return getService().update(models); - } + @Transactional + public List create(List views) { + List models = parse(views); + List ids = getService().create(models); - @Transactional - public boolean drop(List ids) { - return getService().drop(ids); + for (int i = 0; i < ids.size(); i++) { + views.get(i).setId(ids.get(i)); } + return ids; + } + + @Transactional + public boolean update(List updateViews) { + List views = + ListUtil.convert( + updateViews, + e -> { + V oldView = get(e.getId()); + Preconditions.checkNotNull(oldView, "Invalid id {}", e.getId()); + getConverter().merge(oldView, e); + return oldView; + }); + + List models = parse(views); + return getService().update(models); + } + + @Transactional + public boolean drop(List ids) { + return getService().drop(ids); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstallManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstallManagerImpl.java index 8188369a1..b332ab750 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstallManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstallManagerImpl.java @@ -21,7 +21,7 @@ import java.util.ArrayList; import java.util.List; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.InstallManager; import org.apache.geaflow.console.biz.shared.VersionManager; @@ -68,314 +68,307 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Service public class InstallManagerImpl implements InstallManager { - @Autowired - private SystemConfigService systemConfigService; - - @Autowired - private DeployConfig deployConfig; + @Autowired private SystemConfigService systemConfigService; - @Autowired - private InstallViewConverter installViewConverter; + @Autowired private DeployConfig deployConfig; - @Autowired - private DatasourceConfig datasourceConfig; + @Autowired private InstallViewConverter installViewConverter; - @Autowired - private DatasourceService datasourceService; + @Autowired private DatasourceConfig datasourceConfig; - @Autowired - private ClusterService clusterService; + @Autowired private DatasourceService datasourceService; - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private ClusterService clusterService; - @Autowired - private VersionManager versionManager; + @Autowired private PluginConfigService pluginConfigService; - @Autowired - private TokenGenerator tokenGenerator; + @Autowired private VersionManager versionManager; - @Autowired - private JobService jobService; + @Autowired private TokenGenerator tokenGenerator; - @Autowired - private InstanceService instanceService; + @Autowired private JobService jobService; - @Autowired - private UserService userService; + @Autowired private InstanceService instanceService; - @Autowired - private TenantService tenantService; + @Autowired private UserService userService; - private final List demoJobs; + @Autowired private TenantService tenantService; - @Autowired - public InstallManagerImpl(List demoJobs) { - this.demoJobs = demoJobs; - } + private final List demoJobs; + @Autowired + public InstallManagerImpl(List demoJobs) { + this.demoJobs = demoJobs; + } - private interface ConfigBuilder { + private interface ConfigBuilder { - PluginConfigClass configRuntimeCluster(); + PluginConfigClass configRuntimeCluster(); - PluginConfigClass configRuntimeMeta(); + PluginConfigClass configRuntimeMeta(); - PluginConfigClass configHaMeta(); + PluginConfigClass configHaMeta(); - PluginConfigClass configMetric(); + PluginConfigClass configMetric(); - PluginConfigClass configRemoteFile(); + PluginConfigClass configRemoteFile(); - PluginConfigClass configData(); + PluginConfigClass configData(); + } - } + @Override + public InstallView get() { + GeaflowInstall install = new GeaflowInstall(); - @Override - public InstallView get() { - GeaflowInstall install = new GeaflowInstall(); - - // load default config - install.setRuntimeClusterConfig( - pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.RUNTIME_CLUSTER)); - install.setRuntimeMetaConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.RUNTIME_META)); - install.setHaMetaConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.HA_META)); - install.setMetricConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.METRIC)); - install.setRemoteFileConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.REMOTE_FILE)); - install.setDataConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.DATA)); - - // default configs - if (deployConfig.isLocalMode()) { - ConfigBuilder builder = new DefaultConfigBuilder(); - if (install.getRuntimeClusterConfig() == null) { - GeaflowPluginConfig runtimeClusterConfig = new GeaflowPluginConfig( - GeaflowPluginCategory.RUNTIME_CLUSTER, builder.configRuntimeCluster()); - runtimeClusterConfig.setName("cluster-default"); - runtimeClusterConfig.setComment(I18nUtil.getMessage("i18n.key.default.cluster")); - install.setRuntimeClusterConfig(runtimeClusterConfig); - } - - if (install.getRuntimeMetaConfig() == null) { - GeaflowPluginConfig runtimeMetaConfig = new GeaflowPluginConfig(GeaflowPluginCategory.RUNTIME_META, - builder.configRuntimeMeta()); - runtimeMetaConfig.setName("runtime-meta-store-default"); - runtimeMetaConfig.setComment(I18nUtil.getMessage("i18n.key.default.runtime.meta.store")); - install.setRuntimeMetaConfig(runtimeMetaConfig); - } - - if (install.getHaMetaConfig() == null) { - GeaflowPluginConfig haMetaConfig = new GeaflowPluginConfig(GeaflowPluginCategory.HA_META, - builder.configHaMeta()); - haMetaConfig.setName("ha-meta-store-default"); - haMetaConfig.setComment(I18nUtil.getMessage("i18n.key.default.ha.meta.store")); - install.setHaMetaConfig(haMetaConfig); - } - - if (install.getMetricConfig() == null) { - GeaflowPluginConfig metricConfig = new GeaflowPluginConfig(GeaflowPluginCategory.METRIC, - builder.configMetric()); - metricConfig.setName("metric-store-default"); - metricConfig.setComment(I18nUtil.getMessage("i18n.key.default.metric.store")); - install.setMetricConfig(metricConfig); - } - - if (install.getRemoteFileConfig() == null) { - GeaflowPluginConfig remoteFileConfig = new GeaflowPluginConfig(GeaflowPluginCategory.REMOTE_FILE, - builder.configRemoteFile()); - remoteFileConfig.setName("file-store-default"); - remoteFileConfig.setComment(I18nUtil.getMessage("i18n.key.default.file.store")); - install.setRemoteFileConfig(remoteFileConfig); - } - - if (install.getDataConfig() == null) { - GeaflowPluginConfig dataConfig = new GeaflowPluginConfig(GeaflowPluginCategory.DATA, - builder.configData()); - dataConfig.setName("data-store-default"); - dataConfig.setComment(I18nUtil.getMessage("i18n.key.default.data.store")); - install.setDataConfig(dataConfig); - } - } + // load default config + install.setRuntimeClusterConfig( + pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.RUNTIME_CLUSTER)); + install.setRuntimeMetaConfig( + pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.RUNTIME_META)); + install.setHaMetaConfig( + pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.HA_META)); + install.setMetricConfig( + pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.METRIC)); + install.setRemoteFileConfig( + pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.REMOTE_FILE)); + install.setDataConfig(pluginConfigService.getDefaultPluginConfig(GeaflowPluginCategory.DATA)); - // get deploy mode - InstallView installView = installViewConverter.convert(install); - installView.setDeployMode(deployConfig.getMode()); + // default configs + if (deployConfig.isLocalMode()) { + ConfigBuilder builder = new DefaultConfigBuilder(); + if (install.getRuntimeClusterConfig() == null) { + GeaflowPluginConfig runtimeClusterConfig = + new GeaflowPluginConfig( + GeaflowPluginCategory.RUNTIME_CLUSTER, builder.configRuntimeCluster()); + runtimeClusterConfig.setName("cluster-default"); + runtimeClusterConfig.setComment(I18nUtil.getMessage("i18n.key.default.cluster")); + install.setRuntimeClusterConfig(runtimeClusterConfig); + } - return installView; - } + if (install.getRuntimeMetaConfig() == null) { + GeaflowPluginConfig runtimeMetaConfig = + new GeaflowPluginConfig( + GeaflowPluginCategory.RUNTIME_META, builder.configRuntimeMeta()); + runtimeMetaConfig.setName("runtime-meta-store-default"); + runtimeMetaConfig.setComment(I18nUtil.getMessage("i18n.key.default.runtime.meta.store")); + install.setRuntimeMetaConfig(runtimeMetaConfig); + } - @Transactional - @Override - public boolean install(InstallView installView) { - if (systemConfigService.getBoolean(SystemConfigKeys.GEAFLOW_INITIALIZED)) { - throw new GeaflowException("Geaflow has been initialized"); - } + if (install.getHaMetaConfig() == null) { + GeaflowPluginConfig haMetaConfig = + new GeaflowPluginConfig(GeaflowPluginCategory.HA_META, builder.configHaMeta()); + haMetaConfig.setName("ha-meta-store-default"); + haMetaConfig.setComment(I18nUtil.getMessage("i18n.key.default.ha.meta.store")); + install.setHaMetaConfig(haMetaConfig); + } - if (!systemConfigService.exist(null, SystemConfigKeys.GEAFLOW_INITIALIZED)) { - GeaflowSystemConfig config = new GeaflowSystemConfig(); - config.setName(SystemConfigKeys.GEAFLOW_INITIALIZED); - config.setComment(I18nUtil.getMessage("i18n.key.geaflow.system.inited.flag")); - config.setValue("false"); - systemConfigService.create(config); - } + if (install.getMetricConfig() == null) { + GeaflowPluginConfig metricConfig = + new GeaflowPluginConfig(GeaflowPluginCategory.METRIC, builder.configMetric()); + metricConfig.setName("metric-store-default"); + metricConfig.setComment(I18nUtil.getMessage("i18n.key.default.metric.store")); + install.setMetricConfig(metricConfig); + } - // check local deploy mode - if (!deployConfig.isLocalMode() && NetworkUtil.isLocal(datasourceConfig.getUrl())) { - throw new GeaflowException("Datasource '{}' can't be used in 'CLUSTER' deploy mode", - StringUtils.substringBeforeLast(datasourceConfig.getUrl(), "?")); - } + if (install.getRemoteFileConfig() == null) { + GeaflowPluginConfig remoteFileConfig = + new GeaflowPluginConfig(GeaflowPluginCategory.REMOTE_FILE, builder.configRemoteFile()); + remoteFileConfig.setName("file-store-default"); + remoteFileConfig.setComment(I18nUtil.getMessage("i18n.key.default.file.store")); + install.setRemoteFileConfig(remoteFileConfig); + } + + if (install.getDataConfig() == null) { + GeaflowPluginConfig dataConfig = + new GeaflowPluginConfig(GeaflowPluginCategory.DATA, builder.configData()); + dataConfig.setName("data-store-default"); + dataConfig.setComment(I18nUtil.getMessage("i18n.key.default.data.store")); + install.setDataConfig(dataConfig); + } + } - // prepare install - GeaflowInstall install = installViewConverter.convert(installView); - - // init plugin and config - List pluginConfigs = new ArrayList<>(); - pluginConfigs.add(install.getRuntimeClusterConfig()); - pluginConfigs.add(install.getRuntimeMetaConfig()); - pluginConfigs.add(install.getHaMetaConfig()); - pluginConfigs.add(install.getMetricConfig()); - pluginConfigs.add(install.getRemoteFileConfig()); - pluginConfigs.add(install.getDataConfig()); - pluginConfigs.forEach(pluginConfig -> { - pluginConfigService.testConnection(pluginConfig); - pluginConfigService.createDefaultPluginConfig(pluginConfig); - }); + // get deploy mode + InstallView installView = installViewConverter.convert(install); + installView.setDeployMode(deployConfig.getMode()); - // init meta table - GeaflowPluginConfig runtimeMetaConfig = install.getRuntimeMetaConfig(); - if (GeaflowPluginType.JDBC.name().equals(runtimeMetaConfig.getType())) { - JdbcPluginConfigClass jdbcConfig = runtimeMetaConfig.getConfig().parse(JdbcPluginConfigClass.class); - datasourceService.executeResource(jdbcConfig, "runtimemeta.init.sql"); - } + return installView; + } - // setup influxdb - if (deployConfig.isLocalMode()) { - GeaflowPluginConfig metricConfig = install.getMetricConfig(); - if (GeaflowPluginType.INFLUXDB.name().equals(metricConfig.getType())) { - InfluxdbPluginConfigClass influxdbConfig = metricConfig.getConfig() - .parse(InfluxdbPluginConfigClass.class); - if (influxdbConfig.getUrl().contains(deployConfig.getHost())) { - try { - String org = influxdbConfig.getOrg(); - String bucket = influxdbConfig.getBucket(); - String token = influxdbConfig.getToken(); - String setupCommand = Fmt.as("/usr/local/bin/influx setup --org '{}' --bucket '{}' " - + "--username geaflow --password geaflow123456 --token '{}' --force", org, bucket, token); - - log.info("Setup influxdb with command {}", setupCommand); - ProcessUtil.execute(setupCommand); - } catch (Exception e) { - log.error("Set up influx db failed", e); - } - - } - } - } + @Transactional + @Override + public boolean install(InstallView installView) { + if (systemConfigService.getBoolean(SystemConfigKeys.GEAFLOW_INITIALIZED)) { + throw new GeaflowException("Geaflow has been initialized"); + } - // init cluster - clusterService.create(new GeaflowCluster(install.getRuntimeClusterConfig())); + if (!systemConfigService.exist(null, SystemConfigKeys.GEAFLOW_INITIALIZED)) { + GeaflowSystemConfig config = new GeaflowSystemConfig(); + config.setName(SystemConfigKeys.GEAFLOW_INITIALIZED); + config.setComment(I18nUtil.getMessage("i18n.key.geaflow.system.inited.flag")); + config.setValue("false"); + systemConfigService.create(config); + } - // init version - versionManager.createDefaultVersion(); + // check local deploy mode + if (!deployConfig.isLocalMode() && NetworkUtil.isLocal(datasourceConfig.getUrl())) { + throw new GeaflowException( + "Datasource '{}' can't be used in 'CLUSTER' deploy mode", + StringUtils.substringBeforeLast(datasourceConfig.getUrl(), "?")); + } - createDemoJobs(); + // prepare install + GeaflowInstall install = installViewConverter.convert(installView); + + // init plugin and config + List pluginConfigs = new ArrayList<>(); + pluginConfigs.add(install.getRuntimeClusterConfig()); + pluginConfigs.add(install.getRuntimeMetaConfig()); + pluginConfigs.add(install.getHaMetaConfig()); + pluginConfigs.add(install.getMetricConfig()); + pluginConfigs.add(install.getRemoteFileConfig()); + pluginConfigs.add(install.getDataConfig()); + pluginConfigs.forEach( + pluginConfig -> { + pluginConfigService.testConnection(pluginConfig); + pluginConfigService.createDefaultPluginConfig(pluginConfig); + }); - // set install status - systemConfigService.setValue(SystemConfigKeys.GEAFLOW_INITIALIZED, true); - return true; + // init meta table + GeaflowPluginConfig runtimeMetaConfig = install.getRuntimeMetaConfig(); + if (GeaflowPluginType.JDBC.name().equals(runtimeMetaConfig.getType())) { + JdbcPluginConfigClass jdbcConfig = + runtimeMetaConfig.getConfig().parse(JdbcPluginConfigClass.class); + datasourceService.executeResource(jdbcConfig, "runtimemeta.init.sql"); } - private void createDemoJobs() { - GeaflowContext context = ContextHolder.get(); - try { - GeaflowUser user = userService.get(context.getUserId()); - GeaflowTenant tenant = tenantService.getByName(tenantService.getDefaultTenantName(user.getName())); - GeaflowInstance instance = instanceService.getByName( - instanceService.getDefaultInstanceName(user.getName())); - context.setTenantId(tenant.getId()); - context.setSystemSession(false); - - List jobs = new ArrayList<>(); - for (DemoJob demoJob : demoJobs) { - GeaflowJob job = demoJob.build(); - job.setInstanceId(instance.getId()); - jobs.add(job); - } - - jobService.create(jobs); - log.info("create demo jobs success"); - } catch (Exception e) { - log.error("create demo job failed", e); - throw e; - } finally { - context.setTenantId(null); - context.setSystemSession(true); + // setup influxdb + if (deployConfig.isLocalMode()) { + GeaflowPluginConfig metricConfig = install.getMetricConfig(); + if (GeaflowPluginType.INFLUXDB.name().equals(metricConfig.getType())) { + InfluxdbPluginConfigClass influxdbConfig = + metricConfig.getConfig().parse(InfluxdbPluginConfigClass.class); + if (influxdbConfig.getUrl().contains(deployConfig.getHost())) { + try { + String org = influxdbConfig.getOrg(); + String bucket = influxdbConfig.getBucket(); + String token = influxdbConfig.getToken(); + String setupCommand = + Fmt.as( + "/usr/local/bin/influx setup --org '{}' --bucket '{}' " + + "--username geaflow --password geaflow123456 --token '{}' --force", + org, + bucket, + token); + + log.info("Setup influxdb with command {}", setupCommand); + ProcessUtil.execute(setupCommand); + } catch (Exception e) { + log.error("Set up influx db failed", e); + } } - + } } + // init cluster + clusterService.create(new GeaflowCluster(install.getRuntimeClusterConfig())); + + // init version + versionManager.createDefaultVersion(); + + createDemoJobs(); + + // set install status + systemConfigService.setValue(SystemConfigKeys.GEAFLOW_INITIALIZED, true); + return true; + } + + private void createDemoJobs() { + GeaflowContext context = ContextHolder.get(); + try { + GeaflowUser user = userService.get(context.getUserId()); + GeaflowTenant tenant = + tenantService.getByName(tenantService.getDefaultTenantName(user.getName())); + GeaflowInstance instance = + instanceService.getByName(instanceService.getDefaultInstanceName(user.getName())); + context.setTenantId(tenant.getId()); + context.setSystemSession(false); + + List jobs = new ArrayList<>(); + for (DemoJob demoJob : demoJobs) { + GeaflowJob job = demoJob.build(); + job.setInstanceId(instance.getId()); + jobs.add(job); + } + + jobService.create(jobs); + log.info("create demo jobs success"); + } catch (Exception e) { + log.error("create demo job failed", e); + throw e; + } finally { + context.setTenantId(null); + context.setSystemSession(true); + } + } - private class DefaultConfigBuilder implements ConfigBuilder { - - @Override - public PluginConfigClass configRuntimeCluster() { - if (deployConfig.isLocalMode()) { - return new ContainerPluginConfigClass(); - } + private class DefaultConfigBuilder implements ConfigBuilder { - K8sPluginConfigClass k8sConfig = new K8sPluginConfigClass(); - k8sConfig.setMasterUrl(Fmt.as("http://{}:8000", deployConfig.getHost())); - k8sConfig.setImageUrl("tugraph/geaflow:0.1"); - k8sConfig.setServiceType("CLUSTER_IP"); - k8sConfig.setStorageLimit("10Gi"); - k8sConfig.setClientTimeout(600000); - return k8sConfig; - } + @Override + public PluginConfigClass configRuntimeCluster() { + if (deployConfig.isLocalMode()) { + return new ContainerPluginConfigClass(); + } + + K8sPluginConfigClass k8sConfig = new K8sPluginConfigClass(); + k8sConfig.setMasterUrl(Fmt.as("http://{}:8000", deployConfig.getHost())); + k8sConfig.setImageUrl("tugraph/geaflow:0.1"); + k8sConfig.setServiceType("CLUSTER_IP"); + k8sConfig.setStorageLimit("10Gi"); + k8sConfig.setClientTimeout(600000); + return k8sConfig; + } - @Override - public PluginConfigClass configRuntimeMeta() { - return datasourceConfig.buildPluginConfigClass(); - } + @Override + public PluginConfigClass configRuntimeMeta() { + return datasourceConfig.buildPluginConfigClass(); + } - @Override - public PluginConfigClass configHaMeta() { - RedisPluginConfigClass redisConfig = new RedisPluginConfigClass(); - redisConfig.setHost(deployConfig.getHost()); - redisConfig.setPort(6379); - return redisConfig; - } + @Override + public PluginConfigClass configHaMeta() { + RedisPluginConfigClass redisConfig = new RedisPluginConfigClass(); + redisConfig.setHost(deployConfig.getHost()); + redisConfig.setPort(6379); + return redisConfig; + } - @Override - public PluginConfigClass configMetric() { - InfluxdbPluginConfigClass influxdbConfig = new InfluxdbPluginConfigClass(); - influxdbConfig.setUrl(Fmt.as("http://{}:8086", deployConfig.getHost())); - influxdbConfig.setToken(tokenGenerator.nextToken()); - influxdbConfig.setOrg("geaflow"); - influxdbConfig.setBucket("geaflow"); - return influxdbConfig; - } + @Override + public PluginConfigClass configMetric() { + InfluxdbPluginConfigClass influxdbConfig = new InfluxdbPluginConfigClass(); + influxdbConfig.setUrl(Fmt.as("http://{}:8086", deployConfig.getHost())); + influxdbConfig.setToken(tokenGenerator.nextToken()); + influxdbConfig.setOrg("geaflow"); + influxdbConfig.setBucket("geaflow"); + return influxdbConfig; + } - @Override - public PluginConfigClass configRemoteFile() { - LocalPluginConfigClass localConfig = new LocalPluginConfigClass(); - localConfig.setRoot("/tmp"); - return localConfig; - } + @Override + public PluginConfigClass configRemoteFile() { + LocalPluginConfigClass localConfig = new LocalPluginConfigClass(); + localConfig.setRoot("/tmp"); + return localConfig; + } - @Override - public PluginConfigClass configData() { - LocalPluginConfigClass localConfig = new LocalPluginConfigClass(); - localConfig.setRoot("/tmp/geaflow/chk"); - return localConfig; - } + @Override + public PluginConfigClass configData() { + LocalPluginConfigClass localConfig = new LocalPluginConfigClass(); + localConfig.setRoot("/tmp/geaflow/chk"); + return localConfig; } + } } - - - - - - diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstanceManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstanceManagerImpl.java index 142bfb763..3abe85436 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstanceManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/InstanceManagerImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.impl; import java.util.List; + import org.apache.geaflow.console.biz.shared.InstanceManager; import org.apache.geaflow.console.biz.shared.convert.InstanceViewConverter; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; @@ -33,33 +34,32 @@ import org.springframework.stereotype.Service; @Service -public class InstanceManagerImpl extends NameManagerImpl implements - InstanceManager { +public class InstanceManagerImpl + extends NameManagerImpl + implements InstanceManager { - @Autowired - private InstanceService instanceService; + @Autowired private InstanceService instanceService; - @Autowired - private InstanceViewConverter instanceViewConverter; + @Autowired private InstanceViewConverter instanceViewConverter; - @Override - public List search() { - List models = instanceService.search(); - return ListUtil.convert(models, e -> instanceViewConverter.convert(e)); - } + @Override + public List search() { + List models = instanceService.search(); + return ListUtil.convert(models, e -> instanceViewConverter.convert(e)); + } - @Override - protected NameViewConverter getConverter() { - return instanceViewConverter; - } + @Override + protected NameViewConverter getConverter() { + return instanceViewConverter; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, e -> instanceViewConverter.convert(e)); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, e -> instanceViewConverter.convert(e)); + } - @Override - protected NameService getService() { - return instanceService; - } + @Override + protected NameService getService() { + return instanceService; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/JobManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/JobManagerImpl.java index 2b6af3100..e14cac99d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/JobManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/JobManagerImpl.java @@ -21,15 +21,12 @@ import static org.apache.geaflow.console.core.service.RemoteFileService.JAR_FILE_SUFFIX; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.TypeReference; -import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.JobManager; @@ -70,225 +67,237 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; -@Service -@Slf4j -public class JobManagerImpl extends IdManagerImpl implements JobManager { - - @Autowired - private JobService jobService; - - @Autowired - private JobViewConverter jobViewConverter; - - @Autowired - private ReleaseService releaseService; - - @Autowired - private TaskManager taskManager; - - @Autowired - private TaskService taskService; - - @Autowired - private RemoteFileManager remoteFileManager; - - @Autowired - private AuthorizationService authorizationService; - - @Autowired - private RemoteFileService remoteFileService; - - @Autowired - private StatementService statementService; - - @Autowired - private TableService tableService; - - @Override - public IdViewConverter getConverter() { - return jobViewConverter; - } +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.TypeReference; +import com.google.common.base.Preconditions; - @Override - public IdService getService() { - return jobService; - } +import lombok.extern.slf4j.Slf4j; - @Override - public List parse(List views) { - return ListUtil.convert(views, v -> { - GeaflowJobType type = v.getType(); - Preconditions.checkNotNull(type, "job Type is null"); - switch (type) { - case PROCESS: - // get functions - List functionIds = ListUtil.convert(v.getFunctions(), IdView::getId); - List functions = jobService.getResourceService(GeaflowResourceType.FUNCTION).get(functionIds); - return jobViewConverter.convert(v, null, null, functions, null); - case CUSTOM: - GeaflowRemoteFile remoteFile = Optional.ofNullable(v.getJarPackage()) - .map(e -> remoteFileService.get(e.getId())).orElse(null); - return jobViewConverter.convert(v, null, null, null, remoteFile); - case INTEGRATE: - case SERVE: - List structs = null; - if (type == GeaflowJobType.INTEGRATE) { - // get tables - structs = getStructs(v); - } - Preconditions.checkArgument(v.getGraphs() != null && v.getGraphs().size() == 1, - "Must have one graph"); - - List graphIds = ListUtil.convert(v.getGraphs(), IdView::getId); - List graphs = ListUtil.convert(graphIds, id -> { - GeaflowGraph g = (GeaflowGraph) jobService.getResourceService(GeaflowResourceType.GRAPH).get(id); +@Service +@Slf4j +public class JobManagerImpl extends IdManagerImpl + implements JobManager { + + @Autowired private JobService jobService; + + @Autowired private JobViewConverter jobViewConverter; + + @Autowired private ReleaseService releaseService; + + @Autowired private TaskManager taskManager; + + @Autowired private TaskService taskService; + + @Autowired private RemoteFileManager remoteFileManager; + + @Autowired private AuthorizationService authorizationService; + + @Autowired private RemoteFileService remoteFileService; + + @Autowired private StatementService statementService; + + @Autowired private TableService tableService; + + @Override + public IdViewConverter getConverter() { + return jobViewConverter; + } + + @Override + public IdService getService() { + return jobService; + } + + @Override + public List parse(List views) { + return ListUtil.convert( + views, + v -> { + GeaflowJobType type = v.getType(); + Preconditions.checkNotNull(type, "job Type is null"); + switch (type) { + case PROCESS: + // get functions + List functionIds = ListUtil.convert(v.getFunctions(), IdView::getId); + List functions = + jobService.getResourceService(GeaflowResourceType.FUNCTION).get(functionIds); + return jobViewConverter.convert(v, null, null, functions, null); + case CUSTOM: + GeaflowRemoteFile remoteFile = + Optional.ofNullable(v.getJarPackage()) + .map(e -> remoteFileService.get(e.getId())) + .orElse(null); + return jobViewConverter.convert(v, null, null, null, remoteFile); + case INTEGRATE: + case SERVE: + List structs = null; + if (type == GeaflowJobType.INTEGRATE) { + // get tables + structs = getStructs(v); + } + Preconditions.checkArgument( + v.getGraphs() != null && v.getGraphs().size() == 1, "Must have one graph"); + + List graphIds = ListUtil.convert(v.getGraphs(), IdView::getId); + List graphs = + ListUtil.convert( + graphIds, + id -> { + GeaflowGraph g = + (GeaflowGraph) + jobService.getResourceService(GeaflowResourceType.GRAPH).get(id); Preconditions.checkNotNull(g, "Graph id {} is null", id); return g; - }); - return jobViewConverter.convert(v, structs, graphs, null, null); + }); + return jobViewConverter.convert(v, structs, graphs, null, null); - default: - throw new GeaflowException("Unsupported job Type: {}", v.getType()); - } + default: + throw new GeaflowException("Unsupported job Type: {}", v.getType()); + } }); + } + + private List getStructs(JobView jobView) { + List structMappings = + JSON.parseObject(jobView.getStructMappings(), new TypeReference>() {}); + Set tableNames = + structMappings.stream().map(StructMapping::getTableName).collect(Collectors.toSet()); + return tableNames.stream() + .map(e -> tableService.getByName(jobView.getInstanceId(), e)) + .collect(Collectors.toList()); + } + + @Override + @Transactional + public boolean drop(List jobIds) { + List taskIds = taskService.getIdsByJob(jobIds); + taskManager.drop(taskIds); + releaseService.dropByJobIds(jobIds); + jobService.dropResources(jobIds); + authorizationService.dropByResources(jobIds, GeaflowResourceType.JOB); + statementService.dropByJobIds(jobIds); + try { + Map jarIds = jobService.getJarIds(jobIds); + for (String jobId : jobIds) { + remoteFileManager.deleteRefJar(jarIds.get(jobId), jobId, GeaflowResourceType.JOB); + } + } catch (Exception e) { + log.info(e.getMessage()); } - private List getStructs(JobView jobView) { - List structMappings = JSON.parseObject(jobView.getStructMappings(), - new TypeReference>() { - }); - Set tableNames = structMappings.stream().map(StructMapping::getTableName).collect(Collectors.toSet()); - return tableNames.stream().map(e -> tableService.getByName(jobView.getInstanceId(), e)) - .collect(Collectors.toList()); - } + return super.drop(jobIds); + } - - @Override - @Transactional - public boolean drop(List jobIds) { - List taskIds = taskService.getIdsByJob(jobIds); - taskManager.drop(taskIds); - releaseService.dropByJobIds(jobIds); - jobService.dropResources(jobIds); - authorizationService.dropByResources(jobIds, GeaflowResourceType.JOB); - statementService.dropByJobIds(jobIds); - try { - Map jarIds = jobService.getJarIds(jobIds); - for (String jobId : jobIds) { - remoteFileManager.deleteRefJar(jarIds.get(jobId), jobId, GeaflowResourceType.JOB); - } - } catch (Exception e) { - log.info(e.getMessage()); - } - - return super.drop(jobIds); + private String createApiJob(JobView jobView, MultipartFile jarFile, String fileId) { + String jobName = jobView.getName(); + if (StringUtils.isBlank(jobName)) { + throw new GeaflowIllegalException("Invalid function name"); } - - private String createApiJob(JobView jobView, MultipartFile jarFile, String fileId) { - String jobName = jobView.getName(); - if (StringUtils.isBlank(jobName)) { - throw new GeaflowIllegalException("Invalid function name"); - } - - if (jobService.existName(jobName)) { - throw new GeaflowIllegalException("Job name {} exists", jobName); - } - - if (jobView.getType() == GeaflowJobType.CUSTOM) { - Preconditions.checkNotNull(jobView.getEntryClass(), "Custom job needs entryClass"); - } - - if (jarFile != null) { - jobView.setJarPackage(createRemoteFile(jarFile)); - } else if (fileId != null) { - // bind a jar file if jarId is not null - if (!remoteFileService.exist(fileId)) { - throw new GeaflowIllegalException("File {} does not exist", fileId); - } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setId(fileId); - jobView.setJarPackage(remoteFileView); - } - - // job package could be null - return super.create(jobView); + if (jobService.existName(jobName)) { + throw new GeaflowIllegalException("Job name {} exists", jobName); } + if (jobView.getType() == GeaflowJobType.CUSTOM) { + Preconditions.checkNotNull(jobView.getEntryClass(), "Custom job needs entryClass"); + } - private boolean updateApiJob(String jobId, JobView updateView, MultipartFile jarFile, String fileId) { - if (updateView.getType() == GeaflowJobType.CUSTOM) { - Preconditions.checkNotNull(updateView.getEntryClass(), "Hla job needs entryClass"); - } + if (jarFile != null) { + jobView.setJarPackage(createRemoteFile(jarFile)); + } else if (fileId != null) { + // bind a jar file if jarId is not null + if (!remoteFileService.exist(fileId)) { + throw new GeaflowIllegalException("File {} does not exist", fileId); + } + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setId(fileId); + jobView.setJarPackage(remoteFileView); + } - if (jarFile != null) { - updateView.setJarPackage(createRemoteFile(jarFile)); - - // try to delete old file - GeaflowJob job = jobService.get(jobId); - GeaflowRemoteFile oldJar = job.getJarPackage(); - if (oldJar != null) { - String oldJarId = job.getJarPackage().getId(); - try { - remoteFileManager.deleteRefJar(oldJarId, jobId, GeaflowResourceType.JOB); - } catch (Exception e) { - log.info("delete job jar fail, jobName: {}, jarId: {}", job.getName(), oldJarId); - } - } - - } else if (fileId != null) { - // bind a jar file if jarId is not null - if (!remoteFileService.exist(fileId)) { - throw new GeaflowIllegalException("File {} does not exist", fileId); - } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setId(fileId); - updateView.setJarPackage(remoteFileView); - } + // job package could be null + return super.create(jobView); + } - return updateById(jobId, updateView); + private boolean updateApiJob( + String jobId, JobView updateView, MultipartFile jarFile, String fileId) { + if (updateView.getType() == GeaflowJobType.CUSTOM) { + Preconditions.checkNotNull(updateView.getEntryClass(), "Hla job needs entryClass"); } - private RemoteFileView createRemoteFile(MultipartFile jarFile) { - if (!StringUtils.endsWith(jarFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + if (jarFile != null) { + updateView.setJarPackage(createRemoteFile(jarFile)); - String fileName = jarFile.getOriginalFilename(); - if (remoteFileService.existName(fileName)) { - throw new GeaflowException("FileName {} exists", fileName); + // try to delete old file + GeaflowJob job = jobService.get(jobId); + GeaflowRemoteFile oldJar = job.getJarPackage(); + if (oldJar != null) { + String oldJarId = job.getJarPackage().getId(); + try { + remoteFileManager.deleteRefJar(oldJarId, jobId, GeaflowResourceType.JOB); + } catch (Exception e) { + log.info("delete job jar fail, jobName: {}, jarId: {}", job.getName(), oldJarId); } + } + + } else if (fileId != null) { + // bind a jar file if jarId is not null + if (!remoteFileService.exist(fileId)) { + throw new GeaflowIllegalException("File {} does not exist", fileId); + } + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setId(fileId); + updateView.setJarPackage(remoteFileView); + } - String path = RemoteFileStorage.getUserFilePath(ContextHolder.get().getUserId(), fileName); + return updateById(jobId, updateView); + } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setName(fileName); - remoteFileView.setPath(path); - remoteFileManager.create(remoteFileView, jarFile); + private RemoteFileView createRemoteFile(MultipartFile jarFile) { + if (!StringUtils.endsWith(jarFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); + } - return remoteFileView; + String fileName = jarFile.getOriginalFilename(); + if (remoteFileService.existName(fileName)) { + throw new GeaflowException("FileName {} exists", fileName); } - @Override - @Transactional - public String create(JobView jobView, MultipartFile jarFile, String fileId, List graphIds) { - Preconditions.checkNotNull(jobView.getType(), "Job type is null"); - if (CollectionUtils.isNotEmpty(graphIds)) { - List graphViews = ListUtil.convert(graphIds, id -> { + String path = RemoteFileStorage.getUserFilePath(ContextHolder.get().getUserId(), fileName); + + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setName(fileName); + remoteFileView.setPath(path); + remoteFileManager.create(remoteFileView, jarFile); + + return remoteFileView; + } + + @Override + @Transactional + public String create( + JobView jobView, MultipartFile jarFile, String fileId, List graphIds) { + Preconditions.checkNotNull(jobView.getType(), "Job type is null"); + if (CollectionUtils.isNotEmpty(graphIds)) { + List graphViews = + ListUtil.convert( + graphIds, + id -> { GraphView graphView = new GraphView(); graphView.setId(id); return graphView; - }); - jobView.setGraphs(graphViews); - } - return jobView.getType().getTaskType() == GeaflowTaskType.API ? createApiJob(jobView, jarFile, fileId) : - super.create(jobView); - } - - @Override - public boolean update(String jobId, JobView jobView, MultipartFile jarFile, String fileId) { - Preconditions.checkNotNull(jobView.getType(), "Job type is null"); - return jobView.getType().getTaskType() == GeaflowTaskType.API ? updateApiJob(jobId, jobView, jarFile, fileId) : - super.updateById(jobId, jobView); + }); + jobView.setGraphs(graphViews); } + return jobView.getType().getTaskType() == GeaflowTaskType.API + ? createApiJob(jobView, jarFile, fileId) + : super.create(jobView); + } + + @Override + public boolean update(String jobId, JobView jobView, MultipartFile jarFile, String fileId) { + Preconditions.checkNotNull(jobView.getType(), "Job type is null"); + return jobView.getType().getTaskType() == GeaflowTaskType.API + ? updateApiJob(jobId, jobView, jarFile, fileId) + : super.updateById(jobId, jobView); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/LLMManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/LLMManagerImpl.java index e19bfda23..6f5467374 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/LLMManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/LLMManagerImpl.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.console.biz.shared.LLMManager; import org.apache.geaflow.console.biz.shared.convert.LLMViewConverter; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; @@ -35,32 +36,30 @@ import org.springframework.stereotype.Service; @Service -public class LLMManagerImpl extends NameManagerImpl implements - LLMManager { +public class LLMManagerImpl extends NameManagerImpl + implements LLMManager { - @Autowired - private LLMService llmService; + @Autowired private LLMService llmService; - @Autowired - private LLMViewConverter llmViewConverter; + @Autowired private LLMViewConverter llmViewConverter; - @Override - protected NameService getService() { - return llmService; - } + @Override + protected NameService getService() { + return llmService; + } - @Override - protected NameViewConverter getConverter() { - return llmViewConverter; - } + @Override + protected NameViewConverter getConverter() { + return llmViewConverter; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, llmViewConverter::convert); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, llmViewConverter::convert); + } - @Override - public List getLLMTypes() { - return Arrays.asList(GeaflowLLMType.values()); - } + @Override + public List getLLMTypes() { + return Arrays.asList(GeaflowLLMType.values()); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/NameManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/NameManagerImpl.java index c257f8fbf..68c410329 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/NameManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/NameManagerImpl.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.Collections; import java.util.List; + import org.apache.geaflow.console.biz.shared.NameManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; import org.apache.geaflow.console.biz.shared.view.NameView; @@ -31,46 +31,47 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.transaction.annotation.Transactional; -public abstract class NameManagerImpl +import com.google.common.base.Preconditions; + +public abstract class NameManagerImpl< + M extends GeaflowName, V extends NameView, S extends NameSearch> extends IdManagerImpl implements NameManager { - @Autowired - private NameManager nameManager; + @Autowired private NameManager nameManager; - @Override - protected abstract NameService getService(); + @Override + protected abstract NameService getService(); - @Override - protected abstract NameViewConverter getConverter(); + @Override + protected abstract NameViewConverter getConverter(); - @Override - public V getByName(String name) { - List models = getByNames(Collections.singletonList(name)); - return models.isEmpty() ? null : models.get(0); - } + @Override + public V getByName(String name) { + List models = getByNames(Collections.singletonList(name)); + return models.isEmpty() ? null : models.get(0); + } - @Override - public List getByNames(List names) { - List models = getService().getByNames(names); - return build(models); - } + @Override + public List getByNames(List names) { + List models = getService().getByNames(names); + return build(models); + } - @Override - public boolean updateByName(String name, V view) { - String id = getService().getIdByName(name); - Preconditions.checkNotNull(id, "Invalid name %s", name); - return updateById(id, view); - } + @Override + public boolean updateByName(String name, V view) { + String id = getService().getIdByName(name); + Preconditions.checkNotNull(id, "Invalid name %s", name); + return updateById(id, view); + } - @Override - public boolean dropByName(String name) { - return nameManager.dropByNames(Collections.singletonList(name)); - } + @Override + public boolean dropByName(String name) { + return nameManager.dropByNames(Collections.singletonList(name)); + } - @Override - @Transactional - public boolean dropByNames(List names) { - return getService().dropByNames(names); - } + @Override + @Transactional + public boolean dropByNames(List names) { + return getService().dropByNames(names); + } } - diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginConfigManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginConfigManagerImpl.java index c4a863fba..e0d20127b 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginConfigManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginConfigManagerImpl.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.console.biz.shared.PluginConfigManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; import org.apache.geaflow.console.biz.shared.convert.PluginConfigViewConverter; @@ -36,49 +37,48 @@ import org.springframework.stereotype.Service; @Service -public class PluginConfigManagerImpl extends - NameManagerImpl implements PluginConfigManager { +public class PluginConfigManagerImpl + extends NameManagerImpl + implements PluginConfigManager { - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private PluginConfigService pluginConfigService; - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; - @Override - protected NameViewConverter getConverter() { - return pluginConfigViewConverter; - } + @Override + protected NameViewConverter getConverter() { + return pluginConfigViewConverter; + } - @Override - protected NameService getService() { - return pluginConfigService; - } + @Override + protected NameService getService() { + return pluginConfigService; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, v -> pluginConfigViewConverter.convert(v)); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, v -> pluginConfigViewConverter.convert(v)); + } - @Override - public List get(List ids) { - pluginConfigService.validateGetIds(ids); - return super.get(ids); - } + @Override + public List get(List ids) { + pluginConfigService.validateGetIds(ids); + return super.get(ids); + } - @Override - public boolean updateById(String id, PluginConfigView updateView) { - pluginConfigService.validateUpdateIds(Collections.singletonList(id)); - return super.updateById(id, updateView); - } + @Override + public boolean updateById(String id, PluginConfigView updateView) { + pluginConfigService.validateUpdateIds(Collections.singletonList(id)); + return super.updateById(id, updateView); + } - public boolean drop(List ids) { - pluginConfigService.validateUpdateIds(ids); - return super.drop(ids); - } + public boolean drop(List ids) { + pluginConfigService.validateUpdateIds(ids); + return super.drop(ids); + } - @Override - public List getPluginConfigs(GeaflowPluginCategory category, String type) { - return build(pluginConfigService.getPluginConfigs(category, type)); - } + @Override + public List getPluginConfigs(GeaflowPluginCategory category, String type) { + return build(pluginConfigService.getPluginConfigs(category, type)); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginManagerImpl.java index 7f3528035..bc8ca5f90 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/PluginManagerImpl.java @@ -22,12 +22,11 @@ import static org.apache.geaflow.console.core.service.PluginService.PLUGIN_DEFAULT_INSTANCE_ID; import static org.apache.geaflow.console.core.service.RemoteFileService.JAR_FILE_SUFFIX; -import com.google.common.base.Preconditions; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.PluginManager; @@ -61,193 +60,202 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; -@Service -@Slf4j -public class PluginManagerImpl extends NameManagerImpl implements - PluginManager { - - @Autowired - private PluginService pluginService; - - @Autowired - private PluginViewConverter pluginViewConverter; - - @Autowired - private RemoteFileManager remoteFileManager; +import com.google.common.base.Preconditions; - @Autowired - private RemoteFileService remoteFileService; +import lombok.extern.slf4j.Slf4j; - @Autowired - private VersionService versionService; +@Service +@Slf4j +public class PluginManagerImpl extends NameManagerImpl + implements PluginManager { - @Autowired - private JobService jobService; + @Autowired private PluginService pluginService; - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private PluginViewConverter pluginViewConverter; - @Override - protected NameViewConverter getConverter() { - return pluginViewConverter; - } + @Autowired private RemoteFileManager remoteFileManager; - @Override - protected List parse(List views) { - return views.stream().map(e -> { - GeaflowRemoteFile jarPackage = remoteFileService.get( - Optional.ofNullable(e.getJarPackage()).map(IdView::getId).orElse(null)); - return pluginViewConverter.convert(e, jarPackage); - }).collect(Collectors.toList()); - } + @Autowired private RemoteFileService remoteFileService; - @Transactional - @Override - public String createPlugin(PluginView pluginView, MultipartFile jarPackage, String jarId) { - String pluginName = pluginView.getName(); - if (StringUtils.isBlank(pluginName)) { - throw new GeaflowIllegalException("Invalid plugin name"); - } + @Autowired private VersionService versionService; - if (pluginService.existName(pluginName)) { - throw new GeaflowIllegalException("Plugin name {} exists", pluginName); - } + @Autowired private JobService jobService; - String type = pluginView.getType(); - GeaflowPluginCategory category = pluginView.getCategory(); - Preconditions.checkNotNull(type, "Invalid plugin name type"); - Preconditions.checkNotNull(category, "Invalid plugin name category"); - GeaflowVersion defaultVersion = versionService.getDefaultVersion(); - if (category == GeaflowPluginCategory.TABLE) { - if (jarPackage == null && jarId == null) { - throw new GeaflowIllegalException("Need upload or bind a jar"); - } - if (pluginService.pluginTypeInEngine(type, defaultVersion)) { - throw new GeaflowIllegalException("Plugin type {} of category {} exists in engine", type, category); - } - } + @Autowired private PluginConfigService pluginConfigService; - GeaflowPlugin plugin = pluginService.getPlugin(type, category); - if (plugin != null) { - throw new GeaflowIllegalException("Plugin type {} of category {} exists", type, category); - } + @Override + protected NameViewConverter getConverter() { + return pluginViewConverter; + } - if (jarId == null) { - if (jarPackage != null) { - pluginService.checkJar(type, jarPackage, defaultVersion); - RemoteFileView remoteFile = createRemoteFile(pluginName, jarPackage); - pluginView.setJarPackage(remoteFile); - } - } else { - pluginService.checkJar(type, jarId, defaultVersion); - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setId(jarId); - pluginView.setJarPackage(remoteFileView); - } + @Override + protected List parse(List views) { + return views.stream() + .map( + e -> { + GeaflowRemoteFile jarPackage = + remoteFileService.get( + Optional.ofNullable(e.getJarPackage()).map(IdView::getId).orElse(null)); + return pluginViewConverter.convert(e, jarPackage); + }) + .collect(Collectors.toList()); + } - return super.create(pluginView); + @Transactional + @Override + public String createPlugin(PluginView pluginView, MultipartFile jarPackage, String jarId) { + String pluginName = pluginView.getName(); + if (StringUtils.isBlank(pluginName)) { + throw new GeaflowIllegalException("Invalid plugin name"); } - @Transactional - @Override - public boolean updatePlugin(String pluginId, PluginView updateView, MultipartFile jarPackage) { - pluginService.validateUpdateIds(Collections.singletonList(pluginId)); - PluginView view = get(pluginId); - if (view == null) { - throw new GeaflowIllegalException("plugin id {} not exists", pluginId); - } - - if (jarPackage != null) { - RemoteFileView remoteFileView = updateJarPackage(view, jarPackage); - updateView.setJarPackage(remoteFileView); - } - - return updateById(view.getId(), updateView); + if (pluginService.existName(pluginName)) { + throw new GeaflowIllegalException("Plugin name {} exists", pluginName); } - private RemoteFileView updateJarPackage(PluginView versionView, MultipartFile multipartFile) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + String type = pluginView.getType(); + GeaflowPluginCategory category = pluginView.getCategory(); + Preconditions.checkNotNull(type, "Invalid plugin name type"); + Preconditions.checkNotNull(category, "Invalid plugin name category"); + GeaflowVersion defaultVersion = versionService.getDefaultVersion(); + if (category == GeaflowPluginCategory.TABLE) { + if (jarPackage == null && jarId == null) { + throw new GeaflowIllegalException("Need upload or bind a jar"); + } + if (pluginService.pluginTypeInEngine(type, defaultVersion)) { + throw new GeaflowIllegalException( + "Plugin type {} of category {} exists in engine", type, category); + } + } - RemoteFileView jarPackage = versionView.getJarPackage(); - if (jarPackage == null) { - return createRemoteFile(versionView.getName(), multipartFile); + GeaflowPlugin plugin = pluginService.getPlugin(type, category); + if (plugin != null) { + throw new GeaflowIllegalException("Plugin type {} of category {} exists", type, category); + } - } else { - String remoteFileId = jarPackage.getId(); - remoteFileManager.upload(remoteFileId, multipartFile); - return null; - } + if (jarId == null) { + if (jarPackage != null) { + pluginService.checkJar(type, jarPackage, defaultVersion); + RemoteFileView remoteFile = createRemoteFile(pluginName, jarPackage); + pluginView.setJarPackage(remoteFile); + } + } else { + pluginService.checkJar(type, jarId, defaultVersion); + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setId(jarId); + pluginView.setJarPackage(remoteFileView); } + return super.create(pluginView); + } - private RemoteFileView createRemoteFile(String pluginName, MultipartFile multipartFile) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + @Transactional + @Override + public boolean updatePlugin(String pluginId, PluginView updateView, MultipartFile jarPackage) { + pluginService.validateUpdateIds(Collections.singletonList(pluginId)); + PluginView view = get(pluginId); + if (view == null) { + throw new GeaflowIllegalException("plugin id {} not exists", pluginId); + } - String fileName = multipartFile.getOriginalFilename(); - boolean systemSession = ContextHolder.get().isSystemSession(); - String userId = ContextHolder.get().getUserId(); - String path = systemSession ? RemoteFileStorage.getPluginFilePath(pluginName, fileName) - : RemoteFileStorage.getUserFilePath(userId, fileName); + if (jarPackage != null) { + RemoteFileView remoteFileView = updateJarPackage(view, jarPackage); + updateView.setJarPackage(remoteFileView); + } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setName(fileName); - remoteFileView.setPath(path); - remoteFileManager.create(remoteFileView, multipartFile); + return updateById(view.getId(), updateView); + } - return remoteFileView; + private RemoteFileView updateJarPackage(PluginView versionView, MultipartFile multipartFile) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); } - @Override - protected NameService getService() { - return pluginService; - } + RemoteFileView jarPackage = versionView.getJarPackage(); + if (jarPackage == null) { + return createRemoteFile(versionView.getName(), multipartFile); - @Override - public List get(List ids) { - pluginService.validateGetIds(ids); - return super.get(ids); + } else { + String remoteFileId = jarPackage.getId(); + remoteFileManager.upload(remoteFileId, multipartFile); + return null; } + } - @Override - public boolean drop(List ids) { - pluginService.validateUpdateIds(ids); - - for (String id : ids) { - GeaflowPlugin geaflowPlugin = pluginService.get(id); - // check plugin is used by jobs or tables - checkPluginUsed(geaflowPlugin); + private RemoteFileView createRemoteFile(String pluginName, MultipartFile multipartFile) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); + } - GeaflowRemoteFile file = geaflowPlugin.getJarPackage(); - if (file != null) { - try { - remoteFileManager.deleteRefJar(file.getId(), geaflowPlugin.getId(), GeaflowResourceType.PLUGIN); + String fileName = multipartFile.getOriginalFilename(); + boolean systemSession = ContextHolder.get().isSystemSession(); + String userId = ContextHolder.get().getUserId(); + String path = + systemSession + ? RemoteFileStorage.getPluginFilePath(pluginName, fileName) + : RemoteFileStorage.getUserFilePath(userId, fileName); - } catch (Exception e) { - log.info(" Delete plugin file {} failed ", file.getName(), e); - } - } + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setName(fileName); + remoteFileView.setPath(path); + remoteFileManager.create(remoteFileView, multipartFile); + + return remoteFileView; + } + + @Override + protected NameService getService() { + return pluginService; + } + + @Override + public List get(List ids) { + pluginService.validateGetIds(ids); + return super.get(ids); + } + + @Override + public boolean drop(List ids) { + pluginService.validateUpdateIds(ids); + + for (String id : ids) { + GeaflowPlugin geaflowPlugin = pluginService.get(id); + // check plugin is used by jobs or tables + checkPluginUsed(geaflowPlugin); + + GeaflowRemoteFile file = geaflowPlugin.getJarPackage(); + if (file != null) { + try { + remoteFileManager.deleteRefJar( + file.getId(), geaflowPlugin.getId(), GeaflowResourceType.PLUGIN); + + } catch (Exception e) { + log.info(" Delete plugin file {} failed ", file.getName(), e); } - - return super.drop(ids); + } } - private void checkPluginUsed(GeaflowPlugin geaflowPlugin) { - List jobIds = jobService.getJobByResources(geaflowPlugin.getName(), PLUGIN_DEFAULT_INSTANCE_ID, - GeaflowResourceType.PLUGIN); - if (CollectionUtils.isNotEmpty(jobIds)) { - List jobNames = ListUtil.convert(jobIds, e -> jobService.getNameById(e)); - throw new GeaflowException("Plugin {} is used by job: {}", geaflowPlugin.getName(), String.join(",", jobNames)); - } - - List pluginConfigs = pluginConfigService.getPluginConfigs(null, geaflowPlugin.getType()); - if (CollectionUtils.isNotEmpty(pluginConfigs)) { - List configNames = ListUtil.convert(pluginConfigs, GeaflowName::getName); - throw new GeaflowException("Plugin {} is used by config: {}", geaflowPlugin.getName(), String.join(",", configNames)); - } + return super.drop(ids); + } + + private void checkPluginUsed(GeaflowPlugin geaflowPlugin) { + List jobIds = + jobService.getJobByResources( + geaflowPlugin.getName(), PLUGIN_DEFAULT_INSTANCE_ID, GeaflowResourceType.PLUGIN); + if (CollectionUtils.isNotEmpty(jobIds)) { + List jobNames = ListUtil.convert(jobIds, e -> jobService.getNameById(e)); + throw new GeaflowException( + "Plugin {} is used by job: {}", geaflowPlugin.getName(), String.join(",", jobNames)); } + List pluginConfigs = + pluginConfigService.getPluginConfigs(null, geaflowPlugin.getType()); + if (CollectionUtils.isNotEmpty(pluginConfigs)) { + List configNames = ListUtil.convert(pluginConfigs, GeaflowName::getName); + throw new GeaflowException( + "Plugin {} is used by config: {}", + geaflowPlugin.getName(), + String.join(",", configNames)); + } + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ReleaseManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ReleaseManagerImpl.java index 842bb0b95..1a7be16fa 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ReleaseManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/ReleaseManagerImpl.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.biz.shared.impl; - import java.util.List; + import org.apache.geaflow.console.biz.shared.ReleaseManager; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; import org.apache.geaflow.console.biz.shared.convert.ReleaseUpdateViewConverter; @@ -52,103 +52,94 @@ import org.springframework.transaction.annotation.Transactional; @Service -public class ReleaseManagerImpl extends IdManagerImpl implements ReleaseManager { +public class ReleaseManagerImpl extends IdManagerImpl + implements ReleaseManager { - @Autowired - private JobService jobService; + @Autowired private JobService jobService; - @Autowired - private ReleaseService releaseService; + @Autowired private ReleaseService releaseService; - @Autowired - private ReleaseUpdateViewConverter releaseUpdateViewConverter; + @Autowired private ReleaseUpdateViewConverter releaseUpdateViewConverter; - @Autowired - private TaskService taskService; + @Autowired private TaskService taskService; - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - @Autowired - private ClusterService clusterService; + @Autowired private ClusterService clusterService; - @Autowired - private ReleaseViewConverter releaseViewConverter; + @Autowired private ReleaseViewConverter releaseViewConverter; - @Autowired - private AuditService auditService; + @Autowired private AuditService auditService; - @Override - protected IdService getService() { - return releaseService; - } + @Override + protected IdService getService() { + return releaseService; + } - @Override - protected IdViewConverter getConverter() { - return releaseViewConverter; - } + @Override + protected IdViewConverter getConverter() { + return releaseViewConverter; + } - @Override - protected List parse(List views) { - throw new UnsupportedOperationException("Release can't be converted from view"); - } + @Override + protected List parse(List views) { + throw new UnsupportedOperationException("Release can't be converted from view"); + } - @Override - @Transactional(rollbackFor = Exception.class) - public String publish(String jobId) { - GeaflowJob job = jobService.get(jobId); - - GeaflowRelease release = GeaflowReleaseBuilder.build(job); - - boolean newRelease = release.getId() == null; - String releaseId; - // handle release - if (newRelease) { - // the task status is not created, create a new release, version+1 - releaseId = releaseService.create(release); - } else { - // the task status is created, update release, version unchanged - releaseService.update(release); - releaseId = release.getId(); - } - - // handle task - if (newRelease) { - String taskId; - - if (release.getReleaseVersion() == 1) { - // create a task when first publishing - taskId = taskService.createTask(release).get(0).getId(); - } else { - // bind task with release for later publishing - taskId = taskService.bindRelease(release); - } - - if (taskId != null) { - String detail = Fmt.as("Publish version {}", release.getReleaseVersion()); - auditService.create(new GeaflowAudit(taskId, GeaflowOperationType.PUBLISH, detail)); - } - } - - return releaseId; - } + @Override + @Transactional(rollbackFor = Exception.class) + public String publish(String jobId) { + GeaflowJob job = jobService.get(jobId); + GeaflowRelease release = GeaflowReleaseBuilder.build(job); - @Override - public boolean updateRelease(String jobId, ReleaseUpdateView view) { - GeaflowTask task = taskService.getByJobId(jobId); + boolean newRelease = release.getId() == null; + String releaseId; + // handle release + if (newRelease) { + // the task status is not created, create a new release, version+1 + releaseId = releaseService.create(release); + } else { + // the task status is created, update release, version unchanged + releaseService.update(release); + releaseId = release.getId(); + } - if (task.getStatus() != GeaflowTaskStatus.CREATED) { - throw new GeaflowException("Only created status can be updated"); - } + // handle task + if (newRelease) { + String taskId; + + if (release.getReleaseVersion() == 1) { + // create a task when first publishing + taskId = taskService.createTask(release).get(0).getId(); + } else { + // bind task with release for later publishing + taskId = taskService.bindRelease(release); + } + + if (taskId != null) { + String detail = Fmt.as("Publish version {}", release.getReleaseVersion()); + auditService.create(new GeaflowAudit(taskId, GeaflowOperationType.PUBLISH, detail)); + } + } - GeaflowVersion version = versionService.getByName(view.getVersionName()); - GeaflowCluster cluster = clusterService.getByName(view.getClusterName()); - ReleaseUpdate releaseUpdate = releaseUpdateViewConverter.converter(view, version, cluster); + return releaseId; + } - GeaflowRelease newRelease = GeaflowReleaseBuilder.update(task.getRelease(), releaseUpdate); + @Override + public boolean updateRelease(String jobId, ReleaseUpdateView view) { + GeaflowTask task = taskService.getByJobId(jobId); - return releaseService.update(newRelease); + if (task.getStatus() != GeaflowTaskStatus.CREATED) { + throw new GeaflowException("Only created status can be updated"); } + GeaflowVersion version = versionService.getByName(view.getVersionName()); + GeaflowCluster cluster = clusterService.getByName(view.getClusterName()); + ReleaseUpdate releaseUpdate = releaseUpdateViewConverter.converter(view, version, cluster); + + GeaflowRelease newRelease = GeaflowReleaseBuilder.update(task.getRelease(), releaseUpdate); + + return releaseService.update(newRelease); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/RemoteFileManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/RemoteFileManagerImpl.java index 32f8bc1a0..a128aef92 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/RemoteFileManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/RemoteFileManagerImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.io.InputStream; import java.util.Collections; import java.util.HashMap; @@ -29,8 +28,10 @@ import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; + import javax.annotation.PostConstruct; import javax.servlet.http.HttpServletResponse; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.RemoteFileManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; @@ -58,215 +59,213 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; +import com.google.common.base.Preconditions; + @Service -public class RemoteFileManagerImpl extends - NameManagerImpl implements RemoteFileManager { +public class RemoteFileManagerImpl + extends NameManagerImpl + implements RemoteFileManager { - @Autowired - private RemoteFileViewConverter remoteFileViewConverter; + @Autowired private RemoteFileViewConverter remoteFileViewConverter; - @Autowired - private RemoteFileService remoteFileService; + @Autowired private RemoteFileService remoteFileService; - @Autowired - private RemoteFileStorage remoteFileStorage; + @Autowired private RemoteFileStorage remoteFileStorage; - @Autowired - private FunctionService functionService; + @Autowired private FunctionService functionService; - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - @Autowired - private JobService jobService; + @Autowired private JobService jobService; - @Autowired - private PluginService pluginService; + @Autowired private PluginService pluginService; - private final Map fileServiceMap = new HashMap<>(); + private final Map fileServiceMap = new HashMap<>(); - @PostConstruct - public void init() { - fileServiceMap.put(GeaflowResourceType.JOB, jobService); - fileServiceMap.put(GeaflowResourceType.ENGINE_VERSION, versionService); - fileServiceMap.put(GeaflowResourceType.FUNCTION, functionService); - fileServiceMap.put(GeaflowResourceType.PLUGIN, pluginService); - } + @PostConstruct + public void init() { + fileServiceMap.put(GeaflowResourceType.JOB, jobService); + fileServiceMap.put(GeaflowResourceType.ENGINE_VERSION, versionService); + fileServiceMap.put(GeaflowResourceType.FUNCTION, functionService); + fileServiceMap.put(GeaflowResourceType.PLUGIN, pluginService); + } - @Override - protected NameViewConverter getConverter() { - return remoteFileViewConverter; - } + @Override + protected NameViewConverter getConverter() { + return remoteFileViewConverter; + } - @Override - protected NameService getService() { - return remoteFileService; - } + @Override + protected NameService getService() { + return remoteFileService; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, v -> remoteFileViewConverter.convert(v)); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, v -> remoteFileViewConverter.convert(v)); + } - private boolean checkFileName(String str) { - String pattern = "^[A-Za-z0-9@!#$%^&*()_+\\-=\\[\\]{};:'\"\\,\\.\\/\\?\\\\|`~\\s]+$"; - Pattern regex = Pattern.compile(pattern); - Matcher matcher = regex.matcher(str); - return matcher.matches(); - } - - @Transactional - @Override - public String create(RemoteFileView view, MultipartFile multipartFile) { - String name = Optional.ofNullable(view.getName()).orElse(multipartFile.getOriginalFilename()); - if (!checkFileName(name)) { - throw new GeaflowException("File name is illegal, {}", name); - } - String path = Preconditions.checkNotNull(view.getPath(), "Invalid path"); - String remoteFileId; - try { - view.setName(name); - view.setType(GeaflowFileType.of(StringUtils.substringAfterLast(multipartFile.getOriginalFilename(), "."))); - view.setMd5(Md5Util.encodeFile(multipartFile)); - // set url according to the path - String url = remoteFileStorage.getUrl(view.getPath()); - view.setUrl(url); - remoteFileId = create(view); - view.setId(remoteFileId); - } catch (Exception e) { - throw new GeaflowException("Create file {} failed", name, e); - } + private boolean checkFileName(String str) { + String pattern = "^[A-Za-z0-9@!#$%^&*()_+\\-=\\[\\]{};:'\"\\,\\.\\/\\?\\\\|`~\\s]+$"; + Pattern regex = Pattern.compile(pattern); + Matcher matcher = regex.matcher(str); + return matcher.matches(); + } - try (InputStream inputStream = multipartFile.getInputStream()) { - remoteFileStorage.upload(path, inputStream); + @Transactional + @Override + public String create(RemoteFileView view, MultipartFile multipartFile) { + String name = Optional.ofNullable(view.getName()).orElse(multipartFile.getOriginalFilename()); + if (!checkFileName(name)) { + throw new GeaflowException("File name is illegal, {}", name); + } + String path = Preconditions.checkNotNull(view.getPath(), "Invalid path"); + String remoteFileId; + try { + view.setName(name); + view.setType( + GeaflowFileType.of( + StringUtils.substringAfterLast(multipartFile.getOriginalFilename(), "."))); + view.setMd5(Md5Util.encodeFile(multipartFile)); + // set url according to the path + String url = remoteFileStorage.getUrl(view.getPath()); + view.setUrl(url); + remoteFileId = create(view); + view.setId(remoteFileId); + } catch (Exception e) { + throw new GeaflowException("Create file {} failed", name, e); + } - } catch (Exception e) { - throw new GeaflowException("Upload file {} failed", name, path, e); - } + try (InputStream inputStream = multipartFile.getInputStream()) { + remoteFileStorage.upload(path, inputStream); - return remoteFileId; + } catch (Exception e) { + throw new GeaflowException("Upload file {} failed", name, path, e); } - @Transactional - @Override - public boolean upload(String remoteFileId, MultipartFile multipartFile) { - remoteFileService.validateUpdateIds(Collections.singletonList(remoteFileId)); - GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); - if (remoteFile == null) { - return false; - } + return remoteFileId; + } - String name = remoteFile.getName(); - String path = remoteFile.getPath(); - - try { - String md5 = Md5Util.encodeFile(multipartFile); - if (md5.equals(remoteFile.getMd5())) { - return false; - } - remoteFileService.updateMd5ById(remoteFileId, md5); - remoteFileService.updateUrlById(remoteFileId, remoteFileStorage.getUrl(path)); - } catch (Exception e) { - throw new GeaflowException("Update file {} md5 failed", name, e); - } + @Transactional + @Override + public boolean upload(String remoteFileId, MultipartFile multipartFile) { + remoteFileService.validateUpdateIds(Collections.singletonList(remoteFileId)); + GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); + if (remoteFile == null) { + return false; + } - try (InputStream inputStream = multipartFile.getInputStream()) { - remoteFileStorage.upload(path, inputStream); + String name = remoteFile.getName(); + String path = remoteFile.getPath(); + + try { + String md5 = Md5Util.encodeFile(multipartFile); + if (md5.equals(remoteFile.getMd5())) { + return false; + } + remoteFileService.updateMd5ById(remoteFileId, md5); + remoteFileService.updateUrlById(remoteFileId, remoteFileStorage.getUrl(path)); + } catch (Exception e) { + throw new GeaflowException("Update file {} md5 failed", name, e); + } - } catch (Exception e) { - throw new GeaflowException("Upload file {} failed", name, path, e); - } + try (InputStream inputStream = multipartFile.getInputStream()) { + remoteFileStorage.upload(path, inputStream); - return true; + } catch (Exception e) { + throw new GeaflowException("Upload file {} failed", name, path, e); } - @Override - public boolean download(String remoteFileId, HttpServletResponse response) { - remoteFileService.validateGetIds(Collections.singletonList(remoteFileId)); - GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); - if (remoteFile == null) { - throw new GeaflowException("File not found"); - } + return true; + } - String path = remoteFile.getPath(); - String name = remoteFile.getName(); - - try (InputStream input = remoteFileStorage.download(path)) { - HTTPUtil.download(response, input, name); + @Override + public boolean download(String remoteFileId, HttpServletResponse response) { + remoteFileService.validateGetIds(Collections.singletonList(remoteFileId)); + GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); + if (remoteFile == null) { + throw new GeaflowException("File not found"); + } - } catch (Exception e) { - throw new GeaflowException("Download file {} from {} failed", name, path, e); - } + String path = remoteFile.getPath(); + String name = remoteFile.getName(); - return true; - } + try (InputStream input = remoteFileStorage.download(path)) { + HTTPUtil.download(response, input, name); - @Override - public boolean delete(String remoteFileId) { - remoteFileService.validateUpdateIds(Collections.singletonList(remoteFileId)); - // throw exception if file is used by others - checkFileUsed(remoteFileId); - return deleteFile(remoteFileId); + } catch (Exception e) { + throw new GeaflowException("Download file {} from {} failed", name, path, e); } - @Override - public List get(List ids) { - remoteFileService.validateGetIds(ids); - return super.get(ids); + return true; + } + + @Override + public boolean delete(String remoteFileId) { + remoteFileService.validateUpdateIds(Collections.singletonList(remoteFileId)); + // throw exception if file is used by others + checkFileUsed(remoteFileId); + return deleteFile(remoteFileId); + } + + @Override + public List get(List ids) { + remoteFileService.validateGetIds(ids); + return super.get(ids); + } + + @Override + public void deleteRefJar(String jarId, String refId, GeaflowResourceType resourceType) { + if (jarId == null || !fileServiceMap.containsKey(resourceType)) { + return; } - - @Override - public void deleteRefJar(String jarId, String refId, GeaflowResourceType resourceType) { - if (jarId == null || !fileServiceMap.containsKey(resourceType)) { - return; + // version jar is only used by a version, do not filter + if (resourceType != GeaflowResourceType.ENGINE_VERSION) { + for (Entry entry : fileServiceMap.entrySet()) { + FileRefService fileRefService = entry.getValue(); + // exclude itself when current type is the resourceType + long refCount = + entry.getKey() == resourceType + ? fileRefService.getFileRefCount(jarId, refId) + : fileRefService.getFileRefCount(jarId, null); + if (refCount > 0) { + return; } + } + } - // version jar is only used by a version, do not filter - if (resourceType != GeaflowResourceType.ENGINE_VERSION) { - for (Entry entry : fileServiceMap.entrySet()) { - FileRefService fileRefService = entry.getValue(); - // exclude itself when current type is the resourceType - long refCount = entry.getKey() == resourceType ? fileRefService.getFileRefCount(jarId, refId) : - fileRefService.getFileRefCount(jarId, null); - if (refCount > 0) { - return; - } - } - } + deleteFile(jarId); + } - deleteFile(jarId); + private void checkFileUsed(String remoteFileId) { + for (Entry entry : fileServiceMap.entrySet()) { + long functionCount = entry.getValue().getFileRefCount(remoteFileId, null); + if (functionCount > 0) { + throw new GeaflowException( + "file {} is used by {}, count:{}", remoteFileId, entry.getKey().name(), functionCount); + } } + } - private void checkFileUsed(String remoteFileId) { - for (Entry entry : fileServiceMap.entrySet()) { - long functionCount = entry.getValue().getFileRefCount(remoteFileId, null); - if (functionCount > 0) { - throw new GeaflowException("file {} is used by {}, count:{}", - remoteFileId, entry.getKey().name(), functionCount); - } - } + private boolean deleteFile(String remoteFileId) { + GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); + if (remoteFile == null) { + return false; } - private boolean deleteFile(String remoteFileId) { - GeaflowRemoteFile remoteFile = remoteFileService.get(remoteFileId); - if (remoteFile == null) { - return false; - } - - String id = remoteFile.getId(); - String name = remoteFile.getName(); - String path = remoteFile.getPath(); + String id = remoteFile.getId(); + String name = remoteFile.getName(); + String path = remoteFile.getPath(); - try { - remoteFileStorage.delete(path); - drop(id); - - } catch (Exception e) { - throw new GeaflowException("Delete file {} failed", name, e); - } + try { + remoteFileStorage.delete(path); + drop(id); - return true; + } catch (Exception e) { + throw new GeaflowException("Delete file {} failed", name, e); } + return true; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/StatementManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/StatementManagerImpl.java index 6dce4b46f..cd9863ba9 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/StatementManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/StatementManagerImpl.java @@ -21,7 +21,7 @@ import java.util.Collections; import java.util.List; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.biz.shared.StatementManager; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; import org.apache.geaflow.console.biz.shared.convert.StatementViewConverter; @@ -35,34 +35,35 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j -public class StatementManagerImpl extends IdManagerImpl +public class StatementManagerImpl + extends IdManagerImpl implements StatementManager { - @Autowired - private StatementService statementService; + @Autowired private StatementService statementService; - @Autowired - private StatementViewConverter statementViewConverter; + @Autowired private StatementViewConverter statementViewConverter; - @Override - public IdViewConverter getConverter() { - return statementViewConverter; - } + @Override + public IdViewConverter getConverter() { + return statementViewConverter; + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, v -> statementViewConverter.convert(v)); - } + @Override + protected List parse(List views) { + return ListUtil.convert(views, v -> statementViewConverter.convert(v)); + } - @Override - public IdService getService() { - return statementService; - } + @Override + public IdService getService() { + return statementService; + } - @Override - public boolean dropByJobId(String jobId) { - return statementService.dropByJobIds(Collections.singletonList(jobId)); - } + @Override + public boolean dropByJobId(String jobId) { + return statementService.dropByJobIds(Collections.singletonList(jobId)); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/SystemConfigManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/SystemConfigManagerImpl.java index 37c896978..96a9bd64d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/SystemConfigManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/SystemConfigManagerImpl.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.List; + import org.apache.geaflow.console.biz.shared.SystemConfigManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; import org.apache.geaflow.console.biz.shared.convert.SystemConfigViewConverter; @@ -34,72 +34,73 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -@Service -public class SystemConfigManagerImpl extends - NameManagerImpl implements SystemConfigManager { - - @Autowired - private SystemConfigService systemConfigService; - - @Autowired - private SystemConfigViewConverter systemConfigViewConverter; - - @Override - protected NameService getService() { - return systemConfigService; - } - - @Override - protected NameViewConverter getConverter() { - return systemConfigViewConverter; - } +import com.google.common.base.Preconditions; - @Override - protected List parse(List views) { - return ListUtil.convert(views, systemConfigViewConverter::convert); +@Service +public class SystemConfigManagerImpl + extends NameManagerImpl + implements SystemConfigManager { + + @Autowired private SystemConfigService systemConfigService; + + @Autowired private SystemConfigViewConverter systemConfigViewConverter; + + @Override + protected NameService getService() { + return systemConfigService; + } + + @Override + protected NameViewConverter getConverter() { + return systemConfigViewConverter; + } + + @Override + protected List parse(List views) { + return ListUtil.convert(views, systemConfigViewConverter::convert); + } + + @Override + public SystemConfigView getConfig(String tenantId, String key) { + Preconditions.checkNotNull(key, "Invalid key"); + return build(systemConfigService.get(tenantId, key)); + } + + @Override + public String getValue(String key) { + return systemConfigService.getValue(key); + } + + @Override + public boolean createConfig(SystemConfigView view) { + String tenantId = view.getTenantId(); + String key = Preconditions.checkNotNull(view.getName(), "Invalid key"); + if (systemConfigService.exist(tenantId, key)) { + throw new GeaflowIllegalException("Key {} exists", key); } - @Override - public SystemConfigView getConfig(String tenantId, String key) { - Preconditions.checkNotNull(key, "Invalid key"); - return build(systemConfigService.get(tenantId, key)); - } + return create(view) != null; + } - @Override - public String getValue(String key) { - return systemConfigService.getValue(key); + @Override + public boolean updateConfig(String key, SystemConfigView updateView) { + String newKey = updateView.getName(); + if (newKey != null && !newKey.equals(key)) { + throw new GeaflowIllegalException("Rename key from {} to {} not allowed", key, newKey); } - @Override - public boolean createConfig(SystemConfigView view) { - String tenantId = view.getTenantId(); - String key = Preconditions.checkNotNull(view.getName(), "Invalid key"); - if (systemConfigService.exist(tenantId, key)) { - throw new GeaflowIllegalException("Key {} exists", key); - } - - return create(view) != null; + String tenantId = updateView.getTenantId(); + if (!systemConfigService.exist(tenantId, key)) { + throw new GeaflowIllegalException("Key {} not exists", key); } - @Override - public boolean updateConfig(String key, SystemConfigView updateView) { - String newKey = updateView.getName(); - if (newKey != null && !newKey.equals(key)) { - throw new GeaflowIllegalException("Rename key from {} to {} not allowed", key, newKey); - } + SystemConfigView view = getConfig(tenantId, key); + return updateById(view.getId(), updateView); + } - String tenantId = updateView.getTenantId(); - if (!systemConfigService.exist(tenantId, key)) { - throw new GeaflowIllegalException("Key {} not exists", key); - } - - SystemConfigView view = getConfig(tenantId, key); - return updateById(view.getId(), updateView); - } - - @Override - public boolean deleteConfig(String tenantId, String key) { - Preconditions.checkNotNull(key, "Invalid key"); - return systemConfigService.delete(tenantId, key); - } + @Override + public boolean deleteConfig(String tenantId, String key) { + Preconditions.checkNotNull(key, "Invalid key"); + return systemConfigService.delete(tenantId, key); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TableManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TableManagerImpl.java index 570728bfb..4b4a1b2ea 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TableManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TableManagerImpl.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.biz.shared.impl; -import com.google.common.base.Preconditions; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.PluginConfigManager; import org.apache.geaflow.console.biz.shared.TableManager; import org.apache.geaflow.console.biz.shared.convert.DataViewConverter; @@ -43,52 +43,55 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.google.common.base.Preconditions; + @Service -public class TableManagerImpl extends DataManagerImpl implements TableManager { +public class TableManagerImpl extends DataManagerImpl + implements TableManager { - @Autowired - private TableService tableService; + @Autowired private TableService tableService; - @Autowired - private FieldViewConverter fieldViewConverter; + @Autowired private FieldViewConverter fieldViewConverter; - @Autowired - private PluginConfigViewConverter pluginConfigViewConverter; + @Autowired private PluginConfigViewConverter pluginConfigViewConverter; - @Autowired - private PluginConfigManager pluginConfigManager; + @Autowired private PluginConfigManager pluginConfigManager; - @Autowired - private TableViewConverter tableViewConverter; + @Autowired private TableViewConverter tableViewConverter; - @Override - public DataViewConverter getConverter() { - return tableViewConverter; - } + @Override + public DataViewConverter getConverter() { + return tableViewConverter; + } - @Override - public DataService getService() { - return tableService; - } + @Override + public DataService getService() { + return tableService; + } - @Override - public List parse(List views) { - return views.stream().map(v -> { - List fields = ListUtil.convert(v.getFields(), e -> fieldViewConverter.convert(e)); - GeaflowPluginConfig pluginConfig = pluginConfigViewConverter.convert(v.getPluginConfig()); - return tableViewConverter.convert(v, fields, pluginConfig); - }).collect(Collectors.toList()); - } + @Override + public List parse(List views) { + return views.stream() + .map( + v -> { + List fields = + ListUtil.convert(v.getFields(), e -> fieldViewConverter.convert(e)); + GeaflowPluginConfig pluginConfig = + pluginConfigViewConverter.convert(v.getPluginConfig()); + return tableViewConverter.convert(v, fields, pluginConfig); + }) + .collect(Collectors.toList()); + } - @Override - public List create(String instanceName, List views) { - for (TableView view : views) { - PluginConfigView pluginConfigView = Preconditions.checkNotNull(view.getPluginConfig(), - "Table pluginConfig is required"); - pluginConfigView.setCategory(GeaflowPluginCategory.TABLE); - pluginConfigView.setName(Fmt.as("{}-{}-table-config", instanceName, view.getName())); - pluginConfigManager.create(pluginConfigView); - } - return super.create(instanceName, views); + @Override + public List create(String instanceName, List views) { + for (TableView view : views) { + PluginConfigView pluginConfigView = + Preconditions.checkNotNull(view.getPluginConfig(), "Table pluginConfig is required"); + pluginConfigView.setCategory(GeaflowPluginCategory.TABLE); + pluginConfigView.setName(Fmt.as("{}-{}-table-config", instanceName, view.getName())); + pluginConfigManager.create(pluginConfigView); } + return super.create(instanceName, views); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TaskManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TaskManagerImpl.java index 6827d6229..e8b3a7b85 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TaskManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TaskManagerImpl.java @@ -29,13 +29,13 @@ import static org.apache.geaflow.console.common.util.type.GeaflowTaskStatus.RUNNING; import static org.apache.geaflow.console.common.util.type.GeaflowTaskStatus.STOPPED; -import com.alibaba.fastjson2.JSON; import java.io.File; import java.io.InputStream; import java.util.ArrayList; import java.util.List; + import javax.servlet.http.HttpServletResponse; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.biz.shared.TaskManager; import org.apache.geaflow.console.biz.shared.convert.IdViewConverter; import org.apache.geaflow.console.biz.shared.convert.TaskViewConverter; @@ -79,244 +79,246 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -@Slf4j -@Service -public class TaskManagerImpl extends IdManagerImpl implements TaskManager { - - @Autowired - private TaskService taskService; - - @Autowired - private TaskViewConverter taskViewConverter; - - @Autowired - private GeaflowTaskOperator taskOperator; - - @Autowired - private AuditService auditService; - - @Autowired - private DeployConfig deployConfig; - - @Override - public IdViewConverter getConverter() { - return taskViewConverter; - } - - @Override - public IdService getService() { - return taskService; - } - - @Override - protected List parse(List views) { - throw new UnsupportedOperationException("Task can't be converted from view"); - } - - @Override - @Transactional(rollbackFor = Exception.class) - public void operate(String taskId, GeaflowOperationType action) { - GeaflowTask task = taskService.get(taskId); - task.getStatus().checkOperation(action); - switch (action) { - case START: - start(task); - break; - case STOP: - stop(task, false); - break; - case FINISH: - stop(task, true); - break; - case REFRESH: - taskOperator.refreshStatus(task); - break; - case RESET: - clean(task); - break; - case DELETE: - delete(task); - break; - default: - throw new UnsupportedOperationException("not supported task action: " + action); - } - } - - - protected void start(GeaflowTask task) { - GeaflowTaskStatus status = task.getStatus(); - boolean updateStatus = taskService.updateStatus(task.getId(), status, GeaflowTaskStatus.WAITING); - if (!updateStatus) { - throw new GeaflowException("task status has been changed"); - } - - task.setHost(NetworkUtil.getHostName()); - taskService.update(task); - log.info("submit task successfully, waiting for scheduling. id: {}", task.getId()); - auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.START)); - } - - protected void stop(GeaflowTask task, boolean isFinish) { - GeaflowTaskStatus status = task.getStatus(); - if (status == RUNNING) { - taskOperator.stop(task); - } - - if (isFinish) { - taskService.updateStatus(task.getId(), status, FINISHED); - auditService.create(new GeaflowAudit(task.getId(), FINISH)); - log.info("Task {} is finished by {}", task.getId(), task.getModifierId()); - } else { - taskService.updateStatus(task.getId(), status, STOPPED); - auditService.create(new GeaflowAudit(task.getId(), STOP)); - log.info("Task {} is stopped by {}", task.getId(), task.getModifierId()); - } - } +import com.alibaba.fastjson2.JSON; - protected void clean(GeaflowTask task) { - taskOperator.cleanMeta(task); - taskOperator.cleanData(task); - auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.RESET)); - } +import lombok.extern.slf4j.Slf4j; - protected void delete(GeaflowTask task) { +@Slf4j +@Service +public class TaskManagerImpl extends IdManagerImpl + implements TaskManager { + + @Autowired private TaskService taskService; + + @Autowired private TaskViewConverter taskViewConverter; + + @Autowired private GeaflowTaskOperator taskOperator; + + @Autowired private AuditService auditService; + + @Autowired private DeployConfig deployConfig; + + @Override + public IdViewConverter getConverter() { + return taskViewConverter; + } + + @Override + public IdService getService() { + return taskService; + } + + @Override + protected List parse(List views) { + throw new UnsupportedOperationException("Task can't be converted from view"); + } + + @Override + @Transactional(rollbackFor = Exception.class) + public void operate(String taskId, GeaflowOperationType action) { + GeaflowTask task = taskService.get(taskId); + task.getStatus().checkOperation(action); + switch (action) { + case START: + start(task); + break; + case STOP: + stop(task, false); + break; + case FINISH: + stop(task, true); + break; + case REFRESH: + taskOperator.refreshStatus(task); + break; + case RESET: clean(task); - taskService.updateStatus(task.getId(), task.getStatus(), DELETED); - auditService.create(new GeaflowAudit(task.getId(), DELETE)); + break; + case DELETE: + delete(task); + break; + default: + throw new UnsupportedOperationException("not supported task action: " + action); } - - @Override - public GeaflowTaskStatus queryStatus(String taskId, Boolean refresh) { - if (refresh != null && refresh) { - return taskOperator.refreshStatus(taskService.get(taskId)); - } - return taskService.getStatus(taskId); + } + + protected void start(GeaflowTask task) { + GeaflowTaskStatus status = task.getStatus(); + boolean updateStatus = + taskService.updateStatus(task.getId(), status, GeaflowTaskStatus.WAITING); + if (!updateStatus) { + throw new GeaflowException("task status has been changed"); } - @Override - public PageList queryPipelines(String taskId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryPipelines(task); - } + task.setHost(NetworkUtil.getHostName()); + taskService.update(task); + log.info("submit task successfully, waiting for scheduling. id: {}", task.getId()); + auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.START)); + } - @Override - public PageList queryCycles(String taskId, String pipelineId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryCycles(task, pipelineId); + protected void stop(GeaflowTask task, boolean isFinish) { + GeaflowTaskStatus status = task.getStatus(); + if (status == RUNNING) { + taskOperator.stop(task); } - @Override - public PageList queryErrors(String taskId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryErrors(task); + if (isFinish) { + taskService.updateStatus(task.getId(), status, FINISHED); + auditService.create(new GeaflowAudit(task.getId(), FINISH)); + log.info("Task {} is finished by {}", task.getId(), task.getModifierId()); + } else { + taskService.updateStatus(task.getId(), status, STOPPED); + auditService.create(new GeaflowAudit(task.getId(), STOP)); + log.info("Task {} is stopped by {}", task.getId(), task.getModifierId()); } - - @Override - public PageList queryMetricMeta(String taskId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryMetricMeta(task); + } + + protected void clean(GeaflowTask task) { + taskOperator.cleanMeta(task); + taskOperator.cleanData(task); + auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.RESET)); + } + + protected void delete(GeaflowTask task) { + clean(task); + taskService.updateStatus(task.getId(), task.getStatus(), DELETED); + auditService.create(new GeaflowAudit(task.getId(), DELETE)); + } + + @Override + public GeaflowTaskStatus queryStatus(String taskId, Boolean refresh) { + if (refresh != null && refresh) { + return taskOperator.refreshStatus(taskService.get(taskId)); } - - @Override - public PageList queryMetrics(String taskId, GeaflowMetricQueryRequest queryRequest) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryMetrics(task, queryRequest); + return taskService.getStatus(taskId); + } + + @Override + public PageList queryPipelines(String taskId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryPipelines(task); + } + + @Override + public PageList queryCycles(String taskId, String pipelineId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryCycles(task, pipelineId); + } + + @Override + public PageList queryErrors(String taskId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryErrors(task); + } + + @Override + public PageList queryMetricMeta(String taskId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryMetricMeta(task); + } + + @Override + public PageList queryMetrics( + String taskId, GeaflowMetricQueryRequest queryRequest) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryMetrics(task, queryRequest); + } + + @Override + public PageList queryOffsets(String taskId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryOffsets(task); + } + + @Override + public GeaflowHeartbeatInfo queryHeartbeat(String taskId) { + GeaflowTask task = taskService.get(taskId); + return taskOperator.queryHeartbeat(task); + } + + @Transactional + @Override + public void startupNotify(String taskId, TaskStartupNotifyView startupNotifyView) { + StartupNotifyInfo startupNotifyInfo; + GeaflowTaskStatus newStatus; + GeaflowTask task = taskService.get(taskId); + if (task.getHandle().getClusterType() != GeaflowPluginType.K8S) { + return; } - @Override - public PageList queryOffsets(String taskId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryOffsets(task); + if (startupNotifyView.isSuccess()) { + startupNotifyInfo = startupNotifyView.getData(); + newStatus = RUNNING; + } else { + startupNotifyInfo = new StartupNotifyInfo(); + newStatus = FAILED; } - - @Override - public GeaflowHeartbeatInfo queryHeartbeat(String taskId) { - GeaflowTask task = taskService.get(taskId); - return taskOperator.queryHeartbeat(task); - } - - @Transactional - @Override - public void startupNotify(String taskId, TaskStartupNotifyView startupNotifyView) { - StartupNotifyInfo startupNotifyInfo; - GeaflowTaskStatus newStatus; - GeaflowTask task = taskService.get(taskId); - if (task.getHandle().getClusterType() != GeaflowPluginType.K8S) { - return; - } - - if (startupNotifyView.isSuccess()) { - startupNotifyInfo = startupNotifyView.getData(); - newStatus = RUNNING; - } else { - startupNotifyInfo = new StartupNotifyInfo(); - newStatus = FAILED; - } - ((K8sTaskHandle) task.getHandle()).setStartupNotifyInfo(startupNotifyInfo); - taskService.update(task); - taskService.updateStatus(task.getId(), task.getStatus(), newStatus); - log.info("Task {} get startup notify '{}' from cluster", task.getId(), JSON.toJSONString(startupNotifyView)); - auditService.create(new GeaflowAudit(taskId, STARTUP_NOTIFY, "Task startup success")); + ((K8sTaskHandle) task.getHandle()).setStartupNotifyInfo(startupNotifyInfo); + taskService.update(task); + taskService.updateStatus(task.getId(), task.getStatus(), newStatus); + log.info( + "Task {} get startup notify '{}' from cluster", + task.getId(), + JSON.toJSONString(startupNotifyView)); + auditService.create(new GeaflowAudit(taskId, STARTUP_NOTIFY, "Task startup success")); + } + + @Override + public void download(String taskId, String path, HttpServletResponse response) { + // check task id + GeaflowTask task = taskService.get(taskId); + if (task == null) { + throw new GeaflowException("Invalid task id {}", taskId); } - @Override - public void download(String taskId, String path, HttpServletResponse response) { - // check task id - GeaflowTask task = taskService.get(taskId); - if (task == null) { - throw new GeaflowException("Invalid task id {}", taskId); - } - - // check task token and deploy mode - if (!taskId.equals(ContextHolder.get().getTaskId()) || !deployConfig.isLocalMode()) { - throw new GeaflowSecurityException("Download task {} file {} is not allowed", taskId, path); - } - - // check file used by task - String gatewayUrl = deployConfig.getGatewayUrl(); - String taskFileUrl = task.getTaskFileUrl(gatewayUrl, path); - List files = new ArrayList<>(); - files.addAll(task.getVersionFiles(gatewayUrl)); - files.addAll(task.getUserFiles(gatewayUrl)); - if (files.stream().noneMatch(f -> f.getUrl().equals(taskFileUrl))) { - throw new GeaflowIllegalException("Invalid task file {}", path); - } - - // download local file - String name = new File(path).getName(); - try (InputStream input = FileUtil.readFileStream(path)) { - HTTPUtil.download(response, input, name); - - } catch (Exception e) { - throw new GeaflowException("Download file {} from {} failed", name, path, e); - } + // check task token and deploy mode + if (!taskId.equals(ContextHolder.get().getTaskId()) || !deployConfig.isLocalMode()) { + throw new GeaflowSecurityException("Download task {} file {} is not allowed", taskId, path); } - @Override - public String getLogs(String taskId) { - GeaflowTask task = taskService.get(taskId); - GeaflowPluginType type = task.getRelease().getCluster().getType(); - if (type.equals(GeaflowPluginType.CONTAINER)) { - String logFilePath = ContainerRuntime.getLogFilePath(taskId); - return Fmt.as(I18nUtil.getMessage("i18n.key.container.task.log.tips"), logFilePath); - } else { - return Fmt.as(I18nUtil.getMessage("i18n.key.k8s.task.log.tips")); - } + // check file used by task + String gatewayUrl = deployConfig.getGatewayUrl(); + String taskFileUrl = task.getTaskFileUrl(gatewayUrl, path); + List files = new ArrayList<>(); + files.addAll(task.getVersionFiles(gatewayUrl)); + files.addAll(task.getUserFiles(gatewayUrl)); + if (files.stream().noneMatch(f -> f.getUrl().equals(taskFileUrl))) { + throw new GeaflowIllegalException("Invalid task file {}", path); } + // download local file + String name = new File(path).getName(); + try (InputStream input = FileUtil.readFileStream(path)) { + HTTPUtil.download(response, input, name); - @Override - public TaskView getByJobId(String jobId) { - return build(taskService.getByJobId(jobId)); + } catch (Exception e) { + throw new GeaflowException("Download file {} from {} failed", name, path, e); + } + } + + @Override + public String getLogs(String taskId) { + GeaflowTask task = taskService.get(taskId); + GeaflowPluginType type = task.getRelease().getCluster().getType(); + if (type.equals(GeaflowPluginType.CONTAINER)) { + String logFilePath = ContainerRuntime.getLogFilePath(taskId); + return Fmt.as(I18nUtil.getMessage("i18n.key.container.task.log.tips"), logFilePath); + } else { + return Fmt.as(I18nUtil.getMessage("i18n.key.k8s.task.log.tips")); } + } - @Override - public boolean drop(List ids) { - for (String id : ids) { - operate(id, DELETE); - } + @Override + public TaskView getByJobId(String jobId) { + return build(taskService.getByJobId(jobId)); + } - return super.drop(ids); + @Override + public boolean drop(List ids) { + for (String id : ids) { + operate(id, DELETE); } + return super.drop(ids); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TenantManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TenantManagerImpl.java index 9a9125d46..b453f3344 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TenantManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/TenantManagerImpl.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.biz.shared.TenantManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; import org.apache.geaflow.console.biz.shared.convert.TenantViewConverter; @@ -40,65 +41,60 @@ import org.springframework.stereotype.Service; @Service -public class TenantManagerImpl extends NameManagerImpl implements - TenantManager { - - @Autowired - private TenantService tenantService; - - @Autowired - private TenantViewConverter tenantViewConverter; - - @Autowired - private UserService userService; - - @Override - protected NameService getService() { - return tenantService; +public class TenantManagerImpl extends NameManagerImpl + implements TenantManager { + + @Autowired private TenantService tenantService; + + @Autowired private TenantViewConverter tenantViewConverter; + + @Autowired private UserService userService; + + @Override + protected NameService getService() { + return tenantService; + } + + @Override + protected NameViewConverter getConverter() { + return tenantViewConverter; + } + + @Override + protected List parse(List views) { + return ListUtil.convert(views, tenantViewConverter::convert); + } + + @Override + public TenantView getActiveTenant(String userId) { + GeaflowTenant tenant = tenantService.getActiveTenant(userId); + if (tenant == null) { + List userTenants = tenantService.getUserTenants(userId); + if (userTenants.isEmpty()) { + throw new GeaflowException("User not in any tenants"); + } + + // active one tenant + tenant = userTenants.get(0); + tenantService.activateTenant(userId, tenant.getId()); } - @Override - protected NameViewConverter getConverter() { - return tenantViewConverter; - } + return tenantViewConverter.convert(tenant); + } - @Override - protected List parse(List views) { - return ListUtil.convert(views, tenantViewConverter::convert); - } + @Override + public Map getTenantNames(Collection tenantIds) { + return tenantService.getTenantNames(tenantIds); + } - @Override - public TenantView getActiveTenant(String userId) { - GeaflowTenant tenant = tenantService.getActiveTenant(userId); - if (tenant == null) { - List userTenants = tenantService.getUserTenants(userId); - if (userTenants.isEmpty()) { - throw new GeaflowException("User not in any tenants"); - } - - // active one tenant - tenant = userTenants.get(0); - tenantService.activateTenant(userId, tenant.getId()); - } - - return tenantViewConverter.convert(tenant); + @Override + public TenantView get(String tenantId) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String userId = ContextHolder.get().getUserId(); + if (!systemSession && !userService.existTenantUser(tenantId, userId)) { + throw new GeaflowIllegalException("Tenant not found"); } - @Override - public Map getTenantNames(Collection tenantIds) { - return tenantService.getTenantNames(tenantIds); - } - - @Override - public TenantView get(String tenantId) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String userId = ContextHolder.get().getUserId(); - if (!systemSession && !userService.existTenantUser(tenantId, userId)) { - throw new GeaflowIllegalException("Tenant not found"); - } - - return super.get(tenantId); - } - - + return super.get(tenantId); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/UserManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/UserManagerImpl.java index 1d7643ab8..f737f46d4 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/UserManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/UserManagerImpl.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.biz.shared.UserManager; import org.apache.geaflow.console.biz.shared.convert.NameViewConverter; import org.apache.geaflow.console.biz.shared.convert.UserViewConverter; @@ -39,83 +40,82 @@ import org.springframework.transaction.annotation.Transactional; @Service -public class UserManagerImpl extends NameManagerImpl implements UserManager { - - @Autowired - private UserService userService; - - @Autowired - private UserViewConverter userViewConverter; - - @Override - protected NameViewConverter getConverter() { - return userViewConverter; +public class UserManagerImpl extends NameManagerImpl + implements UserManager { + + @Autowired private UserService userService; + + @Autowired private UserViewConverter userViewConverter; + + @Override + protected NameViewConverter getConverter() { + return userViewConverter; + } + + @Override + protected List parse(List views) { + return ListUtil.convert(views, userViewConverter::convert); + } + + @Override + protected NameService getService() { + return userService; + } + + @Transactional + @Override + public String register(UserView view) { + GeaflowUser user = userViewConverter.convert(view); + return userService.createUser(user); + } + + @Override + public Map getUserNames(Collection userIds) { + return userService.getUserNames(userIds); + } + + @Override + public UserView getUser(String userId) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String tenantId = ContextHolder.get().getTenantId(); + if (!systemSession && !userService.existTenantUser(tenantId, userId)) { + throw new GeaflowIllegalException("User not found"); } - @Override - protected List parse(List views) { - return ListUtil.convert(views, userViewConverter::convert); - } + return get(userId); + } - @Override - protected NameService getService() { - return userService; + @Override + public String addUser(UserView view) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String tenantId = ContextHolder.get().getTenantId(); + if (systemSession) { + throw new GeaflowIllegalException("Use user register instead"); } - @Transactional - @Override - public String register(UserView view) { - GeaflowUser user = userViewConverter.convert(view); - return userService.createUser(user); - } + String userName = view.getName(); + String userId = userService.getIdByName(userName); + userService.addTenantUser(tenantId, userId); + return userId; + } - @Override - public Map getUserNames(Collection userIds) { - return userService.getUserNames(userIds); + @Override + public boolean updateUser(String userId, UserView view) { + if (!ContextHolder.get().getUserId().equals(userId)) { + throw new GeaflowIllegalException("Change other user failed"); } - @Override - public UserView getUser(String userId) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String tenantId = ContextHolder.get().getTenantId(); - if (!systemSession && !userService.existTenantUser(tenantId, userId)) { - throw new GeaflowIllegalException("User not found"); - } + return updateById(userId, view); + } - return get(userId); + @Override + public boolean deleteUser(String userId) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String tenantId = ContextHolder.get().getTenantId(); + if (systemSession) { + throw new GeaflowIllegalException("Permanently deleted user not allowed"); } - @Override - public String addUser(UserView view) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String tenantId = ContextHolder.get().getTenantId(); - if (systemSession) { - throw new GeaflowIllegalException("Use user register instead"); - } - - String userName = view.getName(); - String userId = userService.getIdByName(userName); - userService.addTenantUser(tenantId, userId); - return userId; - } - - @Override - public boolean updateUser(String userId, UserView view) { - if (!ContextHolder.get().getUserId().equals(userId)) { - throw new GeaflowIllegalException("Change other user failed"); - } - - return updateById(userId, view); - } - - @Override - public boolean deleteUser(String userId) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String tenantId = ContextHolder.get().getTenantId(); - if (systemSession) { - throw new GeaflowIllegalException("Permanently deleted user not allowed"); - } - - return userService.deleteTenantUser(tenantId, userId); - } + return userService.deleteTenantUser(tenantId, userId); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VersionManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VersionManagerImpl.java index e88c53ac4..e2084ae56 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VersionManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VersionManagerImpl.java @@ -21,15 +21,13 @@ import static org.apache.geaflow.console.core.service.RemoteFileService.JAR_FILE_SUFFIX; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.biz.shared.RemoteFileManager; @@ -59,245 +57,255 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; +import com.google.common.base.Preconditions; + +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j -public class VersionManagerImpl extends NameManagerImpl implements - VersionManager { +public class VersionManagerImpl extends NameManagerImpl + implements VersionManager { - private static final String ENGINE_JAR_PREFIX = ""; + private static final String ENGINE_JAR_PREFIX = ""; - private static final String LANG_JAR_PREFIX = "lang-"; + private static final String LANG_JAR_PREFIX = "lang-"; - private static final String GEAFLOW_DEFAULT_VERSION_NAME = "defaultVersion"; + private static final String GEAFLOW_DEFAULT_VERSION_NAME = "defaultVersion"; - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - @Autowired - private VersionViewConverter versionViewConverter; + @Autowired private VersionViewConverter versionViewConverter; - @Autowired - private RemoteFileService remoteFileService; + @Autowired private RemoteFileService remoteFileService; - @Autowired - private RemoteFileManager remoteFileManager; + @Autowired private RemoteFileManager remoteFileManager; - @Autowired - private RemoteFileStorage remoteFileStorage; + @Autowired private RemoteFileStorage remoteFileStorage; - @Override - protected NameViewConverter getConverter() { - return versionViewConverter; + @Override + protected NameViewConverter getConverter() { + return versionViewConverter; + } + + @Override + protected List parse(List views) { + return views.stream() + .map( + e -> { + GeaflowRemoteFile engineJar = + remoteFileService.get( + Optional.ofNullable(e.getEngineJarPackage()).map(IdView::getId).orElse(null)); + GeaflowRemoteFile langJar = + remoteFileService.get( + Optional.ofNullable(e.getLangJarPackage()).map(IdView::getId).orElse(null)); + return versionViewConverter.convert(e, engineJar, langJar); + }) + .collect(Collectors.toList()); + } + + @Override + protected NameService getService() { + return versionService; + } + + public PageList searchVersions(VersionSearch search) { + // only system admin can see none published version + if (!ContextHolder.get().isSystemSession()) { + search.setPublish(true); + } + return super.search(search); + } + + @Override + public VersionView getVersion(String name) { + // only system admin can see none published version + if (ContextHolder.get().isSystemSession()) { + return getByName(name); } - @Override - protected List parse(List views) { - return views.stream().map(e -> { - GeaflowRemoteFile engineJar = remoteFileService.get( - Optional.ofNullable(e.getEngineJarPackage()).map(IdView::getId).orElse(null)); - GeaflowRemoteFile langJar = remoteFileService.get( - Optional.ofNullable(e.getLangJarPackage()).map(IdView::getId).orElse(null)); - return versionViewConverter.convert(e, engineJar, langJar); - }).collect(Collectors.toList()); + return build(versionService.getPublishVersionByName(name)); + } + + @Override + public String createDefaultVersion() { + // in case of remote file config changed + remoteFileStorage.reset(); + + String path = + LocalFileFactory.getVersionFilePath( + GEAFLOW_DEFAULT_VERSION_NAME, GEAFLOW_DEFAULT_VERSION_NAME + ".jar"); + if (!FileUtil.exist(path)) { + throw new GeaflowIllegalException("No geaflow jar found in {}", path); } - @Override - protected NameService getService() { - return versionService; + VersionView versionView = new VersionView(); + versionView.setName(GEAFLOW_DEFAULT_VERSION_NAME); + versionView.setComment(I18nUtil.getMessage("i18n.key.default.version")); + versionView.setPublish(true); + + return createVersion(versionView, new LocalMultipartFile(new File(path)), null); + } + + @Transactional + @Override + public String createVersion( + VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile) { + String versionName = versionView.getName(); + if (StringUtils.isBlank(versionName)) { + throw new GeaflowIllegalException("Invalid version name"); } - public PageList searchVersions(VersionSearch search) { - // only system admin can see none published version - if (!ContextHolder.get().isSystemSession()) { - search.setPublish(true); - } - return super.search(search); + if (versionService.existName(versionName)) { + throw new GeaflowIllegalException("Version name {} exists", versionName); } - @Override - public VersionView getVersion(String name) { - // only system admin can see none published version - if (ContextHolder.get().isSystemSession()) { - return getByName(name); - } + Preconditions.checkNotNull(engineJarFile, "Invalid engineJarfile"); + versionView.setEngineJarPackage( + createRemoteFile(versionName, engineJarFile, ENGINE_JAR_PREFIX)); - return build(versionService.getPublishVersionByName(name)); + if (langJarFile != null) { + versionView.setLangJarPackage(createRemoteFile(versionName, langJarFile, LANG_JAR_PREFIX)); } - @Override - public String createDefaultVersion() { - // in case of remote file config changed - remoteFileStorage.reset(); - - String path = LocalFileFactory.getVersionFilePath(GEAFLOW_DEFAULT_VERSION_NAME, - GEAFLOW_DEFAULT_VERSION_NAME + ".jar"); - if (!FileUtil.exist(path)) { - throw new GeaflowIllegalException("No geaflow jar found in {}", path); - } - - VersionView versionView = new VersionView(); - versionView.setName(GEAFLOW_DEFAULT_VERSION_NAME); - versionView.setComment(I18nUtil.getMessage("i18n.key.default.version")); - versionView.setPublish(true); - - return createVersion(versionView, new LocalMultipartFile(new File(path)), null); + return super.create(versionView); + } + + @Transactional + @Override + public boolean updateVersion( + String name, VersionView updateView, MultipartFile engineJarFile, MultipartFile langJarFile) { + VersionView view = getByName(name); + if (view == null) { + throw new GeaflowIllegalException("Version name {} not exists", name); } - @Transactional - @Override - public String createVersion(VersionView versionView, MultipartFile engineJarFile, MultipartFile langJarFile) { - String versionName = versionView.getName(); - if (StringUtils.isBlank(versionName)) { - throw new GeaflowIllegalException("Invalid version name"); - } + if (engineJarFile != null) { + updateView.setEngineJarPackage(updateEngineJarFile(view, engineJarFile)); + } - if (versionService.existName(versionName)) { - throw new GeaflowIllegalException("Version name {} exists", versionName); - } + if (langJarFile != null) { + updateView.setLangJarPackage(updateLangJarFile(view, langJarFile)); + } - Preconditions.checkNotNull(engineJarFile, "Invalid engineJarfile"); - versionView.setEngineJarPackage(createRemoteFile(versionName, engineJarFile, ENGINE_JAR_PREFIX)); + return updateById(view.getId(), updateView); + } - if (langJarFile != null) { - versionView.setLangJarPackage(createRemoteFile(versionName, langJarFile, LANG_JAR_PREFIX)); - } + @Transactional + @Override + public boolean deleteVersion(String versionName) { + GeaflowVersion version = versionService.getByName(versionName); + if (version == null) { + return false; + } - return super.create(versionView); + GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); + if (engineJarPackage != null) { + remoteFileManager.deleteRefJar( + engineJarPackage.getId(), null, GeaflowResourceType.ENGINE_VERSION); } - @Transactional - @Override - public boolean updateVersion(String name, VersionView updateView, MultipartFile engineJarFile, - MultipartFile langJarFile) { - VersionView view = getByName(name); - if (view == null) { - throw new GeaflowIllegalException("Version name {} not exists", name); - } - - if (engineJarFile != null) { - updateView.setEngineJarPackage(updateEngineJarFile(view, engineJarFile)); - } - - if (langJarFile != null) { - updateView.setLangJarPackage(updateLangJarFile(view, langJarFile)); - } - - return updateById(view.getId(), updateView); + GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); + if (langJarPackage != null) { + remoteFileManager.deleteRefJar( + langJarPackage.getId(), null, GeaflowResourceType.ENGINE_VERSION); } - @Transactional - @Override - public boolean deleteVersion(String versionName) { - GeaflowVersion version = versionService.getByName(versionName); - if (version == null) { - return false; - } - - GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); - if (engineJarPackage != null) { - remoteFileManager.deleteRefJar(engineJarPackage.getId(), null, GeaflowResourceType.ENGINE_VERSION); - } - - GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); - if (langJarPackage != null) { - remoteFileManager.deleteRefJar(langJarPackage.getId(), null, GeaflowResourceType.ENGINE_VERSION); - } - - return drop(version.getId()); + return drop(version.getId()); + } + + private RemoteFileView createRemoteFile( + String versionName, MultipartFile multipartFile, String filePrefix) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); } + String fileName = filePrefix + versionName + JAR_FILE_SUFFIX; + String path = RemoteFileStorage.getVersionFilePath(versionName, fileName); - private RemoteFileView createRemoteFile(String versionName, MultipartFile multipartFile, String filePrefix) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + RemoteFileView remoteFileView = new RemoteFileView(); + remoteFileView.setName(fileName); + remoteFileView.setPath(path); + remoteFileManager.create(remoteFileView, multipartFile); - String fileName = filePrefix + versionName + JAR_FILE_SUFFIX; - String path = RemoteFileStorage.getVersionFilePath(versionName, fileName); + return remoteFileView; + } + + private RemoteFileView updateEngineJarFile(VersionView versionView, MultipartFile multipartFile) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); + } - RemoteFileView remoteFileView = new RemoteFileView(); - remoteFileView.setName(fileName); - remoteFileView.setPath(path); - remoteFileManager.create(remoteFileView, multipartFile); + RemoteFileView engineJarPackage = versionView.getEngineJarPackage(); + if (engineJarPackage == null) { + return createRemoteFile(versionView.getName(), multipartFile, ENGINE_JAR_PREFIX); - return remoteFileView; + } else { + String remoteFileId = engineJarPackage.getId(); + remoteFileManager.upload(remoteFileId, multipartFile); + return null; } + } - private RemoteFileView updateEngineJarFile(VersionView versionView, MultipartFile multipartFile) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + private RemoteFileView updateLangJarFile(VersionView versionView, MultipartFile multipartFile) { + if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { + throw new GeaflowIllegalException("Invalid jar file"); + } - RemoteFileView engineJarPackage = versionView.getEngineJarPackage(); - if (engineJarPackage == null) { - return createRemoteFile(versionView.getName(), multipartFile, ENGINE_JAR_PREFIX); + RemoteFileView langJarPackage = versionView.getLangJarPackage(); + if (langJarPackage == null) { + return createRemoteFile(versionView.getName(), multipartFile, LANG_JAR_PREFIX); - } else { - String remoteFileId = engineJarPackage.getId(); - remoteFileManager.upload(remoteFileId, multipartFile); - return null; - } + } else { + String remoteFileId = langJarPackage.getId(); + remoteFileManager.upload(remoteFileId, multipartFile); + return null; } + } - private RemoteFileView updateLangJarFile(VersionView versionView, MultipartFile multipartFile) { - if (!StringUtils.endsWith(multipartFile.getOriginalFilename(), JAR_FILE_SUFFIX)) { - throw new GeaflowIllegalException("Invalid jar file"); - } + @AllArgsConstructor + private static class LocalMultipartFile implements MultipartFile { - RemoteFileView langJarPackage = versionView.getLangJarPackage(); - if (langJarPackage == null) { - return createRemoteFile(versionView.getName(), multipartFile, LANG_JAR_PREFIX); + private final File file; - } else { - String remoteFileId = langJarPackage.getId(); - remoteFileManager.upload(remoteFileId, multipartFile); - return null; - } + @Override + public String getName() { + return file.getName(); } - @AllArgsConstructor - private static class LocalMultipartFile implements MultipartFile { - - private final File file; - - @Override - public String getName() { - return file.getName(); - } - - @Override - public String getOriginalFilename() { - return file.getName(); - } - - @Override - public String getContentType() { - throw new GeaflowException("Not supported"); - } - - @Override - public boolean isEmpty() { - throw new GeaflowException("Not supported"); - } - - @Override - public long getSize() { - throw new GeaflowException("Not supported"); - } - - @Override - public byte[] getBytes() throws IOException { - throw new GeaflowException("Not supported"); - } - - @Override - public InputStream getInputStream() throws IOException { - return FileUtils.openInputStream(file); - } - - @Override - public void transferTo(File dest) throws IOException, IllegalStateException { - throw new GeaflowException("Not supported"); - } + @Override + public String getOriginalFilename() { + return file.getName(); + } + + @Override + public String getContentType() { + throw new GeaflowException("Not supported"); + } + + @Override + public boolean isEmpty() { + throw new GeaflowException("Not supported"); + } + + @Override + public long getSize() { + throw new GeaflowException("Not supported"); + } + + @Override + public byte[] getBytes() throws IOException { + throw new GeaflowException("Not supported"); + } + + @Override + public InputStream getInputStream() throws IOException { + return FileUtils.openInputStream(file); + } + + @Override + public void transferTo(File dest) throws IOException, IllegalStateException { + throw new GeaflowException("Not supported"); } + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VertexManagerImpl.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VertexManagerImpl.java index 9c117707f..0f129de4f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VertexManagerImpl.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/impl/VertexManagerImpl.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.biz.shared.VertexManager; import org.apache.geaflow.console.biz.shared.convert.DataViewConverter; import org.apache.geaflow.console.biz.shared.convert.FieldViewConverter; @@ -37,44 +38,44 @@ import org.springframework.stereotype.Service; @Service -public class VertexManagerImpl extends DataManagerImpl implements - VertexManager { - - @Autowired - private VertexService vertexService; +public class VertexManagerImpl extends DataManagerImpl + implements VertexManager { - @Autowired - private VertexViewConverter vertexViewConverter; + @Autowired private VertexService vertexService; - @Autowired - private FieldViewConverter fieldViewConverter; + @Autowired private VertexViewConverter vertexViewConverter; - @Override - public DataViewConverter getConverter() { - return vertexViewConverter; - } + @Autowired private FieldViewConverter fieldViewConverter; - @Override - public DataService getService() { - return vertexService; - } + @Override + public DataViewConverter getConverter() { + return vertexViewConverter; + } - @Override - protected List parse(List views) { - return views.stream().map(e -> { - List fields = ListUtil.convert(e.getFields(), fieldViewConverter::convert); - return vertexViewConverter.converter(e, fields); - }).collect(Collectors.toList()); - } + @Override + public DataService getService() { + return vertexService; + } - @Override - public List getVerticesByGraphId(String graphId) { - return vertexService.getVerticesByGraphId(graphId); - } + @Override + protected List parse(List views) { + return views.stream() + .map( + e -> { + List fields = + ListUtil.convert(e.getFields(), fieldViewConverter::convert); + return vertexViewConverter.converter(e, fields); + }) + .collect(Collectors.toList()); + } - @Override - public List getVerticesByEdgeId(String edgeId) { - return vertexService.getVerticesByEdgeId(edgeId); - } + @Override + public List getVerticesByGraphId(String graphId) { + return vertexService.getVerticesByGraphId(graphId); + } + @Override + public List getVerticesByEdgeId(String edgeId) { + return vertexService.getVerticesByEdgeId(edgeId); + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuditView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuditView.java index 8369f81de..893bfd007 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuditView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuditView.java @@ -19,22 +19,23 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowOperationType; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowOperationType; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @NoArgsConstructor public class AuditView extends IdView { - private GeaflowOperationType operationType; + private GeaflowOperationType operationType; - private String detail; + private String detail; - private String resourceId; + private String resourceId; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthenticationView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthenticationView.java index b98ee6a5b..4bf2bd9f6 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthenticationView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthenticationView.java @@ -26,12 +26,11 @@ @Setter public class AuthenticationView { - private String userId; + private String userId; - private String sessionToken; + private String sessionToken; - private boolean systemSession; - - private String accessTime; + private boolean systemSession; + private String accessTime; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthorizationView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthorizationView.java index e13bc4472..afbb20cf9 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthorizationView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/AuthorizationView.java @@ -19,23 +19,23 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @NoArgsConstructor public class AuthorizationView extends IdView { - private String userId; - - private GeaflowAuthorityType authorityType; + private String userId; - private GeaflowResourceType resourceType; + private GeaflowAuthorityType authorityType; - private String resourceId; + private GeaflowResourceType resourceType; + private String resourceId; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ChatView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ChatView.java index f7dbc143a..32e9e6f2b 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ChatView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ChatView.java @@ -19,24 +19,24 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; @Getter @Setter @NoArgsConstructor public class ChatView extends IdView { - private String prompt; - - private String answer; + private String prompt; - private String modelId; + private String answer; - private String jobId; + private String modelId; - private GeaflowStatementStatus status; + private String jobId; + private GeaflowStatementStatus status; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterOperationView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterOperationView.java index 625706b1e..395f4ae7e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterOperationView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterOperationView.java @@ -26,6 +26,4 @@ @Getter @Setter @NoArgsConstructor -public class ClusterOperationView { - -} +public class ClusterOperationView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterView.java index d24afb4c2..82b4e7360 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ClusterView.java @@ -19,19 +19,19 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; +import org.apache.geaflow.console.core.model.config.GeaflowConfig; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; -import org.apache.geaflow.console.core.model.config.GeaflowConfig; @Getter @Setter @NoArgsConstructor public class ClusterView extends NameView { - private GeaflowPluginType type; - - private GeaflowConfig config; + private GeaflowPluginType type; + private GeaflowConfig config; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/DataView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/DataView.java index fa9ed78c3..58ea7b05c 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/DataView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/DataView.java @@ -24,6 +24,4 @@ @Getter @Setter -public abstract class DataView extends NameView { - -} +public abstract class DataView extends NameView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EdgeView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EdgeView.java index c014ffb7b..ee26cf7e8 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EdgeView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EdgeView.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowStructType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStructType; @Setter @Getter public class EdgeView extends StructView { - public EdgeView() { - type = GeaflowStructType.EDGE; - } - + public EdgeView() { + type = GeaflowStructType.EDGE; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EndpointView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EndpointView.java index 5597c6735..d8c74bd19 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EndpointView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/EndpointView.java @@ -30,10 +30,9 @@ @NoArgsConstructor public class EndpointView { - private String edgeName; + private String edgeName; - private String sourceName; + private String sourceName; - private String targetName; - -} \ No newline at end of file + private String targetName; +} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FieldView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FieldView.java index 1594d3192..4269c243e 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FieldView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FieldView.java @@ -19,24 +19,26 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; +import org.apache.geaflow.console.common.util.type.GeaflowFieldType; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; -import org.apache.geaflow.console.common.util.type.GeaflowFieldType; @Setter @Getter @NoArgsConstructor public class FieldView extends NameView { - private GeaflowFieldType type; + private GeaflowFieldType type; - private GeaflowFieldCategory category; + private GeaflowFieldCategory category; - public FieldView(String name, String comment, GeaflowFieldType type, GeaflowFieldCategory category) { - super(name, comment); - this.type = type; - this.category = category; - } + public FieldView( + String name, String comment, GeaflowFieldType type, GeaflowFieldCategory category) { + super(name, comment); + this.type = type; + this.category = category; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FunctionView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FunctionView.java index e78e53f95..62929d89f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FunctionView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/FunctionView.java @@ -30,8 +30,7 @@ @NoArgsConstructor public class FunctionView extends DataView { - private RemoteFileView jarPackage; - - private String entryClass; + private RemoteFileView jarPackage; + private String entryClass; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/GraphView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/GraphView.java index b5b4a888f..e73b8191d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/GraphView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/GraphView.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.List; + import lombok.Getter; import lombok.Setter; @@ -27,12 +28,11 @@ @Getter public class GraphView extends DataView { - private List vertices; - - private List edges; + private List vertices; - private PluginConfigView pluginConfig; + private List edges; - private List endpoints; + private PluginConfigView pluginConfig; + private List endpoints; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/IdView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/IdView.java index f8e3dbd63..0fa05efb0 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/IdView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/IdView.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.biz.shared.view; import java.io.Serializable; + import lombok.Getter; import lombok.Setter; @@ -27,22 +28,21 @@ @Setter public abstract class IdView implements Serializable { - protected String tenantId; - - protected String tenantName; + protected String tenantId; - protected String id; + protected String tenantName; - protected String createTime; + protected String id; - protected String modifyTime; + protected String createTime; - protected String creatorId; + protected String modifyTime; - protected String creatorName; + protected String creatorId; - protected String modifierId; + protected String creatorName; - protected String modifierName; + protected String modifierId; + protected String modifierName; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstallView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstallView.java index 2351cd003..187615dfc 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstallView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstallView.java @@ -19,26 +19,26 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowDeployMode; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowDeployMode; @Getter @Setter public class InstallView extends IdView { - private PluginConfigView runtimeClusterConfig; - - private PluginConfigView runtimeMetaConfig; + private PluginConfigView runtimeClusterConfig; - private PluginConfigView haMetaConfig; + private PluginConfigView runtimeMetaConfig; - private PluginConfigView metricConfig; + private PluginConfigView haMetaConfig; - private PluginConfigView remoteFileConfig; + private PluginConfigView metricConfig; - private PluginConfigView dataConfig; + private PluginConfigView remoteFileConfig; - private GeaflowDeployMode deployMode; + private PluginConfigView dataConfig; + private GeaflowDeployMode deployMode; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstanceView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstanceView.java index fe993a18d..f33375167 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstanceView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/InstanceView.java @@ -26,6 +26,4 @@ @Getter @Setter @NoArgsConstructor -public class InstanceView extends NameView { - -} +public class InstanceView extends NameView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobOperationView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobOperationView.java index 530121188..050503c33 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobOperationView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobOperationView.java @@ -26,6 +26,4 @@ @Getter @Setter @NoArgsConstructor -public class JobOperationView { - -} +public class JobOperationView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobView.java index 0414abad6..655e23262 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/JobView.java @@ -20,33 +20,35 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.List; + +import org.apache.geaflow.console.common.util.type.GeaflowJobType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowJobType; @Setter @Getter public class JobView extends NameView { - private String instanceId; + private String instanceId; - private String instanceName; + private String instanceName; - private GeaflowJobType type; + private GeaflowJobType type; - private String userCode; + private String userCode; - private String structMappings; + private String structMappings; - private List structs; + private List structs; - private List graphs; + private List graphs; - private SlaView sla; + private SlaView sla; - private List functions; + private List functions; - private String entryClass; + private String entryClass; - private RemoteFileView jarPackage; + private RemoteFileView jarPackage; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LLMView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LLMView.java index 55707aa48..495504549 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LLMView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LLMView.java @@ -19,20 +19,20 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowLLMType; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowLLMType; @Getter @Setter @NoArgsConstructor public class LLMView extends NameView { - private String url; - - private GeaflowLLMType type; + private String url; - private String args; + private GeaflowLLMType type; + private String args; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LoginView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LoginView.java index 432aabda2..1e679bafc 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LoginView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/LoginView.java @@ -26,10 +26,9 @@ @Setter public class LoginView { - private String loginName; + private String loginName; - private String password; - - private boolean systemLogin; + private String password; + private boolean systemLogin; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/NameView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/NameView.java index b503d260f..880c051f2 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/NameView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/NameView.java @@ -28,12 +28,12 @@ @NoArgsConstructor public abstract class NameView extends IdView { - protected String name; + protected String name; - protected String comment; + protected String comment; - public NameView(String name, String comment) { - this.name = name; - this.comment = comment; - } + public NameView(String name, String comment) { + this.name = name; + this.comment = comment; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginConfigView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginConfigView.java index e589b93bb..8848c891a 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginConfigView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginConfigView.java @@ -19,19 +19,19 @@ package org.apache.geaflow.console.biz.shared.view; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.core.model.config.GeaflowConfig; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class PluginConfigView extends NameView { - private String type; - - private GeaflowConfig config; + private String type; - private GeaflowPluginCategory category; + private GeaflowConfig config; + private GeaflowPluginCategory category; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginView.java index 777d904e3..badf23599 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/PluginView.java @@ -19,21 +19,22 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; @Getter @Setter @NoArgsConstructor public class PluginView extends NameView { - private String type; + private String type; - private GeaflowPluginCategory category; + private GeaflowPluginCategory category; - private RemoteFileView jarPackage; + private RemoteFileView jarPackage; - private boolean system; + private boolean system; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseUpdateView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseUpdateView.java index fcf537b69..c034f2c03 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseUpdateView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseUpdateView.java @@ -20,22 +20,23 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.Map; + +import org.apache.geaflow.console.core.model.config.GeaflowConfig; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.config.GeaflowConfig; @Getter @Setter public class ReleaseUpdateView { - private String versionName; - - private String clusterName; + private String versionName; - private Map newParallelisms; + private String clusterName; - private GeaflowConfig newJobConfig; + private Map newParallelisms; - private GeaflowConfig newClusterConfig; + private GeaflowConfig newJobConfig; + private GeaflowConfig newClusterConfig; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseView.java index 9add44413..0f65aa55d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/ReleaseView.java @@ -19,28 +19,29 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.core.model.config.GeaflowConfig; +import org.apache.geaflow.console.core.model.release.JobPlan; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.core.model.config.GeaflowConfig; -import org.apache.geaflow.console.core.model.release.JobPlan; @Getter @Setter @NoArgsConstructor public class ReleaseView extends IdView { - private JobView job; + private JobView job; - private String versionName; + private String versionName; - private JobPlan jobPlan; + private JobPlan jobPlan; - private GeaflowConfig jobConfig; + private GeaflowConfig jobConfig; - private GeaflowConfig clusterConfig; + private GeaflowConfig clusterConfig; - private String clusterName; + private String clusterName; - private int releaseVersion; + private int releaseVersion; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/RemoteFileView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/RemoteFileView.java index 442eb61d1..db28e57b6 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/RemoteFileView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/RemoteFileView.java @@ -19,11 +19,12 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowFileType; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowFileType; @Setter @Getter @@ -31,12 +32,11 @@ @NoArgsConstructor public class RemoteFileView extends NameView { - private GeaflowFileType type; - - private String path; + private GeaflowFileType type; - protected String url; + private String path; - private String md5; + protected String url; + private String md5; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SessionView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SessionView.java index 3e6cce8e2..7ea370953 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SessionView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SessionView.java @@ -20,19 +20,21 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.Set; + +import org.apache.geaflow.console.common.util.type.GeaflowRoleType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowRoleType; @Getter @Setter public class SessionView { - private UserView user; + private UserView user; - private TenantView tenant; + private TenantView tenant; - private AuthenticationView authentication; + private AuthenticationView authentication; - private Set roleTypes; + private Set roleTypes; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SlaView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SlaView.java index 71b2d6e40..b1a22fa93 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SlaView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SlaView.java @@ -19,6 +19,4 @@ package org.apache.geaflow.console.biz.shared.view; -public class SlaView extends IdView { - -} +public class SlaView extends IdView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StatementView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StatementView.java index 78210e06c..7f250b63d 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StatementView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StatementView.java @@ -19,19 +19,20 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; @Setter @Getter public class StatementView extends IdView { - private String script; + private String script; - private Object result; + private Object result; - private GeaflowStatementStatus status; + private GeaflowStatementStatus status; - private String jobId; + private String jobId; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StructView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StructView.java index 32af4abd7..5049b8bf9 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StructView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/StructView.java @@ -20,17 +20,17 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.List; + +import org.apache.geaflow.console.common.util.type.GeaflowStructType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStructType; @Setter @Getter public abstract class StructView extends DataView { - protected GeaflowStructType type; - - protected List fields; - + protected GeaflowStructType type; + protected List fields; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SystemConfigView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SystemConfigView.java index 4f69357ed..1ac789c2b 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SystemConfigView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/SystemConfigView.java @@ -26,6 +26,5 @@ @Setter public class SystemConfigView extends NameView { - private String value; - + private String value; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TableView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TableView.java index ca5b47ec5..4653cb8ff 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TableView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TableView.java @@ -19,17 +19,18 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowStructType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStructType; @Setter @Getter public class TableView extends StructView { - private PluginConfigView pluginConfig; + private PluginConfigView pluginConfig; - public TableView() { - type = GeaflowStructType.TABLE; - } + public TableView() { + type = GeaflowStructType.TABLE; + } } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskOperationView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskOperationView.java index 6bc1fb891..c1f368463 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskOperationView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskOperationView.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowOperationType; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowOperationType; @Getter @Setter @NoArgsConstructor public class TaskOperationView { - private GeaflowOperationType action; - + private GeaflowOperationType action; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskStartupNotifyView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskStartupNotifyView.java index 668b79277..471104b97 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskStartupNotifyView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskStartupNotifyView.java @@ -19,18 +19,18 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.core.model.task.K8sTaskHandle.StartupNotifyInfo; import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.task.K8sTaskHandle.StartupNotifyInfo; @Setter @Getter public class TaskStartupNotifyView { - private boolean success; + private boolean success; - private String message; + private String message; - private StartupNotifyInfo data; + private StartupNotifyInfo data; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskView.java index 4b402c2fe..ce33bf2b6 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TaskView.java @@ -20,30 +20,32 @@ package org.apache.geaflow.console.biz.shared.view; import java.util.Date; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.type.GeaflowTaskStatus; import org.apache.geaflow.console.common.util.type.GeaflowTaskType; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class TaskView extends IdView { - private ReleaseView release; + private ReleaseView release; - private GeaflowTaskType type; + private GeaflowTaskType type; - private GeaflowTaskStatus status; + private GeaflowTaskStatus status; - private Date startTime; + private Date startTime; - private Date endTime; + private Date endTime; - private PluginConfigView runtimeMetaPluginConfig; + private PluginConfigView runtimeMetaPluginConfig; - private PluginConfigView haMetaPluginConfig; + private PluginConfigView haMetaPluginConfig; - private PluginConfigView metricPluginConfig; + private PluginConfigView metricPluginConfig; - private PluginConfigView dataPluginConfig; + private PluginConfigView dataPluginConfig; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TenantView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TenantView.java index ec282c3a1..0b70cd6f3 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TenantView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/TenantView.java @@ -24,6 +24,4 @@ @Getter @Setter -public class TenantView extends NameView { - -} +public class TenantView extends NameView {} diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/UserView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/UserView.java index 47a913427..2e6c2cfe5 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/UserView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/UserView.java @@ -26,10 +26,9 @@ @Setter public class UserView extends NameView { - private String password; + private String password; - private String phone; - - private String email; + private String phone; + private String email; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VersionView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VersionView.java index 0d5ea3ee3..ddd0b8eb6 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VersionView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VersionView.java @@ -28,10 +28,9 @@ @NoArgsConstructor public class VersionView extends NameView { - private RemoteFileView engineJarPackage; + private RemoteFileView engineJarPackage; - private RemoteFileView langJarPackage; - - private Boolean publish = false; + private RemoteFileView langJarPackage; + private Boolean publish = false; } diff --git a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VertexView.java b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VertexView.java index 698f9547a..609cf0d8f 100644 --- a/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VertexView.java +++ b/geaflow-console/app/biz/shared/src/main/java/org/apache/geaflow/console/biz/shared/view/VertexView.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.biz.shared.view; +import org.apache.geaflow.console.common.util.type.GeaflowStructType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStructType; @Setter @Getter public class VertexView extends StructView { - public VertexView() { - type = GeaflowStructType.VERTEX; - } - + public VertexView() { + type = GeaflowStructType.VERTEX; + } } diff --git a/geaflow-console/app/bootstrap/src/main/java/org/apache/geaflow/console/bootstrap/GeaflowApplication.java b/geaflow-console/app/bootstrap/src/main/java/org/apache/geaflow/console/bootstrap/GeaflowApplication.java index 4b2e2729a..0e3d0c878 100644 --- a/geaflow-console/app/bootstrap/src/main/java/org/apache/geaflow/console/bootstrap/GeaflowApplication.java +++ b/geaflow-console/app/bootstrap/src/main/java/org/apache/geaflow/console/bootstrap/GeaflowApplication.java @@ -34,8 +34,7 @@ @EnableScheduling public class GeaflowApplication { - public static void main(String[] args) { - SpringApplication.run(GeaflowApplication.class, args); - } - + public static void main(String[] args) { + SpringApplication.run(GeaflowApplication.class, args); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/IdGenerator.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/IdGenerator.java index 159c63eb0..c1fbe8e69 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/IdGenerator.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/IdGenerator.java @@ -25,7 +25,7 @@ @Component public class IdGenerator { - public static String nextId() { - return System.currentTimeMillis() + RandomStringUtils.randomNumeric(6); - } + public static String nextId() { + return System.currentTimeMillis() + RandomStringUtils.randomNumeric(6); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuditDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuditDao.java index a6e55e777..50c0701ed 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuditDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuditDao.java @@ -19,21 +19,32 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.AuditEntity; import org.apache.geaflow.console.common.dal.mapper.AuditMapper; import org.apache.geaflow.console.common.dal.model.AuditSearch; import org.springframework.stereotype.Repository; -@Repository -public class AuditDao extends TenantLevelDao implements IdDao { +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; - @Override - public void configSearch(LambdaQueryWrapper wrapper, AuditSearch search) { - wrapper.eq(search.getResourceType() != null, AuditEntity::getResourceType, search.getResourceType()) - .eq(StringUtils.isNotBlank(search.getResourceId()), AuditEntity::getResourceId, search.getResourceId()) - .eq(search.getOperationType() != null, AuditEntity::getOperationType, search.getOperationType()); - } +@Repository +public class AuditDao extends TenantLevelDao + implements IdDao { + @Override + public void configSearch(LambdaQueryWrapper wrapper, AuditSearch search) { + wrapper + .eq( + search.getResourceType() != null, + AuditEntity::getResourceType, + search.getResourceType()) + .eq( + StringUtils.isNotBlank(search.getResourceId()), + AuditEntity::getResourceId, + search.getResourceId()) + .eq( + search.getOperationType() != null, + AuditEntity::getOperationType, + search.getOperationType()); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuthorizationDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuthorizationDao.java index a33fe9fad..1995a0cb1 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuthorizationDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/AuthorizationDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.AuthorizationEntity; import org.apache.geaflow.console.common.dal.mapper.AuthorizationMapper; import org.apache.geaflow.console.common.dal.model.AuthorizationSearch; @@ -28,32 +28,53 @@ import org.apache.geaflow.console.common.util.type.GeaflowResourceType; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class AuthorizationDao extends TenantLevelDao implements - IdDao { - - public boolean exist(String userId, GeaflowAuthorityType authorityType, GeaflowResourceType resourceType, - String resourceId) { - return lambdaQuery().eq(AuthorizationEntity::getUserId, userId) - .eq(AuthorizationEntity::getResourceId, resourceId) - .eq(AuthorizationEntity::getResourceType, resourceType) - .and(e -> e.eq(AuthorizationEntity::getAuthorityType, authorityType).or() - // ALL includes other types - .eq(authorityType != GeaflowAuthorityType.ALL, AuthorizationEntity::getAuthorityType, GeaflowAuthorityType.ALL)) - .exists(); - } - - @Override - public void configSearch(LambdaQueryWrapper wrapper, AuthorizationSearch search) { - wrapper.eq(search.getAuthorityType() != null, AuthorizationEntity::getAuthorityType, search.getAuthorityType()); - wrapper.eq(search.getResourceId() != null, AuthorizationEntity::getResourceId, search.getResourceId()); - wrapper.eq(search.getResourceType() != null, AuthorizationEntity::getResourceType, search.getResourceType()); - wrapper.eq(search.getUserId() != null, AuthorizationEntity::getUserId, search.getUserId()); - } - - public boolean dropByResources(List resourceIds, GeaflowResourceType type) { - return lambdaUpdate().in(AuthorizationEntity::getResourceId, resourceIds) - .eq(AuthorizationEntity::getResourceType, type) - .remove(); - } +public class AuthorizationDao extends TenantLevelDao + implements IdDao { + + public boolean exist( + String userId, + GeaflowAuthorityType authorityType, + GeaflowResourceType resourceType, + String resourceId) { + return lambdaQuery() + .eq(AuthorizationEntity::getUserId, userId) + .eq(AuthorizationEntity::getResourceId, resourceId) + .eq(AuthorizationEntity::getResourceType, resourceType) + .and( + e -> + e.eq(AuthorizationEntity::getAuthorityType, authorityType) + .or() + // ALL includes other types + .eq( + authorityType != GeaflowAuthorityType.ALL, + AuthorizationEntity::getAuthorityType, + GeaflowAuthorityType.ALL)) + .exists(); + } + + @Override + public void configSearch( + LambdaQueryWrapper wrapper, AuthorizationSearch search) { + wrapper.eq( + search.getAuthorityType() != null, + AuthorizationEntity::getAuthorityType, + search.getAuthorityType()); + wrapper.eq( + search.getResourceId() != null, AuthorizationEntity::getResourceId, search.getResourceId()); + wrapper.eq( + search.getResourceType() != null, + AuthorizationEntity::getResourceType, + search.getResourceType()); + wrapper.eq(search.getUserId() != null, AuthorizationEntity::getUserId, search.getUserId()); + } + + public boolean dropByResources(List resourceIds, GeaflowResourceType type) { + return lambdaUpdate() + .in(AuthorizationEntity::getResourceId, resourceIds) + .eq(AuthorizationEntity::getResourceType, type) + .remove(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ChatDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ChatDao.java index 08726bd81..a86f49d64 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ChatDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ChatDao.java @@ -19,28 +19,29 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.apache.geaflow.console.common.dal.entity.ChatEntity; import org.apache.geaflow.console.common.dal.mapper.ChatMapper; import org.apache.geaflow.console.common.dal.model.ChatSearch; import org.apache.geaflow.console.common.util.context.ContextHolder; import org.springframework.stereotype.Repository; -@Repository -public class ChatDao extends TenantLevelDao implements IdDao { +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; - public boolean dropByJobId(String jobId) { - String userId = ContextHolder.get().getUserId(); - return lambdaUpdate() - .eq(ChatEntity::getCreatorId, userId) - .eq(ChatEntity::getJobId, jobId) - .remove(); - } +@Repository +public class ChatDao extends TenantLevelDao + implements IdDao { + public boolean dropByJobId(String jobId) { + String userId = ContextHolder.get().getUserId(); + return lambdaUpdate() + .eq(ChatEntity::getCreatorId, userId) + .eq(ChatEntity::getJobId, jobId) + .remove(); + } - @Override - public void configSearch(LambdaQueryWrapper wrapper, ChatSearch search) { - wrapper.eq(search.getJobId() != null, ChatEntity::getJobId, search.getJobId()); - wrapper.eq(search.getModelId() != null, ChatEntity::getModelId, search.getModelId()); - } + @Override + public void configSearch(LambdaQueryWrapper wrapper, ChatSearch search) { + wrapper.eq(search.getJobId() != null, ChatEntity::getJobId, search.getJobId()); + wrapper.eq(search.getModelId() != null, ChatEntity::getModelId, search.getModelId()); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ClusterDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ClusterDao.java index 0ff54bcf9..9dff16f18 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ClusterDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ClusterDao.java @@ -26,9 +26,10 @@ import org.springframework.stereotype.Repository; @Repository -public class ClusterDao extends SystemLevelDao implements NameDao { +public class ClusterDao extends SystemLevelDao + implements NameDao { - public ClusterEntity getDefaultCluster() { - return lambdaQuery().orderByDesc(IdEntity::getGmtModified).last("limit 1").one(); - } + public ClusterEntity getDefaultCluster() { + return lambdaQuery().orderByDesc(IdEntity::getGmtModified).last("limit 1").one(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/DataDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/DataDao.java index 6d8a0eebe..8ee0f6fa0 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/DataDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/DataDao.java @@ -19,14 +19,12 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import com.github.yulichang.wrapper.MPJLambdaWrapper; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.common.dal.entity.DataEntity; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.NameEntity; @@ -35,74 +33,81 @@ import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.springframework.util.CollectionUtils; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.github.yulichang.wrapper.MPJLambdaWrapper; +import com.google.common.base.Preconditions; + public interface DataDao extends NameDao { - String INSTANCE_ID_FIELD_NAME = "instance_id"; + String INSTANCE_ID_FIELD_NAME = "instance_id"; - @Override - default List getByNames(List names) { - throw new GeaflowException("Use getByNames(instanceId, names) instead"); - } + @Override + default List getByNames(List names) { + throw new GeaflowException("Use getByNames(instanceId, names) instead"); + } - @Override - default Map getIdsByNames(List names) { - throw new GeaflowException("Use getIdsByNames(instanceId, names) instead"); - } + @Override + default Map getIdsByNames(List names) { + throw new GeaflowException("Use getIdsByNames(instanceId, names) instead"); + } - default E getByName(String instanceId, String name) { - if (name == null) { - return null; - } - List entities = getByNames(instanceId, Collections.singletonList(name)); - return entities.isEmpty() ? null : entities.get(0); + default E getByName(String instanceId, String name) { + if (name == null) { + return null; } + List entities = getByNames(instanceId, Collections.singletonList(name)); + return entities.isEmpty() ? null : entities.get(0); + } - default Map getIdByName(String instanceId, String name) { - if (name == null) { - return new HashMap<>(); - } - return getIdsByNames(instanceId, Collections.singletonList(name)); + default Map getIdByName(String instanceId, String name) { + if (name == null) { + return new HashMap<>(); } + return getIdsByNames(instanceId, Collections.singletonList(name)); + } + default List getByNames(String instanceId, List names) { + Preconditions.checkNotNull(instanceId, "Invalid instanceId"); - default List getByNames(String instanceId, List names) { - Preconditions.checkNotNull(instanceId, "Invalid instanceId"); - - if (CollectionUtils.isEmpty(names)) { - return new ArrayList<>(); - } - - return lambdaQuery().in(E::getName, names).eq(E::getInstanceId, instanceId).list(); + if (CollectionUtils.isEmpty(names)) { + return new ArrayList<>(); } - default Map getIdsByNames(String instanceId, List names) { - if (CollectionUtils.isEmpty(names)) { - return new HashMap<>(); - } + return lambdaQuery().in(E::getName, names).eq(E::getInstanceId, instanceId).list(); + } - List entities = lambdaQuery().select(E::getId, E::getName).eq(E::getInstanceId, instanceId) - .in(E::getName, names).list(); - return ListUtil.toMap(entities, NameEntity::getName, IdEntity::getId); + default Map getIdsByNames(String instanceId, List names) { + if (CollectionUtils.isEmpty(names)) { + return new HashMap<>(); } - default List getByInstanceId(String instanceId) { - Preconditions.checkNotNull(instanceId); - return lambdaQuery().eq(E::getInstanceId, instanceId).list(); - } - - @Override - default void configBaseSearch(QueryWrapper wrapper, S search) { - NameDao.super.configBaseSearch(wrapper, search); - - String instanceId = search.getInstanceId(); - wrapper.eq(instanceId != null, INSTANCE_ID_FIELD_NAME, instanceId); - } - - @Override - default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { - NameDao.super.configBaseJoinSearch(wrapper, search); - - String instanceId = search.getInstanceId(); - wrapper.eq(instanceId != null, E::getInstanceId, instanceId); - } + List entities = + lambdaQuery() + .select(E::getId, E::getName) + .eq(E::getInstanceId, instanceId) + .in(E::getName, names) + .list(); + return ListUtil.toMap(entities, NameEntity::getName, IdEntity::getId); + } + + default List getByInstanceId(String instanceId) { + Preconditions.checkNotNull(instanceId); + return lambdaQuery().eq(E::getInstanceId, instanceId).list(); + } + + @Override + default void configBaseSearch(QueryWrapper wrapper, S search) { + NameDao.super.configBaseSearch(wrapper, search); + + String instanceId = search.getInstanceId(); + wrapper.eq(instanceId != null, INSTANCE_ID_FIELD_NAME, instanceId); + } + + @Override + default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { + NameDao.super.configBaseJoinSearch(wrapper, search); + + String instanceId = search.getInstanceId(); + wrapper.eq(instanceId != null, E::getInstanceId, instanceId); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EdgeDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EdgeDao.java index 57bae2d8a..12a87e995 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EdgeDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EdgeDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.EdgeEntity; import org.apache.geaflow.console.common.dal.entity.GraphStructMappingEntity; import org.apache.geaflow.console.common.dal.mapper.EdgeMapper; @@ -29,21 +29,21 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Repository; +import com.github.yulichang.wrapper.MPJLambdaWrapper; + @Repository -public class EdgeDao extends TenantLevelDao implements DataDao { +public class EdgeDao extends TenantLevelDao + implements DataDao { - @Autowired - private GraphStructMappingMapper graphStructMappingMapper; + @Autowired private GraphStructMappingMapper graphStructMappingMapper; - public List getByGraphId(String graphId) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll( - EdgeEntity.class).innerJoin(EdgeEntity.class, EdgeEntity::getId, - GraphStructMappingEntity::getResourceId) + public List getByGraphId(String graphId) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(EdgeEntity.class) + .innerJoin(EdgeEntity.class, EdgeEntity::getId, GraphStructMappingEntity::getResourceId) .eq(GraphStructMappingEntity::getGraphId, graphId) .orderByAsc(GraphStructMappingEntity::getSortKey); - return graphStructMappingMapper.selectJoinList(EdgeEntity.class, wrapper); - } - - + return graphStructMappingMapper.selectJoinList(EdgeEntity.class, wrapper); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EndpointDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EndpointDao.java index 469702ff8..4198df6b3 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EndpointDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/EndpointDao.java @@ -20,45 +20,49 @@ package org.apache.geaflow.console.common.dal.dao; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.EndpointEntity; import org.apache.geaflow.console.common.dal.mapper.EndpointMapper; import org.apache.geaflow.console.common.dal.model.IdSearch; import org.springframework.stereotype.Repository; @Repository -public class EndpointDao extends TenantLevelDao implements - IdDao { - - public List getByGraphId(String graphId) { - return lambdaQuery().eq(EndpointEntity::getGraphId, graphId).list(); - } - - public EndpointEntity getByEndpoint(String graphId, String edgeId, String sourceId, String targetId) { - return lambdaQuery() - .eq(EndpointEntity::getGraphId, graphId) - .eq(EndpointEntity::getEdgeId, edgeId) - .eq(EndpointEntity::getSourceId, sourceId) - .eq(EndpointEntity::getTargetId, targetId).one(); - } +public class EndpointDao extends TenantLevelDao + implements IdDao { - public boolean exists(String graphId, String edgeId, String sourceId, String targetId) { - return lambdaQuery().eq(EndpointEntity::getGraphId, graphId) - .eq(EndpointEntity::getEdgeId, edgeId) - .eq(EndpointEntity::getSourceId, sourceId) - .eq(EndpointEntity::getTargetId, targetId).exists(); - } + public List getByGraphId(String graphId) { + return lambdaQuery().eq(EndpointEntity::getGraphId, graphId).list(); + } + public EndpointEntity getByEndpoint( + String graphId, String edgeId, String sourceId, String targetId) { + return lambdaQuery() + .eq(EndpointEntity::getGraphId, graphId) + .eq(EndpointEntity::getEdgeId, edgeId) + .eq(EndpointEntity::getSourceId, sourceId) + .eq(EndpointEntity::getTargetId, targetId) + .one(); + } - public boolean dropByGraphIds(List graphIds) { - return lambdaUpdate().eq(EndpointEntity::getGraphId, graphIds).remove(); - } + public boolean exists(String graphId, String edgeId, String sourceId, String targetId) { + return lambdaQuery() + .eq(EndpointEntity::getGraphId, graphId) + .eq(EndpointEntity::getEdgeId, edgeId) + .eq(EndpointEntity::getSourceId, sourceId) + .eq(EndpointEntity::getTargetId, targetId) + .exists(); + } - public boolean dropByEndpoint(String graphId, String edgeId, String sourceId, String targetId) { - return lambdaUpdate() - .eq(EndpointEntity::getGraphId, graphId) - .eq(EndpointEntity::getEdgeId, edgeId) - .eq(EndpointEntity::getSourceId, sourceId) - .eq(EndpointEntity::getTargetId, targetId).remove(); - } + public boolean dropByGraphIds(List graphIds) { + return lambdaUpdate().eq(EndpointEntity::getGraphId, graphIds).remove(); + } + public boolean dropByEndpoint(String graphId, String edgeId, String sourceId, String targetId) { + return lambdaUpdate() + .eq(EndpointEntity::getGraphId, graphId) + .eq(EndpointEntity::getEdgeId, edgeId) + .eq(EndpointEntity::getSourceId, sourceId) + .eq(EndpointEntity::getTargetId, targetId) + .remove(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FieldDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FieldDao.java index 31b70f1c3..cf093d1aa 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FieldDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FieldDao.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.FieldEntity; import org.apache.geaflow.console.common.dal.mapper.FieldMapper; import org.apache.geaflow.console.common.dal.model.FieldSearch; @@ -31,30 +32,33 @@ import org.springframework.util.CollectionUtils; @Repository -public class FieldDao extends TenantLevelDao implements NameDao { - +public class FieldDao extends TenantLevelDao + implements NameDao { - public boolean removeByResources(List resourceIds, GeaflowResourceType resourceType) { - if (CollectionUtils.isEmpty(resourceIds)) { - return true; - } - return lambdaUpdate().in(FieldEntity::getResourceId, resourceIds).eq(FieldEntity::getResourceType, resourceType).remove(); + public boolean removeByResources(List resourceIds, GeaflowResourceType resourceType) { + if (CollectionUtils.isEmpty(resourceIds)) { + return true; } + return lambdaUpdate() + .in(FieldEntity::getResourceId, resourceIds) + .eq(FieldEntity::getResourceType, resourceType) + .remove(); + } - - public List getByResources(List resourceIds, GeaflowResourceType resourceType) { - if (CollectionUtils.isEmpty(resourceIds)) { - return new ArrayList<>(); - } - - return lambdaQuery().in(FieldEntity::getResourceId, resourceIds) - .eq(FieldEntity::getResourceType, resourceType.name()) - .orderByAsc(Arrays.asList(FieldEntity::getResourceId, FieldEntity::getSortKey)) - .list(); + public List getByResources( + List resourceIds, GeaflowResourceType resourceType) { + if (CollectionUtils.isEmpty(resourceIds)) { + return new ArrayList<>(); } + return lambdaQuery() + .in(FieldEntity::getResourceId, resourceIds) + .eq(FieldEntity::getResourceType, resourceType.name()) + .orderByAsc(Arrays.asList(FieldEntity::getResourceId, FieldEntity::getSortKey)) + .list(); + } - public List getByResource(String resourceId, GeaflowResourceType resourceType) { - return getByResources(Collections.singletonList(resourceId), resourceType); - } + public List getByResource(String resourceId, GeaflowResourceType resourceType) { + return getByResources(Collections.singletonList(resourceId), resourceType); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FunctionDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FunctionDao.java index b0cb915b6..552465344 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FunctionDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/FunctionDao.java @@ -26,11 +26,13 @@ import org.springframework.stereotype.Repository; @Repository -public class FunctionDao extends TenantLevelDao implements DataDao { +public class FunctionDao extends TenantLevelDao + implements DataDao { - public long getFileRefCount(String fileId, String excludeFunctionId) { - return lambdaQuery().eq(FunctionEntity::getJarPackageId, fileId) - .ne(excludeFunctionId != null, IdEntity::getId, excludeFunctionId) - .count(); - } + public long getFileRefCount(String fileId, String excludeFunctionId) { + return lambdaQuery() + .eq(FunctionEntity::getJarPackageId, fileId) + .ne(excludeFunctionId != null, IdEntity::getId, excludeFunctionId) + .count(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GeaflowBaseDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GeaflowBaseDao.java index f4e82d58d..b11c67244 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GeaflowBaseDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GeaflowBaseDao.java @@ -19,6 +19,12 @@ package org.apache.geaflow.console.common.dal.dao; +import org.apache.geaflow.console.common.dal.entity.IdEntity; +import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; +import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaQueryChainWrapper; +import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaUpdateChainWrapper; +import org.springframework.beans.factory.InitializingBean; + import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; @@ -26,44 +32,35 @@ import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; import com.baomidou.mybatisplus.extension.conditions.update.LambdaUpdateChainWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import org.apache.geaflow.console.common.dal.entity.IdEntity; -import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; -import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaQueryChainWrapper; -import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaUpdateChainWrapper; -import org.springframework.beans.factory.InitializingBean; - -public abstract class GeaflowBaseDao, E extends IdEntity> extends - ServiceImpl implements InitializingBean { - - protected boolean ignoreTenant; - protected String tableName; +public abstract class GeaflowBaseDao, E extends IdEntity> + extends ServiceImpl implements InitializingBean { - @Override - public void afterPropertiesSet() throws Exception { - this.ignoreTenant = InterceptorIgnoreHelper.willIgnoreTenantLine(mapperClass.getName() + ".*"); - this.tableName = entityClass.getAnnotation(TableName.class).value(); - } + protected boolean ignoreTenant; - @Override - public LambdaQueryChainWrapper lambdaQuery() { - QueryWrapper wrapper = new QueryWrapper<>(); - configQueryWrapper(wrapper); - return new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), wrapper.lambda()); - } + protected String tableName; - @Override - public LambdaUpdateChainWrapper lambdaUpdate() { - UpdateWrapper wrapper = new UpdateWrapper<>(); - configUpdateWrapper(wrapper); - return new GeaflowLambdaUpdateChainWrapper<>(getBaseMapper(), wrapper.lambda()); - } + @Override + public void afterPropertiesSet() throws Exception { + this.ignoreTenant = InterceptorIgnoreHelper.willIgnoreTenantLine(mapperClass.getName() + ".*"); + this.tableName = entityClass.getAnnotation(TableName.class).value(); + } - public void configQueryWrapper(QueryWrapper wrapper) { + @Override + public LambdaQueryChainWrapper lambdaQuery() { + QueryWrapper wrapper = new QueryWrapper<>(); + configQueryWrapper(wrapper); + return new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), wrapper.lambda()); + } - } + @Override + public LambdaUpdateChainWrapper lambdaUpdate() { + UpdateWrapper wrapper = new UpdateWrapper<>(); + configUpdateWrapper(wrapper); + return new GeaflowLambdaUpdateChainWrapper<>(getBaseMapper(), wrapper.lambda()); + } - public void configUpdateWrapper(UpdateWrapper wrapper) { + public void configQueryWrapper(QueryWrapper wrapper) {} - } + public void configUpdateWrapper(UpdateWrapper wrapper) {} } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphDao.java index b0e72742c..431c6cb16 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphDao.java @@ -25,6 +25,5 @@ import org.springframework.stereotype.Repository; @Repository -public class GraphDao extends TenantLevelDao implements DataDao { - -} +public class GraphDao extends TenantLevelDao + implements DataDao {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphStructMappingDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphStructMappingDao.java index 2ec1e0467..2c5e403ee 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphStructMappingDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/GraphStructMappingDao.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.dao; import java.util.List; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.entity.GraphStructMappingEntity; import org.apache.geaflow.console.common.dal.mapper.GraphStructMappingMapper; @@ -28,34 +29,41 @@ import org.springframework.stereotype.Repository; @Repository -public class GraphStructMappingDao extends TenantLevelDao implements - IdDao { - - public void removeGraphStructs(String graphId, List vertexIds, List edgeIds) { - if (!CollectionUtils.isEmpty(vertexIds)) { - lambdaUpdate().in(GraphStructMappingEntity::getResourceId, vertexIds) - .eq(GraphStructMappingEntity::getResourceType, GeaflowResourceType.VERTEX) - .eq(GraphStructMappingEntity::getGraphId, graphId).remove(); - } - - if (!CollectionUtils.isEmpty(edgeIds)) { - lambdaUpdate().in(GraphStructMappingEntity::getResourceId, edgeIds) - .eq(GraphStructMappingEntity::getResourceType, GeaflowResourceType.EDGE) - .eq(GraphStructMappingEntity::getGraphId, graphId).remove(); - } - } +public class GraphStructMappingDao + extends TenantLevelDao + implements IdDao { - public boolean removeByGraphIds(List graphIds) { - if (CollectionUtils.isEmpty(graphIds)) { - return true; - } + public void removeGraphStructs(String graphId, List vertexIds, List edgeIds) { + if (!CollectionUtils.isEmpty(vertexIds)) { + lambdaUpdate() + .in(GraphStructMappingEntity::getResourceId, vertexIds) + .eq(GraphStructMappingEntity::getResourceType, GeaflowResourceType.VERTEX) + .eq(GraphStructMappingEntity::getGraphId, graphId) + .remove(); + } - return lambdaUpdate().in(GraphStructMappingEntity::getGraphId, graphIds).remove(); + if (!CollectionUtils.isEmpty(edgeIds)) { + lambdaUpdate() + .in(GraphStructMappingEntity::getResourceId, edgeIds) + .eq(GraphStructMappingEntity::getResourceType, GeaflowResourceType.EDGE) + .eq(GraphStructMappingEntity::getGraphId, graphId) + .remove(); } + } - public List getByResourceId(String resourceId, GeaflowResourceType resourceType) { - return lambdaQuery().eq(GraphStructMappingEntity::getResourceId, resourceId) - .eq(GraphStructMappingEntity::getResourceType, resourceType) - .list(); + public boolean removeByGraphIds(List graphIds) { + if (CollectionUtils.isEmpty(graphIds)) { + return true; } + + return lambdaUpdate().in(GraphStructMappingEntity::getGraphId, graphIds).remove(); + } + + public List getByResourceId( + String resourceId, GeaflowResourceType resourceType) { + return lambdaQuery() + .eq(GraphStructMappingEntity::getResourceId, resourceId) + .eq(GraphStructMappingEntity::getResourceType, resourceType) + .list(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/IdDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/IdDao.java index 7b97e2066..e3f1b4e8e 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/IdDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/IdDao.java @@ -19,22 +19,13 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; -import com.baomidou.mybatisplus.core.metadata.OrderItem; -import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; -import com.baomidou.mybatisplus.extension.service.IService; -import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import com.github.yulichang.base.mapper.MPJJoinMapper; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.IdGenerator; import org.apache.geaflow.console.common.dal.entity.IdEntity; @@ -46,207 +37,218 @@ import org.apache.geaflow.console.common.util.context.ContextHolder; import org.springframework.util.CollectionUtils; -public interface IdDao extends IService { - - String TENANT_ID_FIELD_NAME = "tenant_id"; - - String CREATE_TIME_FIELD_NAME = "gmt_create"; - - String MODIFY_TIME_FIELD_NAME = "gmt_modified"; - - String CREATOR_FIELD_NAME = "creator_id"; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; +import com.baomidou.mybatisplus.core.metadata.OrderItem; +import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.baomidou.mybatisplus.extension.service.IService; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.github.yulichang.base.mapper.MPJJoinMapper; +import com.github.yulichang.wrapper.MPJLambdaWrapper; - String MODIFIER_FIELD_NAME = "modifier_id"; +public interface IdDao extends IService { - default boolean exist(String id) { - if (id == null) { - return false; - } + String TENANT_ID_FIELD_NAME = "tenant_id"; - return lambdaQuery().eq(E::getId, id).exists(); - } + String CREATE_TIME_FIELD_NAME = "gmt_create"; - default E get(String id) { - if (id == null) { - return null; - } - List entities = get(Collections.singletonList(id)); - return entities.isEmpty() ? null : entities.get(0); - } + String MODIFY_TIME_FIELD_NAME = "gmt_modified"; - default String create(E entity) { - if (entity == null) { - return null; - } - List ids = create(Collections.singletonList(entity)); - return ids.isEmpty() ? null : ids.get(0); - } + String CREATOR_FIELD_NAME = "creator_id"; - default boolean update(E entity) { - if (entity == null) { - return false; - } - return update(Collections.singletonList(entity)); - } + String MODIFIER_FIELD_NAME = "modifier_id"; - default boolean drop(String id) { - if (id == null) { - return false; - } - return drop(Collections.singletonList(id)); + default boolean exist(String id) { + if (id == null) { + return false; } - default List get(List ids) { - if (CollectionUtils.isEmpty(ids)) { - return new ArrayList<>(); - } + return lambdaQuery().eq(E::getId, id).exists(); + } - return listByIds(ids); + default E get(String id) { + if (id == null) { + return null; } + List entities = get(Collections.singletonList(id)); + return entities.isEmpty() ? null : entities.get(0); + } - default List create(List entities) { - String tenantId = ContextHolder.get().getTenantId(); - String userId = ContextHolder.get().getUserId(); - Date date = new Date(); - - boolean systemSession = ContextHolder.get().isSystemSession(); - boolean userLevel = this instanceof UserLevelDao; - - entities.forEach(e -> { - e.setTenantId(Optional.ofNullable(e.getTenantId()).orElse(tenantId)); - e.setId(IdGenerator.nextId()); - e.setGmtCreate(Optional.ofNullable(e.getGmtCreate()).orElse(date)); - e.setGmtModified(Optional.ofNullable(e.getGmtModified()).orElse(date)); - e.setCreatorId(Optional.ofNullable(e.getCreatorId()).orElse(userId)); - e.setModifierId(Optional.ofNullable(e.getModifierId()).orElse(userId)); - if (userLevel) { - ((UserLevelEntity) e).setSystem(systemSession); - } - }); - saveBatch(entities); - return entities.stream().map(IdEntity::getId).collect(Collectors.toList()); + default String create(E entity) { + if (entity == null) { + return null; } + List ids = create(Collections.singletonList(entity)); + return ids.isEmpty() ? null : ids.get(0); + } - default boolean update(List entities) { - if (CollectionUtils.isEmpty(entities)) { - return false; - } - - String userId = ContextHolder.get().getUserId(); - Date date = new Date(); - for (E e : entities) { - e.setGmtModified(Optional.ofNullable(e.getGmtModified()).orElse(date)); - e.setModifierId(Optional.ofNullable(e.getModifierId()).orElse(userId)); - } - - return updateBatchById(entities); + default boolean update(E entity) { + if (entity == null) { + return false; } + return update(Collections.singletonList(entity)); + } - default boolean drop(List ids) { - if (CollectionUtils.isEmpty(ids)) { - return true; - } - - return lambdaUpdate().in(E::getId, ids).remove(); + default boolean drop(String id) { + if (id == null) { + return false; } + return drop(Collections.singletonList(id)); + } - default PageList search(S search) { - QueryWrapper queryWrapper = new QueryWrapper<>(); - - // config general query condition - configQueryWrapper(queryWrapper); - - // config general search condition - configBaseSearch(queryWrapper, search); - - // config concrete search condition - LambdaQueryWrapper lambdaQueryWrapper = queryWrapper.lambda(); - configSearch(lambdaQueryWrapper, search); - - LambdaQueryChainWrapper wrapper = new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), lambdaQueryWrapper); - Page page = buildPage(search); - if (page != null) { - return new PageList<>(wrapper.page(page)); - - } else { - return new PageList<>(wrapper.list()); - } + default List get(List ids) { + if (CollectionUtils.isEmpty(ids)) { + return new ArrayList<>(); } - default PageList search(MPJLambdaWrapper wrapper, S search) { - MPJJoinMapper joinMapper = ((MPJJoinMapper) getBaseMapper()); - configBaseJoinSearch(wrapper, search); - configJoinSearch(wrapper, search); + return listByIds(ids); + } - Class entityClass = ((ServiceImpl) this).getEntityClass(); - Page page = buildPage(search); - if (page != null) { - return new PageList<>(joinMapper.selectJoinPage(page, entityClass, wrapper)); + default List create(List entities) { + String tenantId = ContextHolder.get().getTenantId(); + String userId = ContextHolder.get().getUserId(); + Date date = new Date(); - } else { - return new PageList<>(joinMapper.selectJoinList(entityClass, wrapper)); - } - } - - static Page buildPage(S search) { - Page page = null; - if (search.getSize() != null) { - page = new Page<>(Optional.ofNullable(search.getPage()).orElse(1), search.getSize()); - - String sort = search.getSort(); - if (StringUtils.isNotBlank(sort)) { - page.addOrder(SortOrder.ASC.equals(search.getOrder()) ? OrderItem.asc(sort) : OrderItem.desc(sort)); - } else { - page.addOrder(OrderItem.desc(MODIFY_TIME_FIELD_NAME)); - } - } - return page; - } + boolean systemSession = ContextHolder.get().isSystemSession(); + boolean userLevel = this instanceof UserLevelDao; - default void configBaseSearch(QueryWrapper wrapper, S search) { - Date startCreateTime = search.getStartCreateTime(); - Date endCreateTime = search.getEndCreateTime(); - Date startModifyTime = search.getStartModifyTime(); - Date endModifyTime = search.getEndModifyTime(); - String creatorId = search.getCreatorId(); - String modifierId = search.getModifierId(); - - wrapper.ge(startCreateTime != null, CREATE_TIME_FIELD_NAME, startCreateTime); - wrapper.le(endCreateTime != null, CREATE_TIME_FIELD_NAME, endCreateTime); - wrapper.ge(startModifyTime != null, MODIFY_TIME_FIELD_NAME, startModifyTime); - wrapper.le(endModifyTime != null, MODIFY_TIME_FIELD_NAME, endModifyTime); - wrapper.eq(StringUtils.isNotBlank(creatorId), CREATOR_FIELD_NAME, creatorId); - wrapper.eq(StringUtils.isNotBlank(modifierId), MODIFIER_FIELD_NAME, modifierId); - wrapper.orderBy(search.getOrder() != null && StringUtils.isNotBlank(search.getSort()), - search.getOrder() == SortOrder.ASC, search.getSort()); + entities.forEach( + e -> { + e.setTenantId(Optional.ofNullable(e.getTenantId()).orElse(tenantId)); + e.setId(IdGenerator.nextId()); + e.setGmtCreate(Optional.ofNullable(e.getGmtCreate()).orElse(date)); + e.setGmtModified(Optional.ofNullable(e.getGmtModified()).orElse(date)); + e.setCreatorId(Optional.ofNullable(e.getCreatorId()).orElse(userId)); + e.setModifierId(Optional.ofNullable(e.getModifierId()).orElse(userId)); + if (userLevel) { + ((UserLevelEntity) e).setSystem(systemSession); + } + }); + saveBatch(entities); + return entities.stream().map(IdEntity::getId).collect(Collectors.toList()); + } + default boolean update(List entities) { + if (CollectionUtils.isEmpty(entities)) { + return false; } - default void configSearch(LambdaQueryWrapper wrapper, S search) { - + String userId = ContextHolder.get().getUserId(); + Date date = new Date(); + for (E e : entities) { + e.setGmtModified(Optional.ofNullable(e.getGmtModified()).orElse(date)); + e.setModifierId(Optional.ofNullable(e.getModifierId()).orElse(userId)); } - void configQueryWrapper(QueryWrapper wrapper); - - void configUpdateWrapper(UpdateWrapper wrapper); + return updateBatchById(entities); + } - default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { - Date startCreateTime = search.getStartCreateTime(); - Date endCreateTime = search.getEndCreateTime(); - Date startModifyTime = search.getStartModifyTime(); - Date endModifyTime = search.getEndModifyTime(); - String creatorId = search.getCreatorId(); - String modifierId = search.getModifierId(); + default boolean drop(List ids) { + if (CollectionUtils.isEmpty(ids)) { + return true; + } + + return lambdaUpdate().in(E::getId, ids).remove(); + } + + default PageList search(S search) { + QueryWrapper queryWrapper = new QueryWrapper<>(); + + // config general query condition + configQueryWrapper(queryWrapper); + + // config general search condition + configBaseSearch(queryWrapper, search); + + // config concrete search condition + LambdaQueryWrapper lambdaQueryWrapper = queryWrapper.lambda(); + configSearch(lambdaQueryWrapper, search); + + LambdaQueryChainWrapper wrapper = + new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), lambdaQueryWrapper); + Page page = buildPage(search); + if (page != null) { + return new PageList<>(wrapper.page(page)); + + } else { + return new PageList<>(wrapper.list()); + } + } + + default PageList search(MPJLambdaWrapper wrapper, S search) { + MPJJoinMapper joinMapper = ((MPJJoinMapper) getBaseMapper()); + configBaseJoinSearch(wrapper, search); + configJoinSearch(wrapper, search); + + Class entityClass = ((ServiceImpl) this).getEntityClass(); + Page page = buildPage(search); + if (page != null) { + return new PageList<>(joinMapper.selectJoinPage(page, entityClass, wrapper)); + + } else { + return new PageList<>(joinMapper.selectJoinList(entityClass, wrapper)); + } + } + + static Page buildPage(S search) { + Page page = null; + if (search.getSize() != null) { + page = new Page<>(Optional.ofNullable(search.getPage()).orElse(1), search.getSize()); + + String sort = search.getSort(); + if (StringUtils.isNotBlank(sort)) { + page.addOrder( + SortOrder.ASC.equals(search.getOrder()) ? OrderItem.asc(sort) : OrderItem.desc(sort)); + } else { + page.addOrder(OrderItem.desc(MODIFY_TIME_FIELD_NAME)); + } + } + return page; + } + + default void configBaseSearch(QueryWrapper wrapper, S search) { + Date startCreateTime = search.getStartCreateTime(); + Date endCreateTime = search.getEndCreateTime(); + Date startModifyTime = search.getStartModifyTime(); + Date endModifyTime = search.getEndModifyTime(); + String creatorId = search.getCreatorId(); + String modifierId = search.getModifierId(); + + wrapper.ge(startCreateTime != null, CREATE_TIME_FIELD_NAME, startCreateTime); + wrapper.le(endCreateTime != null, CREATE_TIME_FIELD_NAME, endCreateTime); + wrapper.ge(startModifyTime != null, MODIFY_TIME_FIELD_NAME, startModifyTime); + wrapper.le(endModifyTime != null, MODIFY_TIME_FIELD_NAME, endModifyTime); + wrapper.eq(StringUtils.isNotBlank(creatorId), CREATOR_FIELD_NAME, creatorId); + wrapper.eq(StringUtils.isNotBlank(modifierId), MODIFIER_FIELD_NAME, modifierId); + wrapper.orderBy( + search.getOrder() != null && StringUtils.isNotBlank(search.getSort()), + search.getOrder() == SortOrder.ASC, + search.getSort()); + } + + default void configSearch(LambdaQueryWrapper wrapper, S search) {} - wrapper.ge(startCreateTime != null, E::getGmtCreate, startCreateTime); - wrapper.le(endCreateTime != null, E::getGmtCreate, endCreateTime); - wrapper.ge(startModifyTime != null, E::getGmtModified, startModifyTime); - wrapper.le(endModifyTime != null, E::getGmtModified, endModifyTime); - wrapper.eq(StringUtils.isNotBlank(creatorId), E::getCreatorId, creatorId); - wrapper.eq(StringUtils.isNotBlank(modifierId), E::getModifierId, modifierId); - } - - default void configJoinSearch(MPJLambdaWrapper wrapper, S search) { + void configQueryWrapper(QueryWrapper wrapper); + + void configUpdateWrapper(UpdateWrapper wrapper); - } + default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { + Date startCreateTime = search.getStartCreateTime(); + Date endCreateTime = search.getEndCreateTime(); + Date startModifyTime = search.getStartModifyTime(); + Date endModifyTime = search.getEndModifyTime(); + String creatorId = search.getCreatorId(); + String modifierId = search.getModifierId(); + + wrapper.ge(startCreateTime != null, E::getGmtCreate, startCreateTime); + wrapper.le(endCreateTime != null, E::getGmtCreate, endCreateTime); + wrapper.ge(startModifyTime != null, E::getGmtModified, startModifyTime); + wrapper.le(endModifyTime != null, E::getGmtModified, endModifyTime); + wrapper.eq(StringUtils.isNotBlank(creatorId), E::getCreatorId, creatorId); + wrapper.eq(StringUtils.isNotBlank(modifierId), E::getModifierId, modifierId); + } + + default void configJoinSearch(MPJLambdaWrapper wrapper, S search) {} } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/InstanceDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/InstanceDao.java index 46d9960aa..33bd3e49f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/InstanceDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/InstanceDao.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.dao; import java.util.List; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.entity.InstanceEntity; import org.apache.geaflow.console.common.dal.entity.ResourceCount; @@ -29,18 +30,18 @@ import org.springframework.stereotype.Repository; @Repository -public class InstanceDao extends TenantLevelDao implements NameDao { - - public List search() { - return lambdaQuery().list(); - } +public class InstanceDao extends TenantLevelDao + implements NameDao { + public List search() { + return lambdaQuery().list(); + } - public List getResourceCount(String instanceId, List names) { - if (instanceId == null || CollectionUtils.isEmpty(names)) { - throw new GeaflowException("Empty instance or names"); - } - - return getBaseMapper().getResourceCount(instanceId, names); + public List getResourceCount(String instanceId, List names) { + if (instanceId == null || CollectionUtils.isEmpty(names)) { + throw new GeaflowException("Empty instance or names"); } + + return getBaseMapper().getResourceCount(instanceId, names); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobDao.java index 353de060b..a869347ea 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobDao.java @@ -19,25 +19,28 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.JobEntity; import org.apache.geaflow.console.common.dal.mapper.JobMapper; import org.apache.geaflow.console.common.dal.model.JobSearch; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class JobDao extends TenantLevelDao implements NameDao { +public class JobDao extends TenantLevelDao + implements NameDao { - @Override - public void configSearch(LambdaQueryWrapper wrapper, JobSearch search) { - wrapper.eq(search.getJobType() != null, JobEntity::getType, search.getJobType()); - wrapper.eq(search.getInstanceId() != null, JobEntity::getInstanceId, search.getInstanceId()); - } + @Override + public void configSearch(LambdaQueryWrapper wrapper, JobSearch search) { + wrapper.eq(search.getJobType() != null, JobEntity::getType, search.getJobType()); + wrapper.eq(search.getInstanceId() != null, JobEntity::getInstanceId, search.getInstanceId()); + } - public long getFileRefCount(String fileId, String excludeJobId) { - return lambdaQuery().eq(JobEntity::getJarPackageId, fileId) - .ne(excludeJobId != null, IdEntity::getId, excludeJobId) - .count(); - } + public long getFileRefCount(String fileId, String excludeJobId) { + return lambdaQuery() + .eq(JobEntity::getJarPackageId, fileId) + .ne(excludeJobId != null, IdEntity::getId, excludeJobId) + .count(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobResourceMappingDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobResourceMappingDao.java index 1c4763348..a7528ffd2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobResourceMappingDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/JobResourceMappingDao.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.entity.JobResourceMappingEntity; import org.apache.geaflow.console.common.dal.mapper.JobResourceMappingMapper; @@ -29,44 +30,51 @@ import org.springframework.stereotype.Repository; @Repository -public class JobResourceMappingDao extends TenantLevelDao implements - IdDao { - - public List getResourcesByJobId(String jobId, GeaflowResourceType resourceType) { - if (jobId == null) { - return new ArrayList<>(); - } +public class JobResourceMappingDao + extends TenantLevelDao + implements IdDao { - return lambdaQuery().eq(JobResourceMappingEntity::getJobId, jobId) - .eq(JobResourceMappingEntity::getResourceType, resourceType).list(); + public List getResourcesByJobId( + String jobId, GeaflowResourceType resourceType) { + if (jobId == null) { + return new ArrayList<>(); } - public List getJobByResources(String resourceName, String instanceId, GeaflowResourceType resourceType) { - if (resourceName == null || instanceId == null) { - return new ArrayList<>(); - } + return lambdaQuery() + .eq(JobResourceMappingEntity::getJobId, jobId) + .eq(JobResourceMappingEntity::getResourceType, resourceType) + .list(); + } - return lambdaQuery() - .eq(JobResourceMappingEntity::getResourceName, resourceName) - .eq(JobResourceMappingEntity::getInstanceId, instanceId) - .eq(JobResourceMappingEntity::getResourceType, resourceType).list(); + public List getJobByResources( + String resourceName, String instanceId, GeaflowResourceType resourceType) { + if (resourceName == null || instanceId == null) { + return new ArrayList<>(); } - public boolean dropByJobIds(List jobIds) { - if (CollectionUtils.isEmpty(jobIds)) { - return true; - } - - return lambdaUpdate().in(JobResourceMappingEntity::getJobId, jobIds).remove(); + return lambdaQuery() + .eq(JobResourceMappingEntity::getResourceName, resourceName) + .eq(JobResourceMappingEntity::getInstanceId, instanceId) + .eq(JobResourceMappingEntity::getResourceType, resourceType) + .list(); + } + public boolean dropByJobIds(List jobIds) { + if (CollectionUtils.isEmpty(jobIds)) { + return true; } - public void removeJobResources(List entities) { - for (JobResourceMappingEntity entity : entities) { - lambdaUpdate().eq(JobResourceMappingEntity::getResourceName, entity.getResourceName()) - .eq(JobResourceMappingEntity::getInstanceId, entity.getInstanceId()) - .eq(JobResourceMappingEntity::getResourceType, entity.getResourceType()) - .eq(JobResourceMappingEntity::getJobId, entity.getJobId()).remove(); - } + return lambdaUpdate().in(JobResourceMappingEntity::getJobId, jobIds).remove(); + } + + public void removeJobResources(List entities) { + for (JobResourceMappingEntity entity : entities) { + lambdaUpdate() + .eq(JobResourceMappingEntity::getResourceName, entity.getResourceName()) + .eq(JobResourceMappingEntity::getInstanceId, entity.getInstanceId()) + .eq(JobResourceMappingEntity::getResourceType, entity.getResourceType()) + .eq(JobResourceMappingEntity::getJobId, entity.getJobId()) + .remove(); } + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/LLMDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/LLMDao.java index b184b98a1..b4b01655d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/LLMDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/LLMDao.java @@ -25,7 +25,5 @@ import org.springframework.stereotype.Repository; @Repository -public class LLMDao extends SystemLevelDao implements NameDao { - - -} +public class LLMDao extends SystemLevelDao + implements NameDao {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/NameDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/NameDao.java index f3fa3bdab..f5649d907 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/NameDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/NameDao.java @@ -19,13 +19,12 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.NameEntity; @@ -33,72 +32,74 @@ import org.apache.geaflow.console.common.util.ListUtil; import org.springframework.util.CollectionUtils; -public interface NameDao extends IdDao { +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.github.yulichang.wrapper.MPJLambdaWrapper; - String NAME_FILE_NAME = "name"; +public interface NameDao extends IdDao { - String COMMENT_FILE_NAME = "comment"; + String NAME_FILE_NAME = "name"; - default boolean existName(String name) { - if (name == null) { - return false; - } + String COMMENT_FILE_NAME = "comment"; - return lambdaQuery().eq(E::getName, name).exists(); + default boolean existName(String name) { + if (name == null) { + return false; } - default E getByName(String name) { - if (name == null) { - return null; - } - List entities = getByNames(Collections.singletonList(name)); - return entities.isEmpty() ? null : entities.get(0); - } + return lambdaQuery().eq(E::getName, name).exists(); + } - default String getIdByName(String name) { - if (name == null) { - return null; - } - return getIdsByNames(Collections.singletonList(name)).get(name); + default E getByName(String name) { + if (name == null) { + return null; } + List entities = getByNames(Collections.singletonList(name)); + return entities.isEmpty() ? null : entities.get(0); + } - default List getByNames(List names) { - if (CollectionUtils.isEmpty(names)) { - return new ArrayList<>(); - } - - return lambdaQuery().in(E::getName, names).list(); + default String getIdByName(String name) { + if (name == null) { + return null; } + return getIdsByNames(Collections.singletonList(name)).get(name); + } - default Map getIdsByNames(List names) { - if (CollectionUtils.isEmpty(names)) { - return new HashMap<>(); - } + default List getByNames(List names) { + if (CollectionUtils.isEmpty(names)) { + return new ArrayList<>(); + } - List entities = lambdaQuery().select(E::getId, E::getName).in(E::getName, names).list(); - return ListUtil.toMap(entities, NameEntity::getName, IdEntity::getId); + return lambdaQuery().in(E::getName, names).list(); + } + default Map getIdsByNames(List names) { + if (CollectionUtils.isEmpty(names)) { + return new HashMap<>(); } - @Override - default void configBaseSearch(QueryWrapper wrapper, S search) { - IdDao.super.configBaseSearch(wrapper, search); + List entities = lambdaQuery().select(E::getId, E::getName).in(E::getName, names).list(); + return ListUtil.toMap(entities, NameEntity::getName, IdEntity::getId); + } - String name = search.getName(); - String comment = search.getComment(); + @Override + default void configBaseSearch(QueryWrapper wrapper, S search) { + IdDao.super.configBaseSearch(wrapper, search); - wrapper.like(StringUtils.isNotBlank(name), NAME_FILE_NAME, name); - wrapper.like(StringUtils.isNotBlank(comment), COMMENT_FILE_NAME, comment); - } + String name = search.getName(); + String comment = search.getComment(); - @Override - default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { - IdDao.super.configBaseJoinSearch(wrapper, search); + wrapper.like(StringUtils.isNotBlank(name), NAME_FILE_NAME, name); + wrapper.like(StringUtils.isNotBlank(comment), COMMENT_FILE_NAME, comment); + } - String name = search.getName(); - String comment = search.getComment(); + @Override + default void configBaseJoinSearch(MPJLambdaWrapper wrapper, S search) { + IdDao.super.configBaseJoinSearch(wrapper, search); - wrapper.like(StringUtils.isNotBlank(name), E::getName, name); - wrapper.like(StringUtils.isNotBlank(comment), E::getComment, comment); - } + String name = search.getName(); + String comment = search.getComment(); + + wrapper.like(StringUtils.isNotBlank(name), E::getName, name); + wrapper.like(StringUtils.isNotBlank(comment), E::getComment, comment); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginConfigDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginConfigDao.java index c401fd84a..386820f6d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginConfigDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginConfigDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.PluginConfigEntity; import org.apache.geaflow.console.common.dal.mapper.PluginConfigMapper; @@ -29,20 +29,26 @@ import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class PluginConfigDao extends UserLevelDao implements - NameDao { +public class PluginConfigDao extends UserLevelDao + implements NameDao { - public List getPluginConfigs(GeaflowPluginCategory category, String type) { - return lambdaQuery().eq(category != null, PluginConfigEntity::getCategory, category) - .eq(StringUtils.isNotEmpty(type), PluginConfigEntity::getType, type).list(); - } + public List getPluginConfigs(GeaflowPluginCategory category, String type) { + return lambdaQuery() + .eq(category != null, PluginConfigEntity::getCategory, category) + .eq(StringUtils.isNotEmpty(type), PluginConfigEntity::getType, type) + .list(); + } - @Override - public void configSearch(LambdaQueryWrapper wrapper, PluginConfigSearch search) { - boolean systemSession = ContextHolder.get().isSystemSession(); - wrapper.eq(search.getType() != null, PluginConfigEntity::getType, search.getType()) - .eq(search.getCategory() != null, PluginConfigEntity::getCategory, search.getCategory()) - .eq(PluginConfigEntity::isSystem, systemSession); - } + @Override + public void configSearch( + LambdaQueryWrapper wrapper, PluginConfigSearch search) { + boolean systemSession = ContextHolder.get().isSystemSession(); + wrapper + .eq(search.getType() != null, PluginConfigEntity::getType, search.getType()) + .eq(search.getCategory() != null, PluginConfigEntity::getCategory, search.getCategory()) + .eq(PluginConfigEntity::isSystem, systemSession); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginDao.java index 3e9d3b618..0766b697e 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/PluginDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.PluginEntity; @@ -30,37 +30,54 @@ import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class PluginDao extends UserLevelDao implements NameDao { +public class PluginDao extends UserLevelDao + implements NameDao { - public List getPlugins(GeaflowPluginCategory category) { - return lambdaQuery().eq(PluginEntity::getPluginCategory, category).list(); - } + public List getPlugins(GeaflowPluginCategory category) { + return lambdaQuery().eq(PluginEntity::getPluginCategory, category).list(); + } - public PluginEntity getPlugin(String type, GeaflowPluginCategory category) { - return lambdaQuery() - .eq(PluginEntity::getPluginType, type) - .eq(PluginEntity::getPluginCategory, category).one(); - } + public PluginEntity getPlugin(String type, GeaflowPluginCategory category) { + return lambdaQuery() + .eq(PluginEntity::getPluginType, type) + .eq(PluginEntity::getPluginCategory, category) + .one(); + } - public List getSystemPlugins(GeaflowPluginCategory category) { - return lambdaQuery().eq(PluginEntity::getPluginCategory, category).eq(PluginEntity::isSystem, true).list(); - } + public List getSystemPlugins(GeaflowPluginCategory category) { + return lambdaQuery() + .eq(PluginEntity::getPluginCategory, category) + .eq(PluginEntity::isSystem, true) + .list(); + } - @Override - public void configSearch(LambdaQueryWrapper wrapper, PluginSearch search) { - boolean systemSession = ContextHolder.get().isSystemSession(); - wrapper.eq(search.getPluginType() != null, PluginEntity::getPluginType, search.getPluginType()) - .eq(search.getPluginCategory() != null, PluginEntity::getPluginCategory, search.getPluginCategory()) - .eq(PluginEntity::isSystem, systemSession) - .and(StringUtils.isNotEmpty(search.getKeyword()), e -> e.like(PluginEntity::getName, search.getKeyword()) - .or().like(PluginEntity::getPluginCategory, search.getKeyword()) - .or().like(PluginEntity::getPluginType, search.getKeyword())); - } + @Override + public void configSearch(LambdaQueryWrapper wrapper, PluginSearch search) { + boolean systemSession = ContextHolder.get().isSystemSession(); + wrapper + .eq(search.getPluginType() != null, PluginEntity::getPluginType, search.getPluginType()) + .eq( + search.getPluginCategory() != null, + PluginEntity::getPluginCategory, + search.getPluginCategory()) + .eq(PluginEntity::isSystem, systemSession) + .and( + StringUtils.isNotEmpty(search.getKeyword()), + e -> + e.like(PluginEntity::getName, search.getKeyword()) + .or() + .like(PluginEntity::getPluginCategory, search.getKeyword()) + .or() + .like(PluginEntity::getPluginType, search.getKeyword())); + } - public long getFileRefCount(String fileId, String excludeId) { - return lambdaQuery().eq(PluginEntity::getJarPackageId, fileId) - .ne(excludeId != null, IdEntity::getId, excludeId) - .count(); - } + public long getFileRefCount(String fileId, String excludeId) { + return lambdaQuery() + .eq(PluginEntity::getJarPackageId, fileId) + .ne(excludeId != null, IdEntity::getId, excludeId) + .count(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ReleaseDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ReleaseDao.java index ead81892b..108858ad2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ReleaseDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ReleaseDao.java @@ -19,35 +19,40 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.ReleaseEntity; import org.apache.geaflow.console.common.dal.mapper.ReleaseMapper; import org.apache.geaflow.console.common.dal.model.ReleaseSearch; import org.springframework.stereotype.Repository; -@Repository -public class ReleaseDao extends TenantLevelDao implements IdDao { - - public ReleaseEntity getLatestRelease(String jobId) { - return lambdaQuery().eq(ReleaseEntity::getJobId, jobId).orderByDesc(ReleaseEntity::getVersion).last("limit 1").one(); - } - - @Override - public void configSearch(LambdaQueryWrapper wrapper, ReleaseSearch search) { - wrapper.eq(search.getVersionId() != null, ReleaseEntity::getVersionId, search.getVersionId()) - .eq(search.getClusterId() != null, ReleaseEntity::getClusterId, search.getClusterId()) - .eq(search.getJobId() != null, ReleaseEntity::getJobId, search.getJobId()); - - } - - public void dropByJobIds(List ids) { - lambdaUpdate().in(ReleaseEntity::getJobId, ids).remove(); - } - - public List getByJobIds(List ids) { - return lambdaQuery().in(ReleaseEntity::getJobId, ids).list(); - } - +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +@Repository +public class ReleaseDao extends TenantLevelDao + implements IdDao { + + public ReleaseEntity getLatestRelease(String jobId) { + return lambdaQuery() + .eq(ReleaseEntity::getJobId, jobId) + .orderByDesc(ReleaseEntity::getVersion) + .last("limit 1") + .one(); + } + + @Override + public void configSearch(LambdaQueryWrapper wrapper, ReleaseSearch search) { + wrapper + .eq(search.getVersionId() != null, ReleaseEntity::getVersionId, search.getVersionId()) + .eq(search.getClusterId() != null, ReleaseEntity::getClusterId, search.getClusterId()) + .eq(search.getJobId() != null, ReleaseEntity::getJobId, search.getJobId()); + } + + public void dropByJobIds(List ids) { + lambdaUpdate().in(ReleaseEntity::getJobId, ids).remove(); + } + + public List getByJobIds(List ids) { + return lambdaQuery().in(ReleaseEntity::getJobId, ids).list(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/RemoteFileDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/RemoteFileDao.java index 898693670..ab71ad1f0 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/RemoteFileDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/RemoteFileDao.java @@ -19,31 +19,33 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.apache.geaflow.console.common.dal.entity.RemoteFileEntity; import org.apache.geaflow.console.common.dal.mapper.RemoteFileMapper; import org.apache.geaflow.console.common.dal.model.RemoteFileSearch; import org.apache.geaflow.console.common.util.context.ContextHolder; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class RemoteFileDao extends UserLevelDao implements - NameDao { +public class RemoteFileDao extends UserLevelDao + implements NameDao { - public void updateMd5(String id, String md5) { - lambdaUpdate().set(RemoteFileEntity::getMd5, md5).eq(RemoteFileEntity::getId, id).update(); - } + public void updateMd5(String id, String md5) { + lambdaUpdate().set(RemoteFileEntity::getMd5, md5).eq(RemoteFileEntity::getId, id).update(); + } - public void updateUrl(String id, String url) { - lambdaUpdate().set(RemoteFileEntity::getUrl, url).eq(RemoteFileEntity::getId, id).update(); - } + public void updateUrl(String id, String url) { + lambdaUpdate().set(RemoteFileEntity::getUrl, url).eq(RemoteFileEntity::getId, id).update(); + } - @Override - public void configSearch(LambdaQueryWrapper wrapper, RemoteFileSearch search) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String userId = ContextHolder.get().getUserId(); - wrapper.eq(search.getType() != null, RemoteFileEntity::getType, search.getType()) - .eq(!systemSession, RemoteFileEntity::getCreatorId, userId) - .eq(RemoteFileEntity::isSystem, systemSession); - } + @Override + public void configSearch(LambdaQueryWrapper wrapper, RemoteFileSearch search) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String userId = ContextHolder.get().getUserId(); + wrapper + .eq(search.getType() != null, RemoteFileEntity::getType, search.getType()) + .eq(!systemSession, RemoteFileEntity::getCreatorId, userId) + .eq(RemoteFileEntity::isSystem, systemSession); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/StatementDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/StatementDao.java index f46b8ed28..1c53b6d01 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/StatementDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/StatementDao.java @@ -19,24 +19,27 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.StatementEntity; import org.apache.geaflow.console.common.dal.mapper.StatementMapper; import org.apache.geaflow.console.common.dal.model.StatementSearch; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository public class StatementDao extends TenantLevelDao implements IdDao { - @Override - public void configSearch(LambdaQueryWrapper wrapper, StatementSearch search) { - wrapper.eq(search.getJobId() != null, StatementEntity::getJobId, search.getJobId()) - .eq(search.getStatus() != null, StatementEntity::getStatus, search.getStatus()); - } + @Override + public void configSearch(LambdaQueryWrapper wrapper, StatementSearch search) { + wrapper + .eq(search.getJobId() != null, StatementEntity::getJobId, search.getJobId()) + .eq(search.getStatus() != null, StatementEntity::getStatus, search.getStatus()); + } - public boolean dropByJobIds(List jobIds) { - return lambdaUpdate().in(StatementEntity::getJobId, jobIds).remove(); - } + public boolean dropByJobIds(List jobIds) { + return lambdaUpdate().in(StatementEntity::getJobId, jobIds).remove(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemConfigDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemConfigDao.java index 0e203966f..efa656c4a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemConfigDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemConfigDao.java @@ -19,60 +19,71 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.interfaces.Compare; -import com.baomidou.mybatisplus.core.conditions.interfaces.Func; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; -import com.baomidou.mybatisplus.core.toolkit.support.SFunction; import java.util.Optional; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.SystemConfigEntity; import org.apache.geaflow.console.common.dal.mapper.SystemConfigMapper; import org.apache.geaflow.console.common.dal.model.SystemConfigSearch; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.interfaces.Compare; +import com.baomidou.mybatisplus.core.conditions.interfaces.Func; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.toolkit.support.SFunction; + @Repository -public class SystemConfigDao extends SystemLevelDao implements - NameDao { +public class SystemConfigDao extends SystemLevelDao + implements NameDao { - @Override - public void configSearch(LambdaQueryWrapper wrapper, SystemConfigSearch search) { - String tenantId = search.getTenantId(); - String value = search.getValue(); - wrapper.eq(StringUtils.isNotBlank(tenantId), SystemConfigEntity::getTenantId, tenantId); - wrapper.like(StringUtils.isNotBlank(value), SystemConfigEntity::getValue, value); - } + @Override + public void configSearch( + LambdaQueryWrapper wrapper, SystemConfigSearch search) { + String tenantId = search.getTenantId(); + String value = search.getValue(); + wrapper.eq(StringUtils.isNotBlank(tenantId), SystemConfigEntity::getTenantId, tenantId); + wrapper.like(StringUtils.isNotBlank(value), SystemConfigEntity::getValue, value); + } - public SystemConfigEntity get(String tenantId, String key) { - return wrap(lambdaQuery(), tenantId).eq(SystemConfigEntity::getName, key).one(); - } + public SystemConfigEntity get(String tenantId, String key) { + return wrap(lambdaQuery(), tenantId).eq(SystemConfigEntity::getName, key).one(); + } - public String getValue(String tenantId, String key) { - SystemConfigEntity entity = wrap(lambdaQuery(), tenantId).select(SystemConfigEntity::getValue) - .eq(SystemConfigEntity::getName, key).one(); - return Optional.ofNullable(entity).map(SystemConfigEntity::getValue).orElse(null); - } + public String getValue(String tenantId, String key) { + SystemConfigEntity entity = + wrap(lambdaQuery(), tenantId) + .select(SystemConfigEntity::getValue) + .eq(SystemConfigEntity::getName, key) + .one(); + return Optional.ofNullable(entity).map(SystemConfigEntity::getValue).orElse(null); + } - public boolean setValue(String tenantId, String key, String value) { - return wrap(lambdaUpdate(), tenantId).set(SystemConfigEntity::getValue, value) - .eq(SystemConfigEntity::getName, key).update(); - } + public boolean setValue(String tenantId, String key, String value) { + return wrap(lambdaUpdate(), tenantId) + .set(SystemConfigEntity::getValue, value) + .eq(SystemConfigEntity::getName, key) + .update(); + } - public boolean exist(String tenantId, String key) { - return wrap(lambdaQuery(), tenantId).eq(SystemConfigEntity::getName, key).exists(); - } + public boolean exist(String tenantId, String key) { + return wrap(lambdaQuery(), tenantId).eq(SystemConfigEntity::getName, key).exists(); + } - public boolean delete(String tenantId, String key) { - return wrap(lambdaUpdate(), tenantId).eq(SystemConfigEntity::getName, key).remove(); - } + public boolean delete(String tenantId, String key) { + return wrap(lambdaUpdate(), tenantId).eq(SystemConfigEntity::getName, key).remove(); + } - private > & Func>> W wrap( - W wrapper, String tenantId) { - if (StringUtils.isNotBlank(tenantId)) { - wrapper.eq(SystemConfigEntity::getTenantId, tenantId); + private < + W extends + Compare> + & Func>> + W wrap(W wrapper, String tenantId) { + if (StringUtils.isNotBlank(tenantId)) { + wrapper.eq(SystemConfigEntity::getTenantId, tenantId); - } else { - wrapper.isNull(SystemConfigEntity::getTenantId); - } - return wrapper; + } else { + wrapper.isNull(SystemConfigEntity::getTenantId); } + return wrapper; + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemLevelDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemLevelDao.java index 8e5e82c13..7b6fb388a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemLevelDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/SystemLevelDao.java @@ -23,15 +23,15 @@ import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; import org.apache.geaflow.console.common.util.exception.GeaflowException; -public abstract class SystemLevelDao, E extends IdEntity> extends GeaflowBaseDao { +public abstract class SystemLevelDao, E extends IdEntity> + extends GeaflowBaseDao { - @Override - public void afterPropertiesSet() throws Exception { - super.afterPropertiesSet(); - if (!ignoreTenant) { - String message = "Mapper {} must be annotated by @InterceptorIgnore(tenantLine = \"true\")"; - throw new GeaflowException(message, mapperClass.getSimpleName()); - } + @Override + public void afterPropertiesSet() throws Exception { + super.afterPropertiesSet(); + if (!ignoreTenant) { + String message = "Mapper {} must be annotated by @InterceptorIgnore(tenantLine = \"true\")"; + throw new GeaflowException(message, mapperClass.getSimpleName()); } - + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TableDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TableDao.java index a2175d6fd..758eb8abb 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TableDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TableDao.java @@ -25,6 +25,5 @@ import org.springframework.stereotype.Repository; @Repository -public class TableDao extends TenantLevelDao implements DataDao { - -} +public class TableDao extends TenantLevelDao + implements DataDao {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskDao.java index 934190af6..4c141773d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.JobEntity; import org.apache.geaflow.console.common.dal.entity.JobResourceMappingEntity; @@ -34,74 +34,96 @@ import org.apache.geaflow.console.common.util.type.GeaflowTaskStatus; import org.springframework.stereotype.Repository; -@Repository -public class TaskDao extends TenantLevelDao implements IdDao { - - public List getByJobId(String jobId) { - return lambdaQuery().eq(TaskEntity::getJobId, jobId).list(); - } - - public List getIdsByJobs(List jobIds) { - return lambdaQuery().select(TaskEntity::getId).in(TaskEntity::getJobId, jobIds).list(); - } - - public boolean updateStatus(String taskId, GeaflowTaskStatus oldStatus, GeaflowTaskStatus newStatus) { - return lambdaUpdate().eq(TaskEntity::getId, taskId).eq(TaskEntity::getStatus, oldStatus) - .ne(TaskEntity::getStatus, newStatus).set(TaskEntity::getStatus, newStatus).update(); - } - - @Override - public boolean update(List entities) { - entities.forEach(e -> e.setStatus(null)); - return IdDao.super.update(entities); - } - - private MPJLambdaWrapper getJoinWrapper() { - return new MPJLambdaWrapper().selectAll(TaskEntity.class) - .innerJoin(ReleaseEntity.class, ReleaseEntity::getId, TaskEntity::getReleaseId) - .innerJoin(JobEntity.class, JobEntity::getId, TaskEntity::getJobId) - .leftJoin(JobResourceMappingEntity.class, JobResourceMappingEntity::getJobId, JobEntity::getId).distinct(); - } - - @Override - public PageList search(TaskSearch search) { - if (ContextHolder.get().isSystemSession()) { - return systemSearch(search); - } else { - return search(getJoinWrapper(), search); - } - } - - public PageList systemSearch(TaskSearch search) { - // don't select by tenantId - MPJLambdaWrapper wrapper = getJoinWrapper(); - wrapper.comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE); - return search(wrapper, search); - } - - - @Override - public void configJoinSearch(MPJLambdaWrapper wrapper, TaskSearch search) { - wrapper.eq(search.getStatus() != null, TaskEntity::getStatus, search.getStatus()); - wrapper.eq(search.getHost() != null, TaskEntity::getHost, search.getHost()); - wrapper.eq(search.getVersionId() != null, ReleaseEntity::getVersionId, search.getVersionId()); - wrapper.eq(search.getClusterId() != null, ReleaseEntity::getClusterId, search.getClusterId()); - wrapper.eq(search.getJobType() != null, JobEntity::getType, search.getJobType()); - wrapper.eq(search.getJobId() != null, TaskEntity::getJobId, search.getJobId()); - wrapper.like(StringUtils.isNotBlank(search.getJobName()), JobEntity::getName, search.getJobName()); - - wrapper.eq(search.getInstanceId() != null, JobEntity::getInstanceId, search.getInstanceId()); - wrapper.eq(search.getResourceName() != null, JobResourceMappingEntity::getResourceName, search.getResourceName()); - wrapper.eq(search.getResourceType() != null, JobResourceMappingEntity::getResourceType, search.getResourceType()); - - } - - public List getTasksByStatus(GeaflowTaskStatus status) { - return lambdaQuery().eq(TaskEntity::getStatus, status).eq(TaskEntity::getHost, NetworkUtil.getHostName()) - .comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE).list(); - } +import com.github.yulichang.wrapper.MPJLambdaWrapper; - public TaskEntity getByToken(String token) { - return lambdaQuery().eq(TaskEntity::getToken, token).comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE).one(); +@Repository +public class TaskDao extends TenantLevelDao + implements IdDao { + + public List getByJobId(String jobId) { + return lambdaQuery().eq(TaskEntity::getJobId, jobId).list(); + } + + public List getIdsByJobs(List jobIds) { + return lambdaQuery().select(TaskEntity::getId).in(TaskEntity::getJobId, jobIds).list(); + } + + public boolean updateStatus( + String taskId, GeaflowTaskStatus oldStatus, GeaflowTaskStatus newStatus) { + return lambdaUpdate() + .eq(TaskEntity::getId, taskId) + .eq(TaskEntity::getStatus, oldStatus) + .ne(TaskEntity::getStatus, newStatus) + .set(TaskEntity::getStatus, newStatus) + .update(); + } + + @Override + public boolean update(List entities) { + entities.forEach(e -> e.setStatus(null)); + return IdDao.super.update(entities); + } + + private MPJLambdaWrapper getJoinWrapper() { + return new MPJLambdaWrapper() + .selectAll(TaskEntity.class) + .innerJoin(ReleaseEntity.class, ReleaseEntity::getId, TaskEntity::getReleaseId) + .innerJoin(JobEntity.class, JobEntity::getId, TaskEntity::getJobId) + .leftJoin( + JobResourceMappingEntity.class, JobResourceMappingEntity::getJobId, JobEntity::getId) + .distinct(); + } + + @Override + public PageList search(TaskSearch search) { + if (ContextHolder.get().isSystemSession()) { + return systemSearch(search); + } else { + return search(getJoinWrapper(), search); } + } + + public PageList systemSearch(TaskSearch search) { + // don't select by tenantId + MPJLambdaWrapper wrapper = getJoinWrapper(); + wrapper.comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE); + return search(wrapper, search); + } + + @Override + public void configJoinSearch(MPJLambdaWrapper wrapper, TaskSearch search) { + wrapper.eq(search.getStatus() != null, TaskEntity::getStatus, search.getStatus()); + wrapper.eq(search.getHost() != null, TaskEntity::getHost, search.getHost()); + wrapper.eq(search.getVersionId() != null, ReleaseEntity::getVersionId, search.getVersionId()); + wrapper.eq(search.getClusterId() != null, ReleaseEntity::getClusterId, search.getClusterId()); + wrapper.eq(search.getJobType() != null, JobEntity::getType, search.getJobType()); + wrapper.eq(search.getJobId() != null, TaskEntity::getJobId, search.getJobId()); + wrapper.like( + StringUtils.isNotBlank(search.getJobName()), JobEntity::getName, search.getJobName()); + + wrapper.eq(search.getInstanceId() != null, JobEntity::getInstanceId, search.getInstanceId()); + wrapper.eq( + search.getResourceName() != null, + JobResourceMappingEntity::getResourceName, + search.getResourceName()); + wrapper.eq( + search.getResourceType() != null, + JobResourceMappingEntity::getResourceType, + search.getResourceType()); + } + + public List getTasksByStatus(GeaflowTaskStatus status) { + return lambdaQuery() + .eq(TaskEntity::getStatus, status) + .eq(TaskEntity::getHost, NetworkUtil.getHostName()) + .comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE) + .list(); + } + + public TaskEntity getByToken(String token) { + return lambdaQuery() + .eq(TaskEntity::getToken, token) + .comment(TenantLevelExtDao.IGNORE_TENANT_SIGNATURE) + .one(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskScheduleDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskScheduleDao.java index 6fae575db..86118aad7 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskScheduleDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TaskScheduleDao.java @@ -25,7 +25,5 @@ import org.springframework.stereotype.Repository; @Repository -public class TaskScheduleDao extends TenantLevelDao implements - NameDao { - -} +public class TaskScheduleDao extends TenantLevelDao + implements NameDao {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantDao.java index d664586ba..77062552b 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantDao.java @@ -19,11 +19,11 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.NameEntity; import org.apache.geaflow.console.common.dal.entity.TenantEntity; @@ -33,21 +33,31 @@ import org.apache.geaflow.console.common.dal.model.TenantSearch; import org.springframework.stereotype.Repository; +import com.github.yulichang.wrapper.MPJLambdaWrapper; + @Repository -public class TenantDao extends SystemLevelDao implements - NameDao { +public class TenantDao extends SystemLevelDao + implements NameDao { - public PageList search(String userId, TenantSearch search) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll(TenantEntity.class) - .innerJoin(TenantUserMappingEntity.class, TenantUserMappingEntity::getTenantId, TenantEntity::getId) + public PageList search(String userId, TenantSearch search) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(TenantEntity.class) + .innerJoin( + TenantUserMappingEntity.class, + TenantUserMappingEntity::getTenantId, + TenantEntity::getId) .eq(TenantUserMappingEntity::getUserId, userId); - return search(wrapper, search); - } + return search(wrapper, search); + } - public Map getTenantNames(Collection tenantIds) { - List entities = lambdaQuery().select(TenantEntity::getId, TenantEntity::getName) - .in(TenantEntity::getId, tenantIds).list(); - return entities.stream().collect(Collectors.toMap(IdEntity::getId, NameEntity::getName)); - } + public Map getTenantNames(Collection tenantIds) { + List entities = + lambdaQuery() + .select(TenantEntity::getId, TenantEntity::getName) + .in(TenantEntity::getId, tenantIds) + .list(); + return entities.stream().collect(Collectors.toMap(IdEntity::getId, NameEntity::getName)); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelDao.java index 27f380de2..161968947 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelDao.java @@ -23,14 +23,15 @@ import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; import org.apache.geaflow.console.common.util.exception.GeaflowException; -public abstract class TenantLevelDao, E extends IdEntity> extends GeaflowBaseDao { +public abstract class TenantLevelDao, E extends IdEntity> + extends GeaflowBaseDao { - @Override - public void afterPropertiesSet() throws Exception { - super.afterPropertiesSet(); - if (ignoreTenant) { - String message = "Mapper {} can't annotated by @InterceptorIgnore(tenantLine = \"true\")"; - throw new GeaflowException(message, mapperClass.getSimpleName()); - } + @Override + public void afterPropertiesSet() throws Exception { + super.afterPropertiesSet(); + if (ignoreTenant) { + String message = "Mapper {} can't annotated by @InterceptorIgnore(tenantLine = \"true\")"; + throw new GeaflowException(message, mapperClass.getSimpleName()); } + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelExtDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelExtDao.java index 298130052..0a31a616b 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelExtDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantLevelExtDao.java @@ -21,44 +21,43 @@ import static org.apache.geaflow.console.common.dal.dao.IdDao.TENANT_ID_FIELD_NAME; -import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; -import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; -import com.baomidou.mybatisplus.extension.conditions.update.LambdaUpdateChainWrapper; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaQueryChainWrapper; import org.apache.geaflow.console.common.dal.wrapper.GeaflowLambdaUpdateChainWrapper; -public abstract class TenantLevelExtDao, E extends IdEntity> extends - TenantLevelDao { - - public static final String IGNORE_TENANT_SIGNATURE = "###GEAFLOW_IGNORE_TENANT_INTERCEPTOR###"; +import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; +import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; +import com.baomidou.mybatisplus.extension.conditions.update.LambdaUpdateChainWrapper; +public abstract class TenantLevelExtDao, E extends IdEntity> + extends TenantLevelDao { - public LambdaQueryChainWrapper lambdaQuery(String tenantId) { - if (tenantId != null) { - return lambdaQuery(); - } + public static final String IGNORE_TENANT_SIGNATURE = "###GEAFLOW_IGNORE_TENANT_INTERCEPTOR###"; - QueryWrapper wrapper = new QueryWrapper<>(); - configSystemWrapper(wrapper); - return new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), wrapper.lambda()); + public LambdaQueryChainWrapper lambdaQuery(String tenantId) { + if (tenantId != null) { + return lambdaQuery(); } - public LambdaUpdateChainWrapper lambdaUpdate(String tenantId) { - if (tenantId != null) { - return lambdaUpdate(); - } + QueryWrapper wrapper = new QueryWrapper<>(); + configSystemWrapper(wrapper); + return new GeaflowLambdaQueryChainWrapper<>(getBaseMapper(), wrapper.lambda()); + } - UpdateWrapper wrapper = new UpdateWrapper<>(); - configSystemWrapper(wrapper); - return new GeaflowLambdaUpdateChainWrapper<>(getBaseMapper(), wrapper.lambda()); + public LambdaUpdateChainWrapper lambdaUpdate(String tenantId) { + if (tenantId != null) { + return lambdaUpdate(); } - private void configSystemWrapper(AbstractWrapper wrapper) { - wrapper.isNull(TENANT_ID_FIELD_NAME).comment(IGNORE_TENANT_SIGNATURE); - } + UpdateWrapper wrapper = new UpdateWrapper<>(); + configSystemWrapper(wrapper); + return new GeaflowLambdaUpdateChainWrapper<>(getBaseMapper(), wrapper.lambda()); + } + private void configSystemWrapper(AbstractWrapper wrapper) { + wrapper.isNull(TENANT_ID_FIELD_NAME).comment(IGNORE_TENANT_SIGNATURE); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantUserMappingDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantUserMappingDao.java index 0fdd108de..d4d5fbc3f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantUserMappingDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/TenantUserMappingDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.TenantEntity; import org.apache.geaflow.console.common.dal.entity.TenantUserMappingEntity; @@ -28,50 +28,67 @@ import org.apache.geaflow.console.common.dal.model.IdSearch; import org.springframework.stereotype.Repository; +import com.github.yulichang.wrapper.MPJLambdaWrapper; + @Repository -public class TenantUserMappingDao extends SystemLevelDao implements - IdDao { +public class TenantUserMappingDao + extends SystemLevelDao + implements IdDao { - public boolean existUser(String tenantId, String userId) { - return lambdaQuery().eq(TenantUserMappingEntity::getTenantId, tenantId) - .eq(TenantUserMappingEntity::getUserId, userId).exists(); - } + public boolean existUser(String tenantId, String userId) { + return lambdaQuery() + .eq(TenantUserMappingEntity::getTenantId, tenantId) + .eq(TenantUserMappingEntity::getUserId, userId) + .exists(); + } - public boolean addUser(String tenantId, String userId) { - TenantUserMappingEntity entity = new TenantUserMappingEntity(); - entity.setTenantId(tenantId); - entity.setUserId(userId); - return StringUtils.isNotBlank(create(entity)); - } + public boolean addUser(String tenantId, String userId) { + TenantUserMappingEntity entity = new TenantUserMappingEntity(); + entity.setTenantId(tenantId); + entity.setUserId(userId); + return StringUtils.isNotBlank(create(entity)); + } - public boolean deleteUser(String tenantId, String userId) { - return lambdaUpdate().eq(TenantUserMappingEntity::getTenantId, tenantId) - .eq(TenantUserMappingEntity::getUserId, userId).remove(); - } + public boolean deleteUser(String tenantId, String userId) { + return lambdaUpdate() + .eq(TenantUserMappingEntity::getTenantId, tenantId) + .eq(TenantUserMappingEntity::getUserId, userId) + .remove(); + } - public List getUserTenants(String userId) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll( - TenantEntity.class).innerJoin(TenantEntity.class, TenantEntity::getId, - TenantUserMappingEntity::getTenantId) + public List getUserTenants(String userId) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(TenantEntity.class) + .innerJoin( + TenantEntity.class, TenantEntity::getId, TenantUserMappingEntity::getTenantId) .eq(TenantUserMappingEntity::getUserId, userId); - return getBaseMapper().selectJoinList(TenantEntity.class, wrapper); - } + return getBaseMapper().selectJoinList(TenantEntity.class, wrapper); + } - public TenantEntity getUserActiveTenant(String userId) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll( - TenantEntity.class).innerJoin(TenantEntity.class, TenantEntity::getId, - TenantUserMappingEntity::getTenantId) - .eq(TenantUserMappingEntity::getUserId, userId).eq(TenantUserMappingEntity::isActive, true); - return getBaseMapper().selectJoinOne(TenantEntity.class, wrapper); - } + public TenantEntity getUserActiveTenant(String userId) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(TenantEntity.class) + .innerJoin( + TenantEntity.class, TenantEntity::getId, TenantUserMappingEntity::getTenantId) + .eq(TenantUserMappingEntity::getUserId, userId) + .eq(TenantUserMappingEntity::isActive, true); + return getBaseMapper().selectJoinOne(TenantEntity.class, wrapper); + } - public void activateUserTenant(String tenantId, String userId) { - lambdaUpdate().set(TenantUserMappingEntity::isActive, true).eq(TenantUserMappingEntity::getTenantId, tenantId) - .eq(TenantUserMappingEntity::getUserId, userId).update(); - } + public void activateUserTenant(String tenantId, String userId) { + lambdaUpdate() + .set(TenantUserMappingEntity::isActive, true) + .eq(TenantUserMappingEntity::getTenantId, tenantId) + .eq(TenantUserMappingEntity::getUserId, userId) + .update(); + } - public void deactivateUserTenants(String userId) { - lambdaUpdate().set(TenantUserMappingEntity::isActive, false).eq(TenantUserMappingEntity::getUserId, userId) - .update(); - } + public void deactivateUserTenants(String userId) { + lambdaUpdate() + .set(TenantUserMappingEntity::isActive, false) + .eq(TenantUserMappingEntity::getUserId, userId) + .update(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserDao.java index faedd5b71..30259ebdb 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserDao.java @@ -19,11 +19,11 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.NameEntity; @@ -34,41 +34,52 @@ import org.apache.geaflow.console.common.dal.model.UserSearch; import org.springframework.stereotype.Repository; +import com.github.yulichang.wrapper.MPJLambdaWrapper; + @Repository -public class UserDao extends SystemLevelDao implements NameDao { +public class UserDao extends SystemLevelDao + implements NameDao { - public PageList search(String tenantId, UserSearch search) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll(UserEntity.class) - .innerJoin(TenantUserMappingEntity.class, TenantUserMappingEntity::getUserId, UserEntity::getId) + public PageList search(String tenantId, UserSearch search) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(UserEntity.class) + .innerJoin( + TenantUserMappingEntity.class, + TenantUserMappingEntity::getUserId, + UserEntity::getId) .eq(TenantUserMappingEntity::getTenantId, tenantId); - return search(wrapper, search); - } - - @Override - public void configJoinSearch(MPJLambdaWrapper wrapper, UserSearch search) { - String email = search.getEmail(); - String phone = search.getPhone(); + return search(wrapper, search); + } - wrapper.like(StringUtils.isNotBlank(email), UserEntity::getEmail, email); - wrapper.like(StringUtils.isNotBlank(phone), UserEntity::getPhone, phone); - } + @Override + public void configJoinSearch(MPJLambdaWrapper wrapper, UserSearch search) { + String email = search.getEmail(); + String phone = search.getPhone(); - public UserEntity getByToken(String token) { - if (StringUtils.isBlank(token)) { - return null; - } + wrapper.like(StringUtils.isNotBlank(email), UserEntity::getEmail, email); + wrapper.like(StringUtils.isNotBlank(phone), UserEntity::getPhone, phone); + } - return lambdaQuery().eq(UserEntity::getSessionToken, token).one(); + public UserEntity getByToken(String token) { + if (StringUtils.isBlank(token)) { + return null; } - public boolean existName(String name) { - return lambdaQuery().eq(UserEntity::getName, name).exists(); - } + return lambdaQuery().eq(UserEntity::getSessionToken, token).one(); + } - public Map getUserNames(Collection userIds) { - List entities = lambdaQuery().select(UserEntity::getId, UserEntity::getName) - .in(UserEntity::getId, userIds).list(); - return entities.stream().collect(Collectors.toMap(IdEntity::getId, NameEntity::getName)); - } + public boolean existName(String name) { + return lambdaQuery().eq(UserEntity::getName, name).exists(); + } + + public Map getUserNames(Collection userIds) { + List entities = + lambdaQuery() + .select(UserEntity::getId, UserEntity::getName) + .in(UserEntity::getId, userIds) + .list(); + return entities.stream().collect(Collectors.toMap(IdEntity::getId, NameEntity::getName)); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserLevelDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserLevelDao.java index 9071a4ae3..76e84336c 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserLevelDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserLevelDao.java @@ -19,50 +19,51 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; import org.apache.geaflow.console.common.dal.entity.UserLevelEntity; import org.apache.geaflow.console.common.dal.mapper.GeaflowBaseMapper; import org.apache.geaflow.console.common.util.context.ContextHolder; -public abstract class UserLevelDao, E extends UserLevelEntity> extends - SystemLevelDao { +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; + +public abstract class UserLevelDao, E extends UserLevelEntity> + extends SystemLevelDao { - public static final String SYSTEM_FIELD_NAME = "`system`"; + public static final String SYSTEM_FIELD_NAME = "`system`"; - @Override - public void configQueryWrapper(QueryWrapper wrapper) { - String userId = ContextHolder.get().getUserId(); - wrapper.nested( - w -> w.eq(SYSTEM_FIELD_NAME, true).or().eq(IdDao.CREATOR_FIELD_NAME, userId)); - } + @Override + public void configQueryWrapper(QueryWrapper wrapper) { + String userId = ContextHolder.get().getUserId(); + wrapper.nested(w -> w.eq(SYSTEM_FIELD_NAME, true).or().eq(IdDao.CREATOR_FIELD_NAME, userId)); + } - @Override - public void configUpdateWrapper(UpdateWrapper wrapper) { - boolean systemSession = ContextHolder.get().isSystemSession(); - String userId = ContextHolder.get().getUserId(); + @Override + public void configUpdateWrapper(UpdateWrapper wrapper) { + boolean systemSession = ContextHolder.get().isSystemSession(); + String userId = ContextHolder.get().getUserId(); - wrapper.eq(SYSTEM_FIELD_NAME, systemSession); - if (!systemSession) { - wrapper.eq(IdDao.CREATOR_FIELD_NAME, userId); - } + wrapper.eq(SYSTEM_FIELD_NAME, systemSession); + if (!systemSession) { + wrapper.eq(IdDao.CREATOR_FIELD_NAME, userId); } + } - public boolean validateGetId(String id) { - if (id == null) { - return false; - } - // system session can query all data - return ContextHolder.get().isSystemSession() ? lambdaQuery().eq(E::getId, id).exists() - : lambdaQuery().eq(E::isSystem, false).eq(E::getId, id).exists(); + public boolean validateGetId(String id) { + if (id == null) { + return false; } + // system session can query all data + return ContextHolder.get().isSystemSession() + ? lambdaQuery().eq(E::getId, id).exists() + : lambdaQuery().eq(E::isSystem, false).eq(E::getId, id).exists(); + } - public boolean validateUpdateId(String id) { - if (id == null) { - return false; - } - // only update when system session or the current user is the creator. - boolean systemSession = ContextHolder.get().isSystemSession(); - return lambdaQuery().eq(E::isSystem, systemSession).eq(E::getId, id).exists(); + public boolean validateUpdateId(String id) { + if (id == null) { + return false; } + // only update when system session or the current user is the creator. + boolean systemSession = ContextHolder.get().isSystemSession(); + return lambdaQuery().eq(E::isSystem, systemSession).eq(E::getId, id).exists(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserRoleMappingDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserRoleMappingDao.java index a4d480e7f..d780f0c99 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserRoleMappingDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/UserRoleMappingDao.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.dao; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.UserRoleMappingEntity; import org.apache.geaflow.console.common.dal.mapper.UserRoleMappingMapper; import org.apache.geaflow.console.common.dal.model.IdSearch; @@ -28,30 +29,38 @@ import org.springframework.stereotype.Repository; @Repository -public class UserRoleMappingDao extends TenantLevelExtDao implements - IdDao { - - public List getRoleTypes(String tenantId, String userId) { - List entities = lambdaQuery(tenantId).select(UserRoleMappingEntity::getRoleType) - .eq(UserRoleMappingEntity::getUserId, userId).list(); - return ListUtil.convert(entities, UserRoleMappingEntity::getRoleType); - } - - public boolean existRoleType(String tenantId, String userId, GeaflowRoleType roleType) { - return lambdaQuery(tenantId).eq(UserRoleMappingEntity::getUserId, userId) - .eq(UserRoleMappingEntity::getRoleType, roleType).exists(); - } - - public void addRoleType(String tenantId, String userId, GeaflowRoleType roleType) { - UserRoleMappingEntity entity = new UserRoleMappingEntity(); - entity.setTenantId(tenantId); - entity.setUserId(userId); - entity.setRoleType(roleType); - create(entity); - } - - public void deleteRoleType(String tenantId, String userId, GeaflowRoleType roleType) { - lambdaUpdate(tenantId).eq(UserRoleMappingEntity::getUserId, userId) - .eq(UserRoleMappingEntity::getRoleType, roleType).remove(); - } +public class UserRoleMappingDao + extends TenantLevelExtDao + implements IdDao { + + public List getRoleTypes(String tenantId, String userId) { + List entities = + lambdaQuery(tenantId) + .select(UserRoleMappingEntity::getRoleType) + .eq(UserRoleMappingEntity::getUserId, userId) + .list(); + return ListUtil.convert(entities, UserRoleMappingEntity::getRoleType); + } + + public boolean existRoleType(String tenantId, String userId, GeaflowRoleType roleType) { + return lambdaQuery(tenantId) + .eq(UserRoleMappingEntity::getUserId, userId) + .eq(UserRoleMappingEntity::getRoleType, roleType) + .exists(); + } + + public void addRoleType(String tenantId, String userId, GeaflowRoleType roleType) { + UserRoleMappingEntity entity = new UserRoleMappingEntity(); + entity.setTenantId(tenantId); + entity.setUserId(userId); + entity.setRoleType(roleType); + create(entity); + } + + public void deleteRoleType(String tenantId, String userId, GeaflowRoleType roleType) { + lambdaUpdate(tenantId) + .eq(UserRoleMappingEntity::getUserId, userId) + .eq(UserRoleMappingEntity::getRoleType, roleType) + .remove(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VersionDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VersionDao.java index fb10804a5..aa4e07424 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VersionDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VersionDao.java @@ -19,36 +19,45 @@ package org.apache.geaflow.console.common.dal.dao; -import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.entity.VersionEntity; import org.apache.geaflow.console.common.dal.mapper.VersionMapper; import org.apache.geaflow.console.common.dal.model.VersionSearch; import org.springframework.stereotype.Repository; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; + @Repository -public class VersionDao extends SystemLevelDao implements - NameDao { - - @Override - public void configSearch(LambdaQueryWrapper wrapper, VersionSearch search) { - wrapper.eq(search.getPublish() != null, VersionEntity::isPublish, search.getPublish()); - } - - public VersionEntity getDefaultVersion() { - return lambdaQuery().eq(VersionEntity::isPublish, true).orderByDesc(IdEntity::getGmtModified).last("limit 1") - .one(); - } - - public VersionEntity getPublishVersionByName(String name) { - return lambdaQuery().eq(VersionEntity::isPublish, true).eq(VersionEntity::getName, name).last("limit 1").one(); - } - - public long getFileRefCount(String fileId, String excludeFunctionId) { - return lambdaQuery() - .eq(VersionEntity::getEngineJarId, fileId) - .or().eq(VersionEntity::getLangJarId, fileId) - .ne(excludeFunctionId != null, IdEntity::getId, excludeFunctionId) - .count(); - } +public class VersionDao extends SystemLevelDao + implements NameDao { + + @Override + public void configSearch(LambdaQueryWrapper wrapper, VersionSearch search) { + wrapper.eq(search.getPublish() != null, VersionEntity::isPublish, search.getPublish()); + } + + public VersionEntity getDefaultVersion() { + return lambdaQuery() + .eq(VersionEntity::isPublish, true) + .orderByDesc(IdEntity::getGmtModified) + .last("limit 1") + .one(); + } + + public VersionEntity getPublishVersionByName(String name) { + return lambdaQuery() + .eq(VersionEntity::isPublish, true) + .eq(VersionEntity::getName, name) + .last("limit 1") + .one(); + } + + public long getFileRefCount(String fileId, String excludeFunctionId) { + return lambdaQuery() + .eq(VersionEntity::getEngineJarId, fileId) + .or() + .eq(VersionEntity::getLangJarId, fileId) + .ne(excludeFunctionId != null, IdEntity::getId, excludeFunctionId) + .count(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VertexDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VertexDao.java index ae287ca10..cbdca5a3f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VertexDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/VertexDao.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.common.dal.dao; -import com.github.yulichang.wrapper.MPJLambdaWrapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.EndpointEntity; import org.apache.geaflow.console.common.dal.entity.GraphStructMappingEntity; import org.apache.geaflow.console.common.dal.entity.VertexEntity; @@ -31,31 +31,33 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Repository; -@Repository -public class VertexDao extends TenantLevelDao implements DataDao { +import com.github.yulichang.wrapper.MPJLambdaWrapper; - @Autowired - private GraphStructMappingMapper graphStructMappingMapper; +@Repository +public class VertexDao extends TenantLevelDao + implements DataDao { - @Autowired - private EndpointMapper endpointMapper; + @Autowired private GraphStructMappingMapper graphStructMappingMapper; + @Autowired private EndpointMapper endpointMapper; - public List getByGraphId(String graphId) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll( - VertexEntity.class) - .innerJoin(VertexEntity.class, VertexEntity::getId, GraphStructMappingEntity::getResourceId) + public List getByGraphId(String graphId) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(VertexEntity.class) + .innerJoin( + VertexEntity.class, VertexEntity::getId, GraphStructMappingEntity::getResourceId) .eq(GraphStructMappingEntity::getGraphId, graphId) .orderByAsc(GraphStructMappingEntity::getSortKey); - return graphStructMappingMapper.selectJoinList(VertexEntity.class, wrapper); - - } - - public List getByEdge(String edgeId) { - MPJLambdaWrapper wrapper = new MPJLambdaWrapper().selectAll( - VertexEntity.class).innerJoin(VertexEntity.class, VertexEntity::getId, - EndpointEntity::getSourceId) + return graphStructMappingMapper.selectJoinList(VertexEntity.class, wrapper); + } + + public List getByEdge(String edgeId) { + MPJLambdaWrapper wrapper = + new MPJLambdaWrapper() + .selectAll(VertexEntity.class) + .innerJoin(VertexEntity.class, VertexEntity::getId, EndpointEntity::getSourceId) .eq(EndpointEntity::getEdgeId, edgeId); - return endpointMapper.selectJoinList(VertexEntity.class, wrapper); - } + return endpointMapper.selectJoinList(VertexEntity.class, wrapper); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ViewDao.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ViewDao.java index 75b1355e1..4623f9ec3 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ViewDao.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/dao/ViewDao.java @@ -25,7 +25,5 @@ import org.springframework.stereotype.Repository; @Repository -public class ViewDao extends TenantLevelDao implements DataDao { - -} +public class ViewDao extends TenantLevelDao + implements DataDao {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/DataSourceConfiguration.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/DataSourceConfiguration.java index 21176fcf4..ee4669ab6 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/DataSourceConfiguration.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/DataSourceConfiguration.java @@ -19,15 +19,8 @@ package org.apache.geaflow.console.common.dal.datasource; -import com.baomidou.mybatisplus.annotation.DbType; -import com.baomidou.mybatisplus.autoconfigure.SpringBootVFS; -import com.baomidou.mybatisplus.core.config.GlobalConfig; -import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; -import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; -import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean; -import com.github.yulichang.injector.MPJSqlInjector; -import com.github.yulichang.interceptor.MPJInterceptor; import javax.sql.DataSource; + import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.session.SqlSessionFactory; import org.mybatis.spring.annotation.MapperScan; @@ -36,39 +29,50 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import com.baomidou.mybatisplus.annotation.DbType; +import com.baomidou.mybatisplus.autoconfigure.SpringBootVFS; +import com.baomidou.mybatisplus.core.config.GlobalConfig; +import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; +import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; +import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean; +import com.github.yulichang.injector.MPJSqlInjector; +import com.github.yulichang.interceptor.MPJInterceptor; + @Configuration -@MapperScan(basePackages = { - "org.apache.geaflow.console.common.dal.mapper"}, sqlSessionFactoryRef = "sqlSessionFactory", nameGenerator = - DataSourceConfiguration.SpringBeanNameGenerator.class) +@MapperScan( + basePackages = {"org.apache.geaflow.console.common.dal.mapper"}, + sqlSessionFactoryRef = "sqlSessionFactory", + nameGenerator = DataSourceConfiguration.SpringBeanNameGenerator.class) public class DataSourceConfiguration { - public static class SpringBeanNameGenerator extends AnnotationBeanNameGenerator { + public static class SpringBeanNameGenerator extends AnnotationBeanNameGenerator { - @Override - protected String buildDefaultBeanName(BeanDefinition definition) { - return definition.getBeanClassName(); - } + @Override + protected String buildDefaultBeanName(BeanDefinition definition) { + return definition.getBeanClassName(); } + } - @Bean("sqlSessionFactory") - public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception { - MybatisSqlSessionFactoryBean factory = new MybatisSqlSessionFactoryBean(); - factory.setDataSource(dataSource); - factory.setVfs(SpringBootVFS.class); + @Bean("sqlSessionFactory") + public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception { + MybatisSqlSessionFactoryBean factory = new MybatisSqlSessionFactoryBean(); + factory.setDataSource(dataSource); + factory.setVfs(SpringBootVFS.class); - // set mybatis tenant interceptor - MybatisPlusInterceptor mybatisPlusInterceptor = new MybatisPlusInterceptor(); - mybatisPlusInterceptor.addInnerInterceptor(new GeaflowTenantInterceptor(new TenantInterceptorHandler())); + // set mybatis tenant interceptor + MybatisPlusInterceptor mybatisPlusInterceptor = new MybatisPlusInterceptor(); + mybatisPlusInterceptor.addInnerInterceptor( + new GeaflowTenantInterceptor(new TenantInterceptorHandler())); - //set mybatis page plugin - mybatisPlusInterceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); + // set mybatis page plugin + mybatisPlusInterceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); - Interceptor[] plugins = new Interceptor[2]; - plugins[0] = mybatisPlusInterceptor; - plugins[1] = new MPJInterceptor(); - factory.setPlugins(plugins); - factory.setGlobalConfig(new GlobalConfig().setSqlInjector(new MPJSqlInjector())); + Interceptor[] plugins = new Interceptor[2]; + plugins[0] = mybatisPlusInterceptor; + plugins[1] = new MPJInterceptor(); + factory.setPlugins(plugins); + factory.setGlobalConfig(new GlobalConfig().setSqlInjector(new MPJSqlInjector())); - return factory.getObject(); - } + return factory.getObject(); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/GeaflowTenantInterceptor.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/GeaflowTenantInterceptor.java index f348ab0e1..b1d78314b 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/GeaflowTenantInterceptor.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/GeaflowTenantInterceptor.java @@ -19,11 +19,9 @@ package org.apache.geaflow.console.common.dal.datasource; -import com.baomidou.mybatisplus.core.toolkit.PluginUtils; -import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler; -import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor; import java.sql.Connection; import java.sql.SQLException; + import org.apache.geaflow.console.common.dal.dao.TenantLevelExtDao; import org.apache.geaflow.console.common.util.Fmt; import org.apache.ibatis.executor.Executor; @@ -33,32 +31,43 @@ import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; -public class GeaflowTenantInterceptor extends TenantLineInnerInterceptor { +import com.baomidou.mybatisplus.core.toolkit.PluginUtils; +import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler; +import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor; - public GeaflowTenantInterceptor(TenantLineHandler tenantLineHandler) { - super(tenantLineHandler); - } +public class GeaflowTenantInterceptor extends TenantLineInnerInterceptor { - @Override - public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, - ResultHandler resultHandler, BoundSql boundSql) throws SQLException { - if (ignoreSql(boundSql.getSql())) { - return; - } + public GeaflowTenantInterceptor(TenantLineHandler tenantLineHandler) { + super(tenantLineHandler); + } - super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql); + @Override + public void beforeQuery( + Executor executor, + MappedStatement ms, + Object parameter, + RowBounds rowBounds, + ResultHandler resultHandler, + BoundSql boundSql) + throws SQLException { + if (ignoreSql(boundSql.getSql())) { + return; } - @Override - public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { - if (ignoreSql(PluginUtils.mpStatementHandler(sh).boundSql().getSql())) { - return; - } + super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql); + } - super.beforePrepare(sh, connection, transactionTimeout); + @Override + public void beforePrepare( + StatementHandler sh, Connection connection, Integer transactionTimeout) { + if (ignoreSql(PluginUtils.mpStatementHandler(sh).boundSql().getSql())) { + return; } - private static boolean ignoreSql(String sql) { - return sql.endsWith(Fmt.as("/*{}*/", TenantLevelExtDao.IGNORE_TENANT_SIGNATURE)); - } + super.beforePrepare(sh, connection, transactionTimeout); + } + + private static boolean ignoreSql(String sql) { + return sql.endsWith(Fmt.as("/*{}*/", TenantLevelExtDao.IGNORE_TENANT_SIGNATURE)); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/TenantInterceptorHandler.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/TenantInterceptorHandler.java index 928a02f68..dc7aae147 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/TenantInterceptorHandler.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/datasource/TenantInterceptorHandler.java @@ -19,38 +19,40 @@ package org.apache.geaflow.console.common.dal.datasource; +import org.apache.geaflow.console.common.dal.dao.IdDao; +import org.apache.geaflow.console.common.util.context.ContextHolder; +import org.apache.geaflow.console.common.util.context.GeaflowContext; + import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler; import com.google.common.base.Preconditions; + import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.StringValue; -import org.apache.geaflow.console.common.dal.dao.IdDao; -import org.apache.geaflow.console.common.util.context.ContextHolder; -import org.apache.geaflow.console.common.util.context.GeaflowContext; public class TenantInterceptorHandler implements TenantLineHandler { - @Override - public Expression getTenantId() { - GeaflowContext context = ContextHolder.get(); - Preconditions.checkNotNull(context, "Invalid context"); - - // allow null tenant id - String tenantId = context.getTenantId(); - if (tenantId == null) { - return null; - } + @Override + public Expression getTenantId() { + GeaflowContext context = ContextHolder.get(); + Preconditions.checkNotNull(context, "Invalid context"); - return new StringValue(tenantId); + // allow null tenant id + String tenantId = context.getTenantId(); + if (tenantId == null) { + return null; } - @Override - public String getTenantIdColumn() { - return IdDao.TENANT_ID_FIELD_NAME; - } + return new StringValue(tenantId); + } - @Override - public boolean ignoreTable(String tableName) { - // allow null tenant id - return ContextHolder.get().getTenantId() == null; - } + @Override + public String getTenantIdColumn() { + return IdDao.TENANT_ID_FIELD_NAME; + } + + @Override + public boolean ignoreTable(String tableName) { + // allow null tenant id + return ContextHolder.get().getTenantId() == null; + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuditEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuditEntity.java index 459b14ace..04216044e 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuditEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuditEntity.java @@ -19,22 +19,24 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowOperationType; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowOperationType; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @TableName("geaflow_audit") public class AuditEntity extends IdEntity { - private String resourceId; + private String resourceId; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private GeaflowOperationType operationType; + private GeaflowOperationType operationType; - private String detail; + private String detail; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuthorizationEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuthorizationEntity.java index baa2bc863..ad7d783e9 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuthorizationEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/AuthorizationEntity.java @@ -19,13 +19,15 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @@ -34,11 +36,11 @@ @AllArgsConstructor public class AuthorizationEntity extends IdEntity { - private String userId; + private String userId; - private GeaflowAuthorityType authorityType; + private GeaflowAuthorityType authorityType; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private String resourceId; + private String resourceId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ChatEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ChatEntity.java index e0e6fba4c..3ca3edb22 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ChatEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ChatEntity.java @@ -19,23 +19,25 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; @Getter @Setter @TableName("geaflow_chat") public class ChatEntity extends IdEntity { - private String prompt; + private String prompt; - private String answer; + private String answer; - private String modelId; + private String modelId; - private String jobId; + private String jobId; - private GeaflowStatementStatus status; + private GeaflowStatementStatus status; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ClusterEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ClusterEntity.java index 27475ff24..03e9c803a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ClusterEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ClusterEntity.java @@ -19,18 +19,19 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Getter @Setter @TableName("geaflow_cluster") public class ClusterEntity extends NameEntity { - private GeaflowPluginType type; - - private String config; + private GeaflowPluginType type; + private String config; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/DataEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/DataEntity.java index bbc738ad8..940d7b752 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/DataEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/DataEntity.java @@ -26,6 +26,5 @@ @Setter public abstract class DataEntity extends NameEntity { - protected String instanceId; - + protected String instanceId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EdgeEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EdgeEntity.java index 1fa2c6a3d..4edb9bd9a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EdgeEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EdgeEntity.java @@ -20,13 +20,11 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @Getter @Setter @TableName("geaflow_edge") -public class EdgeEntity extends DataEntity { - - -} +public class EdgeEntity extends DataEntity {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EndpointEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EndpointEntity.java index 72421768f..4d952b29f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EndpointEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/EndpointEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -31,15 +32,16 @@ @AllArgsConstructor @NoArgsConstructor @TableName("geaflow_endpoint") -@EqualsAndHashCode(of = {"graphId", "edgeId", "sourceId", "targetId"}, callSuper = false) +@EqualsAndHashCode( + of = {"graphId", "edgeId", "sourceId", "targetId"}, + callSuper = false) public class EndpointEntity extends IdEntity { - private String graphId; - - private String edgeId; + private String graphId; - private String sourceId; + private String edgeId; - private String targetId; + private String sourceId; + private String targetId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FieldEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FieldEntity.java index 4d91fff1b..63309e042 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FieldEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FieldEntity.java @@ -19,25 +19,27 @@ package org.apache.geaflow.console.common.dal.entity; -import com.baomidou.mybatisplus.annotation.TableName; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; import org.apache.geaflow.console.common.util.type.GeaflowFieldType; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; +import com.baomidou.mybatisplus.annotation.TableName; + +import lombok.Getter; +import lombok.Setter; + @Getter @Setter @TableName("geaflow_field") public class FieldEntity extends NameEntity { - private GeaflowFieldType type; + private GeaflowFieldType type; - private GeaflowFieldCategory category; + private GeaflowFieldCategory category; - private String resourceId; + private String resourceId; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private int sortKey; + private int sortKey; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FunctionEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FunctionEntity.java index 583810dee..eec87c38d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FunctionEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/FunctionEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,7 +29,7 @@ @TableName("geaflow_function") public class FunctionEntity extends DataEntity { - private String jarPackageId; + private String jarPackageId; - private String entryClass; + private String entryClass; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphEntity.java index 8f31b17a5..790228a57 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,6 +29,5 @@ @TableName("geaflow_graph") public class GraphEntity extends DataEntity { - private String pluginConfigId; - + private String pluginConfigId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphStructMappingEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphStructMappingEntity.java index 8080c4aae..586027180 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphStructMappingEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/GraphStructMappingEntity.java @@ -19,12 +19,14 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @@ -33,12 +35,11 @@ @TableName("geaflow_graph_struct_mapping") public class GraphStructMappingEntity extends IdEntity { - private String graphId; - - private String resourceId; + private String graphId; - private GeaflowResourceType resourceType; + private String resourceId; - private int sortKey; + private GeaflowResourceType resourceType; + private int sortKey; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/IdEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/IdEntity.java index a3944d7b2..03a6e4f7b 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/IdEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/IdEntity.java @@ -19,9 +19,11 @@ package org.apache.geaflow.console.common.dal.entity; -import com.baomidou.mybatisplus.annotation.TableId; import java.io.Serializable; import java.util.Date; + +import com.baomidou.mybatisplus.annotation.TableId; + import lombok.Getter; import lombok.Setter; @@ -29,17 +31,16 @@ @Setter public abstract class IdEntity implements Serializable { - protected String tenantId; - - @TableId("guid") - protected String id; + protected String tenantId; - protected Date gmtCreate; + @TableId("guid") + protected String id; - protected Date gmtModified; + protected Date gmtCreate; - protected String creatorId; + protected Date gmtModified; - protected String modifierId; + protected String creatorId; + protected String modifierId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/InstanceEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/InstanceEntity.java index f1649fbbc..354469970 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/InstanceEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/InstanceEntity.java @@ -20,12 +20,11 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @Getter @Setter @TableName("geaflow_instance") -public class InstanceEntity extends NameEntity { - -} +public class InstanceEntity extends NameEntity {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobEntity.java index 093d2c368..25f41c21e 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobEntity.java @@ -19,27 +19,29 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowJobType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowJobType; @Getter @Setter @TableName("geaflow_job") public class JobEntity extends NameEntity { - private String instanceId; + private String instanceId; - private GeaflowJobType type; + private GeaflowJobType type; - private String userCode; + private String userCode; - private String structMappings; + private String structMappings; - private String slaId; + private String slaId; - private String jarPackageId; + private String jarPackageId; - private String entryClass; + private String entryClass; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobResourceMappingEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobResourceMappingEntity.java index a1d15011b..058df5a9f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobResourceMappingEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/JobResourceMappingEntity.java @@ -19,12 +19,14 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter @Setter @@ -33,12 +35,11 @@ @TableName("geaflow_job_resource_mapping") public class JobResourceMappingEntity extends IdEntity { - private String jobId; - - private String resourceName; + private String jobId; - private GeaflowResourceType resourceType; + private String resourceName; - private String instanceId; + private GeaflowResourceType resourceType; + private String instanceId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/LLMEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/LLMEntity.java index e304ea8f1..6ac70c902 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/LLMEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/LLMEntity.java @@ -19,19 +19,21 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowLLMType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowLLMType; @Getter @Setter @TableName("geaflow_llm") public class LLMEntity extends NameEntity { - private String url; + private String url; - private GeaflowLLMType type; + private GeaflowLLMType type; - private String args; + private String args; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/NameEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/NameEntity.java index e6069eeb8..dd505329d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/NameEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/NameEntity.java @@ -26,7 +26,7 @@ @Setter public abstract class NameEntity extends IdEntity { - protected String name; + protected String name; - protected String comment; + protected String comment; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginConfigEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginConfigEntity.java index 785e2dff1..c6fdae1d3 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginConfigEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginConfigEntity.java @@ -19,20 +19,21 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; @Getter @Setter @TableName("geaflow_plugin_config") public class PluginConfigEntity extends UserLevelEntity { - protected GeaflowPluginCategory category; - - protected String type; + protected GeaflowPluginCategory category; - protected String config; + protected String type; + protected String config; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginEntity.java index 91d1bce5a..fdcd466f2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/PluginEntity.java @@ -19,20 +19,21 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; @Getter @Setter @TableName("geaflow_plugin") public class PluginEntity extends UserLevelEntity { - private String pluginType; - - private GeaflowPluginCategory pluginCategory; + private String pluginType; - private String jarPackageId; + private GeaflowPluginCategory pluginCategory; + private String jarPackageId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ReleaseEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ReleaseEntity.java index 72d06c9a2..04cb9b3dd 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ReleaseEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ReleaseEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,21 +29,21 @@ @TableName("geaflow_release") public class ReleaseEntity extends IdEntity { - private String jobId; + private String jobId; - private int version; + private int version; - private String versionId; + private String versionId; - private String clusterId; + private String clusterId; - private String jobPlan; + private String jobPlan; - private String jobConfig; + private String jobConfig; - private String clusterConfig; + private String clusterConfig; - private String url; + private String url; - private String md5; + private String md5; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/RemoteFileEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/RemoteFileEntity.java index 138632748..eb84af5aa 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/RemoteFileEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/RemoteFileEntity.java @@ -19,22 +19,23 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowFileType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowFileType; @Getter @Setter @TableName("geaflow_remote_file") public class RemoteFileEntity extends UserLevelEntity { - private GeaflowFileType type; - - private String path; + private GeaflowFileType type; - private String md5; + private String path; - protected String url; + private String md5; + protected String url; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ResourceCount.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ResourceCount.java index 09614a315..60526f86f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ResourceCount.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ResourceCount.java @@ -19,19 +19,18 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Setter @Getter public class ResourceCount { - GeaflowResourceType type; - - String name; - - int count; + GeaflowResourceType type; + String name; + int count; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/StatementEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/StatementEntity.java index 77234233f..4874fd7e2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/StatementEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/StatementEntity.java @@ -19,21 +19,23 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; @Getter @Setter @TableName("geaflow_statement") public class StatementEntity extends IdEntity { - private String script; + private String script; - private String result; + private String result; - private GeaflowStatementStatus status; + private GeaflowStatementStatus status; - private String jobId; + private String jobId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/SystemConfigEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/SystemConfigEntity.java index 00c84ab42..42f8c4c16 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/SystemConfigEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/SystemConfigEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,6 +29,5 @@ @TableName("geaflow_system_config") public class SystemConfigEntity extends NameEntity { - private String value; - + private String value; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TableEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TableEntity.java index 8322b71bc..eaa1d459d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TableEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TableEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,6 +29,5 @@ @TableName("geaflow_table") public class TableEntity extends DataEntity { - private String pluginConfigId; - + private String pluginConfigId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskEntity.java index bbb1044e0..93a85aa80 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskEntity.java @@ -19,41 +19,44 @@ package org.apache.geaflow.console.common.dal.entity; -import com.baomidou.mybatisplus.annotation.TableName; import java.util.Date; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.type.GeaflowTaskStatus; import org.apache.geaflow.console.common.util.type.GeaflowTaskType; +import com.baomidou.mybatisplus.annotation.TableName; + +import lombok.Getter; +import lombok.Setter; + @Getter @Setter @TableName("geaflow_task") public class TaskEntity extends IdEntity { - private String jobId; + private String jobId; - private String releaseId; + private String releaseId; - private String token; + private String token; - private Date startTime; + private Date startTime; - private Date endTime; + private Date endTime; - private GeaflowTaskType type; + private GeaflowTaskType type; - private GeaflowTaskStatus status; + private GeaflowTaskStatus status; - private String runtimeMetaConfigId; + private String runtimeMetaConfigId; - private String haMetaConfigId; + private String haMetaConfigId; - private String metricConfigId; + private String metricConfigId; - private String dataConfigId; + private String dataConfigId; - private String host; + private String host; - private String handle; + private String handle; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskScheduleEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskScheduleEntity.java index 1c80e8042..856433e4d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskScheduleEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TaskScheduleEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,13 +29,13 @@ @TableName("geaflow_task") public class TaskScheduleEntity extends NameEntity { - private String taskId; + private String taskId; - private String approvalId; + private String approvalId; - private String grayId; + private String grayId; - private boolean failOver; + private boolean failOver; - private String scheduleCron; + private String scheduleCron; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantEntity.java index 5c0002992..811ba0f8a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,6 +29,5 @@ @TableName("geaflow_tenant") public class TenantEntity extends NameEntity { - private String quotaId; - + private String quotaId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantUserMappingEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantUserMappingEntity.java index 96a81587f..45eef573d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantUserMappingEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/TenantUserMappingEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,8 +29,7 @@ @TableName("geaflow_tenant_user_mapping") public class TenantUserMappingEntity extends IdEntity { - private String userId; - - private boolean active; + private String userId; + private boolean active; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserEntity.java index fe8d4f865..a9d0dd51c 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserEntity.java @@ -19,8 +19,10 @@ package org.apache.geaflow.console.common.dal.entity; -import com.baomidou.mybatisplus.annotation.TableName; import java.util.Date; + +import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -29,15 +31,15 @@ @TableName("geaflow_user") public class UserEntity extends NameEntity { - private String email; + private String email; - private String phone; + private String phone; - private String passwordSign; + private String passwordSign; - private String sessionToken; + private String sessionToken; - private boolean systemSession; + private boolean systemSession; - private Date accessTime; + private Date accessTime; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserLevelEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserLevelEntity.java index d3475b172..973735419 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserLevelEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserLevelEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableField; + import lombok.Getter; import lombok.Setter; @@ -27,7 +28,6 @@ @Setter public abstract class UserLevelEntity extends NameEntity { - @TableField("`system`") - protected boolean system; - + @TableField("`system`") + protected boolean system; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserRoleMappingEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserRoleMappingEntity.java index c963ba3f8..052f9ab3c 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserRoleMappingEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/UserRoleMappingEntity.java @@ -19,17 +19,19 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowRoleType; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowRoleType; @Getter @Setter @TableName("geaflow_user_role_mapping") public class UserRoleMappingEntity extends IdEntity { - private String userId; + private String userId; - private GeaflowRoleType roleType; + private GeaflowRoleType roleType; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VersionEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VersionEntity.java index 337af1b66..3d3c72e05 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VersionEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VersionEntity.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @@ -28,10 +29,9 @@ @TableName("geaflow_version") public class VersionEntity extends NameEntity { - private String engineJarId; - - private String langJarId; + private String engineJarId; - private boolean publish; + private String langJarId; + private boolean publish; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VertexEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VertexEntity.java index 94df06314..43b809cb7 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VertexEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/VertexEntity.java @@ -20,12 +20,11 @@ package org.apache.geaflow.console.common.dal.entity; import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; @Getter @Setter @TableName("geaflow_vertex") -public class VertexEntity extends DataEntity { - -} +public class VertexEntity extends DataEntity {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ViewEntity.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ViewEntity.java index 46fbec6a5..c9b9d7354 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ViewEntity.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/entity/ViewEntity.java @@ -19,18 +19,19 @@ package org.apache.geaflow.console.common.dal.entity; +import org.apache.geaflow.console.common.util.type.GeaflowViewCategory; + import com.baomidou.mybatisplus.annotation.TableName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowViewCategory; @Getter @Setter @TableName("geaflow_view") public class ViewEntity extends DataEntity { - private GeaflowViewCategory category; - - private String code; + private GeaflowViewCategory category; + private String code; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuditMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuditMapper.java index fca7a724b..61d1c5239 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuditMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuditMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface AuditMapper extends GeaflowBaseMapper { - -} +public interface AuditMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuthorizationMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuthorizationMapper.java index c5d24e41d..18dd53e21 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuthorizationMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/AuthorizationMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface AuthorizationMapper extends GeaflowBaseMapper { - -} +public interface AuthorizationMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ChatMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ChatMapper.java index 6c897a69b..423edf52b 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ChatMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ChatMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ChatMapper extends GeaflowBaseMapper { - -} +public interface ChatMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ClusterMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ClusterMapper.java index 50f034736..9c34bd4e4 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ClusterMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ClusterMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.ClusterEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface ClusterMapper extends GeaflowBaseMapper { - -} +public interface ClusterMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/DataEntityMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/DataEntityMapper.java index 8db0010cd..cb5bd87c7 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/DataEntityMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/DataEntityMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DataEntityMapper extends GeaflowBaseMapper { - -} +public interface DataEntityMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EdgeMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EdgeMapper.java index 884dd63d3..ba3eca553 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EdgeMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EdgeMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface EdgeMapper extends GeaflowBaseMapper { - -} +public interface EdgeMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EndpointMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EndpointMapper.java index 03201c450..6a7ff2551 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EndpointMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/EndpointMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface EndpointMapper extends GeaflowBaseMapper { - -} +public interface EndpointMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FieldMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FieldMapper.java index ce7db3ac7..6b4531eee 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FieldMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FieldMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface FieldMapper extends GeaflowBaseMapper { - -} +public interface FieldMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FunctionMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FunctionMapper.java index a1ba15ccd..77336c733 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FunctionMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/FunctionMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface FunctionMapper extends GeaflowBaseMapper { - -} +public interface FunctionMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GeaflowBaseMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GeaflowBaseMapper.java index 59158f3db..9fcb53952 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GeaflowBaseMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GeaflowBaseMapper.java @@ -19,10 +19,9 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.github.yulichang.base.MPJBaseMapper; import org.apache.ibatis.annotations.Mapper; -@Mapper -public interface GeaflowBaseMapper extends MPJBaseMapper { +import com.github.yulichang.base.MPJBaseMapper; -} +@Mapper +public interface GeaflowBaseMapper extends MPJBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphMapper.java index 1990968ee..d51e7c9f6 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface GraphMapper extends GeaflowBaseMapper { - -} +public interface GraphMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphStructMappingMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphStructMappingMapper.java index deb7ce7f2..2d940c222 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphStructMappingMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/GraphStructMappingMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface GraphStructMappingMapper extends GeaflowBaseMapper { - -} +public interface GraphStructMappingMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/IdEntityMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/IdEntityMapper.java index 794bc3983..6b96e2443 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/IdEntityMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/IdEntityMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface IdEntityMapper extends GeaflowBaseMapper { - -} +public interface IdEntityMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/InstanceMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/InstanceMapper.java index 82cf935a9..062841bf1 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/InstanceMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/InstanceMapper.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.mapper; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.InstanceEntity; import org.apache.geaflow.console.common.dal.entity.ResourceCount; import org.apache.ibatis.annotations.Mapper; @@ -29,22 +30,29 @@ @Mapper public interface InstanceMapper extends GeaflowBaseMapper { - @Select("") - List getResourceCount(@Param("instanceId") String instanceId, @Param("names") List names); + @Select( + "") + List getResourceCount( + @Param("instanceId") String instanceId, @Param("names") List names); } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobMapper.java index 47c86d626..18a6d6f13 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface JobMapper extends GeaflowBaseMapper { - -} +public interface JobMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobResourceMappingMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobResourceMappingMapper.java index a96c8a13a..caf2d2dab 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobResourceMappingMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/JobResourceMappingMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface JobResourceMappingMapper extends GeaflowBaseMapper { - -} +public interface JobResourceMappingMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/LLMMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/LLMMapper.java index 85a3c38b5..89ece2b79 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/LLMMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/LLMMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.LLMEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface LLMMapper extends GeaflowBaseMapper { - -} +public interface LLMMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/NameEntityMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/NameEntityMapper.java index 458ca0a36..3e44dcb52 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/NameEntityMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/NameEntityMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface NameEntityMapper extends GeaflowBaseMapper { - -} +public interface NameEntityMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginConfigMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginConfigMapper.java index ec8a240a9..4c8fa4905 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginConfigMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginConfigMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.PluginConfigEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface PluginConfigMapper extends GeaflowBaseMapper { - -} +public interface PluginConfigMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginMapper.java index 47677569b..32f5774c3 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/PluginMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.PluginEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface PluginMapper extends GeaflowBaseMapper { - -} +public interface PluginMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ReleaseMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ReleaseMapper.java index 5621f26f7..e82b511d6 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ReleaseMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ReleaseMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ReleaseMapper extends GeaflowBaseMapper { - -} +public interface ReleaseMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/RemoteFileMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/RemoteFileMapper.java index afb9ce4f5..a2f5c1c57 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/RemoteFileMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/RemoteFileMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.RemoteFileEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface RemoteFileMapper extends GeaflowBaseMapper { - -} +public interface RemoteFileMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/StatementMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/StatementMapper.java index efb450c12..a0f69a581 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/StatementMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/StatementMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface StatementMapper extends GeaflowBaseMapper { - -} +public interface StatementMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/SystemConfigMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/SystemConfigMapper.java index 2938bc593..81b70c8a9 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/SystemConfigMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/SystemConfigMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.SystemConfigEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface SystemConfigMapper extends GeaflowBaseMapper { - -} +public interface SystemConfigMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TableMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TableMapper.java index 816db0ef7..f931bfe09 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TableMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TableMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TableMapper extends GeaflowBaseMapper { - -} +public interface TableMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskMapper.java index 4ab27b7d0..6037175a2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TaskMapper extends GeaflowBaseMapper { - -} +public interface TaskMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskScheduleMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskScheduleMapper.java index 329774bd0..a44179f23 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskScheduleMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TaskScheduleMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TaskScheduleMapper extends GeaflowBaseMapper { - -} +public interface TaskScheduleMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantMapper.java index a835dcd72..ad25076fe 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.TenantEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface TenantMapper extends GeaflowBaseMapper { - -} +public interface TenantMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantUserMappingMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantUserMappingMapper.java index 0f1d845dc..0dbf9f76f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantUserMappingMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/TenantUserMappingMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.TenantUserMappingEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface TenantUserMappingMapper extends GeaflowBaseMapper { - -} +public interface TenantUserMappingMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserLevelMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserLevelMapper.java index 8d4dd2fd4..bd1a4300d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserLevelMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserLevelMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface UserLevelMapper extends GeaflowBaseMapper { - -} +public interface UserLevelMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserMapper.java index 10235685b..ca798d0f0 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.UserEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface UserMapper extends GeaflowBaseMapper { - -} +public interface UserMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserRoleMappingMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserRoleMappingMapper.java index 5e2cfdc2b..9f804b325 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserRoleMappingMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/UserRoleMappingMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface UserRoleMappingMapper extends GeaflowBaseMapper { - -} +public interface UserRoleMappingMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VersionMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VersionMapper.java index 7d9b4e5a2..8713f35fb 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VersionMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VersionMapper.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.common.dal.mapper; -import com.baomidou.mybatisplus.annotation.InterceptorIgnore; import org.apache.geaflow.console.common.dal.entity.VersionEntity; import org.apache.ibatis.annotations.Mapper; +import com.baomidou.mybatisplus.annotation.InterceptorIgnore; + @Mapper @InterceptorIgnore(tenantLine = "true") -public interface VersionMapper extends GeaflowBaseMapper { - -} +public interface VersionMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VertexMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VertexMapper.java index d183a6c28..0c75b8ccf 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VertexMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/VertexMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface VertexMapper extends GeaflowBaseMapper { - -} +public interface VertexMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ViewMapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ViewMapper.java index 82ffbfa6f..35bb1d7c3 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ViewMapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/mapper/ViewMapper.java @@ -23,6 +23,4 @@ import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ViewMapper extends GeaflowBaseMapper { - -} +public interface ViewMapper extends GeaflowBaseMapper {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuditSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuditSearch.java index 9b95e694e..19686c1cb 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuditSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuditSearch.java @@ -19,19 +19,19 @@ package org.apache.geaflow.console.common.dal.model; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowOperationType; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class AuditSearch extends IdSearch { - private String resourceId; - - private GeaflowResourceType resourceType; + private String resourceId; - private GeaflowOperationType operationType; + private GeaflowResourceType resourceType; + private GeaflowOperationType operationType; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuthorizationSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuthorizationSearch.java index aef5ec728..96832b34a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuthorizationSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/AuthorizationSearch.java @@ -19,20 +19,21 @@ package org.apache.geaflow.console.common.dal.model; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class AuthorizationSearch extends IdSearch { - private String userId; + private String userId; - private GeaflowAuthorityType authorityType; + private GeaflowAuthorityType authorityType; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private String resourceId; + private String resourceId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ChatSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ChatSearch.java index 1380e5f38..1003ecc73 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ChatSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ChatSearch.java @@ -26,7 +26,7 @@ @Setter public class ChatSearch extends IdSearch { - String jobId; + String jobId; - String modelId; + String modelId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ClusterSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ClusterSearch.java index e02339c64..f5ec16d98 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ClusterSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ClusterSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class ClusterSearch extends NameSearch { - -} +public class ClusterSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/DataSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/DataSearch.java index 0a3e3c111..db95b8b9d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/DataSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/DataSearch.java @@ -26,6 +26,5 @@ @Setter public class DataSearch extends NameSearch { - protected String instanceId; - + protected String instanceId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/EdgeSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/EdgeSearch.java index 4e4664820..d3c70c645 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/EdgeSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/EdgeSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class EdgeSearch extends DataSearch { - -} +public class EdgeSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FieldSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FieldSearch.java index 6bb508312..d3b8586b4 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FieldSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FieldSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class FieldSearch extends NameSearch { - -} +public class FieldSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FunctionSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FunctionSearch.java index 112899ce2..3c18fe3ec 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FunctionSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/FunctionSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class FunctionSearch extends DataSearch { - -} +public class FunctionSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/GraphSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/GraphSearch.java index fc0b81c9b..56df2d67d 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/GraphSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/GraphSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class GraphSearch extends DataSearch { - -} +public class GraphSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdDeletableSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdDeletableSearch.java index d11a11e92..d245fa3e7 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdDeletableSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdDeletableSearch.java @@ -26,6 +26,5 @@ @Setter public class IdDeletableSearch extends IdSearch { - protected Boolean isDeleted; - + protected Boolean isDeleted; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdSearch.java index d4da99ea4..ef63742cd 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/IdSearch.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.common.dal.model; import java.util.Date; + import lombok.Getter; import lombok.Setter; @@ -27,16 +28,15 @@ @Setter public class IdSearch extends PageSearch { - protected Date startCreateTime; - - protected Date endCreateTime; + protected Date startCreateTime; - protected Date startModifyTime; + protected Date endCreateTime; - protected Date endModifyTime; + protected Date startModifyTime; - protected String creatorId; + protected Date endModifyTime; - protected String modifierId; + protected String creatorId; + protected String modifierId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/InstanceSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/InstanceSearch.java index 711967f7a..456f0fb16 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/InstanceSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/InstanceSearch.java @@ -19,7 +19,4 @@ package org.apache.geaflow.console.common.dal.model; - -public class InstanceSearch extends NameSearch { - -} +public class InstanceSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSearch.java index 071bb165a..6f39f47b2 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSearch.java @@ -19,13 +19,14 @@ package org.apache.geaflow.console.common.dal.model; +import org.apache.geaflow.console.common.util.type.GeaflowJobType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowJobType; @Getter @Setter public class JobSearch extends DataSearch { - GeaflowJobType jobType; + GeaflowJobType jobType; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSlaSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSlaSearch.java index 4b05ac5bb..cc083e30f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSlaSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/JobSlaSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class JobSlaSearch extends IdSearch { - -} +public class JobSlaSearch extends IdSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/LLMSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/LLMSearch.java index e014e47ea..6ffb319ec 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/LLMSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/LLMSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class LLMSearch extends NameSearch { - -} +public class LLMSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/NameSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/NameSearch.java index 1af473f81..b3631ddd9 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/NameSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/NameSearch.java @@ -26,8 +26,7 @@ @Setter public class NameSearch extends IdSearch { - protected String name; - - protected String comment; + protected String name; + protected String comment; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageList.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageList.java index 77128b408..50326aafc 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageList.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageList.java @@ -19,9 +19,11 @@ package org.apache.geaflow.console.common.dal.model; -import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import java.util.List; import java.util.function.Function; + +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; + import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; @@ -30,21 +32,20 @@ @AllArgsConstructor(access = AccessLevel.PRIVATE) public class PageList { - private final List list; - - private final long total; + private final List list; - public PageList(List data) { - this(data, data.size()); - } + private final long total; - public PageList(Page page) { - this(page.getRecords(), page.getTotal()); - } + public PageList(List data) { + this(data, data.size()); + } - public PageList transform(Function, List> converter) { - List newList = converter.apply(this.list); - return new PageList<>(newList, total); - } + public PageList(Page page) { + this(page.getRecords(), page.getTotal()); + } + public PageList transform(Function, List> converter) { + List newList = converter.apply(this.list); + return new PageList<>(newList, total); + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageSearch.java index 042ff18c6..872eb72d0 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PageSearch.java @@ -26,24 +26,22 @@ @Getter public class PageSearch { - protected Integer page; + protected Integer page; - protected Integer size; + protected Integer size; - protected String sort; + protected String sort; - protected SortOrder order = SortOrder.ASC; + protected SortOrder order = SortOrder.ASC; - public enum SortOrder { + public enum SortOrder { + ASC, - ASC, + DESC + } - DESC - } - - public void setOrder(SortOrder order, String key) { - this.order = order; - this.sort = key; - - } + public void setOrder(SortOrder order, String key) { + this.order = order; + this.sort = key; + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginConfigSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginConfigSearch.java index 42d9a67d2..be985360f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginConfigSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginConfigSearch.java @@ -19,17 +19,18 @@ package org.apache.geaflow.console.common.dal.model; +import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; @Getter @Setter public class PluginConfigSearch extends NameSearch { - protected GeaflowPluginCategory category; + protected GeaflowPluginCategory category; - protected String type; + protected String type; - protected String userId; + protected String userId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginSearch.java index eb9b33621..39c9305d0 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/PluginSearch.java @@ -19,18 +19,19 @@ package org.apache.geaflow.console.common.dal.model; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class PluginSearch extends NameSearch { - private GeaflowPluginType pluginType; + private GeaflowPluginType pluginType; - private GeaflowPluginCategory pluginCategory; + private GeaflowPluginCategory pluginCategory; - private String keyword; + private String keyword; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ReleaseSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ReleaseSearch.java index 6e57a4051..f430b5bfe 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ReleaseSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ReleaseSearch.java @@ -26,10 +26,9 @@ @Setter public class ReleaseSearch extends IdSearch { - private String jobId; + private String jobId; - private String versionId; - - private String clusterId; + private String versionId; + private String clusterId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/RemoteFileSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/RemoteFileSearch.java index 084edcbff..1a9fa5121 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/RemoteFileSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/RemoteFileSearch.java @@ -19,13 +19,14 @@ package org.apache.geaflow.console.common.dal.model; +import org.apache.geaflow.console.common.util.type.GeaflowFileType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowFileType; @Getter @Setter public class RemoteFileSearch extends NameSearch { - GeaflowFileType type; + GeaflowFileType type; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/StatementSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/StatementSearch.java index ed1ce0c52..cb87f645e 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/StatementSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/StatementSearch.java @@ -19,15 +19,16 @@ package org.apache.geaflow.console.common.dal.model; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; @Getter @Setter public class StatementSearch extends IdSearch { - GeaflowStatementStatus status; + GeaflowStatementStatus status; - String jobId; + String jobId; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/SystemConfigSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/SystemConfigSearch.java index 3a17325c1..98a43d7a7 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/SystemConfigSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/SystemConfigSearch.java @@ -26,10 +26,9 @@ @Setter public class SystemConfigSearch extends NameSearch { - private String tenantId; + private String tenantId; - private String key; - - private String value; + private String key; + private String value; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TableSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TableSearch.java index 792ecbe0c..032074b50 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TableSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TableSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class TableSearch extends DataSearch { - -} +public class TableSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskScheduleSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskScheduleSearch.java index 5136a35c7..4dd610382 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskScheduleSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskScheduleSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class TaskScheduleSearch extends NameSearch { - -} +public class TaskScheduleSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskSearch.java index 1941d8165..a59837b6f 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TaskSearch.java @@ -19,34 +19,34 @@ package org.apache.geaflow.console.common.dal.model; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; import org.apache.geaflow.console.common.util.type.GeaflowTaskStatus; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class TaskSearch extends NameSearch { - private String jobName; + private String jobName; - private String jobId; + private String jobId; - protected GeaflowTaskStatus status; + protected GeaflowTaskStatus status; - protected GeaflowJobType jobType; + protected GeaflowJobType jobType; - protected String host; + protected String host; - private String resourceName; + private String resourceName; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private String instanceId; + private String instanceId; - private String versionId; + private String versionId; - private String clusterId; + private String clusterId; } - diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TenantSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TenantSearch.java index edeec7539..52884aa61 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TenantSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/TenantSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class TenantSearch extends NameSearch { - -} +public class TenantSearch extends NameSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/UserSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/UserSearch.java index 1f63444e6..eb4455ec5 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/UserSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/UserSearch.java @@ -26,8 +26,7 @@ @Setter public class UserSearch extends NameSearch { - private String email; - - private String phone; + private String email; + private String phone; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VersionSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VersionSearch.java index 8277ed25d..e1f9a7005 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VersionSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VersionSearch.java @@ -26,6 +26,5 @@ @Setter public class VersionSearch extends NameSearch { - private Boolean publish; - + private Boolean publish; } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VertexSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VertexSearch.java index 7f488f37e..fae0b5470 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VertexSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/VertexSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class VertexSearch extends DataSearch { - -} +public class VertexSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ViewSearch.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ViewSearch.java index fe0574861..308f62b68 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ViewSearch.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/model/ViewSearch.java @@ -24,6 +24,4 @@ @Getter @Setter -public class ViewSearch extends DataSearch { - -} +public class ViewSearch extends DataSearch {} diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaQueryChainWrapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaQueryChainWrapper.java index b8594fae0..cd7d86b67 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaQueryChainWrapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaQueryChainWrapper.java @@ -19,15 +19,17 @@ package org.apache.geaflow.console.common.dal.wrapper; +import org.apache.geaflow.console.common.dal.entity.IdEntity; + import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper; -import org.apache.geaflow.console.common.dal.entity.IdEntity; public class GeaflowLambdaQueryChainWrapper extends LambdaQueryChainWrapper { - public GeaflowLambdaQueryChainWrapper(BaseMapper baseMapper, LambdaQueryWrapper queryWrapper) { - super(baseMapper); - super.wrapperChildren = queryWrapper; - } + public GeaflowLambdaQueryChainWrapper( + BaseMapper baseMapper, LambdaQueryWrapper queryWrapper) { + super(baseMapper); + super.wrapperChildren = queryWrapper; + } } diff --git a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaUpdateChainWrapper.java b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaUpdateChainWrapper.java index 7349ab652..b72552f5a 100644 --- a/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaUpdateChainWrapper.java +++ b/geaflow-console/app/common/dal/src/main/java/org/apache/geaflow/console/common/dal/wrapper/GeaflowLambdaUpdateChainWrapper.java @@ -19,26 +19,31 @@ package org.apache.geaflow.console.common.dal.wrapper; +import java.util.Date; + +import org.apache.geaflow.console.common.dal.entity.IdEntity; +import org.apache.geaflow.console.common.util.context.ContextHolder; + import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.extension.conditions.update.LambdaUpdateChainWrapper; -import java.util.Date; + import lombok.Getter; -import org.apache.geaflow.console.common.dal.entity.IdEntity; -import org.apache.geaflow.console.common.util.context.ContextHolder; @Getter -public class GeaflowLambdaUpdateChainWrapper extends LambdaUpdateChainWrapper { - - public GeaflowLambdaUpdateChainWrapper(BaseMapper baseMapper, LambdaUpdateWrapper updateWrapper) { - super(baseMapper); - super.wrapperChildren = updateWrapper; - } - - @Override - public boolean update() { - this.set(E::getGmtModified, new Date()); - this.set(E::getModifierId, ContextHolder.get().getUserId()); - return super.update(); - } +public class GeaflowLambdaUpdateChainWrapper + extends LambdaUpdateChainWrapper { + + public GeaflowLambdaUpdateChainWrapper( + BaseMapper baseMapper, LambdaUpdateWrapper updateWrapper) { + super(baseMapper); + super.wrapperChildren = updateWrapper; + } + + @Override + public boolean update() { + this.set(E::getGmtModified, new Date()); + this.set(E::getModifierId, ContextHolder.get().getUserId()); + return super.update(); + } } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileContext.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileContext.java index 7a36f91e6..4e5c119de 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileContext.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileContext.java @@ -20,13 +20,13 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.Map; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.dsl.common.compile.CompileContext") public interface CompileContext { - void setConfig(Map config); - - void setParallelisms(Map parallelisms); + void setConfig(Map config); + void setParallelisms(Map parallelisms); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileResult.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileResult.java index 6313bdd29..9c970d663 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileResult.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/CompileResult.java @@ -20,18 +20,19 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.Set; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.dsl.common.compile.CompileResult") public interface CompileResult { - JsonPlan getPhysicPlan(); + JsonPlan getPhysicPlan(); - Set getSourceTables(); + Set getSourceTables(); - Set getTargetTables(); + Set getTargetTables(); - Set getSourceGraphs(); + Set getSourceGraphs(); - Set getTargetGraphs(); + Set getTargetGraphs(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Configuration.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Configuration.java index 900c99df9..3ea7a5665 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Configuration.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Configuration.java @@ -20,13 +20,13 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.Map; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.common.config.Configuration") public interface Configuration { - Map getConfigMap(); - - void putAll(Map map); + Map getConfigMap(); + void putAll(Map map); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FsPath.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FsPath.java index c7fa94294..280f73c0d 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FsPath.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FsPath.java @@ -22,6 +22,4 @@ import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.hadoop.fs.Path") -public interface FsPath { - -} +public interface FsPath {} diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FunctionInfo.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FunctionInfo.java index 479baf6e8..1fe8d2eba 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FunctionInfo.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/FunctionInfo.java @@ -24,8 +24,7 @@ @ProxyClass("org.apache.geaflow.dsl.common.compile.FunctionInfo") public interface FunctionInfo { - String getInstanceName(); - - String getFunctionName(); + String getInstanceName(); + String getFunctionName(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GeaflowCompiler.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GeaflowCompiler.java index 5024c89f0..a46092630 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GeaflowCompiler.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GeaflowCompiler.java @@ -20,21 +20,21 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.Set; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.dsl.runtime.QueryClient") public interface GeaflowCompiler { - CompileResult compile(String script, CompileContext context); - - Set getUnResolvedFunctions(String script, CompileContext context); + CompileResult compile(String script, CompileContext context); - Set getDeclaredTablePlugins(String type, CompileContext context); + Set getUnResolvedFunctions(String script, CompileContext context); - Set getEnginePlugins(); + Set getDeclaredTablePlugins(String type, CompileContext context); - Set getUnResolvedTables(String script, CompileContext context); + Set getEnginePlugins(); - String formatOlapResult(String script, Object resultData, CompileContext context); + Set getUnResolvedTables(String script, CompileContext context); + String formatOlapResult(String script, Object resultData, CompileContext context); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GraphInfo.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GraphInfo.java index 1e60414a7..8f244ba24 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GraphInfo.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/GraphInfo.java @@ -24,7 +24,7 @@ @ProxyClass("org.apache.geaflow.dsl.common.compile.GraphInfo") public interface GraphInfo { - String getInstanceName(); + String getInstanceName(); - String getGraphName(); + String getGraphName(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/IPersistentIO.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/IPersistentIO.java index 777439d43..6db39ad93 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/IPersistentIO.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/IPersistentIO.java @@ -20,19 +20,19 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.List; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.file.IPersistentIO") public interface IPersistentIO { - List listFileName(FsPath path); - - long getFileCount(FsPath path); + List listFileName(FsPath path); - boolean exists(FsPath path); + long getFileCount(FsPath path); - boolean renameFile(FsPath from, FsPath to); + boolean exists(FsPath path); - boolean delete(FsPath path, boolean recursive); + boolean renameFile(FsPath from, FsPath to); + boolean delete(FsPath path, boolean recursive); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/JsonPlan.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/JsonPlan.java index 0fe3a77fa..a0aeb92cd 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/JsonPlan.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/JsonPlan.java @@ -20,12 +20,11 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.Map; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.dsl.common.compile.JsonPlan") public interface JsonPlan { - Map getVertices(); - + Map getVertices(); } - diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/K8sJobClient.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/K8sJobClient.java index a56753a45..554cbcc80 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/K8sJobClient.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/K8sJobClient.java @@ -24,10 +24,9 @@ @ProxyClass("org.apache.geaflow.cluster.k8s.client.KubernetesJobClient") public interface K8sJobClient { - void submitJob(); + void submitJob(); - void stopJob(); - - Object getMasterService(); + void stopJob(); + Object getMasterService(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/PersistentIOBuilder.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/PersistentIOBuilder.java index 97dc01263..2986f4870 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/PersistentIOBuilder.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/PersistentIOBuilder.java @@ -24,6 +24,5 @@ @ProxyClass("org.apache.geaflow.file.PersistentIOBuilder") public interface PersistentIOBuilder { - IPersistentIO build(Configuration userConfig); - + IPersistentIO build(Configuration userConfig); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Predecessor.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Predecessor.java index 1b3c87905..4a84cd0a3 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Predecessor.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Predecessor.java @@ -24,9 +24,7 @@ @ProxyClass("org.apache.geaflow.dsl.common.compile.Predecessor") public interface Predecessor { - String getId(); - - String getPartitionType(); - + String getId(); + String getPartitionType(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/TableInfo.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/TableInfo.java index 164c573c1..9a72d4cdf 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/TableInfo.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/TableInfo.java @@ -24,7 +24,7 @@ @ProxyClass("org.apache.geaflow.dsl.common.compile.TableInfo") public interface TableInfo { - String getInstanceName(); + String getInstanceName(); - String getTableName(); + String getTableName(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Vertex.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Vertex.java index 7703ba15f..201b13c9c 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Vertex.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/Vertex.java @@ -20,25 +20,25 @@ package org.apache.geaflow.console.common.service.integration.engine; import java.util.List; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.dsl.common.compile.Vertex") public interface Vertex { - String getVertexType(); - - String getVertexMode(); + String getVertexType(); - String getId(); + String getVertexMode(); - int getParallelism(); + String getId(); - String getOperator(); + int getParallelism(); - String getOperatorName(); + String getOperator(); - List getParents(); + String getOperatorName(); - JsonPlan getInnerPlan(); + List getParents(); + JsonPlan getInnerPlan(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClient.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClient.java index 6d1bee33a..251bc6068 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClient.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClient.java @@ -24,5 +24,5 @@ @ProxyClass("org.apache.geaflow.analytics.service.client.AnalyticsClient") public interface AnalyticsClient { - QueryResults executeQuery(String queryScript); + QueryResults executeQuery(String queryScript); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClientBuilder.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClientBuilder.java index 5a017bff4..6b2df8f11 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClientBuilder.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/AnalyticsClientBuilder.java @@ -24,26 +24,25 @@ @ProxyClass("org.apache.geaflow.analytics.service.client.AnalyticsClientBuilder") public interface AnalyticsClientBuilder { - AnalyticsClientBuilder withHost(String host); + AnalyticsClientBuilder withHost(String host); - AnalyticsClientBuilder withPort(int port); + AnalyticsClientBuilder withPort(int port); - AnalyticsClientBuilder withInitChannelPools(boolean initChannelPools); + AnalyticsClientBuilder withInitChannelPools(boolean initChannelPools); - AnalyticsClientBuilder withNeedAuth(boolean needAuth); + AnalyticsClientBuilder withNeedAuth(boolean needAuth); - AnalyticsClientBuilder withConfiguration(Configuration configuration); + AnalyticsClientBuilder withConfiguration(Configuration configuration); - AnalyticsClientBuilder withUser(String user); + AnalyticsClientBuilder withUser(String user); - AnalyticsClientBuilder withAnalyticsZkNode(String zkBaseNode); + AnalyticsClientBuilder withAnalyticsZkNode(String zkBaseNode); - AnalyticsClientBuilder withAnalyticsZkQuorumServers(String zkQuorumServer); + AnalyticsClientBuilder withAnalyticsZkQuorumServers(String zkQuorumServer); - AnalyticsClientBuilder withTimeoutMs(int timeoutMs); + AnalyticsClientBuilder withTimeoutMs(int timeoutMs); - AnalyticsClientBuilder withRetryNum(int retryNum); - - AnalyticsClient build(); + AnalyticsClientBuilder withRetryNum(int retryNum); + AnalyticsClient build(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/Configuration.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/Configuration.java index 85667da50..48d0cb1e6 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/Configuration.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/Configuration.java @@ -20,12 +20,13 @@ package org.apache.geaflow.console.common.service.integration.engine.analytics; import java.util.Map; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.common.config.Configuration") public interface Configuration { - void putAll(Map map); + void putAll(Map map); - void put(String key, String value); + void put(String key, String value); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryError.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryError.java index 5375bced9..bca170950 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryError.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryError.java @@ -24,8 +24,7 @@ @ProxyClass("org.apache.geaflow.analytics.service.query.QueryError") public interface QueryError { - int getCode(); - - String getName(); + int getCode(); + String getName(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryResults.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryResults.java index 4d12e5768..9d96ff16d 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryResults.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/QueryResults.java @@ -20,19 +20,19 @@ package org.apache.geaflow.console.common.service.integration.engine.analytics; import java.util.List; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.analytics.service.query.QueryResults") public interface QueryResults { - String getQueryId(); - - List> getRawData(); + String getQueryId(); - QueryError getError(); + List> getRawData(); - boolean getQueryStatus(); + QueryError getError(); - String getFormattedData(); + boolean getQueryStatus(); + String getFormattedData(); } diff --git a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/ResponseResult.java b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/ResponseResult.java index 7499afdc9..690f31fbf 100644 --- a/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/ResponseResult.java +++ b/geaflow-console/app/common/service/integration/src/main/java/org/apache/geaflow/console/common/service/integration/engine/analytics/ResponseResult.java @@ -20,13 +20,13 @@ package org.apache.geaflow.console.common.service.integration.engine.analytics; import java.util.List; + import org.apache.geaflow.console.common.util.proxy.ProxyClass; @ProxyClass("org.apache.geaflow.cluster.response.ResponseResult") public interface ResponseResult { - public int getId(); - - public List getResponse(); + public int getId(); + public List getResponse(); } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/DateTimeUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/DateTimeUtil.java index 404a193ed..60d57be40 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/DateTimeUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/DateTimeUtil.java @@ -19,22 +19,23 @@ package org.apache.geaflow.console.common.util; -import com.google.common.base.Preconditions; import java.text.SimpleDateFormat; import java.util.Date; + import org.apache.commons.lang3.time.DateUtils; -public class DateTimeUtil { +import com.google.common.base.Preconditions; - public static final String DATE_TIME_FORMATTER = "yyyy-MM-dd HH:mm:ss"; +public class DateTimeUtil { - public static String format(Date date) { - return date == null ? null : new SimpleDateFormat(DATE_TIME_FORMATTER).format(date); - } + public static final String DATE_TIME_FORMATTER = "yyyy-MM-dd HH:mm:ss"; - public static boolean isExpired(Date date, int liveSeconds) { - Preconditions.checkNotNull(date, "Invalid date"); - return DateUtils.addSeconds(date, liveSeconds).before(new Date()); - } + public static String format(Date date) { + return date == null ? null : new SimpleDateFormat(DATE_TIME_FORMATTER).format(date); + } + public static boolean isExpired(Date date, int liveSeconds) { + Preconditions.checkNotNull(date, "Invalid date"); + return DateUtils.addSeconds(date, liveSeconds).before(new Date()); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/FileUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/FileUtil.java index fac47f0f4..b007d98ce 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/FileUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/FileUtil.java @@ -19,87 +19,90 @@ package org.apache.geaflow.console.common.util; -import com.google.common.base.Preconditions; import java.io.File; import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.exception.GeaflowException; +import com.google.common.base.Preconditions; + public class FileUtil { - public static boolean exist(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - return new File(path).exists(); - } + public static boolean exist(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + return new File(path).exists(); + } - public static void touch(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try { - FileUtils.touch(new File(path)); + public static void touch(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try { + FileUtils.touch(new File(path)); - } catch (Exception e) { - throw new GeaflowException("Create file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Create file {} failed", path, e); } + } - public static void mkdir(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try { - FileUtils.forceMkdir(new File(path)); + public static void mkdir(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try { + FileUtils.forceMkdir(new File(path)); - } catch (Exception e) { - throw new GeaflowException("Create directory {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Create directory {} failed", path, e); } + } - public static boolean delete(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - return FileUtils.deleteQuietly(new File(path)); - } + public static boolean delete(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + return FileUtils.deleteQuietly(new File(path)); + } - public static String readFileContent(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try { - return FileUtils.readFileToString(new File(path), StandardCharsets.UTF_8.name()); + public static String readFileContent(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try { + return FileUtils.readFileToString(new File(path), StandardCharsets.UTF_8.name()); - } catch (Exception e) { - throw new GeaflowException("Read file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Read file {} failed", path, e); } + } - public static InputStream readFileStream(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try { - return FileUtils.openInputStream(new File(path)); + public static InputStream readFileStream(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try { + return FileUtils.openInputStream(new File(path)); - } catch (Exception e) { - throw new GeaflowException("Read file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Read file {} failed", path, e); } + } - public static void writeFile(String path, String content) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try { - FileUtils.write(new File(path), content, StandardCharsets.UTF_8); + public static void writeFile(String path, String content) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try { + FileUtils.write(new File(path), content, StandardCharsets.UTF_8); - } catch (Exception e) { - throw new GeaflowException("Write file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Write file {} failed", path, e); } + } - public static void writeFile(String path, InputStream stream) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - File file = new File(path); + public static void writeFile(String path, InputStream stream) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + File file = new File(path); - try (InputStream in = stream; OutputStream out = FileUtils.openOutputStream(file)) { - IOUtils.copy(in, out, 1024 * 1024 * 8); + try (InputStream in = stream; + OutputStream out = FileUtils.openOutputStream(file)) { + IOUtils.copy(in, out, 1024 * 1024 * 8); - } catch (Exception e) { - throw new GeaflowException("Write file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowException("Write file {} failed", path, e); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Fmt.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Fmt.java index dd4ce09d5..4aef761a1 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Fmt.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Fmt.java @@ -23,8 +23,7 @@ public class Fmt { - public static String as(String fmt, Object... values) { - return MessageFormatter.arrayFormat(fmt, values).getMessage(); - } - + public static String as(String fmt, Object... values) { + return MessageFormatter.arrayFormat(fmt, values).getMessage(); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/HTTPUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/HTTPUtil.java index 46b890a6d..2c9c9afd9 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/HTTPUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/HTTPUtil.java @@ -19,12 +19,19 @@ package org.apache.geaflow.console.common.util; -import com.alibaba.fastjson.JSONObject; import java.io.IOException; import java.io.InputStream; import java.util.Map; + import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.alibaba.fastjson.JSONObject; + import okhttp3.Headers; import okhttp3.MediaType; import okhttp3.OkHttpClient; @@ -33,92 +40,87 @@ import okhttp3.RequestBody; import okhttp3.Response; import okhttp3.ResponseBody; -import org.apache.commons.io.IOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class HTTPUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(HTTPUtil.class); - private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); - - public static JSONObject post(String url, String json) { - return post(url, json, JSONObject.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HTTPUtil.class); + private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); + + public static JSONObject post(String url, String json) { + return post(url, json, JSONObject.class); + } + + public static JSONObject post(String url, String json, Map headers) { + return post(url, json, headers, JSONObject.class); + } + + public static T post(String url, String json, Class resultClass) { + return post(url, json, null, resultClass); + } + + public static T post( + String url, String body, Map headers, Class resultClass) { + LOGGER.info("post url: {} body: {}", url, body); + RequestBody requestBody = RequestBody.create(body, MEDIA_TYPE); + Builder builder = getRequestBuilder(url, headers); + Request request = builder.post(requestBody).build(); + + OkHttpClient client = new OkHttpClient(); + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + String msg = (responseBody != null) ? responseBody.string() : "{}"; + if (!response.isSuccessful()) { + throw new RuntimeException(msg); + } + + return JSONObject.toJavaObject(JSONObject.parseObject(msg), resultClass); + } catch (IOException e) { + LOGGER.info("execute post failed: {}", e.getCause(), e); + throw new RuntimeException(e); } - - public static JSONObject post(String url, String json, Map headers) { - return post(url, json, headers, JSONObject.class); - } - - public static T post(String url, String json, Class resultClass) { - return post(url, json, null, resultClass); - } - - public static T post(String url, String body, Map headers, Class resultClass) { - LOGGER.info("post url: {} body: {}", url, body); - RequestBody requestBody = RequestBody.create(body, MEDIA_TYPE); - Builder builder = getRequestBuilder(url, headers); - Request request = builder.post(requestBody).build(); - - OkHttpClient client = new OkHttpClient(); - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - String msg = (responseBody != null) ? responseBody.string() : "{}"; - if (!response.isSuccessful()) { - throw new RuntimeException(msg); - } - - return JSONObject.toJavaObject(JSONObject.parseObject(msg), resultClass); - } catch (IOException e) { - LOGGER.info("execute post failed: {}", e.getCause(), e); - throw new RuntimeException(e); - } - } - - - public static JSONObject get(String url) { - return get(url, null, JSONObject.class); - } - - - public static T get(String url, Map headers, Class resultClass) { - LOGGER.info("get url: {}", url); - Builder builder = getRequestBuilder(url, headers); - Request request = builder.get().build(); - - - OkHttpClient client = new OkHttpClient(); - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - String msg = (responseBody != null) ? responseBody.string() : "{}"; - if (!response.isSuccessful()) { - throw new RuntimeException(msg); - } - - return JSONObject.toJavaObject(JSONObject.parseObject(msg), resultClass); - } catch (IOException e) { - LOGGER.info("execute get failed: {}", e.getCause(), e); - throw new RuntimeException(e); - } + } + + public static JSONObject get(String url) { + return get(url, null, JSONObject.class); + } + + public static T get(String url, Map headers, Class resultClass) { + LOGGER.info("get url: {}", url); + Builder builder = getRequestBuilder(url, headers); + Request request = builder.get().build(); + + OkHttpClient client = new OkHttpClient(); + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + String msg = (responseBody != null) ? responseBody.string() : "{}"; + if (!response.isSuccessful()) { + throw new RuntimeException(msg); + } + + return JSONObject.toJavaObject(JSONObject.parseObject(msg), resultClass); + } catch (IOException e) { + LOGGER.info("execute get failed: {}", e.getCause(), e); + throw new RuntimeException(e); } + } - private static Builder getRequestBuilder(String url, Map headers) { - Builder requestBuilder = new Request.Builder().url(url); - if (headers != null) { - Headers requestHeaders = Headers.of(headers); - requestBuilder.headers(requestHeaders); - } - return requestBuilder; + private static Builder getRequestBuilder(String url, Map headers) { + Builder requestBuilder = new Request.Builder().url(url); + if (headers != null) { + Headers requestHeaders = Headers.of(headers); + requestBuilder.headers(requestHeaders); } + return requestBuilder; + } - public static void download(HttpServletResponse response, InputStream inputStream, String fileName) - throws IOException { - response.setContentType("application/octet-stream"); - response.setHeader("Content-Disposition", "attachment; filename=" + fileName); + public static void download( + HttpServletResponse response, InputStream inputStream, String fileName) throws IOException { + response.setContentType("application/octet-stream"); + response.setHeader("Content-Disposition", "attachment; filename=" + fileName); - try (ServletOutputStream output = response.getOutputStream()) { - IOUtils.copy(inputStream, output, 1024 * 1024 * 8); - output.flush(); - } + try (ServletOutputStream output = response.getOutputStream()) { + IOUtils.copy(inputStream, output, 1024 * 1024 * 8); + output.flush(); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/I18nUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/I18nUtil.java index f8de7d9b6..0e96a5ca5 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/I18nUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/I18nUtil.java @@ -20,22 +20,23 @@ package org.apache.geaflow.console.common.util; import java.nio.charset.StandardCharsets; + import org.springframework.context.MessageSource; import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.context.support.ResourceBundleMessageSource; public class I18nUtil { - private static final MessageSource MESSAGE_SOURCE = initMessageSource(); + private static final MessageSource MESSAGE_SOURCE = initMessageSource(); - private static MessageSource initMessageSource() { - ResourceBundleMessageSource messageSource = new ResourceBundleMessageSource(); - messageSource.setBasename("i18n/messages"); - messageSource.setDefaultEncoding(StandardCharsets.UTF_8.name()); - return messageSource; - } + private static MessageSource initMessageSource() { + ResourceBundleMessageSource messageSource = new ResourceBundleMessageSource(); + messageSource.setBasename("i18n/messages"); + messageSource.setDefaultEncoding(StandardCharsets.UTF_8.name()); + return messageSource; + } - public static String getMessage(String key, Object... params) { - return MESSAGE_SOURCE.getMessage(key, params, key, LocaleContextHolder.getLocale()); - } + public static String getMessage(String key, Object... params) { + return MESSAGE_SOURCE.getMessage(key, params, key, LocaleContextHolder.getLocale()); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ListUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ListUtil.java index 87a924d10..88e47ac79 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ListUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ListUtil.java @@ -28,116 +28,117 @@ import java.util.Map.Entry; import java.util.function.Function; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; @SuppressWarnings("unchecked") public class ListUtil { - public static Collection join(Collection left, Collection right) { - if (left == null || right == null) { - return null; - } - - return (Collection) CollectionUtils.intersection(left, right); + public static Collection join(Collection left, Collection right) { + if (left == null || right == null) { + return null; } - public static Collection union(Collection left, Collection right) { - if (left == null) { - return right; - } + return (Collection) CollectionUtils.intersection(left, right); + } - if (right == null) { - return left; - } - - return (List) CollectionUtils.union(left, right); + public static Collection union(Collection left, Collection right) { + if (left == null) { + return right; } - public static Map toMap(Collection list, Function key) { - if (CollectionUtils.isEmpty(list)) { - return new HashMap<>(); - } - return list.stream().collect(Collectors.toMap(key, e -> e)); + if (right == null) { + return left; } - public static Map toMap(Collection list, Function key, Function value) { - if (CollectionUtils.isEmpty(list)) { - return new HashMap<>(); - } - return list.stream().collect(Collectors.toMap(key, value)); - } + return (List) CollectionUtils.union(left, right); + } - public static List convert(Collection list, Function function) { - if (CollectionUtils.isEmpty(list)) { - return new ArrayList<>(); - } - return list.stream().map(function).collect(Collectors.toList()); + public static Map toMap(Collection list, Function key) { + if (CollectionUtils.isEmpty(list)) { + return new HashMap<>(); } + return list.stream().collect(Collectors.toMap(key, e -> e)); + } - - public static List diff(List left, List right) { - return diff(left, right, null); + public static Map toMap( + Collection list, Function key, Function value) { + if (CollectionUtils.isEmpty(list)) { + return new HashMap<>(); } + return list.stream().collect(Collectors.toMap(key, value)); + } - public static List diff(List left, List right, Function function) { - return DiffHelper.diff(left, right, function); + public static List convert(Collection list, Function function) { + if (CollectionUtils.isEmpty(list)) { + return new ArrayList<>(); } - - private static class DiffHelper { - - private static List diff(List left, List right, Function function) { - if (left == null) { - return right; - } - - if (right == null) { - return left; - } - - ArrayList list = new ArrayList<>(); - Map mapLeft = getCardinalityMap(left, function); - Map mapRight = getCardinalityMap(right, function); - // duplicate removal for the left list - HashMap objMap = new HashMap<>(); - for (T t : left) { - objMap.putIfAbsent(getKey(t, function), t); - } - // calculate the diff of the number of the same key - for (Entry entry : objMap.entrySet()) { - Object key = entry.getKey(); - for (int i = 0, m = getFreq(key, mapLeft) - getFreq(key, mapRight); i < m; i++) { - list.add(entry.getValue()); - } - - } - - return list; + return list.stream().map(function).collect(Collectors.toList()); + } + + public static List diff(List left, List right) { + return diff(left, right, null); + } + + public static List diff(List left, List right, Function function) { + return DiffHelper.diff(left, right, function); + } + + private static class DiffHelper { + + private static List diff(List left, List right, Function function) { + if (left == null) { + return right; + } + + if (right == null) { + return left; + } + + ArrayList list = new ArrayList<>(); + Map mapLeft = getCardinalityMap(left, function); + Map mapRight = getCardinalityMap(right, function); + // duplicate removal for the left list + HashMap objMap = new HashMap<>(); + for (T t : left) { + objMap.putIfAbsent(getKey(t, function), t); + } + // calculate the diff of the number of the same key + for (Entry entry : objMap.entrySet()) { + Object key = entry.getKey(); + for (int i = 0, m = getFreq(key, mapLeft) - getFreq(key, mapRight); i < m; i++) { + list.add(entry.getValue()); } + } - private static Object getKey(T obj, Function function) { - return function == null ? obj : function.apply(obj); - } + return list; + } - private static Map getCardinalityMap(final Collection coll, Function function) { - Map count = new HashMap(); - for (Iterator it = coll.iterator(); it.hasNext(); ) { - Object key = getKey(it.next(), function); - Integer c = (count.get(key)); - if (c == null) { - count.put(key, 1); - } else { - count.put(key, c + 1); - } - } - return count; - } + private static Object getKey(T obj, Function function) { + return function == null ? obj : function.apply(obj); + } - private static int getFreq(Object key, final Map freqMap) { - Integer count = freqMap.get(key); - if (count != null) { - return count; - } - return 0; + private static Map getCardinalityMap( + final Collection coll, Function function) { + Map count = new HashMap(); + for (Iterator it = coll.iterator(); it.hasNext(); ) { + Object key = getKey(it.next(), function); + Integer c = (count.get(key)); + if (c == null) { + count.put(key, 1); + } else { + count.put(key, c + 1); } + } + return count; + } + + private static int getFreq(Object key, final Map freqMap) { + Integer count = freqMap.get(key); + if (count != null) { + return count; + } + return 0; } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/LoaderSwitchUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/LoaderSwitchUtil.java index 6e6255a15..b03647e88 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/LoaderSwitchUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/LoaderSwitchUtil.java @@ -21,40 +21,39 @@ public class LoaderSwitchUtil { - public static void run(ClassLoader classLoader, Runnable runnable) throws Exception { - Thread currentThread = Thread.currentThread(); - ClassLoader oldClassLoader = currentThread.getContextClassLoader(); + public static void run(ClassLoader classLoader, Runnable runnable) throws Exception { + Thread currentThread = Thread.currentThread(); + ClassLoader oldClassLoader = currentThread.getContextClassLoader(); - try { - currentThread.setContextClassLoader(classLoader); - runnable.run(); + try { + currentThread.setContextClassLoader(classLoader); + runnable.run(); - } finally { - currentThread.setContextClassLoader(oldClassLoader); - } + } finally { + currentThread.setContextClassLoader(oldClassLoader); } + } - public static T call(ClassLoader classLoader, Callable callable) throws Exception { - Thread currentThread = Thread.currentThread(); - ClassLoader oldClassLoader = currentThread.getContextClassLoader(); + public static T call(ClassLoader classLoader, Callable callable) throws Exception { + Thread currentThread = Thread.currentThread(); + ClassLoader oldClassLoader = currentThread.getContextClassLoader(); - try { - currentThread.setContextClassLoader(classLoader); - return callable.call(); + try { + currentThread.setContextClassLoader(classLoader); + return callable.call(); - } finally { - currentThread.setContextClassLoader(oldClassLoader); - } + } finally { + currentThread.setContextClassLoader(oldClassLoader); } + } - public interface Runnable { + public interface Runnable { - void run() throws Exception; - } - - public interface Callable { + void run() throws Exception; + } - T call() throws Exception; - } + public interface Callable { + T call() throws Exception; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Md5Util.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Md5Util.java index fbc19f64d..cd36c77da 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Md5Util.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/Md5Util.java @@ -19,52 +19,54 @@ package org.apache.geaflow.console.common.util; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; + import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.springframework.core.io.InputStreamSource; +import com.google.common.base.Preconditions; + public class Md5Util { - public static String encodeString(String text) { - if (text == null) { - return null; - } - return DigestUtils.md5Hex(text); + public static String encodeString(String text) { + if (text == null) { + return null; } + return DigestUtils.md5Hex(text); + } - public static String encodeFile(String path) { - Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); - try (InputStream input = Files.newInputStream(new File(path).toPath())) { - return DigestUtils.md5Hex(input); + public static String encodeFile(String path) { + Preconditions.checkArgument(StringUtils.isNotBlank(path), "Invalid path"); + try (InputStream input = Files.newInputStream(new File(path).toPath())) { + return DigestUtils.md5Hex(input); - } catch (Exception e) { - throw new GeaflowIllegalException("Encode md5 of file {} failed", path, e); - } + } catch (Exception e) { + throw new GeaflowIllegalException("Encode md5 of file {} failed", path, e); } + } - public static String encodeFile(InputStreamSource stream) throws IOException { - Preconditions.checkNotNull(stream, "Invalid stream source"); - try (InputStream input = stream.getInputStream()) { - return DigestUtils.md5Hex(input); + public static String encodeFile(InputStreamSource stream) throws IOException { + Preconditions.checkNotNull(stream, "Invalid stream source"); + try (InputStream input = stream.getInputStream()) { + return DigestUtils.md5Hex(input); - } catch (Exception e) { - throw new GeaflowIllegalException("Encode stream source failed", e); - } + } catch (Exception e) { + throw new GeaflowIllegalException("Encode stream source failed", e); } + } - public static String encodeFile(InputStream stream) throws IOException { - Preconditions.checkNotNull(stream, "Invalid stream"); - try (InputStream input = stream) { - return DigestUtils.md5Hex(input); + public static String encodeFile(InputStream stream) throws IOException { + Preconditions.checkNotNull(stream, "Invalid stream"); + try (InputStream input = stream) { + return DigestUtils.md5Hex(input); - } catch (Exception e) { - throw new GeaflowIllegalException("Encode stream failed", e); - } + } catch (Exception e) { + throw new GeaflowIllegalException("Encode stream failed", e); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/NetworkUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/NetworkUtil.java index c2c256ea7..e0844c7ed 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/NetworkUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/NetworkUtil.java @@ -19,136 +19,139 @@ package org.apache.geaflow.console.common.util; -import com.google.common.base.Preconditions; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; import java.util.HashMap; import java.util.Map; + import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; -public class NetworkUtil { - - public static final String LOCALHOST = "127.0.0.1"; +import com.google.common.base.Preconditions; - private static final Map DEFAULT_PORT_MAP = new HashMap<>(); +public class NetworkUtil { - private static String HOST_NAME; + public static final String LOCALHOST = "127.0.0.1"; - static { - DEFAULT_PORT_MAP.put("http", 80); - DEFAULT_PORT_MAP.put("https", 443); - DEFAULT_PORT_MAP.put("hdfs", 9000); - DEFAULT_PORT_MAP.put("dfs", 9000); - DEFAULT_PORT_MAP.put("jdbc:mysql", 3306); - } + private static final Map DEFAULT_PORT_MAP = new HashMap<>(); - public static String getHostName() { - if (HOST_NAME != null) { - return HOST_NAME; - } + private static String HOST_NAME; - try { - return HOST_NAME = InetAddress.getLocalHost().getHostName(); + static { + DEFAULT_PORT_MAP.put("http", 80); + DEFAULT_PORT_MAP.put("https", 443); + DEFAULT_PORT_MAP.put("hdfs", 9000); + DEFAULT_PORT_MAP.put("dfs", 9000); + DEFAULT_PORT_MAP.put("jdbc:mysql", 3306); + } - } catch (Exception e) { - throw new GeaflowException("Init local hostname failed", e); - } + public static String getHostName() { + if (HOST_NAME != null) { + return HOST_NAME; } - public static boolean isLocal(String url) { - return LOCALHOST.equals(getIp(getHost(url))); + try { + return HOST_NAME = InetAddress.getLocalHost().getHostName(); + + } catch (Exception e) { + throw new GeaflowException("Init local hostname failed", e); } + } - public static String getHost(String url) { - if (StringUtils.isBlank(url)) { - throw new GeaflowIllegalException("Invalid url"); - } + public static boolean isLocal(String url) { + return LOCALHOST.equals(getIp(getHost(url))); + } - if (url.contains("://")) { - url = StringUtils.substringAfter(url, "://"); - } + public static String getHost(String url) { + if (StringUtils.isBlank(url)) { + throw new GeaflowIllegalException("Invalid url"); + } - String[] seps = new String[]{":", "/", "?", "#"}; - for (String sep : seps) { - if (url.contains(sep)) { - url = StringUtils.substringBefore(url, sep); - } - } + if (url.contains("://")) { + url = StringUtils.substringAfter(url, "://"); + } - return url; + String[] seps = new String[] {":", "/", "?", "#"}; + for (String sep : seps) { + if (url.contains(sep)) { + url = StringUtils.substringBefore(url, sep); + } } - public static Integer getPort(String url) { - if (StringUtils.isBlank(url)) { - throw new GeaflowIllegalException("Invalid url"); - } + return url; + } - Integer port = null; - if (url.contains("://")) { - String schema = StringUtils.substringBefore(url, "://"); - port = NetworkUtil.getDefaultPort(schema); + public static Integer getPort(String url) { + if (StringUtils.isBlank(url)) { + throw new GeaflowIllegalException("Invalid url"); + } - url = StringUtils.substringAfter(url, "://"); - } + Integer port = null; + if (url.contains("://")) { + String schema = StringUtils.substringBefore(url, "://"); + port = NetworkUtil.getDefaultPort(schema); - String[] seps = new String[]{"/", "?", "#"}; - for (String sep : seps) { - if (url.contains(sep)) { - url = StringUtils.substringBefore(url, sep); - } - } + url = StringUtils.substringAfter(url, "://"); + } - if (url.contains(":")) { - port = Integer.parseInt(StringUtils.substringAfter(url, ":")); - } + String[] seps = new String[] {"/", "?", "#"}; + for (String sep : seps) { + if (url.contains(sep)) { + url = StringUtils.substringBefore(url, sep); + } + } - return port; + if (url.contains(":")) { + port = Integer.parseInt(StringUtils.substringAfter(url, ":")); } - public static String getIp(String hostname) { - try { - return InetAddress.getByName(hostname).getHostAddress(); + return port; + } - } catch (Exception e) { - throw new GeaflowIllegalException("Invalid hostname {}", hostname); - } - } + public static String getIp(String hostname) { + try { + return InetAddress.getByName(hostname).getHostAddress(); - public static Integer getDefaultPort(String schema) { - return DEFAULT_PORT_MAP.get(schema); + } catch (Exception e) { + throw new GeaflowIllegalException("Invalid hostname {}", hostname); } + } - public static void testUrls(String urls, String sep) { - String[] list = StringUtils.splitByWholeSeparator(urls, sep); - if (ArrayUtils.isEmpty(list)) { - throw new GeaflowIllegalException("Invalid urls {}", urls); - } + public static Integer getDefaultPort(String schema) { + return DEFAULT_PORT_MAP.get(schema); + } + + public static void testUrls(String urls, String sep) { + String[] list = StringUtils.splitByWholeSeparator(urls, sep); + if (ArrayUtils.isEmpty(list)) { + throw new GeaflowIllegalException("Invalid urls {}", urls); + } - for (String url : list) { - NetworkUtil.testUrl(url); - } + for (String url : list) { + NetworkUtil.testUrl(url); } + } - public static void testUrl(String url) { - String host = NetworkUtil.getHost(url); - Integer port = NetworkUtil.getPort(url); - if (port == null) { - throw new GeaflowIllegalException("Port is needed of url {}", url); - } - testHostPort(host, port); + public static void testUrl(String url) { + String host = NetworkUtil.getHost(url); + Integer port = NetworkUtil.getPort(url); + if (port == null) { + throw new GeaflowIllegalException("Port is needed of url {}", url); } + testHostPort(host, port); + } - public static void testHostPort(String host, int port) { - try (Socket socket = new Socket()) { - socket.connect(new InetSocketAddress(host, port), 3000); - Preconditions.checkArgument(socket.isConnected(), "Socket is not connected"); + public static void testHostPort(String host, int port) { + try (Socket socket = new Socket()) { + socket.connect(new InetSocketAddress(host, port), 3000); + Preconditions.checkArgument(socket.isConnected(), "Socket is not connected"); - } catch (Exception e) { - throw new GeaflowIllegalException("Connect to {}:{} failed, {}", host, port, e.getMessage(), e); - } + } catch (Exception e) { + throw new GeaflowIllegalException( + "Connect to {}:{} failed, {}", host, port, e.getMessage(), e); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ProcessUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ProcessUtil.java index 433195f4e..d4170fc34 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ProcessUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ProcessUtil.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.common.util; -import com.alibaba.fastjson.JSON; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileOutputStream; @@ -31,9 +30,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.exec.CommandLine; import org.apache.commons.exec.DefaultExecuteResultHandler; import org.apache.commons.exec.DefaultExecutor; @@ -47,241 +44,258 @@ import org.apache.geaflow.console.common.util.exception.GeaflowLogException; import org.springframework.util.StreamUtils; +import com.alibaba.fastjson.JSON; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j public class ProcessUtil { - public static String execute(String cmd) { - return execute(cmd, -1); - } - - public static String execute(String cmd, long timeout) { - return execCommand(cmd, false, timeout); - } + public static String execute(String cmd) { + return execute(cmd, -1); + } - public static int executeAsync(String cmd) { - return executeAsync(cmd, -1); - } + public static String execute(String cmd, long timeout) { + return execCommand(cmd, false, timeout); + } - public static int executeAsync(String cmd, long timeout) { - String pid = execCommand(cmd, true, timeout); - return Integer.parseInt(pid); - } + public static int executeAsync(String cmd) { + return executeAsync(cmd, -1); + } - public static List search(String command, String include, String exclude) { - String output = execute(command); - String[] lines = StringUtils.split(output, "\n"); + public static int executeAsync(String cmd, long timeout) { + String pid = execCommand(cmd, true, timeout); + return Integer.parseInt(pid); + } - List results = new ArrayList<>(); - if (lines == null) { - return results; - } + public static List search(String command, String include, String exclude) { + String output = execute(command); + String[] lines = StringUtils.split(output, "\n"); - for (String line : lines) { - if (include != null && !line.contains(include)) { - continue; - } + List results = new ArrayList<>(); + if (lines == null) { + return results; + } - if (exclude != null && line.contains(exclude)) { - continue; - } + for (String line : lines) { + if (include != null && !line.contains(include)) { + continue; + } - results.add(StringUtils.split(line, " \t")); - } + if (exclude != null && line.contains(exclude)) { + continue; + } - return results; + results.add(StringUtils.split(line, " \t")); } - public static List searchPids(String keyword) { - List results = search("ps -a -x -o pid,command", keyword, null); - return results.stream().map(v -> v[0]).collect(Collectors.toList()); - } + return results; + } - public static boolean existPid(int pid) { - try { - List results = search(Fmt.as("ps -p {}", pid), null, "PID"); - if (results.size() > 1) { - throw new GeaflowException("To much process found of pid {}", pid); - } + public static List searchPids(String keyword) { + List results = search("ps -a -x -o pid,command", keyword, null); + return results.stream().map(v -> v[0]).collect(Collectors.toList()); + } - for (String[] result : results) { - return result[0].equals(String.valueOf(pid)); - } + public static boolean existPid(int pid) { + try { + List results = search(Fmt.as("ps -p {}", pid), null, "PID"); + if (results.size() > 1) { + throw new GeaflowException("To much process found of pid {}", pid); + } - return false; + for (String[] result : results) { + return result[0].equals(String.valueOf(pid)); + } - } catch (Exception e) { - log.error("Process {} not found, {}", pid, e.getMessage()); - return false; - } - } + return false; - public static void killPid(int pid) { - execute(Fmt.as("kill -9 {}", pid)); + } catch (Exception e) { + log.error("Process {} not found, {}", pid, e.getMessage()); + return false; } + } - public static int execAsyncCommand(CommandLine command, long waitTime, String logFile, String finishFile) { - ProcessExecutor executor = new ProcessExecutor(); + public static void killPid(int pid) { + execute(Fmt.as("kill -9 {}", pid)); + } - try { - FileOutputStream outputStream = new FileOutputStream(logFile); - executor.setStreamHandler(new PumpStreamHandler(outputStream, outputStream)); - AsyncExecuteResultHandler handler = new AsyncExecuteResultHandler(command.toString(), outputStream, - outputStream); - handler.setFinishFile(finishFile); + public static int execAsyncCommand( + CommandLine command, long waitTime, String logFile, String finishFile) { + ProcessExecutor executor = new ProcessExecutor(); - // execute command - executor.execute(command, handler); + try { + FileOutputStream outputStream = new FileOutputStream(logFile); + executor.setStreamHandler(new PumpStreamHandler(outputStream, outputStream)); + AsyncExecuteResultHandler handler = + new AsyncExecuteResultHandler(command.toString(), outputStream, outputStream); + handler.setFinishFile(finishFile); - // wait async process start at least 1s - Thread.sleep(Math.max(waitTime, 1000)); - return executor.getPid(); + // execute command + executor.execute(command, handler); - } catch (Exception e) { - throw new GeaflowLogException("Execute command `{}` failed", command.toString(), e); - } + // wait async process start at least 1s + Thread.sleep(Math.max(waitTime, 1000)); + return executor.getPid(); + + } catch (Exception e) { + throw new GeaflowLogException("Execute command `{}` failed", command.toString(), e); } + } - private static String execCommand(String cmd, boolean async, long timeout) { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - ByteArrayOutputStream errorStream = new ByteArrayOutputStream(); + private static String execCommand(String cmd, boolean async, long timeout) { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ByteArrayOutputStream errorStream = new ByteArrayOutputStream(); - CommandLine command = CommandLine.parse(cmd); - ProcessExecutor executor = new ProcessExecutor(); + CommandLine command = CommandLine.parse(cmd); + ProcessExecutor executor = new ProcessExecutor(); - executor.setStreamHandler(new PumpStreamHandler(outputStream, errorStream)); - executor.setWatchdog(new ExecuteWatchdog(async ? -1 : timeout)); + executor.setStreamHandler(new PumpStreamHandler(outputStream, errorStream)); + executor.setWatchdog(new ExecuteWatchdog(async ? -1 : timeout)); - try { - if (async) { - executor.execute(command, new AsyncExecuteResultHandler(cmd, outputStream, errorStream)); + try { + if (async) { + executor.execute(command, new AsyncExecuteResultHandler(cmd, outputStream, errorStream)); - // wait async process start at least 1s - Thread.sleep(Math.max(timeout, 1000)); - return String.valueOf(executor.getPid()); - } + // wait async process start at least 1s + Thread.sleep(Math.max(timeout, 1000)); + return String.valueOf(executor.getPid()); + } - DefaultExecuteResultHandler handler = new DefaultExecuteResultHandler(); - executor.execute(command, handler); + DefaultExecuteResultHandler handler = new DefaultExecuteResultHandler(); + executor.execute(command, handler); - // wait sync process start at least 1s, at most 10s - handler.waitFor(Math.min(Math.max(timeout, 1000), 10000)); + // wait sync process start at least 1s, at most 10s + handler.waitFor(Math.min(Math.max(timeout, 1000), 10000)); - if (!handler.hasResult()) { - throw new GeaflowException("Command has no result"); - } + if (!handler.hasResult()) { + throw new GeaflowException("Command has no result"); + } - ExecuteException exception = handler.getException(); - if (exception != null) { - throw exception; - } + ExecuteException exception = handler.getException(); + if (exception != null) { + throw exception; + } - int exitValue = handler.getExitValue(); - if (exitValue != 0) { - throw new GeaflowException("Command exit({}) failed", exitValue); - } + int exitValue = handler.getExitValue(); + if (exitValue != 0) { + throw new GeaflowException("Command exit({}) failed", exitValue); + } - log.info("Execute command `{}` success", cmd); - return StreamUtils.copyToString(outputStream, StandardCharsets.UTF_8).trim(); + log.info("Execute command `{}` success", cmd); + return StreamUtils.copyToString(outputStream, StandardCharsets.UTF_8).trim(); - } catch (Exception e) { - String error = StreamUtils.copyToString(errorStream, StandardCharsets.UTF_8); - throw new GeaflowException("Execute command `{}` failed, error={}", cmd, JSON.toJSONString(error), e); - } + } catch (Exception e) { + String error = StreamUtils.copyToString(errorStream, StandardCharsets.UTF_8); + throw new GeaflowException( + "Execute command `{}` failed, error={}", cmd, JSON.toJSONString(error), e); } + } - private static class ProcessExecutor extends DefaultExecutor { + private static class ProcessExecutor extends DefaultExecutor { - @Getter - private int pid; + @Getter private int pid; - @Override - protected Process launch(CommandLine command, Map env, File dir) throws IOException { - Process process = super.launch(command, env, dir); + @Override + protected Process launch(CommandLine command, Map env, File dir) + throws IOException { + Process process = super.launch(command, env, dir); - try { - Field field = process.getClass().getDeclaredField("pid"); - field.setAccessible(true); - this.pid = (int) field.get(process); - log.debug("Start process `{}` with pid {}", command.toString(), pid); + try { + Field field = process.getClass().getDeclaredField("pid"); + field.setAccessible(true); + this.pid = (int) field.get(process); + log.debug("Start process `{}` with pid {}", command.toString(), pid); - } catch (Exception e) { - throw new GeaflowException("Get process pid failed", e); - } + } catch (Exception e) { + throw new GeaflowException("Get process pid failed", e); + } - return process; - } + return process; } + } - private static class AsyncExecuteResultHandler implements ExecuteResultHandler { + private static class AsyncExecuteResultHandler implements ExecuteResultHandler { - private final String command; + private final String command; - private final OutputStream outputStream; + private final OutputStream outputStream; - private final OutputStream errorStream; + private final OutputStream errorStream; - @Setter - private String finishFile; + @Setter private String finishFile; - public AsyncExecuteResultHandler(String command, OutputStream outputStream, OutputStream errorStream) { - this.command = command; - this.outputStream = outputStream; - this.errorStream = errorStream; - } + public AsyncExecuteResultHandler( + String command, OutputStream outputStream, OutputStream errorStream) { + this.command = command; + this.outputStream = outputStream; + this.errorStream = errorStream; + } - @Override - public void onProcessComplete(int exitValue) { - if (exitValue != 0) { - if (errorStream instanceof ByteArrayOutputStream) { - String error = StreamUtils.copyToString((ByteArrayOutputStream) errorStream, - StandardCharsets.UTF_8); - log.error("Execute async command `{}` exit({}) failed, error={}", command, exitValue, - JSON.toJSONString(error)); - - } else { - String msg = Fmt.as("Execute async command `{}` exit({}) failed\n", command, exitValue); - writeMessage(errorStream, msg); - } - - } else { - if (outputStream instanceof ByteArrayOutputStream) { - String output = StreamUtils.copyToString((ByteArrayOutputStream) outputStream, - StandardCharsets.UTF_8); - log.info("Execute async command `{}` success, output={}", command, JSON.toJSONString(output)); - - } else { - String msg = Fmt.as("Execute async command `{}` success\n", command); - writeMessage(outputStream, msg); - } - - if (finishFile != null) { - FileUtil.touch(finishFile); - } - } + @Override + public void onProcessComplete(int exitValue) { + if (exitValue != 0) { + if (errorStream instanceof ByteArrayOutputStream) { + String error = + StreamUtils.copyToString((ByteArrayOutputStream) errorStream, StandardCharsets.UTF_8); + log.error( + "Execute async command `{}` exit({}) failed, error={}", + command, + exitValue, + JSON.toJSONString(error)); + + } else { + String msg = Fmt.as("Execute async command `{}` exit({}) failed\n", command, exitValue); + writeMessage(errorStream, msg); } - @Override - public void onProcessFailed(ExecuteException e) { - if (errorStream instanceof ByteArrayOutputStream) { - String error = StreamUtils.copyToString((ByteArrayOutputStream) errorStream, StandardCharsets.UTF_8); - log.info("Execute async command `{}` failed, error={}", command, JSON.toJSONString(error), e); - - } else { - String msg = Fmt.as("Execute async command `{}` failed\n{}\n", command, - ExceptionUtils.getStackTrace(e)); - writeMessage(errorStream, msg); - } + } else { + if (outputStream instanceof ByteArrayOutputStream) { + String output = + StreamUtils.copyToString( + (ByteArrayOutputStream) outputStream, StandardCharsets.UTF_8); + log.info( + "Execute async command `{}` success, output={}", command, JSON.toJSONString(output)); + + } else { + String msg = Fmt.as("Execute async command `{}` success\n", command); + writeMessage(outputStream, msg); } - private void writeMessage(OutputStream stream, String message) { - try { - stream.write(message.getBytes()); - stream.flush(); - - } catch (Exception e) { - log.error("Write message '{}' to stream failed", message, e); - } + if (finishFile != null) { + FileUtil.touch(finishFile); } + } } + @Override + public void onProcessFailed(ExecuteException e) { + if (errorStream instanceof ByteArrayOutputStream) { + String error = + StreamUtils.copyToString((ByteArrayOutputStream) errorStream, StandardCharsets.UTF_8); + log.info( + "Execute async command `{}` failed, error={}", command, JSON.toJSONString(error), e); + + } else { + String msg = + Fmt.as( + "Execute async command `{}` failed\n{}\n", + command, + ExceptionUtils.getStackTrace(e)); + writeMessage(errorStream, msg); + } + } + + private void writeMessage(OutputStream stream, String message) { + try { + stream.write(message.getBytes()); + stream.flush(); + + } catch (Exception e) { + log.error("Write message '{}' to stream failed", message, e); + } + } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ReflectUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ReflectUtil.java index 913d79ae6..fdd149d2d 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ReflectUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ReflectUtil.java @@ -27,27 +27,27 @@ public class ReflectUtil { - public static List getFields(Class clazz, Class root, Predicate filter) { - if (clazz == null) { - return null; - } - - if (root == null) { - return doGetFields(clazz, Object.class, filter); - } + public static List getFields( + Class clazz, Class root, Predicate filter) { + if (clazz == null) { + return null; + } - return doGetFields(clazz, root, filter); + if (root == null) { + return doGetFields(clazz, Object.class, filter); } - private static List doGetFields(Class clazz, Class root, Predicate filter) { - List fields = new ArrayList<>(); + return doGetFields(clazz, root, filter); + } - if (!root.equals(clazz)) { - fields.addAll(doGetFields(clazz.getSuperclass(), root, filter)); - } + private static List doGetFields(Class clazz, Class root, Predicate filter) { + List fields = new ArrayList<>(); - Arrays.stream(clazz.getDeclaredFields()).filter(filter).forEach(fields::add); - return fields; + if (!root.equals(clazz)) { + fields.addAll(doGetFields(clazz.getSuperclass(), root, filter)); } + Arrays.stream(clazz.getDeclaredFields()).filter(filter).forEach(fields::add); + return fields; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/RetryUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/RetryUtil.java index 7658ad679..12739f70d 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/RetryUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/RetryUtil.java @@ -20,25 +20,26 @@ package org.apache.geaflow.console.common.util; import java.util.concurrent.Callable; + import org.apache.geaflow.console.common.util.exception.GeaflowException; public class RetryUtil { - public static T exec(Callable function, final int retryCount, long retryIntervalMs) { - int count = retryCount; - while (count > 0) { - try { - return function.call(); - - } catch (Exception e) { - if (--count == 0) { - throw new GeaflowException("exec failed withRetry", e); - } + public static T exec(Callable function, final int retryCount, long retryIntervalMs) { + int count = retryCount; + while (count > 0) { + try { + return function.call(); - ThreadUtil.sleepMilliSeconds(retryIntervalMs); - } + } catch (Exception e) { + if (--count == 0) { + throw new GeaflowException("exec failed withRetry", e); } - return null; + ThreadUtil.sleepMilliSeconds(retryIntervalMs); + } } + + return null; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/SpringUtils.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/SpringUtils.java index bcdfadee0..c30b5d548 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/SpringUtils.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/SpringUtils.java @@ -19,32 +19,30 @@ package org.apache.geaflow.console.common.util; -import lombok.Getter; import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.stereotype.Component; +import lombok.Getter; + @Component("springUtils") public class SpringUtils implements ApplicationContextAware { - @Getter - private static ApplicationContext applicationContext; - - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - if (SpringUtils.applicationContext == null) { - SpringUtils.applicationContext = applicationContext; - } - } - - public static T getBean(Class clazz) { - return applicationContext.getBean(clazz); - } + @Getter private static ApplicationContext applicationContext; - public static Object getBean(String name) { - return applicationContext.getBean(name); + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + if (SpringUtils.applicationContext == null) { + SpringUtils.applicationContext = applicationContext; } + } + public static T getBean(Class clazz) { + return applicationContext.getBean(clazz); + } + public static Object getBean(String name) { + return applicationContext.getBean(name); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ThreadUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ThreadUtil.java index 53d08f2f2..7c3980410 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ThreadUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ThreadUtil.java @@ -20,25 +20,25 @@ package org.apache.geaflow.console.common.util; import java.util.concurrent.TimeUnit; + import lombok.extern.slf4j.Slf4j; @Slf4j public class ThreadUtil { - public static void sleepSecond(long second) { - try { - TimeUnit.SECONDS.sleep(second); - } catch (InterruptedException e) { - log.error("sleep {} seconds interrupted", second); - } + public static void sleepSecond(long second) { + try { + TimeUnit.SECONDS.sleep(second); + } catch (InterruptedException e) { + log.error("sleep {} seconds interrupted", second); } + } - public static void sleepMilliSeconds(long ms) { - try { - TimeUnit.MILLISECONDS.sleep(ms); - } catch (InterruptedException e) { - log.error("sleep {} ms interrupted", ms); - } + public static void sleepMilliSeconds(long ms) { + try { + TimeUnit.MILLISECONDS.sleep(ms); + } catch (InterruptedException e) { + log.error("sleep {} ms interrupted", ms); } - + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/VelocityUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/VelocityUtil.java index cbc92f24d..6c0513260 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/VelocityUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/VelocityUtil.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.commons.lang3.StringUtils; import org.apache.velocity.Template; import org.apache.velocity.VelocityContext; @@ -35,195 +36,208 @@ public class VelocityUtil { - private static VelocityEngine VELOCITY_ENGINE = null; + private static VelocityEngine VELOCITY_ENGINE = null; - private static final ConcurrentHashMap TEMPLATES = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap TEMPLATES = new ConcurrentHashMap<>(); - private static final String END_FLAG = "END_FLAG"; + private static final String END_FLAG = "END_FLAG"; - private static void initEngineIfNeeded() { - if (VELOCITY_ENGINE == null) { - synchronized (VelocityUtil.class) { - try { + private static void initEngineIfNeeded() { + if (VELOCITY_ENGINE == null) { + synchronized (VelocityUtil.class) { + try { - VelocityEngine engine = new VelocityEngine(); - engine.setProperty(RuntimeConstants.RESOURCE_LOADER, "classpath"); - engine.setProperty("classpath.resource.loader.class", FormatResourceLoader.class.getName()); + VelocityEngine engine = new VelocityEngine(); + engine.setProperty(RuntimeConstants.RESOURCE_LOADER, "classpath"); + engine.setProperty( + "classpath.resource.loader.class", FormatResourceLoader.class.getName()); - engine.init(); - VELOCITY_ENGINE = engine; + engine.init(); + VELOCITY_ENGINE = engine; - } catch (Exception e) { - throw new RuntimeException("Init Velocity Engine Failed", e); - } - } + } catch (Exception e) { + throw new RuntimeException("Init Velocity Engine Failed", e); } + } } + } - public static String applyResource(String resourceName, Map params) { - if (StringUtils.isBlank(resourceName)) { - throw new RuntimeException("Invalid Resource Name"); - } + public static String applyResource(String resourceName, Map params) { + if (StringUtils.isBlank(resourceName)) { + throw new RuntimeException("Invalid Resource Name"); + } - if (params == null) { - params = new HashMap<>(); - } + if (params == null) { + params = new HashMap<>(); + } - params.put(END_FLAG, END_FLAG); - initEngineIfNeeded(); + params.put(END_FLAG, END_FLAG); + initEngineIfNeeded(); - try { - Template template = TEMPLATES.computeIfAbsent(resourceName, s -> { + try { + Template template = + TEMPLATES.computeIfAbsent( + resourceName, + s -> { try { - return VELOCITY_ENGINE.getTemplate(s, "utf-8"); + return VELOCITY_ENGINE.getTemplate(s, "utf-8"); } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } - }); + }); - VelocityContext context = new VelocityContext(); - params.forEach(context::put); + VelocityContext context = new VelocityContext(); + params.forEach(context::put); - StringWriter writer = new StringWriter(); - template.merge(context, writer); + StringWriter writer = new StringWriter(); + template.merge(context, writer); - return StringUtils.replacePattern(writer.toString(), ",\n*\\s*" + END_FLAG, ""); + return StringUtils.replacePattern(writer.toString(), ",\n*\\s*" + END_FLAG, ""); - } catch (Exception e) { - throw new RuntimeException(e); - } + } catch (Exception e) { + throw new RuntimeException(e); } - - public static class FormatResourceLoader extends ClasspathResourceLoader { - - public static class VTLIndentationGlobber extends FilterInputStream { - - protected String buffer = ""; - protected int bufpos = 0; - - protected enum State { - defstate, hash, comment, directive, schmoo, eol, eof - } - - protected State state = - State.defstate; - - public VTLIndentationGlobber(InputStream is) { - super(is); - } - - public int read() throws IOException { - while (true) { - switch (state) { - case defstate: { - int ch = in.read(); - switch (ch) { - case (int) '#': - state = State.hash; - buffer = ""; - bufpos = 0; - return ch; - case (int) ' ': - case (int) '\t': - buffer += (char) ' '; - break; - case -1: - state = State.eof; - break; - default: - buffer += (char) ch; - state = State.schmoo; - break; - } - break; - } - case eol: - if (bufpos < buffer.length()) { - return (int) buffer.charAt(bufpos++); - } else { - state = State.defstate; - buffer = ""; - bufpos = 0; - return '\n'; - } - case eof: - if (bufpos < buffer.length()) { - return (int) buffer.charAt(bufpos++); - } else { - return -1; - } - case hash: { - int ch = (int) in.read(); - switch (ch) { - case (int) '#': - state = State.directive; - return ch; - case -1: - state = State.eof; - return -1; - default: - state = State.directive; - buffer = "##"; - return ch; - } - } - case directive: { - int ch = (int) in.read(); - if (ch == (int) '\n') { - state = State.eol; - break; - } else if (ch == -1) { - state = State.eof; - break; - } else { - return ch; - } - } - case schmoo: { - int ch = (int) in.read(); - if (ch == (int) '\n') { - state = State.eol; - break; - } else if (ch == -1) { - state = State.eof; - break; - } else { - buffer += (char) ch; - return (int) buffer.charAt(bufpos++); - } - } - default: - break; - } + } + + public static class FormatResourceLoader extends ClasspathResourceLoader { + + public static class VTLIndentationGlobber extends FilterInputStream { + + protected String buffer = ""; + protected int bufpos = 0; + + protected enum State { + defstate, + hash, + comment, + directive, + schmoo, + eol, + eof + } + + protected State state = State.defstate; + + public VTLIndentationGlobber(InputStream is) { + super(is); + } + + public int read() throws IOException { + while (true) { + switch (state) { + case defstate: + { + int ch = in.read(); + switch (ch) { + case (int) '#': + state = State.hash; + buffer = ""; + bufpos = 0; + return ch; + case (int) ' ': + case (int) '\t': + buffer += (char) ' '; + break; + case -1: + state = State.eof; + break; + default: + buffer += (char) ch; + state = State.schmoo; + break; + } + break; + } + case eol: + if (bufpos < buffer.length()) { + return (int) buffer.charAt(bufpos++); + } else { + state = State.defstate; + buffer = ""; + bufpos = 0; + return '\n'; + } + case eof: + if (bufpos < buffer.length()) { + return (int) buffer.charAt(bufpos++); + } else { + return -1; + } + case hash: + { + int ch = (int) in.read(); + switch (ch) { + case (int) '#': + state = State.directive; + return ch; + case -1: + state = State.eof; + return -1; + default: + state = State.directive; + buffer = "##"; + return ch; + } + } + case directive: + { + int ch = (int) in.read(); + if (ch == (int) '\n') { + state = State.eol; + break; + } else if (ch == -1) { + state = State.eof; + break; + } else { + return ch; } - } - - public int read(byte[] b, int off, int len) throws IOException { - int i; - int ok = 0; - while (len-- > 0) { - i = read(); - if (i == -1) { - return (ok == 0) ? -1 : ok; - } - b[off++] = (byte) i; - ok++; + } + case schmoo: + { + int ch = (int) in.read(); + if (ch == (int) '\n') { + state = State.eol; + break; + } else if (ch == -1) { + state = State.eof; + break; + } else { + buffer += (char) ch; + return (int) buffer.charAt(bufpos++); } - return ok; - } + } + default: + break; + } + } + } + + public int read(byte[] b, int off, int len) throws IOException { + int i; + int ok = 0; + while (len-- > 0) { + i = read(); + if (i == -1) { + return (ok == 0) ? -1 : ok; + } + b[off++] = (byte) i; + ok++; + } + return ok; + } - public int read(byte[] b) throws IOException { - return read(b, 0, b.length); - } + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } - public boolean markSupported() { - return false; - } - } + public boolean markSupported() { + return false; + } + } - @Override - public synchronized InputStream getResourceStream(String name) { - return new VTLIndentationGlobber(super.getResourceStream(name)); - } + @Override + public synchronized InputStream getResourceStream(String name) { + return new VTLIndentationGlobber(super.getResourceStream(name)); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ZipUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ZipUtil.java index 8415be1bf..b8f1d8fa6 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ZipUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/ZipUtil.java @@ -34,124 +34,126 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.console.common.util.exception.GeaflowException; public class ZipUtil { - public static void zipToFile(String fileName, List entries) throws IOException { - try (InputStream zipInputStream = buildZipInputStream(entries)) { - FileUtils.copyInputStreamToFile(zipInputStream, new File(fileName)); - } + public static void zipToFile(String fileName, List entries) throws IOException { + try (InputStream zipInputStream = buildZipInputStream(entries)) { + FileUtils.copyInputStreamToFile(zipInputStream, new File(fileName)); } + } - public static InputStream buildZipInputStream(GeaflowZipEntry entry) throws IOException { - return buildZipInputStream(Collections.singletonList(entry)); - } + public static InputStream buildZipInputStream(GeaflowZipEntry entry) throws IOException { + return buildZipInputStream(Collections.singletonList(entry)); + } - public static InputStream buildZipInputStream(List entries) throws IOException { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - writeZipStream(bos, entries); - byte[] zipBytes = bos.toByteArray(); - return new ByteArrayInputStream(zipBytes); - } + public static InputStream buildZipInputStream(List entries) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { + writeZipStream(bos, entries); + byte[] zipBytes = bos.toByteArray(); + return new ByteArrayInputStream(zipBytes); } - - private static void writeZipStream(OutputStream bos, List entries) throws IOException { - try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) { - for (GeaflowZipEntry entry : entries) { - try (InputStream inputStream = entry.getInputStream()) { - zipOutputStream.putNextEntry(new ZipEntry(entry.getEntryName())); - byte[] buff = new byte[1024]; - int len = 0; - while ((len = inputStream.read(buff)) > -1) { - zipOutputStream.write(buff, 0, len); - } - zipOutputStream.closeEntry(); - } - } - zipOutputStream.flush(); + } + + private static void writeZipStream(OutputStream bos, List entries) + throws IOException { + try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) { + for (GeaflowZipEntry entry : entries) { + try (InputStream inputStream = entry.getInputStream()) { + zipOutputStream.putNextEntry(new ZipEntry(entry.getEntryName())); + byte[] buff = new byte[1024]; + int len = 0; + while ((len = inputStream.read(buff)) > -1) { + zipOutputStream.write(buff, 0, len); + } + zipOutputStream.closeEntry(); } + } + zipOutputStream.flush(); } - - public static void unzip(File file) { - String dir = file.getParent(); - try (ZipInputStream zipInputStream = new ZipInputStream(Files.newInputStream(file.toPath()))) { - ZipEntry entry; - while ((entry = zipInputStream.getNextEntry()) != null) { - String filePath = dir + "/" + entry.getName(); - File outFile = new File(filePath); - if (entry.isDirectory()) { - if (!outFile.exists()) { - outFile.mkdirs(); - } - continue; - } - try (FileOutputStream fileOutputStream = new FileOutputStream(filePath)) { - byte[] buf = new byte[1024 * 1024]; - int num; - while ((num = zipInputStream.read(buf, 0, buf.length)) > -1) { - fileOutputStream.write(buf, 0, num); - } - fileOutputStream.flush(); - } - zipInputStream.closeEntry(); - } - } catch (IOException e) { - throw new GeaflowException("Unzip file {} failed", file.getPath(), e); + } + + public static void unzip(File file) { + String dir = file.getParent(); + try (ZipInputStream zipInputStream = new ZipInputStream(Files.newInputStream(file.toPath()))) { + ZipEntry entry; + while ((entry = zipInputStream.getNextEntry()) != null) { + String filePath = dir + "/" + entry.getName(); + File outFile = new File(filePath); + if (entry.isDirectory()) { + if (!outFile.exists()) { + outFile.mkdirs(); + } + continue; + } + try (FileOutputStream fileOutputStream = new FileOutputStream(filePath)) { + byte[] buf = new byte[1024 * 1024]; + int num; + while ((num = zipInputStream.read(buf, 0, buf.length)) > -1) { + fileOutputStream.write(buf, 0, num); + } + fileOutputStream.flush(); } + zipInputStream.closeEntry(); + } + } catch (IOException e) { + throw new GeaflowException("Unzip file {} failed", file.getPath(), e); } + } - public interface GeaflowZipEntry { + public interface GeaflowZipEntry { - String getEntryName(); + String getEntryName(); - InputStream getInputStream(); - } + InputStream getInputStream(); + } - public static class MemoryZipEntry implements GeaflowZipEntry { + public static class MemoryZipEntry implements GeaflowZipEntry { - private String entryName; - private String content; + private String entryName; + private String content; - public MemoryZipEntry(String entryName, String content) { - this.entryName = entryName; - this.content = content; - } + public MemoryZipEntry(String entryName, String content) { + this.entryName = entryName; + this.content = content; + } - @Override - public String getEntryName() { - return entryName; - } + @Override + public String getEntryName() { + return entryName; + } - @Override - public InputStream getInputStream() { - return new ByteArrayInputStream(content.getBytes()); - } + @Override + public InputStream getInputStream() { + return new ByteArrayInputStream(content.getBytes()); } + } - public static class FileZipEntry implements GeaflowZipEntry { + public static class FileZipEntry implements GeaflowZipEntry { - private String entryName; - private File file; + private String entryName; + private File file; - public FileZipEntry(String entryName, File file) { - this.entryName = entryName; - this.file = file; - } + public FileZipEntry(String entryName, File file) { + this.entryName = entryName; + this.file = file; + } - @Override - public String getEntryName() { - return entryName; - } + @Override + public String getEntryName() { + return entryName; + } - @Override - public InputStream getInputStream() { - try { - return new FileInputStream(file); - } catch (FileNotFoundException e) { - throw new RuntimeException(e); - } - } + @Override + public InputStream getInputStream() { + try { + return new FileInputStream(file); + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/ContextHolder.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/ContextHolder.java index 9e6952847..eadad2869 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/ContextHolder.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/ContextHolder.java @@ -21,18 +21,17 @@ public class ContextHolder { - private static final ThreadLocal HOLDER = new ThreadLocal<>(); + private static final ThreadLocal HOLDER = new ThreadLocal<>(); - public static void init() { - HOLDER.set(new GeaflowContext()); - } + public static void init() { + HOLDER.set(new GeaflowContext()); + } - public static GeaflowContext get() { - return HOLDER.get(); - } - - public static void destroy() { - HOLDER.remove(); - } + public static GeaflowContext get() { + return HOLDER.get(); + } + public static void destroy() { + HOLDER.remove(); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/GeaflowContext.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/GeaflowContext.java index 31b500a0f..0c70689ab 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/GeaflowContext.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/context/GeaflowContext.java @@ -21,29 +21,31 @@ import java.util.LinkedHashSet; import java.util.Set; + import javax.servlet.http.HttpServletRequest; + +import org.apache.geaflow.console.common.util.type.GeaflowRoleType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowRoleType; @Getter @Setter public class GeaflowContext { - public static final String API_PREFIX = "/api"; - - private HttpServletRequest request; + public static final String API_PREFIX = "/api"; - private String taskId; + private HttpServletRequest request; - private String userId; + private String taskId; - private boolean systemSession; + private String userId; - private String tenantId; + private boolean systemSession; - private String sessionToken; + private String tenantId; - private Set roleTypes = new LinkedHashSet<>(); + private String sessionToken; + private Set roleTypes = new LinkedHashSet<>(); } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowCompileException.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowCompileException.java index 54c684b59..cb30a3da8 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowCompileException.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowCompileException.java @@ -21,44 +21,45 @@ import java.util.ArrayList; import java.util.List; -import lombok.AllArgsConstructor; -import lombok.Getter; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.Fmt; -public class GeaflowCompileException extends GeaflowException { +import lombok.AllArgsConstructor; +import lombok.Getter; - public GeaflowCompileException(String fmt, Object... args) { - super(fmt, args); - } +public class GeaflowCompileException extends GeaflowException { - public String getDisplayMessage() { - List causes = new ArrayList<>(); - for (Throwable error = getCause(); error != null; error = error.getCause()) { - String clazz = error.getClass().getSimpleName(); - String message = StringUtils.substringBefore(error.getMessage(), "\n"); - causes.add(new CauseInfo(clazz, message)); - } + public GeaflowCompileException(String fmt, Object... args) { + super(fmt, args); + } - StringBuilder sb = new StringBuilder(); - sb.append(getMessage()); - sb.append("\nCaused by:"); - for (int i = 0; i < causes.size(); i++) { - CauseInfo causeInfo = causes.get(i); - String align = StringUtils.repeat(">", i + 1); - sb.append(Fmt.as("\n{} [{}]: {}", align, causeInfo.getClassName(), causeInfo.getMessage())); - } + public String getDisplayMessage() { + List causes = new ArrayList<>(); + for (Throwable error = getCause(); error != null; error = error.getCause()) { + String clazz = error.getClass().getSimpleName(); + String message = StringUtils.substringBefore(error.getMessage(), "\n"); + causes.add(new CauseInfo(clazz, message)); + } - return sb.toString(); + StringBuilder sb = new StringBuilder(); + sb.append(getMessage()); + sb.append("\nCaused by:"); + for (int i = 0; i < causes.size(); i++) { + CauseInfo causeInfo = causes.get(i); + String align = StringUtils.repeat(">", i + 1); + sb.append(Fmt.as("\n{} [{}]: {}", align, causeInfo.getClassName(), causeInfo.getMessage())); } - @Getter - @AllArgsConstructor - private static class CauseInfo { + return sb.toString(); + } - private String className; + @Getter + @AllArgsConstructor + private static class CauseInfo { - private String message; + private String className; - } + private String message; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowException.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowException.java index 726086b00..80826623b 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowException.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowException.java @@ -23,12 +23,11 @@ public class GeaflowException extends RuntimeException { - public GeaflowException(String fmt, Object... args) { - super(MessageFormatter.arrayFormat(fmt, args).getMessage()); + public GeaflowException(String fmt, Object... args) { + super(MessageFormatter.arrayFormat(fmt, args).getMessage()); - if (args != null && args.length > 0 && args[args.length - 1] instanceof Throwable) { - this.initCause((Throwable) args[args.length - 1]); - } + if (args != null && args.length > 0 && args[args.length - 1] instanceof Throwable) { + this.initCause((Throwable) args[args.length - 1]); } - + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowExceptionClassifier.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowExceptionClassifier.java index 312f4b7dd..b5a96bf80 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowExceptionClassifier.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowExceptionClassifier.java @@ -24,63 +24,63 @@ public class GeaflowExceptionClassifier { - public GeaflowExceptionClassificationResult classify(Throwable error) { - GeaflowApiResponseCode code; - String message = error.getMessage(); - - if (StringUtils.isBlank(message)) { - message = error.getClass().getSimpleName(); - } - - if (error instanceof GeaflowSecurityException) { - code = GeaflowApiResponseCode.FORBIDDEN; - - } else if (error instanceof GeaflowIllegalException) { - code = GeaflowApiResponseCode.ILLEGAL; - - } else if (error instanceof GeaflowCompileException) { - code = GeaflowApiResponseCode.ERROR; - message = ((GeaflowCompileException) error).getDisplayMessage(); - - } else if (error instanceof GeaflowException) { - code = GeaflowApiResponseCode.ERROR; - - } else if (error instanceof IllegalArgumentException) { - code = GeaflowApiResponseCode.ILLEGAL; - - } else if (error instanceof NullPointerException) { - code = GeaflowApiResponseCode.ERROR; - - } else { - code = GeaflowApiResponseCode.FAIL; - // Traverse to root cause for better error message - while (error.getCause() != null) { - error = error.getCause(); - } - message = error.getMessage(); - if (StringUtils.isBlank(message)) { - message = error.getClass().getSimpleName(); - } - } - - return new GeaflowExceptionClassificationResult(code, message); + public GeaflowExceptionClassificationResult classify(Throwable error) { + GeaflowApiResponseCode code; + String message = error.getMessage(); + + if (StringUtils.isBlank(message)) { + message = error.getClass().getSimpleName(); + } + + if (error instanceof GeaflowSecurityException) { + code = GeaflowApiResponseCode.FORBIDDEN; + + } else if (error instanceof GeaflowIllegalException) { + code = GeaflowApiResponseCode.ILLEGAL; + + } else if (error instanceof GeaflowCompileException) { + code = GeaflowApiResponseCode.ERROR; + message = ((GeaflowCompileException) error).getDisplayMessage(); + + } else if (error instanceof GeaflowException) { + code = GeaflowApiResponseCode.ERROR; + + } else if (error instanceof IllegalArgumentException) { + code = GeaflowApiResponseCode.ILLEGAL; + + } else if (error instanceof NullPointerException) { + code = GeaflowApiResponseCode.ERROR; + + } else { + code = GeaflowApiResponseCode.FAIL; + // Traverse to root cause for better error message + while (error.getCause() != null) { + error = error.getCause(); + } + message = error.getMessage(); + if (StringUtils.isBlank(message)) { + message = error.getClass().getSimpleName(); + } } - public static class GeaflowExceptionClassificationResult { - private final GeaflowApiResponseCode code; - private final String message; + return new GeaflowExceptionClassificationResult(code, message); + } - public GeaflowExceptionClassificationResult(GeaflowApiResponseCode code, String message) { - this.code = code; - this.message = message; - } + public static class GeaflowExceptionClassificationResult { + private final GeaflowApiResponseCode code; + private final String message; - public GeaflowApiResponseCode getCode() { - return code; - } + public GeaflowExceptionClassificationResult(GeaflowApiResponseCode code, String message) { + this.code = code; + this.message = message; + } + + public GeaflowApiResponseCode getCode() { + return code; + } - public String getMessage() { - return message; - } + public String getMessage() { + return message; } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowIllegalException.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowIllegalException.java index 3a46205fb..bb5c9e2c2 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowIllegalException.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowIllegalException.java @@ -21,7 +21,7 @@ public class GeaflowIllegalException extends GeaflowException { - public GeaflowIllegalException(String fmt, Object... args) { - super(fmt, args); - } + public GeaflowIllegalException(String fmt, Object... args) { + super(fmt, args); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowLogException.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowLogException.java index b07d7675c..38b6535d1 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowLogException.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowLogException.java @@ -24,8 +24,8 @@ @Slf4j public class GeaflowLogException extends GeaflowException { - public GeaflowLogException(String fmt, Object... args) { - super(fmt, args); - log.error(getMessage(), getCause()); - } + public GeaflowLogException(String fmt, Object... args) { + super(fmt, args); + log.error(getMessage(), getCause()); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowSecurityException.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowSecurityException.java index 78652889e..6d5edff5a 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowSecurityException.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/exception/GeaflowSecurityException.java @@ -21,7 +21,7 @@ public class GeaflowSecurityException extends GeaflowException { - public GeaflowSecurityException(String fmt, Object... args) { - super(fmt, args); - } + public GeaflowSecurityException(String fmt, Object... args) { + super(fmt, args); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/GeaflowInvocationHandler.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/GeaflowInvocationHandler.java index d4c71ae83..e1bbcfa01 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/GeaflowInvocationHandler.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/GeaflowInvocationHandler.java @@ -22,41 +22,46 @@ import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; + +import org.apache.geaflow.console.common.util.LoaderSwitchUtil; + import lombok.AllArgsConstructor; import lombok.Getter; -import org.apache.geaflow.console.common.util.LoaderSwitchUtil; @Getter @AllArgsConstructor public class GeaflowInvocationHandler implements InvocationHandler { - private final ClassLoader targetClassLoader; + private final ClassLoader targetClassLoader; - private final Object targetInstance; + private final Object targetInstance; - @Override - public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - try { - Class targetClass = targetInstance.getClass(); + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + try { + Class targetClass = targetInstance.getClass(); - return LoaderSwitchUtil.call(targetClassLoader, () -> { - // prepare target args - Object[] targetArgs = ProxyUtil.getTargetArgs(args); + return LoaderSwitchUtil.call( + targetClassLoader, + () -> { + // prepare target args + Object[] targetArgs = ProxyUtil.getTargetArgs(args); - // prepare target method - Method targetMethod = ProxyUtil.getTargetMethod(method, targetClass); + // prepare target method + Method targetMethod = ProxyUtil.getTargetMethod(method, targetClass); - // invoke target method - Object targetResult = targetMethod.invoke(targetInstance, targetArgs); + // invoke target method + Object targetResult = targetMethod.invoke(targetInstance, targetArgs); - // proxy target result - return ProxyUtil.proxyInstance(targetClassLoader, method.getGenericReturnType(), targetResult); - }); - } catch (Exception e) { - if (e instanceof InvocationTargetException) { - throw ((InvocationTargetException) e).getTargetException(); - } - throw e; - } + // proxy target result + return ProxyUtil.proxyInstance( + targetClassLoader, method.getGenericReturnType(), targetResult); + }); + } catch (Exception e) { + if (e instanceof InvocationTargetException) { + throw ((InvocationTargetException) e).getTargetException(); + } + throw e; } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyClass.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyClass.java index 4c6715c7b..11d9c9ddb 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyClass.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyClass.java @@ -30,6 +30,5 @@ @Documented public @interface ProxyClass { - String value(); - + String value(); } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyUtil.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyUtil.java index c508a475d..04ac1ae21 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyUtil.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/proxy/ProxyUtil.java @@ -32,186 +32,195 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; + import org.apache.geaflow.console.common.util.LoaderSwitchUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; public class ProxyUtil { - public static T newInstance(ClassLoader targetClassLoader, Class clazz, Object... args) { - try { - return LoaderSwitchUtil.call(targetClassLoader, () -> { - ProxyClass proxyClass = clazz.getAnnotation(ProxyClass.class); - if (proxyClass == null) { - throw new GeaflowException("Use @ProxyClass on class {}", clazz.getSimpleName()); - } + public static T newInstance(ClassLoader targetClassLoader, Class clazz, Object... args) { + try { + return LoaderSwitchUtil.call( + targetClassLoader, + () -> { + ProxyClass proxyClass = clazz.getAnnotation(ProxyClass.class); + if (proxyClass == null) { + throw new GeaflowException("Use @ProxyClass on class {}", clazz.getSimpleName()); + } - String targetClassName = proxyClass.value(); - Class targetClass = targetClassLoader.loadClass(targetClassName); + String targetClassName = proxyClass.value(); + Class targetClass = targetClassLoader.loadClass(targetClassName); - Object[] targetArgs = getTargetArgs(args); - Constructor constructor = getConstructor(targetClass, targetArgs); - Object targetInstance = constructor.newInstance(targetArgs); + Object[] targetArgs = getTargetArgs(args); + Constructor constructor = getConstructor(targetClass, targetArgs); + Object targetInstance = constructor.newInstance(targetArgs); - return clazz.cast(proxyInstance(targetClassLoader, clazz, targetInstance)); - }); + return clazz.cast(proxyInstance(targetClassLoader, clazz, targetInstance)); + }); - } catch (Exception e) { - throw new GeaflowException("Create instance of {} failed", clazz.getSimpleName(), e); - } + } catch (Exception e) { + throw new GeaflowException("Create instance of {} failed", clazz.getSimpleName(), e); } + } + + protected static Constructor getConstructor(Class clazz, Object[] args) { + Constructor[] constructors = clazz.getConstructors(); + for (Constructor constructor : constructors) { + // parameter count check + if (constructor.getParameterCount() != args.length) { + continue; + } + + boolean matched = true; + Class[] parameterTypes = constructor.getParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + Object targetArg = args[i]; + + if (targetArg == null) { + // primitive type not allowed null + if (!Object.class.isAssignableFrom(parameterTypes[i])) { + matched = false; + break; + } - protected static Constructor getConstructor(Class clazz, Object[] args) { - Constructor[] constructors = clazz.getConstructors(); - for (Constructor constructor : constructors) { - // parameter count check - if (constructor.getParameterCount() != args.length) { - continue; - } - - boolean matched = true; - Class[] parameterTypes = constructor.getParameterTypes(); - for (int i = 0; i < parameterTypes.length; i++) { - Object targetArg = args[i]; - - if (targetArg == null) { - // primitive type not allowed null - if (!Object.class.isAssignableFrom(parameterTypes[i])) { - matched = false; - break; - } - - } else { - // type compatible - if (!parameterTypes[i].isAssignableFrom(targetArg.getClass())) { - matched = false; - break; - } - } - } - - if (matched) { - return constructor; - } + } else { + // type compatible + if (!parameterTypes[i].isAssignableFrom(targetArg.getClass())) { + matched = false; + break; + } } + } - throw new GeaflowException("No compatible constructor found of {}", clazz.getSimpleName()); + if (matched) { + return constructor; + } } - protected static Object[] getTargetArgs(Object[] args) { - if (args == null) { - return null; - } - - Object[] targetArgs = args.clone(); - for (int i = 0; i < args.length; i++) { - Object arg = args[i]; - if (arg instanceof Proxy) { - targetArgs[i] = ((GeaflowInvocationHandler) Proxy.getInvocationHandler(arg)).getTargetInstance(); - } - } + throw new GeaflowException("No compatible constructor found of {}", clazz.getSimpleName()); + } - return targetArgs; + protected static Object[] getTargetArgs(Object[] args) { + if (args == null) { + return null; } - protected static Method getTargetMethod(Method method, Class targetClass) { - try { - ClassLoader targetClassLoader = targetClass.getClassLoader(); - Class[] parameterTypes = method.getParameterTypes(); - - // cast to target parameter types - Class[] targetParameterTypes = parameterTypes.clone(); - for (int i = 0; i < targetParameterTypes.length; i++) { - ProxyClass targetParameterClass = targetParameterTypes[i].getAnnotation(ProxyClass.class); - if (targetParameterClass != null) { - targetParameterTypes[i] = targetClassLoader.loadClass(targetParameterClass.value()); - } - } - - // get target method - return targetClass.getMethod(method.getName(), targetParameterTypes); - - } catch (Exception e) { - throw new GeaflowException("Get target method {} failed", method.getName(), e); - } + Object[] targetArgs = args.clone(); + for (int i = 0; i < args.length; i++) { + Object arg = args[i]; + if (arg instanceof Proxy) { + targetArgs[i] = + ((GeaflowInvocationHandler) Proxy.getInvocationHandler(arg)).getTargetInstance(); + } } - protected static Object proxyInstance(ClassLoader targetClassLoader, Type type, Object targetInstance) { - if (type instanceof Class) { - Class clazz = (Class) type; - - // primitive type - if (clazz.isPrimitive()) { - return targetInstance; - } + return targetArgs; + } - // cast directly - if (clazz.getAnnotation(ProxyClass.class) == null || targetInstance == null) { - return clazz.cast(targetInstance); - } + protected static Method getTargetMethod(Method method, Class targetClass) { + try { + ClassLoader targetClassLoader = targetClass.getClassLoader(); + Class[] parameterTypes = method.getParameterTypes(); - // proxy instance - Object proxyInstance = Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, - new GeaflowInvocationHandler(targetClassLoader, targetInstance)); - return clazz.cast(proxyInstance); + // cast to target parameter types + Class[] targetParameterTypes = parameterTypes.clone(); + for (int i = 0; i < targetParameterTypes.length; i++) { + ProxyClass targetParameterClass = targetParameterTypes[i].getAnnotation(ProxyClass.class); + if (targetParameterClass != null) { + targetParameterTypes[i] = targetClassLoader.loadClass(targetParameterClass.value()); } + } - if (type instanceof ParameterizedType) { - ParameterizedType parameterizedType = (ParameterizedType) type; - Class clazz = (Class) parameterizedType.getRawType(); - Type[] elementType = parameterizedType.getActualTypeArguments(); + // get target method + return targetClass.getMethod(method.getName(), targetParameterTypes); - // proxy collection type - if (Collection.class.isAssignableFrom(clazz)) { - return proxyCollection(targetClassLoader, clazz, elementType[0], targetInstance); - } + } catch (Exception e) { + throw new GeaflowException("Get target method {} failed", method.getName(), e); + } + } + + protected static Object proxyInstance( + ClassLoader targetClassLoader, Type type, Object targetInstance) { + if (type instanceof Class) { + Class clazz = (Class) type; + + // primitive type + if (clazz.isPrimitive()) { + return targetInstance; + } + + // cast directly + if (clazz.getAnnotation(ProxyClass.class) == null || targetInstance == null) { + return clazz.cast(targetInstance); + } + + // proxy instance + Object proxyInstance = + Proxy.newProxyInstance( + clazz.getClassLoader(), + new Class[] {clazz}, + new GeaflowInvocationHandler(targetClassLoader, targetInstance)); + return clazz.cast(proxyInstance); + } - // proxy map type - if (Map.class.isAssignableFrom(clazz)) { - return proxyMap(targetClassLoader, elementType[0], elementType[1], targetInstance); - } + if (type instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) type; + Class clazz = (Class) parameterizedType.getRawType(); + Type[] elementType = parameterizedType.getActualTypeArguments(); - // proxy raw type - return proxyInstance(targetClassLoader, clazz, targetInstance); - } + // proxy collection type + if (Collection.class.isAssignableFrom(clazz)) { + return proxyCollection(targetClassLoader, clazz, elementType[0], targetInstance); + } + + // proxy map type + if (Map.class.isAssignableFrom(clazz)) { + return proxyMap(targetClassLoader, elementType[0], elementType[1], targetInstance); + } - throw new GeaflowException("Type {} not support proxy", type.getTypeName()); + // proxy raw type + return proxyInstance(targetClassLoader, clazz, targetInstance); } - private static Object proxyCollection(ClassLoader targetClassLoader, Class clazz, Type elementType, - Object targetInstance) { - // create default collection - Collection proxyCollection; - if (Set.class.isAssignableFrom(clazz)) { - proxyCollection = new HashSet<>(); + throw new GeaflowException("Type {} not support proxy", type.getTypeName()); + } - } else if (List.class.isAssignableFrom(clazz)) { - proxyCollection = new ArrayList<>(); + private static Object proxyCollection( + ClassLoader targetClassLoader, Class clazz, Type elementType, Object targetInstance) { + // create default collection + Collection proxyCollection; + if (Set.class.isAssignableFrom(clazz)) { + proxyCollection = new HashSet<>(); - } else { - throw new GeaflowException("Result collection class {} not supported", clazz.getSimpleName()); - } + } else if (List.class.isAssignableFrom(clazz)) { + proxyCollection = new ArrayList<>(); - // proxy collection element - Collection targetCollection = (Collection) targetInstance; - for (Object element : targetCollection) { - proxyCollection.add(proxyInstance(targetClassLoader, elementType, element)); - } + } else { + throw new GeaflowException("Result collection class {} not supported", clazz.getSimpleName()); + } - return proxyCollection; + // proxy collection element + Collection targetCollection = (Collection) targetInstance; + for (Object element : targetCollection) { + proxyCollection.add(proxyInstance(targetClassLoader, elementType, element)); } - private static Object proxyMap(ClassLoader targetClassLoader, Type keyType, Type valueType, Object targetInstance) { - // create proxy map - Map proxyMap = new HashMap<>(); + return proxyCollection; + } - // proxy map key value - Map targetMap = (Map) targetInstance; - for (Entry entry : targetMap.entrySet()) { - Object proxyKey = proxyInstance(targetClassLoader, keyType, entry.getKey()); - Object proxyValue = proxyInstance(targetClassLoader, valueType, entry.getValue()); - proxyMap.put(proxyKey, proxyValue); - } + private static Object proxyMap( + ClassLoader targetClassLoader, Type keyType, Type valueType, Object targetInstance) { + // create proxy map + Map proxyMap = new HashMap<>(); - return proxyMap; + // proxy map key value + Map targetMap = (Map) targetInstance; + for (Entry entry : targetMap.entrySet()) { + Object proxyKey = proxyInstance(targetClassLoader, keyType, entry.getKey()); + Object proxyValue = proxyInstance(targetClassLoader, valueType, entry.getValue()); + proxyMap.put(proxyKey, proxyValue); } + + return proxyMap; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/CatalogType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/CatalogType.java index 564da1317..26e46d990 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/CatalogType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/CatalogType.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.common.util.type; public enum CatalogType { - MEMORY("memory"), - CONSOLE("console"); + MEMORY("memory"), + CONSOLE("console"); - private String value; + private String value; - CatalogType(String value) { - this.value = value; - } + CatalogType(String value) { + this.value = value; + } - public String getValue() { - return value; - } + public String getValue() { + return value; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowApiResponseCode.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowApiResponseCode.java index df1532a28..8a71608cd 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowApiResponseCode.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowApiResponseCode.java @@ -23,20 +23,19 @@ @Getter public enum GeaflowApiResponseCode { + SUCCESS(200), - SUCCESS(200), + ILLEGAL(400), - ILLEGAL(400), + FORBIDDEN(403), - FORBIDDEN(403), + ERROR(500), - ERROR(500), + FAIL(500); - FAIL(500); + private final int httpCode; - private final int httpCode; - - GeaflowApiResponseCode(int httpCode) { - this.httpCode = httpCode; - } + GeaflowApiResponseCode(int httpCode) { + this.httpCode = httpCode; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowAuthorityType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowAuthorityType.java index e46977adf..35b943014 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowAuthorityType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowAuthorityType.java @@ -20,11 +20,11 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowAuthorityType { - ALL, + ALL, - QUERY, + QUERY, - UPDATE, + UPDATE, - EXECUTE; + EXECUTE; } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowClusterType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowClusterType.java index ac2adf556..a697be82e 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowClusterType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowClusterType.java @@ -20,7 +20,5 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowClusterType { - - K8S - + K8S } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowDeployMode.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowDeployMode.java index f90af6751..2179655aa 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowDeployMode.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowDeployMode.java @@ -20,9 +20,7 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowDeployMode { + LOCAL, - LOCAL, - - CLUSTER - + CLUSTER } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldCategory.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldCategory.java index b5e2b1592..49954b4ae 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldCategory.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldCategory.java @@ -23,81 +23,78 @@ import static org.apache.geaflow.console.common.util.type.GeaflowFieldCategory.NumConstraint.EXACTLY_ONCE; import static org.apache.geaflow.console.common.util.type.GeaflowFieldCategory.NumConstraint.NONE; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Set; + import org.apache.geaflow.console.common.util.exception.GeaflowException; +import com.google.common.collect.Sets; + public enum GeaflowFieldCategory { + PROPERTY(NONE, GeaflowStructType.values()), - PROPERTY(NONE, GeaflowStructType.values()), + ID(EXACTLY_ONCE, GeaflowStructType.TABLE, GeaflowStructType.VIEW), - ID(EXACTLY_ONCE, GeaflowStructType.TABLE, GeaflowStructType.VIEW), + VERTEX_ID(EXACTLY_ONCE, GeaflowStructType.VERTEX), - VERTEX_ID(EXACTLY_ONCE, GeaflowStructType.VERTEX), + VERTEX_LABEL(EXACTLY_ONCE, GeaflowStructType.VERTEX), - VERTEX_LABEL(EXACTLY_ONCE, GeaflowStructType.VERTEX), + EDGE_SOURCE_ID(EXACTLY_ONCE, GeaflowStructType.EDGE), - EDGE_SOURCE_ID(EXACTLY_ONCE, GeaflowStructType.EDGE), + EDGE_TARGET_ID(EXACTLY_ONCE, GeaflowStructType.EDGE), - EDGE_TARGET_ID(EXACTLY_ONCE, GeaflowStructType.EDGE), + EDGE_LABEL(EXACTLY_ONCE, GeaflowStructType.EDGE), - EDGE_LABEL(EXACTLY_ONCE, GeaflowStructType.EDGE), + EDGE_TIMESTAMP(AT_MOST_ONCE, GeaflowStructType.EDGE); - EDGE_TIMESTAMP(AT_MOST_ONCE, GeaflowStructType.EDGE); + private final Set structTypes; - private final Set structTypes; + private final NumConstraint numConstraint; - private final NumConstraint numConstraint; + GeaflowFieldCategory(NumConstraint numConstraint, GeaflowStructType... structTypes) { + this.numConstraint = numConstraint; + this.structTypes = Sets.newHashSet(structTypes); + } - GeaflowFieldCategory(NumConstraint numConstraint, GeaflowStructType... structTypes) { - this.numConstraint = numConstraint; - this.structTypes = Sets.newHashSet(structTypes); + public static List of(GeaflowStructType structType) { + List constraints = new ArrayList<>(); + for (GeaflowFieldCategory value : values()) { + if (value.structTypes.contains(structType)) { + constraints.add(value); + } } - - public static List of(GeaflowStructType structType) { - List constraints = new ArrayList<>(); - for (GeaflowFieldCategory value : values()) { - if (value.structTypes.contains(structType)) { - constraints.add(value); - } + return constraints; + } + + public enum NumConstraint { + /** count == 1. */ + EXACTLY_ONCE, + /** count <= 1. */ + AT_MOST_ONCE, + NONE + } + + public NumConstraint getNumConstraint() { + return numConstraint; + } + + public void validate(int count) { + switch (this.numConstraint) { + case EXACTLY_ONCE: + if (count < 1) { + throw new GeaflowException("Must have {} field", this.name()); + } else if (count > 1) { + throw new GeaflowException("Can have only one {} field", this.name()); } - return constraints; - } - - public enum NumConstraint { - /** - * count == 1. - */ - EXACTLY_ONCE, - /** - * count <= 1. - */ - AT_MOST_ONCE, - NONE - } - - public NumConstraint getNumConstraint() { - return numConstraint; - } - - public void validate(int count) { - switch (this.numConstraint) { - case EXACTLY_ONCE: - if (count < 1) { - throw new GeaflowException("Must have {} field", this.name()); - } else if (count > 1) { - throw new GeaflowException("Can have only one {} field", this.name()); - } - break; - case AT_MOST_ONCE: - if (count > 1) { - throw new GeaflowException("Can have only one {} field", this.name()); - } - break; - default: - return; + break; + case AT_MOST_ONCE: + if (count > 1) { + throw new GeaflowException("Can have only one {} field", this.name()); } + break; + default: + return; } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldType.java index ef944d5fa..7b4381282 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFieldType.java @@ -20,17 +20,15 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowFieldType { + BOOLEAN, - BOOLEAN, + INT, - INT, + BIGINT, - BIGINT, + DOUBLE, - DOUBLE, - - VARCHAR, - - TIMESTAMP; + VARCHAR, + TIMESTAMP; } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFileType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFileType.java index e50327949..495622fd0 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFileType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowFileType.java @@ -22,15 +22,14 @@ import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; public enum GeaflowFileType { + JAR; - JAR; - - public static GeaflowFileType of(String name) { - for (GeaflowFileType value : values()) { - if (value.name().equalsIgnoreCase(name)) { - return value; - } - } - throw new GeaflowIllegalException("File type {} not supported", name); + public static GeaflowFileType of(String name) { + for (GeaflowFileType value : values()) { + if (value.name().equalsIgnoreCase(name)) { + return value; + } } + throw new GeaflowIllegalException("File type {} not supported", name); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowJobType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowJobType.java index 8c643a6c2..e5c0dd464 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowJobType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowJobType.java @@ -23,23 +23,21 @@ @Getter public enum GeaflowJobType { + INTEGRATE(GeaflowTaskType.CODE), - INTEGRATE(GeaflowTaskType.CODE), + DISTRIBUTE(GeaflowTaskType.CODE), - DISTRIBUTE(GeaflowTaskType.CODE), + PROCESS(GeaflowTaskType.CODE), - PROCESS(GeaflowTaskType.CODE), + SERVE(GeaflowTaskType.API), - SERVE(GeaflowTaskType.API), + STAT(GeaflowTaskType.API), - STAT(GeaflowTaskType.API), + CUSTOM(GeaflowTaskType.API); - CUSTOM(GeaflowTaskType.API); - - private final GeaflowTaskType taskType; - - GeaflowJobType(GeaflowTaskType taskType) { - this.taskType = taskType; - } + private final GeaflowTaskType taskType; + GeaflowJobType(GeaflowTaskType taskType) { + this.taskType = taskType; + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowLLMType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowLLMType.java index 6ca00da73..598424d51 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowLLMType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowLLMType.java @@ -20,9 +20,7 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowLLMType { + OPEN_AI, - OPEN_AI, - - LOCAL, - + LOCAL, } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowOperationType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowOperationType.java index aa7215b37..fdfcb707d 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowOperationType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowOperationType.java @@ -20,28 +20,24 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowOperationType { + CREATE, - CREATE, + UPDATE, - UPDATE, + DELETE, - DELETE, + /** operations for job. */ + PUBLISH, - /** - * operations for job. - */ - PUBLISH, + START, - START, + STOP, - STOP, + REFRESH, - REFRESH, + RESET, - RESET, - - STARTUP_NOTIFY, - - FINISH, + STARTUP_NOTIFY, + FINISH, } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginCategory.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginCategory.java index 4c066b9da..f79776438 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginCategory.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginCategory.java @@ -20,21 +20,19 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowPluginCategory { + TABLE, - TABLE, + GRAPH, - GRAPH, + RUNTIME_CLUSTER, - RUNTIME_CLUSTER, + RUNTIME_META, - RUNTIME_META, + HA_META, - HA_META, + METRIC, - METRIC, - - REMOTE_FILE, - - DATA + REMOTE_FILE, + DATA } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginType.java index f2698d57c..0d86e606b 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowPluginType.java @@ -23,57 +23,54 @@ @Slf4j public enum GeaflowPluginType { - KAFKA, + KAFKA, - HIVE, + HIVE, - FILE, + FILE, - SOCKET, + SOCKET, - CONSOLE, + CONSOLE, - MEMORY, + MEMORY, - ROCKSDB, + ROCKSDB, - LOCAL, + LOCAL, - DFS, + DFS, - OSS, + OSS, - JDBC, + JDBC, - REDIS, + REDIS, - INFLUXDB, + INFLUXDB, - K8S, + K8S, - CONTAINER, + CONTAINER, - RAY, + RAY, - /** - * just for custom define or unknown type, not a specific type. - */ - None; + /** just for custom define or unknown type, not a specific type. */ + None; - - public static GeaflowPluginType of(String type) { - try { - return GeaflowPluginType.valueOf(type); - } catch (Exception e) { - return GeaflowPluginType.None; - } + public static GeaflowPluginType of(String type) { + try { + return GeaflowPluginType.valueOf(type); + } catch (Exception e) { + return GeaflowPluginType.None; } + } - public static String getName(String type) { - if (type == null) { - return null; - } - GeaflowPluginType typeEnum = GeaflowPluginType.of(type.toUpperCase()); - return typeEnum == None ? type : typeEnum.name(); + public static String getName(String type) { + if (type == null) { + return null; } + GeaflowPluginType typeEnum = GeaflowPluginType.of(type.toUpperCase()); + return typeEnum == None ? type : typeEnum.name(); + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowResourceType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowResourceType.java index 962c8750e..4f321f1a5 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowResourceType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowResourceType.java @@ -20,55 +20,55 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowResourceType { - TENANT, + TENANT, - USER, + USER, - FILE, + FILE, - PLUGIN_CONFIG, + PLUGIN_CONFIG, - INSTANCE, + INSTANCE, - GRAPH, + GRAPH, - STRUCT, + STRUCT, - FUNCTION, + FUNCTION, - TABLE, + TABLE, - VIEW, + VIEW, - VERTEX, + VERTEX, - EDGE, + EDGE, - FIELD, + FIELD, - JOB, + JOB, - RELEASE, + RELEASE, - TASK, + TASK, - META, + META, - METRIC, + METRIC, - DATA, + DATA, - ENGINE_VERSION, + ENGINE_VERSION, - PLUGIN, + PLUGIN, - CLUSTER, + CLUSTER, - STORE, + STORE, - META_STORE, + META_STORE, - METRIC_STORE, + METRIC_STORE, - DATA_STORE; + DATA_STORE; } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowRoleType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowRoleType.java index f9c68b296..a3553c3e0 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowRoleType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowRoleType.java @@ -20,9 +20,7 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowRoleType { + SYSTEM_ADMIN, - SYSTEM_ADMIN, - - TENANT_ADMIN - + TENANT_ADMIN } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStatementStatus.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStatementStatus.java index e4b95c638..c82780839 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStatementStatus.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStatementStatus.java @@ -20,10 +20,9 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowStatementStatus { + RUNNING, - RUNNING, + FAILED, - FAILED, - - FINISHED; + FINISHED; } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStructType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStructType.java index d2ba643f1..3badf955e 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStructType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowStructType.java @@ -20,13 +20,11 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowStructType { + TABLE, - TABLE, + VIEW, - VIEW, - - VERTEX, - - EDGE + VERTEX, + EDGE } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskStatus.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskStatus.java index 1ac2fc2ca..677c0ab0b 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskStatus.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskStatus.java @@ -23,45 +23,45 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; + import org.apache.geaflow.console.common.util.exception.GeaflowException; public enum GeaflowTaskStatus { + CREATED, - CREATED, - - WAITING, + WAITING, - STARTING, + STARTING, - FAILED, + FAILED, - RUNNING, + RUNNING, - FINISHED, + FINISHED, - STOPPED, + STOPPED, - DELETED; + DELETED; - private static final Map> allowedOperations = new HashMap<>(); + private static final Map> allowedOperations = + new HashMap<>(); - static { - allowedOperations.put(GeaflowOperationType.START, EnumSet.of(CREATED, FAILED, STOPPED)); - allowedOperations.put(GeaflowOperationType.STOP, EnumSet.of(RUNNING, WAITING)); - allowedOperations.put(GeaflowOperationType.REFRESH, EnumSet.allOf(GeaflowTaskStatus.class)); - - Set unRunningStatus = EnumSet.of(CREATED, FAILED, STOPPED, FINISHED); - allowedOperations.put(GeaflowOperationType.PUBLISH, unRunningStatus); - allowedOperations.put(GeaflowOperationType.RESET, unRunningStatus); - allowedOperations.put(GeaflowOperationType.DELETE, unRunningStatus); - allowedOperations.put(GeaflowOperationType.FINISH, EnumSet.of(RUNNING, FINISHED)); - } + static { + allowedOperations.put(GeaflowOperationType.START, EnumSet.of(CREATED, FAILED, STOPPED)); + allowedOperations.put(GeaflowOperationType.STOP, EnumSet.of(RUNNING, WAITING)); + allowedOperations.put(GeaflowOperationType.REFRESH, EnumSet.allOf(GeaflowTaskStatus.class)); + Set unRunningStatus = EnumSet.of(CREATED, FAILED, STOPPED, FINISHED); + allowedOperations.put(GeaflowOperationType.PUBLISH, unRunningStatus); + allowedOperations.put(GeaflowOperationType.RESET, unRunningStatus); + allowedOperations.put(GeaflowOperationType.DELETE, unRunningStatus); + allowedOperations.put(GeaflowOperationType.FINISH, EnumSet.of(RUNNING, FINISHED)); + } - public void checkOperation(GeaflowOperationType operationType) { - Set allowedStatuses = allowedOperations.get(operationType); - if (allowedStatuses == null || !allowedStatuses.contains(this)) { - throw new GeaflowException("Task {} status can't {}", this, operationType); - } + public void checkOperation(GeaflowOperationType operationType) { + Set allowedStatuses = allowedOperations.get(operationType); + if (allowedStatuses == null || !allowedStatuses.contains(this)) { + throw new GeaflowException("Task {} status can't {}", this, operationType); } + } } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskType.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskType.java index 56c37a459..97d2f4e13 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskType.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowTaskType.java @@ -20,9 +20,7 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowTaskType { + CODE, - CODE, - - API - + API } diff --git a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowViewCategory.java b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowViewCategory.java index b837361ea..2e855c46f 100644 --- a/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowViewCategory.java +++ b/geaflow-console/app/common/util/src/main/java/org/apache/geaflow/console/common/util/type/GeaflowViewCategory.java @@ -20,9 +20,7 @@ package org.apache.geaflow.console.common.util.type; public enum GeaflowViewCategory { + LOGICAL, - LOGICAL, - - MATERIALIZED - + MATERIALIZED } diff --git a/geaflow-console/app/common/util/src/test/java/org/apache/geaflow/console/common/util/ProcessUtilTest.java b/geaflow-console/app/common/util/src/test/java/org/apache/geaflow/console/common/util/ProcessUtilTest.java index ad7f41828..04e07cd41 100644 --- a/geaflow-console/app/common/util/src/test/java/org/apache/geaflow/console/common/util/ProcessUtilTest.java +++ b/geaflow-console/app/common/util/src/test/java/org/apache/geaflow/console/common/util/ProcessUtilTest.java @@ -22,17 +22,16 @@ import org.testng.Assert; import org.testng.annotations.Test; - public class ProcessUtilTest { - @Test - public void testCommand() throws Exception { - Assert.assertEquals(ProcessUtil.execute("ls pom.xml"), "pom.xml"); - Assert.assertEquals(ProcessUtil.execute("echo abc"), "abc"); + @Test + public void testCommand() throws Exception { + Assert.assertEquals(ProcessUtil.execute("ls pom.xml"), "pom.xml"); + Assert.assertEquals(ProcessUtil.execute("echo abc"), "abc"); - int pid = ProcessUtil.executeAsync("sleep 123"); - Assert.assertTrue(ProcessUtil.existPid(pid)); - ProcessUtil.killPid(pid); - Assert.assertFalse(ProcessUtil.existPid(pid)); - } + int pid = ProcessUtil.executeAsync("sleep 123"); + Assert.assertTrue(ProcessUtil.existPid(pid)); + ProcessUtil.killPid(pid); + Assert.assertFalse(ProcessUtil.existPid(pid)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowId.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowId.java index 22f11ebe2..c11e39666 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowId.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowId.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Date; + import lombok.Getter; import lombok.Setter; @@ -28,14 +29,12 @@ @Setter public abstract class GeaflowId implements Serializable { - protected String id; - protected Date gmtCreate; - protected Date gmtModified; - protected String creatorId; - protected String modifierId; - private String tenantId; - - public void validate() { + protected String id; + protected Date gmtCreate; + protected Date gmtModified; + protected String creatorId; + protected String modifierId; + private String tenantId; - } + public void validate() {} } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowName.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowName.java index d96953c5a..a5e84fd87 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowName.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/GeaflowName.java @@ -19,25 +19,29 @@ package org.apache.geaflow.console.core.model; +import org.apache.commons.lang3.StringUtils; + import com.google.common.base.Preconditions; + import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; -import org.apache.commons.lang3.StringUtils; @Getter @Setter -@EqualsAndHashCode(of = {"name", "comment"}, callSuper = false) +@EqualsAndHashCode( + of = {"name", "comment"}, + callSuper = false) public abstract class GeaflowName extends GeaflowId { - protected String name; + protected String name; - protected String comment; + protected String comment; - @Override - public void validate() { - super.validate(); - Preconditions.checkArgument(StringUtils.isNotBlank(name), "Invalid name"); - Preconditions.checkArgument(!name.contains(" "), "Name '%s' can't contain space.", name); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Invalid name"); + Preconditions.checkArgument(!name.contains(" "), "Name '%s' can't contain space.", name); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/cluster/GeaflowCluster.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/cluster/GeaflowCluster.java index a407cd65f..77dda12db 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/cluster/GeaflowCluster.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/cluster/GeaflowCluster.java @@ -19,37 +19,39 @@ package org.apache.geaflow.console.core.model.cluster; -import com.google.common.base.Preconditions; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.model.config.ConfigDescFactory; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class GeaflowCluster extends GeaflowName { - private GeaflowPluginType type; + private GeaflowPluginType type; - private GeaflowConfig config; + private GeaflowConfig config; - public GeaflowCluster(GeaflowPluginConfig pluginConfig) { - this.type = GeaflowPluginType.of(pluginConfig.getType()); - this.name = pluginConfig.getName(); - this.comment = pluginConfig.getComment(); - this.config = pluginConfig.getConfig(); - } + public GeaflowCluster(GeaflowPluginConfig pluginConfig) { + this.type = GeaflowPluginType.of(pluginConfig.getType()); + this.name = pluginConfig.getName(); + this.comment = pluginConfig.getComment(); + this.config = pluginConfig.getConfig(); + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(type, "Invalid type"); - Preconditions.checkNotNull(config, "Invalid config"); - ConfigDescFactory.get(type).validateConfig(config); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(type, "Invalid type"); + Preconditions.checkNotNull(config, "Invalid config"); + ConfigDescFactory.get(type).validateConfig(config); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/code/GeaflowCode.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/code/GeaflowCode.java index 48ab15426..867ebf4f9 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/code/GeaflowCode.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/code/GeaflowCode.java @@ -25,9 +25,9 @@ @Getter public class GeaflowCode { - private final String text; + private final String text; - public GeaflowCode(@NonNull String text) { - this.text = text; - } + public GeaflowCode(@NonNull String text) { + this.text = text; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescFactory.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescFactory.java index fabd082c8..184cbc2dc 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescFactory.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescFactory.java @@ -21,7 +21,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.Fmt; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.exception.GeaflowLogException; @@ -29,48 +29,51 @@ import org.apache.geaflow.console.core.model.plugin.config.PluginConfigClass; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class ConfigDescFactory { - private static final Map, GeaflowConfigDesc> CONFIG_DESCS = - new ConcurrentHashMap<>(); + private static final Map, GeaflowConfigDesc> CONFIG_DESCS = + new ConcurrentHashMap<>(); - private static final Map PLUGIN_CONFIG_DESCS = new ConcurrentHashMap<>(); + private static final Map PLUGIN_CONFIG_DESCS = + new ConcurrentHashMap<>(); - static { - String packageName = PluginConfigClass.class.getPackage().getName(); - for (GeaflowPluginType type : GeaflowPluginType.values()) { - if (type == GeaflowPluginType.None) { - continue; - } - String prefix = type.name().charAt(0) + type.name().substring(1).toLowerCase(); - String className = Fmt.as("{}.{}PluginConfigClass", packageName, prefix); + static { + String packageName = PluginConfigClass.class.getPackage().getName(); + for (GeaflowPluginType type : GeaflowPluginType.values()) { + if (type == GeaflowPluginType.None) { + continue; + } + String prefix = type.name().charAt(0) + type.name().substring(1).toLowerCase(); + String className = Fmt.as("{}.{}PluginConfigClass", packageName, prefix); - try { - Class clazz = Class.forName(className); - PLUGIN_CONFIG_DESCS.put(type, getOrRegister((Class) clazz)); - log.info("Register {} plugin config class {} success", type, clazz.getSimpleName()); + try { + Class clazz = Class.forName(className); + PLUGIN_CONFIG_DESCS.put(type, getOrRegister((Class) clazz)); + log.info("Register {} plugin config class {} success", type, clazz.getSimpleName()); - } catch (Exception e) { - throw new GeaflowLogException("Register {} plugin config failed", type, e); - } - } + } catch (Exception e) { + throw new GeaflowLogException("Register {} plugin config failed", type, e); + } } + } - public static GeaflowConfigDesc getOrRegister(Class clazz) { - if (!CONFIG_DESCS.containsKey(clazz)) { - CONFIG_DESCS.put(clazz, new GeaflowConfigDesc(clazz)); - } - - return CONFIG_DESCS.get(clazz); + public static GeaflowConfigDesc getOrRegister(Class clazz) { + if (!CONFIG_DESCS.containsKey(clazz)) { + CONFIG_DESCS.put(clazz, new GeaflowConfigDesc(clazz)); } - public static GeaflowConfigDesc get(GeaflowPluginType type) { - GeaflowConfigDesc config = PLUGIN_CONFIG_DESCS.get(type); - if (config == null) { - throw new GeaflowException("Plugin config type {} not register", type); - } - return config; + return CONFIG_DESCS.get(clazz); + } + + public static GeaflowConfigDesc get(GeaflowPluginType type) { + GeaflowConfigDesc config = PLUGIN_CONFIG_DESCS.get(type); + if (config == null) { + throw new GeaflowException("Plugin config type {} not register", type); } + return config; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescItem.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescItem.java index 28c0de63f..9c3d0d069 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescItem.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigDescItem.java @@ -19,77 +19,83 @@ package org.apache.geaflow.console.core.model.config; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.annotation.JSONField; -import com.google.common.base.Preconditions; import java.lang.reflect.Field; -import lombok.Getter; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.I18nUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.annotation.JSONField; +import com.google.common.base.Preconditions; + +import lombok.Getter; + @Getter public class ConfigDescItem { - @JSONField(serialize = false, deserialize = false) - private final Field field; - - private final String key; + @JSONField(serialize = false, deserialize = false) + private final Field field; - private final String comment; + private final String key; - @JSONField(serialize = false, deserialize = false) - private final boolean jsonIgnore; + private final String comment; - private final GeaflowConfigType type; + @JSONField(serialize = false, deserialize = false) + private final boolean jsonIgnore; - private boolean required; + private final GeaflowConfigType type; - private Object defaultValue; + private boolean required; - private boolean masked; + private Object defaultValue; - @JSONField(serialize = false, deserialize = false) - private ConfigValueBehavior behavior = ConfigValueBehavior.NESTED; + private boolean masked; - @JSONField(serialize = false, deserialize = false) - private GeaflowConfigDesc innerConfigDesc; + @JSONField(serialize = false, deserialize = false) + private ConfigValueBehavior behavior = ConfigValueBehavior.NESTED; - public ConfigDescItem(Field field) { - field.setAccessible(true); + @JSONField(serialize = false, deserialize = false) + private GeaflowConfigDesc innerConfigDesc; - final Class clazz = field.getType(); - final GeaflowConfigKey keyTag = field.getAnnotation(GeaflowConfigKey.class); - final GeaflowConfigValue valueTag = field.getAnnotation(GeaflowConfigValue.class); - Preconditions.checkNotNull(keyTag, "GeaflowConfigKey annotation is required"); + public ConfigDescItem(Field field) { + field.setAccessible(true); - this.field = field; - this.key = StringUtils.trimToNull(keyTag.value()); - this.comment = StringUtils.trimToNull(keyTag.comment()); - this.jsonIgnore = keyTag.jsonIgnore(); - this.type = GeaflowConfigType.of(clazz); + final Class clazz = field.getType(); + final GeaflowConfigKey keyTag = field.getAnnotation(GeaflowConfigKey.class); + final GeaflowConfigValue valueTag = field.getAnnotation(GeaflowConfigValue.class); + Preconditions.checkNotNull(keyTag, "GeaflowConfigKey annotation is required"); - if (valueTag != null) { - this.required = valueTag.required(); + this.field = field; + this.key = StringUtils.trimToNull(keyTag.value()); + this.comment = StringUtils.trimToNull(keyTag.comment()); + this.jsonIgnore = keyTag.jsonIgnore(); + this.type = GeaflowConfigType.of(clazz); - String str = StringUtils.trimToNull(valueTag.defaultValue()); - if (StringUtils.isNotBlank(str)) { - this.defaultValue = String.class.equals(clazz) ? str : JSON.parseObject(str, clazz); - } + if (valueTag != null) { + this.required = valueTag.required(); - this.masked = valueTag.masked(); - this.behavior = valueTag.behavior(); - if (!ConfigValueBehavior.NESTED.equals(this.behavior) && !GeaflowConfigType.CONFIG.equals(this.type)) { - throw new GeaflowException("Only CONFIG type field can use NESTED behavior on key {}", this.key); - } - } + String str = StringUtils.trimToNull(valueTag.defaultValue()); + if (StringUtils.isNotBlank(str)) { + this.defaultValue = String.class.equals(clazz) ? str : JSON.parseObject(str, clazz); + } - if (GeaflowConfigClass.class.isAssignableFrom(clazz)) { - this.innerConfigDesc = ConfigDescFactory.getOrRegister((Class) clazz); - } + this.masked = valueTag.masked(); + this.behavior = valueTag.behavior(); + if (!ConfigValueBehavior.NESTED.equals(this.behavior) + && !GeaflowConfigType.CONFIG.equals(this.type)) { + throw new GeaflowException( + "Only CONFIG type field can use NESTED behavior on key {}", this.key); + } } - public String getComment() { - return I18nUtil.getMessage(comment); + if (GeaflowConfigClass.class.isAssignableFrom(clazz)) { + this.innerConfigDesc = + ConfigDescFactory.getOrRegister((Class) clazz); } + } + + public String getComment() { + return I18nUtil.getMessage(comment); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigValueBehavior.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigValueBehavior.java index b3f8d3f5a..f9812aa2b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigValueBehavior.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/ConfigValueBehavior.java @@ -21,19 +21,12 @@ public enum ConfigValueBehavior { - /** - * default, bind key to config value. - */ - NESTED, + /** default, bind key to config value. */ + NESTED, - /** - * config value will be expand multiple values. - */ - FLATTED, - - /** - * config value will be formatted to json string. - */ - JSON + /** config value will be expand multiple values. */ + FLATTED, + /** config value will be formatted to json string. */ + JSON } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfig.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfig.java index 1f78da72d..f2d5734b0 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfig.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfig.java @@ -19,94 +19,100 @@ package org.apache.geaflow.console.core.model.config; -import com.alibaba.fastjson.JSON; import java.lang.reflect.Modifier; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; + +import org.apache.geaflow.console.common.util.exception.GeaflowException; + +import com.alibaba.fastjson.JSON; + import lombok.Getter; import lombok.NoArgsConstructor; -import org.apache.geaflow.console.common.util.exception.GeaflowException; @Getter @NoArgsConstructor public class GeaflowConfig extends LinkedHashMap { - public GeaflowConfig(Object map) { - putAll((Map) map); + public GeaflowConfig(Object map) { + putAll((Map) map); + } + + public Map toStringMap() { + Map stringMap = new LinkedHashMap<>(); + forEach( + (k, v) -> + stringMap.put(k, v instanceof GeaflowConfig ? JSON.toJSONString(v) : v.toString())); + return stringMap; + } + + public final T parse(Class clazz, boolean fillWithDefault) { + if (Modifier.isAbstract(clazz.getModifiers())) { + throw new GeaflowException("Config field abstract type {} can't be parsed", clazz); } - public Map toStringMap() { - Map stringMap = new LinkedHashMap<>(); - forEach((k, v) -> stringMap.put(k, v instanceof GeaflowConfig ? JSON.toJSONString(v) : v.toString())); - return stringMap; - } + GeaflowConfigDesc configDesc = ConfigDescFactory.getOrRegister(clazz); + configDesc.validateConfig(this); - public final T parse(Class clazz, boolean fillWithDefault) { - if (Modifier.isAbstract(clazz.getModifiers())) { - throw new GeaflowException("Config field abstract type {} can't be parsed", clazz); - } + try { + T instance = clazz.newInstance(); - GeaflowConfigDesc configDesc = ConfigDescFactory.getOrRegister(clazz); - configDesc.validateConfig(this); - - try { - T instance = clazz.newInstance(); - - Set usedKeys = new HashSet<>(); - for (ConfigDescItem item : configDesc.getItems()) { - String key = item.getKey(); - Object value = this.get(key); - if (value == null && fillWithDefault) { - value = item.getDefaultValue(); - } - if (GeaflowConfigType.CONFIG.equals(item.getType())) { - value = innerParse(item, value); - - } else { - // use json do auto convert - if (value != null) { - Class fieldType = item.getField().getType(); - if (!fieldType.isAssignableFrom(value.getClass())) { - value = JSON.parseObject(JSON.toJSONString(value), fieldType); - } - } - } - - usedKeys.add(key); - item.getField().set(instance, value); + Set usedKeys = new HashSet<>(); + for (ConfigDescItem item : configDesc.getItems()) { + String key = item.getKey(); + Object value = this.get(key); + if (value == null && fillWithDefault) { + value = item.getDefaultValue(); + } + if (GeaflowConfigType.CONFIG.equals(item.getType())) { + value = innerParse(item, value); + + } else { + // use json do auto convert + if (value != null) { + Class fieldType = item.getField().getType(); + if (!fieldType.isAssignableFrom(value.getClass())) { + value = JSON.parseObject(JSON.toJSONString(value), fieldType); } + } + } - this.forEach((k, v) -> { - if (!usedKeys.contains(k)) { - instance.getExtendConfig().put(k, v); - } - }); + usedKeys.add(key); + item.getField().set(instance, value); + } - return instance; + this.forEach( + (k, v) -> { + if (!usedKeys.contains(k)) { + instance.getExtendConfig().put(k, v); + } + }); - } catch (Exception e) { - throw new GeaflowException("Parse config with {} failed", clazz.getSimpleName(), e); - } - } + return instance; - public final T parse(Class clazz) { - return parse(clazz, false); + } catch (Exception e) { + throw new GeaflowException("Parse config with {} failed", clazz.getSimpleName(), e); } - - private Object innerParse(ConfigDescItem item, Object value) { - ConfigValueBehavior behavior = item.getBehavior(); - Class innerConfigClass = item.getInnerConfigDesc().getClazz(); - switch (behavior) { - case NESTED: - return new GeaflowConfig(value).parse(innerConfigClass); - case FLATTED: - throw new GeaflowException("Unsupported FLATTED behavior for {}", item.getKey()); - case JSON: - return new GeaflowConfig(JSON.parseObject((String) value)).parse(innerConfigClass); - default: - throw new GeaflowException("Unsupported config value behavior {}", behavior); - } + } + + public final T parse(Class clazz) { + return parse(clazz, false); + } + + private Object innerParse(ConfigDescItem item, Object value) { + ConfigValueBehavior behavior = item.getBehavior(); + Class innerConfigClass = item.getInnerConfigDesc().getClazz(); + switch (behavior) { + case NESTED: + return new GeaflowConfig(value).parse(innerConfigClass); + case FLATTED: + throw new GeaflowException("Unsupported FLATTED behavior for {}", item.getKey()); + case JSON: + return new GeaflowConfig(JSON.parseObject((String) value)).parse(innerConfigClass); + default: + throw new GeaflowException("Unsupported config value behavior {}", behavior); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigClass.java index c32c327a5..fda887f53 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigClass.java @@ -19,75 +19,81 @@ package org.apache.geaflow.console.core.model.config; +import org.apache.geaflow.console.common.util.exception.GeaflowException; + import com.alibaba.fastjson.JSON; + import lombok.Getter; -import org.apache.geaflow.console.common.util.exception.GeaflowException; @Getter public abstract class GeaflowConfigClass { - private final GeaflowConfig extendConfig = new GeaflowConfig(); + private final GeaflowConfig extendConfig = new GeaflowConfig(); - public final GeaflowConfig build() { - GeaflowConfigDesc configDesc = ConfigDescFactory.getOrRegister(getClass()); + public final GeaflowConfig build() { + GeaflowConfigDesc configDesc = ConfigDescFactory.getOrRegister(getClass()); - try { - GeaflowConfig config = new GeaflowConfig(); - for (ConfigDescItem item : configDesc.getItems()) { - String key = item.getKey(); - Object value = item.getField().get(this); - if (value == null) { - continue; - } + try { + GeaflowConfig config = new GeaflowConfig(); + for (ConfigDescItem item : configDesc.getItems()) { + String key = item.getKey(); + Object value = item.getField().get(this); + if (value == null) { + continue; + } - if (GeaflowConfigType.CONFIG.equals(item.getType())) { - // build inner config - buildInner(config, item, (GeaflowConfigClass) value); + if (GeaflowConfigType.CONFIG.equals(item.getType())) { + // build inner config + buildInner(config, item, (GeaflowConfigClass) value); - } else { - config.put(key, value); - } - } + } else { + config.put(key, value); + } + } - // add extend config - config.putAll(this.getExtendConfig()); + // add extend config + config.putAll(this.getExtendConfig()); - // validate config - configDesc.validateConfig(config); + // validate config + configDesc.validateConfig(config); - return config; + return config; - } catch (Exception e) { - throw new GeaflowException("Build config of {} instance failed", getClass().getName(), e); - } + } catch (Exception e) { + throw new GeaflowException("Build config of {} instance failed", getClass().getName(), e); } - - private void buildInner(GeaflowConfig config, ConfigDescItem item, GeaflowConfigClass innerConfigClass) { - String key = item.getKey(); - ConfigValueBehavior behavior = item.getBehavior(); - - GeaflowConfig innerConfig = innerConfigClass.build(); - - // process behavior - switch (behavior) { - case NESTED: - config.put(key, innerConfig); - break; - case FLATTED: - config.putAll(innerConfig); - break; - case JSON: - GeaflowConfigDesc innerConfigDesc = ConfigDescFactory.getOrRegister(innerConfigClass.getClass()); - innerConfigDesc.getItems().forEach(innerItem -> { - if (innerItem.isJsonIgnore()) { - innerConfig.remove(innerItem.getKey()); - } + } + + private void buildInner( + GeaflowConfig config, ConfigDescItem item, GeaflowConfigClass innerConfigClass) { + String key = item.getKey(); + ConfigValueBehavior behavior = item.getBehavior(); + + GeaflowConfig innerConfig = innerConfigClass.build(); + + // process behavior + switch (behavior) { + case NESTED: + config.put(key, innerConfig); + break; + case FLATTED: + config.putAll(innerConfig); + break; + case JSON: + GeaflowConfigDesc innerConfigDesc = + ConfigDescFactory.getOrRegister(innerConfigClass.getClass()); + innerConfigDesc + .getItems() + .forEach( + innerItem -> { + if (innerItem.isJsonIgnore()) { + innerConfig.remove(innerItem.getKey()); + } }); - config.put(key, JSON.toJSONString(innerConfig)); - break; - default: - throw new GeaflowException("Unsupported config value behavior {}", behavior); - } + config.put(key, JSON.toJSONString(innerConfig)); + break; + default: + throw new GeaflowException("Unsupported config value behavior {}", behavior); } - + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigDesc.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigDesc.java index 6bfa6307b..5badac66b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigDesc.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigDesc.java @@ -19,105 +19,111 @@ package org.apache.geaflow.console.core.model.config; -import com.alibaba.fastjson.JSON; import java.lang.reflect.Field; import java.util.List; import java.util.stream.Collectors; -import lombok.Getter; + import org.apache.geaflow.console.common.util.ReflectUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; +import com.alibaba.fastjson.JSON; + +import lombok.Getter; + @Getter public class GeaflowConfigDesc { - private final Class clazz; + private final Class clazz; + + private final List items; + + public GeaflowConfigDesc(Class clazz) { + this.clazz = clazz; + this.items = parseItems(clazz); + } + + private static List parseItems(Class clazz) { + List fields = + ReflectUtil.getFields( + clazz, GeaflowConfigClass.class, f -> f.getAnnotation(GeaflowConfigKey.class) != null); + return fields.stream().map(ConfigDescItem::new).collect(Collectors.toList()); + } + + public void validateConfig(GeaflowConfig config) { + for (ConfigDescItem item : items) { + String key = item.getKey(); + Object value = config.get(key); + + if (value != null) { + GeaflowConfigType valueType = GeaflowConfigType.of(value.getClass()); + GeaflowConfigType expectedType = + ConfigValueBehavior.JSON.equals(item.getBehavior()) + ? GeaflowConfigType.STRING + : item.getType(); + + if (!expectedType.equals(valueType)) { + try { + if (GeaflowConfigType.CONFIG.equals(item.getType())) { + throw new GeaflowException("Type [CONFIG] compatible check is not allowed"); + } + JSON.parseObject(JSON.toJSONString(value), item.getField().getType()); - private final List items; + } catch (Exception e) { + throw new GeaflowException( + "Config key {} type {} not allowed, expected {}", key, valueType, expectedType, e); + } + } - public GeaflowConfigDesc(Class clazz) { - this.clazz = clazz; - this.items = parseItems(clazz); - } + if (GeaflowConfigType.CONFIG.equals(item.getType())) { + validateInnerValue(config, item, value); + } - private static List parseItems(Class clazz) { - List fields = ReflectUtil.getFields(clazz, GeaflowConfigClass.class, - f -> f.getAnnotation(GeaflowConfigKey.class) != null); - return fields.stream().map(ConfigDescItem::new).collect(Collectors.toList()); - } + } else { + if (GeaflowConfigType.CONFIG.equals(item.getType())) { + validateInnerRequired(config, item); - public void validateConfig(GeaflowConfig config) { - for (ConfigDescItem item : items) { - String key = item.getKey(); - Object value = config.get(key); - - if (value != null) { - GeaflowConfigType valueType = GeaflowConfigType.of(value.getClass()); - GeaflowConfigType expectedType = - ConfigValueBehavior.JSON.equals(item.getBehavior()) ? GeaflowConfigType.STRING : item.getType(); - - if (!expectedType.equals(valueType)) { - try { - if (GeaflowConfigType.CONFIG.equals(item.getType())) { - throw new GeaflowException("Type [CONFIG] compatible check is not allowed"); - } - JSON.parseObject(JSON.toJSONString(value), item.getField().getType()); - - } catch (Exception e) { - throw new GeaflowException("Config key {} type {} not allowed, expected {}", key, valueType, - expectedType, e); - } - } - - if (GeaflowConfigType.CONFIG.equals(item.getType())) { - validateInnerValue(config, item, value); - } - - } else { - if (GeaflowConfigType.CONFIG.equals(item.getType())) { - validateInnerRequired(config, item); - - } else { - if (item.isRequired()) { - throw new GeaflowException("Config key {} is required", key); - } - } - } + } else { + if (item.isRequired()) { + throw new GeaflowException("Config key {} is required", key); + } } + } } - - private void validateInnerValue(GeaflowConfig config, ConfigDescItem item, Object value) { - ConfigValueBehavior behavior = item.getBehavior(); - switch (behavior) { - case NESTED: - item.getInnerConfigDesc().validateConfig(new GeaflowConfig(value)); - break; - case FLATTED: - item.getInnerConfigDesc().validateConfig(config); - break; - case JSON: - item.getInnerConfigDesc().validateConfig(new GeaflowConfig(JSON.parseObject((String) value))); - break; - default: - throw new GeaflowException("Unsupported config value behavior {}", behavior); - } + } + + private void validateInnerValue(GeaflowConfig config, ConfigDescItem item, Object value) { + ConfigValueBehavior behavior = item.getBehavior(); + switch (behavior) { + case NESTED: + item.getInnerConfigDesc().validateConfig(new GeaflowConfig(value)); + break; + case FLATTED: + item.getInnerConfigDesc().validateConfig(config); + break; + case JSON: + item.getInnerConfigDesc() + .validateConfig(new GeaflowConfig(JSON.parseObject((String) value))); + break; + default: + throw new GeaflowException("Unsupported config value behavior {}", behavior); } - - private void validateInnerRequired(GeaflowConfig config, ConfigDescItem item) { - String key = item.getKey(); - ConfigValueBehavior behavior = item.getBehavior(); - switch (behavior) { - case NESTED: - case JSON: - if (item.isRequired()) { - throw new GeaflowException("Config key {} is required", key); - } - break; - case FLATTED: - item.getInnerConfigDesc().validateConfig(config); - break; - default: - throw new GeaflowException("Unsupported config value behavior {}", behavior); + } + + private void validateInnerRequired(GeaflowConfig config, ConfigDescItem item) { + String key = item.getKey(); + ConfigValueBehavior behavior = item.getBehavior(); + switch (behavior) { + case NESTED: + case JSON: + if (item.isRequired()) { + throw new GeaflowException("Config key {} is required", key); } + break; + case FLATTED: + item.getInnerConfigDesc().validateConfig(config); + break; + default: + throw new GeaflowException("Unsupported config value behavior {}", behavior); } - + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigKey.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigKey.java index 2ea0e54db..da0ae176e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigKey.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigKey.java @@ -30,18 +30,12 @@ @Documented public @interface GeaflowConfigKey { - /** - * config key. - */ - String value(); + /** config key. */ + String value(); - /** - * config comment, display text. - */ - String comment() default ""; + /** config comment, display text. */ + String comment() default ""; - /** - * ignore target field when config class used as json behavior field type. - */ - boolean jsonIgnore() default false; + /** ignore target field when config class used as json behavior field type. */ + boolean jsonIgnore() default false; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigType.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigType.java index e05a576b6..e4c920488 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigType.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigType.java @@ -19,38 +19,39 @@ package org.apache.geaflow.console.core.model.config; -import com.google.common.collect.Sets; import java.math.BigDecimal; import java.util.Map; import java.util.Set; + import org.apache.geaflow.console.common.util.exception.GeaflowException; -public enum GeaflowConfigType { +import com.google.common.collect.Sets; - BOOLEAN(Boolean.class), +public enum GeaflowConfigType { + BOOLEAN(Boolean.class), - LONG(Long.class, Integer.class, Short.class), + LONG(Long.class, Integer.class, Short.class), - DOUBLE(BigDecimal.class, Double.class, Float.class), + DOUBLE(BigDecimal.class, Double.class, Float.class), - STRING(String.class, Enum.class), + STRING(String.class, Enum.class), - CONFIG(GeaflowConfigClass.class, Map.class); + CONFIG(GeaflowConfigClass.class, Map.class); - private final Set> javaClasses; + private final Set> javaClasses; - GeaflowConfigType(Class... javaClasses) { - this.javaClasses = Sets.newHashSet(javaClasses); - } + GeaflowConfigType(Class... javaClasses) { + this.javaClasses = Sets.newHashSet(javaClasses); + } - public static GeaflowConfigType of(Class clazz) { - for (GeaflowConfigType value : values()) { - for (Class javaClass : value.javaClasses) { - if (javaClass.isAssignableFrom(clazz)) { - return value; - } - } + public static GeaflowConfigType of(Class clazz) { + for (GeaflowConfigType value : values()) { + for (Class javaClass : value.javaClasses) { + if (javaClass.isAssignableFrom(clazz)) { + return value; } - throw new GeaflowException("Unsupported config type {}", clazz.getSimpleName()); + } } + throw new GeaflowException("Unsupported config type {}", clazz.getSimpleName()); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigValue.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigValue.java index dc08cbe16..93a590afe 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigValue.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowConfigValue.java @@ -30,24 +30,15 @@ @Documented public @interface GeaflowConfigValue { - /** - * config value is required. - */ - boolean required() default false; + /** config value is required. */ + boolean required() default false; - /** - * config default value string. - */ - String defaultValue() default ""; + /** config default value string. */ + String defaultValue() default ""; - /** - * sensitive config value, need masked display. - */ - boolean masked() default false; - - /** - * how config value integrating into map. - */ - ConfigValueBehavior behavior() default ConfigValueBehavior.NESTED; + /** sensitive config value, need masked display. */ + boolean masked() default false; + /** how config value integrating into map. */ + ConfigValueBehavior behavior() default ConfigValueBehavior.NESTED; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowSystemConfig.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowSystemConfig.java index 423850f16..1c4f465ff 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowSystemConfig.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/GeaflowSystemConfig.java @@ -19,14 +19,14 @@ package org.apache.geaflow.console.core.model.config; +import org.apache.geaflow.console.core.model.GeaflowName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter public class GeaflowSystemConfig extends GeaflowName { - private String value; - + private String value; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/SystemConfigKeys.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/SystemConfigKeys.java index 4b14fe09f..063ac9e19 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/SystemConfigKeys.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/config/SystemConfigKeys.java @@ -21,6 +21,5 @@ public class SystemConfigKeys { - public static final String GEAFLOW_INITIALIZED = "geaflow.initialized"; - + public static final String GEAFLOW_INITIALIZED = "geaflow.initialized"; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowData.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowData.java index b9adb7def..130497f5e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowData.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowData.java @@ -19,21 +19,23 @@ package org.apache.geaflow.console.core.model.data; +import org.apache.commons.lang3.StringUtils; +import org.apache.geaflow.console.core.model.GeaflowName; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.Setter; -import org.apache.commons.lang3.StringUtils; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter public abstract class GeaflowData extends GeaflowName { - protected String instanceId; + protected String instanceId; - @Override - public void validate() { - super.validate(); - Preconditions.checkArgument(StringUtils.isNotBlank(name), "Invalid name"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Invalid name"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEdge.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEdge.java index dac2525ee..454d6c98e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEdge.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEdge.java @@ -23,50 +23,50 @@ import static org.apache.geaflow.console.common.util.type.GeaflowFieldCategory.EDGE_TARGET_ID; import static org.apache.geaflow.console.common.util.type.GeaflowFieldCategory.EDGE_TIMESTAMP; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; import org.apache.geaflow.console.common.util.type.GeaflowStructType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowEdge extends GeaflowStruct { - public GeaflowEdge() { - super(GeaflowStructType.EDGE); - } - - public GeaflowEdge(String name, String comment) { - this(); - super.name = name; - super.comment = comment; - } + public GeaflowEdge() { + super(GeaflowStructType.EDGE); + } - @Override - public void validate() { - super.validate(); + public GeaflowEdge(String name, String comment) { + this(); + super.name = name; + super.comment = comment; + } - int sourceIdCount = 0; - int targetIdCount = 0; - int tsCount = 0; - for (GeaflowField value : fields.values()) { - GeaflowFieldCategory category = value.getCategory(); - if (category == EDGE_SOURCE_ID) { - sourceIdCount++; - } + @Override + public void validate() { + super.validate(); - if (category == EDGE_TARGET_ID) { - targetIdCount++; - } + int sourceIdCount = 0; + int targetIdCount = 0; + int tsCount = 0; + for (GeaflowField value : fields.values()) { + GeaflowFieldCategory category = value.getCategory(); + if (category == EDGE_SOURCE_ID) { + sourceIdCount++; + } - if (category == EDGE_TIMESTAMP) { - tsCount++; - } - } + if (category == EDGE_TARGET_ID) { + targetIdCount++; + } - EDGE_SOURCE_ID.validate(sourceIdCount); - EDGE_TARGET_ID.validate(targetIdCount); - EDGE_TIMESTAMP.validate(tsCount); + if (category == EDGE_TIMESTAMP) { + tsCount++; + } } + EDGE_SOURCE_ID.validate(sourceIdCount); + EDGE_TARGET_ID.validate(targetIdCount); + EDGE_TIMESTAMP.validate(tsCount); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEndpoint.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEndpoint.java index 61271330e..47b8d18f6 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEndpoint.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowEndpoint.java @@ -30,10 +30,9 @@ @AllArgsConstructor public class GeaflowEndpoint { - private String edgeId; + private String edgeId; - private String sourceId; - - private String targetId; + private String sourceId; + private String targetId; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowField.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowField.java index 2089f59b2..124cc9c6c 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowField.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowField.java @@ -19,36 +19,41 @@ package org.apache.geaflow.console.core.model.data; +import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; +import org.apache.geaflow.console.common.util.type.GeaflowFieldType; +import org.apache.geaflow.console.core.model.GeaflowName; + import com.google.common.base.Preconditions; + import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; -import org.apache.geaflow.console.common.util.type.GeaflowFieldType; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter @NoArgsConstructor -@EqualsAndHashCode(of = {"type", "category"}, callSuper = true) +@EqualsAndHashCode( + of = {"type", "category"}, + callSuper = true) public class GeaflowField extends GeaflowName { - private GeaflowFieldType type; + private GeaflowFieldType type; - private GeaflowFieldCategory category = GeaflowFieldCategory.PROPERTY; + private GeaflowFieldCategory category = GeaflowFieldCategory.PROPERTY; - public GeaflowField(String name, String comment, GeaflowFieldType type, GeaflowFieldCategory category) { - super.name = name; - super.comment = comment; - this.type = type; - this.category = category; - } + public GeaflowField( + String name, String comment, GeaflowFieldType type, GeaflowFieldCategory category) { + super.name = name; + super.comment = comment; + this.type = type; + this.category = category; + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(type, "Invalid type"); - Preconditions.checkNotNull(category, "Invalid constraint"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(type, "Invalid type"); + Preconditions.checkNotNull(category, "Invalid constraint"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowFunction.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowFunction.java index 734ffbec3..78e38dcbe 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowFunction.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowFunction.java @@ -19,30 +19,32 @@ package org.apache.geaflow.console.core.model.data; +import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; @Getter @Setter @NoArgsConstructor public class GeaflowFunction extends GeaflowData { - private GeaflowRemoteFile jarPackage; + private GeaflowRemoteFile jarPackage; - private String entryClass; + private String entryClass; - public GeaflowFunction(String name, String comment) { - super.name = name; - super.comment = comment; - } + public GeaflowFunction(String name, String comment) { + super.name = name; + super.comment = comment; + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(jarPackage, "Invalid jarPackage"); - Preconditions.checkNotNull(entryClass, "Invalid entryClass"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(jarPackage, "Invalid jarPackage"); + Preconditions.checkNotNull(entryClass, "Invalid entryClass"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowGraph.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowGraph.java index 611e540c7..dff2dbead 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowGraph.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowGraph.java @@ -19,78 +19,81 @@ package org.apache.geaflow.console.core.model.data; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; + +import org.apache.commons.collections.MapUtils; +import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; + +import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.commons.collections.MapUtils; -import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; @Getter @Setter @NoArgsConstructor public class GeaflowGraph extends GeaflowData { - private final Map vertices = new LinkedHashMap<>(); - - private final Map edges = new LinkedHashMap<>(); - - private List endpoints = new ArrayList<>(); + private final Map vertices = new LinkedHashMap<>(); - private GeaflowPluginConfig pluginConfig; + private final Map edges = new LinkedHashMap<>(); - public GeaflowGraph(String name, String comment) { - super.name = name; - super.comment = comment; - } + private List endpoints = new ArrayList<>(); - public void addVertices(List vertices) { - for (GeaflowVertex vertex : vertices) { - this.vertices.put(vertex.getName(), vertex); - } - } + private GeaflowPluginConfig pluginConfig; - public void addEdges(List edges) { - for (GeaflowEdge edge : edges) { - this.edges.put(edge.getName(), edge); - } - } + public GeaflowGraph(String name, String comment) { + super.name = name; + super.comment = comment; + } - public void addVertex(GeaflowVertex vertex) { - vertices.put(vertex.getName(), vertex); + public void addVertices(List vertices) { + for (GeaflowVertex vertex : vertices) { + this.vertices.put(vertex.getName(), vertex); } + } - public void addEdge(GeaflowEdge edge) { - edges.put(edge.getName(), edge); + public void addEdges(List edges) { + for (GeaflowEdge edge : edges) { + this.edges.put(edge.getName(), edge); } - - - public void removeVertex(String vertexName) { - vertices.remove(vertexName); + } + + public void addVertex(GeaflowVertex vertex) { + vertices.put(vertex.getName(), vertex); + } + + public void addEdge(GeaflowEdge edge) { + edges.put(edge.getName(), edge); + } + + public void removeVertex(String vertexName) { + vertices.remove(vertexName); + } + + public void removeEdge(String edgeName) { + edges.remove(edgeName); + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(pluginConfig, "pluginConfig is null"); + Preconditions.checkArgument(MapUtils.isNotEmpty(vertices), "Graph needs at least one vertex"); + Preconditions.checkArgument(MapUtils.isNotEmpty(edges), "Graph needs at least one edge"); + } + + public int getShardCount() { + String shardCount = + (String) pluginConfig.getConfig().get("geaflow.dsl.graph.store.shard.count"); + if (shardCount != null) { + return Integer.parseInt(shardCount); } - public void removeEdge(String edgeName) { - edges.remove(edgeName); - } - - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(pluginConfig, "pluginConfig is null"); - Preconditions.checkArgument(MapUtils.isNotEmpty(vertices), "Graph needs at least one vertex"); - Preconditions.checkArgument(MapUtils.isNotEmpty(edges), "Graph needs at least one edge"); - } - - public int getShardCount() { - String shardCount = (String) pluginConfig.getConfig().get("geaflow.dsl.graph.store.shard.count"); - if (shardCount != null) { - return Integer.parseInt(shardCount); - } - - return Integer.parseInt((String) pluginConfig.getConfig().getOrDefault("shardCount", "2")); - } + return Integer.parseInt((String) pluginConfig.getConfig().getOrDefault("shardCount", "2")); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowInstance.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowInstance.java index acbde8b83..ade7771eb 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowInstance.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowInstance.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.core.model.data; +import org.apache.geaflow.console.core.model.GeaflowName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter -public class GeaflowInstance extends GeaflowName { - -} +public class GeaflowInstance extends GeaflowName {} diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowStruct.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowStruct.java index c1e5ca5d6..a39d46309 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowStruct.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowStruct.java @@ -22,43 +22,45 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowStructType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public abstract class GeaflowStruct extends GeaflowData { - protected final Map fields = new LinkedHashMap<>(); - protected GeaflowStructType type; + protected final Map fields = new LinkedHashMap<>(); + protected GeaflowStructType type; - public GeaflowStruct(GeaflowStructType type) { - this.type = type; - } + public GeaflowStruct(GeaflowStructType type) { + this.type = type; + } - public void addField(GeaflowField field) { - fields.put(field.getName(), field); - } + public void addField(GeaflowField field) { + fields.put(field.getName(), field); + } - public void addFields(List fields) { - for (GeaflowField field : fields) { - String fieldName = field.getName(); - if (this.fields.containsKey(fieldName)) { - throw new GeaflowException("Field name {} duplicated", fieldName); - } - this.fields.put(fieldName, field); - } + public void addFields(List fields) { + for (GeaflowField field : fields) { + String fieldName = field.getName(); + if (this.fields.containsKey(fieldName)) { + throw new GeaflowException("Field name {} duplicated", fieldName); + } + this.fields.put(fieldName, field); } + } - public void removeField(String name) { - this.fields.remove(name); - } + public void removeField(String name) { + this.fields.remove(name); + } - public void removeFields(List names) { - for (String name : names) { - this.fields.remove(name); - } + public void removeFields(List names) { + for (String name : names) { + this.fields.remove(name); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowTable.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowTable.java index 404abc377..0a25753ef 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowTable.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowTable.java @@ -19,31 +19,33 @@ package org.apache.geaflow.console.core.model.data; +import org.apache.geaflow.console.common.util.type.GeaflowStructType; +import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStructType; -import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; @Getter @Setter public class GeaflowTable extends GeaflowStruct { - private GeaflowPluginConfig pluginConfig; + private GeaflowPluginConfig pluginConfig; - public GeaflowTable() { - super(GeaflowStructType.TABLE); - } + public GeaflowTable() { + super(GeaflowStructType.TABLE); + } - public GeaflowTable(String name, String comment) { - this(); - super.name = name; - super.comment = comment; - } + public GeaflowTable(String name, String comment) { + this(); + super.name = name; + super.comment = comment; + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(pluginConfig, "pluginConfig is null"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(pluginConfig, "pluginConfig is null"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowVertex.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowVertex.java index 754b17d22..14246ab89 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowVertex.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowVertex.java @@ -21,37 +21,38 @@ import static org.apache.geaflow.console.common.util.type.GeaflowFieldCategory.VERTEX_ID; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowFieldCategory; import org.apache.geaflow.console.common.util.type.GeaflowStructType; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowVertex extends GeaflowStruct { - public GeaflowVertex() { - super(GeaflowStructType.VERTEX); + public GeaflowVertex() { + super(GeaflowStructType.VERTEX); + } + + public GeaflowVertex(String name, String comment) { + this(); + super.name = name; + super.comment = comment; + } + + @Override + public void validate() { + super.validate(); + + int idCount = 0; + for (GeaflowField value : fields.values()) { + GeaflowFieldCategory category = value.getCategory(); + if (category == VERTEX_ID) { + idCount++; + } } - public GeaflowVertex(String name, String comment) { - this(); - super.name = name; - super.comment = comment; - } - - @Override - public void validate() { - super.validate(); - - int idCount = 0; - for (GeaflowField value : fields.values()) { - GeaflowFieldCategory category = value.getCategory(); - if (category == VERTEX_ID) { - idCount++; - } - } - - VERTEX_ID.validate(idCount); - } + VERTEX_ID.validate(idCount); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowView.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowView.java index 8b08d5a5e..96bb5c14c 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowView.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/data/GeaflowView.java @@ -19,21 +19,22 @@ package org.apache.geaflow.console.core.model.data; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowStructType; import org.apache.geaflow.console.common.util.type.GeaflowViewCategory; import org.apache.geaflow.console.core.model.code.GeaflowCode; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowView extends GeaflowStruct { - private GeaflowViewCategory category = GeaflowViewCategory.LOGICAL; + private GeaflowViewCategory category = GeaflowViewCategory.LOGICAL; - private GeaflowCode code; + private GeaflowCode code; - public GeaflowView() { - super(GeaflowStructType.VIEW); - } + public GeaflowView() { + super(GeaflowStructType.VIEW); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/file/GeaflowRemoteFile.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/file/GeaflowRemoteFile.java index b0a8b8a81..4692aefbf 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/file/GeaflowRemoteFile.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/file/GeaflowRemoteFile.java @@ -19,50 +19,51 @@ package org.apache.geaflow.console.core.model.file; +import org.apache.commons.lang3.StringUtils; +import org.apache.geaflow.console.common.util.type.GeaflowFileType; +import org.apache.geaflow.console.core.model.GeaflowName; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.commons.lang3.StringUtils; -import org.apache.geaflow.console.common.util.type.GeaflowFileType; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter @NoArgsConstructor public class GeaflowRemoteFile extends GeaflowName { - protected GeaflowFileType type; - - protected String path; + protected GeaflowFileType type; - protected String url; + protected String path; - protected String md5; + protected String url; - public GeaflowRemoteFile(GeaflowFileType type) { - this.type = type; - } - - public static boolean md5Equals(GeaflowRemoteFile left, GeaflowRemoteFile right) { - if (left == null && right == null) { - return true; - } + protected String md5; - if (left != null && right != null) { - return StringUtils.equals(left.md5, right.md5); - } + public GeaflowRemoteFile(GeaflowFileType type) { + this.type = type; + } - return false; + public static boolean md5Equals(GeaflowRemoteFile left, GeaflowRemoteFile right) { + if (left == null && right == null) { + return true; } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(type, "Invalid type"); - Preconditions.checkNotNull(path, "Invalid path"); - Preconditions.checkNotNull(url, "Invalid url"); - Preconditions.checkNotNull(md5, "Invalid md5"); + if (left != null && right != null) { + return StringUtils.equals(left.md5, right.md5); } + return false; + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(type, "Invalid type"); + Preconditions.checkNotNull(path, "Invalid path"); + Preconditions.checkNotNull(url, "Invalid url"); + Preconditions.checkNotNull(md5, "Invalid md5"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/install/GeaflowInstall.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/install/GeaflowInstall.java index e8d13a9f0..63882bf9e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/install/GeaflowInstall.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/install/GeaflowInstall.java @@ -19,36 +19,38 @@ package org.apache.geaflow.console.core.model.install; +import org.apache.geaflow.console.core.model.GeaflowId; +import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowId; -import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; @Getter @Setter public class GeaflowInstall extends GeaflowId { - private GeaflowPluginConfig runtimeClusterConfig; + private GeaflowPluginConfig runtimeClusterConfig; - private GeaflowPluginConfig runtimeMetaConfig; + private GeaflowPluginConfig runtimeMetaConfig; - private GeaflowPluginConfig haMetaConfig; + private GeaflowPluginConfig haMetaConfig; - private GeaflowPluginConfig metricConfig; + private GeaflowPluginConfig metricConfig; - private GeaflowPluginConfig remoteFileConfig; + private GeaflowPluginConfig remoteFileConfig; - private GeaflowPluginConfig dataConfig; + private GeaflowPluginConfig dataConfig; - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(runtimeClusterConfig, "Invalid runtimeClusterConfig"); - Preconditions.checkNotNull(runtimeMetaConfig, "Invalid runtimeMetaConfig"); - Preconditions.checkNotNull(haMetaConfig, "Invalid haMetaConfig"); - Preconditions.checkNotNull(metricConfig, "Invalid metricConfig"); - Preconditions.checkNotNull(remoteFileConfig, "Invalid remoteFileConfig"); - Preconditions.checkNotNull(dataConfig, "Invalid dataConfig"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(runtimeClusterConfig, "Invalid runtimeClusterConfig"); + Preconditions.checkNotNull(runtimeMetaConfig, "Invalid runtimeMetaConfig"); + Preconditions.checkNotNull(haMetaConfig, "Invalid haMetaConfig"); + Preconditions.checkNotNull(metricConfig, "Invalid metricConfig"); + Preconditions.checkNotNull(remoteFileConfig, "Invalid remoteFileConfig"); + Preconditions.checkNotNull(dataConfig, "Invalid dataConfig"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowAnalysisJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowAnalysisJob.java index 389d57203..75fef0100 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowAnalysisJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowAnalysisJob.java @@ -23,8 +23,7 @@ public abstract class GeaflowAnalysisJob extends GeaflowApiJob { - public GeaflowAnalysisJob(GeaflowJobType type) { - super(type); - } - + public GeaflowAnalysisJob(GeaflowJobType type) { + super(type); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowApiJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowApiJob.java index 20c5e1c05..4416a7a0c 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowApiJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowApiJob.java @@ -21,43 +21,43 @@ import java.util.ArrayList; import java.util.List; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.core.model.code.GeaflowCode; import org.apache.geaflow.console.core.model.data.GeaflowFunction; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; import org.apache.geaflow.console.core.model.job.GeaflowTransferJob.StructMapping; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public abstract class GeaflowApiJob extends GeaflowJob { - @Setter - protected GeaflowRemoteFile jarPackage; + @Setter protected GeaflowRemoteFile jarPackage; - @Setter - private String entryClass; + @Setter private String entryClass; - public GeaflowApiJob(GeaflowJobType type) { - super(type); - } + public GeaflowApiJob(GeaflowJobType type) { + super(type); + } - public List getFunctions() { - return functions; - } + public List getFunctions() { + return functions; + } - public List getStructMappings() { - return new ArrayList<>(); - } + public List getStructMappings() { + return new ArrayList<>(); + } - @Override - public boolean isApiJob() { - return true; - } + @Override + public boolean isApiJob() { + return true; + } - @Override - public GeaflowCode getUserCode() { - return null; - } + @Override + public GeaflowCode getUserCode() { + return null; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCodeJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCodeJob.java index 99de39ce7..8a3c361fb 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCodeJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCodeJob.java @@ -20,39 +20,40 @@ package org.apache.geaflow.console.core.model.job; import java.util.Optional; + import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.core.model.code.GeaflowCode; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; public abstract class GeaflowCodeJob extends GeaflowJob { - protected GeaflowCode userCode; + protected GeaflowCode userCode; - public GeaflowCodeJob(GeaflowJobType type) { - super(type); - } + public GeaflowCodeJob(GeaflowJobType type) { + super(type); + } - @Override - public GeaflowRemoteFile getJarPackage() { - return null; - } + @Override + public GeaflowRemoteFile getJarPackage() { + return null; + } - @Override - public String getEntryClass() { - return null; - } + @Override + public String getEntryClass() { + return null; + } - @Override - public boolean isApiJob() { - return false; - } + @Override + public boolean isApiJob() { + return false; + } - @Override - public GeaflowCode getUserCode() { - return userCode; - } + @Override + public GeaflowCode getUserCode() { + return userCode; + } - public void setUserCode(String code) { - this.userCode = Optional.ofNullable(code).map(e -> new GeaflowCode(code)).orElse(null); - } + public void setUserCode(String code) { + this.userCode = Optional.ofNullable(code).map(e -> new GeaflowCode(code)).orElse(null); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCustomJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCustomJob.java index 2b0d93b0d..51386fc44 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCustomJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowCustomJob.java @@ -23,8 +23,7 @@ public class GeaflowCustomJob extends GeaflowApiJob { - public GeaflowCustomJob() { - super(GeaflowJobType.CUSTOM); - } - + public GeaflowCustomJob() { + super(GeaflowJobType.CUSTOM); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowDistributeJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowDistributeJob.java index 8873fb15f..f327ade70 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowDistributeJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowDistributeJob.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.model.job; import java.util.List; + import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.common.util.type.GeaflowStructType; import org.apache.geaflow.console.core.model.code.GeaflowCode; @@ -29,19 +30,22 @@ public class GeaflowDistributeJob extends GeaflowTransferJob { + public GeaflowDistributeJob() { + super(GeaflowJobType.DISTRIBUTE); + } - public GeaflowDistributeJob() { - super(GeaflowJobType.DISTRIBUTE); - } - - public void fromGraphToTable(GeaflowGraph graph, GeaflowStructType type, String name, GeaflowTable table, - List fieldMapping) { - GeaflowStruct struct = super.importGraphStruct(graph, type, name); - super.addStructMapping(struct, table, fieldMapping); - } + public void fromGraphToTable( + GeaflowGraph graph, + GeaflowStructType type, + String name, + GeaflowTable table, + List fieldMapping) { + GeaflowStruct struct = super.importGraphStruct(graph, type, name); + super.addStructMapping(struct, table, fieldMapping); + } - @Override - public GeaflowCode generateCode() { - return null; - } + @Override + public GeaflowCode generateCode() { + return null; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowIntegrateJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowIntegrateJob.java index edd115563..ae79ccf40 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowIntegrateJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowIntegrateJob.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.core.model.job; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -27,6 +26,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.util.ListUtil; import org.apache.geaflow.console.common.util.VelocityUtil; @@ -42,119 +42,157 @@ import org.apache.geaflow.console.core.model.data.GeaflowTable; import org.apache.geaflow.console.core.model.data.GeaflowView; -public class GeaflowIntegrateJob extends GeaflowTransferJob { - - private static final String TEMPLATE = "template/integration.vm"; - - public GeaflowIntegrateJob() { - super(GeaflowJobType.INTEGRATE); - } - - public void fromTableToGraph(GeaflowTable table, GeaflowGraph graph, GeaflowStructType type, String name, - List fieldMapping) { - super.addStructMapping(table, super.importGraphStruct(graph, type, name), fieldMapping); - } +import com.google.common.base.Preconditions; - public void fromViewToGraph(GeaflowView view, GeaflowGraph graph, GeaflowStructType type, String name, - Map fieldMapping) { - throw new GeaflowException("Unsupported operation"); - } +public class GeaflowIntegrateJob extends GeaflowTransferJob { - @Override - public GeaflowCode generateCode() { - Map velocityMap = new HashMap<>(); - List graphs = getGraphs(); - String graphName = graphs.get(0).getName(); - List> insertList = new ArrayList<>(); - - for (StructMapping structMapping : structMappings) { - String tableName = structMapping.getTableName(); - String structName = structMapping.getStructName(); - Map tableInsertMap = new HashMap<>(); - List> structs = new ArrayList<>(); - - for (FieldMappingItem fieldMapping : structMapping.getFieldMappings()) { - HashMap map = new HashMap<>(); - map.put("structName", structName); - map.put("tableFieldName", fieldMapping.getTableFieldName()); - map.put("structFieldName", fieldMapping.getStructFieldName()); - structs.add(map); - } - - tableInsertMap.put("tableName", tableName); - tableInsertMap.put("structs", structs); - insertList.add(tableInsertMap); - } - velocityMap.put("graphName", graphName); - velocityMap.put("inserts", insertList); - String code = VelocityUtil.applyResource(TEMPLATE, velocityMap); - return new GeaflowCode(code); + private static final String TEMPLATE = "template/integration.vm"; + + public GeaflowIntegrateJob() { + super(GeaflowJobType.INTEGRATE); + } + + public void fromTableToGraph( + GeaflowTable table, + GeaflowGraph graph, + GeaflowStructType type, + String name, + List fieldMapping) { + super.addStructMapping(table, super.importGraphStruct(graph, type, name), fieldMapping); + } + + public void fromViewToGraph( + GeaflowView view, + GeaflowGraph graph, + GeaflowStructType type, + String name, + Map fieldMapping) { + throw new GeaflowException("Unsupported operation"); + } + + @Override + public GeaflowCode generateCode() { + Map velocityMap = new HashMap<>(); + List graphs = getGraphs(); + String graphName = graphs.get(0).getName(); + List> insertList = new ArrayList<>(); + + for (StructMapping structMapping : structMappings) { + String tableName = structMapping.getTableName(); + String structName = structMapping.getStructName(); + Map tableInsertMap = new HashMap<>(); + List> structs = new ArrayList<>(); + + for (FieldMappingItem fieldMapping : structMapping.getFieldMappings()) { + HashMap map = new HashMap<>(); + map.put("structName", structName); + map.put("tableFieldName", fieldMapping.getTableFieldName()); + map.put("structFieldName", fieldMapping.getStructFieldName()); + structs.add(map); + } + + tableInsertMap.put("tableName", tableName); + tableInsertMap.put("structs", structs); + insertList.add(tableInsertMap); } - - @Override - public void validate() { - super.validate(); - - List duplicates = checkDuplicates(structMappings, StructMapping::getStructName); - Preconditions.checkArgument(CollectionUtils.isEmpty(duplicates), - "Struct '%s' can only be integrated from one table.", String.join(",", duplicates)); - - GeaflowGraph graph = new ArrayList<>(graphs.values()).get(0); - Map graphStructMap = new HashMap<>(); - graphStructMap.putAll(graph.getEdges()); - graphStructMap.putAll(graph.getVertices()); - for (StructMapping structMapping : structMappings) { - List fieldMappings = structMapping.getFieldMappings(); - Preconditions.checkArgument(CollectionUtils.isNotEmpty(fieldMappings), - "No fieldMapping from '%s' to '%s'.", structMapping.getTableName(), structMapping.getStructName()); - duplicates = checkDuplicates(fieldMappings, FieldMappingItem::getStructFieldName); - Preconditions.checkArgument(CollectionUtils.isEmpty(duplicates), "Field '%s' can only be inserted once. (%s -> %s)", - String.join(",", duplicates), structMapping.getTableName(), structMapping.getStructName()); - - GeaflowStruct table = structs.get(structMapping.getTableName()); - GeaflowStruct struct = graphStructMap.get(structMapping.getStructName()); - - for (FieldMappingItem item : fieldMappings) { - String tableFieldName = item.getTableFieldName(); - String structFieldName = item.getStructFieldName(); - GeaflowField tableField = table.getFields().get(tableFieldName); - GeaflowField structField = struct.getFields().get(structFieldName); - Preconditions.checkNotNull(tableField, "Table '%s' has no field '%s'.", structMapping.getTableName(), tableFieldName); - Preconditions.checkNotNull(structField, "Struct '%s' has no field '%s'.", structMapping.getStructName(), structFieldName); - Preconditions.checkArgument(tableField.getType() == structField.getType(), - "Field type not match: %s (%s), %s (%s). (%s -> %s)", - tableFieldName, tableField.getType(), structFieldName, structField.getType(), - structMapping.getTableName(), structMapping.getStructName()); - } - - checkMetaFields(struct, fieldMappings); - } + velocityMap.put("graphName", graphName); + velocityMap.put("inserts", insertList); + String code = VelocityUtil.applyResource(TEMPLATE, velocityMap); + return new GeaflowCode(code); + } + + @Override + public void validate() { + super.validate(); + + List duplicates = checkDuplicates(structMappings, StructMapping::getStructName); + Preconditions.checkArgument( + CollectionUtils.isEmpty(duplicates), + "Struct '%s' can only be integrated from one table.", + String.join(",", duplicates)); + + GeaflowGraph graph = new ArrayList<>(graphs.values()).get(0); + Map graphStructMap = new HashMap<>(); + graphStructMap.putAll(graph.getEdges()); + graphStructMap.putAll(graph.getVertices()); + for (StructMapping structMapping : structMappings) { + List fieldMappings = structMapping.getFieldMappings(); + Preconditions.checkArgument( + CollectionUtils.isNotEmpty(fieldMappings), + "No fieldMapping from '%s' to '%s'.", + structMapping.getTableName(), + structMapping.getStructName()); + duplicates = checkDuplicates(fieldMappings, FieldMappingItem::getStructFieldName); + Preconditions.checkArgument( + CollectionUtils.isEmpty(duplicates), + "Field '%s' can only be inserted once. (%s -> %s)", + String.join(",", duplicates), + structMapping.getTableName(), + structMapping.getStructName()); + + GeaflowStruct table = structs.get(structMapping.getTableName()); + GeaflowStruct struct = graphStructMap.get(structMapping.getStructName()); + + for (FieldMappingItem item : fieldMappings) { + String tableFieldName = item.getTableFieldName(); + String structFieldName = item.getStructFieldName(); + GeaflowField tableField = table.getFields().get(tableFieldName); + GeaflowField structField = struct.getFields().get(structFieldName); + Preconditions.checkNotNull( + tableField, + "Table '%s' has no field '%s'.", + structMapping.getTableName(), + tableFieldName); + Preconditions.checkNotNull( + structField, + "Struct '%s' has no field '%s'.", + structMapping.getStructName(), + structFieldName); + Preconditions.checkArgument( + tableField.getType() == structField.getType(), + "Field type not match: %s (%s), %s (%s). (%s -> %s)", + tableFieldName, + tableField.getType(), + structFieldName, + structField.getType(), + structMapping.getTableName(), + structMapping.getStructName()); + } + + checkMetaFields(struct, fieldMappings); } + } - private List checkDuplicates(List list, Function nameFunction) { - List listNames = ListUtil.convert(list, nameFunction); - List distinctNames = listNames.stream().distinct().collect(Collectors.toList()); - return ListUtil.diff(listNames, distinctNames); - } + private List checkDuplicates(List list, Function nameFunction) { + List listNames = ListUtil.convert(list, nameFunction); + List distinctNames = listNames.stream().distinct().collect(Collectors.toList()); + return ListUtil.diff(listNames, distinctNames); + } - private void checkMetaFields(GeaflowStruct struct, List fieldMappings) { - Set mappingCategories = fieldMappings.stream() + private void checkMetaFields(GeaflowStruct struct, List fieldMappings) { + Set mappingCategories = + fieldMappings.stream() .map(e -> struct.getFields().get(e.getStructFieldName()).getCategory()) .collect(Collectors.toSet()); - Set structCategories = struct.getFields().values().stream().map(GeaflowField::getCategory) + Set structCategories = + struct.getFields().values().stream() + .map(GeaflowField::getCategory) .collect(Collectors.toSet()); - // check required meta fields - List needFields = new ArrayList<>(); - for (GeaflowFieldCategory structCategory : structCategories) { - if ((structCategory.getNumConstraint() == NumConstraint.EXACTLY_ONCE - || structCategory.getNumConstraint() == NumConstraint.AT_MOST_ONCE) - && !mappingCategories.contains(structCategory)) { - needFields.add(structCategory.name()); - } - } - Preconditions.checkArgument(needFields.isEmpty(), "%s '%s' needs insert '%s' field", struct.getType(), struct.getName(), - String.join(",", needFields)); + // check required meta fields + List needFields = new ArrayList<>(); + for (GeaflowFieldCategory structCategory : structCategories) { + if ((structCategory.getNumConstraint() == NumConstraint.EXACTLY_ONCE + || structCategory.getNumConstraint() == NumConstraint.AT_MOST_ONCE) + && !mappingCategories.contains(structCategory)) { + needFields.add(structCategory.name()); + } } - + Preconditions.checkArgument( + needFields.isEmpty(), + "%s '%s' needs insert '%s' field", + struct.getType(), + struct.getName(), + String.join(",", needFields)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJob.java index 0b6343e3b..22153da40 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJob.java @@ -19,13 +19,11 @@ package org.apache.geaflow.console.core.model.job; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.model.code.GeaflowCode; @@ -36,77 +34,76 @@ import org.apache.geaflow.console.core.model.job.GeaflowTransferJob.StructMapping; import org.apache.geaflow.console.core.model.plugin.GeaflowPlugin; -@Getter -public abstract class GeaflowJob extends GeaflowName { - - protected final Map structs = new LinkedHashMap<>(); - - protected final Map graphs = new LinkedHashMap<>(); - - protected List functions = new ArrayList<>(); +import com.google.common.base.Preconditions; - protected List plugins = new ArrayList<>(); +import lombok.Getter; +import lombok.Setter; - @Setter - protected GeaflowJobType type; +@Getter +public abstract class GeaflowJob extends GeaflowName { - @Setter - protected GeaflowJobSla sla; + protected final Map structs = new LinkedHashMap<>(); - @Setter - protected String instanceId; + protected final Map graphs = new LinkedHashMap<>(); - public GeaflowJob(GeaflowJobType type) { - this.type = type; - } + protected List functions = new ArrayList<>(); + protected List plugins = new ArrayList<>(); - public abstract boolean isApiJob(); + @Setter protected GeaflowJobType type; - public abstract GeaflowRemoteFile getJarPackage(); + @Setter protected GeaflowJobSla sla; - public abstract String getEntryClass(); + @Setter protected String instanceId; - public abstract List getFunctions(); + public GeaflowJob(GeaflowJobType type) { + this.type = type; + } - public abstract List getStructMappings(); + public abstract boolean isApiJob(); - public abstract GeaflowCode getUserCode(); + public abstract GeaflowRemoteFile getJarPackage(); - public List getGraphs() { - return new ArrayList<>(graphs.values()); - } + public abstract String getEntryClass(); - public List getStructs() { - return new ArrayList<>(structs.values()); - } + public abstract List getFunctions(); - public void setStructs(List structs) { - for (GeaflowStruct struct : structs) { - this.structs.put(struct.getName(), struct); - } - } + public abstract List getStructMappings(); + public abstract GeaflowCode getUserCode(); - public void setGraph(List graphs) { - for (GeaflowGraph graph : graphs) { - this.graphs.put(graph.getName(), graph); - } - } + public List getGraphs() { + return new ArrayList<>(graphs.values()); + } - public void setFunctions(List functions) { - this.functions = functions; - } + public List getStructs() { + return new ArrayList<>(structs.values()); + } - public void setPlugins(List plugins) { - this.plugins = plugins; + public void setStructs(List structs) { + for (GeaflowStruct struct : structs) { + this.structs.put(struct.getName(), struct); } + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(type, "job type is null"); - Preconditions.checkNotNull(instanceId, "instanceId is null"); + public void setGraph(List graphs) { + for (GeaflowGraph graph : graphs) { + this.graphs.put(graph.getName(), graph); } - + } + + public void setFunctions(List functions) { + this.functions = functions; + } + + public void setPlugins(List plugins) { + this.plugins = plugins; + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(type, "job type is null"); + Preconditions.checkNotNull(instanceId, "instanceId is null"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJobSla.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJobSla.java index 56fcda315..7f11e6051 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJobSla.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowJobSla.java @@ -21,6 +21,4 @@ import org.apache.geaflow.console.core.model.GeaflowId; -public class GeaflowJobSla extends GeaflowId { - -} +public class GeaflowJobSla extends GeaflowId {} diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowProcessJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowProcessJob.java index 655998a66..61150e158 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowProcessJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowProcessJob.java @@ -19,37 +19,39 @@ package org.apache.geaflow.console.core.model.job; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.core.model.data.GeaflowFunction; import org.apache.geaflow.console.core.model.job.GeaflowTransferJob.StructMapping; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowProcessJob extends GeaflowCodeJob { - - public GeaflowProcessJob() { - super(GeaflowJobType.PROCESS); - } - - @Override - public List getFunctions() { - return functions; - } - - @Override - public List getStructMappings() { - return new ArrayList<>(); - } - - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(userCode, "user code is null"); - } + public GeaflowProcessJob() { + super(GeaflowJobType.PROCESS); + } + + @Override + public List getFunctions() { + return functions; + } + + @Override + public List getStructMappings() { + return new ArrayList<>(); + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(userCode, "user code is null"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowServeJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowServeJob.java index b1cb29600..ebad3b364 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowServeJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowServeJob.java @@ -23,11 +23,11 @@ public class GeaflowServeJob extends GeaflowAnalysisJob { - private static final String SERVE_JOB_ENTRY_CLASS = "org.apache.geaflow.example.service.QueryService"; - - public GeaflowServeJob() { - super(GeaflowJobType.SERVE); - setEntryClass(SERVE_JOB_ENTRY_CLASS); - } + private static final String SERVE_JOB_ENTRY_CLASS = + "org.apache.geaflow.example.service.QueryService"; + public GeaflowServeJob() { + super(GeaflowJobType.SERVE); + setEntryClass(SERVE_JOB_ENTRY_CLASS); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowStatJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowStatJob.java index f65d80c9a..3afabffc6 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowStatJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowStatJob.java @@ -23,7 +23,7 @@ public class GeaflowStatJob extends GeaflowAnalysisJob { - public GeaflowStatJob() { - super(GeaflowJobType.STAT); - } + public GeaflowStatJob() { + super(GeaflowJobType.STAT); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowTransferJob.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowTransferJob.java index 8b437a546..80172a81f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowTransferJob.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/GeaflowTransferJob.java @@ -19,14 +19,9 @@ package org.apache.geaflow.console.core.model.job; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; -import lombok.AllArgsConstructor; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowJobType; import org.apache.geaflow.console.common.util.type.GeaflowStructType; @@ -36,86 +31,95 @@ import org.apache.geaflow.console.core.model.data.GeaflowStruct; import org.apache.geaflow.console.core.model.data.GeaflowTable; +import com.google.common.base.Preconditions; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter public abstract class GeaflowTransferJob extends GeaflowCodeJob { - protected List structMappings = new ArrayList<>(); - - public GeaflowTransferJob(GeaflowJobType type) { - super(type); - } - - - public abstract GeaflowCode generateCode(); - - @Override - public List getStructs() { - return new ArrayList<>(structs.values()); - } - - @Override - public List getFunctions() { - return new ArrayList<>(); - } - - public List getStructMappings() { - return structMappings; - } - - public void fromTableToTable(GeaflowTable input, GeaflowTable output, List fieldMapping) { - addStructMapping(input, output, fieldMapping); - } - - protected void addStructMapping(GeaflowStruct input, GeaflowStruct output, List fieldMapping) { - String inputName = input.getName(); - String outputName = output.getName(); - - //inputStructs.put(inputName, input); - //outputStructs.put(outputName, output); - structMappings.add(new StructMapping(inputName, outputName, fieldMapping)); + protected List structMappings = new ArrayList<>(); + + public GeaflowTransferJob(GeaflowJobType type) { + super(type); + } + + public abstract GeaflowCode generateCode(); + + @Override + public List getStructs() { + return new ArrayList<>(structs.values()); + } + + @Override + public List getFunctions() { + return new ArrayList<>(); + } + + public List getStructMappings() { + return structMappings; + } + + public void fromTableToTable( + GeaflowTable input, GeaflowTable output, List fieldMapping) { + addStructMapping(input, output, fieldMapping); + } + + protected void addStructMapping( + GeaflowStruct input, GeaflowStruct output, List fieldMapping) { + String inputName = input.getName(); + String outputName = output.getName(); + + // inputStructs.put(inputName, input); + // outputStructs.put(outputName, output); + structMappings.add(new StructMapping(inputName, outputName, fieldMapping)); + } + + protected GeaflowStruct importGraphStruct( + GeaflowGraph graph, GeaflowStructType type, String name) { + String graphName = graph.getName(); + graphs.put(graphName, graph); + Preconditions.checkArgument(graphs.size() == 1, "Only one graph supported"); + + GeaflowStruct struct; + switch (type) { + case VERTEX: + struct = graph.getVertices().get(name); + break; + case EDGE: + struct = graph.getEdges().get(name); + break; + default: + throw new GeaflowException("Struct type {} not allowed in graph", type); } - - protected GeaflowStruct importGraphStruct(GeaflowGraph graph, GeaflowStructType type, String name) { - String graphName = graph.getName(); - graphs.put(graphName, graph); - Preconditions.checkArgument(graphs.size() == 1, "Only one graph supported"); - - GeaflowStruct struct; - switch (type) { - case VERTEX: - struct = graph.getVertices().get(name); - break; - case EDGE: - struct = graph.getEdges().get(name); - break; - default: - throw new GeaflowException("Struct type {} not allowed in graph", type); - } - return struct; - } - - @Setter - @Getter - @AllArgsConstructor - @NoArgsConstructor - @EqualsAndHashCode(of = {"tableName", "structName"}) - public static class StructMapping { - - private String tableName; - private String structName; - private List fieldMappings = new ArrayList<>(); - } - - @Setter - @Getter - @AllArgsConstructor - @NoArgsConstructor - @EqualsAndHashCode(of = {"tableFieldName", "structFieldName"}) - public static class FieldMappingItem { - - private String tableFieldName; - private String structFieldName; - } - + return struct; + } + + @Setter + @Getter + @AllArgsConstructor + @NoArgsConstructor + @EqualsAndHashCode(of = {"tableName", "structName"}) + public static class StructMapping { + + private String tableName; + private String structName; + private List fieldMappings = new ArrayList<>(); + } + + @Setter + @Getter + @AllArgsConstructor + @NoArgsConstructor + @EqualsAndHashCode(of = {"tableFieldName", "structFieldName"}) + public static class FieldMappingItem { + + private String tableFieldName; + private String structFieldName; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterArgsClass.java index 87e28c858..dac51088f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterArgsClass.java @@ -19,19 +19,19 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public abstract class ClusterArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "taskClusterConfig", comment = "i18n.key.job.cluster.args") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private ClusterConfigClass taskClusterConfig; - + @GeaflowConfigKey(value = "taskClusterConfig", comment = "i18n.key.job.cluster.args") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private ClusterConfigClass taskClusterConfig; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterConfigClass.java index d4c7d1256..4f02c7388 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ClusterConfigClass.java @@ -19,74 +19,81 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class ClusterConfigClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.container.num", comment = "i18n.key.container.count") - @GeaflowConfigValue(required = true) - private Integer containers; - - @GeaflowConfigKey(value = "geaflow.container.worker.num", comment = "i18n.key.container.worker.count") - @GeaflowConfigValue(required = true) - private Integer containerWorkers; - - @GeaflowConfigKey(value = "geaflow.container.memory.mb", comment = "i18n.key.container.memory.mb") - @GeaflowConfigValue(required = true, defaultValue = "256") - private Integer containerMemory; - - @GeaflowConfigKey(value = "geaflow.container.vcores", comment = "i18n.key.container.vcores") - @GeaflowConfigValue(required = true, defaultValue = "1") - private Double containerCores; - - @GeaflowConfigKey(value = "geaflow.container.jvm.options", comment = "i18n.key.container.jvm.args") - @GeaflowConfigValue(required = true) - private String containerJvmOptions; - - @GeaflowConfigKey(value = "geaflow.fo.enable", comment = "i18n.key.fo.enable") - @GeaflowConfigValue(defaultValue = "true") - private Boolean enableFo; - - @GeaflowConfigKey(value = "geaflow.client.memory.mb", comment = "i18n.key.client.memory.mb") - @GeaflowConfigValue(defaultValue = "1024") - private Integer clientMemory; - - @GeaflowConfigKey(value = "geaflow.master.memory.mb", comment = "i18n.key.master.memory.mb") - @GeaflowConfigValue(defaultValue = "4096") - private Integer masterMemory; - - @GeaflowConfigKey(value = "geaflow.driver.memory.mb", comment = "i18n.key.driver.memory.mb") - @GeaflowConfigValue(defaultValue = "4096") - private Integer driverMemory; - - @GeaflowConfigKey(value = "geaflow.client.vcores", comment = "i18n.key.client.vcores") - @GeaflowConfigValue(defaultValue = "1") - private Double clientCores; - - @GeaflowConfigKey(value = "geaflow.master.vcores", comment = "i18n.key.master.vcores") - @GeaflowConfigValue(defaultValue = "1") - private Double masterCores; - - @GeaflowConfigKey(value = "geaflow.driver.vcores", comment = "i18n.key.driver.vcores") - @GeaflowConfigValue(defaultValue = "1") - private Double driverCores; - - @GeaflowConfigKey(value = "geaflow.client.jvm.options", comment = "i18n.key.client.jvm.args") - @GeaflowConfigValue(defaultValue = "-Xmx1024m,-Xms1024m,-Xmn256m,-Xss256k,-XX:MaxDirectMemorySize=512m") - private String clientJvmOptions; - - @GeaflowConfigKey(value = "geaflow.master.jvm.options", comment = "i18n.key.master.jvm.args") - @GeaflowConfigValue(defaultValue = "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") - private String masterJvmOptions; - - @GeaflowConfigKey(value = "geaflow.driver.jvm.options", comment = "i18n.key.driver.jvm.args") - @GeaflowConfigValue(defaultValue = "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") - private String driverJvmOptions; - + @GeaflowConfigKey(value = "geaflow.container.num", comment = "i18n.key.container.count") + @GeaflowConfigValue(required = true) + private Integer containers; + + @GeaflowConfigKey( + value = "geaflow.container.worker.num", + comment = "i18n.key.container.worker.count") + @GeaflowConfigValue(required = true) + private Integer containerWorkers; + + @GeaflowConfigKey(value = "geaflow.container.memory.mb", comment = "i18n.key.container.memory.mb") + @GeaflowConfigValue(required = true, defaultValue = "256") + private Integer containerMemory; + + @GeaflowConfigKey(value = "geaflow.container.vcores", comment = "i18n.key.container.vcores") + @GeaflowConfigValue(required = true, defaultValue = "1") + private Double containerCores; + + @GeaflowConfigKey( + value = "geaflow.container.jvm.options", + comment = "i18n.key.container.jvm.args") + @GeaflowConfigValue(required = true) + private String containerJvmOptions; + + @GeaflowConfigKey(value = "geaflow.fo.enable", comment = "i18n.key.fo.enable") + @GeaflowConfigValue(defaultValue = "true") + private Boolean enableFo; + + @GeaflowConfigKey(value = "geaflow.client.memory.mb", comment = "i18n.key.client.memory.mb") + @GeaflowConfigValue(defaultValue = "1024") + private Integer clientMemory; + + @GeaflowConfigKey(value = "geaflow.master.memory.mb", comment = "i18n.key.master.memory.mb") + @GeaflowConfigValue(defaultValue = "4096") + private Integer masterMemory; + + @GeaflowConfigKey(value = "geaflow.driver.memory.mb", comment = "i18n.key.driver.memory.mb") + @GeaflowConfigValue(defaultValue = "4096") + private Integer driverMemory; + + @GeaflowConfigKey(value = "geaflow.client.vcores", comment = "i18n.key.client.vcores") + @GeaflowConfigValue(defaultValue = "1") + private Double clientCores; + + @GeaflowConfigKey(value = "geaflow.master.vcores", comment = "i18n.key.master.vcores") + @GeaflowConfigValue(defaultValue = "1") + private Double masterCores; + + @GeaflowConfigKey(value = "geaflow.driver.vcores", comment = "i18n.key.driver.vcores") + @GeaflowConfigValue(defaultValue = "1") + private Double driverCores; + + @GeaflowConfigKey(value = "geaflow.client.jvm.options", comment = "i18n.key.client.jvm.args") + @GeaflowConfigValue( + defaultValue = "-Xmx1024m,-Xms1024m,-Xmn256m,-Xss256k,-XX:MaxDirectMemorySize=512m") + private String clientJvmOptions; + + @GeaflowConfigKey(value = "geaflow.master.jvm.options", comment = "i18n.key.master.jvm.args") + @GeaflowConfigValue( + defaultValue = "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") + private String masterJvmOptions; + + @GeaflowConfigKey(value = "geaflow.driver.jvm.options", comment = "i18n.key.driver.jvm.args") + @GeaflowConfigValue( + defaultValue = "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") + private String driverJvmOptions; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CodeJobConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CodeJobConfigClass.java index 0b7d69aa0..2ae955a6c 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CodeJobConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CodeJobConfigClass.java @@ -19,17 +19,17 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class CodeJobConfigClass extends JobConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.window.size", comment = "i18n.key.dsl.window.size") - @GeaflowConfigValue(defaultValue = "-1") - private Integer windowSize; - + @GeaflowConfigKey(value = "geaflow.dsl.window.size", comment = "i18n.key.dsl.window.size") + @GeaflowConfigValue(defaultValue = "-1") + private Integer windowSize; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CompileContextClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CompileContextClass.java index d472ce9e8..b1545cfea 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CompileContextClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/CompileContextClass.java @@ -18,31 +18,32 @@ */ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class CompileContextClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") - @GeaflowConfigValue(required = true) - private String tokenKey; - - @GeaflowConfigKey(value = "geaflow.dsl.catalog.instance.name", comment = "i18n.key.default.instance.name") - @GeaflowConfigValue(required = true) - private String instanceName; - - @GeaflowConfigKey(value = "geaflow.dsl.catalog.type", comment = "i18n.key.job.catalog.type") - @GeaflowConfigValue(required = true, defaultValue = "memory") - private String catalogType; + @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") + @GeaflowConfigValue(required = true) + private String tokenKey; - @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") - @GeaflowConfigValue(required = true) - private String endpoint; + @GeaflowConfigKey( + value = "geaflow.dsl.catalog.instance.name", + comment = "i18n.key.default.instance.name") + @GeaflowConfigValue(required = true) + private String instanceName; + @GeaflowConfigKey(value = "geaflow.dsl.catalog.type", comment = "i18n.key.job.catalog.type") + @GeaflowConfigValue(required = true, defaultValue = "memory") + private String catalogType; + @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") + @GeaflowConfigValue(required = true) + private String endpoint; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ContainerClusterArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ContainerClusterArgsClass.java index d3d999f27..e531527c2 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ContainerClusterArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ContainerClusterArgsClass.java @@ -24,6 +24,4 @@ @Getter @Setter -public class ContainerClusterArgsClass extends ClusterArgsClass { - -} +public class ContainerClusterArgsClass extends ClusterArgsClass {} diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/GeaflowArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/GeaflowArgsClass.java index 1f7e0fece..ce146ce8b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/GeaflowArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/GeaflowArgsClass.java @@ -19,26 +19,26 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "system", comment = "i18n.key.system.params") - @GeaflowConfigValue(required = true) - private SystemArgsClass systemArgs; - - @GeaflowConfigKey(value = "cluster", comment = "i18n.key.cluster.args") - @GeaflowConfigValue(required = true) - private ClusterArgsClass clusterArgs; + @GeaflowConfigKey(value = "system", comment = "i18n.key.system.params") + @GeaflowConfigValue(required = true) + private SystemArgsClass systemArgs; - @GeaflowConfigKey(value = "job", comment = "i18n.key.task.params") - @GeaflowConfigValue(required = true) - private JobArgsClass jobArgs; + @GeaflowConfigKey(value = "cluster", comment = "i18n.key.cluster.args") + @GeaflowConfigValue(required = true) + private ClusterArgsClass clusterArgs; + @GeaflowConfigKey(value = "job", comment = "i18n.key.task.params") + @GeaflowConfigValue(required = true) + private JobArgsClass jobArgs; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/HaMetaArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/HaMetaArgsClass.java index 802f67ce8..56fd2b671 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/HaMetaArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/HaMetaArgsClass.java @@ -19,9 +19,6 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; @@ -32,31 +29,36 @@ import org.apache.geaflow.console.core.model.plugin.config.PluginConfigClass; import org.apache.geaflow.console.core.model.plugin.config.RedisPluginConfigClass; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class HaMetaArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.ha.service.type", comment = "i18n.key.type") - @GeaflowConfigValue(required = true, defaultValue = "REDIS") - private GeaflowPluginType type; + @GeaflowConfigKey(value = "geaflow.ha.service.type", comment = "i18n.key.type") + @GeaflowConfigValue(required = true, defaultValue = "REDIS") + private GeaflowPluginType type; - @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private PluginConfigClass plugin; + @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private PluginConfigClass plugin; - public HaMetaArgsClass(GeaflowPluginConfig pluginConfig) { - this.type = GeaflowPluginType.of(pluginConfig.getType()); + public HaMetaArgsClass(GeaflowPluginConfig pluginConfig) { + this.type = GeaflowPluginType.of(pluginConfig.getType()); - Class configClass; - switch (type) { - case REDIS: - configClass = RedisPluginConfigClass.class; - break; - default: - throw new GeaflowIllegalException("Ha meta config type {} not supported", pluginConfig.getType()); - } - - this.plugin = pluginConfig.getConfig().parse(configClass); + Class configClass; + switch (type) { + case REDIS: + configClass = RedisPluginConfigClass.class; + break; + default: + throw new GeaflowIllegalException( + "Ha meta config type {} not supported", pluginConfig.getType()); } + + this.plugin = pluginConfig.getConfig().parse(configClass); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobArgsClass.java index 5068d78df..651445436 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobArgsClass.java @@ -19,29 +19,32 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class JobArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.system.state.backend.type", comment = "i18n.key.state.storage.type") - @GeaflowConfigValue(required = true, defaultValue = "ROCKSDB") - private GeaflowPluginType systemStateType; + @GeaflowConfigKey( + value = "geaflow.system.state.backend.type", + comment = "i18n.key.state.storage.type") + @GeaflowConfigValue(required = true, defaultValue = "ROCKSDB") + private GeaflowPluginType systemStateType; - @GeaflowConfigKey(value = "jobConfig", comment = "i18n.key.task.user.params") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private JobConfigClass jobConfig; + @GeaflowConfigKey(value = "jobConfig", comment = "i18n.key.task.user.params") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private JobConfigClass jobConfig; - public JobArgsClass(JobConfigClass jobConfig) { - this.jobConfig = jobConfig; - } + public JobArgsClass(JobConfigClass jobConfig) { + this.jobConfig = jobConfig; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobConfigClass.java index 19250201f..a4e7285cd 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/JobConfigClass.java @@ -19,12 +19,11 @@ package org.apache.geaflow.console.core.model.job.config; +import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; @Getter @Setter -public class JobConfigClass extends GeaflowConfigClass { - -} +public class JobConfigClass extends GeaflowConfigClass {} diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8SClusterArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8SClusterArgsClass.java index 3952b01de..db2157195 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8SClusterArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8SClusterArgsClass.java @@ -19,25 +19,25 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; import org.apache.geaflow.console.core.model.plugin.config.K8sPluginConfigClass; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class K8SClusterArgsClass extends ClusterArgsClass { - @GeaflowConfigKey(value = "clusterConfig", comment = "i18n.key.k8s.cluster.config") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private K8sPluginConfigClass clusterConfig; - - @GeaflowConfigKey(value = "kubernetes.engine.jar.files", comment = "i18n.key.engine.jar.list") - private String engineJarUrls; + @GeaflowConfigKey(value = "clusterConfig", comment = "i18n.key.k8s.cluster.config") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private K8sPluginConfigClass clusterConfig; - @GeaflowConfigKey(value = "kubernetes.user.jar.files", comment = "i18n.key.task.jar.list") - private String taskJarUrls; + @GeaflowConfigKey(value = "kubernetes.engine.jar.files", comment = "i18n.key.engine.jar.list") + private String engineJarUrls; + @GeaflowConfigKey(value = "kubernetes.user.jar.files", comment = "i18n.key.task.jar.list") + private String taskJarUrls; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientArgsClass.java index 931464045..ad72113af 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientArgsClass.java @@ -19,54 +19,55 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class K8sClientArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "kubernetes.user.class.args", comment = "i18n.key.engine.params.json") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.JSON) - private GeaflowArgsClass geaflowArgs; + @GeaflowConfigKey(value = "kubernetes.user.class.args", comment = "i18n.key.engine.params.json") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.JSON) + private GeaflowArgsClass geaflowArgs; - @GeaflowConfigKey(value = "kubernetes.user.main.class", comment = "i18n.key.main.class") - @GeaflowConfigValue(required = true) - private String mainClass; + @GeaflowConfigKey(value = "kubernetes.user.main.class", comment = "i18n.key.main.class") + @GeaflowConfigValue(required = true) + private String mainClass; - @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") - @GeaflowConfigValue(required = true) - private String runtimeTaskId; + @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") + @GeaflowConfigValue(required = true) + private String runtimeTaskId; - @GeaflowConfigKey(value = "clusterArgs", comment = "i18n.key.cluster.args") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private K8SClusterArgsClass clusterArgs; + @GeaflowConfigKey(value = "clusterArgs", comment = "i18n.key.cluster.args") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private K8SClusterArgsClass clusterArgs; - @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") - private String gateway; + @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") + private String gateway; - @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") - private String token; + @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") + private String token; - public K8sClientArgsClass(GeaflowArgsClass geaflowArgs, String mainClass) { - this.geaflowArgs = geaflowArgs; - this.mainClass = mainClass; - this.gateway = geaflowArgs.getSystemArgs().getGateway(); - this.token = geaflowArgs.getSystemArgs().getTaskToken(); - this.runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); + public K8sClientArgsClass(GeaflowArgsClass geaflowArgs, String mainClass) { + this.geaflowArgs = geaflowArgs; + this.mainClass = mainClass; + this.gateway = geaflowArgs.getSystemArgs().getGateway(); + this.token = geaflowArgs.getSystemArgs().getTaskToken(); + this.runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); - ClusterArgsClass clusterArgs = geaflowArgs.getClusterArgs(); - if (clusterArgs instanceof K8SClusterArgsClass) { - this.clusterArgs = (K8SClusterArgsClass) clusterArgs; - } else { - throw new GeaflowException("Invalid clusterArgs type {}", clusterArgs.getClass()); - } + ClusterArgsClass clusterArgs = geaflowArgs.getClusterArgs(); + if (clusterArgs instanceof K8SClusterArgsClass) { + this.clusterArgs = (K8SClusterArgsClass) clusterArgs; + } else { + throw new GeaflowException("Invalid clusterArgs type {}", clusterArgs.getClass()); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientStopArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientStopArgsClass.java index 9d74c3eb0..380046c2f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientStopArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/K8sClientStopArgsClass.java @@ -19,29 +19,30 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class K8sClientStopArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") - @GeaflowConfigValue(required = true) - private String runtimeTaskId; + @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") + @GeaflowConfigValue(required = true) + private String runtimeTaskId; - @GeaflowConfigKey(value = "clusterArgs", comment = "i18n.key.cluster.args") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private K8SClusterArgsClass clusterArgs; + @GeaflowConfigKey(value = "clusterArgs", comment = "i18n.key.cluster.args") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private K8SClusterArgsClass clusterArgs; - public K8sClientStopArgsClass(String runtimeTaskId, K8SClusterArgsClass clusterArgs) { - this.runtimeTaskId = runtimeTaskId; - this.clusterArgs = clusterArgs; - } + public K8sClientStopArgsClass(String runtimeTaskId, K8SClusterArgsClass clusterArgs) { + this.runtimeTaskId = runtimeTaskId; + this.clusterArgs = clusterArgs; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/MetricArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/MetricArgsClass.java index 10ea709cf..910c30e46 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/MetricArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/MetricArgsClass.java @@ -19,9 +19,6 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; @@ -32,31 +29,36 @@ import org.apache.geaflow.console.core.model.plugin.config.InfluxdbPluginConfigClass; import org.apache.geaflow.console.core.model.plugin.config.PluginConfigClass; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class MetricArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.metric.reporters", comment = "i18n.key.metric.storage.type") - @GeaflowConfigValue(required = true, defaultValue = "INFLUXDB") - private GeaflowPluginType type; + @GeaflowConfigKey(value = "geaflow.metric.reporters", comment = "i18n.key.metric.storage.type") + @GeaflowConfigValue(required = true, defaultValue = "INFLUXDB") + private GeaflowPluginType type; - @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private PluginConfigClass plugin; + @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private PluginConfigClass plugin; - public MetricArgsClass(GeaflowPluginConfig pluginConfig) { - this.type = GeaflowPluginType.of(pluginConfig.getType()); + public MetricArgsClass(GeaflowPluginConfig pluginConfig) { + this.type = GeaflowPluginType.of(pluginConfig.getType()); - Class configClass; - switch (type) { - case INFLUXDB: - configClass = InfluxdbPluginConfigClass.class; - break; - default: - throw new GeaflowIllegalException("Metric meta config type {} not supported", pluginConfig.getType()); - } - - this.plugin = pluginConfig.getConfig().parse(configClass); + Class configClass; + switch (type) { + case INFLUXDB: + configClass = InfluxdbPluginConfigClass.class; + break; + default: + throw new GeaflowIllegalException( + "Metric meta config type {} not supported", pluginConfig.getType()); } + + this.plugin = pluginConfig.getConfig().parse(configClass); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/PersistentArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/PersistentArgsClass.java index e438be7b8..7a3b19f07 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/PersistentArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/PersistentArgsClass.java @@ -19,9 +19,6 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; @@ -35,53 +32,62 @@ import org.apache.geaflow.console.core.model.plugin.config.PersistentPluginConfigClass; import org.apache.geaflow.console.core.model.plugin.config.PluginConfigClass; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class PersistentArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.file.persistent.type", comment = "i18n.key.storage.type") - @GeaflowConfigValue(required = true, defaultValue = "LOCAL") - private GeaflowPluginType type; + @GeaflowConfigKey(value = "geaflow.file.persistent.type", comment = "i18n.key.storage.type") + @GeaflowConfigValue(required = true, defaultValue = "LOCAL") + private GeaflowPluginType type; - @GeaflowConfigKey(value = "geaflow.file.persistent.root", comment = "i18n.key.root.path") - @GeaflowConfigValue(required = true, defaultValue = "/geaflow/chk") - private String root; + @GeaflowConfigKey(value = "geaflow.file.persistent.root", comment = "i18n.key.root.path") + @GeaflowConfigValue(required = true, defaultValue = "/geaflow/chk") + private String root; - @GeaflowConfigKey(value = "geaflow.file.persistent.thread.size", comment = "i18n.key.local.thread.pool.count") - @GeaflowConfigValue - private Integer threadSize; + @GeaflowConfigKey( + value = "geaflow.file.persistent.thread.size", + comment = "i18n.key.local.thread.pool.count") + @GeaflowConfigValue + private Integer threadSize; - @GeaflowConfigKey(value = "geaflow.file.persistent.user.name", comment = "i18n.key.username") - @GeaflowConfigValue(defaultValue = "geaflow") - private String username; + @GeaflowConfigKey(value = "geaflow.file.persistent.user.name", comment = "i18n.key.username") + @GeaflowConfigValue(defaultValue = "geaflow") + private String username; - @GeaflowConfigKey(value = "geaflow.file.persistent.config.json", comment = "i18n.key.ext.config.json") - @GeaflowConfigValue(behavior = ConfigValueBehavior.JSON) - private PluginConfigClass plugin; + @GeaflowConfigKey( + value = "geaflow.file.persistent.config.json", + comment = "i18n.key.ext.config.json") + @GeaflowConfigValue(behavior = ConfigValueBehavior.JSON) + private PluginConfigClass plugin; - public PersistentArgsClass(GeaflowPluginConfig pluginConfig) { - this.type = GeaflowPluginType.of(pluginConfig.getType()); + public PersistentArgsClass(GeaflowPluginConfig pluginConfig) { + this.type = GeaflowPluginType.of(pluginConfig.getType()); - Class configClass; - switch (type) { - case LOCAL: - configClass = LocalPluginConfigClass.class; - break; - case DFS: - configClass = DfsPluginConfigClass.class; - break; - case OSS: - configClass = OssPluginConfigClass.class; - break; - default: - throw new GeaflowIllegalException("Persistent config type {} not supported", pluginConfig.getType()); - } - - PersistentPluginConfigClass config = pluginConfig.getConfig().parse(configClass); - this.root = config.getRoot(); - this.threadSize = config.getThreadSize(); - this.username = config.getUsername(); - this.plugin = config; + Class configClass; + switch (type) { + case LOCAL: + configClass = LocalPluginConfigClass.class; + break; + case DFS: + configClass = DfsPluginConfigClass.class; + break; + case OSS: + configClass = OssPluginConfigClass.class; + break; + default: + throw new GeaflowIllegalException( + "Persistent config type {} not supported", pluginConfig.getType()); } + + PersistentPluginConfigClass config = pluginConfig.getConfig().parse(configClass); + this.root = config.getRoot(); + this.threadSize = config.getThreadSize(); + this.username = config.getUsername(); + this.plugin = config; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClientArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClientArgsClass.java index 4d6f55c7a..f8ca3a552 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClientArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClientArgsClass.java @@ -19,27 +19,28 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class RayClientArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "ray.user.class.args", comment = "i18n.key.engine.params.json") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.JSON) - private GeaflowArgsClass geaflowArgs; + @GeaflowConfigKey(value = "ray.user.class.args", comment = "i18n.key.engine.params.json") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.JSON) + private GeaflowArgsClass geaflowArgs; - @GeaflowConfigKey(value = "ray.user.main.class", comment = "i18n.key.main.class") - @GeaflowConfigValue(required = true) - private String mainClass; + @GeaflowConfigKey(value = "ray.user.main.class", comment = "i18n.key.main.class") + @GeaflowConfigValue(required = true) + private String mainClass; - public RayClientArgsClass(GeaflowArgsClass geaflowArgs, String mainClass) { - this.geaflowArgs = geaflowArgs; - this.mainClass = mainClass; - } + public RayClientArgsClass(GeaflowArgsClass geaflowArgs, String mainClass) { + this.geaflowArgs = geaflowArgs; + this.mainClass = mainClass; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClusterArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClusterArgsClass.java index a9bc445cc..7aa1d85e6 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClusterArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RayClusterArgsClass.java @@ -19,19 +19,19 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; import org.apache.geaflow.console.core.model.plugin.config.RayPluginConfigClass; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class RayClusterArgsClass extends ClusterArgsClass { - @GeaflowConfigKey(value = "clusterConfig", comment = "i18n.key.k8s.cluster.config") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private RayPluginConfigClass rayConfig; - + @GeaflowConfigKey(value = "clusterConfig", comment = "i18n.key.k8s.cluster.config") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private RayPluginConfigClass rayConfig; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RuntimeMetaArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RuntimeMetaArgsClass.java index 6feb60b99..6deada737 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RuntimeMetaArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/RuntimeMetaArgsClass.java @@ -19,9 +19,6 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; @@ -33,46 +30,53 @@ import org.apache.geaflow.console.core.model.plugin.config.MemoryPluginConfigClass; import org.apache.geaflow.console.core.model.plugin.config.PluginConfigClass; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class RuntimeMetaArgsClass extends GeaflowConfigClass { - private static final String RUNTIME_META_TABLE_NAME = "backend_meta"; + private static final String RUNTIME_META_TABLE_NAME = "backend_meta"; - @GeaflowConfigKey(value = "geaflow.metric.stats.type", comment = "i18n.key.type") - @GeaflowConfigValue(required = true, defaultValue = "JDBC") - private GeaflowPluginType type; + @GeaflowConfigKey(value = "geaflow.metric.stats.type", comment = "i18n.key.type") + @GeaflowConfigValue(required = true, defaultValue = "JDBC") + private GeaflowPluginType type; - @GeaflowConfigKey(value = "geaflow.system.offset.backend.type", comment = "i18n.key.offset.storage.type") - @GeaflowConfigValue(required = true, defaultValue = "JDBC") - private GeaflowPluginType offsetMetaType; + @GeaflowConfigKey( + value = "geaflow.system.offset.backend.type", + comment = "i18n.key.offset.storage.type") + @GeaflowConfigValue(required = true, defaultValue = "JDBC") + private GeaflowPluginType offsetMetaType; - @GeaflowConfigKey(value = "geaflow.system.meta.table", comment = "i18n.key.table.name") - @GeaflowConfigValue(required = true, defaultValue = RUNTIME_META_TABLE_NAME) - private String table; + @GeaflowConfigKey(value = "geaflow.system.meta.table", comment = "i18n.key.table.name") + @GeaflowConfigValue(required = true, defaultValue = RUNTIME_META_TABLE_NAME) + private String table; - @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private PluginConfigClass plugin; + @GeaflowConfigKey(value = "plugin", comment = "i18n.key.plugin.config") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private PluginConfigClass plugin; - public RuntimeMetaArgsClass(GeaflowPluginConfig pluginConfig) { - this.type = GeaflowPluginType.of(pluginConfig.getType()); - this.offsetMetaType = GeaflowPluginType.of(pluginConfig.getType()); - this.table = RUNTIME_META_TABLE_NAME; + public RuntimeMetaArgsClass(GeaflowPluginConfig pluginConfig) { + this.type = GeaflowPluginType.of(pluginConfig.getType()); + this.offsetMetaType = GeaflowPluginType.of(pluginConfig.getType()); + this.table = RUNTIME_META_TABLE_NAME; - Class configClass; - switch (type) { - case JDBC: - configClass = JdbcPluginConfigClass.class; - break; - case MEMORY: - configClass = MemoryPluginConfigClass.class; - break; - default: - throw new GeaflowIllegalException("Runtime meta config type {} not supported", pluginConfig.getType()); - } - - this.plugin = pluginConfig.getConfig().parse(configClass); + Class configClass; + switch (type) { + case JDBC: + configClass = JdbcPluginConfigClass.class; + break; + case MEMORY: + configClass = MemoryPluginConfigClass.class; + break; + default: + throw new GeaflowIllegalException( + "Runtime meta config type {} not supported", pluginConfig.getType()); } + + this.plugin = pluginConfig.getConfig().parse(configClass); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ServeJobConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ServeJobConfigClass.java index 6e5f89c1e..e07e7d2ea 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ServeJobConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/ServeJobConfigClass.java @@ -19,31 +19,31 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class ServeJobConfigClass extends JobConfigClass { - @GeaflowConfigKey(value = "geaflow.job.mode", comment = "") - @GeaflowConfigValue(required = true, defaultValue = "OLAP_SERVICE") - private String jobMode; - - @GeaflowConfigKey(value = "geaflow.analytics.service.share.enable", comment = "") - @GeaflowConfigValue(required = true, defaultValue = "true") - private Boolean serviceShareEnable; + @GeaflowConfigKey(value = "geaflow.job.mode", comment = "") + @GeaflowConfigValue(required = true, defaultValue = "OLAP_SERVICE") + private String jobMode; - @GeaflowConfigKey(value = "geaflow.analytics.graph.view.name", comment = "") - @GeaflowConfigValue(required = true) - private String graphName; + @GeaflowConfigKey(value = "geaflow.analytics.service.share.enable", comment = "") + @GeaflowConfigValue(required = true, defaultValue = "true") + private Boolean serviceShareEnable; - @GeaflowConfigKey(value = "geaflow.analytics.query.parallelism", comment = "") - private Integer queryParallelism; + @GeaflowConfigKey(value = "geaflow.analytics.graph.view.name", comment = "") + @GeaflowConfigValue(required = true) + private String graphName; - @GeaflowConfigKey(value = "geaflow.driver.num", comment = "") - private Integer driverNum; + @GeaflowConfigKey(value = "geaflow.analytics.query.parallelism", comment = "") + private Integer queryParallelism; + @GeaflowConfigKey(value = "geaflow.driver.num", comment = "") + private Integer driverNum; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/StateArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/StateArgsClass.java index 668f50b4d..bdc33506a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/StateArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/StateArgsClass.java @@ -19,27 +19,27 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.ConfigValueBehavior; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class StateArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "runtimeMetaArgs", comment = "i18n.key.runtime.meta.params") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private RuntimeMetaArgsClass runtimeMetaArgs; - - @GeaflowConfigKey(value = "haMetaArgs", comment = "i18n.key.ha.storage.params") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private HaMetaArgsClass haMetaArgs; + @GeaflowConfigKey(value = "runtimeMetaArgs", comment = "i18n.key.runtime.meta.params") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private RuntimeMetaArgsClass runtimeMetaArgs; - @GeaflowConfigKey(value = "persistentArgs", comment = "i18n.key.persistent.storage.params") - @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) - private PersistentArgsClass persistentArgs; + @GeaflowConfigKey(value = "haMetaArgs", comment = "i18n.key.ha.storage.params") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private HaMetaArgsClass haMetaArgs; + @GeaflowConfigKey(value = "persistentArgs", comment = "i18n.key.persistent.storage.params") + @GeaflowConfigValue(required = true, behavior = ConfigValueBehavior.FLATTED) + private PersistentArgsClass persistentArgs; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/SystemArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/SystemArgsClass.java index a31831428..4b15e8f45 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/SystemArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/job/config/SystemArgsClass.java @@ -19,54 +19,58 @@ package org.apache.geaflow.console.core.model.job.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class SystemArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "geaflow.job.unique.id", comment = "i18n.key.job.id") - @GeaflowConfigValue(required = true) - private String taskId; - - @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") - @GeaflowConfigValue(required = true) - private String runtimeTaskId; + @GeaflowConfigKey(value = "geaflow.job.unique.id", comment = "i18n.key.job.id") + @GeaflowConfigValue(required = true) + private String taskId; - @GeaflowConfigKey(value = "geaflow.job.runtime.name", comment = "i18n.key.running.job.name") - @GeaflowConfigValue(required = true) - private String runtimeTaskName; + @GeaflowConfigKey(value = "geaflow.job.cluster.id", comment = "i18n.key.running.job.id") + @GeaflowConfigValue(required = true) + private String runtimeTaskId; - @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") - @GeaflowConfigValue(required = true, defaultValue = "http://0.0.0.0:8888") - private String gateway; + @GeaflowConfigKey(value = "geaflow.job.runtime.name", comment = "i18n.key.running.job.name") + @GeaflowConfigValue(required = true) + private String runtimeTaskName; - @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") - @GeaflowConfigValue(required = true, masked = true) - private String taskToken; + @GeaflowConfigKey(value = "geaflow.gw.endpoint", comment = "i18n.key.k8s.server.url") + @GeaflowConfigValue(required = true, defaultValue = "http://0.0.0.0:8888") + private String gateway; - @GeaflowConfigKey(value = "geaflow.cluster.started.callback.url", comment = "i18n.key.startup.notify.url") - @GeaflowConfigValue(required = true) - private String startupNotifyUrl; + @GeaflowConfigKey(value = "geaflow.dsl.catalog.token.key", comment = "i18n.key.api.token") + @GeaflowConfigValue(required = true, masked = true) + private String taskToken; - @GeaflowConfigKey(value = "geaflow.dsl.catalog.instance.name", comment = "i18n.key.default.instance.name") - @GeaflowConfigValue(required = true) - private String instanceName; + @GeaflowConfigKey( + value = "geaflow.cluster.started.callback.url", + comment = "i18n.key.startup.notify.url") + @GeaflowConfigValue(required = true) + private String startupNotifyUrl; - @GeaflowConfigKey(value = "geaflow.dsl.catalog.type", comment = "i18n.key.job.catalog.type") - @GeaflowConfigValue(required = true, defaultValue = "console") - private String catalogType; + @GeaflowConfigKey( + value = "geaflow.dsl.catalog.instance.name", + comment = "i18n.key.default.instance.name") + @GeaflowConfigValue(required = true) + private String instanceName; - @GeaflowConfigKey(value = "stateConfig", comment = "i18n.key.state.params") - @GeaflowConfigValue(required = true) - private StateArgsClass stateArgs; + @GeaflowConfigKey(value = "geaflow.dsl.catalog.type", comment = "i18n.key.job.catalog.type") + @GeaflowConfigValue(required = true, defaultValue = "console") + private String catalogType; - @GeaflowConfigKey(value = "metricConfig", comment = "i18n.key.metric.params") - @GeaflowConfigValue(required = true) - private MetricArgsClass metricArgs; + @GeaflowConfigKey(value = "stateConfig", comment = "i18n.key.state.params") + @GeaflowConfigValue(required = true) + private StateArgsClass stateArgs; + @GeaflowConfigKey(value = "metricConfig", comment = "i18n.key.metric.params") + @GeaflowConfigValue(required = true) + private MetricArgsClass metricArgs; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/CodefuseConfigArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/CodefuseConfigArgsClass.java index 20503d265..e0413542a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/CodefuseConfigArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/CodefuseConfigArgsClass.java @@ -19,21 +19,21 @@ package org.apache.geaflow.console.core.model.llm; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class CodefuseConfigArgsClass extends LLMConfigArgsClass { - @GeaflowConfigKey(value = "chainName") - @GeaflowConfigValue(required = true) - private String chainName; - - @GeaflowConfigKey(value = "sceneName") - @GeaflowConfigValue(required = true) - private String sceneName; + @GeaflowConfigKey(value = "chainName") + @GeaflowConfigValue(required = true) + private String chainName; + @GeaflowConfigKey(value = "sceneName") + @GeaflowConfigValue(required = true) + private String sceneName; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowChat.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowChat.java index 3bd750313..e4b8a9850 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowChat.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowChat.java @@ -19,29 +19,28 @@ package org.apache.geaflow.console.core.model.llm; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter @NoArgsConstructor public class GeaflowChat extends GeaflowId { - private String prompt; - - private String answer; + private String prompt; - private String modelId; + private String answer; - private String jobId; + private String modelId; - private GeaflowStatementStatus status; + private String jobId; - @Override - public void validate() { + private GeaflowStatementStatus status; - } + @Override + public void validate() {} } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowLLM.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowLLM.java index 3919287a8..2c73d1366 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowLLM.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/GeaflowLLM.java @@ -19,43 +19,43 @@ package org.apache.geaflow.console.core.model.llm; +import org.apache.geaflow.console.common.util.type.GeaflowLLMType; +import org.apache.geaflow.console.core.model.GeaflowName; +import org.apache.geaflow.console.core.model.config.GeaflowConfig; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowLLMType; -import org.apache.geaflow.console.core.model.GeaflowName; -import org.apache.geaflow.console.core.model.config.GeaflowConfig; @Getter @Setter @NoArgsConstructor public class GeaflowLLM extends GeaflowName { - private String url; + private String url; - private GeaflowConfig args; + private GeaflowConfig args; - private GeaflowLLMType type; + private GeaflowLLMType type; - public GeaflowLLM(String name, String comment) { - super.name = name; - super.comment = comment; - } - - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(url, "Invalid url"); - Preconditions.checkNotNull(type, "Invalid type"); - switch (type) { - case OPEN_AI: - args.parse(OpenAIConfigArgsClass.class); - break; - default: - break; - } + public GeaflowLLM(String name, String comment) { + super.name = name; + super.comment = comment; + } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(url, "Invalid url"); + Preconditions.checkNotNull(type, "Invalid type"); + switch (type) { + case OPEN_AI: + args.parse(OpenAIConfigArgsClass.class); + break; + default: + break; } - + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LLMConfigArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LLMConfigArgsClass.java index debf5f499..8ed92f948 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LLMConfigArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LLMConfigArgsClass.java @@ -19,34 +19,34 @@ package org.apache.geaflow.console.core.model.llm; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class LLMConfigArgsClass extends GeaflowConfigClass { - @GeaflowConfigKey(value = "retryTimes") - @GeaflowConfigValue(required = false, defaultValue = "3") - private Integer retryTimes; - - @GeaflowConfigKey(value = "retryInterval") - @GeaflowConfigValue(required = false, defaultValue = "1") - private Integer retryInterval; + @GeaflowConfigKey(value = "retryTimes") + @GeaflowConfigValue(required = false, defaultValue = "3") + private Integer retryTimes; - @GeaflowConfigKey(value = "connectTimeout") - @GeaflowConfigValue(required = false, defaultValue = "15") - private Integer connectTimeout; + @GeaflowConfigKey(value = "retryInterval") + @GeaflowConfigValue(required = false, defaultValue = "1") + private Integer retryInterval; - @GeaflowConfigKey(value = "readTimeout") - @GeaflowConfigValue(required = false, defaultValue = "15") - private Integer readTimeout; + @GeaflowConfigKey(value = "connectTimeout") + @GeaflowConfigValue(required = false, defaultValue = "15") + private Integer connectTimeout; - @GeaflowConfigKey(value = "writeTimeout") - @GeaflowConfigValue(required = false, defaultValue = "100") - private Integer writeTimeout; + @GeaflowConfigKey(value = "readTimeout") + @GeaflowConfigValue(required = false, defaultValue = "15") + private Integer readTimeout; + @GeaflowConfigKey(value = "writeTimeout") + @GeaflowConfigValue(required = false, defaultValue = "100") + private Integer writeTimeout; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LocalConfigArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LocalConfigArgsClass.java index f3a68572d..52be660f8 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LocalConfigArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/LocalConfigArgsClass.java @@ -19,17 +19,17 @@ package org.apache.geaflow.console.core.model.llm; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class LocalConfigArgsClass extends LLMConfigArgsClass { - @GeaflowConfigKey(value = "n_predict") - @GeaflowConfigValue(required = false, defaultValue = "128") - private Integer predict; - + @GeaflowConfigKey(value = "n_predict") + @GeaflowConfigValue(required = false, defaultValue = "128") + private Integer predict; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/OpenAIConfigArgsClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/OpenAIConfigArgsClass.java index 8c0684787..a0be34de4 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/OpenAIConfigArgsClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/llm/OpenAIConfigArgsClass.java @@ -19,21 +19,21 @@ package org.apache.geaflow.console.core.model.llm; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Setter @Getter public class OpenAIConfigArgsClass extends LLMConfigArgsClass { - @GeaflowConfigKey(value = "apiKey") - @GeaflowConfigValue(required = true) - private String apiKey; - + @GeaflowConfigKey(value = "apiKey") + @GeaflowConfigValue(required = true) + private String apiKey; - @GeaflowConfigKey(value = "modelId") - @GeaflowConfigValue(required = true) - private String modelId; + @GeaflowConfigKey(value = "modelId") + @GeaflowConfigValue(required = true) + private String modelId; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetric.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetric.java index ee1e0422d..e5a996b09 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetric.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetric.java @@ -26,10 +26,9 @@ @Getter public class GeaflowMetric { - private String metric; + private String metric; - private Long time; - - private Object value; + private Long time; + private Object value; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricMeta.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricMeta.java index 1c9f1a63a..fc3bf530f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricMeta.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricMeta.java @@ -26,14 +26,13 @@ @Setter public class GeaflowMetricMeta { - private String jobName; + private String jobName; - private String metricGroup; + private String metricGroup; - private String metricName; + private String metricName; - private String metricType; - - private String queries; + private String metricType; + private String queries; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQuery.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQuery.java index 9289a842a..ef20ab3f7 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQuery.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQuery.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.model.metric; import java.util.Map; + import lombok.Getter; import lombok.Setter; @@ -27,12 +28,11 @@ @Getter public class GeaflowMetricQuery { - private String metric; - - private String aggregator; + private String metric; - private String downsample; + private String aggregator; - private Map tags; + private String downsample; + private Map tags; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQueryRequest.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQueryRequest.java index 9703896fd..8225e2d73 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQueryRequest.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/metric/GeaflowMetricQueryRequest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.model.metric; import java.util.List; + import lombok.Getter; import lombok.Setter; @@ -27,9 +28,9 @@ @Getter public class GeaflowMetricQueryRequest { - long start; + long start; - long end; + long end; - List queries; + List queries; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/GeaflowPlugin.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/GeaflowPlugin.java index b9dce2c47..567ec00ad 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/GeaflowPlugin.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/GeaflowPlugin.java @@ -19,23 +19,24 @@ package org.apache.geaflow.console.core.model.plugin; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class GeaflowPlugin extends GeaflowName { - private String type; + private String type; - private GeaflowPluginCategory category; + private GeaflowPluginCategory category; - private GeaflowRemoteFile jarPackage; + private GeaflowRemoteFile jarPackage; - private boolean system; + private boolean system; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ConsolePluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ConsolePluginConfigClass.java index a2e72f2aa..55e2e7f33 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ConsolePluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ConsolePluginConfigClass.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.core.model.plugin.config; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Getter @Setter public class ConsolePluginConfigClass extends PluginConfigClass { - public ConsolePluginConfigClass() { - super(GeaflowPluginType.CONSOLE); - } - + public ConsolePluginConfigClass() { + super(GeaflowPluginType.CONSOLE); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ContainerPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ContainerPluginConfigClass.java index 99e391267..bf31fdc12 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ContainerPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/ContainerPluginConfigClass.java @@ -19,17 +19,18 @@ package org.apache.geaflow.console.core.model.plugin.config; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Slf4j @Getter @Setter public class ContainerPluginConfigClass extends PluginConfigClass { - public ContainerPluginConfigClass() { - super(GeaflowPluginType.CONTAINER); - } + public ContainerPluginConfigClass() { + super(GeaflowPluginType.CONTAINER); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/DfsPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/DfsPluginConfigClass.java index 41772f585..06f0c35dc 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/DfsPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/DfsPluginConfigClass.java @@ -19,31 +19,32 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Getter @Setter public class DfsPluginConfigClass extends PersistentPluginConfigClass { - public static final String DFS_URI_KEY = "fs.defaultFS"; + public static final String DFS_URI_KEY = "fs.defaultFS"; - @GeaflowConfigKey(value = DFS_URI_KEY, comment = "i18n.key.dfs.address") - @GeaflowConfigValue(required = true, defaultValue = "hdfs://0.0.0.0:9000") - private String defaultFs; + @GeaflowConfigKey(value = DFS_URI_KEY, comment = "i18n.key.dfs.address") + @GeaflowConfigValue(required = true, defaultValue = "hdfs://0.0.0.0:9000") + private String defaultFs; - public DfsPluginConfigClass() { - super(GeaflowPluginType.DFS); - } + public DfsPluginConfigClass() { + super(GeaflowPluginType.DFS); + } - @Override - public void testConnection() { - NetworkUtil.testUrl(defaultFs); - } + @Override + public void testConnection() { + NetworkUtil.testUrl(defaultFs); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/FilePluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/FilePluginConfigClass.java index 523c5a1a5..365f8025a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/FilePluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/FilePluginConfigClass.java @@ -19,33 +19,36 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class FilePluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.file.path", comment = "i18n.key.file.path") - @GeaflowConfigValue(required = true, defaultValue = "/") - private String filePath; + @GeaflowConfigKey(value = "geaflow.dsl.file.path", comment = "i18n.key.file.path") + @GeaflowConfigValue(required = true, defaultValue = "/") + private String filePath; - @GeaflowConfigKey(value = "geaflow.dsl.column.separator", comment = "i18n.key.column.separator") - @GeaflowConfigValue(defaultValue = ",") - private String columnSeparator; + @GeaflowConfigKey(value = "geaflow.dsl.column.separator", comment = "i18n.key.column.separator") + @GeaflowConfigValue(defaultValue = ",") + private String columnSeparator; - @GeaflowConfigKey(value = "geaflow.dsl.line.separator", comment = "i18n.key.line.separator") - @GeaflowConfigValue(defaultValue = "\\n") - private String lineSeparator; + @GeaflowConfigKey(value = "geaflow.dsl.line.separator", comment = "i18n.key.line.separator") + @GeaflowConfigValue(defaultValue = "\\n") + private String lineSeparator; - @GeaflowConfigKey(value = "geaflow.file.persistent.config.json", comment = "i18n.key.ext.config.json") - @GeaflowConfigValue(defaultValue = "{\"fs.defaultFS\":\"local\"}") - private String persistConfig; + @GeaflowConfigKey( + value = "geaflow.file.persistent.config.json", + comment = "i18n.key.ext.config.json") + @GeaflowConfigValue(defaultValue = "{\"fs.defaultFS\":\"local\"}") + private String persistConfig; - public FilePluginConfigClass() { - super(GeaflowPluginType.FILE); - } + public FilePluginConfigClass() { + super(GeaflowPluginType.FILE); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/GeaflowPluginConfig.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/GeaflowPluginConfig.java index b536d092a..1e9dbccce 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/GeaflowPluginConfig.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/GeaflowPluginConfig.java @@ -19,55 +19,60 @@ package org.apache.geaflow.console.core.model.plugin.config; -import com.google.common.base.Preconditions; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.model.config.ConfigDescFactory; import org.apache.geaflow.console.core.model.config.GeaflowConfig; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Getter @Setter @NoArgsConstructor @Slf4j public class GeaflowPluginConfig extends GeaflowName { - private String type; + private String type; - private GeaflowPluginCategory category; + private GeaflowPluginCategory category; - private GeaflowConfig config; + private GeaflowConfig config; - public GeaflowPluginConfig(GeaflowPluginCategory category, PluginConfigClass pluginConfigClass) { - this.category = category; - this.type = pluginConfigClass.getType().name(); - this.config = pluginConfigClass.build(); - } - - public GeaflowPluginConfig(String name, String comment, GeaflowPluginType type, GeaflowPluginCategory category, - GeaflowConfig config) { - super.name = name; - super.comment = comment; - this.type = type.name(); - this.category = category; - this.config = config; - } + public GeaflowPluginConfig(GeaflowPluginCategory category, PluginConfigClass pluginConfigClass) { + this.category = category; + this.type = pluginConfigClass.getType().name(); + this.config = pluginConfigClass.build(); + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(type, "Invalid plugin type"); - Preconditions.checkNotNull(category, "Invalid category"); - Preconditions.checkNotNull(config, "Invalid plugin config"); + public GeaflowPluginConfig( + String name, + String comment, + GeaflowPluginType type, + GeaflowPluginCategory category, + GeaflowConfig config) { + super.name = name; + super.comment = comment; + this.type = type.name(); + this.category = category; + this.config = config; + } - GeaflowPluginType geaflowPluginType = GeaflowPluginType.of(type); - if (geaflowPluginType != GeaflowPluginType.None) { - ConfigDescFactory.get(geaflowPluginType).validateConfig(config); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(type, "Invalid plugin type"); + Preconditions.checkNotNull(category, "Invalid category"); + Preconditions.checkNotNull(config, "Invalid plugin config"); + GeaflowPluginType geaflowPluginType = GeaflowPluginType.of(type); + if (geaflowPluginType != GeaflowPluginType.None) { + ConfigDescFactory.get(geaflowPluginType).validateConfig(config); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/HivePluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/HivePluginConfigClass.java index 1b489f5b8..3b9f9abbb 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/HivePluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/HivePluginConfigClass.java @@ -19,39 +19,44 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class HivePluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.hive.metastore.uris", comment = "i18n.key.metastore.address") - @GeaflowConfigValue(required = true, defaultValue = "thrift://localhost:9083") - private String metastore; - - @GeaflowConfigKey(value = "geaflow.dsl.hive.database.name", comment = "i18n.key.database.name") - @GeaflowConfigValue(required = true) - private String database; - - @GeaflowConfigKey(value = "geaflow.dsl.hive.table.name", comment = "i18n.key.table.name") - @GeaflowConfigValue(required = true) - private String table; - - @GeaflowConfigKey(value = "geaflow.dsl.hive.splits.per.partition", comment = "i18n.key.read.splits.per.partition") - @GeaflowConfigValue(defaultValue = "1") - private Integer partition; - - public HivePluginConfigClass() { - super(GeaflowPluginType.HIVE); - } - - @Override - public void testConnection() { - NetworkUtil.testUrls(metastore, ","); - } + @GeaflowConfigKey( + value = "geaflow.dsl.hive.metastore.uris", + comment = "i18n.key.metastore.address") + @GeaflowConfigValue(required = true, defaultValue = "thrift://localhost:9083") + private String metastore; + + @GeaflowConfigKey(value = "geaflow.dsl.hive.database.name", comment = "i18n.key.database.name") + @GeaflowConfigValue(required = true) + private String database; + + @GeaflowConfigKey(value = "geaflow.dsl.hive.table.name", comment = "i18n.key.table.name") + @GeaflowConfigValue(required = true) + private String table; + + @GeaflowConfigKey( + value = "geaflow.dsl.hive.splits.per.partition", + comment = "i18n.key.read.splits.per.partition") + @GeaflowConfigValue(defaultValue = "1") + private Integer partition; + + public HivePluginConfigClass() { + super(GeaflowPluginType.HIVE); + } + + @Override + public void testConnection() { + NetworkUtil.testUrls(metastore, ","); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/InfluxdbPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/InfluxdbPluginConfigClass.java index ae4243ce6..6e1997196 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/InfluxdbPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/InfluxdbPluginConfigClass.java @@ -19,47 +19,52 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class InfluxdbPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.metric.influxdb.url", comment = "i18n.key.url") - @GeaflowConfigValue(required = true, defaultValue = "http://0.0.0.0:8086") - private String url; + @GeaflowConfigKey(value = "geaflow.metric.influxdb.url", comment = "i18n.key.url") + @GeaflowConfigValue(required = true, defaultValue = "http://0.0.0.0:8086") + private String url; - @GeaflowConfigKey(value = "geaflow.metric.influxdb.token", comment = "i18n.key.token") - @GeaflowConfigValue(required = true, masked = true) - private String token; + @GeaflowConfigKey(value = "geaflow.metric.influxdb.token", comment = "i18n.key.token") + @GeaflowConfigValue(required = true, masked = true) + private String token; - @GeaflowConfigKey(value = "geaflow.metric.influxdb.org", comment = "i18n.key.organization") - @GeaflowConfigValue(required = true, defaultValue = "geaflow") - private String org; + @GeaflowConfigKey(value = "geaflow.metric.influxdb.org", comment = "i18n.key.organization") + @GeaflowConfigValue(required = true, defaultValue = "geaflow") + private String org; - @GeaflowConfigKey(value = "geaflow.metric.influxdb.bucket", comment = "i18n.key.bucket") - @GeaflowConfigValue(required = true, defaultValue = "geaflow") - private String bucket; + @GeaflowConfigKey(value = "geaflow.metric.influxdb.bucket", comment = "i18n.key.bucket") + @GeaflowConfigValue(required = true, defaultValue = "geaflow") + private String bucket; - @GeaflowConfigKey(value = "geaflow.metric.influxdb.connect.timeout.ms", comment = "i18n.key.connect.timeout") - @GeaflowConfigValue(defaultValue = "30000") - private Integer connectTimeout; + @GeaflowConfigKey( + value = "geaflow.metric.influxdb.connect.timeout.ms", + comment = "i18n.key.connect.timeout") + @GeaflowConfigValue(defaultValue = "30000") + private Integer connectTimeout; - @GeaflowConfigKey(value = "geaflow.metric.influxdb.write.timeout.ms", comment = "i18n.key.write.timeout") - @GeaflowConfigValue(defaultValue = "30000") - private Integer writeTimeout; + @GeaflowConfigKey( + value = "geaflow.metric.influxdb.write.timeout.ms", + comment = "i18n.key.write.timeout") + @GeaflowConfigValue(defaultValue = "30000") + private Integer writeTimeout; - public InfluxdbPluginConfigClass() { - super(GeaflowPluginType.INFLUXDB); - } + public InfluxdbPluginConfigClass() { + super(GeaflowPluginType.INFLUXDB); + } - @Override - public void testConnection() { - NetworkUtil.testUrl(url); - } + @Override + public void testConnection() { + NetworkUtil.testUrl(url); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/JdbcPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/JdbcPluginConfigClass.java index 5186448c2..76a406713 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/JdbcPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/JdbcPluginConfigClass.java @@ -24,83 +24,94 @@ import java.sql.ResultSet; import java.sql.Statement; import java.util.Properties; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class JdbcPluginConfigClass extends PluginConfigClass { - public static final String MYSQL_DRIVER_CLASS = "com.mysql.jdbc.Driver"; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.driver.class", comment = "i18n.key.jdbc.driver") - @GeaflowConfigValue(defaultValue = MYSQL_DRIVER_CLASS) - private String driverClass; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.url", comment = "i18n.key.jdbc.url") - @GeaflowConfigValue(required = true, defaultValue = "jdbc:mysql://0.0.0.0:3306/geaflow?characterEncoding=utf8" - + "&autoReconnect=true&useSSL=false") - private String url; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.user.name", comment = "i18n.key.username") - @GeaflowConfigValue(required = true, defaultValue = "geaflow") - private String username; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.password", comment = "i18n.key.password") - @GeaflowConfigValue(required = true, defaultValue = "geaflow", masked = true) - private String password; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.max.retries", comment = "i18n.key.retry.times") - @GeaflowConfigValue(defaultValue = "3") - private Integer retryTimes; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.connection.pool.size", comment = "i18n.key.connection.pool.size") - @GeaflowConfigValue(defaultValue = "10") - private Integer connectionPoolSize; - - @GeaflowConfigKey(value = "geaflow.store.jdbc.connect.config.json", comment = "i18n.key.connection.ext.config.json") - private String configJson; - - public JdbcPluginConfigClass() { - super(GeaflowPluginType.JDBC); - } - - @Override - public void testConnection() { - try (Connection connection = createConnection()) { - try (Statement statement = connection.createStatement()) { - String testSql = "SELECT 1"; - try (ResultSet resultSet = statement.executeQuery(testSql)) { - if (!resultSet.next()) { - throw new GeaflowException("No response content of query '{}'", testSql); - } - } - } - } catch (Exception e) { - throw new GeaflowIllegalException("JDBC connection test failed, caused by {}", e.getMessage(), e); + public static final String MYSQL_DRIVER_CLASS = "com.mysql.jdbc.Driver"; + + @GeaflowConfigKey(value = "geaflow.store.jdbc.driver.class", comment = "i18n.key.jdbc.driver") + @GeaflowConfigValue(defaultValue = MYSQL_DRIVER_CLASS) + private String driverClass; + + @GeaflowConfigKey(value = "geaflow.store.jdbc.url", comment = "i18n.key.jdbc.url") + @GeaflowConfigValue( + required = true, + defaultValue = + "jdbc:mysql://0.0.0.0:3306/geaflow?characterEncoding=utf8" + + "&autoReconnect=true&useSSL=false") + private String url; + + @GeaflowConfigKey(value = "geaflow.store.jdbc.user.name", comment = "i18n.key.username") + @GeaflowConfigValue(required = true, defaultValue = "geaflow") + private String username; + + @GeaflowConfigKey(value = "geaflow.store.jdbc.password", comment = "i18n.key.password") + @GeaflowConfigValue(required = true, defaultValue = "geaflow", masked = true) + private String password; + + @GeaflowConfigKey(value = "geaflow.store.jdbc.max.retries", comment = "i18n.key.retry.times") + @GeaflowConfigValue(defaultValue = "3") + private Integer retryTimes; + + @GeaflowConfigKey( + value = "geaflow.store.jdbc.connection.pool.size", + comment = "i18n.key.connection.pool.size") + @GeaflowConfigValue(defaultValue = "10") + private Integer connectionPoolSize; + + @GeaflowConfigKey( + value = "geaflow.store.jdbc.connect.config.json", + comment = "i18n.key.connection.ext.config.json") + private String configJson; + + public JdbcPluginConfigClass() { + super(GeaflowPluginType.JDBC); + } + + @Override + public void testConnection() { + try (Connection connection = createConnection()) { + try (Statement statement = connection.createStatement()) { + String testSql = "SELECT 1"; + try (ResultSet resultSet = statement.executeQuery(testSql)) { + if (!resultSet.next()) { + throw new GeaflowException("No response content of query '{}'", testSql); + } } + } + } catch (Exception e) { + throw new GeaflowIllegalException( + "JDBC connection test failed, caused by {}", e.getMessage(), e); } - - public Connection createConnection() { - try { - ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); - if (this.driverClass != null) { - classLoader.loadClass(this.driverClass); - } - - DriverManager.setLoginTimeout(3); - Properties properties = new Properties(); - properties.setProperty("user", username); - properties.setProperty("password", password); - return DriverManager.getConnection(url, properties); - - } catch (Exception e) { - throw new GeaflowIllegalException("JDBC connection create failed, caused by {}", e.getMessage(), e); - } + } + + public Connection createConnection() { + try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (this.driverClass != null) { + classLoader.loadClass(this.driverClass); + } + + DriverManager.setLoginTimeout(3); + Properties properties = new Properties(); + properties.setProperty("user", username); + properties.setProperty("password", password); + return DriverManager.getConnection(url, properties); + + } catch (Exception e) { + throw new GeaflowIllegalException( + "JDBC connection create failed, caused by {}", e.getMessage(), e); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/K8sPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/K8sPluginConfigClass.java index 23b9505a5..95fa629eb 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/K8sPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/K8sPluginConfigClass.java @@ -19,86 +19,101 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Getter @Setter public class K8sPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "kubernetes.master.url", comment = "i18n.key.k8s.server.url") - @GeaflowConfigValue(required = true, defaultValue = "https://0.0.0.0:6443") - private String masterUrl; - - @GeaflowConfigKey(value = "kubernetes.container.image", comment = "i18n.key.geaflow.registry.address") - @GeaflowConfigValue(required = true, defaultValue = "tugraph/geaflow:0.1") - private String imageUrl; - - @GeaflowConfigKey(value = "kubernetes.service.account", comment = "i18n.key.api.service.username") - @GeaflowConfigValue(defaultValue = "geaflow") - private String serviceAccount; - - @GeaflowConfigKey(value = "kubernetes.service.exposed.type", comment = "i18n.key.api.service.type") - @GeaflowConfigValue(defaultValue = "NODE_PORT") - private String serviceType; - - @GeaflowConfigKey(value = "kubernetes.namespace", comment = "i18n.key.namespace") - @GeaflowConfigValue(defaultValue = "default") - private String namespace; - - @GeaflowConfigKey(value = "kubernetes.cert.data", comment = "i18n.key.client.cert.data") - private String certData; - - @GeaflowConfigKey(value = "kubernetes.cert.key", comment = "i18n.key.client.cert.key") - private String certKey; - - @GeaflowConfigKey(value = "kubernetes.ca.data", comment = "i18n.key.cluster.ca.data") - private String caData; - - @GeaflowConfigKey(value = "kubernetes.connection.retry.times", comment = "i18n.key.retry.times") - @GeaflowConfigValue(defaultValue = "100") - private Integer retryTimes; - - @GeaflowConfigKey(value = "kubernetes.cluster.name", comment = "i18n.key.cluster.name") - private String clusterName; - - @GeaflowConfigKey(value = "kubernetes.pod.user.labels", comment = "i18n.key.pod.user.labels") - private String podUserLabels; - - @GeaflowConfigKey(value = "kubernetes.service.suffix", comment = "i18n.key.api.service.suffix") - private String serviceSuffix; - - @GeaflowConfigKey(value = "kubernetes.resource.storage.limit.size", comment = "i18n.key.storage.limit") - @GeaflowConfigValue(defaultValue = "10Gi") - private String storageLimit; - - @GeaflowConfigKey(value = "kubernetes.geaflow.cluster.timeout.ms", comment = "i18n.key.client.timeout") - @GeaflowConfigValue(defaultValue = "300000") - private Integer clientTimeout; - - @GeaflowConfigKey(value = "kubernetes.container.image.pullPolicy", comment = "i18n.key.image.pull.policy") - @GeaflowConfigValue(defaultValue = "Always") - private String pullPolicy; - - @GeaflowConfigKey(value = "kubernetes.certs.client.key.algo", comment = "i18n.key.client.cert.key.algo") - private String certKeyAlgo; - - @GeaflowConfigKey(value = "kubernetes.engine.jar.pull.always", comment = "i18n.key.engine.jar.pull.always") - @GeaflowConfigValue(defaultValue = "true") - private String alwaysPullEngineJar; - - public K8sPluginConfigClass() { - super(GeaflowPluginType.K8S); - } - - @Override - public void testConnection() { - NetworkUtil.testUrl(masterUrl); - } + @GeaflowConfigKey(value = "kubernetes.master.url", comment = "i18n.key.k8s.server.url") + @GeaflowConfigValue(required = true, defaultValue = "https://0.0.0.0:6443") + private String masterUrl; + + @GeaflowConfigKey( + value = "kubernetes.container.image", + comment = "i18n.key.geaflow.registry.address") + @GeaflowConfigValue(required = true, defaultValue = "tugraph/geaflow:0.1") + private String imageUrl; + + @GeaflowConfigKey(value = "kubernetes.service.account", comment = "i18n.key.api.service.username") + @GeaflowConfigValue(defaultValue = "geaflow") + private String serviceAccount; + + @GeaflowConfigKey( + value = "kubernetes.service.exposed.type", + comment = "i18n.key.api.service.type") + @GeaflowConfigValue(defaultValue = "NODE_PORT") + private String serviceType; + + @GeaflowConfigKey(value = "kubernetes.namespace", comment = "i18n.key.namespace") + @GeaflowConfigValue(defaultValue = "default") + private String namespace; + + @GeaflowConfigKey(value = "kubernetes.cert.data", comment = "i18n.key.client.cert.data") + private String certData; + + @GeaflowConfigKey(value = "kubernetes.cert.key", comment = "i18n.key.client.cert.key") + private String certKey; + + @GeaflowConfigKey(value = "kubernetes.ca.data", comment = "i18n.key.cluster.ca.data") + private String caData; + + @GeaflowConfigKey(value = "kubernetes.connection.retry.times", comment = "i18n.key.retry.times") + @GeaflowConfigValue(defaultValue = "100") + private Integer retryTimes; + + @GeaflowConfigKey(value = "kubernetes.cluster.name", comment = "i18n.key.cluster.name") + private String clusterName; + + @GeaflowConfigKey(value = "kubernetes.pod.user.labels", comment = "i18n.key.pod.user.labels") + private String podUserLabels; + + @GeaflowConfigKey(value = "kubernetes.service.suffix", comment = "i18n.key.api.service.suffix") + private String serviceSuffix; + + @GeaflowConfigKey( + value = "kubernetes.resource.storage.limit.size", + comment = "i18n.key.storage.limit") + @GeaflowConfigValue(defaultValue = "10Gi") + private String storageLimit; + + @GeaflowConfigKey( + value = "kubernetes.geaflow.cluster.timeout.ms", + comment = "i18n.key.client.timeout") + @GeaflowConfigValue(defaultValue = "300000") + private Integer clientTimeout; + + @GeaflowConfigKey( + value = "kubernetes.container.image.pullPolicy", + comment = "i18n.key.image.pull.policy") + @GeaflowConfigValue(defaultValue = "Always") + private String pullPolicy; + + @GeaflowConfigKey( + value = "kubernetes.certs.client.key.algo", + comment = "i18n.key.client.cert.key.algo") + private String certKeyAlgo; + + @GeaflowConfigKey( + value = "kubernetes.engine.jar.pull.always", + comment = "i18n.key.engine.jar.pull.always") + @GeaflowConfigValue(defaultValue = "true") + private String alwaysPullEngineJar; + + public K8sPluginConfigClass() { + super(GeaflowPluginType.K8S); + } + + @Override + public void testConnection() { + NetworkUtil.testUrl(masterUrl); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/KafkaPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/KafkaPluginConfigClass.java index 14346147b..31869de13 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/KafkaPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/KafkaPluginConfigClass.java @@ -19,35 +19,36 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class KafkaPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.kafka.servers", comment = "i18n.key.servers") - @GeaflowConfigValue(required = true, defaultValue = "0.0.0.0:9092") - private String servers; + @GeaflowConfigKey(value = "geaflow.dsl.kafka.servers", comment = "i18n.key.servers") + @GeaflowConfigValue(required = true, defaultValue = "0.0.0.0:9092") + private String servers; - @GeaflowConfigKey(value = "geaflow.dsl.kafka.group.id", comment = "i18n.key.group.id") - @GeaflowConfigValue(required = true) - private String group; + @GeaflowConfigKey(value = "geaflow.dsl.kafka.group.id", comment = "i18n.key.group.id") + @GeaflowConfigValue(required = true) + private String group; - @GeaflowConfigKey(value = "geaflow.dsl.kafka.topic", comment = "i18n.key.topic") - @GeaflowConfigValue(required = true) - private String topic; + @GeaflowConfigKey(value = "geaflow.dsl.kafka.topic", comment = "i18n.key.topic") + @GeaflowConfigValue(required = true) + private String topic; - public KafkaPluginConfigClass() { - super(GeaflowPluginType.KAFKA); - } + public KafkaPluginConfigClass() { + super(GeaflowPluginType.KAFKA); + } - @Override - public void testConnection() { - NetworkUtil.testUrls(servers, ","); - } + @Override + public void testConnection() { + NetworkUtil.testUrls(servers, ","); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/LocalPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/LocalPluginConfigClass.java index 75c65e3a0..4a5dd8a23 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/LocalPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/LocalPluginConfigClass.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.core.model.plugin.config; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Getter @Setter public class LocalPluginConfigClass extends PersistentPluginConfigClass { - public LocalPluginConfigClass() { - super(GeaflowPluginType.LOCAL); - } - + public LocalPluginConfigClass() { + super(GeaflowPluginType.LOCAL); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/MemoryPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/MemoryPluginConfigClass.java index 4277b772e..ae05b493e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/MemoryPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/MemoryPluginConfigClass.java @@ -19,15 +19,16 @@ package org.apache.geaflow.console.core.model.plugin.config; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Getter @Setter public class MemoryPluginConfigClass extends PluginConfigClass { - public MemoryPluginConfigClass() { - super(GeaflowPluginType.MEMORY); - } + public MemoryPluginConfigClass() { + super(GeaflowPluginType.MEMORY); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/OssPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/OssPluginConfigClass.java index 8a2ed235a..adc99343a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/OssPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/OssPluginConfigClass.java @@ -19,45 +19,46 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class OssPluginConfigClass extends PersistentPluginConfigClass { - @GeaflowConfigKey(value = "geaflow.file.oss.endpoint", comment = "i18n.key.endpoint") - @GeaflowConfigValue(required = true, defaultValue = "cn-hangzhou.alipay.aliyun-inc.com") - private String endpoint; + @GeaflowConfigKey(value = "geaflow.file.oss.endpoint", comment = "i18n.key.endpoint") + @GeaflowConfigValue(required = true, defaultValue = "cn-hangzhou.alipay.aliyun-inc.com") + private String endpoint; - @GeaflowConfigKey(value = "geaflow.file.oss.access.id", comment = "i18n.key.access.id") - @GeaflowConfigValue(required = true) - private String accessId; + @GeaflowConfigKey(value = "geaflow.file.oss.access.id", comment = "i18n.key.access.id") + @GeaflowConfigValue(required = true) + private String accessId; - @GeaflowConfigKey(value = "geaflow.file.oss.secret.key", comment = "i18n.key.secret.key") - @GeaflowConfigValue(required = true, masked = true) - private String secretKey; + @GeaflowConfigKey(value = "geaflow.file.oss.secret.key", comment = "i18n.key.secret.key") + @GeaflowConfigValue(required = true, masked = true) + private String secretKey; - @GeaflowConfigKey(value = "geaflow.file.oss.bucket.name", comment = "i18n.key.bucket") - @GeaflowConfigValue(required = true) - private String bucket; + @GeaflowConfigKey(value = "geaflow.file.oss.bucket.name", comment = "i18n.key.bucket") + @GeaflowConfigValue(required = true) + private String bucket; - public OssPluginConfigClass() { - super(GeaflowPluginType.OSS); - } + public OssPluginConfigClass() { + super(GeaflowPluginType.OSS); + } - @Override - public void testConnection() { - if (NetworkUtil.getPort(endpoint) == null) { - String host = NetworkUtil.getHost(endpoint); - NetworkUtil.testHostPort(host, 80); + @Override + public void testConnection() { + if (NetworkUtil.getPort(endpoint) == null) { + String host = NetworkUtil.getHost(endpoint); + NetworkUtil.testHostPort(host, 80); - } else { - NetworkUtil.testUrl(endpoint); - } + } else { + NetworkUtil.testUrl(endpoint); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PersistentPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PersistentPluginConfigClass.java index 83476cfdc..3ec7144ae 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PersistentPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PersistentPluginConfigClass.java @@ -19,29 +19,39 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public abstract class PersistentPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.file.persistent.root", comment = "i18n.key.root.path", jsonIgnore = true) - @GeaflowConfigValue(required = true, defaultValue = "/") - private String root; + @GeaflowConfigKey( + value = "geaflow.file.persistent.root", + comment = "i18n.key.root.path", + jsonIgnore = true) + @GeaflowConfigValue(required = true, defaultValue = "/") + private String root; - @GeaflowConfigKey(value = "geaflow.file.persistent.user.name", comment = "i18n.key.username", jsonIgnore = true) - @GeaflowConfigValue - private String username; + @GeaflowConfigKey( + value = "geaflow.file.persistent.user.name", + comment = "i18n.key.username", + jsonIgnore = true) + @GeaflowConfigValue + private String username; - @GeaflowConfigKey(value = "geaflow.file.persistent.thread.size", comment = "i18n.key.local.thread.pool.count", jsonIgnore = true) - @GeaflowConfigValue - private Integer threadSize; + @GeaflowConfigKey( + value = "geaflow.file.persistent.thread.size", + comment = "i18n.key.local.thread.pool.count", + jsonIgnore = true) + @GeaflowConfigValue + private Integer threadSize; - public PersistentPluginConfigClass(GeaflowPluginType type) { - super(type); - } + public PersistentPluginConfigClass(GeaflowPluginType type) { + super(type); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PluginConfigClass.java index 0d80ca8fe..55a5f111a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/PluginConfigClass.java @@ -19,20 +19,19 @@ package org.apache.geaflow.console.core.model.plugin.config; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; +import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; -import org.apache.geaflow.console.core.model.config.GeaflowConfigClass; @Slf4j @Getter @AllArgsConstructor public abstract class PluginConfigClass extends GeaflowConfigClass { - private GeaflowPluginType type; - - public void testConnection() { + private GeaflowPluginType type; - } + public void testConnection() {} } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RayPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RayPluginConfigClass.java index 0f8bd9c8a..e8f4a7c78 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RayPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RayPluginConfigClass.java @@ -19,36 +19,38 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Getter @Setter public class RayPluginConfigClass extends PluginConfigClass { - public RayPluginConfigClass() { - super(GeaflowPluginType.RAY); - } - - @GeaflowConfigKey(value = "ray.dashboard.address", comment = "ray.dashboard.address") - @GeaflowConfigValue(required = true, defaultValue = "http://127.0.0.1:8090") - private String dashboardAddress; + public RayPluginConfigClass() { + super(GeaflowPluginType.RAY); + } - @GeaflowConfigKey(value = "ray.redis.address", comment = "ray.redis.address") - @GeaflowConfigValue(required = true, defaultValue = "127.0.0.1:6379") - private String redisAddress; + @GeaflowConfigKey(value = "ray.dashboard.address", comment = "ray.dashboard.address") + @GeaflowConfigValue(required = true, defaultValue = "http://127.0.0.1:8090") + private String dashboardAddress; + @GeaflowConfigKey(value = "ray.redis.address", comment = "ray.redis.address") + @GeaflowConfigValue(required = true, defaultValue = "127.0.0.1:6379") + private String redisAddress; - @GeaflowConfigKey(value = "ray.dist.jar.path", comment = "ray.dist.jar.path") - @GeaflowConfigValue(required = true) - private String distJarPath; + @GeaflowConfigKey(value = "ray.dist.jar.path", comment = "ray.dist.jar.path") + @GeaflowConfigValue(required = true) + private String distJarPath; - @GeaflowConfigKey(value = "ray.session.resource.jar.path", comment = "ray.session.resource.jar.path") - @GeaflowConfigValue(required = true) - private String sessionResourceJarPath; + @GeaflowConfigKey( + value = "ray.session.resource.jar.path", + comment = "ray.session.resource.jar.path") + @GeaflowConfigValue(required = true) + private String sessionResourceJarPath; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RedisPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RedisPluginConfigClass.java index 9e81c605d..bff3e5de9 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RedisPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RedisPluginConfigClass.java @@ -19,13 +19,14 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; + +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; import redis.clients.jedis.Jedis; @Getter @@ -33,43 +34,48 @@ @ToString public class RedisPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.store.redis.host", comment = "i18n.key.host") - @GeaflowConfigValue(required = true, defaultValue = "0.0.0.0") - private String host; + @GeaflowConfigKey(value = "geaflow.store.redis.host", comment = "i18n.key.host") + @GeaflowConfigValue(required = true, defaultValue = "0.0.0.0") + private String host; - @GeaflowConfigKey(value = "geaflow.store.redis.port", comment = "i18n.key.port") - @GeaflowConfigValue(required = true, defaultValue = "6379") - private Integer port; + @GeaflowConfigKey(value = "geaflow.store.redis.port", comment = "i18n.key.port") + @GeaflowConfigValue(required = true, defaultValue = "6379") + private Integer port; - @GeaflowConfigKey(value = "geaflow.store.redis.user", comment = "i18n.key.user") - private String user; + @GeaflowConfigKey(value = "geaflow.store.redis.user", comment = "i18n.key.user") + private String user; - @GeaflowConfigKey(value = "geaflow.store.redis.password", comment = "i18n.key.password") - private String password; + @GeaflowConfigKey(value = "geaflow.store.redis.password", comment = "i18n.key.password") + private String password; - @GeaflowConfigKey(value = "geaflow.store.redis.connection.timeout", comment = "i18n.key.connection.timeout") - @GeaflowConfigValue(defaultValue = "5000") - private Integer connectionTimeoutMs; + @GeaflowConfigKey( + value = "geaflow.store.redis.connection.timeout", + comment = "i18n.key.connection.timeout") + @GeaflowConfigValue(defaultValue = "5000") + private Integer connectionTimeoutMs; - @GeaflowConfigKey(value = "geaflow.store.redis.retry.times", comment = "i18n.key.retry.times") - @GeaflowConfigValue(defaultValue = "10") - private Integer retryTimes; + @GeaflowConfigKey(value = "geaflow.store.redis.retry.times", comment = "i18n.key.retry.times") + @GeaflowConfigValue(defaultValue = "10") + private Integer retryTimes; - @GeaflowConfigKey(value = "geaflow.store.redis.retry.interval.ms", comment = "i18n.key.retry.interval.ms") - @GeaflowConfigValue(defaultValue = "500") - private Integer retryIntervalMs; + @GeaflowConfigKey( + value = "geaflow.store.redis.retry.interval.ms", + comment = "i18n.key.retry.interval.ms") + @GeaflowConfigValue(defaultValue = "500") + private Integer retryIntervalMs; - public RedisPluginConfigClass() { - super(GeaflowPluginType.REDIS); - } + public RedisPluginConfigClass() { + super(GeaflowPluginType.REDIS); + } - @Override - public void testConnection() { - try (Jedis jedis = new Jedis(host, port)) { - jedis.ping(); + @Override + public void testConnection() { + try (Jedis jedis = new Jedis(host, port)) { + jedis.ping(); - } catch (Exception e) { - throw new GeaflowIllegalException("Redis connection test failed, caused by {}", e.getMessage(), e); - } + } catch (Exception e) { + throw new GeaflowIllegalException( + "Redis connection test failed, caused by {}", e.getMessage(), e); } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RocksdbPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RocksdbPluginConfigClass.java index 9589c09df..182391783 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RocksdbPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/RocksdbPluginConfigClass.java @@ -23,7 +23,7 @@ public class RocksdbPluginConfigClass extends PluginConfigClass { - public RocksdbPluginConfigClass() { - super(GeaflowPluginType.ROCKSDB); - } + public RocksdbPluginConfigClass() { + super(GeaflowPluginType.ROCKSDB); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/SocketPluginConfigClass.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/SocketPluginConfigClass.java index 03d62af2a..6bbcf071f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/SocketPluginConfigClass.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/plugin/config/SocketPluginConfigClass.java @@ -19,33 +19,34 @@ package org.apache.geaflow.console.core.model.plugin.config; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfigKey; import org.apache.geaflow.console.core.model.config.GeaflowConfigValue; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Getter @Setter public class SocketPluginConfigClass extends PluginConfigClass { - @GeaflowConfigKey(value = "geaflow.dsl.socket.host", comment = "i18n.key.host") - @GeaflowConfigValue(required = true) - private String host; + @GeaflowConfigKey(value = "geaflow.dsl.socket.host", comment = "i18n.key.host") + @GeaflowConfigValue(required = true) + private String host; - @GeaflowConfigKey(value = "geaflow.dsl.socket.port", comment = "i18n.key.port") - @GeaflowConfigValue(required = true) - private Integer port; + @GeaflowConfigKey(value = "geaflow.dsl.socket.port", comment = "i18n.key.port") + @GeaflowConfigValue(required = true) + private Integer port; - public SocketPluginConfigClass() { - super(GeaflowPluginType.SOCKET); - } + public SocketPluginConfigClass() { + super(GeaflowPluginType.SOCKET); + } - @Override - public void testConnection() { - NetworkUtil.testHostPort(host, port); - } + @Override + public void testConnection() { + NetworkUtil.testHostPort(host, port); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/GeaflowRelease.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/GeaflowRelease.java index 753a234bf..5ab200b77 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/GeaflowRelease.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/GeaflowRelease.java @@ -19,49 +19,51 @@ package org.apache.geaflow.console.core.model.release; -import com.google.common.base.Preconditions; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.GeaflowId; import org.apache.geaflow.console.core.model.cluster.GeaflowCluster; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.job.GeaflowJob; import org.apache.geaflow.console.core.model.version.GeaflowVersion; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowRelease extends GeaflowId { - private GeaflowConfig jobConfig = new GeaflowConfig(); - private GeaflowConfig clusterConfig = new GeaflowConfig(); - private GeaflowJob job; - private GeaflowVersion version; - private JobPlan jobPlan; - private GeaflowCluster cluster; + private GeaflowConfig jobConfig = new GeaflowConfig(); + private GeaflowConfig clusterConfig = new GeaflowConfig(); + private GeaflowJob job; + private GeaflowVersion version; + private JobPlan jobPlan; + private GeaflowCluster cluster; - private int releaseVersion; + private int releaseVersion; - private String url; + private String url; - private String md5; + private String md5; - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(job, "Invalid job"); - if (!job.isApiJob()) { - Preconditions.checkNotNull(jobPlan, "Invalid jobPlan"); - } - Preconditions.checkNotNull(version, "Invalid version"); - Preconditions.checkNotNull(cluster, "Invalid cluster"); - Preconditions.checkArgument(releaseVersion >= 1); + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(job, "Invalid job"); + if (!job.isApiJob()) { + Preconditions.checkNotNull(jobPlan, "Invalid jobPlan"); } + Preconditions.checkNotNull(version, "Invalid version"); + Preconditions.checkNotNull(cluster, "Invalid cluster"); + Preconditions.checkArgument(releaseVersion >= 1); + } - public void addJobConfig(GeaflowConfig config) { - this.jobConfig.putAll(config); - } + public void addJobConfig(GeaflowConfig config) { + this.jobConfig.putAll(config); + } - public void addClusterConfig(GeaflowConfig config) { - this.clusterConfig.putAll(config); - } + public void addClusterConfig(GeaflowConfig config) { + this.clusterConfig.putAll(config); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlan.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlan.java index d616bd247..5376f65ac 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlan.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlan.java @@ -19,10 +19,12 @@ package org.apache.geaflow.console.core.model.release; -import com.alibaba.fastjson.JSON; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; + +import com.alibaba.fastjson.JSON; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; @@ -31,40 +33,38 @@ @Setter public class JobPlan { - private final Map vertices = new LinkedHashMap<>(); - - private final Map> edgeMap = new LinkedHashMap<>(); - - public static JobPlan build(String json) { - return JSON.parseObject(json, JobPlan.class); - } + private final Map vertices = new LinkedHashMap<>(); - public String toJsonString() { - return JSON.toJSONString(this); - } + private final Map> edgeMap = new LinkedHashMap<>(); - @Getter - @Setter - public static class PlanVertex { + public static JobPlan build(String json) { + return JSON.parseObject(json, JobPlan.class); + } - private String key; + public String toJsonString() { + return JSON.toJSONString(this); + } - private int parallelism; + @Getter + @Setter + public static class PlanVertex { - private JobPlan innerPlan; - } + private String key; - @Getter - @Setter - @AllArgsConstructor - public static class PlanEdge { + private int parallelism; - private String sourceKey; + private JobPlan innerPlan; + } - private String targetKey; + @Getter + @Setter + @AllArgsConstructor + public static class PlanEdge { - private String type; - } + private String sourceKey; + private String targetKey; + private String type; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlanBuilder.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlanBuilder.java index eb9a20ee9..69790844d 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlanBuilder.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/JobPlanBuilder.java @@ -23,112 +23,114 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.common.service.integration.engine.JsonPlan; import org.apache.geaflow.console.common.service.integration.engine.Predecessor; import org.apache.geaflow.console.common.service.integration.engine.Vertex; import org.apache.geaflow.console.core.model.release.JobPlan.PlanEdge; import org.apache.geaflow.console.core.model.release.JobPlan.PlanVertex; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class JobPlanBuilder { - public static JobPlan build(JsonPlan jsonPlan) { - JobPlanContext jobPlanContext = new JobPlanContext(); - return jobPlanContext.build(jsonPlan); + public static JobPlan build(JsonPlan jsonPlan) { + JobPlanContext jobPlanContext = new JobPlanContext(); + return jobPlanContext.build(jsonPlan); + } + + public static void setParallelisms(JobPlan jobPlan, Map map) { + for (PlanVertex vertex : jobPlan.getVertices().values()) { + String key = vertex.getKey(); + if (map.containsKey(key)) { + int parallelism = map.get(key); + // set inner parallelism recursively + setInnerParallelism(vertex, parallelism, map); + } } + } - public static void setParallelisms(JobPlan jobPlan, Map map) { - for (PlanVertex vertex : jobPlan.getVertices().values()) { - String key = vertex.getKey(); - if (map.containsKey(key)) { - int parallelism = map.get(key); - // set inner parallelism recursively - setInnerParallelism(vertex, parallelism, map); - } - } + private static void setInnerParallelism( + PlanVertex vertex, int parallelism, Map map) { + if (vertex.getInnerPlan() == null) { + return; } - private static void setInnerParallelism(PlanVertex vertex, int parallelism, Map map) { - if (vertex.getInnerPlan() == null) { - return; - } - - for (PlanVertex v : vertex.getInnerPlan().getVertices().values()) { - map.put(v.getKey(), parallelism); - setInnerParallelism(v, parallelism, map); - } + for (PlanVertex v : vertex.getInnerPlan().getVertices().values()) { + map.put(v.getKey(), parallelism); + setInnerParallelism(v, parallelism, map); } - - public static Map getParallelismMap(JobPlan jobPlan) { - Map map = new HashMap<>(); - getInnerParallelism(jobPlan, map); - return map; + } + + public static Map getParallelismMap(JobPlan jobPlan) { + Map map = new HashMap<>(); + getInnerParallelism(jobPlan, map); + return map; + } + + private static void getInnerParallelism(JobPlan jobPlan, Map map) { + for (PlanVertex v : jobPlan.getVertices().values()) { + map.put(v.getKey(), v.getParallelism()); + if (v.getInnerPlan() != null) { + getInnerParallelism(v.getInnerPlan(), map); + } } - - private static void getInnerParallelism(JobPlan jobPlan, Map map) { - for (PlanVertex v : jobPlan.getVertices().values()) { - map.put(v.getKey(), v.getParallelism()); - if (v.getInnerPlan() != null) { - getInnerParallelism(v.getInnerPlan(), map); - } - } + } + + private static class JobPlanContext { + + private int level = 0; + private Map keyMap = new HashMap<>(); + + private JobPlan build(JsonPlan jsonPlan) { + + JobPlan jobPlan = new JobPlan(); + Map vertexMap = jsonPlan.getVertices(); + for (Vertex vertex : vertexMap.values()) { + level++; + keyMap.put(vertex.getId(), getVertexKey(vertex)); + PlanVertex planVertex = getPlanVertex(vertex); + jobPlan.getVertices().put(planVertex.getKey(), planVertex); + level--; + } + // set edges after getting the mapping of vertex ids and keys + for (Vertex vertex : vertexMap.values()) { + setEdgeMap(jobPlan.getEdgeMap(), vertex); + } + return jobPlan; } - private static class JobPlanContext { - - private int level = 0; - private Map keyMap = new HashMap<>(); - - private JobPlan build(JsonPlan jsonPlan) { - - JobPlan jobPlan = new JobPlan(); - Map vertexMap = jsonPlan.getVertices(); - for (Vertex vertex : vertexMap.values()) { - level++; - keyMap.put(vertex.getId(), getVertexKey(vertex)); - PlanVertex planVertex = getPlanVertex(vertex); - jobPlan.getVertices().put(planVertex.getKey(), planVertex); - level--; - } - // set edges after getting the mapping of vertex ids and keys - for (Vertex vertex : vertexMap.values()) { - setEdgeMap(jobPlan.getEdgeMap(), vertex); - } - return jobPlan; - } - - private void setEdgeMap(Map> edgeMap, Vertex vertex) { - List parents = vertex.getParents(); - for (Predecessor parent : parents) { - String sourceVertexKey = keyMap.get(parent.getId()); - String targetVertexKey = keyMap.get(vertex.getId()); - edgeMap.putIfAbsent(sourceVertexKey, new ArrayList<>()); - - PlanEdge planEdge = new PlanEdge(sourceVertexKey, targetVertexKey, parent.getPartitionType()); - edgeMap.get(sourceVertexKey).add(planEdge); - } - } - - private PlanVertex getPlanVertex(Vertex vertex) { - PlanVertex planVertex = new PlanVertex(); - planVertex.setKey(keyMap.get(vertex.getId())); - planVertex.setParallelism(vertex.getParallelism()); - // set innerPlan recursively - if (vertex.getInnerPlan() != null) { - JobPlan jobPlan = build(vertex.getInnerPlan()); - planVertex.setInnerPlan(jobPlan); - } - - return planVertex; - } - - private String getVertexKey(Vertex vertex) { - return level == 1 ? vertex.getVertexType() + "-" + vertex.getId() : vertex.getOperatorName(); - } + private void setEdgeMap(Map> edgeMap, Vertex vertex) { + List parents = vertex.getParents(); + for (Predecessor parent : parents) { + String sourceVertexKey = keyMap.get(parent.getId()); + String targetVertexKey = keyMap.get(vertex.getId()); + edgeMap.putIfAbsent(sourceVertexKey, new ArrayList<>()); + + PlanEdge planEdge = + new PlanEdge(sourceVertexKey, targetVertexKey, parent.getPartitionType()); + edgeMap.get(sourceVertexKey).add(planEdge); + } } + private PlanVertex getPlanVertex(Vertex vertex) { + PlanVertex planVertex = new PlanVertex(); + planVertex.setKey(keyMap.get(vertex.getId())); + planVertex.setParallelism(vertex.getParallelism()); + // set innerPlan recursively + if (vertex.getInnerPlan() != null) { + JobPlan jobPlan = build(vertex.getInnerPlan()); + planVertex.setInnerPlan(jobPlan); + } + + return planVertex; + } + private String getVertexKey(Vertex vertex) { + return level == 1 ? vertex.getVertexType() + "-" + vertex.getId() : vertex.getOperatorName(); + } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/ReleaseUpdate.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/ReleaseUpdate.java index 448800fe9..119524c37 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/ReleaseUpdate.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/release/ReleaseUpdate.java @@ -20,23 +20,25 @@ package org.apache.geaflow.console.core.model.release; import java.util.Map; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.core.model.cluster.GeaflowCluster; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.version.GeaflowVersion; +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class ReleaseUpdate { - private GeaflowVersion newVersion; + private GeaflowVersion newVersion; - private Map newParallelisms; + private Map newParallelisms; - private GeaflowConfig newJobConfig; + private GeaflowConfig newJobConfig; - private GeaflowConfig newClusterConfig; + private GeaflowConfig newClusterConfig; - private GeaflowCluster newCluster; + private GeaflowCluster newCluster; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowAudit.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowAudit.java index bcb2d344e..87c1b0c22 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowAudit.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowAudit.java @@ -19,34 +19,35 @@ package org.apache.geaflow.console.core.model.runtime; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; import org.apache.geaflow.console.common.util.type.GeaflowOperationType; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; import org.apache.geaflow.console.core.model.GeaflowId; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + @Getter @Setter @NoArgsConstructor public class GeaflowAudit extends GeaflowId { - private GeaflowOperationType operationType; + private GeaflowOperationType operationType; - private String resourceId; + private String resourceId; - private GeaflowResourceType resourceType; + private GeaflowResourceType resourceType; - private String detail; + private String detail; - public GeaflowAudit(String taskId, GeaflowOperationType operationType) { - this(taskId, operationType, null); - } + public GeaflowAudit(String taskId, GeaflowOperationType operationType) { + this(taskId, operationType, null); + } - public GeaflowAudit(String taskId, GeaflowOperationType operationType, String detail) { - this.operationType = operationType; - this.resourceId = taskId; - this.resourceType = GeaflowResourceType.TASK; - this.detail = detail; - } + public GeaflowAudit(String taskId, GeaflowOperationType operationType, String detail) { + this.operationType = operationType; + this.resourceId = taskId; + this.resourceType = GeaflowResourceType.TASK; + this.detail = detail; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowCycle.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowCycle.java index d3e26f579..09761d921 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowCycle.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowCycle.java @@ -26,20 +26,19 @@ @Setter public class GeaflowCycle { - private String name; - private String pipelineName; - private String opName; - - private Long duration; - private Long startTime; - private Long totalTasks; - private Integer slowestTask; - private Long slowestTaskExecuteTime; - private Long inputRecords; - private Long inputKb; - private Long outputRecords; - private Long outputKb; - private Long avgGcTime; - private Long avgExecuteTime; + private String name; + private String pipelineName; + private String opName; + private Long duration; + private Long startTime; + private Long totalTasks; + private Integer slowestTask; + private Long slowestTaskExecuteTime; + private Long inputRecords; + private Long inputKb; + private Long outputRecords; + private Long outputKb; + private Long avgGcTime; + private Long avgExecuteTime; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowError.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowError.java index 43b7b8d4f..c32f48d4c 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowError.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowError.java @@ -26,13 +26,13 @@ @Setter public class GeaflowError { - private String timeStamp; + private String timeStamp; - private String hostname; + private String hostname; - private Integer processId; + private Integer processId; - private String severity; + private String severity; - private String message; + private String message; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowOffset.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowOffset.java index 6bf3aa402..3e5dd0ce6 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowOffset.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowOffset.java @@ -19,33 +19,34 @@ package org.apache.geaflow.console.core.model.runtime; +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter public class GeaflowOffset extends GeaflowId { - private long offset; - private long writeTime; - private OffsetType type; - private String partitionName; - private long diff; + private long offset; + private long writeTime; + private OffsetType type; + private String partitionName; + private long diff; - public void formatTime() { - if (type == OffsetType.TIMESTAMP) { - String offsetString = String.valueOf(offset); - if (offsetString.length() == 10) { - // transfer to millisecond - offset = offset * 1000; - } - diff = writeTime - offset; - } + public void formatTime() { + if (type == OffsetType.TIMESTAMP) { + String offsetString = String.valueOf(offset); + if (offsetString.length() == 10) { + // transfer to millisecond + offset = offset * 1000; + } + diff = writeTime - offset; } + } - public enum OffsetType { - TIMESTAMP, - NON_TIMESTAMP - } + public enum OffsetType { + TIMESTAMP, + NON_TIMESTAMP + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowPipeline.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowPipeline.java index 666e84b38..57bbc241b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowPipeline.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/runtime/GeaflowPipeline.java @@ -26,10 +26,9 @@ @Setter public class GeaflowPipeline { - String name; + String name; - Long startTime; - - Long duration; + Long startTime; + Long duration; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthentication.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthentication.java index 402dd5944..6407ba065 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthentication.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthentication.java @@ -20,27 +20,29 @@ package org.apache.geaflow.console.core.model.security; import java.util.Date; + +import org.apache.geaflow.console.common.util.DateTimeUtil; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.DateTimeUtil; @Getter @Setter public class GeaflowAuthentication { - private String userId; - - private String sessionToken; + private String userId; - private boolean systemSession; + private String sessionToken; - private Date accessTime; + private boolean systemSession; - public boolean isExpired(int liveSeconds) { - if (sessionToken == null) { - return true; - } + private Date accessTime; - return DateTimeUtil.isExpired(accessTime, liveSeconds); + public boolean isExpired(int liveSeconds) { + if (sessionToken == null) { + return true; } + + return DateTimeUtil.isExpired(accessTime, liveSeconds); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthority.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthority.java index b5c3e0b80..b991bfc4a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthority.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthority.java @@ -21,47 +21,49 @@ import java.util.HashMap; import java.util.Map; -import lombok.AllArgsConstructor; -import lombok.Getter; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; +import lombok.AllArgsConstructor; +import lombok.Getter; + @Getter @AllArgsConstructor public class GeaflowAuthority { - public static final GeaflowAuthority ALL = new GeaflowAuthority(GeaflowAuthorityType.ALL); - public static final GeaflowAuthority QUERY = new GeaflowAuthority(GeaflowAuthorityType.QUERY); - public static final GeaflowAuthority UPDATE = new GeaflowAuthority(GeaflowAuthorityType.UPDATE); - public static final GeaflowAuthority EXECUTE = new GeaflowAuthority(GeaflowAuthorityType.EXECUTE); - private static final Map AUTHORITIES = new HashMap<>(); + public static final GeaflowAuthority ALL = new GeaflowAuthority(GeaflowAuthorityType.ALL); + public static final GeaflowAuthority QUERY = new GeaflowAuthority(GeaflowAuthorityType.QUERY); + public static final GeaflowAuthority UPDATE = new GeaflowAuthority(GeaflowAuthorityType.UPDATE); + public static final GeaflowAuthority EXECUTE = new GeaflowAuthority(GeaflowAuthorityType.EXECUTE); + private static final Map AUTHORITIES = new HashMap<>(); - static { - register(new GeaflowAuthority(GeaflowAuthorityType.ALL)); - register(new GeaflowAuthority(GeaflowAuthorityType.QUERY)); - register(new GeaflowAuthority(GeaflowAuthorityType.UPDATE)); - register(new GeaflowAuthority(GeaflowAuthorityType.EXECUTE)); - } + static { + register(new GeaflowAuthority(GeaflowAuthorityType.ALL)); + register(new GeaflowAuthority(GeaflowAuthorityType.QUERY)); + register(new GeaflowAuthority(GeaflowAuthorityType.UPDATE)); + register(new GeaflowAuthority(GeaflowAuthorityType.EXECUTE)); + } - private GeaflowAuthorityType type; + private GeaflowAuthorityType type; - public static GeaflowAuthority of(GeaflowAuthorityType type) { - GeaflowAuthority authority = AUTHORITIES.get(type); - if (authority == null) { - throw new GeaflowException("Authority type {} not supported", type); - } - return authority; + public static GeaflowAuthority of(GeaflowAuthorityType type) { + GeaflowAuthority authority = AUTHORITIES.get(type); + if (authority == null) { + throw new GeaflowException("Authority type {} not supported", type); } + return authority; + } - private static void register(GeaflowAuthority authority) { - AUTHORITIES.put(authority.type, authority); - } + private static void register(GeaflowAuthority authority) { + AUTHORITIES.put(authority.type, authority); + } - public boolean include(GeaflowAuthority other) { - if (other == null) { - return false; - } - - return other.getType() == type || GeaflowAuthorityType.ALL.equals(type); + public boolean include(GeaflowAuthority other) { + if (other == null) { + return false; } + + return other.getType() == type || GeaflowAuthorityType.ALL.equals(type); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthorization.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthorization.java index f34676fd4..953a707f1 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthorization.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowAuthorization.java @@ -19,13 +19,14 @@ package org.apache.geaflow.console.core.model.security; +import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowAuthorityType; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter @@ -33,12 +34,11 @@ @NoArgsConstructor public class GeaflowAuthorization extends GeaflowId { - private String userId; - - private GeaflowAuthorityType authorityType; + private String userId; - private GeaflowResourceType resourceType; + private GeaflowAuthorityType authorityType; - private String resourceId; + private GeaflowResourceType resourceType; + private String resourceId; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowGrant.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowGrant.java index b38f5a3bd..223698f76 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowGrant.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowGrant.java @@ -19,23 +19,24 @@ package org.apache.geaflow.console.core.model.security; +import org.apache.geaflow.console.core.model.security.resource.GeaflowResource; + import lombok.AllArgsConstructor; import lombok.Getter; -import org.apache.geaflow.console.core.model.security.resource.GeaflowResource; @Getter @AllArgsConstructor public class GeaflowGrant { - private GeaflowAuthority authority; + private GeaflowAuthority authority; - private GeaflowResource resource; + private GeaflowResource resource; - public boolean include(GeaflowGrant other) { - if (other == null) { - return false; - } - - return authority.include(other.authority) && resource.include(other.resource); + public boolean include(GeaflowGrant other) { + if (other == null) { + return false; } + + return authority.include(other.authority) && resource.include(other.resource); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowQuota.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowQuota.java index 22b86d79b..389629bf0 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowQuota.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowQuota.java @@ -26,16 +26,15 @@ @Setter public class GeaflowQuota { - private long cpuCores; + private long cpuCores; - private long memoryMBytes; + private long memoryMBytes; - private long diskMBytes; + private long diskMBytes; - private long dfsMBytes; + private long dfsMBytes; - private long dfsFiles; - - private long metaConnections; + private long dfsFiles; + private long metaConnections; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowRole.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowRole.java index 4b6d1ffd3..71437aa46 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowRole.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowRole.java @@ -21,7 +21,7 @@ import java.util.LinkedHashSet; import java.util.Set; -import lombok.Getter; + import org.apache.geaflow.console.common.util.context.ContextHolder; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @@ -29,55 +29,60 @@ import org.apache.geaflow.console.core.model.security.resource.AllResource; import org.apache.geaflow.console.core.model.security.resource.TenantResource; +import lombok.Getter; + @Getter public class GeaflowRole { - public static final GeaflowRole SYSTEM_ADMIN = new GeaflowRole(GeaflowRoleType.SYSTEM_ADMIN, null); + public static final GeaflowRole SYSTEM_ADMIN = + new GeaflowRole(GeaflowRoleType.SYSTEM_ADMIN, null); - public static final GeaflowRole TENANT_ADMIN = new GeaflowRole(GeaflowRoleType.TENANT_ADMIN, SYSTEM_ADMIN); + public static final GeaflowRole TENANT_ADMIN = + new GeaflowRole(GeaflowRoleType.TENANT_ADMIN, SYSTEM_ADMIN); - private final GeaflowRoleType type; + private final GeaflowRoleType type; - private final GeaflowRole parent; + private final GeaflowRole parent; - private GeaflowRole(GeaflowRoleType type, GeaflowRole parent) { - this.type = type; - this.parent = parent; - } + private GeaflowRole(GeaflowRoleType type, GeaflowRole parent) { + this.type = type; + this.parent = parent; + } - public static GeaflowRole of(GeaflowRoleType type) { - switch (type) { - case SYSTEM_ADMIN: - return SYSTEM_ADMIN; - case TENANT_ADMIN: - return TENANT_ADMIN; - default: - throw new GeaflowException("Role type {} not supported", type); - } + public static GeaflowRole of(GeaflowRoleType type) { + switch (type) { + case SYSTEM_ADMIN: + return SYSTEM_ADMIN; + case TENANT_ADMIN: + return TENANT_ADMIN; + default: + throw new GeaflowException("Role type {} not supported", type); } + } - public static Set getGrants(Set roleTypes) { - Set grants = new LinkedHashSet<>(); - for (GeaflowRoleType roleType : roleTypes) { - GeaflowRole role = GeaflowRole.of(roleType); - grants.addAll(role.getGrants()); - } - return grants; + public static Set getGrants(Set roleTypes) { + Set grants = new LinkedHashSet<>(); + for (GeaflowRoleType roleType : roleTypes) { + GeaflowRole role = GeaflowRole.of(roleType); + grants.addAll(role.getGrants()); } + return grants; + } - public Set getGrants() { - Set grants = new LinkedHashSet<>(); - switch (type) { - case SYSTEM_ADMIN: - grants.add(new GeaflowGrant(GeaflowAuthority.ALL, new AllResource(GeaflowResourceType.TENANT))); - break; - case TENANT_ADMIN: - String tenantId = ContextHolder.get().getTenantId(); - grants.add(new GeaflowGrant(GeaflowAuthority.ALL, new TenantResource(tenantId))); - break; - default: - throw new GeaflowException("Role type {} not supported", type); - } - return grants; + public Set getGrants() { + Set grants = new LinkedHashSet<>(); + switch (type) { + case SYSTEM_ADMIN: + grants.add( + new GeaflowGrant(GeaflowAuthority.ALL, new AllResource(GeaflowResourceType.TENANT))); + break; + case TENANT_ADMIN: + String tenantId = ContextHolder.get().getTenantId(); + grants.add(new GeaflowGrant(GeaflowAuthority.ALL, new TenantResource(tenantId))); + break; + default: + throw new GeaflowException("Role type {} not supported", type); } + return grants; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowTenant.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowTenant.java index c7edd61d0..b568954c7 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowTenant.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowTenant.java @@ -19,14 +19,14 @@ package org.apache.geaflow.console.core.model.security; +import org.apache.geaflow.console.core.model.GeaflowName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter public class GeaflowTenant extends GeaflowName { - private GeaflowQuota quota; - + private GeaflowQuota quota; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUser.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUser.java index 166a3e73c..b10143b35 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUser.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUser.java @@ -19,24 +19,25 @@ package org.apache.geaflow.console.core.model.security; +import org.apache.geaflow.console.core.model.GeaflowName; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter @NoArgsConstructor public class GeaflowUser extends GeaflowName { - private String password; + private String password; - private String phone; + private String phone; - private String email; + private String email; - public GeaflowUser(String name, String comment) { - super.name = name; - super.comment = comment; - } + public GeaflowUser(String name, String comment) { + super.name = name; + super.comment = comment; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUserGroup.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUserGroup.java index ee4bae380..64b75c93b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUserGroup.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/GeaflowUserGroup.java @@ -21,14 +21,15 @@ import java.util.ArrayList; import java.util.List; + +import org.apache.geaflow.console.core.model.GeaflowName; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; @Getter @Setter public class GeaflowUserGroup extends GeaflowName { - private final List users = new ArrayList<>(); - + private final List users = new ArrayList<>(); } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AllResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AllResource.java index 0a91ea732..1c8aeb511 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AllResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AllResource.java @@ -23,27 +23,27 @@ public final class AllResource extends GeaflowResource { - private static final String ALL_RESOURCE_ID = "*"; + private static final String ALL_RESOURCE_ID = "*"; - public AllResource(GeaflowResourceType type) { - super(type, null); - } - - @Override - public String getId() { - return ALL_RESOURCE_ID; - } + public AllResource(GeaflowResourceType type) { + super(type, null); + } - @Override - public boolean include(GeaflowResource other) { - if (other == null) { - return false; - } + @Override + public String getId() { + return ALL_RESOURCE_ID; + } - if (this.type.equals(other.type)) { - return true; - } + @Override + public boolean include(GeaflowResource other) { + if (other == null) { + return false; + } - return include(other.parent); + if (this.type.equals(other.type)) { + return true; } + + return include(other.parent); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AtomResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AtomResource.java index 872000de0..c4b4e9cdb 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AtomResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/AtomResource.java @@ -19,31 +19,32 @@ package org.apache.geaflow.console.core.model.security.resource; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.google.common.base.Preconditions; + import lombok.Getter; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; public abstract class AtomResource extends GeaflowResource { - @Getter - protected final String id; + @Getter protected final String id; - public AtomResource(String id, GeaflowResourceType type, GeaflowResource parent) { - super(type, parent); - this.id = Preconditions.checkNotNull(id, "Invalid resource id"); - } - - @Override - public boolean include(GeaflowResource other) { - if (!(other instanceof AtomResource)) { - return false; - } + public AtomResource(String id, GeaflowResourceType type, GeaflowResource parent) { + super(type, parent); + this.id = Preconditions.checkNotNull(id, "Invalid resource id"); + } - AtomResource otherAtom = (AtomResource) other; - if (this.id.equals(otherAtom.id) && this.type.equals(otherAtom.type)) { - return true; - } + @Override + public boolean include(GeaflowResource other) { + if (!(other instanceof AtomResource)) { + return false; + } - return include(other.parent); + AtomResource otherAtom = (AtomResource) other; + if (this.id.equals(otherAtom.id) && this.type.equals(otherAtom.type)) { + return true; } + + return include(other.parent); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/EdgeResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/EdgeResource.java index 1535e2da5..926df0414 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/EdgeResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/EdgeResource.java @@ -23,7 +23,7 @@ public class EdgeResource extends AtomResource { - public EdgeResource(String tenantId, String instanceId, String edgeId) { - super(edgeId, GeaflowResourceType.EDGE, new InstanceResource(tenantId, instanceId)); - } + public EdgeResource(String tenantId, String instanceId, String edgeId) { + super(edgeId, GeaflowResourceType.EDGE, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/FunctionResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/FunctionResource.java index f1784dad4..d3eacb15f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/FunctionResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/FunctionResource.java @@ -23,7 +23,7 @@ public class FunctionResource extends AtomResource { - public FunctionResource(String tenantId, String instanceId, String functionId) { - super(functionId, GeaflowResourceType.FUNCTION, new InstanceResource(tenantId, instanceId)); - } + public FunctionResource(String tenantId, String instanceId, String functionId) { + super(functionId, GeaflowResourceType.FUNCTION, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GeaflowResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GeaflowResource.java index b0deef1af..e09a6fda9 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GeaflowResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GeaflowResource.java @@ -19,23 +19,25 @@ package org.apache.geaflow.console.core.model.security.resource; +import org.apache.geaflow.console.common.util.type.GeaflowResourceType; + import com.google.common.base.Preconditions; + import lombok.Getter; -import org.apache.geaflow.console.common.util.type.GeaflowResourceType; @Getter public abstract class GeaflowResource { - protected final GeaflowResourceType type; + protected final GeaflowResourceType type; - protected final GeaflowResource parent; + protected final GeaflowResource parent; - protected GeaflowResource(GeaflowResourceType type, GeaflowResource parent) { - this.type = Preconditions.checkNotNull(type, "Invalid resource type"); - this.parent = parent; - } + protected GeaflowResource(GeaflowResourceType type, GeaflowResource parent) { + this.type = Preconditions.checkNotNull(type, "Invalid resource type"); + this.parent = parent; + } - public abstract String getId(); + public abstract String getId(); - public abstract boolean include(GeaflowResource other); + public abstract boolean include(GeaflowResource other); } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GraphResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GraphResource.java index e87c6a075..5450ed000 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GraphResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/GraphResource.java @@ -23,7 +23,7 @@ public class GraphResource extends AtomResource { - public GraphResource(String tenantId, String instanceId, String graphId) { - super(graphId, GeaflowResourceType.GRAPH, new InstanceResource(tenantId, instanceId)); - } + public GraphResource(String tenantId, String instanceId, String graphId) { + super(graphId, GeaflowResourceType.GRAPH, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/InstanceResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/InstanceResource.java index fd9ec3754..01b06d82b 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/InstanceResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/InstanceResource.java @@ -23,7 +23,7 @@ public class InstanceResource extends AtomResource { - public InstanceResource(String tenantId, String instanceId) { - super(instanceId, GeaflowResourceType.INSTANCE, new TenantResource(tenantId)); - } + public InstanceResource(String tenantId, String instanceId) { + super(instanceId, GeaflowResourceType.INSTANCE, new TenantResource(tenantId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/JobResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/JobResource.java index f490cddb7..6422a32de 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/JobResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/JobResource.java @@ -23,7 +23,7 @@ public class JobResource extends AtomResource { - public JobResource(String tenantId, String instanceId, String jobId) { - super(jobId, GeaflowResourceType.JOB, new InstanceResource(tenantId, instanceId)); - } + public JobResource(String tenantId, String instanceId, String jobId) { + super(jobId, GeaflowResourceType.JOB, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TableResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TableResource.java index fa14a1e2a..b875a958d 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TableResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TableResource.java @@ -23,7 +23,7 @@ public class TableResource extends AtomResource { - public TableResource(String tenantId, String instanceId, String tableId) { - super(tableId, GeaflowResourceType.TABLE, new InstanceResource(tenantId, instanceId)); - } + public TableResource(String tenantId, String instanceId, String tableId) { + super(tableId, GeaflowResourceType.TABLE, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TaskResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TaskResource.java index 3adc28ab9..06ce59a31 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TaskResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TaskResource.java @@ -23,7 +23,7 @@ public class TaskResource extends AtomResource { - public TaskResource(JobResource jobResource, String taskId) { - super(taskId, GeaflowResourceType.TASK, jobResource); - } + public TaskResource(JobResource jobResource, String taskId) { + super(taskId, GeaflowResourceType.TASK, jobResource); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TenantResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TenantResource.java index 5346f2660..98637c9e9 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TenantResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/TenantResource.java @@ -23,7 +23,7 @@ public class TenantResource extends AtomResource { - public TenantResource(String id) { - super(id, GeaflowResourceType.TENANT, null); - } + public TenantResource(String id) { + super(id, GeaflowResourceType.TENANT, null); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/VertexResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/VertexResource.java index dc2966a88..16ddc234a 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/VertexResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/VertexResource.java @@ -23,7 +23,7 @@ public class VertexResource extends AtomResource { - public VertexResource(String tenantId, String instanceId, String vertexId) { - super(vertexId, GeaflowResourceType.VERTEX, new InstanceResource(tenantId, instanceId)); - } + public VertexResource(String tenantId, String instanceId, String vertexId) { + super(vertexId, GeaflowResourceType.VERTEX, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/ViewResource.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/ViewResource.java index bbd21e82a..bbe06d664 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/ViewResource.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/security/resource/ViewResource.java @@ -23,7 +23,7 @@ public class ViewResource extends AtomResource { - public ViewResource(String tenantId, String instanceId, String viewId) { - super(viewId, GeaflowResourceType.VIEW, new InstanceResource(tenantId, instanceId)); - } + public ViewResource(String tenantId, String instanceId, String viewId) { + super(viewId, GeaflowResourceType.VIEW, new InstanceResource(tenantId, instanceId)); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/statement/GeaflowStatement.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/statement/GeaflowStatement.java index 204d34d86..1abaf237f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/statement/GeaflowStatement.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/statement/GeaflowStatement.java @@ -19,30 +19,32 @@ package org.apache.geaflow.console.core.model.statement; +import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; +import org.apache.geaflow.console.core.model.GeaflowId; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowStatementStatus; -import org.apache.geaflow.console.core.model.GeaflowId; @Setter @Getter @NoArgsConstructor public class GeaflowStatement extends GeaflowId { - private String script; + private String script; - private GeaflowStatementStatus status; + private GeaflowStatementStatus status; - private String result; + private String result; - private String jobId; + private String jobId; - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(jobId, "JobId is null"); - Preconditions.checkNotNull(script, "Query script is null"); - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(jobId, "JobId is null"); + Preconditions.checkNotNull(script, "Query script is null"); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/ContainerTaskHandle.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/ContainerTaskHandle.java index 85cbb1022..e7b5ae404 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/ContainerTaskHandle.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/ContainerTaskHandle.java @@ -19,19 +19,20 @@ package org.apache.geaflow.console.core.model.task; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Setter @Getter public class ContainerTaskHandle extends GeaflowTaskHandle { - // local process pid - private int pid; + // local process pid + private int pid; - public ContainerTaskHandle(String appId, int pid) { - super(GeaflowPluginType.CONTAINER, appId); - this.pid = pid; - } + public ContainerTaskHandle(String appId, int pid) { + super(GeaflowPluginType.CONTAINER, appId); + this.pid = pid; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowHeartbeatInfo.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowHeartbeatInfo.java index 7aba4e020..1e7f0dc6e 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowHeartbeatInfo.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowHeartbeatInfo.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.model.task; import java.util.List; + import lombok.Getter; import lombok.Setter; @@ -27,67 +28,66 @@ @Setter public class GeaflowHeartbeatInfo { - private Integer activeNum; + private Integer activeNum; - private Integer totalNum; + private Integer totalNum; - private Long expiredTimeMs; + private Long expiredTimeMs; - private List containers; + private List containers; - @Getter - @Setter - public static class ContainerInfo { + @Getter + @Setter + public static class ContainerInfo { - private Integer id; + private Integer id; - private String name; + private String name; - private String host; + private String host; - private int pid; + private int pid; - private Long lastTimestamp; + private Long lastTimestamp; - private boolean active; + private boolean active; - private ProcessMetric metrics; + private ProcessMetric metrics; - @Getter - @Setter - public static class ProcessMetric { + @Getter + @Setter + public static class ProcessMetric { - // This amount of memory is guaranteed for the Java virtual machine to use. - private long heapCommittedMB; - private long heapUsedMB; - private double heapUsedRatio; - private long totalMemoryMB; + // This amount of memory is guaranteed for the Java virtual machine to use. + private long heapCommittedMB; + private long heapUsedMB; + private double heapUsedRatio; + private long totalMemoryMB; - // the total number of full collections that have occurred. - private long fgcCount = 0L; + // the total number of full collections that have occurred. + private long fgcCount = 0L; - // the total cost of full collections that have occurred. - private long fgcTime = 0L; + // the total cost of full collections that have occurred. + private long fgcTime = 0L; - // the approximate accumulated collection elapsed time in milliseconds. - private long gcTime = 0L; + // the approximate accumulated collection elapsed time in milliseconds. + private long gcTime = 0L; - // the total number of collections that have occurred. - private long gcCount = 0; + // the total number of collections that have occurred. + private long gcCount = 0; - // The system load average for the last minute, or a negative value if not available. - private double avgLoad; + // The system load average for the last minute, or a negative value if not available. + private double avgLoad; - // the number of processors available to the Java virtual machine. - private int availCores; + // the number of processors available to the Java virtual machine. + private int availCores; - // cpu usage. - private double processCpu; + // cpu usage. + private double processCpu; - // the number of processors used. - private double usedCores; - private int activeThreads; - } + // the number of processors used. + private double usedCores; + private int activeThreads; } - + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTask.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTask.java index 93c5090ab..bf7d46ae9 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTask.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTask.java @@ -19,15 +19,13 @@ package org.apache.geaflow.console.core.model.task; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Comparator; import java.util.Date; import java.util.List; import java.util.TreeSet; import java.util.stream.Collectors; -import lombok.Getter; -import lombok.Setter; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.util.Fmt; import org.apache.geaflow.console.common.util.context.GeaflowContext; @@ -44,141 +42,158 @@ import org.apache.geaflow.console.core.model.task.schedule.GeaflowSchedule; import org.apache.geaflow.console.core.model.version.GeaflowVersion; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.Setter; + @Getter @Setter public class GeaflowTask extends GeaflowId { - public static final String CODE_TASK_MAIN_CLASS = "org.apache.geaflow.dsl.runtime.engine.GeaFlowGqlClient"; + public static final String CODE_TASK_MAIN_CLASS = + "org.apache.geaflow.dsl.runtime.engine.GeaFlowGqlClient"; - private GeaflowRelease release; + private GeaflowRelease release; - private GeaflowTaskType type; + private GeaflowTaskType type; - private GeaflowSchedule schedule; + private GeaflowSchedule schedule; - private GeaflowTaskStatus status; + private GeaflowTaskStatus status; - private String token; + private String token; - private Date startTime; + private Date startTime; - private Date endTime; + private Date endTime; - private GeaflowTaskHandle handle; + private GeaflowTaskHandle handle; - private String host; + private String host; - private GeaflowPluginConfig runtimeMetaPluginConfig; + private GeaflowPluginConfig runtimeMetaPluginConfig; - private GeaflowPluginConfig haMetaPluginConfig; + private GeaflowPluginConfig haMetaPluginConfig; - private GeaflowPluginConfig metricPluginConfig; + private GeaflowPluginConfig metricPluginConfig; - private GeaflowPluginConfig dataPluginConfig; + private GeaflowPluginConfig dataPluginConfig; - public static String getTaskFileUrlFormatter(String gatewayUrl, String path) { - return Fmt.as("{}{}/tasks/%s/files?path={}", gatewayUrl, GeaflowContext.API_PREFIX, path); - } + public static String getTaskFileUrlFormatter(String gatewayUrl, String path) { + return Fmt.as("{}{}/tasks/%s/files?path={}", gatewayUrl, GeaflowContext.API_PREFIX, path); + } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(runtimeMetaPluginConfig, "Invalid runtimeMetaPluginConfig"); - Preconditions.checkNotNull(haMetaPluginConfig, "Invalid haMetaPluginConfig"); - Preconditions.checkNotNull(metricPluginConfig, "Invalid metricPluginConfig"); - Preconditions.checkNotNull(dataPluginConfig, "Invalid dataPluginConfig"); - Preconditions.checkNotNull(release, "Invalid release"); - } - - public String getMainClass() { - GeaflowJob job = release.getJob(); - switch (type) { - case CODE: - return CODE_TASK_MAIN_CLASS; - case API: - return job.getEntryClass(); - default: - throw new GeaflowIllegalException("Task type {} not supported", type); - } - } + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(runtimeMetaPluginConfig, "Invalid runtimeMetaPluginConfig"); + Preconditions.checkNotNull(haMetaPluginConfig, "Invalid haMetaPluginConfig"); + Preconditions.checkNotNull(metricPluginConfig, "Invalid metricPluginConfig"); + Preconditions.checkNotNull(dataPluginConfig, "Invalid dataPluginConfig"); + Preconditions.checkNotNull(release, "Invalid release"); + } - public List getVersionFiles(String gatewayUrl) { - List files = new ArrayList<>(); - getVersionJars().forEach(jar -> files.add(new TaskFile(jar.getUrl(), jar.getMd5()))); - return rewriteTaskFileUrl(files, gatewayUrl); + public String getMainClass() { + GeaflowJob job = release.getJob(); + switch (type) { + case CODE: + return CODE_TASK_MAIN_CLASS; + case API: + return job.getEntryClass(); + default: + throw new GeaflowIllegalException("Task type {} not supported", type); } - - public List getUserFiles(String gatewayUrl) { - List files = new ArrayList<>(); - files.add(new TaskFile(release.getUrl(), release.getMd5())); - getUserJars().forEach(jar -> files.add(new TaskFile(jar.getUrl(), jar.getMd5()))); - - return rewriteTaskFileUrl(files, gatewayUrl); + } + + public List getVersionFiles(String gatewayUrl) { + List files = new ArrayList<>(); + getVersionJars().forEach(jar -> files.add(new TaskFile(jar.getUrl(), jar.getMd5()))); + return rewriteTaskFileUrl(files, gatewayUrl); + } + + public List getUserFiles(String gatewayUrl) { + List files = new ArrayList<>(); + files.add(new TaskFile(release.getUrl(), release.getMd5())); + getUserJars().forEach(jar -> files.add(new TaskFile(jar.getUrl(), jar.getMd5()))); + + return rewriteTaskFileUrl(files, gatewayUrl); + } + + public List getVersionJars() { + List jars = new ArrayList<>(); + + GeaflowVersion version = release.getVersion(); + GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); + GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); + Preconditions.checkNotNull( + engineJarPackage, "Invalid engine jar of version %s", version.getName()); + + jars.add(engineJarPackage); + if (langJarPackage != null) { + jars.add(langJarPackage); } - public List getVersionJars() { - List jars = new ArrayList<>(); + return jars; + } - GeaflowVersion version = release.getVersion(); - GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); - GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); - Preconditions.checkNotNull(engineJarPackage, "Invalid engine jar of version %s", version.getName()); + public List getUserJars() { + List jars = new ArrayList<>(); - jars.add(engineJarPackage); - if (langJarPackage != null) { - jars.add(langJarPackage); - } - - return jars; - } - - public List getUserJars() { - List jars = new ArrayList<>(); - - GeaflowJob job = release.getJob(); - GeaflowRemoteFile jarPackage = job.getJarPackage(); - List functions = job.getFunctions(); - List plugins = job.getPlugins(); - if (jarPackage != null) { - jars.add(jarPackage); - } - - if (CollectionUtils.isNotEmpty(functions)) { - functions.forEach(f -> { - GeaflowRemoteFile functionJarPackage = f.getJarPackage(); - Preconditions.checkNotNull(functionJarPackage, "Invalid jar of function %s", f.getName()); - jars.add(functionJarPackage); - }); - } - - if (CollectionUtils.isNotEmpty(plugins)) { - plugins.forEach(plugin -> { - GeaflowRemoteFile pluginJarPackage = plugin.getJarPackage(); - Preconditions.checkNotNull(pluginJarPackage, "Invalid jar of plugin %s", plugin.getName()); - jars.add(pluginJarPackage); - }); - } - - List newList = jars.stream().collect(Collectors.collectingAndThen( - Collectors.toCollection(() -> new TreeSet<>(Comparator.comparing(GeaflowRemoteFile::getMd5))), ArrayList::new)); - - return newList; + GeaflowJob job = release.getJob(); + GeaflowRemoteFile jarPackage = job.getJarPackage(); + List functions = job.getFunctions(); + List plugins = job.getPlugins(); + if (jarPackage != null) { + jars.add(jarPackage); } - public String getStartupNotifyUrl(String gatewayUrl) { - return String.format("%s%s/tasks/%s/startup-notify", gatewayUrl, GeaflowContext.API_PREFIX, id); + if (CollectionUtils.isNotEmpty(functions)) { + functions.forEach( + f -> { + GeaflowRemoteFile functionJarPackage = f.getJarPackage(); + Preconditions.checkNotNull( + functionJarPackage, "Invalid jar of function %s", f.getName()); + jars.add(functionJarPackage); + }); } - public String getTaskFileUrl(String gatewayUrl, String path) { - return String.format(getTaskFileUrlFormatter(gatewayUrl, path), id); + if (CollectionUtils.isNotEmpty(plugins)) { + plugins.forEach( + plugin -> { + GeaflowRemoteFile pluginJarPackage = plugin.getJarPackage(); + Preconditions.checkNotNull( + pluginJarPackage, "Invalid jar of plugin %s", plugin.getName()); + jars.add(pluginJarPackage); + }); } - private List rewriteTaskFileUrl(List files, String gatewayUrl) { - files.forEach(f -> { - if (f.getUrl().startsWith(gatewayUrl)) { - f.setUrl(String.format(f.getUrl(), id)); - } + List newList = + jars.stream() + .collect( + Collectors.collectingAndThen( + Collectors.toCollection( + () -> new TreeSet<>(Comparator.comparing(GeaflowRemoteFile::getMd5))), + ArrayList::new)); + + return newList; + } + + public String getStartupNotifyUrl(String gatewayUrl) { + return String.format("%s%s/tasks/%s/startup-notify", gatewayUrl, GeaflowContext.API_PREFIX, id); + } + + public String getTaskFileUrl(String gatewayUrl, String path) { + return String.format(getTaskFileUrlFormatter(gatewayUrl, path), id); + } + + private List rewriteTaskFileUrl(List files, String gatewayUrl) { + files.forEach( + f -> { + if (f.getUrl().startsWith(gatewayUrl)) { + f.setUrl(String.format(f.getUrl(), id)); + } }); - return files; - } + return files; + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTaskHandle.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTaskHandle.java index e8c3c7938..d85e24158 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTaskHandle.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/GeaflowTaskHandle.java @@ -19,47 +19,47 @@ package org.apache.geaflow.console.core.model.task; +import org.apache.commons.lang3.StringUtils; +import org.apache.geaflow.console.common.util.exception.GeaflowException; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; + import lombok.Getter; import lombok.Setter; -import org.apache.commons.lang3.StringUtils; -import org.apache.geaflow.console.common.util.exception.GeaflowException; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Setter @Getter public abstract class GeaflowTaskHandle { - public GeaflowTaskHandle(GeaflowPluginType clusterType, String appId) { - this.clusterType = clusterType; - this.appId = appId; - } - - // cluster type - protected GeaflowPluginType clusterType; - - // app id - protected String appId; + public GeaflowTaskHandle(GeaflowPluginType clusterType, String appId) { + this.clusterType = clusterType; + this.appId = appId; + } + // cluster type + protected GeaflowPluginType clusterType; - public static GeaflowTaskHandle parse(String text) { - if (StringUtils.isEmpty(text)) { - return null; - } + // app id + protected String appId; - JSONObject json = JSON.parseObject(text); - GeaflowPluginType clusterType = GeaflowPluginType.of(json.get("clusterType").toString()); - switch (clusterType) { - case CONTAINER: - return json.toJavaObject(ContainerTaskHandle.class); - case K8S: - return json.toJavaObject(K8sTaskHandle.class); - case RAY: - return json.toJavaObject(RayTaskHandle.class); - default: - throw new GeaflowException("Unsupported cluster type {}", clusterType); - } + public static GeaflowTaskHandle parse(String text) { + if (StringUtils.isEmpty(text)) { + return null; } + JSONObject json = JSON.parseObject(text); + GeaflowPluginType clusterType = GeaflowPluginType.of(json.get("clusterType").toString()); + switch (clusterType) { + case CONTAINER: + return json.toJavaObject(ContainerTaskHandle.class); + case K8S: + return json.toJavaObject(K8sTaskHandle.class); + case RAY: + return json.toJavaObject(RayTaskHandle.class); + default: + throw new GeaflowException("Unsupported cluster type {}", clusterType); + } + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/K8sTaskHandle.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/K8sTaskHandle.java index da989d419..c76fbdb71 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/K8sTaskHandle.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/K8sTaskHandle.java @@ -19,32 +19,30 @@ package org.apache.geaflow.console.core.model.task; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Setter @Getter public class K8sTaskHandle extends GeaflowTaskHandle { + // startup notify + protected StartupNotifyInfo startupNotifyInfo; - // startup notify - protected StartupNotifyInfo startupNotifyInfo; - - @Setter - @Getter - public static class StartupNotifyInfo { - - private String masterAddress; - - private String driverAddress; + @Setter + @Getter + public static class StartupNotifyInfo { - private String clientAddress; + private String masterAddress; - } + private String driverAddress; + private String clientAddress; + } - public K8sTaskHandle(String appId) { - super(GeaflowPluginType.K8S, appId); - } + public K8sTaskHandle(String appId) { + super(GeaflowPluginType.K8S, appId); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/RayTaskHandle.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/RayTaskHandle.java index d857a03e0..c720998a1 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/RayTaskHandle.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/RayTaskHandle.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.core.model.task; +import org.apache.geaflow.console.common.util.type.GeaflowPluginType; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @Setter @Getter public class RayTaskHandle extends GeaflowTaskHandle { - public RayTaskHandle(String submissionId) { - super(GeaflowPluginType.RAY, submissionId); - - } + public RayTaskHandle(String submissionId) { + super(GeaflowPluginType.RAY, submissionId); + } } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/TaskFile.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/TaskFile.java index ea4881605..b9d2dfecc 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/TaskFile.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/TaskFile.java @@ -30,8 +30,7 @@ @AllArgsConstructor public class TaskFile { - private String url; - - private String md5; + private String url; + private String md5; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowApprovalInfo.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowApprovalInfo.java index 2eaa47142..a84c1aca0 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowApprovalInfo.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowApprovalInfo.java @@ -19,16 +19,16 @@ package org.apache.geaflow.console.core.model.task.schedule; +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter public class GeaflowApprovalInfo extends GeaflowId { - private String approvalId; - - private boolean success; + private String approvalId; + private boolean success; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowGrayInfo.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowGrayInfo.java index 6ce522417..1afdbab27 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowGrayInfo.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowGrayInfo.java @@ -20,16 +20,18 @@ package org.apache.geaflow.console.core.model.task.schedule; import java.util.Date; + +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter public class GeaflowGrayInfo extends GeaflowId { - boolean success; - private long duration; - private Date startTime; - private Date endTime; + boolean success; + private long duration; + private Date startTime; + private Date endTime; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowSchedule.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowSchedule.java index b21a24ee7..431094f3f 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowSchedule.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/task/schedule/GeaflowSchedule.java @@ -19,19 +19,20 @@ package org.apache.geaflow.console.core.model.task.schedule; +import org.apache.geaflow.console.core.model.GeaflowId; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowId; @Getter @Setter public class GeaflowSchedule extends GeaflowId { - private boolean failover; + private boolean failover; - private GeaflowApprovalInfo approvalInfo; + private GeaflowApprovalInfo approvalInfo; - private GeaflowGrayInfo grayInfo; + private GeaflowGrayInfo grayInfo; - private String cronExpr; + private String cronExpr; } diff --git a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/version/GeaflowVersion.java b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/version/GeaflowVersion.java index ee90f7931..68f9d9c68 100644 --- a/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/version/GeaflowVersion.java +++ b/geaflow-console/app/core/model/src/main/java/org/apache/geaflow/console/core/model/version/GeaflowVersion.java @@ -19,39 +19,40 @@ package org.apache.geaflow.console.core.model.version; +import org.apache.geaflow.console.core.model.GeaflowName; +import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; + import com.google.common.base.Preconditions; + import lombok.Getter; import lombok.Setter; -import org.apache.geaflow.console.core.model.GeaflowName; -import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; @Getter @Setter public class GeaflowVersion extends GeaflowName { - private GeaflowRemoteFile engineJarPackage; - - private GeaflowRemoteFile langJarPackage; + private GeaflowRemoteFile engineJarPackage; - private boolean publish; + private GeaflowRemoteFile langJarPackage; - public static boolean md5Equals(GeaflowVersion left, GeaflowVersion right) { - if (left == null && right == null) { - return true; - } + private boolean publish; - if (left != null && right != null) { - return GeaflowRemoteFile.md5Equals(left.engineJarPackage, right.engineJarPackage) - && GeaflowRemoteFile.md5Equals(left.langJarPackage, right.langJarPackage); - } - - return false; + public static boolean md5Equals(GeaflowVersion left, GeaflowVersion right) { + if (left == null && right == null) { + return true; } - @Override - public void validate() { - super.validate(); - Preconditions.checkNotNull(engineJarPackage, "Invalid engineJarPackage"); + if (left != null && right != null) { + return GeaflowRemoteFile.md5Equals(left.engineJarPackage, right.engineJarPackage) + && GeaflowRemoteFile.md5Equals(left.langJarPackage, right.langJarPackage); } + return false; + } + + @Override + public void validate() { + super.validate(); + Preconditions.checkNotNull(engineJarPackage, "Invalid engineJarPackage"); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuditService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuditService.java index 6e428c60c..7de62a6f7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuditService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuditService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.AuditDao; import org.apache.geaflow.console.common.dal.dao.IdDao; import org.apache.geaflow.console.common.dal.entity.AuditEntity; @@ -34,25 +35,22 @@ @Service public class AuditService extends IdService { - @Autowired - private AuditDao auditDao; + @Autowired private AuditDao auditDao; - @Autowired - private AuditConverter auditConverter; + @Autowired private AuditConverter auditConverter; - @Override - protected IdDao getDao() { - return auditDao; - } + @Override + protected IdDao getDao() { + return auditDao; + } - @Override - protected IdConverter getConverter() { - return auditConverter; - } + @Override + protected IdConverter getConverter() { + return auditConverter; + } - @Override - protected List parse(List auditEntities) { - return ListUtil.convert(auditEntities, e -> auditConverter.convert(e)); - } + @Override + protected List parse(List auditEntities) { + return ListUtil.convert(auditEntities, e -> auditConverter.convert(e)); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthenticationService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthenticationService.java index 610be361e..f54c982a1 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthenticationService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthenticationService.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.util.Date; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.UserDao; import org.apache.geaflow.console.common.dal.entity.UserEntity; import org.apache.geaflow.console.common.util.Md5Util; @@ -37,191 +37,200 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import com.google.common.base.Preconditions; + @Service public class AuthenticationService { - private static final int TOKEN_EXPIRED_SECONDS = 24 * 60 * 60; + private static final int TOKEN_EXPIRED_SECONDS = 24 * 60 * 60; - @Autowired - private UserDao userDao; + @Autowired private UserDao userDao; - @Autowired - private TokenGenerator tokenGenerator; + @Autowired private TokenGenerator tokenGenerator; - @Autowired - private TenantService tenantService; + @Autowired private TenantService tenantService; - @Autowired - private UserService userService; + @Autowired private UserService userService; - @Autowired - private InstanceService instanceService; + @Autowired private InstanceService instanceService; - @Autowired - private AuthorizationService authorizationService; + @Autowired private AuthorizationService authorizationService; - @Transactional - public GeaflowAuthentication login(String loginName, String password, boolean systemLogin) { - Preconditions.checkNotNull(loginName, "Invalid loginName"); - Preconditions.checkNotNull(password, "Invalid password"); + @Transactional + public GeaflowAuthentication login(String loginName, String password, boolean systemLogin) { + Preconditions.checkNotNull(loginName, "Invalid loginName"); + Preconditions.checkNotNull(password, "Invalid password"); - // get login user - UserEntity entity = userDao.getByName(loginName); - Preconditions.checkNotNull(entity, "User not found"); + // get login user + UserEntity entity = userDao.getByName(loginName); + Preconditions.checkNotNull(entity, "User not found"); - try { - String userId = entity.getId(); + try { + String userId = entity.getId(); - // check password - if (!Md5Util.encodeString(password).equals(entity.getPasswordSign())) { - throw new GeaflowException("Invalid password"); - } + // check password + if (!Md5Util.encodeString(password).equals(entity.getPasswordSign())) { + throw new GeaflowException("Invalid password"); + } - // check system admin role - if (systemLogin && !authorizationService.existRole(null, userId, GeaflowRole.SYSTEM_ADMIN)) { - throw new GeaflowException("Not system admin"); - } + // check system admin role + if (systemLogin && !authorizationService.existRole(null, userId, GeaflowRole.SYSTEM_ADMIN)) { + throw new GeaflowException("Not system admin"); + } - // create or reuse authentication - GeaflowAuthentication authentication = convert(entity); - if (!authentication.isExpired(TOKEN_EXPIRED_SECONDS) && authentication.isSystemSession() == systemLogin) { - updateAccessTime(userId); + // create or reuse authentication + GeaflowAuthentication authentication = convert(entity); + if (!authentication.isExpired(TOKEN_EXPIRED_SECONDS) + && authentication.isSystemSession() == systemLogin) { + updateAccessTime(userId); - } else { - createAuthentication(userId, systemLogin); - } + } else { + createAuthentication(userId, systemLogin); + } - // init user id context when authentication success - ContextHolder.get().setUserId(userId); + // init user id context when authentication success + ContextHolder.get().setUserId(userId); - // init tenant and instance when first login - ensureTenantInited(userId); + // init tenant and instance when first login + ensureTenantInited(userId); - // get authentication - return getAuthenticationByUserId(userId); + // get authentication + return getAuthenticationByUserId(userId); - } catch (Exception e) { - throw new GeaflowSecurityException("User login failed", e); - } + } catch (Exception e) { + throw new GeaflowSecurityException("User login failed", e); } + } - public GeaflowAuthentication authenticate(String token) { - try { - GeaflowAuthentication authentication = getAuthenticationByToken(token); - if (authentication.isExpired(TOKEN_EXPIRED_SECONDS)) { - throw new GeaflowException("User token expired"); - } + public GeaflowAuthentication authenticate(String token) { + try { + GeaflowAuthentication authentication = getAuthenticationByToken(token); + if (authentication.isExpired(TOKEN_EXPIRED_SECONDS)) { + throw new GeaflowException("User token expired"); + } - String userId = authentication.getUserId(); - updateAccessTime(userId); + String userId = authentication.getUserId(); + updateAccessTime(userId); - return getAuthenticationByUserId(userId); + return getAuthenticationByUserId(userId); - } catch (Exception e) { - throw new GeaflowSecurityException("User authenticate failed", e); - } + } catch (Exception e) { + throw new GeaflowSecurityException("User authenticate failed", e); } + } - public boolean switchSession() { - String userId = ContextHolder.get().getUserId(); - boolean expectSystemSession = !ContextHolder.get().isSystemSession(); - - // check system admin role - if (expectSystemSession && !authorizationService.existRole(null, userId, GeaflowRole.SYSTEM_ADMIN)) { - throw new GeaflowException("Not system admin"); - } + public boolean switchSession() { + String userId = ContextHolder.get().getUserId(); + boolean expectSystemSession = !ContextHolder.get().isSystemSession(); - return updateSystemSession(userId, expectSystemSession); + // check system admin role + if (expectSystemSession + && !authorizationService.existRole(null, userId, GeaflowRole.SYSTEM_ADMIN)) { + throw new GeaflowException("Not system admin"); } - public boolean logout(String token) { - try { - GeaflowAuthentication authentication = getAuthenticationByToken(token); - if (authentication.isExpired(TOKEN_EXPIRED_SECONDS)) { - return true; - } - - return destroyAuthentication(authentication.getUserId()); - - } catch (Exception e) { - throw new GeaflowSecurityException("User logout failed", e); - } - } - - private void ensureTenantInited(String userId) { - if (ContextHolder.get().getTenantId() != null) { - return; - } - - List userTenants = tenantService.getUserTenants(userId); - if (!userTenants.isEmpty()) { - return; - } - - GeaflowUser user = userService.get(userId); + return updateSystemSession(userId, expectSystemSession); + } - // create tenant - String tenantId = tenantService.createDefaultTenant(user); + public boolean logout(String token) { + try { + GeaflowAuthentication authentication = getAuthenticationByToken(token); + if (authentication.isExpired(TOKEN_EXPIRED_SECONDS)) { + return true; + } - // add user to new tenant - userService.addTenantUser(tenantId, userId); + return destroyAuthentication(authentication.getUserId()); - // init user as tenant admin - authorizationService.addRole(tenantId, userId, GeaflowRole.TENANT_ADMIN); - - // activate user tenant - tenantService.activateTenant(tenantId, userId); - - // create instance - instanceService.createDefaultInstance(tenantId, user); - } - - public GeaflowAuthentication getAuthenticationByUserId(String userId) { - UserEntity entity = userDao.get(userId); - Preconditions.checkNotNull(entity, "Invalid token"); - return convert(entity); - } - - public GeaflowAuthentication getAuthenticationByToken(String token) { - UserEntity entity = userDao.getByToken(token); - Preconditions.checkNotNull(entity, "Invalid token %s", token); - return convert(entity); - } - - public boolean createAuthentication(String userId, boolean systemLogin) { - UserEntity entity = new UserEntity(); - entity.setId(userId); - entity.setSessionToken(tokenGenerator.nextToken()); - entity.setSystemSession(systemLogin); - entity.setAccessTime(new Date()); - return userDao.updateById(entity); + } catch (Exception e) { + throw new GeaflowSecurityException("User logout failed", e); } + } - public boolean updateAccessTime(String userId) { - return userDao.lambdaUpdate().set(UserEntity::getAccessTime, new Date()).eq(UserEntity::getId, userId).update(); + private void ensureTenantInited(String userId) { + if (ContextHolder.get().getTenantId() != null) { + return; } - public boolean updateSystemSession(String userId, boolean systemSession) { - return userDao.lambdaUpdate().set(UserEntity::isSystemSession, systemSession).eq(UserEntity::getId, userId) - .update(); + List userTenants = tenantService.getUserTenants(userId); + if (!userTenants.isEmpty()) { + return; } - public boolean destroyAuthentication(String userId) { - return userDao.lambdaUpdate().set(UserEntity::getSessionToken, null).set(UserEntity::isSystemSession, false) - .set(UserEntity::getAccessTime, null).eq(UserEntity::getId, userId).update(); + GeaflowUser user = userService.get(userId); + + // create tenant + String tenantId = tenantService.createDefaultTenant(user); + + // add user to new tenant + userService.addTenantUser(tenantId, userId); + + // init user as tenant admin + authorizationService.addRole(tenantId, userId, GeaflowRole.TENANT_ADMIN); + + // activate user tenant + tenantService.activateTenant(tenantId, userId); + + // create instance + instanceService.createDefaultInstance(tenantId, user); + } + + public GeaflowAuthentication getAuthenticationByUserId(String userId) { + UserEntity entity = userDao.get(userId); + Preconditions.checkNotNull(entity, "Invalid token"); + return convert(entity); + } + + public GeaflowAuthentication getAuthenticationByToken(String token) { + UserEntity entity = userDao.getByToken(token); + Preconditions.checkNotNull(entity, "Invalid token %s", token); + return convert(entity); + } + + public boolean createAuthentication(String userId, boolean systemLogin) { + UserEntity entity = new UserEntity(); + entity.setId(userId); + entity.setSessionToken(tokenGenerator.nextToken()); + entity.setSystemSession(systemLogin); + entity.setAccessTime(new Date()); + return userDao.updateById(entity); + } + + public boolean updateAccessTime(String userId) { + return userDao + .lambdaUpdate() + .set(UserEntity::getAccessTime, new Date()) + .eq(UserEntity::getId, userId) + .update(); + } + + public boolean updateSystemSession(String userId, boolean systemSession) { + return userDao + .lambdaUpdate() + .set(UserEntity::isSystemSession, systemSession) + .eq(UserEntity::getId, userId) + .update(); + } + + public boolean destroyAuthentication(String userId) { + return userDao + .lambdaUpdate() + .set(UserEntity::getSessionToken, null) + .set(UserEntity::isSystemSession, false) + .set(UserEntity::getAccessTime, null) + .eq(UserEntity::getId, userId) + .update(); + } + + private GeaflowAuthentication convert(UserEntity entity) { + if (entity == null) { + return null; } - private GeaflowAuthentication convert(UserEntity entity) { - if (entity == null) { - return null; - } - - GeaflowAuthentication authentication = new GeaflowAuthentication(); - authentication.setUserId(entity.getId()); - authentication.setSessionToken(entity.getSessionToken()); - authentication.setSystemSession(entity.isSystemSession()); - authentication.setAccessTime(entity.getAccessTime()); - return authentication; - } + GeaflowAuthentication authentication = new GeaflowAuthentication(); + authentication.setUserId(entity.getId()); + authentication.setSessionToken(entity.getSessionToken()); + authentication.setSystemSession(entity.isSystemSession()); + authentication.setAccessTime(entity.getAccessTime()); + return authentication; + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthorizationService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthorizationService.java index c0ba79ae9..e9cccf3cd 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthorizationService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/AuthorizationService.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.AuthorizationDao; import org.apache.geaflow.console.common.dal.dao.IdDao; import org.apache.geaflow.console.common.dal.dao.UserRoleMappingDao; @@ -41,67 +42,70 @@ import org.springframework.stereotype.Service; @Service -public class AuthorizationService extends IdService { - - @Autowired - private AuthorizationDao authorizationDao; - - @Autowired - private AuthorizationConverter authorizationConverter; - - @Autowired - private UserRoleMappingDao userRoleMappingDao; - - @Override - protected IdDao getDao() { - return authorizationDao; - } - - @Override - protected IdConverter getConverter() { - return authorizationConverter; - } +public class AuthorizationService + extends IdService { - @Override - protected List parse(List grantEntities) { - return grantEntities.stream().map(e -> authorizationConverter.convert(e)).collect(Collectors.toList()); - } + @Autowired private AuthorizationDao authorizationDao; - public List getUserRoleTypes(String tenantId, String userId) { - return userRoleMappingDao.getRoleTypes(tenantId, userId); - } + @Autowired private AuthorizationConverter authorizationConverter; - public boolean existRole(String tenantId, String userId, GeaflowRole role) { - return userRoleMappingDao.existRoleType(tenantId, userId, role.getType()); - } + @Autowired private UserRoleMappingDao userRoleMappingDao; - public void addRole(String tenantId, String userId, GeaflowRole role) { - GeaflowRoleType roleType = role.getType(); - if (userRoleMappingDao.existRoleType(tenantId, userId, roleType)) { - throw new GeaflowIllegalException("User role {} exists", roleType); - } + @Override + protected IdDao getDao() { + return authorizationDao; + } - userRoleMappingDao.addRoleType(tenantId, userId, roleType); - } + @Override + protected IdConverter getConverter() { + return authorizationConverter; + } - public void deleteRole(String tenantId, String userId, GeaflowRole role) { - userRoleMappingDao.deleteRoleType(tenantId, userId, role.getType()); - } + @Override + protected List parse(List grantEntities) { + return grantEntities.stream() + .map(e -> authorizationConverter.convert(e)) + .collect(Collectors.toList()); + } - public boolean exist(String userId, GeaflowAuthority authority, GeaflowResource resource) { - return authorizationDao.exist(userId, authority.getType(), resource.getType(), resource.getId()); - } + public List getUserRoleTypes(String tenantId, String userId) { + return userRoleMappingDao.getRoleTypes(tenantId, userId); + } - public boolean dropByResources(List resourceIds, GeaflowResourceType type) { - return authorizationDao.dropByResources(resourceIds, type); - } + public boolean existRole(String tenantId, String userId, GeaflowRole role) { + return userRoleMappingDao.existRoleType(tenantId, userId, role.getType()); + } - public void addAuthorization(List resourceIds, String userId, GeaflowAuthorityType authorityType, - GeaflowResourceType resourceType) { - List authorizations = ListUtil.convert(resourceIds, - id -> new GeaflowAuthorization(userId, authorityType, resourceType, id)); - create(authorizations); + public void addRole(String tenantId, String userId, GeaflowRole role) { + GeaflowRoleType roleType = role.getType(); + if (userRoleMappingDao.existRoleType(tenantId, userId, roleType)) { + throw new GeaflowIllegalException("User role {} exists", roleType); } + userRoleMappingDao.addRoleType(tenantId, userId, roleType); + } + + public void deleteRole(String tenantId, String userId, GeaflowRole role) { + userRoleMappingDao.deleteRoleType(tenantId, userId, role.getType()); + } + + public boolean exist(String userId, GeaflowAuthority authority, GeaflowResource resource) { + return authorizationDao.exist( + userId, authority.getType(), resource.getType(), resource.getId()); + } + + public boolean dropByResources(List resourceIds, GeaflowResourceType type) { + return authorizationDao.dropByResources(resourceIds, type); + } + + public void addAuthorization( + List resourceIds, + String userId, + GeaflowAuthorityType authorityType, + GeaflowResourceType resourceType) { + List authorizations = + ListUtil.convert( + resourceIds, id -> new GeaflowAuthorization(userId, authorityType, resourceType, id)); + create(authorizations); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ChatService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ChatService.java index c57480df3..fbaac55b9 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ChatService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ChatService.java @@ -19,16 +19,12 @@ package org.apache.geaflow.console.core.service; - -import com.google.common.base.Preconditions; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.geaflow.console.common.dal.dao.ChatDao; @@ -50,141 +46,141 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.google.common.base.Preconditions; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j public class ChatService extends IdService { - private static final ExecutorService EXECUTOR_SERVICE = new ThreadPoolExecutor(5, 5, - 10, TimeUnit.SECONDS, new LinkedBlockingQueue<>(100)); + private static final ExecutorService EXECUTOR_SERVICE = + new ThreadPoolExecutor(5, 5, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<>(100)); - @Autowired - private ChatDao chatDao; + @Autowired private ChatDao chatDao; - @Autowired - private ChatConverter chatConverter; + @Autowired private ChatConverter chatConverter; - @Autowired - private LLMService llmService; + @Autowired private LLMService llmService; - @Autowired - private LLMClientFactory llmClientFactory; + @Autowired private LLMClientFactory llmClientFactory; - @Autowired - private JobService jobService; + @Autowired private JobService jobService; - private static final String HINT_STATEMENT = "Use the schema shown below: "; + private static final String HINT_STATEMENT = "Use the schema shown below: "; - private final Cache schemaCache = CacheBuilder.newBuilder() - .maximumSize(50) - .expireAfterWrite(180, TimeUnit.SECONDS) - .build(); + private final Cache schemaCache = + CacheBuilder.newBuilder().maximumSize(50).expireAfterWrite(180, TimeUnit.SECONDS).build(); - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, chatConverter::convert); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, chatConverter::convert); + } - @Override - protected IdDao getDao() { - return chatDao; - } + @Override + protected IdDao getDao() { + return chatDao; + } - @Override - protected IdConverter getConverter() { - return chatConverter; - } + @Override + protected IdConverter getConverter() { + return chatConverter; + } - public void callASync(GeaflowChat chat, boolean withSchema) { + public void callASync(GeaflowChat chat, boolean withSchema) { - final String sessionToken = ContextHolder.get().getSessionToken(); - EXECUTOR_SERVICE.submit(() -> { - try { - ContextHolder.init(); - ContextHolder.get().setSessionToken(sessionToken); - callSync(chat, true, withSchema); + final String sessionToken = ContextHolder.get().getSessionToken(); + EXECUTOR_SERVICE.submit( + () -> { + try { + ContextHolder.init(); + ContextHolder.get().setSessionToken(sessionToken); + callSync(chat, true, withSchema); - } finally { - ContextHolder.destroy(); - } + } finally { + ContextHolder.destroy(); + } }); - - } - - - public String callSync(GeaflowChat chat, boolean saveRecord, boolean withSchema) { - String answer = null; - GeaflowStatementStatus status = null; - try { - GeaflowLLM geaflowLLM = llmService.get(chat.getModelId()); - Preconditions.checkNotNull(geaflowLLM, "Model %s not found", chat.getModelId()); - - LLMClient chatClient = llmClientFactory.getLLMClient(geaflowLLM.getType()); - - String prompt = getPromptWithSchema(chat, withSchema); - answer = chatClient.call(geaflowLLM, prompt); - status = GeaflowStatementStatus.FINISHED; - return answer; - - } catch (Exception e) { - answer = ExceptionUtils.getStackTrace(e); - status = GeaflowStatementStatus.FAILED; - return answer; - - } finally { - // whether save the query record. - if (saveRecord) { - chat.setAnswer(answer); - chat.setStatus(status); - if (chat.getId() == null) { - create(chat); - } else { - update(chat); - } - } - + } + + public String callSync(GeaflowChat chat, boolean saveRecord, boolean withSchema) { + String answer = null; + GeaflowStatementStatus status = null; + try { + GeaflowLLM geaflowLLM = llmService.get(chat.getModelId()); + Preconditions.checkNotNull(geaflowLLM, "Model %s not found", chat.getModelId()); + + LLMClient chatClient = llmClientFactory.getLLMClient(geaflowLLM.getType()); + + String prompt = getPromptWithSchema(chat, withSchema); + answer = chatClient.call(geaflowLLM, prompt); + status = GeaflowStatementStatus.FINISHED; + return answer; + + } catch (Exception e) { + answer = ExceptionUtils.getStackTrace(e); + status = GeaflowStatementStatus.FAILED; + return answer; + + } finally { + // whether save the query record. + if (saveRecord) { + chat.setAnswer(answer); + chat.setStatus(status); + if (chat.getId() == null) { + create(chat); + } else { + update(chat); } + } } + } + + private String getPromptWithSchema(GeaflowChat chat, boolean withSchema) { + String userPrompt = chat.getPrompt(); + if (withSchema && chat.getJobId() != null) { + try { + String schemaScript = schemaCache.getIfPresent(chat.getJobId()); + if (schemaScript == null) { + GeaflowJob job = jobService.get(chat.getJobId()); + if (job != null && CollectionUtils.isNotEmpty(job.getGraphs())) { + GeaflowGraph graph = job.getGraphs().get(0); + schemaScript = GraphSchemaTranslator.translateGraphSchema(graph); + schemaCache.put(chat.getJobId(), schemaScript); + } + } - private String getPromptWithSchema(GeaflowChat chat, boolean withSchema) { - String userPrompt = chat.getPrompt(); - if (withSchema && chat.getJobId() != null) { - try { - String schemaScript = schemaCache.getIfPresent(chat.getJobId()); - if (schemaScript == null) { - GeaflowJob job = jobService.get(chat.getJobId()); - if (job != null && CollectionUtils.isNotEmpty(job.getGraphs())) { - GeaflowGraph graph = job.getGraphs().get(0); - schemaScript = GraphSchemaTranslator.translateGraphSchema(graph); - schemaCache.put(chat.getJobId(), schemaScript); - } - } - - if (schemaScript == null) { - return userPrompt; - } - - // add graphSchema prompt - StringBuilder sb = new StringBuilder(); - return sb.append(HINT_STATEMENT) - .append(schemaScript) - .append("\n") - .append(userPrompt) - .toString() - .replace("\n", "") - .replaceAll("\\s+", " "); - - } catch (Exception e) { - // return userPrompt if getting schema failed. - log.info("Get SchemaScript failed, jobId: {}, modelId: {}", chat.getJobId(), chat.getModelId(), e); - return userPrompt; - } + if (schemaScript == null) { + return userPrompt; } + // add graphSchema prompt + StringBuilder sb = new StringBuilder(); + return sb.append(HINT_STATEMENT) + .append(schemaScript) + .append("\n") + .append(userPrompt) + .toString() + .replace("\n", "") + .replaceAll("\\s+", " "); + + } catch (Exception e) { + // return userPrompt if getting schema failed. + log.info( + "Get SchemaScript failed, jobId: {}, modelId: {}", + chat.getJobId(), + chat.getModelId(), + e); return userPrompt; + } } - public boolean dropByJobId(String jobId) { - return chatDao.dropByJobId(jobId); - } -} + return userPrompt; + } + public boolean dropByJobId(String jobId) { + return chatDao.dropByJobId(jobId); + } +} diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ClusterService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ClusterService.java index 99c8d29b1..83c4b95b8 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ClusterService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ClusterService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.ClusterDao; import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.entity.ClusterEntity; @@ -34,29 +35,26 @@ @Service public class ClusterService extends NameService { - @Autowired - private ClusterDao clusterDao; + @Autowired private ClusterDao clusterDao; - @Autowired - private ClusterConverter clusterConverter; + @Autowired private ClusterConverter clusterConverter; - @Override - protected NameDao getDao() { - return clusterDao; - } + @Override + protected NameDao getDao() { + return clusterDao; + } - @Override - protected NameConverter getConverter() { - return clusterConverter; - } + @Override + protected NameConverter getConverter() { + return clusterConverter; + } - @Override - protected List parse(List clusterEntities) { - return ListUtil.convert(clusterEntities, clusterConverter::convert); - } + @Override + protected List parse(List clusterEntities) { + return ListUtil.convert(clusterEntities, clusterConverter::convert); + } - public GeaflowCluster getDefaultCluster() { - return parse(clusterDao.getDefaultCluster()); - } + public GeaflowCluster getDefaultCluster() { + return parse(clusterDao.getDefaultCluster()); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DataService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DataService.java index 4c147d917..477c448d3 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DataService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DataService.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.entity.DataEntity; @@ -39,112 +40,108 @@ import org.apache.geaflow.console.core.service.converter.DataConverter; import org.springframework.beans.factory.annotation.Autowired; -public abstract class DataService extends - NameService { - - @Autowired - private InstanceService instanceService; +public abstract class DataService + extends NameService { - @Override - protected abstract DataDao getDao(); + @Autowired private InstanceService instanceService; - @Override - protected abstract DataConverter getConverter(); + @Override + protected abstract DataDao getDao(); - @Override - public List getByNames(List names) { - throw new GeaflowException("Use getByNames(instanceId, names) instead"); - } + @Override + protected abstract DataConverter getConverter(); - @Override - public boolean dropByNames(List names) { - throw new GeaflowException("Use dropByNames(instanceId, names) instead"); - } + @Override + public List getByNames(List names) { + throw new GeaflowException("Use getByNames(instanceId, names) instead"); + } - @Override - public Map getIdsByNames(List names) { - throw new GeaflowException("Use getIdsByNames(instanceId, names) instead"); - } + @Override + public boolean dropByNames(List names) { + throw new GeaflowException("Use dropByNames(instanceId, names) instead"); + } + @Override + public Map getIdsByNames(List names) { + throw new GeaflowException("Use getIdsByNames(instanceId, names) instead"); + } - public M getByName(String instanceId, String name) { - if (name == null) { - return null; - } - List users = getByNames(instanceId, Collections.singletonList(name)); - return users.isEmpty() ? null : users.get(0); + public M getByName(String instanceId, String name) { + if (name == null) { + return null; } + List users = getByNames(instanceId, Collections.singletonList(name)); + return users.isEmpty() ? null : users.get(0); + } - public boolean dropByName(String instanceId, String name) { - if (name == null) { - return false; - } - return dropByNames(instanceId, Collections.singletonList(name)); + public boolean dropByName(String instanceId, String name) { + if (name == null) { + return false; } + return dropByNames(instanceId, Collections.singletonList(name)); + } - - public String getIdByName(String instanceId, String name) { - if (name == null) { - return null; - } - Map idsByNames = getIdsByNames(instanceId, Collections.singletonList(name)); - return idsByNames.get(name); + public String getIdByName(String instanceId, String name) { + if (name == null) { + return null; } - - - public List getByNames(String instanceId, List names) { - List entityList = getDao().getByNames(instanceId, names); - if (CollectionUtils.isEmpty(entityList)) { - return new ArrayList<>(); - } - return parse(entityList); + Map idsByNames = getIdsByNames(instanceId, Collections.singletonList(name)); + return idsByNames.get(name); + } + + public List getByNames(String instanceId, List names) { + List entityList = getDao().getByNames(instanceId, names); + if (CollectionUtils.isEmpty(entityList)) { + return new ArrayList<>(); } - - public boolean dropByNames(String instanceId, List names) { - Map idsByNames = getIdsByNames(instanceId, names); - return this.drop(new ArrayList<>(idsByNames.values())); - } - - - public Map getIdsByNames(String instanceId, List names) { - return getDao().getIdsByNames(instanceId, names); - } - - - public List getByInstanceId(String instanceId) { - List entities = getDao().getByInstanceId(instanceId); - return parse(entities); - } - - @Override - public List create(List models) { - // check duplicate names in the current instance - checkInstanceUniqueName(models); - return super.create(models); - } - - protected void checkInstanceUniqueName(List models) { - Map> map = models.stream().collect(Collectors.groupingBy(GeaflowData::getInstanceId)); - - for (Entry> entry : map.entrySet()) { - String instanceId = entry.getKey(); - List names = ListUtil.convert(entry.getValue(), GeaflowName::getName); - HashSet nameSet = new HashSet<>(names); - // check duplicate names in set - if (nameSet.size() != names.size()) { - throw new GeaflowException("Duplicated name found in {}", names); - } - List resourceCounts = instanceService.getResourceCount(instanceId, names); - List filtered = resourceCounts.stream().filter(e -> e.getCount() > 0) - .collect(Collectors.toList()); - // check duplicate names in databases - if (!filtered.isEmpty()) { - StringBuilder sb = new StringBuilder(); - for (ResourceCount resourceCount : filtered) { - sb.append(Fmt.as("type:{}, name: {};", resourceCount.getType(), resourceCount.getName())); - } - throw new GeaflowException("Name conflict in current instance {}: {}", instanceId, sb.toString()); - } + return parse(entityList); + } + + public boolean dropByNames(String instanceId, List names) { + Map idsByNames = getIdsByNames(instanceId, names); + return this.drop(new ArrayList<>(idsByNames.values())); + } + + public Map getIdsByNames(String instanceId, List names) { + return getDao().getIdsByNames(instanceId, names); + } + + public List getByInstanceId(String instanceId) { + List entities = getDao().getByInstanceId(instanceId); + return parse(entities); + } + + @Override + public List create(List models) { + // check duplicate names in the current instance + checkInstanceUniqueName(models); + return super.create(models); + } + + protected void checkInstanceUniqueName(List models) { + Map> map = + models.stream().collect(Collectors.groupingBy(GeaflowData::getInstanceId)); + + for (Entry> entry : map.entrySet()) { + String instanceId = entry.getKey(); + List names = ListUtil.convert(entry.getValue(), GeaflowName::getName); + HashSet nameSet = new HashSet<>(names); + // check duplicate names in set + if (nameSet.size() != names.size()) { + throw new GeaflowException("Duplicated name found in {}", names); + } + List resourceCounts = instanceService.getResourceCount(instanceId, names); + List filtered = + resourceCounts.stream().filter(e -> e.getCount() > 0).collect(Collectors.toList()); + // check duplicate names in databases + if (!filtered.isEmpty()) { + StringBuilder sb = new StringBuilder(); + for (ResourceCount resourceCount : filtered) { + sb.append(Fmt.as("type:{}, name: {};", resourceCount.getType(), resourceCount.getName())); } + throw new GeaflowException( + "Name conflict in current instance {}: {}", instanceId, sb.toString()); + } } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DatasourceService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DatasourceService.java index ad0053ad8..45882be63 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DatasourceService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/DatasourceService.java @@ -19,13 +19,12 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.sql.Connection; import java.sql.ResultSet; import java.sql.Statement; + import javax.sql.DataSource; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.plugin.config.JdbcPluginConfigClass; import org.springframework.beans.factory.InitializingBean; @@ -34,54 +33,57 @@ import org.springframework.jdbc.datasource.init.ScriptUtils; import org.springframework.stereotype.Service; +import com.google.common.base.Preconditions; + +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Service public class DatasourceService implements InitializingBean { - @Autowired - private DataSource dataSource; + @Autowired private DataSource dataSource; - @Getter - private boolean initialized; + @Getter private boolean initialized; - @Override - public void afterPropertiesSet() throws Exception { - try (Connection connection = dataSource.getConnection()) { - try (Statement statement = connection.createStatement()) { - try (ResultSet resultSet = statement.executeQuery("SELECT database()")) { - if (resultSet.next()) { - String databaseName = resultSet.getString(1); - Preconditions.checkNotNull(databaseName, "No database selected in url"); - } - } + @Override + public void afterPropertiesSet() throws Exception { + try (Connection connection = dataSource.getConnection()) { + try (Statement statement = connection.createStatement()) { + try (ResultSet resultSet = statement.executeQuery("SELECT database()")) { + if (resultSet.next()) { + String databaseName = resultSet.getString(1); + Preconditions.checkNotNull(databaseName, "No database selected in url"); + } + } - // check database inited - try (ResultSet resultSet = statement.executeQuery("SHOW TABLES LIKE 'geaflow_%'")) { - if (resultSet.next()) { - initialized = true; - } - } - } + // check database inited + try (ResultSet resultSet = statement.executeQuery("SHOW TABLES LIKE 'geaflow_%'")) { + if (resultSet.next()) { + initialized = true; + } + } + } - if (!initialized) { - synchronized (DatasourceService.class) { - ScriptUtils.executeSqlScript(connection, new ClassPathResource("datasource.init.sql")); - initialized = true; - } - } + if (!initialized) { + synchronized (DatasourceService.class) { + ScriptUtils.executeSqlScript(connection, new ClassPathResource("datasource.init.sql")); + initialized = true; } + } } + } - public void executeResource(JdbcPluginConfigClass jdbcConfig, String resource) { - String url = jdbcConfig.getUrl(); + public void executeResource(JdbcPluginConfigClass jdbcConfig, String resource) { + String url = jdbcConfig.getUrl(); - try { - try (Connection connection = jdbcConfig.createConnection()) { - ScriptUtils.executeSqlScript(connection, new ClassPathResource(resource)); - } + try { + try (Connection connection = jdbcConfig.createConnection()) { + ScriptUtils.executeSqlScript(connection, new ClassPathResource(resource)); + } - } catch (Exception e) { - throw new GeaflowException("Execute {} on {} failed", resource, url, e); - } + } catch (Exception e) { + throw new GeaflowException("Execute {} on {} failed", resource, url, e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/EdgeService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/EdgeService.java index 1d2bbd98f..17f840777 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/EdgeService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/EdgeService.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.dao.EdgeDao; import org.apache.geaflow.console.common.dal.entity.EdgeEntity; @@ -41,77 +42,77 @@ @Service public class EdgeService extends DataService { - private final GeaflowResourceType resourceType = GeaflowResourceType.EDGE; - @Autowired - private EdgeDao edgeDao; + private final GeaflowResourceType resourceType = GeaflowResourceType.EDGE; + @Autowired private EdgeDao edgeDao; - @Autowired - private FieldService fieldService; + @Autowired private FieldService fieldService; - @Autowired - private GraphService graphService; + @Autowired private GraphService graphService; - @Autowired - private EdgeConverter edgeConverter; + @Autowired private EdgeConverter edgeConverter; - @Override - protected DataDao getDao() { - return edgeDao; - } + @Override + protected DataDao getDao() { + return edgeDao; + } - @Override - protected DataConverter getConverter() { - return edgeConverter; - } + @Override + protected DataConverter getConverter() { + return edgeConverter; + } - @Override - public List create(List models) { - List edgeIds = super.create(models); - // save fields - for (GeaflowEdge model : models) { - fieldService.createByResource(new ArrayList<>(model.getFields().values()), model.getId(), resourceType); - } - return edgeIds; + @Override + public List create(List models) { + List edgeIds = super.create(models); + // save fields + for (GeaflowEdge model : models) { + fieldService.createByResource( + new ArrayList<>(model.getFields().values()), model.getId(), resourceType); } - - @Override - protected List parse(List edgeEntities) { - List edgeIds = ListUtil.convert(edgeEntities, IdEntity::getId); - // select fields - Map> fieldsMap = fieldService.getByResources(edgeIds, GeaflowResourceType.EDGE); - - return edgeEntities.stream().map(e -> { - List fields = fieldsMap.get(e.getId()); - return edgeConverter.convert(e, fields); - }).collect(Collectors.toList()); + return edgeIds; + } + + @Override + protected List parse(List edgeEntities) { + List edgeIds = ListUtil.convert(edgeEntities, IdEntity::getId); + // select fields + Map> fieldsMap = + fieldService.getByResources(edgeIds, GeaflowResourceType.EDGE); + + return edgeEntities.stream() + .map( + e -> { + List fields = fieldsMap.get(e.getId()); + return edgeConverter.convert(e, fields); + }) + .collect(Collectors.toList()); + } + + @Override + public boolean update(List edges) { + List ids = ListUtil.convert(edges, GeaflowId::getId); + + fieldService.removeByResources(ids, resourceType); + for (GeaflowEdge edge : edges) { + List newFields = new ArrayList<>(edge.getFields().values()); + fieldService.createByResource(newFields, edge.getId(), resourceType); } - - @Override - public boolean update(List edges) { - List ids = ListUtil.convert(edges, GeaflowId::getId); - - fieldService.removeByResources(ids, resourceType); - for (GeaflowEdge edge : edges) { - List newFields = new ArrayList<>(edge.getFields().values()); - fieldService.createByResource(newFields, edge.getId(), resourceType); - } - return super.update(edges); + return super.update(edges); + } + + @Override + public boolean drop(List ids) { + // can't drop if is used in graph. + for (String id : ids) { + graphService.checkBindingRelations(id, GeaflowResourceType.EDGE); } - @Override - public boolean drop(List ids) { - // can't drop if is used in graph. - for (String id : ids) { - graphService.checkBindingRelations(id, GeaflowResourceType.EDGE); - } - - fieldService.removeByResources(ids, resourceType); - return super.drop(ids); - } - - public List getEdgesByGraphId(String graphId) { - List edges = edgeDao.getByGraphId(graphId); - return parse(edges); - } + fieldService.removeByResources(ids, resourceType); + return super.drop(ids); + } + public List getEdgesByGraphId(String graphId) { + List edges = edgeDao.getByGraphId(graphId); + return parse(edges); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FieldService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FieldService.java index 70c6e4c90..2736e87a8 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FieldService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FieldService.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.dao.FieldDao; import org.apache.geaflow.console.common.dal.dao.NameDao; @@ -44,91 +45,97 @@ @Service public class FieldService extends NameService { - @Autowired - private FieldDao fieldDao; + @Autowired private FieldDao fieldDao; - @Autowired - private FieldConverter fieldConverter; + @Autowired private FieldConverter fieldConverter; - @Override - protected NameDao getDao() { - return fieldDao; - } + @Override + protected NameDao getDao() { + return fieldDao; + } - @Override - protected NameConverter getConverter() { - return fieldConverter; - } + @Override + protected NameConverter getConverter() { + return fieldConverter; + } - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> fieldConverter.convert(e)); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> fieldConverter.convert(e)); + } - @Override - public List create(List models) { - throw new UnsupportedOperationException("Field can only be saved by vertex/edge/table"); + @Override + public List create(List models) { + throw new UnsupportedOperationException("Field can only be saved by vertex/edge/table"); + } + + public List createByResource( + List models, String resourceId, GeaflowResourceType resourceType) { + if (CollectionUtils.isEmpty(models)) { + return new ArrayList<>(); } - public List createByResource(List models, String resourceId, - GeaflowResourceType resourceType) { - if (CollectionUtils.isEmpty(models)) { - return new ArrayList<>(); - } - - List entities = new ArrayList<>(); - for (int i = 0; i < models.size(); i++) { - entities.add(fieldConverter.convert(models.get(i), resourceId, resourceType, i)); - } - return fieldDao.create(entities); + List entities = new ArrayList<>(); + for (int i = 0; i < models.size(); i++) { + entities.add(fieldConverter.convert(models.get(i), resourceId, resourceType, i)); } + return fieldDao.create(entities); + } - public Map> getByResources(List resourceIds, GeaflowResourceType resourceType) { - if (CollectionUtils.isEmpty(resourceIds)) { - return new HashMap<>(); - } - List fieldEntityList = fieldDao.getByResources(resourceIds, resourceType); + public Map> getByResources( + List resourceIds, GeaflowResourceType resourceType) { + if (CollectionUtils.isEmpty(resourceIds)) { + return new HashMap<>(); + } + List fieldEntityList = fieldDao.getByResources(resourceIds, resourceType); - // init map to avoid null fields - Map> modelMap = new HashMap<>(); - resourceIds.forEach(id -> { - modelMap.put(id, new ArrayList<>()); + // init map to avoid null fields + Map> modelMap = new HashMap<>(); + resourceIds.forEach( + id -> { + modelMap.put(id, new ArrayList<>()); }); - // convert to map according to guid. - fieldEntityList.forEach(e -> { - GeaflowField field = fieldConverter.convert(e); - String id = e.getResourceId(); - modelMap.get(id).add(field); + // convert to map according to guid. + fieldEntityList.forEach( + e -> { + GeaflowField field = fieldConverter.convert(e); + String id = e.getResourceId(); + modelMap.get(id).add(field); }); - return modelMap; - } + return modelMap; + } + public void removeByResources(List resourceIds, GeaflowResourceType resourceType) { + fieldDao.removeByResources(resourceIds, resourceType); + } - public void removeByResources(List resourceIds, GeaflowResourceType resourceType) { - fieldDao.removeByResources(resourceIds, resourceType); - } - - public void updateResourceFields(List oldFields, List newFields, GeaflowStruct struct) { - - GeaflowResourceType resourceType = GeaflowResourceType.valueOf(struct.getType().name()); - // Save old fields - Map oldFieldMap = ListUtil.toMap(oldFields, GeaflowId::getId); - // add fields if id is null - List addFields = newFields.stream().filter(e -> e.getId() == null).collect(Collectors.toList()); - this.createByResource(addFields, struct.getId(), resourceType); + public void updateResourceFields( + List oldFields, List newFields, GeaflowStruct struct) { - // update if id is null and different from the old field - List updateFields = newFields.stream() - .filter(e -> e.getId() != null && !oldFieldMap.get(e.getId()).equals(e)).collect(Collectors.toList()); - this.update(updateFields); + GeaflowResourceType resourceType = GeaflowResourceType.valueOf(struct.getType().name()); + // Save old fields + Map oldFieldMap = ListUtil.toMap(oldFields, GeaflowId::getId); + // add fields if id is null + List addFields = + newFields.stream().filter(e -> e.getId() == null).collect(Collectors.toList()); + this.createByResource(addFields, struct.getId(), resourceType); - // get the remove fields by subset - List oldIds = ListUtil.convert(oldFields, GeaflowId::getId); - List newIds = newFields.stream().map(GeaflowName::getId).filter(Objects::nonNull) + // update if id is null and different from the old field + List updateFields = + newFields.stream() + .filter(e -> e.getId() != null && !oldFieldMap.get(e.getId()).equals(e)) .collect(Collectors.toList()); - List removeIds = ListUtil.diff(oldIds, newIds); - this.drop(removeIds); - } - + this.update(updateFields); + + // get the remove fields by subset + List oldIds = ListUtil.convert(oldFields, GeaflowId::getId); + List newIds = + newFields.stream() + .map(GeaflowName::getId) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + List removeIds = ListUtil.diff(oldIds, newIds); + this.drop(removeIds); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FunctionService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FunctionService.java index 261b49a51..884ba4407 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FunctionService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/FunctionService.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.dao.FunctionDao; import org.apache.geaflow.console.common.dal.entity.FunctionEntity; @@ -37,46 +38,45 @@ import org.springframework.stereotype.Service; @Service -public class FunctionService extends DataService implements FileRefService { - - @Autowired - private FunctionDao functionDao; +public class FunctionService extends DataService + implements FileRefService { - @Autowired - private RemoteFileService remoteFileService; + @Autowired private FunctionDao functionDao; - @Autowired - private FunctionConverter functionConverter; + @Autowired private RemoteFileService remoteFileService; - @Override - protected DataDao getDao() { - return functionDao; - } + @Autowired private FunctionConverter functionConverter; - @Override - protected DataConverter getConverter() { - return functionConverter; - } + @Override + protected DataDao getDao() { + return functionDao; + } - @Override - protected List parse(List functionEntities) { - // get jar packages - List packageIds = functionEntities.stream().map(FunctionEntity::getJarPackageId) - .collect(Collectors.toList()); - List jarPackages = remoteFileService.get(packageIds); + @Override + protected DataConverter getConverter() { + return functionConverter; + } - Map map = ListUtil.toMap(jarPackages, GeaflowId::getId); + @Override + protected List parse(List functionEntities) { + // get jar packages + List packageIds = + functionEntities.stream().map(FunctionEntity::getJarPackageId).collect(Collectors.toList()); + List jarPackages = remoteFileService.get(packageIds); - return functionEntities.stream().map(e -> { - GeaflowRemoteFile jarPackage = map.get(e.getJarPackageId()); - return functionConverter.convert(e, jarPackage); - }).collect(Collectors.toList()); - } + Map map = ListUtil.toMap(jarPackages, GeaflowId::getId); - @Override - public long getFileRefCount(String fileId, String excludeFunctionId) { - return functionDao.getFileRefCount(fileId, excludeFunctionId); - } + return functionEntities.stream() + .map( + e -> { + GeaflowRemoteFile jarPackage = map.get(e.getJarPackageId()); + return functionConverter.convert(e, jarPackage); + }) + .collect(Collectors.toList()); + } + @Override + public long getFileRefCount(String fileId, String excludeFunctionId) { + return functionDao.getFileRefCount(fileId, excludeFunctionId); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/GraphService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/GraphService.java index 466cea3fc..34d103361 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/GraphService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/GraphService.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.dao.EndpointDao; @@ -54,194 +55,213 @@ @Service public class GraphService extends DataService { - @Autowired - private GraphDao graphDao; - - @Autowired - private VertexService vertexService; + @Autowired private GraphDao graphDao; - @Autowired - private EdgeService edgeService; + @Autowired private VertexService vertexService; - @Autowired - private GraphStructMappingDao graphStructMappingDao; + @Autowired private EdgeService edgeService; - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private GraphStructMappingDao graphStructMappingDao; - @Autowired - private GraphConverter graphConverter; + @Autowired private PluginConfigService pluginConfigService; - @Autowired - private EndpointDao endpointDao; + @Autowired private GraphConverter graphConverter; - @Autowired - private DataStoreFactory dataStoreFactory; + @Autowired private EndpointDao endpointDao; - @Autowired - private PluginService pluginService; + @Autowired private DataStoreFactory dataStoreFactory; - @Override - protected DataDao getDao() { - return graphDao; - } + @Autowired private PluginService pluginService; - @Override - protected DataConverter getConverter() { - return graphConverter; - } + @Override + protected DataDao getDao() { + return graphDao; + } - @Override - public List create(List graphs) { - List ids = super.create(graphs); - for (GeaflowGraph g : graphs) { - List vertices = ListUtil.convert(g.getVertices().values(), GeaflowId::getId); - List edges = ListUtil.convert(g.getEdges().values(), GeaflowId::getId); - saveGraphStructs(g, vertices, edges); - } + @Override + protected DataConverter getConverter() { + return graphConverter; + } - return ids; + @Override + public List create(List graphs) { + List ids = super.create(graphs); + for (GeaflowGraph g : graphs) { + List vertices = ListUtil.convert(g.getVertices().values(), GeaflowId::getId); + List edges = ListUtil.convert(g.getEdges().values(), GeaflowId::getId); + saveGraphStructs(g, vertices, edges); } - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, g -> { - // get Vertices and edges. - String id = g.getId(); - List vertices = vertexService.getVerticesByGraphId(id); - List edges = edgeService.getEdgesByGraphId(id); - GeaflowPluginConfig pluginConfig = pluginConfigService.get(g.getPluginConfigId()); - List endpoints = ListUtil.convert(endpointDao.getByGraphId(g.getId()), - e -> new GeaflowEndpoint(e.getEdgeId(), e.getSourceId(), e.getTargetId())); - return graphConverter.convert(g, vertices, edges, endpoints, pluginConfig); + return ids; + } + + @Override + protected List parse(List entities) { + return ListUtil.convert( + entities, + g -> { + // get Vertices and edges. + String id = g.getId(); + List vertices = vertexService.getVerticesByGraphId(id); + List edges = edgeService.getEdgesByGraphId(id); + GeaflowPluginConfig pluginConfig = pluginConfigService.get(g.getPluginConfigId()); + List endpoints = + ListUtil.convert( + endpointDao.getByGraphId(g.getId()), + e -> new GeaflowEndpoint(e.getEdgeId(), e.getSourceId(), e.getTargetId())); + return graphConverter.convert(g, vertices, edges, endpoints, pluginConfig); }); + } + + @Override + public boolean update(List models) { + List ids = ListUtil.convert(models, GeaflowId::getId); + graphStructMappingDao.removeByGraphIds(ids); + + // update vertices and edges + for (GeaflowGraph newGraph : models) { + List vertexIds = ListUtil.convert(newGraph.getVertices().values(), GeaflowId::getId); + List edgeIds = ListUtil.convert(newGraph.getEdges().values(), GeaflowId::getId); + saveGraphStructs(newGraph, vertexIds, edgeIds); + updateGraphEndpoints(newGraph, vertexIds, edgeIds); + GeaflowPluginConfig pluginConfig = newGraph.getPluginConfig(); + pluginConfigService.update(pluginConfig); } - @Override - public boolean update(List models) { - List ids = ListUtil.convert(models, GeaflowId::getId); - graphStructMappingDao.removeByGraphIds(ids); - - // update vertices and edges - for (GeaflowGraph newGraph : models) { - List vertexIds = ListUtil.convert(newGraph.getVertices().values(), GeaflowId::getId); - List edgeIds = ListUtil.convert(newGraph.getEdges().values(), GeaflowId::getId); - saveGraphStructs(newGraph, vertexIds, edgeIds); - updateGraphEndpoints(newGraph, vertexIds, edgeIds); - GeaflowPluginConfig pluginConfig = newGraph.getPluginConfig(); - pluginConfigService.update(pluginConfig); - } - - return super.update(models); - } - - private void updateGraphEndpoints(GeaflowGraph graph, List vertexIds, List edgeIds) { - List entities = endpointDao.getByGraphId(graph.getId()); - // drop endpoints whose vertex/edge not in new graph. - List dropIds = entities.stream().filter(e -> !edgeIds.contains(e.getEdgeId()) - || !vertexIds.contains(e.getSourceId()) - || !vertexIds.contains(e.getTargetId())).map(IdEntity::getId).collect(Collectors.toList()); - endpointDao.drop(dropIds); + return super.update(models); + } + + private void updateGraphEndpoints( + GeaflowGraph graph, List vertexIds, List edgeIds) { + List entities = endpointDao.getByGraphId(graph.getId()); + // drop endpoints whose vertex/edge not in new graph. + List dropIds = + entities.stream() + .filter( + e -> + !edgeIds.contains(e.getEdgeId()) + || !vertexIds.contains(e.getSourceId()) + || !vertexIds.contains(e.getTargetId())) + .map(IdEntity::getId) + .collect(Collectors.toList()); + endpointDao.drop(dropIds); + } + + @Override + public boolean drop(List ids) { + List entities = graphDao.get(ids); + + List graphs = ListUtil.convert(entities, this::parse); + clean(graphs); + + // do not delete edges and vertices + graphStructMappingDao.removeByGraphIds(ids); + endpointDao.dropByGraphIds(ids); + pluginConfigService.drop(ListUtil.convert(entities, GraphEntity::getPluginConfigId)); + return super.drop(ids); + } + + public boolean clean(List graphs) { + // clean graph data, do not delete graph + GeaflowPluginCategory category = GeaflowPluginCategory.DATA; + String dataType = pluginService.getDefaultPlugin(category).getType(); + GeaflowDataStore dataStore = dataStoreFactory.getDataStore(dataType); + + for (GeaflowGraph graph : graphs) { + dataStore.cleanGraphData(graph); } - @Override - public boolean drop(List ids) { - List entities = graphDao.get(ids); + return true; + } - List graphs = ListUtil.convert(entities, this::parse); - clean(graphs); + private void saveGraphStructs(GeaflowGraph g, List vertexIds, List edgeIds) { + List graphStructs = new ArrayList<>(); + String graphId = g.getId(); - // do not delete edges and vertices - graphStructMappingDao.removeByGraphIds(ids); - endpointDao.dropByGraphIds(ids); - pluginConfigService.drop(ListUtil.convert(entities, GraphEntity::getPluginConfigId)); - return super.drop(ids); + for (int i = 0; i < vertexIds.size(); i++) { + String id = vertexIds.get(i); + GraphStructMappingEntity entity = + new GraphStructMappingEntity(graphId, id, GeaflowResourceType.VERTEX, i); + graphStructs.add(entity); } - public boolean clean(List graphs) { - // clean graph data, do not delete graph - GeaflowPluginCategory category = GeaflowPluginCategory.DATA; - String dataType = pluginService.getDefaultPlugin(category).getType(); - GeaflowDataStore dataStore = dataStoreFactory.getDataStore(dataType); - - for (GeaflowGraph graph : graphs) { - dataStore.cleanGraphData(graph); - } - - return true; + for (int i = 0; i < edgeIds.size(); i++) { + String id = edgeIds.get(i); + GraphStructMappingEntity entity = + new GraphStructMappingEntity(graphId, id, GeaflowResourceType.EDGE, i); + graphStructs.add(entity); } - private void saveGraphStructs(GeaflowGraph g, List vertexIds, List edgeIds) { - List graphStructs = new ArrayList<>(); - String graphId = g.getId(); - - for (int i = 0; i < vertexIds.size(); i++) { - String id = vertexIds.get(i); - GraphStructMappingEntity entity = new GraphStructMappingEntity(graphId, id, GeaflowResourceType.VERTEX, i); - graphStructs.add(entity); - } - - for (int i = 0; i < edgeIds.size(); i++) { - String id = edgeIds.get(i); - GraphStructMappingEntity entity = new GraphStructMappingEntity(graphId, id, GeaflowResourceType.EDGE, i); - graphStructs.add(entity); - } - - if (!graphStructs.isEmpty()) { - graphStructMappingDao.create(graphStructs); - } + if (!graphStructs.isEmpty()) { + graphStructMappingDao.create(graphStructs); } - - public boolean createEndpoints(GeaflowGraph graph, List endpoints) { - // validate vertex/edge of endpoints exist in graph - validateEndpoints(graph, endpoints); - // do not insert if exist - List entities = endpoints.stream() - .filter(e -> !endpointDao.exists(graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId())) - .map(e -> new EndpointEntity(graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId())) + } + + public boolean createEndpoints(GeaflowGraph graph, List endpoints) { + // validate vertex/edge of endpoints exist in graph + validateEndpoints(graph, endpoints); + // do not insert if exist + List entities = + endpoints.stream() + .filter( + e -> + !endpointDao.exists( + graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId())) + .map( + e -> + new EndpointEntity( + graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId())) .distinct() .collect(Collectors.toList()); - endpointDao.create(entities); - return true; + endpointDao.create(entities); + return true; + } + + private void validateEndpoints(GeaflowGraph graph, List geaflowEndpoints) { + Set edgeIds = + graph.getEdges().values().stream().map(GeaflowId::getId).collect(Collectors.toSet()); + Set vertexIds = + graph.getVertices().values().stream().map(GeaflowId::getId).collect(Collectors.toSet()); + for (GeaflowEndpoint endpoint : geaflowEndpoints) { + if (!edgeIds.contains(endpoint.getEdgeId())) { + throw new GeaflowException( + "Edge {} not exits in graph", edgeService.getNameById(endpoint.getEdgeId())); + } + + if (!vertexIds.contains(endpoint.getSourceId())) { + throw new GeaflowException( + "Vertex {} not exits in graph", vertexService.getNameById(endpoint.getSourceId())); + } + + if (!vertexIds.contains(endpoint.getTargetId())) { + throw new GeaflowException( + "Vertex {} not exits in graph", vertexService.getNameById(endpoint.getTargetId())); + } } + } - private void validateEndpoints(GeaflowGraph graph, List geaflowEndpoints) { - Set edgeIds = graph.getEdges().values().stream().map(GeaflowId::getId).collect(Collectors.toSet()); - Set vertexIds = graph.getVertices().values().stream().map(GeaflowId::getId).collect(Collectors.toSet()); - for (GeaflowEndpoint endpoint : geaflowEndpoints) { - if (!edgeIds.contains(endpoint.getEdgeId())) { - throw new GeaflowException("Edge {} not exits in graph", edgeService.getNameById(endpoint.getEdgeId())); - } - - if (!vertexIds.contains(endpoint.getSourceId())) { - throw new GeaflowException("Vertex {} not exits in graph", vertexService.getNameById(endpoint.getSourceId())); - } - - if (!vertexIds.contains(endpoint.getTargetId())) { - throw new GeaflowException("Vertex {} not exits in graph", vertexService.getNameById(endpoint.getTargetId())); - } - } + public boolean deleteEndpoints(GeaflowGraph graph, List endpoints) { + // delete all endpoints if input is empty + if (CollectionUtils.isEmpty(endpoints)) { + return endpointDao.dropByGraphIds(Collections.singletonList(graph.getId())); } - - public boolean deleteEndpoints(GeaflowGraph graph, List endpoints) { - // delete all endpoints if input is empty - if (CollectionUtils.isEmpty(endpoints)) { - return endpointDao.dropByGraphIds(Collections.singletonList(graph.getId())); - } - for (GeaflowEndpoint e : endpoints) { - endpointDao.dropByEndpoint(graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId()); - } - return true; + for (GeaflowEndpoint e : endpoints) { + endpointDao.dropByEndpoint(graph.getId(), e.getEdgeId(), e.getSourceId(), e.getTargetId()); } - - public void checkBindingRelations(String resourceId, GeaflowResourceType resourceType) { - // check if vertex/edge used in graph - List entities = graphStructMappingDao.getByResourceId(resourceId, resourceType); - if (CollectionUtils.isNotEmpty(entities)) { - throw new GeaflowException("{} {} is used in Graph {}", resourceType, resourceId, - this.getNameById(entities.get(0).getGraphId())); - } + return true; + } + + public void checkBindingRelations(String resourceId, GeaflowResourceType resourceType) { + // check if vertex/edge used in graph + List entities = + graphStructMappingDao.getByResourceId(resourceId, resourceType); + if (CollectionUtils.isNotEmpty(entities)) { + throw new GeaflowException( + "{} {} is used in Graph {}", + resourceType, + resourceId, + this.getNameById(entities.get(0).getGraphId())); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/IdService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/IdService.java index 8eec82cf7..b3867c26d 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/IdService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/IdService.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.IdDao; import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.dal.model.IdSearch; @@ -31,92 +32,92 @@ public abstract class IdService { - protected abstract IdDao getDao(); - - protected abstract IdConverter getConverter(); + protected abstract IdDao getDao(); - private E build(M model) { - if (model == null) { - return null; - } + protected abstract IdConverter getConverter(); - return getConverter().convert(model); + private E build(M model) { + if (model == null) { + return null; } - private List build(List models) { - return ListUtil.convert(models, this::build); - } + return getConverter().convert(model); + } - protected final M parse(E entity) { - if (entity == null) { - return null; - } + private List build(List models) { + return ListUtil.convert(models, this::build); + } - List models = parse(Collections.singletonList(entity)); - return models.isEmpty() ? null : models.get(0); + protected final M parse(E entity) { + if (entity == null) { + return null; } - protected abstract List parse(List entities); + List models = parse(Collections.singletonList(entity)); + return models.isEmpty() ? null : models.get(0); + } - public boolean exist(String id) { - return getDao().exist(id); - } + protected abstract List parse(List entities); - public PageList search(S search) { - PageList pageList = getDao().search(search); - return pageList.transform(this::parse); - } + public boolean exist(String id) { + return getDao().exist(id); + } - public M get(String id) { - if (id == null) { - return null; - } - List list = get(Collections.singletonList(id)); - return list.isEmpty() ? null : list.get(0); - } + public PageList search(S search) { + PageList pageList = getDao().search(search); + return pageList.transform(this::parse); + } - public String create(M model) { - if (model == null) { - return null; - } - return create(Collections.singletonList(model)).get(0); + public M get(String id) { + if (id == null) { + return null; } + List list = get(Collections.singletonList(id)); + return list.isEmpty() ? null : list.get(0); + } - public boolean update(M model) { - if (model == null) { - return false; - } - return update(Collections.singletonList(model)); + public String create(M model) { + if (model == null) { + return null; } + return create(Collections.singletonList(model)).get(0); + } - public boolean drop(String id) { - if (id == null) { - return false; - } - return drop(Collections.singletonList(id)); + public boolean update(M model) { + if (model == null) { + return false; } + return update(Collections.singletonList(model)); + } - public List get(List ids) { - List entityList = getDao().get(ids); - return parse(entityList); + public boolean drop(String id) { + if (id == null) { + return false; } + return drop(Collections.singletonList(id)); + } - public List create(List models) { - List entities = build(models); - List ids = getDao().create(entities); + public List get(List ids) { + List entityList = getDao().get(ids); + return parse(entityList); + } - // fill id back to model - for (int i = 0; i < models.size(); i++) { - models.get(i).setId(ids.get(i)); - } - return ids; - } + public List create(List models) { + List entities = build(models); + List ids = getDao().create(entities); - public boolean update(List models) { - return getDao().update(build(models)); + // fill id back to model + for (int i = 0; i < models.size(); i++) { + models.get(i).setId(ids.get(i)); } + return ids; + } - public boolean drop(List ids) { - return getDao().drop(ids); - } + public boolean update(List models) { + return getDao().update(build(models)); + } + + public boolean drop(List ids) { + return getDao().drop(ids); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/InstanceService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/InstanceService.java index 5d431d0a8..aff9ff42f 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/InstanceService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/InstanceService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.dao.AuthorizationDao; import org.apache.geaflow.console.common.dal.dao.InstanceDao; @@ -45,79 +46,79 @@ @Service public class InstanceService extends NameService { - @Autowired - private InstanceDao instanceDao; - - @Autowired - private AuthorizationDao authorizationDao; - - @Autowired - private InstanceConverter instanceConverter; - - @Autowired - private AuthorizationService authorizationService; - - @Override - protected NameDao getDao() { - return instanceDao; - } - - @Override - protected NameConverter getConverter() { - return instanceConverter; - } - - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> instanceConverter.convert(e)); - } - - @Override - public List create(List models) { - List ids = super.create(models); - authorizationService.addAuthorization(ids, ContextHolder.get().getUserId(), GeaflowAuthorityType.ALL, - GeaflowResourceType.INSTANCE); - return ids; - } - - @Override - public boolean drop(List ids) { - authorizationService.dropByResources(ids, GeaflowResourceType.INSTANCE); - return super.drop(ids); - } - - public List search() { - return parse(instanceDao.search()); - } - - public List getResourceCount(String instanceId, List names) { - return instanceDao.getResourceCount(instanceId, names); - } - - public String getDefaultInstanceName(String userName) { - return "instance_" + userName; - } - - @Transactional - public String createDefaultInstance(String tenantId, GeaflowUser user) { - String userName = user.getName(); - String userComment = user.getComment(); - String instanceName = getDefaultInstanceName(userName); - String userDisplayName = StringUtils.isBlank(userComment) ? userName : userComment; - String instanceComment = Fmt.as(I18nUtil.getMessage("i18n.key.default.instance.comment.format"), - userDisplayName); - - // Need to set tenantId, using dao directly - InstanceEntity entity = new InstanceEntity(); - entity.setTenantId(tenantId); - entity.setName(instanceName); - entity.setComment(instanceComment); - String instanceId = instanceDao.create(entity); - - AuthorizationEntity authorizationEntity = new AuthorizationEntity(user.getId(), GeaflowAuthorityType.ALL, - GeaflowResourceType.INSTANCE, instanceId); - authorizationEntity.setTenantId(tenantId); - authorizationDao.create(authorizationEntity); - return instanceId; - } + @Autowired private InstanceDao instanceDao; + + @Autowired private AuthorizationDao authorizationDao; + + @Autowired private InstanceConverter instanceConverter; + + @Autowired private AuthorizationService authorizationService; + + @Override + protected NameDao getDao() { + return instanceDao; + } + + @Override + protected NameConverter getConverter() { + return instanceConverter; + } + + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> instanceConverter.convert(e)); + } + + @Override + public List create(List models) { + List ids = super.create(models); + authorizationService.addAuthorization( + ids, + ContextHolder.get().getUserId(), + GeaflowAuthorityType.ALL, + GeaflowResourceType.INSTANCE); + return ids; + } + + @Override + public boolean drop(List ids) { + authorizationService.dropByResources(ids, GeaflowResourceType.INSTANCE); + return super.drop(ids); + } + + public List search() { + return parse(instanceDao.search()); + } + + public List getResourceCount(String instanceId, List names) { + return instanceDao.getResourceCount(instanceId, names); + } + + public String getDefaultInstanceName(String userName) { + return "instance_" + userName; + } + + @Transactional + public String createDefaultInstance(String tenantId, GeaflowUser user) { + String userName = user.getName(); + String userComment = user.getComment(); + String instanceName = getDefaultInstanceName(userName); + String userDisplayName = StringUtils.isBlank(userComment) ? userName : userComment; + String instanceComment = + Fmt.as(I18nUtil.getMessage("i18n.key.default.instance.comment.format"), userDisplayName); + + // Need to set tenantId, using dao directly + InstanceEntity entity = new InstanceEntity(); + entity.setTenantId(tenantId); + entity.setName(instanceName); + entity.setComment(instanceComment); + String instanceId = instanceDao.create(entity); + + AuthorizationEntity authorizationEntity = + new AuthorizationEntity( + user.getId(), GeaflowAuthorityType.ALL, GeaflowResourceType.INSTANCE, instanceId); + authorizationEntity.setTenantId(tenantId); + authorizationDao.create(authorizationEntity); + return instanceId; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/JobService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/JobService.java index 0cc3f09d9..c5483be00 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/JobService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/JobService.java @@ -28,8 +28,9 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import javax.annotation.PostConstruct; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.dal.dao.JobDao; import org.apache.geaflow.console.common.dal.dao.JobResourceMappingDao; import org.apache.geaflow.console.common.dal.dao.NameDao; @@ -67,368 +68,427 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j -public class JobService extends NameService implements FileRefService { +public class JobService extends NameService + implements FileRefService { - private final Map serviceMap = new HashMap<>(); - @Autowired - private JobDao jobDao; + private final Map serviceMap = new HashMap<>(); + @Autowired private JobDao jobDao; - @Autowired - private JobConverter jobConverter; + @Autowired private JobConverter jobConverter; - @Autowired - private JobResourceMappingDao jobResourceMappingDao; + @Autowired private JobResourceMappingDao jobResourceMappingDao; - @Autowired - private GraphService graphService; + @Autowired private GraphService graphService; - @Autowired - private TableService tableService; + @Autowired private TableService tableService; - @Autowired - private VertexService vertexService; + @Autowired private VertexService vertexService; - @Autowired - private EdgeService edgeService; + @Autowired private EdgeService edgeService; - @Autowired - private FunctionService functionService; + @Autowired private FunctionService functionService; - @Autowired - private PluginService pluginService; + @Autowired private PluginService pluginService; - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - @Autowired - private ReleaseService releaseService; + @Autowired private ReleaseService releaseService; - @Autowired - private InstanceService instanceService; + @Autowired private InstanceService instanceService; - @Autowired - private AuthorizationService authorizationService; + @Autowired private AuthorizationService authorizationService; - @Autowired - private RemoteFileService remoteFileService; + @Autowired private RemoteFileService remoteFileService; + protected NameDao getDao() { + return jobDao; + } - protected NameDao getDao() { - return jobDao; - } + @Override + protected NameConverter getConverter() { + return jobConverter; + } - @Override - protected NameConverter getConverter() { - return jobConverter; - } + @Override + protected List parse(List jobEntities) { + return jobEntities.stream() + .map( + e -> { + List structs = getJobStructs(e.getId()); + List graphs = getJobGraphs(e.getId()); + List functions = getJobFunctions(e.getId()); + List plugins = getJobPlugins(e.getId()); + GeaflowRemoteFile remoteFile = remoteFileService.get(e.getJarPackageId()); + return jobConverter.convert(e, structs, graphs, functions, plugins, remoteFile); + }) + .collect(Collectors.toList()); + } - @Override - protected List parse(List jobEntities) { - return jobEntities.stream().map(e -> { - List structs = getJobStructs(e.getId()); - List graphs = getJobGraphs(e.getId()); - List functions = getJobFunctions(e.getId()); - List plugins = getJobPlugins(e.getId()); - GeaflowRemoteFile remoteFile = remoteFileService.get(e.getJarPackageId()); - return jobConverter.convert(e, structs, graphs, functions, plugins, remoteFile); - }).collect(Collectors.toList()); + @Override + @Transactional + public List create(List models) { + // compile processJob + GeaflowVersion version = versionService.getDefaultVersion(); + for (GeaflowJob job : models) { + parseUserCode(job, version); } - - - @Override - @Transactional - public List create(List models) { - // compile processJob - GeaflowVersion version = versionService.getDefaultVersion(); - for (GeaflowJob job : models) { - parseUserCode(job, version); - } - List ids = super.create(models); - // save resourceMappings - for (GeaflowJob job : models) { - createJobResources(job.getId(), job.getStructs(), job.getGraphs(), job.getFunctions(), job.getPlugins()); - } - // save authorizations - List authorizations = ListUtil.convert(ids, - id -> new GeaflowAuthorization(ContextHolder.get().getUserId(), GeaflowAuthorityType.ALL, GeaflowResourceType.JOB, id)); - authorizationService.create(authorizations); - return ids; - } - - @Override - public boolean update(List jobs) { - GeaflowVersion version = versionService.getDefaultVersion(); - for (GeaflowJob newJob : jobs) { - if (newJob.isApiJob()) { - continue; - } - - GeaflowJob oldJob = this.get(newJob.getId()); - if (newJob instanceof GeaflowTransferJob) { - updateRefStructs(oldJob, newJob); - continue; - } - - if ((oldJob.getUserCode() == null && newJob.getUserCode() == null) - || (oldJob.getUserCode() != null && oldJob.getUserCode() != null - && oldJob.getUserCode().getText().equals(newJob.getUserCode().getText()))) { - continue; - } - - parseUserCode(newJob, version); - updateRefStructs(oldJob, newJob); - } - - return super.update(jobs); + List ids = super.create(models); + // save resourceMappings + for (GeaflowJob job : models) { + createJobResources( + job.getId(), job.getStructs(), job.getGraphs(), job.getFunctions(), job.getPlugins()); } - - void updateRefStructs(GeaflowJob oldJob, GeaflowJob newJob) { - // calculate subset of graphs - List oldGraphs = oldJob.getGraphs(); - List newGraphs = newJob.getGraphs(); - List addGraphs = ListUtil.diff(newGraphs, oldGraphs, this::getResourceKey); - List removeGraphs = ListUtil.diff(oldGraphs, newGraphs, this::getResourceKey); - - // calculate subset of structs - List oldStructs = oldJob.getStructs(); - List newStructs = newJob.getStructs(); - List addStructs = ListUtil.diff(newStructs, oldStructs, this::getResourceKey); - List removeStructs = ListUtil.diff(oldStructs, newStructs, this::getResourceKey); - - // calculate subset of functions - List oldFunctions = oldJob.getFunctions(); - List newFunctions = newJob.getFunctions(); - List addFunctions = ListUtil.diff(newFunctions, oldFunctions, this::getResourceKey); - List removeFunctions = ListUtil.diff(oldFunctions, newFunctions, this::getResourceKey); - - // calculate subset of functions - List oldJobPlugins = oldJob.getPlugins(); - List newJobPlugins = newJob.getPlugins(); - List addPlugins = ListUtil.diff(newJobPlugins, oldJobPlugins, e -> e.getType() + e.getCategory()); - List removePlugins = ListUtil.diff(oldJobPlugins, newJobPlugins, e -> e.getType() + e.getCategory()); - - // add resources - createJobResources(newJob.getId(), addStructs, addGraphs, addFunctions, addPlugins); - removeJobResources(newJob.getId(), removeStructs, removeGraphs, removeFunctions, removePlugins); + // save authorizations + List authorizations = + ListUtil.convert( + ids, + id -> + new GeaflowAuthorization( + ContextHolder.get().getUserId(), + GeaflowAuthorityType.ALL, + GeaflowResourceType.JOB, + id)); + authorizationService.create(authorizations); + return ids; + } + + @Override + public boolean update(List jobs) { + GeaflowVersion version = versionService.getDefaultVersion(); + for (GeaflowJob newJob : jobs) { + if (newJob.isApiJob()) { + continue; + } + + GeaflowJob oldJob = this.get(newJob.getId()); + if (newJob instanceof GeaflowTransferJob) { + updateRefStructs(oldJob, newJob); + continue; + } + + if ((oldJob.getUserCode() == null && newJob.getUserCode() == null) + || (oldJob.getUserCode() != null + && oldJob.getUserCode() != null + && oldJob.getUserCode().getText().equals(newJob.getUserCode().getText()))) { + continue; + } + + parseUserCode(newJob, version); + updateRefStructs(oldJob, newJob); } - public void dropResources(List jobIds) { - jobResourceMappingDao.dropByJobIds(jobIds); - } - - @PostConstruct - public void init() { - serviceMap.put(GeaflowResourceType.VERTEX, vertexService); - serviceMap.put(GeaflowResourceType.EDGE, edgeService); - serviceMap.put(GeaflowResourceType.TABLE, tableService); - serviceMap.put(GeaflowResourceType.GRAPH, graphService); - serviceMap.put(GeaflowResourceType.FUNCTION, functionService); - } - - public DataService getResourceService(GeaflowResourceType resourceType) { - return serviceMap.get(resourceType); - } - - - public long getFileRefCount(String fileId, String excludeFunctionId) { - return jobDao.getFileRefCount(fileId, excludeFunctionId); - } - - public Map getJarIds(List jobIds) { - List jobEntities = jobDao.get(jobIds); - return jobEntities.stream().filter(e -> e.getJarPackageId() != null) - .collect(Collectors.toMap(IdEntity::getId, JobEntity::getJarPackageId)); - } - - private void createJobResources(String jobId, List structs, List graphs, - List functions, List plugins) { - List entities = new ArrayList<>(); - - structs.forEach(e -> { - JobResourceMappingEntity entity = new JobResourceMappingEntity(jobId, e.getName(), - GeaflowResourceType.valueOf(e.getType().name()), e.getInstanceId()); - entities.add(entity); + return super.update(jobs); + } + + void updateRefStructs(GeaflowJob oldJob, GeaflowJob newJob) { + // calculate subset of graphs + List oldGraphs = oldJob.getGraphs(); + List newGraphs = newJob.getGraphs(); + List addGraphs = ListUtil.diff(newGraphs, oldGraphs, this::getResourceKey); + List removeGraphs = ListUtil.diff(oldGraphs, newGraphs, this::getResourceKey); + + // calculate subset of structs + List oldStructs = oldJob.getStructs(); + List newStructs = newJob.getStructs(); + List addStructs = ListUtil.diff(newStructs, oldStructs, this::getResourceKey); + List removeStructs = ListUtil.diff(oldStructs, newStructs, this::getResourceKey); + + // calculate subset of functions + List oldFunctions = oldJob.getFunctions(); + List newFunctions = newJob.getFunctions(); + List addFunctions = + ListUtil.diff(newFunctions, oldFunctions, this::getResourceKey); + List removeFunctions = + ListUtil.diff(oldFunctions, newFunctions, this::getResourceKey); + + // calculate subset of functions + List oldJobPlugins = oldJob.getPlugins(); + List newJobPlugins = newJob.getPlugins(); + List addPlugins = + ListUtil.diff(newJobPlugins, oldJobPlugins, e -> e.getType() + e.getCategory()); + List removePlugins = + ListUtil.diff(oldJobPlugins, newJobPlugins, e -> e.getType() + e.getCategory()); + + // add resources + createJobResources(newJob.getId(), addStructs, addGraphs, addFunctions, addPlugins); + removeJobResources(newJob.getId(), removeStructs, removeGraphs, removeFunctions, removePlugins); + } + + public void dropResources(List jobIds) { + jobResourceMappingDao.dropByJobIds(jobIds); + } + + @PostConstruct + public void init() { + serviceMap.put(GeaflowResourceType.VERTEX, vertexService); + serviceMap.put(GeaflowResourceType.EDGE, edgeService); + serviceMap.put(GeaflowResourceType.TABLE, tableService); + serviceMap.put(GeaflowResourceType.GRAPH, graphService); + serviceMap.put(GeaflowResourceType.FUNCTION, functionService); + } + + public DataService getResourceService(GeaflowResourceType resourceType) { + return serviceMap.get(resourceType); + } + + public long getFileRefCount(String fileId, String excludeFunctionId) { + return jobDao.getFileRefCount(fileId, excludeFunctionId); + } + + public Map getJarIds(List jobIds) { + List jobEntities = jobDao.get(jobIds); + return jobEntities.stream() + .filter(e -> e.getJarPackageId() != null) + .collect(Collectors.toMap(IdEntity::getId, JobEntity::getJarPackageId)); + } + + private void createJobResources( + String jobId, + List structs, + List graphs, + List functions, + List plugins) { + List entities = new ArrayList<>(); + + structs.forEach( + e -> { + JobResourceMappingEntity entity = + new JobResourceMappingEntity( + jobId, + e.getName(), + GeaflowResourceType.valueOf(e.getType().name()), + e.getInstanceId()); + entities.add(entity); }); - graphs.forEach(e -> { - JobResourceMappingEntity entity = new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.GRAPH, - e.getInstanceId()); - entities.add(entity); + graphs.forEach( + e -> { + JobResourceMappingEntity entity = + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.GRAPH, e.getInstanceId()); + entities.add(entity); }); - functions.forEach(e -> { - JobResourceMappingEntity entity = new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.FUNCTION, - e.getInstanceId()); - entities.add(entity); + functions.forEach( + e -> { + JobResourceMappingEntity entity = + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.FUNCTION, e.getInstanceId()); + entities.add(entity); }); - plugins.forEach(e -> { - JobResourceMappingEntity entity = new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.PLUGIN, - PLUGIN_DEFAULT_INSTANCE_ID); - entities.add(entity); + plugins.forEach( + e -> { + JobResourceMappingEntity entity = + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.PLUGIN, PLUGIN_DEFAULT_INSTANCE_ID); + entities.add(entity); }); - if (!entities.isEmpty()) { - jobResourceMappingDao.create(entities); - } - } - - private List getJobGraphs(String id) { - return getResourcesByJobId(id, GeaflowResourceType.GRAPH); + if (!entities.isEmpty()) { + jobResourceMappingDao.create(entities); } - - private List getJobFunctions(String id) { - return getResourcesByJobId(id, GeaflowResourceType.FUNCTION); - } - - private List getJobPlugins(String id) { - List entities = jobResourceMappingDao.getResourcesByJobId(id, GeaflowResourceType.PLUGIN); - return ListUtil.convert(entities, e -> { - GeaflowPlugin plugin = pluginService.getByName(e.getResourceName()); - if (plugin == null) { - throw new GeaflowException("Plugin {} not found", e.getResourceName()); - } - return plugin; + } + + private List getJobGraphs(String id) { + return getResourcesByJobId(id, GeaflowResourceType.GRAPH); + } + + private List getJobFunctions(String id) { + return getResourcesByJobId(id, GeaflowResourceType.FUNCTION); + } + + private List getJobPlugins(String id) { + List entities = + jobResourceMappingDao.getResourcesByJobId(id, GeaflowResourceType.PLUGIN); + return ListUtil.convert( + entities, + e -> { + GeaflowPlugin plugin = pluginService.getByName(e.getResourceName()); + if (plugin == null) { + throw new GeaflowException("Plugin {} not found", e.getResourceName()); + } + return plugin; }); + } + + private List getJobStructs(String jobId) { + List res = new ArrayList<>(); + res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.TABLE)); + res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.VERTEX)); + res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.EDGE)); + return res; + } + + private void parseUserCode(GeaflowJob job, GeaflowVersion version) { + if (job instanceof GeaflowProcessJob) { + // parse user functions + handleFunctions(job, version); + // parse user plugins + handlePlugins(job, version); + + // compile the code + CompileResult result = releaseService.compile(job, version, null); + // set job structs and graphs + Set graphInfos = result.getSourceGraphs(); + graphInfos.addAll(result.getTargetGraphs()); + + Set tableInfos = result.getSourceTables(); + tableInfos.addAll(result.getTargetTables()); + + List graphs = getByGraphInfos(graphInfos); + List tables = getByTableInfos(tableInfos); + + job.setStructs(tables); + job.setGraph(graphs); } - - - private List getJobStructs(String jobId) { - List res = new ArrayList<>(); - res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.TABLE)); - res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.VERTEX)); - res.addAll(getResourcesByJobId(jobId, GeaflowResourceType.EDGE)); - return res; + } + + private void handleFunctions(GeaflowJob job, GeaflowVersion version) { + Set functionInfos = releaseService.parseFunctions(job, version); + List functions = getByFunctionInfos(functionInfos); + job.setFunctions(functions); + } + + private void handlePlugins(GeaflowJob job, GeaflowVersion version) { + // parse declared table with plugin in dsl + Set dslPluginTypes = releaseService.parseDeclaredPlugins(job, version); + + // get Unresolved tables in dsl + Set unresolvedTables = releaseService.getUnResolvedTables(job, version); + + for (TableInfo tableInfo : unresolvedTables) { + String instanceId = instanceService.getIdByName(tableInfo.getInstanceName()); + GeaflowTable table = tableService.getByName(instanceId, tableInfo.getTableName()); + // get plugin used in unresolved table + if (table != null) { + String type = table.getPluginConfig().getType(); + dslPluginTypes.add(type); + } } - private void parseUserCode(GeaflowJob job, GeaflowVersion version) { - if (job instanceof GeaflowProcessJob) { - // parse user functions - handleFunctions(job, version); - //parse user plugins - handlePlugins(job, version); - - // compile the code - CompileResult result = releaseService.compile(job, version, null); - // set job structs and graphs - Set graphInfos = result.getSourceGraphs(); - graphInfos.addAll(result.getTargetGraphs()); - - Set tableInfos = result.getSourceTables(); - tableInfos.addAll(result.getTargetTables()); - - List graphs = getByGraphInfos(graphInfos); - List tables = getByTableInfos(tableInfos); - - job.setStructs(tables); - job.setGraph(graphs); - - } - } - - private void handleFunctions(GeaflowJob job, GeaflowVersion version) { - Set functionInfos = releaseService.parseFunctions(job, version); - List functions = getByFunctionInfos(functionInfos); - job.setFunctions(functions); - } - - private void handlePlugins(GeaflowJob job, GeaflowVersion version) { - // parse declared table with plugin in dsl - Set dslPluginTypes = releaseService.parseDeclaredPlugins(job, version); - - // get Unresolved tables in dsl - Set unresolvedTables = releaseService.getUnResolvedTables(job, version); - - for (TableInfo tableInfo : unresolvedTables) { - String instanceId = instanceService.getIdByName(tableInfo.getInstanceName()); - GeaflowTable table = tableService.getByName(instanceId, tableInfo.getTableName()); - // get plugin used in unresolved table - if (table != null) { - String type = table.getPluginConfig().getType(); - dslPluginTypes.add(type); - } - } - - // filter user plugins which are not in engine - Set filteredPluginTypes = dslPluginTypes.stream().filter(e -> !pluginService.pluginTypeInEngine(e, version)) + // filter user plugins which are not in engine + Set filteredPluginTypes = + dslPluginTypes.stream() + .filter(e -> !pluginService.pluginTypeInEngine(e, version)) .collect(Collectors.toSet()); - log.info("{} used plugins: {}", job.getName(), String.join(",", filteredPluginTypes)); - List plugins = getByPluginTypes(filteredPluginTypes); - job.setPlugins(plugins); - } - - private List getByPluginTypes(Set pluginTypes) { - return ListUtil.convert(pluginTypes, type -> { - // only support TABLE category plugin currently - GeaflowPlugin plugin = pluginService.getPlugin(type, GeaflowPluginCategory.TABLE); - if (plugin == null) { - throw new GeaflowException("Plugin type {} not found, please create first", type); - } - return plugin; + log.info("{} used plugins: {}", job.getName(), String.join(",", filteredPluginTypes)); + List plugins = getByPluginTypes(filteredPluginTypes); + job.setPlugins(plugins); + } + + private List getByPluginTypes(Set pluginTypes) { + return ListUtil.convert( + pluginTypes, + type -> { + // only support TABLE category plugin currently + GeaflowPlugin plugin = pluginService.getPlugin(type, GeaflowPluginCategory.TABLE); + if (plugin == null) { + throw new GeaflowException("Plugin type {} not found, please create first", type); + } + return plugin; }); - } - - private List getByTableInfos(Collection tableInfos) { - return ListUtil.convert(tableInfos, info -> { - String instanceId = instanceService.getIdByName(info.getInstanceName()); - return getResourceOrCreate(info.getTableName(), instanceId, GeaflowResourceType.TABLE); + } + + private List getByTableInfos(Collection tableInfos) { + return ListUtil.convert( + tableInfos, + info -> { + String instanceId = instanceService.getIdByName(info.getInstanceName()); + return getResourceOrCreate(info.getTableName(), instanceId, GeaflowResourceType.TABLE); }); - } - - private List getByGraphInfos(Collection graphInfos) { - return ListUtil.convert(graphInfos, info -> { - String instanceId = instanceService.getIdByName(info.getInstanceName()); - return getResourceOrCreate(info.getGraphName(), instanceId, GeaflowResourceType.GRAPH); + } + + private List getByGraphInfos(Collection graphInfos) { + return ListUtil.convert( + graphInfos, + info -> { + String instanceId = instanceService.getIdByName(info.getInstanceName()); + return getResourceOrCreate(info.getGraphName(), instanceId, GeaflowResourceType.GRAPH); }); - } - - private List getByFunctionInfos(Collection functionInfos) { - return ListUtil.convert(functionInfos, info -> { - String instanceId = instanceService.getIdByName(info.getInstanceName()); - GeaflowFunction function = functionService.getByName(instanceId, info.getFunctionName()); - // need create function first - if (function == null) { - throw new GeaflowException("Function {} not found, please create first ", info.getFunctionName()); - } - return function; + } + + private List getByFunctionInfos(Collection functionInfos) { + return ListUtil.convert( + functionInfos, + info -> { + String instanceId = instanceService.getIdByName(info.getInstanceName()); + GeaflowFunction function = functionService.getByName(instanceId, info.getFunctionName()); + // need create function first + if (function == null) { + throw new GeaflowException( + "Function {} not found, please create first ", info.getFunctionName()); + } + return function; }); + } + + private List getResourcesByJobId( + String jobId, GeaflowResourceType resourceType) { + List entities = + jobResourceMappingDao.getResourcesByJobId(jobId, resourceType); + return ListUtil.convert( + entities, e -> getResourceOrCreate(e.getResourceName(), e.getInstanceId(), resourceType)); + } + + private T getResourceOrCreate( + String resourceName, String instanceId, GeaflowResourceType resourceType) { + // create a model only with name if no resource in database. + T data = (T) getResourceService(resourceType).getByName(instanceId, resourceName); + if (data == null) { + data = (T) GeaflowDataFactory.get(resourceName, null, instanceId, resourceType); } - - private List getResourcesByJobId(String jobId, GeaflowResourceType resourceType) { - List entities = jobResourceMappingDao.getResourcesByJobId(jobId, resourceType); - return ListUtil.convert(entities, e -> getResourceOrCreate(e.getResourceName(), e.getInstanceId(), resourceType)); - } - - private T getResourceOrCreate(String resourceName, String instanceId, GeaflowResourceType resourceType) { - // create a model only with name if no resource in database. - T data = (T) getResourceService(resourceType).getByName(instanceId, resourceName); - if (data == null) { - data = (T) GeaflowDataFactory.get(resourceName, null, instanceId, resourceType); - } - return data; - } - - private void removeJobResources(String jobId, List removeStructs, List removeGraphs, - List removeFunctions, List removePlugins) { - List removeEntities = ListUtil.convert(removeGraphs, - e -> new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.GRAPH, e.getInstanceId())); - removeEntities.addAll(ListUtil.convert(removeStructs, - e -> new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.valueOf(e.getType().name()), e.getInstanceId()))); - removeEntities.addAll(ListUtil.convert(removeFunctions, - e -> new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.FUNCTION, e.getInstanceId()))); - removeEntities.addAll(ListUtil.convert(removePlugins, - e -> new JobResourceMappingEntity(jobId, e.getName(), GeaflowResourceType.PLUGIN, PLUGIN_DEFAULT_INSTANCE_ID))); - jobResourceMappingDao.removeJobResources(removeEntities); - } - - private String getResourceKey(GeaflowData e) { - return e.getInstanceId() + "-" + e.getName(); - } - - public List getJobByResources(String resourceName, String instanceId, GeaflowResourceType resourceType) { - List resources = jobResourceMappingDao.getJobByResources(resourceName, instanceId, resourceType); - return ListUtil.convert(resources, JobResourceMappingEntity::getJobId); - } + return data; + } + + private void removeJobResources( + String jobId, + List removeStructs, + List removeGraphs, + List removeFunctions, + List removePlugins) { + List removeEntities = + ListUtil.convert( + removeGraphs, + e -> + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.GRAPH, e.getInstanceId())); + removeEntities.addAll( + ListUtil.convert( + removeStructs, + e -> + new JobResourceMappingEntity( + jobId, + e.getName(), + GeaflowResourceType.valueOf(e.getType().name()), + e.getInstanceId()))); + removeEntities.addAll( + ListUtil.convert( + removeFunctions, + e -> + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.FUNCTION, e.getInstanceId()))); + removeEntities.addAll( + ListUtil.convert( + removePlugins, + e -> + new JobResourceMappingEntity( + jobId, e.getName(), GeaflowResourceType.PLUGIN, PLUGIN_DEFAULT_INSTANCE_ID))); + jobResourceMappingDao.removeJobResources(removeEntities); + } + + private String getResourceKey(GeaflowData e) { + return e.getInstanceId() + "-" + e.getName(); + } + + public List getJobByResources( + String resourceName, String instanceId, GeaflowResourceType resourceType) { + List resources = + jobResourceMappingDao.getJobByResources(resourceName, instanceId, resourceType); + return ListUtil.convert(resources, JobResourceMappingEntity::getJobId); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/LLMService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/LLMService.java index 3defbdceb..8bd5b8b58 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/LLMService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/LLMService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.LLMDao; import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.entity.LLMEntity; @@ -34,26 +35,22 @@ @Service public class LLMService extends NameService { - @Autowired - private LLMDao llmDao; - - @Autowired - private LLMConverter llmConverter; + @Autowired private LLMDao llmDao; - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, llmConverter::convert); - } + @Autowired private LLMConverter llmConverter; - @Override - protected NameDao getDao() { - return llmDao; - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, llmConverter::convert); + } - @Override - protected NameConverter getConverter() { - return llmConverter; - } + @Override + protected NameDao getDao() { + return llmDao; + } + @Override + protected NameConverter getConverter() { + return llmConverter; + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/NameService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/NameService.java index 1774efcab..25f92d8fa 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/NameService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/NameService.java @@ -23,65 +23,64 @@ import java.util.Collections; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.entity.NameEntity; import org.apache.geaflow.console.common.dal.model.NameSearch; import org.apache.geaflow.console.core.model.GeaflowName; import org.apache.geaflow.console.core.service.converter.NameConverter; -public abstract class NameService extends IdService { - - @Override - protected abstract NameDao getDao(); +public abstract class NameService + extends IdService { - @Override - protected abstract NameConverter getConverter(); + @Override + protected abstract NameDao getDao(); - public boolean existName(String name) { - return getDao().existName(name); - } + @Override + protected abstract NameConverter getConverter(); - public M getByName(String name) { - if (name == null) { - return null; - } - List users = getByNames(Collections.singletonList(name)); - return users.isEmpty() ? null : users.get(0); - } + public boolean existName(String name) { + return getDao().existName(name); + } - public boolean dropByName(String name) { - if (name == null) { - return false; - } - return dropByNames(Collections.singletonList(name)); + public M getByName(String name) { + if (name == null) { + return null; } + List users = getByNames(Collections.singletonList(name)); + return users.isEmpty() ? null : users.get(0); + } - public String getIdByName(String name) { - if (name == null) { - return null; - } - return getIdsByNames(Collections.singletonList(name)).get(name); + public boolean dropByName(String name) { + if (name == null) { + return false; } + return dropByNames(Collections.singletonList(name)); + } - - public List getByNames(List names) { - List entityList = getDao().getByNames(names); - return parse(entityList); + public String getIdByName(String name) { + if (name == null) { + return null; } + return getIdsByNames(Collections.singletonList(name)).get(name); + } - public boolean dropByNames(List names) { - Map idsByNames = getIdsByNames(names); - return this.drop(new ArrayList<>(idsByNames.values())); - } + public List getByNames(List names) { + List entityList = getDao().getByNames(names); + return parse(entityList); + } - public Map getIdsByNames(List names) { - return getDao().getIdsByNames(names); - } + public boolean dropByNames(List names) { + Map idsByNames = getIdsByNames(names); + return this.drop(new ArrayList<>(idsByNames.values())); + } - public String getNameById(String id) { - E e = getDao().get(id); - return e != null ? e.getName() : null; - } + public Map getIdsByNames(List names) { + return getDao().getIdsByNames(names); + } + public String getNameById(String id) { + E e = getDao().get(id); + return e != null ? e.getName() : null; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginConfigService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginConfigService.java index 45f4b26d1..fe5f39230 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginConfigService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginConfigService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.PluginConfigDao; import org.apache.geaflow.console.common.dal.entity.PluginConfigEntity; @@ -45,115 +46,116 @@ import org.springframework.transaction.annotation.Transactional; @Service -public class PluginConfigService extends NameService { +public class PluginConfigService + extends NameService { - @Autowired - private PluginConfigDao pluginConfigDao; + @Autowired private PluginConfigDao pluginConfigDao; - @Autowired - private PluginConfigConverter pluginConfigConverter; + @Autowired private PluginConfigConverter pluginConfigConverter; - @Autowired - private PluginService pluginService; + @Autowired private PluginService pluginService; - @Override - protected NameDao getDao() { - return pluginConfigDao; - } + @Override + protected NameDao getDao() { + return pluginConfigDao; + } - @Override - protected NameConverter getConverter() { - return pluginConfigConverter; - } + @Override + protected NameConverter getConverter() { + return pluginConfigConverter; + } - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> pluginConfigConverter.convert(e)); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> pluginConfigConverter.convert(e)); + } - public List getPluginConfigs(GeaflowPluginCategory category, String type) { - return parse(pluginConfigDao.getPluginConfigs(category, type)); + public List getPluginConfigs(GeaflowPluginCategory category, String type) { + return parse(pluginConfigDao.getPluginConfigs(category, type)); + } + + public GeaflowPluginConfig getDefaultPluginConfig(GeaflowPluginCategory category, String type) { + List pluginConfigs = getPluginConfigs(category, type); + if (pluginConfigs.isEmpty()) { + throw new GeaflowException( + "At least one plugin config for {} plugin type {} needed", category, type); } - public GeaflowPluginConfig getDefaultPluginConfig(GeaflowPluginCategory category, String type) { - List pluginConfigs = getPluginConfigs(category, type); - if (pluginConfigs.isEmpty()) { - throw new GeaflowException("At least one plugin config for {} plugin type {} needed", category, type); - } + return pluginConfigs.get(0); + } - return pluginConfigs.get(0); + public void validateGetIds(List ids) { + for (String id : ids) { + if (!pluginConfigDao.validateGetId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } + } + } + + public void validateUpdateIds(List ids) { + for (String id : ids) { + if (!pluginConfigDao.validateUpdateId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } } + } - public void validateGetIds(List ids) { - for (String id : ids) { - if (!pluginConfigDao.validateGetId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", - id); - } - } + public GeaflowPluginConfig getDefaultPluginConfig(GeaflowPluginCategory category) { + List plugins = pluginService.getPlugins(category); + if (plugins.isEmpty()) { + return null; } - public void validateUpdateIds(List ids) { - for (String id : ids) { - if (!pluginConfigDao.validateUpdateId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", - id); - } - } + String type = plugins.get(0).getType(); + List pluginConfigs = getPluginConfigs(category, type); + if (pluginConfigs.isEmpty()) { + return null; } - public GeaflowPluginConfig getDefaultPluginConfig(GeaflowPluginCategory category) { - List plugins = pluginService.getPlugins(category); - if (plugins.isEmpty()) { - return null; - } + return pluginConfigs.get(0); + } - String type = plugins.get(0).getType(); - List pluginConfigs = getPluginConfigs(category, type); - if (pluginConfigs.isEmpty()) { - return null; - } + @Transactional + public String createDefaultPluginConfig(GeaflowPluginConfig pluginConfig) { + String type = pluginConfig.getType(); + GeaflowPluginCategory category = pluginConfig.getCategory(); - return pluginConfigs.get(0); + // check plugin config + List pluginConfigs = getPluginConfigs(category, type); + if (!pluginConfigs.isEmpty()) { + throw new GeaflowIllegalException("Default {} plugin {} config exists", category, type); } - @Transactional - public String createDefaultPluginConfig(GeaflowPluginConfig pluginConfig) { - String type = pluginConfig.getType(); - GeaflowPluginCategory category = pluginConfig.getCategory(); - - // check plugin config - List pluginConfigs = getPluginConfigs(category, type); - if (!pluginConfigs.isEmpty()) { - throw new GeaflowIllegalException("Default {} plugin {} config exists", category, type); - } - - // check plugin - List plugins = pluginService.getPlugins(category); - if (!plugins.isEmpty()) { - throw new GeaflowIllegalException("Default {} plugin exists", category); - } - - // create plugin config - final String pluginConfigId = create(pluginConfig); - - // create plugin - GeaflowPlugin plugin = new GeaflowPlugin(); - plugin.setName(Fmt.as("plugin-{}-{}-default", category, type).toLowerCase()); - plugin.setComment(Fmt.as(I18nUtil.getMessage("i18n.key.default.plugin.comment.format"), category, type).toLowerCase()); - plugin.setType(type); - plugin.setCategory(category); - pluginService.create(plugin); - - return pluginConfigId; + // check plugin + List plugins = pluginService.getPlugins(category); + if (!plugins.isEmpty()) { + throw new GeaflowIllegalException("Default {} plugin exists", category); } - public void testConnection(GeaflowPluginConfig pluginConfig) { - GeaflowPluginType type = GeaflowPluginType.valueOf(pluginConfig.getType()); - GeaflowConfig config = pluginConfig.getConfig(); - - GeaflowConfigDesc configDesc = ConfigDescFactory.get(type); - GeaflowConfigClass configClass = config.parse(configDesc.getClazz()); - ((PluginConfigClass) configClass).testConnection(); - } + // create plugin config + final String pluginConfigId = create(pluginConfig); + + // create plugin + GeaflowPlugin plugin = new GeaflowPlugin(); + plugin.setName(Fmt.as("plugin-{}-{}-default", category, type).toLowerCase()); + plugin.setComment( + Fmt.as(I18nUtil.getMessage("i18n.key.default.plugin.comment.format"), category, type) + .toLowerCase()); + plugin.setType(type); + plugin.setCategory(category); + pluginService.create(plugin); + + return pluginConfigId; + } + + public void testConnection(GeaflowPluginConfig pluginConfig) { + GeaflowPluginType type = GeaflowPluginType.valueOf(pluginConfig.getType()); + GeaflowConfig config = pluginConfig.getConfig(); + + GeaflowConfigDesc configDesc = ConfigDescFactory.get(type); + GeaflowConfigClass configClass = config.parse(configDesc.getClazz()); + ((PluginConfigClass) configClass).testConnection(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginService.java index c0694cea0..e8ad21db9 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/PluginService.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.PluginDao; @@ -50,153 +51,152 @@ import org.springframework.web.multipart.MultipartFile; @Service -public class PluginService extends NameService implements FileRefService { - - - public static final String PLUGIN_DEFAULT_INSTANCE_ID = "0"; - - @Autowired - private PluginDao pluginDao; - - @Autowired - private PluginConverter pluginConverter; - - @Autowired - private RemoteFileService remoteFileService; - - @Autowired - private VersionFactory versionFactory; - - - @Override - protected NameDao getDao() { - return pluginDao; +public class PluginService extends NameService + implements FileRefService { + + public static final String PLUGIN_DEFAULT_INSTANCE_ID = "0"; + + @Autowired private PluginDao pluginDao; + + @Autowired private PluginConverter pluginConverter; + + @Autowired private RemoteFileService remoteFileService; + + @Autowired private VersionFactory versionFactory; + + @Override + protected NameDao getDao() { + return pluginDao; + } + + @Override + protected NameConverter getConverter() { + return pluginConverter; + } + + @Override + protected List parse(List pluginEntities) { + return pluginEntities.stream() + .map( + e -> { + GeaflowRemoteFile jarPackage = remoteFileService.get(e.getJarPackageId()); + return pluginConverter.convert(e, jarPackage); + }) + .collect(Collectors.toList()); + } + + public List getPlugins(GeaflowPluginCategory category) { + List plugins = pluginDao.getPlugins(category); + return parse(plugins); + } + + public GeaflowPlugin getPlugin(String type, GeaflowPluginCategory category) { + PluginEntity plugin = pluginDao.getPlugin(type, category); + return plugin == null ? null : parse(plugin); + } + + public GeaflowPlugin getDefaultPlugin(GeaflowPluginCategory category) { + List plugins = getPlugins(category); + if (plugins.size() != 1) { + throw new GeaflowException("At least one plugin for {} plugin needed", category); } - @Override - protected NameConverter getConverter() { - return pluginConverter; - } + return plugins.get(0); + } - @Override - protected List parse(List pluginEntities) { - return pluginEntities.stream().map(e -> { - GeaflowRemoteFile jarPackage = remoteFileService.get(e.getJarPackageId()); - return pluginConverter.convert(e, jarPackage); - }).collect(Collectors.toList()); - } + public List getSystemPlugins(GeaflowPluginCategory category) { + List plugins = pluginDao.getSystemPlugins(category); + return parse(plugins); + } - public List getPlugins(GeaflowPluginCategory category) { - List plugins = pluginDao.getPlugins(category); - return parse(plugins); + public GeaflowPlugin getDefaultSystemPlugin(GeaflowPluginCategory category) { + List plugins = getSystemPlugins(category); + if (plugins.size() != 1) { + throw new GeaflowException("At least one system plugin for {} plugin needed", category); } - public GeaflowPlugin getPlugin(String type, GeaflowPluginCategory category) { - PluginEntity plugin = pluginDao.getPlugin(type, category); - return plugin == null ? null : parse(plugin); - } + return plugins.get(0); + } - public GeaflowPlugin getDefaultPlugin(GeaflowPluginCategory category) { - List plugins = getPlugins(category); - if (plugins.size() != 1) { - throw new GeaflowException("At least one plugin for {} plugin needed", category); - } - - return plugins.get(0); + public void validateGetIds(List ids) { + for (String id : ids) { + if (!pluginDao.validateGetId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } } - - public List getSystemPlugins(GeaflowPluginCategory category) { - List plugins = pluginDao.getSystemPlugins(category); - return parse(plugins); + } + + public void validateUpdateIds(List ids) { + for (String id : ids) { + if (!pluginDao.validateUpdateId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } } - - public GeaflowPlugin getDefaultSystemPlugin(GeaflowPluginCategory category) { - List plugins = getSystemPlugins(category); - if (plugins.size() != 1) { - throw new GeaflowException("At least one system plugin for {} plugin needed", category); - } - - return plugins.get(0); - } - - public void validateGetIds(List ids) { - for (String id : ids) { - if (!pluginDao.validateGetId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", id); - } - } - } - - public void validateUpdateIds(List ids) { - for (String id : ids) { - if (!pluginDao.validateUpdateId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", id); - } - } - } - - - public boolean pluginTypeInEngine(String pluginType, GeaflowVersion version) { - VersionClassLoader classLoader = versionFactory.getClassLoader(version); - return checkPluginType(classLoader, pluginType); - } - - @Override - public long getFileRefCount(String jarId, String pluginId) { - return pluginDao.getFileRefCount(jarId, pluginId); + } + + public boolean pluginTypeInEngine(String pluginType, GeaflowVersion version) { + VersionClassLoader classLoader = versionFactory.getClassLoader(version); + return checkPluginType(classLoader, pluginType); + } + + @Override + public long getFileRefCount(String jarId, String pluginId) { + return pluginDao.getFileRefCount(jarId, pluginId); + } + + public void checkJar(String type, MultipartFile jarPackage, GeaflowVersion version) { + FunctionClassLoader functionClassLoader = null; + File file = null; + try { + String tmpPath = "/tmp/geaflow/tmpFile/" + jarPackage.getOriginalFilename(); + file = new File(tmpPath); + FileUtils.copyInputStreamToFile(jarPackage.getInputStream(), file); + URL url = file.toURI().toURL(); + VersionClassLoader versionClassLoader = versionFactory.getClassLoader(version); + functionClassLoader = new FunctionClassLoader(versionClassLoader, new URL[] {url}); + if (!checkPluginType(functionClassLoader, type)) { + throw new GeaflowException("Plugin type {} is not in the jar", type); + } + + } catch (IOException e) { + throw new RuntimeException(e); + + } finally { + if (functionClassLoader != null) { + functionClassLoader.closeClassLoader(); + } + + if (file != null) { + file.delete(); + } } - - public void checkJar(String type, MultipartFile jarPackage, GeaflowVersion version) { - FunctionClassLoader functionClassLoader = null; - File file = null; - try { - String tmpPath = "/tmp/geaflow/tmpFile/" + jarPackage.getOriginalFilename(); - file = new File(tmpPath); - FileUtils.copyInputStreamToFile(jarPackage.getInputStream(), file); - URL url = file.toURI().toURL(); - VersionClassLoader versionClassLoader = versionFactory.getClassLoader(version); - functionClassLoader = new FunctionClassLoader(versionClassLoader, new URL[]{url}); - if (!checkPluginType(functionClassLoader, type)) { - throw new GeaflowException("Plugin type {} is not in the jar", type); - } - - } catch (IOException e) { - throw new RuntimeException(e); - - } finally { - if (functionClassLoader != null) { - functionClassLoader.closeClassLoader(); - } - - if (file != null) { - file.delete(); - } - } + } + + public void checkJar(String type, String jarId, GeaflowVersion version) { + FunctionClassLoader functionClassLoader = null; + try { + GeaflowRemoteFile remoteFile = remoteFileService.get(jarId); + VersionClassLoader versionClassLoader = versionFactory.getClassLoader(version); + functionClassLoader = new FunctionClassLoader(versionClassLoader, Arrays.asList(remoteFile)); + if (!checkPluginType(functionClassLoader, type)) { + throw new GeaflowException("Plugin type {} is not in the jar", type); + } + } finally { + if (functionClassLoader != null) { + functionClassLoader.closeClassLoader(); + } } - - public void checkJar(String type, String jarId, GeaflowVersion version) { - FunctionClassLoader functionClassLoader = null; - try { - GeaflowRemoteFile remoteFile = remoteFileService.get(jarId); - VersionClassLoader versionClassLoader = versionFactory.getClassLoader(version); - functionClassLoader = new FunctionClassLoader(versionClassLoader, Arrays.asList(remoteFile)); - if (!checkPluginType(functionClassLoader, type)) { - throw new GeaflowException("Plugin type {} is not in the jar", type); - } - } finally { - if (functionClassLoader != null) { - functionClassLoader.closeClassLoader(); - } - } - } - - private boolean checkPluginType(CompileClassLoader classLoader, String type) { - try { - GeaflowCompiler compiler = classLoader.newInstance(GeaflowCompiler.class); - Set enginePlugins = compiler.getEnginePlugins(); - return enginePlugins.contains(type.toUpperCase()); - } catch (Exception e) { - throw new GeaflowCompileException("Compile job code failed", e); - } + } + + private boolean checkPluginType(CompileClassLoader classLoader, String type) { + try { + GeaflowCompiler compiler = classLoader.newInstance(GeaflowCompiler.class); + Set enginePlugins = compiler.getEnginePlugins(); + return enginePlugins.contains(type.toUpperCase()); + } catch (Exception e) { + throw new GeaflowCompileException("Compile job code failed", e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ReleaseService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ReleaseService.java index eca9c4fed..af132bbf4 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ReleaseService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ReleaseService.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.MapUtils; import org.apache.geaflow.console.common.dal.dao.IdDao; @@ -62,177 +63,180 @@ @Service public class ReleaseService extends IdService { - @Autowired - private DeployConfig deployConfig; + @Autowired private DeployConfig deployConfig; - @Autowired - private ReleaseDao releaseDao; + @Autowired private ReleaseDao releaseDao; - @Autowired - private JobService jobService; + @Autowired private JobService jobService; - @Autowired - private ReleaseConverter releaseConverter; + @Autowired private ReleaseConverter releaseConverter; - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - @Autowired - private InstanceService instanceService; + @Autowired private InstanceService instanceService; - @Autowired - private ClusterService clusterService; + @Autowired private ClusterService clusterService; - @Autowired - private VersionFactory versionFactory; + @Autowired private VersionFactory versionFactory; - @Autowired - private RemoteFileStorage remoteFileStorage; + @Autowired private RemoteFileStorage remoteFileStorage; - @Autowired - private LocalFileFactory localFileFactory; + @Autowired private LocalFileFactory localFileFactory; - protected IdDao getDao() { - return releaseDao; - } + protected IdDao getDao() { + return releaseDao; + } - @Override - protected IdConverter getConverter() { - return releaseConverter; - } + @Override + protected IdConverter getConverter() { + return releaseConverter; + } - @Override - protected List parse(List releaseEntities) { - return releaseEntities.stream().map(e -> { - GeaflowJob job = jobService.get(e.getJobId()); + @Override + protected List parse(List releaseEntities) { + return releaseEntities.stream() + .map( + e -> { + GeaflowJob job = jobService.get(e.getJobId()); - GeaflowVersion version = versionService.get(e.getVersionId()); + GeaflowVersion version = versionService.get(e.getVersionId()); - GeaflowCluster cluster = clusterService.get(e.getClusterId()); + GeaflowCluster cluster = clusterService.get(e.getClusterId()); - return releaseConverter.convert(e, job, version, cluster); - }).collect(Collectors.toList()); - } + return releaseConverter.convert(e, job, version, cluster); + }) + .collect(Collectors.toList()); + } - public Set parseFunctions(GeaflowJob job, GeaflowVersion version) { - CompilerAndContext compilerAndContext = getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); - try { - return compilerAndContext.getCompiler() - .getUnResolvedFunctions(job.getUserCode().getText(), compilerAndContext.getContext()); - } catch (Exception e) { - throw new GeaflowCompileException("Parse functions failed", e); - } + public Set parseFunctions(GeaflowJob job, GeaflowVersion version) { + CompilerAndContext compilerAndContext = + getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); + try { + return compilerAndContext + .getCompiler() + .getUnResolvedFunctions(job.getUserCode().getText(), compilerAndContext.getContext()); + } catch (Exception e) { + throw new GeaflowCompileException("Parse functions failed", e); } - - - public CompileResult compile(GeaflowJob job, GeaflowVersion version, Map parallelisms) { - VersionClassLoader classLoader = versionFactory.getClassLoader(version); - List udfs = ListUtil.convert(job.getFunctions(), GeaflowFunction::getJarPackage); - List plugins = ListUtil.convert(job.getPlugins(), GeaflowPlugin::getJarPackage); - udfs.addAll(plugins); - if (CollectionUtils.isNotEmpty(udfs)) { - // use FunctionClassLoader if job has udf - FunctionClassLoader functionLoader = null; - try { - functionLoader = new FunctionClassLoader(classLoader, udfs); - return compile(functionLoader, job, parallelisms); - } finally { - if (functionLoader != null) { - functionLoader.closeClassLoader(); - } - - } + } + + public CompileResult compile( + GeaflowJob job, GeaflowVersion version, Map parallelisms) { + VersionClassLoader classLoader = versionFactory.getClassLoader(version); + List udfs = + ListUtil.convert(job.getFunctions(), GeaflowFunction::getJarPackage); + List plugins = + ListUtil.convert(job.getPlugins(), GeaflowPlugin::getJarPackage); + udfs.addAll(plugins); + if (CollectionUtils.isNotEmpty(udfs)) { + // use FunctionClassLoader if job has udf + FunctionClassLoader functionLoader = null; + try { + functionLoader = new FunctionClassLoader(classLoader, udfs); + return compile(functionLoader, job, parallelisms); + } finally { + if (functionLoader != null) { + functionLoader.closeClassLoader(); } - - return compile(classLoader, job, parallelisms); + } } + return compile(classLoader, job, parallelisms); + } - private CompileResult compile(CompileClassLoader classLoader, GeaflowJob job, Map parallelisms) { - CompilerAndContext compilerAndContext = getCompilerAndContext(classLoader, job.getInstanceId(), CatalogType.CONSOLE); - if (MapUtils.isNotEmpty(parallelisms)) { - compilerAndContext.getContext().setParallelisms(parallelisms); - } - try { - return compilerAndContext.getCompiler() - .compile(job.getUserCode().getText(), compilerAndContext.getContext()); - } catch (Exception e) { - throw new GeaflowCompileException("Compile job code failed", e); - } + private CompileResult compile( + CompileClassLoader classLoader, GeaflowJob job, Map parallelisms) { + CompilerAndContext compilerAndContext = + getCompilerAndContext(classLoader, job.getInstanceId(), CatalogType.CONSOLE); + if (MapUtils.isNotEmpty(parallelisms)) { + compilerAndContext.getContext().setParallelisms(parallelisms); } - - public void dropByJobIds(List ids) { - List releases = releaseDao.getByJobIds(ids); - for (ReleaseEntity release : releases) { - for (int i = 1; i <= release.getVersion(); i++) { - String path = RemoteFileStorage.getPackageFilePath(release.getJobId(), i); - remoteFileStorage.delete(path); - } - } - releaseDao.dropByJobIds(ids); + try { + return compilerAndContext + .getCompiler() + .compile(job.getUserCode().getText(), compilerAndContext.getContext()); + } catch (Exception e) { + throw new GeaflowCompileException("Compile job code failed", e); } - - public Set parseDeclaredPlugins(GeaflowJob job, GeaflowVersion version) { - CompilerAndContext compilerAndContext = getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); - try { - return compilerAndContext.getCompiler() - .getDeclaredTablePlugins(job.getUserCode().getText(), compilerAndContext.getContext()); - } catch (Exception e) { - throw new GeaflowCompileException("Parse plugins failed", e); - } + } + + public void dropByJobIds(List ids) { + List releases = releaseDao.getByJobIds(ids); + for (ReleaseEntity release : releases) { + for (int i = 1; i <= release.getVersion(); i++) { + String path = RemoteFileStorage.getPackageFilePath(release.getJobId(), i); + remoteFileStorage.delete(path); + } } - - public Set getUnResolvedTables(GeaflowJob job, GeaflowVersion version) { - CompilerAndContext compilerAndContext = getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); - try { - return compilerAndContext.getCompiler() - .getUnResolvedTables(job.getUserCode().getText(), compilerAndContext.getContext()); - } catch (Exception e) { - throw new GeaflowCompileException("Parse plugins failed", e); - } + releaseDao.dropByJobIds(ids); + } + + public Set parseDeclaredPlugins(GeaflowJob job, GeaflowVersion version) { + CompilerAndContext compilerAndContext = + getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); + try { + return compilerAndContext + .getCompiler() + .getDeclaredTablePlugins(job.getUserCode().getText(), compilerAndContext.getContext()); + } catch (Exception e) { + throw new GeaflowCompileException("Parse plugins failed", e); } - - private CompilerAndContext getCompilerAndContext(GeaflowVersion version, String instanceId, CatalogType catalogType) { - VersionClassLoader classLoader = versionFactory.getClassLoader(version); - return getCompilerAndContext(classLoader, instanceId, catalogType); + } + + public Set getUnResolvedTables(GeaflowJob job, GeaflowVersion version) { + CompilerAndContext compilerAndContext = + getCompilerAndContext(version, job.getInstanceId(), CatalogType.MEMORY); + try { + return compilerAndContext + .getCompiler() + .getUnResolvedTables(job.getUserCode().getText(), compilerAndContext.getContext()); + } catch (Exception e) { + throw new GeaflowCompileException("Parse plugins failed", e); } - - private CompilerAndContext getCompilerAndContext(CompileClassLoader classLoader, String instanceId, CatalogType catalogType) { - CompileContext context = classLoader.newInstance(CompileContext.class); - GeaflowCompiler compiler = classLoader.newInstance(GeaflowCompiler.class); - setContextConfig(context, instanceId, catalogType); - return new CompilerAndContext(compiler, context); + } + + private CompilerAndContext getCompilerAndContext( + GeaflowVersion version, String instanceId, CatalogType catalogType) { + VersionClassLoader classLoader = versionFactory.getClassLoader(version); + return getCompilerAndContext(classLoader, instanceId, catalogType); + } + + private CompilerAndContext getCompilerAndContext( + CompileClassLoader classLoader, String instanceId, CatalogType catalogType) { + CompileContext context = classLoader.newInstance(CompileContext.class); + GeaflowCompiler compiler = classLoader.newInstance(GeaflowCompiler.class); + setContextConfig(context, instanceId, catalogType); + return new CompilerAndContext(compiler, context); + } + + private void setContextConfig( + CompileContext context, String instanceId, CatalogType catalogType) { + GeaflowInstance instance = instanceService.get(instanceId); + CompileContextClass config = new CompileContextClass(); + config.setTokenKey(ContextHolder.get().getSessionToken()); + config.setInstanceName(instance.getName()); + config.setCatalogType(catalogType.getValue()); + config.setEndpoint(deployConfig.getGatewayUrl()); + Map map = config.build().toStringMap(); + context.setConfig(map); + } + + private static class CompilerAndContext { + + private final GeaflowCompiler compiler; + private final CompileContext context; + + public CompilerAndContext(GeaflowCompiler geaflowCompiler, CompileContext compileContext) { + this.compiler = geaflowCompiler; + this.context = compileContext; } - private void setContextConfig(CompileContext context, String instanceId, CatalogType catalogType) { - GeaflowInstance instance = instanceService.get(instanceId); - CompileContextClass config = new CompileContextClass(); - config.setTokenKey(ContextHolder.get().getSessionToken()); - config.setInstanceName(instance.getName()); - config.setCatalogType(catalogType.getValue()); - config.setEndpoint(deployConfig.getGatewayUrl()); - Map map = config.build().toStringMap(); - context.setConfig(map); + public GeaflowCompiler getCompiler() { + return compiler; } - private static class CompilerAndContext { - - private final GeaflowCompiler compiler; - private final CompileContext context; - - public CompilerAndContext(GeaflowCompiler geaflowCompiler, CompileContext compileContext) { - this.compiler = geaflowCompiler; - this.context = compileContext; - } - - public GeaflowCompiler getCompiler() { - return compiler; - } - - public CompileContext getContext() { - return context; - } + public CompileContext getContext() { + return context; } - + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/RemoteFileService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/RemoteFileService.java index 6d3d183b9..036d33661 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/RemoteFileService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/RemoteFileService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.RemoteFileDao; import org.apache.geaflow.console.common.dal.entity.RemoteFileEntity; @@ -35,71 +36,70 @@ import org.springframework.stereotype.Service; @Service -public class RemoteFileService extends NameService { - - public static final String JAR_FILE_SUFFIX = ".jar"; - - @Autowired - private RemoteFileDao remoteFileDao; - - @Autowired - private RemoteFileConverter remoteFileConverter; +public class RemoteFileService + extends NameService { - @Autowired - private RemoteFileStorage remoteFileStorage; + public static final String JAR_FILE_SUFFIX = ".jar"; - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> remoteFileConverter.convert(e)); - } + @Autowired private RemoteFileDao remoteFileDao; - @Override - public String create(GeaflowRemoteFile model) { - String name = model.getName(); - if (remoteFileDao.existName(model.getName())) { - throw new GeaflowIllegalException("File {} exists", name); - } + @Autowired private RemoteFileConverter remoteFileConverter; - return super.create(model); - } + @Autowired private RemoteFileStorage remoteFileStorage; - @Override - protected NameDao getDao() { - return remoteFileDao; - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> remoteFileConverter.convert(e)); + } - @Override - protected NameConverter getConverter() { - return remoteFileConverter; + @Override + public String create(GeaflowRemoteFile model) { + String name = model.getName(); + if (remoteFileDao.existName(model.getName())) { + throw new GeaflowIllegalException("File {} exists", name); } - public GeaflowRemoteFile getByName(String name) { - RemoteFileEntity entity = remoteFileDao.getByName(name); - return parse(entity); + return super.create(model); + } + + @Override + protected NameDao getDao() { + return remoteFileDao; + } + + @Override + protected NameConverter getConverter() { + return remoteFileConverter; + } + + public GeaflowRemoteFile getByName(String name) { + RemoteFileEntity entity = remoteFileDao.getByName(name); + return parse(entity); + } + + public void updateMd5ById(String id, String md5) { + remoteFileDao.updateMd5(id, md5); + } + + public void updateUrlById(String id, String url) { + remoteFileDao.updateUrl(id, url); + } + + public void validateGetIds(List ids) { + for (String id : ids) { + if (!remoteFileDao.validateGetId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } } - - public void updateMd5ById(String id, String md5) { - remoteFileDao.updateMd5(id, md5); - } - - public void updateUrlById(String id, String url) { - remoteFileDao.updateUrl(id, url); - } - - public void validateGetIds(List ids) { - for (String id : ids) { - if (!remoteFileDao.validateGetId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", id); - } - } - } - - public void validateUpdateIds(List ids) { - for (String id : ids) { - if (!remoteFileDao.validateUpdateId(id)) { - throw new GeaflowException("Invalidate id {} (Not system session or current user is not the creator)", id); - } - } + } + + public void validateUpdateIds(List ids) { + for (String id : ids) { + if (!remoteFileDao.validateUpdateId(id)) { + throw new GeaflowException( + "Invalidate id {} (Not system session or current user is not the creator)", id); + } } + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/StatementService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/StatementService.java index 459635581..a14cd3f32 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/StatementService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/StatementService.java @@ -19,10 +19,10 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.common.dal.dao.IdDao; import org.apache.geaflow.console.common.dal.dao.StatementDao; import org.apache.geaflow.console.common.dal.entity.StatementEntity; @@ -40,70 +40,69 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; -@Service -public class StatementService extends IdService { - - @Autowired - private StatementDao statementDao; - - @Autowired - private StatementConverter statementConverter; - - @Autowired - private StatementSubmitter statementSubmitter; - - @Autowired - private TaskService taskService; +import com.google.common.base.Preconditions; - @Override - protected IdDao getDao() { - return statementDao; +@Service +public class StatementService + extends IdService { + + @Autowired private StatementDao statementDao; + + @Autowired private StatementConverter statementConverter; + + @Autowired private StatementSubmitter statementSubmitter; + + @Autowired private TaskService taskService; + + @Override + protected IdDao getDao() { + return statementDao; + } + + @Override + protected IdConverter getConverter() { + return statementConverter; + } + + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> statementConverter.convert(e)); + } + + @Override + public List create(List models) { + Map taskMap = new HashMap<>(); + for (GeaflowStatement model : models) { + String jobId = model.getJobId(); + + GeaflowTask task = null; + if (!taskMap.containsKey(jobId)) { + task = taskService.getByJobId(jobId); + Preconditions.checkNotNull(task, "Job %s task is null, please publish job", jobId); + taskMap.put(jobId, task); + } else { + task = taskMap.get(jobId); + } + + GeaflowJob job = task.getRelease().getJob(); + if (task.getStatus() != GeaflowTaskStatus.RUNNING) { + throw new GeaflowException("Job {} task is not running", job.getName()); + } + + model.setStatus(GeaflowStatementStatus.RUNNING); + model.setResult("Query is running, please wait or refresh the page"); } - @Override - protected IdConverter getConverter() { - return statementConverter; - } + List ids = super.create(models); - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> statementConverter.convert(e)); + for (GeaflowStatement model : models) { + statementSubmitter.asyncSubmitQuery(model, taskMap.get(model.getJobId())); } - @Override - public List create(List models) { - Map taskMap = new HashMap<>(); - for (GeaflowStatement model : models) { - String jobId = model.getJobId(); - - GeaflowTask task = null; - if (!taskMap.containsKey(jobId)) { - task = taskService.getByJobId(jobId); - Preconditions.checkNotNull(task, "Job %s task is null, please publish job", jobId); - taskMap.put(jobId, task); - } else { - task = taskMap.get(jobId); - } - - GeaflowJob job = task.getRelease().getJob(); - if (task.getStatus() != GeaflowTaskStatus.RUNNING) { - throw new GeaflowException("Job {} task is not running", job.getName()); - } - - model.setStatus(GeaflowStatementStatus.RUNNING); - model.setResult("Query is running, please wait or refresh the page"); - } - - List ids = super.create(models); - - for (GeaflowStatement model : models) { - statementSubmitter.asyncSubmitQuery(model, taskMap.get(model.getJobId())); - } - - return ids; - } + return ids; + } - public boolean dropByJobIds(List jobIds) { - return statementDao.dropByJobIds(jobIds); - } + public boolean dropByJobIds(List jobIds) { + return statementDao.dropByJobIds(jobIds); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/SystemConfigService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/SystemConfigService.java index 7474e160e..abe566b2a 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/SystemConfigService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/SystemConfigService.java @@ -19,11 +19,9 @@ package org.apache.geaflow.console.core.service; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.SystemConfigDao; import org.apache.geaflow.console.common.dal.entity.SystemConfigEntity; @@ -36,92 +34,94 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; + @Service -public class SystemConfigService extends NameService { +public class SystemConfigService + extends NameService { - @Autowired - private SystemConfigDao systemConfigDao; + @Autowired private SystemConfigDao systemConfigDao; - @Autowired - private SystemConfigConverter systemConfigConvert; + @Autowired private SystemConfigConverter systemConfigConvert; - @Override - protected NameDao getDao() { - return systemConfigDao; - } + @Override + protected NameDao getDao() { + return systemConfigDao; + } - @Override - protected NameConverter getConverter() { - return systemConfigConvert; - } + @Override + protected NameConverter getConverter() { + return systemConfigConvert; + } - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, systemConfigConvert::convert); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, systemConfigConvert::convert); + } - public GeaflowSystemConfig get(String tenantId, String key) { - return parse(systemConfigDao.get(tenantId, key)); - } + public GeaflowSystemConfig get(String tenantId, String key) { + return parse(systemConfigDao.get(tenantId, key)); + } - public boolean exist(String tenantId, String key) { - return systemConfigDao.exist(tenantId, key); - } - - public boolean delete(String tenantId, String key) { - return systemConfigDao.delete(tenantId, key); - } + public boolean exist(String tenantId, String key) { + return systemConfigDao.exist(tenantId, key); + } - public void setValue(String key, Object value) { - String str = Optional.ofNullable(value).map(Object::toString).orElse(null); + public boolean delete(String tenantId, String key) { + return systemConfigDao.delete(tenantId, key); + } - // system level config - if (ContextHolder.get().isSystemSession()) { - systemConfigDao.setValue(null, key, str); - } + public void setValue(String key, Object value) { + String str = Optional.ofNullable(value).map(Object::toString).orElse(null); - // tenant level config - String tenantId = ContextHolder.get().getTenantId(); - systemConfigDao.setValue(tenantId, key, str); + // system level config + if (ContextHolder.get().isSystemSession()) { + systemConfigDao.setValue(null, key, str); } - public String getValue(String key) { - // default config - String defaultValue = systemConfigDao.getValue(null, key); + // tenant level config + String tenantId = ContextHolder.get().getTenantId(); + systemConfigDao.setValue(tenantId, key, str); + } - // system level config - if (ContextHolder.get().isSystemSession()) { - return defaultValue; - } + public String getValue(String key) { + // default config + String defaultValue = systemConfigDao.getValue(null, key); - // tenant level config - String tenantId = ContextHolder.get().getTenantId(); - SystemConfigEntity entity = systemConfigDao.get(tenantId, key); - return entity != null ? entity.getValue() : defaultValue; + // system level config + if (ContextHolder.get().isSystemSession()) { + return defaultValue; } - public String getString(String key) { - return getValue(key); - } + // tenant level config + String tenantId = ContextHolder.get().getTenantId(); + SystemConfigEntity entity = systemConfigDao.get(tenantId, key); + return entity != null ? entity.getValue() : defaultValue; + } - public long getLong(String key) { - return Optional.ofNullable(getValue(key)).map(Long::parseLong).orElse(0L); - } + public String getString(String key) { + return getValue(key); + } - public int getInteger(String key) { - return Optional.ofNullable(getValue(key)).map(Integer::parseInt).orElse(0); - } + public long getLong(String key) { + return Optional.ofNullable(getValue(key)).map(Long::parseLong).orElse(0L); + } - public boolean getBoolean(String key) { - return Optional.ofNullable(getValue(key)).map(Boolean::parseBoolean).orElse(false); - } + public int getInteger(String key) { + return Optional.ofNullable(getValue(key)).map(Integer::parseInt).orElse(0); + } - public JSONObject getJsonObject(String key) { - return Optional.ofNullable(getValue(key)).map(JSON::parseObject).orElse(new JSONObject()); - } + public boolean getBoolean(String key) { + return Optional.ofNullable(getValue(key)).map(Boolean::parseBoolean).orElse(false); + } - public JSONArray getJsonArray(String key) { - return Optional.ofNullable(getValue(key)).map(JSON::parseArray).orElse(new JSONArray()); - } -} + public JSONObject getJsonObject(String key) { + return Optional.ofNullable(getValue(key)).map(JSON::parseObject).orElse(new JSONObject()); + } + public JSONArray getJsonArray(String key) { + return Optional.ofNullable(getValue(key)).map(JSON::parseArray).orElse(new JSONArray()); + } +} diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TableService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TableService.java index 9d7fd0bd0..0431f48a7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TableService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TableService.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.dao.TableDao; import org.apache.geaflow.console.common.dal.entity.IdEntity; @@ -42,76 +43,74 @@ @Service public class TableService extends DataService { - @Autowired - private TableDao tableDao; - - @Autowired - private FieldService fieldService; - - @Autowired - private TableConverter tableConverter; + @Autowired private TableDao tableDao; + @Autowired private FieldService fieldService; - @Autowired - private PluginConfigService pluginConfigService; - - private GeaflowResourceType resourceType = GeaflowResourceType.TABLE; - - @Override - protected DataDao getDao() { - return tableDao; - } + @Autowired private TableConverter tableConverter; - @Override - protected DataConverter getConverter() { - return tableConverter; - } + @Autowired private PluginConfigService pluginConfigService; - @Override - public List create(List models) { - List tableIds = super.create(models); + private GeaflowResourceType resourceType = GeaflowResourceType.TABLE; - for (GeaflowTable model : models) { - fieldService.createByResource(new ArrayList<>(model.getFields().values()), model.getId(), resourceType); - } + @Override + protected DataDao getDao() { + return tableDao; + } - return tableIds; - } + @Override + protected DataConverter getConverter() { + return tableConverter; + } - @Override - protected List parse(List tableEntities) { - List tableIds = ListUtil.convert(tableEntities, IdEntity::getId); - Map> fieldsMap = fieldService.getByResources(tableIds, resourceType); + @Override + public List create(List models) { + List tableIds = super.create(models); - return tableEntities.stream().map(e -> { - List fields = fieldsMap.get(e.getId()); - GeaflowPluginConfig pluginConfig = pluginConfigService.get(e.getPluginConfigId()); - return tableConverter.convert(e, fields, pluginConfig); - }).collect(Collectors.toList()); + for (GeaflowTable model : models) { + fieldService.createByResource( + new ArrayList<>(model.getFields().values()), model.getId(), resourceType); } - @Override - public boolean update(List tables) { - List tableIds = ListUtil.convert(tables, GeaflowId::getId); - - fieldService.removeByResources(tableIds, resourceType); - for (GeaflowTable newTable : tables) { - List newFields = new ArrayList<>(newTable.getFields().values()); - fieldService.createByResource(newFields, newTable.getId(), resourceType); - - GeaflowPluginConfig pluginConfig = newTable.getPluginConfig(); - pluginConfigService.update(pluginConfig); - } - return super.update(tables); + return tableIds; + } + + @Override + protected List parse(List tableEntities) { + List tableIds = ListUtil.convert(tableEntities, IdEntity::getId); + Map> fieldsMap = fieldService.getByResources(tableIds, resourceType); + + return tableEntities.stream() + .map( + e -> { + List fields = fieldsMap.get(e.getId()); + GeaflowPluginConfig pluginConfig = pluginConfigService.get(e.getPluginConfigId()); + return tableConverter.convert(e, fields, pluginConfig); + }) + .collect(Collectors.toList()); + } + + @Override + public boolean update(List tables) { + List tableIds = ListUtil.convert(tables, GeaflowId::getId); + + fieldService.removeByResources(tableIds, resourceType); + for (GeaflowTable newTable : tables) { + List newFields = new ArrayList<>(newTable.getFields().values()); + fieldService.createByResource(newFields, newTable.getId(), resourceType); + + GeaflowPluginConfig pluginConfig = newTable.getPluginConfig(); + pluginConfigService.update(pluginConfig); } + return super.update(tables); + } - @Override - public boolean drop(List ids) { - List entities = tableDao.get(ids); - - fieldService.removeByResources(ids, resourceType); - pluginConfigService.drop(ListUtil.convert(entities, TableEntity::getPluginConfigId)); - return super.drop(ids); - } + @Override + public boolean drop(List ids) { + List entities = tableDao.get(ids); + fieldService.removeByResources(ids, resourceType); + pluginConfigService.drop(ListUtil.convert(entities, TableEntity::getPluginConfigId)); + return super.drop(ids); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TaskService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TaskService.java index e27e2c67b..aa7a1d52b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TaskService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TaskService.java @@ -19,12 +19,12 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.IdDao; import org.apache.geaflow.console.common.dal.dao.TaskDao; import org.apache.geaflow.console.common.dal.entity.IdEntity; @@ -47,160 +47,162 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.google.common.base.Preconditions; + @Service public class TaskService extends IdService { - @Autowired - private TaskDao taskDao; - - @Autowired - private TaskConverter taskConverter; - - @Autowired - private ReleaseService releaseService; - - @Autowired - private AuditService auditService; - - - @Autowired - private PluginService pluginService; - - @Autowired - private PluginConfigService pluginConfigService; - - @Override - protected IdDao getDao() { - return taskDao; + @Autowired private TaskDao taskDao; + + @Autowired private TaskConverter taskConverter; + + @Autowired private ReleaseService releaseService; + + @Autowired private AuditService auditService; + + @Autowired private PluginService pluginService; + + @Autowired private PluginConfigService pluginConfigService; + + @Override + protected IdDao getDao() { + return taskDao; + } + + @Override + protected IdConverter getConverter() { + return taskConverter; + } + + @Override + protected List parse(List taskEntities) { + return taskEntities.stream() + .map( + e -> { + GeaflowRelease release = releaseService.get(e.getReleaseId()); + GeaflowPluginConfig runtimeConfig = + pluginConfigService.get(e.getRuntimeMetaConfigId()); + GeaflowPluginConfig haMetaConfig = pluginConfigService.get(e.getHaMetaConfigId()); + GeaflowPluginConfig metricConfig = pluginConfigService.get(e.getMetricConfigId()); + GeaflowPluginConfig dataConfig = pluginConfigService.get(e.getDataConfigId()); + + return taskConverter.convert( + e, release, runtimeConfig, haMetaConfig, metricConfig, dataConfig); + }) + .collect(Collectors.toList()); + } + + @Override + public GeaflowTask get(String id) { + GeaflowTask task = super.get(id); + Preconditions.checkNotNull(task, "task %s not exist", id); + return task; + } + + public GeaflowTask getByJobId(String jobId) { + List entities = taskDao.getByJobId(jobId); + return entities.isEmpty() ? null : parse(entities).get(0); + } + + public List getIdsByJob(List jobIds) { + List entities = taskDao.getIdsByJobs(jobIds); + return ListUtil.convert(entities, IdEntity::getId); + } + + public List createTask(GeaflowRelease release) { + if (release.getReleaseVersion() != 1) { + throw new GeaflowException("Job status is created or release version is not 1"); } - - @Override - protected IdConverter getConverter() { - return taskConverter; + try { + GeaflowTask task = new GeaflowTask(); + task.setStatus(GeaflowTaskStatus.CREATED); + GeaflowJob job = release.getJob(); + task.setType(job.getType().getTaskType()); + task.setRelease(release); + setPluginConfigs(task); + create(task); + auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.CREATE)); + return Arrays.asList(task); + } catch (Exception e) { + throw new GeaflowException("Build task fail ", e); } - - @Override - protected List parse(List taskEntities) { - return taskEntities.stream().map(e -> { - GeaflowRelease release = releaseService.get(e.getReleaseId()); - GeaflowPluginConfig runtimeConfig = pluginConfigService.get(e.getRuntimeMetaConfigId()); - GeaflowPluginConfig haMetaConfig = pluginConfigService.get(e.getHaMetaConfigId()); - GeaflowPluginConfig metricConfig = pluginConfigService.get(e.getMetricConfigId()); - GeaflowPluginConfig dataConfig = pluginConfigService.get(e.getDataConfigId()); - - return taskConverter.convert(e, release, runtimeConfig, haMetaConfig, metricConfig, dataConfig); - }).collect(Collectors.toList()); + } + + private void setPluginConfigs(GeaflowTask task) { + GeaflowPluginCategory ha = GeaflowPluginCategory.HA_META; + GeaflowPluginCategory metric = GeaflowPluginCategory.METRIC; + GeaflowPluginCategory runtime = GeaflowPluginCategory.RUNTIME_META; + GeaflowPluginCategory data = GeaflowPluginCategory.DATA; + + String haType = pluginService.getDefaultPlugin(ha).getType(); + String metricType = pluginService.getDefaultPlugin(metric).getType(); + String runtimeType = pluginService.getDefaultPlugin(runtime).getType(); + String dataType = pluginService.getDefaultPlugin(data).getType(); + + GeaflowPluginConfig haConfig = pluginConfigService.getDefaultPluginConfig(ha, haType); + GeaflowPluginConfig metricConfig = + pluginConfigService.getDefaultPluginConfig(metric, metricType); + GeaflowPluginConfig runtimeConfig = + pluginConfigService.getDefaultPluginConfig(runtime, runtimeType); + GeaflowPluginConfig dataConfig = pluginConfigService.getDefaultPluginConfig(data, dataType); + + task.setHaMetaPluginConfig(haConfig); + task.setMetricPluginConfig(metricConfig); + task.setRuntimeMetaPluginConfig(runtimeConfig); + task.setDataPluginConfig(dataConfig); + } + + public String bindRelease(GeaflowRelease release) { + TaskEntity task = taskDao.getByJobId(release.getJob().getId()).get(0); + if (task.getStatus() == GeaflowTaskStatus.CREATED && release.getReleaseVersion() == 1) { + // don't need to bind at the first time + return null; } - - @Override - public GeaflowTask get(String id) { - GeaflowTask task = super.get(id); - Preconditions.checkNotNull(task, "task %s not exist", id); - return task; + task.setReleaseId(release.getId()); + boolean updateStatus = updateStatus(task.getId(), task.getStatus(), GeaflowTaskStatus.CREATED); + if (!updateStatus) { + throw new GeaflowException("task status has been changed"); } - - public GeaflowTask getByJobId(String jobId) { - List entities = taskDao.getByJobId(jobId); - return entities.isEmpty() ? null : parse(entities).get(0); - } - - public List getIdsByJob(List jobIds) { - List entities = taskDao.getIdsByJobs(jobIds); - return ListUtil.convert(entities, IdEntity::getId); + taskDao.update(task); + + return task.getId(); + } + + public Map getTaskHandles(GeaflowTaskStatus status) { + List tasks = taskDao.getTasksByStatus(status); + return ListUtil.toMap(tasks, IdEntity::getId, e -> GeaflowTaskHandle.parse(e.getHandle())); + } + + public List getTasksByStatus(GeaflowTaskStatus status) { + List tasks = taskDao.getTasksByStatus(status); + return ListUtil.convert(tasks, this::convertToId); + } + + public GeaflowId getByTaskToken(String token) { + TaskEntity entity = taskDao.getByToken(token); + return entity == null ? null : convertToId(entity); + } + + private GeaflowId convertToId(TaskEntity entity) { + GeaflowId model = new GeaflowTask(); + model.setTenantId(entity.getTenantId()); + model.setId(entity.getId()); + model.setCreatorId(entity.getCreatorId()); + model.setModifierId(entity.getModifierId()); + model.setGmtCreate(entity.getGmtCreate()); + model.setGmtModified(entity.getGmtModified()); + return model; + } + + public boolean updateStatus( + String taskId, GeaflowTaskStatus oldStatus, GeaflowTaskStatus newStatus) { + if (oldStatus == newStatus) { + return true; } + return taskDao.updateStatus(taskId, oldStatus, newStatus); + } - public List createTask(GeaflowRelease release) { - if (release.getReleaseVersion() != 1) { - throw new GeaflowException("Job status is created or release version is not 1"); - } - try { - GeaflowTask task = new GeaflowTask(); - task.setStatus(GeaflowTaskStatus.CREATED); - GeaflowJob job = release.getJob(); - task.setType(job.getType().getTaskType()); - task.setRelease(release); - setPluginConfigs(task); - create(task); - auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.CREATE)); - return Arrays.asList(task); - } catch (Exception e) { - throw new GeaflowException("Build task fail ", e); - } - } - - private void setPluginConfigs(GeaflowTask task) { - GeaflowPluginCategory ha = GeaflowPluginCategory.HA_META; - GeaflowPluginCategory metric = GeaflowPluginCategory.METRIC; - GeaflowPluginCategory runtime = GeaflowPluginCategory.RUNTIME_META; - GeaflowPluginCategory data = GeaflowPluginCategory.DATA; - - String haType = pluginService.getDefaultPlugin(ha).getType(); - String metricType = pluginService.getDefaultPlugin(metric).getType(); - String runtimeType = pluginService.getDefaultPlugin(runtime).getType(); - String dataType = pluginService.getDefaultPlugin(data).getType(); - - GeaflowPluginConfig haConfig = pluginConfigService.getDefaultPluginConfig(ha, haType); - GeaflowPluginConfig metricConfig = pluginConfigService.getDefaultPluginConfig(metric, metricType); - GeaflowPluginConfig runtimeConfig = pluginConfigService.getDefaultPluginConfig(runtime, runtimeType); - GeaflowPluginConfig dataConfig = pluginConfigService.getDefaultPluginConfig(data, dataType); - - task.setHaMetaPluginConfig(haConfig); - task.setMetricPluginConfig(metricConfig); - task.setRuntimeMetaPluginConfig(runtimeConfig); - task.setDataPluginConfig(dataConfig); - } - - public String bindRelease(GeaflowRelease release) { - TaskEntity task = taskDao.getByJobId(release.getJob().getId()).get(0); - if (task.getStatus() == GeaflowTaskStatus.CREATED && release.getReleaseVersion() == 1) { - // don't need to bind at the first time - return null; - } - task.setReleaseId(release.getId()); - boolean updateStatus = updateStatus(task.getId(), task.getStatus(), GeaflowTaskStatus.CREATED); - if (!updateStatus) { - throw new GeaflowException("task status has been changed"); - } - taskDao.update(task); - - return task.getId(); - } - - public Map getTaskHandles(GeaflowTaskStatus status) { - List tasks = taskDao.getTasksByStatus(status); - return ListUtil.toMap(tasks, IdEntity::getId, e -> GeaflowTaskHandle.parse(e.getHandle())); - } - - public List getTasksByStatus(GeaflowTaskStatus status) { - List tasks = taskDao.getTasksByStatus(status); - return ListUtil.convert(tasks, this::convertToId); - } - - public GeaflowId getByTaskToken(String token) { - TaskEntity entity = taskDao.getByToken(token); - return entity == null ? null : convertToId(entity); - } - - private GeaflowId convertToId(TaskEntity entity) { - GeaflowId model = new GeaflowTask(); - model.setTenantId(entity.getTenantId()); - model.setId(entity.getId()); - model.setCreatorId(entity.getCreatorId()); - model.setModifierId(entity.getModifierId()); - model.setGmtCreate(entity.getGmtCreate()); - model.setGmtModified(entity.getGmtModified()); - return model; - } - - public boolean updateStatus(String taskId, GeaflowTaskStatus oldStatus, GeaflowTaskStatus newStatus) { - if (oldStatus == newStatus) { - return true; - } - return taskDao.updateStatus(taskId, oldStatus, newStatus); - } - - public GeaflowTaskStatus getStatus(String id) { - return Optional.ofNullable(taskDao.get(id)).map(TaskEntity::getStatus).orElse(null); - } + public GeaflowTaskStatus getStatus(String id) { + return Optional.ofNullable(taskDao.get(id)).map(TaskEntity::getStatus).orElse(null); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TenantService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TenantService.java index 0bb39f79c..299d64ebd 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TenantService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/TenantService.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.TenantDao; @@ -44,83 +45,81 @@ @Service public class TenantService extends NameService { - @Autowired - private TenantDao tenantDao; - - @Autowired - private TenantUserMappingDao tenantUserMappingDao; - - @Autowired - private TenantConverter tenantConverter; - - @Override - protected NameDao getDao() { - return tenantDao; - } - - @Override - protected NameConverter getConverter() { - return tenantConverter; - } + @Autowired private TenantDao tenantDao; - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> tenantConverter.convert(e)); - } + @Autowired private TenantUserMappingDao tenantUserMappingDao; - @Override - public PageList search(TenantSearch search) { - boolean systemSession = ContextHolder.get().isSystemSession(); - if (systemSession) { - return super.search(search); - } + @Autowired private TenantConverter tenantConverter; - String userId = ContextHolder.get().getUserId(); - return tenantDao.search(userId, search).transform(this::parse); - } + @Override + protected NameDao getDao() { + return tenantDao; + } - public String getDefaultTenantName(String userName) { - return "tenant_" + userName; - } + @Override + protected NameConverter getConverter() { + return tenantConverter; + } - public String createDefaultTenant(GeaflowUser user) { - String userName = user.getName(); - String userComment = user.getComment(); - String tenantName = getDefaultTenantName(userName); - String userDisplayName = StringUtils.isBlank(userComment) ? userName : userComment; - String tenantComment = Fmt.as(I18nUtil.getMessage("i18n.key.default.tenant.comment.format"), userDisplayName); - - TenantEntity entity = new TenantEntity(); - entity.setName(tenantName); - entity.setComment(tenantComment); - return tenantDao.create(entity); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> tenantConverter.convert(e)); + } - public GeaflowTenant getActiveTenant(String userId) { - TenantEntity entity = tenantUserMappingDao.getUserActiveTenant(userId); - return parse(entity); + @Override + public PageList search(TenantSearch search) { + boolean systemSession = ContextHolder.get().isSystemSession(); + if (systemSession) { + return super.search(search); } - public List getUserTenants(String userId) { - List entities = tenantUserMappingDao.getUserTenants(userId); - return parse(entities); + String userId = ContextHolder.get().getUserId(); + return tenantDao.search(userId, search).transform(this::parse); + } + + public String getDefaultTenantName(String userName) { + return "tenant_" + userName; + } + + public String createDefaultTenant(GeaflowUser user) { + String userName = user.getName(); + String userComment = user.getComment(); + String tenantName = getDefaultTenantName(userName); + String userDisplayName = StringUtils.isBlank(userComment) ? userName : userComment; + String tenantComment = + Fmt.as(I18nUtil.getMessage("i18n.key.default.tenant.comment.format"), userDisplayName); + + TenantEntity entity = new TenantEntity(); + entity.setName(tenantName); + entity.setComment(tenantComment); + return tenantDao.create(entity); + } + + public GeaflowTenant getActiveTenant(String userId) { + TenantEntity entity = tenantUserMappingDao.getUserActiveTenant(userId); + return parse(entity); + } + + public List getUserTenants(String userId) { + List entities = tenantUserMappingDao.getUserTenants(userId); + return parse(entities); + } + + @Transactional + public void activateTenant(String tenantId, String userId) { + TenantEntity activeTenant = tenantUserMappingDao.getUserActiveTenant(userId); + if (activeTenant != null && activeTenant.getTenantId().equals(tenantId)) { + return; } - @Transactional - public void activateTenant(String tenantId, String userId) { - TenantEntity activeTenant = tenantUserMappingDao.getUserActiveTenant(userId); - if (activeTenant != null && activeTenant.getTenantId().equals(tenantId)) { - return; - } + // reset other tenants + tenantUserMappingDao.deactivateUserTenants(userId); - // reset other tenants - tenantUserMappingDao.deactivateUserTenants(userId); + // active current tenant + tenantUserMappingDao.activateUserTenant(tenantId, userId); + } - // active current tenant - tenantUserMappingDao.activateUserTenant(tenantId, userId); - } - - public Map getTenantNames(Collection tenantIds) { - return tenantDao.getTenantNames(tenantIds); - } + public Map getTenantNames(Collection tenantIds) { + return tenantDao.getTenantNames(tenantIds); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/UserService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/UserService.java index 7802a9a6b..5ff20cea7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/UserService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/UserService.java @@ -19,10 +19,10 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.TenantUserMappingDao; import org.apache.geaflow.console.common.dal.dao.UserDao; @@ -40,87 +40,84 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import com.google.common.base.Preconditions; + @Service public class UserService extends NameService { - @Autowired - private UserDao userDao; + @Autowired private UserDao userDao; - @Autowired - private UserConverter userConverter; + @Autowired private UserConverter userConverter; - @Autowired - private AuthorizationService authorizationService; + @Autowired private AuthorizationService authorizationService; - @Autowired - private TenantUserMappingDao tenantUserMappingDao; + @Autowired private TenantUserMappingDao tenantUserMappingDao; - @Override - protected NameDao getDao() { - return userDao; - } - - @Override - protected NameConverter getConverter() { - return userConverter; - } + @Override + protected NameDao getDao() { + return userDao; + } - @Override - protected List parse(List entities) { - return ListUtil.convert(entities, e -> userConverter.convert(e)); - } + @Override + protected NameConverter getConverter() { + return userConverter; + } - @Override - public PageList search(UserSearch search) { - boolean systemSession = ContextHolder.get().isSystemSession(); - if (systemSession) { - return super.search(search); - } + @Override + protected List parse(List entities) { + return ListUtil.convert(entities, e -> userConverter.convert(e)); + } - String tenantId = ContextHolder.get().getTenantId(); - return userDao.search(tenantId, search).transform(this::parse); + @Override + public PageList search(UserSearch search) { + boolean systemSession = ContextHolder.get().isSystemSession(); + if (systemSession) { + return super.search(search); } - @Transactional - public String createUser(GeaflowUser user) { - if (userDao.existName(user.getName())) { - throw new IllegalArgumentException("User exists"); - } + String tenantId = ContextHolder.get().getTenantId(); + return userDao.search(tenantId, search).transform(this::parse); + } - // create user - String userId = create(user); + @Transactional + public String createUser(GeaflowUser user) { + if (userDao.existName(user.getName())) { + throw new IllegalArgumentException("User exists"); + } - // init user id context when register success - ContextHolder.get().setUserId(userId); + // create user + String userId = create(user); - // init the first user as system admin - if (userDao.count() == 1) { - authorizationService.addRole(null, userId, GeaflowRole.SYSTEM_ADMIN); - } + // init user id context when register success + ContextHolder.get().setUserId(userId); - return userId; + // init the first user as system admin + if (userDao.count() == 1) { + authorizationService.addRole(null, userId, GeaflowRole.SYSTEM_ADMIN); } - public Map getUserNames(Collection userIds) { - return userDao.getUserNames(userIds); - } + return userId; + } - public boolean existTenantUser(String tenantId, String userId) { - return tenantUserMappingDao.existUser(tenantId, userId); - } + public Map getUserNames(Collection userIds) { + return userDao.getUserNames(userIds); + } - public boolean addTenantUser(String tenantId, String userId) { - if (tenantUserMappingDao.existUser(tenantId, userId)) { - throw new GeaflowIllegalException("Tenant user exists"); - } + public boolean existTenantUser(String tenantId, String userId) { + return tenantUserMappingDao.existUser(tenantId, userId); + } - Preconditions.checkNotNull(tenantId, "Invalid tenantId"); - Preconditions.checkArgument(userDao.exist(userId), "User not exists"); - return tenantUserMappingDao.addUser(tenantId, userId); + public boolean addTenantUser(String tenantId, String userId) { + if (tenantUserMappingDao.existUser(tenantId, userId)) { + throw new GeaflowIllegalException("Tenant user exists"); } - public boolean deleteTenantUser(String tenantId, String userId) { - return tenantUserMappingDao.deleteUser(tenantId, userId); - } -} + Preconditions.checkNotNull(tenantId, "Invalid tenantId"); + Preconditions.checkArgument(userDao.exist(userId), "User not exists"); + return tenantUserMappingDao.addUser(tenantId, userId); + } + public boolean deleteTenantUser(String tenantId, String userId) { + return tenantUserMappingDao.deleteUser(tenantId, userId); + } +} diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VersionService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VersionService.java index 9eb00a86c..d6f226d88 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VersionService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VersionService.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.core.service; -import com.google.common.base.Preconditions; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.VersionDao; import org.apache.geaflow.console.common.dal.entity.VersionEntity; @@ -34,50 +34,52 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.google.common.base.Preconditions; + @Service -public class VersionService extends NameService implements FileRefService { +public class VersionService extends NameService + implements FileRefService { - @Autowired - private VersionDao versionDao; + @Autowired private VersionDao versionDao; - @Autowired - private VersionConverter versionConverter; + @Autowired private VersionConverter versionConverter; - @Autowired - private RemoteFileService remoteFileService; + @Autowired private RemoteFileService remoteFileService; - @Override - protected NameDao getDao() { - return versionDao; - } + @Override + protected NameDao getDao() { + return versionDao; + } - @Override - protected NameConverter getConverter() { - return versionConverter; - } + @Override + protected NameConverter getConverter() { + return versionConverter; + } - @Override - protected List parse(List versionEntities) { - return versionEntities.stream().map(e -> { - GeaflowRemoteFile engineJar = remoteFileService.get(e.getEngineJarId()); - GeaflowRemoteFile langJar = remoteFileService.get(e.getLangJarId()); - return versionConverter.convert(e, engineJar, langJar); - }).collect(Collectors.toList()); - } + @Override + protected List parse(List versionEntities) { + return versionEntities.stream() + .map( + e -> { + GeaflowRemoteFile engineJar = remoteFileService.get(e.getEngineJarId()); + GeaflowRemoteFile langJar = remoteFileService.get(e.getLangJarId()); + return versionConverter.convert(e, engineJar, langJar); + }) + .collect(Collectors.toList()); + } - public GeaflowVersion getDefaultVersion() { - VersionEntity version = versionDao.getDefaultVersion(); - Preconditions.checkNotNull(version, "No default published version found"); - return parse(version); - } + public GeaflowVersion getDefaultVersion() { + VersionEntity version = versionDao.getDefaultVersion(); + Preconditions.checkNotNull(version, "No default published version found"); + return parse(version); + } - public GeaflowVersion getPublishVersionByName(String name) { - return parse(versionDao.getPublishVersionByName(name)); - } + public GeaflowVersion getPublishVersionByName(String name) { + return parse(versionDao.getPublishVersionByName(name)); + } - @Override - public long getFileRefCount(String fileId, String versionId) { - return versionDao.getFileRefCount(fileId, versionId); - } + @Override + public long getFileRefCount(String fileId, String versionId) { + return versionDao.getFileRefCount(fileId, versionId); + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VertexService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VertexService.java index 10dbe3afb..13a9387dc 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VertexService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/VertexService.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.dal.dao.DataDao; import org.apache.geaflow.console.common.dal.dao.VertexDao; import org.apache.geaflow.console.common.dal.entity.IdEntity; @@ -41,105 +42,106 @@ @Service public class VertexService extends DataService { - @Autowired - private VertexDao vertexDao; - - @Autowired - private GraphService graphService; + @Autowired private VertexDao vertexDao; - @Autowired - private FieldService fieldService; + @Autowired private GraphService graphService; - @Autowired - private VertexConverter vertexConverter; - private GeaflowResourceType resourceType = GeaflowResourceType.VERTEX; - - @Override - protected DataDao getDao() { - return vertexDao; - } + @Autowired private FieldService fieldService; - @Override - protected DataConverter getConverter() { - return vertexConverter; - } + @Autowired private VertexConverter vertexConverter; + private GeaflowResourceType resourceType = GeaflowResourceType.VERTEX; - @Override - public List create(List models) { - List vertexIds = super.create(models); + @Override + protected DataDao getDao() { + return vertexDao; + } - for (GeaflowVertex model : models) { - fieldService.createByResource(new ArrayList<>(model.getFields().values()), model.getId(), resourceType); - } + @Override + protected DataConverter getConverter() { + return vertexConverter; + } - return vertexIds; - } + @Override + public List create(List models) { + List vertexIds = super.create(models); - @Override - protected List parse(List vertexEntities) { - List vertexIds = ListUtil.convert(vertexEntities, IdEntity::getId); - //select fields of each vertex - Map> fieldsMap = fieldService.getByResources(vertexIds, resourceType); - return vertexEntities.stream().map(e -> { - List fields = fieldsMap.get(e.getId()); - return vertexConverter.convert(e, fields); - }).collect(Collectors.toList()); + for (GeaflowVertex model : models) { + fieldService.createByResource( + new ArrayList<>(model.getFields().values()), model.getId(), resourceType); } - @Override - public boolean update(List vertices) { - // update field - List ids = ListUtil.convert(vertices, GeaflowId::getId); - - fieldService.removeByResources(ids, resourceType); - for (GeaflowVertex vertex : vertices) { - List newFields = new ArrayList<>(vertex.getFields().values()); - fieldService.createByResource(newFields, vertex.getId(), resourceType); - } - return super.update(vertices); + return vertexIds; + } + + @Override + protected List parse(List vertexEntities) { + List vertexIds = ListUtil.convert(vertexEntities, IdEntity::getId); + // select fields of each vertex + Map> fieldsMap = + fieldService.getByResources(vertexIds, resourceType); + return vertexEntities.stream() + .map( + e -> { + List fields = fieldsMap.get(e.getId()); + return vertexConverter.convert(e, fields); + }) + .collect(Collectors.toList()); + } + + @Override + public boolean update(List vertices) { + // update field + List ids = ListUtil.convert(vertices, GeaflowId::getId); + + fieldService.removeByResources(ids, resourceType); + for (GeaflowVertex vertex : vertices) { + List newFields = new ArrayList<>(vertex.getFields().values()); + fieldService.createByResource(newFields, vertex.getId(), resourceType); } - - @Override - public boolean drop(List ids) { - // can't drop if is used in graph. - for (String id : ids) { - graphService.checkBindingRelations(id, GeaflowResourceType.VERTEX); - } - - fieldService.removeByResources(ids, resourceType); - return super.drop(ids); + return super.update(vertices); + } + + @Override + public boolean drop(List ids) { + // can't drop if is used in graph. + for (String id : ids) { + graphService.checkBindingRelations(id, GeaflowResourceType.VERTEX); } - - public List getVerticesByGraphId(String graphId) { - List entities = vertexDao.getByGraphId(graphId); - return parse(entities); - } - - public List getVerticesByGraphId(String graphId, Map vertexMap) { - List entities = vertexDao.getByGraphId(graphId); - // filter vertices that not exist - List rests = new ArrayList<>(); - List exists = new ArrayList<>(); - - entities.forEach(e -> { - if (vertexMap.containsKey(e.getId())) { - exists.add(e); - } else { - rests.add(e); - } + fieldService.removeByResources(ids, resourceType); + return super.drop(ids); + } + + public List getVerticesByGraphId(String graphId) { + List entities = vertexDao.getByGraphId(graphId); + return parse(entities); + } + + public List getVerticesByGraphId( + String graphId, Map vertexMap) { + List entities = vertexDao.getByGraphId(graphId); + // filter vertices that not exist + List rests = new ArrayList<>(); + List exists = new ArrayList<>(); + + entities.forEach( + e -> { + if (vertexMap.containsKey(e.getId())) { + exists.add(e); + } else { + rests.add(e); + } }); - List vertices = parse(rests); - List existVertices = ListUtil.convert(exists, e -> vertexMap.get(e.getId())); - vertices.addAll(existVertices); - return vertices; - } - - public List getVerticesByEdgeId(String edgeId) { - // get vertexEntity - List vertices = vertexDao.getByEdge(edgeId); - return parse(vertices); - } - + List vertices = parse(rests); + List existVertices = ListUtil.convert(exists, e -> vertexMap.get(e.getId())); + vertices.addAll(existVertices); + return vertices; + } + + public List getVerticesByEdgeId(String edgeId) { + // get vertexEntity + List vertices = vertexDao.getByEdge(edgeId); + return parse(vertices); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ViewService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ViewService.java index 024e1c43a..f61b5ffde 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ViewService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/ViewService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service; import java.util.List; + import org.apache.geaflow.console.common.dal.dao.NameDao; import org.apache.geaflow.console.common.dal.dao.ViewDao; import org.apache.geaflow.console.common.dal.entity.ViewEntity; @@ -33,26 +34,22 @@ @Service public class ViewService extends NameService { - @Autowired - private ViewDao viewDao; - - @Autowired - private ViewConverter viewConverter; + @Autowired private ViewDao viewDao; - @Override - protected NameDao getDao() { - return viewDao; - } + @Autowired private ViewConverter viewConverter; - @Override - protected NameConverter getConverter() { - return viewConverter; - } + @Override + protected NameDao getDao() { + return viewDao; + } - @Override - protected List parse(List viewEntities) { - return null; - } + @Override + protected NameConverter getConverter() { + return viewConverter; + } + @Override + protected List parse(List viewEntities) { + return null; + } } - diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/aspect/ConverterAspect.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/aspect/ConverterAspect.java index 5d7d1d394..92d68fedd 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/aspect/ConverterAspect.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/aspect/ConverterAspect.java @@ -29,17 +29,17 @@ @Component public class ConverterAspect { - @Around("execution(* org.apache.geaflow.console.core.service.converter.*Converter.convert(..))") - public Object handle(ProceedingJoinPoint joinPoint) throws Throwable { - Object[] args = joinPoint.getArgs(); - if (args[0] == null) { - return null; - } - - if (args[0] instanceof GeaflowId) { - ((GeaflowId) args[0]).validate(); - } + @Around("execution(* org.apache.geaflow.console.core.service.converter.*Converter.convert(..))") + public Object handle(ProceedingJoinPoint joinPoint) throws Throwable { + Object[] args = joinPoint.getArgs(); + if (args[0] == null) { + return null; + } - return joinPoint.proceed(args); + if (args[0] instanceof GeaflowId) { + ((GeaflowId) args[0]).validate(); } + + return joinPoint.proceed(args); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DatasourceConfig.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DatasourceConfig.java index 7db019aa2..9850467e9 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DatasourceConfig.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DatasourceConfig.java @@ -19,33 +19,33 @@ package org.apache.geaflow.console.core.service.config; -import lombok.Getter; -import lombok.Setter; import org.apache.geaflow.console.core.model.plugin.config.JdbcPluginConfigClass; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.stereotype.Component; +import lombok.Getter; +import lombok.Setter; + @Component @ConfigurationProperties("spring.datasource") @Getter @Setter public class DatasourceConfig { - private String driverClassName; - - private String url; + private String driverClassName; - private String username; + private String url; - private String password; + private String username; - public JdbcPluginConfigClass buildPluginConfigClass() { - JdbcPluginConfigClass jdbcConfig = new JdbcPluginConfigClass(); - jdbcConfig.setDriverClass(driverClassName); - jdbcConfig.setUrl(url); - jdbcConfig.setUsername(username); - jdbcConfig.setPassword(password); - return jdbcConfig; - } + private String password; + public JdbcPluginConfigClass buildPluginConfigClass() { + JdbcPluginConfigClass jdbcConfig = new JdbcPluginConfigClass(); + jdbcConfig.setDriverClass(driverClassName); + jdbcConfig.setUrl(url); + jdbcConfig.setUsername(username); + jdbcConfig.setPassword(password); + return jdbcConfig; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DeployConfig.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DeployConfig.java index a2c3d0ca8..bb2c41ca4 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DeployConfig.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/config/DeployConfig.java @@ -19,35 +19,36 @@ package org.apache.geaflow.console.core.service.config; -import lombok.Getter; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.type.GeaflowDeployMode; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; +import lombok.Getter; + @Getter @Component public class DeployConfig implements InitializingBean { - @Value("${geaflow.host}") - protected String host; + @Value("${geaflow.host}") + protected String host; - @Value("${geaflow.gateway.port}") - protected Integer gatewayPort; + @Value("${geaflow.gateway.port}") + protected Integer gatewayPort; - @Value("${geaflow.gateway.url}") - protected String gatewayUrl; + @Value("${geaflow.gateway.url}") + protected String gatewayUrl; - @Value("${geaflow.deploy.mode}") - private GeaflowDeployMode mode = GeaflowDeployMode.LOCAL; + @Value("${geaflow.deploy.mode}") + private GeaflowDeployMode mode = GeaflowDeployMode.LOCAL; - public boolean isLocalMode() { - return GeaflowDeployMode.LOCAL.equals(mode); - } + public boolean isLocalMode() { + return GeaflowDeployMode.LOCAL.equals(mode); + } - @Override - public void afterPropertiesSet() throws Exception { - this.gatewayUrl = StringUtils.removeEnd(gatewayUrl, "/"); - } + @Override + public void afterPropertiesSet() throws Exception { + this.gatewayUrl = StringUtils.removeEnd(gatewayUrl, "/"); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuditConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuditConverter.java index 68fd384e0..8223fcd7d 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuditConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuditConverter.java @@ -26,24 +26,24 @@ @Component public class AuditConverter extends IdConverter { - @Override - protected AuditEntity modelToEntity(GeaflowAudit model) { - AuditEntity entity = super.modelToEntity(model); - entity.setResourceType(model.getResourceType()); - entity.setResourceId(model.getResourceId()); - entity.setOperationType(model.getOperationType()); - entity.setDetail(model.getDetail()); + @Override + protected AuditEntity modelToEntity(GeaflowAudit model) { + AuditEntity entity = super.modelToEntity(model); + entity.setResourceType(model.getResourceType()); + entity.setResourceId(model.getResourceId()); + entity.setOperationType(model.getOperationType()); + entity.setDetail(model.getDetail()); - return entity; - } + return entity; + } - public GeaflowAudit convert(AuditEntity entity) { - GeaflowAudit model = super.entityToModel(entity); - model.setResourceType(entity.getResourceType()); - model.setResourceId(entity.getResourceId()); - model.setOperationType(entity.getOperationType()); - model.setDetail(entity.getDetail()); + public GeaflowAudit convert(AuditEntity entity) { + GeaflowAudit model = super.entityToModel(entity); + model.setResourceType(entity.getResourceType()); + model.setResourceId(entity.getResourceId()); + model.setOperationType(entity.getOperationType()); + model.setDetail(entity.getDetail()); - return model; - } + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuthorizationConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuthorizationConverter.java index 9a80fb851..e910f413e 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuthorizationConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/AuthorizationConverter.java @@ -26,28 +26,27 @@ @Component public class AuthorizationConverter extends IdConverter { - @Override - protected AuthorizationEntity modelToEntity(GeaflowAuthorization model) { - AuthorizationEntity entity = super.modelToEntity(model); - entity.setResourceType(model.getResourceType()); - entity.setAuthorityType(model.getAuthorityType()); - entity.setUserId(model.getUserId()); - entity.setResourceId(model.getResourceId()); - return entity; - } + @Override + protected AuthorizationEntity modelToEntity(GeaflowAuthorization model) { + AuthorizationEntity entity = super.modelToEntity(model); + entity.setResourceType(model.getResourceType()); + entity.setAuthorityType(model.getAuthorityType()); + entity.setUserId(model.getUserId()); + entity.setResourceId(model.getResourceId()); + return entity; + } - @Override - protected GeaflowAuthorization entityToModel(AuthorizationEntity entity) { - GeaflowAuthorization model = super.entityToModel(entity); - model.setResourceType(entity.getResourceType()); - model.setAuthorityType(entity.getAuthorityType()); - model.setUserId(entity.getUserId()); - model.setResourceId(entity.getResourceId()); - return model; - } - - public GeaflowAuthorization convert(AuthorizationEntity entity) { - return entityToModel(entity); - } + @Override + protected GeaflowAuthorization entityToModel(AuthorizationEntity entity) { + GeaflowAuthorization model = super.entityToModel(entity); + model.setResourceType(entity.getResourceType()); + model.setAuthorityType(entity.getAuthorityType()); + model.setUserId(entity.getUserId()); + model.setResourceId(entity.getResourceId()); + return model; + } + public GeaflowAuthorization convert(AuthorizationEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ChatConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ChatConverter.java index cf84ebe84..c0fd3c1b2 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ChatConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ChatConverter.java @@ -26,29 +26,29 @@ @Component public class ChatConverter extends IdConverter { - @Override - protected ChatEntity modelToEntity(GeaflowChat model) { - ChatEntity entity = super.modelToEntity(model); - entity.setAnswer(model.getAnswer()); - entity.setPrompt(model.getPrompt()); - entity.setModelId(model.getModelId()); - entity.setStatus(model.getStatus()); - entity.setJobId(model.getJobId()); - return entity; - } + @Override + protected ChatEntity modelToEntity(GeaflowChat model) { + ChatEntity entity = super.modelToEntity(model); + entity.setAnswer(model.getAnswer()); + entity.setPrompt(model.getPrompt()); + entity.setModelId(model.getModelId()); + entity.setStatus(model.getStatus()); + entity.setJobId(model.getJobId()); + return entity; + } - @Override - protected GeaflowChat entityToModel(ChatEntity entity) { - GeaflowChat model = super.entityToModel(entity); - model.setAnswer(entity.getAnswer()); - model.setPrompt(entity.getPrompt()); - model.setModelId(entity.getModelId()); - model.setStatus(entity.getStatus()); - model.setJobId(entity.getJobId()); - return model; - } + @Override + protected GeaflowChat entityToModel(ChatEntity entity) { + GeaflowChat model = super.entityToModel(entity); + model.setAnswer(entity.getAnswer()); + model.setPrompt(entity.getPrompt()); + model.setModelId(entity.getModelId()); + model.setStatus(entity.getStatus()); + model.setJobId(entity.getJobId()); + return model; + } - public GeaflowChat convert(ChatEntity entity) { - return entityToModel(entity); - } + public GeaflowChat convert(ChatEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ClusterConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ClusterConverter.java index 0bb0c891f..69d5cef62 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ClusterConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ClusterConverter.java @@ -19,32 +19,33 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.console.common.dal.entity.ClusterEntity; import org.apache.geaflow.console.core.model.cluster.GeaflowCluster; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class ClusterConverter extends NameConverter { - @Override - protected ClusterEntity modelToEntity(GeaflowCluster model) { - ClusterEntity entity = super.modelToEntity(model); - entity.setType(model.getType()); - entity.setConfig(JSON.toJSONString(model.getConfig())); - return entity; - } + @Override + protected ClusterEntity modelToEntity(GeaflowCluster model) { + ClusterEntity entity = super.modelToEntity(model); + entity.setType(model.getType()); + entity.setConfig(JSON.toJSONString(model.getConfig())); + return entity; + } - @Override - protected GeaflowCluster entityToModel(ClusterEntity entity) { - GeaflowCluster model = super.entityToModel(entity); - model.setType(entity.getType()); - model.setConfig(JSON.parseObject(entity.getConfig(), GeaflowConfig.class)); - return model; - } + @Override + protected GeaflowCluster entityToModel(ClusterEntity entity) { + GeaflowCluster model = super.entityToModel(entity); + model.setType(entity.getType()); + model.setConfig(JSON.parseObject(entity.getConfig(), GeaflowConfig.class)); + return model; + } - public GeaflowCluster convert(ClusterEntity entity) { - return entityToModel(entity); - } + public GeaflowCluster convert(ClusterEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/DataConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/DataConverter.java index e07902c89..22199e34e 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/DataConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/DataConverter.java @@ -22,30 +22,28 @@ import org.apache.geaflow.console.common.dal.entity.DataEntity; import org.apache.geaflow.console.core.model.data.GeaflowData; -public abstract class DataConverter extends NameConverter { - - @Override - protected E modelToEntity(M model) { - E entity = super.modelToEntity(model); - String instanceId = model.getInstanceId(); - entity.setInstanceId(instanceId); - return entity; - } - - @Override - protected M entityToModel(E entity) { - M model = super.entityToModel(entity); - model.setInstanceId(entity.getInstanceId()); - return model; - } - - - @Override - protected M entityToModel(E entity, Class clazz) { - M model = super.entityToModel(entity, clazz); - model.setInstanceId(entity.getInstanceId()); - return model; - } - - +public abstract class DataConverter + extends NameConverter { + + @Override + protected E modelToEntity(M model) { + E entity = super.modelToEntity(model); + String instanceId = model.getInstanceId(); + entity.setInstanceId(instanceId); + return entity; + } + + @Override + protected M entityToModel(E entity) { + M model = super.entityToModel(entity); + model.setInstanceId(entity.getInstanceId()); + return model; + } + + @Override + protected M entityToModel(E entity, Class clazz) { + M model = super.entityToModel(entity, clazz); + model.setInstanceId(entity.getInstanceId()); + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/EdgeConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/EdgeConverter.java index 9f44aa09e..577e78746 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/EdgeConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/EdgeConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.EdgeEntity; import org.apache.geaflow.console.core.model.data.GeaflowEdge; import org.apache.geaflow.console.core.model.data.GeaflowField; @@ -28,9 +29,9 @@ @Component public class EdgeConverter extends DataConverter { - public GeaflowEdge convert(EdgeEntity entity, List fields) { - GeaflowEdge edge = entityToModel(entity); - edge.addFields(fields); - return edge; - } + public GeaflowEdge convert(EdgeEntity entity, List fields) { + GeaflowEdge edge = entityToModel(entity); + edge.addFields(fields); + return edge; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FieldConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FieldConverter.java index 0372c90f6..704fe9662 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FieldConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FieldConverter.java @@ -27,33 +27,32 @@ @Component public class FieldConverter extends NameConverter { - @Override - protected FieldEntity modelToEntity(GeaflowField model) { - FieldEntity entity = super.modelToEntity(model); - entity.setType(model.getType()); - entity.setCategory(model.getCategory()); - return entity; - } - - @Override - protected GeaflowField entityToModel(FieldEntity entity) { - GeaflowField model = super.entityToModel(entity); - model.setType(entity.getType()); - model.setCategory(entity.getCategory()); - return model; - } - - public GeaflowField convert(FieldEntity entity) { - return entityToModel(entity); - } - - - public FieldEntity convert(GeaflowField field, String resourceId, GeaflowResourceType resourceType, int index) { - FieldEntity entity = modelToEntity(field); - entity.setResourceId(resourceId); - entity.setResourceType(resourceType); - entity.setSortKey(index); - return entity; - } - + @Override + protected FieldEntity modelToEntity(GeaflowField model) { + FieldEntity entity = super.modelToEntity(model); + entity.setType(model.getType()); + entity.setCategory(model.getCategory()); + return entity; + } + + @Override + protected GeaflowField entityToModel(FieldEntity entity) { + GeaflowField model = super.entityToModel(entity); + model.setType(entity.getType()); + model.setCategory(entity.getCategory()); + return model; + } + + public GeaflowField convert(FieldEntity entity) { + return entityToModel(entity); + } + + public FieldEntity convert( + GeaflowField field, String resourceId, GeaflowResourceType resourceType, int index) { + FieldEntity entity = modelToEntity(field); + entity.setResourceId(resourceId); + entity.setResourceType(resourceType); + entity.setSortKey(index); + return entity; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FunctionConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FunctionConverter.java index 5c651906c..9261be552 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FunctionConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/FunctionConverter.java @@ -27,24 +27,24 @@ @Component public class FunctionConverter extends DataConverter { - @Override - protected FunctionEntity modelToEntity(GeaflowFunction model) { - FunctionEntity entity = super.modelToEntity(model); - entity.setJarPackageId(model.getJarPackage().getId()); - entity.setEntryClass(model.getEntryClass()); - return entity; - } + @Override + protected FunctionEntity modelToEntity(GeaflowFunction model) { + FunctionEntity entity = super.modelToEntity(model); + entity.setJarPackageId(model.getJarPackage().getId()); + entity.setEntryClass(model.getEntryClass()); + return entity; + } - @Override - protected GeaflowFunction entityToModel(FunctionEntity entity) { - GeaflowFunction model = super.entityToModel(entity); - model.setEntryClass(entity.getEntryClass()); - return model; - } + @Override + protected GeaflowFunction entityToModel(FunctionEntity entity) { + GeaflowFunction model = super.entityToModel(entity); + model.setEntryClass(entity.getEntryClass()); + return model; + } - public GeaflowFunction convert(FunctionEntity entity, GeaflowRemoteFile jarPackage) { - GeaflowFunction model = entityToModel(entity); - model.setJarPackage(jarPackage); - return model; - } + public GeaflowFunction convert(FunctionEntity entity, GeaflowRemoteFile jarPackage) { + GeaflowFunction model = entityToModel(entity); + model.setJarPackage(jarPackage); + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/GraphConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/GraphConverter.java index 6059f7b33..18a7dc82e 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/GraphConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/GraphConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.GraphEntity; import org.apache.geaflow.console.core.model.data.GeaflowEdge; import org.apache.geaflow.console.core.model.data.GeaflowEndpoint; @@ -31,21 +32,25 @@ @Component public class GraphConverter extends DataConverter { - @Override - protected GraphEntity modelToEntity(GeaflowGraph model) { - GraphEntity entity = super.modelToEntity(model); - String configId = model.getPluginConfig().getId(); - entity.setPluginConfigId(configId); - return entity; - } + @Override + protected GraphEntity modelToEntity(GeaflowGraph model) { + GraphEntity entity = super.modelToEntity(model); + String configId = model.getPluginConfig().getId(); + entity.setPluginConfigId(configId); + return entity; + } - public GeaflowGraph convert(GraphEntity entity, List vertices, List edges, - List endpoints, GeaflowPluginConfig pluginConfig) { - GeaflowGraph graph = entityToModel(entity); - graph.addVertices(vertices); - graph.addEdges(edges); - graph.setEndpoints(endpoints); - graph.setPluginConfig(pluginConfig); - return graph; - } + public GeaflowGraph convert( + GraphEntity entity, + List vertices, + List edges, + List endpoints, + GeaflowPluginConfig pluginConfig) { + GeaflowGraph graph = entityToModel(entity); + graph.addVertices(vertices); + graph.addEdges(edges); + graph.setEndpoints(endpoints); + graph.setPluginConfig(pluginConfig); + return graph; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/IdConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/IdConverter.java index 63fa2986a..3c181c6e0 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/IdConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/IdConverter.java @@ -21,6 +21,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; + import org.apache.geaflow.console.common.dal.entity.IdEntity; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.GeaflowId; @@ -28,55 +29,57 @@ @SuppressWarnings("unchecked") public abstract class IdConverter { - public E convert(M model) { - return modelToEntity(model); - } + public E convert(M model) { + return modelToEntity(model); + } - protected E modelToEntity(M model) { - Type[] args = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); + protected E modelToEntity(M model) { + Type[] args = + ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); - try { - E entity = (E) ((Class) args[1]).newInstance(); + try { + E entity = (E) ((Class) args[1]).newInstance(); - entity.setId(model.getId()); - entity.setCreatorId(model.getCreatorId()); - entity.setModifierId(model.getModifierId()); - entity.setGmtCreate(model.getGmtCreate()); - entity.setGmtModified(model.getGmtModified()); - return entity; + entity.setId(model.getId()); + entity.setCreatorId(model.getCreatorId()); + entity.setModifierId(model.getModifierId()); + entity.setGmtCreate(model.getGmtCreate()); + entity.setGmtModified(model.getGmtModified()); + return entity; - } catch (Exception e) { - throw new GeaflowException("Convert id model to entity failed", e); - } + } catch (Exception e) { + throw new GeaflowException("Convert id model to entity failed", e); } + } - protected M entityToModel(E entity) { - try { - Type[] args = ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); - M model = (M) ((Class) args[0]).newInstance(); - setProperty(model, entity); - return model; - } catch (Exception e) { - throw new GeaflowException("Convert id entity to model failed", e); - } + protected M entityToModel(E entity) { + try { + Type[] args = + ((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments(); + M model = (M) ((Class) args[0]).newInstance(); + setProperty(model, entity); + return model; + } catch (Exception e) { + throw new GeaflowException("Convert id entity to model failed", e); } + } - protected M entityToModel(E entity, Class clazz) { - try { - M model = clazz.newInstance(); - setProperty(model, entity); - return model; - } catch (Exception e) { - throw new GeaflowException("Convert id entity to model failed", e); - } + protected M entityToModel(E entity, Class clazz) { + try { + M model = clazz.newInstance(); + setProperty(model, entity); + return model; + } catch (Exception e) { + throw new GeaflowException("Convert id entity to model failed", e); } + } - private void setProperty(M model, E entity) { - model.setTenantId(entity.getTenantId()); - model.setId(entity.getId()); - model.setCreatorId(entity.getCreatorId()); - model.setModifierId(entity.getModifierId()); - model.setGmtCreate(entity.getGmtCreate()); - model.setGmtModified(entity.getGmtModified()); - } + private void setProperty(M model, E entity) { + model.setTenantId(entity.getTenantId()); + model.setId(entity.getId()); + model.setCreatorId(entity.getCreatorId()); + model.setModifierId(entity.getModifierId()); + model.setGmtCreate(entity.getGmtCreate()); + model.setGmtModified(entity.getGmtModified()); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/InstanceConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/InstanceConverter.java index 6fb3957e3..a5f1db0fb 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/InstanceConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/InstanceConverter.java @@ -26,7 +26,7 @@ @Component public class InstanceConverter extends NameConverter { - public GeaflowInstance convert(InstanceEntity entity) { - return entityToModel(entity); - } + public GeaflowInstance convert(InstanceEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/JobConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/JobConverter.java index 16b006550..5a0c7d5f7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/JobConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/JobConverter.java @@ -19,9 +19,9 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import java.util.List; import java.util.Optional; + import org.apache.geaflow.console.common.dal.entity.JobEntity; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowJobType; @@ -40,66 +40,79 @@ import org.apache.geaflow.console.core.model.plugin.GeaflowPlugin; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class JobConverter extends NameConverter { - @Override - protected JobEntity modelToEntity(GeaflowJob model) { - GeaflowJobType jobType = model.getType(); - JobEntity entity = super.modelToEntity(model); - entity.setType(jobType); - entity.setUserCode(Optional.ofNullable(model.getUserCode()).map(GeaflowCode::getText).orElse(null)); - entity.setStructMappings(Optional.ofNullable(model.getStructMappings()).map(JSON::toJSONString).orElse(null)); - entity.setInstanceId(model.getInstanceId()); - entity.setJarPackageId(Optional.ofNullable(model.getJarPackage()).map(GeaflowId::getId).orElse(null)); - entity.setEntryClass(model.getEntryClass()); - return entity; - } - - - public GeaflowJob convert(JobEntity entity, List structs, List graphs, List functions, - List plugins, GeaflowRemoteFile jarPackage) { - GeaflowJobType jobType = entity.getType(); - GeaflowJob job; - switch (jobType) { - case INTEGRATE: - GeaflowIntegrateJob integrateJob = (GeaflowIntegrateJob) super.entityToModel(entity, GeaflowIntegrateJob.class); - List structMappings = JSON.parseArray(entity.getStructMappings(), StructMapping.class); - integrateJob.setStructMappings(structMappings); - integrateJob.setGraph(graphs); - integrateJob.setStructs(structs); - integrateJob.setUserCode(entity.getUserCode()); - job = integrateJob; - break; - case PROCESS: - GeaflowProcessJob processJob = (GeaflowProcessJob) super.entityToModel(entity, GeaflowProcessJob.class); - processJob.setUserCode(entity.getUserCode()); - processJob.setFunctions(functions); - processJob.setPlugins(plugins); - processJob.setStructs(structs); - processJob.setGraph(graphs); - job = processJob; - break; - case CUSTOM: - GeaflowCustomJob customJob = (GeaflowCustomJob) super.entityToModel(entity, GeaflowCustomJob.class); - customJob.setEntryClass(entity.getEntryClass()); - customJob.setJarPackage(jarPackage); - job = customJob; - break; - case SERVE: - GeaflowServeJob serveJob = (GeaflowServeJob) super.entityToModel(entity, GeaflowServeJob.class); - serveJob.setEntryClass(entity.getEntryClass()); - serveJob.setGraph(graphs); - job = serveJob; - break; - default: - throw new GeaflowException("Unsupported job type: {}", jobType); - } + @Override + protected JobEntity modelToEntity(GeaflowJob model) { + GeaflowJobType jobType = model.getType(); + JobEntity entity = super.modelToEntity(model); + entity.setType(jobType); + entity.setUserCode( + Optional.ofNullable(model.getUserCode()).map(GeaflowCode::getText).orElse(null)); + entity.setStructMappings( + Optional.ofNullable(model.getStructMappings()).map(JSON::toJSONString).orElse(null)); + entity.setInstanceId(model.getInstanceId()); + entity.setJarPackageId( + Optional.ofNullable(model.getJarPackage()).map(GeaflowId::getId).orElse(null)); + entity.setEntryClass(model.getEntryClass()); + return entity; + } - job.setType(entity.getType()); - job.setInstanceId(entity.getInstanceId()); - //TODO job.setSla(entity.getSlaId()); - return job; + public GeaflowJob convert( + JobEntity entity, + List structs, + List graphs, + List functions, + List plugins, + GeaflowRemoteFile jarPackage) { + GeaflowJobType jobType = entity.getType(); + GeaflowJob job; + switch (jobType) { + case INTEGRATE: + GeaflowIntegrateJob integrateJob = + (GeaflowIntegrateJob) super.entityToModel(entity, GeaflowIntegrateJob.class); + List structMappings = + JSON.parseArray(entity.getStructMappings(), StructMapping.class); + integrateJob.setStructMappings(structMappings); + integrateJob.setGraph(graphs); + integrateJob.setStructs(structs); + integrateJob.setUserCode(entity.getUserCode()); + job = integrateJob; + break; + case PROCESS: + GeaflowProcessJob processJob = + (GeaflowProcessJob) super.entityToModel(entity, GeaflowProcessJob.class); + processJob.setUserCode(entity.getUserCode()); + processJob.setFunctions(functions); + processJob.setPlugins(plugins); + processJob.setStructs(structs); + processJob.setGraph(graphs); + job = processJob; + break; + case CUSTOM: + GeaflowCustomJob customJob = + (GeaflowCustomJob) super.entityToModel(entity, GeaflowCustomJob.class); + customJob.setEntryClass(entity.getEntryClass()); + customJob.setJarPackage(jarPackage); + job = customJob; + break; + case SERVE: + GeaflowServeJob serveJob = + (GeaflowServeJob) super.entityToModel(entity, GeaflowServeJob.class); + serveJob.setEntryClass(entity.getEntryClass()); + serveJob.setGraph(graphs); + job = serveJob; + break; + default: + throw new GeaflowException("Unsupported job type: {}", jobType); } + job.setType(entity.getType()); + job.setInstanceId(entity.getInstanceId()); + // TODO job.setSla(entity.getSlaId()); + return job; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/LLMConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/LLMConverter.java index 0453e20fa..72b616608 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/LLMConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/LLMConverter.java @@ -19,34 +19,35 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.console.common.dal.entity.LLMEntity; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.llm.GeaflowLLM; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class LLMConverter extends NameConverter { - @Override - protected LLMEntity modelToEntity(GeaflowLLM model) { - LLMEntity entity = super.modelToEntity(model); - entity.setType(model.getType()); - entity.setUrl(model.getUrl()); - entity.setArgs(JSON.toJSONString(model.getArgs())); - return entity; - } + @Override + protected LLMEntity modelToEntity(GeaflowLLM model) { + LLMEntity entity = super.modelToEntity(model); + entity.setType(model.getType()); + entity.setUrl(model.getUrl()); + entity.setArgs(JSON.toJSONString(model.getArgs())); + return entity; + } - @Override - protected GeaflowLLM entityToModel(LLMEntity entity) { - GeaflowLLM model = super.entityToModel(entity); - model.setType(entity.getType()); - model.setUrl(entity.getUrl()); - model.setArgs(JSON.parseObject(entity.getArgs(), GeaflowConfig.class)); - return model; - } + @Override + protected GeaflowLLM entityToModel(LLMEntity entity) { + GeaflowLLM model = super.entityToModel(entity); + model.setType(entity.getType()); + model.setUrl(entity.getUrl()); + model.setArgs(JSON.parseObject(entity.getArgs(), GeaflowConfig.class)); + return model; + } - public GeaflowLLM convert(LLMEntity entity) { - return entityToModel(entity); - } + public GeaflowLLM convert(LLMEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/NameConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/NameConverter.java index 1bf9ff606..481223b5a 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/NameConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/NameConverter.java @@ -22,28 +22,29 @@ import org.apache.geaflow.console.common.dal.entity.NameEntity; import org.apache.geaflow.console.core.model.GeaflowName; -public abstract class NameConverter extends IdConverter { +public abstract class NameConverter + extends IdConverter { - protected E modelToEntity(M model) { - E entity = super.modelToEntity(model); - entity.setName(model.getName()); - entity.setComment(model.getComment()); - return entity; - } + protected E modelToEntity(M model) { + E entity = super.modelToEntity(model); + entity.setName(model.getName()); + entity.setComment(model.getComment()); + return entity; + } - @Override - protected M entityToModel(E entity) { - M model = super.entityToModel(entity); - model.setName(entity.getName()); - model.setComment(entity.getComment()); - return model; - } + @Override + protected M entityToModel(E entity) { + M model = super.entityToModel(entity); + model.setName(entity.getName()); + model.setComment(entity.getComment()); + return model; + } - @Override - protected M entityToModel(E entity, Class clazz) { - M model = super.entityToModel(entity, clazz); - model.setName(entity.getName()); - model.setComment(entity.getComment()); - return model; - } + @Override + protected M entityToModel(E entity, Class clazz) { + M model = super.entityToModel(entity, clazz); + model.setName(entity.getName()); + model.setComment(entity.getComment()); + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConfigConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConfigConverter.java index 41402d6bc..b00fefb76 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConfigConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConfigConverter.java @@ -19,35 +19,36 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.console.common.dal.entity.PluginConfigEntity; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class PluginConfigConverter extends NameConverter { - @Override - protected PluginConfigEntity modelToEntity(GeaflowPluginConfig model) { - PluginConfigEntity entity = super.modelToEntity(model); - entity.setType(model.getType()); - entity.setConfig(JSON.toJSONString(model.getConfig())); - entity.setCategory(model.getCategory()); - return entity; - } + @Override + protected PluginConfigEntity modelToEntity(GeaflowPluginConfig model) { + PluginConfigEntity entity = super.modelToEntity(model); + entity.setType(model.getType()); + entity.setConfig(JSON.toJSONString(model.getConfig())); + entity.setCategory(model.getCategory()); + return entity; + } - @Override - protected GeaflowPluginConfig entityToModel(PluginConfigEntity entity) { - GeaflowPluginConfig model = super.entityToModel(entity); - model.setType(GeaflowPluginType.getName(entity.getType())); - model.setConfig(JSON.parseObject(entity.getConfig(), GeaflowConfig.class)); - model.setCategory(entity.getCategory()); - return model; - } + @Override + protected GeaflowPluginConfig entityToModel(PluginConfigEntity entity) { + GeaflowPluginConfig model = super.entityToModel(entity); + model.setType(GeaflowPluginType.getName(entity.getType())); + model.setConfig(JSON.parseObject(entity.getConfig(), GeaflowConfig.class)); + model.setCategory(entity.getCategory()); + return model; + } - public GeaflowPluginConfig convert(PluginConfigEntity entity) { - return entityToModel(entity); - } + public GeaflowPluginConfig convert(PluginConfigEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConverter.java index f44254e7a..81e2636f6 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/PluginConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.Optional; + import org.apache.geaflow.console.common.dal.entity.PluginEntity; import org.apache.geaflow.console.core.model.GeaflowId; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; @@ -29,27 +30,28 @@ @Component public class PluginConverter extends NameConverter { - @Override - protected PluginEntity modelToEntity(GeaflowPlugin model) { - PluginEntity entity = super.modelToEntity(model); - entity.setPluginType(model.getType()); - entity.setPluginCategory(model.getCategory()); - entity.setJarPackageId(Optional.ofNullable(model.getJarPackage()).map(GeaflowId::getId).orElse(null)); - return entity; - } + @Override + protected PluginEntity modelToEntity(GeaflowPlugin model) { + PluginEntity entity = super.modelToEntity(model); + entity.setPluginType(model.getType()); + entity.setPluginCategory(model.getCategory()); + entity.setJarPackageId( + Optional.ofNullable(model.getJarPackage()).map(GeaflowId::getId).orElse(null)); + return entity; + } - @Override - protected GeaflowPlugin entityToModel(PluginEntity entity) { - GeaflowPlugin model = super.entityToModel(entity); - model.setType(entity.getPluginType()); - model.setCategory(entity.getPluginCategory()); - model.setSystem(entity.isSystem()); - return model; - } + @Override + protected GeaflowPlugin entityToModel(PluginEntity entity) { + GeaflowPlugin model = super.entityToModel(entity); + model.setType(entity.getPluginType()); + model.setCategory(entity.getPluginCategory()); + model.setSystem(entity.isSystem()); + return model; + } - public GeaflowPlugin convert(PluginEntity entity, GeaflowRemoteFile jarPackage) { - GeaflowPlugin model = entityToModel(entity); - model.setJarPackage(jarPackage); - return model; - } + public GeaflowPlugin convert(PluginEntity entity, GeaflowRemoteFile jarPackage) { + GeaflowPlugin model = entityToModel(entity); + model.setJarPackage(jarPackage); + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ReleaseConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ReleaseConverter.java index a5c60f936..8acbfff87 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ReleaseConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ReleaseConverter.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.console.common.dal.entity.ReleaseEntity; import org.apache.geaflow.console.core.model.cluster.GeaflowCluster; import org.apache.geaflow.console.core.model.config.GeaflowConfig; @@ -29,46 +28,48 @@ import org.apache.geaflow.console.core.model.version.GeaflowVersion; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class ReleaseConverter extends IdConverter { + @Override + public ReleaseEntity modelToEntity(GeaflowRelease model) { + ReleaseEntity entity = super.modelToEntity(model); + entity.setClusterConfig(JSON.toJSONString(model.getClusterConfig())); + entity.setJobConfig(JSON.toJSONString(model.getJobConfig())); + entity.setClusterId(model.getCluster().getId()); + entity.setJobPlan(JSON.toJSONString(model.getJobPlan())); + entity.setJobId(model.getJob().getId()); + entity.setVersionId(model.getVersion().getId()); + entity.setVersion(model.getReleaseVersion()); + entity.setUrl(model.getUrl()); + entity.setMd5(model.getMd5()); + return entity; + } - @Override - public ReleaseEntity modelToEntity(GeaflowRelease model) { - ReleaseEntity entity = super.modelToEntity(model); - entity.setClusterConfig(JSON.toJSONString(model.getClusterConfig())); - entity.setJobConfig(JSON.toJSONString(model.getJobConfig())); - entity.setClusterId(model.getCluster().getId()); - entity.setJobPlan(JSON.toJSONString(model.getJobPlan())); - entity.setJobId(model.getJob().getId()); - entity.setVersionId(model.getVersion().getId()); - entity.setVersion(model.getReleaseVersion()); - entity.setUrl(model.getUrl()); - entity.setMd5(model.getMd5()); - return entity; - } - - public GeaflowRelease convert(ReleaseEntity entity, GeaflowJob job, GeaflowVersion version, GeaflowCluster cluster) { - GeaflowRelease release = super.entityToModel(entity); - // job - release.setJob(job); - // versionNumber - release.setReleaseVersion(entity.getVersion()); - // job config - GeaflowConfig jobConfig = JSON.parseObject(entity.getJobConfig(), GeaflowConfig.class); - release.getJobConfig().putAll(jobConfig); - // cluster config - GeaflowConfig clusterConfig = JSON.parseObject(entity.getClusterConfig(), GeaflowConfig.class); - release.getClusterConfig().putAll(clusterConfig); - // build jobPlan - JobPlan jobPlan = JobPlan.build(entity.getJobPlan()); - release.setJobPlan(jobPlan); - // version - release.setVersion(version); - // cluster - release.setCluster(cluster); - release.setUrl(entity.getUrl()); - release.setMd5(entity.getMd5()); - return release; - } + public GeaflowRelease convert( + ReleaseEntity entity, GeaflowJob job, GeaflowVersion version, GeaflowCluster cluster) { + GeaflowRelease release = super.entityToModel(entity); + // job + release.setJob(job); + // versionNumber + release.setReleaseVersion(entity.getVersion()); + // job config + GeaflowConfig jobConfig = JSON.parseObject(entity.getJobConfig(), GeaflowConfig.class); + release.getJobConfig().putAll(jobConfig); + // cluster config + GeaflowConfig clusterConfig = JSON.parseObject(entity.getClusterConfig(), GeaflowConfig.class); + release.getClusterConfig().putAll(clusterConfig); + // build jobPlan + JobPlan jobPlan = JobPlan.build(entity.getJobPlan()); + release.setJobPlan(jobPlan); + // version + release.setVersion(version); + // cluster + release.setCluster(cluster); + release.setUrl(entity.getUrl()); + release.setMd5(entity.getMd5()); + return release; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/RemoteFileConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/RemoteFileConverter.java index ccd62496b..441a7782c 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/RemoteFileConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/RemoteFileConverter.java @@ -26,27 +26,27 @@ @Component public class RemoteFileConverter extends NameConverter { - @Override - protected RemoteFileEntity modelToEntity(GeaflowRemoteFile model) { - RemoteFileEntity entity = super.modelToEntity(model); - entity.setMd5(model.getMd5()); - entity.setType(model.getType()); - entity.setPath(model.getPath()); - entity.setUrl(model.getUrl()); - return entity; - } + @Override + protected RemoteFileEntity modelToEntity(GeaflowRemoteFile model) { + RemoteFileEntity entity = super.modelToEntity(model); + entity.setMd5(model.getMd5()); + entity.setType(model.getType()); + entity.setPath(model.getPath()); + entity.setUrl(model.getUrl()); + return entity; + } - @Override - protected GeaflowRemoteFile entityToModel(RemoteFileEntity entity) { - GeaflowRemoteFile model = super.entityToModel(entity); - model.setPath(entity.getPath()); - model.setMd5(entity.getMd5()); - model.setUrl(entity.getUrl()); - model.setType(entity.getType()); - return model; - } + @Override + protected GeaflowRemoteFile entityToModel(RemoteFileEntity entity) { + GeaflowRemoteFile model = super.entityToModel(entity); + model.setPath(entity.getPath()); + model.setMd5(entity.getMd5()); + model.setUrl(entity.getUrl()); + model.setType(entity.getType()); + return model; + } - public GeaflowRemoteFile convert(RemoteFileEntity entity) { - return entityToModel(entity); - } + public GeaflowRemoteFile convert(RemoteFileEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/StatementConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/StatementConverter.java index 6116e3a7f..5f0ab9ab3 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/StatementConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/StatementConverter.java @@ -26,23 +26,22 @@ @Component public class StatementConverter extends IdConverter { - @Override - protected StatementEntity modelToEntity(GeaflowStatement model) { - StatementEntity entity = super.modelToEntity(model); - entity.setScript(model.getScript()); - entity.setStatus(model.getStatus()); - entity.setResult(model.getResult()); - entity.setJobId(model.getJobId()); - return entity; - } + @Override + protected StatementEntity modelToEntity(GeaflowStatement model) { + StatementEntity entity = super.modelToEntity(model); + entity.setScript(model.getScript()); + entity.setStatus(model.getStatus()); + entity.setResult(model.getResult()); + entity.setJobId(model.getJobId()); + return entity; + } - - public GeaflowStatement convert(StatementEntity entity) { - GeaflowStatement model = super.entityToModel(entity); - model.setScript(entity.getScript()); - model.setStatus(entity.getStatus()); - model.setResult(entity.getResult()); - model.setJobId(entity.getJobId()); - return model; - } + public GeaflowStatement convert(StatementEntity entity) { + GeaflowStatement model = super.entityToModel(entity); + model.setScript(entity.getScript()); + model.setStatus(entity.getStatus()); + model.setResult(entity.getResult()); + model.setJobId(entity.getJobId()); + return model; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/SystemConfigConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/SystemConfigConverter.java index 7ddf23c48..a4d050b4f 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/SystemConfigConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/SystemConfigConverter.java @@ -26,22 +26,22 @@ @Component public class SystemConfigConverter extends NameConverter { - @Override - protected SystemConfigEntity modelToEntity(GeaflowSystemConfig model) { - SystemConfigEntity entity = super.modelToEntity(model); - entity.setTenantId(model.getTenantId()); - entity.setValue(model.getValue()); - return entity; - } + @Override + protected SystemConfigEntity modelToEntity(GeaflowSystemConfig model) { + SystemConfigEntity entity = super.modelToEntity(model); + entity.setTenantId(model.getTenantId()); + entity.setValue(model.getValue()); + return entity; + } - @Override - protected GeaflowSystemConfig entityToModel(SystemConfigEntity entity) { - GeaflowSystemConfig model = super.entityToModel(entity); - model.setValue(entity.getValue()); - return model; - } + @Override + protected GeaflowSystemConfig entityToModel(SystemConfigEntity entity) { + GeaflowSystemConfig model = super.entityToModel(entity); + model.setValue(entity.getValue()); + return model; + } - public GeaflowSystemConfig convert(SystemConfigEntity entity) { - return entityToModel(entity); - } + public GeaflowSystemConfig convert(SystemConfigEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TableConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TableConverter.java index 6ef361b84..d7aaada5d 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TableConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TableConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.TableEntity; import org.apache.geaflow.console.core.model.data.GeaflowField; import org.apache.geaflow.console.core.model.data.GeaflowTable; @@ -29,17 +30,18 @@ @Component public class TableConverter extends DataConverter { - @Override - protected TableEntity modelToEntity(GeaflowTable model) { - TableEntity entity = super.modelToEntity(model); - entity.setPluginConfigId(model.getPluginConfig().getId()); - return entity; - } + @Override + protected TableEntity modelToEntity(GeaflowTable model) { + TableEntity entity = super.modelToEntity(model); + entity.setPluginConfigId(model.getPluginConfig().getId()); + return entity; + } - public GeaflowTable convert(TableEntity entity, List fields, GeaflowPluginConfig pluginConfig) { - GeaflowTable table = entityToModel(entity); - table.addFields(fields); - table.setPluginConfig(pluginConfig); - return table; - } + public GeaflowTable convert( + TableEntity entity, List fields, GeaflowPluginConfig pluginConfig) { + GeaflowTable table = entityToModel(entity); + table.addFields(fields); + table.setPluginConfig(pluginConfig); + return table; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TaskConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TaskConverter.java index 539950417..0fb6715c7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TaskConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TaskConverter.java @@ -19,8 +19,8 @@ package org.apache.geaflow.console.core.service.converter; -import com.alibaba.fastjson.JSON; import java.util.Optional; + import org.apache.geaflow.console.common.dal.entity.TaskEntity; import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; import org.apache.geaflow.console.core.model.release.GeaflowRelease; @@ -28,53 +28,56 @@ import org.apache.geaflow.console.core.model.task.GeaflowTaskHandle; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class TaskConverter extends IdConverter { - @Override - protected TaskEntity modelToEntity(GeaflowTask model) { - TaskEntity task = super.modelToEntity(model); - task.setType(model.getType()); - task.setStatus(model.getStatus()); - task.setStartTime(model.getStartTime()); - task.setEndTime(model.getEndTime()); - task.setReleaseId(model.getRelease().getId()); - task.setHaMetaConfigId(model.getHaMetaPluginConfig().getId()); - task.setDataConfigId(model.getDataPluginConfig().getId()); - task.setMetricConfigId(model.getMetricPluginConfig().getId()); - task.setRuntimeMetaConfigId(model.getRuntimeMetaPluginConfig().getId()); - task.setJobId(model.getRelease().getJob().getId()); - task.setToken(model.getToken()); - task.setHandle(Optional.ofNullable(model.getHandle()).map(JSON::toJSONString).orElse(null)); - task.setHost(model.getHost()); - return task; - } - - @Override - protected GeaflowTask entityToModel(TaskEntity entity) { - GeaflowTask task = super.entityToModel(entity); - task.setStatus(entity.getStatus()); - task.setType(entity.getType()); - task.setStartTime(entity.getStartTime()); - task.setEndTime(entity.getEndTime()); - task.setToken(entity.getToken()); - task.setHandle(GeaflowTaskHandle.parse(entity.getHandle())); - task.setHost(entity.getHost()); - return task; - } + @Override + protected TaskEntity modelToEntity(GeaflowTask model) { + TaskEntity task = super.modelToEntity(model); + task.setType(model.getType()); + task.setStatus(model.getStatus()); + task.setStartTime(model.getStartTime()); + task.setEndTime(model.getEndTime()); + task.setReleaseId(model.getRelease().getId()); + task.setHaMetaConfigId(model.getHaMetaPluginConfig().getId()); + task.setDataConfigId(model.getDataPluginConfig().getId()); + task.setMetricConfigId(model.getMetricPluginConfig().getId()); + task.setRuntimeMetaConfigId(model.getRuntimeMetaPluginConfig().getId()); + task.setJobId(model.getRelease().getJob().getId()); + task.setToken(model.getToken()); + task.setHandle(Optional.ofNullable(model.getHandle()).map(JSON::toJSONString).orElse(null)); + task.setHost(model.getHost()); + return task; + } + @Override + protected GeaflowTask entityToModel(TaskEntity entity) { + GeaflowTask task = super.entityToModel(entity); + task.setStatus(entity.getStatus()); + task.setType(entity.getType()); + task.setStartTime(entity.getStartTime()); + task.setEndTime(entity.getEndTime()); + task.setToken(entity.getToken()); + task.setHandle(GeaflowTaskHandle.parse(entity.getHandle())); + task.setHost(entity.getHost()); + return task; + } - public GeaflowTask convert(TaskEntity entity, GeaflowRelease release, - GeaflowPluginConfig runtimeMetaPluginConfig, - GeaflowPluginConfig haMetaPluginConfig, - GeaflowPluginConfig metricPluginConfig, - GeaflowPluginConfig dataPluginConfig) { - GeaflowTask task = this.entityToModel(entity); - task.setRelease(release); - task.setRuntimeMetaPluginConfig(runtimeMetaPluginConfig); - task.setHaMetaPluginConfig(haMetaPluginConfig); - task.setMetricPluginConfig(metricPluginConfig); - task.setDataPluginConfig(dataPluginConfig); - return task; - } + public GeaflowTask convert( + TaskEntity entity, + GeaflowRelease release, + GeaflowPluginConfig runtimeMetaPluginConfig, + GeaflowPluginConfig haMetaPluginConfig, + GeaflowPluginConfig metricPluginConfig, + GeaflowPluginConfig dataPluginConfig) { + GeaflowTask task = this.entityToModel(entity); + task.setRelease(release); + task.setRuntimeMetaPluginConfig(runtimeMetaPluginConfig); + task.setHaMetaPluginConfig(haMetaPluginConfig); + task.setMetricPluginConfig(metricPluginConfig); + task.setDataPluginConfig(dataPluginConfig); + return task; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TenantConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TenantConverter.java index b830ff34f..9e0688935 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TenantConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/TenantConverter.java @@ -26,7 +26,7 @@ @Component public class TenantConverter extends NameConverter { - public GeaflowTenant convert(TenantEntity entity) { - return entityToModel(entity); - } + public GeaflowTenant convert(TenantEntity entity) { + return entityToModel(entity); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/UserConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/UserConverter.java index 0f373b552..dabf621ad 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/UserConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/UserConverter.java @@ -27,27 +27,27 @@ @Component public class UserConverter extends NameConverter { - @Override - protected UserEntity modelToEntity(GeaflowUser model) { - UserEntity entity = super.modelToEntity(model); - entity.setPhone(model.getPhone()); - entity.setEmail(model.getEmail()); - entity.setPasswordSign(Md5Util.encodeString(model.getPassword())); - return entity; - } + @Override + protected UserEntity modelToEntity(GeaflowUser model) { + UserEntity entity = super.modelToEntity(model); + entity.setPhone(model.getPhone()); + entity.setEmail(model.getEmail()); + entity.setPasswordSign(Md5Util.encodeString(model.getPassword())); + return entity; + } - @Override - protected GeaflowUser entityToModel(UserEntity entity) { - GeaflowUser model = super.entityToModel(entity); - model.setPhone(entity.getPhone()); - model.setEmail(entity.getEmail()); - return model; - } + @Override + protected GeaflowUser entityToModel(UserEntity entity) { + GeaflowUser model = super.entityToModel(entity); + model.setPhone(entity.getPhone()); + model.setEmail(entity.getEmail()); + return model; + } - public GeaflowUser convert(UserEntity entity) { - GeaflowUser user = entityToModel(entity); - user.setPhone(entity.getPhone()); - user.setEmail(entity.getEmail()); - return user; - } + public GeaflowUser convert(UserEntity entity) { + GeaflowUser user = entityToModel(entity); + user.setPhone(entity.getPhone()); + user.setEmail(entity.getEmail()); + return user; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VersionConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VersionConverter.java index 32c104c27..de1fec140 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VersionConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VersionConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.Optional; + import org.apache.geaflow.console.common.dal.entity.VersionEntity; import org.apache.geaflow.console.core.model.GeaflowId; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; @@ -29,27 +30,31 @@ @Component public class VersionConverter extends NameConverter { - @Override - protected VersionEntity modelToEntity(GeaflowVersion model) { - VersionEntity version = super.modelToEntity(model); - version.setPublish(model.isPublish()); - Optional.ofNullable(model.getEngineJarPackage()).map(GeaflowId::getId).ifPresent(version::setEngineJarId); - Optional.ofNullable(model.getLangJarPackage()).map(GeaflowId::getId).ifPresent(version::setLangJarId); - return version; - } - + @Override + protected VersionEntity modelToEntity(GeaflowVersion model) { + VersionEntity version = super.modelToEntity(model); + version.setPublish(model.isPublish()); + Optional.ofNullable(model.getEngineJarPackage()) + .map(GeaflowId::getId) + .ifPresent(version::setEngineJarId); + Optional.ofNullable(model.getLangJarPackage()) + .map(GeaflowId::getId) + .ifPresent(version::setLangJarId); + return version; + } - @Override - protected GeaflowVersion entityToModel(VersionEntity entity) { - GeaflowVersion geaflowVersion = super.entityToModel(entity); - geaflowVersion.setPublish(entity.isPublish()); - return geaflowVersion; - } + @Override + protected GeaflowVersion entityToModel(VersionEntity entity) { + GeaflowVersion geaflowVersion = super.entityToModel(entity); + geaflowVersion.setPublish(entity.isPublish()); + return geaflowVersion; + } - public GeaflowVersion convert(VersionEntity entity, GeaflowRemoteFile engineJar, GeaflowRemoteFile langJar) { - GeaflowVersion version = entityToModel(entity); - version.setEngineJarPackage(engineJar); - version.setLangJarPackage(langJar); - return version; - } + public GeaflowVersion convert( + VersionEntity entity, GeaflowRemoteFile engineJar, GeaflowRemoteFile langJar) { + GeaflowVersion version = entityToModel(entity); + version.setEngineJarPackage(engineJar); + version.setLangJarPackage(langJar); + return version; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VertexConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VertexConverter.java index 86390641d..da1b63880 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VertexConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/VertexConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.VertexEntity; import org.apache.geaflow.console.core.model.data.GeaflowField; import org.apache.geaflow.console.core.model.data.GeaflowVertex; @@ -28,9 +29,9 @@ @Component public class VertexConverter extends DataConverter { - public GeaflowVertex convert(VertexEntity entity, List fields) { - GeaflowVertex vertex = entityToModel(entity); - vertex.addFields(fields); - return vertex; - } + public GeaflowVertex convert(VertexEntity entity, List fields) { + GeaflowVertex vertex = entityToModel(entity); + vertex.addFields(fields); + return vertex; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ViewConverter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ViewConverter.java index 4d50a5e55..027a92755 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ViewConverter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/converter/ViewConverter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.converter; import java.util.List; + import org.apache.geaflow.console.common.dal.entity.ViewEntity; import org.apache.geaflow.console.core.model.code.GeaflowCode; import org.apache.geaflow.console.core.model.data.GeaflowField; @@ -29,25 +30,25 @@ @Component public class ViewConverter extends DataConverter { - @Override - protected ViewEntity modelToEntity(GeaflowView model) { - ViewEntity entity = super.modelToEntity(model); - entity.setCategory(model.getCategory()); - entity.setCode(model.getCode().getText()); - return entity; - } + @Override + protected ViewEntity modelToEntity(GeaflowView model) { + ViewEntity entity = super.modelToEntity(model); + entity.setCategory(model.getCategory()); + entity.setCode(model.getCode().getText()); + return entity; + } - @Override - protected GeaflowView entityToModel(ViewEntity entity) { - GeaflowView model = super.entityToModel(entity); - model.setCategory(entity.getCategory()); - model.setCode(new GeaflowCode(entity.getCode())); - return model; - } + @Override + protected GeaflowView entityToModel(ViewEntity entity) { + GeaflowView model = super.entityToModel(entity); + model.setCategory(entity.getCategory()); + model.setCode(new GeaflowCode(entity.getCode())); + return model; + } - public GeaflowView convert(ViewEntity entity, List fields) { - GeaflowView view = entityToModel(entity); - view.addFields(fields); - return view; - } + public GeaflowView convert(ViewEntity entity, List fields) { + GeaflowView view = entityToModel(entity); + view.addFields(fields); + return view; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/factory/GeaflowDataFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/factory/GeaflowDataFactory.java index 83e18610d..091879691 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/factory/GeaflowDataFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/factory/GeaflowDataFactory.java @@ -33,33 +33,32 @@ @Component public class GeaflowDataFactory { - /** - * generate resource data by resource type. - */ - public static GeaflowData get(String name, String comment, String instanceId, GeaflowResourceType resourceType) { - GeaflowData data; - switch (resourceType) { - case GRAPH: - data = new GeaflowGraph(name, comment); - ((GeaflowGraph) data).setPluginConfig(new GeaflowPluginConfig()); - break; - case TABLE: - data = new GeaflowTable(name, comment); - ((GeaflowTable) data).setPluginConfig(new GeaflowPluginConfig()); - break; - case VERTEX: - data = new GeaflowVertex(name, comment); - break; - case EDGE: - data = new GeaflowEdge(name, comment); - break; - case FUNCTION: - data = new GeaflowFunction(name, comment); - break; - default: - throw new GeaflowException("Unsupported resource type", resourceType); - } - data.setInstanceId(instanceId); - return data; + /** generate resource data by resource type. */ + public static GeaflowData get( + String name, String comment, String instanceId, GeaflowResourceType resourceType) { + GeaflowData data; + switch (resourceType) { + case GRAPH: + data = new GeaflowGraph(name, comment); + ((GeaflowGraph) data).setPluginConfig(new GeaflowPluginConfig()); + break; + case TABLE: + data = new GeaflowTable(name, comment); + ((GeaflowTable) data).setPluginConfig(new GeaflowPluginConfig()); + break; + case VERTEX: + data = new GeaflowVertex(name, comment); + break; + case EDGE: + data = new GeaflowEdge(name, comment); + break; + case FUNCTION: + data = new GeaflowFunction(name, comment); + break; + default: + throw new GeaflowException("Unsupported resource type", resourceType); } + data.setInstanceId(instanceId); + return data; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/DfsFileClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/DfsFileClient.java index 5670ead3d..f4e9cff3e 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/DfsFileClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/DfsFileClient.java @@ -23,6 +23,7 @@ import java.io.InputStream; import java.net.URI; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.config.GeaflowConfig; @@ -37,87 +38,86 @@ public class DfsFileClient implements RemoteFileClient { - private DfsPluginConfigClass dfsConfig; - - private FileSystem fileSystem; - - @Override - public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { - this.dfsConfig = config.getConfig().parse(DfsPluginConfigClass.class); - GeaflowConfig geaflowConfig = new GeaflowConfig(); - geaflowConfig.put(DFS_URI_KEY, dfsConfig.getDefaultFs()); - geaflowConfig.putAll(dfsConfig.getExtendConfig()); + private DfsPluginConfigClass dfsConfig; - Configuration conf = new Configuration(); - geaflowConfig.toStringMap().forEach(conf::set); - - try { - this.fileSystem = FileSystem.get(new URI(dfsConfig.getDefaultFs()), conf); - } catch (Exception e) { - throw new GeaflowException("Init DfsFileClient failed", e); - } - } + private FileSystem fileSystem; - @Override - public void upload(String path, InputStream inputStream) { - String fullPath = getFullPath(path); - try { - Path dfsPath = new Path(fullPath); - FSDataOutputStream outputStream = this.fileSystem.create(dfsPath); - IOUtils.copyBytes(inputStream, outputStream, 1024 * 1024 * 8, true); - } catch (Exception e) { - throw new GeaflowException("Upload file {} failed", fullPath, e); - } + @Override + public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { + this.dfsConfig = config.getConfig().parse(DfsPluginConfigClass.class); + GeaflowConfig geaflowConfig = new GeaflowConfig(); + geaflowConfig.put(DFS_URI_KEY, dfsConfig.getDefaultFs()); + geaflowConfig.putAll(dfsConfig.getExtendConfig()); - } + Configuration conf = new Configuration(); + geaflowConfig.toStringMap().forEach(conf::set); - @Override - public InputStream download(String path) { - String fullPath = getFullPath(path); - try { - Path dfsPath = new Path(fullPath); - if (!fileSystem.exists(dfsPath)) { - throw new GeaflowException("File doesn't exist {}", fullPath); - } - - return this.fileSystem.open(dfsPath); - } catch (Exception e) { - throw new GeaflowException("Download file {} failed", fullPath, e); - } + try { + this.fileSystem = FileSystem.get(new URI(dfsConfig.getDefaultFs()), conf); + } catch (Exception e) { + throw new GeaflowException("Init DfsFileClient failed", e); } - - @Override - public void delete(String path) { - String fullPath = getFullPath(path); - try { - Path dfsPath = new Path(fullPath); - if (!fileSystem.exists(dfsPath)) { - return; - } - - this.fileSystem.delete(dfsPath, true); - } catch (Exception e) { - throw new GeaflowException("Delete file {} failed", fullPath, e); - } + } + + @Override + public void upload(String path, InputStream inputStream) { + String fullPath = getFullPath(path); + try { + Path dfsPath = new Path(fullPath); + FSDataOutputStream outputStream = this.fileSystem.create(dfsPath); + IOUtils.copyBytes(inputStream, outputStream, 1024 * 1024 * 8, true); + } catch (Exception e) { + throw new GeaflowException("Upload file {} failed", fullPath, e); } - - @Override - public String getUrl(String path) { - return String.format("%s%s", dfsConfig.getDefaultFs(), getFullPath(path)); + } + + @Override + public InputStream download(String path) { + String fullPath = getFullPath(path); + try { + Path dfsPath = new Path(fullPath); + if (!fileSystem.exists(dfsPath)) { + throw new GeaflowException("File doesn't exist {}", fullPath); + } + + return this.fileSystem.open(dfsPath); + } catch (Exception e) { + throw new GeaflowException("Download file {} failed", fullPath, e); } - - @Override - public boolean checkFileExists(String path) { - //TODO - return false; + } + + @Override + public void delete(String path) { + String fullPath = getFullPath(path); + try { + Path dfsPath = new Path(fullPath); + if (!fileSystem.exists(dfsPath)) { + return; + } + + this.fileSystem.delete(dfsPath, true); + } catch (Exception e) { + throw new GeaflowException("Delete file {} failed", fullPath, e); } - - public String getFullPath(String path) { - String root = dfsConfig.getRoot(); - if (!StringUtils.startsWith(root, "/")) { - throw new GeaflowException("Invalid root config, should start with /"); - } - root = StringUtils.removeEnd(root, "/"); - return String.format("%s/%s", root, path); + } + + @Override + public String getUrl(String path) { + return String.format("%s%s", dfsConfig.getDefaultFs(), getFullPath(path)); + } + + @Override + public boolean checkFileExists(String path) { + // TODO + return false; + } + + public String getFullPath(String path) { + String root = dfsConfig.getRoot(); + if (!StringUtils.startsWith(root, "/")) { + throw new GeaflowException("Invalid root config, should start with /"); } + root = StringUtils.removeEnd(root, "/"); + return String.format("%s/%s", root, path); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/FileRefService.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/FileRefService.java index 558138d15..4a1117834 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/FileRefService.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/FileRefService.java @@ -21,5 +21,5 @@ public interface FileRefService { - long getFileRefCount(String jarId, String pluginId); + long getFileRefCount(String jarId, String pluginId); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileClient.java index e798a073c..32634fae7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileClient.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.file; import java.io.InputStream; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.FileUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; @@ -30,54 +31,53 @@ public class LocalFileClient implements RemoteFileClient { - private final String gatewayUrl; + private final String gatewayUrl; - private LocalPluginConfigClass localConfig; + private LocalPluginConfigClass localConfig; - public LocalFileClient(String gatewayUrl) { - this.gatewayUrl = gatewayUrl; - } + public LocalFileClient(String gatewayUrl) { + this.gatewayUrl = gatewayUrl; + } - @Override - public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { - this.localConfig = config.getConfig().parse(LocalPluginConfigClass.class); - FileUtil.mkdir(localConfig.getRoot()); - } + @Override + public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { + this.localConfig = config.getConfig().parse(LocalPluginConfigClass.class); + FileUtil.mkdir(localConfig.getRoot()); + } - @Override - public void upload(String path, InputStream inputStream) { - FileUtil.writeFile(getFullPath(path), inputStream); - } + @Override + public void upload(String path, InputStream inputStream) { + FileUtil.writeFile(getFullPath(path), inputStream); + } - @Override - public InputStream download(String path) { - return FileUtil.readFileStream(getFullPath(path)); - } + @Override + public InputStream download(String path) { + return FileUtil.readFileStream(getFullPath(path)); + } - @Override - public void delete(String path) { - FileUtil.delete(getFullPath(path)); - } - - @Override - public String getUrl(String path) { - return GeaflowTask.getTaskFileUrlFormatter(gatewayUrl, getFullPath(path)); - } + @Override + public void delete(String path) { + FileUtil.delete(getFullPath(path)); + } - @Override - public boolean checkFileExists(String path) { - //TODO - return false; - } + @Override + public String getUrl(String path) { + return GeaflowTask.getTaskFileUrlFormatter(gatewayUrl, getFullPath(path)); + } - public String getFullPath(String path) { - String root = localConfig.getRoot(); - if (!StringUtils.startsWith(root, "/")) { - throw new GeaflowException("Invalid root config, should start with /"); - } - root = StringUtils.removeEnd(root, "/"); + @Override + public boolean checkFileExists(String path) { + // TODO + return false; + } - return String.format("%s/%s", root, path); + public String getFullPath(String path) { + String root = localConfig.getRoot(); + if (!StringUtils.startsWith(root, "/")) { + throw new GeaflowException("Invalid root config, should start with /"); } + root = StringUtils.removeEnd(root, "/"); + return String.format("%s/%s", root, path); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileFactory.java index 2f2fdb4dc..d2b95c70a 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/LocalFileFactory.java @@ -21,7 +21,7 @@ import java.io.File; import java.io.InputStream; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.FileUtil; import org.apache.geaflow.console.common.util.Fmt; import org.apache.geaflow.console.common.util.Md5Util; @@ -30,89 +30,90 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class LocalFileFactory { - public static final String LOCAL_VERSION_FILE_DIRECTORY = "/tmp/geaflow/local/versions"; + public static final String LOCAL_VERSION_FILE_DIRECTORY = "/tmp/geaflow/local/versions"; - public static final String LOCAL_TASK_FILE_DIRECTORY = "/tmp/geaflow/local/tasks"; + public static final String LOCAL_TASK_FILE_DIRECTORY = "/tmp/geaflow/local/tasks"; - public static final String LOCAL_USER_FILE_DIRECTORY = "/tmp/geaflow/local/users"; + public static final String LOCAL_USER_FILE_DIRECTORY = "/tmp/geaflow/local/users"; - @Autowired - private RemoteFileStorage remoteFileStorage; - - public File getVersionFile(String versionName, GeaflowRemoteFile remoteFile) { - String filePath = getVersionFilePath(versionName, remoteFile.getName()); - return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); - } - - public File getUserFile(String userId, GeaflowRemoteFile remoteFile) { - String filePath = getUserFilePath(userId, remoteFile.getName()); - return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); - } + @Autowired private RemoteFileStorage remoteFileStorage; - public File getTaskUserFile(String runtimeTaskId, GeaflowRemoteFile remoteFile) { - String filePath = getTaskFilePath(runtimeTaskId, remoteFile.getName()); - return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); - } + public File getVersionFile(String versionName, GeaflowRemoteFile remoteFile) { + String filePath = getVersionFilePath(versionName, remoteFile.getName()); + return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); + } - public File getTaskReleaseFile(String runtimeTaskId, String jobId, GeaflowRelease release) { - String path = RemoteFileStorage.getPackageFilePath(jobId, release.getReleaseVersion()); - String filePath = getTaskFilePath(runtimeTaskId, new File(path).getName()); - return downloadFileWithMd5(path, filePath, release.getMd5()); - } + public File getUserFile(String userId, GeaflowRemoteFile remoteFile) { + String filePath = getUserFilePath(userId, remoteFile.getName()); + return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); + } - public static String getVersionFilePath(String versionName, String fileName) { - return Fmt.as("{}/{}/{}", LOCAL_VERSION_FILE_DIRECTORY, versionName, fileName); - } + public File getTaskUserFile(String runtimeTaskId, GeaflowRemoteFile remoteFile) { + String filePath = getTaskFilePath(runtimeTaskId, remoteFile.getName()); + return downloadFileWithMd5(remoteFile.getPath(), filePath, remoteFile.getMd5()); + } - public static String getTaskFilePath(String runtimeTaskId, String fileName) { - return Fmt.as("{}/{}/{}", LOCAL_TASK_FILE_DIRECTORY, runtimeTaskId, fileName); - } + public File getTaskReleaseFile(String runtimeTaskId, String jobId, GeaflowRelease release) { + String path = RemoteFileStorage.getPackageFilePath(jobId, release.getReleaseVersion()); + String filePath = getTaskFilePath(runtimeTaskId, new File(path).getName()); + return downloadFileWithMd5(path, filePath, release.getMd5()); + } - public static String getUserFilePath(String userId, String fileName) { - return Fmt.as("{}/{}/{}", LOCAL_USER_FILE_DIRECTORY, userId, fileName); - } + public static String getVersionFilePath(String versionName, String fileName) { + return Fmt.as("{}/{}/{}", LOCAL_VERSION_FILE_DIRECTORY, versionName, fileName); + } - private File downloadFileWithMd5(String remotePath, String localPath, String md5) { - // check file md5 - if (!md5.equals(loadFileMd5(localPath))) { - // delete local files - FileUtil.delete(getMd5FilePath(localPath)); - FileUtil.delete(localPath); + public static String getTaskFilePath(String runtimeTaskId, String fileName) { + return Fmt.as("{}/{}/{}", LOCAL_TASK_FILE_DIRECTORY, runtimeTaskId, fileName); + } - // download file - downloadFile(remotePath, localPath); + public static String getUserFilePath(String userId, String fileName) { + return Fmt.as("{}/{}/{}", LOCAL_USER_FILE_DIRECTORY, userId, fileName); + } - // save file md5 - FileUtil.writeFile(getMd5FilePath(localPath), Md5Util.encodeFile(localPath)); - } + private File downloadFileWithMd5(String remotePath, String localPath, String md5) { + // check file md5 + if (!md5.equals(loadFileMd5(localPath))) { + // delete local files + FileUtil.delete(getMd5FilePath(localPath)); + FileUtil.delete(localPath); - return new File(localPath); - } + // download file + downloadFile(remotePath, localPath); - private String getMd5FilePath(String filePath) { - return filePath + ".md5"; + // save file md5 + FileUtil.writeFile(getMd5FilePath(localPath), Md5Util.encodeFile(localPath)); } - private String loadFileMd5(String filePath) { - if (!FileUtil.exist(filePath)) { - return null; - } + return new File(localPath); + } - String md5FilePath = getMd5FilePath(filePath); - if (!FileUtil.exist(md5FilePath)) { - return null; - } + private String getMd5FilePath(String filePath) { + return filePath + ".md5"; + } - return FileUtil.readFileContent(md5FilePath).trim(); + private String loadFileMd5(String filePath) { + if (!FileUtil.exist(filePath)) { + return null; } - private void downloadFile(String remotePath, String localPath) { - InputStream stream = remoteFileStorage.download(remotePath); - FileUtil.writeFile(localPath, stream); - log.info("Download file {} from {} success", localPath, remotePath); + String md5FilePath = getMd5FilePath(filePath); + if (!FileUtil.exist(md5FilePath)) { + return null; } + + return FileUtil.readFileContent(md5FilePath).trim(); + } + + private void downloadFile(String remotePath, String localPath) { + InputStream stream = remoteFileStorage.download(remotePath); + FileUtil.writeFile(localPath, stream); + log.info("Download file {} from {} success", localPath, remotePath); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/OssFileClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/OssFileClient.java index 41e0c75a4..d989a6497 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/OssFileClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/OssFileClient.java @@ -19,12 +19,8 @@ package org.apache.geaflow.console.core.service.file; -import com.aliyun.oss.OSSClient; -import com.aliyun.oss.OSSException; -import com.aliyun.oss.common.auth.DefaultCredentialProvider; -import com.aliyun.oss.model.OSSObject; import java.io.InputStream; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; @@ -32,63 +28,72 @@ import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; import org.apache.geaflow.console.core.model.plugin.config.OssPluginConfigClass; -@Slf4j -public class OssFileClient implements RemoteFileClient { - - private OssPluginConfigClass ossConfig; - - private OSSClient ossClient; - - @Override - public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { - this.ossConfig = config.getConfig().parse(OssPluginConfigClass.class); - this.ossClient = new OSSClient(ossConfig.getEndpoint(), - new DefaultCredentialProvider(ossConfig.getAccessId(), ossConfig.getSecretKey()), null); - } - - @Override - public void upload(String path, InputStream inputStream) { - this.ossClient.putObject(this.ossConfig.getBucket(), getFullPath(path), inputStream); - } +import com.aliyun.oss.OSSClient; +import com.aliyun.oss.OSSException; +import com.aliyun.oss.common.auth.DefaultCredentialProvider; +import com.aliyun.oss.model.OSSObject; - @Override - public InputStream download(String path) throws OSSException { - OSSObject ossObject = this.ossClient.getObject(this.ossConfig.getBucket(), getFullPath(path)); - return ossObject.getObjectContent(); - } +import lombok.extern.slf4j.Slf4j; - @Override - public void delete(String path) { - this.ossClient.deleteObject(this.ossConfig.getBucket(), getFullPath(path)); - } +@Slf4j +public class OssFileClient implements RemoteFileClient { - @Override - public String getUrl(String path) { - return String.format("http://%s.%s/%s", ossConfig.getBucket(), NetworkUtil.getHost(ossConfig.getEndpoint()), - getFullPath(path)); + private OssPluginConfigClass ossConfig; + + private OSSClient ossClient; + + @Override + public void init(GeaflowPlugin plugin, GeaflowPluginConfig config) { + this.ossConfig = config.getConfig().parse(OssPluginConfigClass.class); + this.ossClient = + new OSSClient( + ossConfig.getEndpoint(), + new DefaultCredentialProvider(ossConfig.getAccessId(), ossConfig.getSecretKey()), + null); + } + + @Override + public void upload(String path, InputStream inputStream) { + this.ossClient.putObject(this.ossConfig.getBucket(), getFullPath(path), inputStream); + } + + @Override + public InputStream download(String path) throws OSSException { + OSSObject ossObject = this.ossClient.getObject(this.ossConfig.getBucket(), getFullPath(path)); + return ossObject.getObjectContent(); + } + + @Override + public void delete(String path) { + this.ossClient.deleteObject(this.ossConfig.getBucket(), getFullPath(path)); + } + + @Override + public String getUrl(String path) { + return String.format( + "http://%s.%s/%s", + ossConfig.getBucket(), NetworkUtil.getHost(ossConfig.getEndpoint()), getFullPath(path)); + } + + @Override + public boolean checkFileExists(String path) { + try { + return ossClient.doesObjectExist(ossConfig.getBucket(), getFullPath(path)); + } catch (Exception e) { + log.warn("check oss file failed", e); + return false; } + } - @Override - public boolean checkFileExists(String path) { - try { - return ossClient.doesObjectExist(ossConfig.getBucket(), getFullPath(path)); - } catch (Exception e) { - log.warn("check oss file failed", e); - return false; - } - + public String getFullPath(String path) { + String root = ossConfig.getRoot(); + if (!StringUtils.startsWith(root, "/")) { + throw new GeaflowException("Invalid root config, should start with /"); } - public String getFullPath(String path) { - String root = ossConfig.getRoot(); - if (!StringUtils.startsWith(root, "/")) { - throw new GeaflowException("Invalid root config, should start with /"); - } - - root = StringUtils.removeStart(root, "/"); - root = StringUtils.removeEnd(root, "/"); - - return root.isEmpty() ? path : String.format("%s/%s", root, path); - } + root = StringUtils.removeStart(root, "/"); + root = StringUtils.removeEnd(root, "/"); + return root.isEmpty() ? path : String.format("%s/%s", root, path); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileClient.java index 913307ec7..d9625198b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileClient.java @@ -20,20 +20,21 @@ package org.apache.geaflow.console.core.service.file; import java.io.InputStream; + import org.apache.geaflow.console.core.model.plugin.GeaflowPlugin; import org.apache.geaflow.console.core.model.plugin.config.GeaflowPluginConfig; public interface RemoteFileClient { - void init(GeaflowPlugin plugin, GeaflowPluginConfig config); + void init(GeaflowPlugin plugin, GeaflowPluginConfig config); - void upload(String path, InputStream inputStream); + void upload(String path, InputStream inputStream); - InputStream download(String path); + InputStream download(String path); - void delete(String path); + void delete(String path); - String getUrl(String path); + String getUrl(String path); - boolean checkFileExists(String path); + boolean checkFileExists(String path); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileStorage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileStorage.java index bc57e3ea8..57eef7db8 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileStorage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/file/RemoteFileStorage.java @@ -20,7 +20,7 @@ package org.apache.geaflow.console.core.service.file; import java.io.InputStream; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.exception.GeaflowIllegalException; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @@ -32,128 +32,128 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class RemoteFileStorage { - private static final String GEAFLOW_FILE_DIRECTORY = "geaflow/files"; + private static final String GEAFLOW_FILE_DIRECTORY = "geaflow/files"; - private static final String USER_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/users/%s/%s"; + private static final String USER_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/users/%s/%s"; - private static final String VERSION_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/versions/%s/%s"; + private static final String VERSION_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/versions/%s/%s"; - private static final String PLUGIN_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/plugins/%s/%s"; + private static final String PLUGIN_FILE_PATH_FORMAT = GEAFLOW_FILE_DIRECTORY + "/plugins/%s/%s"; - private static final String GEAFLOW_PACKAGE_PATH_FORMAT = "geaflow/packages/%s/release-%s.zip"; + private static final String GEAFLOW_PACKAGE_PATH_FORMAT = "geaflow/packages/%s/release-%s.zip"; - private static final String TASKS_FILE_PATH_FORMAT = "geaflow/packages/%s/udfs/%s"; + private static final String TASKS_FILE_PATH_FORMAT = "geaflow/packages/%s/udfs/%s"; - @Autowired - private PluginService pluginService; + @Autowired private PluginService pluginService; - @Autowired - private PluginConfigService pluginConfigService; + @Autowired private PluginConfigService pluginConfigService; - @Autowired - private DeployConfig deployConfig; + @Autowired private DeployConfig deployConfig; - private volatile RemoteFileClient remoteFileClient; + private volatile RemoteFileClient remoteFileClient; - public static String getUserFilePath(String userId, String fileName) { - return String.format(USER_FILE_PATH_FORMAT, userId, fileName); - } + public static String getUserFilePath(String userId, String fileName) { + return String.format(USER_FILE_PATH_FORMAT, userId, fileName); + } - public static String getVersionFilePath(String versionName, String fileName) { - return String.format(VERSION_FILE_PATH_FORMAT, versionName, fileName); - } + public static String getVersionFilePath(String versionName, String fileName) { + return String.format(VERSION_FILE_PATH_FORMAT, versionName, fileName); + } - public static String getPluginFilePath(String pluginName, String fileName) { - return String.format(PLUGIN_FILE_PATH_FORMAT, pluginName, fileName); - } + public static String getPluginFilePath(String pluginName, String fileName) { + return String.format(PLUGIN_FILE_PATH_FORMAT, pluginName, fileName); + } - public static String getPackageFilePath(String jobId, int releaseVersion) { - return String.format(GEAFLOW_PACKAGE_PATH_FORMAT, jobId, releaseVersion); - } + public static String getPackageFilePath(String jobId, int releaseVersion) { + return String.format(GEAFLOW_PACKAGE_PATH_FORMAT, jobId, releaseVersion); + } - public static String getTaskFilePath(String jobId, String fileName) { - return String.format(TASKS_FILE_PATH_FORMAT, jobId, fileName); - } + public static String getTaskFilePath(String jobId, String fileName) { + return String.format(TASKS_FILE_PATH_FORMAT, jobId, fileName); + } + public String upload(String path, InputStream stream) { + checkRemoteFileClient(); + return upload(path, stream, remoteFileClient); + } - public String upload(String path, InputStream stream) { - checkRemoteFileClient(); - return upload(path, stream, remoteFileClient); - } + public String upload(String path, InputStream stream, RemoteFileClient client) { + String url = client.getUrl(path); + log.info("Start upload file, url={}", url); + client.upload(path, stream); + log.info("Upload success, url={}", url); + return url; + } - public String upload(String path, InputStream stream, RemoteFileClient client) { - String url = client.getUrl(path); - log.info("Start upload file, url={}", url); - client.upload(path, stream); - log.info("Upload success, url={}", url); - return url; - } + public InputStream download(String path) { + checkRemoteFileClient(); + log.info("Start download file, url={}", remoteFileClient.getUrl(path)); + return remoteFileClient.download(path); + } - public InputStream download(String path) { - checkRemoteFileClient(); - log.info("Start download file, url={}", remoteFileClient.getUrl(path)); - return remoteFileClient.download(path); - } + public void delete(String path) { + checkRemoteFileClient(); + log.info("Start delete file, url={}", remoteFileClient.getUrl(path)); + remoteFileClient.delete(path); + } - public void delete(String path) { - checkRemoteFileClient(); - log.info("Start delete file, url={}", remoteFileClient.getUrl(path)); - remoteFileClient.delete(path); - } + public String getUrl(String path) { + checkRemoteFileClient(); + return remoteFileClient.getUrl(path); + } - public String getUrl(String path) { - checkRemoteFileClient(); - return remoteFileClient.getUrl(path); + public void reset() { + if (remoteFileClient != null) { + synchronized (RemoteFileStorage.class) { + remoteFileClient = null; + } } + } - public void reset() { - if (remoteFileClient != null) { - synchronized (RemoteFileStorage.class) { - remoteFileClient = null; - } - } - } + public boolean checkFileExists(String path) { + checkRemoteFileClient(); + return remoteFileClient.checkFileExists(path); + } - public boolean checkFileExists(String path) { - checkRemoteFileClient(); - return remoteFileClient.checkFileExists(path); + private void checkRemoteFileClient() { + if (remoteFileClient != null) { + return; } - private void checkRemoteFileClient() { - if (remoteFileClient != null) { - return; + synchronized (RemoteFileStorage.class) { + if (remoteFileClient == null) { + GeaflowPluginCategory category = GeaflowPluginCategory.REMOTE_FILE; + GeaflowPlugin plugin = pluginService.getDefaultSystemPlugin(category); + GeaflowPluginConfig config = + pluginConfigService.getDefaultPluginConfig(category, plugin.getType()); + + RemoteFileClient client; + switch (GeaflowPluginType.of(config.getType())) { + case LOCAL: + client = new LocalFileClient(deployConfig.getGatewayUrl()); + break; + case OSS: + client = new OssFileClient(); + break; + case DFS: + client = new DfsFileClient(); + break; + default: + throw new GeaflowIllegalException( + "Remote file client type {} not supported", plugin.getType()); } - synchronized (RemoteFileStorage.class) { - if (remoteFileClient == null) { - GeaflowPluginCategory category = GeaflowPluginCategory.REMOTE_FILE; - GeaflowPlugin plugin = pluginService.getDefaultSystemPlugin(category); - GeaflowPluginConfig config = pluginConfigService.getDefaultPluginConfig(category, plugin.getType()); - - RemoteFileClient client; - switch (GeaflowPluginType.of(config.getType())) { - case LOCAL: - client = new LocalFileClient(deployConfig.getGatewayUrl()); - break; - case OSS: - client = new OssFileClient(); - break; - case DFS: - client = new DfsFileClient(); - break; - default: - throw new GeaflowIllegalException("Remote file client type {} not supported", plugin.getType()); - } - - client.init(plugin, config); - remoteFileClient = client; - - log.info("Init remote file {} client success", plugin.getType()); - } - } + client.init(plugin, config); + remoteFileClient = client; + + log.info("Init remote file {} client success", plugin.getType()); + } } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/CodefuseClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/CodefuseClient.java index b7b6d12e9..8d73c6ef7 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/CodefuseClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/CodefuseClient.java @@ -19,86 +19,91 @@ package org.apache.geaflow.console.core.service.llm; +import org.apache.geaflow.console.common.util.exception.GeaflowException; +import org.apache.geaflow.console.core.model.llm.CodefuseConfigArgsClass; +import org.apache.geaflow.console.core.model.llm.GeaflowLLM; + import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; + import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; -import org.apache.geaflow.console.common.util.exception.GeaflowException; -import org.apache.geaflow.console.core.model.llm.CodefuseConfigArgsClass; -import org.apache.geaflow.console.core.model.llm.GeaflowLLM; public class CodefuseClient extends LLMClient { - private static final LLMClient INSTANCE = new CodefuseClient(); - - private static final String TEMPLATE = "{\n" - + " \"sceneName\": \"%s\",\n" - + " \"chainName\": \"%s\",\n" - + " \"itemId\": \"gpt\",\n" - + " \"modelEnv\": \"pre\",\n" - + " \"feature\": {\n" - + " \"data\": \"{\\\"api_version\\\":\\\"v2\\\",\\\"out_seq_length\\\":300," - + "\\\"prompts\\\":[{\\\"prompt\\\":[{\\\"content\\\":\\\"%s\\\",\\\"role\\\":\\\"\\\"}]," - + "\\\"repetition_penalty\\\":1.1,\\\"temperature\\\":0.2,\\\"top_k\\\":40,\\\"top_p\\\":0.9}],\\\"stream\\\":false}\"\n" - + " }\n" - + "}"; - - private String getJsonString(CodefuseConfigArgsClass config, String prompt) { - - String sceneName = config.getSceneName(); - String chainName = config.getChainName(); - return String.format(TEMPLATE, sceneName, chainName, prompt); + private static final LLMClient INSTANCE = new CodefuseClient(); + + private static final String TEMPLATE = + "{\n" + + " \"sceneName\": \"%s\",\n" + + " \"chainName\": \"%s\",\n" + + " \"itemId\": \"gpt\",\n" + + " \"modelEnv\": \"pre\",\n" + + " \"feature\": {\n" + + " \"data\":" + + " \"{\\\"api_version\\\":\\\"v2\\\",\\\"out_seq_length\\\":300,\\\"prompts\\\":[{\\\"prompt\\\":[{\\\"content\\\":\\\"%s\\\",\\\"role\\\":\\\"\\\"}],\\\"repetition_penalty\\\":1.1,\\\"temperature\\\":0.2,\\\"top_k\\\":40,\\\"top_p\\\":0.9}],\\\"stream\\\":false}\"\n" + + " }\n" + + "}"; + + private String getJsonString(CodefuseConfigArgsClass config, String prompt) { + + String sceneName = config.getSceneName(); + String chainName = config.getChainName(); + return String.format(TEMPLATE, sceneName, chainName, prompt); + } + + private CodefuseClient() {} + + public static LLMClient getInstance() { + return INSTANCE; + } + + @Override + protected Response sendRequest(GeaflowLLM llm, String prompt) { + try { + CodefuseConfigArgsClass config = getConfig(llm, CodefuseConfigArgsClass.class); + String jsonString = getJsonString(config, prompt); + OkHttpClient client = getHttpClient(config); + MediaType type = MediaType.get("application/json; charset=utf-8"); + RequestBody body = RequestBody.create(jsonString, type); + + Request request = + new Request.Builder() + .url(llm.getUrl()) + .addHeader("Content-Type", "application/json") + .post(body) + .build(); + + Response response = client.newCall(request).execute(); + return response; + + } catch (Exception e) { + throw new RuntimeException(e); } - - private CodefuseClient() { - - } - - public static LLMClient getInstance() { - return INSTANCE; - } - - @Override - protected Response sendRequest(GeaflowLLM llm, String prompt) { - try { - CodefuseConfigArgsClass config = getConfig(llm, CodefuseConfigArgsClass.class); - String jsonString = getJsonString(config, prompt); - OkHttpClient client = getHttpClient(config); - MediaType type = MediaType.get("application/json; charset=utf-8"); - RequestBody body = RequestBody.create(jsonString, type); - - Request request = new Request.Builder() - .url(llm.getUrl()) - .addHeader("Content-Type", "application/json") - .post(body) - .build(); - - Response response = client.newCall(request).execute(); - return response; - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - protected String parseResult(Response response) { - try { - String res = response.body().string(); - JSONObject jsonObject = JSON.parseObject(res); - if (!jsonObject.getBooleanValue("success")) { - throw new GeaflowException("request failed msg:{}", jsonObject.getString("resultMsg")); - } - JSONObject o = (JSONObject) jsonObject.getJSONObject("data").getJSONArray("items").get(0); - JSONArray array = o.getJSONObject("attributes").getJSONObject("res").getJSONArray("generated_code").getJSONArray(0); - return array.get(0).toString(); - - } catch (Exception e) { - throw new RuntimeException(e); - } + } + + @Override + protected String parseResult(Response response) { + try { + String res = response.body().string(); + JSONObject jsonObject = JSON.parseObject(res); + if (!jsonObject.getBooleanValue("success")) { + throw new GeaflowException("request failed msg:{}", jsonObject.getString("resultMsg")); + } + JSONObject o = (JSONObject) jsonObject.getJSONObject("data").getJSONArray("items").get(0); + JSONArray array = + o.getJSONObject("attributes") + .getJSONObject("res") + .getJSONArray("generated_code") + .getJSONArray(0); + return array.get(0).toString(); + + } catch (Exception e) { + throw new RuntimeException(e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/GraphSchemaTranslator.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/GraphSchemaTranslator.java index 7fc9894db..aec80eb6b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/GraphSchemaTranslator.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/GraphSchemaTranslator.java @@ -22,9 +22,7 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.VelocityUtil; import org.apache.geaflow.console.core.model.GeaflowId; @@ -33,45 +31,49 @@ import org.apache.geaflow.console.core.model.data.GeaflowGraph; import org.apache.geaflow.console.core.model.data.GeaflowVertex; -public class GraphSchemaTranslator { - - private static final String VERTEX_TEMPLATE = "template/graphSchema.vm"; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; - private static final String END_FLAG = "END_FLAG"; +public class GraphSchemaTranslator { + private static final String VERTEX_TEMPLATE = "template/graphSchema.vm"; - public static String translateGraphSchema(GeaflowGraph graph) { - HashMap velocityMap = new HashMap<>(); + private static final String END_FLAG = "END_FLAG"; - Map vertexMap = graph.getVertices().values().stream().collect(Collectors.toMap(GeaflowId::getId, e -> e)); - Map edgeMap = graph.getEdges().values().stream().collect(Collectors.toMap(GeaflowId::getId, e -> e)); + public static String translateGraphSchema(GeaflowGraph graph) { + HashMap velocityMap = new HashMap<>(); - Map endpointMap = new HashMap<>(); - // format endpoints - for (GeaflowEndpoint endpoint : graph.getEndpoints()) { - GeaflowEdge edge = edgeMap.get(endpoint.getEdgeId()); - GeaflowVertex source = vertexMap.get(endpoint.getSourceId()); - GeaflowVertex target = vertexMap.get(endpoint.getTargetId()); - if (edge != null && source != null && target != null) { - endpointMap.put(edge.getName(), new NameEndpoint(source.getName(), target.getName())); - } - } + Map vertexMap = + graph.getVertices().values().stream().collect(Collectors.toMap(GeaflowId::getId, e -> e)); + Map edgeMap = + graph.getEdges().values().stream().collect(Collectors.toMap(GeaflowId::getId, e -> e)); - velocityMap.put("vertices", graph.getVertices().values()); - velocityMap.put("edges", graph.getEdges().values()); - velocityMap.put("endpoints", endpointMap); - velocityMap.put("endFlag", END_FLAG); - String s = VelocityUtil.applyResource(VERTEX_TEMPLATE, velocityMap); - return StringUtils.replacePattern(s, ",\n*\\s*" + END_FLAG, ""); + Map endpointMap = new HashMap<>(); + // format endpoints + for (GeaflowEndpoint endpoint : graph.getEndpoints()) { + GeaflowEdge edge = edgeMap.get(endpoint.getEdgeId()); + GeaflowVertex source = vertexMap.get(endpoint.getSourceId()); + GeaflowVertex target = vertexMap.get(endpoint.getTargetId()); + if (edge != null && source != null && target != null) { + endpointMap.put(edge.getName(), new NameEndpoint(source.getName(), target.getName())); + } } + velocityMap.put("vertices", graph.getVertices().values()); + velocityMap.put("edges", graph.getEdges().values()); + velocityMap.put("endpoints", endpointMap); + velocityMap.put("endFlag", END_FLAG); + String s = VelocityUtil.applyResource(VERTEX_TEMPLATE, velocityMap); + return StringUtils.replacePattern(s, ",\n*\\s*" + END_FLAG, ""); + } - @AllArgsConstructor - @Setter - @Getter - public static class NameEndpoint { + @AllArgsConstructor + @Setter + @Getter + public static class NameEndpoint { - String sourceName; - String targetName; - } + String sourceName; + String targetName; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClient.java index 3a2918e2c..75da87076 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClient.java @@ -21,71 +21,74 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; -import okhttp3.OkHttpClient; -import okhttp3.Response; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.llm.GeaflowLLM; import org.apache.geaflow.console.core.model.llm.LLMConfigArgsClass; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.Response; + @Slf4j public abstract class LLMClient { - public final String call(GeaflowLLM llm, String prompt) { - try { - LLMConfigArgsClass config = getConfig(llm, LLMConfigArgsClass.class); - int retryTimes = config.getRetryTimes(); - int retryInterval = config.getRetryInterval(); + public final String call(GeaflowLLM llm, String prompt) { + try { + LLMConfigArgsClass config = getConfig(llm, LLMConfigArgsClass.class); + int retryTimes = config.getRetryTimes(); + int retryInterval = config.getRetryInterval(); - return callWithRetry(llm, prompt, retryTimes, retryInterval); + return callWithRetry(llm, prompt, retryTimes, retryInterval); - } catch (Exception e) { - log.info("Call language model failed", e); - throw new GeaflowException("Call language model failed {}", e.getMessage()); - } + } catch (Exception e) { + log.info("Call language model failed", e); + throw new GeaflowException("Call language model failed {}", e.getMessage()); } + } - protected T getConfig(GeaflowLLM llm, Class clazz) { - GeaflowConfig geaflowConfig = Optional.ofNullable(llm.getArgs()).orElse(new GeaflowConfig()); - T config = geaflowConfig.parse(clazz, true); - return config; - } + protected T getConfig(GeaflowLLM llm, Class clazz) { + GeaflowConfig geaflowConfig = Optional.ofNullable(llm.getArgs()).orElse(new GeaflowConfig()); + T config = geaflowConfig.parse(clazz, true); + return config; + } - private String callWithRetry(GeaflowLLM llm, String prompt, int retryTimes, int retryInterval) { - for (int i = 0; i < retryTimes; i++) { - try { - Response response = sendRequest(llm, prompt); - return parseResult(response); - } catch (Exception e) { - if (i == retryTimes - 1) { - throw e; - } - - try { - Thread.sleep(retryInterval * 1000); - } catch (InterruptedException ex) { - throw new RuntimeException(ex); - } - } + private String callWithRetry(GeaflowLLM llm, String prompt, int retryTimes, int retryInterval) { + for (int i = 0; i < retryTimes; i++) { + try { + Response response = sendRequest(llm, prompt); + return parseResult(response); + } catch (Exception e) { + if (i == retryTimes - 1) { + throw e; } - return null; + try { + Thread.sleep(retryInterval * 1000); + } catch (InterruptedException ex) { + throw new RuntimeException(ex); + } + } } - protected abstract Response sendRequest(GeaflowLLM llm, String prompt); + return null; + } - protected abstract String parseResult(Response prompt); + protected abstract Response sendRequest(GeaflowLLM llm, String prompt); - protected OkHttpClient getHttpClient(LLMConfigArgsClass config) { - long connectTimeout = config.getConnectTimeout(); - long readTimeOut = config.getReadTimeout(); - long writeTimeOut = config.getWriteTimeout(); + protected abstract String parseResult(Response prompt); - return new OkHttpClient().newBuilder() - .connectTimeout(connectTimeout, TimeUnit.SECONDS) - .writeTimeout(readTimeOut, TimeUnit.SECONDS) - .readTimeout(writeTimeOut, TimeUnit.SECONDS) - .build(); - } + protected OkHttpClient getHttpClient(LLMConfigArgsClass config) { + long connectTimeout = config.getConnectTimeout(); + long readTimeOut = config.getReadTimeout(); + long writeTimeOut = config.getWriteTimeout(); + + return new OkHttpClient() + .newBuilder() + .connectTimeout(connectTimeout, TimeUnit.SECONDS) + .writeTimeout(readTimeOut, TimeUnit.SECONDS) + .readTimeout(writeTimeOut, TimeUnit.SECONDS) + .build(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClientFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClientFactory.java index a2d3846bc..146a3ac38 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClientFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LLMClientFactory.java @@ -26,15 +26,14 @@ @Component public class LLMClientFactory { - public LLMClient getLLMClient(GeaflowLLMType type) { - switch (type) { - case OPEN_AI: - return OpenAiClient.getInstance(); - case LOCAL: - return LocalClient.getInstance(); - default: - throw new GeaflowException("Unsupported LLM type, {}", type); - } - + public LLMClient getLLMClient(GeaflowLLMType type) { + switch (type) { + case OPEN_AI: + return OpenAiClient.getInstance(); + case LOCAL: + return LocalClient.getInstance(); + default: + throw new GeaflowException("Unsupported LLM type, {}", type); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LocalClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LocalClient.java index 8f23d4c54..e287b86f6 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LocalClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/LocalClient.java @@ -19,78 +19,79 @@ package org.apache.geaflow.console.core.service.llm; +import org.apache.geaflow.console.core.model.llm.GeaflowLLM; +import org.apache.geaflow.console.core.model.llm.LocalConfigArgsClass; + import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.google.common.annotations.VisibleForTesting; + import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; -import org.apache.geaflow.console.core.model.llm.GeaflowLLM; -import org.apache.geaflow.console.core.model.llm.LocalConfigArgsClass; public class LocalClient extends LLMClient { - private static final LLMClient INSTANCE = new LocalClient(); + private static final LLMClient INSTANCE = new LocalClient(); - private static final int DEFAULT_N_PREDICT = 128; + private static final int DEFAULT_N_PREDICT = 128; - @VisibleForTesting - public String getJsonString(LocalConfigArgsClass llm, String prompt) { - Integer predict = llm.getPredict(); - int nPredict = (predict != null) ? predict : DEFAULT_N_PREDICT; - JSONObject root = new JSONObject(); - root.put("prompt", prompt.trim()); - root.put("n_predict", nPredict); - return root.toJSONString(); - } + @VisibleForTesting + public String getJsonString(LocalConfigArgsClass llm, String prompt) { + Integer predict = llm.getPredict(); + int nPredict = (predict != null) ? predict : DEFAULT_N_PREDICT; + JSONObject root = new JSONObject(); + root.put("prompt", prompt.trim()); + root.put("n_predict", nPredict); + return root.toJSONString(); + } - private LocalClient() { + private LocalClient() {} - } + public static LLMClient getInstance() { + return INSTANCE; + } - public static LLMClient getInstance() { - return INSTANCE; - } + @Override + protected Response sendRequest(GeaflowLLM llm, String prompt) { + try { + LocalConfigArgsClass config = getConfig(llm, LocalConfigArgsClass.class); + String jsonString = getJsonString(config, prompt); + OkHttpClient client = getHttpClient(config); + MediaType type = MediaType.get("application/json; charset=utf-8"); + RequestBody body = RequestBody.create(jsonString, type); - @Override - protected Response sendRequest(GeaflowLLM llm, String prompt) { - try { - LocalConfigArgsClass config = getConfig(llm, LocalConfigArgsClass.class); - String jsonString = getJsonString(config, prompt); - OkHttpClient client = getHttpClient(config); - MediaType type = MediaType.get("application/json; charset=utf-8"); - RequestBody body = RequestBody.create(jsonString, type); - - Request request = new Request.Builder() - .url(llm.getUrl()) - .addHeader("Content-Type", "application/json") - .addHeader("Cache-Control", "no-cache") - .post(body) - .build(); - - Response response = client.newCall(request).execute(); - return response; - - } catch (Exception e) { - throw new RuntimeException(e); - } + Request request = + new Request.Builder() + .url(llm.getUrl()) + .addHeader("Content-Type", "application/json") + .addHeader("Cache-Control", "no-cache") + .post(body) + .build(); + + Response response = client.newCall(request).execute(); + return response; + + } catch (Exception e) { + throw new RuntimeException(e); } + } + + @Override + protected String parseResult(Response response) { + try { + String res = response.body().string(); + JSONObject jsonObject = JSON.parseObject(res); + JSONObject error = jsonObject.getJSONObject("error"); + if (error != null) { + return error.getString("message"); + } + return jsonObject.getString("content"); - @Override - protected String parseResult(Response response) { - try { - String res = response.body().string(); - JSONObject jsonObject = JSON.parseObject(res); - JSONObject error = jsonObject.getJSONObject("error"); - if (error != null) { - return error.getString("message"); - } - return jsonObject.getString("content"); - - } catch (Exception e) { - throw new RuntimeException(e); - } + } catch (Exception e) { + throw new RuntimeException(e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/OpenAiClient.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/OpenAiClient.java index 2c4e5b148..f12f56008 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/OpenAiClient.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/llm/OpenAiClient.java @@ -19,90 +19,89 @@ package org.apache.geaflow.console.core.service.llm; +import java.util.Optional; +import java.util.function.Supplier; + +import org.apache.geaflow.console.common.util.exception.GeaflowException; +import org.apache.geaflow.console.core.model.llm.GeaflowLLM; +import org.apache.geaflow.console.core.model.llm.OpenAIConfigArgsClass; + import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; -import java.util.Optional; -import java.util.function.Supplier; + import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; -import org.apache.geaflow.console.common.util.exception.GeaflowException; -import org.apache.geaflow.console.core.model.llm.GeaflowLLM; -import org.apache.geaflow.console.core.model.llm.OpenAIConfigArgsClass; public class OpenAiClient extends LLMClient { - private static final LLMClient INSTANCE = new OpenAiClient(); + private static final LLMClient INSTANCE = new OpenAiClient(); - private static final String TEMPLATE = "{" - + "\"model\":\"%s\"," - + "\"messages\": [{\"role\": \"user\", \"content\": \"%s\"}]}"; + private static final String TEMPLATE = + "{" + "\"model\":\"%s\"," + "\"messages\": [{\"role\": \"user\", \"content\": \"%s\"}]}"; - private String getJsonString(OpenAIConfigArgsClass config, String prompt) { - return String.format(TEMPLATE, config.getModelId(), prompt.trim()); - } + private String getJsonString(OpenAIConfigArgsClass config, String prompt) { + return String.format(TEMPLATE, config.getModelId(), prompt.trim()); + } - private OpenAiClient() { + private OpenAiClient() {} - } - - public static LLMClient getInstance() { - return INSTANCE; - } + public static LLMClient getInstance() { + return INSTANCE; + } - @Override - protected Response sendRequest(GeaflowLLM llm, String prompt) { - try { + @Override + protected Response sendRequest(GeaflowLLM llm, String prompt) { + try { - OpenAIConfigArgsClass config = getConfig(llm, OpenAIConfigArgsClass.class); + OpenAIConfigArgsClass config = getConfig(llm, OpenAIConfigArgsClass.class); - String jsonString = getJsonString(config, prompt); - OkHttpClient client = getHttpClient(config); - MediaType type = MediaType.get("application/json; charset=utf-8"); - RequestBody body = RequestBody.create(jsonString, type); + String jsonString = getJsonString(config, prompt); + OkHttpClient client = getHttpClient(config); + MediaType type = MediaType.get("application/json; charset=utf-8"); + RequestBody body = RequestBody.create(jsonString, type); - Request request = new Request.Builder() - .url(llm.getUrl()) - .addHeader("Content-Type", "application/json") - .addHeader("Authorization", "Bearer " + config.getApiKey()) - .addHeader("Cache-Control", "no-cache") - .post(body) - .build(); + Request request = + new Request.Builder() + .url(llm.getUrl()) + .addHeader("Content-Type", "application/json") + .addHeader("Authorization", "Bearer " + config.getApiKey()) + .addHeader("Cache-Control", "no-cache") + .post(body) + .build(); - Response response = client.newCall(request).execute(); - return response; + Response response = client.newCall(request).execute(); + return response; - } catch (Exception e) { - throw new RuntimeException(e); - } + } catch (Exception e) { + throw new RuntimeException(e); } - - @Override - protected String parseResult(Response response) { - try { - String res = response.body().string(); - JSONObject jsonObject = JSON.parseObject(res); - - JSONArray choices = jsonObject.getJSONArray("choices"); - if (choices == null) { - return Optional.ofNullable(jsonObject.getJSONObject("error")) - .map(e -> e.getString("message")) - .orElseThrow((Supplier) () -> new GeaflowException("request failed")); - } - - JSONObject choice = choices.getJSONObject(0); - if (!choice.getString("finish_reason").equals("stop")) { - throw new GeaflowException("request failed"); - } - return choice - .getJSONObject("message") - .getString("content"); - - } catch (Throwable e) { - throw new RuntimeException(e); - } + } + + @Override + protected String parseResult(Response response) { + try { + String res = response.body().string(); + JSONObject jsonObject = JSON.parseObject(res); + + JSONArray choices = jsonObject.getJSONArray("choices"); + if (choices == null) { + return Optional.ofNullable(jsonObject.getJSONObject("error")) + .map(e -> e.getString("message")) + .orElseThrow((Supplier) () -> new GeaflowException("request failed")); + } + + JSONObject choice = choices.getJSONObject(0); + if (!choice.getString("finish_reason").equals("stop")) { + throw new GeaflowException("request failed"); + } + return choice.getJSONObject("message").getString("content"); + + } catch (Throwable e) { + throw new RuntimeException(e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ClusterConfigBuilder.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ClusterConfigBuilder.java index 9e5f8189f..d8d8b7c08 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ClusterConfigBuilder.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ClusterConfigBuilder.java @@ -19,58 +19,62 @@ package org.apache.geaflow.console.core.service.release; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.job.config.ClusterConfigClass; import org.apache.geaflow.console.core.model.job.config.ServeJobConfigClass; import org.apache.geaflow.console.core.model.release.GeaflowRelease; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class ClusterConfigBuilder { - private static ClusterConfigClass getDefaultClusterConfig() { - ClusterConfigClass clusterConfig = new ClusterConfigClass(); - clusterConfig.setContainers(1); - clusterConfig.setContainerWorkers(5); - clusterConfig.setContainerMemory(1024); - clusterConfig.setContainerCores(1.5); - clusterConfig.setContainerJvmOptions("-Xmx512m,-Xms512m,-Xmn256m"); + private static ClusterConfigClass getDefaultClusterConfig() { + ClusterConfigClass clusterConfig = new ClusterConfigClass(); + clusterConfig.setContainers(1); + clusterConfig.setContainerWorkers(5); + clusterConfig.setContainerMemory(1024); + clusterConfig.setContainerCores(1.5); + clusterConfig.setContainerJvmOptions("-Xmx512m,-Xms512m,-Xmn256m"); - clusterConfig.setClientMemory(1024); - clusterConfig.setMasterMemory(1024); - clusterConfig.setDriverMemory(1024); + clusterConfig.setClientMemory(1024); + clusterConfig.setMasterMemory(1024); + clusterConfig.setDriverMemory(1024); - clusterConfig.setClientCores(0.5); - clusterConfig.setMasterCores(1.0); - clusterConfig.setDriverCores(1.0); + clusterConfig.setClientCores(0.5); + clusterConfig.setMasterCores(1.0); + clusterConfig.setDriverCores(1.0); - clusterConfig.setClientJvmOptions("-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); - clusterConfig.setMasterJvmOptions("-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); - clusterConfig.setDriverJvmOptions("-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); - - return clusterConfig; - } + clusterConfig.setClientJvmOptions( + "-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); + clusterConfig.setMasterJvmOptions( + "-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); + clusterConfig.setDriverJvmOptions( + "-Xmx512m,-Xms512m,-Xmn256m,-Xss512k,-XX:MaxDirectMemorySize=128m"); - public static GeaflowConfig buildDefaultConfig(GeaflowRelease release) { - ClusterConfigClass configClass = getDefaultClusterConfig(); - switch (release.getJob().getType()) { - case SERVE: - ServeJobConfigClass jobConfig = release.getJobConfig().parse(ServeJobConfigClass.class); - Integer driverNum = jobConfig.getDriverNum(); - Integer queryParallelism = jobConfig.getQueryParallelism(); - configClass.setContainerWorkers(driverNum * queryParallelism + 1); - configClass.setContainerMemory(1200); - configClass.setContainerJvmOptions("-Xmx800m,-Xms800m,-Xmn256m"); - configClass.setDriverMemory(800); - configClass.setDriverJvmOptions("-Xmx400m,-Xms400m,-Xmn150m,-Xss512k,-XX:MaxDirectMemorySize=128m"); - break; - default: - break; - } + return clusterConfig; + } - return configClass.build(); + public static GeaflowConfig buildDefaultConfig(GeaflowRelease release) { + ClusterConfigClass configClass = getDefaultClusterConfig(); + switch (release.getJob().getType()) { + case SERVE: + ServeJobConfigClass jobConfig = release.getJobConfig().parse(ServeJobConfigClass.class); + Integer driverNum = jobConfig.getDriverNum(); + Integer queryParallelism = jobConfig.getQueryParallelism(); + configClass.setContainerWorkers(driverNum * queryParallelism + 1); + configClass.setContainerMemory(1200); + configClass.setContainerJvmOptions("-Xmx800m,-Xms800m,-Xmn256m"); + configClass.setDriverMemory(800); + configClass.setDriverJvmOptions( + "-Xmx400m,-Xms400m,-Xmn150m,-Xss512k,-XX:MaxDirectMemorySize=128m"); + break; + default: + break; } + return configClass.build(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildPipeline.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildPipeline.java index 0ea3b83ce..3d0412257 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildPipeline.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildPipeline.java @@ -21,7 +21,9 @@ import java.util.ArrayList; import java.util.List; + import javax.annotation.PostConstruct; + import org.apache.geaflow.console.core.model.job.GeaflowJob; import org.apache.geaflow.console.core.model.release.GeaflowRelease; import org.apache.geaflow.console.core.model.release.ReleaseUpdate; @@ -31,58 +33,49 @@ @Component public class GeaflowBuildPipeline { - private static final List STAGES = new ArrayList<>(); - - @Autowired - private ResolveReleaseVersionStage resolveReleaseVersionStage; + private static final List STAGES = new ArrayList<>(); - @Autowired - private ResolveVersionStage resolveVersionStage; + @Autowired private ResolveReleaseVersionStage resolveReleaseVersionStage; - @Autowired - private GenerateJobPlanStage generateJobPlanStage; + @Autowired private ResolveVersionStage resolveVersionStage; - @Autowired - private GenerateJobConfigStage generateJobConfigStage; + @Autowired private GenerateJobPlanStage generateJobPlanStage; - @Autowired - private GenerateClusterConfigStage generateClusterConfigStage; + @Autowired private GenerateJobConfigStage generateJobConfigStage; - @Autowired - private ResolveClusterStage resolveClusterStage; + @Autowired private GenerateClusterConfigStage generateClusterConfigStage; - @Autowired - private PackageStage packageStage; + @Autowired private ResolveClusterStage resolveClusterStage; - @PostConstruct - public void init() { - STAGES.add(resolveReleaseVersionStage); - STAGES.add(resolveVersionStage); - STAGES.add(generateJobPlanStage); - STAGES.add(generateJobConfigStage); - STAGES.add(generateClusterConfigStage); - STAGES.add(resolveClusterStage); - STAGES.add(packageStage); - } + @Autowired private PackageStage packageStage; + @PostConstruct + public void init() { + STAGES.add(resolveReleaseVersionStage); + STAGES.add(resolveVersionStage); + STAGES.add(generateJobPlanStage); + STAGES.add(generateJobConfigStage); + STAGES.add(generateClusterConfigStage); + STAGES.add(resolveClusterStage); + STAGES.add(packageStage); + } - public static GeaflowRelease build(GeaflowJob job) { - GeaflowRelease release = new GeaflowRelease(); - release.setJob(job); - for (GeaflowBuildStage stage : STAGES) { - stage.init(release); - } - return release; + public static GeaflowRelease build(GeaflowJob job) { + GeaflowRelease release = new GeaflowRelease(); + release.setJob(job); + for (GeaflowBuildStage stage : STAGES) { + stage.init(release); } - - public static void update(GeaflowRelease release, ReleaseUpdate update) { - boolean initNext = true; - for (GeaflowBuildStage stage : STAGES) { - if (initNext) { - stage.init(release); - } - initNext = stage.update(release, update); - } + return release; + } + + public static void update(GeaflowRelease release, ReleaseUpdate update) { + boolean initNext = true; + for (GeaflowBuildStage stage : STAGES) { + if (initNext) { + stage.init(release); + } + initNext = stage.update(release, update); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildStage.java index 926c0ba1f..a8a386157 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowBuildStage.java @@ -24,8 +24,7 @@ public abstract class GeaflowBuildStage { - public abstract void init(GeaflowRelease release); - - public abstract boolean update(GeaflowRelease release, ReleaseUpdate update); + public abstract void init(GeaflowRelease release); + public abstract boolean update(GeaflowRelease release, ReleaseUpdate update); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowReleaseBuilder.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowReleaseBuilder.java index b1371d7f4..f24b20a41 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowReleaseBuilder.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GeaflowReleaseBuilder.java @@ -25,13 +25,12 @@ public class GeaflowReleaseBuilder { - public static GeaflowRelease build(GeaflowJob job) { - return GeaflowBuildPipeline.build(job); - } - - public static GeaflowRelease update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowBuildPipeline.update(release, update); - return release; - } + public static GeaflowRelease build(GeaflowJob job) { + return GeaflowBuildPipeline.build(job); + } + public static GeaflowRelease update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowBuildPipeline.update(release, update); + return release; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateClusterConfigStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateClusterConfigStage.java index fd952383f..8ffba751c 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateClusterConfigStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateClusterConfigStage.java @@ -29,29 +29,28 @@ @Component public class GenerateClusterConfigStage extends GeaflowBuildStage { - @Autowired - private ClusterConfigBuilder clusterConfigBuilder; + @Autowired private ClusterConfigBuilder clusterConfigBuilder; - @Override - public void init(GeaflowRelease release) { - // generate default cluster config from job config - GeaflowConfig clusterConfig = clusterConfigBuilder.buildDefaultConfig(release); + @Override + public void init(GeaflowRelease release) { + // generate default cluster config from job config + GeaflowConfig clusterConfig = clusterConfigBuilder.buildDefaultConfig(release); - if (release.getClusterConfig() != null) { - clusterConfig.putAll(release.getClusterConfig()); - } - - release.setClusterConfig(clusterConfig); + if (release.getClusterConfig() != null) { + clusterConfig.putAll(release.getClusterConfig()); } - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowConfig newClusterConfig = update.getNewClusterConfig(); - if (MapUtils.isEmpty(newClusterConfig)) { - return false; - } + release.setClusterConfig(clusterConfig); + } - release.setClusterConfig(newClusterConfig); - return true; + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowConfig newClusterConfig = update.getNewClusterConfig(); + if (MapUtils.isEmpty(newClusterConfig)) { + return false; } + + release.setClusterConfig(newClusterConfig); + return true; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobConfigStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobConfigStage.java index 440cd628c..345831e83 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobConfigStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobConfigStage.java @@ -29,30 +29,27 @@ @Component public class GenerateJobConfigStage extends GeaflowBuildStage { - @Autowired - private JobConfigBuilder jobConfigBuilder; + @Autowired private JobConfigBuilder jobConfigBuilder; - @Override - public void init(GeaflowRelease release) { - // generate default job config from job plan - GeaflowConfig jobConfig = jobConfigBuilder.buildDefaultConfig(release); - - if (release.getJobConfig() != null) { - jobConfig.putAll(release.getJobConfig()); - } - - release.setJobConfig(jobConfig); + @Override + public void init(GeaflowRelease release) { + // generate default job config from job plan + GeaflowConfig jobConfig = jobConfigBuilder.buildDefaultConfig(release); + if (release.getJobConfig() != null) { + jobConfig.putAll(release.getJobConfig()); } + release.setJobConfig(jobConfig); + } - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowConfig newJobConfig = update.getNewJobConfig(); - if (MapUtils.isEmpty(newJobConfig)) { - return false; - } - release.setJobConfig(newJobConfig); - return true; + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowConfig newJobConfig = update.getNewJobConfig(); + if (MapUtils.isEmpty(newJobConfig)) { + return false; } + release.setJobConfig(newJobConfig); + return true; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobPlanStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobPlanStage.java index f33aae4e5..743ceb7af 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobPlanStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/GenerateJobPlanStage.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.core.service.release; import java.util.Map; + import org.apache.geaflow.console.common.service.integration.engine.CompileResult; import org.apache.geaflow.console.core.model.code.GeaflowCode; import org.apache.geaflow.console.core.model.job.GeaflowJob; @@ -37,57 +38,56 @@ @Component public class GenerateJobPlanStage extends GeaflowBuildStage { - @Autowired - private ReleaseService releaseService; + @Autowired private ReleaseService releaseService; - @Autowired - private JobService jobService; + @Autowired private JobService jobService; - public void init(GeaflowRelease release) { - // Hla Jobs don't need compile - GeaflowJob job = release.getJob(); - if (job.isApiJob()) { - return; - } - // Generate code for transferJobs - generateCode(job); - GeaflowVersion version = release.getVersion(); - release.setJobPlan(compileJobPlan(job, version, null)); + public void init(GeaflowRelease release) { + // Hla Jobs don't need compile + GeaflowJob job = release.getJob(); + if (job.isApiJob()) { + return; } + // Generate code for transferJobs + generateCode(job); + GeaflowVersion version = release.getVersion(); + release.setJobPlan(compileJobPlan(job, version, null)); + } - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowJob job = release.getJob(); - if (job.isApiJob()) { - return false; - } + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowJob job = release.getJob(); + if (job.isApiJob()) { + return false; + } - Map newParallelisms = update.getNewParallelisms(); - if (newParallelisms == null) { - return false; - } + Map newParallelisms = update.getNewParallelisms(); + if (newParallelisms == null) { + return false; + } - JobPlanBuilder.setParallelisms(release.getJobPlan(), newParallelisms); + JobPlanBuilder.setParallelisms(release.getJobPlan(), newParallelisms); - generateCode(job); + generateCode(job); - GeaflowVersion version = release.getVersion(); - JobPlan newJobPlan = compileJobPlan(job, version, newParallelisms); - release.setJobPlan(newJobPlan); + GeaflowVersion version = release.getVersion(); + JobPlan newJobPlan = compileJobPlan(job, version, newParallelisms); + release.setJobPlan(newJobPlan); - return true; - } + return true; + } - private JobPlan compileJobPlan(GeaflowJob job, GeaflowVersion version, Map parallelisms) { - CompileResult compileResult = releaseService.compile(job, version, parallelisms); - return JobPlanBuilder.build(compileResult.getPhysicPlan()); - } + private JobPlan compileJobPlan( + GeaflowJob job, GeaflowVersion version, Map parallelisms) { + CompileResult compileResult = releaseService.compile(job, version, parallelisms); + return JobPlanBuilder.build(compileResult.getPhysicPlan()); + } - private void generateCode(GeaflowJob job) { - if (job instanceof GeaflowTransferJob) { - GeaflowCode geaflowCode = ((GeaflowTransferJob) job).generateCode(); - ((GeaflowTransferJob) job).setUserCode(geaflowCode.getText()); - jobService.update(job); - } + private void generateCode(GeaflowJob job) { + if (job instanceof GeaflowTransferJob) { + GeaflowCode geaflowCode = ((GeaflowTransferJob) job).generateCode(); + ((GeaflowTransferJob) job).setUserCode(geaflowCode.getText()); + jobService.update(job); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/JobConfigBuilder.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/JobConfigBuilder.java index 1abb6e493..5cd4a5fbd 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/JobConfigBuilder.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/JobConfigBuilder.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.core.service.release; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.data.GeaflowGraph; import org.apache.geaflow.console.core.model.job.GeaflowJob; @@ -29,44 +28,46 @@ import org.apache.geaflow.console.core.model.release.GeaflowRelease; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class JobConfigBuilder { - public GeaflowConfig buildDefaultConfig(GeaflowRelease release) { - JobConfigClass configClass; - GeaflowJob job = release.getJob(); - switch (job.getType()) { - case SERVE: - configClass = initServeJobConfigClass(release); - break; - case INTEGRATE: - configClass = new CodeJobConfigClass(); - ((CodeJobConfigClass) configClass).setWindowSize(-1); - break; - default: - configClass = new JobConfigClass(); - break; - } - - return configClass.build(); + public GeaflowConfig buildDefaultConfig(GeaflowRelease release) { + JobConfigClass configClass; + GeaflowJob job = release.getJob(); + switch (job.getType()) { + case SERVE: + configClass = initServeJobConfigClass(release); + break; + case INTEGRATE: + configClass = new CodeJobConfigClass(); + ((CodeJobConfigClass) configClass).setWindowSize(-1); + break; + default: + configClass = new JobConfigClass(); + break; } - private JobConfigClass initServeJobConfigClass(GeaflowRelease release) { - GeaflowJob job = release.getJob(); - GeaflowGraph graph = job.getGraphs().get(0); + return configClass.build(); + } - ServeJobConfigClass configClass = new ServeJobConfigClass(); - configClass.setJobMode("OLAP_SERVICE"); - configClass.setServiceShareEnable(true); - configClass.setGraphName(graph.getName()); + private JobConfigClass initServeJobConfigClass(GeaflowRelease release) { + GeaflowJob job = release.getJob(); + GeaflowGraph graph = job.getGraphs().get(0); - int shardCount = graph.getShardCount(); - int driverNum = 1; - int queryParallelism = shardCount % driverNum == 0 ? shardCount / driverNum : - shardCount / driverNum + 1; - configClass.setQueryParallelism(queryParallelism); - configClass.setDriverNum(driverNum); - return configClass; - } + ServeJobConfigClass configClass = new ServeJobConfigClass(); + configClass.setJobMode("OLAP_SERVICE"); + configClass.setServiceShareEnable(true); + configClass.setGraphName(graph.getName()); + + int shardCount = graph.getShardCount(); + int driverNum = 1; + int queryParallelism = + shardCount % driverNum == 0 ? shardCount / driverNum : shardCount / driverNum + 1; + configClass.setQueryParallelism(queryParallelism); + configClass.setDriverNum(driverNum); + return configClass; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/PackageStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/PackageStage.java index 5b49607a9..b21af6d51 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/PackageStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/PackageStage.java @@ -19,12 +19,12 @@ package org.apache.geaflow.console.core.service.release; -import com.alibaba.fastjson.JSON; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; + import org.apache.geaflow.console.common.util.Md5Util; import org.apache.geaflow.console.common.util.ZipUtil; import org.apache.geaflow.console.common.util.ZipUtil.GeaflowZipEntry; @@ -37,41 +37,46 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + @Component public class PackageStage extends GeaflowBuildStage { - private static final String GQL_FILE_NAME = "user.gql"; - private static final String CONF_FILE_NAME = "user.conf"; + private static final String GQL_FILE_NAME = "user.gql"; + private static final String CONF_FILE_NAME = "user.conf"; - @Autowired - private RemoteFileStorage remoteFileStorage; + @Autowired private RemoteFileStorage remoteFileStorage; - @Override - public void init(GeaflowRelease release) { - try { - List zipEntries; - if (release.getJob().isApiJob()) { - MemoryZipEntry entry = new MemoryZipEntry("empty.txt", ""); - zipEntries = Collections.singletonList(entry); + @Override + public void init(GeaflowRelease release) { + try { + List zipEntries; + if (release.getJob().isApiJob()) { + MemoryZipEntry entry = new MemoryZipEntry("empty.txt", ""); + zipEntries = Collections.singletonList(entry); - } else { - String code = release.getJob().getUserCode().getText(); - Map parallelismMap = JobPlanBuilder.getParallelismMap(release.getJobPlan()); - MemoryZipEntry codeEntry = new MemoryZipEntry(GQL_FILE_NAME, code); - MemoryZipEntry confEntry = new MemoryZipEntry(CONF_FILE_NAME, JSON.toJSONString(parallelismMap)); - zipEntries = Arrays.asList(codeEntry, confEntry); - } - // url and md5 - String path = RemoteFileStorage.getPackageFilePath(release.getJob().getId(), release.getReleaseVersion()); - release.setUrl(remoteFileStorage.upload(path, ZipUtil.buildZipInputStream(zipEntries))); - release.setMd5(Md5Util.encodeFile(ZipUtil.buildZipInputStream(zipEntries))); - } catch (IOException e) { - throw new GeaflowException("Package job {} fail ", release.getJob().getName(), e); - } + } else { + String code = release.getJob().getUserCode().getText(); + Map parallelismMap = + JobPlanBuilder.getParallelismMap(release.getJobPlan()); + MemoryZipEntry codeEntry = new MemoryZipEntry(GQL_FILE_NAME, code); + MemoryZipEntry confEntry = + new MemoryZipEntry(CONF_FILE_NAME, JSON.toJSONString(parallelismMap)); + zipEntries = Arrays.asList(codeEntry, confEntry); + } + // url and md5 + String path = + RemoteFileStorage.getPackageFilePath( + release.getJob().getId(), release.getReleaseVersion()); + release.setUrl(remoteFileStorage.upload(path, ZipUtil.buildZipInputStream(zipEntries))); + release.setMd5(Md5Util.encodeFile(ZipUtil.buildZipInputStream(zipEntries))); + } catch (IOException e) { + throw new GeaflowException("Package job {} fail ", release.getJob().getName(), e); } + } - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - return false; - } + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + return false; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveClusterStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveClusterStage.java index 7db8d7576..0005cb70d 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveClusterStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveClusterStage.java @@ -29,25 +29,23 @@ @Component public class ResolveClusterStage extends GeaflowBuildStage { - - @Autowired - private ClusterService clusterService; - - @Override - public void init(GeaflowRelease release) { - // assign cluster by cluster config - GeaflowCluster cluster = clusterService.getDefaultCluster(); - release.setCluster(cluster); + @Autowired private ClusterService clusterService; + + @Override + public void init(GeaflowRelease release) { + // assign cluster by cluster config + GeaflowCluster cluster = clusterService.getDefaultCluster(); + release.setCluster(cluster); + } + + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowCluster newCluster = update.getNewCluster(); + if (newCluster == null) { + return true; } - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowCluster newCluster = update.getNewCluster(); - if (newCluster == null) { - return true; - } - - release.setCluster(newCluster); - return true; - } + release.setCluster(newCluster); + return true; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveReleaseVersionStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveReleaseVersionStage.java index 32cb857f9..87321d1b3 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveReleaseVersionStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveReleaseVersionStage.java @@ -31,39 +31,37 @@ @Component public class ResolveReleaseVersionStage extends GeaflowBuildStage { - @Autowired - private TaskService taskService; + @Autowired private TaskService taskService; - public void init(GeaflowRelease release) { - String jobId = release.getJob().getId(); + public void init(GeaflowRelease release) { + String jobId = release.getJob().getId(); - GeaflowTask task = taskService.getByJobId(jobId); - if (task == null) { - // publish at the first time - release.setReleaseVersion(1); - } else { - GeaflowTaskStatus status = task.getStatus(); - status.checkOperation(GeaflowOperationType.PUBLISH); + GeaflowTask task = taskService.getByJobId(jobId); + if (task == null) { + // publish at the first time + release.setReleaseVersion(1); + } else { + GeaflowTaskStatus status = task.getStatus(); + status.checkOperation(GeaflowOperationType.PUBLISH); - GeaflowRelease oldRelease = task.getRelease(); - int currentVersion = oldRelease.getReleaseVersion(); - if (status == GeaflowTaskStatus.CREATED) { - //update release, releaseVersion unchanged - release.setReleaseVersion(currentVersion); - release.setId(task.getRelease().getId()); - } else { - //stop, fail, finish, versionNumber + 1 - release.setReleaseVersion(currentVersion + 1); - } + GeaflowRelease oldRelease = task.getRelease(); + int currentVersion = oldRelease.getReleaseVersion(); + if (status == GeaflowTaskStatus.CREATED) { + // update release, releaseVersion unchanged + release.setReleaseVersion(currentVersion); + release.setId(task.getRelease().getId()); + } else { + // stop, fail, finish, versionNumber + 1 + release.setReleaseVersion(currentVersion + 1); + } - release.setJobConfig(oldRelease.getJobConfig()); - release.setClusterConfig(oldRelease.getClusterConfig()); - } + release.setJobConfig(oldRelease.getJobConfig()); + release.setClusterConfig(oldRelease.getClusterConfig()); } + } - - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - return false; - } + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + return false; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveVersionStage.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveVersionStage.java index 1525b519f..298dedc24 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveVersionStage.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/release/ResolveVersionStage.java @@ -30,38 +30,36 @@ @Component public class ResolveVersionStage extends GeaflowBuildStage { - @Autowired - private VersionService versionService; + @Autowired private VersionService versionService; - public void init(GeaflowRelease release) { - // get latest stable version - GeaflowVersion version = versionService.getDefaultVersion(); + public void init(GeaflowRelease release) { + // get latest stable version + GeaflowVersion version = versionService.getDefaultVersion(); - release.setVersion(version); - } - - @Override - public boolean update(GeaflowRelease release, ReleaseUpdate update) { - GeaflowVersion oldVersion = release.getVersion(); - GeaflowVersion newVersion = update.getNewVersion(); - - // the version is deleted if oldVersion is null; - if (oldVersion == null && newVersion == null) { - throw new GeaflowException("Version is null"); - } + release.setVersion(version); + } - if (newVersion == null) { - return false; - } + @Override + public boolean update(GeaflowRelease release, ReleaseUpdate update) { + GeaflowVersion oldVersion = release.getVersion(); + GeaflowVersion newVersion = update.getNewVersion(); - if (oldVersion == null) { - release.setVersion(newVersion); - return true; - } + // the version is deleted if oldVersion is null; + if (oldVersion == null && newVersion == null) { + throw new GeaflowException("Version is null"); + } - release.setVersion(newVersion); - // compile in next stage if version is changed - return !GeaflowVersion.md5Equals(oldVersion, newVersion); + if (newVersion == null) { + return false; + } + if (oldVersion == null) { + release.setVersion(newVersion); + return true; } + + release.setVersion(newVersion); + // compile in next stage if version is changed + return !GeaflowVersion.md5Equals(oldVersion, newVersion); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerRuntime.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerRuntime.java index c6eac60c2..0fd6c1633 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerRuntime.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerRuntime.java @@ -19,11 +19,10 @@ package org.apache.geaflow.console.core.service.runtime; -import com.alibaba.fastjson.JSON; import java.io.File; import java.util.ArrayList; import java.util.List; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.exec.CommandLine; import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; @@ -44,132 +43,146 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class ContainerRuntime implements GeaflowRuntime { - private static final String GEAFLOW_ENGINE_LOG_FILE = "/tmp/logs/task/%s.log"; + private static final String GEAFLOW_ENGINE_LOG_FILE = "/tmp/logs/task/%s.log"; - private static final String GEAFLOW_ENGINE_FINISH_FILE = "/tmp/logs/task/%s.finish"; + private static final String GEAFLOW_ENGINE_FINISH_FILE = "/tmp/logs/task/%s.finish"; - private static final String GEAFLOW_LOG4J_PROPERTIES = "log4j.properties"; + private static final String GEAFLOW_LOG4J_PROPERTIES = "log4j.properties"; - @Autowired - private ContainerTaskParams taskParams; + @Autowired private ContainerTaskParams taskParams; - @Autowired - protected LocalFileFactory localFileFactory; + @Autowired protected LocalFileFactory localFileFactory; - @Autowired - private InstanceService instanceService; + @Autowired private InstanceService instanceService; - public static String getLogFilePath(String taskId) { - return String.format(GEAFLOW_ENGINE_LOG_FILE, TaskParams.getRuntimeTaskName(taskId)); - } + public static String getLogFilePath(String taskId) { + return String.format(GEAFLOW_ENGINE_LOG_FILE, TaskParams.getRuntimeTaskName(taskId)); + } - public static String getFinishFilePath(String taskId) { - return String.format(GEAFLOW_ENGINE_FINISH_FILE, TaskParams.getRuntimeTaskName(taskId)); - } + public static String getFinishFilePath(String taskId) { + return String.format(GEAFLOW_ENGINE_FINISH_FILE, TaskParams.getRuntimeTaskName(taskId)); + } + + @Override + public GeaflowTaskHandle start(GeaflowTask task) { + GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); + return doStart(task, taskParams.buildClientArgs(instance, task)); + } + + @Override + public void stop(GeaflowTask task) { + try { + int pid = ((ContainerTaskHandle) task.getHandle()).getPid(); + if (ProcessUtil.existPid(pid)) { + ProcessUtil.killPid(pid); + } - @Override - public GeaflowTaskHandle start(GeaflowTask task) { - GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); - return doStart(task, taskParams.buildClientArgs(instance, task)); + } catch (Exception e) { + throw new GeaflowLogException("Stop task {} failed", task.getId(), e); } + } - @Override - public void stop(GeaflowTask task) { - try { + @Override + public GeaflowTaskStatus queryStatus(GeaflowTask task) { + try { + return RetryUtil.exec( + () -> { int pid = ((ContainerTaskHandle) task.getHandle()).getPid(); if (ProcessUtil.existPid(pid)) { - ProcessUtil.killPid(pid); + return GeaflowTaskStatus.RUNNING; } - } catch (Exception e) { - throw new GeaflowLogException("Stop task {} failed", task.getId(), e); - } - } - - @Override - public GeaflowTaskStatus queryStatus(GeaflowTask task) { - try { - return RetryUtil.exec(() -> { - int pid = ((ContainerTaskHandle) task.getHandle()).getPid(); - if (ProcessUtil.existPid(pid)) { - return GeaflowTaskStatus.RUNNING; - } - - if (FileUtil.exist(getFinishFilePath(task.getId()))) { - return GeaflowTaskStatus.FINISHED; - } - - return GeaflowTaskStatus.FAILED; - }, 5, 500); + if (FileUtil.exist(getFinishFilePath(task.getId()))) { + return GeaflowTaskStatus.FINISHED; + } - } catch (Exception e) { - log.error("Query task {} status failed, handle={}", task.getId(), JSON.toJSONString(task.getHandle()), e); return GeaflowTaskStatus.FAILED; - } + }, + 5, + 500); + + } catch (Exception e) { + log.error( + "Query task {} status failed, handle={}", + task.getId(), + JSON.toJSONString(task.getHandle()), + e); + return GeaflowTaskStatus.FAILED; } - - private GeaflowTaskHandle doStart(GeaflowTask task, GeaflowArgsClass geaflowArgs) { - String runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); - taskParams.validateRuntimeTaskId(runtimeTaskId); - - try { - // kill last task process if exists - GeaflowTaskHandle handle = task.getHandle(); - if (handle != null) { - int pid = ((ContainerTaskHandle) handle).getPid(); - if (ProcessUtil.existPid(pid)) { - ProcessUtil.killPid(pid); - } - } - - // clear finish file if exists - String finishFile = getFinishFilePath(task.getId()); - FileUtil.delete(finishFile); - - List classPaths = new ArrayList<>(); - - // add version jar - GeaflowRelease release = task.getRelease(); - String versionName = release.getVersion().getName(); - task.getVersionJars() - .forEach(jar -> classPaths.add(localFileFactory.getVersionFile(versionName, jar).getAbsolutePath())); - - // add user jar - task.getUserJars() - .forEach(jar -> classPaths.add(localFileFactory.getTaskUserFile(runtimeTaskId, jar).getAbsolutePath())); - - // add release zip - File releaseFile = localFileFactory.getTaskReleaseFile(runtimeTaskId, release.getJob().getId(), release); - ZipUtil.unzip(releaseFile); - classPaths.add(releaseFile.getParent()); - - // start task process - String java = System.getProperty("java.home") + "/bin/java"; - String classPathString = StringUtils.join(classPaths, ":"); - String mainClass = task.getMainClass(); - String args = StringEscapeUtils.escapeJava(JSON.toJSONString(geaflowArgs.build())); - String logFilePath = getLogFilePath(task.getId()); - CommandLine cmd = new CommandLine(java); - cmd.addArgument("-cp"); - cmd.addArgument(classPathString); - cmd.addArgument("-Dlog.file=" + logFilePath); - cmd.addArgument("-Dlog4j.configuration=" + GEAFLOW_LOG4J_PROPERTIES); - cmd.addArgument(mainClass); - cmd.addArgument(args, false); - int pid = ProcessUtil.execAsyncCommand(cmd, 1000, logFilePath, finishFile); - - // save handle - ContainerTaskHandle taskHandle = new ContainerTaskHandle(runtimeTaskId, pid); - log.info("Start task {} success, handle={}", task.getId(), JSON.toJSONString(taskHandle)); - return taskHandle; - - } catch (Exception e) { - throw new GeaflowLogException("Start task {} failed", task.getId(), e); + } + + private GeaflowTaskHandle doStart(GeaflowTask task, GeaflowArgsClass geaflowArgs) { + String runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); + taskParams.validateRuntimeTaskId(runtimeTaskId); + + try { + // kill last task process if exists + GeaflowTaskHandle handle = task.getHandle(); + if (handle != null) { + int pid = ((ContainerTaskHandle) handle).getPid(); + if (ProcessUtil.existPid(pid)) { + ProcessUtil.killPid(pid); } + } + + // clear finish file if exists + String finishFile = getFinishFilePath(task.getId()); + FileUtil.delete(finishFile); + + List classPaths = new ArrayList<>(); + + // add version jar + GeaflowRelease release = task.getRelease(); + String versionName = release.getVersion().getName(); + task.getVersionJars() + .forEach( + jar -> + classPaths.add( + localFileFactory.getVersionFile(versionName, jar).getAbsolutePath())); + + // add user jar + task.getUserJars() + .forEach( + jar -> + classPaths.add( + localFileFactory.getTaskUserFile(runtimeTaskId, jar).getAbsolutePath())); + + // add release zip + File releaseFile = + localFileFactory.getTaskReleaseFile(runtimeTaskId, release.getJob().getId(), release); + ZipUtil.unzip(releaseFile); + classPaths.add(releaseFile.getParent()); + + // start task process + String java = System.getProperty("java.home") + "/bin/java"; + String classPathString = StringUtils.join(classPaths, ":"); + String mainClass = task.getMainClass(); + String args = StringEscapeUtils.escapeJava(JSON.toJSONString(geaflowArgs.build())); + String logFilePath = getLogFilePath(task.getId()); + CommandLine cmd = new CommandLine(java); + cmd.addArgument("-cp"); + cmd.addArgument(classPathString); + cmd.addArgument("-Dlog.file=" + logFilePath); + cmd.addArgument("-Dlog4j.configuration=" + GEAFLOW_LOG4J_PROPERTIES); + cmd.addArgument(mainClass); + cmd.addArgument(args, false); + int pid = ProcessUtil.execAsyncCommand(cmd, 1000, logFilePath, finishFile); + + // save handle + ContainerTaskHandle taskHandle = new ContainerTaskHandle(runtimeTaskId, pid); + log.info("Start task {} success, handle={}", task.getId(), JSON.toJSONString(taskHandle)); + return taskHandle; + + } catch (Exception e) { + throw new GeaflowLogException("Start task {} failed", task.getId(), e); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerTaskParams.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerTaskParams.java index a3fd157bc..21f7cb928 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerTaskParams.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/ContainerTaskParams.java @@ -31,18 +31,17 @@ @Component public class ContainerTaskParams extends TaskParams { - public GeaflowArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { - return buildGeaflowArgs(instance, task); - } + public GeaflowArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { + return buildGeaflowArgs(instance, task); + } - @Override - protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { - GeaflowRelease release = task.getRelease(); - - ContainerClusterArgsClass clusterArgs = new ContainerClusterArgsClass(); - ClusterConfigClass clusterConfig = release.getClusterConfig().parse(ClusterConfigClass.class); - clusterArgs.setTaskClusterConfig(clusterConfig); - return clusterArgs; - } + @Override + protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { + GeaflowRelease release = task.getRelease(); + ContainerClusterArgsClass clusterArgs = new ContainerClusterArgsClass(); + ClusterConfigClass clusterConfig = release.getClusterConfig().parse(ClusterConfigClass.class); + clusterArgs.setTaskClusterConfig(clusterConfig); + return clusterArgs; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/GeaflowRuntime.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/GeaflowRuntime.java index 6015bbebf..5bdc061db 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/GeaflowRuntime.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/GeaflowRuntime.java @@ -25,10 +25,9 @@ public interface GeaflowRuntime { - GeaflowTaskHandle start(GeaflowTask task); + GeaflowTaskHandle start(GeaflowTask task); - void stop(GeaflowTask task); - - GeaflowTaskStatus queryStatus(GeaflowTask task); + void stop(GeaflowTask task); + GeaflowTaskStatus queryStatus(GeaflowTask task); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sRuntime.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sRuntime.java index bba5d507f..8fe7bb012 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sRuntime.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sRuntime.java @@ -19,9 +19,8 @@ package org.apache.geaflow.console.core.service.runtime; -import com.alibaba.fastjson.JSON; import java.util.Map; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.service.integration.engine.K8sJobClient; import org.apache.geaflow.console.common.util.RetryUtil; import org.apache.geaflow.console.common.util.exception.GeaflowLogException; @@ -40,90 +39,99 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; + +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class K8sRuntime implements GeaflowRuntime { - @Autowired - private K8sTaskParams taskParams; - - @Autowired - private VersionFactory versionFactory; - - @Autowired - private InstanceService instanceService; - - @Override - public GeaflowTaskHandle start(GeaflowTask task) { - GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); - return doStart(task, taskParams.buildClientArgs(instance, task)); + @Autowired private K8sTaskParams taskParams; + + @Autowired private VersionFactory versionFactory; + + @Autowired private InstanceService instanceService; + + @Override + public GeaflowTaskHandle start(GeaflowTask task) { + GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); + return doStart(task, taskParams.buildClientArgs(instance, task)); + } + + @Override + public void stop(GeaflowTask task) { + doStop(task, taskParams.buildClientStopArgs(task)); + } + + @Override + public GeaflowTaskStatus queryStatus(GeaflowTask task) { + try { + return RetryUtil.exec( + () -> { + boolean existMasterService = + existMasterService(task, taskParams.buildClientStopArgs(task)); + return existMasterService ? GeaflowTaskStatus.RUNNING : GeaflowTaskStatus.FAILED; + }, + 5, + 500); + } catch (Exception e) { + log.error( + "Query task {} status failed, handle={}", + task.getId(), + JSON.toJSONString(task.getHandle()), + e); + return GeaflowTaskStatus.FAILED; } + } - @Override - public void stop(GeaflowTask task) { - doStop(task, taskParams.buildClientStopArgs(task)); - } + private GeaflowTaskHandle doStart(GeaflowTask task, K8sClientArgsClass clientArgs) { + GeaflowArgsClass geaflowArgs = clientArgs.getGeaflowArgs(); + K8SClusterArgsClass clusterArgs = (K8SClusterArgsClass) geaflowArgs.getClusterArgs(); + String runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); + taskParams.validateRuntimeTaskId(runtimeTaskId); - @Override - public GeaflowTaskStatus queryStatus(GeaflowTask task) { - try { - return RetryUtil.exec(() -> { - boolean existMasterService = existMasterService(task, taskParams.buildClientStopArgs(task)); - return existMasterService ? GeaflowTaskStatus.RUNNING : GeaflowTaskStatus.FAILED; - }, 5, 500); - } catch (Exception e) { - log.error("Query task {} status failed, handle={}", task.getId(), JSON.toJSONString(task.getHandle()), e); - return GeaflowTaskStatus.FAILED; - } - } + try { + Map params = clientArgs.build().toStringMap(); + String masterUrl = clusterArgs.getClusterConfig().getMasterUrl(); - private GeaflowTaskHandle doStart(GeaflowTask task, K8sClientArgsClass clientArgs) { - GeaflowArgsClass geaflowArgs = clientArgs.getGeaflowArgs(); - K8SClusterArgsClass clusterArgs = (K8SClusterArgsClass) geaflowArgs.getClusterArgs(); - String runtimeTaskId = geaflowArgs.getSystemArgs().getRuntimeTaskId(); - taskParams.validateRuntimeTaskId(runtimeTaskId); + VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); + K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); + jobClient.submitJob(); - try { - Map params = clientArgs.build().toStringMap(); - String masterUrl = clusterArgs.getClusterConfig().getMasterUrl(); + K8sTaskHandle taskHandle = new K8sTaskHandle(runtimeTaskId); + log.info("Start task {} success, handle={}", task.getId(), JSON.toJSONString(taskHandle)); + return taskHandle; - VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); - K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); - jobClient.submitJob(); - - K8sTaskHandle taskHandle = new K8sTaskHandle(runtimeTaskId); - log.info("Start task {} success, handle={}", task.getId(), JSON.toJSONString(taskHandle)); - return taskHandle; - - } catch (Exception e) { - throw new GeaflowLogException("Start task {} failed", task.getId(), e); - } + } catch (Exception e) { + throw new GeaflowLogException("Start task {} failed", task.getId(), e); } + } - private void doStop(GeaflowTask task, K8sClientStopArgsClass k8sClientStopArgs) { - taskParams.validateRuntimeTaskId(k8sClientStopArgs.getRuntimeTaskId()); + private void doStop(GeaflowTask task, K8sClientStopArgsClass k8sClientStopArgs) { + taskParams.validateRuntimeTaskId(k8sClientStopArgs.getRuntimeTaskId()); - try { - Map params = k8sClientStopArgs.build().toStringMap(); - String masterUrl = k8sClientStopArgs.getClusterArgs().getClusterConfig().getMasterUrl(); + try { + Map params = k8sClientStopArgs.build().toStringMap(); + String masterUrl = k8sClientStopArgs.getClusterArgs().getClusterConfig().getMasterUrl(); - VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); - K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); + VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); + K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); - jobClient.stopJob(); + jobClient.stopJob(); - } catch (Exception e) { - throw new GeaflowLogException("Stop task {} failed", task.getId(), e); - } + } catch (Exception e) { + throw new GeaflowLogException("Stop task {} failed", task.getId(), e); } + } - private boolean existMasterService(GeaflowTask task, K8sClientStopArgsClass k8sClientStopArgs) { - Map params = k8sClientStopArgs.build().toStringMap(); - String masterUrl = k8sClientStopArgs.getClusterArgs().getClusterConfig().getMasterUrl(); + private boolean existMasterService(GeaflowTask task, K8sClientStopArgsClass k8sClientStopArgs) { + Map params = k8sClientStopArgs.build().toStringMap(); + String masterUrl = k8sClientStopArgs.getClusterArgs().getClusterConfig().getMasterUrl(); - VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); - K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); + VersionClassLoader loader = versionFactory.getClassLoader(task.getRelease().getVersion()); + K8sJobClient jobClient = loader.newInstance(K8sJobClient.class, params, masterUrl); - return jobClient.getMasterService() != null; - } + return jobClient.getMasterService() != null; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sTaskParams.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sTaskParams.java index de3650688..f30f39584 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sTaskParams.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/K8sTaskParams.java @@ -19,8 +19,6 @@ package org.apache.geaflow.console.core.service.runtime; -import com.alibaba.fastjson.JSON; -import com.google.common.base.Preconditions; import org.apache.geaflow.console.core.model.data.GeaflowInstance; import org.apache.geaflow.console.core.model.job.config.ClusterArgsClass; import org.apache.geaflow.console.core.model.job.config.ClusterConfigClass; @@ -33,33 +31,38 @@ import org.apache.geaflow.console.core.model.task.GeaflowTaskHandle; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; +import com.google.common.base.Preconditions; + @Component public class K8sTaskParams extends TaskParams { - public K8sClientArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { - return new K8sClientArgsClass(buildGeaflowArgs(instance, task), task.getMainClass()); - } - - public K8sClientStopArgsClass buildClientStopArgs(GeaflowTask task) { - GeaflowTaskHandle handle = task.getHandle(); - Preconditions.checkNotNull(handle, "Task %s handle can't be empty", task.getId()); + public K8sClientArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { + return new K8sClientArgsClass(buildGeaflowArgs(instance, task), task.getMainClass()); + } - return new K8sClientStopArgsClass(handle.getAppId(), ((K8SClusterArgsClass) buildClusterArgs(task))); - } + public K8sClientStopArgsClass buildClientStopArgs(GeaflowTask task) { + GeaflowTaskHandle handle = task.getHandle(); + Preconditions.checkNotNull(handle, "Task %s handle can't be empty", task.getId()); - @Override - protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { - GeaflowRelease release = task.getRelease(); + return new K8sClientStopArgsClass( + handle.getAppId(), ((K8SClusterArgsClass) buildClusterArgs(task))); + } - K8SClusterArgsClass clusterArgs = new K8SClusterArgsClass(); - ClusterConfigClass clusterConfig = release.getClusterConfig().parse(ClusterConfigClass.class); - K8sPluginConfigClass k8sPluginConfig = release.getCluster().getConfig().parse(K8sPluginConfigClass.class); + @Override + protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { + GeaflowRelease release = task.getRelease(); - clusterArgs.setTaskClusterConfig(clusterConfig); - clusterArgs.setClusterConfig(k8sPluginConfig); - clusterArgs.setEngineJarUrls(JSON.toJSONString(task.getVersionFiles(deployConfig.getGatewayUrl()))); - clusterArgs.setTaskJarUrls(JSON.toJSONString(task.getUserFiles(deployConfig.getGatewayUrl()))); - return clusterArgs; - } + K8SClusterArgsClass clusterArgs = new K8SClusterArgsClass(); + ClusterConfigClass clusterConfig = release.getClusterConfig().parse(ClusterConfigClass.class); + K8sPluginConfigClass k8sPluginConfig = + release.getCluster().getConfig().parse(K8sPluginConfigClass.class); + clusterArgs.setTaskClusterConfig(clusterConfig); + clusterArgs.setClusterConfig(k8sPluginConfig); + clusterArgs.setEngineJarUrls( + JSON.toJSONString(task.getVersionFiles(deployConfig.getGatewayUrl()))); + clusterArgs.setTaskJarUrls(JSON.toJSONString(task.getUserFiles(deployConfig.getGatewayUrl()))); + return clusterArgs; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayRuntime.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayRuntime.java index 3573a1cb0..0409e03fe 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayRuntime.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayRuntime.java @@ -19,16 +19,12 @@ package org.apache.geaflow.console.core.service.runtime; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONObject; import java.io.ByteArrayInputStream; import java.io.File; import java.io.InputStream; import java.util.ArrayList; import java.util.List; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringEscapeUtils; import org.apache.geaflow.console.common.util.HTTPUtil; @@ -55,212 +51,235 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Component public class RayRuntime implements GeaflowRuntime { - - @Autowired - private RemoteFileStorage remoteFileStorage; - - @Getter - @Setter - public static class RaySubmitResponse { - - private String jobId; - private String submissionId; - - } - - @Autowired - private RayTaskParams taskParams; - - @Autowired - private RemoteFileStorage fileStorage; - - @Autowired - private LocalFileFactory localFileFactory; - - @Autowired - private InstanceService instanceService; - - @Override - public GeaflowTaskHandle start(GeaflowTask task) { - GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); - uploadZipFiles(task); - return doStart(task, taskParams.buildClientArgs(instance, task)); - } - - private String getSubmitJobUrl(String rayUrl) { - return String.format("%s/api/jobs/", rayUrl); - } - - private String getQueryJobStatusUrl(String rayUrl, String rayJobId) { - return String.format("%s/api/jobs/%s", rayUrl, rayJobId); - } - - private String getStopJobUrl(String rayUrl, String rayJobId) { - return String.format("%s/api/jobs/%s/stop", rayUrl, rayJobId); - } - - private GeaflowTaskStatus toTaskStatus(String status) { - switch (status) { - // pending also shown as running in console. - case "PENDING": - case "RUNNING": - return GeaflowTaskStatus.RUNNING; - case "STOPPED": - return GeaflowTaskStatus.STOPPED; - case "SUCCEEDED": - return GeaflowTaskStatus.FINISHED; - case "FAILED": - return GeaflowTaskStatus.FAILED; - default: - throw new GeaflowException("Unknown status {}", status); - } - } - - @Override - public GeaflowTaskStatus queryStatus(GeaflowTask task) { - try { - return RetryUtil.exec(() -> { - RayPluginConfigClass rayPluginConfig = taskParams.buildRayClusterConfig(task); - String url = getQueryJobStatusUrl(rayPluginConfig.getDashboardAddress(), task.getHandle().getAppId()); - JSONObject response = HTTPUtil.get(url); - String status = response.get("status").toString(); - return toTaskStatus(status); - }, 5, 500); - - } catch (Exception e) { - log.error("Query task {} status failed, handle={}", task.getId(), JSON.toJSONString(task.getHandle()), e); - return GeaflowTaskStatus.FAILED; - } + @Autowired private RemoteFileStorage remoteFileStorage; + + @Getter + @Setter + public static class RaySubmitResponse { + + private String jobId; + private String submissionId; + } + + @Autowired private RayTaskParams taskParams; + + @Autowired private RemoteFileStorage fileStorage; + + @Autowired private LocalFileFactory localFileFactory; + + @Autowired private InstanceService instanceService; + + @Override + public GeaflowTaskHandle start(GeaflowTask task) { + GeaflowInstance instance = instanceService.get(task.getRelease().getJob().getInstanceId()); + uploadZipFiles(task); + return doStart(task, taskParams.buildClientArgs(instance, task)); + } + + private String getSubmitJobUrl(String rayUrl) { + return String.format("%s/api/jobs/", rayUrl); + } + + private String getQueryJobStatusUrl(String rayUrl, String rayJobId) { + return String.format("%s/api/jobs/%s", rayUrl, rayJobId); + } + + private String getStopJobUrl(String rayUrl, String rayJobId) { + return String.format("%s/api/jobs/%s/stop", rayUrl, rayJobId); + } + + private GeaflowTaskStatus toTaskStatus(String status) { + switch (status) { + // pending also shown as running in console. + case "PENDING": + case "RUNNING": + return GeaflowTaskStatus.RUNNING; + case "STOPPED": + return GeaflowTaskStatus.STOPPED; + case "SUCCEEDED": + return GeaflowTaskStatus.FINISHED; + case "FAILED": + return GeaflowTaskStatus.FAILED; + default: + throw new GeaflowException("Unknown status {}", status); } + } - @Override - public void stop(GeaflowTask task) { - try { - String rayJobId = task.getHandle().getAppId(); + @Override + public GeaflowTaskStatus queryStatus(GeaflowTask task) { + try { + return RetryUtil.exec( + () -> { RayPluginConfigClass rayPluginConfig = taskParams.buildRayClusterConfig(task); - String rayUrl = rayPluginConfig.getDashboardAddress(); - HTTPUtil.post(getStopJobUrl(rayUrl, rayJobId), new JSONObject().toJSONString()); - - } catch (Exception e) { - throw new GeaflowLogException("Stop task {} failed", task.getId(), e); - } + String url = + getQueryJobStatusUrl( + rayPluginConfig.getDashboardAddress(), task.getHandle().getAppId()); + JSONObject response = HTTPUtil.get(url); + String status = response.get("status").toString(); + return toTaskStatus(status); + }, + 5, + 500); + + } catch (Exception e) { + log.error( + "Query task {} status failed, handle={}", + task.getId(), + JSON.toJSONString(task.getHandle()), + e); + return GeaflowTaskStatus.FAILED; } - - - private String buildRequest(GeaflowTask task, GeaflowArgsClass geaflowArgs) { - RayClusterArgsClass clusterArgs = (RayClusterArgsClass) geaflowArgs.getClusterArgs(); - RayPluginConfigClass rayConfig = clusterArgs.getRayConfig(); - List remoteJarUrls = getDownloadJarUrls(task); - List downloadJarPaths = new ArrayList<>(remoteJarUrls.size()); - for (String remoteJarUrl : remoteJarUrls) { - String str = remoteJarUrl.replace(".zip", ""); - String result = str.replaceAll("[:/.]+", "_"); - downloadJarPaths.add(rayConfig.getSessionResourceJarPath() + result + "/*"); - } - - String downloadJarClassPath = String.join(":", downloadJarPaths); - List remoteJarUrlsStr = new ArrayList<>(remoteJarUrls.size()); - for (String remoteUrl : remoteJarUrls) { - remoteJarUrlsStr.add("\"" + remoteUrl + "\""); - } - String remoteJarJsonPath = String.join(",", remoteJarUrlsStr); - - String argString = StringEscapeUtils.escapeJava(JSON.toJSONString(geaflowArgs.build())); - argString = StringEscapeUtils.escapeJava("\"" + argString + "\""); - return String.format("{\n" + "\"entrypoint\": \"java -classpath %s:%s -Dray.address=%s %s %s\",\n" - + "\"runtime_env\": {\"java_jars\": [%s]}\n" + "}", rayConfig.getDistJarPath(), downloadJarClassPath, - rayConfig.getRedisAddress(), task.getMainClass(), argString, remoteJarJsonPath); + } + + @Override + public void stop(GeaflowTask task) { + try { + String rayJobId = task.getHandle().getAppId(); + RayPluginConfigClass rayPluginConfig = taskParams.buildRayClusterConfig(task); + String rayUrl = rayPluginConfig.getDashboardAddress(); + HTTPUtil.post(getStopJobUrl(rayUrl, rayJobId), new JSONObject().toJSONString()); + + } catch (Exception e) { + throw new GeaflowLogException("Stop task {} failed", task.getId(), e); } - - private List getDownloadJarUrls(GeaflowTask task) { - List urls = new ArrayList<>(); - GeaflowVersion version = task.getRelease().getVersion(); - - String versionUrl = formatHttp(remoteFileStorage.getUrl(getVersionFilePath(version))); - urls.add(versionUrl); - if (CollectionUtils.isNotEmpty(task.getUserJars())) { - String udfUrl = formatHttp(remoteFileStorage.getUrl(getTaskFilePath(task))); - urls.add(udfUrl); - } - return urls; + } + + private String buildRequest(GeaflowTask task, GeaflowArgsClass geaflowArgs) { + RayClusterArgsClass clusterArgs = (RayClusterArgsClass) geaflowArgs.getClusterArgs(); + RayPluginConfigClass rayConfig = clusterArgs.getRayConfig(); + List remoteJarUrls = getDownloadJarUrls(task); + List downloadJarPaths = new ArrayList<>(remoteJarUrls.size()); + for (String remoteJarUrl : remoteJarUrls) { + String str = remoteJarUrl.replace(".zip", ""); + String result = str.replaceAll("[:/.]+", "_"); + downloadJarPaths.add(rayConfig.getSessionResourceJarPath() + result + "/*"); } - private String formatHttp(String url) { - return url.replace("http://", "https://"); + String downloadJarClassPath = String.join(":", downloadJarPaths); + List remoteJarUrlsStr = new ArrayList<>(remoteJarUrls.size()); + for (String remoteUrl : remoteJarUrls) { + remoteJarUrlsStr.add("\"" + remoteUrl + "\""); } - - private GeaflowTaskHandle doStart(GeaflowTask task, RayClientArgsClass clientArgs) { - GeaflowArgsClass geaflowArgs = clientArgs.getGeaflowArgs(); - RayClusterArgsClass clusterArgs = (RayClusterArgsClass) geaflowArgs.getClusterArgs(); - try { - String request = buildRequest(task, geaflowArgs); - String rayUrl = clusterArgs.getRayConfig().getDashboardAddress(); - RaySubmitResponse response = HTTPUtil.post(getSubmitJobUrl(rayUrl), request, RaySubmitResponse.class); - - GeaflowTaskHandle taskHandle = new RayTaskHandle(response.submissionId); - log.info("Start task {} success, rayUrl={}, handle={}", rayUrl, task.getId(), - JSON.toJSONString(taskHandle)); - return taskHandle; - - } catch (Exception e) { - throw new GeaflowLogException("Start task {} failed", task.getId(), e); - } + String remoteJarJsonPath = String.join(",", remoteJarUrlsStr); + + String argString = StringEscapeUtils.escapeJava(JSON.toJSONString(geaflowArgs.build())); + argString = StringEscapeUtils.escapeJava("\"" + argString + "\""); + return String.format( + "{\n" + + "\"entrypoint\": \"java -classpath %s:%s -Dray.address=%s %s %s\",\n" + + "\"runtime_env\": {\"java_jars\": [%s]}\n" + + "}", + rayConfig.getDistJarPath(), + downloadJarClassPath, + rayConfig.getRedisAddress(), + task.getMainClass(), + argString, + remoteJarJsonPath); + } + + private List getDownloadJarUrls(GeaflowTask task) { + List urls = new ArrayList<>(); + GeaflowVersion version = task.getRelease().getVersion(); + + String versionUrl = formatHttp(remoteFileStorage.getUrl(getVersionFilePath(version))); + urls.add(versionUrl); + if (CollectionUtils.isNotEmpty(task.getUserJars())) { + String udfUrl = formatHttp(remoteFileStorage.getUrl(getTaskFilePath(task))); + urls.add(udfUrl); } - - - private String getTaskFilePath(GeaflowTask task) { - return RemoteFileStorage.getTaskFilePath(task.getRelease().getJob().getId(), task.getId() + "-task-udf.zip"); + return urls; + } + + private String formatHttp(String url) { + return url.replace("http://", "https://"); + } + + private GeaflowTaskHandle doStart(GeaflowTask task, RayClientArgsClass clientArgs) { + GeaflowArgsClass geaflowArgs = clientArgs.getGeaflowArgs(); + RayClusterArgsClass clusterArgs = (RayClusterArgsClass) geaflowArgs.getClusterArgs(); + try { + String request = buildRequest(task, geaflowArgs); + String rayUrl = clusterArgs.getRayConfig().getDashboardAddress(); + RaySubmitResponse response = + HTTPUtil.post(getSubmitJobUrl(rayUrl), request, RaySubmitResponse.class); + + GeaflowTaskHandle taskHandle = new RayTaskHandle(response.submissionId); + log.info( + "Start task {} success, rayUrl={}, handle={}", + rayUrl, + task.getId(), + JSON.toJSONString(taskHandle)); + return taskHandle; + + } catch (Exception e) { + throw new GeaflowLogException("Start task {} failed", task.getId(), e); } - - private String getVersionFilePath(GeaflowVersion version) { - return RemoteFileStorage.getVersionFilePath(version.getName(), version.getEngineJarPackage().getName()) - + ".zip"; + } + + private String getTaskFilePath(GeaflowTask task) { + return RemoteFileStorage.getTaskFilePath( + task.getRelease().getJob().getId(), task.getId() + "-task-udf.zip"); + } + + private String getVersionFilePath(GeaflowVersion version) { + return RemoteFileStorage.getVersionFilePath( + version.getName(), version.getEngineJarPackage().getName()) + + ".zip"; + } + + private String getVersionZipMd5Path(GeaflowVersion version, String md5) { + return String.format( + "%s_%s_zip.md5", + RemoteFileStorage.getVersionFilePath( + version.getName(), version.getEngineJarPackage().getName()), + md5); + } + + private void uploadZipFiles(GeaflowTask task) { + GeaflowVersion version = task.getRelease().getVersion(); + try { + // check file exists. + String zipPath = getVersionFilePath(version); + GeaflowRemoteFile jar = version.getEngineJarPackage(); + String versionZipMd5Path = getVersionZipMd5Path(version, jar.getMd5()); + if (!remoteFileStorage.checkFileExists(versionZipMd5Path)) { + File file = localFileFactory.getVersionFile(version.getName(), jar); + InputStream stream = ZipUtil.buildZipInputStream(new FileZipEntry(jar.getName(), file)); + fileStorage.upload(zipPath, stream); + // use the fileName as md5 + fileStorage.upload(versionZipMd5Path, new ByteArrayInputStream(new byte[] {})); + log.info("upload zip file for ray {}, {}", version.getName(), versionZipMd5Path); + } + } catch (Exception e) { + throw new GeaflowLogException("ZipFile engine failed {}", version.getName(), e); } - private String getVersionZipMd5Path(GeaflowVersion version, String md5) { - return String.format("%s_%s_zip.md5", - RemoteFileStorage.getVersionFilePath(version.getName(), version.getEngineJarPackage().getName()), md5); + List userJars = task.getUserJars(); + List udfs = new ArrayList<>(); + try { + for (GeaflowRemoteFile userJar : userJars) { + File file = localFileFactory.getUserFile(userJar.getCreatorId(), userJar); + udfs.add(new FileZipEntry(userJar.getName(), file)); + } + if (!udfs.isEmpty()) { + InputStream stream = ZipUtil.buildZipInputStream(udfs); + String zipPath = getTaskFilePath(task); + fileStorage.upload(zipPath, stream); + } + + } catch (Exception e) { + throw new GeaflowLogException("ZipFile udf failed", e); } - - private void uploadZipFiles(GeaflowTask task) { - GeaflowVersion version = task.getRelease().getVersion(); - try { - // check file exists. - String zipPath = getVersionFilePath(version); - GeaflowRemoteFile jar = version.getEngineJarPackage(); - String versionZipMd5Path = getVersionZipMd5Path(version, jar.getMd5()); - if (!remoteFileStorage.checkFileExists(versionZipMd5Path)) { - File file = localFileFactory.getVersionFile(version.getName(), jar); - InputStream stream = ZipUtil.buildZipInputStream(new FileZipEntry(jar.getName(), file)); - fileStorage.upload(zipPath, stream); - // use the fileName as md5 - fileStorage.upload(versionZipMd5Path, new ByteArrayInputStream(new byte[]{})); - log.info("upload zip file for ray {}, {}", version.getName(), versionZipMd5Path); - } - } catch (Exception e) { - throw new GeaflowLogException("ZipFile engine failed {}", version.getName(), e); - } - - List userJars = task.getUserJars(); - List udfs = new ArrayList<>(); - try { - for (GeaflowRemoteFile userJar : userJars) { - File file = localFileFactory.getUserFile(userJar.getCreatorId(), userJar); - udfs.add(new FileZipEntry(userJar.getName(), file)); - } - if (!udfs.isEmpty()) { - InputStream stream = ZipUtil.buildZipInputStream(udfs); - String zipPath = getTaskFilePath(task); - fileStorage.upload(zipPath, stream); - } - - } catch (Exception e) { - throw new GeaflowLogException("ZipFile udf failed", e); - } - } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayTaskParams.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayTaskParams.java index 6686986f5..1800c5266 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayTaskParams.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RayTaskParams.java @@ -31,23 +31,23 @@ @Component public class RayTaskParams extends TaskParams { - public RayClientArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { - return new RayClientArgsClass(buildGeaflowArgs(instance, task), task.getMainClass()); - } + public RayClientArgsClass buildClientArgs(GeaflowInstance instance, GeaflowTask task) { + return new RayClientArgsClass(buildGeaflowArgs(instance, task), task.getMainClass()); + } - @Override - protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { + @Override + protected ClusterArgsClass buildClusterArgs(GeaflowTask task) { - RayClusterArgsClass clusterArgs = new RayClusterArgsClass(); - ClusterConfigClass clusterConfig = task.getRelease().getClusterConfig().parse(ClusterConfigClass.class); + RayClusterArgsClass clusterArgs = new RayClusterArgsClass(); + ClusterConfigClass clusterConfig = + task.getRelease().getClusterConfig().parse(ClusterConfigClass.class); - clusterArgs.setTaskClusterConfig(clusterConfig); - clusterArgs.setRayConfig(buildRayClusterConfig(task)); - return clusterArgs; - } - - RayPluginConfigClass buildRayClusterConfig(GeaflowTask task) { - return task.getRelease().getCluster().getConfig().parse(RayPluginConfigClass.class); - } + clusterArgs.setTaskClusterConfig(clusterConfig); + clusterArgs.setRayConfig(buildRayClusterConfig(task)); + return clusterArgs; + } + RayPluginConfigClass buildRayClusterConfig(GeaflowTask task) { + return task.getRelease().getCluster().getConfig().parse(RayPluginConfigClass.class); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RuntimeFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RuntimeFactory.java index 45139dabd..0ea7d638a 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RuntimeFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/RuntimeFactory.java @@ -29,28 +29,26 @@ @Component public class RuntimeFactory { - @Autowired - ApplicationContext context; - - public GeaflowRuntime getRuntime(GeaflowTask task) { - GeaflowPluginType type = task.getRelease().getCluster().getType(); - - Class runtimeClass; - switch (type) { - case CONTAINER: - runtimeClass = ContainerRuntime.class; - break; - case K8S: - runtimeClass = K8sRuntime.class; - break; - case RAY: - runtimeClass = RayRuntime.class; - break; - default: - throw new GeaflowException("Unsupported runtime type {}", type); - } - - return context.getBean(runtimeClass); + @Autowired ApplicationContext context; + + public GeaflowRuntime getRuntime(GeaflowTask task) { + GeaflowPluginType type = task.getRelease().getCluster().getType(); + + Class runtimeClass; + switch (type) { + case CONTAINER: + runtimeClass = ContainerRuntime.class; + break; + case K8S: + runtimeClass = K8sRuntime.class; + break; + case RAY: + runtimeClass = RayRuntime.class; + break; + default: + throw new GeaflowException("Unsupported runtime type {}", type); } + return context.getBean(runtimeClass); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/TaskParams.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/TaskParams.java index c441981c2..2c176a036 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/TaskParams.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/runtime/TaskParams.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.core.service.runtime; -import com.google.common.base.Preconditions; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.util.type.CatalogType; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @@ -38,61 +37,65 @@ import org.apache.geaflow.console.core.service.config.DeployConfig; import org.springframework.beans.factory.annotation.Autowired; +import com.google.common.base.Preconditions; + public abstract class TaskParams { - private static final String RUNTIME_TASK_NAME_PREFIX = "geaflow"; - - @Autowired - protected DeployConfig deployConfig; - - public static String getRuntimeTaskName(String taskId) { - return RUNTIME_TASK_NAME_PREFIX + taskId; - } - - public void validateRuntimeTaskId(String runtimeTaskId) { - Preconditions.checkArgument(StringUtils.startsWith(runtimeTaskId, RUNTIME_TASK_NAME_PREFIX), - "Invalid runtimeTaskId %s", runtimeTaskId); - } - - protected final GeaflowArgsClass buildGeaflowArgs(GeaflowInstance instance, GeaflowTask task) { - GeaflowArgsClass geaflowArgs = new GeaflowArgsClass(); - geaflowArgs.setSystemArgs(buildSystemArgs(instance, task)); - geaflowArgs.setClusterArgs(buildClusterArgs(task)); - geaflowArgs.setJobArgs(buildJobArgs(task)); - return geaflowArgs; - } - - private SystemArgsClass buildSystemArgs(GeaflowInstance instance, GeaflowTask task) { - SystemArgsClass systemArgs = new SystemArgsClass(); - - String taskId = task.getId(); - String runtimeTaskName = getRuntimeTaskName(taskId); - String runtimeTaskId = runtimeTaskName + "-" + System.currentTimeMillis(); - - systemArgs.setTaskId(taskId); - systemArgs.setRuntimeTaskId(runtimeTaskId); - systemArgs.setRuntimeTaskName(runtimeTaskName); - systemArgs.setGateway(deployConfig.getGatewayUrl()); - systemArgs.setTaskToken(task.getToken()); - systemArgs.setStartupNotifyUrl(task.getStartupNotifyUrl(deployConfig.getGatewayUrl())); - systemArgs.setInstanceName(instance.getName()); - systemArgs.setCatalogType(CatalogType.CONSOLE.getValue()); - - StateArgsClass stateArgs = new StateArgsClass(); - stateArgs.setRuntimeMetaArgs(new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig())); - stateArgs.setHaMetaArgs(new HaMetaArgsClass(task.getHaMetaPluginConfig())); - stateArgs.setPersistentArgs(new PersistentArgsClass(task.getDataPluginConfig())); - systemArgs.setStateArgs(stateArgs); - - systemArgs.setMetricArgs(new MetricArgsClass(task.getMetricPluginConfig())); - return systemArgs; - } - - protected abstract ClusterArgsClass buildClusterArgs(GeaflowTask task); - - private JobArgsClass buildJobArgs(GeaflowTask task) { - JobArgsClass jobArgs = new JobArgsClass(task.getRelease().getJobConfig().parse(JobConfigClass.class)); - jobArgs.setSystemStateType(GeaflowPluginType.ROCKSDB); - return jobArgs; - } + private static final String RUNTIME_TASK_NAME_PREFIX = "geaflow"; + + @Autowired protected DeployConfig deployConfig; + + public static String getRuntimeTaskName(String taskId) { + return RUNTIME_TASK_NAME_PREFIX + taskId; + } + + public void validateRuntimeTaskId(String runtimeTaskId) { + Preconditions.checkArgument( + StringUtils.startsWith(runtimeTaskId, RUNTIME_TASK_NAME_PREFIX), + "Invalid runtimeTaskId %s", + runtimeTaskId); + } + + protected final GeaflowArgsClass buildGeaflowArgs(GeaflowInstance instance, GeaflowTask task) { + GeaflowArgsClass geaflowArgs = new GeaflowArgsClass(); + geaflowArgs.setSystemArgs(buildSystemArgs(instance, task)); + geaflowArgs.setClusterArgs(buildClusterArgs(task)); + geaflowArgs.setJobArgs(buildJobArgs(task)); + return geaflowArgs; + } + + private SystemArgsClass buildSystemArgs(GeaflowInstance instance, GeaflowTask task) { + SystemArgsClass systemArgs = new SystemArgsClass(); + + String taskId = task.getId(); + String runtimeTaskName = getRuntimeTaskName(taskId); + String runtimeTaskId = runtimeTaskName + "-" + System.currentTimeMillis(); + + systemArgs.setTaskId(taskId); + systemArgs.setRuntimeTaskId(runtimeTaskId); + systemArgs.setRuntimeTaskName(runtimeTaskName); + systemArgs.setGateway(deployConfig.getGatewayUrl()); + systemArgs.setTaskToken(task.getToken()); + systemArgs.setStartupNotifyUrl(task.getStartupNotifyUrl(deployConfig.getGatewayUrl())); + systemArgs.setInstanceName(instance.getName()); + systemArgs.setCatalogType(CatalogType.CONSOLE.getValue()); + + StateArgsClass stateArgs = new StateArgsClass(); + stateArgs.setRuntimeMetaArgs(new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig())); + stateArgs.setHaMetaArgs(new HaMetaArgsClass(task.getHaMetaPluginConfig())); + stateArgs.setPersistentArgs(new PersistentArgsClass(task.getDataPluginConfig())); + systemArgs.setStateArgs(stateArgs); + + systemArgs.setMetricArgs(new MetricArgsClass(task.getMetricPluginConfig())); + return systemArgs; + } + + protected abstract ClusterArgsClass buildClusterArgs(GeaflowTask task); + + private JobArgsClass buildJobArgs(GeaflowTask task) { + JobArgsClass jobArgs = + new JobArgsClass(task.getRelease().getJobConfig().parse(JobConfigClass.class)); + jobArgs.setSystemStateType(GeaflowPluginType.ROCKSDB); + return jobArgs; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/ResourceFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/ResourceFactory.java index c04d4ae0c..d3ddaba10 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/ResourceFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/ResourceFactory.java @@ -55,108 +55,102 @@ @Component public class ResourceFactory implements InitializingBean { - private static ResourceFactory INSTANCE; + private static ResourceFactory INSTANCE; - @Autowired - private InstanceDao instanceDao; + @Autowired private InstanceDao instanceDao; - @Autowired - private TableDao tableDao; + @Autowired private TableDao tableDao; - @Autowired - private ViewDao viewDao; + @Autowired private ViewDao viewDao; - @Autowired - private FunctionDao functionDao; + @Autowired private FunctionDao functionDao; - @Autowired - private VertexDao vertexDao; + @Autowired private VertexDao vertexDao; - @Autowired - private EdgeDao edgeDao; + @Autowired private EdgeDao edgeDao; - @Autowired - private GraphDao graphDao; + @Autowired private GraphDao graphDao; - @Autowired - private JobDao jobDao; + @Autowired private JobDao jobDao; - @Autowired - private TaskDao taskDao; + @Autowired private TaskDao taskDao; - protected static ResourceFactory getInstance() { - if (INSTANCE == null) { - throw new GeaflowException("{} is not ready", ResourceFactory.class.getSimpleName()); - } - return INSTANCE; + protected static ResourceFactory getInstance() { + if (INSTANCE == null) { + throw new GeaflowException("{} is not ready", ResourceFactory.class.getSimpleName()); } - - @Override - public void afterPropertiesSet() throws Exception { - INSTANCE = this; - } - - public T build(GeaflowResourceType resourceType, String resourceId) { - GeaflowResource resource; - try { - switch (resourceType) { - case TENANT: - resource = new TenantResource(resourceId); - break; - - case INSTANCE: - InstanceEntity instance = instanceDao.get(resourceId); - resource = new InstanceResource(instance.getTenantId(), instance.getId()); - break; - - case TABLE: - TableEntity table = tableDao.get(resourceId); - resource = new TableResource(table.getTenantId(), table.getInstanceId(), table.getId()); - break; - - case VIEW: - ViewEntity view = viewDao.get(resourceId); - resource = new ViewResource(view.getTenantId(), view.getInstanceId(), view.getId()); - break; - - case FUNCTION: - FunctionEntity function = functionDao.get(resourceId); - resource = new ViewResource(function.getTenantId(), function.getInstanceId(), function.getId()); - break; - - case VERTEX: - VertexEntity vertex = vertexDao.get(resourceId); - resource = new GraphResource(vertex.getTenantId(), vertex.getInstanceId(), vertex.getId()); - break; - - case EDGE: - EdgeEntity edge = edgeDao.get(resourceId); - resource = new GraphResource(edge.getTenantId(), edge.getInstanceId(), edge.getId()); - break; - - case GRAPH: - GraphEntity graph = graphDao.get(resourceId); - resource = new GraphResource(graph.getTenantId(), graph.getInstanceId(), graph.getId()); - break; - - case JOB: - JobEntity job = jobDao.get(resourceId); - resource = new JobResource(job.getTenantId(), job.getInstanceId(), job.getId()); - break; - - case TASK: - TaskEntity task = taskDao.get(resourceId); - resource = new TaskResource(build(GeaflowResourceType.JOB, task.getJobId()), task.getId()); - break; - - default: - throw new GeaflowException("Resource type {} not supported", resourceType); - } - - return (T) resource; - - } catch (Exception e) { - throw new GeaflowIllegalException("Build resource {} {} failed", resourceType, resourceId, e); - } + return INSTANCE; + } + + @Override + public void afterPropertiesSet() throws Exception { + INSTANCE = this; + } + + public T build(GeaflowResourceType resourceType, String resourceId) { + GeaflowResource resource; + try { + switch (resourceType) { + case TENANT: + resource = new TenantResource(resourceId); + break; + + case INSTANCE: + InstanceEntity instance = instanceDao.get(resourceId); + resource = new InstanceResource(instance.getTenantId(), instance.getId()); + break; + + case TABLE: + TableEntity table = tableDao.get(resourceId); + resource = new TableResource(table.getTenantId(), table.getInstanceId(), table.getId()); + break; + + case VIEW: + ViewEntity view = viewDao.get(resourceId); + resource = new ViewResource(view.getTenantId(), view.getInstanceId(), view.getId()); + break; + + case FUNCTION: + FunctionEntity function = functionDao.get(resourceId); + resource = + new ViewResource(function.getTenantId(), function.getInstanceId(), function.getId()); + break; + + case VERTEX: + VertexEntity vertex = vertexDao.get(resourceId); + resource = + new GraphResource(vertex.getTenantId(), vertex.getInstanceId(), vertex.getId()); + break; + + case EDGE: + EdgeEntity edge = edgeDao.get(resourceId); + resource = new GraphResource(edge.getTenantId(), edge.getInstanceId(), edge.getId()); + break; + + case GRAPH: + GraphEntity graph = graphDao.get(resourceId); + resource = new GraphResource(graph.getTenantId(), graph.getInstanceId(), graph.getId()); + break; + + case JOB: + JobEntity job = jobDao.get(resourceId); + resource = new JobResource(job.getTenantId(), job.getInstanceId(), job.getId()); + break; + + case TASK: + TaskEntity task = taskDao.get(resourceId); + resource = + new TaskResource(build(GeaflowResourceType.JOB, task.getJobId()), task.getId()); + break; + + default: + throw new GeaflowException("Resource type {} not supported", resourceType); + } + + return (T) resource; + + } catch (Exception e) { + throw new GeaflowIllegalException("Build resource {} {} failed", resourceType, resourceId, e); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/Resources.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/Resources.java index e95f2bfbd..0bbe21458 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/Resources.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/Resources.java @@ -33,47 +33,47 @@ public class Resources { - private static ResourceFactory getResourceFactory() { - return ResourceFactory.getInstance(); - } + private static ResourceFactory getResourceFactory() { + return ResourceFactory.getInstance(); + } - public static TenantResource tenant(String tenantId) { - return getResourceFactory().build(GeaflowResourceType.TENANT, tenantId); - } + public static TenantResource tenant(String tenantId) { + return getResourceFactory().build(GeaflowResourceType.TENANT, tenantId); + } - public static InstanceResource instance(String instanceId) { - return getResourceFactory().build(GeaflowResourceType.INSTANCE, instanceId); - } + public static InstanceResource instance(String instanceId) { + return getResourceFactory().build(GeaflowResourceType.INSTANCE, instanceId); + } - public static TableResource table(String tableId) { - return getResourceFactory().build(GeaflowResourceType.TABLE, tableId); - } + public static TableResource table(String tableId) { + return getResourceFactory().build(GeaflowResourceType.TABLE, tableId); + } - public static ViewResource view(String viewId) { - return getResourceFactory().build(GeaflowResourceType.VIEW, viewId); - } + public static ViewResource view(String viewId) { + return getResourceFactory().build(GeaflowResourceType.VIEW, viewId); + } - public static FunctionResource function(String functionId) { - return getResourceFactory().build(GeaflowResourceType.FUNCTION, functionId); - } + public static FunctionResource function(String functionId) { + return getResourceFactory().build(GeaflowResourceType.FUNCTION, functionId); + } - public static VertexResource vertex(String vertexId) { - return getResourceFactory().build(GeaflowResourceType.VERTEX, vertexId); - } + public static VertexResource vertex(String vertexId) { + return getResourceFactory().build(GeaflowResourceType.VERTEX, vertexId); + } - public static EdgeResource edge(String edgeId) { - return getResourceFactory().build(GeaflowResourceType.EDGE, edgeId); - } + public static EdgeResource edge(String edgeId) { + return getResourceFactory().build(GeaflowResourceType.EDGE, edgeId); + } - public static GraphResource graph(String graphId) { - return getResourceFactory().build(GeaflowResourceType.GRAPH, graphId); - } + public static GraphResource graph(String graphId) { + return getResourceFactory().build(GeaflowResourceType.GRAPH, graphId); + } - public static JobResource job(String jobId) { - return getResourceFactory().build(GeaflowResourceType.JOB, jobId); - } + public static JobResource job(String jobId) { + return getResourceFactory().build(GeaflowResourceType.JOB, jobId); + } - public static TaskResource task(String taskId) { - return getResourceFactory().build(GeaflowResourceType.TASK, taskId); - } + public static TaskResource task(String taskId) { + return getResourceFactory().build(GeaflowResourceType.TASK, taskId); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/TokenGenerator.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/TokenGenerator.java index dedbf7f2f..670b304c0 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/TokenGenerator.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/security/TokenGenerator.java @@ -26,18 +26,17 @@ @Component public class TokenGenerator { - private static final String TASK_TOKEN_PREFIX = "TASK-"; + private static final String TASK_TOKEN_PREFIX = "TASK-"; - public static boolean isTaskToken(String token) { - return StringUtils.startsWith(token, TASK_TOKEN_PREFIX); - } + public static boolean isTaskToken(String token) { + return StringUtils.startsWith(token, TASK_TOKEN_PREFIX); + } - public String nextToken() { - return RandomStringUtils.randomAlphanumeric(32); - } - - public String nextTaskToken() { - return TASK_TOKEN_PREFIX + nextToken(); - } + public String nextToken() { + return RandomStringUtils.randomAlphanumeric(32); + } + public String nextTaskToken() { + return TASK_TOKEN_PREFIX + nextToken(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientFactory.java index fbcce8af0..1b56dda07 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientFactory.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.core.service.statement; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.service.integration.engine.analytics.AnalyticsClient; import org.apache.geaflow.console.common.service.integration.engine.analytics.AnalyticsClientBuilder; import org.apache.geaflow.console.common.service.integration.engine.analytics.Configuration; @@ -29,25 +28,24 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Component @Slf4j public class AnalyticsClientFactory { - @Autowired - private VersionFactory versionFactory; - - public AnalyticsClient buildClient(GeaflowTask task) { - final VersionClassLoader classLoader = versionFactory.getClassLoader(task.getRelease().getVersion()); - final AnalyticsClientBuilder builder = classLoader.newInstance(AnalyticsClientBuilder.class); - Configuration configuration = classLoader.newInstance(Configuration.class); - final String redisParentNamespace = "/geaflow" + task.getId(); - configuration.putAll(task.getRelease().getJobConfig().toStringMap()); - configuration.put("brpc.connect.timeout.ms", String.valueOf(8000)); - configuration.put("geaflow.meta.server.retry.times", String.valueOf(2)); - configuration.put("geaflow.job.runtime.name", redisParentNamespace); - return builder.withConfiguration(configuration) - .withInitChannelPools(true) - .build(); - } + @Autowired private VersionFactory versionFactory; + public AnalyticsClient buildClient(GeaflowTask task) { + final VersionClassLoader classLoader = + versionFactory.getClassLoader(task.getRelease().getVersion()); + final AnalyticsClientBuilder builder = classLoader.newInstance(AnalyticsClientBuilder.class); + Configuration configuration = classLoader.newInstance(Configuration.class); + final String redisParentNamespace = "/geaflow" + task.getId(); + configuration.putAll(task.getRelease().getJobConfig().toStringMap()); + configuration.put("brpc.connect.timeout.ms", String.valueOf(8000)); + configuration.put("geaflow.meta.server.retry.times", String.valueOf(2)); + configuration.put("geaflow.job.runtime.name", redisParentNamespace); + return builder.withConfiguration(configuration).withInitChannelPools(true).build(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientPool.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientPool.java index d77fec6a3..86d7bd4e1 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientPool.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/AnalyticsClientPool.java @@ -19,55 +19,58 @@ package org.apache.geaflow.console.core.service.statement; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.service.integration.engine.analytics.AnalyticsClient; import org.apache.geaflow.console.core.model.task.GeaflowTask; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +import lombok.extern.slf4j.Slf4j; + @Component @Slf4j public class AnalyticsClientPool { - @Autowired - private AnalyticsClientFactory analyticsClientFactory; + @Autowired private AnalyticsClientFactory analyticsClientFactory; - private static final LoadingCache> clientPool = CacheBuilder.newBuilder().maximumSize(100) - .expireAfterWrite(180, TimeUnit.SECONDS).build(new CacheLoader>() { - @Override - public LinkedBlockingQueue load(String jobId) { - return new LinkedBlockingQueue<>(10); - } - }); + private static final LoadingCache> clientPool = + CacheBuilder.newBuilder() + .maximumSize(100) + .expireAfterWrite(180, TimeUnit.SECONDS) + .build( + new CacheLoader>() { + @Override + public LinkedBlockingQueue load(String jobId) { + return new LinkedBlockingQueue<>(10); + } + }); - public AnalyticsClient getClient(GeaflowTask task) { - try { - LinkedBlockingQueue analyticsClients = clientPool.get(task.getId()); - AnalyticsClient client = analyticsClients.poll(); - if (client == null) { - client = analyticsClientFactory.buildClient(task); - } - return client; - } catch (ExecutionException e) { - throw new RuntimeException(e); - } + public AnalyticsClient getClient(GeaflowTask task) { + try { + LinkedBlockingQueue analyticsClients = clientPool.get(task.getId()); + AnalyticsClient client = analyticsClients.poll(); + if (client == null) { + client = analyticsClientFactory.buildClient(task); + } + return client; + } catch (ExecutionException e) { + throw new RuntimeException(e); } + } - - public void addClient(GeaflowTask task, AnalyticsClient client) { - try { - LinkedBlockingQueue analyticsClients = clientPool.get(task.getId()); - analyticsClients.offer(client); - } catch (Exception e) { - log.info("add client fail", e); - } + public void addClient(GeaflowTask task, AnalyticsClient client) { + try { + LinkedBlockingQueue analyticsClients = clientPool.get(task.getId()); + analyticsClients.offer(client); + } catch (Exception e) { + log.info("add client fail", e); } - - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/StatementSubmitter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/StatementSubmitter.java index 02d6cf123..94f75ee7b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/StatementSubmitter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/statement/StatementSubmitter.java @@ -23,7 +23,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.geaflow.console.common.service.integration.engine.analytics.AnalyticsClient; import org.apache.geaflow.console.common.service.integration.engine.analytics.QueryResults; @@ -37,72 +37,71 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j public class StatementSubmitter { - private static final ExecutorService EXECUTOR_SERVICE = new ThreadPoolExecutor(10, 30, - 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(100)); + private static final ExecutorService EXECUTOR_SERVICE = + new ThreadPoolExecutor(10, 30, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(100)); - private static final String queryTemplate = "USE GRAPH %s; \n %s;"; - @Autowired - private VersionFactory versionFactory; + private static final String queryTemplate = "USE GRAPH %s; \n %s;"; + @Autowired private VersionFactory versionFactory; - @Autowired - private AnalyticsClientPool analyticsClientPool; + @Autowired private AnalyticsClientPool analyticsClientPool; - @Autowired - private StatementService statementService; + @Autowired private StatementService statementService; - public void asyncSubmitQuery(GeaflowStatement query, GeaflowTask task) { - final String sessionToken = ContextHolder.get().getSessionToken(); - EXECUTOR_SERVICE.submit(() -> { - try { - ContextHolder.init(); - ContextHolder.get().setSessionToken(sessionToken); + public void asyncSubmitQuery(GeaflowStatement query, GeaflowTask task) { + final String sessionToken = ContextHolder.get().getSessionToken(); + EXECUTOR_SERVICE.submit( + () -> { + try { + ContextHolder.init(); + ContextHolder.get().setSessionToken(sessionToken); - submitQuery(query, task); + submitQuery(query, task); - } finally { - ContextHolder.destroy(); - } + } finally { + ContextHolder.destroy(); + } }); - } + } - private void submitQuery(GeaflowStatement query, GeaflowTask task) { - GeaflowStatementStatus status = null; - String result = null; - AnalyticsClient client = null; - try { - String script = formatQuery(query.getScript(), task); - client = analyticsClientPool.getClient(task); - QueryResults queryResults = client.executeQuery(script); - if (queryResults.getQueryStatus()) { - status = GeaflowStatementStatus.FINISHED; - result = queryResults.getFormattedData(); - } else { - status = GeaflowStatementStatus.FAILED; - result = queryResults.getError().getName(); - } - - } catch (Exception e) { - status = GeaflowStatementStatus.FAILED; - result = ExceptionUtils.getStackTrace(e); - - } finally { - log.info("query finish {}, {}, {}", query.getScript(), status, result); - query.setStatus(status); - query.setResult(result); - statementService.update(query); - if (client != null) { - analyticsClientPool.addClient(task, client); - } - - } - } + private void submitQuery(GeaflowStatement query, GeaflowTask task) { + GeaflowStatementStatus status = null; + String result = null; + AnalyticsClient client = null; + try { + String script = formatQuery(query.getScript(), task); + client = analyticsClientPool.getClient(task); + QueryResults queryResults = client.executeQuery(script); + if (queryResults.getQueryStatus()) { + status = GeaflowStatementStatus.FINISHED; + result = queryResults.getFormattedData(); + } else { + status = GeaflowStatementStatus.FAILED; + result = queryResults.getError().getName(); + } - private String formatQuery(String script, GeaflowTask task) { - GeaflowGraph graph = task.getRelease().getJob().getGraphs().get(0); - return String.format(queryTemplate, graph.getName(), script); + } catch (Exception e) { + status = GeaflowStatementStatus.FAILED; + result = ExceptionUtils.getStackTrace(e); + + } finally { + log.info("query finish {}, {}, {}", query.getScript(), status, result); + query.setStatus(status); + query.setResult(result); + statementService.update(query); + if (client != null) { + analyticsClientPool.addClient(task, client); + } } + } + + private String formatQuery(String script, GeaflowTask task) { + GeaflowGraph graph = task.getRelease().getJob().getGraphs().get(0); + return String.format(queryTemplate, graph.getName(), script); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowDataStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowDataStore.java index 4b1615a5c..7ee69ba8d 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowDataStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowDataStore.java @@ -20,19 +20,19 @@ package org.apache.geaflow.console.core.service.store; import java.util.Date; + import org.apache.geaflow.console.core.model.data.GeaflowGraph; import org.apache.geaflow.console.core.model.task.GeaflowTask; public interface GeaflowDataStore { - Long queryStorageUsage(GeaflowTask task); - - Long queryFileCount(GeaflowTask task); + Long queryStorageUsage(GeaflowTask task); - Date queryModifyTime(GeaflowTask task); + Long queryFileCount(GeaflowTask task); - void cleanTaskData(GeaflowTask task); + Date queryModifyTime(GeaflowTask task); - void cleanGraphData(GeaflowGraph graph); + void cleanTaskData(GeaflowTask task); + void cleanGraphData(GeaflowGraph graph); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowHaMetaStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowHaMetaStore.java index c0c476a07..1b3d3459b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowHaMetaStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowHaMetaStore.java @@ -23,6 +23,5 @@ public interface GeaflowHaMetaStore { - void cleanHaMeta(GeaflowTask task); - + void cleanHaMeta(GeaflowTask task); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowMetricStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowMetricStore.java index f2aaa34ad..b6508f542 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowMetricStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowMetricStore.java @@ -26,6 +26,5 @@ public interface GeaflowMetricStore { - PageList queryMetrics(GeaflowTask task, GeaflowMetricQueryRequest queryRequest); - + PageList queryMetrics(GeaflowTask task, GeaflowMetricQueryRequest queryRequest); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowRuntimeMetaStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowRuntimeMetaStore.java index 1b2688add..72d6fe30a 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowRuntimeMetaStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/GeaflowRuntimeMetaStore.java @@ -30,18 +30,17 @@ public interface GeaflowRuntimeMetaStore { - PageList queryPipelines(GeaflowTask task); + PageList queryPipelines(GeaflowTask task); - PageList queryCycles(GeaflowTask task, String pipelineId); + PageList queryCycles(GeaflowTask task, String pipelineId); - PageList queryOffsets(GeaflowTask task); + PageList queryOffsets(GeaflowTask task); - PageList queryErrors(GeaflowTask task); + PageList queryErrors(GeaflowTask task); - PageList queryMetricMeta(GeaflowTask task); + PageList queryMetricMeta(GeaflowTask task); - GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task); - - void cleanRuntimeMeta(GeaflowTask task); + GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task); + void cleanRuntimeMeta(GeaflowTask task); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/DataStoreFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/DataStoreFactory.java index d3e4c1478..b850a9266 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/DataStoreFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/DataStoreFactory.java @@ -31,22 +31,19 @@ @Component public class DataStoreFactory { - @Autowired - ApplicationContext context; + @Autowired ApplicationContext context; - @Autowired - PluginService pluginService; + @Autowired PluginService pluginService; - public GeaflowDataStore getDataStore(String type) { - GeaflowPluginType typeEnum = GeaflowPluginType.of(type); - switch (typeEnum) { - case LOCAL: - case DFS: - case OSS: - return context.getBean(PersistentDataStore.class); - default: - throw new GeaflowIllegalException("Not supported data store type {}", type); - } + public GeaflowDataStore getDataStore(String type) { + GeaflowPluginType typeEnum = GeaflowPluginType.of(type); + switch (typeEnum) { + case LOCAL: + case DFS: + case OSS: + return context.getBean(PersistentDataStore.class); + default: + throw new GeaflowIllegalException("Not supported data store type {}", type); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/HaMetaStoreFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/HaMetaStoreFactory.java index f3f37d440..67a6091ea 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/HaMetaStoreFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/HaMetaStoreFactory.java @@ -31,17 +31,16 @@ @Component public class HaMetaStoreFactory { - @Autowired - ApplicationContext context; + @Autowired ApplicationContext context; - public GeaflowHaMetaStore getHaMetaStore(GeaflowPluginConfig pluginConfig) { - GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); - switch (type) { - case REDIS: - return context.getBean(RedisStore.class); - default: - throw new GeaflowIllegalException("Not supported HA meta store type {}", pluginConfig.getType()); - } + public GeaflowHaMetaStore getHaMetaStore(GeaflowPluginConfig pluginConfig) { + GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); + switch (type) { + case REDIS: + return context.getBean(RedisStore.class); + default: + throw new GeaflowIllegalException( + "Not supported HA meta store type {}", pluginConfig.getType()); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/MetricStoreFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/MetricStoreFactory.java index e1aef4221..27c140d33 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/MetricStoreFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/MetricStoreFactory.java @@ -31,17 +31,16 @@ @Component public class MetricStoreFactory { - @Autowired - ApplicationContext context; + @Autowired ApplicationContext context; - public GeaflowMetricStore getMetricStore(GeaflowPluginConfig pluginConfig) { - GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); - switch (type) { - case INFLUXDB: - return context.getBean(InfluxdbStore.class); - default: - throw new GeaflowIllegalException("Not supported metric store type {}", pluginConfig.getType()); - } + public GeaflowMetricStore getMetricStore(GeaflowPluginConfig pluginConfig) { + GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); + switch (type) { + case INFLUXDB: + return context.getBean(InfluxdbStore.class); + default: + throw new GeaflowIllegalException( + "Not supported metric store type {}", pluginConfig.getType()); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/RuntimeMetaStoreFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/RuntimeMetaStoreFactory.java index 4b8023d9f..234b7da5b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/RuntimeMetaStoreFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/factory/RuntimeMetaStoreFactory.java @@ -31,17 +31,16 @@ @Component public class RuntimeMetaStoreFactory { - @Autowired - ApplicationContext context; + @Autowired ApplicationContext context; - public GeaflowRuntimeMetaStore getRuntimeMetaStore(GeaflowPluginConfig pluginConfig) { - GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); - switch (type) { - case JDBC: - return context.getBean(JdbcStore.class); - default: - throw new GeaflowIllegalException("Not supported runtime meta store type {}", pluginConfig.getType()); - } + public GeaflowRuntimeMetaStore getRuntimeMetaStore(GeaflowPluginConfig pluginConfig) { + GeaflowPluginType type = GeaflowPluginType.of(pluginConfig.getType()); + switch (type) { + case JDBC: + return context.getBean(JdbcStore.class); + default: + throw new GeaflowIllegalException( + "Not supported runtime meta store type {}", pluginConfig.getType()); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/InfluxdbStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/InfluxdbStore.java index 98e4111aa..79519bc72 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/InfluxdbStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/InfluxdbStore.java @@ -19,22 +19,12 @@ package org.apache.geaflow.console.core.service.store.impl; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.RemovalListener; -import com.influxdb.client.InfluxDBClientOptions; -import com.influxdb.client.QueryApi; -import com.influxdb.client.internal.InfluxDBClientImpl; -import com.influxdb.query.FluxTable; -import com.influxdb.query.dsl.Flux; -import com.influxdb.query.dsl.functions.restriction.Restrictions; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; -import okhttp3.OkHttpClient; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.dal.model.PageList; import org.apache.geaflow.console.common.util.exception.GeaflowException; @@ -46,106 +36,133 @@ import org.apache.geaflow.console.core.service.store.GeaflowMetricStore; import org.springframework.stereotype.Service; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; +import com.influxdb.client.InfluxDBClientOptions; +import com.influxdb.client.QueryApi; +import com.influxdb.client.internal.InfluxDBClientImpl; +import com.influxdb.query.FluxTable; +import com.influxdb.query.dsl.Flux; +import com.influxdb.query.dsl.functions.restriction.Restrictions; + +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; + @Service @Slf4j public class InfluxdbStore implements GeaflowMetricStore { - private static final Cache INFLUXDB_CLIENT_CACHE = - CacheBuilder.newBuilder().initialCapacity(10).maximumSize(100) - .expireAfterWrite(120, TimeUnit.SECONDS) - .removalListener( - (RemovalListener) removalNotification -> { + private static final Cache INFLUXDB_CLIENT_CACHE = + CacheBuilder.newBuilder() + .initialCapacity(10) + .maximumSize(100) + .expireAfterWrite(120, TimeUnit.SECONDS) + .removalListener( + (RemovalListener) + removalNotification -> { String key = removalNotification.getKey(); InfluxDBClientImpl client = removalNotification.getValue(); if (client != null) { - log.debug("influxdb {} close", key); - client.close(); + log.debug("influxdb {} close", key); + client.close(); } - }) - .build(); + }) + .build(); - @Override - public PageList queryMetrics(GeaflowTask task, GeaflowMetricQueryRequest queryRequest) { - if (CollectionUtils.isEmpty(queryRequest.getQueries())) { - return new PageList<>(new ArrayList<>()); - } - - InfluxdbPluginConfigClass influxdbConfig = task.getMetricPluginConfig().getConfig().parse(InfluxdbPluginConfigClass.class); - Flux flux = buildQueryFlux(influxdbConfig.getBucket(), queryRequest); - InfluxDBClientImpl client = getInfluxdbClient(influxdbConfig); - QueryApi queryApi = client.getQueryApi(); - List tables = queryApi.query(flux.toString()); - return new PageList<>(buildGeaflowMetrics(tables)); + @Override + public PageList queryMetrics( + GeaflowTask task, GeaflowMetricQueryRequest queryRequest) { + if (CollectionUtils.isEmpty(queryRequest.getQueries())) { + return new PageList<>(new ArrayList<>()); } - protected Flux buildQueryFlux(String bucketName, GeaflowMetricQueryRequest queryRequest) { - long startMs = queryRequest.getStart(); - long endMs = queryRequest.getEnd(); - List queries = queryRequest.getQueries(); - GeaflowMetricQuery metricQuery = queries.get(0); - Restrictions restriction = Restrictions.and( + InfluxdbPluginConfigClass influxdbConfig = + task.getMetricPluginConfig().getConfig().parse(InfluxdbPluginConfigClass.class); + Flux flux = buildQueryFlux(influxdbConfig.getBucket(), queryRequest); + InfluxDBClientImpl client = getInfluxdbClient(influxdbConfig); + QueryApi queryApi = client.getQueryApi(); + List tables = queryApi.query(flux.toString()); + return new PageList<>(buildGeaflowMetrics(tables)); + } + + protected Flux buildQueryFlux(String bucketName, GeaflowMetricQueryRequest queryRequest) { + long startMs = queryRequest.getStart(); + long endMs = queryRequest.getEnd(); + List queries = queryRequest.getQueries(); + GeaflowMetricQuery metricQuery = queries.get(0); + Restrictions restriction = + Restrictions.and( Restrictions.tag("jobName").equal(metricQuery.getTags().get("jobName")), - Restrictions.measurement().equal(metricQuery.getMetric()) - ); - Flux flux = Flux - .from(bucketName) + Restrictions.measurement().equal(metricQuery.getMetric())); + Flux flux = + Flux.from(bucketName) .range(startMs / 1000, endMs / 1000) .filter(restriction) .window(60L, ChronoUnit.SECONDS); - String[] downsample = metricQuery.getDownsample().split("-"); - flux = setAggregator(flux, downsample[1]); - return flux.duplicate("_stop", "_time"); - } + String[] downsample = metricQuery.getDownsample().split("-"); + flux = setAggregator(flux, downsample[1]); + return flux.duplicate("_stop", "_time"); + } - protected Flux setAggregator(Flux flux, String aggregator) { - switch (aggregator) { - case "avg": - return flux.mean(); - case "sum": - return flux.sum(); - default: - throw new GeaflowException("not supported aggregator: {}", aggregator); - } + protected Flux setAggregator(Flux flux, String aggregator) { + switch (aggregator) { + case "avg": + return flux.mean(); + case "sum": + return flux.sum(); + default: + throw new GeaflowException("not supported aggregator: {}", aggregator); } + } - protected List buildGeaflowMetrics(List tables) { - List geaflowMetricList = new ArrayList<>(); - tables.forEach(fluxTable -> fluxTable.getRecords().forEach(fluxRecord -> { - GeaflowMetric geaflowMetric = new GeaflowMetric(); - String lineName = - Optional.ofNullable((String) fluxRecord.getValues().get("worker")).orElse(fluxRecord.getMeasurement()); - geaflowMetric.setMetric(lineName); - geaflowMetric.setTime(fluxRecord.getTime().toEpochMilli()); - geaflowMetric.setValue(fluxRecord.getValue()); - geaflowMetricList.add(geaflowMetric); - })); - return geaflowMetricList; - } + protected List buildGeaflowMetrics(List tables) { + List geaflowMetricList = new ArrayList<>(); + tables.forEach( + fluxTable -> + fluxTable + .getRecords() + .forEach( + fluxRecord -> { + GeaflowMetric geaflowMetric = new GeaflowMetric(); + String lineName = + Optional.ofNullable((String) fluxRecord.getValues().get("worker")) + .orElse(fluxRecord.getMeasurement()); + geaflowMetric.setMetric(lineName); + geaflowMetric.setTime(fluxRecord.getTime().toEpochMilli()); + geaflowMetric.setValue(fluxRecord.getValue()); + geaflowMetricList.add(geaflowMetric); + })); + return geaflowMetricList; + } - protected synchronized InfluxDBClientImpl getInfluxdbClient(InfluxdbPluginConfigClass influxdbConfig) { - String cacheKey = influxdbConfig.toString(); - InfluxDBClientImpl client = INFLUXDB_CLIENT_CACHE.getIfPresent(cacheKey); - if (client == null) { - client = buildInfluxDBClientImpl(influxdbConfig); - INFLUXDB_CLIENT_CACHE.put(cacheKey, client); - } - return client; + protected synchronized InfluxDBClientImpl getInfluxdbClient( + InfluxdbPluginConfigClass influxdbConfig) { + String cacheKey = influxdbConfig.toString(); + InfluxDBClientImpl client = INFLUXDB_CLIENT_CACHE.getIfPresent(cacheKey); + if (client == null) { + client = buildInfluxDBClientImpl(influxdbConfig); + INFLUXDB_CLIENT_CACHE.put(cacheKey, client); } + return client; + } - private InfluxDBClientImpl buildInfluxDBClientImpl(InfluxdbPluginConfigClass influxdbConfig) { - log.debug("build influxdb: {}", influxdbConfig.toString()); - Integer connectTimeout = Optional.ofNullable(influxdbConfig.getConnectTimeout()).orElse(30000); - Integer writeTimeout = Optional.ofNullable(influxdbConfig.getWriteTimeout()).orElse(30000); - OkHttpClient.Builder httpClient = new OkHttpClient.Builder() + private InfluxDBClientImpl buildInfluxDBClientImpl(InfluxdbPluginConfigClass influxdbConfig) { + log.debug("build influxdb: {}", influxdbConfig.toString()); + Integer connectTimeout = Optional.ofNullable(influxdbConfig.getConnectTimeout()).orElse(30000); + Integer writeTimeout = Optional.ofNullable(influxdbConfig.getWriteTimeout()).orElse(30000); + OkHttpClient.Builder httpClient = + new OkHttpClient.Builder() .connectTimeout(connectTimeout, TimeUnit.MILLISECONDS) .writeTimeout(writeTimeout, TimeUnit.MILLISECONDS); - InfluxDBClientOptions options = InfluxDBClientOptions.builder() + InfluxDBClientOptions options = + InfluxDBClientOptions.builder() .okHttpClient(httpClient) .url(influxdbConfig.getUrl()) .org(influxdbConfig.getOrg()) .bucket(influxdbConfig.getBucket()) .authenticateToken(influxdbConfig.getToken().toCharArray()) .build(); - return new InfluxDBClientImpl(options); - } + return new InfluxDBClientImpl(options); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/JdbcStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/JdbcStore.java index dfdfee085..b8668cc8c 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/JdbcStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/JdbcStore.java @@ -21,8 +21,6 @@ import static org.apache.geaflow.console.core.model.plugin.config.JdbcPluginConfigClass.MYSQL_DRIVER_CLASS; -import com.alibaba.fastjson.JSON; -import com.google.common.collect.Lists; import java.nio.charset.StandardCharsets; import java.util.Date; import java.util.Iterator; @@ -30,10 +28,9 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; + import javax.sql.DataSource; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.console.common.dal.model.PageList; @@ -55,188 +52,219 @@ import org.springframework.jdbc.datasource.DriverManagerDataSource; import org.springframework.stereotype.Component; -@Component -public class JdbcStore implements GeaflowRuntimeMetaStore { - - private static final String QUERY_TASK_ALL_KEY = "_"; - - private static final String QUERY_TASK_OFFSET_KEY = "_offset_"; - - private static final String QUERY_TASK_EXCEPTION_KEY = "_exception_"; - - private static final String QUERY_METRIC_META_KEY = "_metrics_META_"; +import com.alibaba.fastjson.JSON; +import com.google.common.collect.Lists; - private static final String QUERY_HEARTBEAT_KEY = "_heartbeat_"; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; - private static final String QUERY_TASK_PIPELINE_KEY = "_metrics_PIPELINE_"; +@Component +public class JdbcStore implements GeaflowRuntimeMetaStore { - private static final String QUERY_TASK_CYCLE_KEY = "_metrics_CYCLE_"; + private static final String QUERY_TASK_ALL_KEY = "_"; + private static final String QUERY_TASK_OFFSET_KEY = "_offset_"; - private static final Object lock = new Object(); - private static final Map jdbcTemplateCache = new ConcurrentHashMap(); + private static final String QUERY_TASK_EXCEPTION_KEY = "_exception_"; - private static String getSqlByPkLimitedAndOrdered(String tableName, String pk, Integer count) { - return String.format( - "SELECT pk, value, gmt_modified from %s WHERE pk LIKE '%s%%' order by gmt_modified desc limit %s", - tableName, pk, count); - } + private static final String QUERY_METRIC_META_KEY = "_metrics_META_"; - private static String getSqlByPkLimitedAndOrderedWithStartTime(String tableName, String pk, Integer count, - Long startTime) { - if (startTime == null) { - return getSqlByPkLimitedAndOrdered(tableName, pk, count); - } - return String.format("SELECT pk, value, gmt_modified from %s WHERE pk LIKE '%s%%' AND gmt_modified > '%s' " - + "order by gmt_modified desc limit %s", tableName, pk, DateTimeUtil.format(new Date(startTime)), count); - } + private static final String QUERY_HEARTBEAT_KEY = "_heartbeat_"; - private static String deleteSqlByPkLike(String tableName, String pk) { - return String.format("DELETE from %s WHERE pk LIKE '%s%%'", tableName, pk); - } + private static final String QUERY_TASK_PIPELINE_KEY = "_metrics_PIPELINE_"; - @Override - public PageList queryPipelines(GeaflowTask task) { - List pipelines = queryMessage(task, getQueryKey(task, QUERY_TASK_PIPELINE_KEY), 500, null); - return new PageList<>(ListUtil.convert(pipelines, pipeline -> JSON.parseObject(pipeline.getValue(), GeaflowPipeline.class))); - } + private static final String QUERY_TASK_CYCLE_KEY = "_metrics_CYCLE_"; - @Override - public PageList queryCycles(GeaflowTask task, String pipelineId) { - List cycles = queryMessage(task, getQueryKey(task, QUERY_TASK_CYCLE_KEY + pipelineId), 500, null); - return new PageList<>(ListUtil.convert(cycles, cycle -> JSON.parseObject(cycle.getValue(), GeaflowCycle.class))); - } + private static final Object lock = new Object(); + private static final Map jdbcTemplateCache = new ConcurrentHashMap(); - @Override - public PageList queryOffsets(GeaflowTask task) { - String queryKey = getQueryKey(task, QUERY_TASK_OFFSET_KEY); - List offsets = queryMessage(task, queryKey, 500, null); - List list = ListUtil.convert(offsets, e -> { - GeaflowOffset offset = JSON.parseObject(e.getValue(), GeaflowOffset.class); - offset.setPartitionName(e.getPk().substring(queryKey.length())); - offset.formatTime(); - return offset; - }); - return new PageList<>(list); - } + private static String getSqlByPkLimitedAndOrdered(String tableName, String pk, Integer count) { + return String.format( + "SELECT pk, value, gmt_modified from %s WHERE pk LIKE '%s%%' order by gmt_modified desc" + + " limit %s", + tableName, pk, count); + } - @Override - public PageList queryErrors(GeaflowTask task) { - List errors = queryMessage(task, getQueryKey(task, QUERY_TASK_EXCEPTION_KEY), 500, 0L); - return new PageList<>(ListUtil.convert(errors, error -> JSON.parseObject(error.getValue(), GeaflowError.class))); + private static String getSqlByPkLimitedAndOrderedWithStartTime( + String tableName, String pk, Integer count, Long startTime) { + if (startTime == null) { + return getSqlByPkLimitedAndOrdered(tableName, pk, count); } - - @Override - public PageList queryMetricMeta(GeaflowTask task) { - List metricMetaList = queryMessage(task, getQueryKey(task, QUERY_METRIC_META_KEY), 500, 0L); - return new PageList<>(ListUtil.convert(metricMetaList, metricMeta -> { - GeaflowMetricMeta meta = JSON.parseObject(metricMeta.getValue(), GeaflowMetricMeta.class); - String fullName = meta.getMetricName(); - if (fullName.contains("/")) { + return String.format( + "SELECT pk, value, gmt_modified from %s WHERE pk LIKE '%s%%' AND gmt_modified > '%s' " + + "order by gmt_modified desc limit %s", + tableName, pk, DateTimeUtil.format(new Date(startTime)), count); + } + + private static String deleteSqlByPkLike(String tableName, String pk) { + return String.format("DELETE from %s WHERE pk LIKE '%s%%'", tableName, pk); + } + + @Override + public PageList queryPipelines(GeaflowTask task) { + List pipelines = + queryMessage(task, getQueryKey(task, QUERY_TASK_PIPELINE_KEY), 500, null); + return new PageList<>( + ListUtil.convert( + pipelines, pipeline -> JSON.parseObject(pipeline.getValue(), GeaflowPipeline.class))); + } + + @Override + public PageList queryCycles(GeaflowTask task, String pipelineId) { + List cycles = + queryMessage(task, getQueryKey(task, QUERY_TASK_CYCLE_KEY + pipelineId), 500, null); + return new PageList<>( + ListUtil.convert(cycles, cycle -> JSON.parseObject(cycle.getValue(), GeaflowCycle.class))); + } + + @Override + public PageList queryOffsets(GeaflowTask task) { + String queryKey = getQueryKey(task, QUERY_TASK_OFFSET_KEY); + List offsets = queryMessage(task, queryKey, 500, null); + List list = + ListUtil.convert( + offsets, + e -> { + GeaflowOffset offset = JSON.parseObject(e.getValue(), GeaflowOffset.class); + offset.setPartitionName(e.getPk().substring(queryKey.length())); + offset.formatTime(); + return offset; + }); + return new PageList<>(list); + } + + @Override + public PageList queryErrors(GeaflowTask task) { + List errors = + queryMessage(task, getQueryKey(task, QUERY_TASK_EXCEPTION_KEY), 500, 0L); + return new PageList<>( + ListUtil.convert(errors, error -> JSON.parseObject(error.getValue(), GeaflowError.class))); + } + + @Override + public PageList queryMetricMeta(GeaflowTask task) { + List metricMetaList = + queryMessage(task, getQueryKey(task, QUERY_METRIC_META_KEY), 500, 0L); + return new PageList<>( + ListUtil.convert( + metricMetaList, + metricMeta -> { + GeaflowMetricMeta meta = + JSON.parseObject(metricMeta.getValue(), GeaflowMetricMeta.class); + String fullName = meta.getMetricName(); + if (fullName.contains("/")) { String groupName = StringUtils.substringBefore(fullName, "/"); String metricName = StringUtils.substringAfter(fullName, "/"); meta.setMetricGroup(groupName); meta.setMetricName(metricName); - } - return meta; - })); - } - - @Override - public GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task) { - List heartbeatList = queryMessage(task, getQueryKey(task, QUERY_HEARTBEAT_KEY), 1, 0L); - if (CollectionUtils.isEmpty(heartbeatList)) { - return null; - } - return JSON.parseObject(heartbeatList.get(0).getValue(), GeaflowHeartbeatInfo.class); - } - - @Override - public void cleanRuntimeMeta(GeaflowTask task) { - RuntimeMetaArgsClass runtimeMetaArgs = new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig()); - JdbcTemplate template = buildJdbcTemplate(((JdbcPluginConfigClass) runtimeMetaArgs.getPlugin())); - String tableName = runtimeMetaArgs.getTable(); - String key = getQueryKey(task, QUERY_TASK_ALL_KEY); - String sql = deleteSqlByPkLike(tableName, key); - template.execute(sql); - } - - private String getQueryKey(GeaflowTask task, String separator) { - return TaskParams.getRuntimeTaskName(task.getId()) + separator; - } - - private List queryMessage(GeaflowTask task, String queryKey, int maxCount, Long startTime) { - RuntimeMetaArgsClass runtimeMetaArgs = new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig()); - - List result = Lists.newArrayListWithCapacity(maxCount); - String tableName = runtimeMetaArgs.getTable(); - String sql = getSqlByPkLimitedAndOrderedWithStartTime(tableName, queryKey, maxCount, startTime); - try { - JdbcTemplate jdbcTemplate = buildJdbcTemplate((JdbcPluginConfigClass) runtimeMetaArgs.getPlugin()); - if (!checkTableExist(jdbcTemplate, tableName)) { - return null; - } - List> resultList = jdbcTemplate.queryForList(sql); - Iterator> iterator = resultList.iterator(); - int count = 0; - while (iterator.hasNext() && count < maxCount) { - Map rs = iterator.next(); - String pk = rs.get("pk").toString(); - String value = new String((byte[]) rs.get("value"), StandardCharsets.UTF_8); - result.add(new RuntimeMeta(pk, value)); - count++; - } - } catch (Exception e) { - throw new GeaflowException("jdbc query error", e.getMessage(), e); - } - return result; + } + return meta; + })); + } + + @Override + public GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task) { + List heartbeatList = + queryMessage(task, getQueryKey(task, QUERY_HEARTBEAT_KEY), 1, 0L); + if (CollectionUtils.isEmpty(heartbeatList)) { + return null; } - - private boolean checkTableExist(JdbcTemplate jdbcTemplate, String tableName) { - String sql = String.format("show tables like '%s'", tableName); - List> tables = jdbcTemplate.queryForList(sql); - return !tables.isEmpty(); + return JSON.parseObject(heartbeatList.get(0).getValue(), GeaflowHeartbeatInfo.class); + } + + @Override + public void cleanRuntimeMeta(GeaflowTask task) { + RuntimeMetaArgsClass runtimeMetaArgs = + new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig()); + JdbcTemplate template = + buildJdbcTemplate(((JdbcPluginConfigClass) runtimeMetaArgs.getPlugin())); + String tableName = runtimeMetaArgs.getTable(); + String key = getQueryKey(task, QUERY_TASK_ALL_KEY); + String sql = deleteSqlByPkLike(tableName, key); + template.execute(sql); + } + + private String getQueryKey(GeaflowTask task, String separator) { + return TaskParams.getRuntimeTaskName(task.getId()) + separator; + } + + private List queryMessage( + GeaflowTask task, String queryKey, int maxCount, Long startTime) { + RuntimeMetaArgsClass runtimeMetaArgs = + new RuntimeMetaArgsClass(task.getRuntimeMetaPluginConfig()); + + List result = Lists.newArrayListWithCapacity(maxCount); + String tableName = runtimeMetaArgs.getTable(); + String sql = getSqlByPkLimitedAndOrderedWithStartTime(tableName, queryKey, maxCount, startTime); + try { + JdbcTemplate jdbcTemplate = + buildJdbcTemplate((JdbcPluginConfigClass) runtimeMetaArgs.getPlugin()); + if (!checkTableExist(jdbcTemplate, tableName)) { + return null; + } + List> resultList = jdbcTemplate.queryForList(sql); + Iterator> iterator = resultList.iterator(); + int count = 0; + while (iterator.hasNext() && count < maxCount) { + Map rs = iterator.next(); + String pk = rs.get("pk").toString(); + String value = new String((byte[]) rs.get("value"), StandardCharsets.UTF_8); + result.add(new RuntimeMeta(pk, value)); + count++; + } + } catch (Exception e) { + throw new GeaflowException("jdbc query error", e.getMessage(), e); } - - private JdbcTemplate buildJdbcTemplate(JdbcPluginConfigClass jdbcPluginConfig) { - String confKey = JSON.toJSONString(jdbcPluginConfig.build()); - JdbcTemplate jdbcTemplate = jdbcTemplateCache.get(confKey); + return result; + } + + private boolean checkTableExist(JdbcTemplate jdbcTemplate, String tableName) { + String sql = String.format("show tables like '%s'", tableName); + List> tables = jdbcTemplate.queryForList(sql); + return !tables.isEmpty(); + } + + private JdbcTemplate buildJdbcTemplate(JdbcPluginConfigClass jdbcPluginConfig) { + String confKey = JSON.toJSONString(jdbcPluginConfig.build()); + JdbcTemplate jdbcTemplate = jdbcTemplateCache.get(confKey); + if (jdbcTemplate != null) { + return jdbcTemplate; + } else { + synchronized (lock) { + jdbcTemplate = jdbcTemplateCache.get(confKey); if (jdbcTemplate != null) { - return jdbcTemplate; + return jdbcTemplate; } else { - synchronized (lock) { - jdbcTemplate = jdbcTemplateCache.get(confKey); - if (jdbcTemplate != null) { - return jdbcTemplate; - } else { - jdbcTemplate = new JdbcTemplate(); - DataSource dataSource = getDriverManagerDataSource(jdbcPluginConfig); - - jdbcTemplate.setDataSource(dataSource); - jdbcTemplateCache.put(confKey, jdbcTemplate); - return jdbcTemplate; - } - } - } - } - - private DriverManagerDataSource getDriverManagerDataSource(JdbcPluginConfigClass jdbcPluginConfig) { - DriverManagerDataSource dataSource = new DriverManagerDataSource(); - dataSource.setDriverClassName(Optional.ofNullable(jdbcPluginConfig.getDriverClass()).orElse(MYSQL_DRIVER_CLASS)); - dataSource.setUrl(jdbcPluginConfig.getUrl()); - dataSource.setUsername(jdbcPluginConfig.getUsername()); - dataSource.setPassword(jdbcPluginConfig.getPassword()); - return dataSource; - } - - @Setter - @Getter - @AllArgsConstructor - private static class RuntimeMeta { - - private String pk; - - private String value; + jdbcTemplate = new JdbcTemplate(); + DataSource dataSource = getDriverManagerDataSource(jdbcPluginConfig); + jdbcTemplate.setDataSource(dataSource); + jdbcTemplateCache.put(confKey, jdbcTemplate); + return jdbcTemplate; + } + } } + } + + private DriverManagerDataSource getDriverManagerDataSource( + JdbcPluginConfigClass jdbcPluginConfig) { + DriverManagerDataSource dataSource = new DriverManagerDataSource(); + dataSource.setDriverClassName( + Optional.ofNullable(jdbcPluginConfig.getDriverClass()).orElse(MYSQL_DRIVER_CLASS)); + dataSource.setUrl(jdbcPluginConfig.getUrl()); + dataSource.setUsername(jdbcPluginConfig.getUsername()); + dataSource.setPassword(jdbcPluginConfig.getPassword()); + return dataSource; + } + + @Setter + @Getter + @AllArgsConstructor + private static class RuntimeMeta { + + private String pk; + + private String value; + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/PersistentDataStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/PersistentDataStore.java index e3ff7b84e..5d37d5a93 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/PersistentDataStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/PersistentDataStore.java @@ -21,7 +21,7 @@ import java.util.Date; import java.util.Map; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.service.integration.engine.Configuration; import org.apache.geaflow.console.common.service.integration.engine.FsPath; import org.apache.geaflow.console.common.service.integration.engine.IPersistentIO; @@ -44,87 +44,86 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; +import lombok.extern.slf4j.Slf4j; + @Component @Slf4j public class PersistentDataStore implements GeaflowDataStore { - @Autowired - private VersionFactory versionFactory; - - @Autowired - private PluginConfigService pluginConfigService; - - @Autowired - private PluginService pluginService; - - @Autowired - private VersionService versionService; - - @Autowired - private InstanceService instanceService; - - @Override - public Long queryStorageUsage(GeaflowTask task) { - return null; - } - - @Override - public Long queryFileCount(GeaflowTask task) { - IPersistentIO persistentIO = buildPersistentIO(task.getDataPluginConfig(), task.getRelease().getVersion()); - FsPath path = getTaskPath(task); - return persistentIO.getFileCount(path); - } - - @Override - public Date queryModifyTime(GeaflowTask task) { - return null; - } - - @Override - public void cleanTaskData(GeaflowTask task) { - IPersistentIO persistentIO = buildPersistentIO(task.getDataPluginConfig(), task.getRelease().getVersion()); - FsPath path = getTaskPath(task); - persistentIO.delete(path, true); - } - - @Override - public void cleanGraphData(GeaflowGraph graph) { - // use default config - GeaflowPluginCategory category = GeaflowPluginCategory.DATA; - String dataType = pluginService.getDefaultPlugin(category).getType(); - GeaflowPluginConfig dataConfig = pluginConfigService.getDefaultPluginConfig(category, dataType); - - GeaflowVersion version = versionService.getDefaultVersion(); - IPersistentIO persistentIO = buildPersistentIO(dataConfig, version); - - PersistentArgsClass persistentArgs = new PersistentArgsClass(dataConfig); - String root = persistentArgs.getRoot(); - VersionClassLoader classLoader = versionFactory.getClassLoader(version); - - GeaflowInstance instance = instanceService.get(graph.getInstanceId()); - String pathSuffix = instance.getName() + "_" + graph.getName(); - - FsPath path = classLoader.newInstance(FsPath.class, root, pathSuffix); - persistentIO.delete(path, true); - log.info("clean graph data {},{}", root, pathSuffix); - } - - protected IPersistentIO buildPersistentIO(GeaflowPluginConfig pluginConfig, GeaflowVersion version) { - PersistentArgsClass persistentArgs = new PersistentArgsClass(pluginConfig); - Map config = persistentArgs.build().toStringMap(); - - VersionClassLoader classLoader = versionFactory.getClassLoader(version); - Configuration configuration = classLoader.newInstance(Configuration.class, config); - PersistentIOBuilder builder = classLoader.newInstance(PersistentIOBuilder.class); - return builder.build(configuration); - } - - protected FsPath getTaskPath(GeaflowTask task) { - PersistentArgsClass persistentArgs = new PersistentArgsClass(task.getDataPluginConfig()); - String root = persistentArgs.getRoot(); - String pathSuffix = TaskParams.getRuntimeTaskName(task.getId()); - VersionClassLoader classLoader = versionFactory.getClassLoader(task.getRelease().getVersion()); - return classLoader.newInstance(FsPath.class, root, pathSuffix); - } - + @Autowired private VersionFactory versionFactory; + + @Autowired private PluginConfigService pluginConfigService; + + @Autowired private PluginService pluginService; + + @Autowired private VersionService versionService; + + @Autowired private InstanceService instanceService; + + @Override + public Long queryStorageUsage(GeaflowTask task) { + return null; + } + + @Override + public Long queryFileCount(GeaflowTask task) { + IPersistentIO persistentIO = + buildPersistentIO(task.getDataPluginConfig(), task.getRelease().getVersion()); + FsPath path = getTaskPath(task); + return persistentIO.getFileCount(path); + } + + @Override + public Date queryModifyTime(GeaflowTask task) { + return null; + } + + @Override + public void cleanTaskData(GeaflowTask task) { + IPersistentIO persistentIO = + buildPersistentIO(task.getDataPluginConfig(), task.getRelease().getVersion()); + FsPath path = getTaskPath(task); + persistentIO.delete(path, true); + } + + @Override + public void cleanGraphData(GeaflowGraph graph) { + // use default config + GeaflowPluginCategory category = GeaflowPluginCategory.DATA; + String dataType = pluginService.getDefaultPlugin(category).getType(); + GeaflowPluginConfig dataConfig = pluginConfigService.getDefaultPluginConfig(category, dataType); + + GeaflowVersion version = versionService.getDefaultVersion(); + IPersistentIO persistentIO = buildPersistentIO(dataConfig, version); + + PersistentArgsClass persistentArgs = new PersistentArgsClass(dataConfig); + String root = persistentArgs.getRoot(); + VersionClassLoader classLoader = versionFactory.getClassLoader(version); + + GeaflowInstance instance = instanceService.get(graph.getInstanceId()); + String pathSuffix = instance.getName() + "_" + graph.getName(); + + FsPath path = classLoader.newInstance(FsPath.class, root, pathSuffix); + persistentIO.delete(path, true); + log.info("clean graph data {},{}", root, pathSuffix); + } + + protected IPersistentIO buildPersistentIO( + GeaflowPluginConfig pluginConfig, GeaflowVersion version) { + PersistentArgsClass persistentArgs = new PersistentArgsClass(pluginConfig); + Map config = persistentArgs.build().toStringMap(); + + VersionClassLoader classLoader = versionFactory.getClassLoader(version); + Configuration configuration = classLoader.newInstance(Configuration.class, config); + PersistentIOBuilder builder = classLoader.newInstance(PersistentIOBuilder.class); + return builder.build(configuration); + } + + protected FsPath getTaskPath(GeaflowTask task) { + PersistentArgsClass persistentArgs = new PersistentArgsClass(task.getDataPluginConfig()); + String root = persistentArgs.getRoot(); + String pathSuffix = TaskParams.getRuntimeTaskName(task.getId()); + VersionClassLoader classLoader = versionFactory.getClassLoader(task.getRelease().getVersion()); + return classLoader.newInstance(FsPath.class, root, pathSuffix); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/RedisStore.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/RedisStore.java index 2bce478e6..548aae836 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/RedisStore.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/store/impl/RedisStore.java @@ -19,15 +19,12 @@ package org.apache.geaflow.console.core.service.store.impl; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.RemovalListener; import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.apache.geaflow.console.common.util.ThreadUtil; @@ -36,6 +33,12 @@ import org.apache.geaflow.console.core.model.task.GeaflowTask; import org.apache.geaflow.console.core.service.store.GeaflowHaMetaStore; import org.springframework.stereotype.Component; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.RemovalListener; + +import lombok.extern.slf4j.Slf4j; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; import redis.clients.jedis.ScanParams; @@ -45,103 +48,114 @@ @Component public class RedisStore implements GeaflowHaMetaStore { - private static final Integer MAX_SCAN_COUNT = Integer.MAX_VALUE; + private static final Integer MAX_SCAN_COUNT = Integer.MAX_VALUE; - private static final int DEFAULT_RETRY_INTERVAL_MS = 500; + private static final int DEFAULT_RETRY_INTERVAL_MS = 500; - private static final int DEFAULT_RETRY_TIMES = 10; + private static final int DEFAULT_RETRY_TIMES = 10; - private static final int DEFAULT_CONNECTION_TIMEOUT_MS = 5000; + private static final int DEFAULT_CONNECTION_TIMEOUT_MS = 5000; - private static final Cache JEDIS_POOL_CACHE = - CacheBuilder.newBuilder().initialCapacity(10).maximumSize(100) - .removalListener( - (RemovalListener) removalNotification -> { + private static final Cache JEDIS_POOL_CACHE = + CacheBuilder.newBuilder() + .initialCapacity(10) + .maximumSize(100) + .removalListener( + (RemovalListener) + removalNotification -> { JedisPool jedisPool = removalNotification.getValue(); // async clean the resources - CompletableFuture.runAsync(() -> { - Jedis jedis = jedisPool.getResource(); - log.info("Jedis pool closed: {}:{}", jedis.getClient().getHost(), - jedis.getClient().getPort()); - jedisPool.close(); - }); - }) - .build(); - - @Override - public void cleanHaMeta(GeaflowTask task) { - RedisPluginConfigClass redisConfig = (RedisPluginConfigClass) new HaMetaArgsClass( - task.getHaMetaPluginConfig()).getPlugin(); - - String keyPattern = String.format("*%s*", task.getId()); - deleteByKeyPattern(redisConfig, keyPattern); - } + CompletableFuture.runAsync( + () -> { + Jedis jedis = jedisPool.getResource(); + log.info( + "Jedis pool closed: {}:{}", + jedis.getClient().getHost(), + jedis.getClient().getPort()); + jedisPool.close(); + }); + }) + .build(); - private void deleteByKey(RedisPluginConfigClass config, String key) { - Jedis jedis = getJedis(config); - jedis.del(key); - } + @Override + public void cleanHaMeta(GeaflowTask task) { + RedisPluginConfigClass redisConfig = + (RedisPluginConfigClass) new HaMetaArgsClass(task.getHaMetaPluginConfig()).getPlugin(); - private void deleteByKeyPattern(RedisPluginConfigClass config, String keyPattern) { - int retryTimes = Optional.ofNullable(config.getRetryTimes()).orElse(DEFAULT_RETRY_TIMES); - int retryIntervalMs = Optional.ofNullable(config.getRetryIntervalMs()).orElse(DEFAULT_RETRY_INTERVAL_MS); - Set keys = deleteByKeyPatternWithRetry(config, keyPattern, retryTimes, - retryIntervalMs); - log.info("Successfully deleted redis data with key pattern: {}. Redis host: {}:{}. " - + "Deleted keys: {}.", - keyPattern, config.getHost(), config.getHost(), keys); - } + String keyPattern = String.format("*%s*", task.getId()); + deleteByKeyPattern(redisConfig, keyPattern); + } + + private void deleteByKey(RedisPluginConfigClass config, String key) { + Jedis jedis = getJedis(config); + jedis.del(key); + } + + private void deleteByKeyPattern(RedisPluginConfigClass config, String keyPattern) { + int retryTimes = Optional.ofNullable(config.getRetryTimes()).orElse(DEFAULT_RETRY_TIMES); + int retryIntervalMs = + Optional.ofNullable(config.getRetryIntervalMs()).orElse(DEFAULT_RETRY_INTERVAL_MS); + Set keys = deleteByKeyPatternWithRetry(config, keyPattern, retryTimes, retryIntervalMs); + log.info( + "Successfully deleted redis data with key pattern: {}. Redis host: {}:{}. " + + "Deleted keys: {}.", + keyPattern, + config.getHost(), + config.getHost(), + keys); + } - private Set deleteByKeyPatternWithRetry(RedisPluginConfigClass config, - String keyPattern, int retry, - int retryIntervalMs) { - try (Jedis jedis = getJedis(config)) { - Set keySets = getScanKeySets(jedis, keyPattern, MAX_SCAN_COUNT); - keySets.forEach(jedis::del); - return keySets; - } catch (Exception e) { - if (retry <= 0) { - throw e; - } - ThreadUtil.sleepMilliSeconds(retryIntervalMs); - return deleteByKeyPatternWithRetry(config, keyPattern, retry - 1, retryIntervalMs); - } + private Set deleteByKeyPatternWithRetry( + RedisPluginConfigClass config, String keyPattern, int retry, int retryIntervalMs) { + try (Jedis jedis = getJedis(config)) { + Set keySets = getScanKeySets(jedis, keyPattern, MAX_SCAN_COUNT); + keySets.forEach(jedis::del); + return keySets; + } catch (Exception e) { + if (retry <= 0) { + throw e; + } + ThreadUtil.sleepMilliSeconds(retryIntervalMs); + return deleteByKeyPatternWithRetry(config, keyPattern, retry - 1, retryIntervalMs); } + } - private Set getScanKeySets(Jedis jedis, String keyPattern, Integer count) { - Set keySets = new HashSet<>(); - ScanParams scanParams = new ScanParams(); - scanParams.match(keyPattern); - scanParams.count(count); - // the start cursor - String cursor = ScanParams.SCAN_POINTER_START; - ScanResult scanResult; - while (true) { - scanResult = jedis.scan(cursor, scanParams); - List result = scanResult.getResult(); - if (CollectionUtils.isNotEmpty(result)) { - keySets.addAll(result); - } - if ("0".equals(cursor)) { - break; - } - } - return keySets; + private Set getScanKeySets(Jedis jedis, String keyPattern, Integer count) { + Set keySets = new HashSet<>(); + ScanParams scanParams = new ScanParams(); + scanParams.match(keyPattern); + scanParams.count(count); + // the start cursor + String cursor = ScanParams.SCAN_POINTER_START; + ScanResult scanResult; + while (true) { + scanResult = jedis.scan(cursor, scanParams); + List result = scanResult.getResult(); + if (CollectionUtils.isNotEmpty(result)) { + keySets.addAll(result); + } + if ("0".equals(cursor)) { + break; + } } + return keySets; + } - private Jedis getJedis(RedisPluginConfigClass config) { - String cacheKey = config.toString(); - JedisPool jedisPool = JEDIS_POOL_CACHE.getIfPresent(cacheKey); - if (jedisPool == null) { - String host = config.getHost(); - int port = config.getPort(); - int timeout = Optional.ofNullable(config.getConnectionTimeoutMs()).orElse(DEFAULT_CONNECTION_TIMEOUT_MS); - log.info("Jedis pool created: {}", config); - GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig(); - jedisPool = new JedisPool(poolConfig, host, port, timeout, config.getUser(), - config.getPassword()); - JEDIS_POOL_CACHE.put(cacheKey, jedisPool); - } - return jedisPool.getResource(); + private Jedis getJedis(RedisPluginConfigClass config) { + String cacheKey = config.toString(); + JedisPool jedisPool = JEDIS_POOL_CACHE.getIfPresent(cacheKey); + if (jedisPool == null) { + String host = config.getHost(); + int port = config.getPort(); + int timeout = + Optional.ofNullable(config.getConnectionTimeoutMs()) + .orElse(DEFAULT_CONNECTION_TIMEOUT_MS); + log.info("Jedis pool created: {}", config); + GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig(); + jedisPool = + new JedisPool(poolConfig, host, port, timeout, config.getUser(), config.getPassword()); + JEDIS_POOL_CACHE.put(cacheKey, jedisPool); } + return jedisPool.getResource(); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskOperator.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskOperator.java index 6ada7dc64..5595f9942 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskOperator.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskOperator.java @@ -23,10 +23,9 @@ import static org.apache.geaflow.console.common.util.type.GeaflowTaskStatus.RUNNING; import static org.apache.geaflow.console.common.util.type.GeaflowTaskStatus.STARTING; -import com.alibaba.fastjson.JSON; import java.util.Date; import java.util.Optional; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.dal.model.PageList; import org.apache.geaflow.console.common.util.DateTimeUtil; import org.apache.geaflow.console.common.util.Fmt; @@ -63,171 +62,176 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import com.alibaba.fastjson.JSON; + +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j public class GeaflowTaskOperator { - private static final int TASK_STARTUP_TIMEOUT = 8 * 60; - - @Autowired - private TaskService taskService; - - @Autowired - private TokenGenerator tokenGenerator; - - @Autowired - private RuntimeFactory runtimeFactory; - - @Autowired - private RuntimeMetaStoreFactory runtimeMetaStoreFactory; - - @Autowired - private MetricStoreFactory metricStoreFactory; - - @Autowired - private HaMetaStoreFactory haMetaStoreFactory; - - @Autowired - private DataStoreFactory dataStoreFactory; - - @Autowired - private AuditService auditService; + private static final int TASK_STARTUP_TIMEOUT = 8 * 60; - public boolean start(GeaflowTask task) { - GeaflowRuntime runtime = runtimeFactory.getRuntime(task); + @Autowired private TaskService taskService; - // generate task token and save before start - task.setToken(tokenGenerator.nextTaskToken()); - taskService.update(task); + @Autowired private TokenGenerator tokenGenerator; - // submit job to the engine - try { - GeaflowTaskHandle handle = runtime.start(task); - task.setHandle(handle); - task.setStartTime(new Date()); - taskService.update(task); - taskService.updateStatus(task.getId(), STARTING, RUNNING); - log.info("Submit task {} successfully", task.getId()); - return true; + @Autowired private RuntimeFactory runtimeFactory; - } catch (Exception e) { - log.error("Submit task {} failed", task.getId(), e); - taskService.updateStatus(task.getId(), STARTING, FAILED); - throw e; - } - } + @Autowired private RuntimeMetaStoreFactory runtimeMetaStoreFactory; - public void stop(GeaflowTask task) { - runtimeFactory.getRuntime(task).stop(task); - log.info("Stop task {} success, handle={}", task.getId(), JSON.toJSONString(task.getHandle())); - } + @Autowired private MetricStoreFactory metricStoreFactory; - public GeaflowTaskStatus refreshStatus(GeaflowTask task) { - GeaflowTaskStatus oldStatus = task.getStatus(); + @Autowired private HaMetaStoreFactory haMetaStoreFactory; - // only refresh running task - if (!RUNNING.equals(oldStatus)) { - return oldStatus; - } + @Autowired private DataStoreFactory dataStoreFactory; - GeaflowTaskStatus newStatus = runtimeFactory.getRuntime(task).queryStatus(task); - if (newStatus == FAILED && task.getRelease().getCluster().getType() == GeaflowPluginType.K8S) { - // task has not been started completely - if (!Optional.ofNullable(((K8sTaskHandle) task.getHandle())).map(K8sTaskHandle::getStartupNotifyInfo) - .isPresent()) { - if (DateTimeUtil.isExpired(task.getStartTime(), TASK_STARTUP_TIMEOUT)) { - // release task resource - this.stop(task); - - // waiting startup timeout - String detail = Fmt.as("Waiting task startup timeout after {}s", TASK_STARTUP_TIMEOUT); - log.info(detail); - auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.STOP, detail)); - - } else { - // waiting startup, keep status not changed - newStatus = oldStatus; - } - } - } + @Autowired private AuditService auditService; - taskService.updateStatus(task.getId(), oldStatus, newStatus); - return newStatus; - } + public boolean start(GeaflowTask task) { + GeaflowRuntime runtime = runtimeFactory.getRuntime(task); - public boolean cleanMeta(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowPluginConfig haMetaConfig = task.getHaMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - GeaflowHaMetaStore haMetaStore = haMetaStoreFactory.getHaMetaStore(haMetaConfig); - runtimeMetaStore.cleanRuntimeMeta(task); - haMetaStore.cleanHaMeta(task); - return true; - } + // generate task token and save before start + task.setToken(tokenGenerator.nextTaskToken()); + taskService.update(task); - public boolean cleanData(GeaflowTask task) { - GeaflowPluginConfig dataConfig = task.getDataPluginConfig(); - GeaflowDataStore dataStore = dataStoreFactory.getDataStore(dataConfig.getType()); - dataStore.cleanTaskData(task); - return true; - } + // submit job to the engine + try { + GeaflowTaskHandle handle = runtime.start(task); + task.setHandle(handle); + task.setStartTime(new Date()); + taskService.update(task); + taskService.updateStatus(task.getId(), STARTING, RUNNING); + log.info("Submit task {} successfully", task.getId()); + return true; - public PageList queryPipelines(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - return runtimeMetaStore.queryPipelines(task); + } catch (Exception e) { + log.error("Submit task {} failed", task.getId(), e); + taskService.updateStatus(task.getId(), STARTING, FAILED); + throw e; } + } - public PageList queryCycles(GeaflowTask task, String pipelineId) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - return runtimeMetaStore.queryCycles(task, pipelineId); - } - - public PageList queryMetricMeta(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - return runtimeMetaStore.queryMetricMeta(task); - } + public void stop(GeaflowTask task) { + runtimeFactory.getRuntime(task).stop(task); + log.info("Stop task {} success, handle={}", task.getId(), JSON.toJSONString(task.getHandle())); + } - public PageList queryMetrics(GeaflowTask task, GeaflowMetricQueryRequest queryRequest) { - GeaflowPluginConfig metricConfig = task.getMetricPluginConfig(); - GeaflowMetricStore metaStore = metricStoreFactory.getMetricStore(metricConfig); - return metaStore.queryMetrics(task, queryRequest); - } + public GeaflowTaskStatus refreshStatus(GeaflowTask task) { + GeaflowTaskStatus oldStatus = task.getStatus(); - public PageList queryOffsets(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - return runtimeMetaStore.queryOffsets(task); + // only refresh running task + if (!RUNNING.equals(oldStatus)) { + return oldStatus; } - public PageList queryErrors(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - return runtimeMetaStore.queryErrors(task); - } - - public GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task) { - GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); - GeaflowRuntimeMetaStore runtimeMetaStore = runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); - GeaflowHeartbeatInfo heartbeatInfo = runtimeMetaStore.queryHeartbeat(task); - setupHeartbeatInfo(heartbeatInfo); - return heartbeatInfo; + GeaflowTaskStatus newStatus = runtimeFactory.getRuntime(task).queryStatus(task); + if (newStatus == FAILED && task.getRelease().getCluster().getType() == GeaflowPluginType.K8S) { + // task has not been started completely + if (!Optional.ofNullable(((K8sTaskHandle) task.getHandle())) + .map(K8sTaskHandle::getStartupNotifyInfo) + .isPresent()) { + if (DateTimeUtil.isExpired(task.getStartTime(), TASK_STARTUP_TIMEOUT)) { + // release task resource + this.stop(task); + + // waiting startup timeout + String detail = Fmt.as("Waiting task startup timeout after {}s", TASK_STARTUP_TIMEOUT); + log.info(detail); + auditService.create(new GeaflowAudit(task.getId(), GeaflowOperationType.STOP, detail)); + + } else { + // waiting startup, keep status not changed + newStatus = oldStatus; + } + } } - private void setupHeartbeatInfo(GeaflowHeartbeatInfo heartbeatInfo) { - if (heartbeatInfo != null) { - int activeNum = 0; - long now = System.currentTimeMillis(); - long expiredTime = now - heartbeatInfo.getExpiredTimeMs(); - for (ContainerInfo container : heartbeatInfo.getContainers()) { - if (container.getLastTimestamp() != null && container.getLastTimestamp() > expiredTime) { - container.setActive(true); - activeNum++; - } - } - heartbeatInfo.setActiveNum(activeNum); + taskService.updateStatus(task.getId(), oldStatus, newStatus); + return newStatus; + } + + public boolean cleanMeta(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowPluginConfig haMetaConfig = task.getHaMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + GeaflowHaMetaStore haMetaStore = haMetaStoreFactory.getHaMetaStore(haMetaConfig); + runtimeMetaStore.cleanRuntimeMeta(task); + haMetaStore.cleanHaMeta(task); + return true; + } + + public boolean cleanData(GeaflowTask task) { + GeaflowPluginConfig dataConfig = task.getDataPluginConfig(); + GeaflowDataStore dataStore = dataStoreFactory.getDataStore(dataConfig.getType()); + dataStore.cleanTaskData(task); + return true; + } + + public PageList queryPipelines(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + return runtimeMetaStore.queryPipelines(task); + } + + public PageList queryCycles(GeaflowTask task, String pipelineId) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + return runtimeMetaStore.queryCycles(task, pipelineId); + } + + public PageList queryMetricMeta(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + return runtimeMetaStore.queryMetricMeta(task); + } + + public PageList queryMetrics( + GeaflowTask task, GeaflowMetricQueryRequest queryRequest) { + GeaflowPluginConfig metricConfig = task.getMetricPluginConfig(); + GeaflowMetricStore metaStore = metricStoreFactory.getMetricStore(metricConfig); + return metaStore.queryMetrics(task, queryRequest); + } + + public PageList queryOffsets(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + return runtimeMetaStore.queryOffsets(task); + } + + public PageList queryErrors(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + return runtimeMetaStore.queryErrors(task); + } + + public GeaflowHeartbeatInfo queryHeartbeat(GeaflowTask task) { + GeaflowPluginConfig runtimeMetaConfig = task.getRuntimeMetaPluginConfig(); + GeaflowRuntimeMetaStore runtimeMetaStore = + runtimeMetaStoreFactory.getRuntimeMetaStore(runtimeMetaConfig); + GeaflowHeartbeatInfo heartbeatInfo = runtimeMetaStore.queryHeartbeat(task); + setupHeartbeatInfo(heartbeatInfo); + return heartbeatInfo; + } + + private void setupHeartbeatInfo(GeaflowHeartbeatInfo heartbeatInfo) { + if (heartbeatInfo != null) { + int activeNum = 0; + long now = System.currentTimeMillis(); + long expiredTime = now - heartbeatInfo.getExpiredTimeMs(); + for (ContainerInfo container : heartbeatInfo.getContainers()) { + if (container.getLastTimestamp() != null && container.getLastTimestamp() > expiredTime) { + container.setActive(true); + activeNum++; } + } + heartbeatInfo.setActiveNum(activeNum); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskStatusRefresher.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskStatusRefresher.java index aadb80b12..af790b00b 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskStatusRefresher.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskStatusRefresher.java @@ -25,7 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.util.context.ContextHolder; import org.apache.geaflow.console.common.util.type.GeaflowTaskStatus; @@ -36,59 +36,61 @@ import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Service; +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j public class GeaflowTaskStatusRefresher { - private static final ExecutorService EXECUTOR_SERVICE; + private static final ExecutorService EXECUTOR_SERVICE; - static { - EXECUTOR_SERVICE = new ThreadPoolExecutor(50, 500, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000)); - } + static { + EXECUTOR_SERVICE = + new ThreadPoolExecutor(50, 500, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000)); + } - @Autowired - private TaskService taskService; - - @Autowired - private GeaflowTaskOperator taskOperator; - - @Scheduled(cron = "30 * * * * ?") - void refresh() { - try { - ContextHolder.init(); - List taskIds = this.getRunningTasks(); - if (CollectionUtils.isEmpty(taskIds)) { - return; - } - log.info("Task status refresh start, task size: {}", taskIds.size()); - - for (GeaflowId taskId : taskIds) { - CompletableFuture.runAsync(() -> { - try { - // set tenant and user by task - ContextHolder.init(); - ContextHolder.get().setUserId(taskId.getModifierId()); - ContextHolder.get().setTenantId(taskId.getTenantId()); - - GeaflowTask task = taskService.get(taskId.getId()); - taskOperator.refreshStatus(task); - - } catch (Exception e) { - log.error("Task {} status refresh error: {}", taskId.getId(), e.getMessage(), e); - - } finally { - ContextHolder.destroy(); - } - - }, EXECUTOR_SERVICE); - } - - } finally { - ContextHolder.destroy(); - } - } + @Autowired private TaskService taskService; + + @Autowired private GeaflowTaskOperator taskOperator; - private List getRunningTasks() { - return taskService.getTasksByStatus(GeaflowTaskStatus.RUNNING); + @Scheduled(cron = "30 * * * * ?") + void refresh() { + try { + ContextHolder.init(); + List taskIds = this.getRunningTasks(); + if (CollectionUtils.isEmpty(taskIds)) { + return; + } + log.info("Task status refresh start, task size: {}", taskIds.size()); + + for (GeaflowId taskId : taskIds) { + CompletableFuture.runAsync( + () -> { + try { + // set tenant and user by task + ContextHolder.init(); + ContextHolder.get().setUserId(taskId.getModifierId()); + ContextHolder.get().setTenantId(taskId.getTenantId()); + + GeaflowTask task = taskService.get(taskId.getId()); + taskOperator.refreshStatus(task); + + } catch (Exception e) { + log.error("Task {} status refresh error: {}", taskId.getId(), e.getMessage(), e); + + } finally { + ContextHolder.destroy(); + } + }, + EXECUTOR_SERVICE); + } + + } finally { + ContextHolder.destroy(); } + } + + private List getRunningTasks() { + return taskService.getTasksByStatus(GeaflowTaskStatus.RUNNING); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskSubmitter.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskSubmitter.java index 1cde799c1..7899e6edb 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskSubmitter.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/task/GeaflowTaskSubmitter.java @@ -19,8 +19,6 @@ package org.apache.geaflow.console.core.service.task; -import com.alibaba.fastjson.JSON; -import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -29,7 +27,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.collections.CollectionUtils; import org.apache.geaflow.console.common.util.context.ContextHolder; import org.apache.geaflow.console.common.util.exception.GeaflowException; @@ -41,82 +39,89 @@ import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Service; +import com.alibaba.fastjson.JSON; +import com.google.common.collect.Maps; + +import lombok.extern.slf4j.Slf4j; + @Service @Slf4j public class GeaflowTaskSubmitter { - private static final ExecutorService EXECUTOR_SERVICE = new ThreadPoolExecutor(50, 500, 30, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(1000)); - - private static final int JOB_SUBMIT_TIMEOUT_MINUTE = 4; - - @Autowired - private TaskService taskService; - - @Autowired - private GeaflowTaskOperator taskOperator; - - @Scheduled(cron = "0-59/10 * * * * ?") - void submit() { - try { - ContextHolder.init(); - List taskIds = this.getWaitingTasks(); - if (CollectionUtils.isEmpty(taskIds)) { - return; - } - - log.info("task submitter start, task size: {}", taskIds.size()); - List> futureList = new ArrayList<>(taskIds.size()); - Map submitMap = Maps.newConcurrentMap(); - - for (GeaflowId taskId : taskIds) { - futureList.add(CompletableFuture.runAsync(() -> { - try { - // set tenant and user by task - ContextHolder.init(); - ContextHolder.get().setUserId(taskId.getModifierId()); - ContextHolder.get().setTenantId(taskId.getTenantId()); - GeaflowTask task = taskService.get(taskId.getId()); - submitMap.put(task.getId(), task); - - log.info("task {} submit start, curr status: {}", task.getId(), task.getStatus()); - - // update status to avoid submitting repeatedly. - boolean updateStatus = taskService.updateStatus(task.getId(), GeaflowTaskStatus.WAITING, - GeaflowTaskStatus.STARTING); - if (!updateStatus) { - throw new GeaflowException("task status has been changed, need {} status", - GeaflowTaskStatus.WAITING); - } - taskOperator.start(task); - - } catch (Exception e) { - log.error("task {} submit error: {}", taskId.getId(), e.getMessage(), e); - - } finally { - submitMap.remove(taskId.getId()); - ContextHolder.destroy(); + private static final ExecutorService EXECUTOR_SERVICE = + new ThreadPoolExecutor(50, 500, 30, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000)); + + private static final int JOB_SUBMIT_TIMEOUT_MINUTE = 4; + + @Autowired private TaskService taskService; + + @Autowired private GeaflowTaskOperator taskOperator; + + @Scheduled(cron = "0-59/10 * * * * ?") + void submit() { + try { + ContextHolder.init(); + List taskIds = this.getWaitingTasks(); + if (CollectionUtils.isEmpty(taskIds)) { + return; + } + + log.info("task submitter start, task size: {}", taskIds.size()); + List> futureList = new ArrayList<>(taskIds.size()); + Map submitMap = Maps.newConcurrentMap(); + + for (GeaflowId taskId : taskIds) { + futureList.add( + CompletableFuture.runAsync( + () -> { + try { + // set tenant and user by task + ContextHolder.init(); + ContextHolder.get().setUserId(taskId.getModifierId()); + ContextHolder.get().setTenantId(taskId.getTenantId()); + GeaflowTask task = taskService.get(taskId.getId()); + submitMap.put(task.getId(), task); + + log.info( + "task {} submit start, curr status: {}", task.getId(), task.getStatus()); + + // update status to avoid submitting repeatedly. + boolean updateStatus = + taskService.updateStatus( + task.getId(), GeaflowTaskStatus.WAITING, GeaflowTaskStatus.STARTING); + if (!updateStatus) { + throw new GeaflowException( + "task status has been changed, need {} status", + GeaflowTaskStatus.WAITING); } + taskOperator.start(task); - }, EXECUTOR_SERVICE)); - } + } catch (Exception e) { + log.error("task {} submit error: {}", taskId.getId(), e.getMessage(), e); - if (!futureList.isEmpty()) { - try { - CompletableFuture.allOf(futureList.toArray(new CompletableFuture[futureList.size()])) - .get(JOB_SUBMIT_TIMEOUT_MINUTE, TimeUnit.MINUTES); - } catch (Exception e) { - log.error("Task {} Submit Waiting Timeout", JSON.toJSONString(submitMap.keySet()), e); - } - } + } finally { + submitMap.remove(taskId.getId()); + ContextHolder.destroy(); + } + }, + EXECUTOR_SERVICE)); + } - } finally { - ContextHolder.destroy(); + if (!futureList.isEmpty()) { + try { + CompletableFuture.allOf(futureList.toArray(new CompletableFuture[futureList.size()])) + .get(JOB_SUBMIT_TIMEOUT_MINUTE, TimeUnit.MINUTES); + } catch (Exception e) { + log.error("Task {} Submit Waiting Timeout", JSON.toJSONString(submitMap.keySet()), e); } + } + } finally { + ContextHolder.destroy(); } + } - private List getWaitingTasks() { - return taskService.getTasksByStatus(GeaflowTaskStatus.WAITING); - } + private List getWaitingTasks() { + return taskService.getTasksByStatus(GeaflowTaskStatus.WAITING); + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/CompileClassLoader.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/CompileClassLoader.java index e36bc409f..980e7f574 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/CompileClassLoader.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/CompileClassLoader.java @@ -20,6 +20,5 @@ public interface CompileClassLoader { - T newInstance(Class clazz, Object... parameters); - + T newInstance(Class clazz, Object... parameters); } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/FunctionClassLoader.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/FunctionClassLoader.java index 277d1f042..58e93e4ed 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/FunctionClassLoader.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/FunctionClassLoader.java @@ -24,61 +24,67 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.ListUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.proxy.ProxyUtil; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; +import lombok.extern.slf4j.Slf4j; + @Slf4j public class FunctionClassLoader implements CompileClassLoader { - private final URLClassLoader functionLoader; + private final URLClassLoader functionLoader; - private final VersionClassLoader versionClassLoader; + private final VersionClassLoader versionClassLoader; - public FunctionClassLoader(VersionClassLoader versionClassLoader, List jars) { - this.versionClassLoader = versionClassLoader; - this.functionLoader = createFunctionLoader(jars); - } + public FunctionClassLoader(VersionClassLoader versionClassLoader, List jars) { + this.versionClassLoader = versionClassLoader; + this.functionLoader = createFunctionLoader(jars); + } - public FunctionClassLoader(VersionClassLoader versionClassLoader, URL[] urls) { - this.versionClassLoader = versionClassLoader; - this.functionLoader = createFunctionLoader(urls); - } + public FunctionClassLoader(VersionClassLoader versionClassLoader, URL[] urls) { + this.versionClassLoader = versionClassLoader; + this.functionLoader = createFunctionLoader(urls); + } - private URLClassLoader createFunctionLoader(URL[] urls) { - return new URLClassLoader(urls, versionClassLoader.getClassLoader()); - } + private URLClassLoader createFunctionLoader(URL[] urls) { + return new URLClassLoader(urls, versionClassLoader.getClassLoader()); + } - private URLClassLoader createFunctionLoader(List userJars) { - List userUrls = ListUtil.convert(userJars, jar -> { - try { - File file = versionClassLoader.getLocalFileFactory().getUserFile(jar.getCreatorId(), jar); + private URLClassLoader createFunctionLoader(List userJars) { + List userUrls = + ListUtil.convert( + userJars, + jar -> { + try { + File file = + versionClassLoader.getLocalFileFactory().getUserFile(jar.getCreatorId(), jar); return file.toURI().toURL(); - } catch (Exception e) { + } catch (Exception e) { throw new GeaflowException("Add function jar file {} failed", jar.getName(), e); - } - }); + } + }); - return new URLClassLoader(userUrls.toArray(new URL[]{}), versionClassLoader.getClassLoader()); + return new URLClassLoader(userUrls.toArray(new URL[] {}), versionClassLoader.getClassLoader()); + } - } - - @Override - public T newInstance(Class clazz, Object... parameters) { - return ProxyUtil.newInstance(functionLoader, clazz, parameters); - } + @Override + public T newInstance(Class clazz, Object... parameters) { + return ProxyUtil.newInstance(functionLoader, clazz, parameters); + } - public void closeClassLoader() { - String files = Arrays.stream(functionLoader.getURLs()).map(URL::getFile).collect(Collectors.joining(";")); - try { - functionLoader.close(); - log.info("Close functionLoader {}", files); + public void closeClassLoader() { + String files = + Arrays.stream(functionLoader.getURLs()).map(URL::getFile).collect(Collectors.joining(";")); + try { + functionLoader.close(); + log.info("Close functionLoader {}", files); - } catch (Exception e) { - log.info("Fail to close functionLoader {}", files); - } + } catch (Exception e) { + log.info("Fail to close functionLoader {}", files); } + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionClassLoader.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionClassLoader.java index 627c30715..ac13b6d09 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionClassLoader.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionClassLoader.java @@ -24,71 +24,72 @@ import java.net.URLClassLoader; import java.util.ArrayList; import java.util.List; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.proxy.ProxyUtil; import org.apache.geaflow.console.core.model.file.GeaflowRemoteFile; import org.apache.geaflow.console.core.model.version.GeaflowVersion; import org.apache.geaflow.console.core.service.file.LocalFileFactory; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + @Slf4j @Getter public class VersionClassLoader implements CompileClassLoader { - protected final GeaflowVersion version; + protected final GeaflowVersion version; - protected final LocalFileFactory localFileFactory; + protected final LocalFileFactory localFileFactory; - protected final URLClassLoader classLoader; + protected final URLClassLoader classLoader; - protected VersionClassLoader(GeaflowVersion version, LocalFileFactory localFileFactory) { - this.version = version; - this.localFileFactory = localFileFactory; - this.classLoader = createClassLoader(); - } + protected VersionClassLoader(GeaflowVersion version, LocalFileFactory localFileFactory) { + this.version = version; + this.localFileFactory = localFileFactory; + this.classLoader = createClassLoader(); + } - public T newInstance(Class clazz, Object... parameters) { - return ProxyUtil.newInstance(classLoader, clazz, parameters); - } + public T newInstance(Class clazz, Object... parameters) { + return ProxyUtil.newInstance(classLoader, clazz, parameters); + } - protected void closeClassLoader() { - try { - classLoader.close(); + protected void closeClassLoader() { + try { + classLoader.close(); - log.info("Close classloader of version {}", version.getName()); - } catch (Exception e) { - log.info("Close classloader of version {} failed", version.getName(), e); - } + log.info("Close classloader of version {}", version.getName()); + } catch (Exception e) { + log.info("Close classloader of version {} failed", version.getName(), e); } - - private URLClassLoader createClassLoader() { - try { - String versionName = version.getName(); - GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); - GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); - if (engineJarPackage == null) { - throw new GeaflowException("Engine jar not found in version {}", versionName); - } - - // prepare engine jar file - List urlList = new ArrayList<>(); - File engineJarFile = localFileFactory.getVersionFile(versionName, engineJarPackage); - urlList.add(engineJarFile.toURI().toURL()); - - // prepare lang jar file - if (langJarPackage != null) { - File langJarFile = localFileFactory.getVersionFile(versionName, langJarPackage); - urlList.add(langJarFile.toURI().toURL()); - } - - // create classloader - ClassLoader extClassLoader = ClassLoader.getSystemClassLoader().getParent(); - return new URLClassLoader(urlList.toArray(new URL[]{}), extClassLoader); - - } catch (Exception e) { - throw new GeaflowException("Create classloader of version {} failed", version.getName(), e); - } + } + + private URLClassLoader createClassLoader() { + try { + String versionName = version.getName(); + GeaflowRemoteFile engineJarPackage = version.getEngineJarPackage(); + GeaflowRemoteFile langJarPackage = version.getLangJarPackage(); + if (engineJarPackage == null) { + throw new GeaflowException("Engine jar not found in version {}", versionName); + } + + // prepare engine jar file + List urlList = new ArrayList<>(); + File engineJarFile = localFileFactory.getVersionFile(versionName, engineJarPackage); + urlList.add(engineJarFile.toURI().toURL()); + + // prepare lang jar file + if (langJarPackage != null) { + File langJarFile = localFileFactory.getVersionFile(versionName, langJarPackage); + urlList.add(langJarFile.toURI().toURL()); + } + + // create classloader + ClassLoader extClassLoader = ClassLoader.getSystemClassLoader().getParent(); + return new URLClassLoader(urlList.toArray(new URL[] {}), extClassLoader); + + } catch (Exception e) { + throw new GeaflowException("Create classloader of version {} failed", version.getName(), e); } - + } } diff --git a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionFactory.java b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionFactory.java index 92fa0da7c..a788e6467 100644 --- a/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionFactory.java +++ b/geaflow-console/app/core/service/src/main/java/org/apache/geaflow/console/core/service/version/VersionFactory.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.console.core.model.version.GeaflowVersion; import org.apache.geaflow.console.core.service.file.LocalFileFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -29,22 +30,23 @@ @Component public class VersionFactory { - private final Map classLoaderMap = new ConcurrentHashMap<>(); + private final Map classLoaderMap = new ConcurrentHashMap<>(); - @Autowired - private LocalFileFactory localFileFactory; + @Autowired private LocalFileFactory localFileFactory; - public VersionClassLoader getClassLoader(GeaflowVersion version) { - return classLoaderMap.compute(version.getName(), (k, vcl) -> { - if (vcl != null && GeaflowVersion.md5Equals(version, vcl.version)) { - return vcl; - } + public VersionClassLoader getClassLoader(GeaflowVersion version) { + return classLoaderMap.compute( + version.getName(), + (k, vcl) -> { + if (vcl != null && GeaflowVersion.md5Equals(version, vcl.version)) { + return vcl; + } - if (vcl != null) { - vcl.closeClassLoader(); - } + if (vcl != null) { + vcl.closeClassLoader(); + } - return new VersionClassLoader(version, localFileFactory); + return new VersionClassLoader(version, localFileFactory); }); - } + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowCompilerTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowCompilerTest.java index a12121650..04aa11989 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowCompilerTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowCompilerTest.java @@ -26,7 +26,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.console.common.service.integration.engine.CompileContext; import org.apache.geaflow.console.common.service.integration.engine.CompileResult; @@ -38,58 +38,66 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import lombok.extern.slf4j.Slf4j; + @Slf4j public class GeaflowCompilerTest { - private final static String ENGINE_JAR_PATH; + private static final String ENGINE_JAR_PATH; - static { - ENGINE_JAR_PATH = "/tmp/geaflow.jar"; - } + static { + ENGINE_JAR_PATH = "/tmp/geaflow.jar"; + } - private static String SCRIPT; + private static String SCRIPT; - @BeforeClass - public static void beforeClass() throws Exception { - log.info(ENGINE_JAR_PATH); - if (!FileUtil.exist(ENGINE_JAR_PATH)) { - throw new GeaflowException("Prepare engine jar at path {}", ENGINE_JAR_PATH); - } - SCRIPT = IOUtils.resourceToString("/compile.sql", StandardCharsets.UTF_8); + @BeforeClass + public static void beforeClass() throws Exception { + log.info(ENGINE_JAR_PATH); + if (!FileUtil.exist(ENGINE_JAR_PATH)) { + throw new GeaflowException("Prepare engine jar at path {}", ENGINE_JAR_PATH); } + SCRIPT = IOUtils.resourceToString("/compile.sql", StandardCharsets.UTF_8); + } - private URLClassLoader createClassLoader() throws Exception { - List urlList = new ArrayList<>(); - urlList.add(new File(ENGINE_JAR_PATH).toURI().toURL()); - ClassLoader extClassLoader = ClassLoader.getSystemClassLoader().getParent(); - return new URLClassLoader(urlList.toArray(new URL[]{}), extClassLoader); - } + private URLClassLoader createClassLoader() throws Exception { + List urlList = new ArrayList<>(); + urlList.add(new File(ENGINE_JAR_PATH).toURI().toURL()); + ClassLoader extClassLoader = ClassLoader.getSystemClassLoader().getParent(); + return new URLClassLoader(urlList.toArray(new URL[] {}), extClassLoader); + } - @Test(enabled = false) - public void testCompile() throws Exception { - try (URLClassLoader classLoader = createClassLoader()) { - LoaderSwitchUtil.run(classLoader, () -> { - Object context = classLoader.loadClass("org.apache.geaflow.dsl.common.compile.CompileContext") + @Test(enabled = false) + public void testCompile() throws Exception { + try (URLClassLoader classLoader = createClassLoader()) { + LoaderSwitchUtil.run( + classLoader, + () -> { + Object context = + classLoader + .loadClass("org.apache.geaflow.dsl.common.compile.CompileContext") .newInstance(); - Object compiler = classLoader.loadClass("org.apache.geaflow.dsl.runtime.QueryClient").newInstance(); + Object compiler = + classLoader.loadClass("org.apache.geaflow.dsl.runtime.QueryClient").newInstance(); - Method method = compiler.getClass().getMethod("compile", String.class, context.getClass()); - Object result = method.invoke(compiler, SCRIPT, context); + Method method = + compiler.getClass().getMethod("compile", String.class, context.getClass()); + Object result = method.invoke(compiler, SCRIPT, context); - String physicPlan = result.getClass().getMethod("getPhysicPlan").invoke(result).toString(); - log.info(physicPlan); - }); - } + String physicPlan = + result.getClass().getMethod("getPhysicPlan").invoke(result).toString(); + log.info(physicPlan); + }); } + } - @Test(enabled = false) - public void testProxyCompile() throws Exception { - try (URLClassLoader classLoader = createClassLoader()) { - CompileContext context = ProxyUtil.newInstance(classLoader, CompileContext.class); - GeaflowCompiler compiler = ProxyUtil.newInstance(classLoader, GeaflowCompiler.class); - CompileResult result = compiler.compile(SCRIPT, context); - log.info(result.getPhysicPlan().toString()); - } + @Test(enabled = false) + public void testProxyCompile() throws Exception { + try (URLClassLoader classLoader = createClassLoader()) { + CompileContext context = ProxyUtil.newInstance(classLoader, CompileContext.class); + GeaflowCompiler compiler = ProxyUtil.newInstance(classLoader, GeaflowCompiler.class); + CompileResult result = compiler.compile(SCRIPT, context); + log.info(result.getPhysicPlan().toString()); } - + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowConfigClassTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowConfigClassTest.java index 94c4d7646..46024800d 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowConfigClassTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GeaflowConfigClassTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.test; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; import org.apache.geaflow.console.core.model.job.config.ClusterConfigClass; @@ -46,165 +45,175 @@ import org.testng.Assert; import org.testng.annotations.Test; +import lombok.extern.slf4j.Slf4j; + @Slf4j public class GeaflowConfigClassTest { - @Test() - public void testGeaflowArgs() { - Assert.assertEquals(buildGeaflowClientArgs().build().size(), 35); - Assert.assertEquals(buildGeaflowClientStopArgs().build().size(), 31); - Assert.assertEquals(buildGeaflowArgs().build().size(), 3); - Assert.assertEquals(buildSystemArgs().build().size(), 10); - Assert.assertEquals(buildClusterArgs().build().size(), 30); - Assert.assertEquals(buildJobArgs().build().size(), 2); - Assert.assertEquals(buildStateArgs().build().size(), 19); - Assert.assertEquals(buildMetricArgs().build().size(), 7); - } - - private static K8sClientArgsClass buildGeaflowClientArgs() { - return new K8sClientArgsClass(buildGeaflowArgs(), GeaflowTask.CODE_TASK_MAIN_CLASS); - } - - private static K8sClientStopArgsClass buildGeaflowClientStopArgs() { - return new K8sClientStopArgsClass("geaflow123456-123456", buildClusterArgs()); - } - - private static GeaflowArgsClass buildGeaflowArgs() { - GeaflowArgsClass config = new GeaflowArgsClass(); - config.setSystemArgs(buildSystemArgs()); - config.setClusterArgs(buildClusterArgs()); - config.setJobArgs(buildJobArgs()); - return config; - } - - private static SystemArgsClass buildSystemArgs() { - SystemArgsClass config = new SystemArgsClass(); - config.setTaskId("123455"); - config.setRuntimeTaskId("geaflow123456-123456"); - config.setRuntimeTaskName("geaflow123456"); - config.setGateway("http://127.0.0.1:8888"); - config.setTaskToken("qwertyuiopasdfghjklzxcvbnm"); - config.setStartupNotifyUrl("http://127.0.0.1:8888/api/tasks/123455/operations"); - config.setInstanceName("test-instance"); - config.setCatalogType("console"); - config.setStateArgs(buildStateArgs()); - config.setMetricArgs(buildMetricArgs()); - return config; - } - - private static StateArgsClass buildStateArgs() { - StateArgsClass stateArgs = new StateArgsClass(); - stateArgs.setRuntimeMetaArgs(new RuntimeMetaArgsClass( + @Test() + public void testGeaflowArgs() { + Assert.assertEquals(buildGeaflowClientArgs().build().size(), 35); + Assert.assertEquals(buildGeaflowClientStopArgs().build().size(), 31); + Assert.assertEquals(buildGeaflowArgs().build().size(), 3); + Assert.assertEquals(buildSystemArgs().build().size(), 10); + Assert.assertEquals(buildClusterArgs().build().size(), 30); + Assert.assertEquals(buildJobArgs().build().size(), 2); + Assert.assertEquals(buildStateArgs().build().size(), 19); + Assert.assertEquals(buildMetricArgs().build().size(), 7); + } + + private static K8sClientArgsClass buildGeaflowClientArgs() { + return new K8sClientArgsClass(buildGeaflowArgs(), GeaflowTask.CODE_TASK_MAIN_CLASS); + } + + private static K8sClientStopArgsClass buildGeaflowClientStopArgs() { + return new K8sClientStopArgsClass("geaflow123456-123456", buildClusterArgs()); + } + + private static GeaflowArgsClass buildGeaflowArgs() { + GeaflowArgsClass config = new GeaflowArgsClass(); + config.setSystemArgs(buildSystemArgs()); + config.setClusterArgs(buildClusterArgs()); + config.setJobArgs(buildJobArgs()); + return config; + } + + private static SystemArgsClass buildSystemArgs() { + SystemArgsClass config = new SystemArgsClass(); + config.setTaskId("123455"); + config.setRuntimeTaskId("geaflow123456-123456"); + config.setRuntimeTaskName("geaflow123456"); + config.setGateway("http://127.0.0.1:8888"); + config.setTaskToken("qwertyuiopasdfghjklzxcvbnm"); + config.setStartupNotifyUrl("http://127.0.0.1:8888/api/tasks/123455/operations"); + config.setInstanceName("test-instance"); + config.setCatalogType("console"); + config.setStateArgs(buildStateArgs()); + config.setMetricArgs(buildMetricArgs()); + return config; + } + + private static StateArgsClass buildStateArgs() { + StateArgsClass stateArgs = new StateArgsClass(); + stateArgs.setRuntimeMetaArgs( + new RuntimeMetaArgsClass( new GeaflowPluginConfig(GeaflowPluginCategory.RUNTIME_META, buildJdbcPluginConfig()))); - stateArgs.setHaMetaArgs( - new HaMetaArgsClass(new GeaflowPluginConfig(GeaflowPluginCategory.HA_META, buildRedisPluginConfig()))); - stateArgs.setPersistentArgs( - new PersistentArgsClass(new GeaflowPluginConfig(GeaflowPluginCategory.DATA, buildDfsPluginConfig()))); - return stateArgs; - } - - private static MetricArgsClass buildMetricArgs() { - return new MetricArgsClass(new GeaflowPluginConfig(GeaflowPluginCategory.METRIC, buildInfluxPluginConfig())); - } - - private static K8SClusterArgsClass buildClusterArgs() { - K8SClusterArgsClass config = new K8SClusterArgsClass(); - config.setTaskClusterConfig(buildClusterConfig()); - config.setClusterConfig(buildK8sPluginConfig()); - config.setEngineJarUrls("http://127.0.0.1/geaflow/files/versions/0.1/geaflow.jar"); - config.setTaskJarUrls("http://127.0.0.1/geaflow/files/users/123456/udf.jar"); - return config; - } - - private static JobArgsClass buildJobArgs() { - JobArgsClass config = new JobArgsClass(); - config.setSystemStateType(GeaflowPluginType.ROCKSDB); - config.setJobConfig(buildJobConfig()); - return config; - } - - private static K8sPluginConfigClass buildK8sPluginConfig() { - K8sPluginConfigClass config = new K8sPluginConfigClass(); - config.setMasterUrl("https://0.0.0.0:6443"); - config.setImageUrl("tugraph/geaflow:0.1"); - config.setServiceAccount("geaflow"); - config.setServiceType("NODE_PORT"); - config.setNamespace("default"); - config.setCertData("xxx"); - config.setCertKey("yyy"); - config.setCaData("zzz"); - config.setRetryTimes(100); - config.setClusterName("aaa"); - config.setPodUserLabels("bbb"); - config.setServiceSuffix("ccc"); - config.setStorageLimit("50Gi"); - return config; - } - - private static JdbcPluginConfigClass buildJdbcPluginConfig() { - JdbcPluginConfigClass config = new JdbcPluginConfigClass(); - config.setDriverClass("com.mysql.jdbc.Driver"); - config.setUrl("jdbc:mysql://127..0.0.1:3306/geaflow"); - config.setUsername("geaflow"); - config.setPassword("geaflow"); - config.setRetryTimes(3); - config.setConnectionPoolSize(10); - config.setConfigJson("{}"); - return config; - } - - private static RedisPluginConfigClass buildRedisPluginConfig() { - RedisPluginConfigClass config = new RedisPluginConfigClass(); - config.setHost("127.0.0.1"); - config.setPort(6379); - config.setRetryTimes(10); - return config; - } - - private static PluginConfigClass buildInfluxPluginConfig() { - InfluxdbPluginConfigClass config = new InfluxdbPluginConfigClass(); - config.setUrl("http://127.0.0.1:8086"); - config.setToken("qwertyuiopkjhgfdxcvb"); - config.setOrg("geaflow"); - config.setBucket("geaflow"); - config.setConnectTimeout(30000); - config.setWriteTimeout(30000); - return config; - } - - private static DfsPluginConfigClass buildDfsPluginConfig() { - DfsPluginConfigClass config = new DfsPluginConfigClass(); - config.setDefaultFs("hdfs://127.0.0.1"); - config.setRoot("/geaflow/chk"); - config.setThreadSize(16); - config.setUsername("geaflow"); - config.getExtendConfig().put("fs.custom", "custom"); - return config; - } - - private static ClusterConfigClass buildClusterConfig() { - ClusterConfigClass config = new ClusterConfigClass(); - config.setContainers(1); - config.setContainerMemory(4096); - config.setContainerCores(1.0); - config.setContainerWorkers(1); - config.setContainerJvmOptions("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); - config.setEnableFo(true); - config.setMasterMemory(4096); - config.setDriverMemory(4096); - config.setClientMemory(1024); - config.setMasterCores(1.0); - config.setDriverCores(1.0); - config.setClientCores(1.0); - config.setMasterJvmOptions("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); - config.setDriverJvmOptions("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); - config.setClientJvmOptions("-Xmx1024m,-Xms1024m,-Xmn256m,-Xss256k,-XX:MaxDirectMemorySize=512m"); - return config; - } - - private static JobConfigClass buildJobConfig() { - JobConfigClass config = new JobConfigClass(); - config.getExtendConfig().put("job.custom", "custom"); - return config; - } + stateArgs.setHaMetaArgs( + new HaMetaArgsClass( + new GeaflowPluginConfig(GeaflowPluginCategory.HA_META, buildRedisPluginConfig()))); + stateArgs.setPersistentArgs( + new PersistentArgsClass( + new GeaflowPluginConfig(GeaflowPluginCategory.DATA, buildDfsPluginConfig()))); + return stateArgs; + } + + private static MetricArgsClass buildMetricArgs() { + return new MetricArgsClass( + new GeaflowPluginConfig(GeaflowPluginCategory.METRIC, buildInfluxPluginConfig())); + } + + private static K8SClusterArgsClass buildClusterArgs() { + K8SClusterArgsClass config = new K8SClusterArgsClass(); + config.setTaskClusterConfig(buildClusterConfig()); + config.setClusterConfig(buildK8sPluginConfig()); + config.setEngineJarUrls("http://127.0.0.1/geaflow/files/versions/0.1/geaflow.jar"); + config.setTaskJarUrls("http://127.0.0.1/geaflow/files/users/123456/udf.jar"); + return config; + } + + private static JobArgsClass buildJobArgs() { + JobArgsClass config = new JobArgsClass(); + config.setSystemStateType(GeaflowPluginType.ROCKSDB); + config.setJobConfig(buildJobConfig()); + return config; + } + + private static K8sPluginConfigClass buildK8sPluginConfig() { + K8sPluginConfigClass config = new K8sPluginConfigClass(); + config.setMasterUrl("https://0.0.0.0:6443"); + config.setImageUrl("tugraph/geaflow:0.1"); + config.setServiceAccount("geaflow"); + config.setServiceType("NODE_PORT"); + config.setNamespace("default"); + config.setCertData("xxx"); + config.setCertKey("yyy"); + config.setCaData("zzz"); + config.setRetryTimes(100); + config.setClusterName("aaa"); + config.setPodUserLabels("bbb"); + config.setServiceSuffix("ccc"); + config.setStorageLimit("50Gi"); + return config; + } + + private static JdbcPluginConfigClass buildJdbcPluginConfig() { + JdbcPluginConfigClass config = new JdbcPluginConfigClass(); + config.setDriverClass("com.mysql.jdbc.Driver"); + config.setUrl("jdbc:mysql://127..0.0.1:3306/geaflow"); + config.setUsername("geaflow"); + config.setPassword("geaflow"); + config.setRetryTimes(3); + config.setConnectionPoolSize(10); + config.setConfigJson("{}"); + return config; + } + + private static RedisPluginConfigClass buildRedisPluginConfig() { + RedisPluginConfigClass config = new RedisPluginConfigClass(); + config.setHost("127.0.0.1"); + config.setPort(6379); + config.setRetryTimes(10); + return config; + } + + private static PluginConfigClass buildInfluxPluginConfig() { + InfluxdbPluginConfigClass config = new InfluxdbPluginConfigClass(); + config.setUrl("http://127.0.0.1:8086"); + config.setToken("qwertyuiopkjhgfdxcvb"); + config.setOrg("geaflow"); + config.setBucket("geaflow"); + config.setConnectTimeout(30000); + config.setWriteTimeout(30000); + return config; + } + + private static DfsPluginConfigClass buildDfsPluginConfig() { + DfsPluginConfigClass config = new DfsPluginConfigClass(); + config.setDefaultFs("hdfs://127.0.0.1"); + config.setRoot("/geaflow/chk"); + config.setThreadSize(16); + config.setUsername("geaflow"); + config.getExtendConfig().put("fs.custom", "custom"); + return config; + } + + private static ClusterConfigClass buildClusterConfig() { + ClusterConfigClass config = new ClusterConfigClass(); + config.setContainers(1); + config.setContainerMemory(4096); + config.setContainerCores(1.0); + config.setContainerWorkers(1); + config.setContainerJvmOptions( + "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); + config.setEnableFo(true); + config.setMasterMemory(4096); + config.setDriverMemory(4096); + config.setClientMemory(1024); + config.setMasterCores(1.0); + config.setDriverCores(1.0); + config.setClientCores(1.0); + config.setMasterJvmOptions( + "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); + config.setDriverJvmOptions( + "-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m"); + config.setClientJvmOptions( + "-Xmx1024m,-Xms1024m,-Xmn256m,-Xss256k,-XX:MaxDirectMemorySize=512m"); + return config; + } + + private static JobConfigClass buildJobConfig() { + JobConfigClass config = new JobConfigClass(); + config.getExtendConfig().put("job.custom", "custom"); + return config; + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphDiffTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphDiffTest.java index 6b4ceb0f8..eaca387b9 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphDiffTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphDiffTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.console.common.util.ListUtil; import org.apache.geaflow.console.core.model.data.GeaflowGraph; import org.apache.geaflow.console.core.model.job.GeaflowTransferJob.StructMapping; @@ -31,78 +32,79 @@ public class GraphDiffTest { - @Test - public void testDiff() { - List list1 = new ArrayList<>(); - List list2 = new ArrayList<>(); - list1.add("1"); - list1.add("2"); - list1.add("3"); - - list2.add("3"); - list2.add("4"); - List diff1 = ListUtil.diff(list1, list2); - List diff2 = ListUtil.diff(list2, list1); - Assert.assertEquals(diff1.size(), 2); - Assert.assertEquals(diff2.size(), 1); - } - - @Test - public void testDiff2() { - List list1 = new ArrayList<>(); - List list2 = new ArrayList<>(); - list1.add("1"); - list1.add("2"); - list1.add("2"); - list1.add("2"); - list1.add("3"); + @Test + public void testDiff() { + List list1 = new ArrayList<>(); + List list2 = new ArrayList<>(); + list1.add("1"); + list1.add("2"); + list1.add("3"); - list2.add("3"); - list2.add("3"); - list2.add("4"); - List diff1 = ListUtil.diff(list1, list2); - List diff2 = ListUtil.diff(list2, list1); - Assert.assertEquals(diff1.size(), 4); - Assert.assertEquals(diff2.size(), 2); - } + list2.add("3"); + list2.add("4"); + List diff1 = ListUtil.diff(list1, list2); + List diff2 = ListUtil.diff(list2, list1); + Assert.assertEquals(diff1.size(), 2); + Assert.assertEquals(diff2.size(), 1); + } - @Test - public void testNameDiff() { - GeaflowGraph g1 = new GeaflowGraph("g1", null); - g1.setInstanceId("1"); - GeaflowGraph g2 = new GeaflowGraph("g1", null); - g2.setInstanceId("2"); - GeaflowGraph g3 = new GeaflowGraph("g1", null); - g3.setInstanceId("3"); - GeaflowGraph g4 = new GeaflowGraph("g1", null); - g4.setInstanceId("1"); - List list1 = Arrays.asList(g1, g2); - List list2 = Arrays.asList(g2, g3, g4); + @Test + public void testDiff2() { + List list1 = new ArrayList<>(); + List list2 = new ArrayList<>(); + list1.add("1"); + list1.add("2"); + list1.add("2"); + list1.add("2"); + list1.add("3"); - List diff1 = ListUtil.diff(list1, list2); - List diff2 = ListUtil.diff(list2, list1); - Assert.assertEquals(diff1.size(), 0); - Assert.assertEquals(diff2.size(), 1); + list2.add("3"); + list2.add("3"); + list2.add("4"); + List diff1 = ListUtil.diff(list1, list2); + List diff2 = ListUtil.diff(list2, list1); + Assert.assertEquals(diff1.size(), 4); + Assert.assertEquals(diff2.size(), 2); + } - List diff3 = ListUtil.diff(list1, list2, graph -> graph.getName() + "-" + graph.getInstanceId()); - List diff4 = ListUtil.diff(list2, list1, graph -> graph.getName() + "-" + graph.getInstanceId()); + @Test + public void testNameDiff() { + GeaflowGraph g1 = new GeaflowGraph("g1", null); + g1.setInstanceId("1"); + GeaflowGraph g2 = new GeaflowGraph("g1", null); + g2.setInstanceId("2"); + GeaflowGraph g3 = new GeaflowGraph("g1", null); + g3.setInstanceId("3"); + GeaflowGraph g4 = new GeaflowGraph("g1", null); + g4.setInstanceId("1"); + List list1 = Arrays.asList(g1, g2); + List list2 = Arrays.asList(g2, g3, g4); - Assert.assertEquals(diff3.size(), 0); - Assert.assertEquals(diff4.get(0).getInstanceId(), "3"); - } + List diff1 = ListUtil.diff(list1, list2); + List diff2 = ListUtil.diff(list2, list1); + Assert.assertEquals(diff1.size(), 0); + Assert.assertEquals(diff2.size(), 1); - @Test - public void testStructMappingDiff() { - List structMappings = new ArrayList<>(); - structMappings.add(new StructMapping("t1", "v1", null)); - structMappings.add(new StructMapping("t1", "v1", null)); - structMappings.add(new StructMapping("t2", "v2", null)); - structMappings.add(new StructMapping("t2", "v2", null)); + List diff3 = + ListUtil.diff(list1, list2, graph -> graph.getName() + "-" + graph.getInstanceId()); + List diff4 = + ListUtil.diff(list2, list1, graph -> graph.getName() + "-" + graph.getInstanceId()); - List distinctList = structMappings.stream().distinct().collect(Collectors.toList()); - List diff = ListUtil.diff(structMappings, distinctList); - Assert.assertEquals(diff.size(), 2); + Assert.assertEquals(diff3.size(), 0); + Assert.assertEquals(diff4.get(0).getInstanceId(), "3"); + } - } + @Test + public void testStructMappingDiff() { + List structMappings = new ArrayList<>(); + structMappings.add(new StructMapping("t1", "v1", null)); + structMappings.add(new StructMapping("t1", "v1", null)); + structMappings.add(new StructMapping("t2", "v2", null)); + structMappings.add(new StructMapping("t2", "v2", null)); + List distinctList = + structMappings.stream().distinct().collect(Collectors.toList()); + List diff = ListUtil.diff(structMappings, distinctList); + Assert.assertEquals(diff.size(), 2); + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphSchemaTranslateTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphSchemaTranslateTest.java index b2bf10c30..a19060221 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphSchemaTranslateTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/GraphSchemaTranslateTest.java @@ -33,56 +33,70 @@ public class GraphSchemaTranslateTest { - @Test - public void testGenerateCode() { + @Test + public void testGenerateCode() { - GeaflowVertex person = new GeaflowVertex("person", ""); - person.setId("1"); - person.addFields(Lists.newArrayList( + GeaflowVertex person = new GeaflowVertex("person", ""); + person.setId("1"); + person.addFields( + Lists.newArrayList( new GeaflowField("id", "", GeaflowFieldType.INT, GeaflowFieldCategory.VERTEX_ID), new GeaflowField("name", "", GeaflowFieldType.VARCHAR, GeaflowFieldCategory.PROPERTY), new GeaflowField("age", "", GeaflowFieldType.INT, GeaflowFieldCategory.PROPERTY))); - GeaflowVertex software = new GeaflowVertex("software", ""); - software.setId("2"); - software.addFields(Lists.newArrayList( + GeaflowVertex software = new GeaflowVertex("software", ""); + software.setId("2"); + software.addFields( + Lists.newArrayList( new GeaflowField("id", "", GeaflowFieldType.INT, GeaflowFieldCategory.VERTEX_ID), new GeaflowField("lang", "", GeaflowFieldType.VARCHAR, GeaflowFieldCategory.PROPERTY), new GeaflowField("price", "", GeaflowFieldType.INT, GeaflowFieldCategory.PROPERTY))); + GeaflowEdge knows = new GeaflowEdge("knows", ""); + knows.setId("3"); + knows.addFields( + Lists.newArrayList( + new GeaflowField( + "srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), + new GeaflowField( + "targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), + new GeaflowField( + "weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); - GeaflowEdge knows = new GeaflowEdge("knows", ""); - knows.setId("3"); - knows.addFields(Lists.newArrayList( - new GeaflowField("srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), - new GeaflowField("targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), - new GeaflowField("weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); + GeaflowEdge creates = new GeaflowEdge("creates", ""); + creates.setId("4"); + creates.addFields( + Lists.newArrayList( + new GeaflowField( + "srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), + new GeaflowField( + "targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), + new GeaflowField( + "weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); - GeaflowEdge creates = new GeaflowEdge("creates", ""); - creates.setId("4"); - creates.addFields(Lists.newArrayList( - new GeaflowField("srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), - new GeaflowField("targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), - new GeaflowField("weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); + GeaflowEdge uses = new GeaflowEdge("uses", ""); + uses.setId("5"); + uses.addFields( + Lists.newArrayList( + new GeaflowField( + "srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), + new GeaflowField( + "targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), + new GeaflowField( + "weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); - GeaflowEdge uses = new GeaflowEdge("uses", ""); - uses.setId("5"); - uses.addFields(Lists.newArrayList( - new GeaflowField("srcId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_SOURCE_ID), - new GeaflowField("targetId", "", GeaflowFieldType.INT, GeaflowFieldCategory.EDGE_TARGET_ID), - new GeaflowField("weight", "", GeaflowFieldType.DOUBLE, GeaflowFieldCategory.PROPERTY))); - - GeaflowGraph graph = new GeaflowGraph(); - graph.addVertex(person); - graph.addVertex(software); - graph.addEdge(knows); - graph.addEdge(creates); - graph.addEdge(uses); - graph.setEndpoints(Lists.newArrayList( - new GeaflowEndpoint("3", "1", "1"), - new GeaflowEndpoint("4", "1", "2"))); - String result = GraphSchemaTranslator.translateGraphSchema(graph); - Assert.assertEquals(result, "CREATE GRAPH g (\n" + GeaflowGraph graph = new GeaflowGraph(); + graph.addVertex(person); + graph.addVertex(software); + graph.addEdge(knows); + graph.addEdge(creates); + graph.addEdge(uses); + graph.setEndpoints( + Lists.newArrayList(new GeaflowEndpoint("3", "1", "1"), new GeaflowEndpoint("4", "1", "2"))); + String result = GraphSchemaTranslator.translateGraphSchema(graph); + Assert.assertEquals( + result, + "CREATE GRAPH g (\n" + " Vertex person (\n" + " id INT ID,\n" + " name VARCHAR,\n" @@ -109,5 +123,5 @@ public void testGenerateCode() { + " weight DOUBLE\n" + " )\n" + ");"); - } + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/IntegrationTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/IntegrationTest.java index a6fb8fdf8..854ae75d3 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/IntegrationTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/IntegrationTest.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.console.biz.shared.convert.JobViewConverter; import org.apache.geaflow.console.biz.shared.view.JobView; import org.apache.geaflow.console.common.util.type.GeaflowJobType; @@ -34,9 +35,10 @@ public class IntegrationTest { - @Test - public void testGenerateCode() { - String structString = "[\n" + @Test + public void testGenerateCode() { + String structString = + "[\n" + " {\n" + " \"tableName\": \"t1\",\n" + " \"structName\": \"v1\",\n" @@ -66,18 +68,21 @@ public void testGenerateCode() { + " }\n" + "]"; - JobView jobView = new JobView(); - jobView.setStructMappings(structString); - jobView.setType(GeaflowJobType.INTEGRATE); - JobViewConverter jobViewConverter = new JobViewConverter(); - GeaflowJob job = jobViewConverter.convert(jobView, new ArrayList<>(), new ArrayList<>(), null, null); - List structMappings = job.getStructMappings(); - Assert.assertEquals(structMappings.get(0).getFieldMappings().size(), 2); - Assert.assertEquals(structMappings.get(1).getFieldMappings().size(), 3); + JobView jobView = new JobView(); + jobView.setStructMappings(structString); + jobView.setType(GeaflowJobType.INTEGRATE); + JobViewConverter jobViewConverter = new JobViewConverter(); + GeaflowJob job = + jobViewConverter.convert(jobView, new ArrayList<>(), new ArrayList<>(), null, null); + List structMappings = job.getStructMappings(); + Assert.assertEquals(structMappings.get(0).getFieldMappings().size(), 2); + Assert.assertEquals(structMappings.get(1).getFieldMappings().size(), 3); - job.setGraph(Lists.newArrayList(new GeaflowGraph("g1", null))); - String text = ((GeaflowIntegrateJob) job).generateCode().getText(); - Assert.assertEquals(text, "USE GRAPH g1;\n" + job.setGraph(Lists.newArrayList(new GeaflowGraph("g1", null))); + String text = ((GeaflowIntegrateJob) job).generateCode().getText(); + Assert.assertEquals( + text, + "USE GRAPH g1;\n" + "\n" + "insert into g1(\n" + " v1.v_id,\n" @@ -105,7 +110,5 @@ public void testGenerateCode() { + " t_targetId\n" + "from edgeTable;\n" + " \n"); - } - - + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/JobPlanTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/JobPlanTest.java index 75d721f58..a1e5cf6ea 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/JobPlanTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/JobPlanTest.java @@ -19,18 +19,20 @@ package org.apache.geaflow.console.test; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.console.common.service.integration.engine.JsonPlan; import org.apache.geaflow.console.core.model.release.JobPlan; import org.apache.geaflow.console.core.model.release.JobPlanBuilder; import org.testng.Assert; import org.testng.annotations.Test; +import com.alibaba.fastjson.JSON; + public class JobPlanTest { - @Test - public void test() { - String plan = "{\"vertices\":{\"1\":{\"vertexType\":\"source\",\"id\":\"1\",\"parallelism\":1,\"parents\":[]," + @Test + public void test() { + String plan = + "{\"vertices\":{\"1\":{\"vertexType\":\"source\",\"id\":\"1\",\"parallelism\":1,\"parents\":[]," + "\"innerPlan\":{\"vertices\":{\"1-1\":{\"id\":\"1-1\",\"parallelism\":1,\"operator\":\"WindowStreamSourceOperator\"," + "\"operatorName\":\"1\",\"parents\":[]},\"1-4\":{\"id\":\"1-4\",\"parallelism\":1," + "\"operator\":\"KeySelectorStreamOperator\",\"operatorName\":\"4\",\"parents\":[{\"id\":\"1-1\"}]}}}}," @@ -48,11 +50,10 @@ public void test() { + "\"6-6\":{\"id\":\"6-6\",\"parallelism\":2,\"operator\":\"FlatMapStreamOperator\"," + "\"operatorName\":\"TraversalResponseToRow-0\",\"parents\":[]}}}}}}"; - JsonPlan jsonPlan = JSON.parseObject(plan, JsonPlan.class); - JobPlan jobPlan = JobPlanBuilder.build(jsonPlan); - Assert.assertEquals(jobPlan.getVertices().size(), 4); - Assert.assertEquals(jobPlan.getVertices().get("vertex_centric-3").getParallelism(), 2); - Assert.assertEquals(jobPlan.getEdgeMap().get("source-1").get(0).getSourceKey(), "source-1"); - } - + JsonPlan jsonPlan = JSON.parseObject(plan, JsonPlan.class); + JobPlan jobPlan = JobPlanBuilder.build(jsonPlan); + Assert.assertEquals(jobPlan.getVertices().size(), 4); + Assert.assertEquals(jobPlan.getVertices().get("vertex_centric-3").getParallelism(), 2); + Assert.assertEquals(jobPlan.getEdgeMap().get("source-1").get(0).getSourceKey(), "source-1"); + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LLMClientTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LLMClientTest.java index 489604721..3af27507d 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LLMClientTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LLMClientTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.test; -import lombok.extern.slf4j.Slf4j; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.core.model.config.GeaflowConfig; import org.apache.geaflow.console.core.model.llm.CodefuseConfigArgsClass; @@ -32,62 +31,61 @@ import org.testng.Assert; import org.testng.annotations.Test; +import lombok.extern.slf4j.Slf4j; + @Slf4j public class LLMClientTest { - @Test(enabled = false) - public void testLocal() { - LLMClient llmClient = LocalClient.getInstance(); - - GeaflowLLM geaflowLLM = new GeaflowLLM(); - geaflowLLM.setUrl("http://127.0.0.1:8000/completion"); - String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); - log.info(answer); - } + @Test(enabled = false) + public void testLocal() { + LLMClient llmClient = LocalClient.getInstance(); - @Test(enabled = false) - public void testOpenAI() { + GeaflowLLM geaflowLLM = new GeaflowLLM(); + geaflowLLM.setUrl("http://127.0.0.1:8000/completion"); + String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); + log.info(answer); + } - LLMClient llmClient = OpenAiClient.getInstance(); + @Test(enabled = false) + public void testOpenAI() { - GeaflowLLM geaflowLLM = new GeaflowLLM(); - geaflowLLM.setUrl("https://api.openai.com/v1/chat/completions"); + LLMClient llmClient = OpenAiClient.getInstance(); - GeaflowConfig geaflowConfig = new GeaflowConfig(); + GeaflowLLM geaflowLLM = new GeaflowLLM(); + geaflowLLM.setUrl("https://api.openai.com/v1/chat/completions"); - OpenAIConfigArgsClass configArgsClass = new OpenAIConfigArgsClass(); - // set own sk - String sk = "sk-xxxx"; - configArgsClass.setApiKey(sk); + GeaflowConfig geaflowConfig = new GeaflowConfig(); - Assert.assertThrows(GeaflowException.class, () -> geaflowConfig.parse(OpenAIConfigArgsClass.class)); + OpenAIConfigArgsClass configArgsClass = new OpenAIConfigArgsClass(); + // set own sk + String sk = "sk-xxxx"; + configArgsClass.setApiKey(sk); - configArgsClass.setModelId("ft:gpt-3.5-turbo-1106:personal:geaflow:8zLbe4Ua"); - geaflowLLM.setArgs(configArgsClass.build()); - - for (int i = 0; i < 5; i++) { - String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); - log.info(answer); - } + Assert.assertThrows( + GeaflowException.class, () -> geaflowConfig.parse(OpenAIConfigArgsClass.class)); + configArgsClass.setModelId("ft:gpt-3.5-turbo-1106:personal:geaflow:8zLbe4Ua"); + geaflowLLM.setArgs(configArgsClass.build()); + for (int i = 0; i < 5; i++) { + String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); + log.info(answer); } + } - @Test(enabled = false) - public void testCodefuse() { - LLMClient llmClient = CodefuseClient.getInstance(); - - GeaflowLLM geaflowLLM = new GeaflowLLM(); - geaflowLLM.setUrl("https://riskautopilot-pre.alipay.com/v1/gpt/codegpt/task"); + @Test(enabled = false) + public void testCodefuse() { + LLMClient llmClient = CodefuseClient.getInstance(); - CodefuseConfigArgsClass configArgsClass = new CodefuseConfigArgsClass(); - configArgsClass.setChainName("v18"); - configArgsClass.setSceneName("codegpt_single_finetune_v18"); - geaflowLLM.setArgs(configArgsClass.build()); - - String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); - log.info(answer); - } + GeaflowLLM geaflowLLM = new GeaflowLLM(); + geaflowLLM.setUrl("https://riskautopilot-pre.alipay.com/v1/gpt/codegpt/task"); + CodefuseConfigArgsClass configArgsClass = new CodefuseConfigArgsClass(); + configArgsClass.setChainName("v18"); + configArgsClass.setSceneName("codegpt_single_finetune_v18"); + geaflowLLM.setArgs(configArgsClass.build()); + String answer = llmClient.call(geaflowLLM, "找出小红的10个朋友?"); + log.info(answer); + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LocalClientTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LocalClientTest.java index 5bfe8a15c..ce2fc6d57 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LocalClientTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/LocalClientTest.java @@ -19,114 +19,113 @@ package org.apache.geaflow.console.test; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONObject; import org.apache.geaflow.console.core.model.llm.LocalConfigArgsClass; import org.apache.geaflow.console.core.service.llm.LocalClient; import org.testng.Assert; import org.testng.annotations.Test; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; + public class LocalClientTest { - @Test - public void testGetJsonString_WithPredictValue() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(256); - - String prompt = "Test prompt"; - String jsonString = client.getJsonString(config, prompt); - - JSONObject json = JSON.parseObject(jsonString); - Assert.assertEquals(json.getString("prompt"), "Test prompt"); - Assert.assertEquals(json.getInteger("n_predict").intValue(), 256); - } - - @Test - public void testGetJsonString_WithNullPredict() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(null); - - String prompt = "Test prompt"; - String jsonString = client.getJsonString(config, prompt); - - JSONObject json = JSON.parseObject(jsonString); - Assert.assertEquals(json.getString("prompt"), "Test prompt"); - // Should use default value 128 when predict is null - Assert.assertEquals(json.getInteger("n_predict").intValue(), 128); - } - - @Test - public void testGetJsonString_WithPromptContainingSpecialCharacters() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(100); - - // Test with special characters that need JSON escaping - String prompt = "Test \"quoted\" prompt\nwith newline\tand tab"; - String jsonString = client.getJsonString(config, prompt); - - // Should be valid JSON and properly escaped - JSONObject json = JSON.parseObject(jsonString); - Assert.assertNotNull(json); - Assert.assertEquals(json.getInteger("n_predict").intValue(), 100); - - // The prompt should be properly escaped in JSON string (check raw JSON string) - // In the raw JSON string, quotes should be escaped as \" - Assert.assertTrue(jsonString.contains("\\\""), - "JSON string should contain escaped quotes"); - - // Verify the parsed prompt matches the original (special characters preserved) - String parsedPrompt = json.getString("prompt"); - Assert.assertNotNull(parsedPrompt); - Assert.assertEquals(parsedPrompt, prompt, - "Parsed prompt should match original prompt with special characters"); - } - - @Test - public void testGetJsonString_WithPromptTrim() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(200); - - // Test with trailing newline that should be trimmed - String prompt = "Test prompt\n"; - String jsonString = client.getJsonString(config, prompt); - - JSONObject json = JSON.parseObject(jsonString); - Assert.assertEquals(json.getString("prompt"), "Test prompt"); - Assert.assertEquals(json.getInteger("n_predict").intValue(), 200); - } - - @Test - public void testGetJsonString_WithZeroPredict() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(0); - - String prompt = "Test prompt"; - String jsonString = client.getJsonString(config, prompt); - - JSONObject json = JSON.parseObject(jsonString); - Assert.assertEquals(json.getString("prompt"), "Test prompt"); - Assert.assertEquals(json.getInteger("n_predict").intValue(), 0); - } - - @Test - public void testGetJsonString_JsonStructure() { - LocalClient client = (LocalClient) LocalClient.getInstance(); - LocalConfigArgsClass config = new LocalConfigArgsClass(); - config.setPredict(128); - - String prompt = "Test"; - String jsonString = client.getJsonString(config, prompt); - - // Verify JSON structure - JSONObject json = JSON.parseObject(jsonString); - Assert.assertTrue(json.containsKey("prompt")); - Assert.assertTrue(json.containsKey("n_predict")); - Assert.assertEquals(json.size(), 2); // Should only have these two fields - } -} + @Test + public void testGetJsonString_WithPredictValue() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(256); + + String prompt = "Test prompt"; + String jsonString = client.getJsonString(config, prompt); + + JSONObject json = JSON.parseObject(jsonString); + Assert.assertEquals(json.getString("prompt"), "Test prompt"); + Assert.assertEquals(json.getInteger("n_predict").intValue(), 256); + } + + @Test + public void testGetJsonString_WithNullPredict() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(null); + + String prompt = "Test prompt"; + String jsonString = client.getJsonString(config, prompt); + + JSONObject json = JSON.parseObject(jsonString); + Assert.assertEquals(json.getString("prompt"), "Test prompt"); + // Should use default value 128 when predict is null + Assert.assertEquals(json.getInteger("n_predict").intValue(), 128); + } + + @Test + public void testGetJsonString_WithPromptContainingSpecialCharacters() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(100); + + // Test with special characters that need JSON escaping + String prompt = "Test \"quoted\" prompt\nwith newline\tand tab"; + String jsonString = client.getJsonString(config, prompt); + // Should be valid JSON and properly escaped + JSONObject json = JSON.parseObject(jsonString); + Assert.assertNotNull(json); + Assert.assertEquals(json.getInteger("n_predict").intValue(), 100); + + // The prompt should be properly escaped in JSON string (check raw JSON string) + // In the raw JSON string, quotes should be escaped as \" + Assert.assertTrue(jsonString.contains("\\\""), "JSON string should contain escaped quotes"); + + // Verify the parsed prompt matches the original (special characters preserved) + String parsedPrompt = json.getString("prompt"); + Assert.assertNotNull(parsedPrompt); + Assert.assertEquals( + parsedPrompt, prompt, "Parsed prompt should match original prompt with special characters"); + } + + @Test + public void testGetJsonString_WithPromptTrim() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(200); + + // Test with trailing newline that should be trimmed + String prompt = "Test prompt\n"; + String jsonString = client.getJsonString(config, prompt); + + JSONObject json = JSON.parseObject(jsonString); + Assert.assertEquals(json.getString("prompt"), "Test prompt"); + Assert.assertEquals(json.getInteger("n_predict").intValue(), 200); + } + + @Test + public void testGetJsonString_WithZeroPredict() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(0); + + String prompt = "Test prompt"; + String jsonString = client.getJsonString(config, prompt); + + JSONObject json = JSON.parseObject(jsonString); + Assert.assertEquals(json.getString("prompt"), "Test prompt"); + Assert.assertEquals(json.getInteger("n_predict").intValue(), 0); + } + + @Test + public void testGetJsonString_JsonStructure() { + LocalClient client = (LocalClient) LocalClient.getInstance(); + LocalConfigArgsClass config = new LocalConfigArgsClass(); + config.setPredict(128); + + String prompt = "Test"; + String jsonString = client.getJsonString(config, prompt); + + // Verify JSON structure + JSONObject json = JSON.parseObject(jsonString); + Assert.assertTrue(json.containsKey("prompt")); + Assert.assertTrue(json.containsKey("n_predict")); + Assert.assertEquals(json.size(), 2); // Should only have these two fields + } +} diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/PluginConfigViewConverterTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/PluginConfigViewConverterTest.java index 61298a1e6..a58643327 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/PluginConfigViewConverterTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/PluginConfigViewConverterTest.java @@ -20,8 +20,7 @@ package org.apache.geaflow.console.test; import java.util.Date; -import lombok.Getter; -import lombok.Setter; + import org.apache.geaflow.console.biz.shared.convert.PluginConfigViewConverter; import org.apache.geaflow.console.biz.shared.view.PluginConfigView; import org.apache.geaflow.console.common.util.type.GeaflowPluginType; @@ -35,63 +34,66 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; -public class PluginConfigViewConverterTest { - - PluginConfigViewConverter pluginConfigViewConverter = new PluginConfigViewConverter(); - - @BeforeTest - public void setUp() throws Exception { - ConfigDescFactory.getOrRegister(TestPluginConfigView.class); - } - - @Test - public void test() { - GeaflowConfig geaflowConfig = new GeaflowConfig(); - geaflowConfig.put("key.boolean", null); - geaflowConfig.put("key.long", 1234L); - geaflowConfig.put("key.double", 3.14); - geaflowConfig.put("key.string", "hello"); - geaflowConfig.put("a", 1); - geaflowConfig.put("b", "xxx"); - - PluginConfigView view = new PluginConfigView(); - view.setId("xxxx"); - view.setName("test-plugin"); - view.setComment("test plugin"); - view.setType(GeaflowPluginType.MEMORY.name()); - view.setConfig(geaflowConfig); - - GeaflowPluginConfig model = pluginConfigViewConverter.convert(view); - - model.setGmtCreate(new Date()); - model.setGmtModified(new Date()); - PluginConfigView newView = pluginConfigViewConverter.convert(model); - - Assert.assertEquals(view.getConfig().size(), 6); - Assert.assertEquals(newView.getConfig().size(), 6); - } - - @Getter - @Setter - public static class TestPluginConfigView extends PluginConfigClass { - - @GeaflowConfigKey(value = "key.boolean") - private Boolean booleanField; - - @GeaflowConfigKey(value = "key.long") - @GeaflowConfigValue(required = true) - private Long longField; - - @GeaflowConfigKey("key.double") - @GeaflowConfigValue(defaultValue = "3.14") - private Double doubleField; +import lombok.Getter; +import lombok.Setter; - @GeaflowConfigKey(value = "key.string", comment = "String Value") - @GeaflowConfigValue(required = true, defaultValue = "stringValue", masked = true) - private String stringField; +public class PluginConfigViewConverterTest { - public TestPluginConfigView() { - super(GeaflowPluginType.MEMORY); - } + PluginConfigViewConverter pluginConfigViewConverter = new PluginConfigViewConverter(); + + @BeforeTest + public void setUp() throws Exception { + ConfigDescFactory.getOrRegister(TestPluginConfigView.class); + } + + @Test + public void test() { + GeaflowConfig geaflowConfig = new GeaflowConfig(); + geaflowConfig.put("key.boolean", null); + geaflowConfig.put("key.long", 1234L); + geaflowConfig.put("key.double", 3.14); + geaflowConfig.put("key.string", "hello"); + geaflowConfig.put("a", 1); + geaflowConfig.put("b", "xxx"); + + PluginConfigView view = new PluginConfigView(); + view.setId("xxxx"); + view.setName("test-plugin"); + view.setComment("test plugin"); + view.setType(GeaflowPluginType.MEMORY.name()); + view.setConfig(geaflowConfig); + + GeaflowPluginConfig model = pluginConfigViewConverter.convert(view); + + model.setGmtCreate(new Date()); + model.setGmtModified(new Date()); + PluginConfigView newView = pluginConfigViewConverter.convert(model); + + Assert.assertEquals(view.getConfig().size(), 6); + Assert.assertEquals(newView.getConfig().size(), 6); + } + + @Getter + @Setter + public static class TestPluginConfigView extends PluginConfigClass { + + @GeaflowConfigKey(value = "key.boolean") + private Boolean booleanField; + + @GeaflowConfigKey(value = "key.long") + @GeaflowConfigValue(required = true) + private Long longField; + + @GeaflowConfigKey("key.double") + @GeaflowConfigValue(defaultValue = "3.14") + private Double doubleField; + + @GeaflowConfigKey(value = "key.string", comment = "String Value") + @GeaflowConfigValue(required = true, defaultValue = "stringValue", masked = true) + private String stringField; + + public TestPluginConfigView() { + super(GeaflowPluginType.MEMORY); } + } } diff --git a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/ZipTest.java b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/ZipTest.java index 165601b85..525eecab8 100644 --- a/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/ZipTest.java +++ b/geaflow-console/app/test/src/test/java/org/apache/geaflow/console/test/ZipTest.java @@ -19,8 +19,6 @@ package org.apache.geaflow.console.test; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONObject; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; @@ -37,7 +35,7 @@ import java.util.UUID; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; -import lombok.extern.slf4j.Slf4j; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.console.common.util.Fmt; import org.apache.geaflow.console.common.util.ZipUtil; @@ -47,125 +45,161 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; + +import lombok.extern.slf4j.Slf4j; + @Slf4j public class ZipTest { - @Test - void unzip() throws IOException { - String testPath = "/tmp/geaflow-zipTest-" + UUID.randomUUID() + "/"; - String zipPath = testPath + "test-zip.zip"; - String gqlPath = testPath + "user.gql"; - String confPath = testPath + "user.conf"; - try { - FileUtils.forceMkdir(new File(testPath)); - HashMap map = new HashMap<>(); - map.put("aa", 1); - map.put("b", 3); - map.put("c", 4); - GeaflowZipEntry userGql = new MemoryZipEntry("user.gql", code); - GeaflowZipEntry conf = new MemoryZipEntry("user.conf", JSON.toJSONString(map)); - - InputStream inputStream = ZipUtil.buildZipInputStream(Arrays.asList(userGql, conf)); - FileUtils.copyInputStreamToFile(inputStream, new File(zipPath)); - - ZipUtil.unzip(new File(zipPath)); - - File confFile = new File(confPath); - String confString = FileUtils.readFileToString(confFile, Charset.defaultCharset()); - JSONObject jsonObject = JSON.parseObject(confString); - Assert.assertEquals(jsonObject.getInteger("aa").intValue(), 1); - - File gqlFile = new File(gqlPath); - String gqlString = FileUtils.readFileToString(gqlFile, Charset.defaultCharset()); - Assert.assertEquals(gqlString, code); - } finally { - FileUtils.deleteDirectory(new File(testPath)); - } + @Test + void unzip() throws IOException { + String testPath = "/tmp/geaflow-zipTest-" + UUID.randomUUID() + "/"; + String zipPath = testPath + "test-zip.zip"; + String gqlPath = testPath + "user.gql"; + String confPath = testPath + "user.conf"; + try { + FileUtils.forceMkdir(new File(testPath)); + HashMap map = new HashMap<>(); + map.put("aa", 1); + map.put("b", 3); + map.put("c", 4); + GeaflowZipEntry userGql = new MemoryZipEntry("user.gql", code); + GeaflowZipEntry conf = new MemoryZipEntry("user.conf", JSON.toJSONString(map)); + + InputStream inputStream = ZipUtil.buildZipInputStream(Arrays.asList(userGql, conf)); + FileUtils.copyInputStreamToFile(inputStream, new File(zipPath)); + + ZipUtil.unzip(new File(zipPath)); + + File confFile = new File(confPath); + String confString = FileUtils.readFileToString(confFile, Charset.defaultCharset()); + JSONObject jsonObject = JSON.parseObject(confString); + Assert.assertEquals(jsonObject.getInteger("aa").intValue(), 1); + + File gqlFile = new File(gqlPath); + String gqlString = FileUtils.readFileToString(gqlFile, Charset.defaultCharset()); + Assert.assertEquals(gqlString, code); + } finally { + FileUtils.deleteDirectory(new File(testPath)); } - - @Test - public void test() throws IOException { - String path = Fmt.as("/tmp/test-zip-{}.zip", System.currentTimeMillis()); - try { - HashMap map = new HashMap<>(); - map.put("aa", 1); - map.put("b", 3); - map.put("c", 4); - GeaflowZipEntry userGql = new MemoryZipEntry("user.gql", code); - GeaflowZipEntry conf = new MemoryZipEntry("user.conf", JSON.toJSONString(map)); - - InputStream inputStream = ZipUtil.buildZipInputStream(Arrays.asList(userGql, conf)); - FileUtils.copyInputStreamToFile(inputStream, new File(path)); - inputStream.close(); - - inputStream = Files.newInputStream(Paths.get(path)); - ZipInputStream zipInputStream = new ZipInputStream(inputStream); - ZipEntry entry; - while ((entry = zipInputStream.getNextEntry()) != null) { - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - byte[] byte_s = new byte[1024]; - int num; - while ((num = zipInputStream.read(byte_s, 0, byte_s.length)) > -1) { - byteArrayOutputStream.write(byte_s, 0, num); - - } - String s = byteArrayOutputStream.toString(); - - log.info(s); - zipInputStream.closeEntry(); - byteArrayOutputStream.close(); - if (entry.getName().equals("user.conf")) { - JSONObject json = JSON.parseObject(s); - Assert.assertEquals(json.get("aa"), 1); - Assert.assertEquals(json.get("b"), 3); - } - if (entry.getName().equals("user.gql")) { - Assert.assertEquals(s, code); - } - } - } finally { - File file = new File(path); - file.delete(); + } + + @Test + public void test() throws IOException { + String path = Fmt.as("/tmp/test-zip-{}.zip", System.currentTimeMillis()); + try { + HashMap map = new HashMap<>(); + map.put("aa", 1); + map.put("b", 3); + map.put("c", 4); + GeaflowZipEntry userGql = new MemoryZipEntry("user.gql", code); + GeaflowZipEntry conf = new MemoryZipEntry("user.conf", JSON.toJSONString(map)); + + InputStream inputStream = ZipUtil.buildZipInputStream(Arrays.asList(userGql, conf)); + FileUtils.copyInputStreamToFile(inputStream, new File(path)); + inputStream.close(); + + inputStream = Files.newInputStream(Paths.get(path)); + ZipInputStream zipInputStream = new ZipInputStream(inputStream); + ZipEntry entry; + while ((entry = zipInputStream.getNextEntry()) != null) { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + byte[] byte_s = new byte[1024]; + int num; + while ((num = zipInputStream.read(byte_s, 0, byte_s.length)) > -1) { + byteArrayOutputStream.write(byte_s, 0, num); + } + String s = byteArrayOutputStream.toString(); + + log.info(s); + zipInputStream.closeEntry(); + byteArrayOutputStream.close(); + if (entry.getName().equals("user.conf")) { + JSONObject json = JSON.parseObject(s); + Assert.assertEquals(json.get("aa"), 1); + Assert.assertEquals(json.get("b"), 3); } + if (entry.getName().equals("user.gql")) { + Assert.assertEquals(s, code); + } + } + } finally { + File file = new File(path); + file.delete(); } + } - @Test - public void testMultiFiles() throws IOException, URISyntaxException { - String path = Fmt.as("/tmp/ziptest/test-zip-{}.zip", System.currentTimeMillis()); - try { - URL src1 = getClass().getClassLoader().getResource("zip_test.txt"); - - String entry1 = "dir/zip_test.txt"; + @Test + public void testMultiFiles() throws IOException, URISyntaxException { + String path = Fmt.as("/tmp/ziptest/test-zip-{}.zip", System.currentTimeMillis()); + try { + URL src1 = getClass().getClassLoader().getResource("zip_test.txt"); - URL src2 = getClass().getClassLoader().getResource("zip_test2.txt"); - String entry2 = "dir/zip_test2.txt"; + String entry1 = "dir/zip_test.txt"; + URL src2 = getClass().getClassLoader().getResource("zip_test2.txt"); + String entry2 = "dir/zip_test2.txt"; - List entries = new ArrayList<>(); - entries.add(new FileZipEntry(entry1, new File(src1.toURI()))); - entries.add(new FileZipEntry(entry2, new File(src2.toURI()))); - InputStream stream = ZipUtil.buildZipInputStream(entries); - - FileUtils.copyInputStreamToFile(stream, new File(path)); - } finally { - File file = new File(path); - file.delete(); - } + List entries = new ArrayList<>(); + entries.add(new FileZipEntry(entry1, new File(src1.toURI()))); + entries.add(new FileZipEntry(entry2, new File(src2.toURI()))); + InputStream stream = ZipUtil.buildZipInputStream(entries); + FileUtils.copyInputStreamToFile(stream, new File(path)); + } finally { + File file = new File(path); + file.delete(); } - - String code = "CREATE GRAPH modern (\n" + "\tVertex person (\n" + "\t id bigint ID,\n" + "\t name varchar,\n" - + "\t age int\n" + "\t),\n" + "\tVertex software (\n" + "\t id bigint ID,\n" + "\t name varchar,\n" - + "\t lang varchar\n" + "\t),\n" + "\tEdge knows (\n" + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID,\n" + "\t weight double\n" + "\t),\n" + "\tEdge created (\n" - + "\t srcId bigint SOURCE ID,\n" + " \ttargetId bigint DESTINATION ID,\n" + " \tweight double\n" + "\t)\n" - + ") WITH (\n" + "\tstoreType='rocksdb',\n" - + "\tgeaflow.dsl.using.vertex.path = 'resource:///data/modern_vertex.txt',\n" - + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + ");\n" + "\n" - + "CREATE TABLE tbl_result (\n" + " a_id bigint,\n" + " weight double,\n" + " b_id bigint\n" + ") WITH (\n" - + "\ttype='file',\n" + "\tgeaflow.dsl.file.path='${target}'\n" + ");\n" + "\n" + "USE GRAPH modern;\n" + "\n" - + "INSERT INTO tbl_result\n" + "SELECT\n" + "\ta_id,\n" + "\tweight,\n" + "\tb_id\n" + "FROM (\n" - + " MATCH (a) -[e:knows]->(b:person where b.id != 1)\n" - + " RETURN a.id as a_id, e.weight as weight, b.id as b_id\n" + ")"; - + } + + String code = + "CREATE GRAPH modern (\n" + + "\tVertex person (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t age int\n" + + "\t),\n" + + "\tVertex software (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t lang varchar\n" + + "\t),\n" + + "\tEdge knows (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID,\n" + + "\t weight double\n" + + "\t),\n" + + "\tEdge created (\n" + + "\t srcId bigint SOURCE ID,\n" + + " \ttargetId bigint DESTINATION ID,\n" + + " \tweight double\n" + + "\t)\n" + + ") WITH (\n" + + "\tstoreType='rocksdb',\n" + + "\tgeaflow.dsl.using.vertex.path = 'resource:///data/modern_vertex.txt',\n" + + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + + ");\n" + + "\n" + + "CREATE TABLE tbl_result (\n" + + " a_id bigint,\n" + + " weight double,\n" + + " b_id bigint\n" + + ") WITH (\n" + + "\ttype='file',\n" + + "\tgeaflow.dsl.file.path='${target}'\n" + + ");\n" + + "\n" + + "USE GRAPH modern;\n" + + "\n" + + "INSERT INTO tbl_result\n" + + "SELECT\n" + + "\ta_id,\n" + + "\tweight,\n" + + "\tb_id\n" + + "FROM (\n" + + " MATCH (a) -[e:knows]->(b:person where b.id != 1)\n" + + " RETURN a.id as a_id, e.weight as weight, b.id as b_id\n" + + ")"; } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiCorsConfigurer.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiCorsConfigurer.java index 2e7244ea3..896ad152a 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiCorsConfigurer.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiCorsConfigurer.java @@ -20,25 +20,26 @@ package org.apache.geaflow.console.web.api; import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.console.common.util.context.ContextHolder; import org.springframework.http.HttpHeaders; public class ErrorApiCorsConfigurer { - private ErrorApiCorsConfigurer() { - } + private ErrorApiCorsConfigurer() {} - public static void configure(HttpServletResponse response) { - String originKey = HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; - String credentialsKey = HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS; - String headersKey = HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; - String methodsKey = HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS; - String ageKey = HttpHeaders.ACCESS_CONTROL_MAX_AGE; + public static void configure(HttpServletResponse response) { + String originKey = HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; + String credentialsKey = HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS; + String headersKey = HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS; + String methodsKey = HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS; + String ageKey = HttpHeaders.ACCESS_CONTROL_MAX_AGE; - response.setHeader(originKey, ContextHolder.get().getRequest().getHeader(HttpHeaders.ORIGIN)); - response.setHeader(methodsKey, "OPTIONS,HEAD,GET,POST,PUT,PATCH,DELETE,TRACE"); - response.setHeader(headersKey, "Origin,X-Requested-With,Content-Type,Accept,geaflow-token,geaflow-task-token"); - response.setHeader(credentialsKey, "true"); - response.setHeader(ageKey, "3600"); - } + response.setHeader(originKey, ContextHolder.get().getRequest().getHeader(HttpHeaders.ORIGIN)); + response.setHeader(methodsKey, "OPTIONS,HEAD,GET,POST,PUT,PATCH,DELETE,TRACE"); + response.setHeader( + headersKey, "Origin,X-Requested-With,Content-Type,Accept,geaflow-token,geaflow-task-token"); + response.setHeader(credentialsKey, "true"); + response.setHeader(ageKey, "3600"); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiResponse.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiResponse.java index 4c97371e3..ad7bd3bd4 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiResponse.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/ErrorApiResponse.java @@ -20,35 +20,38 @@ package org.apache.geaflow.console.web.api; import javax.servlet.http.HttpServletResponse; -import lombok.Getter; + import org.apache.geaflow.console.common.util.exception.GeaflowExceptionClassifier; import org.apache.geaflow.console.common.util.exception.GeaflowExceptionClassifier.GeaflowExceptionClassificationResult; +import lombok.Getter; + @Getter public class ErrorApiResponse extends GeaflowApiResponse { - private static final GeaflowExceptionClassifier EXCEPTION_CLASSIFIER = new GeaflowExceptionClassifier(); + private static final GeaflowExceptionClassifier EXCEPTION_CLASSIFIER = + new GeaflowExceptionClassifier(); - private final GeaflowApiRequest request; + private final GeaflowApiRequest request; - private final String message; + private final String message; - protected ErrorApiResponse(Throwable error) { - super(false); + protected ErrorApiResponse(Throwable error) { + super(false); - // Use classifier to classify the exception - GeaflowExceptionClassificationResult result = EXCEPTION_CLASSIFIER.classify(error); - this.code = result.getCode(); - this.message = result.getMessage(); + // Use classifier to classify the exception + GeaflowExceptionClassificationResult result = EXCEPTION_CLASSIFIER.classify(error); + this.code = result.getCode(); + this.message = result.getMessage(); - this.request = GeaflowApiRequest.currentRequest(); - } + this.request = GeaflowApiRequest.currentRequest(); + } - @Override - public void write(HttpServletResponse response) { - response.reset(); - // Use centralized CORS configuration - ErrorApiCorsConfigurer.configure(response); - super.write(response); - } + @Override + public void write(HttpServletResponse response) { + response.reset(); + // Use centralized CORS configuration + ErrorApiCorsConfigurer.configure(response); + super.write(response); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiRequest.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiRequest.java index 7a22cc47f..8edcc3040 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiRequest.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiRequest.java @@ -21,69 +21,70 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; -import lombok.Getter; -import lombok.Setter; + import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; +import lombok.Getter; +import lombok.Setter; + @Getter public class GeaflowApiRequest { - private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; + private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; - private String url; + private String url; - @Setter - private RequestMethod method; + @Setter private RequestMethod method; - @Setter - private T body; + @Setter private T body; - public static GeaflowApiRequest currentRequest() { - GeaflowApiRequest apiRequest = new GeaflowApiRequest<>(); + public static GeaflowApiRequest currentRequest() { + GeaflowApiRequest apiRequest = new GeaflowApiRequest<>(); - ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); - if (attributes != null) { - HttpServletRequest request = attributes.getRequest(); - apiRequest.url = request.getRequestURI(); - apiRequest.method = RequestMethod.valueOf(request.getMethod()); - } - return apiRequest; + ServletRequestAttributes attributes = + (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (attributes != null) { + HttpServletRequest request = attributes.getRequest(); + apiRequest.url = request.getRequestURI(); + apiRequest.method = RequestMethod.valueOf(request.getMethod()); } + return apiRequest; + } - public static String getSessionToken(HttpServletRequest request) { - return getRequestParameter(request, GEAFLOW_TOKEN_KEY); - } + public static String getSessionToken(HttpServletRequest request) { + return getRequestParameter(request, GEAFLOW_TOKEN_KEY); + } - public static String getRequestParameter(HttpServletRequest request, String key) { - String arg = getUrlParameter(request, key); - if (arg == null) { - arg = getHeader(request, key); - } - if (arg == null) { - arg = getCookie(request, key); - } - return arg; + public static String getRequestParameter(HttpServletRequest request, String key) { + String arg = getUrlParameter(request, key); + if (arg == null) { + arg = getHeader(request, key); } - - public static String getUrlParameter(HttpServletRequest request, String key) { - return request.getParameter(key); + if (arg == null) { + arg = getCookie(request, key); } + return arg; + } - public static String getHeader(HttpServletRequest request, String key) { - return request.getHeader(key); - } + public static String getUrlParameter(HttpServletRequest request, String key) { + return request.getParameter(key); + } + + public static String getHeader(HttpServletRequest request, String key) { + return request.getHeader(key); + } - public static String getCookie(HttpServletRequest request, String key) { - Cookie[] cookies = request.getCookies(); - if (cookies != null) { - for (Cookie cookie : cookies) { - if (key.equals(cookie.getName())) { - return cookie.getValue(); - } - } + public static String getCookie(HttpServletRequest request, String key) { + Cookie[] cookies = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if (key.equals(cookie.getName())) { + return cookie.getValue(); } - return null; + } } + return null; + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiResponse.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiResponse.java index 1edc440b9..5aef97400 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiResponse.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/GeaflowApiResponse.java @@ -19,49 +19,53 @@ package org.apache.geaflow.console.web.api; -import com.alibaba.fastjson.JSON; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; + import javax.servlet.http.HttpServletResponse; -import lombok.Getter; + import org.apache.geaflow.console.common.util.NetworkUtil; import org.apache.geaflow.console.common.util.exception.GeaflowException; import org.apache.geaflow.console.common.util.type.GeaflowApiResponseCode; import org.springframework.http.MediaType; +import com.alibaba.fastjson.JSON; + +import lombok.Getter; + @Getter public abstract class GeaflowApiResponse { - private final boolean success; + private final boolean success; - private final String host; + private final String host; - protected GeaflowApiResponseCode code; + protected GeaflowApiResponseCode code; - protected GeaflowApiResponse(boolean success) { - this.host = NetworkUtil.getHostName(); - this.success = success; - } + protected GeaflowApiResponse(boolean success) { + this.host = NetworkUtil.getHostName(); + this.success = success; + } - public static GeaflowApiResponse success(T data) { - return new SuccessApiResponse<>(data); - } + public static GeaflowApiResponse success(T data) { + return new SuccessApiResponse<>(data); + } - public static GeaflowApiResponse error(Throwable error) { - return new ErrorApiResponse<>(error); - } + public static GeaflowApiResponse error(Throwable error) { + return new ErrorApiResponse<>(error); + } - public void write(HttpServletResponse response) { - try { - response.setContentType(MediaType.APPLICATION_JSON_VALUE); - response.setCharacterEncoding(StandardCharsets.UTF_8.toString()); - response.setStatus(GeaflowApiResponseCode.SUCCESS.getHttpCode()); - PrintWriter out = response.getWriter(); - JSON.writeJSONString(out, this); - out.flush(); + public void write(HttpServletResponse response) { + try { + response.setContentType(MediaType.APPLICATION_JSON_VALUE); + response.setCharacterEncoding(StandardCharsets.UTF_8.toString()); + response.setStatus(GeaflowApiResponseCode.SUCCESS.getHttpCode()); + PrintWriter out = response.getWriter(); + JSON.writeJSONString(out, this); + out.flush(); - } catch (Exception e) { - throw new GeaflowException("Write api response failed", e); - } + } catch (Exception e) { + throw new GeaflowException("Write api response failed", e); } + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/SuccessApiResponse.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/SuccessApiResponse.java index 13a2ff51c..30e977180 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/SuccessApiResponse.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/api/SuccessApiResponse.java @@ -19,18 +19,18 @@ package org.apache.geaflow.console.web.api; -import lombok.Getter; import org.apache.geaflow.console.common.util.type.GeaflowApiResponseCode; +import lombok.Getter; + @Getter public class SuccessApiResponse extends GeaflowApiResponse { - private final T data; - - protected SuccessApiResponse(T data) { - super(true); - this.code = GeaflowApiResponseCode.SUCCESS; - this.data = data; - } + private final T data; + protected SuccessApiResponse(T data) { + super(true); + this.code = GeaflowApiResponseCode.SUCCESS; + this.data = data; + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/aspect/ViewNameAspect.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/aspect/ViewNameAspect.java index 4f6c022be..d6aad3d10 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/aspect/ViewNameAspect.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/aspect/ViewNameAspect.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.geaflow.console.biz.shared.TenantManager; import org.apache.geaflow.console.biz.shared.UserManager; import org.apache.geaflow.console.biz.shared.view.IdView; @@ -41,86 +42,87 @@ @Component public class ViewNameAspect { - @Autowired - private UserManager userManager; - - @Autowired - private TenantManager tenantManager; + @Autowired private UserManager userManager; - @AfterReturning(value = "execution(* org.apache.geaflow.console.web.controller..*Controller.*(..))", returning - = "response") - public void handle(JoinPoint joinPoint, Object response) throws Throwable { - if (response instanceof SuccessApiResponse) { - Object data = ((SuccessApiResponse) response).getData(); + @Autowired private TenantManager tenantManager; - if (data instanceof IdView) { - fillName(Collections.singleton(data)); + @AfterReturning( + value = "execution(* org.apache.geaflow.console.web.controller..*Controller.*(..))", + returning = "response") + public void handle(JoinPoint joinPoint, Object response) throws Throwable { + if (response instanceof SuccessApiResponse) { + Object data = ((SuccessApiResponse) response).getData(); - } else if (data instanceof Collection) { - fillName((Collection) data); + if (data instanceof IdView) { + fillName(Collections.singleton(data)); - } else if (data instanceof Map) { - fillName(((Map) data).values()); + } else if (data instanceof Collection) { + fillName((Collection) data); - } else if (data instanceof PageList) { - fillName(((PageList) data).getList()); - } - } - } + } else if (data instanceof Map) { + fillName(((Map) data).values()); - private void fillName(Collection collection) { - fillUserName(collection); - fillTenantName(collection); + } else if (data instanceof PageList) { + fillName(((PageList) data).getList()); + } } - - private void fillUserName(Collection collection) { - List creatorViews = new ArrayList<>(); - List modifierViews = new ArrayList<>(); - Set userIds = new HashSet<>(); - - collection.forEach(v -> { - if (v instanceof IdView) { - IdView idView = (IdView) v; - String creatorId = idView.getCreatorId(); - String modifierId = idView.getModifierId(); - - if (creatorId != null) { - userIds.add(creatorId); - creatorViews.add(idView); - } - if (modifierId != null) { - userIds.add(modifierId); - modifierViews.add(idView); - } + } + + private void fillName(Collection collection) { + fillUserName(collection); + fillTenantName(collection); + } + + private void fillUserName(Collection collection) { + List creatorViews = new ArrayList<>(); + List modifierViews = new ArrayList<>(); + Set userIds = new HashSet<>(); + + collection.forEach( + v -> { + if (v instanceof IdView) { + IdView idView = (IdView) v; + String creatorId = idView.getCreatorId(); + String modifierId = idView.getModifierId(); + + if (creatorId != null) { + userIds.add(creatorId); + creatorViews.add(idView); + } + if (modifierId != null) { + userIds.add(modifierId); + modifierViews.add(idView); } + } }); - if (!userIds.isEmpty()) { - Map userNames = userManager.getUserNames(userIds); - creatorViews.forEach(v -> v.setCreatorName(userNames.get(v.getCreatorId()))); - modifierViews.forEach(v -> v.setModifierName(userNames.get(v.getModifierId()))); - } + if (!userIds.isEmpty()) { + Map userNames = userManager.getUserNames(userIds); + creatorViews.forEach(v -> v.setCreatorName(userNames.get(v.getCreatorId()))); + modifierViews.forEach(v -> v.setModifierName(userNames.get(v.getModifierId()))); } + } - private void fillTenantName(Collection collection) { - List views = new ArrayList<>(); - Set tenantIds = new HashSet<>(); + private void fillTenantName(Collection collection) { + List views = new ArrayList<>(); + Set tenantIds = new HashSet<>(); - collection.forEach(v -> { - if (v instanceof IdView) { - IdView idView = (IdView) v; - String tenantId = idView.getTenantId(); + collection.forEach( + v -> { + if (v instanceof IdView) { + IdView idView = (IdView) v; + String tenantId = idView.getTenantId(); - if (tenantId != null) { - tenantIds.add(tenantId); - views.add(idView); - } + if (tenantId != null) { + tenantIds.add(tenantId); + views.add(idView); } + } }); - if (!tenantIds.isEmpty()) { - Map tenantNames = tenantManager.getTenantNames(tenantIds); - views.forEach(v -> v.setTenantName(tenantNames.get(v.getTenantId()))); - } + if (!tenantIds.isEmpty()) { + Map tenantNames = tenantManager.getTenantNames(tenantIds); + views.forEach(v -> v.setTenantName(tenantNames.get(v.getTenantId()))); } + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/AuthenticationController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/AuthenticationController.java index aa6252e90..8e448e4b2 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/AuthenticationController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/AuthenticationController.java @@ -36,27 +36,25 @@ @RequestMapping("/auth") public class AuthenticationController { - @Autowired - private AuthenticationManager authenticationManager; - - @Autowired - private UserManager userManager; - - @PostMapping("/register") - @ResponseBody - public GeaflowApiResponse register(@RequestBody UserView userView) { - String userId = userManager.register(userView); - return GeaflowApiResponse.success(userId); - } - - @PostMapping("/login") - @ResponseBody - public GeaflowApiResponse login(@RequestBody LoginView loginView) { - String loginName = loginView.getLoginName(); - String password = loginView.getPassword(); - boolean systemLogin = loginView.isSystemLogin(); - AuthenticationView authentication = authenticationManager.login(loginName, password, systemLogin); - return GeaflowApiResponse.success(authentication); - } - + @Autowired private AuthenticationManager authenticationManager; + + @Autowired private UserManager userManager; + + @PostMapping("/register") + @ResponseBody + public GeaflowApiResponse register(@RequestBody UserView userView) { + String userId = userManager.register(userView); + return GeaflowApiResponse.success(userId); + } + + @PostMapping("/login") + @ResponseBody + public GeaflowApiResponse login(@RequestBody LoginView loginView) { + String loginName = loginView.getLoginName(); + String password = loginView.getPassword(); + boolean systemLogin = loginView.isSystemLogin(); + AuthenticationView authentication = + authenticationManager.login(loginName, password, systemLogin); + return GeaflowApiResponse.success(authentication); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/HealthController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/HealthController.java index 53a630e63..f0b18ca5a 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/HealthController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/HealthController.java @@ -27,10 +27,9 @@ @Controller public class HealthController { - @GetMapping("/health") - @ResponseBody - public GeaflowApiResponse health() { - return GeaflowApiResponse.success(true); - } - + @GetMapping("/health") + @ResponseBody + public GeaflowApiResponse health() { + return GeaflowApiResponse.success(true); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/IndexController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/IndexController.java index 625162ec9..b15f2824d 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/IndexController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/IndexController.java @@ -25,10 +25,8 @@ @Controller public class IndexController { - @GetMapping("/") - public String index() { - return "index.html"; - } - + @GetMapping("/") + public String index() { + return "index.html"; + } } - \ No newline at end of file diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuditController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuditController.java index f1906a907..1bbcec9aa 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuditController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuditController.java @@ -30,22 +30,19 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/audits") public class AuditController { - @Autowired - private AuditManager auditManager; - - @GetMapping - public GeaflowApiResponse> searchAudits(AuditSearch search) { - return GeaflowApiResponse.success(auditManager.search(search)); - } + @Autowired private AuditManager auditManager; - @GetMapping("/{id}") - public GeaflowApiResponse queryAudit(@PathVariable String id) { - return GeaflowApiResponse.success(auditManager.get(id)); - } + @GetMapping + public GeaflowApiResponse> searchAudits(AuditSearch search) { + return GeaflowApiResponse.success(auditManager.search(search)); + } + @GetMapping("/{id}") + public GeaflowApiResponse queryAudit(@PathVariable String id) { + return GeaflowApiResponse.success(auditManager.get(id)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuthorizationController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuthorizationController.java index 0b11842d9..cd3267da3 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuthorizationController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/AuthorizationController.java @@ -33,31 +33,30 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/authorizations") public class AuthorizationController { - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private AuthorizationManager authorizationManager; - @GetMapping - public GeaflowApiResponse> searchAuthorizations(AuthorizationSearch search) { - return GeaflowApiResponse.success(authorizationManager.search(search)); - } + @GetMapping + public GeaflowApiResponse> searchAuthorizations( + AuthorizationSearch search) { + return GeaflowApiResponse.success(authorizationManager.search(search)); + } - @GetMapping("/{id}") - public GeaflowApiResponse queryAuthorization(@PathVariable String id) { - return GeaflowApiResponse.success(authorizationManager.get(id)); - } + @GetMapping("/{id}") + public GeaflowApiResponse queryAuthorization(@PathVariable String id) { + return GeaflowApiResponse.success(authorizationManager.get(id)); + } - @PostMapping - public GeaflowApiResponse applyAuthorization(@RequestBody AuthorizationView view) { - return GeaflowApiResponse.success(authorizationManager.create(view)); - } + @PostMapping + public GeaflowApiResponse applyAuthorization(@RequestBody AuthorizationView view) { + return GeaflowApiResponse.success(authorizationManager.create(view)); + } - @DeleteMapping("/{id}") - public GeaflowApiResponse deleteAuthorization(@PathVariable String id) { - return GeaflowApiResponse.success(authorizationManager.drop(id)); - } + @DeleteMapping("/{id}") + public GeaflowApiResponse deleteAuthorization(@PathVariable String id) { + return GeaflowApiResponse.success(authorizationManager.drop(id)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ChatController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ChatController.java index d0f52a9da..d3dcd0f4c 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ChatController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ChatController.java @@ -37,40 +37,36 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/chats") public class ChatController { - @Autowired - private ChatManager chatManager; - - - @GetMapping - public GeaflowApiResponse> searchLLMs(ChatSearch search) { - if (StringUtils.isEmpty(search.getSort())) { - search.setOrder(SortOrder.ASC, CREATE_TIME_FIELD_NAME); - } - return GeaflowApiResponse.success(chatManager.search(search)); - } + @Autowired private ChatManager chatManager; - @PostMapping - public GeaflowApiResponse createChat(ChatView chatView, - @RequestParam(required = false) Boolean withSchema) { - return GeaflowApiResponse.success(chatManager.callASync(chatView, withSchema)); + @GetMapping + public GeaflowApiResponse> searchLLMs(ChatSearch search) { + if (StringUtils.isEmpty(search.getSort())) { + search.setOrder(SortOrder.ASC, CREATE_TIME_FIELD_NAME); } + return GeaflowApiResponse.success(chatManager.search(search)); + } - @PostMapping("/callSync") - public GeaflowApiResponse callSync(ChatView chatView, - @RequestParam(required = false) Boolean saveRecord, - @RequestParam(required = false) Boolean withSchema) { - return GeaflowApiResponse.success(chatManager.callSync(chatView, saveRecord, withSchema)); - } - - @DeleteMapping("/jobs/{jobId}") - public GeaflowApiResponse deleteChat(@PathVariable String jobId) { - return GeaflowApiResponse.success(chatManager.dropByJobId(jobId)); - } + @PostMapping + public GeaflowApiResponse createChat( + ChatView chatView, @RequestParam(required = false) Boolean withSchema) { + return GeaflowApiResponse.success(chatManager.callASync(chatView, withSchema)); + } + @PostMapping("/callSync") + public GeaflowApiResponse callSync( + ChatView chatView, + @RequestParam(required = false) Boolean saveRecord, + @RequestParam(required = false) Boolean withSchema) { + return GeaflowApiResponse.success(chatManager.callSync(chatView, saveRecord, withSchema)); + } + @DeleteMapping("/jobs/{jobId}") + public GeaflowApiResponse deleteChat(@PathVariable String jobId) { + return GeaflowApiResponse.success(chatManager.dropByJobId(jobId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ClusterController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ClusterController.java index 08a01eb7f..ec1956f04 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ClusterController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ClusterController.java @@ -39,51 +39,47 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/clusters") public class ClusterController { - @Autowired - private ClusterManager clusterManager; - - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private ClusterManager clusterManager; - @GetMapping - public GeaflowApiResponse> searchClusters(ClusterSearch search) { - return GeaflowApiResponse.success(clusterManager.search(search)); - } + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("/{clusterName}") - public GeaflowApiResponse queryCluster(@PathVariable String clusterName) { - return GeaflowApiResponse.success(clusterManager.getByName(clusterName)); - } + @GetMapping + public GeaflowApiResponse> searchClusters(ClusterSearch search) { + return GeaflowApiResponse.success(clusterManager.search(search)); + } - @PostMapping - public GeaflowApiResponse createCluster(@RequestBody ClusterView clusterView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(clusterManager.create(clusterView)); - } + @GetMapping("/{clusterName}") + public GeaflowApiResponse queryCluster(@PathVariable String clusterName) { + return GeaflowApiResponse.success(clusterManager.getByName(clusterName)); + } - @PutMapping("/{clusterName}") - public GeaflowApiResponse updateCluster(@PathVariable String clusterName, - @RequestBody ClusterView clusterView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(clusterManager.updateByName(clusterName, clusterView)); - } + @PostMapping + public GeaflowApiResponse createCluster(@RequestBody ClusterView clusterView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(clusterManager.create(clusterView)); + } - @DeleteMapping("/{clusterName}") - public GeaflowApiResponse deleteCluster(@PathVariable String clusterName) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(clusterManager.dropByName(clusterName)); - } + @PutMapping("/{clusterName}") + public GeaflowApiResponse updateCluster( + @PathVariable String clusterName, @RequestBody ClusterView clusterView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(clusterManager.updateByName(clusterName, clusterView)); + } - @PostMapping("/{clusterName}/operations") - public GeaflowApiResponse operateCluster(@PathVariable String clusterName, - @RequestParam GeaflowOperationType clusterAction) { + @DeleteMapping("/{clusterName}") + public GeaflowApiResponse deleteCluster(@PathVariable String clusterName) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(clusterManager.dropByName(clusterName)); + } - throw new GeaflowException("Cluster operation not supported"); - } + @PostMapping("/{clusterName}/operations") + public GeaflowApiResponse operateCluster( + @PathVariable String clusterName, @RequestParam GeaflowOperationType clusterAction) { + throw new GeaflowException("Cluster operation not supported"); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ConfigController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ConfigController.java index 98a57e0f9..5f91a0b24 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ConfigController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ConfigController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import java.util.List; + import org.apache.geaflow.console.biz.shared.ConfigManager; import org.apache.geaflow.console.common.util.type.GeaflowPluginCategory; import org.apache.geaflow.console.core.model.config.ConfigDescItem; @@ -30,39 +31,36 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/config") public class ConfigController { - @Autowired - private ConfigManager configManager; - - @GetMapping("/cluster") - public GeaflowApiResponse> getClusterConfig() { - return GeaflowApiResponse.success(configManager.getClusterConfig()); - } + @Autowired private ConfigManager configManager; - @GetMapping("/job") - public GeaflowApiResponse> getJobConfig() { - return GeaflowApiResponse.success(configManager.getJobConfig()); - } + @GetMapping("/cluster") + public GeaflowApiResponse> getClusterConfig() { + return GeaflowApiResponse.success(configManager.getClusterConfig()); + } - @GetMapping("/plugin/categories") - public GeaflowApiResponse> getPluginCategories() { - return GeaflowApiResponse.success(configManager.getPluginCategories()); - } + @GetMapping("/job") + public GeaflowApiResponse> getJobConfig() { + return GeaflowApiResponse.success(configManager.getJobConfig()); + } - @GetMapping("/plugin/categories/{category}/types") - public GeaflowApiResponse> getPluginCategoryTypes( - @PathVariable GeaflowPluginCategory category) { - return GeaflowApiResponse.success(configManager.getPluginCategoryTypes(category)); - } + @GetMapping("/plugin/categories") + public GeaflowApiResponse> getPluginCategories() { + return GeaflowApiResponse.success(configManager.getPluginCategories()); + } - @GetMapping("/plugin/categories/{category}/types/{type}") - public GeaflowApiResponse> getPluginConfig(@PathVariable GeaflowPluginCategory category, - @PathVariable String type) { - return GeaflowApiResponse.success(configManager.getPluginConfig(category, type)); - } + @GetMapping("/plugin/categories/{category}/types") + public GeaflowApiResponse> getPluginCategoryTypes( + @PathVariable GeaflowPluginCategory category) { + return GeaflowApiResponse.success(configManager.getPluginCategoryTypes(category)); + } + @GetMapping("/plugin/categories/{category}/types/{type}") + public GeaflowApiResponse> getPluginConfig( + @PathVariable GeaflowPluginCategory category, @PathVariable String type) { + return GeaflowApiResponse.success(configManager.getPluginConfig(category, type)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/EdgeController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/EdgeController.java index 8b336e548..6638a41cc 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/EdgeController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/EdgeController.java @@ -38,45 +38,44 @@ @RequestMapping public class EdgeController { - @Autowired - private EdgeManager edgeManager; + @Autowired private EdgeManager edgeManager; - @GetMapping("/edges") - public GeaflowApiResponse> searchEdges(EdgeSearch search) { - return GeaflowApiResponse.success(edgeManager.search(search)); - } + @GetMapping("/edges") + public GeaflowApiResponse> searchEdges(EdgeSearch search) { + return GeaflowApiResponse.success(edgeManager.search(search)); + } + @GetMapping("/instances/{instanceName}/edges") + public GeaflowApiResponse> searchEdges( + @PathVariable("instanceName") String instanceName, EdgeSearch search) { + return GeaflowApiResponse.success(edgeManager.searchByInstanceName(instanceName, search)); + } - @GetMapping("/instances/{instanceName}/edges") - public GeaflowApiResponse> searchEdges(@PathVariable("instanceName") String instanceName, - EdgeSearch search) { - return GeaflowApiResponse.success(edgeManager.searchByInstanceName(instanceName, search)); - } + @GetMapping("/instances/{instanceName}/edges/{edgeName}") + public GeaflowApiResponse getEdge( + @PathVariable("instanceName") String instanceName, + @PathVariable("edgeName") String edgeName) { + return GeaflowApiResponse.success(edgeManager.getByName(instanceName, edgeName)); + } - @GetMapping("/instances/{instanceName}/edges/{edgeName}") - public GeaflowApiResponse getEdge(@PathVariable("instanceName") String instanceName, - @PathVariable("edgeName") String edgeName) { - return GeaflowApiResponse.success(edgeManager.getByName(instanceName, edgeName)); - } + @PostMapping("/instances/{instanceName}/edges") + public GeaflowApiResponse create( + @PathVariable("instanceName") String instanceName, @RequestBody EdgeView edgeView) { + return GeaflowApiResponse.success(edgeManager.create(instanceName, edgeView)); + } - @PostMapping("/instances/{instanceName}/edges") - public GeaflowApiResponse create(@PathVariable("instanceName") String instanceName, - @RequestBody EdgeView edgeView) { - return GeaflowApiResponse.success(edgeManager.create(instanceName, edgeView)); - } - - @PutMapping("/instances/{instanceName}/edges/{edgeName}") - public GeaflowApiResponse update(@PathVariable("instanceName") String instanceName, - @PathVariable("edgeName") String edgeName, - @RequestBody EdgeView edgeView) { - return GeaflowApiResponse.success(edgeManager.updateByName(instanceName, edgeName, edgeView)); - } - - - @DeleteMapping("/instances/{instanceName}/edges/{edgeName}") - public GeaflowApiResponse drop(@PathVariable("instanceName") String instanceName, - @PathVariable("edgeName") String edgeName) { - return GeaflowApiResponse.success(edgeManager.dropByName(instanceName, edgeName)); - } + @PutMapping("/instances/{instanceName}/edges/{edgeName}") + public GeaflowApiResponse update( + @PathVariable("instanceName") String instanceName, + @PathVariable("edgeName") String edgeName, + @RequestBody EdgeView edgeView) { + return GeaflowApiResponse.success(edgeManager.updateByName(instanceName, edgeName, edgeView)); + } + @DeleteMapping("/instances/{instanceName}/edges/{edgeName}") + public GeaflowApiResponse drop( + @PathVariable("instanceName") String instanceName, + @PathVariable("edgeName") String edgeName) { + return GeaflowApiResponse.success(edgeManager.dropByName(instanceName, edgeName)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/FunctionController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/FunctionController.java index 276c5f52e..b7256a97c 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/FunctionController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/FunctionController.java @@ -36,37 +36,43 @@ @RestController public class FunctionController { - @Autowired - private FunctionManager functionManager; + @Autowired private FunctionManager functionManager; - @GetMapping("/instances/{instanceName}/functions") - public GeaflowApiResponse> instanceSearch(@PathVariable("instanceName") String instanceName, - FunctionSearch search) { - return GeaflowApiResponse.success(functionManager.searchByInstanceName(instanceName, search)); - } + @GetMapping("/instances/{instanceName}/functions") + public GeaflowApiResponse> instanceSearch( + @PathVariable("instanceName") String instanceName, FunctionSearch search) { + return GeaflowApiResponse.success(functionManager.searchByInstanceName(instanceName, search)); + } - @GetMapping("/instances/{instanceName}/functions/{functionName}") - public GeaflowApiResponse getFunction(@PathVariable String instanceName, - @PathVariable String functionName) { - return GeaflowApiResponse.success(functionManager.getByName(instanceName, functionName)); - } + @GetMapping("/instances/{instanceName}/functions/{functionName}") + public GeaflowApiResponse getFunction( + @PathVariable String instanceName, @PathVariable String functionName) { + return GeaflowApiResponse.success(functionManager.getByName(instanceName, functionName)); + } - @PostMapping("/instances/{instanceName}/functions") - public GeaflowApiResponse createFunction(@PathVariable String instanceName, FunctionView view, - @RequestParam(required = false) MultipartFile functionFile, - @RequestParam(required = false) String fileId) { - return GeaflowApiResponse.success(functionManager.createFunction(instanceName, view, functionFile, fileId)); - } + @PostMapping("/instances/{instanceName}/functions") + public GeaflowApiResponse createFunction( + @PathVariable String instanceName, + FunctionView view, + @RequestParam(required = false) MultipartFile functionFile, + @RequestParam(required = false) String fileId) { + return GeaflowApiResponse.success( + functionManager.createFunction(instanceName, view, functionFile, fileId)); + } - @PutMapping("/instances/{instanceName}/functions/{functionName}") - public GeaflowApiResponse updateFunction(@PathVariable String instanceName, @PathVariable String functionName, - FunctionView view, @RequestParam(required = false) MultipartFile functionFile) { - return GeaflowApiResponse.success(functionManager.updateFunction(instanceName, functionName, view, functionFile)); - } + @PutMapping("/instances/{instanceName}/functions/{functionName}") + public GeaflowApiResponse updateFunction( + @PathVariable String instanceName, + @PathVariable String functionName, + FunctionView view, + @RequestParam(required = false) MultipartFile functionFile) { + return GeaflowApiResponse.success( + functionManager.updateFunction(instanceName, functionName, view, functionFile)); + } - @DeleteMapping("/instances/{instanceName}/functions/{functionName}") - public GeaflowApiResponse deleteFunction(@PathVariable String instanceName, - @PathVariable String functionName) { - return GeaflowApiResponse.success(functionManager.deleteFunction(instanceName, functionName)); - } + @DeleteMapping("/instances/{instanceName}/functions/{functionName}") + public GeaflowApiResponse deleteFunction( + @PathVariable String instanceName, @PathVariable String functionName) { + return GeaflowApiResponse.success(functionManager.deleteFunction(instanceName, functionName)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/GraphController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/GraphController.java index 85b6c53ca..b27a21cea 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/GraphController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/GraphController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import java.util.List; + import org.apache.geaflow.console.biz.shared.GraphManager; import org.apache.geaflow.console.biz.shared.view.EndpointView; import org.apache.geaflow.console.biz.shared.view.GraphView; @@ -38,64 +39,70 @@ @RestController public class GraphController { - @Autowired - private GraphManager graphManager; - - @GetMapping("/graphs") - public GeaflowApiResponse> search(GraphSearch search) { - return GeaflowApiResponse.success(graphManager.search(search)); - } - - @GetMapping("/instances/{instanceName}/graphs") - public GeaflowApiResponse> instanceSearch(@PathVariable("instanceName") String instanceName, - GraphSearch search) { - return GeaflowApiResponse.success(graphManager.searchByInstanceName(instanceName, search)); - } + @Autowired private GraphManager graphManager; - @GetMapping("/instances/{instanceName}/graphs/{graphName}") - public GeaflowApiResponse getGraph(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName) { - return GeaflowApiResponse.success(graphManager.getByName(instanceName, graphName)); - } + @GetMapping("/graphs") + public GeaflowApiResponse> search(GraphSearch search) { + return GeaflowApiResponse.success(graphManager.search(search)); + } - @PostMapping("/instances/{instanceName}/graphs") - public GeaflowApiResponse create(@PathVariable("instanceName") String instanceName, - @RequestBody GraphView graphView) { - return GeaflowApiResponse.success(graphManager.create(instanceName, graphView)); - } + @GetMapping("/instances/{instanceName}/graphs") + public GeaflowApiResponse> instanceSearch( + @PathVariable("instanceName") String instanceName, GraphSearch search) { + return GeaflowApiResponse.success(graphManager.searchByInstanceName(instanceName, search)); + } - @PutMapping("/instances/{instanceName}/graphs/{graphName}") - public GeaflowApiResponse update(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName, - @RequestBody GraphView graphView) { - return GeaflowApiResponse.success(graphManager.updateByName(instanceName, graphName, graphView)); - } + @GetMapping("/instances/{instanceName}/graphs/{graphName}") + public GeaflowApiResponse getGraph( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName) { + return GeaflowApiResponse.success(graphManager.getByName(instanceName, graphName)); + } + @PostMapping("/instances/{instanceName}/graphs") + public GeaflowApiResponse create( + @PathVariable("instanceName") String instanceName, @RequestBody GraphView graphView) { + return GeaflowApiResponse.success(graphManager.create(instanceName, graphView)); + } - @DeleteMapping("/instances/{instanceName}/graphs/{graphName}") - public GeaflowApiResponse drop(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName) { - return GeaflowApiResponse.success(graphManager.dropByName(instanceName, graphName)); - } + @PutMapping("/instances/{instanceName}/graphs/{graphName}") + public GeaflowApiResponse update( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName, + @RequestBody GraphView graphView) { + return GeaflowApiResponse.success( + graphManager.updateByName(instanceName, graphName, graphView)); + } - @PostMapping("/instances/{instanceName}/graphs/{graphName}/clean") - public GeaflowApiResponse clean(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName) { - return GeaflowApiResponse.success(graphManager.clean(instanceName, graphName)); - } + @DeleteMapping("/instances/{instanceName}/graphs/{graphName}") + public GeaflowApiResponse drop( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName) { + return GeaflowApiResponse.success(graphManager.dropByName(instanceName, graphName)); + } - @PostMapping("/instances/{instanceName}/graphs/{graphName}/endpoints") - public GeaflowApiResponse createEndpoints(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName, - @RequestBody List endpoints) { - return GeaflowApiResponse.success(graphManager.createEndpoints(instanceName, graphName, endpoints)); - } + @PostMapping("/instances/{instanceName}/graphs/{graphName}/clean") + public GeaflowApiResponse clean( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName) { + return GeaflowApiResponse.success(graphManager.clean(instanceName, graphName)); + } - @DeleteMapping("/instances/{instanceName}/graphs/{graphName}/endpoints") - public GeaflowApiResponse deleteEndpoints(@PathVariable("instanceName") String instanceName, - @PathVariable("graphName") String graphName, - @RequestBody(required = false) List endpoints) { - return GeaflowApiResponse.success(graphManager.deleteEndpoints(instanceName, graphName, endpoints)); - } + @PostMapping("/instances/{instanceName}/graphs/{graphName}/endpoints") + public GeaflowApiResponse createEndpoints( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName, + @RequestBody List endpoints) { + return GeaflowApiResponse.success( + graphManager.createEndpoints(instanceName, graphName, endpoints)); + } + @DeleteMapping("/instances/{instanceName}/graphs/{graphName}/endpoints") + public GeaflowApiResponse deleteEndpoints( + @PathVariable("instanceName") String instanceName, + @PathVariable("graphName") String graphName, + @RequestBody(required = false) List endpoints) { + return GeaflowApiResponse.success( + graphManager.deleteEndpoints(instanceName, graphName, endpoints)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstallController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstallController.java index bcc9c3f74..3d8145464 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstallController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstallController.java @@ -31,26 +31,23 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/install") public class InstallController { - @Autowired - private InstallManager installManager; + @Autowired private InstallManager installManager; - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private AuthorizationManager authorizationManager; - @GetMapping - public GeaflowApiResponse get() { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(installManager.get()); - } + @GetMapping + public GeaflowApiResponse get() { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(installManager.get()); + } - @PostMapping - public GeaflowApiResponse install(@RequestBody InstallView installView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(installManager.install(installView)); - } + @PostMapping + public GeaflowApiResponse install(@RequestBody InstallView installView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(installManager.install(installView)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstanceController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstanceController.java index 18a271564..258b362a0 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstanceController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/InstanceController.java @@ -35,38 +35,35 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/instances") public class InstanceController { - @Autowired - private InstanceManager instanceManager; - - @GetMapping - public GeaflowApiResponse> searchInstances(InstanceSearch instanceSearch) { - return GeaflowApiResponse.success(instanceManager.search(instanceSearch)); - } + @Autowired private InstanceManager instanceManager; - @GetMapping("/{instanceName}") - public GeaflowApiResponse queryInstance(@PathVariable String instanceName) { - return GeaflowApiResponse.success(instanceManager.getByName(instanceName)); - } + @GetMapping + public GeaflowApiResponse> searchInstances(InstanceSearch instanceSearch) { + return GeaflowApiResponse.success(instanceManager.search(instanceSearch)); + } - @PostMapping - public GeaflowApiResponse createInstance(@RequestBody InstanceView instanceView) { - return GeaflowApiResponse.success(instanceManager.create(instanceView)); - } + @GetMapping("/{instanceName}") + public GeaflowApiResponse queryInstance(@PathVariable String instanceName) { + return GeaflowApiResponse.success(instanceManager.getByName(instanceName)); + } - @PutMapping("/{instanceName}") - public GeaflowApiResponse updateInstance(@PathVariable String instanceName, - @RequestBody InstanceView instanceView) { - return GeaflowApiResponse.success(instanceManager.updateByName(instanceName, instanceView)); - } + @PostMapping + public GeaflowApiResponse createInstance(@RequestBody InstanceView instanceView) { + return GeaflowApiResponse.success(instanceManager.create(instanceView)); + } - @DeleteMapping("/{instanceName}") - public GeaflowApiResponse deleteInstance(@PathVariable String instanceName) { - throw new GeaflowException("Delete instance {} not allowed", instanceName); - } + @PutMapping("/{instanceName}") + public GeaflowApiResponse updateInstance( + @PathVariable String instanceName, @RequestBody InstanceView instanceView) { + return GeaflowApiResponse.success(instanceManager.updateByName(instanceName, instanceView)); + } + @DeleteMapping("/{instanceName}") + public GeaflowApiResponse deleteInstance(@PathVariable String instanceName) { + throw new GeaflowException("Delete instance {} not allowed", instanceName); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/JobController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/JobController.java index c3584f733..bb0da3026 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/JobController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/JobController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import java.util.List; + import org.apache.geaflow.console.biz.shared.AuthorizationManager; import org.apache.geaflow.console.biz.shared.JobManager; import org.apache.geaflow.console.biz.shared.view.JobView; @@ -39,50 +40,49 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; - @RestController @RequestMapping("/jobs") public class JobController { - @Autowired - private JobManager jobManager; - - @Autowired - private AuthorizationManager authorizationManager; - + @Autowired private JobManager jobManager; - @GetMapping - public GeaflowApiResponse> searchJob(JobSearch search) { - return GeaflowApiResponse.success(jobManager.search(search)); - } + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("/{jobId}") - public GeaflowApiResponse getJob(@PathVariable String jobId) { - authorizationManager.hasAuthority(GeaflowAuthority.QUERY, Resources.job(jobId)); - return GeaflowApiResponse.success(jobManager.get(jobId)); - } + @GetMapping + public GeaflowApiResponse> searchJob(JobSearch search) { + return GeaflowApiResponse.success(jobManager.search(search)); + } - @PostMapping - public GeaflowApiResponse createJob(JobView jobView, - @RequestParam(required = false) MultipartFile jarFile, - @RequestParam(required = false) String fileId, - @RequestParam(required = false) List graphIds) { - authorizationManager.hasAuthority(GeaflowAuthority.ALL, Resources.instance(jobView.getInstanceId())); - return GeaflowApiResponse.success(jobManager.create(jobView, jarFile, fileId, graphIds)); - } + @GetMapping("/{jobId}") + public GeaflowApiResponse getJob(@PathVariable String jobId) { + authorizationManager.hasAuthority(GeaflowAuthority.QUERY, Resources.job(jobId)); + return GeaflowApiResponse.success(jobManager.get(jobId)); + } - @PutMapping("/{jobId}") - public GeaflowApiResponse updateJob(@PathVariable String jobId, JobView jobView, - @RequestParam(required = false) MultipartFile jarFile, - @RequestParam(required = false) String fileId) { - authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.job(jobId)); - return GeaflowApiResponse.success(jobManager.update(jobId, jobView, jarFile, fileId)); - } + @PostMapping + public GeaflowApiResponse createJob( + JobView jobView, + @RequestParam(required = false) MultipartFile jarFile, + @RequestParam(required = false) String fileId, + @RequestParam(required = false) List graphIds) { + authorizationManager.hasAuthority( + GeaflowAuthority.ALL, Resources.instance(jobView.getInstanceId())); + return GeaflowApiResponse.success(jobManager.create(jobView, jarFile, fileId, graphIds)); + } - @DeleteMapping("/{jobId}") - public GeaflowApiResponse deleteJob(@PathVariable String jobId) { - authorizationManager.hasAuthority(GeaflowAuthority.ALL, Resources.job(jobId)); - return GeaflowApiResponse.success(jobManager.drop(jobId)); - } + @PutMapping("/{jobId}") + public GeaflowApiResponse updateJob( + @PathVariable String jobId, + JobView jobView, + @RequestParam(required = false) MultipartFile jarFile, + @RequestParam(required = false) String fileId) { + authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.job(jobId)); + return GeaflowApiResponse.success(jobManager.update(jobId, jobView, jarFile, fileId)); + } + @DeleteMapping("/{jobId}") + public GeaflowApiResponse deleteJob(@PathVariable String jobId) { + authorizationManager.hasAuthority(GeaflowAuthority.ALL, Resources.job(jobId)); + return GeaflowApiResponse.success(jobManager.drop(jobId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/LLMController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/LLMController.java index c4ed540de..b0d2bec8f 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/LLMController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/LLMController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import java.util.List; + import org.apache.geaflow.console.biz.shared.AuthorizationManager; import org.apache.geaflow.console.biz.shared.LLMManager; import org.apache.geaflow.console.biz.shared.view.LLMView; @@ -38,49 +39,45 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/llms") public class LLMController { - @Autowired - private LLMManager llmManager; - - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private LLMManager llmManager; - @GetMapping - public GeaflowApiResponse> searchLLMs(LLMSearch search) { - return GeaflowApiResponse.success(llmManager.search(search)); - } + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("/{llmName}") - public GeaflowApiResponse queryLLM(@PathVariable String llmName) { - return GeaflowApiResponse.success(llmManager.getByName(llmName)); - } + @GetMapping + public GeaflowApiResponse> searchLLMs(LLMSearch search) { + return GeaflowApiResponse.success(llmManager.search(search)); + } - @PostMapping - public GeaflowApiResponse createLLM(@RequestBody LLMView llmView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(llmManager.create(llmView)); - } + @GetMapping("/{llmName}") + public GeaflowApiResponse queryLLM(@PathVariable String llmName) { + return GeaflowApiResponse.success(llmManager.getByName(llmName)); + } - @PutMapping("/{llmName}") - public GeaflowApiResponse updateLLM(@PathVariable String llmName, - @RequestBody LLMView llmView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(llmManager.updateByName(llmName, llmView)); - } + @PostMapping + public GeaflowApiResponse createLLM(@RequestBody LLMView llmView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(llmManager.create(llmView)); + } - @DeleteMapping("/{llmName}") - public GeaflowApiResponse deleteLLM(@PathVariable String llmName) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(llmManager.dropByName(llmName)); - } + @PutMapping("/{llmName}") + public GeaflowApiResponse updateLLM( + @PathVariable String llmName, @RequestBody LLMView llmView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(llmManager.updateByName(llmName, llmView)); + } - @GetMapping("/types") - public GeaflowApiResponse> getLLMTypes() { - return GeaflowApiResponse.success(llmManager.getLLMTypes()); - } + @DeleteMapping("/{llmName}") + public GeaflowApiResponse deleteLLM(@PathVariable String llmName) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(llmManager.dropByName(llmName)); + } + @GetMapping("/types") + public GeaflowApiResponse> getLLMTypes() { + return GeaflowApiResponse.success(llmManager.getLLMTypes()); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginConfigController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginConfigController.java index b46267f5e..0432ce481 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginConfigController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginConfigController.java @@ -38,32 +38,32 @@ @RequestMapping("/plugin-configs") public class PluginConfigController { - @Autowired - private PluginConfigManager pluginConfigManager; + @Autowired private PluginConfigManager pluginConfigManager; - @GetMapping - public GeaflowApiResponse> searchPluginConfigs(PluginConfigSearch searchView) { - return GeaflowApiResponse.success(pluginConfigManager.search(searchView)); - } + @GetMapping + public GeaflowApiResponse> searchPluginConfigs( + PluginConfigSearch searchView) { + return GeaflowApiResponse.success(pluginConfigManager.search(searchView)); + } - @GetMapping("/{id}") - public GeaflowApiResponse getPluginConfig(@PathVariable String id) { - return GeaflowApiResponse.success(pluginConfigManager.get(id)); - } + @GetMapping("/{id}") + public GeaflowApiResponse getPluginConfig(@PathVariable String id) { + return GeaflowApiResponse.success(pluginConfigManager.get(id)); + } - @PostMapping - public GeaflowApiResponse createPluginConfig(@RequestBody PluginConfigView view) { - return GeaflowApiResponse.success(pluginConfigManager.create(view)); - } + @PostMapping + public GeaflowApiResponse createPluginConfig(@RequestBody PluginConfigView view) { + return GeaflowApiResponse.success(pluginConfigManager.create(view)); + } - @PutMapping("/{id}") - public GeaflowApiResponse updatePluginConfig(@PathVariable String id, @RequestBody PluginConfigView view) { - return GeaflowApiResponse.success(pluginConfigManager.updateById(id, view)); - } - - @DeleteMapping("/{id}") - public GeaflowApiResponse deletePluginConfig(@PathVariable String id) { - return GeaflowApiResponse.success(pluginConfigManager.drop(id)); - } + @PutMapping("/{id}") + public GeaflowApiResponse updatePluginConfig( + @PathVariable String id, @RequestBody PluginConfigView view) { + return GeaflowApiResponse.success(pluginConfigManager.updateById(id, view)); + } + @DeleteMapping("/{id}") + public GeaflowApiResponse deletePluginConfig(@PathVariable String id) { + return GeaflowApiResponse.success(pluginConfigManager.drop(id)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginController.java index c72fbce48..fe53168da 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/PluginController.java @@ -35,41 +35,40 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; - @RestController @RequestMapping("/plugins") public class PluginController { - @Autowired - private PluginManager pluginManager; - - @GetMapping("/{pluginId}") - public GeaflowApiResponse queryPlugin(@PathVariable String pluginId) { - return GeaflowApiResponse.success(pluginManager.get(pluginId)); - } - - @GetMapping - public GeaflowApiResponse> searchPlugins(PluginSearch search) { - return GeaflowApiResponse.success(pluginManager.search(search)); - } + @Autowired private PluginManager pluginManager; - @PostMapping - public GeaflowApiResponse createPlugin(PluginView pluginView, - @RequestParam(required = false) MultipartFile jarFile, - @RequestParam(required = false) String fileId) { - return GeaflowApiResponse.success(pluginManager.createPlugin(pluginView, jarFile, fileId)); - } + @GetMapping("/{pluginId}") + public GeaflowApiResponse queryPlugin(@PathVariable String pluginId) { + return GeaflowApiResponse.success(pluginManager.get(pluginId)); + } - @PutMapping("/{pluginId}") - public GeaflowApiResponse updatePlugin(@PathVariable String pluginId, PluginView view, - @RequestParam(required = false) MultipartFile jarFile) { - return GeaflowApiResponse.success(pluginManager.updatePlugin(pluginId, view, jarFile)); - } + @GetMapping + public GeaflowApiResponse> searchPlugins(PluginSearch search) { + return GeaflowApiResponse.success(pluginManager.search(search)); + } - @DeleteMapping("/{pluginId}") - public GeaflowApiResponse deletePlugin(@PathVariable String pluginId) { - return GeaflowApiResponse.success(pluginManager.drop(pluginId)); - } + @PostMapping + public GeaflowApiResponse createPlugin( + PluginView pluginView, + @RequestParam(required = false) MultipartFile jarFile, + @RequestParam(required = false) String fileId) { + return GeaflowApiResponse.success(pluginManager.createPlugin(pluginView, jarFile, fileId)); + } + @PutMapping("/{pluginId}") + public GeaflowApiResponse updatePlugin( + @PathVariable String pluginId, + PluginView view, + @RequestParam(required = false) MultipartFile jarFile) { + return GeaflowApiResponse.success(pluginManager.updatePlugin(pluginId, view, jarFile)); + } + @DeleteMapping("/{pluginId}") + public GeaflowApiResponse deletePlugin(@PathVariable String pluginId) { + return GeaflowApiResponse.success(pluginManager.drop(pluginId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ReleaseController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ReleaseController.java index c8cc1bb59..018cae7b8 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ReleaseController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/ReleaseController.java @@ -39,30 +39,27 @@ @RestController public class ReleaseController { - @Autowired - private ReleaseManager releaseManager; + @Autowired private ReleaseManager releaseManager; - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("jobs/{jobId}/releases") - public GeaflowApiResponse> searchReleases(@PathVariable("jobId") String jobId, ReleaseSearch search) { - search.setJobId(jobId); - return GeaflowApiResponse.success(releaseManager.search(search)); - } - - @PostMapping("/jobs/{jobId}/releases") - public GeaflowApiResponse publish(@PathVariable("jobId") String jobId) { - authorizationManager.hasAuthority(GeaflowAuthority.EXECUTE, Resources.job(jobId)); - return GeaflowApiResponse.success(releaseManager.publish(jobId)); - } - - @PutMapping("jobs/{jobId}/releases") - public GeaflowApiResponse updateRelease(@PathVariable("jobId") String jobId, - @RequestBody ReleaseUpdateView updateView) { - authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.job(jobId)); - return GeaflowApiResponse.success(releaseManager.updateRelease(jobId, updateView)); - } + @GetMapping("jobs/{jobId}/releases") + public GeaflowApiResponse> searchReleases( + @PathVariable("jobId") String jobId, ReleaseSearch search) { + search.setJobId(jobId); + return GeaflowApiResponse.success(releaseManager.search(search)); + } + @PostMapping("/jobs/{jobId}/releases") + public GeaflowApiResponse publish(@PathVariable("jobId") String jobId) { + authorizationManager.hasAuthority(GeaflowAuthority.EXECUTE, Resources.job(jobId)); + return GeaflowApiResponse.success(releaseManager.publish(jobId)); + } + @PutMapping("jobs/{jobId}/releases") + public GeaflowApiResponse updateRelease( + @PathVariable("jobId") String jobId, @RequestBody ReleaseUpdateView updateView) { + authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.job(jobId)); + return GeaflowApiResponse.success(releaseManager.updateRelease(jobId, updateView)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/RemoteFileController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/RemoteFileController.java index 4ea4d834b..d4b622116 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/RemoteFileController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/RemoteFileController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.console.biz.shared.RemoteFileManager; import org.apache.geaflow.console.biz.shared.view.RemoteFileView; import org.apache.geaflow.console.common.dal.model.PageList; @@ -36,38 +37,38 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; - @RestController @RequestMapping("/remote-files") public class RemoteFileController { - @Autowired - private RemoteFileManager remoteFileManager; - - @GetMapping - public GeaflowApiResponse> searchFiles(RemoteFileSearch search) { - return GeaflowApiResponse.success(remoteFileManager.search(search)); - } + @Autowired private RemoteFileManager remoteFileManager; - @GetMapping("/{remoteFileId}") - public void getFile(HttpServletResponse response, @PathVariable String remoteFileId, - @RequestParam(required = false) boolean download) { - if (download) { - remoteFileManager.download(remoteFileId, response); + @GetMapping + public GeaflowApiResponse> searchFiles(RemoteFileSearch search) { + return GeaflowApiResponse.success(remoteFileManager.search(search)); + } - } else { - GeaflowApiResponse.success(remoteFileManager.get(remoteFileId)).write(response); - } - } + @GetMapping("/{remoteFileId}") + public void getFile( + HttpServletResponse response, + @PathVariable String remoteFileId, + @RequestParam(required = false) boolean download) { + if (download) { + remoteFileManager.download(remoteFileId, response); - @PutMapping("/{remoteFileId}") - public GeaflowApiResponse updateFile(@PathVariable String remoteFileId, @RequestPart MultipartFile file) { - return GeaflowApiResponse.success(remoteFileManager.upload(remoteFileId, file)); + } else { + GeaflowApiResponse.success(remoteFileManager.get(remoteFileId)).write(response); } + } - @DeleteMapping("/{remoteFileId}") - public GeaflowApiResponse deleteFile(@PathVariable String remoteFileId) { - return GeaflowApiResponse.success(remoteFileManager.delete(remoteFileId)); - } + @PutMapping("/{remoteFileId}") + public GeaflowApiResponse updateFile( + @PathVariable String remoteFileId, @RequestPart MultipartFile file) { + return GeaflowApiResponse.success(remoteFileManager.upload(remoteFileId, file)); + } + @DeleteMapping("/{remoteFileId}") + public GeaflowApiResponse deleteFile(@PathVariable String remoteFileId) { + return GeaflowApiResponse.success(remoteFileManager.delete(remoteFileId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SessionController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SessionController.java index 8f2bff629..d2551c6eb 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SessionController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SessionController.java @@ -29,29 +29,26 @@ import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/session") public class SessionController { - @Autowired - private AuthenticationManager authenticationManager; - - @GetMapping - public GeaflowApiResponse currentSession() { - return GeaflowApiResponse.success(authenticationManager.currentSession()); - } + @Autowired private AuthenticationManager authenticationManager; - @PostMapping("/switch") - @ResponseBody - public GeaflowApiResponse switchSession() { - return GeaflowApiResponse.success(authenticationManager.switchSession()); - } + @GetMapping + public GeaflowApiResponse currentSession() { + return GeaflowApiResponse.success(authenticationManager.currentSession()); + } - @PostMapping("/logout") - @ResponseBody - public GeaflowApiResponse logout() { - return GeaflowApiResponse.success(authenticationManager.logout()); - } + @PostMapping("/switch") + @ResponseBody + public GeaflowApiResponse switchSession() { + return GeaflowApiResponse.success(authenticationManager.switchSession()); + } + @PostMapping("/logout") + @ResponseBody + public GeaflowApiResponse logout() { + return GeaflowApiResponse.success(authenticationManager.logout()); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/StatementController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/StatementController.java index 54a35a8e3..51f3f570a 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/StatementController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/StatementController.java @@ -40,36 +40,33 @@ @RequestMapping("/statements") public class StatementController { - @Autowired - private StatementManager statementManager; + @Autowired private StatementManager statementManager; - @GetMapping - public GeaflowApiResponse> searchStatement(StatementSearch search) { - if (StringUtils.isEmpty(search.getSort())) { - search.setOrder(SortOrder.DESC, CREATE_TIME_FIELD_NAME); - } - return GeaflowApiResponse.success(statementManager.search(search)); + @GetMapping + public GeaflowApiResponse> searchStatement(StatementSearch search) { + if (StringUtils.isEmpty(search.getSort())) { + search.setOrder(SortOrder.DESC, CREATE_TIME_FIELD_NAME); } + return GeaflowApiResponse.success(statementManager.search(search)); + } - @GetMapping("/{statementId}") - public GeaflowApiResponse getStatement(@PathVariable String statementId) { - return GeaflowApiResponse.success(statementManager.get(statementId)); - } - - @PostMapping - public GeaflowApiResponse createStatement(StatementView statementView) { - return GeaflowApiResponse.success(statementManager.create(statementView)); - } + @GetMapping("/{statementId}") + public GeaflowApiResponse getStatement(@PathVariable String statementId) { + return GeaflowApiResponse.success(statementManager.get(statementId)); + } + @PostMapping + public GeaflowApiResponse createStatement(StatementView statementView) { + return GeaflowApiResponse.success(statementManager.create(statementView)); + } - @DeleteMapping("/{statementId}") - public GeaflowApiResponse deleteStatement(@PathVariable String statementId) { - return GeaflowApiResponse.success(statementManager.drop(statementId)); - } - - @DeleteMapping("/jobs/{jobId}") - public GeaflowApiResponse deleteJobStatement(@PathVariable String jobId) { - return GeaflowApiResponse.success(statementManager.dropByJobId(jobId)); - } + @DeleteMapping("/{statementId}") + public GeaflowApiResponse deleteStatement(@PathVariable String statementId) { + return GeaflowApiResponse.success(statementManager.drop(statementId)); + } + @DeleteMapping("/jobs/{jobId}") + public GeaflowApiResponse deleteJobStatement(@PathVariable String jobId) { + return GeaflowApiResponse.success(statementManager.dropByJobId(jobId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SystemConfigController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SystemConfigController.java index 8ed546663..0e5e889d5 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SystemConfigController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/SystemConfigController.java @@ -37,52 +37,49 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/configs") public class SystemConfigController { - @Autowired - private SystemConfigManager systemConfigManager; - - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private SystemConfigManager systemConfigManager; - @GetMapping - public GeaflowApiResponse> searchConfigs(SystemConfigSearch search) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(systemConfigManager.search(search)); - } + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("/{key}") - public GeaflowApiResponse getConfig(@PathVariable String key, - @RequestParam(required = false) String tenantId) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(systemConfigManager.getConfig(tenantId, key)); - } + @GetMapping + public GeaflowApiResponse> searchConfigs(SystemConfigSearch search) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(systemConfigManager.search(search)); + } - @GetMapping("/{key}/value") - public GeaflowApiResponse getValue(@PathVariable String key) { - return GeaflowApiResponse.success(systemConfigManager.getValue(key)); - } + @GetMapping("/{key}") + public GeaflowApiResponse getConfig( + @PathVariable String key, @RequestParam(required = false) String tenantId) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(systemConfigManager.getConfig(tenantId, key)); + } - @PostMapping - public GeaflowApiResponse createConfig(@RequestBody SystemConfigView view) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(systemConfigManager.createConfig(view)); - } + @GetMapping("/{key}/value") + public GeaflowApiResponse getValue(@PathVariable String key) { + return GeaflowApiResponse.success(systemConfigManager.getValue(key)); + } - @PutMapping("/{key}") - public GeaflowApiResponse updateConfig(@PathVariable String key, @RequestBody SystemConfigView view) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(systemConfigManager.updateConfig(key, view)); - } + @PostMapping + public GeaflowApiResponse createConfig(@RequestBody SystemConfigView view) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(systemConfigManager.createConfig(view)); + } - @DeleteMapping("/{key}") - public GeaflowApiResponse deleteConfig(@PathVariable String key, - @RequestParam(required = false) String tenantId) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(systemConfigManager.deleteConfig(tenantId, key)); - } + @PutMapping("/{key}") + public GeaflowApiResponse updateConfig( + @PathVariable String key, @RequestBody SystemConfigView view) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(systemConfigManager.updateConfig(key, view)); + } + @DeleteMapping("/{key}") + public GeaflowApiResponse deleteConfig( + @PathVariable String key, @RequestParam(required = false) String tenantId) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(systemConfigManager.deleteConfig(tenantId, key)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TableController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TableController.java index 88b2732d1..fdcb4ce55 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TableController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TableController.java @@ -38,44 +38,45 @@ @RequestMapping public class TableController { - @Autowired - private TableManager tableManager; + @Autowired private TableManager tableManager; - @GetMapping("/tables") - public GeaflowApiResponse> search(TableSearch search) { - return GeaflowApiResponse.success(tableManager.search(search)); - } + @GetMapping("/tables") + public GeaflowApiResponse> search(TableSearch search) { + return GeaflowApiResponse.success(tableManager.search(search)); + } - @GetMapping("/instances/{instanceName}/tables") - public GeaflowApiResponse> instanceSearch(@PathVariable("instanceName") String instanceName, - TableSearch search) { - return GeaflowApiResponse.success(tableManager.searchByInstanceName(instanceName, search)); - } + @GetMapping("/instances/{instanceName}/tables") + public GeaflowApiResponse> instanceSearch( + @PathVariable("instanceName") String instanceName, TableSearch search) { + return GeaflowApiResponse.success(tableManager.searchByInstanceName(instanceName, search)); + } - @GetMapping("/instances/{instanceName}/tables/{tableName}") - public GeaflowApiResponse getTable(@PathVariable("instanceName") String instanceName, - @PathVariable("tableName") String tableName) { - return GeaflowApiResponse.success(tableManager.getByName(instanceName, tableName)); - } + @GetMapping("/instances/{instanceName}/tables/{tableName}") + public GeaflowApiResponse getTable( + @PathVariable("instanceName") String instanceName, + @PathVariable("tableName") String tableName) { + return GeaflowApiResponse.success(tableManager.getByName(instanceName, tableName)); + } - @PostMapping("/instances/{instanceName}/tables") - public GeaflowApiResponse createTable(@PathVariable("instanceName") String instanceName, - @RequestBody TableView tableView) { - return GeaflowApiResponse.success(tableManager.create(instanceName, tableView)); - } + @PostMapping("/instances/{instanceName}/tables") + public GeaflowApiResponse createTable( + @PathVariable("instanceName") String instanceName, @RequestBody TableView tableView) { + return GeaflowApiResponse.success(tableManager.create(instanceName, tableView)); + } - @PutMapping("/instances/{instanceName}/tables/{tableName}") - public GeaflowApiResponse updateTable(@PathVariable("instanceName") String instanceName, - @PathVariable("tableName") String tableName, - @RequestBody TableView tableView) { - return GeaflowApiResponse.success(tableManager.updateByName(instanceName, tableName, tableView)); - } - - - @DeleteMapping("/instances/{instanceName}/tables/{tableName}") - public GeaflowApiResponse dropTable(@PathVariable("instanceName") String instanceName, - @PathVariable("tableName") String tableName) { - return GeaflowApiResponse.success(tableManager.dropByName(instanceName, tableName)); - } + @PutMapping("/instances/{instanceName}/tables/{tableName}") + public GeaflowApiResponse updateTable( + @PathVariable("instanceName") String instanceName, + @PathVariable("tableName") String tableName, + @RequestBody TableView tableView) { + return GeaflowApiResponse.success( + tableManager.updateByName(instanceName, tableName, tableView)); + } + @DeleteMapping("/instances/{instanceName}/tables/{tableName}") + public GeaflowApiResponse dropTable( + @PathVariable("instanceName") String instanceName, + @PathVariable("tableName") String tableName) { + return GeaflowApiResponse.success(tableManager.dropByName(instanceName, tableName)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TaskController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TaskController.java index 6fc260d3d..4162e4388 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TaskController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TaskController.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.controller.api; import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.console.biz.shared.AuthorizationManager; import org.apache.geaflow.console.biz.shared.TaskManager; import org.apache.geaflow.console.biz.shared.view.TaskOperationView; @@ -50,87 +51,88 @@ @RestController public class TaskController { - @Autowired - private TaskManager taskManager; - - @Autowired - private AuthorizationManager authorizationManager; - - @GetMapping("/tasks") - public GeaflowApiResponse> searchTasks(TaskSearch search) { - return GeaflowApiResponse.success(taskManager.search(search)); - } - - @GetMapping("/tasks/{taskId}") - public GeaflowApiResponse getTask(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.get(taskId)); - } - - @PostMapping("/tasks/{taskId}/operations") - public GeaflowApiResponse operateTask(@PathVariable String taskId, - @RequestBody TaskOperationView request) { - authorizationManager.hasAuthority(GeaflowAuthority.EXECUTE, Resources.task(taskId)); - taskManager.operate(taskId, request.getAction()); - return GeaflowApiResponse.success(true); - } - - @GetMapping("/tasks/{taskId}/status") - public GeaflowApiResponse queryTaskStatus(@PathVariable String taskId, - @RequestParam(required = false) Boolean refresh) { - return GeaflowApiResponse.success(taskManager.queryStatus(taskId, refresh)); - } - - @GetMapping("/tasks/{taskId}/pipelines") - public GeaflowApiResponse> queryTaskPipelines(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.queryPipelines(taskId)); - } - - @GetMapping("/tasks/{taskId}/pipelines/{pipelineName}/cycles") - public GeaflowApiResponse> queryTaskCycles(@PathVariable String taskId, - @PathVariable String pipelineName) { - return GeaflowApiResponse.success(taskManager.queryCycles(taskId, pipelineName)); - } - - @GetMapping("/tasks/{taskId}/errors") - public GeaflowApiResponse> queryTaskErrors(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.queryErrors(taskId)); - } - - @GetMapping("tasks/{taskId}/metric-meta") - public GeaflowApiResponse> queryTaskMetricMeta(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.queryMetricMeta(taskId)); - } - - @PostMapping("/tasks/{taskId}/metrics") - public GeaflowApiResponse> queryTaskMetrics(@PathVariable String taskId, - @RequestBody GeaflowMetricQueryRequest queryRequest) { - return GeaflowApiResponse.success(taskManager.queryMetrics(taskId, queryRequest)); - } - - @GetMapping("/tasks/{taskId}/offsets") - public GeaflowApiResponse> queryTaskOffsets(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.queryOffsets(taskId)); - } - - @GetMapping("tasks/{taskId}/heartbeat") - public GeaflowApiResponse queryTaskHeartbeat(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.queryHeartbeat(taskId)); - } - - @GetMapping("tasks/{taskId}/logs") - public GeaflowApiResponse getLogs(@PathVariable String taskId) { - return GeaflowApiResponse.success(taskManager.getLogs(taskId)); - } - - @PostMapping("/tasks/{taskId}/startup-notify") - public GeaflowApiResponse startupNotify(@PathVariable String taskId, - @RequestBody TaskStartupNotifyView startupNotifyView) { - taskManager.startupNotify(taskId, startupNotifyView); - return GeaflowApiResponse.success(null); - } - - @GetMapping("/tasks/{taskId}/files") - public void downloadTaskFile(HttpServletResponse response, @PathVariable String taskId, @RequestParam String path) { - taskManager.download(taskId, path, response); - } + @Autowired private TaskManager taskManager; + + @Autowired private AuthorizationManager authorizationManager; + + @GetMapping("/tasks") + public GeaflowApiResponse> searchTasks(TaskSearch search) { + return GeaflowApiResponse.success(taskManager.search(search)); + } + + @GetMapping("/tasks/{taskId}") + public GeaflowApiResponse getTask(@PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.get(taskId)); + } + + @PostMapping("/tasks/{taskId}/operations") + public GeaflowApiResponse operateTask( + @PathVariable String taskId, @RequestBody TaskOperationView request) { + authorizationManager.hasAuthority(GeaflowAuthority.EXECUTE, Resources.task(taskId)); + taskManager.operate(taskId, request.getAction()); + return GeaflowApiResponse.success(true); + } + + @GetMapping("/tasks/{taskId}/status") + public GeaflowApiResponse queryTaskStatus( + @PathVariable String taskId, @RequestParam(required = false) Boolean refresh) { + return GeaflowApiResponse.success(taskManager.queryStatus(taskId, refresh)); + } + + @GetMapping("/tasks/{taskId}/pipelines") + public GeaflowApiResponse> queryTaskPipelines( + @PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.queryPipelines(taskId)); + } + + @GetMapping("/tasks/{taskId}/pipelines/{pipelineName}/cycles") + public GeaflowApiResponse> queryTaskCycles( + @PathVariable String taskId, @PathVariable String pipelineName) { + return GeaflowApiResponse.success(taskManager.queryCycles(taskId, pipelineName)); + } + + @GetMapping("/tasks/{taskId}/errors") + public GeaflowApiResponse> queryTaskErrors(@PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.queryErrors(taskId)); + } + + @GetMapping("tasks/{taskId}/metric-meta") + public GeaflowApiResponse> queryTaskMetricMeta( + @PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.queryMetricMeta(taskId)); + } + + @PostMapping("/tasks/{taskId}/metrics") + public GeaflowApiResponse> queryTaskMetrics( + @PathVariable String taskId, @RequestBody GeaflowMetricQueryRequest queryRequest) { + return GeaflowApiResponse.success(taskManager.queryMetrics(taskId, queryRequest)); + } + + @GetMapping("/tasks/{taskId}/offsets") + public GeaflowApiResponse> queryTaskOffsets(@PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.queryOffsets(taskId)); + } + + @GetMapping("tasks/{taskId}/heartbeat") + public GeaflowApiResponse queryTaskHeartbeat(@PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.queryHeartbeat(taskId)); + } + + @GetMapping("tasks/{taskId}/logs") + public GeaflowApiResponse getLogs(@PathVariable String taskId) { + return GeaflowApiResponse.success(taskManager.getLogs(taskId)); + } + + @PostMapping("/tasks/{taskId}/startup-notify") + public GeaflowApiResponse startupNotify( + @PathVariable String taskId, @RequestBody TaskStartupNotifyView startupNotifyView) { + taskManager.startupNotify(taskId, startupNotifyView); + return GeaflowApiResponse.success(null); + } + + @GetMapping("/tasks/{taskId}/files") + public void downloadTaskFile( + HttpServletResponse response, @PathVariable String taskId, @RequestParam String path) { + taskManager.download(taskId, path, response); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TenantController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TenantController.java index d190da3bf..7e68cecdc 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TenantController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/TenantController.java @@ -39,41 +39,39 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/tenants") public class TenantController { - @Autowired - private TenantManager tenantManager; + @Autowired private TenantManager tenantManager; - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private AuthorizationManager authorizationManager; - @GetMapping - public GeaflowApiResponse> searchTenants(TenantSearch search) { - return GeaflowApiResponse.success(tenantManager.search(search)); - } + @GetMapping + public GeaflowApiResponse> searchTenants(TenantSearch search) { + return GeaflowApiResponse.success(tenantManager.search(search)); + } - @GetMapping("/{tenantId}") - public GeaflowApiResponse getTenant(@PathVariable String tenantId) { - return GeaflowApiResponse.success(tenantManager.get(tenantId)); - } + @GetMapping("/{tenantId}") + public GeaflowApiResponse getTenant(@PathVariable String tenantId) { + return GeaflowApiResponse.success(tenantManager.get(tenantId)); + } - @PostMapping - public GeaflowApiResponse createTenant(@RequestBody TenantView tenantView) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(tenantManager.create(tenantView)); - } + @PostMapping + public GeaflowApiResponse createTenant(@RequestBody TenantView tenantView) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(tenantManager.create(tenantView)); + } - @PutMapping("/{tenantId}") - public GeaflowApiResponse updateTenant(@PathVariable String tenantId, @RequestBody TenantView tenantView) { - authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.tenant(tenantId)); - return GeaflowApiResponse.success(tenantManager.updateById(tenantId, tenantView)); - } + @PutMapping("/{tenantId}") + public GeaflowApiResponse updateTenant( + @PathVariable String tenantId, @RequestBody TenantView tenantView) { + authorizationManager.hasAuthority(GeaflowAuthority.UPDATE, Resources.tenant(tenantId)); + return GeaflowApiResponse.success(tenantManager.updateById(tenantId, tenantView)); + } - @DeleteMapping("/{tenantId}") - public GeaflowApiResponse deleteTenant(@PathVariable String tenantId) { - throw new GeaflowException("Delete tenant {} not allowed", tenantId); - } + @DeleteMapping("/{tenantId}") + public GeaflowApiResponse deleteTenant(@PathVariable String tenantId) { + throw new GeaflowException("Delete tenant {} not allowed", tenantId); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/UserController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/UserController.java index b5594f93e..44b85d461 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/UserController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/UserController.java @@ -35,41 +35,38 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - @RestController @RequestMapping("/users") public class UserController { - @Autowired - private UserManager userManager; + @Autowired private UserManager userManager; - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private AuthorizationManager authorizationManager; - @GetMapping - public GeaflowApiResponse> searchUsers(UserSearch search) { - return GeaflowApiResponse.success(userManager.search(search)); - } + @GetMapping + public GeaflowApiResponse> searchUsers(UserSearch search) { + return GeaflowApiResponse.success(userManager.search(search)); + } - @GetMapping("/{userId}") - public GeaflowApiResponse getUser(@PathVariable String userId) { - return GeaflowApiResponse.success(userManager.getUser(userId)); - } + @GetMapping("/{userId}") + public GeaflowApiResponse getUser(@PathVariable String userId) { + return GeaflowApiResponse.success(userManager.getUser(userId)); + } - @PostMapping - public GeaflowApiResponse addUser(UserView userView) { - authorizationManager.hasRole(GeaflowRole.TENANT_ADMIN); - return GeaflowApiResponse.success(userManager.addUser(userView)); - } + @PostMapping + public GeaflowApiResponse addUser(UserView userView) { + authorizationManager.hasRole(GeaflowRole.TENANT_ADMIN); + return GeaflowApiResponse.success(userManager.addUser(userView)); + } - @PutMapping("/{userId}") - public GeaflowApiResponse updateUser(@PathVariable String userId, UserView userView) { - return GeaflowApiResponse.success(userManager.updateUser(userId, userView)); - } + @PutMapping("/{userId}") + public GeaflowApiResponse updateUser(@PathVariable String userId, UserView userView) { + return GeaflowApiResponse.success(userManager.updateUser(userId, userView)); + } - @DeleteMapping("/{userId}") - public GeaflowApiResponse deleteUser(@PathVariable String userId) { - authorizationManager.hasRole(GeaflowRole.TENANT_ADMIN); - return GeaflowApiResponse.success(userManager.deleteUser(userId)); - } + @DeleteMapping("/{userId}") + public GeaflowApiResponse deleteUser(@PathVariable String userId) { + authorizationManager.hasRole(GeaflowRole.TENANT_ADMIN); + return GeaflowApiResponse.success(userManager.deleteUser(userId)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VersionController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VersionController.java index e1ea530d8..6c51e6520 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VersionController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VersionController.java @@ -37,47 +37,48 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; - @RestController @RequestMapping("/versions") public class VersionController { - @Autowired - private VersionManager versionManager; - - @Autowired - private AuthorizationManager authorizationManager; + @Autowired private VersionManager versionManager; - @GetMapping - public GeaflowApiResponse> searchVersions(VersionSearch search) { - return GeaflowApiResponse.success(versionManager.searchVersions(search)); - } + @Autowired private AuthorizationManager authorizationManager; - @GetMapping("/{versionName}") - public GeaflowApiResponse getVersion(@PathVariable String versionName) { - return GeaflowApiResponse.success(versionManager.getVersion(versionName)); - } + @GetMapping + public GeaflowApiResponse> searchVersions(VersionSearch search) { + return GeaflowApiResponse.success(versionManager.searchVersions(search)); + } - @PostMapping - public GeaflowApiResponse createVersion(VersionView view, - @RequestParam(required = false) MultipartFile engineJarFile, - @RequestParam(required = false) MultipartFile langJarFile) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(versionManager.createVersion(view, engineJarFile, langJarFile)); - } + @GetMapping("/{versionName}") + public GeaflowApiResponse getVersion(@PathVariable String versionName) { + return GeaflowApiResponse.success(versionManager.getVersion(versionName)); + } - @PutMapping("/{versionName}") - public GeaflowApiResponse updateVersion(@PathVariable String versionName, VersionView view, - @RequestParam(required = false) MultipartFile engineJarFile, - @RequestParam(required = false) MultipartFile langJarFile) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(versionManager.updateVersion(versionName, view, engineJarFile, langJarFile)); - } + @PostMapping + public GeaflowApiResponse createVersion( + VersionView view, + @RequestParam(required = false) MultipartFile engineJarFile, + @RequestParam(required = false) MultipartFile langJarFile) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success( + versionManager.createVersion(view, engineJarFile, langJarFile)); + } - @DeleteMapping("/{versionName}") - public GeaflowApiResponse deleteVersion(@PathVariable String versionName) { - authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); - return GeaflowApiResponse.success(versionManager.deleteVersion(versionName)); - } + @PutMapping("/{versionName}") + public GeaflowApiResponse updateVersion( + @PathVariable String versionName, + VersionView view, + @RequestParam(required = false) MultipartFile engineJarFile, + @RequestParam(required = false) MultipartFile langJarFile) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success( + versionManager.updateVersion(versionName, view, engineJarFile, langJarFile)); + } + @DeleteMapping("/{versionName}") + public GeaflowApiResponse deleteVersion(@PathVariable String versionName) { + authorizationManager.hasRole(GeaflowRole.SYSTEM_ADMIN); + return GeaflowApiResponse.success(versionManager.deleteVersion(versionName)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VertexController.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VertexController.java index b3366087b..7f9e9226c 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VertexController.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/controller/api/VertexController.java @@ -38,44 +38,46 @@ @RequestMapping public class VertexController { - @Autowired - private VertexManager vertexManager; + @Autowired private VertexManager vertexManager; - @GetMapping("/vertices") - public GeaflowApiResponse> search(VertexSearch search) { - return GeaflowApiResponse.success(vertexManager.search(search)); - } + @GetMapping("/vertices") + public GeaflowApiResponse> search(VertexSearch search) { + return GeaflowApiResponse.success(vertexManager.search(search)); + } - @GetMapping("/instances/{instanceName}/vertices") - public GeaflowApiResponse> instanceSearch(@PathVariable("instanceName") String instanceName, - VertexSearch search) { - return GeaflowApiResponse.success(vertexManager.searchByInstanceName(instanceName, search)); - } + @GetMapping("/instances/{instanceName}/vertices") + public GeaflowApiResponse> instanceSearch( + @PathVariable("instanceName") String instanceName, VertexSearch search) { + return GeaflowApiResponse.success(vertexManager.searchByInstanceName(instanceName, search)); + } - @GetMapping("/instances/{instanceName}/vertices/{vertexName}") - public GeaflowApiResponse getVertex(@PathVariable("instanceName") String instanceName, - @PathVariable("vertexName") String vertexName) { - return GeaflowApiResponse.success(vertexManager.getByName(instanceName, vertexName)); - } + @GetMapping("/instances/{instanceName}/vertices/{vertexName}") + public GeaflowApiResponse getVertex( + @PathVariable("instanceName") String instanceName, + @PathVariable("vertexName") String vertexName) { + return GeaflowApiResponse.success(vertexManager.getByName(instanceName, vertexName)); + } - @PostMapping("/instances/{instanceName}/vertices") - public GeaflowApiResponse create(@PathVariable("instanceName") String instanceName, - @RequestBody VertexView vertexView) { - return GeaflowApiResponse.success(vertexManager.create(instanceName, vertexView)); - } + @PostMapping("/instances/{instanceName}/vertices") + public GeaflowApiResponse create( + @PathVariable("instanceName") String instanceName, @RequestBody VertexView vertexView) { + return GeaflowApiResponse.success(vertexManager.create(instanceName, vertexView)); + } - @PutMapping("/instances/{instanceName}/vertices/{vertexName}") - public GeaflowApiResponse update(@PathVariable("instanceName") String instanceName, - @PathVariable("vertexName") String vertexName, - @RequestBody VertexView vertexView) { - return GeaflowApiResponse.success(vertexManager.updateByName(instanceName, vertexName, vertexView)); - } + @PutMapping("/instances/{instanceName}/vertices/{vertexName}") + public GeaflowApiResponse update( + @PathVariable("instanceName") String instanceName, + @PathVariable("vertexName") String vertexName, + @RequestBody VertexView vertexView) { + return GeaflowApiResponse.success( + vertexManager.updateByName(instanceName, vertexName, vertexView)); + } + @DeleteMapping("/instances/{instanceName}/vertices/{vertexName}") + public GeaflowApiResponse drop( + @PathVariable("instanceName") String instanceName, + @PathVariable("vertexName") String vertexName) { - @DeleteMapping("/instances/{instanceName}/vertices/{vertexName}") - public GeaflowApiResponse drop(@PathVariable("instanceName") String instanceName, - @PathVariable("vertexName") String vertexName) { - - return GeaflowApiResponse.success(vertexManager.dropByName(instanceName, vertexName)); - } + return GeaflowApiResponse.success(vertexManager.dropByName(instanceName, vertexName)); + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowAuthInterceptor.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowAuthInterceptor.java index a9d4a873d..1caf0e640 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowAuthInterceptor.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowAuthInterceptor.java @@ -19,10 +19,9 @@ package org.apache.geaflow.console.web.mvc; -import com.google.common.base.Preconditions; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.biz.shared.AuthenticationManager; import org.apache.geaflow.console.biz.shared.AuthorizationManager; import org.apache.geaflow.console.biz.shared.TenantManager; @@ -39,64 +38,66 @@ import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; +import com.google.common.base.Preconditions; + +import lombok.extern.slf4j.Slf4j; + @Component @Slf4j public class GeaflowAuthInterceptor implements HandlerInterceptor { - @Autowired - private TenantManager tenantManager; - - @Autowired - private AuthenticationManager authenticationManager; - - @Autowired - private AuthorizationManager authorizationManager; - - @Autowired - private TaskService taskService; - - @Override - public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) - throws Exception { - if (request.getMethod().equals("OPTIONS")) { - return true; - } - - GeaflowContext context = ContextHolder.get(); - - String token = GeaflowApiRequest.getSessionToken(request); - - if (TokenGenerator.isTaskToken(token)) { - // The request is from tasks - GeaflowId task = taskService.getByTaskToken(token); - Preconditions.checkNotNull(task, "Invalid task token %s", token); - context.setTaskId(task.getId()); - context.setUserId(task.getModifierId()); - context.setSystemSession(false); - context.setSessionToken(token); - context.setTenantId(task.getTenantId()); - context.getRoleTypes().addAll(authorizationManager.getUserRoleTypes(task.getModifierId())); - return true; - } - - AuthenticationView authentication = authenticationManager.authenticate(token); - String userId = authentication.getUserId(); - boolean systemSession = authentication.isSystemSession(); - context.setUserId(userId); - context.setSystemSession(systemSession); - context.setSessionToken(token); - if (!systemSession) { - TenantView tenant = tenantManager.getActiveTenant(userId); - context.setTenantId(tenant.getId()); - } - - context.getRoleTypes().addAll(authorizationManager.getUserRoleTypes(userId)); - return true; + @Autowired private TenantManager tenantManager; + + @Autowired private AuthenticationManager authenticationManager; + + @Autowired private AuthorizationManager authorizationManager; + + @Autowired private TaskService taskService; + + @Override + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) + throws Exception { + if (request.getMethod().equals("OPTIONS")) { + return true; } - @Override - public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, - ModelAndView modelAndView) throws Exception { + GeaflowContext context = ContextHolder.get(); + + String token = GeaflowApiRequest.getSessionToken(request); + if (TokenGenerator.isTaskToken(token)) { + // The request is from tasks + GeaflowId task = taskService.getByTaskToken(token); + Preconditions.checkNotNull(task, "Invalid task token %s", token); + context.setTaskId(task.getId()); + context.setUserId(task.getModifierId()); + context.setSystemSession(false); + context.setSessionToken(token); + context.setTenantId(task.getTenantId()); + context.getRoleTypes().addAll(authorizationManager.getUserRoleTypes(task.getModifierId())); + return true; } + + AuthenticationView authentication = authenticationManager.authenticate(token); + String userId = authentication.getUserId(); + boolean systemSession = authentication.isSystemSession(); + context.setUserId(userId); + context.setSystemSession(systemSession); + context.setSessionToken(token); + if (!systemSession) { + TenantView tenant = tenantManager.getActiveTenant(userId); + context.setTenantId(tenant.getId()); + } + + context.getRoleTypes().addAll(authorizationManager.getUserRoleTypes(userId)); + return true; + } + + @Override + public void postHandle( + HttpServletRequest request, + HttpServletResponse response, + Object handler, + ModelAndView modelAndView) + throws Exception {} } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowGlobalFilter.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowGlobalFilter.java index af650c264..c38999e2f 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowGlobalFilter.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowGlobalFilter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.console.web.mvc; import java.io.IOException; + import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -28,40 +29,43 @@ import javax.servlet.annotation.WebFilter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import lombok.extern.slf4j.Slf4j; + import org.apache.geaflow.console.common.util.Fmt; import org.apache.geaflow.console.common.util.context.ContextHolder; import org.apache.geaflow.console.web.api.GeaflowApiResponse; import org.springframework.core.Ordered; +import lombok.extern.slf4j.Slf4j; + @Slf4j @WebFilter(urlPatterns = {"/auth/*", "/api/*"}) public class GeaflowGlobalFilter implements Filter, Ordered { - @Override - public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) - throws IOException, ServletException { - HttpServletRequest request = (HttpServletRequest) servletRequest; - HttpServletResponse response = (HttpServletResponse) servletResponse; + @Override + public void doFilter( + ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) + throws IOException, ServletException { + HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; - try { - ContextHolder.init(); - ContextHolder.get().setRequest(request); - filterChain.doFilter(request, response); + try { + ContextHolder.init(); + ContextHolder.get().setRequest(request); + filterChain.doFilter(request, response); - } catch (Exception e) { - String message = Fmt.as("Request url {} failed", request.getRequestURI()); - log.info(message); - log.error(message, e.getCause()); - GeaflowApiResponse.error(e.getCause()).write(response); + } catch (Exception e) { + String message = Fmt.as("Request url {} failed", request.getRequestURI()); + log.info(message); + log.error(message, e.getCause()); + GeaflowApiResponse.error(e.getCause()).write(response); - } finally { - ContextHolder.destroy(); - } + } finally { + ContextHolder.destroy(); } + } - @Override - public int getOrder() { - return 0; - } + @Override + public int getOrder() { + return 0; + } } diff --git a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowWebConfig.java b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowWebConfig.java index c4d8c0a18..05803f8d9 100644 --- a/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowWebConfig.java +++ b/geaflow-console/app/web/src/main/java/org/apache/geaflow/console/web/mvc/GeaflowWebConfig.java @@ -19,7 +19,6 @@ package org.apache.geaflow.console.web.mvc; -import com.alibaba.fastjson.support.spring.FastJsonHttpMessageConverter; import org.apache.geaflow.console.common.util.context.GeaflowContext; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.http.HttpMessageConverters; @@ -30,31 +29,40 @@ import org.springframework.web.servlet.config.annotation.PathMatchConfigurer; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import com.alibaba.fastjson.support.spring.FastJsonHttpMessageConverter; + @Configuration public class GeaflowWebConfig implements WebMvcConfigurer { - @Autowired - private GeaflowAuthInterceptor authenticateInterceptor; - - @Override - public void configurePathMatch(PathMatchConfigurer configurer) { - configurer.addPathPrefix(GeaflowContext.API_PREFIX, - c -> c.getPackage().getName().startsWith("org.apache.geaflow.console.web.controller.api")); - } - - @Override - public void addInterceptors(InterceptorRegistry registry) { - registry.addInterceptor(authenticateInterceptor).addPathPatterns(GeaflowContext.API_PREFIX + "/**"); - } - - @Override - public void addCorsMappings(CorsRegistry registry) { - registry.addMapping("/**").allowedHeaders("*").allowedMethods("*").allowedOriginPatterns("*") - .allowCredentials(true).maxAge(3600); - } - - @Bean - public HttpMessageConverters fastJsonConfigure() { - return new HttpMessageConverters(new FastJsonHttpMessageConverter()); - } + @Autowired private GeaflowAuthInterceptor authenticateInterceptor; + + @Override + public void configurePathMatch(PathMatchConfigurer configurer) { + configurer.addPathPrefix( + GeaflowContext.API_PREFIX, + c -> c.getPackage().getName().startsWith("org.apache.geaflow.console.web.controller.api")); + } + + @Override + public void addInterceptors(InterceptorRegistry registry) { + registry + .addInterceptor(authenticateInterceptor) + .addPathPatterns(GeaflowContext.API_PREFIX + "/**"); + } + + @Override + public void addCorsMappings(CorsRegistry registry) { + registry + .addMapping("/**") + .allowedHeaders("*") + .allowedMethods("*") + .allowedOriginPatterns("*") + .allowCredentials(true) + .maxAge(3600); + } + + @Bean + public HttpMessageConverters fastJsonConfigure() { + return new HttpMessageConverters(new FastJsonHttpMessageConverter()); + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActions.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActions.java index 4c66fedbc..dd8024340 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActions.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActions.java @@ -21,13 +21,13 @@ public interface GeaFlowMcpActions { - String createGraph(String graphName, String ddl); + String createGraph(String graphName, String ddl); - String queryGraph(String graphName, String gql); + String queryGraph(String graphName, String gql); - String queryType(String graphName, String type); + String queryType(String graphName, String type); - String getGraphSchema(String graphName); + String getGraphSchema(String graphName); - void withUser(String user); + void withUser(String user); } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActionsLocalImpl.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActionsLocalImpl.java index eefb2e3cb..e23cabcf7 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActionsLocalImpl.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpActionsLocalImpl.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.local.client.LocalEnvironment; @@ -43,179 +44,188 @@ public class GeaFlowMcpActionsLocalImpl implements GeaFlowMcpActions { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpActionsLocalImpl.class); + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpActionsLocalImpl.class); - private Map configs; - private String user; - private IEnvironment localEnv = new LocalEnvironment(); + private Map configs; + private String user; + private IEnvironment localEnv = new LocalEnvironment(); - public GeaFlowMcpActionsLocalImpl(Map configs) { - this.configs = configs; - } + public GeaFlowMcpActionsLocalImpl(Map configs) { + this.configs = configs; + } - @Override - public String createGraph(String graphName, String ddl) { - QueryLocalRunner runner = new QueryLocalRunner(); - runner.withGraphName(graphName).withGraphDefine(ddl); - GeaFlowGraph graph; - try { - graph = runner.compileGraph(); - } catch (Throwable e) { - LOGGER.error("Compile error: " + e.getCause().getMessage()); - throw new GeaflowRuntimeException("Compile error: " + e.getCause().getMessage()); - } - if (graph == null) { - LOGGER.error("Cannot create graph: " + graphName); - throw new GeaflowRuntimeException("Cannot create graph: " + graphName); - } - //Store graph ddl to schema - try { - McpLocalFileUtil.createAndWriteFile( - QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, ddl, graphName); - } catch (Throwable e) { - return runner.getErrorMsg(); - } - return "Create graph " + graphName + " success."; + @Override + public String createGraph(String graphName, String ddl) { + QueryLocalRunner runner = new QueryLocalRunner(); + runner.withGraphName(graphName).withGraphDefine(ddl); + GeaFlowGraph graph; + try { + graph = runner.compileGraph(); + } catch (Throwable e) { + LOGGER.error("Compile error: " + e.getCause().getMessage()); + throw new GeaflowRuntimeException("Compile error: " + e.getCause().getMessage()); } + if (graph == null) { + LOGGER.error("Cannot create graph: " + graphName); + throw new GeaflowRuntimeException("Cannot create graph: " + graphName); + } + // Store graph ddl to schema + try { + McpLocalFileUtil.createAndWriteFile( + QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, ddl, graphName); + } catch (Throwable e) { + return runner.getErrorMsg(); + } + return "Create graph " + graphName + " success."; + } - @Override - public String queryGraph(String graphName, String dml) { - QueryLocalRunner compileRunner = new QueryLocalRunner(); - String ddl = null; - try { - ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); - } catch (Throwable e) { - LOGGER.error("Cannot get graph schema for: " + graphName); - throw new GeaflowRuntimeException("Cannot get graph schema for: " + graphName); - } - compileRunner.withGraphName(graphName).withGraphDefine(ddl); - GeaFlowGraph graph; - try { - graph = compileRunner.compileGraph(); - } catch (Throwable e) { - LOGGER.error("Compile error: " + compileRunner.getErrorMsg()); - throw new GeaflowRuntimeException("Compile error: " + compileRunner.getErrorMsg()); - } - if (graph == null) { - LOGGER.error("Cannot create graph: " + graphName); - throw new GeaflowRuntimeException("Cannot create graph: " + graphName); - } - - QueryLocalRunner runner = new QueryLocalRunner(); - runner.withGraphDefine(ddl); - runner.withQuery(ddl + "\n" + QueryFormatUtil.makeUseGraph(graphName) + dml); - try { - runner.execute(); - } catch (Throwable e) { - LOGGER.error("Run query error: " + e.getCause().getMessage()); - throw new GeaflowRuntimeException("Run query error: " + e.getCause().getMessage()); - } - return "run query success: " + dml; + @Override + public String queryGraph(String graphName, String dml) { + QueryLocalRunner compileRunner = new QueryLocalRunner(); + String ddl = null; + try { + ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); + } catch (Throwable e) { + LOGGER.error("Cannot get graph schema for: " + graphName); + throw new GeaflowRuntimeException("Cannot get graph schema for: " + graphName); + } + compileRunner.withGraphName(graphName).withGraphDefine(ddl); + GeaFlowGraph graph; + try { + graph = compileRunner.compileGraph(); + } catch (Throwable e) { + LOGGER.error("Compile error: " + compileRunner.getErrorMsg()); + throw new GeaflowRuntimeException("Compile error: " + compileRunner.getErrorMsg()); + } + if (graph == null) { + LOGGER.error("Cannot create graph: " + graphName); + throw new GeaflowRuntimeException("Cannot create graph: " + graphName); } - @Override - public String queryType(String graphName, String type) { - QueryLocalRunner compileRunner = new QueryLocalRunner(); - String ddl = null; - try { - ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); - } catch (Throwable e) { - LOGGER.error("Cannot get graph schema for: " + graphName); - return "Cannot get graph schema for: " + graphName; - } - compileRunner.withGraphName(graphName).withGraphDefine(ddl); - GeaFlowGraph graph; - try { - graph = compileRunner.compileGraph(); - } catch (Throwable e) { - return compileRunner.getErrorMsg(); - } - if (graph == null) { - LOGGER.error("Cannot create graph: " + graphName); - throw new GeaflowRuntimeException("Cannot create graph: " + graphName); - } + QueryLocalRunner runner = new QueryLocalRunner(); + runner.withGraphDefine(ddl); + runner.withQuery(ddl + "\n" + QueryFormatUtil.makeUseGraph(graphName) + dml); + try { + runner.execute(); + } catch (Throwable e) { + LOGGER.error("Run query error: " + e.getCause().getMessage()); + throw new GeaflowRuntimeException("Run query error: " + e.getCause().getMessage()); + } + return "run query success: " + dml; + } - QueryLocalRunner runner = new QueryLocalRunner(); - runner.withGraphDefine(ddl); - String dql = null; - String dirName = "query_result_" + Instant.now().toEpochMilli(); - String resultPath = QueryLocalRunner.DSL_STATE_REMOTE_PATH + "/" + dirName; - GeaFlowTable resultTable = null; - for (GeaFlowGraph.VertexTable vertexTable : graph.getVertexTables()) { - if (vertexTable.getTypeName().equals(type)) { - dql = QueryFormatUtil.makeResultTable(vertexTable, resultPath) - + "\n" + QueryFormatUtil.makeEntityTableQuery(vertexTable); - resultTable = vertexTable; - } - } - for (GeaFlowGraph.EdgeTable edgeTable : graph.getEdgeTables()) { - if (edgeTable.getTypeName().equals(type)) { - dql = QueryFormatUtil.makeResultTable(edgeTable, resultPath) - + "\n" + QueryFormatUtil.makeEntityTableQuery(edgeTable); - resultTable = edgeTable; - } - } - if (resultTable == null) { - LOGGER.error("Cannot find type: " + type + " in graph: " + graphName); - throw new GeaflowRuntimeException("Cannot find type: " + type + " in graph: " + graphName); - } - runner.withQuery(ddl + "\n" + QueryFormatUtil.makeUseGraph(graphName) + dql); - String resultContent = "null"; - try { - runner.execute(); - resultContent = readFile(resultPath); - } catch (Throwable e) { - return runner.getErrorMsg(); - } - String schemaContent = "type: " + type + "\nschema: " + resultTable.getFields().stream() - .map(TableField::getName).collect(Collectors.joining("|")); - return schemaContent + "\n" + resultContent; + @Override + public String queryType(String graphName, String type) { + QueryLocalRunner compileRunner = new QueryLocalRunner(); + String ddl = null; + try { + ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); + } catch (Throwable e) { + LOGGER.error("Cannot get graph schema for: " + graphName); + return "Cannot get graph schema for: " + graphName; + } + compileRunner.withGraphName(graphName).withGraphDefine(ddl); + GeaFlowGraph graph; + try { + graph = compileRunner.compileGraph(); + } catch (Throwable e) { + return compileRunner.getErrorMsg(); + } + if (graph == null) { + LOGGER.error("Cannot create graph: " + graphName); + throw new GeaflowRuntimeException("Cannot create graph: " + graphName); } - @Override - public String getGraphSchema(String graphName) { - String ddl = null; - try { - ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); - } catch (Throwable e) { - LOGGER.error("Cannot get graph schema for: " + graphName); - return "Cannot get graph schema for: " + graphName; - } - return ddl; + QueryLocalRunner runner = new QueryLocalRunner(); + runner.withGraphDefine(ddl); + String dql = null; + String dirName = "query_result_" + Instant.now().toEpochMilli(); + String resultPath = QueryLocalRunner.DSL_STATE_REMOTE_PATH + "/" + dirName; + GeaFlowTable resultTable = null; + for (GeaFlowGraph.VertexTable vertexTable : graph.getVertexTables()) { + if (vertexTable.getTypeName().equals(type)) { + dql = + QueryFormatUtil.makeResultTable(vertexTable, resultPath) + + "\n" + + QueryFormatUtil.makeEntityTableQuery(vertexTable); + resultTable = vertexTable; + } + } + for (GeaFlowGraph.EdgeTable edgeTable : graph.getEdgeTables()) { + if (edgeTable.getTypeName().equals(type)) { + dql = + QueryFormatUtil.makeResultTable(edgeTable, resultPath) + + "\n" + + QueryFormatUtil.makeEntityTableQuery(edgeTable); + resultTable = edgeTable; + } + } + if (resultTable == null) { + LOGGER.error("Cannot find type: " + type + " in graph: " + graphName); + throw new GeaflowRuntimeException("Cannot find type: " + type + " in graph: " + graphName); } + runner.withQuery(ddl + "\n" + QueryFormatUtil.makeUseGraph(graphName) + dql); + String resultContent = "null"; + try { + runner.execute(); + resultContent = readFile(resultPath); + } catch (Throwable e) { + return runner.getErrorMsg(); + } + String schemaContent = + "type: " + + type + + "\nschema: " + + resultTable.getFields().stream() + .map(TableField::getName) + .collect(Collectors.joining("|")); + return schemaContent + "\n" + resultContent; + } - @Override - public void withUser(String user) { - this.user = user; + @Override + public String getGraphSchema(String graphName) { + String ddl = null; + try { + ddl = McpLocalFileUtil.readFile(QueryLocalRunner.DSL_STATE_REMOTE_SCHEM_PATH, graphName); + } catch (Throwable e) { + LOGGER.error("Cannot get graph schema for: " + graphName); + return "Cannot get graph schema for: " + graphName; } + return ddl; + } - private String readFile(String path) throws IOException { - File file = new File(path); - if (file.isHidden()) { - return ""; - } - if (file.isFile()) { - return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); - } - File[] files = file.listFiles(); - StringBuilder content = new StringBuilder(); - List readTextList = new ArrayList<>(); - if (files != null) { - for (File subFile : files) { - String readText = readFile(subFile.getAbsolutePath()); - if (StringUtils.isBlank(readText)) { - continue; - } - readTextList.add(readText); - } - } - readTextList = readTextList.stream().sorted().collect(Collectors.toList()); - for (String readText : readTextList) { - if (content.length() > 0) { - content.append("\n"); - } - content.append(readText); - } - return content.toString().trim(); + @Override + public void withUser(String user) { + this.user = user; + } + + private String readFile(String path) throws IOException { + File file = new File(path); + if (file.isHidden()) { + return ""; + } + if (file.isFile()) { + return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); + } + File[] files = file.listFiles(); + StringBuilder content = new StringBuilder(); + List readTextList = new ArrayList<>(); + if (files != null) { + for (File subFile : files) { + String readText = readFile(subFile.getAbsolutePath()); + if (StringUtils.isBlank(readText)) { + continue; + } + readTextList.add(readText); + } + } + readTextList = readTextList.stream().sorted().collect(Collectors.toList()); + for (String readText : readTextList) { + if (content.length() > 0) { + content.append("\n"); + } + content.append(readText); } + return content.toString().trim(); + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServer.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServer.java index d580c9f01..24f3d2cd4 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServer.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServer.java @@ -25,26 +25,28 @@ import org.noear.solon.ai.mcp.server.resource.MethodResourceProvider; import org.noear.solon.annotation.Controller; -/** - * Main entrance, support jdk8+ environment. - */ +/** Main entrance, support jdk8+ environment. */ @Controller public class GeaFlowMcpServer { - private static final String SERVER_NAME = "geaflow-mcp-server"; - private static final String SSE_CHANNEL = "sse"; - private static final String SSE_ENDPOINT = "/geaflow/sse"; + private static final String SERVER_NAME = "geaflow-mcp-server"; + private static final String SSE_CHANNEL = "sse"; + private static final String SSE_ENDPOINT = "/geaflow/sse"; - public static void main(String[] args) { - Solon.start(GeaFlowMcpServer.class, args, app -> { - // Manually build the mcp service endpoint. - McpServerEndpointProvider endpointProvider = McpServerEndpointProvider.builder() - .name(SERVER_NAME) - .channel(SSE_CHANNEL) - .sseEndpoint(SSE_ENDPOINT) - .build(); - endpointProvider.addTool(new MethodToolProvider(new GeaFlowMcpServerTools())); - endpointProvider.addResource(new MethodResourceProvider(new GeaFlowMcpServerTools())); - endpointProvider.postStart(); + public static void main(String[] args) { + Solon.start( + GeaFlowMcpServer.class, + args, + app -> { + // Manually build the mcp service endpoint. + McpServerEndpointProvider endpointProvider = + McpServerEndpointProvider.builder() + .name(SERVER_NAME) + .channel(SSE_CHANNEL) + .sseEndpoint(SSE_ENDPOINT) + .build(); + endpointProvider.addTool(new MethodToolProvider(new GeaFlowMcpServerTools())); + endpointProvider.addResource(new MethodResourceProvider(new GeaFlowMcpServerTools())); + endpointProvider.postStart(); }); - } + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServerTools.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServerTools.java index e19625d94..3af7748d1 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServerTools.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/GeaFlowMcpServerTools.java @@ -19,9 +19,8 @@ package org.apache.geaflow.mcp.server; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONObject; import java.util.Map; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.analytics.service.client.AnalyticsClientBuilder; import org.apache.geaflow.analytics.service.query.QueryResults; @@ -35,177 +34,183 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; + @McpServerEndpoint(name = "geaflow-mcp-server", channel = "sse", sseEndpoint = "/geaflow/sse") public class GeaFlowMcpServerTools { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpServerTools.class); - - private static final String RETRY_TIMES = "analytics.retry.times"; - private static final int DEFAULT_RETRY_TIMES = 3; - private static final String ERROR = "error"; - public static final String SERVER_HOST = "analytics.server.host"; - public static final String SERVER_PORT = "analytics.server.port"; - public static final String SERVER_USER = "analytics.query.user"; - public static final String QUERY_TIMEOUT_MS = "analytics.query.timeout.ms"; - public static final String INIT_CHANNEL_POOLS = "analytics.init.channel.pools"; - public static final String CONFIG = "analytics.client.config"; - public static final String CURRENT_VERSION = "v1.0.0"; - - /** - * Resource that provides getting geaflow mcp server version. - * - * @return version id. - */ - @ResourceMapping(uri = "config://mcp-server-version", description = "Get mcp server version") - public String getServerVersion() { - return CURRENT_VERSION; + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpServerTools.class); + + private static final String RETRY_TIMES = "analytics.retry.times"; + private static final int DEFAULT_RETRY_TIMES = 3; + private static final String ERROR = "error"; + public static final String SERVER_HOST = "analytics.server.host"; + public static final String SERVER_PORT = "analytics.server.port"; + public static final String SERVER_USER = "analytics.query.user"; + public static final String QUERY_TIMEOUT_MS = "analytics.query.timeout.ms"; + public static final String INIT_CHANNEL_POOLS = "analytics.init.channel.pools"; + public static final String CONFIG = "analytics.client.config"; + public static final String CURRENT_VERSION = "v1.0.0"; + + /** + * Resource that provides getting geaflow mcp server version. + * + * @return version id. + */ + @ResourceMapping(uri = "config://mcp-server-version", description = "Get mcp server version") + public String getServerVersion() { + return CURRENT_VERSION; + } + + /** + * A tool that provides graph query capabilities. + * + * @param query GQL query. + * @return query result or error code. + */ + public String executeQuery(@Param(name = "query", description = "query") String query) { + AnalyticsClient analyticsClient = null; + + try { + Map config = YamlParser.loadConfig(); + int retryTimes = DEFAULT_RETRY_TIMES; + if (config.containsKey(RETRY_TIMES)) { + retryTimes = Integer.parseInt(config.get(RETRY_TIMES).toString()); + } + + AnalyticsClientBuilder builder = + AnalyticsClient.builder() + .withHost(config.get(SERVER_HOST).toString()) + .withPort((Integer) config.get(SERVER_PORT)) + .withRetryNum(retryTimes); + if (config.containsKey(CONFIG)) { + Map clientConfig = + JSON.parseObject(config.get(CONFIG).toString(), Map.class); + Configuration configuration = new Configuration(clientConfig); + builder.withConfiguration(configuration); + LOGGER.info("client config: {}", configuration); + } + if (config.containsKey(SERVER_USER)) { + builder.withUser(config.get(SERVER_USER).toString()); + } + if (config.containsKey(QUERY_TIMEOUT_MS)) { + builder.withTimeoutMs((Integer) config.get(QUERY_TIMEOUT_MS)); + } + if (config.containsKey(INIT_CHANNEL_POOLS)) { + builder.withInitChannelPools((Boolean) config.get(INIT_CHANNEL_POOLS)); + } + analyticsClient = builder.build(); + + QueryResults queryResults = analyticsClient.executeQuery(query); + if (queryResults.getError() != null) { + final JSONObject error = new JSONObject(); + error.put(ERROR, queryResults.getError()); + return error.toJSONString(); + } + return queryResults.getFormattedData(); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new RuntimeException(e); + } finally { + if (analyticsClient != null) { + analyticsClient.shutdown(); + } } - - /** - * A tool that provides graph query capabilities. - * - * @param query GQL query. - * @return query result or error code. - */ - public String executeQuery(@Param(name = "query", description = "query") String query) { - AnalyticsClient analyticsClient = null; - - try { - Map config = YamlParser.loadConfig(); - int retryTimes = DEFAULT_RETRY_TIMES; - if (config.containsKey(RETRY_TIMES)) { - retryTimes = Integer.parseInt(config.get(RETRY_TIMES).toString()); - } - - AnalyticsClientBuilder builder = AnalyticsClient - .builder() - .withHost(config.get(SERVER_HOST).toString()) - .withPort((Integer) config.get(SERVER_PORT)) - .withRetryNum(retryTimes); - if (config.containsKey(CONFIG)) { - Map clientConfig = JSON.parseObject(config.get(CONFIG).toString(), Map.class); - Configuration configuration = new Configuration(clientConfig); - builder.withConfiguration(configuration); - LOGGER.info("client config: {}", configuration); - } - if (config.containsKey(SERVER_USER)) { - builder.withUser(config.get(SERVER_USER).toString()); - } - if (config.containsKey(QUERY_TIMEOUT_MS)) { - builder.withTimeoutMs((Integer) config.get(QUERY_TIMEOUT_MS)); - } - if (config.containsKey(INIT_CHANNEL_POOLS)) { - builder.withInitChannelPools((Boolean) config.get(INIT_CHANNEL_POOLS)); - } - analyticsClient = builder.build(); - - QueryResults queryResults = analyticsClient.executeQuery(query); - if (queryResults.getError() != null) { - final JSONObject error = new JSONObject(); - error.put(ERROR, queryResults.getError()); - return error.toJSONString(); - } - return queryResults.getFormattedData(); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new RuntimeException(e); - } finally { - if (analyticsClient != null) { - analyticsClient.shutdown(); - } - } + } + + /** + * A tool that provides create graph capabilities. + * + * @param graphName graph name to create. + * @param ddl Create graph ddl. + * @return execute result or error message. + */ + @ToolMapping(description = ToolDesc.createGraph) + public String createGraph( + @Param(name = McpConstants.GRAPH_NAME, description = "create graph name") String graphName, + @Param(name = McpConstants.DDL, description = "create graph ddl") String ddl) { + try { + Map config = YamlParser.loadConfig(); + GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); + if (config.containsKey(SERVER_USER)) { + mcpActions.withUser(config.get(SERVER_USER).toString()); + } + return mcpActions.createGraph(graphName, ddl); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + return e.getMessage(); } - - - /** - * A tool that provides create graph capabilities. - * - * @param graphName graph name to create. - * @param ddl Create graph ddl. - * @return execute result or error message. - */ - @ToolMapping(description = ToolDesc.createGraph) - public String createGraph(@Param(name = McpConstants.GRAPH_NAME, description = "create graph name") String graphName, - @Param(name = McpConstants.DDL, description = "create graph ddl") String ddl) { - try { - Map config = YamlParser.loadConfig(); - GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); - if (config.containsKey(SERVER_USER)) { - mcpActions.withUser(config.get(SERVER_USER).toString()); - } - return mcpActions.createGraph(graphName, ddl); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - return e.getMessage(); - } - } - - /** - * A tool that get graph schema. - * - * @param graphName graphName to get. - * @return execute result or error message. - */ - @ToolMapping(description = ToolDesc.getGraphSchema) - public String getGraphSchema(@Param(name = McpConstants.GRAPH_NAME, description = "get graph schema name") String graphName) { - try { - Map config = YamlParser.loadConfig(); - GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); - if (config.containsKey(SERVER_USER)) { - mcpActions.withUser(config.get(SERVER_USER).toString()); - } - return mcpActions.getGraphSchema(graphName); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - return e.getMessage(); - } + } + + /** + * A tool that get graph schema. + * + * @param graphName graphName to get. + * @return execute result or error message. + */ + @ToolMapping(description = ToolDesc.getGraphSchema) + public String getGraphSchema( + @Param(name = McpConstants.GRAPH_NAME, description = "get graph schema name") + String graphName) { + try { + Map config = YamlParser.loadConfig(); + GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); + if (config.containsKey(SERVER_USER)) { + mcpActions.withUser(config.get(SERVER_USER).toString()); + } + return mcpActions.getGraphSchema(graphName); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + return e.getMessage(); } - - - /** - * A tool that provides insert data into graph capabilities. - * - * @param graphName graph name to operate. - * @param dml dml to run with graph. - * @return execute result or error message. - */ - @ToolMapping(description = ToolDesc.insertGraph) - public String insertGraph(@Param(name = McpConstants.GRAPH_NAME, description = "graph name") String graphName, - @Param(name = McpConstants.DML, description = "dml insert values into graph") String dml) { - try { - Map config = YamlParser.loadConfig(); - GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); - if (config.containsKey(SERVER_USER)) { - mcpActions.withUser(config.get(SERVER_USER).toString()); - } - return mcpActions.queryGraph(graphName, dml); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - return e.getMessage(); - } + } + + /** + * A tool that provides insert data into graph capabilities. + * + * @param graphName graph name to operate. + * @param dml dml to run with graph. + * @return execute result or error message. + */ + @ToolMapping(description = ToolDesc.insertGraph) + public String insertGraph( + @Param(name = McpConstants.GRAPH_NAME, description = "graph name") String graphName, + @Param(name = McpConstants.DML, description = "dml insert values into graph") String dml) { + try { + Map config = YamlParser.loadConfig(); + GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); + if (config.containsKey(SERVER_USER)) { + mcpActions.withUser(config.get(SERVER_USER).toString()); + } + return mcpActions.queryGraph(graphName, dml); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + return e.getMessage(); } - - - /** - * A tool that provides graph query capabilities. - * - * @param graphName graph name to query. - * @param type query graph entity type. - * @return execute result or error message. - */ - @ToolMapping(description = ToolDesc.queryType) - public String queryType(@Param(name = McpConstants.GRAPH_NAME, description = "query graph name") String graphName, - @Param(name = McpConstants.TYPE, description = "query graph vertex or edge type name") String type) { - try { - Map config = YamlParser.loadConfig(); - GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); - if (config.containsKey(SERVER_USER)) { - mcpActions.withUser(config.get(SERVER_USER).toString()); - } - return mcpActions.queryType(graphName, type); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - return e.getMessage(); - } + } + + /** + * A tool that provides graph query capabilities. + * + * @param graphName graph name to query. + * @param type query graph entity type. + * @return execute result or error message. + */ + @ToolMapping(description = ToolDesc.queryType) + public String queryType( + @Param(name = McpConstants.GRAPH_NAME, description = "query graph name") String graphName, + @Param(name = McpConstants.TYPE, description = "query graph vertex or edge type name") + String type) { + try { + Map config = YamlParser.loadConfig(); + GeaFlowMcpActions mcpActions = new GeaFlowMcpActionsLocalImpl(config); + if (config.containsKey(SERVER_USER)) { + mcpActions.withUser(config.get(SERVER_USER).toString()); + } + return mcpActions.queryType(graphName, type); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + return e.getMessage(); } - + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/ToolDesc.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/ToolDesc.java index 41bcd896d..ddb89201a 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/ToolDesc.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/ToolDesc.java @@ -21,44 +21,47 @@ public class ToolDesc { - public static final String createGraph = "create graph with ddl, Set the storeType to rocksdb, " - + "ensuring the syntax is correct and do not use any syntax not present in the examples. " - + "DDL statements must end with a semicolon. " - + "example: CREATE GRAPH `modern` (\n" - + "\tVertex `person` (\n" - + "\t `id` bigint ID,\n" - + "\t `name` varchar,\n" - + "\t `age` int\n" - + "\t),\n" - + "\tVertex `software` (\n" - + "\t `id` bigint ID,\n" - + "\t `name` varchar,\n" - + "\t `lang` varchar\n" - + "\t),\n" - + "\tEdge `knows` (\n" - + "\t `srcId` bigint SOURCE ID,\n" - + "\t `targetId` bigint DESTINATION ID,\n" - + "\t `weight` double\n" - + "\t),\n" - + "\tEdge `created` (\n" - + "\t `srcId` bigint SOURCE ID,\n" - + " \t`targetId` bigint DESTINATION ID,\n" - + " \t`weight` double\n" - + "\t)\n" - + ") WITH (\n" - + "\tstoreType='rocksdb'\n" - + ");"; + public static final String createGraph = + "create graph with ddl, Set the storeType to rocksdb, " + + "ensuring the syntax is correct and do not use any syntax not present in the examples. " + + "DDL statements must end with a semicolon. " + + "example: CREATE GRAPH `modern` (\n" + + "\tVertex `person` (\n" + + "\t `id` bigint ID,\n" + + "\t `name` varchar,\n" + + "\t `age` int\n" + + "\t),\n" + + "\tVertex `software` (\n" + + "\t `id` bigint ID,\n" + + "\t `name` varchar,\n" + + "\t `lang` varchar\n" + + "\t),\n" + + "\tEdge `knows` (\n" + + "\t `srcId` bigint SOURCE ID,\n" + + "\t `targetId` bigint DESTINATION ID,\n" + + "\t `weight` double\n" + + "\t),\n" + + "\tEdge `created` (\n" + + "\t `srcId` bigint SOURCE ID,\n" + + " \t`targetId` bigint DESTINATION ID,\n" + + " \t`weight` double\n" + + "\t)\n" + + ") WITH (\n" + + "\tstoreType='rocksdb'\n" + + ");"; - public static final String insertGraph = "Insert into graph with dml. " - + "A single call can only insert data into one vertex or edge type, and can only use the VALUES syntax. " - + "Do not use any syntax not present in the examples. " - + "example: INSERT INTO `modern`.`person`(`id`, `name`, `age`)\n" - + "VALUES (1, 'jim', 20), (2, 'kate', 22)\n" - + ";"; + public static final String insertGraph = + "Insert into graph with dml. A single call can only insert data into one vertex or edge type," + + " and can only use the VALUES syntax. Do not use any syntax not present in the" + + " examples. example: INSERT INTO `modern`.`person`(`id`, `name`, `age`)\n" + + "VALUES (1, 'jim', 20), (2, 'kate', 22)\n" + + ";"; - public static final String queryType = "You need to provide the graph name and the type name of the vertex or edge " - + "type to be queried. The query tool will return all data of this type in the graph. A single call can only " - + "query data from one vertex or edge type, and only one type name needs to be provided."; + public static final String queryType = + "You need to provide the graph name and the type name of the vertex or edge type to be" + + " queried. The query tool will return all data of this type in the graph. A single call" + + " can only query data from one vertex or edge type, and only one type name needs to be" + + " provided."; - public static final String getGraphSchema = "query graph schema."; + public static final String getGraphSchema = "query graph schema."; } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpConstants.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpConstants.java index 1b00fd6c7..e37a62454 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpConstants.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpConstants.java @@ -21,14 +21,14 @@ public class McpConstants { - public static final String CREATE_GRAPH = "create_graph"; - public static final String GRAPH_NAME = "graph_name"; - public static final String DDL = "create_graph_ddl"; - public static final String DQL = "query_graph_dql"; - public static final String DML = "insert_graph_dml"; - public static final String TYPE = "query_type_name"; - public static final String CREATE_GRAPH_TOOL_NAME = "createGraph"; - public static final String INSERT_GRAPH_TOOL_NAME = "insertGraph"; - public static final String QUERY_GRAPH_TOOL_NAME = "queryGraph"; - public static final String QUERY_TYPE_TOOL_NAME = "queryType"; + public static final String CREATE_GRAPH = "create_graph"; + public static final String GRAPH_NAME = "graph_name"; + public static final String DDL = "create_graph_ddl"; + public static final String DQL = "query_graph_dql"; + public static final String DML = "insert_graph_dml"; + public static final String TYPE = "query_type_name"; + public static final String CREATE_GRAPH_TOOL_NAME = "createGraph"; + public static final String INSERT_GRAPH_TOOL_NAME = "insertGraph"; + public static final String QUERY_GRAPH_TOOL_NAME = "queryGraph"; + public static final String QUERY_TYPE_TOOL_NAME = "queryType"; } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpLocalFileUtil.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpLocalFileUtil.java index 2852ea8b9..52323e85d 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpLocalFileUtil.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/McpLocalFileUtil.java @@ -29,38 +29,38 @@ public class McpLocalFileUtil { + public static String createAndWriteFile(String root, String text, String... fileNames) + throws IOException { + Files.createDirectories(Paths.get(root)); + String fileName = "execute_query_" + Instant.now().toEpochMilli(); + if (fileNames != null && fileNames.length > 0) { + fileName = fileNames[0]; + } - public static String createAndWriteFile(String root, String text, String... fileNames) throws IOException { - Files.createDirectories(Paths.get(root)); - String fileName = "execute_query_" + Instant.now().toEpochMilli(); - if (fileNames != null && fileNames.length > 0) { - fileName = fileNames[0]; - } - - String fullPath = Paths.get(root, fileName).toString(); - - try (FileWriter writer = new FileWriter(fullPath)) { - if (text != null) { - writer.write(text); - } - } catch (IOException e) { - e.printStackTrace(); - } + String fullPath = Paths.get(root, fileName).toString(); - return fileName; + try (FileWriter writer = new FileWriter(fullPath)) { + if (text != null) { + writer.write(text); + } + } catch (IOException e) { + e.printStackTrace(); } - public static String readFile(String root, String fileName) throws IOException { - Path filePath = Paths.get(root, fileName); + return fileName; + } - if (!Files.exists(filePath)) { - throw new IOException("File not exist: " + filePath); - } + public static String readFile(String root, String fileName) throws IOException { + Path filePath = Paths.get(root, fileName); - if (!Files.isRegularFile(filePath)) { - throw new IOException("Path is not file: " + filePath); - } + if (!Files.exists(filePath)) { + throw new IOException("File not exist: " + filePath); + } - return new String(Files.readAllBytes(filePath), StandardCharsets.UTF_8); + if (!Files.isRegularFile(filePath)) { + throw new IOException("Path is not file: " + filePath); } + + return new String(Files.readAllBytes(filePath), StandardCharsets.UTF_8); + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryFormatUtil.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryFormatUtil.java index ec739308d..407b2e18c 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryFormatUtil.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryFormatUtil.java @@ -20,80 +20,85 @@ package org.apache.geaflow.mcp.server.util; import java.util.Locale; + import org.apache.geaflow.dsl.common.types.TableField; import org.apache.geaflow.dsl.schema.GeaFlowGraph; import org.apache.geaflow.dsl.schema.GeaFlowTable; public class QueryFormatUtil { - public static String makeResultTable(GeaFlowTable table, String resultPath) { - StringBuilder builder = new StringBuilder(); - builder.append("CREATE TABLE output_table(\n"); - int index = 0; - for (TableField field : table.getFields()) { - if (index > 0) { - builder.append(",\n"); - } - builder.append("`").append(field.getName()).append("` ") - .append(tableTypeMapper(field.getType().getName())); - index++; - } - builder.append("\n) WITH ( \n"); - builder.append("type='file',\n"); - builder.append("geaflow.dsl.file.path='").append(resultPath).append("'\n"); - builder.append(");\n"); - return builder.toString(); + public static String makeResultTable(GeaFlowTable table, String resultPath) { + StringBuilder builder = new StringBuilder(); + builder.append("CREATE TABLE output_table(\n"); + int index = 0; + for (TableField field : table.getFields()) { + if (index > 0) { + builder.append(",\n"); + } + builder + .append("`") + .append(field.getName()) + .append("` ") + .append(tableTypeMapper(field.getType().getName())); + index++; } + builder.append("\n) WITH ( \n"); + builder.append("type='file',\n"); + builder.append("geaflow.dsl.file.path='").append(resultPath).append("'\n"); + builder.append(");\n"); + return builder.toString(); + } - public static String makeEntityTableQuery(GeaFlowTable table) { - StringBuilder builder = new StringBuilder(); - builder.append("INSERT INTO output_table\n"); - if (table instanceof GeaFlowGraph.VertexTable) { - builder.append("Match(a:`") - .append(((GeaFlowGraph.VertexTable)table).getTypeName()) - .append("`)\n"); - } else { - builder.append("Match()-[a:`") - .append(((GeaFlowGraph.EdgeTable)table).getTypeName()) - .append("`]-()\n"); - } - builder.append("Return "); - int index = 0; - for (TableField field : table.getFields()) { - if (index > 0) { - builder.append(",\n"); - } - builder.append("a.`").append(field.getName()).append("` "); - index++; - } - builder.append("\n;\n"); - return builder.toString(); + public static String makeEntityTableQuery(GeaFlowTable table) { + StringBuilder builder = new StringBuilder(); + builder.append("INSERT INTO output_table\n"); + if (table instanceof GeaFlowGraph.VertexTable) { + builder + .append("Match(a:`") + .append(((GeaFlowGraph.VertexTable) table).getTypeName()) + .append("`)\n"); + } else { + builder + .append("Match()-[a:`") + .append(((GeaFlowGraph.EdgeTable) table).getTypeName()) + .append("`]-()\n"); } - - public static String makeUseGraph(String graphName) { - return "USE GRAPH " + graphName + ";\n"; + builder.append("Return "); + int index = 0; + for (TableField field : table.getFields()) { + if (index > 0) { + builder.append(",\n"); + } + builder.append("a.`").append(field.getName()).append("` "); + index++; } + builder.append("\n;\n"); + return builder.toString(); + } - private static String tableTypeMapper(String iType) { - String upper = iType.toUpperCase(Locale.ROOT); - switch (upper) { - case "STRING": - case "BINARY_STRING": - case "VARCHAR": - return "VARCHAR"; - case "LONG": - case "INTEGER": - case "SHORT": - return "BIGINT"; - case "FLOAT": - case "DOUBLE": - return "DOUBLE"; - case "BOOL": - case "BOOLEAN": - return "BOOL"; - default: - throw new RuntimeException("Cannt convert type name: " + iType); - } - } + public static String makeUseGraph(String graphName) { + return "USE GRAPH " + graphName + ";\n"; + } + private static String tableTypeMapper(String iType) { + String upper = iType.toUpperCase(Locale.ROOT); + switch (upper) { + case "STRING": + case "BINARY_STRING": + case "VARCHAR": + return "VARCHAR"; + case "LONG": + case "INTEGER": + case "SHORT": + return "BIGINT"; + case "FLOAT": + case "DOUBLE": + return "DOUBLE"; + case "BOOL": + case "BOOLEAN": + return "BOOL"; + default: + throw new RuntimeException("Cannt convert type name: " + iType); + } + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryLocalRunner.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryLocalRunner.java index 3f662a489..8be637130 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryLocalRunner.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/server/util/QueryLocalRunner.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.cluster.system.ClusterMetaStore; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; @@ -42,101 +43,104 @@ public class QueryLocalRunner { - - public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/mcp"; - public static final String DSL_STATE_REMOTE_SCHEM_PATH = "/tmp/dsl/mcp/schema"; - private String graphDefine; - private String graphName; - private String errorMsg; - private String query; - private final Map config = new HashMap<>(); - - public QueryLocalRunner withConfig(Map config) { - this.config.putAll(config); - return this; - } - - public QueryLocalRunner withConfig(String key, Object value) { - this.config.put(key, String.valueOf(value)); - return this; - } - - public QueryLocalRunner withGraphDefine(String graphDefine) { - this.graphDefine = Objects.requireNonNull(graphDefine); - return this; + public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/mcp"; + public static final String DSL_STATE_REMOTE_SCHEM_PATH = "/tmp/dsl/mcp/schema"; + private String graphDefine; + private String graphName; + private String errorMsg; + private String query; + private final Map config = new HashMap<>(); + + public QueryLocalRunner withConfig(Map config) { + this.config.putAll(config); + return this; + } + + public QueryLocalRunner withConfig(String key, Object value) { + this.config.put(key, String.valueOf(value)); + return this; + } + + public QueryLocalRunner withGraphDefine(String graphDefine) { + this.graphDefine = Objects.requireNonNull(graphDefine); + return this; + } + + public QueryLocalRunner withGraphName(String graphName) { + this.graphName = Objects.requireNonNull(graphName); + return this; + } + + public QueryLocalRunner withQuery(String query) { + this.query = Objects.requireNonNull(query); + return this; + } + + public String getErrorMsg() { + return errorMsg; + } + + public GeaFlowGraph compileGraph() throws Exception { + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); + if (this.graphDefine == null) { + throw new RuntimeException("Create graph ddl is empty"); } - - public QueryLocalRunner withGraphName(String graphName) { - this.graphName = Objects.requireNonNull(graphName); - return this; + config.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE.getKey(), "memory"); + config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); + String fileName = McpLocalFileUtil.createAndWriteFile(DSL_STATE_REMOTE_PATH, this.graphDefine); + config.put( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), DSL_STATE_REMOTE_PATH + "/" + fileName); + config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file"); + config.putAll(this.config); + Environment environment = EnvironmentFactory.onLocalEnvironment(); + environment.getEnvironmentContext().withConfig(config); + // Compile graph name + CompileContext compileContext = new CompileContext(); + config.put(DSLConfigKeys.GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), "false"); + compileContext.setConfig(config); + PipelineContext pipelineContext = + new PipelineContext( + PipelineTaskType.CompileTask.name(), new Configuration(compileContext.getConfig())); + PipelineTaskContext pipelineTaskCxt = new PipelineTaskContext(0L, pipelineContext); + QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); + QueryContext queryContext = + QueryContext.builder() + .setEngineContext(engineContext) + .setCompile(true) + .setTraversalParallelism(-1) + .build(); + QueryClient queryClient = new QueryClient(); + queryClient.executeQuery(this.graphDefine, queryContext); + return queryContext.getGraph(graphName); + } + + public QueryLocalRunner execute() throws Exception { + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); + if (this.graphDefine == null) { + throw new RuntimeException("Create graph ddl is empty"); } - - public QueryLocalRunner withQuery(String query) { - this.query = Objects.requireNonNull(query); - return this; - } - - public String getErrorMsg() { - return errorMsg; - } - - public GeaFlowGraph compileGraph() throws Exception { - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); - if (this.graphDefine == null) { - throw new RuntimeException("Create graph ddl is empty"); - } - config.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE.getKey(), "memory"); - config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); - String fileName = McpLocalFileUtil.createAndWriteFile(DSL_STATE_REMOTE_PATH, this.graphDefine); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), DSL_STATE_REMOTE_PATH + "/" + fileName); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file"); - config.putAll(this.config); - Environment environment = EnvironmentFactory.onLocalEnvironment(); - environment.getEnvironmentContext().withConfig(config); - // Compile graph name - CompileContext compileContext = new CompileContext(); - config.put(DSLConfigKeys.GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), "false"); - compileContext.setConfig(config); - PipelineContext pipelineContext = new PipelineContext(PipelineTaskType.CompileTask.name(), - new Configuration(compileContext.getConfig())); - PipelineTaskContext pipelineTaskCxt = new PipelineTaskContext(0L, pipelineContext); - QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); - QueryContext queryContext = QueryContext.builder() - .setEngineContext(engineContext) - .setCompile(true) - .setTraversalParallelism(-1) - .build(); - QueryClient queryClient = new QueryClient(); - queryClient.executeQuery(this.graphDefine, queryContext); - return queryContext.getGraph(graphName); - } - - public QueryLocalRunner execute() throws Exception { - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); - if (this.graphDefine == null) { - throw new RuntimeException("Create graph ddl is empty"); - } - config.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE.getKey(), "memory"); - config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); - String fileName = McpLocalFileUtil.createAndWriteFile(DSL_STATE_REMOTE_PATH, this.query); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), DSL_STATE_REMOTE_PATH + "/" + fileName); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file"); - config.put(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1"); - config.putAll(this.config); - - Environment environment = EnvironmentFactory.onLocalEnvironment(); - environment.getEnvironmentContext().withConfig(config); - - GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, 0); - - try { - gqlPipeLine.execute(); - } finally { - environment.shutdown(); - ClusterMetaStore.close(); - } - return this; + config.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE.getKey(), "memory"); + config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); + String fileName = McpLocalFileUtil.createAndWriteFile(DSL_STATE_REMOTE_PATH, this.query); + config.put( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), DSL_STATE_REMOTE_PATH + "/" + fileName); + config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file"); + config.put(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1"); + config.putAll(this.config); + + Environment environment = EnvironmentFactory.onLocalEnvironment(); + environment.getEnvironmentContext().withConfig(config); + + GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, 0); + + try { + gqlPipeLine.execute(); + } finally { + environment.shutdown(); + ClusterMetaStore.close(); } + return this; + } } diff --git a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/util/YamlParser.java b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/util/YamlParser.java index 91dcb4242..351178ab4 100644 --- a/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/util/YamlParser.java +++ b/geaflow-mcp/src/main/java/org/apache/geaflow/mcp/util/YamlParser.java @@ -21,19 +21,19 @@ import java.io.InputStream; import java.util.Map; + import org.yaml.snakeyaml.Yaml; public class YamlParser { - private static final String CONFIG_FILE = "/app.yml"; + private static final String CONFIG_FILE = "/app.yml"; - public static Map loadConfig() { - try (InputStream inputStream = YamlParser.class - .getResourceAsStream(CONFIG_FILE)) { - Yaml yaml = new Yaml(); - return yaml.load(inputStream); - } catch (Exception e) { - throw new RuntimeException(e); - } + public static Map loadConfig() { + try (InputStream inputStream = YamlParser.class.getResourceAsStream(CONFIG_FILE)) { + Yaml yaml = new Yaml(); + return yaml.load(inputStream); + } catch (Exception e) { + throw new RuntimeException(e); } + } } diff --git a/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/GeaFlowMcpClientTest.java b/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/GeaFlowMcpClientTest.java index a564d6381..f0a47b66a 100644 --- a/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/GeaFlowMcpClientTest.java +++ b/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/GeaFlowMcpClientTest.java @@ -19,6 +19,10 @@ package org.apache.geaflow.mcp.server; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.*; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.mcp.server.util.McpConstants; import org.junit.jupiter.api.Assertions; @@ -28,189 +32,166 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.nio.charset.Charset; -import java.util.*; - @SolonTest(GeaFlowMcpServer.class) public class GeaFlowMcpClientTest { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpClientTest.class); - - private static final String SSE_CHANNEL = "sse"; - private static final String SSE_ENDPOINT = "http://localhost:8088/geaflow/sse"; - private static final String QUERY = "query"; - private static final String EXECUTE_QUERY_TOOL_NAME = "executeQuery"; - private static final String GRAPH_G1 = "create graph g1(" - + "vertex user(" - + " id bigint ID," - + "name varchar" - + ")," - + "vertex person(" - + " id bigint ID," - + "name varchar," - + "gender int," - + "age integer" - + ")," - + "edge knows(" - + " src_id bigint SOURCE ID," - + " target_id bigint DESTINATION ID," - + " time bigint TIMESTAMP," - + " weight double" - + ")" - + ")"; - private static final String EXECUTE_QUERY = GRAPH_G1 + "match(a:user)-[e:knows]->(b:user)"; - - /** - * Call execute query tool. - */ - public void testExecuteQuery() { - McpClientProvider toolProvider = McpClientProvider.builder() - .apiUrl(SSE_ENDPOINT) - .build(); - - Map map = Collections.singletonMap(QUERY, EXECUTE_QUERY); - String queryResults = toolProvider.callToolAsText(EXECUTE_QUERY_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - Assertions.assertEquals("{\"error\":{\"code\":1001,\"name\":\"ANALYTICS_SERVER_UNAVAILABLE\"}}", - queryResults); - - } - - @Test - public void testGetServerVersion() { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - String resourceContent = toolProvider.readResourceAsText("config://mcp-server-version").getContent(); - Assertions.assertTrue(resourceContent.contains("v1.0.0")); - } - - @Test - public void testListTools() { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - Set toolNames = new HashSet<>(); - toolProvider.getTools().stream().forEach(tool -> { - LOGGER.info("Tool: {}, desc: {}, input schema: {}", tool.name(), tool.description(), tool.inputSchema()); - toolNames.add(tool.name()); - }); - Assertions.assertTrue(toolNames.contains(McpConstants.CREATE_GRAPH_TOOL_NAME)); - } - - /** - * Call create graph tool. - */ - @Test - public void testCreateGraph() throws IOException { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - String gql = IOUtils.resourceToString("/gql/modern", Charset.defaultCharset()).trim(); - Map map = new HashMap<>(); - map.put(McpConstants.DDL, gql); - map.put(McpConstants.GRAPH_NAME, "modern"); - - String queryResults = toolProvider.callToolAsText(McpConstants.CREATE_GRAPH_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - Assertions.assertEquals("Create graph modern success.", - queryResults); - - } - - /** - * Call insert graph tool. - */ - @Test - public void testInsertGraph() throws IOException { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - String gql = IOUtils.resourceToString("/gql/insert1", Charset.defaultCharset()).trim(); - Map map = new HashMap<>(); - map.put(McpConstants.DML, gql); - map.put(McpConstants.GRAPH_NAME, "modern"); - - String queryResults = toolProvider.callToolAsText(McpConstants.INSERT_GRAPH_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - String licenseHead = IOUtils.resourceToString("/gql/licenseHead", Charset.defaultCharset()); - Assertions.assertEquals("run query success: " + licenseHead + "INSERT INTO modern.person(id, name, age)\n" + - "VALUES (1, 'jim', 20), (2, 'kate', 22)\n" + - ";", - queryResults); - - } - - /** - * Call query graph tool. - */ - @Test - public void testQueryGraphType() throws IOException { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - Map map = new HashMap<>(); - map.put(McpConstants.TYPE, "person"); - map.put(McpConstants.GRAPH_NAME, "modern"); - - String queryResults = toolProvider.callToolAsText(McpConstants.QUERY_TYPE_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - Assertions.assertTrue(queryResults.startsWith("type: person")); - Assertions.assertTrue(queryResults.contains("schema: id|name|age")); - Assertions.assertTrue(queryResults.contains("1,jim,20")); - Assertions.assertTrue(queryResults.contains("2,kate,22")); - - } - - /** - * Call create graph tool with error. - */ - @Test - public void testCreateGraphFailed() throws IOException { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - String gql = IOUtils.resourceToString("/gql/modern_error", Charset.defaultCharset()).trim(); - Map map = new HashMap<>(); - map.put(McpConstants.DDL, gql); - map.put(McpConstants.GRAPH_NAME, "modern"); - - String queryResults = toolProvider.callToolAsText(McpConstants.CREATE_GRAPH_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - Assertions.assertTrue(queryResults.startsWith("Compile error: Encountered \"character\"")); - - } - - /** - * Call insert graph tool. - */ - @Test - public void testInsertGraphWithError() throws IOException { - McpClientProvider toolProvider = McpClientProvider.builder() - .channel(SSE_CHANNEL) - .apiUrl(SSE_ENDPOINT) - .build(); - - String gql = IOUtils.resourceToString("/gql/insert1_error", Charset.defaultCharset()).trim(); - Map map = new HashMap<>(); - map.put(McpConstants.DML, gql); - map.put(McpConstants.GRAPH_NAME, "modern"); - - String queryResults = toolProvider.callToolAsText(McpConstants.INSERT_GRAPH_TOOL_NAME, map).getContent(); - LOGGER.info("queryResults: {}", queryResults); - Assertions.assertTrue(queryResults.contains("Field:er is not found")); - - } + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMcpClientTest.class); + + private static final String SSE_CHANNEL = "sse"; + private static final String SSE_ENDPOINT = "http://localhost:8088/geaflow/sse"; + private static final String QUERY = "query"; + private static final String EXECUTE_QUERY_TOOL_NAME = "executeQuery"; + private static final String GRAPH_G1 = + "create graph g1(" + + "vertex user(" + + " id bigint ID," + + "name varchar" + + ")," + + "vertex person(" + + " id bigint ID," + + "name varchar," + + "gender int," + + "age integer" + + ")," + + "edge knows(" + + " src_id bigint SOURCE ID," + + " target_id bigint DESTINATION ID," + + " time bigint TIMESTAMP," + + " weight double" + + ")" + + ")"; + private static final String EXECUTE_QUERY = GRAPH_G1 + "match(a:user)-[e:knows]->(b:user)"; + + /** Call execute query tool. */ + public void testExecuteQuery() { + McpClientProvider toolProvider = McpClientProvider.builder().apiUrl(SSE_ENDPOINT).build(); + + Map map = Collections.singletonMap(QUERY, EXECUTE_QUERY); + String queryResults = toolProvider.callToolAsText(EXECUTE_QUERY_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + Assertions.assertEquals( + "{\"error\":{\"code\":1001,\"name\":\"ANALYTICS_SERVER_UNAVAILABLE\"}}", queryResults); + } + + @Test + public void testGetServerVersion() { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + String resourceContent = + toolProvider.readResourceAsText("config://mcp-server-version").getContent(); + Assertions.assertTrue(resourceContent.contains("v1.0.0")); + } + + @Test + public void testListTools() { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + Set toolNames = new HashSet<>(); + toolProvider.getTools().stream() + .forEach( + tool -> { + LOGGER.info( + "Tool: {}, desc: {}, input schema: {}", + tool.name(), + tool.description(), + tool.inputSchema()); + toolNames.add(tool.name()); + }); + Assertions.assertTrue(toolNames.contains(McpConstants.CREATE_GRAPH_TOOL_NAME)); + } + + /** Call create graph tool. */ + @Test + public void testCreateGraph() throws IOException { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + String gql = IOUtils.resourceToString("/gql/modern", Charset.defaultCharset()).trim(); + Map map = new HashMap<>(); + map.put(McpConstants.DDL, gql); + map.put(McpConstants.GRAPH_NAME, "modern"); + + String queryResults = + toolProvider.callToolAsText(McpConstants.CREATE_GRAPH_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + Assertions.assertEquals("Create graph modern success.", queryResults); + } + + /** Call insert graph tool. */ + @Test + public void testInsertGraph() throws IOException { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + String gql = IOUtils.resourceToString("/gql/insert1", Charset.defaultCharset()).trim(); + Map map = new HashMap<>(); + map.put(McpConstants.DML, gql); + map.put(McpConstants.GRAPH_NAME, "modern"); + + String queryResults = + toolProvider.callToolAsText(McpConstants.INSERT_GRAPH_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + String licenseHead = IOUtils.resourceToString("/gql/licenseHead", Charset.defaultCharset()); + Assertions.assertEquals( + "run query success: " + + licenseHead + + "INSERT INTO modern.person(id, name, age)\n" + + "VALUES (1, 'jim', 20), (2, 'kate', 22)\n" + + ";", + queryResults); + } + + /** Call query graph tool. */ + @Test + public void testQueryGraphType() throws IOException { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + Map map = new HashMap<>(); + map.put(McpConstants.TYPE, "person"); + map.put(McpConstants.GRAPH_NAME, "modern"); + + String queryResults = + toolProvider.callToolAsText(McpConstants.QUERY_TYPE_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + Assertions.assertTrue(queryResults.startsWith("type: person")); + Assertions.assertTrue(queryResults.contains("schema: id|name|age")); + Assertions.assertTrue(queryResults.contains("1,jim,20")); + Assertions.assertTrue(queryResults.contains("2,kate,22")); + } + + /** Call create graph tool with error. */ + @Test + public void testCreateGraphFailed() throws IOException { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + String gql = IOUtils.resourceToString("/gql/modern_error", Charset.defaultCharset()).trim(); + Map map = new HashMap<>(); + map.put(McpConstants.DDL, gql); + map.put(McpConstants.GRAPH_NAME, "modern"); + + String queryResults = + toolProvider.callToolAsText(McpConstants.CREATE_GRAPH_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + Assertions.assertTrue(queryResults.startsWith("Compile error: Encountered \"character\"")); + } + + /** Call insert graph tool. */ + @Test + public void testInsertGraphWithError() throws IOException { + McpClientProvider toolProvider = + McpClientProvider.builder().channel(SSE_CHANNEL).apiUrl(SSE_ENDPOINT).build(); + + String gql = IOUtils.resourceToString("/gql/insert1_error", Charset.defaultCharset()).trim(); + Map map = new HashMap<>(); + map.put(McpConstants.DML, gql); + map.put(McpConstants.GRAPH_NAME, "modern"); + + String queryResults = + toolProvider.callToolAsText(McpConstants.INSERT_GRAPH_TOOL_NAME, map).getContent(); + LOGGER.info("queryResults: {}", queryResults); + Assertions.assertTrue(queryResults.contains("Field:er is not found")); + } } diff --git a/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/util/YamlParserTest.java b/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/util/YamlParserTest.java index b5af4f02d..462bf10ba 100644 --- a/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/util/YamlParserTest.java +++ b/geaflow-mcp/src/test/java/org/apache/geaflow/mcp/server/util/YamlParserTest.java @@ -19,40 +19,41 @@ package org.apache.geaflow.mcp.server.util; -import com.alibaba.fastjson.JSON; +import static org.apache.geaflow.mcp.server.GeaFlowMcpServerTools.*; + +import java.util.Map; + import org.apache.geaflow.analytics.service.config.AnalyticsClientConfigKeys; import org.apache.geaflow.mcp.util.YamlParser; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Map; - -import static org.apache.geaflow.mcp.server.GeaFlowMcpServerTools.*; +import com.alibaba.fastjson.JSON; public class YamlParserTest { - private static final String LOCAL_HOST = "localhost"; - private static final String NOT_EXIST = "not-exist"; - private static final int LOCAL_SERVER_PORT = 8090; - private static final int DEFAULT_MESSAGE_SIZE = 4194304; + private static final String LOCAL_HOST = "localhost"; + private static final String NOT_EXIST = "not-exist"; + private static final int LOCAL_SERVER_PORT = 8090; + private static final int DEFAULT_MESSAGE_SIZE = 4194304; - @Test - public void testLoadConfig() { - Map config = YamlParser.loadConfig(); - Assertions.assertNotNull(config); + @Test + public void testLoadConfig() { + Map config = YamlParser.loadConfig(); + Assertions.assertNotNull(config); - Assertions.assertTrue(config.containsKey(SERVER_HOST)); - Assertions.assertEquals(LOCAL_HOST, config.get(SERVER_HOST)); + Assertions.assertTrue(config.containsKey(SERVER_HOST)); + Assertions.assertEquals(LOCAL_HOST, config.get(SERVER_HOST)); - Assertions.assertTrue(config.containsKey(SERVER_PORT)); - Assertions.assertEquals(LOCAL_SERVER_PORT, (int) config.get(SERVER_PORT)); + Assertions.assertTrue(config.containsKey(SERVER_PORT)); + Assertions.assertEquals(LOCAL_SERVER_PORT, (int) config.get(SERVER_PORT)); - Assertions.assertFalse(config.containsKey(NOT_EXIST)); + Assertions.assertFalse(config.containsKey(NOT_EXIST)); - Assertions.assertTrue(config.containsKey(CONFIG)); - Map clientConfig = JSON.parseObject(config.get(CONFIG).toString(), Map.class); - String key = AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE.getKey(); - Assertions.assertTrue(clientConfig.containsKey(key)); - Assertions.assertEquals(DEFAULT_MESSAGE_SIZE, clientConfig.get(key)); - } + Assertions.assertTrue(config.containsKey(CONFIG)); + Map clientConfig = JSON.parseObject(config.get(CONFIG).toString(), Map.class); + String key = AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE.getKey(); + Assertions.assertTrue(clientConfig.containsKey(key)); + Assertions.assertEquals(DEFAULT_MESSAGE_SIZE, clientConfig.get(key)); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AbstractQueryRunner.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AbstractQueryRunner.java index e450a1cdb..676f29206 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AbstractQueryRunner.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AbstractQueryRunner.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicReference; + import org.apache.commons.collections4.CollectionUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -35,84 +36,84 @@ import org.slf4j.LoggerFactory; public abstract class AbstractQueryRunner implements IQueryRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractQueryRunner.class); - - protected static final Random RANDOM = new Random(System.currentTimeMillis()); - protected List coordinatorAddresses; - - protected Configuration config; - protected MetaServerQueryClient serverQueryClient; - protected AnalyticsServiceInfo analyticsServiceInfo; - protected HostAndPort hostAndPort; - protected AtomicReference queryRunnerStatus; - private QueryRunnerContext context; - - @Override - public void init(QueryRunnerContext context) { - this.context = context; - this.config = context.getConfiguration(); - this.hostAndPort = context.getHostAndPort(); - initAnalyticsServiceAddress(); - initManagedChannel(); - this.queryRunnerStatus = new AtomicReference<>(QueryRunnerStatus.RUNNING); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractQueryRunner.class); + + protected static final Random RANDOM = new Random(System.currentTimeMillis()); + protected List coordinatorAddresses; + + protected Configuration config; + protected MetaServerQueryClient serverQueryClient; + protected AnalyticsServiceInfo analyticsServiceInfo; + protected HostAndPort hostAndPort; + protected AtomicReference queryRunnerStatus; + private QueryRunnerContext context; + + @Override + public void init(QueryRunnerContext context) { + this.context = context; + this.config = context.getConfiguration(); + this.hostAndPort = context.getHostAndPort(); + initAnalyticsServiceAddress(); + initManagedChannel(); + this.queryRunnerStatus = new AtomicReference<>(QueryRunnerStatus.RUNNING); + } + + protected void initManagedChannel() { + for (HostAndPort coordinatorAddress : this.coordinatorAddresses) { + initManagedChannel(coordinatorAddress); } - - protected void initManagedChannel() { - for (HostAndPort coordinatorAddress : this.coordinatorAddresses) { - initManagedChannel(coordinatorAddress); - } + } + + protected abstract void initManagedChannel(HostAndPort address); + + public void initAnalyticsServiceAddress() { + if (hostAndPort != null) { + String serverName = String.format("%s:%d", hostAndPort.getHost(), hostAndPort.getPort()); + List hostAndPorts = Collections.singletonList(hostAndPort); + this.analyticsServiceInfo = new AnalyticsServiceInfo(hostAndPorts); + this.coordinatorAddresses = hostAndPorts; + LOGGER.info("init single analytics service: [{}] finish", serverName); + return; } - - protected abstract void initManagedChannel(HostAndPort address); - - public void initAnalyticsServiceAddress() { - if (hostAndPort != null) { - String serverName = String.format("%s:%d", hostAndPort.getHost(), hostAndPort.getPort()); - List hostAndPorts = Collections.singletonList(hostAndPort); - this.analyticsServiceInfo = new AnalyticsServiceInfo(hostAndPorts); - this.coordinatorAddresses = hostAndPorts; - LOGGER.info("init single analytics service: [{}] finish", serverName); - return; - } - this.serverQueryClient = MetaServerQueryClient.getClient(config); - List serviceAddresses = getServiceAddresses(); - this.coordinatorAddresses = serviceAddresses; - this.analyticsServiceInfo = new AnalyticsServiceInfo(serviceAddresses); - LOGGER.info("init analytics service finish by meta server"); + this.serverQueryClient = MetaServerQueryClient.getClient(config); + List serviceAddresses = getServiceAddresses(); + this.coordinatorAddresses = serviceAddresses; + this.analyticsServiceInfo = new AnalyticsServiceInfo(serviceAddresses); + LOGGER.info("init analytics service finish by meta server"); + } + + private List getServiceAddresses() { + List hostAndPorts; + try { + hostAndPorts = serverQueryClient.queryAllServices(DEFAULT); + } catch (Throwable e) { + throw new GeaflowRuntimeException("query analytics coordinator addresses failed", e); } - - private List getServiceAddresses() { - List hostAndPorts; - try { - hostAndPorts = serverQueryClient.queryAllServices(DEFAULT); - } catch (Throwable e) { - throw new GeaflowRuntimeException("query analytics coordinator addresses failed", e); - } - if (CollectionUtils.isEmpty(hostAndPorts)) { - throw new GeaflowRuntimeException("query analytics coordinator addresses is empty"); - } - LOGGER.info("query analytics coordinator addresses is {}", Arrays.toString(hostAndPorts.toArray())); - return hostAndPorts; + if (CollectionUtils.isEmpty(hostAndPorts)) { + throw new GeaflowRuntimeException("query analytics coordinator addresses is empty"); } - - @Override - public boolean isRunning() { - return queryRunnerStatus.get() == QueryRunnerStatus.RUNNING; - } - - @Override - public boolean isAborted() { - return queryRunnerStatus.get() == QueryRunnerStatus.ABORTED; - } - - @Override - public boolean isError() { - return queryRunnerStatus.get() == QueryRunnerStatus.ERROR; - } - - @Override - public boolean isFinished() { - return queryRunnerStatus.get() == QueryRunnerStatus.FINISHED; - } - + LOGGER.info( + "query analytics coordinator addresses is {}", Arrays.toString(hostAndPorts.toArray())); + return hostAndPorts; + } + + @Override + public boolean isRunning() { + return queryRunnerStatus.get() == QueryRunnerStatus.RUNNING; + } + + @Override + public boolean isAborted() { + return queryRunnerStatus.get() == QueryRunnerStatus.ABORTED; + } + + @Override + public boolean isError() { + return queryRunnerStatus.get() == QueryRunnerStatus.ERROR; + } + + @Override + public boolean isFinished() { + return queryRunnerStatus.get() == QueryRunnerStatus.FINISHED; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClient.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClient.java index 98b617098..07dc60e79 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClient.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClient.java @@ -25,7 +25,6 @@ import static org.apache.geaflow.analytics.service.query.StandardError.ANALYTICS_SERVER_BUSY; import static org.apache.geaflow.analytics.service.query.StandardError.ANALYTICS_SERVER_UNAVAILABLE; -import com.google.common.base.Preconditions; import org.apache.geaflow.analytics.service.client.QueryRunnerContext.ClientHandlerContextBuilder; import org.apache.geaflow.analytics.service.query.QueryError; import org.apache.geaflow.analytics.service.query.QueryResults; @@ -40,98 +39,104 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class AnalyticsClient { - private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsClient.class); - private final Configuration config; - private final String host; - private final int port; - private final ServiceType serviceType; - private final boolean initChannelPools; - private final int executeRetryNum; - private IQueryRunner queryRunner; - private final long sleepTimeMs; + private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsClient.class); + private final Configuration config; + private final String host; + private final int port; + private final ServiceType serviceType; + private final boolean initChannelPools; + private final int executeRetryNum; + private IQueryRunner queryRunner; + private final long sleepTimeMs; - public static AnalyticsClientBuilder builder() { - return new AnalyticsClientBuilder(); - } + public static AnalyticsClientBuilder builder() { + return new AnalyticsClientBuilder(); + } - protected AnalyticsClient(AnalyticsClientBuilder builder) { - this.config = builder.getConfiguration(); - this.host = builder.getHost(); - this.port = builder.getPort(); - this.initChannelPools = builder.enableInitChannelPools(); - this.executeRetryNum = config.getInteger(ANALYTICS_CLIENT_EXECUTE_RETRY_NUM); - this.sleepTimeMs = config.getLong(ANALYTICS_CLIENT_SLEEP_TIME_MS); - this.serviceType = ServiceType.getEnum(config); - init(); - } + protected AnalyticsClient(AnalyticsClientBuilder builder) { + this.config = builder.getConfiguration(); + this.host = builder.getHost(); + this.port = builder.getPort(); + this.initChannelPools = builder.enableInitChannelPools(); + this.executeRetryNum = config.getInteger(ANALYTICS_CLIENT_EXECUTE_RETRY_NUM); + this.sleepTimeMs = config.getLong(ANALYTICS_CLIENT_SLEEP_TIME_MS); + this.serviceType = ServiceType.getEnum(config); + init(); + } - private void init() { - if (this.host == null) { - checkAnalyticsClientConfig(config); - } - ClientHandlerContextBuilder clientHandlerContextBuilder = QueryRunnerContext.newBuilder() + private void init() { + if (this.host == null) { + checkAnalyticsClientConfig(config); + } + ClientHandlerContextBuilder clientHandlerContextBuilder = + QueryRunnerContext.newBuilder() .setConfiguration(config) .enableInitChannelPools(initChannelPools); - if (host != null) { - clientHandlerContextBuilder.setHost(new HostAndPort(host, port)); - } - QueryRunnerContext clientHandlerContext = clientHandlerContextBuilder.build(); - this.queryRunner = QueryRunnerFactory.loadQueryRunner(clientHandlerContext); + if (host != null) { + clientHandlerContextBuilder.setHost(new HostAndPort(host, port)); } + QueryRunnerContext clientHandlerContext = clientHandlerContextBuilder.build(); + this.queryRunner = QueryRunnerFactory.loadQueryRunner(clientHandlerContext); + } - - public QueryResults executeQuery(String queryScript) { - QueryResults result = null; - for (int i = 0; i < this.executeRetryNum; i++) { - result = this.queryRunner.executeQuery(queryScript); - boolean serviceBusy = false; - boolean serviceUnavailable = false; - if (result.getError() != null) { - int resultErrorCode = result.getError().getCode(); - serviceBusy = resultErrorCode == ANALYTICS_SERVER_BUSY.getQueryError().getCode(); - serviceUnavailable = resultErrorCode == ANALYTICS_SERVER_UNAVAILABLE.getQueryError().getCode(); - } - if (result.getQueryStatus() || (!serviceBusy && !serviceUnavailable)) { - return result; - } - LOGGER.info("all coordinator busy or unavailable, sleep {}ms and retry", sleepTimeMs); - SleepUtils.sleepMilliSecond(sleepTimeMs); - } - if (result == null) { - QueryError queryError = ANALYTICS_NULL_RESULT.getQueryError(); - return new QueryResults(queryError); - } + public QueryResults executeQuery(String queryScript) { + QueryResults result = null; + for (int i = 0; i < this.executeRetryNum; i++) { + result = this.queryRunner.executeQuery(queryScript); + boolean serviceBusy = false; + boolean serviceUnavailable = false; + if (result.getError() != null) { + int resultErrorCode = result.getError().getCode(); + serviceBusy = resultErrorCode == ANALYTICS_SERVER_BUSY.getQueryError().getCode(); + serviceUnavailable = + resultErrorCode == ANALYTICS_SERVER_UNAVAILABLE.getQueryError().getCode(); + } + if (result.getQueryStatus() || (!serviceBusy && !serviceUnavailable)) { return result; + } + LOGGER.info("all coordinator busy or unavailable, sleep {}ms and retry", sleepTimeMs); + SleepUtils.sleepMilliSecond(sleepTimeMs); } - - public void shutdown() { - try { - queryRunner.close(); - } catch (Throwable e) { - LOGGER.error("client handler close error", e); - } + if (result == null) { + QueryError queryError = ANALYTICS_NULL_RESULT.getQueryError(); + return new QueryResults(queryError); } + return result; + } - protected static void checkAnalyticsClientConfig(Configuration config) { - // Check job mode. - checkAnalyticsClientJobMode(config); + public void shutdown() { + try { + queryRunner.close(); + } catch (Throwable e) { + LOGGER.error("client handler close error", e); } + } - private static void checkAnalyticsClientJobMode(Configuration config) { - if (config.contains(ExecutionConfigKeys.JOB_MODE)) { - JobMode jobMode = JobMode.getJobMode(config); - Preconditions.checkArgument(JobMode.OLAP_SERVICE.equals(jobMode), "analytics job mode must set OLAP_SERVICE"); - return; - } - throw new GeaflowRuntimeException("analytics client config miss: " + ExecutionConfigKeys.JOB_MODE.getKey()); - } + protected static void checkAnalyticsClientConfig(Configuration config) { + // Check job mode. + checkAnalyticsClientJobMode(config); + } - private static void configIsExist(Configuration config, ConfigKey configKey) { - Preconditions.checkArgument( - config.contains(configKey) && !config.getConfigMap().get(configKey.getKey()).isEmpty(), - "client missing config: " + configKey.getKey() + ", description: " - + configKey.getDescription()); + private static void checkAnalyticsClientJobMode(Configuration config) { + if (config.contains(ExecutionConfigKeys.JOB_MODE)) { + JobMode jobMode = JobMode.getJobMode(config); + Preconditions.checkArgument( + JobMode.OLAP_SERVICE.equals(jobMode), "analytics job mode must set OLAP_SERVICE"); + return; } + throw new GeaflowRuntimeException( + "analytics client config miss: " + ExecutionConfigKeys.JOB_MODE.getKey()); + } + private static void configIsExist(Configuration config, ConfigKey configKey) { + Preconditions.checkArgument( + config.contains(configKey) && !config.getConfigMap().get(configKey.getKey()).isEmpty(), + "client missing config: " + + configKey.getKey() + + ", description: " + + configKey.getDescription()); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClientBuilder.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClientBuilder.java index 20ccf5205..c3ed0f1fa 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClientBuilder.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsClientBuilder.java @@ -26,81 +26,80 @@ import org.apache.geaflow.common.config.Configuration; public class AnalyticsClientBuilder { - private final Configuration configuration = new Configuration(); - private String host; - private int port; - private String user; - private int queryRetryNum; - private boolean initChannelPools; - - public AnalyticsClientBuilder() { + private final Configuration configuration = new Configuration(); + private String host; + private int port; + private String user; + private int queryRetryNum; + private boolean initChannelPools; + + public AnalyticsClientBuilder() {} + + public AnalyticsClientBuilder withHost(String host) { + this.host = host; + return this; + } + + public AnalyticsClientBuilder withPort(int port) { + this.port = port; + return this; + } + + public AnalyticsClientBuilder withInitChannelPools(boolean initChannelPools) { + this.initChannelPools = initChannelPools; + return this; + } + + public AnalyticsClientBuilder withConfiguration(Configuration configuration) { + this.configuration.putAll(configuration.getConfigMap()); + return this; + } + + public AnalyticsClientBuilder withUser(String user) { + this.user = user; + return this; + } + + public AnalyticsClientBuilder withTimeoutMs(int timeoutMs) { + this.configuration.put( + AnalyticsClientConfigKeys.ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS, String.valueOf(timeoutMs)); + return this; + } + + public AnalyticsClientBuilder withRetryNum(int retryNum) { + this.configuration.put(ANALYTICS_CLIENT_CONNECT_RETRY_NUM, String.valueOf(retryNum)); + this.queryRetryNum = retryNum; + return this; + } + + public Configuration getConfiguration() { + return configuration; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getUser() { + return user; + } + + public int getQueryRetryNum() { + return queryRetryNum; + } + + public boolean enableInitChannelPools() { + return initChannelPools; + } + + public AnalyticsClient build() { + if (host == null) { + checkAnalyticsClientConfig(configuration); } - - public AnalyticsClientBuilder withHost(String host) { - this.host = host; - return this; - } - - public AnalyticsClientBuilder withPort(int port) { - this.port = port; - return this; - } - - public AnalyticsClientBuilder withInitChannelPools(boolean initChannelPools) { - this.initChannelPools = initChannelPools; - return this; - } - - public AnalyticsClientBuilder withConfiguration(Configuration configuration) { - this.configuration.putAll(configuration.getConfigMap()); - return this; - } - - public AnalyticsClientBuilder withUser(String user) { - this.user = user; - return this; - } - - public AnalyticsClientBuilder withTimeoutMs(int timeoutMs) { - this.configuration.put(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS, String.valueOf(timeoutMs)); - return this; - } - - public AnalyticsClientBuilder withRetryNum(int retryNum) { - this.configuration.put(ANALYTICS_CLIENT_CONNECT_RETRY_NUM, String.valueOf(retryNum)); - this.queryRetryNum = retryNum; - return this; - } - - public Configuration getConfiguration() { - return configuration; - } - - public String getHost() { - return host; - } - - public int getPort() { - return port; - } - - public String getUser() { - return user; - } - - public int getQueryRetryNum() { - return queryRetryNum; - } - - public boolean enableInitChannelPools() { - return initChannelPools; - } - - public AnalyticsClient build() { - if (host == null) { - checkAnalyticsClientConfig(configuration); - } - return new AnalyticsClient(this); - } - + return new AnalyticsClient(this); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptions.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptions.java index 5e0bd9055..8d67fee29 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptions.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptions.java @@ -22,68 +22,86 @@ import static java.util.Collections.emptyMap; import static java.util.Locale.ENGLISH; -import com.google.common.net.HostAndPort; import java.net.URI; import java.net.URISyntaxException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import com.google.common.net.HostAndPort; + public class AnalyticsManagerOptions { - private static final int DEFAULT_PORT = 8080; + private static final int DEFAULT_PORT = 8080; - private static final int DEFAULT_REQUEST_TIMEOUT_SECOND = 120; + private static final int DEFAULT_REQUEST_TIMEOUT_SECOND = 120; - private static final String HTTP_PREFIX = "http://"; + private static final String HTTP_PREFIX = "http://"; - private static final String DELIMITER = ":"; + private static final String DELIMITER = ":"; - private static final String PARSER_PREFIX = "http"; + private static final String PARSER_PREFIX = "http"; - private static final String HTTPS_PREFIX = "https://"; + private static final String HTTPS_PREFIX = "https://"; - private static final String DEFAULT_SERVER = "localhost"; + private static final String DEFAULT_SERVER = "localhost"; - private static final String DEFAULT_USER = System.getProperty("user.name"); + private static final String DEFAULT_USER = System.getProperty("user.name"); - private static final String DEFAULT_SOURCE = "geaflow-client"; + private static final String DEFAULT_SOURCE = "geaflow-client"; - public static AnalyticsManagerSession createClientSession(String host, int port) { - String url = HTTP_PREFIX + host + DELIMITER + port; - return new AnalyticsManagerSession(parseServer(url), DEFAULT_USER, DEFAULT_SOURCE, - DEFAULT_REQUEST_TIMEOUT_SECOND, false, emptyMap(), emptyMap()); - } + public static AnalyticsManagerSession createClientSession(String host, int port) { + String url = HTTP_PREFIX + host + DELIMITER + port; + return new AnalyticsManagerSession( + parseServer(url), + DEFAULT_USER, + DEFAULT_SOURCE, + DEFAULT_REQUEST_TIMEOUT_SECOND, + false, + emptyMap(), + emptyMap()); + } - public static AnalyticsManagerSession createClientSession(int port) { - return new AnalyticsManagerSession(parseServer(DEFAULT_SERVER, port), DEFAULT_USER, DEFAULT_SOURCE, - DEFAULT_REQUEST_TIMEOUT_SECOND, false, emptyMap(), emptyMap()); - } + public static AnalyticsManagerSession createClientSession(int port) { + return new AnalyticsManagerSession( + parseServer(DEFAULT_SERVER, port), + DEFAULT_USER, + DEFAULT_SOURCE, + DEFAULT_REQUEST_TIMEOUT_SECOND, + false, + emptyMap(), + emptyMap()); + } - public static URI parseServer(String server) { - server = server.toLowerCase(ENGLISH); - if (server.startsWith(HTTP_PREFIX) || server.startsWith(HTTPS_PREFIX)) { - return URI.create(server); - } - HostAndPort host = HostAndPort.fromString(server); - try { - return new URI(PARSER_PREFIX, null, host.getHost(), host.getPortOrDefault(DEFAULT_PORT), - null, null, null); - } catch (URISyntaxException e) { - throw new GeaflowRuntimeException("parse http server error", e); - } + public static URI parseServer(String server) { + server = server.toLowerCase(ENGLISH); + if (server.startsWith(HTTP_PREFIX) || server.startsWith(HTTPS_PREFIX)) { + return URI.create(server); } - - public static URI parseServer(String server, int port) { - server = server.toLowerCase(ENGLISH); - if (server.startsWith(HTTP_PREFIX) || server.startsWith(HTTPS_PREFIX)) { - return URI.create(server); - } - HostAndPort host = HostAndPort.fromString(server); - try { - return new URI(PARSER_PREFIX, null, host.getHost(), port, - null, null, null); - } catch (URISyntaxException e) { - throw new GeaflowRuntimeException("parse http server error", e); - } + HostAndPort host = HostAndPort.fromString(server); + try { + return new URI( + PARSER_PREFIX, + null, + host.getHost(), + host.getPortOrDefault(DEFAULT_PORT), + null, + null, + null); + } catch (URISyntaxException e) { + throw new GeaflowRuntimeException("parse http server error", e); } + } + public static URI parseServer(String server, int port) { + server = server.toLowerCase(ENGLISH); + if (server.startsWith(HTTP_PREFIX) || server.startsWith(HTTPS_PREFIX)) { + return URI.create(server); + } + HostAndPort host = HostAndPort.fromString(server); + try { + return new URI(PARSER_PREFIX, null, host.getHost(), port, null, null, null); + } catch (URISyntaxException e) { + throw new GeaflowRuntimeException("parse http server error", e); + } + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerSession.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerSession.java index 341faf0f7..5a3cfd286 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerSession.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerSession.java @@ -22,15 +22,88 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; -import com.google.common.collect.ImmutableMap; import java.net.URI; import java.util.Map; +import com.google.common.collect.ImmutableMap; + public class AnalyticsManagerSession { - private static final String USER = "user"; - private static final String SERVER = "server"; - private static final String PROPERTIES = "properties"; + private static final String USER = "user"; + private static final String SERVER = "server"; + private static final String PROPERTIES = "properties"; + private URI server; + private String user; + private String source; + private Map properties; + private Map customHeaders; + private boolean compressionDisabled; + private long clientRequestTimeoutMs; + + public static Builder builder() { + return new Builder(); + } + + public AnalyticsManagerSession( + URI server, + String user, + String source, + long clientRequestTimeoutMs, + boolean compressionDisabled, + Map properties, + Map customHeaders) { + this.server = requireNonNull(server, "server is null"); + this.user = user; + this.source = source; + this.compressionDisabled = compressionDisabled; + this.clientRequestTimeoutMs = clientRequestTimeoutMs; + this.properties = ImmutableMap.copyOf(requireNonNull(properties, "properties is null")); + this.customHeaders = + ImmutableMap.copyOf(requireNonNull(customHeaders, "customHeaders is null")); + } + + public AnalyticsManagerSession() {} + + public long getClientRequestTimeout() { + return clientRequestTimeoutMs; + } + + public URI getServer() { + return server; + } + + public String getUser() { + return user; + } + + public String getSource() { + return source; + } + + public Map getProperties() { + return properties; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public boolean isCompressionDisabled() { + return compressionDisabled; + } + + @Override + public String toString() { + return toStringHelper(this) + .add(SERVER, server) + .add(USER, user) + .add(PROPERTIES, properties) + .omitNullValues() + .toString(); + } + + public static final class Builder { + private URI server; private String user; private String source; @@ -39,121 +112,53 @@ public class AnalyticsManagerSession { private boolean compressionDisabled; private long clientRequestTimeoutMs; - public static Builder builder() { - return new Builder(); - } - - public AnalyticsManagerSession(URI server, String user, String source, long clientRequestTimeoutMs, - boolean compressionDisabled, Map properties, - Map customHeaders) { - this.server = requireNonNull(server, "server is null"); - this.user = user; - this.source = source; - this.compressionDisabled = compressionDisabled; - this.clientRequestTimeoutMs = clientRequestTimeoutMs; - this.properties = ImmutableMap.copyOf(requireNonNull(properties, "properties is null")); - this.customHeaders = ImmutableMap.copyOf( - requireNonNull(customHeaders, "customHeaders is null")); - } - - public AnalyticsManagerSession() { - } - - public long getClientRequestTimeout() { - return clientRequestTimeoutMs; - } - - public URI getServer() { - return server; - } - - public String getUser() { - return user; + public Builder setServer(URI server) { + this.server = server; + return this; } - public String getSource() { - return source; + public Builder setClientUser(String user) { + this.user = user; + return this; } - public Map getProperties() { - return properties; + public Builder setClientSource(String source) { + this.source = source; + return this; } - public Map getCustomHeaders() { - return customHeaders; + public Builder setCompressionDisabled(boolean compressionDisabled) { + this.compressionDisabled = compressionDisabled; + return this; } - public boolean isCompressionDisabled() { - return compressionDisabled; + public Builder setClientRequestTimeoutMs(long clientRequestTimeoutMs) { + this.clientRequestTimeoutMs = clientRequestTimeoutMs; + return this; } - @Override - public String toString() { - return toStringHelper(this) - .add(SERVER, server) - .add(USER, user) - .add(PROPERTIES, properties) - .omitNullValues() - .toString(); + public Builder setCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; } - public static final class Builder { - - private URI server; - private String user; - private String source; - private Map properties; - private Map customHeaders; - private boolean compressionDisabled; - private long clientRequestTimeoutMs; - - public Builder setServer(URI server) { - this.server = server; - return this; - } - - public Builder setClientUser(String user) { - this.user = user; - return this; - } - - public Builder setClientSource(String source) { - this.source = source; - return this; - } - - public Builder setCompressionDisabled(boolean compressionDisabled) { - this.compressionDisabled = compressionDisabled; - return this; - } - - public Builder setClientRequestTimeoutMs(long clientRequestTimeoutMs) { - this.clientRequestTimeoutMs = clientRequestTimeoutMs; - return this; - } - - public Builder setCustomHeaders(Map customHeaders) { - this.customHeaders = customHeaders; - return this; - } - - public Builder setProperties(Map properties) { - this.properties = properties; - return this; - } - - public AnalyticsManagerSession build() { - return new AnalyticsManagerSession(this); - } + public Builder setProperties(Map properties) { + this.properties = properties; + return this; } - private AnalyticsManagerSession(Builder builder) { - this.server = builder.server; - this.user = builder.user; - this.source = builder.source; - this.compressionDisabled = builder.compressionDisabled; - this.properties = builder.properties; - this.customHeaders = builder.customHeaders; - this.clientRequestTimeoutMs = builder.clientRequestTimeoutMs; + public AnalyticsManagerSession build() { + return new AnalyticsManagerSession(this); } + } + + private AnalyticsManagerSession(Builder builder) { + this.server = builder.server; + this.user = builder.user; + this.source = builder.source; + this.compressionDisabled = builder.compressionDisabled; + this.properties = builder.properties; + this.customHeaders = builder.customHeaders; + this.clientRequestTimeoutMs = builder.clientRequestTimeoutMs; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsServiceInfo.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsServiceInfo.java index 28b02e71e..83c0eaa1f 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsServiceInfo.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/AnalyticsServiceInfo.java @@ -20,38 +20,38 @@ package org.apache.geaflow.analytics.service.client; import java.util.List; + import org.apache.geaflow.common.rpc.HostAndPort; public class AnalyticsServiceInfo { - private String serverName; - private final List coordinatorAddresses; - private final int coordinatorNum; - - public AnalyticsServiceInfo(String serverName, List coordinatorAddresses) { - this.serverName = serverName; - this.coordinatorAddresses = coordinatorAddresses; - this.coordinatorNum = coordinatorAddresses.size(); - } - - public AnalyticsServiceInfo(List coordinatorAddresses) { - this.coordinatorAddresses = coordinatorAddresses; - this.coordinatorNum = coordinatorAddresses.size(); - } - - public String getServerName() { - return serverName; - } - - public List getCoordinatorAddresses() { - return coordinatorAddresses; - } - - public HostAndPort getCoordinatorAddresses(int index) { - return coordinatorAddresses.get(index); - } - - public int getCoordinatorNum() { - return coordinatorNum; - } - + private String serverName; + private final List coordinatorAddresses; + private final int coordinatorNum; + + public AnalyticsServiceInfo(String serverName, List coordinatorAddresses) { + this.serverName = serverName; + this.coordinatorAddresses = coordinatorAddresses; + this.coordinatorNum = coordinatorAddresses.size(); + } + + public AnalyticsServiceInfo(List coordinatorAddresses) { + this.coordinatorAddresses = coordinatorAddresses; + this.coordinatorNum = coordinatorAddresses.size(); + } + + public String getServerName() { + return serverName; + } + + public List getCoordinatorAddresses() { + return coordinatorAddresses; + } + + public HostAndPort getCoordinatorAddresses(int index) { + return coordinatorAddresses.get(index); + } + + public int getCoordinatorNum() { + return coordinatorNum; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpQueryRunner.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpQueryRunner.java index 25127e3dc..b8f7f5df5 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpQueryRunner.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpQueryRunner.java @@ -32,7 +32,6 @@ import static org.apache.geaflow.analytics.service.query.StandardError.ANALYTICS_SERVER_BUSY; import static org.apache.http.HttpStatus.SC_OK; -import com.google.gson.Gson; import java.io.IOException; import java.net.URI; import java.util.HashMap; @@ -40,6 +39,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.analytics.service.query.QueryError; import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.common.config.Configuration; @@ -59,145 +59,157 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class HttpQueryRunner extends AbstractQueryRunner { - - private static final Logger LOGGER = LoggerFactory.getLogger(HttpQueryRunner.class); - private static final String USER_AGENT_VALUE = HttpQueryRunner.class.getSimpleName() + "/" + firstNonNull( - HttpQueryRunner.class.getPackage().getImplementationVersion(), "unknown"); +import com.google.gson.Gson; - protected static final String URL_PATH = "/rest/analytics/query/execute"; - protected static final String QUERY = "query"; - private Map coordinatorAddress2Stub; +public class HttpQueryRunner extends AbstractQueryRunner { - @Override - public void init(QueryRunnerContext context) { - this.coordinatorAddress2Stub = new HashMap<>(); - super.init(context); + private static final Logger LOGGER = LoggerFactory.getLogger(HttpQueryRunner.class); + private static final String USER_AGENT_VALUE = + HttpQueryRunner.class.getSimpleName() + + "/" + + firstNonNull(HttpQueryRunner.class.getPackage().getImplementationVersion(), "unknown"); + + protected static final String URL_PATH = "/rest/analytics/query/execute"; + protected static final String QUERY = "query"; + private Map coordinatorAddress2Stub; + + @Override + public void init(QueryRunnerContext context) { + this.coordinatorAddress2Stub = new HashMap<>(); + super.init(context); + } + + @Override + protected void initManagedChannel(HostAndPort address) { + CloseableHttpClient httpClient = this.coordinatorAddress2Stub.get(address); + if (httpClient == null) { + this.coordinatorAddress2Stub.put(address, createHttpClient(config)); } - - @Override - protected void initManagedChannel(HostAndPort address) { - CloseableHttpClient httpClient = this.coordinatorAddress2Stub.get(address); - if (httpClient == null) { - this.coordinatorAddress2Stub.put(address, createHttpClient(config)); - } + } + + @Override + public QueryResults executeQuery(String queryScript) { + int coordinatorNum = analyticsServiceInfo.getCoordinatorNum(); + if (coordinatorNum == 0) { + QueryError queryError = ANALYTICS_NO_COORDINATOR.getQueryError(); + return new QueryResults(queryError); } - - @Override - public QueryResults executeQuery(String queryScript) { - int coordinatorNum = analyticsServiceInfo.getCoordinatorNum(); - if (coordinatorNum == 0) { - QueryError queryError = ANALYTICS_NO_COORDINATOR.getQueryError(); - return new QueryResults(queryError); - } - int idx = RANDOM.nextInt(coordinatorNum); - List coordinatorAddresses = analyticsServiceInfo.getCoordinatorAddresses(); - QueryResults result = null; - for (int i = 0; i < coordinatorAddresses.size(); i++) { - HostAndPort address = coordinatorAddresses.get(idx); - final long start = System.currentTimeMillis(); - result = executeInternal(address, queryScript); - LOGGER.info("coordinator {} execute query script {} finish, cost {} ms", address, queryScript, System.currentTimeMillis() - start); - if (!result.getQueryStatus() && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode()) { - LOGGER.warn("coordinator[{}] [{}] is busy, try next", idx, address.toString()); - idx = (idx + 1) % coordinatorNum; - continue; - } - queryRunnerStatus.compareAndSet(RUNNING, FINISHED); - return result; - } - - if (result != null && (!result.getQueryStatus() && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode())) { - QueryError queryError = ANALYTICS_SERVER_BUSY.getQueryError(); - LOGGER.error(queryError.getName()); - queryRunnerStatus.compareAndSet(RUNNING, ERROR); - return new QueryResults(queryError); - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.analyticsClientError(String.format("execute query [%s] error", queryScript))); + int idx = RANDOM.nextInt(coordinatorNum); + List coordinatorAddresses = analyticsServiceInfo.getCoordinatorAddresses(); + QueryResults result = null; + for (int i = 0; i < coordinatorAddresses.size(); i++) { + HostAndPort address = coordinatorAddresses.get(idx); + final long start = System.currentTimeMillis(); + result = executeInternal(address, queryScript); + LOGGER.info( + "coordinator {} execute query script {} finish, cost {} ms", + address, + queryScript, + System.currentTimeMillis() - start); + if (!result.getQueryStatus() + && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode()) { + LOGGER.warn("coordinator[{}] [{}] is busy, try next", idx, address.toString()); + idx = (idx + 1) % coordinatorNum; + continue; + } + queryRunnerStatus.compareAndSet(RUNNING, FINISHED); + return result; } - @Override - public ServiceType getServiceType() { - return ServiceType.analytics_http; + if (result != null + && (!result.getQueryStatus() + && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode())) { + QueryError queryError = ANALYTICS_SERVER_BUSY.getQueryError(); + LOGGER.error(queryError.getName()); + queryRunnerStatus.compareAndSet(RUNNING, ERROR); + return new QueryResults(queryError); } - - @Override - public QueryResults cancelQuery(long queryId) { - throw new GeaflowRuntimeException("not support cancel query"); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.analyticsClientError( + String.format("execute query [%s] error", queryScript))); + } + + @Override + public ServiceType getServiceType() { + return ServiceType.analytics_http; + } + + @Override + public QueryResults cancelQuery(long queryId) { + throw new GeaflowRuntimeException("not support cancel query"); + } + + private QueryResults executeInternal(HostAndPort address, String script) { + CloseableHttpClient httpClient = coordinatorAddress2Stub.get(address); + if (httpClient == null) { + initManagedChannel(address); + httpClient = coordinatorAddress2Stub.get(address); } - - - private QueryResults executeInternal(HostAndPort address, String script) { - CloseableHttpClient httpClient = coordinatorAddress2Stub.get(address); - if (httpClient == null) { - initManagedChannel(address); - httpClient = coordinatorAddress2Stub.get(address); - } - AnalyticsManagerSession analyticsManagerSession = createClientSession(address.getHost(), address.getPort()); - HttpUriRequest queryRequest = buildQueryRequest(analyticsManagerSession, script); - HttpResponse response = HttpResponse.execute(httpClient, queryRequest); - if ((response.getStatusCode() != SC_OK) || !response.enableQuerySuccess()) { - this.initAnalyticsServiceAddress(); - queryRunnerStatus.compareAndSet(RUNNING, ABORTED); - LOGGER.warn("coordinator execute query error, need re-init"); - } - return response.getValue(); + AnalyticsManagerSession analyticsManagerSession = + createClientSession(address.getHost(), address.getPort()); + HttpUriRequest queryRequest = buildQueryRequest(analyticsManagerSession, script); + HttpResponse response = HttpResponse.execute(httpClient, queryRequest); + if ((response.getStatusCode() != SC_OK) || !response.enableQuerySuccess()) { + this.initAnalyticsServiceAddress(); + queryRunnerStatus.compareAndSet(RUNNING, ABORTED); + LOGGER.warn("coordinator execute query error, need re-init"); } - - @Override - public void close() throws IOException { - for (Entry entry : this.coordinatorAddress2Stub.entrySet()) { - HostAndPort address = entry.getKey(); - CloseableHttpClient client = entry.getValue(); - if (client != null) { - try { - client.close(); - } catch (Exception e) { - LOGGER.warn("coordinator [{}:{}] shutdown failed", address.getHost(), - address.getPort(), e); - throw new GeaflowRuntimeException(String.format("coordinator [%s:%d] " - + "shutdown error", address.getHost(), address.getPort()), e); - } - } + return response.getValue(); + } + + @Override + public void close() throws IOException { + for (Entry entry : this.coordinatorAddress2Stub.entrySet()) { + HostAndPort address = entry.getKey(); + CloseableHttpClient client = entry.getValue(); + if (client != null) { + try { + client.close(); + } catch (Exception e) { + LOGGER.warn( + "coordinator [{}:{}] shutdown failed", address.getHost(), address.getPort(), e); + throw new GeaflowRuntimeException( + String.format( + "coordinator [%s:%d] " + "shutdown error", address.getHost(), address.getPort()), + e); } - if (this.serverQueryClient != null) { - this.serverQueryClient.close(); - } - LOGGER.info("http query executor shutdown"); + } } - - private CloseableHttpClient createHttpClient(Configuration config) { - HttpClientBuilder clientBuilder = HttpClientBuilder.create(); - int connectTimeout = config.getInteger(ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS); - clientBuilder.setConnectionTimeToLive(connectTimeout, TimeUnit.MILLISECONDS); - int retryNum = config.getInteger(ANALYTICS_CLIENT_CONNECT_RETRY_NUM); - clientBuilder.setRetryHandler(new DefaultHttpRequestRetryHandler(retryNum, true)); - clientBuilder.setRedirectStrategy(new LaxRedirectStrategy()); - int requestTimeout = config.getInteger(ANALYTICS_CLIENT_REQUEST_TIMEOUT_MS); - SocketConfig socketConfig = SocketConfig.custom() - .setSoTimeout(requestTimeout) - .setTcpNoDelay(true) - .build(); - clientBuilder.setDefaultSocketConfig(socketConfig); - clientBuilder.setUserAgent(USER_AGENT_VALUE); - return clientBuilder.build(); + if (this.serverQueryClient != null) { + this.serverQueryClient.close(); } - - - private HttpUriRequest buildQueryRequest(AnalyticsManagerSession session, String script) { - URI serverUri = session.getServer(); - if (serverUri == null) { - throw new GeaflowRuntimeException("Invalid server URL is null"); - } - String fullUri = serverUri.resolve(URL_PATH).toString(); - HttpPost httpPost = new HttpPost(fullUri); - Map params = new HashMap<>(); - params.put(QUERY, script); - StringEntity requestEntity = new StringEntity(new Gson().toJson(params), ContentType.APPLICATION_JSON); - httpPost.setEntity(requestEntity); - Map customHeaders = session.getCustomHeaders(); - customHeaders.forEach(httpPost::setHeader); - return httpPost; + LOGGER.info("http query executor shutdown"); + } + + private CloseableHttpClient createHttpClient(Configuration config) { + HttpClientBuilder clientBuilder = HttpClientBuilder.create(); + int connectTimeout = config.getInteger(ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS); + clientBuilder.setConnectionTimeToLive(connectTimeout, TimeUnit.MILLISECONDS); + int retryNum = config.getInteger(ANALYTICS_CLIENT_CONNECT_RETRY_NUM); + clientBuilder.setRetryHandler(new DefaultHttpRequestRetryHandler(retryNum, true)); + clientBuilder.setRedirectStrategy(new LaxRedirectStrategy()); + int requestTimeout = config.getInteger(ANALYTICS_CLIENT_REQUEST_TIMEOUT_MS); + SocketConfig socketConfig = + SocketConfig.custom().setSoTimeout(requestTimeout).setTcpNoDelay(true).build(); + clientBuilder.setDefaultSocketConfig(socketConfig); + clientBuilder.setUserAgent(USER_AGENT_VALUE); + return clientBuilder.build(); + } + + private HttpUriRequest buildQueryRequest(AnalyticsManagerSession session, String script) { + URI serverUri = session.getServer(); + if (serverUri == null) { + throw new GeaflowRuntimeException("Invalid server URL is null"); } - + String fullUri = serverUri.resolve(URL_PATH).toString(); + HttpPost httpPost = new HttpPost(fullUri); + Map params = new HashMap<>(); + params.put(QUERY, script); + StringEntity requestEntity = + new StringEntity(new Gson().toJson(params), ContentType.APPLICATION_JSON); + httpPost.setEntity(requestEntity); + Map customHeaders = session.getCustomHeaders(); + customHeaders.forEach(httpPost::setHeader); + return httpPost; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpResponse.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpResponse.java index 4df0b8899..9fe9dd797 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpResponse.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/HttpResponse.java @@ -22,6 +22,7 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.Arrays; + import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.serialize.SerializerFactory; @@ -34,103 +35,104 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is an adaptation of Presto's com.facebook.presto.client.JsonResponse. - */ +/** This class is an adaptation of Presto's com.facebook.presto.client.JsonResponse. */ public class HttpResponse { - private static final Logger LOGGER = LoggerFactory.getLogger(HttpResponse.class); - private static final int RESULT_BUFFER_SIZE = 4096; - private static final String STATUS_CODE = "statusCode"; - private static final String STATUS_MESSAGE = "statusMessage"; - private static final String HEADERS = "headers"; - private static final String QUERY_SUCCESS_KEY = "querySuccess"; - private static final String VALUE_KEY = "value"; - - private final int statusCode; - - private final String statusMessage; - - private final Header[] headers; - private byte[] responseBytes; - private boolean querySuccess; - private QueryResults value; - - private GeaflowRuntimeException exception; - - private HttpResponse(int statusCode, String statusMessage, Header[] responseAllHeaders, - byte[] responseBytes) { - this.statusCode = statusCode; - this.statusMessage = statusMessage; - this.headers = requireNonNull(responseAllHeaders, "headers is null"); - this.responseBytes = requireNonNull(responseBytes, "responseBytes is null"); - if (statusCode == HttpStatus.SC_OK) { - this.value = (QueryResults) SerializerFactory.getKryoSerializer().deserialize(responseBytes); - this.querySuccess = true; - } else { - String response = new String(responseBytes, StandardCharsets.UTF_8); - exception = new GeaflowRuntimeException(String.format("analytics http request failed," - + " status code: %s, status message: %s, response: %s", statusCode, statusMessage, response)); - this.querySuccess = false; - } - } - - public static HttpResponse execute(CloseableHttpClient client, HttpUriRequest request) { - try (CloseableHttpResponse response = client.execute(request)) { - HttpEntity entity = response.getEntity(); - InputStream content = entity.getContent(); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[RESULT_BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = content.read(buffer)) != -1) { - outputStream.write(buffer, 0, bytesRead); - } - byte[] resultBytes = outputStream.toByteArray(); - outputStream.close(); - content.close(); - int responseCode = response.getStatusLine().getStatusCode(); - Header[] responseAllHeaders = response.getAllHeaders(); - String responseMessage = response.getStatusLine().toString(); - return new HttpResponse(responseCode, responseMessage, responseAllHeaders, resultBytes); - } catch (IOException e) { - throw new GeaflowRuntimeException(String.format("execute analytics http request %s " - + "fail", request.getURI().toString()), e); - } - } - - - public boolean enableQuerySuccess() { - return querySuccess; + private static final Logger LOGGER = LoggerFactory.getLogger(HttpResponse.class); + private static final int RESULT_BUFFER_SIZE = 4096; + private static final String STATUS_CODE = "statusCode"; + private static final String STATUS_MESSAGE = "statusMessage"; + private static final String HEADERS = "headers"; + private static final String QUERY_SUCCESS_KEY = "querySuccess"; + private static final String VALUE_KEY = "value"; + + private final int statusCode; + + private final String statusMessage; + + private final Header[] headers; + private byte[] responseBytes; + private boolean querySuccess; + private QueryResults value; + + private GeaflowRuntimeException exception; + + private HttpResponse( + int statusCode, String statusMessage, Header[] responseAllHeaders, byte[] responseBytes) { + this.statusCode = statusCode; + this.statusMessage = statusMessage; + this.headers = requireNonNull(responseAllHeaders, "headers is null"); + this.responseBytes = requireNonNull(responseBytes, "responseBytes is null"); + if (statusCode == HttpStatus.SC_OK) { + this.value = (QueryResults) SerializerFactory.getKryoSerializer().deserialize(responseBytes); + this.querySuccess = true; + } else { + String response = new String(responseBytes, StandardCharsets.UTF_8); + exception = + new GeaflowRuntimeException( + String.format( + "analytics http request failed," + + " status code: %s, status message: %s, response: %s", + statusCode, statusMessage, response)); + this.querySuccess = false; } - - public QueryResults getValue() { - if (!querySuccess) { - throw new GeaflowRuntimeException("Response does not contain value", exception); - } - return value; - } - - public int getStatusCode() { - return statusCode; + } + + public static HttpResponse execute(CloseableHttpClient client, HttpUriRequest request) { + try (CloseableHttpResponse response = client.execute(request)) { + HttpEntity entity = response.getEntity(); + InputStream content = entity.getContent(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[RESULT_BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = content.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + byte[] resultBytes = outputStream.toByteArray(); + outputStream.close(); + content.close(); + int responseCode = response.getStatusLine().getStatusCode(); + Header[] responseAllHeaders = response.getAllHeaders(); + String responseMessage = response.getStatusLine().toString(); + return new HttpResponse(responseCode, responseMessage, responseAllHeaders, resultBytes); + } catch (IOException e) { + throw new GeaflowRuntimeException( + String.format("execute analytics http request %s " + "fail", request.getURI().toString()), + e); } + } - public Header[] getHeaders() { - return headers; - } + public boolean enableQuerySuccess() { + return querySuccess; + } - public GeaflowRuntimeException getException() { - return exception; + public QueryResults getValue() { + if (!querySuccess) { + throw new GeaflowRuntimeException("Response does not contain value", exception); } - - @Override - public String toString() { - return toStringHelper(this) - .add(STATUS_CODE, statusCode) - .add(STATUS_MESSAGE, statusMessage) - .add(HEADERS, Arrays.toString(headers)) - .add(QUERY_SUCCESS_KEY, querySuccess) - .add(VALUE_KEY, value) - .omitNullValues() - .toString(); - } - + return value; + } + + public int getStatusCode() { + return statusCode; + } + + public Header[] getHeaders() { + return headers; + } + + public GeaflowRuntimeException getException() { + return exception; + } + + @Override + public String toString() { + return toStringHelper(this) + .add(STATUS_CODE, statusCode) + .add(STATUS_MESSAGE, statusMessage) + .add(HEADERS, Arrays.toString(headers)) + .add(QUERY_SUCCESS_KEY, querySuccess) + .add(VALUE_KEY, value) + .omitNullValues() + .toString(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/IQueryRunner.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/IQueryRunner.java index a6436281c..47a752cc1 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/IQueryRunner.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/IQueryRunner.java @@ -20,48 +20,33 @@ package org.apache.geaflow.analytics.service.client; import java.io.Closeable; + import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.pipeline.service.ServiceType; public interface IQueryRunner extends Closeable { - /** - * Init query runner. - */ - void init(QueryRunnerContext handlerContext); + /** Init query runner. */ + void init(QueryRunnerContext handlerContext); - /** - * Execute query. - */ - QueryResults executeQuery(String queryScript); + /** Execute query. */ + QueryResults executeQuery(String queryScript); - /** - * Get service type. - */ - ServiceType getServiceType(); + /** Get service type. */ + ServiceType getServiceType(); - /** - * Cancel query. - */ - QueryResults cancelQuery(long queryId); + /** Cancel query. */ + QueryResults cancelQuery(long queryId); - /** - * Query runner is running. - */ - boolean isRunning(); + /** Query runner is running. */ + boolean isRunning(); - /** - * Query runner is aborted, when request fail. - */ - boolean isAborted(); + /** Query runner is aborted, when request fail. */ + boolean isAborted(); - /** - * Query runner is error. - */ - boolean isError(); + /** Query runner is error. */ + boolean isError(); - /** - * Query runner is finished. - */ - boolean isFinished(); + /** Query runner is finished. */ + boolean isFinished(); } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerContext.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerContext.java index 281543699..29f3c2114 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerContext.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerContext.java @@ -23,84 +23,83 @@ import org.apache.geaflow.common.rpc.HostAndPort; public class QueryRunnerContext { - private final Configuration configuration; - private final boolean initChannelPools; - private final HostAndPort hostAndPort; - private final String metaServerBaseNode; - private final String metaServerAddress; - - public Configuration getConfiguration() { - return configuration; + private final Configuration configuration; + private final boolean initChannelPools; + private final HostAndPort hostAndPort; + private final String metaServerBaseNode; + private final String metaServerAddress; + + public Configuration getConfiguration() { + return configuration; + } + + public HostAndPort getHostAndPort() { + return hostAndPort; + } + + public boolean isInitChannelPools() { + return initChannelPools; + } + + public String getMetaServerBaseNode() { + return metaServerBaseNode; + } + + public String getMetaServerAddress() { + return metaServerAddress; + } + + public boolean enableInitChannelPools() { + return initChannelPools; + } + + public static ClientHandlerContextBuilder newBuilder() { + return new ClientHandlerContextBuilder(); + } + + public static class ClientHandlerContextBuilder { + private Configuration configuration; + private boolean initChannelPools; + + private HostAndPort hostAndPort; + private String analyticsServiceJobName; + private String metaServerAddress; + + public ClientHandlerContextBuilder setAnalyticsServiceJobName(String analyticsServiceJobName) { + this.analyticsServiceJobName = analyticsServiceJobName; + return this; } - public HostAndPort getHostAndPort() { - return hostAndPort; + public ClientHandlerContextBuilder setMetaServerAddress(String metaServerAddress) { + this.metaServerAddress = metaServerAddress; + return this; } - public boolean isInitChannelPools() { - return initChannelPools; + public ClientHandlerContextBuilder setConfiguration(Configuration configuration) { + this.configuration = configuration; + return this; } - public String getMetaServerBaseNode() { - return metaServerBaseNode; + public ClientHandlerContextBuilder setHost(HostAndPort hostAndPort) { + this.hostAndPort = hostAndPort; + return this; } - public String getMetaServerAddress() { - return metaServerAddress; + public ClientHandlerContextBuilder enableInitChannelPools(boolean initChannelPools) { + this.initChannelPools = initChannelPools; + return this; } - public boolean enableInitChannelPools() { - return initChannelPools; - } - - public static ClientHandlerContextBuilder newBuilder() { - return new ClientHandlerContextBuilder(); - } - - public static class ClientHandlerContextBuilder { - private Configuration configuration; - private boolean initChannelPools; - - private HostAndPort hostAndPort; - private String analyticsServiceJobName; - private String metaServerAddress; - - public ClientHandlerContextBuilder setAnalyticsServiceJobName(String analyticsServiceJobName) { - this.analyticsServiceJobName = analyticsServiceJobName; - return this; - } - - public ClientHandlerContextBuilder setMetaServerAddress(String metaServerAddress) { - this.metaServerAddress = metaServerAddress; - return this; - } - - public ClientHandlerContextBuilder setConfiguration(Configuration configuration) { - this.configuration = configuration; - return this; - } - - public ClientHandlerContextBuilder setHost(HostAndPort hostAndPort) { - this.hostAndPort = hostAndPort; - return this; - } - - public ClientHandlerContextBuilder enableInitChannelPools(boolean initChannelPools) { - this.initChannelPools = initChannelPools; - return this; - } - - public QueryRunnerContext build() { - return new QueryRunnerContext(this); - } - } - - - private QueryRunnerContext(ClientHandlerContextBuilder builder) { - this.configuration = builder.configuration; - this.initChannelPools = builder.initChannelPools; - this.hostAndPort = builder.hostAndPort; - this.metaServerAddress = builder.metaServerAddress; - this.metaServerBaseNode = builder.analyticsServiceJobName; + public QueryRunnerContext build() { + return new QueryRunnerContext(this); } + } + + private QueryRunnerContext(ClientHandlerContextBuilder builder) { + this.configuration = builder.configuration; + this.initChannelPools = builder.initChannelPools; + this.hostAndPort = builder.hostAndPort; + this.metaServerAddress = builder.metaServerAddress; + this.metaServerBaseNode = builder.analyticsServiceJobName; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerFactory.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerFactory.java index 143692a50..095d9ffa9 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerFactory.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -31,22 +32,23 @@ public class QueryRunnerFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(QueryRunnerFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(QueryRunnerFactory.class); - public static IQueryRunner loadQueryRunner(QueryRunnerContext handlerContext) { - Configuration configuration = handlerContext.getConfiguration(); - String type = configuration.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); - ServiceLoader contextLoader = ServiceLoader.load(IQueryRunner.class); - Iterator contextIterable = contextLoader.iterator(); - while (contextIterable.hasNext()) { - IQueryRunner clientHandler = contextIterable.next(); - if (clientHandler.getServiceType() == ServiceType.getEnum(type)) { - LOGGER.info("loaded IClientHandler implementation {}", clientHandler); - clientHandler.init(handlerContext); - return clientHandler; - } - } - LOGGER.error("NOT found IClientHandler implementation with type:{}", type); - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(IQueryRunner.class.getSimpleName())); + public static IQueryRunner loadQueryRunner(QueryRunnerContext handlerContext) { + Configuration configuration = handlerContext.getConfiguration(); + String type = configuration.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); + ServiceLoader contextLoader = ServiceLoader.load(IQueryRunner.class); + Iterator contextIterable = contextLoader.iterator(); + while (contextIterable.hasNext()) { + IQueryRunner clientHandler = contextIterable.next(); + if (clientHandler.getServiceType() == ServiceType.getEnum(type)) { + LOGGER.info("loaded IClientHandler implementation {}", clientHandler); + clientHandler.init(handlerContext); + return clientHandler; + } } + LOGGER.error("NOT found IClientHandler implementation with type:{}", type); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IQueryRunner.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerStatus.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerStatus.java index 99959f589..532c263c3 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerStatus.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/QueryRunnerStatus.java @@ -21,23 +21,15 @@ public enum QueryRunnerStatus { - /** - * Running status. - */ - RUNNING, + /** Running status. */ + RUNNING, - /** - * Error status. - */ - ERROR, + /** Error status. */ + ERROR, - /** - * Aborted status. - */ - ABORTED, + /** Aborted status. */ + ABORTED, - /** - * Finished status. - */ - FINISHED + /** Finished status. */ + FINISHED } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/RpcQueryRunner.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/RpcQueryRunner.java index c43cb9013..9b06ac978 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/RpcQueryRunner.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/RpcQueryRunner.java @@ -28,17 +28,13 @@ import static org.apache.geaflow.analytics.service.query.StandardError.ANALYTICS_NO_COORDINATOR; import static org.apache.geaflow.analytics.service.query.StandardError.ANALYTICS_SERVER_BUSY; -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; -import io.grpc.StatusRuntimeException; -import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.ChannelOption; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.analytics.service.config.AnalyticsClientConfigKeys; import org.apache.geaflow.analytics.service.query.QueryError; import org.apache.geaflow.analytics.service.query.QueryResults; @@ -60,196 +56,236 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RpcQueryRunner extends AbstractQueryRunner { +import com.google.protobuf.ByteString; - private static final Logger LOGGER = LoggerFactory.getLogger(RpcQueryRunner.class); - private static final int DEFAULT_SHUTDOWN_AWAIT_MS = 5000; - private static final ISerializer SERIALIZER = SerializerFactory.getKryoSerializer(); - private static final String REQUEST = "request"; - private Map coordinatorAddress2Channel; - private Map coordinatorAddress2Stub; - - @Override - public void init(QueryRunnerContext context) { - this.coordinatorAddress2Channel = new HashMap<>(); - this.coordinatorAddress2Stub = new HashMap<>(); - super.init(context); - } +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.ChannelOption; - @Override - public QueryResults executeQuery(String queryScript) { - int coordinatorNum = analyticsServiceInfo.getCoordinatorNum(); - if (coordinatorNum == 0) { - QueryError queryError = ANALYTICS_NO_COORDINATOR.getQueryError(); - return new QueryResults(queryError); - } - int idx = RANDOM.nextInt(coordinatorNum); - List coordinatorAddresses = analyticsServiceInfo.getCoordinatorAddresses(); - QueryResults result = null; - for (int i = 0; i < coordinatorAddresses.size(); i++) { - HostAndPort address = coordinatorAddresses.get(idx); - final long start = System.currentTimeMillis(); - result = executeInternal(address, queryScript); - LOGGER.info("coordinator {} execute query script {} finish, cost {} ms", address, queryScript, System.currentTimeMillis() - start); - if (!result.getQueryStatus() && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode()) { - LOGGER.warn("coordinator[{}] [{}] is busy, try next", idx, address.toString()); - idx = (idx + 1) % coordinatorNum; - continue; - } - queryRunnerStatus.compareAndSet(RUNNING, FINISHED); - return result; - } +public class RpcQueryRunner extends AbstractQueryRunner { - if (result != null && (!result.getQueryStatus() && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode())) { - QueryError queryError = ANALYTICS_SERVER_BUSY.getQueryError(); - LOGGER.error(queryError.getName()); - queryRunnerStatus.compareAndSet(RUNNING, ERROR); - return new QueryResults(queryError); - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.analyticsClientError(String.format("execute query [%s] error", queryScript))); - } + private static final Logger LOGGER = LoggerFactory.getLogger(RpcQueryRunner.class); + private static final int DEFAULT_SHUTDOWN_AWAIT_MS = 5000; + private static final ISerializer SERIALIZER = SerializerFactory.getKryoSerializer(); + private static final String REQUEST = "request"; + private Map coordinatorAddress2Channel; + private Map coordinatorAddress2Stub; - @Override - public void initManagedChannel(HostAndPort address) { - ManagedChannel coordinatorChannel = this.coordinatorAddress2Channel.get(address); - if (coordinatorChannel != null && (!coordinatorChannel.isShutdown() || !coordinatorChannel.isTerminated())) { - coordinatorChannel.shutdownNow(); - this.coordinatorAddress2Channel.remove(address); - } - ManagedChannel managedChannel = createManagedChannel(config, address); - this.coordinatorAddress2Channel.put(address, managedChannel); - this.coordinatorAddress2Stub.put(address, AnalyticsServiceGrpc.newBlockingStub(managedChannel)); - } + @Override + public void init(QueryRunnerContext context) { + this.coordinatorAddress2Channel = new HashMap<>(); + this.coordinatorAddress2Stub = new HashMap<>(); + super.init(context); + } - public void ensureChannelAlive(HostAndPort address) { - ManagedChannel channel = this.coordinatorAddress2Channel.get(address); - if (channel == null || channel.isShutdown() || channel.isTerminated()) { - LOGGER.warn("connection of [{}:{}] lost, reconnect...", address.getHost(), address.getPort()); - this.initManagedChannel(address); - } + @Override + public QueryResults executeQuery(String queryScript) { + int coordinatorNum = analyticsServiceInfo.getCoordinatorNum(); + if (coordinatorNum == 0) { + QueryError queryError = ANALYTICS_NO_COORDINATOR.getQueryError(); + return new QueryResults(queryError); } - - private ManagedChannel buildChannel(String host, int port, int timeoutMs) { - return NettyChannelBuilder.forAddress(host, port) - .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMs) - .maxInboundMessageSize(config.getInteger(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE)) - .maxRetryAttempts(config.getInteger(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_RETRY_ATTEMPTS)) - .retryBufferSize(config.getLong(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_DEFALUT_RETRY_BUFFER_SIZE)) - .perRpcBufferLimit(config.getLong(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_PER_RPC_BUFFER_LIMIT)) - .usePlaintext() - .build(); + int idx = RANDOM.nextInt(coordinatorNum); + List coordinatorAddresses = analyticsServiceInfo.getCoordinatorAddresses(); + QueryResults result = null; + for (int i = 0; i < coordinatorAddresses.size(); i++) { + HostAndPort address = coordinatorAddresses.get(idx); + final long start = System.currentTimeMillis(); + result = executeInternal(address, queryScript); + LOGGER.info( + "coordinator {} execute query script {} finish, cost {} ms", + address, + queryScript, + System.currentTimeMillis() - start); + if (!result.getQueryStatus() + && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode()) { + LOGGER.warn("coordinator[{}] [{}] is busy, try next", idx, address.toString()); + idx = (idx + 1) % coordinatorNum; + continue; + } + queryRunnerStatus.compareAndSet(RUNNING, FINISHED); + return result; } + if (result != null + && (!result.getQueryStatus() + && result.getError().getCode() == ANALYTICS_SERVER_BUSY.getQueryError().getCode())) { + QueryError queryError = ANALYTICS_SERVER_BUSY.getQueryError(); + LOGGER.error(queryError.getName()); + queryRunnerStatus.compareAndSet(RUNNING, ERROR); + return new QueryResults(queryError); + } + throw new GeaflowRuntimeException( + RuntimeErrors.INST.analyticsClientError( + String.format("execute query [%s] error", queryScript))); + } - private QueryResults executeInternal(HostAndPort address, String script) { - AnalyticsServiceBlockingStub analyticsServiceBlockingStub = getAnalyticsServiceBlockingStub(address); - if (analyticsServiceBlockingStub == null) { - initManagedChannel(address); - analyticsServiceBlockingStub = getAnalyticsServiceBlockingStub(address); - } - QueryRequest queryRequest = buildRpcRequest(script); - QueryResult queryResponse; - try { - queryResponse = analyticsServiceBlockingStub.executeQuery(queryRequest); - } catch (StatusRuntimeException e) { - LOGGER.error("query {} execute failed with status {} and cause {}", script, e.getStatus(), LogMsgUtil.getStackMsg(e)); - return getResultIfException(e, address); - } - return RpcMessageEncoder.decode(queryResponse.getQueryResult()); + @Override + public void initManagedChannel(HostAndPort address) { + ManagedChannel coordinatorChannel = this.coordinatorAddress2Channel.get(address); + if (coordinatorChannel != null + && (!coordinatorChannel.isShutdown() || !coordinatorChannel.isTerminated())) { + coordinatorChannel.shutdownNow(); + this.coordinatorAddress2Channel.remove(address); } + ManagedChannel managedChannel = createManagedChannel(config, address); + this.coordinatorAddress2Channel.put(address, managedChannel); + this.coordinatorAddress2Stub.put(address, AnalyticsServiceGrpc.newBlockingStub(managedChannel)); + } - public AnalyticsServiceBlockingStub getAnalyticsServiceBlockingStub(HostAndPort address) { - ensureChannelAlive(address); - AnalyticsServiceBlockingStub analyticsServiceBlockingStub = this.coordinatorAddress2Stub.get(address); - if (analyticsServiceBlockingStub != null) { - return analyticsServiceBlockingStub; - } - throw new GeaflowRuntimeException(String.format("coordinator address [%s:%d] get rpc stub " - + "fail", address.getHost(), address.getPort())); + public void ensureChannelAlive(HostAndPort address) { + ManagedChannel channel = this.coordinatorAddress2Channel.get(address); + if (channel == null || channel.isShutdown() || channel.isTerminated()) { + LOGGER.warn("connection of [{}:{}] lost, reconnect...", address.getHost(), address.getPort()); + this.initManagedChannel(address); } + } + private ManagedChannel buildChannel(String host, int port, int timeoutMs) { + return NettyChannelBuilder.forAddress(host, port) + .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMs) + .maxInboundMessageSize( + config.getInteger(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE)) + .maxRetryAttempts( + config.getInteger(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_MAX_RETRY_ATTEMPTS)) + .retryBufferSize( + config.getLong(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_DEFALUT_RETRY_BUFFER_SIZE)) + .perRpcBufferLimit( + config.getLong(AnalyticsClientConfigKeys.ANALYTICS_CLIENT_PER_RPC_BUFFER_LIMIT)) + .usePlaintext() + .build(); + } - private QueryResults getResultIfException(StatusRuntimeException e, HostAndPort address) { - QueryError queryError; - switch (e.getStatus().getCode()) { - case UNAVAILABLE: - LOGGER.warn("coordinator address {} server unavailable", address); - initAnalyticsServiceAddress(); - queryError = StandardError.ANALYTICS_SERVER_UNAVAILABLE.getQueryError(); - queryRunnerStatus.compareAndSet(RUNNING, ABORTED); - return new QueryResults(queryError); - case RESOURCE_EXHAUSTED: - LOGGER.warn("query result too big: {}", e.getMessage()); - queryRunnerStatus.compareAndSet(RUNNING, ERROR); - queryError = StandardError.ANALYTICS_RESULT_TO_LONG.getQueryError(); - return new QueryResults(queryError); - default: - LOGGER.warn("re-init channel {} for unexpected exception: {}", address, e.getStatus()); - queryRunnerStatus.compareAndSet(RUNNING, ERROR); - initManagedChannel(address); - queryError = StandardError.ANALYTICS_RPC_ERROR.getQueryError(); - return new QueryResults(queryError); - } + private QueryResults executeInternal(HostAndPort address, String script) { + AnalyticsServiceBlockingStub analyticsServiceBlockingStub = + getAnalyticsServiceBlockingStub(address); + if (analyticsServiceBlockingStub == null) { + initManagedChannel(address); + analyticsServiceBlockingStub = getAnalyticsServiceBlockingStub(address); + } + QueryRequest queryRequest = buildRpcRequest(script); + QueryResult queryResponse; + try { + queryResponse = analyticsServiceBlockingStub.executeQuery(queryRequest); + } catch (StatusRuntimeException e) { + LOGGER.error( + "query {} execute failed with status {} and cause {}", + script, + e.getStatus(), + LogMsgUtil.getStackMsg(e)); + return getResultIfException(e, address); } + return RpcMessageEncoder.decode(queryResponse.getQueryResult()); + } - @Override - public ServiceType getServiceType() { - return ServiceType.analytics_rpc; + public AnalyticsServiceBlockingStub getAnalyticsServiceBlockingStub(HostAndPort address) { + ensureChannelAlive(address); + AnalyticsServiceBlockingStub analyticsServiceBlockingStub = + this.coordinatorAddress2Stub.get(address); + if (analyticsServiceBlockingStub != null) { + return analyticsServiceBlockingStub; } + throw new GeaflowRuntimeException( + String.format( + "coordinator address [%s:%d] get rpc stub " + "fail", + address.getHost(), address.getPort())); + } - @Override - public QueryResults cancelQuery(long queryId) { - throw new GeaflowRuntimeException("not support cancel query"); + private QueryResults getResultIfException(StatusRuntimeException e, HostAndPort address) { + QueryError queryError; + switch (e.getStatus().getCode()) { + case UNAVAILABLE: + LOGGER.warn("coordinator address {} server unavailable", address); + initAnalyticsServiceAddress(); + queryError = StandardError.ANALYTICS_SERVER_UNAVAILABLE.getQueryError(); + queryRunnerStatus.compareAndSet(RUNNING, ABORTED); + return new QueryResults(queryError); + case RESOURCE_EXHAUSTED: + LOGGER.warn("query result too big: {}", e.getMessage()); + queryRunnerStatus.compareAndSet(RUNNING, ERROR); + queryError = StandardError.ANALYTICS_RESULT_TO_LONG.getQueryError(); + return new QueryResults(queryError); + default: + LOGGER.warn("re-init channel {} for unexpected exception: {}", address, e.getStatus()); + queryRunnerStatus.compareAndSet(RUNNING, ERROR); + initManagedChannel(address); + queryError = StandardError.ANALYTICS_RPC_ERROR.getQueryError(); + return new QueryResults(queryError); } + } - @Override - public void close() throws IOException { - for (Entry entry : this.coordinatorAddress2Channel.entrySet()) { - ManagedChannel channel = entry.getValue(); - HostAndPort address = entry.getKey(); - if (channel != null) { - try { - channel.shutdown().awaitTermination(DEFAULT_SHUTDOWN_AWAIT_MS, TimeUnit.MILLISECONDS); - } catch (Exception e) { - LOGGER.warn("coordinator [{}:{}] shutdown failed", address.getHost(), address.getPort(), e); - throw new GeaflowRuntimeException(String.format("coordinator [%s:%d] " - + "shutdown error", address.getHost(), address.getPort()), e); - } - } - } - if (this.serverQueryClient != null) { - this.serverQueryClient.close(); - } + @Override + public ServiceType getServiceType() { + return ServiceType.analytics_rpc; + } - LOGGER.info("rpc query executor shutdown"); - } + @Override + public QueryResults cancelQuery(long queryId) { + throw new GeaflowRuntimeException("not support cancel query"); + } - private ManagedChannel createManagedChannel(Configuration config, HostAndPort address) { - Throwable latestException = null; - int connectTimeoutMs = config.getInteger(ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS); - int retryNum = config.getInteger(ANALYTICS_CLIENT_CONNECT_RETRY_NUM); - for (int i = 0; i < retryNum; i++) { - try { - ManagedChannel managedChannel = buildChannel(address.getHost(), address.getPort(), connectTimeoutMs); - LOGGER.info("init managed channel with address {}:{}", address.getHost(), address.getPort()); - return managedChannel; - } catch (Throwable e) { - latestException = e; - LOGGER.warn("init managed channel [{}:{}] failed, retry {}", address.getHost(), address.getPort(), i + 1, e); - } + @Override + public void close() throws IOException { + for (Entry entry : this.coordinatorAddress2Channel.entrySet()) { + ManagedChannel channel = entry.getValue(); + HostAndPort address = entry.getKey(); + if (channel != null) { + try { + channel.shutdown().awaitTermination(DEFAULT_SHUTDOWN_AWAIT_MS, TimeUnit.MILLISECONDS); + } catch (Exception e) { + LOGGER.warn( + "coordinator [{}:{}] shutdown failed", address.getHost(), address.getPort(), e); + throw new GeaflowRuntimeException( + String.format( + "coordinator [%s:%d] " + "shutdown error", address.getHost(), address.getPort()), + e); } - String msg = String.format("try connect to [%s:%d] fail after %d times", address.getHost(), address.getPort(), retryNum); - LOGGER.error(msg, latestException); - throw new GeaflowRuntimeException(RuntimeErrors.INST.analyticsClientError(msg), latestException); + } } + if (this.serverQueryClient != null) { + this.serverQueryClient.close(); + } + + LOGGER.info("rpc query executor shutdown"); + } - private Analytics.QueryRequest buildRpcRequest(String queryScript) { - ByteString.Output output = ByteString.newOutput(); - SERIALIZER.serialize(REQUEST, output); - return Analytics.QueryRequest.newBuilder() - .setQuery(queryScript) - .setQueryConfig(output.toByteString()) - .build(); + private ManagedChannel createManagedChannel(Configuration config, HostAndPort address) { + Throwable latestException = null; + int connectTimeoutMs = config.getInteger(ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS); + int retryNum = config.getInteger(ANALYTICS_CLIENT_CONNECT_RETRY_NUM); + for (int i = 0; i < retryNum; i++) { + try { + ManagedChannel managedChannel = + buildChannel(address.getHost(), address.getPort(), connectTimeoutMs); + LOGGER.info( + "init managed channel with address {}:{}", address.getHost(), address.getPort()); + return managedChannel; + } catch (Throwable e) { + latestException = e; + LOGGER.warn( + "init managed channel [{}:{}] failed, retry {}", + address.getHost(), + address.getPort(), + i + 1, + e); + } } + String msg = + String.format( + "try connect to [%s:%d] fail after %d times", + address.getHost(), address.getPort(), retryNum); + LOGGER.error(msg, latestException); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.analyticsClientError(msg), latestException); + } + private Analytics.QueryRequest buildRpcRequest(String queryScript) { + ByteString.Output output = ByteString.newOutput(); + SERIALIZER.serialize(REQUEST, output); + return Analytics.QueryRequest.newBuilder() + .setQuery(queryScript) + .setQueryConfig(output.toByteString()) + .build(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsConnection.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsConnection.java index 8fe36be00..03ecab180 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsConnection.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsConnection.java @@ -23,7 +23,6 @@ import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; -import com.google.common.primitives.Ints; import java.net.URI; import java.sql.Array; import java.sql.Blob; @@ -46,368 +45,368 @@ import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class AnalyticsConnection implements Connection { - - private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsConnection.class); - - private static final String SEPARATOR_COLON = ":"; - private final AtomicReference networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); - private final Map clientInfo = new ConcurrentHashMap<>(); - private final AtomicBoolean closed = new AtomicBoolean(); - private final AtomicReference catalog = new AtomicReference<>(); - private final URI httpUri; - private final String user; - private final Map sessionProperties; - - private final AnalyticsDriverURI analyticsDriverURI; - - private AnalyticsClient client; - - public static Connection newInstance(String url, Properties properties) { - AnalyticsDriverURI analyticsDriverURI = new AnalyticsDriverURI(url, properties); - return new AnalyticsConnection(analyticsDriverURI); - } - - private AnalyticsConnection(AnalyticsDriverURI uri) { - requireNonNull(uri, "analytics driver uri is null"); - this.analyticsDriverURI = uri; - this.user = uri.getUser(); - this.httpUri = uri.getHttpUri(); - this.sessionProperties = new ConcurrentHashMap<>(uri.getSessionProperties()); - initAnalyticsClient(); - } - - - private void initAnalyticsClient() { - String authority = this.analyticsDriverURI.getAuthority(); - try { - if (authority != null && !authority.isEmpty()) { - String[] split = authority.split(SEPARATOR_COLON); - if (split.length == 2) { - String host = split[0]; - int port = Integer.parseInt(split[1]); - this.client = AnalyticsClient.builder() - .withHost(host).withPort(port).build(); - LOGGER.info("init geaflow analytics connection with host [{}] and port [{}]", - host, port); - } else { - LOGGER.warn("illegal authority: [{}]", authority); - throw new GeaflowRuntimeException("illegal authority: " + authority); - } - } - } catch (Throwable e) { - LOGGER.warn("parse authority from driver uri failed [{}]", authority, e); - throw new GeaflowRuntimeException("analytics jdbc create client failed"); - } - } - - @Override - public Statement createStatement() throws SQLException { - checkOpen(); - return AnalyticsStatement.newInstance(this, client); - } - - private void checkOpen() throws SQLException { - if (isClosed()) { - throw new GeaflowRuntimeException("Connection is closed"); - } - } - - @Override - public PreparedStatement prepareStatement(String sql) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public CallableStatement prepareCall(String sql) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String nativeSQL(String sql) throws SQLException { - checkOpen(); - return sql; - } - - @Override - public void setAutoCommit(boolean autoCommit) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getAutoCommit() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void commit() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void rollback() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void close() throws SQLException { - closed.set(true); - this.client.shutdown(); - } - - @Override - public boolean isClosed() throws SQLException { - return closed.get(); - } - - @Override - public DatabaseMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setReadOnly(boolean readOnly) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isReadOnly() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setCatalog(String catalog) throws SQLException { - checkOpen(); - this.catalog.set(catalog); - } - - @Override - public String getCatalog() throws SQLException { - checkOpen(); - return catalog.get(); - } - - @Override - public void setTransactionIsolation(int level) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getTransactionIsolation() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLWarning getWarnings() throws SQLException { - checkOpen(); - return null; - } - - @Override - public void clearWarnings() throws SQLException { - checkOpen(); - } - - @Override - public Statement createStatement(int resultSetType, int resultSetConcurrency) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public PreparedStatement prepareStatement(String sql, int resultSetType, - int resultSetConcurrency) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Map> getTypeMap() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setTypeMap(Map> map) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setHoldability(int holdability) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getHoldability() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Savepoint setSavepoint() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Savepoint setSavepoint(String name) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void rollback(Savepoint savepoint) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void releaseSavepoint(Savepoint savepoint) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Statement createStatement(int resultSetType, int resultSetConcurrency, - int resultSetHoldability) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public PreparedStatement prepareStatement(String sql, int resultSetType, - int resultSetConcurrency, int resultSetHoldability) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency, - int resultSetHoldability) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public PreparedStatement prepareStatement(String sql, String[] columnNames) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob createClob() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob createBlob() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob createNClob() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML createSQLXML() throws SQLException { - throw new UnsupportedOperationException(); - } +import com.google.common.primitives.Ints; - @Override - public boolean isValid(int timeout) throws SQLException { - throw new UnsupportedOperationException(); - } +public class AnalyticsConnection implements Connection { - @Override - public void setClientInfo(String name, String value) throws SQLClientInfoException { - requireNonNull(name, "name is null"); - if (value != null) { - clientInfo.put(name, value); + private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsConnection.class); + + private static final String SEPARATOR_COLON = ":"; + private final AtomicReference networkTimeoutMillis = + new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2))); + private final Map clientInfo = new ConcurrentHashMap<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + private final AtomicReference catalog = new AtomicReference<>(); + private final URI httpUri; + private final String user; + private final Map sessionProperties; + + private final AnalyticsDriverURI analyticsDriverURI; + + private AnalyticsClient client; + + public static Connection newInstance(String url, Properties properties) { + AnalyticsDriverURI analyticsDriverURI = new AnalyticsDriverURI(url, properties); + return new AnalyticsConnection(analyticsDriverURI); + } + + private AnalyticsConnection(AnalyticsDriverURI uri) { + requireNonNull(uri, "analytics driver uri is null"); + this.analyticsDriverURI = uri; + this.user = uri.getUser(); + this.httpUri = uri.getHttpUri(); + this.sessionProperties = new ConcurrentHashMap<>(uri.getSessionProperties()); + initAnalyticsClient(); + } + + private void initAnalyticsClient() { + String authority = this.analyticsDriverURI.getAuthority(); + try { + if (authority != null && !authority.isEmpty()) { + String[] split = authority.split(SEPARATOR_COLON); + if (split.length == 2) { + String host = split[0]; + int port = Integer.parseInt(split[1]); + this.client = AnalyticsClient.builder().withHost(host).withPort(port).build(); + LOGGER.info("init geaflow analytics connection with host [{}] and port [{}]", host, port); } else { - clientInfo.remove(name); - } - } - - @Override - public void setClientInfo(Properties properties) throws SQLClientInfoException { - clientInfo.putAll(fromProperties(properties)); - } - - @Override - public String getClientInfo(String name) throws SQLException { - return clientInfo.get(name); - } - - @Override - public Properties getClientInfo() throws SQLException { - Properties properties = new Properties(); - for (Map.Entry entry : clientInfo.entrySet()) { - properties.setProperty(entry.getKey(), entry.getValue()); - } - return properties; - } - - @Override - public Array createArrayOf(String typeName, Object[] elements) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Struct createStruct(String typeName, Object[] attributes) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setSchema(String schema) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getSchema() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void abort(Executor executor) throws SQLException { - close(); - } - - @Override - public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { - checkOpen(); - if (milliseconds < 0) { - throw new GeaflowRuntimeException("NetWork timeout is negative"); + LOGGER.warn("illegal authority: [{}]", authority); + throw new GeaflowRuntimeException("illegal authority: " + authority); } - networkTimeoutMillis.set(milliseconds); - } - - @Override - public int getNetworkTimeout() throws SQLException { - checkOpen(); - return networkTimeoutMillis.get(); - } - - @Override - public T unwrap(Class iFace) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isWrapperFor(Class iFace) throws SQLException { - throw new UnsupportedOperationException(); - } + } + } catch (Throwable e) { + LOGGER.warn("parse authority from driver uri failed [{}]", authority, e); + throw new GeaflowRuntimeException("analytics jdbc create client failed"); + } + } + + @Override + public Statement createStatement() throws SQLException { + checkOpen(); + return AnalyticsStatement.newInstance(this, client); + } + + private void checkOpen() throws SQLException { + if (isClosed()) { + throw new GeaflowRuntimeException("Connection is closed"); + } + } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public CallableStatement prepareCall(String sql) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String nativeSQL(String sql) throws SQLException { + checkOpen(); + return sql; + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getAutoCommit() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void commit() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void rollback() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void close() throws SQLException { + closed.set(true); + this.client.shutdown(); + } + + @Override + public boolean isClosed() throws SQLException { + return closed.get(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setReadOnly(boolean readOnly) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isReadOnly() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setCatalog(String catalog) throws SQLException { + checkOpen(); + this.catalog.set(catalog); + } + + @Override + public String getCatalog() throws SQLException { + checkOpen(); + return catalog.get(); + } + + @Override + public void setTransactionIsolation(int level) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getTransactionIsolation() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + checkOpen(); + return null; + } + + @Override + public void clearWarnings() throws SQLException { + checkOpen(); + } + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Map> getTypeMap() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setTypeMap(Map> map) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setHoldability(int holdability) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getHoldability() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Savepoint setSavepoint(String name) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void rollback(Savepoint savepoint) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void releaseSavepoint(Savepoint savepoint) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Statement createStatement( + int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public PreparedStatement prepareStatement( + String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public CallableStatement prepareCall( + String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob createClob() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob createBlob() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob createNClob() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isValid(int timeout) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setClientInfo(String name, String value) throws SQLClientInfoException { + requireNonNull(name, "name is null"); + if (value != null) { + clientInfo.put(name, value); + } else { + clientInfo.remove(name); + } + } + + @Override + public void setClientInfo(Properties properties) throws SQLClientInfoException { + clientInfo.putAll(fromProperties(properties)); + } + + @Override + public String getClientInfo(String name) throws SQLException { + return clientInfo.get(name); + } + + @Override + public Properties getClientInfo() throws SQLException { + Properties properties = new Properties(); + for (Map.Entry entry : clientInfo.entrySet()) { + properties.setProperty(entry.getKey(), entry.getValue()); + } + return properties; + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setSchema(String schema) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getSchema() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void abort(Executor executor) throws SQLException { + close(); + } + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException { + checkOpen(); + if (milliseconds < 0) { + throw new GeaflowRuntimeException("NetWork timeout is negative"); + } + networkTimeoutMillis.set(milliseconds); + } + + @Override + public int getNetworkTimeout() throws SQLException { + checkOpen(); + return networkTimeoutMillis.get(); + } + + @Override + public T unwrap(Class iFace) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class iFace) throws SQLException { + throw new UnsupportedOperationException(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriver.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriver.java index 41acb263e..5a9077a3a 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriver.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriver.java @@ -29,82 +29,80 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; + import org.apache.geaflow.analytics.service.client.utils.JDBCUtils; import org.apache.geaflow.common.exception.GeaflowRuntimeException; -/** - * This class is an adaptation of Presto's com.facebook.presto.jdbc.PrestoDriver. - */ +/** This class is an adaptation of Presto's com.facebook.presto.jdbc.PrestoDriver. */ public class AnalyticsDriver implements Driver { - static final String DRIVER_VERSION; - static final int DRIVER_MAJOR_VERSION; - static final int DRIVER_MINOR_VERSION; + static final String DRIVER_VERSION; + static final int DRIVER_MAJOR_VERSION; + static final int DRIVER_MINOR_VERSION; - static { - String version = nullToEmpty(AnalyticsDriver.class.getPackage().getImplementationVersion()); - Matcher matcher = Pattern.compile("^(\\d+)\\.(\\d+)($|[.-])").matcher(version); - if (!matcher.find()) { - DRIVER_VERSION = "unknown"; - DRIVER_MAJOR_VERSION = 0; - DRIVER_MINOR_VERSION = 0; - } else { - DRIVER_VERSION = version; - DRIVER_MAJOR_VERSION = parseInt(matcher.group(1)); - DRIVER_MINOR_VERSION = parseInt(matcher.group(2)); - } - - try { - DriverManager.registerDriver(new AnalyticsDriver()); - } catch (SQLException e) { - throw new GeaflowRuntimeException("can not register analytics driver", e); - } + static { + String version = nullToEmpty(AnalyticsDriver.class.getPackage().getImplementationVersion()); + Matcher matcher = Pattern.compile("^(\\d+)\\.(\\d+)($|[.-])").matcher(version); + if (!matcher.find()) { + DRIVER_VERSION = "unknown"; + DRIVER_MAJOR_VERSION = 0; + DRIVER_MINOR_VERSION = 0; + } else { + DRIVER_VERSION = version; + DRIVER_MAJOR_VERSION = parseInt(matcher.group(1)); + DRIVER_MINOR_VERSION = parseInt(matcher.group(2)); } - @Override - public Connection connect(String url, Properties properties) { - if (!acceptsURL(url)) { - return null; - } - return AnalyticsConnection.newInstance(url, properties); + try { + DriverManager.registerDriver(new AnalyticsDriver()); + } catch (SQLException e) { + throw new GeaflowRuntimeException("can not register analytics driver", e); } + } - @Override - public boolean acceptsURL(String url) { - return JDBCUtils.acceptsURL(url); + @Override + public Connection connect(String url, Properties properties) { + if (!acceptsURL(url)) { + return null; } + return AnalyticsConnection.newInstance(url, properties); + } - @Override - public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) { - AnalyticsDriverURI analyticsDriverURI = new AnalyticsDriverURI(url, info); - Properties properties = analyticsDriverURI.getProperties(); - ArrayList driverPropertyInfos = new ArrayList<>(); - Set keySets = properties.keySet().stream().map(Object::toString) - .collect(Collectors.toSet()); - for (String key : keySets) { - driverPropertyInfos.add(new DriverPropertyInfo(key, properties.getProperty(key))); - } - return driverPropertyInfos.toArray(new DriverPropertyInfo[0]); - } + @Override + public boolean acceptsURL(String url) { + return JDBCUtils.acceptsURL(url); + } - @Override - public int getMajorVersion() { - return DRIVER_MAJOR_VERSION; + @Override + public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) { + AnalyticsDriverURI analyticsDriverURI = new AnalyticsDriverURI(url, info); + Properties properties = analyticsDriverURI.getProperties(); + ArrayList driverPropertyInfos = new ArrayList<>(); + Set keySets = + properties.keySet().stream().map(Object::toString).collect(Collectors.toSet()); + for (String key : keySets) { + driverPropertyInfos.add(new DriverPropertyInfo(key, properties.getProperty(key))); } + return driverPropertyInfos.toArray(new DriverPropertyInfo[0]); + } - @Override - public int getMinorVersion() { - return DRIVER_MINOR_VERSION; - } + @Override + public int getMajorVersion() { + return DRIVER_MAJOR_VERSION; + } - @Override - public boolean jdbcCompliant() { - return false; - } + @Override + public int getMinorVersion() { + return DRIVER_MINOR_VERSION; + } - @Override - public Logger getParentLogger() { - throw new UnsupportedOperationException(); - } + @Override + public boolean jdbcCompliant() { + return false; + } + @Override + public Logger getParentLogger() { + throw new UnsupportedOperationException(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURI.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURI.java index 4f4d2c0e2..a080dc4bc 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURI.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURI.java @@ -26,10 +26,6 @@ import static org.apache.geaflow.analytics.service.client.jdbc.property.ConnectProperties.SESSION_PROPERTIES; import static org.apache.geaflow.analytics.service.client.utils.JDBCUtils.acceptsURL; -import com.google.common.base.Splitter; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; -import com.google.common.net.HostAndPort; import java.net.URI; import java.net.URISyntaxException; import java.util.HashMap; @@ -37,169 +33,175 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Properties; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.net.HostAndPort; + public final class AnalyticsDriverURI { - private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsDriverURI.class); - private static final Splitter QUERY_SPLITTER = Splitter.on('&').omitEmptyStrings(); - private static final Splitter ARG_SPLITTER = Splitter.on('=').limit(2); - private static final String HTTP_PREFIX = "http"; - private static final String USER = "user"; - private static final String SECURE_HTTP_PREFIX = "https"; - private static final int JDBC_SCHEMA_OFFSET = 5; - private static final int JDBC_SECURE_PORT = 443; - private static final int MAX_PORT_LIMIT = 65535; - private static final int MIN_PORT_LIMIT = 1; - private static final String SYMBOL_SLASH = "/"; - private final URI uri; - private final HostAndPort address; - private final Properties properties; - private final boolean useSecureConnection; - private String graphView; - private final String authority; - - public AnalyticsDriverURI(String url, Properties driverProperties) { - this(parseDriverUrl(url), driverProperties); + private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsDriverURI.class); + private static final Splitter QUERY_SPLITTER = Splitter.on('&').omitEmptyStrings(); + private static final Splitter ARG_SPLITTER = Splitter.on('=').limit(2); + private static final String HTTP_PREFIX = "http"; + private static final String USER = "user"; + private static final String SECURE_HTTP_PREFIX = "https"; + private static final int JDBC_SCHEMA_OFFSET = 5; + private static final int JDBC_SECURE_PORT = 443; + private static final int MAX_PORT_LIMIT = 65535; + private static final int MIN_PORT_LIMIT = 1; + private static final String SYMBOL_SLASH = "/"; + private final URI uri; + private final HostAndPort address; + private final Properties properties; + private final boolean useSecureConnection; + private String graphView; + private final String authority; + + public AnalyticsDriverURI(String url, Properties driverProperties) { + this(parseDriverUrl(url), driverProperties); + } + + private AnalyticsDriverURI(URI uri, Properties driverProperties) { + this.uri = requireNonNull(uri, "analytics jdbc uri is null"); + this.address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); + this.properties = mergeConnectionProperties(uri, driverProperties); + this.useSecureConnection = uri.getPort() == JDBC_SECURE_PORT; + this.authority = uri.getAuthority(); + initGraphView(); + } + + public String getGraphView() { + return graphView; + } + + public Properties getProperties() { + return properties; + } + + private static URI parseDriverUrl(String url) { + URI jdbcUrl; + try { + jdbcUrl = acceptsURL(url) ? new URI(url.substring(JDBC_SCHEMA_OFFSET)) : null; + } catch (URISyntaxException e) { + String msg = String.format("illegal url: %s", url); + throw new GeaflowRuntimeException(msg); } - private AnalyticsDriverURI(URI uri, Properties driverProperties) { - this.uri = requireNonNull(uri, "analytics jdbc uri is null"); - this.address = HostAndPort.fromParts(uri.getHost(), uri.getPort()); - this.properties = mergeConnectionProperties(uri, driverProperties); - this.useSecureConnection = uri.getPort() == JDBC_SECURE_PORT; - this.authority = uri.getAuthority(); - initGraphView(); + if (isNullOrEmpty(jdbcUrl.getHost())) { + throw new GeaflowRuntimeException("No host specified: " + url); } - public String getGraphView() { - return graphView; + if (jdbcUrl.getPort() == -1) { + throw new GeaflowRuntimeException("No port number specified: " + url); } - - public Properties getProperties() { - return properties; + if ((jdbcUrl.getPort() < MIN_PORT_LIMIT) || (jdbcUrl.getPort() > MAX_PORT_LIMIT)) { + throw new GeaflowRuntimeException("Invalid port number: " + url); } - - private static URI parseDriverUrl(String url) { - URI jdbcUrl; - try { - jdbcUrl = acceptsURL(url) ? new URI(url.substring(JDBC_SCHEMA_OFFSET)) : null; - } catch (URISyntaxException e) { - String msg = String.format("illegal url: %s", url); - throw new GeaflowRuntimeException(msg); - } - - if (isNullOrEmpty(jdbcUrl.getHost())) { - throw new GeaflowRuntimeException("No host specified: " + url); - } - - if (jdbcUrl.getPort() == -1) { - throw new GeaflowRuntimeException("No port number specified: " + url); - } - if ((jdbcUrl.getPort() < MIN_PORT_LIMIT) || (jdbcUrl.getPort() > MAX_PORT_LIMIT)) { - throw new GeaflowRuntimeException("Invalid port number: " + url); - } - return jdbcUrl; + return jdbcUrl; + } + + private static Properties mergeConnectionProperties(URI uri, Properties properties) { + Map urlProperties = parseParameters(uri.getQuery()); + Map suppliedProperties = Maps.fromProperties(properties); + + for (String key : urlProperties.keySet()) { + if (suppliedProperties.containsKey(key)) { + throw new GeaflowRuntimeException( + String.format("Connection property '%s' is both in the URL and an " + "argument", key)); + } } - private static Properties mergeConnectionProperties(URI uri, Properties properties) { - Map urlProperties = parseParameters(uri.getQuery()); - Map suppliedProperties = Maps.fromProperties(properties); + Properties result = new Properties(); + setProperties(result, urlProperties); + setProperties(result, suppliedProperties); + return result; + } - for (String key : urlProperties.keySet()) { - if (suppliedProperties.containsKey(key)) { - throw new GeaflowRuntimeException( - String.format("Connection property '%s' is both in the URL and an " - + "argument", key)); - } - } - - Properties result = new Properties(); - setProperties(result, urlProperties); - setProperties(result, suppliedProperties); - return result; + private void initGraphView() { + String path = uri.getPath(); + if (isNullOrEmpty(uri.getPath()) || SYMBOL_SLASH.equals(path)) { + return; } - private void initGraphView() { - String path = uri.getPath(); - if (isNullOrEmpty(uri.getPath()) || SYMBOL_SLASH.equals(path)) { - return; - } - - if (!path.startsWith(SYMBOL_SLASH)) { - throw new GeaflowRuntimeException("Path does not start with a slash: " + uri); - } - - path = path.substring(1); - LOGGER.info("get server path {}, from url {}", path, uri); - - List parts = Splitter.on(SYMBOL_SLASH).splitToList(path); - if (parts.get(parts.size() - 1).isEmpty()) { - parts = parts.subList(0, parts.size() - 1); - } - if (parts.size() > 2) { - throw new GeaflowRuntimeException("Invalid path segments in URL: " + uri); - } - - if (parts.get(0).isEmpty()) { - throw new GeaflowRuntimeException("Graph view is empty: " + uri); - } - graphView = parts.get(0); + if (!path.startsWith(SYMBOL_SLASH)) { + throw new GeaflowRuntimeException("Path does not start with a slash: " + uri); } - private static void setProperties(Properties properties, Map values) { - for (Entry entry : values.entrySet()) { - properties.setProperty(entry.getKey(), entry.getValue()); - } - } + path = path.substring(1); + LOGGER.info("get server path {}, from url {}", path, uri); - public String getUser() { - String user = properties.getProperty(USER); - if (StringUtils.isEmpty(user)) { - throw new GeaflowRuntimeException("connect user is null"); - } - return user; + List parts = Splitter.on(SYMBOL_SLASH).splitToList(path); + if (parts.get(parts.size() - 1).isEmpty()) { + parts = parts.subList(0, parts.size() - 1); } - - public boolean isCompressionDisabled() { - return false; + if (parts.size() > 2) { + throw new GeaflowRuntimeException("Invalid path segments in URL: " + uri); } - public URI getHttpUri() { - String scheme = useSecureConnection ? SECURE_HTTP_PREFIX : HTTP_PREFIX; - try { - return new URI(scheme, null, address.getHost(), address.getPort(), null, null, null); - } catch (URISyntaxException e) { - throw new GeaflowRuntimeException("get http uri fail", e); - } + if (parts.get(0).isEmpty()) { + throw new GeaflowRuntimeException("Graph view is empty: " + uri); } + graphView = parts.get(0); + } - public Map getSessionProperties() { - return SESSION_PROPERTIES.getValue(properties).orElse(ImmutableMap.of()); + private static void setProperties(Properties properties, Map values) { + for (Entry entry : values.entrySet()) { + properties.setProperty(entry.getKey(), entry.getValue()); } + } - public Map getCustomHeaders() { - return CUSTOM_HEADERS.getValue(properties).orElse(ImmutableMap.of()); + public String getUser() { + String user = properties.getProperty(USER); + if (StringUtils.isEmpty(user)) { + throw new GeaflowRuntimeException("connect user is null"); } - - private static Map parseParameters(String query) { - Map result = new HashMap<>(); - if (query != null) { - Iterable queryArgs = QUERY_SPLITTER.split(query); - for (String queryArg : queryArgs) { - List parts = ARG_SPLITTER.splitToList(queryArg); - if (result.put(parts.get(0), parts.get(1)) != null) { - throw new GeaflowRuntimeException(format("Connection property '%s' is in URL multiple times", parts.get(0))); - } - } + return user; + } + + public boolean isCompressionDisabled() { + return false; + } + + public URI getHttpUri() { + String scheme = useSecureConnection ? SECURE_HTTP_PREFIX : HTTP_PREFIX; + try { + return new URI(scheme, null, address.getHost(), address.getPort(), null, null, null); + } catch (URISyntaxException e) { + throw new GeaflowRuntimeException("get http uri fail", e); + } + } + + public Map getSessionProperties() { + return SESSION_PROPERTIES.getValue(properties).orElse(ImmutableMap.of()); + } + + public Map getCustomHeaders() { + return CUSTOM_HEADERS.getValue(properties).orElse(ImmutableMap.of()); + } + + private static Map parseParameters(String query) { + Map result = new HashMap<>(); + if (query != null) { + Iterable queryArgs = QUERY_SPLITTER.split(query); + for (String queryArg : queryArgs) { + List parts = ARG_SPLITTER.splitToList(queryArg); + if (result.put(parts.get(0), parts.get(1)) != null) { + throw new GeaflowRuntimeException( + format("Connection property '%s' is in URL multiple times", parts.get(0))); } - return result; + } } + return result; + } - public String getAuthority() { - return authority; - } + public String getAuthority() { + return authority; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultMetaData.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultMetaData.java index 106347b03..08388b729 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultMetaData.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultMetaData.java @@ -27,151 +27,155 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.geaflow.common.type.IType; public class AnalyticsResultMetaData implements ResultSetMetaData { - private final RelDataType resultType; - private final List fields; - private final Map> tableFieldName2Type; - - public AnalyticsResultMetaData(RelDataType resultType) { - this.resultType = resultType; - this.fields = resultType != null ? resultType.getFieldList() : new ArrayList<>(); - this.tableFieldName2Type = fields.stream().collect(Collectors.toMap(RelDataTypeField::getName, - relDataTypeField -> convertType(relDataTypeField.getType()))); - } - - public Map> getTableFieldName2Type() { - return tableFieldName2Type; - } - - public List getFields() { - return fields; - } - - - @Override - public int getColumnCount() throws SQLException { - return this.resultType.getFieldCount(); - } - - @Override - public boolean isAutoIncrement(int column) throws SQLException { - return false; - } - - @Override - public boolean isCaseSensitive(int column) throws SQLException { - return false; - } - - @Override - public boolean isSearchable(int column) throws SQLException { - return true; - } - - @Override - public boolean isCurrency(int column) throws SQLException { - return false; - } - - @Override - public int isNullable(int column) throws SQLException { - return ResultSetMetaData.columnNoNulls; - } - - @Override - public boolean isSigned(int column) throws SQLException { - return false; - } - - @Override - public int getColumnDisplaySize(int column) throws SQLException { - return 0; - } - - @Override - public String getColumnLabel(int column) throws SQLException { - return null; - } - - @Override - public String getColumnName(int column) throws SQLException { - return getRelDataTypeField(column).getName(); - } - - @Override - public String getSchemaName(int column) throws SQLException { - return null; - } - - @Override - public int getPrecision(int column) throws SQLException { - return 0; - } - - @Override - public int getScale(int column) throws SQLException { - return 0; - } - - @Override - public String getTableName(int column) throws SQLException { - return null; - } - - @Override - public String getCatalogName(int column) throws SQLException { - return null; - } - - @Override - public int getColumnType(int column) throws SQLException { - return 0; - } - - @Override - public String getColumnTypeName(int column) throws SQLException { - return null; - } - - @Override - public boolean isReadOnly(int column) throws SQLException { - return false; - } - - @Override - public boolean isWritable(int column) throws SQLException { - return false; - } - - @Override - public boolean isDefinitelyWritable(int column) throws SQLException { - return false; - } - - @Override - public String getColumnClassName(int column) throws SQLException { - return null; - } - - @Override - public T unwrap(Class iface) throws SQLException { - return null; - } - - @Override - public boolean isWrapperFor(Class iface) throws SQLException { - return false; - } - - private RelDataTypeField getRelDataTypeField(int column) throws SQLException { - if ((column <= 0) || (column > this.fields.size())) { - throw new SQLException("Invalid column index: " + column); - } - return this.fields.get(column); - } + private final RelDataType resultType; + private final List fields; + private final Map> tableFieldName2Type; + + public AnalyticsResultMetaData(RelDataType resultType) { + this.resultType = resultType; + this.fields = resultType != null ? resultType.getFieldList() : new ArrayList<>(); + this.tableFieldName2Type = + fields.stream() + .collect( + Collectors.toMap( + RelDataTypeField::getName, + relDataTypeField -> convertType(relDataTypeField.getType()))); + } + + public Map> getTableFieldName2Type() { + return tableFieldName2Type; + } + + public List getFields() { + return fields; + } + + @Override + public int getColumnCount() throws SQLException { + return this.resultType.getFieldCount(); + } + + @Override + public boolean isAutoIncrement(int column) throws SQLException { + return false; + } + + @Override + public boolean isCaseSensitive(int column) throws SQLException { + return false; + } + + @Override + public boolean isSearchable(int column) throws SQLException { + return true; + } + + @Override + public boolean isCurrency(int column) throws SQLException { + return false; + } + + @Override + public int isNullable(int column) throws SQLException { + return ResultSetMetaData.columnNoNulls; + } + + @Override + public boolean isSigned(int column) throws SQLException { + return false; + } + + @Override + public int getColumnDisplaySize(int column) throws SQLException { + return 0; + } + + @Override + public String getColumnLabel(int column) throws SQLException { + return null; + } + + @Override + public String getColumnName(int column) throws SQLException { + return getRelDataTypeField(column).getName(); + } + + @Override + public String getSchemaName(int column) throws SQLException { + return null; + } + + @Override + public int getPrecision(int column) throws SQLException { + return 0; + } + + @Override + public int getScale(int column) throws SQLException { + return 0; + } + + @Override + public String getTableName(int column) throws SQLException { + return null; + } + + @Override + public String getCatalogName(int column) throws SQLException { + return null; + } + + @Override + public int getColumnType(int column) throws SQLException { + return 0; + } + + @Override + public String getColumnTypeName(int column) throws SQLException { + return null; + } + + @Override + public boolean isReadOnly(int column) throws SQLException { + return false; + } + + @Override + public boolean isWritable(int column) throws SQLException { + return false; + } + + @Override + public boolean isDefinitelyWritable(int column) throws SQLException { + return false; + } + + @Override + public String getColumnClassName(int column) throws SQLException { + return null; + } + + @Override + public T unwrap(Class iface) throws SQLException { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + + private RelDataTypeField getRelDataTypeField(int column) throws SQLException { + if ((column <= 0) || (column > this.fields.size())) { + throw new SQLException("Invalid column index: " + column); + } + return this.fields.get(column); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultSet.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultSet.java index 0a3897ff3..82d45d97d 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultSet.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsResultSet.java @@ -23,9 +23,6 @@ import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; -import com.google.common.base.Preconditions; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.ImmutableMap; import java.io.InputStream; import java.io.Reader; import java.math.BigDecimal; @@ -53,6 +50,7 @@ import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.geaflow.analytics.service.query.IQueryStatus; @@ -73,1243 +71,1259 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class AnalyticsResultSet implements ResultSet { - - private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsResultSet.class); - - private final QueryResults queryResults; - private final Statement statement; - private final Iterator> dataResults; - private final AtomicReference> currentResult = new AtomicReference<>(); - private final AtomicBoolean closed = new AtomicBoolean(); - private final AtomicBoolean wasNull = new AtomicBoolean(); - private String queryId; - private final AnalyticsResultMetaData resultSetMetaData; - private final Map fieldLabel2Index; - private final Map fieldIndex2Label; - - public AnalyticsResultSet(AnalyticsStatement statement, QueryResults result) { - requireNonNull(result, "query result is null"); - this.statement = statement; - this.resultSetMetaData = new AnalyticsResultMetaData(result.getResultMeta()); - this.queryId = result.getQueryId(); - this.queryResults = result; - this.fieldLabel2Index = getFieldLabel2IndexMap(result.getResultMeta()); - this.fieldIndex2Label = getFieldIndex2LabelMap(result.getResultMeta()); - List> iteratorData = queryResults.getRawData(); - this.dataResults = flatten(new PageIterator(iteratorData), Long.MAX_VALUE); - } - - public static GeaflowRuntimeException resultsException(IQueryStatus result) { - QueryError error = requireNonNull(result.getError()); - String message = format("QueryId failed (#%s): %s", result.getQueryId(), error.getName()); - return new GeaflowRuntimeException(message); - } +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableMap; - public String getQueryId() { - return this.queryId; - } +public class AnalyticsResultSet implements ResultSet { - @Override - public boolean next() throws SQLException { - checkResultSetOpen(); - try { - if (!dataResults.hasNext()) { - currentResult.set(null); - return false; - } - currentResult.set(dataResults.next()); - return true; - } catch (RuntimeException e) { - if (e.getCause() instanceof SQLException) { - throw new GeaflowRuntimeException(e); - } - throw new GeaflowRuntimeException("fetching results failed", e); + private static final Logger LOGGER = LoggerFactory.getLogger(AnalyticsResultSet.class); + + private final QueryResults queryResults; + private final Statement statement; + private final Iterator> dataResults; + private final AtomicReference> currentResult = new AtomicReference<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + private final AtomicBoolean wasNull = new AtomicBoolean(); + private String queryId; + private final AnalyticsResultMetaData resultSetMetaData; + private final Map fieldLabel2Index; + private final Map fieldIndex2Label; + + public AnalyticsResultSet(AnalyticsStatement statement, QueryResults result) { + requireNonNull(result, "query result is null"); + this.statement = statement; + this.resultSetMetaData = new AnalyticsResultMetaData(result.getResultMeta()); + this.queryId = result.getQueryId(); + this.queryResults = result; + this.fieldLabel2Index = getFieldLabel2IndexMap(result.getResultMeta()); + this.fieldIndex2Label = getFieldIndex2LabelMap(result.getResultMeta()); + List> iteratorData = queryResults.getRawData(); + this.dataResults = flatten(new PageIterator(iteratorData), Long.MAX_VALUE); + } + + public static GeaflowRuntimeException resultsException(IQueryStatus result) { + QueryError error = requireNonNull(result.getError()); + String message = format("QueryId failed (#%s): %s", result.getQueryId(), error.getName()); + return new GeaflowRuntimeException(message); + } + + public String getQueryId() { + return this.queryId; + } + + @Override + public boolean next() throws SQLException { + checkResultSetOpen(); + try { + if (!dataResults.hasNext()) { + currentResult.set(null); + return false; + } + currentResult.set(dataResults.next()); + return true; + } catch (RuntimeException e) { + if (e.getCause() instanceof SQLException) { + throw new GeaflowRuntimeException(e); + } + throw new GeaflowRuntimeException("fetching results failed", e); + } + } + + @Override + public void close() throws SQLException { + closed.set(true); + statement.close(); + } + + @Override + public boolean wasNull() throws SQLException { + return wasNull.get(); + } + + @Override + public String getString(int index) throws SQLException { + Object value = getFiledByIndex(index); + return (value != null) ? value.toString() : null; + } + + @Override + public boolean getBoolean(int index) throws SQLException { + Object value = getFiledByIndex(index); + return (value != null) ? (Boolean) value : false; + } + + @Override + public byte getByte(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).byteValue(); + } + + @Override + public short getShort(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).shortValue(); + } + + @Override + public int getInt(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).intValue(); + } + + @Override + public long getLong(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).longValue(); + } + + @Override + public float getFloat(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).floatValue(); + } + + @Override + public double getDouble(int index) throws SQLException { + return convertNumber(getFiledByIndex(index)).doubleValue(); + } + + @Override + public BigDecimal getBigDecimal(int index, int scale) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBytes(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + public IVertex> getVertex(int index) throws SQLException { + Object field = getObject(index); + Preconditions.checkState( + field instanceof RowVertex, + "field index " + + index + + " must be " + + "vertex, but get type is " + + field.getClass().getCanonicalName()); + RowVertex rowVertex = (RowVertex) field; + String fieldLabel = getFieldLabelByIndex(index); + Map properties = parseVertexValue(rowVertex, fieldLabel); + return new ValueLabelVertex<>(rowVertex.getId(), properties, rowVertex.getLabel()); + } + + public IVertex> getVertex(String label) throws SQLException { + return getVertex(getFieldIndexByLabel(label)); + } + + public IEdge> getEdge(int index) throws SQLException { + Object field = getObject(index); + Preconditions.checkState( + field instanceof RowEdge, + "field index " + + index + + " must be " + + "edge, but get type is " + + field.getClass().getCanonicalName()); + RowEdge rowEdge = (RowEdge) field; + String fieldLabel = getFieldLabelByIndex(index); + Map properties = parseEdgeValue(rowEdge, fieldLabel); + return new ValueLabelEdge<>( + rowEdge.getSrcId(), + rowEdge.getTargetId(), + properties, + rowEdge.getDirect(), + rowEdge.getLabel()); + } + + public IEdge> getEdge(String label) throws SQLException { + return getEdge(getFieldIndexByLabel(label)); + } + + @Override + public Time getTime(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getAsciiStream(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getUnicodeStream(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(String columnLabel) throws SQLException { + Object value = getFiledByLabel(columnLabel); + return (value != null) ? value.toString() : null; + } + + @Override + public boolean getBoolean(String columnLabel) throws SQLException { + Object value = getFiledByLabel(columnLabel); + return (value != null) ? (Boolean) value : false; + } + + @Override + public byte getByte(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).byteValue(); + } + + @Override + public short getShort(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).shortValue(); + } + + @Override + public int getInt(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).intValue(); + } + + @Override + public long getLong(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).longValue(); + } + + @Override + public float getFloat(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).floatValue(); + } + + @Override + public double getDouble(String columnLabel) throws SQLException { + return convertNumber(getFiledByLabel(columnLabel)).doubleValue(); + } + + @Override + public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBytes(String columnLabel) throws SQLException { + return (byte[]) getFiledByLabel(columnLabel); + } + + @Override + public Date getDate(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getAsciiStream(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getUnicodeStream(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void clearWarnings() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getCursorName() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(int columnIndex) throws SQLException { + return getFiledByIndex(columnIndex); + } + + @Override + public Object getObject(String columnLabel) throws SQLException { + return getObject(getFieldIndexByLabel(columnLabel)); + } + + @Override + public int findColumn(String columnLabel) throws SQLException { + checkResultSetOpen(); + return getFieldIndexByLabel(columnLabel); + } + + @Override + public Reader getCharacterStream(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getBigDecimal(int index) throws SQLException { + Object value = getFiledByIndex(index); + if (value == null) { + return null; + } + return new BigDecimal(String.valueOf(value)); + } + + @Override + public BigDecimal getBigDecimal(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isBeforeFirst() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isAfterLast() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFirst() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isLast() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void beforeFirst() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void afterLast() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean first() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean last() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean absolute(int row) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean relative(int rows) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean previous() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchDirection() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchSize() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getType() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getConcurrency() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowUpdated() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowInserted() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowDeleted() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(int index, boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(int index, byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(int index, short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(int index, int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(int index, long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(int index, float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(int index, double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(int index, BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(int index, String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(int index, byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(int index, Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(int index, Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(int index, Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(int index, InputStream x, int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(int index, InputStream x, int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(int index, Reader x, int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(int index, Object x, int scaleOrLength) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(int index, Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(String columnLabel, boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(String columnLabel, byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(String columnLabel, short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(String columnLabel, int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(String columnLabel, long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(String columnLabel, float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(String columnLabel, double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(String columnLabel, String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(String columnLabel, byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(String columnLabel, Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(String columnLabel, Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(String columnLabel, Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x, int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x, int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader, int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(String columnLabel, Object x, int scaleOrLength) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(String columnLabel, Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void insertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void refreshRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelRowUpdates() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToInsertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToCurrentRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Statement getStatement() throws SQLException { + return this.statement; + } + + @Override + public Object getObject(int index, Map> map) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(String columnLabel, Map> map) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int index, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(String columnLabel, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(int index, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(String columnLabel, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(int index, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public URL getURL(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public URL getURL(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(int index, Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(String columnLabel, Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(int index, Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(String columnLabel, Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(int index, Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(String columnLabel, Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(int index, Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(String columnLabel, Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(int index, RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(String columnLabel, RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getHoldability() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClosed() throws SQLException { + return closed.get(); + } + + @Override + public void updateNString(int index, String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNString(String columnLabel, String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(int index, NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(String columnLabel, NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob getNClob(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob getNClob(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(int index, SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(String columnLabel, SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(int index) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(int index, Reader x, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(String columnLabel, Reader reader, long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(int index, InputStream x, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(int index, InputStream x, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(int index, Reader x, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x, long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x, long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader, long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(int index, InputStream inputStream, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(String columnLabel, InputStream inputStream, long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(int index, Reader reader, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(String columnLabel, Reader reader, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(int index, Reader reader, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(String columnLabel, Reader reader, long length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(int index, Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(String columnLabel, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(int index, InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(int index, InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(int index, Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(String columnLabel, InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(String columnLabel, InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(String columnLabel, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(int index, InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(String columnLabel, InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(int index, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(String columnLabel, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(int index, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(String columnLabel, Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getObject(int index, Class type) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getObject(String columnLabel, Class type) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T unwrap(Class iFace) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class iFace) throws SQLException { + throw new UnsupportedOperationException(); + } + + private void checkResultSetOpen() throws SQLException { + if (isClosed()) { + throw new SQLException("Analytics result set is closed"); + } + } + + private int getFieldIndexByLabel(String label) throws SQLException { + if (label == null) { + throw new SQLException("field label is null"); + } + Integer index = fieldLabel2Index.get(label.toLowerCase(ENGLISH)); + if (index == null) { + throw new SQLException("invalid field label: " + label); + } + return index; + } + + private String getFieldLabelByIndex(int index) throws SQLException { + if (index <= 0) { + throw new SQLException( + String.format("field index must be positive,but current index is %d", index)); + } + String label = fieldIndex2Label.get(index); + if (label == null) { + throw new SQLException(String.format("label is not exist, index is %d", index)); + } + return label; + } + + private static Number convertNumber(Object value) throws SQLException { + if (value == null) { + return 0; + } + if (value instanceof Number) { + return (Number) value; + } + if (value instanceof Boolean) { + return ((Boolean) value) ? 1 : 0; + } + throw new SQLException( + String.format("value %s is not number", value.getClass().getCanonicalName())); + } + + private Object getFiledByIndex(int index) throws SQLException { + checkResultSetOpen(); + checkValidResult(); + if ((index <= 0) || (index > resultSetMetaData.getColumnCount())) { + throw new SQLException("Invalid column index: " + index); + } + Object value = currentResult.get().get(index - 1); + wasNull.set(value == null); + return value; + } + + private Object getFiledByLabel(String label) throws SQLException { + checkResultSetOpen(); + checkValidResult(); + Object value = currentResult.get().get(getFieldIndexByLabel(label) - 1); + wasNull.set(value == null); + return value; + } + + private static Map getFieldLabel2IndexMap(RelDataType relDataType) { + Map map = new HashMap<>(); + List fieldList = relDataType.getFieldList(); + for (int i = 0; i < fieldList.size(); i++) { + String name = fieldList.get(i).getName().toLowerCase(ENGLISH); + if (!map.containsKey(name)) { + map.put(name, i + 1); + } + } + return ImmutableMap.copyOf(map); + } + + private Map parseVertexValue(RowVertex rowVertex, String fieldLabel) { + Row vertexValue = rowVertex.getValue(); + Map properties = new HashMap<>(); + if (vertexValue != null) { + Map> tableFieldName2Type = resultSetMetaData.getTableFieldName2Type(); + for (Entry> entry : tableFieldName2Type.entrySet()) { + if (fieldLabel.equalsIgnoreCase(entry.getKey())) { + VertexType vertexType = (VertexType) entry.getValue(); + List valueFields = vertexType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + TableField tableField = valueFields.get(i); + Object valueField = vertexValue.getField(i, tableField.getType()); + properties.put(tableField.getName(), valueField); + } } - } - - @Override - public void close() throws SQLException { - closed.set(true); - statement.close(); - } - - @Override - public boolean wasNull() throws SQLException { - return wasNull.get(); - } - - @Override - public String getString(int index) throws SQLException { - Object value = getFiledByIndex(index); - return (value != null) ? value.toString() : null; - } - - @Override - public boolean getBoolean(int index) throws SQLException { - Object value = getFiledByIndex(index); - return (value != null) ? (Boolean) value : false; - } - - @Override - public byte getByte(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).byteValue(); - } - - @Override - public short getShort(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).shortValue(); - } - - @Override - public int getInt(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).intValue(); - } - - @Override - public long getLong(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).longValue(); - } - - @Override - public float getFloat(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).floatValue(); - } - - @Override - public double getDouble(int index) throws SQLException { - return convertNumber(getFiledByIndex(index)).doubleValue(); - } - - @Override - public BigDecimal getBigDecimal(int index, int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - public IVertex> getVertex(int index) throws SQLException { - Object field = getObject(index); - Preconditions.checkState(field instanceof RowVertex, "field index " + index + " must be " - + "vertex, but get type is " + field.getClass().getCanonicalName()); - RowVertex rowVertex = (RowVertex) field; - String fieldLabel = getFieldLabelByIndex(index); - Map properties = parseVertexValue(rowVertex, fieldLabel); - return new ValueLabelVertex<>(rowVertex.getId(), properties, rowVertex.getLabel()); - } - - public IVertex> getVertex(String label) throws SQLException { - return getVertex(getFieldIndexByLabel(label)); - } - - public IEdge> getEdge(int index) throws SQLException { - Object field = getObject(index); - Preconditions.checkState(field instanceof RowEdge, "field index " + index + " must be " - + "edge, but get type is " + field.getClass().getCanonicalName()); - RowEdge rowEdge = (RowEdge) field; - String fieldLabel = getFieldLabelByIndex(index); - Map properties = parseEdgeValue(rowEdge, fieldLabel); - return new ValueLabelEdge<>(rowEdge.getSrcId(), rowEdge.getTargetId(), properties, rowEdge.getDirect(), rowEdge.getLabel()); - } - - public IEdge> getEdge(String label) throws SQLException { - return getEdge(getFieldIndexByLabel(label)); - } - - @Override - public Time getTime(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getUnicodeStream(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(String columnLabel) throws SQLException { - Object value = getFiledByLabel(columnLabel); - return (value != null) ? value.toString() : null; - } - - @Override - public boolean getBoolean(String columnLabel) throws SQLException { - Object value = getFiledByLabel(columnLabel); - return (value != null) ? (Boolean) value : false; - } - - @Override - public byte getByte(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).byteValue(); - } - - @Override - public short getShort(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).shortValue(); - } - - @Override - public int getInt(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).intValue(); - } - - @Override - public long getLong(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).longValue(); - } - - @Override - public float getFloat(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).floatValue(); - } - - @Override - public double getDouble(String columnLabel) throws SQLException { - return convertNumber(getFiledByLabel(columnLabel)).doubleValue(); - } - - @Override - public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(String columnLabel) throws SQLException { - return (byte[]) getFiledByLabel(columnLabel); - } - - @Override - public Date getDate(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getUnicodeStream(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void clearWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getCursorName() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public ResultSetMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(int columnIndex) throws SQLException { - return getFiledByIndex(columnIndex); - } - - @Override - public Object getObject(String columnLabel) throws SQLException { - return getObject(getFieldIndexByLabel(columnLabel)); - } - - - @Override - public int findColumn(String columnLabel) throws SQLException { - checkResultSetOpen(); - return getFieldIndexByLabel(columnLabel); - } - - @Override - public Reader getCharacterStream(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(int index) throws SQLException { - Object value = getFiledByIndex(index); - if (value == null) { - return null; + } + } + return properties; + } + + private Map parseEdgeValue(RowEdge rowEdge, String fieldLabel) { + Row edgeValue = rowEdge.getValue(); + Map properties = new HashMap<>(); + if (edgeValue != null) { + Map> tableFieldName2Type = resultSetMetaData.getTableFieldName2Type(); + for (Entry> entry : tableFieldName2Type.entrySet()) { + if (fieldLabel.equalsIgnoreCase(entry.getKey())) { + EdgeType edgeType = (EdgeType) entry.getValue(); + List valueFields = edgeType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + TableField tableField = valueFields.get(i); + Object valueField = edgeValue.getField(i, tableField.getType()); + properties.put(tableField.getName(), valueField); + } } - return new BigDecimal(String.valueOf(value)); - } - - @Override - public BigDecimal getBigDecimal(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isBeforeFirst() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isAfterLast() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isFirst() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isLast() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void beforeFirst() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void afterLast() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean first() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean last() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean absolute(int row) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean relative(int rows) throws SQLException { - throw new UnsupportedOperationException(); + } } + return properties; + } - @Override - public boolean previous() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setFetchDirection(int direction) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchDirection() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setFetchSize(int rows) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchSize() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getType() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getConcurrency() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowUpdated() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowInserted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowDeleted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(int index, boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(int index, byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(int index, short x) throws SQLException { - throw new UnsupportedOperationException(); + private static Map getFieldIndex2LabelMap(RelDataType relDataType) { + Map map = new HashMap<>(); + List fieldList = relDataType.getFieldList(); + for (int i = 0; i < fieldList.size(); i++) { + String name = fieldList.get(i).getName().toLowerCase(ENGLISH); + if (!map.containsValue(name)) { + map.put(i + 1, name); + } } + return ImmutableMap.copyOf(map); + } - @Override - public void updateInt(int index, int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(int index, long x) throws SQLException { - throw new UnsupportedOperationException(); + private void checkValidResult() throws SQLException { + if (currentResult.get() == null) { + throw new SQLException("Not on a valid row"); } + } - @Override - public void updateFloat(int index, float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(int index, double x) throws SQLException { - throw new UnsupportedOperationException(); - } + private static Iterator flatten(Iterator rowsIterator, long maxRows) { + return (maxRows > 0) ? new RowLimitedIterator<>(rowsIterator, maxRows) : rowsIterator; + } - @Override - public void updateBigDecimal(int index, BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } + private static class RowLimitedIterator implements Iterator { - @Override - public void updateString(int index, String x) throws SQLException { - throw new UnsupportedOperationException(); - } + private final Iterator iterator; - @Override - public void updateBytes(int index, byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } + private final long maxRows; - @Override - public void updateDate(int index, Date x) throws SQLException { - throw new UnsupportedOperationException(); - } + private long cursor; - @Override - public void updateTime(int index, Time x) throws SQLException { - throw new UnsupportedOperationException(); + public RowLimitedIterator(Iterator iterator, long maxRows) { + Preconditions.checkState(maxRows >= 0, "max rows is negative"); + this.iterator = iterator; + this.maxRows = maxRows; } @Override - public void updateTimestamp(int index, Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); + public boolean hasNext() { + return cursor < maxRows && iterator.hasNext(); } @Override - public void updateAsciiStream(int index, InputStream x, int length) throws SQLException { - throw new UnsupportedOperationException(); + public T next() { + if (!hasNext()) { + throw new GeaflowRuntimeException("no such element"); + } + cursor++; + return iterator.next(); } + } - @Override - public void updateBinaryStream(int index, InputStream x, int length) throws SQLException { - throw new UnsupportedOperationException(); - } + private static class PageIterator extends AbstractIterator> { - @Override - public void updateCharacterStream(int index, Reader x, int length) throws SQLException { - throw new UnsupportedOperationException(); - } + private final List> queryData; + private int cursor; - @Override - public void updateObject(int index, Object x, int scaleOrLength) throws SQLException { - throw new UnsupportedOperationException(); + private PageIterator(List> queryData) { + this.queryData = queryData; + this.cursor = 0; } @Override - public void updateObject(int index, Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(String columnLabel, boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(String columnLabel, byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(String columnLabel, short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(String columnLabel, int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(String columnLabel, long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(String columnLabel, float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(String columnLabel, double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(String columnLabel, String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(String columnLabel, byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(String columnLabel, Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(String columnLabel, Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(String columnLabel, Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(String columnLabel, InputStream x, int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(String columnLabel, InputStream x, int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(String columnLabel, Reader reader, int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(String columnLabel, Object x, int scaleOrLength) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(String columnLabel, Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void insertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void deleteRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void refreshRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void cancelRowUpdates() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToInsertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToCurrentRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Statement getStatement() throws SQLException { - return this.statement; - } - - @Override - public Object getObject(int index, Map> map) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(String columnLabel, Map> map) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int index, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(String columnLabel, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(int index, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(String columnLabel, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(int index, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(int index, Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(String columnLabel, Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(int index, Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(String columnLabel, Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(int index, Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(String columnLabel, Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(int index, Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(String columnLabel, Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(int index, RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(String columnLabel, RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getHoldability() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClosed() throws SQLException { - return closed.get(); - } - - @Override - public void updateNString(int index, String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNString(String columnLabel, String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(int index, NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(String columnLabel, NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob getNClob(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob getNClob(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(int index, SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(String columnLabel, SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(int index) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(int index, Reader x, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(String columnLabel, Reader reader, long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(int index, InputStream x, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(int index, InputStream x, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(int index, Reader x, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(String columnLabel, InputStream x, long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(String columnLabel, InputStream x, long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(String columnLabel, Reader reader, long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(int index, InputStream inputStream, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(String columnLabel, InputStream inputStream, long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(int index, Reader reader, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(String columnLabel, Reader reader, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(int index, Reader reader, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(String columnLabel, Reader reader, long length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(int index, Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(String columnLabel, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(int index, InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(int index, InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(int index, Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(String columnLabel, InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(String columnLabel, InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(String columnLabel, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(int index, InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(String columnLabel, InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(int index, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(String columnLabel, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(int index, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(String columnLabel, Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T getObject(int index, Class type) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T getObject(String columnLabel, Class type) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T unwrap(Class iFace) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isWrapperFor(Class iFace) throws SQLException { - throw new UnsupportedOperationException(); - } - - private void checkResultSetOpen() throws SQLException { - if (isClosed()) { - throw new SQLException("Analytics result set is closed"); - } - } - - private int getFieldIndexByLabel(String label) throws SQLException { - if (label == null) { - throw new SQLException("field label is null"); - } - Integer index = fieldLabel2Index.get(label.toLowerCase(ENGLISH)); - if (index == null) { - throw new SQLException("invalid field label: " + label); - } - return index; - } - - private String getFieldLabelByIndex(int index) throws SQLException { - if (index <= 0) { - throw new SQLException(String.format("field index must be positive,but current index is %d", - index)); - } - String label = fieldIndex2Label.get(index); - if (label == null) { - throw new SQLException(String.format("label is not exist, index is %d", index)); - } - return label; - } - - private static Number convertNumber(Object value) throws SQLException { - if (value == null) { - return 0; - } - if (value instanceof Number) { - return (Number) value; - } - if (value instanceof Boolean) { - return ((Boolean) value) ? 1 : 0; - } - throw new SQLException( - String.format("value %s is not number", value.getClass().getCanonicalName())); - } - - private Object getFiledByIndex(int index) throws SQLException { - checkResultSetOpen(); - checkValidResult(); - if ((index <= 0) || (index > resultSetMetaData.getColumnCount())) { - throw new SQLException("Invalid column index: " + index); - } - Object value = currentResult.get().get(index - 1); - wasNull.set(value == null); - return value; - } - - private Object getFiledByLabel(String label) throws SQLException { - checkResultSetOpen(); - checkValidResult(); - Object value = currentResult.get().get(getFieldIndexByLabel(label) - 1); - wasNull.set(value == null); - return value; - } - - private static Map getFieldLabel2IndexMap(RelDataType relDataType) { - Map map = new HashMap<>(); - List fieldList = relDataType.getFieldList(); - for (int i = 0; i < fieldList.size(); i++) { - String name = fieldList.get(i).getName().toLowerCase(ENGLISH); - if (!map.containsKey(name)) { - map.put(name, i + 1); - } - } - return ImmutableMap.copyOf(map); - } - - private Map parseVertexValue(RowVertex rowVertex, String fieldLabel) { - Row vertexValue = rowVertex.getValue(); - Map properties = new HashMap<>(); - if (vertexValue != null) { - Map> tableFieldName2Type = resultSetMetaData.getTableFieldName2Type(); - for (Entry> entry : tableFieldName2Type.entrySet()) { - if (fieldLabel.equalsIgnoreCase(entry.getKey())) { - VertexType vertexType = (VertexType) entry.getValue(); - List valueFields = vertexType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - TableField tableField = valueFields.get(i); - Object valueField = vertexValue.getField(i, tableField.getType()); - properties.put(tableField.getName(), valueField); - } - } - } - } - return properties; - } - - private Map parseEdgeValue(RowEdge rowEdge, String fieldLabel) { - Row edgeValue = rowEdge.getValue(); - Map properties = new HashMap<>(); - if (edgeValue != null) { - Map> tableFieldName2Type = resultSetMetaData.getTableFieldName2Type(); - for (Entry> entry : tableFieldName2Type.entrySet()) { - if (fieldLabel.equalsIgnoreCase(entry.getKey())) { - EdgeType edgeType = (EdgeType) entry.getValue(); - List valueFields = edgeType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - TableField tableField = valueFields.get(i); - Object valueField = edgeValue.getField(i, tableField.getType()); - properties.put(tableField.getName(), valueField); - } - } - } - } - return properties; - } - - private static Map getFieldIndex2LabelMap(RelDataType relDataType) { - Map map = new HashMap<>(); - List fieldList = relDataType.getFieldList(); - for (int i = 0; i < fieldList.size(); i++) { - String name = fieldList.get(i).getName().toLowerCase(ENGLISH); - if (!map.containsValue(name)) { - map.put(i + 1, name); - } - } - return ImmutableMap.copyOf(map); - } - - private void checkValidResult() throws SQLException { - if (currentResult.get() == null) { - throw new SQLException("Not on a valid row"); - } - } - - - private static Iterator flatten(Iterator rowsIterator, long maxRows) { - return (maxRows > 0) ? new RowLimitedIterator<>(rowsIterator, maxRows) : rowsIterator; - } - - private static class RowLimitedIterator implements Iterator { - - private final Iterator iterator; - - private final long maxRows; - - private long cursor; - - public RowLimitedIterator(Iterator iterator, long maxRows) { - Preconditions.checkState(maxRows >= 0, "max rows is negative"); - this.iterator = iterator; - this.maxRows = maxRows; - } - - @Override - public boolean hasNext() { - return cursor < maxRows && iterator.hasNext(); - } - - @Override - public T next() { - if (!hasNext()) { - throw new GeaflowRuntimeException("no such element"); - } - cursor++; - return iterator.next(); - } - } - - private static class PageIterator extends AbstractIterator> { - - private final List> queryData; - private int cursor; - - private PageIterator(List> queryData) { - this.queryData = queryData; - this.cursor = 0; - } - - @Override - protected List computeNext() { - if (cursor < queryData.size()) { - List element = queryData.get(cursor); - cursor++; - return element; - } else { - return endOfData(); - } - } + protected List computeNext() { + if (cursor < queryData.size()) { + List element = queryData.get(cursor); + cursor++; + return element; + } else { + return endOfData(); + } } + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsStatement.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsStatement.java index e4660631d..1c6149904 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsStatement.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsStatement.java @@ -28,313 +28,290 @@ import java.sql.Statement; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class AnalyticsStatement implements Statement { - private final AtomicReference connection; - private final AtomicInteger queryTimeoutSeconds = new AtomicInteger(); - private final AtomicReference currentResult = new AtomicReference<>(); - private final AtomicReference executingClient; - - private AnalyticsStatement(AnalyticsConnection connection, AnalyticsClient client) { - this.connection = new AtomicReference<>(requireNonNull(connection, "analytics connection " - + "is null")); - this.executingClient = new AtomicReference<>(requireNonNull(client, "analytics client " - + "is null")); - } - - protected static AnalyticsStatement newInstance(AnalyticsConnection connection, - AnalyticsClient client) { - return new AnalyticsStatement(connection, client); - } - - @Override - public ResultSet executeQuery(String script) throws SQLException { - if (!execute(script)) { - throw new GeaflowRuntimeException("Execute statement is not a query: " + script); - } - return currentResult.get(); - } - - @Override - public boolean execute(String statement) throws SQLException { - return executeQueryInternal(statement); - } - - @Override - public int executeUpdate(String query) { - return 0; - } - - @Override - public void close() throws SQLException { - if (executingClient.get() != null) { - executingClient.get().shutdown(); - } - if (connection.get() != null) { - connection.get().close(); - } - } - - private boolean executeQueryInternal(String script) throws SQLException { - // Remove current results. - removeCurrentResults(); - // Check connection. - checkConnect(); - AnalyticsClient client = executingClient.get(); - ResultSet resultSet = null; - try { - QueryResults queryResults = client.executeQuery(script); - resultSet = new AnalyticsResultSet(this, queryResults); - executingClient.set(client); - currentResult.set(resultSet); - return true; - } catch (Exception e) { - throw new GeaflowRuntimeException("execute query fail", e); - } finally { - executingClient.set(null); - if (currentResult.get() == null) { - if (resultSet != null) { - resultSet.close(); - } - if (client != null) { - client.shutdown(); - } - } - } - } - - private void removeCurrentResults() throws SQLException { - ResultSet resultSet = currentResult.getAndSet(null); + private final AtomicReference connection; + private final AtomicInteger queryTimeoutSeconds = new AtomicInteger(); + private final AtomicReference currentResult = new AtomicReference<>(); + private final AtomicReference executingClient; + + private AnalyticsStatement(AnalyticsConnection connection, AnalyticsClient client) { + this.connection = + new AtomicReference<>(requireNonNull(connection, "analytics connection " + "is null")); + this.executingClient = + new AtomicReference<>(requireNonNull(client, "analytics client " + "is null")); + } + + protected static AnalyticsStatement newInstance( + AnalyticsConnection connection, AnalyticsClient client) { + return new AnalyticsStatement(connection, client); + } + + @Override + public ResultSet executeQuery(String script) throws SQLException { + if (!execute(script)) { + throw new GeaflowRuntimeException("Execute statement is not a query: " + script); + } + return currentResult.get(); + } + + @Override + public boolean execute(String statement) throws SQLException { + return executeQueryInternal(statement); + } + + @Override + public int executeUpdate(String query) { + return 0; + } + + @Override + public void close() throws SQLException { + if (executingClient.get() != null) { + executingClient.get().shutdown(); + } + if (connection.get() != null) { + connection.get().close(); + } + } + + private boolean executeQueryInternal(String script) throws SQLException { + // Remove current results. + removeCurrentResults(); + // Check connection. + checkConnect(); + AnalyticsClient client = executingClient.get(); + ResultSet resultSet = null; + try { + QueryResults queryResults = client.executeQuery(script); + resultSet = new AnalyticsResultSet(this, queryResults); + executingClient.set(client); + currentResult.set(resultSet); + return true; + } catch (Exception e) { + throw new GeaflowRuntimeException("execute query fail", e); + } finally { + executingClient.set(null); + if (currentResult.get() == null) { if (resultSet != null) { - resultSet.close(); + resultSet.close(); } - } - - private void checkConnect() throws SQLException { - connection(); - } - - private AnalyticsConnection connection() throws SQLException { - AnalyticsConnection connection = this.connection.get(); - if (connection == null || connection.isClosed()) { - throw new GeaflowRuntimeException("analytics connection is closed"); + if (client != null) { + client.shutdown(); } - return connection; + } } + } - - @Override - public int getMaxFieldSize() throws SQLException { - return 0; + private void removeCurrentResults() throws SQLException { + ResultSet resultSet = currentResult.getAndSet(null); + if (resultSet != null) { + resultSet.close(); } + } - @Override - public void setMaxFieldSize(int max) throws SQLException { - - } - - @Override - public int getMaxRows() throws SQLException { - return 0; - } - - @Override - public void setMaxRows(int max) throws SQLException { - - } - - @Override - public void setEscapeProcessing(boolean enable) throws SQLException { - - } - - @Override - public int getQueryTimeout() throws SQLException { - checkConnect(); - return this.queryTimeoutSeconds.get(); - } - - @Override - public void setQueryTimeout(int seconds) throws SQLException { - checkConnect(); - if (seconds < 0) { - throw new GeaflowRuntimeException("Query timeout seconds must be positive"); - } - queryTimeoutSeconds.set(seconds); - } - - @Override - public void cancel() throws SQLException { - checkConnect(); - AnalyticsClient analyticsClient = executingClient.get(); - removeCurrentResults(); - if (analyticsClient != null) { - analyticsClient.shutdown(); - } - } - - @Override - public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void clearWarnings() throws SQLException { - - } - - @Override - public void setCursorName(String name) throws SQLException { - - } - - @Override - public ResultSet getResultSet() throws SQLException { - return this.currentResult.get(); - } - - @Override - public int getUpdateCount() throws SQLException { - return 0; - } - - @Override - public boolean getMoreResults() throws SQLException { - throw new UnsupportedOperationException(); - } + private void checkConnect() throws SQLException { + connection(); + } - @Override - public void setFetchDirection(int direction) throws SQLException { + private AnalyticsConnection connection() throws SQLException { + AnalyticsConnection connection = this.connection.get(); + if (connection == null || connection.isClosed()) { + throw new GeaflowRuntimeException("analytics connection is closed"); + } + return connection; + } - } - - @Override - public int getFetchDirection() throws SQLException { - return 0; - } - - @Override - public void setFetchSize(int rows) throws SQLException { - - } - - @Override - public int getFetchSize() throws SQLException { - return 0; - } - - @Override - public int getResultSetConcurrency() throws SQLException { - return 0; - } - - @Override - public int getResultSetType() throws SQLException { - return 0; - } - - @Override - public void addBatch(String query) throws SQLException { - - } - - @Override - public void clearBatch() throws SQLException { - - } - - @Override - public int[] executeBatch() throws SQLException { - return new int[0]; - } - - @Override - public Connection getConnection() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getMoreResults(int current) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public ResultSet getGeneratedKeys() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int executeUpdate(String query, int autoGeneratedKeys) throws SQLException { - return 0; - } - - @Override - public int executeUpdate(String query, int[] columnIndexes) throws SQLException { - return 0; - } - - @Override - public int executeUpdate(String query, String[] columnNames) throws SQLException { - return 0; - } - - @Override - public boolean execute(String query, int autoGeneratedKeys) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean execute(String query, int[] columnIndexes) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean execute(String query, String[] columnNames) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getResultSetHoldability() throws SQLException { - return 0; - } - - @Override - public boolean isClosed() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setPoolable(boolean enablePool) throws SQLException { - - } - - @Override - public boolean isPoolable() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void closeOnCompletion() throws SQLException { - - } - - @Override - public boolean isCloseOnCompletion() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T unwrap(Class param) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isWrapperFor(Class param) throws SQLException { - throw new UnsupportedOperationException(); - } + @Override + public int getMaxFieldSize() throws SQLException { + return 0; + } + + @Override + public void setMaxFieldSize(int max) throws SQLException {} + @Override + public int getMaxRows() throws SQLException { + return 0; + } + + @Override + public void setMaxRows(int max) throws SQLException {} + + @Override + public void setEscapeProcessing(boolean enable) throws SQLException {} + + @Override + public int getQueryTimeout() throws SQLException { + checkConnect(); + return this.queryTimeoutSeconds.get(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + checkConnect(); + if (seconds < 0) { + throw new GeaflowRuntimeException("Query timeout seconds must be positive"); + } + queryTimeoutSeconds.set(seconds); + } + + @Override + public void cancel() throws SQLException { + checkConnect(); + AnalyticsClient analyticsClient = executingClient.get(); + removeCurrentResults(); + if (analyticsClient != null) { + analyticsClient.shutdown(); + } + } + + @Override + public SQLWarning getWarnings() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void clearWarnings() throws SQLException {} + + @Override + public void setCursorName(String name) throws SQLException {} + + @Override + public ResultSet getResultSet() throws SQLException { + return this.currentResult.get(); + } + + @Override + public int getUpdateCount() throws SQLException { + return 0; + } + + @Override + public boolean getMoreResults() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException {} + + @Override + public int getFetchDirection() throws SQLException { + return 0; + } + + @Override + public void setFetchSize(int rows) throws SQLException {} + + @Override + public int getFetchSize() throws SQLException { + return 0; + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return 0; + } + + @Override + public int getResultSetType() throws SQLException { + return 0; + } + + @Override + public void addBatch(String query) throws SQLException {} + + @Override + public void clearBatch() throws SQLException {} + + @Override + public int[] executeBatch() throws SQLException { + return new int[0]; + } + + @Override + public Connection getConnection() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int executeUpdate(String query, int autoGeneratedKeys) throws SQLException { + return 0; + } + + @Override + public int executeUpdate(String query, int[] columnIndexes) throws SQLException { + return 0; + } + + @Override + public int executeUpdate(String query, String[] columnNames) throws SQLException { + return 0; + } + + @Override + public boolean execute(String query, int autoGeneratedKeys) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean execute(String query, int[] columnIndexes) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean execute(String query, String[] columnNames) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return 0; + } + + @Override + public boolean isClosed() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setPoolable(boolean enablePool) throws SQLException {} + + @Override + public boolean isPoolable() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void closeOnCompletion() throws SQLException {} + + @Override + public boolean isCloseOnCompletion() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T unwrap(Class param) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class param) throws SQLException { + throw new UnsupportedOperationException(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/AbstractConnectProperty.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/AbstractConnectProperty.java index 18c6d127b..e8e647591 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/AbstractConnectProperty.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/AbstractConnectProperty.java @@ -19,181 +19,188 @@ import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; -import com.google.common.base.CharMatcher; -import com.google.common.base.Splitter; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Properties; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is an adaptation of Presto's com.facebook.presto.jdbc.AbstractConnectionProperty. - */ +import com.google.common.base.CharMatcher; +import com.google.common.base.Splitter; + +/** This class is an adaptation of Presto's com.facebook.presto.jdbc.AbstractConnectionProperty. */ public abstract class AbstractConnectProperty implements ConnectProperty { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractConnectProperty.class); - private static final String TRUE_FLAG = "true"; - private static final String FALSE_FLAG = "false"; - - private final String key; - private final Optional defaultValue; - private final Predicate enableRequired; - private final Predicate enableAllowed; - private final Converter converter; - - protected AbstractConnectProperty(String key, Optional defaultValue, - Predicate enableRequired, - Predicate enableAllowed, - Converter converter) { - this.key = requireNonNull(key, "key is null"); - this.defaultValue = requireNonNull(defaultValue, "defaultValue is null"); - this.enableRequired = requireNonNull(enableRequired, "enableRequired is null"); - this.enableAllowed = requireNonNull(enableAllowed, "enableAllowed is null"); - this.converter = requireNonNull(converter, "converter is null"); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractConnectProperty.class); + private static final String TRUE_FLAG = "true"; + private static final String FALSE_FLAG = "false"; + + private final String key; + private final Optional defaultValue; + private final Predicate enableRequired; + private final Predicate enableAllowed; + private final Converter converter; + + protected AbstractConnectProperty( + String key, + Optional defaultValue, + Predicate enableRequired, + Predicate enableAllowed, + Converter converter) { + this.key = requireNonNull(key, "key is null"); + this.defaultValue = requireNonNull(defaultValue, "defaultValue is null"); + this.enableRequired = requireNonNull(enableRequired, "enableRequired is null"); + this.enableAllowed = requireNonNull(enableAllowed, "enableAllowed is null"); + this.converter = requireNonNull(converter, "converter is null"); + } + + protected AbstractConnectProperty( + String key, + Predicate required, + Predicate allowed, + Converter converter) { + this(key, Optional.empty(), required, allowed, converter); + } + + @Override + public String getPropertyKey() { + return key; + } + + @Override + public Optional getDefault() { + return defaultValue; + } + + @Override + public boolean enableRequired(Properties properties) { + return enableRequired.test(properties); + } + + @Override + public boolean enableAllowed(Properties properties) { + return !properties.containsKey(key) || enableAllowed.test(properties); + } + + @Override + public Optional getValue(Properties properties) throws GeaflowRuntimeException { + String value = properties.getProperty(key); + if (value == null) { + if (enableRequired(properties)) { + throw new GeaflowRuntimeException(format("Connection property '%s' is required", key)); + } + return Optional.empty(); } - protected AbstractConnectProperty( - String key, - Predicate required, - Predicate allowed, - Converter converter) { - this(key, Optional.empty(), required, allowed, converter); + try { + return Optional.of(converter.convert(value)); + } catch (RuntimeException e) { + if (value.isEmpty()) { + throw new GeaflowRuntimeException( + format("Connection property '%s' value is empty", key), e); + } + throw new GeaflowRuntimeException( + format("Connection property '%s' value is invalid: %s", key, value), e); } + } - @Override - public String getPropertyKey() { - return key; + @Override + public void validate(Properties properties) throws GeaflowRuntimeException { + if (!enableAllowed(properties)) { + throw new GeaflowRuntimeException(format("Connection property '%s' is not allowed", key)); } - @Override - public Optional getDefault() { - return defaultValue; - } + getValue(properties); + } - @Override - public boolean enableRequired(Properties properties) { - return enableRequired.test(properties); - } + protected static final Predicate REQUIRED = properties -> true; + protected static final Predicate NOT_REQUIRED = properties -> false; - @Override - public boolean enableAllowed(Properties properties) { - return !properties.containsKey(key) || enableAllowed.test(properties); - } + protected static final Predicate ALLOWED = properties -> true; - @Override - public Optional getValue(Properties properties) throws GeaflowRuntimeException { - String value = properties.getProperty(key); - if (value == null) { - if (enableRequired(properties)) { - throw new GeaflowRuntimeException(format("Connection property '%s' is required", key)); - } - return Optional.empty(); - } + interface Converter { - try { - return Optional.of(converter.convert(value)); - } catch (RuntimeException e) { - if (value.isEmpty()) { - throw new GeaflowRuntimeException(format("Connection property '%s' value is empty", key), e); - } - throw new GeaflowRuntimeException( - format("Connection property '%s' value is invalid: %s", key, value), e); - } - } + T convert(String value); + } - @Override - public void validate(Properties properties) - throws GeaflowRuntimeException { - if (!enableAllowed(properties)) { - throw new GeaflowRuntimeException(format("Connection property '%s' is not allowed", key)); - } + protected static final Converter STRING_CONVERTER = value -> value; - getValue(properties); - } - - protected static final Predicate REQUIRED = properties -> true; - protected static final Predicate NOT_REQUIRED = properties -> false; - - protected static final Predicate ALLOWED = properties -> true; - - interface Converter { - - T convert(String value); - } - - protected static final Converter STRING_CONVERTER = value -> value; - - protected static final Converter NON_EMPTY_STRING_CONVERTER = value -> { + protected static final Converter NON_EMPTY_STRING_CONVERTER = + value -> { checkArgument(!value.isEmpty(), "value is empty"); return value; - }; + }; - protected static final Converter BOOLEAN_CONVERTER = value -> { + protected static final Converter BOOLEAN_CONVERTER = + value -> { switch (value.toLowerCase(ENGLISH)) { - case TRUE_FLAG: - return true; - case FALSE_FLAG: - return false; - default: - break; + case TRUE_FLAG: + return true; + case FALSE_FLAG: + return false; + default: + break; } throw new IllegalArgumentException("value must be 'true' or 'false'"); - }; + }; - protected static final class StringMapConverter implements Converter> { + protected static final class StringMapConverter implements Converter> { - private static final CharMatcher PRINTABLE_ASCII = CharMatcher.inRange((char) 0x21, (char) 0x7E); + private static final CharMatcher PRINTABLE_ASCII = + CharMatcher.inRange((char) 0x21, (char) 0x7E); - public static final StringMapConverter STRING_MAP_CONVERTER = new StringMapConverter(); + public static final StringMapConverter STRING_MAP_CONVERTER = new StringMapConverter(); - private static final char DELIMITER_COLON = ':'; + private static final char DELIMITER_COLON = ':'; - private static final char DELIMITER_SEMICOLON = ';'; + private static final char DELIMITER_SEMICOLON = ';'; - private StringMapConverter() { + private StringMapConverter() {} - } - - @Override - public Map convert(String value) { - return Splitter.on(DELIMITER_SEMICOLON).splitToList(value).stream() - .map(this::parseKeyValuePair) - .collect(Collectors.toMap(entry -> entry.get(0), entry -> entry.get(1))); - } - - public List parseKeyValuePair(String keyValue) { - List nameValue = Splitter.on(DELIMITER_COLON).splitToList(keyValue); - checkArgument(nameValue.size() == 2, "Malformed key value pair: %s", keyValue); - String name = nameValue.get(0); - String value = nameValue.get(1); - checkArgument(!name.isEmpty(), "Key is empty"); - checkArgument(!value.isEmpty(), "Value is empty"); - - checkArgument(PRINTABLE_ASCII.matchesAllOf(name), - "Key contains spaces or is not printable ASCII: %s", name); - checkArgument(PRINTABLE_ASCII.matchesAllOf(value), - "Value contains spaces or is not printable ASCII: %s", name); - return nameValue; - } - } - - protected interface CheckedPredicate { - boolean test(T t) throws GeaflowRuntimeException; + @Override + public Map convert(String value) { + return Splitter.on(DELIMITER_SEMICOLON).splitToList(value).stream() + .map(this::parseKeyValuePair) + .collect(Collectors.toMap(entry -> entry.get(0), entry -> entry.get(1))); } - protected static Predicate checkedPredicate(CheckedPredicate predicate) { - return t -> { - try { - return predicate.test(t); - } catch (GeaflowRuntimeException e) { - LOGGER.warn("check predicate error", e); - return false; - } - }; + public List parseKeyValuePair(String keyValue) { + List nameValue = Splitter.on(DELIMITER_COLON).splitToList(keyValue); + checkArgument(nameValue.size() == 2, "Malformed key value pair: %s", keyValue); + String name = nameValue.get(0); + String value = nameValue.get(1); + checkArgument(!name.isEmpty(), "Key is empty"); + checkArgument(!value.isEmpty(), "Value is empty"); + + checkArgument( + PRINTABLE_ASCII.matchesAllOf(name), + "Key contains spaces or is not printable ASCII: %s", + name); + checkArgument( + PRINTABLE_ASCII.matchesAllOf(value), + "Value contains spaces or is not printable ASCII: %s", + name); + return nameValue; } + } + + protected interface CheckedPredicate { + boolean test(T t) throws GeaflowRuntimeException; + } + + protected static Predicate checkedPredicate(CheckedPredicate predicate) { + return t -> { + try { + return predicate.test(t); + } catch (GeaflowRuntimeException e) { + LOGGER.warn("check predicate error", e); + return false; + } + }; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperties.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperties.java index a816c50b8..e0d6b5703 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperties.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperties.java @@ -19,141 +19,149 @@ import static java.util.stream.Collectors.toMap; import static org.apache.geaflow.analytics.service.client.jdbc.property.AbstractConnectProperty.StringMapConverter.STRING_MAP_CONVERTER; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.net.HostAndPort; import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.function.Predicate; -/** - * This class is an adaptation of Presto's com.facebook.presto.jdbc.ConnectionProperties. - */ +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.net.HostAndPort; + +/** This class is an adaptation of Presto's com.facebook.presto.jdbc.ConnectionProperties. */ public class ConnectProperties { - public static final ConnectProperty USER = new User(); - public static final ConnectProperty PASSWORD = new Password(); - public static final ConnectProperty SOCKS_PROXY = new SocksProxy(); - public static final ConnectProperty HTTP_PROXY = new HttpProxy(); - public static final ConnectProperty DISABLE_COMPRESSION = new DisableCompression(); - public static final ConnectProperty ACCESS_TOKEN = new AccessToken(); - public static final ConnectProperty> CUSTOM_HEADERS = new CustomHeaders(); - public static final ConnectProperty> SESSION_PROPERTIES = new SessionProperties(); - - private static final Set> ALL_PROPERTIES = - ImmutableSet.>builder() - .add(USER) - .add(PASSWORD) - .add(SOCKS_PROXY) - .add(HTTP_PROXY) - .add(DISABLE_COMPRESSION) - .add(ACCESS_TOKEN) - .add(CUSTOM_HEADERS) - .add(SESSION_PROPERTIES) - .build(); - - private static final Map> KEY_LOOKUP = unmodifiableMap( - ALL_PROPERTIES.stream() - .collect(toMap(ConnectProperty::getPropertyKey, identity()))); - - private static final Map DEFAULTS; - - static { - ImmutableMap.Builder defaults = ImmutableMap.builder(); - for (ConnectProperty property : ALL_PROPERTIES) { - property.getDefault() - .ifPresent(value -> defaults.put(property.getPropertyKey(), value)); - } - DEFAULTS = defaults.build(); + public static final ConnectProperty USER = new User(); + public static final ConnectProperty PASSWORD = new Password(); + public static final ConnectProperty SOCKS_PROXY = new SocksProxy(); + public static final ConnectProperty HTTP_PROXY = new HttpProxy(); + public static final ConnectProperty DISABLE_COMPRESSION = new DisableCompression(); + public static final ConnectProperty ACCESS_TOKEN = new AccessToken(); + public static final ConnectProperty> CUSTOM_HEADERS = new CustomHeaders(); + public static final ConnectProperty> SESSION_PROPERTIES = + new SessionProperties(); + + private static final Set> ALL_PROPERTIES = + ImmutableSet.>builder() + .add(USER) + .add(PASSWORD) + .add(SOCKS_PROXY) + .add(HTTP_PROXY) + .add(DISABLE_COMPRESSION) + .add(ACCESS_TOKEN) + .add(CUSTOM_HEADERS) + .add(SESSION_PROPERTIES) + .build(); + + private static final Map> KEY_LOOKUP = + unmodifiableMap( + ALL_PROPERTIES.stream().collect(toMap(ConnectProperty::getPropertyKey, identity()))); + + private static final Map DEFAULTS; + + static { + ImmutableMap.Builder defaults = ImmutableMap.builder(); + for (ConnectProperty property : ALL_PROPERTIES) { + property.getDefault().ifPresent(value -> defaults.put(property.getPropertyKey(), value)); } + DEFAULTS = defaults.build(); + } - private ConnectProperties() { - } + private ConnectProperties() {} - public static ConnectProperty forKey(String propertiesKey) { - return KEY_LOOKUP.get(propertiesKey); - } + public static ConnectProperty forKey(String propertiesKey) { + return KEY_LOOKUP.get(propertiesKey); + } - public static Set> allProperties() { - return ALL_PROPERTIES; - } + public static Set> allProperties() { + return ALL_PROPERTIES; + } - public static Map getDefaults() { - return DEFAULTS; - } + public static Map getDefaults() { + return DEFAULTS; + } - private static class User - extends AbstractConnectProperty { + private static class User extends AbstractConnectProperty { - public User() { - super(User.class.getSimpleName().toLowerCase(), REQUIRED, ALLOWED, NON_EMPTY_STRING_CONVERTER); - } + public User() { + super( + User.class.getSimpleName().toLowerCase(), REQUIRED, ALLOWED, NON_EMPTY_STRING_CONVERTER); } + } - private static class Password - extends AbstractConnectProperty { + private static class Password extends AbstractConnectProperty { - public Password() { - super(Password.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_CONVERTER); - } + public Password() { + super(Password.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_CONVERTER); } + } - private static class SocksProxy - extends AbstractConnectProperty { + private static class SocksProxy extends AbstractConnectProperty { - private static final Predicate NO_HTTP_PROXY = - checkedPredicate(properties -> !HTTP_PROXY.getValue(properties).isPresent()); + private static final Predicate NO_HTTP_PROXY = + checkedPredicate(properties -> !HTTP_PROXY.getValue(properties).isPresent()); - public SocksProxy() { - super(SocksProxy.class.getSimpleName().toLowerCase(), NOT_REQUIRED, NO_HTTP_PROXY, HostAndPort::fromString); - } + public SocksProxy() { + super( + SocksProxy.class.getSimpleName().toLowerCase(), + NOT_REQUIRED, + NO_HTTP_PROXY, + HostAndPort::fromString); } + } - private static class HttpProxy - extends AbstractConnectProperty { + private static class HttpProxy extends AbstractConnectProperty { - private static final Predicate NO_SOCKS_PROXY = - checkedPredicate(properties -> !SOCKS_PROXY.getValue(properties).isPresent()); + private static final Predicate NO_SOCKS_PROXY = + checkedPredicate(properties -> !SOCKS_PROXY.getValue(properties).isPresent()); - public HttpProxy() { - super(HttpProxy.class.getSimpleName().toLowerCase(), NOT_REQUIRED, NO_SOCKS_PROXY, - HostAndPort::fromString); - } + public HttpProxy() { + super( + HttpProxy.class.getSimpleName().toLowerCase(), + NOT_REQUIRED, + NO_SOCKS_PROXY, + HostAndPort::fromString); } + } + private static class DisableCompression extends AbstractConnectProperty { - private static class DisableCompression - extends AbstractConnectProperty { - - public DisableCompression() { - super(DisableCompression.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, - BOOLEAN_CONVERTER); - } + public DisableCompression() { + super( + DisableCompression.class.getSimpleName().toLowerCase(), + NOT_REQUIRED, + ALLOWED, + BOOLEAN_CONVERTER); } + } - private static class AccessToken - extends AbstractConnectProperty { + private static class AccessToken extends AbstractConnectProperty { - public AccessToken() { - super(AccessToken.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_CONVERTER); - } + public AccessToken() { + super( + AccessToken.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_CONVERTER); } + } - private static class CustomHeaders - extends AbstractConnectProperty> { + private static class CustomHeaders extends AbstractConnectProperty> { - public CustomHeaders() { - super(CustomHeaders.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_MAP_CONVERTER); - } + public CustomHeaders() { + super( + CustomHeaders.class.getSimpleName().toLowerCase(), + NOT_REQUIRED, + ALLOWED, + STRING_MAP_CONVERTER); } + } - private static class SessionProperties - extends AbstractConnectProperty> { + private static class SessionProperties extends AbstractConnectProperty> { - public SessionProperties() { - super(SessionProperties.class.getSimpleName().toLowerCase(), NOT_REQUIRED, ALLOWED, STRING_MAP_CONVERTER); - } + public SessionProperties() { + super( + SessionProperties.class.getSimpleName().toLowerCase(), + NOT_REQUIRED, + ALLOWED, + STRING_MAP_CONVERTER); } - + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperty.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperty.java index b568b3236..9a5ee8c07 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperty.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/jdbc/property/ConnectProperty.java @@ -18,27 +18,29 @@ import java.util.Optional; import java.util.Properties; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; -/** - * This class is an adaptation of Presto's com.facebook.presto.jdbc.ConnectionProperty. - */ +/** This class is an adaptation of Presto's com.facebook.presto.jdbc.ConnectionProperty. */ public interface ConnectProperty { - String getPropertyKey(); + String getPropertyKey(); - Optional getDefault(); + Optional getDefault(); - boolean enableRequired(Properties properties); + boolean enableRequired(Properties properties); - boolean enableAllowed(Properties properties); + boolean enableAllowed(Properties properties); - Optional getValue(Properties properties) throws GeaflowRuntimeException; + Optional getValue(Properties properties) throws GeaflowRuntimeException; - void validate(Properties properties) throws GeaflowRuntimeException; + void validate(Properties properties) throws GeaflowRuntimeException; - default T getRequiredValue(Properties properties) throws GeaflowRuntimeException { - return getValue(properties).orElseThrow(() -> - new GeaflowRuntimeException(format("Connect property '%s' is required", getPropertyKey()))); - } + default T getRequiredValue(Properties properties) throws GeaflowRuntimeException { + return getValue(properties) + .orElseThrow( + () -> + new GeaflowRuntimeException( + format("Connect property '%s' is required", getPropertyKey()))); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/utils/JDBCUtils.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/utils/JDBCUtils.java index 85eee0106..7bc84c8b7 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/utils/JDBCUtils.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/main/java/org/apache/geaflow/analytics/service/client/utils/JDBCUtils.java @@ -23,13 +23,12 @@ public class JDBCUtils { - public static final String DRIVER_URL_START = "jdbc:geaflow://"; + public static final String DRIVER_URL_START = "jdbc:geaflow://"; - public static boolean acceptsURL(String url) { - if (url.startsWith(JDBCUtils.DRIVER_URL_START)) { - return true; - } - throw new GeaflowRuntimeException("Invalid GeaFlow JDBC URL " + url); + public static boolean acceptsURL(String url) { + if (url.startsWith(JDBCUtils.DRIVER_URL_START)) { + return true; } - + throw new GeaflowRuntimeException("Invalid GeaFlow JDBC URL " + url); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptionsTest.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptionsTest.java index ada413518..e642f8184 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptionsTest.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/AnalyticsManagerOptionsTest.java @@ -23,25 +23,26 @@ import org.testng.annotations.Test; public class AnalyticsManagerOptionsTest { - private static final String URL_PATH = "/rest/analytics/query/execute"; + private static final String URL_PATH = "/rest/analytics/query/execute"; - @Test - public void testCreateClientSession() { - String host = "localhost"; - int port = 8080; - AnalyticsManagerSession clientSession1 = AnalyticsManagerOptions.createClientSession(host, port); - AnalyticsManagerSession clientSession2 = AnalyticsManagerOptions.createClientSession(port); - Assert.assertEquals(clientSession1.getServer().toString(), clientSession2.getServer().toString()); - Assert.assertEquals(clientSession1.getServer().toString(), "http://localhost:8080"); - } - - @Test - public void testServerResolve() { - String host = "localhost"; - int port = 8080; - AnalyticsManagerSession clientSession = AnalyticsManagerOptions.createClientSession(host, port); - String fullUri = clientSession.getServer().resolve(URL_PATH).toString(); - Assert.assertEquals(fullUri, "http://localhost:8080/rest/analytics/query/execute"); - } + @Test + public void testCreateClientSession() { + String host = "localhost"; + int port = 8080; + AnalyticsManagerSession clientSession1 = + AnalyticsManagerOptions.createClientSession(host, port); + AnalyticsManagerSession clientSession2 = AnalyticsManagerOptions.createClientSession(port); + Assert.assertEquals( + clientSession1.getServer().toString(), clientSession2.getServer().toString()); + Assert.assertEquals(clientSession1.getServer().toString(), "http://localhost:8080"); + } + @Test + public void testServerResolve() { + String host = "localhost"; + int port = 8080; + AnalyticsManagerSession clientSession = AnalyticsManagerOptions.createClientSession(host, port); + String fullUri = clientSession.getServer().resolve(URL_PATH).toString(); + Assert.assertEquals(fullUri, "http://localhost:8080/rest/analytics/query/execute"); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURITest.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURITest.java index 9cb4a7be9..1720b2e3d 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURITest.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-client/src/test/java/org/apache/geaflow/analytics/service/client/jdbc/AnalyticsDriverURITest.java @@ -24,30 +24,31 @@ import static org.testng.Assert.fail; import java.util.Properties; + import org.testng.annotations.Test; public class AnalyticsDriverURITest { - @Test - public void testInvalidURI() { - assertInvalid("jdbc:geaflow://localhost/", "No port number specified:"); - } + @Test + public void testInvalidURI() { + assertInvalid("jdbc:geaflow://localhost/", "No port number specified:"); + } - private static void assertInvalid(String url, String prefix) { - try { - createDriverUri(url); - fail("expected exception"); - } catch (Exception e) { - assertNotNull(e.getMessage()); - if (!e.getMessage().startsWith(prefix)) { - fail(format("expected:<%s> to start with <%s>", e.getMessage(), prefix)); - } - } + private static void assertInvalid(String url, String prefix) { + try { + createDriverUri(url); + fail("expected exception"); + } catch (Exception e) { + assertNotNull(e.getMessage()); + if (!e.getMessage().startsWith(prefix)) { + fail(format("expected:<%s> to start with <%s>", e.getMessage(), prefix)); + } } + } - private static AnalyticsDriverURI createDriverUri(String url) { - Properties properties = new Properties(); - properties.setProperty("user", "only-test"); - return new AnalyticsDriverURI(url, properties); - } + private static AnalyticsDriverURI createDriverUri(String url) { + Properties properties = new Properties(); + properties.setProperty("user", "only-test"); + return new AnalyticsDriverURI(url, properties); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsClientConfigKeys.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsClientConfigKeys.java index 9186660f7..5fcc3c017 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsClientConfigKeys.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsClientConfigKeys.java @@ -20,58 +20,59 @@ package org.apache.geaflow.analytics.service.config; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class AnalyticsClientConfigKeys implements Serializable { - public static final ConfigKey ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS = ConfigKeys - .key("geaflow.analytics.client.connect.timeout.ms") - .defaultValue(30000) - .description("analytics client connect timeout ms, default is 30000"); + public static final ConfigKey ANALYTICS_CLIENT_CONNECT_TIMEOUT_MS = + ConfigKeys.key("geaflow.analytics.client.connect.timeout.ms") + .defaultValue(30000) + .description("analytics client connect timeout ms, default is 30000"); - public static final ConfigKey ANALYTICS_CLIENT_REQUEST_TIMEOUT_MS = ConfigKeys - .key("geaflow.analytics.client.request.timeout.ms") - .defaultValue(30000) - .description("analytics client request timeout ms, default is 30000"); + public static final ConfigKey ANALYTICS_CLIENT_REQUEST_TIMEOUT_MS = + ConfigKeys.key("geaflow.analytics.client.request.timeout.ms") + .defaultValue(30000) + .description("analytics client request timeout ms, default is 30000"); - public static final ConfigKey ANALYTICS_CLIENT_CONNECT_RETRY_NUM = ConfigKeys - .key("geaflow.analytics.client.connect.retry.num") - .defaultValue(3) - .description("analytics client connect retry num, default is 3"); + public static final ConfigKey ANALYTICS_CLIENT_CONNECT_RETRY_NUM = + ConfigKeys.key("geaflow.analytics.client.connect.retry.num") + .defaultValue(3) + .description("analytics client connect retry num, default is 3"); - public static final ConfigKey ANALYTICS_CLIENT_EXECUTE_RETRY_NUM = ConfigKeys - .key("geaflow.analytics.client.execute.retry.num") - .defaultValue(3) - .description("analytics client execute retry num, default is 3"); + public static final ConfigKey ANALYTICS_CLIENT_EXECUTE_RETRY_NUM = + ConfigKeys.key("geaflow.analytics.client.execute.retry.num") + .defaultValue(3) + .description("analytics client execute retry num, default is 3"); - public static final ConfigKey ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE = ConfigKeys - .key("geaflow.analytics.client.max.inbound.message.size") - .defaultValue(4194304) - .description("analytics client max inbound message size for rpc, default is 4194304"); + public static final ConfigKey ANALYTICS_CLIENT_MAX_INBOUND_MESSAGE_SIZE = + ConfigKeys.key("geaflow.analytics.client.max.inbound.message.size") + .defaultValue(4194304) + .description("analytics client max inbound message size for rpc, default is 4194304"); - public static final ConfigKey ANALYTICS_CLIENT_MAX_RETRY_ATTEMPTS = ConfigKeys - .key("geaflow.analytics.client.max.retry.attempts") - .defaultValue(5) - .description("analytics client max retry attempts for rpc, default is 5"); + public static final ConfigKey ANALYTICS_CLIENT_MAX_RETRY_ATTEMPTS = + ConfigKeys.key("geaflow.analytics.client.max.retry.attempts") + .defaultValue(5) + .description("analytics client max retry attempts for rpc, default is 5"); - public static final ConfigKey ANALYTICS_CLIENT_DEFALUT_RETRY_BUFFER_SIZE = ConfigKeys - .key("geaflow.analytics.client.retry.buffer.size") - .defaultValue(16777216L) - .description("analytics client default retry buffer size for rpc, default is 16777216"); + public static final ConfigKey ANALYTICS_CLIENT_DEFALUT_RETRY_BUFFER_SIZE = + ConfigKeys.key("geaflow.analytics.client.retry.buffer.size") + .defaultValue(16777216L) + .description("analytics client default retry buffer size for rpc, default is 16777216"); - public static final ConfigKey ANALYTICS_CLIENT_PER_RPC_BUFFER_LIMIT = ConfigKeys - .key("geaflow.analytics.client.per.rpc.buffer.limit") - .defaultValue(1048576L) - .description("analytics client per rpc buffer limit, default is 1048576"); + public static final ConfigKey ANALYTICS_CLIENT_PER_RPC_BUFFER_LIMIT = + ConfigKeys.key("geaflow.analytics.client.per.rpc.buffer.limit") + .defaultValue(1048576L) + .description("analytics client per rpc buffer limit, default is 1048576"); - public static final ConfigKey ANALYTICS_CLIENT_ACCESS_TOKEN = ConfigKeys - .key("geaflow.analytics.client.access.token") - .noDefaultValue() - .description("analytics client access token for auth"); + public static final ConfigKey ANALYTICS_CLIENT_ACCESS_TOKEN = + ConfigKeys.key("geaflow.analytics.client.access.token") + .noDefaultValue() + .description("analytics client access token for auth"); - public static final ConfigKey ANALYTICS_CLIENT_SLEEP_TIME_MS = ConfigKeys - .key("geaflow.analytics.client.sleep.time.ms") - .defaultValue(3000L) - .description("analytics client sleep time"); + public static final ConfigKey ANALYTICS_CLIENT_SLEEP_TIME_MS = + ConfigKeys.key("geaflow.analytics.client.sleep.time.ms") + .defaultValue(3000L) + .description("analytics client sleep time"); } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsServiceConfigKeys.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsServiceConfigKeys.java index 1fe08be8b..0dcf0e879 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsServiceConfigKeys.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/config/AnalyticsServiceConfigKeys.java @@ -20,38 +20,38 @@ package org.apache.geaflow.analytics.service.config; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class AnalyticsServiceConfigKeys implements Serializable { - public static final ConfigKey MAX_REQUEST_PER_SERVER = ConfigKeys - .key("geaflow.max.request.per.server") - .defaultValue(1) - .description("the maximum number of requests that can be accepted simultaneously for per server"); - - public static final ConfigKey ANALYTICS_SERVICE_PORT = ConfigKeys - .key("geaflow.analytics.service.port") - .defaultValue(0) - .description("analytics service port, default is 0"); - - public static final ConfigKey ANALYTICS_QUERY_PARALLELISM = ConfigKeys - .key("geaflow.analytics.query.parallelism") - .defaultValue(1) - .description("analytics query parallelism"); - - public static final ConfigKey ANALYTICS_QUERY = ConfigKeys - .key("geaflow.analytics.query") - .noDefaultValue() - .description("analytics query"); - - public static final ConfigKey ANALYTICS_SERVICE_REGISTER_ENABLE = ConfigKeys - .key("geaflow.analytics.service.register.enable") - .defaultValue(true) - .description("enable analytics service info register"); - - public static final ConfigKey ANALYTICS_COMPILE_SCHEMA_ENABLE = ConfigKeys - .key("geaflow.analytics.compile.schema.enable") - .defaultValue(true) - .description("enable analytics compile schema"); + public static final ConfigKey MAX_REQUEST_PER_SERVER = + ConfigKeys.key("geaflow.max.request.per.server") + .defaultValue(1) + .description( + "the maximum number of requests that can be accepted simultaneously for per server"); + + public static final ConfigKey ANALYTICS_SERVICE_PORT = + ConfigKeys.key("geaflow.analytics.service.port") + .defaultValue(0) + .description("analytics service port, default is 0"); + + public static final ConfigKey ANALYTICS_QUERY_PARALLELISM = + ConfigKeys.key("geaflow.analytics.query.parallelism") + .defaultValue(1) + .description("analytics query parallelism"); + + public static final ConfigKey ANALYTICS_QUERY = + ConfigKeys.key("geaflow.analytics.query").noDefaultValue().description("analytics query"); + + public static final ConfigKey ANALYTICS_SERVICE_REGISTER_ENABLE = + ConfigKeys.key("geaflow.analytics.service.register.enable") + .defaultValue(true) + .description("enable analytics service info register"); + + public static final ConfigKey ANALYTICS_COMPILE_SCHEMA_ENABLE = + ConfigKeys.key("geaflow.analytics.compile.schema.enable") + .defaultValue(true) + .description("enable analytics compile schema"); } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/DefaultResultSetFormatUtils.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/DefaultResultSetFormatUtils.java index 706843c6e..b8ae958c7 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/DefaultResultSetFormatUtils.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/DefaultResultSetFormatUtils.java @@ -19,10 +19,6 @@ package org.apache.geaflow.analytics.service.query; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; -import com.alibaba.fastjson.serializer.SerializerFeature; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -30,6 +26,7 @@ import java.util.Map; import java.util.TreeSet; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.geaflow.cluster.response.ResponseResult; @@ -41,117 +38,133 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.serializer.SerializerFeature; + public class DefaultResultSetFormatUtils { - private static final Logger LOGGER = LoggerFactory.getLogger(DefaultResultSetFormatUtils.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DefaultResultSetFormatUtils.class); - private static final String VERTEX = "nodes"; - private static final String EDGE = "edges"; + private static final String VERTEX = "nodes"; + private static final String EDGE = "edges"; - private static final String VIEW_RESULT = "viewResult"; + private static final String VIEW_RESULT = "viewResult"; - private static final String JSON_RESULT = "jsonResult"; + private static final String JSON_RESULT = "jsonResult"; - public static String formatResult(Object queryResult, RelDataType currentResultType) { - final JSONObject finalResult = new JSONObject(); - JSONArray jsonResult = new JSONArray(); - JSONObject viewResult = new JSONObject(); - List vertices = new ArrayList<>(); - List edges = new ArrayList<>(); - List> list = (List>) queryResult; + public static String formatResult(Object queryResult, RelDataType currentResultType) { + final JSONObject finalResult = new JSONObject(); + JSONArray jsonResult = new JSONArray(); + JSONObject viewResult = new JSONObject(); + List vertices = new ArrayList<>(); + List edges = new ArrayList<>(); + List> list = (List>) queryResult; - for (List responseResults : list) { - for (ResponseResult responseResult : responseResults) { - for (Object o : responseResult.getResponse()) { - jsonResult.add(formatRow(o, currentResultType, vertices, edges)); - } - } + for (List responseResults : list) { + for (ResponseResult responseResult : responseResults) { + for (Object o : responseResult.getResponse()) { + jsonResult.add(formatRow(o, currentResultType, vertices, edges)); } - - List filteredVertices = - vertices.stream().collect(Collectors.collectingAndThen(Collectors.toCollection(() -> new TreeSet<>( - Comparator.comparing(ViewVertex::getId))), ArrayList::new)); - - viewResult.put(VERTEX, filteredVertices); - viewResult.put(EDGE, edges); - finalResult.put(VIEW_RESULT, viewResult); - finalResult.put(JSON_RESULT, jsonResult); - return JSON.toJSONString(finalResult, SerializerFeature.DisableCircularReferenceDetect); + } } - private static Object formatRow(Object o, RelDataType currentResultType, List vertices, List edges) { - if (o == null) { - return null; - } - if (o instanceof ObjectRow) { - JSONObject jsonObject = new JSONObject(); - ObjectRow objectRow = (ObjectRow) o; - Object[] fields = objectRow.getFields(); - for (int i = 0; i < fields.length; i++) { - RelDataTypeField relDataTypeField = currentResultType.getFieldList().get(i); - Object field = fields[i]; - Object formatResult; - if (field instanceof RowVertex) { - RowVertex vertex = (RowVertex) field; - ObjectRow vertexValue = (ObjectRow) vertex.getValue(); - Map properties = new HashMap<>(); - if (vertexValue != null) { - Object[] vertexProperties = vertexValue.getFields(); - int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); - List typeList = relDataTypeField.getType().getFieldList(); - // Find the correspond key in properties. - for (int j = 0; j < vertexProperties.length; j++) { - properties.put(typeList.get(j + metaFieldCount).getName(), vertexProperties[j]); - } - } - - formatResult = new ViewVertex(String.valueOf(vertex.getId()), getLabel(vertex), properties); - vertices.add((ViewVertex) formatResult); - } else if (field instanceof RowEdge) { - RowEdge edge = (RowEdge) field; - ObjectRow edgeValue = (ObjectRow) edge.getValue(); - Map properties = new HashMap<>(); - if (edgeValue != null) { - Object[] edgeProperties = edgeValue.getFields(); - int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); - List typeList = relDataTypeField.getType().getFieldList(); - for (int j = 0; j < edgeProperties.length; j++) { - properties.put(typeList.get(j + metaFieldCount).getName(), edgeProperties[j]); - } - } - formatResult = new ViewEdge(String.valueOf(edge.getSrcId()), String.valueOf(edge.getTargetId()), - getLabel(edge), properties, edge.getDirect().name()); - edges.add((ViewEdge) formatResult); - } else { - formatResult = field.toString(); - } - jsonObject.put(relDataTypeField.getKey(), formatResult); + List filteredVertices = + vertices.stream() + .collect( + Collectors.collectingAndThen( + Collectors.toCollection( + () -> new TreeSet<>(Comparator.comparing(ViewVertex::getId))), + ArrayList::new)); + + viewResult.put(VERTEX, filteredVertices); + viewResult.put(EDGE, edges); + finalResult.put(VIEW_RESULT, viewResult); + finalResult.put(JSON_RESULT, jsonResult); + return JSON.toJSONString(finalResult, SerializerFeature.DisableCircularReferenceDetect); + } + + private static Object formatRow( + Object o, RelDataType currentResultType, List vertices, List edges) { + if (o == null) { + return null; + } + if (o instanceof ObjectRow) { + JSONObject jsonObject = new JSONObject(); + ObjectRow objectRow = (ObjectRow) o; + Object[] fields = objectRow.getFields(); + for (int i = 0; i < fields.length; i++) { + RelDataTypeField relDataTypeField = currentResultType.getFieldList().get(i); + Object field = fields[i]; + Object formatResult; + if (field instanceof RowVertex) { + RowVertex vertex = (RowVertex) field; + ObjectRow vertexValue = (ObjectRow) vertex.getValue(); + Map properties = new HashMap<>(); + if (vertexValue != null) { + Object[] vertexProperties = vertexValue.getFields(); + int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); + List typeList = relDataTypeField.getType().getFieldList(); + // Find the correspond key in properties. + for (int j = 0; j < vertexProperties.length; j++) { + properties.put(typeList.get(j + metaFieldCount).getName(), vertexProperties[j]); } - - return jsonObject; + } + + formatResult = + new ViewVertex(String.valueOf(vertex.getId()), getLabel(vertex), properties); + vertices.add((ViewVertex) formatResult); + } else if (field instanceof RowEdge) { + RowEdge edge = (RowEdge) field; + ObjectRow edgeValue = (ObjectRow) edge.getValue(); + Map properties = new HashMap<>(); + if (edgeValue != null) { + Object[] edgeProperties = edgeValue.getFields(); + int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); + List typeList = relDataTypeField.getType().getFieldList(); + for (int j = 0; j < edgeProperties.length; j++) { + properties.put(typeList.get(j + metaFieldCount).getName(), edgeProperties[j]); + } + } + formatResult = + new ViewEdge( + String.valueOf(edge.getSrcId()), + String.valueOf(edge.getTargetId()), + getLabel(edge), + properties, + edge.getDirect().name()); + edges.add((ViewEdge) formatResult); } else { - return o.toString(); + formatResult = field.toString(); } - } + jsonObject.put(relDataTypeField.getKey(), formatResult); + } - private static String getLabel(IGraphElementWithLabelField field) { - try { - return field.getLabel(); - } catch (Exception e) { - LOGGER.warn("field {} get label error", field, e); - return null; - } + return jsonObject; + } else { + return o.toString(); } - - public static int getMetaFieldCount(RelDataType type) { - List fieldList = type.getFieldList(); - int count = 0; - for (RelDataTypeField relDataTypeField : fieldList) { - if (!(relDataTypeField.getType() instanceof MetaFieldType)) { - break; - } - count++; - } - return count; + } + + private static String getLabel(IGraphElementWithLabelField field) { + try { + return field.getLabel(); + } catch (Exception e) { + LOGGER.warn("field {} get label error", field, e); + return null; + } + } + + public static int getMetaFieldCount(RelDataType type) { + List fieldList = type.getFieldList(); + int count = 0; + for (RelDataTypeField relDataTypeField : fieldList) { + if (!(relDataTypeField.getType() instanceof MetaFieldType)) { + break; + } + count++; } + return count; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/IQueryStatus.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/IQueryStatus.java index 551a06b7e..21c09e968 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/IQueryStatus.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/IQueryStatus.java @@ -18,18 +18,12 @@ * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Presto's com.facebook.presto.client.QueryStatusInfo. - */ +/** This class is an adaptation of Presto's com.facebook.presto.client.QueryStatusInfo. */ public interface IQueryStatus { - /** - * Get query id. - */ - String getQueryId(); + /** Get query id. */ + String getQueryId(); - /** - * Get query error. - */ - QueryError getError(); + /** Get query error. */ + QueryError getError(); } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryError.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryError.java index 910a921f9..b0c948fb5 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryError.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryError.java @@ -24,71 +24,71 @@ import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Objects; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class QueryError implements Externalizable { - private static final int DEFAULT_ERROR_CODE = 0; - private static final String DELIMITER = ":"; - private int code; - private String name; - - public QueryError(String name, int code) { - if (code < 0) { - throw new GeaflowRuntimeException("query error code is negative"); - } - this.name = name; - this.code = code; - } + private static final int DEFAULT_ERROR_CODE = 0; + private static final String DELIMITER = ":"; + private int code; + private String name; - public QueryError(String name) { - this.name = name; - this.code = DEFAULT_ERROR_CODE; + public QueryError(String name, int code) { + if (code < 0) { + throw new GeaflowRuntimeException("query error code is negative"); } + this.name = name; + this.code = code; + } - public QueryError() { - } + public QueryError(String name) { + this.name = name; + this.code = DEFAULT_ERROR_CODE; + } - public int getCode() { - return code; - } + public QueryError() {} - public String getName() { - return name; - } + public int getCode() { + return code; + } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - QueryError that = (QueryError) obj; - return Objects.equals(this.code, that.code); - } + public String getName() { + return name; + } - @Override - public int hashCode() { - return Objects.hash(code); + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; } - - @Override - public String toString() { - return name + DELIMITER + code; + if (obj == null || getClass() != obj.getClass()) { + return false; } - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(code); - out.writeObject(name); - } + QueryError that = (QueryError) obj; + return Objects.equals(this.code, that.code); + } - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - code = in.readInt(); - name = (String) in.readObject(); - } + @Override + public int hashCode() { + return Objects.hash(code); + } + + @Override + public String toString() { + return name + DELIMITER + code; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(code); + out.writeObject(name); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + code = in.readInt(); + name = (String) in.readObject(); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryIdGenerator.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryIdGenerator.java index a4c720c77..3cf164a80 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryIdGenerator.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryIdGenerator.java @@ -17,81 +17,82 @@ import static com.google.common.base.Preconditions.checkState; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableSet; -import com.google.common.primitives.Chars; -import com.google.common.util.concurrent.Uninterruptibles; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; + import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -/** - * This class is an adaptation of Presto's com.facebook.presto.execution.QueryIdGenerator. - */ -public class QueryIdGenerator { +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Chars; +import com.google.common.util.concurrent.Uninterruptibles; - private static final char[] BASE_32 = { - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', - 'i', 'j', 'k', 'm', 'n', 'p', 'q', 'r', - 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', - '2', '3', '4', '5', '6', '7', '8', '9'}; +/** This class is an adaptation of Presto's com.facebook.presto.execution.QueryIdGenerator. */ +public class QueryIdGenerator { - static { - checkState(ImmutableSet.copyOf(Chars.asList(BASE_32)).size() == 32); - } + private static final char[] BASE_32 = { + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'm', 'n', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '2', '3', '4', '5', '6', '7', '8', '9' + }; - private static final DateTimeFormatter TIMESTAMP_FORMAT = DateTimeFormat.forPattern( - "YYYYMMdd_HHmmss").withZoneUTC(); + static { + checkState(ImmutableSet.copyOf(Chars.asList(BASE_32)).size() == 32); + } - private static final long BASE_SYSTEM_TIME_MILLIS = System.currentTimeMillis(); - private static final long BASE_NANO_TIME = System.nanoTime(); - private static final int COUNT_LIMIT = 99_999; - private static final int ID_SIZE = 5; - private final String coordinatorId; - private long lastTimeInDays; + private static final DateTimeFormatter TIMESTAMP_FORMAT = + DateTimeFormat.forPattern("YYYYMMdd_HHmmss").withZoneUTC(); - private long lastTimeInSeconds; + private static final long BASE_SYSTEM_TIME_MILLIS = System.currentTimeMillis(); + private static final long BASE_NANO_TIME = System.nanoTime(); + private static final int COUNT_LIMIT = 99_999; + private static final int ID_SIZE = 5; + private final String coordinatorId; + private long lastTimeInDays; - private String lastTimestamp; + private long lastTimeInSeconds; - private int counter; + private String lastTimestamp; - public QueryIdGenerator() { - StringBuilder coordinatorId = new StringBuilder(ID_SIZE); - for (int i = 0; i < ID_SIZE; i++) { - coordinatorId.append(BASE_32[ThreadLocalRandom.current().nextInt(BASE_32.length)]); - } - this.coordinatorId = coordinatorId.toString(); - } + private int counter; - public String getCoordinatorId() { - return coordinatorId; + public QueryIdGenerator() { + StringBuilder coordinatorId = new StringBuilder(ID_SIZE); + for (int i = 0; i < ID_SIZE; i++) { + coordinatorId.append(BASE_32[ThreadLocalRandom.current().nextInt(BASE_32.length)]); } - - public synchronized String createQueryId() { - if (counter > COUNT_LIMIT) { - while (MILLISECONDS.toSeconds(nowInMillis()) == lastTimeInSeconds) { - Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); - } - counter = 0; - } - - long now = nowInMillis(); - if (MILLISECONDS.toSeconds(now) != lastTimeInSeconds) { - lastTimeInSeconds = MILLISECONDS.toSeconds(now); - lastTimestamp = TIMESTAMP_FORMAT.print(now); - if (MILLISECONDS.toDays(now) != lastTimeInDays) { - lastTimeInDays = MILLISECONDS.toDays(now); - counter = 0; - } - } - return String.format("%s_%05d_%s", lastTimestamp, counter++, coordinatorId); + this.coordinatorId = coordinatorId.toString(); + } + + public String getCoordinatorId() { + return coordinatorId; + } + + public synchronized String createQueryId() { + if (counter > COUNT_LIMIT) { + while (MILLISECONDS.toSeconds(nowInMillis()) == lastTimeInSeconds) { + Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); + } + counter = 0; } - @VisibleForTesting - protected long nowInMillis() { - return BASE_SYSTEM_TIME_MILLIS + TimeUnit.NANOSECONDS.toMillis( - System.nanoTime() - BASE_NANO_TIME); + long now = nowInMillis(); + if (MILLISECONDS.toSeconds(now) != lastTimeInSeconds) { + lastTimeInSeconds = MILLISECONDS.toSeconds(now); + lastTimestamp = TIMESTAMP_FORMAT.print(now); + if (MILLISECONDS.toDays(now) != lastTimeInDays) { + lastTimeInDays = MILLISECONDS.toDays(now); + counter = 0; + } } + return String.format("%s_%05d_%s", lastTimestamp, counter++, coordinatorId); + } + + @VisibleForTesting + protected long nowInMillis() { + return BASE_SYSTEM_TIME_MILLIS + + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - BASE_NANO_TIME); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryInfo.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryInfo.java index 40c0ac841..fe69eaeb8 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryInfo.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryInfo.java @@ -23,35 +23,35 @@ public class QueryInfo { - private final String queryId; + private final String queryId; - private final String queryScript; + private final String queryScript; - private RelDataType scriptSchema; + private RelDataType scriptSchema; - public QueryInfo(String queryId, String queryScript) { - this.queryId = queryId; - this.queryScript = queryScript; - } + public QueryInfo(String queryId, String queryScript) { + this.queryId = queryId; + this.queryScript = queryScript; + } - public RelDataType getScriptSchema() { - return scriptSchema; - } + public RelDataType getScriptSchema() { + return scriptSchema; + } - public void setScriptSchema(RelDataType scriptSchema) { - this.scriptSchema = scriptSchema; - } + public void setScriptSchema(RelDataType scriptSchema) { + this.scriptSchema = scriptSchema; + } - public String getQueryId() { - return queryId; - } + public String getQueryId() { + return queryId; + } - public String getQueryScript() { - return queryScript; - } + public String getQueryScript() { + return queryScript; + } - @Override - public String toString() { - return "queryId: " + queryId + " queryScript: " + queryScript; - } + @Override + public String toString() { + return "queryId: " + queryId + " queryScript: " + queryScript; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryResults.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryResults.java index 437a24f90..25cdd6d71 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryResults.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/QueryResults.java @@ -22,142 +22,143 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.geaflow.cluster.response.ResponseResult; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; -/** - * This class is an adaptation of Presto's com.facebook.presto.client.QueryResults. - */ +/** This class is an adaptation of Presto's com.facebook.presto.client.QueryResults. */ public class QueryResults implements IQueryStatus, Externalizable { - private static final String DEFAULT_QUERY_ID = "0"; - private static final String SEPARATOR = ":"; - private String queryId; - private List> queryOriginData; - private List> rawData; - private QueryError error; - private boolean queryStatus; - private RelDataType resultMeta; - - public QueryResults() { - - } - - public QueryResults(String queryId, List> data, QueryError error, boolean queryStatus) { - this.queryId = queryId; - this.queryOriginData = data; - this.rawData = getRawData(); - this.error = error; - this.queryStatus = queryStatus; - } - - public RelDataType getResultMeta() { - return resultMeta; - } - - public void setResultMeta(RelDataType resultMeta) { - this.resultMeta = resultMeta; - } - - public QueryResults(String queryId, List> data) { - this(queryId, data, null, true); - } - - public QueryResults(String queryId, QueryError error) { - this(queryId, null, error, false); - } - - public QueryResults(QueryError error) { - this(DEFAULT_QUERY_ID, null, error, false); + private static final String DEFAULT_QUERY_ID = "0"; + private static final String SEPARATOR = ":"; + private String queryId; + private List> queryOriginData; + private List> rawData; + private QueryError error; + private boolean queryStatus; + private RelDataType resultMeta; + + public QueryResults() {} + + public QueryResults( + String queryId, List> data, QueryError error, boolean queryStatus) { + this.queryId = queryId; + this.queryOriginData = data; + this.rawData = getRawData(); + this.error = error; + this.queryStatus = queryStatus; + } + + public RelDataType getResultMeta() { + return resultMeta; + } + + public void setResultMeta(RelDataType resultMeta) { + this.resultMeta = resultMeta; + } + + public QueryResults(String queryId, List> data) { + this(queryId, data, null, true); + } + + public QueryResults(String queryId, QueryError error) { + this(queryId, null, error, false); + } + + public QueryResults(QueryError error) { + this(DEFAULT_QUERY_ID, null, error, false); + } + + @Override + public String getQueryId() { + return queryId; + } + + public String getFormattedData() { + if (this.queryOriginData != null) { + return DefaultResultSetFormatUtils.formatResult(this.queryOriginData, this.resultMeta); } + return null; + } - @Override - public String getQueryId() { - return queryId; + public List> getRawData() { + List> result = new ArrayList<>(); + if (queryOriginData == null) { + return result; } - - - public String getFormattedData() { - if (this.queryOriginData != null) { - return DefaultResultSetFormatUtils.formatResult(this.queryOriginData, this.resultMeta); + for (List responseResults : queryOriginData) { + for (ResponseResult responseResult : responseResults) { + for (Object response : responseResult.getResponse()) { + if (response == null) { + continue; + } + if (response instanceof ObjectRow) { + ObjectRow objectRow = (ObjectRow) response; + Object[] fields = objectRow.getFields(); + result.add(Arrays.asList(fields)); + } else { + result.add(Collections.singletonList(response)); + } } - return null; + } } - - public List> getRawData() { - List> result = new ArrayList<>(); - if (queryOriginData == null) { - return result; + return result; + } + + public QueryError getError() { + return error; + } + + public boolean getQueryStatus() { + return queryStatus; + } + + @Override + public String toString() { + return this.queryId + + SEPARATOR + + this.getQueryStatus() + + SEPARATOR + + (this.getQueryStatus() ? this.rawData : this.error.toString()); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(queryId); + out.writeBoolean(queryStatus); + if (queryStatus) { + out.writeInt(this.rawData.size()); + for (List rawDatum : rawData) { + out.writeInt(rawDatum.size()); + for (Object data : rawDatum) { + out.writeObject(data); } - for (List responseResults : queryOriginData) { - for (ResponseResult responseResult : responseResults) { - for (Object response : responseResult.getResponse()) { - if (response == null) { - continue; - } - if (response instanceof ObjectRow) { - ObjectRow objectRow = (ObjectRow) response; - Object[] fields = objectRow.getFields(); - result.add(Arrays.asList(fields)); - } else { - result.add(Collections.singletonList(response)); - } - } - } - } - return result; - } - - public QueryError getError() { - return error; + } + out.writeObject(resultMeta); + } else { + out.writeObject(error); } - - public boolean getQueryStatus() { - return queryStatus; - } - - @Override - public String toString() { - return this.queryId + SEPARATOR + this.getQueryStatus() + SEPARATOR + (this.getQueryStatus() ? this.rawData : this.error.toString()); - } - - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeObject(queryId); - out.writeBoolean(queryStatus); - if (queryStatus) { - out.writeInt(this.rawData.size()); - for (List rawDatum : rawData) { - out.writeInt(rawDatum.size()); - for (Object data : rawDatum) { - out.writeObject(data); - } - } - out.writeObject(resultMeta); - } else { - out.writeObject(error); - } - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - queryId = (String) in.readObject(); - queryStatus = in.readBoolean(); - if (queryStatus) { - int dataSize = in.readInt(); - List> rawData = new ArrayList<>(dataSize); - for (int i = 0; i < dataSize; i++) { - int datumSize = in.readInt(); - List datum = new ArrayList<>(); - for (int j = 0; j < datumSize; j++) { - datum.add(in.readObject()); - } - rawData.add(datum); - } - this.rawData = rawData; - resultMeta = (RelDataType) in.readObject(); - } else { - error = (QueryError) in.readObject(); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + queryId = (String) in.readObject(); + queryStatus = in.readBoolean(); + if (queryStatus) { + int dataSize = in.readInt(); + List> rawData = new ArrayList<>(dataSize); + for (int i = 0; i < dataSize; i++) { + int datumSize = in.readInt(); + List datum = new ArrayList<>(); + for (int j = 0; j < datumSize; j++) { + datum.add(in.readObject()); } + rawData.add(datum); + } + this.rawData = rawData; + resultMeta = (RelDataType) in.readObject(); + } else { + error = (QueryError) in.readObject(); } + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/StandardError.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/StandardError.java index b172eb4f5..f5d97daaa 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/StandardError.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/StandardError.java @@ -21,43 +21,31 @@ public enum StandardError { - /** - * Analytics server un available. - */ - ANALYTICS_SERVER_UNAVAILABLE(1001), - - /** - * Analytics server busy. - */ - ANALYTICS_SERVER_BUSY(1002), - - /** - * Analytics no available coordinator. - */ - ANALYTICS_NO_COORDINATOR(1003), - - /** - * Analytics query result is null. - */ - ANALYTICS_NULL_RESULT(1004), - - /** - * Analytics rpc error. - */ - ANALYTICS_RPC_ERROR(1005), - - /** - * Analytics query result is too long. - */ - ANALYTICS_RESULT_TO_LONG(1006); - - private final QueryError errorCode; - - StandardError(int code) { - errorCode = new QueryError(name(), code); - } - - public QueryError getQueryError() { - return errorCode; - } + /** Analytics server un available. */ + ANALYTICS_SERVER_UNAVAILABLE(1001), + + /** Analytics server busy. */ + ANALYTICS_SERVER_BUSY(1002), + + /** Analytics no available coordinator. */ + ANALYTICS_NO_COORDINATOR(1003), + + /** Analytics query result is null. */ + ANALYTICS_NULL_RESULT(1004), + + /** Analytics rpc error. */ + ANALYTICS_RPC_ERROR(1005), + + /** Analytics query result is too long. */ + ANALYTICS_RESULT_TO_LONG(1006); + + private final QueryError errorCode; + + StandardError(int code) { + errorCode = new QueryError(name(), code); + } + + public QueryError getQueryError() { + return errorCode; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewEdge.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewEdge.java index 6143fb9a3..b72cda3da 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewEdge.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewEdge.java @@ -23,39 +23,42 @@ public class ViewEdge { - private final String source; - private final String target; - private final String label; - private final String direction; - private final Map properties; + private final String source; + private final String target; + private final String label; + private final String direction; + private final Map properties; - public ViewEdge(String source, String target, String label, Map properties, String direction) { - this.source = source; - this.target = target; - this.label = label; - this.properties = properties; - this.direction = direction; - } + public ViewEdge( + String source, + String target, + String label, + Map properties, + String direction) { + this.source = source; + this.target = target; + this.label = label; + this.properties = properties; + this.direction = direction; + } - public String getSource() { - return source; - } + public String getSource() { + return source; + } - public String getTarget() { - return target; - } + public String getTarget() { + return target; + } - public String getLabel() { - return label; - } + public String getLabel() { + return label; + } + public String getDirection() { + return direction; + } - public String getDirection() { - return direction; - } - - public Map getProperties() { - return properties; - } - + public Map getProperties() { + return properties; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewVertex.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewVertex.java index f3c869586..ca385d266 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewVertex.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/main/java/org/apache/geaflow/analytics/service/query/ViewVertex.java @@ -23,26 +23,25 @@ public class ViewVertex { - private final String id; - private final String label; - private final Map properties; - - public ViewVertex(String identifier, String label, Map properties) { - this.id = identifier; - this.label = label; - this.properties = properties; - } - - public String getId() { - return id; - } - - public String getLabel() { - return label; - } - - public Map getProperties() { - return properties; - } - + private final String id; + private final String label; + private final Map properties; + + public ViewVertex(String identifier, String label, Map properties) { + this.id = identifier; + this.label = label; + this.properties = properties; + } + + public String getId() { + return id; + } + + public String getLabel() { + return label; + } + + public Map getProperties() { + return properties; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryIdGeneratorTest.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryIdGeneratorTest.java index 663f934bc..3ba9f945f 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryIdGeneratorTest.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryIdGeneratorTest.java @@ -26,47 +26,55 @@ public class QueryIdGeneratorTest { - @Test - public void testCreateNextQueryId() { - QueryIdGeneratorOnlyTest queryIdGenerator = new QueryIdGeneratorOnlyTest(); - long currentMillis = new DateTime(2023, 8, 12, 1, 2, 3, 4, DateTimeZone.UTC).getMillis(); - queryIdGenerator.setCurrentTime(currentMillis); + @Test + public void testCreateNextQueryId() { + QueryIdGeneratorOnlyTest queryIdGenerator = new QueryIdGeneratorOnlyTest(); + long currentMillis = new DateTime(2023, 8, 12, 1, 2, 3, 4, DateTimeZone.UTC).getMillis(); + queryIdGenerator.setCurrentTime(currentMillis); - // Generate ids to 99,999 - for (int i = 0; i < 100_000; i++) { - assertEquals(queryIdGenerator.createQueryId(), String.format("20230812_010203_%05d_%s", i, queryIdGenerator.getCoordinatorId())); - } + // Generate ids to 99,999 + for (int i = 0; i < 100_000; i++) { + assertEquals( + queryIdGenerator.createQueryId(), + String.format("20230812_010203_%05d_%s", i, queryIdGenerator.getCoordinatorId())); + } - currentMillis += 1000; - queryIdGenerator.setCurrentTime(currentMillis); - for (int i = 0; i < 100_000; i++) { - assertEquals(queryIdGenerator.createQueryId(), String.format("20230812_010204_%05d_%s", i, queryIdGenerator.getCoordinatorId())); - } + currentMillis += 1000; + queryIdGenerator.setCurrentTime(currentMillis); + for (int i = 0; i < 100_000; i++) { + assertEquals( + queryIdGenerator.createQueryId(), + String.format("20230812_010204_%05d_%s", i, queryIdGenerator.getCoordinatorId())); + } - currentMillis += 1000; - queryIdGenerator.setCurrentTime(currentMillis); - for (int i = 0; i < 100; i++) { - assertEquals(queryIdGenerator.createQueryId(), String.format("20230812_010205_%05d_%s", i, queryIdGenerator.getCoordinatorId())); - } + currentMillis += 1000; + queryIdGenerator.setCurrentTime(currentMillis); + for (int i = 0; i < 100; i++) { + assertEquals( + queryIdGenerator.createQueryId(), + String.format("20230812_010205_%05d_%s", i, queryIdGenerator.getCoordinatorId())); + } - currentMillis = new DateTime(2023, 8, 13, 0, 0, 0, 0, DateTimeZone.UTC).getMillis(); - queryIdGenerator.setCurrentTime(currentMillis); - for (int i = 0; i < 100_000; i++) { - assertEquals(queryIdGenerator.createQueryId(), String.format("20230813_000000_%05d_%s", i, queryIdGenerator.getCoordinatorId())); - } + currentMillis = new DateTime(2023, 8, 13, 0, 0, 0, 0, DateTimeZone.UTC).getMillis(); + queryIdGenerator.setCurrentTime(currentMillis); + for (int i = 0; i < 100_000; i++) { + assertEquals( + queryIdGenerator.createQueryId(), + String.format("20230813_000000_%05d_%s", i, queryIdGenerator.getCoordinatorId())); } + } - private static class QueryIdGeneratorOnlyTest extends QueryIdGenerator { + private static class QueryIdGeneratorOnlyTest extends QueryIdGenerator { - private long currentTime; + private long currentTime; - public void setCurrentTime(long currentTime) { - this.currentTime = currentTime; - } + public void setCurrentTime(long currentTime) { + this.currentTime = currentTime; + } - @Override - protected long nowInMillis() { - return currentTime; - } + @Override + protected long nowInMillis() { + return currentTime; } + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryResultTest.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryResultTest.java index 5fd79e5e8..9e0412d65 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryResultTest.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-common/src/test/java/org/apache/geaflow/analytics/service/query/QueryResultTest.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.calcite.rel.type.RelRecordType; import org.apache.geaflow.cluster.response.ResponseResult; import org.apache.geaflow.common.serialize.SerializerFactory; @@ -36,111 +37,112 @@ public class QueryResultTest { - private static final String ERROR_MSG = "server execute failed"; - - @Test - public void testQueryError() { - QueryError queryError = new QueryError(ERROR_MSG); - Assert.assertEquals(queryError.getCode(), 0); - Assert.assertEquals(queryError.getName(), ERROR_MSG); - queryError = new QueryError(ERROR_MSG, 1); - Assert.assertEquals(queryError.getCode(), 1); - Assert.assertEquals(queryError.getName(), ERROR_MSG); - } + private static final String ERROR_MSG = "server execute failed"; - @Test - public void testQueryResult() { - String queryId = "1"; - QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); - Assert.assertNull(queryResults.getFormattedData()); - Assert.assertEquals(queryResults.getQueryId(), queryId); - Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); - queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); - Assert.assertEquals(queryResults.getQueryId(), queryId); - Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); - } + @Test + public void testQueryError() { + QueryError queryError = new QueryError(ERROR_MSG); + Assert.assertEquals(queryError.getCode(), 0); + Assert.assertEquals(queryError.getName(), ERROR_MSG); + queryError = new QueryError(ERROR_MSG, 1); + Assert.assertEquals(queryError.getCode(), 1); + Assert.assertEquals(queryError.getName(), ERROR_MSG); + } - @Test - public void testQueryStatus() { - String queryId = "1"; - List> result = new ArrayList<>(); - ArrayList responseResults = new ArrayList<>(); - responseResults.add(new ResponseResult(1, OutputType.RESPONSE, - Collections.singletonList("result"))); - result.add(responseResults); - QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); - Assert.assertNull(queryResults.getFormattedData()); - Assert.assertEquals(queryResults.getQueryId(), queryId); - Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); - Assert.assertFalse(queryResults.getQueryStatus()); + @Test + public void testQueryResult() { + String queryId = "1"; + QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); + Assert.assertNull(queryResults.getFormattedData()); + Assert.assertEquals(queryResults.getQueryId(), queryId); + Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); + queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); + Assert.assertEquals(queryResults.getQueryId(), queryId); + Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); + } - queryResults = new QueryResults(queryId, result); + @Test + public void testQueryStatus() { + String queryId = "1"; + List> result = new ArrayList<>(); + ArrayList responseResults = new ArrayList<>(); + responseResults.add( + new ResponseResult(1, OutputType.RESPONSE, Collections.singletonList("result"))); + result.add(responseResults); + QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); + Assert.assertNull(queryResults.getFormattedData()); + Assert.assertEquals(queryResults.getQueryId(), queryId); + Assert.assertEquals(queryResults.getError().getName(), ERROR_MSG); + Assert.assertFalse(queryResults.getQueryStatus()); - Assert.assertEquals(queryResults.getQueryId(), queryId); - Assert.assertTrue(queryResults.getQueryStatus()); - } + queryResults = new QueryResults(queryId, result); - @Test - public void testQueryResultWriteAndReadObjectWithData() throws IOException, ClassNotFoundException { - String queryId = "1"; - List> result = new ArrayList<>(); - ArrayList responseResults = new ArrayList<>(); - responseResults.add(new ResponseResult(1, OutputType.RESPONSE, - Collections.singletonList("result"))); - result.add(responseResults); - QueryResults queryResults = new QueryResults(queryId, result); - queryResults.setResultMeta(new RelRecordType(new ArrayList<>())); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream); - objectOutputStream.writeObject(queryResults); + Assert.assertEquals(queryResults.getQueryId(), queryId); + Assert.assertTrue(queryResults.getQueryStatus()); + } - ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); - ObjectInputStream objectInputStream = new ObjectInputStream(inputStream); - QueryResults deserializedQueryResults = (QueryResults) objectInputStream.readObject(); - Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); - Assert.assertTrue(deserializedQueryResults.getQueryStatus()); - Assert.assertTrue(deserializedQueryResults.getResultMeta().getFieldNames().isEmpty()); - outputStream.close(); - inputStream.close(); - } + @Test + public void testQueryResultWriteAndReadObjectWithData() + throws IOException, ClassNotFoundException { + String queryId = "1"; + List> result = new ArrayList<>(); + ArrayList responseResults = new ArrayList<>(); + responseResults.add( + new ResponseResult(1, OutputType.RESPONSE, Collections.singletonList("result"))); + result.add(responseResults); + QueryResults queryResults = new QueryResults(queryId, result); + queryResults.setResultMeta(new RelRecordType(new ArrayList<>())); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream); + objectOutputStream.writeObject(queryResults); - @Test - public void testQueryResultWriteAndReadObjectWithQueryError() throws IOException, - ClassNotFoundException { - String queryId = "1"; - QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream); - objectOutputStream.writeObject(queryResults); + ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(inputStream); + QueryResults deserializedQueryResults = (QueryResults) objectInputStream.readObject(); + Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); + Assert.assertTrue(deserializedQueryResults.getQueryStatus()); + Assert.assertTrue(deserializedQueryResults.getResultMeta().getFieldNames().isEmpty()); + outputStream.close(); + inputStream.close(); + } - ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); - ObjectInputStream objectInputStream = new ObjectInputStream(inputStream); - QueryResults deserializedQueryResults = (QueryResults) objectInputStream.readObject(); - Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); - Assert.assertFalse(deserializedQueryResults.getQueryStatus()); - Assert.assertEquals(deserializedQueryResults.getError().getName(), ERROR_MSG); - outputStream.close(); - inputStream.close(); - } + @Test + public void testQueryResultWriteAndReadObjectWithQueryError() + throws IOException, ClassNotFoundException { + String queryId = "1"; + QueryResults queryResults = new QueryResults(queryId, new QueryError(ERROR_MSG)); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream); + objectOutputStream.writeObject(queryResults); - @Test - public void testQueryResultWriteAndReadObjectWithKryo() { - String queryId = "1"; - List> result = new ArrayList<>(); - ArrayList responseResults = new ArrayList<>(); - responseResults.add(new ResponseResult(1, OutputType.RESPONSE, - Collections.singletonList("result"))); - result.add(responseResults); - QueryResults queryResults = new QueryResults(queryId, result); - queryResults.setResultMeta(new RelRecordType(new ArrayList<>())); - byte[] serialize = SerializerFactory.getKryoSerializer().serialize(queryResults); - QueryResults deserializedQueryResults = (QueryResults) SerializerFactory.getKryoSerializer().deserialize(serialize); - Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); - Assert.assertTrue(deserializedQueryResults.getQueryStatus()); - String expectResult = "{\"viewResult\":{\"nodes\":[],\"edges\":[]}," - + "\"jsonResult\":[\"result\"]}"; - Assert.assertEquals(deserializedQueryResults.getFormattedData(), expectResult); - Assert.assertTrue(deserializedQueryResults.getResultMeta().getFieldNames().isEmpty()); - } + ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); + ObjectInputStream objectInputStream = new ObjectInputStream(inputStream); + QueryResults deserializedQueryResults = (QueryResults) objectInputStream.readObject(); + Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); + Assert.assertFalse(deserializedQueryResults.getQueryStatus()); + Assert.assertEquals(deserializedQueryResults.getError().getName(), ERROR_MSG); + outputStream.close(); + inputStream.close(); + } + @Test + public void testQueryResultWriteAndReadObjectWithKryo() { + String queryId = "1"; + List> result = new ArrayList<>(); + ArrayList responseResults = new ArrayList<>(); + responseResults.add( + new ResponseResult(1, OutputType.RESPONSE, Collections.singletonList("result"))); + result.add(responseResults); + QueryResults queryResults = new QueryResults(queryId, result); + queryResults.setResultMeta(new RelRecordType(new ArrayList<>())); + byte[] serialize = SerializerFactory.getKryoSerializer().serialize(queryResults); + QueryResults deserializedQueryResults = + (QueryResults) SerializerFactory.getKryoSerializer().deserialize(serialize); + Assert.assertEquals(deserializedQueryResults.getQueryId(), queryId); + Assert.assertTrue(deserializedQueryResults.getQueryStatus()); + String expectResult = + "{\"viewResult\":{\"nodes\":[],\"edges\":[]}," + "\"jsonResult\":[\"result\"]}"; + Assert.assertEquals(deserializedQueryResults.getFormattedData(), expectResult); + Assert.assertTrue(deserializedQueryResults.getResultMeta().getFieldNames().isEmpty()); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/AbstractAnalyticsServiceServer.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/AbstractAnalyticsServiceServer.java index 6925aca29..d11402baf 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/AbstractAnalyticsServiceServer.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/AbstractAnalyticsServiceServer.java @@ -24,7 +24,6 @@ import static org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE; import static org.apache.geaflow.common.config.keys.DSLConfigKeys.GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE; -import com.google.common.base.Preconditions; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; @@ -32,6 +31,7 @@ import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; + import org.apache.calcite.rel.type.RelDataType; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; @@ -63,166 +63,186 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public abstract class AbstractAnalyticsServiceServer implements IServiceServer { - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractAnalyticsServiceServer.class); - - private static final String ANALYTICS_SERVICE_PREFIX = "analytics-service-"; - private static final String SERVER_EXECUTOR = "server-executor"; - - protected int port; - protected int maxRequests; - protected boolean running; - protected PipelineService pipelineService; - protected PipelineServiceExecutorContext serviceExecutorContext; - protected BlockingQueue requestBlockingQueue; - protected BlockingMap> responseBlockingMap; - protected BlockingQueue cancelRequestBlockingQueue; - protected BlockingQueue cancelResponseBlockingQueue; - protected MetaServerClient metaServerClient; - protected Semaphore semaphore; - private ExecutorService executorService; - - protected Configuration configuration; - protected boolean enableCompileSchema; - - @Override - public void init(IPipelineServiceExecutorContext context) { - this.serviceExecutorContext = (PipelineServiceExecutorContext) context; - this.pipelineService = this.serviceExecutorContext.getPipelineService(); - this.configuration = context.getConfiguration(); - this.port = configuration.getInteger(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT); - this.maxRequests = configuration.getInteger( - AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER); - this.requestBlockingQueue = new LinkedBlockingQueue<>(maxRequests); - this.responseBlockingMap = new BlockingMap<>(); - this.cancelRequestBlockingQueue = new LinkedBlockingQueue<>(maxRequests); - this.cancelResponseBlockingQueue = new LinkedBlockingQueue<>(maxRequests); - this.semaphore = new Semaphore(maxRequests); - this.executorService = Executors.newFixedThreadPool(this.maxRequests, - ThreadUtil.namedThreadFactory(true, SERVER_EXECUTOR, - new ComponentUncaughtExceptionHandler())); - this.enableCompileSchema = configuration.getBoolean(ANALYTICS_COMPILE_SCHEMA_ENABLE); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractAnalyticsServiceServer.class); + + private static final String ANALYTICS_SERVICE_PREFIX = "analytics-service-"; + private static final String SERVER_EXECUTOR = "server-executor"; + + protected int port; + protected int maxRequests; + protected boolean running; + protected PipelineService pipelineService; + protected PipelineServiceExecutorContext serviceExecutorContext; + protected BlockingQueue requestBlockingQueue; + protected BlockingMap> responseBlockingMap; + protected BlockingQueue cancelRequestBlockingQueue; + protected BlockingQueue cancelResponseBlockingQueue; + protected MetaServerClient metaServerClient; + protected Semaphore semaphore; + private ExecutorService executorService; + + protected Configuration configuration; + protected boolean enableCompileSchema; + + @Override + public void init(IPipelineServiceExecutorContext context) { + this.serviceExecutorContext = (PipelineServiceExecutorContext) context; + this.pipelineService = this.serviceExecutorContext.getPipelineService(); + this.configuration = context.getConfiguration(); + this.port = configuration.getInteger(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT); + this.maxRequests = configuration.getInteger(AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER); + this.requestBlockingQueue = new LinkedBlockingQueue<>(maxRequests); + this.responseBlockingMap = new BlockingMap<>(); + this.cancelRequestBlockingQueue = new LinkedBlockingQueue<>(maxRequests); + this.cancelResponseBlockingQueue = new LinkedBlockingQueue<>(maxRequests); + this.semaphore = new Semaphore(maxRequests); + this.executorService = + Executors.newFixedThreadPool( + this.maxRequests, + ThreadUtil.namedThreadFactory( + true, SERVER_EXECUTOR, new ComponentUncaughtExceptionHandler())); + this.enableCompileSchema = configuration.getBoolean(ANALYTICS_COMPILE_SCHEMA_ENABLE); + } + + @Override + public void stopServer() { + this.running = false; + if (this.metaServerClient != null) { + this.metaServerClient.close(); } - - @Override - public void stopServer() { - this.running = false; - if (this.metaServerClient != null) { - this.metaServerClient.close(); - } + } + + public static QueryResults getQueryResults( + QueryInfo queryInfo, BlockingMap> responseBlockingMap) + throws Exception { + Future resultFuture = responseBlockingMap.get(queryInfo.getQueryId()); + IExecutionResult result = resultFuture.get(); + QueryResults queryResults; + String queryId = queryInfo.getQueryId(); + if (result.isSuccess()) { + List> responseResult = (List>) result.getResult(); + queryResults = new QueryResults(queryId, responseResult); + } else { + String errorMsg = result.getError().toString(); + queryResults = new QueryResults(queryId, new QueryError(errorMsg)); } - - public static QueryResults getQueryResults(QueryInfo queryInfo, - BlockingMap> responseBlockingMap) throws Exception { - Future resultFuture = responseBlockingMap.get(queryInfo.getQueryId()); - IExecutionResult result = resultFuture.get(); - QueryResults queryResults; - String queryId = queryInfo.getQueryId(); - if (result.isSuccess()) { - List> responseResult = (List>) result.getResult(); - queryResults = new QueryResults(queryId, responseResult); - } else { - String errorMsg = result.getError().toString(); - queryResults = new QueryResults(queryId, new QueryError(errorMsg)); + queryResults.setResultMeta(queryInfo.getScriptSchema()); + return queryResults; + } + + protected void waitForExecuted() { + registerServiceInfo(); + + while (this.running) { + try { + QueryInfo queryInfo = requestBlockingQueue.take(); + final String queryScript = queryInfo.getQueryScript(); + final String queryId = queryInfo.getQueryId(); + + if (enableCompileSchema) { + try { + CompileResult compileResult = compileQuerySchema(queryScript, configuration); + RelDataType relDataType = compileResult.getCurrentResultType(); + queryInfo.setScriptSchema(relDataType); + } catch (Throwable e) { + // Set error code if precompile failed. + LOGGER.error("precompile query: {} failed", queryInfo, e); + Future future = + executorService.submit( + () -> + new ExecutionResult( + queryId, new QueryError(ExceptionUtils.getStackTrace(e)), false)); + responseBlockingMap.put(queryId, future); + continue; + } } - queryResults.setResultMeta(queryInfo.getScriptSchema()); - return queryResults; - } - protected void waitForExecuted() { - registerServiceInfo(); - - while (this.running) { - try { - QueryInfo queryInfo = requestBlockingQueue.take(); - final String queryScript = queryInfo.getQueryScript(); - final String queryId = queryInfo.getQueryId(); - - if (enableCompileSchema) { - try { - CompileResult compileResult = compileQuerySchema(queryScript, configuration); - RelDataType relDataType = compileResult.getCurrentResultType(); - queryInfo.setScriptSchema(relDataType); - } catch (Throwable e) { - // Set error code if precompile failed. - LOGGER.error("precompile query: {} failed", queryInfo, e); - Future future = executorService.submit( - () -> new ExecutionResult(queryId, new QueryError(ExceptionUtils.getStackTrace(e)), false)); - responseBlockingMap.put(queryId, future); - continue; - } - } - - Future future = executorService.submit(() -> { - try { - return executeQuery(queryScript); - } catch (Throwable e) { - LOGGER.error("execute query: {} failed", queryInfo, e); - return new ExecutionResult(queryId, new QueryError(ExceptionUtils.getStackTrace(e)), false); - } + Future future = + executorService.submit( + () -> { + try { + return executeQuery(queryScript); + } catch (Throwable e) { + LOGGER.error("execute query: {} failed", queryInfo, e); + return new ExecutionResult( + queryId, new QueryError(ExceptionUtils.getStackTrace(e)), false); + } }); - responseBlockingMap.put(queryId, future); - } catch (Throwable t) { - if (this.running) { - LOGGER.error("analytics service abnormal {}", t.getMessage(), t); - } - } + responseBlockingMap.put(queryId, future); + } catch (Throwable t) { + if (this.running) { + LOGGER.error("analytics service abnormal {}", t.getMessage(), t); } + } } - - protected static CompileResult compileQuerySchema(String query, Configuration configuration) { - QueryClient queryManager = new QueryClient(); - CompileContext compileContext = new CompileContext(); - compileContext.setConfig(configuration.getConfigMap()); - compileContext.getConfig().put(GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), - Boolean.FALSE.toString()); - return queryManager.compile(query, compileContext); + } + + protected static CompileResult compileQuerySchema(String query, Configuration configuration) { + QueryClient queryManager = new QueryClient(); + CompileContext compileContext = new CompileContext(); + compileContext.setConfig(configuration.getConfigMap()); + compileContext + .getConfig() + .put(GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), Boolean.FALSE.toString()); + return queryManager.compile(query, compileContext); + } + + private void registerServiceInfo() { + // First initialize analytics service instance and only in service 0. + if (serviceExecutorContext.getDriverIndex() == 0) { + String analyticsQuery = + serviceExecutorContext.getPipelineContext().getConfig().getString(ANALYTICS_QUERY); + Preconditions.checkArgument(analyticsQuery != null, "analytics query must be not null"); + executeQuery(analyticsQuery); + LOGGER.info( + "service index {} analytics query execute successfully", + serviceExecutorContext.getDriverIndex()); } - - private void registerServiceInfo() { - // First initialize analytics service instance and only in service 0. - if (serviceExecutorContext.getDriverIndex() == 0) { - String analyticsQuery = serviceExecutorContext.getPipelineContext().getConfig() - .getString(ANALYTICS_QUERY); - Preconditions.checkArgument(analyticsQuery != null, "analytics query must be not null"); - executeQuery(analyticsQuery); - LOGGER.info("service index {} analytics query execute successfully", - serviceExecutorContext.getDriverIndex()); - } - // Register analytics service info. - if (serviceExecutorContext.getConfiguration() - .getBoolean(ANALYTICS_SERVICE_REGISTER_ENABLE)) { - metaServerClient = new MetaServerClient(serviceExecutorContext.getConfiguration()); - metaServerClient.registerService(NamespaceType.DEFAULT, - ANALYTICS_SERVICE_PREFIX + serviceExecutorContext.getDriverIndex(), - new HostAndPort(ProcessUtil.getHostIp(), port)); - LOGGER.info("service index {} register analytics service {}:{}", - serviceExecutorContext.getDriverIndex(), ProcessUtil.getHostIp(), port); - } - this.running = true; + // Register analytics service info. + if (serviceExecutorContext.getConfiguration().getBoolean(ANALYTICS_SERVICE_REGISTER_ENABLE)) { + metaServerClient = new MetaServerClient(serviceExecutorContext.getConfiguration()); + metaServerClient.registerService( + NamespaceType.DEFAULT, + ANALYTICS_SERVICE_PREFIX + serviceExecutorContext.getDriverIndex(), + new HostAndPort(ProcessUtil.getHostIp(), port)); + LOGGER.info( + "service index {} register analytics service {}:{}", + serviceExecutorContext.getDriverIndex(), + ProcessUtil.getHostIp(), + port); } + this.running = true; + } - private IExecutionResult executeQuery(String query) { - // User pipeline Task. - PipelineContext pipelineContext = new PipelineContext( + private IExecutionResult executeQuery(String query) { + // User pipeline Task. + PipelineContext pipelineContext = + new PipelineContext( serviceExecutorContext.getPipelineContext().getName(), serviceExecutorContext.getPipelineContext().getConfig()); - serviceExecutorContext.getPipelineContext().getViewDescMap().forEach( - (s, iViewDesc) -> pipelineContext.addView(iViewDesc)); - PipelineServiceContext serviceContext = new PipelineServiceContext( - System.currentTimeMillis(), pipelineContext, query); - pipelineService.execute(serviceContext); - PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); - // 1. Build pipeline graph plan. - PipelineGraph pipelineGraph = pipelinePlanBuilder.buildPlan(pipelineContext); - // 2. Opt pipeline graph plan. - pipelinePlanBuilder.optimizePlan(pipelineContext.getConfig()); - // 3. Execute query. - IExecutionResult result = this.serviceExecutorContext.getPipelineRunner() + serviceExecutorContext + .getPipelineContext() + .getViewDescMap() + .forEach((s, iViewDesc) -> pipelineContext.addView(iViewDesc)); + PipelineServiceContext serviceContext = + new PipelineServiceContext(System.currentTimeMillis(), pipelineContext, query); + pipelineService.execute(serviceContext); + PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); + // 1. Build pipeline graph plan. + PipelineGraph pipelineGraph = pipelinePlanBuilder.buildPlan(pipelineContext); + // 2. Opt pipeline graph plan. + pipelinePlanBuilder.optimizePlan(pipelineContext.getConfig()); + // 3. Execute query. + IExecutionResult result = + this.serviceExecutorContext + .getPipelineRunner() .runPipelineGraph(pipelineGraph, serviceExecutorContext); - return result; - } + return result; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/HttpAnalyticsServiceServer.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/HttpAnalyticsServiceServer.java index 44c6cd435..130ee3ce8 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/HttpAnalyticsServiceServer.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/HttpAnalyticsServiceServer.java @@ -41,98 +41,97 @@ public class HttpAnalyticsServiceServer extends AbstractAnalyticsServiceServer { - private static final Logger LOGGER = LoggerFactory.getLogger(HttpAnalyticsServiceServer.class); - private static final String SERVER_NAME = "analytics-service-server"; - private static final String SERVER_SCHEDULER = "analytics-service-server-scheduler"; - private static final String QUERY_EXECUTE_REST_API = "/rest/analytics/query/execute"; - - private HttpAnalyticsServiceHandler httpHandler; - private Server server; - private QueuedThreadPool threadPool; - private ScheduledExecutorScheduler serverExecutor; - - @Override - public void init(IPipelineServiceExecutorContext context) { - super.init(context); - - this.httpHandler = new HttpAnalyticsServiceHandler(requestBlockingQueue, - responseBlockingMap, semaphore); - this.threadPool = new QueuedThreadPool(); - this.threadPool.setDaemon(true); - this.threadPool.setName(SERVER_NAME); - this.server = new Server(threadPool); - this.port = PortUtil.getPort(port); - ErrorHandler errorHandler = new ErrorHandler(); - errorHandler.setShowStacks(true); - errorHandler.setServer(this.server); - this.server.addBean(errorHandler); - - this.serverExecutor = new ScheduledExecutorScheduler(SERVER_SCHEDULER, true); + private static final Logger LOGGER = LoggerFactory.getLogger(HttpAnalyticsServiceServer.class); + private static final String SERVER_NAME = "analytics-service-server"; + private static final String SERVER_SCHEDULER = "analytics-service-server-scheduler"; + private static final String QUERY_EXECUTE_REST_API = "/rest/analytics/query/execute"; + + private HttpAnalyticsServiceHandler httpHandler; + private Server server; + private QueuedThreadPool threadPool; + private ScheduledExecutorScheduler serverExecutor; + + @Override + public void init(IPipelineServiceExecutorContext context) { + super.init(context); + + this.httpHandler = + new HttpAnalyticsServiceHandler(requestBlockingQueue, responseBlockingMap, semaphore); + this.threadPool = new QueuedThreadPool(); + this.threadPool.setDaemon(true); + this.threadPool.setName(SERVER_NAME); + this.server = new Server(threadPool); + this.port = PortUtil.getPort(port); + ErrorHandler errorHandler = new ErrorHandler(); + errorHandler.setShowStacks(true); + errorHandler.setServer(this.server); + this.server.addBean(errorHandler); + + this.serverExecutor = new ScheduledExecutorScheduler(SERVER_SCHEDULER, true); + } + + @Override + public void startServer() { + // Jetty's processing collection. + ContextHandlerCollection contexts = new ContextHandlerCollection(); + server.setHandler(contexts); + + // Add servlet. + ServletContextHandler contextHandler = + new ServletContextHandler(ServletContextHandler.SESSIONS); + contextHandler.addServlet(new ServletHolder(httpHandler), QUERY_EXECUTE_REST_API); + contexts.addHandler(contextHandler); + + try { + ServerConnector connector = newConnector(server, serverExecutor, null, port); + connector.setName(SERVER_NAME); + server.addConnector(connector); + + int minThreads = 1; + minThreads += connector.getAcceptors() * 2; + threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); + server.start(); + String hostIpAddress = ProcessUtil.getHostIp(); + LOGGER.info("Http analytics Server started: ip {}, port {}", hostIpAddress, port); + } catch (Exception e) { + LOGGER.error("Http analytics Server start failed:", e); + throw new GeaflowRuntimeException(e); } - - @Override - public void startServer() { - // Jetty's processing collection. - ContextHandlerCollection contexts = new ContextHandlerCollection(); - server.setHandler(contexts); - - // Add servlet. - ServletContextHandler contextHandler = new ServletContextHandler( - ServletContextHandler.SESSIONS); - contextHandler.addServlet(new ServletHolder(httpHandler), QUERY_EXECUTE_REST_API); - contexts.addHandler(contextHandler); - - try { - ServerConnector connector = newConnector(server, serverExecutor, null, port); - connector.setName(SERVER_NAME); - server.addConnector(connector); - - int minThreads = 1; - minThreads += connector.getAcceptors() * 2; - threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); - server.start(); - String hostIpAddress = ProcessUtil.getHostIp(); - LOGGER.info("Http analytics Server started: ip {}, port {}", hostIpAddress, port); - } catch (Exception e) { - LOGGER.error("Http analytics Server start failed:", e); - throw new GeaflowRuntimeException(e); - } - waitForExecuted(); + waitForExecuted(); + } + + @Override + public void stopServer() { + try { + super.stopServer(); + server.stop(); + if (threadPool.isStarted()) { + threadPool.stop(); + } + if (serverExecutor.isStarted()) { + serverExecutor.stop(); + } + } catch (Exception e) { + LOGGER.warn("stop analytics server failed", e); + throw new GeaflowRuntimeException(e); } - - @Override - public void stopServer() { - try { - super.stopServer(); - server.stop(); - if (threadPool.isStarted()) { - threadPool.stop(); - } - if (serverExecutor.isStarted()) { - serverExecutor.stop(); - } - } catch (Exception e) { - LOGGER.warn("stop analytics server failed", e); - throw new GeaflowRuntimeException(e); - } - } - - @Override - public ServiceType getServiceType() { - return ServiceType.analytics_http; - } - - private ServerConnector newConnector(Server server, ScheduledExecutorScheduler serverExecutor, - String hostName, int port) throws Exception { - ConnectionFactory[] connectionFactories = new ConnectionFactory[]{ - new HttpConnectionFactory()}; - ServerConnector connector = new ServerConnector(server, null, serverExecutor, null, -1, -1, - connectionFactories); - connector.setHost(hostName); - connector.setPort(port); - connector.start(); - connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), maxRequests)); - return connector; - } - + } + + @Override + public ServiceType getServiceType() { + return ServiceType.analytics_http; + } + + private ServerConnector newConnector( + Server server, ScheduledExecutorScheduler serverExecutor, String hostName, int port) + throws Exception { + ConnectionFactory[] connectionFactories = new ConnectionFactory[] {new HttpConnectionFactory()}; + ServerConnector connector = + new ServerConnector(server, null, serverExecutor, null, -1, -1, connectionFactories); + connector.setHost(hostName); + connector.setPort(port); + connector.start(); + connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), maxRequests)); + return connector; + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/AbstractHttpHandler.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/AbstractHttpHandler.java index fa9a7ed02..f98a97a09 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/AbstractHttpHandler.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/AbstractHttpHandler.java @@ -24,16 +24,15 @@ public class AbstractHttpHandler extends HttpServlet { - protected static final String ERROR_KEY = "error"; - - protected void addHeader(HttpServletResponse response) { - response.setHeader("Access-Control-Allow-Origin", "*"); - response.setHeader("Access-Control-Allow-Methods", "*"); - response.setHeader("Access-Control-Max-Age", "3600"); - response.addHeader("Access-Control-Allow-Headers", "*"); - response.setHeader("Access-Control-Allow-Credentials", "*"); - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - } + protected static final String ERROR_KEY = "error"; + protected void addHeader(HttpServletResponse response) { + response.setHeader("Access-Control-Allow-Origin", "*"); + response.setHeader("Access-Control-Allow-Methods", "*"); + response.setHeader("Access-Control-Max-Age", "3600"); + response.addHeader("Access-Control-Allow-Headers", "*"); + response.setHeader("Access-Control-Allow-Credentials", "*"); + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/HttpAnalyticsServiceHandler.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/HttpAnalyticsServiceHandler.java index 8689d4201..4eab6f2ce 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/HttpAnalyticsServiceHandler.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/http/handler/HttpAnalyticsServiceHandler.java @@ -21,8 +21,6 @@ import static org.apache.geaflow.analytics.service.server.AbstractAnalyticsServiceServer.getQueryResults; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; import java.io.IOException; import java.lang.reflect.Type; import java.util.Map; @@ -30,8 +28,10 @@ import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.stream.Collectors; + import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; + import org.apache.geaflow.analytics.service.query.QueryError; import org.apache.geaflow.analytics.service.query.QueryIdGenerator; import org.apache.geaflow.analytics.service.query.QueryInfo; @@ -43,58 +43,66 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + public class HttpAnalyticsServiceHandler extends AbstractHttpHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(HttpAnalyticsServiceHandler.class); - private static final Type DEFAULT_REQUEST_TYPE = new TypeToken>() { - }.getType(); - private static final String QUERY = "query"; + private static final Logger LOGGER = LoggerFactory.getLogger(HttpAnalyticsServiceHandler.class); + private static final Type DEFAULT_REQUEST_TYPE = + new TypeToken>() {}.getType(); + private static final String QUERY = "query"; - private final BlockingQueue requestBlockingQueue; - private final BlockingMap> responseBlockingMap; - private final QueryIdGenerator queryIdGenerator; - private final Semaphore semaphore; + private final BlockingQueue requestBlockingQueue; + private final BlockingMap> responseBlockingMap; + private final QueryIdGenerator queryIdGenerator; + private final Semaphore semaphore; - public HttpAnalyticsServiceHandler(BlockingQueue requestBlockingQueue, - BlockingMap> responseBlockingMap, - Semaphore semaphore) { - this.requestBlockingQueue = requestBlockingQueue; - this.responseBlockingMap = responseBlockingMap; - this.semaphore = semaphore; - this.queryIdGenerator = new QueryIdGenerator(); - } + public HttpAnalyticsServiceHandler( + BlockingQueue requestBlockingQueue, + BlockingMap> responseBlockingMap, + Semaphore semaphore) { + this.requestBlockingQueue = requestBlockingQueue; + this.responseBlockingMap = responseBlockingMap; + this.semaphore = semaphore; + this.queryIdGenerator = new QueryIdGenerator(); + } - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { - QueryResults result = null; - String queryId = queryIdGenerator.createQueryId(); - try { - if (!this.semaphore.tryAcquire()) { - QueryError queryError = StandardError.ANALYTICS_SERVER_BUSY.getQueryError(); - result = new QueryResults(queryId, queryError); - resp.setStatus(HttpServletResponse.SC_OK); - } else { - String requestBody = req.getReader().lines().collect(Collectors.joining(System.lineSeparator())); - Map requestParam = new Gson().fromJson(requestBody, DEFAULT_REQUEST_TYPE); - String query = requestParam.get(QUERY).toString(); - QueryInfo queryInfo = new QueryInfo(queryId, query); - LOGGER.info("start execute query [{}]", queryInfo); - final long start = System.currentTimeMillis(); - requestBlockingQueue.put(queryInfo); - result = getQueryResults(queryInfo, responseBlockingMap); - LOGGER.info("finish execute query [{}], result {}, cost {}ms", result, resp, System.currentTimeMillis() - start); - resp.setStatus(HttpServletResponse.SC_OK); - } - } catch (Throwable t) { - result = new QueryResults(queryId, new QueryError(t.getMessage())); - resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - } finally { - addHeader(resp); - byte[] serializeResult = SerializerFactory.getKryoSerializer().serialize(result); - resp.getOutputStream().write(serializeResult); - resp.getOutputStream().flush(); - this.semaphore.release(); - } + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { + QueryResults result = null; + String queryId = queryIdGenerator.createQueryId(); + try { + if (!this.semaphore.tryAcquire()) { + QueryError queryError = StandardError.ANALYTICS_SERVER_BUSY.getQueryError(); + result = new QueryResults(queryId, queryError); + resp.setStatus(HttpServletResponse.SC_OK); + } else { + String requestBody = + req.getReader().lines().collect(Collectors.joining(System.lineSeparator())); + Map requestParam = new Gson().fromJson(requestBody, DEFAULT_REQUEST_TYPE); + String query = requestParam.get(QUERY).toString(); + QueryInfo queryInfo = new QueryInfo(queryId, query); + LOGGER.info("start execute query [{}]", queryInfo); + final long start = System.currentTimeMillis(); + requestBlockingQueue.put(queryInfo); + result = getQueryResults(queryInfo, responseBlockingMap); + LOGGER.info( + "finish execute query [{}], result {}, cost {}ms", + result, + resp, + System.currentTimeMillis() - start); + resp.setStatus(HttpServletResponse.SC_OK); + } + } catch (Throwable t) { + result = new QueryResults(queryId, new QueryError(t.getMessage())); + resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } finally { + addHeader(resp); + byte[] serializeResult = SerializerFactory.getKryoSerializer().serialize(result); + resp.getOutputStream().write(serializeResult); + resp.getOutputStream().flush(); + this.semaphore.release(); } - + } } diff --git a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/rpc/RpcAnalyticsServiceServer.java b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/rpc/RpcAnalyticsServiceServer.java index 0f99abc6b..8c614825c 100644 --- a/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/rpc/RpcAnalyticsServiceServer.java +++ b/geaflow/geaflow-analytics-service/geaflow-analytics-service-server/src/main/java/org/apache/geaflow/analytics/service/server/rpc/RpcAnalyticsServiceServer.java @@ -19,12 +19,10 @@ package org.apache.geaflow.analytics.service.server.rpc; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.stub.StreamObserver; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Future; import java.util.concurrent.Semaphore; + import org.apache.geaflow.analytics.service.query.QueryError; import org.apache.geaflow.analytics.service.query.QueryIdGenerator; import org.apache.geaflow.analytics.service.query.QueryInfo; @@ -45,105 +43,116 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.stub.StreamObserver; + public class RpcAnalyticsServiceServer extends AbstractAnalyticsServiceServer { - private static final Logger LOGGER = LoggerFactory.getLogger(RpcAnalyticsServiceServer.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RpcAnalyticsServiceServer.class); - private Server server; + private Server server; - @Override - public void startServer() { - try { - // Will support server ip and port report in future. - this.server = ServerBuilder.forPort(this.port).addService(new CoordinatorImpl(this)).build().start(); - String hostIpAddress = ProcessUtil.getHostIp(); - this.port = this.server.getPort(); - LOGGER.info("Server started: {}, listening on: {}", hostIpAddress, port); - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - // Use stderr here since the logger may have been reset by its JVM shutdown hook. - LOGGER.warn("*** Shutting down analytics server since JVM is shutting down."); - stopServer(); - LOGGER.warn("*** Geaflow analytics server shutdown."); - })); - } catch (Throwable t) { - LOGGER.error(t.getMessage(), t); - throw new GeaflowRuntimeException(t); - } - waitForExecuted(); + @Override + public void startServer() { + try { + // Will support server ip and port report in future. + this.server = + ServerBuilder.forPort(this.port).addService(new CoordinatorImpl(this)).build().start(); + String hostIpAddress = ProcessUtil.getHostIp(); + this.port = this.server.getPort(); + LOGGER.info("Server started: {}, listening on: {}", hostIpAddress, port); + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + // Use stderr here since the logger may have been reset by its JVM shutdown + // hook. + LOGGER.warn("*** Shutting down analytics server since JVM is shutting down."); + stopServer(); + LOGGER.warn("*** Geaflow analytics server shutdown."); + })); + } catch (Throwable t) { + LOGGER.error(t.getMessage(), t); + throw new GeaflowRuntimeException(t); } + waitForExecuted(); + } - @Override - public void stopServer() { - super.stopServer(); - if (this.server != null) { - this.server.shutdown(); - } + @Override + public void stopServer() { + super.stopServer(); + if (this.server != null) { + this.server.shutdown(); } + } - @Override - public ServiceType getServiceType() { - return ServiceType.analytics_rpc; - } + @Override + public ServiceType getServiceType() { + return ServiceType.analytics_rpc; + } - static class CoordinatorImpl extends AnalyticsServiceGrpc.AnalyticsServiceImplBase { + static class CoordinatorImpl extends AnalyticsServiceGrpc.AnalyticsServiceImplBase { - private final BlockingQueue requestBlockingQueue; - private final BlockingMap> responseBlockingMap; - private final BlockingQueue cancelRequestBlockingQueue; - private final BlockingQueue cancelResponseBlockingQueue; - private final QueryIdGenerator queryIdGenerator; - private final Semaphore semaphore; + private final BlockingQueue requestBlockingQueue; + private final BlockingMap> responseBlockingMap; + private final BlockingQueue cancelRequestBlockingQueue; + private final BlockingQueue cancelResponseBlockingQueue; + private final QueryIdGenerator queryIdGenerator; + private final Semaphore semaphore; - public CoordinatorImpl(RpcAnalyticsServiceServer server) { - this.requestBlockingQueue = server.requestBlockingQueue; - this.responseBlockingMap = server.responseBlockingMap; - this.cancelRequestBlockingQueue = server.cancelRequestBlockingQueue; - this.cancelResponseBlockingQueue = server.cancelResponseBlockingQueue; - this.semaphore = server.semaphore; - this.queryIdGenerator = new QueryIdGenerator(); - } + public CoordinatorImpl(RpcAnalyticsServiceServer server) { + this.requestBlockingQueue = server.requestBlockingQueue; + this.responseBlockingMap = server.responseBlockingMap; + this.cancelRequestBlockingQueue = server.cancelRequestBlockingQueue; + this.cancelResponseBlockingQueue = server.cancelResponseBlockingQueue; + this.semaphore = server.semaphore; + this.queryIdGenerator = new QueryIdGenerator(); + } - @Override - public void executeQuery(Analytics.QueryRequest request, StreamObserver responseObserver) { - String queryId = queryIdGenerator.createQueryId(); - if (!this.semaphore.tryAcquire()) { - QueryError queryError = StandardError.ANALYTICS_SERVER_BUSY.getQueryError(); - QueryResults queryResults = new QueryResults(queryId, queryError); - QueryResult result = QueryResult.newBuilder() - .setQueryResult(RpcMessageEncoder.encode(queryResults)) - .build(); - responseObserver.onNext(result); - responseObserver.onCompleted(); - } - try { - String query = request.getQuery(); - QueryInfo queryInfo = new QueryInfo(queryId, query); - final long start = System.currentTimeMillis(); - requestBlockingQueue.put(queryInfo); - QueryResults queryResults = getQueryResults(queryInfo, responseBlockingMap); - LOGGER.info("finish execute query [{}], cost {}ms, query result {}", queryInfo, - System.currentTimeMillis() - start, queryResults); - QueryResult queryResult = QueryResult.newBuilder() - .setQueryResult(RpcMessageEncoder.encode(queryResults)) - .build(); - responseObserver.onNext(queryResult); - responseObserver.onCompleted(); - } catch (Throwable t) { - LOGGER.error("execute query: [{}] failed, cause: {}", request.getQuery(), t); - QueryResults queryResults = new QueryResults(queryId, new QueryError(t.getMessage())); - QueryResult result = QueryResult.newBuilder() - .setQueryResult(RpcMessageEncoder.encode(queryResults)) - .build(); - responseObserver.onNext(result); - responseObserver.onCompleted(); - } finally { - this.semaphore.release(); - } - } + @Override + public void executeQuery( + Analytics.QueryRequest request, StreamObserver responseObserver) { + String queryId = queryIdGenerator.createQueryId(); + if (!this.semaphore.tryAcquire()) { + QueryError queryError = StandardError.ANALYTICS_SERVER_BUSY.getQueryError(); + QueryResults queryResults = new QueryResults(queryId, queryError); + QueryResult result = + QueryResult.newBuilder().setQueryResult(RpcMessageEncoder.encode(queryResults)).build(); + responseObserver.onNext(result); + responseObserver.onCompleted(); + } + try { + String query = request.getQuery(); + QueryInfo queryInfo = new QueryInfo(queryId, query); + final long start = System.currentTimeMillis(); + requestBlockingQueue.put(queryInfo); + QueryResults queryResults = getQueryResults(queryInfo, responseBlockingMap); + LOGGER.info( + "finish execute query [{}], cost {}ms, query result {}", + queryInfo, + System.currentTimeMillis() - start, + queryResults); + QueryResult queryResult = + QueryResult.newBuilder().setQueryResult(RpcMessageEncoder.encode(queryResults)).build(); + responseObserver.onNext(queryResult); + responseObserver.onCompleted(); + } catch (Throwable t) { + LOGGER.error("execute query: [{}] failed, cause: {}", request.getQuery(), t); + QueryResults queryResults = new QueryResults(queryId, new QueryError(t.getMessage())); + QueryResult result = + QueryResult.newBuilder().setQueryResult(RpcMessageEncoder.encode(queryResults)).build(); + responseObserver.onNext(result); + responseObserver.onCompleted(); + } finally { + this.semaphore.release(); + } + } - @Override - public void cancelQuery(QueryCancelRequest request, StreamObserver responseObserver) { - throw new GeaflowRuntimeException("Not supported cancel query yet."); - } + @Override + public void cancelQuery( + QueryCancelRequest request, StreamObserver responseObserver) { + throw new GeaflowRuntimeException("Not supported cancel query yet."); } + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/PrimitiveType.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/PrimitiveType.java index f53b8c653..b458f14b5 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/PrimitiveType.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/PrimitiveType.java @@ -20,32 +20,32 @@ package org.apache.geaflow.collection; public enum PrimitiveType { - INT, - LONG, - DOUBLE, - BYTE, - SHORT, - BOOLEAN, - FLOAT, - BYTE_ARRAY, - OBJECT; + INT, + LONG, + DOUBLE, + BYTE, + SHORT, + BOOLEAN, + FLOAT, + BYTE_ARRAY, + OBJECT; - private static final String INTEGER = "INTEGER"; - private static final String BYTES = "BYTE[]"; + private static final String INTEGER = "INTEGER"; + private static final String BYTES = "BYTE[]"; - public static PrimitiveType getEnum(String value) { - String up = value.toUpperCase(); - for (PrimitiveType v : values()) { - if (v.name().equals(up)) { - return v; - } - } - if (INTEGER.equals(up)) { - return INT; - } else if (BYTES.equals(up)) { - return BYTE_ARRAY; - } else { - return OBJECT; - } + public static PrimitiveType getEnum(String value) { + String up = value.toUpperCase(); + for (PrimitiveType v : values()) { + if (v.name().equals(up)) { + return v; + } } + if (INTEGER.equals(up)) { + return INT; + } else if (BYTES.equals(up)) { + return BYTE_ARRAY; + } else { + return OBJECT; + } + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BooleanArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BooleanArray.java index 7b0064581..5dacda6ed 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BooleanArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BooleanArray.java @@ -21,25 +21,24 @@ public class BooleanArray implements PrimitiveArray { - private boolean[] arr; - - public BooleanArray(int capacity) { - arr = new boolean[capacity]; - } - - @Override - public void set(int pos, Boolean value) { - arr[pos] = value; - } - - @Override - public Boolean get(int pos) { - return arr[pos]; - } - - @Override - public void drop() { - arr = null; - } - + private boolean[] arr; + + public BooleanArray(int capacity) { + arr = new boolean[capacity]; + } + + @Override + public void set(int pos, Boolean value) { + arr[pos] = value; + } + + @Override + public Boolean get(int pos) { + return arr[pos]; + } + + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ByteArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ByteArray.java index 3b0936a63..4f9946e8c 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ByteArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ByteArray.java @@ -21,25 +21,24 @@ public class ByteArray implements PrimitiveArray { - private byte[] arr; - - public ByteArray(int capacity) { - arr = new byte[capacity]; - } - - @Override - public void set(int pos, Byte value) { - arr[pos] = value; - } - - @Override - public Byte get(int pos) { - return arr[pos]; - } - - @Override - public void drop() { - arr = null; - } - + private byte[] arr; + + public ByteArray(int capacity) { + arr = new byte[capacity]; + } + + @Override + public void set(int pos, Byte value) { + arr[pos] = value; + } + + @Override + public Byte get(int pos) { + return arr[pos]; + } + + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BytesArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BytesArray.java index 24598799c..4d3f02413 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BytesArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/BytesArray.java @@ -21,24 +21,24 @@ public class BytesArray implements PrimitiveArray { - private byte[][] arr; + private byte[][] arr; - public BytesArray(int capacity) { - arr = new byte[capacity][]; - } + public BytesArray(int capacity) { + arr = new byte[capacity][]; + } - @Override - public void set(int pos, byte[] value) { - arr[pos] = value; - } + @Override + public void set(int pos, byte[] value) { + arr[pos] = value; + } - @Override - public byte[] get(int pos) { - return arr[pos]; - } + @Override + public byte[] get(int pos) { + return arr[pos]; + } - @Override - public void drop() { - this.arr = null; - } + @Override + public void drop() { + this.arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/DoubleArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/DoubleArray.java index 3d39cdc2a..7c4ed2c5d 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/DoubleArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/DoubleArray.java @@ -21,25 +21,24 @@ public class DoubleArray implements PrimitiveArray { - private double[] arr; - - public DoubleArray(int capacity) { - arr = new double[capacity]; - } - - @Override - public void set(int pos, Double value) { - arr[pos] = value; - } - - @Override - public Double get(int pos) { - return arr[pos]; - } - - @Override - public void drop() { - arr = null; - } - + private double[] arr; + + public DoubleArray(int capacity) { + arr = new double[capacity]; + } + + @Override + public void set(int pos, Double value) { + arr[pos] = value; + } + + @Override + public Double get(int pos) { + return arr[pos]; + } + + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/FloatArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/FloatArray.java index 74a9e04c5..34aaed3fb 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/FloatArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/FloatArray.java @@ -19,28 +19,26 @@ package org.apache.geaflow.collection.array; - public class FloatArray implements PrimitiveArray { - private float[] arr; - - public FloatArray(int capacity) { - arr = new float[capacity]; - } + private float[] arr; - @Override - public void set(int pos, Float value) { - arr[pos] = value; - } + public FloatArray(int capacity) { + arr = new float[capacity]; + } - @Override - public Float get(int pos) { - return arr[pos]; - } + @Override + public void set(int pos, Float value) { + arr[pos] = value; + } - @Override - public void drop() { - arr = null; - } + @Override + public Float get(int pos) { + return arr[pos]; + } + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/IntArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/IntArray.java index 7bb26fe37..1e4c3d867 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/IntArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/IntArray.java @@ -21,25 +21,24 @@ public class IntArray implements PrimitiveArray { - private int[] arr; - - public IntArray(int capacity) { - arr = new int[capacity]; - } - - @Override - public void set(int pos, Integer value) { - arr[pos] = value; - } - - @Override - public Integer get(int pos) { - return arr[pos]; - } - - @Override - public void drop() { - arr = null; - } - + private int[] arr; + + public IntArray(int capacity) { + arr = new int[capacity]; + } + + @Override + public void set(int pos, Integer value) { + arr[pos] = value; + } + + @Override + public Integer get(int pos) { + return arr[pos]; + } + + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/LongArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/LongArray.java index 857758d2d..40a8896d1 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/LongArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/LongArray.java @@ -21,25 +21,24 @@ public class LongArray implements PrimitiveArray { - private long[] arr; - - public LongArray(int capacity) { - arr = new long[capacity]; - } - - @Override - public void set(int pos, Long value) { - arr[pos] = value; - } - - @Override - public Long get(int pos) { - return arr[pos]; - } - - @Override - public void drop() { - arr = null; - } - + private long[] arr; + + public LongArray(int capacity) { + arr = new long[capacity]; + } + + @Override + public void set(int pos, Long value) { + arr[pos] = value; + } + + @Override + public Long get(int pos) { + return arr[pos]; + } + + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ObjectArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ObjectArray.java index ae20ac24d..462fbccf4 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ObjectArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ObjectArray.java @@ -19,28 +19,26 @@ package org.apache.geaflow.collection.array; - public class ObjectArray implements PrimitiveArray { - private Object[] arr; - - public ObjectArray(int capacity) { - arr = new Object[capacity]; - } + private Object[] arr; - @Override - public void set(int pos, Object value) { - arr[pos] = value; - } + public ObjectArray(int capacity) { + arr = new Object[capacity]; + } - @Override - public Object get(int pos) { - return arr[pos]; - } + @Override + public void set(int pos, Object value) { + arr[pos] = value; + } - @Override - public void drop() { - arr = null; - } + @Override + public Object get(int pos) { + return arr[pos]; + } + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArray.java index a6aab865d..0ee32a0a3 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArray.java @@ -21,10 +21,9 @@ public interface PrimitiveArray { - void set(int pos, E value); + void set(int pos, E value); - E get(int pos); - - void drop(); + E get(int pos); + void drop(); } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArrayFactory.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArrayFactory.java index 8b13bf55d..cc9b1bde5 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArrayFactory.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/PrimitiveArrayFactory.java @@ -25,28 +25,29 @@ public class PrimitiveArrayFactory { - private static final Function DEFAULT_BUILDER = ObjectArray::new; - private static final Map, Function> UNMARSHALL = new HashMap<>(); + private static final Function DEFAULT_BUILDER = ObjectArray::new; + private static final Map, Function> UNMARSHALL = + new HashMap<>(); - static { - UNMARSHALL.put(Integer.class, IntArray::new); - UNMARSHALL.put(Integer.TYPE, IntArray::new); - UNMARSHALL.put(Long.class, LongArray::new); - UNMARSHALL.put(Long.TYPE, LongArray::new); - UNMARSHALL.put(Double.class, DoubleArray::new); - UNMARSHALL.put(Double.TYPE, DoubleArray::new); - UNMARSHALL.put(Byte.class, ByteArray::new); - UNMARSHALL.put(Byte.TYPE, ByteArray::new); - UNMARSHALL.put(Float.class, FloatArray::new); - UNMARSHALL.put(Float.TYPE, FloatArray::new); - UNMARSHALL.put(Boolean.class, BooleanArray::new); - UNMARSHALL.put(Boolean.TYPE, BooleanArray::new); - UNMARSHALL.put(Short.class, ShortArray::new); - UNMARSHALL.put(Short.TYPE, ShortArray::new); - UNMARSHALL.put(byte[].class, BytesArray::new); - } + static { + UNMARSHALL.put(Integer.class, IntArray::new); + UNMARSHALL.put(Integer.TYPE, IntArray::new); + UNMARSHALL.put(Long.class, LongArray::new); + UNMARSHALL.put(Long.TYPE, LongArray::new); + UNMARSHALL.put(Double.class, DoubleArray::new); + UNMARSHALL.put(Double.TYPE, DoubleArray::new); + UNMARSHALL.put(Byte.class, ByteArray::new); + UNMARSHALL.put(Byte.TYPE, ByteArray::new); + UNMARSHALL.put(Float.class, FloatArray::new); + UNMARSHALL.put(Float.TYPE, FloatArray::new); + UNMARSHALL.put(Boolean.class, BooleanArray::new); + UNMARSHALL.put(Boolean.TYPE, BooleanArray::new); + UNMARSHALL.put(Short.class, ShortArray::new); + UNMARSHALL.put(Short.TYPE, ShortArray::new); + UNMARSHALL.put(byte[].class, BytesArray::new); + } - public static PrimitiveArray getCustomArray(Class type, int capacity) { - return UNMARSHALL.getOrDefault(type, DEFAULT_BUILDER).apply(capacity); - } + public static PrimitiveArray getCustomArray(Class type, int capacity) { + return UNMARSHALL.getOrDefault(type, DEFAULT_BUILDER).apply(capacity); + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ShortArray.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ShortArray.java index e40f30f2f..74c0de507 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ShortArray.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/array/ShortArray.java @@ -19,28 +19,26 @@ package org.apache.geaflow.collection.array; - public class ShortArray implements PrimitiveArray { - private short[] arr; - - public ShortArray(int capacity) { - arr = new short[capacity]; - } + private short[] arr; - @Override - public void set(int pos, Short value) { - arr[pos] = value; - } + public ShortArray(int capacity) { + arr = new short[capacity]; + } - @Override - public Short get(int pos) { - return arr[pos]; - } + @Override + public void set(int pos, Short value) { + arr[pos] = value; + } - @Override - public void drop() { - arr = null; - } + @Override + public Short get(int pos) { + return arr[pos]; + } + @Override + public void drop() { + arr = null; + } } diff --git a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/map/MapFactory.java b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/map/MapFactory.java index d0886d7eb..860144db8 100644 --- a/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/map/MapFactory.java +++ b/geaflow/geaflow-collection/src/main/java/org/apache/geaflow/collection/map/MapFactory.java @@ -19,6 +19,11 @@ package org.apache.geaflow.collection.map; +import java.util.HashMap; +import java.util.Map; + +import org.apache.geaflow.collection.PrimitiveType; + import it.unimi.dsi.fastutil.bytes.Byte2BooleanOpenHashMap; import it.unimi.dsi.fastutil.bytes.Byte2ByteArrayMap; import it.unimi.dsi.fastutil.bytes.Byte2ByteOpenHashMap; @@ -82,229 +87,226 @@ import it.unimi.dsi.fastutil.shorts.Short2LongOpenHashMap; import it.unimi.dsi.fastutil.shorts.Short2ObjectOpenHashMap; import it.unimi.dsi.fastutil.shorts.Short2ShortOpenHashMap; -import java.util.HashMap; -import java.util.Map; -import org.apache.geaflow.collection.PrimitiveType; public class MapFactory { - private static final Map ADAPTOR_MAP = new HashMap<>(); - private static final MapFactoryAdaptor DEFAULT_ADAPTOR = new ObjectMapFactoryAdaptor(); - - static { - ADAPTOR_MAP.put(Integer.class, new IntMapFactoryAdaptor()); - ADAPTOR_MAP.put(Integer.TYPE, new IntMapFactoryAdaptor()); - ADAPTOR_MAP.put(Byte.class, new ByteMapFactoryAdaptor()); - ADAPTOR_MAP.put(Byte.TYPE, new ByteMapFactoryAdaptor()); - ADAPTOR_MAP.put(Double.class, new DoubleMapFactoryAdaptor()); - ADAPTOR_MAP.put(Double.TYPE, new DoubleMapFactoryAdaptor()); - ADAPTOR_MAP.put(Long.class, new LongMapFactoryAdaptor()); - ADAPTOR_MAP.put(Long.TYPE, new LongMapFactoryAdaptor()); - ADAPTOR_MAP.put(Float.class, new FloatMapFactoryAdaptor()); - ADAPTOR_MAP.put(Float.TYPE, new FloatMapFactoryAdaptor()); - ADAPTOR_MAP.put(Short.class, new ShortMapFactoryAdaptor()); - ADAPTOR_MAP.put(Short.TYPE, new ShortMapFactoryAdaptor()); - } + private static final Map ADAPTOR_MAP = new HashMap<>(); + private static final MapFactoryAdaptor DEFAULT_ADAPTOR = new ObjectMapFactoryAdaptor(); - public static Map buildMap(Class key, Class value) { - MapFactoryAdaptor adaptor = ADAPTOR_MAP.get(key); - if (adaptor != null) { - return adaptor.buildMap(value); - } + static { + ADAPTOR_MAP.put(Integer.class, new IntMapFactoryAdaptor()); + ADAPTOR_MAP.put(Integer.TYPE, new IntMapFactoryAdaptor()); + ADAPTOR_MAP.put(Byte.class, new ByteMapFactoryAdaptor()); + ADAPTOR_MAP.put(Byte.TYPE, new ByteMapFactoryAdaptor()); + ADAPTOR_MAP.put(Double.class, new DoubleMapFactoryAdaptor()); + ADAPTOR_MAP.put(Double.TYPE, new DoubleMapFactoryAdaptor()); + ADAPTOR_MAP.put(Long.class, new LongMapFactoryAdaptor()); + ADAPTOR_MAP.put(Long.TYPE, new LongMapFactoryAdaptor()); + ADAPTOR_MAP.put(Float.class, new FloatMapFactoryAdaptor()); + ADAPTOR_MAP.put(Float.TYPE, new FloatMapFactoryAdaptor()); + ADAPTOR_MAP.put(Short.class, new ShortMapFactoryAdaptor()); + ADAPTOR_MAP.put(Short.TYPE, new ShortMapFactoryAdaptor()); + } - return DEFAULT_ADAPTOR.buildMap(value); + public static Map buildMap(Class key, Class value) { + MapFactoryAdaptor adaptor = ADAPTOR_MAP.get(key); + if (adaptor != null) { + return adaptor.buildMap(value); } - public interface MapFactoryAdaptor { + return DEFAULT_ADAPTOR.buildMap(value); + } - Map buildMap(Class value); - } + public interface MapFactoryAdaptor { + + Map buildMap(Class value); + } - public static class ByteMapFactoryAdaptor implements MapFactoryAdaptor { + public static class ByteMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Byte2IntOpenHashMap(); - case LONG: - return (Map) new Byte2LongOpenHashMap(); - case BYTE: - return (Map) new Byte2ByteOpenHashMap(); - case FLOAT: - return (Map) new Byte2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Byte2BooleanOpenHashMap(); - case SHORT: - return (Map) new Byte2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Byte2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Byte2ByteArrayMap(); - default: - return (Map) new Byte2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Byte2IntOpenHashMap(); + case LONG: + return (Map) new Byte2LongOpenHashMap(); + case BYTE: + return (Map) new Byte2ByteOpenHashMap(); + case FLOAT: + return (Map) new Byte2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Byte2BooleanOpenHashMap(); + case SHORT: + return (Map) new Byte2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Byte2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Byte2ByteArrayMap(); + default: + return (Map) new Byte2ObjectOpenHashMap(); + } } + } - public static class DoubleMapFactoryAdaptor implements MapFactoryAdaptor { + public static class DoubleMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Double2IntOpenHashMap(); - case LONG: - return (Map) new Double2LongOpenHashMap(); - case BYTE: - return (Map) new Double2ByteOpenHashMap(); - case FLOAT: - return (Map) new Double2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Double2BooleanOpenHashMap(); - case SHORT: - return (Map) new Double2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Double2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Double2ByteArrayMap(); - default: - return (Map) new Double2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Double2IntOpenHashMap(); + case LONG: + return (Map) new Double2LongOpenHashMap(); + case BYTE: + return (Map) new Double2ByteOpenHashMap(); + case FLOAT: + return (Map) new Double2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Double2BooleanOpenHashMap(); + case SHORT: + return (Map) new Double2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Double2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Double2ByteArrayMap(); + default: + return (Map) new Double2ObjectOpenHashMap(); + } } + } - public static class FloatMapFactoryAdaptor implements MapFactoryAdaptor { + public static class FloatMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Float2IntOpenHashMap(); - case LONG: - return (Map) new Float2LongOpenHashMap(); - case BYTE: - return (Map) new Float2ByteOpenHashMap(); - case FLOAT: - return (Map) new Float2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Float2BooleanOpenHashMap(); - case SHORT: - return (Map) new Float2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Float2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Float2ByteArrayMap(); - default: - return (Map) new Float2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Float2IntOpenHashMap(); + case LONG: + return (Map) new Float2LongOpenHashMap(); + case BYTE: + return (Map) new Float2ByteOpenHashMap(); + case FLOAT: + return (Map) new Float2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Float2BooleanOpenHashMap(); + case SHORT: + return (Map) new Float2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Float2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Float2ByteArrayMap(); + default: + return (Map) new Float2ObjectOpenHashMap(); + } } + } - public static class IntMapFactoryAdaptor implements MapFactoryAdaptor { + public static class IntMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Int2IntOpenHashMap(); - case LONG: - return (Map) new Int2LongOpenHashMap(); - case BYTE: - return (Map) new Int2ByteOpenHashMap(); - case FLOAT: - return (Map) new Int2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Int2BooleanOpenHashMap(); - case SHORT: - return (Map) new Int2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Int2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Int2ByteArrayMap(); - default: - return (Map) new Int2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Int2IntOpenHashMap(); + case LONG: + return (Map) new Int2LongOpenHashMap(); + case BYTE: + return (Map) new Int2ByteOpenHashMap(); + case FLOAT: + return (Map) new Int2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Int2BooleanOpenHashMap(); + case SHORT: + return (Map) new Int2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Int2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Int2ByteArrayMap(); + default: + return (Map) new Int2ObjectOpenHashMap(); + } } + } - public static class LongMapFactoryAdaptor implements MapFactoryAdaptor { + public static class LongMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Long2IntOpenHashMap(); - case LONG: - return (Map) new Long2LongOpenHashMap(); - case BYTE: - return (Map) new Long2ByteOpenHashMap(); - case FLOAT: - return (Map) new Long2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Long2BooleanOpenHashMap(); - case SHORT: - return (Map) new Long2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Long2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Long2ByteArrayMap(); - default: - return (Map) new Long2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Long2IntOpenHashMap(); + case LONG: + return (Map) new Long2LongOpenHashMap(); + case BYTE: + return (Map) new Long2ByteOpenHashMap(); + case FLOAT: + return (Map) new Long2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Long2BooleanOpenHashMap(); + case SHORT: + return (Map) new Long2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Long2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Long2ByteArrayMap(); + default: + return (Map) new Long2ObjectOpenHashMap(); + } } + } - public static class ShortMapFactoryAdaptor implements MapFactoryAdaptor { + public static class ShortMapFactoryAdaptor implements MapFactoryAdaptor { - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Short2IntOpenHashMap(); - case LONG: - return (Map) new Short2LongOpenHashMap(); - case BYTE: - return (Map) new Short2ByteOpenHashMap(); - case FLOAT: - return (Map) new Short2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Short2BooleanOpenHashMap(); - case SHORT: - return (Map) new Short2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Short2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Short2ByteArrayMap(); - default: - return (Map) new Short2ObjectOpenHashMap(); - } - } + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Short2IntOpenHashMap(); + case LONG: + return (Map) new Short2LongOpenHashMap(); + case BYTE: + return (Map) new Short2ByteOpenHashMap(); + case FLOAT: + return (Map) new Short2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Short2BooleanOpenHashMap(); + case SHORT: + return (Map) new Short2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Short2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Short2ByteArrayMap(); + default: + return (Map) new Short2ObjectOpenHashMap(); + } } + } - public static class ObjectMapFactoryAdaptor implements MapFactoryAdaptor { + public static class ObjectMapFactoryAdaptor implements MapFactoryAdaptor { - @Override - public Map buildMap(Class value) { - switch (PrimitiveType.getEnum(value.getSimpleName())) { - case INT: - return (Map) new Object2IntOpenHashMap(); - case LONG: - return (Map) new Object2LongOpenHashMap(); - case BYTE: - return (Map) new Object2ByteOpenHashMap(); - case FLOAT: - return (Map) new Object2FloatOpenHashMap(); - case BOOLEAN: - return (Map) new Object2BooleanOpenHashMap(); - case SHORT: - return (Map) new Object2ShortOpenHashMap(); - case DOUBLE: - return (Map) new Object2DoubleOpenHashMap(); - case BYTE_ARRAY: - return (Map) new Object2ByteArrayMap(); - default: - return (Map) new Object2ObjectOpenHashMap(); - } - } + @Override + public Map buildMap(Class value) { + switch (PrimitiveType.getEnum(value.getSimpleName())) { + case INT: + return (Map) new Object2IntOpenHashMap(); + case LONG: + return (Map) new Object2LongOpenHashMap(); + case BYTE: + return (Map) new Object2ByteOpenHashMap(); + case FLOAT: + return (Map) new Object2FloatOpenHashMap(); + case BOOLEAN: + return (Map) new Object2BooleanOpenHashMap(); + case SHORT: + return (Map) new Object2ShortOpenHashMap(); + case DOUBLE: + return (Map) new Object2DoubleOpenHashMap(); + case BYTE_ARRAY: + return (Map) new Object2ByteArrayMap(); + default: + return (Map) new Object2ObjectOpenHashMap(); + } } + } } diff --git a/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/array/PrimitiveArrayFactoryTest.java b/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/array/PrimitiveArrayFactoryTest.java index 5585c5f73..11f62a956 100644 --- a/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/array/PrimitiveArrayFactoryTest.java +++ b/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/array/PrimitiveArrayFactoryTest.java @@ -19,47 +19,48 @@ package org.apache.geaflow.collection.array; - import java.util.Arrays; import java.util.List; import java.util.Random; + import org.apache.geaflow.collection.PrimitiveType; import org.testng.Assert; import org.testng.annotations.Test; public class PrimitiveArrayFactoryTest { - private Random random = new Random(); + private Random random = new Random(); - private Object getRandomValue(Class clazz) { - PrimitiveType type = PrimitiveType.getEnum(clazz.getSimpleName()); - switch (type) { - case INT: - return random.nextInt(10); - case BYTE: - return (byte) (random.nextInt(8)); - case DOUBLE: - return random.nextDouble(); - case FLOAT: - return random.nextFloat(); - case BOOLEAN: - return random.nextBoolean(); - case SHORT: - return (short) random.nextInt(1000); - case LONG: - return random.nextLong(); - case BYTE_ARRAY: - byte[] bytes = new byte[5]; - random.nextBytes(bytes); - return bytes; - default: - throw new UnsupportedOperationException(); - } + private Object getRandomValue(Class clazz) { + PrimitiveType type = PrimitiveType.getEnum(clazz.getSimpleName()); + switch (type) { + case INT: + return random.nextInt(10); + case BYTE: + return (byte) (random.nextInt(8)); + case DOUBLE: + return random.nextDouble(); + case FLOAT: + return random.nextFloat(); + case BOOLEAN: + return random.nextBoolean(); + case SHORT: + return (short) random.nextInt(1000); + case LONG: + return random.nextLong(); + case BYTE_ARRAY: + byte[] bytes = new byte[5]; + random.nextBytes(bytes); + return bytes; + default: + throw new UnsupportedOperationException(); } + } - @Test - public void test() { - List list = Arrays.asList( + @Test + public void test() { + List list = + Arrays.asList( Integer.TYPE, Long.TYPE, Double.TYPE, @@ -67,18 +68,16 @@ public void test() { Byte.TYPE, Short.TYPE, Boolean.TYPE, - byte[].class - ); + byte[].class); - for (int i = 0; i < list.size(); i++) { - PrimitiveArray array = PrimitiveArrayFactory.getCustomArray(list.get(i), 10); - for (int j = 0; j < 10; j++) { - Object obj = getRandomValue(list.get(i)); - array.set(j, obj); - Assert.assertEquals(array.get(j), obj); - } - array.drop(); - } + for (int i = 0; i < list.size(); i++) { + PrimitiveArray array = PrimitiveArrayFactory.getCustomArray(list.get(i), 10); + for (int j = 0; j < 10; j++) { + Object obj = getRandomValue(list.get(i)); + array.set(j, obj); + Assert.assertEquals(array.get(j), obj); + } + array.drop(); } - + } } diff --git a/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/map/MapFactoryTest.java b/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/map/MapFactoryTest.java index eccd6c1a0..cfd7f8bd6 100644 --- a/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/map/MapFactoryTest.java +++ b/geaflow/geaflow-collection/src/test/java/org/apache/geaflow/collection/map/MapFactoryTest.java @@ -22,38 +22,34 @@ import java.util.Arrays; import java.util.List; import java.util.Map; + import org.testng.Assert; import org.testng.annotations.Test; public class MapFactoryTest { - @Test - public void test() { - List list = Arrays.asList( - Integer.TYPE, - Long.TYPE, - Double.TYPE, - Float.TYPE, - Byte.TYPE, - Short.TYPE, - byte[].class - ); - - for (int i = 0; i < list.size() - 1; i++) { - for (int j = 0; j < list.size(); j++) { - Map map = MapFactory.buildMap(list.get(i), list.get(j)); - - String key = list.get(i).getSimpleName(); - String value = list.get(j).getSimpleName(); - String mapClass = "OpenHashMap"; - if (value.equals("byte[]")) { - value = "ByteArray"; - mapClass = "Map"; - } - - Assert.assertEquals(map.getClass().getSimpleName().toLowerCase(), - (key + "2" + value + mapClass).toLowerCase()); - } + @Test + public void test() { + List list = + Arrays.asList( + Integer.TYPE, Long.TYPE, Double.TYPE, Float.TYPE, Byte.TYPE, Short.TYPE, byte[].class); + + for (int i = 0; i < list.size() - 1; i++) { + for (int j = 0; j < list.size(); j++) { + Map map = MapFactory.buildMap(list.get(i), list.get(j)); + + String key = list.get(i).getSimpleName(); + String value = list.get(j).getSimpleName(); + String mapClass = "OpenHashMap"; + if (value.equals("byte[]")) { + value = "ByteArray"; + mapClass = "Map"; } + + Assert.assertEquals( + map.getClass().getSimpleName().toLowerCase(), + (key + "2" + value + mapClass).toLowerCase()); + } } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryOperations.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryOperations.java index 4359006d0..795b7189d 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryOperations.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryOperations.java @@ -24,212 +24,229 @@ public class BinaryOperations { - private static final sun.misc.Unsafe UNSAFE; - - public static final int BOOLEAN_ARRAY_OFFSET; - - public static final int BYTE_ARRAY_OFFSET; - - public static final int SHORT_ARRAY_OFFSET; - - public static final int INT_ARRAY_OFFSET; - - public static final int LONG_ARRAY_OFFSET; - - public static final int FLOAT_ARRAY_OFFSET; - - public static final int DOUBLE_ARRAY_OFFSET; - - private static final int MAJOR_VERSION = - Integer.parseInt(System.getProperty("java.version").split("\\D+")[0]); - - private static final boolean UNALIGNED; - - static { - sun.misc.Unsafe unsafe; - try { - Field unsafeField = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); - unsafeField.setAccessible(true); - unsafe = (sun.misc.Unsafe) unsafeField.get(null); - } catch (Throwable cause) { - unsafe = null; - } - UNSAFE = unsafe; - - if (UNSAFE != null) { - BOOLEAN_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(boolean[].class); - BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); - SHORT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(short[].class); - INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class); - LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); - FLOAT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(float[].class); - DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + private static final sun.misc.Unsafe UNSAFE; + + public static final int BOOLEAN_ARRAY_OFFSET; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int SHORT_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int FLOAT_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + private static final int MAJOR_VERSION = + Integer.parseInt(System.getProperty("java.version").split("\\D+")[0]); + + private static final boolean UNALIGNED; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + UNSAFE = unsafe; + + if (UNSAFE != null) { + BOOLEAN_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(boolean[].class); + BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(short[].class); + INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(float[].class); + DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + } else { + BOOLEAN_ARRAY_OFFSET = 0; + BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } + + static { + boolean _unaligned; + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64") || arch.equals("s390x")) { + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + if (UNSAFE != null && MAJOR_VERSION >= 9) { + // Java 9/10 and 11/12 have different field names. + Field unalignedField = + bitsClass.getDeclaredField(MAJOR_VERSION >= 11 ? "UNALIGNED" : "unaligned"); + _unaligned = + UNSAFE.getBoolean( + UNSAFE.staticFieldBase(unalignedField), UNSAFE.staticFieldOffset(unalignedField)); } else { - BOOLEAN_ARRAY_OFFSET = 0; - BYTE_ARRAY_OFFSET = 0; - SHORT_ARRAY_OFFSET = 0; - INT_ARRAY_OFFSET = 0; - LONG_ARRAY_OFFSET = 0; - FLOAT_ARRAY_OFFSET = 0; - DOUBLE_ARRAY_OFFSET = 0; + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); } - } - - static { - boolean _unaligned; - String arch = System.getProperty("os.arch", ""); - if (arch.equals("ppc64le") || arch.equals("ppc64") || arch.equals("s390x")) { - _unaligned = true; - } else { - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - if (UNSAFE != null && MAJOR_VERSION >= 9) { - // Java 9/10 and 11/12 have different field names. - Field unalignedField = - bitsClass.getDeclaredField(MAJOR_VERSION >= 11 ? "UNALIGNED" : "unaligned"); - _unaligned = UNSAFE.getBoolean( - UNSAFE.staticFieldBase(unalignedField), UNSAFE.staticFieldOffset(unalignedField)); - } else { - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } - } catch (Throwable t) { - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); - } + } catch (Throwable t) { + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } + } + UNALIGNED = _unaligned; + } + + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + public static int getInt(IBinaryObject object, long offset) { + return UNSAFE.getInt(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putInt(IBinaryObject object, long offset, int value) { + UNSAFE.putInt(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static boolean getBoolean(IBinaryObject object, long offset) { + return UNSAFE.getBoolean(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putBoolean(IBinaryObject object, long offset, boolean value) { + UNSAFE.putBoolean(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static byte getByte(IBinaryObject object, long offset) { + return UNSAFE.getByte(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putByte(IBinaryObject object, long offset, byte value) { + UNSAFE.putByte(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static short getShort(IBinaryObject object, long offset) { + return UNSAFE.getShort(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putShort(IBinaryObject object, long offset, short value) { + UNSAFE.putShort(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static long getLong(IBinaryObject object, long offset) { + return UNSAFE.getLong(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putLong(IBinaryObject object, long offset, long value) { + UNSAFE.putLong(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static float getFloat(IBinaryObject object, long offset) { + return UNSAFE.getFloat(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putFloat(IBinaryObject object, long offset, float value) { + UNSAFE.putFloat(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static double getDouble(IBinaryObject object, long offset) { + return UNSAFE.getDouble(object.getBaseObject(), object.getAbsoluteAddress(offset)); + } + + public static void putDouble(IBinaryObject object, long offset, double value) { + UNSAFE.putDouble(object.getBaseObject(), object.getAbsoluteAddress(offset), value); + } + + public static void copyMemory( + byte[] src, long srcOffset, byte[] dst, long dstOffset, long length) { + HeapBinaryObject srcObject = HeapBinaryObject.of(src); + HeapBinaryObject dstObject = HeapBinaryObject.of(dst); + copyMemory(srcObject, srcOffset, dstObject, dstOffset, length); + } + + public static void copyMemory( + byte[] src, long srcOffset, IBinaryObject dst, long dstOffset, long length) { + HeapBinaryObject srcObject = HeapBinaryObject.of(src); + copyMemory(srcObject, srcOffset, dst, dstOffset, length); + } + + public static void copyMemory( + IBinaryObject src, long srcOffset, byte[] dst, long dstOffset, long length) { + HeapBinaryObject dstObject = HeapBinaryObject.of(dst); + copyMemory(src, srcOffset, dstObject, dstOffset, length); + } + + public static void copyMemory( + IBinaryObject src, long srcOffset, IBinaryObject dst, long dstOffset, long length) { + if (src.size() == 0 || length == 0) { + return; + } + if (dst.getAbsoluteAddress(dstOffset) < src.getAbsoluteAddress(srcOffset)) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory( + src.getBaseObject(), + src.getAbsoluteAddress(srcOffset), + dst.getBaseObject(), + dst.getAbsoluteAddress(dstOffset), + size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + UNSAFE.copyMemory( + src.getBaseObject(), + src.getAbsoluteAddress(srcOffset), + dst.getBaseObject(), + dst.getAbsoluteAddress(dstOffset), + size); + length -= size; + } + } + } + + public static boolean arrayEquals( + IBinaryObject leftBase, + long leftOffset, + IBinaryObject rightBase, + long rightOffset, + final long length) { + int i = 0; + + // check if stars align and we can get both offsets to be aligned + if ((leftOffset % 8) == (rightOffset % 8)) { + while ((leftOffset + i) % 8 != 0 && i < length) { + if (getByte(leftBase, leftOffset + i) != getByte(rightBase, rightOffset + i)) { + return false; } - UNALIGNED = _unaligned; - } - - private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; - - public static int getInt(IBinaryObject object, long offset) { - return UNSAFE.getInt(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putInt(IBinaryObject object, long offset, int value) { - UNSAFE.putInt(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static boolean getBoolean(IBinaryObject object, long offset) { - return UNSAFE.getBoolean(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putBoolean(IBinaryObject object, long offset, boolean value) { - UNSAFE.putBoolean(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static byte getByte(IBinaryObject object, long offset) { - return UNSAFE.getByte(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putByte(IBinaryObject object, long offset, byte value) { - UNSAFE.putByte(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static short getShort(IBinaryObject object, long offset) { - return UNSAFE.getShort(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putShort(IBinaryObject object, long offset, short value) { - UNSAFE.putShort(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static long getLong(IBinaryObject object, long offset) { - return UNSAFE.getLong(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putLong(IBinaryObject object, long offset, long value) { - UNSAFE.putLong(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static float getFloat(IBinaryObject object, long offset) { - return UNSAFE.getFloat(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putFloat(IBinaryObject object, long offset, float value) { - UNSAFE.putFloat(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static double getDouble(IBinaryObject object, long offset) { - return UNSAFE.getDouble(object.getBaseObject(), object.getAbsoluteAddress(offset)); - } - - public static void putDouble(IBinaryObject object, long offset, double value) { - UNSAFE.putDouble(object.getBaseObject(), object.getAbsoluteAddress(offset), value); - } - - public static void copyMemory(byte[] src, long srcOffset, byte[] dst, long dstOffset, long length) { - HeapBinaryObject srcObject = HeapBinaryObject.of(src); - HeapBinaryObject dstObject = HeapBinaryObject.of(dst); - copyMemory(srcObject, srcOffset, dstObject, dstOffset, length); - } - - public static void copyMemory(byte[] src, long srcOffset, IBinaryObject dst, long dstOffset, long length) { - HeapBinaryObject srcObject = HeapBinaryObject.of(src); - copyMemory(srcObject, srcOffset, dst, dstOffset, length); - } - - public static void copyMemory(IBinaryObject src, long srcOffset, byte[] dst, long dstOffset, long length) { - HeapBinaryObject dstObject = HeapBinaryObject.of(dst); - copyMemory(src, srcOffset, dstObject, dstOffset, length); + i += 1; + } } - - public static void copyMemory(IBinaryObject src, long srcOffset, IBinaryObject dst, long dstOffset, long length) { - if (src.size() == 0 || length == 0) { - return; - } - if (dst.getAbsoluteAddress(dstOffset) < src.getAbsoluteAddress(srcOffset)) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - UNSAFE.copyMemory(src.getBaseObject(), src.getAbsoluteAddress(srcOffset), dst.getBaseObject(), - dst.getAbsoluteAddress(dstOffset), size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } else { - srcOffset += length; - dstOffset += length; - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - srcOffset -= size; - dstOffset -= size; - UNSAFE.copyMemory(src.getBaseObject(), src.getAbsoluteAddress(srcOffset), - dst.getBaseObject(), dst.getAbsoluteAddress(dstOffset), size); - length -= size; - } + if (UNALIGNED || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { + while (i <= length - 8) { + if (getLong(leftBase, leftOffset + i) != getLong(rightBase, rightOffset + i)) { + return false; } + i += 8; + } } - - public static boolean arrayEquals( - IBinaryObject leftBase, long leftOffset, IBinaryObject rightBase, long rightOffset, final long length) { - int i = 0; - - // check if stars align and we can get both offsets to be aligned - if ((leftOffset % 8) == (rightOffset % 8)) { - while ((leftOffset + i) % 8 != 0 && i < length) { - if (getByte(leftBase, leftOffset + i) != getByte(rightBase, rightOffset + i)) { - return false; - } - i += 1; - } - } - if (UNALIGNED || (((leftOffset + i) % 8 == 0) && ((rightOffset + i) % 8 == 0))) { - while (i <= length - 8) { - if (getLong(leftBase, leftOffset + i) != getLong(rightBase, rightOffset + i)) { - return false; - } - i += 8; - } - } - while (i < length) { - if (getByte(leftBase, leftOffset + i) != getByte(rightBase, rightOffset + i)) { - return false; - } - i += 1; - } - return true; + while (i < length) { + if (getByte(leftBase, leftOffset + i) != getByte(rightBase, rightOffset + i)) { + return false; + } + i += 1; } + return true; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryString.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryString.java index d8a5c8362..12040826e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryString.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/BinaryString.java @@ -39,435 +39,663 @@ import static org.apache.geaflow.common.binary.BinaryOperations.copyMemory; import static org.apache.geaflow.common.binary.BinaryOperations.getLong; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.io.Serializable; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.util.Objects; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /* This file is based on source code from the Spark Project (http://spark.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Spark's org.apache.spark.unsafe.types.UTF8String. - */ +/** This class is an adaptation of Spark's org.apache.spark.unsafe.types.UTF8String. */ public class BinaryString implements Comparable, Serializable, KryoSerializable { - private static final boolean IS_LITTLE_ENDIAN = - ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; - - private IBinaryObject binaryObject; - private long offset; - private int numBytes; - - private transient int hashCode = 0; - - private static byte[] bytesOfCodePointInUTF8 = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F - // Continuation bytes cannot appear as the first byte - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF - 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF - 4, 4, 4, 4, 4, // 0xF0..0xF4 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 - }; - - public static final BinaryString EMPTY_STRING = BinaryString.fromString(""); - - public BinaryString() { - - } - - public BinaryString(IBinaryObject binaryObject, long offset, int numBytes) { - this.binaryObject = binaryObject; - this.offset = offset; - this.numBytes = numBytes; - } - - public static BinaryString fromString(String string) { - byte[] bytes = string.getBytes(StandardCharsets.UTF_8); - return new BinaryString(HeapBinaryObject.of(bytes), 0, bytes.length); - } - - public static BinaryString fromBytes(byte[] bytes) { - return new BinaryString(HeapBinaryObject.of(bytes), 0, bytes.length); - } - - public byte[] getBytes() { - if (offset == 0 - && binaryObject instanceof HeapBinaryObject - && (((HeapBinaryObject) binaryObject).getBaseObject()).length == numBytes) { - return ((HeapBinaryObject) binaryObject).getBaseObject(); - } - byte[] bytes = new byte[numBytes]; - copyMemory(binaryObject, offset, bytes, 0, numBytes); - return bytes; - } - - @Override - public String toString() { - return new String(binaryObject.toBytes(), (int) offset, numBytes, StandardCharsets.UTF_8); - } - - public IBinaryObject getBinaryObject() { - return binaryObject; - } - - public long getOffset() { - return offset; - } - - public int getNumBytes() { - return numBytes; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof BinaryString)) { - return false; - } - BinaryString that = (BinaryString) o; - if (numBytes != that.numBytes) { - return false; - } - return BinaryOperations.arrayEquals(binaryObject, offset, that.binaryObject, that.offset, numBytes); - } - - @Override - public int hashCode() { - if (hashCode == 0) { - hashCode = hashUnsafeBytes(binaryObject, offset, numBytes, 42); - } - return hashCode; - } - - @Override - public int compareTo(BinaryString other) { - int len = Math.min(numBytes, other.numBytes); - int wordMax = (len / 8) * 8; - long otherOffset = other.offset; - IBinaryObject otherBase = other.binaryObject; - for (int i = 0; i < wordMax; i += 8) { - long left = getLong(binaryObject, offset + i); - long right = getLong(otherBase, otherOffset + i); - if (left != right) { - if (IS_LITTLE_ENDIAN) { - // Use binary search - int n = 0; - int y; - long diff = left ^ right; - int x = (int) diff; - if (x == 0) { - x = (int) (diff >>> 32); - n = 32; - } - - y = x << 16; - if (y == 0) { - n += 16; - } else { - x = y; - } - - y = x << 8; - if (y == 0) { - n += 8; - } - return (int) (((left >>> n) & 0xFFL) - ((right >>> n) & 0xFFL)); - } else { - return Long.compareUnsigned(left, right); - } - } - } - for (int i = wordMax; i < len; i++) { - // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (BinaryOperations.getByte(otherBase, otherOffset + i) & 0xFF); - if (res != 0) { - return res; - } + private static final boolean IS_LITTLE_ENDIAN = + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + + private IBinaryObject binaryObject; + private long offset; + private int numBytes; + + private transient int hashCode = 0; + + private static byte[] bytesOfCodePointInUTF8 = { + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x00..0x0F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x10..0x1F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x20..0x2F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x30..0x3F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x40..0x4F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x50..0x5F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x60..0x6F + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, // 0x80..0x8F + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, // 0x90..0x9F + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, // 0xA0..0xAF + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, // 0xB0..0xBF + 0, + 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, // 0xC2..0xCF + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, // 0xD0..0xDF + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, // 0xE0..0xEF + 4, + 4, + 4, + 4, + 4, // 0xF0..0xF4 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 // 0xF5..0xFF - disallowed in UTF-8 + }; + + public static final BinaryString EMPTY_STRING = BinaryString.fromString(""); + + public BinaryString() {} + + public BinaryString(IBinaryObject binaryObject, long offset, int numBytes) { + this.binaryObject = binaryObject; + this.offset = offset; + this.numBytes = numBytes; + } + + public static BinaryString fromString(String string) { + byte[] bytes = string.getBytes(StandardCharsets.UTF_8); + return new BinaryString(HeapBinaryObject.of(bytes), 0, bytes.length); + } + + public static BinaryString fromBytes(byte[] bytes) { + return new BinaryString(HeapBinaryObject.of(bytes), 0, bytes.length); + } + + public byte[] getBytes() { + if (offset == 0 + && binaryObject instanceof HeapBinaryObject + && (((HeapBinaryObject) binaryObject).getBaseObject()).length == numBytes) { + return ((HeapBinaryObject) binaryObject).getBaseObject(); + } + byte[] bytes = new byte[numBytes]; + copyMemory(binaryObject, offset, bytes, 0, numBytes); + return bytes; + } + + @Override + public String toString() { + return new String(binaryObject.toBytes(), (int) offset, numBytes, StandardCharsets.UTF_8); + } + + public IBinaryObject getBinaryObject() { + return binaryObject; + } + + public long getOffset() { + return offset; + } + + public int getNumBytes() { + return numBytes; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof BinaryString)) { + return false; + } + BinaryString that = (BinaryString) o; + if (numBytes != that.numBytes) { + return false; + } + return BinaryOperations.arrayEquals( + binaryObject, offset, that.binaryObject, that.offset, numBytes); + } + + @Override + public int hashCode() { + if (hashCode == 0) { + hashCode = hashUnsafeBytes(binaryObject, offset, numBytes, 42); + } + return hashCode; + } + + @Override + public int compareTo(BinaryString other) { + int len = Math.min(numBytes, other.numBytes); + int wordMax = (len / 8) * 8; + long otherOffset = other.offset; + IBinaryObject otherBase = other.binaryObject; + for (int i = 0; i < wordMax; i += 8) { + long left = getLong(binaryObject, offset + i); + long right = getLong(otherBase, otherOffset + i); + if (left != right) { + if (IS_LITTLE_ENDIAN) { + // Use binary search + int n = 0; + int y; + long diff = left ^ right; + int x = (int) diff; + if (x == 0) { + x = (int) (diff >>> 32); + n = 32; + } + + y = x << 16; + if (y == 0) { + n += 16; + } else { + x = y; + } + + y = x << 8; + if (y == 0) { + n += 8; + } + return (int) (((left >>> n) & 0xFFL) - ((right >>> n) & 0xFFL)); + } else { + return Long.compareUnsigned(left, right); } - return numBytes - other.numBytes; + } } - - public byte getByte(int i) { - return BinaryOperations.getByte(binaryObject, offset + i); - } - - public BinaryString[] split(BinaryString pattern, int limit) { - // Java String's split method supports "ignore empty string" behavior when the limit is 0 - // whereas other languages do not. To avoid this java specific behavior, we fall back to - // -1 when the limit is 0. - if (limit == 0) { - limit = -1; - } - String[] splits = toString().split(pattern.toString(), limit); - BinaryString[] res = new BinaryString[splits.length]; - for (int i = 0; i < res.length; i++) { - res[i] = fromString(splits[i]); - } + for (int i = wordMax; i < len; i++) { + // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. + int res = (getByte(i) & 0xFF) - (BinaryOperations.getByte(otherBase, otherOffset + i) & 0xFF); + if (res != 0) { return res; - } - - /** - * Get the length of chars in this string. - */ - public int getLength() { - int len = 0; - for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) { - len += 1; - } - return len; - } - - public boolean contains(final BinaryString substring) { - if (substring.numBytes == 0) { - return true; - } - - byte first = substring.getByte(0); - for (int i = 0; i <= numBytes - substring.numBytes; i++) { - if (getByte(i) == first && matchAt(substring, i)) { - return true; - } - } - return false; - } - - public static BinaryString concat(BinaryString... inputs) { - long totalLength = 0; - for (BinaryString input : inputs) { - if (Objects.isNull(input)) { - continue; - } - totalLength += input.numBytes; - } - - byte[] result = new byte[Math.toIntExact(totalLength)]; - int offset = 0; - for (BinaryString input : inputs) { - if (Objects.isNull(input)) { - continue; - } - int len = input.numBytes; - copyMemory(input.binaryObject, input.offset, result, offset, len); - offset += len; - } - return fromBytes(result); - } - - public static BinaryString concatWs(BinaryString separator, BinaryString... inputs) { - if (Objects.isNull(separator)) { - separator = EMPTY_STRING; - } - - // total number of bytes from inputs - long numInputBytes = 0L; - int numInputs = inputs.length; - for (BinaryString input : inputs) { - if (Objects.nonNull(input)) { - numInputBytes += input.numBytes; - } - } - - int resultSize = - Math.toIntExact(numInputBytes + (numInputs - 1) * (long) separator.numBytes); - byte[] result = new byte[resultSize]; - int offset = 0; - - for (int i = 0, j = 0; i < inputs.length; i++) { - if (Objects.nonNull(inputs[i])) { - int len = inputs[i].numBytes; - copyMemory(inputs[i].binaryObject, inputs[i].offset, result, - offset, len); - offset += len; - } - - j++; - // Add separator if this is not the last input. - if (j < numInputs) { - copyMemory(separator.binaryObject, separator.offset, result, - offset, separator.numBytes); - offset += separator.numBytes; - } - } - return fromBytes(result); - } - - public int indexOf(BinaryString s, int start) { - if (s.numBytes == 0) { - return 0; - } - - // locate to the start position. - int i = 0; // position in byte - int c = 0; // position in character - while (i < numBytes && c < start) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - - do { - if (i + s.numBytes > numBytes) { - return -1; - } - if (BinaryOperations.arrayEquals(binaryObject, offset + i, s.binaryObject, s.offset, s.numBytes)) { - return c; - } - i += numBytesForFirstByte(getByte(i)); - c += 1; - } while (i < numBytes); - + } + } + return numBytes - other.numBytes; + } + + public byte getByte(int i) { + return BinaryOperations.getByte(binaryObject, offset + i); + } + + public BinaryString[] split(BinaryString pattern, int limit) { + // Java String's split method supports "ignore empty string" behavior when the limit is 0 + // whereas other languages do not. To avoid this java specific behavior, we fall back to + // -1 when the limit is 0. + if (limit == 0) { + limit = -1; + } + String[] splits = toString().split(pattern.toString(), limit); + BinaryString[] res = new BinaryString[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + + /** Get the length of chars in this string. */ + public int getLength() { + int len = 0; + for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) { + len += 1; + } + return len; + } + + public boolean contains(final BinaryString substring) { + if (substring.numBytes == 0) { + return true; + } + + byte first = substring.getByte(0); + for (int i = 0; i <= numBytes - substring.numBytes; i++) { + if (getByte(i) == first && matchAt(substring, i)) { + return true; + } + } + return false; + } + + public static BinaryString concat(BinaryString... inputs) { + long totalLength = 0; + for (BinaryString input : inputs) { + if (Objects.isNull(input)) { + continue; + } + totalLength += input.numBytes; + } + + byte[] result = new byte[Math.toIntExact(totalLength)]; + int offset = 0; + for (BinaryString input : inputs) { + if (Objects.isNull(input)) { + continue; + } + int len = input.numBytes; + copyMemory(input.binaryObject, input.offset, result, offset, len); + offset += len; + } + return fromBytes(result); + } + + public static BinaryString concatWs(BinaryString separator, BinaryString... inputs) { + if (Objects.isNull(separator)) { + separator = EMPTY_STRING; + } + + // total number of bytes from inputs + long numInputBytes = 0L; + int numInputs = inputs.length; + for (BinaryString input : inputs) { + if (Objects.nonNull(input)) { + numInputBytes += input.numBytes; + } + } + + int resultSize = Math.toIntExact(numInputBytes + (numInputs - 1) * (long) separator.numBytes); + byte[] result = new byte[resultSize]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { + if (Objects.nonNull(inputs[i])) { + int len = inputs[i].numBytes; + copyMemory(inputs[i].binaryObject, inputs[i].offset, result, offset, len); + offset += len; + } + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + copyMemory(separator.binaryObject, separator.offset, result, offset, separator.numBytes); + offset += separator.numBytes; + } + } + return fromBytes(result); + } + + public int indexOf(BinaryString s, int start) { + if (s.numBytes == 0) { + return 0; + } + + // locate to the start position. + int i = 0; // position in byte + int c = 0; // position in character + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + do { + if (i + s.numBytes > numBytes) { return -1; - } - - public boolean matchAt(final BinaryString s, int pos) { - if (s.numBytes + pos > numBytes || pos < 0) { - return false; - } - return BinaryOperations.arrayEquals(binaryObject, offset + pos, - s.binaryObject, s.offset, s.numBytes); - } - - public boolean startsWith(final BinaryString prefix) { - return matchAt(prefix, 0); - } - - public boolean endsWith(final BinaryString suffix) { - return matchAt(suffix, numBytes - suffix.numBytes); - } - - private static int numBytesForFirstByte(final byte b) { - final int offset = b & 0xFF; - byte numBytes = bytesOfCodePointInUTF8[offset]; - return (numBytes == 0) ? 1 : numBytes; // Skip the first byte disallowed in UTF-8 - } - - public BinaryString substring(final int start) { - return substring(start, getLength()); - } - - /** - * This method is an adaptation of Spark's BinaryString#substring. - */ - public BinaryString substring(final int start, final int end) { - if (end <= start || start >= numBytes) { - return EMPTY_STRING; - } - - int i = 0; - int c = 0; - while (i < numBytes && c < start) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - - int j = i; - while (i < numBytes && c < end) { - i += numBytesForFirstByte(getByte(i)); - c += 1; - } - - if (i > j) { - byte[] bytes = new byte[i - j]; - copyMemory(binaryObject, offset + j, bytes, 0, i - j); - return fromBytes(bytes); - } else { - return EMPTY_STRING; - } - } - - public BinaryString reverse() { - byte[] bytes = new byte[numBytes]; - - for (int i = 0; i < numBytes; i++) { - bytes[i] = getByte(numBytes - i - 1); - } - return fromBytes(bytes); - } - - @Override - public void write(Kryo kryo, Output output) { - output.writeInt(numBytes); - output.write(binaryObject.toBytes(), (int) offset, numBytes); - } - - @Override - public void read(Kryo kryo, Input input) { - int size = input.readInt(); - byte[] bytes = new byte[size]; - input.read(bytes); - this.binaryObject = HeapBinaryObject.of(bytes); - this.offset = 0; - this.numBytes = bytes.length; - } - - /** - * * This method is an adaptation of Spark's org.apache.spark.unsafe.hash.Murmur3_x86_32. - */ - private static int hashUnsafeBytes(IBinaryObject base, long offset, int lengthInBytes, int seed) { - int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); - for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = BinaryOperations.getByte(base, offset + i); - int k1 = mixK1(halfWord); - h1 = mixH1(h1, k1); - } - return fmix(h1, lengthInBytes); - } - - private static int hashBytesByInt(IBinaryObject base, long offset, int lengthInBytes, int seed) { - assert (lengthInBytes % 4 == 0); - int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = BinaryOperations.getInt(base, offset + i); - if (!IS_LITTLE_ENDIAN) { - halfWord = Integer.reverseBytes(halfWord); - } - h1 = mixH1(h1, mixK1(halfWord)); - } - return h1; - } - - private static int mixK1(int k1) { - k1 *= 0xcc9e2d51; - k1 = Integer.rotateLeft(k1, 15); - k1 *= 0x1b873593; - return k1; - } - - private static int mixH1(int h1, int k1) { - h1 ^= k1; - h1 = Integer.rotateLeft(h1, 13); - h1 = h1 * 5 + 0xe6546b64; - return h1; - } - - private static int fmix(int h1, int length) { - h1 ^= length; - h1 ^= h1 >>> 16; - h1 *= 0x85ebca6b; - h1 ^= h1 >>> 13; - h1 *= 0xc2b2ae35; - h1 ^= h1 >>> 16; - return h1; - } + } + if (BinaryOperations.arrayEquals( + binaryObject, offset + i, s.binaryObject, s.offset, s.numBytes)) { + return c; + } + i += numBytesForFirstByte(getByte(i)); + c += 1; + } while (i < numBytes); + + return -1; + } + + public boolean matchAt(final BinaryString s, int pos) { + if (s.numBytes + pos > numBytes || pos < 0) { + return false; + } + return BinaryOperations.arrayEquals( + binaryObject, offset + pos, s.binaryObject, s.offset, s.numBytes); + } + + public boolean startsWith(final BinaryString prefix) { + return matchAt(prefix, 0); + } + + public boolean endsWith(final BinaryString suffix) { + return matchAt(suffix, numBytes - suffix.numBytes); + } + + private static int numBytesForFirstByte(final byte b) { + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1 : numBytes; // Skip the first byte disallowed in UTF-8 + } + + public BinaryString substring(final int start) { + return substring(start, getLength()); + } + + /** This method is an adaptation of Spark's BinaryString#substring. */ + public BinaryString substring(final int start, final int end) { + if (end <= start || start >= numBytes) { + return EMPTY_STRING; + } + + int i = 0; + int c = 0; + while (i < numBytes && c < start) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + int j = i; + while (i < numBytes && c < end) { + i += numBytesForFirstByte(getByte(i)); + c += 1; + } + + if (i > j) { + byte[] bytes = new byte[i - j]; + copyMemory(binaryObject, offset + j, bytes, 0, i - j); + return fromBytes(bytes); + } else { + return EMPTY_STRING; + } + } + + public BinaryString reverse() { + byte[] bytes = new byte[numBytes]; + + for (int i = 0; i < numBytes; i++) { + bytes[i] = getByte(numBytes - i - 1); + } + return fromBytes(bytes); + } + + @Override + public void write(Kryo kryo, Output output) { + output.writeInt(numBytes); + output.write(binaryObject.toBytes(), (int) offset, numBytes); + } + + @Override + public void read(Kryo kryo, Input input) { + int size = input.readInt(); + byte[] bytes = new byte[size]; + input.read(bytes); + this.binaryObject = HeapBinaryObject.of(bytes); + this.offset = 0; + this.numBytes = bytes.length; + } + + /** * This method is an adaptation of Spark's org.apache.spark.unsafe.hash.Murmur3_x86_32. */ + private static int hashUnsafeBytes(IBinaryObject base, long offset, int lengthInBytes, int seed) { + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = BinaryOperations.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(IBinaryObject base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = BinaryOperations.getInt(base, offset + i); + if (!IS_LITTLE_ENDIAN) { + halfWord = Integer.reverseBytes(halfWord); + } + h1 = mixH1(h1, mixK1(halfWord)); + } + return h1; + } + + private static int mixK1(int k1) { + k1 *= 0xcc9e2d51; + k1 = Integer.rotateLeft(k1, 15); + k1 *= 0x1b873593; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/HeapBinaryObject.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/HeapBinaryObject.java index 6cb243d42..68d8dea79 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/HeapBinaryObject.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/HeapBinaryObject.java @@ -19,94 +19,94 @@ package org.apache.geaflow.common.binary; +import java.util.Arrays; + +import org.apache.geaflow.common.exception.GeaflowRuntimeException; + import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import java.util.Arrays; -import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class HeapBinaryObject implements IBinaryObject { - private byte[] bytes; - - public HeapBinaryObject() { + private byte[] bytes; - } + public HeapBinaryObject() {} - private HeapBinaryObject(byte[] bytes) { - this.bytes = bytes; - } + private HeapBinaryObject(byte[] bytes) { + this.bytes = bytes; + } - public static HeapBinaryObject of(byte[] bytes) { - return new HeapBinaryObject(bytes); - } + public static HeapBinaryObject of(byte[] bytes) { + return new HeapBinaryObject(bytes); + } - @Override - public byte[] getBaseObject() { - return bytes; - } + @Override + public byte[] getBaseObject() { + return bytes; + } - @Override - public long getAbsoluteAddress(long address) { - if (address < 0 || address >= size()) { - throw new GeaflowRuntimeException("Illegal address: " + address + ", is out of visit " - + "range[0," + size() + ")"); - } - return BinaryOperations.BYTE_ARRAY_OFFSET + address; + @Override + public long getAbsoluteAddress(long address) { + if (address < 0 || address >= size()) { + throw new GeaflowRuntimeException( + "Illegal address: " + address + ", is out of visit " + "range[0," + size() + ")"); } - - @Override - public int size() { - return bytes.length; - } - - @Override - public void release() { - bytes = null; + return BinaryOperations.BYTE_ARRAY_OFFSET + address; + } + + @Override + public int size() { + return bytes.length; + } + + @Override + public void release() { + bytes = null; + } + + @Override + public boolean isReleased() { + return bytes == null; + } + + @Override + public byte[] toBytes() { + return bytes; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public boolean isReleased() { - return bytes == null; - } - - @Override - public byte[] toBytes() { - return bytes; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof HeapBinaryObject)) { - return false; - } - HeapBinaryObject that = (HeapBinaryObject) o; - return Arrays.equals(bytes, that.bytes); - } - - @Override - public int hashCode() { - return Arrays.hashCode(bytes); - } - - @Override - public void write(Kryo kryo, Output output) { - output.writeVarInt(bytes.length, true); - output.write(bytes); - } - - @Override - public void read(Kryo kryo, Input input) { - int length = input.readVarInt(true); - this.bytes = new byte[length]; - input.read(bytes); - } - - @Override - public String toString() { - return "HeapBinaryObject{" + "bytes=" + Arrays.toString(bytes) + '}'; + if (!(o instanceof HeapBinaryObject)) { + return false; } + HeapBinaryObject that = (HeapBinaryObject) o; + return Arrays.equals(bytes, that.bytes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } + + @Override + public void write(Kryo kryo, Output output) { + output.writeVarInt(bytes.length, true); + output.write(bytes); + } + + @Override + public void read(Kryo kryo, Input input) { + int length = input.readVarInt(true); + this.bytes = new byte[length]; + input.read(bytes); + } + + @Override + public String toString() { + return "HeapBinaryObject{" + "bytes=" + Arrays.toString(bytes) + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/IBinaryObject.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/IBinaryObject.java index 9a16f353d..0ceb0c9d0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/IBinaryObject.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/binary/IBinaryObject.java @@ -19,40 +19,31 @@ package org.apache.geaflow.common.binary; -import com.esotericsoftware.kryo.KryoSerializable; import java.io.Serializable; +import com.esotericsoftware.kryo.KryoSerializable; + public interface IBinaryObject extends KryoSerializable, Serializable { - /** - * Get base binary object. - */ - Object getBaseObject(); - - /** - * Get absolute address. - * - * @param address relative address. - */ - long getAbsoluteAddress(long address); - - /** - * Binary size. - */ - int size(); - - /** - * Release memory. - */ - void release(); - - /** - * Judge release or not. - */ - boolean isReleased(); - - /** - * Convert to byte array. - */ - byte[] toBytes(); + /** Get base binary object. */ + Object getBaseObject(); + + /** + * Get absolute address. + * + * @param address relative address. + */ + long getAbsoluteAddress(long address); + + /** Binary size. */ + int size(); + + /** Release memory. */ + void release(); + + /** Judge release or not. */ + boolean isReleased(); + + /** Convert to byte array. */ + byte[] toBytes(); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/blocking/map/BlockingMap.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/blocking/map/BlockingMap.java index c15fbcc42..0f4e186ce 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/blocking/map/BlockingMap.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/blocking/map/BlockingMap.java @@ -24,26 +24,28 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class BlockingMap { - private Map> map = new ConcurrentHashMap<>(); + private Map> map = new ConcurrentHashMap<>(); - private BlockingQueue getQueue(K key) { - return map.computeIfAbsent(key, k -> new ArrayBlockingQueue<>(1)); - } + private BlockingQueue getQueue(K key) { + return map.computeIfAbsent(key, k -> new ArrayBlockingQueue<>(1)); + } - public void put(K key, V value) { - if (!getQueue(key).offer(value)) { - throw new GeaflowRuntimeException(String.format("BlockingMap offer element (%s, %s) failed.", key, value)); - } + public void put(K key, V value) { + if (!getQueue(key).offer(value)) { + throw new GeaflowRuntimeException( + String.format("BlockingMap offer element (%s, %s) failed.", key, value)); } + } - public V get(K key) throws InterruptedException { - return getQueue(key).take(); - } + public V get(K key) throws InterruptedException { + return getQueue(key).take(); + } - public V get(K key, long timeout, TimeUnit unit) throws InterruptedException { - return getQueue(key).poll(timeout, unit); - } + public V get(K key, long timeout, TimeUnit unit) throws InterruptedException { + return getQueue(key).poll(timeout, unit); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigHelper.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigHelper.java index f3057b169..2c9c662db 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigHelper.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigHelper.java @@ -20,88 +20,88 @@ package org.apache.geaflow.common.config; import java.util.Map; + import org.apache.geaflow.common.exception.ConfigException; public class ConfigHelper { - public static int getInteger(Map config, String configKey) { - if (config.containsKey(configKey)) { - return Integer.valueOf(String.valueOf(config.get(configKey))); - } else { - throw new ConfigException(configKey); - } + public static int getInteger(Map config, String configKey) { + if (config.containsKey(configKey)) { + return Integer.valueOf(String.valueOf(config.get(configKey))); + } else { + throw new ConfigException(configKey); } + } - public static int getIntegerOrDefault(Map config, String configKey, int defaultValue) { - if (config.containsKey(configKey)) { - return Integer.valueOf(String.valueOf(config.get(configKey))); - } else { - return defaultValue; - } + public static int getIntegerOrDefault(Map config, String configKey, int defaultValue) { + if (config.containsKey(configKey)) { + return Integer.valueOf(String.valueOf(config.get(configKey))); + } else { + return defaultValue; } + } - public static long getLong(Map config, String configKey) { - if (config.containsKey(configKey)) { - return Long.valueOf(String.valueOf(config.get(configKey))); - } else { - throw new ConfigException(configKey); - } + public static long getLong(Map config, String configKey) { + if (config.containsKey(configKey)) { + return Long.valueOf(String.valueOf(config.get(configKey))); + } else { + throw new ConfigException(configKey); } + } - public static long getLongOrDefault(Map config, String configKey, long defaultValue) { - if (config.containsKey(configKey)) { - return Long.valueOf(String.valueOf(config.get(configKey))); - } else { - return defaultValue; - } + public static long getLongOrDefault(Map config, String configKey, long defaultValue) { + if (config.containsKey(configKey)) { + return Long.valueOf(String.valueOf(config.get(configKey))); + } else { + return defaultValue; } + } - public static boolean getBoolean(Map config, String configKey) { - if (config.containsKey(configKey)) { - return Boolean.valueOf(String.valueOf(config.get(configKey))); - } else { - throw new ConfigException(configKey); - } + public static boolean getBoolean(Map config, String configKey) { + if (config.containsKey(configKey)) { + return Boolean.valueOf(String.valueOf(config.get(configKey))); + } else { + throw new ConfigException(configKey); } + } - public static boolean getBooleanOrDefault(Map config, String configKey, boolean defaultValue) { - if (config.containsKey(configKey)) { - return Boolean.valueOf(String.valueOf(config.get(configKey))); - } else { - return defaultValue; - } + public static boolean getBooleanOrDefault(Map config, String configKey, boolean defaultValue) { + if (config.containsKey(configKey)) { + return Boolean.valueOf(String.valueOf(config.get(configKey))); + } else { + return defaultValue; } + } - public static String getString(Map config, String configKey) { - if (config.containsKey(configKey)) { - return String.valueOf(config.get(configKey)); - } else { - throw new ConfigException("Missing config:'" + configKey + "'"); - } + public static String getString(Map config, String configKey) { + if (config.containsKey(configKey)) { + return String.valueOf(config.get(configKey)); + } else { + throw new ConfigException("Missing config:'" + configKey + "'"); } + } - public static String getStringOrDefault(Map config, String configKey, String defaultValue) { - if (config.containsKey(configKey)) { - return String.valueOf(config.get(configKey)); - } else { - return defaultValue; - } + public static String getStringOrDefault(Map config, String configKey, String defaultValue) { + if (config.containsKey(configKey)) { + return String.valueOf(config.get(configKey)); + } else { + return defaultValue; } + } - public static Double getDoubleOrDefault(Map config, String configKey, double defaultValue) { - if (config.containsKey(configKey)) { - return Double.valueOf(String.valueOf(config.get(configKey))); - } else { - return defaultValue; - } + public static Double getDoubleOrDefault(Map config, String configKey, double defaultValue) { + if (config.containsKey(configKey)) { + return Double.valueOf(String.valueOf(config.get(configKey))); + } else { + return defaultValue; } + } - public static Double getDouble(Map config, String configKey) { - if (config.containsKey(configKey)) { - return Double.valueOf(String.valueOf(config.get(configKey))); - } else { - throw new ConfigException(configKey); - } + public static Double getDouble(Map config, String configKey) { + if (config.containsKey(configKey)) { + return Double.valueOf(String.valueOf(config.get(configKey))); + } else { + throw new ConfigException(configKey); } - + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKey.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKey.java index 2035b3c9a..85b71673f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKey.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKey.java @@ -23,38 +23,38 @@ public class ConfigKey implements Serializable { - private final String key; - private String description; - private Object defaultValue; - - public ConfigKey(String key) { - this.key = key; - } - - public ConfigKey(String key, Object defaultValue) { - this.key = key; - this.defaultValue = defaultValue; - } - - public ConfigKey description(String description) { - this.description = description; - return this; - } - - public ConfigKey defaultValue(Object defaultValue) { - this.defaultValue = defaultValue; - return this; - } - - public String getKey() { - return key; - } - - public String getDescription() { - return description; - } - - public Object getDefaultValue() { - return defaultValue; - } + private final String key; + private String description; + private Object defaultValue; + + public ConfigKey(String key) { + this.key = key; + } + + public ConfigKey(String key, Object defaultValue) { + this.key = key; + this.defaultValue = defaultValue; + } + + public ConfigKey description(String description) { + this.description = description; + return this; + } + + public ConfigKey defaultValue(Object defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public String getKey() { + return key; + } + + public String getDescription() { + return description; + } + + public Object getDefaultValue() { + return defaultValue; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKeys.java index 0efeb425d..3420264bc 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/ConfigKeys.java @@ -23,26 +23,24 @@ public class ConfigKeys implements Serializable { - public static ConfigKeyBuilder key(String key) { - return new ConfigKeyBuilder(key); - } - - public static class ConfigKeyBuilder { - - private final String key; + public static ConfigKeyBuilder key(String key) { + return new ConfigKeyBuilder(key); + } - public ConfigKeyBuilder(String key) { - this.key = key; - } + public static class ConfigKeyBuilder { - public ConfigKey defaultValue(Object defaultValue) { - return new ConfigKey(key, defaultValue); - } + private final String key; - public ConfigKey noDefaultValue() { - return new ConfigKey(key); - } + public ConfigKeyBuilder(String key) { + this.key = key; + } + public ConfigKey defaultValue(Object defaultValue) { + return new ConfigKey(key, defaultValue); } + public ConfigKey noDefaultValue() { + return new ConfigKey(key); + } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/Configuration.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/Configuration.java index 1611b2b65..97815540a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/Configuration.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/Configuration.java @@ -25,153 +25,152 @@ public class Configuration implements Serializable { - private String masterId; - private final Map config; + private String masterId; + private final Map config; - public Configuration() { - this.config = new HashMap<>(); - } - - public Configuration(Map config) { - this.config = config; - } - - public Map getConfigMap() { - return config; - } - - public String getMasterId() { - return masterId; - } - - public void setMasterId(String masterId) { - this.masterId = masterId; - } - - public boolean contains(ConfigKey key) { - return config.containsKey(key.getKey()); - } - - public boolean contains(String key) { - return config.containsKey(key); - } - - public void put(String key, String value) { - config.put(key, value); - } - - public void put(ConfigKey key, String value) { - config.put(key.getKey(), value); - } - - public void putAll(Map map) { - config.putAll(map); - } - - public String getString(ConfigKey configKey) { - return getString(configKey, config); - } - - public String getString(ConfigKey configKey, String defaultValue) { - return getString(configKey, defaultValue, config); - } - - public String getString(String configKey) { - return config.get(configKey); - } - - public String getString(String configKey, String defaultValue) { - return ConfigHelper.getStringOrDefault(config, configKey, defaultValue); - } - - public int getInteger(ConfigKey configKey) { - return getInteger(configKey, config); - } - - public int getInteger(ConfigKey configKey, int defaultValue) { - return getInteger(configKey, defaultValue, config); - } - - public int getInteger(String configKey, int defaultValue) { - return ConfigHelper.getIntegerOrDefault(config, configKey, defaultValue); - } - - public double getDouble(ConfigKey configKey) { - return getDouble(configKey, config); - } - - public boolean getBoolean(ConfigKey configKey) { - return getBoolean(configKey, config); - } - - public long getLong(ConfigKey configKey) { - return getLong(configKey, config); - } - - public long getLong(String configKey, long defaultValue) { - return ConfigHelper.getLongOrDefault(config, configKey, defaultValue); - } - - public static String getString(ConfigKey configKey, Map config) { - if (configKey.getDefaultValue() != null) { - return ConfigHelper.getStringOrDefault(config, configKey.getKey(), - String.valueOf(configKey.getDefaultValue())); - } else { - return ConfigHelper.getString(config, configKey.getKey()); - } - } - - public static String getString(ConfigKey configKey, String defaultValue, - Map config) { - return ConfigHelper.getStringOrDefault(config, configKey.getKey(), defaultValue); - } - - public static boolean getBoolean(ConfigKey configKey, Map config) { - if (configKey.getDefaultValue() != null) { - return ConfigHelper.getBooleanOrDefault(config, configKey.getKey(), - (Boolean) configKey.getDefaultValue()); - } else { - return ConfigHelper.getBoolean(config, configKey.getKey()); - } - } - - public static int getInteger(ConfigKey configKey, Map config) { - if (configKey.getDefaultValue() != null) { - return ConfigHelper.getIntegerOrDefault(config, configKey.getKey(), - (Integer) configKey.getDefaultValue()); - } else { - return ConfigHelper.getInteger(config, configKey.getKey()); - } - } + public Configuration() { + this.config = new HashMap<>(); + } + + public Configuration(Map config) { + this.config = config; + } - public static int getInteger(ConfigKey configKey, int defaultValue, - Map config) { - return ConfigHelper.getIntegerOrDefault(config, configKey.getKey(), defaultValue); - } - - public static long getLong(ConfigKey configKey, Map config) { - if (configKey.getDefaultValue() != null) { - return ConfigHelper - .getLongOrDefault(config, configKey.getKey(), (Long) configKey.getDefaultValue()); - } else { - return ConfigHelper.getLong(config, configKey.getKey()); - } - } + public Map getConfigMap() { + return config; + } + + public String getMasterId() { + return masterId; + } - public static long getLong(ConfigKey configKey, long defaultValue, Map config) { - return ConfigHelper.getLongOrDefault(config, configKey.getKey(), defaultValue); - } + public void setMasterId(String masterId) { + this.masterId = masterId; + } - public static double getDouble(ConfigKey configKey, Map config) { - if (configKey.getDefaultValue() != null) { - return ConfigHelper.getDoubleOrDefault(config, configKey.getKey(), - (Double) configKey.getDefaultValue()); - } - return ConfigHelper.getDouble(config, configKey.getKey()); - } - - @Override - public String toString() { - return "Configuration{" + config + '}'; - } + public boolean contains(ConfigKey key) { + return config.containsKey(key.getKey()); + } + + public boolean contains(String key) { + return config.containsKey(key); + } + + public void put(String key, String value) { + config.put(key, value); + } + + public void put(ConfigKey key, String value) { + config.put(key.getKey(), value); + } + + public void putAll(Map map) { + config.putAll(map); + } + + public String getString(ConfigKey configKey) { + return getString(configKey, config); + } + + public String getString(ConfigKey configKey, String defaultValue) { + return getString(configKey, defaultValue, config); + } + + public String getString(String configKey) { + return config.get(configKey); + } + + public String getString(String configKey, String defaultValue) { + return ConfigHelper.getStringOrDefault(config, configKey, defaultValue); + } + + public int getInteger(ConfigKey configKey) { + return getInteger(configKey, config); + } + + public int getInteger(ConfigKey configKey, int defaultValue) { + return getInteger(configKey, defaultValue, config); + } + + public int getInteger(String configKey, int defaultValue) { + return ConfigHelper.getIntegerOrDefault(config, configKey, defaultValue); + } + + public double getDouble(ConfigKey configKey) { + return getDouble(configKey, config); + } + + public boolean getBoolean(ConfigKey configKey) { + return getBoolean(configKey, config); + } + + public long getLong(ConfigKey configKey) { + return getLong(configKey, config); + } + + public long getLong(String configKey, long defaultValue) { + return ConfigHelper.getLongOrDefault(config, configKey, defaultValue); + } + + public static String getString(ConfigKey configKey, Map config) { + if (configKey.getDefaultValue() != null) { + return ConfigHelper.getStringOrDefault( + config, configKey.getKey(), String.valueOf(configKey.getDefaultValue())); + } else { + return ConfigHelper.getString(config, configKey.getKey()); + } + } + + public static String getString( + ConfigKey configKey, String defaultValue, Map config) { + return ConfigHelper.getStringOrDefault(config, configKey.getKey(), defaultValue); + } + + public static boolean getBoolean(ConfigKey configKey, Map config) { + if (configKey.getDefaultValue() != null) { + return ConfigHelper.getBooleanOrDefault( + config, configKey.getKey(), (Boolean) configKey.getDefaultValue()); + } else { + return ConfigHelper.getBoolean(config, configKey.getKey()); + } + } + + public static int getInteger(ConfigKey configKey, Map config) { + if (configKey.getDefaultValue() != null) { + return ConfigHelper.getIntegerOrDefault( + config, configKey.getKey(), (Integer) configKey.getDefaultValue()); + } else { + return ConfigHelper.getInteger(config, configKey.getKey()); + } + } + + public static int getInteger(ConfigKey configKey, int defaultValue, Map config) { + return ConfigHelper.getIntegerOrDefault(config, configKey.getKey(), defaultValue); + } + + public static long getLong(ConfigKey configKey, Map config) { + if (configKey.getDefaultValue() != null) { + return ConfigHelper.getLongOrDefault( + config, configKey.getKey(), (Long) configKey.getDefaultValue()); + } else { + return ConfigHelper.getLong(config, configKey.getKey()); + } + } + + public static long getLong(ConfigKey configKey, long defaultValue, Map config) { + return ConfigHelper.getLongOrDefault(config, configKey.getKey(), defaultValue); + } + + public static double getDouble(ConfigKey configKey, Map config) { + if (configKey.getDefaultValue() != null) { + return ConfigHelper.getDoubleOrDefault( + config, configKey.getKey(), (Double) configKey.getDefaultValue()); + } + return ConfigHelper.getDouble(config, configKey.getKey()); + } + + @Override + public String toString() { + return "Configuration{" + config + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ConnectorConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ConnectorConfigKeys.java index 039f07be3..c943ec63e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ConnectorConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ConnectorConfigKeys.java @@ -20,106 +20,106 @@ package org.apache.geaflow.common.config.keys; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class ConnectorConfigKeys implements Serializable { - /************************************************* - * Connectors Common Parameters. - *************************************************/ - public static final ConfigKey GEAFLOW_DSL_LINE_SEPARATOR = ConfigKeys - .key("geaflow.dsl.line.separator") - .defaultValue("\n") - .description("The line separator for split text to columns."); - - public static final ConfigKey GEAFLOW_DSL_COLUMN_SEPARATOR = ConfigKeys - .key("geaflow.dsl.column.separator") - .defaultValue(",") - .description("The column separator for split text to columns."); - - public static final ConfigKey GEAFLOW_DSL_PARTITIONS_PER_SOURCE_PARALLELISM = ConfigKeys - .key("geaflow.dsl.partitions.per.source.parallelism") - .defaultValue(1) - .description("Partitions to read for each source parallelism."); - - public static final ConfigKey GEAFLOW_DSL_COLUMN_TRIM = ConfigKeys - .key("geaflow.dsl.column.trim") - .defaultValue(true) - .description("Whether doing trim operation for column text when split text to columns."); - - public static final ConfigKey GEAFLOW_DSL_START_TIME = ConfigKeys - .key("geaflow.dsl.start.time") - .defaultValue("begin") - .description("Specifies the starting unix timestamp for reading the data table. Format " - + "must be 'yyyy-MM-dd HH:mm:ss'."); - - public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT = ConfigKeys - .key("geaflow.dsl.connector.format") - .defaultValue("text") - .description("Specifies the deserialization format for reading from external source like kafka, " - + "possible option currently: json/text"); - - - public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_IGNORE_PARSE_ERROR = ConfigKeys - .key("geaflow.dsl.connector.format.json.ignore-parse-error") - .defaultValue(false) - .description("for json format, skip fields and rows with parse errors instead of failing. " - + "Fields are set to null in case of errors."); - - public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_FAIL_ON_MISSING_FIELD = ConfigKeys - .key("geaflow.dsl.connector.format.json.fail-on-missing-field") - .defaultValue(false) - .description("for json format, whether to fail if a field is missing or not."); - - /************************************************* - * FILE Connector Parameters. - *************************************************/ - public static final ConfigKey GEAFLOW_DSL_FILE_PATH = ConfigKeys - .key("geaflow.dsl.file.path") - .noDefaultValue() - .description("The file path for the file table."); - - public static final ConfigKey GEAFLOW_DSL_FILE_NAME_REGEX = ConfigKeys - .key("geaflow.dsl.file.name.regex") - .defaultValue("") - .description("The regular expression for filtering the files in the path."); - - public static final ConfigKey GEAFLOW_DSL_FILE_FORMAT = ConfigKeys - .key("geaflow.dsl.file.format") - .defaultValue("txt") - .description("The file format to read or write, default value is 'txt'. "); - - public static final ConfigKey GEAFLOW_DSL_SKIP_HEADER = ConfigKeys - .key("geaflow.dsl.skip.header") - .defaultValue(false) - .description("Whether skip the header for csv format."); - - public static final ConfigKey GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD = ConfigKeys - .key("geaflow.dsl.source.file.parallel.mod") - .defaultValue(false) - .description("Whether read single file by index"); - - public static final ConfigKey GEAFLOW_DSL_SINK_FILE_COLLISION = ConfigKeys - .key("geaflow.dsl.sink.file.collision") - .defaultValue("newfile") - .description("Whether create new file when collision occurs."); - - public static final ConfigKey GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE = ConfigKeys - .key("geaflow.dsl.file.line.split.size") - .defaultValue(-1) - .description("file line split size set by user"); - - public static final ConfigKey GEAFLOW_DSL_SOURCE_ENABLE_UPLOAD_METRICS = ConfigKeys - .key("geaflow.dsl.source.enable.upload.metrics") - .defaultValue(true) - .description("source enable upload metrics"); - - public static final ConfigKey GEAFLOW_DSL_SINK_ENABLE_SKIP = ConfigKeys - .key("geaflow.dsl.sink.enable.skip") - .defaultValue(false) - .description("sink enable skip"); - + /************************************************* + * Connectors Common Parameters. + *************************************************/ + public static final ConfigKey GEAFLOW_DSL_LINE_SEPARATOR = + ConfigKeys.key("geaflow.dsl.line.separator") + .defaultValue("\n") + .description("The line separator for split text to columns."); + + public static final ConfigKey GEAFLOW_DSL_COLUMN_SEPARATOR = + ConfigKeys.key("geaflow.dsl.column.separator") + .defaultValue(",") + .description("The column separator for split text to columns."); + + public static final ConfigKey GEAFLOW_DSL_PARTITIONS_PER_SOURCE_PARALLELISM = + ConfigKeys.key("geaflow.dsl.partitions.per.source.parallelism") + .defaultValue(1) + .description("Partitions to read for each source parallelism."); + + public static final ConfigKey GEAFLOW_DSL_COLUMN_TRIM = + ConfigKeys.key("geaflow.dsl.column.trim") + .defaultValue(true) + .description("Whether doing trim operation for column text when split text to columns."); + + public static final ConfigKey GEAFLOW_DSL_START_TIME = + ConfigKeys.key("geaflow.dsl.start.time") + .defaultValue("begin") + .description( + "Specifies the starting unix timestamp for reading the data table. Format " + + "must be 'yyyy-MM-dd HH:mm:ss'."); + + public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT = + ConfigKeys.key("geaflow.dsl.connector.format") + .defaultValue("text") + .description( + "Specifies the deserialization format for reading from external source like kafka, " + + "possible option currently: json/text"); + + public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_IGNORE_PARSE_ERROR = + ConfigKeys.key("geaflow.dsl.connector.format.json.ignore-parse-error") + .defaultValue(false) + .description( + "for json format, skip fields and rows with parse errors instead of failing. " + + "Fields are set to null in case of errors."); + + public static final ConfigKey GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_FAIL_ON_MISSING_FIELD = + ConfigKeys.key("geaflow.dsl.connector.format.json.fail-on-missing-field") + .defaultValue(false) + .description("for json format, whether to fail if a field is missing or not."); + + /************************************************* + * FILE Connector Parameters. + *************************************************/ + public static final ConfigKey GEAFLOW_DSL_FILE_PATH = + ConfigKeys.key("geaflow.dsl.file.path") + .noDefaultValue() + .description("The file path for the file table."); + + public static final ConfigKey GEAFLOW_DSL_FILE_NAME_REGEX = + ConfigKeys.key("geaflow.dsl.file.name.regex") + .defaultValue("") + .description("The regular expression for filtering the files in the path."); + + public static final ConfigKey GEAFLOW_DSL_FILE_FORMAT = + ConfigKeys.key("geaflow.dsl.file.format") + .defaultValue("txt") + .description("The file format to read or write, default value is 'txt'. "); + + public static final ConfigKey GEAFLOW_DSL_SKIP_HEADER = + ConfigKeys.key("geaflow.dsl.skip.header") + .defaultValue(false) + .description("Whether skip the header for csv format."); + + public static final ConfigKey GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD = + ConfigKeys.key("geaflow.dsl.source.file.parallel.mod") + .defaultValue(false) + .description("Whether read single file by index"); + + public static final ConfigKey GEAFLOW_DSL_SINK_FILE_COLLISION = + ConfigKeys.key("geaflow.dsl.sink.file.collision") + .defaultValue("newfile") + .description("Whether create new file when collision occurs."); + + public static final ConfigKey GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE = + ConfigKeys.key("geaflow.dsl.file.line.split.size") + .defaultValue(-1) + .description("file line split size set by user"); + + public static final ConfigKey GEAFLOW_DSL_SOURCE_ENABLE_UPLOAD_METRICS = + ConfigKeys.key("geaflow.dsl.source.enable.upload.metrics") + .defaultValue(true) + .description("source enable upload metrics"); + + public static final ConfigKey GEAFLOW_DSL_SINK_ENABLE_SKIP = + ConfigKeys.key("geaflow.dsl.sink.enable.skip") + .defaultValue(false) + .description("sink enable skip"); } - - diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java index 76cbf2902..829b475ed 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/DSLConfigKeys.java @@ -20,126 +20,124 @@ package org.apache.geaflow.common.config.keys; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class DSLConfigKeys implements Serializable { - private static final long serialVersionUID = 3550044668482560581L; - - public static final ConfigKey INCR_TRAVERSAL_ITERATION_THRESHOLD = ConfigKeys - .key("geaflow.dsl.incr.traversal.iteration.threshold") - .defaultValue(4) - .description("The max iteration to enable incr match"); - - public static final ConfigKey TABLE_SINK_SPLIT_LINE = ConfigKeys - .key("geaflow.dsl.table.sink.split.line") - .noDefaultValue() - .description("The file sink split line."); - - public static final ConfigKey ENABLE_INCR_TRAVERSAL = ConfigKeys - .key("geaflow.dsl.graph.enable.incr.traversal") - .defaultValue(false) - .description("Enable incr match"); - - public static final ConfigKey INCR_TRAVERSAL_WINDOW = ConfigKeys - .key("geaflow.dsl.graph.incr.traversal.window") - .defaultValue(-1L) - .description("When window id is large than this parameter to enable incr match"); - - public static final ConfigKey GEAFLOW_DSL_STORE_TYPE = ConfigKeys - .key("geaflow.dsl.graph.store.type") - .noDefaultValue() - .description("The graph store type."); - - public static final ConfigKey GEAFLOW_DSL_STORE_SHARD_COUNT = ConfigKeys - .key("geaflow.dsl.graph.store.shard.count") - .defaultValue(2) - .description("The graph store shard count."); - - public static final ConfigKey GEAFLOW_DSL_WINDOW_SIZE = ConfigKeys - .key("geaflow.dsl.window.size") - .defaultValue(1L) - .description("Window size, -1 represent the all window."); - - public static final ConfigKey GEAFLOW_DSL_TIME_WINDOW_SIZE = ConfigKeys - .key("geaflow.dsl.time.window.size") - .defaultValue(-1L) - .description("Specifies source time window size in second unites"); - - public static final ConfigKey GEAFLOW_DSL_TABLE_TYPE = ConfigKeys - .key("geaflow.dsl.table.type") - .noDefaultValue() - .description("The table type."); - - public static final ConfigKey GEAFLOW_DSL_TABLE_PARALLELISM = ConfigKeys - .key("geaflow.dsl.table.parallelism") - .defaultValue(1) - .description("The table parallelism."); - - public static final ConfigKey GEAFLOW_DSL_MAX_TRAVERSAL = ConfigKeys - .key("geaflow.dsl.max.traversal") - .defaultValue(64) - .description("The max traversal count."); - - public static final ConfigKey GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION = ConfigKeys - .key("geaflow.dsl.custom.source.function") - .noDefaultValue() - .description("Custom source function class name."); - - public static final ConfigKey GEAFLOW_DSL_CUSTOM_SINK_FUNCTION = ConfigKeys - .key("geaflow.dsl.custom.sink.function") - .noDefaultValue() - .description("Custom sink function class name."); - - public static final ConfigKey GEAFLOW_DSL_QUERY_PATH = ConfigKeys - .key("geaflow.dsl.query.path") - .noDefaultValue() - .description("The gql query path."); - - public static final ConfigKey GEAFLOW_DSL_QUERY_PATH_TYPE = ConfigKeys - .key("geaflow.dsl.query.path.type") - .noDefaultValue() - .description("The gql query path type."); - - public static final ConfigKey GEAFLOW_DSL_PARALLELISM_CONFIG_PATH = ConfigKeys - .key("geaflow.dsl.parallelism.config.path") - .noDefaultValue() - .description("The gql query path."); - - public static final ConfigKey GEAFLOW_DSL_CATALOG_TYPE = ConfigKeys - .key("geaflow.dsl.catalog.type") - .defaultValue("memory") - .description("The catalog type name. Optional internal implementations " - + "include 'memory' and 'console'"); - - public static final ConfigKey GEAFLOW_DSL_CATALOG_TOKEN_KEY = ConfigKeys - .key("geaflow.dsl.catalog.token.key") - .noDefaultValue() - .description("The catalog token key set by console platform"); - - public static final ConfigKey GEAFLOW_DSL_CATALOG_INSTANCE_NAME = ConfigKeys - .key("geaflow.dsl.catalog.instance.name") - .defaultValue("default") - .description("The default instance name set by console platform"); - - public static final ConfigKey GEAFLOW_DSL_SKIP_EXCEPTION = ConfigKeys - .key("geaflow.dsl.ignore.exception") - .defaultValue(false) - .description("If set true, dsl will skip the exception for dirty data."); - - public static final ConfigKey GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE = ConfigKeys - .key("geaflow.dsl.traversal.all.split.enable") - .defaultValue(false) - .description("Whether enable the split of the ids for traversal all. "); - - public static final ConfigKey GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE = ConfigKeys - .key("geaflow.dsl.compile.physical.plan.enable") - .defaultValue(true) - .description("Whether enable compile query physical plan. "); - - public static final ConfigKey GEAFLOW_DSL_SOURCE_PARALLELISM = ConfigKeys - .key("geaflow.dsl.source.parallelism") - .noDefaultValue() - .description("Set source parallelism"); + private static final long serialVersionUID = 3550044668482560581L; + + public static final ConfigKey INCR_TRAVERSAL_ITERATION_THRESHOLD = + ConfigKeys.key("geaflow.dsl.incr.traversal.iteration.threshold") + .defaultValue(4) + .description("The max iteration to enable incr match"); + + public static final ConfigKey TABLE_SINK_SPLIT_LINE = + ConfigKeys.key("geaflow.dsl.table.sink.split.line") + .noDefaultValue() + .description("The file sink split line."); + + public static final ConfigKey ENABLE_INCR_TRAVERSAL = + ConfigKeys.key("geaflow.dsl.graph.enable.incr.traversal") + .defaultValue(false) + .description("Enable incr match"); + + public static final ConfigKey INCR_TRAVERSAL_WINDOW = + ConfigKeys.key("geaflow.dsl.graph.incr.traversal.window") + .defaultValue(-1L) + .description("When window id is large than this parameter to enable incr match"); + + public static final ConfigKey GEAFLOW_DSL_STORE_TYPE = + ConfigKeys.key("geaflow.dsl.graph.store.type") + .noDefaultValue() + .description("The graph store type."); + + public static final ConfigKey GEAFLOW_DSL_STORE_SHARD_COUNT = + ConfigKeys.key("geaflow.dsl.graph.store.shard.count") + .defaultValue(2) + .description("The graph store shard count."); + + public static final ConfigKey GEAFLOW_DSL_WINDOW_SIZE = + ConfigKeys.key("geaflow.dsl.window.size") + .defaultValue(1L) + .description("Window size, -1 represent the all window."); + + public static final ConfigKey GEAFLOW_DSL_TIME_WINDOW_SIZE = + ConfigKeys.key("geaflow.dsl.time.window.size") + .defaultValue(-1L) + .description("Specifies source time window size in second unites"); + + public static final ConfigKey GEAFLOW_DSL_TABLE_TYPE = + ConfigKeys.key("geaflow.dsl.table.type").noDefaultValue().description("The table type."); + + public static final ConfigKey GEAFLOW_DSL_TABLE_PARALLELISM = + ConfigKeys.key("geaflow.dsl.table.parallelism") + .defaultValue(1) + .description("The table parallelism."); + + public static final ConfigKey GEAFLOW_DSL_MAX_TRAVERSAL = + ConfigKeys.key("geaflow.dsl.max.traversal") + .defaultValue(64) + .description("The max traversal count."); + + public static final ConfigKey GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION = + ConfigKeys.key("geaflow.dsl.custom.source.function") + .noDefaultValue() + .description("Custom source function class name."); + + public static final ConfigKey GEAFLOW_DSL_CUSTOM_SINK_FUNCTION = + ConfigKeys.key("geaflow.dsl.custom.sink.function") + .noDefaultValue() + .description("Custom sink function class name."); + + public static final ConfigKey GEAFLOW_DSL_QUERY_PATH = + ConfigKeys.key("geaflow.dsl.query.path").noDefaultValue().description("The gql query path."); + + public static final ConfigKey GEAFLOW_DSL_QUERY_PATH_TYPE = + ConfigKeys.key("geaflow.dsl.query.path.type") + .noDefaultValue() + .description("The gql query path type."); + + public static final ConfigKey GEAFLOW_DSL_PARALLELISM_CONFIG_PATH = + ConfigKeys.key("geaflow.dsl.parallelism.config.path") + .noDefaultValue() + .description("The gql query path."); + + public static final ConfigKey GEAFLOW_DSL_CATALOG_TYPE = + ConfigKeys.key("geaflow.dsl.catalog.type") + .defaultValue("memory") + .description( + "The catalog type name. Optional internal implementations " + + "include 'memory' and 'console'"); + + public static final ConfigKey GEAFLOW_DSL_CATALOG_TOKEN_KEY = + ConfigKeys.key("geaflow.dsl.catalog.token.key") + .noDefaultValue() + .description("The catalog token key set by console platform"); + + public static final ConfigKey GEAFLOW_DSL_CATALOG_INSTANCE_NAME = + ConfigKeys.key("geaflow.dsl.catalog.instance.name") + .defaultValue("default") + .description("The default instance name set by console platform"); + + public static final ConfigKey GEAFLOW_DSL_SKIP_EXCEPTION = + ConfigKeys.key("geaflow.dsl.ignore.exception") + .defaultValue(false) + .description("If set true, dsl will skip the exception for dirty data."); + + public static final ConfigKey GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE = + ConfigKeys.key("geaflow.dsl.traversal.all.split.enable") + .defaultValue(false) + .description("Whether enable the split of the ids for traversal all. "); + + public static final ConfigKey GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE = + ConfigKeys.key("geaflow.dsl.compile.physical.plan.enable") + .defaultValue(true) + .description("Whether enable compile query physical plan. "); + + public static final ConfigKey GEAFLOW_DSL_SOURCE_PARALLELISM = + ConfigKeys.key("geaflow.dsl.source.parallelism") + .noDefaultValue() + .description("Set source parallelism"); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ExecutionConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ExecutionConfigKeys.java index c6df0b869..af2a8373e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ExecutionConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/ExecutionConfigKeys.java @@ -20,6 +20,7 @@ package org.apache.geaflow.common.config.keys; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; import org.apache.geaflow.common.shuffle.StorageLevel; @@ -27,615 +28,594 @@ public class ExecutionConfigKeys implements Serializable { - // ------------------------------------------------------------------------ - // console - // ------------------------------------------------------------------------ - - public static final ConfigKey GEAFLOW_GW_ENDPOINT = ConfigKeys - .key("geaflow.gw.endpoint") - .noDefaultValue() - .description("console address, such as http://localhost:8888"); - - public static final ConfigKey JOB_APP_NAME = ConfigKeys - .key("geaflow.job.runtime.name") - .defaultValue("default") - .description("job app name generated by console"); - - public static final ConfigKey JOB_UNIQUE_ID = ConfigKeys - .key("geaflow.job.unique.id") - .noDefaultValue() - .description("job unique id generated by console"); - - public static final ConfigKey CLUSTER_ID = ConfigKeys - .key("geaflow.job.cluster.id") - .defaultValue("") - .description("geaflow job cluster id"); - - public static final ConfigKey SYSTEM_META_TABLE = ConfigKeys - .key("geaflow.system.meta.table") - .noDefaultValue() - .description("system meta table"); - - public static final ConfigKey CLUSTER_STARTED_CALLBACK_URL = ConfigKeys - .key("geaflow.cluster.started.callback.url") - .defaultValue("") - .description("callback url to register the cluster info"); - - // ------------------------------------------------------------------------ - // rpc - // ------------------------------------------------------------------------ - - public static final ConfigKey MASTER_HTTP_PORT = ConfigKeys.key("geaflow.master.http.port") - .defaultValue(8090) - .description("master http port"); - - public static final ConfigKey AGENT_HTTP_PORT = ConfigKeys.key("geaflow.agent.http.port") - .defaultValue(0) - .description("agent http port"); - - public static final ConfigKey DRIVER_RPC_PORT = ConfigKeys.key("geaflow.driver.rpc.port") - .defaultValue(6123) - .description("driver rpc port"); - - public static final ConfigKey RPC_ASYNC_THREADS = ConfigKeys - .key("geaflow.rpc.async.thread.num") - .defaultValue(2) - .description("rpc thread pool number"); - - public static final ConfigKey HEARTBEAT_INTERVAL_MS = ConfigKeys - .key("geaflow.heartbeat.interval.ms") - .defaultValue(30000) - .description("heartbeat interval"); - - public static final ConfigKey HEARTBEAT_INITIAL_DELAY_MS = ConfigKeys - .key("geaflow.heartbeat.initial.delay.ms") - .defaultValue(10000) - .description("heart beat thread initial delay"); - - public static final ConfigKey HEARTBEAT_TIMEOUT_MS = ConfigKeys - .key("geaflow.heartbeat.timeout.ms") - .defaultValue(120000) - .description("heartbeat timeout in ms"); - - public static final ConfigKey HEARTBEAT_REPORT_INTERVAL_MS = ConfigKeys - .key("geaflow.heartbeat.report.interval.ms") - .defaultValue(30000) - .description("heartbeat report interval in ms"); - - public static final ConfigKey HEARTBEAT_REPORT_EXPIRED_MS = ConfigKeys - .key("geaflow.heartbeat.report.expired.ms") - .noDefaultValue() - .description("heartbeat report expired time in ms"); - - public static final ConfigKey RPC_RETRY_TIMES = ConfigKeys - .key("geaflow.rpc.retry.times") - .defaultValue(20) - .description("max retry of rpc connection"); - - public static final ConfigKey RPC_RETRY_INTERVAL_MS = ConfigKeys - .key("geaflow.rpc.retry.interval.ms") - .defaultValue(1000) - .description("retry interval of rpc connection in ms"); - - public static final ConfigKey RPC_CONNECT_TIMEOUT_MS = ConfigKeys - .key("geaflow.rpc.connect.timeout.ms") - .defaultValue(5000) - .description("rpc connect timeout"); - - public static final ConfigKey RPC_READ_TIMEOUT_MS = ConfigKeys - .key("geaflow.rpc.read.timeout.ms") - .defaultValue(Integer.MAX_VALUE) - .description("rpc read timeout"); - - public static final ConfigKey RPC_WRITE_TIMEOUT_MS = ConfigKeys - .key("geaflow.rpc.write.timeout.ms") - .defaultValue(Integer.MAX_VALUE) - .description("rpc write timeout"); - - public static final ConfigKey RPC_MAX_TOTAL_CONNECTION_NUM = ConfigKeys - .key("geaflow.rpc.max.total.connection.num") - .defaultValue(2) - .description("rpc max total connection num"); - - public static final ConfigKey RPC_MIN_IDLE_CONNECTION_NUM = ConfigKeys - .key("geaflow.rpc.min.idle.connection.num") - .defaultValue(2) - .description("rpc min idle connection num"); - - public static final ConfigKey RPC_MAX_RETRY_TIMES = ConfigKeys - .key("geaflow.rpc.max.retry.times") - .defaultValue(3) - .description("rpc max retry times"); - - public static final ConfigKey RPC_KEEP_ALIVE_TIME_SEC = ConfigKeys - .key("geaflow.rpc.keep.alive.time.sec") - .defaultValue(0) - .description("rpc keep alive time sec"); - - public static final ConfigKey RPC_THREADPOOL_SHARING_ENABLE = ConfigKeys - .key("geaflow.rpc.threadpool.sharing.enable") - .defaultValue(true) - .description("rpc threadpool sharing enable"); - - public static final ConfigKey RPC_IO_THREAD_NUM = ConfigKeys - .key("geaflow.rpc.io.thread.num") - .defaultValue(8) - .description("rpc io thread num"); - - public static final ConfigKey RPC_WORKER_THREAD_NUM = ConfigKeys - .key("geaflow.rpc.worker.thread.num") - .defaultValue(8) - .description("rpc worker thread num"); - - public static final ConfigKey RPC_BUFFER_SIZE_BYTES = ConfigKeys - .key("geaflow.rpc.buffer.size.bytes") - .defaultValue(256 * 1024) - .description("rpc buffer size bytes"); - - public static final ConfigKey RPC_CHANNEL_CONNECT_TYPE = ConfigKeys - .key("geaflow.rpc.channel.connect.type") - .defaultValue("pooled_connection") - .description("rpc channel connect type, e.g. [pooled_connection, short_connection, single_connection]"); - - // ------------------------------------------------------------------------ - // cluster - // ------------------------------------------------------------------------ - - public static final ConfigKey RUN_LOCAL_MODE = ConfigKeys - .key("geaflow.run.local.mode") - .defaultValue(false) - .description("job run in single process or distributed"); - - public static final ConfigKey JOB_WORK_PATH = ConfigKeys - .key("geaflow.work.path") - .defaultValue("/tmp") - .description("job work path on disk"); - - public static final ConfigKey CONTAINER_DISPATCH_THREADS = ConfigKeys - .key("geaflow.container.dispatcher.threads") - .defaultValue(1) - .description("container event dispatcher thread number"); - - public static final ConfigKey CLIENT_VCORES = ConfigKeys.key("geaflow.client.vcores") - .defaultValue(1.0) - .description("client cpu number"); - - public static final ConfigKey CLIENT_MEMORY_MB = ConfigKeys.key("geaflow.client.memory.mb") - .defaultValue(1024) - .description("client container memory"); - - public static final ConfigKey CLIENT_DISK_GB = ConfigKeys.key("geaflow.client.disk.gb") - .defaultValue(0) - .description("client container disk"); - - public static final ConfigKey CLIENT_JVM_OPTIONS = ConfigKeys.key("geaflow.client.jvm.options") - .defaultValue("-Xmx640m,-Xms640m,-Xmn256m,-Xss256k") - .description("client jvm options"); - - public static final ConfigKey MASTER_MEMORY_MB = ConfigKeys.key("geaflow.master.memory.mb") - .defaultValue(4096) - .description("master container memory"); - - public static final ConfigKey MASTER_DISK_GB = ConfigKeys.key("geaflow.master.disk.gb") - .defaultValue(0) - .description("master container disk"); - - public static final ConfigKey MASTER_JVM_OPTIONS = ConfigKeys.key("geaflow.master.jvm.options") - .defaultValue("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") - .description("master container jvm options"); - - public static final ConfigKey MASTER_VCORES = ConfigKeys.key("geaflow.master.vcores") - .defaultValue(1.0) - .description("master cpu"); - - public static final ConfigKey DRIVER_NUM = ConfigKeys.key("geaflow.driver.num") - .defaultValue(1) - .description("driver number"); - - public static final ConfigKey DRIVER_MEMORY_MB = ConfigKeys.key("geaflow.driver.memory.mb") - .defaultValue(4096) - .description("driver container memory"); - - public static final ConfigKey DRIVER_DISK_GB = ConfigKeys.key("geaflow.driver.disk.gb") - .defaultValue(0) - .description("driver container disk"); - - public static final ConfigKey DRIVER_JVM_OPTION = ConfigKeys.key("geaflow.driver.jvm.options") - .defaultValue("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") - .description("driver container jvm options"); - - public static final ConfigKey DRIVER_VCORES = ConfigKeys.key("geaflow.driver.vcores") - .defaultValue(1.0) - .description("driver cpu"); - - public static final ConfigKey CONTAINER_NUM = ConfigKeys.key("geaflow.container.num") - .defaultValue(1) - .description("container num"); - - public static final ConfigKey CONTAINER_VCORES = ConfigKeys.key("geaflow.container.vcores") - .defaultValue(1.0) - .description("container cpu"); - - public static final ConfigKey CONTAINER_MEMORY_MB = ConfigKeys - .key("geaflow.container.memory.mb") - .defaultValue(256) - .description("container memory"); - - public static final ConfigKey CONTAINER_DISK_GB = ConfigKeys - .key("geaflow.container.disk.gb") - .defaultValue(0) - .description("container disk"); - - public static final ConfigKey CONTAINER_WORKER_NUM = ConfigKeys - .key("geaflow.container.worker.num") - .defaultValue(16) - .description("max worker num in container"); - - public static final ConfigKey CONTAINER_JVM_OPTION = ConfigKeys - .key("geaflow.container.jvm.options") - .noDefaultValue() - .description("container jvm options"); - - public static final ConfigKey CONTAINER_HEAP_SIZE_MB = ConfigKeys - .key("geaflow.container.heap.size.mb") - .noDefaultValue() - .description("container max heap size in mb"); - - public static final ConfigKey EXECUTOR_MAX_MULTIPLE = ConfigKeys - .key("geaflow.executor.thread.max.multiple") - .defaultValue(10) - .description("Maximum thread pool size multiplier (maxThreads = multiple * available cores)"); - - public static final ConfigKey FO_ENABLE = ConfigKeys - .key("geaflow.fo.enable") - .defaultValue(true) - .description("whether to enable fo"); - - public static final ConfigKey FO_STRATEGY = ConfigKeys - .key("geaflow.fo.strategy") - .defaultValue("cluster_fo") - .description("whether to enable fo"); - - public static final ConfigKey FO_TIMEOUT_MS = ConfigKeys.key("geaflow.fo.timeout.ms") - .defaultValue(300000) - .description("fo timeout in ms"); - - public static final ConfigKey FO_MAX_RESTARTS = ConfigKeys - .key("geaflow.fo.max.restarts") - .defaultValue(Integer.MAX_VALUE) - .description("process max restart times, value in [0, Int.maxValue]"); - - public static final ConfigKey HA_SERVICE_TYPE = ConfigKeys - .key("geaflow.ha.service.type") - .defaultValue("") - .description("ha service type, e.g., [redis, hbase, memory]"); - - public static final ConfigKey ENABLE_MASTER_LEADER_ELECTION = ConfigKeys - .key("geaflow.master.leader-election.enable") - .defaultValue(false) - .description("whether to enable leader-election of master, currently only supports in k8s env"); - - public static final ConfigKey LEADER_ELECTION_TYPE = ConfigKeys - .key("geaflow.leader-election.type") - .defaultValue("kubernetes") - .description("leader-election type, e.g., [kubernetes]"); - - public static final ConfigKey HTTP_REST_SERVICE_ENABLE = ConfigKeys - .key("geaflow.http.rest.service.enable") - .defaultValue(true) - .description("whether to enable http rest service"); - - public static final ConfigKey PROFILER_FILENAME_EXTENSION = ConfigKeys - .key("geaflow.profiler.filename.extension") - .defaultValue(".html") - .description("filename extension of profiler results, e.g., [.html, .svg]"); - - public static final ConfigKey CLUSTER_CLIENT_TIMEOUT_MS = ConfigKeys - .key("geaflow.cluster.client.timeout.ms") - .defaultValue(300000) - .description("cluster client timeout in ms"); - - public static final ConfigKey CLIENT_EXIT_WAIT_SECONDS = ConfigKeys - .key("geaflow.cluster.client.exit.wait.secs") - .defaultValue(5) - .description("cluster client exit wait time in seconds"); - - public static final ConfigKey SERVICE_DISCOVERY_TYPE = ConfigKeys - .key("geaflow.service.discovery.type") - .defaultValue("redis") - .description("service discovery type, e.g.[zookeeper, redis]"); - - public static final ConfigKey JOB_MODE = ConfigKeys - .key("geaflow.job.mode") - .defaultValue("compute") - .description("job mode, e.g.[compute, olap service]"); - - // ------------------------------------------------------------------------ - // supervisor - // ------------------------------------------------------------------------ - - public static final ConfigKey SUPERVISOR_ENABLE = ConfigKeys.key("geaflow.supervisor.enable") - .defaultValue(false) - .description("enable supervisor or not"); - - public static final ConfigKey SUPERVISOR_RPC_PORT = ConfigKeys.key("geaflow.supervisor.rpc.port") - .defaultValue(0) - .description("supervisor rpc port"); - - public static final ConfigKey SUPERVISOR_JVM_OPTIONS = ConfigKeys.key("geaflow.supervisor.jvm.options") - .defaultValue("-Xmx128m,-Xms64m,-Xmn32m") - .description("supervisor jvm options"); - - public static final ConfigKey LOG_DIR = ConfigKeys.key("geaflow.log.dir") - .defaultValue("/home/admin/logs/geaflow") - .description("geaflow job log directory"); - - public static final ConfigKey CONF_DIR = ConfigKeys.key("geaflow.conf.dir") - .defaultValue("/etc/geaflow/conf") - .description("geaflow conf directory"); - - public static final ConfigKey PROCESS_AUTO_RESTART = ConfigKeys - .key("geaflow.process.auto-restart") - .defaultValue("unexpected") - .description("whether to restart process automatically"); - - public static final ConfigKey PROCESS_EXIT_WAIT_SECONDS = ConfigKeys - .key("geaflow.process.exit.wait.secs") - .defaultValue(3) - .description("process exit max wait seconds"); - - // ------------------------------------------------------------------------ - // shuffle - // ------------------------------------------------------------------------ - - /** - * Shuffle common config. - */ - - public static final ConfigKey SHUFFLE_PREFETCH = ConfigKeys - .key("geaflow.shuffle.prefetch.enable") - .defaultValue(true) - .description("if enable shuffle prefetch"); - - public static final ConfigKey SHUFFLE_MEMORY_POOL_ENABLE = ConfigKeys - .key("geaflow.shuffle.memory.pool.enable") - .defaultValue(false) - .description("whether to enable shuffle memory pool"); - - public static final ConfigKey SHUFFLE_COMPRESSION_ENABLE = ConfigKeys - .key("geaflow.shuffle.compression.enable") - .defaultValue(false) - .description("whether to enable shuffle compression"); - - public static final ConfigKey SHUFFLE_COMPRESSION_CODEC = ConfigKeys - .key("geaflow.shuffle.compression.codec") - .defaultValue("snappy") - .description("codec of shuffle compression"); - - public static final ConfigKey SHUFFLE_BACKPRESSURE_ENABLE = ConfigKeys - .key("geaflow.shuffle.backpressure.enable") - .defaultValue(false) - .description("whether to enable shuffle backpressure"); - - /** - * Shuffle network config. - */ - - public static final ConfigKey NETTY_SERVER_HOST = ConfigKeys - .key("geaflow.netty.server.host") - .defaultValue(ProcessUtil.LOCAL_ADDRESS) - .description("netty server host"); - - public static final ConfigKey NETTY_SERVER_PORT = ConfigKeys - .key("geaflow.netty.server.port") - .defaultValue(0) - .description("netty server port"); - - public static final ConfigKey NETTY_SERVER_BACKLOG = ConfigKeys - .key("geaflow.netty.server.backlog") - .defaultValue(512) - .description("requested maximum length of the queue of incoming connections"); - - public static final ConfigKey NETTY_SERVER_THREADS_NUM = ConfigKeys - .key("geaflow.netty.server.threads") - .defaultValue(4) - .description("number of threads used in the server thread pool. default to 4, and 0 means 2x#cores"); - - public static final ConfigKey NETTY_CLIENT_THREADS_NUM = ConfigKeys - .key("geaflow.netty.client.threads") - .defaultValue(4) - .description("number of threads used in the client thread pool. default to 4, and 0 means 2x#cores"); - - public static final ConfigKey NETTY_CONNECT_TIMEOUT_MS = ConfigKeys - .key("geaflow.netty.connect.timeout.ms") - .defaultValue(180000) - .description("netty connection timeout in milliseconds"); - - public static final ConfigKey NETTY_CONNECT_MAX_RETRY_TIMES = ConfigKeys - .key("geaflow.netty.connect.retry.times") - .defaultValue(100) - .description("max retry times of netty connection"); - - public static final ConfigKey NETTY_CONNECT_INITIAL_BACKOFF_MS = ConfigKeys - .key("geaflow.netty.connect.initial.backoff.ms") - .defaultValue(10) - .description("initial backoff time of netty connection in milliseconds"); - - public static final ConfigKey NETTY_CONNECT_MAX_BACKOFF_MS = ConfigKeys - .key("geaflow.netty.connect.max.backoff.ms") - .defaultValue(300000) - .description("max backoff time of netty connection in milliseconds"); - - public static final ConfigKey NETTY_RECEIVE_BUFFER_SIZE = ConfigKeys - .key("geaflow.netty.receive.buffer.size") - .defaultValue(0) - .description("netty receive buffer size"); - - public static final ConfigKey NETTY_SEND_BUFFER_SIZE = ConfigKeys - .key("geaflow.netty.send.buffer.size") - .defaultValue(0) - .description("netty send buffer size"); - - public static final ConfigKey NETTY_THREAD_CACHE_ENABLE = ConfigKeys - .key("geaflow.netty.thread.cache.enable") - .defaultValue(true) - .description("whether to enable netty thread cache"); - - public static final ConfigKey NETTY_PREFER_DIRECT_BUFFER = ConfigKeys - .key("geaflow.netty.prefer.direct.buffer") - .defaultValue(true) - .description("whether to prefer direct buffer"); - - public static final ConfigKey NETTY_CUSTOM_FRAME_DECODER_ENABLE = ConfigKeys - .key("geaflow.netty.custom.frame.decoder.enable") - .defaultValue(true) - .description("whether to enable custom frame decoder"); - - /** - * shuffle fetch config. - */ - - public static final ConfigKey SHUFFLE_FETCH_TIMEOUT_MS = ConfigKeys - .key("geaflow.shuffle.fetch.timeout.ms") - .defaultValue(600000) - .description("shuffle fetch timeout in milliseconds"); - - public static final ConfigKey SHUFFLE_FETCH_QUEUE_SIZE = ConfigKeys - .key("geaflow.shuffle.fetch.queue.size") - .defaultValue(1) - .description("size of shuffle fetch queue"); - - public static final ConfigKey SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE = ConfigKeys - .key("geaflow.shuffle.fetch.channel.queue.size") - .defaultValue(64) - .description("buffer number per channel"); - - /** - * Shuffle write config. - */ - - public static final ConfigKey SHUFFLE_SPILL_RECORDS = ConfigKeys - .key("geaflow.shuffle.spill.records") - .defaultValue(10000) - .description("num of shuffle spill records"); - - public static final ConfigKey SHUFFLE_SLICE_MAX_SPILL_SIZE = ConfigKeys - .key("geaflow.max.spill.size.perSlice") - .defaultValue(1610612736L) // 1.5G - .description("max size of each spill per slice in Bytes"); - - public static final ConfigKey SHUFFLE_FLUSH_BUFFER_SIZE_BYTES = ConfigKeys - .key("geaflow.shuffle.flush.buffer.size.bytes") - .defaultValue(128 * 1024) - .description("size of shuffle write buffer"); - - public static final ConfigKey SHUFFLE_WRITER_BUFFER_SIZE = ConfigKeys - .key("geaflow.shuffle.writer.buffer.size") - .defaultValue(64 * 1024 * 1024) - .description("max buffer size for the shuffle writer in bytes"); - - public static final ConfigKey SHUFFLE_EMIT_BUFFER_SIZE = ConfigKeys - .key("geaflow.shuffle.emit.buffer.size") - .defaultValue(1024) - .description("size of shuffle emit buffer of java object"); - - public static final ConfigKey SHUFFLE_EMIT_QUEUE_SIZE = ConfigKeys - .key("geaflow.shuffle.emit.queue.size") - .defaultValue(1) - .description("size of shuffle emit queue"); - - public static final ConfigKey SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS = ConfigKeys - .key("geaflow.shuffle.flush.buffer.timeout.ms") - .defaultValue(100) - .description("shuffle flush buffer timeout ms"); - - /** - * Shuffle storage. - */ - - public static final ConfigKey SHUFFLE_STORAGE_TYPE = ConfigKeys - .key("geaflow.shuffle.storage.type") - .defaultValue(StorageLevel.MEMORY_AND_DISK) - .description("type of shuffle storage"); - - public static final ConfigKey SHUFFLE_OFFHEAP_MEMORY_FRACTION = ConfigKeys - .key("geaflow.shuffle.offheap.fraction") - .defaultValue(0.2) - .description("fraction of shuffle offheap memory"); - - public static final ConfigKey SHUFFLE_HEAP_MEMORY_FRACTION = ConfigKeys - .key("geaflow.shuffle.heap.memory.fraction") - .defaultValue(0.2) - .description("fraction of shuffle heap memory"); - - public static final ConfigKey SHUFFLE_MEMORY_SAFETY_FRACTION = ConfigKeys - .key("geaflow.shuffle.memory.safety.fraction") - .defaultValue(0.9) - .description("fraction of shuffle memory to ensure safety"); - - - // ------------------------------------------------------------------------ - // metrics - // ------------------------------------------------------------------------ - - public static final ConfigKey SCHEDULE_PERIOD = ConfigKeys - .key("geaflow.metric.schedule.period.sec") - .defaultValue(30) - .description("metric report interval in seconds"); - - public static final ConfigKey REPORTER_LIST = ConfigKeys.key("geaflow.metric.reporters") - .defaultValue("") - .description("metric reporter list. Multiple reporters are separated by comma. for " - + "example: influxdb,tsdb,prometheus,slf4j"); - - public static final ConfigKey METRIC_META_REPORT_DELAY = ConfigKeys - .key("geaflow.metric.meta.delay.sec") - .defaultValue(5) - .description("metric meta report thread initial delay, used in tsdb reporter"); - - public static final ConfigKey METRIC_META_REPORT_PERIOD = ConfigKeys - .key("geaflow.metric.meta.period.sec") - .defaultValue(10) - .description("metric meta report period in seconds, used in tsdb reporter"); - - public static final ConfigKey METRIC_META_REPORT_RETRIES = ConfigKeys - .key("geaflow.metric.meta.retries") - .defaultValue(5) - .description("metric meta report max retry times"); - - public static final ConfigKey METRIC_MAX_CACHED_PIPELINES = ConfigKeys - .key("geaflow.metric.max.cached.pipelines") - .defaultValue(50) - .description("max cached pipeline metrics"); - - public static final ConfigKey METRIC_SERVICE_PORT = ConfigKeys - .key("geaflow.metric.service.port") - .defaultValue(0) - .description("metric service port"); - - public static final ConfigKey STATS_METRIC_STORE_TYPE = ConfigKeys - .key("geaflow.metric.stats.type") - .defaultValue("MEMORY") - .description("stats metrics store type, e.g., [MEMORY, JDBC, HBASE]"); - - public static final ConfigKey STATS_METRIC_FLUSH_THREADS = ConfigKeys - .key("geaflow.metric.flush.threads") - .defaultValue(1) - .description("stats metrics flush thread number"); - - public static final ConfigKey STATS_METRIC_FLUSH_BATCH_SIZE = ConfigKeys - .key("geaflow.metric.flush.batch.size") - .defaultValue(200) - .description("stats metrics flush batch size"); - - public static final ConfigKey STATS_METRIC_FLUSH_INTERVAL_MS = ConfigKeys - .key("geaflow.metric.flush.interval.ms") - .defaultValue(1000) - .description("stats flush interval in ms"); - - public static final ConfigKey ENABLE_DETAIL_METRIC = ConfigKeys - .key("geaflow.metric.detail.enable") - .defaultValue(false) - .description("if enable detail job metric"); - + // ------------------------------------------------------------------------ + // console + // ------------------------------------------------------------------------ + + public static final ConfigKey GEAFLOW_GW_ENDPOINT = + ConfigKeys.key("geaflow.gw.endpoint") + .noDefaultValue() + .description("console address, such as http://localhost:8888"); + + public static final ConfigKey JOB_APP_NAME = + ConfigKeys.key("geaflow.job.runtime.name") + .defaultValue("default") + .description("job app name generated by console"); + + public static final ConfigKey JOB_UNIQUE_ID = + ConfigKeys.key("geaflow.job.unique.id") + .noDefaultValue() + .description("job unique id generated by console"); + + public static final ConfigKey CLUSTER_ID = + ConfigKeys.key("geaflow.job.cluster.id") + .defaultValue("") + .description("geaflow job cluster id"); + + public static final ConfigKey SYSTEM_META_TABLE = + ConfigKeys.key("geaflow.system.meta.table").noDefaultValue().description("system meta table"); + + public static final ConfigKey CLUSTER_STARTED_CALLBACK_URL = + ConfigKeys.key("geaflow.cluster.started.callback.url") + .defaultValue("") + .description("callback url to register the cluster info"); + + // ------------------------------------------------------------------------ + // rpc + // ------------------------------------------------------------------------ + + public static final ConfigKey MASTER_HTTP_PORT = + ConfigKeys.key("geaflow.master.http.port").defaultValue(8090).description("master http port"); + + public static final ConfigKey AGENT_HTTP_PORT = + ConfigKeys.key("geaflow.agent.http.port").defaultValue(0).description("agent http port"); + + public static final ConfigKey DRIVER_RPC_PORT = + ConfigKeys.key("geaflow.driver.rpc.port").defaultValue(6123).description("driver rpc port"); + + public static final ConfigKey RPC_ASYNC_THREADS = + ConfigKeys.key("geaflow.rpc.async.thread.num") + .defaultValue(2) + .description("rpc thread pool number"); + + public static final ConfigKey HEARTBEAT_INTERVAL_MS = + ConfigKeys.key("geaflow.heartbeat.interval.ms") + .defaultValue(30000) + .description("heartbeat interval"); + + public static final ConfigKey HEARTBEAT_INITIAL_DELAY_MS = + ConfigKeys.key("geaflow.heartbeat.initial.delay.ms") + .defaultValue(10000) + .description("heart beat thread initial delay"); + + public static final ConfigKey HEARTBEAT_TIMEOUT_MS = + ConfigKeys.key("geaflow.heartbeat.timeout.ms") + .defaultValue(120000) + .description("heartbeat timeout in ms"); + + public static final ConfigKey HEARTBEAT_REPORT_INTERVAL_MS = + ConfigKeys.key("geaflow.heartbeat.report.interval.ms") + .defaultValue(30000) + .description("heartbeat report interval in ms"); + + public static final ConfigKey HEARTBEAT_REPORT_EXPIRED_MS = + ConfigKeys.key("geaflow.heartbeat.report.expired.ms") + .noDefaultValue() + .description("heartbeat report expired time in ms"); + + public static final ConfigKey RPC_RETRY_TIMES = + ConfigKeys.key("geaflow.rpc.retry.times") + .defaultValue(20) + .description("max retry of rpc connection"); + + public static final ConfigKey RPC_RETRY_INTERVAL_MS = + ConfigKeys.key("geaflow.rpc.retry.interval.ms") + .defaultValue(1000) + .description("retry interval of rpc connection in ms"); + + public static final ConfigKey RPC_CONNECT_TIMEOUT_MS = + ConfigKeys.key("geaflow.rpc.connect.timeout.ms") + .defaultValue(5000) + .description("rpc connect timeout"); + + public static final ConfigKey RPC_READ_TIMEOUT_MS = + ConfigKeys.key("geaflow.rpc.read.timeout.ms") + .defaultValue(Integer.MAX_VALUE) + .description("rpc read timeout"); + + public static final ConfigKey RPC_WRITE_TIMEOUT_MS = + ConfigKeys.key("geaflow.rpc.write.timeout.ms") + .defaultValue(Integer.MAX_VALUE) + .description("rpc write timeout"); + + public static final ConfigKey RPC_MAX_TOTAL_CONNECTION_NUM = + ConfigKeys.key("geaflow.rpc.max.total.connection.num") + .defaultValue(2) + .description("rpc max total connection num"); + + public static final ConfigKey RPC_MIN_IDLE_CONNECTION_NUM = + ConfigKeys.key("geaflow.rpc.min.idle.connection.num") + .defaultValue(2) + .description("rpc min idle connection num"); + + public static final ConfigKey RPC_MAX_RETRY_TIMES = + ConfigKeys.key("geaflow.rpc.max.retry.times") + .defaultValue(3) + .description("rpc max retry times"); + + public static final ConfigKey RPC_KEEP_ALIVE_TIME_SEC = + ConfigKeys.key("geaflow.rpc.keep.alive.time.sec") + .defaultValue(0) + .description("rpc keep alive time sec"); + + public static final ConfigKey RPC_THREADPOOL_SHARING_ENABLE = + ConfigKeys.key("geaflow.rpc.threadpool.sharing.enable") + .defaultValue(true) + .description("rpc threadpool sharing enable"); + + public static final ConfigKey RPC_IO_THREAD_NUM = + ConfigKeys.key("geaflow.rpc.io.thread.num").defaultValue(8).description("rpc io thread num"); + + public static final ConfigKey RPC_WORKER_THREAD_NUM = + ConfigKeys.key("geaflow.rpc.worker.thread.num") + .defaultValue(8) + .description("rpc worker thread num"); + + public static final ConfigKey RPC_BUFFER_SIZE_BYTES = + ConfigKeys.key("geaflow.rpc.buffer.size.bytes") + .defaultValue(256 * 1024) + .description("rpc buffer size bytes"); + + public static final ConfigKey RPC_CHANNEL_CONNECT_TYPE = + ConfigKeys.key("geaflow.rpc.channel.connect.type") + .defaultValue("pooled_connection") + .description( + "rpc channel connect type, e.g. [pooled_connection, short_connection," + + " single_connection]"); + + // ------------------------------------------------------------------------ + // cluster + // ------------------------------------------------------------------------ + + public static final ConfigKey RUN_LOCAL_MODE = + ConfigKeys.key("geaflow.run.local.mode") + .defaultValue(false) + .description("job run in single process or distributed"); + + public static final ConfigKey JOB_WORK_PATH = + ConfigKeys.key("geaflow.work.path").defaultValue("/tmp").description("job work path on disk"); + + public static final ConfigKey CONTAINER_DISPATCH_THREADS = + ConfigKeys.key("geaflow.container.dispatcher.threads") + .defaultValue(1) + .description("container event dispatcher thread number"); + + public static final ConfigKey CLIENT_VCORES = + ConfigKeys.key("geaflow.client.vcores").defaultValue(1.0).description("client cpu number"); + + public static final ConfigKey CLIENT_MEMORY_MB = + ConfigKeys.key("geaflow.client.memory.mb") + .defaultValue(1024) + .description("client container memory"); + + public static final ConfigKey CLIENT_DISK_GB = + ConfigKeys.key("geaflow.client.disk.gb").defaultValue(0).description("client container disk"); + + public static final ConfigKey CLIENT_JVM_OPTIONS = + ConfigKeys.key("geaflow.client.jvm.options") + .defaultValue("-Xmx640m,-Xms640m,-Xmn256m,-Xss256k") + .description("client jvm options"); + + public static final ConfigKey MASTER_MEMORY_MB = + ConfigKeys.key("geaflow.master.memory.mb") + .defaultValue(4096) + .description("master container memory"); + + public static final ConfigKey MASTER_DISK_GB = + ConfigKeys.key("geaflow.master.disk.gb").defaultValue(0).description("master container disk"); + + public static final ConfigKey MASTER_JVM_OPTIONS = + ConfigKeys.key("geaflow.master.jvm.options") + .defaultValue("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") + .description("master container jvm options"); + + public static final ConfigKey MASTER_VCORES = + ConfigKeys.key("geaflow.master.vcores").defaultValue(1.0).description("master cpu"); + + public static final ConfigKey DRIVER_NUM = + ConfigKeys.key("geaflow.driver.num").defaultValue(1).description("driver number"); + + public static final ConfigKey DRIVER_MEMORY_MB = + ConfigKeys.key("geaflow.driver.memory.mb") + .defaultValue(4096) + .description("driver container memory"); + + public static final ConfigKey DRIVER_DISK_GB = + ConfigKeys.key("geaflow.driver.disk.gb").defaultValue(0).description("driver container disk"); + + public static final ConfigKey DRIVER_JVM_OPTION = + ConfigKeys.key("geaflow.driver.jvm.options") + .defaultValue("-Xmx2048m,-Xms2048m,-Xmn512m,-Xss512k,-XX:MaxDirectMemorySize=1024m") + .description("driver container jvm options"); + + public static final ConfigKey DRIVER_VCORES = + ConfigKeys.key("geaflow.driver.vcores").defaultValue(1.0).description("driver cpu"); + + public static final ConfigKey CONTAINER_NUM = + ConfigKeys.key("geaflow.container.num").defaultValue(1).description("container num"); + + public static final ConfigKey CONTAINER_VCORES = + ConfigKeys.key("geaflow.container.vcores").defaultValue(1.0).description("container cpu"); + + public static final ConfigKey CONTAINER_MEMORY_MB = + ConfigKeys.key("geaflow.container.memory.mb") + .defaultValue(256) + .description("container memory"); + + public static final ConfigKey CONTAINER_DISK_GB = + ConfigKeys.key("geaflow.container.disk.gb").defaultValue(0).description("container disk"); + + public static final ConfigKey CONTAINER_WORKER_NUM = + ConfigKeys.key("geaflow.container.worker.num") + .defaultValue(16) + .description("max worker num in container"); + + public static final ConfigKey CONTAINER_JVM_OPTION = + ConfigKeys.key("geaflow.container.jvm.options") + .noDefaultValue() + .description("container jvm options"); + + public static final ConfigKey CONTAINER_HEAP_SIZE_MB = + ConfigKeys.key("geaflow.container.heap.size.mb") + .noDefaultValue() + .description("container max heap size in mb"); + + public static final ConfigKey EXECUTOR_MAX_MULTIPLE = + ConfigKeys.key("geaflow.executor.thread.max.multiple") + .defaultValue(10) + .description( + "Maximum thread pool size multiplier (maxThreads = multiple * available cores)"); + + public static final ConfigKey FO_ENABLE = + ConfigKeys.key("geaflow.fo.enable").defaultValue(true).description("whether to enable fo"); + + public static final ConfigKey FO_STRATEGY = + ConfigKeys.key("geaflow.fo.strategy") + .defaultValue("cluster_fo") + .description("whether to enable fo"); + + public static final ConfigKey FO_TIMEOUT_MS = + ConfigKeys.key("geaflow.fo.timeout.ms").defaultValue(300000).description("fo timeout in ms"); + + public static final ConfigKey FO_MAX_RESTARTS = + ConfigKeys.key("geaflow.fo.max.restarts") + .defaultValue(Integer.MAX_VALUE) + .description("process max restart times, value in [0, Int.maxValue]"); + + public static final ConfigKey HA_SERVICE_TYPE = + ConfigKeys.key("geaflow.ha.service.type") + .defaultValue("") + .description("ha service type, e.g., [redis, hbase, memory]"); + + public static final ConfigKey ENABLE_MASTER_LEADER_ELECTION = + ConfigKeys.key("geaflow.master.leader-election.enable") + .defaultValue(false) + .description( + "whether to enable leader-election of master, currently only supports in k8s env"); + + public static final ConfigKey LEADER_ELECTION_TYPE = + ConfigKeys.key("geaflow.leader-election.type") + .defaultValue("kubernetes") + .description("leader-election type, e.g., [kubernetes]"); + + public static final ConfigKey HTTP_REST_SERVICE_ENABLE = + ConfigKeys.key("geaflow.http.rest.service.enable") + .defaultValue(true) + .description("whether to enable http rest service"); + + public static final ConfigKey PROFILER_FILENAME_EXTENSION = + ConfigKeys.key("geaflow.profiler.filename.extension") + .defaultValue(".html") + .description("filename extension of profiler results, e.g., [.html, .svg]"); + + public static final ConfigKey CLUSTER_CLIENT_TIMEOUT_MS = + ConfigKeys.key("geaflow.cluster.client.timeout.ms") + .defaultValue(300000) + .description("cluster client timeout in ms"); + + public static final ConfigKey CLIENT_EXIT_WAIT_SECONDS = + ConfigKeys.key("geaflow.cluster.client.exit.wait.secs") + .defaultValue(5) + .description("cluster client exit wait time in seconds"); + + public static final ConfigKey SERVICE_DISCOVERY_TYPE = + ConfigKeys.key("geaflow.service.discovery.type") + .defaultValue("redis") + .description("service discovery type, e.g.[zookeeper, redis]"); + + public static final ConfigKey JOB_MODE = + ConfigKeys.key("geaflow.job.mode") + .defaultValue("compute") + .description("job mode, e.g.[compute, olap service]"); + + // ------------------------------------------------------------------------ + // supervisor + // ------------------------------------------------------------------------ + + public static final ConfigKey SUPERVISOR_ENABLE = + ConfigKeys.key("geaflow.supervisor.enable") + .defaultValue(false) + .description("enable supervisor or not"); + + public static final ConfigKey SUPERVISOR_RPC_PORT = + ConfigKeys.key("geaflow.supervisor.rpc.port") + .defaultValue(0) + .description("supervisor rpc port"); + + public static final ConfigKey SUPERVISOR_JVM_OPTIONS = + ConfigKeys.key("geaflow.supervisor.jvm.options") + .defaultValue("-Xmx128m,-Xms64m,-Xmn32m") + .description("supervisor jvm options"); + + public static final ConfigKey LOG_DIR = + ConfigKeys.key("geaflow.log.dir") + .defaultValue("/home/admin/logs/geaflow") + .description("geaflow job log directory"); + + public static final ConfigKey CONF_DIR = + ConfigKeys.key("geaflow.conf.dir") + .defaultValue("/etc/geaflow/conf") + .description("geaflow conf directory"); + + public static final ConfigKey PROCESS_AUTO_RESTART = + ConfigKeys.key("geaflow.process.auto-restart") + .defaultValue("unexpected") + .description("whether to restart process automatically"); + + public static final ConfigKey PROCESS_EXIT_WAIT_SECONDS = + ConfigKeys.key("geaflow.process.exit.wait.secs") + .defaultValue(3) + .description("process exit max wait seconds"); + + // ------------------------------------------------------------------------ + // shuffle + // ------------------------------------------------------------------------ + + /** Shuffle common config. */ + public static final ConfigKey SHUFFLE_PREFETCH = + ConfigKeys.key("geaflow.shuffle.prefetch.enable") + .defaultValue(true) + .description("if enable shuffle prefetch"); + + public static final ConfigKey SHUFFLE_MEMORY_POOL_ENABLE = + ConfigKeys.key("geaflow.shuffle.memory.pool.enable") + .defaultValue(false) + .description("whether to enable shuffle memory pool"); + + public static final ConfigKey SHUFFLE_COMPRESSION_ENABLE = + ConfigKeys.key("geaflow.shuffle.compression.enable") + .defaultValue(false) + .description("whether to enable shuffle compression"); + + public static final ConfigKey SHUFFLE_COMPRESSION_CODEC = + ConfigKeys.key("geaflow.shuffle.compression.codec") + .defaultValue("snappy") + .description("codec of shuffle compression"); + + public static final ConfigKey SHUFFLE_BACKPRESSURE_ENABLE = + ConfigKeys.key("geaflow.shuffle.backpressure.enable") + .defaultValue(false) + .description("whether to enable shuffle backpressure"); + + /** Shuffle network config. */ + public static final ConfigKey NETTY_SERVER_HOST = + ConfigKeys.key("geaflow.netty.server.host") + .defaultValue(ProcessUtil.LOCAL_ADDRESS) + .description("netty server host"); + + public static final ConfigKey NETTY_SERVER_PORT = + ConfigKeys.key("geaflow.netty.server.port").defaultValue(0).description("netty server port"); + + public static final ConfigKey NETTY_SERVER_BACKLOG = + ConfigKeys.key("geaflow.netty.server.backlog") + .defaultValue(512) + .description("requested maximum length of the queue of incoming connections"); + + public static final ConfigKey NETTY_SERVER_THREADS_NUM = + ConfigKeys.key("geaflow.netty.server.threads") + .defaultValue(4) + .description( + "number of threads used in the server thread pool. default to 4, and 0 means" + + " 2x#cores"); + + public static final ConfigKey NETTY_CLIENT_THREADS_NUM = + ConfigKeys.key("geaflow.netty.client.threads") + .defaultValue(4) + .description( + "number of threads used in the client thread pool. default to 4, and 0 means" + + " 2x#cores"); + + public static final ConfigKey NETTY_CONNECT_TIMEOUT_MS = + ConfigKeys.key("geaflow.netty.connect.timeout.ms") + .defaultValue(180000) + .description("netty connection timeout in milliseconds"); + + public static final ConfigKey NETTY_CONNECT_MAX_RETRY_TIMES = + ConfigKeys.key("geaflow.netty.connect.retry.times") + .defaultValue(100) + .description("max retry times of netty connection"); + + public static final ConfigKey NETTY_CONNECT_INITIAL_BACKOFF_MS = + ConfigKeys.key("geaflow.netty.connect.initial.backoff.ms") + .defaultValue(10) + .description("initial backoff time of netty connection in milliseconds"); + + public static final ConfigKey NETTY_CONNECT_MAX_BACKOFF_MS = + ConfigKeys.key("geaflow.netty.connect.max.backoff.ms") + .defaultValue(300000) + .description("max backoff time of netty connection in milliseconds"); + + public static final ConfigKey NETTY_RECEIVE_BUFFER_SIZE = + ConfigKeys.key("geaflow.netty.receive.buffer.size") + .defaultValue(0) + .description("netty receive buffer size"); + + public static final ConfigKey NETTY_SEND_BUFFER_SIZE = + ConfigKeys.key("geaflow.netty.send.buffer.size") + .defaultValue(0) + .description("netty send buffer size"); + + public static final ConfigKey NETTY_THREAD_CACHE_ENABLE = + ConfigKeys.key("geaflow.netty.thread.cache.enable") + .defaultValue(true) + .description("whether to enable netty thread cache"); + + public static final ConfigKey NETTY_PREFER_DIRECT_BUFFER = + ConfigKeys.key("geaflow.netty.prefer.direct.buffer") + .defaultValue(true) + .description("whether to prefer direct buffer"); + + public static final ConfigKey NETTY_CUSTOM_FRAME_DECODER_ENABLE = + ConfigKeys.key("geaflow.netty.custom.frame.decoder.enable") + .defaultValue(true) + .description("whether to enable custom frame decoder"); + + /** shuffle fetch config. */ + public static final ConfigKey SHUFFLE_FETCH_TIMEOUT_MS = + ConfigKeys.key("geaflow.shuffle.fetch.timeout.ms") + .defaultValue(600000) + .description("shuffle fetch timeout in milliseconds"); + + public static final ConfigKey SHUFFLE_FETCH_QUEUE_SIZE = + ConfigKeys.key("geaflow.shuffle.fetch.queue.size") + .defaultValue(1) + .description("size of shuffle fetch queue"); + + public static final ConfigKey SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE = + ConfigKeys.key("geaflow.shuffle.fetch.channel.queue.size") + .defaultValue(64) + .description("buffer number per channel"); + + /** Shuffle write config. */ + public static final ConfigKey SHUFFLE_SPILL_RECORDS = + ConfigKeys.key("geaflow.shuffle.spill.records") + .defaultValue(10000) + .description("num of shuffle spill records"); + + public static final ConfigKey SHUFFLE_SLICE_MAX_SPILL_SIZE = + ConfigKeys.key("geaflow.max.spill.size.perSlice") + .defaultValue(1610612736L) // 1.5G + .description("max size of each spill per slice in Bytes"); + + public static final ConfigKey SHUFFLE_FLUSH_BUFFER_SIZE_BYTES = + ConfigKeys.key("geaflow.shuffle.flush.buffer.size.bytes") + .defaultValue(128 * 1024) + .description("size of shuffle write buffer"); + + public static final ConfigKey SHUFFLE_WRITER_BUFFER_SIZE = + ConfigKeys.key("geaflow.shuffle.writer.buffer.size") + .defaultValue(64 * 1024 * 1024) + .description("max buffer size for the shuffle writer in bytes"); + + public static final ConfigKey SHUFFLE_EMIT_BUFFER_SIZE = + ConfigKeys.key("geaflow.shuffle.emit.buffer.size") + .defaultValue(1024) + .description("size of shuffle emit buffer of java object"); + + public static final ConfigKey SHUFFLE_EMIT_QUEUE_SIZE = + ConfigKeys.key("geaflow.shuffle.emit.queue.size") + .defaultValue(1) + .description("size of shuffle emit queue"); + + public static final ConfigKey SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS = + ConfigKeys.key("geaflow.shuffle.flush.buffer.timeout.ms") + .defaultValue(100) + .description("shuffle flush buffer timeout ms"); + + /** Shuffle storage. */ + public static final ConfigKey SHUFFLE_STORAGE_TYPE = + ConfigKeys.key("geaflow.shuffle.storage.type") + .defaultValue(StorageLevel.MEMORY_AND_DISK) + .description("type of shuffle storage"); + + public static final ConfigKey SHUFFLE_OFFHEAP_MEMORY_FRACTION = + ConfigKeys.key("geaflow.shuffle.offheap.fraction") + .defaultValue(0.2) + .description("fraction of shuffle offheap memory"); + + public static final ConfigKey SHUFFLE_HEAP_MEMORY_FRACTION = + ConfigKeys.key("geaflow.shuffle.heap.memory.fraction") + .defaultValue(0.2) + .description("fraction of shuffle heap memory"); + + public static final ConfigKey SHUFFLE_MEMORY_SAFETY_FRACTION = + ConfigKeys.key("geaflow.shuffle.memory.safety.fraction") + .defaultValue(0.9) + .description("fraction of shuffle memory to ensure safety"); + + // ------------------------------------------------------------------------ + // metrics + // ------------------------------------------------------------------------ + + public static final ConfigKey SCHEDULE_PERIOD = + ConfigKeys.key("geaflow.metric.schedule.period.sec") + .defaultValue(30) + .description("metric report interval in seconds"); + + public static final ConfigKey REPORTER_LIST = + ConfigKeys.key("geaflow.metric.reporters") + .defaultValue("") + .description( + "metric reporter list. Multiple reporters are separated by comma. for " + + "example: influxdb,tsdb,prometheus,slf4j"); + + public static final ConfigKey METRIC_META_REPORT_DELAY = + ConfigKeys.key("geaflow.metric.meta.delay.sec") + .defaultValue(5) + .description("metric meta report thread initial delay, used in tsdb reporter"); + + public static final ConfigKey METRIC_META_REPORT_PERIOD = + ConfigKeys.key("geaflow.metric.meta.period.sec") + .defaultValue(10) + .description("metric meta report period in seconds, used in tsdb reporter"); + + public static final ConfigKey METRIC_META_REPORT_RETRIES = + ConfigKeys.key("geaflow.metric.meta.retries") + .defaultValue(5) + .description("metric meta report max retry times"); + + public static final ConfigKey METRIC_MAX_CACHED_PIPELINES = + ConfigKeys.key("geaflow.metric.max.cached.pipelines") + .defaultValue(50) + .description("max cached pipeline metrics"); + + public static final ConfigKey METRIC_SERVICE_PORT = + ConfigKeys.key("geaflow.metric.service.port") + .defaultValue(0) + .description("metric service port"); + + public static final ConfigKey STATS_METRIC_STORE_TYPE = + ConfigKeys.key("geaflow.metric.stats.type") + .defaultValue("MEMORY") + .description("stats metrics store type, e.g., [MEMORY, JDBC, HBASE]"); + + public static final ConfigKey STATS_METRIC_FLUSH_THREADS = + ConfigKeys.key("geaflow.metric.flush.threads") + .defaultValue(1) + .description("stats metrics flush thread number"); + + public static final ConfigKey STATS_METRIC_FLUSH_BATCH_SIZE = + ConfigKeys.key("geaflow.metric.flush.batch.size") + .defaultValue(200) + .description("stats metrics flush batch size"); + + public static final ConfigKey STATS_METRIC_FLUSH_INTERVAL_MS = + ConfigKeys.key("geaflow.metric.flush.interval.ms") + .defaultValue(1000) + .description("stats flush interval in ms"); + + public static final ConfigKey ENABLE_DETAIL_METRIC = + ConfigKeys.key("geaflow.metric.detail.enable") + .defaultValue(false) + .description("if enable detail job metric"); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java index 441370ab5..7acd86fe0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/FrameworkConfigKeys.java @@ -20,153 +20,151 @@ package org.apache.geaflow.common.config.keys; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class FrameworkConfigKeys implements Serializable { - private static final long serialVersionUID = 0L; - - public static final ConfigKey ENABLE_EXTRA_OPTIMIZE = ConfigKeys - .key("geaflow.extra.optimize.enable") - .defaultValue(false) - .description("union optimization, disabled by default"); - - public static final ConfigKey ENABLE_EXTRA_OPTIMIZE_SINK = ConfigKeys - .key("geaflow.extra.optimize.sink.enable") - .defaultValue(false) - .description("union optimization starts on the sink, disabled by default"); - - public static final ConfigKey JOB_MAX_PARALLEL = ConfigKeys - .key("geaflow.job.max.parallel") - .defaultValue(1024) - .description("maximum parallelism of the job, default value is 1024"); - - public static final ConfigKey STREAMING_RUN_TIMES = ConfigKeys - .key("geaflow.streaming.job.run.times") - .defaultValue(Long.MAX_VALUE) - .description("maximum number of job runs"); - - public static final ConfigKey STREAMING_FLYING_BATCH_NUM = ConfigKeys - .key("geaflow.streaming.flying.batch.num") - .defaultValue(5) - .description("the number of batches that pipelined job runs simultaneously, default value is 5"); - - public static final ConfigKey BATCH_NUMBER_PER_CHECKPOINT = ConfigKeys - .key("geaflow.batch.number.per.checkpoint") - .defaultValue(5L) - .description("do checkpoint every specified number of batch"); - - public static final ConfigKey SYSTEM_STATE_BACKEND_TYPE = ConfigKeys - .key("geaflow.system.state.backend.type") - .defaultValue("ROCKSDB") - .description("system state backend store type, e.g., [rocksdb, memory]"); - - public static final ConfigKey SYSTEM_OFFSET_BACKEND_TYPE = ConfigKeys - .key("geaflow.system.offset.backend.type") - .defaultValue("MEMORY") - .description("system offset backend store type, e.g., [jdbc, memory]"); - - public static final ConfigKey INC_STREAM_MATERIALIZE_DISABLE = ConfigKeys - .key("geaflow.inc.stream.materialize.disable") - .defaultValue(false) - .description("inc stream materialize, enabled by default"); - - public static final ConfigKey SERVICE_SERVER_TYPE = ConfigKeys - .key("geaflow.analytics.service.server.type") - .defaultValue("analytics_rpc") - .description("analytics service server type, e.g., [analytics_rpc, analytics_http, storage]"); - - public static final ConfigKey SERVICE_SHARE_ENABLE = ConfigKeys - .key("geaflow.analytics.service.share.enable") - .defaultValue(false) - .description("whether enable analytics service using share mode, default is false"); - - public static final ConfigKey CLIENT_QUERY_TIMEOUT = ConfigKeys - .key("geaflow.analytics.client.query.timeout") - .noDefaultValue() - .description("analytics client query max run time"); - - public static final ConfigKey CLIENT_REQUEST_TIMEOUT_MILLISECOND = ConfigKeys - .key("geaflow.analytics.client.request.timeout.millisecond") - .defaultValue(2 * 60 * 1000) - .description("analytics client request max run time"); - - - public static final ConfigKey INFER_ENV_ENABLE = ConfigKeys - .key("geaflow.infer.env.enable") - .defaultValue(false) - .description("infer env enable, default is false"); - - public static final ConfigKey INFER_ENV_SHARE_MEMORY_QUEUE_SIZE = ConfigKeys - .key("geaflow.infer.env.share.memory.queue.size") - .defaultValue(8 * 1024 * 1024) - .description("infer env share memory queue size, default is 8 * 1024 * 1024"); - - public static final ConfigKey INFER_ENV_SO_LIB_URL = ConfigKeys - .key("geaflow.infer.env.so.lib.url") - .noDefaultValue() - .description("infer env so lib package oss url"); - - public static final ConfigKey INFER_ENV_INIT_TIMEOUT_SEC = ConfigKeys - .key("geaflow.infer.env.init.timeout.sec") - .defaultValue(120) - .description("infer env init timeout sec, default is 120"); - - public static final ConfigKey INFER_ENV_SUPPRESS_LOG_ENABLE = ConfigKeys - .key("geaflow.infer.env.suppress.log.enable") - .defaultValue(true) - .description("infer env suppress log enable, default is true"); - - public static final ConfigKey INFER_USER_DEFINE_LIB_PATH = ConfigKeys - .key("geaflow.infer.user.define.lib.path") - .noDefaultValue() - .description("infer user define lib path"); - - public static final ConfigKey INFER_ENV_USER_TRANSFORM_CLASSNAME = ConfigKeys - .key("geaflow.infer.env.user.transform.classname") - .noDefaultValue() - .description("infer env user custom define transform class name"); - - public static final ConfigKey INFER_ENV_OSS_ACCESS_KEY = ConfigKeys - .key("geaflow.infer.env.oss.access.key") - .noDefaultValue() - .description("infer env oss access key"); - - public static final ConfigKey INFER_ENV_OSS_ACCESS_ID = ConfigKeys - .key("geaflow.infer.env.oss.access.id") - .noDefaultValue() - .description("infer env oss access id"); - - public static final ConfigKey INFER_ENV_OSS_ENDPOINT = ConfigKeys - .key("geaflow.infer.env.oss.endpoint") - .noDefaultValue() - .description("infer env oss endpoint"); - - public static final ConfigKey INFER_ENV_OSS_DOWNLOAD_RETRY_NUM = ConfigKeys - .key("geaflow.infer.env.oss.download.retry.num") - .defaultValue(3) - .description("infer env oss download retry num, default is 3"); - - public static final ConfigKey INFER_ENV_CONDA_URL = ConfigKeys - .key("geaflow.infer.env.conda.url") - .noDefaultValue() - .description("infer env conda url"); - - public static final ConfigKey ASP_ENABLE = ConfigKeys - .key("geaflow.iteration.asp.enable") - .defaultValue(false) - .description("whether enable iteration asp mode, disabled by default"); - - public static final ConfigKey ADD_INVOKE_VIDS_EACH_ITERATION = ConfigKeys - .key("geaflow.add.invoke.vids.each.iteration") - .defaultValue(true) - .description(""); - - public static final ConfigKey UDF_MATERIALIZE_GRAPH_IN_FINISH = ConfigKeys - .key("geaflow.udf.materialize.graph.in.finish") - .defaultValue(false) - .description("in dynmic graph, whether udf function materialize graph in finish"); - + private static final long serialVersionUID = 0L; + + public static final ConfigKey ENABLE_EXTRA_OPTIMIZE = + ConfigKeys.key("geaflow.extra.optimize.enable") + .defaultValue(false) + .description("union optimization, disabled by default"); + + public static final ConfigKey ENABLE_EXTRA_OPTIMIZE_SINK = + ConfigKeys.key("geaflow.extra.optimize.sink.enable") + .defaultValue(false) + .description("union optimization starts on the sink, disabled by default"); + + public static final ConfigKey JOB_MAX_PARALLEL = + ConfigKeys.key("geaflow.job.max.parallel") + .defaultValue(1024) + .description("maximum parallelism of the job, default value is 1024"); + + public static final ConfigKey STREAMING_RUN_TIMES = + ConfigKeys.key("geaflow.streaming.job.run.times") + .defaultValue(Long.MAX_VALUE) + .description("maximum number of job runs"); + + public static final ConfigKey STREAMING_FLYING_BATCH_NUM = + ConfigKeys.key("geaflow.streaming.flying.batch.num") + .defaultValue(5) + .description( + "the number of batches that pipelined job runs simultaneously, default value is 5"); + + public static final ConfigKey BATCH_NUMBER_PER_CHECKPOINT = + ConfigKeys.key("geaflow.batch.number.per.checkpoint") + .defaultValue(5L) + .description("do checkpoint every specified number of batch"); + + public static final ConfigKey SYSTEM_STATE_BACKEND_TYPE = + ConfigKeys.key("geaflow.system.state.backend.type") + .defaultValue("ROCKSDB") + .description("system state backend store type, e.g., [rocksdb, memory]"); + + public static final ConfigKey SYSTEM_OFFSET_BACKEND_TYPE = + ConfigKeys.key("geaflow.system.offset.backend.type") + .defaultValue("MEMORY") + .description("system offset backend store type, e.g., [jdbc, memory]"); + + public static final ConfigKey INC_STREAM_MATERIALIZE_DISABLE = + ConfigKeys.key("geaflow.inc.stream.materialize.disable") + .defaultValue(false) + .description("inc stream materialize, enabled by default"); + + public static final ConfigKey SERVICE_SERVER_TYPE = + ConfigKeys.key("geaflow.analytics.service.server.type") + .defaultValue("analytics_rpc") + .description( + "analytics service server type, e.g., [analytics_rpc, analytics_http, storage]"); + + public static final ConfigKey SERVICE_SHARE_ENABLE = + ConfigKeys.key("geaflow.analytics.service.share.enable") + .defaultValue(false) + .description("whether enable analytics service using share mode, default is false"); + + public static final ConfigKey CLIENT_QUERY_TIMEOUT = + ConfigKeys.key("geaflow.analytics.client.query.timeout") + .noDefaultValue() + .description("analytics client query max run time"); + + public static final ConfigKey CLIENT_REQUEST_TIMEOUT_MILLISECOND = + ConfigKeys.key("geaflow.analytics.client.request.timeout.millisecond") + .defaultValue(2 * 60 * 1000) + .description("analytics client request max run time"); + + public static final ConfigKey INFER_ENV_ENABLE = + ConfigKeys.key("geaflow.infer.env.enable") + .defaultValue(false) + .description("infer env enable, default is false"); + + public static final ConfigKey INFER_ENV_SHARE_MEMORY_QUEUE_SIZE = + ConfigKeys.key("geaflow.infer.env.share.memory.queue.size") + .defaultValue(8 * 1024 * 1024) + .description("infer env share memory queue size, default is 8 * 1024 * 1024"); + + public static final ConfigKey INFER_ENV_SO_LIB_URL = + ConfigKeys.key("geaflow.infer.env.so.lib.url") + .noDefaultValue() + .description("infer env so lib package oss url"); + + public static final ConfigKey INFER_ENV_INIT_TIMEOUT_SEC = + ConfigKeys.key("geaflow.infer.env.init.timeout.sec") + .defaultValue(120) + .description("infer env init timeout sec, default is 120"); + + public static final ConfigKey INFER_ENV_SUPPRESS_LOG_ENABLE = + ConfigKeys.key("geaflow.infer.env.suppress.log.enable") + .defaultValue(true) + .description("infer env suppress log enable, default is true"); + + public static final ConfigKey INFER_USER_DEFINE_LIB_PATH = + ConfigKeys.key("geaflow.infer.user.define.lib.path") + .noDefaultValue() + .description("infer user define lib path"); + + public static final ConfigKey INFER_ENV_USER_TRANSFORM_CLASSNAME = + ConfigKeys.key("geaflow.infer.env.user.transform.classname") + .noDefaultValue() + .description("infer env user custom define transform class name"); + + public static final ConfigKey INFER_ENV_OSS_ACCESS_KEY = + ConfigKeys.key("geaflow.infer.env.oss.access.key") + .noDefaultValue() + .description("infer env oss access key"); + + public static final ConfigKey INFER_ENV_OSS_ACCESS_ID = + ConfigKeys.key("geaflow.infer.env.oss.access.id") + .noDefaultValue() + .description("infer env oss access id"); + + public static final ConfigKey INFER_ENV_OSS_ENDPOINT = + ConfigKeys.key("geaflow.infer.env.oss.endpoint") + .noDefaultValue() + .description("infer env oss endpoint"); + + public static final ConfigKey INFER_ENV_OSS_DOWNLOAD_RETRY_NUM = + ConfigKeys.key("geaflow.infer.env.oss.download.retry.num") + .defaultValue(3) + .description("infer env oss download retry num, default is 3"); + + public static final ConfigKey INFER_ENV_CONDA_URL = + ConfigKeys.key("geaflow.infer.env.conda.url") + .noDefaultValue() + .description("infer env conda url"); + + public static final ConfigKey ASP_ENABLE = + ConfigKeys.key("geaflow.iteration.asp.enable") + .defaultValue(false) + .description("whether enable iteration asp mode, disabled by default"); + + public static final ConfigKey ADD_INVOKE_VIDS_EACH_ITERATION = + ConfigKeys.key("geaflow.add.invoke.vids.each.iteration").defaultValue(true).description(""); + + public static final ConfigKey UDF_MATERIALIZE_GRAPH_IN_FINISH = + ConfigKeys.key("geaflow.udf.materialize.graph.in.finish") + .defaultValue(false) + .description("in dynmic graph, whether udf function materialize graph in finish"); } - diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/StateConfigKeys.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/StateConfigKeys.java index 3fbde5276..3b51cdc86 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/StateConfigKeys.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/config/keys/StateConfigKeys.java @@ -20,71 +20,72 @@ package org.apache.geaflow.common.config.keys; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class StateConfigKeys implements Serializable { - public static final byte[] DELIMITER = "\u0001\u0008".getBytes(); - - public static final ConfigKey STATE_ARCHIVED_VERSION_NUM = ConfigKeys - .key("geaflow.state.archived.version.num") - .defaultValue(1) - .description("state archived version number, default 1"); - - public static final ConfigKey STATE_PARANOID_CHECK_ENABLE = ConfigKeys - .key("geaflow.state.paranoid.check.enable") - .defaultValue(false) - .description("state paranoid check, default false"); - - public static final ConfigKey STATE_WRITE_ASYNC_ENABLE = ConfigKeys - .key("geaflow.state.write.async.enable") - .defaultValue(true) - .description("state async write, default true"); - - public static final ConfigKey STATE_WRITE_BUFFER_DEEP_COPY = ConfigKeys - .key("geaflow.state.write.buffer.deep.copy") - .defaultValue(false) - .description("state write buffer deep copy read, default false"); - - public static final ConfigKey STATE_WRITE_BUFFER_NUMBER = ConfigKeys - .key("geaflow.state.write.buffer.number") - .defaultValue(3) - .description("state write buffer number, default 3"); - - public static final ConfigKey STATE_WRITE_BUFFER_SIZE = ConfigKeys - .key("geaflow.state.write.buffer.size") - .defaultValue(10000) - .description("state write buffer size, default 10000"); - - public static final ConfigKey STATE_KV_ENCODER_CLASS = ConfigKeys - .key("geaflow.state.kv.encoder.class") - .defaultValue("org.apache.geaflow.state.graph.encoder.GraphKVEncoder") - .description("state kv encoder"); - - public static final ConfigKey STATE_KV_ENCODER_EDGE_ORDER = ConfigKeys - .key("geaflow.state.kv.encoder.edge.order") - .defaultValue("") - .description("state kv encoder edge atom order, splitter ,"); - - // for read only state. - public static final ConfigKey STATE_RECOVER_LATEST_VERSION_ENABLE = ConfigKeys - .key("geaflow.state.recover.latest.version.enable") - .defaultValue(false) - .description("enable recover latest version, default false"); - - public static final ConfigKey STATE_BACKGROUND_SYNC_ENABLE = ConfigKeys - .key("geaflow.state.background.sync.enable") - .defaultValue(false) - .description("enable state background sync, default false"); - - public static final ConfigKey STATE_SYNC_GAP_MS = ConfigKeys - .key("geaflow.state.sync.gap.ms") - .defaultValue(600000) - .description("state background sync ms, default 600000ms"); - - public static final ConfigKey STATE_ROCKSDB_PERSIST_TIMEOUT_SECONDS = ConfigKeys - .key("geaflow.state.rocksdb.persist.timeout.second") - .defaultValue(Integer.MAX_VALUE) - .description("rocksdb persist timeout second, default Integer.MAX_VALUE"); + public static final byte[] DELIMITER = "\u0001\u0008".getBytes(); + + public static final ConfigKey STATE_ARCHIVED_VERSION_NUM = + ConfigKeys.key("geaflow.state.archived.version.num") + .defaultValue(1) + .description("state archived version number, default 1"); + + public static final ConfigKey STATE_PARANOID_CHECK_ENABLE = + ConfigKeys.key("geaflow.state.paranoid.check.enable") + .defaultValue(false) + .description("state paranoid check, default false"); + + public static final ConfigKey STATE_WRITE_ASYNC_ENABLE = + ConfigKeys.key("geaflow.state.write.async.enable") + .defaultValue(true) + .description("state async write, default true"); + + public static final ConfigKey STATE_WRITE_BUFFER_DEEP_COPY = + ConfigKeys.key("geaflow.state.write.buffer.deep.copy") + .defaultValue(false) + .description("state write buffer deep copy read, default false"); + + public static final ConfigKey STATE_WRITE_BUFFER_NUMBER = + ConfigKeys.key("geaflow.state.write.buffer.number") + .defaultValue(3) + .description("state write buffer number, default 3"); + + public static final ConfigKey STATE_WRITE_BUFFER_SIZE = + ConfigKeys.key("geaflow.state.write.buffer.size") + .defaultValue(10000) + .description("state write buffer size, default 10000"); + + public static final ConfigKey STATE_KV_ENCODER_CLASS = + ConfigKeys.key("geaflow.state.kv.encoder.class") + .defaultValue("org.apache.geaflow.state.graph.encoder.GraphKVEncoder") + .description("state kv encoder"); + + public static final ConfigKey STATE_KV_ENCODER_EDGE_ORDER = + ConfigKeys.key("geaflow.state.kv.encoder.edge.order") + .defaultValue("") + .description("state kv encoder edge atom order, splitter ,"); + + // for read only state. + public static final ConfigKey STATE_RECOVER_LATEST_VERSION_ENABLE = + ConfigKeys.key("geaflow.state.recover.latest.version.enable") + .defaultValue(false) + .description("enable recover latest version, default false"); + + public static final ConfigKey STATE_BACKGROUND_SYNC_ENABLE = + ConfigKeys.key("geaflow.state.background.sync.enable") + .defaultValue(false) + .description("enable state background sync, default false"); + + public static final ConfigKey STATE_SYNC_GAP_MS = + ConfigKeys.key("geaflow.state.sync.gap.ms") + .defaultValue(600000) + .description("state background sync ms, default 600000ms"); + + public static final ConfigKey STATE_ROCKSDB_PERSIST_TIMEOUT_SECONDS = + ConfigKeys.key("geaflow.state.rocksdb.persist.timeout.second") + .defaultValue(Integer.MAX_VALUE) + .description("rocksdb persist timeout second, default Integer.MAX_VALUE"); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/EncoderResolver.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/EncoderResolver.java index 470bd1e69..51d52ad2f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/EncoderResolver.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/EncoderResolver.java @@ -19,7 +19,6 @@ package org.apache.geaflow.common.encoder; -import com.google.common.annotations.VisibleForTesting; import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -32,6 +31,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.encoder.impl.EnumEncoder; import org.apache.geaflow.common.encoder.impl.GenericArrayEncoder; import org.apache.geaflow.common.encoder.impl.PojoEncoder; @@ -42,364 +42,383 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; + public class EncoderResolver { - private static final Logger LOGGER = LoggerFactory.getLogger(EncoderResolver.class); + private static final Logger LOGGER = LoggerFactory.getLogger(EncoderResolver.class); - private static final String EMPTY = ""; - private static final String UNDERLINE = "_"; - private static final String METHOD_READ_OBJECT = "readObject"; - private static final String METHOD_WRITE_OBJECT = "writeObject"; + private static final String EMPTY = ""; + private static final String UNDERLINE = "_"; + private static final String METHOD_READ_OBJECT = "readObject"; + private static final String METHOD_WRITE_OBJECT = "writeObject"; - public static IEncoder resolveClass(Class clazz) { - return resolveType(clazz); - } + public static IEncoder resolveClass(Class clazz) { + return resolveType(clazz); + } - @SuppressWarnings({"unchecked", "rawtypes"}) - public static IEncoder resolveType(Type type) { - if (isClassType(type)) { - Class clazz = typeToClass(type); - if (clazz == Object.class) { - return null; - } - if (clazz == Class.class) { - return null; - } - if (Modifier.isInterface(clazz.getModifiers())) { - return null; - } - if (Encoders.PRIMITIVE_ENCODER_MAP.containsKey(clazz)) { - return Encoders.PRIMITIVE_ENCODER_MAP.get(clazz); - } - if (Tuple.class.isAssignableFrom(clazz)) { - return resolveTuple(type); - } - if (Triple.class.isAssignableFrom(clazz)) { - return resolveTriple(type); - } - if (Enum.class.isAssignableFrom(clazz)) { - return new EnumEncoder<>(clazz); - } - if (clazz.isArray()) { - if (Encoders.PRIMITIVE_ARR_ENCODER_MAP.containsKey(clazz)) { - return Encoders.PRIMITIVE_ARR_ENCODER_MAP.get(clazz); - } - Class componentClass = clazz.getComponentType(); - IEncoder componentEncoder = resolveClass(clazz.getComponentType()); - if (componentEncoder != null) { - GenericArrayEncoder.ArrayConstructor constructor = - length -> (Object[]) Array.newInstance(componentClass, length); - return new GenericArrayEncoder(componentEncoder, constructor); - } - } - return resolvePojo(type); - } + @SuppressWarnings({"unchecked", "rawtypes"}) + public static IEncoder resolveType(Type type) { + if (isClassType(type)) { + Class clazz = typeToClass(type); + if (clazz == Object.class) { return null; - } - - public static IEncoder resolveTuple(Type type) { - List subTypeTree = new ArrayList<>(); - Type curType = type; - while (!(isClassType(curType) && typeToClass(curType).equals(Tuple.class))) { - if (curType instanceof ParameterizedType) { - subTypeTree.add((ParameterizedType) curType); - } - curType = typeToClass(curType).getGenericSuperclass(); - } - - if (curType instanceof Class) { - LOGGER.warn("Tuple needs to be parameterized with generics"); - return null; - } - - ParameterizedType parameterizedType = (ParameterizedType) curType; - subTypeTree.add(parameterizedType); - - IEncoder[] subEncoders = resolveSubEncoder(subTypeTree, parameterizedType); - if (subEncoders == null || subEncoders.length != 2) { - return null; + } + if (clazz == Class.class) { + return null; + } + if (Modifier.isInterface(clazz.getModifiers())) { + return null; + } + if (Encoders.PRIMITIVE_ENCODER_MAP.containsKey(clazz)) { + return Encoders.PRIMITIVE_ENCODER_MAP.get(clazz); + } + if (Tuple.class.isAssignableFrom(clazz)) { + return resolveTuple(type); + } + if (Triple.class.isAssignableFrom(clazz)) { + return resolveTriple(type); + } + if (Enum.class.isAssignableFrom(clazz)) { + return new EnumEncoder<>(clazz); + } + if (clazz.isArray()) { + if (Encoders.PRIMITIVE_ARR_ENCODER_MAP.containsKey(clazz)) { + return Encoders.PRIMITIVE_ARR_ENCODER_MAP.get(clazz); } - if (countClassFields(typeToClass(type)) != subEncoders.length) { - LOGGER.warn("tuple filed num does not match encoder num"); - return null; + Class componentClass = clazz.getComponentType(); + IEncoder componentEncoder = resolveClass(clazz.getComponentType()); + if (componentEncoder != null) { + GenericArrayEncoder.ArrayConstructor constructor = + length -> (Object[]) Array.newInstance(componentClass, length); + return new GenericArrayEncoder(componentEncoder, constructor); } - return Encoders.tuple(subEncoders[0], subEncoders[1]); + } + return resolvePojo(type); + } + return null; + } + + public static IEncoder resolveTuple(Type type) { + List subTypeTree = new ArrayList<>(); + Type curType = type; + while (!(isClassType(curType) && typeToClass(curType).equals(Tuple.class))) { + if (curType instanceof ParameterizedType) { + subTypeTree.add((ParameterizedType) curType); + } + curType = typeToClass(curType).getGenericSuperclass(); } - public static IEncoder resolveTriple(Type type) { - List subTypeTree = new ArrayList<>(); - Type curType = type; - while (!(isClassType(curType) && typeToClass(curType).equals(Triple.class))) { - if (curType instanceof ParameterizedType) { - subTypeTree.add((ParameterizedType) curType); - } - curType = typeToClass(curType).getGenericSuperclass(); - } - - if (curType instanceof Class) { - LOGGER.warn("Tuple needs to be parameterized with generics"); - return null; - } + if (curType instanceof Class) { + LOGGER.warn("Tuple needs to be parameterized with generics"); + return null; + } - ParameterizedType parameterizedType = (ParameterizedType) curType; - subTypeTree.add(parameterizedType); + ParameterizedType parameterizedType = (ParameterizedType) curType; + subTypeTree.add(parameterizedType); - IEncoder[] subEncoders = resolveSubEncoder(subTypeTree, parameterizedType); - if (subEncoders == null || subEncoders.length != 3) { - return null; - } - if (countClassFields(typeToClass(type)) != subEncoders.length) { - LOGGER.warn("triple filed num does not match encoder num"); - return null; - } - return Encoders.triple(subEncoders[0], subEncoders[1], subEncoders[2]); + IEncoder[] subEncoders = resolveSubEncoder(subTypeTree, parameterizedType); + if (subEncoders == null || subEncoders.length != 2) { + return null; + } + if (countClassFields(typeToClass(type)) != subEncoders.length) { + LOGGER.warn("tuple filed num does not match encoder num"); + return null; + } + return Encoders.tuple(subEncoders[0], subEncoders[1]); + } + + public static IEncoder resolveTriple(Type type) { + List subTypeTree = new ArrayList<>(); + Type curType = type; + while (!(isClassType(curType) && typeToClass(curType).equals(Triple.class))) { + if (curType instanceof ParameterizedType) { + subTypeTree.add((ParameterizedType) curType); + } + curType = typeToClass(curType).getGenericSuperclass(); } - private static IEncoder[] resolveSubEncoder(List typeTree, - ParameterizedType parameterizedType) { - Type[] typeArguments = parameterizedType.getActualTypeArguments(); - IEncoder[] encoders = new IEncoder[typeArguments.length]; - for (int i = 0; i < typeArguments.length; i++) { - Type typeArgument = typeArguments[i]; - Type concreteType = typeArgument; - if (typeArgument instanceof TypeVariable) { - concreteType = getConcreteTypeofTypeVariable(typeTree, (TypeVariable) typeArgument); - } - IEncoder encoder = resolveType(concreteType); - if (encoder == null) { - return null; - } - encoders[i] = encoder; - } - return encoders; + if (curType instanceof Class) { + LOGGER.warn("Tuple needs to be parameterized with generics"); + return null; } - public static IEncoder resolvePojo(Type type) { - if (isClassType(type)) { - Class clazz = typeToClass(type); - try { - analysisPojo(clazz); - } catch (GeaflowRuntimeException e) { - return null; - } - return PojoEncoder.build(clazz); - } + ParameterizedType parameterizedType = (ParameterizedType) curType; + subTypeTree.add(parameterizedType); + + IEncoder[] subEncoders = resolveSubEncoder(subTypeTree, parameterizedType); + if (subEncoders == null || subEncoders.length != 3) { + return null; + } + if (countClassFields(typeToClass(type)) != subEncoders.length) { + LOGGER.warn("triple filed num does not match encoder num"); + return null; + } + return Encoders.triple(subEncoders[0], subEncoders[1], subEncoders[2]); + } + + private static IEncoder[] resolveSubEncoder( + List typeTree, ParameterizedType parameterizedType) { + Type[] typeArguments = parameterizedType.getActualTypeArguments(); + IEncoder[] encoders = new IEncoder[typeArguments.length]; + for (int i = 0; i < typeArguments.length; i++) { + Type typeArgument = typeArguments[i]; + Type concreteType = typeArgument; + if (typeArgument instanceof TypeVariable) { + concreteType = getConcreteTypeofTypeVariable(typeTree, (TypeVariable) typeArgument); + } + IEncoder encoder = resolveType(concreteType); + if (encoder == null) { return null; + } + encoders[i] = encoder; } - - @VisibleForTesting - protected static void analysisPojo(Class clazz) { - if (!Modifier.isPublic(clazz.getModifiers())) { - String msg = "Class [" + clazz.getName() + "] is not public, it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers())) { - String msg = "Class [" + clazz.getName() + "] is abstract or an interface," - + "it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - if (clazz.getSuperclass() != Object.class) { - String msg = "Class [" + clazz.getName() + "] does not extends Object directly, " + return encoders; + } + + public static IEncoder resolvePojo(Type type) { + if (isClassType(type)) { + Class clazz = typeToClass(type); + try { + analysisPojo(clazz); + } catch (GeaflowRuntimeException e) { + return null; + } + return PojoEncoder.build(clazz); + } + return null; + } + + @VisibleForTesting + protected static void analysisPojo(Class clazz) { + if (!Modifier.isPublic(clazz.getModifiers())) { + String msg = + "Class [" + clazz.getName() + "] is not public, it cannot be used as a POJO type"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers())) { + String msg = + "Class [" + + clazz.getName() + + "] is abstract or an interface," + + "it cannot be used as a POJO type"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + if (clazz.getSuperclass() != Object.class) { + String msg = + "Class [" + + clazz.getName() + + "] does not extends Object directly, " + + "it cannot be used as a POJO type"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods) { + if (METHOD_READ_OBJECT.equals(method.getName()) + || METHOD_WRITE_OBJECT.equals(method.getName())) { + String msg = + "Class [" + + clazz.getName() + + "] contains custom serialization methods we do not call, " + "it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - Method[] methods = clazz.getDeclaredMethods(); - for (Method method : methods) { - if (METHOD_READ_OBJECT.equals(method.getName()) || METHOD_WRITE_OBJECT.equals(method.getName())) { - String msg = "Class [" + clazz.getName() + "] contains custom serialization methods we do not call, " - + "it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - } - - // check default constructor - Constructor defaultConstructor; - try { - defaultConstructor = clazz.getDeclaredConstructor(); - } catch (NoSuchMethodException e) { - String msg = "Class [" + clazz.getName() + "] has no default constructor, it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg), e); - } - if (!Modifier.isPublic(defaultConstructor.getModifiers())) { - String msg = "The default constructor of [" + clazz + "] is not Public, it cannot be used as a POJO type"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - - List fields = getPojoFields(clazz); - if (fields.isEmpty()) { - String msg = "Class [" + clazz.getName() + "] has no declared fields"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - for (Field field : fields) { - checkValidPojoField(field, clazz); - } + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } } - private static void checkValidPojoField(Field f, Class clazz) { - if (!Modifier.isPublic(f.getModifiers())) { - final String fieldNameLow = f.getName().toLowerCase().replaceAll(UNDERLINE, EMPTY); - - Type fieldType = f.getGenericType(); - Class fieldTypeWrapper = f.getType(); - if (fieldTypeWrapper.isPrimitive()) { - fieldTypeWrapper = Encoders.PRIMITIVE_WRAPPER_MAP.get(fieldTypeWrapper); - } - if (fieldType instanceof TypeVariable) { - String msg = "do not support generics yet"; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - - boolean hasGetter = false; - boolean hasSetter = false; - for (Method m : clazz.getMethods()) { - final String methodNameLow = m.getName().toLowerCase().replaceAll(UNDERLINE, EMPTY); - // check getter, one qualified method is ok - if (!hasGetter - && (methodNameLow.equals("get" + fieldNameLow) || methodNameLow.equals("is" + fieldNameLow)) - && m.getParameterTypes().length == 0 - && (m.getGenericReturnType().equals(fieldType) || m.getReturnType().equals(fieldTypeWrapper)) - ) { - hasGetter = true; - } - // check setter, one qualified method is ok - if (!hasSetter - && methodNameLow.equals("set" + fieldNameLow) - && m.getParameterTypes().length == 1 - && (m.getGenericParameterTypes()[0].equals(fieldType) || m.getParameterTypes()[0].equals(fieldTypeWrapper)) - && (m.getReturnType().equals(Void.TYPE) || m.getReturnType().equals(clazz)) - ) { - hasSetter = true; - } - } - - if (!hasGetter) { - String msg = clazz + " does not contain a getter for field " + f.getName(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - if (!hasSetter) { - String msg = clazz + " does not contain a setter for field " + f.getName(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); - } - } + // check default constructor + Constructor defaultConstructor; + try { + defaultConstructor = clazz.getDeclaredConstructor(); + } catch (NoSuchMethodException e) { + String msg = + "Class [" + + clazz.getName() + + "] has no default constructor, it cannot be used as a POJO type"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg), e); } - - /** - * Resolve encoder from function generics. - */ - - public static IEncoder resolveFunction(Class baseClass, Object function) { - return resolveFunction(baseClass, function, 0); + if (!Modifier.isPublic(defaultConstructor.getModifiers())) { + String msg = + "The default constructor of [" + + clazz + + "] is not Public, it cannot be used as a POJO type"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); } - public static IEncoder resolveFunction(Class baseClass, Object function, int typeParaIdx) { - if (baseClass == Object.class) { - return null; - } - Class funcClass = function.getClass(); - List typeTree = new ArrayList<>(); - Type paraType = extractParaType(typeTree, baseClass, funcClass, typeParaIdx); - - if (paraType instanceof TypeVariable) { - Type concreteType = getConcreteTypeofTypeVariable(typeTree, (TypeVariable) paraType); - return resolveType(concreteType); - } - - return resolveType(paraType); + List fields = getPojoFields(clazz); + if (fields.isEmpty()) { + String msg = "Class [" + clazz.getName() + "] has no declared fields"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); } - - private static Type extractParaType(List typeTree, - Class baseClass, - Class functionClass, - int typeParaIdx) { - Type[] gInterfaces = functionClass.getGenericInterfaces(); - for (Type gInterface : gInterfaces) { - Type type = extractParaTypeFromGeneric(typeTree, baseClass, gInterface, typeParaIdx); - if (type != null) { - return type; - } - } - Type gClass = functionClass.getGenericSuperclass(); - return extractParaTypeFromGeneric(typeTree, baseClass, gClass, typeParaIdx); + for (Field field : fields) { + checkValidPojoField(field, clazz); } - - private static Type extractParaTypeFromGeneric(List typeTree, - Class baseClass, - Type type, - int typeParaIdx) { - if (type instanceof ParameterizedType) { - ParameterizedType parameterizedType = (ParameterizedType) type; - typeTree.add(parameterizedType); - Class rawType = (Class) parameterizedType.getRawType(); - if (baseClass.equals(rawType)) { - return parameterizedType.getActualTypeArguments()[typeParaIdx]; - } - if (baseClass.isAssignableFrom(rawType)) { - return extractParaType(typeTree, baseClass, rawType, typeParaIdx); - } + } + + private static void checkValidPojoField(Field f, Class clazz) { + if (!Modifier.isPublic(f.getModifiers())) { + final String fieldNameLow = f.getName().toLowerCase().replaceAll(UNDERLINE, EMPTY); + + Type fieldType = f.getGenericType(); + Class fieldTypeWrapper = f.getType(); + if (fieldTypeWrapper.isPrimitive()) { + fieldTypeWrapper = Encoders.PRIMITIVE_WRAPPER_MAP.get(fieldTypeWrapper); + } + if (fieldType instanceof TypeVariable) { + String msg = "do not support generics yet"; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + + boolean hasGetter = false; + boolean hasSetter = false; + for (Method m : clazz.getMethods()) { + final String methodNameLow = m.getName().toLowerCase().replaceAll(UNDERLINE, EMPTY); + // check getter, one qualified method is ok + if (!hasGetter + && (methodNameLow.equals("get" + fieldNameLow) + || methodNameLow.equals("is" + fieldNameLow)) + && m.getParameterTypes().length == 0 + && (m.getGenericReturnType().equals(fieldType) + || m.getReturnType().equals(fieldTypeWrapper))) { + hasGetter = true; } - if (type instanceof Class) { - Class clazz = (Class) type; - if (baseClass.isAssignableFrom(clazz)) { - return extractParaType(typeTree, baseClass, clazz, typeParaIdx); - } + // check setter, one qualified method is ok + if (!hasSetter + && methodNameLow.equals("set" + fieldNameLow) + && m.getParameterTypes().length == 1 + && (m.getGenericParameterTypes()[0].equals(fieldType) + || m.getParameterTypes()[0].equals(fieldTypeWrapper)) + && (m.getReturnType().equals(Void.TYPE) || m.getReturnType().equals(clazz))) { + hasSetter = true; } - return null; + } + + if (!hasGetter) { + String msg = clazz + " does not contain a getter for field " + f.getName(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + if (!hasSetter) { + String msg = clazz + " does not contain a setter for field " + f.getName(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } } + } - private static Type getConcreteTypeofTypeVariable(List typeTree, TypeVariable typeVar) { - TypeVariable curTypeVar = typeVar; - for (int i = typeTree.size() - 1; i >= 0; i--) { - ParameterizedType curType = typeTree.get(i); - Class rawType = (Class) curType.getRawType(); - TypeVariable>[] rawTypeParameters = rawType.getTypeParameters(); - for (int idx = 0; idx < rawTypeParameters.length; idx++) { - TypeVariable rawTypeVar = rawType.getTypeParameters()[idx]; - // check if variable match - if (curTypeVar.getName().equals(rawTypeVar.getName()) - && curTypeVar.getGenericDeclaration().equals(rawTypeVar.getGenericDeclaration())) { - Type actualTypeArgument = curType.getActualTypeArguments()[idx]; - if (actualTypeArgument instanceof TypeVariable) { - // another type variable level - curTypeVar = (TypeVariable) actualTypeArgument; - } else { - // class - return actualTypeArgument; - } - } - } - } - // most likely type erasure - return curTypeVar; - } + /** Resolve encoder from function generics. */ + public static IEncoder resolveFunction(Class baseClass, Object function) { + return resolveFunction(baseClass, function, 0); + } - public static boolean isClassType(Type t) { - return t instanceof Class || t instanceof ParameterizedType; + public static IEncoder resolveFunction(Class baseClass, Object function, int typeParaIdx) { + if (baseClass == Object.class) { + return null; } + Class funcClass = function.getClass(); + List typeTree = new ArrayList<>(); + Type paraType = extractParaType(typeTree, baseClass, funcClass, typeParaIdx); - public static Class typeToClass(Type t) { - if (t instanceof Class) { - return (Class) t; - } else if (t instanceof ParameterizedType) { - return ((Class) ((ParameterizedType) t).getRawType()); - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("Cannot convert type to class")); + if (paraType instanceof TypeVariable) { + Type concreteType = getConcreteTypeofTypeVariable(typeTree, (TypeVariable) paraType); + return resolveType(concreteType); } - private static int countClassFields(Class clazz) { - int fieldCount = 0; - for (Field field : clazz.getFields()) { - if (!Modifier.isStatic(field.getModifiers()) - && !Modifier.isTransient(field.getModifiers())) { - fieldCount++; - } + return resolveType(paraType); + } + + private static Type extractParaType( + List typeTree, + Class baseClass, + Class functionClass, + int typeParaIdx) { + Type[] gInterfaces = functionClass.getGenericInterfaces(); + for (Type gInterface : gInterfaces) { + Type type = extractParaTypeFromGeneric(typeTree, baseClass, gInterface, typeParaIdx); + if (type != null) { + return type; + } + } + Type gClass = functionClass.getGenericSuperclass(); + return extractParaTypeFromGeneric(typeTree, baseClass, gClass, typeParaIdx); + } + + private static Type extractParaTypeFromGeneric( + List typeTree, Class baseClass, Type type, int typeParaIdx) { + if (type instanceof ParameterizedType) { + ParameterizedType parameterizedType = (ParameterizedType) type; + typeTree.add(parameterizedType); + Class rawType = (Class) parameterizedType.getRawType(); + if (baseClass.equals(rawType)) { + return parameterizedType.getActualTypeArguments()[typeParaIdx]; + } + if (baseClass.isAssignableFrom(rawType)) { + return extractParaType(typeTree, baseClass, rawType, typeParaIdx); + } + } + if (type instanceof Class) { + Class clazz = (Class) type; + if (baseClass.isAssignableFrom(clazz)) { + return extractParaType(typeTree, baseClass, clazz, typeParaIdx); + } + } + return null; + } + + private static Type getConcreteTypeofTypeVariable( + List typeTree, TypeVariable typeVar) { + TypeVariable curTypeVar = typeVar; + for (int i = typeTree.size() - 1; i >= 0; i--) { + ParameterizedType curType = typeTree.get(i); + Class rawType = (Class) curType.getRawType(); + TypeVariable>[] rawTypeParameters = rawType.getTypeParameters(); + for (int idx = 0; idx < rawTypeParameters.length; idx++) { + TypeVariable rawTypeVar = rawType.getTypeParameters()[idx]; + // check if variable match + if (curTypeVar.getName().equals(rawTypeVar.getName()) + && curTypeVar.getGenericDeclaration().equals(rawTypeVar.getGenericDeclaration())) { + Type actualTypeArgument = curType.getActualTypeArguments()[idx]; + if (actualTypeArgument instanceof TypeVariable) { + // another type variable level + curTypeVar = (TypeVariable) actualTypeArgument; + } else { + // class + return actualTypeArgument; + } } - return fieldCount; + } } - - private static List getPojoFields(Class clazz) { - return Arrays.stream(clazz.getDeclaredFields()) - .filter(field -> !Modifier.isTransient(field.getModifiers()) && !Modifier.isStatic(field.getModifiers())) - .collect(Collectors.toList()); + // most likely type erasure + return curTypeVar; + } + + public static boolean isClassType(Type t) { + return t instanceof Class || t instanceof ParameterizedType; + } + + public static Class typeToClass(Type t) { + if (t instanceof Class) { + return (Class) t; + } else if (t instanceof ParameterizedType) { + return ((Class) ((ParameterizedType) t).getRawType()); } - + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError("Cannot convert type to class")); + } + + private static int countClassFields(Class clazz) { + int fieldCount = 0; + for (Field field : clazz.getFields()) { + if (!Modifier.isStatic(field.getModifiers()) && !Modifier.isTransient(field.getModifiers())) { + fieldCount++; + } + } + return fieldCount; + } + + private static List getPojoFields(Class clazz) { + return Arrays.stream(clazz.getDeclaredFields()) + .filter( + field -> + !Modifier.isTransient(field.getModifiers()) + && !Modifier.isStatic(field.getModifiers())) + .collect(Collectors.toList()); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/Encoders.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/Encoders.java index 17274adf9..8cd05529a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/Encoders.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/Encoders.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.encoder.impl.BooleanArrEncoder; import org.apache.geaflow.common.encoder.impl.BooleanEncoder; import org.apache.geaflow.common.encoder.impl.ByteArrEncoder; @@ -46,90 +47,93 @@ public class Encoders { - public static final IEncoder BOOLEAN = BooleanEncoder.INSTANCE; - public static final IEncoder BYTE = ByteEncoder.INSTANCE; - public static final IEncoder SHORT = ShortEncoder.INSTANCE; - public static final IEncoder INTEGER = IntegerEncoder.INSTANCE; - public static final IEncoder LONG = LongEncoder.INSTANCE; - public static final IEncoder FLOAT = FloatEncoder.INSTANCE; - public static final IEncoder DOUBLE = DoubleEncoder.INSTANCE; - public static final IEncoder CHARACTER = CharacterEncoder.INSTANCE; - public static final IEncoder STRING = StringEncoder.INSTANCE; - - public static final IEncoder BOOLEAN_ARR = BooleanArrEncoder.INSTANCE; - public static final IEncoder BYTE_ARR = ByteArrEncoder.INSTANCE; - public static final IEncoder SHORT_ARR = ShortArrEncoder.INSTANCE; - public static final IEncoder INTEGER_ARR = IntegerArrEncoder.INSTANCE; - public static final IEncoder LONG_ARR = LongArrEncoder.INSTANCE; - public static final IEncoder FLOAT_ARR = FloatArrEncoder.INSTANCE; - public static final IEncoder DOUBLE_ARR = DoubleArrEncoder.INSTANCE; - public static final IEncoder CHARACTER_ARR = CharacterArrEncoder.INSTANCE; + public static final IEncoder BOOLEAN = BooleanEncoder.INSTANCE; + public static final IEncoder BYTE = ByteEncoder.INSTANCE; + public static final IEncoder SHORT = ShortEncoder.INSTANCE; + public static final IEncoder INTEGER = IntegerEncoder.INSTANCE; + public static final IEncoder LONG = LongEncoder.INSTANCE; + public static final IEncoder FLOAT = FloatEncoder.INSTANCE; + public static final IEncoder DOUBLE = DoubleEncoder.INSTANCE; + public static final IEncoder CHARACTER = CharacterEncoder.INSTANCE; + public static final IEncoder STRING = StringEncoder.INSTANCE; - public static final Map, IEncoder> PRIMITIVE_ENCODER_MAP = new HashMap<>(); + public static final IEncoder BOOLEAN_ARR = BooleanArrEncoder.INSTANCE; + public static final IEncoder BYTE_ARR = ByteArrEncoder.INSTANCE; + public static final IEncoder SHORT_ARR = ShortArrEncoder.INSTANCE; + public static final IEncoder INTEGER_ARR = IntegerArrEncoder.INSTANCE; + public static final IEncoder LONG_ARR = LongArrEncoder.INSTANCE; + public static final IEncoder FLOAT_ARR = FloatArrEncoder.INSTANCE; + public static final IEncoder DOUBLE_ARR = DoubleArrEncoder.INSTANCE; + public static final IEncoder CHARACTER_ARR = CharacterArrEncoder.INSTANCE; - static { - PRIMITIVE_ENCODER_MAP.put(boolean.class, BOOLEAN); - PRIMITIVE_ENCODER_MAP.put(Boolean.class, BOOLEAN); - PRIMITIVE_ENCODER_MAP.put(byte.class, BYTE); - PRIMITIVE_ENCODER_MAP.put(Byte.class, BYTE); - PRIMITIVE_ENCODER_MAP.put(short.class, SHORT); - PRIMITIVE_ENCODER_MAP.put(Short.class, SHORT); - PRIMITIVE_ENCODER_MAP.put(int.class, INTEGER); - PRIMITIVE_ENCODER_MAP.put(Integer.class, INTEGER); - PRIMITIVE_ENCODER_MAP.put(long.class, LONG); - PRIMITIVE_ENCODER_MAP.put(Long.class, LONG); - PRIMITIVE_ENCODER_MAP.put(float.class, FLOAT); - PRIMITIVE_ENCODER_MAP.put(Float.class, FLOAT); - PRIMITIVE_ENCODER_MAP.put(double.class, DOUBLE); - PRIMITIVE_ENCODER_MAP.put(Double.class, DOUBLE); - PRIMITIVE_ENCODER_MAP.put(char.class, CHARACTER); - PRIMITIVE_ENCODER_MAP.put(Character.class, CHARACTER); - PRIMITIVE_ENCODER_MAP.put(String.class, STRING); - } + public static final Map, IEncoder> PRIMITIVE_ENCODER_MAP = new HashMap<>(); - public static final Map, IEncoder> PRIMITIVE_ARR_ENCODER_MAP = new HashMap<>(); + static { + PRIMITIVE_ENCODER_MAP.put(boolean.class, BOOLEAN); + PRIMITIVE_ENCODER_MAP.put(Boolean.class, BOOLEAN); + PRIMITIVE_ENCODER_MAP.put(byte.class, BYTE); + PRIMITIVE_ENCODER_MAP.put(Byte.class, BYTE); + PRIMITIVE_ENCODER_MAP.put(short.class, SHORT); + PRIMITIVE_ENCODER_MAP.put(Short.class, SHORT); + PRIMITIVE_ENCODER_MAP.put(int.class, INTEGER); + PRIMITIVE_ENCODER_MAP.put(Integer.class, INTEGER); + PRIMITIVE_ENCODER_MAP.put(long.class, LONG); + PRIMITIVE_ENCODER_MAP.put(Long.class, LONG); + PRIMITIVE_ENCODER_MAP.put(float.class, FLOAT); + PRIMITIVE_ENCODER_MAP.put(Float.class, FLOAT); + PRIMITIVE_ENCODER_MAP.put(double.class, DOUBLE); + PRIMITIVE_ENCODER_MAP.put(Double.class, DOUBLE); + PRIMITIVE_ENCODER_MAP.put(char.class, CHARACTER); + PRIMITIVE_ENCODER_MAP.put(Character.class, CHARACTER); + PRIMITIVE_ENCODER_MAP.put(String.class, STRING); + } - static { - PRIMITIVE_ARR_ENCODER_MAP.put(boolean[].class, BOOLEAN_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Boolean[].class, new GenericArrayEncoder<>(BOOLEAN, Boolean[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(byte[].class, BYTE_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Byte[].class, new GenericArrayEncoder<>(BYTE, Byte[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(short[].class, SHORT_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Short[].class, new GenericArrayEncoder<>(SHORT, Short[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(int[].class, INTEGER_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Integer[].class, new GenericArrayEncoder<>(INTEGER, Integer[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(long[].class, LONG_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Long[].class, new GenericArrayEncoder<>(LONG, Long[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(float[].class, FLOAT_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Float[].class, new GenericArrayEncoder<>(FLOAT, Float[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(double[].class, DOUBLE_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Double[].class, new GenericArrayEncoder<>(DOUBLE, Double[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(char[].class, CHARACTER_ARR); - PRIMITIVE_ARR_ENCODER_MAP.put(Character[].class, new GenericArrayEncoder<>(CHARACTER, Character[]::new)); - PRIMITIVE_ARR_ENCODER_MAP.put(String[].class, new GenericArrayEncoder<>(STRING, String[]::new)); - } + public static final Map, IEncoder> PRIMITIVE_ARR_ENCODER_MAP = new HashMap<>(); - public static final Map, Class> PRIMITIVE_WRAPPER_MAP = new HashMap<>(); + static { + PRIMITIVE_ARR_ENCODER_MAP.put(boolean[].class, BOOLEAN_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put( + Boolean[].class, new GenericArrayEncoder<>(BOOLEAN, Boolean[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(byte[].class, BYTE_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put(Byte[].class, new GenericArrayEncoder<>(BYTE, Byte[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(short[].class, SHORT_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put(Short[].class, new GenericArrayEncoder<>(SHORT, Short[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(int[].class, INTEGER_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put( + Integer[].class, new GenericArrayEncoder<>(INTEGER, Integer[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(long[].class, LONG_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put(Long[].class, new GenericArrayEncoder<>(LONG, Long[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(float[].class, FLOAT_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put(Float[].class, new GenericArrayEncoder<>(FLOAT, Float[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(double[].class, DOUBLE_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put(Double[].class, new GenericArrayEncoder<>(DOUBLE, Double[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(char[].class, CHARACTER_ARR); + PRIMITIVE_ARR_ENCODER_MAP.put( + Character[].class, new GenericArrayEncoder<>(CHARACTER, Character[]::new)); + PRIMITIVE_ARR_ENCODER_MAP.put(String[].class, new GenericArrayEncoder<>(STRING, String[]::new)); + } - static { - PRIMITIVE_WRAPPER_MAP.put(Boolean.TYPE, Boolean.class); - PRIMITIVE_WRAPPER_MAP.put(Byte.TYPE, Byte.class); - PRIMITIVE_WRAPPER_MAP.put(Character.TYPE, Character.class); - PRIMITIVE_WRAPPER_MAP.put(Short.TYPE, Short.class); - PRIMITIVE_WRAPPER_MAP.put(Integer.TYPE, Integer.class); - PRIMITIVE_WRAPPER_MAP.put(Long.TYPE, Long.class); - PRIMITIVE_WRAPPER_MAP.put(Double.TYPE, Double.class); - PRIMITIVE_WRAPPER_MAP.put(Float.TYPE, Float.class); - PRIMITIVE_WRAPPER_MAP.put(Void.TYPE, Void.TYPE); - } + public static final Map, Class> PRIMITIVE_WRAPPER_MAP = new HashMap<>(); - public static IEncoder> tuple(IEncoder encoder0, IEncoder encoder1) { - return new TupleEncoder<>(encoder0, encoder1); - } + static { + PRIMITIVE_WRAPPER_MAP.put(Boolean.TYPE, Boolean.class); + PRIMITIVE_WRAPPER_MAP.put(Byte.TYPE, Byte.class); + PRIMITIVE_WRAPPER_MAP.put(Character.TYPE, Character.class); + PRIMITIVE_WRAPPER_MAP.put(Short.TYPE, Short.class); + PRIMITIVE_WRAPPER_MAP.put(Integer.TYPE, Integer.class); + PRIMITIVE_WRAPPER_MAP.put(Long.TYPE, Long.class); + PRIMITIVE_WRAPPER_MAP.put(Double.TYPE, Double.class); + PRIMITIVE_WRAPPER_MAP.put(Float.TYPE, Float.class); + PRIMITIVE_WRAPPER_MAP.put(Void.TYPE, Void.TYPE); + } - public static IEncoder> triple( - IEncoder encoder0, IEncoder encoder1, IEncoder encoder2) { - return new TripleEncoder<>(encoder0, encoder1, encoder2); - } + public static IEncoder> tuple( + IEncoder encoder0, IEncoder encoder1) { + return new TupleEncoder<>(encoder0, encoder1); + } + public static IEncoder> triple( + IEncoder encoder0, IEncoder encoder1, IEncoder encoder2) { + return new TripleEncoder<>(encoder0, encoder1, encoder2); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/IEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/IEncoder.java index a19a404e8..51a1c2052 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/IEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/IEncoder.java @@ -23,33 +23,33 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.Serializable; + import org.apache.geaflow.common.config.Configuration; public interface IEncoder extends Serializable { - /** - * Init with config. - * - * @param config config - */ - void init(Configuration config); - - /** - * Encode an object to output stream. - * - * @param data data - * @param outputStream output stream - * @throws IOException IO exception - */ - void encode(T data, OutputStream outputStream) throws IOException; + /** + * Init with config. + * + * @param config config + */ + void init(Configuration config); - /** - * Decode an object from input stream. - * - * @param inputStream input stream - * @return data - * @throws IOException IO exception - */ - T decode(InputStream inputStream) throws IOException; + /** + * Encode an object to output stream. + * + * @param data data + * @param outputStream output stream + * @throws IOException IO exception + */ + void encode(T data, OutputStream outputStream) throws IOException; + /** + * Decode an object from input stream. + * + * @param inputStream input stream + * @return data + * @throws IOException IO exception + */ + T decode(InputStream inputStream) throws IOException; } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/RpcMessageEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/RpcMessageEncoder.java index fb12000ac..60cafb593 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/RpcMessageEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/RpcMessageEncoder.java @@ -19,22 +19,24 @@ package org.apache.geaflow.common.encoder; -import com.google.protobuf.ByteString; import java.io.Serializable; + import org.apache.geaflow.common.serialize.SerializerFactory; +import com.google.protobuf.ByteString; + public class RpcMessageEncoder implements Serializable { - public static T decode(ByteString payload) { - if (payload == null || payload.isEmpty()) { - throw new IllegalArgumentException("Cannot decode null or empty ByteString payload"); - } - return SerializerFactory.getKryoSerializer().deserialize(payload.newInput()); + public static T decode(ByteString payload) { + if (payload == null || payload.isEmpty()) { + throw new IllegalArgumentException("Cannot decode null or empty ByteString payload"); } + return SerializerFactory.getKryoSerializer().deserialize(payload.newInput()); + } - public static ByteString encode(T request) { - ByteString.Output output = ByteString.newOutput(); - SerializerFactory.getKryoSerializer().serialize(request, output); - return output.toByteString(); - } + public static ByteString encode(T request) { + ByteString.Output output = ByteString.newOutput(); + SerializerFactory.getKryoSerializer().serialize(request, output); + return output.toByteString(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/AbstractEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/AbstractEncoder.java index a304a2bf5..8f25cba61 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/AbstractEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/AbstractEncoder.java @@ -24,12 +24,10 @@ public abstract class AbstractEncoder implements IEncoder { - protected static final String MSG_ARR_TOO_BIG = "arr is too big"; - protected static final int NULL = 0; - protected static final int NOT_NULL = 1; - - @Override - public void init(Configuration config) { - } + protected static final String MSG_ARR_TOO_BIG = "arr is too big"; + protected static final int NULL = 0; + protected static final int NOT_NULL = 1; + @Override + public void init(Configuration config) {} } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanArrEncoder.java index 18abf36dd..1dec54796 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class BooleanArrEncoder extends AbstractEncoder { - public static final BooleanArrEncoder INSTANCE = new BooleanArrEncoder(); + public static final BooleanArrEncoder INSTANCE = new BooleanArrEncoder(); - @Override - public void encode(boolean[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (boolean datum : data) { - Encoders.BOOLEAN.encode(datum, outputStream); - } + @Override + public void encode(boolean[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public boolean[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - boolean[] arr = new boolean[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.BOOLEAN.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (boolean datum : data) { + Encoders.BOOLEAN.encode(datum, outputStream); } + } + @Override + public boolean[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + boolean[] arr = new boolean[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.BOOLEAN.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanEncoder.java index 59759cbaa..613f0270b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/BooleanEncoder.java @@ -22,36 +22,36 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; public class BooleanEncoder extends AbstractEncoder { - public static final BooleanEncoder INSTANCE = new BooleanEncoder(); - - private static final int TRUE = 1; - private static final int FALSE = 2; - - @Override - public void encode(Boolean data, OutputStream outputStream) throws IOException { - // 0: null - // 1: true - // 2: false - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - } else if (data) { - Encoders.INTEGER.encode(TRUE, outputStream); - } else { - Encoders.INTEGER.encode(FALSE, outputStream); - } + public static final BooleanEncoder INSTANCE = new BooleanEncoder(); + + private static final int TRUE = 1; + private static final int FALSE = 2; + + @Override + public void encode(Boolean data, OutputStream outputStream) throws IOException { + // 0: null + // 1: true + // 2: false + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + } else if (data) { + Encoders.INTEGER.encode(TRUE, outputStream); + } else { + Encoders.INTEGER.encode(FALSE, outputStream); } + } - @Override - public Boolean decode(InputStream inputStream) throws IOException { - Integer value = Encoders.INTEGER.decode(inputStream); - if (value == NULL) { - return null; - } - return value == TRUE; + @Override + public Boolean decode(InputStream inputStream) throws IOException { + Integer value = Encoders.INTEGER.decode(inputStream); + if (value == NULL) { + return null; } - + return value == TRUE; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteArrEncoder.java index 733adc0f9..de6fa5709 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class ByteArrEncoder extends AbstractEncoder { - public static final ByteArrEncoder INSTANCE = new ByteArrEncoder(); + public static final ByteArrEncoder INSTANCE = new ByteArrEncoder(); - @Override - public void encode(byte[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (byte datum : data) { - Encoders.BYTE.encode(datum, outputStream); - } + @Override + public void encode(byte[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public byte[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - byte[] arr = new byte[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.BYTE.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (byte datum : data) { + Encoders.BYTE.encode(datum, outputStream); } + } + @Override + public byte[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + byte[] arr = new byte[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.BYTE.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteEncoder.java index f248bb38a..6c702468d 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ByteEncoder.java @@ -25,16 +25,15 @@ public class ByteEncoder extends AbstractEncoder { - public static final ByteEncoder INSTANCE = new ByteEncoder(); + public static final ByteEncoder INSTANCE = new ByteEncoder(); - @Override - public void encode(Byte data, OutputStream outputStream) throws IOException { - outputStream.write(data); - } - - @Override - public Byte decode(InputStream inputStream) throws IOException { - return (byte) inputStream.read(); - } + @Override + public void encode(Byte data, OutputStream outputStream) throws IOException { + outputStream.write(data); + } + @Override + public Byte decode(InputStream inputStream) throws IOException { + return (byte) inputStream.read(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterArrEncoder.java index 1e0c50497..54a25cad3 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class CharacterArrEncoder extends AbstractEncoder { - public static final CharacterArrEncoder INSTANCE = new CharacterArrEncoder(); + public static final CharacterArrEncoder INSTANCE = new CharacterArrEncoder(); - @Override - public void encode(char[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (char datum : data) { - Encoders.CHARACTER.encode(datum, outputStream); - } + @Override + public void encode(char[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public char[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - char[] arr = new char[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.CHARACTER.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (char datum : data) { + Encoders.CHARACTER.encode(datum, outputStream); } + } + @Override + public char[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + char[] arr = new char[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.CHARACTER.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterEncoder.java index 1c7ea81e8..7f1ee8100 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/CharacterEncoder.java @@ -25,19 +25,18 @@ public class CharacterEncoder extends AbstractEncoder { - public static final CharacterEncoder INSTANCE = new CharacterEncoder(); + public static final CharacterEncoder INSTANCE = new CharacterEncoder(); - @Override - public void encode(Character data, OutputStream outputStream) throws IOException { - outputStream.write(data); - outputStream.write(data >> 8); - } - - @Override - public Character decode(InputStream inputStream) throws IOException { - int b1 = inputStream.read(); - int b2 = inputStream.read(); - return (char) (b1 | (b2 << 8)); - } + @Override + public void encode(Character data, OutputStream outputStream) throws IOException { + outputStream.write(data); + outputStream.write(data >> 8); + } + @Override + public Character decode(InputStream inputStream) throws IOException { + int b1 = inputStream.read(); + int b2 = inputStream.read(); + return (char) (b1 | (b2 << 8)); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleArrEncoder.java index 8bbff8953..1f4fca3ae 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class DoubleArrEncoder extends AbstractEncoder { - public static final DoubleArrEncoder INSTANCE = new DoubleArrEncoder(); + public static final DoubleArrEncoder INSTANCE = new DoubleArrEncoder(); - @Override - public void encode(double[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (double datum : data) { - Encoders.DOUBLE.encode(datum, outputStream); - } + @Override + public void encode(double[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public double[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - double[] arr = new double[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.DOUBLE.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (double datum : data) { + Encoders.DOUBLE.encode(datum, outputStream); } + } + @Override + public double[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + double[] arr = new double[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.DOUBLE.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleEncoder.java index c5f7384b5..62b41ee73 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/DoubleEncoder.java @@ -22,20 +22,20 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; public class DoubleEncoder extends AbstractEncoder { - public static final DoubleEncoder INSTANCE = new DoubleEncoder(); - - @Override - public void encode(Double data, OutputStream outputStream) throws IOException { - Encoders.LONG.encode(Double.doubleToLongBits(data), outputStream); - } + public static final DoubleEncoder INSTANCE = new DoubleEncoder(); - @Override - public Double decode(InputStream inputStream) throws IOException { - return Double.longBitsToDouble(Encoders.LONG.decode(inputStream)); - } + @Override + public void encode(Double data, OutputStream outputStream) throws IOException { + Encoders.LONG.encode(Double.doubleToLongBits(data), outputStream); + } + @Override + public Double decode(InputStream inputStream) throws IOException { + return Double.longBitsToDouble(Encoders.LONG.decode(inputStream)); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/EnumEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/EnumEncoder.java index fa124767c..f9746f8f9 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/EnumEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/EnumEncoder.java @@ -24,6 +24,7 @@ import java.io.OutputStream; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -31,44 +32,43 @@ public class EnumEncoder extends AbstractEncoder { - private final Class enumClass; - private T[] values; - private Map value2ordinal; + private final Class enumClass; + private T[] values; + private Map value2ordinal; - public EnumEncoder(Class enumClass) { - if (!Enum.class.isAssignableFrom(enumClass)) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError( - enumClass.getCanonicalName() + "is not an enum")); - } - this.enumClass = enumClass; + public EnumEncoder(Class enumClass) { + if (!Enum.class.isAssignableFrom(enumClass)) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(enumClass.getCanonicalName() + "is not an enum")); } + this.enumClass = enumClass; + } - @Override - public void init(Configuration config) { - this.value2ordinal = new HashMap<>(); - this.values = this.enumClass.getEnumConstants(); - int i = 0; - for (T value : this.values) { - this.value2ordinal.put(value, ++i); - } + @Override + public void init(Configuration config) { + this.value2ordinal = new HashMap<>(); + this.values = this.enumClass.getEnumConstants(); + int i = 0; + for (T value : this.values) { + this.value2ordinal.put(value, ++i); } + } - @Override - public void encode(T data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - Encoders.INTEGER.encode(this.value2ordinal.get(data), outputStream); + @Override + public void encode(T data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } + Encoders.INTEGER.encode(this.value2ordinal.get(data), outputStream); + } - @Override - public T decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - return this.values[flag - 1]; + @Override + public T decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; } - + return this.values[flag - 1]; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatArrEncoder.java index 21f0afc29..098560538 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class FloatArrEncoder extends AbstractEncoder { - public static final FloatArrEncoder INSTANCE = new FloatArrEncoder(); + public static final FloatArrEncoder INSTANCE = new FloatArrEncoder(); - @Override - public void encode(float[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (float datum : data) { - Encoders.FLOAT.encode(datum, outputStream); - } + @Override + public void encode(float[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public float[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - float[] arr = new float[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.FLOAT.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (float datum : data) { + Encoders.FLOAT.encode(datum, outputStream); } + } + @Override + public float[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + float[] arr = new float[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.FLOAT.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatEncoder.java index 9cb09df52..1c3a0f92e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/FloatEncoder.java @@ -22,20 +22,20 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; public class FloatEncoder extends AbstractEncoder { - public static final FloatEncoder INSTANCE = new FloatEncoder(); - - @Override - public void encode(Float data, OutputStream outputStream) throws IOException { - Encoders.INTEGER.encode(Float.floatToIntBits(data), outputStream); - } + public static final FloatEncoder INSTANCE = new FloatEncoder(); - @Override - public Float decode(InputStream inputStream) throws IOException { - return Float.intBitsToFloat(Encoders.INTEGER.decode(inputStream)); - } + @Override + public void encode(Float data, OutputStream outputStream) throws IOException { + Encoders.INTEGER.encode(Float.floatToIntBits(data), outputStream); + } + @Override + public Float decode(InputStream inputStream) throws IOException { + return Float.intBitsToFloat(Encoders.INTEGER.decode(inputStream)); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/GenericArrayEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/GenericArrayEncoder.java index 4c63be6f9..77081a60b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/GenericArrayEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/GenericArrayEncoder.java @@ -24,6 +24,7 @@ import java.io.OutputStream; import java.io.Serializable; import java.util.function.IntFunction; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.encoder.IEncoder; @@ -32,52 +33,50 @@ public class GenericArrayEncoder extends AbstractEncoder { - private final IEncoder encoder; - private final ArrayConstructor constructor; - - public GenericArrayEncoder(IEncoder encoder, ArrayConstructor constructor) { - this.encoder = encoder; - this.constructor = constructor; - } + private final IEncoder encoder; + private final ArrayConstructor constructor; - @Override - public void init(Configuration config) { - this.encoder.init(config); - } + public GenericArrayEncoder(IEncoder encoder, ArrayConstructor constructor) { + this.encoder = encoder; + this.constructor = constructor; + } - @Override - public void encode(T[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } + @Override + public void init(Configuration config) { + this.encoder.init(config); + } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (T datum : data) { - this.encoder.encode(datum, outputStream); - } + @Override + public void encode(T[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - @Override - public T[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - T[] arr = this.constructor.apply(length); - for (int i = 0; i < length; i++) { - arr[i] = this.encoder.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (T datum : data) { + this.encoder.encode(datum, outputStream); } + } - @FunctionalInterface - public interface ArrayConstructor extends IntFunction, Serializable { + @Override + public T[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + T[] arr = this.constructor.apply(length); + for (int i = 0; i < length; i++) { + arr[i] = this.encoder.decode(inputStream); } + return arr; + } + @FunctionalInterface + public interface ArrayConstructor extends IntFunction, Serializable {} } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerArrEncoder.java index f13cea259..13ab7bb2c 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class IntegerArrEncoder extends AbstractEncoder { - public static final IntegerArrEncoder INSTANCE = new IntegerArrEncoder(); + public static final IntegerArrEncoder INSTANCE = new IntegerArrEncoder(); - @Override - public void encode(int[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (int datum : data) { - Encoders.INTEGER.encode(datum, outputStream); - } + @Override + public void encode(int[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public int[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - int[] arr = new int[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.INTEGER.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (int datum : data) { + Encoders.INTEGER.encode(datum, outputStream); } + } + @Override + public int[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + int[] arr = new int[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.INTEGER.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerEncoder.java index c3963f31b..9fdb89f35 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/IntegerEncoder.java @@ -25,86 +25,85 @@ public class IntegerEncoder extends AbstractEncoder { - public static final IntegerEncoder INSTANCE = new IntegerEncoder(); + public static final IntegerEncoder INSTANCE = new IntegerEncoder(); - // VarInt encoding constants - private static final int VARINT_MASK = 0x7F; - private static final int VARINT_CONTINUE_FLAG = 0x80; - private static final int VARINT_SHIFT = 7; - private static final int DIRECT_WRITE_THRESHOLD = 128; + // VarInt encoding constants + private static final int VARINT_MASK = 0x7F; + private static final int VARINT_CONTINUE_FLAG = 0x80; + private static final int VARINT_SHIFT = 7; + private static final int DIRECT_WRITE_THRESHOLD = 128; - @Override - public void encode(Integer data, OutputStream outputStream) throws IOException { - // if between 0 ~ 127, just write the byte - if (data >= 0 && data < DIRECT_WRITE_THRESHOLD) { - outputStream.write(data); - return; - } + @Override + public void encode(Integer data, OutputStream outputStream) throws IOException { + // if between 0 ~ 127, just write the byte + if (data >= 0 && data < DIRECT_WRITE_THRESHOLD) { + outputStream.write(data); + return; + } - // write var int, takes 1 ~ 5 byte - int value = data; - int varInt = (value & VARINT_MASK); - value >>>= VARINT_SHIFT; + // write var int, takes 1 ~ 5 byte + int value = data; + int varInt = (value & VARINT_MASK); + value >>>= VARINT_SHIFT; - varInt |= VARINT_CONTINUE_FLAG; - varInt |= ((value & VARINT_MASK) << 8); - value >>>= VARINT_SHIFT; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - return; - } - - varInt |= (VARINT_CONTINUE_FLAG << 8); - varInt |= ((value & VARINT_MASK) << 16); - value >>>= VARINT_SHIFT; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - outputStream.write(varInt >> 16); - return; - } + varInt |= VARINT_CONTINUE_FLAG; + varInt |= ((value & VARINT_MASK) << 8); + value >>>= VARINT_SHIFT; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + return; + } - varInt |= (VARINT_CONTINUE_FLAG << 16); - varInt |= ((value & VARINT_MASK) << 24); - value >>>= VARINT_SHIFT; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - outputStream.write(varInt >> 16); - outputStream.write(varInt >> 24); - return; - } + varInt |= (VARINT_CONTINUE_FLAG << 8); + varInt |= ((value & VARINT_MASK) << 16); + value >>>= VARINT_SHIFT; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + outputStream.write(varInt >> 16); + return; + } - varInt |= (VARINT_CONTINUE_FLAG << 24); - outputStream.write(varInt); - outputStream.write(varInt >> 8); - outputStream.write(varInt >> 16); - outputStream.write(varInt >> 24); - outputStream.write(value); + varInt |= (VARINT_CONTINUE_FLAG << 16); + varInt |= ((value & VARINT_MASK) << 24); + value >>>= VARINT_SHIFT; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + outputStream.write(varInt >> 16); + outputStream.write(varInt >> 24); + return; } - @Override - public Integer decode(InputStream inputStream) throws IOException { - int b = inputStream.read(); - int result = b & VARINT_MASK; + varInt |= (VARINT_CONTINUE_FLAG << 24); + outputStream.write(varInt); + outputStream.write(varInt >> 8); + outputStream.write(varInt >> 16); + outputStream.write(varInt >> 24); + outputStream.write(value); + } + + @Override + public Integer decode(InputStream inputStream) throws IOException { + int b = inputStream.read(); + int result = b & VARINT_MASK; + if ((b & VARINT_CONTINUE_FLAG) != 0) { + b = inputStream.read(); + result |= (b & VARINT_MASK) << VARINT_SHIFT; + if ((b & VARINT_CONTINUE_FLAG) != 0) { + b = inputStream.read(); + result |= (b & VARINT_MASK) << (VARINT_SHIFT * 2); if ((b & VARINT_CONTINUE_FLAG) != 0) { + b = inputStream.read(); + result |= (b & VARINT_MASK) << (VARINT_SHIFT * 3); + if ((b & VARINT_CONTINUE_FLAG) != 0) { b = inputStream.read(); - result |= (b & VARINT_MASK) << VARINT_SHIFT; - if ((b & VARINT_CONTINUE_FLAG) != 0) { - b = inputStream.read(); - result |= (b & VARINT_MASK) << (VARINT_SHIFT * 2); - if ((b & VARINT_CONTINUE_FLAG) != 0) { - b = inputStream.read(); - result |= (b & VARINT_MASK) << (VARINT_SHIFT * 3); - if ((b & VARINT_CONTINUE_FLAG) != 0) { - b = inputStream.read(); - result |= (b & VARINT_MASK) << (VARINT_SHIFT * 4); - } - } - } + result |= (b & VARINT_MASK) << (VARINT_SHIFT * 4); + } } - return result; + } } - + return result; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongArrEncoder.java index 18bc74bb3..1404e6023 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class LongArrEncoder extends AbstractEncoder { - public static final LongArrEncoder INSTANCE = new LongArrEncoder(); + public static final LongArrEncoder INSTANCE = new LongArrEncoder(); - @Override - public void encode(long[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (long datum : data) { - Encoders.LONG.encode(datum, outputStream); - } + @Override + public void encode(long[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public long[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - long[] arr = new long[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.LONG.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (long datum : data) { + Encoders.LONG.encode(datum, outputStream); } + } + @Override + public long[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + long[] arr = new long[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.LONG.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongEncoder.java index f47a950ba..4645ea2c5 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/LongEncoder.java @@ -25,155 +25,154 @@ public class LongEncoder extends AbstractEncoder { - public static final LongEncoder INSTANCE = new LongEncoder(); - - @Override - public void encode(Long data, OutputStream outputStream) throws IOException { - // if between 0 ~ 127, just write the byte - if (data >= 0 && data < 128) { - outputStream.write(data.intValue()); - return; - } - - // write var long, takes 1 ~ 9 byte - long value = data; - int varInt = (int) (value & 0x7F); - value >>>= 7; - - varInt |= 0x80; - varInt |= ((value & 0x7F) << 8); - value >>>= 7; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - return; - } + public static final LongEncoder INSTANCE = new LongEncoder(); + + @Override + public void encode(Long data, OutputStream outputStream) throws IOException { + // if between 0 ~ 127, just write the byte + if (data >= 0 && data < 128) { + outputStream.write(data.intValue()); + return; + } - varInt |= (0x80 << 8); - varInt |= ((value & 0x7F) << 16); - value >>>= 7; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - outputStream.write(varInt >> 16); - return; - } + // write var long, takes 1 ~ 9 byte + long value = data; + int varInt = (int) (value & 0x7F); + value >>>= 7; + + varInt |= 0x80; + varInt |= ((value & 0x7F) << 8); + value >>>= 7; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + return; + } - varInt |= (0x80 << 16); - varInt |= ((value & 0x7F) << 24); - value >>>= 7; - if (value == 0) { - outputStream.write(varInt); - outputStream.write(varInt >> 8); - outputStream.write(varInt >> 16); - outputStream.write(varInt >> 24); - return; - } + varInt |= (0x80 << 8); + varInt |= ((value & 0x7F) << 16); + value >>>= 7; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + outputStream.write(varInt >> 16); + return; + } - varInt |= (0x80 << 24); - long varLong = (varInt & 0xFFFFFFFFL); - varLong |= (((value & 0x7F)) << 32); - value >>>= 7; - if (value == 0) { - outputStream.write((int) varLong); - outputStream.write((int) (varLong >> 8)); - outputStream.write((int) (varLong >> 16)); - outputStream.write((int) (varLong >> 24)); - outputStream.write((int) (varLong >> 32)); - return; - } + varInt |= (0x80 << 16); + varInt |= ((value & 0x7F) << 24); + value >>>= 7; + if (value == 0) { + outputStream.write(varInt); + outputStream.write(varInt >> 8); + outputStream.write(varInt >> 16); + outputStream.write(varInt >> 24); + return; + } - varLong |= (0x80L << 32); - varLong |= (((value & 0x7F)) << 40); - value >>>= 7; - if (value == 0) { - outputStream.write((int) varLong); - outputStream.write((int) (varLong >> 8)); - outputStream.write((int) (varLong >> 16)); - outputStream.write((int) (varLong >> 24)); - outputStream.write((int) (varLong >> 32)); - outputStream.write((int) (varLong >> 40)); - return; - } + varInt |= (0x80 << 24); + long varLong = (varInt & 0xFFFFFFFFL); + varLong |= (((value & 0x7F)) << 32); + value >>>= 7; + if (value == 0) { + outputStream.write((int) varLong); + outputStream.write((int) (varLong >> 8)); + outputStream.write((int) (varLong >> 16)); + outputStream.write((int) (varLong >> 24)); + outputStream.write((int) (varLong >> 32)); + return; + } - varLong |= (0x80L << 40); - varLong |= (((value & 0x7F)) << 48); - value >>>= 7; - if (value == 0) { - outputStream.write((int) varLong); - outputStream.write((int) (varLong >> 8)); - outputStream.write((int) (varLong >> 16)); - outputStream.write((int) (varLong >> 24)); - outputStream.write((int) (varLong >> 32)); - outputStream.write((int) (varLong >> 40)); - outputStream.write((int) (varLong >> 48)); - return; - } + varLong |= (0x80L << 32); + varLong |= (((value & 0x7F)) << 40); + value >>>= 7; + if (value == 0) { + outputStream.write((int) varLong); + outputStream.write((int) (varLong >> 8)); + outputStream.write((int) (varLong >> 16)); + outputStream.write((int) (varLong >> 24)); + outputStream.write((int) (varLong >> 32)); + outputStream.write((int) (varLong >> 40)); + return; + } - varLong |= (0x80L << 48); - varLong |= (((value & 0x7F)) << 56); - value >>>= 7; - if (value == 0) { - outputStream.write((int) varLong); - outputStream.write((int) (varLong >> 8)); - outputStream.write((int) (varLong >> 16)); - outputStream.write((int) (varLong >> 24)); - outputStream.write((int) (varLong >> 32)); - outputStream.write((int) (varLong >> 40)); - outputStream.write((int) (varLong >> 48)); - outputStream.write((int) (varLong >> 56)); - return; - } + varLong |= (0x80L << 40); + varLong |= (((value & 0x7F)) << 48); + value >>>= 7; + if (value == 0) { + outputStream.write((int) varLong); + outputStream.write((int) (varLong >> 8)); + outputStream.write((int) (varLong >> 16)); + outputStream.write((int) (varLong >> 24)); + outputStream.write((int) (varLong >> 32)); + outputStream.write((int) (varLong >> 40)); + outputStream.write((int) (varLong >> 48)); + return; + } - varLong |= (0x80L << 56); - outputStream.write((int) varLong); - outputStream.write((int) (varLong >> 8)); - outputStream.write((int) (varLong >> 16)); - outputStream.write((int) (varLong >> 24)); - outputStream.write((int) (varLong >> 32)); - outputStream.write((int) (varLong >> 40)); - outputStream.write((int) (varLong >> 48)); - outputStream.write((int) (varLong >> 56)); - outputStream.write((int) value); + varLong |= (0x80L << 48); + varLong |= (((value & 0x7F)) << 56); + value >>>= 7; + if (value == 0) { + outputStream.write((int) varLong); + outputStream.write((int) (varLong >> 8)); + outputStream.write((int) (varLong >> 16)); + outputStream.write((int) (varLong >> 24)); + outputStream.write((int) (varLong >> 32)); + outputStream.write((int) (varLong >> 40)); + outputStream.write((int) (varLong >> 48)); + outputStream.write((int) (varLong >> 56)); + return; } - @Override - public Long decode(InputStream inputStream) throws IOException { - int b = inputStream.read(); - long result = b & 0x7F; + varLong |= (0x80L << 56); + outputStream.write((int) varLong); + outputStream.write((int) (varLong >> 8)); + outputStream.write((int) (varLong >> 16)); + outputStream.write((int) (varLong >> 24)); + outputStream.write((int) (varLong >> 32)); + outputStream.write((int) (varLong >> 40)); + outputStream.write((int) (varLong >> 48)); + outputStream.write((int) (varLong >> 56)); + outputStream.write((int) value); + } + + @Override + public Long decode(InputStream inputStream) throws IOException { + int b = inputStream.read(); + long result = b & 0x7F; + if ((b & 0x80) != 0) { + b = inputStream.read(); + result |= (b & 0x7F) << 7; + if ((b & 0x80) != 0) { + b = inputStream.read(); + result |= (b & 0x7F) << 14; if ((b & 0x80) != 0) { + b = inputStream.read(); + result |= (b & 0x7F) << 21; + if ((b & 0x80) != 0) { b = inputStream.read(); - result |= (b & 0x7F) << 7; + result |= (long) (b & 0x7F) << 28; if ((b & 0x80) != 0) { + b = inputStream.read(); + result |= (long) (b & 0x7F) << 35; + if ((b & 0x80) != 0) { b = inputStream.read(); - result |= (b & 0x7F) << 14; + result |= (long) (b & 0x7F) << 42; if ((b & 0x80) != 0) { + b = inputStream.read(); + result |= (long) (b & 0x7F) << 49; + if ((b & 0x80) != 0) { b = inputStream.read(); - result |= (b & 0x7F) << 21; - if ((b & 0x80) != 0) { - b = inputStream.read(); - result |= (long) (b & 0x7F) << 28; - if ((b & 0x80) != 0) { - b = inputStream.read(); - result |= (long) (b & 0x7F) << 35; - if ((b & 0x80) != 0) { - b = inputStream.read(); - result |= (long) (b & 0x7F) << 42; - if ((b & 0x80) != 0) { - b = inputStream.read(); - result |= (long) (b & 0x7F) << 49; - if ((b & 0x80) != 0) { - b = inputStream.read(); - result |= (long) b << 56; - } - } - } - } - } + result |= (long) b << 56; + } } + } } + } } - return result; + } } - + return result; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/PojoEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/PojoEncoder.java index c54a9a182..9deac9d42 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/PojoEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/PojoEncoder.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.EncoderResolver; import org.apache.geaflow.common.encoder.Encoders; @@ -41,123 +42,124 @@ public class PojoEncoder extends AbstractEncoder { - private static final Map, PojoField[]> POJO_FIELDS_CACHE = new HashMap<>(); + private static final Map, PojoField[]> POJO_FIELDS_CACHE = new HashMap<>(); - private final Class clazz; - private PojoField[] pojoFields; + private final Class clazz; + private PojoField[] pojoFields; - public static PojoField[] getPojoFields(Class clazz) { + public static PojoField[] getPojoFields(Class clazz) { + if (!POJO_FIELDS_CACHE.containsKey(clazz)) { + synchronized (PojoEncoder.class) { if (!POJO_FIELDS_CACHE.containsKey(clazz)) { - synchronized (PojoEncoder.class) { - if (!POJO_FIELDS_CACHE.containsKey(clazz)) { - List fields = Arrays.stream(clazz.getDeclaredFields()) - .filter(field -> !Modifier.isTransient(field.getModifiers()) && !Modifier.isStatic(field.getModifiers())) - .sorted(Comparator.comparing(Field::getName)) - .peek(f -> f.setAccessible(true)) - .collect(Collectors.toList()); - - List pojoFields = new ArrayList<>(); - for (Field field : fields) { - Class fieldType = field.getType(); - IEncoder encoder = EncoderResolver.resolveClass(fieldType); - pojoFields.add(PojoField.build(field, encoder)); - } - - PojoField[] fieldsArr = pojoFields.stream() - .peek(f -> f.getField().setAccessible(true)) - .sorted(Comparator.comparing(f -> f.getField().getName())) - .toArray(PojoField[]::new); - POJO_FIELDS_CACHE.put(clazz, fieldsArr); - } - } + List fields = + Arrays.stream(clazz.getDeclaredFields()) + .filter( + field -> + !Modifier.isTransient(field.getModifiers()) + && !Modifier.isStatic(field.getModifiers())) + .sorted(Comparator.comparing(Field::getName)) + .peek(f -> f.setAccessible(true)) + .collect(Collectors.toList()); + + List pojoFields = new ArrayList<>(); + for (Field field : fields) { + Class fieldType = field.getType(); + IEncoder encoder = EncoderResolver.resolveClass(fieldType); + pojoFields.add(PojoField.build(field, encoder)); + } + + PojoField[] fieldsArr = + pojoFields.stream() + .peek(f -> f.getField().setAccessible(true)) + .sorted(Comparator.comparing(f -> f.getField().getName())) + .toArray(PojoField[]::new); + POJO_FIELDS_CACHE.put(clazz, fieldsArr); } - return POJO_FIELDS_CACHE.get(clazz); + } } + return POJO_FIELDS_CACHE.get(clazz); + } - public static PojoEncoder build(Class clazz) { - return new PojoEncoder<>(clazz); - } + public static PojoEncoder build(Class clazz) { + return new PojoEncoder<>(clazz); + } - public PojoEncoder(Class clazz) { - this.clazz = clazz; - } + public PojoEncoder(Class clazz) { + this.clazz = clazz; + } - @Override - public void init(Configuration config) { - if (this.pojoFields == null) { - this.pojoFields = getPojoFields(this.clazz); - } + @Override + public void init(Configuration config) { + if (this.pojoFields == null) { + this.pojoFields = getPojoFields(this.clazz); } - - @Override - public void encode(T data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; + } + + @Override + public void encode(T data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; + } else { + Encoders.INTEGER.encode(NOT_NULL, outputStream); + } + try { + for (int i = 0; i < this.pojoFields.length; i++) { + PojoField pojoField = this.pojoFields[i]; + Object value = pojoField.getField().get(data); + if (value == null) { + Encoders.INTEGER.encode(NULL, outputStream); } else { - Encoders.INTEGER.encode(NOT_NULL, outputStream); - } - try { - for (int i = 0; i < this.pojoFields.length; i++) { - PojoField pojoField = this.pojoFields[i]; - Object value = pojoField.getField().get(data); - if (value == null) { - Encoders.INTEGER.encode(NULL, outputStream); - } else { - Encoders.INTEGER.encode(NOT_NULL, outputStream); - ((IEncoder) pojoField.getEncoder()).encode(value, outputStream); - } - } - } catch (IllegalAccessException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(e.getMessage()), e); + Encoders.INTEGER.encode(NOT_NULL, outputStream); + ((IEncoder) pojoField.getEncoder()).encode(value, outputStream); } - + } + } catch (IllegalAccessException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(e.getMessage()), e); } + } - @Override - public T decode(InputStream inputStream) throws IOException { - if (Encoders.INTEGER.decode(inputStream) == NULL) { - return null; - } - try { - T obj = clazz.newInstance(); - for (int i = 0; i < this.pojoFields.length; i++) { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NOT_NULL) { - PojoField pojoField = this.pojoFields[i]; - Object value = pojoField.getEncoder().decode(inputStream); - pojoField.getField().set(obj, value); - } - } - return obj; - } catch (InstantiationException | IllegalAccessException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(e.getMessage()), e); + @Override + public T decode(InputStream inputStream) throws IOException { + if (Encoders.INTEGER.decode(inputStream) == NULL) { + return null; + } + try { + T obj = clazz.newInstance(); + for (int i = 0; i < this.pojoFields.length; i++) { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NOT_NULL) { + PojoField pojoField = this.pojoFields[i]; + Object value = pojoField.getEncoder().decode(inputStream); + pojoField.getField().set(obj, value); } - + } + return obj; + } catch (InstantiationException | IllegalAccessException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(e.getMessage()), e); } + } - public static class PojoField implements Serializable { - - private final Field field; - private final IEncoder encoder; + public static class PojoField implements Serializable { - public static PojoField build(Field field, IEncoder encoder) { - return new PojoField(field, encoder); - } - - public PojoField(Field field, IEncoder encoder) { - this.field = field; - this.encoder = encoder; - } + private final Field field; + private final IEncoder encoder; - public Field getField() { - return this.field; - } + public static PojoField build(Field field, IEncoder encoder) { + return new PojoField(field, encoder); + } - public IEncoder getEncoder() { - return this.encoder; - } + public PojoField(Field field, IEncoder encoder) { + this.field = field; + this.encoder = encoder; + } + public Field getField() { + return this.field; } + public IEncoder getEncoder() { + return this.encoder; + } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortArrEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortArrEncoder.java index 67343b7e1..fbb1007af 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortArrEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortArrEncoder.java @@ -22,42 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class ShortArrEncoder extends AbstractEncoder { - public static final ShortArrEncoder INSTANCE = new ShortArrEncoder(); + public static final ShortArrEncoder INSTANCE = new ShortArrEncoder(); - @Override - public void encode(short[] data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } - int lenToWrite = data.length + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); - } - Encoders.INTEGER.encode(lenToWrite, outputStream); - for (short datum : data) { - Encoders.SHORT.encode(datum, outputStream); - } + @Override + public void encode(short[] data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; } - - @Override - public short[] decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - int length = flag - 1; - short[] arr = new short[length]; - for (int i = 0; i < length; i++) { - arr[i] = Encoders.SHORT.decode(inputStream); - } - return arr; + int lenToWrite = data.length + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(MSG_ARR_TOO_BIG)); + } + Encoders.INTEGER.encode(lenToWrite, outputStream); + for (short datum : data) { + Encoders.SHORT.encode(datum, outputStream); } + } + @Override + public short[] decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; + } + int length = flag - 1; + short[] arr = new short[length]; + for (int i = 0; i < length; i++) { + arr[i] = Encoders.SHORT.decode(inputStream); + } + return arr; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortEncoder.java index 61a7cc020..c43f4c511 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/ShortEncoder.java @@ -25,19 +25,18 @@ public class ShortEncoder extends AbstractEncoder { - public static ShortEncoder INSTANCE = new ShortEncoder(); + public static ShortEncoder INSTANCE = new ShortEncoder(); - @Override - public void encode(Short data, OutputStream outputStream) throws IOException { - outputStream.write(data); - outputStream.write(data >> 8); - } - - @Override - public Short decode(InputStream inputStream) throws IOException { - int b1 = inputStream.read(); - int b2 = inputStream.read(); - return (short) (b1 | (b2 << 8)); - } + @Override + public void encode(Short data, OutputStream outputStream) throws IOException { + outputStream.write(data); + outputStream.write(data >> 8); + } + @Override + public Short decode(InputStream inputStream) throws IOException { + int b1 = inputStream.read(); + int b2 = inputStream.read(); + return (short) (b1 | (b2 << 8)); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/StringEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/StringEncoder.java index cbe84d227..b2a2809e9 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/StringEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/StringEncoder.java @@ -22,44 +22,44 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class StringEncoder extends AbstractEncoder { - public static final StringEncoder INSTANCE = new StringEncoder(); - - private static final int NULL = 0; + public static final StringEncoder INSTANCE = new StringEncoder(); - @Override - public void encode(String data, OutputStream outputStream) throws IOException { - if (data == null) { - Encoders.INTEGER.encode(NULL, outputStream); - return; - } + private static final int NULL = 0; - // the length we write is offset by one, because a length of zero indicates a null value - int lenToWrite = data.length() + 1; - if (lenToWrite < 0) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("string is too long")); - } + @Override + public void encode(String data, OutputStream outputStream) throws IOException { + if (data == null) { + Encoders.INTEGER.encode(NULL, outputStream); + return; + } - byte[] bytes = data.getBytes(); - IntegerEncoder.INSTANCE.encode(bytes.length + 1, outputStream); - outputStream.write(bytes); + // the length we write is offset by one, because a length of zero indicates a null value + int lenToWrite = data.length() + 1; + if (lenToWrite < 0) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("string is too long")); } - @Override - public String decode(InputStream inputStream) throws IOException { - int length = Encoders.INTEGER.decode(inputStream); - if (length == NULL) { - return null; - } + byte[] bytes = data.getBytes(); + IntegerEncoder.INSTANCE.encode(bytes.length + 1, outputStream); + outputStream.write(bytes); + } - byte[] bytes = new byte[length - 1]; - inputStream.read(bytes); - return new String(bytes); + @Override + public String decode(InputStream inputStream) throws IOException { + int length = Encoders.INTEGER.decode(inputStream); + if (length == NULL) { + return null; } + byte[] bytes = new byte[length - 1]; + inputStream.read(bytes); + return new String(bytes); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TripleEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TripleEncoder.java index b08113d1d..7454f59b2 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TripleEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TripleEncoder.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.encoder.IEncoder; @@ -29,46 +30,44 @@ public class TripleEncoder extends AbstractEncoder> { - private final IEncoder encoder0; - private final IEncoder encoder1; - private final IEncoder encoder2; + private final IEncoder encoder0; + private final IEncoder encoder1; + private final IEncoder encoder2; - public TripleEncoder(IEncoder encoder0, IEncoder encoder1, IEncoder encoder2) { - this.encoder0 = encoder0; - this.encoder1 = encoder1; - this.encoder2 = encoder2; - } + public TripleEncoder(IEncoder encoder0, IEncoder encoder1, IEncoder encoder2) { + this.encoder0 = encoder0; + this.encoder1 = encoder1; + this.encoder2 = encoder2; + } - @Override - public void init(Configuration config) { - this.encoder0.init(config); - this.encoder1.init(config); - this.encoder2.init(config); - } + @Override + public void init(Configuration config) { + this.encoder0.init(config); + this.encoder1.init(config); + this.encoder2.init(config); + } - @Override - public void encode(Triple data, - OutputStream outputStream) throws IOException { - int flag = data == null ? NULL : NOT_NULL; - Encoders.INTEGER.encode(flag, outputStream); - if (flag == NULL) { - return; - } - this.encoder0.encode(data.f0, outputStream); - this.encoder1.encode(data.f1, outputStream); - this.encoder2.encode(data.f2, outputStream); + @Override + public void encode(Triple data, OutputStream outputStream) throws IOException { + int flag = data == null ? NULL : NOT_NULL; + Encoders.INTEGER.encode(flag, outputStream); + if (flag == NULL) { + return; } + this.encoder0.encode(data.f0, outputStream); + this.encoder1.encode(data.f1, outputStream); + this.encoder2.encode(data.f2, outputStream); + } - @Override - public Triple decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - F0 f0 = this.encoder0.decode(inputStream); - F1 f1 = this.encoder1.decode(inputStream); - F2 f2 = this.encoder2.decode(inputStream); - return Triple.of(f0, f1, f2); + @Override + public Triple decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; } - + F0 f0 = this.encoder0.decode(inputStream); + F1 f1 = this.encoder1.decode(inputStream); + F2 f2 = this.encoder2.decode(inputStream); + return Triple.of(f0, f1, f2); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TupleEncoder.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TupleEncoder.java index a3835b0fc..127eaddcd 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TupleEncoder.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/encoder/impl/TupleEncoder.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.encoder.IEncoder; @@ -29,41 +30,39 @@ public class TupleEncoder extends AbstractEncoder> { - private final IEncoder encoder0; - private final IEncoder encoder1; + private final IEncoder encoder0; + private final IEncoder encoder1; - public TupleEncoder(IEncoder encoder0, IEncoder encoder1) { - this.encoder0 = encoder0; - this.encoder1 = encoder1; - } + public TupleEncoder(IEncoder encoder0, IEncoder encoder1) { + this.encoder0 = encoder0; + this.encoder1 = encoder1; + } - @Override - public void init(Configuration config) { - this.encoder0.init(config); - this.encoder1.init(config); - } + @Override + public void init(Configuration config) { + this.encoder0.init(config); + this.encoder1.init(config); + } - @Override - public void encode(Tuple data, - OutputStream outputStream) throws IOException { - int flag = data == null ? NULL : NOT_NULL; - Encoders.INTEGER.encode(flag, outputStream); - if (flag == NULL) { - return; - } - this.encoder0.encode(data.f0, outputStream); - this.encoder1.encode(data.f1, outputStream); + @Override + public void encode(Tuple data, OutputStream outputStream) throws IOException { + int flag = data == null ? NULL : NOT_NULL; + Encoders.INTEGER.encode(flag, outputStream); + if (flag == NULL) { + return; } + this.encoder0.encode(data.f0, outputStream); + this.encoder1.encode(data.f1, outputStream); + } - @Override - public Tuple decode(InputStream inputStream) throws IOException { - int flag = Encoders.INTEGER.decode(inputStream); - if (flag == NULL) { - return null; - } - F0 f0 = this.encoder0.decode(inputStream); - F1 f1 = this.encoder1.decode(inputStream); - return Tuple.of(f0, f1); + @Override + public Tuple decode(InputStream inputStream) throws IOException { + int flag = Encoders.INTEGER.decode(inputStream); + if (flag == NULL) { + return null; } - + F0 f0 = this.encoder0.decode(inputStream); + F1 f1 = this.encoder1.decode(inputStream); + return Tuple.of(f0, f1); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/ErrorFactory.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/ErrorFactory.java index 7722ff7be..78b7a7d6f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/ErrorFactory.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/ErrorFactory.java @@ -38,333 +38,348 @@ import java.util.List; import java.util.Set; import java.util.regex.Pattern; + import org.apache.commons.lang3.StringUtils; public class ErrorFactory { - protected static Pattern pattern1 = Pattern.compile("[0-9]{8}"); - protected static Set modNames = new HashSet<>(); - - public static String ERR_ID = "ERR_ID"; - public static String ERROR_MSG_PREFIX = "\n************\n"; - - static { - modNames.add("SQL"); - modNames.add("DSL"); - modNames.add("IO"); - modNames.add("STB"); - modNames.add("STS"); - modNames.add("RUN"); + protected static Pattern pattern1 = Pattern.compile("[0-9]{8}"); + protected static Set modNames = new HashSet<>(); + + public static String ERR_ID = "ERR_ID"; + public static String ERROR_MSG_PREFIX = "\n************\n"; + + static { + modNames.add("SQL"); + modNames.add("DSL"); + modNames.add("IO"); + modNames.add("STB"); + modNames.add("STS"); + modNames.add("RUN"); + } + + private ErrorFactory() {} + + /** + * Create proxy instance for one module's error code interface. + * + * @param clazz interface that has error code definitions for one module. + * @return instance of the interface that can be used by developer for specifying error code. + */ + public static Object createProxy(Class clazz) { + + return Proxy.newProxyInstance( + clazz.getClassLoader(), + clazz.isInterface() ? new Class[] {clazz} : clazz.getInterfaces(), + (obj, method, args) -> { + checkParam(method, args); + return assemblyErrCodeString(method, args); + }); + } + + /** + * Parameter check when invoking method. + * + * @param method + * @param args + */ + protected static void checkParam(Method method, Object[] args) { + ErrCode errCode = method.getAnnotation(ErrCode.class); + String errDetail = errCode.details(); + String errCause = errCode.cause(); + + MessageFormat format1 = new MessageFormat(errDetail); + MessageFormat format2 = new MessageFormat(errCause); + + if (args == null || args.length == 0) { + if ((format1.getFormatsByArgumentIndex() != null + && format1.getFormatsByArgumentIndex().length > 0) + || (format2.getFormatsByArgumentIndex() != null + && format2.getFormatsByArgumentIndex().length > 0)) { + throw new AssertionError( + "mismatched parameter length between " + + method.getName() + + " and its annotation @ErrCode"); + } + } else { + if ((format1.getFormatsByArgumentIndex() != null + && format1.getFormatsByArgumentIndex().length > args.length) + || format1.getFormatsByArgumentIndex() == null + || (format2.getFormatsByArgumentIndex() != null + && format2.getFormatsByArgumentIndex().length > args.length)) { + throw new AssertionError( + "mismatched parameter length between " + + method.getName() + + " and its annotation @ErrCode"); + } } - - private ErrorFactory() { + } + + /** + * Assembly error code messages. + * + * @param method error code related function declared in error interface + * @param args args passed to that related function + * @return error code messages containing code id, cause and action. + */ + protected static String assemblyErrCodeString(Method method, Object[] args) { + ErrCode errCode = method.getAnnotation(ErrCode.class); + String errId = errCode.codeId(); + String errCause = errCode.cause(); + String errDetail = errCode.details(); + + if (args != null && args.length != 0) { + MessageFormat format1 = new MessageFormat(errDetail); + errDetail = format1.format(args); + + MessageFormat format2 = new MessageFormat(errCause); + errCause = format2.format(args); } - /** - * Create proxy instance for one module's error code interface. - * - * @param clazz interface that has error code definitions for one module. - * @return instance of the interface that can be used by developer for specifying error code. - */ - public static Object createProxy(Class clazz) { + errId = prettyPrint(errId); + errCause = prettyPrint(errCause); + errDetail = prettyPrint(errDetail); + + String errAction = errCode.action(); + errAction = prettyPrint(errAction); + + String msg = + ERROR_MSG_PREFIX + + "ERR_ID:\n" + + errId + + "\n" + + "CAUSE:\n" + + errCause + + "\n" + + "ACTION:\n" + + errAction + + "\n" + + "DETAIL:\n" + + errDetail + + "\n" + + "************"; + return msg; + } + + /** + * Print out error code in a pretty way. + * + * @param str + * @return + */ + public static String prettyPrint(String str) { + if (str != null && str.length() != 0) { + str = indent(5) + str.replaceAll("\n", "\n" + indent(5)); + } - return Proxy.newProxyInstance(clazz.getClassLoader(), - clazz.isInterface() ? new Class[]{clazz} : clazz.getInterfaces(), - (obj, method, args) -> { - checkParam(method, args); - return assemblyErrCodeString(method, args); - }); + return str; + } + + /** + * Validate an error code definition interface to check its annotation usage and err code format. + * + * @param clazz err code definition interface + */ + public static void validate(Class clazz) { + validate(clazz, EnumSet.allOf(ValidationType.class)); + } + + /** + * Validate an error code definition interface to check its annotation usage and err code format. + * + * @param clazz err code definition interface. + * @param validations types of validations to perform. + */ + public static void validate(Class clazz, EnumSet validations) { + int cnt = 0; + Set errIds = new HashSet<>(); + + for (Method method : clazz.getMethods()) { + if (!Modifier.isStatic(method.getModifiers())) { + cnt++; + + final ErrCode anno1 = method.getAnnotation(ErrCode.class); + + for (ValidationType validation : validations) { + switch (validation) { + case ANNOTATION_SPECIFIED: + if (anno1 == null + || StringUtils.isEmpty(anno1.codeId()) + || StringUtils.isEmpty(anno1.cause())) { + throw new AssertionError( + String.format( + "error code method[%s] " + + "must specify @ErrCode annotation with cause, details " + + "and none-empty codeId!", + method.getName())); + } + break; + + case ERROR_ID_CHECK: + if (anno1 == null) { + throw new AssertionError( + String.format( + "error code method[%s]" + "has no @ErrId annotation!", method.getName())); + } + String errId = anno1.codeId(); + if (!checkErrorCodeFmt(errId)) { + throw new AssertionError( + String.format( + "error code method[%s]" + "has invalid error code: %s", + method.getName(), errId)); + } + if (errIds.contains(errId)) { + throw new AssertionError( + String.format( + "error code method[%s]" + "has duplicated err id: %s", + method.getName(), errId)); + } + errIds.add(errId); + break; + + case ARGUMENT_MATCH: + if (anno1 == null) { + throw new AssertionError( + String.format( + "error code method[%s]" + "has no @ErrCode annotation!", method.getName())); + } + + String msg = anno1.details(); + String cause = anno1.cause(); + + MessageFormat msgFmt = new MessageFormat(msg); + MessageFormat causeFmt = new MessageFormat(cause); + + final Format[] msgFormats = msgFmt.getFormatsByArgumentIndex(); + final Format[] causeFormats = causeFmt.getFormatsByArgumentIndex(); + + final Format[] formats = msgFormats.length != 0 ? msgFormats : causeFormats; + + final List types = new ArrayList<>(); + final Class[] paramTypes = method.getParameterTypes(); + + if (!(msgFormats.length != 0 && causeFormats.length != 0)) { + + for (int i = 0; i < formats.length; i++) { + Format fmt1 = formats[i]; + Class paramType = paramTypes[i]; + final Class e; + if (fmt1 instanceof NumberFormat) { + e = + paramType == short.class + || paramType == int.class + || paramType == long.class + || paramType == float.class + || paramType == double.class + || Number.class.isAssignableFrom(paramType) + ? paramType + : Number.class; + } else if (fmt1 instanceof DateFormat) { + e = Date.class; + } else { + e = String.class; + } + + types.add(e); + } - } + final List> paramTypeList = Arrays.asList(paramTypes); + if (!types.equals(paramTypeList)) { + throw new AssertionError( + String.format( + "error code[%s]" + + " has type mismatch(s) between method param %s and" + + " format elements %s in annotation", + method.getName(), types, paramTypeList)); + } + } + break; - /** - * Parameter check when invoking method. - * - * @param method - * @param args - */ - protected static void checkParam(Method method, Object[] args) { - ErrCode errCode = method.getAnnotation(ErrCode.class); - String errDetail = errCode.details(); - String errCause = errCode.cause(); - - MessageFormat format1 = new MessageFormat(errDetail); - MessageFormat format2 = new MessageFormat(errCause); - - if (args == null || args.length == 0) { - if ((format1.getFormatsByArgumentIndex() != null - && format1.getFormatsByArgumentIndex().length > 0) || ( - format2.getFormatsByArgumentIndex() != null - && format2.getFormatsByArgumentIndex().length > 0)) { - throw new AssertionError("mismatched parameter length between " + method.getName() - + " and its annotation @ErrCode"); - } - } else { - if ((format1.getFormatsByArgumentIndex() != null - && format1.getFormatsByArgumentIndex().length > args.length) - || format1.getFormatsByArgumentIndex() == null || ( - format2.getFormatsByArgumentIndex() != null - && format2.getFormatsByArgumentIndex().length > args.length)) { - throw new AssertionError("mismatched parameter length between " + method.getName() - + " and its annotation @ErrCode"); - } + default: + break; + } } + } } - /** - * Assembly error code messages. - * - * @param method error code related function declared in error interface - * @param args args passed to that related function - * @return error code messages containing code id, cause and action. - */ - protected static String assemblyErrCodeString(Method method, Object[] args) { - ErrCode errCode = method.getAnnotation(ErrCode.class); - String errId = errCode.codeId(); - String errCause = errCode.cause(); - String errDetail = errCode.details(); - - if (args != null && args.length != 0) { - MessageFormat format1 = new MessageFormat(errDetail); - errDetail = format1.format(args); - - MessageFormat format2 = new MessageFormat(errCause); - errCause = format2.format(args); - } - - errId = prettyPrint(errId); - errCause = prettyPrint(errCause); - errDetail = prettyPrint(errDetail); - - String errAction = errCode.action(); - errAction = prettyPrint(errAction); - - String msg = ERROR_MSG_PREFIX - + "ERR_ID:\n" + errId + "\n" + "CAUSE:\n" + errCause + "\n" + "ACTION:\n" + errAction - + "\n" + "DETAIL:\n" + errDetail + "\n" - + "************"; - return msg; + if (cnt == 0 && validations.contains(ValidationType.AT_LEAST_ONE)) { + throw new AssertionError(clazz + " contains no error code"); } - - /** - * Print out error code in a pretty way. - * - * @param str - * @return - */ - public static String prettyPrint(String str) { - if (str != null && str.length() != 0) { - str = indent(5) + str.replaceAll("\n", "\n" + indent(5)); - } - - return str; + } + + /** + * Check error code id format. + * + * @param errId + * @return + */ + protected static boolean checkErrorCodeFmt(String errId) { + if (errId == null || errId.isEmpty()) { + return false; } - - /** - * Validate an error code definition interface to check its annotation usage - * and err code format. - * - * @param clazz err code definition interface - */ - public static void validate(Class clazz) { - validate(clazz, EnumSet.allOf(ValidationType.class)); + String[] parts = errId.split("-"); + if (parts.length != 2) { + return false; + } + if (!modNames.contains(parts[0])) { + return false; + } + if (!pattern1.matcher(parts[1]).matches()) { + return false; } - /** - * Validate an error code definition interface to check its annotation usage - * and err code format. - * - * @param clazz err code definition interface. - * @param validations types of validations to perform. - */ - public static void validate(Class clazz, EnumSet validations) { - int cnt = 0; - Set errIds = new HashSet<>(); - - for (Method method : clazz.getMethods()) { - if (!Modifier.isStatic(method.getModifiers())) { - cnt++; - - final ErrCode anno1 = method.getAnnotation(ErrCode.class); - - for (ValidationType validation : validations) { - switch (validation) { - case ANNOTATION_SPECIFIED: - if (anno1 == null || StringUtils.isEmpty(anno1.codeId()) || StringUtils - .isEmpty(anno1.cause())) { - throw new AssertionError(String.format("error code method[%s] " - + "must specify @ErrCode annotation with cause, details " + "and none-empty codeId!", - method.getName())); - } - break; - - case ERROR_ID_CHECK: - if (anno1 == null) { - throw new AssertionError(String - .format("error code method[%s]" + "has no @ErrId annotation!", - method.getName())); - } - String errId = anno1.codeId(); - if (!checkErrorCodeFmt(errId)) { - throw new AssertionError(String - .format("error code method[%s]" + "has invalid error code: %s", - method.getName(), errId)); - } - if (errIds.contains(errId)) { - throw new AssertionError(String - .format("error code method[%s]" + "has duplicated err id: %s", - method.getName(), errId)); - } - errIds.add(errId); - break; - - case ARGUMENT_MATCH: - if (anno1 == null) { - throw new AssertionError(String - .format("error code method[%s]" + "has no @ErrCode annotation!", - method.getName())); - } - - String msg = anno1.details(); - String cause = anno1.cause(); - - MessageFormat msgFmt = new MessageFormat(msg); - MessageFormat causeFmt = new MessageFormat(cause); - - final Format[] msgFormats = msgFmt.getFormatsByArgumentIndex(); - final Format[] causeFormats = causeFmt.getFormatsByArgumentIndex(); - - final Format[] formats = - msgFormats.length != 0 ? msgFormats : causeFormats; - - final List types = new ArrayList<>(); - final Class[] paramTypes = method.getParameterTypes(); - - if (!(msgFormats.length != 0 && causeFormats.length != 0)) { - - for (int i = 0; i < formats.length; i++) { - Format fmt1 = formats[i]; - Class paramType = paramTypes[i]; - final Class e; - if (fmt1 instanceof NumberFormat) { - e = paramType == short.class || paramType == int.class - || paramType == long.class - || paramType == float.class - || paramType == double.class || Number.class - .isAssignableFrom(paramType) ? paramType : Number.class; - } else if (fmt1 instanceof DateFormat) { - e = Date.class; - } else { - e = String.class; - } - - types.add(e); - } - - final List> paramTypeList = Arrays.asList(paramTypes); - if (!types.equals(paramTypeList)) { - throw new AssertionError(String.format("error code[%s]" - + " has type mismatch(s) between method param %s and" - + " format elements %s in annotation", method.getName(), - types, paramTypeList)); - } - } - break; - - default: - break; - } - } - } - } - - if (cnt == 0 && validations.contains(ValidationType.AT_LEAST_ONE)) { - throw new AssertionError(clazz + " contains no error code"); - } + return true; + } + + /** + * Get multiple indent. + * + * @param cnt + * @return + */ + public static String indent(int cnt) { + if (cnt <= 0) { + return null; } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < cnt; i++) { + sb.append(" "); + } + return sb.toString(); + } - /** - * Check error code id format. - * - * @param errId - * @return - */ - protected static boolean checkErrorCodeFmt(String errId) { - if (errId == null || errId.isEmpty()) { - return false; - } - String[] parts = errId.split("-"); - if (parts.length != 2) { - return false; - } - if (!modNames.contains(parts[0])) { - return false; - } - if (!pattern1.matcher(parts[1]).matches()) { - return false; - } + /** Types of validation that can be performed on a resource. */ + public enum ValidationType { + /** Checks that the ErrId, ErrCause and ErrAction annotations are on every resource. */ + ANNOTATION_SPECIFIED, - return true; - } + /** Checks that there is at least one resource. */ + AT_LEAST_ONE, /** - * Get multiple indent. - * - * @param cnt - * @return + * Checks that @ErrId anno has non-null value which has expected error id format, and there's no + * duplicated error id in resource instance. */ - public static String indent(int cnt) { - if (cnt <= 0) { - return null; - } - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < cnt; i++) { - sb.append(" "); - } - return sb.toString(); - } + ERROR_ID_CHECK, /** - * Types of validation that can be performed on a resource. + * Checks that the parameters of the method are consistent with the format elements in the error + * cause message. */ - public enum ValidationType { - /** - * Checks that the ErrId, ErrCause and ErrAction annotations are on every resource. - */ - ANNOTATION_SPECIFIED, - - /** - * Checks that there is at least one resource. - */ - AT_LEAST_ONE, - - /** - * Checks that @ErrId anno has non-null value which has expected error id format, - * and there's no duplicated error id in resource instance. - */ - ERROR_ID_CHECK, - - /** - * Checks that the parameters of the method are consistent with the - * format elements in the error cause message. - */ - ARGUMENT_MATCH, - } + ARGUMENT_MATCH, + } - /** - * err code id, cause, detailed message and action. - **/ - @Retention(RetentionPolicy.RUNTIME) - @Target(ElementType.METHOD) - public @interface ErrCode { - - String codeId(); + /** err code id, cause, detailed message and action. */ + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + public @interface ErrCode { - String cause(); + String codeId(); - String details(); + String cause(); - String action(); - } + String details(); + String action(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrorCode.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrorCode.java index d52302d77..515a0985a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrorCode.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrorCode.java @@ -23,196 +23,215 @@ public interface RuntimeErrorCode { - // Common module. - // RUN-00xxxxxx. - @ErrorFactory.ErrCode(codeId = "RUN-00000001", - cause = "Undefined error: {0}", - details = "", - action = "Please check your code or dsl, and contact admin.") - String undefinedError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-00000002", - cause = "Run error: {0}", - details = "", - action = "Please check your code, or contact admin.") - String runError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-00000003", - cause = "SystemInternalError - not find ''{0}'' SPI implement", - details = "", - action = "Please contact admin.") - String spiNotFoundError(String className); - - @ErrorFactory.ErrCode(codeId = "RUN-00000004", - cause = "Key conflicts error: userKey={0}, systemKeys={1}", - details = "", - action = "Please check your config or contact admin.") - String keyConflictsError(Set userArgsKeys, Set systemArgsKeys); - - @ErrorFactory.ErrCode(codeId = "RUN-00000005", - cause = "Operator {0} type not support error.", - details = "", - action = "Please check your code, or contact admin.") - String operatorTypeNotSupportError(String opType); - - @ErrorFactory.ErrCode(codeId = "RUN-00000006", - cause = "unsupported operation", - details = "", - action = "Please check your config, or contact admin.") - String unsupportedError(); - - @ErrorFactory.ErrCode(codeId = "RUN-00000007", - cause = "SystemInternalError - not support ''{0}'' event type", - details = "", - action = "Please contact admin.") - String requestTypeNotSupportError(String requestType); - - @ErrorFactory.ErrCode(codeId = "RUN-00000008", - cause = "Config key not found: {0}", - details = "", - action = "Please check your config or contact admin.") - String configKeyNotFound(String key); - - // Plan module. - // RUN-01xxxxxx. - @ErrorFactory.ErrCode(codeId = "RUN-01000001", - cause = "SystemInternalError - the previous vertex of vertex '{0}' is null", - details = "", - action = "Please check your code or dsl, and contact admin.") - String previousVertexIsNullError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-01000002", - cause = "Logical error - action is empty, missing necessary sink or get action", - details = "", - action = "Please contact admin for help.") - String actionIsEmptyError(); - - @ErrorFactory.ErrCode(codeId = "RUN-01000003", - cause = "SystemInternalError - not support stream type ''{0}''", - details = "", - action = "Please contact admin for help.") - String streamTypeNotSupportError(String streamType); - - - // Execution module. - // RUN-02xxxxxx. - @ErrorFactory.ErrCode(codeId = "RUN-02000001", - cause = "shuffle serialize error: {0}", - details = "", - action = "Check the cause, and contact admin for help if necessary.") - String shuffleSerializeError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-02000002", - cause = "shuffle deserialize error: {0}", - details = "", - action = "Check the cause, and contact admin for help if necessary.") - String shuffleDeserializeError(String info); - - - // state module. - // RUN-03xxxxxx. - - /** - * State common. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-03000001", - cause = "StateCommonError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String stateCommonError(String info); - - /** - * State RocksDB. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-03000002", - cause = "StateRocksDbError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String stateRocksDbError(String info); - - - // framework module. - // RUN-04xxxxxx. - - /** - * Resource. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-04000001", - cause = "ResourceError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String resourceError(String info); - - /** - * Type system. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-04000002", - cause = "Type system error, error message: {0}", - details = "", - action = "Please contact admin for help.") - String typeSysError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-04000003", - cause = "IllegalRequireNumError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String resourceIllegalRequireNumError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-04000004", - cause = "ResourceRecoveringError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String resourceRecoveringError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-04000005", - cause = "ResourceNotReadyError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String resourceNotReadyError(String info); - - @ErrorFactory.ErrCode(codeId = "RUN-04000006", - cause = "analyticsClientError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String analyticsClientError(String info); - - // DSL module. - // RUN-05xxxxxx. - - /** - * DSL runtime. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-05000001", - cause = "DslRuntimeError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String dslRuntimeError(String info); - - /** - * DSL parser. - * - * @param info info - * @return msg - */ - @ErrorFactory.ErrCode(codeId = "RUN-05000002", - cause = "DslParserError, error message: {0}", - details = "", - action = "Please contact admin for help.") - String dslParserError(String info); - + // Common module. + // RUN-00xxxxxx. + @ErrorFactory.ErrCode( + codeId = "RUN-00000001", + cause = "Undefined error: {0}", + details = "", + action = "Please check your code or dsl, and contact admin.") + String undefinedError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000002", + cause = "Run error: {0}", + details = "", + action = "Please check your code, or contact admin.") + String runError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000003", + cause = "SystemInternalError - not find ''{0}'' SPI implement", + details = "", + action = "Please contact admin.") + String spiNotFoundError(String className); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000004", + cause = "Key conflicts error: userKey={0}, systemKeys={1}", + details = "", + action = "Please check your config or contact admin.") + String keyConflictsError(Set userArgsKeys, Set systemArgsKeys); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000005", + cause = "Operator {0} type not support error.", + details = "", + action = "Please check your code, or contact admin.") + String operatorTypeNotSupportError(String opType); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000006", + cause = "unsupported operation", + details = "", + action = "Please check your config, or contact admin.") + String unsupportedError(); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000007", + cause = "SystemInternalError - not support ''{0}'' event type", + details = "", + action = "Please contact admin.") + String requestTypeNotSupportError(String requestType); + + @ErrorFactory.ErrCode( + codeId = "RUN-00000008", + cause = "Config key not found: {0}", + details = "", + action = "Please check your config or contact admin.") + String configKeyNotFound(String key); + + // Plan module. + // RUN-01xxxxxx. + @ErrorFactory.ErrCode( + codeId = "RUN-01000001", + cause = "SystemInternalError - the previous vertex of vertex '{0}' is null", + details = "", + action = "Please check your code or dsl, and contact admin.") + String previousVertexIsNullError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-01000002", + cause = "Logical error - action is empty, missing necessary sink or get action", + details = "", + action = "Please contact admin for help.") + String actionIsEmptyError(); + + @ErrorFactory.ErrCode( + codeId = "RUN-01000003", + cause = "SystemInternalError - not support stream type ''{0}''", + details = "", + action = "Please contact admin for help.") + String streamTypeNotSupportError(String streamType); + + // Execution module. + // RUN-02xxxxxx. + @ErrorFactory.ErrCode( + codeId = "RUN-02000001", + cause = "shuffle serialize error: {0}", + details = "", + action = "Check the cause, and contact admin for help if necessary.") + String shuffleSerializeError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-02000002", + cause = "shuffle deserialize error: {0}", + details = "", + action = "Check the cause, and contact admin for help if necessary.") + String shuffleDeserializeError(String info); + + // state module. + // RUN-03xxxxxx. + + /** + * State common. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-03000001", + cause = "StateCommonError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String stateCommonError(String info); + + /** + * State RocksDB. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-03000002", + cause = "StateRocksDbError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String stateRocksDbError(String info); + + // framework module. + // RUN-04xxxxxx. + + /** + * Resource. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-04000001", + cause = "ResourceError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String resourceError(String info); + + /** + * Type system. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-04000002", + cause = "Type system error, error message: {0}", + details = "", + action = "Please contact admin for help.") + String typeSysError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-04000003", + cause = "IllegalRequireNumError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String resourceIllegalRequireNumError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-04000004", + cause = "ResourceRecoveringError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String resourceRecoveringError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-04000005", + cause = "ResourceNotReadyError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String resourceNotReadyError(String info); + + @ErrorFactory.ErrCode( + codeId = "RUN-04000006", + cause = "analyticsClientError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String analyticsClientError(String info); + + // DSL module. + // RUN-05xxxxxx. + + /** + * DSL runtime. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-05000001", + cause = "DslRuntimeError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String dslRuntimeError(String info); + + /** + * DSL parser. + * + * @param info info + * @return msg + */ + @ErrorFactory.ErrCode( + codeId = "RUN-05000002", + cause = "DslParserError, error message: {0}", + details = "", + action = "Please contact admin for help.") + String dslParserError(String info); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrors.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrors.java index 3f2a5c8c5..db4f9c7e1 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrors.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/errorcode/RuntimeErrors.java @@ -21,6 +21,6 @@ public class RuntimeErrors { - public static final RuntimeErrorCode INST = (RuntimeErrorCode) ErrorFactory - .createProxy(RuntimeErrorCode.class); + public static final RuntimeErrorCode INST = + (RuntimeErrorCode) ErrorFactory.createProxy(RuntimeErrorCode.class); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/ConfigException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/ConfigException.java index 4744c601f..0f1aa9bf8 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/ConfigException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/ConfigException.java @@ -21,11 +21,11 @@ public class ConfigException extends RuntimeException { - public ConfigException(Throwable t) { - super(t); - } + public ConfigException(Throwable t) { + super(t); + } - public ConfigException(String msg) { - super(msg); - } + public ConfigException(String msg) { + super(msg); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowDispatchException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowDispatchException.java index 371e8ee0e..ac641e24f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowDispatchException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowDispatchException.java @@ -21,15 +21,15 @@ public class GeaflowDispatchException extends GeaflowRuntimeException { - public GeaflowDispatchException(String msg) { - super(msg); - } + public GeaflowDispatchException(String msg) { + super(msg); + } - public GeaflowDispatchException(Throwable t) { - super(t); - } + public GeaflowDispatchException(Throwable t) { + super(t); + } - public GeaflowDispatchException(String msg, Throwable t) { - super(msg, t); - } + public GeaflowDispatchException(String msg, Throwable t) { + super(msg, t); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowHeartbeatException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowHeartbeatException.java index 25090ef4f..a695c4458 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowHeartbeatException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowHeartbeatException.java @@ -21,17 +21,17 @@ public class GeaflowHeartbeatException extends RuntimeException { - private static final String MESSAGE = "heartbeat timeout"; + private static final String MESSAGE = "heartbeat timeout"; - public GeaflowHeartbeatException() { - super(MESSAGE); - } + public GeaflowHeartbeatException() { + super(MESSAGE); + } - public GeaflowHeartbeatException(String message) { - super(message); - } + public GeaflowHeartbeatException(String message) { + super(message); + } - public GeaflowHeartbeatException(String message, Throwable cause) { - super(MESSAGE, cause); - } + public GeaflowHeartbeatException(String message, Throwable cause) { + super(MESSAGE, cause); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowInterruptedException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowInterruptedException.java index bbd40ec5e..4df330660 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowInterruptedException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowInterruptedException.java @@ -21,16 +21,15 @@ public class GeaflowInterruptedException extends GeaflowRuntimeException { - public GeaflowInterruptedException(Throwable e) { - super(e); - } + public GeaflowInterruptedException(Throwable e) { + super(e); + } - public GeaflowInterruptedException(String errorMessage) { - super(errorMessage); - } - - public GeaflowInterruptedException(String errorMessage, Throwable cause) { - super(errorMessage, cause); - } + public GeaflowInterruptedException(String errorMessage) { + super(errorMessage); + } + public GeaflowInterruptedException(String errorMessage, Throwable cause) { + super(errorMessage, cause); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowRuntimeException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowRuntimeException.java index 8149b4576..45973ce39 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowRuntimeException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/GeaflowRuntimeException.java @@ -20,26 +20,25 @@ package org.apache.geaflow.common.exception; public class GeaflowRuntimeException extends RuntimeException { - private static final long serialVersionUID = 8832569372505798566L; + private static final long serialVersionUID = 8832569372505798566L; - private String errorMessage; + private String errorMessage; - public GeaflowRuntimeException(Throwable e) { - super(e); - } + public GeaflowRuntimeException(Throwable e) { + super(e); + } - public GeaflowRuntimeException(String errorMessage) { - super(errorMessage); - this.errorMessage = errorMessage; - } + public GeaflowRuntimeException(String errorMessage) { + super(errorMessage); + this.errorMessage = errorMessage; + } - public GeaflowRuntimeException(String errorMessage, Throwable cause) { - super(errorMessage, cause); - this.errorMessage = errorMessage; - } - - public String getErrorMessage() { - return this.errorMessage; - } + public GeaflowRuntimeException(String errorMessage, Throwable cause) { + super(errorMessage, cause); + this.errorMessage = errorMessage; + } + public String getErrorMessage() { + return this.errorMessage; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/NullFieldException.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/NullFieldException.java index bea319e89..c431f4357 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/NullFieldException.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/exception/NullFieldException.java @@ -21,7 +21,7 @@ public class NullFieldException extends GeaflowRuntimeException { - public NullFieldException(int pos) { - super("pos" + pos + " is null"); - } + public NullFieldException(int pos) { + super("pos" + pos + " is null"); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/Heartbeat.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/Heartbeat.java index 902093ff8..7f3576201 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/Heartbeat.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/Heartbeat.java @@ -20,55 +20,62 @@ package org.apache.geaflow.common.heartbeat; import java.io.Serializable; + import org.apache.geaflow.common.metric.ProcessMetrics; public class Heartbeat implements Serializable { - private int containerId; - private long timestamp; - private String containerName; - private ProcessMetrics processMetrics; + private int containerId; + private long timestamp; + private String containerName; + private ProcessMetrics processMetrics; - public Heartbeat(int resourceId) { - this.containerId = resourceId; - this.timestamp = System.currentTimeMillis(); - } + public Heartbeat(int resourceId) { + this.containerId = resourceId; + this.timestamp = System.currentTimeMillis(); + } - public int getContainerId() { - return containerId; - } + public int getContainerId() { + return containerId; + } - public void setContainerId(int containerId) { - this.containerId = containerId; - } + public void setContainerId(int containerId) { + this.containerId = containerId; + } - public long getTimestamp() { - return timestamp; - } + public long getTimestamp() { + return timestamp; + } - public void setTimestamp(long timestamp) { - this.timestamp = timestamp; - } + public void setTimestamp(long timestamp) { + this.timestamp = timestamp; + } - public ProcessMetrics getProcessMetrics() { - return processMetrics; - } + public ProcessMetrics getProcessMetrics() { + return processMetrics; + } - public void setProcessMetrics(ProcessMetrics processMetrics) { - this.processMetrics = processMetrics; - } + public void setProcessMetrics(ProcessMetrics processMetrics) { + this.processMetrics = processMetrics; + } - public String getContainerName() { - return containerName; - } + public String getContainerName() { + return containerName; + } - public void setContainerName(String containerName) { - this.containerName = containerName; - } + public void setContainerName(String containerName) { + this.containerName = containerName; + } - @Override - public String toString() { - return "Heartbeat{" + "containerId=" + containerId + ", timestamp=" + timestamp - + ", processMetrics=" + processMetrics + '}'; - } + @Override + public String toString() { + return "Heartbeat{" + + "containerId=" + + containerId + + ", timestamp=" + + timestamp + + ", processMetrics=" + + processMetrics + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/HeartbeatInfo.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/HeartbeatInfo.java index 6309642cb..fff8d8c98 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/HeartbeatInfo.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/heartbeat/HeartbeatInfo.java @@ -21,120 +21,142 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.metric.ProcessMetrics; public class HeartbeatInfo implements Serializable { - private int totalNum; - private int activeNum; - private List containers; - private long expiredTimeMs; - - public static class ContainerHeartbeatInfo { - - private Integer id; - - private String name; - - private String host; - - private int pid; - - private Long lastTimestamp; + private int totalNum; + private int activeNum; + private List containers; + private long expiredTimeMs; - private ProcessMetrics metrics; + public static class ContainerHeartbeatInfo { - public Integer getId() { - return id; - } + private Integer id; - public void setId(Integer id) { - this.id = id; - } + private String name; - public String getName() { - return name; - } + private String host; - public void setName(String name) { - this.name = name; - } + private int pid; - public String getHost() { - return host; - } + private Long lastTimestamp; - public void setHost(String host) { - this.host = host; - } + private ProcessMetrics metrics; - public int getPid() { - return pid; - } - - public void setPid(int pid) { - this.pid = pid; - } - - public Long getLastTimestamp() { - return lastTimestamp; - } - - public void setLastTimestamp(Long lastTimestamp) { - this.lastTimestamp = lastTimestamp; - } + public Integer getId() { + return id; + } - public ProcessMetrics getMetrics() { - return metrics; - } + public void setId(Integer id) { + this.id = id; + } - public void setMetrics(ProcessMetrics metrics) { - this.metrics = metrics; - } + public String getName() { + return name; + } - @Override - public String toString() { - return "ContainerHeartbeatInfo{" + "id=" + id + ", name='" + name + '\'' + ", host='" - + host + '\'' + ", pid=" + pid + ", lastTimestamp=" + lastTimestamp + ", metrics=" - + metrics + '}'; - } + public void setName(String name) { + this.name = name; } - public int getTotalNum() { - return totalNum; + public String getHost() { + return host; } - public void setTotalNum(int totalNum) { - this.totalNum = totalNum; + public void setHost(String host) { + this.host = host; } - public int getActiveNum() { - return activeNum; + public int getPid() { + return pid; } - public void setActiveNum(int activeNum) { - this.activeNum = activeNum; + public void setPid(int pid) { + this.pid = pid; } - public List getContainers() { - return containers; + public Long getLastTimestamp() { + return lastTimestamp; } - public void setContainers(List containers) { - this.containers = containers; + public void setLastTimestamp(Long lastTimestamp) { + this.lastTimestamp = lastTimestamp; } - public long getExpiredTimeMs() { - return expiredTimeMs; + public ProcessMetrics getMetrics() { + return metrics; } - public void setExpiredTimeMs(long expiredTimeMs) { - this.expiredTimeMs = expiredTimeMs; + public void setMetrics(ProcessMetrics metrics) { + this.metrics = metrics; } @Override public String toString() { - return "HeartbeatInfo{" + "totalNum=" + totalNum + ", activeNum=" + activeNum - + ", containers=" + containers + ", expiredTime=" + expiredTimeMs + '}'; + return "ContainerHeartbeatInfo{" + + "id=" + + id + + ", name='" + + name + + '\'' + + ", host='" + + host + + '\'' + + ", pid=" + + pid + + ", lastTimestamp=" + + lastTimestamp + + ", metrics=" + + metrics + + '}'; } + } + + public int getTotalNum() { + return totalNum; + } + + public void setTotalNum(int totalNum) { + this.totalNum = totalNum; + } + + public int getActiveNum() { + return activeNum; + } + + public void setActiveNum(int activeNum) { + this.activeNum = activeNum; + } + + public List getContainers() { + return containers; + } + + public void setContainers(List containers) { + this.containers = containers; + } + + public long getExpiredTimeMs() { + return expiredTimeMs; + } + + public void setExpiredTimeMs(long expiredTimeMs) { + this.expiredTimeMs = expiredTimeMs; + } + + @Override + public String toString() { + return "HeartbeatInfo{" + + "totalNum=" + + totalNum + + ", activeNum=" + + activeNum + + ", containers=" + + containers + + ", expiredTime=" + + expiredTimeMs + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/ChainedCloseableIterator.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/ChainedCloseableIterator.java index 3743b3d8c..06f27bf2e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/ChainedCloseableIterator.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/ChainedCloseableIterator.java @@ -24,46 +24,46 @@ public class ChainedCloseableIterator implements CloseableIterator { - private final Iterator> closeableIterators; - private CloseableIterator currentIterator; + private final Iterator> closeableIterators; + private CloseableIterator currentIterator; - public ChainedCloseableIterator(List> iterators) { - if (iterators == null || iterators.isEmpty()) { - this.currentIterator = null; - } - - this.closeableIterators = iterators.iterator(); - this.currentIterator = closeableIterators.hasNext() ? closeableIterators.next() : null; + public ChainedCloseableIterator(List> iterators) { + if (iterators == null || iterators.isEmpty()) { + this.currentIterator = null; } - @Override - public boolean hasNext() { - while (currentIterator != null) { - if (currentIterator.hasNext()) { - return true; - } else { - currentIterator.close(); - currentIterator = closeableIterators.hasNext() ? closeableIterators.next() : null; - } - } + this.closeableIterators = iterators.iterator(); + this.currentIterator = closeableIterators.hasNext() ? closeableIterators.next() : null; + } - return false; + @Override + public boolean hasNext() { + while (currentIterator != null) { + if (currentIterator.hasNext()) { + return true; + } else { + currentIterator.close(); + currentIterator = closeableIterators.hasNext() ? closeableIterators.next() : null; + } } - @Override - public T next() { - return currentIterator.next(); - } + return false; + } - @Override - public void close() { - if (currentIterator != null) { - currentIterator.close(); - } + @Override + public T next() { + return currentIterator.next(); + } + + @Override + public void close() { + if (currentIterator != null) { + currentIterator.close(); + } - while (closeableIterators.hasNext()) { - CloseableIterator next = closeableIterators.next(); - next.close(); - } + while (closeableIterators.hasNext()) { + CloseableIterator next = closeableIterators.next(); + next.close(); } -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/CloseableIterator.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/CloseableIterator.java index 331c6d430..8038efa1b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/CloseableIterator.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/iterator/CloseableIterator.java @@ -23,5 +23,5 @@ public interface CloseableIterator extends Iterator, AutoCloseable { - void close(); + void close(); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/CycleMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/CycleMetrics.java index 25451a5fb..f2838e75e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/CycleMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/CycleMetrics.java @@ -23,178 +23,198 @@ public class CycleMetrics implements Serializable { - private String name; - private String pipelineName; - private String opName; - - private long duration; - private long startTime; - private long totalTasks; - private int slowestTask; - private long slowestTaskExecuteTime; - private long inputRecords; - private long inputKb; - private long outputRecords; - private long outputKb; - private long avgGcTime; - private long avgExecuteTime; - - public CycleMetrics(String name, String pipelineName, String opName) { - this.name = name; - this.pipelineName = pipelineName; - this.opName = opName; - } - - public String getName() { - return name; - } - - public String getPipelineName() { - return pipelineName; - } - - public String getOpName() { - return opName; - } - - public void setOpName(String opName) { - this.opName = opName; - } - - public long getDuration() { - return duration; - } - - public void setDuration(long duration) { - this.duration = duration; - } - - public long getStartTime() { - return startTime; - } - - public void setStartTime(long startTime) { - this.startTime = startTime; - } - - public long getTotalTasks() { - return totalTasks; - } - - public void setTotalTasks(long totalTasks) { - this.totalTasks = totalTasks; - } - - public int getSlowestTask() { - return slowestTask; - } - - public void setSlowestTask(int slowestTask) { - this.slowestTask = slowestTask; - } - - public long getAvgGcTime() { - return avgGcTime; - } - - public void setAvgGcTime(long avgGcTime) { - this.avgGcTime = avgGcTime; - } - - public long getAvgExecuteTime() { - return avgExecuteTime; - } - - public void setAvgExecuteTime(long avgExecuteTime) { - this.avgExecuteTime = avgExecuteTime; - } - - public long getSlowestTaskExecuteTime() { - return slowestTaskExecuteTime; - } - - public void setSlowestTaskExecuteTime(long slowestTaskExecuteTime) { - this.slowestTaskExecuteTime = slowestTaskExecuteTime; - } - - public long getInputRecords() { - return inputRecords; - } - - public void setInputRecords(long inputRecords) { - this.inputRecords = inputRecords; - } - - public long getOutputRecords() { - return outputRecords; - } - - public void setOutputRecords(long outputRecords) { - this.outputRecords = outputRecords; - } - - public long getInputKb() { - return inputKb; - } - - public void setInputKb(long inputKb) { - this.inputKb = inputKb; - } - - public long getOutputKb() { - return outputKb; - } - - public void setOutputKb(long outputKb) { - this.outputKb = outputKb; - } - - public static CycleMetrics build(String metricName, - String pipelineName, - String opName, - int taskNum, - int slowestTask, - long startTime, - long duration, - long totalExecuteTime, - long totalGcTime, - long slowestTaskExecuteTime, - long totalInputRecords, - long totalInputBytes, - long totalOutputRecords, - long totalOutputBytes) { - CycleMetrics cycleMetrics = new CycleMetrics(metricName, pipelineName, opName); - cycleMetrics.setStartTime(startTime); - cycleMetrics.setTotalTasks(taskNum); - cycleMetrics.setSlowestTask(slowestTask); - cycleMetrics.setDuration(duration); - cycleMetrics.setAvgExecuteTime(totalExecuteTime / taskNum); - cycleMetrics.setAvgGcTime(totalGcTime / taskNum); - cycleMetrics.setSlowestTaskExecuteTime(slowestTaskExecuteTime); - cycleMetrics.setInputRecords(totalInputRecords); - cycleMetrics.setInputKb(totalInputBytes / 1024); - cycleMetrics.setOutputRecords(totalOutputRecords); - cycleMetrics.setOutputKb(totalOutputBytes / 1024); - return cycleMetrics; - } - - @Override - public String toString() { - return "CycleMetrics{" - + "pipelineName='" + pipelineName + '\'' - + ", name='" + name + '\'' - + ", opName='" + opName + '\'' - + ", duration=" + duration + "ms" - + ", totalTasks=" + totalTasks - + ", slowestTask=" + slowestTask - + ", slowestTaskExecuteTime=" + slowestTaskExecuteTime + "ms" - + ", avgGcTime=" + avgGcTime + "ms" - + ", avgExecuteTime=" + avgExecuteTime + "ms" - + ", inputRecords=" + inputRecords - + ", inputKb=" + inputKb - + ", outputRecords=" + outputRecords - + ", outputKb=" + outputKb - + '}'; - } - + private String name; + private String pipelineName; + private String opName; + + private long duration; + private long startTime; + private long totalTasks; + private int slowestTask; + private long slowestTaskExecuteTime; + private long inputRecords; + private long inputKb; + private long outputRecords; + private long outputKb; + private long avgGcTime; + private long avgExecuteTime; + + public CycleMetrics(String name, String pipelineName, String opName) { + this.name = name; + this.pipelineName = pipelineName; + this.opName = opName; + } + + public String getName() { + return name; + } + + public String getPipelineName() { + return pipelineName; + } + + public String getOpName() { + return opName; + } + + public void setOpName(String opName) { + this.opName = opName; + } + + public long getDuration() { + return duration; + } + + public void setDuration(long duration) { + this.duration = duration; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getTotalTasks() { + return totalTasks; + } + + public void setTotalTasks(long totalTasks) { + this.totalTasks = totalTasks; + } + + public int getSlowestTask() { + return slowestTask; + } + + public void setSlowestTask(int slowestTask) { + this.slowestTask = slowestTask; + } + + public long getAvgGcTime() { + return avgGcTime; + } + + public void setAvgGcTime(long avgGcTime) { + this.avgGcTime = avgGcTime; + } + + public long getAvgExecuteTime() { + return avgExecuteTime; + } + + public void setAvgExecuteTime(long avgExecuteTime) { + this.avgExecuteTime = avgExecuteTime; + } + + public long getSlowestTaskExecuteTime() { + return slowestTaskExecuteTime; + } + + public void setSlowestTaskExecuteTime(long slowestTaskExecuteTime) { + this.slowestTaskExecuteTime = slowestTaskExecuteTime; + } + + public long getInputRecords() { + return inputRecords; + } + + public void setInputRecords(long inputRecords) { + this.inputRecords = inputRecords; + } + + public long getOutputRecords() { + return outputRecords; + } + + public void setOutputRecords(long outputRecords) { + this.outputRecords = outputRecords; + } + + public long getInputKb() { + return inputKb; + } + + public void setInputKb(long inputKb) { + this.inputKb = inputKb; + } + + public long getOutputKb() { + return outputKb; + } + + public void setOutputKb(long outputKb) { + this.outputKb = outputKb; + } + + public static CycleMetrics build( + String metricName, + String pipelineName, + String opName, + int taskNum, + int slowestTask, + long startTime, + long duration, + long totalExecuteTime, + long totalGcTime, + long slowestTaskExecuteTime, + long totalInputRecords, + long totalInputBytes, + long totalOutputRecords, + long totalOutputBytes) { + CycleMetrics cycleMetrics = new CycleMetrics(metricName, pipelineName, opName); + cycleMetrics.setStartTime(startTime); + cycleMetrics.setTotalTasks(taskNum); + cycleMetrics.setSlowestTask(slowestTask); + cycleMetrics.setDuration(duration); + cycleMetrics.setAvgExecuteTime(totalExecuteTime / taskNum); + cycleMetrics.setAvgGcTime(totalGcTime / taskNum); + cycleMetrics.setSlowestTaskExecuteTime(slowestTaskExecuteTime); + cycleMetrics.setInputRecords(totalInputRecords); + cycleMetrics.setInputKb(totalInputBytes / 1024); + cycleMetrics.setOutputRecords(totalOutputRecords); + cycleMetrics.setOutputKb(totalOutputBytes / 1024); + return cycleMetrics; + } + + @Override + public String toString() { + return "CycleMetrics{" + + "pipelineName='" + + pipelineName + + '\'' + + ", name='" + + name + + '\'' + + ", opName='" + + opName + + '\'' + + ", duration=" + + duration + + "ms" + + ", totalTasks=" + + totalTasks + + ", slowestTask=" + + slowestTask + + ", slowestTaskExecuteTime=" + + slowestTaskExecuteTime + + "ms" + + ", avgGcTime=" + + avgGcTime + + "ms" + + ", avgExecuteTime=" + + avgExecuteTime + + "ms" + + ", inputRecords=" + + inputRecords + + ", inputKb=" + + inputKb + + ", outputRecords=" + + outputRecords + + ", outputKb=" + + outputKb + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/EventMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/EventMetrics.java index ed8cb4405..2f1d4c5d0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/EventMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/EventMetrics.java @@ -20,187 +20,195 @@ package org.apache.geaflow.common.metric; import java.io.Serializable; + import org.apache.geaflow.common.utils.GcUtil; public class EventMetrics implements Serializable { - /** - * Meta. - */ - private final int vertexId; - private final int parallelism; - private final int index; - - /** - * Execution. - */ - private long startTime; - private long finishTime; - private long executeCostMs; - private long processCostMs; - private long gcCostMs; - - /** - * Shuffle. - */ - private long shuffleReadRecords; - private long shuffleReadBytes; - private long shuffleReadCostMs; - - private long shuffleWriteRecords; - private long shuffleWriteBytes; - private long shuffleWriteCostMs; - - private transient long startGcTs; - - public EventMetrics(int vertexId, int parallelism, int index) { - this.vertexId = vertexId; - this.parallelism = parallelism; - this.index = index; - this.startTime = System.currentTimeMillis(); - this.startGcTs = GcUtil.computeCurrentTotalGcTime(); - } - - public int getVertexId() { - return this.vertexId; - } - - public int getParallelism() { - return this.parallelism; - } - - public int getIndex() { - return this.index; - } - - public long getStartTime() { - return this.startTime; - } - - public void setStartTime(long startTime) { - this.startTime = startTime; - } - - public long getFinishTime() { - return this.finishTime; - } - - public void setFinishTime(long finishTime) { - this.finishTime = finishTime; - this.executeCostMs = this.finishTime - this.startTime; - } - - public long getExecuteCostMs() { - return this.executeCostMs; - } - - public long getProcessCostMs() { - return this.processCostMs; - } - - public void setProcessCostMs(long processCostMs) { - this.processCostMs = processCostMs; - } - - public void addProcessCostMs(long processCostMs) { - this.processCostMs += processCostMs; - } - - public long getGcCostMs() { - return this.gcCostMs; - } - - public long getShuffleReadRecords() { - return this.shuffleReadRecords; - } - - public void setShuffleReadRecords(long shuffleReadRecords) { - this.shuffleReadRecords = shuffleReadRecords; - } - - public void addShuffleReadRecords(long shuffleReadRecords) { - this.shuffleReadRecords += shuffleReadRecords; - } - - public long getShuffleReadBytes() { - return this.shuffleReadBytes; - } - - public void setShuffleReadBytes(long shuffleReadBytes) { - this.shuffleReadBytes = shuffleReadBytes; - } - - public void addShuffleReadBytes(long shuffleReadBytes) { - this.shuffleReadBytes += shuffleReadBytes; - } - - public long getShuffleReadCostMs() { - return this.shuffleReadCostMs; - } - - public void setShuffleReadCostMs(long shuffleReadCostMs) { - this.shuffleReadCostMs = shuffleReadCostMs; - } - - public void addShuffleReadCostMs(long shuffleReadCostMs) { - this.shuffleReadCostMs += shuffleReadCostMs; - } - - public long getShuffleWriteRecords() { - return this.shuffleWriteRecords; - } - - public void setShuffleWriteRecords(long shuffleWriteRecords) { - this.shuffleWriteRecords = shuffleWriteRecords; - } - - public void addShuffleWriteRecords(long shuffleWriteRecords) { - this.shuffleWriteRecords += shuffleWriteRecords; - } - - public long getShuffleWriteBytes() { - return this.shuffleWriteBytes; - } - - public void setShuffleWriteBytes(long shuffleWriteBytes) { - this.shuffleWriteBytes = shuffleWriteBytes; - } - - public void addShuffleWriteBytes(long shuffleWriteBytes) { - this.shuffleWriteBytes += shuffleWriteBytes; - } - - public long getShuffleWriteCostMs() { - return this.shuffleWriteCostMs; - } - - public void addShuffleWriteCostMs(long shuffleWriteCostMs) { - this.shuffleWriteCostMs += shuffleWriteCostMs; - } - - public void setStartGcTs(long startGcTs) { - this.startGcTs = startGcTs; - } - - public void setFinishGcTs(long finishGcTs) { - this.gcCostMs = finishGcTs - this.startGcTs; - } - - @Override - public String toString() { - return "EventMetrics{" - + "startTime=" + startTime - + ", finishTime=" + finishTime - + ", executeCostMs=" + executeCostMs - + ", processCostMs=" + processCostMs - + ", gcCostMs=" + gcCostMs - + ", shuffleReadRecords=" + shuffleReadRecords - + ", shuffleReadBytes=" + shuffleReadBytes - + ", shuffleReadCostMs=" + shuffleReadCostMs - + ", shuffleWriteRecords=" + shuffleWriteRecords - + ", shuffleWriteBytes=" + shuffleWriteBytes - + ", shuffleWriteCostMs=" + shuffleWriteCostMs - + '}'; - } + /** Meta. */ + private final int vertexId; + + private final int parallelism; + private final int index; + + /** Execution. */ + private long startTime; + + private long finishTime; + private long executeCostMs; + private long processCostMs; + private long gcCostMs; + + /** Shuffle. */ + private long shuffleReadRecords; + + private long shuffleReadBytes; + private long shuffleReadCostMs; + + private long shuffleWriteRecords; + private long shuffleWriteBytes; + private long shuffleWriteCostMs; + + private transient long startGcTs; + + public EventMetrics(int vertexId, int parallelism, int index) { + this.vertexId = vertexId; + this.parallelism = parallelism; + this.index = index; + this.startTime = System.currentTimeMillis(); + this.startGcTs = GcUtil.computeCurrentTotalGcTime(); + } + + public int getVertexId() { + return this.vertexId; + } + + public int getParallelism() { + return this.parallelism; + } + + public int getIndex() { + return this.index; + } + + public long getStartTime() { + return this.startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getFinishTime() { + return this.finishTime; + } + + public void setFinishTime(long finishTime) { + this.finishTime = finishTime; + this.executeCostMs = this.finishTime - this.startTime; + } + public long getExecuteCostMs() { + return this.executeCostMs; + } + + public long getProcessCostMs() { + return this.processCostMs; + } + + public void setProcessCostMs(long processCostMs) { + this.processCostMs = processCostMs; + } + + public void addProcessCostMs(long processCostMs) { + this.processCostMs += processCostMs; + } + + public long getGcCostMs() { + return this.gcCostMs; + } + + public long getShuffleReadRecords() { + return this.shuffleReadRecords; + } + + public void setShuffleReadRecords(long shuffleReadRecords) { + this.shuffleReadRecords = shuffleReadRecords; + } + + public void addShuffleReadRecords(long shuffleReadRecords) { + this.shuffleReadRecords += shuffleReadRecords; + } + + public long getShuffleReadBytes() { + return this.shuffleReadBytes; + } + + public void setShuffleReadBytes(long shuffleReadBytes) { + this.shuffleReadBytes = shuffleReadBytes; + } + + public void addShuffleReadBytes(long shuffleReadBytes) { + this.shuffleReadBytes += shuffleReadBytes; + } + + public long getShuffleReadCostMs() { + return this.shuffleReadCostMs; + } + + public void setShuffleReadCostMs(long shuffleReadCostMs) { + this.shuffleReadCostMs = shuffleReadCostMs; + } + + public void addShuffleReadCostMs(long shuffleReadCostMs) { + this.shuffleReadCostMs += shuffleReadCostMs; + } + + public long getShuffleWriteRecords() { + return this.shuffleWriteRecords; + } + + public void setShuffleWriteRecords(long shuffleWriteRecords) { + this.shuffleWriteRecords = shuffleWriteRecords; + } + + public void addShuffleWriteRecords(long shuffleWriteRecords) { + this.shuffleWriteRecords += shuffleWriteRecords; + } + + public long getShuffleWriteBytes() { + return this.shuffleWriteBytes; + } + + public void setShuffleWriteBytes(long shuffleWriteBytes) { + this.shuffleWriteBytes = shuffleWriteBytes; + } + + public void addShuffleWriteBytes(long shuffleWriteBytes) { + this.shuffleWriteBytes += shuffleWriteBytes; + } + + public long getShuffleWriteCostMs() { + return this.shuffleWriteCostMs; + } + + public void addShuffleWriteCostMs(long shuffleWriteCostMs) { + this.shuffleWriteCostMs += shuffleWriteCostMs; + } + + public void setStartGcTs(long startGcTs) { + this.startGcTs = startGcTs; + } + + public void setFinishGcTs(long finishGcTs) { + this.gcCostMs = finishGcTs - this.startGcTs; + } + + @Override + public String toString() { + return "EventMetrics{" + + "startTime=" + + startTime + + ", finishTime=" + + finishTime + + ", executeCostMs=" + + executeCostMs + + ", processCostMs=" + + processCostMs + + ", gcCostMs=" + + gcCostMs + + ", shuffleReadRecords=" + + shuffleReadRecords + + ", shuffleReadBytes=" + + shuffleReadBytes + + ", shuffleReadCostMs=" + + shuffleReadCostMs + + ", shuffleWriteRecords=" + + shuffleWriteRecords + + ", shuffleWriteBytes=" + + shuffleWriteBytes + + ", shuffleWriteCostMs=" + + shuffleWriteCostMs + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/PipelineMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/PipelineMetrics.java index 8449501a3..458ec021d 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/PipelineMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/PipelineMetrics.java @@ -23,40 +23,36 @@ public class PipelineMetrics implements Serializable { - private String name; - private long duration; - private long startTime; - - public PipelineMetrics(String name) { - this.name = name; - } - - public String getName() { - return name; - } - - public long getDuration() { - return duration; - } - - public void setDuration(long duration) { - this.duration = duration; - } - - public long getStartTime() { - return startTime; - } - - public void setStartTime(long startTime) { - this.startTime = startTime; - } - - @Override - public String toString() { - return "PipelineMetrics{" - + "name='" + name + '\'' - + ", duration=" + duration + "ms" - + '}'; - } - + private String name; + private long duration; + private long startTime; + + public PipelineMetrics(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public long getDuration() { + return duration; + } + + public void setDuration(long duration) { + this.duration = duration; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + @Override + public String toString() { + return "PipelineMetrics{" + "name='" + name + '\'' + ", duration=" + duration + "ms" + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ProcessMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ProcessMetrics.java index 1e46b0a03..c92f23f5f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ProcessMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ProcessMetrics.java @@ -21,161 +21,172 @@ public class ProcessMetrics { - /** - * This amount of memory is guaranteed for the Java virtual machine to use. - */ - private long heapCommittedMB; - private long heapUsedMB; - private double heapUsedRatio; - private long totalMemoryMB; - - /** - * the total number of full collections that have occurred. - */ - private long fgcCount = 0L; - /** - * the total cost of full collections that have occurred. - */ - private long fgcTime = 0L; - /** - * the approximate accumulated collection elapsed time in milliseconds. - */ - private long gcTime = 0L; - /** - * the total number of collections that have occurred. - */ - private long gcCount = 0; - - /** - * the system load average for the last minute, or a negative value if not available. - */ - private double avgLoad; - /** - * the number of processors available to the Java virtual machine. - */ - private int availCores; - /** - * cpu usage. - */ - private double processCpu; - /** - * the number of processors used. - */ - private double usedCores; - private int activeThreads; - - public long getHeapCommittedMB() { - return heapCommittedMB; - } - - public void setHeapCommittedMB(long heapCommittedMB) { - this.heapCommittedMB = heapCommittedMB; - } - - public long getHeapUsedMB() { - return heapUsedMB; - } - - public void setHeapUsedMB(long heapUsedMB) { - this.heapUsedMB = heapUsedMB; - } - - public double getHeapUsedRatio() { - return heapUsedRatio; - } - - public void setHeapUsedRatio(double heapUsedRatio) { - this.heapUsedRatio = heapUsedRatio; - } - - public long getTotalMemoryMB() { - return totalMemoryMB; - } - - public void setTotalMemoryMB(long totalMemoryMB) { - this.totalMemoryMB = totalMemoryMB; - } - - public long getFgcCount() { - return fgcCount; - } - - public void setFgcCount(long fgcCount) { - this.fgcCount = fgcCount; - } - - public long getFgcTime() { - return fgcTime; - } - - public void setFgcTime(long fgcTime) { - this.fgcTime = fgcTime; - } - - public long getGcTime() { - return gcTime; - } - - public void setGcTime(long gcTime) { - this.gcTime = gcTime; - } - - public long getGcCount() { - return gcCount; - } - - public void setGcCount(long gcCount) { - this.gcCount = gcCount; - } - - public double getAvgLoad() { - return avgLoad; - } - - public void setAvgLoad(double avgLoad) { - this.avgLoad = avgLoad; - } - - public int getAvailCores() { - return availCores; - } - - public void setAvailCores(int availCores) { - this.availCores = availCores; - } - - public double getUsedCores() { - return usedCores; - } - - public void setUsedCores(double usedCores) { - this.usedCores = usedCores; - } - - public double getProcessCpu() { - return processCpu; - } - - public void setProcessCpu(double processCpu) { - this.processCpu = processCpu; - } - - public int getActiveThreads() { - return activeThreads; - } + /** This amount of memory is guaranteed for the Java virtual machine to use. */ + private long heapCommittedMB; - public void setActiveThreads(int activeThreads) { - this.activeThreads = activeThreads; - } + private long heapUsedMB; + private double heapUsedRatio; + private long totalMemoryMB; - @Override - public String toString() { - return "ProcessMetrics{" + "heapCommittedMB=" + heapCommittedMB + ", heapUsedMB=" - + heapUsedMB + ", heapUsedRatio=" + heapUsedRatio + ", totalMemoryMB=" + totalMemoryMB - + ", fgcCount=" + fgcCount + ", fgcTime=" + fgcTime + ", gcTime=" + gcTime - + ", gcCount=" + gcCount + ", avgLoad=" + avgLoad + ", availCores=" + availCores - + ", processCpu=" + processCpu + ", usedCores=" + usedCores + ", activeThreads=" - + activeThreads + '}'; - } + /** the total number of full collections that have occurred. */ + private long fgcCount = 0L; + /** the total cost of full collections that have occurred. */ + private long fgcTime = 0L; + + /** the approximate accumulated collection elapsed time in milliseconds. */ + private long gcTime = 0L; + + /** the total number of collections that have occurred. */ + private long gcCount = 0; + + /** the system load average for the last minute, or a negative value if not available. */ + private double avgLoad; + + /** the number of processors available to the Java virtual machine. */ + private int availCores; + + /** cpu usage. */ + private double processCpu; + + /** the number of processors used. */ + private double usedCores; + + private int activeThreads; + + public long getHeapCommittedMB() { + return heapCommittedMB; + } + + public void setHeapCommittedMB(long heapCommittedMB) { + this.heapCommittedMB = heapCommittedMB; + } + + public long getHeapUsedMB() { + return heapUsedMB; + } + + public void setHeapUsedMB(long heapUsedMB) { + this.heapUsedMB = heapUsedMB; + } + + public double getHeapUsedRatio() { + return heapUsedRatio; + } + + public void setHeapUsedRatio(double heapUsedRatio) { + this.heapUsedRatio = heapUsedRatio; + } + + public long getTotalMemoryMB() { + return totalMemoryMB; + } + + public void setTotalMemoryMB(long totalMemoryMB) { + this.totalMemoryMB = totalMemoryMB; + } + + public long getFgcCount() { + return fgcCount; + } + + public void setFgcCount(long fgcCount) { + this.fgcCount = fgcCount; + } + + public long getFgcTime() { + return fgcTime; + } + + public void setFgcTime(long fgcTime) { + this.fgcTime = fgcTime; + } + + public long getGcTime() { + return gcTime; + } + + public void setGcTime(long gcTime) { + this.gcTime = gcTime; + } + + public long getGcCount() { + return gcCount; + } + + public void setGcCount(long gcCount) { + this.gcCount = gcCount; + } + + public double getAvgLoad() { + return avgLoad; + } + + public void setAvgLoad(double avgLoad) { + this.avgLoad = avgLoad; + } + + public int getAvailCores() { + return availCores; + } + + public void setAvailCores(int availCores) { + this.availCores = availCores; + } + + public double getUsedCores() { + return usedCores; + } + + public void setUsedCores(double usedCores) { + this.usedCores = usedCores; + } + + public double getProcessCpu() { + return processCpu; + } + + public void setProcessCpu(double processCpu) { + this.processCpu = processCpu; + } + + public int getActiveThreads() { + return activeThreads; + } + + public void setActiveThreads(int activeThreads) { + this.activeThreads = activeThreads; + } + + @Override + public String toString() { + return "ProcessMetrics{" + + "heapCommittedMB=" + + heapCommittedMB + + ", heapUsedMB=" + + heapUsedMB + + ", heapUsedRatio=" + + heapUsedRatio + + ", totalMemoryMB=" + + totalMemoryMB + + ", fgcCount=" + + fgcCount + + ", fgcTime=" + + fgcTime + + ", gcTime=" + + gcTime + + ", gcCount=" + + gcCount + + ", avgLoad=" + + avgLoad + + ", availCores=" + + availCores + + ", processCpu=" + + processCpu + + ", usedCores=" + + usedCores + + ", activeThreads=" + + activeThreads + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleReadMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleReadMetrics.java index c9b0fa56a..5ca4d1ce2 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleReadMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleReadMetrics.java @@ -23,107 +23,113 @@ public class ShuffleReadMetrics implements Serializable { - private int fetchSlices; - /** - * total records of fetch response. - */ - private long fetchRecords; - private long decodeBytes; - /** - * time to write response. - */ - private long fetchWaitMs; - /** - * time to decode response. - */ - private long decodeMs; - /** - * cost from request sent to response back. - */ - private long requestMs; - - public void merge(ShuffleReadMetrics readMetrics) { - if (readMetrics != null) { - fetchSlices += readMetrics.getFetchSlices(); - fetchRecords += readMetrics.getFetchRecords(); - decodeBytes += readMetrics.getDecodeBytes(); - fetchWaitMs += readMetrics.getFetchWaitMs(); - decodeMs += readMetrics.getDecodeMs(); - } - } - - public int getFetchSlices() { - return fetchSlices; - } - - public void setFetchSlices(int fetchSlices) { - this.fetchSlices = fetchSlices; - } + private int fetchSlices; - public long getFetchRecords() { - return fetchRecords; - } - - public void setFetchRecords(long fetchRecords) { - this.fetchRecords = fetchRecords; - } - - public long getDecodeBytes() { - return decodeBytes; - } - - public void setDecodeBytes(long decodeBytes) { - this.decodeBytes = decodeBytes; - } - - public void increaseDecodeBytes(long decodeBytes) { - this.decodeBytes += decodeBytes; - } - - - public long getFetchWaitMs() { - return fetchWaitMs; - } - - public void setFetchWaitMs(long fetchWaitMs) { - this.fetchWaitMs = fetchWaitMs; - } - - public long getDecodeMs() { - return decodeMs; - } - - public void setDecodeMs(long decodeMs) { - this.decodeMs = decodeMs; - } - - public long getRequestMs() { - return requestMs; - } - - public void setRequestMs(long requestMs) { - this.requestMs = requestMs; - } - - public void updateDecodeBytes(long decodeBytes) { - this.decodeBytes += decodeBytes; - } - - public void updateFetchRecords(long records) { - this.fetchRecords += records; - } - - public void increaseDecodeMs(long decodeMs) { - this.decodeMs += decodeMs; - } - - public void incFetchWaitMs(long fetchWaitMs) { - this.fetchWaitMs += fetchWaitMs; - } - - @Override - public String toString() { - return "ReadMetrics{" + "fetchSlices=" + fetchSlices + ", fetchRecords=" + fetchRecords + ", decodeKB=" + decodeBytes / 1024 + ", fetchWaitMs=" + fetchWaitMs - + ", decodeMs=" + decodeMs + '}'; - } + /** total records of fetch response. */ + private long fetchRecords; + + private long decodeBytes; + + /** time to write response. */ + private long fetchWaitMs; + + /** time to decode response. */ + private long decodeMs; + + /** cost from request sent to response back. */ + private long requestMs; + + public void merge(ShuffleReadMetrics readMetrics) { + if (readMetrics != null) { + fetchSlices += readMetrics.getFetchSlices(); + fetchRecords += readMetrics.getFetchRecords(); + decodeBytes += readMetrics.getDecodeBytes(); + fetchWaitMs += readMetrics.getFetchWaitMs(); + decodeMs += readMetrics.getDecodeMs(); + } + } + + public int getFetchSlices() { + return fetchSlices; + } + + public void setFetchSlices(int fetchSlices) { + this.fetchSlices = fetchSlices; + } + + public long getFetchRecords() { + return fetchRecords; + } + + public void setFetchRecords(long fetchRecords) { + this.fetchRecords = fetchRecords; + } + + public long getDecodeBytes() { + return decodeBytes; + } + + public void setDecodeBytes(long decodeBytes) { + this.decodeBytes = decodeBytes; + } + + public void increaseDecodeBytes(long decodeBytes) { + this.decodeBytes += decodeBytes; + } + + public long getFetchWaitMs() { + return fetchWaitMs; + } + + public void setFetchWaitMs(long fetchWaitMs) { + this.fetchWaitMs = fetchWaitMs; + } + + public long getDecodeMs() { + return decodeMs; + } + + public void setDecodeMs(long decodeMs) { + this.decodeMs = decodeMs; + } + + public long getRequestMs() { + return requestMs; + } + + public void setRequestMs(long requestMs) { + this.requestMs = requestMs; + } + + public void updateDecodeBytes(long decodeBytes) { + this.decodeBytes += decodeBytes; + } + + public void updateFetchRecords(long records) { + this.fetchRecords += records; + } + + public void increaseDecodeMs(long decodeMs) { + this.decodeMs += decodeMs; + } + + public void incFetchWaitMs(long fetchWaitMs) { + this.fetchWaitMs += fetchWaitMs; + } + + @Override + public String toString() { + return "ReadMetrics{" + + "fetchSlices=" + + fetchSlices + + ", fetchRecords=" + + fetchRecords + + ", decodeKB=" + + decodeBytes / 1024 + + ", fetchWaitMs=" + + fetchWaitMs + + ", decodeMs=" + + decodeMs + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleWriteMetrics.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleWriteMetrics.java index 5cc37405c..26e4fa2bc 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleWriteMetrics.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/metric/ShuffleWriteMetrics.java @@ -23,203 +23,208 @@ public class ShuffleWriteMetrics implements Serializable { - /** - * total output channels. - */ - private long numChannels; - /** - * total output channels with data written. - */ - private long writtenChannels; - - /** - * total written records. - */ - private long writtenRecords; - /** - * total written bytes. - */ - private long encodedSize; - /** - * time cost on serializing. - */ - private long encodeMs; - /** - * max written slice size in KB. - */ - private long maxSliceKB; - - /** - * total spills. - */ - private int spillNum; - /** - * total spills to disk. - */ - private int spillDisk; - /** - * total oom count. - */ - private int oomCount; - /** - * max spilled size in KB. - */ - private long maxSpillKB; - - /** - * total spill cost in ms. - */ - private long spillMs; - private long flushMs; - - public void merge(ShuffleWriteMetrics metrics) { - this.numChannels += metrics.numChannels; - this.writtenChannels += metrics.writtenChannels; - this.writtenRecords += metrics.writtenRecords; - this.encodedSize += metrics.encodedSize; - this.encodeMs += metrics.encodeMs; - this.spillNum += metrics.spillNum; - this.spillDisk += metrics.spillDisk; - this.oomCount += metrics.oomCount; - this.spillMs += metrics.spillMs; - this.flushMs += metrics.flushMs; - this.maxSliceKB = Math.max(maxSliceKB, metrics.maxSliceKB); - this.maxSpillKB = Math.max(maxSpillKB, metrics.maxSpillKB); - } + /** total output channels. */ + private long numChannels; - public long getWrittenRecords() { - return writtenRecords; - } + /** total output channels with data written. */ + private long writtenChannels; - public void setWrittenRecords(long writtenRecords) { - this.writtenRecords = writtenRecords; - } + /** total written records. */ + private long writtenRecords; - public long getEncodedSize() { - return encodedSize; - } + /** total written bytes. */ + private long encodedSize; - public void setEncodedSize(long encodedSize) { - this.encodedSize = encodedSize; - } + /** time cost on serializing. */ + private long encodeMs; - public long getEncodeMs() { - return encodeMs; - } + /** max written slice size in KB. */ + private long maxSliceKB; - public void setEncodeMs(long encodeMs) { - this.encodeMs = encodeMs; - } + /** total spills. */ + private int spillNum; - public int getSpillNum() { - return spillNum; - } + /** total spills to disk. */ + private int spillDisk; - public void setSpillNum(int spillNum) { - this.spillNum = spillNum; - } + /** total oom count. */ + private int oomCount; - public int getSpillDisk() { - return spillDisk; - } + /** max spilled size in KB. */ + private long maxSpillKB; - public void setSpillDisk(int spillDisk) { - this.spillDisk = spillDisk; - } + /** total spill cost in ms. */ + private long spillMs; - public int getOomCount() { - return oomCount; - } + private long flushMs; - public void setOomCount(int oomCount) { - this.oomCount = oomCount; - } + public void merge(ShuffleWriteMetrics metrics) { + this.numChannels += metrics.numChannels; + this.writtenChannels += metrics.writtenChannels; + this.writtenRecords += metrics.writtenRecords; + this.encodedSize += metrics.encodedSize; + this.encodeMs += metrics.encodeMs; + this.spillNum += metrics.spillNum; + this.spillDisk += metrics.spillDisk; + this.oomCount += metrics.oomCount; + this.spillMs += metrics.spillMs; + this.flushMs += metrics.flushMs; + this.maxSliceKB = Math.max(maxSliceKB, metrics.maxSliceKB); + this.maxSpillKB = Math.max(maxSpillKB, metrics.maxSpillKB); + } - public long getMaxSpillKB() { - return maxSpillKB; - } + public long getWrittenRecords() { + return writtenRecords; + } - public void setMaxSpillKB(long maxSpillKB) { - this.maxSpillKB = maxSpillKB; - } + public void setWrittenRecords(long writtenRecords) { + this.writtenRecords = writtenRecords; + } - public long getMaxSliceKB() { - return maxSliceKB; - } + public long getEncodedSize() { + return encodedSize; + } - public void setMaxSliceKB(long maxSliceSizeKB) { - this.maxSliceKB = maxSliceSizeKB; - } + public void setEncodedSize(long encodedSize) { + this.encodedSize = encodedSize; + } - public long getSpillMs() { - return spillMs; - } + public long getEncodeMs() { + return encodeMs; + } - public void setSpillMs(long spillMs) { - this.spillMs = spillMs; - } + public void setEncodeMs(long encodeMs) { + this.encodeMs = encodeMs; + } - public long getNumChannels() { - return numChannels; - } + public int getSpillNum() { + return spillNum; + } - public void setNumChannels(long numChannels) { - this.numChannels = numChannels; - } - - public long getWrittenChannels() { - return writtenChannels; - } - - public void setWrittenChannels(long writtenChannels) { - this.writtenChannels = writtenChannels; - } - - public long getFlushMs() { - return flushMs; - } - - public void setFlushMs(long flushMs) { - this.flushMs = flushMs; - } - - public void increaseRecords(long recordNum) { - this.writtenRecords += recordNum; - } - - public void increaseEncodedSize(long bytes) { - this.encodedSize += bytes; - } - - public void increaseEncodeMs(long encodeMs) { - this.encodeMs += encodeMs; - } - - public void increaseSpillMs(long spillMs) { - this.spillMs += spillMs; - } - - public void increaseSpillNum() { - this.spillNum++; - } - - public void increaseWrittenChannels() { - writtenChannels++; - } - - public void updateMaxSpillKB(long spillKB) { - if (this.maxSpillKB < spillKB) { - this.maxSpillKB = spillKB; - } - } - - @Override - public String toString() { - return "WriteMetrics{" + "outputRecords=" + writtenRecords + ", encodedKb=" - + encodedSize / 1024 + ", encodeMs=" + encodeMs + ", spillNum=" + spillNum - + ", spillDisk=" + spillDisk + ", oomCnt=" + oomCount + ", spillMs=" + spillMs - + ", maxSpillKB=" + maxSpillKB + ", " + "maxSliceKB=" + maxSliceKB + ", channels=" - + numChannels + ", writtenChannels=" + writtenChannels + '}'; - } + public void setSpillNum(int spillNum) { + this.spillNum = spillNum; + } + public int getSpillDisk() { + return spillDisk; + } + + public void setSpillDisk(int spillDisk) { + this.spillDisk = spillDisk; + } + + public int getOomCount() { + return oomCount; + } + + public void setOomCount(int oomCount) { + this.oomCount = oomCount; + } + + public long getMaxSpillKB() { + return maxSpillKB; + } + + public void setMaxSpillKB(long maxSpillKB) { + this.maxSpillKB = maxSpillKB; + } + + public long getMaxSliceKB() { + return maxSliceKB; + } + + public void setMaxSliceKB(long maxSliceSizeKB) { + this.maxSliceKB = maxSliceSizeKB; + } + + public long getSpillMs() { + return spillMs; + } + + public void setSpillMs(long spillMs) { + this.spillMs = spillMs; + } + + public long getNumChannels() { + return numChannels; + } + + public void setNumChannels(long numChannels) { + this.numChannels = numChannels; + } + + public long getWrittenChannels() { + return writtenChannels; + } + + public void setWrittenChannels(long writtenChannels) { + this.writtenChannels = writtenChannels; + } + + public long getFlushMs() { + return flushMs; + } + + public void setFlushMs(long flushMs) { + this.flushMs = flushMs; + } + + public void increaseRecords(long recordNum) { + this.writtenRecords += recordNum; + } + + public void increaseEncodedSize(long bytes) { + this.encodedSize += bytes; + } + + public void increaseEncodeMs(long encodeMs) { + this.encodeMs += encodeMs; + } + + public void increaseSpillMs(long spillMs) { + this.spillMs += spillMs; + } + + public void increaseSpillNum() { + this.spillNum++; + } + + public void increaseWrittenChannels() { + writtenChannels++; + } + + public void updateMaxSpillKB(long spillKB) { + if (this.maxSpillKB < spillKB) { + this.maxSpillKB = spillKB; + } + } + + @Override + public String toString() { + return "WriteMetrics{" + + "outputRecords=" + + writtenRecords + + ", encodedKb=" + + encodedSize / 1024 + + ", encodeMs=" + + encodeMs + + ", spillNum=" + + spillNum + + ", spillDisk=" + + spillDisk + + ", oomCnt=" + + oomCount + + ", spillMs=" + + spillMs + + ", maxSpillKB=" + + maxSpillKB + + ", " + + "maxSliceKB=" + + maxSliceKB + + ", channels=" + + numChannels + + ", writtenChannels=" + + writtenChannels + + '}'; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/mode/JobMode.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/mode/JobMode.java index 6f2a79fe4..399053717 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/mode/JobMode.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/mode/JobMode.java @@ -19,28 +19,21 @@ package org.apache.geaflow.common.mode; - import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; public enum JobMode { - /** - * compute job mode. - */ - COMPUTE, + /** compute job mode. */ + COMPUTE, - /** - * Olap service job mode. - */ - OLAP_SERVICE, + /** Olap service job mode. */ + OLAP_SERVICE, - /** - * State service job mode. - */ - STATE_SERVICE; + /** State service job mode. */ + STATE_SERVICE; - public static JobMode getJobMode(Configuration configuration) { - String jobMode = configuration.getString(ExecutionConfigKeys.JOB_MODE); - return valueOf(jobMode.toUpperCase()); - } + public static JobMode getJobMode(Configuration configuration) { + String jobMode = configuration.getString(ExecutionConfigKeys.JOB_MODE); + return valueOf(jobMode.toUpperCase()); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableClientOption.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableClientOption.java index 9e0f38b72..b9f8f84d4 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableClientOption.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableClientOption.java @@ -19,82 +19,82 @@ package org.apache.geaflow.common.rpc; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.baidu.brpc.client.RpcClientOptions; import com.baidu.brpc.client.channel.ChannelType; import com.baidu.brpc.loadbalance.LoadBalanceStrategy; import com.baidu.brpc.protocol.Options; import com.baidu.brpc.utils.BrpcConstants; -import org.apache.geaflow.common.config.Configuration; -import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class ConfigurableClientOption { - private static final String SHORT_CONNECTION = "short_connection"; - private static final String SINGLE_CONNECTION = "single_connection"; - - private static final Logger LOGGER = LoggerFactory.getLogger(ConfigurableClientOption.class); - - public static RpcClientOptions build(Configuration config) { - RpcClientOptions clientOption = new RpcClientOptions(); - clientOption.setProtocolType(Options.ProtocolType.PROTOCOL_BAIDU_STD_VALUE); - int maxRetryTimes = config.getInteger(ExecutionConfigKeys.RPC_MAX_RETRY_TIMES); - clientOption.setMaxTryTimes(maxRetryTimes); - - // only works for pooled connection - int maxTotalConnectionNum = config.getInteger(ExecutionConfigKeys.RPC_MAX_TOTAL_CONNECTION_NUM); - clientOption.setMaxTotalConnections(maxTotalConnectionNum); - // only works for pooled connection - int minIdleConnectionNum = config.getInteger(ExecutionConfigKeys.RPC_MIN_IDLE_CONNECTION_NUM); - clientOption.setMinIdleConnections(minIdleConnectionNum); - - clientOption.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); - clientOption.setCompressType(Options.CompressType.COMPRESS_TYPE_NONE); - - boolean threadSharing = config.getBoolean(ExecutionConfigKeys.RPC_THREADPOOL_SHARING_ENABLE); - clientOption.setGlobalThreadPoolSharing(threadSharing); - - int ioThreadNum = config.getInteger(ExecutionConfigKeys.RPC_IO_THREAD_NUM); - int workerThreadNum = config.getInteger(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM); - int defaultThreads = Runtime.getRuntime().availableProcessors(); - - clientOption.setIoThreadNum(Math.max(ioThreadNum, defaultThreads)); - clientOption.setWorkThreadNum(Math.max(workerThreadNum, defaultThreads)); - - clientOption.setIoEventType(BrpcConstants.IO_EVENT_NETTY_EPOLL); - int rpcBufferSize = config.getInteger(ExecutionConfigKeys.RPC_BUFFER_SIZE_BYTES); - clientOption.setSendBufferSize(rpcBufferSize); - clientOption.setReceiveBufferSize(rpcBufferSize); - - ChannelType channelType = getChannelType(config); - if (ChannelType.SINGLE_CONNECTION.equals(channelType)) { - // Only SINGLE_CONNECTION type needs to be set keepAliveTime. - int keepAliveTime = config.getInteger(ExecutionConfigKeys.RPC_KEEP_ALIVE_TIME_SEC); - clientOption.setKeepAliveTime(keepAliveTime); - } - clientOption.setChannelType(channelType); - - int writeTimeout = config.getInteger(ExecutionConfigKeys.RPC_WRITE_TIMEOUT_MS); - clientOption.setWriteTimeoutMillis(writeTimeout); - int readTimeout = config.getInteger(ExecutionConfigKeys.RPC_READ_TIMEOUT_MS); - clientOption.setReadTimeoutMillis(readTimeout); - int connectTimeout = config.getInteger(ExecutionConfigKeys.RPC_CONNECT_TIMEOUT_MS); - clientOption.setConnectTimeoutMillis(connectTimeout); - - LOGGER.info("rpc client options set: {}", clientOption); - return clientOption; - } + private static final String SHORT_CONNECTION = "short_connection"; + private static final String SINGLE_CONNECTION = "single_connection"; - private static ChannelType getChannelType(Configuration config) { - String channelType = config.getString(ExecutionConfigKeys.RPC_CHANNEL_CONNECT_TYPE); - if (channelType.equalsIgnoreCase(SHORT_CONNECTION)) { - return ChannelType.SHORT_CONNECTION; - } else if (channelType.equalsIgnoreCase(SINGLE_CONNECTION)) { - return ChannelType.SINGLE_CONNECTION; - } else { - return ChannelType.POOLED_CONNECTION; - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(ConfigurableClientOption.class); + + public static RpcClientOptions build(Configuration config) { + RpcClientOptions clientOption = new RpcClientOptions(); + clientOption.setProtocolType(Options.ProtocolType.PROTOCOL_BAIDU_STD_VALUE); + int maxRetryTimes = config.getInteger(ExecutionConfigKeys.RPC_MAX_RETRY_TIMES); + clientOption.setMaxTryTimes(maxRetryTimes); + + // only works for pooled connection + int maxTotalConnectionNum = config.getInteger(ExecutionConfigKeys.RPC_MAX_TOTAL_CONNECTION_NUM); + clientOption.setMaxTotalConnections(maxTotalConnectionNum); + // only works for pooled connection + int minIdleConnectionNum = config.getInteger(ExecutionConfigKeys.RPC_MIN_IDLE_CONNECTION_NUM); + clientOption.setMinIdleConnections(minIdleConnectionNum); + + clientOption.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); + clientOption.setCompressType(Options.CompressType.COMPRESS_TYPE_NONE); + + boolean threadSharing = config.getBoolean(ExecutionConfigKeys.RPC_THREADPOOL_SHARING_ENABLE); + clientOption.setGlobalThreadPoolSharing(threadSharing); + int ioThreadNum = config.getInteger(ExecutionConfigKeys.RPC_IO_THREAD_NUM); + int workerThreadNum = config.getInteger(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM); + int defaultThreads = Runtime.getRuntime().availableProcessors(); + + clientOption.setIoThreadNum(Math.max(ioThreadNum, defaultThreads)); + clientOption.setWorkThreadNum(Math.max(workerThreadNum, defaultThreads)); + + clientOption.setIoEventType(BrpcConstants.IO_EVENT_NETTY_EPOLL); + int rpcBufferSize = config.getInteger(ExecutionConfigKeys.RPC_BUFFER_SIZE_BYTES); + clientOption.setSendBufferSize(rpcBufferSize); + clientOption.setReceiveBufferSize(rpcBufferSize); + + ChannelType channelType = getChannelType(config); + if (ChannelType.SINGLE_CONNECTION.equals(channelType)) { + // Only SINGLE_CONNECTION type needs to be set keepAliveTime. + int keepAliveTime = config.getInteger(ExecutionConfigKeys.RPC_KEEP_ALIVE_TIME_SEC); + clientOption.setKeepAliveTime(keepAliveTime); + } + clientOption.setChannelType(channelType); + + int writeTimeout = config.getInteger(ExecutionConfigKeys.RPC_WRITE_TIMEOUT_MS); + clientOption.setWriteTimeoutMillis(writeTimeout); + int readTimeout = config.getInteger(ExecutionConfigKeys.RPC_READ_TIMEOUT_MS); + clientOption.setReadTimeoutMillis(readTimeout); + int connectTimeout = config.getInteger(ExecutionConfigKeys.RPC_CONNECT_TIMEOUT_MS); + clientOption.setConnectTimeoutMillis(connectTimeout); + + LOGGER.info("rpc client options set: {}", clientOption); + return clientOption; + } + + private static ChannelType getChannelType(Configuration config) { + String channelType = config.getString(ExecutionConfigKeys.RPC_CHANNEL_CONNECT_TYPE); + if (channelType.equalsIgnoreCase(SHORT_CONNECTION)) { + return ChannelType.SHORT_CONNECTION; + } else if (channelType.equalsIgnoreCase(SINGLE_CONNECTION)) { + return ChannelType.SINGLE_CONNECTION; + } else { + return ChannelType.POOLED_CONNECTION; + } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableServerOption.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableServerOption.java index dfbbe9639..09c8b91d0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableServerOption.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/ConfigurableServerOption.java @@ -19,43 +19,44 @@ package org.apache.geaflow.common.rpc; -import com.baidu.brpc.protocol.Options; -import com.baidu.brpc.server.RpcServerOptions; -import com.baidu.brpc.utils.BrpcConstants; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.baidu.brpc.protocol.Options; +import com.baidu.brpc.server.RpcServerOptions; +import com.baidu.brpc.utils.BrpcConstants; + public class ConfigurableServerOption { - private static final Logger LOGGER = LoggerFactory.getLogger(ConfigurableServerOption.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ConfigurableServerOption.class); - public static RpcServerOptions build(Configuration config) { - RpcServerOptions serverOptions = new RpcServerOptions(); - serverOptions.setProtocolType(Options.ProtocolType.PROTOCOL_BAIDU_STD_VALUE); - int maxRetryTimes = config.getInteger(ExecutionConfigKeys.RPC_MAX_RETRY_TIMES); - serverOptions.setMaxTryTimes(maxRetryTimes); + public static RpcServerOptions build(Configuration config) { + RpcServerOptions serverOptions = new RpcServerOptions(); + serverOptions.setProtocolType(Options.ProtocolType.PROTOCOL_BAIDU_STD_VALUE); + int maxRetryTimes = config.getInteger(ExecutionConfigKeys.RPC_MAX_RETRY_TIMES); + serverOptions.setMaxTryTimes(maxRetryTimes); - int keepAliveTime = config.getInteger(ExecutionConfigKeys.RPC_KEEP_ALIVE_TIME_SEC); - serverOptions.setKeepAliveTime(keepAliveTime); + int keepAliveTime = config.getInteger(ExecutionConfigKeys.RPC_KEEP_ALIVE_TIME_SEC); + serverOptions.setKeepAliveTime(keepAliveTime); - boolean threadSharing = config.getBoolean(ExecutionConfigKeys.RPC_THREADPOOL_SHARING_ENABLE); - serverOptions.setGlobalThreadPoolSharing(threadSharing); + boolean threadSharing = config.getBoolean(ExecutionConfigKeys.RPC_THREADPOOL_SHARING_ENABLE); + serverOptions.setGlobalThreadPoolSharing(threadSharing); - int ioThreadNum = config.getInteger(ExecutionConfigKeys.RPC_IO_THREAD_NUM); - int workerThreadNum = config.getInteger(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM); - int defaultThreads = Math.max(Runtime.getRuntime().availableProcessors(), 8); + int ioThreadNum = config.getInteger(ExecutionConfigKeys.RPC_IO_THREAD_NUM); + int workerThreadNum = config.getInteger(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM); + int defaultThreads = Math.max(Runtime.getRuntime().availableProcessors(), 8); - serverOptions.setIoThreadNum(Math.max(ioThreadNum, defaultThreads)); - serverOptions.setWorkThreadNum(Math.max(workerThreadNum, defaultThreads)); - serverOptions.setIoEventType(BrpcConstants.IO_EVENT_NETTY_EPOLL); + serverOptions.setIoThreadNum(Math.max(ioThreadNum, defaultThreads)); + serverOptions.setWorkThreadNum(Math.max(workerThreadNum, defaultThreads)); + serverOptions.setIoEventType(BrpcConstants.IO_EVENT_NETTY_EPOLL); - int rpcBufferSize = config.getInteger(ExecutionConfigKeys.RPC_BUFFER_SIZE_BYTES); - serverOptions.setSendBufferSize(rpcBufferSize); - serverOptions.setReceiveBufferSize(rpcBufferSize); + int rpcBufferSize = config.getInteger(ExecutionConfigKeys.RPC_BUFFER_SIZE_BYTES); + serverOptions.setSendBufferSize(rpcBufferSize); + serverOptions.setReceiveBufferSize(rpcBufferSize); - LOGGER.info("server options set: {}", serverOptions); - return serverOptions; - } + LOGGER.info("server options set: {}", serverOptions); + return serverOptions; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/HostAndPort.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/HostAndPort.java index 6055d436d..c3a00ae29 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/HostAndPort.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/rpc/HostAndPort.java @@ -19,55 +19,56 @@ package org.apache.geaflow.common.rpc; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.Objects; +import com.google.common.base.Preconditions; + public class HostAndPort implements Serializable { - private static final String DELIMITER = "_"; - private final String host; + private static final String DELIMITER = "_"; + private final String host; - private final int port; + private final int port; - public HostAndPort(String host, int port) { - this.host = host; - this.port = port; - } + public HostAndPort(String host, int port) { + this.host = host; + this.port = port; + } - public static HostAndPort of(String hostAndPort) { - String[] strs = hostAndPort.split(DELIMITER); - Preconditions.checkState(strs.length == 2); - return new HostAndPort(strs[0], Integer.parseInt(strs[1])); - } + public static HostAndPort of(String hostAndPort) { + String[] strs = hostAndPort.split(DELIMITER); + Preconditions.checkState(strs.length == 2); + return new HostAndPort(strs[0], Integer.parseInt(strs[1])); + } - public String getHost() { - return host; - } + public String getHost() { + return host; + } - public int getPort() { - return port; - } + public int getPort() { + return port; + } - @Override - public int hashCode() { - return Objects.hash(host, port); - } + @Override + public int hashCode() { + return Objects.hash(host, port); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - HostAndPort that = (HostAndPort) o; - return port == that.port && Objects.equals(host, that.host); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return host + DELIMITER + port; + if (o == null || getClass() != o.getClass()) { + return false; } + HostAndPort that = (HostAndPort) o; + return port == that.port && Objects.equals(host, that.host); + } + + @Override + public String toString() { + return host + DELIMITER + port; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/Field.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/Field.java index 67402201a..3440e9574 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/Field.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/Field.java @@ -21,235 +21,237 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; public class Field implements Serializable { - private final String name; - private final IType type; - private final boolean nullable; - protected final Object defaultValue; - - public Field(String name, IType type) { - this(name, type, true, null); - } - - public Field(String name, IType type, boolean nullable, Object defaultValue) { - this.name = name; - this.type = type; - this.nullable = nullable; - this.defaultValue = defaultValue; - } - - public String getName() { - return this.name; - } - - public IType getType() { - return this.type; - } - - public boolean isNullable() { - return this.nullable; - } - - public Object getDefaultValue() { - return this.defaultValue; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Field field = (Field) o; - return Objects.equals(this.name, field.name) && this.type == field.type; - } - - @Override - public int hashCode() { - return Objects.hash(name, type); + private final String name; + private final IType type; + private final boolean nullable; + protected final Object defaultValue; + + public Field(String name, IType type) { + this(name, type, true, null); + } + + public Field(String name, IType type, boolean nullable, Object defaultValue) { + this.name = name; + this.type = type; + this.nullable = nullable; + this.defaultValue = defaultValue; + } + + public String getName() { + return this.name; + } + + public IType getType() { + return this.type; + } + + public boolean isNullable() { + return this.nullable; + } + + public Object getDefaultValue() { + return this.defaultValue; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Field field = (Field) o; + return Objects.equals(this.name, field.name) && this.type == field.type; + } + + @Override + public int hashCode() { + return Objects.hash(name, type); + } + + @Override + public String toString() { + return "Field{" + + "name='" + + name + + '\'' + + ", type=" + + type + + ", nullable=" + + nullable + + ", defaultValue=" + + defaultValue + + '}'; + } + + public static class ByteField extends Field { + + public ByteField(String name, boolean nullable, Byte defaultValue) { + super(name, Types.BYTE, nullable, defaultValue); } @Override - public String toString() { - return "Field{" - + "name='" + name + '\'' - + ", type=" + type - + ", nullable=" + nullable - + ", defaultValue=" + defaultValue - + '}'; - } - - public static class ByteField extends Field { - - public ByteField(String name, boolean nullable, Byte defaultValue) { - super(name, Types.BYTE, nullable, defaultValue); - } - - @Override - public Byte getDefaultValue() { - return this.defaultValue == null ? null : (Byte) this.defaultValue; - } - - } - - public static class ShortField extends Field { - - public ShortField(String name, boolean nullable, Short defaultValue) { - super(name, Types.SHORT, nullable, defaultValue); - } - - @Override - public Short getDefaultValue() { - return this.defaultValue == null ? null : (Short) this.defaultValue; - } - + public Byte getDefaultValue() { + return this.defaultValue == null ? null : (Byte) this.defaultValue; } + } - public static class IntegerField extends Field { - - public IntegerField(String name, boolean nullable, Integer defaultValue) { - super(name, Types.INTEGER, nullable, defaultValue); - } - - @Override - public Integer getDefaultValue() { - return this.defaultValue == null ? null : (Integer) this.defaultValue; - } + public static class ShortField extends Field { + public ShortField(String name, boolean nullable, Short defaultValue) { + super(name, Types.SHORT, nullable, defaultValue); } - public static class LongField extends Field { - - public LongField(String name, boolean nullable, Long defaultValue) { - super(name, Types.LONG, nullable, defaultValue); - } - - @Override - public Long getDefaultValue() { - return this.defaultValue == null ? null : (Long) this.defaultValue; - } - - } - - public static class BooleanField extends Field { - - public BooleanField(String name, boolean nullable, Boolean defaultValue) { - super(name, Types.BOOLEAN, nullable, defaultValue); - } - - @Override - public Boolean getDefaultValue() { - return this.defaultValue == null ? null : (Boolean) this.defaultValue; - } - - } - - public static class FloatField extends Field { - - public FloatField(String name, boolean nullable, Float defaultValue) { - super(name, Types.FLOAT, nullable, defaultValue); - } - - @Override - public Float getDefaultValue() { - return this.defaultValue == null ? null : (Float) this.defaultValue; - } - - } - - public static class DoubleField extends Field { - - public DoubleField(String name, boolean nullable, Double defaultValue) { - super(name, Types.DOUBLE, nullable, defaultValue); - } - - @Override - public Double getDefaultValue() { - return this.defaultValue == null ? null : (Double) this.defaultValue; - } - + @Override + public Short getDefaultValue() { + return this.defaultValue == null ? null : (Short) this.defaultValue; } + } - public static class StringField extends Field { - - public StringField(String name, boolean nullable, String defaultValue) { - super(name, Types.STRING, nullable, defaultValue); - } - - @Override - public String getDefaultValue() { - return this.defaultValue == null ? null : this.defaultValue.toString(); - } - } + public static class IntegerField extends Field { - public static ByteField newByteField(String fieldName) { - return new ByteField(fieldName, true, null); + public IntegerField(String name, boolean nullable, Integer defaultValue) { + super(name, Types.INTEGER, nullable, defaultValue); } - public static ByteField newByteField(String fieldName, boolean nullable, Byte defaultValue) { - return new ByteField(fieldName, nullable, defaultValue); + @Override + public Integer getDefaultValue() { + return this.defaultValue == null ? null : (Integer) this.defaultValue; } + } - public static ShortField newShortField(String fieldName) { - return new ShortField(fieldName, true, null); - } + public static class LongField extends Field { - public static ShortField newShortField(String fieldName, boolean nullable, Short defaultValue) { - return new ShortField(fieldName, nullable, defaultValue); + public LongField(String name, boolean nullable, Long defaultValue) { + super(name, Types.LONG, nullable, defaultValue); } - public static IntegerField newIntegerField(String fieldName) { - return new IntegerField(fieldName, true, null); + @Override + public Long getDefaultValue() { + return this.defaultValue == null ? null : (Long) this.defaultValue; } + } - public static IntegerField newIntegerField(String fieldName, boolean nullable, Integer defaultValue) { - return new IntegerField(fieldName, nullable, defaultValue); - } + public static class BooleanField extends Field { - public static LongField newLongField(String fieldName) { - return new LongField(fieldName, true, null); + public BooleanField(String name, boolean nullable, Boolean defaultValue) { + super(name, Types.BOOLEAN, nullable, defaultValue); } - public static LongField newLongField(String fieldName, boolean nullable, Long defaultValue) { - return new LongField(fieldName, nullable, defaultValue); + @Override + public Boolean getDefaultValue() { + return this.defaultValue == null ? null : (Boolean) this.defaultValue; } + } - public static BooleanField newBooleanField(String fieldName) { - return new BooleanField(fieldName, true, null); - } + public static class FloatField extends Field { - public static BooleanField newBooleanField(String fieldName, boolean nullable, Boolean defaultValue) { - return new BooleanField(fieldName, nullable, defaultValue); + public FloatField(String name, boolean nullable, Float defaultValue) { + super(name, Types.FLOAT, nullable, defaultValue); } - public static FloatField newFloatField(String fieldName) { - return new FloatField(fieldName, true, null); + @Override + public Float getDefaultValue() { + return this.defaultValue == null ? null : (Float) this.defaultValue; } + } - public static FloatField newFloatField(String fieldName, boolean nullable, Float defaultValue) { - return new FloatField(fieldName, nullable, defaultValue); - } + public static class DoubleField extends Field { - public static DoubleField newDoubleField(String fieldName) { - return new DoubleField(fieldName, true, null); + public DoubleField(String name, boolean nullable, Double defaultValue) { + super(name, Types.DOUBLE, nullable, defaultValue); } - public static DoubleField newDoubleField(String fieldName, boolean nullable, Double defaultValue) { - return new DoubleField(fieldName, nullable, defaultValue); + @Override + public Double getDefaultValue() { + return this.defaultValue == null ? null : (Double) this.defaultValue; } + } - public static StringField newStringField(String fieldName) { - return new StringField(fieldName, true, null); - } + public static class StringField extends Field { - public static StringField newStringField(String fieldName, boolean nullable, String defaultValue) { - return new StringField(fieldName, nullable, defaultValue); + public StringField(String name, boolean nullable, String defaultValue) { + super(name, Types.STRING, nullable, defaultValue); } + @Override + public String getDefaultValue() { + return this.defaultValue == null ? null : this.defaultValue.toString(); + } + } + + public static ByteField newByteField(String fieldName) { + return new ByteField(fieldName, true, null); + } + + public static ByteField newByteField(String fieldName, boolean nullable, Byte defaultValue) { + return new ByteField(fieldName, nullable, defaultValue); + } + + public static ShortField newShortField(String fieldName) { + return new ShortField(fieldName, true, null); + } + + public static ShortField newShortField(String fieldName, boolean nullable, Short defaultValue) { + return new ShortField(fieldName, nullable, defaultValue); + } + + public static IntegerField newIntegerField(String fieldName) { + return new IntegerField(fieldName, true, null); + } + + public static IntegerField newIntegerField( + String fieldName, boolean nullable, Integer defaultValue) { + return new IntegerField(fieldName, nullable, defaultValue); + } + + public static LongField newLongField(String fieldName) { + return new LongField(fieldName, true, null); + } + + public static LongField newLongField(String fieldName, boolean nullable, Long defaultValue) { + return new LongField(fieldName, nullable, defaultValue); + } + + public static BooleanField newBooleanField(String fieldName) { + return new BooleanField(fieldName, true, null); + } + + public static BooleanField newBooleanField( + String fieldName, boolean nullable, Boolean defaultValue) { + return new BooleanField(fieldName, nullable, defaultValue); + } + + public static FloatField newFloatField(String fieldName) { + return new FloatField(fieldName, true, null); + } + + public static FloatField newFloatField(String fieldName, boolean nullable, Float defaultValue) { + return new FloatField(fieldName, nullable, defaultValue); + } + + public static DoubleField newDoubleField(String fieldName) { + return new DoubleField(fieldName, true, null); + } + + public static DoubleField newDoubleField( + String fieldName, boolean nullable, Double defaultValue) { + return new DoubleField(fieldName, nullable, defaultValue); + } + + public static StringField newStringField(String fieldName) { + return new StringField(fieldName, true, null); + } + + public static StringField newStringField( + String fieldName, boolean nullable, String defaultValue) { + return new StringField(fieldName, nullable, defaultValue); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/ISchema.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/ISchema.java index e2ff27252..d6f8d8e76 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/ISchema.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/ISchema.java @@ -24,41 +24,40 @@ public interface ISchema extends Serializable { - /** - * Get schema id. - * - * @return schema id - */ - int getSchemaId(); - - /** - * Get schema name. - * - * @return schema name - */ - String getSchemaName(); - - /** - * Get a field of an index. - * - * @param index field index - * @return field - */ - Field getField(int index); - - /** - * Get a field of a specified field name. - * - * @param fieldName field name - * @return field - */ - Field getField(String fieldName); - - /** - * Get all fields. - * - * @return field array - */ - List getFields(); - + /** + * Get schema id. + * + * @return schema id + */ + int getSchemaId(); + + /** + * Get schema name. + * + * @return schema name + */ + String getSchemaName(); + + /** + * Get a field of an index. + * + * @param index field index + * @return field + */ + Field getField(int index); + + /** + * Get a field of a specified field name. + * + * @param fieldName field name + * @return field + */ + Field getField(String fieldName); + + /** + * Get all fields. + * + * @return field array + */ + List getFields(); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/SchemaImpl.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/SchemaImpl.java index 1c6b73d7c..7a4b6a7b0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/SchemaImpl.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/schema/SchemaImpl.java @@ -26,53 +26,52 @@ public class SchemaImpl implements ISchema { - private final int id; - private final String name; - private final List fields; - private final Map fieldMap; + private final int id; + private final String name; + private final List fields; + private final Map fieldMap; - public SchemaImpl(String name, List fields) { - this(0, name, fields); - } - - public SchemaImpl(int id, String name, List fields) { - this.id = id; - this.name = name; - this.fields = Collections.unmodifiableList(fields); - this.fieldMap = Collections.unmodifiableMap(generateFieldMap(fields)); - } + public SchemaImpl(String name, List fields) { + this(0, name, fields); + } - private static Map generateFieldMap(List fields) { - Map map = new HashMap<>(fields.size()); - for (Field field : fields) { - map.put(field.getName(), field); - } - return map; - } + public SchemaImpl(int id, String name, List fields) { + this.id = id; + this.name = name; + this.fields = Collections.unmodifiableList(fields); + this.fieldMap = Collections.unmodifiableMap(generateFieldMap(fields)); + } - @Override - public int getSchemaId() { - return this.id; + private static Map generateFieldMap(List fields) { + Map map = new HashMap<>(fields.size()); + for (Field field : fields) { + map.put(field.getName(), field); } + return map; + } - @Override - public String getSchemaName() { - return this.name; - } + @Override + public int getSchemaId() { + return this.id; + } - @Override - public Field getField(int index) { - return this.fields.get(index); - } + @Override + public String getSchemaName() { + return this.name; + } - @Override - public Field getField(String fieldName) { - return this.fieldMap.get(fieldName); - } + @Override + public Field getField(int index) { + return this.fields.get(index); + } - @Override - public List getFields() { - return this.fields; - } + @Override + public Field getField(String fieldName) { + return this.fieldMap.get(fieldName); + } + @Override + public List getFields() { + return this.fields; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/ISerializer.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/ISerializer.java index a9343d356..a4dea2d78 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/ISerializer.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/ISerializer.java @@ -25,14 +25,13 @@ public interface ISerializer extends Serializable { - byte[] serialize(Object o); + byte[] serialize(Object o); - Object deserialize(byte[] bytes); + Object deserialize(byte[] bytes); - void serialize(Object o, OutputStream outputStream); + void serialize(Object o, OutputStream outputStream); - T deserialize(InputStream inputStream); + T deserialize(InputStream inputStream); - - T copy(T target); + T copy(T target); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/SerializerFactory.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/SerializerFactory.java index 961802759..0266f6ef0 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/SerializerFactory.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/SerializerFactory.java @@ -20,13 +20,14 @@ package org.apache.geaflow.common.serialize; import java.io.Serializable; + import org.apache.geaflow.common.serialize.impl.KryoSerializer; public class SerializerFactory implements Serializable { - private static KryoSerializer kryoSerializer = new KryoSerializer(); + private static KryoSerializer kryoSerializer = new KryoSerializer(); - public static ISerializer getKryoSerializer() { - return kryoSerializer; - } + public static ISerializer getKryoSerializer() { + return kryoSerializer; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/impl/KryoSerializer.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/impl/KryoSerializer.java index 860d4fb12..b47d379b1 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/impl/KryoSerializer.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/impl/KryoSerializer.java @@ -19,11 +19,28 @@ package org.apache.geaflow.common.serialize.impl; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.apache.geaflow.common.serialize.ISerializer; +import org.apache.geaflow.common.serialize.kryo.SubListSerializers4Jdk9; +import org.apache.geaflow.common.utils.ClassUtil; +import org.objenesis.strategy.StdInstantiatorStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.esotericsoftware.kryo.serializers.ClosureSerializer; + import de.javakaffee.kryoserializers.ArraysAsListSerializer; import de.javakaffee.kryoserializers.CollectionsEmptyListSerializer; import de.javakaffee.kryoserializers.CollectionsSingletonListSerializer; @@ -40,274 +57,334 @@ import de.javakaffee.kryoserializers.guava.ReverseListSerializer; import de.javakaffee.kryoserializers.guava.TreeMultimapSerializer; import de.javakaffee.kryoserializers.guava.UnmodifiableNavigableSetSerializer; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import org.apache.geaflow.common.exception.GeaflowRuntimeException; -import org.apache.geaflow.common.serialize.ISerializer; -import org.apache.geaflow.common.serialize.kryo.SubListSerializers4Jdk9; -import org.apache.geaflow.common.utils.ClassUtil; -import org.objenesis.strategy.StdInstantiatorStrategy; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class KryoSerializer implements ISerializer { - private static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class); - private static final int INITIAL_BUFFER_SIZE = 4096; - private static List needRegisterClasses; - private static Map registeredSerializers; + private static final Logger LOGGER = LoggerFactory.getLogger(KryoSerializer.class); + private static final int INITIAL_BUFFER_SIZE = 4096; + private static List needRegisterClasses; + private static Map registeredSerializers; - private final ThreadLocal local = new ThreadLocal() { + private final ThreadLocal local = + new ThreadLocal() { @Override protected Kryo initialValue() { - Kryo kryo = new Kryo(); - Kryo.DefaultInstantiatorStrategy is = new Kryo.DefaultInstantiatorStrategy(); - is.setFallbackInstantiatorStrategy(new StdInstantiatorStrategy()); - kryo.setInstantiatorStrategy(is); - - kryo.getFieldSerializerConfig().setOptimizedGenerics(false); - kryo.setRegistrationRequired(false); - - kryo.register(Arrays.asList("").getClass(), new ArraysAsListSerializer()); - kryo.register(Collections.EMPTY_LIST.getClass(), new CollectionsEmptyListSerializer()); - kryo.register(Collections.singletonList("").getClass(), - new CollectionsSingletonListSerializer()); - kryo.register(ClosureSerializer.Closure.class, new ClosureSerializer()); - - ArrayListMultimapSerializer.registerSerializers(kryo); - HashMultimapSerializer.registerSerializers(kryo); - ImmutableListSerializer.registerSerializers(kryo); - ImmutableMapSerializer.registerSerializers(kryo); - ImmutableMultimapSerializer.registerSerializers(kryo); - ImmutableSetSerializer.registerSerializers(kryo); - ImmutableSortedSetSerializer.registerSerializers(kryo); - LinkedHashMultimapSerializer.registerSerializers(kryo); - LinkedListMultimapSerializer.registerSerializers(kryo); - ReverseListSerializer.registerSerializers(kryo); - TreeMultimapSerializer.registerSerializers(kryo); - UnmodifiableNavigableSetSerializer.registerSerializers(kryo); - SubListSerializers4Jdk9.addDefaultSerializers(kryo); - UnmodifiableCollectionsSerializer.registerSerializers(kryo); - - ClassLoader tcl = Thread.currentThread().getContextClassLoader(); - if (tcl != null) { - kryo.setClassLoader(tcl); - } + Kryo kryo = new Kryo(); + Kryo.DefaultInstantiatorStrategy is = new Kryo.DefaultInstantiatorStrategy(); + is.setFallbackInstantiatorStrategy(new StdInstantiatorStrategy()); + kryo.setInstantiatorStrategy(is); - if (registeredSerializers != null) { - for (Map.Entry entry : registeredSerializers.entrySet()) { - LOGGER.info("register class:{} serializer", entry.getKey().getSimpleName()); - kryo.register(entry.getKey(), entry.getValue()); - } - } + kryo.getFieldSerializerConfig().setOptimizedGenerics(false); + kryo.setRegistrationRequired(false); - if (needRegisterClasses != null && needRegisterClasses.size() != 0) { - for (String clazz : needRegisterClasses) { - String[] clazzToId = clazz.trim().split(":"); - if (clazzToId.length != 2) { - throw new GeaflowRuntimeException("invalid clazzToId format:" + clazz); - } - int registerId = Integer.parseInt(clazzToId[1]); - registerClass(kryo, clazzToId[0], registerId); - } - } + kryo.register(Arrays.asList("").getClass(), new ArraysAsListSerializer()); + kryo.register(Collections.EMPTY_LIST.getClass(), new CollectionsEmptyListSerializer()); + kryo.register( + Collections.singletonList("").getClass(), new CollectionsSingletonListSerializer()); + kryo.register(ClosureSerializer.Closure.class, new ClosureSerializer()); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.BinaryRow", 1011); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedPath", - "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedPath$DefaultParameterizedPathSerializer", 1012); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedRow", - "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedRow$DefaultParameterizedRowSerializer", 1013); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.DefaultPath", - "org.apache.geaflow.dsl.common.data.impl.DefaultPath$DefaultPathSerializer", 1014); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.DefaultRowKeyWithRequestId", - "org.apache.geaflow.dsl.common.data.impl.DefaultRowKeyWithRequestId$DefaultRowKeyWithRequestIdSerializer", 1015); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.ObjectRow", - "org.apache.geaflow.dsl.common.data.impl.ObjectRow$ObjectRowSerializer", 1016); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.ObjectRowKey", - "org.apache.geaflow.dsl.common.data.impl.ObjectRowKey$ObjectRowKeySerializer", 1017); - - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringEdge", 1018); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringTsEdge", 1019); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringVertex", 1020); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleEdge", 1021); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleTsEdge", 1022); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleVertex", 1023); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntEdge", 1024); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntTsEdge", 1025); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntVertex", 1026); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongEdge", 1027); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongTsEdge", 1028); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongVertex", 1029); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectEdge", 1030); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectTsEdge", 1031); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectVertex", 1032); - - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignEdge", - "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignEdge$FieldAlignEdgeSerializer", 1033); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignPath", - "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignPath$FieldAlignPathSerializer", 1034); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignVertex", - "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignVertex$FieldAlignVertexSerializer", 1035); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.data.IdOnlyVertex", 1036); - - registerClass(kryo, "org.apache.geaflow.dsl.common.data.ParameterizedRow", 1037); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.Path", 1038); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.Row", 1039); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowEdge", 1040); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowKey", 1041); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowKeyWithRequestId", 1042); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowVertex", 1043); - registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.ParameterizedPath", 1044); - - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.EODMessage", - "org.apache.geaflow.dsl.runtime.traversal.message.EODMessage$EODMessageSerializer", 1045); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.IPathMessage", 1046); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.JoinPathMessage", - "org.apache.geaflow.dsl.runtime.traversal.message.JoinPathMessage$JoinPathMessageSerializer", 1047); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessage", 1048); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessageImpl", - "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessageImpl$KeyGroupMessageImplSerializer", 1049); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.ParameterRequestMessage", - "org.apache.geaflow.dsl.runtime.traversal.message.ParameterRequestMessage$ParameterRequestMessageSerializer", 1050); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.RequestIsolationMessage", 1051); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessage", 1052); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessageImpl", - "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessageImpl$ReturnMessageImplSerializer", 1053); - - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.AbstractSingleTreePath", 1054); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.AbstractTreePath", 1055); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.EdgeTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.EdgeTreePath$EdgeTreePathSerializer", 1056); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.EmptyTreePath", 1057); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.ITreePath", 1058); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.ParameterizedTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.ParameterizedTreePath$ParameterizedTreePathSerializer", 1059); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.SourceEdgeTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.SourceEdgeTreePath$SourceEdgeTreePathSerializer", 1060); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.SourceVertexTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.SourceVertexTreePath$SourceVertexTreePathSerializer", 1061); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.UnionTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.UnionTreePath$UnionTreePathSerializer", 1062); - registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.VertexTreePath", - "org.apache.geaflow.dsl.runtime.traversal.path.VertexTreePath$VertexTreePathSerializer", 1063); - - // Register MST algorithm related classes - registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTMessage", 1064); - registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTVertexState", 1065); - registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTEdge", 1066); - registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTMessage$MessageType", 1067); - - // Register binary object classes - registerClass(kryo, "org.apache.geaflow.common.binary.IBinaryObject", 106); - registerClass(kryo, "org.apache.geaflow.common.binary.HeapBinaryObject", 112); - - // Force registration of binary object classes to avoid unregistered class ID errors - try { - Class iBinaryObjectClass = ClassUtil.classForName("org.apache.geaflow.common.binary.IBinaryObject"); - Class heapBinaryObjectClass = ClassUtil.classForName("org.apache.geaflow.common.binary.HeapBinaryObject"); - kryo.register(iBinaryObjectClass, 106); - kryo.register(heapBinaryObjectClass, 112); - LOGGER.debug("Force registered binary object classes with IDs 106 and 112"); - } catch (Exception e) { - LOGGER.warn("Failed to force register binary object classes: {}", e.getMessage()); - } + ArrayListMultimapSerializer.registerSerializers(kryo); + HashMultimapSerializer.registerSerializers(kryo); + ImmutableListSerializer.registerSerializers(kryo); + ImmutableMapSerializer.registerSerializers(kryo); + ImmutableMultimapSerializer.registerSerializers(kryo); + ImmutableSetSerializer.registerSerializers(kryo); + ImmutableSortedSetSerializer.registerSerializers(kryo); + LinkedHashMultimapSerializer.registerSerializers(kryo); + LinkedListMultimapSerializer.registerSerializers(kryo); + ReverseListSerializer.registerSerializers(kryo); + TreeMultimapSerializer.registerSerializers(kryo); + UnmodifiableNavigableSetSerializer.registerSerializers(kryo); + SubListSerializers4Jdk9.addDefaultSerializers(kryo); + UnmodifiableCollectionsSerializer.registerSerializers(kryo); - return kryo; - } - }; - - private void registerClass(Kryo kryo, String className, int kryoId) { - try { - LOGGER.debug("register class:{} id:{}", className, kryoId); - Class clazz = ClassUtil.classForName(className); - kryo.register(clazz, kryoId); - } catch (GeaflowRuntimeException e) { - if (e.getCause() instanceof ClassNotFoundException) { - LOGGER.warn("class not found: {} skip register id:{}", className, kryoId); + ClassLoader tcl = Thread.currentThread().getContextClassLoader(); + if (tcl != null) { + kryo.setClassLoader(tcl); + } + + if (registeredSerializers != null) { + for (Map.Entry entry : registeredSerializers.entrySet()) { + LOGGER.info("register class:{} serializer", entry.getKey().getSimpleName()); + kryo.register(entry.getKey(), entry.getValue()); } - } catch (Throwable e) { - LOGGER.error("error in register class: {} to kryo.", className); - throw new GeaflowRuntimeException(e); - } - } + } - private void registerClass(Kryo kryo, String className, String serializerClassName, int kryoId) { - try { - LOGGER.debug("register class:{} id:{}", className, kryoId); - Class clazz = ClassUtil.classForName(className); - Class serializerClazz = ClassUtil.classForName(serializerClassName); - Serializer serializer = (Serializer) serializerClazz.newInstance(); - kryo.register(clazz, serializer, kryoId); - } catch (GeaflowRuntimeException e) { - if (e.getCause() instanceof ClassNotFoundException) { - LOGGER.warn("class not found: {} skip register id:{}", className, kryoId); + if (needRegisterClasses != null && needRegisterClasses.size() != 0) { + for (String clazz : needRegisterClasses) { + String[] clazzToId = clazz.trim().split(":"); + if (clazzToId.length != 2) { + throw new GeaflowRuntimeException("invalid clazzToId format:" + clazz); + } + int registerId = Integer.parseInt(clazzToId[1]); + registerClass(kryo, clazzToId[0], registerId); } - } catch (Throwable e) { - LOGGER.error("error in register class: {} to kryo.", className); - throw new GeaflowRuntimeException(e); - } - } + } - @Override - public byte[] serialize(Object o) { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(INITIAL_BUFFER_SIZE); - Output output = new Output(outputStream); - try { - local.get().writeClassAndObject(output, o); - output.flush(); - } finally { - output.clear(); - output.close(); - } - return outputStream.toByteArray(); - } + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.BinaryRow", 1011); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedPath", + "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedPath$DefaultParameterizedPathSerializer", + 1012); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedRow", + "org.apache.geaflow.dsl.common.data.impl.DefaultParameterizedRow$DefaultParameterizedRowSerializer", + 1013); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.DefaultPath", + "org.apache.geaflow.dsl.common.data.impl.DefaultPath$DefaultPathSerializer", + 1014); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.DefaultRowKeyWithRequestId", + "org.apache.geaflow.dsl.common.data.impl.DefaultRowKeyWithRequestId$DefaultRowKeyWithRequestIdSerializer", + 1015); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.ObjectRow", + "org.apache.geaflow.dsl.common.data.impl.ObjectRow$ObjectRowSerializer", + 1016); + registerClass( + kryo, + "org.apache.geaflow.dsl.common.data.impl.ObjectRowKey", + "org.apache.geaflow.dsl.common.data.impl.ObjectRowKey$ObjectRowKeySerializer", + 1017); - @Override - public Object deserialize(byte[] bytes) { - try { - Input input = new Input(bytes); - return local.get().readClassAndObject(input); - } catch (Exception e) { - // Handle Kryo serialization errors by returning null - // This allows the algorithm to create a new state instead of crashing - LOGGER.warn("Failed to deserialize object: {}, returning null", e.getMessage()); - return null; - } - } + registerClass( + kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringEdge", 1018); + registerClass( + kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringTsEdge", 1019); + registerClass( + kryo, "org.apache.geaflow.dsl.common.data.impl.types.BinaryStringVertex", 1020); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleEdge", 1021); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleTsEdge", 1022); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.DoubleVertex", 1023); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntEdge", 1024); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntTsEdge", 1025); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.IntVertex", 1026); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongEdge", 1027); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongTsEdge", 1028); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.LongVertex", 1029); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectEdge", 1030); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectTsEdge", 1031); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.types.ObjectVertex", 1032); + + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignEdge", + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignEdge$FieldAlignEdgeSerializer", + 1033); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignPath", + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignPath$FieldAlignPathSerializer", + 1034); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignVertex", + "org.apache.geaflow.dsl.runtime.traversal.data.FieldAlignVertex$FieldAlignVertexSerializer", + 1035); + registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.data.IdOnlyVertex", 1036); + + registerClass(kryo, "org.apache.geaflow.dsl.common.data.ParameterizedRow", 1037); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.Path", 1038); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.Row", 1039); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowEdge", 1040); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowKey", 1041); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowKeyWithRequestId", 1042); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.RowVertex", 1043); + registerClass(kryo, "org.apache.geaflow.dsl.common.data.impl.ParameterizedPath", 1044); + + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.EODMessage", + "org.apache.geaflow.dsl.runtime.traversal.message.EODMessage$EODMessageSerializer", + 1045); + registerClass( + kryo, "org.apache.geaflow.dsl.runtime.traversal.message.IPathMessage", 1046); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.JoinPathMessage", + "org.apache.geaflow.dsl.runtime.traversal.message.JoinPathMessage$JoinPathMessageSerializer", + 1047); + registerClass( + kryo, "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessage", 1048); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessageImpl", + "org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessageImpl$KeyGroupMessageImplSerializer", + 1049); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.ParameterRequestMessage", + "org.apache.geaflow.dsl.runtime.traversal.message.ParameterRequestMessage$ParameterRequestMessageSerializer", + 1050); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.RequestIsolationMessage", + 1051); + registerClass( + kryo, "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessage", 1052); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessageImpl", + "org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessageImpl$ReturnMessageImplSerializer", + 1053); + + registerClass( + kryo, "org.apache.geaflow.dsl.runtime.traversal.path.AbstractSingleTreePath", 1054); + registerClass( + kryo, "org.apache.geaflow.dsl.runtime.traversal.path.AbstractTreePath", 1055); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.EdgeTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.EdgeTreePath$EdgeTreePathSerializer", + 1056); + registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.EmptyTreePath", 1057); + registerClass(kryo, "org.apache.geaflow.dsl.runtime.traversal.path.ITreePath", 1058); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.ParameterizedTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.ParameterizedTreePath$ParameterizedTreePathSerializer", + 1059); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.SourceEdgeTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.SourceEdgeTreePath$SourceEdgeTreePathSerializer", + 1060); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.SourceVertexTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.SourceVertexTreePath$SourceVertexTreePathSerializer", + 1061); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.UnionTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.UnionTreePath$UnionTreePathSerializer", + 1062); + registerClass( + kryo, + "org.apache.geaflow.dsl.runtime.traversal.path.VertexTreePath", + "org.apache.geaflow.dsl.runtime.traversal.path.VertexTreePath$VertexTreePathSerializer", + 1063); + + // Register MST algorithm related classes + registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTMessage", 1064); + registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTVertexState", 1065); + registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTEdge", 1066); + registerClass(kryo, "org.apache.geaflow.dsl.udf.graph.mst.MSTMessage$MessageType", 1067); - @Override - public void serialize(Object o, OutputStream outputStream) { - Output output = new Output(outputStream); - try { - local.get().writeClassAndObject(output, o); - output.flush(); - } finally { - output.clear(); - output.close(); + // Register binary object classes + registerClass(kryo, "org.apache.geaflow.common.binary.IBinaryObject", 106); + registerClass(kryo, "org.apache.geaflow.common.binary.HeapBinaryObject", 112); + + // Force registration of binary object classes to avoid unregistered class ID errors + try { + Class iBinaryObjectClass = + ClassUtil.classForName("org.apache.geaflow.common.binary.IBinaryObject"); + Class heapBinaryObjectClass = + ClassUtil.classForName("org.apache.geaflow.common.binary.HeapBinaryObject"); + kryo.register(iBinaryObjectClass, 106); + kryo.register(heapBinaryObjectClass, 112); + LOGGER.debug("Force registered binary object classes with IDs 106 and 112"); + } catch (Exception e) { + LOGGER.warn("Failed to force register binary object classes: {}", e.getMessage()); + } + + return kryo; } + }; + + private void registerClass(Kryo kryo, String className, int kryoId) { + try { + LOGGER.debug("register class:{} id:{}", className, kryoId); + Class clazz = ClassUtil.classForName(className); + kryo.register(clazz, kryoId); + } catch (GeaflowRuntimeException e) { + if (e.getCause() instanceof ClassNotFoundException) { + LOGGER.warn("class not found: {} skip register id:{}", className, kryoId); + } + } catch (Throwable e) { + LOGGER.error("error in register class: {} to kryo.", className); + throw new GeaflowRuntimeException(e); } + } - @Override - public Object deserialize(InputStream inputStream) { - Input input = new Input(inputStream); - return local.get().readClassAndObject(input); + private void registerClass(Kryo kryo, String className, String serializerClassName, int kryoId) { + try { + LOGGER.debug("register class:{} id:{}", className, kryoId); + Class clazz = ClassUtil.classForName(className); + Class serializerClazz = ClassUtil.classForName(serializerClassName); + Serializer serializer = (Serializer) serializerClazz.newInstance(); + kryo.register(clazz, serializer, kryoId); + } catch (GeaflowRuntimeException e) { + if (e.getCause() instanceof ClassNotFoundException) { + LOGGER.warn("class not found: {} skip register id:{}", className, kryoId); + } + } catch (Throwable e) { + LOGGER.error("error in register class: {} to kryo.", className); + throw new GeaflowRuntimeException(e); } + } - public Kryo getThreadKryo() { - return local.get(); + @Override + public byte[] serialize(Object o) { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(INITIAL_BUFFER_SIZE); + Output output = new Output(outputStream); + try { + local.get().writeClassAndObject(output, o); + output.flush(); + } finally { + output.clear(); + output.close(); } + return outputStream.toByteArray(); + } - @Override - public T copy(T target) { - return local.get().copy(target); + @Override + public Object deserialize(byte[] bytes) { + try { + Input input = new Input(bytes); + return local.get().readClassAndObject(input); + } catch (Exception e) { + // Handle Kryo serialization errors by returning null + // This allows the algorithm to create a new state instead of crashing + LOGGER.warn("Failed to deserialize object: {}, returning null", e.getMessage()); + return null; } + } - public void clean() { - local.remove(); + @Override + public void serialize(Object o, OutputStream outputStream) { + Output output = new Output(outputStream); + try { + local.get().writeClassAndObject(output, o); + output.flush(); + } finally { + output.clear(); + output.close(); } + } + + @Override + public Object deserialize(InputStream inputStream) { + Input input = new Input(inputStream); + return local.get().readClassAndObject(input); + } + + public Kryo getThreadKryo() { + return local.get(); + } + + @Override + public T copy(T target) { + return local.get().copy(target); + } + + public void clean() { + local.remove(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/kryo/SubListSerializers4Jdk9.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/kryo/SubListSerializers4Jdk9.java index 8add636fc..92f31e63b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/kryo/SubListSerializers4Jdk9.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/serialize/kryo/SubListSerializers4Jdk9.java @@ -19,334 +19,331 @@ package org.apache.geaflow.common.serialize.kryo; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.lang.reflect.Field; import java.util.AbstractList; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /** - * Support JDK 9+ SubListSerializers. - * see https://github.com/magro/kryo-serializers/commit/39e662fc2cb94fb5867af0a8043cdaa5f62e3ef0 - * #diff-103927eece6ef23d20059d8f711a7275d0a0064d25317928c83a025a88675e35 - * Kryo {@link Serializer}s for lists created via {@link List#subList(int, int)}. - * An instance of a serializer can be obtained via {@link #createFor(Class)}, which - * just returns null if the given type is not supported by these - * serializers. + * Support JDK 9+ SubListSerializers. see + * https://github.com/magro/kryo-serializers/commit/39e662fc2cb94fb5867af0a8043cdaa5f62e3ef0 + * #diff-103927eece6ef23d20059d8f711a7275d0a0064d25317928c83a025a88675e35 Kryo {@link Serializer}s + * for lists created via {@link List#subList(int, int)}. An instance of a serializer can be obtained + * via {@link #createFor(Class)}, which just returns null if the given type is not + * supported by these serializers. * * @author Martin Grotzke */ public class SubListSerializers4Jdk9 { - static Class getClass(final String className) { - try { - return Class.forName(className); - } catch (final Exception e) { - throw new RuntimeException(e); - } + static Class getClass(final String className) { + try { + return Class.forName(className); + } catch (final Exception e) { + throw new RuntimeException(e); } + } - static Class getClassOrNull(final String className) { - try { - return Class.forName(className); - } catch (final Exception e) { - return null; - } + static Class getClassOrNull(final String className) { + try { + return Class.forName(className); + } catch (final Exception e) { + return null; + } + } + + // Workaround reference reading, this should be removed sometimes. See also + // https://groups.google.com/d/msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ + private static final Object FAKE_REFERENCE = new Object(); + + /** + * Obtain a serializer for the given sublist type. If the type is not supported null + * is returned. + * + * @param type the class of the sublist. + * @return a serializer instance or null. + */ + @SuppressWarnings("rawtypes") + public static Serializer> createFor(final Class type) { + if (ArrayListSubListSerializer.canSerialize(type)) { + return new ArrayListSubListSerializer(); + } + if (JavaUtilSubListSerializer.canSerialize(type)) { + return new JavaUtilSubListSerializer(); + } + return null; + } + + /** Adds appropriate sublist serializers as default serializers. */ + public static Kryo addDefaultSerializers(Kryo kryo) { + ArrayListSubListSerializer.addDefaultSerializer(kryo); + AbstractListSubListSerializer.addDefaultSerializer(kryo); + JavaUtilSubListSerializer.addDefaultSerializer(kryo); + return kryo; + } + + /** + * Supports sublists created via {@link ArrayList#subList(int, int)} since java7 and {@link + * LinkedList#subList(int, int)} since java9 (openjdk). + */ + private static class SubListSerializer extends Serializer> { + + private Field parentField; + private Field parentOffsetField; + private Field sizeField; + + public SubListSerializer(String subListClassName) { + try { + final Class clazz = Class.forName(subListClassName); + parentField = getParentField(clazz); + parentOffsetField = getOffsetField(clazz); + sizeField = clazz.getDeclaredField("size"); + parentField.setAccessible(true); + parentOffsetField.setAccessible(true); + sizeField.setAccessible(true); + } catch (final Exception e) { + throw new RuntimeException(e); + } } - // Workaround reference reading, this should be removed sometimes. See also - // https://groups.google.com/d/msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ - private static final Object FAKE_REFERENCE = new Object(); + private static Field getParentField(Class clazz) throws NoSuchFieldException { + try { + // java 9+ + return clazz.getDeclaredField("root"); + } catch (NoSuchFieldException e) { + return clazz.getDeclaredField("parent"); + } + } + + private static Field getOffsetField(Class clazz) throws NoSuchFieldException { + try { + // up to jdk8 (which also has an "offset" field (we don't need) - therefore we + // check "parentOffset" first + return clazz.getDeclaredField("parentOffset"); + } catch (NoSuchFieldException e) { + // jdk9+ only has "offset" which is the parent offset + return clazz.getDeclaredField("offset"); + } + } + + @Override + public List read(final Kryo kryo, final Input input, final Class> clazz) { + kryo.reference(FAKE_REFERENCE); + final List list = (List) kryo.readClassAndObject(input); + final int fromIndex = input.readInt(true); + final int toIndex = input.readInt(true); + return list.subList(fromIndex, toIndex); + } + + @Override + public void write(final Kryo kryo, final Output output, final List obj) { + try { + kryo.writeClassAndObject(output, parentField.get(obj)); + final int parentOffset = parentOffsetField.getInt(obj); + final int fromIndex = parentOffset; + output.writeInt(fromIndex, true); + final int toIndex = fromIndex + sizeField.getInt(obj); + output.writeInt(toIndex, true); + } catch (final RuntimeException e) { + // Don't eat and wrap RuntimeExceptions because the ObjectBuffer.write... + // handles SerializationException specifically (resizing the buffer)... + throw e; + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public List copy(final Kryo kryo, final List original) { + kryo.reference(FAKE_REFERENCE); + try { + final List list = (List) parentField.get(original); + final int parentOffset = parentOffsetField.getInt(original); + final int fromIndex = parentOffset; + final int toIndex = fromIndex + sizeField.getInt(original); + return kryo.copy(list).subList(fromIndex, toIndex); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * Supports sublists created via {@link ArrayList#subList(int, int)} since java7 (oracle jdk, + * represented by java.util.ArrayList$SubList). + */ + public static class ArrayListSubListSerializer extends Serializer> { + + public static final Class SUBLIST_CLASS = + SubListSerializers4Jdk9.getClassOrNull("java.util.ArrayList$SubList"); + + private final SubListSerializer delegate = new SubListSerializer("java.util.ArrayList$SubList"); /** - * Obtain a serializer for the given sublist type. If the type is not supported - * null is returned. + * Can be used to determine, if the given type can be handled by this serializer. * - * @param type the class of the sublist. - * @return a serializer instance or null. + * @param type the class to check. + * @return true if the given class can be serialized/deserialized by this serializer. */ - @SuppressWarnings("rawtypes") - public static Serializer> createFor(final Class type) { - if (ArrayListSubListSerializer.canSerialize(type)) { - return new ArrayListSubListSerializer(); - } - if (JavaUtilSubListSerializer.canSerialize(type)) { - return new JavaUtilSubListSerializer(); - } - return null; + public static boolean canSerialize(final Class type) { + return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); } - /** - * Adds appropriate sublist serializers as default serializers. - */ - public static Kryo addDefaultSerializers(Kryo kryo) { - ArrayListSubListSerializer.addDefaultSerializer(kryo); - AbstractListSubListSerializer.addDefaultSerializer(kryo); - JavaUtilSubListSerializer.addDefaultSerializer(kryo); - return kryo; + public static Kryo addDefaultSerializer(Kryo kryo) { + if (SUBLIST_CLASS != null) { + kryo.addDefaultSerializer(SUBLIST_CLASS, new ArrayListSubListSerializer()); + } + return kryo; } - /** - * Supports sublists created via {@link ArrayList#subList(int, int)} since java7 and {@link - * LinkedList#subList(int, int)} since java9 (openjdk). - */ - private static class SubListSerializer extends Serializer> { - - private Field parentField; - private Field parentOffsetField; - private Field sizeField; - - public SubListSerializer(String subListClassName) { - try { - final Class clazz = Class.forName(subListClassName); - parentField = getParentField(clazz); - parentOffsetField = getOffsetField(clazz); - sizeField = clazz.getDeclaredField("size"); - parentField.setAccessible(true); - parentOffsetField.setAccessible(true); - sizeField.setAccessible(true); - } catch (final Exception e) { - throw new RuntimeException(e); - } - } - - private static Field getParentField(Class clazz) throws NoSuchFieldException { - try { - // java 9+ - return clazz.getDeclaredField("root"); - } catch (NoSuchFieldException e) { - return clazz.getDeclaredField("parent"); - } - } - - private static Field getOffsetField(Class clazz) throws NoSuchFieldException { - try { - // up to jdk8 (which also has an "offset" field (we don't need) - therefore we - // check "parentOffset" first - return clazz.getDeclaredField("parentOffset"); - } catch (NoSuchFieldException e) { - // jdk9+ only has "offset" which is the parent offset - return clazz.getDeclaredField("offset"); - } - } - - @Override - public List read(final Kryo kryo, final Input input, final Class> clazz) { - kryo.reference(FAKE_REFERENCE); - final List list = (List) kryo.readClassAndObject(input); - final int fromIndex = input.readInt(true); - final int toIndex = input.readInt(true); - return list.subList(fromIndex, toIndex); - } - - @Override - public void write(final Kryo kryo, final Output output, final List obj) { - try { - kryo.writeClassAndObject(output, parentField.get(obj)); - final int parentOffset = parentOffsetField.getInt(obj); - final int fromIndex = parentOffset; - output.writeInt(fromIndex, true); - final int toIndex = fromIndex + sizeField.getInt(obj); - output.writeInt(toIndex, true); - } catch (final RuntimeException e) { - // Don't eat and wrap RuntimeExceptions because the ObjectBuffer.write... - // handles SerializationException specifically (resizing the buffer)... - throw e; - } catch (final Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public List copy(final Kryo kryo, final List original) { - kryo.reference(FAKE_REFERENCE); - try { - final List list = (List) parentField.get(original); - final int parentOffset = parentOffsetField.getInt(original); - final int fromIndex = parentOffset; - final int toIndex = fromIndex + sizeField.getInt(original); - return kryo.copy(list).subList(fromIndex, toIndex); - } catch (final Exception e) { - throw new RuntimeException(e); - } - } + @Override + public List read(final Kryo kryo, final Input input, final Class> clazz) { + return delegate.read(kryo, input, clazz); } - /** - * Supports sublists created via {@link ArrayList#subList(int, int)} since java7 (oracle jdk, - * represented by java.util.ArrayList$SubList). - */ - public static class ArrayListSubListSerializer extends Serializer> { - - public static final Class SUBLIST_CLASS = SubListSerializers4Jdk9 - .getClassOrNull("java.util.ArrayList$SubList"); - - private final SubListSerializer delegate = new SubListSerializer( - "java.util.ArrayList$SubList"); - - /** - * Can be used to determine, if the given type can be handled by this serializer. - * - * @param type the class to check. - * @return true if the given class can be serialized/deserialized by this serializer. - */ - public static boolean canSerialize(final Class type) { - return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); - } - - public static Kryo addDefaultSerializer(Kryo kryo) { - if (SUBLIST_CLASS != null) { - kryo.addDefaultSerializer(SUBLIST_CLASS, new ArrayListSubListSerializer()); - } - return kryo; - } - - @Override - public List read(final Kryo kryo, final Input input, final Class> clazz) { - return delegate.read(kryo, input, clazz); - } - - @Override - public void write(final Kryo kryo, final Output output, final List obj) { - delegate.write(kryo, output, obj); - } - - @Override - public List copy(final Kryo kryo, final List original) { - return delegate.copy(kryo, original); - } + @Override + public void write(final Kryo kryo, final Output output, final List obj) { + delegate.write(kryo, output, obj); } + @Override + public List copy(final Kryo kryo, final List original) { + return delegate.copy(kryo, original); + } + } + + /** + * Supports sublists created via {@link LinkedList#subList(int, int)} since java9 (oracle jdk, + * represented by java.util.AbstractList$SubList). + */ + public static class AbstractListSubListSerializer extends Serializer> { + + public static final Class SUBLIST_CLASS = + SubListSerializers4Jdk9.getClassOrNull("java.util.AbstractList$SubList"); + + private final SubListSerializer delegate = + new SubListSerializer("java.util.AbstractList$SubList"); + /** - * Supports sublists created via {@link LinkedList#subList(int, int)} since java9 (oracle jdk, - * represented by java.util.AbstractList$SubList). + * Can be used to determine, if the given type can be handled by this serializer. + * + * @param type the class to check. + * @return true if the given class can be serialized/deserialized by this serializer. */ - public static class AbstractListSubListSerializer extends Serializer> { - - public static final Class SUBLIST_CLASS = SubListSerializers4Jdk9 - .getClassOrNull("java.util.AbstractList$SubList"); - - private final SubListSerializer delegate = new SubListSerializer( - "java.util.AbstractList$SubList"); - - /** - * Can be used to determine, if the given type can be handled by this serializer. - * - * @param type the class to check. - * @return true if the given class can be serialized/deserialized by this serializer. - */ - public static boolean canSerialize(final Class type) { - return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); - } - - public static Kryo addDefaultSerializer(Kryo kryo) { - if (SUBLIST_CLASS != null) { - kryo.addDefaultSerializer(SUBLIST_CLASS, new AbstractListSubListSerializer()); - } - return kryo; - } - - @Override - public List read(final Kryo kryo, final Input input, final Class> clazz) { - return delegate.read(kryo, input, clazz); - } - - @Override - public void write(final Kryo kryo, final Output output, final List obj) { - delegate.write(kryo, output, obj); - } - - @Override - public List copy(final Kryo kryo, final List original) { - return delegate.copy(kryo, original); - } + public static boolean canSerialize(final Class type) { + return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); + } + + public static Kryo addDefaultSerializer(Kryo kryo) { + if (SUBLIST_CLASS != null) { + kryo.addDefaultSerializer(SUBLIST_CLASS, new AbstractListSubListSerializer()); + } + return kryo; + } + + @Override + public List read(final Kryo kryo, final Input input, final Class> clazz) { + return delegate.read(kryo, input, clazz); + } + + @Override + public void write(final Kryo kryo, final Output output, final List obj) { + delegate.write(kryo, output, obj); + } + + @Override + public List copy(final Kryo kryo, final List original) { + return delegate.copy(kryo, original); + } + } + + /** + * Supports sublists created via {@link AbstractList#subList(int, int)}, e.g. LinkedList. In + * oracle jdk such sublists are represented by java.util.SubList. + */ + public static class JavaUtilSubListSerializer extends Serializer> { + + public static final Class SUBLIST_CLASS = + SubListSerializers4Jdk9.getClassOrNull("java.util.SubList"); + + private Field listField; + private Field offsetField; + private Field sizeField; + + public JavaUtilSubListSerializer() { + try { + final Class clazz = Class.forName("java.util.SubList"); + listField = clazz.getDeclaredField("l"); + offsetField = clazz.getDeclaredField("offset"); + sizeField = clazz.getDeclaredField("size"); + listField.setAccessible(true); + offsetField.setAccessible(true); + sizeField.setAccessible(true); + } catch (final Exception e) { + throw new RuntimeException(e); + } } /** - * Supports sublists created via {@link AbstractList#subList(int, int)}, e.g. LinkedList. - * In oracle jdk such sublists are represented by java.util.SubList. + * Can be used to determine, if the given type can be handled by this serializer. + * + * @param type the class to check. + * @return true if the given class can be serialized/deserialized by this serializer. */ - public static class JavaUtilSubListSerializer extends Serializer> { - - public static final Class SUBLIST_CLASS = SubListSerializers4Jdk9 - .getClassOrNull("java.util.SubList"); - - private Field listField; - private Field offsetField; - private Field sizeField; - - public JavaUtilSubListSerializer() { - try { - final Class clazz = Class.forName("java.util.SubList"); - listField = clazz.getDeclaredField("l"); - offsetField = clazz.getDeclaredField("offset"); - sizeField = clazz.getDeclaredField("size"); - listField.setAccessible(true); - offsetField.setAccessible(true); - sizeField.setAccessible(true); - } catch (final Exception e) { - throw new RuntimeException(e); - } - } - - /** - * Can be used to determine, if the given type can be handled by this serializer. - * - * @param type the class to check. - * @return true if the given class can be serialized/deserialized by this serializer. - */ - public static boolean canSerialize(final Class type) { - return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); - } - - public static Kryo addDefaultSerializer(Kryo kryo) { - if (SUBLIST_CLASS != null) { - kryo.addDefaultSerializer(SUBLIST_CLASS, new JavaUtilSubListSerializer()); - } - return kryo; - } - - @Override - public List read(final Kryo kryo, final Input input, final Class> clazz) { - kryo.reference(FAKE_REFERENCE); - final List list = (List) kryo.readClassAndObject(input); - final int fromIndex = input.readInt(true); - final int toIndex = input.readInt(true); - return list.subList(fromIndex, toIndex); - } - - @Override - public void write(final Kryo kryo, final Output output, final List obj) { - try { - kryo.writeClassAndObject(output, listField.get(obj)); - final int fromIndex = offsetField.getInt(obj); - output.writeInt(fromIndex, true); - final int toIndex = fromIndex + sizeField.getInt(obj); - output.writeInt(toIndex, true); - } catch (final RuntimeException e) { - // Don't eat and wrap RuntimeExceptions because the ObjectBuffer.write... - // handles SerializationException specifically (resizing the buffer)... - throw e; - } catch (final Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public List copy(final Kryo kryo, final List obj) { - kryo.reference(FAKE_REFERENCE); - try { - final List list = (List) listField.get(obj); - final int fromIndex = offsetField.getInt(obj); - final int toIndex = fromIndex + sizeField.getInt(obj); - return kryo.copy(list).subList(fromIndex, toIndex); - } catch (final Exception e) { - throw new RuntimeException(e); - } - } + public static boolean canSerialize(final Class type) { + return SUBLIST_CLASS != null && SUBLIST_CLASS.isAssignableFrom(type); + } + + public static Kryo addDefaultSerializer(Kryo kryo) { + if (SUBLIST_CLASS != null) { + kryo.addDefaultSerializer(SUBLIST_CLASS, new JavaUtilSubListSerializer()); + } + return kryo; + } + + @Override + public List read(final Kryo kryo, final Input input, final Class> clazz) { + kryo.reference(FAKE_REFERENCE); + final List list = (List) kryo.readClassAndObject(input); + final int fromIndex = input.readInt(true); + final int toIndex = input.readInt(true); + return list.subList(fromIndex, toIndex); + } + + @Override + public void write(final Kryo kryo, final Output output, final List obj) { + try { + kryo.writeClassAndObject(output, listField.get(obj)); + final int fromIndex = offsetField.getInt(obj); + output.writeInt(fromIndex, true); + final int toIndex = fromIndex + sizeField.getInt(obj); + output.writeInt(toIndex, true); + } catch (final RuntimeException e) { + // Don't eat and wrap RuntimeExceptions because the ObjectBuffer.write... + // handles SerializationException specifically (resizing the buffer)... + throw e; + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public List copy(final Kryo kryo, final List obj) { + kryo.reference(FAKE_REFERENCE); + try { + final List list = (List) listField.get(obj); + final int fromIndex = offsetField.getInt(obj); + final int toIndex = fromIndex + sizeField.getInt(obj); + return kryo.copy(list).subList(fromIndex, toIndex); + } catch (final Exception e) { + throw new RuntimeException(e); + } } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/BatchPhase.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/BatchPhase.java index 0e49a6e87..998f17b6d 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/BatchPhase.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/BatchPhase.java @@ -21,17 +21,10 @@ public enum BatchPhase { - /** - * Execute the shuffle process in pull mode. - */ - CLASSIC, - /** - * Read the pre-fetched shuffle data. - */ - PREFETCH_READ, - /** - * Pre-fetch the data to the reduce side. - */ - PREFETCH_WRITE - + /** Execute the shuffle process in pull mode. */ + CLASSIC, + /** Read the pre-fetched shuffle data. */ + PREFETCH_READ, + /** Pre-fetch the data to the reduce side. */ + PREFETCH_WRITE } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/DataExchangeMode.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/DataExchangeMode.java index 3b0473d24..b0ed64b15 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/DataExchangeMode.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/DataExchangeMode.java @@ -21,15 +21,12 @@ public enum DataExchangeMode { - /** - * The data exchange is streamed, sender and receiver are online at the same time. - */ - PIPELINE, - - /** - * The data exchange is decoupled. The sender first produces its entire result and - * finishes. After that, the receiver is started and may consume the data. - */ - BATCH + /** The data exchange is streamed, sender and receiver are online at the same time. */ + PIPELINE, + /** + * The data exchange is decoupled. The sender first produces its entire result and finishes. After + * that, the receiver is started and may consume the data. + */ + BATCH } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/ShuffleAddress.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/ShuffleAddress.java index 6fe117fc5..2f933106f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/ShuffleAddress.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/ShuffleAddress.java @@ -24,45 +24,43 @@ public class ShuffleAddress implements Serializable { - /** - * ip address. - */ - private final String host; - private final int port; + /** ip address. */ + private final String host; - public ShuffleAddress(String host, int port) { - this.host = host; - this.port = port; - } + private final int port; - public String host() { - return host; - } + public ShuffleAddress(String host, int port) { + this.host = host; + this.port = port; + } - public int port() { - return port; - } + public String host() { + return host; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ShuffleAddress that = (ShuffleAddress) o; - return port == that.port && Objects.equals(host, that.host); - } + public int port() { + return port; + } - @Override - public int hashCode() { - return Objects.hash(host, port); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return host + ":" + port; + if (o == null || getClass() != o.getClass()) { + return false; } + ShuffleAddress that = (ShuffleAddress) o; + return port == that.port && Objects.equals(host, that.host); + } + + @Override + public int hashCode() { + return Objects.hash(host, port); + } + @Override + public String toString() { + return host + ":" + port; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/StorageLevel.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/StorageLevel.java index 71ae6fd15..98fe98430 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/StorageLevel.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/shuffle/StorageLevel.java @@ -19,25 +19,18 @@ package org.apache.geaflow.common.shuffle; -/** - * Shuffle data storage level. - */ +/** Shuffle data storage level. */ public enum StorageLevel { - /** - * Shuffle data is stored in memory. - */ - MEMORY, - - /** - * Shuffle data are stored on local disks. - */ - DISK, + /** Shuffle data is stored in memory. */ + MEMORY, - /** - * Shuffle data is written to memory first, and if there is insufficient memory, it is then - * written to disk. - */ - MEMORY_AND_DISK + /** Shuffle data are stored on local disks. */ + DISK, + /** + * Shuffle data is written to memory first, and if there is insufficient memory, it is then + * written to disk. + */ + MEMORY_AND_DISK } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/task/TaskArgs.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/task/TaskArgs.java index 37584ea5c..52eb571d2 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/task/TaskArgs.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/task/TaskArgs.java @@ -22,50 +22,54 @@ import java.io.Serializable; /** - * TaskArgs that denotes relevant information of task, including: taskId, taskIndex, taskParallelism and - * maxParallelism, processIndex. + * TaskArgs that denotes relevant information of task, including: taskId, taskIndex, taskParallelism + * and maxParallelism, processIndex. */ public class TaskArgs implements Serializable { - private int taskId; - private int taskIndex; - private String taskName; - private int parallelism; - private int maxParallelism; - private int processIndex; + private int taskId; + private int taskIndex; + private String taskName; + private int parallelism; + private int maxParallelism; + private int processIndex; - public TaskArgs(int taskId, int taskIndex, String taskName, int parallelism, - int maxParallelism, int processIndex) { - this.taskId = taskId; - this.taskIndex = taskIndex; - this.taskName = taskName; - this.parallelism = parallelism; - this.maxParallelism = maxParallelism; - this.processIndex = processIndex; - } + public TaskArgs( + int taskId, + int taskIndex, + String taskName, + int parallelism, + int maxParallelism, + int processIndex) { + this.taskId = taskId; + this.taskIndex = taskIndex; + this.taskName = taskName; + this.parallelism = parallelism; + this.maxParallelism = maxParallelism; + this.processIndex = processIndex; + } - public int getTaskId() { - return taskId; - } + public int getTaskId() { + return taskId; + } - public int getTaskIndex() { - return taskIndex; - } + public int getTaskIndex() { + return taskIndex; + } - public int getParallelism() { - return parallelism; - } + public int getParallelism() { + return parallelism; + } - public int getMaxParallelism() { - return maxParallelism; - } + public int getMaxParallelism() { + return maxParallelism; + } - public String getTaskName() { - return taskName; - } - - public int getProcessIndex() { - return processIndex; - } + public String getTaskName() { + return taskName; + } + public int getProcessIndex() { + return processIndex; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/BoundedExecutor.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/BoundedExecutor.java index 6f9918727..f2fcfa35a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/BoundedExecutor.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/BoundedExecutor.java @@ -26,87 +26,86 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class BoundedExecutor extends ThreadPoolExecutor { - private static final Logger LOGGER = LoggerFactory.getLogger(BoundedExecutor.class); - private static final int DEFAULT_KEEP_ALIVE_MINUTES = 30; + private static final Logger LOGGER = LoggerFactory.getLogger(BoundedExecutor.class); + private static final int DEFAULT_KEEP_ALIVE_MINUTES = 30; - private final Semaphore semaphore; - private final int counter; + private final Semaphore semaphore; + private final int counter; - public BoundedExecutor(int bound, int capacity) { - this(bound, capacity, DEFAULT_KEEP_ALIVE_MINUTES, TimeUnit.MINUTES); - } + public BoundedExecutor(int bound, int capacity) { + this(bound, capacity, DEFAULT_KEEP_ALIVE_MINUTES, TimeUnit.MINUTES); + } - public BoundedExecutor(int bound, int capacity, long keepAliveTime, TimeUnit unit) { - super(bound, bound, keepAliveTime, unit, new LinkedBlockingQueue<>(capacity)); - counter = capacity + bound; - semaphore = new Semaphore(counter); - } + public BoundedExecutor(int bound, int capacity, long keepAliveTime, TimeUnit unit) { + super(bound, bound, keepAliveTime, unit, new LinkedBlockingQueue<>(capacity)); + counter = capacity + bound; + semaphore = new Semaphore(counter); + } - @Override - protected void afterExecute(Runnable r, Throwable t) { - super.afterExecute(r, t); - semaphore.release(); - } + @Override + protected void afterExecute(Runnable r, Throwable t) { + super.afterExecute(r, t); + semaphore.release(); + } - public void tryExecute(Runnable command) { - while (true) { - try { - semaphore.acquire(); - super.execute(command); - break; - } catch (RejectedExecutionException e) { - LOGGER.info("reject task, retry to submit"); - semaphore.release(); - continue; - } catch (InterruptedException e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } - } - } - - @Override - public Future submit(Runnable task) { - try { - semaphore.acquire(); - return super.submit(task); - } catch (RejectedExecutionException e) { - LOGGER.error(e.getMessage(), e); - semaphore.release(); - throw e; - } catch (InterruptedException e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + public void tryExecute(Runnable command) { + while (true) { + try { + semaphore.acquire(); + super.execute(command); + break; + } catch (RejectedExecutionException e) { + LOGGER.info("reject task, retry to submit"); + semaphore.release(); + continue; + } catch (InterruptedException e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); + } } + } - @Override - public Future submit(Callable task) { - while (true) { - try { - semaphore.acquire(); - return super.submit(task); - } catch (RejectedExecutionException e) { - LOGGER.info("reject task, retry to submit"); - semaphore.release(); - continue; - } catch (InterruptedException e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } - } + @Override + public Future submit(Runnable task) { + try { + semaphore.acquire(); + return super.submit(task); + } catch (RejectedExecutionException e) { + LOGGER.error(e.getMessage(), e); + semaphore.release(); + throw e; + } catch (InterruptedException e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } + } - public boolean isEmpty() { - LOGGER.info("current available:{}, counter:{}", - this.semaphore.availablePermits(), counter); - return this.semaphore.availablePermits() == counter; + @Override + public Future submit(Callable task) { + while (true) { + try { + semaphore.acquire(); + return super.submit(task); + } catch (RejectedExecutionException e) { + LOGGER.info("reject task, retry to submit"); + semaphore.release(); + continue; + } catch (InterruptedException e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); + } } + } + public boolean isEmpty() { + LOGGER.info("current available:{}, counter:{}", this.semaphore.availablePermits(), counter); + return this.semaphore.availablePermits() == counter; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/Executors.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/Executors.java index 20f9bceba..d6acbad6b 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/Executors.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/thread/Executors.java @@ -19,8 +19,6 @@ package org.apache.geaflow.common.thread; -import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -28,151 +26,167 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import org.apache.commons.lang3.StringUtils; - -public class Executors { - - private static final int DEFAULT_KEEP_ALIVE_MINUTES = 30; - private static final int DEFAULT_QUEUE_CAPACITY = 1024; - private static final int DEFAULT_MAGNIFICATION = 2; - private static final int DEFAULT_MAX_MULTIPLE = 10; - - private static final Map BOUNDED_EXECUTORS = new HashMap<>(); - private static final Map UNBOUNDED_EXECUTORS = new HashMap<>(); - private static final int CORE_NUM = Runtime.getRuntime().availableProcessors(); - - private static String getKey(String type, int bound, int capacity, long keepAliveTime, - TimeUnit unit) { - return String.format("%s%s%s%s%s", type, bound, capacity, keepAliveTime, unit); - } - - public static synchronized ExecutorService getBoundedService(int bound, int capacity, - long keepAliveTime, - TimeUnit unit) { - String key = getKey("bound", bound, capacity, keepAliveTime, unit); - if (BOUNDED_EXECUTORS.get(key) == null) { - BoundedExecutor boundedExecutor = new BoundedExecutor(bound, capacity, keepAliveTime, - unit); - BOUNDED_EXECUTORS.put(key, boundedExecutor); - } - return BOUNDED_EXECUTORS.get(key); - } - - public static synchronized ExecutorService getMaxCoreBoundedService() { - return getMaxCoreBoundedService(DEFAULT_MAGNIFICATION); - } - - public static synchronized ExecutorService getMaxCoreBoundedService(int magnification) { - int cores = Runtime.getRuntime().availableProcessors(); - return getBoundedService(magnification * cores, DEFAULT_QUEUE_CAPACITY, - DEFAULT_KEEP_ALIVE_MINUTES, TimeUnit.MINUTES); - } - public static synchronized ExecutorService getService(int bound, int capacity, - long keepAliveTime, TimeUnit unit) { - String key = getKey("normal", bound, capacity, keepAliveTime, unit); - if (BOUNDED_EXECUTORS.get(key) == null) { - ExecutorService boundedExecutor = new ThreadPoolExecutor(bound, bound, keepAliveTime, - unit, new LinkedBlockingQueue<>(capacity)); - BOUNDED_EXECUTORS.put(key, boundedExecutor); - } - return BOUNDED_EXECUTORS.get(key); - } +import org.apache.commons.lang3.StringUtils; - public static synchronized ExecutorService getMultiCoreExecutorService(int maxMultiple, - double magnification) { - return getExecutorService(maxMultiple, (int) (magnification * CORE_NUM)); - } +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.ThreadFactoryBuilder; - public static synchronized ExecutorService getExecutorService(int maxMultiple, int coreNumber) { - Preconditions.checkArgument(coreNumber > 0 && coreNumber <= maxMultiple * CORE_NUM, - "executor core not right " + coreNumber + " is greater than " + maxMultiple * CORE_NUM); - return getService(coreNumber, Integer.MAX_VALUE, DEFAULT_KEEP_ALIVE_MINUTES, - TimeUnit.MINUTES); - } +public class Executors { - public static ExecutorService getExecutorService(int coreNumber, - String threadFormat) { - return getExecutorService(DEFAULT_MAX_MULTIPLE, coreNumber, threadFormat, null); + private static final int DEFAULT_KEEP_ALIVE_MINUTES = 30; + private static final int DEFAULT_QUEUE_CAPACITY = 1024; + private static final int DEFAULT_MAGNIFICATION = 2; + private static final int DEFAULT_MAX_MULTIPLE = 10; + + private static final Map BOUNDED_EXECUTORS = new HashMap<>(); + private static final Map UNBOUNDED_EXECUTORS = new HashMap<>(); + private static final int CORE_NUM = Runtime.getRuntime().availableProcessors(); + + private static String getKey( + String type, int bound, int capacity, long keepAliveTime, TimeUnit unit) { + return String.format("%s%s%s%s%s", type, bound, capacity, keepAliveTime, unit); + } + + public static synchronized ExecutorService getBoundedService( + int bound, int capacity, long keepAliveTime, TimeUnit unit) { + String key = getKey("bound", bound, capacity, keepAliveTime, unit); + if (BOUNDED_EXECUTORS.get(key) == null) { + BoundedExecutor boundedExecutor = new BoundedExecutor(bound, capacity, keepAliveTime, unit); + BOUNDED_EXECUTORS.put(key, boundedExecutor); } - - public static ExecutorService getExecutorService(int coreNumber, - String threadFormat, - Thread.UncaughtExceptionHandler handler) { - return getExecutorService(DEFAULT_MAX_MULTIPLE, coreNumber, threadFormat, handler); + return BOUNDED_EXECUTORS.get(key); + } + + public static synchronized ExecutorService getMaxCoreBoundedService() { + return getMaxCoreBoundedService(DEFAULT_MAGNIFICATION); + } + + public static synchronized ExecutorService getMaxCoreBoundedService(int magnification) { + int cores = Runtime.getRuntime().availableProcessors(); + return getBoundedService( + magnification * cores, + DEFAULT_QUEUE_CAPACITY, + DEFAULT_KEEP_ALIVE_MINUTES, + TimeUnit.MINUTES); + } + + public static synchronized ExecutorService getService( + int bound, int capacity, long keepAliveTime, TimeUnit unit) { + String key = getKey("normal", bound, capacity, keepAliveTime, unit); + if (BOUNDED_EXECUTORS.get(key) == null) { + ExecutorService boundedExecutor = + new ThreadPoolExecutor( + bound, bound, keepAliveTime, unit, new LinkedBlockingQueue<>(capacity)); + BOUNDED_EXECUTORS.put(key, boundedExecutor); } - - public static ExecutorService getExecutorService(int maxMultiple, - int coreNumber, - String threadFormat) { - return getExecutorService(maxMultiple, coreNumber, threadFormat, null); + return BOUNDED_EXECUTORS.get(key); + } + + public static synchronized ExecutorService getMultiCoreExecutorService( + int maxMultiple, double magnification) { + return getExecutorService(maxMultiple, (int) (magnification * CORE_NUM)); + } + + public static synchronized ExecutorService getExecutorService(int maxMultiple, int coreNumber) { + Preconditions.checkArgument( + coreNumber > 0 && coreNumber <= maxMultiple * CORE_NUM, + "executor core not right " + coreNumber + " is greater than " + maxMultiple * CORE_NUM); + return getService(coreNumber, Integer.MAX_VALUE, DEFAULT_KEEP_ALIVE_MINUTES, TimeUnit.MINUTES); + } + + public static ExecutorService getExecutorService(int coreNumber, String threadFormat) { + return getExecutorService(DEFAULT_MAX_MULTIPLE, coreNumber, threadFormat, null); + } + + public static ExecutorService getExecutorService( + int coreNumber, String threadFormat, Thread.UncaughtExceptionHandler handler) { + return getExecutorService(DEFAULT_MAX_MULTIPLE, coreNumber, threadFormat, handler); + } + + public static ExecutorService getExecutorService( + int maxMultiple, int coreNumber, String threadFormat) { + return getExecutorService(maxMultiple, coreNumber, threadFormat, null); + } + + /** + * Creates an ExecutorService with following params. + * + * @param maxMultiple Maximum threads multiplier + * @param coreNumber Number of core threads + * @param threadFormat Thread name format + * @param handler Uncaught exception handler + * @return Configured ExecutorService + */ + public static synchronized ExecutorService getExecutorService( + int maxMultiple, + int coreNumber, + String threadFormat, + Thread.UncaughtExceptionHandler handler) { + int maxThreads = maxMultiple * CORE_NUM; + Preconditions.checkArgument( + coreNumber > 0 && coreNumber <= maxThreads, + "executor threads should be smaller than " + maxThreads); + Preconditions.checkArgument( + StringUtils.isNotEmpty(threadFormat), "thread format couldn't" + " be empty"); + return getNamedService( + coreNumber, + Integer.MAX_VALUE, + DEFAULT_KEEP_ALIVE_MINUTES, + TimeUnit.MINUTES, + threadFormat, + handler); + } + + private static synchronized ExecutorService getNamedService( + int bound, + int capacity, + long keepAliveTime, + TimeUnit unit, + String threadFormat, + Thread.UncaughtExceptionHandler handler) { + String key = getKey(threadFormat, bound, capacity, keepAliveTime, unit); + if (BOUNDED_EXECUTORS.get(key) == null || BOUNDED_EXECUTORS.get(key).isShutdown()) { + ThreadFactoryBuilder builder = + new ThreadFactoryBuilder().setNameFormat(threadFormat).setDaemon(true); + if (handler != null) { + builder.setUncaughtExceptionHandler(handler); + } + ExecutorService boundedExecutor = + new ThreadPoolExecutor( + bound, + bound, + keepAliveTime, + unit, + new LinkedBlockingQueue<>(capacity), + builder.build()); + BOUNDED_EXECUTORS.put(key, boundedExecutor); } - - /** - * Creates an ExecutorService with following params. - * - * @param maxMultiple Maximum threads multiplier - * @param coreNumber Number of core threads - * @param threadFormat Thread name format - * @param handler Uncaught exception handler - * @return Configured ExecutorService - */ - public static synchronized ExecutorService getExecutorService(int maxMultiple, - int coreNumber, - String threadFormat, - Thread.UncaughtExceptionHandler handler) { - int maxThreads = maxMultiple * CORE_NUM; - Preconditions.checkArgument(coreNumber > 0 && coreNumber <= maxThreads, - "executor threads should be smaller than " + maxThreads); - Preconditions.checkArgument(StringUtils.isNotEmpty(threadFormat), - "thread format couldn't" + " be empty"); - return getNamedService(coreNumber, Integer.MAX_VALUE, DEFAULT_KEEP_ALIVE_MINUTES, - TimeUnit.MINUTES, threadFormat, handler); + return BOUNDED_EXECUTORS.get(key); + } + + public static synchronized ExecutorService getUnboundedExecutorService( + String name, + long keepAliveTime, + TimeUnit unit, + String threadFormat, + Thread.UncaughtExceptionHandler handler) { + ExecutorService cached = UNBOUNDED_EXECUTORS.get(name); + if (cached != null && !cached.isShutdown()) { + return cached; } - private static synchronized ExecutorService getNamedService(int bound, int capacity, - long keepAliveTime, TimeUnit unit, - String threadFormat, - Thread.UncaughtExceptionHandler handler) { - String key = getKey(threadFormat, bound, capacity, keepAliveTime, unit); - if (BOUNDED_EXECUTORS.get(key) == null || BOUNDED_EXECUTORS.get(key).isShutdown()) { - ThreadFactoryBuilder builder = new ThreadFactoryBuilder() - .setNameFormat(threadFormat) - .setDaemon(true); - if (handler != null) { - builder.setUncaughtExceptionHandler(handler); - } - ExecutorService boundedExecutor = new ThreadPoolExecutor(bound, bound, keepAliveTime, - unit, new LinkedBlockingQueue<>(capacity), builder.build()); - BOUNDED_EXECUTORS.put(key, boundedExecutor); - } - return BOUNDED_EXECUTORS.get(key); + ThreadFactoryBuilder builder = new ThreadFactoryBuilder().setDaemon(true); + if (threadFormat != null) { + builder.setNameFormat(threadFormat); } - - public static synchronized ExecutorService getUnboundedExecutorService(String name, - long keepAliveTime, - TimeUnit unit, - String threadFormat, - Thread.UncaughtExceptionHandler handler) { - ExecutorService cached = UNBOUNDED_EXECUTORS.get(name); - if (cached != null && !cached.isShutdown()) { - return cached; - } - - ThreadFactoryBuilder builder = new ThreadFactoryBuilder() - .setDaemon(true); - if (threadFormat != null) { - builder.setNameFormat(threadFormat); - } - if (handler != null) { - builder.setUncaughtExceptionHandler(handler); - } - ThreadPoolExecutor pool = new ThreadPoolExecutor( - 0, Integer.MAX_VALUE, keepAliveTime, unit, - new SynchronousQueue<>(), builder.build()); - UNBOUNDED_EXECUTORS.put(name, pool); - return pool; - + if (handler != null) { + builder.setUncaughtExceptionHandler(handler); } - + ThreadPoolExecutor pool = + new ThreadPoolExecutor( + 0, Integer.MAX_VALUE, keepAliveTime, unit, new SynchronousQueue<>(), builder.build()); + UNBOUNDED_EXECUTORS.put(name, pool); + return pool; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Triple.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Triple.java index 56cb86666..b3e5d8237 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Triple.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Triple.java @@ -21,72 +21,78 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.common.utils.StringUtils; public class Triple implements Serializable { - public F0 f0; - - public F1 f1; + public F0 f0; - public F2 f2; + public F1 f1; - public Triple(F0 f0, F1 f1, F2 f2) { - this.f0 = f0; - this.f1 = f1; - this.f2 = f2; - } + public F2 f2; - public static Triple of(F0 f0, F1 f1, F2 f2) { - return new Triple<>(f0, f1, f2); - } + public Triple(F0 f0, F1 f1, F2 f2) { + this.f0 = f0; + this.f1 = f1; + this.f2 = f2; + } - public F0 getF0() { - return f0; - } + public static Triple of(F0 f0, F1 f1, F2 f2) { + return new Triple<>(f0, f1, f2); + } - public void setF0(F0 f0) { - this.f0 = f0; - } + public F0 getF0() { + return f0; + } - public F1 getF1() { - return f1; - } + public void setF0(F0 f0) { + this.f0 = f0; + } - public void setF1(F1 f1) { - this.f1 = f1; - } + public F1 getF1() { + return f1; + } - public F2 getF2() { - return f2; - } + public void setF1(F1 f1) { + this.f1 = f1; + } - public void setF2(F2 f2) { - this.f2 = f2; - } + public F2 getF2() { + return f2; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof Triple)) { - return false; - } - Triple tuple3 = (Triple) o; - return Objects.equals(f0, tuple3.f0) && Objects.equals(f1, tuple3.f1) - && Objects.equals(f2, tuple3.f2); - } + public void setF2(F2 f2) { + this.f2 = f2; + } - @Override - public int hashCode() { - return Objects.hash(f0, f1, f2); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return "(" + StringUtils.arrayAwareToString(f0) - + "," + StringUtils.arrayAwareToString(f1) - + "," + StringUtils.arrayAwareToString(f2) + ")"; + if (!(o instanceof Triple)) { + return false; } + Triple tuple3 = (Triple) o; + return Objects.equals(f0, tuple3.f0) + && Objects.equals(f1, tuple3.f1) + && Objects.equals(f2, tuple3.f2); + } + + @Override + public int hashCode() { + return Objects.hash(f0, f1, f2); + } + + @Override + public String toString() { + return "(" + + StringUtils.arrayAwareToString(f0) + + "," + + StringUtils.arrayAwareToString(f1) + + "," + + StringUtils.arrayAwareToString(f2) + + ")"; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Tuple.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Tuple.java index 4f36f2b7c..dd52b6d51 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Tuple.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/tuple/Tuple.java @@ -21,59 +21,63 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.common.utils.StringUtils; public class Tuple implements Serializable { - public F0 f0; + public F0 f0; - public F1 f1; + public F1 f1; - public Tuple(F0 f0, F1 f1) { - this.f0 = f0; - this.f1 = f1; - } + public Tuple(F0 f0, F1 f1) { + this.f0 = f0; + this.f1 = f1; + } - public static Tuple of(F0 f0, F1 f1) { - return new Tuple<>(f0, f1); - } + public static Tuple of(F0 f0, F1 f1) { + return new Tuple<>(f0, f1); + } - public F0 getF0() { - return f0; - } + public F0 getF0() { + return f0; + } - public void setF0(F0 f0) { - this.f0 = f0; - } + public void setF0(F0 f0) { + this.f0 = f0; + } - public F1 getF1() { - return f1; - } + public F1 getF1() { + return f1; + } - public void setF1(F1 f1) { - this.f1 = f1; - } + public void setF1(F1 f1) { + this.f1 = f1; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof Tuple)) { - return false; - } - Tuple tuple = (Tuple) o; - return Objects.equals(f0, tuple.f0) && Objects.equals(f1, tuple.f1); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(f0, f1); + if (!(o instanceof Tuple)) { + return false; } + Tuple tuple = (Tuple) o; + return Objects.equals(f0, tuple.f0) && Objects.equals(f1, tuple.f1); + } - @Override - public String toString() { - return "(" + StringUtils.arrayAwareToString(f0) - + "," + StringUtils.arrayAwareToString(f1) + ")"; - } + @Override + public int hashCode() { + return Objects.hash(f0, f1); + } + + @Override + public String toString() { + return "(" + + StringUtils.arrayAwareToString(f0) + + "," + + StringUtils.arrayAwareToString(f1) + + ")"; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/IType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/IType.java index 477d58d61..875145342 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/IType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/IType.java @@ -23,15 +23,15 @@ public interface IType extends Serializable { - String getName(); + String getName(); - Class getTypeClass(); + Class getTypeClass(); - byte[] serialize(T obj); + byte[] serialize(T obj); - T deserialize(byte[] bytes); + T deserialize(byte[] bytes); - int compare(T x, T y); + int compare(T x, T y); - boolean isPrimitive(); + boolean isPrimitive(); } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/Types.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/Types.java index 64481553e..4583c12f2 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/Types.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/Types.java @@ -19,11 +19,11 @@ package org.apache.geaflow.common.type; -import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; import java.util.Locale; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.primitive.BinaryStringType; import org.apache.geaflow.common.type.primitive.BooleanType; @@ -38,101 +38,103 @@ import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.common.type.primitive.TimestampType; +import com.google.common.collect.ImmutableMap; + public class Types { - public static final String TYPE_NAME_BOOLEAN = "BOOLEAN"; - public static final String TYPE_NAME_BYTE = "BYTE"; - public static final String TYPE_NAME_SHORT = "SHORT"; - public static final String TYPE_NAME_INTEGER = "INTEGER"; - public static final String TYPE_NAME_LONG = "LONG"; - public static final String TYPE_NAME_FLOAT = "FLOAT"; - public static final String TYPE_NAME_DOUBLE = "DOUBLE"; - public static final String TYPE_NAME_STRING = "STRING"; - public static final String TYPE_NAME_STRUCT = "STRUCT"; - public static final String TYPE_NAME_ARRAY = "ARRAY"; - public static final String TYPE_NAME_VERTEX = "VERTEX"; - public static final String TYPE_NAME_EDGE = "EDGE"; - public static final String TYPE_NAME_PATH = "PATH"; - public static final String TYPE_NAME_CLASS = "CLASS"; - public static final String TYPE_NAME_DECIMAL = "DECIMAL"; - public static final String TYPE_NAME_GRAPH = "GRAPH"; - public static final String TYPE_NAME_OBJECT = "OBJECT"; - public static final String TYPE_NAME_BINARY_STRING = "BINARY_STRING"; - public static final String TYPE_NAME_TIMESTAMP = "TIMESTAMP"; - public static final String TYPE_NAME_DATE = "DATE"; + public static final String TYPE_NAME_BOOLEAN = "BOOLEAN"; + public static final String TYPE_NAME_BYTE = "BYTE"; + public static final String TYPE_NAME_SHORT = "SHORT"; + public static final String TYPE_NAME_INTEGER = "INTEGER"; + public static final String TYPE_NAME_LONG = "LONG"; + public static final String TYPE_NAME_FLOAT = "FLOAT"; + public static final String TYPE_NAME_DOUBLE = "DOUBLE"; + public static final String TYPE_NAME_STRING = "STRING"; + public static final String TYPE_NAME_STRUCT = "STRUCT"; + public static final String TYPE_NAME_ARRAY = "ARRAY"; + public static final String TYPE_NAME_VERTEX = "VERTEX"; + public static final String TYPE_NAME_EDGE = "EDGE"; + public static final String TYPE_NAME_PATH = "PATH"; + public static final String TYPE_NAME_CLASS = "CLASS"; + public static final String TYPE_NAME_DECIMAL = "DECIMAL"; + public static final String TYPE_NAME_GRAPH = "GRAPH"; + public static final String TYPE_NAME_OBJECT = "OBJECT"; + public static final String TYPE_NAME_BINARY_STRING = "BINARY_STRING"; + public static final String TYPE_NAME_TIMESTAMP = "TIMESTAMP"; + public static final String TYPE_NAME_DATE = "DATE"; - public static final IType BOOLEAN = BooleanType.INSTANCE; - public static final IType BYTE = ByteType.INSTANCE; - public static final IType SHORT = ShortType.INSTANCE; - public static final IType INTEGER = IntegerType.INSTANCE; - public static final IType LONG = LongType.INSTANCE; - public static final IType FLOAT = FloatType.INSTANCE; - public static final IType DOUBLE = DoubleType.INSTANCE; - public static final IType STRING = StringType.INSTANCE; - public static final IType DECIMAL = DecimalType.INSTANCE; - public static final IType BINARY_STRING = BinaryStringType.INSTANCE; - public static final IType TIMESTAMP = TimestampType.INSTANCE; - public static final IType DATE = DateType.INSTANCE; + public static final IType BOOLEAN = BooleanType.INSTANCE; + public static final IType BYTE = ByteType.INSTANCE; + public static final IType SHORT = ShortType.INSTANCE; + public static final IType INTEGER = IntegerType.INSTANCE; + public static final IType LONG = LongType.INSTANCE; + public static final IType FLOAT = FloatType.INSTANCE; + public static final IType DOUBLE = DoubleType.INSTANCE; + public static final IType STRING = StringType.INSTANCE; + public static final IType DECIMAL = DecimalType.INSTANCE; + public static final IType BINARY_STRING = BinaryStringType.INSTANCE; + public static final IType TIMESTAMP = TimestampType.INSTANCE; + public static final IType DATE = DateType.INSTANCE; - public static final ImmutableMap TYPE_IMMUTABLE_MAP = - ImmutableMap.builder() - .put(BOOLEAN.getTypeClass(), BOOLEAN) - .put(BYTE.getTypeClass(), BYTE) - .put(SHORT.getTypeClass(), SHORT) - .put(INTEGER.getTypeClass(), INTEGER) - .put(LONG.getTypeClass(), LONG) - .put(FLOAT.getTypeClass(), FLOAT) - .put(DOUBLE.getTypeClass(), DOUBLE) - .put(STRING.getTypeClass(), STRING) - .put(DECIMAL.getTypeClass(), DECIMAL) - .put(BINARY_STRING.getTypeClass(), BINARY_STRING) - .put(TIMESTAMP.getTypeClass(), TIMESTAMP) - .put(DATE.getTypeClass(), DATE) - .build(); + public static final ImmutableMap TYPE_IMMUTABLE_MAP = + ImmutableMap.builder() + .put(BOOLEAN.getTypeClass(), BOOLEAN) + .put(BYTE.getTypeClass(), BYTE) + .put(SHORT.getTypeClass(), SHORT) + .put(INTEGER.getTypeClass(), INTEGER) + .put(LONG.getTypeClass(), LONG) + .put(FLOAT.getTypeClass(), FLOAT) + .put(DOUBLE.getTypeClass(), DOUBLE) + .put(STRING.getTypeClass(), STRING) + .put(DECIMAL.getTypeClass(), DECIMAL) + .put(BINARY_STRING.getTypeClass(), BINARY_STRING) + .put(TIMESTAMP.getTypeClass(), TIMESTAMP) + .put(DATE.getTypeClass(), DATE) + .build(); - public static IType getType(Class type) { - return TYPE_IMMUTABLE_MAP.get(type); - } + public static IType getType(Class type) { + return TYPE_IMMUTABLE_MAP.get(type); + } - public static IType of(String typeName, int precision) { - if (typeName == null) { - throw new IllegalArgumentException("typeName is null"); - } - switch (typeName.toUpperCase(Locale.ROOT)) { - case TYPE_NAME_BOOLEAN: - return BOOLEAN; - case TYPE_NAME_BYTE: - return BYTE; - case TYPE_NAME_DOUBLE: - return DOUBLE; - case TYPE_NAME_FLOAT: - return FLOAT; - case TYPE_NAME_INTEGER: - return INTEGER; - case TYPE_NAME_LONG: - return LONG; - case TYPE_NAME_STRING: - return STRING; - case TYPE_NAME_DECIMAL: - return DECIMAL; - case TYPE_NAME_BINARY_STRING: - return new BinaryStringType(precision); - case TYPE_NAME_TIMESTAMP: - return TIMESTAMP; - case TYPE_NAME_DATE: - return DATE; - default: - throw new IllegalArgumentException("Not support typeName: " + typeName); - } + public static IType of(String typeName, int precision) { + if (typeName == null) { + throw new IllegalArgumentException("typeName is null"); + } + switch (typeName.toUpperCase(Locale.ROOT)) { + case TYPE_NAME_BOOLEAN: + return BOOLEAN; + case TYPE_NAME_BYTE: + return BYTE; + case TYPE_NAME_DOUBLE: + return DOUBLE; + case TYPE_NAME_FLOAT: + return FLOAT; + case TYPE_NAME_INTEGER: + return INTEGER; + case TYPE_NAME_LONG: + return LONG; + case TYPE_NAME_STRING: + return STRING; + case TYPE_NAME_DECIMAL: + return DECIMAL; + case TYPE_NAME_BINARY_STRING: + return new BinaryStringType(precision); + case TYPE_NAME_TIMESTAMP: + return TIMESTAMP; + case TYPE_NAME_DATE: + return DATE; + default: + throw new IllegalArgumentException("Not support typeName: " + typeName); } + } - public static int compare(Comparable a, Comparable b) { - if (null == a) { - return b == null ? 0 : -1; - } else if (b == null) { - return 1; - } else { - return a.compareTo(b); - } + public static int compare(Comparable a, Comparable b) { + if (null == a) { + return b == null ? 0 : -1; + } else if (b == null) { + return 1; + } else { + return a.compareTo(b); } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BinaryStringType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BinaryStringType.java index d9817f073..c57c4f73a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BinaryStringType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BinaryStringType.java @@ -25,77 +25,75 @@ public class BinaryStringType implements IType { - public static final BinaryStringType INSTANCE = new BinaryStringType(); + public static final BinaryStringType INSTANCE = new BinaryStringType(); - private int precision; + private int precision; - public BinaryStringType() { + public BinaryStringType() {} - } + public BinaryStringType(int precision) { + this.precision = precision; + } - public BinaryStringType(int precision) { - this.precision = precision; - } + public int getPrecision() { + return precision; + } - public int getPrecision() { - return precision; + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj == null) { - return false; - } - - if (getClass() != obj.getClass()) { - return false; - } - - return true; + if (obj == null) { + return false; } - @Override - public String getName() { - return Types.TYPE_NAME_BINARY_STRING; + if (getClass() != obj.getClass()) { + return false; } - @Override - public Class getTypeClass() { - return BinaryString.class; + return true; + } + + @Override + public String getName() { + return Types.TYPE_NAME_BINARY_STRING; + } + + @Override + public Class getTypeClass() { + return BinaryString.class; + } + + @Override + public byte[] serialize(BinaryString obj) { + return obj.getBytes(); + } + + @Override + public BinaryString deserialize(byte[] bytes) { + return BinaryString.fromBytes(bytes); + } + + @Override + public int compare(BinaryString x, BinaryString y) { + if (x == null) { + return y == null ? 0 : -1; + } else if (y == null) { + return 1; + } else { + return x.compareTo(y); } + } - @Override - public byte[] serialize(BinaryString obj) { - return obj.getBytes(); - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public BinaryString deserialize(byte[] bytes) { - return BinaryString.fromBytes(bytes); - } - - @Override - public int compare(BinaryString x, BinaryString y) { - if (x == null) { - return y == null ? 0 : -1; - } else if (y == null) { - return 1; - } else { - return x.compareTo(y); - } - } - - @Override - public boolean isPrimitive() { - return true; - } - - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BooleanType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BooleanType.java index a3c6c8d2e..b6f694f59 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BooleanType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/BooleanType.java @@ -26,60 +26,60 @@ public class BooleanType implements IType { - public static final BooleanType INSTANCE = new BooleanType(); + public static final BooleanType INSTANCE = new BooleanType(); - @Override - public String getName() { - return Types.TYPE_NAME_BOOLEAN; - } + @Override + public String getName() { + return Types.TYPE_NAME_BOOLEAN; + } - @Override - public Class getTypeClass() { - return Boolean.class; - } + @Override + public Class getTypeClass() { + return Boolean.class; + } - @Override - public byte[] serialize(Boolean obj) { - if (obj == null) { - return null; - } - if (Boolean.TRUE.equals(obj)) { - return new byte[]{1}; - } - if (Boolean.FALSE.equals(obj)) { - return new byte[]{0}; - } - String msg = "illegal boolean value: " + obj; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + @Override + public byte[] serialize(Boolean obj) { + if (obj == null) { + return null; } - - @Override - public Boolean deserialize(byte[] bytes) { - if (bytes == null || bytes.length == 0) { - return null; - } - if (bytes[0] == 1) { - return Boolean.TRUE; - } - if (bytes[0] == 0) { - return Boolean.FALSE; - } - String msg = "illegal boolean value: " + bytes[0]; - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + if (Boolean.TRUE.equals(obj)) { + return new byte[] {1}; } - - @Override - public int compare(Boolean a, Boolean b) { - return Types.compare(a, b); + if (Boolean.FALSE.equals(obj)) { + return new byte[] {0}; } + String msg = "illegal boolean value: " + obj; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } - @Override - public boolean isPrimitive() { - return true; + @Override + public Boolean deserialize(byte[] bytes) { + if (bytes == null || bytes.length == 0) { + return null; } - - @Override - public String toString() { - return getName(); + if (bytes[0] == 1) { + return Boolean.TRUE; } + if (bytes[0] == 0) { + return Boolean.FALSE; + } + String msg = "illegal boolean value: " + bytes[0]; + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError(msg)); + } + + @Override + public int compare(Boolean a, Boolean b) { + return Types.compare(a, b); + } + + @Override + public boolean isPrimitive() { + return true; + } + + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ByteType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ByteType.java index abb5b2a60..729fdde4c 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ByteType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ByteType.java @@ -24,40 +24,40 @@ public class ByteType implements IType { - public static final ByteType INSTANCE = new ByteType(); - - @Override - public String getName() { - return Types.TYPE_NAME_BYTE; - } - - @Override - public Class getTypeClass() { - return Byte.class; - } - - @Override - public byte[] serialize(Byte obj) { - return new byte[]{obj}; - } - - @Override - public Byte deserialize(byte[] bytes) { - return bytes[0]; - } - - @Override - public int compare(Byte a, Byte b) { - return Types.compare(a, b); - } - - @Override - public boolean isPrimitive() { - return true; - } - - @Override - public String toString() { - return getName(); - } + public static final ByteType INSTANCE = new ByteType(); + + @Override + public String getName() { + return Types.TYPE_NAME_BYTE; + } + + @Override + public Class getTypeClass() { + return Byte.class; + } + + @Override + public byte[] serialize(Byte obj) { + return new byte[] {obj}; + } + + @Override + public Byte deserialize(byte[] bytes) { + return bytes[0]; + } + + @Override + public int compare(Byte a, Byte b) { + return Types.compare(a, b); + } + + @Override + public boolean isPrimitive() { + return true; + } + + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DateType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DateType.java index 0e735626c..3a30e1d34 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DateType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DateType.java @@ -19,42 +19,44 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Longs; import java.sql.Date; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Longs; + public class DateType implements IType { - public static final DateType INSTANCE = new DateType(); - - @Override - public String getName() { - return Types.TYPE_NAME_DATE; - } - - @Override - public Class getTypeClass() { - return Date.class; - } - - @Override - public byte[] serialize(Date obj) { - return Longs.toByteArray(obj.getTime()); - } - - @Override - public Date deserialize(byte[] bytes) { - return new Date(Longs.fromByteArray(bytes)); - } - - @Override - public int compare(Date x, Date y) { - return Long.compare(x.getTime(), y.getTime()); - } - - @Override - public boolean isPrimitive() { - return true; - } + public static final DateType INSTANCE = new DateType(); + + @Override + public String getName() { + return Types.TYPE_NAME_DATE; + } + + @Override + public Class getTypeClass() { + return Date.class; + } + + @Override + public byte[] serialize(Date obj) { + return Longs.toByteArray(obj.getTime()); + } + + @Override + public Date deserialize(byte[] bytes) { + return new Date(Longs.fromByteArray(bytes)); + } + + @Override + public int compare(Date x, Date y) { + return Long.compare(x.getTime(), y.getTime()); + } + + @Override + public boolean isPrimitive() { + return true; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DecimalType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DecimalType.java index 57a1e507a..d131162f5 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DecimalType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DecimalType.java @@ -20,63 +20,64 @@ package org.apache.geaflow.common.type.primitive; import java.math.BigDecimal; + import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; public class DecimalType implements IType { - public static final DecimalType INSTANCE = new DecimalType(); + public static final DecimalType INSTANCE = new DecimalType(); - @Override - public String getName() { - return Types.TYPE_NAME_DECIMAL; - } + @Override + public String getName() { + return Types.TYPE_NAME_DECIMAL; + } - @Override - public Class getTypeClass() { - return BigDecimal.class; - } + @Override + public Class getTypeClass() { + return BigDecimal.class; + } - @Override - public byte[] serialize(BigDecimal obj) { - return SerializerFactory.getKryoSerializer().serialize(obj); - } + @Override + public byte[] serialize(BigDecimal obj) { + return SerializerFactory.getKryoSerializer().serialize(obj); + } - @Override - public BigDecimal deserialize(byte[] bytes) { - return (BigDecimal) SerializerFactory.getKryoSerializer().deserialize(bytes); - } + @Override + public BigDecimal deserialize(byte[] bytes) { + return (BigDecimal) SerializerFactory.getKryoSerializer().deserialize(bytes); + } - @Override - public int compare(BigDecimal a, BigDecimal b) { - return Types.compare(a, b); - } + @Override + public int compare(BigDecimal a, BigDecimal b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } - @Override - public int hashCode() { - return getName().hashCode(); - } + @Override + public int hashCode() { + return getName().hashCode(); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - DecimalType decimal = (DecimalType) o; - return this.getName().equals(decimal.getName()); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; } + DecimalType decimal = (DecimalType) o; + return this.getName().equals(decimal.getName()); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DoubleType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DoubleType.java index b630ce62a..b8bcc0827 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DoubleType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/DoubleType.java @@ -19,46 +19,47 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Longs; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Longs; + public class DoubleType implements IType { - public static final DoubleType INSTANCE = new DoubleType(); + public static final DoubleType INSTANCE = new DoubleType(); - @Override - public String getName() { - return Types.TYPE_NAME_DOUBLE; - } + @Override + public String getName() { + return Types.TYPE_NAME_DOUBLE; + } - @Override - public Class getTypeClass() { - return Double.class; - } + @Override + public Class getTypeClass() { + return Double.class; + } - @Override - public byte[] serialize(Double obj) { - return Longs.toByteArray(Double.doubleToLongBits(obj)); - } + @Override + public byte[] serialize(Double obj) { + return Longs.toByteArray(Double.doubleToLongBits(obj)); + } - @Override - public Double deserialize(byte[] bytes) { - return Double.longBitsToDouble(Longs.fromByteArray(bytes)); - } + @Override + public Double deserialize(byte[] bytes) { + return Double.longBitsToDouble(Longs.fromByteArray(bytes)); + } - @Override - public int compare(Double a, Double b) { - return Types.compare(a, b); - } + @Override + public int compare(Double a, Double b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/FloatType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/FloatType.java index 99e22b9b5..8f852c469 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/FloatType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/FloatType.java @@ -19,46 +19,47 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Ints; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Ints; + public class FloatType implements IType { - public static final FloatType INSTANCE = new FloatType(); + public static final FloatType INSTANCE = new FloatType(); - @Override - public String getName() { - return Types.TYPE_NAME_FLOAT; - } + @Override + public String getName() { + return Types.TYPE_NAME_FLOAT; + } - @Override - public Class getTypeClass() { - return Float.class; - } + @Override + public Class getTypeClass() { + return Float.class; + } - @Override - public byte[] serialize(Float obj) { - return Ints.toByteArray(Float.floatToIntBits(obj)); - } + @Override + public byte[] serialize(Float obj) { + return Ints.toByteArray(Float.floatToIntBits(obj)); + } - @Override - public Float deserialize(byte[] bytes) { - return Float.intBitsToFloat(Ints.fromByteArray(bytes)); - } + @Override + public Float deserialize(byte[] bytes) { + return Float.intBitsToFloat(Ints.fromByteArray(bytes)); + } - @Override - public int compare(Float a, Float b) { - return Types.compare(a, b); - } + @Override + public int compare(Float a, Float b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/IntegerType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/IntegerType.java index bf48e1972..02510aae9 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/IntegerType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/IntegerType.java @@ -19,46 +19,47 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Ints; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Ints; + public class IntegerType implements IType { - public static final IntegerType INSTANCE = new IntegerType(); + public static final IntegerType INSTANCE = new IntegerType(); - @Override - public String getName() { - return Types.TYPE_NAME_INTEGER; - } + @Override + public String getName() { + return Types.TYPE_NAME_INTEGER; + } - @Override - public Class getTypeClass() { - return Integer.class; - } + @Override + public Class getTypeClass() { + return Integer.class; + } - @Override - public byte[] serialize(Integer obj) { - return Ints.toByteArray(obj); - } + @Override + public byte[] serialize(Integer obj) { + return Ints.toByteArray(obj); + } - @Override - public Integer deserialize(byte[] bytes) { - return Ints.fromByteArray(bytes); - } + @Override + public Integer deserialize(byte[] bytes) { + return Ints.fromByteArray(bytes); + } - @Override - public int compare(Integer a, Integer b) { - return Types.compare(a, b); - } + @Override + public int compare(Integer a, Integer b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/LongType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/LongType.java index c35c3bcea..b37711ecd 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/LongType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/LongType.java @@ -19,46 +19,47 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Longs; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Longs; + public class LongType implements IType { - public static final LongType INSTANCE = new LongType(); + public static final LongType INSTANCE = new LongType(); - @Override - public String getName() { - return Types.TYPE_NAME_LONG; - } + @Override + public String getName() { + return Types.TYPE_NAME_LONG; + } - @Override - public Class getTypeClass() { - return Long.class; - } + @Override + public Class getTypeClass() { + return Long.class; + } - @Override - public byte[] serialize(Long obj) { - return Longs.toByteArray(obj); - } + @Override + public byte[] serialize(Long obj) { + return Longs.toByteArray(obj); + } - @Override - public Long deserialize(byte[] bytes) { - return Longs.fromByteArray(bytes); - } + @Override + public Long deserialize(byte[] bytes) { + return Longs.fromByteArray(bytes); + } - @Override - public int compare(Long a, Long b) { - return Types.compare(a, b); - } + @Override + public int compare(Long a, Long b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ShortType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ShortType.java index 53d197d3a..b5527fd6f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ShortType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/ShortType.java @@ -19,46 +19,47 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Shorts; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Shorts; + public class ShortType implements IType { - public static final ShortType INSTANCE = new ShortType(); + public static final ShortType INSTANCE = new ShortType(); - @Override - public String getName() { - return Types.TYPE_NAME_SHORT; - } + @Override + public String getName() { + return Types.TYPE_NAME_SHORT; + } - @Override - public Class getTypeClass() { - return Short.class; - } + @Override + public Class getTypeClass() { + return Short.class; + } - @Override - public byte[] serialize(Short obj) { - return Shorts.toByteArray(obj); - } + @Override + public byte[] serialize(Short obj) { + return Shorts.toByteArray(obj); + } - @Override - public Short deserialize(byte[] bytes) { - return Shorts.fromByteArray(bytes); - } + @Override + public Short deserialize(byte[] bytes) { + return Shorts.fromByteArray(bytes); + } - @Override - public int compare(Short a, Short b) { - return Types.compare(a, b); - } + @Override + public int compare(Short a, Short b) { + return Types.compare(a, b); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/StringType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/StringType.java index f0c54d4a9..367552cdc 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/StringType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/StringType.java @@ -24,40 +24,40 @@ public class StringType implements IType { - public static final StringType INSTANCE = new StringType(); - - @Override - public String getName() { - return Types.TYPE_NAME_STRING; - } - - @Override - public Class getTypeClass() { - return String.class; - } - - @Override - public byte[] serialize(String obj) { - return obj.getBytes(); - } - - @Override - public String deserialize(byte[] bytes) { - return new String(bytes); - } - - @Override - public int compare(String a, String b) { - return Types.compare(a, b); - } - - @Override - public boolean isPrimitive() { - return true; - } - - @Override - public String toString() { - return getName(); - } + public static final StringType INSTANCE = new StringType(); + + @Override + public String getName() { + return Types.TYPE_NAME_STRING; + } + + @Override + public Class getTypeClass() { + return String.class; + } + + @Override + public byte[] serialize(String obj) { + return obj.getBytes(); + } + + @Override + public String deserialize(byte[] bytes) { + return new String(bytes); + } + + @Override + public int compare(String a, String b) { + return Types.compare(a, b); + } + + @Override + public boolean isPrimitive() { + return true; + } + + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/TimestampType.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/TimestampType.java index 1cb39890e..1ba789527 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/TimestampType.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/type/primitive/TimestampType.java @@ -19,48 +19,50 @@ package org.apache.geaflow.common.type.primitive; -import com.google.common.primitives.Longs; import java.sql.Timestamp; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; +import com.google.common.primitives.Longs; + public class TimestampType implements IType { - public static final TimestampType INSTANCE = new TimestampType(); + public static final TimestampType INSTANCE = new TimestampType(); - @Override - public String getName() { - return Types.TYPE_NAME_TIMESTAMP; - } + @Override + public String getName() { + return Types.TYPE_NAME_TIMESTAMP; + } - @Override - public Class getTypeClass() { - return Timestamp.class; - } + @Override + public Class getTypeClass() { + return Timestamp.class; + } - @Override - public byte[] serialize(Timestamp obj) { - return Longs.toByteArray(obj.getTime()); - } + @Override + public byte[] serialize(Timestamp obj) { + return Longs.toByteArray(obj.getTime()); + } - @Override - public Timestamp deserialize(byte[] bytes) { - long time = Longs.fromByteArray(bytes); - return new Timestamp(time); - } + @Override + public Timestamp deserialize(byte[] bytes) { + long time = Longs.fromByteArray(bytes); + return new Timestamp(time); + } - @Override - public int compare(Timestamp x, Timestamp y) { - return Types.compare(x, y); - } + @Override + public int compare(Timestamp x, Timestamp y) { + return Types.compare(x, y); + } - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } - @Override - public String toString() { - return getName(); - } + @Override + public String toString() { + return getName(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ArrayUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ArrayUtil.java index 6a55c827d..1b8e612ea 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ArrayUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ArrayUtil.java @@ -19,7 +19,6 @@ package org.apache.geaflow.common.utils; -import com.google.common.collect.Sets; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; @@ -27,125 +26,127 @@ import java.util.List; import java.util.Set; +import com.google.common.collect.Sets; + public class ArrayUtil { - public static int[] toIntArray(Collection list) { - if (list == null) { - return null; - } - int[] intArray = new int[list.size()]; - int i = 0; - for (Integer e : list) { - intArray[i++] = e; - } - return intArray; - } - - public static long[] toLongArray(Collection list) { - if (list == null) { - return null; - } - long[] longArray = new long[list.size()]; - int i = 0; - for (Long e : list) { - longArray[i++] = e; - } - return longArray; - } - - public static List toList(int[] array) { - if (array == null) { - return null; - } - List list = new ArrayList<>(); - for (int i : array) { - list.add(i); - } - return list; - } - - public static int indexOf(long[] longs, long value) { - if (longs == null) { - return -1; - } - int index = -1; - for (int i = 0; i < longs.length; i++) { - if (longs[i] == value) { - index = i; - break; - } - } - return index; - } - - public static long[] grow(long[] longs, int growSize) { - long[] newArray; - if (longs == null) { - newArray = new long[growSize]; - } else { - newArray = new long[longs.length + growSize]; - System.arraycopy(longs, 0, newArray, 0, longs.length); - } - return newArray; - } - - public static long[] copy(long[] longs) { - if (longs == null) { - return null; - } - return Arrays.copyOf(longs, longs.length); - } - - @SuppressWarnings("unchecked") - public static T[] concat(T[] a, T[] b) { - if (a == null) { - return b; - } - if (b == null) { - return a; - } - T[] c = (T[]) Array.newInstance(a.getClass().getComponentType(), a.length + b.length); - System.arraycopy(a, 0, c, 0, a.length); - System.arraycopy(b, 0, c, a.length, b.length); - return c; - } - - public static List castList(List list) { - if (list == null) { - return null; - } - List outList = new ArrayList<>(list.size()); - for (IN in : list) { - outList.add((OUT) in); - } - return outList; - } - - public static boolean isEmpty(Collection collection) { - return collection == null || collection.isEmpty(); - } - - public static boolean isEmpty(int[] array) { - return array == null || array.length == 0; - } - - public static Set copySet(Set set) { - if (set == null) { - return null; - } - return Sets.newHashSet(set); - } - - public static Object[] concatArray(Object[] array1, Object[] array2) { - if (array1 == null) { - return array2; - } - if (array2 == null) { - return array1; - } - Object[] concat = new Object[array1.length + array2.length]; - System.arraycopy(array1, 0, concat, 0, array1.length); - System.arraycopy(array2, 0, concat, array1.length, array2.length); - return concat; + public static int[] toIntArray(Collection list) { + if (list == null) { + return null; + } + int[] intArray = new int[list.size()]; + int i = 0; + for (Integer e : list) { + intArray[i++] = e; + } + return intArray; + } + + public static long[] toLongArray(Collection list) { + if (list == null) { + return null; + } + long[] longArray = new long[list.size()]; + int i = 0; + for (Long e : list) { + longArray[i++] = e; + } + return longArray; + } + + public static List toList(int[] array) { + if (array == null) { + return null; + } + List list = new ArrayList<>(); + for (int i : array) { + list.add(i); + } + return list; + } + + public static int indexOf(long[] longs, long value) { + if (longs == null) { + return -1; + } + int index = -1; + for (int i = 0; i < longs.length; i++) { + if (longs[i] == value) { + index = i; + break; + } + } + return index; + } + + public static long[] grow(long[] longs, int growSize) { + long[] newArray; + if (longs == null) { + newArray = new long[growSize]; + } else { + newArray = new long[longs.length + growSize]; + System.arraycopy(longs, 0, newArray, 0, longs.length); + } + return newArray; + } + + public static long[] copy(long[] longs) { + if (longs == null) { + return null; + } + return Arrays.copyOf(longs, longs.length); + } + + @SuppressWarnings("unchecked") + public static T[] concat(T[] a, T[] b) { + if (a == null) { + return b; + } + if (b == null) { + return a; + } + T[] c = (T[]) Array.newInstance(a.getClass().getComponentType(), a.length + b.length); + System.arraycopy(a, 0, c, 0, a.length); + System.arraycopy(b, 0, c, a.length, b.length); + return c; + } + + public static List castList(List list) { + if (list == null) { + return null; + } + List outList = new ArrayList<>(list.size()); + for (IN in : list) { + outList.add((OUT) in); + } + return outList; + } + + public static boolean isEmpty(Collection collection) { + return collection == null || collection.isEmpty(); + } + + public static boolean isEmpty(int[] array) { + return array == null || array.length == 0; + } + + public static Set copySet(Set set) { + if (set == null) { + return null; + } + return Sets.newHashSet(set); + } + + public static Object[] concatArray(Object[] array1, Object[] array2) { + if (array1 == null) { + return array2; + } + if (array2 == null) { + return array1; } + Object[] concat = new Object[array1.length + array2.length]; + System.arraycopy(array1, 0, concat, 0, array1.length); + System.arraycopy(array2, 0, concat, array1.length, array2.length); + return concat; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CheckpointUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CheckpointUtil.java index ce7ab8d30..f9cc30a61 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CheckpointUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CheckpointUtil.java @@ -21,7 +21,7 @@ public class CheckpointUtil { - public static boolean needDoCheckpoint(long batchId, long checkpointDuration) { - return batchId % checkpointDuration == 0; - } + public static boolean needDoCheckpoint(long batchId, long checkpointDuration) { + return batchId % checkpointDuration == 0; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ClassUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ClassUtil.java index 7de0ce157..d548a41d5 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ClassUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ClassUtil.java @@ -23,44 +23,44 @@ public class ClassUtil { - public static Class classForName(String className) { - return classForName(className, true); - } + public static Class classForName(String className) { + return classForName(className, true); + } - public static Class classForName(String className, boolean initialize) { - try { - return (Class) Class.forName(className, initialize, getClassLoader()); - } catch (ClassNotFoundException e) { - throw new GeaflowRuntimeException(e); - } + public static Class classForName(String className, boolean initialize) { + try { + return (Class) Class.forName(className, initialize, getClassLoader()); + } catch (ClassNotFoundException e) { + throw new GeaflowRuntimeException(e); } + } - public static Class classForName(String className, ClassLoader classLoader) { - try { - return (Class) Class.forName(className, true, classLoader); - } catch (ClassNotFoundException e) { - throw new GeaflowRuntimeException("fail to load class:" + className, e); - } + public static Class classForName(String className, ClassLoader classLoader) { + try { + return (Class) Class.forName(className, true, classLoader); + } catch (ClassNotFoundException e) { + throw new GeaflowRuntimeException("fail to load class:" + className, e); } + } - public static ClassLoader getClassLoader() { - ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); - if (classLoader == null) { - classLoader = ClassUtil.class.getClassLoader(); - } - return classLoader; + public static ClassLoader getClassLoader() { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (classLoader == null) { + classLoader = ClassUtil.class.getClassLoader(); } + return classLoader; + } - public static O newInstance(Class clazz) { - try { - return clazz.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new GeaflowRuntimeException("fail to create instance for: " + clazz, e); - } + public static O newInstance(Class clazz) { + try { + return clazz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new GeaflowRuntimeException("fail to create instance for: " + clazz, e); } + } - public static O newInstance(String className) { - Class clazz = classForName(className, Thread.currentThread().getContextClassLoader()); - return newInstance(clazz); - } + public static O newInstance(String className) { + Class clazz = classForName(className, Thread.currentThread().getContextClassLoader()); + return newInstance(clazz); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CollectionUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CollectionUtil.java index 90517a83d..4382fec99 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CollectionUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CollectionUtil.java @@ -24,10 +24,10 @@ public class CollectionUtil { - public static List emptyIfNull(List list) { - if (list == null) { - return Collections.emptyList(); - } - return list; + public static List emptyIfNull(List list) { + if (list == null) { + return Collections.emptyList(); } + return list; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CustomClassLoader.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CustomClassLoader.java index 91564918e..949548ff6 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CustomClassLoader.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/CustomClassLoader.java @@ -24,12 +24,11 @@ public class CustomClassLoader extends URLClassLoader { - public CustomClassLoader(URL[] urls) { - this(urls, CustomClassLoader.class.getClassLoader()); - } - - public CustomClassLoader(URL[] urls, ClassLoader loader) { - super(urls, loader); - } + public CustomClassLoader(URL[] urls) { + this(urls, CustomClassLoader.class.getClassLoader()); + } + public CustomClassLoader(URL[] urls, ClassLoader loader) { + super(urls, loader); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/DateTimeUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/DateTimeUtil.java index f361042c3..fe4bb0574 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/DateTimeUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/DateTimeUtil.java @@ -24,16 +24,16 @@ public class DateTimeUtil { - public static long toUnixTime(String dateStr, String format) { - if (dateStr == null || dateStr.isEmpty()) { - return -1; - } else { - DateTimeFormatter dateTimeFormat = DateTimeFormat.forPattern(format); - return dateTimeFormat.parseMillis(dateStr); - } + public static long toUnixTime(String dateStr, String format) { + if (dateStr == null || dateStr.isEmpty()) { + return -1; + } else { + DateTimeFormatter dateTimeFormat = DateTimeFormat.forPattern(format); + return dateTimeFormat.parseMillis(dateStr); } + } - public static String fromUnixTime(long unixTime, String format) { - return DateTimeFormat.forPattern(format).print(unixTime); - } + public static String fromUnixTime(long unixTime, String format) { + return DateTimeFormat.forPattern(format).print(unixTime); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ExecutorUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ExecutorUtil.java index d68a8d74c..be3ad5eb3 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ExecutorUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ExecutorUtil.java @@ -22,51 +22,52 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ExecutorUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutorUtil.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ExecutorUtil.class); - private static final int SHUTDOWN_TIMEOUT_MS = 1000; + private static final int SHUTDOWN_TIMEOUT_MS = 1000; - public static void execute(ExecutorService service, Runnable command, ExceptionHandler handler) { - service.execute(() -> { - try { - command.run(); - } catch (Throwable throwable) { - handler.handle(throwable); - } + public static void execute(ExecutorService service, Runnable command, ExceptionHandler handler) { + service.execute( + () -> { + try { + command.run(); + } catch (Throwable throwable) { + handler.handle(throwable); + } }); - } + } - public static void spinLockMs(Supplier condition, Runnable checkFun, long ms) { - while (!condition.get()) { - SleepUtils.sleepMilliSecond(ms); - checkFun.run(); - } + public static void spinLockMs(Supplier condition, Runnable checkFun, long ms) { + while (!condition.get()) { + SleepUtils.sleepMilliSecond(ms); + checkFun.run(); } + } - public static void shutdown(ExecutorService executorService, long timeout, TimeUnit timeUnit) { - LOGGER.info("shutdown executor service {}", executorService); - executorService.shutdown(); - try { - if (!executorService.awaitTermination(timeout, timeUnit)) { - LOGGER.info("shutdown executor service force"); - executorService.shutdownNow(); - } - } catch (InterruptedException e) { - LOGGER.warn("Interrupted when shutdown executor service", e); - } + public static void shutdown(ExecutorService executorService, long timeout, TimeUnit timeUnit) { + LOGGER.info("shutdown executor service {}", executorService); + executorService.shutdown(); + try { + if (!executorService.awaitTermination(timeout, timeUnit)) { + LOGGER.info("shutdown executor service force"); + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + LOGGER.warn("Interrupted when shutdown executor service", e); } + } - public static void shutdown(ExecutorService executorService) { - shutdown(executorService, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS); - } - - public interface ExceptionHandler { - void handle(Throwable exp); - } + public static void shutdown(ExecutorService executorService) { + shutdown(executorService, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + public interface ExceptionHandler { + void handle(Throwable exp); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FileUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FileUtil.java index 80f05c9b3..6581ce0d3 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FileUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FileUtil.java @@ -19,44 +19,45 @@ package org.apache.geaflow.common.utils; -import com.google.common.base.Joiner; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; +import com.google.common.base.Joiner; + public class FileUtil { - public static String concatPath(String baseDir, String fileName) { - if (baseDir == null || fileName == null) { - throw new NullPointerException(); - } - if (baseDir.endsWith("/")) { - return baseDir + fileName; - } - return baseDir + "/" + fileName; + public static String concatPath(String baseDir, String fileName) { + if (baseDir == null || fileName == null) { + throw new NullPointerException(); } - - public static String constitutePath(String... args) { - return File.separator + Joiner.on(File.separator).join(args); + if (baseDir.endsWith("/")) { + return baseDir + fileName; } + return baseDir + "/" + fileName; + } + + public static String constitutePath(String... args) { + return File.separator + Joiner.on(File.separator).join(args); + } - public static String getContentFromFile(String filePath) { - File file = new File(filePath); - if (file.exists()) { - StringBuilder content = new StringBuilder(); - String line; - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(new FileInputStream(file)))) { - while ((line = reader.readLine()) != null) { - content.append(line).append(System.lineSeparator()); - } - } catch (IOException e) { - throw new RuntimeException("Error read file content.", e); - } - return content.toString(); + public static String getContentFromFile(String filePath) { + File file = new File(filePath); + if (file.exists()) { + StringBuilder content = new StringBuilder(); + String line; + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(new FileInputStream(file)))) { + while ((line = reader.readLine()) != null) { + content.append(line).append(System.lineSeparator()); } - return null; + } catch (IOException e) { + throw new RuntimeException("Error read file content.", e); + } + return content.toString(); } + return null; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FutureUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FutureUtil.java index dabca4b0f..082b77a53 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FutureUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/FutureUtil.java @@ -24,32 +24,33 @@ import java.util.List; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class FutureUtil { - public static List wait(List> futureList) { - return wait(futureList, 0); - } + public static List wait(List> futureList) { + return wait(futureList, 0); + } - public static List wait(List> futureList, int timeoutMs) { - return wait(futureList, timeoutMs, TimeUnit.MILLISECONDS); - } + public static List wait(List> futureList, int timeoutMs) { + return wait(futureList, timeoutMs, TimeUnit.MILLISECONDS); + } - public static List wait(Collection> futureList, long timeout, - TimeUnit timeUnit) { - List result = new ArrayList<>(); - for (Future future : futureList) { - try { - if (timeout > 0) { - result.add(future.get(timeout, timeUnit)); - } else { - result.add(future.get()); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public static List wait( + Collection> futureList, long timeout, TimeUnit timeUnit) { + List result = new ArrayList<>(); + for (Future future : futureList) { + try { + if (timeout > 0) { + result.add(future.get(timeout, timeUnit)); + } else { + result.add(future.get()); } - return result; + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } } + return result; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GcUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GcUtil.java index 592923d9d..031fe0a30 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GcUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GcUtil.java @@ -25,12 +25,12 @@ public class GcUtil { - public static long computeCurrentTotalGcTime() { - long totalGcTime = 0; - List beans = ManagementFactory.getGarbageCollectorMXBeans(); - for (GarbageCollectorMXBean bean : beans) { - totalGcTime += bean.getCollectionTime(); - } - return totalGcTime; + public static long computeCurrentTotalGcTime() { + long totalGcTime = 0; + List beans = ManagementFactory.getGarbageCollectorMXBeans(); + for (GarbageCollectorMXBean bean : beans) { + totalGcTime += bean.getCollectionTime(); } + return totalGcTime; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GsonUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GsonUtil.java index e2ab26ff9..5c35ecb43 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GsonUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/GsonUtil.java @@ -19,24 +19,23 @@ package org.apache.geaflow.common.utils; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; import java.lang.reflect.Type; import java.util.Map; -public class GsonUtil { +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; - private static final Type TYPE = new TypeToken>() { - }.getType(); +public class GsonUtil { - public static Map parse(String json) { - Gson gson = new Gson(); - return gson.fromJson(json, TYPE); - } + private static final Type TYPE = new TypeToken>() {}.getType(); - public static String toJson(Map config) { - Gson gson = new Gson(); - return gson.toJson(config, TYPE); - } + public static Map parse(String json) { + Gson gson = new Gson(); + return gson.fromJson(json, TYPE); + } + public static String toJson(Map config) { + Gson gson = new Gson(); + return gson.toJson(config, TYPE); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/IdGenerator.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/IdGenerator.java index 154723c1f..730fa8ed8 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/IdGenerator.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/IdGenerator.java @@ -20,88 +20,84 @@ package org.apache.geaflow.common.utils; /** - * Twitter-like snowflake id generator. - * Id is composed of: - * timestamp | containerId | sequence - * 41 | 16 | 6 - * timestamp - 41 bits (millisecond precision w/ a custom epoch gives us 69 years) - * containerId - 16 bits - gives us up to 65536 workers - * sequence number - 6 bits - rolls over every 64 per machine (with protection to avoid - * rollover in the same ms) + * Twitter-like snowflake id generator. Id is composed of: timestamp | containerId | sequence 41 | + * 16 | 6 timestamp - 41 bits (millisecond precision w/ a custom epoch gives us 69 years) + * containerId - 16 bits - gives us up to 65536 workers sequence number - 6 bits - rolls over every + * 64 per machine (with protection to avoid rollover in the same ms) */ public class IdGenerator { - /** - * Start timestamp.(2022-01-01 00:00:00.000) - */ - private static final long START_EPOCH = 1611158400000L; + /** Start timestamp.(2022-01-01 00:00:00.000) */ + private static final long START_EPOCH = 1611158400000L; - private final long sequenceBits = 6L; - private final int containerIdBits = 16; + private final long sequenceBits = 6L; + private final int containerIdBits = 16; - private final long maxContainerId = -1L ^ (-1L << containerIdBits); - private final long containerIdShift = sequenceBits; - private final long timestampLeftShift = sequenceBits + containerIdBits; + private final long maxContainerId = -1L ^ (-1L << containerIdBits); + private final long containerIdShift = sequenceBits; + private final long timestampLeftShift = sequenceBits + containerIdBits; - private final long sequenceMask = -1L ^ (-1L << sequenceBits); + private final long sequenceMask = -1L ^ (-1L << sequenceBits); - private long containerId; + private long containerId; - private long lastTimestamp = -1L; - private long sequence = 0L; + private long lastTimestamp = -1L; + private long sequence = 0L; - /** - * Id generator. - * @param containerId (0~65534). - */ - public IdGenerator(long containerId) { - if (containerId > maxContainerId || containerId < 0) { - throw new IllegalArgumentException( - String.format("worker Id can't be greater than %d or less than 0", maxContainerId)); - } - this.containerId = containerId; + /** + * Id generator. + * + * @param containerId (0~65534). + */ + public IdGenerator(long containerId) { + if (containerId > maxContainerId || containerId < 0) { + throw new IllegalArgumentException( + String.format("worker Id can't be greater than %d or less than 0", maxContainerId)); } - - /** - * Generate next Id. - * - * @return - */ - public synchronized long nextId() { - long timestamp = currentTimeMillis(); - - if (timestamp < lastTimestamp) { - throw new RuntimeException( - String.format("Clock moved backwards. Refusing to generate id for %d milliseconds", - lastTimestamp - timestamp)); - } - - if (lastTimestamp == timestamp) { - sequence = (sequence + 1) & sequenceMask; - if (sequence == 0) { - // blocking till next millis second - timestamp = tilNextMillis(lastTimestamp); - } - } else { - sequence = 0L; - } - - lastTimestamp = timestamp; - - return ((timestamp - START_EPOCH) << timestampLeftShift) - | (containerId << containerIdShift) - | sequence; + this.containerId = containerId; + } + + /** + * Generate next Id. + * + * @return + */ + public synchronized long nextId() { + long timestamp = currentTimeMillis(); + + if (timestamp < lastTimestamp) { + throw new RuntimeException( + String.format( + "Clock moved backwards. Refusing to generate id for %d milliseconds", + lastTimestamp - timestamp)); } - protected long tilNextMillis(long lastTimestamp) { - long timestamp = currentTimeMillis(); - while (timestamp <= lastTimestamp) { - timestamp = currentTimeMillis(); - } - return timestamp; + if (lastTimestamp == timestamp) { + sequence = (sequence + 1) & sequenceMask; + if (sequence == 0) { + // blocking till next millis second + timestamp = tilNextMillis(lastTimestamp); + } + } else { + sequence = 0L; } - protected long currentTimeMillis() { - return System.currentTimeMillis(); + lastTimestamp = timestamp; + + return ((timestamp - START_EPOCH) << timestampLeftShift) + | (containerId << containerIdShift) + | sequence; + } + + protected long tilNextMillis(long lastTimestamp) { + long timestamp = currentTimeMillis(); + while (timestamp <= lastTimestamp) { + timestamp = currentTimeMillis(); } + return timestamp; + } + + protected long currentTimeMillis() { + return System.currentTimeMillis(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LogMsgUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LogMsgUtil.java index 84dd5e72e..daa4a8f3a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LogMsgUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LogMsgUtil.java @@ -21,13 +21,12 @@ public class LogMsgUtil { - public static String getStackMsg(Exception e) { - StringBuilder sb = new StringBuilder(); - StackTraceElement[] stackArray = e.getStackTrace(); - for (StackTraceElement element : stackArray) { - sb.append(element.toString()).append("\n"); - } - return sb.toString(); + public static String getStackMsg(Exception e) { + StringBuilder sb = new StringBuilder(); + StackTraceElement[] stackArray = e.getStackTrace(); + for (StackTraceElement element : stackArray) { + sb.append(element.toString()).append("\n"); } - + return sb.toString(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LoggerFormatter.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LoggerFormatter.java index 92bad5e6b..c62f47377 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LoggerFormatter.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/LoggerFormatter.java @@ -23,125 +23,131 @@ public class LoggerFormatter { - public LoggerFormatter() { - } - - public static String getCycleName(int cycleId) { - return String.format("cycle#%s", cycleId); - } - - public static String getCycleName(int cycleId, long windowId) { - return String.format("cycle#%s-%s", cycleId, windowId); - } - - public static String getCycleTag(String pipelineName, int cycleId) { - return String.format("%s %s", pipelineName, getCycleName(cycleId)); - } - - public static String getCycleTag(String pipelineName, int cycleId, long windowId) { - return String.format("%s %s", pipelineName, getCycleName(cycleId, windowId)); - } - - public static String getCycleMetricName(int cycleId, int vertexId) { - return String.format("%s[%d]", getCycleName(cycleId), vertexId); - } - - public static String getCycleMetricName(int cycleId, long windowId, int vertexId) { - return String.format("%s[%d]", getCycleName(cycleId, windowId), vertexId); - } - - public static String getTaskTag(String pipelineName, int cycleId, - int taskId, int vertexId, int index, int parallelism) { - return String.format("%s task#%d [%d-%d/%d]", getCycleTag(pipelineName, cycleId), - taskId, vertexId, index, parallelism); - } - - public static String getTaskTag(String pipelineName, int cycleId, long windowId, - int taskId, int vertexId, int index, int parallelism) { - return String.format("%s task#%d [%d-%d/%d]", getCycleTag(pipelineName, cycleId, windowId), - taskId, vertexId, index, parallelism); - } - - /** - * Get the exception stack message in order to troubleshoot problems. - * - * @param e - * @return - */ - public static String getStackMsg(Exception e) { - StringBuffer sb = new StringBuffer(); - StackTraceElement[] stackArray = e.getStackTrace(); - for (int i = 0; i < stackArray.length; i++) { - StackTraceElement element = stackArray[i]; - sb.append(element.toString() + "\n"); - } - return sb.toString(); - } - - public static void debug(Logger logger, String msg) { - if (logger.isDebugEnabled()) { - logger.debug(msg); - } - } - - public static void debug(Logger logger, String msg, Object o) { - if (logger.isDebugEnabled()) { - logger.debug(msg, o); - } - } - - public static void debug(Logger logger, String msg, Object... o) { - if (logger.isDebugEnabled()) { - logger.debug(msg, o); - } - } - - public static void info(Logger logger, String msg) { - if (logger.isInfoEnabled()) { - logger.info(msg); - } - } - - public static void info(Logger logger, String msg, Object o) { - if (logger.isInfoEnabled()) { - logger.info(msg, o); - } - } - - public static void info(Logger logger, String msg, Object... o) { - if (logger.isInfoEnabled()) { - logger.info(msg, o); - } - } - - public static void info(Logger logger, String msg, Throwable t) { - if (logger.isInfoEnabled()) { - logger.info(msg, t); - } - } - - public static void warn(Logger logger, String msg) { - logger.warn(msg); - } - - public static void warn(Logger logger, String msg, Object... o) { - logger.warn(msg, o); - } - - public static void warn(Logger logger, String msg, Throwable t) { - logger.warn(msg, t); - } - - public static void error(Logger logger, String msg) { - logger.error(msg); - } - - public static void error(Logger logger, String msg, Object... o) { - logger.error(msg, o); - } - - public static void error(Logger logger, String msg, Throwable t) { - logger.error(msg, t); - } - + public LoggerFormatter() {} + + public static String getCycleName(int cycleId) { + return String.format("cycle#%s", cycleId); + } + + public static String getCycleName(int cycleId, long windowId) { + return String.format("cycle#%s-%s", cycleId, windowId); + } + + public static String getCycleTag(String pipelineName, int cycleId) { + return String.format("%s %s", pipelineName, getCycleName(cycleId)); + } + + public static String getCycleTag(String pipelineName, int cycleId, long windowId) { + return String.format("%s %s", pipelineName, getCycleName(cycleId, windowId)); + } + + public static String getCycleMetricName(int cycleId, int vertexId) { + return String.format("%s[%d]", getCycleName(cycleId), vertexId); + } + + public static String getCycleMetricName(int cycleId, long windowId, int vertexId) { + return String.format("%s[%d]", getCycleName(cycleId, windowId), vertexId); + } + + public static String getTaskTag( + String pipelineName, int cycleId, int taskId, int vertexId, int index, int parallelism) { + return String.format( + "%s task#%d [%d-%d/%d]", + getCycleTag(pipelineName, cycleId), taskId, vertexId, index, parallelism); + } + + public static String getTaskTag( + String pipelineName, + int cycleId, + long windowId, + int taskId, + int vertexId, + int index, + int parallelism) { + return String.format( + "%s task#%d [%d-%d/%d]", + getCycleTag(pipelineName, cycleId, windowId), taskId, vertexId, index, parallelism); + } + + /** + * Get the exception stack message in order to troubleshoot problems. + * + * @param e + * @return + */ + public static String getStackMsg(Exception e) { + StringBuffer sb = new StringBuffer(); + StackTraceElement[] stackArray = e.getStackTrace(); + for (int i = 0; i < stackArray.length; i++) { + StackTraceElement element = stackArray[i]; + sb.append(element.toString() + "\n"); + } + return sb.toString(); + } + + public static void debug(Logger logger, String msg) { + if (logger.isDebugEnabled()) { + logger.debug(msg); + } + } + + public static void debug(Logger logger, String msg, Object o) { + if (logger.isDebugEnabled()) { + logger.debug(msg, o); + } + } + + public static void debug(Logger logger, String msg, Object... o) { + if (logger.isDebugEnabled()) { + logger.debug(msg, o); + } + } + + public static void info(Logger logger, String msg) { + if (logger.isInfoEnabled()) { + logger.info(msg); + } + } + + public static void info(Logger logger, String msg, Object o) { + if (logger.isInfoEnabled()) { + logger.info(msg, o); + } + } + + public static void info(Logger logger, String msg, Object... o) { + if (logger.isInfoEnabled()) { + logger.info(msg, o); + } + } + + public static void info(Logger logger, String msg, Throwable t) { + if (logger.isInfoEnabled()) { + logger.info(msg, t); + } + } + + public static void warn(Logger logger, String msg) { + logger.warn(msg); + } + + public static void warn(Logger logger, String msg, Object... o) { + logger.warn(msg, o); + } + + public static void warn(Logger logger, String msg, Throwable t) { + logger.warn(msg, t); + } + + public static void error(Logger logger, String msg) { + logger.error(msg); + } + + public static void error(Logger logger, String msg, Object... o) { + logger.error(msg, o); + } + + public static void error(Logger logger, String msg, Throwable t) { + logger.error(msg, t); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/MemoryUtils.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/MemoryUtils.java index 8bb8b8e46..c69814a0e 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/MemoryUtils.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/MemoryUtils.java @@ -23,21 +23,21 @@ public class MemoryUtils { - public static final long KB = 1024L; - public static final long MB = 1024L * KB; - public static final long GB = 1024L * MB; - public static final long TB = 1024L * GB; - public static final long PB = 1024L * TB; - private static final String BYTE_UNITS = "KMGTPE"; + public static final long KB = 1024L; + public static final long MB = 1024L * KB; + public static final long GB = 1024L * MB; + public static final long TB = 1024L * GB; + public static final long PB = 1024L * TB; + private static final String BYTE_UNITS = "KMGTPE"; - public static String humanReadableByteCount(long bytes) { - int unit = 1024; - if (bytes < unit) { - return bytes + "B"; - } - int exp = (int) (Math.log(bytes) / Math.log(unit)); - char pre = BYTE_UNITS.charAt(exp - 1); - BigDecimal bd = new BigDecimal(bytes / Math.pow(unit, exp)); - return String.format("%s%cB", bd.setScale(2, BigDecimal.ROUND_DOWN), pre); + public static String humanReadableByteCount(long bytes) { + int unit = 1024; + if (bytes < unit) { + return bytes + "B"; } + int exp = (int) (Math.log(bytes) / Math.log(unit)); + char pre = BYTE_UNITS.charAt(exp - 1); + BigDecimal bd = new BigDecimal(bytes / Math.pow(unit, exp)); + return String.format("%s%cB", bd.setScale(2, BigDecimal.ROUND_DOWN), pre); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java index b97b45e7c..4adcac94f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/PortUtil.java @@ -25,42 +25,41 @@ public class PortUtil { - private static final int MAX_NUM = 200; - private static final int DEFAULT_MIN_PORT = 50000; - private static final int DEFAULT_MAX_PORT = 60000; + private static final int MAX_NUM = 200; + private static final int DEFAULT_MIN_PORT = 50000; + private static final int DEFAULT_MAX_PORT = 60000; - public static int getPort(int minPort, int maxPort) { + public static int getPort(int minPort, int maxPort) { - int num = 0; - int port; - while (num < MAX_NUM) { - try { - port = getAvailablePort(minPort, maxPort); - if (port > 0) { - return port; - } - } catch (Exception e) { - num++; - } + int num = 0; + int port; + while (num < MAX_NUM) { + try { + port = getAvailablePort(minPort, maxPort); + if (port > 0) { + return port; } - throw new RuntimeException(String.format("no available port in [%d,%d]", minPort, maxPort)); + } catch (Exception e) { + num++; + } } + throw new RuntimeException(String.format("no available port in [%d,%d]", minPort, maxPort)); + } - public static int getPort(int port) { - return port != 0 ? port : getPort(DEFAULT_MIN_PORT, DEFAULT_MAX_PORT); - } + public static int getPort(int port) { + return port != 0 ? port : getPort(DEFAULT_MIN_PORT, DEFAULT_MAX_PORT); + } - private static int getAvailablePort(int minPort, int maxPort) throws IOException { - Random random = new Random(); - int port = 0; - while (true) { - int tempPort = random.nextInt(maxPort) % (maxPort - minPort + 1) + minPort; - ServerSocket serverSocket = new ServerSocket(tempPort); - port = serverSocket.getLocalPort(); - serverSocket.close(); - break; - } - return port; + private static int getAvailablePort(int minPort, int maxPort) throws IOException { + Random random = new Random(); + int port = 0; + while (true) { + int tempPort = random.nextInt(maxPort) % (maxPort - minPort + 1) + minPort; + ServerSocket serverSocket = new ServerSocket(tempPort); + port = serverSocket.getLocalPort(); + serverSocket.close(); + break; } - + return port; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ProcessUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ProcessUtil.java index d21486646..f2607228c 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ProcessUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ProcessUtil.java @@ -24,6 +24,7 @@ import java.lang.management.RuntimeMXBean; import java.lang.reflect.Field; import java.net.InetAddress; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; @@ -31,105 +32,99 @@ public class ProcessUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(ProcessUtil.class); - private static final String HOSTNAME; - private static final String HOST_IP; - private static final String HOSTNAME_AND_IP; - private static final String HOST_AND_PID; - private static final int PROCESS_ID; - public static final String LOCAL_ADDRESS; - - static { - try { - InetAddress addr = InetAddress.getLocalHost(); - HOSTNAME = addr.getHostName(); - HOST_IP = addr.getHostAddress(); - LOCAL_ADDRESS = addr.getHostAddress(); - HOSTNAME_AND_IP = HOSTNAME + "/" + HOST_IP; - - RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); - String name = runtime.getName(); // format: "pid@hostname" - PROCESS_ID = Integer.parseInt(name.substring(0, name.indexOf('@'))); - HOST_AND_PID = HOSTNAME + ":" + PROCESS_ID; - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new RuntimeException(e); - } - } - - public static String getHostname() { - return HOSTNAME; - } - - public static String getHostIp() { - return HOST_IP; - } - - public static int getProcessId() { - return PROCESS_ID; - } - - public static String getHostAndPid() { - return HOST_AND_PID; - } - - public static String getHostAndIp() { - return HOSTNAME_AND_IP; + private static final Logger LOGGER = LoggerFactory.getLogger(ProcessUtil.class); + private static final String HOSTNAME; + private static final String HOST_IP; + private static final String HOSTNAME_AND_IP; + private static final String HOST_AND_PID; + private static final int PROCESS_ID; + public static final String LOCAL_ADDRESS; + + static { + try { + InetAddress addr = InetAddress.getLocalHost(); + HOSTNAME = addr.getHostName(); + HOST_IP = addr.getHostAddress(); + LOCAL_ADDRESS = addr.getHostAddress(); + HOSTNAME_AND_IP = HOSTNAME + "/" + HOST_IP; + + RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); + String name = runtime.getName(); // format: "pid@hostname" + PROCESS_ID = Integer.parseInt(name.substring(0, name.indexOf('@'))); + HOST_AND_PID = HOSTNAME + ":" + PROCESS_ID; + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new RuntimeException(e); } - - /** - * -Xmx. - */ - public static long getMaxMemory() { - Runtime runtime = Runtime.getRuntime(); - return (runtime.maxMemory()) / FileUtils.ONE_MB; - } - - /** - * start is -Xms, Process Current Get Memory from Os. - */ - public static long getTotalMemory() { - Runtime runtime = Runtime.getRuntime(); - return (runtime.totalMemory()) / FileUtils.ONE_MB; + } + + public static String getHostname() { + return HOSTNAME; + } + + public static String getHostIp() { + return HOST_IP; + } + + public static int getProcessId() { + return PROCESS_ID; + } + + public static String getHostAndPid() { + return HOST_AND_PID; + } + + public static String getHostAndIp() { + return HOSTNAME_AND_IP; + } + + /** -Xmx. */ + public static long getMaxMemory() { + Runtime runtime = Runtime.getRuntime(); + return (runtime.maxMemory()) / FileUtils.ONE_MB; + } + + /** start is -Xms, Process Current Get Memory from Os. */ + public static long getTotalMemory() { + Runtime runtime = Runtime.getRuntime(); + return (runtime.totalMemory()) / FileUtils.ONE_MB; + } + + /** Get Memory from os and not use. */ + public static long getFreeMemory() { + Runtime runtime = Runtime.getRuntime(); + return (runtime.freeMemory()) / FileUtils.ONE_MB; + } + + public static synchronized int getProcessPid(Process p) { + int pid = -1; + try { + if (ReflectionUtil.JAVA_VERSION >= 9 + || p.getClass().getName().equals("java.lang.UNIXProcess")) { + Field f = p.getClass().getDeclaredField("pid"); + f.setAccessible(true); + pid = f.getInt(p); + f.setAccessible(false); + } + } catch (Exception e) { + LOGGER.warn("fail to get pid from {}", p.getClass().getCanonicalName()); + pid = -1; } - - /** - * Get Memory from os and not use. - */ - public static long getFreeMemory() { - Runtime runtime = Runtime.getRuntime(); - return (runtime.freeMemory()) / FileUtils.ONE_MB; - } - - public static synchronized int getProcessPid(Process p) { - int pid = -1; - try { - if (ReflectionUtil.JAVA_VERSION >= 9 || p.getClass().getName().equals("java.lang.UNIXProcess")) { - Field f = p.getClass().getDeclaredField("pid"); - f.setAccessible(true); - pid = f.getInt(p); - f.setAccessible(false); - } - } catch (Exception e) { - LOGGER.warn("fail to get pid from {}", p.getClass().getCanonicalName()); - pid = -1; - } - return pid; + return pid; + } + + public static void killProcess(int pid) { + execute("kill -9 " + pid); + } + + public static void execute(String cmd) { + LOGGER.info(cmd); + try { + Process process = Runtime.getRuntime().exec(cmd); + process.waitFor(); + } catch (InterruptedException | IOException e) { + LOGGER.error(" {} failed: {}", cmd, e); + throw new GeaflowRuntimeException(e.getMessage(), e); } - - public static void killProcess(int pid) { - execute("kill -9 " + pid); - } - - public static void execute(String cmd) { - LOGGER.info(cmd); - try { - Process process = Runtime.getRuntime().exec(cmd); - process.waitFor(); - } catch (InterruptedException | IOException e) { - LOGGER.error(" {} failed: {}", cmd, e); - throw new GeaflowRuntimeException(e.getMessage(), e); - } - } - + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ReflectionUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ReflectionUtil.java index 848acee47..9a16b680f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ReflectionUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ReflectionUtil.java @@ -23,74 +23,73 @@ package org.apache.geaflow.common.utils; -import com.google.common.base.Preconditions; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Field; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; -public final class ReflectionUtil { +import com.google.common.base.Preconditions; - public static final int JAVA_VERSION = majorVersion( - SystemArgsUtil.get("java.specification.version", "1.6")); +public final class ReflectionUtil { - private static int majorVersion(final String javaSpecVersion) { - final String[] components = javaSpecVersion.split("\\."); - final int[] version = new int[components.length]; - for (int i = 0; i < components.length; i++) { - version[i] = Integer.parseInt(components[i]); - } + public static final int JAVA_VERSION = + majorVersion(SystemArgsUtil.get("java.specification.version", "1.6")); - if (version[0] == 1) { - Preconditions.checkArgument(version[1] >= 6); - return version[1]; - } else { - return version[0]; - } + private static int majorVersion(final String javaSpecVersion) { + final String[] components = javaSpecVersion.split("\\."); + final int[] version = new int[components.length]; + for (int i = 0; i < components.length; i++) { + version[i] = Integer.parseInt(components[i]); } - private ReflectionUtil() { + if (version[0] == 1) { + Preconditions.checkArgument(version[1] >= 6); + return version[1]; + } else { + return version[0]; } + } - /** - * Set visibility. - */ - public static Throwable trySetAccessible(AccessibleObject object, boolean checkAccessible) { - if (checkAccessible && JAVA_VERSION >= 9) { - return new UnsupportedOperationException("Reflective setAccessible(true) disabled"); - } - try { - object.setAccessible(true); - return null; - } catch (Exception e) { - return new GeaflowRuntimeException(e); - } - } + private ReflectionUtil() {} - public static Object getField(Object object, String fieldName) { - try { - Field field = getField(object.getClass(), fieldName); - field.setAccessible(true); - return field.get(object); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + /** Set visibility. */ + public static Throwable trySetAccessible(AccessibleObject object, boolean checkAccessible) { + if (checkAccessible && JAVA_VERSION >= 9) { + return new UnsupportedOperationException("Reflective setAccessible(true) disabled"); } + try { + object.setAccessible(true); + return null; + } catch (Exception e) { + return new GeaflowRuntimeException(e); + } + } - private static Field getField(Class clazz, String fieldName) throws NoSuchFieldException { - while (clazz != Object.class) { - try { - return clazz.getDeclaredField(fieldName); - } catch (NoSuchFieldException e) { - // ignore current class field, try super class. - } - clazz = clazz.getSuperclass(); - } - throw new NoSuchFieldException(fieldName); + public static Object getField(Object object, String fieldName) { + try { + Field field = getField(object.getClass(), fieldName); + field.setAccessible(true); + return field.get(object); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public static void setField(Object instance, String fieldName, Object value) throws Exception { - Field field = getField(instance.getClass(), fieldName); - field.setAccessible(true); - field.set(instance, value); + private static Field getField(Class clazz, String fieldName) throws NoSuchFieldException { + while (clazz != Object.class) { + try { + return clazz.getDeclaredField(fieldName); + } catch (NoSuchFieldException e) { + // ignore current class field, try super class. + } + clazz = clazz.getSuperclass(); } + throw new NoSuchFieldException(fieldName); + } + + public static void setField(Object instance, String fieldName, Object value) throws Exception { + Field field = getField(instance.getClass(), fieldName); + field.setAccessible(true); + field.set(instance, value); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/RetryCommand.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/RetryCommand.java index 69d030438..760e027fd 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/RetryCommand.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/RetryCommand.java @@ -21,59 +21,61 @@ import java.util.Random; import java.util.concurrent.Callable; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utils for retrying to execute. - */ +/** Utils for retrying to execute. */ public class RetryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(RetryCommand.class); - private static final Random RANDOM = new Random(); + private static final Logger LOGGER = LoggerFactory.getLogger(RetryCommand.class); + private static final Random RANDOM = new Random(); - public static T run(Callable function, int retryCount) { - return run(function, retryCount, 0); - } + public static T run(Callable function, int retryCount) { + return run(function, retryCount, 0); + } - public static T run(Callable function, int retryCount, long retryIntervalMs) { - return run(function, null, retryCount, retryIntervalMs); - } + public static T run(Callable function, int retryCount, long retryIntervalMs) { + return run(function, null, retryCount, retryIntervalMs); + } - public static T run(Callable function, Callable retryFunction, int retryCount, - long retryIntervalMs) { - return run(function, retryFunction, retryCount, retryIntervalMs, false); - } + public static T run( + Callable function, Callable retryFunction, int retryCount, long retryIntervalMs) { + return run(function, retryFunction, retryCount, retryIntervalMs, false); + } - public static T run(Callable function, Callable retryFunction, final int retryCount, - long retryIntervalMs, boolean needRandom) { - int i = retryCount; - while (0 < i) { - try { - return function.call(); - } catch (Exception e) { - i--; + public static T run( + Callable function, + Callable retryFunction, + final int retryCount, + long retryIntervalMs, + boolean needRandom) { + int i = retryCount; + while (0 < i) { + try { + return function.call(); + } catch (Exception e) { + i--; - if (i == 0) { - LOGGER.error("Retry failed and reached the maximum retried times.", e); - throw new GeaflowRuntimeException(e); - } + if (i == 0) { + LOGGER.error("Retry failed and reached the maximum retried times.", e); + throw new GeaflowRuntimeException(e); + } - try { - long sleepTime = needRandom ? retryIntervalMs * (RANDOM.nextInt(retryCount) + 1) - : retryIntervalMs; - LOGGER.warn("Retry failed, will retry {} times with interval {} ms", i, - sleepTime); - Thread.sleep(sleepTime); - if (retryFunction != null) { - retryFunction.call(); - } - } catch (Exception e1) { - throw new RuntimeException(e1); - } - } + try { + long sleepTime = + needRandom ? retryIntervalMs * (RANDOM.nextInt(retryCount) + 1) : retryIntervalMs; + LOGGER.warn("Retry failed, will retry {} times with interval {} ms", i, sleepTime); + Thread.sleep(sleepTime); + if (retryFunction != null) { + retryFunction.call(); + } + } catch (Exception e1) { + throw new RuntimeException(e1); } - return null; + } } -} \ No newline at end of file + return null; + } +} diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ShellUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ShellUtil.java index b289c0150..cf62d6631 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ShellUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ShellUtil.java @@ -23,6 +23,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.util.concurrent.TimeUnit; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; @@ -30,55 +31,53 @@ public class ShellUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(ShellUtil.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ShellUtil.class); - public static void executeShellCommand(ProcessBuilder command, int timeoutSeconds) { - Process process = null; - try { - LOGGER.info("Start executing shell command: {}", command.command()); - process = command.start(); - boolean exit = process.waitFor(timeoutSeconds, TimeUnit.SECONDS); - if (!exit) { - throw new GeaflowRuntimeException( - String.format("Command %s execute timeout.", command.command())); - } - int code = process.exitValue(); - if (code != 0) { - String message = getCommandErrorMessage(process); - LOGGER.error("Execute command {} failed with code {}. Error message: {}", - command, code, message); - throw new GeaflowRuntimeException( - String.format("Code: %s, Message: %s", code, message)); - } - LOGGER.info("Finished executing shell command finished: {}", command.command()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } finally { - if (process != null && process.isAlive()) { - LOGGER.info("Killed subprocess generated by command {}.", command.command()); - process.destroy(); - } - } + public static void executeShellCommand(ProcessBuilder command, int timeoutSeconds) { + Process process = null; + try { + LOGGER.info("Start executing shell command: {}", command.command()); + process = command.start(); + boolean exit = process.waitFor(timeoutSeconds, TimeUnit.SECONDS); + if (!exit) { + throw new GeaflowRuntimeException( + String.format("Command %s execute timeout.", command.command())); + } + int code = process.exitValue(); + if (code != 0) { + String message = getCommandErrorMessage(process); + LOGGER.error( + "Execute command {} failed with code {}. Error message: {}", command, code, message); + throw new GeaflowRuntimeException(String.format("Code: %s, Message: %s", code, message)); + } + LOGGER.info("Finished executing shell command finished: {}", command.command()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } finally { + if (process != null && process.isAlive()) { + LOGGER.info("Killed subprocess generated by command {}.", command.command()); + process.destroy(); + } } + } - public static String getCommandErrorMessage(Process process) { - String errorMessage; - try (InputStream inputStream = process.getErrorStream()) { - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - StringBuilder stringBuilder = new StringBuilder(); - String line; - while ((line = reader.readLine()) != null) { - stringBuilder.append(line).append("\n"); - } - errorMessage = stringBuilder.toString(); - if (StringUtils.isNotEmpty(errorMessage) && errorMessage.endsWith("\n")) { - errorMessage = errorMessage.substring(0, errorMessage.length() - 1); - } - } catch (Exception e) { - errorMessage = "Get error message from error-stream of process failed."; - LOGGER.warn(errorMessage); - } - return errorMessage; + public static String getCommandErrorMessage(Process process) { + String errorMessage; + try (InputStream inputStream = process.getErrorStream()) { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + StringBuilder stringBuilder = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + stringBuilder.append(line).append("\n"); + } + errorMessage = stringBuilder.toString(); + if (StringUtils.isNotEmpty(errorMessage) && errorMessage.endsWith("\n")) { + errorMessage = errorMessage.substring(0, errorMessage.length() - 1); + } + } catch (Exception e) { + errorMessage = "Get error message from error-stream of process failed."; + LOGGER.warn(errorMessage); } - + return errorMessage; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SleepUtils.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SleepUtils.java index b4bf5ca34..5c0a6376f 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SleepUtils.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SleepUtils.java @@ -20,29 +20,30 @@ package org.apache.geaflow.common.utils; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SleepUtils { - private static final Logger LOGGER = LoggerFactory.getLogger(SleepUtils.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SleepUtils.class); - public static void sleepSecond(long second) { - try { - TimeUnit.SECONDS.sleep(second); - } catch (InterruptedException e) { - LOGGER.warn("sleep {}s interrupted", second); - throw new GeaflowRuntimeException(e); - } + public static void sleepSecond(long second) { + try { + TimeUnit.SECONDS.sleep(second); + } catch (InterruptedException e) { + LOGGER.warn("sleep {}s interrupted", second); + throw new GeaflowRuntimeException(e); } + } - public static void sleepMilliSecond(long mileSecond) { - try { - TimeUnit.MILLISECONDS.sleep(mileSecond); - } catch (InterruptedException e) { - LOGGER.warn("sleepMilliSecond {}ms interrupted", mileSecond); - throw new GeaflowRuntimeException(e); - } + public static void sleepMilliSecond(long mileSecond) { + try { + TimeUnit.MILLISECONDS.sleep(mileSecond); + } catch (InterruptedException e) { + LOGGER.warn("sleepMilliSecond {}ms interrupted", mileSecond); + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/StringUtils.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/StringUtils.java index e56a6e081..cb52cb356 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/StringUtils.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/StringUtils.java @@ -21,59 +21,57 @@ import java.util.Arrays; -/** - * Created by eagle on 2017/8/30. - */ +/** Created by eagle on 2017/8/30. */ public class StringUtils { - public static String arrayAwareToString(Object o) { - if (o == null) { - return "null"; - } - if (o.getClass().isArray()) { - return arrayToString(o); - } - - return o.toString(); + public static String arrayAwareToString(Object o) { + if (o == null) { + return "null"; + } + if (o.getClass().isArray()) { + return arrayToString(o); } - public static String arrayToString(Object array) { - if (array == null) { - throw new NullPointerException(); - } + return o.toString(); + } - if (array instanceof int[]) { - return Arrays.toString((int[]) array); - } - if (array instanceof long[]) { - return Arrays.toString((long[]) array); - } - if (array instanceof Object[]) { - return Arrays.toString((Object[]) array); - } - if (array instanceof byte[]) { - return Arrays.toString((byte[]) array); - } - if (array instanceof double[]) { - return Arrays.toString((double[]) array); - } - if (array instanceof float[]) { - return Arrays.toString((float[]) array); - } - if (array instanceof boolean[]) { - return Arrays.toString((boolean[]) array); - } - if (array instanceof char[]) { - return Arrays.toString((char[]) array); - } - if (array instanceof short[]) { - return Arrays.toString((short[]) array); - } + public static String arrayToString(Object array) { + if (array == null) { + throw new NullPointerException(); + } + + if (array instanceof int[]) { + return Arrays.toString((int[]) array); + } + if (array instanceof long[]) { + return Arrays.toString((long[]) array); + } + if (array instanceof Object[]) { + return Arrays.toString((Object[]) array); + } + if (array instanceof byte[]) { + return Arrays.toString((byte[]) array); + } + if (array instanceof double[]) { + return Arrays.toString((double[]) array); + } + if (array instanceof float[]) { + return Arrays.toString((float[]) array); + } + if (array instanceof boolean[]) { + return Arrays.toString((boolean[]) array); + } + if (array instanceof char[]) { + return Arrays.toString((char[]) array); + } + if (array instanceof short[]) { + return Arrays.toString((short[]) array); + } - if (array.getClass().isArray()) { - return ""; - } else { - throw new IllegalArgumentException("The given argument is no array."); - } + if (array.getClass().isArray()) { + return ""; + } else { + throw new IllegalArgumentException("The given argument is no array."); } + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SystemArgsUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SystemArgsUtil.java index 0376cf468..97eb0e61a 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SystemArgsUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/SystemArgsUtil.java @@ -27,60 +27,58 @@ import java.security.AccessController; import java.security.PrivilegedAction; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SystemArgsUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(SystemArgsUtil.class); - - private SystemArgsUtil() { - // Unused - } + private static final Logger LOGGER = LoggerFactory.getLogger(SystemArgsUtil.class); - /** - * Returns the value of the Java system property with the specified - * {@code key}, while falling back to {@code null} if the property access fails. - * - * @return the property value or {@code null} - */ - public static String get(String key) { - return get(key, null); - } + private SystemArgsUtil() { + // Unused + } - /** - * Returns the value of the Java system property with the specified - * {@code key}, while falling back to the specified default value if - * the property access fails. - * - * @return the property value. - * {@code def} if there's no such property or if an access to the - * specified property is not allowed. - */ - public static String get(final String key, String def) { - requireNonNull(key, "key"); - if (key.isEmpty()) { - throw new IllegalArgumentException("key must not be empty."); - } + /** + * Returns the value of the Java system property with the specified {@code key}, while falling + * back to {@code null} if the property access fails. + * + * @return the property value or {@code null} + */ + public static String get(String key) { + return get(key, null); + } - String value = null; - try { - if (System.getSecurityManager() == null) { - value = System.getProperty(key); - } else { - value = AccessController - .doPrivileged((PrivilegedAction) () -> System.getProperty(key)); - } - } catch (Exception ignore) { - LOGGER.warn("Unable to retrieve a system property '{}'; default values will be used.", - key, ignore); - } + /** + * Returns the value of the Java system property with the specified {@code key}, while falling + * back to the specified default value if the property access fails. + * + * @return the property value. {@code def} if there's no such property or if an access to the + * specified property is not allowed. + */ + public static String get(final String key, String def) { + requireNonNull(key, "key"); + if (key.isEmpty()) { + throw new IllegalArgumentException("key must not be empty."); + } - if (value == null) { - return def; - } + String value = null; + try { + if (System.getSecurityManager() == null) { + value = System.getProperty(key); + } else { + value = + AccessController.doPrivileged((PrivilegedAction) () -> System.getProperty(key)); + } + } catch (Exception ignore) { + LOGGER.warn( + "Unable to retrieve a system property '{}'; default values will be used.", key, ignore); + } - return value; + if (value == null) { + return def; } + return value; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ThreadUtil.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ThreadUtil.java index cd76102a5..05197df44 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ThreadUtil.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/utils/ThreadUtil.java @@ -19,22 +19,22 @@ package org.apache.geaflow.common.utils; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.util.concurrent.ThreadFactory; -public class ThreadUtil { +import com.google.common.util.concurrent.ThreadFactoryBuilder; - public static ThreadFactory namedThreadFactory(boolean isDaemon, String prefix) { - return new ThreadFactoryBuilder().setDaemon(isDaemon).setNameFormat(prefix + "-%d").build(); - } +public class ThreadUtil { - public static ThreadFactory namedThreadFactory(boolean isDaemon, String prefix, - Thread.UncaughtExceptionHandler handler) { - return new ThreadFactoryBuilder() - .setDaemon(isDaemon) - .setNameFormat(prefix + "-%d") - .setUncaughtExceptionHandler(handler) - .build(); - } + public static ThreadFactory namedThreadFactory(boolean isDaemon, String prefix) { + return new ThreadFactoryBuilder().setDaemon(isDaemon).setNameFormat(prefix + "-%d").build(); + } + public static ThreadFactory namedThreadFactory( + boolean isDaemon, String prefix, Thread.UncaughtExceptionHandler handler) { + return new ThreadFactoryBuilder() + .setDaemon(isDaemon) + .setNameFormat(prefix + "-%d") + .setUncaughtExceptionHandler(handler) + .build(); + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/ConsoleVisualizeVertex.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/ConsoleVisualizeVertex.java index 4f281317e..0de883bcf 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/ConsoleVisualizeVertex.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/ConsoleVisualizeVertex.java @@ -24,76 +24,76 @@ public class ConsoleVisualizeVertex { - public String vertexType; - public String vertexMode; - public String id; - public int parallelism; - public String operator; - public String operatorName; - public List parents = new ArrayList<>(); - public JsonPlan innerPlan; - - public String getVertexType() { - return vertexType; - } - - public void setVertexType(String vertexType) { - this.vertexType = vertexType; - } - - public String getVertexMode() { - return vertexMode; - } - - public void setVertexMode(String vertexMode) { - this.vertexMode = vertexMode; - } - - public String getId() { - return id; - } - - public void setId(String id) { - this.id = id; - } - - public int getParallelism() { - return parallelism; - } - - public void setParallelism(int parallelism) { - this.parallelism = parallelism; - } - - public String getOperator() { - return operator; - } - - public void setOperator(String operator) { - this.operator = operator; - } - - public String getOperatorName() { - return operatorName; - } - - public void setOperatorName(String operatorName) { - this.operatorName = operatorName; - } - - public List getParents() { - return parents; - } - - public void setParents(List parents) { - this.parents = parents; - } - - public JsonPlan getInnerPlan() { - return innerPlan; - } - - public void setInnerPlan(JsonPlan innerPlan) { - this.innerPlan = innerPlan; - } + public String vertexType; + public String vertexMode; + public String id; + public int parallelism; + public String operator; + public String operatorName; + public List parents = new ArrayList<>(); + public JsonPlan innerPlan; + + public String getVertexType() { + return vertexType; + } + + public void setVertexType(String vertexType) { + this.vertexType = vertexType; + } + + public String getVertexMode() { + return vertexMode; + } + + public void setVertexMode(String vertexMode) { + this.vertexMode = vertexMode; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public int getParallelism() { + return parallelism; + } + + public void setParallelism(int parallelism) { + this.parallelism = parallelism; + } + + public String getOperator() { + return operator; + } + + public void setOperator(String operator) { + this.operator = operator; + } + + public String getOperatorName() { + return operatorName; + } + + public void setOperatorName(String operatorName) { + this.operatorName = operatorName; + } + + public List getParents() { + return parents; + } + + public void setParents(List parents) { + this.parents = parents; + } + + public JsonPlan getInnerPlan() { + return innerPlan; + } + + public void setInnerPlan(JsonPlan innerPlan) { + this.innerPlan = innerPlan; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/JsonPlan.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/JsonPlan.java index c683bfbbe..5c9d7766c 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/JsonPlan.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/JsonPlan.java @@ -24,13 +24,13 @@ public class JsonPlan { - public Map vertices = new HashMap<>(); + public Map vertices = new HashMap<>(); - public Map getVertices() { - return vertices; - } + public Map getVertices() { + return vertices; + } - public void setVertices(Map vertices) { - this.vertices = vertices; - } + public void setVertices(Map vertices) { + this.vertices = vertices; + } } diff --git a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/Predecessor.java b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/Predecessor.java index 37f9fc639..748e69155 100644 --- a/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/Predecessor.java +++ b/geaflow/geaflow-common/src/main/java/org/apache/geaflow/common/visualization/console/Predecessor.java @@ -21,22 +21,22 @@ public class Predecessor { - public String id; - public String partitionType; + public String id; + public String partitionType; - public String getId() { - return id; - } + public String getId() { + return id; + } - public void setId(String id) { - this.id = id; - } + public void setId(String id) { + this.id = id; + } - public String getPartitionType() { - return partitionType; - } + public String getPartitionType() { + return partitionType; + } - public void setPartitionType(String partitionType) { - this.partitionType = partitionType; - } + public void setPartitionType(String partitionType) { + this.partitionType = partitionType; + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/TupleTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/TupleTest.java index a5bd5d612..1180b4e49 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/TupleTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/TupleTest.java @@ -26,28 +26,28 @@ public class TupleTest { - @Test - public void testTuple() { - Tuple tuple = Tuple.of(1, "a"); - Assert.assertEquals(tuple.getF0().intValue(), 1); - Assert.assertEquals(tuple.getF1(), "a"); - } + @Test + public void testTuple() { + Tuple tuple = Tuple.of(1, "a"); + Assert.assertEquals(tuple.getF0().intValue(), 1); + Assert.assertEquals(tuple.getF1(), "a"); + } - @Test - public void testTriple() { - Triple triple1 = Triple.of(1, 2, "a"); - Assert.assertEquals(triple1.getF0().intValue(), 1); - Assert.assertEquals(triple1.getF1().intValue(), 2); - Assert.assertEquals(triple1.getF2(), "a"); - Assert.assertEquals(triple1.toString(), "(1,2,a)"); + @Test + public void testTriple() { + Triple triple1 = Triple.of(1, 2, "a"); + Assert.assertEquals(triple1.getF0().intValue(), 1); + Assert.assertEquals(triple1.getF1().intValue(), 2); + Assert.assertEquals(triple1.getF2(), "a"); + Assert.assertEquals(triple1.toString(), "(1,2,a)"); - Triple triple2 = Triple.of(1, 2, "a"); - Assert.assertEquals(triple2, triple1); - Assert.assertEquals(triple2.hashCode(), triple1.hashCode()); + Triple triple2 = Triple.of(1, 2, "a"); + Assert.assertEquals(triple2, triple1); + Assert.assertEquals(triple2.hashCode(), triple1.hashCode()); - triple2.setF0(2); - triple2.setF1(1); - triple2.setF2("b"); - Assert.assertNotEquals(triple2, triple1); - } + triple2.setF0(2); + triple2.setF1(1); + triple2.setF2("b"); + Assert.assertNotEquals(triple2, triple1); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/binary/BinaryStringTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/binary/BinaryStringTest.java index d411c0846..9d228e511 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/binary/BinaryStringTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/binary/BinaryStringTest.java @@ -20,44 +20,45 @@ package org.apache.geaflow.common.binary; import java.nio.charset.StandardCharsets; + import org.testng.Assert; import org.testng.annotations.Test; public class BinaryStringTest { - @Test - public void testBinaryString() { - String testStr = "djwakfmnlkgritio3175453406fsdjhhflkdsk1`26ad09~|?!!"; - byte[] bytes = testStr.getBytes(StandardCharsets.UTF_8); - BinaryString binaryTestStr = BinaryString.fromBytes(bytes); - Assert.assertEquals(testStr.length(), binaryTestStr.getLength()); - Assert.assertEquals(binaryTestStr.getBytes(), bytes); - - String compareStr = "3218478293djadhfue8917535566"; - BinaryString binaryCompareStr = BinaryString.fromString(compareStr); - Assert.assertTrue(binaryCompareStr.compareTo(binaryTestStr) < 0); - Assert.assertNotEquals(binaryTestStr, binaryCompareStr); - - BinaryString splitStr = BinaryString.fromString("123_445_13da_deg"); - BinaryString[] splits = splitStr.split(BinaryString.fromString("_"), 0); - Assert.assertEquals(splits.length, 4); - Assert.assertTrue(splitStr.contains(BinaryString.fromString("13da"))); - Assert.assertFalse(splitStr.contains(BinaryString.fromString("21"))); - Assert.assertTrue(splitStr.startsWith(BinaryString.fromString("123"))); - } - - @Test - public void testStartWith() { - BinaryString str = BinaryString.fromString("abcabc"); - Assert.assertTrue(str.startsWith(BinaryString.fromString("abc"))); - Assert.assertFalse(str.startsWith(BinaryString.fromString("c"))); - } - - @Test - public void testEndsWith() { - BinaryString str = BinaryString.fromString("abcabc"); - Assert.assertTrue(str.endsWith(BinaryString.fromString("abc"))); - Assert.assertTrue(str.endsWith(BinaryString.fromString("c"))); - Assert.assertFalse(str.endsWith(BinaryString.fromString("d"))); - } + @Test + public void testBinaryString() { + String testStr = "djwakfmnlkgritio3175453406fsdjhhflkdsk1`26ad09~|?!!"; + byte[] bytes = testStr.getBytes(StandardCharsets.UTF_8); + BinaryString binaryTestStr = BinaryString.fromBytes(bytes); + Assert.assertEquals(testStr.length(), binaryTestStr.getLength()); + Assert.assertEquals(binaryTestStr.getBytes(), bytes); + + String compareStr = "3218478293djadhfue8917535566"; + BinaryString binaryCompareStr = BinaryString.fromString(compareStr); + Assert.assertTrue(binaryCompareStr.compareTo(binaryTestStr) < 0); + Assert.assertNotEquals(binaryTestStr, binaryCompareStr); + + BinaryString splitStr = BinaryString.fromString("123_445_13da_deg"); + BinaryString[] splits = splitStr.split(BinaryString.fromString("_"), 0); + Assert.assertEquals(splits.length, 4); + Assert.assertTrue(splitStr.contains(BinaryString.fromString("13da"))); + Assert.assertFalse(splitStr.contains(BinaryString.fromString("21"))); + Assert.assertTrue(splitStr.startsWith(BinaryString.fromString("123"))); + } + + @Test + public void testStartWith() { + BinaryString str = BinaryString.fromString("abcabc"); + Assert.assertTrue(str.startsWith(BinaryString.fromString("abc"))); + Assert.assertFalse(str.startsWith(BinaryString.fromString("c"))); + } + + @Test + public void testEndsWith() { + BinaryString str = BinaryString.fromString("abcabc"); + Assert.assertTrue(str.endsWith(BinaryString.fromString("abc"))); + Assert.assertTrue(str.endsWith(BinaryString.fromString("c"))); + Assert.assertFalse(str.endsWith(BinaryString.fromString("d"))); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderResolverTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderResolverTest.java index 7f78812b3..bd5425c1a 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderResolverTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderResolverTest.java @@ -22,6 +22,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.util.Random; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.impl.EnumEncoder; import org.apache.geaflow.common.encoder.impl.PojoEncoder; @@ -35,414 +36,376 @@ public class EncoderResolverTest { - @Test - public void testResolveClass() { - IEncoder int1 = EncoderResolver.resolveClass(int.class); - Assert.assertEquals(int1, Encoders.INTEGER); - IEncoder int2 = EncoderResolver.resolveClass(Integer.class); - Assert.assertEquals(int2, Encoders.INTEGER); - - IEncoder encoderObj = EncoderResolver.resolveClass(Object.class); - Assert.assertNull(encoderObj); - IEncoder encoderCls = EncoderResolver.resolveClass(Class.class); - Assert.assertNull(encoderCls); - IEncoder encoderIfc = EncoderResolver.resolveClass(ITest0.class); - Assert.assertNull(encoderIfc); - } - - @Test - public void testResolveInterfaceAndImpl() { - IEncoder encoder0 = EncoderResolver.resolveFunction(ITest0.class, new CTest0()); - Assert.assertEquals(encoder0, Encoders.INTEGER); - - IEncoder encoder1 = EncoderResolver.resolveFunction(ITest0.class, new CTest1()); - Assert.assertEquals(encoder1, Encoders.STRING); - - IEncoder encoder2 = EncoderResolver.resolveFunction(ITest0.class, new CTest2()); - Assert.assertEquals(encoder2, Encoders.LONG); - - IEncoder encoder3 = EncoderResolver.resolveFunction(ITest0.class, new CTest3()); - Assert.assertEquals(encoder3, Encoders.INTEGER); - - IEncoder encoder4 = EncoderResolver.resolveFunction(ITest0.class, new CTest4()); - Assert.assertEquals(encoder4, Encoders.STRING); + @Test + public void testResolveClass() { + IEncoder int1 = EncoderResolver.resolveClass(int.class); + Assert.assertEquals(int1, Encoders.INTEGER); + IEncoder int2 = EncoderResolver.resolveClass(Integer.class); + Assert.assertEquals(int2, Encoders.INTEGER); - IEncoder encoder5 = EncoderResolver.resolveFunction(ITest0.class, new CTest5()); - Assert.assertEquals(encoder5, Encoders.LONG); + IEncoder encoderObj = EncoderResolver.resolveClass(Object.class); + Assert.assertNull(encoderObj); + IEncoder encoderCls = EncoderResolver.resolveClass(Class.class); + Assert.assertNull(encoderCls); + IEncoder encoderIfc = EncoderResolver.resolveClass(ITest0.class); + Assert.assertNull(encoderIfc); + } - IEncoder encoder6 = EncoderResolver.resolveFunction(ITest0.class, new CTest6<>()); - Assert.assertEquals(encoder6, Encoders.INTEGER); + @Test + public void testResolveInterfaceAndImpl() { + IEncoder encoder0 = EncoderResolver.resolveFunction(ITest0.class, new CTest0()); + Assert.assertEquals(encoder0, Encoders.INTEGER); - IEncoder encoder7 = EncoderResolver.resolveFunction(ITest0.class, new CTest7()); - Assert.assertEquals(encoder7, Encoders.INTEGER); + IEncoder encoder1 = EncoderResolver.resolveFunction(ITest0.class, new CTest1()); + Assert.assertEquals(encoder1, Encoders.STRING); - IEncoder encoder8 = EncoderResolver.resolveFunction(ITest0.class, new CTest8()); - Assert.assertEquals(encoder8, Encoders.BOOLEAN); + IEncoder encoder2 = EncoderResolver.resolveFunction(ITest0.class, new CTest2()); + Assert.assertEquals(encoder2, Encoders.LONG); - IEncoder encoder9 = EncoderResolver.resolveFunction(ITest0.class, new CTest9<>()); - Assert.assertEquals(encoder9, Encoders.BOOLEAN); + IEncoder encoder3 = EncoderResolver.resolveFunction(ITest0.class, new CTest3()); + Assert.assertEquals(encoder3, Encoders.INTEGER); - IEncoder encoder10 = EncoderResolver.resolveFunction(ITest0.class, new CTest10<>()); - Assert.assertNull(encoder10); + IEncoder encoder4 = EncoderResolver.resolveFunction(ITest0.class, new CTest4()); + Assert.assertEquals(encoder4, Encoders.STRING); - IEncoder encoder11 = EncoderResolver.resolveFunction(ITest0.class, new CTest11<>()); - Assert.assertNull(encoder11); + IEncoder encoder5 = EncoderResolver.resolveFunction(ITest0.class, new CTest5()); + Assert.assertEquals(encoder5, Encoders.LONG); - IEncoder encoder12_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest12(), 0); - Assert.assertEquals(encoder12_0, Encoders.STRING); - IEncoder encoder12_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest12(), 1); - Assert.assertEquals(encoder12_1, Encoders.INTEGER); + IEncoder encoder6 = EncoderResolver.resolveFunction(ITest0.class, new CTest6<>()); + Assert.assertEquals(encoder6, Encoders.INTEGER); - IEncoder encoder13_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest13<>(), 0); - IEncoder encoder13_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest13<>(), 1); - IEncoder encoder13_2 = EncoderResolver.resolveFunction(ITest1.class, new CTest13(), 0); - IEncoder encoder13_3 = EncoderResolver.resolveFunction(ITest1.class, new CTest13(), 1); - IEncoder encoder13_4 = EncoderResolver.resolveFunction(ITest1.class, new CTest13() { - }, 0); - Assert.assertNull(encoder13_0); - Assert.assertEquals(encoder13_1, Encoders.LONG); - Assert.assertNull(encoder13_2); - Assert.assertEquals(encoder13_3, Encoders.LONG); - Assert.assertEquals(encoder13_4, Encoders.INTEGER); + IEncoder encoder7 = EncoderResolver.resolveFunction(ITest0.class, new CTest7()); + Assert.assertEquals(encoder7, Encoders.INTEGER); - IEncoder encoder15 = EncoderResolver.resolveFunction(ITest1.class, new CTest15(), 1); - Assert.assertEquals(encoder15, Encoders.DOUBLE); - IEncoder encoder16_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest16(), 0); - IEncoder encoder16_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest16(), 1); - Assert.assertEquals(encoder16_0, Encoders.INTEGER); - Assert.assertEquals(encoder16_1, Encoders.FLOAT); - } + IEncoder encoder8 = EncoderResolver.resolveFunction(ITest0.class, new CTest8()); + Assert.assertEquals(encoder8, Encoders.BOOLEAN); - public interface ITest0 { - } + IEncoder encoder9 = EncoderResolver.resolveFunction(ITest0.class, new CTest9<>()); + Assert.assertEquals(encoder9, Encoders.BOOLEAN); - public static abstract class ACTest0 implements ITest0 { - } + IEncoder encoder10 = EncoderResolver.resolveFunction(ITest0.class, new CTest10<>()); + Assert.assertNull(encoder10); - public static abstract class ACTest1 implements ITest0 { - } + IEncoder encoder11 = EncoderResolver.resolveFunction(ITest0.class, new CTest11<>()); + Assert.assertNull(encoder11); - public static class CTest0 implements ITest0 { - } - - public static class CTest1 extends ACTest0 { - } + IEncoder encoder12_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest12(), 0); + Assert.assertEquals(encoder12_0, Encoders.STRING); + IEncoder encoder12_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest12(), 1); + Assert.assertEquals(encoder12_1, Encoders.INTEGER); - public static class CTest2 extends ACTest1 { - } + IEncoder encoder13_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest13<>(), 0); + IEncoder encoder13_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest13<>(), 1); + IEncoder encoder13_2 = + EncoderResolver.resolveFunction(ITest1.class, new CTest13(), 0); + IEncoder encoder13_3 = + EncoderResolver.resolveFunction(ITest1.class, new CTest13(), 1); + IEncoder encoder13_4 = + EncoderResolver.resolveFunction(ITest1.class, new CTest13() {}, 0); + Assert.assertNull(encoder13_0); + Assert.assertEquals(encoder13_1, Encoders.LONG); + Assert.assertNull(encoder13_2); + Assert.assertEquals(encoder13_3, Encoders.LONG); + Assert.assertEquals(encoder13_4, Encoders.INTEGER); - public static class CTest3 extends CTest0 { - } + IEncoder encoder15 = + EncoderResolver.resolveFunction(ITest1.class, new CTest15(), 1); + Assert.assertEquals(encoder15, Encoders.DOUBLE); + IEncoder encoder16_0 = EncoderResolver.resolveFunction(ITest1.class, new CTest16(), 0); + IEncoder encoder16_1 = EncoderResolver.resolveFunction(ITest1.class, new CTest16(), 1); + Assert.assertEquals(encoder16_0, Encoders.INTEGER); + Assert.assertEquals(encoder16_1, Encoders.FLOAT); + } - public static class CTest4 extends CTest1 { - } + public interface ITest0 {} - public static class CTest5 extends CTest2 { - } + public abstract static class ACTest0 implements ITest0 {} - public static abstract class ACTest2 implements ITest0 { - } + public abstract static class ACTest1 implements ITest0 {} - public static class CTest6 extends ACTest2 { - } + public static class CTest0 implements ITest0 {} - public static class CTest7 extends CTest6 { - } + public static class CTest1 extends ACTest0 {} - public static class CTest8 extends ACTest2 { - } + public static class CTest2 extends ACTest1 {} - public static class CTest9 extends ACTest2 { - } + public static class CTest3 extends CTest0 {} - public static class CTest10 extends ACTest2 { - } + public static class CTest4 extends CTest1 {} - public static abstract class ACTest3 implements ITest0 { - } + public static class CTest5 extends CTest2 {} - public static class CTest11 extends ACTest3 { - } + public abstract static class ACTest2 implements ITest0 {} - public interface ITest1 { - } + public static class CTest6 extends ACTest2 {} - public static class CTest12 implements ITest1 { - } + public static class CTest7 extends CTest6 {} - public static class CTest13 implements ITest1 { - } + public static class CTest8 extends ACTest2 {} - public static abstract class ACTest4 implements ITest1 { - } + public static class CTest9 extends ACTest2 {} - public static class CTest15 extends ACTest4 { - } + public static class CTest10 extends ACTest2 {} - public static class CTest16 extends ACTest4 { - } + public abstract static class ACTest3 implements ITest0 {} - @Test - public void testResolveTuple() { - IEncoder encoder17 = EncoderResolver.resolveFunction(ITest0.class, new CTest17()); - Assert.assertNotNull(encoder17); - Assert.assertEquals(encoder17.getClass(), TupleEncoder.class); + public static class CTest11 extends ACTest3 {} - IEncoder encoder18 = EncoderResolver.resolveFunction(ITest0.class, new CTest18()); - Assert.assertNotNull(encoder18); - Assert.assertEquals(encoder18.getClass(), TupleEncoder.class); + public interface ITest1 {} - IEncoder encoder19 = EncoderResolver.resolveFunction(ITest0.class, new CTest19()); - Assert.assertNotNull(encoder19); - Assert.assertEquals(encoder19.getClass(), TupleEncoder.class); + public static class CTest12 implements ITest1 {} - IEncoder encoder20 = EncoderResolver.resolveFunction(ITest0.class, new CTest20()); - Assert.assertNotNull(encoder20); - Assert.assertEquals(encoder20.getClass(), TupleEncoder.class); + public static class CTest13 implements ITest1 {} - IEncoder encoder21 = EncoderResolver.resolveFunction(ITest0.class, new CTest21()); - Assert.assertNotNull(encoder21); - Assert.assertEquals(encoder21.getClass(), TupleEncoder.class); - } + public abstract static class ACTest4 implements ITest1 {} - @Test - public void testResolveTriple() { - IEncoder encoder22 = EncoderResolver.resolveFunction(ITest0.class, new CTest22()); - Assert.assertNotNull(encoder22); - Assert.assertEquals(encoder22.getClass(), TripleEncoder.class); - } + public static class CTest15 extends ACTest4 {} - public static class CTest17 implements ITest0> { - } + public static class CTest16 extends ACTest4 {} - public static class CTest18 extends ACTest1> { - } + @Test + public void testResolveTuple() { + IEncoder encoder17 = EncoderResolver.resolveFunction(ITest0.class, new CTest17()); + Assert.assertNotNull(encoder17); + Assert.assertEquals(encoder17.getClass(), TupleEncoder.class); - public static class CTest19 extends ACTest1>> { - } + IEncoder encoder18 = EncoderResolver.resolveFunction(ITest0.class, new CTest18()); + Assert.assertNotNull(encoder18); + Assert.assertEquals(encoder18.getClass(), TupleEncoder.class); - public static class MyTuple0 extends Tuple { + IEncoder encoder19 = EncoderResolver.resolveFunction(ITest0.class, new CTest19()); + Assert.assertNotNull(encoder19); + Assert.assertEquals(encoder19.getClass(), TupleEncoder.class); - public MyTuple0(String s, Integer integer) { - super(s, integer); - } + IEncoder encoder20 = EncoderResolver.resolveFunction(ITest0.class, new CTest20()); + Assert.assertNotNull(encoder20); + Assert.assertEquals(encoder20.getClass(), TupleEncoder.class); - } + IEncoder encoder21 = EncoderResolver.resolveFunction(ITest0.class, new CTest21()); + Assert.assertNotNull(encoder21); + Assert.assertEquals(encoder21.getClass(), TupleEncoder.class); + } - public static class CTest20 extends ACTest1 { - } + @Test + public void testResolveTriple() { + IEncoder encoder22 = EncoderResolver.resolveFunction(ITest0.class, new CTest22()); + Assert.assertNotNull(encoder22); + Assert.assertEquals(encoder22.getClass(), TripleEncoder.class); + } - public static class MyTuple1 extends Tuple { + public static class CTest17 implements ITest0> {} - public MyTuple1(T0 t0, T1 t1) { - super(t0, t1); - } + public static class CTest18 extends ACTest1> {} - } + public static class CTest19 extends ACTest1>> {} - public static class CTest21 extends ACTest1> { - } + public static class MyTuple0 extends Tuple { - public static class CTest22 extends ACTest1> { + public MyTuple0(String s, Integer integer) { + super(s, integer); } + } - @SuppressWarnings("unchecked") - @Test - public void testPojoEncoder0() throws Exception { - Random random = new Random(); - random.setSeed(37); - PojoEncoder encoder = (PojoEncoder) EncoderResolver.resolvePojo(Pojo0.class); - encoder.init(new Configuration()); - - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 100; i++) { - int b = random.nextInt(); - String a = "abc" + b; - long c = random.nextLong(); - encoder.encode(new Pojo0(a, b, c), bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - random.setSeed(37); - while (bis.available() > 0) { - int b = random.nextInt(); - String a = "abc" + b; - long c = random.nextLong(); - Pojo0 res = encoder.decode(bis); - Assert.assertEquals(res.getA(), a); - Assert.assertEquals((int) res.getB(), b); - Assert.assertEquals(res.getC(), c); - } + public static class CTest20 extends ACTest1 {} - } + public static class MyTuple1 extends Tuple { - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder1() { - EncoderResolver.analysisPojo(Pojo1.class); + public MyTuple1(T0 t0, T1 t1) { + super(t0, t1); } + } - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder2() { - EncoderResolver.analysisPojo(Pojo2.class); - } + public static class CTest21 extends ACTest1> {} - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder3() { - EncoderResolver.analysisPojo(Pojo3.class); - } + public static class CTest22 extends ACTest1> {} - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder4() { - EncoderResolver.analysisPojo(Pojo4.class); - } + @SuppressWarnings("unchecked") + @Test + public void testPojoEncoder0() throws Exception { + Random random = new Random(); + random.setSeed(37); + PojoEncoder encoder = (PojoEncoder) EncoderResolver.resolvePojo(Pojo0.class); + encoder.init(new Configuration()); - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder5() { - EncoderResolver.analysisPojo(Pojo5.class); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 100; i++) { + int b = random.nextInt(); + String a = "abc" + b; + long c = random.nextLong(); + encoder.encode(new Pojo0(a, b, c), bos); } - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder6() { - EncoderResolver.analysisPojo(Pojo6.class); - } + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder7() { - EncoderResolver.analysisPojo(Pojo7.class); + random.setSeed(37); + while (bis.available() > 0) { + int b = random.nextInt(); + String a = "abc" + b; + long c = random.nextLong(); + Pojo0 res = encoder.decode(bis); + Assert.assertEquals(res.getA(), a); + Assert.assertEquals((int) res.getB(), b); + Assert.assertEquals(res.getC(), c); } + } - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder8() { - EncoderResolver.analysisPojo(Pojo8.class); - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder1() { + EncoderResolver.analysisPojo(Pojo1.class); + } - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testPojoEncoder9() { - EncoderResolver.analysisPojo(Pojo9.class); - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder2() { + EncoderResolver.analysisPojo(Pojo2.class); + } - public static class Pojo0 { + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder3() { + EncoderResolver.analysisPojo(Pojo3.class); + } - private String a; - private Integer b; - private long c; + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder4() { + EncoderResolver.analysisPojo(Pojo4.class); + } - public Pojo0() { - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder5() { + EncoderResolver.analysisPojo(Pojo5.class); + } - public Pojo0(String a, Integer b, long c) { - this.a = a; - this.b = b; - this.c = c; - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder6() { + EncoderResolver.analysisPojo(Pojo6.class); + } - public String getA() { - return this.a; - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder7() { + EncoderResolver.analysisPojo(Pojo7.class); + } - public void setA(String a) { - this.a = a; - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder8() { + EncoderResolver.analysisPojo(Pojo8.class); + } - public Integer getB() { - return this.b; - } + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testPojoEncoder9() { + EncoderResolver.analysisPojo(Pojo9.class); + } - public void setB(Integer b) { - this.b = b; - } + public static class Pojo0 { - public long getC() { - return this.c; - } + private String a; + private Integer b; + private long c; - public void setC(long c) { - this.c = c; - } + public Pojo0() {} + public Pojo0(String a, Integer b, long c) { + this.a = a; + this.b = b; + this.c = c; } - public static class Pojo1 { - - private Pojo1() { - } + public String getA() { + return this.a; + } + public void setA(String a) { + this.a = a; } - private static class Pojo2 { + public Integer getB() { + return this.b; } - public static abstract class Pojo3 { + public void setB(Integer b) { + this.b = b; } - public interface Pojo4 { + public long getC() { + return this.c; } - public static class Pojo5 extends Pojo1 { + public void setC(long c) { + this.c = c; + } + } - private int a; + public static class Pojo1 { - public int getA() { - return this.a; - } + private Pojo1() {} + } - public void setA(int a) { - this.a = a; - } + private static class Pojo2 {} - } + public abstract static class Pojo3 {} - public static class Pojo6 { + public interface Pojo4 {} - private int a; + public static class Pojo5 extends Pojo1 { - public int getA() { - return this.a; - } + private int a; + public int getA() { + return this.a; } - public static class Pojo7 { + public void setA(int a) { + this.a = a; + } + } - private int a; + public static class Pojo6 { - public void setA(int a) { - this.a = a; - } + private int a; + public int getA() { + return this.a; } + } - public static class Pojo8 { - } + public static class Pojo7 { - public static class Pojo9 { + private int a; - private int a; + public void setA(int a) { + this.a = a; + } + } - public Pojo9(int a) { - this.a = a; - } + public static class Pojo8 {} - public int getA() { - return this.a; - } + public static class Pojo9 { - public void setA(int a) { - this.a = a; - } + private int a; + public Pojo9(int a) { + this.a = a; } - @Test - public void testEnumEncoder() { - IEncoder encoder = EncoderResolver.resolveClass(TestEnum.class); - Assert.assertEquals(encoder.getClass(), EnumEncoder.class); + public int getA() { + return this.a; } - public enum TestEnum { - A, B, C, D, E + public void setA(int a) { + this.a = a; } + } + + @Test + public void testEnumEncoder() { + IEncoder encoder = EncoderResolver.resolveClass(TestEnum.class); + Assert.assertEquals(encoder.getClass(), EnumEncoder.class); + } + public enum TestEnum { + A, + B, + C, + D, + E + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderTest.java index 846d1d5cd..2a37b4e91 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/encoder/EncoderTest.java @@ -22,6 +22,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.impl.CharacterEncoder; import org.apache.geaflow.common.encoder.impl.DoubleEncoder; @@ -40,616 +41,599 @@ public class EncoderTest { - @Test - public void testBooleanEncoder() { - IEncoder encoder = Encoders.BOOLEAN; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try { - encoder.encode(null, bos); - encoder.encode(true, bos); - encoder.encode(false, bos); - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - Boolean d1 = encoder.decode(bis); - Boolean d2 = encoder.decode(bis); - Boolean d3 = encoder.decode(bis); - Assert.assertNull(d1); - Assert.assertEquals(d2, Boolean.TRUE); - Assert.assertEquals(d3, Boolean.FALSE); - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testBooleanArrEncoder() { - IEncoder encoder = Encoders.BOOLEAN_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - boolean[] booleans = {true, false}; - try { - encoder.encode(null, bos); - encoder.encode(booleans, bos); - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - boolean[] d1 = encoder.decode(bis); - boolean[] d2 = encoder.decode(bis); - Assert.assertNull(d1); - Assert.assertTrue(d2[0]); - Assert.assertFalse(d2[1]); - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - - @Test - public void testByteEncoder() { - try { - IEncoder encoder = Encoders.BYTE; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) { - encoder.encode((byte) i, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = Byte.MIN_VALUE; - while (bis.available() > 0) { - byte res = encoder.decode(bis); - Assert.assertEquals(res, c); - c++; - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testByteArrEncoder() { - try { - IEncoder encoder = Encoders.BYTE_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - byte[] bytes = new byte[256]; - byte v = Byte.MIN_VALUE; - for (int i = 0; i < 256; i++) { - bytes[i] = v; - v++; - } - encoder.encode(null, bos); - encoder.encode(bytes, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - byte[] bytes1 = encoder.decode(bis); - Assert.assertNull(bytes1); - - byte[] bytes2 = encoder.decode(bis); - v = Byte.MIN_VALUE; - for (int i = 0; i < 256; i++) { - Assert.assertEquals(bytes2[i], v); - v++; - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testShortEncoder() { - try { - doTestShortEncoder(ShortEncoder.INSTANCE, 1); - doTestShortEncoder(ShortEncoder.INSTANCE, -1); - doTestShortEncoder(Encoders.SHORT, 1); - doTestShortEncoder(Encoders.SHORT, -1); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private static void doTestShortEncoder(IEncoder encoder, int flag) throws IOException { - int offset = flag > 0 ? 1 : 0; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 16; i++) { - int v = (flag << i) - offset; - encoder.encode((short) v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - int v = (flag << c) - offset; - short res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testShortArrEncoder() { - try { - int n = 16; - IEncoder encoder = Encoders.SHORT_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - short[] shorts = new short[n]; - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - shorts[i] = (short) v; - } - - encoder.encode(null, bos); - encoder.encode(shorts, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - short[] shorts1 = encoder.decode(bis); - Assert.assertNull(shorts1); - - short[] shorts2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - Assert.assertEquals(shorts2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testIntegerEncoder() { - try { - doTestIntegerEncoder(Encoders.INTEGER, 1); - doTestIntegerEncoder(Encoders.INTEGER, -1); - doTestIntegerEncoder(IntegerEncoder.INSTANCE, 1); - doTestIntegerEncoder(IntegerEncoder.INSTANCE, -1); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private static void doTestIntegerEncoder(IEncoder encoder, int flag) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 32; i++) { - int v = flag << i; - encoder.encode(v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - int v = flag << c; - int res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testIntArrEncoder() { - try { - int n = 32; - IEncoder encoder = Encoders.INTEGER_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - int[] ints = new int[n]; - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - ints[i] = v; - } - - encoder.encode(null, bos); - encoder.encode(ints, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int[] ints1 = encoder.decode(bis); - Assert.assertNull(ints1); - - int[] ints2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - Assert.assertEquals(ints2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testLongEncoder() { - try { - doTestLongEncoder(Encoders.LONG, 1); - doTestLongEncoder(Encoders.LONG, -1); - doTestLongEncoder(LongEncoder.INSTANCE, 1); - doTestLongEncoder(LongEncoder.INSTANCE, -1); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private static void doTestLongEncoder(IEncoder encoder, long flag) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 64; i++) { - long v = flag << i; - encoder.encode(v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - long v = flag << c; - long res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testLongArrEncoder() { - try { - int n = 64; - IEncoder encoder = Encoders.LONG_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - long[] longs = new long[n]; - for (int i = 0; i < n; i++) { - long v = (1L << i) - 1; - longs[i] = v; - } - - encoder.encode(null, bos); - encoder.encode(longs, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - long[] longs1 = encoder.decode(bis); - Assert.assertNull(longs1); - - long[] longs2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - long v = (1L << i) - 1; - Assert.assertEquals(longs2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testFloatEncoder() { - try { - doTestFloatEncoder(Encoders.FLOAT, 1.25f); - doTestFloatEncoder(Encoders.FLOAT, -1.25f); - doTestFloatEncoder(FloatEncoder.INSTANCE, 1.25f); - doTestFloatEncoder(FloatEncoder.INSTANCE, -1.25f); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private void doTestFloatEncoder(IEncoder encoder, float flag) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 32; i++) { - float v = flag * (1 << i); - encoder.encode(v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - float v = flag * (1 << c); - float res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testFloatArrEncoder() { - try { - int n = 32; - IEncoder encoder = Encoders.FLOAT_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - float[] floats = new float[n]; - for (int i = 0; i < n; i++) { - float v = (1 << i) - 1; - floats[i] = v; - } - - encoder.encode(null, bos); - encoder.encode(floats, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - float[] floats1 = encoder.decode(bis); - Assert.assertNull(floats1); - - float[] floats2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - float v = (1 << i) - 1; - Assert.assertEquals(floats2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testDoubleEncoder() { - try { - doTestDoubleEncoder(Encoders.DOUBLE, 1.25); - doTestDoubleEncoder(Encoders.DOUBLE, -1.25); - doTestDoubleEncoder(DoubleEncoder.INSTANCE, 1.25); - doTestDoubleEncoder(DoubleEncoder.INSTANCE, -1.25); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private void doTestDoubleEncoder(IEncoder encoder, double flag) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 64; i++) { - double v = flag * (1 << i); - encoder.encode(v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - double v = flag * (1 << c); - double res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testDoubleArrEncoder() { - try { - int n = 64; - IEncoder encoder = Encoders.DOUBLE_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - double[] doubles = new double[n]; - for (int i = 0; i < n; i++) { - double v = (1L << i) - 1; - doubles[i] = v; - } - - encoder.encode(null, bos); - encoder.encode(doubles, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - double[] doubles1 = encoder.decode(bis); - Assert.assertNull(doubles1); - - double[] doubles2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - double v = (1L << i) - 1; - Assert.assertEquals(doubles2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testCharEncoder() { - try { - doTestCharEncoder(CharacterEncoder.INSTANCE); - doTestCharEncoder(Encoders.CHARACTER); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - private static void doTestCharEncoder(IEncoder encoder) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 16; i++) { - int v = 1 << i; - encoder.encode((char) v, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int c = 0; - while (bis.available() > 0) { - int v = 1 << c; - char res = encoder.decode(bis); - Assert.assertEquals(v, res); - c++; - } - } - - @Test - public void testCharArrEncoder() { - try { - int n = 16; - IEncoder encoder = Encoders.CHARACTER_ARR; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - char[] chars = new char[n]; - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - chars[i] = (char) v; - } - - encoder.encode(null, bos); - encoder.encode(chars, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - char[] chars1 = encoder.decode(bis); - Assert.assertNull(chars1); - - char[] chars2 = encoder.decode(bis); - for (int i = 0; i < n; i++) { - int v = (1 << i) - 1; - Assert.assertEquals(chars2[i], v); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); - } - } - - @Test - public void testStringEncoder() throws Exception { - String[] strings = { - "abc", - "abc", - null, - null, - null, - "[ 007]", - "!@#$%^&*()_=-", - "abc", - "abc", - "abc", - "让世界没有难用的图" - }; - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - - for (String s : strings) { - Encoders.STRING.encode(s, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int i = 0; - while (bis.available() > 0) { - String res = Encoders.STRING.decode(bis); - if (i == 10) { - Assert.assertEquals(strings[i], res, "" + strings[i].length() + " " + res.length()); - } else { - Assert.assertEquals(strings[i], res); - } - i++; - } - } - - @Test - public void testStringArrEncoder() throws Exception { - String[] strings = { - "abc", - "abc", - null, - null, - null, - "[ 007]", - "!@#$%^&*()_=-", - "abc", - "abc", - "abc", - "让世界没有难用的图" - }; - IEncoder encoder = new GenericArrayEncoder<>(Encoders.STRING, String[]::new); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - encoder.encode(null, bos); - encoder.encode(strings, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - String[] strings1 = encoder.decode(bis); - Assert.assertNull(strings1); - String[] strings2 = encoder.decode(bis); - for (int i = 0; i < strings.length; i++) { - Assert.assertEquals(strings[i], strings2[i], "not right " + i); - } - } - - @Test - public void testEnumEncoder() throws Exception { - IEncoder encoder = new EnumEncoder<>(TestEnum.class); - encoder.init(new Configuration()); - - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - encoder.encode(null, bos); - encoder.encode(TestEnum.A, bos); - encoder.encode(TestEnum.B, bos); - encoder.encode(TestEnum.C, bos); - encoder.encode(TestEnum.D, bos); - encoder.encode(TestEnum.E, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - TestEnum e0 = encoder.decode(bis); - TestEnum e1 = encoder.decode(bis); - TestEnum e2 = encoder.decode(bis); - TestEnum e3 = encoder.decode(bis); - TestEnum e4 = encoder.decode(bis); - TestEnum e5 = encoder.decode(bis); - Assert.assertNull(e0); - Assert.assertEquals(e1, TestEnum.A); - Assert.assertEquals(e2, TestEnum.B); - Assert.assertEquals(e3, TestEnum.C); - Assert.assertEquals(e4, TestEnum.D); - Assert.assertEquals(e5, TestEnum.E); - } - - public enum TestEnum { - A, B, C, D, E - } - - @Test - public void testTupleEncoder() throws IOException { - IEncoder> tupleEncoder = Encoders.tuple(Encoders.INTEGER, Encoders.INTEGER); - Assert.assertNotNull(tupleEncoder); - tupleEncoder.init(new Configuration()); - - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 10; i++) { - Tuple data = Tuple.of(1, 1); - tupleEncoder.encode(data, bos); - } - tupleEncoder.encode(null, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - for (int i = 0; i < 10; i++) { - Tuple data = tupleEncoder.decode(bis); - Assert.assertNotNull(data); - } - Assert.assertNull(tupleEncoder.decode(bis)); - } - - @Test - public void testTripleEncoder() throws IOException { - IEncoder> tripleEncoder = - Encoders.triple(Encoders.INTEGER, Encoders.INTEGER, Encoders.INTEGER); - Assert.assertNotNull(tripleEncoder); - tripleEncoder.init(new Configuration()); - - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 0; i < 10; i++) { - Triple data = Triple.of(1, 1, 1); - tripleEncoder.encode(data, bos); - } - tripleEncoder.encode(null, bos); - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - for (int i = 0; i < 10; i++) { - Triple data = tripleEncoder.decode(bis); - Assert.assertNotNull(data); - } - Assert.assertNull(tripleEncoder.decode(bis)); + @Test + public void testBooleanEncoder() { + IEncoder encoder = Encoders.BOOLEAN; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + encoder.encode(null, bos); + encoder.encode(true, bos); + encoder.encode(false, bos); + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + Boolean d1 = encoder.decode(bis); + Boolean d2 = encoder.decode(bis); + Boolean d3 = encoder.decode(bis); + Assert.assertNull(d1); + Assert.assertEquals(d2, Boolean.TRUE); + Assert.assertEquals(d3, Boolean.FALSE); + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testBooleanArrEncoder() { + IEncoder encoder = Encoders.BOOLEAN_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + boolean[] booleans = {true, false}; + try { + encoder.encode(null, bos); + encoder.encode(booleans, bos); + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + boolean[] d1 = encoder.decode(bis); + boolean[] d2 = encoder.decode(bis); + Assert.assertNull(d1); + Assert.assertTrue(d2[0]); + Assert.assertFalse(d2[1]); + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testByteEncoder() { + try { + IEncoder encoder = Encoders.BYTE; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) { + encoder.encode((byte) i, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = Byte.MIN_VALUE; + while (bis.available() > 0) { + byte res = encoder.decode(bis); + Assert.assertEquals(res, c); + c++; + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testByteArrEncoder() { + try { + IEncoder encoder = Encoders.BYTE_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + byte[] bytes = new byte[256]; + byte v = Byte.MIN_VALUE; + for (int i = 0; i < 256; i++) { + bytes[i] = v; + v++; + } + encoder.encode(null, bos); + encoder.encode(bytes, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + byte[] bytes1 = encoder.decode(bis); + Assert.assertNull(bytes1); + + byte[] bytes2 = encoder.decode(bis); + v = Byte.MIN_VALUE; + for (int i = 0; i < 256; i++) { + Assert.assertEquals(bytes2[i], v); + v++; + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testShortEncoder() { + try { + doTestShortEncoder(ShortEncoder.INSTANCE, 1); + doTestShortEncoder(ShortEncoder.INSTANCE, -1); + doTestShortEncoder(Encoders.SHORT, 1); + doTestShortEncoder(Encoders.SHORT, -1); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private static void doTestShortEncoder(IEncoder encoder, int flag) throws IOException { + int offset = flag > 0 ? 1 : 0; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 16; i++) { + int v = (flag << i) - offset; + encoder.encode((short) v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + int v = (flag << c) - offset; + short res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testShortArrEncoder() { + try { + int n = 16; + IEncoder encoder = Encoders.SHORT_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + short[] shorts = new short[n]; + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + shorts[i] = (short) v; + } + + encoder.encode(null, bos); + encoder.encode(shorts, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + short[] shorts1 = encoder.decode(bis); + Assert.assertNull(shorts1); + + short[] shorts2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + Assert.assertEquals(shorts2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testIntegerEncoder() { + try { + doTestIntegerEncoder(Encoders.INTEGER, 1); + doTestIntegerEncoder(Encoders.INTEGER, -1); + doTestIntegerEncoder(IntegerEncoder.INSTANCE, 1); + doTestIntegerEncoder(IntegerEncoder.INSTANCE, -1); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private static void doTestIntegerEncoder(IEncoder encoder, int flag) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 32; i++) { + int v = flag << i; + encoder.encode(v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + int v = flag << c; + int res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testIntArrEncoder() { + try { + int n = 32; + IEncoder encoder = Encoders.INTEGER_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + int[] ints = new int[n]; + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + ints[i] = v; + } + + encoder.encode(null, bos); + encoder.encode(ints, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int[] ints1 = encoder.decode(bis); + Assert.assertNull(ints1); + + int[] ints2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + Assert.assertEquals(ints2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testLongEncoder() { + try { + doTestLongEncoder(Encoders.LONG, 1); + doTestLongEncoder(Encoders.LONG, -1); + doTestLongEncoder(LongEncoder.INSTANCE, 1); + doTestLongEncoder(LongEncoder.INSTANCE, -1); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private static void doTestLongEncoder(IEncoder encoder, long flag) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 64; i++) { + long v = flag << i; + encoder.encode(v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + long v = flag << c; + long res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testLongArrEncoder() { + try { + int n = 64; + IEncoder encoder = Encoders.LONG_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + long[] longs = new long[n]; + for (int i = 0; i < n; i++) { + long v = (1L << i) - 1; + longs[i] = v; + } + + encoder.encode(null, bos); + encoder.encode(longs, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + long[] longs1 = encoder.decode(bis); + Assert.assertNull(longs1); + + long[] longs2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + long v = (1L << i) - 1; + Assert.assertEquals(longs2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testFloatEncoder() { + try { + doTestFloatEncoder(Encoders.FLOAT, 1.25f); + doTestFloatEncoder(Encoders.FLOAT, -1.25f); + doTestFloatEncoder(FloatEncoder.INSTANCE, 1.25f); + doTestFloatEncoder(FloatEncoder.INSTANCE, -1.25f); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private void doTestFloatEncoder(IEncoder encoder, float flag) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 32; i++) { + float v = flag * (1 << i); + encoder.encode(v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + float v = flag * (1 << c); + float res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testFloatArrEncoder() { + try { + int n = 32; + IEncoder encoder = Encoders.FLOAT_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + float[] floats = new float[n]; + for (int i = 0; i < n; i++) { + float v = (1 << i) - 1; + floats[i] = v; + } + + encoder.encode(null, bos); + encoder.encode(floats, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + float[] floats1 = encoder.decode(bis); + Assert.assertNull(floats1); + + float[] floats2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + float v = (1 << i) - 1; + Assert.assertEquals(floats2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testDoubleEncoder() { + try { + doTestDoubleEncoder(Encoders.DOUBLE, 1.25); + doTestDoubleEncoder(Encoders.DOUBLE, -1.25); + doTestDoubleEncoder(DoubleEncoder.INSTANCE, 1.25); + doTestDoubleEncoder(DoubleEncoder.INSTANCE, -1.25); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private void doTestDoubleEncoder(IEncoder encoder, double flag) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 64; i++) { + double v = flag * (1 << i); + encoder.encode(v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + double v = flag * (1 << c); + double res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testDoubleArrEncoder() { + try { + int n = 64; + IEncoder encoder = Encoders.DOUBLE_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + double[] doubles = new double[n]; + for (int i = 0; i < n; i++) { + double v = (1L << i) - 1; + doubles[i] = v; + } + + encoder.encode(null, bos); + encoder.encode(doubles, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + double[] doubles1 = encoder.decode(bis); + Assert.assertNull(doubles1); + + double[] doubles2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + double v = (1L << i) - 1; + Assert.assertEquals(doubles2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testCharEncoder() { + try { + doTestCharEncoder(CharacterEncoder.INSTANCE); + doTestCharEncoder(Encoders.CHARACTER); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + private static void doTestCharEncoder(IEncoder encoder) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 16; i++) { + int v = 1 << i; + encoder.encode((char) v, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int c = 0; + while (bis.available() > 0) { + int v = 1 << c; + char res = encoder.decode(bis); + Assert.assertEquals(v, res); + c++; + } + } + + @Test + public void testCharArrEncoder() { + try { + int n = 16; + IEncoder encoder = Encoders.CHARACTER_ARR; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + char[] chars = new char[n]; + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + chars[i] = (char) v; + } + + encoder.encode(null, bos); + encoder.encode(chars, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + char[] chars1 = encoder.decode(bis); + Assert.assertNull(chars1); + + char[] chars2 = encoder.decode(bis); + for (int i = 0; i < n; i++) { + int v = (1 << i) - 1; + Assert.assertEquals(chars2[i], v); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + } + + @Test + public void testStringEncoder() throws Exception { + String[] strings = { + "abc", "abc", null, null, null, "[ 007]", "!@#$%^&*()_=-", "abc", "abc", "abc", "让世界没有难用的图" + }; + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + + for (String s : strings) { + Encoders.STRING.encode(s, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int i = 0; + while (bis.available() > 0) { + String res = Encoders.STRING.decode(bis); + if (i == 10) { + Assert.assertEquals(strings[i], res, "" + strings[i].length() + " " + res.length()); + } else { + Assert.assertEquals(strings[i], res); + } + i++; + } + } + + @Test + public void testStringArrEncoder() throws Exception { + String[] strings = { + "abc", "abc", null, null, null, "[ 007]", "!@#$%^&*()_=-", "abc", "abc", "abc", "让世界没有难用的图" + }; + IEncoder encoder = new GenericArrayEncoder<>(Encoders.STRING, String[]::new); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + encoder.encode(null, bos); + encoder.encode(strings, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + String[] strings1 = encoder.decode(bis); + Assert.assertNull(strings1); + String[] strings2 = encoder.decode(bis); + for (int i = 0; i < strings.length; i++) { + Assert.assertEquals(strings[i], strings2[i], "not right " + i); } + } + + @Test + public void testEnumEncoder() throws Exception { + IEncoder encoder = new EnumEncoder<>(TestEnum.class); + encoder.init(new Configuration()); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + encoder.encode(null, bos); + encoder.encode(TestEnum.A, bos); + encoder.encode(TestEnum.B, bos); + encoder.encode(TestEnum.C, bos); + encoder.encode(TestEnum.D, bos); + encoder.encode(TestEnum.E, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + TestEnum e0 = encoder.decode(bis); + TestEnum e1 = encoder.decode(bis); + TestEnum e2 = encoder.decode(bis); + TestEnum e3 = encoder.decode(bis); + TestEnum e4 = encoder.decode(bis); + TestEnum e5 = encoder.decode(bis); + Assert.assertNull(e0); + Assert.assertEquals(e1, TestEnum.A); + Assert.assertEquals(e2, TestEnum.B); + Assert.assertEquals(e3, TestEnum.C); + Assert.assertEquals(e4, TestEnum.D); + Assert.assertEquals(e5, TestEnum.E); + } + + public enum TestEnum { + A, + B, + C, + D, + E + } + + @Test + public void testTupleEncoder() throws IOException { + IEncoder> tupleEncoder = + Encoders.tuple(Encoders.INTEGER, Encoders.INTEGER); + Assert.assertNotNull(tupleEncoder); + tupleEncoder.init(new Configuration()); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 10; i++) { + Tuple data = Tuple.of(1, 1); + tupleEncoder.encode(data, bos); + } + tupleEncoder.encode(null, bos); + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + for (int i = 0; i < 10; i++) { + Tuple data = tupleEncoder.decode(bis); + Assert.assertNotNull(data); + } + Assert.assertNull(tupleEncoder.decode(bis)); + } + + @Test + public void testTripleEncoder() throws IOException { + IEncoder> tripleEncoder = + Encoders.triple(Encoders.INTEGER, Encoders.INTEGER, Encoders.INTEGER); + Assert.assertNotNull(tripleEncoder); + tripleEncoder.init(new Configuration()); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 0; i < 10; i++) { + Triple data = Triple.of(1, 1, 1); + tripleEncoder.encode(data, bos); + } + tripleEncoder.encode(null, bos); + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + for (int i = 0; i < 10; i++) { + Triple data = tripleEncoder.decode(bis); + Assert.assertNotNull(data); + } + Assert.assertNull(tripleEncoder.decode(bis)); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/schema/FieldTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/schema/FieldTest.java index 1e481a5e2..cd21b0513 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/schema/FieldTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/schema/FieldTest.java @@ -25,47 +25,46 @@ public class FieldTest { - private static final String FIELD_NAME = "f"; + private static final String FIELD_NAME = "f"; - @Test - public void testField() { + @Test + public void testField() { - Field.BooleanField field0 = Field.newBooleanField(FIELD_NAME); - Assert.assertEquals(field0.getName(), FIELD_NAME); - Assert.assertEquals(field0.getType(), BooleanType.INSTANCE); - Assert.assertTrue(field0.isNullable()); - Assert.assertNull(field0.getDefaultValue()); - - Field.BooleanField field1 = Field.newBooleanField(FIELD_NAME, Boolean.TRUE, Boolean.FALSE); - Field.ByteField field2 = Field.newByteField(FIELD_NAME); - Field.ByteField field3 = Field.newByteField(FIELD_NAME, Boolean.TRUE, (byte) 0); - Field.ShortField field4 = Field.newShortField(FIELD_NAME); - Field.ShortField field5 = Field.newShortField(FIELD_NAME, Boolean.TRUE, (short) 0); - Field.IntegerField field6 = Field.newIntegerField(FIELD_NAME); - Field.IntegerField field7 = Field.newIntegerField(FIELD_NAME, Boolean.TRUE, 0); - Field.LongField field8 = Field.newLongField(FIELD_NAME); - Field.LongField field9 = Field.newLongField(FIELD_NAME, Boolean.TRUE, 0L); - Field.FloatField field10 = Field.newFloatField(FIELD_NAME); - Field.FloatField field11 = Field.newFloatField(FIELD_NAME, Boolean.TRUE, 0.0f); - Field.DoubleField field12 = Field.newDoubleField(FIELD_NAME); - Field.DoubleField field13 = Field.newDoubleField(FIELD_NAME, Boolean.TRUE, 0.0); - Field.StringField field14 = Field.newStringField(FIELD_NAME); - Field.StringField field15 = Field.newStringField(FIELD_NAME, Boolean.TRUE, ""); - Assert.assertNotNull(field1); - Assert.assertNotNull(field2); - Assert.assertNotNull(field3); - Assert.assertNotNull(field4); - Assert.assertNotNull(field5); - Assert.assertNotNull(field6); - Assert.assertNotNull(field7); - Assert.assertNotNull(field8); - Assert.assertNotNull(field9); - Assert.assertNotNull(field10); - Assert.assertNotNull(field11); - Assert.assertNotNull(field12); - Assert.assertNotNull(field13); - Assert.assertNotNull(field14); - Assert.assertNotNull(field15); - } + Field.BooleanField field0 = Field.newBooleanField(FIELD_NAME); + Assert.assertEquals(field0.getName(), FIELD_NAME); + Assert.assertEquals(field0.getType(), BooleanType.INSTANCE); + Assert.assertTrue(field0.isNullable()); + Assert.assertNull(field0.getDefaultValue()); + Field.BooleanField field1 = Field.newBooleanField(FIELD_NAME, Boolean.TRUE, Boolean.FALSE); + Field.ByteField field2 = Field.newByteField(FIELD_NAME); + Field.ByteField field3 = Field.newByteField(FIELD_NAME, Boolean.TRUE, (byte) 0); + Field.ShortField field4 = Field.newShortField(FIELD_NAME); + Field.ShortField field5 = Field.newShortField(FIELD_NAME, Boolean.TRUE, (short) 0); + Field.IntegerField field6 = Field.newIntegerField(FIELD_NAME); + Field.IntegerField field7 = Field.newIntegerField(FIELD_NAME, Boolean.TRUE, 0); + Field.LongField field8 = Field.newLongField(FIELD_NAME); + Field.LongField field9 = Field.newLongField(FIELD_NAME, Boolean.TRUE, 0L); + Field.FloatField field10 = Field.newFloatField(FIELD_NAME); + Field.FloatField field11 = Field.newFloatField(FIELD_NAME, Boolean.TRUE, 0.0f); + Field.DoubleField field12 = Field.newDoubleField(FIELD_NAME); + Field.DoubleField field13 = Field.newDoubleField(FIELD_NAME, Boolean.TRUE, 0.0); + Field.StringField field14 = Field.newStringField(FIELD_NAME); + Field.StringField field15 = Field.newStringField(FIELD_NAME, Boolean.TRUE, ""); + Assert.assertNotNull(field1); + Assert.assertNotNull(field2); + Assert.assertNotNull(field3); + Assert.assertNotNull(field4); + Assert.assertNotNull(field5); + Assert.assertNotNull(field6); + Assert.assertNotNull(field7); + Assert.assertNotNull(field8); + Assert.assertNotNull(field9); + Assert.assertNotNull(field10); + Assert.assertNotNull(field11); + Assert.assertNotNull(field12); + Assert.assertNotNull(field13); + Assert.assertNotNull(field14); + Assert.assertNotNull(field15); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/serialize/impl/KryoSerializerTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/serialize/impl/KryoSerializerTest.java index ba97fe3ec..dd141d80d 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/serialize/impl/KryoSerializerTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/serialize/impl/KryoSerializerTest.java @@ -20,39 +20,39 @@ package org.apache.geaflow.common.serialize.impl; import java.io.Serializable; + import org.testng.Assert; import org.testng.annotations.Test; public class KryoSerializerTest { - @Test - public void testSerializeLambda() { - KryoSerializer kryoSerializer = new KryoSerializer(); - - LambdaMsg msg = new LambdaMsg(e -> e + 1); - byte[] serialize = kryoSerializer.serialize(msg); - LambdaMsg deserialized = (LambdaMsg) kryoSerializer.deserialize(serialize); + @Test + public void testSerializeLambda() { + KryoSerializer kryoSerializer = new KryoSerializer(); - Assert.assertNotNull(deserialized.func); - Assert.assertEquals(2, deserialized.getFunc().accept(1)); + LambdaMsg msg = new LambdaMsg(e -> e + 1); + byte[] serialize = kryoSerializer.serialize(msg); + LambdaMsg deserialized = (LambdaMsg) kryoSerializer.deserialize(serialize); - } + Assert.assertNotNull(deserialized.func); + Assert.assertEquals(2, deserialized.getFunc().accept(1)); + } - static class LambdaMsg { + static class LambdaMsg { - Func func; + Func func; - public LambdaMsg(Func msg) { - this.func = msg; - } - - public Func getFunc() { - return func; - } + public LambdaMsg(Func msg) { + this.func = msg; } - @FunctionalInterface - interface Func extends Serializable { - int accept(int input); + public Func getFunc() { + return func; } + } + + @FunctionalInterface + interface Func extends Serializable { + int accept(int input); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/DateTimeUtilTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/DateTimeUtilTest.java index 03c5464bb..f3d4cf5c4 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/DateTimeUtilTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/DateTimeUtilTest.java @@ -24,13 +24,12 @@ public class DateTimeUtilTest { - @Test - public void testDateTimeUtil() { - Assert.assertEquals(DateTimeUtil.fromUnixTime(1111, "yyyy-MM-dd hh:mm:ss"), - "1970-01-01 08:00:01"); - Assert.assertEquals(DateTimeUtil.toUnixTime("1970-01-01 08:18:31", "yyyy-MM-dd hh:mm:ss"), - 1111000); - Assert.assertEquals(DateTimeUtil.toUnixTime("", "yyyy-MM-dd hh:mm:ss"), - -1); - } + @Test + public void testDateTimeUtil() { + Assert.assertEquals( + DateTimeUtil.fromUnixTime(1111, "yyyy-MM-dd hh:mm:ss"), "1970-01-01 08:00:01"); + Assert.assertEquals( + DateTimeUtil.toUnixTime("1970-01-01 08:18:31", "yyyy-MM-dd hh:mm:ss"), 1111000); + Assert.assertEquals(DateTimeUtil.toUnixTime("", "yyyy-MM-dd hh:mm:ss"), -1); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/ProcessUtilTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/ProcessUtilTest.java index d84ff50c0..318879caf 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/ProcessUtilTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/ProcessUtilTest.java @@ -25,6 +25,7 @@ import static org.testng.Assert.assertThrows; import java.io.IOException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -33,32 +34,32 @@ public class ProcessUtilTest { - @Mock - private Runtime runtime; + @Mock private Runtime runtime; - @Mock - private Process process; + @Mock private Process process; - @BeforeMethod - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - when(runtime.exec(anyString())).thenReturn(process); - } + @BeforeMethod + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + when(runtime.exec(anyString())).thenReturn(process); + } - @Test - public void execute_CommandThrowsIOException_ExceptionHandled() throws IOException, InterruptedException { - String cmd = "some command"; - when(runtime.exec(anyString())).thenThrow(new IOException("IO error")); + @Test + public void execute_CommandThrowsIOException_ExceptionHandled() + throws IOException, InterruptedException { + String cmd = "some command"; + when(runtime.exec(anyString())).thenThrow(new IOException("IO error")); - assertThrows(GeaflowRuntimeException.class, () -> ProcessUtil.execute(cmd)); - } + assertThrows(GeaflowRuntimeException.class, () -> ProcessUtil.execute(cmd)); + } - @Test - public void execute_CommandThrowsInterruptedException_ExceptionHandled() throws IOException, InterruptedException { - String cmd = "some command"; - when(runtime.exec(anyString())).thenReturn(process); - doThrow(new InterruptedException("Interrupted")).when(process).waitFor(); + @Test + public void execute_CommandThrowsInterruptedException_ExceptionHandled() + throws IOException, InterruptedException { + String cmd = "some command"; + when(runtime.exec(anyString())).thenReturn(process); + doThrow(new InterruptedException("Interrupted")).when(process).waitFor(); - assertThrows(GeaflowRuntimeException.class, () -> ProcessUtil.execute(cmd)); - } + assertThrows(GeaflowRuntimeException.class, () -> ProcessUtil.execute(cmd)); + } } diff --git a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/RetryCommandTest.java b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/RetryCommandTest.java index 0f5dd45ca..519a84b5b 100644 --- a/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/RetryCommandTest.java +++ b/geaflow/geaflow-common/src/test/java/org/apache/geaflow/common/utils/RetryCommandTest.java @@ -27,33 +27,40 @@ public class RetryCommandTest { - private static final Logger LOGGER = LoggerFactory.getLogger(RetryCommandTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RetryCommandTest.class); - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testFail() { - final int time = 0; - RetryCommand.run(() -> 10 / time > 0, 3, 1); - } + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testFail() { + final int time = 0; + RetryCommand.run(() -> 10 / time > 0, 3, 1); + } - @Test - public void testRun() { - Object result = RetryCommand.run(() -> { - LOGGER.info("hello"); - return "null"; - }, 1, 100); - Assert.assertEquals(result, "null"); - } + @Test + public void testRun() { + Object result = + RetryCommand.run( + () -> { + LOGGER.info("hello"); + return "null"; + }, + 1, + 100); + Assert.assertEquals(result, "null"); + } - @Test - public void testException() { - Object result = null; - try { - RetryCommand.run(() -> { - throw new RuntimeException("exception"); - }, 1, 100); - } catch (Throwable e) { - result = e; - } - Assert.assertTrue(result instanceof RuntimeException); + @Test + public void testException() { + Object result = null; + try { + RetryCommand.run( + () -> { + throw new RuntimeException("exception"); + }, + 1, + 100); + } catch (Throwable e) { + result = e; } + Assert.assertTrue(result instanceof RuntimeException); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/collector/Collector.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/collector/Collector.java index 798b37818..07f28f8c8 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/collector/Collector.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/collector/Collector.java @@ -21,12 +21,8 @@ import java.io.Serializable; - public interface Collector extends Serializable { - /** - * Partition data with value itself. - */ - void partition(T value); - + /** Partition data with value itself. */ + void partition(T value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/context/RuntimeContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/context/RuntimeContext.java index d439e7940..6c0d98414 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/context/RuntimeContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/context/RuntimeContext.java @@ -21,51 +21,38 @@ import java.io.Serializable; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.task.TaskArgs; import org.apache.geaflow.metrics.common.api.MetricGroup; public interface RuntimeContext extends Serializable, Cloneable { - /** - * Returns pipeline id. - */ - long getPipelineId(); + /** Returns pipeline id. */ + long getPipelineId(); - /** - * Returns pipeline name. - */ - String getPipelineName(); + /** Returns pipeline name. */ + String getPipelineName(); - /** - * Get Relevant information of task. - */ - TaskArgs getTaskArgs(); + /** Get Relevant information of task. */ + TaskArgs getTaskArgs(); - /** - * Returns runtime configuration. - */ - Configuration getConfiguration(); + /** Returns runtime configuration. */ + Configuration getConfiguration(); - /** - * Returns runtime work path dir. - */ - String getWorkPath(); + /** Returns runtime work path dir. */ + String getWorkPath(); - /** - * Returns metric group ref. - */ - MetricGroup getMetric(); + /** Returns metric group ref. */ + MetricGroup getMetric(); - /** - * Clone runtime context and put all opConfig into runtime config. - * - * @param opConfig The config of corresponding operator. - */ - RuntimeContext clone(Map opConfig); + /** + * Clone runtime context and put all opConfig into runtime config. + * + * @param opConfig The config of corresponding operator. + */ + RuntimeContext clone(Map opConfig); - /** - * Returns current window id. - */ - long getWindowId(); + /** Returns current window id. */ + long getWindowId(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/Function.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/Function.java index b59dd07f0..d0049773d 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/Function.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/Function.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface Function extends Serializable { - -} +public interface Function extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichFunction.java index 73c761311..969b29b2f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichFunction.java @@ -19,19 +19,13 @@ package org.apache.geaflow.api.function; - import org.apache.geaflow.api.context.RuntimeContext; public abstract class RichFunction implements Function { - /** - * Open function. - */ - public abstract void open(RuntimeContext runtimeContext); - - /** - * Close function. - */ - public abstract void close(); + /** Open function. */ + public abstract void open(RuntimeContext runtimeContext); + /** Close function. */ + public abstract void close(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichWindowFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichWindowFunction.java index 890e0ff50..2e598437e 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichWindowFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/RichWindowFunction.java @@ -21,8 +21,6 @@ public abstract class RichWindowFunction extends RichFunction { - /** - * Finish window function. - */ - public abstract void finish(); + /** Finish window function. */ + public abstract void finish(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/AggregateFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/AggregateFunction.java index 2fed75f37..f5e368f7d 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/AggregateFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/AggregateFunction.java @@ -23,24 +23,15 @@ public interface AggregateFunction extends Function { - /** - * Create aggregate accumulator for aggregate function to store the aggregate value. - */ - ACC createAccumulator(); + /** Create aggregate accumulator for aggregate function to store the aggregate value. */ + ACC createAccumulator(); - /** - * Accumulate the input to the accumulator. - */ - void add(IN value, ACC accumulator); + /** Accumulate the input to the accumulator. */ + void add(IN value, ACC accumulator); - /** - * Get aggregate result from the accumulator. - */ - OUT getResult(ACC accumulator); - - /** - * Merge a with b. - */ - ACC merge(ACC a, ACC b); + /** Get aggregate result from the accumulator. */ + OUT getResult(ACC accumulator); + /** Merge a with b. */ + ACC merge(ACC a, ACC b); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FilterFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FilterFunction.java index 9836597e5..380e64d48 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FilterFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FilterFunction.java @@ -24,9 +24,6 @@ @FunctionalInterface public interface FilterFunction extends Function { - /** - * If false then filter the record. - */ - boolean filter(T record); - + /** If false then filter the record. */ + boolean filter(T record); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FlatMapFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FlatMapFunction.java index 0a0c23bfa..17c288457 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FlatMapFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/FlatMapFunction.java @@ -25,8 +25,9 @@ @FunctionalInterface public interface FlatMapFunction extends Function { - /** - * Process input value to produce 0~n records, and then partition all records by collector directly. - */ - void flatMap(IN value, Collector collector); + /** + * Process input value to produce 0~n records, and then partition all records by collector + * directly. + */ + void flatMap(IN value, Collector collector); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/KeySelector.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/KeySelector.java index e8097a7e4..780141c7a 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/KeySelector.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/KeySelector.java @@ -24,9 +24,6 @@ @FunctionalInterface public interface KeySelector extends Function { - /** - * Extract the partition key from value. - */ - KEY getKey(IN value); - + /** Extract the partition key from value. */ + KEY getKey(IN value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/MapFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/MapFunction.java index 52e36ea56..1b958ddfd 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/MapFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/MapFunction.java @@ -24,9 +24,6 @@ @FunctionalInterface public interface MapFunction extends Function { - /** - * Process input value to produce a new record. - */ - O map(T value); - + /** Process input value to produce a new record. */ + O map(T value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/ReduceFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/ReduceFunction.java index ed0177c59..92dab80d5 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/ReduceFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/base/ReduceFunction.java @@ -24,9 +24,6 @@ @FunctionalInterface public interface ReduceFunction extends Function { - /** - * Reduce new value with old value, and then return a new T. - */ - T reduce(T oldValue, T newValue); - + /** Reduce new value with old value, and then return a new T. */ + T reduce(T oldValue, T newValue); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/internal/CollectionSource.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/internal/CollectionSource.java index 3e0631c30..39375e42c 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/internal/CollectionSource.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/internal/CollectionSource.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.api.function.io.SourceFunction; @@ -34,65 +35,64 @@ public class CollectionSource extends RichFunction implements SourceFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(CollectionSource.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CollectionSource.class); - private int batchSize; - private List records; - private static Map readPosMap = new ConcurrentHashMap<>(); + private int batchSize; + private List records; + private static Map readPosMap = new ConcurrentHashMap<>(); - private transient RuntimeContext runtimeContext; + private transient RuntimeContext runtimeContext; - public CollectionSource(Collection records, int batchSize) { - this.records = new ArrayList<>(records); - this.batchSize = batchSize; - } + public CollectionSource(Collection records, int batchSize) { + this.records = new ArrayList<>(records); + this.batchSize = batchSize; + } - public CollectionSource(Collection records) { - this(records, 2); - } + public CollectionSource(Collection records) { + this(records, 2); + } - public CollectionSource(OUT... collections) { - this(Arrays.asList(collections)); - } + public CollectionSource(OUT... collections) { + this(Arrays.asList(collections)); + } - @Override - public void open(RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - } + @Override + public void open(RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + } - @Override - public void init(int parallel, int index) { - if (parallel != 1) { - List allRecords = records; - records = new ArrayList<>(); - for (int i = 0; i < allRecords.size(); i++) { - if (i % parallel == index) { - records.add(allRecords.get(i)); - } - } + @Override + public void init(int parallel, int index) { + if (parallel != 1) { + List allRecords = records; + records = new ArrayList<>(); + for (int i = 0; i < allRecords.size(); i++) { + if (i % parallel == index) { + records.add(allRecords.get(i)); } + } } + } - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - String taskName = runtimeContext.getTaskArgs().getTaskName(); - int taskId = runtimeContext.getTaskArgs().getTaskId(); - int readPos = readPosMap.getOrDefault(taskId, 0); - LOGGER.info("taskName:{} fetch batchId:{} readPos:{}", taskName, window.windowId(), readPos); - while (readPos < records.size()) { - OUT out = records.get(readPos); - if (window.assignWindow(out) == window.windowId() && ctx.collect(out)) { - readPos++; - } else { - break; - } - } - LOGGER.info("taskName:{} save batchId:{} readPos:{}", taskName, window.windowId(), readPos); - readPosMap.put(taskId, readPos); - return readPos < records.size(); + @Override + public boolean fetch(IWindow window, SourceContext ctx) throws Exception { + String taskName = runtimeContext.getTaskArgs().getTaskName(); + int taskId = runtimeContext.getTaskArgs().getTaskId(); + int readPos = readPosMap.getOrDefault(taskId, 0); + LOGGER.info("taskName:{} fetch batchId:{} readPos:{}", taskName, window.windowId(), readPos); + while (readPos < records.size()) { + OUT out = records.get(readPos); + if (window.assignWindow(out) == window.windowId() && ctx.collect(out)) { + readPos++; + } else { + break; + } } + LOGGER.info("taskName:{} save batchId:{} readPos:{}", taskName, window.windowId(), readPos); + readPosMap.put(taskId, readPos); + return readPos < records.size(); + } - @Override - public void close() { - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/GraphSourceFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/GraphSourceFunction.java index 52f0cf9f1..f7e8f3823 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/GraphSourceFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/GraphSourceFunction.java @@ -24,30 +24,27 @@ import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; -public interface GraphSourceFunction extends SourceFunction, IEdge>> { - - /** - * Fetch vertex/edge data from source by window, and collect data by ctx. - */ - @Override - default boolean fetch(IWindow, IEdge>> window, - SourceContext, IEdge>> ctx) throws Exception { - return fetch(window.windowId(), (GraphSourceContext) ctx); - } - - boolean fetch(long windowId, GraphSourceContext ctx) throws Exception; - - interface GraphSourceContext extends SourceContext, - IEdge>> { - - /** - * Partition vertex. - */ - void collectVertex(IVertex vertex) throws Exception; - - /** - * Partition edge. - */ - void collectEdge(IEdge edge) throws Exception; - } +public interface GraphSourceFunction + extends SourceFunction, IEdge>> { + + /** Fetch vertex/edge data from source by window, and collect data by ctx. */ + @Override + default boolean fetch( + IWindow, IEdge>> window, + SourceContext, IEdge>> ctx) + throws Exception { + return fetch(window.windowId(), (GraphSourceContext) ctx); + } + + boolean fetch(long windowId, GraphSourceContext ctx) throws Exception; + + interface GraphSourceContext + extends SourceContext, IEdge>> { + + /** Partition vertex. */ + void collectVertex(IVertex vertex) throws Exception; + + /** Partition edge. */ + void collectEdge(IEdge edge) throws Exception; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SinkFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SinkFunction.java index c28509d70..7cfe0a66f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SinkFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SinkFunction.java @@ -24,9 +24,6 @@ @FunctionalInterface public interface SinkFunction extends Function { - /** - * The write method for Outputting data t. - */ - void write(T t) throws Exception; - + /** The write method for Outputting data t. */ + void write(T t) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SourceFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SourceFunction.java index 31a4d66f0..6683e7569 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SourceFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/io/SourceFunction.java @@ -24,31 +24,23 @@ public interface SourceFunction extends Function { - /** - * Initialize source function. - */ - void init(int parallel, int index); + /** Initialize source function. */ + void init(int parallel, int index); - /** - * Fetch data from source by window, and collect data by ctx. - * - * @param window Used to split windows for source. - * @param ctx The source context. - */ - boolean fetch(IWindow window, SourceContext ctx) throws Exception; + /** + * Fetch data from source by window, and collect data by ctx. + * + * @param window Used to split windows for source. + * @param ctx The source context. + */ + boolean fetch(IWindow window, SourceContext ctx) throws Exception; - /** - * Close source function. - */ - void close(); + /** Close source function. */ + void close(); - interface SourceContext { - - /** - * Partition element data. - */ - boolean collect(T element) throws Exception; - - } + interface SourceContext { + /** Partition element data. */ + boolean collect(T element) throws Exception; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/IteratorFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/IteratorFunction.java index 45ecf9666..f74dde522 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/IteratorFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/IteratorFunction.java @@ -23,9 +23,6 @@ public interface IteratorFunction extends Function { - /** - * Get the max iteration count for graph compute or traversal. - */ - long getMaxIterationCount(); - + /** Get the max iteration count for graph compute or traversal. */ + long getMaxIterationCount(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/RichIteratorFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/RichIteratorFunction.java index c670656eb..6dd548f31 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/RichIteratorFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/function/iterator/RichIteratorFunction.java @@ -21,13 +21,9 @@ public interface RichIteratorFunction { - /** - * Initialize the windowId iteration. - */ - void initIteration(long windowId); + /** Initialize the windowId iteration. */ + void initIteration(long windowId); - /** - * Finish the windowId iteration. - */ - void finishIteration(long windowId); + /** Finish the windowId iteration. */ + void finishIteration(long windowId); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/PGraphWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/PGraphWindow.java index f181d19a3..afbf34583 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/PGraphWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/PGraphWindow.java @@ -31,35 +31,24 @@ public interface PGraphWindow { - /** - * Build ComputeWindowGraph based on vertexCentricCompute function. - */ - PGraphCompute compute(VertexCentricCompute vertexCentricCompute); + /** Build ComputeWindowGraph based on vertexCentricCompute function. */ + PGraphCompute compute(VertexCentricCompute vertexCentricCompute); - /** - * Build ComputeWindowGraph based on vertexCentricAggCompute function. - */ - PGraphCompute compute( - VertexCentricAggCompute vertexCentricAggCompute); + /** Build ComputeWindowGraph based on vertexCentricAggCompute function. */ + PGraphCompute compute( + VertexCentricAggCompute vertexCentricAggCompute); - /** - * Build PGraphTraversal based on vertexCentricTraversal function. - */ - PGraphTraversal traversal(VertexCentricTraversal vertexCentricTraversal); + /** Build PGraphTraversal based on vertexCentricTraversal function. */ + PGraphTraversal traversal( + VertexCentricTraversal vertexCentricTraversal); - /** - * Build PGraphTraversal based on vertexCentricTraversal function. - */ - PGraphTraversal traversal( - VertexCentricAggTraversal vertexCentricAggTraversal); + /** Build PGraphTraversal based on vertexCentricTraversal function. */ + PGraphTraversal traversal( + VertexCentricAggTraversal vertexCentricAggTraversal); - /** - * Returns the edges. - */ - PWindowStream> getEdges(); + /** Returns the edges. */ + PWindowStream> getEdges(); - /** - * Returns the vertices. - */ - PWindowStream> getVertices(); + /** Returns the vertices. */ + PWindowStream> getVertices(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java index c26edafb4..7b09a4381 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java @@ -21,17 +21,17 @@ import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; -public abstract class AbstractIncVertexCentricComputeAlgo> extends VertexCentricAlgo { +public abstract class AbstractIncVertexCentricComputeAlgo< + K, VV, EV, M, FUNC extends IncVertexCentricComputeFunction> + extends VertexCentricAlgo { - public AbstractIncVertexCentricComputeAlgo(long iterations) { - super(iterations); - } + public AbstractIncVertexCentricComputeAlgo(long iterations) { + super(iterations); + } - public AbstractIncVertexCentricComputeAlgo(long iterations, String name) { - super(iterations, name); - } - - public abstract FUNC getIncComputeFunction(); + public AbstractIncVertexCentricComputeAlgo(long iterations, String name) { + super(iterations, name); + } + public abstract FUNC getIncComputeFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricTraversalAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricTraversalAlgo.java index 645edd64e..7b687fe46 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricTraversalAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricTraversalAlgo.java @@ -21,18 +21,17 @@ import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; -public abstract class AbstractIncVertexCentricTraversalAlgo> +public abstract class AbstractIncVertexCentricTraversalAlgo< + K, VV, EV, M, R, FUNC extends IncVertexCentricTraversalFunction> extends VertexCentricAlgo { - public AbstractIncVertexCentricTraversalAlgo(long iterations) { - super(iterations); - } + public AbstractIncVertexCentricTraversalAlgo(long iterations) { + super(iterations); + } - public AbstractIncVertexCentricTraversalAlgo(long iterations, String name) { - super(iterations, name); - } - - public abstract FUNC getIncTraversalFunction(); + public AbstractIncVertexCentricTraversalAlgo(long iterations, String name) { + super(iterations, name); + } + public abstract FUNC getIncTraversalFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricComputeAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricComputeAlgo.java index 58a7fbf4a..cf6dc3313 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricComputeAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricComputeAlgo.java @@ -21,16 +21,17 @@ import org.apache.geaflow.api.graph.function.vc.VertexCentricComputeFunction; -public abstract class AbstractVertexCentricComputeAlgo> extends VertexCentricAlgo { +public abstract class AbstractVertexCentricComputeAlgo< + K, VV, EV, M, FUNC extends VertexCentricComputeFunction> + extends VertexCentricAlgo { - public AbstractVertexCentricComputeAlgo(long iterations) { - super(iterations); - } + public AbstractVertexCentricComputeAlgo(long iterations) { + super(iterations); + } - public AbstractVertexCentricComputeAlgo(long iterations, String name) { - super(iterations, name); - } + public AbstractVertexCentricComputeAlgo(long iterations, String name) { + super(iterations, name); + } - public abstract FUNC getComputeFunction(); + public abstract FUNC getComputeFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricTraversalAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricTraversalAlgo.java index cb56efd4e..8df83da42 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricTraversalAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractVertexCentricTraversalAlgo.java @@ -21,17 +21,17 @@ import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; -public abstract class AbstractVertexCentricTraversalAlgo> +public abstract class AbstractVertexCentricTraversalAlgo< + K, VV, EV, M, R, FUNC extends VertexCentricTraversalFunction> extends VertexCentricAlgo { - public AbstractVertexCentricTraversalAlgo(long iterations) { - super(iterations); - } + public AbstractVertexCentricTraversalAlgo(long iterations) { + super(iterations); + } - public AbstractVertexCentricTraversalAlgo(long iterations, String name) { - super(iterations, name); - } + public AbstractVertexCentricTraversalAlgo(long iterations, String name) { + super(iterations, name); + } - public abstract FUNC getTraversalFunction(); + public abstract FUNC getTraversalFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphAggregationAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphAggregationAlgo.java index 771d7052e..f101703ce 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphAggregationAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphAggregationAlgo.java @@ -24,7 +24,7 @@ /** * Interface for graph aggregation algo function. * - * @param The type of aggregate input iterm. + * @param The type of aggregate input iterm. * @param The type of partial aggregate iterm. * @param The type of partial aggregate result. * @param The type of global aggregate iterm. @@ -32,5 +32,5 @@ */ public interface GraphAggregationAlgo { - VertexCentricAggregateFunction getAggregateFunction(); + VertexCentricAggregateFunction getAggregateFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphExecAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphExecAlgo.java index d3e566bfe..8aa0b556e 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphExecAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/GraphExecAlgo.java @@ -20,8 +20,6 @@ package org.apache.geaflow.api.graph.base.algo; public enum GraphExecAlgo { - /** - * Vertex centric algorithm. - */ - VertexCentric, + /** Vertex centric algorithm. */ + VertexCentric, } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/VertexCentricAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/VertexCentricAlgo.java index 44b88422b..4d1cad6c8 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/VertexCentricAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/VertexCentricAlgo.java @@ -26,42 +26,40 @@ public abstract class VertexCentricAlgo implements IteratorFunction { - protected String name; - protected long iterations; + protected String name; + protected long iterations; - protected VertexCentricAlgo(long iterations) { - this.iterations = iterations; - this.name = this.getClass().getSimpleName(); - } + protected VertexCentricAlgo(long iterations) { + this.iterations = iterations; + this.name = this.getClass().getSimpleName(); + } - protected VertexCentricAlgo(long iterations, String name) { - this.iterations = iterations; - this.name = name; - } + protected VertexCentricAlgo(long iterations, String name) { + this.iterations = iterations; + this.name = name; + } - /** - * Returns vertex centric combine function. - */ - public abstract VertexCentricCombineFunction getCombineFunction(); + /** Returns vertex centric combine function. */ + public abstract VertexCentricCombineFunction getCombineFunction(); - public IEncoder getKeyEncoder() { - return null; - } + public IEncoder getKeyEncoder() { + return null; + } - public IEncoder getMessageEncoder() { - return null; - } + public IEncoder getMessageEncoder() { + return null; + } - public IGraphVCPartition getGraphPartition() { - return null; - } + public IGraphVCPartition getGraphPartition() { + return null; + } - @Override - public final long getMaxIterationCount() { - return this.iterations; - } + @Override + public final long getMaxIterationCount() { + return this.iterations; + } - public String getName() { - return this.name; - } + public String getName() { + return this.name; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricAggCompute.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricAggCompute.java index b9f878423..23d3ccf7c 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricAggCompute.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricAggCompute.java @@ -23,17 +23,16 @@ import org.apache.geaflow.api.graph.base.algo.GraphAggregationAlgo; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggComputeFunction; - public abstract class IncVertexCentricAggCompute - extends AbstractIncVertexCentricComputeAlgo> + extends AbstractIncVertexCentricComputeAlgo< + K, VV, EV, M, IncVertexCentricAggComputeFunction> implements GraphAggregationAlgo { + public IncVertexCentricAggCompute(long iterations) { + super(iterations); + } - public IncVertexCentricAggCompute(long iterations) { - super(iterations); - } - - public IncVertexCentricAggCompute(long iterations, String name) { - super(iterations, name); - } + public IncVertexCentricAggCompute(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricCompute.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricCompute.java index d6594b5e3..5107364f9 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricCompute.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/IncVertexCentricCompute.java @@ -22,17 +22,14 @@ import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricComputeAlgo; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; - public abstract class IncVertexCentricCompute - extends AbstractIncVertexCentricComputeAlgo> { - - public IncVertexCentricCompute(long iterations) { - super(iterations); - } + extends AbstractIncVertexCentricComputeAlgo< + K, VV, EV, M, IncVertexCentricComputeFunction> { - /** - * Returns incremental vertex centric compute function. - */ - public abstract IncVertexCentricComputeFunction getIncComputeFunction(); + public IncVertexCentricCompute(long iterations) { + super(iterations); + } + /** Returns incremental vertex centric compute function. */ + public abstract IncVertexCentricComputeFunction getIncComputeFunction(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/PGraphCompute.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/PGraphCompute.java index 6765facc2..0bfa1f394 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/PGraphCompute.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/PGraphCompute.java @@ -25,24 +25,15 @@ public interface PGraphCompute { - /** - * Returns the vertices of graph. - */ - PWindowStream> getVertices(); + /** Returns the vertices of graph. */ + PWindowStream> getVertices(); - /** - * Returns the PGraphCompute itself. - */ - PGraphCompute compute(); + /** Returns the PGraphCompute itself. */ + PGraphCompute compute(); - /** - * Set parallelism of graph compute and return the PGraphCompute itself. - */ - PGraphCompute compute(int parallelism); - - /** - * Returns the graph compute type, {@link GraphExecAlgo}. - */ - GraphExecAlgo getGraphComputeType(); + /** Set parallelism of graph compute and return the PGraphCompute itself. */ + PGraphCompute compute(int parallelism); + /** Returns the graph compute type, {@link GraphExecAlgo}. */ + GraphExecAlgo getGraphComputeType(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricAggCompute.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricAggCompute.java index 40a76b4e8..4998d2791 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricAggCompute.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricAggCompute.java @@ -19,20 +19,20 @@ package org.apache.geaflow.api.graph.compute; - import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricComputeAlgo; import org.apache.geaflow.api.graph.base.algo.GraphAggregationAlgo; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggComputeFunction; public abstract class VertexCentricAggCompute - extends AbstractVertexCentricComputeAlgo> + extends AbstractVertexCentricComputeAlgo< + K, VV, EV, M, VertexCentricAggComputeFunction> implements GraphAggregationAlgo { - public VertexCentricAggCompute(long iterations) { - super(iterations); - } + public VertexCentricAggCompute(long iterations) { + super(iterations); + } - public VertexCentricAggCompute(long iterations, String name) { - super(iterations, name); - } + public VertexCentricAggCompute(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricCompute.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricCompute.java index e3df0b8fc..f15a4dd74 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricCompute.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/compute/VertexCentricCompute.java @@ -19,18 +19,18 @@ package org.apache.geaflow.api.graph.compute; - import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricComputeAlgo; import org.apache.geaflow.api.graph.function.vc.VertexCentricComputeFunction; public abstract class VertexCentricCompute - extends AbstractVertexCentricComputeAlgo> { + extends AbstractVertexCentricComputeAlgo< + K, VV, EV, M, VertexCentricComputeFunction> { - public VertexCentricCompute(long iterations) { - super(iterations); - } + public VertexCentricCompute(long iterations) { + super(iterations); + } - public VertexCentricCompute(long iterations, String name) { - super(iterations, name); - } + public VertexCentricCompute(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/aggregate/VertexCentricAggContextFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/aggregate/VertexCentricAggContextFunction.java index e2a4c9a4e..fc0194246 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/aggregate/VertexCentricAggContextFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/aggregate/VertexCentricAggContextFunction.java @@ -26,22 +26,15 @@ * @param The aggregate result. */ public interface VertexCentricAggContextFunction { - /** - * Init aggregation context. - */ - void initContext(VertexCentricAggContext aggContext); + /** Init aggregation context. */ + void initContext(VertexCentricAggContext aggContext); - interface VertexCentricAggContext { + interface VertexCentricAggContext { - /** - * Return current global aggregation result. - */ - R getAggregateResult(); + /** Return current global aggregation result. */ + R getAggregateResult(); - /** - * Do aggregate for input iterm. - */ - void aggregate(I iterm); - - } + /** Do aggregate for input iterm. */ + void aggregate(I iterm); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggComputeFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggComputeFunction.java index bcc1fd09f..03ef3353f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggComputeFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggComputeFunction.java @@ -24,14 +24,12 @@ /** * Interface for incremental vertex centric compute function with graph aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. + * @param The message type during iterations. + * @param The type of aggregate input iterm. * @param The type of aggregate global result. */ public interface IncVertexCentricAggComputeFunction - extends IncVertexCentricComputeFunction, VertexCentricAggContextFunction { - -} + extends IncVertexCentricComputeFunction, VertexCentricAggContextFunction {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggTraversalFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggTraversalFunction.java index ac2bfdb83..478e66548 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggTraversalFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricAggTraversalFunction.java @@ -24,14 +24,13 @@ /** * Interface for incremental vertex centric traversal function with graph aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. + * @param The message type during iterations. + * @param The type of aggregate input iterm. * @param The type of aggregate global result. */ public interface IncVertexCentricAggTraversalFunction - extends IncVertexCentricTraversalFunction, VertexCentricAggContextFunction { - -} + extends IncVertexCentricTraversalFunction, + VertexCentricAggContextFunction {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricComputeFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricComputeFunction.java index 1aa1cd2fc..bbbc55528 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricComputeFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricComputeFunction.java @@ -25,23 +25,19 @@ /** * Interface for incremental vertex centric compute function with graph aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. + * @param The message type during iterations. */ -public interface IncVertexCentricComputeFunction extends IncVertexCentricFunction { +public interface IncVertexCentricComputeFunction + extends IncVertexCentricFunction { - /** - * Initialize compute function based on context. - */ - void init(IncGraphComputeContext incGraphContext); - - interface IncGraphComputeContext extends IncGraphContext { - /** - * Partition vertex. - */ - void collect(IVertex vertex); - } + /** Initialize compute function based on context. */ + void init(IncGraphComputeContext incGraphContext); + interface IncGraphComputeContext extends IncGraphContext { + /** Partition vertex. */ + void collect(IVertex vertex); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricTraversalFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricTraversalFunction.java index a70192589..ae2179e05 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricTraversalFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/IncVertexCentricTraversalFunction.java @@ -29,83 +29,58 @@ /** * Interface for incremental vertex centric traversal function. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The request type for traversal. + * @param The message type during iterations. + * @param The request type for traversal. */ -public interface IncVertexCentricTraversalFunction extends IncVertexCentricFunction { - - /** - * Open incremental traversal function based on context. - */ - void open(IncVertexCentricTraversalFuncContext vertexCentricFuncContext); - - /** - * Initialize the traversal by request. - */ - void init(ITraversalRequest traversalRequest); - - /** - * Finish iteration traversal. - */ - void finish(); - - /** - * Close resources in iteration traversal. - */ - void close(); - - interface IncVertexCentricTraversalFuncContext extends IncGraphContext { - - /** - * Active traversal request to process. - */ - void activeRequest(ITraversalRequest request); - - /** - * Receive the response. - */ - void takeResponse(ITraversalResponse response); - - /** - * Broadcast message. - */ - void broadcast(IGraphMessage message); - - /** - * Get the historical graph of graph state. - */ - TraversalHistoricalGraph getHistoricalGraph(); - - /** - * Get the traversal operator name. - */ - String getTraversalOpName(); - } - - - interface TraversalHistoricalGraph extends HistoricalGraph { - - /** - * Get the graph snapshot of specified version. - */ - TraversalGraphSnapShot getSnapShot(long version); - } - - interface TraversalGraphSnapShot extends GraphSnapShot { - - /** - * Returns the TraversalVertexQuery. - */ - TraversalVertexQuery vertex(); - - /** - * Returns the TraversalEdgeQuery. - */ - TraversalEdgeQuery edges(); - } +public interface IncVertexCentricTraversalFunction + extends IncVertexCentricFunction { + + /** Open incremental traversal function based on context. */ + void open(IncVertexCentricTraversalFuncContext vertexCentricFuncContext); + + /** Initialize the traversal by request. */ + void init(ITraversalRequest traversalRequest); + + /** Finish iteration traversal. */ + void finish(); + + /** Close resources in iteration traversal. */ + void close(); + + interface IncVertexCentricTraversalFuncContext + extends IncGraphContext { + + /** Active traversal request to process. */ + void activeRequest(ITraversalRequest request); + + /** Receive the response. */ + void takeResponse(ITraversalResponse response); + + /** Broadcast message. */ + void broadcast(IGraphMessage message); + + /** Get the historical graph of graph state. */ + TraversalHistoricalGraph getHistoricalGraph(); + + /** Get the traversal operator name. */ + String getTraversalOpName(); + } + + interface TraversalHistoricalGraph extends HistoricalGraph { + + /** Get the graph snapshot of specified version. */ + TraversalGraphSnapShot getSnapShot(long version); + } + + interface TraversalGraphSnapShot extends GraphSnapShot { + + /** Returns the TraversalVertexQuery. */ + TraversalVertexQuery vertex(); + + /** Returns the TraversalEdgeQuery. */ + TraversalEdgeQuery edges(); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggComputeFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggComputeFunction.java index 91141f898..732d2e139 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggComputeFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggComputeFunction.java @@ -24,14 +24,12 @@ /** * Interface for static vertex centric compute function with graph aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. + * @param The message type during iterations. + * @param The type of aggregate input iterm. * @param The type of aggregate global result. */ public interface VertexCentricAggComputeFunction - extends VertexCentricComputeFunction, VertexCentricAggContextFunction { - -} + extends VertexCentricComputeFunction, VertexCentricAggContextFunction {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggTraversalFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggTraversalFunction.java index 44b44fd91..41504e7a6 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggTraversalFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggTraversalFunction.java @@ -24,13 +24,13 @@ /** * Interface for static vertex centric traversal function with graph aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. + * @param The message type during iterations. + * @param The type of aggregate input iterm. * @param The type of aggregate global result. */ public interface VertexCentricAggTraversalFunction - extends VertexCentricTraversalFunction, VertexCentricAggContextFunction { -} + extends VertexCentricTraversalFunction, + VertexCentricAggContextFunction {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggregateFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggregateFunction.java index 891556634..253e5824f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggregateFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricAggregateFunction.java @@ -22,52 +22,49 @@ /** * Interface for graph aggregate function. * - * @param The type of aggregate input iterm. - * @param The type of partial aggregate iterm. + * @param The type of aggregate input iterm. + * @param The type of partial aggregate iterm. * @param The type of partial aggregate result. - * @param The type of global aggregate iterm. - * @param The type of global aggregate result. + * @param The type of global aggregate iterm. + * @param The type of global aggregate result. */ public interface VertexCentricAggregateFunction { - IPartialGraphAggFunction getPartialAggregation(); + IPartialGraphAggFunction getPartialAggregation(); - IGraphAggregateFunction getGlobalAggregation(); + IGraphAggregateFunction getGlobalAggregation(); - interface IPartialAggContext { + interface IPartialAggContext { - long getIteration(); + long getIteration(); - void collect(VALUE result); + void collect(VALUE result); + } - } + interface IPartialGraphAggFunction { - interface IPartialGraphAggFunction { + PAGG create(IPartialAggContext partialAggContext); - PAGG create(IPartialAggContext partialAggContext); + PRESULT aggregate(ITERM iterm, PAGG result); - PRESULT aggregate(ITERM iterm, PAGG result); + void finish(PRESULT result); + } - void finish(PRESULT result); - } + interface IGlobalGraphAggContext { - interface IGlobalGraphAggContext { + long getIteration(); - long getIteration(); + void broadcast(RESULT result); - void broadcast(RESULT result); + void terminate(); + } - void terminate(); + interface IGraphAggregateFunction { - } + GAGG create(IGlobalGraphAggContext globalGraphAggContext); - interface IGraphAggregateFunction { - - GAGG create(IGlobalGraphAggContext globalGraphAggContext); - - RESULT aggregate(ITERM iterm, GAGG agg); - - void finish(RESULT result); - } + RESULT aggregate(ITERM iterm, GAGG agg); + void finish(RESULT result); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricCombineFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricCombineFunction.java index c4aeabb90..a45a6cbde 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricCombineFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricCombineFunction.java @@ -26,9 +26,6 @@ */ public interface VertexCentricCombineFunction { - /** - * Combine old message with new message. - */ - M combine(M oldMessage, M newMessage); - + /** Combine old message with new message. */ + M combine(M oldMessage, M newMessage); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricComputeFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricComputeFunction.java index db7f30ec7..661f23484 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricComputeFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricComputeFunction.java @@ -20,40 +20,33 @@ package org.apache.geaflow.api.graph.function.vc; import java.util.Iterator; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction; /** * Interface for vertex centric compute function. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. + * @param The message type during iterations. */ -public interface VertexCentricComputeFunction extends VertexCentricFunction { - - /** - * Initialize compute function based on context. - */ - void init(VertexCentricComputeFuncContext vertexCentricFuncContext); - - /** - * Perform traversing based on message iterator during iterations. - */ - void compute(K vertexId, Iterator messageIterator); +public interface VertexCentricComputeFunction + extends VertexCentricFunction { - /** - * Finish iteration computation. - */ - void finish(); + /** Initialize compute function based on context. */ + void init(VertexCentricComputeFuncContext vertexCentricFuncContext); - interface VertexCentricComputeFuncContext extends VertexCentricFuncContext { + /** Perform traversing based on message iterator during iterations. */ + void compute(K vertexId, Iterator messageIterator); - /** - * Update new value of current vertex. - */ - void setNewVertexValue(VV value); + /** Finish iteration computation. */ + void finish(); - } + interface VertexCentricComputeFuncContext + extends VertexCentricFuncContext { + /** Update new value of current vertex. */ + void setNewVertexValue(VV value); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricTraversalFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricTraversalFunction.java index 22944237e..5744f234b 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricTraversalFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/VertexCentricTraversalFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.graph.function.vc; import java.util.Iterator; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.message.IGraphMessage; @@ -29,84 +30,63 @@ /** * Interface for vertex centric traversal function. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The request type for traversal. + * @param The message type during iterations. + * @param The request type for traversal. */ -public interface VertexCentricTraversalFunction extends VertexCentricFunction { +public interface VertexCentricTraversalFunction + extends VertexCentricFunction { - /** - * Open traversal function based on context. - */ - void open(VertexCentricTraversalFuncContext vertexCentricFuncContext); + /** Open traversal function based on context. */ + void open(VertexCentricTraversalFuncContext vertexCentricFuncContext); - /** - * Initialize the traversal by request. - */ - void init(ITraversalRequest traversalRequest); + /** Initialize the traversal by request. */ + void init(ITraversalRequest traversalRequest); - /** - * Perform traversing based on message iterator during iterations. - */ - void compute(K vertexId, Iterator messageIterator); + /** Perform traversing based on message iterator during iterations. */ + void compute(K vertexId, Iterator messageIterator); - /** - * Finish iteration traversal. - */ - void finish(); + /** Finish iteration traversal. */ + void finish(); + + /** Close resources in iteration traversal. */ + void close(); + + interface VertexCentricTraversalFuncContext + extends VertexCentricFuncContext { + + /** Receive the response. */ + void takeResponse(ITraversalResponse response); + + /** Returns the TraversalVertexQuery. */ + TraversalVertexQuery vertex(); + + /** Returns the TraversalEdgeQuery. */ + TraversalEdgeQuery edges(); + + /** Broadcast message. */ + void broadcast(IGraphMessage message); + + /** Get the traversal operator name. */ + String getTraversalOpName(); + } + + interface TraversalVertexQuery extends VertexQuery { + + /** Load vertex id iterator. */ + CloseableIterator loadIdIterator(); + } + + interface TraversalEdgeQuery extends EdgeQuery { /** - * Close resources in iteration traversal. + * Set vertex id. + * + * @param vertexId + * @return */ - void close(); - - interface VertexCentricTraversalFuncContext extends VertexCentricFuncContext { - - /** - * Receive the response. - */ - void takeResponse(ITraversalResponse response); - - /** - * Returns the TraversalVertexQuery. - */ - TraversalVertexQuery vertex(); - - /** - * Returns the TraversalEdgeQuery. - */ - TraversalEdgeQuery edges(); - - /** - * Broadcast message. - */ - void broadcast(IGraphMessage message); - - /** - * Get the traversal operator name. - */ - String getTraversalOpName(); - } - - interface TraversalVertexQuery extends VertexQuery { - - /** - * Load vertex id iterator. - */ - CloseableIterator loadIdIterator(); - } - - interface TraversalEdgeQuery extends EdgeQuery { - - /** - * Set vertex id. - * - * @param vertexId - * @return - */ - TraversalEdgeQuery withId(K vertexId); - } + TraversalEdgeQuery withId(K vertexId); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncGraphInferContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncGraphInferContext.java index 96cf60595..e6e583a61 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncGraphInferContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncGraphInferContext.java @@ -22,9 +22,6 @@ public interface IncGraphInferContext extends Closeable { - /** - * Model infer. - */ - OUT infer(Object... modelInputs); - + /** Model infer. */ + OUT infer(Object... modelInputs); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncVertexCentricFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncVertexCentricFunction.java index e05a77ecc..27025be58 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncVertexCentricFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/IncVertexCentricFunction.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.Function; import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; @@ -33,160 +34,108 @@ /** * Interface for incremental vertex centric compute function. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. + * @param The message type during iterations. */ public interface IncVertexCentricFunction extends Function { - /** - * Evolve based on temporary graph in first iteration. - * - * @param vertexId The vertex id. - * @param temporaryGraph The incremental memory graph. - */ - void evolve(K vertexId, TemporaryGraph temporaryGraph); - - /** - * Perform computing based on message iterator during iterations. - */ - void compute(K vertexId, Iterator messageIterator); - - /** - * Finish iteration computation, could add vertices and edges from temporaryGraph into mutableGraph. - */ - void finish(K vertexId, MutableGraph mutableGraph); - - interface IncGraphContext { - - /** - * Returns the job id. - */ - long getJobId(); - - /** - * Returns the current iteration id. - */ - long getIterationId(); - - /** - * Returns the runtime context. - */ - RuntimeContext getRuntimeContext(); - - /** - * Returns the mutable graph. - */ - MutableGraph getMutableGraph(); - - /** - * Returns the incremental graph. - */ - TemporaryGraph getTemporaryGraph(); - - /** - * Returns the historical graph on graph state. - */ - HistoricalGraph getHistoricalGraph(); + /** + * Evolve based on temporary graph in first iteration. + * + * @param vertexId The vertex id. + * @param temporaryGraph The incremental memory graph. + */ + void evolve(K vertexId, TemporaryGraph temporaryGraph); + + /** Perform computing based on message iterator during iterations. */ + void compute(K vertexId, Iterator messageIterator); + + /** + * Finish iteration computation, could add vertices and edges from temporaryGraph into + * mutableGraph. + */ + void finish(K vertexId, MutableGraph mutableGraph); - /** - * Send message to vertex. - */ - void sendMessage(K vertexId, M message); + interface IncGraphContext { - /** - * Send message to neighbors of current vertex. - */ - void sendMessageToNeighbors(M message); + /** Returns the job id. */ + long getJobId(); - } + /** Returns the current iteration id. */ + long getIterationId(); - interface TemporaryGraph { + /** Returns the runtime context. */ + RuntimeContext getRuntimeContext(); - /** - * Returns the current vertex. - */ - IVertex getVertex(); + /** Returns the mutable graph. */ + MutableGraph getMutableGraph(); - /** - * Returns the edges of current vertex. - */ - List> getEdges(); + /** Returns the incremental graph. */ + TemporaryGraph getTemporaryGraph(); - /** - * Update value of current vertex. - */ - void updateVertexValue(VV value); + /** Returns the historical graph on graph state. */ + HistoricalGraph getHistoricalGraph(); - } + /** Send message to vertex. */ + void sendMessage(K vertexId, M message); - interface HistoricalGraph { + /** Send message to neighbors of current vertex. */ + void sendMessageToNeighbors(M message); + } - /** - * Returns the latest version id of graph state. - */ - Long getLatestVersionId(); + interface TemporaryGraph { - /** - * Returns all version ids of graph state. - */ - List getAllVersionIds(); + /** Returns the current vertex. */ + IVertex getVertex(); - /** - * Returns all vertices of all versions. - */ - Map> getAllVertex(); + /** Returns the edges of current vertex. */ + List> getEdges(); - /** - * Get all vertices of specified versions. - */ - Map> getAllVertex(List versions); + /** Update value of current vertex. */ + void updateVertexValue(VV value); + } - /** - * Get all vertices of specified versions which satisfy the filter. - */ - Map> getAllVertex(List versions, IVertexFilter vertexFilter); + interface HistoricalGraph { - /** - * Get the graph snapshot of specified version. - */ - GraphSnapShot getSnapShot(long version); + /** Returns the latest version id of graph state. */ + Long getLatestVersionId(); - } + /** Returns all version ids of graph state. */ + List getAllVersionIds(); - interface GraphSnapShot { + /** Returns all vertices of all versions. */ + Map> getAllVertex(); - /** - * Returns the snapshot's version. - */ - long getVersion(); + /** Get all vertices of specified versions. */ + Map> getAllVertex(List versions); - /** - * Returns the VertexQuery. - */ - VertexQuery vertex(); + /** Get all vertices of specified versions which satisfy the filter. */ + Map> getAllVertex(List versions, IVertexFilter vertexFilter); - /** - * Returns the EdgeQuery. - */ - EdgeQuery edges(); + /** Get the graph snapshot of specified version. */ + GraphSnapShot getSnapShot(long version); + } - } + interface GraphSnapShot { - interface MutableGraph { + /** Returns the snapshot's version. */ + long getVersion(); - /** - * Add vertex into mutable graph with specified version. - */ - void addVertex(long version, IVertex vertex); + /** Returns the VertexQuery. */ + VertexQuery vertex(); - /** - * Add edge into mutable graph with specified version. - */ - void addEdge(long version, IEdge edge); + /** Returns the EdgeQuery. */ + EdgeQuery edges(); + } - } + interface MutableGraph { + /** Add vertex into mutable graph with specified version. */ + void addVertex(long version, IVertex vertex); + /** Add edge into mutable graph with specified version. */ + void addEdge(long version, IEdge edge); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/VertexCentricFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/VertexCentricFunction.java index afdf52b52..2fd99248b 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/VertexCentricFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/function/vc/base/VertexCentricFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.graph.function.vc.base; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.Function; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -29,84 +30,54 @@ public interface VertexCentricFunction extends Function { - interface VertexCentricFuncContext { - - /** - * Returns the job id. - */ - long getJobId(); - - /** - * Returns the current iteration id. - */ - long getIterationId(); - - /** - * Returns the runtime context. - */ - RuntimeContext getRuntimeContext(); - - /** - * Returns the VertexQuery. - */ - VertexQuery vertex(); - - /** - * Returns the EdgeQuery. - */ - EdgeQuery edges(); - - /** - * Send message to vertex. - */ - void sendMessage(K vertexId, M message); - - /** - * Send message to neighbors of current vertex. - */ - void sendMessageToNeighbors(M message); - - } - - interface VertexQuery { - - /** - * Set vertex id. - */ - VertexQuery withId(K vertexId); - - /** - * Returns the current vertex. - */ - IVertex get(); - - /** - * Get the vertex which satisfies filter condition. - */ - IVertex get(IFilter vertexFilter); - - } - - interface EdgeQuery { - - /** - * Returns the both edges. - */ - List> getEdges(); - - /** - * Returns the out edges. - */ - List> getOutEdges(); - - /** - * Returns the in edges. - */ - List> getInEdges(); - - /** - * Get the edges which satisfies filter condition. - */ - CloseableIterator> getEdges(IFilter edgeFilter); - } + interface VertexCentricFuncContext { + + /** Returns the job id. */ + long getJobId(); + + /** Returns the current iteration id. */ + long getIterationId(); + + /** Returns the runtime context. */ + RuntimeContext getRuntimeContext(); + + /** Returns the VertexQuery. */ + VertexQuery vertex(); + + /** Returns the EdgeQuery. */ + EdgeQuery edges(); + + /** Send message to vertex. */ + void sendMessage(K vertexId, M message); + + /** Send message to neighbors of current vertex. */ + void sendMessageToNeighbors(M message); + } + + interface VertexQuery { + + /** Set vertex id. */ + VertexQuery withId(K vertexId); + + /** Returns the current vertex. */ + IVertex get(); + + /** Get the vertex which satisfies filter condition. */ + IVertex get(IFilter vertexFilter); + } + + interface EdgeQuery { + + /** Returns the both edges. */ + List> getEdges(); + + /** Returns the out edges. */ + List> getOutEdges(); + + /** Returns the in edges. */ + List> getInEdges(); + + /** Get the edges which satisfies filter condition. */ + CloseableIterator> getEdges(IFilter edgeFilter); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/GraphMaterializeFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/GraphMaterializeFunction.java index c4aebacba..44cdaf5bd 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/GraphMaterializeFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/GraphMaterializeFunction.java @@ -25,13 +25,9 @@ public interface GraphMaterializeFunction extends Function { - /** - * Materialize vertex into state. - */ - void materializeVertex(IVertex vertex); + /** Materialize vertex into state. */ + void materializeVertex(IVertex vertex); - /** - * Materialize edge into state. - */ - void materializeEdge(IEdge edge); + /** Materialize edge into state. */ + void materializeEdge(IEdge edge); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/PGraphMaterialize.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/PGraphMaterialize.java index fc9cef255..e9e0db914 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/PGraphMaterialize.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/materialize/PGraphMaterialize.java @@ -23,9 +23,6 @@ public interface PGraphMaterialize extends PAction { - /** - * Build graph materialize. - */ - void materialize(); - + /** Build graph materialize. */ + void materialize(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricAggTraversal.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricAggTraversal.java index 68465fb46..f64c7fda4 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricAggTraversal.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricAggTraversal.java @@ -24,15 +24,15 @@ import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; public abstract class IncVertexCentricAggTraversal - extends AbstractIncVertexCentricTraversalAlgo> + extends AbstractIncVertexCentricTraversalAlgo< + K, VV, EV, M, R, IncVertexCentricAggTraversalFunction> implements GraphAggregationAlgo { - public IncVertexCentricAggTraversal(long iterations) { - super(iterations); - } + public IncVertexCentricAggTraversal(long iterations) { + super(iterations); + } - public IncVertexCentricAggTraversal(long iterations, String name) { - super(iterations, name); - } + public IncVertexCentricAggTraversal(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricTraversal.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricTraversal.java index ac3f90db8..90c342547 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricTraversal.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/IncVertexCentricTraversal.java @@ -23,13 +23,14 @@ import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; public abstract class IncVertexCentricTraversal - extends AbstractIncVertexCentricTraversalAlgo> { + extends AbstractIncVertexCentricTraversalAlgo< + K, VV, EV, M, R, IncVertexCentricTraversalFunction> { - public IncVertexCentricTraversal(long iterations) { - super(iterations); - } + public IncVertexCentricTraversal(long iterations) { + super(iterations); + } - public IncVertexCentricTraversal(long iterations, String name) { - super(iterations, name); - } + public IncVertexCentricTraversal(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/PGraphTraversal.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/PGraphTraversal.java index 12373a114..848ab45db 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/PGraphTraversal.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/PGraphTraversal.java @@ -19,8 +19,8 @@ package org.apache.geaflow.api.graph.traversal; - import java.util.List; + import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; import org.apache.geaflow.api.pdata.stream.window.PWindowStream; import org.apache.geaflow.model.traversal.ITraversalRequest; @@ -28,34 +28,22 @@ public interface PGraphTraversal { - /** - * Start traversal all computing. - */ - PWindowStream> start(); - - /** - * Start traversal computing with vid. - */ - PWindowStream> start(K vId); + /** Start traversal all computing. */ + PWindowStream> start(); - /** - * Start traversal computing with vid list. - */ - PWindowStream> start(List vId); + /** Start traversal computing with vid. */ + PWindowStream> start(K vId); - /** - * Start traversal computing with requests. - */ - PWindowStream> start(PWindowStream> requests); + /** Start traversal computing with vid list. */ + PWindowStream> start(List vId); - /** - * Set the traversal parallelism. - */ - PGraphTraversal withParallelism(int parallelism); + /** Start traversal computing with requests. */ + PWindowStream> start( + PWindowStream> requests); - /** - * Returns graph traversal type. - */ - GraphExecAlgo getGraphTraversalType(); + /** Set the traversal parallelism. */ + PGraphTraversal withParallelism(int parallelism); + /** Returns graph traversal type. */ + GraphExecAlgo getGraphTraversalType(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricAggTraversal.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricAggTraversal.java index cb91acacf..b5fefaa8f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricAggTraversal.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricAggTraversal.java @@ -24,16 +24,15 @@ import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; public abstract class VertexCentricAggTraversal - extends AbstractVertexCentricTraversalAlgo> + extends AbstractVertexCentricTraversalAlgo< + K, VV, EV, M, R, VertexCentricAggTraversalFunction> implements GraphAggregationAlgo { - public VertexCentricAggTraversal(long iterations) { - super(iterations); - } - - public VertexCentricAggTraversal(long iterations, String name) { - super(iterations, name); - } + public VertexCentricAggTraversal(long iterations) { + super(iterations); + } + public VertexCentricAggTraversal(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricTraversal.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricTraversal.java index 88fa8fd59..96c93dc06 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricTraversal.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/traversal/VertexCentricTraversal.java @@ -23,13 +23,14 @@ import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; public abstract class VertexCentricTraversal - extends AbstractVertexCentricTraversalAlgo> { + extends AbstractVertexCentricTraversalAlgo< + K, VV, EV, M, R, VertexCentricTraversalFunction> { - public VertexCentricTraversal(long iterations) { - super(iterations); - } + public VertexCentricTraversal(long iterations) { + super(iterations); + } - public VertexCentricTraversal(long iterations, String name) { - super(iterations, name); - } + public VertexCentricTraversal(long iterations, String name) { + super(iterations, name); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/IPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/IPartition.java index 63e6d012b..acd17203f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/IPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/IPartition.java @@ -21,14 +21,10 @@ import java.io.Serializable; -/** - * Interface for defining partition rule. - */ +/** Interface for defining partition rule. */ @FunctionalInterface public interface IPartition extends Serializable { - /** - * Compute the partition list for value. - */ - int[] partition(T value, int numPartition); + /** Compute the partition list for value. */ + int[] partition(T value, int numPartition); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/IGraphPartitioner.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/IGraphPartitioner.java index 851ef8d10..eb6b2d0ba 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/IGraphPartitioner.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/IGraphPartitioner.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.partition.graph; import java.io.Serializable; + import org.apache.geaflow.api.partition.IPartition; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -27,8 +28,7 @@ public interface IGraphPartitioner, E extends IEdge> extends Serializable { - IPartition getVertexPartitioner(); - - IPartition getEdgePartitioner(); + IPartition getVertexPartitioner(); + IPartition getEdgePartitioner(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomEdgeVCPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomEdgeVCPartition.java index 91f5217db..fb07c7fad 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomEdgeVCPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomEdgeVCPartition.java @@ -24,14 +24,14 @@ public class CustomEdgeVCPartition implements KeySelector, Integer> { - private IGraphVCPartition graphVCPartition; + private IGraphVCPartition graphVCPartition; - public CustomEdgeVCPartition(IGraphVCPartition graphVCPartition) { - this.graphVCPartition = graphVCPartition; - } + public CustomEdgeVCPartition(IGraphVCPartition graphVCPartition) { + this.graphVCPartition = graphVCPartition; + } - @Override - public Integer getKey(IEdge value) { - return graphVCPartition.getPartition(value.getSrcId()); - } + @Override + public Integer getKey(IEdge value) { + return graphVCPartition.getPartition(value.getSrcId()); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomVertexVCPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomVertexVCPartition.java index 8405628c7..628f2b555 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomVertexVCPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/CustomVertexVCPartition.java @@ -24,15 +24,14 @@ public class CustomVertexVCPartition implements KeySelector, Integer> { - private IGraphVCPartition graphVCPartition; + private IGraphVCPartition graphVCPartition; - public CustomVertexVCPartition(IGraphVCPartition graphVCPartition) { - this.graphVCPartition = graphVCPartition; - } - - @Override - public Integer getKey(IVertex vertex) { - return graphVCPartition.getPartition(vertex.getId()); - } + public CustomVertexVCPartition(IGraphVCPartition graphVCPartition) { + this.graphVCPartition = graphVCPartition; + } + @Override + public Integer getKey(IVertex vertex) { + return graphVCPartition.getPartition(vertex.getId()); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/IGraphVCPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/IGraphVCPartition.java index 97f3fd80d..1c4095483 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/IGraphVCPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/IGraphVCPartition.java @@ -23,6 +23,5 @@ public interface IGraphVCPartition extends Serializable { - Integer getPartition(K value); - + Integer getPartition(K value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/VertexCentricPartitioner.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/VertexCentricPartitioner.java index 24417027d..202304baa 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/VertexCentricPartitioner.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/edge/VertexCentricPartitioner.java @@ -24,33 +24,32 @@ import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; -public class VertexCentricPartitioner implements IGraphPartitioner, IEdge> { +public class VertexCentricPartitioner + implements IGraphPartitioner, IEdge> { - @Override - public IPartition getVertexPartitioner() { - return new DefaultVertexPartition(); - } + @Override + public IPartition getVertexPartitioner() { + return new DefaultVertexPartition(); + } - static class DefaultVertexPartition implements IPartition> { - - @Override - public int[] partition(IVertex value, int numPartition) { - return new int[]{Math.abs(value.getId().hashCode() % numPartition)}; - } - } + static class DefaultVertexPartition implements IPartition> { @Override - public IPartition getEdgePartitioner() { - return new DefaultEdgePartition(); + public int[] partition(IVertex value, int numPartition) { + return new int[] {Math.abs(value.getId().hashCode() % numPartition)}; } + } - static class DefaultEdgePartition implements IPartition> { + @Override + public IPartition getEdgePartitioner() { + return new DefaultEdgePartition(); + } - @Override - public int[] partition(IEdge value, int numPartition) { - return new int[]{Math.abs(value.getSrcId().hashCode() % numPartition)}; - } - } + static class DefaultEdgePartition implements IPartition> { + @Override + public int[] partition(IEdge value, int numPartition) { + return new int[] {Math.abs(value.getSrcId().hashCode() % numPartition)}; + } + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/request/DefaultTraversalRequestPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/request/DefaultTraversalRequestPartition.java index f6ac8ff93..48766066e 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/request/DefaultTraversalRequestPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/request/DefaultTraversalRequestPartition.java @@ -24,8 +24,8 @@ public class DefaultTraversalRequestPartition implements KeySelector, K> { - @Override - public K getKey(ITraversalRequest request) { - return request.getVId(); - } + @Override + public K getKey(ITraversalRequest request) { + return request.getVId(); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/vertex/GraphPartitioner.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/vertex/GraphPartitioner.java index a36fbadb3..48d9d08b0 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/vertex/GraphPartitioner.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/graph/vertex/GraphPartitioner.java @@ -24,11 +24,10 @@ import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; -public interface GraphPartitioner extends - IGraphPartitioner, IEdge> { +public interface GraphPartitioner + extends IGraphPartitioner, IEdge> { - IPartition> getVertexPartitioner(); - - IPartition> getEdgePartitioner(); + IPartition> getVertexPartitioner(); + IPartition> getEdgePartitioner(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/BroadCastPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/BroadCastPartition.java index 7d069b575..91232e2b0 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/BroadCastPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/BroadCastPartition.java @@ -20,19 +20,20 @@ package org.apache.geaflow.api.partition.kv; import java.util.stream.IntStream; + import org.apache.geaflow.api.partition.IPartition; public class BroadCastPartition implements IPartition { - private int[] partitions = null; + private int[] partitions = null; - @Override - public int[] partition(T value, int numPartition) { - if (partitions != null) { - return partitions; - } else { - partitions = IntStream.rangeClosed(0, numPartition - 1).toArray(); - return partitions; - } + @Override + public int[] partition(T value, int numPartition) { + if (partitions != null) { + return partitions; + } else { + partitions = IntStream.rangeClosed(0, numPartition - 1).toArray(); + return partitions; } + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/KeyByPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/KeyByPartition.java index a8236c3b1..4589c3eb9 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/KeyByPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/KeyByPartition.java @@ -24,17 +24,15 @@ public class KeyByPartition implements IPartition { - private int maxParallelism; - - public KeyByPartition(int maxParallelism) { - this.maxParallelism = maxParallelism; - } - - @Override - public int[] partition(K key, int numPartition) { - int channel = KeyGroupAssignment.assignKeyToParallelTask(key, maxParallelism, numPartition); - return new int[]{channel}; - } + private int maxParallelism; + public KeyByPartition(int maxParallelism) { + this.maxParallelism = maxParallelism; + } + @Override + public int[] partition(K key, int numPartition) { + int channel = KeyGroupAssignment.assignKeyToParallelTask(key, maxParallelism, numPartition); + return new int[] {channel}; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/RandomPartition.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/RandomPartition.java index 20ad54d69..670043681 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/RandomPartition.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/partition/kv/RandomPartition.java @@ -23,10 +23,10 @@ public class RandomPartition implements IPartition { - private long index = System.currentTimeMillis() % 173; + private long index = System.currentTimeMillis() % 173; - @Override - public int[] partition(T value, int numPartition) { - return new int[]{(int) ((index++) % numPartition)}; - } + @Override + public int[] partition(T value, int numPartition) { + return new int[] {(int) ((index++) % numPartition)}; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSink.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSink.java index cfb2b205a..c79848bee 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSink.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSink.java @@ -20,16 +20,16 @@ package org.apache.geaflow.api.pdata; import java.util.Map; + import org.apache.geaflow.api.pdata.base.PAction; public interface PStreamSink extends PAction { - PStreamSink withParallelism(int parallelism); - - PStreamSink withName(String name); + PStreamSink withParallelism(int parallelism); - PStreamSink withConfig(Map map); + PStreamSink withName(String name); - PStreamSink withConfig(String key, String value); + PStreamSink withConfig(Map map); + PStreamSink withConfig(String key, String value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSource.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSource.java index 31ba5b1f0..67f1db28c 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSource.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PStreamSource.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.Map; import java.util.ServiceLoader; + import org.apache.geaflow.api.function.io.SourceFunction; import org.apache.geaflow.api.pdata.base.PSource; import org.apache.geaflow.api.pdata.stream.PStream; @@ -36,49 +37,48 @@ public interface PStreamSource extends PStream, PSource { - Logger LOGGER = LoggerFactory.getLogger(PStreamSource.class); + Logger LOGGER = LoggerFactory.getLogger(PStreamSource.class); - static PStreamSource from(IPipelineContext pipelineContext, - SourceFunction sourceFunction, - IWindow window) { - LOGGER.info("load PStreamSource SPI Implementation"); - ServiceLoader serviceLoader = ServiceLoader.load(PStreamSource.class); - Iterator iterator = serviceLoader.iterator(); - boolean hasImpl = iterator.hasNext(); - if (hasImpl) { - PStreamSource streamSource = iterator.next(); - return streamSource.build(pipelineContext, sourceFunction, window); - } else { - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(PStreamSource.class.getSimpleName())); - } + static PStreamSource from( + IPipelineContext pipelineContext, SourceFunction sourceFunction, IWindow window) { + LOGGER.info("load PStreamSource SPI Implementation"); + ServiceLoader serviceLoader = ServiceLoader.load(PStreamSource.class); + Iterator iterator = serviceLoader.iterator(); + boolean hasImpl = iterator.hasNext(); + if (hasImpl) { + PStreamSource streamSource = iterator.next(); + return streamSource.build(pipelineContext, sourceFunction, window); + } else { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(PStreamSource.class.getSimpleName())); } + } - /** - * Build source by window. - * - * @param pipelineContext - * @param sourceFunction - * @param window - * @return - */ - PStreamSource build(IPipelineContext pipelineContext, SourceFunction sourceFunction, - IWindow window); - - PWindowSource window(IWindow window); + /** + * Build source by window. + * + * @param pipelineContext + * @param sourceFunction + * @param window + * @return + */ + PStreamSource build( + IPipelineContext pipelineContext, SourceFunction sourceFunction, IWindow window); - @Override - PStreamSource withConfig(Map map); + PWindowSource window(IWindow window); - @Override - PStreamSource withConfig(String key, String value); + @Override + PStreamSource withConfig(Map map); - @Override - PStreamSource withName(String name); + @Override + PStreamSource withConfig(String key, String value); - @Override - PStreamSource withParallelism(int parallelism); + @Override + PStreamSource withName(String name); - @Override - PStreamSource withEncoder(IEncoder encoder); + @Override + PStreamSource withParallelism(int parallelism); + @Override + PStreamSource withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PWindowCollect.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PWindowCollect.java index 640d2c026..1e440b6b6 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PWindowCollect.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/PWindowCollect.java @@ -20,20 +20,20 @@ package org.apache.geaflow.api.pdata; import java.util.Map; + import org.apache.geaflow.api.pdata.base.PAction; public interface PWindowCollect extends PAction { - @Override - PWindowCollect withParallelism(int parallelism); - - @Override - PWindowCollect withName(String name); + @Override + PWindowCollect withParallelism(int parallelism); - @Override - PWindowCollect withConfig(Map map); + @Override + PWindowCollect withName(String name); - @Override - PWindowCollect withConfig(String key, String value); + @Override + PWindowCollect withConfig(Map map); + @Override + PWindowCollect withConfig(String key, String value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PAction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PAction.java index 75848baa4..34de9cedd 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PAction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PAction.java @@ -21,21 +21,19 @@ import java.io.Serializable; import java.util.Map; + import org.apache.geaflow.api.pdata.PStreamSink; -/** - * Interface for triggering executing, such as {@link PStreamSink}. - */ +/** Interface for triggering executing, such as {@link PStreamSink}. */ public interface PAction extends Serializable { - int getId(); - - R withParallelism(int parallelism); + int getId(); - R withName(String name); + R withParallelism(int parallelism); - R withConfig(Map map); + R withName(String name); - R withConfig(String key, String value); + R withConfig(Map map); + R withConfig(String key, String value); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PData.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PData.java index 25a8c83b8..20799a9e0 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PData.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PData.java @@ -24,29 +24,18 @@ public interface PData extends Serializable { - /** - * Returns id. - */ - int getId(); - - /** - * Set config. - */ - R withConfig(Map map); - - /** - * Add key value pair config. - */ - R withConfig(String key, String value); - - /** - * Set name. - */ - R withName(String name); - - /** - * Set parallelism. - */ - R withParallelism(int parallelism); + /** Returns id. */ + int getId(); + /** Set config. */ + R withConfig(Map map); + + /** Add key value pair config. */ + R withConfig(String key, String value); + + /** Set name. */ + R withName(String name); + + /** Set parallelism. */ + R withParallelism(int parallelism); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PSource.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PSource.java index da4f5b85d..53ed89503 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PSource.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/base/PSource.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface PSource extends Serializable { - -} +public interface PSource extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PKeyStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PKeyStream.java index ce9813749..1d9d2c8bb 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PKeyStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PKeyStream.java @@ -20,35 +20,37 @@ package org.apache.geaflow.api.pdata.stream; import java.util.Map; + import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.ReduceFunction; import org.apache.geaflow.common.encoder.IEncoder; public interface PKeyStream extends PStream { - /** - * Default is incremental reduce compute, otherwise is window reduce compute if INC_STREAM_MATERIALIZE_DISABLE is true. - */ - PStream reduce(ReduceFunction reduceFunction); - - /** - * Default is incremental aggregate compute, otherwise is window aggregate compute if INC_STREAM_MATERIALIZE_DISABLE is true. - */ - PStream aggregate(AggregateFunction aggregateFunction); + /** + * Default is incremental reduce compute, otherwise is window reduce compute if + * INC_STREAM_MATERIALIZE_DISABLE is true. + */ + PStream reduce(ReduceFunction reduceFunction); - @Override - PKeyStream withConfig(Map map); + /** + * Default is incremental aggregate compute, otherwise is window aggregate compute if + * INC_STREAM_MATERIALIZE_DISABLE is true. + */ + PStream aggregate(AggregateFunction aggregateFunction); - @Override - PKeyStream withConfig(String key, String value); + @Override + PKeyStream withConfig(Map map); - @Override - PKeyStream withName(String name); + @Override + PKeyStream withConfig(String key, String value); - @Override - PKeyStream withParallelism(int parallelism); + @Override + PKeyStream withName(String name); - @Override - PKeyStream withEncoder(IEncoder encoder); + @Override + PKeyStream withParallelism(int parallelism); + @Override + PKeyStream withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PStream.java index aeb839a6a..a7bba3eb1 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PStream.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.pdata.stream; import java.util.Map; + import org.apache.geaflow.api.function.base.FilterFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -31,33 +32,31 @@ public interface PStream extends PData { - PStream map(MapFunction mapFunction); - - PStream filter(FilterFunction filterFunction); - - PStream flatMap(FlatMapFunction flatMapFunction); + PStream map(MapFunction mapFunction); - PStream union(PStream uStream); + PStream filter(FilterFunction filterFunction); - PStream broadcast(); + PStream flatMap(FlatMapFunction flatMapFunction); - PKeyStream keyBy(KeySelector selectorFunction); + PStream union(PStream uStream); - PStreamSink sink(SinkFunction sinkFunction); + PStream broadcast(); - @Override - PStream withConfig(Map map); + PKeyStream keyBy(KeySelector selectorFunction); - @Override - PStream withConfig(String key, String value); + PStreamSink sink(SinkFunction sinkFunction); - @Override - PStream withName(String name); + @Override + PStream withConfig(Map map); - @Override - PStream withParallelism(int parallelism); + @Override + PStream withConfig(String key, String value); + @Override + PStream withName(String name); - PStream withEncoder(IEncoder encoder); + @Override + PStream withParallelism(int parallelism); + PStream withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PUnionStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PUnionStream.java index d3b0333ef..3fc2cdb67 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PUnionStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/PUnionStream.java @@ -20,22 +20,22 @@ package org.apache.geaflow.api.pdata.stream; import java.util.Map; + import org.apache.geaflow.common.encoder.IEncoder; public interface PUnionStream extends PStream { - @Override - PUnionStream withConfig(Map map); - - @Override - PUnionStream withConfig(String key, String value); + @Override + PUnionStream withConfig(Map map); - @Override - PUnionStream withName(String name); + @Override + PUnionStream withConfig(String key, String value); - @Override - PUnionStream withParallelism(int parallelism); + @Override + PUnionStream withName(String name); - PUnionStream withEncoder(IEncoder encoder); + @Override + PUnionStream withParallelism(int parallelism); + PUnionStream withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PIncStreamView.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PIncStreamView.java index d456d5a5d..5f1da631e 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PIncStreamView.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PIncStreamView.java @@ -25,15 +25,9 @@ public interface PIncStreamView extends PStreamView { - /** - * Incremental reduce compute. - */ - PWindowStream reduce(ReduceFunction reduceFunction); - - /** - * Incremental aggregate compute. - */ - PWindowStream aggregate(AggregateFunction aggregateFunction); - + /** Incremental reduce compute. */ + PWindowStream reduce(ReduceFunction reduceFunction); + /** Incremental aggregate compute. */ + PWindowStream aggregate(AggregateFunction aggregateFunction); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PStreamView.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PStreamView.java index 615c7478b..2ce14f768 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PStreamView.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/view/PStreamView.java @@ -25,14 +25,9 @@ public interface PStreamView extends PView { - /** - * Initialize stream view. - */ - PStreamView init(IViewDesc viewDesc); - - /** - * Append windowStream into incremental view. - */ - PIncStreamView append(PWindowStream windowStream); + /** Initialize stream view. */ + PStreamView init(IViewDesc viewDesc); + /** Append windowStream into incremental view. */ + PIncStreamView append(PWindowStream windowStream); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowBroadcastStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowBroadcastStream.java index 49d375c22..ea3c7229d 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowBroadcastStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowBroadcastStream.java @@ -20,22 +20,23 @@ package org.apache.geaflow.api.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.common.encoder.IEncoder; public interface PWindowBroadcastStream extends PWindowStream { - @Override - PWindowBroadcastStream withConfig(Map config); + @Override + PWindowBroadcastStream withConfig(Map config); - @Override - PWindowBroadcastStream withConfig(String key, String value); + @Override + PWindowBroadcastStream withConfig(String key, String value); - @Override - PWindowBroadcastStream withName(String name); + @Override + PWindowBroadcastStream withName(String name); - @Override - PWindowBroadcastStream withParallelism(int parallelism); + @Override + PWindowBroadcastStream withParallelism(int parallelism); - @Override - PWindowBroadcastStream withEncoder(IEncoder encoder); + @Override + PWindowBroadcastStream withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowKeyStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowKeyStream.java index 9a1c149fa..406e827a9 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowKeyStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowKeyStream.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.ReduceFunction; import org.apache.geaflow.api.pdata.stream.PKeyStream; @@ -28,40 +29,36 @@ public interface PWindowKeyStream extends PWindowStream, PKeyStream { - /** - * Default is incremental reduce compute, otherwise is window reduce compute if INC_STREAM_MATERIALIZE_DISABLE is true. - */ - @Override - PWindowStream reduce(ReduceFunction reduceFunction); - - /** - * Default is incremental aggregate compute, otherwise is window aggregate compute if INC_STREAM_MATERIALIZE_DISABLE is true. - */ - @Override - PWindowStream aggregate(AggregateFunction aggregateFunction); - - /** - * Build incremental stream view. - */ - PIncStreamView materialize(); + /** + * Default is incremental reduce compute, otherwise is window reduce compute if + * INC_STREAM_MATERIALIZE_DISABLE is true. + */ + @Override + PWindowStream reduce(ReduceFunction reduceFunction); - @Override - PWindowKeyStream withConfig(Map config); + /** + * Default is incremental aggregate compute, otherwise is window aggregate compute if + * INC_STREAM_MATERIALIZE_DISABLE is true. + */ + @Override + PWindowStream aggregate(AggregateFunction aggregateFunction); - @Override - PWindowKeyStream withConfig(String key, String value); + /** Build incremental stream view. */ + PIncStreamView materialize(); - @Override - PWindowKeyStream withName(String name); + @Override + PWindowKeyStream withConfig(Map config); - @Override - PWindowKeyStream withParallelism(int parallelism); + @Override + PWindowKeyStream withConfig(String key, String value); - /** - * Set the encoder for performance. - */ - @Override - PWindowKeyStream withEncoder(IEncoder encoder); + @Override + PWindowKeyStream withName(String name); + @Override + PWindowKeyStream withParallelism(int parallelism); + /** Set the encoder for performance. */ + @Override + PWindowKeyStream withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowSource.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowSource.java index 5c37db1c4..c01955842 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowSource.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowSource.java @@ -20,24 +20,24 @@ package org.apache.geaflow.api.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.pdata.PStreamSource; import org.apache.geaflow.common.encoder.IEncoder; public interface PWindowSource extends PWindowStream, PStreamSource { - @Override - PWindowSource withConfig(Map map); - - @Override - PWindowSource withConfig(String key, String value); + @Override + PWindowSource withConfig(Map map); - @Override - PWindowSource withName(String name); + @Override + PWindowSource withConfig(String key, String value); - @Override - PWindowSource withParallelism(int parallelism); + @Override + PWindowSource withName(String name); - @Override - PWindowSource withEncoder(IEncoder encoder); + @Override + PWindowSource withParallelism(int parallelism); + @Override + PWindowSource withEncoder(IEncoder encoder); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowStream.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowStream.java index 944af863a..114be5fb4 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowStream.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/pdata/stream/window/PWindowStream.java @@ -20,6 +20,7 @@ package org.apache.geaflow.api.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.function.base.FilterFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -32,84 +33,56 @@ public interface PWindowStream extends PStream { - /** - * Transform T to R by mapFunction. - */ - @Override - PWindowStream map(MapFunction mapFunction); - - /** - * Filter T with filterFunction return false. - */ - @Override - PWindowStream filter(FilterFunction filterFunction); - - /** - * Transform T into 0~n R by flatMapFunction. - */ - @Override - PWindowStream flatMap(FlatMapFunction flatMapFunction); - - /** - * Perform union operation with uStream. - */ - @Override - PWindowStream union(PStream uStream); - - /** - * Broadcast records to downstream. - */ - PWindowBroadcastStream broadcast(); - - /** - * Partition by some key based on selectorFunction. - */ - @Override - PWindowKeyStream keyBy(KeySelector selectorFunction); - - /** - * Output data by sinkFunction. - */ - @Override - PStreamSink sink(SinkFunction sinkFunction); - - /** - * Collect result. - */ - PWindowCollect collect(); - - /** - * Set config. - */ - @Override - PWindowStream withConfig(Map map); - - /** - * Set config with key value pair. - */ - @Override - PWindowStream withConfig(String key, String value); - - /** - * Set name. - */ - @Override - PWindowStream withName(String name); - - /** - * Set parallelism of stream. - */ - @Override - PWindowStream withParallelism(int parallelism); - - /** - * Set encoder for performance. - */ - @Override - PWindowStream withEncoder(IEncoder encoder); - - /** - * Returns the parallelism. - */ - int getParallelism(); + /** Transform T to R by mapFunction. */ + @Override + PWindowStream map(MapFunction mapFunction); + + /** Filter T with filterFunction return false. */ + @Override + PWindowStream filter(FilterFunction filterFunction); + + /** Transform T into 0~n R by flatMapFunction. */ + @Override + PWindowStream flatMap(FlatMapFunction flatMapFunction); + + /** Perform union operation with uStream. */ + @Override + PWindowStream union(PStream uStream); + + /** Broadcast records to downstream. */ + PWindowBroadcastStream broadcast(); + + /** Partition by some key based on selectorFunction. */ + @Override + PWindowKeyStream keyBy(KeySelector selectorFunction); + + /** Output data by sinkFunction. */ + @Override + PStreamSink sink(SinkFunction sinkFunction); + + /** Collect result. */ + PWindowCollect collect(); + + /** Set config. */ + @Override + PWindowStream withConfig(Map map); + + /** Set config with key value pair. */ + @Override + PWindowStream withConfig(String key, String value); + + /** Set name. */ + @Override + PWindowStream withName(String name); + + /** Set parallelism of stream. */ + @Override + PWindowStream withParallelism(int parallelism); + + /** Set encoder for performance. */ + @Override + PWindowStream withEncoder(IEncoder encoder); + + /** Returns the parallelism. */ + int getParallelism(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CancellableTrait.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CancellableTrait.java index bb192d859..858a59b44 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CancellableTrait.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CancellableTrait.java @@ -21,8 +21,6 @@ public interface CancellableTrait { - /** - * Cancel current running task. - */ - void cancel(); + /** Cancel current running task. */ + void cancel(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CheckpointTrait.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CheckpointTrait.java index 18be3c423..2612f634b 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CheckpointTrait.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/CheckpointTrait.java @@ -21,8 +21,6 @@ public interface CheckpointTrait { - /** - * Do checkpoint. - */ - void checkpoint(long windowId); + /** Do checkpoint. */ + void checkpoint(long windowId); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/TransactionTrait.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/TransactionTrait.java index 1e7fb0116..13db51aad 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/TransactionTrait.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/trait/TransactionTrait.java @@ -21,13 +21,9 @@ public interface TransactionTrait { - /** - * Finish the window execution. - */ - void finish(long windowId); + /** Finish the window execution. */ + void finish(long windowId); - /** - * Rollback to windowId. - */ - void rollback(long windowId); + /** Rollback to windowId. */ + void rollback(long windowId); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/ITumblingWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/ITumblingWindow.java index a1c12709d..4492a4de0 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/ITumblingWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/ITumblingWindow.java @@ -19,6 +19,4 @@ package org.apache.geaflow.api.window; -public interface ITumblingWindow extends IWindow { - -} +public interface ITumblingWindow extends IWindow {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/IWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/IWindow.java index bbb671e7f..a5aa1a6cf 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/IWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/IWindow.java @@ -20,28 +20,20 @@ package org.apache.geaflow.api.window; import java.io.Serializable; + import org.apache.geaflow.api.function.Function; public interface IWindow extends Function, Serializable { - /** - * Returns the window id. - */ - long windowId(); - - /** - * Initialize window with windowId. - */ - void initWindow(long windowId); + /** Returns the window id. */ + long windowId(); - /** - * Assign window id for value. - */ - long assignWindow(T value); + /** Initialize window with windowId. */ + void initWindow(long windowId); - /** - * Return window type. - */ - WindowType getType(); + /** Assign window id for value. */ + long assignWindow(T value); + /** Return window type. */ + WindowType getType(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowFactory.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowFactory.java index 76bdb6ab8..0dc078709 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowFactory.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowFactory.java @@ -20,17 +20,17 @@ package org.apache.geaflow.api.window; import java.io.Serializable; + import org.apache.geaflow.api.window.impl.AllWindow; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; public class WindowFactory implements Serializable { - public static IWindow createSizeTumblingWindow(long size) { - return SizeTumblingWindow.of(size); - } - - public static IWindow allWindow() { - return AllWindow.getInstance(); - } + public static IWindow createSizeTumblingWindow(long size) { + return SizeTumblingWindow.of(size); + } + public static IWindow allWindow() { + return AllWindow.getInstance(); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowType.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowType.java index 908804095..e81188296 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowType.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/WindowType.java @@ -20,8 +20,8 @@ package org.apache.geaflow.api.window; public enum WindowType { - ALL_WINDOW, // all data - FIXED_TIME_TUMBLING_WINDOW, // window with time unit - SIZE_TUMBLING_WINDOW, // window with size unit - CUSTOM + ALL_WINDOW, // all data + FIXED_TIME_TUMBLING_WINDOW, // window with time unit + SIZE_TUMBLING_WINDOW, // window with size unit + CUSTOM } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/AllWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/AllWindow.java index 6e6e0d8ec..c770b0e49 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/AllWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/AllWindow.java @@ -24,31 +24,29 @@ public class AllWindow implements IWindow { - private static final long DEFAULT_GLOBAL_WINDOW_ID = 0; + private static final long DEFAULT_GLOBAL_WINDOW_ID = 0; - private AllWindow() { - } + private AllWindow() {} - @Override - public long windowId() { - return DEFAULT_GLOBAL_WINDOW_ID; - } + @Override + public long windowId() { + return DEFAULT_GLOBAL_WINDOW_ID; + } - @Override - public void initWindow(long windowId) { - } + @Override + public void initWindow(long windowId) {} - @Override - public long assignWindow(T value) { - return DEFAULT_GLOBAL_WINDOW_ID; - } + @Override + public long assignWindow(T value) { + return DEFAULT_GLOBAL_WINDOW_ID; + } - public static synchronized AllWindow getInstance() { - return new AllWindow<>(); - } + public static synchronized AllWindow getInstance() { + return new AllWindow<>(); + } - @Override - public WindowType getType() { - return WindowType.ALL_WINDOW; - } + @Override + public WindowType getType() { + return WindowType.ALL_WINDOW; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/FixedTimeTumblingWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/FixedTimeTumblingWindow.java index 8d2ef63ff..676ef8240 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/FixedTimeTumblingWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/FixedTimeTumblingWindow.java @@ -24,34 +24,34 @@ public class FixedTimeTumblingWindow implements ITumblingWindow { - private final long timeWindowInSecond; - private long windowId; - - public FixedTimeTumblingWindow(long timeWindowInSecond) { - this.timeWindowInSecond = timeWindowInSecond; - } - - public long getTimeWindowSize() { - return timeWindowInSecond; - } - - @Override - public long windowId() { - return windowId; - } - - @Override - public void initWindow(long windowId) { - this.windowId = windowId; - } - - @Override - public long assignWindow(T value) { - return windowId; - } - - @Override - public WindowType getType() { - return WindowType.FIXED_TIME_TUMBLING_WINDOW; - } + private final long timeWindowInSecond; + private long windowId; + + public FixedTimeTumblingWindow(long timeWindowInSecond) { + this.timeWindowInSecond = timeWindowInSecond; + } + + public long getTimeWindowSize() { + return timeWindowInSecond; + } + + @Override + public long windowId() { + return windowId; + } + + @Override + public void initWindow(long windowId) { + this.windowId = windowId; + } + + @Override + public long assignWindow(T value) { + return windowId; + } + + @Override + public WindowType getType() { + return WindowType.FIXED_TIME_TUMBLING_WINDOW; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/SizeTumblingWindow.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/SizeTumblingWindow.java index 5713dc5f3..dcc4695c2 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/SizeTumblingWindow.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/window/impl/SizeTumblingWindow.java @@ -24,45 +24,45 @@ public class SizeTumblingWindow implements ITumblingWindow { - private final long size; - private long count; - private long windowId; + private final long size; + private long count; + private long windowId; - public SizeTumblingWindow(long size) { - this.size = size; - this.count = 0; - } + public SizeTumblingWindow(long size) { + this.size = size; + this.count = 0; + } - public long getSize() { - return size; - } + public long getSize() { + return size; + } - public static SizeTumblingWindow of(long size) { - return new SizeTumblingWindow(size); - } + public static SizeTumblingWindow of(long size) { + return new SizeTumblingWindow(size); + } - @Override - public long windowId() { - return this.windowId; - } + @Override + public long windowId() { + return this.windowId; + } - @Override - public void initWindow(long windowId) { - this.windowId = windowId; - this.count = 0; - } + @Override + public void initWindow(long windowId) { + this.windowId = windowId; + this.count = 0; + } - @Override - public long assignWindow(T value) { - if (count++ < size) { - return windowId; - } else { - return windowId + 1; - } + @Override + public long assignWindow(T value) { + if (count++ < size) { + return windowId; + } else { + return windowId + 1; } + } - @Override - public WindowType getType() { - return WindowType.SIZE_TUMBLING_WINDOW; - } + @Override + public WindowType getType() { + return WindowType.SIZE_TUMBLING_WINDOW; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/Environment.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/Environment.java index f790b3751..d35947d7f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/Environment.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/Environment.java @@ -19,25 +19,24 @@ package org.apache.geaflow.env; - import org.apache.geaflow.env.ctx.EnvironmentContext; import org.apache.geaflow.env.ctx.IEnvironmentContext; import org.apache.geaflow.pipeline.Pipeline; public abstract class Environment implements IEnvironment { - protected Pipeline pipeline; - protected IEnvironmentContext context; + protected Pipeline pipeline; + protected IEnvironmentContext context; - public Environment() { - context = new EnvironmentContext(); - } + public Environment() { + context = new EnvironmentContext(); + } - public void addPipeline(Pipeline pipeline) { - this.pipeline = pipeline; - } + public void addPipeline(Pipeline pipeline) { + this.pipeline = pipeline; + } - public IEnvironmentContext getEnvironmentContext() { - return context; - } + public IEnvironmentContext getEnvironmentContext() { + return context; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/EnvironmentFactory.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/EnvironmentFactory.java index 9bf606432..956e0bdd9 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/EnvironmentFactory.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/EnvironmentFactory.java @@ -27,6 +27,7 @@ import java.util.Iterator; import java.util.Map; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -39,81 +40,82 @@ public class EnvironmentFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentFactory.class); - public static Environment onLocalEnvironment() { - Map config = new HashMap<>(); - config.put(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE.getKey(), Boolean.FALSE.toString()); - return onLocalEnvironment(config); - } + public static Environment onLocalEnvironment() { + Map config = new HashMap<>(); + config.put(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE.getKey(), Boolean.FALSE.toString()); + return onLocalEnvironment(config); + } - public static Environment onLocalEnvironment(String[] args) { - IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); - Map config = new HashMap<>(argsParser.parse(args)); - return onLocalEnvironment(config); - } + public static Environment onLocalEnvironment(String[] args) { + IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); + Map config = new HashMap<>(argsParser.parse(args)); + return onLocalEnvironment(config); + } - private static Environment onLocalEnvironment(Map config) { - config.put(RUN_LOCAL_MODE.getKey(), Boolean.TRUE.toString()); - // Set default state backend type to memory on local env. - config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - if (!config.containsKey(SYSTEM_OFFSET_BACKEND_TYPE.getKey())) { - config.put(SYSTEM_OFFSET_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - } - Environment environment = (Environment) loadEnvironment(EnvType.LOCAL); - environment.getEnvironmentContext().withConfig(config); - return environment; + private static Environment onLocalEnvironment(Map config) { + config.put(RUN_LOCAL_MODE.getKey(), Boolean.TRUE.toString()); + // Set default state backend type to memory on local env. + config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + if (!config.containsKey(SYSTEM_OFFSET_BACKEND_TYPE.getKey())) { + config.put(SYSTEM_OFFSET_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); } + Environment environment = (Environment) loadEnvironment(EnvType.LOCAL); + environment.getEnvironmentContext().withConfig(config); + return environment; + } - public static Environment onRayEnvironment() { - return (Environment) loadEnvironment(EnvType.RAY); - } + public static Environment onRayEnvironment() { + return (Environment) loadEnvironment(EnvType.RAY); + } - public static Environment onRayEnvironment(String[] args) { - Environment environment = (Environment) loadEnvironment(EnvType.RAY); - IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); - environment.getEnvironmentContext().withConfig(argsParser.parse(args)); - return environment; - } + public static Environment onRayEnvironment(String[] args) { + Environment environment = (Environment) loadEnvironment(EnvType.RAY); + IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); + environment.getEnvironmentContext().withConfig(argsParser.parse(args)); + return environment; + } - public static Environment onK8SEnvironment() { - return (Environment) loadEnvironment(EnvType.K8S); - } + public static Environment onK8SEnvironment() { + return (Environment) loadEnvironment(EnvType.K8S); + } - public static Environment onK8SEnvironment(String[] args) { - Environment environment = (Environment) loadEnvironment(EnvType.K8S); - IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); - environment.getEnvironmentContext().withConfig(argsParser.parse(args)); - return environment; - } + public static Environment onK8SEnvironment(String[] args) { + Environment environment = (Environment) loadEnvironment(EnvType.K8S); + IEnvironmentArgsParser argsParser = loadEnvironmentArgsParser(); + environment.getEnvironmentContext().withConfig(argsParser.parse(args)); + return environment; + } - private static IEnvironment loadEnvironment(EnvType envType) { - ServiceLoader contextLoader = ServiceLoader.load(IEnvironment.class); - Iterator contextIterable = contextLoader.iterator(); - while (contextIterable.hasNext()) { - IEnvironment environment = contextIterable.next(); - if (environment.getEnvType() == envType) { - LOGGER.info("loaded IEnvironment implementation {}", environment); - return environment; - } - } - LOGGER.error("NOT found IEnvironment implementation with type:{}", envType); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(IEnvironment.class.getSimpleName())); + private static IEnvironment loadEnvironment(EnvType envType) { + ServiceLoader contextLoader = ServiceLoader.load(IEnvironment.class); + Iterator contextIterable = contextLoader.iterator(); + while (contextIterable.hasNext()) { + IEnvironment environment = contextIterable.next(); + if (environment.getEnvType() == envType) { + LOGGER.info("loaded IEnvironment implementation {}", environment); + return environment; + } } + LOGGER.error("NOT found IEnvironment implementation with type:{}", envType); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IEnvironment.class.getSimpleName())); + } - private static IEnvironmentArgsParser loadEnvironmentArgsParser() { - ServiceLoader contextLoader = ServiceLoader.load(IEnvironmentArgsParser.class); - Iterator contextIterable = contextLoader.iterator(); - boolean hasNext = contextIterable.hasNext(); - IEnvironmentArgsParser argsParser; - if (hasNext) { - argsParser = contextIterable.next(); - } else { - // Use default argument parser. - argsParser = new EnvironmentArgumentParser(); - } - LOGGER.info("loaded IEnvironmentArgsParser implementation {}", argsParser); - return argsParser; + private static IEnvironmentArgsParser loadEnvironmentArgsParser() { + ServiceLoader contextLoader = + ServiceLoader.load(IEnvironmentArgsParser.class); + Iterator contextIterable = contextLoader.iterator(); + boolean hasNext = contextIterable.hasNext(); + IEnvironmentArgsParser argsParser; + if (hasNext) { + argsParser = contextIterable.next(); + } else { + // Use default argument parser. + argsParser = new EnvironmentArgumentParser(); } + LOGGER.info("loaded IEnvironmentArgsParser implementation {}", argsParser); + return argsParser; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/IEnvironment.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/IEnvironment.java index 43477b7a5..7f2d7c825 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/IEnvironment.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/IEnvironment.java @@ -20,46 +20,32 @@ package org.apache.geaflow.env; import java.io.Serializable; + import org.apache.geaflow.pipeline.IPipelineResult; public interface IEnvironment extends Serializable { - /** - * Initialize environment. - */ - void init(); - - /** - * Submit pipeline by geaflow client. - */ - IPipelineResult submit(); + /** Initialize environment. */ + void init(); - /** - * Shutdown geaflow client. - */ - void shutdown(); + /** Submit pipeline by geaflow client. */ + IPipelineResult submit(); - /** - * Returns the env type. - */ - EnvType getEnvType(); + /** Shutdown geaflow client. */ + void shutdown(); - enum EnvType { + /** Returns the env type. */ + EnvType getEnvType(); - /** - * Ray cluster. - */ - RAY, + enum EnvType { - /** - * K8s cluster. - */ - K8S, + /** Ray cluster. */ + RAY, - /** - * Local. - */ - LOCAL, - } + /** K8s cluster. */ + K8S, + /** Local. */ + LOCAL, + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/EnvironmentArgumentParser.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/EnvironmentArgumentParser.java index ad2f4859a..52ef37d6f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/EnvironmentArgumentParser.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/EnvironmentArgumentParser.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; + import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.ConfigKey; @@ -38,108 +39,106 @@ public class EnvironmentArgumentParser implements IEnvironmentArgsParser { - private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentArgumentParser.class); - - private static final String JOB_ARGS = "job"; - private static final String SYSTEM_ARGS = "system"; - private static final String CLUSTER_ARGS = "cluster"; - private static final String GEAFLOW_PREFIX = "geaflow"; - - private static final String STATE_CONFIG = "stateConfig"; - private static final String METRIC_CONFIG = "metricConfig"; - - @Override - public Map parse(String[] args) { - if (args == null || args.length == 0) { - return Collections.emptyMap(); - } - - LOGGER.warn("user config: {}", Arrays.asList(args)); - Map mainArgs = JsonUtils.parseJson2map(StringEscapeUtils.unescapeJava(args[0])); - - Map systemArgs = null; - if (mainArgs.containsKey(SYSTEM_ARGS)) { - systemArgs = parseSystemArgs(mainArgs.remove(SYSTEM_ARGS)); - } - - Map userArgs = null; - if (mainArgs.containsKey(JOB_ARGS)) { - userArgs = JsonUtils.parseJson2map(mainArgs.remove(JOB_ARGS)); - } - - if (systemArgs != null && userArgs != null) { - Set systemArgsKeys = systemArgs.keySet(); - Set userArgsKeys = userArgs.keySet(); - if (userArgsKeys.stream().anyMatch(systemArgsKeys::contains)) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.keyConflictsError(userArgsKeys, systemArgsKeys)); - } - } - - Map clusterArgs = null; - if (mainArgs.containsKey(CLUSTER_ARGS)) { - clusterArgs = JsonUtils.parseJson2map(mainArgs.remove(CLUSTER_ARGS)); - } - - LOGGER.info("build Env with systemArgs: {}, ", JsonUtils.toJsonString(systemArgs)); - LOGGER.info("build Env with userArgs: {}", JsonUtils.toJsonString(userArgs)); - LOGGER.info("build Env with clusterArgs: {}", JsonUtils.toJsonString(clusterArgs)); - - Map envConfig = new HashMap<>(); - if (systemArgs != null) { - envConfig.putAll(systemArgs); - } - if (clusterArgs != null) { - envConfig.putAll(clusterArgs); - } - if (userArgs != null) { - envConfig.putAll(userArgs); - } - LOGGER.info("build envConfig: {}", JsonUtils.toJsonString(envConfig)); - - return envConfig; - } + private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentArgumentParser.class); - private Map parseSystemArgs(String jsonStr) { - Map systemArgs = JsonUtils.parseJson2map(jsonStr); + private static final String JOB_ARGS = "job"; + private static final String SYSTEM_ARGS = "system"; + private static final String CLUSTER_ARGS = "cluster"; + private static final String GEAFLOW_PREFIX = "geaflow"; - ensureConfigExist(systemArgs, JOB_APP_NAME); - ensureConfigExist(systemArgs, JOB_UNIQUE_ID); + private static final String STATE_CONFIG = "stateConfig"; + private static final String METRIC_CONFIG = "metricConfig"; - Map finalSystemArgs = new HashMap<>(); - fillSystemConfig(finalSystemArgs, systemArgs); + @Override + public Map parse(String[] args) { + if (args == null || args.length == 0) { + return Collections.emptyMap(); + } - if (systemArgs.containsKey(STATE_CONFIG)) { - Map stateConfig = JsonUtils.parseJson2map(systemArgs.remove(STATE_CONFIG)); - fillSystemConfig(finalSystemArgs, stateConfig); + LOGGER.warn("user config: {}", Arrays.asList(args)); + Map mainArgs = JsonUtils.parseJson2map(StringEscapeUtils.unescapeJava(args[0])); + + Map systemArgs = null; + if (mainArgs.containsKey(SYSTEM_ARGS)) { + systemArgs = parseSystemArgs(mainArgs.remove(SYSTEM_ARGS)); + } - } - if (systemArgs.containsKey(METRIC_CONFIG)) { - Map metricConfig = JsonUtils.parseJson2map(systemArgs.remove(METRIC_CONFIG)); - fillSystemConfig(finalSystemArgs, metricConfig); - } + Map userArgs = null; + if (mainArgs.containsKey(JOB_ARGS)) { + userArgs = JsonUtils.parseJson2map(mainArgs.remove(JOB_ARGS)); + } - return finalSystemArgs; + if (systemArgs != null && userArgs != null) { + Set systemArgsKeys = systemArgs.keySet(); + Set userArgsKeys = userArgs.keySet(); + if (userArgsKeys.stream().anyMatch(systemArgsKeys::contains)) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.keyConflictsError(userArgsKeys, systemArgsKeys)); + } } - private static void ensureConfigExist(Map config, ConfigKey configKey) { - String key = configKey.getKey(); - if (!config.containsKey(key) - || !StringUtils.isNotBlank(config.get(key))) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.configKeyNotFound(key)); - } + Map clusterArgs = null; + if (mainArgs.containsKey(CLUSTER_ARGS)) { + clusterArgs = JsonUtils.parseJson2map(mainArgs.remove(CLUSTER_ARGS)); } - private static void fillSystemConfig(Map finalSystemArgs, Map tmp) { - for (Map.Entry entry : tmp.entrySet()) { - String key = entry.getKey(); - String value = entry.getValue(); - if (key.startsWith(GEAFLOW_PREFIX)) { - finalSystemArgs.put(key, value); - } else { - LOGGER.warn("ignore nonstandard system config: {} {}", key, value); - } - } + LOGGER.info("build Env with systemArgs: {}, ", JsonUtils.toJsonString(systemArgs)); + LOGGER.info("build Env with userArgs: {}", JsonUtils.toJsonString(userArgs)); + LOGGER.info("build Env with clusterArgs: {}", JsonUtils.toJsonString(clusterArgs)); + + Map envConfig = new HashMap<>(); + if (systemArgs != null) { + envConfig.putAll(systemArgs); + } + if (clusterArgs != null) { + envConfig.putAll(clusterArgs); } + if (userArgs != null) { + envConfig.putAll(userArgs); + } + LOGGER.info("build envConfig: {}", JsonUtils.toJsonString(envConfig)); + + return envConfig; + } + private Map parseSystemArgs(String jsonStr) { + Map systemArgs = JsonUtils.parseJson2map(jsonStr); + + ensureConfigExist(systemArgs, JOB_APP_NAME); + ensureConfigExist(systemArgs, JOB_UNIQUE_ID); + + Map finalSystemArgs = new HashMap<>(); + fillSystemConfig(finalSystemArgs, systemArgs); + + if (systemArgs.containsKey(STATE_CONFIG)) { + Map stateConfig = JsonUtils.parseJson2map(systemArgs.remove(STATE_CONFIG)); + fillSystemConfig(finalSystemArgs, stateConfig); + } + if (systemArgs.containsKey(METRIC_CONFIG)) { + Map metricConfig = JsonUtils.parseJson2map(systemArgs.remove(METRIC_CONFIG)); + fillSystemConfig(finalSystemArgs, metricConfig); + } + + return finalSystemArgs; + } + + private static void ensureConfigExist(Map config, ConfigKey configKey) { + String key = configKey.getKey(); + if (!config.containsKey(key) || !StringUtils.isNotBlank(config.get(key))) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.configKeyNotFound(key)); + } + } + + private static void fillSystemConfig( + Map finalSystemArgs, Map tmp) { + for (Map.Entry entry : tmp.entrySet()) { + String key = entry.getKey(); + String value = entry.getValue(); + if (key.startsWith(GEAFLOW_PREFIX)) { + finalSystemArgs.put(key, value); + } else { + LOGGER.warn("ignore nonstandard system config: {} {}", key, value); + } + } + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/IEnvironmentArgsParser.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/IEnvironmentArgsParser.java index e8124a410..ef4da6f79 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/IEnvironmentArgsParser.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/args/IEnvironmentArgsParser.java @@ -23,8 +23,6 @@ public interface IEnvironmentArgsParser { - /** - * Parse input arguments. - */ - Map parse(String[] args); + /** Parse input arguments. */ + Map parse(String[] args); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/EnvironmentContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/EnvironmentContext.java index 51d22b838..771fdaf8f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/EnvironmentContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/EnvironmentContext.java @@ -21,33 +21,33 @@ import java.util.Map; import java.util.UUID; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; public class EnvironmentContext implements IEnvironmentContext { - private final Configuration config; + private final Configuration config; - public EnvironmentContext() { - this.config = new Configuration(); - buildDefaultConfig(); - } + public EnvironmentContext() { + this.config = new Configuration(); + buildDefaultConfig(); + } - @Override - public void withConfig(Map config) { - if (config != null) { - this.config.putAll(config); - } + @Override + public void withConfig(Map config) { + if (config != null) { + this.config.putAll(config); } + } - public Configuration getConfig() { - return config; - } - - private void buildDefaultConfig() { - String jobUid = UUID.randomUUID().toString(); - this.config.put(ExecutionConfigKeys.JOB_UNIQUE_ID, jobUid); - this.config.put(ExecutionConfigKeys.JOB_APP_NAME, "geaflow" + jobUid); - } + public Configuration getConfig() { + return config; + } + private void buildDefaultConfig() { + String jobUid = UUID.randomUUID().toString(); + this.config.put(ExecutionConfigKeys.JOB_UNIQUE_ID, jobUid); + this.config.put(ExecutionConfigKeys.JOB_APP_NAME, "geaflow" + jobUid); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/IEnvironmentContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/IEnvironmentContext.java index 691cc3e19..9793f8d4a 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/IEnvironmentContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/env/ctx/IEnvironmentContext.java @@ -20,17 +20,14 @@ package org.apache.geaflow.env.ctx; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; public interface IEnvironmentContext { - /** - * Put config into global config. - */ - void withConfig(Map config); + /** Put config into global config. */ + void withConfig(Map config); - /** - * Get the config. - */ - Configuration getConfig(); + /** Get the config. */ + Configuration getConfig(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/IPipelineResult.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/IPipelineResult.java index 45e61b4b1..b87df7ed8 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/IPipelineResult.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/IPipelineResult.java @@ -21,18 +21,17 @@ public interface IPipelineResult { - /** - * PipelineTask execute successful. - * - * @return true if pipeline execute successful, otherwise false. - */ - boolean isSuccess(); + /** + * PipelineTask execute successful. + * + * @return true if pipeline execute successful, otherwise false. + */ + boolean isSuccess(); - /** - * Get the result of PipelineTask. - * Will block until PipelineTask success to return result. - * - * @return the final pipeline result. - */ - R get(); + /** + * Get the result of PipelineTask. Will block until PipelineTask success to return result. + * + * @return the final pipeline result. + */ + R get(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/Pipeline.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/Pipeline.java index 846dfa611..5660a1760 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/Pipeline.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/Pipeline.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.env.Environment; import org.apache.geaflow.pipeline.callback.TaskCallBack; import org.apache.geaflow.pipeline.service.PipelineService; @@ -33,70 +34,67 @@ public class Pipeline implements Serializable { - private transient Environment environment; - private List pipelineTaskList; - private List pipelineTaskCallbacks; - private List pipelineServices; - private Map viewDescMap; - - public Pipeline(Environment environment) { - this.environment = environment; - this.environment.addPipeline(this); - this.viewDescMap = new HashMap<>(); - this.pipelineTaskList = new ArrayList<>(); - this.pipelineTaskCallbacks = new ArrayList<>(); - this.pipelineServices = new ArrayList<>(); - } - - public void init() { - - } - - public Pipeline withView(String viewName, IViewDesc viewDesc) { - this.viewDescMap.put(viewName, viewDesc); - return this; - } - - public TaskCallBack submit(PipelineTask pipelineTask) { - this.pipelineTaskList.add(pipelineTask); - TaskCallBack taskCallBack = new TaskCallBack(); - this.pipelineTaskCallbacks.add(taskCallBack); - return taskCallBack; - } - - public Pipeline start(PipelineService pipelineService) { - this.pipelineServices.add(pipelineService); - return this; - } - - public Pipeline schedule(PipelineTask pipelineTask) { - this.pipelineTaskList.add(pipelineTask); - return this; - } - - public IPipelineResult execute() { - this.environment.init(); - return this.environment.submit(); - } - - - public void shutdown() { - this.environment.shutdown(); - } - - public List getViewDescMap() { - return viewDescMap.values().stream().collect(Collectors.toList()); - } - - public List getPipelineTaskList() { - return pipelineTaskList; - } - - public List getPipelineTaskCallbacks() { - return pipelineTaskCallbacks; - } - - public List getPipelineServices() { - return pipelineServices; - } + private transient Environment environment; + private List pipelineTaskList; + private List pipelineTaskCallbacks; + private List pipelineServices; + private Map viewDescMap; + + public Pipeline(Environment environment) { + this.environment = environment; + this.environment.addPipeline(this); + this.viewDescMap = new HashMap<>(); + this.pipelineTaskList = new ArrayList<>(); + this.pipelineTaskCallbacks = new ArrayList<>(); + this.pipelineServices = new ArrayList<>(); + } + + public void init() {} + + public Pipeline withView(String viewName, IViewDesc viewDesc) { + this.viewDescMap.put(viewName, viewDesc); + return this; + } + + public TaskCallBack submit(PipelineTask pipelineTask) { + this.pipelineTaskList.add(pipelineTask); + TaskCallBack taskCallBack = new TaskCallBack(); + this.pipelineTaskCallbacks.add(taskCallBack); + return taskCallBack; + } + + public Pipeline start(PipelineService pipelineService) { + this.pipelineServices.add(pipelineService); + return this; + } + + public Pipeline schedule(PipelineTask pipelineTask) { + this.pipelineTaskList.add(pipelineTask); + return this; + } + + public IPipelineResult execute() { + this.environment.init(); + return this.environment.submit(); + } + + public void shutdown() { + this.environment.shutdown(); + } + + public List getViewDescMap() { + return viewDescMap.values().stream().collect(Collectors.toList()); + } + + public List getPipelineTaskList() { + return pipelineTaskList; + } + + public List getPipelineTaskCallbacks() { + return pipelineTaskCallbacks; + } + + public List getPipelineServices() { + return pipelineServices; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/PipelineFactory.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/PipelineFactory.java index 7400d633e..591dd8168 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/PipelineFactory.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/PipelineFactory.java @@ -23,12 +23,11 @@ public class PipelineFactory { - public static Pipeline buildPipeline(Environment environment) { - return new Pipeline(environment); - } - - public static SchedulerPipeline buildSchedulerPipeline(Environment environment) { - return new SchedulerPipeline(environment); - } + public static Pipeline buildPipeline(Environment environment) { + return new Pipeline(environment); + } + public static SchedulerPipeline buildSchedulerPipeline(Environment environment) { + return new SchedulerPipeline(environment); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/SchedulerPipeline.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/SchedulerPipeline.java index 93f1c5242..542a96a8b 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/SchedulerPipeline.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/SchedulerPipeline.java @@ -25,18 +25,18 @@ public class SchedulerPipeline extends Pipeline { - private ISchedule scheduler; + private ISchedule scheduler; - public SchedulerPipeline(Environment environment) { - super(environment); - } + public SchedulerPipeline(Environment environment) { + super(environment); + } - public SchedulerPipeline schedule(PipelineTask pipelineTask) { - submit(pipelineTask); - return this; - } + public SchedulerPipeline schedule(PipelineTask pipelineTask) { + submit(pipelineTask); + return this; + } - public void with(ISchedule scheduler) { - this.scheduler = scheduler; - } + public void with(ISchedule scheduler) { + this.scheduler = scheduler; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/ICallbackFunction.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/ICallbackFunction.java index 53c8eebc8..37ebd3dc7 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/ICallbackFunction.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/ICallbackFunction.java @@ -21,13 +21,9 @@ public interface ICallbackFunction { - /** - * Pass window id by callback. - */ - void window(long windowId); + /** Pass window id by callback. */ + void window(long windowId); - /** - * Logical definition of all Windows finished. - */ - void terminal(); + /** Logical definition of all Windows finished. */ + void terminal(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/TaskCallBack.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/TaskCallBack.java index 0ba5a9832..02bd3c971 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/TaskCallBack.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/callback/TaskCallBack.java @@ -21,13 +21,13 @@ public class TaskCallBack { - private ICallbackFunction callbackFunction; + private ICallbackFunction callbackFunction; - public void addCallBack(ICallbackFunction callbackFunction) { - this.callbackFunction = callbackFunction; - } + public void addCallBack(ICallbackFunction callbackFunction) { + this.callbackFunction = callbackFunction; + } - public ICallbackFunction getCallbackFunction() { - return callbackFunction; - } + public ICallbackFunction getCallbackFunction() { + return callbackFunction; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineContext.java index 438812ca0..9fbd084f2 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineContext.java @@ -23,28 +23,18 @@ public interface IPipelineContext { - /** - * Generate pipeline id. - */ - int generateId(); + /** Generate pipeline id. */ + int generateId(); - /** - * Add the action that triggers the computation. - */ - void addPAction(PAction action); + /** Add the action that triggers the computation. */ + void addPAction(PAction action); - enum PipelineType { - /** - * Time series graph simulation. - */ - TimeSeriesGraphSimulation, - /** - * Time series graph analytics. - */ - TimeSeriesGraphAnalytics, - /** - * Interactive graph analytics. - */ - InteractiveGraphAnalytics, - } + enum PipelineType { + /** Time series graph simulation. */ + TimeSeriesGraphSimulation, + /** Time series graph analytics. */ + TimeSeriesGraphAnalytics, + /** Interactive graph analytics. */ + InteractiveGraphAnalytics, + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineExecutorContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineExecutorContext.java index bba032706..bac912d63 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineExecutorContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/context/IPipelineExecutorContext.java @@ -21,28 +21,18 @@ public interface IPipelineExecutorContext { - /** - * Returns the driver id. - */ - String getDriverId(); + /** Returns the driver id. */ + String getDriverId(); - /** - * Returns the pipeline task id. - */ - long getPipelineTaskId(); + /** Returns the pipeline task id. */ + long getPipelineTaskId(); - /** - * Returns the pipeline task name. - */ - String getPipelineTaskName(); + /** Returns the pipeline task name. */ + String getPipelineTaskName(); - /** - * Returns the pipeline context. - */ - IPipelineContext getPipelineContext(); + /** Returns the pipeline context. */ + IPipelineContext getPipelineContext(); - /** - * Returns the driver index. - */ - int getDriverIndex(); + /** Returns the driver index. */ + int getDriverIndex(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/job/IPipelineJobContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/job/IPipelineJobContext.java index 51709b827..dc602bbad 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/job/IPipelineJobContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/job/IPipelineJobContext.java @@ -20,6 +20,7 @@ package org.apache.geaflow.pipeline.job; import java.io.Serializable; + import org.apache.geaflow.api.function.io.SourceFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -34,35 +35,24 @@ public interface IPipelineJobContext extends Serializable { - /** - * Returns pipeline id. - */ - long getId(); + /** Returns pipeline id. */ + long getId(); - /** - * Returns pipeline config. - */ - Configuration getConfig(); + /** Returns pipeline config. */ + Configuration getConfig(); - /** - * Build window source with source function and window. - */ - PWindowSource buildSource(SourceFunction sourceFunction, IWindow window); + /** Build window source with source function and window. */ + PWindowSource buildSource(SourceFunction sourceFunction, IWindow window); - /** - * Returns graph view with view name. - */ - PGraphView getGraphView(String viewName); + /** Returns graph view with view name. */ + PGraphView getGraphView(String viewName); - /** - * Create graph view with view desc. - */ - PGraphView createGraphView(IViewDesc viewDesc); + /** Create graph view with view desc. */ + PGraphView createGraphView(IViewDesc viewDesc); - /** - * Build window stream graph. - */ - PGraphWindow buildWindowStreamGraph(PWindowStream> vertexWindowSteam, - PWindowStream> edgeWindowStream, - GraphViewDesc graphViewDesc); + /** Build window stream graph. */ + PGraphWindow buildWindowStreamGraph( + PWindowStream> vertexWindowSteam, + PWindowStream> edgeWindowStream, + GraphViewDesc graphViewDesc); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/scheduler/ISchedule.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/scheduler/ISchedule.java index b210df2f9..02cbcf7f5 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/scheduler/ISchedule.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/scheduler/ISchedule.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface ISchedule extends Serializable { - -} +public interface ISchedule extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceContext.java index a0e555ce1..dccb324af 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceContext.java @@ -23,15 +23,9 @@ public interface IPipelineServiceContext extends IPipelineJobContext { - /** - * Returns request from client. - */ - Object getRequest(); - - /** - * Sets the response. - */ - void response(Object response); - + /** Returns request from client. */ + Object getRequest(); + /** Sets the response. */ + void response(Object response); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceExecutorContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceExecutorContext.java index 00916df4f..00838d244 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceExecutorContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IPipelineServiceExecutorContext.java @@ -24,13 +24,9 @@ public interface IPipelineServiceExecutorContext extends IPipelineExecutorContext { - /** - * Returns the config of pipeline service executor context. - */ - Configuration getConfiguration(); + /** Returns the config of pipeline service executor context. */ + Configuration getConfiguration(); - /** - * Returns the pipeline service which need be started. - */ - PipelineService getPipelineService(); + /** Returns the pipeline service which need be started. */ + PipelineService getPipelineService(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IServiceServer.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IServiceServer.java index d7762d1e7..d55b594a6 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IServiceServer.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/IServiceServer.java @@ -23,24 +23,15 @@ public interface IServiceServer extends Serializable { - /** - * Init service server context. - */ - void init(IPipelineServiceExecutorContext context); + /** Init service server context. */ + void init(IPipelineServiceExecutorContext context); - /** - * Start service server. - */ - void startServer(); + /** Start service server. */ + void startServer(); - /** - * Stop service server. - */ - void stopServer(); - - /** - * Returns service type. - */ - ServiceType getServiceType(); + /** Stop service server. */ + void stopServer(); + /** Returns service type. */ + ServiceType getServiceType(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/PipelineService.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/PipelineService.java index 26b7704dc..79bdfe5c8 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/PipelineService.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/PipelineService.java @@ -23,10 +23,6 @@ public interface PipelineService extends Serializable { - /** - * Define the execution logic of service. - */ - void execute(IPipelineServiceContext pipelineServiceContext); - - + /** Define the execution logic of service. */ + void execute(IPipelineServiceContext pipelineServiceContext); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/ServiceType.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/ServiceType.java index 32d63bfd4..a7e813abd 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/ServiceType.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/service/ServiceType.java @@ -24,33 +24,26 @@ public enum ServiceType { - /** - * Rpc analytics service. - */ - analytics_rpc, - - /** - * Http analytics service. - */ - analytics_http, - - /** - * Storage service. - */ - storage; - - public static ServiceType getEnum(String type) { - for (ServiceType serviceType : values()) { - if (serviceType.name().equalsIgnoreCase(type)) { - return serviceType; - } - } - return analytics_rpc; - } + /** Rpc analytics service. */ + analytics_rpc, + + /** Http analytics service. */ + analytics_http, + + /** Storage service. */ + storage; - public static ServiceType getEnum(Configuration config) { - String type = config.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); - return getEnum(type); + public static ServiceType getEnum(String type) { + for (ServiceType serviceType : values()) { + if (serviceType.name().equalsIgnoreCase(type)) { + return serviceType; + } } + return analytics_rpc; + } + public static ServiceType getEnum(Configuration config) { + String type = config.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); + return getEnum(type); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/IPipelineTaskContext.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/IPipelineTaskContext.java index c55c64862..bd9770942 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/IPipelineTaskContext.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/IPipelineTaskContext.java @@ -21,6 +21,4 @@ import org.apache.geaflow.pipeline.job.IPipelineJobContext; -public interface IPipelineTaskContext extends IPipelineJobContext { - -} +public interface IPipelineTaskContext extends IPipelineJobContext {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/PipelineTask.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/PipelineTask.java index b6ff14770..3d25e0b1f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/PipelineTask.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/pipeline/task/PipelineTask.java @@ -23,9 +23,6 @@ public interface PipelineTask extends Serializable { - /** - * Define the execution logic of task. - */ - void execute(IPipelineTaskContext pipelineTaskCxt); - + /** Define the execution logic of task. */ + void execute(IPipelineTaskContext pipelineTaskCxt); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/GraphViewBuilder.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/GraphViewBuilder.java index ebf4dd84d..c9f724ffd 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/GraphViewBuilder.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/GraphViewBuilder.java @@ -19,66 +19,69 @@ package org.apache.geaflow.view; -import com.google.common.base.Preconditions; import java.util.Map; + import org.apache.geaflow.api.partition.graph.vertex.GraphPartitioner; import org.apache.geaflow.model.graph.meta.GraphMetaType; import org.apache.geaflow.utils.math.MathUtil; import org.apache.geaflow.view.IViewDesc.BackendType; import org.apache.geaflow.view.graph.GraphViewDesc; -public class GraphViewBuilder { - - public static final String DEFAULT_GRAPH = "default_graph"; - - private final String viewName; - - private int shardNum; - private BackendType backend; - private GraphPartitioner partitioner; - private GraphMetaType graphMetaType; - private Map props; - private long latestVersion = -1L; - - private GraphViewBuilder(String name) { - this.viewName = name; - } - - public static GraphViewBuilder createGraphView(String name) { - return new GraphViewBuilder(name); - } - - public GraphViewBuilder withShardNum(int shardNum) { - this.shardNum = shardNum; - return this; - } - - public GraphViewBuilder withBackend(BackendType backend) { - this.backend = backend; - return this; - } - - public GraphViewBuilder withSchema(GraphMetaType graphMetaType) { - this.graphMetaType = graphMetaType; - return this; - } - - public GraphViewBuilder withProps(Map props) { - this.props = props; - return this; - } - - public GraphViewBuilder withLatestVersion(long latestVersion) { - this.latestVersion = latestVersion; - return this; - } - - public GraphViewDesc build() { - Preconditions.checkArgument(this.viewName != null, "this name is empty"); - Preconditions.checkArgument(MathUtil.isPowerOf2(this.shardNum), "this shardNum must be power of 2"); - Preconditions.checkArgument(this.backend != null, "this backend is null"); +import com.google.common.base.Preconditions; - return new GraphViewDesc(viewName, shardNum, backend, partitioner, graphMetaType, props, latestVersion); - } +public class GraphViewBuilder { + public static final String DEFAULT_GRAPH = "default_graph"; + + private final String viewName; + + private int shardNum; + private BackendType backend; + private GraphPartitioner partitioner; + private GraphMetaType graphMetaType; + private Map props; + private long latestVersion = -1L; + + private GraphViewBuilder(String name) { + this.viewName = name; + } + + public static GraphViewBuilder createGraphView(String name) { + return new GraphViewBuilder(name); + } + + public GraphViewBuilder withShardNum(int shardNum) { + this.shardNum = shardNum; + return this; + } + + public GraphViewBuilder withBackend(BackendType backend) { + this.backend = backend; + return this; + } + + public GraphViewBuilder withSchema(GraphMetaType graphMetaType) { + this.graphMetaType = graphMetaType; + return this; + } + + public GraphViewBuilder withProps(Map props) { + this.props = props; + return this; + } + + public GraphViewBuilder withLatestVersion(long latestVersion) { + this.latestVersion = latestVersion; + return this; + } + + public GraphViewDesc build() { + Preconditions.checkArgument(this.viewName != null, "this name is empty"); + Preconditions.checkArgument( + MathUtil.isPowerOf2(this.shardNum), "this shardNum must be power of 2"); + Preconditions.checkArgument(this.backend != null, "this backend is null"); + + return new GraphViewDesc( + viewName, shardNum, backend, partitioner, graphMetaType, props, latestVersion); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/IViewDesc.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/IViewDesc.java index 9bd006b63..31ff208f8 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/IViewDesc.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/IViewDesc.java @@ -24,59 +24,47 @@ public interface IViewDesc extends Serializable { - /** - * Returns the view name. - */ - String getName(); + /** Returns the view name. */ + String getName(); - /** - * Returns the shard num of view. - */ - int getShardNum(); + /** Returns the shard num of view. */ + int getShardNum(); - /** - * Returns the data model. - */ - DataModel getDataModel(); + /** Returns the data model. */ + DataModel getDataModel(); - /** - * Returns the backend type. - */ - BackendType getBackend(); + /** Returns the backend type. */ + BackendType getBackend(); - /** - * Returns the view properties. - */ - Map getViewProps(); + /** Returns the view properties. */ + Map getViewProps(); + enum DataModel { + // Table data model. + TABLE, + // Graph data model. + GRAPH, + } - enum DataModel { - // Table data model. - TABLE, - // Graph data model. - GRAPH, - } - - enum BackendType { - // Default view backend, current is pangu. - Native, - // RocksDB backend. - RocksDB, - // Memory backend. - Memory, - // Paimon backend. - Paimon, - // Custom backend. - Custom; + enum BackendType { + // Default view backend, current is pangu. + Native, + // RocksDB backend. + RocksDB, + // Memory backend. + Memory, + // Paimon backend. + Paimon, + // Custom backend. + Custom; - public static BackendType of(String type) { - for (BackendType value : values()) { - if (value.name().equalsIgnoreCase(type)) { - return value; - } - } - throw new IllegalArgumentException("Illegal backend type: " + type); + public static BackendType of(String type) { + for (BackendType value : values()) { + if (value.name().equalsIgnoreCase(type)) { + return value; } + } + throw new IllegalArgumentException("Illegal backend type: " + type); } - + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/PView.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/PView.java index 0924825e0..22c7594ed 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/PView.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/PView.java @@ -21,9 +21,5 @@ import java.io.Serializable; -/** - * Interface for unifying graph and stream view. - */ -public interface PView extends Serializable { - -} +/** Interface for unifying graph and stream view. */ +public interface PView extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphSnapshotDesc.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphSnapshotDesc.java index 8b51f12e7..47d64b4c5 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphSnapshotDesc.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphSnapshotDesc.java @@ -20,20 +20,25 @@ package org.apache.geaflow.view.graph; import java.util.Map; + import org.apache.geaflow.api.partition.graph.vertex.GraphPartitioner; import org.apache.geaflow.model.graph.meta.GraphMetaType; public class GraphSnapshotDesc extends GraphViewDesc { - public GraphSnapshotDesc(String viewName, int shardNum, BackendType backend, - GraphPartitioner partitioner, - GraphMetaType graphMetaType, - Map props, long latestVersion) { - super(viewName, shardNum, backend, partitioner, graphMetaType, props, latestVersion); - } + public GraphSnapshotDesc( + String viewName, + int shardNum, + BackendType backend, + GraphPartitioner partitioner, + GraphMetaType graphMetaType, + Map props, + long latestVersion) { + super(viewName, shardNum, backend, partitioner, graphMetaType, props, latestVersion); + } - @Override - public boolean isStatic() { - return false; - } + @Override + public boolean isStatic() { + return false; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphViewDesc.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphViewDesc.java index 340332c54..cad35548f 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphViewDesc.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/GraphViewDesc.java @@ -22,93 +22,96 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.api.partition.graph.vertex.GraphPartitioner; import org.apache.geaflow.model.graph.meta.GraphMetaType; import org.apache.geaflow.view.IViewDesc; public class GraphViewDesc implements IViewDesc { - private final String viewName; - private final int shardNum; - private final BackendType backend; - private final GraphPartitioner partitioner; - private final GraphMetaType graphMetaType; - private final Map props; - - private final long currentVersion; - - public GraphViewDesc(String viewName, int shardNum, BackendType backend, - GraphPartitioner partitioner, GraphMetaType graphMetaType, - Map props, long currentVersion) { - this.viewName = Objects.requireNonNull(viewName, "view name is null"); - this.shardNum = shardNum; - this.backend = backend; - this.partitioner = partitioner; - this.graphMetaType = graphMetaType; - this.props = props; - this.currentVersion = currentVersion; - } - - @Override - public String getName() { - return viewName; - } - - @Override - public int getShardNum() { - return shardNum; - } - - @Override - public DataModel getDataModel() { - return DataModel.GRAPH; - } - - @Override - public BackendType getBackend() { - return backend; - } - - @Override - public Map getViewProps() { - return props == null ? new HashMap() : props; + private final String viewName; + private final int shardNum; + private final BackendType backend; + private final GraphPartitioner partitioner; + private final GraphMetaType graphMetaType; + private final Map props; + + private final long currentVersion; + + public GraphViewDesc( + String viewName, + int shardNum, + BackendType backend, + GraphPartitioner partitioner, + GraphMetaType graphMetaType, + Map props, + long currentVersion) { + this.viewName = Objects.requireNonNull(viewName, "view name is null"); + this.shardNum = shardNum; + this.backend = backend; + this.partitioner = partitioner; + this.graphMetaType = graphMetaType; + this.props = props; + this.currentVersion = currentVersion; + } + + @Override + public String getName() { + return viewName; + } + + @Override + public int getShardNum() { + return shardNum; + } + + @Override + public DataModel getDataModel() { + return DataModel.GRAPH; + } + + @Override + public BackendType getBackend() { + return backend; + } + + @Override + public Map getViewProps() { + return props == null ? new HashMap() : props; + } + + public GraphMetaType getGraphMetaType() { + return graphMetaType; + } + + public long getCurrentVersion() { + return currentVersion; + } + + public GraphSnapshotDesc snapshot(long snapshotVersion) { + assert !isStatic() : "Only dynamic graph have the snapshot() method."; + return new GraphSnapshotDesc( + viewName, shardNum, backend, partitioner, graphMetaType, props, snapshotVersion); + } + + /** Whether the graph is static or dynamic for graph state format. */ + public boolean isStatic() { + // static graph version is 0 + return currentVersion == 0L; + } + + public long getCheckpoint(long currentWindowId) { + if (isStatic()) { // static graph checkpoint is 0 + return 0L; + } else { + if (currentVersion <= 0) { // dynamic graph checkpoint start from 1 + return Math.max(currentWindowId, 1L); + } + return currentWindowId + currentVersion; } + } - public GraphMetaType getGraphMetaType() { - return graphMetaType; - } - - public long getCurrentVersion() { - return currentVersion; - } - - public GraphSnapshotDesc snapshot(long snapshotVersion) { - assert !isStatic() : "Only dynamic graph have the snapshot() method."; - return new GraphSnapshotDesc(viewName, shardNum, backend, partitioner, graphMetaType, - props, snapshotVersion); - } - - /** - * Whether the graph is static or dynamic for graph state format. - */ - public boolean isStatic() { - // static graph version is 0 - return currentVersion == 0L; - } - - public long getCheckpoint(long currentWindowId) { - if (isStatic()) { // static graph checkpoint is 0 - return 0L; - } else { - if (currentVersion <= 0) { // dynamic graph checkpoint start from 1 - return Math.max(currentWindowId, 1L); - } - return currentWindowId + currentVersion; - } - } - - public GraphViewDesc asStatic() { - return new GraphViewDesc(viewName, shardNum, backend, partitioner, - graphMetaType, props, 0L); - } + public GraphViewDesc asStatic() { + return new GraphViewDesc(viewName, shardNum, backend, partitioner, graphMetaType, props, 0L); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PGraphView.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PGraphView.java index e67fb51ee..a8c6d8a8b 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PGraphView.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PGraphView.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.pdata.stream.window.PWindowStream; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -31,44 +32,32 @@ public interface PGraphView extends PView { - /** - * Load and initialize graph view. - */ - static PGraphView loadPGraphView(GraphViewDesc graphViewDesc) { - ServiceLoader contextLoader = ServiceLoader.load(PGraphView.class); - Iterator graphViewIterable = contextLoader.iterator(); - while (graphViewIterable.hasNext()) { - PGraphView pGraphView = graphViewIterable.next(); - pGraphView.init(graphViewDesc); - return pGraphView; - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(PGraphView.class.getSimpleName())); + /** Load and initialize graph view. */ + static PGraphView loadPGraphView(GraphViewDesc graphViewDesc) { + ServiceLoader contextLoader = ServiceLoader.load(PGraphView.class); + Iterator graphViewIterable = contextLoader.iterator(); + while (graphViewIterable.hasNext()) { + PGraphView pGraphView = graphViewIterable.next(); + pGraphView.init(graphViewDesc); + return pGraphView; } + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(PGraphView.class.getSimpleName())); + } - /** - * Initialize graph view by desc. - */ - PGraphView init(GraphViewDesc graphViewDesc); - - /** - * Append vertex stream into incremental graphView. - */ - PIncGraphView appendVertex(PWindowStream> vertexStream); + /** Initialize graph view by desc. */ + PGraphView init(GraphViewDesc graphViewDesc); - /** - * Append edge stream into incremental graphView. - */ - PIncGraphView appendEdge(PWindowStream> edgeStream); + /** Append vertex stream into incremental graphView. */ + PIncGraphView appendVertex(PWindowStream> vertexStream); - /** - * Append vertex/edge stream into incremental graphView. - */ - PIncGraphView appendGraph(PWindowStream> vertexStream, - PWindowStream> edgeStream); + /** Append edge stream into incremental graphView. */ + PIncGraphView appendEdge(PWindowStream> edgeStream); - /** - * Build graph snapshot of specified version. - */ - PGraphWindow snapshot(long version); + /** Append vertex/edge stream into incremental graphView. */ + PIncGraphView appendGraph( + PWindowStream> vertexStream, PWindowStream> edgeStream); + /** Build graph snapshot of specified version. */ + PGraphWindow snapshot(long version); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PIncGraphView.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PIncGraphView.java index 1dd45a418..4d32d2d5c 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PIncGraphView.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/graph/PIncGraphView.java @@ -28,31 +28,22 @@ public interface PIncGraphView extends PGraphView { - /** - * Incremental graph traversal. - */ - PGraphTraversal incrementalTraversal(IncVertexCentricTraversal incVertexCentricTraversal); - - /** - * Incremental graph traversal with aggregation. - */ - PGraphTraversal incrementalTraversal(IncVertexCentricAggTraversal incVertexCentricTraversal); - - /** - * Incremental graph compute. - */ - PGraphCompute incrementalCompute(IncVertexCentricCompute incVertexCentricCompute); - - /** - * Incremental graph compute with aggregation. - */ - PGraphCompute incrementalCompute( - IncVertexCentricAggCompute incVertexCentricCompute); - - /** - * Materialize graph data into graph state. - */ - void materialize(); + /** Incremental graph traversal. */ + PGraphTraversal incrementalTraversal( + IncVertexCentricTraversal incVertexCentricTraversal); + + /** Incremental graph traversal with aggregation. */ + PGraphTraversal incrementalTraversal( + IncVertexCentricAggTraversal incVertexCentricTraversal); + + /** Incremental graph compute. */ + PGraphCompute incrementalCompute( + IncVertexCentricCompute incVertexCentricCompute); + + /** Incremental graph compute with aggregation. */ + PGraphCompute incrementalCompute( + IncVertexCentricAggCompute incVertexCentricCompute); + + /** Materialize graph data into graph state. */ + void materialize(); } diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/stream/StreamViewDesc.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/stream/StreamViewDesc.java index 9de831818..b4fc7a22c 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/stream/StreamViewDesc.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/view/stream/StreamViewDesc.java @@ -20,52 +20,53 @@ package org.apache.geaflow.view.stream; import java.util.HashMap; + import org.apache.geaflow.api.partition.IPartition; import org.apache.geaflow.view.IViewDesc; public class StreamViewDesc implements IViewDesc { - private String viewName; - private int shardNum; - private BackendType backend; - private IPartition partitioner; - private HashMap props; + private String viewName; + private int shardNum; + private BackendType backend; + private IPartition partitioner; + private HashMap props; - public StreamViewDesc(String viewName, int shardNum, BackendType backend) { - this.viewName = viewName; - this.shardNum = shardNum; - this.backend = backend; - } + public StreamViewDesc(String viewName, int shardNum, BackendType backend) { + this.viewName = viewName; + this.shardNum = shardNum; + this.backend = backend; + } - public StreamViewDesc(String viewName, int shardNum, BackendType backend, - IPartition partitioner, HashMap props) { - this(viewName, shardNum, backend); - this.partitioner = partitioner; - this.props = props; - } + public StreamViewDesc( + String viewName, int shardNum, BackendType backend, IPartition partitioner, HashMap props) { + this(viewName, shardNum, backend); + this.partitioner = partitioner; + this.props = props; + } - @Override - public String getName() { - return this.viewName; - } + @Override + public String getName() { + return this.viewName; + } - @Override - public int getShardNum() { - return this.shardNum; - } + @Override + public int getShardNum() { + return this.shardNum; + } - @Override - public DataModel getDataModel() { - return DataModel.TABLE; - } + @Override + public DataModel getDataModel() { + return DataModel.TABLE; + } - @Override - public BackendType getBackend() { - return backend; - } + @Override + public BackendType getBackend() { + return backend; + } - @Override - public HashMap getViewProps() { - return this.props; - } + @Override + public HashMap getViewProps() { + return this.props; + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/api/window/WindowFactoryTest.java b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/api/window/WindowFactoryTest.java index 2d6568ff0..46c54ce2d 100644 --- a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/api/window/WindowFactoryTest.java +++ b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/api/window/WindowFactoryTest.java @@ -26,15 +26,15 @@ public class WindowFactoryTest { - @Test - public void testCreateSizeTumblingWindow() { - IWindow window = WindowFactory.createSizeTumblingWindow(10); - Assert.assertTrue(window instanceof SizeTumblingWindow); - } + @Test + public void testCreateSizeTumblingWindow() { + IWindow window = WindowFactory.createSizeTumblingWindow(10); + Assert.assertTrue(window instanceof SizeTumblingWindow); + } - @Test - public void testAllWindow() { - IWindow window = WindowFactory.allWindow(); - Assert.assertTrue(window instanceof AllWindow); - } + @Test + public void testAllWindow() { + IWindow window = WindowFactory.allWindow(); + Assert.assertTrue(window instanceof AllWindow); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/args/EnvironmentArgumentParserTest.java b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/args/EnvironmentArgumentParserTest.java index f3d8214f8..73b88c009 100644 --- a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/args/EnvironmentArgumentParserTest.java +++ b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/args/EnvironmentArgumentParserTest.java @@ -20,78 +20,79 @@ package org.apache.geaflow.env.args; import java.util.Map; + import org.apache.commons.lang3.StringEscapeUtils; import org.testng.Assert; import org.testng.annotations.Test; public class EnvironmentArgumentParserTest { - @Test - public void testConsoleArgs() { - String args = "{\n" + - " \"job\":\n" + - " {\n" + - " \"geaflow.state.write.async.enable\": \"true\",\n" + - " \"geaflow.fo.enable\": false,\n" + - " \"geaflow.fo.max.restarts\": 0,\n" + - " \"geaflow.batch.number.per.checkpoint\": 1\n" + - " },\n" + - " \"system\":\n" + - " {\n" + - " \t\"geaflow.job.runtime.name\": \"geaflow123\",\n" + - " \t\"geaflow.job.unique.id\": \"123\",\n" + - " \t\"geaflow.job.id\": \"123456\",\n" + - " \"geaflow.job.owner\": \"test\",\n" + - " \"stateConfig\":\n" + - " {\n" + - " \"geaflow.file.persistent.root\": \"/geaflow/chk\",\n" + - " \"geaflow.file.persistent.config.json\":\n" + - " {\n" + - " \"fs.defaultFS\": \"dfs://xxxxxx\",\n" + - " \"dfs.usergroupservice.impl\": \"xxxxxx.class\",\n" + - " \"fs.AbstractFileSystem.dfs.impl\": \"xxxxxx\",\n" + - " \"ipc.client.connection.maxidletime\": \"300000\",\n" + - " \"alidfs.default.write.buffer.size\": \"1048576\",\n" + - " \"alidfs.default.read.buffer.size\": \"1048576\",\n" + - " \"alidfs.perf.counter.enable\": \"false\",\n" + - " \"fs.dfs.impl\": \"xxxxxx\"\n" + - " },\n" + - " \"geaflow.store.redis.host\": \"xxxxxx\",\n" + - " \"geaflow.file.persistent.type\": \"DFS\",\n" + - " \"geaflow.file.persistent.user.name\": \"geaflow\",\n" + - " \"geaflow.store.redis.port\": 8016\n" + - " },\n" + - " \n" + - " \"geaflow.cluster.started.callback.url\": \"http://xxxxxx\",\n" + - " \"metricConfig\":\n" + - " {\n" + - " \"geaflow.metric.reporters\": \"influxdb\",\n" + - " \"geaflow.metric.influxdb.url\": \"http://xxxxxx\",\n" + - " \"geaflow.metric.influxdb.bucket\": \"geaflow_metric\",\n" + - " \"geaflow.metric.influxdb.token\": \"xxxxxx\",\n" + - " \"geaflow.metric.influxdb.org\": \"geaflow\"\n" + - " },\n" + - " \"geaflow.gw.endpoint\": \"http://xxxxxx\",\n" + - " \"geaflow.job.cluster.id\": \"geaflow123-1684396791903\"\n" + - " },\n" + - " \"cluster\":\n" + - " {\n" + - " \"geaflow.container.memory.mb\": 20000,\n" + - " \"geaflow.system.state.backend.type\": \"ROCKSDB\",\n" + - " \"geaflow.container.worker.num\": 4,\n" + - " \"geaflow.container.num\": 1,\n" + - " \"geaflow.container.jvm.options\": \"-Xmx15000m,-Xms15000m,-Xmn10000m\",\n" + - " \"geaflow.container.vcores\": 8\n" + - " }\n" + - "}"; - - EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); - Map config = parser.parse(new String[]{args}); - Assert.assertNotNull(config); + @Test + public void testConsoleArgs() { + String args = + "{\n" + + " \"job\":\n" + + " {\n" + + " \"geaflow.state.write.async.enable\": \"true\",\n" + + " \"geaflow.fo.enable\": false,\n" + + " \"geaflow.fo.max.restarts\": 0,\n" + + " \"geaflow.batch.number.per.checkpoint\": 1\n" + + " },\n" + + " \"system\":\n" + + " {\n" + + " \t\"geaflow.job.runtime.name\": \"geaflow123\",\n" + + " \t\"geaflow.job.unique.id\": \"123\",\n" + + " \t\"geaflow.job.id\": \"123456\",\n" + + " \"geaflow.job.owner\": \"test\",\n" + + " \"stateConfig\":\n" + + " {\n" + + " \"geaflow.file.persistent.root\": \"/geaflow/chk\",\n" + + " \"geaflow.file.persistent.config.json\":\n" + + " {\n" + + " \"fs.defaultFS\": \"dfs://xxxxxx\",\n" + + " \"dfs.usergroupservice.impl\": \"xxxxxx.class\",\n" + + " \"fs.AbstractFileSystem.dfs.impl\": \"xxxxxx\",\n" + + " \"ipc.client.connection.maxidletime\": \"300000\",\n" + + " \"alidfs.default.write.buffer.size\": \"1048576\",\n" + + " \"alidfs.default.read.buffer.size\": \"1048576\",\n" + + " \"alidfs.perf.counter.enable\": \"false\",\n" + + " \"fs.dfs.impl\": \"xxxxxx\"\n" + + " },\n" + + " \"geaflow.store.redis.host\": \"xxxxxx\",\n" + + " \"geaflow.file.persistent.type\": \"DFS\",\n" + + " \"geaflow.file.persistent.user.name\": \"geaflow\",\n" + + " \"geaflow.store.redis.port\": 8016\n" + + " },\n" + + " \n" + + " \"geaflow.cluster.started.callback.url\": \"http://xxxxxx\",\n" + + " \"metricConfig\":\n" + + " {\n" + + " \"geaflow.metric.reporters\": \"influxdb\",\n" + + " \"geaflow.metric.influxdb.url\": \"http://xxxxxx\",\n" + + " \"geaflow.metric.influxdb.bucket\": \"geaflow_metric\",\n" + + " \"geaflow.metric.influxdb.token\": \"xxxxxx\",\n" + + " \"geaflow.metric.influxdb.org\": \"geaflow\"\n" + + " },\n" + + " \"geaflow.gw.endpoint\": \"http://xxxxxx\",\n" + + " \"geaflow.job.cluster.id\": \"geaflow123-1684396791903\"\n" + + " },\n" + + " \"cluster\":\n" + + " {\n" + + " \"geaflow.container.memory.mb\": 20000,\n" + + " \"geaflow.system.state.backend.type\": \"ROCKSDB\",\n" + + " \"geaflow.container.worker.num\": 4,\n" + + " \"geaflow.container.num\": 1,\n" + + " \"geaflow.container.jvm.options\": \"-Xmx15000m,-Xms15000m,-Xmn10000m\",\n" + + " \"geaflow.container.vcores\": 8\n" + + " }\n" + + "}"; - Map escapeConfig = - parser.parse(new String[]{StringEscapeUtils.escapeJava(args)}); - Assert.assertNotNull(escapeConfig); - } + EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); + Map config = parser.parse(new String[] {args}); + Assert.assertNotNull(config); + Map escapeConfig = + parser.parse(new String[] {StringEscapeUtils.escapeJava(args)}); + Assert.assertNotNull(escapeConfig); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/ctx/EnvironmentContextTest.java b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/ctx/EnvironmentContextTest.java index 5262040e5..095e5af59 100644 --- a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/ctx/EnvironmentContextTest.java +++ b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/env/ctx/EnvironmentContextTest.java @@ -25,11 +25,10 @@ public class EnvironmentContextTest { - @Test - public void test() { - EnvironmentContext context = new EnvironmentContext(); - Assert.assertNotNull(context.getConfig().getString(ExecutionConfigKeys.JOB_APP_NAME)); - Assert.assertNotNull(context.getConfig().getString(ExecutionConfigKeys.JOB_UNIQUE_ID)); - } - + @Test + public void test() { + EnvironmentContext context = new EnvironmentContext(); + Assert.assertNotNull(context.getConfig().getString(ExecutionConfigKeys.JOB_APP_NAME)); + Assert.assertNotNull(context.getConfig().getString(ExecutionConfigKeys.JOB_UNIQUE_ID)); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/pipeline/service/ServiceTypeTest.java b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/pipeline/service/ServiceTypeTest.java index 81178a2f6..b320b3d85 100644 --- a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/pipeline/service/ServiceTypeTest.java +++ b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/pipeline/service/ServiceTypeTest.java @@ -24,16 +24,16 @@ public class ServiceTypeTest { - @Test - public void testServiceType() { - String serviceTypeName = ServiceType.analytics_rpc.name(); - Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_rpc); - serviceTypeName = ServiceType.storage.name(); - Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.storage); - serviceTypeName = null; - Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_rpc); + @Test + public void testServiceType() { + String serviceTypeName = ServiceType.analytics_rpc.name(); + Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_rpc); + serviceTypeName = ServiceType.storage.name(); + Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.storage); + serviceTypeName = null; + Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_rpc); - serviceTypeName = ServiceType.analytics_http.name(); - Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_http); - } + serviceTypeName = ServiceType.analytics_http.name(); + Assert.assertTrue(ServiceType.getEnum(serviceTypeName) == ServiceType.analytics_http); + } } diff --git a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/view/stream/StreamViewDescTest.java b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/view/stream/StreamViewDescTest.java index b4fad8d61..d3ac708b7 100644 --- a/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/view/stream/StreamViewDescTest.java +++ b/geaflow/geaflow-core/geaflow-api/src/test/java/org/apache/geaflow/view/stream/StreamViewDescTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.view.stream; import java.util.HashMap; + import org.apache.geaflow.api.partition.kv.KeyByPartition; import org.apache.geaflow.view.IViewDesc; import org.testng.Assert; @@ -27,21 +28,23 @@ public class StreamViewDescTest { - @Test - public void testStreamViewDesc() { - StreamViewDesc desc = new StreamViewDesc("view", 1, IViewDesc.BackendType.Memory); - Assert.assertTrue(desc.getBackend().equals(IViewDesc.BackendType.Memory)); - Assert.assertTrue(desc.getName().equals("view")); - Assert.assertTrue(desc.getShardNum() == 1); - Assert.assertNull(desc.getViewProps()); - } + @Test + public void testStreamViewDesc() { + StreamViewDesc desc = new StreamViewDesc("view", 1, IViewDesc.BackendType.Memory); + Assert.assertTrue(desc.getBackend().equals(IViewDesc.BackendType.Memory)); + Assert.assertTrue(desc.getName().equals("view")); + Assert.assertTrue(desc.getShardNum() == 1); + Assert.assertNull(desc.getViewProps()); + } - @Test - public void testStreamViewDescWithPartition() { - StreamViewDesc desc = new StreamViewDesc("view", 1, IViewDesc.BackendType.Memory, new KeyByPartition(2), new HashMap()); - Assert.assertTrue(desc.getBackend().equals(IViewDesc.BackendType.Memory)); - Assert.assertTrue(desc.getName().equals("view")); - Assert.assertTrue(desc.getShardNum() == 1); - Assert.assertTrue(desc.getViewProps().size() == 0); - } + @Test + public void testStreamViewDescWithPartition() { + StreamViewDesc desc = + new StreamViewDesc( + "view", 1, IViewDesc.BackendType.Memory, new KeyByPartition(2), new HashMap()); + Assert.assertTrue(desc.getBackend().equals(IViewDesc.BackendType.Memory)); + Assert.assertTrue(desc.getName().equals("view")); + Assert.assertTrue(desc.getShardNum() == 1); + Assert.assertTrue(desc.getViewProps().size() == 0); + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/AbstractCollector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/AbstractCollector.java index 88c065981..340693fe4 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/AbstractCollector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/AbstractCollector.java @@ -24,29 +24,27 @@ public abstract class AbstractCollector { - protected int id; - protected RuntimeContext runtimeContext; - protected Meter outputMeter; + protected int id; + protected RuntimeContext runtimeContext; + protected Meter outputMeter; - public AbstractCollector(int id) { - this.id = id; - } + public AbstractCollector(int id) { + this.id = id; + } - public void setUp(RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - } + public void setUp(RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + } - public void setOutputMetric(Meter outputMeter) { - this.outputMeter = outputMeter; - } + public void setOutputMetric(Meter outputMeter) { + this.outputMeter = outputMeter; + } - public int getId() { - return id; - } + public int getId() { + return id; + } - public void finish() { - } + public void finish() {} - public void close() { - } + public void close() {} } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/CollectionCollector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/CollectionCollector.java index 29b92c4e3..86d6e72c6 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/CollectionCollector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/CollectionCollector.java @@ -20,44 +20,42 @@ package org.apache.geaflow.collector; import java.util.List; + import org.apache.geaflow.shuffle.desc.OutputType; public class CollectionCollector extends AbstractCollector implements ICollector { - private List> collectors; - - public CollectionCollector(int id, List> collectors) { - super(id); - this.collectors = collectors; - } + private List> collectors; + public CollectionCollector(int id, List> collectors) { + super(id); + this.collectors = collectors; + } - @Override - public void partition(T value) { - for (ICollector collector : collectors) { - collector.partition(value); - } + @Override + public void partition(T value) { + for (ICollector collector : collectors) { + collector.partition(value); } + } - @Override - public String getTag() { - return ""; - } + @Override + public String getTag() { + return ""; + } - @Override - public OutputType getType() { - return OutputType.FORWARD; - } - - @Override - public void broadcast(T value) { + @Override + public OutputType getType() { + return OutputType.FORWARD; + } - } + @Override + public void broadcast(T value) {} - @Override - public void partition(KEY key, T value) { - for (ICollector collector : collectors) { - collector.partition(key, value); - } + @Override + public void partition(KEY key, T value) { + for (ICollector collector : collectors) { + collector.partition(key, value); } + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/ICollector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/ICollector.java index f6ff60f6a..7ed744117 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/ICollector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/ICollector.java @@ -25,46 +25,31 @@ public interface ICollector extends Collector { - /** - * Returns op id. - */ - int getId(); + /** Returns op id. */ + int getId(); - /** - * Returns tag. - */ - String getTag(); + /** Returns tag. */ + String getTag(); - /** - * Returns type. - */ - OutputType getType(); + /** Returns type. */ + OutputType getType(); - /** - * Initialize collector. - * - * @param runtimeContext The runtime context. - */ - void setUp(RuntimeContext runtimeContext); + /** + * Initialize collector. + * + * @param runtimeContext The runtime context. + */ + void setUp(RuntimeContext runtimeContext); - /** - * Broadcast value to downstream. - */ - void broadcast(T value); + /** Broadcast value to downstream. */ + void broadcast(T value); - /** - * Partition value by key. - */ - void partition(KEY key, T value); + /** Partition value by key. */ + void partition(KEY key, T value); - /** - * Finish flush. - */ - void finish(); - - /** - * Close pipeline writer. - */ - void close(); + /** Finish flush. */ + void finish(); + /** Close pipeline writer. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/IResultCollector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/IResultCollector.java index a4ca21fa9..0d992c4da 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/IResultCollector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/collector/IResultCollector.java @@ -21,8 +21,6 @@ public interface IResultCollector { - /** - * Collect result. - */ - R collectResult(); + /** Collect result. */ + R collectResult(); } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/AbstractMessageBuffer.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/AbstractMessageBuffer.java index 767831607..1a16d2239 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/AbstractMessageBuffer.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/AbstractMessageBuffer.java @@ -21,48 +21,48 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.metric.EventMetrics; public abstract class AbstractMessageBuffer implements IMessageBuffer { - private static final int DEFAULT_TIMEOUT_MS = 100; - - private final LinkedBlockingQueue queue; - protected volatile EventMetrics eventMetrics; + private static final int DEFAULT_TIMEOUT_MS = 100; - public AbstractMessageBuffer(int capacity) { - this.queue = new LinkedBlockingQueue<>(capacity); - } + private final LinkedBlockingQueue queue; + protected volatile EventMetrics eventMetrics; - @Override - public void offer(R record) { - while (true) { - try { - if (this.queue.offer(record, DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) { - break; - } - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } - } - } + public AbstractMessageBuffer(int capacity) { + this.queue = new LinkedBlockingQueue<>(capacity); + } - @Override - public R poll(long timeout, TimeUnit unit) { - try { - return this.queue.poll(timeout, unit); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); + @Override + public void offer(R record) { + while (true) { + try { + if (this.queue.offer(record, DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)) { + break; } + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); + } } + } - public void setEventMetrics(EventMetrics eventMetrics) { - this.eventMetrics = eventMetrics; + @Override + public R poll(long timeout, TimeUnit unit) { + try { + return this.queue.poll(timeout, unit); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } + } - public EventMetrics getEventMetrics() { - return this.eventMetrics; - } + public void setEventMetrics(EventMetrics eventMetrics) { + this.eventMetrics = eventMetrics; + } + public EventMetrics getEventMetrics() { + return this.eventMetrics; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/IMessageBuffer.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/IMessageBuffer.java index 1efebcc35..7c1175216 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/IMessageBuffer.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/io/IMessageBuffer.java @@ -23,20 +23,19 @@ public interface IMessageBuffer { - /** - * Push a message to this pipe. - * - * @param message message - */ - void offer(M message); - - /** - * Pull a record from this pipe with timeout. - * - * @param timeout timeout number - * @param unit timeout unit - * @return message - */ - M poll(long timeout, TimeUnit unit); + /** + * Push a message to this pipe. + * + * @param message message + */ + void offer(M message); + /** + * Pull a record from this pipe with timeout. + * + * @param timeout timeout number + * @param unit timeout unit + * @return message + */ + M poll(long timeout, TimeUnit unit); } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/IPartitioner.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/IPartitioner.java index 78a97a006..776329baa 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/IPartitioner.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/IPartitioner.java @@ -20,57 +20,41 @@ package org.apache.geaflow.partitioner; import java.io.Serializable; + import org.apache.geaflow.api.partition.IPartition; public interface IPartitioner extends Serializable { - /** - * Returns op id. - */ - int getOpId(); - - /** - * Returns the partition. - */ - IPartition getPartition(); + /** Returns op id. */ + int getOpId(); - /** - * Returns the partition type. - */ - PartitionType getPartitionType(); + /** Returns the partition. */ + IPartition getPartition(); - enum PartitionType { + /** Returns the partition type. */ + PartitionType getPartitionType(); - /** - * Random partition. - */ - forward(true), - /** - * Broadcast partition. - */ - broadcast(true), - /** - * Key partition. - */ - key(true), - /** - * Custom partition. - */ - custom(false), - /** - * Iterator partition. - */ - iterator(false); + enum PartitionType { - boolean enablePushUp; + /** Random partition. */ + forward(true), + /** Broadcast partition. */ + broadcast(true), + /** Key partition. */ + key(true), + /** Custom partition. */ + custom(false), + /** Iterator partition. */ + iterator(false); - PartitionType(boolean enablePushUp) { - this.enablePushUp = enablePushUp; - } + boolean enablePushUp; - public boolean isEnablePushUp() { - return enablePushUp; - } + PartitionType(boolean enablePushUp) { + this.enablePushUp = enablePushUp; } + public boolean isEnablePushUp() { + return enablePushUp; + } + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/AbstractPartitioner.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/AbstractPartitioner.java index b45446a86..7591569ab 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/AbstractPartitioner.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/AbstractPartitioner.java @@ -23,16 +23,14 @@ public abstract class AbstractPartitioner implements IPartitioner { - private final int opId; - - public AbstractPartitioner(int opId) { - this.opId = opId; - } - - @Override - public int getOpId() { - return this.opId; - } + private final int opId; + public AbstractPartitioner(int opId) { + this.opId = opId; + } + @Override + public int getOpId() { + return this.opId; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/KeyPartitioner.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/KeyPartitioner.java index be5c9b29b..4a609a35f 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/KeyPartitioner.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/partitioner/impl/KeyPartitioner.java @@ -24,24 +24,23 @@ public class KeyPartitioner extends AbstractPartitioner { - private int maxParallelism; + private int maxParallelism; - public KeyPartitioner(int opId) { - super(opId); - } + public KeyPartitioner(int opId) { + super(opId); + } - public void init(int maxParallelism) { - this.maxParallelism = maxParallelism; - } + public void init(int maxParallelism) { + this.maxParallelism = maxParallelism; + } - @Override - public IPartition getPartition() { - return new KeyByPartition(maxParallelism); - } - - @Override - public PartitionType getPartitionType() { - return PartitionType.key; - } + @Override + public IPartition getPartition() { + return new KeyByPartition(maxParallelism); + } + @Override + public PartitionType getPartitionType() { + return PartitionType.key; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/ISelector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/ISelector.java index 3ca102a01..d2ebaf51c 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/ISelector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/ISelector.java @@ -23,9 +23,6 @@ public interface ISelector extends Serializable { - /** - * Compute the channel list for partitionKey. - */ - int[] selectChannels(T partitionKey); - + /** Compute the channel list for partitionKey. */ + int[] selectChannels(T partitionKey); } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/impl/ChannelSelector.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/impl/ChannelSelector.java index 3929775e9..bd16e5f63 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/impl/ChannelSelector.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/selector/impl/ChannelSelector.java @@ -25,17 +25,16 @@ public class ChannelSelector implements ISelector { - private int numChannels; - private IPartition partition; + private int numChannels; + private IPartition partition; - public ChannelSelector(int numChannels, IPartitioner partitioner) { - this.numChannels = numChannels; - this.partition = partitioner.getPartition(); - } - - @Override - public int[] selectChannels(KEY partitionKey) { - return partition.partition(partitionKey, numChannels); - } + public ChannelSelector(int numChannels, IPartitioner partitioner) { + this.numChannels = numChannels; + this.partition = partitioner.getPartition(); + } + @Override + public int[] selectChannels(KEY partitionKey) { + return partition.partition(partitionKey, numChannels); + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ForwardOutputDesc.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ForwardOutputDesc.java index 860692518..fd4e03450 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ForwardOutputDesc.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ForwardOutputDesc.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.shuffle.DataExchangeMode; import org.apache.geaflow.partitioner.IPartitioner; @@ -29,77 +30,76 @@ public class ForwardOutputDesc implements IOutputDesc, Serializable { - // Execution vertex id. - private final int vertexId; - // The current output edge id. - private final int edgeId; - // Partition number. - private final int numPartitions; - // Name of the output edge. - private final String edgeName; - // Data exchange mode. - private final DataExchangeMode dataExchangeMode; - // Describe the target task ids which the current output will send data to. - private final List targetTaskIndices; - // The partitioner of the output data. - private final IPartitioner partitioner; - // Data encoder, for serialization and deserialization. - private final IEncoder encoder; - - public ForwardOutputDesc( - int vertexId, - int edgeId, - int numPartitions, - String edgeName, - DataExchangeMode dataExchangeMode, - List targetTaskIndices, - IPartitioner partitioner, - IEncoder encoder) { - this.vertexId = vertexId; - this.edgeId = edgeId; - this.numPartitions = numPartitions; - this.edgeName = edgeName; - this.dataExchangeMode = dataExchangeMode; - this.targetTaskIndices = targetTaskIndices; - this.partitioner = partitioner; - this.encoder = encoder; - } + // Execution vertex id. + private final int vertexId; + // The current output edge id. + private final int edgeId; + // Partition number. + private final int numPartitions; + // Name of the output edge. + private final String edgeName; + // Data exchange mode. + private final DataExchangeMode dataExchangeMode; + // Describe the target task ids which the current output will send data to. + private final List targetTaskIndices; + // The partitioner of the output data. + private final IPartitioner partitioner; + // Data encoder, for serialization and deserialization. + private final IEncoder encoder; - public int getVertexId() { - return this.vertexId; - } + public ForwardOutputDesc( + int vertexId, + int edgeId, + int numPartitions, + String edgeName, + DataExchangeMode dataExchangeMode, + List targetTaskIndices, + IPartitioner partitioner, + IEncoder encoder) { + this.vertexId = vertexId; + this.edgeId = edgeId; + this.numPartitions = numPartitions; + this.edgeName = edgeName; + this.dataExchangeMode = dataExchangeMode; + this.targetTaskIndices = targetTaskIndices; + this.partitioner = partitioner; + this.encoder = encoder; + } - public int getEdgeId() { - return this.edgeId; - } + public int getVertexId() { + return this.vertexId; + } - public int getNumPartitions() { - return this.numPartitions; - } + public int getEdgeId() { + return this.edgeId; + } - public String getEdgeName() { - return this.edgeName; - } + public int getNumPartitions() { + return this.numPartitions; + } - public DataExchangeMode getDataExchangeMode() { - return this.dataExchangeMode; - } + public String getEdgeName() { + return this.edgeName; + } - public List getTargetTaskIndices() { - return this.targetTaskIndices; - } + public DataExchangeMode getDataExchangeMode() { + return this.dataExchangeMode; + } - public IPartitioner getPartitioner() { - return this.partitioner; - } + public List getTargetTaskIndices() { + return this.targetTaskIndices; + } - public IEncoder getEncoder() { - return this.encoder; - } + public IPartitioner getPartitioner() { + return this.partitioner; + } - @Override - public OutputType getType() { - return OutputType.FORWARD; - } + public IEncoder getEncoder() { + return this.encoder; + } + @Override + public OutputType getType() { + return OutputType.FORWARD; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/InputDescriptor.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/InputDescriptor.java index 5636dc2f2..a31149f0d 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/InputDescriptor.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/InputDescriptor.java @@ -21,18 +21,18 @@ import java.io.Serializable; import java.util.Map; + import org.apache.geaflow.shuffle.desc.IInputDesc; public class InputDescriptor implements Serializable { - private final Map> inputDescMap; - - public InputDescriptor(Map> inputDescMap) { - this.inputDescMap = inputDescMap; - } + private final Map> inputDescMap; - public Map> getInputDescMap() { - return inputDescMap; - } + public InputDescriptor(Map> inputDescMap) { + this.inputDescMap = inputDescMap; + } + public Map> getInputDescMap() { + return inputDescMap; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/IoDescriptor.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/IoDescriptor.java index dc7dcc43f..a3ea8190a 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/IoDescriptor.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/IoDescriptor.java @@ -23,22 +23,21 @@ public class IoDescriptor implements Serializable { - // A InputDescriptor that represents the upstream input descriptor info. - private final InputDescriptor inputDescriptor; - // A OutputDescriptor that represents the downstream output descriptor info. - private final OutputDescriptor outputDescriptor; - - public IoDescriptor(InputDescriptor inputDescriptor, OutputDescriptor outputDescriptor) { - this.inputDescriptor = inputDescriptor; - this.outputDescriptor = outputDescriptor; - } - - public InputDescriptor getInputDescriptor() { - return inputDescriptor; - } - - public OutputDescriptor getOutputDescriptor() { - return outputDescriptor; - } - + // A InputDescriptor that represents the upstream input descriptor info. + private final InputDescriptor inputDescriptor; + // A OutputDescriptor that represents the downstream output descriptor info. + private final OutputDescriptor outputDescriptor; + + public IoDescriptor(InputDescriptor inputDescriptor, OutputDescriptor outputDescriptor) { + this.inputDescriptor = inputDescriptor; + this.outputDescriptor = outputDescriptor; + } + + public InputDescriptor getInputDescriptor() { + return inputDescriptor; + } + + public OutputDescriptor getOutputDescriptor() { + return outputDescriptor; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/OutputDescriptor.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/OutputDescriptor.java index 8407fa552..bb3d1d1d4 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/OutputDescriptor.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/OutputDescriptor.java @@ -21,19 +21,19 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.shuffle.desc.IOutputDesc; public class OutputDescriptor implements Serializable { - // The output info list is the info of downstream output dependencies. - private final List outputDescList; - - public OutputDescriptor(List outputDescList) { - this.outputDescList = outputDescList; - } + // The output info list is the info of downstream output dependencies. + private final List outputDescList; - public List getOutputDescList() { - return outputDescList; - } + public OutputDescriptor(List outputDescList) { + this.outputDescList = outputDescList; + } + public List getOutputDescList() { + return outputDescList; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/RawDataInputDesc.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/RawDataInputDesc.java index fef506304..ba73ca756 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/RawDataInputDesc.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/RawDataInputDesc.java @@ -20,38 +20,39 @@ package org.apache.geaflow.shuffle; import java.util.List; + import org.apache.geaflow.shuffle.desc.IInputDesc; import org.apache.geaflow.shuffle.desc.InputType; public class RawDataInputDesc implements IInputDesc { - private final int edgeId; - private final String edgeName; - private final List rawData; - - public RawDataInputDesc(int edgeId, String edgeName, List rawData) { - this.edgeId = edgeId; - this.edgeName = edgeName; - this.rawData = rawData; - } - - @Override - public int getEdgeId() { - return edgeId; - } - - @Override - public String getName() { - return edgeName; - } - - @Override - public List getInput() { - return rawData; - } - - @Override - public InputType getInputType() { - return InputType.DATA; - } + private final int edgeId; + private final String edgeName; + private final List rawData; + + public RawDataInputDesc(int edgeId, String edgeName, List rawData) { + this.edgeId = edgeId; + this.edgeName = edgeName; + this.rawData = rawData; + } + + @Override + public int getEdgeId() { + return edgeId; + } + + @Override + public String getName() { + return edgeName; + } + + @Override + public List getInput() { + return rawData; + } + + @Override + public InputType getInputType() { + return InputType.DATA; + } } diff --git a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ResponseOutputDesc.java b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ResponseOutputDesc.java index 2a5a7528a..98b85dff7 100644 --- a/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ResponseOutputDesc.java +++ b/geaflow/geaflow-core/geaflow-core-common/src/main/java/org/apache/geaflow/shuffle/ResponseOutputDesc.java @@ -24,33 +24,32 @@ public class ResponseOutputDesc implements IOutputDesc { - private final int opId; - private final int edgeId; - private final String edgeName; - - public ResponseOutputDesc(int opId, int edgeId, String edgeName) { - this.opId = opId; - this.edgeId = edgeId; - this.edgeName = edgeName; - } - - public int getOpId() { - return opId; - } - - @Override - public int getEdgeId() { - return edgeId; - } - - @Override - public String getEdgeName() { - return edgeName; - } - - @Override - public OutputType getType() { - return OutputType.RESPONSE; - } - + private final int opId; + private final int edgeId; + private final String edgeName; + + public ResponseOutputDesc(int opId, int edgeId, String edgeName) { + this.opId = opId; + this.edgeId = edgeId; + this.edgeName = edgeName; + } + + public int getOpId() { + return opId; + } + + @Override + public int getEdgeId() { + return edgeId; + } + + @Override + public String getEdgeName() { + return edgeName; + } + + @Override + public OutputType getType() { + return OutputType.RESPONSE; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractClusterClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractClusterClient.java index 46a2f769d..a4c64db0a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractClusterClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractClusterClient.java @@ -27,17 +27,17 @@ public abstract class AbstractClusterClient implements IClusterClient { - private static final String MASTER_ID = "_MASTER"; + private static final String MASTER_ID = "_MASTER"; - protected String masterId; - protected Configuration config; - protected ClusterStartedCallback callback; + protected String masterId; + protected Configuration config; + protected ClusterStartedCallback callback; - public void init(IEnvironmentContext environmentContext) { - EnvironmentContext context = (EnvironmentContext) environmentContext; - this.config = context.getConfig(); - this.masterId = MASTER_ID; - this.config.setMasterId(masterId); - this.callback = ClusterCallbackFactory.createClusterStartCallback(config); - } + public void init(IEnvironmentContext environmentContext) { + EnvironmentContext context = (EnvironmentContext) environmentContext; + this.config = context.getConfig(); + this.masterId = MASTER_ID; + this.config.setMasterId(masterId); + this.callback = ClusterCallbackFactory.createClusterStartCallback(config); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractEnvironment.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractEnvironment.java index b6689d05e..2335d3189 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractEnvironment.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractEnvironment.java @@ -24,28 +24,27 @@ public abstract class AbstractEnvironment extends Environment { - protected GeaFlowClient geaflowClient; - - @Override - public void init() { - IClusterClient clusterClient = getClusterClient(); - this.geaflowClient = new GeaFlowClient(); - this.geaflowClient.init(context, clusterClient); - this.geaflowClient.startCluster(); + protected GeaFlowClient geaflowClient; + + @Override + public void init() { + IClusterClient clusterClient = getClusterClient(); + this.geaflowClient = new GeaFlowClient(); + this.geaflowClient.init(context, clusterClient); + this.geaflowClient.startCluster(); + } + + @Override + public IPipelineResult submit() { + return this.geaflowClient.submit(this.pipeline); + } + + @Override + public void shutdown() { + if (this.geaflowClient != null) { + this.geaflowClient.shutdown(); } + } - @Override - public IPipelineResult submit() { - return this.geaflowClient.submit(this.pipeline); - } - - @Override - public void shutdown() { - if (this.geaflowClient != null) { - this.geaflowClient.shutdown(); - } - } - - protected abstract IClusterClient getClusterClient(); - + protected abstract IClusterClient getClusterClient(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractPipelineClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractPipelineClient.java index 05928d07c..fb68a645e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractPipelineClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AbstractPipelineClient.java @@ -20,19 +20,19 @@ package org.apache.geaflow.cluster.client; import java.util.Map; + import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.cluster.rpc.RpcClient; import org.apache.geaflow.common.config.Configuration; public abstract class AbstractPipelineClient implements IPipelineClient { - protected RpcClient rpcClient; - protected Map driverAddresses; - - @Override - public void init(Map driverAddresses, Configuration config) { - this.rpcClient = RpcClient.init(config); - this.driverAddresses = driverAddresses; - } + protected RpcClient rpcClient; + protected Map driverAddresses; + @Override + public void init(Map driverAddresses, Configuration config) { + this.rpcClient = RpcClient.init(config); + this.driverAddresses = driverAddresses; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AsyncPipelineClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AsyncPipelineClient.java index 9ac6519c3..bc29eda26 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AsyncPipelineClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/AsyncPipelineClient.java @@ -29,6 +29,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.cluster.rpc.RpcClient; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -41,70 +42,79 @@ public class AsyncPipelineClient extends AbstractPipelineClient { - private static final Logger LOGGER = LoggerFactory.getLogger(AsyncPipelineClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AsyncPipelineClient.class); - private static final String PREFIX_DRIVER_EXECUTE_PIPELINE = "driver-submit-pipeline-"; + private static final String PREFIX_DRIVER_EXECUTE_PIPELINE = "driver-submit-pipeline-"; - private ExecutorService executorService; + private ExecutorService executorService; - @Override - public IPipelineResult submit(Pipeline pipeline) { - int driverNum = driverAddresses.size(); - executorService = new ThreadPoolExecutor(driverNum, driverNum, 0, - TimeUnit.SECONDS, new LinkedBlockingQueue<>(driverNum), + @Override + public IPipelineResult submit(Pipeline pipeline) { + int driverNum = driverAddresses.size(); + executorService = + new ThreadPoolExecutor( + driverNum, + driverNum, + 0, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(driverNum), ThreadUtil.namedThreadFactory(true, PREFIX_DRIVER_EXECUTE_PIPELINE)); - List> list = new ArrayList<>(driverNum); - int pipelineIndex = 0; - for (Map.Entry entry : driverAddresses.entrySet()) { - list.add(executorService.submit(new ExecutePipelineTask(driverNum, pipelineIndex, - pipeline, entry.getKey()))); - pipelineIndex++; - } - - try { - return list.get(0).get(); - } catch (InterruptedException | ExecutionException e) { - LOGGER.error("submit pipeline failed", e); - throw new GeaflowRuntimeException(e); - } + List> list = new ArrayList<>(driverNum); + int pipelineIndex = 0; + for (Map.Entry entry : driverAddresses.entrySet()) { + list.add( + executorService.submit( + new ExecutePipelineTask(driverNum, pipelineIndex, pipeline, entry.getKey()))); + pipelineIndex++; } - @Override - public boolean isSync() { - return false; + try { + return list.get(0).get(); + } catch (InterruptedException | ExecutionException e) { + LOGGER.error("submit pipeline failed", e); + throw new GeaflowRuntimeException(e); } + } - @Override - public void close() { - if (executorService != null) { - ExecutorUtil.shutdown(executorService); - } - } + @Override + public boolean isSync() { + return false; + } - private class ExecutePipelineTask implements Callable { + @Override + public void close() { + if (executorService != null) { + ExecutorUtil.shutdown(executorService); + } + } - private final String driverId; - private final Pipeline pipeline; - private final int total; - private final int index; + private class ExecutePipelineTask implements Callable { - private ExecutePipelineTask(int total, int index, Pipeline pipeline, String driverId) { - this.driverId = driverId; - this.pipeline = pipeline; - this.total = total; - this.index = index; - } + private final String driverId; + private final Pipeline pipeline; + private final int total; + private final int index; - @Override - public IPipelineResult call() throws Exception { - int num = this.index + 1; - LOGGER.info("execute pipeline [{}/{}]", num, this.total); - long start = System.currentTimeMillis(); - IPipelineResult future = RpcClient.getInstance().executePipeline(driverId, pipeline); - LOGGER.info("execute pipeline [{}/{}] costs {}ms, driver: {}", num, this.total, - System.currentTimeMillis() - start, driverId); - return future; - } + private ExecutePipelineTask(int total, int index, Pipeline pipeline, String driverId) { + this.driverId = driverId; + this.pipeline = pipeline; + this.total = total; + this.index = index; + } + @Override + public IPipelineResult call() throws Exception { + int num = this.index + 1; + LOGGER.info("execute pipeline [{}/{}]", num, this.total); + long start = System.currentTimeMillis(); + IPipelineResult future = RpcClient.getInstance().executePipeline(driverId, pipeline); + LOGGER.info( + "execute pipeline [{}/{}] costs {}ms, driver: {}", + num, + this.total, + System.currentTimeMillis() - start, + driverId); + return future; } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/GeaFlowClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/GeaFlowClient.java index dcef65eab..7d0fcb99c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/GeaFlowClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/GeaFlowClient.java @@ -20,34 +20,33 @@ package org.apache.geaflow.cluster.client; import java.io.Serializable; + import org.apache.geaflow.env.ctx.IEnvironmentContext; import org.apache.geaflow.pipeline.IPipelineResult; import org.apache.geaflow.pipeline.Pipeline; public class GeaFlowClient implements Serializable { - private IClusterClient clusterClient; - private IPipelineClient pipelineClient; + private IClusterClient clusterClient; + private IPipelineClient pipelineClient; - public void init(IEnvironmentContext environmentContext, IClusterClient clusterClient) { - this.clusterClient = clusterClient; - this.clusterClient.init(environmentContext); - } + public void init(IEnvironmentContext environmentContext, IClusterClient clusterClient) { + this.clusterClient = clusterClient; + this.clusterClient.init(environmentContext); + } - public void startCluster() { - this.pipelineClient = this.clusterClient.startCluster(); - } + public void startCluster() { + this.pipelineClient = this.clusterClient.startCluster(); + } - public IPipelineResult submit(Pipeline pipeline) { - return this.pipelineClient.submit(pipeline); - } + public IPipelineResult submit(Pipeline pipeline) { + return this.pipelineClient.submit(pipeline); + } - public void shutdown() { - if (this.pipelineClient != null) { - this.pipelineClient.close(); - } - this.clusterClient.shutdown(); + public void shutdown() { + if (this.pipelineClient != null) { + this.pipelineClient.close(); } - + this.clusterClient.shutdown(); + } } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IClusterClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IClusterClient.java index b9dd59d7f..0c1a84ffc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IClusterClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IClusterClient.java @@ -20,23 +20,17 @@ package org.apache.geaflow.cluster.client; import java.io.Serializable; + import org.apache.geaflow.env.ctx.IEnvironmentContext; public interface IClusterClient extends Serializable { - /** - * Initialize cluster client. - */ - void init(IEnvironmentContext environmentContext); - - /** - * Start cluster. - */ - IPipelineClient startCluster(); + /** Initialize cluster client. */ + void init(IEnvironmentContext environmentContext); - /** - * Shutdown cluster. - */ - void shutdown(); + /** Start cluster. */ + IPipelineClient startCluster(); + /** Shutdown cluster. */ + void shutdown(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IPipelineClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IPipelineClient.java index c56401a6a..a95e7a50d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IPipelineClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/IPipelineClient.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.client; import java.util.Map; + import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.pipeline.IPipelineResult; @@ -27,26 +28,19 @@ public interface IPipelineClient { - /** - * Init pipeline client. - * - * @param driverAddresses Driver Address map. - */ - void init(Map driverAddresses, Configuration config); - - /** - * Submit pipeline to execute. - */ - IPipelineResult submit(Pipeline pipeline); + /** + * Init pipeline client. + * + * @param driverAddresses Driver Address map. + */ + void init(Map driverAddresses, Configuration config); - /** - * Returns whether is sync client. - */ - boolean isSync(); + /** Submit pipeline to execute. */ + IPipelineResult submit(Pipeline pipeline); - /** - * Close client. - */ - void close(); + /** Returns whether is sync client. */ + boolean isSync(); + /** Close client. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineClientFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineClientFactory.java index 84891631c..ad7817994 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineClientFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineClientFactory.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.Map; import java.util.ServiceLoader; + import org.apache.geaflow.cluster.client.utils.PipelineUtil; import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.common.config.Configuration; @@ -32,21 +33,22 @@ public class PipelineClientFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineClientFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineClientFactory.class); - public static IPipelineClient createPipelineClient(Map driverAddresses, Configuration config) { - ServiceLoader clientLoader = ServiceLoader.load(IPipelineClient.class); - Iterator clientIterable = clientLoader.iterator(); - boolean isSync = !PipelineUtil.isAsync(config); - while (clientIterable.hasNext()) { - IPipelineClient client = clientIterable.next(); - if (client.isSync() == isSync) { - client.init(driverAddresses, config); - return client; - } - } - LOGGER.error("NOT found IPipelineClient implementation"); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(IPipelineClient.class.getSimpleName())); + public static IPipelineClient createPipelineClient( + Map driverAddresses, Configuration config) { + ServiceLoader clientLoader = ServiceLoader.load(IPipelineClient.class); + Iterator clientIterable = clientLoader.iterator(); + boolean isSync = !PipelineUtil.isAsync(config); + while (clientIterable.hasNext()) { + IPipelineClient client = clientIterable.next(); + if (client.isSync() == isSync) { + client.init(driverAddresses, config); + return client; + } } + LOGGER.error("NOT found IPipelineClient implementation"); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IPipelineClient.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineResult.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineResult.java index ff53ee4c3..362a4416c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineResult.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/PipelineResult.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.client; import java.util.concurrent.CompletableFuture; + import org.apache.geaflow.common.encoder.RpcMessageEncoder; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.pipeline.IPipelineResult; @@ -27,37 +28,37 @@ public class PipelineResult implements IPipelineResult { - private final CompletableFuture resultFuture; - private Boolean success; - private R result; + private final CompletableFuture resultFuture; + private Boolean success; + private R result; - public PipelineResult(CompletableFuture resultFuture) { - this.resultFuture = resultFuture; - this.success = null; - } + public PipelineResult(CompletableFuture resultFuture) { + this.resultFuture = resultFuture; + this.success = null; + } - @Override - public boolean isSuccess() { - if (success == null) { - try { - PipelineRes pipelineRes = resultFuture.get(); - result = RpcMessageEncoder.decode(pipelineRes.getPayload()); - } catch (Exception e) { - throw new GeaflowRuntimeException("get pipeline result error", e); - } - success = true; - return true; - } else { - return success; - } + @Override + public boolean isSuccess() { + if (success == null) { + try { + PipelineRes pipelineRes = resultFuture.get(); + result = RpcMessageEncoder.decode(pipelineRes.getPayload()); + } catch (Exception e) { + throw new GeaflowRuntimeException("get pipeline result error", e); + } + success = true; + return true; + } else { + return success; } + } - @Override - public R get() { - if (isSuccess()) { - return result; - } else { - throw new GeaflowRuntimeException("failed to get result"); - } + @Override + public R get() { + if (isSuccess()) { + return result; + } else { + throw new GeaflowRuntimeException("failed to get result"); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/SyncPipelineClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/SyncPipelineClient.java index eff902c62..58129ac95 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/SyncPipelineClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/SyncPipelineClient.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.pipeline.IPipelineResult; import org.apache.geaflow.pipeline.Pipeline; @@ -30,25 +31,23 @@ public class SyncPipelineClient extends AbstractPipelineClient { - private static final Logger LOGGER = LoggerFactory.getLogger(SyncPipelineClient.class); - - @Override - public IPipelineResult submit(Pipeline pipeline) { - List results = new ArrayList<>(); - for (Map.Entry entry : driverAddresses.entrySet()) { - LOGGER.info("submit pipeline to driver {}: {}", entry.getKey(), entry.getValue()); - results.add(rpcClient.executePipeline(entry.getKey(), pipeline)); - } - return results.get(0); - } + private static final Logger LOGGER = LoggerFactory.getLogger(SyncPipelineClient.class); - @Override - public boolean isSync() { - return true; + @Override + public IPipelineResult submit(Pipeline pipeline) { + List results = new ArrayList<>(); + for (Map.Entry entry : driverAddresses.entrySet()) { + LOGGER.info("submit pipeline to driver {}: {}", entry.getKey(), entry.getValue()); + results.add(rpcClient.executePipeline(entry.getKey(), pipeline)); } + return results.get(0); + } - @Override - public void close() { + @Override + public boolean isSync() { + return true; + } - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterCallbackFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterCallbackFactory.java index de4dc812a..4d62519c6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterCallbackFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterCallbackFactory.java @@ -26,12 +26,11 @@ public class ClusterCallbackFactory { - public static ClusterStartedCallback createClusterStartCallback(Configuration configuration) { - String callbackUrl = configuration.getString(CLUSTER_STARTED_CALLBACK_URL); - if (StringUtils.isNotBlank(callbackUrl)) { - return new RestClusterStartedCallback(configuration, callbackUrl); - } - return new SimpleClusterStartedCallback(); + public static ClusterStartedCallback createClusterStartCallback(Configuration configuration) { + String callbackUrl = configuration.getString(CLUSTER_STARTED_CALLBACK_URL); + if (StringUtils.isNotBlank(callbackUrl)) { + return new RestClusterStartedCallback(configuration, callbackUrl); } - + return new SimpleClusterStartedCallback(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterStartedCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterStartedCallback.java index 4f3aa7e36..4ed455763 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterStartedCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/ClusterStartedCallback.java @@ -22,70 +22,73 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.clustermanager.ClusterInfo; import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.common.utils.ProcessUtil; public interface ClusterStartedCallback extends Serializable { - /** - * The callback for cluster start succeed. - */ - void onSuccess(ClusterMeta clusterInfo); - - /** - * The callback for cluster start failed. - */ - void onFailure(Throwable e); + /** The callback for cluster start succeed. */ + void onSuccess(ClusterMeta clusterInfo); - class ClusterMeta implements Serializable { + /** The callback for cluster start failed. */ + void onFailure(Throwable e); - private String masterAddress; - private String clientAddress; - private Map driverAddresses; + class ClusterMeta implements Serializable { - public ClusterMeta() { - } + private String masterAddress; + private String clientAddress; + private Map driverAddresses; - public ClusterMeta(ClusterInfo clusterInfo) { - this(clusterInfo.getDriverAddresses(), clusterInfo.getMasterAddress().toString()); - } + public ClusterMeta() {} - public ClusterMeta(Map driverAddresses, String masterAddress) { - this.driverAddresses = new HashMap<>(driverAddresses); - this.masterAddress = masterAddress; - this.clientAddress = ProcessUtil.getHostAndIp(); - } + public ClusterMeta(ClusterInfo clusterInfo) { + this(clusterInfo.getDriverAddresses(), clusterInfo.getMasterAddress().toString()); + } - public String getMasterAddress() { - return masterAddress; - } + public ClusterMeta(Map driverAddresses, String masterAddress) { + this.driverAddresses = new HashMap<>(driverAddresses); + this.masterAddress = masterAddress; + this.clientAddress = ProcessUtil.getHostAndIp(); + } - public void setMasterAddress(String masterAddress) { - this.masterAddress = masterAddress; - } + public String getMasterAddress() { + return masterAddress; + } - public Map getDriverAddresses() { - return driverAddresses; - } + public void setMasterAddress(String masterAddress) { + this.masterAddress = masterAddress; + } - public void setDriverAddresses(Map driverAddresses) { - this.driverAddresses = driverAddresses; - } + public Map getDriverAddresses() { + return driverAddresses; + } - public String getClientAddress() { - return clientAddress; - } + public void setDriverAddresses(Map driverAddresses) { + this.driverAddresses = driverAddresses; + } - public void setClientAddress(String clientAddress) { - this.clientAddress = clientAddress; - } + public String getClientAddress() { + return clientAddress; + } - @Override - public String toString() { - return "ClusterMeta{" + "clientAddress='" + clientAddress + '\'' + ", masterAddress='" - + masterAddress + '\'' + ", driverAddresses=" + driverAddresses + '}'; - } + public void setClientAddress(String clientAddress) { + this.clientAddress = clientAddress; } + @Override + public String toString() { + return "ClusterMeta{" + + "clientAddress='" + + clientAddress + + '\'' + + ", masterAddress='" + + masterAddress + + '\'' + + ", driverAddresses=" + + driverAddresses + + '}'; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/HttpRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/HttpRequest.java index ff968fca3..fbe5e006c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/HttpRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/HttpRequest.java @@ -22,32 +22,32 @@ import java.io.Serializable; class HttpRequest implements Serializable { - private static final long serialVersionUID = 0L; - private boolean success; - private String message; - private Object data; - - public boolean isSuccess() { - return success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } - - public Object getData() { - return data; - } - - public void setData(Object data) { - this.data = data; - } -} \ No newline at end of file + private static final long serialVersionUID = 0L; + private boolean success; + private String message; + private Object data; + + public boolean isSuccess() { + return success; + } + + public void setSuccess(boolean success) { + this.success = success; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Object getData() { + return data; + } + + public void setData(Object data) { + this.data = data; + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallback.java index aed6ac49c..b33105302 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallback.java @@ -23,34 +23,32 @@ public interface JobOperatorCallback { - /** - * The callback for job finish succeed. - */ - void onFinish(); - - class JobOperatorMeta implements Serializable { - private boolean success; - private String action; - - public boolean isSuccess() { - return success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public String getAction() { - return action; - } - - public void setAction(String action) { - this.action = action; - } - - @Override - public String toString() { - return "JobOperatorMeta{" + "action='" + action + ", success=" + success + '}'; - } + /** The callback for job finish succeed. */ + void onFinish(); + + class JobOperatorMeta implements Serializable { + private boolean success; + private String action; + + public boolean isSuccess() { + return success; + } + + public void setSuccess(boolean success) { + this.success = success; + } + + public String getAction() { + return action; + } + + public void setAction(String action) { + this.action = action; + } + + @Override + public String toString() { + return "JobOperatorMeta{" + "action='" + action + ", success=" + success + '}'; } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallbackFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallbackFactory.java index 77e4c9505..f7b01709d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallbackFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/JobOperatorCallbackFactory.java @@ -26,12 +26,11 @@ public class JobOperatorCallbackFactory { - public static JobOperatorCallback createJobOperatorCallback(Configuration configuration) { - String callbackUrl = configuration.getString(GEAFLOW_GW_ENDPOINT, ""); - if (StringUtils.isNotBlank(callbackUrl)) { - return new RestJobOperatorCallback(configuration, callbackUrl); - } - return new SimpleJobOperatorCallback(); + public static JobOperatorCallback createJobOperatorCallback(Configuration configuration) { + String callbackUrl = configuration.getString(GEAFLOW_GW_ENDPOINT, ""); + if (StringUtils.isNotBlank(callbackUrl)) { + return new RestJobOperatorCallback(configuration, callbackUrl); } - + return new SimpleJobOperatorCallback(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallback.java index 942515679..c03ab2f0e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallback.java @@ -19,41 +19,43 @@ package org.apache.geaflow.cluster.client.callback; -import com.google.gson.Gson; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.utils.HttpUtil; -public class RestClusterStartedCallback implements ClusterStartedCallback { - - private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; - - private final String callbackUrl; - - private final Map headers; - - public RestClusterStartedCallback(Configuration config, String url) { - this.callbackUrl = url; - this.headers = new HashMap<>(); - this.headers.put(GEAFLOW_TOKEN_KEY, config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY, "")); - } - - @Override - public void onSuccess(ClusterMeta clusterInfo) { - HttpRequest request = new HttpRequest(); - request.setSuccess(true); - request.setData(clusterInfo); - HttpUtil.post(callbackUrl, new Gson().toJson(request), headers); - } +import com.google.gson.Gson; - @Override - public void onFailure(Throwable e) { - HttpRequest request = new HttpRequest(); - request.setSuccess(false); - request.setMessage(e.getMessage()); - HttpUtil.post(callbackUrl, new Gson().toJson(request), headers); - } +public class RestClusterStartedCallback implements ClusterStartedCallback { + private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; + + private final String callbackUrl; + + private final Map headers; + + public RestClusterStartedCallback(Configuration config, String url) { + this.callbackUrl = url; + this.headers = new HashMap<>(); + this.headers.put( + GEAFLOW_TOKEN_KEY, config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY, "")); + } + + @Override + public void onSuccess(ClusterMeta clusterInfo) { + HttpRequest request = new HttpRequest(); + request.setSuccess(true); + request.setData(clusterInfo); + HttpUtil.post(callbackUrl, new Gson().toJson(request), headers); + } + + @Override + public void onFailure(Throwable e) { + HttpRequest request = new HttpRequest(); + request.setSuccess(false); + request.setMessage(e.getMessage()); + HttpUtil.post(callbackUrl, new Gson().toJson(request), headers); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallback.java index b366c2e77..449137b87 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallback.java @@ -21,56 +21,58 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_UNIQUE_ID; -import com.google.gson.Gson; import java.net.URI; import java.net.URISyntaxException; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.utils.HttpUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.gson.Gson; + public class RestJobOperatorCallback implements JobOperatorCallback { - private static final Logger LOGGER = LoggerFactory.getLogger(RestJobOperatorCallback.class); - private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; - private static final String FINISH_JOB_PATH = "/api/tasks/%s/operations"; - private static final String FINISH_ACTION_KEY = "finish"; + private static final Logger LOGGER = LoggerFactory.getLogger(RestJobOperatorCallback.class); + private static final String GEAFLOW_TOKEN_KEY = "geaflow-token"; + private static final String FINISH_JOB_PATH = "/api/tasks/%s/operations"; + private static final String FINISH_ACTION_KEY = "finish"; - private final String callbackUrl; - private final Map headers; - private final long uniqueId; + private final String callbackUrl; + private final Map headers; + private final long uniqueId; - public RestJobOperatorCallback(Configuration config, String url) { - this.uniqueId = config.getLong(JOB_UNIQUE_ID); - this.callbackUrl = url; - this.headers = new HashMap<>(); - this.headers.put(GEAFLOW_TOKEN_KEY, config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY, "")); - } + public RestJobOperatorCallback(Configuration config, String url) { + this.uniqueId = config.getLong(JOB_UNIQUE_ID); + this.callbackUrl = url; + this.headers = new HashMap<>(); + this.headers.put( + GEAFLOW_TOKEN_KEY, config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY, "")); + } - @Override - public void onFinish() { - JobOperatorCallback.JobOperatorMeta jobOperatorMeta = new JobOperatorCallback.JobOperatorMeta(); - jobOperatorMeta.setSuccess(true); - jobOperatorMeta.setAction(FINISH_ACTION_KEY); + @Override + public void onFinish() { + JobOperatorCallback.JobOperatorMeta jobOperatorMeta = new JobOperatorCallback.JobOperatorMeta(); + jobOperatorMeta.setSuccess(true); + jobOperatorMeta.setAction(FINISH_ACTION_KEY); - String fullUrl = getFullUrl(jobOperatorMeta); - if (fullUrl != null) { - HttpUtil.post(fullUrl, new Gson().toJson(jobOperatorMeta), headers); - } + String fullUrl = getFullUrl(jobOperatorMeta); + if (fullUrl != null) { + HttpUtil.post(fullUrl, new Gson().toJson(jobOperatorMeta), headers); } + } - private String getFullUrl(JobOperatorMeta jobOperatorMeta) { - String fullUrl = null; - try { - URI uri = new URI(this.callbackUrl); - String path = String.format(FINISH_JOB_PATH, uniqueId); - fullUrl = uri.resolve(path).toString(); - } catch (URISyntaxException e) { - LOGGER.error("post {} failed: {}, msg: {}", fullUrl, jobOperatorMeta, e.getMessage()); - } - return fullUrl; + private String getFullUrl(JobOperatorMeta jobOperatorMeta) { + String fullUrl = null; + try { + URI uri = new URI(this.callbackUrl); + String path = String.format(FINISH_JOB_PATH, uniqueId); + fullUrl = uri.resolve(path).toString(); + } catch (URISyntaxException e) { + LOGGER.error("post {} failed: {}, msg: {}", fullUrl, jobOperatorMeta, e.getMessage()); } - + return fullUrl; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleClusterStartedCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleClusterStartedCallback.java index df51b2444..970f9006a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleClusterStartedCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleClusterStartedCallback.java @@ -23,16 +23,15 @@ import org.slf4j.LoggerFactory; public class SimpleClusterStartedCallback implements ClusterStartedCallback { - private static final Logger LOGGER = LoggerFactory.getLogger(SimpleClusterStartedCallback.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SimpleClusterStartedCallback.class); - @Override - public void onSuccess(ClusterMeta clusterInfo) { - LOGGER.info("start cluster successfully: {}", clusterInfo); - } - - @Override - public void onFailure(Throwable e) { - LOGGER.error("start cluster failed", e); - } + @Override + public void onSuccess(ClusterMeta clusterInfo) { + LOGGER.info("start cluster successfully: {}", clusterInfo); + } + @Override + public void onFailure(Throwable e) { + LOGGER.error("start cluster failed", e); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleJobOperatorCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleJobOperatorCallback.java index 5d7f61844..17de2bb12 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleJobOperatorCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/callback/SimpleJobOperatorCallback.java @@ -23,11 +23,10 @@ import org.slf4j.LoggerFactory; public class SimpleJobOperatorCallback implements JobOperatorCallback { - private static final Logger LOGGER = LoggerFactory.getLogger(SimpleJobOperatorCallback.class); - - @Override - public void onFinish() { - LOGGER.info("finish job successfully"); - } + private static final Logger LOGGER = LoggerFactory.getLogger(SimpleJobOperatorCallback.class); + @Override + public void onFinish() { + LOGGER.info("finish job successfully"); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/utils/PipelineUtil.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/utils/PipelineUtil.java index 46ba0c167..aac73523e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/utils/PipelineUtil.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/client/utils/PipelineUtil.java @@ -24,8 +24,8 @@ public class PipelineUtil { - public static boolean isAsync(Configuration configuration) { - // TODO Currently check whether is async mode using service share, maybe refactor in later. - return configuration.getBoolean(FrameworkConfigKeys.SERVICE_SHARE_ENABLE); - } + public static boolean isAsync(Configuration configuration) { + // TODO Currently check whether is async mode using service share, maybe refactor in later. + return configuration.getBoolean(FrameworkConfigKeys.SERVICE_SHARE_ENABLE); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/AbstractClusterManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/AbstractClusterManager.java index ddaaccf63..0746c5962 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/AbstractClusterManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/AbstractClusterManager.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.clustermanager; -import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +27,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.constants.ClusterConstants; import org.apache.geaflow.cluster.container.ContainerInfo; @@ -49,246 +49,240 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public abstract class AbstractClusterManager implements IClusterManager { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractClusterManager.class); - - protected String masterId; - protected ClusterConfig clusterConfig; - protected ClusterContext clusterContext; - protected Configuration config; - protected ClusterId clusterInfo; - protected Map containerInfos; - protected Map driverInfos; - protected Map> driverFutureMap; - protected IFailoverStrategy foStrategy; - protected long driverTimeoutSec; - private AtomicInteger idGenerator; - - @Override - public void init(ClusterContext clusterContext) { - this.config = clusterContext.getConfig(); - this.clusterConfig = clusterContext.getClusterConfig(); - this.driverTimeoutSec = clusterConfig.getDriverRegisterTimeoutSec(); - this.containerInfos = new ConcurrentHashMap<>(); - this.driverInfos = new ConcurrentHashMap<>(); - this.clusterContext = clusterContext; - this.idGenerator = new AtomicInteger(clusterContext.getMaxComponentId()); - this.masterId = clusterContext.getConfig().getMasterId(); - Preconditions.checkNotNull(masterId, "masterId is not set"); - this.foStrategy = buildFailoverStrategy(); - this.driverFutureMap = new ConcurrentHashMap<>(); - if (clusterContext.isRecover()) { - for (Integer driverId : clusterContext.getDriverIds().keySet()) { - driverFutureMap.put(driverId, new CompletableFuture<>()); - } - } - RpcClient.init(clusterContext.getConfig()); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractClusterManager.class); + + protected String masterId; + protected ClusterConfig clusterConfig; + protected ClusterContext clusterContext; + protected Configuration config; + protected ClusterId clusterInfo; + protected Map containerInfos; + protected Map driverInfos; + protected Map> driverFutureMap; + protected IFailoverStrategy foStrategy; + protected long driverTimeoutSec; + private AtomicInteger idGenerator; + + @Override + public void init(ClusterContext clusterContext) { + this.config = clusterContext.getConfig(); + this.clusterConfig = clusterContext.getClusterConfig(); + this.driverTimeoutSec = clusterConfig.getDriverRegisterTimeoutSec(); + this.containerInfos = new ConcurrentHashMap<>(); + this.driverInfos = new ConcurrentHashMap<>(); + this.clusterContext = clusterContext; + this.idGenerator = new AtomicInteger(clusterContext.getMaxComponentId()); + this.masterId = clusterContext.getConfig().getMasterId(); + Preconditions.checkNotNull(masterId, "masterId is not set"); + this.foStrategy = buildFailoverStrategy(); + this.driverFutureMap = new ConcurrentHashMap<>(); + if (clusterContext.isRecover()) { + for (Integer driverId : clusterContext.getDriverIds().keySet()) { + driverFutureMap.put(driverId, new CompletableFuture<>()); + } } - - @Override - public void allocateWorkers(int workerNum) { - int workersPerContainer = clusterConfig.getContainerWorkerNum(); - int containerNum = (workerNum + workersPerContainer - 1) / workersPerContainer; - LOGGER.info("allocate {} containers with {} workers", containerNum, workerNum); - startContainers(containerNum); - doCheckpoint(); + RpcClient.init(clusterContext.getConfig()); + } + + @Override + public void allocateWorkers(int workerNum) { + int workersPerContainer = clusterConfig.getContainerWorkerNum(); + int containerNum = (workerNum + workersPerContainer - 1) / workersPerContainer; + LOGGER.info("allocate {} containers with {} workers", containerNum, workerNum); + startContainers(containerNum); + doCheckpoint(); + } + + protected void startContainers(int containerNum) { + validateContainerNum(containerNum); + Map containerIds = new HashMap<>(); + for (int i = 0; i < containerNum; i++) { + int containerId = generateNextComponentId(); + createNewContainer(containerId, false); + containerIds.put(containerId, ClusterConstants.getContainerName(containerId)); } - - protected void startContainers(int containerNum) { - validateContainerNum(containerNum); - Map containerIds = new HashMap<>(); - for (int i = 0; i < containerNum; i++) { - int containerId = generateNextComponentId(); - createNewContainer(containerId, false); - containerIds.put(containerId, ClusterConstants.getContainerName(containerId)); - } - clusterContext.getContainerIds().putAll(containerIds); + clusterContext.getContainerIds().putAll(containerIds); + } + + @Override + public Map startDrivers() { + int driverNum = clusterConfig.getDriverNum(); + LOGGER.info("start driver number: {}", driverNum); + if (!clusterContext.isRecover()) { + Map driverIds = new HashMap<>(); + for (int driverIndex = 0; driverIndex < driverNum; driverIndex++) { + int driverId = generateNextComponentId(); + driverFutureMap.put(driverId, new CompletableFuture<>()); + createNewDriver(driverId, driverIndex); + driverIds.put(driverId, ClusterConstants.getDriverName(driverId)); + } + clusterContext.getDriverIds().putAll(driverIds); + doCheckpoint(); } - - @Override - public Map startDrivers() { - int driverNum = clusterConfig.getDriverNum(); - LOGGER.info("start driver number: {}", driverNum); - if (!clusterContext.isRecover()) { - Map driverIds = new HashMap<>(); - for (int driverIndex = 0; driverIndex < driverNum; driverIndex++) { - int driverId = generateNextComponentId(); - driverFutureMap.put(driverId, new CompletableFuture<>()); - createNewDriver(driverId, driverIndex); - driverIds.put(driverId, ClusterConstants.getDriverName(driverId)); - } - clusterContext.getDriverIds().putAll(driverIds); - doCheckpoint(); - } - Map driverAddresses = new HashMap<>(driverNum); - List driverInfoList = FutureUtil - .wait(driverFutureMap.values(), driverTimeoutSec, TimeUnit.SECONDS); - driverInfoList.forEach(driverInfo -> driverAddresses - .put(driverInfo.getName(), new ConnectAddress(driverInfo.getHost(), - driverInfo.getRpcPort()))); - return driverAddresses; + Map driverAddresses = new HashMap<>(driverNum); + List driverInfoList = + FutureUtil.wait(driverFutureMap.values(), driverTimeoutSec, TimeUnit.SECONDS); + driverInfoList.forEach( + driverInfo -> + driverAddresses.put( + driverInfo.getName(), + new ConnectAddress(driverInfo.getHost(), driverInfo.getRpcPort()))); + return driverAddresses; + } + + /** Restart all driver. */ + public void restartAllDrivers() { + Map driverIds = clusterContext.getDriverIds(); + LOGGER.info("Restart all drivers: {}", driverIds); + for (Map.Entry entry : driverIds.entrySet()) { + restartDriver(entry.getKey()); } - - /** - * Restart all driver. - */ - public void restartAllDrivers() { - Map driverIds = clusterContext.getDriverIds(); - LOGGER.info("Restart all drivers: {}", driverIds); - for (Map.Entry entry : driverIds.entrySet()) { - restartDriver(entry.getKey()); - } - } - - /** - * Restart all containers. - */ - public void restartAllContainers() { - Map containerIds = clusterContext.getContainerIds(); - LOGGER.info("Restart all containers: {}", containerIds); - for (Map.Entry entry : containerIds.entrySet()) { - restartContainer(entry.getKey()); - } + } + + /** Restart all containers. */ + public void restartAllContainers() { + Map containerIds = clusterContext.getContainerIds(); + LOGGER.info("Restart all containers: {}", containerIds); + for (Map.Entry entry : containerIds.entrySet()) { + restartContainer(entry.getKey()); } + } - /** - * Restart a driver. - */ - public abstract void restartDriver(int driverId); + /** Restart a driver. */ + public abstract void restartDriver(int driverId); - /** - * Restart a container. - */ - public abstract void restartContainer(int containerId); + /** Restart a container. */ + public abstract void restartContainer(int containerId); - /** - * Create a new driver. - */ - protected abstract void createNewDriver(int driverId, int index); + /** Create a new driver. */ + protected abstract void createNewDriver(int driverId, int index); - /** - * Create a new container. - */ - protected abstract void createNewContainer(int containerId, boolean isRecover); + /** Create a new container. */ + protected abstract void createNewContainer(int containerId, boolean isRecover); - protected abstract IFailoverStrategy buildFailoverStrategy(); + protected abstract IFailoverStrategy buildFailoverStrategy(); - protected void validateContainerNum(int containerNum) { - } - - @Override - public void doFailover(int componentId, Throwable cause) { - foStrategy.doFailover(componentId, cause); - } + protected void validateContainerNum(int containerNum) {} - @Override - public void close() { - if (clusterInfo != null) { - LOGGER.info("close master {}", masterId); - RpcClient.getInstance().closeMasterConnection(masterId); - } - - for (ContainerInfo containerInfo : containerInfos.values()) { - LOGGER.info("close container {}", containerInfo.getName()); - RpcClient.getInstance().closeContainerConnection(containerInfo.getName()); - } - - for (DriverInfo driverInfo : driverInfos.values()) { - LOGGER.info("close driver {}", driverInfo.getName()); - RpcClient.getInstance().closeDriverConnection(driverInfo.getName()); - } - } + @Override + public void doFailover(int componentId, Throwable cause) { + foStrategy.doFailover(componentId, cause); + } - private int generateNextComponentId() { - int id = idGenerator.incrementAndGet(); - clusterContext.setMaxComponentId(id); - return id; + @Override + public void close() { + if (clusterInfo != null) { + LOGGER.info("close master {}", masterId); + RpcClient.getInstance().closeMasterConnection(masterId); } - public RegisterResponse registerContainer(ContainerInfo request) { - LOGGER.info("register container:{}", request); - containerInfos.put(request.getId(), request); - RpcUtil.asyncExecute(() -> openContainer(request)); - return RegisterResponse.newBuilder().setSuccess(true).build(); + for (ContainerInfo containerInfo : containerInfos.values()) { + LOGGER.info("close container {}", containerInfo.getName()); + RpcClient.getInstance().closeContainerConnection(containerInfo.getName()); } - public RegisterResponse registerDriver(DriverInfo driverInfo) { - LOGGER.info("register driver:{}", driverInfo); - driverInfos.put(driverInfo.getId(), driverInfo); - CompletableFuture completableFuture = - (CompletableFuture) driverFutureMap.get(driverInfo.getId()); - completableFuture.complete(driverInfo); - return RegisterResponse.newBuilder().setSuccess(true).build(); + for (DriverInfo driverInfo : driverInfos.values()) { + LOGGER.info("close driver {}", driverInfo.getName()); + RpcClient.getInstance().closeDriverConnection(driverInfo.getName()); } - - protected void openContainer(ContainerInfo containerInfo) { - ContainerEndpointRef endpointRef = RpcEndpointRefFactory.getInstance() + } + + private int generateNextComponentId() { + int id = idGenerator.incrementAndGet(); + clusterContext.setMaxComponentId(id); + return id; + } + + public RegisterResponse registerContainer(ContainerInfo request) { + LOGGER.info("register container:{}", request); + containerInfos.put(request.getId(), request); + RpcUtil.asyncExecute(() -> openContainer(request)); + return RegisterResponse.newBuilder().setSuccess(true).build(); + } + + public RegisterResponse registerDriver(DriverInfo driverInfo) { + LOGGER.info("register driver:{}", driverInfo); + driverInfos.put(driverInfo.getId(), driverInfo); + CompletableFuture completableFuture = + (CompletableFuture) driverFutureMap.get(driverInfo.getId()); + completableFuture.complete(driverInfo); + return RegisterResponse.newBuilder().setSuccess(true).build(); + } + + protected void openContainer(ContainerInfo containerInfo) { + ContainerEndpointRef endpointRef = + RpcEndpointRefFactory.getInstance() .connectContainer(containerInfo.getHost(), containerInfo.getRpcPort()); - int workerNum = clusterConfig.getContainerWorkerNum(); - endpointRef.process(new OpenContainerEvent(workerNum), new RpcCallback() { - @Override - public void onSuccess(Response response) { - byte[] payload = response.getPayload().toByteArray(); - OpenContainerResponseEvent openResult = - (OpenContainerResponseEvent) SerializerFactory - .getKryoSerializer().deserialize(payload); - ContainerExecutorInfo executorInfo = new ContainerExecutorInfo(containerInfo, - openResult.getFirstWorkerIndex(), workerNum); - handleRegisterResponse(executorInfo, openResult, null); - } - - @Override - public void onFailure(Throwable t) { - handleRegisterResponse(null, null, t); - } - }); + int workerNum = clusterConfig.getContainerWorkerNum(); + endpointRef.process( + new OpenContainerEvent(workerNum), + new RpcCallback() { + @Override + public void onSuccess(Response response) { + byte[] payload = response.getPayload().toByteArray(); + OpenContainerResponseEvent openResult = + (OpenContainerResponseEvent) + SerializerFactory.getKryoSerializer().deserialize(payload); + ContainerExecutorInfo executorInfo = + new ContainerExecutorInfo( + containerInfo, openResult.getFirstWorkerIndex(), workerNum); + handleRegisterResponse(executorInfo, openResult, null); + } + + @Override + public void onFailure(Throwable t) { + handleRegisterResponse(null, null, t); + } + }); + } + + private void handleRegisterResponse( + ContainerExecutorInfo executorInfo, OpenContainerResponseEvent response, Throwable e) { + List callbacks = clusterContext.getCallbacks(); + if (e != null || !response.isSuccess()) { + for (ExecutorRegisteredCallback callback : callbacks) { + callback.onFailure(new ExecutorRegisterException(e)); + } + } else { + for (ExecutorRegisteredCallback callback : callbacks) { + callback.onSuccess(executorInfo); + } } + } - private void handleRegisterResponse(ContainerExecutorInfo executorInfo, - OpenContainerResponseEvent response, Throwable e) { - List callbacks = clusterContext.getCallbacks(); - if (e != null || !response.isSuccess()) { - for (ExecutorRegisteredCallback callback : callbacks) { - callback.onFailure(new ExecutorRegisterException(e)); - } - } else { - for (ExecutorRegisteredCallback callback : callbacks) { - callback.onSuccess(executorInfo); - } - } - } + private synchronized void doCheckpoint() { + clusterContext.checkpoint(new ClusterContext.ClusterCheckpointFunction()); + } - private synchronized void doCheckpoint() { - clusterContext.checkpoint(new ClusterContext.ClusterCheckpointFunction()); - } + public ClusterContext getClusterContext() { + return clusterContext; + } - public ClusterContext getClusterContext() { - return clusterContext; - } + public int getTotalContainers() { + return clusterContext.getContainerIds().size(); + } - public int getTotalContainers() { - return clusterContext.getContainerIds().size(); - } + public int getTotalDrivers() { + return clusterContext.getDriverIds().size(); + } - public int getTotalDrivers() { - return clusterContext.getDriverIds().size(); - } + public Map getContainerInfos() { + return containerInfos; + } - public Map getContainerInfos() { - return containerInfos; - } - - public Map getDriverInfos() { - return driverInfos; - } + public Map getDriverInfos() { + return driverInfos; + } - public Map getContainerIds() { - return clusterContext.getContainerIds(); - } - - public Map getDriverIds() { - return clusterContext.getDriverIds(); - } + public Map getContainerIds() { + return clusterContext.getContainerIds(); + } + public Map getDriverIds() { + return clusterContext.getDriverIds(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterContext.java index 50629e414..32aa3a2f3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterContext.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.common.IReliableContext; import org.apache.geaflow.cluster.common.ReliableContainerContext; import org.apache.geaflow.cluster.config.ClusterConfig; @@ -37,117 +38,124 @@ public class ClusterContext extends ReliableContainerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterContext.class); - - private final Configuration config; - private final ClusterConfig clusterConfig; - private final List callbacks; - private HeartbeatManager heartbeatManager; - private Map containerIds; - private Map driverIds; - private int maxComponentId; - - public ClusterContext(Configuration configuration) { - super(DEFAULT_MASTER_ID, getMasterName(), configuration); - this.config = configuration; - this.clusterConfig = ClusterConfig.build(configuration); - this.callbacks = new ArrayList<>(); - } - - public Configuration getConfig() { - return config; - } - - public ClusterConfig getClusterConfig() { - return clusterConfig; - } - - public void addExecutorRegisteredCallback(ExecutorRegisteredCallback callback) { - this.callbacks.add(callback); - } - - public List getCallbacks() { - return callbacks; - } - - public HeartbeatManager getHeartbeatManager() { - return heartbeatManager; - } - - public void setHeartbeatManager(HeartbeatManager heartbeatManager) { - this.heartbeatManager = heartbeatManager; - } - - public Map getContainerIds() { - return containerIds; - } - - public void setContainerIds(Map containerIds) { - this.containerIds = containerIds; - } - - public Map getDriverIds() { - return driverIds; - } - - public void setDriverIds(Map driverIds) { - this.driverIds = driverIds; - } - - public int getMaxComponentId() { - return maxComponentId; - } - - public void setMaxComponentId(int maxComponentId) { - this.maxComponentId = maxComponentId; - } + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterContext.class); + + private final Configuration config; + private final ClusterConfig clusterConfig; + private final List callbacks; + private HeartbeatManager heartbeatManager; + private Map containerIds; + private Map driverIds; + private int maxComponentId; + + public ClusterContext(Configuration configuration) { + super(DEFAULT_MASTER_ID, getMasterName(), configuration); + this.config = configuration; + this.clusterConfig = ClusterConfig.build(configuration); + this.callbacks = new ArrayList<>(); + } + + public Configuration getConfig() { + return config; + } + + public ClusterConfig getClusterConfig() { + return clusterConfig; + } + + public void addExecutorRegisteredCallback(ExecutorRegisteredCallback callback) { + this.callbacks.add(callback); + } + + public List getCallbacks() { + return callbacks; + } + + public HeartbeatManager getHeartbeatManager() { + return heartbeatManager; + } + + public void setHeartbeatManager(HeartbeatManager heartbeatManager) { + this.heartbeatManager = heartbeatManager; + } + + public Map getContainerIds() { + return containerIds; + } + + public void setContainerIds(Map containerIds) { + this.containerIds = containerIds; + } + + public Map getDriverIds() { + return driverIds; + } + + public void setDriverIds(Map driverIds) { + this.driverIds = driverIds; + } + + public int getMaxComponentId() { + return maxComponentId; + } + + public void setMaxComponentId(int maxComponentId) { + this.maxComponentId = maxComponentId; + } + + @Override + public void load() { + ClusterMetaStore metaStore = ClusterMetaStore.getInstance(id, name, config); + Map drivers = metaStore.getDriverIds(); + Map containerIds = metaStore.getContainerIds(); + int driverNum = drivers == null ? 0 : drivers.size(); + int containerNum = containerIds == null ? 0 : containerIds.size(); + if (driverNum != 0 && containerNum != 0) { + this.isRecover = true; + this.driverIds = drivers; + this.containerIds = containerIds; + this.maxComponentId = metaStore.getMaxContainerId(); + LOGGER.info( + "recover {} containers and {} drivers maxComponentId: {}", + containerNum, + driverNum, + maxComponentId); + } else { + this.isRecover = false; + this.driverIds = new ConcurrentHashMap<>(); + this.containerIds = new ConcurrentHashMap<>(); + this.maxComponentId = 0; + LOGGER.info("init with maxComponentId: {}", maxComponentId); + } + } + + public void setRecover(boolean isRecovered) { + this.isRecover = isRecovered; + } + + public static class ClusterCheckpointFunction implements IReliableContextCheckpointFunction { @Override - public void load() { - ClusterMetaStore metaStore = ClusterMetaStore.getInstance(id, name, config); - Map drivers = metaStore.getDriverIds(); - Map containerIds = metaStore.getContainerIds(); - int driverNum = drivers == null ? 0 : drivers.size(); - int containerNum = containerIds == null ? 0 : containerIds.size(); - if (driverNum != 0 && containerNum != 0) { - this.isRecover = true; - this.driverIds = drivers; - this.containerIds = containerIds; - this.maxComponentId = metaStore.getMaxContainerId(); - LOGGER.info("recover {} containers and {} drivers maxComponentId: {}", - containerNum, driverNum, maxComponentId); - } else { - this.isRecover = false; - this.driverIds = new ConcurrentHashMap<>(); - this.containerIds = new ConcurrentHashMap<>(); - this.maxComponentId = 0; - LOGGER.info("init with maxComponentId: {}", maxComponentId); - } - } - - public void setRecover(boolean isRecovered) { - this.isRecover = isRecovered; - } - - public static class ClusterCheckpointFunction implements IReliableContextCheckpointFunction { - - @Override - public void doCheckpoint(IReliableContext context) { - ClusterContext clusterContext = (ClusterContext) context; - Map containerIds = clusterContext.getContainerIds(); - Map driverIds = clusterContext.getDriverIds(); - ClusterMetaStore metaStore = ClusterMetaStore - .getInstance(clusterContext.id, clusterContext.name, clusterContext.config); - if (containerIds != null && !containerIds.isEmpty() && driverIds != null && !driverIds - .isEmpty()) { - LOGGER.info("persist {} containers and {} drivers into metaStore", - containerIds.size(), driverIds.size()); - metaStore.saveMaxContainerId(clusterContext.getMaxComponentId()); - metaStore.saveContainerIds(containerIds); - metaStore.saveDriverIds(driverIds); - metaStore.flush(); - } - } - } - + public void doCheckpoint(IReliableContext context) { + ClusterContext clusterContext = (ClusterContext) context; + Map containerIds = clusterContext.getContainerIds(); + Map driverIds = clusterContext.getDriverIds(); + ClusterMetaStore metaStore = + ClusterMetaStore.getInstance( + clusterContext.id, clusterContext.name, clusterContext.config); + if (containerIds != null + && !containerIds.isEmpty() + && driverIds != null + && !driverIds.isEmpty()) { + LOGGER.info( + "persist {} containers and {} drivers into metaStore", + containerIds.size(), + driverIds.size()); + metaStore.saveMaxContainerId(clusterContext.getMaxComponentId()); + metaStore.saveContainerIds(containerIds); + metaStore.saveDriverIds(driverIds); + metaStore.flush(); + } + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterId.java index 10ebf7f85..5a01b61c7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterId.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface ClusterId extends Serializable { - -} +public interface ClusterId extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterInfo.java index 89f9f3bf6..5c68a5036 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ClusterInfo.java @@ -22,54 +22,57 @@ import java.io.Serializable; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.cluster.rpc.ConnectAddress; public class ClusterInfo implements Serializable { - private ConnectAddress masterAddress; - private Map driverAddresses; + private ConnectAddress masterAddress; + private Map driverAddresses; - public ClusterInfo() { - } + public ClusterInfo() {} - public ConnectAddress getMasterAddress() { - return masterAddress; - } + public ConnectAddress getMasterAddress() { + return masterAddress; + } - public void setMasterAddress(ConnectAddress masterAddress) { - this.masterAddress = masterAddress; - } + public void setMasterAddress(ConnectAddress masterAddress) { + this.masterAddress = masterAddress; + } - public Map getDriverAddresses() { - return driverAddresses; - } + public Map getDriverAddresses() { + return driverAddresses; + } - public void setDriverAddresses(Map driverAddresses) { - this.driverAddresses = driverAddresses; - } + public void setDriverAddresses(Map driverAddresses) { + this.driverAddresses = driverAddresses; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ClusterInfo that = (ClusterInfo) o; - return Objects.equals(masterAddress, that.masterAddress) && Objects - .equals(driverAddresses, that.driverAddresses); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(masterAddress, driverAddresses); + if (o == null || getClass() != o.getClass()) { + return false; } + ClusterInfo that = (ClusterInfo) o; + return Objects.equals(masterAddress, that.masterAddress) + && Objects.equals(driverAddresses, that.driverAddresses); + } - @Override - public String toString() { - return "ClusterInfo{" + "masterAddress=" + masterAddress + ", driverAddresses=" - + driverAddresses + '}'; - } + @Override + public int hashCode() { + return Objects.hash(masterAddress, driverAddresses); + } + @Override + public String toString() { + return "ClusterInfo{" + + "masterAddress=" + + masterAddress + + ", driverAddresses=" + + driverAddresses + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ContainerExecutorInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ContainerExecutorInfo.java index 55d7da62e..0b3480660 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ContainerExecutorInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ContainerExecutorInfo.java @@ -22,79 +22,70 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.container.ContainerInfo; public class ContainerExecutorInfo implements Serializable { - /** - * container id. - */ - private int containerId; - /** - * container name. - */ - private String containerName; - /** - * host ip. - */ - private String host; - /** - * process id. - */ - private int processId; - /** - * rpc service port. - */ - private int rpcPort; - /** - * shuffle service port. - */ - private int shufflePort; - /** - * executor index list. - */ - private List executorIds; - - public ContainerExecutorInfo(ContainerInfo containerInfo, int firstWorkerIndex, - int workerNum) { - this.containerId = containerInfo.getId(); - this.containerName = containerInfo.getName(); - this.host = containerInfo.getHost(); - this.rpcPort = containerInfo.getRpcPort(); - this.shufflePort = containerInfo.getShufflePort(); - this.processId = containerInfo.getPid(); - this.executorIds = new ArrayList<>(workerNum); - for (int i = 0; i < workerNum; i++) { - this.executorIds.add(firstWorkerIndex + i); - } - } + /** container id. */ + private int containerId; - public int getContainerId() { - return containerId; - } + /** container name. */ + private String containerName; - public String getContainerName() { - return containerName; - } + /** host ip. */ + private String host; - public String getHost() { - return host; - } + /** process id. */ + private int processId; - public int getProcessId() { - return processId; - } + /** rpc service port. */ + private int rpcPort; - public int getRpcPort() { - return rpcPort; - } + /** shuffle service port. */ + private int shufflePort; - public int getShufflePort() { - return shufflePort; - } + /** executor index list. */ + private List executorIds; - public List getExecutorIds() { - return executorIds; + public ContainerExecutorInfo(ContainerInfo containerInfo, int firstWorkerIndex, int workerNum) { + this.containerId = containerInfo.getId(); + this.containerName = containerInfo.getName(); + this.host = containerInfo.getHost(); + this.rpcPort = containerInfo.getRpcPort(); + this.shufflePort = containerInfo.getShufflePort(); + this.processId = containerInfo.getPid(); + this.executorIds = new ArrayList<>(workerNum); + for (int i = 0; i < workerNum; i++) { + this.executorIds.add(firstWorkerIndex + i); } + } + + public int getContainerId() { + return containerId; + } + + public String getContainerName() { + return containerName; + } + + public String getHost() { + return host; + } + + public int getProcessId() { + return processId; + } + + public int getRpcPort() { + return rpcPort; + } + + public int getShufflePort() { + return shufflePort; + } + public List getExecutorIds() { + return executorIds; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisterException.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisterException.java index e9e74bdf6..335664ca3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisterException.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisterException.java @@ -23,8 +23,7 @@ public class ExecutorRegisterException extends GeaflowRuntimeException { - public ExecutorRegisterException(Throwable e) { - super(e); - } - + public ExecutorRegisterException(Throwable e) { + super(e); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisteredCallback.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisteredCallback.java index 0fd808857..9a23d350a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisteredCallback.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/ExecutorRegisteredCallback.java @@ -23,14 +23,9 @@ public interface ExecutorRegisteredCallback extends Serializable { - /** - * The callback for executor register succeed. - */ - void onSuccess(ContainerExecutorInfo executorInfos); - - /** - * The callback for executor register failed. - */ - void onFailure(ExecutorRegisterException e); + /** The callback for executor register succeed. */ + void onSuccess(ContainerExecutorInfo executorInfos); + /** The callback for executor register failed. */ + void onFailure(ExecutorRegisterException e); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/IClusterManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/IClusterManager.java index 83faf30b4..b821806cb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/IClusterManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/clustermanager/IClusterManager.java @@ -21,38 +21,26 @@ import java.io.Serializable; import java.util.Map; + import org.apache.geaflow.cluster.rpc.ConnectAddress; public interface IClusterManager extends Serializable { - /** - * Initialize cluster manager. - */ - void init(ClusterContext context); - - /** - * Start master. - */ - ClusterId startMaster(); + /** Initialize cluster manager. */ + void init(ClusterContext context); - /** - * Start drivers drivers and returns rpc addresses. - */ - Map startDrivers(); + /** Start master. */ + ClusterId startMaster(); - /** - * Start worker threads. - */ - void allocateWorkers(int workerNum); + /** Start drivers drivers and returns rpc addresses. */ + Map startDrivers(); - /** - * Trigger job failover. - */ - void doFailover(int componentId, Throwable cause); + /** Start worker threads. */ + void allocateWorkers(int workerNum); - /** - * Close cluster manager. - */ - void close(); + /** Trigger job failover. */ + void doFailover(int componentId, Throwable cause); + /** Close cluster manager. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractEmitterRequest.java index 3a1657485..0f689f3ac 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractEmitterRequest.java @@ -21,22 +21,21 @@ public abstract class AbstractEmitterRequest implements IEmitterRequest { - private final int taskId; - private final long windowId; + private final int taskId; + private final long windowId; - public AbstractEmitterRequest(int taskId, long windowId) { - this.taskId = taskId; - this.windowId = windowId; - } + public AbstractEmitterRequest(int taskId, long windowId) { + this.taskId = taskId; + this.windowId = windowId; + } - @Override - public int getTaskId() { - return this.taskId; - } - - @Override - public long getWindowId() { - return this.windowId; - } + @Override + public int getTaskId() { + return this.taskId; + } + @Override + public long getWindowId() { + return this.windowId; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineCollector.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineCollector.java index 2b920e32d..95796565d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineCollector.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineCollector.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.IntStream; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.collector.AbstractCollector; import org.apache.geaflow.collector.ICollector; @@ -31,100 +32,94 @@ import org.apache.geaflow.selector.impl.ChannelSelector; import org.apache.geaflow.shuffle.ForwardOutputDesc; -public abstract class AbstractPipelineCollector - extends AbstractCollector implements ICollector { - - protected transient IOutputMessageBuffer outputBuffer; - protected transient ISelector recordISelector; - protected ForwardOutputDesc outputDesc; - protected long windowId; - - public AbstractPipelineCollector(ForwardOutputDesc outputDesc) { - super(outputDesc.getPartitioner().getOpId()); - this.outputDesc = outputDesc; - } - - @Override - public void setUp(RuntimeContext runtimeContext) { - super.setUp(runtimeContext); - List targetTaskIds = outputDesc.getTargetTaskIndices(); - IPartitioner partitioner = outputDesc.getPartitioner(); - if (partitioner.getPartitionType() == IPartitioner.PartitionType.key) { - ((KeyPartitioner) partitioner).init(outputDesc.getNumPartitions()); - } - this.recordISelector = new ChannelSelector(targetTaskIds.size(), - partitioner); - } - - public void setOutputBuffer(IOutputMessageBuffer outputBuffer) { - this.outputBuffer = outputBuffer; - } - - public long getWindowId() { - return windowId; - } - - public void setWindowId(long windowId) { - this.windowId = windowId; +public abstract class AbstractPipelineCollector extends AbstractCollector + implements ICollector { + + protected transient IOutputMessageBuffer outputBuffer; + protected transient ISelector recordISelector; + protected ForwardOutputDesc outputDesc; + protected long windowId; + + public AbstractPipelineCollector(ForwardOutputDesc outputDesc) { + super(outputDesc.getPartitioner().getOpId()); + this.outputDesc = outputDesc; + } + + @Override + public void setUp(RuntimeContext runtimeContext) { + super.setUp(runtimeContext); + List targetTaskIds = outputDesc.getTargetTaskIndices(); + IPartitioner partitioner = outputDesc.getPartitioner(); + if (partitioner.getPartitionType() == IPartitioner.PartitionType.key) { + ((KeyPartitioner) partitioner).init(outputDesc.getNumPartitions()); } - - @Override - public void broadcast(T value) { - List targetTaskIds = outputDesc.getTargetTaskIndices(); - int[] channels = IntStream.rangeClosed(0, targetTaskIds.size() - 1).toArray(); - try { - this.outputBuffer.emit(this.windowId, value, false, channels); - this.outputMeter.mark(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + this.recordISelector = new ChannelSelector(targetTaskIds.size(), partitioner); + } + + public void setOutputBuffer(IOutputMessageBuffer outputBuffer) { + this.outputBuffer = outputBuffer; + } + + public long getWindowId() { + return windowId; + } + + public void setWindowId(long windowId) { + this.windowId = windowId; + } + + @Override + public void broadcast(T value) { + List targetTaskIds = outputDesc.getTargetTaskIndices(); + int[] channels = IntStream.rangeClosed(0, targetTaskIds.size() - 1).toArray(); + try { + this.outputBuffer.emit(this.windowId, value, false, channels); + this.outputMeter.mark(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - @Override - public void partition(T value) { - shuffle(value, false); + } + + @Override + public void partition(T value) { + shuffle(value, false); + } + + @Override + public void partition(KEY key, T value) { + shuffle(key, value, false); + } + + @Override + public void finish() { + try { + this.outputBuffer.finish(this.windowId); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - @Override - public void partition(KEY key, T value) { - shuffle(key, value, false); - } + /** Shuffle data with value itself. */ + protected void shuffle(T value, boolean isRetract) { + int[] targetChannels = this.recordISelector.selectChannels(value); - @Override - public void finish() { - try { - this.outputBuffer.finish(this.windowId); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + try { + this.outputBuffer.emit(this.windowId, value, isRetract, targetChannels); + this.outputMeter.mark(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - /** - * Shuffle data with value itself. - */ - protected void shuffle(T value, boolean isRetract) { - int[] targetChannels = this.recordISelector.selectChannels(value); - - try { - this.outputBuffer.emit(this.windowId, value, isRetract, targetChannels); - this.outputMeter.mark(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - } + /** Shuffle data with key. */ + protected void shuffle(KEY key, T value, boolean isRetract) { + int[] targetChannels = this.recordISelector.selectChannels(key); - /** - * Shuffle data with key. - */ - protected void shuffle(KEY key, T value, boolean isRetract) { - int[] targetChannels = this.recordISelector.selectChannels(key); - - try { - this.outputBuffer.emit(this.windowId, value, isRetract, targetChannels); - this.outputMeter.mark(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + try { + this.outputBuffer.emit(this.windowId, value, isRetract, targetChannels); + this.outputMeter.mark(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineOutputCollector.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineOutputCollector.java index 13fcb545c..8147881ac 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineOutputCollector.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/AbstractPipelineOutputCollector.java @@ -29,57 +29,63 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractPipelineOutputCollector - extends AbstractPipelineCollector implements IResultCollector { +public abstract class AbstractPipelineOutputCollector extends AbstractPipelineCollector + implements IResultCollector { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractPipelineOutputCollector.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractPipelineOutputCollector.class); - private int edgeId; - private String edgeName; - private Shard shard; - private OutputType collectType; + private int edgeId; + private String edgeName; + private Shard shard; + private OutputType collectType; - public AbstractPipelineOutputCollector(ForwardOutputDesc outputDesc) { - super(outputDesc); - this.edgeName = outputDesc.getEdgeName(); - this.edgeId = outputDesc.getEdgeId(); - this.collectType = outputDesc.getType(); - } + public AbstractPipelineOutputCollector(ForwardOutputDesc outputDesc) { + super(outputDesc); + this.edgeName = outputDesc.getEdgeName(); + this.edgeId = outputDesc.getEdgeId(); + this.collectType = outputDesc.getType(); + } - @Override - public void setUp(RuntimeContext runtimeContext) { - super.setUp(runtimeContext); - int taskId = runtimeContext.getTaskArgs().getTaskId(); - int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - this.shard = null; - LOGGER.info("setup PipelineOutputCollector {} taskId {} taskIndex {} edgeName {} edgeId {}", - this, taskId, taskIndex, this.edgeName, this.edgeId); - } + @Override + public void setUp(RuntimeContext runtimeContext) { + super.setUp(runtimeContext); + int taskId = runtimeContext.getTaskArgs().getTaskId(); + int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + this.shard = null; + LOGGER.info( + "setup PipelineOutputCollector {} taskId {} taskIndex {} edgeName {} edgeId {}", + this, + taskId, + taskIndex, + this.edgeName, + this.edgeId); + } - @Override - public void finish() { - try { - this.shard = (Shard) this.outputBuffer.finish(this.windowId); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + @Override + public void finish() { + try { + this.shard = (Shard) this.outputBuffer.finish(this.windowId); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - @Override - public String getTag() { - return edgeName; - } + @Override + public String getTag() { + return edgeName; + } - /** - * Collect shard result. - * - * @return - */ - @Override - public ShardResult collectResult() { - if (shard == null) { - return null; - } - return new ShardResult(shard.getEdgeId(), collectType, shard.getSlices()); + /** + * Collect shard result. + * + * @return + */ + @Override + public ShardResult collectResult() { + if (shard == null) { + return null; } + return new ShardResult(shard.getEdgeId(), collectType, shard.getSlices()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ClearEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ClearEmitterRequest.java index 92da21522..f6f70561a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ClearEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ClearEmitterRequest.java @@ -21,15 +21,14 @@ public class ClearEmitterRequest extends AbstractEmitterRequest { - public static final ClearEmitterRequest INSTANCE = new ClearEmitterRequest(); + public static final ClearEmitterRequest INSTANCE = new ClearEmitterRequest(); - private ClearEmitterRequest() { - super(-1, -1); - } - - @Override - public RequestType getRequestType() { - return RequestType.CLEAR; - } + private ClearEmitterRequest() { + super(-1, -1); + } + @Override + public RequestType getRequestType() { + return RequestType.CLEAR; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CloseEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CloseEmitterRequest.java index fe4b8fff2..2eee9841e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CloseEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CloseEmitterRequest.java @@ -21,13 +21,12 @@ public class CloseEmitterRequest extends AbstractEmitterRequest { - public CloseEmitterRequest(int taskId, long windowId) { - super(taskId, windowId); - } - - @Override - public RequestType getRequestType() { - return RequestType.CLOSE; - } + public CloseEmitterRequest(int taskId, long windowId) { + super(taskId, windowId); + } + @Override + public RequestType getRequestType() { + return RequestType.CLOSE; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectResponseCollector.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectResponseCollector.java index c1c3b5db5..c2652f3b5 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectResponseCollector.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectResponseCollector.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.response.ResponseResult; import org.apache.geaflow.collector.AbstractCollector; import org.apache.geaflow.collector.ICollector; @@ -30,58 +31,56 @@ public class CollectResponseCollector extends AbstractCollector implements IResultCollector, ICollector { - private int edgeId; - private OutputType collectorType; - private String edgeName; - private final List buffer; - private final List result; - - public CollectResponseCollector(ResponseOutputDesc outputDesc) { - super(outputDesc.getOpId()); - this.edgeId = outputDesc.getEdgeId(); - this.collectorType = outputDesc.getType(); - this.edgeName = outputDesc.getEdgeName(); - this.buffer = new ArrayList<>(); - this.result = new ArrayList<>(); - } + private int edgeId; + private OutputType collectorType; + private String edgeName; + private final List buffer; + private final List result; - @Override - public void partition(T value) { - buffer.add(value); - this.outputMeter.mark(); - } + public CollectResponseCollector(ResponseOutputDesc outputDesc) { + super(outputDesc.getOpId()); + this.edgeId = outputDesc.getEdgeId(); + this.collectorType = outputDesc.getType(); + this.edgeName = outputDesc.getEdgeName(); + this.buffer = new ArrayList<>(); + this.result = new ArrayList<>(); + } - @Override - public void finish() { - result.clear(); - result.addAll(buffer); - buffer.clear(); - } + @Override + public void partition(T value) { + buffer.add(value); + this.outputMeter.mark(); + } - @Override - public String getTag() { - return edgeName; - } + @Override + public void finish() { + result.clear(); + result.addAll(buffer); + buffer.clear(); + } - @Override - public OutputType getType() { - return collectorType; - } + @Override + public String getTag() { + return edgeName; + } - @Override - public void broadcast(T value) { + @Override + public OutputType getType() { + return collectorType; + } - } + @Override + public void broadcast(T value) {} - @Override - public void partition(KEY key, T value) { - partition(value); - } + @Override + public void partition(KEY key, T value) { + partition(value); + } - @Override - public ResponseResult collectResult() { - ResponseResult responseResult = new ResponseResult(edgeId, getType(), new ArrayList<>(result)); - result.clear(); - return responseResult; - } -} \ No newline at end of file + @Override + public ResponseResult collectResult() { + ResponseResult responseResult = new ResponseResult(edgeId, getType(), new ArrayList<>(result)); + result.clear(); + return responseResult; + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectorFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectorFactory.java index 05b8c0e81..a44f18742 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectorFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/CollectorFactory.java @@ -27,17 +27,16 @@ public class CollectorFactory { - public static ICollector create(IOutputDesc outputDesc) { - switch (outputDesc.getType()) { - case FORWARD: - return new ForwardOutputCollector<>((ForwardOutputDesc) outputDesc); - case LOOP: - return new IterationOutputCollector<>((ForwardOutputDesc) outputDesc); - case RESPONSE: - return new CollectResponseCollector<>((ResponseOutputDesc) outputDesc); - default: - throw new GeaflowRuntimeException("not support output type " + outputDesc.getType()); - - } + public static ICollector create(IOutputDesc outputDesc) { + switch (outputDesc.getType()) { + case FORWARD: + return new ForwardOutputCollector<>((ForwardOutputDesc) outputDesc); + case LOOP: + return new IterationOutputCollector<>((ForwardOutputDesc) outputDesc); + case RESPONSE: + return new CollectResponseCollector<>((ResponseOutputDesc) outputDesc); + default: + throw new GeaflowRuntimeException("not support output type " + outputDesc.getType()); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterRunner.java index 93632ed75..308b26a18 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterRunner.java @@ -26,34 +26,33 @@ public class EmitterRunner extends AbstractTaskRunner { - private final PipelineOutputEmitter outputEmitter; + private final PipelineOutputEmitter outputEmitter; - public EmitterRunner(Configuration configuration, int index) { - this.outputEmitter = new PipelineOutputEmitter(configuration, index); - } + public EmitterRunner(Configuration configuration, int index) { + this.outputEmitter = new PipelineOutputEmitter(configuration, index); + } - @Override - protected void process(IEmitterRequest request) { - switch (request.getRequestType()) { - case INIT: - this.outputEmitter.init((InitEmitterRequest) request); - break; - case POP: - this.outputEmitter.update((UpdateEmitterRequest) request); - break; - case CLOSE: - this.outputEmitter.close((CloseEmitterRequest) request); - break; - case STASH: - this.outputEmitter.stash((StashEmitterRequest) request); - break; - case CLEAR: - this.outputEmitter.clear(); - break; - default: - throw new GeaflowRuntimeException( - RuntimeErrors.INST.requestTypeNotSupportError(request.getRequestType().name())); - } + @Override + protected void process(IEmitterRequest request) { + switch (request.getRequestType()) { + case INIT: + this.outputEmitter.init((InitEmitterRequest) request); + break; + case POP: + this.outputEmitter.update((UpdateEmitterRequest) request); + break; + case CLOSE: + this.outputEmitter.close((CloseEmitterRequest) request); + break; + case STASH: + this.outputEmitter.stash((StashEmitterRequest) request); + break; + case CLEAR: + this.outputEmitter.clear(); + break; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.requestTypeNotSupportError(request.getRequestType().name())); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterService.java index 495e44b04..cfa22e04e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/EmitterService.java @@ -19,29 +19,32 @@ package org.apache.geaflow.cluster.collector; -import com.google.common.base.Preconditions; import java.io.Serializable; + import org.apache.geaflow.cluster.task.service.AbstractTaskService; import org.apache.geaflow.common.config.Configuration; -public class EmitterService extends AbstractTaskService implements Serializable { +import com.google.common.base.Preconditions; - private static final String EMITTER_FORMAT = "geaflow-emitter-%d"; +public class EmitterService extends AbstractTaskService + implements Serializable { - private final int slots; + private static final String EMITTER_FORMAT = "geaflow-emitter-%d"; - public EmitterService(int slots, Configuration configuration) { - super(configuration, EMITTER_FORMAT); - this.slots = slots; - } + private final int slots; + + public EmitterService(int slots, Configuration configuration) { + super(configuration, EMITTER_FORMAT); + this.slots = slots; + } - protected EmitterRunner[] buildTaskRunner() { - Preconditions.checkArgument(slots > 0, "fetcher pool should be larger than 0"); - EmitterRunner[] emitterRunners = new EmitterRunner[slots]; - for (int i = 0; i < slots; i++) { - EmitterRunner runner = new EmitterRunner(this.configuration, i); - emitterRunners[i] = runner; - } - return emitterRunners; + protected EmitterRunner[] buildTaskRunner() { + Preconditions.checkArgument(slots > 0, "fetcher pool should be larger than 0"); + EmitterRunner[] emitterRunners = new EmitterRunner[slots]; + for (int i = 0; i < slots; i++) { + EmitterRunner runner = new EmitterRunner(this.configuration, i); + emitterRunners[i] = runner; } + return emitterRunners; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ForwardOutputCollector.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ForwardOutputCollector.java index 010fdaa02..b44c0d5b6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ForwardOutputCollector.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/ForwardOutputCollector.java @@ -26,18 +26,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ForwardOutputCollector - extends AbstractPipelineOutputCollector implements IResultCollector { +public class ForwardOutputCollector extends AbstractPipelineOutputCollector + implements IResultCollector { - private static final Logger LOGGER = LoggerFactory.getLogger(ForwardOutputCollector.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ForwardOutputCollector.class); - public ForwardOutputCollector(ForwardOutputDesc outputDesc) { - super(outputDesc); - } - - @Override - public OutputType getType() { - return OutputType.FORWARD; - } + public ForwardOutputCollector(ForwardOutputDesc outputDesc) { + super(outputDesc); + } + @Override + public OutputType getType() { + return OutputType.FORWARD; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IEmitterRequest.java index d372e72e4..e9b26d32f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IEmitterRequest.java @@ -23,51 +23,39 @@ public interface IEmitterRequest extends Serializable { - /** - * Return the task id of the emitter request. - * - * @return task id. - */ - int getTaskId(); - - /** - * Return the window id of the emitter request. - * - * @return window id. - */ - long getWindowId(); - - /** - * Return the request type of the emitter request. - * - * @return request type. - */ - RequestType getRequestType(); - - enum RequestType { - - /** - * Init request. - */ - INIT, - /** - * Close request. - */ - CLOSE, - - /** - * Stash request. - */ - STASH, - /** - * Pop request, update the request when cached. - */ - POP, - /** - * Clear the init emitter request in cache. - */ - CLEAR - - } - + /** + * Return the task id of the emitter request. + * + * @return task id. + */ + int getTaskId(); + + /** + * Return the window id of the emitter request. + * + * @return window id. + */ + long getWindowId(); + + /** + * Return the request type of the emitter request. + * + * @return request type. + */ + RequestType getRequestType(); + + enum RequestType { + + /** Init request. */ + INIT, + /** Close request. */ + CLOSE, + + /** Stash request. */ + STASH, + /** Pop request, update the request when cached. */ + POP, + /** Clear the init emitter request in cache. */ + CLEAR + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IOutputMessageBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IOutputMessageBuffer.java index 69be1638b..e7ae9e396 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IOutputMessageBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IOutputMessageBuffer.java @@ -24,37 +24,36 @@ public interface IOutputMessageBuffer extends IMessageBuffer> { - /** - * Emit a record. - * - * @param windowId window id - * @param data data - * @param isRetract if this data is retract - * @param targetChannels target channels - */ - void emit(long windowId, T data, boolean isRetract, int[] targetChannels); - - /** - * For the consumer to finish. - * - * @param windowId window id - * @param result finish result - */ - void setResult(long windowId, R result); - - /** - * For the producer to call finish. - * - * @param windowId window id - * @return finish result - */ - R finish(long windowId); - - /** - * Meet error. - * - * @param t error - */ - void error(Throwable t); - + /** + * Emit a record. + * + * @param windowId window id + * @param data data + * @param isRetract if this data is retract + * @param targetChannels target channels + */ + void emit(long windowId, T data, boolean isRetract, int[] targetChannels); + + /** + * For the consumer to finish. + * + * @param windowId window id + * @param result finish result + */ + void setResult(long windowId, R result); + + /** + * For the producer to call finish. + * + * @param windowId window id + * @return finish result + */ + R finish(long windowId); + + /** + * Meet error. + * + * @param t error + */ + void error(Throwable t); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/InitEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/InitEmitterRequest.java index 20a58b7fa..4166d33bb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/InitEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/InitEmitterRequest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.collector; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.task.TaskArgs; import org.apache.geaflow.shuffle.OutputDescriptor; @@ -27,56 +28,56 @@ public class InitEmitterRequest extends AbstractEmitterRequest { - private final Configuration configuration; - private final long pipelineId; - private final String pipelineName; - private final TaskArgs taskArgs; - private final OutputDescriptor outputDescriptor; - private final List> outputBuffers; - - public InitEmitterRequest(Configuration configuration, - long windowId, - long pipelineId, - String pipelineName, - TaskArgs taskArgs, - OutputDescriptor outputDescriptor, - List> outputBuffers) { - super(taskArgs.getTaskId(), windowId); - this.configuration = configuration; - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - this.taskArgs = taskArgs; - this.outputDescriptor = outputDescriptor; - this.outputBuffers = outputBuffers; - } + private final Configuration configuration; + private final long pipelineId; + private final String pipelineName; + private final TaskArgs taskArgs; + private final OutputDescriptor outputDescriptor; + private final List> outputBuffers; - public Configuration getConfiguration() { - return this.configuration; - } + public InitEmitterRequest( + Configuration configuration, + long windowId, + long pipelineId, + String pipelineName, + TaskArgs taskArgs, + OutputDescriptor outputDescriptor, + List> outputBuffers) { + super(taskArgs.getTaskId(), windowId); + this.configuration = configuration; + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + this.taskArgs = taskArgs; + this.outputDescriptor = outputDescriptor; + this.outputBuffers = outputBuffers; + } - public long getPipelineId() { - return this.pipelineId; - } + public Configuration getConfiguration() { + return this.configuration; + } - public String getPipelineName() { - return this.pipelineName; - } + public long getPipelineId() { + return this.pipelineId; + } - public TaskArgs getTaskArgs() { - return this.taskArgs; - } + public String getPipelineName() { + return this.pipelineName; + } - public OutputDescriptor getOutputDescriptor() { - return this.outputDescriptor; - } + public TaskArgs getTaskArgs() { + return this.taskArgs; + } - public List> getOutputBuffers() { - return this.outputBuffers; - } + public OutputDescriptor getOutputDescriptor() { + return this.outputDescriptor; + } - @Override - public RequestType getRequestType() { - return RequestType.INIT; - } + public List> getOutputBuffers() { + return this.outputBuffers; + } + @Override + public RequestType getRequestType() { + return RequestType.INIT; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IterationOutputCollector.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IterationOutputCollector.java index e367211ab..9f5ce4f8e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IterationOutputCollector.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/IterationOutputCollector.java @@ -26,18 +26,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class IterationOutputCollector - extends AbstractPipelineOutputCollector implements IResultCollector { +public class IterationOutputCollector extends AbstractPipelineOutputCollector + implements IResultCollector { - private static final Logger LOGGER = LoggerFactory.getLogger(IterationOutputCollector.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IterationOutputCollector.class); - public IterationOutputCollector(ForwardOutputDesc outputDesc) { - super(outputDesc); - } - - @Override - public OutputType getType() { - return OutputType.LOOP; - } + public IterationOutputCollector(ForwardOutputDesc outputDesc) { + super(outputDesc); + } + @Override + public OutputType getType() { + return OutputType.LOOP; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/PipelineOutputEmitter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/PipelineOutputEmitter.java index efd4c6130..3771f7ecb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/PipelineOutputEmitter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/PipelineOutputEmitter.java @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.cluster.exception.ComponentUncaughtExceptionHandler; import org.apache.geaflow.cluster.protocol.OutputMessage; import org.apache.geaflow.common.config.Configuration; @@ -52,195 +53,203 @@ public class PipelineOutputEmitter { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineOutputEmitter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineOutputEmitter.class); - private static final ExecutorService EMIT_EXECUTOR = Executors.getUnboundedExecutorService( - PipelineOutputEmitter.class.getSimpleName(), 60, TimeUnit.SECONDS, null, ComponentUncaughtExceptionHandler.INSTANCE); + private static final ExecutorService EMIT_EXECUTOR = + Executors.getUnboundedExecutorService( + PipelineOutputEmitter.class.getSimpleName(), + 60, + TimeUnit.SECONDS, + null, + ComponentUncaughtExceptionHandler.INSTANCE); - private static final int DEFAULT_TIMEOUT_MS = 100; + private static final int DEFAULT_TIMEOUT_MS = 100; - private final Configuration configuration; - private final int index; - private final Map initRequestCache = new HashMap<>(); - private final Map runningFlags = new HashMap<>(); + private final Configuration configuration; + private final int index; + private final Map initRequestCache = new HashMap<>(); + private final Map runningFlags = new HashMap<>(); - public PipelineOutputEmitter(Configuration configuration, int index) { - this.configuration = configuration; - this.index = index; - } + public PipelineOutputEmitter(Configuration configuration, int index) { + this.configuration = configuration; + this.index = index; + } - public void init(InitEmitterRequest request) { - this.initRequestCache.put(request.getTaskId(), request); - UpdateEmitterRequest updateEmitterRequest = new UpdateEmitterRequest( + public void init(InitEmitterRequest request) { + this.initRequestCache.put(request.getTaskId(), request); + UpdateEmitterRequest updateEmitterRequest = + new UpdateEmitterRequest( request.getTaskId(), request.getWindowId(), request.getPipelineId(), request.getPipelineName(), request.getOutputBuffers()); - this.update(updateEmitterRequest); + this.update(updateEmitterRequest); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void update(UpdateEmitterRequest request) { + int taskId = request.getTaskId(); + if (!this.initRequestCache.containsKey(taskId)) { + throw new GeaflowRuntimeException("init emitter request not found for task " + taskId); } - - @SuppressWarnings({"unchecked", "rawtypes"}) - public void update(UpdateEmitterRequest request) { - int taskId = request.getTaskId(); - if (!this.initRequestCache.containsKey(taskId)) { - throw new GeaflowRuntimeException("init emitter request not found for task " + taskId); - } - InitEmitterRequest initEmitterRequest = this.initRequestCache.get(taskId); - OutputDescriptor outputDescriptor = initEmitterRequest.getOutputDescriptor(); - List> outputBuffers = request.getOutputBuffers(); - List outputDescList = outputDescriptor.getOutputDescList(); - - int outputNum = outputDescList.size(); - AtomicBoolean[] flags = new AtomicBoolean[outputNum]; - ShuffleConfig shuffleConfig = ShuffleManager.getInstance().getShuffleConfig(); - for (int i = 0; i < outputNum; i++) { - IOutputDesc outputDesc = outputDescList.get(i); - if (outputDesc.getType() == OutputType.RESPONSE) { - continue; - } - ForwardOutputDesc forwardOutputDesc = (ForwardOutputDesc) outputDesc; - IShuffleWriter pipeRecordWriter = ShuffleManager.getInstance().loadShuffleWriter(); - IEncoder encoder = forwardOutputDesc.getEncoder(); - if (encoder != null) { - encoder.init(initEmitterRequest.getConfiguration()); - } - TaskArgs taskArgs = initEmitterRequest.getTaskArgs(); - IWriterContext writerContext = WriterContext.newBuilder() - .setPipelineId(request.getPipelineId()) - .setPipelineName(request.getPipelineName()) - .setConfig(shuffleConfig) - .setVertexId(forwardOutputDesc.getPartitioner().getOpId()) - .setEdgeId(forwardOutputDesc.getEdgeId()) - .setTaskId(taskArgs.getTaskId()) - .setTaskIndex(taskArgs.getTaskIndex()) - .setTaskName(taskArgs.getTaskName()) - .setChannelNum(forwardOutputDesc.getTargetTaskIndices().size()) - .setEncoder(encoder) - .setDataExchangeMode(forwardOutputDesc.getDataExchangeMode()); - pipeRecordWriter.init(writerContext); - - AtomicBoolean flag = new AtomicBoolean(true); - flags[i] = flag; - String emitterId = String.format("%d[%d/%d]", taskId, taskArgs.getTaskIndex(), taskArgs.getParallelism()); - EmitterTask emitterTask = new EmitterTask( - pipeRecordWriter, - outputBuffers.get(i), - flag, - request.getWindowId(), - this.index, - forwardOutputDesc.getEdgeName(), - emitterId); - EMIT_EXECUTOR.execute(emitterTask); - } - this.runningFlags.put(taskId, flags); + InitEmitterRequest initEmitterRequest = this.initRequestCache.get(taskId); + OutputDescriptor outputDescriptor = initEmitterRequest.getOutputDescriptor(); + List> outputBuffers = request.getOutputBuffers(); + List outputDescList = outputDescriptor.getOutputDescList(); + + int outputNum = outputDescList.size(); + AtomicBoolean[] flags = new AtomicBoolean[outputNum]; + ShuffleConfig shuffleConfig = ShuffleManager.getInstance().getShuffleConfig(); + for (int i = 0; i < outputNum; i++) { + IOutputDesc outputDesc = outputDescList.get(i); + if (outputDesc.getType() == OutputType.RESPONSE) { + continue; + } + ForwardOutputDesc forwardOutputDesc = (ForwardOutputDesc) outputDesc; + IShuffleWriter pipeRecordWriter = ShuffleManager.getInstance().loadShuffleWriter(); + IEncoder encoder = forwardOutputDesc.getEncoder(); + if (encoder != null) { + encoder.init(initEmitterRequest.getConfiguration()); + } + TaskArgs taskArgs = initEmitterRequest.getTaskArgs(); + IWriterContext writerContext = + WriterContext.newBuilder() + .setPipelineId(request.getPipelineId()) + .setPipelineName(request.getPipelineName()) + .setConfig(shuffleConfig) + .setVertexId(forwardOutputDesc.getPartitioner().getOpId()) + .setEdgeId(forwardOutputDesc.getEdgeId()) + .setTaskId(taskArgs.getTaskId()) + .setTaskIndex(taskArgs.getTaskIndex()) + .setTaskName(taskArgs.getTaskName()) + .setChannelNum(forwardOutputDesc.getTargetTaskIndices().size()) + .setEncoder(encoder) + .setDataExchangeMode(forwardOutputDesc.getDataExchangeMode()); + pipeRecordWriter.init(writerContext); + + AtomicBoolean flag = new AtomicBoolean(true); + flags[i] = flag; + String emitterId = + String.format("%d[%d/%d]", taskId, taskArgs.getTaskIndex(), taskArgs.getParallelism()); + EmitterTask emitterTask = + new EmitterTask( + pipeRecordWriter, + outputBuffers.get(i), + flag, + request.getWindowId(), + this.index, + forwardOutputDesc.getEdgeName(), + emitterId); + EMIT_EXECUTOR.execute(emitterTask); } - - public void close(CloseEmitterRequest request) { - int taskId = request.getTaskId(); - this.initRequestCache.remove(taskId); - this.handleRunningFlags(taskId); + this.runningFlags.put(taskId, flags); + } + + public void close(CloseEmitterRequest request) { + int taskId = request.getTaskId(); + this.initRequestCache.remove(taskId); + this.handleRunningFlags(taskId); + } + + public void stash(StashEmitterRequest request) { + this.handleRunningFlags(request.getTaskId()); + } + + public Configuration getConfiguration() { + return this.configuration; + } + + public void clear() { + LOGGER.info("clear emitter cache of task {}", this.initRequestCache.keySet()); + this.initRequestCache.clear(); + } + + private void handleRunningFlags(int taskId) { + if (!this.runningFlags.containsKey(taskId)) { + return; } - - public void stash(StashEmitterRequest request) { - this.handleRunningFlags(request.getTaskId()); + for (AtomicBoolean flag : this.runningFlags.remove(taskId)) { + if (flag != null) { + flag.set(false); + } } - - public Configuration getConfiguration() { - return this.configuration; + } + + private static class EmitterTask implements Runnable { + + private static final String WRITER_NAME_PATTERN = "shuffle-writer-%d-%s"; + + private final IShuffleWriter writer; + private final IOutputMessageBuffer pipe; + private final AtomicBoolean running; + private final long windowId; + private final String name; + private final String emitterId; + private final boolean isMessage; + + public EmitterTask( + IShuffleWriter writer, + IOutputMessageBuffer pipe, + AtomicBoolean running, + long windowId, + int workerIndex, + String edgeName, + String emitterId) { + this.writer = writer; + this.pipe = pipe; + this.running = running; + this.windowId = windowId; + this.name = String.format(WRITER_NAME_PATTERN, workerIndex, edgeName); + this.emitterId = emitterId; + this.isMessage = edgeName.equals(RecordArgs.GraphRecordNames.Message.name()); } - public void clear() { - LOGGER.info("clear emitter cache of task {}", this.initRequestCache.keySet()); - this.initRequestCache.clear(); + @Override + public void run() { + Thread.currentThread().setName(this.name); + try { + this.execute(); + } catch (Throwable t) { + this.pipe.error(t); + LOGGER.error("emitter task err in window id {} {}", this.windowId, this.emitterId, t); + throw new GeaflowRuntimeException(t); + } + LOGGER.info("emitter task finish window id {} {}", this.windowId, this.emitterId); } - private void handleRunningFlags(int taskId) { - if (!this.runningFlags.containsKey(taskId)) { - return; + private void execute() throws Exception { + while (this.running.get()) { + OutputMessage record = this.pipe.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); + if (record == null) { + continue; } - for (AtomicBoolean flag : this.runningFlags.remove(taskId)) { - if (flag != null) { - flag.set(false); - } + long windowId = record.getWindowId(); + if (record.isBarrier()) { + Optional result = this.writer.flush(windowId); + this.handleMetrics(); + this.pipe.setResult(windowId, result.orElse(null)); + } else { + this.writer.emit(windowId, record.getMessage(), false, record.getTargetChannel()); } + } + this.writer.close(); } - private static class EmitterTask implements Runnable { - - private static final String WRITER_NAME_PATTERN = "shuffle-writer-%d-%s"; - - private final IShuffleWriter writer; - private final IOutputMessageBuffer pipe; - private final AtomicBoolean running; - private final long windowId; - private final String name; - private final String emitterId; - private final boolean isMessage; - - public EmitterTask(IShuffleWriter writer, - IOutputMessageBuffer pipe, - AtomicBoolean running, - long windowId, - int workerIndex, - String edgeName, - String emitterId) { - this.writer = writer; - this.pipe = pipe; - this.running = running; - this.windowId = windowId; - this.name = String.format(WRITER_NAME_PATTERN, workerIndex, edgeName); - this.emitterId = emitterId; - this.isMessage = edgeName.equals(RecordArgs.GraphRecordNames.Message.name()); - } - - @Override - public void run() { - Thread.currentThread().setName(this.name); - try { - this.execute(); - } catch (Throwable t) { - this.pipe.error(t); - LOGGER.error("emitter task err in window id {} {}", this.windowId, this.emitterId, t); - throw new GeaflowRuntimeException(t); - } - LOGGER.info("emitter task finish window id {} {}", this.windowId, this.emitterId); - } - - private void execute() throws Exception { - while (this.running.get()) { - OutputMessage record = this.pipe.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); - if (record == null) { - continue; - } - long windowId = record.getWindowId(); - if (record.isBarrier()) { - Optional result = this.writer.flush(windowId); - this.handleMetrics(); - this.pipe.setResult(windowId, result.orElse(null)); - } else { - this.writer.emit(windowId, record.getMessage(), false, record.getTargetChannel()); - } - } - this.writer.close(); - } - - @SuppressWarnings("unchecked") - private void handleMetrics() { - ShuffleWriteMetrics shuffleWriteMetrics = this.writer.getShuffleWriteMetrics(); - EventMetrics eventMetrics = ((AbstractMessageBuffer) this.pipe).getEventMetrics(); - if (this.isMessage) { - // When send message, all iteration share the same context and writer, just set the total metric. - eventMetrics.setShuffleWriteRecords(shuffleWriteMetrics.getWrittenRecords()); - eventMetrics.setShuffleWriteBytes(shuffleWriteMetrics.getEncodedSize()); - } else { - // In FINISH iteration or other case, just add output metric. - eventMetrics.addShuffleWriteRecords(shuffleWriteMetrics.getWrittenRecords()); - eventMetrics.addShuffleWriteBytes(shuffleWriteMetrics.getEncodedSize()); - } - - } - + @SuppressWarnings("unchecked") + private void handleMetrics() { + ShuffleWriteMetrics shuffleWriteMetrics = this.writer.getShuffleWriteMetrics(); + EventMetrics eventMetrics = ((AbstractMessageBuffer) this.pipe).getEventMetrics(); + if (this.isMessage) { + // When send message, all iteration share the same context and writer, just set the total + // metric. + eventMetrics.setShuffleWriteRecords(shuffleWriteMetrics.getWrittenRecords()); + eventMetrics.setShuffleWriteBytes(shuffleWriteMetrics.getEncodedSize()); + } else { + // In FINISH iteration or other case, just add output metric. + eventMetrics.addShuffleWriteRecords(shuffleWriteMetrics.getWrittenRecords()); + eventMetrics.addShuffleWriteBytes(shuffleWriteMetrics.getEncodedSize()); + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/StashEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/StashEmitterRequest.java index b32e0d5c1..df205f8f6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/StashEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/StashEmitterRequest.java @@ -21,13 +21,12 @@ public class StashEmitterRequest extends AbstractEmitterRequest { - public StashEmitterRequest(int taskId, long windowId) { - super(taskId, windowId); - } - - @Override - public RequestType getRequestType() { - return RequestType.STASH; - } + public StashEmitterRequest(int taskId, long windowId) { + super(taskId, windowId); + } + @Override + public RequestType getRequestType() { + return RequestType.STASH; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/UpdateEmitterRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/UpdateEmitterRequest.java index 41b152c13..11e603001 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/UpdateEmitterRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/collector/UpdateEmitterRequest.java @@ -20,40 +20,41 @@ package org.apache.geaflow.cluster.collector; import java.util.List; + import org.apache.geaflow.shuffle.message.Shard; public class UpdateEmitterRequest extends AbstractEmitterRequest { - private final long pipelineId; - private final String pipelineName; - private final List> outputBuffers; - - public UpdateEmitterRequest(int taskId, - long windowId, - long pipelineId, - String pipelineName, - List> outputBuffers) { - super(taskId, windowId); - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - this.outputBuffers = outputBuffers; - } - - public long getPipelineId() { - return this.pipelineId; - } - - public String getPipelineName() { - return this.pipelineName; - } - - public List> getOutputBuffers() { - return this.outputBuffers; - } - - @Override - public RequestType getRequestType() { - return RequestType.POP; - } - + private final long pipelineId; + private final String pipelineName; + private final List> outputBuffers; + + public UpdateEmitterRequest( + int taskId, + long windowId, + long pipelineId, + String pipelineName, + List> outputBuffers) { + super(taskId, windowId); + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + this.outputBuffers = outputBuffers; + } + + public long getPipelineId() { + return this.pipelineId; + } + + public String getPipelineName() { + return this.pipelineName; + } + + public List> getOutputBuffers() { + return this.outputBuffers; + } + + @Override + public RequestType getRequestType() { + return RequestType.POP; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractComponent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractComponent.java index e54364327..7105cc9e7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractComponent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractComponent.java @@ -36,79 +36,80 @@ public abstract class AbstractComponent { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractComponent.class); - - protected int id; - protected String name; - protected String masterId; - protected int rpcPort; - protected int supervisorPort; - - protected Configuration configuration; - protected IHAService haService; - protected RpcServiceImpl rpcService; - protected MetricGroup metricGroup; - protected MetricGroupRegistry metricGroupRegistry; - - public AbstractComponent() { + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractComponent.class); + + protected int id; + protected String name; + protected String masterId; + protected int rpcPort; + protected int supervisorPort; + + protected Configuration configuration; + protected IHAService haService; + protected RpcServiceImpl rpcService; + protected MetricGroup metricGroup; + protected MetricGroupRegistry metricGroupRegistry; + + public AbstractComponent() {} + + public AbstractComponent(int rpcPort) { + this.rpcPort = rpcPort; + } + + public void init(int id, String name, Configuration configuration) { + this.id = id; + this.name = name; + this.configuration = configuration; + this.masterId = configuration.getMasterId(); + + this.metricGroupRegistry = MetricGroupRegistry.getInstance(configuration); + this.metricGroup = metricGroupRegistry.getMetricGroup(); + this.haService = HAServiceFactory.getService(configuration); + + RpcClient.init(configuration); + ClusterMetaStore.init(id, name, configuration); + StatsCollectorFactory.init(configuration); + + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + // Use stderr here since the logger may have been reset by its JVM shutdown hook. + LOGGER.warn("*** Shutting ClusterMetaStore since JVM is shutting down."); + ClusterMetaStore.close(); + LOGGER.warn("*** ClusterMetaStore is shutdown."); + })); + } + + protected void registerHAService() { + ResourceData resourceData = buildResourceData(); + LOGGER.info("register {}: {}", name, resourceData); + haService.register(name, resourceData); + } + + protected ResourceData buildResourceData() { + ResourceData resourceData = new ResourceData(); + resourceData.setProcessId(ProcessUtil.getProcessId()); + resourceData.setHost(ProcessUtil.getHostIp()); + resourceData.setRpcPort(rpcPort); + ShuffleManager shuffleManager = ShuffleManager.getInstance(); + if (shuffleManager != null) { + resourceData.setShufflePort(shuffleManager.getShufflePort()); } + return resourceData; + } - public AbstractComponent(int rpcPort) { - this.rpcPort = rpcPort; + public void close() { + if (haService != null) { + haService.close(); } - - public void init(int id, String name, Configuration configuration) { - this.id = id; - this.name = name; - this.configuration = configuration; - this.masterId = configuration.getMasterId(); - - this.metricGroupRegistry = MetricGroupRegistry.getInstance(configuration); - this.metricGroup = metricGroupRegistry.getMetricGroup(); - this.haService = HAServiceFactory.getService(configuration); - - RpcClient.init(configuration); - ClusterMetaStore.init(id, name, configuration); - StatsCollectorFactory.init(configuration); - - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - // Use stderr here since the logger may have been reset by its JVM shutdown hook. - LOGGER.warn("*** Shutting ClusterMetaStore since JVM is shutting down."); - ClusterMetaStore.close(); - LOGGER.warn("*** ClusterMetaStore is shutdown."); - })); - } - - protected void registerHAService() { - ResourceData resourceData = buildResourceData(); - LOGGER.info("register {}: {}", name, resourceData); - haService.register(name, resourceData); - } - - protected ResourceData buildResourceData() { - ResourceData resourceData = new ResourceData(); - resourceData.setProcessId(ProcessUtil.getProcessId()); - resourceData.setHost(ProcessUtil.getHostIp()); - resourceData.setRpcPort(rpcPort); - ShuffleManager shuffleManager = ShuffleManager.getInstance(); - if (shuffleManager != null) { - resourceData.setShufflePort(shuffleManager.getShufflePort()); - } - return resourceData; - } - - public void close() { - if (haService != null) { - haService.close(); - } - if (rpcService != null) { - rpcService.stopService(); - } - ClusterMetaStore.close(); - } - - public void waitTermination() { - rpcService.waitTermination(); + if (rpcService != null) { + rpcService.stopService(); } + ClusterMetaStore.close(); + } + public void waitTermination() { + rpcService.waitTermination(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractContainer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractContainer.java index 224f4085c..1ee52dac7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractContainer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/AbstractContainer.java @@ -35,78 +35,77 @@ public abstract class AbstractContainer extends AbstractComponent { - protected HeartbeatClient heartbeatClient; - protected ExceptionCollectService exceptionCollectService; - protected MetricServer metricServer; - protected int metricPort; - protected int supervisorPort; - protected boolean enableInfer; - - public AbstractContainer(int rpcPort) { - super(rpcPort); + protected HeartbeatClient heartbeatClient; + protected ExceptionCollectService exceptionCollectService; + protected MetricServer metricServer; + protected int metricPort; + protected int supervisorPort; + protected boolean enableInfer; + + public AbstractContainer(int rpcPort) { + super(rpcPort); + } + + @Override + public void init(int id, String name, Configuration configuration) { + super.init(id, name, configuration); + + startRpcService(); + ShuffleManager.init(configuration); + ExceptionClient.init(id, name, masterId); + this.heartbeatClient = new HeartbeatClient(id, name, configuration); + this.exceptionCollectService = new ExceptionCollectService(); + this.metricServer = new MetricServer(configuration); + this.metricPort = metricServer.start(); + this.supervisorPort = configuration.getInteger(SUPERVISOR_RPC_PORT); + this.enableInfer = configuration.getBoolean(FrameworkConfigKeys.INFER_ENV_ENABLE); + initInferEnvironment(configuration); + } + + protected void registerToMaster() { + this.heartbeatClient.init(masterId, buildComponentInfo()); + } + + @Override + protected ResourceData buildResourceData() { + ResourceData resourceData = super.buildResourceData(); + resourceData.setMetricPort(metricPort); + resourceData.setSupervisorPort(supervisorPort); + return resourceData; + } + + protected abstract void startRpcService(); + + protected abstract ComponentInfo buildComponentInfo(); + + protected void fillComponentInfo(ComponentInfo componentInfo) { + componentInfo.setId(id); + componentInfo.setName(name); + componentInfo.setHost(ProcessUtil.getHostIp()); + componentInfo.setPid(ProcessUtil.getProcessId()); + componentInfo.setRpcPort(rpcPort); + componentInfo.setMetricPort(metricPort); + componentInfo.setAgentPort(configuration.getInteger(ExecutionConfigKeys.AGENT_HTTP_PORT)); + } + + public void close() { + super.close(); + if (exceptionCollectService != null) { + exceptionCollectService.shutdown(); } - - @Override - public void init(int id, String name, Configuration configuration) { - super.init(id, name, configuration); - - startRpcService(); - ShuffleManager.init(configuration); - ExceptionClient.init(id, name, masterId); - this.heartbeatClient = new HeartbeatClient(id, name, configuration); - this.exceptionCollectService = new ExceptionCollectService(); - this.metricServer = new MetricServer(configuration); - this.metricPort = metricServer.start(); - this.supervisorPort = configuration.getInteger(SUPERVISOR_RPC_PORT); - this.enableInfer = configuration.getBoolean(FrameworkConfigKeys.INFER_ENV_ENABLE); - initInferEnvironment(configuration); + if (heartbeatClient != null) { + heartbeatClient.close(); } - - protected void registerToMaster() { - this.heartbeatClient.init(masterId, buildComponentInfo()); + if (metricServer != null) { + metricServer.stop(); } + } - @Override - protected ResourceData buildResourceData() { - ResourceData resourceData = super.buildResourceData(); - resourceData.setMetricPort(metricPort); - resourceData.setSupervisorPort(supervisorPort); - return resourceData; + private void initInferEnvironment(Configuration configuration) { + if (enableInfer) { + InferEnvironmentManager inferEnvironmentManager = + InferEnvironmentManager.buildInferEnvironmentManager(configuration); + inferEnvironmentManager.createEnvironment(); } - - protected abstract void startRpcService(); - - protected abstract ComponentInfo buildComponentInfo(); - - protected void fillComponentInfo(ComponentInfo componentInfo) { - componentInfo.setId(id); - componentInfo.setName(name); - componentInfo.setHost(ProcessUtil.getHostIp()); - componentInfo.setPid(ProcessUtil.getProcessId()); - componentInfo.setRpcPort(rpcPort); - componentInfo.setMetricPort(metricPort); - componentInfo.setAgentPort(configuration.getInteger(ExecutionConfigKeys.AGENT_HTTP_PORT)); - } - - public void close() { - super.close(); - if (exceptionCollectService != null) { - exceptionCollectService.shutdown(); - } - if (heartbeatClient != null) { - heartbeatClient.close(); - } - if (metricServer != null) { - metricServer.stop(); - } - } - - private void initInferEnvironment(Configuration configuration) { - if (enableInfer) { - InferEnvironmentManager inferEnvironmentManager = - InferEnvironmentManager.buildInferEnvironmentManager(configuration); - inferEnvironmentManager.createEnvironment(); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ComponentInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ComponentInfo.java index 8ca7dc07c..8c460d1d3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ComponentInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ComponentInfo.java @@ -23,99 +23,90 @@ public class ComponentInfo implements Serializable { - /** - * component id. - */ - protected int id; - /** - * component name. - */ - protected String name; - /** - * host ip. - */ - protected String host; - /** - * process id. - */ - protected int pid; - /** - * rpc service port. - */ - protected int rpcPort; - /** - * metric query port. - */ - protected int metricPort; - /** - * agent service port. - */ - protected int agentPort; - - public ComponentInfo() { - } - - public ComponentInfo(int id, String name, String host, int pid, int rpcPort) { - this.id = id; - this.host = host; - this.pid = pid; - this.rpcPort = rpcPort; - this.name = name; - } - - public int getId() { - return id; - } - - public void setId(int id) { - this.id = id; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public String getHost() { - return host; - } - - public void setHost(String host) { - this.host = host; - } - - public int getPid() { - return pid; - } - - public void setPid(int pid) { - this.pid = pid; - } - - public int getRpcPort() { - return rpcPort; - } - - public void setRpcPort(int rpcPort) { - this.rpcPort = rpcPort; - } - - public int getMetricPort() { - return metricPort; - } - - public void setMetricPort(int metricPort) { - this.metricPort = metricPort; - } - - public int getAgentPort() { - return agentPort; - } - - public void setAgentPort(int agentPort) { - this.agentPort = agentPort; - } + /** component id. */ + protected int id; + + /** component name. */ + protected String name; + + /** host ip. */ + protected String host; + + /** process id. */ + protected int pid; + + /** rpc service port. */ + protected int rpcPort; + + /** metric query port. */ + protected int metricPort; + + /** agent service port. */ + protected int agentPort; + + public ComponentInfo() {} + + public ComponentInfo(int id, String name, String host, int pid, int rpcPort) { + this.id = id; + this.host = host; + this.pid = pid; + this.rpcPort = rpcPort; + this.name = name; + } + + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public int getPid() { + return pid; + } + + public void setPid(int pid) { + this.pid = pid; + } + + public int getRpcPort() { + return rpcPort; + } + + public void setRpcPort(int rpcPort) { + this.rpcPort = rpcPort; + } + + public int getMetricPort() { + return metricPort; + } + + public void setMetricPort(int metricPort) { + this.metricPort = metricPort; + } + + public int getAgentPort() { + return agentPort; + } + + public void setAgentPort(int agentPort) { + this.agentPort = agentPort; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ExecutionIdGenerator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ExecutionIdGenerator.java index d14e2d203..fe6a2bc5e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ExecutionIdGenerator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ExecutionIdGenerator.java @@ -23,31 +23,30 @@ public class ExecutionIdGenerator { - private static volatile ExecutionIdGenerator INSTANCE; + private static volatile ExecutionIdGenerator INSTANCE; - private IdGenerator idGenerator; + private IdGenerator idGenerator; - private ExecutionIdGenerator(int containerId) { - this.idGenerator = new IdGenerator(containerId); - } + private ExecutionIdGenerator(int containerId) { + this.idGenerator = new IdGenerator(containerId); + } - public static ExecutionIdGenerator init(int containerId) { + public static ExecutionIdGenerator init(int containerId) { + if (INSTANCE == null) { + synchronized (ExecutionIdGenerator.class) { if (INSTANCE == null) { - synchronized (ExecutionIdGenerator.class) { - if (INSTANCE == null) { - INSTANCE = new ExecutionIdGenerator(containerId); - } - } + INSTANCE = new ExecutionIdGenerator(containerId); } - return INSTANCE; - } - - public static ExecutionIdGenerator getInstance() { - return INSTANCE; + } } + return INSTANCE; + } - public long generateId() { - return idGenerator.nextId(); - } + public static ExecutionIdGenerator getInstance() { + return INSTANCE; + } + public long generateId() { + return idGenerator.nextId(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IDispatcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IDispatcher.java index 57d561bcd..925cef55e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IDispatcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IDispatcher.java @@ -24,8 +24,6 @@ public interface IDispatcher { - /** - * Dispatch an event to all registered listeners. - */ - void dispatch(IEvent event) throws GeaflowDispatchException; + /** Dispatch an event to all registered listeners. */ + void dispatch(IEvent event) throws GeaflowDispatchException; } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventListener.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventListener.java index 18a55cebe..76e324cba 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventListener.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventListener.java @@ -23,8 +23,6 @@ public interface IEventListener { - /** - * Handle input event if the listener register to driver. - */ - void handleEvent(IEvent event); + /** Handle input event if the listener register to driver. */ + void handleEvent(IEvent event); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventProcessor.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventProcessor.java index 77550df45..3848b6032 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventProcessor.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IEventProcessor.java @@ -23,9 +23,6 @@ public interface IEventProcessor extends Serializable { - /** - * Execute an input event. - */ - O process(I input); - + /** Execute an input event. */ + O process(I input); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IReliableContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IReliableContext.java index 704da1d37..dbb22f582 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IReliableContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/IReliableContext.java @@ -23,22 +23,15 @@ public interface IReliableContext extends Serializable { - /** - * Execute function doCheckpoint interface. - */ - void checkpoint(IReliableContextCheckpointFunction function); + /** Execute function doCheckpoint interface. */ + void checkpoint(IReliableContextCheckpointFunction function); - /** - * Load events from meta store. - */ - void load(); + /** Load events from meta store. */ + void load(); - interface IReliableContextCheckpointFunction { - - /** - * Do checkpoint based on context. - */ - void doCheckpoint(IReliableContext context); - } + interface IReliableContextCheckpointFunction { + /** Do checkpoint based on context. */ + void doCheckpoint(IReliableContext context); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ReliableContainerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ReliableContainerContext.java index 5ca472db7..2c8a1463e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ReliableContainerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/common/ReliableContainerContext.java @@ -20,39 +20,41 @@ package org.apache.geaflow.cluster.common; import java.io.Serializable; + import org.apache.geaflow.common.config.Configuration; public abstract class ReliableContainerContext implements IReliableContext, Serializable { - protected final int id; - protected final String name; - protected final Configuration config; - protected boolean isRecover; - - public ReliableContainerContext(int id, String name, Configuration config) { - this.id = id; - this.name = name; - this.config = config; - } - - public int getId() { - return id; - } - - public String getName() { - return name; - } - - public Configuration getConfig() { - return config; - } - - public boolean isRecover() { - return isRecover; - } - - @Override - public synchronized void checkpoint(IReliableContext.IReliableContextCheckpointFunction function) { - function.doCheckpoint(this); - } + protected final int id; + protected final String name; + protected final Configuration config; + protected boolean isRecover; + + public ReliableContainerContext(int id, String name, Configuration config) { + this.id = id; + this.name = name; + this.config = config; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + + public Configuration getConfig() { + return config; + } + + public boolean isRecover() { + return isRecover; + } + + @Override + public synchronized void checkpoint( + IReliableContext.IReliableContextCheckpointFunction function) { + function.doCheckpoint(this); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterConfig.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterConfig.java index 7f669dd36..08ae55334 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterConfig.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterConfig.java @@ -45,311 +45,326 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.MASTER_VCORES; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SUPERVISOR_JVM_OPTIONS; -import com.google.common.base.Preconditions; import java.io.Serializable; + import org.apache.geaflow.cluster.client.utils.PipelineUtil; import org.apache.geaflow.cluster.failover.FailoverStrategyType; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; -public class ClusterConfig implements Serializable { - - private static final double DEFAULT_HEAP_FRACTION = 0.8; - - private int containerNum; - private int containerMemoryMB; - private int containerDiskGB; - private double containerVcores; - private int containerWorkerNum; - private ClusterJvmOptions containerJvmOptions; - - private int masterMemoryMB; - private int masterDiskGB; - private double masterVcores; - private ClusterJvmOptions masterJvmOptions; - - private int driverNum; - private int driverMemoryMB; - private int driverDiskGB; - private double driverVcores; - private ClusterJvmOptions driverJvmOptions; - - private int clientMemoryMB; - private int clientDiskGB; - private double clientVcores; - private ClusterJvmOptions clientJvmOptions; - - private boolean isFoEnable; - private int maxRestarts; - private Configuration config; - private ClusterJvmOptions supervisorJvmOptions; - - public static ClusterConfig build(Configuration config) { - ClusterConfig clusterConfig = new ClusterConfig(); - - clusterConfig.setMasterMemoryMB(config.getInteger(MASTER_MEMORY_MB)); - clusterConfig.setMasterDiskGB(config.getInteger(MASTER_DISK_GB)); - ClusterJvmOptions masterJvmOptions = ClusterJvmOptions.build( - config.getString(MASTER_JVM_OPTIONS)); - clusterConfig.setMasterJvmOptions(masterJvmOptions); - clusterConfig.setMasterVcores(config.getDouble(MASTER_VCORES)); - - ClusterJvmOptions clientJvmOptions = ClusterJvmOptions.build( - config.getString(CLIENT_JVM_OPTIONS)); - clusterConfig.setClientJvmOptions(clientJvmOptions); - clusterConfig.setClientVcores(config.getDouble(CLIENT_VCORES)); - clusterConfig.setClientMemoryMB(config.getInteger(CLIENT_MEMORY_MB)); - clusterConfig.setClientDiskGB(config.getInteger(CLIENT_DISK_GB)); - - int driverMB = config.getInteger(DRIVER_MEMORY_MB); - clusterConfig.setDriverMemoryMB(driverMB); - int driverDiskGB = config.getInteger(DRIVER_DISK_GB); - clusterConfig.setDriverDiskGB(driverDiskGB); - ClusterJvmOptions driverJvmOptions = ClusterJvmOptions.build( - config.getString(DRIVER_JVM_OPTION)); - clusterConfig.setDriverJvmOptions(driverJvmOptions); - clusterConfig.setDriverVcores(config.getDouble(DRIVER_VCORES)); - - int driverNum = config.getInteger(DRIVER_NUM); - Preconditions.checkArgument( - driverNum == 1 || driverNum > 1 && PipelineUtil.isAsync(config), - "only one driver is allowed in no-share mode"); - clusterConfig.setDriverNum(driverNum); - - clusterConfig.setContainerMemoryMB(config.getInteger(CONTAINER_MEMORY_MB)); - clusterConfig.setContainerDiskGB(config.getInteger(CONTAINER_DISK_GB)); - clusterConfig.setContainerVcores(config.getDouble(CONTAINER_VCORES)); - int workersPerContainer = config.getInteger(CONTAINER_WORKER_NUM); - clusterConfig.setContainerWorkerNum(workersPerContainer); - - int containerNum = config.getInteger(CONTAINER_NUM); - clusterConfig.setContainerNum(containerNum); - - ClusterJvmOptions containerJvmOptions; - if (config.contains(CONTAINER_JVM_OPTION)) { - containerJvmOptions = ClusterJvmOptions.build(config.getString(CONTAINER_JVM_OPTION)); - } else { - containerJvmOptions = new ClusterJvmOptions(); - containerJvmOptions.setMaxHeapMB((int) (driverMB * DEFAULT_HEAP_FRACTION)); - } - clusterConfig.setContainerJvmOptions(containerJvmOptions); - config.put(CONTAINER_HEAP_SIZE_MB, String.valueOf(containerJvmOptions.getMaxHeapMB())); - - ClusterJvmOptions supervisorJvmOptions = - ClusterJvmOptions.build(config.getString(SUPERVISOR_JVM_OPTIONS)); - clusterConfig.setSupervisorJvmOptions(supervisorJvmOptions); - - boolean isFoEnabled = config.getBoolean(FO_ENABLE); - clusterConfig.setFoEnable(isFoEnabled); - clusterConfig.setMaxRestarts(config.getInteger(FO_MAX_RESTARTS)); - - FailoverStrategyType strategyType = - FailoverStrategyType.valueOf(config.getString(FO_STRATEGY)); - if (!isFoEnabled || strategyType == FailoverStrategyType.disable_fo) { - clusterConfig.setMaxRestarts(0); - config.put(ExecutionConfigKeys.FO_STRATEGY, FailoverStrategyType.disable_fo.name()); - } - clusterConfig.setConfig(config); - - return clusterConfig; - } - - public int getContainerNum() { - return containerNum; - } - - public void setContainerNum(int containerNum) { - this.containerNum = containerNum; - } - - public int getContainerMemoryMB() { - return containerMemoryMB; - } - - public void setContainerMemoryMB(int containerMemoryMB) { - this.containerMemoryMB = containerMemoryMB; - } - - public int getContainerDiskGB() { - return containerDiskGB; - } - - public void setContainerDiskGB(int containerDiskGB) { - this.containerDiskGB = containerDiskGB; - } - - public int getContainerWorkerNum() { - return containerWorkerNum; - } - - public void setContainerWorkerNum(int containerWorkerNum) { - this.containerWorkerNum = containerWorkerNum; - } - - public double getContainerVcores() { - return containerVcores; - } - - public void setContainerVcores(double containerVcores) { - this.containerVcores = containerVcores; - } - - public int getMasterMemoryMB() { - return masterMemoryMB; - } - - public void setMasterMemoryMB(int masterMemoryMB) { - this.masterMemoryMB = masterMemoryMB; - } - - public int getDriverMemoryMB() { - return driverMemoryMB; - } - - public void setDriverMemoryMB(int driverMemoryMB) { - this.driverMemoryMB = driverMemoryMB; - } - - public int getMasterDiskGB() { - return masterDiskGB; - } - - public void setMasterDiskGB(int masterDiskGB) { - this.masterDiskGB = masterDiskGB; - } - - public int getDriverDiskGB() { - return driverDiskGB; - } - - public void setDriverDiskGB(int driverDiskGB) { - this.driverDiskGB = driverDiskGB; - } - - public ClusterJvmOptions getDriverJvmOptions() { - return driverJvmOptions; - } - - public void setDriverJvmOptions(ClusterJvmOptions driverJvmOptions) { - this.driverJvmOptions = driverJvmOptions; - } - - public ClusterJvmOptions getMasterJvmOptions() { - return masterJvmOptions; - } - - public void setMasterJvmOptions(ClusterJvmOptions masterJvmOptions) { - this.masterJvmOptions = masterJvmOptions; - } - - public double getMasterVcores() { - return masterVcores; - } - - public void setMasterVcores(double masterVcores) { - this.masterVcores = masterVcores; - } - - public double getDriverVcores() { - return driverVcores; - } - - public void setDriverVcores(double driverVcores) { - this.driverVcores = driverVcores; - } - - public Configuration getConfig() { - return config; - } - - public void setConfig(Configuration config) { - this.config = config; - } - - public boolean isFoEnable() { - return isFoEnable; - } - - public void setFoEnable(boolean isFoEnable) { - this.isFoEnable = isFoEnable; - } - - public int getMaxRestarts() { - return maxRestarts; - } - - public void setMaxRestarts(int maxRestarts) { - this.maxRestarts = maxRestarts; - } - - public int getDriverNum() { - return driverNum; - } - - public void setDriverNum(int driverNum) { - this.driverNum = driverNum; - } - - public ClusterJvmOptions getContainerJvmOptions() { - return containerJvmOptions; - } - - public void setContainerJvmOptions(ClusterJvmOptions containerJvmOptions) { - this.containerJvmOptions = containerJvmOptions; - } - - public ClusterJvmOptions getClientJvmOptions() { - return clientJvmOptions; - } +import com.google.common.base.Preconditions; - public void setClientJvmOptions(ClusterJvmOptions clientJvmOptions) { - this.clientJvmOptions = clientJvmOptions; - } +public class ClusterConfig implements Serializable { - public double getClientVcores() { - return clientVcores; - } + private static final double DEFAULT_HEAP_FRACTION = 0.8; + + private int containerNum; + private int containerMemoryMB; + private int containerDiskGB; + private double containerVcores; + private int containerWorkerNum; + private ClusterJvmOptions containerJvmOptions; + + private int masterMemoryMB; + private int masterDiskGB; + private double masterVcores; + private ClusterJvmOptions masterJvmOptions; + + private int driverNum; + private int driverMemoryMB; + private int driverDiskGB; + private double driverVcores; + private ClusterJvmOptions driverJvmOptions; + + private int clientMemoryMB; + private int clientDiskGB; + private double clientVcores; + private ClusterJvmOptions clientJvmOptions; + + private boolean isFoEnable; + private int maxRestarts; + private Configuration config; + private ClusterJvmOptions supervisorJvmOptions; + + public static ClusterConfig build(Configuration config) { + ClusterConfig clusterConfig = new ClusterConfig(); + + clusterConfig.setMasterMemoryMB(config.getInteger(MASTER_MEMORY_MB)); + clusterConfig.setMasterDiskGB(config.getInteger(MASTER_DISK_GB)); + ClusterJvmOptions masterJvmOptions = + ClusterJvmOptions.build(config.getString(MASTER_JVM_OPTIONS)); + clusterConfig.setMasterJvmOptions(masterJvmOptions); + clusterConfig.setMasterVcores(config.getDouble(MASTER_VCORES)); + + ClusterJvmOptions clientJvmOptions = + ClusterJvmOptions.build(config.getString(CLIENT_JVM_OPTIONS)); + clusterConfig.setClientJvmOptions(clientJvmOptions); + clusterConfig.setClientVcores(config.getDouble(CLIENT_VCORES)); + clusterConfig.setClientMemoryMB(config.getInteger(CLIENT_MEMORY_MB)); + clusterConfig.setClientDiskGB(config.getInteger(CLIENT_DISK_GB)); + + int driverMB = config.getInteger(DRIVER_MEMORY_MB); + clusterConfig.setDriverMemoryMB(driverMB); + int driverDiskGB = config.getInteger(DRIVER_DISK_GB); + clusterConfig.setDriverDiskGB(driverDiskGB); + ClusterJvmOptions driverJvmOptions = + ClusterJvmOptions.build(config.getString(DRIVER_JVM_OPTION)); + clusterConfig.setDriverJvmOptions(driverJvmOptions); + clusterConfig.setDriverVcores(config.getDouble(DRIVER_VCORES)); + + int driverNum = config.getInteger(DRIVER_NUM); + Preconditions.checkArgument( + driverNum == 1 || driverNum > 1 && PipelineUtil.isAsync(config), + "only one driver is allowed in no-share mode"); + clusterConfig.setDriverNum(driverNum); + + clusterConfig.setContainerMemoryMB(config.getInteger(CONTAINER_MEMORY_MB)); + clusterConfig.setContainerDiskGB(config.getInteger(CONTAINER_DISK_GB)); + clusterConfig.setContainerVcores(config.getDouble(CONTAINER_VCORES)); + int workersPerContainer = config.getInteger(CONTAINER_WORKER_NUM); + clusterConfig.setContainerWorkerNum(workersPerContainer); + + int containerNum = config.getInteger(CONTAINER_NUM); + clusterConfig.setContainerNum(containerNum); + + ClusterJvmOptions containerJvmOptions; + if (config.contains(CONTAINER_JVM_OPTION)) { + containerJvmOptions = ClusterJvmOptions.build(config.getString(CONTAINER_JVM_OPTION)); + } else { + containerJvmOptions = new ClusterJvmOptions(); + containerJvmOptions.setMaxHeapMB((int) (driverMB * DEFAULT_HEAP_FRACTION)); + } + clusterConfig.setContainerJvmOptions(containerJvmOptions); + config.put(CONTAINER_HEAP_SIZE_MB, String.valueOf(containerJvmOptions.getMaxHeapMB())); + + ClusterJvmOptions supervisorJvmOptions = + ClusterJvmOptions.build(config.getString(SUPERVISOR_JVM_OPTIONS)); + clusterConfig.setSupervisorJvmOptions(supervisorJvmOptions); + + boolean isFoEnabled = config.getBoolean(FO_ENABLE); + clusterConfig.setFoEnable(isFoEnabled); + clusterConfig.setMaxRestarts(config.getInteger(FO_MAX_RESTARTS)); + + FailoverStrategyType strategyType = FailoverStrategyType.valueOf(config.getString(FO_STRATEGY)); + if (!isFoEnabled || strategyType == FailoverStrategyType.disable_fo) { + clusterConfig.setMaxRestarts(0); + config.put(ExecutionConfigKeys.FO_STRATEGY, FailoverStrategyType.disable_fo.name()); + } + clusterConfig.setConfig(config); + + return clusterConfig; + } + + public int getContainerNum() { + return containerNum; + } + + public void setContainerNum(int containerNum) { + this.containerNum = containerNum; + } + + public int getContainerMemoryMB() { + return containerMemoryMB; + } + + public void setContainerMemoryMB(int containerMemoryMB) { + this.containerMemoryMB = containerMemoryMB; + } + + public int getContainerDiskGB() { + return containerDiskGB; + } + + public void setContainerDiskGB(int containerDiskGB) { + this.containerDiskGB = containerDiskGB; + } + + public int getContainerWorkerNum() { + return containerWorkerNum; + } + + public void setContainerWorkerNum(int containerWorkerNum) { + this.containerWorkerNum = containerWorkerNum; + } + + public double getContainerVcores() { + return containerVcores; + } + + public void setContainerVcores(double containerVcores) { + this.containerVcores = containerVcores; + } + + public int getMasterMemoryMB() { + return masterMemoryMB; + } + + public void setMasterMemoryMB(int masterMemoryMB) { + this.masterMemoryMB = masterMemoryMB; + } + + public int getDriverMemoryMB() { + return driverMemoryMB; + } + + public void setDriverMemoryMB(int driverMemoryMB) { + this.driverMemoryMB = driverMemoryMB; + } - public void setClientVcores(double clientVcores) { - this.clientVcores = clientVcores; - } + public int getMasterDiskGB() { + return masterDiskGB; + } - public int getClientMemoryMB() { - return clientMemoryMB; - } + public void setMasterDiskGB(int masterDiskGB) { + this.masterDiskGB = masterDiskGB; + } - public void setClientMemoryMB(int clientMemoryMB) { - this.clientMemoryMB = clientMemoryMB; - } + public int getDriverDiskGB() { + return driverDiskGB; + } - public int getClientDiskGB() { - return clientDiskGB; - } + public void setDriverDiskGB(int driverDiskGB) { + this.driverDiskGB = driverDiskGB; + } - public void setClientDiskGB(int clientDiskGB) { - this.clientDiskGB = clientDiskGB; - } + public ClusterJvmOptions getDriverJvmOptions() { + return driverJvmOptions; + } - public int getDriverRegisterTimeoutSec() { - return config.getInteger(FO_TIMEOUT_MS) / 1000; - } + public void setDriverJvmOptions(ClusterJvmOptions driverJvmOptions) { + this.driverJvmOptions = driverJvmOptions; + } - public ClusterJvmOptions getSupervisorJvmOptions() { - return supervisorJvmOptions; - } + public ClusterJvmOptions getMasterJvmOptions() { + return masterJvmOptions; + } - public void setSupervisorJvmOptions(ClusterJvmOptions supervisorJvmOptions) { - this.supervisorJvmOptions = supervisorJvmOptions; - } + public void setMasterJvmOptions(ClusterJvmOptions masterJvmOptions) { + this.masterJvmOptions = masterJvmOptions; + } + public double getMasterVcores() { + return masterVcores; + } - @Override - public String toString() { - return "ClusterConfig{" + "containerNum=" + containerNum + ", containerMemoryMB=" - + containerMemoryMB + ", containerWorkers=" + containerWorkerNum + ", " - + "containerJvmOptions=" + containerJvmOptions + ", masterMemoryMB=" + masterMemoryMB - + ", masterJvmOptions=" + masterJvmOptions + ", driverMemoryMB=" + driverMemoryMB - + ", driverJvmOptions=" + driverJvmOptions + ", restartAllFo=" - + isFoEnable + '}'; - } + public void setMasterVcores(double masterVcores) { + this.masterVcores = masterVcores; + } + + public double getDriverVcores() { + return driverVcores; + } + + public void setDriverVcores(double driverVcores) { + this.driverVcores = driverVcores; + } + + public Configuration getConfig() { + return config; + } + + public void setConfig(Configuration config) { + this.config = config; + } + + public boolean isFoEnable() { + return isFoEnable; + } + + public void setFoEnable(boolean isFoEnable) { + this.isFoEnable = isFoEnable; + } + + public int getMaxRestarts() { + return maxRestarts; + } + + public void setMaxRestarts(int maxRestarts) { + this.maxRestarts = maxRestarts; + } + + public int getDriverNum() { + return driverNum; + } + + public void setDriverNum(int driverNum) { + this.driverNum = driverNum; + } + + public ClusterJvmOptions getContainerJvmOptions() { + return containerJvmOptions; + } + + public void setContainerJvmOptions(ClusterJvmOptions containerJvmOptions) { + this.containerJvmOptions = containerJvmOptions; + } + + public ClusterJvmOptions getClientJvmOptions() { + return clientJvmOptions; + } + + public void setClientJvmOptions(ClusterJvmOptions clientJvmOptions) { + this.clientJvmOptions = clientJvmOptions; + } + + public double getClientVcores() { + return clientVcores; + } + + public void setClientVcores(double clientVcores) { + this.clientVcores = clientVcores; + } + + public int getClientMemoryMB() { + return clientMemoryMB; + } + + public void setClientMemoryMB(int clientMemoryMB) { + this.clientMemoryMB = clientMemoryMB; + } + + public int getClientDiskGB() { + return clientDiskGB; + } + + public void setClientDiskGB(int clientDiskGB) { + this.clientDiskGB = clientDiskGB; + } + + public int getDriverRegisterTimeoutSec() { + return config.getInteger(FO_TIMEOUT_MS) / 1000; + } + + public ClusterJvmOptions getSupervisorJvmOptions() { + return supervisorJvmOptions; + } + + public void setSupervisorJvmOptions(ClusterJvmOptions supervisorJvmOptions) { + this.supervisorJvmOptions = supervisorJvmOptions; + } + + @Override + public String toString() { + return "ClusterConfig{" + + "containerNum=" + + containerNum + + ", containerMemoryMB=" + + containerMemoryMB + + ", containerWorkers=" + + containerWorkerNum + + ", " + + "containerJvmOptions=" + + containerJvmOptions + + ", masterMemoryMB=" + + masterMemoryMB + + ", masterJvmOptions=" + + masterJvmOptions + + ", driverMemoryMB=" + + driverMemoryMB + + ", driverJvmOptions=" + + driverJvmOptions + + ", restartAllFo=" + + isFoEnable + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterJvmOptions.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterJvmOptions.java index ddd29e67c..0ce6e409a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterJvmOptions.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/config/ClusterJvmOptions.java @@ -19,216 +19,236 @@ package org.apache.geaflow.cluster.config; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; -import org.apache.commons.lang3.StringUtils; - -public class ClusterJvmOptions { - - private static final String XMX_OPTION = "-Xmx"; - private static final String XMS_OPTION = "-Xms"; - private static final String XMN_OPTION = "-Xmn"; - private static final String XSS_OPTION = "-Xss"; - private static final String MAX_DIRECT_MEM_SIZE = "MaxDirectMemorySize"; - private static final String CLUSTER_WORKER_JVM_ARG_PATTERN = - "(-XX:\\+(HeapDumpOnOutOfMemoryError|CrashOnOutOfMemoryError|UseG1GC))" - + "|(-XX:MaxDirectMemorySize=(\\d+)(m|g|M|G))" - + "|(-XX:(HeapDumpPath|ErrorFile)=(.+))" - + "|(-(Xmx|Xms|Xmn|Xss)(\\d+)(k|m|g|M|G))" - + "|(-Dray\\.logging\\.max\\.log\\.file\\.num=(\\d)+)" - + "|(-Dray\\.logging\\.level=(INFO|DEBUG|WARN|ERROR))" - + "|(-Dray\\.logging\\.max\\.log\\.file\\.size=(\\d)+(MB|GB))" - + "|(-Dray\\.task\\.return_task_exception=(false|true))"; - - private int xmsMB = 0; - private int xmnMB = 0; - private int maxHeapMB; - private int maxDirectMB; - private final List jvmOptions = new ArrayList<>(); - private final List extraOptions = new ArrayList<>(); - - public int getXmsMB() { - return xmsMB; - } - - public int getXmnMB() { - return xmnMB; - } - - public int getMaxHeapMB() { - return maxHeapMB; - } - - public void setMaxHeapMB(int maxHeapMB) { - this.maxHeapMB = maxHeapMB; - } - - public int getMaxDirectMB() { - return maxDirectMB; - } - - public List getJvmOptions() { - return jvmOptions; - } - - public List getExtraOptions() { - return extraOptions; - } - - public static ClusterJvmOptions build(List options) { - return parseJvmOptions(options.iterator()); - } - - public static ClusterJvmOptions build(String jvmArgs) { - if (StringUtils.isBlank(jvmArgs)) { - return null; - } - - String[] args = jvmArgs.trim().split("\\s*,\\s*"); - Iterator iterator = Arrays.stream(args).iterator(); - - return parseJvmOptions(iterator); - } - private static ClusterJvmOptions parseJvmOptions(Iterator args) { - String jvmArgPattern = CLUSTER_WORKER_JVM_ARG_PATTERN; - ClusterJvmOptions jvmOptions = new ClusterJvmOptions(); - - while (args.hasNext()) { - String jvmArg = args.next(); - if (StringUtils.isBlank(jvmArg)) { - continue; - } - - if (!jvmArg.matches(jvmArgPattern)) { - throw new IllegalArgumentException( - String.format("jvm arg %s not match the pattern %s", jvmArg, jvmArgPattern)); - } - - if (jvmArg.startsWith("-XX:")) { - jvmOptions.parseNonStableOption(jvmArg); - } else if (jvmArg.startsWith("-X")) { - jvmOptions.parseNonStandardOption(jvmArg); - } else if (jvmArg.startsWith("-D")) { - jvmOptions.parseSystemOption(jvmArg); - } else { - throw new RuntimeException("not support jvm option " + jvmArg + " yet"); - } - } - - return jvmOptions; - } - - /** - * parse -X options. - * - * @param jvmArg - */ - private void parseNonStandardOption(String jvmArg) { - if (jvmArg.startsWith(XMX_OPTION)) { - this.maxHeapMB = parseMemoryOptionToMB(jvmArg, XMX_OPTION); - } else if (jvmArg.startsWith(XMS_OPTION)) { - this.xmsMB = parseMemoryOptionToMB(jvmArg, XMS_OPTION); - } else if (jvmArg.startsWith(XMN_OPTION)) { - this.xmnMB = parseMemoryOptionToMB(jvmArg, XMN_OPTION); - } else if (jvmArg.startsWith(XSS_OPTION)) { - //this.xss = parseMemoryOptionToMB(jvmArg, XSS_OPTION); - this.extraOptions.add(jvmArg); - } else { - throw new RuntimeException("not support -X option " + jvmArg + " yet"); - } - } - - /** - * parse -D options. - * - * @param jvmArg - */ - private void parseSystemOption(String jvmArg) { - String option = jvmArg.substring(2); - if (StringUtils.isBlank(option)) { - throw new IllegalArgumentException("invalid jvm option " + jvmArg); - } - - this.jvmOptions.add(jvmArg); - } - - /** - * parse -XX: options. - * - * @param jvmArg - */ - private void parseNonStableOption(String jvmArg) { - String option = jvmArg.substring(4); - if (option.startsWith(MAX_DIRECT_MEM_SIZE)) { - Preconditions.checkArgument(option.length() > MAX_DIRECT_MEM_SIZE.length() + 1, - "the jvm option %s is too short", jvmArg); - - String size = option.substring(MAX_DIRECT_MEM_SIZE.length() + 1); - this.maxDirectMB = convertMemoryToMB(size); - } else { - this.extraOptions.add(jvmArg); - } - this.jvmOptions.add(jvmArg); - } - - /** - * parse -Xmx, -Xms, -Xmn, -Xss option to MB value. - * - * @param memArg - * @return - */ - private int parseMemoryOptionToMB(String memArg, String prefix) { - Preconditions.checkArgument(memArg.length() > prefix.length(), - "invalid memory argument %s for option %s", memArg, prefix); - String memory = memArg.substring(prefix.length()); - int size = convertMemoryToMB(memory); - Preconditions.checkArgument(size > 0, - "memory size should greater than 0m while current is %s", size); - this.jvmOptions.add(prefix + size + "m"); - return size; - } +import org.apache.commons.lang3.StringUtils; - private int convertMemoryToMB(String memory) { - int size; - if (memory.endsWith("G") || memory.endsWith("g") - || memory.endsWith("M") || memory.endsWith("m") - || memory.endsWith("K") || memory.endsWith("k")) { - char memoryUnit = memory.charAt(memory.length() - 1); - int memorySize = Integer.valueOf(memory.substring(0, memory.length() - 1)); - switch (memoryUnit) { - case 'M': - case 'm': - size = memorySize; - break; - case 'G': - case 'g': - size = memorySize * 1024; - break; - case 'K': - case 'k': - size = memorySize / 1024; - break; - default: - throw new IllegalArgumentException("invalid memory size " + memory); - } - } else { - int memorySize = Integer.parseInt(memory); - size = memorySize / 1024 / 1024; - } - Preconditions.checkArgument(size > 0, - "memory size should greater than 0m while current is %s", size); - return size; - } +import com.google.common.base.Preconditions; - @Override - public String toString() { - return "ClusterJvmOptions{" + "xmsMB=" + xmsMB + ", xmnMB=" + xmnMB + ", maxHeapMB=" - + maxHeapMB + ", maxDirectMB=" + maxDirectMB + ", jvmOptions=" + jvmOptions - + ", extraOptions=" + extraOptions + '}'; - } +public class ClusterJvmOptions { + private static final String XMX_OPTION = "-Xmx"; + private static final String XMS_OPTION = "-Xms"; + private static final String XMN_OPTION = "-Xmn"; + private static final String XSS_OPTION = "-Xss"; + private static final String MAX_DIRECT_MEM_SIZE = "MaxDirectMemorySize"; + private static final String CLUSTER_WORKER_JVM_ARG_PATTERN = + "(-XX:\\+(HeapDumpOnOutOfMemoryError|CrashOnOutOfMemoryError|UseG1GC))" + + "|(-XX:MaxDirectMemorySize=(\\d+)(m|g|M|G))" + + "|(-XX:(HeapDumpPath|ErrorFile)=(.+))" + + "|(-(Xmx|Xms|Xmn|Xss)(\\d+)(k|m|g|M|G))" + + "|(-Dray\\.logging\\.max\\.log\\.file\\.num=(\\d)+)" + + "|(-Dray\\.logging\\.level=(INFO|DEBUG|WARN|ERROR))" + + "|(-Dray\\.logging\\.max\\.log\\.file\\.size=(\\d)+(MB|GB))" + + "|(-Dray\\.task\\.return_task_exception=(false|true))"; + + private int xmsMB = 0; + private int xmnMB = 0; + private int maxHeapMB; + private int maxDirectMB; + private final List jvmOptions = new ArrayList<>(); + private final List extraOptions = new ArrayList<>(); + + public int getXmsMB() { + return xmsMB; + } + + public int getXmnMB() { + return xmnMB; + } + + public int getMaxHeapMB() { + return maxHeapMB; + } + + public void setMaxHeapMB(int maxHeapMB) { + this.maxHeapMB = maxHeapMB; + } + + public int getMaxDirectMB() { + return maxDirectMB; + } + + public List getJvmOptions() { + return jvmOptions; + } + + public List getExtraOptions() { + return extraOptions; + } + + public static ClusterJvmOptions build(List options) { + return parseJvmOptions(options.iterator()); + } + + public static ClusterJvmOptions build(String jvmArgs) { + if (StringUtils.isBlank(jvmArgs)) { + return null; + } + + String[] args = jvmArgs.trim().split("\\s*,\\s*"); + Iterator iterator = Arrays.stream(args).iterator(); + + return parseJvmOptions(iterator); + } + + private static ClusterJvmOptions parseJvmOptions(Iterator args) { + String jvmArgPattern = CLUSTER_WORKER_JVM_ARG_PATTERN; + ClusterJvmOptions jvmOptions = new ClusterJvmOptions(); + + while (args.hasNext()) { + String jvmArg = args.next(); + if (StringUtils.isBlank(jvmArg)) { + continue; + } + + if (!jvmArg.matches(jvmArgPattern)) { + throw new IllegalArgumentException( + String.format("jvm arg %s not match the pattern %s", jvmArg, jvmArgPattern)); + } + + if (jvmArg.startsWith("-XX:")) { + jvmOptions.parseNonStableOption(jvmArg); + } else if (jvmArg.startsWith("-X")) { + jvmOptions.parseNonStandardOption(jvmArg); + } else if (jvmArg.startsWith("-D")) { + jvmOptions.parseSystemOption(jvmArg); + } else { + throw new RuntimeException("not support jvm option " + jvmArg + " yet"); + } + } + + return jvmOptions; + } + + /** + * parse -X options. + * + * @param jvmArg + */ + private void parseNonStandardOption(String jvmArg) { + if (jvmArg.startsWith(XMX_OPTION)) { + this.maxHeapMB = parseMemoryOptionToMB(jvmArg, XMX_OPTION); + } else if (jvmArg.startsWith(XMS_OPTION)) { + this.xmsMB = parseMemoryOptionToMB(jvmArg, XMS_OPTION); + } else if (jvmArg.startsWith(XMN_OPTION)) { + this.xmnMB = parseMemoryOptionToMB(jvmArg, XMN_OPTION); + } else if (jvmArg.startsWith(XSS_OPTION)) { + // this.xss = parseMemoryOptionToMB(jvmArg, XSS_OPTION); + this.extraOptions.add(jvmArg); + } else { + throw new RuntimeException("not support -X option " + jvmArg + " yet"); + } + } + + /** + * parse -D options. + * + * @param jvmArg + */ + private void parseSystemOption(String jvmArg) { + String option = jvmArg.substring(2); + if (StringUtils.isBlank(option)) { + throw new IllegalArgumentException("invalid jvm option " + jvmArg); + } + + this.jvmOptions.add(jvmArg); + } + + /** + * parse -XX: options. + * + * @param jvmArg + */ + private void parseNonStableOption(String jvmArg) { + String option = jvmArg.substring(4); + if (option.startsWith(MAX_DIRECT_MEM_SIZE)) { + Preconditions.checkArgument( + option.length() > MAX_DIRECT_MEM_SIZE.length() + 1, + "the jvm option %s is too short", + jvmArg); + + String size = option.substring(MAX_DIRECT_MEM_SIZE.length() + 1); + this.maxDirectMB = convertMemoryToMB(size); + } else { + this.extraOptions.add(jvmArg); + } + this.jvmOptions.add(jvmArg); + } + + /** + * parse -Xmx, -Xms, -Xmn, -Xss option to MB value. + * + * @param memArg + * @return + */ + private int parseMemoryOptionToMB(String memArg, String prefix) { + Preconditions.checkArgument( + memArg.length() > prefix.length(), + "invalid memory argument %s for option %s", + memArg, + prefix); + String memory = memArg.substring(prefix.length()); + int size = convertMemoryToMB(memory); + Preconditions.checkArgument( + size > 0, "memory size should greater than 0m while current is %s", size); + this.jvmOptions.add(prefix + size + "m"); + return size; + } + + private int convertMemoryToMB(String memory) { + int size; + if (memory.endsWith("G") + || memory.endsWith("g") + || memory.endsWith("M") + || memory.endsWith("m") + || memory.endsWith("K") + || memory.endsWith("k")) { + char memoryUnit = memory.charAt(memory.length() - 1); + int memorySize = Integer.valueOf(memory.substring(0, memory.length() - 1)); + switch (memoryUnit) { + case 'M': + case 'm': + size = memorySize; + break; + case 'G': + case 'g': + size = memorySize * 1024; + break; + case 'K': + case 'k': + size = memorySize / 1024; + break; + default: + throw new IllegalArgumentException("invalid memory size " + memory); + } + } else { + int memorySize = Integer.parseInt(memory); + size = memorySize / 1024 / 1024; + } + Preconditions.checkArgument( + size > 0, "memory size should greater than 0m while current is %s", size); + return size; + } + + @Override + public String toString() { + return "ClusterJvmOptions{" + + "xmsMB=" + + xmsMB + + ", xmnMB=" + + xmnMB + + ", maxHeapMB=" + + maxHeapMB + + ", maxDirectMB=" + + maxDirectMB + + ", jvmOptions=" + + jvmOptions + + ", extraOptions=" + + extraOptions + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/constants/ClusterConstants.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/constants/ClusterConstants.java index 5f74aeeff..b69ae7cec 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/constants/ClusterConstants.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/constants/ClusterConstants.java @@ -21,44 +21,43 @@ public class ClusterConstants { - private static final String MASTER_PREFIX = "master-"; - private static final String DRIVER_PREFIX = "driver-"; - private static final String CONTAINER_PREFIX = "container-"; - - public static final String MASTER_LOG_SUFFIX = "master.log"; - public static final String DRIVER_LOG_SUFFIX = "driver.log"; - public static final String CONTAINER_LOG_SUFFIX = "container.log"; - public static final String CLUSTER_TYPE = "clusterType"; - public static final String LOCAL_CLUSTER = "LOCAL"; - - public static final int DEFAULT_MASTER_ID = 0; - public static final int EXIT_CODE = -1; - - public static final String ENV_AGENT_PORT = "AGENT_PORT"; - public static final String ENV_SUPERVISOR_PORT = "SUPERVISOR_PORT"; - - public static final String MASTER_ID = "GEAFLOW_MASTER_ID"; - public static final String CONTAINER_ID = "GEAFLOW_CONTAINER_ID"; - public static final String CONTAINER_INDEX = "GEAFLOW_CONTAINER_INDEX"; - public static final String AUTO_RESTART = "GEAFLOW_AUTO_RESTART"; - public static final String IS_RECOVER = "GEAFLOW_IS_RECOVER"; - public static final String JOB_CONFIG = "GEAFLOW_JOB_CONFIG"; - public static final String CONTAINER_START_COMMAND = "CONTAINER_START_COMMAND"; - public static final String CONTAINER_START_COMMAND_TEMPLATE = - "%java% %classpath% %jvmmem% %jvmopts% %logging% %class% %redirects%"; - public static final String AGENT_PROFILER_PATH = "AGENT_PROFILER_PATH"; - public static final String CONFIG_FILE_LOG4J_NAME = "log4j.properties"; - - public static String getMasterName() { - return String.format("%s%s", MASTER_PREFIX, DEFAULT_MASTER_ID); - } - - public static String getDriverName(int id) { - return String.format("%s%s", DRIVER_PREFIX, id); - } - - public static String getContainerName(int id) { - return String.format("%s%s", CONTAINER_PREFIX, id); - } - + private static final String MASTER_PREFIX = "master-"; + private static final String DRIVER_PREFIX = "driver-"; + private static final String CONTAINER_PREFIX = "container-"; + + public static final String MASTER_LOG_SUFFIX = "master.log"; + public static final String DRIVER_LOG_SUFFIX = "driver.log"; + public static final String CONTAINER_LOG_SUFFIX = "container.log"; + public static final String CLUSTER_TYPE = "clusterType"; + public static final String LOCAL_CLUSTER = "LOCAL"; + + public static final int DEFAULT_MASTER_ID = 0; + public static final int EXIT_CODE = -1; + + public static final String ENV_AGENT_PORT = "AGENT_PORT"; + public static final String ENV_SUPERVISOR_PORT = "SUPERVISOR_PORT"; + + public static final String MASTER_ID = "GEAFLOW_MASTER_ID"; + public static final String CONTAINER_ID = "GEAFLOW_CONTAINER_ID"; + public static final String CONTAINER_INDEX = "GEAFLOW_CONTAINER_INDEX"; + public static final String AUTO_RESTART = "GEAFLOW_AUTO_RESTART"; + public static final String IS_RECOVER = "GEAFLOW_IS_RECOVER"; + public static final String JOB_CONFIG = "GEAFLOW_JOB_CONFIG"; + public static final String CONTAINER_START_COMMAND = "CONTAINER_START_COMMAND"; + public static final String CONTAINER_START_COMMAND_TEMPLATE = + "%java% %classpath% %jvmmem% %jvmopts% %logging% %class% %redirects%"; + public static final String AGENT_PROFILER_PATH = "AGENT_PROFILER_PATH"; + public static final String CONFIG_FILE_LOG4J_NAME = "log4j.properties"; + + public static String getMasterName() { + return String.format("%s%s", MASTER_PREFIX, DEFAULT_MASTER_ID); + } + + public static String getDriverName(int id) { + return String.format("%s%s", DRIVER_PREFIX, id); + } + + public static String getContainerName(int id) { + return String.format("%s%s", CONTAINER_PREFIX, id); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/Container.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/Container.java index 685ba3d93..faa8add95 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/Container.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/Container.java @@ -19,9 +19,8 @@ package org.apache.geaflow.cluster.container; -import com.baidu.brpc.server.RpcServerOptions; -import com.google.common.base.Preconditions; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.cluster.collector.EmitterService; import org.apache.geaflow.cluster.common.AbstractContainer; import org.apache.geaflow.cluster.constants.ClusterConstants; @@ -42,121 +41,124 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class Container extends AbstractContainer implements IContainer { - - private static final Logger LOGGER = LoggerFactory.getLogger(Container.class); - - private ContainerContext containerContext; - private Dispatcher dispatcher; - private AtomicBoolean isOpened; - protected FetcherService fetcherService; - protected EmitterService emitterService; - protected TaskService workerService; - protected DispatcherService dispatcherService; +import com.baidu.brpc.server.RpcServerOptions; +import com.google.common.base.Preconditions; - public Container() { - this(0); - } +public class Container extends AbstractContainer implements IContainer { - public Container(int rpcPort) { - super(rpcPort); - this.isOpened = new AtomicBoolean(false); + private static final Logger LOGGER = LoggerFactory.getLogger(Container.class); + + private ContainerContext containerContext; + private Dispatcher dispatcher; + private AtomicBoolean isOpened; + protected FetcherService fetcherService; + protected EmitterService emitterService; + protected TaskService workerService; + protected DispatcherService dispatcherService; + + public Container() { + this(0); + } + + public Container(int rpcPort) { + super(rpcPort); + this.isOpened = new AtomicBoolean(false); + } + + @Override + public void init(ContainerContext containerContext) { + try { + this.containerContext = containerContext; + String containerName = ClusterConstants.getContainerName(containerContext.getId()); + super.init(containerContext.getId(), containerName, containerContext.getConfig()); + registerToMaster(); + LOGGER.info("container {} init finish", name); + } catch (Throwable t) { + LOGGER.error("init container err", t); + throw new GeaflowRuntimeException(t); } - - @Override - public void init(ContainerContext containerContext) { - try { - this.containerContext = containerContext; - String containerName = ClusterConstants.getContainerName(containerContext.getId()); - super.init(containerContext.getId(), containerName, containerContext.getConfig()); - registerToMaster(); - LOGGER.info("container {} init finish", name); - } catch (Throwable t) { - LOGGER.error("init container err", t); - throw new GeaflowRuntimeException(t); + } + + @Override + protected void startRpcService() { + RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); + this.rpcService = new RpcServiceImpl(PortUtil.getPort(rpcPort), serverOptions); + this.rpcService.addEndpoint(new ContainerEndpoint(this)); + this.rpcPort = rpcService.startService(); + } + + public OpenContainerResponseEvent open(OpenContainerEvent event) { + try { + if (isOpened.compareAndSet(false, true)) { + int num = event.getExecutorNum(); + Preconditions.checkArgument(num > 0, "worker num should > 0"); + LOGGER.info("open container {} with {} executors", name, num); + + this.fetcherService = new FetcherService(num, configuration); + this.emitterService = new EmitterService(num, configuration); + this.workerService = + new TaskService(id, num, configuration, metricGroup, fetcherService, emitterService); + this.dispatcher = new Dispatcher(workerService); + this.dispatcherService = new DispatcherService(dispatcher, configuration); + + // start task service + this.fetcherService.start(); + this.emitterService.start(); + this.workerService.start(); + this.dispatcherService.start(); + + if (containerContext.getReliableEvents() != null) { + for (IEvent reliableEvent : containerContext.getReliableEvents()) { + LOGGER.info("{} replay event {}", name, reliableEvent); + this.dispatcher.add((ICommand) reliableEvent); + } } + registerHAService(); + } + return new OpenContainerResponseEvent(id, 0); + } catch (Throwable throwable) { + LOGGER.error("{} open error", name, throwable); + throw throwable; } - - @Override - protected void startRpcService() { - RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); - this.rpcService = new RpcServiceImpl(PortUtil.getPort(rpcPort), serverOptions); - this.rpcService.addEndpoint(new ContainerEndpoint(this)); - this.rpcPort = rpcService.startService(); + } + + @Override + public IEvent process(IEvent input) { + LOGGER.info("{} process event {}", name, input); + try { + this.containerContext.addEvent(input); + this.containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); + this.dispatcher.add((ICommand) input); + return null; + } catch (Throwable throwable) { + LOGGER.error("{} process error", name, throwable); + throw throwable; } + } - public OpenContainerResponseEvent open(OpenContainerEvent event) { - try { - if (isOpened.compareAndSet(false, true)) { - int num = event.getExecutorNum(); - Preconditions.checkArgument(num > 0, "worker num should > 0"); - LOGGER.info("open container {} with {} executors", name, num); - - this.fetcherService = new FetcherService(num, configuration); - this.emitterService = new EmitterService(num, configuration); - this.workerService = new TaskService(id, num, - configuration, metricGroup, fetcherService, emitterService); - this.dispatcher = new Dispatcher(workerService); - this.dispatcherService = new DispatcherService(dispatcher, configuration); - - // start task service - this.fetcherService.start(); - this.emitterService.start(); - this.workerService.start(); - this.dispatcherService.start(); - - if (containerContext.getReliableEvents() != null) { - for (IEvent reliableEvent : containerContext.getReliableEvents()) { - LOGGER.info("{} replay event {}", name, reliableEvent); - this.dispatcher.add((ICommand) reliableEvent); - } - } - registerHAService(); - } - return new OpenContainerResponseEvent(id, 0); - } catch (Throwable throwable) { - LOGGER.error("{} open error", name, throwable); - throw throwable; - } + @Override + public void close() { + super.close(); + if (fetcherService != null) { + fetcherService.shutdown(); } - - @Override - public IEvent process(IEvent input) { - LOGGER.info("{} process event {}", name, input); - try { - this.containerContext.addEvent(input); - this.containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); - this.dispatcher.add((ICommand) input); - return null; - } catch (Throwable throwable) { - LOGGER.error("{} process error", name, throwable); - throw throwable; - } + if (workerService != null) { + workerService.shutdown(); } - - @Override - public void close() { - super.close(); - if (fetcherService != null) { - fetcherService.shutdown(); - } - if (workerService != null) { - workerService.shutdown(); - } - if (dispatcherService != null) { - dispatcherService.shutdown(); - } - if (emitterService != null) { - emitterService.shutdown(); - } - LOGGER.info("container {} closed", name); + if (dispatcherService != null) { + dispatcherService.shutdown(); } - - @Override - protected ContainerInfo buildComponentInfo() { - ContainerInfo containerInfo = new ContainerInfo(); - fillComponentInfo(containerInfo); - containerInfo.setShufflePort(ShuffleManager.getInstance().getShufflePort()); - return containerInfo; + if (emitterService != null) { + emitterService.shutdown(); } + LOGGER.info("container {} closed", name); + } + + @Override + protected ContainerInfo buildComponentInfo() { + ContainerInfo containerInfo = new ContainerInfo(); + fillComponentInfo(containerInfo); + containerInfo.setShufflePort(ShuffleManager.getInstance().getShufflePort()); + return containerInfo; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerContext.java index 23fd1aba9..c937ee89a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerContext.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.common.IReliableContext; import org.apache.geaflow.cluster.common.ReliableContainerContext; import org.apache.geaflow.cluster.constants.ClusterConstants; @@ -36,91 +37,94 @@ public class ContainerContext extends ReliableContainerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(ContainerContext.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ContainerContext.class); - private List reliableEvents; - private transient List waitingCheckpointEvents; + private List reliableEvents; + private transient List waitingCheckpointEvents; - public ContainerContext(int id, Configuration config) { - super(id, ClusterConstants.getContainerName(id), config); - this.reliableEvents = new ArrayList<>(); - this.waitingCheckpointEvents = new ArrayList<>(); - } + public ContainerContext(int id, Configuration config) { + super(id, ClusterConstants.getContainerName(id), config); + this.reliableEvents = new ArrayList<>(); + this.waitingCheckpointEvents = new ArrayList<>(); + } - public ContainerContext(int id, Configuration config, boolean isRecover) { - this(id, config); - this.isRecover = isRecover; - } + public ContainerContext(int id, Configuration config, boolean isRecover) { + this(id, config); + this.isRecover = isRecover; + } - public ContainerContext(int id, Configuration config, boolean isRecover, List reliableEvents) { - this(id, config, isRecover); - this.reliableEvents = reliableEvents; - } + public ContainerContext( + int id, Configuration config, boolean isRecover, List reliableEvents) { + this(id, config, isRecover); + this.reliableEvents = reliableEvents; + } - @Override - public void load() { - List events = ClusterMetaStore.getInstance(id, name, config).getEvents(); - if (events != null) { - LOGGER.info("container {} recover events {}", id, events); - reliableEvents = events; - } - if (waitingCheckpointEvents == null) { - waitingCheckpointEvents = new ArrayList<>(); - } else { - waitingCheckpointEvents.clear(); - } + @Override + public void load() { + List events = ClusterMetaStore.getInstance(id, name, config).getEvents(); + if (events != null) { + LOGGER.info("container {} recover events {}", id, events); + reliableEvents = events; } - - public List getReliableEvents() { - return reliableEvents; + if (waitingCheckpointEvents == null) { + waitingCheckpointEvents = new ArrayList<>(); + } else { + waitingCheckpointEvents.clear(); } + } - public synchronized void addEvent(IEvent input) { - if (input instanceof IHighAvailableEvent) { - if (((IHighAvailableEvent) input).getHaLevel() == HighAvailableLevel.CHECKPOINT) { - if (waitingCheckpointEvents == null) { - waitingCheckpointEvents = new ArrayList<>(); - } - if (!waitingCheckpointEvents.contains(input)) { - waitingCheckpointEvents.add(input); - LOGGER.info("container {} add recoverable event {}", id, input); - } else { - LOGGER.info("container {} already has recoverable event {}", id, input); - } - } - } else if (input.getEventType() == EventType.COMPOSE) { - IComposeEvent composeEvent = (IComposeEvent) input; - for (IEvent event : composeEvent.getEventList()) { - addEvent(event); - } + public List getReliableEvents() { + return reliableEvents; + } + + public synchronized void addEvent(IEvent input) { + if (input instanceof IHighAvailableEvent) { + if (((IHighAvailableEvent) input).getHaLevel() == HighAvailableLevel.CHECKPOINT) { + if (waitingCheckpointEvents == null) { + waitingCheckpointEvents = new ArrayList<>(); } + if (!waitingCheckpointEvents.contains(input)) { + waitingCheckpointEvents.add(input); + LOGGER.info("container {} add recoverable event {}", id, input); + } else { + LOGGER.info("container {} already has recoverable event {}", id, input); + } + } + } else if (input.getEventType() == EventType.COMPOSE) { + IComposeEvent composeEvent = (IComposeEvent) input; + for (IEvent event : composeEvent.getEventList()) { + addEvent(event); + } } + } - public static class EventCheckpointFunction implements IReliableContextCheckpointFunction { + public static class EventCheckpointFunction implements IReliableContextCheckpointFunction { - @Override - public void doCheckpoint(IReliableContext context) { - ContainerContext containerContext = ((ContainerContext) context); - if (containerContext.waitingCheckpointEvents == null || containerContext.waitingCheckpointEvents.isEmpty()) { - LOGGER.info("container {} has no new events to checkpoint", containerContext.getId()); - return; - } - List reliableEvents = ClusterMetaStore.getInstance().getEvents(); + @Override + public void doCheckpoint(IReliableContext context) { + ContainerContext containerContext = ((ContainerContext) context); + if (containerContext.waitingCheckpointEvents == null + || containerContext.waitingCheckpointEvents.isEmpty()) { + LOGGER.info("container {} has no new events to checkpoint", containerContext.getId()); + return; + } + List reliableEvents = ClusterMetaStore.getInstance().getEvents(); - if (reliableEvents == null) { - reliableEvents = new ArrayList<>(containerContext.waitingCheckpointEvents); - } else { - for (IEvent event : containerContext.waitingCheckpointEvents) { - if (reliableEvents.contains(event)) { - LOGGER.info("container {} already has saved recoverable event {}", containerContext.id, event); - } else { - reliableEvents.add(event); - } - } - } - ClusterMetaStore.getInstance().saveEvent(reliableEvents).flush(); - LOGGER.info("container {} checkpoint events {}", containerContext.getId(), reliableEvents); - containerContext.waitingCheckpointEvents.clear(); + if (reliableEvents == null) { + reliableEvents = new ArrayList<>(containerContext.waitingCheckpointEvents); + } else { + for (IEvent event : containerContext.waitingCheckpointEvents) { + if (reliableEvents.contains(event)) { + LOGGER.info( + "container {} already has saved recoverable event {}", containerContext.id, event); + } else { + reliableEvents.add(event); + } } + } + ClusterMetaStore.getInstance().saveEvent(reliableEvents).flush(); + LOGGER.info("container {} checkpoint events {}", containerContext.getId(), reliableEvents); + containerContext.waitingCheckpointEvents.clear(); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerInfo.java index 39c266b2a..f3c2cdacc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/ContainerInfo.java @@ -23,31 +23,42 @@ public class ContainerInfo extends ComponentInfo { - /** - * shuffle service port. - */ - private int shufflePort; - - public ContainerInfo() { - } - - public ContainerInfo(int containerId, String containerName, String host, int pid, int rpcPort, - int shufflePort) { - super(containerId, containerName, host, pid, rpcPort); - this.shufflePort = shufflePort; - } - - public int getShufflePort() { - return shufflePort; - } - - public void setShufflePort(int shufflePort) { - this.shufflePort = shufflePort; - } - - @Override - public String toString() { - return "ContainerInfo{" + "id=" + id + ", name='" + name + '\'' + ", host='" + host + '\'' - + ", pid=" + pid + ", rpcPort=" + rpcPort + ", shufflePort=" + shufflePort + "}"; - } + /** shuffle service port. */ + private int shufflePort; + + public ContainerInfo() {} + + public ContainerInfo( + int containerId, String containerName, String host, int pid, int rpcPort, int shufflePort) { + super(containerId, containerName, host, pid, rpcPort); + this.shufflePort = shufflePort; + } + + public int getShufflePort() { + return shufflePort; + } + + public void setShufflePort(int shufflePort) { + this.shufflePort = shufflePort; + } + + @Override + public String toString() { + return "ContainerInfo{" + + "id=" + + id + + ", name='" + + name + + '\'' + + ", host='" + + host + + '\'' + + ", pid=" + + pid + + ", rpcPort=" + + rpcPort + + ", shufflePort=" + + shufflePort + + "}"; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/IContainer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/IContainer.java index 8ac8e4267..dda4ae0f6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/IContainer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/container/IContainer.java @@ -25,19 +25,12 @@ public interface IContainer extends IEventProcessor { - /** - * Initialize container. - */ - void init(ContainerContext containerContext); + /** Initialize container. */ + void init(ContainerContext containerContext); - /** - * Open container to run workers. - */ - OpenContainerResponseEvent open(OpenContainerEvent event); - - /** - * Close container. - */ - void close(); + /** Open container to run workers. */ + OpenContainerResponseEvent open(OpenContainerEvent event); + /** Close container. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/Driver.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/Driver.java index 3957bfe94..e3a9d86bc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/Driver.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/Driver.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.driver; -import com.baidu.brpc.server.RpcServerOptions; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -27,6 +26,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.client.callback.JobOperatorCallback; import org.apache.geaflow.cluster.client.callback.JobOperatorCallbackFactory; import org.apache.geaflow.cluster.common.AbstractContainer; @@ -51,134 +51,143 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Driver process. - */ -public class Driver extends AbstractContainer implements IDriver { - - private static final Logger LOGGER = LoggerFactory.getLogger(Driver.class); - private static final String DRIVER_EXECUTOR = "driver-executor"; - private static final AtomicInteger pipelineTaskIdGenerator = new AtomicInteger(0); - - private DriverEventDispatcher eventDispatcher; - private DriverContext driverContext; - private ExecutorService executorService; - private Map pipelineExecutorMap; - private JobOperatorCallback jobOperatorCallback; - - public Driver() { - this(0); - } +import com.baidu.brpc.server.RpcServerOptions; - public Driver(int rpcPort) { - super(rpcPort); - } +/** Driver process. */ +public class Driver extends AbstractContainer implements IDriver { - @Override - public void init(DriverContext driverContext) { - super.init(driverContext.getId(), ClusterConstants.getDriverName(driverContext.getId()), - driverContext.getConfig()); - this.driverContext = driverContext; - this.eventDispatcher = new DriverEventDispatcher(); - this.executorService = Executors.newFixedThreadPool( + private static final Logger LOGGER = LoggerFactory.getLogger(Driver.class); + private static final String DRIVER_EXECUTOR = "driver-executor"; + private static final AtomicInteger pipelineTaskIdGenerator = new AtomicInteger(0); + + private DriverEventDispatcher eventDispatcher; + private DriverContext driverContext; + private ExecutorService executorService; + private Map pipelineExecutorMap; + private JobOperatorCallback jobOperatorCallback; + + public Driver() { + this(0); + } + + public Driver(int rpcPort) { + super(rpcPort); + } + + @Override + public void init(DriverContext driverContext) { + super.init( + driverContext.getId(), + ClusterConstants.getDriverName(driverContext.getId()), + driverContext.getConfig()); + this.driverContext = driverContext; + this.eventDispatcher = new DriverEventDispatcher(); + this.executorService = + Executors.newFixedThreadPool( 1, - ThreadUtil.namedThreadFactory(true, DRIVER_EXECUTOR, ComponentUncaughtExceptionHandler.INSTANCE)); - this.pipelineExecutorMap = new HashMap<>(); - this.jobOperatorCallback = JobOperatorCallbackFactory.createJobOperatorCallback(configuration); - - ExecutionIdGenerator.init(id); - if (driverContext.getPipeline() != null) { - LOGGER.info("driver {} execute pipeline from recovered context", name); - executorService.execute(() -> executePipelineInternal(driverContext.getPipeline())); - } - registerToMaster(); - registerHAService(); - LOGGER.info("driver {} init finish", name); + ThreadUtil.namedThreadFactory( + true, DRIVER_EXECUTOR, ComponentUncaughtExceptionHandler.INSTANCE)); + this.pipelineExecutorMap = new HashMap<>(); + this.jobOperatorCallback = JobOperatorCallbackFactory.createJobOperatorCallback(configuration); + + ExecutionIdGenerator.init(id); + if (driverContext.getPipeline() != null) { + LOGGER.info("driver {} execute pipeline from recovered context", name); + executorService.execute(() -> executePipelineInternal(driverContext.getPipeline())); } - - @Override - protected void startRpcService() { - RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); - this.rpcService = new RpcServiceImpl(PortUtil.getPort(rpcPort), serverOptions); - this.rpcService.addEndpoint(new DriverEndpoint(this)); - this.rpcService.addEndpoint(new PipelineMasterEndpoint(this)); - this.rpcPort = rpcService.startService(); - } - - @Override - public Boolean executePipeline(Pipeline pipeline) { - LOGGER.info("driver {} execute pipeline {}", name, pipeline); - Future future = executorService.submit(() -> executePipelineInternal(pipeline)); - try { - return future.get(); - } catch (Throwable e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } - } - - public Boolean executePipelineInternal(Pipeline pipeline) { - try { - LOGGER.info("start execute pipeline {}", pipeline); - driverContext.addPipeline(pipeline); - driverContext.checkpoint(new DriverContext.PipelineCheckpointFunction()); - - IPipelineExecutor pipelineExecutor = PipelineExecutorFactory.createPipelineExecutor(); - PipelineExecutorContext executorContext = new PipelineExecutorContext(name, driverContext.getIndex(), - eventDispatcher, configuration, pipelineTaskIdGenerator); - pipelineExecutor.init(executorContext); - pipelineExecutor.register(pipeline.getViewDescMap()); - - List pipelineTaskList = pipeline.getPipelineTaskList(); - List taskCallBackList = pipeline.getPipelineTaskCallbacks(); - for (int i = 0, size = pipelineTaskList.size(); i < size; i++) { - if (driverContext.getFinishedPipelineTasks() == null || !driverContext.getFinishedPipelineTasks().contains(i)) { - pipelineExecutor.runPipelineTask(pipelineTaskList.get(i), - taskCallBackList.get(i)); - driverContext.addFinishedPipelineTask(i); - driverContext.checkpoint(new DriverContext.PipelineTaskCheckpointFunction()); - } - } - - List pipelineServices = pipeline.getPipelineServices(); - for (PipelineService pipelineService : pipelineServices) { - LOGGER.info("execute service"); - pipelineExecutorMap.put(pipelineService, pipelineExecutor); - pipelineExecutor.startPipelineService(pipelineService); - } - - this.jobOperatorCallback.onFinish(); - LOGGER.info("finish execute pipeline {}", pipeline); - return true; - } catch (Throwable e) { - LOGGER.error("driver exception", e); - throw e; - } - } - - @Override - public Boolean process(IEvent input) { - LOGGER.info("{} process event {}", name, input); - eventDispatcher.dispatch(input); - return true; + registerToMaster(); + registerHAService(); + LOGGER.info("driver {} init finish", name); + } + + @Override + protected void startRpcService() { + RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); + this.rpcService = new RpcServiceImpl(PortUtil.getPort(rpcPort), serverOptions); + this.rpcService.addEndpoint(new DriverEndpoint(this)); + this.rpcService.addEndpoint(new PipelineMasterEndpoint(this)); + this.rpcPort = rpcService.startService(); + } + + @Override + public Boolean executePipeline(Pipeline pipeline) { + LOGGER.info("driver {} execute pipeline {}", name, pipeline); + Future future = executorService.submit(() -> executePipelineInternal(pipeline)); + try { + return future.get(); + } catch (Throwable e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } - - @Override - public void close() { - executorService.shutdownNow(); - for (PipelineService service : pipelineExecutorMap.keySet()) { - pipelineExecutorMap.get(service).stopPipelineService(service); + } + + public Boolean executePipelineInternal(Pipeline pipeline) { + try { + LOGGER.info("start execute pipeline {}", pipeline); + driverContext.addPipeline(pipeline); + driverContext.checkpoint(new DriverContext.PipelineCheckpointFunction()); + + IPipelineExecutor pipelineExecutor = PipelineExecutorFactory.createPipelineExecutor(); + PipelineExecutorContext executorContext = + new PipelineExecutorContext( + name, + driverContext.getIndex(), + eventDispatcher, + configuration, + pipelineTaskIdGenerator); + pipelineExecutor.init(executorContext); + pipelineExecutor.register(pipeline.getViewDescMap()); + + List pipelineTaskList = pipeline.getPipelineTaskList(); + List taskCallBackList = pipeline.getPipelineTaskCallbacks(); + for (int i = 0, size = pipelineTaskList.size(); i < size; i++) { + if (driverContext.getFinishedPipelineTasks() == null + || !driverContext.getFinishedPipelineTasks().contains(i)) { + pipelineExecutor.runPipelineTask(pipelineTaskList.get(i), taskCallBackList.get(i)); + driverContext.addFinishedPipelineTask(i); + driverContext.checkpoint(new DriverContext.PipelineTaskCheckpointFunction()); } - pipelineExecutorMap.clear(); - - super.close(); - LOGGER.info("driver {} closed", name); + } + + List pipelineServices = pipeline.getPipelineServices(); + for (PipelineService pipelineService : pipelineServices) { + LOGGER.info("execute service"); + pipelineExecutorMap.put(pipelineService, pipelineExecutor); + pipelineExecutor.startPipelineService(pipelineService); + } + + this.jobOperatorCallback.onFinish(); + LOGGER.info("finish execute pipeline {}", pipeline); + return true; + } catch (Throwable e) { + LOGGER.error("driver exception", e); + throw e; } - - @Override - protected DriverInfo buildComponentInfo() { - DriverInfo driverInfo = new DriverInfo(); - fillComponentInfo(driverInfo); - return driverInfo; + } + + @Override + public Boolean process(IEvent input) { + LOGGER.info("{} process event {}", name, input); + eventDispatcher.dispatch(input); + return true; + } + + @Override + public void close() { + executorService.shutdownNow(); + for (PipelineService service : pipelineExecutorMap.keySet()) { + pipelineExecutorMap.get(service).stopPipelineService(service); } + pipelineExecutorMap.clear(); + + super.close(); + LOGGER.info("driver {} closed", name); + } + + @Override + protected DriverInfo buildComponentInfo() { + DriverInfo driverInfo = new DriverInfo(); + fillComponentInfo(driverInfo); + return driverInfo; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverContext.java index afd14ca2c..0bf65f4be 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverContext.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.common.ExecutionIdGenerator; import org.apache.geaflow.cluster.common.IReliableContext; import org.apache.geaflow.cluster.common.ReliableContainerContext; @@ -37,120 +38,136 @@ public class DriverContext extends ReliableContainerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(Driver.class); - - private Pipeline pipeline; - - public List getPipelineTaskIds() { - return pipelineTaskIds; - } - - private List pipelineTaskIds; - private List finishedPipelineTasks; - private int index; - - public DriverContext(int id, int index, Configuration config) { - super(id, ClusterConstants.getDriverName(id), config); - this.index = index; - this.finishedPipelineTasks = new ArrayList<>(); - this.pipelineTaskIds = new ArrayList<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(Driver.class); + + private Pipeline pipeline; + + public List getPipelineTaskIds() { + return pipelineTaskIds; + } + + private List pipelineTaskIds; + private List finishedPipelineTasks; + private int index; + + public DriverContext(int id, int index, Configuration config) { + super(id, ClusterConstants.getDriverName(id), config); + this.index = index; + this.finishedPipelineTasks = new ArrayList<>(); + this.pipelineTaskIds = new ArrayList<>(); + } + + public DriverContext(int id, int index, Configuration config, boolean isRecover) { + this(id, index, config); + this.isRecover = isRecover; + } + + @Override + public void load() { + Pipeline pipeline = ClusterMetaStore.getInstance(id, name, config).getPipeline(); + if (pipeline != null) { + List finishedPipelineTasks = ClusterMetaStore.getInstance().getPipelineTasks(); + if (finishedPipelineTasks == null) { + finishedPipelineTasks = new ArrayList<>(); + } + List pipelineTaskIds = ClusterMetaStore.getInstance().getPipelineTaskIds(); + if (pipeline.getPipelineTaskList() != null && pipelineTaskIds == null) { + throw new GeaflowRuntimeException( + String.format( + "driver %s recover context %s " + "error: pipeline task ids is null", id, this)); + } + this.pipeline = pipeline; + this.finishedPipelineTasks = finishedPipelineTasks; + this.pipelineTaskIds = pipelineTaskIds; + LOGGER.info( + "driver {} recover context {} pipeline {} finishedPipelineTasks {} pipelineTaskIds {}", + id, + this, + pipeline, + finishedPipelineTasks, + pipelineTaskIds); } + } - public DriverContext(int id, int index, Configuration config, boolean isRecover) { - this(id, index, config); - this.isRecover = isRecover; - } + public Pipeline getPipeline() { + return pipeline; + } - @Override - public void load() { - Pipeline pipeline = - ClusterMetaStore.getInstance(id, name, config).getPipeline(); - if (pipeline != null) { - List finishedPipelineTasks = ClusterMetaStore.getInstance().getPipelineTasks(); - if (finishedPipelineTasks == null) { - finishedPipelineTasks = new ArrayList<>(); - } - List pipelineTaskIds = ClusterMetaStore.getInstance().getPipelineTaskIds(); - if (pipeline.getPipelineTaskList() != null && pipelineTaskIds == null) { - throw new GeaflowRuntimeException(String.format("driver %s recover context %s " - + "error: pipeline task ids is null", id, this)); - } - this.pipeline = pipeline; - this.finishedPipelineTasks = finishedPipelineTasks; - this.pipelineTaskIds = pipelineTaskIds; - LOGGER.info("driver {} recover context {} pipeline {} finishedPipelineTasks {} pipelineTaskIds {}", - id, this, pipeline, finishedPipelineTasks, pipelineTaskIds); - } + public void addPipeline(Pipeline pipeline) { + genPipelineTaskIds(pipeline); + validatePipeline(pipeline); + if (!pipeline.equals(this.pipeline)) { + this.pipeline = pipeline; } + } - public Pipeline getPipeline() { - return pipeline; - } + public int getIndex() { + return index; + } - public void addPipeline(Pipeline pipeline) { - genPipelineTaskIds(pipeline); - validatePipeline(pipeline); - if (!pipeline.equals(this.pipeline)) { - this.pipeline = pipeline; - } - } + public List getFinishedPipelineTasks() { + return finishedPipelineTasks; + } - public int getIndex() { - return index; + public void addFinishedPipelineTask(int pipelineTaskIndex) { + if (!finishedPipelineTasks.contains(pipelineTaskIndex)) { + finishedPipelineTasks.add(pipelineTaskIndex); } - - public List getFinishedPipelineTasks() { - return finishedPipelineTasks; + } + + private void validatePipeline(Pipeline pipeline) { + // Given that partial components fo only supported for pipeline service, + // do validation for pipeline. + if (!pipeline.getPipelineTaskList().isEmpty() + && config.getString(FO_STRATEGY).equalsIgnoreCase(component_fo.name())) { + throw new GeaflowRuntimeException("not support component_fo for executing pipeline tasks"); } + } - public void addFinishedPipelineTask(int pipelineTaskIndex) { - if (!finishedPipelineTasks.contains(pipelineTaskIndex)) { - finishedPipelineTasks.add(pipelineTaskIndex); - } - } + public static class PipelineCheckpointFunction implements IReliableContextCheckpointFunction { - private void validatePipeline(Pipeline pipeline) { - // Given that partial components fo only supported for pipeline service, - // do validation for pipeline. - if (!pipeline.getPipelineTaskList().isEmpty() - && config.getString(FO_STRATEGY).equalsIgnoreCase(component_fo.name())) { - throw new GeaflowRuntimeException("not support component_fo for executing pipeline tasks"); - } + @Override + public void doCheckpoint(IReliableContext context) { + DriverContext driverContext = ((DriverContext) context); + if (driverContext.getPipeline() != null) { + ClusterMetaStore.getInstance().savePipeline(driverContext.getPipeline()).flush(); + ClusterMetaStore.getInstance() + .savePipelineTaskIds(driverContext.getPipelineTaskIds()) + .flush(); + LOGGER.info( + "driver {} checkpoint context {} pipeline {}, PipelineTaskIds {}", + driverContext.getId(), + driverContext, + driverContext.getPipeline(), + driverContext.getPipelineTaskIds()); + } } + } - public static class PipelineCheckpointFunction implements IReliableContextCheckpointFunction { - - @Override - public void doCheckpoint(IReliableContext context) { - DriverContext driverContext = ((DriverContext) context); - if (driverContext.getPipeline() != null) { - ClusterMetaStore.getInstance().savePipeline(driverContext.getPipeline()).flush(); - ClusterMetaStore.getInstance().savePipelineTaskIds(driverContext.getPipelineTaskIds()).flush(); - LOGGER.info("driver {} checkpoint context {} pipeline {}, PipelineTaskIds {}", - driverContext.getId(), driverContext, driverContext.getPipeline(), driverContext.getPipelineTaskIds()); - } - } - } + public static class PipelineTaskCheckpointFunction implements IReliableContextCheckpointFunction { - public static class PipelineTaskCheckpointFunction implements IReliableContextCheckpointFunction { - - @Override - public void doCheckpoint(IReliableContext context) { - DriverContext driverContext = ((DriverContext) context); - if (driverContext.getFinishedPipelineTasks() != null && !driverContext.getFinishedPipelineTasks().isEmpty()) { - ClusterMetaStore.getInstance().savePipelineTasks(driverContext.getFinishedPipelineTasks()).flush(); - LOGGER.info("driver {} checkpoint pipeline finished tasks {}", - driverContext.getId(), driverContext.getFinishedPipelineTasks()); - } - } + @Override + public void doCheckpoint(IReliableContext context) { + DriverContext driverContext = ((DriverContext) context); + if (driverContext.getFinishedPipelineTasks() != null + && !driverContext.getFinishedPipelineTasks().isEmpty()) { + ClusterMetaStore.getInstance() + .savePipelineTasks(driverContext.getFinishedPipelineTasks()) + .flush(); + LOGGER.info( + "driver {} checkpoint pipeline finished tasks {}", + driverContext.getId(), + driverContext.getFinishedPipelineTasks()); + } } - - private void genPipelineTaskIds(Pipeline pipeline) { - // When recover, we need not generate pipeline task ids. - if (this.pipelineTaskIds.isEmpty()) { - for (int i = 0, size = pipeline.getPipelineTaskList().size(); i < size; i++) { - this.pipelineTaskIds.add(ExecutionIdGenerator.getInstance().generateId()); - } - } + } + + private void genPipelineTaskIds(Pipeline pipeline) { + // When recover, we need not generate pipeline task ids. + if (this.pipelineTaskIds.isEmpty()) { + for (int i = 0, size = pipeline.getPipelineTaskList().size(); i < size; i++) { + this.pipelineTaskIds.add(ExecutionIdGenerator.getInstance().generateId()); + } } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverEventDispatcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverEventDispatcher.java index 0f95df83e..c47e1a86a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverEventDispatcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverEventDispatcher.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.common.IDispatcher; import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.ICycleResponseEvent; @@ -29,26 +30,28 @@ public class DriverEventDispatcher implements IDispatcher { - private Map eventListenerMap; + private Map eventListenerMap; - public DriverEventDispatcher() { - this.eventListenerMap = new HashMap<>(); - } + public DriverEventDispatcher() { + this.eventListenerMap = new HashMap<>(); + } - public void dispatch(IEvent event) { - ICycleResponseEvent doneEvent = (ICycleResponseEvent) event; - IEventListener eventListener = eventListenerMap.get(doneEvent.getSchedulerId()); - if (eventListener == null) { - throw new GeaflowRuntimeException(String.format("event %s do not find handle listener %s", event, doneEvent.getSchedulerId())); - } - eventListener.handleEvent(event); + public void dispatch(IEvent event) { + ICycleResponseEvent doneEvent = (ICycleResponseEvent) event; + IEventListener eventListener = eventListenerMap.get(doneEvent.getSchedulerId()); + if (eventListener == null) { + throw new GeaflowRuntimeException( + String.format( + "event %s do not find handle listener %s", event, doneEvent.getSchedulerId())); } + eventListener.handleEvent(event); + } - public void registerListener(long schedulerId, IEventListener eventListener) { - this.eventListenerMap.put(schedulerId, eventListener); - } + public void registerListener(long schedulerId, IEventListener eventListener) { + this.eventListenerMap.put(schedulerId, eventListener); + } - public void removeListener(long schedulerId) { - this.eventListenerMap.remove(schedulerId); - } + public void removeListener(long schedulerId) { + this.eventListenerMap.remove(schedulerId); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverInfo.java index 150a1fd03..5df7942d7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/DriverInfo.java @@ -23,9 +23,21 @@ public class DriverInfo extends ComponentInfo { - @Override - public String toString() { - return "DriverInfo{" + "id=" + id + ", name='" + name + '\'' + ", host='" + host + '\'' - + ", pid=" + pid + ", rpcPort=" + rpcPort + "}"; - } + @Override + public String toString() { + return "DriverInfo{" + + "id=" + + id + + ", name='" + + name + + '\'' + + ", host='" + + host + + '\'' + + ", pid=" + + pid + + ", rpcPort=" + + rpcPort + + "}"; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/IDriver.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/IDriver.java index 461bc4861..8cbcd92c8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/IDriver.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/driver/IDriver.java @@ -24,19 +24,12 @@ public interface IDriver extends IEventProcessor { - /** - * Initialize driver. - */ - void init(DriverContext driverContext); + /** Initialize driver. */ + void init(DriverContext driverContext); - /** - * Execute pipeline task/service. - */ - R executePipeline(Pipeline pipeline); - - /** - * Close driver. - */ - void close(); + /** Execute pipeline task/service. */ + R executePipeline(Pipeline pipeline); + /** Close driver. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentExceptionSupervisor.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentExceptionSupervisor.java index 145174b2b..b87e587a7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentExceptionSupervisor.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentExceptionSupervisor.java @@ -25,54 +25,56 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ComponentExceptionSupervisor extends AbstractTaskRunner { +public class ComponentExceptionSupervisor + extends AbstractTaskRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(ComponentExceptionSupervisor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ComponentExceptionSupervisor.class); - private static ComponentExceptionSupervisor INSTANCE; + private static ComponentExceptionSupervisor INSTANCE; - @Override - protected void process(ExceptionElement exceptionElement) { - // Send exception to master. - ExceptionClient exceptionClient = ExceptionClient.getInstance(); - if (exceptionClient != null) { - exceptionClient.sendException(exceptionElement.cause); - } - // Exit current process if supervisor is running. - if (running) { - LOGGER.error(String.format("%s occur fatal exception, exit process now", - exceptionElement.thread), exceptionElement.cause); - System.exit(EXIT_CODE); - } else { - LOGGER.info("{} ignore exception because supervisor is shutdown", exceptionElement.thread); - } + @Override + protected void process(ExceptionElement exceptionElement) { + // Send exception to master. + ExceptionClient exceptionClient = ExceptionClient.getInstance(); + if (exceptionClient != null) { + exceptionClient.sendException(exceptionElement.cause); } - - public static synchronized ComponentExceptionSupervisor getInstance() { - if (INSTANCE == null) { - INSTANCE = new ComponentExceptionSupervisor(); - } - return INSTANCE; + // Exit current process if supervisor is running. + if (running) { + LOGGER.error( + String.format("%s occur fatal exception, exit process now", exceptionElement.thread), + exceptionElement.cause); + System.exit(EXIT_CODE); + } else { + LOGGER.info("{} ignore exception because supervisor is shutdown", exceptionElement.thread); } + } - @Override - public void shutdown() { - super.shutdown(); - this.INSTANCE = null; + public static synchronized ComponentExceptionSupervisor getInstance() { + if (INSTANCE == null) { + INSTANCE = new ComponentExceptionSupervisor(); } + return INSTANCE; + } + + @Override + public void shutdown() { + super.shutdown(); + this.INSTANCE = null; + } - public static class ExceptionElement { + public static class ExceptionElement { - private Thread thread; - private Throwable cause; + private Thread thread; + private Throwable cause; - public ExceptionElement(Thread thread, Throwable cause) { - this.thread = thread; - this.cause = cause; - } + public ExceptionElement(Thread thread, Throwable cause) { + this.thread = thread; + this.cause = cause; + } - public static ExceptionElement of(Thread thread, Throwable cause) { - return new ExceptionElement(thread, cause); - } + public static ExceptionElement of(Thread thread, Throwable cause) { + return new ExceptionElement(thread, cause); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandler.java index 8d99d8786..5de422dff 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandler.java @@ -26,18 +26,20 @@ public class ComponentUncaughtExceptionHandler implements Thread.UncaughtExceptionHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(ComponentUncaughtExceptionHandler.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(ComponentUncaughtExceptionHandler.class); - public static final ComponentUncaughtExceptionHandler INSTANCE = new ComponentUncaughtExceptionHandler(); + public static final ComponentUncaughtExceptionHandler INSTANCE = + new ComponentUncaughtExceptionHandler(); - @Override - public void uncaughtException(Thread thread, Throwable cause) { - LOGGER.error("FATAL exception in thread: {}", thread.getName(), cause); - StatsCollectorFactory collectorFactory = StatsCollectorFactory.getInstance(); - if (collectorFactory != null) { - collectorFactory.getExceptionCollector().reportException(ExceptionLevel.FATAL, cause); - } - ComponentExceptionSupervisor.getInstance() - .add(ComponentExceptionSupervisor.ExceptionElement.of(thread, cause)); + @Override + public void uncaughtException(Thread thread, Throwable cause) { + LOGGER.error("FATAL exception in thread: {}", thread.getName(), cause); + StatsCollectorFactory collectorFactory = StatsCollectorFactory.getInstance(); + if (collectorFactory != null) { + collectorFactory.getExceptionCollector().reportException(ExceptionLevel.FATAL, cause); } + ComponentExceptionSupervisor.getInstance() + .add(ComponentExceptionSupervisor.ExceptionElement.of(thread, cause)); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionClient.java index 618db0546..0cdfb2fe0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionClient.java @@ -25,38 +25,37 @@ public class ExceptionClient { - private static final Logger LOGGER = LoggerFactory.getLogger(ExceptionClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ExceptionClient.class); - private static ExceptionClient INSTANCE; + private static ExceptionClient INSTANCE; - private final String masterId; - private final int containerId; - private final String containerName; + private final String masterId; + private final int containerId; + private final String containerName; - public ExceptionClient(int containerId, String containerName, String masterId) { - this.containerId = containerId; - this.containerName = containerName; - this.masterId = masterId; - } + public ExceptionClient(int containerId, String containerName, String masterId) { + this.containerId = containerId; + this.containerName = containerName; + this.masterId = masterId; + } - public static synchronized ExceptionClient init(int containerId, String name, String masterId) { - if (INSTANCE == null) { - INSTANCE = new ExceptionClient(containerId, name, masterId); - } - return INSTANCE; + public static synchronized ExceptionClient init(int containerId, String name, String masterId) { + if (INSTANCE == null) { + INSTANCE = new ExceptionClient(containerId, name, masterId); } - - public static synchronized ExceptionClient getInstance() { - return INSTANCE; + return INSTANCE; + } + + public static synchronized ExceptionClient getInstance() { + return INSTANCE; + } + + public void sendException(Throwable throwable) { + try { + LOGGER.info("Send exception {} to master.", throwable.getMessage()); + RpcClient.getInstance().sendException(masterId, containerId, containerName, throwable); + } catch (Throwable e) { + LOGGER.error("Send exception {} to master failed.", throwable.getMessage(), e); } - - public void sendException(Throwable throwable) { - try { - LOGGER.info("Send exception {} to master.", throwable.getMessage()); - RpcClient.getInstance().sendException(masterId, containerId, containerName, throwable); - } catch (Throwable e) { - LOGGER.error("Send exception {} to master failed.", throwable.getMessage(), e); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionCollectService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionCollectService.java index d5ea4eef5..b45864108 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionCollectService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/exception/ExceptionCollectService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.exception; import java.util.concurrent.ExecutorService; + import org.apache.geaflow.common.thread.Executors; import org.apache.geaflow.common.utils.ExecutorUtil; import org.slf4j.Logger; @@ -27,29 +28,29 @@ public class ExceptionCollectService { - private static final Logger LOGGER = LoggerFactory.getLogger(ExceptionCollectService.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ExceptionCollectService.class); - private static final String EXCEPTION_COLLECT_FORMAT = "geaflow-exception-collect-%d"; - private static final int EXCEPTION_COLLECTOR_THREAD_NUM = 1; + private static final String EXCEPTION_COLLECT_FORMAT = "geaflow-exception-collect-%d"; + private static final int EXCEPTION_COLLECTOR_THREAD_NUM = 1; - private ExecutorService exceptionCollectService; - private ComponentExceptionSupervisor supervisor; + private ExecutorService exceptionCollectService; + private ComponentExceptionSupervisor supervisor; - public ExceptionCollectService() { - this.exceptionCollectService = - Executors.getExecutorService(EXCEPTION_COLLECTOR_THREAD_NUM, EXCEPTION_COLLECT_FORMAT); - supervisor = ComponentExceptionSupervisor.getInstance(); - this.exceptionCollectService.execute(supervisor); - } + public ExceptionCollectService() { + this.exceptionCollectService = + Executors.getExecutorService(EXCEPTION_COLLECTOR_THREAD_NUM, EXCEPTION_COLLECT_FORMAT); + supervisor = ComponentExceptionSupervisor.getInstance(); + this.exceptionCollectService.execute(supervisor); + } - public void shutdown() { + public void shutdown() { - LOGGER.info("shutdown exception collect service"); - if (supervisor != null) { - supervisor.shutdown(); - } - if (exceptionCollectService != null) { - ExecutorUtil.shutdown(exceptionCollectService); - } + LOGGER.info("shutdown exception collect service"); + if (supervisor != null) { + supervisor.shutdown(); + } + if (exceptionCollectService != null) { + ExecutorUtil.shutdown(exceptionCollectService); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/IPipelineExecutor.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/IPipelineExecutor.java index f2164f0e6..7dc2f6232 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/IPipelineExecutor.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/IPipelineExecutor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.executor; import java.util.List; + import org.apache.geaflow.pipeline.callback.TaskCallBack; import org.apache.geaflow.pipeline.service.PipelineService; import org.apache.geaflow.pipeline.task.PipelineTask; @@ -27,28 +28,18 @@ public interface IPipelineExecutor { - /** - * Init pipeline executor. - */ - void init(PipelineExecutorContext executorContext); - - /** - * Register view desc list. - */ - void register(List viewDescList); - - /** - * Trigger to run pipeline task. - */ - void runPipelineTask(PipelineTask pipelineTask, TaskCallBack taskCallBack); - - /** - * Trigger to start pipeline service. - */ - void startPipelineService(PipelineService pipelineService); - - /** - * Stop pipeline service server. - */ - void stopPipelineService(PipelineService pipelineService); + /** Init pipeline executor. */ + void init(PipelineExecutorContext executorContext); + + /** Register view desc list. */ + void register(List viewDescList); + + /** Trigger to run pipeline task. */ + void runPipelineTask(PipelineTask pipelineTask, TaskCallBack taskCallBack); + + /** Trigger to start pipeline service. */ + void startPipelineService(PipelineService pipelineService); + + /** Stop pipeline service server. */ + void stopPipelineService(PipelineService pipelineService); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorContext.java index 97d90a05f..aa3ec1fcc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorContext.java @@ -20,46 +20,48 @@ package org.apache.geaflow.cluster.executor; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.driver.DriverEventDispatcher; import org.apache.geaflow.common.config.Configuration; public class PipelineExecutorContext { - private DriverEventDispatcher eventDispatcher; - private Configuration envConfig; - private String driverId; - private int driverIndex; - private AtomicInteger idGenerator; + private DriverEventDispatcher eventDispatcher; + private Configuration envConfig; + private String driverId; + private int driverIndex; + private AtomicInteger idGenerator; - public PipelineExecutorContext(String driverId, int driverIndex, DriverEventDispatcher eventDispatcher, - Configuration envConfig, - AtomicInteger idGenerator) { - this.eventDispatcher = eventDispatcher; - this.envConfig = envConfig; - this.driverId = driverId; - this.driverIndex = driverIndex; - this.idGenerator = idGenerator; - } + public PipelineExecutorContext( + String driverId, + int driverIndex, + DriverEventDispatcher eventDispatcher, + Configuration envConfig, + AtomicInteger idGenerator) { + this.eventDispatcher = eventDispatcher; + this.envConfig = envConfig; + this.driverId = driverId; + this.driverIndex = driverIndex; + this.idGenerator = idGenerator; + } - public DriverEventDispatcher getEventDispatcher() { - return eventDispatcher; - } + public DriverEventDispatcher getEventDispatcher() { + return eventDispatcher; + } - public Configuration getEnvConfig() { - return this.envConfig; - } + public Configuration getEnvConfig() { + return this.envConfig; + } - public String getDriverId() { - return driverId; - } + public String getDriverId() { + return driverId; + } - public int getDriverIndex() { - return driverIndex; - } + public int getDriverIndex() { + return driverIndex; + } - public AtomicInteger getIdGenerator() { - return idGenerator; - } + public AtomicInteger getIdGenerator() { + return idGenerator; + } } - - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorFactory.java index f396a8262..610f32b11 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/executor/PipelineExecutorFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; @@ -28,16 +29,16 @@ public class PipelineExecutorFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutorFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutorFactory.class); - public static IPipelineExecutor createPipelineExecutor() { - ServiceLoader executorLoader = ServiceLoader.load(IPipelineExecutor.class); - Iterator executorIterable = executorLoader.iterator(); - while (executorIterable.hasNext()) { - return executorIterable.next(); - } - LOGGER.error("NOT found IPipelineExecutor implementation"); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(IPipelineExecutor.class.getSimpleName())); + public static IPipelineExecutor createPipelineExecutor() { + ServiceLoader executorLoader = ServiceLoader.load(IPipelineExecutor.class); + Iterator executorIterable = executorLoader.iterator(); + while (executorIterable.hasNext()) { + return executorIterable.next(); } + LOGGER.error("NOT found IPipelineExecutor implementation"); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IPipelineExecutor.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyFactory.java index 8c1b3d76e..eb819197d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.env.IEnvironment.EnvType; @@ -29,19 +30,20 @@ public class FailoverStrategyFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(FailoverStrategyFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FailoverStrategyFactory.class); - public static IFailoverStrategy loadFailoverStrategy(EnvType envType, String foStrategyType) { - ServiceLoader contextLoader = ServiceLoader.load(IFailoverStrategy.class); - Iterator contextIterable = contextLoader.iterator(); - while (contextIterable.hasNext()) { - IFailoverStrategy strategy = contextIterable.next(); - if (strategy.getEnv() == envType && strategy.getType().name().equalsIgnoreCase(foStrategyType)) { - return strategy; - } - } - LOGGER.error("NOT found IFoStrategy implementation with type:{}", foStrategyType); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(IFailoverStrategy.class.getSimpleName())); + public static IFailoverStrategy loadFailoverStrategy(EnvType envType, String foStrategyType) { + ServiceLoader contextLoader = ServiceLoader.load(IFailoverStrategy.class); + Iterator contextIterable = contextLoader.iterator(); + while (contextIterable.hasNext()) { + IFailoverStrategy strategy = contextIterable.next(); + if (strategy.getEnv() == envType + && strategy.getType().name().equalsIgnoreCase(foStrategyType)) { + return strategy; + } } + LOGGER.error("NOT found IFoStrategy implementation with type:{}", foStrategyType); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IFailoverStrategy.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyType.java index 75cb875a5..bbab2c0fb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/FailoverStrategyType.java @@ -21,18 +21,12 @@ public enum FailoverStrategyType { - /** - * Restart all components. - */ - cluster_fo, + /** Restart all components. */ + cluster_fo, - /** - * Component only restarts itself. - */ - component_fo, + /** Component only restarts itself. */ + component_fo, - /** - * Disable failover. - */ - disable_fo + /** Disable failover. */ + disable_fo } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/IFailoverStrategy.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/IFailoverStrategy.java index 30197a351..0e6d09d59 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/IFailoverStrategy.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/failover/IFailoverStrategy.java @@ -24,26 +24,19 @@ public interface IFailoverStrategy { - /** - * Init fo strategy by context. - */ - void init(ClusterContext context); + /** Init fo strategy by context. */ + void init(ClusterContext context); - /** - * Trigger failover by input component id. - */ - void doFailover(int componentId, Throwable cause); + /** Trigger failover by input component id. */ + void doFailover(int componentId, Throwable cause); - /** - * Get failover strategy name. - */ - FailoverStrategyType getType(); - - /** - * Get Env type. - * - * @return - */ - EnvType getEnv(); + /** Get failover strategy name. */ + FailoverStrategyType getType(); + /** + * Get Env type. + * + * @return + */ + EnvType getEnv(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/BarrierHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/BarrierHandler.java index 8553b8c96..44b2ec6d3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/BarrierHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/BarrierHandler.java @@ -25,79 +25,84 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.desc.ShardInputDesc; import org.apache.geaflow.shuffle.message.PipelineBarrier; -/** - * Process the barrier event. - */ +/** Process the barrier event. */ public class BarrierHandler implements Serializable { - private final int taskId; - private final Map inputShards; - private final Map edgeId2sliceNum; - private final int inputSliceNum; - private final Map> edgeBarrierCache; - private final Map> windowBarrierCache; + private final int taskId; + private final Map inputShards; + private final Map edgeId2sliceNum; + private final int inputSliceNum; + private final Map> edgeBarrierCache; + private final Map> windowBarrierCache; - private long finishedWindowId; - private long totalWindowCount; + private long finishedWindowId; + private long totalWindowCount; - public BarrierHandler(int taskId, Map inputShards) { - this.taskId = taskId; - this.inputShards = inputShards; - this.edgeId2sliceNum = inputShards.entrySet().stream() + public BarrierHandler(int taskId, Map inputShards) { + this.taskId = taskId; + this.inputShards = inputShards; + this.edgeId2sliceNum = + inputShards.entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getSliceNum())); - this.inputSliceNum = this.edgeId2sliceNum.values().stream().mapToInt(i -> i).sum(); - this.edgeBarrierCache = this.inputShards.entrySet().stream() + this.inputSliceNum = this.edgeId2sliceNum.values().stream().mapToInt(i -> i).sum(); + this.edgeBarrierCache = + this.inputShards.entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, e -> new HashSet<>())); - this.windowBarrierCache = new HashMap<>(); - this.finishedWindowId = -1; - this.totalWindowCount = 0; - } + this.windowBarrierCache = new HashMap<>(); + this.finishedWindowId = -1; + this.totalWindowCount = 0; + } - public boolean checkCompleted(PipelineBarrier barrier) { - if (barrier.getWindowId() <= this.finishedWindowId) { - throw new GeaflowRuntimeException(String.format( - "illegal state: taskId %s window %s has finished, last finished window is: %s", - this.taskId, barrier.getWindowId(), this.finishedWindowId)); - } - int edgeId = barrier.getEdgeId(); - long windowId = barrier.getWindowId(); - Set edgeBarriers = this.edgeBarrierCache.computeIfAbsent(edgeId, k -> new HashSet<>()); - Set windowBarriers = this.windowBarrierCache.computeIfAbsent(windowId, k -> new HashSet<>()); - edgeBarriers.add(barrier); - windowBarriers.add(barrier); + public boolean checkCompleted(PipelineBarrier barrier) { + if (barrier.getWindowId() <= this.finishedWindowId) { + throw new GeaflowRuntimeException( + String.format( + "illegal state: taskId %s window %s has finished, last finished window is: %s", + this.taskId, barrier.getWindowId(), this.finishedWindowId)); + } + int edgeId = barrier.getEdgeId(); + long windowId = barrier.getWindowId(); + Set edgeBarriers = + this.edgeBarrierCache.computeIfAbsent(edgeId, k -> new HashSet<>()); + Set windowBarriers = + this.windowBarrierCache.computeIfAbsent(windowId, k -> new HashSet<>()); + edgeBarriers.add(barrier); + windowBarriers.add(barrier); - if (this.inputShards.get(edgeId).isPrefetchWrite()) { - int barrierSize = edgeBarriers.size(); - if (barrierSize == this.edgeId2sliceNum.get(edgeId)) { - this.edgeBarrierCache.remove(edgeId); - edgeBarriers.clear(); - if (this.edgeBarrierCache.isEmpty()) { - this.windowBarrierCache.remove(windowId); - this.finishedWindowId = windowId; - this.totalWindowCount = windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum(); - windowBarriers.clear(); - } - return true; - } - } else { - int barrierSize = windowBarriers.size(); - if (barrierSize == this.inputSliceNum) { - this.windowBarrierCache.remove(windowId); - this.finishedWindowId = windowId; - this.totalWindowCount = windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum(); - windowBarriers.clear(); - return true; - } + if (this.inputShards.get(edgeId).isPrefetchWrite()) { + int barrierSize = edgeBarriers.size(); + if (barrierSize == this.edgeId2sliceNum.get(edgeId)) { + this.edgeBarrierCache.remove(edgeId); + edgeBarriers.clear(); + if (this.edgeBarrierCache.isEmpty()) { + this.windowBarrierCache.remove(windowId); + this.finishedWindowId = windowId; + this.totalWindowCount = + windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum(); + windowBarriers.clear(); } - - return false; + return true; + } + } else { + int barrierSize = windowBarriers.size(); + if (barrierSize == this.inputSliceNum) { + this.windowBarrierCache.remove(windowId); + this.finishedWindowId = windowId; + this.totalWindowCount = windowBarriers.stream().mapToLong(PipelineBarrier::getCount).sum(); + windowBarriers.clear(); + return true; + } } - public long getTotalWindowCount() { - return totalWindowCount; - } + return false; + } + + public long getTotalWindowCount() { + return totalWindowCount; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/CloseFetchRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/CloseFetchRequest.java index 134b6989b..8c8af56ec 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/CloseFetchRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/CloseFetchRequest.java @@ -21,20 +21,19 @@ public class CloseFetchRequest implements IFetchRequest { - private final int taskId; + private final int taskId; - public CloseFetchRequest(int taskId) { - this.taskId = taskId; - } + public CloseFetchRequest(int taskId) { + this.taskId = taskId; + } - @Override - public int getTaskId() { - return this.taskId; - } - - @Override - public RequestType getRequestType() { - return RequestType.CLOSE; - } + @Override + public int getTaskId() { + return this.taskId; + } + @Override + public RequestType getRequestType() { + return RequestType.CLOSE; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetchRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetchRequest.java index cc26989e8..f3939d97f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetchRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetchRequest.java @@ -21,32 +21,31 @@ public class FetchRequest implements IFetchRequest { - private final int taskId; - private final long windowId; - private final long windowCount; - - public FetchRequest(int taskId, long windowId, long windowCount) { - this.taskId = taskId; - this.windowId = windowId; - this.windowCount = windowCount; - } - - @Override - public int getTaskId() { - return this.taskId; - } - - public long getWindowId() { - return this.windowId; - } - - public long getWindowCount() { - return this.windowCount; - } - - @Override - public RequestType getRequestType() { - return RequestType.FETCH; - } - + private final int taskId; + private final long windowId; + private final long windowCount; + + public FetchRequest(int taskId, long windowId, long windowCount) { + this.taskId = taskId; + this.windowId = windowId; + this.windowCount = windowCount; + } + + @Override + public int getTaskId() { + return this.taskId; + } + + public long getWindowId() { + return this.windowId; + } + + public long getWindowCount() { + return this.windowCount; + } + + @Override + public RequestType getRequestType() { + return RequestType.FETCH; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherRunner.java index 9fe6f3462..31d496c78 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherRunner.java @@ -28,43 +28,42 @@ public class FetcherRunner extends AbstractTaskRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(FetcherRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FetcherRunner.class); - private final PipelineInputFetcher fetcher; + private final PipelineInputFetcher fetcher; - public FetcherRunner(Configuration configuration) { - this.fetcher = new PipelineInputFetcher(configuration); - } - - @Override - protected void process(IFetchRequest task) { - IFetchRequest.RequestType requestType = task.getRequestType(); - switch (requestType) { - case INIT: - this.fetcher.init((InitFetchRequest) task); - break; - case FETCH: - this.fetcher.fetch((FetchRequest) task); - break; - case CLOSE: - this.fetcher.close((CloseFetchRequest) task); - break; - default: - throw new GeaflowRuntimeException( - RuntimeErrors.INST.requestTypeNotSupportError(requestType.name())); - } - } + public FetcherRunner(Configuration configuration) { + this.fetcher = new PipelineInputFetcher(configuration); + } - @Override - public void interrupt() { - LOGGER.info("cancel fetcher runner"); - fetcher.cancel(); + @Override + protected void process(IFetchRequest task) { + IFetchRequest.RequestType requestType = task.getRequestType(); + switch (requestType) { + case INIT: + this.fetcher.init((InitFetchRequest) task); + break; + case FETCH: + this.fetcher.fetch((FetchRequest) task); + break; + case CLOSE: + this.fetcher.close((CloseFetchRequest) task); + break; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.requestTypeNotSupportError(requestType.name())); } + } - @Override - public void shutdown() { - super.shutdown(); - fetcher.close(); - } + @Override + public void interrupt() { + LOGGER.info("cancel fetcher runner"); + fetcher.cancel(); + } + @Override + public void shutdown() { + super.shutdown(); + fetcher.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherService.java index cecb67147..ca243904d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/FetcherService.java @@ -19,34 +19,37 @@ package org.apache.geaflow.cluster.fetcher; -import com.google.common.base.Preconditions; import java.io.Serializable; + import org.apache.geaflow.cluster.task.service.AbstractTaskService; import org.apache.geaflow.common.config.Configuration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class FetcherService extends AbstractTaskService implements Serializable { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(FetcherService.class); +public class FetcherService extends AbstractTaskService + implements Serializable { - private static final String FETCHER_FORMAT = "geaflow-fetcher-%d"; + private static final Logger LOGGER = LoggerFactory.getLogger(FetcherService.class); - private int slots; + private static final String FETCHER_FORMAT = "geaflow-fetcher-%d"; - public FetcherService(int slots, Configuration configuration) { - super(configuration, FETCHER_FORMAT); - this.slots = slots; - } + private int slots; + + public FetcherService(int slots, Configuration configuration) { + super(configuration, FETCHER_FORMAT); + this.slots = slots; + } - @Override - protected FetcherRunner[] buildTaskRunner() { - Preconditions.checkArgument(slots > 0, "fetcher pool should be larger than 0"); - FetcherRunner[] fetcherRunners = new FetcherRunner[slots]; - for (int i = 0; i < slots; i++) { - FetcherRunner runner = new FetcherRunner(configuration); - fetcherRunners[i] = runner; - } - return fetcherRunners; + @Override + protected FetcherRunner[] buildTaskRunner() { + Preconditions.checkArgument(slots > 0, "fetcher pool should be larger than 0"); + FetcherRunner[] fetcherRunners = new FetcherRunner[slots]; + for (int i = 0; i < slots; i++) { + FetcherRunner runner = new FetcherRunner(configuration); + fetcherRunners[i] = runner; } + return fetcherRunners; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IFetchRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IFetchRequest.java index 984ccf475..dfa80bfcb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IFetchRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IFetchRequest.java @@ -23,35 +23,27 @@ public interface IFetchRequest extends Serializable { - /** - * Get the task id of this request. - * - * @return task id - */ - int getTaskId(); - - /** - * Get the request type. - * - * @return request type - */ - RequestType getRequestType(); - - enum RequestType { - - /** - * Init fetch request, setup fetch context and input slice meta. - */ - INIT, - /** - * Fetch data of a window id. - */ - FETCH, - /** - * Close the fetch task when data finish. - */ - CLOSE - - } - + /** + * Get the task id of this request. + * + * @return task id + */ + int getTaskId(); + + /** + * Get the request type. + * + * @return request type + */ + RequestType getRequestType(); + + enum RequestType { + + /** Init fetch request, setup fetch context and input slice meta. */ + INIT, + /** Fetch data of a window id. */ + FETCH, + /** Close the fetch task when data finish. */ + CLOSE + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IInputMessageBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IInputMessageBuffer.java index 4c88a45d1..a930c45a9 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IInputMessageBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/IInputMessageBuffer.java @@ -26,8 +26,7 @@ public interface IInputMessageBuffer extends IMessageBuffer> { - void onMessage(PipelineMessage message); - - void onBarrier(PipelineBarrier barrier); + void onMessage(PipelineMessage message); + void onBarrier(PipelineBarrier barrier); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/InitFetchRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/InitFetchRequest.java index 9542b39af..ef83e1e2d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/InitFetchRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/InitFetchRequest.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.common.shuffle.BatchPhase; import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.common.tuple.Tuple; @@ -38,135 +39,139 @@ public class InitFetchRequest implements IFetchRequest { - private final long pipelineId; - private final String pipelineName; - private final int vertexId; - private final int taskId; - private final int taskIndex; - private final int taskParallelism; - private final String taskName; - private final Map inputShards; - private final Map shufflePhases; - private final int totalSliceNum; - private final Map> inputSlices; - private final List> fetchListeners; - - public InitFetchRequest(long pipelineId, - String pipelineName, - int vertexId, - int taskId, - int taskIndex, - int taskParallelism, - String taskName, - Map inputShards) { - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - this.vertexId = vertexId; - this.taskId = taskId; - this.taskIndex = taskIndex; - this.taskParallelism = taskParallelism; - this.taskName = taskName; - this.inputShards = inputShards; - this.shufflePhases = this.inputShards.entrySet().stream() + private final long pipelineId; + private final String pipelineName; + private final int vertexId; + private final int taskId; + private final int taskIndex; + private final int taskParallelism; + private final String taskName; + private final Map inputShards; + private final Map shufflePhases; + private final int totalSliceNum; + private final Map> inputSlices; + private final List> fetchListeners; + + public InitFetchRequest( + long pipelineId, + String pipelineName, + int vertexId, + int taskId, + int taskIndex, + int taskParallelism, + String taskName, + Map inputShards) { + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + this.vertexId = vertexId; + this.taskId = taskId; + this.taskIndex = taskIndex; + this.taskParallelism = taskParallelism; + this.taskName = taskName; + this.inputShards = inputShards; + this.shufflePhases = + this.inputShards.entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getBatchPhase())); - Tuple>> tuple = buildSlices(this.inputShards); - this.totalSliceNum = tuple.f0; - this.inputSlices = tuple.f1; - this.fetchListeners = new ArrayList<>(); - } - - @Override - public int getTaskId() { - return this.taskId; - } - - @Override - public RequestType getRequestType() { - return RequestType.INIT; - } - - public long getPipelineId() { - return this.pipelineId; - } - - public String getPipelineName() { - return this.pipelineName; - } - - public int getVertexId() { - return this.vertexId; - } - - public int getTaskIndex() { - return this.taskIndex; - } - - public int getTaskParallelism() { - return this.taskParallelism; - } - - public String getTaskName() { - return this.taskName; - } - - public Map getInputShards() { - return this.inputShards; - } - - public Map getShufflePhases() { - return this.shufflePhases; - } - - public int getSliceNum() { - return this.totalSliceNum; - } - - public Map> getInputSlices() { - return this.inputSlices; - } - - public List> getFetchListeners() { - return this.fetchListeners; - } - - public void addListener(IInputMessageBuffer listener) { - this.fetchListeners.add(listener); - } - - private static Tuple>> buildSlices(Map inputShards) { - int totalSliceNum = 0; - Map> inputSlices = new HashMap<>(); - - for (Map.Entry entry : inputShards.entrySet()) { - Integer edgeId = entry.getKey(); - ShardInputDesc inputDesc = entry.getValue(); - List shards = inputDesc.getInput(); - List slices = new ArrayList<>(); - for (Shard shard : shards) { - for (ISliceMeta slice : shard.getSlices()) { - if (slice instanceof LogicalPipelineSliceMeta) { - // Convert to physical shuffle slice meta. - LogicalPipelineSliceMeta logicalPipelineSliceMeta = (LogicalPipelineSliceMeta) slice; - slices.add(toPhysicalSliceMeta(logicalPipelineSliceMeta)); - } else { - slices.add((PipelineSliceMeta) slice); - } - } - } - inputSlices.put(edgeId, slices); - totalSliceNum += inputDesc.getSliceNum(); + Tuple>> tuple = buildSlices(this.inputShards); + this.totalSliceNum = tuple.f0; + this.inputSlices = tuple.f1; + this.fetchListeners = new ArrayList<>(); + } + + @Override + public int getTaskId() { + return this.taskId; + } + + @Override + public RequestType getRequestType() { + return RequestType.INIT; + } + + public long getPipelineId() { + return this.pipelineId; + } + + public String getPipelineName() { + return this.pipelineName; + } + + public int getVertexId() { + return this.vertexId; + } + + public int getTaskIndex() { + return this.taskIndex; + } + + public int getTaskParallelism() { + return this.taskParallelism; + } + + public String getTaskName() { + return this.taskName; + } + + public Map getInputShards() { + return this.inputShards; + } + + public Map getShufflePhases() { + return this.shufflePhases; + } + + public int getSliceNum() { + return this.totalSliceNum; + } + + public Map> getInputSlices() { + return this.inputSlices; + } + + public List> getFetchListeners() { + return this.fetchListeners; + } + + public void addListener(IInputMessageBuffer listener) { + this.fetchListeners.add(listener); + } + + private static Tuple>> buildSlices( + Map inputShards) { + int totalSliceNum = 0; + Map> inputSlices = new HashMap<>(); + + for (Map.Entry entry : inputShards.entrySet()) { + Integer edgeId = entry.getKey(); + ShardInputDesc inputDesc = entry.getValue(); + List shards = inputDesc.getInput(); + List slices = new ArrayList<>(); + for (Shard shard : shards) { + for (ISliceMeta slice : shard.getSlices()) { + if (slice instanceof LogicalPipelineSliceMeta) { + // Convert to physical shuffle slice meta. + LogicalPipelineSliceMeta logicalPipelineSliceMeta = (LogicalPipelineSliceMeta) slice; + slices.add(toPhysicalSliceMeta(logicalPipelineSliceMeta)); + } else { + slices.add((PipelineSliceMeta) slice); + } } - - return Tuple.of(totalSliceNum, inputSlices); - } - - private static PipelineSliceMeta toPhysicalSliceMeta(LogicalPipelineSliceMeta sliceMeta) { - String containerId = sliceMeta.getContainerId(); - SliceId sliceId = sliceMeta.getSliceId(); - long windowId = sliceMeta.getWindowId(); - ResourceData resourceData = HAServiceFactory.getService().resolveResource(containerId); - return new PipelineSliceMeta(sliceId, windowId, - new ShuffleAddress(resourceData.getHost(), resourceData.getShufflePort())); - } - + } + inputSlices.put(edgeId, slices); + totalSliceNum += inputDesc.getSliceNum(); + } + + return Tuple.of(totalSliceNum, inputSlices); + } + + private static PipelineSliceMeta toPhysicalSliceMeta(LogicalPipelineSliceMeta sliceMeta) { + String containerId = sliceMeta.getContainerId(); + SliceId sliceId = sliceMeta.getSliceId(); + long windowId = sliceMeta.getWindowId(); + ResourceData resourceData = HAServiceFactory.getService().resolveResource(containerId); + return new PipelineSliceMeta( + sliceId, + windowId, + new ShuffleAddress(resourceData.getHost(), resourceData.getShufflePort())); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PipelineInputFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PipelineInputFetcher.java index 937f81894..83aa3d6ef 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PipelineInputFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PipelineInputFetcher.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.exception.ComponentUncaughtExceptionHandler; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.IEncoder; @@ -43,202 +44,205 @@ public class PipelineInputFetcher { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineInputFetcher.class); - - private static final ExecutorService FETCH_EXECUTOR = Executors.getUnboundedExecutorService( - PipelineInputFetcher.class.getSimpleName(), 60, TimeUnit.SECONDS, null, - ComponentUncaughtExceptionHandler.INSTANCE); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineInputFetcher.class); + + private static final ExecutorService FETCH_EXECUTOR = + Executors.getUnboundedExecutorService( + PipelineInputFetcher.class.getSimpleName(), + 60, + TimeUnit.SECONDS, + null, + ComponentUncaughtExceptionHandler.INSTANCE); + + private final Map taskId2fetchTask = new HashMap<>(); + private final Configuration config; + + public PipelineInputFetcher(Configuration config) { + this.config = config; + } + + /** + * Init input fetcher reader. + * + * @param request init fetch request + */ + public void init(InitFetchRequest request) { + int taskId = request.getTaskId(); + if (this.taskId2fetchTask.containsKey(taskId)) { + throw new GeaflowRuntimeException("task already exists: " + taskId); + } + for (ShardInputDesc inputDesc : request.getInputShards().values()) { + IEncoder encoder = inputDesc.getEncoder(); + if (encoder != null) { + encoder.init(this.config); + } + } + this.taskId2fetchTask.put(taskId, new FetcherTask(this.config, request)); + LOGGER.info("init fetcher task {} {}", request.getTaskName(), request.getShufflePhases()); + } + + /** + * Fetch data according to fetch request and process by worker. + * + * @param request fetch request + */ + protected void fetch(FetchRequest request) { + FetcherTask fetcherTask = this.taskId2fetchTask.get(request.getTaskId()); + if (fetcherTask != null) { + long targetWindowId = request.getWindowId() + request.getWindowCount() - 1; + fetcherTask.updateWindowId(targetWindowId); + if (!fetcherTask.isRunning()) { + fetcherTask.start(); + FETCH_EXECUTOR.execute(fetcherTask); + } + } + } + + public void close(CloseFetchRequest request) { + int taskId = request.getTaskId(); + FetcherTask task = this.taskId2fetchTask.remove(taskId); + if (task != null) { + task.close(); + LOGGER.info( + "close fetcher task {} {}", + task.initFetchRequest.getTaskName(), + task.initFetchRequest.getShufflePhases()); + } + } - private final Map taskId2fetchTask = new HashMap<>(); - private final Configuration config; + public void cancel() { + // TODO Cancel fetching task. + // Shuffle reader should support cancel. + } - public PipelineInputFetcher(Configuration config) { - this.config = config; + /** Close the shuffle reader. */ + public void close() { + for (FetcherTask task : this.taskId2fetchTask.values()) { + task.close(); } + this.taskId2fetchTask.clear(); + } - /** - * Init input fetcher reader. - * - * @param request init fetch request - */ - public void init(InitFetchRequest request) { - int taskId = request.getTaskId(); - if (this.taskId2fetchTask.containsKey(taskId)) { - throw new GeaflowRuntimeException("task already exists: " + taskId); - } - for (ShardInputDesc inputDesc : request.getInputShards().values()) { - IEncoder encoder = inputDesc.getEncoder(); - if (encoder != null) { - encoder.init(this.config); - } - } - this.taskId2fetchTask.put(taskId, new FetcherTask(this.config, request)); - LOGGER.info("init fetcher task {} {}", request.getTaskName(), request.getShufflePhases()); - } + private static class FetcherTask implements Runnable { - /** - * Fetch data according to fetch request and process by worker. - * - * @param request fetch request - */ - protected void fetch(FetchRequest request) { - FetcherTask fetcherTask = this.taskId2fetchTask.get(request.getTaskId()); - if (fetcherTask != null) { - long targetWindowId = request.getWindowId() + request.getWindowCount() - 1; - fetcherTask.updateWindowId(targetWindowId); - if (!fetcherTask.isRunning()) { - fetcherTask.start(); - FETCH_EXECUTOR.execute(fetcherTask); - } - } - } + private static final String READER_NAME_PATTERN = "shuffle-reader-%d[%d/%d]"; + private static final int WAIT_TIME_OUT_MS = 100; - public void close(CloseFetchRequest request) { - int taskId = request.getTaskId(); - FetcherTask task = this.taskId2fetchTask.remove(taskId); - if (task != null) { - task.close(); - LOGGER.info("close fetcher task {} {}", task.initFetchRequest.getTaskName(), - task.initFetchRequest.getShufflePhases()); - } + private final Configuration config; + private final InitFetchRequest initFetchRequest; + private final PipelineReader shuffleReader; + private final IInputMessageBuffer[] fetchListeners; + private final BarrierHandler barrierHandler; + private final String name; + + private volatile boolean running; + private volatile long targetWindowId; + + private FetcherTask(Configuration config, InitFetchRequest request) { + this.config = config; + this.initFetchRequest = request; + this.shuffleReader = (PipelineReader) ShuffleManager.getInstance().loadShuffleReader(); + this.shuffleReader.init(this.buildReaderContext()); + this.fetchListeners = request.getFetchListeners().toArray(new IInputMessageBuffer[] {}); + this.barrierHandler = new BarrierHandler(request.getTaskId(), request.getInputShards()); + this.name = + String.format( + READER_NAME_PATTERN, + request.getTaskId(), + request.getTaskIndex(), + request.getTaskParallelism()); } - public void cancel() { - // TODO Cancel fetching task. - // Shuffle reader should support cancel. + public void start() { + this.running = true; } - /** - * Close the shuffle reader. - */ - public void close() { - for (FetcherTask task : this.taskId2fetchTask.values()) { - task.close(); - } - this.taskId2fetchTask.clear(); + public boolean isRunning() { + return this.running; } - private static class FetcherTask implements Runnable { - - private static final String READER_NAME_PATTERN = "shuffle-reader-%d[%d/%d]"; - private static final int WAIT_TIME_OUT_MS = 100; - - private final Configuration config; - private final InitFetchRequest initFetchRequest; - private final PipelineReader shuffleReader; - private final IInputMessageBuffer[] fetchListeners; - private final BarrierHandler barrierHandler; - private final String name; - - private volatile boolean running; - private volatile long targetWindowId; - - private FetcherTask(Configuration config, InitFetchRequest request) { - this.config = config; - this.initFetchRequest = request; - this.shuffleReader = (PipelineReader) ShuffleManager.getInstance().loadShuffleReader(); - this.shuffleReader.init(this.buildReaderContext()); - this.fetchListeners = request.getFetchListeners() - .toArray(new IInputMessageBuffer[]{}); - this.barrierHandler = new BarrierHandler(request.getTaskId(), request.getInputShards()); - this.name = String.format(READER_NAME_PATTERN, request.getTaskId(), - request.getTaskIndex(), request.getTaskParallelism()); - } - - public void start() { - this.running = true; + public void updateWindowId(long windowId) { + if (this.targetWindowId < windowId) { + this.targetWindowId = windowId; + synchronized (this) { + this.notifyAll(); } + } + } - public boolean isRunning() { - return this.running; - } + @Override + public void run() { + Thread.currentThread().setName(this.name); + try { + this.fetch(); + } catch (GeaflowRuntimeException e) { + LOGGER.error("fetcher task err with window id {} {}", this.targetWindowId, this.name, e); + throw e; + } catch (Throwable e) { + LOGGER.error("fetcher task err with window id {} {}", this.targetWindowId, this.name, e); + throw new GeaflowRuntimeException(e.getMessage(), e); + } + } - public void updateWindowId(long windowId) { - if (this.targetWindowId < windowId) { - this.targetWindowId = windowId; - synchronized (this) { - this.notifyAll(); - } - } + public void fetch() throws InterruptedException { + while (this.running) { + this.shuffleReader.fetch(this.targetWindowId); + if (!this.shuffleReader.hasNext()) { + synchronized (this) { + this.wait(WAIT_TIME_OUT_MS); + } + continue; } - - @Override - public void run() { - Thread.currentThread().setName(this.name); - try { - this.fetch(); - } catch (GeaflowRuntimeException e) { - LOGGER.error("fetcher task err with window id {} {}", this.targetWindowId, - this.name, e); - throw e; - } catch (Throwable e) { - LOGGER.error("fetcher task err with window id {} {}", this.targetWindowId, - this.name, e); - throw new GeaflowRuntimeException(e.getMessage(), e); + PipelineEvent event = this.shuffleReader.next(); + if (event != null) { + if (event instanceof PipelineMessage) { + PipelineMessage message = (PipelineMessage) event; + for (IInputMessageBuffer listener : this.fetchListeners) { + listener.onMessage(message); } - } - - public void fetch() throws InterruptedException { - while (this.running) { - this.shuffleReader.fetch(this.targetWindowId); - if (!this.shuffleReader.hasNext()) { - synchronized (this) { - this.wait(WAIT_TIME_OUT_MS); - } - continue; - } - PipelineEvent event = this.shuffleReader.next(); - if (event != null) { - if (event instanceof PipelineMessage) { - PipelineMessage message = (PipelineMessage) event; - for (IInputMessageBuffer listener : this.fetchListeners) { - listener.onMessage(message); - } - } else { - PipelineBarrier barrier = (PipelineBarrier) event; - if (this.barrierHandler.checkCompleted(barrier)) { - long windowCount = this.barrierHandler.getTotalWindowCount(); - this.handleMetrics(); - PipelineBarrier windowBarrier = new PipelineBarrier( - barrier.getWindowId(), barrier.getEdgeId(), windowCount); - for (IInputMessageBuffer listener : this.fetchListeners) { - listener.onBarrier(windowBarrier); - } - } - } - } + } else { + PipelineBarrier barrier = (PipelineBarrier) event; + if (this.barrierHandler.checkCompleted(barrier)) { + long windowCount = this.barrierHandler.getTotalWindowCount(); + this.handleMetrics(); + PipelineBarrier windowBarrier = + new PipelineBarrier(barrier.getWindowId(), barrier.getEdgeId(), windowCount); + for (IInputMessageBuffer listener : this.fetchListeners) { + listener.onBarrier(windowBarrier); + } } - LOGGER.info("fetcher task finish window id {} {}", this.targetWindowId, this.name); - } - - private ReaderContext buildReaderContext() { - ReaderContext context = new ReaderContext(); - context.setConfig(this.config); - context.setVertexId(this.initFetchRequest.getVertexId()); - context.setTaskName(this.initFetchRequest.getTaskName()); - context.setInputShardMap(this.initFetchRequest.getInputShards()); - context.setInputSlices(this.initFetchRequest.getInputSlices()); - context.setSliceNum(this.initFetchRequest.getSliceNum()); - return context; + } } + } + LOGGER.info("fetcher task finish window id {} {}", this.targetWindowId, this.name); + } - private void handleMetrics() { - ShuffleReadMetrics shuffleReadMetrics = this.shuffleReader.getShuffleReadMetrics(); - for (IInputMessageBuffer listener : this.fetchListeners) { - if (listener instanceof AbstractMessageBuffer) { - EventMetrics eventMetrics = ((AbstractMessageBuffer) listener).getEventMetrics(); - eventMetrics.addShuffleReadBytes(shuffleReadMetrics.getDecodeBytes()); - } - } - } + private ReaderContext buildReaderContext() { + ReaderContext context = new ReaderContext(); + context.setConfig(this.config); + context.setVertexId(this.initFetchRequest.getVertexId()); + context.setTaskName(this.initFetchRequest.getTaskName()); + context.setInputShardMap(this.initFetchRequest.getInputShards()); + context.setInputSlices(this.initFetchRequest.getInputSlices()); + context.setSliceNum(this.initFetchRequest.getSliceNum()); + return context; + } - public void close() { - this.running = false; - if (this.shuffleReader != null) { - this.shuffleReader.close(); - } + private void handleMetrics() { + ShuffleReadMetrics shuffleReadMetrics = this.shuffleReader.getShuffleReadMetrics(); + for (IInputMessageBuffer listener : this.fetchListeners) { + if (listener instanceof AbstractMessageBuffer) { + EventMetrics eventMetrics = ((AbstractMessageBuffer) listener).getEventMetrics(); + eventMetrics.addShuffleReadBytes(shuffleReadMetrics.getDecodeBytes()); } - + } } -} \ No newline at end of file + public void close() { + this.running = false; + if (this.shuffleReader != null) { + this.shuffleReader.close(); + } + } + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PrefetchMessageBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PrefetchMessageBuffer.java index 05b1b66df..5ed64e1cb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PrefetchMessageBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/fetcher/PrefetchMessageBuffer.java @@ -21,6 +21,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.protocol.InputMessage; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.message.PipelineBarrier; @@ -35,57 +36,56 @@ public class PrefetchMessageBuffer implements IInputMessageBuffer { - // The latch to notify data finish. - private final CountDownLatch latch = new CountDownLatch(1); - // Data slice. - private final SpillablePipelineSlice slice; - // Data edge id. - private final int edgeId; + // The latch to notify data finish. + private final CountDownLatch latch = new CountDownLatch(1); + // Data slice. + private final SpillablePipelineSlice slice; + // Data edge id. + private final int edgeId; - public PrefetchMessageBuffer(String logTag, SliceId sliceId) { - this.slice = new SpillablePipelineSlice(logTag, sliceId); - this.edgeId = sliceId.getEdgeId(); - SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - sliceManager.register(sliceId, this.slice); - } + public PrefetchMessageBuffer(String logTag, SliceId sliceId) { + this.slice = new SpillablePipelineSlice(logTag, sliceId); + this.edgeId = sliceId.getEdgeId(); + SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + sliceManager.register(sliceId, this.slice); + } - @Override - public void offer(InputMessage message) { - throw new UnsupportedOperationException(); - } + @Override + public void offer(InputMessage message) { + throw new UnsupportedOperationException(); + } - @Override - public InputMessage poll(long timeout, TimeUnit unit) { - throw new UnsupportedOperationException(); - } + @Override + public InputMessage poll(long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } - @Override - public void onMessage(PipelineMessage message) { - if (message.getEdgeId() != this.edgeId) { - return; - } - AbstractMessageIterator iterator = (AbstractMessageIterator) message.getMessageIterator(); - OutBuffer outBuffer = iterator.getOutBuffer(); - long windowId = message.getRecordArgs().getWindowId(); - this.slice.add(new PipeBuffer(outBuffer, windowId)); + @Override + public void onMessage(PipelineMessage message) { + if (message.getEdgeId() != this.edgeId) { + return; } + AbstractMessageIterator iterator = (AbstractMessageIterator) message.getMessageIterator(); + OutBuffer outBuffer = iterator.getOutBuffer(); + long windowId = message.getRecordArgs().getWindowId(); + this.slice.add(new PipeBuffer(outBuffer, windowId)); + } - @Override - public void onBarrier(PipelineBarrier barrier) { - if (barrier.getEdgeId() != this.edgeId) { - return; - } - this.slice.add(new PipeBuffer(barrier.getWindowId(), (int) barrier.getCount(), true)); - this.slice.flush(); - this.latch.countDown(); + @Override + public void onBarrier(PipelineBarrier barrier) { + if (barrier.getEdgeId() != this.edgeId) { + return; } + this.slice.add(new PipeBuffer(barrier.getWindowId(), (int) barrier.getCount(), true)); + this.slice.flush(); + this.latch.countDown(); + } - public void waitUtilFinish() { - try { - this.latch.await(); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + public void waitUtilFinish() { + try { + this.latch.await(); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatClient.java index c485e7d98..138bb62f2 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatClient.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.heartbeat; import java.io.Serializable; + import org.apache.geaflow.cluster.common.ComponentInfo; import org.apache.geaflow.cluster.rpc.RpcClient; import org.apache.geaflow.cluster.rpc.RpcEndpointRef.RpcCallback; @@ -32,70 +33,78 @@ import org.slf4j.LoggerFactory; public class HeartbeatClient implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatClient.class); - private final int containerId; - private final String containerName; - private final Configuration config; - private HeartbeatSender heartbeatSender; - private final ProcessStatsCollector statsCollector; - private T info; - private String masterId; + private final int containerId; + private final String containerName; + private final Configuration config; + private HeartbeatSender heartbeatSender; + private final ProcessStatsCollector statsCollector; + private T info; + private String masterId; - public HeartbeatClient(int containerId, String containerName, Configuration config) { - this.containerId = containerId; - this.containerName = containerName; - this.config = config; - this.statsCollector = StatsCollectorFactory.getInstance().getProcessStatsCollector(); - } + public HeartbeatClient(int containerId, String containerName, Configuration config) { + this.containerId = containerId; + this.containerName = containerName; + this.config = config; + this.statsCollector = StatsCollectorFactory.getInstance().getProcessStatsCollector(); + } - public void init(String masterId, T info) { - this.masterId = masterId; - this.info = info; - registerToMaster(); - startHeartBeat(masterId); - } + public void init(String masterId, T info) { + this.masterId = masterId; + this.info = info; + registerToMaster(); + startHeartBeat(masterId); + } - public void registerToMaster() { - LOGGER.info("register: {}", info); - RpcClient.init(config); - doRegister(masterId, info); - } + public void registerToMaster() { + LOGGER.info("register: {}", info); + RpcClient.init(config); + doRegister(masterId, info); + } - private void doRegister(String masterId, T info) { - RpcClient.getInstance().registerContainer(masterId, info, new RpcCallback() { + private void doRegister(String masterId, T info) { + RpcClient.getInstance() + .registerContainer( + masterId, + info, + new RpcCallback() { - @Override - public void onSuccess(RegisterResponse event) { + @Override + public void onSuccess(RegisterResponse event) { LOGGER.info("{} registered success:{}", containerName, event.getSuccess()); - } + } - @Override - public void onFailure(Throwable t) { + @Override + public void onFailure(Throwable t) { LOGGER.error("register info failed", t); - } - }); - } + } + }); + } - public void startHeartBeat(String masterId) { - LOGGER.info("start {} heartbeat", containerName); - this.heartbeatSender = new HeartbeatSender(masterId, () -> { - Heartbeat heartbeat = null; - if (containerName != null) { + public void startHeartBeat(String masterId) { + LOGGER.info("start {} heartbeat", containerName); + this.heartbeatSender = + new HeartbeatSender( + masterId, + () -> { + Heartbeat heartbeat = null; + if (containerName != null) { heartbeat = new Heartbeat(containerId); heartbeat.setContainerName(containerName); heartbeat.setProcessMetrics(statsCollector.collect()); - } - return heartbeat; - }, config, this); + } + return heartbeat; + }, + config, + this); - this.heartbeatSender.start(); - } + this.heartbeatSender.start(); + } - public void close() { - if (heartbeatSender != null) { - heartbeatSender.close(); - } + public void close() { + if (heartbeatSender != null) { + heartbeatSender.close(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManager.java index 6ccc08e0e..8e74ce7c7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManager.java @@ -36,6 +36,7 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.IClusterManager; import org.apache.geaflow.cluster.common.ComponentInfo; @@ -57,211 +58,215 @@ public class HeartbeatManager { - private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatManager.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatManager.class); - private final long heartbeatTimeoutMs; - private final long heartbeatReportExpiredMs; - private final Map senderMap; - private final AbstractClusterManager clusterManager; - private final ScheduledFuture timeoutFuture; - private final ScheduledFuture reportFuture; - private final ScheduledExecutorService checkTimeoutService; - private final ScheduledExecutorService heartbeatReportService; - private final IStatsWriter statsWriter; - private ScheduledFuture checkFuture; - private volatile boolean isRunning = true; + private final long heartbeatTimeoutMs; + private final long heartbeatReportExpiredMs; + private final Map senderMap; + private final AbstractClusterManager clusterManager; + private final ScheduledFuture timeoutFuture; + private final ScheduledFuture reportFuture; + private final ScheduledExecutorService checkTimeoutService; + private final ScheduledExecutorService heartbeatReportService; + private final IStatsWriter statsWriter; + private ScheduledFuture checkFuture; + private volatile boolean isRunning = true; - public HeartbeatManager(Configuration config, IClusterManager clusterManager) { - this.senderMap = new ConcurrentHashMap<>(); - this.heartbeatTimeoutMs = config.getInteger(HEARTBEAT_TIMEOUT_MS); - int heartbeatReportMs = config.getInteger(HEARTBEAT_REPORT_INTERVAL_MS); - int defaultReportExpiredMs = (int) ((heartbeatTimeoutMs + heartbeatReportMs) * 1.2); - this.heartbeatReportExpiredMs = config.getInteger(HEARTBEAT_REPORT_EXPIRED_MS, defaultReportExpiredMs); + public HeartbeatManager(Configuration config, IClusterManager clusterManager) { + this.senderMap = new ConcurrentHashMap<>(); + this.heartbeatTimeoutMs = config.getInteger(HEARTBEAT_TIMEOUT_MS); + int heartbeatReportMs = config.getInteger(HEARTBEAT_REPORT_INTERVAL_MS); + int defaultReportExpiredMs = (int) ((heartbeatTimeoutMs + heartbeatReportMs) * 1.2); + this.heartbeatReportExpiredMs = + config.getInteger(HEARTBEAT_REPORT_EXPIRED_MS, defaultReportExpiredMs); - boolean supervisorEnable = config.getBoolean(SUPERVISOR_ENABLE); - int corePoolSize = supervisorEnable ? 2 : 1; - this.checkTimeoutService = new ScheduledThreadPoolExecutor(corePoolSize, - ThreadUtil.namedThreadFactory(true, "heartbeat-manager")); - int initDelayMs = config.getInteger(HEARTBEAT_INITIAL_DELAY_MS); - this.timeoutFuture = checkTimeoutService.scheduleAtFixedRate(this::checkHeartBeat, - initDelayMs, heartbeatTimeoutMs, TimeUnit.MILLISECONDS); - if (supervisorEnable) { - long heartbeatCheckMs = config.getInteger(HEARTBEAT_INTERVAL_MS); - this.checkFuture = checkTimeoutService.scheduleAtFixedRate(this::checkWorkerHealth, - heartbeatCheckMs, heartbeatCheckMs, TimeUnit.MILLISECONDS); - } + boolean supervisorEnable = config.getBoolean(SUPERVISOR_ENABLE); + int corePoolSize = supervisorEnable ? 2 : 1; + this.checkTimeoutService = + new ScheduledThreadPoolExecutor( + corePoolSize, ThreadUtil.namedThreadFactory(true, "heartbeat-manager")); + int initDelayMs = config.getInteger(HEARTBEAT_INITIAL_DELAY_MS); + this.timeoutFuture = + checkTimeoutService.scheduleAtFixedRate( + this::checkHeartBeat, initDelayMs, heartbeatTimeoutMs, TimeUnit.MILLISECONDS); + if (supervisorEnable) { + long heartbeatCheckMs = config.getInteger(HEARTBEAT_INTERVAL_MS); + this.checkFuture = + checkTimeoutService.scheduleAtFixedRate( + this::checkWorkerHealth, heartbeatCheckMs, heartbeatCheckMs, TimeUnit.MILLISECONDS); + } - this.heartbeatReportService = new ScheduledThreadPoolExecutor(1, - ThreadUtil.namedThreadFactory(true, "heartbeat-report")); - this.reportFuture = heartbeatReportService - .scheduleAtFixedRate(this::reportHeartbeat, heartbeatReportMs, heartbeatReportMs, - TimeUnit.MILLISECONDS); + this.heartbeatReportService = + new ScheduledThreadPoolExecutor(1, ThreadUtil.namedThreadFactory(true, "heartbeat-report")); + this.reportFuture = + heartbeatReportService.scheduleAtFixedRate( + this::reportHeartbeat, heartbeatReportMs, heartbeatReportMs, TimeUnit.MILLISECONDS); - this.clusterManager = (AbstractClusterManager) clusterManager; - this.statsWriter = StatsCollectorFactory.init(config).getStatsWriter(); - } + this.clusterManager = (AbstractClusterManager) clusterManager; + this.statsWriter = StatsCollectorFactory.init(config).getStatsWriter(); + } - public HeartbeatResponse receivedHeartbeat(Heartbeat heartbeat) { - senderMap.put(heartbeat.getContainerId(), heartbeat); - boolean registered = isRegistered(heartbeat.getContainerId()); - return HeartbeatResponse.newBuilder().setSuccess(true).setRegistered(registered).build(); - } + public HeartbeatResponse receivedHeartbeat(Heartbeat heartbeat) { + senderMap.put(heartbeat.getContainerId(), heartbeat); + boolean registered = isRegistered(heartbeat.getContainerId()); + return HeartbeatResponse.newBuilder().setSuccess(true).setRegistered(registered).build(); + } - public void registerMasterHeartbeat(ComponentInfo masterInfo) { - this.statsWriter.addMetric(masterInfo.getName(), masterInfo); - } + public void registerMasterHeartbeat(ComponentInfo masterInfo) { + this.statsWriter.addMetric(masterInfo.getName(), masterInfo); + } - void checkHeartBeat() { - try { - long checkTime = System.currentTimeMillis(); - checkTimeout(clusterManager.getContainerIds(), checkTime); - checkTimeout(clusterManager.getDriverIds(), checkTime); - } catch (Throwable e) { - LOGGER.warn("Catch unexpect error", e); - } + void checkHeartBeat() { + try { + long checkTime = System.currentTimeMillis(); + checkTimeout(clusterManager.getContainerIds(), checkTime); + checkTimeout(clusterManager.getDriverIds(), checkTime); + } catch (Throwable e) { + LOGGER.warn("Catch unexpect error", e); } + } - private void checkTimeout(Map map, long checkTime) { - for (Map.Entry entry : map.entrySet()) { - int componentId = entry.getKey(); - Heartbeat heartbeat = senderMap.get(componentId); - if (heartbeat == null) { - if (isRegistered(componentId)) { - LOGGER.warn("{} heartbeat is not received", entry.getValue()); - } else { - LOGGER.warn("{} is not registered", entry.getValue()); - } - } else if (checkTime > heartbeat.getTimestamp() + heartbeatTimeoutMs) { - String message = String.format("%s heartbeat is lost", entry.getValue()); - LOGGER.error(message); - doFailover(componentId, new GeaflowHeartbeatException(message)); - } + private void checkTimeout(Map map, long checkTime) { + for (Map.Entry entry : map.entrySet()) { + int componentId = entry.getKey(); + Heartbeat heartbeat = senderMap.get(componentId); + if (heartbeat == null) { + if (isRegistered(componentId)) { + LOGGER.warn("{} heartbeat is not received", entry.getValue()); + } else { + LOGGER.warn("{} is not registered", entry.getValue()); } + } else if (checkTime > heartbeat.getTimestamp() + heartbeatTimeoutMs) { + String message = String.format("%s heartbeat is lost", entry.getValue()); + LOGGER.error(message); + doFailover(componentId, new GeaflowHeartbeatException(message)); + } } + } - public void reportHeartbeat() { - HeartbeatInfo heartbeatInfo = buildHeartbeatInfo(); - StatsCollectorFactory collectorFactory = StatsCollectorFactory.getInstance(); - if (collectorFactory != null) { - collectorFactory.getHeartbeatCollector().reportHeartbeat(heartbeatInfo); - } + public void reportHeartbeat() { + HeartbeatInfo heartbeatInfo = buildHeartbeatInfo(); + StatsCollectorFactory collectorFactory = StatsCollectorFactory.getInstance(); + if (collectorFactory != null) { + collectorFactory.getHeartbeatCollector().reportHeartbeat(heartbeatInfo); } + } - void checkWorkerHealth() { - try { - checkWorkerHealth(clusterManager.getContainerIds()); - checkWorkerHealth(clusterManager.getDriverIds()); - } catch (Throwable e) { - LOGGER.warn("Check container healthy error: {}", e.getMessage(), e); - } + void checkWorkerHealth() { + try { + checkWorkerHealth(clusterManager.getContainerIds()); + checkWorkerHealth(clusterManager.getDriverIds()); + } catch (Throwable e) { + LOGGER.warn("Check container healthy error: {}", e.getMessage(), e); } + } - private void checkWorkerHealth(Map map) { - for (Map.Entry entry : map.entrySet()) { - String name = entry.getValue(); - try { - StatusResponse response = RpcClient.getInstance().queryWorkerStatusBySupervisor(name); - if (!response.getIsAlive()) { - String message = String.format("worker %s is not alive", name); - LOGGER.error(message); - doFailover(entry.getKey(), new GeaflowHeartbeatException(message)); - } - } catch (Throwable e) { - String message = String.format("connect to supervisor of %s failed: %s", name, - e.getMessage()); - LOGGER.error(message, e); - doFailover(entry.getKey(), new GeaflowHeartbeatException(message, e)); - } + private void checkWorkerHealth(Map map) { + for (Map.Entry entry : map.entrySet()) { + String name = entry.getValue(); + try { + StatusResponse response = RpcClient.getInstance().queryWorkerStatusBySupervisor(name); + if (!response.getIsAlive()) { + String message = String.format("worker %s is not alive", name); + LOGGER.error(message); + doFailover(entry.getKey(), new GeaflowHeartbeatException(message)); } + } catch (Throwable e) { + String message = + String.format("connect to supervisor of %s failed: %s", name, e.getMessage()); + LOGGER.error(message, e); + doFailover(entry.getKey(), new GeaflowHeartbeatException(message, e)); + } } + } - void doFailover(int componentId, Throwable e) { - clusterManager.doFailover(componentId, e); - } + void doFailover(int componentId, Throwable e) { + clusterManager.doFailover(componentId, e); + } - protected boolean isRegistered(int componentId) { - AbstractClusterManager cm = clusterManager; - return cm.getContainerInfos().containsKey(componentId) || cm.getDriverInfos() - .containsKey(componentId); - } + protected boolean isRegistered(int componentId) { + AbstractClusterManager cm = clusterManager; + return cm.getContainerInfos().containsKey(componentId) + || cm.getDriverInfos().containsKey(componentId); + } - protected HeartbeatInfo buildHeartbeatInfo() { - Map heartbeatMap = getHeartBeatMap(); - Map containerMap = clusterManager.getContainerInfos(); - Map containerIndex = clusterManager.getContainerIds(); - int totalContainerNum = containerIndex.size(); - List containerList = new ArrayList<>(); - int activeContainers = 0; - for (Map.Entry entry : containerMap.entrySet()) { - ContainerHeartbeatInfo containerHeartbeatInfo = new ContainerHeartbeatInfo(); - containerHeartbeatInfo.setId(entry.getKey()); - ContainerInfo info = entry.getValue(); - containerHeartbeatInfo.setName(info.getName()); - containerHeartbeatInfo.setHost(info.getHost()); - containerHeartbeatInfo.setPid(info.getPid()); - Heartbeat heartbeat = heartbeatMap.get(entry.getKey()); - if (heartbeat != null) { - containerHeartbeatInfo.setLastTimestamp(heartbeat.getTimestamp()); - containerHeartbeatInfo.setMetrics(heartbeat.getProcessMetrics()); - activeContainers++; - } - containerList.add(containerHeartbeatInfo); - } - HeartbeatInfo heartbeatInfo = new HeartbeatInfo(); - heartbeatInfo.setExpiredTimeMs(heartbeatReportExpiredMs); - heartbeatInfo.setTotalNum(totalContainerNum); - heartbeatInfo.setActiveNum(activeContainers); - heartbeatInfo.setContainers(containerList); - return heartbeatInfo; + protected HeartbeatInfo buildHeartbeatInfo() { + Map heartbeatMap = getHeartBeatMap(); + Map containerMap = clusterManager.getContainerInfos(); + Map containerIndex = clusterManager.getContainerIds(); + int totalContainerNum = containerIndex.size(); + List containerList = new ArrayList<>(); + int activeContainers = 0; + for (Map.Entry entry : containerMap.entrySet()) { + ContainerHeartbeatInfo containerHeartbeatInfo = new ContainerHeartbeatInfo(); + containerHeartbeatInfo.setId(entry.getKey()); + ContainerInfo info = entry.getValue(); + containerHeartbeatInfo.setName(info.getName()); + containerHeartbeatInfo.setHost(info.getHost()); + containerHeartbeatInfo.setPid(info.getPid()); + Heartbeat heartbeat = heartbeatMap.get(entry.getKey()); + if (heartbeat != null) { + containerHeartbeatInfo.setLastTimestamp(heartbeat.getTimestamp()); + containerHeartbeatInfo.setMetrics(heartbeat.getProcessMetrics()); + activeContainers++; + } + containerList.add(containerHeartbeatInfo); } + HeartbeatInfo heartbeatInfo = new HeartbeatInfo(); + heartbeatInfo.setExpiredTimeMs(heartbeatReportExpiredMs); + heartbeatInfo.setTotalNum(totalContainerNum); + heartbeatInfo.setActiveNum(activeContainers); + heartbeatInfo.setContainers(containerList); + return heartbeatInfo; + } - public Map getHeartBeatMap() { - return senderMap; - } + public Map getHeartBeatMap() { + return senderMap; + } - public Set getActiveContainerIds() { - Map containerIdMap = clusterManager.getContainerIds(); - return getActiveComponentIds(containerIdMap); - } + public Set getActiveContainerIds() { + Map containerIdMap = clusterManager.getContainerIds(); + return getActiveComponentIds(containerIdMap); + } - public Set getActiveDriverIds() { - Map driverIdMap = clusterManager.getDriverIds(); - return getActiveComponentIds(driverIdMap); - } + public Set getActiveDriverIds() { + Map driverIdMap = clusterManager.getDriverIds(); + return getActiveComponentIds(driverIdMap); + } - private Set getActiveComponentIds(Map map) { - long checkTime = System.currentTimeMillis(); - Set activeComponentIds = new HashSet<>(); - for (Map.Entry entry : map.entrySet()) { - int componentId = entry.getKey(); - Heartbeat heartbeat = senderMap.get(componentId); - if (heartbeat != null && checkTime <= heartbeat.getTimestamp() + heartbeatTimeoutMs) { - activeComponentIds.add(componentId); - } - } - return activeComponentIds; + private Set getActiveComponentIds(Map map) { + long checkTime = System.currentTimeMillis(); + Set activeComponentIds = new HashSet<>(); + for (Map.Entry entry : map.entrySet()) { + int componentId = entry.getKey(); + Heartbeat heartbeat = senderMap.get(componentId); + if (heartbeat != null && checkTime <= heartbeat.getTimestamp() + heartbeatTimeoutMs) { + activeComponentIds.add(componentId); + } } + return activeComponentIds; + } - public void close() { - if (!isRunning) { - return; - } - isRunning = false; - if (timeoutFuture != null) { - timeoutFuture.cancel(true); - } - if (checkFuture != null) { - checkFuture.cancel(true); - } - if (checkTimeoutService != null) { - ExecutorUtil.shutdown(checkTimeoutService); - } - if (reportFuture != null) { - reportFuture.cancel(true); - } - if (heartbeatReportService != null) { - ExecutorUtil.shutdown(heartbeatReportService); - } - LOGGER.info("HeartbeatManager is closed"); + public void close() { + if (!isRunning) { + return; + } + isRunning = false; + if (timeoutFuture != null) { + timeoutFuture.cancel(true); + } + if (checkFuture != null) { + checkFuture.cancel(true); + } + if (checkTimeoutService != null) { + ExecutorUtil.shutdown(checkTimeoutService); + } + if (reportFuture != null) { + reportFuture.cancel(true); + } + if (heartbeatReportService != null) { + ExecutorUtil.shutdown(heartbeatReportService); } + LOGGER.info("HeartbeatManager is closed"); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatSender.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatSender.java index 0b10a1492..0df6b6757 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatSender.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/heartbeat/HeartbeatSender.java @@ -28,6 +28,7 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; + import org.apache.geaflow.cluster.rpc.RpcClient; import org.apache.geaflow.cluster.rpc.RpcEndpointRef.RpcCallback; import org.apache.geaflow.common.config.Configuration; @@ -39,63 +40,74 @@ import org.slf4j.LoggerFactory; public class HeartbeatSender implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatSender.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HeartbeatSender.class); - private final String masterId; - private final ScheduledExecutorService scheduledService; - private final Supplier heartbeatTrigger; - private final HeartbeatClient heartbeatClient; - private final long initialDelayMs; - private final long intervalMs; - private ScheduledFuture scheduledFuture; + private final String masterId; + private final ScheduledExecutorService scheduledService; + private final Supplier heartbeatTrigger; + private final HeartbeatClient heartbeatClient; + private final long initialDelayMs; + private final long intervalMs; + private ScheduledFuture scheduledFuture; - public HeartbeatSender(String masterId, Supplier heartbeatTrigger, - Configuration config, HeartbeatClient heartbeatClient) { - this.masterId = masterId; - this.heartbeatTrigger = heartbeatTrigger; - this.scheduledService = new ScheduledThreadPoolExecutor(1, - ThreadUtil.namedThreadFactory(true, "heartbeat-sender")); - this.heartbeatClient = heartbeatClient; - this.initialDelayMs = config.getInteger(HEARTBEAT_INITIAL_DELAY_MS); - this.intervalMs = config.getInteger(HEARTBEAT_INTERVAL_MS); - } + public HeartbeatSender( + String masterId, + Supplier heartbeatTrigger, + Configuration config, + HeartbeatClient heartbeatClient) { + this.masterId = masterId; + this.heartbeatTrigger = heartbeatTrigger; + this.scheduledService = + new ScheduledThreadPoolExecutor(1, ThreadUtil.namedThreadFactory(true, "heartbeat-sender")); + this.heartbeatClient = heartbeatClient; + this.initialDelayMs = config.getInteger(HEARTBEAT_INITIAL_DELAY_MS); + this.intervalMs = config.getInteger(HEARTBEAT_INTERVAL_MS); + } - public void start() { - scheduledFuture = scheduledService.scheduleWithFixedDelay(() -> { - Heartbeat message = null; - try { + public void start() { + scheduledFuture = + scheduledService.scheduleWithFixedDelay( + () -> { + Heartbeat message = null; + try { message = heartbeatTrigger.get(); if (message != null) { - RpcClient.getInstance().sendHeartBeat(masterId, message, new RpcCallback() { + RpcClient.getInstance() + .sendHeartBeat( + masterId, + message, + new RpcCallback() { - @Override - public void onSuccess(HeartbeatResponse event) { - if (!event.getRegistered()) { + @Override + public void onSuccess(HeartbeatResponse event) { + if (!event.getRegistered()) { LOGGER.warn("Heartbeat is not registered."); heartbeatClient.registerToMaster(); + } } - } - @Override - public void onFailure(Throwable t) { - LOGGER.error("Send heartbeat failed.", t); - } - }); + @Override + public void onFailure(Throwable t) { + LOGGER.error("Send heartbeat failed.", t); + } + }); } - } catch (Throwable e) { + } catch (Throwable e) { LOGGER.error("send heartbeat {} failed", message, e); - } - }, initialDelayMs, intervalMs, TimeUnit.MILLISECONDS); - } + } + }, + initialDelayMs, + intervalMs, + TimeUnit.MILLISECONDS); + } - public void close() { - LOGGER.info("Close heartbeat sender"); - if (scheduledFuture != null) { - scheduledFuture.cancel(true); - } - if (scheduledService != null) { - ExecutorUtil.shutdown(scheduledService); - } + public void close() { + LOGGER.info("Close heartbeat sender"); + if (scheduledFuture != null) { + scheduledFuture.cancel(true); } - + if (scheduledService != null) { + ExecutorUtil.shutdown(scheduledService); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/AbstractMaster.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/AbstractMaster.java index b530c7759..14b9dcd4b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/AbstractMaster.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/AbstractMaster.java @@ -22,8 +22,8 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.AGENT_HTTP_PORT; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.MASTER_HTTP_PORT; -import com.baidu.brpc.server.RpcServerOptions; import java.util.Map; + import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.clustermanager.ClusterInfo; import org.apache.geaflow.cluster.clustermanager.IClusterManager; @@ -50,139 +50,138 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractMaster extends AbstractComponent implements IMaster { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractMaster.class); - - protected IResourceManager resourceManager; - protected IClusterManager clusterManager; - protected HeartbeatManager heartbeatManager; - protected ConnectAddress masterAddress; - protected int agentPort; - protected int httpPort; - protected HttpServer httpServer; - protected ClusterContext clusterContext; - protected ILeaderElectionService leaderElectionService; - - public AbstractMaster() { - this(0); - } - - public AbstractMaster(int rpcPort) { - super(rpcPort); - } - - @Override - public void init(MasterContext context) { - super.init(context.getId(), context.getConfiguration().getMasterId(), - context.getConfiguration()); - - this.clusterManager = context.getClusterManager(); - this.clusterContext = context.getClusterContext(); - this.heartbeatManager = new HeartbeatManager(configuration, clusterManager); - this.resourceManager = new DefaultResourceManager(clusterManager); - this.clusterContext.setHeartbeatManager(heartbeatManager); - this.httpPort = configuration.getInteger(MASTER_HTTP_PORT); - - initEnv(context); - } - - protected void initEnv(MasterContext context) { - this.clusterManager.init(clusterContext); - startRpcService(clusterManager, resourceManager); - - // Register service info and initialize cluster. - registerHAService(); - // Start container. - resourceManager.init(ResourceManagerContext.build(context, clusterContext)); - - if (configuration.getBoolean(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE)) { - this.agentPort = startAgent(); - httpServer = new HttpServer(configuration, clusterManager, heartbeatManager, - resourceManager, buildMasterInfo()); - httpServer.start(); - } - registerHeartbeat(); - } - - public void initLeaderElectionService(ILeaderContender contender, - Configuration configuration, - int componentId) { - leaderElectionService = LeaderElectionServiceFactory.loadElectionService(configuration); - leaderElectionService.init(configuration, String.valueOf(componentId)); - leaderElectionService.open(contender); - LOGGER.info("Leader election service enabled for master."); - } - - public void waitForLeaderElection() throws InterruptedException { - LOGGER.info("Wait for becoming a leader..."); - synchronized (leaderElectionService) { - leaderElectionService.wait(); - } - } - - public void notifyLeaderElection() { - synchronized (leaderElectionService) { - leaderElectionService.notify(); - } - } +import com.baidu.brpc.server.RpcServerOptions; - protected void startRpcService(IClusterManager clusterManager, - IResourceManager resourceManager) { - RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); - int port = PortUtil.getPort(rpcPort); - this.rpcService = new RpcServiceImpl(port, serverOptions); - this.rpcService.addEndpoint(new MasterEndpoint(this, clusterManager)); - this.rpcService.addEndpoint(new ResourceManagerEndpoint(resourceManager)); - this.rpcPort = rpcService.startService(); - this.masterAddress = new ConnectAddress(ProcessUtil.getHostIp(), httpPort); - } +public abstract class AbstractMaster extends AbstractComponent implements IMaster { - public ClusterInfo startCluster() { - ClusterInfo clusterInfo = new ClusterInfo(); - clusterInfo.setMasterAddress(masterAddress); - Map driverAddresses = clusterManager.startDrivers(); - clusterInfo.setDriverAddresses(driverAddresses); - LOGGER.info("init cluster with info: {}", clusterInfo); - return clusterInfo; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractMaster.class); + + protected IResourceManager resourceManager; + protected IClusterManager clusterManager; + protected HeartbeatManager heartbeatManager; + protected ConnectAddress masterAddress; + protected int agentPort; + protected int httpPort; + protected HttpServer httpServer; + protected ClusterContext clusterContext; + protected ILeaderElectionService leaderElectionService; + + public AbstractMaster() { + this(0); + } + + public AbstractMaster(int rpcPort) { + super(rpcPort); + } + + @Override + public void init(MasterContext context) { + super.init( + context.getId(), context.getConfiguration().getMasterId(), context.getConfiguration()); + + this.clusterManager = context.getClusterManager(); + this.clusterContext = context.getClusterContext(); + this.heartbeatManager = new HeartbeatManager(configuration, clusterManager); + this.resourceManager = new DefaultResourceManager(clusterManager); + this.clusterContext.setHeartbeatManager(heartbeatManager); + this.httpPort = configuration.getInteger(MASTER_HTTP_PORT); + + initEnv(context); + } + + protected void initEnv(MasterContext context) { + this.clusterManager.init(clusterContext); + startRpcService(clusterManager, resourceManager); + + // Register service info and initialize cluster. + registerHAService(); + // Start container. + resourceManager.init(ResourceManagerContext.build(context, clusterContext)); + + if (configuration.getBoolean(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE)) { + this.agentPort = startAgent(); + httpServer = + new HttpServer( + configuration, clusterManager, heartbeatManager, resourceManager, buildMasterInfo()); + httpServer.start(); } - - private int startAgent() { - int port = PortUtil.getPort(configuration.getInteger(AGENT_HTTP_PORT)); - AgentWebServer agentServer = new AgentWebServer(port, configuration); - agentServer.start(); - return port; + registerHeartbeat(); + } + + public void initLeaderElectionService( + ILeaderContender contender, Configuration configuration, int componentId) { + leaderElectionService = LeaderElectionServiceFactory.loadElectionService(configuration); + leaderElectionService.init(configuration, String.valueOf(componentId)); + leaderElectionService.open(contender); + LOGGER.info("Leader election service enabled for master."); + } + + public void waitForLeaderElection() throws InterruptedException { + LOGGER.info("Wait for becoming a leader..."); + synchronized (leaderElectionService) { + leaderElectionService.wait(); } + } - protected MasterInfo buildMasterInfo() { - MasterInfo componentInfo = new MasterInfo(); - componentInfo.setId(id); - componentInfo.setName(name); - componentInfo.setHost(ProcessUtil.getHostIp()); - componentInfo.setPid(ProcessUtil.getProcessId()); - componentInfo.setRpcPort(rpcPort); - componentInfo.setAgentPort(agentPort); - componentInfo.setHttpPort(httpPort); - return componentInfo; + public void notifyLeaderElection() { + synchronized (leaderElectionService) { + leaderElectionService.notify(); } - - protected void registerHeartbeat() { - ComponentInfo componentInfo = buildMasterInfo(); - heartbeatManager.registerMasterHeartbeat(componentInfo); + } + + protected void startRpcService(IClusterManager clusterManager, IResourceManager resourceManager) { + RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); + int port = PortUtil.getPort(rpcPort); + this.rpcService = new RpcServiceImpl(port, serverOptions); + this.rpcService.addEndpoint(new MasterEndpoint(this, clusterManager)); + this.rpcService.addEndpoint(new ResourceManagerEndpoint(resourceManager)); + this.rpcPort = rpcService.startService(); + this.masterAddress = new ConnectAddress(ProcessUtil.getHostIp(), httpPort); + } + + public ClusterInfo startCluster() { + ClusterInfo clusterInfo = new ClusterInfo(); + clusterInfo.setMasterAddress(masterAddress); + Map driverAddresses = clusterManager.startDrivers(); + clusterInfo.setDriverAddresses(driverAddresses); + LOGGER.info("init cluster with info: {}", clusterInfo); + return clusterInfo; + } + + private int startAgent() { + int port = PortUtil.getPort(configuration.getInteger(AGENT_HTTP_PORT)); + AgentWebServer agentServer = new AgentWebServer(port, configuration); + agentServer.start(); + return port; + } + + protected MasterInfo buildMasterInfo() { + MasterInfo componentInfo = new MasterInfo(); + componentInfo.setId(id); + componentInfo.setName(name); + componentInfo.setHost(ProcessUtil.getHostIp()); + componentInfo.setPid(ProcessUtil.getProcessId()); + componentInfo.setRpcPort(rpcPort); + componentInfo.setAgentPort(agentPort); + componentInfo.setHttpPort(httpPort); + return componentInfo; + } + + protected void registerHeartbeat() { + ComponentInfo componentInfo = buildMasterInfo(); + heartbeatManager.registerMasterHeartbeat(componentInfo); + } + + @Override + public void close() { + super.close(); + clusterManager.close(); + if (heartbeatManager != null) { + heartbeatManager.close(); } - - @Override - public void close() { - super.close(); - clusterManager.close(); - if (heartbeatManager != null) { - heartbeatManager.close(); - } - if (httpServer != null) { - httpServer.stop(); - } - LOGGER.info("master {} closed", name); + if (httpServer != null) { + httpServer.stop(); } - + LOGGER.info("master {} closed", name); + } } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/IMaster.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/IMaster.java index d93283338..736b7716c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/IMaster.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/IMaster.java @@ -23,14 +23,9 @@ public interface IMaster extends Serializable { - /** - * Initialize master. - */ - void init(MasterContext context); - - /** - * Close master. - */ - void close(); + /** Initialize master. */ + void init(MasterContext context); + /** Close master. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/JobMaster.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/JobMaster.java index 2e1db1884..95c5b7f7d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/JobMaster.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/JobMaster.java @@ -19,6 +19,4 @@ package org.apache.geaflow.cluster.master; -public class JobMaster extends AbstractMaster { - -} +public class JobMaster extends AbstractMaster {} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterContext.java index 29d98eb5e..6d2ec2c71 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterContext.java @@ -29,44 +29,43 @@ public class MasterContext extends ReliableContainerContext { - private Configuration configuration; - private IClusterManager clusterManager; - private ClusterContext clusterContext; + private Configuration configuration; + private IClusterManager clusterManager; + private ClusterContext clusterContext; - public MasterContext(Configuration configuration) { - super(DEFAULT_MASTER_ID, ClusterConstants.getMasterName(), configuration); - this.configuration = configuration; - } + public MasterContext(Configuration configuration) { + super(DEFAULT_MASTER_ID, ClusterConstants.getMasterName(), configuration); + this.configuration = configuration; + } - public Configuration getConfiguration() { - return configuration; - } + public Configuration getConfiguration() { + return configuration; + } - public void setConfiguration(Configuration configuration) { - this.configuration = configuration; - } + public void setConfiguration(Configuration configuration) { + this.configuration = configuration; + } - public IClusterManager getClusterManager() { - return clusterManager; - } + public IClusterManager getClusterManager() { + return clusterManager; + } - public void setClusterManager(IClusterManager clusterManager) { - this.clusterManager = clusterManager; - } + public void setClusterManager(IClusterManager clusterManager) { + this.clusterManager = clusterManager; + } - public ClusterContext getClusterContext() { - return clusterContext; - } + public ClusterContext getClusterContext() { + return clusterContext; + } - @Override - public boolean isRecover() { - return clusterContext.isRecover(); - } - - @Override - public void load() { - this.clusterContext = new ClusterContext(configuration); - this.clusterContext.load(); - } + @Override + public boolean isRecover() { + return clusterContext.isRecover(); + } + @Override + public void load() { + this.clusterContext = new ClusterContext(configuration); + this.clusterContext.load(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterFactory.java index bf8a154e1..6b51bf0cb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterFactory.java @@ -24,16 +24,15 @@ public class MasterFactory { - public static synchronized AbstractMaster create(Configuration configuration) { + public static synchronized AbstractMaster create(Configuration configuration) { - AbstractMaster master; - if (JobMode.getJobMode(configuration) == JobMode.OLAP_SERVICE - || JobMode.getJobMode(configuration) == JobMode.STATE_SERVICE) { - master = new ServiceMaster(); - } else { - master = new JobMaster(); - } - return master; + AbstractMaster master; + if (JobMode.getJobMode(configuration) == JobMode.OLAP_SERVICE + || JobMode.getJobMode(configuration) == JobMode.STATE_SERVICE) { + master = new ServiceMaster(); + } else { + master = new JobMaster(); } - + return master; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterInfo.java index 87b2b7365..9cf81ad21 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/MasterInfo.java @@ -22,20 +22,38 @@ import org.apache.geaflow.cluster.common.ComponentInfo; public class MasterInfo extends ComponentInfo { - private int httpPort; + private int httpPort; - public int getHttpPort() { - return httpPort; - } + public int getHttpPort() { + return httpPort; + } - public void setHttpPort(int httpPort) { - this.httpPort = httpPort; - } + public void setHttpPort(int httpPort) { + this.httpPort = httpPort; + } - @Override - public String toString() { - return "MasterInfo{" + "httpPort=" + httpPort + ", id=" + id + ", name='" + name + '\'' - + ", host='" + host + '\'' + ", pid=" + pid + ", rpcPort=" + rpcPort + ", metricPort=" - + metricPort + ", agentPort=" + agentPort + "} " + super.toString(); - } + @Override + public String toString() { + return "MasterInfo{" + + "httpPort=" + + httpPort + + ", id=" + + id + + ", name='" + + name + + '\'' + + ", host='" + + host + + '\'' + + ", pid=" + + pid + + ", rpcPort=" + + rpcPort + + ", metricPort=" + + metricPort + + ", agentPort=" + + agentPort + + "} " + + super.toString(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/ServiceMaster.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/ServiceMaster.java index 200bd8b8e..8e226d050 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/ServiceMaster.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/master/ServiceMaster.java @@ -24,27 +24,27 @@ public class ServiceMaster extends AbstractMaster { - private MetaServer metaServer; - - @Override - public void init(MasterContext context) { - super.init(context); - } - - @Override - protected void initEnv(MasterContext context) { - // Start meta server. - this.metaServer = new MetaServer(); - MetaServerContext metaServerContext = new MetaServerContext(context.getConfiguration()); - metaServerContext.setRecover(clusterContext.isRecover()); - this.metaServer.init(metaServerContext); - - super.initEnv(context); - } - - @Override - public void close() { - super.close(); - metaServer.close(); - } + private MetaServer metaServer; + + @Override + public void init(MasterContext context) { + super.init(context); + } + + @Override + protected void initEnv(MasterContext context) { + // Start meta server. + this.metaServer = new MetaServer(); + MetaServerContext metaServerContext = new MetaServerContext(context.getConfiguration()); + metaServerContext.setRecover(clusterContext.isRecover()); + this.metaServer.init(metaServerContext); + + super.initEnv(context); + } + + @Override + public void close() { + super.close(); + metaServer.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/AbstractMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/AbstractMessage.java index 666434148..bc8fdd700 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/AbstractMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/AbstractMessage.java @@ -21,19 +21,18 @@ public abstract class AbstractMessage implements IMessage { - private final long windowId; + private final long windowId; - public AbstractMessage(long windowId) { - this.windowId = windowId; - } + public AbstractMessage(long windowId) { + this.windowId = windowId; + } - public long getWindowId() { - return this.windowId; - } - - @Override - public EventType getEventType() { - return EventType.MESSAGE; - } + public long getWindowId() { + return this.windowId; + } + @Override + public EventType getEventType() { + return EventType.MESSAGE; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/EventType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/EventType.java index e8bc81967..7e67dc3b3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/EventType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/EventType.java @@ -21,142 +21,96 @@ import java.io.Serializable; -/** - * The enum class for all event types. - */ +/** The enum class for all event types. */ public enum EventType implements Serializable { - /** - * A basic cycle execution command to assign resource and initialize cycle context. - */ - INIT_CYCLE, - - /** - * A basic cycle execution command to assign loop data exchange for iteration. - */ - INIT_ITERATION, - - /** - * A basic cycle execution command to prefetch data. - */ - PREFETCH, - - /** - * A basic cycle execution command to load graph vertex and edge.. - */ - PRE_GRAPH_PROCESS, - - /** - * A basic cycle execution command to reset input or output shuffle descriptors. - */ - REASSIGN, - - /** - * A basic cycle command to start one cycle iteration. - * The command send from scheduler to cycle heads. - */ - EXECUTE_COMPUTE, - - /** - * A basic cycle command to that execute iteration with agg info. - */ - ITERATIVE_COMPUTE_WITH_AGGREGATE, - - /** - * A basic cycle command to start one cycle iteration to fetch source. - * The command send from scheduler to cycle heads. - */ - LAUNCH_SOURCE, - - /** - * A basic cycle command to start one the first round of iteration cycle. - * The command send from scheduler to cycle heads. - */ - EXECUTE_FIRST_ITERATION, - - /** - * A basic cycle command denotes the end of a cycle iteration. - * The command send from cycle tails to scheduler. - */ - DONE, - - /** - * A basic cycle command that transfer between cycle intermediate tasks. - */ - BARRIER, - - /** - * A basic cycle command to finish iteration cycle. - */ - FINISH_ITERATION, - - /** - * A basic cycle command to clean up cycle context after all iterations finished. - */ - CLEAN_CYCLE, - - /** - * A compose command contains a collection of basic cycle command. - */ - COMPOSE, - - /** - * Clean env after all cycle of a pipeline task finished. - */ - CLEAN_ENV, - - /** - * Rollback to certain iteration id. - */ - ROLLBACK, - - /** - * Interrupt running task. - */ - INTERRUPT_TASK, - - /** - * Open container event. - */ - OPEN_CONTAINER, - - /** - * Response event by open container. - */ - OPEN_CONTAINER_RESPONSE, - - /** - * A inner message event to trigger processing worker to execute. - */ - MESSAGE, - - /** - * Create task. - */ - CREATE_TASK, - - /** - * Destroy task. - */ - DESTROY_TASK, - - /** - * A basic cycle command to create worker. - */ - CREATE_WORKER, - - /** - * Stash current worker that can be reused on demand. - */ - STASH_WORKER, - - /** - * Pop worker to reuse worker for following events. - */ - POP_WORKER, - - /** - * Collect execute result data. - */ - COLLECT_DATA, + /** A basic cycle execution command to assign resource and initialize cycle context. */ + INIT_CYCLE, + + /** A basic cycle execution command to assign loop data exchange for iteration. */ + INIT_ITERATION, + + /** A basic cycle execution command to prefetch data. */ + PREFETCH, + + /** A basic cycle execution command to load graph vertex and edge.. */ + PRE_GRAPH_PROCESS, + + /** A basic cycle execution command to reset input or output shuffle descriptors. */ + REASSIGN, + + /** + * A basic cycle command to start one cycle iteration. The command send from scheduler to cycle + * heads. + */ + EXECUTE_COMPUTE, + + /** A basic cycle command to that execute iteration with agg info. */ + ITERATIVE_COMPUTE_WITH_AGGREGATE, + + /** + * A basic cycle command to start one cycle iteration to fetch source. The command send from + * scheduler to cycle heads. + */ + LAUNCH_SOURCE, + + /** + * A basic cycle command to start one the first round of iteration cycle. The command send from + * scheduler to cycle heads. + */ + EXECUTE_FIRST_ITERATION, + + /** + * A basic cycle command denotes the end of a cycle iteration. The command send from cycle tails + * to scheduler. + */ + DONE, + + /** A basic cycle command that transfer between cycle intermediate tasks. */ + BARRIER, + + /** A basic cycle command to finish iteration cycle. */ + FINISH_ITERATION, + + /** A basic cycle command to clean up cycle context after all iterations finished. */ + CLEAN_CYCLE, + + /** A compose command contains a collection of basic cycle command. */ + COMPOSE, + + /** Clean env after all cycle of a pipeline task finished. */ + CLEAN_ENV, + + /** Rollback to certain iteration id. */ + ROLLBACK, + + /** Interrupt running task. */ + INTERRUPT_TASK, + + /** Open container event. */ + OPEN_CONTAINER, + + /** Response event by open container. */ + OPEN_CONTAINER_RESPONSE, + + /** A inner message event to trigger processing worker to execute. */ + MESSAGE, + + /** Create task. */ + CREATE_TASK, + + /** Destroy task. */ + DESTROY_TASK, + + /** A basic cycle command to create worker. */ + CREATE_WORKER, + + /** Stash current worker that can be reused on demand. */ + STASH_WORKER, + + /** Pop worker to reuse worker for following events. */ + POP_WORKER, + + /** Collect execute result data. */ + COLLECT_DATA, } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICommand.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICommand.java index f137406b1..925961cb5 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICommand.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICommand.java @@ -19,13 +19,9 @@ package org.apache.geaflow.cluster.protocol; -/** - * A command is the event of control flow of among cycle scheduling. - */ +/** A command is the event of control flow of among cycle scheduling. */ public interface ICommand extends IEvent { - /** - * Define the target worker to execute the event. - */ - int getWorkerId(); + /** Define the target worker to execute the event. */ + int getWorkerId(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IComposeEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IComposeEvent.java index d206704d6..6d5cbb18d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IComposeEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IComposeEvent.java @@ -21,14 +21,12 @@ import java.util.List; -/** - * A compose command contains a collection of basic cycle commands. - */ +/** A compose command contains a collection of basic cycle commands. */ public interface IComposeEvent extends IEvent { - /** - * A collection of basic cycle command, e.g: - * {@link EventType#INIT_CYCLE}, {@link EventType#EXECUTE_COMPUTE} and {@link EventType#BARRIER} - */ - List getEventList(); + /** + * A collection of basic cycle command, e.g: {@link EventType#INIT_CYCLE}, {@link + * EventType#EXECUTE_COMPUTE} and {@link EventType#BARRIER} + */ + List getEventList(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICycleResponseEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICycleResponseEvent.java index 0812741a3..f403a5256 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICycleResponseEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ICycleResponseEvent.java @@ -21,13 +21,9 @@ public interface ICycleResponseEvent extends IEvent { - /** - * The cycle id of the callback event. - */ - int getCycleId(); + /** The cycle id of the callback event. */ + int getCycleId(); - /** - * Returns the scheduler id of the callback event. - */ - long getSchedulerId(); + /** Returns the scheduler id of the callback event. */ + long getSchedulerId(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEvent.java index 3c70d38f2..9610904d6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEvent.java @@ -21,14 +21,9 @@ import java.io.Serializable; -/** - * An interface that defined the data/control flow of the cycle scheduling. - */ +/** An interface that defined the data/control flow of the cycle scheduling. */ public interface IEvent extends Serializable { - /** - * Return the type of event. - */ - EventType getEventType(); - + /** Return the type of event. */ + EventType getEventType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEventContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEventContext.java index 8850a06ca..fef25e597 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEventContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IEventContext.java @@ -21,5 +21,4 @@ import java.io.Serializable; -public interface IEventContext extends Serializable { -} +public interface IEventContext extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IExecutableCommand.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IExecutableCommand.java index ab80e1d9b..ecb4fc502 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IExecutableCommand.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IExecutableCommand.java @@ -21,18 +21,12 @@ import org.apache.geaflow.cluster.task.ITaskContext; -/** - * A executable command is the event of executable command. - */ +/** A executable command is the event of executable command. */ public interface IExecutableCommand extends ICommand { - /** - * Define compute logic process for the corresponding command. - */ - void execute(ITaskContext taskContext); + /** Define compute logic process for the corresponding command. */ + void execute(ITaskContext taskContext); - /** - * Interrupt current running command. - */ - void interrupt(); + /** Interrupt current running command. */ + void interrupt(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IHighAvailableEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IHighAvailableEvent.java index 598cf228e..ef84beb77 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IHighAvailableEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IHighAvailableEvent.java @@ -21,13 +21,9 @@ import org.apache.geaflow.ha.runtime.HighAvailableLevel; -/** - * A recoverable event that save and replay after fail over. - */ +/** A recoverable event that save and replay after fail over. */ public interface IHighAvailableEvent { - /** - * Returns the HA level. - */ - HighAvailableLevel getHaLevel(); + /** Returns the HA level. */ + HighAvailableLevel getHaLevel(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IMessage.java index 994865a52..27c50a53b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/IMessage.java @@ -19,16 +19,13 @@ package org.apache.geaflow.cluster.protocol; -/** - * A message is the event of data flow among cycle scheduling. - */ +/** A message is the event of data flow among cycle scheduling. */ public interface IMessage extends IEvent { - /** - * Get the message content. - * - * @return message - */ - T getMessage(); - + /** + * Get the message content. + * + * @return message + */ + T getMessage(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/InputMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/InputMessage.java index ffc658d96..b4c1469fe 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/InputMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/InputMessage.java @@ -21,33 +21,30 @@ import org.apache.geaflow.shuffle.message.PipelineMessage; -/** - * A message which is processed by worker. - */ +/** A message which is processed by worker. */ public class InputMessage extends AbstractMessage> { - private final PipelineMessage message; - private final long windowCount; - - public InputMessage(PipelineMessage message) { - super(message.getWindowId()); - this.message = message; - this.windowCount = -1; - } - - public InputMessage(long windowId, long windowCount) { - super(windowId); - this.message = null; - this.windowCount = windowCount; - } - - @Override - public PipelineMessage getMessage() { - return message; - } - - public long getWindowCount() { - return windowCount; - } - + private final PipelineMessage message; + private final long windowCount; + + public InputMessage(PipelineMessage message) { + super(message.getWindowId()); + this.message = message; + this.windowCount = -1; + } + + public InputMessage(long windowId, long windowCount) { + super(windowId); + this.message = null; + this.windowCount = windowCount; + } + + @Override + public PipelineMessage getMessage() { + return message; + } + + public long getWindowCount() { + return windowCount; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerEvent.java index 3d7a80294..43ab32164 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerEvent.java @@ -21,25 +21,24 @@ public class OpenContainerEvent implements IEvent { - private int executorNum; + private int executorNum; - public OpenContainerEvent() { - } + public OpenContainerEvent() {} - public OpenContainerEvent(int executorNum) { - this.executorNum = executorNum; - } + public OpenContainerEvent(int executorNum) { + this.executorNum = executorNum; + } - public int getExecutorNum() { - return executorNum; - } + public int getExecutorNum() { + return executorNum; + } - public void setExecutorNum(int executorNum) { - this.executorNum = executorNum; - } + public void setExecutorNum(int executorNum) { + this.executorNum = executorNum; + } - @Override - public EventType getEventType() { - return EventType.OPEN_CONTAINER; - } + @Override + public EventType getEventType() { + return EventType.OPEN_CONTAINER; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerResponseEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerResponseEvent.java index d10ffe74e..51e633872 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerResponseEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OpenContainerResponseEvent.java @@ -21,45 +21,44 @@ public class OpenContainerResponseEvent implements IEvent { - private boolean success; - private int containerId; - private int firstWorkerIndex; + private boolean success; + private int containerId; + private int firstWorkerIndex; - public OpenContainerResponseEvent() { - } + public OpenContainerResponseEvent() {} - public OpenContainerResponseEvent(boolean success) { - this.success = success; - } + public OpenContainerResponseEvent(boolean success) { + this.success = success; + } - public OpenContainerResponseEvent(int containerId, int firstWorkerIndex) { - this.success = true; - this.containerId = containerId; - this.firstWorkerIndex = firstWorkerIndex; - } + public OpenContainerResponseEvent(int containerId, int firstWorkerIndex) { + this.success = true; + this.containerId = containerId; + this.firstWorkerIndex = firstWorkerIndex; + } - public boolean isSuccess() { - return success; - } + public boolean isSuccess() { + return success; + } - public int getContainerId() { - return containerId; - } + public int getContainerId() { + return containerId; + } - public void setContainerId(int containerId) { - this.containerId = containerId; - } + public void setContainerId(int containerId) { + this.containerId = containerId; + } - public int getFirstWorkerIndex() { - return firstWorkerIndex; - } + public int getFirstWorkerIndex() { + return firstWorkerIndex; + } - public void setFirstWorkerIndex(int firstWorkerIndex) { - this.firstWorkerIndex = firstWorkerIndex; - } + public void setFirstWorkerIndex(int firstWorkerIndex) { + this.firstWorkerIndex = firstWorkerIndex; + } - @Override - public EventType getEventType() { - return EventType.OPEN_CONTAINER_RESPONSE; - } + @Override + public EventType getEventType() { + return EventType.OPEN_CONTAINER_RESPONSE; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OutputMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OutputMessage.java index 6e55c9762..3734493f8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OutputMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/OutputMessage.java @@ -23,36 +23,35 @@ public class OutputMessage extends AbstractMessage> { - private final int targetChannel; - private final List data; - private final boolean isBarrier; - - public OutputMessage(long windowId, int targetChannel, List data) { - super(windowId); - this.targetChannel = targetChannel; - this.data = data; - this.isBarrier = data == null; - } - - @Override - public List getMessage() { - return this.data; - } - - public int getTargetChannel() { - return this.targetChannel; - } - - public boolean isBarrier() { - return this.isBarrier; - } - - public static OutputMessage data(long windowId, int targetChannel, List data) { - return new OutputMessage<>(windowId, targetChannel, data); - } - - public static OutputMessage barrier(long windowId) { - return new OutputMessage<>(windowId, -1, null); - } - + private final int targetChannel; + private final List data; + private final boolean isBarrier; + + public OutputMessage(long windowId, int targetChannel, List data) { + super(windowId); + this.targetChannel = targetChannel; + this.data = data; + this.isBarrier = data == null; + } + + @Override + public List getMessage() { + return this.data; + } + + public int getTargetChannel() { + return this.targetChannel; + } + + public boolean isBarrier() { + return this.isBarrier; + } + + public static OutputMessage data(long windowId, int targetChannel, List data) { + return new OutputMessage<>(windowId, targetChannel, data); + } + + public static OutputMessage barrier(long windowId) { + return new OutputMessage<>(windowId, -1, null); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ScheduleStateType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ScheduleStateType.java index 6f860950d..e9559c2a6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ScheduleStateType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/protocol/ScheduleStateType.java @@ -23,59 +23,36 @@ public enum ScheduleStateType implements Serializable { - /** - * Start state. - */ - START, + /** Start state. */ + START, - /** - * Shuffle prefetch state. - */ - PREFETCH, + /** Shuffle prefetch state. */ + PREFETCH, - /** - * Shuffle finish prefetch state. - */ - FINISH_PREFETCH, + /** Shuffle finish prefetch state. */ + FINISH_PREFETCH, - /** - * Init state. - */ - INIT, + /** Init state. */ + INIT, - /** - * Init graph and execute first iteration state. - */ - ITERATION_INIT, + /** Init graph and execute first iteration state. */ + ITERATION_INIT, - /** - * Execute state. - */ - EXECUTE_COMPUTE, + /** Execute state. */ + EXECUTE_COMPUTE, - /** - * Finish iteration state. - */ - ITERATION_FINISH, + /** Finish iteration state. */ + ITERATION_FINISH, - /** - * Clean cycle. - */ - CLEAN_CYCLE, + /** Clean cycle. */ + CLEAN_CYCLE, - /** - * Rollback state. - */ - ROLLBACK, + /** Rollback state. */ + ROLLBACK, - /** - * Compose state. - */ - COMPOSE, - - /** - * End state. - */ - END, + /** Compose state. */ + COMPOSE, + /** End state. */ + END, } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManager.java index 6214a4cb3..385b3dd36 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManager.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.resourcemanager; -import com.google.common.annotations.VisibleForTesting; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; @@ -33,6 +32,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; + import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.clustermanager.ContainerExecutorInfo; import org.apache.geaflow.cluster.clustermanager.ExecutorRegisterException; @@ -50,317 +50,376 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DefaultResourceManager implements IResourceManager, ExecutorRegisteredCallback, Serializable { +import com.google.common.annotations.VisibleForTesting; - private static final Logger LOGGER = LoggerFactory.getLogger(DefaultResourceManager.class); +public class DefaultResourceManager + implements IResourceManager, ExecutorRegisteredCallback, Serializable { - private static final String OPERATION_REQUIRE = "require"; - private static final String OPERATION_RELEASE = "release"; - private static final String OPERATION_ALLOCATE = "allocate"; + private static final Logger LOGGER = LoggerFactory.getLogger(DefaultResourceManager.class); - private static final int DEFAULT_SLEEP_MS = 10; - private static final int MAX_REQUIRE_RETRY_TIMES = 1; - private static final int MAX_RELEASE_RETRY_TIMES = 100; + private static final String OPERATION_REQUIRE = "require"; + private static final String OPERATION_RELEASE = "release"; + private static final String OPERATION_ALLOCATE = "allocate"; - private final AtomicReference allocateWorkerErr = new AtomicReference<>(); - private final AtomicInteger pendingWorkerCounter = new AtomicInteger(0); - private final AtomicBoolean resourceLock = new AtomicBoolean(true); - private final AtomicBoolean recovering = new AtomicBoolean(false); - private final AtomicBoolean inited = new AtomicBoolean(false); - private final Map> allocators = new HashMap<>(); - protected final IClusterManager clusterManager; - protected final ClusterMetaStore metaKeeper; - protected int totalWorkerNum; + private static final int DEFAULT_SLEEP_MS = 10; + private static final int MAX_REQUIRE_RETRY_TIMES = 1; + private static final int MAX_RELEASE_RETRY_TIMES = 100; - private final Map availableWorkers = new TreeMap<>( - Comparator.comparing(WorkerInfo.WorkerId::getContainerName).thenComparing(WorkerInfo.WorkerId::getWorkerIndex)); - private final Map sessions = new HashMap<>(); + private final AtomicReference allocateWorkerErr = new AtomicReference<>(); + private final AtomicInteger pendingWorkerCounter = new AtomicInteger(0); + private final AtomicBoolean resourceLock = new AtomicBoolean(true); + private final AtomicBoolean recovering = new AtomicBoolean(false); + private final AtomicBoolean inited = new AtomicBoolean(false); + private final Map> allocators = + new HashMap<>(); + protected final IClusterManager clusterManager; + protected final ClusterMetaStore metaKeeper; + protected int totalWorkerNum; - public DefaultResourceManager(IClusterManager clusterManager) { - this.clusterManager = clusterManager; - this.metaKeeper = ClusterMetaStore.getInstance(); - } + private final Map availableWorkers = + new TreeMap<>( + Comparator.comparing(WorkerInfo.WorkerId::getContainerName) + .thenComparing(WorkerInfo.WorkerId::getWorkerIndex)); + private final Map sessions = new HashMap<>(); + + public DefaultResourceManager(IClusterManager clusterManager) { + this.clusterManager = clusterManager; + this.metaKeeper = ClusterMetaStore.getInstance(); + } - @Override - public void init(ResourceManagerContext context) { - this.allocators.put(IAllocator.AllocateStrategy.ROUND_ROBIN, new RoundRobinAllocator()); - this.allocators.put(IAllocator.AllocateStrategy.PROCESS_FAIR, new ProcessFairAllocator()); - ClusterContext clusterContext = context.getClusterContext(); - clusterContext.addExecutorRegisteredCallback(this); + @Override + public void init(ResourceManagerContext context) { + this.allocators.put(IAllocator.AllocateStrategy.ROUND_ROBIN, new RoundRobinAllocator()); + this.allocators.put(IAllocator.AllocateStrategy.PROCESS_FAIR, new ProcessFairAllocator()); + ClusterContext clusterContext = context.getClusterContext(); + clusterContext.addExecutorRegisteredCallback(this); - boolean isRecover = context.isRecover(); - this.recovering.set(isRecover); - int workerNum = clusterContext.getClusterConfig().getContainerNum() + boolean isRecover = context.isRecover(); + this.recovering.set(isRecover); + int workerNum = + clusterContext.getClusterConfig().getContainerNum() * clusterContext.getClusterConfig().getContainerWorkerNum(); - this.totalWorkerNum += workerNum; - this.pendingWorkerCounter.set(workerNum); - LOGGER.info("init worker number {}, isRecover {}", workerNum, isRecover); - if (isRecover) { - this.recover(); - } else { - this.clusterManager.allocateWorkers(workerNum); - } - this.inited.set(true); - LOGGER.info("init worker manager finish"); + this.totalWorkerNum += workerNum; + this.pendingWorkerCounter.set(workerNum); + LOGGER.info("init worker number {}, isRecover {}", workerNum, isRecover); + if (isRecover) { + this.recover(); + } else { + this.clusterManager.allocateWorkers(workerNum); } + this.inited.set(true); + LOGGER.info("init worker manager finish"); + } - @Override - public RequireResponse requireResource(RequireResourceRequest requireRequest) { - String requireId = requireRequest.getRequireId(); - int requiredNum = requireRequest.getRequiredNum(); - if (this.sessions.containsKey(requireId)) { - Map sessionWorkers = this.sessions.get(requireId).getWorkers(); - if (requiredNum != sessionWorkers.size()) { - String msg = "require number mismatch, old " + sessionWorkers.size() + " new " + requiredNum; - LOGGER.error("[{}] require from session err: {}", requireId, msg); - return RequireResponse.fail(requireId, msg); - } - List workers = new ArrayList<>(sessionWorkers.values()); - LOGGER.info("[{}] require from session with {} worker", requireId, workers.size()); - return RequireResponse.success(requireId, workers); - } - if (requiredNum <= 0) { - String msg = RuntimeErrors.INST.resourceIllegalRequireNumError("illegal num " + requiredNum); - LOGGER.error("[{}] {}", requireId, msg); - return RequireResponse.fail(requireId, msg); - } - if (this.recovering.get()) { - String msg = "resource manager still recovering"; - LOGGER.warn("[{}] {}", requireId, msg); - return RequireResponse.fail(requireId, msg); - } - if (this.pendingWorkerCounter.get() > 0) { - String msg = "some worker still pending creation"; - LOGGER.warn("[{}] {}", requireId, msg); - return RequireResponse.fail(requireId, msg); - } + @Override + public RequireResponse requireResource(RequireResourceRequest requireRequest) { + String requireId = requireRequest.getRequireId(); + int requiredNum = requireRequest.getRequiredNum(); + if (this.sessions.containsKey(requireId)) { + Map sessionWorkers = + this.sessions.get(requireId).getWorkers(); + if (requiredNum != sessionWorkers.size()) { + String msg = + "require number mismatch, old " + sessionWorkers.size() + " new " + requiredNum; + LOGGER.error("[{}] require from session err: {}", requireId, msg); + return RequireResponse.fail(requireId, msg); + } + List workers = new ArrayList<>(sessionWorkers.values()); + LOGGER.info("[{}] require from session with {} worker", requireId, workers.size()); + return RequireResponse.success(requireId, workers); + } + if (requiredNum <= 0) { + String msg = RuntimeErrors.INST.resourceIllegalRequireNumError("illegal num " + requiredNum); + LOGGER.error("[{}] {}", requireId, msg); + return RequireResponse.fail(requireId, msg); + } + if (this.recovering.get()) { + String msg = "resource manager still recovering"; + LOGGER.warn("[{}] {}", requireId, msg); + return RequireResponse.fail(requireId, msg); + } + if (this.pendingWorkerCounter.get() > 0) { + String msg = "some worker still pending creation"; + LOGGER.warn("[{}] {}", requireId, msg); + return RequireResponse.fail(requireId, msg); + } - Optional> optional = this.withLock(OPERATION_REQUIRE, num -> { - if (this.availableWorkers.size() < num) { - LOGGER.warn("[{}] require {}, available {}, return empty", - requireId, num, this.availableWorkers.size()); + Optional> optional = + this.withLock( + OPERATION_REQUIRE, + num -> { + if (this.availableWorkers.size() < num) { + LOGGER.warn( + "[{}] require {}, available {}, return empty", + requireId, + num, + this.availableWorkers.size()); return Collections.emptyList(); - } - IAllocator.AllocateStrategy strategy = requireRequest.getAllocateStrategy(); - List allocated = this.allocators.get(strategy) - .allocate(this.availableWorkers.values(), num); - for (WorkerInfo worker : allocated) { + } + IAllocator.AllocateStrategy strategy = requireRequest.getAllocateStrategy(); + List allocated = + this.allocators.get(strategy).allocate(this.availableWorkers.values(), num); + for (WorkerInfo worker : allocated) { WorkerInfo.WorkerId workerId = worker.generateWorkerId(); this.availableWorkers.remove(workerId); - ResourceSession session = this.sessions.computeIfAbsent(requireId, ResourceSession::new); + ResourceSession session = + this.sessions.computeIfAbsent(requireId, ResourceSession::new); session.addWorker(workerId, worker); - } - LOGGER.info("[{}] require {} allocated {} available {}", - requireId, num, allocated.size(), this.availableWorkers.size()); - if (!allocated.isEmpty()) { + } + LOGGER.info( + "[{}] require {} allocated {} available {}", + requireId, + num, + allocated.size(), + this.availableWorkers.size()); + if (!allocated.isEmpty()) { this.persist(); - } - return allocated; - }, requiredNum, MAX_REQUIRE_RETRY_TIMES); - List allocated = optional.orElse(Collections.emptyList()); - return RequireResponse.success(requireId, allocated); - } + } + return allocated; + }, + requiredNum, + MAX_REQUIRE_RETRY_TIMES); + List allocated = optional.orElse(Collections.emptyList()); + return RequireResponse.success(requireId, allocated); + } - @Override - public ReleaseResponse releaseResource(ReleaseResourceRequest releaseRequest) { - String releaseId = String.valueOf(releaseRequest.getReleaseId()); - if (!this.sessions.containsKey(releaseId)) { - String msg = "release fail, session not exists: " + releaseId; - LOGGER.error(msg); - return ReleaseResponse.fail(releaseId, msg); - } - int expectSize = this.sessions.get(releaseId).getWorkers().size(); - int actualSize = releaseRequest.getWorkers().size(); - if (expectSize != actualSize) { - String msg = String.format("release fail, worker num of session %s mismatch, expected %d, actual %d", - releaseId, expectSize, actualSize); - LOGGER.error(msg); - return ReleaseResponse.fail(releaseId, msg); - } - Optional optional = this.withLock(OPERATION_RELEASE, workers -> { - for (WorkerInfo worker : workers) { + @Override + public ReleaseResponse releaseResource(ReleaseResourceRequest releaseRequest) { + String releaseId = String.valueOf(releaseRequest.getReleaseId()); + if (!this.sessions.containsKey(releaseId)) { + String msg = "release fail, session not exists: " + releaseId; + LOGGER.error(msg); + return ReleaseResponse.fail(releaseId, msg); + } + int expectSize = this.sessions.get(releaseId).getWorkers().size(); + int actualSize = releaseRequest.getWorkers().size(); + if (expectSize != actualSize) { + String msg = + String.format( + "release fail, worker num of session %s mismatch, expected %d, actual %d", + releaseId, expectSize, actualSize); + LOGGER.error(msg); + return ReleaseResponse.fail(releaseId, msg); + } + Optional optional = + this.withLock( + OPERATION_RELEASE, + workers -> { + for (WorkerInfo worker : workers) { WorkerInfo.WorkerId workerId = worker.generateWorkerId(); this.availableWorkers.put(workerId, worker); if (!this.sessions.get(releaseId).removeWorker(workerId)) { - String msg = String.format("worker %s not exists in session %s", workerId, releaseId); - LOGGER.error(msg); - throw new GeaflowRuntimeException(msg); + String msg = + String.format("worker %s not exists in session %s", workerId, releaseId); + LOGGER.error(msg); + throw new GeaflowRuntimeException(msg); } - } - this.sessions.remove(releaseId); - LOGGER.info("[{}] release {} available {}", - releaseId, workers.size(), this.availableWorkers.size()); - this.persist(); - return true; - }, releaseRequest.getWorkers(), MAX_RELEASE_RETRY_TIMES); - - if (!optional.orElse(false)) { - String msg = "release fail after " + MAX_RELEASE_RETRY_TIMES + " times"; - LOGGER.error(msg); - return ReleaseResponse.fail(releaseId, msg); - } - return ReleaseResponse.success(releaseId); - } + } + this.sessions.remove(releaseId); + LOGGER.info( + "[{}] release {} available {}", + releaseId, + workers.size(), + this.availableWorkers.size()); + this.persist(); + return true; + }, + releaseRequest.getWorkers(), + MAX_RELEASE_RETRY_TIMES); - @Override - public void onSuccess(ContainerExecutorInfo containerExecutorInfo) { - this.waitForInit(); - this.withLock(OPERATION_ALLOCATE, container -> { - String containerName = container.getContainerName(); - String host = container.getHost(); - int rpcPort = container.getRpcPort(); - int shufflePort = container.getShufflePort(); - int processId = container.getProcessId(); - List executorIds = container.getExecutorIds(); - onRegister(containerName, host, rpcPort, shufflePort, processId, executorIds); - return true; - }, containerExecutorInfo, Integer.MAX_VALUE); + if (!optional.orElse(false)) { + String msg = "release fail after " + MAX_RELEASE_RETRY_TIMES + " times"; + LOGGER.error(msg); + return ReleaseResponse.fail(releaseId, msg); } + return ReleaseResponse.success(releaseId); + } - private void onRegister(String containerName, - String host, - int rpcPort, - int shufflePort, - int processId, - List executorIds) { - for (Integer workerIndex : executorIds) { - WorkerInfo.WorkerId workerId = new WorkerInfo.WorkerId(containerName, workerIndex); - WorkerInfo worker = null; - if (this.availableWorkers.containsKey(workerId)) { - worker = this.availableWorkers.get(workerId); - } - for (ResourceSession session : this.sessions.values()) { - if (session.getWorkers().containsKey(workerId)) { - worker = session.getWorkers().get(workerId); - } - } - if (worker == null) { - worker = WorkerInfo.build( - host, rpcPort, shufflePort, processId, workerIndex, containerName); - this.availableWorkers.put(worker.generateWorkerId(), worker); - this.pendingWorkerCounter.addAndGet(-1); - } else { - worker.setHost(host); - worker.setProcessId(processId); - worker.setRpcPort(rpcPort); - worker.setShufflePort(shufflePort); - } - } - int pending = this.pendingWorkerCounter.get(); - LOGGER.info("register {} worker from cluster manager container:{}, host:{}, processId:{}," - + " pending:{}", executorIds.size(), containerName, host, processId, pending); + @Override + public void onSuccess(ContainerExecutorInfo containerExecutorInfo) { + this.waitForInit(); + this.withLock( + OPERATION_ALLOCATE, + container -> { + String containerName = container.getContainerName(); + String host = container.getHost(); + int rpcPort = container.getRpcPort(); + int shufflePort = container.getShufflePort(); + int processId = container.getProcessId(); + List executorIds = container.getExecutorIds(); + onRegister(containerName, host, rpcPort, shufflePort, processId, executorIds); + return true; + }, + containerExecutorInfo, + Integer.MAX_VALUE); + } - int used = this.sessions.values().stream().mapToInt(s -> s.getWorkers().size()).sum(); - if (pending <= 0) { - this.recovering.set(false); - LOGGER.info("register worker over, available/used : {}/{}, pending {}", - this.availableWorkers.size(), used, pending); - persist(); - } else { - LOGGER.debug("still pending : {}, available/used : {}/{}", pending, - this.availableWorkers.size(), used); + private void onRegister( + String containerName, + String host, + int rpcPort, + int shufflePort, + int processId, + List executorIds) { + for (Integer workerIndex : executorIds) { + WorkerInfo.WorkerId workerId = new WorkerInfo.WorkerId(containerName, workerIndex); + WorkerInfo worker = null; + if (this.availableWorkers.containsKey(workerId)) { + worker = this.availableWorkers.get(workerId); + } + for (ResourceSession session : this.sessions.values()) { + if (session.getWorkers().containsKey(workerId)) { + worker = session.getWorkers().get(workerId); } + } + if (worker == null) { + worker = + WorkerInfo.build(host, rpcPort, shufflePort, processId, workerIndex, containerName); + this.availableWorkers.put(worker.generateWorkerId(), worker); + this.pendingWorkerCounter.addAndGet(-1); + } else { + worker.setHost(host); + worker.setProcessId(processId); + worker.setRpcPort(rpcPort); + worker.setShufflePort(shufflePort); + } } + int pending = this.pendingWorkerCounter.get(); + LOGGER.info( + "register {} worker from cluster manager container:{}, host:{}, processId:{}," + + " pending:{}", + executorIds.size(), + containerName, + host, + processId, + pending); - @Override - public void onFailure(ExecutorRegisterException e) { - LOGGER.error("create worker err", e); - this.allocateWorkerErr.compareAndSet(null, e); - } - - private Optional withLock(String operation, Function function, T input, int maxRetryTimes) { - this.checkError(); - try { - int retry = 0; - while (!this.resourceLock.compareAndSet(true, false)) { - SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); - retry++; - if (retry >= maxRetryTimes) { - LOGGER.warn("[{}] lock not ready, return empty", operation); - return Optional.empty(); - } - if (retry % 100 == 0) { - LOGGER.warn("[{}] lock not ready after {} times", operation, retry); - } - } - return Optional.of(function.apply(input)); - } finally { - this.resourceLock.set(true); - } + int used = this.sessions.values().stream().mapToInt(s -> s.getWorkers().size()).sum(); + if (pending <= 0) { + this.recovering.set(false); + LOGGER.info( + "register worker over, available/used : {}/{}, pending {}", + this.availableWorkers.size(), + used, + pending); + persist(); + } else { + LOGGER.debug( + "still pending : {}, available/used : {}/{}", + pending, + this.availableWorkers.size(), + used); } + } - private void persist() { - final long start = System.currentTimeMillis(); - List available = new ArrayList<>(this.availableWorkers.values()); - List sessions = new ArrayList<>(this.sessions.values()); - int used = sessions.stream().mapToInt(s -> s.getWorkers().size()).sum(); - this.metaKeeper.saveWorkers(new WorkerSnapshot(available, sessions)); - LOGGER.info("persist {}/{} workers costs {}ms", - this.availableWorkers.size(), used, System.currentTimeMillis() - start); - } + @Override + public void onFailure(ExecutorRegisterException e) { + LOGGER.error("create worker err", e); + this.allocateWorkerErr.compareAndSet(null, e); + } - private void recover() { - final long start = System.currentTimeMillis(); - WorkerSnapshot workerSnapshot = this.metaKeeper.getWorkers(); - List available = workerSnapshot.getAvailableWorkers(); - List sessions = workerSnapshot.getSessions(); - for (WorkerInfo worker : available) { - WorkerInfo.WorkerId workerId = worker.generateWorkerId(); - this.availableWorkers.put(workerId, worker); + private Optional withLock( + String operation, Function function, T input, int maxRetryTimes) { + this.checkError(); + try { + int retry = 0; + while (!this.resourceLock.compareAndSet(true, false)) { + SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); + retry++; + if (retry >= maxRetryTimes) { + LOGGER.warn("[{}] lock not ready, return empty", operation); + return Optional.empty(); } - int usedWorkerNum = 0; - for (ResourceSession session : sessions) { - String sessionId = session.getId(); - this.sessions.put(sessionId, session); - usedWorkerNum += session.getWorkers().size(); + if (retry % 100 == 0) { + LOGGER.warn("[{}] lock not ready after {} times", operation, retry); } - int availableWorkerNum = this.availableWorkers.size(); - - this.pendingWorkerCounter.addAndGet(-availableWorkerNum - usedWorkerNum); - LOGGER.info("recover {}/{} workers, pending {}, costs {}ms", availableWorkerNum, - usedWorkerNum, pendingWorkerCounter.get(), System.currentTimeMillis() - start); - if (this.pendingWorkerCounter.get() <= 0) { - this.recovering.set(false); - LOGGER.info("recover worker over, available/used : {}/{}", this.availableWorkers.size(), - usedWorkerNum); - } - clusterManager.doFailover(ClusterConstants.DEFAULT_MASTER_ID, null); + } + return Optional.of(function.apply(input)); + } finally { + this.resourceLock.set(true); } + } - private void waitForInit() { - int count = 0; - while (!this.inited.get()) { - count++; - if (count % 100 == 0) { - LOGGER.warn("resource manager not inited, wait {}ms and retry", DEFAULT_SLEEP_MS); - } - SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); - } - } + private void persist() { + final long start = System.currentTimeMillis(); + List available = new ArrayList<>(this.availableWorkers.values()); + List sessions = new ArrayList<>(this.sessions.values()); + int used = sessions.stream().mapToInt(s -> s.getWorkers().size()).sum(); + this.metaKeeper.saveWorkers(new WorkerSnapshot(available, sessions)); + LOGGER.info( + "persist {}/{} workers costs {}ms", + this.availableWorkers.size(), + used, + System.currentTimeMillis() - start); + } - public ResourceMetrics getResourceMetrics() { - ResourceMetrics metrics = new ResourceMetrics(); - metrics.setPendingWorkers(pendingWorkerCounter.get()); - metrics.setAvailableWorkers(availableWorkers.size()); - metrics.setTotalWorkers(totalWorkerNum); - return metrics; + private void recover() { + final long start = System.currentTimeMillis(); + WorkerSnapshot workerSnapshot = this.metaKeeper.getWorkers(); + List available = workerSnapshot.getAvailableWorkers(); + List sessions = workerSnapshot.getSessions(); + for (WorkerInfo worker : available) { + WorkerInfo.WorkerId workerId = worker.generateWorkerId(); + this.availableWorkers.put(workerId, worker); } - - @VisibleForTesting - protected AtomicInteger getPendingWorkerCounter() { - return this.pendingWorkerCounter; + int usedWorkerNum = 0; + for (ResourceSession session : sessions) { + String sessionId = session.getId(); + this.sessions.put(sessionId, session); + usedWorkerNum += session.getWorkers().size(); } + int availableWorkerNum = this.availableWorkers.size(); - @VisibleForTesting - protected AtomicBoolean getResourceLock() { - return this.resourceLock; + this.pendingWorkerCounter.addAndGet(-availableWorkerNum - usedWorkerNum); + LOGGER.info( + "recover {}/{} workers, pending {}, costs {}ms", + availableWorkerNum, + usedWorkerNum, + pendingWorkerCounter.get(), + System.currentTimeMillis() - start); + if (this.pendingWorkerCounter.get() <= 0) { + this.recovering.set(false); + LOGGER.info( + "recover worker over, available/used : {}/{}", + this.availableWorkers.size(), + usedWorkerNum); } + clusterManager.doFailover(ClusterConstants.DEFAULT_MASTER_ID, null); + } - private void checkError() { - Throwable firstException = this.allocateWorkerErr.get(); - if (firstException != null) { - throw new GeaflowRuntimeException(firstException); - } + private void waitForInit() { + int count = 0; + while (!this.inited.get()) { + count++; + if (count % 100 == 0) { + LOGGER.warn("resource manager not inited, wait {}ms and retry", DEFAULT_SLEEP_MS); + } + SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); } + } + + public ResourceMetrics getResourceMetrics() { + ResourceMetrics metrics = new ResourceMetrics(); + metrics.setPendingWorkers(pendingWorkerCounter.get()); + metrics.setAvailableWorkers(availableWorkers.size()); + metrics.setTotalWorkers(totalWorkerNum); + return metrics; + } + + @VisibleForTesting + protected AtomicInteger getPendingWorkerCounter() { + return this.pendingWorkerCounter; + } + @VisibleForTesting + protected AtomicBoolean getResourceLock() { + return this.resourceLock; + } + + private void checkError() { + Throwable firstException = this.allocateWorkerErr.get(); + if (firstException != null) { + throw new GeaflowRuntimeException(firstException); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/IResourceManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/IResourceManager.java index 35162dc76..edf75d633 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/IResourceManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/IResourceManager.java @@ -19,32 +19,29 @@ package org.apache.geaflow.cluster.resourcemanager; -/** - * IResourceManager interface. - */ +/** IResourceManager interface. */ public interface IResourceManager { - /** - * Initialization. - * - * @param context resource manager context - */ - void init(ResourceManagerContext context); - - /** - * Require resource with specified allocate strategy. - * - * @param requireRequest require request - * @return require response - */ - RequireResponse requireResource(RequireResourceRequest requireRequest); + /** + * Initialization. + * + * @param context resource manager context + */ + void init(ResourceManagerContext context); - /** - * Release resource. - * - * @param releaseRequest release request - * @return release response - */ - ReleaseResponse releaseResource(ReleaseResourceRequest releaseRequest); + /** + * Require resource with specified allocate strategy. + * + * @param requireRequest require request + * @return require response + */ + RequireResponse requireResource(RequireResourceRequest requireRequest); + /** + * Release resource. + * + * @param releaseRequest release request + * @return release response + */ + ReleaseResponse releaseResource(ReleaseResourceRequest releaseRequest); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResourceRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResourceRequest.java index c39e182aa..93f081a15 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResourceRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResourceRequest.java @@ -23,24 +23,23 @@ public class ReleaseResourceRequest { - private final String releaseId; - private final List workers; + private final String releaseId; + private final List workers; - private ReleaseResourceRequest(String releaseId, List workers) { - this.releaseId = releaseId; - this.workers = workers; - } + private ReleaseResourceRequest(String releaseId, List workers) { + this.releaseId = releaseId; + this.workers = workers; + } - public String getReleaseId() { - return this.releaseId; - } + public String getReleaseId() { + return this.releaseId; + } - public List getWorkers() { - return this.workers; - } - - public static ReleaseResourceRequest build(String releaseId, List workers) { - return new ReleaseResourceRequest(releaseId, workers); - } + public List getWorkers() { + return this.workers; + } + public static ReleaseResourceRequest build(String releaseId, List workers) { + return new ReleaseResourceRequest(releaseId, workers); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResponse.java index eface15cf..a3cc6dc16 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ReleaseResponse.java @@ -21,34 +21,33 @@ public class ReleaseResponse { - private final String releaseId; - private final boolean success; - private final String msg; - - private ReleaseResponse(String releaseId, boolean success, String msg) { - this.releaseId = releaseId; - this.success = success; - this.msg = msg; - } - - public String getReleaseId() { - return this.releaseId; - } - - public boolean isSuccess() { - return this.success; - } - - public String getMsg() { - return this.msg; - } - - public static ReleaseResponse success(String releaseId) { - return new ReleaseResponse(releaseId, true, null); - } - - public static ReleaseResponse fail(String releaseId, String msg) { - return new ReleaseResponse(releaseId, false, msg); - } - + private final String releaseId; + private final boolean success; + private final String msg; + + private ReleaseResponse(String releaseId, boolean success, String msg) { + this.releaseId = releaseId; + this.success = success; + this.msg = msg; + } + + public String getReleaseId() { + return this.releaseId; + } + + public boolean isSuccess() { + return this.success; + } + + public String getMsg() { + return this.msg; + } + + public static ReleaseResponse success(String releaseId) { + return new ReleaseResponse(releaseId, true, null); + } + + public static ReleaseResponse fail(String releaseId, String msg) { + return new ReleaseResponse(releaseId, false, msg); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResourceRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResourceRequest.java index 24edbd006..db2fd326b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResourceRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResourceRequest.java @@ -23,36 +23,35 @@ public class RequireResourceRequest { - private final String requireId; - private final int requiredNum; - private final IAllocator.AllocateStrategy allocateStrategy; - - private RequireResourceRequest(String requireId, - int requiredNum, - IAllocator.AllocateStrategy allocateStrategy) { - this.requireId = requireId; - this.requiredNum = requiredNum; - this.allocateStrategy = allocateStrategy; - } - - public String getRequireId() { - return this.requireId; - } - - public int getRequiredNum() { - return this.requiredNum; - } - - public IAllocator.AllocateStrategy getAllocateStrategy() { - return this.allocateStrategy; - } - - public static RequireResourceRequest build(String requireId, int requiredNum) { - return new RequireResourceRequest(requireId, requiredNum, IAllocator.DEFAULT_ALLOCATE_STRATEGY); - } - - public static RequireResourceRequest build(String requireId, int requiredNum, IAllocator.AllocateStrategy allocateStrategy) { - return new RequireResourceRequest(requireId, requiredNum, allocateStrategy); - } - + private final String requireId; + private final int requiredNum; + private final IAllocator.AllocateStrategy allocateStrategy; + + private RequireResourceRequest( + String requireId, int requiredNum, IAllocator.AllocateStrategy allocateStrategy) { + this.requireId = requireId; + this.requiredNum = requiredNum; + this.allocateStrategy = allocateStrategy; + } + + public String getRequireId() { + return this.requireId; + } + + public int getRequiredNum() { + return this.requiredNum; + } + + public IAllocator.AllocateStrategy getAllocateStrategy() { + return this.allocateStrategy; + } + + public static RequireResourceRequest build(String requireId, int requiredNum) { + return new RequireResourceRequest(requireId, requiredNum, IAllocator.DEFAULT_ALLOCATE_STRATEGY); + } + + public static RequireResourceRequest build( + String requireId, int requiredNum, IAllocator.AllocateStrategy allocateStrategy) { + return new RequireResourceRequest(requireId, requiredNum, allocateStrategy); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResponse.java index 4fe76282d..0d885e92f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/RequireResponse.java @@ -23,40 +23,39 @@ public class RequireResponse { - private final String requireId; - private final boolean success; - private final List workers; - private final String msg; - - private RequireResponse(String requireId, boolean success, List workers, String msg) { - this.requireId = requireId; - this.success = success; - this.workers = workers; - this.msg = msg; - } - - public String getRequireId() { - return this.requireId; - } - - public boolean isSuccess() { - return this.success; - } - - public List getWorkers() { - return this.workers; - } - - public String getMsg() { - return this.msg; - } - - public static RequireResponse success(String requireId, List workers) { - return new RequireResponse(requireId, true, workers, null); - } - - public static RequireResponse fail(String requireId, String msg) { - return new RequireResponse(requireId, false, null, msg); - } - + private final String requireId; + private final boolean success; + private final List workers; + private final String msg; + + private RequireResponse(String requireId, boolean success, List workers, String msg) { + this.requireId = requireId; + this.success = success; + this.workers = workers; + this.msg = msg; + } + + public String getRequireId() { + return this.requireId; + } + + public boolean isSuccess() { + return this.success; + } + + public List getWorkers() { + return this.workers; + } + + public String getMsg() { + return this.msg; + } + + public static RequireResponse success(String requireId, List workers) { + return new RequireResponse(requireId, true, workers, null); + } + + public static RequireResponse fail(String requireId, String msg) { + return new RequireResponse(requireId, false, null, msg); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceInfo.java index 75d4aec00..6168f78b3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceInfo.java @@ -24,35 +24,32 @@ public class ResourceInfo implements Serializable { - private String resourceId; - private List workers; - - public ResourceInfo(String resourceId, List workers) { - this.resourceId = resourceId; - this.workers = workers; - } - - public void setResourceId(String resourceId) { - this.resourceId = resourceId; - } - - public void setWorkers(List workers) { - this.workers = workers; - } - - public String getResourceId() { - return resourceId; - } - - public List getWorkers() { - return workers; - } - - @Override - public String toString() { - return "ResourceInfo{" - + "resourceId=" + resourceId - + ", workers=" + workers - + '}'; - } + private String resourceId; + private List workers; + + public ResourceInfo(String resourceId, List workers) { + this.resourceId = resourceId; + this.workers = workers; + } + + public void setResourceId(String resourceId) { + this.resourceId = resourceId; + } + + public void setWorkers(List workers) { + this.workers = workers; + } + + public String getResourceId() { + return resourceId; + } + + public List getWorkers() { + return workers; + } + + @Override + public String toString() { + return "ResourceInfo{" + "resourceId=" + resourceId + ", workers=" + workers + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceManagerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceManagerContext.java index 669362f7f..e8521e33b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceManagerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceManagerContext.java @@ -25,34 +25,34 @@ public class ResourceManagerContext { - private final Configuration config; - private final ClusterContext clusterContext; - private boolean recover; - - private ResourceManagerContext(MasterContext masterContext, ClusterContext clusterContext) { - this.config = masterContext.getConfiguration(); - this.clusterContext = clusterContext; - this.recover = clusterContext.isRecover(); - } - - public Configuration getConfig() { - return this.config; - } - - public ClusterContext getClusterContext() { - return this.clusterContext; - } - - public boolean isRecover() { - return this.recover; - } - - public void setRecover(boolean recover) { - this.recover = recover; - } - - public static ResourceManagerContext build(MasterContext masterContext, ClusterContext clusterContext) { - return new ResourceManagerContext(masterContext, clusterContext); - } - + private final Configuration config; + private final ClusterContext clusterContext; + private boolean recover; + + private ResourceManagerContext(MasterContext masterContext, ClusterContext clusterContext) { + this.config = masterContext.getConfiguration(); + this.clusterContext = clusterContext; + this.recover = clusterContext.isRecover(); + } + + public Configuration getConfig() { + return this.config; + } + + public ClusterContext getClusterContext() { + return this.clusterContext; + } + + public boolean isRecover() { + return this.recover; + } + + public void setRecover(boolean recover) { + this.recover = recover; + } + + public static ResourceManagerContext build( + MasterContext masterContext, ClusterContext clusterContext) { + return new ResourceManagerContext(masterContext, clusterContext); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceSession.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceSession.java index 0bd204811..5a49b5891 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceSession.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/ResourceSession.java @@ -25,27 +25,26 @@ public class ResourceSession implements Serializable { - private final Map workers = new HashMap<>(); - private final String id; + private final Map workers = new HashMap<>(); + private final String id; - public ResourceSession(String id) { - this.id = id; - } + public ResourceSession(String id) { + this.id = id; + } - public String getId() { - return this.id; - } + public String getId() { + return this.id; + } - public Map getWorkers() { - return this.workers; - } + public Map getWorkers() { + return this.workers; + } - public void addWorker(WorkerInfo.WorkerId workerId, WorkerInfo worker) { - this.workers.put(workerId, worker); - } - - public boolean removeWorker(WorkerInfo.WorkerId workerId) { - return this.workers.remove(workerId) != null; - } + public void addWorker(WorkerInfo.WorkerId workerId, WorkerInfo worker) { + this.workers.put(workerId, worker); + } + public boolean removeWorker(WorkerInfo.WorkerId workerId) { + return this.workers.remove(workerId) != null; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerInfo.java index 67fa625ff..fce3c454d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerInfo.java @@ -24,205 +24,217 @@ public class WorkerInfo implements Comparable, Serializable { - private String host; - private int rpcPort; - private int shufflePort; - private int processId; - private int processIndex; - private int workerIndex; - private String containerName; - - public WorkerInfo() { - } - - public WorkerInfo(String host, - int rpcPort, - int shufflePort, - int processId, - int workerId, - String containerName) { - this.host = host; - this.rpcPort = rpcPort; - this.shufflePort = shufflePort; - this.processId = processId; - this.workerIndex = workerId; - this.containerName = containerName; - } - - public WorkerInfo(String host, - int rpcPort, - int shufflePort, - int processId, - int processIndex, - int workerId, - String containerName) { - this(host, rpcPort, shufflePort, processId, workerId, containerName); - this.processIndex = processIndex; - } - - public String getHost() { - return this.host; - } - - public void setHost(String host) { - this.host = host; - } - - public int getProcessId() { - return this.processId; - } - - public void setProcessId(int processId) { - this.processId = processId; - } - - public int getProcessIndex() { - return processIndex; - } - - public void setProcessIndex(int processIndex) { - this.processIndex = processIndex; - } - - public int getRpcPort() { - return this.rpcPort; - } - - public void setRpcPort(int rpcPort) { - this.rpcPort = rpcPort; - } - - public int getShufflePort() { - return shufflePort; - } - - public void setShufflePort(int shufflePort) { - this.shufflePort = shufflePort; - } - - public int getWorkerIndex() { - return this.workerIndex; - } - - public void setWorkerIndex(int workerIndex) { - this.workerIndex = workerIndex; + private String host; + private int rpcPort; + private int shufflePort; + private int processId; + private int processIndex; + private int workerIndex; + private String containerName; + + public WorkerInfo() {} + + public WorkerInfo( + String host, + int rpcPort, + int shufflePort, + int processId, + int workerId, + String containerName) { + this.host = host; + this.rpcPort = rpcPort; + this.shufflePort = shufflePort; + this.processId = processId; + this.workerIndex = workerId; + this.containerName = containerName; + } + + public WorkerInfo( + String host, + int rpcPort, + int shufflePort, + int processId, + int processIndex, + int workerId, + String containerName) { + this(host, rpcPort, shufflePort, processId, workerId, containerName); + this.processIndex = processIndex; + } + + public String getHost() { + return this.host; + } + + public void setHost(String host) { + this.host = host; + } + + public int getProcessId() { + return this.processId; + } + + public void setProcessId(int processId) { + this.processId = processId; + } + + public int getProcessIndex() { + return processIndex; + } + + public void setProcessIndex(int processIndex) { + this.processIndex = processIndex; + } + + public int getRpcPort() { + return this.rpcPort; + } + + public void setRpcPort(int rpcPort) { + this.rpcPort = rpcPort; + } + + public int getShufflePort() { + return shufflePort; + } + + public void setShufflePort(int shufflePort) { + this.shufflePort = shufflePort; + } + + public int getWorkerIndex() { + return this.workerIndex; + } + + public void setWorkerIndex(int workerIndex) { + this.workerIndex = workerIndex; + } + + public String getContainerName() { + return containerName; + } + + public void setContainerName(String containerName) { + this.containerName = containerName; + } + + public WorkerId generateWorkerId() { + return new WorkerId(this.containerName, this.workerIndex); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerInfo that = (WorkerInfo) o; + return Objects.equals(this.containerName, that.containerName) + && this.processId == that.processId + && this.workerIndex == that.workerIndex; + } + + @Override + public int hashCode() { + return Objects.hash(this.containerName, this.processId, this.workerIndex); + } + + @Override + public int compareTo(WorkerInfo o) { + int flag = this.containerName.compareTo(o.containerName); + if (flag == 0) { + flag = Integer.compare(this.processId, o.processId); + if (flag == 0) { + flag = Integer.compare(this.workerIndex, o.workerIndex); + } + } + return flag; + } + + @Override + public String toString() { + return "WorkerInfo{" + + "host='" + + host + + '\'' + + ", rpcPort=" + + rpcPort + + ", shufflePort=" + + shufflePort + + ", processId=" + + processId + + ", processIndex=" + + processIndex + + ", workerIndex=" + + workerIndex + + ", containerName='" + + containerName + + '\'' + + '}'; + } + + public static WorkerInfo build( + String host, + int rpcPort, + int shufflePort, + int processId, + int workerId, + String containerName) { + return new WorkerInfo(host, rpcPort, shufflePort, processId, workerId, containerName); + } + + public static WorkerInfo build( + String host, + int rpcPort, + int shufflePort, + int processId, + int processIndex, + int workerId, + String containerName) { + return new WorkerInfo( + host, rpcPort, shufflePort, processId, processIndex, workerId, containerName); + } + + public static class WorkerId { + + private final String containerName; + private final int workerIndex; + + public WorkerId(String containerName, int workerIndex) { + this.containerName = containerName; + this.workerIndex = workerIndex; } public String getContainerName() { - return containerName; - } - - public void setContainerName(String containerName) { - this.containerName = containerName; + return this.containerName; } - public WorkerId generateWorkerId() { - return new WorkerId(this.containerName, this.workerIndex); + public int getWorkerIndex() { + return this.workerIndex; } @Override public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - WorkerInfo that = (WorkerInfo) o; - return Objects.equals(this.containerName, that.containerName) - && this.processId == that.processId - && this.workerIndex == that.workerIndex; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerId workerId = (WorkerId) o; + return Objects.equals(this.containerName, workerId.containerName) + && this.workerIndex == workerId.workerIndex; } @Override public int hashCode() { - return Objects.hash(this.containerName, this.processId, this.workerIndex); - } - - @Override - public int compareTo(WorkerInfo o) { - int flag = this.containerName.compareTo(o.containerName); - if (flag == 0) { - flag = Integer.compare(this.processId, o.processId); - if (flag == 0) { - flag = Integer.compare(this.workerIndex, o.workerIndex); - } - } - return flag; + return Objects.hash(this.containerName, this.workerIndex); } @Override public String toString() { - return "WorkerInfo{" - + "host='" + host + '\'' - + ", rpcPort=" + rpcPort - + ", shufflePort=" + shufflePort - + ", processId=" + processId - + ", processIndex=" + processIndex - + ", workerIndex=" + workerIndex - + ", containerName='" + containerName + '\'' - + '}'; - } - - public static WorkerInfo build(String host, - int rpcPort, - int shufflePort, - int processId, - int workerId, - String containerName) { - return new WorkerInfo(host, rpcPort, shufflePort, processId, workerId, containerName); - } - - public static WorkerInfo build(String host, - int rpcPort, - int shufflePort, - int processId, - int processIndex, - int workerId, - String containerName) { - return new WorkerInfo(host, rpcPort, shufflePort, processId, processIndex, workerId, containerName); + return this.containerName + '/' + this.workerIndex; } - - public static class WorkerId { - - private final String containerName; - private final int workerIndex; - - public WorkerId(String containerName, int workerIndex) { - this.containerName = containerName; - this.workerIndex = workerIndex; - } - - public String getContainerName() { - return this.containerName; - } - - public int getWorkerIndex() { - return this.workerIndex; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - WorkerId workerId = (WorkerId) o; - return Objects.equals(this.containerName, workerId.containerName) && this.workerIndex == workerId.workerIndex; - } - - @Override - public int hashCode() { - return Objects.hash(this.containerName, this.workerIndex); - } - - @Override - public String toString() { - return this.containerName + '/' + this.workerIndex; - } - - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerSnapshot.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerSnapshot.java index 50814821f..8e037327c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerSnapshot.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/WorkerSnapshot.java @@ -24,31 +24,29 @@ public class WorkerSnapshot implements Serializable { - private List availableWorkers; - private List sessions; + private List availableWorkers; + private List sessions; - public WorkerSnapshot() { - } + public WorkerSnapshot() {} - public WorkerSnapshot(List availableWorkers, List sessions) { - this.availableWorkers = availableWorkers; - this.sessions = sessions; - } + public WorkerSnapshot(List availableWorkers, List sessions) { + this.availableWorkers = availableWorkers; + this.sessions = sessions; + } - public List getAvailableWorkers() { - return this.availableWorkers; - } + public List getAvailableWorkers() { + return this.availableWorkers; + } - public void setAvailableWorkers(List availableWorkers) { - this.availableWorkers = availableWorkers; - } + public void setAvailableWorkers(List availableWorkers) { + this.availableWorkers = availableWorkers; + } - public List getSessions() { - return this.sessions; - } - - public void setSessions(List sessions) { - this.sessions = sessions; - } + public List getSessions() { + return this.sessions; + } + public void setSessions(List sessions) { + this.sessions = sessions; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/AbstractAllocator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/AbstractAllocator.java index 470b123e4..7b61ce0bf 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/AbstractAllocator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/AbstractAllocator.java @@ -25,54 +25,54 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public abstract class AbstractAllocator implements IAllocator { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractAllocator.class); - - protected static final WorkerGroupByFunction PROC_GROUP_SELECTOR = - worker -> String.format("%s-%d", worker.getHost(), worker.getProcessId()); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractAllocator.class); - protected Map, LinkedList> group2workers; - - protected AbstractAllocator() { - this.group2workers = new TreeMap<>(); - } + protected static final WorkerGroupByFunction PROC_GROUP_SELECTOR = + worker -> String.format("%s-%d", worker.getHost(), worker.getProcessId()); - @Override - public List allocate(Collection idleWorkers, int num) { + protected Map, LinkedList> group2workers; - if (idleWorkers.size() < num) { - LOGGER.warn("worker not enough, available {} require {}", idleWorkers.size(), num); - return Collections.emptyList(); - } + protected AbstractAllocator() { + this.group2workers = new TreeMap<>(); + } - WorkerGroupByFunction groupSelector = this.getWorkerGroupByFunction(); - for (W worker : idleWorkers) { - Comparable group = groupSelector.getGroup(worker); - List list = this.group2workers.computeIfAbsent(group, g -> new LinkedList<>()); - list.add(worker); - } + @Override + public List allocate(Collection idleWorkers, int num) { - List allocated = doAllocate(num); - reset(); + if (idleWorkers.size() < num) { + LOGGER.warn("worker not enough, available {} require {}", idleWorkers.size(), num); + return Collections.emptyList(); + } - return allocated; + WorkerGroupByFunction groupSelector = this.getWorkerGroupByFunction(); + for (W worker : idleWorkers) { + Comparable group = groupSelector.getGroup(worker); + List list = this.group2workers.computeIfAbsent(group, g -> new LinkedList<>()); + list.add(worker); } - /** - * Allocate workers with strategy. - * - * @param num number - * @return workers - */ - protected abstract List doAllocate(int num); + List allocated = doAllocate(num); + reset(); - private void reset() { - this.group2workers.clear(); - } + return allocated; + } + + /** + * Allocate workers with strategy. + * + * @param num number + * @return workers + */ + protected abstract List doAllocate(int num); + private void reset() { + this.group2workers.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/IAllocator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/IAllocator.java index 02f5b3594..de12e1ff6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/IAllocator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/IAllocator.java @@ -24,51 +24,46 @@ public interface IAllocator { - AllocateStrategy DEFAULT_ALLOCATE_STRATEGY = AllocateStrategy.ROUND_ROBIN; + AllocateStrategy DEFAULT_ALLOCATE_STRATEGY = AllocateStrategy.ROUND_ROBIN; - enum AllocateStrategy { - /** - * Round-robin. - */ - ROUND_ROBIN, - - /** - * Allocate same number of workers on every JVM process. - * Require number should be a multiple of the number of JVM, or else return zero. - */ - PROCESS_FAIR - } - - @FunctionalInterface - interface WorkerGroupByFunction { - - /** - * Get the group the worker belongs. - */ - Comparable getGroup(W worker); - } + enum AllocateStrategy { + /** Round-robin. */ + ROUND_ROBIN, /** - * Strategy of this allocator. - * - * @return allocate strategy + * Allocate same number of workers on every JVM process. Require number should be a multiple of + * the number of JVM, or else return zero. */ - AllocateStrategy getStrategy(); + PROCESS_FAIR + } - /** - * Worker group selector of this allocator. - * - * @return worker group selector - */ - WorkerGroupByFunction getWorkerGroupByFunction(); + @FunctionalInterface + interface WorkerGroupByFunction { - /** - * Allocate workers. - * - * @param idleWorkers workers to allocate - * @param num number - * @return allocated workers - */ - List allocate(Collection idleWorkers, int num); + /** Get the group the worker belongs. */ + Comparable getGroup(W worker); + } + + /** + * Strategy of this allocator. + * + * @return allocate strategy + */ + AllocateStrategy getStrategy(); + + /** + * Worker group selector of this allocator. + * + * @return worker group selector + */ + WorkerGroupByFunction getWorkerGroupByFunction(); + /** + * Allocate workers. + * + * @param idleWorkers workers to allocate + * @param num number + * @return allocated workers + */ + List allocate(Collection idleWorkers, int num); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/ProcessFairAllocator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/ProcessFairAllocator.java index c47dc866e..ba500ba64 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/ProcessFairAllocator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/ProcessFairAllocator.java @@ -26,76 +26,83 @@ import java.util.List; import java.util.Map; import java.util.stream.IntStream; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ProcessFairAllocator extends AbstractAllocator { - private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFairAllocator.class); - - private static final String SEPARATOR = "-"; + private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFairAllocator.class); - @Override - public List doAllocate(int num) { + private static final String SEPARATOR = "-"; - int processNum = this.group2workers.size(); - if (num % processNum != 0) { - LOGGER.warn("require num must be a multiple of process num, available {} require {}", this.group2workers.size(), num); - return Collections.emptyList(); - } + @Override + public List doAllocate(int num) { - List allocated = new ArrayList<>(); - int nPerProc = num / processNum; - for (Map.Entry, LinkedList> entry : this.group2workers.entrySet()) { - String key = entry.getKey().toString(); - LinkedList list = entry.getValue(); - if (list.size() < nPerProc) { - LOGGER.warn("not enough worker for jvm {}, available {} require {}", key, list.size(), nPerProc); - return Collections.emptyList(); - } - IntStream.range(0, nPerProc).forEach(i -> allocated.add(list.pollFirst())); - } - - // Sort workers and assemble process index id for every worker. - assembleProcessIndexId(processNum, allocated); - return allocated; + int processNum = this.group2workers.size(); + if (num % processNum != 0) { + LOGGER.warn( + "require num must be a multiple of process num, available {} require {}", + this.group2workers.size(), + num); + return Collections.emptyList(); } - @Override - public WorkerGroupByFunction getWorkerGroupByFunction() { - return PROC_GROUP_SELECTOR; + List allocated = new ArrayList<>(); + int nPerProc = num / processNum; + for (Map.Entry, LinkedList> entry : + this.group2workers.entrySet()) { + String key = entry.getKey().toString(); + LinkedList list = entry.getValue(); + if (list.size() < nPerProc) { + LOGGER.warn( + "not enough worker for jvm {}, available {} require {}", key, list.size(), nPerProc); + return Collections.emptyList(); + } + IntStream.range(0, nPerProc).forEach(i -> allocated.add(list.pollFirst())); } - @Override - public AllocateStrategy getStrategy() { - return AllocateStrategy.PROCESS_FAIR; - } + // Sort workers and assemble process index id for every worker. + assembleProcessIndexId(processNum, allocated); + return allocated; + } - /** - * Sort workers and assemble process index id for every worker. - * - * @param workers - */ - public void assembleProcessIndexId(int processNum, List workers) { - // Sort worker by hostname + process id in order to ensure sort by jvm. - Collections.sort(workers, new Comparator() { - @Override - public int compare(WorkerInfo o1, WorkerInfo o2) { - return (o1.getHost() + SEPARATOR + o1.getProcessId()).compareTo( - o2.getHost() + SEPARATOR + o2.getProcessId()); - } + @Override + public WorkerGroupByFunction getWorkerGroupByFunction() { + return PROC_GROUP_SELECTOR; + } + + @Override + public AllocateStrategy getStrategy() { + return AllocateStrategy.PROCESS_FAIR; + } + + /** + * Sort workers and assemble process index id for every worker. + * + * @param workers + */ + public void assembleProcessIndexId(int processNum, List workers) { + // Sort worker by hostname + process id in order to ensure sort by jvm. + Collections.sort( + workers, + new Comparator() { + @Override + public int compare(WorkerInfo o1, WorkerInfo o2) { + return (o1.getHost() + SEPARATOR + o1.getProcessId()) + .compareTo(o2.getHost() + SEPARATOR + o2.getProcessId()); + } }); - int step = workers.size() / processNum; - int globalProcessIndex = 0; - // Assemble the worker jvm index. - for (int i = 0; i < workers.size(); i += step) { - for (int j = 0; j < step; j++) { - workers.get(i + j).setProcessIndex(globalProcessIndex); - } - globalProcessIndex++; - } + int step = workers.size() / processNum; + int globalProcessIndex = 0; + // Assemble the worker jvm index. + for (int i = 0; i < workers.size(); i += step) { + for (int j = 0; j < step; j++) { + workers.get(i + j).setProcessIndex(globalProcessIndex); + } + globalProcessIndex++; } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/RoundRobinAllocator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/RoundRobinAllocator.java index 2b7aa565d..99d9f4cf0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/RoundRobinAllocator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/resourcemanager/allocator/RoundRobinAllocator.java @@ -23,41 +23,40 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; public class RoundRobinAllocator extends AbstractAllocator { - @Override - public List doAllocate(int num) { - - List allocated = new ArrayList<>(); - int n = 0; - Iterator> groupIterator = this.group2workers.keySet().iterator(); - while (n < num) { - if (groupIterator.hasNext()) { - Comparable group = groupIterator.next(); - LinkedList next = this.group2workers.get(group); - if (!next.isEmpty()) { - allocated.add(next.pollFirst()); - n++; - } - } else { - groupIterator = this.group2workers.keySet().iterator(); - } + @Override + public List doAllocate(int num) { + + List allocated = new ArrayList<>(); + int n = 0; + Iterator> groupIterator = this.group2workers.keySet().iterator(); + while (n < num) { + if (groupIterator.hasNext()) { + Comparable group = groupIterator.next(); + LinkedList next = this.group2workers.get(group); + if (!next.isEmpty()) { + allocated.add(next.pollFirst()); + n++; } - - return allocated; + } else { + groupIterator = this.group2workers.keySet().iterator(); + } } - @Override - public AllocateStrategy getStrategy() { - return AllocateStrategy.ROUND_ROBIN; - } + return allocated; + } - @Override - public WorkerGroupByFunction getWorkerGroupByFunction() { - return PROC_GROUP_SELECTOR; - } + @Override + public AllocateStrategy getStrategy() { + return AllocateStrategy.ROUND_ROBIN; + } + @Override + public WorkerGroupByFunction getWorkerGroupByFunction() { + return PROC_GROUP_SELECTOR; + } } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/IResult.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/IResult.java index 1a76df82b..697f05857 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/IResult.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/IResult.java @@ -20,23 +20,17 @@ package org.apache.geaflow.cluster.response; import java.util.List; + import org.apache.geaflow.shuffle.desc.OutputType; public interface IResult { - /** - * Returns id. - */ - int getId(); - - /** - * Returns the response. - */ - List getResponse(); + /** Returns id. */ + int getId(); - /** - * Returns the response type. - */ - OutputType getType(); + /** Returns the response. */ + List getResponse(); + /** Returns the response type. */ + OutputType getType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ResponseResult.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ResponseResult.java index 77adc64ba..a418c20d0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ResponseResult.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ResponseResult.java @@ -21,34 +21,33 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.shuffle.desc.OutputType; public class ResponseResult implements IResult, Serializable { - private int collectId; - private OutputType outputType; - private List responses; - - public ResponseResult(int collectId, OutputType outputType, List responses) { - this.collectId = collectId; - this.outputType = outputType; - this.responses = responses; - } - - @Override - public int getId() { - return collectId; - } - - @Override - public List getResponse() { - return responses; - } - - @Override - public OutputType getType() { - return outputType; - } - + private int collectId; + private OutputType outputType; + private List responses; + + public ResponseResult(int collectId, OutputType outputType, List responses) { + this.collectId = collectId; + this.outputType = outputType; + this.responses = responses; + } + + @Override + public int getId() { + return collectId; + } + + @Override + public List getResponse() { + return responses; + } + + @Override + public OutputType getType() { + return outputType; + } } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ShardResult.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ShardResult.java index fb5887aea..b3a965c9b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ShardResult.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/response/ShardResult.java @@ -20,52 +20,52 @@ package org.apache.geaflow.cluster.response; import java.util.List; + import org.apache.geaflow.shuffle.desc.OutputType; import org.apache.geaflow.shuffle.message.ISliceMeta; public class ShardResult implements IResult { - /** - * Use edge id of output info to identify the result. - */ - private int id; - private OutputType outputType; - private List slices; - private long recordNum; - private long recordBytes; + /** Use edge id of output info to identify the result. */ + private int id; - public ShardResult(int id, OutputType outputType, List slices) { - this.id = id; - this.outputType = outputType; - this.slices = slices; - if (slices != null) { - for (ISliceMeta sliceMeta : slices) { - recordNum += sliceMeta.getRecordNum(); - recordBytes += sliceMeta.getEncodedSize(); - } - } - } + private OutputType outputType; + private List slices; + private long recordNum; + private long recordBytes; - @Override - public int getId() { - return id; + public ShardResult(int id, OutputType outputType, List slices) { + this.id = id; + this.outputType = outputType; + this.slices = slices; + if (slices != null) { + for (ISliceMeta sliceMeta : slices) { + recordNum += sliceMeta.getRecordNum(); + recordBytes += sliceMeta.getEncodedSize(); + } } + } - @Override - public List getResponse() { - return slices; - } + @Override + public int getId() { + return id; + } - public long getRecordNum() { - return recordNum; - } + @Override + public List getResponse() { + return slices; + } - public long getRecordBytes() { - return recordBytes; - } + public long getRecordNum() { + return recordNum; + } - @Override - public OutputType getType() { - return outputType; - } + public long getRecordBytes() { + return recordBytes; + } + + @Override + public OutputType getType() { + return outputType; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ConnectAddress.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ConnectAddress.java index e7413a927..9ada544c7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ConnectAddress.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ConnectAddress.java @@ -23,62 +23,60 @@ import java.util.Objects; public class ConnectAddress implements Serializable { - public static final String PORT_SEPARATOR = ":"; - private String host; - private int port; + public static final String PORT_SEPARATOR = ":"; + private String host; + private int port; - public ConnectAddress() { - } + public ConnectAddress() {} - public ConnectAddress(String host, int port) { - this.host = host; - this.port = port; - } + public ConnectAddress(String host, int port) { + this.host = host; + this.port = port; + } - public String getHost() { - return host; - } + public String getHost() { + return host; + } - public void setHost(String host) { - this.host = host; - } + public void setHost(String host) { + this.host = host; + } - public int getPort() { - return port; - } + public int getPort() { + return port; + } - public void setPort(int port) { - this.port = port; - } + public void setPort(int port) { + this.port = port; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ConnectAddress that = (ConnectAddress) o; - return port == that.port && Objects.equals(host, that.host); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(host, port); + if (o == null || getClass() != o.getClass()) { + return false; } + ConnectAddress that = (ConnectAddress) o; + return port == that.port && Objects.equals(host, that.host); + } - @Override - public String toString() { - return host + PORT_SEPARATOR + port; - } + @Override + public int hashCode() { + return Objects.hash(host, port); + } - public static ConnectAddress build(String address) { - ConnectAddress rpcAddress = new ConnectAddress(); - String[] hostAndPort = address.split(PORT_SEPARATOR); - rpcAddress.setHost(hostAndPort[0]); - rpcAddress.setPort(Integer.parseInt(hostAndPort[1])); - return rpcAddress; - } + @Override + public String toString() { + return host + PORT_SEPARATOR + port; + } + public static ConnectAddress build(String address) { + ConnectAddress rpcAddress = new ConnectAddress(); + String[] hostAndPort = address.split(PORT_SEPARATOR); + rpcAddress.setHost(hostAndPort[0]); + rpcAddress.setPort(Integer.parseInt(hostAndPort[1])); + return rpcAddress; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncContainerEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncContainerEndpoint.java index 57a15abd7..08b52d561 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncContainerEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncContainerEndpoint.java @@ -19,16 +19,15 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Container.Request; import org.apache.geaflow.rpc.proto.Container.Response; -public interface IAsyncContainerEndpoint extends IContainerEndpoint { +import com.baidu.brpc.client.RpcCallback; - /** - * Async container process. - */ - Future process(Request request, RpcCallback callback); +public interface IAsyncContainerEndpoint extends IContainerEndpoint { + /** Async container process. */ + Future process(Request request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncDriverEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncDriverEndpoint.java index c9cb34f57..24f892825 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncDriverEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncDriverEndpoint.java @@ -19,16 +19,15 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Driver.PipelineReq; import org.apache.geaflow.rpc.proto.Driver.PipelineRes; -public interface IAsyncDriverEndpoint extends IDriverEndpoint { +import com.baidu.brpc.client.RpcCallback; - /** - * Async execute pipeline. - */ - Future executePipeline(PipelineReq request, RpcCallback callback); +public interface IAsyncDriverEndpoint extends IDriverEndpoint { + /** Async execute pipeline. */ + Future executePipeline(PipelineReq request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMasterEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMasterEndpoint.java index 720b83550..befdefc91 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMasterEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMasterEndpoint.java @@ -19,29 +19,26 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Master.HeartbeatRequest; import org.apache.geaflow.rpc.proto.Master.HeartbeatResponse; import org.apache.geaflow.rpc.proto.Master.RegisterRequest; import org.apache.geaflow.rpc.proto.Master.RegisterResponse; -public interface IAsyncMasterEndpoint extends IMasterEndpoint { +import com.baidu.brpc.client.RpcCallback; - /** - * Async register container. - */ - Future registerContainer(RegisterRequest request, RpcCallback callback); +public interface IAsyncMasterEndpoint extends IMasterEndpoint { - /** - * Async register driver. - */ - Future registerDriver(RegisterRequest request, RpcCallback callback); + /** Async register container. */ + Future registerContainer( + RegisterRequest request, RpcCallback callback); - /** - * Async receive heartbeat. - */ - Future receiveHeartbeat(HeartbeatRequest request, RpcCallback callback); + /** Async register driver. */ + Future registerDriver( + RegisterRequest request, RpcCallback callback); + /** Async receive heartbeat. */ + Future receiveHeartbeat( + HeartbeatRequest request, RpcCallback callback); } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMetricEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMetricEndpoint.java index f8e4a9745..fc0de56f0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMetricEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncMetricEndpoint.java @@ -19,16 +19,16 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Metrics.MetricQueryRequest; import org.apache.geaflow.rpc.proto.Metrics.MetricQueryResponse; +import com.baidu.brpc.client.RpcCallback; + public interface IAsyncMetricEndpoint extends IMetricEndpoint { - /** - * Async query metrics. - */ - Future queryMetrics(MetricQueryRequest request, - RpcCallback callback); + /** Async query metrics. */ + Future queryMetrics( + MetricQueryRequest request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncSupervisorEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncSupervisorEndpoint.java index a1771f525..2b5de920b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncSupervisorEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IAsyncSupervisorEndpoint.java @@ -19,12 +19,14 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; -import com.google.protobuf.Empty; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Supervisor.RestartRequest; +import com.baidu.brpc.client.RpcCallback; +import com.google.protobuf.Empty; + public interface IAsyncSupervisorEndpoint extends ISupervisorEndpoint { - Future restart(RestartRequest request, RpcCallback callback); + Future restart(RestartRequest request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpoint.java index a2a7b3c43..1962e4004 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpoint.java @@ -19,21 +19,18 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import java.io.Serializable; + import org.apache.geaflow.rpc.proto.Container.Request; import org.apache.geaflow.rpc.proto.Container.Response; -public interface IContainerEndpoint extends Serializable { +import com.google.protobuf.Empty; - /** - * Container process. - */ - Response process(Request request); +public interface IContainerEndpoint extends Serializable { - /** - * Container close. - */ - Empty close(Empty request); + /** Container process. */ + Response process(Request request); + /** Container close. */ + Empty close(Empty request); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpointRef.java index c0472d9f0..24b4e274d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IContainerEndpointRef.java @@ -21,15 +21,13 @@ import java.io.Serializable; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.rpc.RpcEndpointRef.RpcCallback; import org.apache.geaflow.rpc.proto.Container.Response; public interface IContainerEndpointRef extends Serializable { - /** - * Process event request. - */ - Future process(IEvent request, RpcCallback callback); - + /** Process event request. */ + Future process(IEvent request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpoint.java index d5e80d464..06c77667a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpoint.java @@ -19,19 +19,16 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import org.apache.geaflow.rpc.proto.Driver.PipelineReq; import org.apache.geaflow.rpc.proto.Driver.PipelineRes; +import com.google.protobuf.Empty; + public interface IDriverEndpoint { - /** - * Driver execute pipeline. - */ - PipelineRes executePipeline(PipelineReq request); + /** Driver execute pipeline. */ + PipelineRes executePipeline(PipelineReq request); - /** - * Driver close. - */ - Empty close(Empty request); + /** Driver close. */ + Empty close(Empty request); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpointRef.java index 328a24e7e..f4045c75a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IDriverEndpointRef.java @@ -20,14 +20,12 @@ package org.apache.geaflow.cluster.rpc; import java.io.Serializable; + import org.apache.geaflow.pipeline.IPipelineResult; import org.apache.geaflow.pipeline.Pipeline; public interface IDriverEndpointRef extends Serializable { - /** - * Receive and execute pipeline. - */ - IPipelineResult executePipeline(Pipeline pipeline); - + /** Receive and execute pipeline. */ + IPipelineResult executePipeline(Pipeline pipeline); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpoint.java index 237254703..a21cf364f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpoint.java @@ -19,37 +19,27 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import org.apache.geaflow.rpc.proto.Master.HeartbeatRequest; import org.apache.geaflow.rpc.proto.Master.HeartbeatResponse; import org.apache.geaflow.rpc.proto.Master.RegisterRequest; import org.apache.geaflow.rpc.proto.Master.RegisterResponse; +import com.google.protobuf.Empty; + public interface IMasterEndpoint { - /** - * Register container into master. - */ - RegisterResponse registerContainer(RegisterRequest request); - - /** - * Register driver into master. - */ - RegisterResponse registerDriver(RegisterRequest request); - - /** - * Receive heart beat. - */ - HeartbeatResponse receiveHeartbeat(HeartbeatRequest request); - - /** - * Receive exception. - */ - Empty receiveException(HeartbeatRequest request); - - /** - * Close master. - */ - Empty close(Empty request); -} + /** Register container into master. */ + RegisterResponse registerContainer(RegisterRequest request); + /** Register driver into master. */ + RegisterResponse registerDriver(RegisterRequest request); + + /** Receive heart beat. */ + HeartbeatResponse receiveHeartbeat(HeartbeatRequest request); + + /** Receive exception. */ + Empty receiveException(HeartbeatRequest request); + + /** Close master. */ + Empty close(Empty request); +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpointRef.java index 02b588dc4..1a6edb590 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMasterEndpointRef.java @@ -19,29 +19,25 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import java.io.Serializable; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.rpc.RpcEndpointRef.RpcCallback; import org.apache.geaflow.common.heartbeat.Heartbeat; import org.apache.geaflow.rpc.proto.Master.HeartbeatResponse; import org.apache.geaflow.rpc.proto.Master.RegisterResponse; -public interface IMasterEndpointRef extends Serializable { +import com.google.protobuf.Empty; - /** - * Register container into master. - */ - Future registerContainer(T request, RpcCallback callback); +public interface IMasterEndpointRef extends Serializable { - /** - * Send heartbeat. - */ - Future sendHeartBeat(Heartbeat heartbeat, RpcCallback callback); + /** Register container into master. */ + Future registerContainer(T request, RpcCallback callback); - /** - * Send exception. - */ - Empty sendException(Integer containerId, String containerName, String message); + /** Send heartbeat. */ + Future sendHeartBeat( + Heartbeat heartbeat, RpcCallback callback); + /** Send exception. */ + Empty sendException(Integer containerId, String containerName, String message); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpoint.java index 23102c20e..1f3806525 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpoint.java @@ -24,9 +24,6 @@ public interface IMetricEndpoint { - /** - * Query metrics. - */ - MetricQueryResponse queryMetrics(MetricQueryRequest request); - + /** Query metrics. */ + MetricQueryResponse queryMetrics(MetricQueryRequest request); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpointRef.java index 52c7bdd27..74601d87c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IMetricEndpointRef.java @@ -20,14 +20,13 @@ package org.apache.geaflow.cluster.rpc; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Metrics.MetricQueryRequest; import org.apache.geaflow.rpc.proto.Metrics.MetricQueryResponse; public interface IMetricEndpointRef extends RpcEndpointRef { - /** - * Async query metrics. - */ - Future queryMetrics(MetricQueryRequest request, RpcCallback callback); - + /** Async query metrics. */ + Future queryMetrics( + MetricQueryRequest request, RpcCallback callback); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineManagerEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineManagerEndpointRef.java index 1a610c4df..4fadc8353 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineManagerEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineManagerEndpointRef.java @@ -19,6 +19,4 @@ package org.apache.geaflow.cluster.rpc; -public interface IPipelineManagerEndpointRef extends IContainerEndpointRef { - -} +public interface IPipelineManagerEndpointRef extends IContainerEndpointRef {} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineMasterEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineMasterEndpoint.java index dc09321ce..79034361a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineMasterEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IPipelineMasterEndpoint.java @@ -24,9 +24,6 @@ public interface IPipelineMasterEndpoint { - /** - * Pipeline master process. - */ - Response process(Request request); - -} \ No newline at end of file + /** Pipeline master process. */ + Response process(Request request); +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceEndpointRef.java index aafdd8a13..1fac62b04 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceEndpointRef.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.rpc; import java.io.Serializable; + import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.ReleaseResponse; import org.apache.geaflow.cluster.resourcemanager.RequireResourceRequest; @@ -27,14 +28,9 @@ public interface IResourceEndpointRef extends Serializable { - /** - * Require resource. - */ - RequireResponse requireResource(RequireResourceRequest request); - - /** - * Release resource. - */ - ReleaseResponse releaseResource(ReleaseResourceRequest request); + /** Require resource. */ + RequireResponse requireResource(RequireResourceRequest request); + /** Release resource. */ + ReleaseResponse releaseResource(ReleaseResourceRequest request); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceManagerEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceManagerEndpoint.java index 866d2f4e2..8ff0a3c95 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceManagerEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/IResourceManagerEndpoint.java @@ -25,14 +25,9 @@ public interface IResourceManagerEndpoint { - /** - * Require resource. - */ - RequireResourceResponse requireResource(Resource.RequireResourceRequest request); - - /** - * Release resource. - */ - ReleaseResourceResponse releaseResource(Resource.ReleaseResourceRequest request); + /** Require resource. */ + RequireResourceResponse requireResource(Resource.RequireResourceRequest request); + /** Release resource. */ + ReleaseResourceResponse releaseResource(Resource.ReleaseResourceRequest request); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpoint.java index ca30875ee..fa491868c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpoint.java @@ -19,14 +19,14 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import org.apache.geaflow.rpc.proto.Supervisor.RestartRequest; import org.apache.geaflow.rpc.proto.Supervisor.StatusResponse; -public interface ISupervisorEndpoint extends RpcEndpoint { +import com.google.protobuf.Empty; - Empty restart(RestartRequest request); +public interface ISupervisorEndpoint extends RpcEndpoint { - StatusResponse status(Empty empty); + Empty restart(RestartRequest request); + StatusResponse status(Empty empty); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpointRef.java index 3f01d5888..ac14a16ac 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/ISupervisorEndpointRef.java @@ -19,14 +19,15 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import java.util.concurrent.Future; + import org.apache.geaflow.rpc.proto.Supervisor.StatusResponse; -public interface ISupervisorEndpointRef extends RpcEndpointRef { +import com.google.protobuf.Empty; - Future restart(int pid, RpcCallback callback); +public interface ISupervisorEndpointRef extends RpcEndpointRef { - StatusResponse status(); + Future restart(int pid, RpcCallback callback); + StatusResponse status(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcClient.java index a771364c9..ab4c1b34e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcClient.java @@ -29,7 +29,6 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.HEARTBEAT_TIMEOUT_MS; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.RPC_ASYNC_THREADS; -import com.google.protobuf.Empty; import java.io.Serializable; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -38,6 +37,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.ReleaseResponse; @@ -74,305 +74,350 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.protobuf.Empty; + public class RpcClient implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class); - private static final int RPC_RETRY_EXTRA_MS = 30000; - private static IHAService haService; - private static RpcEndpointRefFactory refFactory; - private static RpcClient INSTANCE; - private final int retryTimes; - private final int retryIntervalMs; - private final ExecutorService executorService; - - private RpcClient(Configuration configuration) { - // Ensure total retry time be longer than (heartbeat timeout + 30s). - retryIntervalMs = configuration.getInteger(ExecutionConfigKeys.RPC_RETRY_INTERVAL_MS); - int heartbeatTimeoutMs = configuration.getInteger(HEARTBEAT_TIMEOUT_MS); - int minTimes = (int) Math.ceil( - (double) (heartbeatTimeoutMs + RPC_RETRY_EXTRA_MS) / retryIntervalMs); - int rpcRetryTimes = configuration.getInteger(ExecutionConfigKeys.RPC_RETRY_TIMES); - retryTimes = Math.max(minTimes, rpcRetryTimes); - refFactory = RpcEndpointRefFactory.getInstance(configuration); - haService = HAServiceFactory.getService(configuration); - - int threads = configuration.getInteger(RPC_ASYNC_THREADS); - this.executorService = new ThreadPoolExecutor(threads, threads, Long.MAX_VALUE, - TimeUnit.MINUTES, new LinkedBlockingQueue<>(), + private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class); + private static final int RPC_RETRY_EXTRA_MS = 30000; + private static IHAService haService; + private static RpcEndpointRefFactory refFactory; + private static RpcClient INSTANCE; + private final int retryTimes; + private final int retryIntervalMs; + private final ExecutorService executorService; + + private RpcClient(Configuration configuration) { + // Ensure total retry time be longer than (heartbeat timeout + 30s). + retryIntervalMs = configuration.getInteger(ExecutionConfigKeys.RPC_RETRY_INTERVAL_MS); + int heartbeatTimeoutMs = configuration.getInteger(HEARTBEAT_TIMEOUT_MS); + int minTimes = + (int) Math.ceil((double) (heartbeatTimeoutMs + RPC_RETRY_EXTRA_MS) / retryIntervalMs); + int rpcRetryTimes = configuration.getInteger(ExecutionConfigKeys.RPC_RETRY_TIMES); + retryTimes = Math.max(minTimes, rpcRetryTimes); + refFactory = RpcEndpointRefFactory.getInstance(configuration); + haService = HAServiceFactory.getService(configuration); + + int threads = configuration.getInteger(RPC_ASYNC_THREADS); + this.executorService = + new ThreadPoolExecutor( + threads, + threads, + Long.MAX_VALUE, + TimeUnit.MINUTES, + new LinkedBlockingQueue<>(), ThreadUtil.namedThreadFactory(true, "rpc-executor")); - LOGGER.info("RpcClient init retryTimes:{} retryIntervalMs:{} threads:{}", retryTimes, - retryIntervalMs, threads); - } - - public static synchronized RpcClient init(Configuration configuration) { - if (INSTANCE == null) { - INSTANCE = new RpcClient(configuration); - } - return INSTANCE; - } - - public static synchronized RpcClient getInstance() { - return INSTANCE; - } - - // Master endpoint ref. - public void registerContainer(String masterId, T info, - RpcCallback callback) { - doRpcWithRetry(() -> { - MasterEndpointRef endpointRef = connectMaster(masterId); - if (endpointRef == null) { - LOGGER.warn("Cannot register container with master {}: endpoint not available", masterId); - return; - } - endpointRef.registerContainer(info, new DefaultRpcCallbackImpl<>(callback, masterId, haService)); - }, masterId, MASTER); - } - - public void sendHeartBeat(String masterId, Heartbeat heartbeat, - RpcCallback callback) { - doRpcWithRetry(() -> { - MasterEndpointRef endpointRef = connectMaster(masterId); - if (endpointRef == null) { - LOGGER.warn("Cannot send heartbeat to master {}: endpoint not available", masterId); - return; - } - endpointRef.sendHeartBeat(heartbeat, new DefaultRpcCallbackImpl<>(callback, masterId, haService)); - }, masterId, MASTER); - } - - public Empty sendException(String masterId, Integer containerId, String containerName, - Throwable throwable) { - return doRpcWithRetry(() -> { - MasterEndpointRef endpointRef = connectMaster(masterId); - if (endpointRef == null) { - LOGGER.warn("Cannot send exception to master {}: endpoint not available", masterId); - return Empty.getDefaultInstance(); - } - return endpointRef.sendException(containerId, containerName, throwable.getMessage()); - }, masterId, MASTER); - } - - // Container endpoint ref. - public Future processContainer(String containerId, IEvent event) { - return doRpcWithRetry(() -> { - ContainerEndpointRef endpointRef = connectContainer(containerId); - if (endpointRef == null) { - LOGGER.warn("Cannot process container event for {}: endpoint not available", containerId); - return null; - } - return endpointRef.process(event, new DefaultRpcCallbackImpl(null, containerId, haService)); - }, containerId, CONTAINER); - } - - public void processContainer(String containerId, IEvent event, RpcCallback callback) { - doRpcWithRetry(() -> { - ContainerEndpointRef endpointRef = connectContainer(containerId); - if (endpointRef == null) { - LOGGER.warn("Cannot process container event for {}: endpoint not available", containerId); - return; - } - endpointRef.process(event, new DefaultRpcCallbackImpl<>(callback, containerId, haService)); - }, containerId, CONTAINER); - } - - // Pipeline endpoint ref. - public void processPipeline(String driverId, IEvent event) { - doRpcWithRetry( - () -> connectPipelineManager(driverId).process(event, new DefaultRpcCallbackImpl<>()), - driverId, PIPELINE_MANAGER); - } - - public IPipelineResult executePipeline(String driverId, Pipeline pipeline) { - return doRpcWithRetry(() -> connectDriver(driverId).executePipeline(pipeline), driverId, - DRIVER); - } - - // Resource manager endpoint ref. - public RequireResponse requireResource(String masterId, RequireResourceRequest request) { - return doRpcWithRetry(() -> connectRM(masterId).requireResource(request), masterId, - RESOURCE_MANAGER); - } - - public ReleaseResponse releaseResource(String masterId, ReleaseResourceRequest request) { - return doRpcWithRetry(() -> connectRM(masterId).releaseResource(request), masterId, - RESOURCE_MANAGER); - } - - public Future requestMetrics(String id, MetricQueryRequest request, - RpcCallback callback) { - return doRpcWithRetry(() -> connectMetricServer(id).queryMetrics(request, - new DefaultRpcCallbackImpl<>(callback, id, haService)), id, METRIC); - } - - public Future restartWorkerBySupervisor(String id, boolean fastFailure) { - int retries = fastFailure ? 1 : retryTimes; - try { - return doRpcWithRetry(() -> { - ResourceData resourceData = loadSupervisorData(id, fastFailure); - return connectSupervisor(resourceData).restart(resourceData.getProcessId(), - new DefaultRpcCallbackImpl<>()); - }, id, SUPERVISOR, retries); - } catch (Throwable e) { - CompletableFuture result = new CompletableFuture<>(); - result.completeExceptionally(e); - return result; - } - } - - public StatusResponse queryWorkerStatusBySupervisor(String id) { - return connectSupervisor(id).status(); - } - - // Close endpoint connection. - public void closeMasterConnection(String masterId) { - MasterEndpointRef endpointRef = connectMaster(masterId); - if (endpointRef != null) { - endpointRef.closeEndpoint(); - } else { - LOGGER.debug("No endpoint reference found for master: {}, skipping close", masterId); - } - } - - public void closeDriverConnection(String driverId) { - DriverEndpointRef endpointRef = connectDriver(driverId); - if (endpointRef != null) { - endpointRef.closeEndpoint(); - } else { - LOGGER.debug("No endpoint reference found for driver: {}, skipping close", driverId); - } - } - - public void closeContainerConnection(String containerId) { - ContainerEndpointRef endpointRef = connectContainer(containerId); - if (endpointRef != null) { - endpointRef.closeEndpoint(); - } else { - LOGGER.debug("No endpoint reference found for container: {}, skipping close", containerId); - } - } - - private MasterEndpointRef connectMaster(String masterId) { - ResourceData resourceData = getResourceData(masterId); - if (resourceData == null) { - LOGGER.warn("Resource data not found for master: {}, skipping connection", masterId); - return null; - } - return refFactory.connectMaster(resourceData.getHost(), resourceData.getRpcPort()); - } - - private ResourceManagerEndpointRef connectRM(String masterId) { - ResourceData resourceData = getResourceData(masterId); - return refFactory.connectResourceManager(resourceData.getHost(), resourceData.getRpcPort()); - } - - private DriverEndpointRef connectDriver(String driverId) { - ResourceData resourceData = getResourceData(driverId); - if (resourceData == null) { - LOGGER.warn("Resource data not found for driver: {}, skipping connection", driverId); - return null; - } - return refFactory.connectDriver(resourceData.getHost(), resourceData.getRpcPort()); - } - - private ContainerEndpointRef connectContainer(String containerId) { - ResourceData resourceData = getResourceData(containerId); - if (resourceData == null) { - LOGGER.warn("Resource data not found for container: {}, skipping connection", containerId); + LOGGER.info( + "RpcClient init retryTimes:{} retryIntervalMs:{} threads:{}", + retryTimes, + retryIntervalMs, + threads); + } + + public static synchronized RpcClient init(Configuration configuration) { + if (INSTANCE == null) { + INSTANCE = new RpcClient(configuration); + } + return INSTANCE; + } + + public static synchronized RpcClient getInstance() { + return INSTANCE; + } + + // Master endpoint ref. + public void registerContainer( + String masterId, T info, RpcCallback callback) { + doRpcWithRetry( + () -> { + MasterEndpointRef endpointRef = connectMaster(masterId); + if (endpointRef == null) { + LOGGER.warn( + "Cannot register container with master {}: endpoint not available", masterId); + return; + } + endpointRef.registerContainer( + info, new DefaultRpcCallbackImpl<>(callback, masterId, haService)); + }, + masterId, + MASTER); + } + + public void sendHeartBeat( + String masterId, Heartbeat heartbeat, RpcCallback callback) { + doRpcWithRetry( + () -> { + MasterEndpointRef endpointRef = connectMaster(masterId); + if (endpointRef == null) { + LOGGER.warn("Cannot send heartbeat to master {}: endpoint not available", masterId); + return; + } + endpointRef.sendHeartBeat( + heartbeat, new DefaultRpcCallbackImpl<>(callback, masterId, haService)); + }, + masterId, + MASTER); + } + + public Empty sendException( + String masterId, Integer containerId, String containerName, Throwable throwable) { + return doRpcWithRetry( + () -> { + MasterEndpointRef endpointRef = connectMaster(masterId); + if (endpointRef == null) { + LOGGER.warn("Cannot send exception to master {}: endpoint not available", masterId); + return Empty.getDefaultInstance(); + } + return endpointRef.sendException(containerId, containerName, throwable.getMessage()); + }, + masterId, + MASTER); + } + + // Container endpoint ref. + public Future processContainer(String containerId, IEvent event) { + return doRpcWithRetry( + () -> { + ContainerEndpointRef endpointRef = connectContainer(containerId); + if (endpointRef == null) { + LOGGER.warn( + "Cannot process container event for {}: endpoint not available", containerId); return null; - } - return refFactory.connectContainer(resourceData.getHost(), resourceData.getRpcPort()); - } - - private PipelineMasterEndpointRef connectPipelineManager(String id) { - ResourceData resourceData = getResourceData(id); - return refFactory.connectPipelineManager(resourceData.getHost(), resourceData.getRpcPort()); - } - - private MetricEndpointRef connectMetricServer(String id) { - ResourceData resourceData = getResourceData(id); - return refFactory.connectMetricServer(resourceData.getHost(), resourceData.getMetricPort()); - } - - private SupervisorEndpointRef connectSupervisor(String id) { - ResourceData resourceData = loadSupervisorData(id, true); - return connectSupervisor(resourceData); - } - - private SupervisorEndpointRef connectSupervisor(ResourceData resourceData) { - return refFactory.connectSupervisor(resourceData.getHost(), - resourceData.getSupervisorPort()); - } - - private ResourceData loadSupervisorData(String id, boolean fastFailure) { - ResourceData resourceData; - if (fastFailure) { - resourceData = haService.loadResource(id); - } else { - resourceData = ((AbstractHAService) haService).loadDataFromStore(id, - true, ResourceData::getSupervisorPort); - } - return resourceData; - } - - private T doRpcWithRetry(Callable function, String resourceId, - EndpointType endpointType) { - return doRpcWithRetry(function, resourceId, endpointType, retryTimes); - } - - private T doRpcWithRetry(Callable function, String resourceId, EndpointType endpointType, - int retryTimes) { - return RetryCommand.run(() -> { - try { - return function.call(); - } catch (Throwable t) { - throw handleRpcException(resourceId, endpointType, t); - } - }, retryTimes, retryIntervalMs); - } - - private void doRpcWithRetry(Runnable function, String resourceId, EndpointType endpointType) { - RetryCommand.run(() -> { - try { - function.run(); - } catch (Throwable t) { - throw handleRpcException(resourceId, endpointType, t); - } - return null; - }, retryTimes, retryIntervalMs); - } - - private Exception handleRpcException(String resourceId, EndpointType endpointType, - Throwable t) { - try { - invalidateEndpointCache(resourceId, endpointType); - } catch (Throwable e) { - LOGGER.warn("invalidate rpc cache {} failed: {}", resourceId, e); - } - return new GeaflowRuntimeException(String.format("do rpc failed. %s", t.getMessage()), t); - } - - protected void invalidateEndpointCache(String resourceId, EndpointType endpointType) { - ResourceData resourceData = haService.invalidateResource(resourceId); - if (resourceData != null) { - refFactory.invalidateEndpointCache(resourceData.getHost(), resourceData.getRpcPort(), - endpointType); - } - } - - protected ResourceData getResourceData(String resourceId) { - if (haService == null) { - LOGGER.warn("HAService is not initialized, cannot resolve resource: {}", resourceId); - return null; - } - ResourceData resourceData = haService.resolveResource(resourceId); - if (resourceData == null) { - LOGGER.warn("Resource data not found for resource: {}", resourceId); - } - return resourceData; - } - - public ExecutorService getExecutor() { - return executorService; - } + } + return endpointRef.process( + event, new DefaultRpcCallbackImpl(null, containerId, haService)); + }, + containerId, + CONTAINER); + } + + public void processContainer(String containerId, IEvent event, RpcCallback callback) { + doRpcWithRetry( + () -> { + ContainerEndpointRef endpointRef = connectContainer(containerId); + if (endpointRef == null) { + LOGGER.warn( + "Cannot process container event for {}: endpoint not available", containerId); + return; + } + endpointRef.process( + event, new DefaultRpcCallbackImpl<>(callback, containerId, haService)); + }, + containerId, + CONTAINER); + } + + // Pipeline endpoint ref. + public void processPipeline(String driverId, IEvent event) { + doRpcWithRetry( + () -> connectPipelineManager(driverId).process(event, new DefaultRpcCallbackImpl<>()), + driverId, + PIPELINE_MANAGER); + } + + public IPipelineResult executePipeline(String driverId, Pipeline pipeline) { + return doRpcWithRetry( + () -> connectDriver(driverId).executePipeline(pipeline), driverId, DRIVER); + } + + // Resource manager endpoint ref. + public RequireResponse requireResource(String masterId, RequireResourceRequest request) { + return doRpcWithRetry( + () -> connectRM(masterId).requireResource(request), masterId, RESOURCE_MANAGER); + } + + public ReleaseResponse releaseResource(String masterId, ReleaseResourceRequest request) { + return doRpcWithRetry( + () -> connectRM(masterId).releaseResource(request), masterId, RESOURCE_MANAGER); + } + + public Future requestMetrics( + String id, MetricQueryRequest request, RpcCallback callback) { + return doRpcWithRetry( + () -> + connectMetricServer(id) + .queryMetrics(request, new DefaultRpcCallbackImpl<>(callback, id, haService)), + id, + METRIC); + } + + public Future restartWorkerBySupervisor(String id, boolean fastFailure) { + int retries = fastFailure ? 1 : retryTimes; + try { + return doRpcWithRetry( + () -> { + ResourceData resourceData = loadSupervisorData(id, fastFailure); + return connectSupervisor(resourceData) + .restart(resourceData.getProcessId(), new DefaultRpcCallbackImpl<>()); + }, + id, + SUPERVISOR, + retries); + } catch (Throwable e) { + CompletableFuture result = new CompletableFuture<>(); + result.completeExceptionally(e); + return result; + } + } + + public StatusResponse queryWorkerStatusBySupervisor(String id) { + return connectSupervisor(id).status(); + } + + // Close endpoint connection. + public void closeMasterConnection(String masterId) { + MasterEndpointRef endpointRef = connectMaster(masterId); + if (endpointRef != null) { + endpointRef.closeEndpoint(); + } else { + LOGGER.debug("No endpoint reference found for master: {}, skipping close", masterId); + } + } + + public void closeDriverConnection(String driverId) { + DriverEndpointRef endpointRef = connectDriver(driverId); + if (endpointRef != null) { + endpointRef.closeEndpoint(); + } else { + LOGGER.debug("No endpoint reference found for driver: {}, skipping close", driverId); + } + } + + public void closeContainerConnection(String containerId) { + ContainerEndpointRef endpointRef = connectContainer(containerId); + if (endpointRef != null) { + endpointRef.closeEndpoint(); + } else { + LOGGER.debug("No endpoint reference found for container: {}, skipping close", containerId); + } + } + + private MasterEndpointRef connectMaster(String masterId) { + ResourceData resourceData = getResourceData(masterId); + if (resourceData == null) { + LOGGER.warn("Resource data not found for master: {}, skipping connection", masterId); + return null; + } + return refFactory.connectMaster(resourceData.getHost(), resourceData.getRpcPort()); + } + + private ResourceManagerEndpointRef connectRM(String masterId) { + ResourceData resourceData = getResourceData(masterId); + return refFactory.connectResourceManager(resourceData.getHost(), resourceData.getRpcPort()); + } + + private DriverEndpointRef connectDriver(String driverId) { + ResourceData resourceData = getResourceData(driverId); + if (resourceData == null) { + LOGGER.warn("Resource data not found for driver: {}, skipping connection", driverId); + return null; + } + return refFactory.connectDriver(resourceData.getHost(), resourceData.getRpcPort()); + } + + private ContainerEndpointRef connectContainer(String containerId) { + ResourceData resourceData = getResourceData(containerId); + if (resourceData == null) { + LOGGER.warn("Resource data not found for container: {}, skipping connection", containerId); + return null; + } + return refFactory.connectContainer(resourceData.getHost(), resourceData.getRpcPort()); + } + + private PipelineMasterEndpointRef connectPipelineManager(String id) { + ResourceData resourceData = getResourceData(id); + return refFactory.connectPipelineManager(resourceData.getHost(), resourceData.getRpcPort()); + } + + private MetricEndpointRef connectMetricServer(String id) { + ResourceData resourceData = getResourceData(id); + return refFactory.connectMetricServer(resourceData.getHost(), resourceData.getMetricPort()); + } + + private SupervisorEndpointRef connectSupervisor(String id) { + ResourceData resourceData = loadSupervisorData(id, true); + return connectSupervisor(resourceData); + } + + private SupervisorEndpointRef connectSupervisor(ResourceData resourceData) { + return refFactory.connectSupervisor(resourceData.getHost(), resourceData.getSupervisorPort()); + } + + private ResourceData loadSupervisorData(String id, boolean fastFailure) { + ResourceData resourceData; + if (fastFailure) { + resourceData = haService.loadResource(id); + } else { + resourceData = + ((AbstractHAService) haService) + .loadDataFromStore(id, true, ResourceData::getSupervisorPort); + } + return resourceData; + } + + private T doRpcWithRetry(Callable function, String resourceId, EndpointType endpointType) { + return doRpcWithRetry(function, resourceId, endpointType, retryTimes); + } + + private T doRpcWithRetry( + Callable function, String resourceId, EndpointType endpointType, int retryTimes) { + return RetryCommand.run( + () -> { + try { + return function.call(); + } catch (Throwable t) { + throw handleRpcException(resourceId, endpointType, t); + } + }, + retryTimes, + retryIntervalMs); + } + + private void doRpcWithRetry(Runnable function, String resourceId, EndpointType endpointType) { + RetryCommand.run( + () -> { + try { + function.run(); + } catch (Throwable t) { + throw handleRpcException(resourceId, endpointType, t); + } + return null; + }, + retryTimes, + retryIntervalMs); + } + + private Exception handleRpcException(String resourceId, EndpointType endpointType, Throwable t) { + try { + invalidateEndpointCache(resourceId, endpointType); + } catch (Throwable e) { + LOGGER.warn("invalidate rpc cache {} failed: {}", resourceId, e); + } + return new GeaflowRuntimeException(String.format("do rpc failed. %s", t.getMessage()), t); + } + + protected void invalidateEndpointCache(String resourceId, EndpointType endpointType) { + ResourceData resourceData = haService.invalidateResource(resourceId); + if (resourceData != null) { + refFactory.invalidateEndpointCache( + resourceData.getHost(), resourceData.getRpcPort(), endpointType); + } + } + + protected ResourceData getResourceData(String resourceId) { + if (haService == null) { + LOGGER.warn("HAService is not initialized, cannot resolve resource: {}", resourceId); + return null; + } + ResourceData resourceData = haService.resolveResource(resourceId); + if (resourceData == null) { + LOGGER.warn("Resource data not found for resource: {}", resourceId); + } + return resourceData; + } + + public ExecutorService getExecutor() { + return executorService; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpoint.java index 39734c017..c2a2ab972 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpoint.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface RpcEndpoint extends Serializable { - -} +public interface RpcEndpoint extends Serializable {} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRef.java index d8bf942a7..3cb050867 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRef.java @@ -24,22 +24,15 @@ public interface RpcEndpointRef extends Closeable, Serializable { - /** - * Close rpc endpoint. - */ - void closeEndpoint(); + /** Close rpc endpoint. */ + void closeEndpoint(); - interface RpcCallback { + interface RpcCallback { - /** - * The callback for rpc process succeed. - */ - void onSuccess(T value); - - /** - * The callback for rpc process failed. - */ - void onFailure(Throwable t); - } + /** The callback for rpc process succeed. */ + void onSuccess(T value); + /** The callback for rpc process failed. */ + void onFailure(Throwable t); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefFactory.java index 6a61c736d..a2ec7fb2d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefFactory.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.rpc.impl.ContainerEndpointRef; import org.apache.geaflow.cluster.rpc.impl.DriverEndpointRef; import org.apache.geaflow.cluster.rpc.impl.MasterEndpointRef; @@ -33,206 +34,201 @@ import org.apache.geaflow.common.config.Configuration; public class RpcEndpointRefFactory implements Serializable { - private final Map endpointRefMap; - private final Configuration configuration; - - private static RpcEndpointRefFactory INSTANCE; - - private RpcEndpointRefFactory(Configuration config) { - this.endpointRefMap = new ConcurrentHashMap<>(); - this.configuration = config; - } - - public static synchronized RpcEndpointRefFactory getInstance(Configuration config) { - if (INSTANCE == null) { - INSTANCE = new RpcEndpointRefFactory(config); - } - return INSTANCE; - } - - public static synchronized RpcEndpointRefFactory getInstance() { - return INSTANCE; - } - - public MasterEndpointRef connectMaster(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.MASTER); - try { - return (MasterEndpointRef) endpointRefMap - .computeIfAbsent(refID, - key -> new MasterEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect master error, host " + host + " port " + port, t); - } - } - - public ResourceManagerEndpointRef connectResourceManager(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.RESOURCE_MANAGER); - try { - return (ResourceManagerEndpointRef) endpointRefMap - .computeIfAbsent(refID, - key -> new ResourceManagerEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect rm error, host " + host + " port " + port, t); - } - } - - public DriverEndpointRef connectDriver(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.DRIVER); - try { - return (DriverEndpointRef) endpointRefMap - .computeIfAbsent(refID, - key -> new DriverEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect driver error, host " + host + " port " + port, t); - } - } - - public PipelineMasterEndpointRef connectPipelineManager(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.PIPELINE_MANAGER); - try { - return (PipelineMasterEndpointRef) endpointRefMap - .computeIfAbsent(refID, - key -> new PipelineMasterEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect pipeline master error, host " + host + " port " + port, t); - } - } - - public ContainerEndpointRef connectContainer(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.CONTAINER); - try { - return (ContainerEndpointRef) endpointRefMap - .computeIfAbsent(refID, - key -> new ContainerEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect container error, host " + host + " port " + port, t); - } - } - - public SupervisorEndpointRef connectSupervisor(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.SUPERVISOR); - try { - return (SupervisorEndpointRef) endpointRefMap - .computeIfAbsent(refID, key -> new SupervisorEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect container error, host " + host + " port " + port, t); - } - } - - public MetricEndpointRef connectMetricServer(String host, int port) { - EndpointRefID refID = new EndpointRefID(host, port, EndpointType.METRIC); - try { - return (MetricEndpointRef) endpointRefMap - .computeIfAbsent(refID, key -> new MetricEndpointRef(host, port, configuration)); - } catch (Throwable t) { - invalidateRef(refID); - throw new RuntimeException("connect container error, host " + host + " port " + port, t); - } - } - - public void invalidateEndpointCache(String host, int port, EndpointType endpointType) { - invalidateRef(new EndpointRefID(host, port, endpointType)); - } - - public void invalidateRef(EndpointRefID refId) { - endpointRefMap.remove(refId); - } - - public enum EndpointType { - /** - * Master endpoint. - */ - MASTER, - /** - * ResourceManager endpoint. - */ - RESOURCE_MANAGER, - /** - * Driver endpoint. - */ - DRIVER, - /** - * Pipeline endpoint. - */ - PIPELINE_MANAGER, - /** - * Container endpoint. - */ - CONTAINER, - /** - * Worker endpoint. - */ - SUPERVISOR, - /** - * Metric query endpoint. - */ - METRIC - } - - public static class EndpointRefID { - - private String host; - private int port; - private EndpointType endpointType; - - public EndpointRefID(String host, int port, EndpointType endpointType) { - this.host = host; - this.port = port; - this.endpointType = endpointType; - } - - public String getHost() { - return host; - } - - public void setHost(String host) { - this.host = host; - } - - public int getPort() { - return port; - } - - public void setPort(int port) { - this.port = port; - } - - public EndpointType getEndpointType() { - return endpointType; - } - - public void setEndpointType(EndpointType endpointType) { - this.endpointType = endpointType; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - EndpointRefID that = (EndpointRefID) o; - return port == that.port && host.equals(that.host) && endpointType == that.endpointType; - } - - @Override - public int hashCode() { - return Objects.hash(host, port, endpointType); - } - - @Override - public String toString() { - return "EndpointRefID{" + "host='" + host + '\'' + ", port=" + port + ", endpointType=" - + endpointType + '}'; - } + private final Map endpointRefMap; + private final Configuration configuration; + + private static RpcEndpointRefFactory INSTANCE; + + private RpcEndpointRefFactory(Configuration config) { + this.endpointRefMap = new ConcurrentHashMap<>(); + this.configuration = config; + } + + public static synchronized RpcEndpointRefFactory getInstance(Configuration config) { + if (INSTANCE == null) { + INSTANCE = new RpcEndpointRefFactory(config); + } + return INSTANCE; + } + + public static synchronized RpcEndpointRefFactory getInstance() { + return INSTANCE; + } + + public MasterEndpointRef connectMaster(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.MASTER); + try { + return (MasterEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new MasterEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect master error, host " + host + " port " + port, t); + } + } + + public ResourceManagerEndpointRef connectResourceManager(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.RESOURCE_MANAGER); + try { + return (ResourceManagerEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new ResourceManagerEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect rm error, host " + host + " port " + port, t); + } + } + + public DriverEndpointRef connectDriver(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.DRIVER); + try { + return (DriverEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new DriverEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect driver error, host " + host + " port " + port, t); + } + } + + public PipelineMasterEndpointRef connectPipelineManager(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.PIPELINE_MANAGER); + try { + return (PipelineMasterEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new PipelineMasterEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException( + "connect pipeline master error, host " + host + " port " + port, t); + } + } + + public ContainerEndpointRef connectContainer(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.CONTAINER); + try { + return (ContainerEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new ContainerEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect container error, host " + host + " port " + port, t); + } + } + + public SupervisorEndpointRef connectSupervisor(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.SUPERVISOR); + try { + return (SupervisorEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new SupervisorEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect container error, host " + host + " port " + port, t); + } + } + + public MetricEndpointRef connectMetricServer(String host, int port) { + EndpointRefID refID = new EndpointRefID(host, port, EndpointType.METRIC); + try { + return (MetricEndpointRef) + endpointRefMap.computeIfAbsent( + refID, key -> new MetricEndpointRef(host, port, configuration)); + } catch (Throwable t) { + invalidateRef(refID); + throw new RuntimeException("connect container error, host " + host + " port " + port, t); + } + } + + public void invalidateEndpointCache(String host, int port, EndpointType endpointType) { + invalidateRef(new EndpointRefID(host, port, endpointType)); + } + + public void invalidateRef(EndpointRefID refId) { + endpointRefMap.remove(refId); + } + + public enum EndpointType { + /** Master endpoint. */ + MASTER, + /** ResourceManager endpoint. */ + RESOURCE_MANAGER, + /** Driver endpoint. */ + DRIVER, + /** Pipeline endpoint. */ + PIPELINE_MANAGER, + /** Container endpoint. */ + CONTAINER, + /** Worker endpoint. */ + SUPERVISOR, + /** Metric query endpoint. */ + METRIC + } + + public static class EndpointRefID { + + private String host; + private int port; + private EndpointType endpointType; + + public EndpointRefID(String host, int port, EndpointType endpointType) { + this.host = host; + this.port = port; + this.endpointType = endpointType; } + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + this.port = port; + } + + public EndpointType getEndpointType() { + return endpointType; + } + + public void setEndpointType(EndpointType endpointType) { + this.endpointType = endpointType; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EndpointRefID that = (EndpointRefID) o; + return port == that.port && host.equals(that.host) && endpointType == that.endpointType; + } + + @Override + public int hashCode() { + return Objects.hash(host, port, endpointType); + } + + @Override + public String toString() { + return "EndpointRefID{" + + "host='" + + host + + '\'' + + ", port=" + + port + + ", endpointType=" + + endpointType + + '}'; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcResponseFuture.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcResponseFuture.java index ee6e5ba8d..32d1d60bf 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcResponseFuture.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcResponseFuture.java @@ -19,56 +19,58 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.ByteString; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.common.encoder.RpcMessageEncoder; import org.apache.geaflow.rpc.proto.Container; import org.jetbrains.annotations.NotNull; +import com.google.protobuf.ByteString; + public class RpcResponseFuture implements Future { - private final Future delegate; + private final Future delegate; - public RpcResponseFuture(Future delegate) { - this.delegate = delegate; - } + public RpcResponseFuture(Future delegate) { + this.delegate = delegate; + } - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return delegate.cancel(mayInterruptIfRunning); - } + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return delegate.cancel(mayInterruptIfRunning); + } - @Override - public boolean isCancelled() { - return delegate.isCancelled(); - } + @Override + public boolean isCancelled() { + return delegate.isCancelled(); + } - @Override - public boolean isDone() { - return delegate.isDone(); - } + @Override + public boolean isDone() { + return delegate.isDone(); + } - @Override - public IEvent get() throws InterruptedException, ExecutionException { - return delegate.get(); - } + @Override + public IEvent get() throws InterruptedException, ExecutionException { + return delegate.get(); + } - @Override - public IEvent get(long timeout, @NotNull TimeUnit unit) - throws InterruptedException, ExecutionException, TimeoutException { - return delegate.get(timeout, unit); - } + @Override + public IEvent get(long timeout, @NotNull TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegate.get(timeout, unit); + } - private IEvent getEvent(Container.Response response) { - ByteString payload = response.getPayload(); - if (payload == ByteString.EMPTY) { - return null; - } else { - return RpcMessageEncoder.decode(payload); - } + private IEvent getEvent(Container.Response response) { + ByteString payload = response.getPayload(); + if (payload == ByteString.EMPTY) { + return null; + } else { + return RpcMessageEncoder.decode(payload); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcService.java index 517a6dd51..98e5d5e19 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcService.java @@ -21,19 +21,12 @@ public interface RpcService { - /** - * start rpc service. - */ - int startService(); + /** start rpc service. */ + int startService(); - /** - * terminate the rpc service. - */ - void stopService(); - - /** - * Waits for the service to become terminated. - */ - void waitTermination(); + /** terminate the rpc service. */ + void stopService(); + /** Waits for the service to become terminated. */ + void waitTermination(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcUtil.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcUtil.java index dfc682d76..22d3be7dd 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcUtil.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/RpcUtil.java @@ -19,40 +19,38 @@ package org.apache.geaflow.cluster.rpc; -import com.baidu.brpc.client.RpcCallback; -import io.grpc.Context; import java.io.Serializable; import java.util.concurrent.CompletableFuture; -public class RpcUtil implements Serializable { +import com.baidu.brpc.client.RpcCallback; - public static void asyncExecute(Runnable runnable) { - Context.current().fork().run(runnable); - } - - /** - * Build brpc callback. - */ - public static RpcCallback buildRpcCallback(RpcEndpointRef.RpcCallback listener, - CompletableFuture result) { - return new RpcCallback() { - @Override - public void success(T response) { - if (listener != null) { - listener.onSuccess(response); - } - result.complete(response); - } - - @Override - public void fail(Throwable t) { - if (listener != null) { - listener.onFailure(t); - } - result.completeExceptionally(t); - } - }; - } +import io.grpc.Context; +public class RpcUtil implements Serializable { + public static void asyncExecute(Runnable runnable) { + Context.current().fork().run(runnable); + } + + /** Build brpc callback. */ + public static RpcCallback buildRpcCallback( + RpcEndpointRef.RpcCallback listener, CompletableFuture result) { + return new RpcCallback() { + @Override + public void success(T response) { + if (listener != null) { + listener.onSuccess(response); + } + result.complete(response); + } + + @Override + public void fail(Throwable t) { + if (listener != null) { + listener.onFailure(t); + } + result.completeExceptionally(t); + } + }; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/AbstractRpcEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/AbstractRpcEndpointRef.java index 103bddb27..9f7707462 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/AbstractRpcEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/AbstractRpcEndpointRef.java @@ -19,44 +19,44 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.RpcClient; -import com.baidu.brpc.client.RpcClientOptions; -import com.baidu.brpc.client.channel.Endpoint; import org.apache.geaflow.cluster.rpc.RpcEndpointRef; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.rpc.ConfigurableClientOption; -public abstract class AbstractRpcEndpointRef implements RpcEndpointRef { - - protected RpcClient rpcClient; - protected final String host; - protected final int port; - protected final Configuration configuration; - - public AbstractRpcEndpointRef(String host, int port, Configuration configuration) { - this.host = host; - this.port = port; - this.configuration = configuration; - this.rpcClient = new RpcClient(new Endpoint(host, port), getClientOptions()); - getRpcEndpoint(); - } - - protected abstract void getRpcEndpoint(); - - protected synchronized RpcClientOptions getClientOptions() { - return ConfigurableClientOption.build(configuration); - } +import com.baidu.brpc.client.RpcClient; +import com.baidu.brpc.client.RpcClientOptions; +import com.baidu.brpc.client.channel.Endpoint; - @Override - public void closeEndpoint() { - close(); - } +public abstract class AbstractRpcEndpointRef implements RpcEndpointRef { - @Override - public void close() { - if (!rpcClient.isShutdown()) { - this.rpcClient.stop(); - } + protected RpcClient rpcClient; + protected final String host; + protected final int port; + protected final Configuration configuration; + + public AbstractRpcEndpointRef(String host, int port, Configuration configuration) { + this.host = host; + this.port = port; + this.configuration = configuration; + this.rpcClient = new RpcClient(new Endpoint(host, port), getClientOptions()); + getRpcEndpoint(); + } + + protected abstract void getRpcEndpoint(); + + protected synchronized RpcClientOptions getClientOptions() { + return ConfigurableClientOption.build(configuration); + } + + @Override + public void closeEndpoint() { + close(); + } + + @Override + public void close() { + if (!rpcClient.isShutdown()) { + this.rpcClient.stop(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpoint.java index 5f1457f8e..2de145399 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpoint.java @@ -19,8 +19,6 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import org.apache.geaflow.cluster.container.Container; import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.protocol.OpenContainerEvent; @@ -32,46 +30,49 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + public class ContainerEndpoint implements IContainerEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(ContainerEndpoint.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ContainerEndpoint.class); - private final Container container; + private final Container container; - public ContainerEndpoint(Container workerContainer) { - this.container = workerContainer; - } + public ContainerEndpoint(Container workerContainer) { + this.container = workerContainer; + } - @Override - public Response process(Request request) { - Response.Builder builder = Response.newBuilder(); - try { - IEvent res; - IEvent event = RpcMessageEncoder.decode(request.getPayload()); - if (event instanceof OpenContainerEvent) { - res = container.open((OpenContainerEvent) event); - } else { - res = container.process(event); - } - if (res != null) { - ByteString payload = RpcMessageEncoder.encode(res); - builder.setPayload(payload); - } - return builder.build(); - } catch (Throwable t) { - LOGGER.error("process request failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException("process request failed", t); - } + @Override + public Response process(Request request) { + Response.Builder builder = Response.newBuilder(); + try { + IEvent res; + IEvent event = RpcMessageEncoder.decode(request.getPayload()); + if (event instanceof OpenContainerEvent) { + res = container.open((OpenContainerEvent) event); + } else { + res = container.process(event); + } + if (res != null) { + ByteString payload = RpcMessageEncoder.encode(res); + builder.setPayload(payload); + } + return builder.build(); + } catch (Throwable t) { + LOGGER.error("process request failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException("process request failed", t); } + } - @Override - public Empty close(Empty request) { - try { - container.close(); - return Empty.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("close failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); - } + @Override + public Empty close(Empty request) { + try { + container.close(); + return Empty.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("close failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpointRef.java index 5e0aec3a8..a95f78248 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ContainerEndpointRef.java @@ -19,11 +19,9 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.rpc.IAsyncContainerEndpoint; import org.apache.geaflow.cluster.rpc.IContainerEndpointRef; @@ -33,57 +31,62 @@ import org.apache.geaflow.rpc.proto.Container.Request; import org.apache.geaflow.rpc.proto.Container.Response; +import com.baidu.brpc.client.BrpcProxy; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + public class ContainerEndpointRef extends AbstractRpcEndpointRef implements IContainerEndpointRef { - protected IAsyncContainerEndpoint containerEndpoint; + protected IAsyncContainerEndpoint containerEndpoint; - public ContainerEndpointRef(String host, int port, Configuration configuration) { - super(host, port, configuration); - } + public ContainerEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + } - @Override - protected void getRpcEndpoint() { - this.containerEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncContainerEndpoint.class); - } + @Override + protected void getRpcEndpoint() { + this.containerEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncContainerEndpoint.class); + } - @Override - public Future process(IEvent request, RpcCallback callback) { - CompletableFuture result = new CompletableFuture<>(); - Container.Request req = buildRequest(request); - this.containerEndpoint.process(req, new com.baidu.brpc.client.RpcCallback() { - @Override - public void success(Response response) { - if (callback != null) { - callback.onSuccess(response); - } - ByteString payload = response.getPayload(); - IEvent event; - if (payload == ByteString.EMPTY) { - event = null; - } else { - event = RpcMessageEncoder.decode(payload); - } - result.complete(event); + @Override + public Future process(IEvent request, RpcCallback callback) { + CompletableFuture result = new CompletableFuture<>(); + Container.Request req = buildRequest(request); + this.containerEndpoint.process( + req, + new com.baidu.brpc.client.RpcCallback() { + @Override + public void success(Response response) { + if (callback != null) { + callback.onSuccess(response); } - - @Override - public void fail(Throwable throwable) { - callback.onFailure(throwable); - result.completeExceptionally(throwable); + ByteString payload = response.getPayload(); + IEvent event; + if (payload == ByteString.EMPTY) { + event = null; + } else { + event = RpcMessageEncoder.decode(payload); } - }); - return result; - } + result.complete(event); + } - @Override - public void closeEndpoint() { - this.containerEndpoint.close(Empty.newBuilder().build()); - super.closeEndpoint(); - } + @Override + public void fail(Throwable throwable) { + callback.onFailure(throwable); + result.completeExceptionally(throwable); + } + }); + return result; + } - protected Request buildRequest(IEvent request) { - ByteString payload = RpcMessageEncoder.encode(request); - return Request.newBuilder().setPayload(payload).build(); - } + @Override + public void closeEndpoint() { + this.containerEndpoint.close(Empty.newBuilder().build()); + super.closeEndpoint(); + } + protected Request buildRequest(IEvent request) { + ByteString payload = RpcMessageEncoder.encode(request); + return Request.newBuilder().setPayload(payload).build(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DefaultRpcCallbackImpl.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DefaultRpcCallbackImpl.java index c6f681b65..f1e137b49 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DefaultRpcCallbackImpl.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DefaultRpcCallbackImpl.java @@ -24,38 +24,36 @@ public class DefaultRpcCallbackImpl implements RpcCallback { - private RpcCallback rpcCallback; + private RpcCallback rpcCallback; - private String resourceId; + private String resourceId; - private IHAService haService; + private IHAService haService; - public DefaultRpcCallbackImpl(RpcCallback rpcCallback, String resourceId, IHAService haService) { - this.rpcCallback = rpcCallback; - this.resourceId = resourceId; - this.haService = haService; - } + public DefaultRpcCallbackImpl( + RpcCallback rpcCallback, String resourceId, IHAService haService) { + this.rpcCallback = rpcCallback; + this.resourceId = resourceId; + this.haService = haService; + } - public DefaultRpcCallbackImpl() { + public DefaultRpcCallbackImpl() {} + @Override + public void onSuccess(T value) { + if (rpcCallback != null) { + rpcCallback.onSuccess(value); } + } - @Override - public void onSuccess(T value) { - if (rpcCallback != null) { - rpcCallback.onSuccess(value); - } + @Override + public void onFailure(Throwable t) { + if (rpcCallback != null) { + rpcCallback.onFailure(t); } - @Override - public void onFailure(Throwable t) { - if (rpcCallback != null) { - rpcCallback.onFailure(t); - } - - if (resourceId != null) { - haService.invalidateResource(resourceId); - } + if (resourceId != null) { + haService.invalidateResource(resourceId); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpoint.java index c6f18de10..7cabaf8b6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpoint.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.google.protobuf.Empty; import org.apache.geaflow.cluster.driver.IDriver; import org.apache.geaflow.cluster.rpc.IDriverEndpoint; import org.apache.geaflow.common.encoder.RpcMessageEncoder; @@ -30,39 +29,39 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.protobuf.Empty; + public class DriverEndpoint implements IDriverEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(DriverEndpoint.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DriverEndpoint.class); - private final IDriver driver; + private final IDriver driver; - public DriverEndpoint(IDriver driver) { - this.driver = driver; - } + public DriverEndpoint(IDriver driver) { + this.driver = driver; + } - @Override - public PipelineRes executePipeline(PipelineReq request) { - try { - Pipeline pipeline = RpcMessageEncoder.decode(request.getPayload()); - Object result = driver.executePipeline(pipeline); - return PipelineRes.newBuilder() - .setPayload(RpcMessageEncoder.encode(result)) - .build(); - } catch (Throwable e) { - LOGGER.error("execute pipeline failed: {}", e.getMessage(), e); - throw new GeaflowRuntimeException(String.format("execute pipeline failed: %s", - e.getMessage()), e); - } + @Override + public PipelineRes executePipeline(PipelineReq request) { + try { + Pipeline pipeline = RpcMessageEncoder.decode(request.getPayload()); + Object result = driver.executePipeline(pipeline); + return PipelineRes.newBuilder().setPayload(RpcMessageEncoder.encode(result)).build(); + } catch (Throwable e) { + LOGGER.error("execute pipeline failed: {}", e.getMessage(), e); + throw new GeaflowRuntimeException( + String.format("execute pipeline failed: %s", e.getMessage()), e); } + } - @Override - public Empty close(Empty request) { - try { - driver.close(); - return Empty.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("close failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage(), t)); - } + @Override + public Empty close(Empty request) { + try { + driver.close(); + return Empty.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("close failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage(), t)); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpointRef.java index 95213fb36..585b61352 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/DriverEndpointRef.java @@ -19,10 +19,8 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import java.util.concurrent.CompletableFuture; + import org.apache.geaflow.cluster.client.PipelineResult; import org.apache.geaflow.cluster.rpc.IAsyncDriverEndpoint; import org.apache.geaflow.cluster.rpc.IDriverEndpointRef; @@ -36,35 +34,41 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.baidu.brpc.client.BrpcProxy; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + public class DriverEndpointRef extends AbstractRpcEndpointRef implements IDriverEndpointRef { - private static final Logger LOGGER = LoggerFactory.getLogger(DriverEndpointRef.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DriverEndpointRef.class); - private IAsyncDriverEndpoint driverEndpoint; + private IAsyncDriverEndpoint driverEndpoint; - public DriverEndpointRef(String host, int port, Configuration configuration) { - super(host, port, configuration); - } + public DriverEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + } - @Override - protected void getRpcEndpoint() { - this.driverEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncDriverEndpoint.class); - } + @Override + protected void getRpcEndpoint() { + this.driverEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncDriverEndpoint.class); + } - @Override - public IPipelineResult executePipeline(Pipeline pipeline) { - LOGGER.info("send pipeline to driver, driver host:{}, port:{}. {}", super.host, super.port, pipeline); - ByteString payload = RpcMessageEncoder.encode(pipeline); - PipelineReq req = PipelineReq.newBuilder().setPayload(payload).build(); - CompletableFuture result = new CompletableFuture<>(); - com.baidu.brpc.client.RpcCallback rpcCallback = RpcUtil.buildRpcCallback(null, result); - this.driverEndpoint.executePipeline(req, rpcCallback); - return new PipelineResult(result); - } + @Override + public IPipelineResult executePipeline(Pipeline pipeline) { + LOGGER.info( + "send pipeline to driver, driver host:{}, port:{}. {}", super.host, super.port, pipeline); + ByteString payload = RpcMessageEncoder.encode(pipeline); + PipelineReq req = PipelineReq.newBuilder().setPayload(payload).build(); + CompletableFuture result = new CompletableFuture<>(); + com.baidu.brpc.client.RpcCallback rpcCallback = + RpcUtil.buildRpcCallback(null, result); + this.driverEndpoint.executePipeline(req, rpcCallback); + return new PipelineResult(result); + } - @Override - public void closeEndpoint() { - this.driverEndpoint.close(Empty.newBuilder().build()); - super.closeEndpoint(); - } + @Override + public void closeEndpoint() { + this.driverEndpoint.close(Empty.newBuilder().build()); + super.closeEndpoint(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpoint.java index 2d226f21a..29cd32d2c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpoint.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.google.protobuf.Empty; import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.IClusterManager; import org.apache.geaflow.cluster.container.ContainerInfo; @@ -37,83 +36,85 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.protobuf.Empty; + public class MasterEndpoint implements IMasterEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(MasterEndpoint.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MasterEndpoint.class); - private final IMaster master; - private final IClusterManager clusterManager; + private final IMaster master; + private final IClusterManager clusterManager; - public MasterEndpoint(IMaster master, IClusterManager clusterManager) { - this.master = master; - this.clusterManager = clusterManager; - } + public MasterEndpoint(IMaster master, IClusterManager clusterManager) { + this.master = master; + this.clusterManager = clusterManager; + } - @Override - public RegisterResponse registerContainer(RegisterRequest request) { - try { - ContainerInfo containerInfo = RpcMessageEncoder.decode(request.getPayload()); - return ((AbstractClusterManager) clusterManager).registerContainer(containerInfo); - } catch (Throwable t) { - LOGGER.error("register container failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("register container failed: %s", - t.getMessage()), t); - } + @Override + public RegisterResponse registerContainer(RegisterRequest request) { + try { + ContainerInfo containerInfo = RpcMessageEncoder.decode(request.getPayload()); + return ((AbstractClusterManager) clusterManager).registerContainer(containerInfo); + } catch (Throwable t) { + LOGGER.error("register container failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("register container failed: %s", t.getMessage()), t); } + } - @Override - public RegisterResponse registerDriver(RegisterRequest request) { - try { - DriverInfo driverInfo = RpcMessageEncoder.decode(request.getPayload()); - return ((AbstractClusterManager) clusterManager).registerDriver(driverInfo); - } catch (Throwable t) { - LOGGER.error("register driver failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("register driver failed: %s", - t.getMessage()), t); - } + @Override + public RegisterResponse registerDriver(RegisterRequest request) { + try { + DriverInfo driverInfo = RpcMessageEncoder.decode(request.getPayload()); + return ((AbstractClusterManager) clusterManager).registerDriver(driverInfo); + } catch (Throwable t) { + LOGGER.error("register driver failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("register driver failed: %s", t.getMessage()), t); } + } - @Override - public HeartbeatResponse receiveHeartbeat(HeartbeatRequest request) { - try { - Heartbeat heartbeat = new Heartbeat(request.getId()); - heartbeat.setTimestamp(request.getTimestamp()); - heartbeat.setContainerName(RpcMessageEncoder.decode(request.getName())); - heartbeat.setProcessMetrics(RpcMessageEncoder.decode(request.getPayload())); - HeartbeatManager heartbeatManager = - ((AbstractClusterManager) clusterManager).getClusterContext().getHeartbeatManager(); - return heartbeatManager.receivedHeartbeat(heartbeat); - } catch (Throwable t) { - LOGGER.error("process {} heartbeat failed: {}", request.getId(), t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("process %s heartbeat failed: %s", - request.getId(), t.getMessage()), t); - } + @Override + public HeartbeatResponse receiveHeartbeat(HeartbeatRequest request) { + try { + Heartbeat heartbeat = new Heartbeat(request.getId()); + heartbeat.setTimestamp(request.getTimestamp()); + heartbeat.setContainerName(RpcMessageEncoder.decode(request.getName())); + heartbeat.setProcessMetrics(RpcMessageEncoder.decode(request.getPayload())); + HeartbeatManager heartbeatManager = + ((AbstractClusterManager) clusterManager).getClusterContext().getHeartbeatManager(); + return heartbeatManager.receivedHeartbeat(heartbeat); + } catch (Throwable t) { + LOGGER.error("process {} heartbeat failed: {}", request.getId(), t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("process %s heartbeat failed: %s", request.getId(), t.getMessage()), t); } + } - @Override - public Empty receiveException(HeartbeatRequest request) { - try { - int containerId = request.getId(); - String containerName = RpcMessageEncoder.decode(request.getName()); - String errMessage = RpcMessageEncoder.decode(request.getPayload()); - LOGGER.info("received exception from {}: {}", containerName, errMessage); - clusterManager.doFailover(containerId, new RuntimeException(errMessage)); - return Empty.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("process {} heartbeat failed: {}", request.getId(), t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("process %s heartbeat failed: %s", - request.getId(), t.getMessage()), t); - } + @Override + public Empty receiveException(HeartbeatRequest request) { + try { + int containerId = request.getId(); + String containerName = RpcMessageEncoder.decode(request.getName()); + String errMessage = RpcMessageEncoder.decode(request.getPayload()); + LOGGER.info("received exception from {}: {}", containerName, errMessage); + clusterManager.doFailover(containerId, new RuntimeException(errMessage)); + return Empty.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("process {} heartbeat failed: {}", request.getId(), t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("process %s heartbeat failed: %s", request.getId(), t.getMessage()), t); } + } - @Override - public Empty close(Empty request) { - try { - master.close(); - return Empty.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("close failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); - } + @Override + public Empty close(Empty request) { + try { + master.close(); + return Empty.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("close failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpointRef.java index e61025904..88a820cc0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MasterEndpointRef.java @@ -19,11 +19,9 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.google.protobuf.ByteString; -import com.google.protobuf.Empty; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.driver.DriverInfo; import org.apache.geaflow.cluster.rpc.IAsyncMasterEndpoint; import org.apache.geaflow.cluster.rpc.IMasterEndpointRef; @@ -37,59 +35,69 @@ import org.apache.geaflow.rpc.proto.Master.RegisterRequest; import org.apache.geaflow.rpc.proto.Master.RegisterResponse; +import com.baidu.brpc.client.BrpcProxy; +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; + public class MasterEndpointRef extends AbstractRpcEndpointRef implements IMasterEndpointRef { - private IAsyncMasterEndpoint masterEndpoint; + private IAsyncMasterEndpoint masterEndpoint; - public MasterEndpointRef(String host, int port, Configuration configuration) { - super(host, port, configuration); - } + public MasterEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + } - @Override - protected void getRpcEndpoint() { - this.masterEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncMasterEndpoint.class); - } + @Override + protected void getRpcEndpoint() { + this.masterEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncMasterEndpoint.class); + } - public Future registerContainer(T info, RpcEndpointRef.RpcCallback callback) { - CompletableFuture result = new CompletableFuture<>(); - ByteString payload = RpcMessageEncoder.encode(info); - RegisterRequest register = RegisterRequest.newBuilder().setPayload(payload).build(); - com.baidu.brpc.client.RpcCallback rpcCallback = RpcUtil.buildRpcCallback(callback, result); - if (info instanceof DriverInfo) { - this.masterEndpoint.registerDriver(register, rpcCallback); - } else { - this.masterEndpoint.registerContainer(register, rpcCallback); - } - return result; + public Future registerContainer( + T info, RpcEndpointRef.RpcCallback callback) { + CompletableFuture result = new CompletableFuture<>(); + ByteString payload = RpcMessageEncoder.encode(info); + RegisterRequest register = RegisterRequest.newBuilder().setPayload(payload).build(); + com.baidu.brpc.client.RpcCallback rpcCallback = + RpcUtil.buildRpcCallback(callback, result); + if (info instanceof DriverInfo) { + this.masterEndpoint.registerDriver(register, rpcCallback); + } else { + this.masterEndpoint.registerContainer(register, rpcCallback); } + return result; + } - @Override - public Future sendHeartBeat(Heartbeat heartbeat, RpcCallback callback) { - CompletableFuture result = new CompletableFuture<>(); - HeartbeatRequest heartbeatRequest = HeartbeatRequest.newBuilder() + @Override + public Future sendHeartBeat( + Heartbeat heartbeat, RpcCallback callback) { + CompletableFuture result = new CompletableFuture<>(); + HeartbeatRequest heartbeatRequest = + HeartbeatRequest.newBuilder() .setId(heartbeat.getContainerId()) .setTimestamp(heartbeat.getTimestamp()) .setName(RpcMessageEncoder.encode(heartbeat.getContainerName())) .setPayload(RpcMessageEncoder.encode(heartbeat.getProcessMetrics())) .build(); - com.baidu.brpc.client.RpcCallback rpcCallback = RpcUtil.buildRpcCallback(callback, result); - this.masterEndpoint.receiveHeartbeat(heartbeatRequest, rpcCallback); - return result; - } + com.baidu.brpc.client.RpcCallback rpcCallback = + RpcUtil.buildRpcCallback(callback, result); + this.masterEndpoint.receiveHeartbeat(heartbeatRequest, rpcCallback); + return result; + } - @Override - public Empty sendException(Integer containerId, String containerName, String message) { - HeartbeatRequest heartbeatRequest = HeartbeatRequest.newBuilder() + @Override + public Empty sendException(Integer containerId, String containerName, String message) { + HeartbeatRequest heartbeatRequest = + HeartbeatRequest.newBuilder() .setId(containerId) .setName(RpcMessageEncoder.encode(containerName)) .setPayload(RpcMessageEncoder.encode(message)) .build(); - return masterEndpoint.receiveException(heartbeatRequest); - } + return masterEndpoint.receiveException(heartbeatRequest); + } - @Override - public void closeEndpoint() { - this.masterEndpoint.close(Empty.newBuilder().build()); - super.closeEndpoint(); - } + @Override + public void closeEndpoint() { + this.masterEndpoint.close(Empty.newBuilder().build()); + super.closeEndpoint(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpoint.java index dd66c13ad..80acb79c2 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpoint.java @@ -32,24 +32,24 @@ public class MetricEndpoint implements IMetricEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(MetricEndpoint.class); - private final Configuration configuration; + private static final Logger LOGGER = LoggerFactory.getLogger(MetricEndpoint.class); + private final Configuration configuration; - public MetricEndpoint(Configuration configuration) { - this.configuration = configuration; - } + public MetricEndpoint(Configuration configuration) { + this.configuration = configuration; + } - @Override - public MetricQueryResponse queryMetrics(MetricQueryRequest request) { - try { - MetricCache cache = StatsCollectorFactory.init(configuration).getMetricCache(); - MetricQueryResponse.Builder builder = MetricQueryResponse.newBuilder(); - builder.setPayload(RpcMessageEncoder.encode(cache)); - return builder.build(); - } catch (Throwable t) { - LOGGER.error("process request failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("process request failed: %s", t.getMessage()), - t); - } + @Override + public MetricQueryResponse queryMetrics(MetricQueryRequest request) { + try { + MetricCache cache = StatsCollectorFactory.init(configuration).getMetricCache(); + MetricQueryResponse.Builder builder = MetricQueryResponse.newBuilder(); + builder.setPayload(RpcMessageEncoder.encode(cache)); + return builder.build(); + } catch (Throwable t) { + LOGGER.error("process request failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("process request failed: %s", t.getMessage()), t); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpointRef.java index 74a0a821b..d2be271bd 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/MetricEndpointRef.java @@ -19,11 +19,9 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.baidu.brpc.client.RpcClientOptions; -import com.baidu.brpc.loadbalance.LoadBalanceStrategy; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.rpc.IAsyncMetricEndpoint; import org.apache.geaflow.cluster.rpc.IMetricEndpointRef; import org.apache.geaflow.cluster.rpc.RpcEndpointRefFactory; @@ -34,45 +32,49 @@ import org.apache.geaflow.rpc.proto.Metrics.MetricQueryRequest; import org.apache.geaflow.rpc.proto.Metrics.MetricQueryResponse; +import com.baidu.brpc.client.BrpcProxy; +import com.baidu.brpc.client.RpcClientOptions; +import com.baidu.brpc.loadbalance.LoadBalanceStrategy; + public class MetricEndpointRef extends AbstractRpcEndpointRef implements IMetricEndpointRef { - private IAsyncMetricEndpoint metricEndpoint; - private final EndpointRefID refID; + private IAsyncMetricEndpoint metricEndpoint; + private final EndpointRefID refID; - public MetricEndpointRef(String host, int port, Configuration configuration) { - super(host, port, configuration); - this.refID = new EndpointRefID(host, port, EndpointType.METRIC); - } + public MetricEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + this.refID = new EndpointRefID(host, port, EndpointType.METRIC); + } - @Override - protected void getRpcEndpoint() { - this.metricEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncMetricEndpoint.class); - } + @Override + protected void getRpcEndpoint() { + this.metricEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncMetricEndpoint.class); + } - @Override - protected RpcClientOptions getClientOptions() { - RpcClientOptions options = super.getClientOptions(); - options.setGlobalThreadPoolSharing(false); - options.setMaxTotalConnections(2); - options.setMinIdleConnections(1); - options.setIoThreadNum(1); - options.setWorkThreadNum(1); - options.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); - return options; - } + @Override + protected RpcClientOptions getClientOptions() { + RpcClientOptions options = super.getClientOptions(); + options.setGlobalThreadPoolSharing(false); + options.setMaxTotalConnections(2); + options.setMinIdleConnections(1); + options.setIoThreadNum(1); + options.setWorkThreadNum(1); + options.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); + return options; + } - @Override - public Future queryMetrics(MetricQueryRequest request, RpcCallback callback) { - CompletableFuture result = new CompletableFuture<>(); - com.baidu.brpc.client.RpcCallback rpcCallback = - RpcUtil.buildRpcCallback(callback, result); - try { - this.metricEndpoint.queryMetrics(request, rpcCallback); - } catch (Throwable e) { - rpcCallback.fail(e); - RpcEndpointRefFactory.getInstance().invalidateRef(refID); - } - return result; + @Override + public Future queryMetrics( + MetricQueryRequest request, RpcCallback callback) { + CompletableFuture result = new CompletableFuture<>(); + com.baidu.brpc.client.RpcCallback rpcCallback = + RpcUtil.buildRpcCallback(callback, result); + try { + this.metricEndpoint.queryMetrics(request, rpcCallback); + } catch (Throwable e) { + rpcCallback.fail(e); + RpcEndpointRefFactory.getInstance().invalidateRef(refID); } - + return result; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpoint.java index 4228f6a8f..304341c29 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpoint.java @@ -31,24 +31,24 @@ public class PipelineMasterEndpoint implements IPipelineMasterEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineMasterEndpoint.class); - - private final Driver driver; - - public PipelineMasterEndpoint(Driver driver) { - this.driver = driver; + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineMasterEndpoint.class); + + private final Driver driver; + + public PipelineMasterEndpoint(Driver driver) { + this.driver = driver; + } + + @Override + public Response process(Request request) { + try { + IEvent event = RpcMessageEncoder.decode(request.getPayload()); + driver.process(event); + return Response.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("process event failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("process event failed: %s", t.getMessage()), t); } - - @Override - public Response process(Request request) { - try { - IEvent event = RpcMessageEncoder.decode(request.getPayload()); - driver.process(event); - return Response.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("process event failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("process event failed: %s", t.getMessage()), t); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpointRef.java index 0f1ede5b6..95aae9384 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/PipelineMasterEndpointRef.java @@ -19,9 +19,8 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.google.protobuf.ByteString; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.rpc.IPipelineManagerEndpointRef; import org.apache.geaflow.cluster.rpc.IPipelineMasterEndpoint; @@ -31,29 +30,32 @@ import org.apache.geaflow.rpc.proto.Container.Request; import org.apache.geaflow.rpc.proto.Container.Response; -public class PipelineMasterEndpointRef extends AbstractRpcEndpointRef implements IPipelineManagerEndpointRef { +import com.baidu.brpc.client.BrpcProxy; +import com.google.protobuf.ByteString; + +public class PipelineMasterEndpointRef extends AbstractRpcEndpointRef + implements IPipelineManagerEndpointRef { - protected IPipelineMasterEndpoint pipelineMasterEndpoint; + protected IPipelineMasterEndpoint pipelineMasterEndpoint; - public PipelineMasterEndpointRef(String host, int port, - Configuration configuration) { - super(host, port, configuration); - } + public PipelineMasterEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + } - @Override - protected void getRpcEndpoint() { - this.pipelineMasterEndpoint = BrpcProxy.getProxy(rpcClient, IPipelineMasterEndpoint.class); - } + @Override + protected void getRpcEndpoint() { + this.pipelineMasterEndpoint = BrpcProxy.getProxy(rpcClient, IPipelineMasterEndpoint.class); + } - @Override - public Future process(IEvent request, RpcCallback callback) { - Container.Request taskEvent = buildRequest(request); - this.pipelineMasterEndpoint.process(taskEvent); - return null; - } + @Override + public Future process(IEvent request, RpcCallback callback) { + Container.Request taskEvent = buildRequest(request); + this.pipelineMasterEndpoint.process(taskEvent); + return null; + } - protected Request buildRequest(IEvent request) { - ByteString payload = RpcMessageEncoder.encode(request); - return Request.newBuilder().setPayload(payload).build(); - } + protected Request buildRequest(IEvent request) { + ByteString payload = RpcMessageEncoder.encode(request); + return Request.newBuilder().setPayload(payload).build(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpoint.java index bc9c8d589..cbc0a1588 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpoint.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.cluster.resourcemanager.IResourceManager; import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.ReleaseResponse; @@ -39,101 +40,112 @@ public class ResourceManagerEndpoint implements IResourceManagerEndpoint { - private static final Logger LOGGER = LoggerFactory.getLogger(ResourceManagerEndpoint.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ResourceManagerEndpoint.class); - private final IResourceManager resourceManager; + private final IResourceManager resourceManager; - public ResourceManagerEndpoint(IResourceManager resourceManager) { - this.resourceManager = resourceManager; - } + public ResourceManagerEndpoint(IResourceManager resourceManager) { + this.resourceManager = resourceManager; + } - @Override - public RequireResourceResponse requireResource(Resource.RequireResourceRequest request) { - try { - RequireResponse requireResponse = this.resourceManager.requireResource(convertRequireRequest(request)); - return convertRequireResponse(requireResponse); - } catch (Throwable t) { - LOGGER.error("require resource failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("require resource failed: %s", - t.getMessage()), t); - } + @Override + public RequireResourceResponse requireResource(Resource.RequireResourceRequest request) { + try { + RequireResponse requireResponse = + this.resourceManager.requireResource(convertRequireRequest(request)); + return convertRequireResponse(requireResponse); + } catch (Throwable t) { + LOGGER.error("require resource failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("require resource failed: %s", t.getMessage()), t); } + } - @Override - public ReleaseResourceResponse releaseResource(Resource.ReleaseResourceRequest request) { - try { - ReleaseResponse releaseResponse = this.resourceManager - .releaseResource(convertReleaseRequest(request)); - return convertReleaseResponse(releaseResponse); - } catch (Throwable t) { - LOGGER.error("release resource failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("release resource failed: %s", - t.getMessage()), t); - } + @Override + public ReleaseResourceResponse releaseResource(Resource.ReleaseResourceRequest request) { + try { + ReleaseResponse releaseResponse = + this.resourceManager.releaseResource(convertReleaseRequest(request)); + return convertReleaseResponse(releaseResponse); + } catch (Throwable t) { + LOGGER.error("release resource failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException( + String.format("release resource failed: %s", t.getMessage()), t); } + } - private static RequireResourceRequest convertRequireRequest( - Resource.RequireResourceRequest request) { - IAllocator.AllocateStrategy strategy; - switch (request.getAllocStrategy()) { - case ROUND_ROBIN: - strategy = IAllocator.AllocateStrategy.ROUND_ROBIN; - break; - case PROCESS_FAIR: - strategy = IAllocator.AllocateStrategy.PROCESS_FAIR; - break; - default: - String msg = "unrecognized allocate strategy" + request.getAllocStrategy(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.resourceError(msg)); - } - return RequireResourceRequest.build(request.getRequireId(), request.getWorkersNum(), strategy); + private static RequireResourceRequest convertRequireRequest( + Resource.RequireResourceRequest request) { + IAllocator.AllocateStrategy strategy; + switch (request.getAllocStrategy()) { + case ROUND_ROBIN: + strategy = IAllocator.AllocateStrategy.ROUND_ROBIN; + break; + case PROCESS_FAIR: + strategy = IAllocator.AllocateStrategy.PROCESS_FAIR; + break; + default: + String msg = "unrecognized allocate strategy" + request.getAllocStrategy(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.resourceError(msg)); } + return RequireResourceRequest.build(request.getRequireId(), request.getWorkersNum(), strategy); + } - private static Resource.RequireResourceResponse convertRequireResponse(RequireResponse response) { - Resource.RequireResourceResponse.Builder builder = Resource.RequireResourceResponse.newBuilder(); - boolean success = response.isSuccess(); - builder.setRequireId(response.getRequireId()); - builder.setSuccess(success); - if (response.getMsg() != null) { - builder.setMsg(response.getMsg()); - } - if (!success) { - return builder.build(); - } - for (WorkerInfo workerInfo : response.getWorkers()) { - Resource.Worker worker = Resource.Worker.newBuilder() - .setHost(workerInfo.getHost()) - .setProcessId(workerInfo.getProcessId()) - .setProcessIndex(workerInfo.getProcessIndex()) - .setRpcPort(workerInfo.getRpcPort()) - .setShufflePort(workerInfo.getShufflePort()) - .setWorkerId(workerInfo.getWorkerIndex()) - .setContainerId(workerInfo.getContainerName()) - .build(); - builder.addWorker(worker); - } - return builder.build(); + private static Resource.RequireResourceResponse convertRequireResponse(RequireResponse response) { + Resource.RequireResourceResponse.Builder builder = + Resource.RequireResourceResponse.newBuilder(); + boolean success = response.isSuccess(); + builder.setRequireId(response.getRequireId()); + builder.setSuccess(success); + if (response.getMsg() != null) { + builder.setMsg(response.getMsg()); + } + if (!success) { + return builder.build(); } + for (WorkerInfo workerInfo : response.getWorkers()) { + Resource.Worker worker = + Resource.Worker.newBuilder() + .setHost(workerInfo.getHost()) + .setProcessId(workerInfo.getProcessId()) + .setProcessIndex(workerInfo.getProcessIndex()) + .setRpcPort(workerInfo.getRpcPort()) + .setShufflePort(workerInfo.getShufflePort()) + .setWorkerId(workerInfo.getWorkerIndex()) + .setContainerId(workerInfo.getContainerName()) + .build(); + builder.addWorker(worker); + } + return builder.build(); + } - private static ReleaseResourceRequest convertReleaseRequest( - Resource.ReleaseResourceRequest request) { - List workerInfoList = request.getWorkerList().stream().map( - w -> WorkerInfo.build(w.getHost(), w.getRpcPort(), w.getShufflePort(), - w.getProcessId(), w.getProcessIndex(), w.getWorkerId(), w.getContainerId())) + private static ReleaseResourceRequest convertReleaseRequest( + Resource.ReleaseResourceRequest request) { + List workerInfoList = + request.getWorkerList().stream() + .map( + w -> + WorkerInfo.build( + w.getHost(), + w.getRpcPort(), + w.getShufflePort(), + w.getProcessId(), + w.getProcessIndex(), + w.getWorkerId(), + w.getContainerId())) .collect(Collectors.toList()); - return ReleaseResourceRequest.build(request.getReleaseId(), workerInfoList); - } + return ReleaseResourceRequest.build(request.getReleaseId(), workerInfoList); + } - private static Resource.ReleaseResourceResponse convertReleaseResponse( - ReleaseResponse response) { - boolean success = response.isSuccess(); - Resource.ReleaseResourceResponse.Builder builder = Resource.ReleaseResourceResponse.newBuilder() + private static Resource.ReleaseResourceResponse convertReleaseResponse(ReleaseResponse response) { + boolean success = response.isSuccess(); + Resource.ReleaseResourceResponse.Builder builder = + Resource.ReleaseResourceResponse.newBuilder() .setReleaseId(response.getReleaseId()) .setSuccess(success); - if (!success) { - builder.setMsg(response.getMsg()); - } - return builder.build(); + if (!success) { + builder.setMsg(response.getMsg()); } - + return builder.build(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpointRef.java index 1870c4b77..737f5bf82 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/ResourceManagerEndpointRef.java @@ -19,9 +19,9 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.ReleaseResponse; import org.apache.geaflow.cluster.resourcemanager.RequireResourceRequest; @@ -34,87 +34,103 @@ import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.rpc.proto.Resource; -public class ResourceManagerEndpointRef extends AbstractRpcEndpointRef implements IResourceEndpointRef { +import com.baidu.brpc.client.BrpcProxy; - private IResourceManagerEndpoint resourceManagerEndpoint; +public class ResourceManagerEndpointRef extends AbstractRpcEndpointRef + implements IResourceEndpointRef { - public ResourceManagerEndpointRef(String host, int port, - Configuration configuration) { - super(host, port, configuration); - } + private IResourceManagerEndpoint resourceManagerEndpoint; - @Override - protected void getRpcEndpoint() { - this.resourceManagerEndpoint = BrpcProxy.getProxy(rpcClient, IResourceManagerEndpoint.class); - } + public ResourceManagerEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + } - @Override - public RequireResponse requireResource(RequireResourceRequest request) { - Resource.RequireResourceResponse response = this.resourceManagerEndpoint.requireResource(convertRequireRequest(request)); - return convertRequireResponse(response); - } + @Override + protected void getRpcEndpoint() { + this.resourceManagerEndpoint = BrpcProxy.getProxy(rpcClient, IResourceManagerEndpoint.class); + } - @Override - public ReleaseResponse releaseResource(ReleaseResourceRequest request) { - Resource.ReleaseResourceResponse response = this.resourceManagerEndpoint.releaseResource(convertReleaseRequest(request)); - return convertReleaseResponse(response); - } + @Override + public RequireResponse requireResource(RequireResourceRequest request) { + Resource.RequireResourceResponse response = + this.resourceManagerEndpoint.requireResource(convertRequireRequest(request)); + return convertRequireResponse(response); + } - private static Resource.RequireResourceRequest convertRequireRequest(RequireResourceRequest request) { - Resource.AllocateStrategy strategy; - switch (request.getAllocateStrategy()) { - case ROUND_ROBIN: - strategy = Resource.AllocateStrategy.ROUND_ROBIN; - break; - case PROCESS_FAIR: - strategy = Resource.AllocateStrategy.PROCESS_FAIR; - break; - default: - String msg = "unrecognized allocate strategy" + request.getAllocateStrategy(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.resourceError(msg)); - } - return Resource.RequireResourceRequest.newBuilder() - .setRequireId(request.getRequireId()) - .setWorkersNum(request.getRequiredNum()) - .setAllocStrategy(strategy).build(); - } + @Override + public ReleaseResponse releaseResource(ReleaseResourceRequest request) { + Resource.ReleaseResourceResponse response = + this.resourceManagerEndpoint.releaseResource(convertReleaseRequest(request)); + return convertReleaseResponse(response); + } - private static RequireResponse convertRequireResponse(Resource.RequireResourceResponse response) { - String requireId = response.getRequireId(); - boolean success = response.getSuccess(); - String msg = response.getMsg(); - if (!success) { - return RequireResponse.fail(requireId, msg); - } - List workers = response.getWorkerList().stream() - .map(w -> WorkerInfo.build(w.getHost(), w.getRpcPort(), - w.getShufflePort(), w.getProcessId(), w.getProcessIndex(), w.getWorkerId(), w.getContainerId())) - .collect(Collectors.toList()); - return RequireResponse.success(requireId, workers); + private static Resource.RequireResourceRequest convertRequireRequest( + RequireResourceRequest request) { + Resource.AllocateStrategy strategy; + switch (request.getAllocateStrategy()) { + case ROUND_ROBIN: + strategy = Resource.AllocateStrategy.ROUND_ROBIN; + break; + case PROCESS_FAIR: + strategy = Resource.AllocateStrategy.PROCESS_FAIR; + break; + default: + String msg = "unrecognized allocate strategy" + request.getAllocateStrategy(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.resourceError(msg)); } + return Resource.RequireResourceRequest.newBuilder() + .setRequireId(request.getRequireId()) + .setWorkersNum(request.getRequiredNum()) + .setAllocStrategy(strategy) + .build(); + } - private static Resource.ReleaseResourceRequest convertReleaseRequest(ReleaseResourceRequest request) { - Resource.ReleaseResourceRequest.Builder builder = Resource.ReleaseResourceRequest.newBuilder(); - builder.setReleaseId(request.getReleaseId()); - for (WorkerInfo workerInfo : request.getWorkers()) { - Resource.Worker worker = Resource.Worker.newBuilder() - .setHost(workerInfo.getHost()) - .setProcessId(workerInfo.getProcessId()) - .setProcessIndex(workerInfo.getProcessIndex()) - .setRpcPort(workerInfo.getRpcPort()) - .setWorkerId(workerInfo.getWorkerIndex()) - .setContainerId(workerInfo.getContainerName()) - .build(); - builder.addWorker(worker); - } - return builder.build(); + private static RequireResponse convertRequireResponse(Resource.RequireResourceResponse response) { + String requireId = response.getRequireId(); + boolean success = response.getSuccess(); + String msg = response.getMsg(); + if (!success) { + return RequireResponse.fail(requireId, msg); } + List workers = + response.getWorkerList().stream() + .map( + w -> + WorkerInfo.build( + w.getHost(), + w.getRpcPort(), + w.getShufflePort(), + w.getProcessId(), + w.getProcessIndex(), + w.getWorkerId(), + w.getContainerId())) + .collect(Collectors.toList()); + return RequireResponse.success(requireId, workers); + } - private static ReleaseResponse convertReleaseResponse(Resource.ReleaseResourceResponse response) { - String releaseId = response.getReleaseId(); - return response.getSuccess() - ? ReleaseResponse.success(releaseId) - : ReleaseResponse.fail(releaseId, response.getMsg()); + private static Resource.ReleaseResourceRequest convertReleaseRequest( + ReleaseResourceRequest request) { + Resource.ReleaseResourceRequest.Builder builder = Resource.ReleaseResourceRequest.newBuilder(); + builder.setReleaseId(request.getReleaseId()); + for (WorkerInfo workerInfo : request.getWorkers()) { + Resource.Worker worker = + Resource.Worker.newBuilder() + .setHost(workerInfo.getHost()) + .setProcessId(workerInfo.getProcessId()) + .setProcessIndex(workerInfo.getProcessIndex()) + .setRpcPort(workerInfo.getRpcPort()) + .setWorkerId(workerInfo.getWorkerIndex()) + .setContainerId(workerInfo.getContainerName()) + .build(); + builder.addWorker(worker); } + return builder.build(); + } + private static ReleaseResponse convertReleaseResponse(Resource.ReleaseResourceResponse response) { + String releaseId = response.getReleaseId(); + return response.getSuccess() + ? ReleaseResponse.success(releaseId) + : ReleaseResponse.fail(releaseId, response.getMsg()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/RpcServiceImpl.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/RpcServiceImpl.java index 0ffef01c7..ad42eec78 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/RpcServiceImpl.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/RpcServiceImpl.java @@ -19,61 +19,63 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.server.RpcServer; -import com.baidu.brpc.server.RpcServerOptions; import java.io.Serializable; + import org.apache.geaflow.cluster.rpc.RpcService; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.baidu.brpc.server.RpcServer; +import com.baidu.brpc.server.RpcServerOptions; + public class RpcServiceImpl implements RpcService, Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(RpcServiceImpl.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RpcServiceImpl.class); - private final int port; + private final int port; - private final RpcServer server; + private final RpcServer server; - public RpcServiceImpl(int port, RpcServerOptions options) { - this.port = port; - this.server = new RpcServer(port, options); - } + public RpcServiceImpl(int port, RpcServerOptions options) { + this.port = port; + this.server = new RpcServer(port, options); + } - public void addEndpoint(Object rpcEndpoint) { - server.registerService(rpcEndpoint); - } + public void addEndpoint(Object rpcEndpoint) { + server.registerService(rpcEndpoint); + } - @Override - public int startService() { - try { - this.server.start(); - LOGGER.info("Brpc Server started: {}", port); - return port; - } catch (Throwable t) { - LOGGER.error(t.getMessage(), t); - throw new GeaflowRuntimeException(t); - } + @Override + public int startService() { + try { + this.server.start(); + LOGGER.info("Brpc Server started: {}", port); + return port; + } catch (Throwable t) { + LOGGER.error(t.getMessage(), t); + throw new GeaflowRuntimeException(t); } + } - @Test - public void waitTermination() { - synchronized (server) { - while (!server.isShutdown()) { - try { - server.wait(); - } catch (InterruptedException e) { - LOGGER.warn("shutdown is interrupted"); - } - } + @Test + public void waitTermination() { + synchronized (server) { + while (!server.isShutdown()) { + try { + server.wait(); + } catch (InterruptedException e) { + LOGGER.warn("shutdown is interrupted"); } + } } + } - @Override - public void stopService() { - if (server != null) { - server.shutdown(); - } + @Override + public void stopService() { + if (server != null) { + server.shutdown(); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpoint.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpoint.java index c7980e5b6..d06039406 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpoint.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpoint.java @@ -19,32 +19,32 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.google.protobuf.Empty; import org.apache.geaflow.cluster.rpc.ISupervisorEndpoint; import org.apache.geaflow.cluster.runner.Supervisor; import org.apache.geaflow.rpc.proto.Supervisor.RestartRequest; import org.apache.geaflow.rpc.proto.Supervisor.StatusResponse; -public class SupervisorEndpoint implements ISupervisorEndpoint { - - private final Supervisor supervisor; - private final Empty empty; - - public SupervisorEndpoint(Supervisor supervisor) { - this.supervisor = supervisor; - this.empty = Empty.newBuilder().build(); - } - - @Override - public Empty restart(RestartRequest request) { - supervisor.restartWorker(request.getPid()); - return empty; - } +import com.google.protobuf.Empty; - @Override - public StatusResponse status(Empty empty) { - boolean isAlive = supervisor.isWorkerAlive(); - return StatusResponse.newBuilder().setIsAlive(isAlive).build(); - } +public class SupervisorEndpoint implements ISupervisorEndpoint { + private final Supervisor supervisor; + private final Empty empty; + + public SupervisorEndpoint(Supervisor supervisor) { + this.supervisor = supervisor; + this.empty = Empty.newBuilder().build(); + } + + @Override + public Empty restart(RestartRequest request) { + supervisor.restartWorker(request.getPid()); + return empty; + } + + @Override + public StatusResponse status(Empty empty) { + boolean isAlive = supervisor.isWorkerAlive(); + return StatusResponse.newBuilder().setIsAlive(isAlive).build(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpointRef.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpointRef.java index 86991078c..6665e3aeb 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpointRef.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/rpc/impl/SupervisorEndpointRef.java @@ -19,12 +19,9 @@ package org.apache.geaflow.cluster.rpc.impl; -import com.baidu.brpc.client.BrpcProxy; -import com.baidu.brpc.client.RpcClientOptions; -import com.baidu.brpc.loadbalance.LoadBalanceStrategy; -import com.google.protobuf.Empty; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.rpc.IAsyncSupervisorEndpoint; import org.apache.geaflow.cluster.rpc.ISupervisorEndpointRef; import org.apache.geaflow.cluster.rpc.RpcUtil; @@ -32,46 +29,51 @@ import org.apache.geaflow.rpc.proto.Supervisor.RestartRequest; import org.apache.geaflow.rpc.proto.Supervisor.StatusResponse; -public class SupervisorEndpointRef extends AbstractRpcEndpointRef implements ISupervisorEndpointRef { +import com.baidu.brpc.client.BrpcProxy; +import com.baidu.brpc.client.RpcClientOptions; +import com.baidu.brpc.loadbalance.LoadBalanceStrategy; +import com.google.protobuf.Empty; - private IAsyncSupervisorEndpoint supervisorEndpoint; - private final Empty empty; +public class SupervisorEndpointRef extends AbstractRpcEndpointRef + implements ISupervisorEndpointRef { - public SupervisorEndpointRef(String host, int port, Configuration configuration) { - super(host, port, configuration); - this.empty = Empty.newBuilder().build(); - } + private IAsyncSupervisorEndpoint supervisorEndpoint; + private final Empty empty; - @Override - protected void getRpcEndpoint() { - this.supervisorEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncSupervisorEndpoint.class); - } + public SupervisorEndpointRef(String host, int port, Configuration configuration) { + super(host, port, configuration); + this.empty = Empty.newBuilder().build(); + } - @Override - protected RpcClientOptions getClientOptions() { - RpcClientOptions options = super.getClientOptions(); - options.setGlobalThreadPoolSharing(false); - options.setMaxTotalConnections(2); - options.setMinIdleConnections(2); - options.setIoThreadNum(1); - options.setWorkThreadNum(2); - options.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); - return options; - } + @Override + protected void getRpcEndpoint() { + this.supervisorEndpoint = BrpcProxy.getProxy(rpcClient, IAsyncSupervisorEndpoint.class); + } - @Override - public Future restart(int pid, RpcCallback callback) { - CompletableFuture result = new CompletableFuture<>(); - com.baidu.brpc.client.RpcCallback rpcCallback = - RpcUtil.buildRpcCallback(callback, result); - RestartRequest request = RestartRequest.newBuilder().setPid(pid).build(); - supervisorEndpoint.restart(request, rpcCallback); - return result; - } + @Override + protected RpcClientOptions getClientOptions() { + RpcClientOptions options = super.getClientOptions(); + options.setGlobalThreadPoolSharing(false); + options.setMaxTotalConnections(2); + options.setMinIdleConnections(2); + options.setIoThreadNum(1); + options.setWorkThreadNum(2); + options.setLoadBalanceType(LoadBalanceStrategy.LOAD_BALANCE_ROUND_ROBIN); + return options; + } - @Override - public StatusResponse status() { - return supervisorEndpoint.status(empty); - } + @Override + public Future restart(int pid, RpcCallback callback) { + CompletableFuture result = new CompletableFuture<>(); + com.baidu.brpc.client.RpcCallback rpcCallback = + RpcUtil.buildRpcCallback(callback, result); + RestartRequest request = RestartRequest.newBuilder().setPid(pid).build(); + supervisorEndpoint.restart(request, rpcCallback); + return result; + } + @Override + public StatusResponse status() { + return supervisorEndpoint.status(empty); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/CommandRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/CommandRunner.java index 6c836ddd0..a3865847f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/CommandRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/CommandRunner.java @@ -21,12 +21,12 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.PROCESS_EXIT_WAIT_SECONDS; -import com.google.common.base.Preconditions; import java.io.IOException; import java.lang.ProcessBuilder.Redirect; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.ProcessUtil; @@ -36,126 +36,133 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class CommandRunner { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(CommandRunner.class); - - private int pid; - private Process process; - private final String command; - private final int maxRestarts; - private final Map env; - private final Configuration configuration; - private final int exitWaitSecs; - - public CommandRunner(String command, int maxRestarts, Map env, - Configuration config) { - this.command = command; - this.maxRestarts = maxRestarts; - this.env = env; - this.configuration = config; - this.exitWaitSecs = config.getInteger(PROCESS_EXIT_WAIT_SECONDS); - } +public class CommandRunner { - public void asyncStart() { - CompletableFuture.runAsync(() -> { - try { - startProcess(); - } catch (Throwable e) { - LOGGER.error("Start process failed: {}", e.getMessage(), e); - String errMsg = String.format("Worker process exited: %s", e.getMessage()); - StatsCollectorFactory.init(configuration).getEventCollector() - .reportEvent(ExceptionLevel.ERROR, EventLabel.WORKER_PROCESS_EXITED, errMsg); - } + private static final Logger LOGGER = LoggerFactory.getLogger(CommandRunner.class); + + private int pid; + private Process process; + private final String command; + private final int maxRestarts; + private final Map env; + private final Configuration configuration; + private final int exitWaitSecs; + + public CommandRunner( + String command, int maxRestarts, Map env, Configuration config) { + this.command = command; + this.maxRestarts = maxRestarts; + this.env = env; + this.configuration = config; + this.exitWaitSecs = config.getInteger(PROCESS_EXIT_WAIT_SECONDS); + } + + public void asyncStart() { + CompletableFuture.runAsync( + () -> { + try { + startProcess(); + } catch (Throwable e) { + LOGGER.error("Start process failed: {}", e.getMessage(), e); + String errMsg = String.format("Worker process exited: %s", e.getMessage()); + StatsCollectorFactory.init(configuration) + .getEventCollector() + .reportEvent(ExceptionLevel.ERROR, EventLabel.WORKER_PROCESS_EXITED, errMsg); + } }); - } - - public void startProcess() { - try { - int restarts = maxRestarts; - do { - Process childProcess = doStartProcess(command); - int code = childProcess.waitFor(); - LOGGER.warn("Child process {} exits with code: {} and alive: {}", pid, code, - childProcess.isAlive()); - // 0: success, 137: killed by SIGKILL, 143: killed by SIGTERM - if (code == 0 || code == 137 || code == 143) { - return; - } - if (restarts == 0) { - String errMsg; - if (maxRestarts == 0) { - errMsg = String.format("process exits code: %s", code); - } else { - errMsg = String.format("process exits code: %s, exhausted %s restarts", - code, maxRestarts); - } - throw new GeaflowRuntimeException(errMsg); - } - restarts--; - } while (true); - } catch (GeaflowRuntimeException e) { - LOGGER.error("FATAL: start command failed: {}", command, e); - throw e; - } catch (Throwable e) { - LOGGER.error("FATAL: start command failed: {}", command, e); - throw new GeaflowRuntimeException(e.getMessage(), e); + } + + public void startProcess() { + try { + int restarts = maxRestarts; + do { + Process childProcess = doStartProcess(command); + int code = childProcess.waitFor(); + LOGGER.warn( + "Child process {} exits with code: {} and alive: {}", + pid, + code, + childProcess.isAlive()); + // 0: success, 137: killed by SIGKILL, 143: killed by SIGTERM + if (code == 0 || code == 137 || code == 143) { + return; } - } - - private Process doStartProcess(String startCommand) throws IOException { - LOGGER.info("Start process with command: {}", startCommand); - ProcessBuilder pb = new ProcessBuilder(); - //pb.redirectInput(Redirect.INHERIT); - pb.redirectOutput(Redirect.INHERIT); - if (env != null) { - pb.environment().putAll(env); + if (restarts == 0) { + String errMsg; + if (maxRestarts == 0) { + errMsg = String.format("process exits code: %s", code); + } else { + errMsg = + String.format("process exits code: %s, exhausted %s restarts", code, maxRestarts); + } + throw new GeaflowRuntimeException(errMsg); } - String[] cmds = startCommand.split("\\s+"); - pb.command(cmds); - Process childProcess = pb.start(); - this.process = childProcess; - this.pid = ProcessUtil.getProcessPid(childProcess); - LOGGER.info("Process started with pid: {}", pid); - return childProcess; - } - - public Process getProcess() { - return process; + restarts--; + } while (true); + } catch (GeaflowRuntimeException e) { + LOGGER.error("FATAL: start command failed: {}", command, e); + throw e; + } catch (Throwable e) { + LOGGER.error("FATAL: start command failed: {}", command, e); + throw new GeaflowRuntimeException(e.getMessage(), e); } - - public int getProcessId() { - return pid; + } + + private Process doStartProcess(String startCommand) throws IOException { + LOGGER.info("Start process with command: {}", startCommand); + ProcessBuilder pb = new ProcessBuilder(); + // pb.redirectInput(Redirect.INHERIT); + pb.redirectOutput(Redirect.INHERIT); + if (env != null) { + pb.environment().putAll(env); } - - public void stop() { - stop(pid); + String[] cmds = startCommand.split("\\s+"); + pb.command(cmds); + Process childProcess = pb.start(); + this.process = childProcess; + this.pid = ProcessUtil.getProcessPid(childProcess); + LOGGER.info("Process started with pid: {}", pid); + return childProcess; + } + + public Process getProcess() { + return process; + } + + public int getProcessId() { + return pid; + } + + public void stop() { + stop(pid); + } + + public void stop(int oldPid) { + Preconditions.checkArgument(pid > 0, "pid should be larger than 0"); + LOGGER.info("Stop old process if exists: {}", oldPid); + + Process curProcess = process; + int curPid = pid; + + // If bash process is alive, kill it, and it's child process is supposed to be killed. + if (curProcess.isAlive()) { + if (curPid <= 0) { + LOGGER.warn("Process is alive but pid not found: {}", curProcess); + return; + } + curProcess.destroy(); + try { + boolean status = curProcess.waitFor(exitWaitSecs, TimeUnit.SECONDS); + LOGGER.info("Destroy current process {}: {}", curPid, status); + } catch (InterruptedException e) { + LOGGER.warn("Interrupted while waiting for process to exit: {}", pid); + } } - - public void stop(int oldPid) { - Preconditions.checkArgument(pid > 0, "pid should be larger than 0"); - LOGGER.info("Stop old process if exists: {}", oldPid); - - Process curProcess = process; - int curPid = pid; - - // If bash process is alive, kill it, and it's child process is supposed to be killed. - if (curProcess.isAlive()) { - if (curPid <= 0) { - LOGGER.warn("Process is alive but pid not found: {}", curProcess); - return; - } - curProcess.destroy(); - try { - boolean status = curProcess.waitFor(exitWaitSecs, TimeUnit.SECONDS); - LOGGER.info("Destroy current process {}: {}", curPid, status); - } catch (InterruptedException e) { - LOGGER.warn("Interrupted while waiting for process to exit: {}", pid); - } - } - if (curPid != oldPid) { - LOGGER.info("Kill old process: {}", oldPid); - ProcessUtil.killProcess(oldPid); - } + if (curPid != oldPid) { + LOGGER.info("Kill old process: {}", oldPid); + ProcessUtil.killProcess(oldPid); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/Supervisor.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/Supervisor.java index df9f05acb..5bfb5e25a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/Supervisor.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/runner/Supervisor.java @@ -25,9 +25,9 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.FO_MAX_RESTARTS; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SUPERVISOR_RPC_PORT; -import com.baidu.brpc.server.RpcServerOptions; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.rpc.impl.RpcServiceImpl; import org.apache.geaflow.cluster.rpc.impl.SupervisorEndpoint; import org.apache.geaflow.cluster.web.agent.AgentWebServer; @@ -40,100 +40,104 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class Supervisor { - - private static final Logger LOGGER = LoggerFactory.getLogger(Supervisor.class); - private static final int DEFAULT_RETRIES = 3; - - private final RpcServiceImpl rpcService; - private final CommandRunner mainRunner; - private final Configuration configuration; - private final int maxRestarts; - private final Map envMap; - - public Supervisor(String startCommand, Configuration configuration, boolean autoRestart) { - this.configuration = configuration; - this.maxRestarts = autoRestart ? configuration.getInteger(FO_MAX_RESTARTS) : 0; - - RpcServerOptions serverOptions = getServerOptions(configuration); - int port = PortUtil.getPort(configuration.getInteger(SUPERVISOR_RPC_PORT)); - this.rpcService = new RpcServiceImpl(port, serverOptions); - this.rpcService.addEndpoint(new SupervisorEndpoint(this)); - this.rpcService.startService(); - - this.envMap = new HashMap<>(); - envMap.put(ENV_SUPERVISOR_PORT, String.valueOf(port)); - - this.mainRunner = new CommandRunner(startCommand, maxRestarts, envMap, configuration); - LOGGER.info("Start supervisor with maxRestarts: {}", maxRestarts); - } - - private RpcServerOptions getServerOptions(Configuration configuration) { - RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); - serverOptions.setGlobalThreadPoolSharing(false); - serverOptions.setIoThreadNum(1); - serverOptions.setWorkThreadNum(2); - return serverOptions; - } - - public void start() { - try { - startAgent(); - startWorker(); - } catch (Throwable e) { - StatsCollectorFactory.init(configuration).getExceptionCollector().reportException( - ExceptionLevel.FATAL, e); - throw e; - } - } - - public void restartWorker(int pid) { - LOGGER.info("Restart worker process: {}", pid); - stopWorker(pid); - startWorker(); - } - - public void startWorker() { - mainRunner.asyncStart(); - } - - public boolean isWorkerAlive() { - Process process = mainRunner.getProcess(); - if (maxRestarts > 0 || process != null && process.isAlive()) { - return true; - } - LOGGER.warn("Worker process {} is dead.", mainRunner.getProcessId()); - return false; - } +import com.baidu.brpc.server.RpcServerOptions; - public void stopWorker() { - mainRunner.stop(); - } +public class Supervisor { - public void stopWorker(int pid) { - mainRunner.stop(pid); + private static final Logger LOGGER = LoggerFactory.getLogger(Supervisor.class); + private static final int DEFAULT_RETRIES = 3; + + private final RpcServiceImpl rpcService; + private final CommandRunner mainRunner; + private final Configuration configuration; + private final int maxRestarts; + private final Map envMap; + + public Supervisor(String startCommand, Configuration configuration, boolean autoRestart) { + this.configuration = configuration; + this.maxRestarts = autoRestart ? configuration.getInteger(FO_MAX_RESTARTS) : 0; + + RpcServerOptions serverOptions = getServerOptions(configuration); + int port = PortUtil.getPort(configuration.getInteger(SUPERVISOR_RPC_PORT)); + this.rpcService = new RpcServiceImpl(port, serverOptions); + this.rpcService.addEndpoint(new SupervisorEndpoint(this)); + this.rpcService.startService(); + + this.envMap = new HashMap<>(); + envMap.put(ENV_SUPERVISOR_PORT, String.valueOf(port)); + + this.mainRunner = new CommandRunner(startCommand, maxRestarts, envMap, configuration); + LOGGER.info("Start supervisor with maxRestarts: {}", maxRestarts); + } + + private RpcServerOptions getServerOptions(Configuration configuration) { + RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); + serverOptions.setGlobalThreadPoolSharing(false); + serverOptions.setIoThreadNum(1); + serverOptions.setWorkThreadNum(2); + return serverOptions; + } + + public void start() { + try { + startAgent(); + startWorker(); + } catch (Throwable e) { + StatsCollectorFactory.init(configuration) + .getExceptionCollector() + .reportException(ExceptionLevel.FATAL, e); + throw e; } - - public void startAgent() { - RetryCommand.run(() -> { - int agentPort = PortUtil.getPort(configuration.getInteger(AGENT_HTTP_PORT)); - envMap.put(ENV_AGENT_PORT, String.valueOf(agentPort)); - AgentWebServer server = new AgentWebServer(agentPort, configuration); - server.start(); - return null; - }, DEFAULT_RETRIES); + } + + public void restartWorker(int pid) { + LOGGER.info("Restart worker process: {}", pid); + stopWorker(pid); + startWorker(); + } + + public void startWorker() { + mainRunner.asyncStart(); + } + + public boolean isWorkerAlive() { + Process process = mainRunner.getProcess(); + if (maxRestarts > 0 || process != null && process.isAlive()) { + return true; } - - public void waitForTermination() { - if (rpcService != null) { - rpcService.waitTermination(); - } + LOGGER.warn("Worker process {} is dead.", mainRunner.getProcessId()); + return false; + } + + public void stopWorker() { + mainRunner.stop(); + } + + public void stopWorker(int pid) { + mainRunner.stop(pid); + } + + public void startAgent() { + RetryCommand.run( + () -> { + int agentPort = PortUtil.getPort(configuration.getInteger(AGENT_HTTP_PORT)); + envMap.put(ENV_AGENT_PORT, String.valueOf(agentPort)); + AgentWebServer server = new AgentWebServer(agentPort, configuration); + server.start(); + return null; + }, + DEFAULT_RETRIES); + } + + public void waitForTermination() { + if (rpcService != null) { + rpcService.waitTermination(); } + } - public void stop() { - if (rpcService != null) { - rpcService.stopService(); - } + public void stop() { + if (rpcService != null) { + rpcService.stopService(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStore.java index 86fd7cf5d..589b322a4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStore.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.resourcemanager.WorkerSnapshot; import org.apache.geaflow.common.config.Configuration; @@ -34,275 +35,270 @@ public class ClusterMetaStore { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterMetaStore.class); - - private static final String CLUSTER_META_NAMESPACE_LABEL = "framework"; - private static final String CLUSTER_NAMESPACE_PREFIX = "cluster"; - private static final String OFFSET_NAMESPACE = "offset"; - - private static ClusterMetaStore INSTANCE; - - private final int componentId; - private final String componentName; - private final String clusterId; - private final Configuration configuration; - private final IClusterMetaKVStore componentBackend; - private Map> backends; - - private ClusterMetaStore(int id, String name, Configuration configuration) { - this.componentId = id; - this.componentName = name; - this.configuration = configuration; - this.backends = new ConcurrentHashMap<>(); - this.clusterId = configuration.getString(CLUSTER_ID); - String namespace = String.format("%s/%s/%s", CLUSTER_NAMESPACE_PREFIX, clusterId, componentName); - this.componentBackend = createBackend(namespace); - this.backends.put(namespace, componentBackend); - } - - public static void init(int id, String name, Configuration configuration) { - if (INSTANCE == null) { - synchronized (ClusterMetaStore.class) { - if (INSTANCE == null) { - INSTANCE = new ClusterMetaStore(id, name, configuration); - } - } - } - } - - public static ClusterMetaStore getInstance(int id, String name, Configuration configuration) { + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterMetaStore.class); + + private static final String CLUSTER_META_NAMESPACE_LABEL = "framework"; + private static final String CLUSTER_NAMESPACE_PREFIX = "cluster"; + private static final String OFFSET_NAMESPACE = "offset"; + + private static ClusterMetaStore INSTANCE; + + private final int componentId; + private final String componentName; + private final String clusterId; + private final Configuration configuration; + private final IClusterMetaKVStore componentBackend; + private Map> backends; + + private ClusterMetaStore(int id, String name, Configuration configuration) { + this.componentId = id; + this.componentName = name; + this.configuration = configuration; + this.backends = new ConcurrentHashMap<>(); + this.clusterId = configuration.getString(CLUSTER_ID); + String namespace = + String.format("%s/%s/%s", CLUSTER_NAMESPACE_PREFIX, clusterId, componentName); + this.componentBackend = createBackend(namespace); + this.backends.put(namespace, componentBackend); + } + + public static void init(int id, String name, Configuration configuration) { + if (INSTANCE == null) { + synchronized (ClusterMetaStore.class) { if (INSTANCE == null) { - init(id, name, configuration); + INSTANCE = new ClusterMetaStore(id, name, configuration); } - return INSTANCE; - } - - public static ClusterMetaStore getInstance() { - return INSTANCE; - } - - public static synchronized void close() { - LOGGER.info("close ClusterMetaStore"); - if (INSTANCE != null) { - Map> backends = INSTANCE.backends; - INSTANCE.backends = null; - for (IClusterMetaKVStore backend : backends.values()) { - backend.close(); - } - INSTANCE = null; + } + } + } + + public static ClusterMetaStore getInstance(int id, String name, Configuration configuration) { + if (INSTANCE == null) { + init(id, name, configuration); + } + return INSTANCE; + } + + public static ClusterMetaStore getInstance() { + return INSTANCE; + } + + public static synchronized void close() { + LOGGER.info("close ClusterMetaStore"); + if (INSTANCE != null) { + Map> backends = INSTANCE.backends; + INSTANCE.backends = null; + for (IClusterMetaKVStore backend : backends.values()) { + backend.close(); + } + INSTANCE = null; + } + } + + public ClusterMetaStore savePipeline(Pipeline pipeline) { + save(ClusterMetaKey.PIPELINE, pipeline); + return this; + } + + public ClusterMetaStore savePipelineTaskIds(List pipelineTaskIds) { + save(ClusterMetaKey.PIPELINE_TASK_IDS, pipelineTaskIds); + return this; + } + + public ClusterMetaStore savePipelineTasks(List taskIndices) { + save(ClusterMetaKey.PIPELINE_TASKS, taskIndices); + return this; + } + + /** Auto flush after save value. */ + public void saveWindowId(Long windowId, long pipelineTaskId) { + save(ClusterMetaKey.WINDOW_ID, windowId, pipelineTaskId); + } + + public ClusterMetaStore saveCycle(Object cycle, long pipelineTaskId) { + save(ClusterMetaKey.CYCLE, cycle, pipelineTaskId); + return this; + } + + public ClusterMetaStore saveEvent(List event) { + save(ClusterMetaKey.EVENTS, event); + return this; + } + + /** Auto flush after save value. */ + public void saveWorkers(WorkerSnapshot workers) { + save(ClusterMetaKey.WORKERS, workers); + } + + public ClusterMetaStore saveContainerIds(Map containerIds) { + save(ClusterMetaKey.CONTAINER_IDS, containerIds); + return this; + } + + public ClusterMetaStore saveDriverIds(Map driverIds) { + save(ClusterMetaKey.DRIVER_IDS, driverIds); + return this; + } + + public ClusterMetaStore saveMaxContainerId(int containerId) { + save(ClusterMetaKey.MAX_CONTAINER_ID, containerId); + return this; + } + + public Pipeline getPipeline() { + return get(ClusterMetaKey.PIPELINE); + } + + public List getPipelineTaskIds() { + return get(ClusterMetaKey.PIPELINE_TASK_IDS); + } + + public List getPipelineTasks() { + return get(ClusterMetaKey.PIPELINE_TASKS); + } + + public Long getWindowId(long pipelineTaskId) { + return get(ClusterMetaKey.WINDOW_ID, pipelineTaskId); + } + + public Object getCycle(long pipelineTaskId) { + return get(ClusterMetaKey.CYCLE, pipelineTaskId); + } + + public List getEvents() { + return get(ClusterMetaKey.EVENTS); + } + + public WorkerSnapshot getWorkers() { + return get(ClusterMetaKey.WORKERS); + } + + public int getMaxContainerId() { + return get(ClusterMetaKey.MAX_CONTAINER_ID); + } + + public Map getContainerIds() { + return get(ClusterMetaKey.CONTAINER_IDS); + } + + public Map getDriverIds() { + return get(ClusterMetaKey.DRIVER_IDS); + } + + public void flush() { + componentBackend.flush(); + } + + public void clean() { + // TODO Clean namespace directly from backend. + } + + private void save(ClusterMetaKey key, T value) { + getBackend(key).put(key.name(), value); + } + + private void save(ClusterMetaKey key, T value, long pipelineTaskId) { + getBackend(key).put(getKeyTag(key.name(), pipelineTaskId), value); + } + + private T get(ClusterMetaKey key) { + return (T) getBackend(key).get(key.name()); + } + + private T get(ClusterMetaKey key, long pipelineTaskId) { + return (T) getBackend(key).get(getKeyTag(key.name(), pipelineTaskId)); + } + + private String getKeyTag(String key, long pipelineTaskId) { + return String.format("%s#%s", key, pipelineTaskId); + } + + private IClusterMetaKVStore getBackend(ClusterMetaKey metaKey) { + String namespace; + switch (metaKey) { + case WORKERS: + namespace = + String.format( + "%s/%s/%s", CLUSTER_NAMESPACE_PREFIX, clusterId, metaKey.name().toLowerCase()); + break; + case WINDOW_ID: + namespace = OFFSET_NAMESPACE; + break; + default: + return componentBackend; + } + // Cluster meta store is closed. + if (backends == null) { + return null; + } + if (!backends.containsKey(namespace)) { + synchronized (ClusterMetaStore.class) { + if (!backends.containsKey(namespace)) { + IClusterMetaKVStore backend = createBackend(namespace); + backends.put(namespace, new ClusterMetaKVStoreProxy<>(backend)); } + } } + return backends.get(namespace); + } - public ClusterMetaStore savePipeline(Pipeline pipeline) { - save(ClusterMetaKey.PIPELINE, pipeline); - return this; - } - - public ClusterMetaStore savePipelineTaskIds(List pipelineTaskIds) { - save(ClusterMetaKey.PIPELINE_TASK_IDS, pipelineTaskIds); - return this; - } - - public ClusterMetaStore savePipelineTasks(List taskIndices) { - save(ClusterMetaKey.PIPELINE_TASKS, taskIndices); - return this; - } - - /** - * Auto flush after save value. - */ - public void saveWindowId(Long windowId, long pipelineTaskId) { - save(ClusterMetaKey.WINDOW_ID, windowId, pipelineTaskId); - } - - public ClusterMetaStore saveCycle(Object cycle, long pipelineTaskId) { - save(ClusterMetaKey.CYCLE, cycle, pipelineTaskId); - return this; - } - - public ClusterMetaStore saveEvent(List event) { - save(ClusterMetaKey.EVENTS, event); - return this; - } - - /** - * Auto flush after save value. - */ - public void saveWorkers(WorkerSnapshot workers) { - save(ClusterMetaKey.WORKERS, workers); - } - - public ClusterMetaStore saveContainerIds(Map containerIds) { - save(ClusterMetaKey.CONTAINER_IDS, containerIds); - return this; - } - - public ClusterMetaStore saveDriverIds(Map driverIds) { - save(ClusterMetaKey.DRIVER_IDS, driverIds); - return this; - } - - public ClusterMetaStore saveMaxContainerId(int containerId) { - save(ClusterMetaKey.MAX_CONTAINER_ID, containerId); - return this; - } - - public Pipeline getPipeline() { - return get(ClusterMetaKey.PIPELINE); - } - - public List getPipelineTaskIds() { - return get(ClusterMetaKey.PIPELINE_TASK_IDS); - } - - public List getPipelineTasks() { - return get(ClusterMetaKey.PIPELINE_TASKS); - } - - public Long getWindowId(long pipelineTaskId) { - return get(ClusterMetaKey.WINDOW_ID, pipelineTaskId); - } - - public Object getCycle(long pipelineTaskId) { - return get(ClusterMetaKey.CYCLE, pipelineTaskId); - } - - public List getEvents() { - return get(ClusterMetaKey.EVENTS); - } + private IClusterMetaKVStore createBackend(String namespace) { + String storeKey = String.format("%s/%s", CLUSTER_META_NAMESPACE_LABEL, namespace); + IClusterMetaKVStore backend = + ClusterMetaStoreFactory.create(storeKey, componentId, configuration); + LOGGER.info("create ClusterMetaStore, store key {}, id {}", storeKey, componentId); + return backend; + } - public WorkerSnapshot getWorkers() { - return get(ClusterMetaKey.WORKERS); - } + public enum ClusterMetaKey { + PIPELINE, + PIPELINE_TASK_IDS, + PIPELINE_TASKS, + WINDOW_ID, + CYCLE, + EVENTS, + WORKERS, + CONTAINER_IDS, + DRIVER_IDS, + MAX_CONTAINER_ID + } - public int getMaxContainerId() { - return get(ClusterMetaKey.MAX_CONTAINER_ID); - } + /** A proxy that flush immediately once put entry to store. */ + private class ClusterMetaKVStoreProxy implements IClusterMetaKVStore { - public Map getContainerIds() { - return get(ClusterMetaKey.CONTAINER_IDS); + public ClusterMetaKVStoreProxy(IClusterMetaKVStore store) { + this.store = store; } - public Map getDriverIds() { - return get(ClusterMetaKey.DRIVER_IDS); - } - - public void flush() { - componentBackend.flush(); - } + private IClusterMetaKVStore store; - public void clean() { - // TODO Clean namespace directly from backend. + @Override + public void init(StoreContext storeContext) { + store.init(storeContext); } - private void save(ClusterMetaKey key, T value) { - getBackend(key).put(key.name(), value); - } - - private void save(ClusterMetaKey key, T value, long pipelineTaskId) { - getBackend(key).put(getKeyTag(key.name(), pipelineTaskId), value); - } - - private T get(ClusterMetaKey key) { - return (T) getBackend(key).get(key.name()); - } - - private T get(ClusterMetaKey key, long pipelineTaskId) { - return (T) getBackend(key).get(getKeyTag(key.name(), pipelineTaskId)); - } - - private String getKeyTag(String key, long pipelineTaskId) { - return String.format("%s#%s", key, pipelineTaskId); - } + @Override + public void flush() {} - private IClusterMetaKVStore getBackend(ClusterMetaKey metaKey) { - String namespace; - switch (metaKey) { - case WORKERS: - namespace = String.format("%s/%s/%s", CLUSTER_NAMESPACE_PREFIX, clusterId, metaKey.name().toLowerCase()); - break; - case WINDOW_ID: - namespace = OFFSET_NAMESPACE; - break; - default: - return componentBackend; - } - // Cluster meta store is closed. - if (backends == null) { - return null; - } - if (!backends.containsKey(namespace)) { - synchronized (ClusterMetaStore.class) { - if (!backends.containsKey(namespace)) { - IClusterMetaKVStore backend = createBackend(namespace); - backends.put(namespace, new ClusterMetaKVStoreProxy<>(backend)); - } - } - } - return backends.get(namespace); + @Override + public void close() { + store.close(); } - private IClusterMetaKVStore createBackend(String namespace) { - String storeKey = String.format("%s/%s", CLUSTER_META_NAMESPACE_LABEL, namespace); - IClusterMetaKVStore backend = - ClusterMetaStoreFactory.create(storeKey, componentId, configuration); - LOGGER.info("create ClusterMetaStore, store key {}, id {}", storeKey, componentId); - return backend; + @Override + public V get(K key) { + return store.get(key); } - public enum ClusterMetaKey { - - PIPELINE, - PIPELINE_TASK_IDS, - PIPELINE_TASKS, - WINDOW_ID, - CYCLE, - EVENTS, - WORKERS, - CONTAINER_IDS, - DRIVER_IDS, - MAX_CONTAINER_ID + @Override + public void put(K key, V value) { + store.put(key, value); + store.flush(); } - /** - * A proxy that flush immediately once put entry to store. - */ - private class ClusterMetaKVStoreProxy implements IClusterMetaKVStore { - - public ClusterMetaKVStoreProxy(IClusterMetaKVStore store) { - this.store = store; - } - - private IClusterMetaKVStore store; - - @Override - public void init(StoreContext storeContext) { - store.init(storeContext); - } - - @Override - public void flush() { - } - - @Override - public void close() { - store.close(); - } - - @Override - public V get(K key) { - return store.get(key); - } - - @Override - public void put(K key, V value) { - store.put(key, value); - store.flush(); - } - - @Override - public void remove(K key) { - store.remove(key); - } + @Override + public void remove(K key) { + store.remove(key); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactory.java index 496cd575e..3abe0efb6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactory.java @@ -29,37 +29,38 @@ public class ClusterMetaStoreFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterMetaStoreFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterMetaStoreFactory.class); - private static final int DEFAULT_SHARD_ID = 0; + private static final int DEFAULT_SHARD_ID = 0; - public static IClusterMetaKVStore create(String name, Configuration configuration) { - return create(name, DEFAULT_SHARD_ID, configuration); - } + public static IClusterMetaKVStore create(String name, Configuration configuration) { + return create(name, DEFAULT_SHARD_ID, configuration); + } - public static IClusterMetaKVStore create(String name, int shardId, Configuration configuration) { - StoreContext storeContext = new StoreContext(name); - storeContext.withKeySerializer(new DefaultKVSerializer(null, null)); - storeContext.withConfig(configuration); - storeContext.withShardId(shardId); + public static IClusterMetaKVStore create( + String name, int shardId, Configuration configuration) { + StoreContext storeContext = new StoreContext(name); + storeContext.withKeySerializer(new DefaultKVSerializer(null, null)); + storeContext.withConfig(configuration); + storeContext.withShardId(shardId); - String backendType = configuration.getString(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE); - IClusterMetaKVStore store = create(StoreType.getEnum(backendType)); - store.init(storeContext); - return store; - } + String backendType = configuration.getString(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE); + IClusterMetaKVStore store = create(StoreType.getEnum(backendType)); + store.init(storeContext); + return store; + } - private static IClusterMetaKVStore create(StoreType storeType) { - IClusterMetaKVStore clusterMetaKVStore; - switch (storeType) { - case ROCKSDB: - LOGGER.info("create rocksdb cluster metastore"); - clusterMetaKVStore = new RocksdbClusterMetaKVStore(); - break; - default: - LOGGER.info("create memory cluster metastore"); - clusterMetaKVStore = new MemoryClusterMetaKVStore(); - } - return clusterMetaKVStore; + private static IClusterMetaKVStore create(StoreType storeType) { + IClusterMetaKVStore clusterMetaKVStore; + switch (storeType) { + case ROCKSDB: + LOGGER.info("create rocksdb cluster metastore"); + clusterMetaKVStore = new RocksdbClusterMetaKVStore(); + break; + default: + LOGGER.info("create memory cluster metastore"); + clusterMetaKVStore = new MemoryClusterMetaKVStore(); } + return clusterMetaKVStore; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/IClusterMetaKVStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/IClusterMetaKVStore.java index df69b2679..07443f82a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/IClusterMetaKVStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/IClusterMetaKVStore.java @@ -24,19 +24,12 @@ public interface IClusterMetaKVStore extends KeyValueTrait { - /** - * Init cluster meta kv store. - */ - void init(StoreContext storeContext); + /** Init cluster meta kv store. */ + void init(StoreContext storeContext); - /** - * Flush meta info into kv store. - */ - void flush(); - - /** - * Close cluster meta kv store. - */ - void close(); + /** Flush meta info into kv store. */ + void flush(); + /** Close cluster meta kv store. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStore.java index f81a09ae6..5177b9d81 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStore.java @@ -28,36 +28,34 @@ public class MemoryClusterMetaKVStore implements IClusterMetaKVStore { - private IKVStore kvStore; - - @Override - public void init(StoreContext storeContext) { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); - kvStore = (IKVStore) builder.getStore(DataModel.KV, storeContext.getConfig()); - } - - @Override - public void flush() { - - } - - @Override - public V get(String key) { - return (V) kvStore.get(key); - } - - @Override - public void put(String key, V value) { - kvStore.put(key, value); - } - - @Override - public void remove(String key) { - kvStore.remove(key); - } - - @Override - public void close() { - kvStore.close(); - } + private IKVStore kvStore; + + @Override + public void init(StoreContext storeContext) { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); + kvStore = (IKVStore) builder.getStore(DataModel.KV, storeContext.getConfig()); + } + + @Override + public void flush() {} + + @Override + public V get(String key) { + return (V) kvStore.get(key); + } + + @Override + public void put(String key, V value) { + kvStore.put(key, value); + } + + @Override + public void remove(String key) { + kvStore.remove(key); + } + + @Override + public void close() { + kvStore.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStore.java index 68ecb49c9..f786e7243 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStore.java @@ -30,57 +30,57 @@ public class RocksdbClusterMetaKVStore implements IClusterMetaKVStore { - private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbClusterMetaKVStore.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbClusterMetaKVStore.class); - private static final Integer DEFAULT_VERSION = 1; + private static final Integer DEFAULT_VERSION = 1; - private IKVStatefulStore kvStore; - private transient long version; - private String name; + private IKVStatefulStore kvStore; + private transient long version; + private String name; - @Override - public void init(StoreContext storeContext) { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); - this.name = storeContext.getName(); - kvStore = (IKVStatefulStore) builder.getStore(DataModel.KV, - storeContext.getConfig()); - kvStore.init(storeContext); + @Override + public void init(StoreContext storeContext) { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); + this.name = storeContext.getName(); + kvStore = + (IKVStatefulStore) builder.getStore(DataModel.KV, storeContext.getConfig()); + kvStore.init(storeContext); - // recovery - long latest = kvStore.recoveryLatest(); - if (latest > 0) { - LOGGER.info("recovery to latest version {}", latest); - version = latest + 1; - } else { - LOGGER.info("not found any version to recovery"); - version = DEFAULT_VERSION; - } + // recovery + long latest = kvStore.recoveryLatest(); + if (latest > 0) { + LOGGER.info("recovery to latest version {}", latest); + version = latest + 1; + } else { + LOGGER.info("not found any version to recovery"); + version = DEFAULT_VERSION; } + } - @Override - public void flush() { - LOGGER.info("cluster meta {} do flush", name); - kvStore.archive(version); - version++; - } + @Override + public void flush() { + LOGGER.info("cluster meta {} do flush", name); + kvStore.archive(version); + version++; + } - @Override - public V get(K key) { - return (V) kvStore.get(key); - } + @Override + public V get(K key) { + return (V) kvStore.get(key); + } - @Override - public void put(K key, V value) { - kvStore.put(key, value); - } + @Override + public void put(K key, V value) { + kvStore.put(key, value); + } - @Override - public void remove(K key) { - kvStore.remove(key); - } + @Override + public void remove(K key) { + kvStore.remove(key); + } - @Override - public void close() { - kvStore.close(); - } + @Override + public void close() { + kvStore.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITask.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITask.java index 0a5524d53..b1ce1a9e5 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITask.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITask.java @@ -23,23 +23,15 @@ public interface ITask { - /** - * Init task context. - */ - void init(ITaskContext taskContext); + /** Init task context. */ + void init(ITaskContext taskContext); - /** - * Execute corresponding command. - */ - void execute(IExecutableCommand command); + /** Execute corresponding command. */ + void execute(IExecutableCommand command); - /** - * Terminate current running task. - */ - void interrupt(); + /** Terminate current running task. */ + void interrupt(); - /** - * Close task resource. - */ - void close(); + /** Close task resource. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITaskContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITaskContext.java index b42150883..dae7b8bd1 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITaskContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/ITaskContext.java @@ -27,43 +27,27 @@ public interface ITaskContext { - /** - * Register worker into task context. - */ - void registerWorker(IWorker worker); + /** Register worker into task context. */ + void registerWorker(IWorker worker); - /** - * Returns the worker. - */ - IWorker getWorker(); + /** Returns the worker. */ + IWorker getWorker(); - /** - * Returns the fetcher service. - */ - FetcherService getFetcherService(); + /** Returns the fetcher service. */ + FetcherService getFetcherService(); - /** - * Returns the emitter service. - */ - EmitterService getEmitterService(); + /** Returns the emitter service. */ + EmitterService getEmitterService(); - /** - * Returns the worker index. - */ - int getWorkerIndex(); + /** Returns the worker index. */ + int getWorkerIndex(); - /** - * Returns config. - */ - Configuration getConfig(); + /** Returns config. */ + Configuration getConfig(); - /** - * Returns the metric group ref. - */ - MetricGroup getMetricGroup(); + /** Returns the metric group ref. */ + MetricGroup getMetricGroup(); - /** - * Close worker and fetcher/emitter service. - */ - void close(); + /** Close worker and fetcher/emitter service. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/Task.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/Task.java index edade4ca4..97c485b78 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/Task.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/Task.java @@ -23,34 +23,32 @@ public class Task implements ITask { - private ITaskContext context; - private IExecutableCommand command; - - public Task() { - } - - @Override - public void init(ITaskContext taskContext) { - this.context = taskContext; - this.command = null; - } - - @Override - public void execute(IExecutableCommand command) { - this.command = command; - command.execute(this.context); - } - - @Override - public void interrupt() { - if (command != null) { - command.interrupt(); - } - } - - @Override - public void close() { - this.context.close(); + private ITaskContext context; + private IExecutableCommand command; + + public Task() {} + + @Override + public void init(ITaskContext taskContext) { + this.context = taskContext; + this.command = null; + } + + @Override + public void execute(IExecutableCommand command) { + this.command = command; + command.execute(this.context); + } + + @Override + public void interrupt() { + if (command != null) { + command.interrupt(); } + } + @Override + public void close() { + this.context.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/TaskContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/TaskContext.java index 267afbb35..3c4a421d4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/TaskContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/TaskContext.java @@ -28,65 +28,62 @@ public class TaskContext implements ITaskContext { - private int workerIndex; - private Configuration config; - private MetricGroup metricGroup; + private int workerIndex; + private Configuration config; + private MetricGroup metricGroup; - private IWorker worker; - private FetcherService fetcherService; - private EmitterService emitterService; + private IWorker worker; + private FetcherService fetcherService; + private EmitterService emitterService; - public TaskContext(ITaskRunnerContext taskContext) { - this.workerIndex = taskContext.getWorkerIndex(); - this.config = taskContext.getConfig(); - this.metricGroup = taskContext.getMetricGroup(); - this.fetcherService = taskContext.getFetcherService(); - this.emitterService = taskContext.getEmitterService(); - } + public TaskContext(ITaskRunnerContext taskContext) { + this.workerIndex = taskContext.getWorkerIndex(); + this.config = taskContext.getConfig(); + this.metricGroup = taskContext.getMetricGroup(); + this.fetcherService = taskContext.getFetcherService(); + this.emitterService = taskContext.getEmitterService(); + } - @Override - public IWorker getWorker() { - return worker; - } + @Override + public IWorker getWorker() { + return worker; + } - @Override - public void registerWorker(IWorker worker) { - this.worker = worker; - } + @Override + public void registerWorker(IWorker worker) { + this.worker = worker; + } - @Override - public FetcherService getFetcherService() { - return fetcherService; - } + @Override + public FetcherService getFetcherService() { + return fetcherService; + } - @Override - public EmitterService getEmitterService() { - return emitterService; - } + @Override + public EmitterService getEmitterService() { + return emitterService; + } - @Override - public int getWorkerIndex() { - return workerIndex; - } + @Override + public int getWorkerIndex() { + return workerIndex; + } - @Override - public Configuration getConfig() { - return config; - } + @Override + public Configuration getConfig() { + return config; + } - @Override - public MetricGroup getMetricGroup() { - return metricGroup; - } - - /** - * Close worker and io resources. - */ - @Override - public void close() { - worker.close(); - fetcherService.shutdown(); - emitterService.shutdown(); - } + @Override + public MetricGroup getMetricGroup() { + return metricGroup; + } + /** Close worker and io resources. */ + @Override + public void close() { + worker.close(); + fetcherService.shutdown(); + emitterService.shutdown(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/AbstractTaskRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/AbstractTaskRunner.java index 59f97b477..83107cea5 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/AbstractTaskRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/AbstractTaskRunner.java @@ -21,6 +21,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowInterruptedException; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; @@ -28,48 +29,48 @@ public abstract class AbstractTaskRunner implements ITaskRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTaskRunner.class); - private static final int POOL_TIMEOUT = 100; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTaskRunner.class); + private static final int POOL_TIMEOUT = 100; - private final LinkedBlockingQueue taskQueue; - protected volatile boolean running; + private final LinkedBlockingQueue taskQueue; + protected volatile boolean running; - public AbstractTaskRunner() { - this.running = true; - this.taskQueue = new LinkedBlockingQueue<>(); - } + public AbstractTaskRunner() { + this.running = true; + this.taskQueue = new LinkedBlockingQueue<>(); + } - @Override - public void run() { - while (running) { - try { - TASK task = taskQueue.poll(POOL_TIMEOUT, TimeUnit.MILLISECONDS); - if (running && task != null) { - process(task); - } - } catch (InterruptedException e) { - throw new GeaflowInterruptedException(e); - } catch (Throwable t) { - LOGGER.error(t.getMessage(), t); - throw new GeaflowRuntimeException(t); - } + @Override + public void run() { + while (running) { + try { + TASK task = taskQueue.poll(POOL_TIMEOUT, TimeUnit.MILLISECONDS); + if (running && task != null) { + process(task); } + } catch (InterruptedException e) { + throw new GeaflowInterruptedException(e); + } catch (Throwable t) { + LOGGER.error(t.getMessage(), t); + throw new GeaflowRuntimeException(t); + } } + } - @Override - public void add(TASK task) { - this.taskQueue.add(task); - } + @Override + public void add(TASK task) { + this.taskQueue.add(task); + } - protected abstract void process(TASK task); + protected abstract void process(TASK task); - @Override - public void interrupt() { - // TODO interrupt running task. - } + @Override + public void interrupt() { + // TODO interrupt running task. + } - @Override - public void shutdown() { - this.running = false; - } + @Override + public void shutdown() { + this.running = false; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunner.java index 469952708..18884a23b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunner.java @@ -21,19 +21,12 @@ public interface ITaskRunner extends Runnable { - /** - * Add task into processing queue. - */ - void add(TASK task); + /** Add task into processing queue. */ + void add(TASK task); - /** - * Interrupt current running task events. - */ - void interrupt(); - - /** - * Shutdown task runner. - */ - void shutdown(); + /** Interrupt current running task events. */ + void interrupt(); + /** Shutdown task runner. */ + void shutdown(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunnerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunnerContext.java index cd7fe973c..26f12eaa6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunnerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/ITaskRunnerContext.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.task.runner; import java.io.Serializable; + import org.apache.geaflow.cluster.collector.EmitterService; import org.apache.geaflow.cluster.fetcher.FetcherService; import org.apache.geaflow.common.config.Configuration; @@ -27,33 +28,21 @@ public interface ITaskRunnerContext extends Serializable { - /** - * Returns container id. - */ - int getContainerId(); - - /** - * Returns index of current task. - */ - int getWorkerIndex(); - - /** - * Returns the worker config. - */ - Configuration getConfig(); - - /** - * Returns the metric group ref. - */ - MetricGroup getMetricGroup(); - - /** - * Returns the fetcher service. - */ - FetcherService getFetcherService(); - - /** - * Returns teh emitter service. - */ - EmitterService getEmitterService(); + /** Returns container id. */ + int getContainerId(); + + /** Returns index of current task. */ + int getWorkerIndex(); + + /** Returns the worker config. */ + Configuration getConfig(); + + /** Returns the metric group ref. */ + MetricGroup getMetricGroup(); + + /** Returns the fetcher service. */ + FetcherService getFetcherService(); + + /** Returns teh emitter service. */ + EmitterService getEmitterService(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunner.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunner.java index 368933338..95e6fb913 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunner.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunner.java @@ -28,44 +28,44 @@ public class TaskRunner extends AbstractTaskRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(TaskRunner.class); - private Task task; - private ITaskRunnerContext taskRunnerContext; + private static final Logger LOGGER = LoggerFactory.getLogger(TaskRunner.class); + private Task task; + private ITaskRunnerContext taskRunnerContext; - public TaskRunner() { - super(); - } + public TaskRunner() { + super(); + } - public void init(ITaskRunnerContext taskRunnerContext) { - this.taskRunnerContext = taskRunnerContext; - } + public void init(ITaskRunnerContext taskRunnerContext) { + this.taskRunnerContext = taskRunnerContext; + } - @Override - protected void process(ICommand command) { - LOGGER.info("task Executor:{}", command); - switch (command.getEventType()) { - case CREATE_TASK: - // Starting of task's life cycle. - task = new Task(); - task.init(new TaskContext(taskRunnerContext)); - break; - case DESTROY_TASK: - // Ending of task's life cycle. - task.close(); - task = null; - break; - default: - // Execute task command. - task.execute((IExecutableCommand) command); - break; - } + @Override + protected void process(ICommand command) { + LOGGER.info("task Executor:{}", command); + switch (command.getEventType()) { + case CREATE_TASK: + // Starting of task's life cycle. + task = new Task(); + task.init(new TaskContext(taskRunnerContext)); + break; + case DESTROY_TASK: + // Ending of task's life cycle. + task.close(); + task = null; + break; + default: + // Execute task command. + task.execute((IExecutableCommand) command); + break; } + } - @Override - public void interrupt() { - super.interrupt(); - if (task != null) { - task.interrupt(); - } + @Override + public void interrupt() { + super.interrupt(); + if (task != null) { + task.interrupt(); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunnerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunnerContext.java index bd50e53ff..7900c981c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunnerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/runner/TaskRunnerContext.java @@ -26,52 +26,55 @@ public class TaskRunnerContext implements ITaskRunnerContext { - private final int workerIndex; - private final int containerId; - private final Configuration config; - private final MetricGroup metricGroup; - private final FetcherService fetcherService; - private final EmitterService emitterService; + private final int workerIndex; + private final int containerId; + private final Configuration config; + private final MetricGroup metricGroup; + private final FetcherService fetcherService; + private final EmitterService emitterService; - public TaskRunnerContext(int containerId, int workerIndex, Configuration config, - MetricGroup metricGroup, FetcherService fetcherService, - EmitterService emitterService) { - this.containerId = containerId; - this.workerIndex = workerIndex; - this.config = config; - this.metricGroup = metricGroup; - this.fetcherService = fetcherService; - this.emitterService = emitterService; - } + public TaskRunnerContext( + int containerId, + int workerIndex, + Configuration config, + MetricGroup metricGroup, + FetcherService fetcherService, + EmitterService emitterService) { + this.containerId = containerId; + this.workerIndex = workerIndex; + this.config = config; + this.metricGroup = metricGroup; + this.fetcherService = fetcherService; + this.emitterService = emitterService; + } - @Override - public int getContainerId() { - return containerId; - } + @Override + public int getContainerId() { + return containerId; + } - @Override - public int getWorkerIndex() { - return workerIndex; - } + @Override + public int getWorkerIndex() { + return workerIndex; + } - @Override - public Configuration getConfig() { - return config; - } + @Override + public Configuration getConfig() { + return config; + } - @Override - public MetricGroup getMetricGroup() { - return metricGroup; - } + @Override + public MetricGroup getMetricGroup() { + return metricGroup; + } - @Override - public FetcherService getFetcherService() { - return fetcherService; - } - - @Override - public EmitterService getEmitterService() { - return emitterService; - } + @Override + public FetcherService getFetcherService() { + return fetcherService; + } + @Override + public EmitterService getEmitterService() { + return emitterService; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/AbstractTaskService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/AbstractTaskService.java index 9dfc2267d..75cf74df1 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/AbstractTaskService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/AbstractTaskService.java @@ -19,8 +19,8 @@ package org.apache.geaflow.cluster.task.service; -import com.google.common.base.Preconditions; import java.util.concurrent.ExecutorService; + import org.apache.geaflow.cluster.exception.ComponentUncaughtExceptionHandler; import org.apache.geaflow.cluster.task.runner.ITaskRunner; import org.apache.geaflow.common.config.Configuration; @@ -30,66 +30,73 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractTaskService> implements ITaskService { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTaskService.class); +public abstract class AbstractTaskService> + implements ITaskService { - protected ExecutorService executorService; - private R[] tasks; - private String threadFormat; - protected final Configuration configuration; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTaskService.class); - public AbstractTaskService(Configuration configuration, String threadFormat) { - this.threadFormat = threadFormat; - this.configuration = configuration; - } + protected ExecutorService executorService; + private R[] tasks; + private String threadFormat; + protected final Configuration configuration; - public void start() { - this.tasks = buildTaskRunner(); - Preconditions.checkArgument(tasks != null && tasks.length != 0, "must specify at least one task"); - this.executorService = Executors.getExecutorService(getMaxMultiple(), tasks.length, threadFormat, - ComponentUncaughtExceptionHandler.INSTANCE); - for (int i = 0; i < tasks.length; i++) { - executorService.execute(tasks[i]); - } - } + public AbstractTaskService(Configuration configuration, String threadFormat) { + this.threadFormat = threadFormat; + this.configuration = configuration; + } - /** - * Provides the maximum thread multiplier value. - * - * @return the maximum thread multiplier - */ - protected int getMaxMultiple() { - return configuration.getInteger(ExecutionConfigKeys.EXECUTOR_MAX_MULTIPLE); + public void start() { + this.tasks = buildTaskRunner(); + Preconditions.checkArgument( + tasks != null && tasks.length != 0, "must specify at least one task"); + this.executorService = + Executors.getExecutorService( + getMaxMultiple(), + tasks.length, + threadFormat, + ComponentUncaughtExceptionHandler.INSTANCE); + for (int i = 0; i < tasks.length; i++) { + executorService.execute(tasks[i]); } + } - public void process(int workerId, TASK task) { - tasks[workerId].add(task); - } + /** + * Provides the maximum thread multiplier value. + * + * @return the maximum thread multiplier + */ + protected int getMaxMultiple() { + return configuration.getInteger(ExecutionConfigKeys.EXECUTOR_MAX_MULTIPLE); + } - @Override - public void interrupt(int workerId) { - // TODO Interrupt specified worker running task. - // 1. Try interrupt task runner. - // 2. If failed or timeout, try shutdown and then rebuild executor service. - // 3. If failed or timeout, report exception, may need exit process. - tasks[workerId].interrupt(); - } + public void process(int workerId, TASK task) { + tasks[workerId].add(task); + } - @Override - public void shutdown() { - LOGGER.info("shutdown executor service {}", threadFormat); - for (int i = 0; i < tasks.length; i++) { - tasks[i].shutdown(); - } - // try shutdown executor service - ExecutorUtil.shutdown(executorService); - } + @Override + public void interrupt(int workerId) { + // TODO Interrupt specified worker running task. + // 1. Try interrupt task runner. + // 2. If failed or timeout, try shutdown and then rebuild executor service. + // 3. If failed or timeout, report exception, may need exit process. + tasks[workerId].interrupt(); + } - public R getRunner(int workerId) { - return tasks[workerId]; + @Override + public void shutdown() { + LOGGER.info("shutdown executor service {}", threadFormat); + for (int i = 0; i < tasks.length; i++) { + tasks[i].shutdown(); } + // try shutdown executor service + ExecutorUtil.shutdown(executorService); + } - protected abstract R[] buildTaskRunner(); + public R getRunner(int workerId) { + return tasks[workerId]; + } + protected abstract R[] buildTaskRunner(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/ITaskService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/ITaskService.java index acfb9aeef..70a29a5c7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/ITaskService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/ITaskService.java @@ -21,23 +21,15 @@ public interface ITaskService { - /** - * Start task service. - */ - void start(); + /** Start task service. */ + void start(); - /** - * Process event on workerId. - */ - void process(int workerId, TASK task); + /** Process event on workerId. */ + void process(int workerId, TASK task); - /** - * Interrupt running task worker. - */ - void interrupt(int workerId); + /** Interrupt running task worker. */ + void interrupt(int workerId); - /** - * Shutdown task service. - */ - void shutdown(); + /** Shutdown task service. */ + void shutdown(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/TaskService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/TaskService.java index a39a86b81..b8dbb100d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/TaskService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/task/service/TaskService.java @@ -31,44 +31,54 @@ public class TaskService extends AbstractTaskService { - private static final Logger LOGGER = LoggerFactory.getLogger(TaskService.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TaskService.class); - private static final String WORKER_FORMAT = "geaflow-worker-%d"; + private static final String WORKER_FORMAT = "geaflow-worker-%d"; - private int containerId; - private int taskNum; - private MetricGroup metricGroup; - private FetcherService fetcherService; - private EmitterService emitterService; + private int containerId; + private int taskNum; + private MetricGroup metricGroup; + private FetcherService fetcherService; + private EmitterService emitterService; - public TaskService(int containerId, int taskNum, Configuration configuration, - MetricGroup metricGroup, FetcherService fetcherService, EmitterService emitterService) { - super(configuration, WORKER_FORMAT); - this.containerId = containerId; - this.taskNum = taskNum; - this.metricGroup = metricGroup; - this.fetcherService = fetcherService; - this.emitterService = emitterService; - } + public TaskService( + int containerId, + int taskNum, + Configuration configuration, + MetricGroup metricGroup, + FetcherService fetcherService, + EmitterService emitterService) { + super(configuration, WORKER_FORMAT); + this.containerId = containerId; + this.taskNum = taskNum; + this.metricGroup = metricGroup; + this.fetcherService = fetcherService; + this.emitterService = emitterService; + } - @Override - protected TaskRunner[] buildTaskRunner() { - TaskRunner[] taskRunners = new TaskRunner[taskNum]; - for (int i = 0; i < taskNum; i++) { - TaskRunner taskRunner = buildTask(containerId, i, configuration, - metricGroup, fetcherService, emitterService); - taskRunners[i] = taskRunner; - } - return taskRunners; + @Override + protected TaskRunner[] buildTaskRunner() { + TaskRunner[] taskRunners = new TaskRunner[taskNum]; + for (int i = 0; i < taskNum; i++) { + TaskRunner taskRunner = + buildTask(containerId, i, configuration, metricGroup, fetcherService, emitterService); + taskRunners[i] = taskRunner; } + return taskRunners; + } - private TaskRunner buildTask(int containerId, int taskIndex, - Configuration configuration, MetricGroup metricGroup, - FetcherService fetcherService, EmitterService emitterService) { - TaskRunner taskRunner = new TaskRunner(); - TaskRunnerContext taskRunnerContext = new TaskRunnerContext(containerId, taskIndex, - configuration, metricGroup, fetcherService, emitterService); - taskRunner.init(taskRunnerContext); - return taskRunner; - } + private TaskRunner buildTask( + int containerId, + int taskIndex, + Configuration configuration, + MetricGroup metricGroup, + FetcherService fetcherService, + EmitterService emitterService) { + TaskRunner taskRunner = new TaskRunner(); + TaskRunnerContext taskRunnerContext = + new TaskRunnerContext( + containerId, taskIndex, configuration, metricGroup, fetcherService, emitterService); + taskRunner.init(taskRunnerContext); + return taskRunner; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/HttpServer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/HttpServer.java index 60a89079d..f45f8fb8d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/HttpServer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/HttpServer.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.MASTER_HTTP_PORT; import java.net.URL; + import org.apache.geaflow.cluster.clustermanager.IClusterManager; import org.apache.geaflow.cluster.common.ComponentInfo; import org.apache.geaflow.cluster.heartbeat.HeartbeatManager; @@ -53,109 +54,110 @@ public class HttpServer { - private static final Logger LOGGER = LoggerFactory.getLogger(HttpServer.class); - private static final String SERVER_NAME = "jetty-server"; - private static final int DEFAULT_ACCEPT_QUEUE_SIZE = 8; - private static final int HTTP_NOTFOUND_CODE = 404; - private static final String STATIC_RESOURCES_FOLDER_PATH = "dist"; - - private final Server server; - private final int httpPort; - private final QueuedThreadPool threadPool; - private final ScheduledExecutorScheduler serverExecutor; - - public HttpServer(Configuration configuration, IClusterManager clusterManager, - HeartbeatManager heartbeatManager, IResourceManager resourceManager, - ComponentInfo masterInfo) { - httpPort = configuration.getInteger(MASTER_HTTP_PORT); - threadPool = new QueuedThreadPool(); - threadPool.setDaemon(true); - threadPool.setName(SERVER_NAME); - server = new Server(threadPool); - - ErrorHandler errorHandler = new ErrorHandler(); - errorHandler.setShowStacks(true); - errorHandler.setServer(server); - server.addBean(errorHandler); - - MetricCache metricCache = new MetricCache(); - MetricFetcher metricFetcher = new MetricFetcher(configuration, clusterManager, - metricCache); - - ResourceConfig resourceConfig = new ResourceConfig(); - resourceConfig.register(new MasterRestHandler(masterInfo, configuration)); - resourceConfig.register(new ClusterRestHandler(clusterManager, heartbeatManager, - resourceManager, metricFetcher)); - resourceConfig.register(new PipelineRestHandler(metricCache, metricFetcher)); - - ServletContextHandler handler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS); - handler.setContextPath("/"); - handler.addServlet(new ServletHolder(new ProxyHandler()), "/proxy/*"); - handler.addServlet(new ServletHolder(new ServletContainer(resourceConfig)), "/rest/*"); - handler.addServlet(new ServletHolder(new DefaultServlet()), "/"); - - try { - URL resourcePath = - HttpServer.class.getClassLoader().getResource(STATIC_RESOURCES_FOLDER_PATH); - LOGGER.info("Try Loading static resources of path: {}", resourcePath); - handler.setBaseResource(Resource.newResource(resourcePath)); - } catch (Exception e) { - LOGGER.error("Failed to load static resources. {}", e.getMessage(), e); - } - - ErrorPageErrorHandler errorPageHandler = new ErrorPageErrorHandler(); - errorPageHandler.addErrorPage(HTTP_NOTFOUND_CODE, "/"); - handler.setErrorHandler(errorPageHandler); - - server.setHandler(handler); - - serverExecutor = new ScheduledExecutorScheduler("jetty-scheduler", true); + private static final Logger LOGGER = LoggerFactory.getLogger(HttpServer.class); + private static final String SERVER_NAME = "jetty-server"; + private static final int DEFAULT_ACCEPT_QUEUE_SIZE = 8; + private static final int HTTP_NOTFOUND_CODE = 404; + private static final String STATIC_RESOURCES_FOLDER_PATH = "dist"; + + private final Server server; + private final int httpPort; + private final QueuedThreadPool threadPool; + private final ScheduledExecutorScheduler serverExecutor; + + public HttpServer( + Configuration configuration, + IClusterManager clusterManager, + HeartbeatManager heartbeatManager, + IResourceManager resourceManager, + ComponentInfo masterInfo) { + httpPort = configuration.getInteger(MASTER_HTTP_PORT); + threadPool = new QueuedThreadPool(); + threadPool.setDaemon(true); + threadPool.setName(SERVER_NAME); + server = new Server(threadPool); + + ErrorHandler errorHandler = new ErrorHandler(); + errorHandler.setShowStacks(true); + errorHandler.setServer(server); + server.addBean(errorHandler); + + MetricCache metricCache = new MetricCache(); + MetricFetcher metricFetcher = new MetricFetcher(configuration, clusterManager, metricCache); + + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(new MasterRestHandler(masterInfo, configuration)); + resourceConfig.register( + new ClusterRestHandler(clusterManager, heartbeatManager, resourceManager, metricFetcher)); + resourceConfig.register(new PipelineRestHandler(metricCache, metricFetcher)); + + ServletContextHandler handler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS); + handler.setContextPath("/"); + handler.addServlet(new ServletHolder(new ProxyHandler()), "/proxy/*"); + handler.addServlet(new ServletHolder(new ServletContainer(resourceConfig)), "/rest/*"); + handler.addServlet(new ServletHolder(new DefaultServlet()), "/"); + + try { + URL resourcePath = + HttpServer.class.getClassLoader().getResource(STATIC_RESOURCES_FOLDER_PATH); + LOGGER.info("Try Loading static resources of path: {}", resourcePath); + handler.setBaseResource(Resource.newResource(resourcePath)); + } catch (Exception e) { + LOGGER.error("Failed to load static resources. {}", e.getMessage(), e); } - public void start() { - try { - ServerConnector connector = newConnector(server, serverExecutor, null, httpPort); - connector.setName(SERVER_NAME); - server.addConnector(connector); - - int minThreads = 1; - minThreads += connector.getAcceptors() * 2; - threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); - - server.start(); - LOGGER.info("Jetty Server started: {}", httpPort); - } catch (Exception e) { - LOGGER.error("jetty server failed:", e); - throw new GeaflowRuntimeException(e); - } - } + ErrorPageErrorHandler errorPageHandler = new ErrorPageErrorHandler(); + errorPageHandler.addErrorPage(HTTP_NOTFOUND_CODE, "/"); + handler.setErrorHandler(errorPageHandler); - public void stop() { - try { - server.stop(); - if (threadPool.isStarted()) { - threadPool.stop(); - } - if (serverExecutor.isStarted()) { - serverExecutor.stop(); - } - } catch (Exception e) { - LOGGER.warn("stop jetty server failed", e); - throw new GeaflowRuntimeException(e); - } - } + server.setHandler(handler); - private ServerConnector newConnector(Server server, ScheduledExecutorScheduler serverExecutor, - String hostName, int port) throws Exception { - ConnectionFactory[] connectionFactories = new ConnectionFactory[]{ - new HttpConnectionFactory()}; - ServerConnector connector = new ServerConnector(server, null, serverExecutor, null, -1, -1, - connectionFactories); - connector.setHost(hostName); - connector.setPort(port); - connector.start(); - connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), DEFAULT_ACCEPT_QUEUE_SIZE)); - return connector; - } + serverExecutor = new ScheduledExecutorScheduler("jetty-scheduler", true); + } -} \ No newline at end of file + public void start() { + try { + ServerConnector connector = newConnector(server, serverExecutor, null, httpPort); + connector.setName(SERVER_NAME); + server.addConnector(connector); + + int minThreads = 1; + minThreads += connector.getAcceptors() * 2; + threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); + + server.start(); + LOGGER.info("Jetty Server started: {}", httpPort); + } catch (Exception e) { + LOGGER.error("jetty server failed:", e); + throw new GeaflowRuntimeException(e); + } + } + + public void stop() { + try { + server.stop(); + if (threadPool.isStarted()) { + threadPool.stop(); + } + if (serverExecutor.isStarted()) { + serverExecutor.stop(); + } + } catch (Exception e) { + LOGGER.warn("stop jetty server failed", e); + throw new GeaflowRuntimeException(e); + } + } + + private ServerConnector newConnector( + Server server, ScheduledExecutorScheduler serverExecutor, String hostName, int port) + throws Exception { + ConnectionFactory[] connectionFactories = new ConnectionFactory[] {new HttpConnectionFactory()}; + ServerConnector connector = + new ServerConnector(server, null, serverExecutor, null, -1, -1, connectionFactories); + connector.setHost(hostName); + connector.setPort(port); + connector.start(); + connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), DEFAULT_ACCEPT_QUEUE_SIZE)); + return connector; + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/AgentWebServer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/AgentWebServer.java index 432b71214..b2da15749 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/AgentWebServer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/AgentWebServer.java @@ -46,113 +46,117 @@ public class AgentWebServer { - private static final Logger LOGGER = LoggerFactory.getLogger(AgentWebServer.class); - - private static final String AGENT_SERVER_NAME = "agent-jetty-server"; - - private static final int DEFAULT_ACCEPT_QUEUE_SIZE = 8; - - private final Server server; - - private final int httpPort; - - private final QueuedThreadPool threadPool; - - private final ScheduledExecutorScheduler serverExecutor; - - private final Object lock = new Object(); - - public AgentWebServer(int httpPort, Configuration configuration) { - this(httpPort, configuration.getString(LOG_DIR), - configuration.getString(AGENT_PROFILER_PATH), - configuration.getString(PROFILER_FILENAME_EXTENSION), - configuration.getString(JOB_WORK_PATH)); + private static final Logger LOGGER = LoggerFactory.getLogger(AgentWebServer.class); + + private static final String AGENT_SERVER_NAME = "agent-jetty-server"; + + private static final int DEFAULT_ACCEPT_QUEUE_SIZE = 8; + + private final Server server; + + private final int httpPort; + + private final QueuedThreadPool threadPool; + + private final ScheduledExecutorScheduler serverExecutor; + + private final Object lock = new Object(); + + public AgentWebServer(int httpPort, Configuration configuration) { + this( + httpPort, + configuration.getString(LOG_DIR), + configuration.getString(AGENT_PROFILER_PATH), + configuration.getString(PROFILER_FILENAME_EXTENSION), + configuration.getString(JOB_WORK_PATH)); + } + + public AgentWebServer( + int httpPort, + String runtimeLogDirPath, + String flameGraphProfilerPath, + String flameGraphFileNameExtension, + String agentDir) { + this.httpPort = httpPort; + threadPool = new QueuedThreadPool(); + threadPool.setDaemon(true); + threadPool.setName(AGENT_SERVER_NAME); + server = new Server(threadPool); + + ErrorHandler errorHandler = new ErrorHandler(); + errorHandler.setShowStacks(true); + errorHandler.setServer(server); + server.addBean(errorHandler); + + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(new LogRestHandler(runtimeLogDirPath)); + resourceConfig.register( + new FlameGraphRestHandler(flameGraphProfilerPath, flameGraphFileNameExtension, agentDir)); + resourceConfig.register(new ThreadDumpRestHandler(agentDir)); + + ServletContextHandler handler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS); + handler.setContextPath("/"); + handler.addServlet(new ServletHolder(new ServletContainer(resourceConfig)), "/rest/*"); + handler.addServlet(new ServletHolder(new DefaultServlet()), "/"); + + server.setHandler(handler); + + serverExecutor = new ScheduledExecutorScheduler("jetty-scheduler", true); + } + + public void start() { + try { + ServerConnector connector = newConnector(server, serverExecutor, null, httpPort); + connector.setName(AGENT_SERVER_NAME); + server.addConnector(connector); + + int minThreads = 1; + minThreads += connector.getAcceptors() * 2; + threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); + + server.start(); + LOGGER.info("Jetty Server started: {}.", httpPort); + } catch (Exception e) { + LOGGER.error("Jetty server failed.", e); + throw new GeaflowRuntimeException(e); } + } - public AgentWebServer(int httpPort, String runtimeLogDirPath, String flameGraphProfilerPath, - String flameGraphFileNameExtension, String agentDir) { - this.httpPort = httpPort; - threadPool = new QueuedThreadPool(); - threadPool.setDaemon(true); - threadPool.setName(AGENT_SERVER_NAME); - server = new Server(threadPool); - - ErrorHandler errorHandler = new ErrorHandler(); - errorHandler.setShowStacks(true); - errorHandler.setServer(server); - server.addBean(errorHandler); - - ResourceConfig resourceConfig = new ResourceConfig(); - resourceConfig.register(new LogRestHandler(runtimeLogDirPath)); - resourceConfig.register(new FlameGraphRestHandler(flameGraphProfilerPath, flameGraphFileNameExtension, agentDir)); - resourceConfig.register(new ThreadDumpRestHandler(agentDir)); - - ServletContextHandler handler = new ServletContextHandler( - ServletContextHandler.NO_SESSIONS); - handler.setContextPath("/"); - handler.addServlet(new ServletHolder(new ServletContainer(resourceConfig)), "/rest/*"); - handler.addServlet(new ServletHolder(new DefaultServlet()), "/"); - - server.setHandler(handler); - - serverExecutor = new ScheduledExecutorScheduler("jetty-scheduler", true); + public void await() throws InterruptedException { + LOGGER.info("Wait for agent jetty server stopped."); + synchronized (lock) { + lock.wait(); } - - public void start() { - try { - ServerConnector connector = newConnector(server, serverExecutor, null, httpPort); - connector.setName(AGENT_SERVER_NAME); - server.addConnector(connector); - - int minThreads = 1; - minThreads += connector.getAcceptors() * 2; - threadPool.setMaxThreads(Math.max(threadPool.getMaxThreads(), minThreads)); - - server.start(); - LOGGER.info("Jetty Server started: {}.", httpPort); - } catch (Exception e) { - LOGGER.error("Jetty server failed.", e); - throw new GeaflowRuntimeException(e); - } + } + + public void stop() { + try { + server.stop(); + if (threadPool.isStarted()) { + threadPool.stop(); + } + if (serverExecutor.isStarted()) { + serverExecutor.stop(); + } + synchronized (lock) { + lock.notify(); + } + } catch (Exception e) { + LOGGER.warn("Stop jetty server failed.", e); + throw new GeaflowRuntimeException(e); } - - public void await() throws InterruptedException { - LOGGER.info("Wait for agent jetty server stopped."); - synchronized (lock) { - lock.wait(); - } - } - - public void stop() { - try { - server.stop(); - if (threadPool.isStarted()) { - threadPool.stop(); - } - if (serverExecutor.isStarted()) { - serverExecutor.stop(); - } - synchronized (lock) { - lock.notify(); - } - } catch (Exception e) { - LOGGER.warn("Stop jetty server failed.", e); - throw new GeaflowRuntimeException(e); - } - } - - private ServerConnector newConnector(Server server, ScheduledExecutorScheduler serverExecutor, - String hostName, int port) throws Exception { - ConnectionFactory[] connectionFactories = new ConnectionFactory[]{ - new HttpConnectionFactory()}; - ServerConnector connector = new ServerConnector(server, null, serverExecutor, null, -1, -1, - connectionFactories); - connector.setHost(hostName); - connector.setPort(port); - connector.start(); - connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), DEFAULT_ACCEPT_QUEUE_SIZE)); - return connector; - } - - + } + + private ServerConnector newConnector( + Server server, ScheduledExecutorScheduler serverExecutor, String hostName, int port) + throws Exception { + ConnectionFactory[] connectionFactories = new ConnectionFactory[] {new HttpConnectionFactory()}; + ServerConnector connector = + new ServerConnector(server, null, serverExecutor, null, -1, -1, connectionFactories); + connector.setHost(hostName); + connector.setPort(port); + connector.start(); + connector.setAcceptQueueSize(Math.min(connector.getAcceptors(), DEFAULT_ACCEPT_QUEUE_SIZE)); + return connector; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/FlameGraphRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/FlameGraphRestHandler.java index 59797e440..09f5cd38e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/FlameGraphRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/FlameGraphRestHandler.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.AGENT_PROFILER_PATH; -import com.google.common.base.Preconditions; import java.io.File; import java.util.ArrayList; import java.util.List; @@ -29,6 +28,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -37,6 +37,7 @@ import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; import javax.ws.rs.core.MediaType; + import org.apache.commons.lang.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.web.agent.model.FileInfo; @@ -51,177 +52,206 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + @Path("/flame-graphs") public class FlameGraphRestHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(FlameGraphRestHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FlameGraphRestHandler.class); - private static final String FLAME_GRAPH_FILE_PREFIX = "geaflow-flamegraph"; + private static final String FLAME_GRAPH_FILE_PREFIX = "geaflow-flamegraph"; - private static final int FLAME_GRAPH_FILE_MAX_CNT = 10; + private static final int FLAME_GRAPH_FILE_MAX_CNT = 10; - private final ExecutorService profileService = new ThreadPoolExecutor(1, 10, 30, - TimeUnit.SECONDS, new LinkedBlockingQueue<>(10), - ThreadUtil.namedThreadFactory(true, "flame-graph-profiler")); + private final ExecutorService profileService = + new ThreadPoolExecutor( + 1, + 10, + 30, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(10), + ThreadUtil.namedThreadFactory(true, "flame-graph-profiler")); - private final String flameGraphProfilerPath; + private final String flameGraphProfilerPath; - private final String flameGraphFileNameExtension; + private final String flameGraphFileNameExtension; - private final String agentDir; + private final String agentDir; - public FlameGraphRestHandler(String flameGraphProfilerPath, String flameGraphFileNameExtension, String agentDir) { - this.flameGraphProfilerPath = flameGraphProfilerPath; - this.flameGraphFileNameExtension = flameGraphFileNameExtension; - this.agentDir = agentDir; - } + public FlameGraphRestHandler( + String flameGraphProfilerPath, String flameGraphFileNameExtension, String agentDir) { + this.flameGraphProfilerPath = flameGraphProfilerPath; + this.flameGraphFileNameExtension = flameGraphFileNameExtension; + this.agentDir = agentDir; + } - @GET - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getFlameGraphFileList() { - try { - List flameGraphFiles = new ArrayList<>(); - File file = new File(agentDir); - String[] fileList = file.list(); - if (fileList != null) { - for (String f : fileList) { - String filePath = agentDir + File.separator + f; - File flameGraphFile = new File(filePath); - if (flameGraphFile.isFile() && f.startsWith(FLAME_GRAPH_FILE_PREFIX) - && f.endsWith(flameGraphFileNameExtension)) { - FileInfo fileInfo = FileUtil.buildFileInfo(flameGraphFile, filePath); - flameGraphFiles.add(fileInfo); - } - } - } - return ApiResponse.success(flameGraphFiles); - } catch (Throwable t) { - LOGGER.error("Query flame-graph file list failed. {}", t.getMessage(), t); - return ApiResponse.error(t); + @GET + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getFlameGraphFileList() { + try { + List flameGraphFiles = new ArrayList<>(); + File file = new File(agentDir); + String[] fileList = file.list(); + if (fileList != null) { + for (String f : fileList) { + String filePath = agentDir + File.separator + f; + File flameGraphFile = new File(filePath); + if (flameGraphFile.isFile() + && f.startsWith(FLAME_GRAPH_FILE_PREFIX) + && f.endsWith(flameGraphFileNameExtension)) { + FileInfo fileInfo = FileUtil.buildFileInfo(flameGraphFile, filePath); + flameGraphFiles.add(fileInfo); + } } + } + return ApiResponse.success(flameGraphFiles); + } catch (Throwable t) { + LOGGER.error("Query flame-graph file list failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/content") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse getFlameGraphFileContent(@QueryParam("path") String filePath) { - try { - checkFlameGraphFilePath(filePath); - String content = org.apache.geaflow.common.utils.FileUtil.getContentFromFile(filePath); - if (content == null) { - throw new GeaflowRuntimeException( - String.format("Flame-graph file %s not exists.", filePath)); - } - return ApiResponse.success(content); - } catch (Throwable t) { - LOGGER.error("Query flame-graph file content failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/content") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse getFlameGraphFileContent(@QueryParam("path") String filePath) { + try { + checkFlameGraphFilePath(filePath); + String content = org.apache.geaflow.common.utils.FileUtil.getContentFromFile(filePath); + if (content == null) { + throw new GeaflowRuntimeException( + String.format("Flame-graph file %s not exists.", filePath)); + } + return ApiResponse.success(content); + } catch (Throwable t) { + LOGGER.error("Query flame-graph file content failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @POST - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public ApiResponse executeFlameGraphProfiler(FlameGraphRequest request) { - try { - checkProfilerPath(); - checkFlameGraphRequest(request); - checkFlameGraphFileCount(); - ProcessBuilder command = getCommand(request); - profileService.submit(() -> ShellUtil.executeShellCommand(command, 90)); - return ApiResponse.success(); - } catch (Throwable t) { - LOGGER.error("Execute flame-graph profiler command failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @POST + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public ApiResponse executeFlameGraphProfiler(FlameGraphRequest request) { + try { + checkProfilerPath(); + checkFlameGraphRequest(request); + checkFlameGraphFileCount(); + ProcessBuilder command = getCommand(request); + profileService.submit(() -> ShellUtil.executeShellCommand(command, 90)); + return ApiResponse.success(); + } catch (Throwable t) { + LOGGER.error("Execute flame-graph profiler command failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @DELETE - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse deleteFlameGraphFile(@QueryParam("path") String filePath) { - try { - checkFlameGraphFilePath(filePath); - File file = new File(filePath); - if (!file.exists() || !file.isFile()) { - throw new GeaflowRuntimeException(String.format("File %s not found.", filePath)); - } - file.delete(); - return ApiResponse.success(); - } catch (Throwable t) { - LOGGER.error("Delete flame-graph file failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @DELETE + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse deleteFlameGraphFile(@QueryParam("path") String filePath) { + try { + checkFlameGraphFilePath(filePath); + File file = new File(filePath); + if (!file.exists() || !file.isFile()) { + throw new GeaflowRuntimeException(String.format("File %s not found.", filePath)); + } + file.delete(); + return ApiResponse.success(); + } catch (Throwable t) { + LOGGER.error("Delete flame-graph file failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - private void checkFlameGraphFileCount() { - File file = new File(agentDir); - String[] fileList = file.list(); - int cnt = 0; - if (fileList != null) { - for (String f : fileList) { - String filePath = agentDir + File.separator + f; - File flameGraphFile = new File(filePath); - if (flameGraphFile.isFile() && f.startsWith(FLAME_GRAPH_FILE_PREFIX) && f.endsWith( - flameGraphFileNameExtension)) { - cnt++; - } - } - } - if (cnt >= FLAME_GRAPH_FILE_MAX_CNT) { - throw new GeaflowRuntimeException(String.format( - "The count of flame-graph files is " + "limited to " + "%s. " - + "Please delete some of them first.", FLAME_GRAPH_FILE_MAX_CNT)); + private void checkFlameGraphFileCount() { + File file = new File(agentDir); + String[] fileList = file.list(); + int cnt = 0; + if (fileList != null) { + for (String f : fileList) { + String filePath = agentDir + File.separator + f; + File flameGraphFile = new File(filePath); + if (flameGraphFile.isFile() + && f.startsWith(FLAME_GRAPH_FILE_PREFIX) + && f.endsWith(flameGraphFileNameExtension)) { + cnt++; } + } } - - private void checkProfilerPath() { - if (StringUtils.isEmpty(flameGraphProfilerPath)) { - throw new GeaflowRuntimeException(String.format("Async-profiler shell script path is " - + "not set. Please set the file path of async-profiler path: %s", - AGENT_PROFILER_PATH)); - } + if (cnt >= FLAME_GRAPH_FILE_MAX_CNT) { + throw new GeaflowRuntimeException( + String.format( + "The count of flame-graph files is " + + "limited to " + + "%s. " + + "Please delete some of them first.", + FLAME_GRAPH_FILE_MAX_CNT)); } + } - private void checkFlameGraphFilePath(String path) { - Preconditions.checkArgument( - path != null && path.startsWith(agentDir) && path.endsWith( - flameGraphFileNameExtension), "File path is invalid."); + private void checkProfilerPath() { + if (StringUtils.isEmpty(flameGraphProfilerPath)) { + throw new GeaflowRuntimeException( + String.format( + "Async-profiler shell script path is " + + "not set. Please set the file path of async-profiler path: %s", + AGENT_PROFILER_PATH)); } + } - private void checkFlameGraphRequest(FlameGraphRequest request) { - Preconditions.checkArgument(request.getType() != null, "Profiler type cannot be null."); - Preconditions.checkArgument(request.getDuration() > 0 && request.getDuration() <= 60, - "Duration must be within 0~60 seconds."); - Preconditions.checkArgument(request.getPid() > 0, "Pid must be larger than 0."); - } + private void checkFlameGraphFilePath(String path) { + Preconditions.checkArgument( + path != null && path.startsWith(agentDir) && path.endsWith(flameGraphFileNameExtension), + "File path is invalid."); + } - private ProcessBuilder getCommand(FlameGraphRequest request) { - String now = DateUtil.simpleFormat(System.currentTimeMillis()); - String randomSuffix = RandomStringUtils.randomAlphabetic(4); - StringBuilder filePath = new StringBuilder(); - filePath.append(agentDir).append("/").append(FLAME_GRAPH_FILE_PREFIX).append("-") - .append("pid").append(request.getPid()).append("-").append(request.getType()) - .append("-").append(request.getDuration()).append("s").append("-").append(now) - .append("-").append(randomSuffix).append(flameGraphFileNameExtension); - List commands = new ArrayList<>(); - commands.add("sh"); - commands.add(flameGraphProfilerPath); - commands.add("--all-user"); - commands.add("-d"); - commands.add(String.valueOf(request.getDuration())); - if (request.getType() == FlameGraphType.ALLOC) { - commands.add("-e"); - commands.add(FlameGraphType.ALLOC.name().toLowerCase()); - } - commands.add("-f"); - commands.add(filePath.toString()); - commands.add(String.valueOf(request.getPid())); - ProcessBuilder processBuilder = new ProcessBuilder(); - processBuilder.command(commands); - return processBuilder; - } + private void checkFlameGraphRequest(FlameGraphRequest request) { + Preconditions.checkArgument(request.getType() != null, "Profiler type cannot be null."); + Preconditions.checkArgument( + request.getDuration() > 0 && request.getDuration() <= 60, + "Duration must be within 0~60 seconds."); + Preconditions.checkArgument(request.getPid() > 0, "Pid must be larger than 0."); + } + private ProcessBuilder getCommand(FlameGraphRequest request) { + String now = DateUtil.simpleFormat(System.currentTimeMillis()); + String randomSuffix = RandomStringUtils.randomAlphabetic(4); + StringBuilder filePath = new StringBuilder(); + filePath + .append(agentDir) + .append("/") + .append(FLAME_GRAPH_FILE_PREFIX) + .append("-") + .append("pid") + .append(request.getPid()) + .append("-") + .append(request.getType()) + .append("-") + .append(request.getDuration()) + .append("s") + .append("-") + .append(now) + .append("-") + .append(randomSuffix) + .append(flameGraphFileNameExtension); + List commands = new ArrayList<>(); + commands.add("sh"); + commands.add(flameGraphProfilerPath); + commands.add("--all-user"); + commands.add("-d"); + commands.add(String.valueOf(request.getDuration())); + if (request.getType() == FlameGraphType.ALLOC) { + commands.add("-e"); + commands.add(FlameGraphType.ALLOC.name().toLowerCase()); + } + commands.add("-f"); + commands.add(filePath.toString()); + commands.add(String.valueOf(request.getPid())); + ProcessBuilder processBuilder = new ProcessBuilder(); + processBuilder.command(commands); + return processBuilder; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/LogRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/LogRestHandler.java index 2db8d629e..124f5f4d9 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/LogRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/LogRestHandler.java @@ -25,11 +25,13 @@ import java.util.ArrayList; import java.util.List; import java.util.regex.Pattern; + import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; import javax.ws.rs.core.MediaType; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.web.agent.model.FileInfo; import org.apache.geaflow.cluster.web.agent.model.PaginationRequest; @@ -43,73 +45,76 @@ @Path("/logs") public class LogRestHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(LogRestHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LogRestHandler.class); - private final String runtimeLogDirPath; + private final String runtimeLogDirPath; - private final Pattern logPattern; + private final Pattern logPattern; - public LogRestHandler(String runtimeLogDirPath) { - this.runtimeLogDirPath = runtimeLogDirPath; - this.logPattern = Pattern.compile(String.format("%s.*\\.log(\\.\\d*)?", this.runtimeLogDirPath)); - } + public LogRestHandler(String runtimeLogDirPath) { + this.runtimeLogDirPath = runtimeLogDirPath; + this.logPattern = + Pattern.compile(String.format("%s.*\\.log(\\.\\d*)?", this.runtimeLogDirPath)); + } - @GET - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getLogList() { - try { - checkRuntimeLogDirPath(); - List logs = new ArrayList<>(); - File file = new File(runtimeLogDirPath); - String[] fileList = file.list(); - for (String f : fileList) { - String logPath = runtimeLogDirPath + File.separator + f; - File logFile = new File(logPath); - if (logFile.isFile() && logPattern.matcher(logPath).matches()) { - FileInfo fileInfo = FileUtil.buildFileInfo(logFile, logPath); - logs.add(fileInfo); - } - } - return ApiResponse.success(logs); - } catch (Throwable t) { - LOGGER.error("Query log file list failed. {}", t.getMessage(), t); - return ApiResponse.error(t); + @GET + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getLogList() { + try { + checkRuntimeLogDirPath(); + List logs = new ArrayList<>(); + File file = new File(runtimeLogDirPath); + String[] fileList = file.list(); + for (String f : fileList) { + String logPath = runtimeLogDirPath + File.separator + f; + File logFile = new File(logPath); + if (logFile.isFile() && logPattern.matcher(logPath).matches()) { + FileInfo fileInfo = FileUtil.buildFileInfo(logFile, logPath); + logs.add(fileInfo); } + } + return ApiResponse.success(logs); + } catch (Throwable t) { + LOGGER.error("Query log file list failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/content") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getLogContent(@QueryParam("path") String logPath, - @QueryParam("pageNo") int pageNo, - @QueryParam("pageSize") int pageSize) { - try { - checkLogPath(logPath); - PaginationRequest request = new PaginationRequest(pageNo, pageSize); - FileUtil.checkPaginationRequest(request); - PaginationResponse response = FileUtil.getFileContent(request, logPath); - if (response == null) { - throw new GeaflowRuntimeException( - String.format("Log file %s not exists.", logPath)); - } - return ApiResponse.success(response); - } catch (Throwable t) { - LOGGER.error("Query log content {} failed. {}", logPath, t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/content") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getLogContent( + @QueryParam("path") String logPath, + @QueryParam("pageNo") int pageNo, + @QueryParam("pageSize") int pageSize) { + try { + checkLogPath(logPath); + PaginationRequest request = new PaginationRequest(pageNo, pageSize); + FileUtil.checkPaginationRequest(request); + PaginationResponse response = FileUtil.getFileContent(request, logPath); + if (response == null) { + throw new GeaflowRuntimeException(String.format("Log file %s not exists.", logPath)); + } + return ApiResponse.success(response); + } catch (Throwable t) { + LOGGER.error("Query log content {} failed. {}", logPath, t.getMessage(), t); + return ApiResponse.error(t); } + } - private void checkRuntimeLogDirPath() { - if (StringUtils.isEmpty(runtimeLogDirPath)) { - throw new GeaflowRuntimeException(String.format("Log dir path is not set. Please set the log " - + "dir path config: %s", LOG_DIR.getKey())); - } + private void checkRuntimeLogDirPath() { + if (StringUtils.isEmpty(runtimeLogDirPath)) { + throw new GeaflowRuntimeException( + String.format( + "Log dir path is not set. Please set the log " + "dir path config: %s", + LOG_DIR.getKey())); } + } - private void checkLogPath(String logPath) { - if (logPath == null || !logPattern.matcher(logPath).matches()) { - throw new GeaflowRuntimeException(String.format("Log path %s is invalid.", logPath)); - } + private void checkLogPath(String logPath) { + if (logPath == null || !logPattern.matcher(logPath).matches()) { + throw new GeaflowRuntimeException(String.format("Log path %s is invalid.", logPath)); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/ThreadDumpRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/ThreadDumpRestHandler.java index d05799b15..12bbb2103 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/ThreadDumpRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/handler/ThreadDumpRestHandler.java @@ -19,11 +19,11 @@ package org.apache.geaflow.cluster.web.agent.handler; -import com.google.common.base.Preconditions; import java.io.File; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; + import javax.ws.rs.Consumes; import javax.ws.rs.GET; import javax.ws.rs.POST; @@ -31,6 +31,7 @@ import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; import javax.ws.rs.core.MediaType; + import org.apache.geaflow.cluster.web.agent.model.PaginationRequest; import org.apache.geaflow.cluster.web.agent.model.PaginationResponse; import org.apache.geaflow.cluster.web.agent.model.ThreadDumpRequest; @@ -41,71 +42,72 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + @Path("/thread-dump") public class ThreadDumpRestHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(ThreadDumpRestHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ThreadDumpRestHandler.class); - private static final String THREAD_DUMP_FILE_NAME = "geaflow-thread-dump.log"; + private static final String THREAD_DUMP_FILE_NAME = "geaflow-thread-dump.log"; - private final String threadDumpFilePath; + private final String threadDumpFilePath; - public ThreadDumpRestHandler(String agentDir) { - this.threadDumpFilePath = Paths.get(agentDir, THREAD_DUMP_FILE_NAME).toString(); - } - - @GET - @Path("/content") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getThreadDumpFileContent(@QueryParam("pageNo") int pageNo, - @QueryParam("pageSize") int pageSize) { - try { - PaginationRequest request = new PaginationRequest(pageNo, pageSize); - FileUtil.checkPaginationRequest(request); - PaginationResponse response = FileUtil.getFileContent(request, threadDumpFilePath); - if (response == null) { - LOGGER.warn("Thread-dump log file {} not exists.", threadDumpFilePath); - return ApiResponse.success(new PaginationResponse<>(0, null)); - } - File file = new File(threadDumpFilePath); - ThreadDumpResponse threadDumpResponse = new ThreadDumpResponse(); - threadDumpResponse.setLastDumpTime(file.lastModified()); - threadDumpResponse.setContent(response.getData()); - return ApiResponse.success(new PaginationResponse<>(response.getTotal(), threadDumpResponse)); - } catch (Throwable t) { - LOGGER.error("Query thread-dump log content failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } - } + public ThreadDumpRestHandler(String agentDir) { + this.threadDumpFilePath = Paths.get(agentDir, THREAD_DUMP_FILE_NAME).toString(); + } - @POST - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public ApiResponse executeThreadDumpProfiler(ThreadDumpRequest request) { - try { - checkThreadDumpRequest(request); - ProcessBuilder command = getCommand(request.getPid()); - ShellUtil.executeShellCommand(command, 30); - return ApiResponse.success(); - } catch (Throwable t) { - LOGGER.error("Execute thread-dump command failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/content") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getThreadDumpFileContent( + @QueryParam("pageNo") int pageNo, @QueryParam("pageSize") int pageSize) { + try { + PaginationRequest request = new PaginationRequest(pageNo, pageSize); + FileUtil.checkPaginationRequest(request); + PaginationResponse response = FileUtil.getFileContent(request, threadDumpFilePath); + if (response == null) { + LOGGER.warn("Thread-dump log file {} not exists.", threadDumpFilePath); + return ApiResponse.success(new PaginationResponse<>(0, null)); + } + File file = new File(threadDumpFilePath); + ThreadDumpResponse threadDumpResponse = new ThreadDumpResponse(); + threadDumpResponse.setLastDumpTime(file.lastModified()); + threadDumpResponse.setContent(response.getData()); + return ApiResponse.success(new PaginationResponse<>(response.getTotal(), threadDumpResponse)); + } catch (Throwable t) { + LOGGER.error("Query thread-dump log content failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - private void checkThreadDumpRequest(ThreadDumpRequest request) { - Preconditions.checkArgument(request.getPid() > 0, "Pid must be larger than 0."); + @POST + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public ApiResponse executeThreadDumpProfiler(ThreadDumpRequest request) { + try { + checkThreadDumpRequest(request); + ProcessBuilder command = getCommand(request.getPid()); + ShellUtil.executeShellCommand(command, 30); + return ApiResponse.success(); + } catch (Throwable t) { + LOGGER.error("Execute thread-dump command failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - private ProcessBuilder getCommand(int pid) { - List commands = new ArrayList<>(); - commands.add("sh"); - commands.add("-c"); - commands.add(String.format("jstack -l %d > %s", pid, threadDumpFilePath)); - ProcessBuilder processBuilder = new ProcessBuilder(); - processBuilder.command(commands); - return processBuilder; - } + private void checkThreadDumpRequest(ThreadDumpRequest request) { + Preconditions.checkArgument(request.getPid() > 0, "Pid must be larger than 0."); + } + private ProcessBuilder getCommand(int pid) { + List commands = new ArrayList<>(); + commands.add("sh"); + commands.add("-c"); + commands.add(String.format("jstack -l %d > %s", pid, threadDumpFilePath)); + ProcessBuilder processBuilder = new ProcessBuilder(); + processBuilder.command(commands); + return processBuilder; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FileInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FileInfo.java index 5927002fc..611c888fc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FileInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FileInfo.java @@ -21,39 +21,46 @@ public class FileInfo { - private String path; + private String path; - private Long size; + private Long size; - private Long createdTime; + private Long createdTime; - public String getPath() { - return path; - } + public String getPath() { + return path; + } - public void setPath(String path) { - this.path = path; - } + public void setPath(String path) { + this.path = path; + } - public Long getSize() { - return size; - } + public Long getSize() { + return size; + } - public void setSize(Long size) { - this.size = size; - } + public void setSize(Long size) { + this.size = size; + } - public Long getCreatedTime() { - return createdTime; - } + public Long getCreatedTime() { + return createdTime; + } - public void setCreatedTime(Long createdTime) { - this.createdTime = createdTime; - } + public void setCreatedTime(Long createdTime) { + this.createdTime = createdTime; + } - @Override - public String toString() { - return "FileInfo{" + "path='" + path + '\'' + ", size=" + size + ", createdTime=" - + createdTime + '}'; - } + @Override + public String toString() { + return "FileInfo{" + + "path='" + + path + + '\'' + + ", size=" + + size + + ", createdTime=" + + createdTime + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphRequest.java index 168bfa817..4d0f1edf4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphRequest.java @@ -21,42 +21,36 @@ public class FlameGraphRequest { - /** - * Profiler type. Support CPU or ALLOC. - */ - private FlameGraphType type; - - /** - * Duration of async-profiler execution, in second time-unit. - */ - private int duration; - - /** - * The pid to profile. - */ - private int pid; - - public FlameGraphType getType() { - return type; - } - - public void setType(FlameGraphType type) { - this.type = type; - } - - public int getDuration() { - return duration; - } - - public void setDuration(int duration) { - this.duration = duration; - } - - public int getPid() { - return pid; - } - - public void setPid(int pid) { - this.pid = pid; - } + /** Profiler type. Support CPU or ALLOC. */ + private FlameGraphType type; + + /** Duration of async-profiler execution, in second time-unit. */ + private int duration; + + /** The pid to profile. */ + private int pid; + + public FlameGraphType getType() { + return type; + } + + public void setType(FlameGraphType type) { + this.type = type; + } + + public int getDuration() { + return duration; + } + + public void setDuration(int duration) { + this.duration = duration; + } + + public int getPid() { + return pid; + } + + public void setPid(int pid) { + this.pid = pid; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphType.java index fe9d66d2d..1da6d8e32 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/FlameGraphType.java @@ -21,13 +21,9 @@ public enum FlameGraphType { - /** - * Cpu profiler. - */ - CPU, + /** Cpu profiler. */ + CPU, - /** - * Heap memory alloc profiler. - */ - ALLOC + /** Heap memory alloc profiler. */ + ALLOC } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationRequest.java index 3389acc4d..68783c8b7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationRequest.java @@ -21,31 +21,30 @@ public class PaginationRequest { - private int pageNo; + private int pageNo; - private int pageSize; + private int pageSize; - public PaginationRequest() { - } + public PaginationRequest() {} - public PaginationRequest(int pageNo, int pageSize) { - this.pageNo = pageNo; - this.pageSize = pageSize; - } + public PaginationRequest(int pageNo, int pageSize) { + this.pageNo = pageNo; + this.pageSize = pageSize; + } - public int getPageNo() { - return pageNo; - } + public int getPageNo() { + return pageNo; + } - public void setPageNo(int pageNo) { - this.pageNo = pageNo; - } + public void setPageNo(int pageNo) { + this.pageNo = pageNo; + } - public int getPageSize() { - return pageSize; - } + public int getPageSize() { + return pageSize; + } - public void setPageSize(int pageSize) { - this.pageSize = pageSize; - } + public void setPageSize(int pageSize) { + this.pageSize = pageSize; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationResponse.java index f4acac4ac..be1251ed8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/PaginationResponse.java @@ -21,31 +21,30 @@ public class PaginationResponse { - private long total; + private long total; - private T data; + private T data; - public PaginationResponse() { - } + public PaginationResponse() {} - public PaginationResponse(long total, T data) { - this.total = total; - this.data = data; - } + public PaginationResponse(long total, T data) { + this.total = total; + this.data = data; + } - public long getTotal() { - return total; - } + public long getTotal() { + return total; + } - public void setTotal(long total) { - this.total = total; - } + public void setTotal(long total) { + this.total = total; + } - public T getData() { - return data; - } + public T getData() { + return data; + } - public void setData(T data) { - this.data = data; - } + public void setData(T data) { + this.data = data; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpRequest.java index fa44c1c58..51fa50806 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpRequest.java @@ -21,13 +21,13 @@ public class ThreadDumpRequest { - private int pid; + private int pid; - public int getPid() { - return pid; - } + public int getPid() { + return pid; + } - public void setPid(int pid) { - this.pid = pid; - } + public void setPid(int pid) { + this.pid = pid; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpResponse.java index 54c668710..68a56ed82 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/model/ThreadDumpResponse.java @@ -21,23 +21,23 @@ public class ThreadDumpResponse { - private long lastDumpTime; + private long lastDumpTime; - private String content; + private String content; - public long getLastDumpTime() { - return lastDumpTime; - } + public long getLastDumpTime() { + return lastDumpTime; + } - public void setLastDumpTime(long lastDumpTime) { - this.lastDumpTime = lastDumpTime; - } + public void setLastDumpTime(long lastDumpTime) { + this.lastDumpTime = lastDumpTime; + } - public String getContent() { - return content; - } + public String getContent() { + return content; + } - public void setContent(String content) { - this.content = content; - } + public void setContent(String content) { + this.content = content; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/DateUtil.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/DateUtil.java index 5fec7bec5..52a6ac85c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/DateUtil.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/DateUtil.java @@ -23,8 +23,8 @@ public class DateUtil { - public static String simpleFormat(long time) { - SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - return formatter.format(time); - } + public static String simpleFormat(long time) { + SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + return formatter.format(time); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/FileUtil.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/FileUtil.java index 1eb23ce87..b86d5627e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/FileUtil.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/agent/util/FileUtil.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.web.agent.util; -import com.google.common.base.Preconditions; import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -27,58 +26,58 @@ import java.nio.file.LinkOption; import java.nio.file.Paths; import java.nio.file.attribute.BasicFileAttributeView; + import org.apache.geaflow.cluster.web.agent.model.FileInfo; import org.apache.geaflow.cluster.web.agent.model.PaginationRequest; import org.apache.geaflow.cluster.web.agent.model.PaginationResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class FileUtil { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(FileUtil.class); +public class FileUtil { - private static final String NUL_CHAR = "\\u0000"; + private static final Logger LOGGER = LoggerFactory.getLogger(FileUtil.class); - public static PaginationResponse getFileContent(PaginationRequest request, - String filePath) { - int start = (request.getPageNo() - 1) * request.getPageSize(); - File file = new File(filePath); - if (file.exists()) { - byte[] buf = new byte[request.getPageSize()]; - try (FileInputStream inputStream = new FileInputStream(file)) { - inputStream.skip(start); - inputStream.read(buf); - } catch (IOException e) { - throw new RuntimeException("Error read file content.", e); - } - PaginationResponse response = new PaginationResponse<>(); - response.setData(new String(buf).replaceAll(NUL_CHAR, "")); - response.setTotal(file.length()); - return response; - } - return null; - } + private static final String NUL_CHAR = "\\u0000"; - public static FileInfo buildFileInfo(File file, String path) { - FileInfo fileInfo = new FileInfo(); - fileInfo.setPath(path); - fileInfo.setSize(file.length()); - try { - BasicFileAttributeView basicFileAttributeView = - Files.getFileAttributeView(Paths.get(path), - BasicFileAttributeView.class, LinkOption.NOFOLLOW_LINKS); - fileInfo.setCreatedTime(basicFileAttributeView.readAttributes().creationTime().toMillis()); - } catch (IOException e) { - LOGGER.error("Get created time of file {} failed. {}", path, e.getMessage(), e); - } - return fileInfo; + public static PaginationResponse getFileContent( + PaginationRequest request, String filePath) { + int start = (request.getPageNo() - 1) * request.getPageSize(); + File file = new File(filePath); + if (file.exists()) { + byte[] buf = new byte[request.getPageSize()]; + try (FileInputStream inputStream = new FileInputStream(file)) { + inputStream.skip(start); + inputStream.read(buf); + } catch (IOException e) { + throw new RuntimeException("Error read file content.", e); + } + PaginationResponse response = new PaginationResponse<>(); + response.setData(new String(buf).replaceAll(NUL_CHAR, "")); + response.setTotal(file.length()); + return response; } + return null; + } - public static void checkPaginationRequest(PaginationRequest request) { - Preconditions.checkArgument(request.getPageNo() > 0, - "Page number should be greater than 0."); - Preconditions.checkArgument(request.getPageSize() > 0, - "Page size should be greater than 0."); + public static FileInfo buildFileInfo(File file, String path) { + FileInfo fileInfo = new FileInfo(); + fileInfo.setPath(path); + fileInfo.setSize(file.length()); + try { + BasicFileAttributeView basicFileAttributeView = + Files.getFileAttributeView( + Paths.get(path), BasicFileAttributeView.class, LinkOption.NOFOLLOW_LINKS); + fileInfo.setCreatedTime(basicFileAttributeView.readAttributes().creationTime().toMillis()); + } catch (IOException e) { + LOGGER.error("Get created time of file {} failed. {}", path, e.getMessage(), e); } + return fileInfo; + } + public static void checkPaginationRequest(PaginationRequest request) { + Preconditions.checkArgument(request.getPageNo() > 0, "Page number should be greater than 0."); + Preconditions.checkArgument(request.getPageSize() > 0, "Page size should be greater than 0."); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/api/ApiResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/api/ApiResponse.java index 4f7307e39..4f10e0059 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/api/ApiResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/api/ApiResponse.java @@ -23,49 +23,56 @@ public class ApiResponse implements Serializable { - private final boolean success; + private final boolean success; - private final String message; + private final String message; - private final T data; + private final T data; - public ApiResponse(boolean success, String message, T data) { - this.success = success; - this.message = message; - this.data = data; - } + public ApiResponse(boolean success, String message, T data) { + this.success = success; + this.message = message; + this.data = data; + } - public static ApiResponse success() { - return new ApiResponse<>(true, null, null); - } + public static ApiResponse success() { + return new ApiResponse<>(true, null, null); + } - public static ApiResponse success(T data) { - return new ApiResponse<>(true, null, data); - } + public static ApiResponse success(T data) { + return new ApiResponse<>(true, null, data); + } - public static ApiResponse error(String message) { - return new ApiResponse<>(false, message, null); - } + public static ApiResponse error(String message) { + return new ApiResponse<>(false, message, null); + } - public static ApiResponse error(Throwable t) { - return error(t.getMessage()); - } + public static ApiResponse error(Throwable t) { + return error(t.getMessage()); + } - public boolean isSuccess() { - return success; - } + public boolean isSuccess() { + return success; + } - public String getMessage() { - return message; - } + public String getMessage() { + return message; + } - public T getData() { - return data; - } + public T getData() { + return data; + } - @Override - public String toString() { - return "ApiResponse{" + "success=" + success + ", message='" + message + '\'' + ", data=" - + data + '}'; - } -} \ No newline at end of file + @Override + public String toString() { + return "ApiResponse{" + + "success=" + + success + + ", message='" + + message + + '\'' + + ", data=" + + data + + '}'; + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ClusterRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ClusterRestHandler.java index eedc1b5a0..484baebcc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ClusterRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ClusterRestHandler.java @@ -19,7 +19,6 @@ package org.apache.geaflow.cluster.web.handler; -import com.alibaba.fastjson.JSONObject; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; @@ -28,10 +27,12 @@ import java.util.List; import java.util.Map; import java.util.Set; + import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.IClusterManager; import org.apache.geaflow.cluster.common.ComponentInfo; @@ -47,132 +48,136 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSONObject; + @Path("/") public class ClusterRestHandler implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterRestHandler.class); - - private static final String TOTAL_CONTAINER_NUM = "totalContainers"; - private static final String TOTAL_DRIVER_NUM = "totalDrivers"; - private static final String ACTIVE_CONTAINER_NUM = "activeContainers"; - private static final String ACTIVE_DRIVER_NUM = "activeDrivers"; - private static final String TOTAL_WORKER_NUM = "totalWorkers"; - private static final String USED_WORKER_NUM = "usedWorkers"; - private static final String AVAILABLE_WORKER_NUM = "availableWorkers"; - private static final String PENDING_WORKER_NUM = "pendingWorkers"; - - private static final String CONTAINER_ID_KEY = "id"; - private static final String CONTAINER_NAME_KEY = "name"; - private static final String CONTAINER_HOST_KEY = "host"; - private static final String PROCESS_ID_KEY = "pid"; - private static final String AGENT_PORT_KEY = "agentPort"; - private static final String PROCESS_METRICS_KEY = "metrics"; - private static final String LAST_UPDATE_TIME = "lastTimestamp"; - private static final String IS_ACTIVE = "isActive"; - - private final AbstractClusterManager clusterManager; - private final HeartbeatManager heartbeatManager; - private final IResourceManager resourceManager; - private final MetricFetcher metricFetcher; - - public ClusterRestHandler(IClusterManager clusterManager, HeartbeatManager heartbeatManager, - IResourceManager resourceManager, MetricFetcher metricFetcher) { - this.clusterManager = (AbstractClusterManager) clusterManager; - this.heartbeatManager = heartbeatManager; - this.resourceManager = resourceManager; - this.metricFetcher = metricFetcher; - } + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterRestHandler.class); + + private static final String TOTAL_CONTAINER_NUM = "totalContainers"; + private static final String TOTAL_DRIVER_NUM = "totalDrivers"; + private static final String ACTIVE_CONTAINER_NUM = "activeContainers"; + private static final String ACTIVE_DRIVER_NUM = "activeDrivers"; + private static final String TOTAL_WORKER_NUM = "totalWorkers"; + private static final String USED_WORKER_NUM = "usedWorkers"; + private static final String AVAILABLE_WORKER_NUM = "availableWorkers"; + private static final String PENDING_WORKER_NUM = "pendingWorkers"; + + private static final String CONTAINER_ID_KEY = "id"; + private static final String CONTAINER_NAME_KEY = "name"; + private static final String CONTAINER_HOST_KEY = "host"; + private static final String PROCESS_ID_KEY = "pid"; + private static final String AGENT_PORT_KEY = "agentPort"; + private static final String PROCESS_METRICS_KEY = "metrics"; + private static final String LAST_UPDATE_TIME = "lastTimestamp"; + private static final String IS_ACTIVE = "isActive"; - @GET - @Path("/overview") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getOverview() throws IOException { - try { - metricFetcher.update(); - Set heartbeatIds = heartbeatManager.getHeartBeatMap().keySet(); - Set activeContainerIds = new HashSet<>(heartbeatIds); - Set activeDriverIds = new HashSet<>(heartbeatIds); - activeContainerIds.retainAll(heartbeatManager.getActiveContainerIds()); - activeDriverIds.retainAll(heartbeatManager.getActiveDriverIds()); - int activeContainerNum = activeContainerIds.size(); - int activeDriverNum = activeDriverIds.size(); - int totalContainerNum = clusterManager.getTotalContainers(); - int totalDriverNum = clusterManager.getTotalDrivers(); - Map ret = new HashMap<>(); - ret.put(TOTAL_CONTAINER_NUM, totalContainerNum); - ret.put(TOTAL_DRIVER_NUM, totalDriverNum); - ret.put(ACTIVE_CONTAINER_NUM, activeContainerNum); - ret.put(ACTIVE_DRIVER_NUM, activeDriverNum); - if (resourceManager instanceof DefaultResourceManager) { - ResourceMetrics metrics = - ((DefaultResourceManager) resourceManager).getResourceMetrics(); - ret.put(TOTAL_WORKER_NUM, metrics.getTotalWorkers()); - ret.put(AVAILABLE_WORKER_NUM, metrics.getAvailableWorkers()); - ret.put(PENDING_WORKER_NUM, metrics.getPendingWorkers()); - ret.put(USED_WORKER_NUM, metrics.getTotalWorkers() - metrics.getAvailableWorkers()); - } - return ApiResponse.success(ret); - } catch (Throwable t) { - LOGGER.error("Query overview info failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + private final AbstractClusterManager clusterManager; + private final HeartbeatManager heartbeatManager; + private final IResourceManager resourceManager; + private final MetricFetcher metricFetcher; + public ClusterRestHandler( + IClusterManager clusterManager, + HeartbeatManager heartbeatManager, + IResourceManager resourceManager, + MetricFetcher metricFetcher) { + this.clusterManager = (AbstractClusterManager) clusterManager; + this.heartbeatManager = heartbeatManager; + this.resourceManager = resourceManager; + this.metricFetcher = metricFetcher; + } + + @GET + @Path("/overview") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getOverview() throws IOException { + try { + metricFetcher.update(); + Set heartbeatIds = heartbeatManager.getHeartBeatMap().keySet(); + Set activeContainerIds = new HashSet<>(heartbeatIds); + Set activeDriverIds = new HashSet<>(heartbeatIds); + activeContainerIds.retainAll(heartbeatManager.getActiveContainerIds()); + activeDriverIds.retainAll(heartbeatManager.getActiveDriverIds()); + int activeContainerNum = activeContainerIds.size(); + int activeDriverNum = activeDriverIds.size(); + int totalContainerNum = clusterManager.getTotalContainers(); + int totalDriverNum = clusterManager.getTotalDrivers(); + Map ret = new HashMap<>(); + ret.put(TOTAL_CONTAINER_NUM, totalContainerNum); + ret.put(TOTAL_DRIVER_NUM, totalDriverNum); + ret.put(ACTIVE_CONTAINER_NUM, activeContainerNum); + ret.put(ACTIVE_DRIVER_NUM, activeDriverNum); + if (resourceManager instanceof DefaultResourceManager) { + ResourceMetrics metrics = ((DefaultResourceManager) resourceManager).getResourceMetrics(); + ret.put(TOTAL_WORKER_NUM, metrics.getTotalWorkers()); + ret.put(AVAILABLE_WORKER_NUM, metrics.getAvailableWorkers()); + ret.put(PENDING_WORKER_NUM, metrics.getPendingWorkers()); + ret.put(USED_WORKER_NUM, metrics.getTotalWorkers() - metrics.getAvailableWorkers()); + } + return ApiResponse.success(ret); + } catch (Throwable t) { + LOGGER.error("Query overview info failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/containers") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getContainers() throws IOException { - try { - metricFetcher.update(); - Map containerMap = clusterManager.getContainerInfos(); - Map heartbeatMap = heartbeatManager.getHeartBeatMap(); - Set activeContainerIds = heartbeatManager.getActiveContainerIds(); - return ApiResponse.success(buildDetailInfo(containerMap, heartbeatMap, activeContainerIds)); - } catch (Throwable t) { - LOGGER.error("Query containers failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/containers") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getContainers() throws IOException { + try { + metricFetcher.update(); + Map containerMap = clusterManager.getContainerInfos(); + Map heartbeatMap = heartbeatManager.getHeartBeatMap(); + Set activeContainerIds = heartbeatManager.getActiveContainerIds(); + return ApiResponse.success(buildDetailInfo(containerMap, heartbeatMap, activeContainerIds)); + } catch (Throwable t) { + LOGGER.error("Query containers failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } + + @GET + @Path("/drivers") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> getDrivers() throws IOException { - @GET - @Path("/drivers") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> getDrivers() throws IOException { - - try { - metricFetcher.update(); - Map driverMap = clusterManager.getDriverInfos(); - Map heartbeatMap = heartbeatManager.getHeartBeatMap(); - Set activeDriverIds = heartbeatManager.getActiveDriverIds(); - return ApiResponse.success(buildDetailInfo(driverMap, heartbeatMap, activeDriverIds)); - } catch (Throwable t) { - LOGGER.error("Query drivers failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + try { + metricFetcher.update(); + Map driverMap = clusterManager.getDriverInfos(); + Map heartbeatMap = heartbeatManager.getHeartBeatMap(); + Set activeDriverIds = heartbeatManager.getActiveDriverIds(); + return ApiResponse.success(buildDetailInfo(driverMap, heartbeatMap, activeDriverIds)); + } catch (Throwable t) { + LOGGER.error("Query drivers failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - private List buildDetailInfo(Map componentMap, - Map heartbeatMap, - Set activeComponentIds) { - List result = new ArrayList<>(); - for (Map.Entry entry : componentMap.entrySet()) { - Integer componentId = entry.getKey(); - JSONObject containerObj = new JSONObject(); - containerObj.put(CONTAINER_ID_KEY, componentId); - ComponentInfo info = entry.getValue(); - containerObj.put(CONTAINER_NAME_KEY, info.getName()); - containerObj.put(CONTAINER_HOST_KEY, info.getHost()); - containerObj.put(AGENT_PORT_KEY, info.getAgentPort()); - containerObj.put(PROCESS_ID_KEY, info.getPid()); - Heartbeat heartbeat = heartbeatMap.get(componentId); - if (heartbeat != null) { - containerObj.put(LAST_UPDATE_TIME, heartbeat.getTimestamp()); - containerObj.put(PROCESS_METRICS_KEY, heartbeat.getProcessMetrics()); - } - containerObj.put(IS_ACTIVE, activeComponentIds.contains(componentId)); - result.add(containerObj); - } - return result; + private List buildDetailInfo( + Map componentMap, + Map heartbeatMap, + Set activeComponentIds) { + List result = new ArrayList<>(); + for (Map.Entry entry : componentMap.entrySet()) { + Integer componentId = entry.getKey(); + JSONObject containerObj = new JSONObject(); + containerObj.put(CONTAINER_ID_KEY, componentId); + ComponentInfo info = entry.getValue(); + containerObj.put(CONTAINER_NAME_KEY, info.getName()); + containerObj.put(CONTAINER_HOST_KEY, info.getHost()); + containerObj.put(AGENT_PORT_KEY, info.getAgentPort()); + containerObj.put(PROCESS_ID_KEY, info.getPid()); + Heartbeat heartbeat = heartbeatMap.get(componentId); + if (heartbeat != null) { + containerObj.put(LAST_UPDATE_TIME, heartbeat.getTimestamp()); + containerObj.put(PROCESS_METRICS_KEY, heartbeat.getProcessMetrics()); + } + containerObj.put(IS_ACTIVE, activeComponentIds.contains(componentId)); + result.add(containerObj); } + return result; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/MasterRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/MasterRestHandler.java index d113470b1..a0e0cf221 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/MasterRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/MasterRestHandler.java @@ -21,10 +21,12 @@ import java.io.Serializable; import java.util.Map; + import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; + import org.apache.geaflow.cluster.common.ComponentInfo; import org.apache.geaflow.cluster.web.api.ApiResponse; import org.apache.geaflow.common.config.Configuration; @@ -36,50 +38,50 @@ @Path("/master") public class MasterRestHandler implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(MasterRestHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MasterRestHandler.class); - private final Configuration configuration; - private final ComponentInfo componentInfo; + private final Configuration configuration; + private final ComponentInfo componentInfo; - public MasterRestHandler(ComponentInfo componentInfo, Configuration configuration) { - this.configuration = configuration; - this.componentInfo = componentInfo; - } + public MasterRestHandler(ComponentInfo componentInfo, Configuration configuration) { + this.configuration = configuration; + this.componentInfo = componentInfo; + } - @GET - @Path("/configuration") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> queryConfiguration() { - try { - return ApiResponse.success(configuration.getConfigMap()); - } catch (Throwable t) { - LOGGER.error("Query master configuration failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/configuration") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> queryConfiguration() { + try { + return ApiResponse.success(configuration.getConfigMap()); + } catch (Throwable t) { + LOGGER.error("Query master configuration failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/metrics") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse queryProcessMetrics() { - try { - return ApiResponse.success(StatsCollectorFactory.init(configuration).getProcessStatsCollector().collect()); - } catch (Throwable t) { - LOGGER.error("Query master process metrics failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/metrics") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse queryProcessMetrics() { + try { + return ApiResponse.success( + StatsCollectorFactory.init(configuration).getProcessStatsCollector().collect()); + } catch (Throwable t) { + LOGGER.error("Query master process metrics failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/info") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse queryMasterInfo() { - try { - return ApiResponse.success(componentInfo); - } catch (Throwable t) { - LOGGER.error("Query master process metrics failed. {}", t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/info") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse queryMasterInfo() { + try { + return ApiResponse.success(componentInfo); + } catch (Throwable t) { + LOGGER.error("Query master process metrics failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/PipelineRestHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/PipelineRestHandler.java index 10af91015..1353527ca 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/PipelineRestHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/PipelineRestHandler.java @@ -24,11 +24,13 @@ import java.util.Collection; import java.util.Collections; import java.util.List; + import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; + import org.apache.geaflow.cluster.web.api.ApiResponse; import org.apache.geaflow.cluster.web.metrics.MetricFetcher; import org.apache.geaflow.common.metric.CycleMetrics; @@ -41,51 +43,51 @@ @Path("/pipelines") public class PipelineRestHandler implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineRestHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineRestHandler.class); - private final MetricCache metricCache; - private final MetricFetcher metricFetcher; + private final MetricCache metricCache; + private final MetricFetcher metricFetcher; - public PipelineRestHandler(MetricCache metricCache, MetricFetcher metricFetcher) { - this.metricCache = metricCache; - this.metricFetcher = metricFetcher; - } + public PipelineRestHandler(MetricCache metricCache, MetricFetcher metricFetcher) { + this.metricCache = metricCache; + this.metricFetcher = metricFetcher; + } - @GET - @Path("/") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> queryPipelineList() { - try { - metricFetcher.update(); - List list = new ArrayList<>(); - for (PipelineMetricCache cache : metricCache.getPipelineMetricCaches().values()) { - if (cache.getPipelineMetrics() != null) { - list.add(cache.getPipelineMetrics()); - } - } - return ApiResponse.success(list); - } catch (Throwable t) { - LOGGER.error("Query pipeline list failed. {}", t.getMessage(), t); - return ApiResponse.error(t); + @GET + @Path("/") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> queryPipelineList() { + try { + metricFetcher.update(); + List list = new ArrayList<>(); + for (PipelineMetricCache cache : metricCache.getPipelineMetricCaches().values()) { + if (cache.getPipelineMetrics() != null) { + list.add(cache.getPipelineMetrics()); } + } + return ApiResponse.success(list); + } catch (Throwable t) { + LOGGER.error("Query pipeline list failed. {}", t.getMessage(), t); + return ApiResponse.error(t); } + } - @GET - @Path("/{pipelineName}/cycles") - @Produces(MediaType.APPLICATION_JSON) - public ApiResponse> queryCycleList(@PathParam("pipelineName") String pipelineName) { - try { - metricFetcher.update(); - PipelineMetricCache cache = metricCache.getPipelineMetricCaches().get(pipelineName); - if (cache == null) { - return ApiResponse.success(Collections.EMPTY_LIST); - } - return ApiResponse.success(cache.getCycleMetricList().values()); - } catch (Throwable t) { - LOGGER.error("Query cycle metric list of pipeline {} failed. {}", pipelineName, - t.getMessage(), t); - return ApiResponse.error(t); - } + @GET + @Path("/{pipelineName}/cycles") + @Produces(MediaType.APPLICATION_JSON) + public ApiResponse> queryCycleList( + @PathParam("pipelineName") String pipelineName) { + try { + metricFetcher.update(); + PipelineMetricCache cache = metricCache.getPipelineMetricCaches().get(pipelineName); + if (cache == null) { + return ApiResponse.success(Collections.EMPTY_LIST); + } + return ApiResponse.success(cache.getCycleMetricList().values()); + } catch (Throwable t) { + LOGGER.error( + "Query cycle metric list of pipeline {} failed. {}", pipelineName, t.getMessage(), t); + return ApiResponse.error(t); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ProxyHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ProxyHandler.java index 49606ee53..b6883245f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ProxyHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/handler/ProxyHandler.java @@ -20,29 +20,31 @@ package org.apache.geaflow.cluster.web.handler; import java.util.Arrays; + import javax.servlet.http.HttpServletRequest; + import org.eclipse.jetty.proxy.ProxyServlet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ProxyHandler extends ProxyServlet { - private static final Logger LOGGER = LoggerFactory.getLogger(ProxyHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ProxyHandler.class); - @Override - protected String rewriteTarget(HttpServletRequest request) { - return getTargetUrl(request); - } + @Override + protected String rewriteTarget(HttpServletRequest request) { + return getTargetUrl(request); + } - private String getTargetUrl(HttpServletRequest request) { - String path = request.getRequestURI(); - String[] pathParts = path.split("/"); - String fullUri = String.join("/", Arrays.copyOfRange(pathParts, 2, pathParts.length)); - StringBuilder target = new StringBuilder(); - target.append("http://").append(fullUri); - if (request.getQueryString() != null) { - target.append("?").append(request.getQueryString()); - } - return target.toString(); + private String getTargetUrl(HttpServletRequest request) { + String path = request.getRequestURI(); + String[] pathParts = path.split("/"); + String fullUri = String.join("/", Arrays.copyOfRange(pathParts, 2, pathParts.length)); + StringBuilder target = new StringBuilder(); + target.append("http://").append(fullUri); + if (request.getQueryString() != null) { + target.append("?").append(request.getQueryString()); } + return target.toString(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricFetcher.java index 1030678b9..0f2caa083 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricFetcher.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.IClusterManager; import org.apache.geaflow.cluster.rpc.RpcClient; @@ -37,55 +38,58 @@ import org.slf4j.LoggerFactory; public class MetricFetcher implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(MetricFetcher.class); - private static final int DEFAULT_TIMEOUT = 10000; + private static final Logger LOGGER = LoggerFactory.getLogger(MetricFetcher.class); + private static final int DEFAULT_TIMEOUT = 10000; - private final Map driverIds; - private final MetricCache metricCache; - private final int updateIntervalMs; - private long lastUpdateTime; + private final Map driverIds; + private final MetricCache metricCache; + private final int updateIntervalMs; + private long lastUpdateTime; - public MetricFetcher(Configuration configuration, IClusterManager clusterManager, - MetricCache metricCache) { - this.driverIds = ((AbstractClusterManager) clusterManager).getDriverIds(); - this.metricCache = metricCache; - this.updateIntervalMs = DEFAULT_TIMEOUT; - RpcClient.init(configuration); - } + public MetricFetcher( + Configuration configuration, IClusterManager clusterManager, MetricCache metricCache) { + this.driverIds = ((AbstractClusterManager) clusterManager).getDriverIds(); + this.metricCache = metricCache; + this.updateIntervalMs = DEFAULT_TIMEOUT; + RpcClient.init(configuration); + } - public synchronized void update() { - long currentTime = System.currentTimeMillis(); - if (lastUpdateTime + updateIntervalMs <= currentTime) { - lastUpdateTime = currentTime; - fetch(); - } + public synchronized void update() { + long currentTime = System.currentTimeMillis(); + if (lastUpdateTime + updateIntervalMs <= currentTime) { + lastUpdateTime = currentTime; + fetch(); } + } - private void fetch() { - MetricQueryRequest request = MetricQueryRequest.newBuilder().build(); - Map futureList = new HashMap<>(); - MetricCache newMetricCache = new MetricCache(); - AtomicInteger count = new AtomicInteger(driverIds.values().size()); - for (String driverId : driverIds.values()) { - Future responseFuture = RpcClient.getInstance() - .requestMetrics(driverId, request, new RpcCallback() { + private void fetch() { + MetricQueryRequest request = MetricQueryRequest.newBuilder().build(); + Map futureList = new HashMap<>(); + MetricCache newMetricCache = new MetricCache(); + AtomicInteger count = new AtomicInteger(driverIds.values().size()); + for (String driverId : driverIds.values()) { + Future responseFuture = + RpcClient.getInstance() + .requestMetrics( + driverId, + request, + new RpcCallback() { @Override public void onSuccess(MetricQueryResponse value) { - MetricCache cache = RpcMessageEncoder.decode(value.getPayload()); - newMetricCache.mergeMetricCache(cache); - if (count.decrementAndGet() == 0) { - metricCache.clearAll(); - metricCache.mergeMetricCache(newMetricCache); - } + MetricCache cache = RpcMessageEncoder.decode(value.getPayload()); + newMetricCache.mergeMetricCache(cache); + if (count.decrementAndGet() == 0) { + metricCache.clearAll(); + metricCache.mergeMetricCache(newMetricCache); + } } @Override public void onFailure(Throwable t) { - LOGGER.warn("fail to fetch metric from " + driverId, t); + LOGGER.warn("fail to fetch metric from " + driverId, t); } - }); - futureList.put(driverId, responseFuture); - } + }); + futureList.put(driverId, responseFuture); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricServer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricServer.java index 35399da1f..fdf34b18b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricServer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/MetricServer.java @@ -22,8 +22,8 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.METRIC_SERVICE_PORT; -import com.baidu.brpc.server.RpcServerOptions; import java.io.Serializable; + import org.apache.geaflow.cluster.rpc.RpcService; import org.apache.geaflow.cluster.rpc.impl.MetricEndpoint; import org.apache.geaflow.cluster.rpc.impl.RpcServiceImpl; @@ -33,45 +33,46 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.baidu.brpc.server.RpcServerOptions; + public class MetricServer implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(MetricServer.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MetricServer.class); - private final int port; - private RpcService rpcService; + private final int port; + private RpcService rpcService; - public MetricServer(Configuration configuration) { - this.port = configuration.getInteger(METRIC_SERVICE_PORT); - if (configuration.getBoolean(HTTP_REST_SERVICE_ENABLE)) { - RpcServerOptions serverOptions = getServerOptions(configuration); - RpcServiceImpl rpcService = new RpcServiceImpl(PortUtil.getPort(port), serverOptions); - rpcService.addEndpoint(new MetricEndpoint(configuration)); - this.rpcService = rpcService; - } + public MetricServer(Configuration configuration) { + this.port = configuration.getInteger(METRIC_SERVICE_PORT); + if (configuration.getBoolean(HTTP_REST_SERVICE_ENABLE)) { + RpcServerOptions serverOptions = getServerOptions(configuration); + RpcServiceImpl rpcService = new RpcServiceImpl(PortUtil.getPort(port), serverOptions); + rpcService.addEndpoint(new MetricEndpoint(configuration)); + this.rpcService = rpcService; } + } - private RpcServerOptions getServerOptions(Configuration configuration) { - RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); - serverOptions.setGlobalThreadPoolSharing(false); - serverOptions.setIoThreadNum(1); - serverOptions.setWorkThreadNum(2); - return serverOptions; - } + private RpcServerOptions getServerOptions(Configuration configuration) { + RpcServerOptions serverOptions = ConfigurableServerOption.build(configuration); + serverOptions.setGlobalThreadPoolSharing(false); + serverOptions.setIoThreadNum(1); + serverOptions.setWorkThreadNum(2); + return serverOptions; + } - public int start() { - if (rpcService != null) { - int metricPort = rpcService.startService(); - LOGGER.info("started metric service on port:{}", metricPort); - return metricPort; - } else { - return port; - } + public int start() { + if (rpcService != null) { + int metricPort = rpcService.startService(); + LOGGER.info("started metric service on port:{}", metricPort); + return metricPort; + } else { + return port; } + } - public void stop() { - if (rpcService != null) { - LOGGER.info("stopping metric query service"); - rpcService.stopService(); - } + public void stop() { + if (rpcService != null) { + LOGGER.info("stopping metric query service"); + rpcService.stopService(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/ResourceMetrics.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/ResourceMetrics.java index 3a4f3ca4a..01e68cbcf 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/ResourceMetrics.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/web/metrics/ResourceMetrics.java @@ -23,32 +23,31 @@ public class ResourceMetrics implements Serializable { - private int totalWorkers; - private int availableWorkers; - private int pendingWorkers; + private int totalWorkers; + private int availableWorkers; + private int pendingWorkers; - public int getTotalWorkers() { - return totalWorkers; - } + public int getTotalWorkers() { + return totalWorkers; + } - public void setTotalWorkers(int totalWorkers) { - this.totalWorkers = totalWorkers; - } + public void setTotalWorkers(int totalWorkers) { + this.totalWorkers = totalWorkers; + } - public int getAvailableWorkers() { - return availableWorkers; - } + public int getAvailableWorkers() { + return availableWorkers; + } - public void setAvailableWorkers(int availableWorkers) { - this.availableWorkers = availableWorkers; - } + public void setAvailableWorkers(int availableWorkers) { + this.availableWorkers = availableWorkers; + } - public int getPendingWorkers() { - return pendingWorkers; - } - - public void setPendingWorkers(int pendingWorkers) { - this.pendingWorkers = pendingWorkers; - } + public int getPendingWorkers() { + return pendingWorkers; + } + public void setPendingWorkers(int pendingWorkers) { + this.pendingWorkers = pendingWorkers; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/Dispatcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/Dispatcher.java index f396ccb80..a00e3e0b6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/Dispatcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/Dispatcher.java @@ -29,32 +29,30 @@ public class Dispatcher extends AbstractTaskRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(Dispatcher.class); + private static final Logger LOGGER = LoggerFactory.getLogger(Dispatcher.class); - private TaskService taskService; + private TaskService taskService; - public Dispatcher(TaskService taskService) { - super(); - this.taskService = taskService; - } - - @Override - protected void process(ICommand command) { - switch (command.getEventType()) { - case COMPOSE: - for (IEvent event : ((IComposeEvent) command).getEventList()) { - process((ICommand) event); - } - break; - case INTERRUPT_TASK: - LOGGER.info("{} interrupt current running task", command.getWorkerId()); - this.taskService.interrupt(command.getWorkerId()); - break; - default: - this.taskService.process(command.getWorkerId(), command); - break; + public Dispatcher(TaskService taskService) { + super(); + this.taskService = taskService; + } + @Override + protected void process(ICommand command) { + switch (command.getEventType()) { + case COMPOSE: + for (IEvent event : ((IComposeEvent) command).getEventList()) { + process((ICommand) event); } + break; + case INTERRUPT_TASK: + LOGGER.info("{} interrupt current running task", command.getWorkerId()); + this.taskService.interrupt(command.getWorkerId()); + break; + default: + this.taskService.process(command.getWorkerId(), command); + break; } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/DispatcherService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/DispatcherService.java index beb7c9391..c32dbc82b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/DispatcherService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/DispatcherService.java @@ -25,17 +25,17 @@ public class DispatcherService extends AbstractTaskService { - private static final String MESSAGE_FORMAT = "geaflow-message-%d"; + private static final String MESSAGE_FORMAT = "geaflow-message-%d"; - private Dispatcher dispatcher; + private Dispatcher dispatcher; - public DispatcherService(Dispatcher dispatcher, Configuration configuration) { - super(configuration, MESSAGE_FORMAT); - this.dispatcher = dispatcher; - } + public DispatcherService(Dispatcher dispatcher, Configuration configuration) { + super(configuration, MESSAGE_FORMAT); + this.dispatcher = dispatcher; + } - @Override - public Dispatcher[] buildTaskRunner() { - return new Dispatcher[]{dispatcher}; - } + @Override + public Dispatcher[] buildTaskRunner() { + return new Dispatcher[] {dispatcher}; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IAffinityWorker.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IAffinityWorker.java index 0834be8f1..a94177641 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IAffinityWorker.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IAffinityWorker.java @@ -21,13 +21,9 @@ public interface IAffinityWorker { - /** - * Stash current worker context. - */ - void stash(); + /** Stash current worker context. */ + void stash(); - /** - * Pop worker context. - */ - void pop(IWorkerContext workerContext); + /** Pop worker context. */ + void pop(IWorkerContext workerContext); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorker.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorker.java index 20526b13b..d2a9107ee 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorker.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorker.java @@ -23,43 +23,27 @@ public interface IWorker extends Serializable { - /** - * Open worker processor. - */ - void open(IWorkerContext workerContext); + /** Open worker processor. */ + void open(IWorkerContext workerContext); - /** - * Init worker processor runtime info. - */ - void init(long windowId); + /** Init worker processor runtime info. */ + void init(long windowId); - /** - * Worker do processing of processor. - */ - O process(I input); + /** Worker do processing of processor. */ + O process(I input); - /** - * Worker finish processing of processor. - */ - void finish(long windowId); + /** Worker finish processing of processor. */ + void finish(long windowId); - /** - * Interrupt the processing of current window. - */ - void interrupt(); + /** Interrupt the processing of current window. */ + void interrupt(); - /** - * Release worker resource if needed. - */ - void close(); + /** Release worker resource if needed. */ + void close(); - /** - * Returns the runtime context of worker. - */ - IWorkerContext getWorkerContext(); + /** Returns the runtime context of worker. */ + IWorkerContext getWorkerContext(); - /** - * Returns the worker type. - */ - WorkerType getWorkerType(); + /** Returns the worker type. */ + WorkerType getWorkerType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorkerContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorkerContext.java index 21b19408b..c6158154f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorkerContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/IWorkerContext.java @@ -20,18 +20,14 @@ package org.apache.geaflow.cluster.worker; import java.io.Serializable; + import org.apache.geaflow.cluster.protocol.IEventContext; public interface IWorkerContext extends Serializable { - /** - * Init the worker context. - */ - void init(IEventContext eventContext); - - /** - * Close runtime resource of worker. - */ - void close(); + /** Init the worker context. */ + void init(IEventContext eventContext); + /** Close runtime resource of worker. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/WorkerType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/WorkerType.java index 7fd303430..b7caa7ffa 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/WorkerType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/org/apache/geaflow/cluster/worker/WorkerType.java @@ -20,22 +20,18 @@ package org.apache.geaflow.cluster.worker; public enum WorkerType { - /** - * Aligned compute worker type. - */ - aligned_compute, + /** Aligned compute worker type. */ + aligned_compute, - /** - * Unaligned compute worker type. - */ - unaligned_compute; + /** Unaligned compute worker type. */ + unaligned_compute; - public static WorkerType getEnum(String value) { - for (WorkerType v : values()) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - return aligned_compute; + public static WorkerType getEnum(String value) { + for (WorkerType v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } } + return aligned_compute; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/PipelineResultTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/PipelineResultTest.java index 90371a26d..75c78b93d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/PipelineResultTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/PipelineResultTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.client; import java.util.concurrent.CompletableFuture; + import org.apache.geaflow.common.encoder.RpcMessageEncoder; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.rpc.proto.Driver; @@ -29,21 +30,22 @@ public class PipelineResultTest { - @Test - public void testResult() { - CompletableFuture res = new CompletableFuture<>(); - res.complete(Driver.PipelineRes.newBuilder().setPayload(RpcMessageEncoder.encode("test")).build()); - PipelineResult result = new PipelineResult(res); - Assert.assertEquals("test", result.get()); - } - - @Test(expectedExceptions = GeaflowRuntimeException.class, - expectedExceptionsMessageRegExp = ".*get pipeline result error.*") - public void testNotHasResult() { - CompletableFuture res = new CompletableFuture<>(); - res.completeExceptionally(new Throwable()); - PipelineResult result = new PipelineResult(res); - Assert.assertTrue(result.isSuccess()); - } + @Test + public void testResult() { + CompletableFuture res = new CompletableFuture<>(); + res.complete( + Driver.PipelineRes.newBuilder().setPayload(RpcMessageEncoder.encode("test")).build()); + PipelineResult result = new PipelineResult(res); + Assert.assertEquals("test", result.get()); + } + @Test( + expectedExceptions = GeaflowRuntimeException.class, + expectedExceptionsMessageRegExp = ".*get pipeline result error.*") + public void testNotHasResult() { + CompletableFuture res = new CompletableFuture<>(); + res.completeExceptionally(new Throwable()); + PipelineResult result = new PipelineResult(res); + Assert.assertTrue(result.isSuccess()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallbackTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallbackTest.java index 8619f2c14..9df66622e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallbackTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestClusterStartedCallbackTest.java @@ -19,15 +19,12 @@ package org.apache.geaflow.cluster.client.callback; -import com.google.gson.Gson; import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; + import org.apache.geaflow.cluster.client.callback.ClusterStartedCallback.ClusterMeta; import org.apache.geaflow.cluster.rpc.ConnectAddress; import org.apache.geaflow.common.config.Configuration; @@ -36,52 +33,59 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import com.google.gson.Gson; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + public class RestClusterStartedCallbackTest { - MockWebServer server; - String baseUrl; + MockWebServer server; + String baseUrl; - @BeforeClass - public void prepare() throws IOException { - // Create a MockWebServer. - server = new MockWebServer(); - // Schedule some responses. - server.enqueue(new MockResponse().setBody("{key:value,success:true}")); - server.enqueue(new MockResponse().setBody("{success:true}")); - // Start the server. - server.start(); - baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); - } + @BeforeClass + public void prepare() throws IOException { + // Create a MockWebServer. + server = new MockWebServer(); + // Schedule some responses. + server.enqueue(new MockResponse().setBody("{key:value,success:true}")); + server.enqueue(new MockResponse().setBody("{success:true}")); + // Start the server. + server.start(); + baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); + } - @AfterClass - public void tearUp() throws IOException { - // Shut down the server. Instances cannot be reused. - server.shutdown(); - } + @AfterClass + public void tearUp() throws IOException { + // Shut down the server. Instances cannot be reused. + server.shutdown(); + } - @Test - public void test() throws InterruptedException { - // Ask the server for its URL. You'll need this to make HTTP requests. - Configuration configuration = new Configuration(); - String url = URI.create(baseUrl).resolve("/v1/cluster").toString(); - RestClusterStartedCallback callback = new RestClusterStartedCallback(configuration, url); - Map addressList = new HashMap<>(); - addressList.put("1", new ConnectAddress()); - ClusterMeta clusterMeta = new ClusterMeta(addressList, "master1"); - callback.onSuccess(clusterMeta); + @Test + public void test() throws InterruptedException { + // Ask the server for its URL. You'll need this to make HTTP requests. + Configuration configuration = new Configuration(); + String url = URI.create(baseUrl).resolve("/v1/cluster").toString(); + RestClusterStartedCallback callback = new RestClusterStartedCallback(configuration, url); + Map addressList = new HashMap<>(); + addressList.put("1", new ConnectAddress()); + ClusterMeta clusterMeta = new ClusterMeta(addressList, "master1"); + callback.onSuccess(clusterMeta); - // confirm that your app made the HTTP requests you were expecting. - RecordedRequest request1 = server.takeRequest(); - Assert.assertEquals("/v1/cluster", request1.getPath()); - HttpRequest result1 = new Gson() + // confirm that your app made the HTTP requests you were expecting. + RecordedRequest request1 = server.takeRequest(); + Assert.assertEquals("/v1/cluster", request1.getPath()); + HttpRequest result1 = + new Gson() .fromJson(request1.getBody().readString(StandardCharsets.UTF_8), HttpRequest.class); - Assert.assertTrue(result1.isSuccess()); + Assert.assertTrue(result1.isSuccess()); - callback.onFailure(new RuntimeException("error")); - RecordedRequest request2 = server.takeRequest(); - HttpRequest result2 = new Gson() + callback.onFailure(new RuntimeException("error")); + RecordedRequest request2 = server.takeRequest(); + HttpRequest result2 = + new Gson() .fromJson(request2.getBody().readString(StandardCharsets.UTF_8), HttpRequest.class); - Assert.assertFalse(result2.isSuccess()); - } - + Assert.assertFalse(result2.isSuccess()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallbackTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallbackTest.java index 9f29174ed..c7ece431c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallbackTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/client/callback/RestJobOperatorCallbackTest.java @@ -21,59 +21,61 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_UNIQUE_ID; -import com.google.gson.Gson; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.Map; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; -import org.apache.geaflow.cluster.rpc.ConnectAddress; + import org.apache.geaflow.common.config.Configuration; import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class RestJobOperatorCallbackTest { +import com.google.gson.Gson; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; - MockWebServer server; - String baseUrl; +public class RestJobOperatorCallbackTest { - @BeforeClass - public void prepare() throws IOException { - // Create a MockWebServer. - server = new MockWebServer(); - // Schedule some responses. - server.enqueue(new MockResponse().setBody("{key:value,success:true}")); - server.enqueue(new MockResponse().setBody("{success:true}")); - // Start the server. - server.start(); - baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); - } + MockWebServer server; + String baseUrl; - @AfterClass - public void tearUp() throws IOException { - // Shut down the server. Instances cannot be reused. - server.shutdown(); - } + @BeforeClass + public void prepare() throws IOException { + // Create a MockWebServer. + server = new MockWebServer(); + // Schedule some responses. + server.enqueue(new MockResponse().setBody("{key:value,success:true}")); + server.enqueue(new MockResponse().setBody("{success:true}")); + // Start the server. + server.start(); + baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); + } - @Test - public void test() throws InterruptedException { - // Ask the server for its URL. You'll need this to make HTTP requests. - Configuration configuration = new Configuration(); - configuration.put(JOB_UNIQUE_ID, String.valueOf(0L)); - RestJobOperatorCallback callback = new RestJobOperatorCallback(configuration, baseUrl); - callback.onFinish(); + @AfterClass + public void tearUp() throws IOException { + // Shut down the server. Instances cannot be reused. + server.shutdown(); + } - // confirm that your app made the HTTP requests you were expecting. - RecordedRequest request = server.takeRequest(); - Assert.assertEquals("/api/tasks/0/operations", request.getPath()); - JobOperatorCallback.JobOperatorMeta reqBody = new Gson() - .fromJson(request.getBody().readString(StandardCharsets.UTF_8), JobOperatorCallback.JobOperatorMeta.class); - Assert.assertTrue(reqBody.isSuccess()); - Assert.assertEquals(reqBody.getAction(), "finish"); - } + @Test + public void test() throws InterruptedException { + // Ask the server for its URL. You'll need this to make HTTP requests. + Configuration configuration = new Configuration(); + configuration.put(JOB_UNIQUE_ID, String.valueOf(0L)); + RestJobOperatorCallback callback = new RestJobOperatorCallback(configuration, baseUrl); + callback.onFinish(); + // confirm that your app made the HTTP requests you were expecting. + RecordedRequest request = server.takeRequest(); + Assert.assertEquals("/api/tasks/0/operations", request.getPath()); + JobOperatorCallback.JobOperatorMeta reqBody = + new Gson() + .fromJson( + request.getBody().readString(StandardCharsets.UTF_8), + JobOperatorCallback.JobOperatorMeta.class); + Assert.assertTrue(reqBody.isSuccess()); + Assert.assertEquals(reqBody.getAction(), "finish"); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/config/ClusterConfigTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/config/ClusterConfigTest.java index 788de1677..2ae9e87d8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/config/ClusterConfigTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/config/ClusterConfigTest.java @@ -30,27 +30,26 @@ public class ClusterConfigTest { - @Test - public void testMasterConfig() { - ClusterConfig clusterConfig = ClusterConfig.build(new Configuration()); - ClusterJvmOptions masterOptions = clusterConfig.getMasterJvmOptions(); - - Assert.assertEquals(masterOptions.getJvmOptions().size(), 4); - Assert.assertEquals(masterOptions.getExtraOptions().size(), 1); - } - - @Test - public void test() { - Configuration config = new Configuration(); - config.put(CONTAINER_MEMORY_MB.getKey(), "256"); - config.put(CONTAINER_JVM_OPTION.getKey(), "-Xmx200m,-Xms200m,-Xmn128m"); - config.put(MASTER_MEMORY_MB.getKey(), "256"); - config.put(MASTER_JVM_OPTIONS.getKey(), "-Xmx200m,-Xms200m,-Xmn128m"); - ClusterConfig clusterConfig = ClusterConfig.build(config); - Assert.assertNotNull(clusterConfig); - Assert.assertNotNull(clusterConfig.getMasterJvmOptions()); - Assert.assertNotNull(clusterConfig.getDriverJvmOptions()); - Assert.assertNotNull(clusterConfig.getContainerJvmOptions()); - } - + @Test + public void testMasterConfig() { + ClusterConfig clusterConfig = ClusterConfig.build(new Configuration()); + ClusterJvmOptions masterOptions = clusterConfig.getMasterJvmOptions(); + + Assert.assertEquals(masterOptions.getJvmOptions().size(), 4); + Assert.assertEquals(masterOptions.getExtraOptions().size(), 1); + } + + @Test + public void test() { + Configuration config = new Configuration(); + config.put(CONTAINER_MEMORY_MB.getKey(), "256"); + config.put(CONTAINER_JVM_OPTION.getKey(), "-Xmx200m,-Xms200m,-Xmn128m"); + config.put(MASTER_MEMORY_MB.getKey(), "256"); + config.put(MASTER_JVM_OPTIONS.getKey(), "-Xmx200m,-Xms200m,-Xmn128m"); + ClusterConfig clusterConfig = ClusterConfig.build(config); + Assert.assertNotNull(clusterConfig); + Assert.assertNotNull(clusterConfig.getMasterJvmOptions()); + Assert.assertNotNull(clusterConfig.getDriverJvmOptions()); + Assert.assertNotNull(clusterConfig.getContainerJvmOptions()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/constants/ClusterConstantsTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/constants/ClusterConstantsTest.java index 1bc2cfad1..c2c63c689 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/constants/ClusterConstantsTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/constants/ClusterConstantsTest.java @@ -24,51 +24,50 @@ public class ClusterConstantsTest { - @Test - public void testDefaultValues() { - // Test default values - Assert.assertEquals(ClusterConstants.getMasterName(), "master-0"); - Assert.assertEquals(ClusterConstants.getDriverName(1), "driver-1"); - Assert.assertEquals(ClusterConstants.getContainerName(2), "container-2"); - - Assert.assertEquals(ClusterConstants.MASTER_LOG_SUFFIX, "master.log"); - Assert.assertEquals(ClusterConstants.DRIVER_LOG_SUFFIX, "driver.log"); - Assert.assertEquals(ClusterConstants.CONTAINER_LOG_SUFFIX, "container.log"); - - Assert.assertEquals(ClusterConstants.DEFAULT_MASTER_ID, 0); - } + @Test + public void testDefaultValues() { + // Test default values + Assert.assertEquals(ClusterConstants.getMasterName(), "master-0"); + Assert.assertEquals(ClusterConstants.getDriverName(1), "driver-1"); + Assert.assertEquals(ClusterConstants.getContainerName(2), "container-2"); - @Test - public void testGetMasterName() { - String masterName = ClusterConstants.getMasterName(); - Assert.assertEquals(masterName, "master-0"); - } + Assert.assertEquals(ClusterConstants.MASTER_LOG_SUFFIX, "master.log"); + Assert.assertEquals(ClusterConstants.DRIVER_LOG_SUFFIX, "driver.log"); + Assert.assertEquals(ClusterConstants.CONTAINER_LOG_SUFFIX, "container.log"); - @Test - public void testGetDriverName() { - Assert.assertEquals(ClusterConstants.getDriverName(0), "driver-0"); - Assert.assertEquals(ClusterConstants.getDriverName(1), "driver-1"); - Assert.assertEquals(ClusterConstants.getDriverName(10), "driver-10"); - } + Assert.assertEquals(ClusterConstants.DEFAULT_MASTER_ID, 0); + } - @Test - public void testGetContainerName() { - Assert.assertEquals(ClusterConstants.getContainerName(0), "container-0"); - Assert.assertEquals(ClusterConstants.getContainerName(1), "container-1"); - Assert.assertEquals(ClusterConstants.getContainerName(100), "container-100"); - } + @Test + public void testGetMasterName() { + String masterName = ClusterConstants.getMasterName(); + Assert.assertEquals(masterName, "master-0"); + } - @Test - public void testConstants() { - // Test all constants are properly defined - Assert.assertNotNull(ClusterConstants.MASTER_LOG_SUFFIX); - Assert.assertNotNull(ClusterConstants.DRIVER_LOG_SUFFIX); - Assert.assertNotNull(ClusterConstants.CONTAINER_LOG_SUFFIX); - Assert.assertNotNull(ClusterConstants.CLUSTER_TYPE); - Assert.assertNotNull(ClusterConstants.LOCAL_CLUSTER); - Assert.assertNotNull(ClusterConstants.MASTER_ID); - Assert.assertNotNull(ClusterConstants.CONTAINER_ID); - Assert.assertNotNull(ClusterConstants.CONTAINER_INDEX); - } -} + @Test + public void testGetDriverName() { + Assert.assertEquals(ClusterConstants.getDriverName(0), "driver-0"); + Assert.assertEquals(ClusterConstants.getDriverName(1), "driver-1"); + Assert.assertEquals(ClusterConstants.getDriverName(10), "driver-10"); + } + + @Test + public void testGetContainerName() { + Assert.assertEquals(ClusterConstants.getContainerName(0), "container-0"); + Assert.assertEquals(ClusterConstants.getContainerName(1), "container-1"); + Assert.assertEquals(ClusterConstants.getContainerName(100), "container-100"); + } + @Test + public void testConstants() { + // Test all constants are properly defined + Assert.assertNotNull(ClusterConstants.MASTER_LOG_SUFFIX); + Assert.assertNotNull(ClusterConstants.DRIVER_LOG_SUFFIX); + Assert.assertNotNull(ClusterConstants.CONTAINER_LOG_SUFFIX); + Assert.assertNotNull(ClusterConstants.CLUSTER_TYPE); + Assert.assertNotNull(ClusterConstants.LOCAL_CLUSTER); + Assert.assertNotNull(ClusterConstants.MASTER_ID); + Assert.assertNotNull(ClusterConstants.CONTAINER_ID); + Assert.assertNotNull(ClusterConstants.CONTAINER_INDEX); + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerContextTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerContextTest.java index a0a0abe67..0c830af73 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerContextTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerContextTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.util.Objects; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; @@ -37,122 +38,118 @@ public class ContainerContextTest { - private Configuration configuration = new Configuration(); - - @BeforeMethod - public void before() { - String path = "/tmp/" + ContainerContextTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - - configuration.getConfigMap().clear(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), ContainerContextTest.class.getSimpleName()); - configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - configuration.put(FileConfigKeys.ROOT.getKey(), path); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); - ClusterMetaStore.close(); + private Configuration configuration = new Configuration(); + + @BeforeMethod + public void before() { + String path = "/tmp/" + ContainerContextTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + + configuration.getConfigMap().clear(); + configuration.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), ContainerContextTest.class.getSimpleName()); + configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + configuration.put(FileConfigKeys.ROOT.getKey(), path); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + ClusterMetaStore.close(); + } + + @AfterMethod + public void after() { + String path = "/tmp/" + ContainerContextTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + ClusterMetaStore.close(); + } + + @Test + public void testContainer() { + + int containerId = 1; + ClusterMetaStore.init(containerId, "container-0", configuration); + ContainerContext containerContext = new ContainerContext(containerId, configuration); + + TestHAEvent event = new TestHAEvent(); + containerContext.addEvent(event); + containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); + + // recover + ContainerContext recoverContext = new ContainerContext(containerId, configuration, true); + recoverContext.load(); + Assert.assertEquals(1, recoverContext.getReliableEvents().size()); + Assert.assertEquals(event, recoverContext.getReliableEvents().get(0)); + + // ---- mock restart job ---- + // cluster id is changed, re-init cluster metastore. + ClusterMetaStore.close(); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); + ClusterMetaStore.init(containerId, "container-0", configuration); + // rebuild, context reliable event list is empty, and metastore is cleaned. + ContainerContext restartContext = new ContainerContext(containerId, configuration); + restartContext.load(); + Assert.assertEquals(0, restartContext.getReliableEvents().size()); + Assert.assertNull(ClusterMetaStore.getInstance().getEvents()); + } + + @Test + public void testCheckpoint() { + int containerId = 1; + ClusterMetaStore.init(containerId, "container-0", configuration); + ContainerContext containerContext = new ContainerContext(containerId, configuration); + + TestHAEvent event = new TestHAEvent("test1", 1); + containerContext.addEvent(event); + containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); + event.variable = 2; + + TestHAEvent event2 = new TestHAEvent("test2", 2); + containerContext.addEvent(event2); + containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); + + ContainerContext newContext = new ContainerContext(containerId, configuration, true); + newContext.load(); + Assert.assertEquals(2, newContext.getReliableEvents().size()); + Assert.assertEquals(1, ((TestHAEvent) (newContext.getReliableEvents().get(0))).variable); + Assert.assertEquals(2, ((TestHAEvent) (newContext.getReliableEvents().get(1))).variable); + } + + private class TestHAEvent implements IEvent, IHighAvailableEvent { + + private String name = "testEvent"; + private int variable = 0; + + public TestHAEvent() {} + + public TestHAEvent(String name, int variable) { + this.name = name; + this.variable = variable; } - @AfterMethod - public void after() { - String path = "/tmp/" + ContainerContextTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - ClusterMetaStore.close(); + @Override + public EventType getEventType() { + return null; } - @Test - public void testContainer() { - - int containerId = 1; - ClusterMetaStore.init(containerId, "container-0", configuration); - ContainerContext containerContext = new ContainerContext(containerId, configuration); - - TestHAEvent event = new TestHAEvent(); - containerContext.addEvent(event); - containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); - - // recover - ContainerContext recoverContext = new ContainerContext(containerId, configuration, true); - recoverContext.load(); - Assert.assertEquals(1, recoverContext.getReliableEvents().size()); - Assert.assertEquals(event, recoverContext.getReliableEvents().get(0)); - - // ---- mock restart job ---- - // cluster id is changed, re-init cluster metastore. - ClusterMetaStore.close(); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); - ClusterMetaStore.init(containerId, "container-0", configuration); - // rebuild, context reliable event list is empty, and metastore is cleaned. - ContainerContext restartContext = new ContainerContext(containerId, configuration); - restartContext.load(); - Assert.assertEquals(0, restartContext.getReliableEvents().size()); - Assert.assertNull(ClusterMetaStore.getInstance().getEvents()); + @Override + public HighAvailableLevel getHaLevel() { + return HighAvailableLevel.CHECKPOINT; + } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TestHAEvent that = (TestHAEvent) o; + return Objects.equals(name, that.name); } - @Test - public void testCheckpoint() { - int containerId = 1; - ClusterMetaStore.init(containerId, "container-0", configuration); - ContainerContext containerContext = new ContainerContext(containerId, configuration); - - TestHAEvent event = new TestHAEvent("test1", 1); - containerContext.addEvent(event); - containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); - event.variable = 2; - - TestHAEvent event2 = new TestHAEvent("test2", 2); - containerContext.addEvent(event2); - containerContext.checkpoint(new ContainerContext.EventCheckpointFunction()); - - ContainerContext newContext = new ContainerContext(containerId, configuration, true); - newContext.load(); - Assert.assertEquals(2, newContext.getReliableEvents().size()); - Assert.assertEquals(1, ((TestHAEvent) (newContext.getReliableEvents().get(0))).variable); - Assert.assertEquals(2, ((TestHAEvent) (newContext.getReliableEvents().get(1))).variable); + @Override + public int hashCode() { + return Objects.hash(name); } - private class TestHAEvent implements IEvent, IHighAvailableEvent { - - private String name = "testEvent"; - private int variable = 0; - - public TestHAEvent() { - } - - public TestHAEvent(String name, int variable) { - this.name = name; - this.variable = variable; - } - - @Override - public EventType getEventType() { - return null; - } - - @Override - public HighAvailableLevel getHaLevel() { - return HighAvailableLevel.CHECKPOINT; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TestHAEvent that = (TestHAEvent) o; - return Objects.equals(name, that.name); - } - - @Override - public int hashCode() { - return Objects.hash(name); - } - - @Override - public String toString() { - return "TestHAEvent{" - + "name='" + name + '\'' - + ", variable=" + variable - + '}'; - } + @Override + public String toString() { + return "TestHAEvent{" + "name='" + name + '\'' + ", variable=" + variable + '}'; } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerTest.java index afb457056..602f1fb4a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/container/ContainerTest.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.cluster.exception.ExceptionCollectService; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.ICommand; @@ -46,123 +47,121 @@ public class ContainerTest { - private static AtomicBoolean eventExecuted = new AtomicBoolean(false); - private static AtomicBoolean hasException = new AtomicBoolean(false); - private static SecurityManager securityManager; - - @BeforeClass - public void before() { - securityManager = System.getSecurityManager(); - System.setSecurityManager(new SystemExitSignalCatcher(hasException)); + private static AtomicBoolean eventExecuted = new AtomicBoolean(false); + private static AtomicBoolean hasException = new AtomicBoolean(false); + private static SecurityManager securityManager; + + @BeforeClass + public void before() { + securityManager = System.getSecurityManager(); + System.setSecurityManager(new SystemExitSignalCatcher(hasException)); + } + + @AfterClass + public void after() { + System.setSecurityManager(securityManager); + } + + @BeforeMethod + public void beforeMethod() { + hasException.set(false); + } + + @Test + public void testProcessEventHandleException() throws Exception { + + eventExecuted.set(false); + Container container = new Container(); + Map config = new HashMap<>(); + config.put(CLUSTER_ID.getKey(), "0"); + config.put(REPORTER_LIST.getKey(), "slf4j"); + config.put(RUN_LOCAL_MODE.getKey(), "true"); + config.put(CONTAINER_DISPATCH_THREADS.getKey(), "1"); + Configuration configuration = new Configuration(config); + ReflectionUtil.setField(container, "name", "test"); + ReflectionUtil.setField(container, "configuration", configuration); + ReflectionUtil.setField(container, "haService", HAServiceFactory.getService(configuration)); + ReflectionUtil.setField(container, "containerContext", new ContainerContext(0, configuration)); + ReflectionUtil.setField(container, "exceptionCollectService", new ExceptionCollectService()); + container.open(new OpenContainerEvent(1)); + container.process(new TestCreateTaskEvent()); + container.process(new ExceptionCommandEvent()); + + waitTestResult(); + Assert.assertTrue(hasException.get()); + container.close(); + } + + @Test + public void testProcessMultiEventHandleException() throws Exception { + + eventExecuted.set(false); + Container container = new Container(); + Map config = new HashMap<>(); + config.put(CLUSTER_ID.getKey(), "0"); + config.put(REPORTER_LIST.getKey(), "slf4j"); + config.put(RUN_LOCAL_MODE.getKey(), "true"); + config.put(CONTAINER_DISPATCH_THREADS.getKey(), "1"); + Configuration configuration = new Configuration(config); + ReflectionUtil.setField(container, "name", "test"); + ReflectionUtil.setField(container, "configuration", configuration); + ReflectionUtil.setField(container, "haService", HAServiceFactory.getService(configuration)); + ReflectionUtil.setField(container, "containerContext", new ContainerContext(0, configuration)); + ReflectionUtil.setField(container, "exceptionCollectService", new ExceptionCollectService()); + container.open(new OpenContainerEvent(1)); + container.process(new TestCreateTaskEvent()); + container.process(new ExceptionCommandEvent()); + + waitTestResult(); + Assert.assertTrue(hasException.get()); + container.close(); + } + + private void waitTestResult() { + int retry = 10; + while (!eventExecuted.compareAndSet(true, false) && retry > 0) { + SleepUtils.sleepMilliSecond(100); + retry--; } - - @AfterClass - public void after() { - System.setSecurityManager(securityManager); + retry = 10; + while (!hasException.get() && retry > 0) { + SleepUtils.sleepMilliSecond(100); + retry--; } + } - @BeforeMethod - public void beforeMethod() { - hasException.set(false); - } + static class ExceptionCommandEvent implements IExecutableCommand { - @Test - public void testProcessEventHandleException() throws Exception { - - eventExecuted.set(false); - Container container = new Container(); - Map config = new HashMap<>(); - config.put(CLUSTER_ID.getKey(), "0"); - config.put(REPORTER_LIST.getKey(), "slf4j"); - config.put(RUN_LOCAL_MODE.getKey(), "true"); - config.put(CONTAINER_DISPATCH_THREADS.getKey(), "1"); - Configuration configuration = new Configuration(config); - ReflectionUtil.setField(container, "name", "test"); - ReflectionUtil.setField(container, "configuration", configuration); - ReflectionUtil.setField(container, "haService", HAServiceFactory.getService(configuration)); - ReflectionUtil.setField(container, "containerContext", new ContainerContext(0, configuration)); - ReflectionUtil.setField(container, "exceptionCollectService", new ExceptionCollectService()); - container.open(new OpenContainerEvent(1)); - container.process(new TestCreateTaskEvent()); - container.process(new ExceptionCommandEvent()); - - waitTestResult(); - Assert.assertTrue(hasException.get()); - container.close(); + @Override + public int getWorkerId() { + return 0; } - @Test - public void testProcessMultiEventHandleException() throws Exception { - - eventExecuted.set(false); - Container container = new Container(); - Map config = new HashMap<>(); - config.put(CLUSTER_ID.getKey(), "0"); - config.put(REPORTER_LIST.getKey(), "slf4j"); - config.put(RUN_LOCAL_MODE.getKey(), "true"); - config.put(CONTAINER_DISPATCH_THREADS.getKey(), "1"); - Configuration configuration = new Configuration(config); - ReflectionUtil.setField(container, "name", "test"); - ReflectionUtil.setField(container, "configuration", configuration); - ReflectionUtil.setField(container, "haService", HAServiceFactory.getService(configuration)); - ReflectionUtil.setField(container, "containerContext", new ContainerContext(0, configuration)); - ReflectionUtil.setField(container, "exceptionCollectService", new ExceptionCollectService()); - container.open(new OpenContainerEvent(1)); - container.process(new TestCreateTaskEvent()); - container.process(new ExceptionCommandEvent()); - - waitTestResult(); - Assert.assertTrue(hasException.get()); - container.close(); + @Override + public EventType getEventType() { + return EventType.INIT_CYCLE; } - private void waitTestResult() { - int retry = 10; - while (!eventExecuted.compareAndSet(true, false) && retry > 0) { - SleepUtils.sleepMilliSecond(100); - retry--; - } - retry = 10; - while (!hasException.get() && retry > 0) { - SleepUtils.sleepMilliSecond(100); - retry--; - } + @Override + public void execute(ITaskContext taskContext) { + eventExecuted.set(true); + throw new RuntimeException("fatal error"); } - static class ExceptionCommandEvent implements IExecutableCommand { - - @Override - public int getWorkerId() { - return 0; - } + @Override + public void interrupt() {} + } - @Override - public EventType getEventType() { - return EventType.INIT_CYCLE; - } + static class TestCreateTaskEvent implements ICommand { - @Override - public void execute(ITaskContext taskContext) { - eventExecuted.set(true); - throw new RuntimeException("fatal error"); - } - - @Override - public void interrupt() { - - } + @Override + public int getWorkerId() { + return 0; } - static class TestCreateTaskEvent implements ICommand { - - @Override - public int getWorkerId() { - return 0; - } - - @Override - public EventType getEventType() { - return EventType.CREATE_TASK; - } + @Override + public EventType getEventType() { + return EventType.CREATE_TASK; } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/driver/DriverContextTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/driver/DriverContextTest.java index e63734adc..0b1bdcd53 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/driver/DriverContextTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/driver/DriverContextTest.java @@ -24,6 +24,7 @@ import java.io.File; import java.util.List; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.common.ExecutionIdGenerator; import org.apache.geaflow.cluster.failover.FailoverStrategyType; @@ -45,91 +46,95 @@ public class DriverContextTest { - private Configuration configuration = new Configuration(); - - @BeforeMethod - public void before() { - String path = "/tmp/" + DriverContextTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - - configuration.getConfigMap().clear(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), DriverContextTest.class.getSimpleName()); - configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - configuration.put(FileConfigKeys.ROOT.getKey(), path); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); - ExecutionIdGenerator.init(0); - } - - @AfterMethod - public void after() { - String path = "/tmp/" + DriverContextTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - ClusterMetaStore.close(); - } - - @Test - public void testRecoverContext() { - - int driverId = 1; - ClusterMetaStore.init(driverId, "driver-0", configuration); - DriverContext driverContext = new DriverContext(driverId, 0, configuration); - - Environment environment = Mockito.mock(Environment.class); - Mockito.doNothing().when(environment).addPipeline(any()); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.getPipelineTaskList().add(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - - } - }); - driverContext.addPipeline(pipeline); - List pipelineTaskIds = driverContext.getPipelineTaskIds(); - driverContext.addFinishedPipelineTask(0); - driverContext.addFinishedPipelineTask(1); - driverContext.checkpoint(new DriverContext.PipelineCheckpointFunction()); - driverContext.checkpoint(new DriverContext.PipelineTaskCheckpointFunction()); - - DriverContext newContext = new DriverContext(driverId, 0, configuration); - newContext.load(); - - Assert.assertNotNull(pipeline); - Assert.assertEquals(2, newContext.getFinishedPipelineTasks().size()); - Assert.assertEquals(0, newContext.getFinishedPipelineTasks().get(0).intValue()); - Assert.assertEquals(1, newContext.getFinishedPipelineTasks().get(1).intValue()); - Assert.assertEquals(pipelineTaskIds.get(0), newContext.getPipelineTaskIds().get(0)); - - // ---- mock restart job ---- - // cluster id is changed, re-init cluster metastore. - ClusterMetaStore.close(); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); - ClusterMetaStore.init(driverId, "driver-0", configuration); - // rebuild, context reliable event list is empty, and metastore is cleaned. - DriverContext restarted = new DriverContext(driverId, 0, configuration); - restarted.load(); - Assert.assertNull(restarted.getPipeline()); - Assert.assertTrue(restarted.getFinishedPipelineTasks().isEmpty()); - Assert.assertTrue(restarted.getPipelineTaskIds().isEmpty()); - } - - @Test(expectedExceptions = GeaflowRuntimeException.class, - expectedExceptionsMessageRegExp = "not support component_fo for executing pipeline tasks") - public void testPipelineAndCheckFoStrategy() { - - int driverId = 1; - Configuration recoverConfig = new Configuration(configuration.getConfigMap()); - recoverConfig.put(FO_STRATEGY, FailoverStrategyType.component_fo.name()); - DriverContext driverContext = new DriverContext(driverId, 0, recoverConfig); - - Environment environment = Mockito.mock(Environment.class); - Mockito.doNothing().when(environment).addPipeline(any()); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.getPipelineTaskList().add(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - - } - }); - driverContext.addPipeline(pipeline); - } + private Configuration configuration = new Configuration(); + + @BeforeMethod + public void before() { + String path = "/tmp/" + DriverContextTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + + configuration.getConfigMap().clear(); + configuration.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), DriverContextTest.class.getSimpleName()); + configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + configuration.put(FileConfigKeys.ROOT.getKey(), path); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + ExecutionIdGenerator.init(0); + } + + @AfterMethod + public void after() { + String path = "/tmp/" + DriverContextTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + ClusterMetaStore.close(); + } + + @Test + public void testRecoverContext() { + + int driverId = 1; + ClusterMetaStore.init(driverId, "driver-0", configuration); + DriverContext driverContext = new DriverContext(driverId, 0, configuration); + + Environment environment = Mockito.mock(Environment.class); + Mockito.doNothing().when(environment).addPipeline(any()); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline + .getPipelineTaskList() + .add( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) {} + }); + driverContext.addPipeline(pipeline); + List pipelineTaskIds = driverContext.getPipelineTaskIds(); + driverContext.addFinishedPipelineTask(0); + driverContext.addFinishedPipelineTask(1); + driverContext.checkpoint(new DriverContext.PipelineCheckpointFunction()); + driverContext.checkpoint(new DriverContext.PipelineTaskCheckpointFunction()); + + DriverContext newContext = new DriverContext(driverId, 0, configuration); + newContext.load(); + + Assert.assertNotNull(pipeline); + Assert.assertEquals(2, newContext.getFinishedPipelineTasks().size()); + Assert.assertEquals(0, newContext.getFinishedPipelineTasks().get(0).intValue()); + Assert.assertEquals(1, newContext.getFinishedPipelineTasks().get(1).intValue()); + Assert.assertEquals(pipelineTaskIds.get(0), newContext.getPipelineTaskIds().get(0)); + + // ---- mock restart job ---- + // cluster id is changed, re-init cluster metastore. + ClusterMetaStore.close(); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); + ClusterMetaStore.init(driverId, "driver-0", configuration); + // rebuild, context reliable event list is empty, and metastore is cleaned. + DriverContext restarted = new DriverContext(driverId, 0, configuration); + restarted.load(); + Assert.assertNull(restarted.getPipeline()); + Assert.assertTrue(restarted.getFinishedPipelineTasks().isEmpty()); + Assert.assertTrue(restarted.getPipelineTaskIds().isEmpty()); + } + + @Test( + expectedExceptions = GeaflowRuntimeException.class, + expectedExceptionsMessageRegExp = "not support component_fo for executing pipeline tasks") + public void testPipelineAndCheckFoStrategy() { + + int driverId = 1; + Configuration recoverConfig = new Configuration(configuration.getConfigMap()); + recoverConfig.put(FO_STRATEGY, FailoverStrategyType.component_fo.name()); + DriverContext driverContext = new DriverContext(driverId, 0, recoverConfig); + + Environment environment = Mockito.mock(Environment.class); + Mockito.doNothing().when(environment).addPipeline(any()); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline + .getPipelineTaskList() + .add( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) {} + }); + driverContext.addPipeline(pipeline); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandlerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandlerTest.java index 1bae92aca..93fc67ad3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandlerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/exception/ComponentUncaughtExceptionHandlerTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.cluster.util.SystemExitSignalCatcher; import org.apache.geaflow.common.utils.ThreadUtil; import org.testng.Assert; @@ -32,39 +33,42 @@ public class ComponentUncaughtExceptionHandlerTest { - private static SecurityManager securityManager; - private static AtomicBoolean hasException = new AtomicBoolean(false); - + private static SecurityManager securityManager; + private static AtomicBoolean hasException = new AtomicBoolean(false); - @BeforeClass - public void before() { - securityManager = System.getSecurityManager(); - System.setSecurityManager(new SystemExitSignalCatcher(hasException)); - } + @BeforeClass + public void before() { + securityManager = System.getSecurityManager(); + System.setSecurityManager(new SystemExitSignalCatcher(hasException)); + } - @AfterClass - public void after() { - System.setSecurityManager(securityManager); - } + @AfterClass + public void after() { + System.setSecurityManager(securityManager); + } - @BeforeMethod - public void beforeMethod() { - hasException.set(false); - } + @BeforeMethod + public void beforeMethod() { + hasException.set(false); + } - @Test - public void testHandleExceptionInThreadPool() throws InterruptedException { + @Test + public void testHandleExceptionInThreadPool() throws InterruptedException { - ComponentExceptionSupervisor.getInstance(); - ExecutorService executorService = Executors.newFixedThreadPool(2, - ThreadUtil.namedThreadFactory(true, "test-handler", new ComponentUncaughtExceptionHandler())); + ComponentExceptionSupervisor.getInstance(); + ExecutorService executorService = + Executors.newFixedThreadPool( + 2, + ThreadUtil.namedThreadFactory( + true, "test-handler", new ComponentUncaughtExceptionHandler())); - executorService.execute(() -> { - throw new RuntimeException("test exception"); + executorService.execute( + () -> { + throw new RuntimeException("test exception"); }); - executorService.execute(ComponentExceptionSupervisor.getInstance()); - // wait async thread catch and handle exception - Thread.sleep(100); - Assert.assertTrue(hasException.get()); - } + executorService.execute(ComponentExceptionSupervisor.getInstance()); + // wait async thread catch and handle exception + Thread.sleep(100); + Assert.assertTrue(hasException.get()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/failover/FoStrategyFactoryTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/failover/FoStrategyFactoryTest.java index cd8626897..f5f046b0a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/failover/FoStrategyFactoryTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/failover/FoStrategyFactoryTest.java @@ -25,9 +25,8 @@ public class FoStrategyFactoryTest { - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testLoad() { - FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, ""); - } - + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testLoad() { + FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, ""); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManagerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManagerTest.java index 1db413176..5acf9ad2e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManagerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/heartbeat/HeartbeatManagerTest.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.container.ContainerInfo; import org.apache.geaflow.common.config.Configuration; @@ -44,86 +45,89 @@ public class HeartbeatManagerTest { - @Mock - private AbstractClusterManager clusterManager; - - private HeartbeatManager heartbeatManager; - - @BeforeMethod - public void setUp() { - MockitoAnnotations.initMocks(this); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.HEARTBEAT_INITIAL_DELAY_MS, "60000"); - config.put(ExecutionConfigKeys.HEARTBEAT_INTERVAL_MS, "60000"); - config.put(ExecutionConfigKeys.HEARTBEAT_TIMEOUT_MS, "500"); - config.put(ExecutionConfigKeys.SUPERVISOR_ENABLE, "true"); - heartbeatManager = new HeartbeatManager(config, clusterManager); - } - - @AfterMethod - public void tearDown() { - heartbeatManager.close(); - } - - @Test - public void receivedHeartbeat_RegisteredHeartbeat_ReturnsSuccessAndRegistered() { - Heartbeat heartbeat = new Heartbeat(1); - when(clusterManager.getContainerInfos()).thenReturn(new HashMap() {{ - put(1, new ContainerInfo()); - }}); - - HeartbeatResponse response = heartbeatManager.receivedHeartbeat(heartbeat); - - assertEquals(true, response.getSuccess()); - assertEquals(true, response.getRegistered()); - } - - @Test - public void receivedHeartbeat_UnregisteredHeartbeat_ReturnsSuccessAndNotRegistered() { - Heartbeat heartbeat = new Heartbeat(2); - when(clusterManager.getContainerInfos()).thenReturn(new HashMap()); - - HeartbeatResponse response = heartbeatManager.receivedHeartbeat(heartbeat); - - assertEquals(true, response.getSuccess()); - assertEquals(false, response.getRegistered()); - } - - @Test - public void checkHeartBeat_LogsWarningsAndErrors() { - Map containerMap = new HashMap<>(); - containerMap.put(1, "container1"); - when(clusterManager.getContainerIds()).thenReturn(containerMap); - when(clusterManager.getDriverIds()).thenReturn(new HashMap()); - - Heartbeat heartbeat = new Heartbeat(1); - heartbeatManager.receivedHeartbeat(heartbeat); - SleepUtils.sleepMilliSecond(600); - heartbeatManager.checkHeartBeat(); - - verify(clusterManager, times(1)).doFailover(eq(1), isA(GeaflowHeartbeatException.class)); - } - - @Test - public void checkWorkHealth_LogsWarningsAndErrors() { - Map containerMap = new HashMap<>(); - containerMap.put(1, "container1"); - when(clusterManager.getContainerIds()).thenReturn(containerMap); - when(clusterManager.getDriverIds()).thenReturn(new HashMap()); - - heartbeatManager.checkWorkerHealth(); - heartbeatManager.close(); - - verify(clusterManager, times(1)).doFailover(eq(1), isA(GeaflowHeartbeatException.class)); - } - - @Test - public void doFailover_ReportsExceptionAndCallsFailover() { - int componentId = 1; - Throwable exception = new RuntimeException("Test exception"); - - heartbeatManager.doFailover(componentId, exception); - - verify(clusterManager, times(1)).doFailover(eq(componentId), eq(exception)); - } -} \ No newline at end of file + @Mock private AbstractClusterManager clusterManager; + + private HeartbeatManager heartbeatManager; + + @BeforeMethod + public void setUp() { + MockitoAnnotations.initMocks(this); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.HEARTBEAT_INITIAL_DELAY_MS, "60000"); + config.put(ExecutionConfigKeys.HEARTBEAT_INTERVAL_MS, "60000"); + config.put(ExecutionConfigKeys.HEARTBEAT_TIMEOUT_MS, "500"); + config.put(ExecutionConfigKeys.SUPERVISOR_ENABLE, "true"); + heartbeatManager = new HeartbeatManager(config, clusterManager); + } + + @AfterMethod + public void tearDown() { + heartbeatManager.close(); + } + + @Test + public void receivedHeartbeat_RegisteredHeartbeat_ReturnsSuccessAndRegistered() { + Heartbeat heartbeat = new Heartbeat(1); + when(clusterManager.getContainerInfos()) + .thenReturn( + new HashMap() { + { + put(1, new ContainerInfo()); + } + }); + + HeartbeatResponse response = heartbeatManager.receivedHeartbeat(heartbeat); + + assertEquals(true, response.getSuccess()); + assertEquals(true, response.getRegistered()); + } + + @Test + public void receivedHeartbeat_UnregisteredHeartbeat_ReturnsSuccessAndNotRegistered() { + Heartbeat heartbeat = new Heartbeat(2); + when(clusterManager.getContainerInfos()).thenReturn(new HashMap()); + + HeartbeatResponse response = heartbeatManager.receivedHeartbeat(heartbeat); + + assertEquals(true, response.getSuccess()); + assertEquals(false, response.getRegistered()); + } + + @Test + public void checkHeartBeat_LogsWarningsAndErrors() { + Map containerMap = new HashMap<>(); + containerMap.put(1, "container1"); + when(clusterManager.getContainerIds()).thenReturn(containerMap); + when(clusterManager.getDriverIds()).thenReturn(new HashMap()); + + Heartbeat heartbeat = new Heartbeat(1); + heartbeatManager.receivedHeartbeat(heartbeat); + SleepUtils.sleepMilliSecond(600); + heartbeatManager.checkHeartBeat(); + + verify(clusterManager, times(1)).doFailover(eq(1), isA(GeaflowHeartbeatException.class)); + } + + @Test + public void checkWorkHealth_LogsWarningsAndErrors() { + Map containerMap = new HashMap<>(); + containerMap.put(1, "container1"); + when(clusterManager.getContainerIds()).thenReturn(containerMap); + when(clusterManager.getDriverIds()).thenReturn(new HashMap()); + + heartbeatManager.checkWorkerHealth(); + heartbeatManager.close(); + + verify(clusterManager, times(1)).doFailover(eq(1), isA(GeaflowHeartbeatException.class)); + } + + @Test + public void doFailover_ReportsExceptionAndCallsFailover() { + int componentId = 1; + Throwable exception = new RuntimeException("Test exception"); + + heartbeatManager.doFailover(componentId, exception); + + verify(clusterManager, times(1)).doFailover(eq(componentId), eq(exception)); + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManagerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManagerTest.java index e42bd8b81..aae935c8d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManagerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/resourcemanager/DefaultResourceManagerTest.java @@ -33,6 +33,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.clustermanager.ClusterId; import org.apache.geaflow.cluster.clustermanager.ContainerExecutorInfo; @@ -56,598 +57,606 @@ public class DefaultResourceManagerTest { - private static final String TEST = "test"; - - private static final ExecutorService POOL = Executors.newCachedThreadPool(); - private Configuration config = new Configuration(); - - @BeforeMethod - public void setUp() { - config = new Configuration(); - config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - ClusterMetaStore.close(); + private static final String TEST = "test"; + + private static final ExecutorService POOL = Executors.newCachedThreadPool(); + private Configuration config = new Configuration(); + + @BeforeMethod + public void setUp() { + config = new Configuration(); + config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + ClusterMetaStore.close(); + } + + @Test + public void testAllocateWithWorkerEqUserDefinedWorker() { + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithWorkerEqUserDefinedWorker() { - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 2); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 0); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 2); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 0); + } + + @Test + public void testAllocateWithWorkerGtUserDefinedWorker() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithWorkerGtUserDefinedWorker() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 10); - - RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 10); - RequireResponse response3 = resourceManager.requireResource(request3); - Assert.assertEquals(response3.getWorkers().size(), 0); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 10); + + RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 10); + RequireResponse response3 = resourceManager.requireResource(request3); + Assert.assertEquals(response3.getWorkers().size(), 0); + } + + @Test + public void testAllocateWithWorkerLtUserDefinedWorker() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithWorkerLtUserDefinedWorker() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 10); - - RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); - RequireResponse response3 = resourceManager.requireResource(request3); - Assert.assertEquals(response3.getWorkers().size(), 6); - - RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); - RequireResponse response4 = resourceManager.requireResource(request4); - Assert.assertEquals(response4.getWorkers().size(), 0); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 10); + + RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); + RequireResponse response3 = resourceManager.requireResource(request3); + Assert.assertEquals(response3.getWorkers().size(), 6); + + RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); + RequireResponse response4 = resourceManager.requireResource(request4); + Assert.assertEquals(response4.getWorkers().size(), 0); + } + + @Test + public void testAllocateAndRelease() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateAndRelease() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 10); - - RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); - RequireResponse response3 = resourceManager.requireResource(request3); - Assert.assertEquals(response3.getWorkers().size(), 6); - - RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); - RequireResponse response4 = resourceManager.requireResource(request4); - Assert.assertEquals(response4.getWorkers().size(), 0); - - resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); - resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 2, response2.getWorkers())); - - RequireResourceRequest request5 = RequireResourceRequest.build(TEST + 5, 15); - RequireResponse response5 = resourceManager.requireResource(request5); - Assert.assertEquals(response5.getWorkers().size(), 0); - - resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 3, response3.getWorkers())); - RequireResourceRequest request6 = RequireResourceRequest.build(TEST + 6, 15); - RequireResponse response6 = resourceManager.requireResource(request6); - Assert.assertEquals(response6.getWorkers().size(), 15); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 10); + + RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); + RequireResponse response3 = resourceManager.requireResource(request3); + Assert.assertEquals(response3.getWorkers().size(), 6); + + RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); + RequireResponse response4 = resourceManager.requireResource(request4); + Assert.assertEquals(response4.getWorkers().size(), 0); + + resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); + resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 2, response2.getWorkers())); + + RequireResourceRequest request5 = RequireResourceRequest.build(TEST + 5, 15); + RequireResponse response5 = resourceManager.requireResource(request5); + Assert.assertEquals(response5.getWorkers().size(), 0); + + resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 3, response3.getWorkers())); + RequireResourceRequest request6 = RequireResourceRequest.build(TEST + 6, 15); + RequireResponse response6 = resourceManager.requireResource(request6); + Assert.assertEquals(response6.getWorkers().size(), 15); + } + + @Test + public void testAllocateAndReleaseFail() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateAndReleaseFail() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - ReleaseResponse releaseRes1 = resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response1.getWorkers().subList(0, 1))); - Assert.assertFalse(releaseRes1.isSuccess()); - - ReleaseResponse releaseRes2 = resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); - Assert.assertTrue(releaseRes2.isSuccess()); - - ReleaseResponse releaseRes3 = resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 3, response1.getWorkers())); - Assert.assertFalse(releaseRes3.isSuccess()); - - + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + ReleaseResponse releaseRes1 = + resourceManager.releaseResource( + ReleaseResourceRequest.build(TEST + 1, response1.getWorkers().subList(0, 1))); + Assert.assertFalse(releaseRes1.isSuccess()); + + ReleaseResponse releaseRes2 = + resourceManager.releaseResource( + ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); + Assert.assertTrue(releaseRes2.isSuccess()); + + ReleaseResponse releaseRes3 = + resourceManager.releaseResource( + ReleaseResourceRequest.build(TEST + 3, response1.getWorkers())); + Assert.assertFalse(releaseRes3.isSuccess()); + } + + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testReleaseException() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testReleaseException() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 4); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 4); - - resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response2.getWorkers())); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 4); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 4); + + resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response2.getWorkers())); + } + + @Test + public void testAllocateWithIllegalNum() { + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithIllegalNum() { - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST, -1); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertFalse(response1.isSuccess()); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST, -1); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertFalse(response1.isSuccess()); + } + + @Test + public void testReleaseWithIllegalResourceId() { + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testReleaseWithIllegalResourceId() { - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - config.put(JOB_UNIQUE_ID.getKey(), "geaflow12345"); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST, -1); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertFalse(response1.isSuccess()); - - ReleaseResponse response2 = resourceManager.releaseResource(ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); - Assert.assertFalse(response2.isSuccess()); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST, -1); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertFalse(response1.isSuccess()); + + ReleaseResponse response2 = + resourceManager.releaseResource( + ReleaseResourceRequest.build(TEST + 1, response1.getWorkers())); + Assert.assertFalse(response2.isSuccess()); + } + + @Test + public void testAllocateWithOneRequireId() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithOneRequireId() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertTrue(response1.isSuccess()); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST, 4); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertTrue(response2.isSuccess()); - Assert.assertEquals(response2.getWorkers().size(), 4); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertTrue(response1.isSuccess()); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST, 4); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertTrue(response2.isSuccess()); + Assert.assertEquals(response2.getWorkers().size(), 4); + } + + @Test + public void testAllocateWithOneRequireIdFail() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockClusterManager clusterManager = new MockClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testAllocateWithOneRequireIdFail() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockClusterManager clusterManager = new MockClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST, 4); - RequireResponse response1 = resourceManager.requireResource(request1); - Assert.assertTrue(response1.isSuccess()); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST, 10); - RequireResponse response2 = resourceManager.requireResource(request2); - Assert.assertFalse(response2.isSuccess()); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST, 4); + RequireResponse response1 = resourceManager.requireResource(request1); + Assert.assertTrue(response1.isSuccess()); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST, 10); + RequireResponse response2 = resourceManager.requireResource(request2); + Assert.assertFalse(response2.isSuccess()); + } + + @Test + public void testRecoverWithWorkerGtUserDefinedWorker() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testRecoverWithWorkerGtUserDefinedWorker() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - clusterContext.getCallbacks().clear(); - - DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); - MasterContext recoverMasterContext = new MasterContext(config); - clusterContext.setRecover(true); - recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); - - pending = recoverRm.getPendingWorkerCounter(); - lock = recoverRm.getResourceLock(); - // wait async allocate worker ready - do { - SleepUtils.sleepMilliSecond(10); - } while (pending.get() > 0 || !lock.get()); - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = recoverRm.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = recoverRm.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 10); - - RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 10); - RequireResponse response3 = recoverRm.requireResource(request3); - Assert.assertEquals(response3.getWorkers().size(), 0); + clusterContext.getCallbacks().clear(); + + DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); + MasterContext recoverMasterContext = new MasterContext(config); + clusterContext.setRecover(true); + recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); + + pending = recoverRm.getPendingWorkerCounter(); + lock = recoverRm.getResourceLock(); + // wait async allocate worker ready + do { + SleepUtils.sleepMilliSecond(10); + } while (pending.get() > 0 || !lock.get()); + + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = recoverRm.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = recoverRm.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 10); + + RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 10); + RequireResponse response3 = recoverRm.requireResource(request3); + Assert.assertEquals(response3.getWorkers().size(), 0); + } + + @Test + public void testRecoverWithWorkerLtUserDefinedWorker() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testRecoverWithWorkerLtUserDefinedWorker() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - clusterContext.getCallbacks().clear(); - - DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); - MasterContext recoverMasterContext = new MasterContext(config); - clusterContext.setRecover(true); - recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); - - pending = recoverRm.getPendingWorkerCounter(); - lock = recoverRm.getResourceLock(); - // wait async allocate worker ready - do { - SleepUtils.sleepMilliSecond(10); - } while (pending.get() > 0 || !lock.get()); - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = recoverRm.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = recoverRm.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 10); - - RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); - RequireResponse response3 = recoverRm.requireResource(request3); - Assert.assertEquals(response3.getWorkers().size(), 6); - - RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); - RequireResponse response4 = recoverRm.requireResource(request4); - Assert.assertEquals(response4.getWorkers().size(), 0); + clusterContext.getCallbacks().clear(); + + DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); + MasterContext recoverMasterContext = new MasterContext(config); + clusterContext.setRecover(true); + recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); + + pending = recoverRm.getPendingWorkerCounter(); + lock = recoverRm.getResourceLock(); + // wait async allocate worker ready + do { + SleepUtils.sleepMilliSecond(10); + } while (pending.get() > 0 || !lock.get()); + + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = recoverRm.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = recoverRm.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 10); + + RequireResourceRequest request3 = RequireResourceRequest.build(TEST + 3, 6); + RequireResponse response3 = recoverRm.requireResource(request3); + Assert.assertEquals(response3.getWorkers().size(), 6); + + RequireResourceRequest request4 = RequireResourceRequest.build(TEST + 4, 1); + RequireResponse response4 = recoverRm.requireResource(request4); + Assert.assertEquals(response4.getWorkers().size(), 0); + } + + @Test + public void testUseSomeAndRecover() { + config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); + config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); + config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); + ClusterMetaStore.init(0, "master-0", config); + ClusterContext clusterContext = new ClusterContext(config); + MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); + clusterManager.init(clusterContext); + + DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); + resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); + + AtomicInteger pending = resourceManager.getPendingWorkerCounter(); + AtomicBoolean lock = resourceManager.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - @Test - public void testUseSomeAndRecover() { - config.put(JOB_UNIQUE_ID.getKey(), "geaflow23456"); - config.put(CONTAINER_NUM.getKey(), String.valueOf(4)); - config.put(CONTAINER_WORKER_NUM.getKey(), String.valueOf(5)); - ClusterMetaStore.init(0, "master-0", config); - ClusterContext clusterContext = new ClusterContext(config); - MockRecoverClusterManager clusterManager = new MockRecoverClusterManager(); - clusterManager.init(clusterContext); - - DefaultResourceManager resourceManager = new DefaultResourceManager(clusterManager); - resourceManager.init(ResourceManagerContext.build(new MasterContext(config), clusterContext)); - - AtomicInteger pending = resourceManager.getPendingWorkerCounter(); - AtomicBoolean lock = resourceManager.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request0 = RequireResourceRequest.build(TEST + 0, 7); - RequireResponse response0 = resourceManager.requireResource(request0); - Assert.assertEquals(response0.getWorkers().size(), 7); - - clusterContext.getCallbacks().clear(); - - DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); - MasterContext recoverMasterContext = new MasterContext(config); - clusterContext.setRecover(true); - recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); - - pending = recoverRm.getPendingWorkerCounter(); - lock = recoverRm.getResourceLock(); - while (pending.get() > 0 || !lock.get()) { - SleepUtils.sleepMilliSecond(10); - } - - RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); - RequireResponse response1 = recoverRm.requireResource(request1); - Assert.assertEquals(response1.getWorkers().size(), 4); - - RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); - RequireResponse response2 = recoverRm.requireResource(request2); - Assert.assertEquals(response2.getWorkers().size(), 0); - } + RequireResourceRequest request0 = RequireResourceRequest.build(TEST + 0, 7); + RequireResponse response0 = resourceManager.requireResource(request0); + Assert.assertEquals(response0.getWorkers().size(), 7); - @Test - public void testPersistWorkers() { - List available = buildMockWorkers(buildMockContainerExecutorInfo(0, 10)); - List used = buildMockWorkers(buildMockContainerExecutorInfo(1, 10)); - ResourceSession session = new ResourceSession(TEST); - for (WorkerInfo worker : used) { - session.addWorker(worker.generateWorkerId(), worker); - } - List sessions = Collections.singletonList(session); - - WorkerSnapshot workerSnapshot = new WorkerSnapshot(); - workerSnapshot.setAvailableWorkers(available); - workerSnapshot.setSessions(sessions); - Assert.assertNotNull(workerSnapshot.getAvailableWorkers()); - Assert.assertNotNull(workerSnapshot.getSessions()); - Assert.assertEquals(workerSnapshot.getAvailableWorkers().size(), 10); - Assert.assertEquals(workerSnapshot.getSessions().size(), 1); - Assert.assertEquals(workerSnapshot.getSessions().get(0).getWorkers().size(), 10); - // test serialize - ISerializer kryoSerializer = SerializerFactory.getKryoSerializer(); - byte[] bytes = kryoSerializer.serialize(workerSnapshot); - WorkerSnapshot deserialized = (WorkerSnapshot) kryoSerializer.deserialize(bytes); - List deAvailable = deserialized.getAvailableWorkers(); - List deSessions = deserialized.getSessions(); - Assert.assertNotNull(deAvailable); - Assert.assertNotNull(deSessions); - for (int i = 0; i < 10; i++) { - Assert.assertEquals(deAvailable.get(i).getContainerName(), "host0"); - Assert.assertEquals(deAvailable.get(i).getHost(), "host0"); - Assert.assertEquals(deAvailable.get(i).getWorkerIndex(), i); - } - List deUsed = new ArrayList<>(deSessions.get(0).getWorkers().values()); - deUsed.sort(Comparator.comparing(WorkerInfo::getWorkerIndex)); - for (int i = 0; i < 10; i++) { - Assert.assertEquals(deUsed.get(i).getContainerName(), "host1"); - Assert.assertEquals(deUsed.get(i).getHost(), "host1"); - Assert.assertEquals(deUsed.get(i).getWorkerIndex(), i); - } + clusterContext.getCallbacks().clear(); - } + DefaultResourceManager recoverRm = new DefaultResourceManager(clusterManager); + MasterContext recoverMasterContext = new MasterContext(config); + clusterContext.setRecover(true); + recoverRm.init(ResourceManagerContext.build(recoverMasterContext, clusterContext)); - @AfterTest - public void destroy() { - POOL.shutdown(); + pending = recoverRm.getPendingWorkerCounter(); + lock = recoverRm.getResourceLock(); + while (pending.get() > 0 || !lock.get()) { + SleepUtils.sleepMilliSecond(10); } - private static ContainerExecutorInfo buildMockContainerExecutorInfo(int hostId, int workerNum) { - String host = "host" + hostId; - return new ContainerExecutorInfo(new ContainerInfo(hostId, host, host, 0, 0, 0), 0, workerNum); + RequireResourceRequest request1 = RequireResourceRequest.build(TEST + 1, 4); + RequireResponse response1 = recoverRm.requireResource(request1); + Assert.assertEquals(response1.getWorkers().size(), 4); + + RequireResourceRequest request2 = RequireResourceRequest.build(TEST + 2, 10); + RequireResponse response2 = recoverRm.requireResource(request2); + Assert.assertEquals(response2.getWorkers().size(), 0); + } + + @Test + public void testPersistWorkers() { + List available = buildMockWorkers(buildMockContainerExecutorInfo(0, 10)); + List used = buildMockWorkers(buildMockContainerExecutorInfo(1, 10)); + ResourceSession session = new ResourceSession(TEST); + for (WorkerInfo worker : used) { + session.addWorker(worker.generateWorkerId(), worker); + } + List sessions = Collections.singletonList(session); + + WorkerSnapshot workerSnapshot = new WorkerSnapshot(); + workerSnapshot.setAvailableWorkers(available); + workerSnapshot.setSessions(sessions); + Assert.assertNotNull(workerSnapshot.getAvailableWorkers()); + Assert.assertNotNull(workerSnapshot.getSessions()); + Assert.assertEquals(workerSnapshot.getAvailableWorkers().size(), 10); + Assert.assertEquals(workerSnapshot.getSessions().size(), 1); + Assert.assertEquals(workerSnapshot.getSessions().get(0).getWorkers().size(), 10); + // test serialize + ISerializer kryoSerializer = SerializerFactory.getKryoSerializer(); + byte[] bytes = kryoSerializer.serialize(workerSnapshot); + WorkerSnapshot deserialized = (WorkerSnapshot) kryoSerializer.deserialize(bytes); + List deAvailable = deserialized.getAvailableWorkers(); + List deSessions = deserialized.getSessions(); + Assert.assertNotNull(deAvailable); + Assert.assertNotNull(deSessions); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(deAvailable.get(i).getContainerName(), "host0"); + Assert.assertEquals(deAvailable.get(i).getHost(), "host0"); + Assert.assertEquals(deAvailable.get(i).getWorkerIndex(), i); + } + List deUsed = new ArrayList<>(deSessions.get(0).getWorkers().values()); + deUsed.sort(Comparator.comparing(WorkerInfo::getWorkerIndex)); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(deUsed.get(i).getContainerName(), "host1"); + Assert.assertEquals(deUsed.get(i).getHost(), "host1"); + Assert.assertEquals(deUsed.get(i).getWorkerIndex(), i); + } + } + + @AfterTest + public void destroy() { + POOL.shutdown(); + } + + private static ContainerExecutorInfo buildMockContainerExecutorInfo(int hostId, int workerNum) { + String host = "host" + hostId; + return new ContainerExecutorInfo(new ContainerInfo(hostId, host, host, 0, 0, 0), 0, workerNum); + } + + private static List buildMockWorkers(ContainerExecutorInfo container) { + String containerName = container.getContainerName(); + String host = container.getHost(); + int rpcPort = container.getRpcPort(); + int shufflePort = container.getShufflePort(); + int processId = container.getProcessId(); + List executorIds = container.getExecutorIds(); + List workers = new ArrayList<>(); + for (Integer workerIndex : executorIds) { + WorkerInfo worker = + WorkerInfo.build(host, rpcPort, shufflePort, processId, 0, workerIndex, containerName); + workers.add(worker); + } + return workers; + } + + private static class MockClusterManager implements IClusterManager { + + protected final AtomicInteger counter = new AtomicInteger(); + protected ClusterContext clusterContext; + protected int containerWorkerNum; + + @Override + public void init(ClusterContext context) { + this.clusterContext = context; + this.clusterContext + .getClusterConfig() + .getConfig() + .put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + ClusterConfig clusterConfig = ClusterConfig.build(context.getConfig()); + this.containerWorkerNum = clusterConfig.getContainerWorkerNum(); } - private static List buildMockWorkers(ContainerExecutorInfo container) { - String containerName = container.getContainerName(); - String host = container.getHost(); - int rpcPort = container.getRpcPort(); - int shufflePort = container.getShufflePort(); - int processId = container.getProcessId(); - List executorIds = container.getExecutorIds(); - List workers = new ArrayList<>(); - for (Integer workerIndex : executorIds) { - WorkerInfo worker = WorkerInfo.build( - host, rpcPort, shufflePort, processId, 0, workerIndex, containerName); - workers.add(worker); - } - return workers; + @Override + public ClusterId startMaster() { + return null; } - private static class MockClusterManager implements IClusterManager { - - protected final AtomicInteger counter = new AtomicInteger(); - protected ClusterContext clusterContext; - protected int containerWorkerNum; - - @Override - public void init(ClusterContext context) { - this.clusterContext = context; - this.clusterContext.getClusterConfig().getConfig().put(SYSTEM_STATE_BACKEND_TYPE.getKey(), - StoreType.MEMORY.name()); - ClusterConfig clusterConfig = ClusterConfig.build(context.getConfig()); - this.containerWorkerNum = clusterConfig.getContainerWorkerNum(); - } - - @Override - public ClusterId startMaster() { - return null; - } - - @Override - public Map startDrivers() { - return null; - } - - @Override - public void allocateWorkers(int executorNum) { - POOL.execute(() -> { - int left = executorNum; - while (left > 0) { - for (ExecutorRegisteredCallback callback : this.clusterContext.getCallbacks()) { - callback.onSuccess(buildMockContainerExecutorInfo(this.counter.getAndIncrement(), this.containerWorkerNum)); - } - left -= this.containerWorkerNum; - } - }); - } - - @Override - public void doFailover(int componentId, Throwable cause) { - } - - @Override - public void close() { - } + @Override + public Map startDrivers() { + return null; } - private static class MockRecoverClusterManager extends MockClusterManager { - - @Override - public void allocateWorkers(int executorNum) { - POOL.execute(() -> { - int left = executorNum; - int hostId = 0; - while (left > 0) { - for (ExecutorRegisteredCallback callback : this.clusterContext.getCallbacks()) { - callback.onSuccess(buildMockContainerExecutorInfo(hostId, this.containerWorkerNum)); - } - left -= this.containerWorkerNum; - hostId++; - } - }); - } + @Override + public void allocateWorkers(int executorNum) { + POOL.execute( + () -> { + int left = executorNum; + while (left > 0) { + for (ExecutorRegisteredCallback callback : this.clusterContext.getCallbacks()) { + callback.onSuccess( + buildMockContainerExecutorInfo( + this.counter.getAndIncrement(), this.containerWorkerNum)); + } + left -= this.containerWorkerNum; + } + }); } + @Override + public void doFailover(int componentId, Throwable cause) {} + + @Override + public void close() {} + } + + private static class MockRecoverClusterManager extends MockClusterManager { + + @Override + public void allocateWorkers(int executorNum) { + POOL.execute( + () -> { + int left = executorNum; + int hostId = 0; + while (left > 0) { + for (ExecutorRegisteredCallback callback : this.clusterContext.getCallbacks()) { + callback.onSuccess(buildMockContainerExecutorInfo(hostId, this.containerWorkerNum)); + } + left -= this.containerWorkerNum; + hostId++; + } + }); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/AsyncRpcTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/AsyncRpcTest.java index 7e1b57d69..4a2c52b24 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/AsyncRpcTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/AsyncRpcTest.java @@ -19,12 +19,12 @@ package org.apache.geaflow.cluster.rpc; -import com.google.protobuf.Empty; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.rpc.impl.ContainerEndpointRef; @@ -45,183 +45,176 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.protobuf.Empty; + public class AsyncRpcTest { - private static final Logger LOGGER = LoggerFactory.getLogger(AsyncRpcTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AsyncRpcTest.class); - @Test - public void testAsyncRpc() throws Exception { + @Test + public void testAsyncRpc() throws Exception { - Server server = new Server(); - server.startServer(); - String host = ProcessUtil.getHostIp(); - ContainerEndpointRef client = new ContainerEndpointRef(host, server.rpcPort, - new Configuration()); + Server server = new Server(); + server.startServer(); + String host = ProcessUtil.getHostIp(); + ContainerEndpointRef client = + new ContainerEndpointRef(host, server.rpcPort, new Configuration()); - int eventCount = 100; - List request = new ArrayList<>(); - List> events = new ArrayList<>(); - for (int i = 0; i < eventCount; i++) { - IEvent event = new TestEvent(i); - request.add(event); - events.add(new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>()))); - } - validateResult(events, eventCount, 5000); - LOGGER.info("send event finish"); + int eventCount = 100; + List request = new ArrayList<>(); + List> events = new ArrayList<>(); + for (int i = 0; i < eventCount; i++) { + IEvent event = new TestEvent(i); + request.add(event); + events.add(new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>()))); } - - @Test - public void testShutdownChannel() throws Exception { - Server server = new Server(); - server.startServer(); - String host = ProcessUtil.getHostIp(); - ContainerEndpointRef client = new ContainerEndpointRef(host, server.rpcPort, - new Configuration()); - - int eventCount = 1000; - List eventIds = new ArrayList<>(); - List processedIds = new ArrayList<>(); - for (int i = 0; i < eventCount; i++) { - TestEvent event = new TestEvent(i); - event.processTimeMs = 1; - Future future = new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>())); - processedIds.add(((TestEvent) (future.get(1000, TimeUnit.MILLISECONDS))).id); - eventIds.add(i); - // Do shutdown. - if (i % 100 == 0) { - LOGGER.info("shutdown channel"); - server.stopServer(); - LOGGER.info("shutdown channel finish"); - SleepUtils.sleepMilliSecond(10); - } - } - - Assert.assertEquals(processedIds, eventIds); - LOGGER.info("send event finish"); + validateResult(events, eventCount, 5000); + LOGGER.info("send event finish"); + } + + @Test + public void testShutdownChannel() throws Exception { + Server server = new Server(); + server.startServer(); + String host = ProcessUtil.getHostIp(); + ContainerEndpointRef client = + new ContainerEndpointRef(host, server.rpcPort, new Configuration()); + + int eventCount = 1000; + List eventIds = new ArrayList<>(); + List processedIds = new ArrayList<>(); + for (int i = 0; i < eventCount; i++) { + TestEvent event = new TestEvent(i); + event.processTimeMs = 1; + Future future = + new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>())); + processedIds.add(((TestEvent) (future.get(1000, TimeUnit.MILLISECONDS))).id); + eventIds.add(i); + // Do shutdown. + if (i % 100 == 0) { + LOGGER.info("shutdown channel"); + server.stopServer(); + LOGGER.info("shutdown channel finish"); + SleepUtils.sleepMilliSecond(10); + } } - @Test(expectedExceptions = ExecutionException.class) - public void testServerError() throws Exception { - - Server server = new Server(); - server.startServer(); - String host = ProcessUtil.getHostIp(); - ContainerEndpointRef client = new ContainerEndpointRef(host, server.rpcPort, - new Configuration()); - - int eventCount = 100; - List> results = new ArrayList<>(); - for (int i = 0; i < eventCount; i++) { - TestEvent event = new TestEvent(i); - if (i == 50) { - event.isException = true; - } - results.add(new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>()))); - } - LOGGER.info("send event finish"); - validateResult(results, eventCount, 5000); + Assert.assertEquals(processedIds, eventIds); + LOGGER.info("send event finish"); + } + + @Test(expectedExceptions = ExecutionException.class) + public void testServerError() throws Exception { + + Server server = new Server(); + server.startServer(); + String host = ProcessUtil.getHostIp(); + ContainerEndpointRef client = + new ContainerEndpointRef(host, server.rpcPort, new Configuration()); + + int eventCount = 100; + List> results = new ArrayList<>(); + for (int i = 0; i < eventCount; i++) { + TestEvent event = new TestEvent(i); + if (i == 50) { + event.isException = true; + } + results.add(new RpcResponseFuture(client.process(event, new DefaultRpcCallbackImpl<>()))); } - - public void validateResult(List> results, int count, int waitTimeMs) - throws Exception { - List eventIds = new ArrayList<>(); - List processedIds = new ArrayList<>(); - LOGGER.info("validate result"); - for (int i = 0; i < count; i++) { - eventIds.add(i); - processedIds.add( - ((TestEvent) (results.get(i).get(waitTimeMs, TimeUnit.MILLISECONDS))).id); - } - Assert.assertEquals(processedIds, eventIds); - LOGGER.info("validate finish"); + LOGGER.info("send event finish"); + validateResult(results, eventCount, 5000); + } + + public void validateResult(List> results, int count, int waitTimeMs) + throws Exception { + List eventIds = new ArrayList<>(); + List processedIds = new ArrayList<>(); + LOGGER.info("validate result"); + for (int i = 0; i < count; i++) { + eventIds.add(i); + processedIds.add(((TestEvent) (results.get(i).get(waitTimeMs, TimeUnit.MILLISECONDS))).id); } + Assert.assertEquals(processedIds, eventIds); + LOGGER.info("validate finish"); + } - /** - * Mock event with dummy info. - */ - public class TestEvent implements IEvent { + /** Mock event with dummy info. */ + public class TestEvent implements IEvent { - private int id; - private int processTimeMs; - private boolean isException; - - public TestEvent(int id) { - this.id = id; - } + private int id; + private int processTimeMs; + private boolean isException; + public TestEvent(int id) { + this.id = id; + } - @Override - public EventType getEventType() { - return null; - } + @Override + public EventType getEventType() { + return null; + } - @Override - public String toString() { - return "TestEvent{" + - "id=" + id + - '}'; - } + @Override + public String toString() { + return "TestEvent{" + "id=" + id + '}'; } + } - public class Server { + public class Server { - protected int rpcPort; + protected int rpcPort; - protected Configuration configuration = new Configuration(); - protected RpcServiceImpl rpcService; + protected Configuration configuration = new Configuration(); + protected RpcServiceImpl rpcService; - public void startServer() { - this.rpcService = new RpcServiceImpl(PortUtil.getPort(rpcPort), - ConfigurableServerOption.build(configuration)); - this.rpcService.addEndpoint(new MockContainerEndpoint()); - this.rpcPort = rpcService.startService(); - } + public void startServer() { + this.rpcService = + new RpcServiceImpl( + PortUtil.getPort(rpcPort), ConfigurableServerOption.build(configuration)); + this.rpcService.addEndpoint(new MockContainerEndpoint()); + this.rpcPort = rpcService.startService(); + } - public void stopServer() { - rpcService.stopService(); - } + public void stopServer() { + rpcService.stopService(); } + } - /** - * Mock endpoint to process events. - */ - public class MockContainerEndpoint implements IContainerEndpoint { + /** Mock endpoint to process events. */ + public class MockContainerEndpoint implements IContainerEndpoint { - public MockContainerEndpoint() { - } + public MockContainerEndpoint() {} - @Override - public Response process(Request request) { - try { - IEvent event = RpcMessageEncoder.decode(request.getPayload()); - if (((TestEvent) event).processTimeMs > 0) { - SleepUtils.sleepMilliSecond(((TestEvent) event).processTimeMs); - } - if (((TestEvent) event).isException) { - LOGGER.info("on error: mock exception"); - throw new GeaflowRuntimeException("occur mock exception"); - } else { - Container.Response.Builder builder = - Container.Response.newBuilder(); - builder.setPayload(request.getPayload()); - return builder.build(); - } - } catch (Throwable t) { - LOGGER.error("process request failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException("process request failed", t); - } + @Override + public Response process(Request request) { + try { + IEvent event = RpcMessageEncoder.decode(request.getPayload()); + if (((TestEvent) event).processTimeMs > 0) { + SleepUtils.sleepMilliSecond(((TestEvent) event).processTimeMs); } - - @Override - public Empty close(Empty request) { - try { - LOGGER.info("close"); - return Empty.newBuilder().build(); - } catch (Throwable t) { - LOGGER.error("close failed: {}", t.getMessage(), t); - throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); - } + if (((TestEvent) event).isException) { + LOGGER.info("on error: mock exception"); + throw new GeaflowRuntimeException("occur mock exception"); + } else { + Container.Response.Builder builder = Container.Response.newBuilder(); + builder.setPayload(request.getPayload()); + return builder.build(); } + } catch (Throwable t) { + LOGGER.error("process request failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException("process request failed", t); + } } + @Override + public Empty close(Empty request) { + try { + LOGGER.info("close"); + return Empty.newBuilder().build(); + } catch (Throwable t) { + LOGGER.error("close failed: {}", t.getMessage(), t); + throw new GeaflowRuntimeException(String.format("close failed: %s", t.getMessage()), t); + } + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcAddressTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcAddressTest.java index 12d6c1c3e..4124ab0a4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcAddressTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcAddressTest.java @@ -24,11 +24,10 @@ public class RpcAddressTest { - @Test - public void testAddress() { - ConnectAddress address = new ConnectAddress("localhost", 0); - ConnectAddress newAddr = ConnectAddress.build(address.toString()); - Assert.assertEquals(address, newAddr); - } - + @Test + public void testAddress() { + ConnectAddress address = new ConnectAddress("localhost", 0); + ConnectAddress newAddr = ConnectAddress.build(address.toString()); + Assert.assertEquals(address, newAddr); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcClientTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcClientTest.java index c6e4fe987..27606f11d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcClientTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcClientTest.java @@ -31,25 +31,23 @@ public class RpcClientTest { - @Test - public void testInvalidateResource() { - Configuration config = new Configuration(); - config.put(RUN_LOCAL_MODE, "true"); - RpcClient rpcClient = RpcClient.init(config); - IHAService haService = HAServiceFactory.getService(); - - ResourceData resourceData = new ResourceData(); - resourceData.setHost("host"); - resourceData.setRpcPort(2); - haService.register("1", resourceData); - - resourceData = rpcClient.getResourceData("1"); - Assert.assertNotNull(resourceData); - - rpcClient.invalidateEndpointCache("1", EndpointType.MASTER); - resourceData = rpcClient.getResourceData("1"); - Assert.assertNull(resourceData); - } - - + @Test + public void testInvalidateResource() { + Configuration config = new Configuration(); + config.put(RUN_LOCAL_MODE, "true"); + RpcClient rpcClient = RpcClient.init(config); + IHAService haService = HAServiceFactory.getService(); + + ResourceData resourceData = new ResourceData(); + resourceData.setHost("host"); + resourceData.setRpcPort(2); + haService.register("1", resourceData); + + resourceData = rpcClient.getResourceData("1"); + Assert.assertNotNull(resourceData); + + rpcClient.invalidateEndpointCache("1", EndpointType.MASTER); + resourceData = rpcClient.getResourceData("1"); + Assert.assertNull(resourceData); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefTest.java index b582c36b3..bbe876e57 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/rpc/RpcEndpointRefTest.java @@ -19,72 +19,82 @@ package org.apache.geaflow.cluster.rpc; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; + import javax.annotation.Nullable; + import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + public class RpcEndpointRefTest { - private static final Logger LOGGER = LoggerFactory.getLogger(RpcEndpointRefTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RpcEndpointRefTest.class); - @Test - public void testFuture() throws ExecutionException, InterruptedException { - ExecutorService executorService = Executors.newFixedThreadPool(2); - SettableFuture future = SettableFuture.create(); + @Test + public void testFuture() throws ExecutionException, InterruptedException { + ExecutorService executorService = Executors.newFixedThreadPool(2); + SettableFuture future = SettableFuture.create(); - AtomicInteger result = new AtomicInteger(0); - CountDownLatch countDownLatch = new CountDownLatch(1); - handleFutureCallback(future, new RpcEndpointRef.RpcCallback() { - @Override - public void onSuccess(Object value) { - LOGGER.info("on success"); - result.set(1); - countDownLatch.countDown(); - } + AtomicInteger result = new AtomicInteger(0); + CountDownLatch countDownLatch = new CountDownLatch(1); + handleFutureCallback( + future, + new RpcEndpointRef.RpcCallback() { + @Override + public void onSuccess(Object value) { + LOGGER.info("on success"); + result.set(1); + countDownLatch.countDown(); + } - @Override - public void onFailure(Throwable t) { - LOGGER.info("on failure"); - result.set(2); - countDownLatch.countDown(); - } - }, executorService); + @Override + public void onFailure(Throwable t) { + LOGGER.info("on failure"); + result.set(2); + countDownLatch.countDown(); + } + }, + executorService); - future.set("test"); + future.set("test"); - countDownLatch.await(2, TimeUnit.SECONDS); - Assert.assertEquals(1, result.get()); - Assert.assertEquals("test", future.get()); + countDownLatch.await(2, TimeUnit.SECONDS); + Assert.assertEquals(1, result.get()); + Assert.assertEquals("test", future.get()); - System.out.println("final result" + future.get()); - } + System.out.println("final result" + future.get()); + } - public void handleFutureCallback(ListenableFuture future, - RpcEndpointRef.RpcCallback listener, - ExecutorService executorService) { - Futures.addCallback(future, new FutureCallback() { - @Override - public void onSuccess(@Nullable T result) { - listener.onSuccess(result); - } + public void handleFutureCallback( + ListenableFuture future, + RpcEndpointRef.RpcCallback listener, + ExecutorService executorService) { + Futures.addCallback( + future, + new FutureCallback() { + @Override + public void onSuccess(@Nullable T result) { + listener.onSuccess(result); + } - @Override - public void onFailure(Throwable t) { - LOGGER.error("rpc call failed " + t); - listener.onFailure(t); - } - }, executorService); - } -} \ No newline at end of file + @Override + public void onFailure(Throwable t) { + LOGGER.error("rpc call failed " + t); + listener.onFailure(t); + } + }, + executorService); + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/runner/SupervisorTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/runner/SupervisorTest.java index 1f3de201e..46099285d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/runner/SupervisorTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/runner/SupervisorTest.java @@ -21,27 +21,28 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.SleepUtils; import org.testng.Assert; import org.testng.annotations.Test; public class SupervisorTest { - @Test - public void test() { - Configuration configuration = new Configuration(); - Map envMap = new HashMap<>(); - String cmd = "sleep 30000"; - Supervisor supervisor = new Supervisor(cmd, configuration, false); - supervisor.startWorker(); - // wait for process starts. - while (!supervisor.isWorkerAlive()) { - SleepUtils.sleepMilliSecond(500); - } - supervisor.stopWorker(); - // wait for process exits. - SleepUtils.sleepSecond(1); - Assert.assertFalse(supervisor.isWorkerAlive()); - supervisor.stop(); + @Test + public void test() { + Configuration configuration = new Configuration(); + Map envMap = new HashMap<>(); + String cmd = "sleep 30000"; + Supervisor supervisor = new Supervisor(cmd, configuration, false); + supervisor.startWorker(); + // wait for process starts. + while (!supervisor.isWorkerAlive()) { + SleepUtils.sleepMilliSecond(500); } + supervisor.stopWorker(); + // wait for process exits. + SleepUtils.sleepSecond(1); + Assert.assertFalse(supervisor.isWorkerAlive()); + supervisor.stop(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactoryTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactoryTest.java index d31fb8b5f..d72bdedc0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactoryTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreFactoryTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.system; import java.io.File; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -30,27 +31,28 @@ public class ClusterMetaStoreFactoryTest { - @Test - public void testCreateMemoryKvStore() { - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "memory"); - IClusterMetaKVStore ikvStore = ClusterMetaStoreFactory.create("test", configuration); - Assert.assertNotNull(ikvStore); - } - - @Test - public void testCreateRocksdbKvStore() { - - String path = "/tmp/" + ClusterMetaStoreFactoryTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME, ClusterMetaStoreFactoryTest.class.getSimpleName()); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); - configuration.put(FileConfigKeys.ROOT, path); - configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "rocksdb"); - configuration.put(ExecutionConfigKeys.SYSTEM_META_TABLE, "cluster_metastore_kv_test"); - IClusterMetaKVStore ikvStore = ClusterMetaStoreFactory.create("test", configuration); - Assert.assertNotNull(ikvStore); - } + @Test + public void testCreateMemoryKvStore() { + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "memory"); + IClusterMetaKVStore ikvStore = ClusterMetaStoreFactory.create("test", configuration); + Assert.assertNotNull(ikvStore); + } + + @Test + public void testCreateRocksdbKvStore() { + + String path = "/tmp/" + ClusterMetaStoreFactoryTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + + Configuration configuration = new Configuration(); + configuration.put( + ExecutionConfigKeys.JOB_APP_NAME, ClusterMetaStoreFactoryTest.class.getSimpleName()); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); + configuration.put(FileConfigKeys.ROOT, path); + configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "rocksdb"); + configuration.put(ExecutionConfigKeys.SYSTEM_META_TABLE, "cluster_metastore_kv_test"); + IClusterMetaKVStore ikvStore = ClusterMetaStoreFactory.create("test", configuration); + Assert.assertNotNull(ikvStore); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreTest.java index 5999f715c..b57c0ab79 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/ClusterMetaStoreTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.util.ArrayList; import java.util.HashMap; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.resourcemanager.WorkerSnapshot; import org.apache.geaflow.common.config.Configuration; @@ -33,40 +34,42 @@ public class ClusterMetaStoreTest { - private Configuration configuration = new Configuration(); - - @BeforeMethod - public void before() { - String path = "/tmp/" + ClusterMetaStoreTest.class.getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - - configuration.getConfigMap().clear(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), ClusterMetaStoreTest.class.getSimpleName()); - configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - configuration.put(FileConfigKeys.ROOT.getKey(), path); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + private Configuration configuration = new Configuration(); - } + @BeforeMethod + public void before() { + String path = "/tmp/" + ClusterMetaStoreTest.class.getSimpleName(); + FileUtils.deleteQuietly(new File(path)); - @Test - public void testMultiThreadSave() { - int id = 0; - ClusterMetaStore metaStore = ClusterMetaStore.getInstance(id, "master-0", configuration); - metaStore.getContainerIds(); + configuration.getConfigMap().clear(); + configuration.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), ClusterMetaStoreTest.class.getSimpleName()); + configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + configuration.put(FileConfigKeys.ROOT.getKey(), path); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + } - Thread thread = new Thread(() -> { - WorkerSnapshot workerSnapshot = new WorkerSnapshot(new ArrayList<>(), new ArrayList<>()); - metaStore.saveWorkers(workerSnapshot); - SleepUtils.sleepSecond(1); - metaStore.flush(); - }); - thread.start(); + @Test + public void testMultiThreadSave() { + int id = 0; + ClusterMetaStore metaStore = ClusterMetaStore.getInstance(id, "master-0", configuration); + metaStore.getContainerIds(); - HashMap ids = new HashMap<>(); - ids.put(1, "1"); - metaStore.saveContainerIds(ids); - SleepUtils.sleepSecond(2); - metaStore.flush(); - } + Thread thread = + new Thread( + () -> { + WorkerSnapshot workerSnapshot = + new WorkerSnapshot(new ArrayList<>(), new ArrayList<>()); + metaStore.saveWorkers(workerSnapshot); + SleepUtils.sleepSecond(1); + metaStore.flush(); + }); + thread.start(); -} \ No newline at end of file + HashMap ids = new HashMap<>(); + ids.put(1, "1"); + metaStore.saveContainerIds(ids); + SleepUtils.sleepSecond(2); + metaStore.flush(); + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStoreTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStoreTest.java index 9612fe70d..1b844bb56 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStoreTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/MemoryClusterMetaKVStoreTest.java @@ -26,35 +26,35 @@ public class MemoryClusterMetaKVStoreTest { - @Test - public void testStore() { - Configuration config = new Configuration(); - IClusterMetaKVStore kvStore = new MemoryClusterMetaKVStore(); - StoreContext storeContext = new StoreContext("cluster_meta_test"); - storeContext.withConfig(config); - kvStore.init(storeContext); - - kvStore.put("key1", "value1"); - kvStore.put("key2", "value2"); - kvStore.flush(); - Assert.assertEquals(kvStore.get("key1"), "value1"); - Assert.assertEquals(kvStore.get("key2"), "value2"); - - kvStore.put("key1", "value1"); - kvStore.put("key3", "value3"); - kvStore.flush(); - - Assert.assertEquals(kvStore.get("key1"), "value1"); - Assert.assertEquals(kvStore.get("key2"), "value2"); - Assert.assertEquals(kvStore.get("key3"), "value3"); - Assert.assertEquals(kvStore.get("key4"), null); - - kvStore.remove("key1"); - kvStore.remove("key4"); - kvStore.flush(); - Assert.assertEquals(kvStore.get("key1"), null); - Assert.assertEquals(kvStore.get("key2"), "value2"); - Assert.assertEquals(kvStore.get("key3"), "value3"); - Assert.assertEquals(kvStore.get("key4"), null); - } + @Test + public void testStore() { + Configuration config = new Configuration(); + IClusterMetaKVStore kvStore = new MemoryClusterMetaKVStore(); + StoreContext storeContext = new StoreContext("cluster_meta_test"); + storeContext.withConfig(config); + kvStore.init(storeContext); + + kvStore.put("key1", "value1"); + kvStore.put("key2", "value2"); + kvStore.flush(); + Assert.assertEquals(kvStore.get("key1"), "value1"); + Assert.assertEquals(kvStore.get("key2"), "value2"); + + kvStore.put("key1", "value1"); + kvStore.put("key3", "value3"); + kvStore.flush(); + + Assert.assertEquals(kvStore.get("key1"), "value1"); + Assert.assertEquals(kvStore.get("key2"), "value2"); + Assert.assertEquals(kvStore.get("key3"), "value3"); + Assert.assertEquals(kvStore.get("key4"), null); + + kvStore.remove("key1"); + kvStore.remove("key4"); + kvStore.flush(); + Assert.assertEquals(kvStore.get("key1"), null); + Assert.assertEquals(kvStore.get("key2"), "value2"); + Assert.assertEquals(kvStore.get("key3"), "value3"); + Assert.assertEquals(kvStore.get("key4"), null); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStoreTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStoreTest.java index c29a17483..7f21a0aee 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStoreTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/system/RocksdbClusterMetaKVStoreTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.system; import java.io.File; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -31,42 +32,42 @@ public class RocksdbClusterMetaKVStoreTest { - @Test - public void testRocksdbKvStore() { - Configuration config = new Configuration(); - FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbStoreBuilderTest"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/RocksdbStoreBuilderTest"); + @Test + public void testRocksdbKvStore() { + Configuration config = new Configuration(); + FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbStoreBuilderTest"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/RocksdbStoreBuilderTest"); - IClusterMetaKVStore kvStore = new RocksdbClusterMetaKVStore<>(); - StoreContext storeContext = new StoreContext("cluster_meta_test"); - storeContext.withConfig(config); - storeContext.withKeySerializer(new DefaultKVSerializer(null, null)); - kvStore.init(storeContext); + IClusterMetaKVStore kvStore = new RocksdbClusterMetaKVStore<>(); + StoreContext storeContext = new StoreContext("cluster_meta_test"); + storeContext.withConfig(config); + storeContext.withKeySerializer(new DefaultKVSerializer(null, null)); + kvStore.init(storeContext); - kvStore.put("key1", "value1"); - kvStore.put("key2", "value2"); - kvStore.flush(); - Assert.assertEquals(kvStore.get("key1"), "value1"); - Assert.assertEquals(kvStore.get("key2"), "value2"); + kvStore.put("key1", "value1"); + kvStore.put("key2", "value2"); + kvStore.flush(); + Assert.assertEquals(kvStore.get("key1"), "value1"); + Assert.assertEquals(kvStore.get("key2"), "value2"); - kvStore.put("key1", "value1"); - kvStore.put("key3", "value3"); - kvStore.flush(); - Assert.assertEquals(kvStore.get("key1"), "value1"); - Assert.assertEquals(kvStore.get("key2"), "value2"); - Assert.assertEquals(kvStore.get("key3"), "value3"); + kvStore.put("key1", "value1"); + kvStore.put("key3", "value3"); + kvStore.flush(); + Assert.assertEquals(kvStore.get("key1"), "value1"); + Assert.assertEquals(kvStore.get("key2"), "value2"); + Assert.assertEquals(kvStore.get("key3"), "value3"); - kvStore.remove("key1"); - kvStore.remove("key5"); - kvStore.put("key4", "value4"); - kvStore.flush(); - Assert.assertEquals(kvStore.get("key1"), null); - Assert.assertEquals(kvStore.get("key2"), "value2"); - Assert.assertEquals(kvStore.get("key3"), "value3"); - Assert.assertEquals(kvStore.get("key4"), "value4"); + kvStore.remove("key1"); + kvStore.remove("key5"); + kvStore.put("key4", "value4"); + kvStore.flush(); + Assert.assertEquals(kvStore.get("key1"), null); + Assert.assertEquals(kvStore.get("key2"), "value2"); + Assert.assertEquals(kvStore.get("key3"), "value3"); + Assert.assertEquals(kvStore.get("key4"), "value4"); - kvStore.close(); - } + kvStore.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/util/SystemExitSignalCatcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/util/SystemExitSignalCatcher.java index bf9c9ee2b..60c410fa4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/util/SystemExitSignalCatcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/util/SystemExitSignalCatcher.java @@ -21,28 +21,27 @@ import java.security.Permission; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class SystemExitSignalCatcher extends SecurityManager { - private AtomicBoolean hasSignal; + private AtomicBoolean hasSignal; - public SystemExitSignalCatcher(AtomicBoolean hasSignal) { - this.hasSignal = hasSignal; - } + public SystemExitSignalCatcher(AtomicBoolean hasSignal) { + this.hasSignal = hasSignal; + } - @Override - public void checkPermission(Permission perm) { - } + @Override + public void checkPermission(Permission perm) {} - @Override - public void checkPermission(Permission perm, Object context) { - } + @Override + public void checkPermission(Permission perm, Object context) {} - @Override - public void checkExit(int status) { - super.checkExit(status); - hasSignal.set(true); - throw new GeaflowRuntimeException("throw exception instead of exit process"); - } + @Override + public void checkExit(int status) { + super.checkExit(status); + hasSignal.set(true); + throw new GeaflowRuntimeException("throw exception instead of exit process"); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/HttpServerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/HttpServerTest.java index a4904350c..466a61cfa 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/HttpServerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/HttpServerTest.java @@ -24,10 +24,7 @@ import java.net.URI; import java.util.Collections; import java.util.concurrent.CountDownLatch; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.ResponseBody; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.common.ComponentInfo; import org.apache.geaflow.cluster.heartbeat.HeartbeatManager; @@ -39,52 +36,58 @@ import org.testng.Assert; import org.testng.annotations.Test; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + public class HttpServerTest { - @Test - public void test() throws Exception { - Configuration configuration = new Configuration(); - configuration.put(HA_SERVICE_TYPE, HAServiceType.memory.name()); - AbstractClusterManager clusterManager = Mockito.mock(AbstractClusterManager.class); - HeartbeatManager heartbeatManager = new HeartbeatManager(configuration, clusterManager); - IResourceManager resourceManager = new DefaultResourceManager(clusterManager); - HttpServer httpServer = new HttpServer(configuration, clusterManager, heartbeatManager, - resourceManager, new ComponentInfo()); - CountDownLatch latch = new CountDownLatch(1); + @Test + public void test() throws Exception { + Configuration configuration = new Configuration(); + configuration.put(HA_SERVICE_TYPE, HAServiceType.memory.name()); + AbstractClusterManager clusterManager = Mockito.mock(AbstractClusterManager.class); + HeartbeatManager heartbeatManager = new HeartbeatManager(configuration, clusterManager); + IResourceManager resourceManager = new DefaultResourceManager(clusterManager); + HttpServer httpServer = + new HttpServer( + configuration, clusterManager, heartbeatManager, resourceManager, new ComponentInfo()); + CountDownLatch latch = new CountDownLatch(1); - new Thread(new Runnable() { - @Override - public void run() { + new Thread( + new Runnable() { + @Override + public void run() { httpServer.start(); latch.countDown(); - } - }).start(); + } + }) + .start(); - Mockito.when(clusterManager.getContainerInfos()).thenReturn(Collections.EMPTY_MAP); + Mockito.when(clusterManager.getContainerInfos()).thenReturn(Collections.EMPTY_MAP); - latch.await(); - doGet("http://localhost:8090/", "rest/overview"); - doGet("http://localhost:8090/", "rest/containers"); - doGet("http://localhost:8090/", "rest/drivers"); - doGet("http://localhost:8090/", "rest/master/configuration"); - doGet("http://localhost:8090/", "rest/pipelines"); - doGet("http://localhost:8090/", "rest/pipelines/1/cycles"); - httpServer.stop(); - } + latch.await(); + doGet("http://localhost:8090/", "rest/overview"); + doGet("http://localhost:8090/", "rest/containers"); + doGet("http://localhost:8090/", "rest/drivers"); + doGet("http://localhost:8090/", "rest/master/configuration"); + doGet("http://localhost:8090/", "rest/pipelines"); + doGet("http://localhost:8090/", "rest/pipelines/1/cycles"); + httpServer.stop(); + } - private void doGet(String url, String path) throws Exception { - URI uri = new URI(url); - String fullUrl = uri.resolve(path).toString(); - Request request = new Request.Builder().url(fullUrl) - .get().build(); + private void doGet(String url, String path) throws Exception { + URI uri = new URI(url); + String fullUrl = uri.resolve(path).toString(); + Request request = new Request.Builder().url(fullUrl).get().build(); - OkHttpClient client = new OkHttpClient(); - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - Assert.assertTrue(response.isSuccessful()); - Assert.assertNotNull(responseBody); - Assert.assertNotNull(responseBody.string()); - } + OkHttpClient client = new OkHttpClient(); + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + Assert.assertTrue(response.isSuccessful()); + Assert.assertNotNull(responseBody); + Assert.assertNotNull(responseBody.string()); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/agent/AgentWebServerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/agent/AgentWebServerTest.java index 9d8ab3f81..bebf36600 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/agent/AgentWebServerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/agent/AgentWebServerTest.java @@ -19,120 +19,128 @@ package org.apache.geaflow.cluster.web.agent; -import com.alibaba.fastjson.JSON; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.net.URI; import java.util.concurrent.CountDownLatch; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.ResponseBody; + import org.apache.geaflow.cluster.web.api.ApiResponse; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import org.testng.Assert; +import com.alibaba.fastjson.JSON; + +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + public class AgentWebServerTest { - private static final int AGENT_PORT = 8088; + private static final int AGENT_PORT = 8088; - private static final String LOG_DIR = "/tmp/geaflow/test/logs/geaflow"; + private static final String LOG_DIR = "/tmp/geaflow/test/logs/geaflow"; - private static final String CLIENT_LOG_PATH = LOG_DIR + File.separator + "client.log"; + private static final String CLIENT_LOG_PATH = LOG_DIR + File.separator + "client.log"; - private static final String FLAME_GRAPH_PROFILER_PATH = "/tmp/async-profiler/profiler.sh"; + private static final String FLAME_GRAPH_PROFILER_PATH = "/tmp/async-profiler/profiler.sh"; - private static final String FLAME_GRAPH_PROFILER_FILENAME_EXTENSION = ".html"; + private static final String FLAME_GRAPH_PROFILER_FILENAME_EXTENSION = ".html"; - private static final String AGENT_DIR = "/tmp/agent"; + private static final String AGENT_DIR = "/tmp/agent"; - private static final String THREAD_DUMP_LOG_FILE_PATH = "/tmp/agent/geaflow-thread-dump.log"; + private static final String THREAD_DUMP_LOG_FILE_PATH = "/tmp/agent/geaflow-thread-dump.log"; - @Test - public void testServer() throws Exception { - AgentWebServer httpServer = new AgentWebServer(AGENT_PORT, LOG_DIR, - FLAME_GRAPH_PROFILER_PATH, FLAME_GRAPH_PROFILER_FILENAME_EXTENSION, AGENT_DIR); - CountDownLatch latch = new CountDownLatch(1); - new Thread(new Runnable() { - @Override - public void run() { + @Test + public void testServer() throws Exception { + AgentWebServer httpServer = + new AgentWebServer( + AGENT_PORT, + LOG_DIR, + FLAME_GRAPH_PROFILER_PATH, + FLAME_GRAPH_PROFILER_FILENAME_EXTENSION, + AGENT_DIR); + CountDownLatch latch = new CountDownLatch(1); + new Thread( + new Runnable() { + @Override + public void run() { httpServer.start(); latch.countDown(); - } - }).start(); - - latch.await(); - doGet("http://localhost:8088/", "rest/logs"); - doGet("http://localhost:8088/", - "rest/logs/content?path=/tmp/geaflow/test/logs/geaflow/client.log&pageNo=1&pageSize=1024"); - doGet("http://localhost:8088/", - "rest/thread-dump/content?pageNo=1&pageSize=1024"); - httpServer.stop(); + } + }) + .start(); + + latch.await(); + doGet("http://localhost:8088/", "rest/logs"); + doGet( + "http://localhost:8088/", + "rest/logs/content?path=/tmp/geaflow/test/logs/geaflow/client.log&pageNo=1&pageSize=1024"); + doGet("http://localhost:8088/", "rest/thread-dump/content?pageNo=1&pageSize=1024"); + httpServer.stop(); + } + + @BeforeClass + public static void init() { + initLogFiles(); + } + + @AfterClass + public static void afterClass() { + cleanup(); + } + + private static void initLogFiles() { + initLogFile(CLIENT_LOG_PATH); + initLogFile(THREAD_DUMP_LOG_FILE_PATH); + } + + private static void initLogFile(String filePath) { + try { + File file = new File(filePath); + File fileParent = file.getParentFile(); + if (!fileParent.exists()) { + fileParent.mkdirs(); + } + if (!file.exists()) { + file.createNewFile(); + } + FileOutputStream fos = new FileOutputStream(file); + for (int i = 0; i < 10240; i++) { + fos.write(("Mock log " + i + "\n").getBytes()); + } + + } catch (IOException e) { + throw new RuntimeException(e); } + } - @BeforeClass - public static void init() { - initLogFiles(); - } - - @AfterClass - public static void afterClass() { - cleanup(); - } - - private static void initLogFiles() { - initLogFile(CLIENT_LOG_PATH); - initLogFile(THREAD_DUMP_LOG_FILE_PATH); - } - - private static void initLogFile(String filePath) { - try { - File file = new File(filePath); - File fileParent = file.getParentFile(); - if (!fileParent.exists()) { - fileParent.mkdirs(); - } - if (!file.exists()) { - file.createNewFile(); - } - FileOutputStream fos = new FileOutputStream(file); - for (int i = 0; i < 10240; i++) { - fos.write(("Mock log " + i + "\n").getBytes()); - } - - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static void cleanup() { - cleanLogFile(CLIENT_LOG_PATH); - cleanLogFile(THREAD_DUMP_LOG_FILE_PATH); - } + private static void cleanup() { + cleanLogFile(CLIENT_LOG_PATH); + cleanLogFile(THREAD_DUMP_LOG_FILE_PATH); + } - private static void cleanLogFile(String filePath) { - File file = new File(filePath); - if (file.exists() && file.isFile()) { - file.delete(); - } + private static void cleanLogFile(String filePath) { + File file = new File(filePath); + if (file.exists() && file.isFile()) { + file.delete(); } - - private void doGet(String url, String path) throws Exception { - URI uri = new URI(url); - String fullUrl = uri.resolve(path).toString(); - Request request = new Request.Builder().url(fullUrl).get().build(); - - OkHttpClient client = new OkHttpClient(); - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - Assert.assertEquals(response.code(), 200); - ApiResponse apiResponse = JSON.parseObject(responseBody.string(), ApiResponse.class); - Assert.assertTrue(apiResponse.isSuccess()); - } + } + + private void doGet(String url, String path) throws Exception { + URI uri = new URI(url); + String fullUrl = uri.resolve(path).toString(); + Request request = new Request.Builder().url(fullUrl).get().build(); + + OkHttpClient client = new OkHttpClient(); + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + Assert.assertEquals(response.code(), 200); + ApiResponse apiResponse = JSON.parseObject(responseBody.string(), ApiResponse.class); + Assert.assertTrue(apiResponse.isSuccess()); } - - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/metrics/MetricCacheTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/metrics/MetricCacheTest.java index a6b2cdad2..e965c8f0c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/metrics/MetricCacheTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/test/java/org/apache/geaflow/cluster/web/metrics/MetricCacheTest.java @@ -28,17 +28,16 @@ public class MetricCacheTest { - @Test - public void test() { - PipelineMetrics pipelineMetrics = new PipelineMetrics("1"); - MetricCache metricCache = new MetricCache(); - metricCache.addPipelineMetrics(pipelineMetrics); - CycleMetrics cycleMetrics = new CycleMetrics("1", "1", ""); - metricCache.addCycleMetrics(cycleMetrics); - byte[] bytes = SerializerFactory.getKryoSerializer().serialize(metricCache); - MetricCache newCache = (MetricCache) SerializerFactory.getKryoSerializer().deserialize(bytes); - Assert.assertNotNull(newCache); - Assert.assertEquals(newCache.getPipelineMetricCaches().size(), 1); - } - + @Test + public void test() { + PipelineMetrics pipelineMetrics = new PipelineMetrics("1"); + MetricCache metricCache = new MetricCache(); + metricCache.addPipelineMetrics(pipelineMetrics); + CycleMetrics cycleMetrics = new CycleMetrics("1", "1", ""); + metricCache.addCycleMetrics(cycleMetrics); + byte[] bytes = SerializerFactory.getKryoSerializer().serialize(metricCache); + MetricCache newCache = (MetricCache) SerializerFactory.getKryoSerializer().deserialize(bytes); + Assert.assertNotNull(newCache); + Assert.assertEquals(newCache.getPipelineMetricCaches().size(), 1); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderContender.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderContender.java index f86f856ab..f63f5b0db 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderContender.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderContender.java @@ -21,10 +21,9 @@ public interface ILeaderContender { - void handleLeadershipGranted(); + void handleLeadershipGranted(); - void handleLeadershipLost(); - - LeaderContenderType getType(); + void handleLeadershipLost(); + LeaderContenderType getType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionDriver.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionDriver.java index e3d3b9862..95a043812 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionDriver.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionDriver.java @@ -23,8 +23,7 @@ public interface ILeaderElectionDriver { - void open(ILeaderContender contender, Configuration configuration); - - void close(); + void open(ILeaderContender contender, Configuration configuration); + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionEventListener.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionEventListener.java index bcc0638b8..a0ea3a645 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionEventListener.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionEventListener.java @@ -21,10 +21,9 @@ public interface ILeaderElectionEventListener { - void handleLeadershipGranted(); + void handleLeadershipGranted(); - void handleLeadershipLost(); - - void handleNewLeadership(String newLeader); + void handleLeadershipLost(); + void handleNewLeadership(String newLeader); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionService.java index a6c752022..efc177849 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/ILeaderElectionService.java @@ -23,14 +23,13 @@ public interface ILeaderElectionService { - void init(Configuration configuration, String componentId); + void init(Configuration configuration, String componentId); - void open(ILeaderContender contender); + void open(ILeaderContender contender); - void close(); + void close(); - boolean isLeader(); - - LeaderElectionServiceType getType(); + boolean isLeader(); + LeaderElectionServiceType getType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderContenderType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderContenderType.java index c8b652e59..9c5038fc7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderContenderType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderContenderType.java @@ -20,10 +20,9 @@ package org.apache.geaflow.ha.leaderelection; public enum LeaderContenderType { + supervisor, - supervisor, + master, - master, - - driver + driver } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceFactory.java index b72a9109e..8a3db618c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -30,23 +31,27 @@ public class LeaderElectionServiceFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(LeaderElectionServiceFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LeaderElectionServiceFactory.class); - public static ILeaderElectionService loadElectionService(Configuration configuration) { - String leaderElectionServiceType = configuration.getString(ExecutionConfigKeys.LEADER_ELECTION_TYPE); - ServiceLoader contextLoader = ServiceLoader.load(ILeaderElectionService.class); - Iterator contextIterable = contextLoader.iterator(); - while (contextIterable.hasNext()) { - ILeaderElectionService service = contextIterable.next(); - if (service.getType().name().equalsIgnoreCase(leaderElectionServiceType)) { - LOGGER.info("load ILeaderElectionService implementation {} with type {}", - service.getClass().getName(), leaderElectionServiceType); - return service; - } - } - LOGGER.error("Not found ILeaderElectionService implementation with type {}", leaderElectionServiceType); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(ILeaderElectionService.class.getSimpleName())); + public static ILeaderElectionService loadElectionService(Configuration configuration) { + String leaderElectionServiceType = + configuration.getString(ExecutionConfigKeys.LEADER_ELECTION_TYPE); + ServiceLoader contextLoader = + ServiceLoader.load(ILeaderElectionService.class); + Iterator contextIterable = contextLoader.iterator(); + while (contextIterable.hasNext()) { + ILeaderElectionService service = contextIterable.next(); + if (service.getType().name().equalsIgnoreCase(leaderElectionServiceType)) { + LOGGER.info( + "load ILeaderElectionService implementation {} with type {}", + service.getClass().getName(), + leaderElectionServiceType); + return service; + } } - + LOGGER.error( + "Not found ILeaderElectionService implementation with type {}", leaderElectionServiceType); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(ILeaderElectionService.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceType.java index 617dcb0b1..a080650a8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/leaderelection/LeaderElectionServiceType.java @@ -21,9 +21,6 @@ public enum LeaderElectionServiceType { - /** - * Kubernetes native leader election service. - */ - kubernetes - + /** Kubernetes native leader election service. */ + kubernetes } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/runtime/HighAvailableLevel.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/runtime/HighAvailableLevel.java index c5a207adf..f278acc67 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/runtime/HighAvailableLevel.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/runtime/HighAvailableLevel.java @@ -20,12 +20,8 @@ package org.apache.geaflow.ha.runtime; public enum HighAvailableLevel { - /** - * Redo level. - */ - REDO, - /** - * Checkpoint level. - */ - CHECKPOINT + /** Redo level. */ + REDO, + /** Checkpoint level. */ + CHECKPOINT } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/AbstractHAService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/AbstractHAService.java index a84b789a0..f91a1ae0c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/AbstractHAService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/AbstractHAService.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -34,105 +35,106 @@ public abstract class AbstractHAService implements IHAService { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractHAService.class); - protected static final String TABLE_PREFIX = "WORKERS_"; - protected static final int LOAD_INTERVAL_MS = 200; - - protected int connectTimeout; - protected int recoverTimeout; - protected Map resourceDataCache; - protected IKVStore kvStore; - - public AbstractHAService() { - this.resourceDataCache = new ConcurrentHashMap<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractHAService.class); + protected static final String TABLE_PREFIX = "WORKERS_"; + protected static final int LOAD_INTERVAL_MS = 200; + + protected int connectTimeout; + protected int recoverTimeout; + protected Map resourceDataCache; + protected IKVStore kvStore; + + public AbstractHAService() { + this.resourceDataCache = new ConcurrentHashMap<>(); + } + + @Override + public void open(Configuration configuration) { + this.recoverTimeout = configuration.getInteger(ExecutionConfigKeys.FO_TIMEOUT_MS); + this.connectTimeout = configuration.getInteger(ExecutionConfigKeys.RPC_CONNECT_TIMEOUT_MS); + } + + @Override + public void register(String resourceId, ResourceData resourceData) { + if (kvStore != null) { + kvStore.put(resourceId, resourceData); } - - @Override - public void open(Configuration configuration) { - this.recoverTimeout = configuration.getInteger(ExecutionConfigKeys.FO_TIMEOUT_MS); - this.connectTimeout = configuration.getInteger(ExecutionConfigKeys.RPC_CONNECT_TIMEOUT_MS); - } - - @Override - public void register(String resourceId, ResourceData resourceData) { - if (kvStore != null) { - kvStore.put(resourceId, resourceData); - } - } - - @Override - public ResourceData resolveResource(String resourceId) { - return resourceDataCache.computeIfAbsent(resourceId, key -> loadDataFromStore(key, true)); - } - - @Override - public ResourceData loadResource(String resourceId) { - return resourceDataCache.computeIfAbsent(resourceId, key -> loadDataFromStore(key, false)); + } + + @Override + public ResourceData resolveResource(String resourceId) { + return resourceDataCache.computeIfAbsent(resourceId, key -> loadDataFromStore(key, true)); + } + + @Override + public ResourceData loadResource(String resourceId) { + return resourceDataCache.computeIfAbsent(resourceId, key -> loadDataFromStore(key, false)); + } + + @Override + public ResourceData invalidateResource(String resourceId) { + return resourceDataCache.remove(resourceId); + } + + @Override + public void close() { + if (kvStore != null) { + kvStore.close(); } + } - @Override - public ResourceData invalidateResource(String resourceId) { - return resourceDataCache.remove(resourceId); + protected ResourceData getResourceData(String resourceId) { + if (kvStore != null) { + return kvStore.get(resourceId); } - - @Override - public void close() { - if (kvStore != null) { - kvStore.close(); + return null; + } + + private ResourceData loadDataFromStore(String resourceId, boolean resolve) { + return loadDataFromStore(resourceId, resolve, recoverTimeout, ResourceData::getRpcPort); + } + + public ResourceData loadDataFromStore( + String resourceId, boolean resolve, Function portFunc) { + return loadDataFromStore(resourceId, resolve, recoverTimeout, portFunc); + } + + private ResourceData loadDataFromStore( + String resourceId, boolean resolve, int timeoutMs, Function portFunc) { + long currentTime = System.currentTimeMillis(); + long startTime = currentTime; + long checkTime = currentTime; + Throwable throwable = null; + ResourceData resourceData; + do { + currentTime = System.currentTimeMillis(); + if (currentTime - checkTime > 2000) { + long elapsedTime = currentTime - startTime; + checkTime = currentTime; + if (elapsedTime > timeoutMs) { + String reason = throwable != null ? throwable.getMessage() : null; + String msg = + String.format( + "load resource %s timeout after %sms, reason:%s", + resourceId, elapsedTime, reason); + LOGGER.error(msg); + throw new GeaflowRuntimeException(msg); } - } - - protected ResourceData getResourceData(String resourceId) { - if (kvStore != null) { - return kvStore.get(resourceId); + SleepUtils.sleepMilliSecond(LOAD_INTERVAL_MS); + } + resourceData = getResourceData(resourceId); + if (resourceData != null) { + try { + if (resolve) { + int port = portFunc.apply(resourceData); + NetworkUtil.checkServiceAvailable(resourceData.getHost(), port, connectTimeout); + } + break; + } catch (IOException ex) { + throwable = ex; } - return null; - } - - private ResourceData loadDataFromStore(String resourceId, boolean resolve) { - return loadDataFromStore(resourceId, resolve, recoverTimeout, ResourceData::getRpcPort); - } - - public ResourceData loadDataFromStore(String resourceId, boolean resolve, - Function portFunc) { - return loadDataFromStore(resourceId, resolve, recoverTimeout, portFunc); - } - - private ResourceData loadDataFromStore(String resourceId, boolean resolve, int timeoutMs, - Function portFunc) { - long currentTime = System.currentTimeMillis(); - long startTime = currentTime; - long checkTime = currentTime; - Throwable throwable = null; - ResourceData resourceData; - do { - currentTime = System.currentTimeMillis(); - if (currentTime - checkTime > 2000) { - long elapsedTime = currentTime - startTime; - checkTime = currentTime; - if (elapsedTime > timeoutMs) { - String reason = throwable != null ? throwable.getMessage() : null; - String msg = String.format("load resource %s timeout after %sms, reason:%s", - resourceId, elapsedTime, reason); - LOGGER.error(msg); - throw new GeaflowRuntimeException(msg); - } - SleepUtils.sleepMilliSecond(LOAD_INTERVAL_MS); - } - resourceData = getResourceData(resourceId); - if (resourceData != null) { - try { - if (resolve) { - int port = portFunc.apply(resourceData); - NetworkUtil.checkServiceAvailable(resourceData.getHost(), port, - connectTimeout); - } - break; - } catch (IOException ex) { - throwable = ex; - } - } - } while (true); - return resourceData; - } + } + } while (true); + return resourceData; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceFactory.java index 8a95d9cdb..ce3a18090 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceFactory.java @@ -27,39 +27,38 @@ import org.slf4j.LoggerFactory; public class HAServiceFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(HAServiceFactory.class); - private static IHAService haService; + private static final Logger LOGGER = LoggerFactory.getLogger(HAServiceFactory.class); + private static IHAService haService; - public static synchronized IHAService getService(Configuration configuration) { - if (haService == null) { - String serviceType = configuration.getString(ExecutionConfigKeys.HA_SERVICE_TYPE); - if (StringUtils.isEmpty(serviceType)) { - if (configuration.getBoolean(ExecutionConfigKeys.RUN_LOCAL_MODE)) { - serviceType = HAServiceType.memory.name(); - } else { - serviceType = HAServiceType.redis.name(); - } - } - haService = createHAService(serviceType); - haService.open(configuration); + public static synchronized IHAService getService(Configuration configuration) { + if (haService == null) { + String serviceType = configuration.getString(ExecutionConfigKeys.HA_SERVICE_TYPE); + if (StringUtils.isEmpty(serviceType)) { + if (configuration.getBoolean(ExecutionConfigKeys.RUN_LOCAL_MODE)) { + serviceType = HAServiceType.memory.name(); + } else { + serviceType = HAServiceType.redis.name(); } - return haService; + } + haService = createHAService(serviceType); + haService.open(configuration); } + return haService; + } - public static IHAService getService() { - if (haService == null) { - throw new GeaflowRuntimeException("HAService not initialized"); - } - return haService; + public static IHAService getService() { + if (haService == null) { + throw new GeaflowRuntimeException("HAService not initialized"); } + return haService; + } - private static IHAService createHAService(String serviceType) { - if (serviceType.equalsIgnoreCase(HAServiceType.redis.name())) { - return new RedisHAService(); - } else { - LOGGER.warn("unknown ha service type:{}, use default memoryHaService", serviceType); - return new MemoryHAService(); - } + private static IHAService createHAService(String serviceType) { + if (serviceType.equalsIgnoreCase(HAServiceType.redis.name())) { + return new RedisHAService(); + } else { + LOGGER.warn("unknown ha service type:{}, use default memoryHaService", serviceType); + return new MemoryHAService(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceType.java index a8179d88a..6c0885250 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/HAServiceType.java @@ -21,19 +21,12 @@ public enum HAServiceType { - /** - * Service based on redis. - */ - redis, + /** Service based on redis. */ + redis, - /** - * Service based on hbase. - */ - hbase, - - /** - * Service based on memory. - */ - memory + /** Service based on hbase. */ + hbase, + /** Service based on memory. */ + memory } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/IHAService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/IHAService.java index 5d4d8cca2..969c56e44 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/IHAService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/IHAService.java @@ -20,38 +20,26 @@ package org.apache.geaflow.ha.service; import java.io.Serializable; + import org.apache.geaflow.common.config.Configuration; public interface IHAService extends Serializable { - /** - * HA service init. - */ - void open(Configuration configuration); - - /** - * Register resource data. - */ - void register(String resourceId, ResourceData resourceData); + /** HA service init. */ + void open(Configuration configuration); - /** - * Load and resolve resource by resource id. - */ - ResourceData resolveResource(String resourceId); + /** Register resource data. */ + void register(String resourceId, ResourceData resourceData); - /** - * Load resource data for corresponding resource id. - */ - ResourceData loadResource(String resourceId); + /** Load and resolve resource by resource id. */ + ResourceData resolveResource(String resourceId); - /** - * Invalidate resource data for corresponding resource id. - */ - ResourceData invalidateResource(String resourceId); + /** Load resource data for corresponding resource id. */ + ResourceData loadResource(String resourceId); - /** - * HA service close. - */ - void close(); + /** Invalidate resource data for corresponding resource id. */ + ResourceData invalidateResource(String resourceId); + /** HA service close. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/MemoryHAService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/MemoryHAService.java index 5e52e1f49..23e8f2433 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/MemoryHAService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/MemoryHAService.java @@ -23,19 +23,18 @@ public class MemoryHAService extends AbstractHAService { - @Override - public void open(Configuration configuration) { - super.open(configuration); - } + @Override + public void open(Configuration configuration) { + super.open(configuration); + } - @Override - public void register(String resourceId, ResourceData resourceData) { - resourceDataCache.put(resourceId, resourceData); - } - - @Override - public ResourceData resolveResource(String resourceId) { - return resourceDataCache.get(resourceId); - } + @Override + public void register(String resourceId, ResourceData resourceData) { + resourceDataCache.put(resourceId, resourceData); + } + @Override + public ResourceData resolveResource(String resourceId) { + return resourceDataCache.get(resourceId); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/RedisHAService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/RedisHAService.java index 854f57fff..3518bd4a3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/RedisHAService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/RedisHAService.java @@ -31,31 +31,29 @@ public class RedisHAService extends AbstractHAService { - @Override - public void open(Configuration configuration) { - super.open(configuration); - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); - this.kvStore = (KVRedisStore) builder.getStore(DataModel.KV, configuration); - String namespace = - configuration.getString(ExecutionConfigKeys.JOB_UNIQUE_ID) + TABLE_PREFIX; - StoreContext storeContext = new StoreContext(namespace); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); - storeContext.withConfig(configuration); - this.kvStore.init(storeContext); - } + @Override + public void open(Configuration configuration) { + super.open(configuration); + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); + this.kvStore = (KVRedisStore) builder.getStore(DataModel.KV, configuration); + String namespace = configuration.getString(ExecutionConfigKeys.JOB_UNIQUE_ID) + TABLE_PREFIX; + StoreContext storeContext = new StoreContext(namespace); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); + storeContext.withConfig(configuration); + this.kvStore.init(storeContext); + } - @Override - public void register(String resourceId, ResourceData resourceData) { - synchronized (kvStore) { - kvStore.put(resourceId, resourceData); - } + @Override + public void register(String resourceId, ResourceData resourceData) { + synchronized (kvStore) { + kvStore.put(resourceId, resourceData); } + } - @Override - protected ResourceData getResourceData(String resourceId) { - synchronized (kvStore) { - return kvStore.get(resourceId); - } + @Override + protected ResourceData getResourceData(String resourceId) { + synchronized (kvStore) { + return kvStore.get(resourceId); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/ResourceData.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/ResourceData.java index af6173eb1..789e97c86 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/ResourceData.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/main/java/org/apache/geaflow/ha/service/ResourceData.java @@ -22,78 +22,85 @@ import java.io.Serializable; public class ResourceData implements Serializable { - private String host; - private int processId; - /** - * rpc service port. - */ - - private int rpcPort; - /** - * shuffle service port. - */ - private int shufflePort; - /** - * shuffle service port. - */ - private int metricPort; - /** - * worker rpc porker. - */ - private int supervisorPort; - - public String getHost() { - return host; - } - - public void setHost(String host) { - this.host = host; - } - - public int getProcessId() { - return processId; - } - - public void setProcessId(int processId) { - this.processId = processId; - } - - public int getRpcPort() { - return rpcPort; - } - - public void setRpcPort(int rpcPort) { - this.rpcPort = rpcPort; - } - - public int getShufflePort() { - return shufflePort; - } - - public void setShufflePort(int shufflePort) { - this.shufflePort = shufflePort; - } - - public int getMetricPort() { - return metricPort; - } - - public void setMetricPort(int metricPort) { - this.metricPort = metricPort; - } - - public int getSupervisorPort() { - return supervisorPort; - } - - public void setSupervisorPort(int supervisorPort) { - this.supervisorPort = supervisorPort; - } - - @Override - public String toString() { - return "ResourceData{" + "host='" + host + '\'' + ", processId=" + processId + ", rpcPort=" - + rpcPort + ", shufflePort=" + shufflePort + ", metricPort=" + metricPort - + ", supervisorPort=" + supervisorPort + '}'; - } + private String host; + private int processId; + + /** rpc service port. */ + private int rpcPort; + + /** shuffle service port. */ + private int shufflePort; + + /** shuffle service port. */ + private int metricPort; + + /** worker rpc porker. */ + private int supervisorPort; + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public int getProcessId() { + return processId; + } + + public void setProcessId(int processId) { + this.processId = processId; + } + + public int getRpcPort() { + return rpcPort; + } + + public void setRpcPort(int rpcPort) { + this.rpcPort = rpcPort; + } + + public int getShufflePort() { + return shufflePort; + } + + public void setShufflePort(int shufflePort) { + this.shufflePort = shufflePort; + } + + public int getMetricPort() { + return metricPort; + } + + public void setMetricPort(int metricPort) { + this.metricPort = metricPort; + } + + public int getSupervisorPort() { + return supervisorPort; + } + + public void setSupervisorPort(int supervisorPort) { + this.supervisorPort = supervisorPort; + } + + @Override + public String toString() { + return "ResourceData{" + + "host='" + + host + + '\'' + + ", processId=" + + processId + + ", rpcPort=" + + rpcPort + + ", shufflePort=" + + shufflePort + + ", metricPort=" + + metricPort + + ", supervisorPort=" + + supervisorPort + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/MemoryHAServiceTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/MemoryHAServiceTest.java index a93f85e45..e2260f59f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/MemoryHAServiceTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/MemoryHAServiceTest.java @@ -25,17 +25,16 @@ public class MemoryHAServiceTest { - @Test - public void test() { - MemoryHAService memoryHAService = new MemoryHAService(); - memoryHAService.open(new Configuration()); - - ResourceData resourceData = new ResourceData(); - memoryHAService.register("1", resourceData); - Assert.assertEquals(resourceData, memoryHAService.resolveResource("1")); - Assert.assertNull(memoryHAService.resolveResource("2")); - memoryHAService.invalidateResource("1"); - Assert.assertNull(memoryHAService.resolveResource("1")); - } + @Test + public void test() { + MemoryHAService memoryHAService = new MemoryHAService(); + memoryHAService.open(new Configuration()); + ResourceData resourceData = new ResourceData(); + memoryHAService.register("1", resourceData); + Assert.assertEquals(resourceData, memoryHAService.resolveResource("1")); + Assert.assertNull(memoryHAService.resolveResource("2")); + memoryHAService.invalidateResource("1"); + Assert.assertNull(memoryHAService.resolveResource("1")); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/RedisHAServiceTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/RedisHAServiceTest.java index 0f124372d..d258f08af 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/RedisHAServiceTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-highavailability/src/test/java/org/apache/geaflow/ha/service/RedisHAServiceTest.java @@ -24,11 +24,11 @@ import static org.apache.geaflow.store.redis.RedisConfigKeys.REDIS_HOST; import static org.apache.geaflow.store.redis.RedisConfigKeys.REDIS_PORT; -import com.github.fppt.jedismock.RedisServer; import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.testng.Assert; @@ -36,66 +36,68 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import com.github.fppt.jedismock.RedisServer; + public class RedisHAServiceTest { - private RedisServer redisServer; + private RedisServer redisServer; - @BeforeClass - public void prepare() throws IOException { - redisServer = RedisServer.newRedisServer().start(); - } + @BeforeClass + public void prepare() throws IOException { + redisServer = RedisServer.newRedisServer().start(); + } - @AfterClass - public void tearUp() throws IOException { - redisServer.stop(); - } + @AfterClass + public void tearUp() throws IOException { + redisServer.stop(); + } - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void test() { - RedisHAService haService = new RedisHAService(); - Configuration configuration = new Configuration(); - configuration.put(REDIS_HOST, redisServer.getHost()); - configuration.put(REDIS_PORT, String.valueOf(redisServer.getBindPort())); - configuration.put(JOB_UNIQUE_ID, "123"); - configuration.put(FO_TIMEOUT_MS, "2000"); - haService.open(configuration); + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void test() { + RedisHAService haService = new RedisHAService(); + Configuration configuration = new Configuration(); + configuration.put(REDIS_HOST, redisServer.getHost()); + configuration.put(REDIS_PORT, String.valueOf(redisServer.getBindPort())); + configuration.put(JOB_UNIQUE_ID, "123"); + configuration.put(FO_TIMEOUT_MS, "2000"); + haService.open(configuration); - ResourceData resourceData = new ResourceData(); - resourceData.setHost("127.0.0.1"); - resourceData.setRpcPort(6055); - haService.register("1", resourceData); - ResourceData result = haService.resolveResource("1"); - Assert.assertNotNull(result); - haService.close(); - } + ResourceData resourceData = new ResourceData(); + resourceData.setHost("127.0.0.1"); + resourceData.setRpcPort(6055); + haService.register("1", resourceData); + ResourceData result = haService.resolveResource("1"); + Assert.assertNotNull(result); + haService.close(); + } - @Test - public void testMultiThread() throws InterruptedException { - ExecutorService executorService = Executors.newFixedThreadPool(1); + @Test + public void testMultiThread() throws InterruptedException { + ExecutorService executorService = Executors.newFixedThreadPool(1); - RedisHAService haService = new RedisHAService(); - Configuration configuration = new Configuration(); - configuration.put(REDIS_HOST, redisServer.getHost()); - configuration.put(REDIS_PORT, String.valueOf(redisServer.getBindPort())); - configuration.put(JOB_UNIQUE_ID, "2300087"); - configuration.put(FO_TIMEOUT_MS, "2000"); - haService.open(configuration); + RedisHAService haService = new RedisHAService(); + Configuration configuration = new Configuration(); + configuration.put(REDIS_HOST, redisServer.getHost()); + configuration.put(REDIS_PORT, String.valueOf(redisServer.getBindPort())); + configuration.put(JOB_UNIQUE_ID, "2300087"); + configuration.put(FO_TIMEOUT_MS, "2000"); + haService.open(configuration); - String resourceId = "test-multi-thread"; - CountDownLatch latch = new CountDownLatch(1); - executorService.execute(() -> { - while (true) { - ResourceData resourceData = new ResourceData(); - resourceData.setHost("abc"); - haService.register(resourceId, resourceData); - latch.countDown(); - } + String resourceId = "test-multi-thread"; + CountDownLatch latch = new CountDownLatch(1); + executorService.execute( + () -> { + while (true) { + ResourceData resourceData = new ResourceData(); + resourceData.setHost("abc"); + haService.register(resourceId, resourceData); + latch.countDown(); + } }); - latch.await(); - ResourceData data = haService.getResourceData(resourceId); - Assert.assertNotNull(data); - executorService.shutdown(); - } - + latch.await(); + ResourceData data = haService.getResourceData(resourceId); + Assert.assertNotNull(data); + executorService.shutdown(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IReaderContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IReaderContext.java index 4a077e6c5..2d92d556c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IReaderContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IReaderContext.java @@ -20,15 +20,15 @@ package org.apache.geaflow.shuffle.api.reader; import java.io.Serializable; + import org.apache.geaflow.common.config.Configuration; public interface IReaderContext extends Serializable { - /** - * Get the configuration. - * - * @return configuration. - */ - Configuration getConfig(); - + /** + * Get the configuration. + * + * @return configuration. + */ + Configuration getConfig(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IShuffleReader.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IShuffleReader.java index e7333a342..570c4f959 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IShuffleReader.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/IShuffleReader.java @@ -20,49 +20,47 @@ package org.apache.geaflow.shuffle.api.reader; import java.io.Serializable; + import org.apache.geaflow.common.metric.ShuffleReadMetrics; import org.apache.geaflow.shuffle.message.PipelineEvent; public interface IShuffleReader extends Serializable { - /** - * Init shuffle reader with reader context. - * - * @param context reader context. - */ - void init(IReaderContext context); - - /** - * Fetch upstream shards. - * - * @param targetWindowId target window id - */ - void fetch(long targetWindowId); + /** + * Init shuffle reader with reader context. + * + * @param context reader context. + */ + void init(IReaderContext context); - /** - * Returns true if the requested batches is not fetched completely. - * - * @return true/false. - */ - boolean hasNext(); + /** + * Fetch upstream shards. + * + * @param targetWindowId target window id + */ + void fetch(long targetWindowId); - /** - * Returns the next batch. - * - * @return batch data or event. - */ - PipelineEvent next(); + /** + * Returns true if the requested batches is not fetched completely. + * + * @return true/false. + */ + boolean hasNext(); - /** - * Get read metrics. - * - * @return read metrics. - */ - ShuffleReadMetrics getShuffleReadMetrics(); + /** + * Returns the next batch. + * + * @return batch data or event. + */ + PipelineEvent next(); - /** - * Close. - */ - void close(); + /** + * Get read metrics. + * + * @return read metrics. + */ + ShuffleReadMetrics getShuffleReadMetrics(); + /** Close. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/PipelineReader.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/PipelineReader.java index bef1b3bcc..368a34a36 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/PipelineReader.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/PipelineReader.java @@ -19,7 +19,6 @@ package org.apache.geaflow.shuffle.api.reader; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -27,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.metric.ShuffleReadMetrics; @@ -51,190 +51,195 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PipelineReader implements IShuffleReader { - - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineReader.class); - - private final IConnectionManager connectionManager; - private ReaderContext readerContext; - private Map> encoders; - private String taskName; - private ShuffleReadMetrics readMetrics; - - private int channels; - private int totalSliceNum; - private int processedNum; - private long targetWindowId; +import com.google.common.base.Preconditions; - private ShardFetcher inputFetcher; - private volatile boolean isRunning; +public class PipelineReader implements IShuffleReader { - public PipelineReader(IConnectionManager connectionManager) { - this.connectionManager = connectionManager; + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineReader.class); + + private final IConnectionManager connectionManager; + private ReaderContext readerContext; + private Map> encoders; + private String taskName; + private ShuffleReadMetrics readMetrics; + + private int channels; + private int totalSliceNum; + private int processedNum; + private long targetWindowId; + + private ShardFetcher inputFetcher; + private volatile boolean isRunning; + + public PipelineReader(IConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + @Override + public void init(IReaderContext context) { + this.readerContext = (ReaderContext) context; + this.encoders = new HashMap<>(); + for (Map.Entry entry : + this.readerContext.getInputShardMap().entrySet()) { + this.encoders.put(entry.getKey(), entry.getValue().getEncoder()); } - - @Override - public void init(IReaderContext context) { - this.readerContext = (ReaderContext) context; - this.encoders = new HashMap<>(); - for (Map.Entry entry : this.readerContext.getInputShardMap() - .entrySet()) { - this.encoders.put(entry.getKey(), entry.getValue().getEncoder()); - } - this.taskName = this.readerContext.getTaskName(); - this.channels = this.readerContext.getSliceNum(); - this.readMetrics = new ShuffleReadMetrics(); + this.taskName = this.readerContext.getTaskName(); + this.channels = this.readerContext.getSliceNum(); + this.readMetrics = new ShuffleReadMetrics(); + } + + @Override + public void fetch(long windowId) { + Preconditions.checkArgument(windowId > 0, "window should be larger than 0"); + if (windowId <= this.targetWindowId) { + return; } - @Override - public void fetch(long windowId) { - Preconditions.checkArgument(windowId > 0, "window should be larger than 0"); - if (windowId <= this.targetWindowId) { - return; - } - - this.totalSliceNum = this.channels; - this.targetWindowId = windowId; - this.processedNum = 0; - - try { - if (this.inputFetcher == null) { - this.inputFetcher = this.createShardFetcher(this.targetWindowId); - } - this.inputFetcher.requestSlices(this.targetWindowId); - } catch (IOException e) { - LOGGER.error(e.getMessage(), e.getCause()); - throw new GeaflowRuntimeException("fetch error", e); - } - this.isRunning = true; + this.totalSliceNum = this.channels; + this.targetWindowId = windowId; + this.processedNum = 0; + + try { + if (this.inputFetcher == null) { + this.inputFetcher = this.createShardFetcher(this.targetWindowId); + } + this.inputFetcher.requestSlices(this.targetWindowId); + } catch (IOException e) { + LOGGER.error(e.getMessage(), e.getCause()); + throw new GeaflowRuntimeException("fetch error", e); } - - @Override - public boolean hasNext() { - boolean longTerm = this.targetWindowId == Long.MAX_VALUE; - boolean moreAvailable = this.processedNum < this.totalSliceNum; - return longTerm || moreAvailable; + this.isRunning = true; + } + + @Override + public boolean hasNext() { + boolean longTerm = this.targetWindowId == Long.MAX_VALUE; + boolean moreAvailable = this.processedNum < this.totalSliceNum; + return longTerm || moreAvailable; + } + + @Override + public PipelineEvent next() { + if (!this.isRunning) { + return null; } - - @Override - public PipelineEvent next() { - if (!this.isRunning) { - return null; - } - long startTime = System.currentTimeMillis(); - try { - Optional next = this.inputFetcher.getNext(); - if (next.isPresent()) { - PipeFetcherBuffer buffer = next.get(); - if (buffer.isBarrier()) { - if (this.targetWindowId != Long.MAX_VALUE && ( - buffer.getBatchId() == this.targetWindowId || buffer.isFinish())) { - this.processedNum++; - } - - SliceId sliceId = buffer.getSliceId(); - PipelineBarrier barrier = new PipelineBarrier(buffer.getBatchId(), - sliceId.getEdgeId(), sliceId.getShardIndex(), sliceId.getSliceIndex(), - buffer.getBatchCount()); - barrier.setFinish(buffer.isFinish()); - return barrier; - } else { - int edgeId = buffer.getSliceId().getEdgeId(); - this.readMetrics.increaseDecodeBytes(buffer.getBufferSize()); - IMessageIterator msgIterator = this.getMessageIterator(edgeId, - buffer.getBuffer()); - return new PipelineMessage<>(edgeId, buffer.getBatchId(), - buffer.getStreamName(), msgIterator); - } - } else { - return null; - } - } catch (IOException | InterruptedException e) { - LOGGER.error(e.getMessage(), e.getCause()); - throw new GeaflowRuntimeException(e); - } finally { - this.readMetrics.incFetchWaitMs(System.currentTimeMillis() - startTime); + long startTime = System.currentTimeMillis(); + try { + Optional next = this.inputFetcher.getNext(); + if (next.isPresent()) { + PipeFetcherBuffer buffer = next.get(); + if (buffer.isBarrier()) { + if (this.targetWindowId != Long.MAX_VALUE + && (buffer.getBatchId() == this.targetWindowId || buffer.isFinish())) { + this.processedNum++; + } + + SliceId sliceId = buffer.getSliceId(); + PipelineBarrier barrier = + new PipelineBarrier( + buffer.getBatchId(), + sliceId.getEdgeId(), + sliceId.getShardIndex(), + sliceId.getSliceIndex(), + buffer.getBatchCount()); + barrier.setFinish(buffer.isFinish()); + return barrier; + } else { + int edgeId = buffer.getSliceId().getEdgeId(); + this.readMetrics.increaseDecodeBytes(buffer.getBufferSize()); + IMessageIterator msgIterator = this.getMessageIterator(edgeId, buffer.getBuffer()); + return new PipelineMessage<>( + edgeId, buffer.getBatchId(), buffer.getStreamName(), msgIterator); } + } else { + return null; + } + } catch (IOException | InterruptedException e) { + LOGGER.error(e.getMessage(), e.getCause()); + throw new GeaflowRuntimeException(e); + } finally { + this.readMetrics.incFetchWaitMs(System.currentTimeMillis() - startTime); } + } - @Override - public ShuffleReadMetrics getShuffleReadMetrics() { - return this.readMetrics; - } + @Override + public ShuffleReadMetrics getShuffleReadMetrics() { + return this.readMetrics; + } - @Override - public void close() { - checkIfCloseable(); + @Override + public void close() { + checkIfCloseable(); - this.isRunning = false; - if (this.inputFetcher != null) { - this.inputFetcher.close(); - } + this.isRunning = false; + if (this.inputFetcher != null) { + this.inputFetcher.close(); } + } - private void checkIfCloseable() { - boolean longTerm = this.targetWindowId == Long.MAX_VALUE; - boolean moreAvailable = this.processedNum < this.totalSliceNum; - if (!longTerm && moreAvailable) { - throw new GeaflowRuntimeException("shuffle reader has unfinished messages"); - } + private void checkIfCloseable() { + boolean longTerm = this.targetWindowId == Long.MAX_VALUE; + boolean moreAvailable = this.processedNum < this.totalSliceNum; + if (!longTerm && moreAvailable) { + throw new GeaflowRuntimeException("shuffle reader has unfinished messages"); } - - private ShardFetcher createShardFetcher(long targetWindowId) { - Map> inputSlices = this.readerContext.getInputSlices(); - - int fetcherIndex = 0; - - List fetchers = new ArrayList<>(inputSlices.size()); - Map inputShards = this.readerContext.getInputShardMap(); - for (Map.Entry entry : inputShards.entrySet()) { - Integer edgeId = entry.getKey(); - ShardInputDesc inputDesc = entry.getValue(); - String streamName = inputDesc.getName(); - List slices = inputDesc.isPrefetchRead() - ? this.buildPrefetchSlice(inputSlices.get(edgeId)) - : inputSlices.get(edgeId); - OneShardFetcher inputFetcher = new OneShardFetcher( - this.readerContext.getVertexId(), - this.taskName, - fetcherIndex, - edgeId, - streamName, - slices, - targetWindowId, - this.connectionManager); - fetchers.add(inputFetcher); - fetcherIndex++; - } - - if (fetchers.size() == 1) { - return fetchers.get(0); - } else { - return new MultiShardFetcher(fetchers.toArray(new OneShardFetcher[0])); - } + } + + private ShardFetcher createShardFetcher(long targetWindowId) { + Map> inputSlices = this.readerContext.getInputSlices(); + + int fetcherIndex = 0; + + List fetchers = new ArrayList<>(inputSlices.size()); + Map inputShards = this.readerContext.getInputShardMap(); + for (Map.Entry entry : inputShards.entrySet()) { + Integer edgeId = entry.getKey(); + ShardInputDesc inputDesc = entry.getValue(); + String streamName = inputDesc.getName(); + List slices = + inputDesc.isPrefetchRead() + ? this.buildPrefetchSlice(inputSlices.get(edgeId)) + : inputSlices.get(edgeId); + OneShardFetcher inputFetcher = + new OneShardFetcher( + this.readerContext.getVertexId(), + this.taskName, + fetcherIndex, + edgeId, + streamName, + slices, + targetWindowId, + this.connectionManager); + fetchers.add(inputFetcher); + fetcherIndex++; } - private List buildPrefetchSlice(List slices) { - PipelineSliceMeta slice = slices.get(0); - SliceId tmp = slice.getSliceId(); - SliceId sliceId = new SliceId(tmp.getPipelineId(), tmp.getEdgeId(), -1, - tmp.getSliceIndex()); - SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - SpillablePipelineSlice resultSlice = (SpillablePipelineSlice) sliceManager.getSlice( - sliceId); - if (resultSlice == null || !resultSlice.isReady2read() || resultSlice.isReleased()) { - throw new GeaflowRuntimeException("illegal slice: " + sliceId); - } - PipelineSliceMeta newSlice = new PipelineSliceMeta(sliceId, slice.getWindowId(), - this.connectionManager.getShuffleAddress()); - return Collections.singletonList(newSlice); + if (fetchers.size() == 1) { + return fetchers.get(0); + } else { + return new MultiShardFetcher(fetchers.toArray(new OneShardFetcher[0])); } - - private IMessageIterator getMessageIterator(int edgeId, OutBuffer outBuffer) { - IEncoder encoder = this.encoders.get(edgeId); - return encoder == null - ? new MessageIterator<>(outBuffer) - : new EncoderMessageIterator<>(outBuffer, encoder); + } + + private List buildPrefetchSlice(List slices) { + PipelineSliceMeta slice = slices.get(0); + SliceId tmp = slice.getSliceId(); + SliceId sliceId = new SliceId(tmp.getPipelineId(), tmp.getEdgeId(), -1, tmp.getSliceIndex()); + SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + SpillablePipelineSlice resultSlice = (SpillablePipelineSlice) sliceManager.getSlice(sliceId); + if (resultSlice == null || !resultSlice.isReady2read() || resultSlice.isReleased()) { + throw new GeaflowRuntimeException("illegal slice: " + sliceId); } - + PipelineSliceMeta newSlice = + new PipelineSliceMeta( + sliceId, slice.getWindowId(), this.connectionManager.getShuffleAddress()); + return Collections.singletonList(newSlice); + } + + private IMessageIterator getMessageIterator(int edgeId, OutBuffer outBuffer) { + IEncoder encoder = this.encoders.get(edgeId); + return encoder == null + ? new MessageIterator<>(outBuffer) + : new EncoderMessageIterator<>(outBuffer, encoder); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/ReaderContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/ReaderContext.java index bb8c10b96..7a85af625 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/ReaderContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/reader/ReaderContext.java @@ -21,66 +21,66 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.shuffle.desc.ShardInputDesc; import org.apache.geaflow.shuffle.message.PipelineSliceMeta; public class ReaderContext implements IReaderContext { - private Configuration config; - private int vertexId; - private String taskName; - private Map inputShardMap; - private Map> inputSlices; - private int sliceNum; - - @Override - public Configuration getConfig() { - return this.config; - } - - public int getVertexId() { - return this.vertexId; - } - - public String getTaskName() { - return this.taskName; - } - - public Map getInputShardMap() { - return this.inputShardMap; - } - - public Map> getInputSlices() { - return this.inputSlices; - } - - public int getSliceNum() { - return this.sliceNum; - } - - public void setConfig(Configuration config) { - this.config = config; - } - - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } - - public void setTaskName(String taskName) { - this.taskName = taskName; - } - - public void setInputShardMap(Map inputShardMap) { - this.inputShardMap = inputShardMap; - } - - public void setInputSlices(Map> inputSlices) { - this.inputSlices = inputSlices; - } - - public void setSliceNum(int sliceNum) { - this.sliceNum = sliceNum; - } - -} \ No newline at end of file + private Configuration config; + private int vertexId; + private String taskName; + private Map inputShardMap; + private Map> inputSlices; + private int sliceNum; + + @Override + public Configuration getConfig() { + return this.config; + } + + public int getVertexId() { + return this.vertexId; + } + + public String getTaskName() { + return this.taskName; + } + + public Map getInputShardMap() { + return this.inputShardMap; + } + + public Map> getInputSlices() { + return this.inputSlices; + } + + public int getSliceNum() { + return this.sliceNum; + } + + public void setConfig(Configuration config) { + this.config = config; + } + + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } + + public void setTaskName(String taskName) { + this.taskName = taskName; + } + + public void setInputShardMap(Map inputShardMap) { + this.inputShardMap = inputShardMap; + } + + public void setInputSlices(Map> inputSlices) { + this.inputSlices = inputSlices; + } + + public void setSliceNum(int sliceNum) { + this.sliceNum = sliceNum; + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IShuffleWriter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IShuffleWriter.java index 7de8acddf..2d4d97924 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IShuffleWriter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IShuffleWriter.java @@ -22,53 +22,52 @@ import java.io.IOException; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.metric.ShuffleWriteMetrics; public interface IShuffleWriter { - /** - * Init with shuffle writer context. - * - * @param writerContext writer context. - */ - void init(IWriterContext writerContext); + /** + * Init with shuffle writer context. + * + * @param writerContext writer context. + */ + void init(IWriterContext writerContext); - /** - * Emit value to output channels. - * - * @param channels output channels. - * @throws IOException io exception. - */ - void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException; + /** + * Emit value to output channels. + * + * @param channels output channels. + * @throws IOException io exception. + */ + void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException; - /** - * Emit values to output channels. - * - * @param batchId batch id - * @param value data list - * @param isRetract if retract - * @param channels output channels - * @throws IOException err - */ - void emit(long batchId, List value, boolean isRetract, int channels) throws IOException; + /** + * Emit values to output channels. + * + * @param batchId batch id + * @param value data list + * @param isRetract if retract + * @param channels output channels + * @throws IOException err + */ + void emit(long batchId, List value, boolean isRetract, int channels) throws IOException; - /** - * Flush buffered data. - * - * @return shuffle result. - * @throws IOException io exception. - */ - Optional flush(long batchId) throws IOException; + /** + * Flush buffered data. + * + * @return shuffle result. + * @throws IOException io exception. + */ + Optional flush(long batchId) throws IOException; - /** - * Get write metrics. - * - * @return write metrics. - */ - ShuffleWriteMetrics getShuffleWriteMetrics(); + /** + * Get write metrics. + * + * @return write metrics. + */ + ShuffleWriteMetrics getShuffleWriteMetrics(); - /** - * Close. - */ - void close(); + /** Close. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IWriterContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IWriterContext.java index c9a1442d4..b5d696da7 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IWriterContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/IWriterContext.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.api.writer; import java.io.Serializable; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.shuffle.DataExchangeMode; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -27,69 +28,68 @@ public interface IWriterContext extends Serializable { - /** - * Get the pipeline information. - * - * @return pipeline info. - */ - PipelineInfo getPipelineInfo(); - - /** - * Get the vertex id in the DAG. - * - * @return vertex id. - */ - int getVertexId(); + /** + * Get the pipeline information. + * + * @return pipeline info. + */ + PipelineInfo getPipelineInfo(); - /** - * Get the edge id in the DAG. - * - * @return edge id. - */ - int getEdgeId(); + /** + * Get the vertex id in the DAG. + * + * @return vertex id. + */ + int getVertexId(); - /** - * Get the task index of the writer. - * - * @return task index. - */ - int getTaskIndex(); + /** + * Get the edge id in the DAG. + * + * @return edge id. + */ + int getEdgeId(); - /** - * Get the task id of the writer. - * - * @return task id. - */ - int getTaskId(); + /** + * Get the task index of the writer. + * + * @return task index. + */ + int getTaskIndex(); - /** - * Get task name of the writer. - * - * @return task name. - */ - String getTaskName(); + /** + * Get the task id of the writer. + * + * @return task id. + */ + int getTaskId(); - /** - * Get the target channel number of downstream. - * - * @return target channel number. - */ - int getTargetChannelNum(); + /** + * Get task name of the writer. + * + * @return task name. + */ + String getTaskName(); - /** - * Get the configuration. - * - * @return configuration. - */ - ShuffleConfig getConfig(); + /** + * Get the target channel number of downstream. + * + * @return target channel number. + */ + int getTargetChannelNum(); - /** - * Get the encoder for serialize and deserialize data. - * - * @return data encoder. - */ - IEncoder getEncoder(); + /** + * Get the configuration. + * + * @return configuration. + */ + ShuffleConfig getConfig(); - DataExchangeMode getDataExchangeMode(); + /** + * Get the encoder for serialize and deserialize data. + * + * @return data encoder. + */ + IEncoder getEncoder(); + DataExchangeMode getDataExchangeMode(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriter.java index 8b2e445ec..f200ce25f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriter.java @@ -24,6 +24,7 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.message.Shard; import org.apache.geaflow.shuffle.message.SliceId; @@ -36,161 +37,160 @@ public class PipelineShardWriter extends ShardWriter { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineWriter.class); - - private OutputFlusher outputFlusher; - private final AtomicReference throwable; - private final AtomicInteger curBufferBytes; - private int maxWriteBufferSize; - - public PipelineShardWriter() { - this.throwable = new AtomicReference<>(); - this.curBufferBytes = new AtomicInteger(0); - } - - @Override - public void init(IWriterContext writerContext) { - super.init(writerContext); - String threadName = String.format("flusher-%s", writerContext.getTaskName()); - int flushTimeout = this.shuffleConfig.getFlushBufferTimeoutMs(); - this.maxWriteBufferSize = shuffleConfig.getMaxWriteBufferSize(); - this.outputFlusher = new OutputFlusher(threadName, flushTimeout); - this.outputFlusher.start(); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineWriter.class); + + private OutputFlusher outputFlusher; + private final AtomicReference throwable; + private final AtomicInteger curBufferBytes; + private int maxWriteBufferSize; + + public PipelineShardWriter() { + this.throwable = new AtomicReference<>(); + this.curBufferBytes = new AtomicInteger(0); + } + + @Override + public void init(IWriterContext writerContext) { + super.init(writerContext); + String threadName = String.format("flusher-%s", writerContext.getTaskName()); + int flushTimeout = this.shuffleConfig.getFlushBufferTimeoutMs(); + this.maxWriteBufferSize = shuffleConfig.getMaxWriteBufferSize(); + this.outputFlusher = new OutputFlusher(threadName, flushTimeout); + this.outputFlusher.start(); + } + + @Override + protected IPipelineSlice newSlice(String taskLogTag, SliceId sliceId) { + if (enableBackPressure) { + return new BlockingSlice(taskLogTag, sliceId, this); + } else { + return new PipelineSlice(taskLogTag, sliceId); } - - @Override - protected IPipelineSlice newSlice(String taskLogTag, SliceId sliceId) { - if (enableBackPressure) { - return new BlockingSlice(taskLogTag, sliceId, this); - } else { - return new PipelineSlice(taskLogTag, sliceId); + } + + @Override + public void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException { + this.checkError(); + super.emit(batchId, value, isRetract, channels); + } + + @Override + public void emit(long batchId, List data, int channel) throws IOException { + this.checkError(); + super.emit(batchId, data, channel); + } + + @Override + public Optional doFinish(long windowId) throws IOException { + this.checkError(); + return Optional.empty(); + } + + @Override + protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) { + if (enableBackPressure) { + if (curBufferBytes.get() >= maxWriteBufferSize) { + synchronized (this) { + while (curBufferBytes.get() >= maxWriteBufferSize) { + try { + this.wait(); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); + } + } } + } + curBufferBytes.addAndGet(builder.getBufferSize()); } - - @Override - public void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException { - this.checkError(); - super.emit(batchId, value, isRetract, channels); + super.sendBuffer(sliceIndex, builder, windowId); + } + + public void notifyBufferConsumed(int bufferBytes) { + int preBytes = curBufferBytes.getAndAdd(-bufferBytes); + if (preBytes >= maxWriteBufferSize && curBufferBytes.get() < maxWriteBufferSize) { + synchronized (this) { + this.notifyAll(); + } } + } - @Override - public void emit(long batchId, List data, int channel) throws IOException { - this.checkError(); - super.emit(batchId, data, channel); + private void flushAll() { + boolean flushed = this.flushSlices(); + if (!flushed) { + LOGGER.warn("terminate flusher due to slices released"); + this.outputFlusher.terminate(); } + } - @Override - public Optional doFinish(long windowId) throws IOException { - this.checkError(); - return Optional.empty(); + @Override + public void close() { + if (this.outputFlusher != null) { + this.outputFlusher.terminate(); + this.outputFlusher = null; } - - @Override - protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) { - if (enableBackPressure) { - if (curBufferBytes.get() >= maxWriteBufferSize) { - synchronized (this) { - while (curBufferBytes.get() >= maxWriteBufferSize) { - try { - this.wait(); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } - } - } - } - curBufferBytes.addAndGet(builder.getBufferSize()); - } - super.sendBuffer(sliceIndex, builder, windowId); + } + + private void checkError() throws IOException { + if (this.throwable.get() != null) { + Throwable t = this.throwable.get(); + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new GeaflowRuntimeException(t); + } } + } - public void notifyBufferConsumed(int bufferBytes) { - int preBytes = curBufferBytes.getAndAdd(-bufferBytes); - if (preBytes >= maxWriteBufferSize && curBufferBytes.get() < maxWriteBufferSize) { - synchronized (this) { - this.notifyAll(); - } - } - } + /** + * A dedicated thread that periodically flushes the output buffers, to set upper latency bounds. + * + *

The thread is daemonic, because it is only a utility thread. + */ + private class OutputFlusher extends Thread { - private void flushAll() { - boolean flushed = this.flushSlices(); - if (!flushed) { - LOGGER.warn("terminate flusher due to slices released"); - this.outputFlusher.terminate(); - } - } + private final long timeout; - @Override - public void close() { - if (this.outputFlusher != null) { - this.outputFlusher.terminate(); - this.outputFlusher = null; - } - } + private volatile boolean running = true; - private void checkError() throws IOException { - if (this.throwable.get() != null) { - Throwable t = this.throwable.get(); - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new GeaflowRuntimeException(t); - } - } + OutputFlusher(String name, long timeout) { + super(name); + setDaemon(true); + this.timeout = timeout; + LOGGER.info("start {} with timeout {}ms", name, timeout); } - /** - * A dedicated thread that periodically flushes the output buffers, to set upper latency bounds. - * - *

The thread is daemonic, because it is only a utility thread. - */ - private class OutputFlusher extends Thread { - - private final long timeout; - - private volatile boolean running = true; - - OutputFlusher(String name, long timeout) { - super(name); - setDaemon(true); - this.timeout = timeout; - LOGGER.info("start {} with timeout {}ms", name, timeout); - } + public void terminate() { + if (running) { + running = false; + interrupt(); + } + } - public void terminate() { - if (running) { - running = false; - interrupt(); + @Override + public void run() { + try { + while (this.running) { + try { + Thread.sleep(this.timeout); + } catch (InterruptedException e) { + // Propagate this if we are still running, + // because it should not happen in that case. + if (this.running) { + LOGGER.error("Interrupted", e); + throw e; } - } + } - @Override - public void run() { - try { - while (this.running) { - try { - Thread.sleep(this.timeout); - } catch (InterruptedException e) { - // Propagate this if we are still running, - // because it should not happen in that case. - if (this.running) { - LOGGER.error("Interrupted", e); - throw e; - } - } - - // Any errors here should let the thread come to a halt and be - // recognized by the writer. - flushAll(); - } - flushAll(); - } catch (Throwable t) { - if (throwable.compareAndSet(null, t)) { - LOGGER.error("flush failed", t); - } - } + // Any errors here should let the thread come to a halt and be + // recognized by the writer. + flushAll(); } + flushAll(); + } catch (Throwable t) { + if (throwable.compareAndSet(null, t)) { + LOGGER.error("flush failed", t); + } + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineWriter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineWriter.java index 229d540d1..8a136d64c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineWriter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/PipelineWriter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.metric.ShuffleWriteMetrics; import org.apache.geaflow.common.shuffle.DataExchangeMode; import org.apache.geaflow.shuffle.message.Shard; @@ -29,46 +30,46 @@ public class PipelineWriter implements IShuffleWriter { - private final IConnectionManager connectionManager; - private ShardWriter shardWriter; + private final IConnectionManager connectionManager; + private ShardWriter shardWriter; - public PipelineWriter(IConnectionManager connectionManager) { - this.connectionManager = connectionManager; - } + public PipelineWriter(IConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } - @Override - public void init(IWriterContext writerContext) { - this.shardWriter = writerContext.getDataExchangeMode() == DataExchangeMode.BATCH + @Override + public void init(IWriterContext writerContext) { + this.shardWriter = + writerContext.getDataExchangeMode() == DataExchangeMode.BATCH ? new SpillableShardWriter<>(this.connectionManager.getShuffleAddress()) : new PipelineShardWriter<>(); - this.shardWriter.init(writerContext); - } + this.shardWriter.init(writerContext); + } - @Override - public void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException { - this.shardWriter.emit(batchId, value, isRetract, channels); - } + @Override + public void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException { + this.shardWriter.emit(batchId, value, isRetract, channels); + } - @Override - public void emit(long batchId, List data, boolean isRetract, int channel) throws IOException { - this.shardWriter.emit(batchId, data, channel); - } + @Override + public void emit(long batchId, List data, boolean isRetract, int channel) throws IOException { + this.shardWriter.emit(batchId, data, channel); + } - @Override - public Optional flush(long batchId) throws IOException { - return this.shardWriter.finish(batchId); - } + @Override + public Optional flush(long batchId) throws IOException { + return this.shardWriter.finish(batchId); + } - @Override - public ShuffleWriteMetrics getShuffleWriteMetrics() { - return this.shardWriter.getShuffleWriteMetrics(); - } + @Override + public ShuffleWriteMetrics getShuffleWriteMetrics() { + return this.shardWriter.getShuffleWriteMetrics(); + } - @Override - public void close() { - if (this.shardWriter != null) { - this.shardWriter.close(); - } + @Override + public void close() { + if (this.shardWriter != null) { + this.shardWriter.close(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/ShardWriter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/ShardWriter.java index f65c6652a..7e07dcab8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/ShardWriter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/ShardWriter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.metric.ShuffleWriteMetrics; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -41,178 +42,174 @@ public abstract class ShardWriter { - protected IWriterContext writerContext; - protected ShuffleConfig shuffleConfig; - protected ShuffleWriteMetrics writeMetrics; - - - protected long pipelineId; - protected String pipelineName; - protected int edgeId; - protected int taskIndex; - protected int targetChannels; - protected boolean enableBackPressure; - - protected String taskLogTag; - protected long[] recordCounter; - protected long[] bytesCounter; - protected long maxBufferSize; - - protected BufferBuilder[] buffers; - protected volatile IPipelineSlice[] resultSlices; - protected IRecordSerializer recordSerializer; - - ////////////////////////////// - // Init. - - /// /////////////////////////// - - public void init(IWriterContext writerContext) { - this.writerContext = writerContext; - this.shuffleConfig = writerContext.getConfig(); - this.writeMetrics = new ShuffleWriteMetrics(); - - this.pipelineId = writerContext.getPipelineInfo().getPipelineId(); - this.pipelineName = writerContext.getPipelineInfo().getPipelineName(); - this.edgeId = writerContext.getEdgeId(); - this.taskIndex = writerContext.getTaskIndex(); - this.targetChannels = writerContext.getTargetChannelNum(); - this.taskLogTag = writerContext.getTaskName(); - this.recordCounter = new long[this.targetChannels]; - this.bytesCounter = new long[this.targetChannels]; - this.maxBufferSize = this.shuffleConfig.getMaxBufferSizeBytes(); - this.enableBackPressure = this.shuffleConfig.isBackpressureEnabled(); - - this.buffers = this.buildBufferBuilder(this.targetChannels, - this.shuffleConfig.isMemoryPoolEnable()); - this.resultSlices = this.buildResultSlices(this.targetChannels); - this.recordSerializer = this.getRecordSerializer(); + protected IWriterContext writerContext; + protected ShuffleConfig shuffleConfig; + protected ShuffleWriteMetrics writeMetrics; + + protected long pipelineId; + protected String pipelineName; + protected int edgeId; + protected int taskIndex; + protected int targetChannels; + protected boolean enableBackPressure; + + protected String taskLogTag; + protected long[] recordCounter; + protected long[] bytesCounter; + protected long maxBufferSize; + + protected BufferBuilder[] buffers; + protected volatile IPipelineSlice[] resultSlices; + protected IRecordSerializer recordSerializer; + + ////////////////////////////// + // Init. + + /// /////////////////////////// + + public void init(IWriterContext writerContext) { + this.writerContext = writerContext; + this.shuffleConfig = writerContext.getConfig(); + this.writeMetrics = new ShuffleWriteMetrics(); + + this.pipelineId = writerContext.getPipelineInfo().getPipelineId(); + this.pipelineName = writerContext.getPipelineInfo().getPipelineName(); + this.edgeId = writerContext.getEdgeId(); + this.taskIndex = writerContext.getTaskIndex(); + this.targetChannels = writerContext.getTargetChannelNum(); + this.taskLogTag = writerContext.getTaskName(); + this.recordCounter = new long[this.targetChannels]; + this.bytesCounter = new long[this.targetChannels]; + this.maxBufferSize = this.shuffleConfig.getMaxBufferSizeBytes(); + this.enableBackPressure = this.shuffleConfig.isBackpressureEnabled(); + + this.buffers = + this.buildBufferBuilder(this.targetChannels, this.shuffleConfig.isMemoryPoolEnable()); + this.resultSlices = this.buildResultSlices(this.targetChannels); + this.recordSerializer = this.getRecordSerializer(); + } + + private BufferBuilder[] buildBufferBuilder(int channels, boolean enableMemoryPool) { + BufferBuilder[] buffers = new BufferBuilder[channels]; + for (int i = 0; i < channels; i++) { + BufferBuilder bufferBuilder = + enableMemoryPool ? new MemoryViewBufferBuilder() : new HeapBufferBuilder(); + bufferBuilder.enableMemoryTrack(); + buffers[i] = bufferBuilder; } - - private BufferBuilder[] buildBufferBuilder(int channels, boolean enableMemoryPool) { - BufferBuilder[] buffers = new BufferBuilder[channels]; - for (int i = 0; i < channels; i++) { - BufferBuilder bufferBuilder = enableMemoryPool - ? new MemoryViewBufferBuilder() - : new HeapBufferBuilder(); - bufferBuilder.enableMemoryTrack(); - buffers[i] = bufferBuilder; - } - return buffers; + return buffers; + } + + protected IPipelineSlice[] buildResultSlices(int channels) { + IPipelineSlice[] slices = new IPipelineSlice[channels]; + WriterId writerId = new WriterId(this.pipelineId, this.edgeId, this.taskIndex); + SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + for (int i = 0; i < channels; i++) { + SliceId sliceId = new SliceId(writerId, i); + IPipelineSlice slice = this.newSlice(this.taskLogTag, sliceId); + slices[i] = slice; + sliceManager.register(sliceId, slice); } + return slices; + } - protected IPipelineSlice[] buildResultSlices(int channels) { - IPipelineSlice[] slices = new IPipelineSlice[channels]; - WriterId writerId = new WriterId(this.pipelineId, this.edgeId, this.taskIndex); - SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - for (int i = 0; i < channels; i++) { - SliceId sliceId = new SliceId(writerId, i); - IPipelineSlice slice = this.newSlice(this.taskLogTag, sliceId); - slices[i] = slice; - sliceManager.register(sliceId, slice); - } - return slices; - } - - protected abstract IPipelineSlice newSlice(String taskLogTag, SliceId sliceId); + protected abstract IPipelineSlice newSlice(String taskLogTag, SliceId sliceId); - @SuppressWarnings("unchecked") - private IRecordSerializer getRecordSerializer() { - IEncoder encoder = this.writerContext.getEncoder(); - if (encoder == null) { - return new RecordSerializer<>(); - } - return new EncoderRecordSerializer<>((IEncoder) encoder); + @SuppressWarnings("unchecked") + private IRecordSerializer getRecordSerializer() { + IEncoder encoder = this.writerContext.getEncoder(); + if (encoder == null) { + return new RecordSerializer<>(); } + return new EncoderRecordSerializer<>((IEncoder) encoder); + } - ////////////////////////////// - // Write data. + ////////////////////////////// + // Write data. - /// /////////////////////////// + /// /////////////////////////// - public void emit(long windowId, T value, boolean isRetract, int[] channels) throws IOException { - for (int channel : channels) { - BufferBuilder outBuffer = this.buffers[channel]; - this.recordSerializer.serialize(value, isRetract, outBuffer); - if (outBuffer.getBufferSize() >= this.maxBufferSize) { - this.sendBuffer(channel, outBuffer, windowId); - } - } + public void emit(long windowId, T value, boolean isRetract, int[] channels) throws IOException { + for (int channel : channels) { + BufferBuilder outBuffer = this.buffers[channel]; + this.recordSerializer.serialize(value, isRetract, outBuffer); + if (outBuffer.getBufferSize() >= this.maxBufferSize) { + this.sendBuffer(channel, outBuffer, windowId); + } } + } - public void emit(long windowId, List data, int channel) throws IOException { - BufferBuilder outBuffer = this.buffers[channel]; - for (T datum : data) { - this.recordSerializer.serialize(datum, false, outBuffer); - } - if (outBuffer.getBufferSize() >= this.maxBufferSize) { - this.sendBuffer(channel, outBuffer, windowId); - } + public void emit(long windowId, List data, int channel) throws IOException { + BufferBuilder outBuffer = this.buffers[channel]; + for (T datum : data) { + this.recordSerializer.serialize(datum, false, outBuffer); } - - public Optional finish(long windowId) throws IOException { - this.flushFloatingBuffers(windowId); - this.notify(new PipelineBarrier(windowId, this.edgeId, this.taskIndex)); - this.flushSlices(); - return this.doFinish(windowId); + if (outBuffer.getBufferSize() >= this.maxBufferSize) { + this.sendBuffer(channel, outBuffer, windowId); } - - protected abstract Optional doFinish(long windowId) throws IOException; - - protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) { - this.recordCounter[sliceIndex] += builder.getRecordCount(); - this.bytesCounter[sliceIndex] += builder.getBufferSize(); - IPipelineSlice resultSlice = this.resultSlices[sliceIndex]; - resultSlice.add(new PipeBuffer(builder.build(), windowId)); + } + + public Optional finish(long windowId) throws IOException { + this.flushFloatingBuffers(windowId); + this.notify(new PipelineBarrier(windowId, this.edgeId, this.taskIndex)); + this.flushSlices(); + return this.doFinish(windowId); + } + + protected abstract Optional doFinish(long windowId) throws IOException; + + protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) { + this.recordCounter[sliceIndex] += builder.getRecordCount(); + this.bytesCounter[sliceIndex] += builder.getBufferSize(); + IPipelineSlice resultSlice = this.resultSlices[sliceIndex]; + resultSlice.add(new PipeBuffer(builder.build(), windowId)); + } + + private void sendBarrier(int sliceIndex, long windowId, int count, boolean isFinish) { + IPipelineSlice resultSlice = this.resultSlices[sliceIndex]; + resultSlice.add(new PipeBuffer(windowId, count, isFinish)); + } + + private void flushFloatingBuffers(long windowId) { + for (int i = 0; i < this.targetChannels; i++) { + BufferBuilder bufferBuilder = this.buffers[i]; + if (bufferBuilder.getBufferSize() > 0) { + this.sendBuffer(i, bufferBuilder, windowId); + } } - - private void sendBarrier(int sliceIndex, long windowId, int count, boolean isFinish) { - IPipelineSlice resultSlice = this.resultSlices[sliceIndex]; - resultSlice.add(new PipeBuffer(windowId, count, isFinish)); - } - - private void flushFloatingBuffers(long windowId) { - for (int i = 0; i < this.targetChannels; i++) { - BufferBuilder bufferBuilder = this.buffers[i]; - if (bufferBuilder.getBufferSize() > 0) { - this.sendBuffer(i, bufferBuilder, windowId); - } - } - } - - protected boolean flushSlices() { - IPipelineSlice[] pipeSlices = this.resultSlices; - boolean flushed = false; - if (pipeSlices != null) { - for (int i = 0; i < pipeSlices.length; i++) { - if (null != pipeSlices[i]) { - pipeSlices[i].flush(); - flushed = true; - } - } + } + + protected boolean flushSlices() { + IPipelineSlice[] pipeSlices = this.resultSlices; + boolean flushed = false; + if (pipeSlices != null) { + for (int i = 0; i < pipeSlices.length; i++) { + if (null != pipeSlices[i]) { + pipeSlices[i].flush(); + flushed = true; } - return flushed; + } } - - public ShuffleWriteMetrics getShuffleWriteMetrics() { - return this.writeMetrics; - } - - protected void notify(PipelineBarrier barrier) throws IOException { - for (int channel = 0; channel < this.targetChannels; channel++) { - long windowId = barrier.getWindowId(); - long recordCount = this.recordCounter[channel]; - long bytesCount = this.bytesCounter[channel]; - sendBarrier(channel, windowId, (int) recordCount, barrier.isFinish()); - - this.writeMetrics.increaseRecords(recordCount); - this.writeMetrics.increaseEncodedSize(bytesCount); - this.recordCounter[channel] = 0; - this.bytesCounter[channel] = 0; - } - } - - public void close() { + return flushed; + } + + public ShuffleWriteMetrics getShuffleWriteMetrics() { + return this.writeMetrics; + } + + protected void notify(PipelineBarrier barrier) throws IOException { + for (int channel = 0; channel < this.targetChannels; channel++) { + long windowId = barrier.getWindowId(); + long recordCount = this.recordCounter[channel]; + long bytesCount = this.bytesCounter[channel]; + sendBarrier(channel, windowId, (int) recordCount, barrier.isFinish()); + + this.writeMetrics.increaseRecords(recordCount); + this.writeMetrics.increaseEncodedSize(bytesCount); + this.recordCounter[channel] = 0; + this.bytesCounter[channel] = 0; } + } -} \ No newline at end of file + public void close() {} +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriter.java index 59ab2960a..1bb9347f4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriter.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.shuffle.message.ISliceMeta; import org.apache.geaflow.shuffle.message.PipelineSliceMeta; @@ -32,33 +33,32 @@ public class SpillableShardWriter extends ShardWriter { - protected final ShuffleAddress shuffleAddress; + protected final ShuffleAddress shuffleAddress; - public SpillableShardWriter(ShuffleAddress shuffleAddress) { - this.shuffleAddress = shuffleAddress; - } + public SpillableShardWriter(ShuffleAddress shuffleAddress) { + this.shuffleAddress = shuffleAddress; + } - @Override - protected IPipelineSlice newSlice(String taskLogTag, SliceId sliceId) { - return new SpillablePipelineSlice(taskLogTag, sliceId); - } + @Override + protected IPipelineSlice newSlice(String taskLogTag, SliceId sliceId) { + return new SpillablePipelineSlice(taskLogTag, sliceId); + } - @Override - public Optional doFinish(long windowId) { - List slices = this.buildSliceMeta(windowId); - return Optional.of(new Shard(this.edgeId, slices)); - } + @Override + public Optional doFinish(long windowId) { + List slices = this.buildSliceMeta(windowId); + return Optional.of(new Shard(this.edgeId, slices)); + } - private List buildSliceMeta(long windowId) { - List slices = new ArrayList<>(); - for (int i = 0; i < this.targetChannels; i++) { - SliceId sliceId = this.resultSlices[i].getSliceId(); - PipelineSliceMeta sliceMeta = new PipelineSliceMeta(sliceId, windowId, this.shuffleAddress); - sliceMeta.setRecordNum(this.recordCounter[i]); - sliceMeta.setEncodedSize(this.bytesCounter[i]); - slices.add(sliceMeta); - } - return slices; + private List buildSliceMeta(long windowId) { + List slices = new ArrayList<>(); + for (int i = 0; i < this.targetChannels; i++) { + SliceId sliceId = this.resultSlices[i].getSliceId(); + PipelineSliceMeta sliceMeta = new PipelineSliceMeta(sliceId, windowId, this.shuffleAddress); + sliceMeta.setRecordNum(this.recordCounter[i]); + sliceMeta.setEncodedSize(this.bytesCounter[i]); + slices.add(sliceMeta); } - + return slices; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/WriterContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/WriterContext.java index 1dd488cc9..8dac0fadd 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/WriterContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/api/writer/WriterContext.java @@ -26,132 +26,131 @@ public class WriterContext implements IWriterContext { - private final PipelineInfo pipelineInfo; - private int edgeId; - private int vertexId; - private int taskIndex; - private int taskId; - private String taskName; - private DataExchangeMode dataExchangeMode; - private int targetChannels; - private ShuffleConfig config; - private IEncoder encoder; - - public WriterContext(long pipelineId, String pipelineName) { - this.pipelineInfo = new PipelineInfo(pipelineId, pipelineName); - } - - public WriterContext setEdgeId(int edgeId) { - this.edgeId = edgeId; - return this; - } - - public WriterContext setVertexId(int vertexId) { - this.vertexId = vertexId; - return this; - } - - public WriterContext setTaskIndex(int taskIndex) { - this.taskIndex = taskIndex; - return this; - } - - public WriterContext setTaskId(int taskId) { - this.taskId = taskId; - return this; - } - - public WriterContext setTaskName(String taskName) { - this.taskName = taskName; - return this; - } - - public WriterContext setDataExchangeMode(DataExchangeMode dataExchangeMode) { - this.dataExchangeMode = dataExchangeMode; - return this; - } - - public WriterContext setChannelNum(int targetChannels) { - this.targetChannels = targetChannels; - return this; - } - - public WriterContext setConfig(ShuffleConfig config) { - this.config = config; - return this; - } - - public WriterContext setEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } - - @Override - public PipelineInfo getPipelineInfo() { - return pipelineInfo; - } - - @Override - public int getEdgeId() { - return edgeId; - } - - @Override - public int getVertexId() { - return vertexId; - } - - @Override - public int getTaskId() { - return taskId; - } - - @Override - public String getTaskName() { - return taskName; - } - - @Override - public int getTaskIndex() { - return taskIndex; - } - - @Override - public ShuffleConfig getConfig() { - return config; - } - - @Override - public int getTargetChannelNum() { - return targetChannels; - } - - @Override - public IEncoder getEncoder() { - return this.encoder; - } - - @Override - public DataExchangeMode getDataExchangeMode() { - return this.dataExchangeMode; - } - - public static WriterContextBuilder newBuilder() { - return new WriterContextBuilder(); - } - - public static class WriterContextBuilder { - - private long pipelineId; - - public WriterContextBuilder setPipelineId(long pipelineId) { - this.pipelineId = pipelineId; - return this; - } - - public WriterContext setPipelineName(String pipelineName) { - return new WriterContext(pipelineId, pipelineName); - } - } - + private final PipelineInfo pipelineInfo; + private int edgeId; + private int vertexId; + private int taskIndex; + private int taskId; + private String taskName; + private DataExchangeMode dataExchangeMode; + private int targetChannels; + private ShuffleConfig config; + private IEncoder encoder; + + public WriterContext(long pipelineId, String pipelineName) { + this.pipelineInfo = new PipelineInfo(pipelineId, pipelineName); + } + + public WriterContext setEdgeId(int edgeId) { + this.edgeId = edgeId; + return this; + } + + public WriterContext setVertexId(int vertexId) { + this.vertexId = vertexId; + return this; + } + + public WriterContext setTaskIndex(int taskIndex) { + this.taskIndex = taskIndex; + return this; + } + + public WriterContext setTaskId(int taskId) { + this.taskId = taskId; + return this; + } + + public WriterContext setTaskName(String taskName) { + this.taskName = taskName; + return this; + } + + public WriterContext setDataExchangeMode(DataExchangeMode dataExchangeMode) { + this.dataExchangeMode = dataExchangeMode; + return this; + } + + public WriterContext setChannelNum(int targetChannels) { + this.targetChannels = targetChannels; + return this; + } + + public WriterContext setConfig(ShuffleConfig config) { + this.config = config; + return this; + } + + public WriterContext setEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } + + @Override + public PipelineInfo getPipelineInfo() { + return pipelineInfo; + } + + @Override + public int getEdgeId() { + return edgeId; + } + + @Override + public int getVertexId() { + return vertexId; + } + + @Override + public int getTaskId() { + return taskId; + } + + @Override + public String getTaskName() { + return taskName; + } + + @Override + public int getTaskIndex() { + return taskIndex; + } + + @Override + public ShuffleConfig getConfig() { + return config; + } + + @Override + public int getTargetChannelNum() { + return targetChannels; + } + + @Override + public IEncoder getEncoder() { + return this.encoder; + } + + @Override + public DataExchangeMode getDataExchangeMode() { + return this.dataExchangeMode; + } + + public static WriterContextBuilder newBuilder() { + return new WriterContextBuilder(); + } + + public static class WriterContextBuilder { + + private long pipelineId; + + public WriterContextBuilder setPipelineId(long pipelineId) { + this.pipelineId = pipelineId; + return this; + } + + public WriterContext setPipelineName(String pipelineName) { + return new WriterContext(pipelineId, pipelineName); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/config/ShuffleConfig.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/config/ShuffleConfig.java index 7e0b6fcf4..c7dd59065 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/config/ShuffleConfig.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/config/ShuffleConfig.java @@ -53,222 +53,215 @@ public class ShuffleConfig { - private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleConfig.class); - - private final Configuration configuration; - - ////////////////////////////// - // Netty - /// /////////////////////////// - - private final String serverAddress; - private final int serverPort; - // Connect timeout in milliseconds. Default 120 secs. - private final int connectTimeoutMs; - // fetch timeout in milliseconds. Default is 600 secs. - private final int serverBacklog; - private final int serverThreads; - private final int clientThreads; - private final int receiveBufferSize; - private final int sendBufferSize; - private final int connectMaxRetryTimes; - private final int connectInitBackoffMs; - private final int connectMaxBackoffMs; - private final boolean threadCacheEnabled; - private final boolean preferDirectBuffer; - private final boolean customFrameDecoderEnable; - private final boolean enableBackpressure; - - ////////////////////////////// - // Read & Write - /// /////////////////////////// - - private final boolean memoryPoolEnable; - private final boolean compressionEnabled; - - ////////////////////////////// - // Read - /// /////////////////////////// - - private final int fetchTimeoutMs; - private final int fetchQueueSize; - private final int channelQueueSize; - - - ////////////////////////////// - // Write - /// /////////////////////////// - - private final int emitQueueSize; - private final int emitBufferSize; - private final int maxBufferSizeBytes; - private final int maxWriteBufferSize; - private final int flushBufferTimeoutMs; - private final StorageLevel storageLevel; - - public ShuffleConfig(Configuration config) { - this.configuration = config; - - // netty - this.serverAddress = config.getString(NETTY_SERVER_HOST); - this.serverPort = config.getInteger(NETTY_SERVER_PORT); - this.connectTimeoutMs = config.getInteger(NETTY_CONNECT_TIMEOUT_MS); - this.serverBacklog = config.getInteger(NETTY_SERVER_BACKLOG); - this.serverThreads = config.getInteger(NETTY_SERVER_THREADS_NUM); - this.clientThreads = config.getInteger(NETTY_CLIENT_THREADS_NUM); - this.receiveBufferSize = config.getInteger(NETTY_RECEIVE_BUFFER_SIZE); - this.sendBufferSize = config.getInteger(NETTY_SEND_BUFFER_SIZE); - this.connectMaxRetryTimes = config.getInteger(NETTY_CONNECT_MAX_RETRY_TIMES); - this.connectInitBackoffMs = config.getInteger(NETTY_CONNECT_INITIAL_BACKOFF_MS); - this.connectMaxBackoffMs = config.getInteger(NETTY_CONNECT_MAX_BACKOFF_MS); - this.threadCacheEnabled = config.getBoolean(NETTY_THREAD_CACHE_ENABLE); - this.preferDirectBuffer = config.getBoolean(NETTY_PREFER_DIRECT_BUFFER); - this.customFrameDecoderEnable = config.getBoolean(NETTY_CUSTOM_FRAME_DECODER_ENABLE); - - // read & write - this.memoryPoolEnable = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE); - this.compressionEnabled = config.getBoolean(SHUFFLE_COMPRESSION_ENABLE); - this.enableBackpressure = config.getBoolean(SHUFFLE_BACKPRESSURE_ENABLE); - - // read - this.fetchTimeoutMs = config.getInteger(SHUFFLE_FETCH_TIMEOUT_MS); - this.fetchQueueSize = config.getInteger(SHUFFLE_FETCH_QUEUE_SIZE); - this.channelQueueSize = config.getInteger(SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE); - - // write - this.emitQueueSize = config.getInteger(SHUFFLE_EMIT_QUEUE_SIZE); - this.emitBufferSize = config.getInteger(SHUFFLE_EMIT_BUFFER_SIZE); - this.maxBufferSizeBytes = config.getInteger(SHUFFLE_FLUSH_BUFFER_SIZE_BYTES); - this.maxWriteBufferSize = config.getInteger(SHUFFLE_WRITER_BUFFER_SIZE); - this.flushBufferTimeoutMs = config.getInteger(SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS); - this.storageLevel = StorageLevel.valueOf(config.getString(SHUFFLE_STORAGE_TYPE)); - - LOGGER.info("init shuffle config: {}", config); - } - - public Configuration getConfig() { - return this.configuration; - } - - public String getServerAddress() { - return this.serverAddress; - } - - public int getServerPort() { - return this.serverPort; - } - - public int getConnectTimeoutMs() { - return this.connectTimeoutMs; - } - - /** - * Requested maximum length of the queue of incoming connections. - * - * @return server back log. - */ - public int getServerConnectBacklog() { - return this.serverBacklog; - } - - /** - * Number of threads used in the server thread pool. Default to 4, and 0 means 2x#cores. - */ - public int getServerThreadsNum() { - return this.serverThreads; - } - - /** - * Number of threads used in the client thread pool. Default to 4, and 0 means 2x#cores. - */ - public int getClientNumThreads() { - return this.clientThreads; - } - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - public int getReceiveBufferSize() { - return this.receiveBufferSize; - } - - public int getSendBufferSize() { - return this.sendBufferSize; - } - - public int getConnectMaxRetries() { - return this.connectMaxRetryTimes; - } - - public int getConnectInitialBackoffMs() { - return this.connectInitBackoffMs; - } - - public int getConnectMaxBackoffMs() { - return this.connectMaxBackoffMs; - } - - public boolean isThreadCacheEnabled() { - return this.threadCacheEnabled; - } - - public boolean preferDirectBuffer() { - return this.preferDirectBuffer; - } - - public boolean enableCustomFrameDecoder() { - return this.customFrameDecoderEnable; - } - - public boolean isMemoryPoolEnable() { - return this.memoryPoolEnable; - } - - public boolean isCompressionEnabled() { - return this.compressionEnabled; - } - - public int getFetchTimeoutMs() { - return this.fetchTimeoutMs; - } - - public int getFetchQueueSize() { - return this.fetchQueueSize; - } - - public int getChannelQueueSize() { - return channelQueueSize; - } - - public int getEmitQueueSize() { - return this.emitQueueSize; - } - - public int getEmitBufferSize() { - return this.emitBufferSize; - } - - public int getMaxBufferSizeBytes() { - return this.maxBufferSizeBytes; - } - - public boolean isBackpressureEnabled() { - return enableBackpressure; - } - - public int getMaxWriteBufferSize() { - return maxWriteBufferSize; - } - - public int getFlushBufferTimeoutMs() { - return this.flushBufferTimeoutMs; - } - - public StorageLevel getStorageLevel() { - return this.storageLevel; - } + private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleConfig.class); + + private final Configuration configuration; + + ////////////////////////////// + // Netty + /// /////////////////////////// + + private final String serverAddress; + private final int serverPort; + // Connect timeout in milliseconds. Default 120 secs. + private final int connectTimeoutMs; + // fetch timeout in milliseconds. Default is 600 secs. + private final int serverBacklog; + private final int serverThreads; + private final int clientThreads; + private final int receiveBufferSize; + private final int sendBufferSize; + private final int connectMaxRetryTimes; + private final int connectInitBackoffMs; + private final int connectMaxBackoffMs; + private final boolean threadCacheEnabled; + private final boolean preferDirectBuffer; + private final boolean customFrameDecoderEnable; + private final boolean enableBackpressure; + + ////////////////////////////// + // Read & Write + /// /////////////////////////// + + private final boolean memoryPoolEnable; + private final boolean compressionEnabled; + + ////////////////////////////// + // Read + /// /////////////////////////// + + private final int fetchTimeoutMs; + private final int fetchQueueSize; + private final int channelQueueSize; + + ////////////////////////////// + // Write + /// /////////////////////////// + + private final int emitQueueSize; + private final int emitBufferSize; + private final int maxBufferSizeBytes; + private final int maxWriteBufferSize; + private final int flushBufferTimeoutMs; + private final StorageLevel storageLevel; + + public ShuffleConfig(Configuration config) { + this.configuration = config; + + // netty + this.serverAddress = config.getString(NETTY_SERVER_HOST); + this.serverPort = config.getInteger(NETTY_SERVER_PORT); + this.connectTimeoutMs = config.getInteger(NETTY_CONNECT_TIMEOUT_MS); + this.serverBacklog = config.getInteger(NETTY_SERVER_BACKLOG); + this.serverThreads = config.getInteger(NETTY_SERVER_THREADS_NUM); + this.clientThreads = config.getInteger(NETTY_CLIENT_THREADS_NUM); + this.receiveBufferSize = config.getInteger(NETTY_RECEIVE_BUFFER_SIZE); + this.sendBufferSize = config.getInteger(NETTY_SEND_BUFFER_SIZE); + this.connectMaxRetryTimes = config.getInteger(NETTY_CONNECT_MAX_RETRY_TIMES); + this.connectInitBackoffMs = config.getInteger(NETTY_CONNECT_INITIAL_BACKOFF_MS); + this.connectMaxBackoffMs = config.getInteger(NETTY_CONNECT_MAX_BACKOFF_MS); + this.threadCacheEnabled = config.getBoolean(NETTY_THREAD_CACHE_ENABLE); + this.preferDirectBuffer = config.getBoolean(NETTY_PREFER_DIRECT_BUFFER); + this.customFrameDecoderEnable = config.getBoolean(NETTY_CUSTOM_FRAME_DECODER_ENABLE); + + // read & write + this.memoryPoolEnable = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE); + this.compressionEnabled = config.getBoolean(SHUFFLE_COMPRESSION_ENABLE); + this.enableBackpressure = config.getBoolean(SHUFFLE_BACKPRESSURE_ENABLE); + + // read + this.fetchTimeoutMs = config.getInteger(SHUFFLE_FETCH_TIMEOUT_MS); + this.fetchQueueSize = config.getInteger(SHUFFLE_FETCH_QUEUE_SIZE); + this.channelQueueSize = config.getInteger(SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE); + + // write + this.emitQueueSize = config.getInteger(SHUFFLE_EMIT_QUEUE_SIZE); + this.emitBufferSize = config.getInteger(SHUFFLE_EMIT_BUFFER_SIZE); + this.maxBufferSizeBytes = config.getInteger(SHUFFLE_FLUSH_BUFFER_SIZE_BYTES); + this.maxWriteBufferSize = config.getInteger(SHUFFLE_WRITER_BUFFER_SIZE); + this.flushBufferTimeoutMs = config.getInteger(SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS); + this.storageLevel = StorageLevel.valueOf(config.getString(SHUFFLE_STORAGE_TYPE)); + + LOGGER.info("init shuffle config: {}", config); + } + + public Configuration getConfig() { + return this.configuration; + } + + public String getServerAddress() { + return this.serverAddress; + } + + public int getServerPort() { + return this.serverPort; + } + + public int getConnectTimeoutMs() { + return this.connectTimeoutMs; + } + + /** + * Requested maximum length of the queue of incoming connections. + * + * @return server back log. + */ + public int getServerConnectBacklog() { + return this.serverBacklog; + } + + /** Number of threads used in the server thread pool. Default to 4, and 0 means 2x#cores. */ + public int getServerThreadsNum() { + return this.serverThreads; + } + + /** Number of threads used in the client thread pool. Default to 4, and 0 means 2x#cores. */ + public int getClientNumThreads() { + return this.clientThreads; + } + + /** + * Receive buffer size (SO_RCVBUF). Note: the optimal size for receive buffer and send buffer + * should be latency * network_bandwidth. Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + public int getReceiveBufferSize() { + return this.receiveBufferSize; + } + + public int getSendBufferSize() { + return this.sendBufferSize; + } + + public int getConnectMaxRetries() { + return this.connectMaxRetryTimes; + } + + public int getConnectInitialBackoffMs() { + return this.connectInitBackoffMs; + } + + public int getConnectMaxBackoffMs() { + return this.connectMaxBackoffMs; + } + + public boolean isThreadCacheEnabled() { + return this.threadCacheEnabled; + } + + public boolean preferDirectBuffer() { + return this.preferDirectBuffer; + } + + public boolean enableCustomFrameDecoder() { + return this.customFrameDecoderEnable; + } + + public boolean isMemoryPoolEnable() { + return this.memoryPoolEnable; + } + + public boolean isCompressionEnabled() { + return this.compressionEnabled; + } + + public int getFetchTimeoutMs() { + return this.fetchTimeoutMs; + } + + public int getFetchQueueSize() { + return this.fetchQueueSize; + } + + public int getChannelQueueSize() { + return channelQueueSize; + } + + public int getEmitQueueSize() { + return this.emitQueueSize; + } + + public int getEmitBufferSize() { + return this.emitBufferSize; + } + + public int getMaxBufferSizeBytes() { + return this.maxBufferSizeBytes; + } + + public boolean isBackpressureEnabled() { + return enableBackpressure; + } + + public int getMaxWriteBufferSize() { + return maxWriteBufferSize; + } + + public int getFlushBufferTimeoutMs() { + return this.flushBufferTimeoutMs; + } + + public StorageLevel getStorageLevel() { + return this.storageLevel; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IInputDesc.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IInputDesc.java index 86217d590..ef6f01959 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IInputDesc.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IInputDesc.java @@ -23,24 +23,15 @@ public interface IInputDesc { - /** - * Return he edge id of correlated input execution edge. - */ - int getEdgeId(); + /** Return he edge id of correlated input execution edge. */ + int getEdgeId(); - /** - * Return the edge name of correlated input execution edge. - */ - String getName(); + /** Return the edge name of correlated input execution edge. */ + String getName(); - /** - * Return data descriptors of current input. - */ - List getInput(); - - /** - * Return input data type, including shuffle shard meta and raw data. - */ - InputType getInputType(); + /** Return data descriptors of current input. */ + List getInput(); + /** Return input data type, including shuffle shard meta and raw data. */ + InputType getInputType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IOutputDesc.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IOutputDesc.java index ef6032e82..761663b17 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IOutputDesc.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/IOutputDesc.java @@ -21,19 +21,12 @@ public interface IOutputDesc { - /** - * Return output edge id. - */ - int getEdgeId(); + /** Return output edge id. */ + int getEdgeId(); - /** - * Return output edge name. - */ - String getEdgeName(); - - /** - * Return the type of data transfer on the edge. - */ - OutputType getType(); + /** Return output edge name. */ + String getEdgeName(); + /** Return the type of data transfer on the edge. */ + OutputType getType(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/InputType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/InputType.java index 214884cdb..9576a4e96 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/InputType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/InputType.java @@ -20,12 +20,8 @@ package org.apache.geaflow.shuffle.desc; public enum InputType { - /** - * Input data is meta. - */ - META, - /** - * Input data is real data. - */ - DATA + /** Input data is meta. */ + META, + /** Input data is real data. */ + DATA } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/OutputType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/OutputType.java index e24d74b64..680904e87 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/OutputType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/OutputType.java @@ -21,18 +21,12 @@ public enum OutputType { - /** - * Shuffle data forward from upstream to downstream. - */ - FORWARD, + /** Shuffle data forward from upstream to downstream. */ + FORWARD, - /** - * Collect output response to scheduler. - */ - RESPONSE, + /** Collect output response to scheduler. */ + RESPONSE, - /** - * Collect and forward data to cycle itself, that data loop. - */ - LOOP + /** Collect and forward data to cycle itself, that data loop. */ + LOOP } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/ShardInputDesc.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/ShardInputDesc.java index 0cd60e313..d422eb9c0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/ShardInputDesc.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/desc/ShardInputDesc.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.desc; import java.util.List; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.shuffle.BatchPhase; import org.apache.geaflow.common.shuffle.DataExchangeMode; @@ -27,70 +28,74 @@ public class ShardInputDesc implements IInputDesc { - private final int edgeId; - private final String edgeName; - private final List shards; - - private final IEncoder encoder; - private final DataExchangeMode dataExchangeMode; - private final BatchPhase batchPhase; - - public ShardInputDesc(int edgeId, - String edgeName, - List shards, - IEncoder encoder, - DataExchangeMode dataExchangeMode, - BatchPhase batchPhase) { - this.edgeId = edgeId; - this.edgeName = edgeName; - this.shards = shards; - this.encoder = encoder; - this.dataExchangeMode = dataExchangeMode; - this.batchPhase = batchPhase; - } - - @Override - public int getEdgeId() { - return edgeId; - } - - @Override - public String getName() { - return edgeName; - } - - @Override - public List getInput() { - return shards; - } - - @Override - public InputType getInputType() { - return InputType.META; - } - - public IEncoder getEncoder() { - return encoder; - } - - public DataExchangeMode getDataExchangeMode() { - return this.dataExchangeMode; - } - - public BatchPhase getBatchPhase() { - return this.batchPhase; - } - - public int getSliceNum() { - return this.isPrefetchRead() ? 1 : this.shards.stream().mapToInt(s -> s.getSlices().size()).sum(); - } - - public boolean isPrefetchWrite() { - return this.dataExchangeMode == DataExchangeMode.BATCH && this.batchPhase == BatchPhase.PREFETCH_WRITE; - } - - public boolean isPrefetchRead() { - return this.dataExchangeMode == DataExchangeMode.BATCH && this.batchPhase == BatchPhase.PREFETCH_READ; - } - + private final int edgeId; + private final String edgeName; + private final List shards; + + private final IEncoder encoder; + private final DataExchangeMode dataExchangeMode; + private final BatchPhase batchPhase; + + public ShardInputDesc( + int edgeId, + String edgeName, + List shards, + IEncoder encoder, + DataExchangeMode dataExchangeMode, + BatchPhase batchPhase) { + this.edgeId = edgeId; + this.edgeName = edgeName; + this.shards = shards; + this.encoder = encoder; + this.dataExchangeMode = dataExchangeMode; + this.batchPhase = batchPhase; + } + + @Override + public int getEdgeId() { + return edgeId; + } + + @Override + public String getName() { + return edgeName; + } + + @Override + public List getInput() { + return shards; + } + + @Override + public InputType getInputType() { + return InputType.META; + } + + public IEncoder getEncoder() { + return encoder; + } + + public DataExchangeMode getDataExchangeMode() { + return this.dataExchangeMode; + } + + public BatchPhase getBatchPhase() { + return this.batchPhase; + } + + public int getSliceNum() { + return this.isPrefetchRead() + ? 1 + : this.shards.stream().mapToInt(s -> s.getSlices().size()).sum(); + } + + public boolean isPrefetchWrite() { + return this.dataExchangeMode == DataExchangeMode.BATCH + && this.batchPhase == BatchPhase.PREFETCH_WRITE; + } + + public boolean isPrefetchRead() { + return this.dataExchangeMode == DataExchangeMode.BATCH + && this.batchPhase == BatchPhase.PREFETCH_READ; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/BaseSliceMeta.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/BaseSliceMeta.java index 79f86af44..cd0fed395 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/BaseSliceMeta.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/BaseSliceMeta.java @@ -21,67 +21,65 @@ public class BaseSliceMeta implements ISliceMeta { - protected int sourceIndex; - protected int targetIndex; - protected long recordNum; - protected long encodedSize; - protected int edgeId; - protected long windowId; - - public BaseSliceMeta() { - } - - public BaseSliceMeta(int sourceIndex, int targetIndex) { - this.sourceIndex = sourceIndex; - this.targetIndex = targetIndex; - } - - public int getSourceIndex() { - return sourceIndex; - } - - public void setSourceIndex(int sourceIndex) { - this.sourceIndex = sourceIndex; - } - - public int getTargetIndex() { - return targetIndex; - } - - public void setTargetIndex(int targetIndex) { - this.targetIndex = targetIndex; - } - - public long getRecordNum() { - return recordNum; - } - - public void setRecordNum(long recordNum) { - this.recordNum = recordNum; - } - - public long getEncodedSize() { - return encodedSize; - } - - public void setEncodedSize(long encodedSize) { - this.encodedSize = encodedSize; - } - - public int getEdgeId() { - return edgeId; - } - - public void setEdgeId(int edgeId) { - this.edgeId = edgeId; - } - - public long getWindowId() { - return windowId; - } - - public void setWindowId(long windowId) { - this.windowId = windowId; - } - + protected int sourceIndex; + protected int targetIndex; + protected long recordNum; + protected long encodedSize; + protected int edgeId; + protected long windowId; + + public BaseSliceMeta() {} + + public BaseSliceMeta(int sourceIndex, int targetIndex) { + this.sourceIndex = sourceIndex; + this.targetIndex = targetIndex; + } + + public int getSourceIndex() { + return sourceIndex; + } + + public void setSourceIndex(int sourceIndex) { + this.sourceIndex = sourceIndex; + } + + public int getTargetIndex() { + return targetIndex; + } + + public void setTargetIndex(int targetIndex) { + this.targetIndex = targetIndex; + } + + public long getRecordNum() { + return recordNum; + } + + public void setRecordNum(long recordNum) { + this.recordNum = recordNum; + } + + public long getEncodedSize() { + return encodedSize; + } + + public void setEncodedSize(long encodedSize) { + this.encodedSize = encodedSize; + } + + public int getEdgeId() { + return edgeId; + } + + public void setEdgeId(int edgeId) { + this.edgeId = edgeId; + } + + public long getWindowId() { + return windowId; + } + + public void setWindowId(long windowId) { + this.windowId = windowId; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ISliceMeta.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ISliceMeta.java index c5e2b21f9..50bb380cc 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ISliceMeta.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ISliceMeta.java @@ -23,32 +23,31 @@ public interface ISliceMeta extends Serializable { - /** - * Get source index of the slice. - * - * @return source index. - */ - int getSourceIndex(); - - /** - * Get target index of the slice. - * - * @return target index. - */ - int getTargetIndex(); - - /** - * Get record number of the slice. - * - * @return record number. - */ - long getRecordNum(); - - /** - * Get encode size (bytes) of the slice. - * - * @return encode size. - */ - long getEncodedSize(); - + /** + * Get source index of the slice. + * + * @return source index. + */ + int getSourceIndex(); + + /** + * Get target index of the slice. + * + * @return target index. + */ + int getTargetIndex(); + + /** + * Get record number of the slice. + * + * @return record number. + */ + long getRecordNum(); + + /** + * Get encode size (bytes) of the slice. + * + * @return encode size. + */ + long getEncodedSize(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/LogicalPipelineSliceMeta.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/LogicalPipelineSliceMeta.java index 662c7bbf2..983516cb0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/LogicalPipelineSliceMeta.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/LogicalPipelineSliceMeta.java @@ -21,24 +21,23 @@ public class LogicalPipelineSliceMeta extends BaseSliceMeta { - private final SliceId sliceId; - private final String containerId; + private final SliceId sliceId; + private final String containerId; - public LogicalPipelineSliceMeta(int sourceIndex, int targetIndex, long pipelineId, int edgeId, - String containerId) { - super(sourceIndex, targetIndex); - this.windowId = -1; - this.sliceId = new SliceId(pipelineId, edgeId, sourceIndex, targetIndex); - this.containerId = containerId; - setEdgeId(edgeId); - } + public LogicalPipelineSliceMeta( + int sourceIndex, int targetIndex, long pipelineId, int edgeId, String containerId) { + super(sourceIndex, targetIndex); + this.windowId = -1; + this.sliceId = new SliceId(pipelineId, edgeId, sourceIndex, targetIndex); + this.containerId = containerId; + setEdgeId(edgeId); + } - public SliceId getSliceId() { - return sliceId; - } - - public String getContainerId() { - return containerId; - } + public SliceId getSliceId() { + return sliceId; + } + public String getContainerId() { + return containerId; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineBarrier.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineBarrier.java index 7c11e1ca8..658ed136f 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineBarrier.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineBarrier.java @@ -23,100 +23,114 @@ public class PipelineBarrier implements PipelineEvent { - // Input edge id. - private final int edgeId; - - // Iteration id of cycle. - private final long windowId; - - // If of source task that send the event from. - private int sourceTaskIndex; - - // Id of target task that send the event to. - private int targetTaskIndex; - - // Message count that current event involved. - private long count; - - // Flag that denote source task is finished after the current event. - private boolean finish; - - public PipelineBarrier(long windowId, int edgeId, int sourceTaskIndex) { - this.edgeId = edgeId; - this.windowId = windowId; - this.sourceTaskIndex = sourceTaskIndex; - } - - public PipelineBarrier(long windowId, int edgeId, long count) { - this.edgeId = edgeId; - this.windowId = windowId; - this.count = count; - this.finish = false; - } - - public PipelineBarrier(long windowId, int edgeId, int sourceTaskIndex, int targetTaskId, long count) { - this.edgeId = edgeId; - this.windowId = windowId; - this.sourceTaskIndex = sourceTaskIndex; - this.targetTaskIndex = targetTaskId; - this.count = count; - this.finish = false; - } - - @Override - public int getEdgeId() { - return edgeId; - } - - public void setFinish(boolean finish) { - this.finish = finish; - } - - public boolean isFinish() { - return finish; - } - - @Override - public long getWindowId() { - return windowId; - } - - public int getSourceTaskIndex() { - return sourceTaskIndex; - } - - public int getTargetTaskIndex() { - return targetTaskIndex; + // Input edge id. + private final int edgeId; + + // Iteration id of cycle. + private final long windowId; + + // If of source task that send the event from. + private int sourceTaskIndex; + + // Id of target task that send the event to. + private int targetTaskIndex; + + // Message count that current event involved. + private long count; + + // Flag that denote source task is finished after the current event. + private boolean finish; + + public PipelineBarrier(long windowId, int edgeId, int sourceTaskIndex) { + this.edgeId = edgeId; + this.windowId = windowId; + this.sourceTaskIndex = sourceTaskIndex; + } + + public PipelineBarrier(long windowId, int edgeId, long count) { + this.edgeId = edgeId; + this.windowId = windowId; + this.count = count; + this.finish = false; + } + + public PipelineBarrier( + long windowId, int edgeId, int sourceTaskIndex, int targetTaskId, long count) { + this.edgeId = edgeId; + this.windowId = windowId; + this.sourceTaskIndex = sourceTaskIndex; + this.targetTaskIndex = targetTaskId; + this.count = count; + this.finish = false; + } + + @Override + public int getEdgeId() { + return edgeId; + } + + public void setFinish(boolean finish) { + this.finish = finish; + } + + public boolean isFinish() { + return finish; + } + + @Override + public long getWindowId() { + return windowId; + } + + public int getSourceTaskIndex() { + return sourceTaskIndex; + } + + public int getTargetTaskIndex() { + return targetTaskIndex; + } + + public long getCount() { + return count; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public long getCount() { - return count; + if (o == null || getClass() != o.getClass()) { + return false; } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - PipelineBarrier that = (PipelineBarrier) o; - return edgeId == that.edgeId && windowId == that.windowId - && sourceTaskIndex == that.sourceTaskIndex && targetTaskIndex == that.targetTaskIndex - && count == that.count && finish == that.finish; - } - - @Override - public int hashCode() { - return Objects.hash(edgeId, windowId, sourceTaskIndex, targetTaskIndex, count, finish); - } - - @Override - public String toString() { - return "PipelineBarrier{" + "edgeId=" + edgeId + ", windowId=" + windowId - + ", sourceTaskIndex=" + sourceTaskIndex + ", targetTaskIndex=" + targetTaskIndex - + ", count=" + count + ", finish=" + finish + '}'; - } - + PipelineBarrier that = (PipelineBarrier) o; + return edgeId == that.edgeId + && windowId == that.windowId + && sourceTaskIndex == that.sourceTaskIndex + && targetTaskIndex == that.targetTaskIndex + && count == that.count + && finish == that.finish; + } + + @Override + public int hashCode() { + return Objects.hash(edgeId, windowId, sourceTaskIndex, targetTaskIndex, count, finish); + } + + @Override + public String toString() { + return "PipelineBarrier{" + + "edgeId=" + + edgeId + + ", windowId=" + + windowId + + ", sourceTaskIndex=" + + sourceTaskIndex + + ", targetTaskIndex=" + + targetTaskIndex + + ", count=" + + count + + ", finish=" + + finish + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineEvent.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineEvent.java index 9cf738a82..dbc7dcd07 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineEvent.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineEvent.java @@ -21,18 +21,17 @@ public interface PipelineEvent { - /** - * Get the edge id of this pipeline event. - * - * @return edge id - */ - int getEdgeId(); - - /** - * Get the window id of this pipeline event. - * - * @return window id. - */ - long getWindowId(); + /** + * Get the edge id of this pipeline event. + * + * @return edge id + */ + int getEdgeId(); + /** + * Get the window id of this pipeline event. + * + * @return window id. + */ + long getWindowId(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineInfo.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineInfo.java index 086e2c410..a1867eb7c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineInfo.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineInfo.java @@ -23,24 +23,30 @@ public class PipelineInfo implements Serializable { - private final long pipelineId; - private final String pipelineName; - - public PipelineInfo(long pipelineId, String pipelineName) { - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - } - - public long getPipelineId() { - return pipelineId; - } - - public String getPipelineName() { - return pipelineName; - } - - @Override - public String toString() { - return "PipelineInfo{" + "pipelineId=" + pipelineId + ", pipelineName='" + pipelineName + '\'' + '}'; - } + private final long pipelineId; + private final String pipelineName; + + public PipelineInfo(long pipelineId, String pipelineName) { + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + } + + public long getPipelineId() { + return pipelineId; + } + + public String getPipelineName() { + return pipelineName; + } + + @Override + public String toString() { + return "PipelineInfo{" + + "pipelineId=" + + pipelineId + + ", pipelineName='" + + pipelineName + + '\'' + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineMessage.java index 5fbcc6c6f..00dcfa47d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineMessage.java @@ -24,32 +24,32 @@ public class PipelineMessage implements PipelineEvent { - private final int edgeId; - private final RecordArgs recordArgs; - private final IMessageIterator messageIterator; - - public PipelineMessage(int edgeId, long batchId, String streamName, IMessageIterator messageIterator) { - this.edgeId = edgeId; - this.recordArgs = new RecordArgs(batchId, streamName); - this.messageIterator = messageIterator; - } - - @Override - public int getEdgeId() { - return this.edgeId; - } - - @Override - public long getWindowId() { - return recordArgs.getWindowId(); - } - - public IMessageIterator getMessageIterator() { - return messageIterator; - } - - public RecordArgs getRecordArgs() { - return recordArgs; - } - + private final int edgeId; + private final RecordArgs recordArgs; + private final IMessageIterator messageIterator; + + public PipelineMessage( + int edgeId, long batchId, String streamName, IMessageIterator messageIterator) { + this.edgeId = edgeId; + this.recordArgs = new RecordArgs(batchId, streamName); + this.messageIterator = messageIterator; + } + + @Override + public int getEdgeId() { + return this.edgeId; + } + + @Override + public long getWindowId() { + return recordArgs.getWindowId(); + } + + public IMessageIterator getMessageIterator() { + return messageIterator; + } + + public RecordArgs getRecordArgs() { + return recordArgs; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineSliceMeta.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineSliceMeta.java index 420a706a4..300888740 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineSliceMeta.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/PipelineSliceMeta.java @@ -23,31 +23,31 @@ public class PipelineSliceMeta extends BaseSliceMeta { - private final SliceId sliceId; - private final ShuffleAddress shuffleAddress; - - public PipelineSliceMeta(int sourceIndex, int targetIndex, long pipelineId, int edgeId, - ShuffleAddress address) { - super(sourceIndex, targetIndex); - this.windowId = -1; - this.sliceId = new SliceId(pipelineId, edgeId, sourceIndex, targetIndex); - this.shuffleAddress = address; - setEdgeId(edgeId); - } - - public PipelineSliceMeta(SliceId sliceId, long windowId, ShuffleAddress address) { - super(sliceId.getShardIndex(), sliceId.getSliceIndex()); - this.windowId = windowId; - this.sliceId = sliceId; - this.shuffleAddress = address; - setEdgeId(sliceId.getEdgeId()); - } - - public SliceId getSliceId() { - return sliceId; - } - - public ShuffleAddress getShuffleAddress() { - return shuffleAddress; - } + private final SliceId sliceId; + private final ShuffleAddress shuffleAddress; + + public PipelineSliceMeta( + int sourceIndex, int targetIndex, long pipelineId, int edgeId, ShuffleAddress address) { + super(sourceIndex, targetIndex); + this.windowId = -1; + this.sliceId = new SliceId(pipelineId, edgeId, sourceIndex, targetIndex); + this.shuffleAddress = address; + setEdgeId(edgeId); + } + + public PipelineSliceMeta(SliceId sliceId, long windowId, ShuffleAddress address) { + super(sliceId.getShardIndex(), sliceId.getSliceIndex()); + this.windowId = windowId; + this.sliceId = sliceId; + this.shuffleAddress = address; + setEdgeId(sliceId.getEdgeId()); + } + + public SliceId getSliceId() { + return sliceId; + } + + public ShuffleAddress getShuffleAddress() { + return shuffleAddress; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/Shard.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/Shard.java index 0b29cd278..610ef2ec1 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/Shard.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/Shard.java @@ -24,22 +24,22 @@ public class Shard implements Serializable { - // An edgeId can identity the relationship between upstream and downstream tasks. - private final int edgeId; + // An edgeId can identity the relationship between upstream and downstream tasks. + private final int edgeId; - // All output slices of the edge. - private final List slices; + // All output slices of the edge. + private final List slices; - public Shard(int edgeId, List slices) { - this.edgeId = edgeId; - this.slices = slices; - } + public Shard(int edgeId, List slices) { + this.edgeId = edgeId; + this.slices = slices; + } - public int getEdgeId() { - return edgeId; - } + public int getEdgeId() { + return edgeId; + } - public List getSlices() { - return slices; - } + public List getSlices() { + return slices; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ShuffleId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ShuffleId.java index 6a5338eb6..475b782f0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ShuffleId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/ShuffleId.java @@ -23,64 +23,62 @@ public class ShuffleId implements Serializable { - private String pipelineName; - private int vertexId; - private int outEdgeId; - private long batchId; - - public ShuffleId() { - } - - public ShuffleId(String pipelineName, int vertexId, int outEdgeId) { - this(pipelineName, vertexId, outEdgeId, 0); - } - - public ShuffleId(String pipelineName, int vertexId, int outEdgeId, long batchId) { - this.pipelineName = pipelineName; - this.vertexId = vertexId; - this.outEdgeId = outEdgeId; - this.batchId = batchId; - } - - public String getPipelineName() { - return pipelineName; - } - - public void setPipelineName(String pipelineName) { - this.pipelineName = pipelineName; - } - - public int getVertexId() { - return vertexId; - } - - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } - - public int getOutEdgeId() { - return outEdgeId; - } - - public void setOutEdgeId(int outEdgeId) { - this.outEdgeId = outEdgeId; - } - - public long getBatchId() { - return batchId; - } - - public void setBatchId(long batchId) { - this.batchId = batchId; + private String pipelineName; + private int vertexId; + private int outEdgeId; + private long batchId; + + public ShuffleId() {} + + public ShuffleId(String pipelineName, int vertexId, int outEdgeId) { + this(pipelineName, vertexId, outEdgeId, 0); + } + + public ShuffleId(String pipelineName, int vertexId, int outEdgeId, long batchId) { + this.pipelineName = pipelineName; + this.vertexId = vertexId; + this.outEdgeId = outEdgeId; + this.batchId = batchId; + } + + public String getPipelineName() { + return pipelineName; + } + + public void setPipelineName(String pipelineName) { + this.pipelineName = pipelineName; + } + + public int getVertexId() { + return vertexId; + } + + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } + + public int getOutEdgeId() { + return outEdgeId; + } + + public void setOutEdgeId(int outEdgeId) { + this.outEdgeId = outEdgeId; + } + + public long getBatchId() { + return batchId; + } + + public void setBatchId(long batchId) { + this.batchId = batchId; + } + + @Override + public String toString() { + if (batchId == 0) { + return String.format("%s-%s-%s", pipelineName, vertexId, outEdgeId); + } else { + return String.format("%s-%s-%s-%s", pipelineName, vertexId, outEdgeId, batchId); } - - @Override - public String toString() { - if (batchId == 0) { - return String.format("%s-%s-%s", pipelineName, vertexId, outEdgeId); - } else { - return String.format("%s-%s-%s-%s", pipelineName, vertexId, outEdgeId, batchId); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/SliceId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/SliceId.java index 162035682..041dc9c63 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/SliceId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/SliceId.java @@ -19,83 +19,92 @@ package org.apache.geaflow.shuffle.message; -import io.netty.buffer.ByteBuf; import java.io.Serializable; import java.util.Objects; -public class SliceId implements Serializable { - private static final long serialVersionUID = 1L; - public static final int SLICE_ID_BYTES = 20; - - private final WriterId writerId; - private final int sliceIndex; - - public SliceId(long pipelineId, int edgeId, int shardIndex, int sliceIndex) { - this.writerId = new WriterId(pipelineId, edgeId, shardIndex); - this.sliceIndex = sliceIndex; - } - - public SliceId(WriterId writerId, int sliceIndex) { - this.writerId = writerId; - this.sliceIndex = sliceIndex; - } - - public long getPipelineId() { - return writerId.getPipelineId(); - } - - public int getEdgeId() { - return writerId.getEdgeId(); - } - - public int getShardIndex() { - return writerId.getShardIndex(); - } - - public int getSliceIndex() { - return sliceIndex; - } - - public WriterId getWriterId() { - return writerId; - } - - public void writeTo(ByteBuf buf) { - buf.writeLong(writerId.getPipelineId()); - buf.writeInt(writerId.getEdgeId()); - buf.writeInt(writerId.getShardIndex()); - buf.writeInt(sliceIndex); - } - - public static SliceId readFrom(ByteBuf buf) { - long execId = buf.readLong(); - int edgeId = buf.readInt(); - int shardIndex = buf.readInt(); - int sliceIndex = buf.readInt(); - return new SliceId(execId, edgeId, shardIndex, sliceIndex); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - SliceId sliceId = (SliceId) o; - return sliceIndex == sliceId.sliceIndex && Objects.equals(writerId, sliceId.writerId); - } +import io.netty.buffer.ByteBuf; - @Override - public int hashCode() { - return Objects.hash(writerId, sliceIndex); +public class SliceId implements Serializable { + private static final long serialVersionUID = 1L; + public static final int SLICE_ID_BYTES = 20; + + private final WriterId writerId; + private final int sliceIndex; + + public SliceId(long pipelineId, int edgeId, int shardIndex, int sliceIndex) { + this.writerId = new WriterId(pipelineId, edgeId, shardIndex); + this.sliceIndex = sliceIndex; + } + + public SliceId(WriterId writerId, int sliceIndex) { + this.writerId = writerId; + this.sliceIndex = sliceIndex; + } + + public long getPipelineId() { + return writerId.getPipelineId(); + } + + public int getEdgeId() { + return writerId.getEdgeId(); + } + + public int getShardIndex() { + return writerId.getShardIndex(); + } + + public int getSliceIndex() { + return sliceIndex; + } + + public WriterId getWriterId() { + return writerId; + } + + public void writeTo(ByteBuf buf) { + buf.writeLong(writerId.getPipelineId()); + buf.writeInt(writerId.getEdgeId()); + buf.writeInt(writerId.getShardIndex()); + buf.writeInt(sliceIndex); + } + + public static SliceId readFrom(ByteBuf buf) { + long execId = buf.readLong(); + int edgeId = buf.readInt(); + int shardIndex = buf.readInt(); + int sliceIndex = buf.readInt(); + return new SliceId(execId, edgeId, shardIndex, sliceIndex); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return "SliceId{" + "pipelineId=" + writerId.getPipelineId() + ", edgeId=" + writerId - .getEdgeId() + ", " + "shardIndex=" + writerId.getShardIndex() + ", sliceIndex=" - + sliceIndex + '}'; + if (o == null || getClass() != o.getClass()) { + return false; } + SliceId sliceId = (SliceId) o; + return sliceIndex == sliceId.sliceIndex && Objects.equals(writerId, sliceId.writerId); + } + + @Override + public int hashCode() { + return Objects.hash(writerId, sliceIndex); + } + + @Override + public String toString() { + return "SliceId{" + + "pipelineId=" + + writerId.getPipelineId() + + ", edgeId=" + + writerId.getEdgeId() + + ", " + + "shardIndex=" + + writerId.getShardIndex() + + ", sliceIndex=" + + sliceIndex + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/WriterId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/WriterId.java index cb85c6add..3bb723fda 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/WriterId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/message/WriterId.java @@ -24,50 +24,56 @@ public class WriterId implements Serializable { - private final long pipelineId; - private final int edgeId; - private final int shardIndex; + private final long pipelineId; + private final int edgeId; + private final int shardIndex; - public WriterId(long pipelineId, int edgeId, int shardIndex) { - this.pipelineId = pipelineId; - this.edgeId = edgeId; - this.shardIndex = shardIndex; - } + public WriterId(long pipelineId, int edgeId, int shardIndex) { + this.pipelineId = pipelineId; + this.edgeId = edgeId; + this.shardIndex = shardIndex; + } - public long getPipelineId() { - return pipelineId; - } + public long getPipelineId() { + return pipelineId; + } - public int getEdgeId() { - return edgeId; - } + public int getEdgeId() { + return edgeId; + } - public int getShardIndex() { - return shardIndex; - } + public int getShardIndex() { + return shardIndex; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - WriterId writerID = (WriterId) o; - return pipelineId == writerID.pipelineId && edgeId == writerID.edgeId - && shardIndex == writerID.shardIndex; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(pipelineId, edgeId, shardIndex); + if (o == null || getClass() != o.getClass()) { + return false; } + WriterId writerID = (WriterId) o; + return pipelineId == writerID.pipelineId + && edgeId == writerID.edgeId + && shardIndex == writerID.shardIndex; + } - @Override - public String toString() { - return "WriterId{" + "pipelineId=" + pipelineId + ", edgeId=" + edgeId + ", shardIndex=" - + shardIndex + '}'; - } + @Override + public int hashCode() { + return Objects.hash(pipelineId, edgeId, shardIndex); + } + @Override + public String toString() { + return "WriterId{" + + "pipelineId=" + + pipelineId + + ", edgeId=" + + edgeId + + ", shardIndex=" + + shardIndex + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ConnectionId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ConnectionId.java index b06bc198f..3de5dd2aa 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ConnectionId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ConnectionId.java @@ -21,46 +21,46 @@ import java.io.Serializable; import java.net.InetSocketAddress; + import org.apache.geaflow.common.shuffle.ShuffleAddress; public class ConnectionId implements Serializable { - private static final long serialVersionUID = -8068626194818666857L; + private static final long serialVersionUID = -8068626194818666857L; - private final InetSocketAddress address; - private final int connectionIndex; + private final InetSocketAddress address; + private final int connectionIndex; - public ConnectionId(ShuffleAddress shuffleAddress, int connectionIndex) { - this.address = new InetSocketAddress(shuffleAddress.host(), shuffleAddress.port()); - this.connectionIndex = connectionIndex; - } + public ConnectionId(ShuffleAddress shuffleAddress, int connectionIndex) { + this.address = new InetSocketAddress(shuffleAddress.host(), shuffleAddress.port()); + this.connectionIndex = connectionIndex; + } - public InetSocketAddress getAddress() { - return address; - } + public InetSocketAddress getAddress() { + return address; + } - public int getConnectionIndex() { - return connectionIndex; - } + public int getConnectionIndex() { + return connectionIndex; + } - @Override - public int hashCode() { - return address.hashCode() + (31 * connectionIndex); - } - - @Override - public boolean equals(Object other) { - if (other.getClass() != ConnectionId.class) { - return false; - } + @Override + public int hashCode() { + return address.hashCode() + (31 * connectionIndex); + } - final ConnectionId id = (ConnectionId) other; - return id.getAddress().equals(address) && id.getConnectionIndex() == connectionIndex; + @Override + public boolean equals(Object other) { + if (other.getClass() != ConnectionId.class) { + return false; } - @Override - public String toString() { - return address + " [" + connectionIndex + "]"; - } + final ConnectionId id = (ConnectionId) other; + return id.getAddress().equals(address) && id.getConnectionIndex() == connectionIndex; + } + @Override + public String toString() { + return address + " [" + connectionIndex + "]"; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/IConnectionManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/IConnectionManager.java index 0f96717ff..7b78acd70 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/IConnectionManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/IConnectionManager.java @@ -19,54 +19,55 @@ package org.apache.geaflow.shuffle.network; -import io.netty.buffer.PooledByteBufAllocator; import java.io.IOException; import java.util.concurrent.ExecutorService; + import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.shuffle.config.ShuffleConfig; -public interface IConnectionManager { +import io.netty.buffer.PooledByteBufAllocator; - /** - * Get the shuffle config. - * - * @return shuffle config. - */ - ShuffleConfig getShuffleConfig(); +public interface IConnectionManager { - /** - * Client buffer. - * - * @return buffer. - */ - PooledByteBufAllocator getClientBufAllocator(); + /** + * Get the shuffle config. + * + * @return shuffle config. + */ + ShuffleConfig getShuffleConfig(); - /** - * Server buffer. - * - * @return buffer. - */ - PooledByteBufAllocator getServerBufAllocator(); + /** + * Client buffer. + * + * @return buffer. + */ + PooledByteBufAllocator getClientBufAllocator(); - /** - * Get the shuffle address. - * - * @return shuffle address. - */ - ShuffleAddress getShuffleAddress(); + /** + * Server buffer. + * + * @return buffer. + */ + PooledByteBufAllocator getServerBufAllocator(); - /** - * Close connection manager. - * - * @throws IOException io exception. - */ - void close() throws IOException; + /** + * Get the shuffle address. + * + * @return shuffle address. + */ + ShuffleAddress getShuffleAddress(); - /** - * Get the thread pool to async callback action. - * - * @return thread pool. - */ - ExecutorService getExecutor(); + /** + * Close connection manager. + * + * @throws IOException io exception. + */ + void close() throws IOException; + /** + * Get the thread pool to async callback action. + * + * @return thread pool. + */ + ExecutorService getExecutor(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ITransportContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ITransportContext.java index b5b991bdf..eace964e8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ITransportContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/ITransportContext.java @@ -24,20 +24,19 @@ public interface ITransportContext { - /** - * Create server channel handlers. - * - * @param channel netty channel. - * @return handlers. - */ - ChannelHandler[] createServerChannelHandler(Channel channel); - - /** - * Create client channel handlers. - * - * @param channel netty channel. - * @return handlers. - */ - ChannelHandler[] createClientChannelHandlers(Channel channel); + /** + * Create server channel handlers. + * + * @param channel netty channel. + * @return handlers. + */ + ChannelHandler[] createServerChannelHandler(Channel channel); + /** + * Create client channel handlers. + * + * @param channel netty channel. + * @return handlers. + */ + ChannelHandler[] createClientChannelHandlers(Channel channel); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/NettyUtils.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/NettyUtils.java index 21248dd19..0a27d9a7b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/NettyUtils.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/NettyUtils.java @@ -19,71 +19,65 @@ package org.apache.geaflow.shuffle.network; -import io.netty.buffer.PooledByteBufAllocator; -import io.netty.channel.Channel; -import io.netty.util.internal.PlatformDependent; import java.io.Closeable; import java.io.IOException; import java.util.concurrent.ThreadFactory; + import org.apache.geaflow.common.utils.ThreadUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.util.internal.PlatformDependent; + public class NettyUtils { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyUtils.class); + private static final Logger LOGGER = LoggerFactory.getLogger(NettyUtils.class); - public static ThreadFactory getNamedThreadFactory(String name) { - return ThreadUtil.namedThreadFactory(true, name); - } + public static ThreadFactory getNamedThreadFactory(String name) { + return ThreadUtil.namedThreadFactory(true, name); + } - /** - * Returns the remote address on the channel or "<unknown remote>" if none exists. - */ - public static String getRemoteAddress(Channel channel) { - if (channel != null && channel.remoteAddress() != null) { - return channel.remoteAddress().toString(); - } - return ""; + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); } + return ""; + } - /** - * Create a pooled ByteBuf allocator. - */ - public static PooledByteBufAllocator createPooledByteBufAllocator( - boolean allowDirectBufs, - boolean allowCache, - int numCores) { - if (numCores == 0) { - numCores = Runtime.getRuntime().availableProcessors(); - } - boolean preferDirect = allowDirectBufs && PlatformDependent.directBufferPreferred(); - LOGGER.info("create a PooledByteBufAllocator: preferDirect={}, allowCache={}", - preferDirect, allowCache); - return new PooledByteBufAllocator( - preferDirect, - Math.min(PooledByteBufAllocator.defaultNumHeapArena(), preferDirect ? 0 : numCores), - Math.min(PooledByteBufAllocator.defaultNumDirectArena(), preferDirect ? numCores : 0), - PooledByteBufAllocator.defaultPageSize(), - PooledByteBufAllocator.defaultMaxOrder(), - allowCache ? PooledByteBufAllocator.defaultTinyCacheSize() : 0, - allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0, - allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0, - allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false - ); + /** Create a pooled ByteBuf allocator. */ + public static PooledByteBufAllocator createPooledByteBufAllocator( + boolean allowDirectBufs, boolean allowCache, int numCores) { + if (numCores == 0) { + numCores = Runtime.getRuntime().availableProcessors(); } + boolean preferDirect = allowDirectBufs && PlatformDependent.directBufferPreferred(); + LOGGER.info( + "create a PooledByteBufAllocator: preferDirect={}, allowCache={}", + preferDirect, + allowCache); + return new PooledByteBufAllocator( + preferDirect, + Math.min(PooledByteBufAllocator.defaultNumHeapArena(), preferDirect ? 0 : numCores), + Math.min(PooledByteBufAllocator.defaultNumDirectArena(), preferDirect ? numCores : 0), + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + allowCache ? PooledByteBufAllocator.defaultTinyCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0, + allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false); + } - /** - * Closes the given object, ignoring IOExceptions. - */ - public static void closeQuietly(Closeable closeable) { - try { - if (closeable != null) { - closeable.close(); - } - } catch (IOException e) { - LOGGER.error("IOException should not have been thrown.", e); - } + /** Closes the given object, ignoring IOExceptions. */ + public static void closeQuietly(Closeable closeable) { + try { + if (closeable != null) { + closeable.close(); + } + } catch (IOException e) { + LOGGER.error("IOException should not have been thrown.", e); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/ConnectionManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/ConnectionManager.java index e62e8fb9e..862f482d4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/ConnectionManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/ConnectionManager.java @@ -19,13 +19,13 @@ package org.apache.geaflow.shuffle.network.netty; -import io.netty.buffer.PooledByteBufAllocator; import java.io.IOException; import java.net.InetSocketAddress; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.common.utils.ThreadUtil; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -35,78 +35,85 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ConnectionManager implements IConnectionManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionManager.class); - - private final ShuffleConfig nettyConfig; - private final ShuffleAddress shuffleAddress; - private final SliceRequestClientFactory clientFactory; - private final ExecutorService executor; - - private NettyServer server; - private NettyClient client; - - public ConnectionManager(ShuffleConfig config) { - ITransportContext context = new NettyContext(config); - this.client = new NettyClient(config, context); - this.server = new NettyServer(config, context); - InetSocketAddress address = server.start(); - this.shuffleAddress = new ShuffleAddress(address.getAddress().getHostAddress(), - address.getPort()); - this.clientFactory = new SliceRequestClientFactory(config, client); - this.nettyConfig = config; - this.executor = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue(), ThreadUtil.namedThreadFactory(true, "connect")); - } - - public ShuffleAddress getShuffleAddress() { - return shuffleAddress; - } - - public ShuffleConfig getShuffleConfig() { - return nettyConfig; - } - - public PooledByteBufAllocator getServerBufAllocator() { - return server.getPooledAllocator(); - } - - public PooledByteBufAllocator getClientBufAllocator() { - return client.getAllocator(); - } - - public NettyClient getClient() { - return client; - } +import io.netty.buffer.PooledByteBufAllocator; - public SliceRequestClient createSliceRequestClient(ConnectionId connectionId) - throws IOException, InterruptedException { - return clientFactory.createSliceRequestClient(connectionId); - } +public class ConnectionManager implements IConnectionManager { - public void closeOpenChannelConnections(ConnectionId connectionId) { - clientFactory.closeOpenChannelConnections(connectionId); + private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionManager.class); + + private final ShuffleConfig nettyConfig; + private final ShuffleAddress shuffleAddress; + private final SliceRequestClientFactory clientFactory; + private final ExecutorService executor; + + private NettyServer server; + private NettyClient client; + + public ConnectionManager(ShuffleConfig config) { + ITransportContext context = new NettyContext(config); + this.client = new NettyClient(config, context); + this.server = new NettyServer(config, context); + InetSocketAddress address = server.start(); + this.shuffleAddress = + new ShuffleAddress(address.getAddress().getHostAddress(), address.getPort()); + this.clientFactory = new SliceRequestClientFactory(config, client); + this.nettyConfig = config; + this.executor = + new ThreadPoolExecutor( + 1, + 1, + 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue(), + ThreadUtil.namedThreadFactory(true, "connect")); + } + + public ShuffleAddress getShuffleAddress() { + return shuffleAddress; + } + + public ShuffleConfig getShuffleConfig() { + return nettyConfig; + } + + public PooledByteBufAllocator getServerBufAllocator() { + return server.getPooledAllocator(); + } + + public PooledByteBufAllocator getClientBufAllocator() { + return client.getAllocator(); + } + + public NettyClient getClient() { + return client; + } + + public SliceRequestClient createSliceRequestClient(ConnectionId connectionId) + throws IOException, InterruptedException { + return clientFactory.createSliceRequestClient(connectionId); + } + + public void closeOpenChannelConnections(ConnectionId connectionId) { + clientFactory.closeOpenChannelConnections(connectionId); + } + + public void close() throws IOException { + LOGGER.info("closing connection manager"); + if (server != null) { + server.close(); + server = null; } - - public void close() throws IOException { - LOGGER.info("closing connection manager"); - if (server != null) { - server.close(); - server = null; - } - if (client != null) { - client.shutdown(); - client = null; - } - if (executor != null) { - executor.shutdown(); - } + if (client != null) { + client.shutdown(); + client = null; } - - @Override - public ExecutorService getExecutor() { - return executor; + if (executor != null) { + executor.shutdown(); } + } + @Override + public ExecutorService getExecutor() { + return executor; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyClient.java index c7bd011f7..6f7547f29 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyClient.java @@ -21,6 +21,15 @@ import static com.google.common.base.Preconditions.checkState; +import java.net.InetSocketAddress; + +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.apache.geaflow.shuffle.config.ShuffleConfig; +import org.apache.geaflow.shuffle.network.ITransportContext; +import org.apache.geaflow.shuffle.network.NettyUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.Bootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelException; @@ -33,145 +42,144 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; -import java.net.InetSocketAddress; -import org.apache.geaflow.common.exception.GeaflowRuntimeException; -import org.apache.geaflow.shuffle.config.ShuffleConfig; -import org.apache.geaflow.shuffle.network.ITransportContext; -import org.apache.geaflow.shuffle.network.NettyUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.NettyClient. - */ +/** This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.NettyClient. */ public class NettyClient { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyClient.class); - private static final String CLIENT_THREAD_GROUP_NAME = "NettyClient"; - - private final ShuffleConfig config; - private final PooledByteBufAllocator allocator; - private Bootstrap bootstrap; - - public NettyClient(ShuffleConfig config, ITransportContext context) { - this.config = config; - this.bootstrap = new Bootstrap(); - - final long start = System.nanoTime(); - - // -------------------------------------------------------------------- - // Transport-specific configuration - // -------------------------------------------------------------------- - if (Epoll.isAvailable()) { - initEpollBootstrap(); - LOGGER.info("Transport type 'auto': using EPOLL."); - } else { - initNioBootstrap(); - LOGGER.info("Transport type 'auto': using NIO."); - } - - // -------------------------------------------------------------------- - // Configuration - // -------------------------------------------------------------------- - - bootstrap.option(ChannelOption.TCP_NODELAY, true); - bootstrap.option(ChannelOption.SO_KEEPALIVE, true); - - // Timeout for new connections. - bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.getConnectTimeoutMs()); - - // Pooled allocator for Netty's ByteBuf instances. - allocator = NettyUtils.createPooledByteBufAllocator(config.preferDirectBuffer(), - config.isThreadCacheEnabled(), config.getClientNumThreads()); - bootstrap.option(ChannelOption.ALLOCATOR, allocator); - - // Receive and send buffer size. - int sendBufferSize = config.getSendBufferSize(); - if (sendBufferSize > 0) { - bootstrap.option(ChannelOption.SO_SNDBUF, sendBufferSize); - } - int receiveBufferSize = config.getReceiveBufferSize(); - if (receiveBufferSize > 0) { - bootstrap.option(ChannelOption.SO_RCVBUF, receiveBufferSize); - } - - bootstrap.handler(new ChannelInitializer() { - @Override - public void initChannel(SocketChannel channel) throws Exception { - channel.pipeline().addLast(context.createClientChannelHandlers(channel)); - } - }); + private static final Logger LOGGER = LoggerFactory.getLogger(NettyClient.class); + private static final String CLIENT_THREAD_GROUP_NAME = "NettyClient"; + + private final ShuffleConfig config; + private final PooledByteBufAllocator allocator; + private Bootstrap bootstrap; + + public NettyClient(ShuffleConfig config, ITransportContext context) { + this.config = config; + this.bootstrap = new Bootstrap(); + + final long start = System.nanoTime(); - final long duration = (System.nanoTime() - start) / 1_000_000; - LOGGER.info("Successful initialization (took {} ms).", duration); + // -------------------------------------------------------------------- + // Transport-specific configuration + // -------------------------------------------------------------------- + if (Epoll.isAvailable()) { + initEpollBootstrap(); + LOGGER.info("Transport type 'auto': using EPOLL."); + } else { + initNioBootstrap(); + LOGGER.info("Transport type 'auto': using NIO."); } - private void initNioBootstrap() { - // Add the server port number to the name in order to distinguish - // multiple clients running on the same host. + // -------------------------------------------------------------------- + // Configuration + // -------------------------------------------------------------------- - NioEventLoopGroup nioGroup = new NioEventLoopGroup(config.getClientNumThreads(), - NettyUtils.getNamedThreadFactory(CLIENT_THREAD_GROUP_NAME)); - bootstrap.group(nioGroup).channel(NioSocketChannel.class); + bootstrap.option(ChannelOption.TCP_NODELAY, true); + bootstrap.option(ChannelOption.SO_KEEPALIVE, true); + + // Timeout for new connections. + bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, config.getConnectTimeoutMs()); + + // Pooled allocator for Netty's ByteBuf instances. + allocator = + NettyUtils.createPooledByteBufAllocator( + config.preferDirectBuffer(), + config.isThreadCacheEnabled(), + config.getClientNumThreads()); + bootstrap.option(ChannelOption.ALLOCATOR, allocator); + + // Receive and send buffer size. + int sendBufferSize = config.getSendBufferSize(); + if (sendBufferSize > 0) { + bootstrap.option(ChannelOption.SO_SNDBUF, sendBufferSize); } + int receiveBufferSize = config.getReceiveBufferSize(); + if (receiveBufferSize > 0) { + bootstrap.option(ChannelOption.SO_RCVBUF, receiveBufferSize); + } + + bootstrap.handler( + new ChannelInitializer() { + @Override + public void initChannel(SocketChannel channel) throws Exception { + channel.pipeline().addLast(context.createClientChannelHandlers(channel)); + } + }); + + final long duration = (System.nanoTime() - start) / 1_000_000; + LOGGER.info("Successful initialization (took {} ms).", duration); + } - private void initEpollBootstrap() { - // Add the server port number to the name in order to distinguish - // multiple clients running on the same host. + private void initNioBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple clients running on the same host. - EpollEventLoopGroup epollGroup = new EpollEventLoopGroup(config.getClientNumThreads(), + NioEventLoopGroup nioGroup = + new NioEventLoopGroup( + config.getClientNumThreads(), NettyUtils.getNamedThreadFactory(CLIENT_THREAD_GROUP_NAME)); - bootstrap.group(epollGroup).channel(EpollSocketChannel.class); - } + bootstrap.group(nioGroup).channel(NioSocketChannel.class); + } - // ------------------------------------------------------------------------ - // Client connections - // ------------------------------------------------------------------------ - public ChannelFuture connect(final InetSocketAddress serverSocketAddress) { - checkState(null != bootstrap, "Client has not been initialized yet."); - - try { - return bootstrap.connect(serverSocketAddress); - } catch (ChannelException e) { - final String message = "Too many open files"; - if ((e.getCause() instanceof java.net.SocketException && message - .equals(e.getCause().getMessage())) || (e.getCause() instanceof ChannelException - && e.getCause().getCause() instanceof java.net.SocketException && message - .equals(e.getCause().getCause().getMessage()))) { - throw new GeaflowRuntimeException( - "The operating system does not offer enough file handles to open the network " - + "connection. Please increase the number of available file handles.", - e.getCause()); - } else { - throw e; - } - } - } + private void initEpollBootstrap() { + // Add the server port number to the name in order to distinguish + // multiple clients running on the same host. - public ShuffleConfig getConfig() { - return config; + EpollEventLoopGroup epollGroup = + new EpollEventLoopGroup( + config.getClientNumThreads(), + NettyUtils.getNamedThreadFactory(CLIENT_THREAD_GROUP_NAME)); + bootstrap.group(epollGroup).channel(EpollSocketChannel.class); + } + + // ------------------------------------------------------------------------ + // Client connections + // ------------------------------------------------------------------------ + public ChannelFuture connect(final InetSocketAddress serverSocketAddress) { + checkState(null != bootstrap, "Client has not been initialized yet."); + + try { + return bootstrap.connect(serverSocketAddress); + } catch (ChannelException e) { + final String message = "Too many open files"; + if ((e.getCause() instanceof java.net.SocketException + && message.equals(e.getCause().getMessage())) + || (e.getCause() instanceof ChannelException + && e.getCause().getCause() instanceof java.net.SocketException + && message.equals(e.getCause().getCause().getMessage()))) { + throw new GeaflowRuntimeException( + "The operating system does not offer enough file handles to open the network " + + "connection. Please increase the number of available file handles.", + e.getCause()); + } else { + throw e; + } } + } - public PooledByteBufAllocator getAllocator() { - return allocator; - } + public ShuffleConfig getConfig() { + return config; + } - public void shutdown() { - final long start = System.nanoTime(); + public PooledByteBufAllocator getAllocator() { + return allocator; + } - if (bootstrap != null) { - if (bootstrap.group() != null) { - bootstrap.group().shutdownGracefully(); - } - bootstrap = null; - } + public void shutdown() { + final long start = System.nanoTime(); - final long duration = (System.nanoTime() - start) / 1_000_000; - LOGGER.info("Successful shutdown (took {} ms).", duration); + if (bootstrap != null) { + if (bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + bootstrap = null; } + final long duration = (System.nanoTime() - start) / 1_000_000; + LOGGER.info("Successful shutdown (took {} ms).", duration); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyContext.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyContext.java index 395acf175..55426da50 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyContext.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyContext.java @@ -19,64 +19,70 @@ package org.apache.geaflow.shuffle.network.netty; -import io.netty.channel.Channel; -import io.netty.channel.ChannelHandler; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import org.apache.geaflow.shuffle.config.ShuffleConfig; import org.apache.geaflow.shuffle.network.ITransportContext; import org.apache.geaflow.shuffle.network.protocol.NettyMessageDecoder; import org.apache.geaflow.shuffle.network.protocol.NettyMessageEncoder; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + public class NettyContext implements ITransportContext { - private final NettyMessageEncoder messageEncoder = new NettyMessageEncoder(); - private final NettyMessageDecoder messageDecoder = new NettyMessageDecoder(); - private final ShuffleConfig config; + private final NettyMessageEncoder messageEncoder = new NettyMessageEncoder(); + private final NettyMessageDecoder messageDecoder = new NettyMessageDecoder(); + private final ShuffleConfig config; - public NettyContext(ShuffleConfig config) { - this.config = config; - } + public NettyContext(ShuffleConfig config) { + this.config = config; + } - /** - * Create the frame length decoder. - * +------------------+------------------+--------++----------------+ - * | FRAME LENGTH (4) | MAGIC NUMBER (4) | ID (1) || CUSTOM MESSAGE | - * +------------------+------------------+--------++----------------+ - * - * @return decoder. - */ - private ChannelHandler createFrameLengthDecoder() { - if (config.enableCustomFrameDecoder()) { - return new NettyFrameDecoder(); - } - return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, -4, 4); + /** + * Create the frame length decoder. + * +------------------+------------------+--------++----------------+ | FRAME LENGTH (4) | MAGIC + * NUMBER (4) | ID (1) || CUSTOM MESSAGE | + * +------------------+------------------+--------++----------------+ + * + * @return decoder. + */ + private ChannelHandler createFrameLengthDecoder() { + if (config.enableCustomFrameDecoder()) { + return new NettyFrameDecoder(); } + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, -4, 4); + } - /** - * Create channel handler on server for batch & pipeline shuffle. - * - * @return handlers. - */ - @Override - public ChannelHandler[] createServerChannelHandler(Channel channel) { - SliceOutputChannelHandler queueOfPartitionQueue = new SliceOutputChannelHandler(); - SliceRequestServerHandler sliceRequestHandler = new SliceRequestServerHandler( - queueOfPartitionQueue); + /** + * Create channel handler on server for batch & pipeline shuffle. + * + * @return handlers. + */ + @Override + public ChannelHandler[] createServerChannelHandler(Channel channel) { + SliceOutputChannelHandler queueOfPartitionQueue = new SliceOutputChannelHandler(); + SliceRequestServerHandler sliceRequestHandler = + new SliceRequestServerHandler(queueOfPartitionQueue); - return new ChannelHandler[]{messageEncoder, createFrameLengthDecoder(), messageDecoder, - sliceRequestHandler, queueOfPartitionQueue}; - } + return new ChannelHandler[] { + messageEncoder, + createFrameLengthDecoder(), + messageDecoder, + sliceRequestHandler, + queueOfPartitionQueue + }; + } - /** - * Create channel handler on client for pipeline shuffle. - * - * @return handlers. - */ - public ChannelHandler[] createClientChannelHandlers(Channel channel) { - SliceRequestClientHandler networkClientHandler = new SliceRequestClientHandler(); - - return new ChannelHandler[]{messageEncoder, createFrameLengthDecoder(), messageDecoder, - networkClientHandler}; - } + /** + * Create channel handler on client for pipeline shuffle. + * + * @return handlers. + */ + public ChannelHandler[] createClientChannelHandlers(Channel channel) { + SliceRequestClientHandler networkClientHandler = new SliceRequestClientHandler(); + return new ChannelHandler[] { + messageEncoder, createFrameLengthDecoder(), messageDecoder, networkClientHandler + }; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoder.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoder.java index 301a1fa1e..30916d5b0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoder.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoder.java @@ -19,204 +19,201 @@ package org.apache.geaflow.shuffle.network.netty; +import java.util.LinkedList; + import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import java.util.LinkedList; /* This file is based on source code from the Spark Project (http://spark.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Spark's org.apache.spark.network.util.TransportFrameDecoder. - */ +/** This class is an adaptation of Spark's org.apache.spark.network.util.TransportFrameDecoder. */ public class NettyFrameDecoder extends ChannelInboundHandlerAdapter { - private static final int LENGTH_SIZE = 4; - private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; - private static final int UNKNOWN_FRAME_SIZE = -1; - private static final long CONSOLIDATE_THRESHOLD = 20 * 1024 * 1024; - - private final LinkedList buffers = new LinkedList<>(); - private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - private final long consolidateThreshold; - - private CompositeByteBuf frameBuf = null; - private long consolidatedFrameBufSize = 0; - private int consolidatedNumComponents = 0; - - private long totalSize = 0; - private long nextFrameSize = UNKNOWN_FRAME_SIZE; - private int frameRemainingBytes = UNKNOWN_FRAME_SIZE; - - public NettyFrameDecoder() { - this(CONSOLIDATE_THRESHOLD); + private static final int LENGTH_SIZE = 4; + private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + private static final long CONSOLIDATE_THRESHOLD = 20 * 1024 * 1024; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); + private final long consolidateThreshold; + + private CompositeByteBuf frameBuf = null; + private long consolidatedFrameBufSize = 0; + private int consolidatedNumComponents = 0; + + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; + private int frameRemainingBytes = UNKNOWN_FRAME_SIZE; + + public NettyFrameDecoder() { + this(CONSOLIDATE_THRESHOLD); + } + + @VisibleForTesting + NettyFrameDecoder(long consolidateThreshold) { + this.consolidateThreshold = consolidateThreshold; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); } + } - @VisibleForTesting - NettyFrameDecoder(long consolidateThreshold) { - this.consolidateThreshold = consolidateThreshold; + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; } - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf in = (ByteBuf) data; - buffers.add(in); - totalSize += in.readableBytes(); - - while (!buffers.isEmpty()) { - ByteBuf frame = decodeNext(); - if (frame == null) { - break; - } - ctx.fireChannelRead(frame); - } + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readInt() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + return nextFrameSize; } - private long decodeFrameSize() { - if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { - return nextFrameSize; - } - - // We know there's enough data. If the first buffer contains all the data, great. Otherwise, - // hold the bytes for the frame length in a composite buffer until we have enough data to read - // the frame size. Normally, it should be rare to need more than one buffer to read the frame - // size. - ByteBuf first = buffers.getFirst(); - if (first.readableBytes() >= LENGTH_SIZE) { - nextFrameSize = first.readInt() - LENGTH_SIZE; - totalSize -= LENGTH_SIZE; - if (!first.isReadable()) { - buffers.removeFirst().release(); - } - return nextFrameSize; - } - - while (frameLenBuf.readableBytes() < LENGTH_SIZE) { - ByteBuf next = buffers.getFirst(); - int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); - frameLenBuf.writeBytes(next, toRead); - if (!next.isReadable()) { - buffers.removeFirst().release(); - } - } - - nextFrameSize = frameLenBuf.readInt() - LENGTH_SIZE; - totalSize -= LENGTH_SIZE; - frameLenBuf.clear(); - return nextFrameSize; + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } } - private ByteBuf decodeNext() { - long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE) { - return null; - } - - if (frameBuf == null) { - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, - "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, - "Frame length should be positive: %s", frameSize); - frameRemainingBytes = (int) frameSize; - - // If buffers is empty, then return immediately for more input data. - if (buffers.isEmpty()) { - return null; - } - // Otherwise, if the first buffer holds the entire frame, we attempt to - // build frame with it and return. - if (buffers.getFirst().readableBytes() >= frameRemainingBytes) { - // Reset buf and size for next frame. - frameBuf = null; - nextFrameSize = UNKNOWN_FRAME_SIZE; - return nextBufferForFrame(frameRemainingBytes); - } - // Other cases, create a composite buffer to manage all the buffers. - frameBuf = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); - } - - while (frameRemainingBytes > 0 && !buffers.isEmpty()) { - ByteBuf next = nextBufferForFrame(frameRemainingBytes); - frameRemainingBytes -= next.readableBytes(); - frameBuf.addComponent(true, next); - } - // If the delta size of frameBuf exceeds the threshold, then we do consolidation - // to reduce memory consumption. - if (frameBuf.capacity() - consolidatedFrameBufSize > consolidateThreshold) { - int newNumComponents = frameBuf.numComponents() - consolidatedNumComponents; - frameBuf.consolidate(consolidatedNumComponents, newNumComponents); - consolidatedFrameBufSize = frameBuf.capacity(); - consolidatedNumComponents = frameBuf.numComponents(); - } - if (frameRemainingBytes > 0) { - return null; - } - - return consumeCurrentFrameBuf(); + nextFrameSize = frameLenBuf.readInt() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; + } + + private ByteBuf decodeNext() { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE) { + return null; } - private ByteBuf consumeCurrentFrameBuf() { - final ByteBuf frame = frameBuf; + if (frameBuf == null) { + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + frameRemainingBytes = (int) frameSize; + + // If buffers is empty, then return immediately for more input data. + if (buffers.isEmpty()) { + return null; + } + // Otherwise, if the first buffer holds the entire frame, we attempt to + // build frame with it and return. + if (buffers.getFirst().readableBytes() >= frameRemainingBytes) { // Reset buf and size for next frame. frameBuf = null; - consolidatedFrameBufSize = 0; - consolidatedNumComponents = 0; nextFrameSize = UNKNOWN_FRAME_SIZE; - return frame; + return nextBufferForFrame(frameRemainingBytes); + } + // Other cases, create a composite buffer to manage all the buffers. + frameBuf = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); } - /** - * Takes the first buffer in the internal list, and either adjust it to fit in the frame - * (by taking a slice out of it) or remove it from the internal list. - */ - private ByteBuf nextBufferForFrame(int bytesToRead) { - ByteBuf buf = buffers.getFirst(); - ByteBuf frame; - - if (buf.readableBytes() > bytesToRead) { - frame = buf.retain().readSlice(bytesToRead); - totalSize -= bytesToRead; - } else { - frame = buf; - buffers.removeFirst(); - totalSize -= frame.readableBytes(); - } - - return frame; + while (frameRemainingBytes > 0 && !buffers.isEmpty()) { + ByteBuf next = nextBufferForFrame(frameRemainingBytes); + frameRemainingBytes -= next.readableBytes(); + frameBuf.addComponent(true, next); } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - super.channelInactive(ctx); + // If the delta size of frameBuf exceeds the threshold, then we do consolidation + // to reduce memory consumption. + if (frameBuf.capacity() - consolidatedFrameBufSize > consolidateThreshold) { + int newNumComponents = frameBuf.numComponents() - consolidatedNumComponents; + frameBuf.consolidate(consolidatedNumComponents, newNumComponents); + consolidatedFrameBufSize = frameBuf.capacity(); + consolidatedNumComponents = frameBuf.numComponents(); } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - super.exceptionCaught(ctx, cause); + if (frameRemainingBytes > 0) { + return null; } - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - // Release all buffers that are still in our ownership. - // Doing this in handlerRemoved(...) guarantees that this will happen in all cases: - // - When the Channel becomes inactive - // - When the decoder is removed from the ChannelPipeline - for (ByteBuf b : buffers) { - b.release(); - } - buffers.clear(); - frameLenBuf.release(); - ByteBuf frame = consumeCurrentFrameBuf(); - if (frame != null) { - frame.release(); - } - super.handlerRemoved(ctx); + return consumeCurrentFrameBuf(); + } + + private ByteBuf consumeCurrentFrameBuf() { + final ByteBuf frame = frameBuf; + // Reset buf and size for next frame. + frameBuf = null; + consolidatedFrameBufSize = 0; + consolidatedNumComponents = 0; + nextFrameSize = UNKNOWN_FRAME_SIZE; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame (by + * taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); } + return frame; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // Release all buffers that are still in our ownership. + // Doing this in handlerRemoved(...) guarantees that this will happen in all cases: + // - When the Channel becomes inactive + // - When the decoder is removed from the ChannelPipeline + for (ByteBuf b : buffers) { + b.release(); + } + buffers.clear(); + frameLenBuf.release(); + ByteBuf frame = consumeCurrentFrameBuf(); + if (frame != null) { + frame.release(); + } + super.handlerRemoved(ctx); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyServer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyServer.java index 628e329da..a29ad4f14 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyServer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/NettyServer.java @@ -21,6 +21,17 @@ import static com.google.common.base.Preconditions.checkState; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; + +import org.apache.commons.lang3.SystemUtils; +import org.apache.geaflow.shuffle.config.ShuffleConfig; +import org.apache.geaflow.shuffle.network.ITransportContext; +import org.apache.geaflow.shuffle.network.NettyUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelFuture; @@ -32,138 +43,133 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; -import java.io.Closeable; -import java.io.IOException; -import java.net.InetSocketAddress; -import org.apache.commons.lang3.SystemUtils; -import org.apache.geaflow.shuffle.config.ShuffleConfig; -import org.apache.geaflow.shuffle.network.ITransportContext; -import org.apache.geaflow.shuffle.network.NettyUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.NettyServer. - */ +/** This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.NettyServer. */ public class NettyServer implements Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyServer.class); - private static final String SERVER_THREAD_GROUP_NAME = "NettyServer"; - - private ServerBootstrap bootstrap; - private ChannelFuture bindFuture; - private ShuffleConfig config; - private ITransportContext context; - private PooledByteBufAllocator pooledAllocator; + private static final Logger LOGGER = LoggerFactory.getLogger(NettyServer.class); + private static final String SERVER_THREAD_GROUP_NAME = "NettyServer"; - public NettyServer(ShuffleConfig config, ITransportContext transportContext) { - this.config = config; - this.context = transportContext; - } + private ServerBootstrap bootstrap; + private ChannelFuture bindFuture; + private ShuffleConfig config; + private ITransportContext context; + private PooledByteBufAllocator pooledAllocator; - public InetSocketAddress start() { - checkState(bootstrap == null, "Netty server has already been initialized."); - - final long start = System.currentTimeMillis(); - bootstrap = new ServerBootstrap(); - - if (Epoll.isAvailable()) { - initEpollBootstrap(config); - } else { - initNioBootstrap(config); - } - - bootstrap.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); - // Pooled allocators for Netty's ByteBuf instances - pooledAllocator = NettyUtils - .createPooledByteBufAllocator(config.preferDirectBuffer(), config.isThreadCacheEnabled(), - config.getServerThreadsNum()); - bootstrap.option(ChannelOption.ALLOCATOR, pooledAllocator); - bootstrap.childOption(ChannelOption.ALLOCATOR, pooledAllocator); - - if (config.getServerConnectBacklog() > 0) { - bootstrap.option(ChannelOption.SO_BACKLOG, config.getServerConnectBacklog()); - } - - int receiveBufferSize = config.getReceiveBufferSize(); - if (receiveBufferSize > 0) { - bootstrap.childOption(ChannelOption.SO_RCVBUF, receiveBufferSize); - } - int sendBufferSize = config.getSendBufferSize(); - if (sendBufferSize > 0) { - bootstrap.childOption(ChannelOption.SO_SNDBUF, sendBufferSize); - } - - // -------------------------------------------------------------------- - // Child channel pipeline for accepted connections - // -------------------------------------------------------------------- - - bootstrap.childHandler(new ChannelInitializer() { - @Override - public void initChannel(SocketChannel channel) throws Exception { - channel.pipeline().addLast(context.createServerChannelHandler(channel)); - //context.initializePipeline(channel); - } - }); + public NettyServer(ShuffleConfig config, ITransportContext transportContext) { + this.config = config; + this.context = transportContext; + } - // -------------------------------------------------------------------- - // Start Server - // -------------------------------------------------------------------- + public InetSocketAddress start() { + checkState(bootstrap == null, "Netty server has already been initialized."); - bootstrap.localAddress(config.getServerAddress(), config.getServerPort()); - bindFuture = bootstrap.bind().syncUninterruptibly(); - InetSocketAddress localAddress = (InetSocketAddress) bindFuture.channel().localAddress(); + final long start = System.currentTimeMillis(); + bootstrap = new ServerBootstrap(); - long end = System.currentTimeMillis(); - LOGGER.info("Successful initialization (took {} ms). Listening on {}. NettyConfig: {}", - (end - start), localAddress.toString(), config); - - return localAddress; + if (Epoll.isAvailable()) { + initEpollBootstrap(config); + } else { + initNioBootstrap(config); } - private void initNioBootstrap(ShuffleConfig config) { - // Add the server port number to the name in order to distinguish - // multiple servers running on the same host. - String name = String.format("%s(%s)", SERVER_THREAD_GROUP_NAME, config.getServerPort()); + bootstrap.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); + // Pooled allocators for Netty's ByteBuf instances + pooledAllocator = + NettyUtils.createPooledByteBufAllocator( + config.preferDirectBuffer(), + config.isThreadCacheEnabled(), + config.getServerThreadsNum()); + bootstrap.option(ChannelOption.ALLOCATOR, pooledAllocator); + bootstrap.childOption(ChannelOption.ALLOCATOR, pooledAllocator); + + if (config.getServerConnectBacklog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, config.getServerConnectBacklog()); + } - NioEventLoopGroup nioGroup = new NioEventLoopGroup(config.getServerThreadsNum(), - NettyUtils.getNamedThreadFactory(name)); - bootstrap.group(nioGroup).channel(NioServerSocketChannel.class); + int receiveBufferSize = config.getReceiveBufferSize(); + if (receiveBufferSize > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, receiveBufferSize); + } + int sendBufferSize = config.getSendBufferSize(); + if (sendBufferSize > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, sendBufferSize); } - private void initEpollBootstrap(ShuffleConfig config) { - // Add the server port number to the name in order to distinguish - // multiple servers running on the same host. - String name = String.format("%s(%s)", SERVER_THREAD_GROUP_NAME, config.getServerPort()); + // -------------------------------------------------------------------- + // Child channel pipeline for accepted connections + // -------------------------------------------------------------------- + + bootstrap.childHandler( + new ChannelInitializer() { + @Override + public void initChannel(SocketChannel channel) throws Exception { + channel.pipeline().addLast(context.createServerChannelHandler(channel)); + // context.initializePipeline(channel); + } + }); - EpollEventLoopGroup epollGroup = new EpollEventLoopGroup(config.getServerThreadsNum(), - NettyUtils.getNamedThreadFactory(name)); - bootstrap.group(epollGroup).channel(EpollServerSocketChannel.class); + // -------------------------------------------------------------------- + // Start Server + // -------------------------------------------------------------------- + + bootstrap.localAddress(config.getServerAddress(), config.getServerPort()); + bindFuture = bootstrap.bind().syncUninterruptibly(); + InetSocketAddress localAddress = (InetSocketAddress) bindFuture.channel().localAddress(); + + long end = System.currentTimeMillis(); + LOGGER.info( + "Successful initialization (took {} ms). Listening on {}. NettyConfig: {}", + (end - start), + localAddress.toString(), + config); + + return localAddress; + } + + private void initNioBootstrap(ShuffleConfig config) { + // Add the server port number to the name in order to distinguish + // multiple servers running on the same host. + String name = String.format("%s(%s)", SERVER_THREAD_GROUP_NAME, config.getServerPort()); + + NioEventLoopGroup nioGroup = + new NioEventLoopGroup(config.getServerThreadsNum(), NettyUtils.getNamedThreadFactory(name)); + bootstrap.group(nioGroup).channel(NioServerSocketChannel.class); + } + + private void initEpollBootstrap(ShuffleConfig config) { + // Add the server port number to the name in order to distinguish + // multiple servers running on the same host. + String name = String.format("%s(%s)", SERVER_THREAD_GROUP_NAME, config.getServerPort()); + + EpollEventLoopGroup epollGroup = + new EpollEventLoopGroup( + config.getServerThreadsNum(), NettyUtils.getNamedThreadFactory(name)); + bootstrap.group(epollGroup).channel(EpollServerSocketChannel.class); + } + + public PooledByteBufAllocator getPooledAllocator() { + return pooledAllocator; + } + + @Override + public void close() throws IOException { + final long start = System.currentTimeMillis(); + if (bindFuture != null) { + bindFuture.channel().close().awaitUninterruptibly(); + bindFuture = null; } - - public PooledByteBufAllocator getPooledAllocator() { - return pooledAllocator; + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully(); } - - @Override - public void close() throws IOException { - final long start = System.currentTimeMillis(); - if (bindFuture != null) { - bindFuture.channel().close().awaitUninterruptibly(); - bindFuture = null; - } - if (bootstrap != null && bootstrap.config().group() != null) { - bootstrap.config().group().shutdownGracefully(); - } - if (bootstrap != null && bootstrap.config().childGroup() != null) { - bootstrap.config().childGroup().shutdownGracefully(); - } - long end = System.currentTimeMillis(); - LOGGER.info("Successful shutdown (took {} ms).", (end - start)); + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully(); } - + long end = System.currentTimeMillis(); + LOGGER.info("Successful shutdown (took {} ms).", (end - start)); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceOutputChannelHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceOutputChannelHandler.java index 27cd8b6b0..5c2c3c40d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceOutputChannelHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceOutputChannelHandler.java @@ -19,16 +19,12 @@ package org.apache.geaflow.shuffle.network.netty; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; import java.io.IOException; import java.util.ArrayDeque; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; + import org.apache.geaflow.shuffle.network.protocol.ErrorResponse; import org.apache.geaflow.shuffle.network.protocol.SliceResponse; import org.apache.geaflow.shuffle.pipeline.buffer.OutBuffer; @@ -39,242 +35,248 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + public class SliceOutputChannelHandler extends ChannelInboundHandlerAdapter { - private static final Logger LOGGER = LoggerFactory.getLogger(SliceOutputChannelHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SliceOutputChannelHandler.class); - // The readers which are already enqueued available for transferring data. - private final ArrayDeque availableReaders = new ArrayDeque<>(); + // The readers which are already enqueued available for transferring data. + private final ArrayDeque availableReaders = new ArrayDeque<>(); - // All the readers created for the consumers' slice requests. - private final ConcurrentMap allReaders = - new ConcurrentHashMap<>(); + // All the readers created for the consumers' slice requests. + private final ConcurrentMap allReaders = + new ConcurrentHashMap<>(); - private boolean fatalError; + private boolean fatalError; - private ChannelHandlerContext ctx; + private ChannelHandlerContext ctx; - @Override - public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { - if (this.ctx == null) { - this.ctx = ctx; - } + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + if (this.ctx == null) { + this.ctx = ctx; + } - super.channelRegistered(ctx); + super.channelRegistered(ctx); + } + + public void notifyNonEmpty(final SequenceSliceReader reader) { + ctx.executor().execute(() -> ctx.pipeline().fireUserEventTriggered(reader)); + } + + /** + * Try to enqueue the reader once receiving non-empty reader notification from the sliceWriter. + * + *

NOTE: Only one thread would trigger the actual enqueue after checking the reader's + * availability, so there is no race condition here. + */ + private void enqueueReader(final SequenceSliceReader reader) throws Exception { + if (reader.isRegistered() || !reader.isAvailable()) { + return; } - public void notifyNonEmpty(final SequenceSliceReader reader) { - ctx.executor().execute(() -> ctx.pipeline().fireUserEventTriggered(reader)); + // Queue an available reader for consumption. If the queue is empty, + // we try trigger the actual write. Otherwise, this will be handled by + // the writeAndFlushNextMessageIfPossible calls. + boolean triggerWrite = availableReaders.isEmpty(); + addAvailableReader(reader); + + if (triggerWrite) { + writeAndFlushNextMessageIfPossible(ctx.channel()); } + } - /** - * Try to enqueue the reader once receiving non-empty reader notification from the sliceWriter. - * - *

NOTE: Only one thread would trigger the actual enqueue after checking the reader's - * availability, so there is no race condition here. - */ - private void enqueueReader(final SequenceSliceReader reader) throws Exception { - if (reader.isRegistered() || !reader.isAvailable()) { - return; - } + public void notifyReaderCreated(final SequenceSliceReader reader) { + allReaders.put(reader.getReceiverId(), reader); + } - // Queue an available reader for consumption. If the queue is empty, - // we try trigger the actual write. Otherwise, this will be handled by - // the writeAndFlushNextMessageIfPossible calls. - boolean triggerWrite = availableReaders.isEmpty(); - addAvailableReader(reader); + public void cancel(ChannelId receiverId) { + ctx.pipeline().fireUserEventTriggered(receiverId); + } - if (triggerWrite) { - writeAndFlushNextMessageIfPossible(ctx.channel()); - } + public void close() throws IOException { + if (ctx != null) { + ctx.channel().close(); } - public void notifyReaderCreated(final SequenceSliceReader reader) { - allReaders.put(reader.getReceiverId(), reader); + for (SequenceSliceReader reader : allReaders.values()) { + releaseReader(reader); } + allReaders.clear(); + } - public void cancel(ChannelId receiverId) { - ctx.pipeline().fireUserEventTriggered(receiverId); + public void applyReaderOperation(ChannelId receiverId, Consumer operation) + throws Exception { + if (fatalError) { + return; } - public void close() throws IOException { - if (ctx != null) { - ctx.channel().close(); - } - - for (SequenceSliceReader reader : allReaders.values()) { - releaseReader(reader); - } - allReaders.clear(); + SequenceSliceReader reader = allReaders.get(receiverId); + if (reader != null) { + operation.accept(reader); + enqueueReader(reader); + } else { + throw new IllegalStateException("No reader for receiverId = " + receiverId + " exists."); } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object msg) throws Exception { + // The user event triggered event loop callback is used for thread-safe + // hand over of reader queues and cancelled producers. + + if (msg instanceof SequenceSliceReader) { + enqueueReader((SequenceSliceReader) msg); + } else if (msg.getClass() == ChannelId.class) { + // Release reader that get a cancel request. + ChannelId toCancel = (ChannelId) msg; + + // Remove reader from queue of available readers. + availableReaders.removeIf(reader -> reader.getReceiverId().equals(toCancel)); + + // Remove reader from queue of all readers and release its resource. + final SequenceSliceReader toRelease = allReaders.remove(toCancel); + if (toRelease != null) { + releaseReader(toRelease); + } + } else { + ctx.fireUserEventTriggered(msg); + } + } - public void applyReaderOperation(ChannelId receiverId, Consumer operation) - throws Exception { - if (fatalError) { - return; - } + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writeAndFlushNextMessageIfPossible(ctx.channel()); + } - SequenceSliceReader reader = allReaders.get(receiverId); - if (reader != null) { - operation.accept(reader); - enqueueReader(reader); - } else { - throw new IllegalStateException( - "No reader for receiverId = " + receiverId + " exists."); - } + private void writeAndFlushNextMessageIfPossible(final Channel channel) throws IOException { + if (fatalError || !channel.isWritable()) { + return; } - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object msg) throws Exception { - // The user event triggered event loop callback is used for thread-safe - // hand over of reader queues and cancelled producers. - - if (msg instanceof SequenceSliceReader) { - enqueueReader((SequenceSliceReader) msg); - } else if (msg.getClass() == ChannelId.class) { - // Release reader that get a cancel request. - ChannelId toCancel = (ChannelId) msg; - - // Remove reader from queue of available readers. - availableReaders.removeIf(reader -> reader.getReceiverId().equals(toCancel)); - - // Remove reader from queue of all readers and release its resource. - final SequenceSliceReader toRelease = allReaders.remove(toCancel); - if (toRelease != null) { - releaseReader(toRelease); - } - } else { - ctx.fireUserEventTriggered(msg); - } - } + PipeChannelBuffer next; + SequenceSliceReader reader = null; - @Override - public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { - writeAndFlushNextMessageIfPossible(ctx.channel()); - } + try { + while (true) { + reader = pollAvailableReader(); - private void writeAndFlushNextMessageIfPossible(final Channel channel) throws IOException { - if (fatalError || !channel.isWritable()) { - return; + // No queue with available data. We allow this here, because + // of the write callbacks that are executed after each write. + if (reader == null) { + return; } - PipeChannelBuffer next; - SequenceSliceReader reader = null; - - try { - while (true) { - reader = pollAvailableReader(); - - // No queue with available data. We allow this here, because - // of the write callbacks that are executed after each write. - if (reader == null) { - return; - } - - next = reader.next(); - if (next != null) { - // This channel was now removed from the available reader queue. - // We re-add it into the queue if it is still available. - if (next.moreAvailable()) { - addAvailableReader(reader); - } - - SliceResponse msg = new SliceResponse(next.getBuffer(), - reader.getSequenceNumber(), reader.getReceiverId()); - - // Write and flush and wait until this is done before - // trying to continue with the next buffer. - channel.writeAndFlush(msg).addListener(new WriteNextMessageIfPossibleListener(next.getBuffer())); - - return; - } - } - } catch (Throwable t) { - LOGGER.error("fetch {} failed: {}", reader, t.getMessage()); - throw new IOException(t.getMessage(), t); + next = reader.next(); + if (next != null) { + // This channel was now removed from the available reader queue. + // We re-add it into the queue if it is still available. + if (next.moreAvailable()) { + addAvailableReader(reader); + } + + SliceResponse msg = + new SliceResponse( + next.getBuffer(), reader.getSequenceNumber(), reader.getReceiverId()); + + // Write and flush and wait until this is done before + // trying to continue with the next buffer. + channel + .writeAndFlush(msg) + .addListener(new WriteNextMessageIfPossibleListener(next.getBuffer())); + + return; } + } + } catch (Throwable t) { + LOGGER.error("fetch {} failed: {}", reader, t.getMessage()); + throw new IOException(t.getMessage(), t); } + } - private void addAvailableReader(SequenceSliceReader reader) { - availableReaders.add(reader); - reader.setRegistered(true); - } + private void addAvailableReader(SequenceSliceReader reader) { + availableReaders.add(reader); + reader.setRegistered(true); + } - private SequenceSliceReader pollAvailableReader() { - SequenceSliceReader reader = availableReaders.poll(); - if (reader != null) { - reader.setRegistered(false); - } - return reader; + private SequenceSliceReader pollAvailableReader() { + SequenceSliceReader reader = availableReaders.poll(); + if (reader != null) { + reader.setRegistered(false); } + return reader; + } - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - LOGGER.warn("channel inactive and release resource..."); - releaseAllResources(); - ctx.fireChannelInactive(); - } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + LOGGER.warn("channel inactive and release resource..."); + releaseAllResources(); + ctx.fireChannelInactive(); + } - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - handleException(ctx.channel(), cause); - } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + handleException(ctx.channel(), cause); + } - private void handleException(Channel channel, Throwable cause) throws IOException { - LOGGER.error("Encountered error while consuming slices", cause); + private void handleException(Channel channel, Throwable cause) throws IOException { + LOGGER.error("Encountered error while consuming slices", cause); - fatalError = true; - releaseAllResources(); + fatalError = true; + releaseAllResources(); - if (channel.isActive()) { - channel.writeAndFlush(new ErrorResponse(cause)) - .addListener(ChannelFutureListener.CLOSE); - } + if (channel.isActive()) { + channel.writeAndFlush(new ErrorResponse(cause)).addListener(ChannelFutureListener.CLOSE); } + } - private void releaseAllResources() throws IOException { - // Note: this is only ever executed by one thread: the Netty IO thread! - for (SequenceSliceReader reader : allReaders.values()) { - releaseReader(reader); - } - - availableReaders.clear(); - allReaders.clear(); + private void releaseAllResources() throws IOException { + // Note: this is only ever executed by one thread: the Netty IO thread! + for (SequenceSliceReader reader : allReaders.values()) { + releaseReader(reader); } - private void releaseReader(SequenceSliceReader reader) throws IOException { - reader.setRegistered(false); - reader.releaseAllResources(); - } + availableReaders.clear(); + allReaders.clear(); + } - // This listener is called after an element of the current nonEmptyReader has been - // flushed. If successful, the listener triggers further processing of the queues. - private class WriteNextMessageIfPossibleListener implements ChannelFutureListener { + private void releaseReader(SequenceSliceReader reader) throws IOException { + reader.setRegistered(false); + reader.releaseAllResources(); + } - private final OutBuffer buffer; + // This listener is called after an element of the current nonEmptyReader has been + // flushed. If successful, the listener triggers further processing of the queues. + private class WriteNextMessageIfPossibleListener implements ChannelFutureListener { - public WriteNextMessageIfPossibleListener(PipeBuffer pipeBuffer) { - this.buffer = pipeBuffer.getBuffer(); - } + private final OutBuffer buffer; - @Override - public void operationComplete(ChannelFuture future) throws Exception { - try { - if (buffer != null) { - buffer.release(); - } - if (future.isSuccess()) { - writeAndFlushNextMessageIfPossible(future.channel()); - } else if (future.cause() != null) { - handleException(future.channel(), future.cause()); - } else { - handleException(future.channel(), - new IllegalStateException("Sending cancelled by user.")); - } - } catch (Throwable t) { - handleException(future.channel(), t); - } - } + public WriteNextMessageIfPossibleListener(PipeBuffer pipeBuffer) { + this.buffer = pipeBuffer.getBuffer(); } + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + if (buffer != null) { + buffer.release(); + } + if (future.isSuccess()) { + writeAndFlushNextMessageIfPossible(future.channel()); + } else if (future.cause() != null) { + handleException(future.channel(), future.cause()); + } else { + handleException( + future.channel(), new IllegalStateException("Sending cancelled by user.")); + } + } catch (Throwable t) { + handleException(future.channel(), t); + } + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClient.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClient.java index 71993d52d..9c7beee76 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClient.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClient.java @@ -19,13 +19,10 @@ package org.apache.geaflow.shuffle.network.netty; -import com.google.common.base.Preconditions; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import java.io.IOException; import java.net.SocketAddress; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.network.ConnectionId; import org.apache.geaflow.shuffle.network.protocol.AddCreditRequest; @@ -39,143 +36,167 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class SliceRequestClient { - - private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClient.class); - - private final Channel tcpChannel; - private final ConnectionId connectionId; - private final SliceRequestClientHandler clientHandler; - private final SliceRequestClientFactory clientFactory; - // If zero, the underlying TCP channel can be safely closed. - private final AtomicReferenceCounter closeReferenceCounter = new AtomicReferenceCounter(); - - public SliceRequestClient(Channel tcpChannel, SliceRequestClientHandler clientHandler, - ConnectionId connectionId, SliceRequestClientFactory clientFactory) { - - this.tcpChannel = Preconditions.checkNotNull(tcpChannel); - this.clientHandler = Preconditions.checkNotNull(clientHandler); - this.connectionId = Preconditions.checkNotNull(connectionId); - this.clientFactory = Preconditions.checkNotNull(clientFactory); - } +import com.google.common.base.Preconditions; - public boolean disposeIfNotUsed() { - return closeReferenceCounter.disposeIfNotUsed(); - } +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; - /** - * Increments the reference counter. - * - *

Note: the reference counter has to be incremented before returning the - * instance of this client to ensure correct closing logic. - */ - boolean incrementReferenceCounter() { - return closeReferenceCounter.increment(); - } +public class SliceRequestClient { - /** - * Requests a remote intermediate result partition queue. - * - *

The request goes to the remote producer, for which this partition - * request client instance has been created. - */ - public void requestSlice(SliceId sliceId, final RemoteInputChannel inputChannel, int delayMs, - long startBatchId) throws IOException { - - checkNotClosed(); - clientHandler.addInputChannel(inputChannel); - - final SliceRequest request = new SliceRequest(sliceId, startBatchId, - inputChannel.getInputChannelId(), inputChannel.getInitialCredit()); - - final ChannelFutureListener listener = new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (!future.isSuccess()) { - clientHandler.removeInputChannel(inputChannel); - SocketAddress remoteAddr = future.channel().remoteAddress(); - inputChannel.onError(new TransportException( - String.format("Sending the request to '%s' failed.", remoteAddr), - future.channel().localAddress(), future.cause())); - } + private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClient.class); + + private final Channel tcpChannel; + private final ConnectionId connectionId; + private final SliceRequestClientHandler clientHandler; + private final SliceRequestClientFactory clientFactory; + // If zero, the underlying TCP channel can be safely closed. + private final AtomicReferenceCounter closeReferenceCounter = new AtomicReferenceCounter(); + + public SliceRequestClient( + Channel tcpChannel, + SliceRequestClientHandler clientHandler, + ConnectionId connectionId, + SliceRequestClientFactory clientFactory) { + + this.tcpChannel = Preconditions.checkNotNull(tcpChannel); + this.clientHandler = Preconditions.checkNotNull(clientHandler); + this.connectionId = Preconditions.checkNotNull(connectionId); + this.clientFactory = Preconditions.checkNotNull(clientFactory); + } + + public boolean disposeIfNotUsed() { + return closeReferenceCounter.disposeIfNotUsed(); + } + + /** + * Increments the reference counter. + * + *

Note: the reference counter has to be incremented before returning the instance of this + * client to ensure correct closing logic. + */ + boolean incrementReferenceCounter() { + return closeReferenceCounter.increment(); + } + + /** + * Requests a remote intermediate result partition queue. + * + *

The request goes to the remote producer, for which this partition request client instance + * has been created. + */ + public void requestSlice( + SliceId sliceId, final RemoteInputChannel inputChannel, int delayMs, long startBatchId) + throws IOException { + + checkNotClosed(); + clientHandler.addInputChannel(inputChannel); + + final SliceRequest request = + new SliceRequest( + sliceId, + startBatchId, + inputChannel.getInputChannelId(), + inputChannel.getInitialCredit()); + + final ChannelFutureListener listener = + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + clientHandler.removeInputChannel(inputChannel); + SocketAddress remoteAddr = future.channel().remoteAddress(); + inputChannel.onError( + new TransportException( + String.format("Sending the request to '%s' failed.", remoteAddr), + future.channel().localAddress(), + future.cause())); } + } }; - if (delayMs == 0) { - ChannelFuture f = tcpChannel.writeAndFlush(request); - f.addListener(listener); - } else { - final ChannelFuture[] f = new ChannelFuture[1]; - tcpChannel.eventLoop().schedule(new Runnable() { + if (delayMs == 0) { + ChannelFuture f = tcpChannel.writeAndFlush(request); + f.addListener(listener); + } else { + final ChannelFuture[] f = new ChannelFuture[1]; + tcpChannel + .eventLoop() + .schedule( + new Runnable() { @Override public void run() { - f[0] = tcpChannel.writeAndFlush(request); - f[0].addListener(listener); + f[0] = tcpChannel.writeAndFlush(request); + f[0].addListener(listener); } - }, delayMs, TimeUnit.MILLISECONDS); - } - } - - public void requestNextBatch(long batchId, final RemoteInputChannel inputChannel) - throws IOException { - checkNotClosed(); - final BatchRequest request = new BatchRequest(batchId, inputChannel.getInputChannelId()); - sendRequest(inputChannel, request); - } - - public void notifyCreditAvailable(RemoteInputChannel inputChannel) throws IOException { - checkNotClosed(); - - int credit = inputChannel.getAndResetAvailableCredit(); - Preconditions.checkArgument(credit > 0, "Credit must be greater than zero."); - final AddCreditRequest request = new AddCreditRequest(credit, - inputChannel.getInputChannelId()); - sendRequest(inputChannel, request); + }, + delayMs, + TimeUnit.MILLISECONDS); } - - private void sendRequest(RemoteInputChannel inputChannel, NettyMessage request) throws IOException { - checkNotClosed(); - - final ChannelFutureListener listener = new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (!future.isSuccess()) { - SocketAddress remoteAddr = future.channel().remoteAddress(); - inputChannel.onError(new TransportException( - String.format("Sending the batch request to '%s' failed.", remoteAddr), - future.channel().localAddress(), future.cause())); - } + } + + public void requestNextBatch(long batchId, final RemoteInputChannel inputChannel) + throws IOException { + checkNotClosed(); + final BatchRequest request = new BatchRequest(batchId, inputChannel.getInputChannelId()); + sendRequest(inputChannel, request); + } + + public void notifyCreditAvailable(RemoteInputChannel inputChannel) throws IOException { + checkNotClosed(); + + int credit = inputChannel.getAndResetAvailableCredit(); + Preconditions.checkArgument(credit > 0, "Credit must be greater than zero."); + final AddCreditRequest request = new AddCreditRequest(credit, inputChannel.getInputChannelId()); + sendRequest(inputChannel, request); + } + + private void sendRequest(RemoteInputChannel inputChannel, NettyMessage request) + throws IOException { + checkNotClosed(); + + final ChannelFutureListener listener = + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + SocketAddress remoteAddr = future.channel().remoteAddress(); + inputChannel.onError( + new TransportException( + String.format("Sending the batch request to '%s' failed.", remoteAddr), + future.channel().localAddress(), + future.cause())); } + } }; - ChannelFuture f = tcpChannel.writeAndFlush(request); - f.addListener(listener); + ChannelFuture f = tcpChannel.writeAndFlush(request); + f.addListener(listener); + } + + public void close(RemoteInputChannel inputChannel) throws IOException { + clientHandler.removeInputChannel(inputChannel); + + if (closeReferenceCounter.decrement()) { + // Close the TCP connection. Send a close request msg to ensure + // that outstanding backwards task events are not discarded. + tcpChannel + .writeAndFlush(new CloseRequest()) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + + // Make sure to remove the client from the factory. + clientFactory.destroyRequestClient(connectionId, this); + } else { + LOGGER.warn("cancel slice consumption of {}", inputChannel.getInputSliceId()); + clientHandler.cancelRequest(inputChannel.getInputChannelId()); } + } - public void close(RemoteInputChannel inputChannel) throws IOException { - clientHandler.removeInputChannel(inputChannel); - - if (closeReferenceCounter.decrement()) { - // Close the TCP connection. Send a close request msg to ensure - // that outstanding backwards task events are not discarded. - tcpChannel.writeAndFlush(new CloseRequest()) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - - // Make sure to remove the client from the factory. - clientFactory.destroyRequestClient(connectionId, this); - } else { - LOGGER.warn("cancel slice consumption of {}", inputChannel.getInputSliceId()); - clientHandler.cancelRequest(inputChannel.getInputChannelId()); - } + private void checkNotClosed() throws IOException { + if (closeReferenceCounter.isDisposed()) { + final SocketAddress localAddr = tcpChannel.localAddress(); + final SocketAddress remoteAddr = tcpChannel.remoteAddress(); + throw new TransportException(String.format("Channel to '%s' closed.", remoteAddr), localAddr); } - - private void checkNotClosed() throws IOException { - if (closeReferenceCounter.isDisposed()) { - final SocketAddress localAddr = tcpChannel.localAddress(); - final SocketAddress remoteAddr = tcpChannel.remoteAddress(); - throw new TransportException(String.format("Channel to '%s' closed.", remoteAddr), - localAddr); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientFactory.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientFactory.java index fb12ee76f..f3d47398c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientFactory.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientFactory.java @@ -19,12 +19,12 @@ package org.apache.geaflow.shuffle.network.netty; -import io.netty.channel.Channel; import java.io.IOException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.config.ShuffleConfig; import org.apache.geaflow.shuffle.network.ConnectionId; @@ -32,123 +32,125 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.netty.channel.Channel; + public class SliceRequestClientFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClientFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClientFactory.class); - private final int retryNumber; - private final NettyClient nettyClient; + private final int retryNumber; + private final NettyClient nettyClient; - private final ConcurrentMap> clients = - new ConcurrentHashMap<>(); + private final ConcurrentMap> clients = + new ConcurrentHashMap<>(); - public SliceRequestClientFactory(ShuffleConfig nettyConfig, NettyClient nettyClient) { - this.nettyClient = nettyClient; - this.retryNumber = nettyConfig.getConnectMaxRetries(); - } + public SliceRequestClientFactory(ShuffleConfig nettyConfig, NettyClient nettyClient) { + this.nettyClient = nettyClient; + this.retryNumber = nettyConfig.getConnectMaxRetries(); + } - /** - * Atomically establishes a TCP connection to the given remote address and - * creates a {@link SliceRequestClient} instance for this connection. - */ - public SliceRequestClient createSliceRequestClient(ConnectionId connectionId) - throws InterruptedException { - while (true) { - final CompletableFuture newClientFuture = - new CompletableFuture<>(); - - CompletableFuture clientFuture = clients.putIfAbsent(connectionId - , newClientFuture); - - final SliceRequestClient client; - - if (clientFuture == null) { - try { - client = connectWithRetries(connectionId); - } catch (Throwable e) { - newClientFuture.completeExceptionally( - new IOException("Could not create client.", e)); - clients.remove(connectionId, newClientFuture); - throw e; - } - newClientFuture.complete(client); - } else { - try { - client = clientFuture.get(); - } catch (ExecutionException e) { - throw new GeaflowRuntimeException("connect failed", e); - } - } + /** + * Atomically establishes a TCP connection to the given remote address and creates a {@link + * SliceRequestClient} instance for this connection. + */ + public SliceRequestClient createSliceRequestClient(ConnectionId connectionId) + throws InterruptedException { + while (true) { + final CompletableFuture newClientFuture = new CompletableFuture<>(); - // Make sure to increment the reference count before handing a client - // out to ensure correct bookkeeping for channel closing. - if (client.incrementReferenceCounter()) { - return client; - } else { - destroyRequestClient(connectionId, client); - } - } - } + CompletableFuture clientFuture = + clients.putIfAbsent(connectionId, newClientFuture); - private SliceRequestClient connectWithRetries(ConnectionId connectionId) { - int tried = 0; - long startTime = System.nanoTime(); - while (true) { - try { - SliceRequestClient client = connect(connectionId); - LOGGER.info("Successfully created connection to {} after {} ms", connectionId, - (System.nanoTime() - startTime) / 1000000); - return client; - } catch (TransportException e) { - tried++; - LOGGER.error("failed to connect to {}, retry #{}", connectionId, tried, e); - if (tried > retryNumber) { - throw new GeaflowRuntimeException(String.format("Failed to connect to %s", - connectionId.getAddress()), e); - } - } - } - } + final SliceRequestClient client; - private SliceRequestClient connect(ConnectionId connectionId) throws TransportException { + if (clientFuture == null) { try { - Channel channel = nettyClient.connect(connectionId.getAddress()).await().channel(); - SliceRequestClientHandler clientHandler = channel.pipeline() - .get(SliceRequestClientHandler.class); - return new SliceRequestClient(channel, clientHandler, connectionId, this); - } catch (Exception e) { - throw new TransportException( - "Connecting to remote server '" + connectionId.getAddress() - + "' has failed. This might indicate that the remote server has been lost.", - connectionId.getAddress(), e); + client = connectWithRetries(connectionId); + } catch (Throwable e) { + newClientFuture.completeExceptionally(new IOException("Could not create client.", e)); + clients.remove(connectionId, newClientFuture); + throw e; } - } - - public void closeOpenChannelConnections(ConnectionId connectionId) { - CompletableFuture entry = clients.get(connectionId); - - if (entry != null && !entry.isDone()) { - entry.thenAccept(client -> { - if (client.disposeIfNotUsed()) { - clients.remove(connectionId, entry); - } - }); + newClientFuture.complete(client); + } else { + try { + client = clientFuture.get(); + } catch (ExecutionException e) { + throw new GeaflowRuntimeException("connect failed", e); } + } + + // Make sure to increment the reference count before handing a client + // out to ensure correct bookkeeping for channel closing. + if (client.incrementReferenceCounter()) { + return client; + } else { + destroyRequestClient(connectionId, client); + } } - - /** - * Removes the client for the given {@link ConnectionId}. - */ - public void destroyRequestClient(ConnectionId connectionId, - SliceRequestClient client) { - final CompletableFuture future = clients.get(connectionId); - if (future != null && future.isDone()) { - future.thenAccept(futureClient -> { - if (client.equals(futureClient)) { - clients.remove(connectionId, future); - } - }); + } + + private SliceRequestClient connectWithRetries(ConnectionId connectionId) { + int tried = 0; + long startTime = System.nanoTime(); + while (true) { + try { + SliceRequestClient client = connect(connectionId); + LOGGER.info( + "Successfully created connection to {} after {} ms", + connectionId, + (System.nanoTime() - startTime) / 1000000); + return client; + } catch (TransportException e) { + tried++; + LOGGER.error("failed to connect to {}, retry #{}", connectionId, tried, e); + if (tried > retryNumber) { + throw new GeaflowRuntimeException( + String.format("Failed to connect to %s", connectionId.getAddress()), e); } + } + } + } + + private SliceRequestClient connect(ConnectionId connectionId) throws TransportException { + try { + Channel channel = nettyClient.connect(connectionId.getAddress()).await().channel(); + SliceRequestClientHandler clientHandler = + channel.pipeline().get(SliceRequestClientHandler.class); + return new SliceRequestClient(channel, clientHandler, connectionId, this); + } catch (Exception e) { + throw new TransportException( + "Connecting to remote server '" + + connectionId.getAddress() + + "' has failed. This might indicate that the remote server has been lost.", + connectionId.getAddress(), + e); } + } + public void closeOpenChannelConnections(ConnectionId connectionId) { + CompletableFuture entry = clients.get(connectionId); + + if (entry != null && !entry.isDone()) { + entry.thenAccept( + client -> { + if (client.disposeIfNotUsed()) { + clients.remove(connectionId, entry); + } + }); + } + } + + /** Removes the client for the given {@link ConnectionId}. */ + public void destroyRequestClient(ConnectionId connectionId, SliceRequestClient client) { + final CompletableFuture future = clients.get(connectionId); + if (future != null && future.isDone()) { + future.thenAccept( + futureClient -> { + if (client.equals(futureClient)) { + clients.remove(connectionId, future); + } + }); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientHandler.java index 6e461b6f7..f0d89e1de 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientHandler.java @@ -19,14 +19,12 @@ package org.apache.geaflow.shuffle.network.netty; -import com.google.common.annotations.VisibleForTesting; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; import java.io.IOException; import java.net.SocketAddress; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.shuffle.network.protocol.CancelRequest; import org.apache.geaflow.shuffle.network.protocol.ErrorResponse; import org.apache.geaflow.shuffle.network.protocol.NettyMessage; @@ -38,215 +36,229 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.CreditBasedPartitionRequestClientHandler. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.netty.CreditBasedPartitionRequestClientHandler. */ public class SliceRequestClientHandler extends SimpleChannelInboundHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClientHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClientHandler.class); - // Channels, which already requested partitions from the producers. - private final ConcurrentMap inputChannels = - new ConcurrentHashMap<>(); + // Channels, which already requested partitions from the producers. + private final ConcurrentMap inputChannels = + new ConcurrentHashMap<>(); - private final AtomicReference channelError = new AtomicReference<>(); + private final AtomicReference channelError = new AtomicReference<>(); - // Set of cancelled partition requests. A request is cancelled iff an input channel is cleared - // while data is still coming in for this channel. - private final ConcurrentMap cancelled = new ConcurrentHashMap<>(); + // Set of cancelled partition requests. A request is cancelled iff an input channel is cleared + // while data is still coming in for this channel. + private final ConcurrentMap cancelled = new ConcurrentHashMap<>(); - // The channel handler context is initialized in channel active event by netty thread - // the context may also be accessed by task thread or canceler thread to - // cancel partition request during releasing resources. - private volatile ChannelHandlerContext ctx; + // The channel handler context is initialized in channel active event by netty thread + // the context may also be accessed by task thread or canceler thread to + // cancel partition request during releasing resources. + private volatile ChannelHandlerContext ctx; - // ------------------------------------------------------------------------ - // Input channel/receiver registration - // ------------------------------------------------------------------------ + // ------------------------------------------------------------------------ + // Input channel/receiver registration + // ------------------------------------------------------------------------ - public void addInputChannel(RemoteInputChannel listener) throws IOException { - checkError(); - inputChannels.putIfAbsent(listener.getInputChannelId(), listener); - } + public void addInputChannel(RemoteInputChannel listener) throws IOException { + checkError(); + inputChannels.putIfAbsent(listener.getInputChannelId(), listener); + } - public void removeInputChannel(RemoteInputChannel listener) { - inputChannels.remove(listener.getInputChannelId()); - } + public void removeInputChannel(RemoteInputChannel listener) { + inputChannels.remove(listener.getInputChannelId()); + } - public RemoteInputChannel getInputChannel(ChannelId inputChannelId) { - return inputChannels.get(inputChannelId); - } + public RemoteInputChannel getInputChannel(ChannelId inputChannelId) { + return inputChannels.get(inputChannelId); + } - public void cancelRequest(ChannelId inputChannelId) { - if (inputChannelId == null || ctx == null) { - return; - } - - if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) { - ctx.writeAndFlush(new CancelRequest(inputChannelId)); - } + public void cancelRequest(ChannelId inputChannelId) { + if (inputChannelId == null || ctx == null) { + return; } - // ------------------------------------------------------------------------ - // Network events - // ------------------------------------------------------------------------ - - @Override - public void channelActive(final ChannelHandlerContext ctx) throws Exception { - if (this.ctx == null) { - this.ctx = ctx; - } - - super.channelActive(ctx); + if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) { + ctx.writeAndFlush(new CancelRequest(inputChannelId)); } + } - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - // Unexpected close. In normal operation, the client closes the connection after all input - // channels have been removed. This indicates a problem with the remote server. - if (!inputChannels.isEmpty()) { - final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + // ------------------------------------------------------------------------ + // Network events + // ------------------------------------------------------------------------ - notifyAllChannelsOfErrorAndClose(new TransportException( - "Connection unexpectedly closed by remote server '" + remoteAddr + "'. " - + "This might indicate that the remote server was lost.", remoteAddr)); - } - - super.channelInactive(ctx); + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + if (this.ctx == null) { + this.ctx = ctx; } - /** - * Called on exceptions in the client handler pipeline. - * - *

Remote exceptions are received as regular payload. - */ - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (cause instanceof TransportException) { - notifyAllChannelsOfErrorAndClose(cause); - } else { - final SocketAddress remoteAddr = ctx.channel().remoteAddress(); - final TransportException tex; - - // Improve on the connection reset by peer error message. - if (cause instanceof IOException && "Connection reset by peer" - .equals(cause.getMessage())) { - tex = new TransportException("Lost connection to server '" + remoteAddr + "'. " - + "This indicates that the remote server was lost.", remoteAddr, cause); - } else { - final SocketAddress localAddr = ctx.channel().localAddress(); - tex = new TransportException( - String.format("%s (connection to '%s')", cause.getMessage(), remoteAddr), - localAddr, cause); - } - - notifyAllChannelsOfErrorAndClose(tex); - } + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Unexpected close. In normal operation, the client closes the connection after all input + // channels have been removed. This indicates a problem with the remote server. + if (!inputChannels.isEmpty()) { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + notifyAllChannelsOfErrorAndClose( + new TransportException( + "Connection unexpectedly closed by remote server '" + + remoteAddr + + "'. " + + "This might indicate that the remote server was lost.", + remoteAddr)); } - @Override - protected void channelRead0(ChannelHandlerContext channelHandlerContext, - NettyMessage nettyMessage) throws Exception { - try { - decodeMsg(nettyMessage); - } catch (Throwable t) { - notifyAllChannelsOfErrorAndClose(t); - } + super.channelInactive(ctx); + } + + /** + * Called on exceptions in the client handler pipeline. + * + *

Remote exceptions are received as regular payload. + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TransportException) { + notifyAllChannelsOfErrorAndClose(cause); + } else { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + final TransportException tex; + + // Improve on the connection reset by peer error message. + if (cause instanceof IOException && "Connection reset by peer".equals(cause.getMessage())) { + tex = + new TransportException( + "Lost connection to server '" + + remoteAddr + + "'. " + + "This indicates that the remote server was lost.", + remoteAddr, + cause); + } else { + final SocketAddress localAddr = ctx.channel().localAddress(); + tex = + new TransportException( + String.format("%s (connection to '%s')", cause.getMessage(), remoteAddr), + localAddr, + cause); + } + + notifyAllChannelsOfErrorAndClose(tex); } - - private void notifyAllChannelsOfErrorAndClose(Throwable cause) { - if (channelError.compareAndSet(null, cause)) { - try { - for (RemoteInputChannel inputChannel : inputChannels.values()) { - inputChannel.onError(cause); - } - } catch (Throwable t) { - LOGGER.warn( - "Exception was thrown during error notification of a remote input channel.", t); - } finally { - inputChannels.clear(); - - if (ctx != null) { - ctx.close(); - } - } - } + } + + @Override + protected void channelRead0( + ChannelHandlerContext channelHandlerContext, NettyMessage nettyMessage) throws Exception { + try { + decodeMsg(nettyMessage); + } catch (Throwable t) { + notifyAllChannelsOfErrorAndClose(t); } + } - /** - * Checks for an error and rethrows it if one was reported. - */ - @VisibleForTesting - void checkError() throws IOException { - final Throwable t = channelError.get(); - - if (t != null) { - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new IOException("There has been an error in the channel.", t); - } + private void notifyAllChannelsOfErrorAndClose(Throwable cause) { + if (channelError.compareAndSet(null, cause)) { + try { + for (RemoteInputChannel inputChannel : inputChannels.values()) { + inputChannel.onError(cause); } - } + } catch (Throwable t) { + LOGGER.warn("Exception was thrown during error notification of a remote input channel.", t); + } finally { + inputChannels.clear(); - private void decodeMsg(Object msg) { - final Class msgClazz = msg.getClass(); - - if (msgClazz == SliceResponse.class) { - SliceResponse response = (SliceResponse) msg; - - RemoteInputChannel inputChannel = inputChannels.get(response.getReceiverId()); - if (inputChannel == null || inputChannel.isReleased()) { - cancelRequest(response.getReceiverId()); - return; - } - - try { - processBuffer(inputChannel, response); - } catch (Throwable t) { - inputChannel.onError(t); - } - } else if (msgClazz == ErrorResponse.class) { - ErrorResponse error = (ErrorResponse) msg; - SocketAddress remoteAddr = ctx.channel().remoteAddress(); - - if (error.isFatalError()) { - notifyAllChannelsOfErrorAndClose( - new TransportException("Fatal error at remote server '" + remoteAddr + "'.", - remoteAddr, error.getCause())); - } else { - RemoteInputChannel inputChannel = inputChannels.get(error.getChannelId()); - - if (inputChannel != null) { - if (error.getCause().getClass() == SliceNotFoundException.class) { - inputChannel.onFailedFetchRequest(); - } else { - inputChannel.onError( - new TransportException("Error at remote server '" + remoteAddr + "'.", - remoteAddr, error.getCause())); - } - } - } - } else { - throw new IllegalStateException( - "Received unknown message from producer: " + msg.getClass()); + if (ctx != null) { + ctx.close(); } + } } - - private void processBuffer(RemoteInputChannel inputChannel, SliceResponse response) - throws Throwable { - if (response.getBuffer().isData() && response.getBufferSize() == 0) { - inputChannel.onEmptyBuffer(response.getSequenceNumber()); - } else if (response.getBuffer() != null) { - inputChannel.onBuffer(response.getBuffer(), response.getSequenceNumber()); - } else { - throw new IllegalStateException( - "The read buffer is null in input channel: " + inputChannel.getChannelIndex()); + } + + /** Checks for an error and rethrows it if one was reported. */ + @VisibleForTesting + void checkError() throws IOException { + final Throwable t = channelError.get(); + + if (t != null) { + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new IOException("There has been an error in the channel.", t); + } + } + } + + private void decodeMsg(Object msg) { + final Class msgClazz = msg.getClass(); + + if (msgClazz == SliceResponse.class) { + SliceResponse response = (SliceResponse) msg; + + RemoteInputChannel inputChannel = inputChannels.get(response.getReceiverId()); + if (inputChannel == null || inputChannel.isReleased()) { + cancelRequest(response.getReceiverId()); + return; + } + + try { + processBuffer(inputChannel, response); + } catch (Throwable t) { + inputChannel.onError(t); + } + } else if (msgClazz == ErrorResponse.class) { + ErrorResponse error = (ErrorResponse) msg; + SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + if (error.isFatalError()) { + notifyAllChannelsOfErrorAndClose( + new TransportException( + "Fatal error at remote server '" + remoteAddr + "'.", + remoteAddr, + error.getCause())); + } else { + RemoteInputChannel inputChannel = inputChannels.get(error.getChannelId()); + + if (inputChannel != null) { + if (error.getCause().getClass() == SliceNotFoundException.class) { + inputChannel.onFailedFetchRequest(); + } else { + inputChannel.onError( + new TransportException( + "Error at remote server '" + remoteAddr + "'.", remoteAddr, error.getCause())); + } } + } + } else { + throw new IllegalStateException("Received unknown message from producer: " + msg.getClass()); } - + } + + private void processBuffer(RemoteInputChannel inputChannel, SliceResponse response) + throws Throwable { + if (response.getBuffer().isData() && response.getBufferSize() == 0) { + inputChannel.onEmptyBuffer(response.getSequenceNumber()); + } else if (response.getBuffer() != null) { + inputChannel.onBuffer(response.getBuffer(), response.getSequenceNumber()); + } else { + throw new IllegalStateException( + "The read buffer is null in input channel: " + inputChannel.getChannelIndex()); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestServerHandler.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestServerHandler.java index df0911731..95482a4c0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestServerHandler.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/netty/SliceRequestServerHandler.java @@ -19,8 +19,6 @@ package org.apache.geaflow.shuffle.network.netty; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; import org.apache.geaflow.shuffle.network.protocol.AddCreditRequest; import org.apache.geaflow.shuffle.network.protocol.BatchRequest; import org.apache.geaflow.shuffle.network.protocol.CancelRequest; @@ -33,75 +31,76 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; + public class SliceRequestServerHandler extends SimpleChannelInboundHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestServerHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestServerHandler.class); - private final SliceOutputChannelHandler outboundQueue; + private final SliceOutputChannelHandler outboundQueue; - public SliceRequestServerHandler(SliceOutputChannelHandler outboundQueue) { - this.outboundQueue = outboundQueue; - } + public SliceRequestServerHandler(SliceOutputChannelHandler outboundQueue) { + this.outboundQueue = outboundQueue; + } - @Override - public void channelRegistered(ChannelHandlerContext ctx) throws Exception { - super.channelRegistered(ctx); - } + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + super.channelRegistered(ctx); + } - @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - super.channelUnregistered(ctx); - } + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + super.channelUnregistered(ctx); + } - @Override - protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws Exception { + @Override + protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws Exception { + try { + Class msgClazz = msg.getClass(); + if (msgClazz == SliceRequest.class) { + SliceRequest request = (SliceRequest) msg; try { - Class msgClazz = msg.getClass(); - if (msgClazz == SliceRequest.class) { - SliceRequest request = (SliceRequest) msg; - try { - SequenceSliceReader reader = new SequenceSliceReader( - request.getReceiverId(), outboundQueue); - reader.createSliceReader(request.getSliceId(), request.getStartBatchId(), - request.getInitialCredit()); - - outboundQueue.notifyReaderCreated(reader); - } catch (Throwable notFound) { - respondWithError(ctx, notFound, request.getReceiverId()); - } - } else if (msgClazz == CancelRequest.class) { - CancelRequest request = (CancelRequest) msg; - - outboundQueue.cancel(request.receiverId()); - } else if (msgClazz == CloseRequest.class) { - - outboundQueue.close(); - } else if (msgClazz == BatchRequest.class) { - BatchRequest request = (BatchRequest) msg; - - outboundQueue.applyReaderOperation(request.receiverId(), - reader -> reader.requestBatch(request.getNextBatchId())); - } else if (msgClazz == AddCreditRequest.class) { - AddCreditRequest request = (AddCreditRequest) msg; - - outboundQueue.applyReaderOperation(request.receiverId(), - reader -> reader.addCredit(request.getCredit())); - } else { - LOGGER.warn("Received unexpected client request: {}", msg); - respondWithError(ctx, new IllegalArgumentException("unknown request:" + msg)); - } - } catch (Throwable t) { - respondWithError(ctx, t); + SequenceSliceReader reader = + new SequenceSliceReader(request.getReceiverId(), outboundQueue); + reader.createSliceReader( + request.getSliceId(), request.getStartBatchId(), request.getInitialCredit()); + + outboundQueue.notifyReaderCreated(reader); + } catch (Throwable notFound) { + respondWithError(ctx, notFound, request.getReceiverId()); } + } else if (msgClazz == CancelRequest.class) { + CancelRequest request = (CancelRequest) msg; + + outboundQueue.cancel(request.receiverId()); + } else if (msgClazz == CloseRequest.class) { + + outboundQueue.close(); + } else if (msgClazz == BatchRequest.class) { + BatchRequest request = (BatchRequest) msg; + + outboundQueue.applyReaderOperation( + request.receiverId(), reader -> reader.requestBatch(request.getNextBatchId())); + } else if (msgClazz == AddCreditRequest.class) { + AddCreditRequest request = (AddCreditRequest) msg; + + outboundQueue.applyReaderOperation( + request.receiverId(), reader -> reader.addCredit(request.getCredit())); + } else { + LOGGER.warn("Received unexpected client request: {}", msg); + respondWithError(ctx, new IllegalArgumentException("unknown request:" + msg)); + } + } catch (Throwable t) { + respondWithError(ctx, t); } + } - private void respondWithError(ChannelHandlerContext ctx, Throwable error) { - ctx.writeAndFlush(new ErrorResponse(error)); - } - - private void respondWithError(ChannelHandlerContext ctx, Throwable error, - ChannelId sourceId) { - ctx.writeAndFlush(new ErrorResponse(sourceId, error)); - } + private void respondWithError(ChannelHandlerContext ctx, Throwable error) { + ctx.writeAndFlush(new ErrorResponse(error)); + } + private void respondWithError(ChannelHandlerContext ctx, Throwable error, ChannelId sourceId) { + ctx.writeAndFlush(new ErrorResponse(sourceId, error)); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AbstractFileRegion.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AbstractFileRegion.java index 48be9c9e5..a29f6ded8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AbstractFileRegion.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AbstractFileRegion.java @@ -26,62 +26,59 @@ * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Spark's org.apache.spark.network.util.AbstractFileRegion. - */ +/** This class is an adaptation of Spark's org.apache.spark.network.util.AbstractFileRegion. */ public abstract class AbstractFileRegion extends AbstractReferenceCounted implements FileRegion { - protected int chunkSize = 64 * 1024 * 1024; - - protected long transferred; - protected long contentSize; - - public AbstractFileRegion(long contentSize) { - this.transferred = 0; - this.contentSize = contentSize; - } - - @Override - public final long transfered() { - return transferred(); - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return transferred; - } - - @Override - public long count() { - return contentSize; - } - - @Override - public AbstractFileRegion retain() { - super.retain(); - return this; - } - - @Override - public AbstractFileRegion retain(int increment) { - super.retain(increment); - return this; - } - - @Override - public AbstractFileRegion touch() { - super.touch(); - return this; - } - - @Override - public AbstractFileRegion touch(Object o) { - return this; - } - + protected int chunkSize = 64 * 1024 * 1024; + + protected long transferred; + protected long contentSize; + + public AbstractFileRegion(long contentSize) { + this.transferred = 0; + this.contentSize = contentSize; + } + + @Override + public final long transfered() { + return transferred(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public long count() { + return contentSize; + } + + @Override + public AbstractFileRegion retain() { + super.retain(); + return this; + } + + @Override + public AbstractFileRegion retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public AbstractFileRegion touch() { + super.touch(); + return this; + } + + @Override + public AbstractFileRegion touch(Object o) { + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AddCreditRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AddCreditRequest.java index d09016985..92aa8ce4c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AddCreditRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/AddCreditRequest.java @@ -19,58 +19,60 @@ package org.apache.geaflow.shuffle.network.protocol; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import java.io.IOException; + import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + public class AddCreditRequest extends NettyMessage { - private final ChannelId receiverId; - private final int credit; + private final ChannelId receiverId; + private final int credit; - public AddCreditRequest(int credit, ChannelId receiverId) { - this.receiverId = receiverId; - this.credit = credit; - } + public AddCreditRequest(int credit, ChannelId receiverId) { + this.receiverId = receiverId; + this.credit = credit; + } - public ChannelId receiverId() { - return receiverId; - } + public ChannelId receiverId() { + return receiverId; + } - public int getCredit() { - return credit; - } + public int getCredit() { + return credit; + } - @Override - public Object write(ByteBufAllocator allocator) throws Exception { - ByteBuf result = null; + @Override + public Object write(ByteBufAllocator allocator) throws Exception { + ByteBuf result = null; - try { - int length = Integer.BYTES + ChannelId.CHANNEL_ID_BYTES; - result = allocateBuffer(allocator, MessageType.ADD_CREDIT_REQUEST.getId(), length); - result.writeInt(credit); - receiverId.writeTo(result); + try { + int length = Integer.BYTES + ChannelId.CHANNEL_ID_BYTES; + result = allocateBuffer(allocator, MessageType.ADD_CREDIT_REQUEST.getId(), length); + result.writeInt(credit); + receiverId.writeTo(result); - return result; - } catch (Throwable t) { - if (result != null) { - result.release(); - } + return result; + } catch (Throwable t) { + if (result != null) { + result.release(); + } - throw new IOException(t); - } + throw new IOException(t); } + } - public static AddCreditRequest readFrom(ByteBuf buffer) { - int credit = buffer.readInt(); - ChannelId receiverId = ChannelId.readFrom(buffer); + public static AddCreditRequest readFrom(ByteBuf buffer) { + int credit = buffer.readInt(); + ChannelId receiverId = ChannelId.readFrom(buffer); - return new AddCreditRequest(credit, receiverId); - } + return new AddCreditRequest(credit, receiverId); + } - @Override - public String toString() { - return String.format("AddCredit(%s: %d)", receiverId, credit); - } + @Override + public String toString() { + return String.format("AddCredit(%s: %d)", receiverId, credit); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/BatchRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/BatchRequest.java index 0f961dedf..f4190968b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/BatchRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/BatchRequest.java @@ -19,64 +19,63 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.IOException; + +import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; + import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import java.io.IOException; -import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; -/** - * Incremental request sequence id from the client to the server. - */ +/** Incremental request sequence id from the client to the server. */ public class BatchRequest extends NettyMessage { - final long nextBatchId; - final ChannelId receiverId; + final long nextBatchId; + final ChannelId receiverId; - public BatchRequest(long nextBatchId, ChannelId receiverId) { - Preconditions.checkArgument(nextBatchId >= 0, "The sequence id should be positive"); - this.nextBatchId = nextBatchId; - this.receiverId = receiverId; - } + public BatchRequest(long nextBatchId, ChannelId receiverId) { + Preconditions.checkArgument(nextBatchId >= 0, "The sequence id should be positive"); + this.nextBatchId = nextBatchId; + this.receiverId = receiverId; + } - public ChannelId receiverId() { - return receiverId; - } + public ChannelId receiverId() { + return receiverId; + } - public long getNextBatchId() { - return nextBatchId; - } + public long getNextBatchId() { + return nextBatchId; + } - @Override - public ByteBuf write(ByteBufAllocator allocator) throws IOException { - ByteBuf result = null; + @Override + public ByteBuf write(ByteBufAllocator allocator) throws IOException { + ByteBuf result = null; - try { - result = allocateBuffer(allocator, - MessageType.FETCH_BATCH_REQUEST.getId(), 8 + 16); - result.writeLong(nextBatchId); - receiverId.writeTo(result); + try { + result = allocateBuffer(allocator, MessageType.FETCH_BATCH_REQUEST.getId(), 8 + 16); + result.writeLong(nextBatchId); + receiverId.writeTo(result); - return result; - } catch (Throwable t) { - if (result != null) { - result.release(); - } + return result; + } catch (Throwable t) { + if (result != null) { + result.release(); + } - throw new IOException(t); - } + throw new IOException(t); } + } - public static BatchRequest readFrom(ByteBuf buffer) { - long nextBatchId = buffer.readLong(); - ChannelId receiverId = ChannelId.readFrom(buffer); + public static BatchRequest readFrom(ByteBuf buffer) { + long nextBatchId = buffer.readLong(); + ChannelId receiverId = ChannelId.readFrom(buffer); - return new BatchRequest(nextBatchId, receiverId); - } - - @Override - public String toString() { - return String.format("BatchRequest(%s: %d)", receiverId, nextBatchId); - } + return new BatchRequest(nextBatchId, receiverId); + } + @Override + public String toString() { + return String.format("BatchRequest(%s: %d)", receiverId, nextBatchId); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CancelRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CancelRequest.java index 470cdd200..a317fd19c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CancelRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CancelRequest.java @@ -19,47 +19,47 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.IOException; + +import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; + import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import java.io.IOException; -import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; -/** - * Message to notify producer to cancel. - */ +/** Message to notify producer to cancel. */ public class CancelRequest extends NettyMessage { - final ChannelId receiverId; - - public CancelRequest(ChannelId receiverId) { - this.receiverId = Preconditions.checkNotNull(receiverId); - } + final ChannelId receiverId; - public ChannelId receiverId() { - return receiverId; - } + public CancelRequest(ChannelId receiverId) { + this.receiverId = Preconditions.checkNotNull(receiverId); + } - @Override - public ByteBuf write(ByteBufAllocator allocator) throws Exception { - ByteBuf result = null; + public ChannelId receiverId() { + return receiverId; + } - try { - result = allocateBuffer(allocator, MessageType.CANCEL_CONNECTION.getId(), 16); - receiverId.writeTo(result); - } catch (Throwable t) { - if (result != null) { - result.release(); - } + @Override + public ByteBuf write(ByteBufAllocator allocator) throws Exception { + ByteBuf result = null; - throw new IOException(t); - } + try { + result = allocateBuffer(allocator, MessageType.CANCEL_CONNECTION.getId(), 16); + receiverId.writeTo(result); + } catch (Throwable t) { + if (result != null) { + result.release(); + } - return result; + throw new IOException(t); } - public static CancelRequest readFrom(ByteBuf buffer) throws Exception { - return new CancelRequest(ChannelId.readFrom(buffer)); - } + return result; + } + public static CancelRequest readFrom(ByteBuf buffer) throws Exception { + return new CancelRequest(ChannelId.readFrom(buffer)); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CloseRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CloseRequest.java index 3d46fc5be..ea7914759 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CloseRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CloseRequest.java @@ -24,13 +24,12 @@ public class CloseRequest extends NettyMessage { - @Override - public ByteBuf write(ByteBufAllocator allocator) throws Exception { - return allocateBuffer(allocator, MessageType.CLOSE_CONNECTION.getId(), 0); - } - - public static CloseRequest readFrom(@SuppressWarnings("unused") ByteBuf buffer) throws Exception { - return new CloseRequest(); - } + @Override + public ByteBuf write(ByteBufAllocator allocator) throws Exception { + return allocateBuffer(allocator, MessageType.CLOSE_CONNECTION.getId(), 0); + } + public static CloseRequest readFrom(@SuppressWarnings("unused") ByteBuf buffer) throws Exception { + return new CloseRequest(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CompositeFileRegion.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CompositeFileRegion.java index 6bb223d05..ee85972b4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CompositeFileRegion.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/CompositeFileRegion.java @@ -19,137 +19,138 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.channel.FileRegion; import io.netty.util.ReferenceCountUtil; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; /* This file is based on source code from the Spark Project (http://spark.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Spark's org.apache.spark.network.protocol.MessageWithHeader. - */ +/** This class is an adaptation of Spark's org.apache.spark.network.protocol.MessageWithHeader. */ public class CompositeFileRegion extends AbstractFileRegion { - private final ByteBuf header; - private final int headerLength; - private final Object body; - private long totalBytesTransferred; - - // When the write buffer size is larger than this limit, I/O will be done in chunks of this size. - // The size should not be too large as it will waste underlying memory copy. e.g. If network - // available buffer is smaller than this limit, the data cannot be sent within one single write - // operation while it still will make memory copy with this size. - private static final int NIO_BUFFER_LIMIT = 256 * 1024; - - /** - * Composite File Region. - * @param header the message header. - * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. - * @param contentSize the length of the message body and header, in bytes. - */ - public CompositeFileRegion(ByteBuf header, Object body, long contentSize) { - super(contentSize); - Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, - "Body must be a ByteBuf or a FileRegion."); - this.header = header; - this.headerLength = header.readableBytes(); - this.body = body; + private final ByteBuf header; + private final int headerLength; + private final Object body; + private long totalBytesTransferred; + + // When the write buffer size is larger than this limit, I/O will be done in chunks of this size. + // The size should not be too large as it will waste underlying memory copy. e.g. If network + // available buffer is smaller than this limit, the data cannot be sent within one single write + // operation while it still will make memory copy with this size. + private static final int NIO_BUFFER_LIMIT = 256 * 1024; + + /** + * Composite File Region. + * + * @param header the message header. + * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. + * @param contentSize the length of the message body and header, in bytes. + */ + public CompositeFileRegion(ByteBuf header, Object body, long contentSize) { + super(contentSize); + Preconditions.checkArgument( + body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + } + + @Override + public long transferred() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple transferTo + * invocations in order to transfer a single CompositeMessage to avoid busy waiting. + * + *

The contract is that the caller will ensure position is properly set to the total number of + * bytes transferred so far (i.e. value returned by transferred()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } } - @Override - public long transferred() { - return totalBytesTransferred; + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); } - - /** - * This code is more complicated than you would think because we might require multiple - * transferTo invocations in order to transfer a single CompositeMessage to avoid busy waiting. - * - *

The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transferred()). - */ - @Override - public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); - // Bytes written for header in this call. - long writtenHeader = 0; - if (header.readableBytes() > 0) { - writtenHeader = copyByteBuf(header, target); - totalBytesTransferred += writtenHeader; - if (header.readableBytes() > 0) { - return writtenHeader; - } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) + // to eliminate extra memory copies. + int written = 0; + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + written = target.write(buffer); + } else { + ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); + for (ByteBuffer buffer : buffers) { + int remaining = buffer.remaining(); + int w = target.write(buffer); + written += w; + if (w < remaining) { + // Could not write all, we need to break now. + break; } - - // Bytes written for body in this call. - long writtenBody = 0; - if (body instanceof FileRegion) { - writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); - } else if (body instanceof ByteBuf) { - writtenBody = copyByteBuf((ByteBuf) body, target); - } - totalBytesTransferred += writtenBody; - - return writtenHeader + writtenBody; + } } - - private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); - // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) - // to eliminate extra memory copies. - int written = 0; - if (buf.nioBufferCount() == 1) { - ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); - written = target.write(buffer); - } else { - ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); - for (ByteBuffer buffer : buffers) { - int remaining = buffer.remaining(); - int w = target.write(buffer); - written += w; - if (w < remaining) { - // Could not write all, we need to break now. - break; - } - } - } - buf.skipBytes(written); - return written; - } - - @Override - protected void deallocate() { - header.release(); - ReferenceCountUtil.release(body); - } - - @Override - public CompositeFileRegion touch(Object o) { - super.touch(o); - header.touch(o); - ReferenceCountUtil.touch(body, o); - return this; - } - - @Override - public CompositeFileRegion retain(int increment) { - super.retain(increment); - header.retain(increment); - ReferenceCountUtil.retain(body, increment); - return this; - } - - @Override - public boolean release(int decrement) { - header.release(decrement); - ReferenceCountUtil.release(body, decrement); - return super.release(decrement); - } - + buf.skipBytes(written); + return written; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + @Override + public CompositeFileRegion touch(Object o) { + super.touch(o); + header.touch(o); + ReferenceCountUtil.touch(body, o); + return this; + } + + @Override + public CompositeFileRegion retain(int increment) { + super.retain(increment); + header.retain(increment); + ReferenceCountUtil.retain(body, increment); + return this; + } + + @Override + public boolean release(int decrement) { + header.release(decrement); + ReferenceCountUtil.release(body, decrement); + return super.release(decrement); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/ErrorResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/ErrorResponse.java index 07a88fc82..3405af78a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/ErrorResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/ErrorResponse.java @@ -19,74 +19,75 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; + import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; public class ErrorResponse extends NettyMessage { - private ChannelId channelId; - private final Throwable cause; + private ChannelId channelId; + private final Throwable cause; - public ErrorResponse(ChannelId channelId, Throwable cause) { - this.channelId = channelId; - this.cause = Preconditions.checkNotNull(cause); - } + public ErrorResponse(ChannelId channelId, Throwable cause) { + this.channelId = channelId; + this.cause = Preconditions.checkNotNull(cause); + } - public ErrorResponse(Throwable cause) { - this.cause = Preconditions.checkNotNull(cause); - } + public ErrorResponse(Throwable cause) { + this.cause = Preconditions.checkNotNull(cause); + } - @Override - public ByteBuf write(ByteBufAllocator allocator) throws IOException { - final ByteBuf result = allocateBuffer(allocator, MessageType.ERROR_RESPONSE.getId()); - channelId.writeTo(result); - - try (ObjectOutputStream oos = new ObjectOutputStream(new ByteBufOutputStream(result))) { - oos.writeObject(cause); - result.setInt(0, result.readableBytes()); - return result; - } catch (Throwable t) { - result.release(); - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new IOException(t); - } - } - } + @Override + public ByteBuf write(ByteBufAllocator allocator) throws IOException { + final ByteBuf result = allocateBuffer(allocator, MessageType.ERROR_RESPONSE.getId()); + channelId.writeTo(result); - public static ErrorResponse readFrom(ByteBuf buffer) throws Exception { - ChannelId channelId = ChannelId.readFrom(buffer); + try (ObjectOutputStream oos = new ObjectOutputStream(new ByteBufOutputStream(result))) { + oos.writeObject(cause); + result.setInt(0, result.readableBytes()); + return result; + } catch (Throwable t) { + result.release(); + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new IOException(t); + } + } + } - try (ObjectInputStream ois = new ObjectInputStream(new ByteBufInputStream(buffer))) { - Object obj = ois.readObject(); + public static ErrorResponse readFrom(ByteBuf buffer) throws Exception { + ChannelId channelId = ChannelId.readFrom(buffer); - if (!(obj instanceof Throwable)) { - throw new ClassCastException( - "Read object expected to be of type Throwable, " + "actual type is " + obj - .getClass()); - } - return new ErrorResponse(channelId, (Throwable) obj); - } - } + try (ObjectInputStream ois = new ObjectInputStream(new ByteBufInputStream(buffer))) { + Object obj = ois.readObject(); - public ChannelId getChannelId() { - return channelId; + if (!(obj instanceof Throwable)) { + throw new ClassCastException( + "Read object expected to be of type Throwable, " + "actual type is " + obj.getClass()); + } + return new ErrorResponse(channelId, (Throwable) obj); } + } - public Throwable getCause() { - return cause; - } + public ChannelId getChannelId() { + return channelId; + } - public boolean isFatalError() { - return channelId == null; - } + public Throwable getCause() { + return cause; + } + public boolean isFatalError() { + return channelId == null; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryBytesFileRegion.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryBytesFileRegion.java index 7b0d4e0be..ca201da7e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryBytesFileRegion.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryBytesFileRegion.java @@ -27,55 +27,50 @@ * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Spark's org.apache.spark.util.io.ChunkedByteBufferFileRegion. - */ +/** This class is an adaptation of Spark's org.apache.spark.util.io.ChunkedByteBufferFileRegion. */ public class MemoryBytesFileRegion extends AbstractFileRegion { - private ByteBuffer curBuffer; + private ByteBuffer curBuffer; - public MemoryBytesFileRegion(byte[] buffer) { - super(buffer.length); - this.curBuffer = ByteBuffer.wrap(buffer); - } + public MemoryBytesFileRegion(byte[] buffer) { + super(buffer.length); + this.curBuffer = ByteBuffer.wrap(buffer); + } - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - assert (position == transferred); - if (position == contentSize) { - return 0L; - } - boolean keepGoing = true; - long written = 0L; - ByteBuffer currentBuffer = curBuffer; + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + assert (position == transferred); + if (position == contentSize) { + return 0L; + } + boolean keepGoing = true; + long written = 0L; + ByteBuffer currentBuffer = curBuffer; - while (keepGoing) { - while (currentBuffer.hasRemaining() && keepGoing) { - int ioSize = Math.min(currentBuffer.remaining(), chunkSize); - int originalLimit = currentBuffer.limit(); - currentBuffer.limit(currentBuffer.position() + ioSize); - int writtenSize = target.write(currentBuffer); - currentBuffer.limit(originalLimit); - written += writtenSize; - if (writtenSize < ioSize) { - // the channel did not accept our entire write. We do *not* keep trying -- netty wants - // us to just stop, and report how much we've written. - keepGoing = false; - } - } - if (keepGoing) { - curBuffer = null; - break; - } + while (keepGoing) { + while (currentBuffer.hasRemaining() && keepGoing) { + int ioSize = Math.min(currentBuffer.remaining(), chunkSize); + int originalLimit = currentBuffer.limit(); + currentBuffer.limit(currentBuffer.position() + ioSize); + int writtenSize = target.write(currentBuffer); + currentBuffer.limit(originalLimit); + written += writtenSize; + if (writtenSize < ioSize) { + // the channel did not accept our entire write. We do *not* keep trying -- netty wants + // us to just stop, and report how much we've written. + keepGoing = false; } - - transferred += written; - return written; + } + if (keepGoing) { + curBuffer = null; + break; + } } - @Override - protected void deallocate() { - - } + transferred += written; + return written; + } + @Override + protected void deallocate() {} } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryViewFileRegion.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryViewFileRegion.java index 32221d427..455613cf9 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryViewFileRegion.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MemoryViewFileRegion.java @@ -23,62 +23,61 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.util.Iterator; + import org.apache.geaflow.memory.ByteBuf; import org.apache.geaflow.memory.MemoryView; public class MemoryViewFileRegion extends AbstractFileRegion { - private final Iterator bufIterator; - private ByteBuffer curBuffer; + private final Iterator bufIterator; + private ByteBuffer curBuffer; - public MemoryViewFileRegion(MemoryView memoryView) { - super(memoryView.contentSize()); - this.bufIterator = memoryView.getBufList().iterator(); - this.curBuffer = this.getNextBuffer(); - } + public MemoryViewFileRegion(MemoryView memoryView) { + super(memoryView.contentSize()); + this.bufIterator = memoryView.getBufList().iterator(); + this.curBuffer = this.getNextBuffer(); + } - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - if (position == this.contentSize) { - return 0L; + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (position == this.contentSize) { + return 0L; + } + boolean keepGoing = true; + long written = 0L; + ByteBuffer currentBuffer = this.curBuffer; + while (keepGoing) { + while (currentBuffer.hasRemaining() && keepGoing) { + int ioSize = Math.min(currentBuffer.remaining(), this.chunkSize); + int originalLimit = currentBuffer.limit(); + currentBuffer.limit(currentBuffer.position() + ioSize); + int writtenSize = target.write(currentBuffer); + currentBuffer.limit(originalLimit); + written += writtenSize; + if (writtenSize < ioSize) { + keepGoing = false; } - boolean keepGoing = true; - long written = 0L; - ByteBuffer currentBuffer = this.curBuffer; - while (keepGoing) { - while (currentBuffer.hasRemaining() && keepGoing) { - int ioSize = Math.min(currentBuffer.remaining(), this.chunkSize); - int originalLimit = currentBuffer.limit(); - currentBuffer.limit(currentBuffer.position() + ioSize); - int writtenSize = target.write(currentBuffer); - currentBuffer.limit(originalLimit); - written += writtenSize; - if (writtenSize < ioSize) { - keepGoing = false; - } - } - if (keepGoing) { - if (!this.bufIterator.hasNext()) { - keepGoing = false; - } else { - currentBuffer = this.getNextBuffer(); - } - } + } + if (keepGoing) { + if (!this.bufIterator.hasNext()) { + keepGoing = false; + } else { + currentBuffer = this.getNextBuffer(); } - - this.transferred += written; - return written; + } } - @Override - protected void deallocate() { - } + this.transferred += written; + return written; + } - private ByteBuffer getNextBuffer() { - ByteBuffer buffer = this.bufIterator.next().getBf().duplicate(); - buffer.flip(); - this.curBuffer = buffer; - return buffer; - } + @Override + protected void deallocate() {} + private ByteBuffer getNextBuffer() { + ByteBuffer buffer = this.bufIterator.next().getBf().duplicate(); + buffer.flip(); + this.curBuffer = buffer; + return buffer; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MessageType.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MessageType.java index da42bfecb..9dc81fe7a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MessageType.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/MessageType.java @@ -22,45 +22,43 @@ import com.google.common.base.Preconditions; public enum MessageType { + ERROR_RESPONSE(1), + FETCH_SLICE_REQUEST(2), + FETCH_SLICE_RESPONSE(3), + FETCH_BATCH_REQUEST(4), + CLOSE_CONNECTION(5), + CANCEL_CONNECTION(6), + ADD_CREDIT_REQUEST(7); - ERROR_RESPONSE(1), - FETCH_SLICE_REQUEST(2), - FETCH_SLICE_RESPONSE(3), - FETCH_BATCH_REQUEST(4), - CLOSE_CONNECTION(5), - CANCEL_CONNECTION(6), - ADD_CREDIT_REQUEST(7); + private final byte id; - private final byte id; + MessageType(int id) { + Preconditions.checkArgument(id < 128, "Cannot have more than 128 message types"); + this.id = (byte) id; + } - MessageType(int id) { - Preconditions.checkArgument(id < 128, "Cannot have more than 128 message types"); - this.id = (byte) id; - } - - public static MessageType decode(byte id) { - switch (id) { - case 1: - return ERROR_RESPONSE; - case 2: - return FETCH_SLICE_REQUEST; - case 3: - return FETCH_SLICE_RESPONSE; - case 4: - return FETCH_BATCH_REQUEST; - case 5: - return CLOSE_CONNECTION; - case 6: - return CANCEL_CONNECTION; - case 7: - return ADD_CREDIT_REQUEST; - default: - throw new IllegalArgumentException("unrecognized MessageType:" + id); - } - } - - public byte getId() { - return id; + public static MessageType decode(byte id) { + switch (id) { + case 1: + return ERROR_RESPONSE; + case 2: + return FETCH_SLICE_REQUEST; + case 3: + return FETCH_SLICE_RESPONSE; + case 4: + return FETCH_BATCH_REQUEST; + case 5: + return CLOSE_CONNECTION; + case 6: + return CANCEL_CONNECTION; + case 7: + return ADD_CREDIT_REQUEST; + default: + throw new IllegalArgumentException("unrecognized MessageType:" + id); } + } + public byte getId() { + return id; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessage.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessage.java index c5bfeca95..6ab9e132d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessage.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessage.java @@ -19,10 +19,12 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.Serializable; + import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import java.io.Serializable; /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for @@ -33,75 +35,76 @@ */ public abstract class NettyMessage implements Serializable { - // ------------------------------------------------------------------------ - // Note: Every NettyMessage subtype needs to have a public 0-argument - // constructor in order to work with the generic deserializer. - // ------------------------------------------------------------------------ + // ------------------------------------------------------------------------ + // Note: Every NettyMessage subtype needs to have a public 0-argument + // constructor in order to work with the generic deserializer. + // ------------------------------------------------------------------------ - // frame length (4), magic number (4), msg ID (1) - public static final int FRAME_HEADER_LENGTH = 4 + 4 + 1; - public static final int MAGIC_NUMBER = 0xBADC0FFE; + // frame length (4), magic number (4), msg ID (1) + public static final int FRAME_HEADER_LENGTH = 4 + 4 + 1; + public static final int MAGIC_NUMBER = 0xBADC0FFE; - protected static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id) { - return allocateBuffer(allocator, id, -1); - } - - /** - * Allocates a new (header and contents) buffer and adds some header information for the frame - * decoder. - * - *

If the contentLength is unknown, you must write the actual length after adding - * the contents as an integer to position 0! - * - * @param allocator byte buffer allocator to use. - * @param id {@link NettyMessage} subclass ID. - * @param contentLength content length (or -1 if unknown). - * @return a newly allocated direct buffer with header data written for decoder. - */ - static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id, int contentLength) { - return allocateBuffer(allocator, id, 0, contentLength, true); - } + protected static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id) { + return allocateBuffer(allocator, id, -1); + } - /** - * Allocates a new buffer and adds some header information for the frame decoder. - * - *

If the contentLength is unknown, you must write the actual length after adding - * the contents as an integer to position 0! - * - * @param allocator byte buffer allocator to use. - * @param id {@link NettyMessage} subclass ID. - * @param messageHeaderLength additional header length that should be part of the allocated - * buffer and is written outside of this method. - * @param contentLength content length (or -1 if unknown). - * @param allocateForContent whether to make room for the actual content in the buffer - * (true) or whether to only return a buffer with the header information - * (false). - * @return a newly allocated direct buffer with header data written for decoder. - */ - public static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id, - int messageHeaderLength, int contentLength, - boolean allocateForContent) { - Preconditions.checkArgument(contentLength <= Integer.MAX_VALUE - FRAME_HEADER_LENGTH); + /** + * Allocates a new (header and contents) buffer and adds some header information for the frame + * decoder. + * + *

If the contentLength is unknown, you must write the actual length after adding the + * contents as an integer to position 0! + * + * @param allocator byte buffer allocator to use. + * @param id {@link NettyMessage} subclass ID. + * @param contentLength content length (or -1 if unknown). + * @return a newly allocated direct buffer with header data written for decoder. + */ + static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id, int contentLength) { + return allocateBuffer(allocator, id, 0, contentLength, true); + } - final ByteBuf buffer; - if (!allocateForContent) { - buffer = allocator.buffer(FRAME_HEADER_LENGTH + messageHeaderLength); - } else if (contentLength != -1) { - buffer = allocator.buffer(FRAME_HEADER_LENGTH + messageHeaderLength + contentLength); - } else { - // Content length unknown -> start with the default initial size (rather than - // FRAME_HEADER_LENGTH only): - buffer = allocator.buffer(); - } - // May be updated later, e.g. if contentLength == -1 - buffer.writeInt(FRAME_HEADER_LENGTH + messageHeaderLength + contentLength); - buffer.writeInt(MAGIC_NUMBER); - buffer.writeByte(id); + /** + * Allocates a new buffer and adds some header information for the frame decoder. + * + *

If the contentLength is unknown, you must write the actual length after adding the + * contents as an integer to position 0! + * + * @param allocator byte buffer allocator to use. + * @param id {@link NettyMessage} subclass ID. + * @param messageHeaderLength additional header length that should be part of the allocated buffer + * and is written outside of this method. + * @param contentLength content length (or -1 if unknown). + * @param allocateForContent whether to make room for the actual content in the buffer + * (true) or whether to only return a buffer with the header information + * (false). + * @return a newly allocated direct buffer with header data written for decoder. + */ + public static ByteBuf allocateBuffer( + ByteBufAllocator allocator, + byte id, + int messageHeaderLength, + int contentLength, + boolean allocateForContent) { + Preconditions.checkArgument(contentLength <= Integer.MAX_VALUE - FRAME_HEADER_LENGTH); - return buffer; + final ByteBuf buffer; + if (!allocateForContent) { + buffer = allocator.buffer(FRAME_HEADER_LENGTH + messageHeaderLength); + } else if (contentLength != -1) { + buffer = allocator.buffer(FRAME_HEADER_LENGTH + messageHeaderLength + contentLength); + } else { + // Content length unknown -> start with the default initial size (rather than + // FRAME_HEADER_LENGTH only): + buffer = allocator.buffer(); } + // May be updated later, e.g. if contentLength == -1 + buffer.writeInt(FRAME_HEADER_LENGTH + messageHeaderLength + contentLength); + buffer.writeInt(MAGIC_NUMBER); + buffer.writeByte(id); - public abstract Object write(ByteBufAllocator allocator) throws Exception; - + return buffer; + } + public abstract Object write(ByteBufAllocator allocator) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageDecoder.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageDecoder.java index 87a421022..def3518f6 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageDecoder.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageDecoder.java @@ -19,56 +19,55 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.net.ProtocolException; +import java.util.List; + import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToMessageDecoder; -import java.net.ProtocolException; -import java.util.List; @ChannelHandler.Sharable public class NettyMessageDecoder extends MessageToMessageDecoder { - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) - throws Exception { - int magicNumber = msg.readInt(); - - if (magicNumber != NettyMessage.MAGIC_NUMBER) { - throw new IllegalStateException( - "Network stream corrupted: received incorrect magic number."); - } + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + int magicNumber = msg.readInt(); - byte msgId = msg.readByte(); - MessageType msgType = MessageType.decode(msgId); + if (magicNumber != NettyMessage.MAGIC_NUMBER) { + throw new IllegalStateException("Network stream corrupted: received incorrect magic number."); + } - final NettyMessage decodedMsg; - switch (msgType) { - case ERROR_RESPONSE: - decodedMsg = ErrorResponse.readFrom(msg); - break; - case FETCH_SLICE_REQUEST: - decodedMsg = SliceRequest.readFrom(msg); - break; - case FETCH_SLICE_RESPONSE: - decodedMsg = SliceResponse.readFrom(msg); - break; - case FETCH_BATCH_REQUEST: - decodedMsg = BatchRequest.readFrom(msg); - break; - case CLOSE_CONNECTION: - decodedMsg = CloseRequest.readFrom(msg); - break; - case CANCEL_CONNECTION: - decodedMsg = CancelRequest.readFrom(msg); - break; - case ADD_CREDIT_REQUEST: - decodedMsg = AddCreditRequest.readFrom(msg); - break; - default: - throw new ProtocolException("Received unknown message from producer: " + msg); - } + byte msgId = msg.readByte(); + MessageType msgType = MessageType.decode(msgId); - out.add(decodedMsg); + final NettyMessage decodedMsg; + switch (msgType) { + case ERROR_RESPONSE: + decodedMsg = ErrorResponse.readFrom(msg); + break; + case FETCH_SLICE_REQUEST: + decodedMsg = SliceRequest.readFrom(msg); + break; + case FETCH_SLICE_RESPONSE: + decodedMsg = SliceResponse.readFrom(msg); + break; + case FETCH_BATCH_REQUEST: + decodedMsg = BatchRequest.readFrom(msg); + break; + case CLOSE_CONNECTION: + decodedMsg = CloseRequest.readFrom(msg); + break; + case CANCEL_CONNECTION: + decodedMsg = CancelRequest.readFrom(msg); + break; + case ADD_CREDIT_REQUEST: + decodedMsg = AddCreditRequest.readFrom(msg); + break; + default: + throw new ProtocolException("Received unknown message from producer: " + msg); } + + out.add(decodedMsg); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageEncoder.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageEncoder.java index 0651c7e98..96fe5a756 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageEncoder.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/NettyMessageEncoder.java @@ -19,32 +19,32 @@ package org.apache.geaflow.shuffle.network.protocol; +import java.io.IOException; + import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; -import java.io.IOException; @ChannelHandler.Sharable public class NettyMessageEncoder extends ChannelOutboundHandlerAdapter { - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - if (msg instanceof NettyMessage) { - Object serialized = null; - try { - serialized = ((NettyMessage) msg).write(ctx.alloc()); - } catch (Throwable t) { - throw new IOException("Error while serializing message: " + msg, t); - } finally { - if (serialized != null) { - ctx.write(serialized, promise); - } - } - } else { - ctx.write(msg, promise); + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (msg instanceof NettyMessage) { + Object serialized = null; + try { + serialized = ((NettyMessage) msg).write(ctx.alloc()); + } catch (Throwable t) { + throw new IOException("Error while serializing message: " + msg, t); + } finally { + if (serialized != null) { + ctx.write(serialized, promise); } + } + } else { + ctx.write(msg, promise); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceRequest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceRequest.java index 0e2854e25..f6ffcdd02 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceRequest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceRequest.java @@ -19,79 +19,80 @@ package org.apache.geaflow.shuffle.network.protocol; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import java.io.IOException; + import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; -public class SliceRequest extends NettyMessage { - - final SliceId sliceId; - final long startBatchId; - final ChannelId receiverId; - final int initialCredit; - - public SliceRequest(SliceId sliceId, long startBatchId, ChannelId receiverId, - int initialCredits) { - this.sliceId = sliceId; - this.startBatchId = startBatchId; - this.receiverId = receiverId; - this.initialCredit = initialCredits; - } - - public ChannelId getReceiverId() { - return receiverId; - } - - public SliceId getSliceId() { - return sliceId; - } - - public long getStartBatchId() { - return startBatchId; - } - - public int getInitialCredit() { - return initialCredit; - } - - @Override - public ByteBuf write(ByteBufAllocator allocator) throws IOException { - ByteBuf result = null; - - try { - int length = - SliceId.SLICE_ID_BYTES + ChannelId.CHANNEL_ID_BYTES + Long.BYTES + Integer.BYTES; - result = allocateBuffer(allocator, MessageType.FETCH_SLICE_REQUEST.getId(), length); - - sliceId.writeTo(result); - receiverId.writeTo(result); - result.writeLong(startBatchId); - result.writeInt(initialCredit); - - return result; - } catch (Throwable t) { - if (result != null) { - result.release(); - } - throw new IOException(t); - } - } - - public static SliceRequest readFrom(ByteBuf buffer) { - SliceId sliceId = SliceId.readFrom(buffer); - ChannelId receiverId = ChannelId.readFrom(buffer); - long startBatchId = buffer.readLong(); - int initialCredits = buffer.readInt(); +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; - return new SliceRequest(sliceId, startBatchId, receiverId, initialCredits); - } +public class SliceRequest extends NettyMessage { - @Override - public String toString() { - return String.format("SliceFetchRequest(%s, startBatchId=%s, initCredit=%s)", sliceId, - startBatchId, initialCredit); + final SliceId sliceId; + final long startBatchId; + final ChannelId receiverId; + final int initialCredit; + + public SliceRequest( + SliceId sliceId, long startBatchId, ChannelId receiverId, int initialCredits) { + this.sliceId = sliceId; + this.startBatchId = startBatchId; + this.receiverId = receiverId; + this.initialCredit = initialCredits; + } + + public ChannelId getReceiverId() { + return receiverId; + } + + public SliceId getSliceId() { + return sliceId; + } + + public long getStartBatchId() { + return startBatchId; + } + + public int getInitialCredit() { + return initialCredit; + } + + @Override + public ByteBuf write(ByteBufAllocator allocator) throws IOException { + ByteBuf result = null; + + try { + int length = SliceId.SLICE_ID_BYTES + ChannelId.CHANNEL_ID_BYTES + Long.BYTES + Integer.BYTES; + result = allocateBuffer(allocator, MessageType.FETCH_SLICE_REQUEST.getId(), length); + + sliceId.writeTo(result); + receiverId.writeTo(result); + result.writeLong(startBatchId); + result.writeInt(initialCredit); + + return result; + } catch (Throwable t) { + if (result != null) { + result.release(); + } + throw new IOException(t); } - + } + + public static SliceRequest readFrom(ByteBuf buffer) { + SliceId sliceId = SliceId.readFrom(buffer); + ChannelId receiverId = ChannelId.readFrom(buffer); + long startBatchId = buffer.readLong(); + int initialCredits = buffer.readInt(); + + return new SliceRequest(sliceId, startBatchId, receiverId, initialCredits); + } + + @Override + public String toString() { + return String.format( + "SliceFetchRequest(%s, startBatchId=%s, initCredit=%s)", + sliceId, startBatchId, initialCredit); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceResponse.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceResponse.java index 5f8c04ff4..a5a8a19b0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceResponse.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/network/protocol/SliceResponse.java @@ -19,94 +19,93 @@ package org.apache.geaflow.shuffle.network.protocol; +import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; +import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.FileRegion; -import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; -import org.apache.geaflow.shuffle.pipeline.channel.ChannelId; public class SliceResponse extends NettyMessage { - final PipeBuffer buffer; - final ChannelId receiverId; - final int sequenceNumber; - final int bufferSize; - - public SliceResponse(PipeBuffer buffer, int sequenceNumber, - ChannelId inputChannelId) { - this.buffer = buffer; - this.sequenceNumber = sequenceNumber; - this.receiverId = inputChannelId; - this.bufferSize = buffer.getBuffer() != null ? buffer.getBufferSize() : 0; - } - - public PipeBuffer getBuffer() { - return buffer; - } - - public ChannelId getReceiverId() { - return receiverId; + final PipeBuffer buffer; + final ChannelId receiverId; + final int sequenceNumber; + final int bufferSize; + + public SliceResponse(PipeBuffer buffer, int sequenceNumber, ChannelId inputChannelId) { + this.buffer = buffer; + this.sequenceNumber = sequenceNumber; + this.receiverId = inputChannelId; + this.bufferSize = buffer.getBuffer() != null ? buffer.getBufferSize() : 0; + } + + public PipeBuffer getBuffer() { + return buffer; + } + + public ChannelId getReceiverId() { + return receiverId; + } + + public int getSequenceNumber() { + return sequenceNumber; + } + + public int getBufferSize() { + return bufferSize; + } + + @Override + public Object write(ByteBufAllocator allocator) throws Exception { + if (buffer.isData()) { + int headerLen = 16 + 8 + 4 + 1; + int contentSize = buffer.getBufferSize(); + // Only allocate header buffer - we will combine it with the data buffer below. + ByteBuf headerBuf = + allocateBuffer( + allocator, MessageType.FETCH_SLICE_RESPONSE.getId(), headerLen, contentSize, false); + + receiverId.writeTo(headerBuf); + headerBuf.writeLong(buffer.getBatchId()); + headerBuf.writeInt(sequenceNumber); + headerBuf.writeBoolean(buffer.isData()); + + int totalSize = headerBuf.readableBytes() + contentSize; + headerBuf.setInt(0, totalSize); + + FileRegion body = buffer.getBuffer().toFileRegion(); + return new CompositeFileRegion(headerBuf, body, totalSize); + } else { + final ByteBuf result = allocateBuffer(allocator, MessageType.FETCH_SLICE_RESPONSE.getId()); + receiverId.writeTo(result); + result.writeLong(buffer.getBatchId()); + result.writeInt(sequenceNumber); + result.writeBoolean(buffer.isData()); + result.writeInt(buffer.getCount()); + result.writeBoolean(buffer.isFinish()); + result.setInt(0, result.readableBytes()); + return result; } - - public int getSequenceNumber() { - return sequenceNumber; - } - - public int getBufferSize() { - return bufferSize; - } - - @Override - public Object write(ByteBufAllocator allocator) throws Exception { - if (buffer.isData()) { - int headerLen = 16 + 8 + 4 + 1; - int contentSize = buffer.getBufferSize(); - // Only allocate header buffer - we will combine it with the data buffer below. - ByteBuf headerBuf = allocateBuffer(allocator, MessageType.FETCH_SLICE_RESPONSE.getId(), - headerLen, contentSize, false); - - receiverId.writeTo(headerBuf); - headerBuf.writeLong(buffer.getBatchId()); - headerBuf.writeInt(sequenceNumber); - headerBuf.writeBoolean(buffer.isData()); - - int totalSize = headerBuf.readableBytes() + contentSize; - headerBuf.setInt(0, totalSize); - - FileRegion body = buffer.getBuffer().toFileRegion(); - return new CompositeFileRegion(headerBuf, body, totalSize); - } else { - final ByteBuf result = allocateBuffer(allocator, MessageType.FETCH_SLICE_RESPONSE.getId()); - receiverId.writeTo(result); - result.writeLong(buffer.getBatchId()); - result.writeInt(sequenceNumber); - result.writeBoolean(buffer.isData()); - result.writeInt(buffer.getCount()); - result.writeBoolean(buffer.isFinish()); - result.setInt(0, result.readableBytes()); - return result; - } - - } - - public static SliceResponse readFrom(ByteBuf buf) throws Exception { - ChannelId inputChannelId = ChannelId.readFrom(buf); - long batchId = buf.readLong(); - int sequenceNum = buf.readInt(); - boolean isData = buf.readBoolean(); - - PipeBuffer recordBuffer; - if (isData) { - byte[] bytes = new byte[buf.readableBytes()]; - buf.readBytes(bytes); - recordBuffer = new PipeBuffer(bytes, batchId); - } else { - int count = buf.readInt(); - boolean isFinish = buf.readBoolean(); - recordBuffer = new PipeBuffer(batchId, count, isFinish); - } - - return new SliceResponse(recordBuffer, sequenceNum, inputChannelId); + } + + public static SliceResponse readFrom(ByteBuf buf) throws Exception { + ChannelId inputChannelId = ChannelId.readFrom(buf); + long batchId = buf.readLong(); + int sequenceNum = buf.readInt(); + boolean isData = buf.readBoolean(); + + PipeBuffer recordBuffer; + if (isData) { + byte[] bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + recordBuffer = new PipeBuffer(bytes, batchId); + } else { + int count = buf.readInt(); + boolean isFinish = buf.readBoolean(); + recordBuffer = new PipeBuffer(batchId, count, isFinish); } + return new SliceResponse(recordBuffer, sequenceNum, inputChannelId); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/AbstractBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/AbstractBuffer.java index 0a89fbfd9..8be876058 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/AbstractBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/AbstractBuffer.java @@ -23,53 +23,51 @@ public abstract class AbstractBuffer implements OutBuffer { - private final ShuffleMemoryTracker memoryTracker; + private final ShuffleMemoryTracker memoryTracker; - public AbstractBuffer(boolean enableMemoryTrack) { - this.memoryTracker = enableMemoryTrack - ? ShuffleManager.getInstance().getShuffleMemoryTracker() : null; - } + public AbstractBuffer(boolean enableMemoryTrack) { + this.memoryTracker = + enableMemoryTrack ? ShuffleManager.getInstance().getShuffleMemoryTracker() : null; + } - public AbstractBuffer(ShuffleMemoryTracker memoryTracker) { - this.memoryTracker = memoryTracker; - } + public AbstractBuffer(ShuffleMemoryTracker memoryTracker) { + this.memoryTracker = memoryTracker; + } - protected void requireMemory(long dataSize) { - if (this.memoryTracker != null) { - memoryTracker.requireMemory(dataSize); - } + protected void requireMemory(long dataSize) { + if (this.memoryTracker != null) { + memoryTracker.requireMemory(dataSize); } + } - protected void releaseMemory(long dataSize) { - if (this.memoryTracker != null) { - memoryTracker.releaseMemory(dataSize); - } + protected void releaseMemory(long dataSize) { + if (this.memoryTracker != null) { + memoryTracker.releaseMemory(dataSize); } + } - protected abstract static class AbstractBufferBuilder implements BufferBuilder { - - private long recordCount; - protected boolean memoryTrack; - - @Override - public long getRecordCount() { - return this.recordCount; - } + protected abstract static class AbstractBufferBuilder implements BufferBuilder { - @Override - public void increaseRecordCount() { - this.recordCount++; - } + private long recordCount; + protected boolean memoryTrack; - protected void resetRecordCount() { - this.recordCount = 0; - } + @Override + public long getRecordCount() { + return this.recordCount; + } - @Override - public void enableMemoryTrack() { - this.memoryTrack = true; - } + @Override + public void increaseRecordCount() { + this.recordCount++; + } + protected void resetRecordCount() { + this.recordCount = 0; } + @Override + public void enableMemoryTrack() { + this.memoryTrack = true; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/HeapBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/HeapBuffer.java index fa7e21a2d..e0a8ef892 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/HeapBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/HeapBuffer.java @@ -19,106 +19,106 @@ package org.apache.geaflow.shuffle.pipeline.buffer; -import io.netty.channel.FileRegion; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.network.protocol.MemoryBytesFileRegion; -public class HeapBuffer extends AbstractBuffer { +import io.netty.channel.FileRegion; - private static final int INITIAL_BUFFER_SIZE = 4096; - private byte[] bytes; +public class HeapBuffer extends AbstractBuffer { - public HeapBuffer(byte[] bytes) { - this(bytes, true); + private static final int INITIAL_BUFFER_SIZE = 4096; + private byte[] bytes; + + public HeapBuffer(byte[] bytes) { + this(bytes, true); + } + + public HeapBuffer(byte[] bytes, ShuffleMemoryTracker memoryTracker) { + super(memoryTracker); + this.bytes = bytes; + this.requireMemory(bytes.length); + } + + public HeapBuffer(byte[] bytes, boolean memoryTrack) { + super(memoryTrack); + this.bytes = bytes; + this.requireMemory(bytes.length); + } + + @Override + public InputStream getInputStream() { + return new ByteArrayInputStream(bytes); + } + + @Override + public int getBufferSize() { + return this.bytes == null ? 0 : this.bytes.length; + } + + @Override + public void write(OutputStream outputStream) throws IOException { + if (this.bytes != null) { + outputStream.write(this.bytes); } - - public HeapBuffer(byte[] bytes, ShuffleMemoryTracker memoryTracker) { - super(memoryTracker); - this.bytes = bytes; - this.requireMemory(bytes.length); + } + + @Override + public FileRegion toFileRegion() { + return new MemoryBytesFileRegion(this.bytes); + } + + @Override + public void release() { + if (this.bytes != null) { + int dataSize = this.bytes.length; + releaseMemory(dataSize); + this.bytes = null; } + } - public HeapBuffer(byte[] bytes, boolean memoryTrack) { - super(memoryTrack); - this.bytes = bytes; - this.requireMemory(bytes.length); - } + public static class HeapBufferBuilder extends AbstractBufferBuilder { + private ByteArrayOutputStream outputStream; @Override - public InputStream getInputStream() { - return new ByteArrayInputStream(bytes); + public OutputStream getOutputStream() { + if (outputStream == null) { + outputStream = new ByteArrayOutputStream(INITIAL_BUFFER_SIZE); + } + return outputStream; } @Override - public int getBufferSize() { - return this.bytes == null ? 0 : this.bytes.length; - } + public void positionStream(int position) {} @Override - public void write(OutputStream outputStream) throws IOException { - if (this.bytes != null) { - outputStream.write(this.bytes); - } + public int getBufferSize() { + return this.outputStream != null ? outputStream.size() : 0; } @Override - public FileRegion toFileRegion() { - return new MemoryBytesFileRegion(this.bytes); + public OutBuffer build() { + byte[] bytes = outputStream.toByteArray(); + this.outputStream.reset(); + this.resetRecordCount(); + return new HeapBuffer(bytes, this.memoryTrack); } @Override - public void release() { - if (this.bytes != null) { - int dataSize = this.bytes.length; - releaseMemory(dataSize); - this.bytes = null; - } - } - - public static class HeapBufferBuilder extends AbstractBufferBuilder { - private ByteArrayOutputStream outputStream; - - @Override - public OutputStream getOutputStream() { - if (outputStream == null) { - outputStream = new ByteArrayOutputStream(INITIAL_BUFFER_SIZE); - } - return outputStream; - } - - @Override - public void positionStream(int position) { - } - - @Override - public int getBufferSize() { - return this.outputStream != null ? outputStream.size() : 0; - } - - @Override - public OutBuffer build() { - byte[] bytes = outputStream.toByteArray(); - this.outputStream.reset(); - this.resetRecordCount(); - return new HeapBuffer(bytes, this.memoryTrack); - } - - @Override - public void close() { - if (this.outputStream != null) { - try { - this.outputStream.close(); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - this.outputStream = null; - } + public void close() { + if (this.outputStream != null) { + try { + this.outputStream.close(); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + this.outputStream = null; + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBuffer.java index cbecc721c..c405b1f8d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBuffer.java @@ -19,10 +19,10 @@ package org.apache.geaflow.shuffle.pipeline.buffer; -import io.netty.channel.FileRegion; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; import org.apache.geaflow.memory.MemoryView; @@ -30,88 +30,87 @@ import org.apache.geaflow.memory.channel.ByteArrayOutputStream; import org.apache.geaflow.shuffle.network.protocol.MemoryViewFileRegion; -public class MemoryViewBuffer extends AbstractBuffer { +import io.netty.channel.FileRegion; - private final MemoryView memoryView; +public class MemoryViewBuffer extends AbstractBuffer { - public MemoryViewBuffer(MemoryView memoryView, boolean memoryTrack) { - super(memoryTrack); - this.memoryView = memoryView; - this.requireMemory(memoryView.contentSize()); + private final MemoryView memoryView; + + public MemoryViewBuffer(MemoryView memoryView, boolean memoryTrack) { + super(memoryTrack); + this.memoryView = memoryView; + this.requireMemory(memoryView.contentSize()); + } + + public MemoryViewBuffer(MemoryView memoryView, ShuffleMemoryTracker memoryTracker) { + super(memoryTracker); + this.memoryView = memoryView; + this.requireMemory(memoryView.contentSize()); + } + + @Override + public InputStream getInputStream() { + return new ByteArrayInputStream(this.memoryView); + } + + @Override + public FileRegion toFileRegion() { + return new MemoryViewFileRegion(this.memoryView); + } + + @Override + public int getBufferSize() { + return this.memoryView.contentSize(); + } + + @Override + public void write(OutputStream outputStream) throws IOException { + if (this.memoryView.contentSize() > 0) { + this.memoryView.getReader().read(outputStream); } + } - public MemoryViewBuffer(MemoryView memoryView, ShuffleMemoryTracker memoryTracker) { - super(memoryTracker); - this.memoryView = memoryView; - this.requireMemory(memoryView.contentSize()); - } + @Override + public void release() { + int contentSize = memoryView.contentSize(); + this.memoryView.close(); + this.releaseMemory(contentSize); + } + + public static class MemoryViewBufferBuilder extends AbstractBufferBuilder { + + private MemoryView memoryView; + private OutputStream outputStream; @Override - public InputStream getInputStream() { - return new ByteArrayInputStream(this.memoryView); + public OutputStream getOutputStream() { + if (this.memoryView == null) { + this.memoryView = + MemoryManager.getInstance() + .requireMemory(MemoryGroupManger.SHUFFLE.getSpanSize(), MemoryGroupManger.SHUFFLE); + this.outputStream = new ByteArrayOutputStream(this.memoryView); + } + return this.outputStream; } @Override - public FileRegion toFileRegion() { - return new MemoryViewFileRegion(this.memoryView); - } + public void positionStream(int position) {} @Override public int getBufferSize() { - return this.memoryView.contentSize(); + return this.memoryView == null ? 0 : this.memoryView.contentSize(); } @Override - public void write(OutputStream outputStream) throws IOException { - if (this.memoryView.contentSize() > 0) { - this.memoryView.getReader().read(outputStream); - } + public OutBuffer build() { + final MemoryViewBuffer buffer = new MemoryViewBuffer(this.memoryView, this.memoryTrack); + this.memoryView = null; + this.outputStream = null; + this.resetRecordCount(); + return buffer; } @Override - public void release() { - int contentSize = memoryView.contentSize(); - this.memoryView.close(); - this.releaseMemory(contentSize); - } - - public static class MemoryViewBufferBuilder extends AbstractBufferBuilder { - - private MemoryView memoryView; - private OutputStream outputStream; - - @Override - public OutputStream getOutputStream() { - if (this.memoryView == null) { - this.memoryView = MemoryManager.getInstance().requireMemory( - MemoryGroupManger.SHUFFLE.getSpanSize(), MemoryGroupManger.SHUFFLE); - this.outputStream = new ByteArrayOutputStream(this.memoryView); - } - return this.outputStream; - } - - @Override - public void positionStream(int position) { - } - - @Override - public int getBufferSize() { - return this.memoryView == null ? 0 : this.memoryView.contentSize(); - } - - @Override - public OutBuffer build() { - final MemoryViewBuffer buffer = new MemoryViewBuffer(this.memoryView, this.memoryTrack); - this.memoryView = null; - this.outputStream = null; - this.resetRecordCount(); - return buffer; - } - - @Override - public void close() { - } - - } - + public void close() {} + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/OutBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/OutBuffer.java index db70ac74c..ac09b548b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/OutBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/OutBuffer.java @@ -19,99 +19,90 @@ package org.apache.geaflow.shuffle.pipeline.buffer; -import io.netty.channel.FileRegion; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import io.netty.channel.FileRegion; + public interface OutBuffer { + /** + * Get the input stream of this buffer. + * + * @return buffer input stream. + */ + InputStream getInputStream(); + + /** + * Convert this buffer to file region. + * + * @return file region. + */ + FileRegion toFileRegion(); + + /** + * Get the buffer size of this buffer. + * + * @return buffer size in bytes. + */ + int getBufferSize(); + + /** + * Write data from a output stream. + * + * @param outputStream output stream. + * @throws IOException io exception. + */ + void write(OutputStream outputStream) throws IOException; + + /** Release this buffer. */ + void release(); + + interface BufferBuilder { + /** - * Get the input stream of this buffer. + * Get the OutputStream. * - * @return buffer input stream. + * @return output stream */ - InputStream getInputStream(); + OutputStream getOutputStream(); /** - * Convert this buffer to file region. + * Set the position of the stream. * - * @return file region. + * @param position position */ - FileRegion toFileRegion(); + void positionStream(int position); /** - * Get the buffer size of this buffer. + * Get the buffer size. * - * @return buffer size in bytes. + * @return buffer size */ int getBufferSize(); /** - * Write data from a output stream. + * Get record count in the buffer. * - * @param outputStream output stream. - * @throws IOException io exception. + * @return record count */ - void write(OutputStream outputStream) throws IOException; + long getRecordCount(); + + /** Increase the record count. */ + void increaseRecordCount(); + + /** Set memory track. */ + void enableMemoryTrack(); /** - * Release this buffer. + * Build the buffer. + * + * @return buffer. */ - void release(); - - interface BufferBuilder { - - /** - * Get the OutputStream. - * - * @return output stream - */ - OutputStream getOutputStream(); - - /** - * Set the position of the stream. - * - * @param position position - */ - void positionStream(int position); - - /** - * Get the buffer size. - * - * @return buffer size - */ - int getBufferSize(); - - /** - * Get record count in the buffer. - * - * @return record count - */ - long getRecordCount(); - - /** - * Increase the record count. - */ - void increaseRecordCount(); - - /** - * Set memory track. - */ - void enableMemoryTrack(); - - /** - * Build the buffer. - * - * @return buffer. - */ - OutBuffer build(); - - /** - * Close this builder. - */ - void close(); - - } + OutBuffer build(); + /** Close this builder. */ + void close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeBuffer.java index 23a19941f..f59d3a055 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeBuffer.java @@ -23,53 +23,53 @@ public class PipeBuffer implements Serializable { - private final OutBuffer buffer; - private final boolean isData; - private final long batchId; - private final int count; - private final boolean isFinish; + private final OutBuffer buffer; + private final boolean isData; + private final long batchId; + private final int count; + private final boolean isFinish; - public PipeBuffer(byte[] buffer, long batchId) { - this(new HeapBuffer(buffer), batchId); - } + public PipeBuffer(byte[] buffer, long batchId) { + this(new HeapBuffer(buffer), batchId); + } - public PipeBuffer(OutBuffer buffer, long batchId) { - this.buffer = buffer; - this.batchId = batchId; - this.isData = true; - this.count = 0; - this.isFinish = false; - } + public PipeBuffer(OutBuffer buffer, long batchId) { + this.buffer = buffer; + this.batchId = batchId; + this.isData = true; + this.count = 0; + this.isFinish = false; + } - public PipeBuffer(long batchId, int count, boolean isFinish) { - this.buffer = null; - this.batchId = batchId; - this.isData = false; - this.count = count; - this.isFinish = isFinish; - } + public PipeBuffer(long batchId, int count, boolean isFinish) { + this.buffer = null; + this.batchId = batchId; + this.isData = false; + this.count = count; + this.isFinish = isFinish; + } - public OutBuffer getBuffer() { - return buffer; - } + public OutBuffer getBuffer() { + return buffer; + } - public int getBufferSize() { - return buffer != null ? buffer.getBufferSize() : 0; - } + public int getBufferSize() { + return buffer != null ? buffer.getBufferSize() : 0; + } - public boolean isData() { - return isData; - } + public boolean isData() { + return isData; + } - public long getBatchId() { - return batchId; - } + public long getBatchId() { + return batchId; + } - public int getCount() { - return count; - } + public int getCount() { + return count; + } - public boolean isFinish() { - return isFinish; - } + public boolean isFinish() { + return isFinish; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeChannelBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeChannelBuffer.java index c77e5059a..e78046920 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeChannelBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeChannelBuffer.java @@ -22,27 +22,23 @@ import org.apache.geaflow.shuffle.pipeline.channel.LocalInputChannel; import org.apache.geaflow.shuffle.pipeline.slice.SequenceSliceReader; -/** - * Message consumed by channels ({@link LocalInputChannel} - * and {@link SequenceSliceReader}). - */ +/** Message consumed by channels ({@link LocalInputChannel} and {@link SequenceSliceReader}). */ public class PipeChannelBuffer { - private final PipeBuffer buffer; - // Indicate the availability of message in PipeSlice. - private final boolean moreAvailable; - - public PipeChannelBuffer(PipeBuffer buffer, boolean moreAvailable) { - this.buffer = buffer; - this.moreAvailable = moreAvailable; - } + private final PipeBuffer buffer; + // Indicate the availability of message in PipeSlice. + private final boolean moreAvailable; - public PipeBuffer getBuffer() { - return buffer; - } + public PipeChannelBuffer(PipeBuffer buffer, boolean moreAvailable) { + this.buffer = buffer; + this.moreAvailable = moreAvailable; + } - public boolean moreAvailable() { - return moreAvailable; - } + public PipeBuffer getBuffer() { + return buffer; + } + public boolean moreAvailable() { + return moreAvailable; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeFetcherBuffer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeFetcherBuffer.java index 9f43969f4..87a7d34ef 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeFetcherBuffer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/PipeFetcherBuffer.java @@ -24,80 +24,82 @@ import org.apache.geaflow.shuffle.pipeline.fetcher.ShardFetcher; /** - * Message consumed by {@link ShardFetcher}. - * One fetcher consumes buffers from multiple {@link InputChannel} - * which is identified by channelIndex. + * Message consumed by {@link ShardFetcher}. One fetcher consumes buffers from multiple {@link + * InputChannel} which is identified by channelIndex. */ public class PipeFetcherBuffer { - private final PipeBuffer buffer; - private final SliceId sliceId; - private final String streamName; - - // indicate if has more in the channel. - private boolean moreAvailable; - // current fetched channel. - private int channelIndex; - - public PipeFetcherBuffer(PipeBuffer pipeBuffer, int channelIndex, - boolean moreAvailable, SliceId sliceId, String streamName) { - this.buffer = pipeBuffer; - this.channelIndex = channelIndex; - this.moreAvailable = moreAvailable; - this.sliceId = sliceId; - this.streamName = streamName; + private final PipeBuffer buffer; + private final SliceId sliceId; + private final String streamName; + + // indicate if has more in the channel. + private boolean moreAvailable; + // current fetched channel. + private int channelIndex; + + public PipeFetcherBuffer( + PipeBuffer pipeBuffer, + int channelIndex, + boolean moreAvailable, + SliceId sliceId, + String streamName) { + this.buffer = pipeBuffer; + this.channelIndex = channelIndex; + this.moreAvailable = moreAvailable; + this.sliceId = sliceId; + this.streamName = streamName; + } + + public boolean moreAvailable() { + return moreAvailable; + } + + public void setMoreAvailable(boolean moreAvailable) { + this.moreAvailable = moreAvailable; + } + + public int getChannelIndex() { + return channelIndex; + } + + public void setChannelIndex(int channelIndex) { + this.channelIndex = channelIndex; + } + + public SliceId getSliceId() { + return sliceId; + } + + public String getStreamName() { + return streamName; + } + + public int getBufferSize() { + if (buffer != null) { + return buffer.getBuffer().getBufferSize(); + } else { + return 0; } + } - public boolean moreAvailable() { - return moreAvailable; - } - - public void setMoreAvailable(boolean moreAvailable) { - this.moreAvailable = moreAvailable; - } - - public int getChannelIndex() { - return channelIndex; - } - - public void setChannelIndex(int channelIndex) { - this.channelIndex = channelIndex; - } - - public SliceId getSliceId() { - return sliceId; - } - - public String getStreamName() { - return streamName; - } - - public int getBufferSize() { - if (buffer != null) { - return buffer.getBuffer().getBufferSize(); - } else { - return 0; - } - } - - public OutBuffer getBuffer() { - return buffer.getBuffer(); - } + public OutBuffer getBuffer() { + return buffer.getBuffer(); + } - public boolean isBarrier() { - return !buffer.isData(); - } + public boolean isBarrier() { + return !buffer.isData(); + } - public long getBatchId() { - return buffer.getBatchId(); - } + public long getBatchId() { + return buffer.getBatchId(); + } - public int getBatchCount() { - return buffer.getCount(); - } - - public boolean isFinish() { - return buffer.isFinish(); - } + public int getBatchCount() { + return buffer.getCount(); + } + public boolean isFinish() { + return buffer.isFinish(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTracker.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTracker.java index 356207e6c..200188d7e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTracker.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTracker.java @@ -26,6 +26,7 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_OFFHEAP_MEMORY_FRACTION; import java.util.concurrent.atomic.AtomicLong; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.MemoryManager; @@ -34,67 +35,69 @@ public class ShuffleMemoryTracker { - private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleMemoryTracker.class); - - private final long maxShuffleSize; - private final AtomicLong usedMemory; - private MemoryManager memoryPoolManager; - - public ShuffleMemoryTracker(Configuration config) { - boolean memoryPool = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE); - double safetyFraction = config.getDouble(SHUFFLE_MEMORY_SAFETY_FRACTION); - - long maxMemorySize; - if (memoryPool) { - memoryPoolManager = MemoryManager.build(config); - maxMemorySize = (long) (memoryPoolManager.maxMemory() * safetyFraction); - double fraction = config.getDouble(SHUFFLE_OFFHEAP_MEMORY_FRACTION); - maxShuffleSize = (long) (maxMemorySize * fraction); - } else { - long maxHeapSize = config.getInteger(CONTAINER_HEAP_SIZE_MB) * FileUtils.ONE_MB; - maxMemorySize = (long) (maxHeapSize * safetyFraction); - double fraction = config.getDouble(SHUFFLE_HEAP_MEMORY_FRACTION); - maxShuffleSize = (long) (maxMemorySize * fraction); - } - - usedMemory = new AtomicLong(0); - LOGGER.info("memoryPool:{} maxMemory:{}mb shuffleMax:{}mb", memoryPool, - maxMemorySize / FileUtils.ONE_MB, maxShuffleSize / FileUtils.ONE_MB); + private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleMemoryTracker.class); + + private final long maxShuffleSize; + private final AtomicLong usedMemory; + private MemoryManager memoryPoolManager; + + public ShuffleMemoryTracker(Configuration config) { + boolean memoryPool = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE); + double safetyFraction = config.getDouble(SHUFFLE_MEMORY_SAFETY_FRACTION); + + long maxMemorySize; + if (memoryPool) { + memoryPoolManager = MemoryManager.build(config); + maxMemorySize = (long) (memoryPoolManager.maxMemory() * safetyFraction); + double fraction = config.getDouble(SHUFFLE_OFFHEAP_MEMORY_FRACTION); + maxShuffleSize = (long) (maxMemorySize * fraction); + } else { + long maxHeapSize = config.getInteger(CONTAINER_HEAP_SIZE_MB) * FileUtils.ONE_MB; + maxMemorySize = (long) (maxHeapSize * safetyFraction); + double fraction = config.getDouble(SHUFFLE_HEAP_MEMORY_FRACTION); + maxShuffleSize = (long) (maxMemorySize * fraction); } - public boolean requireMemory(long requiredBytes) { - if (usedMemory.get() < 0) { - LOGGER.warn("memory statistic incorrect!"); - } - if (requiredBytes < 0) { - throw new IllegalArgumentException("invalid required bytes:" + requiredBytes); - } else if (requiredBytes == 0) { - return maxShuffleSize >= usedMemory.get(); - } else { - return maxShuffleSize >= usedMemory.addAndGet(requiredBytes); - } + usedMemory = new AtomicLong(0); + LOGGER.info( + "memoryPool:{} maxMemory:{}mb shuffleMax:{}mb", + memoryPool, + maxMemorySize / FileUtils.ONE_MB, + maxShuffleSize / FileUtils.ONE_MB); + } + + public boolean requireMemory(long requiredBytes) { + if (usedMemory.get() < 0) { + LOGGER.warn("memory statistic incorrect!"); } - - public boolean checkMemoryEnough() { - return usedMemory.get() < maxShuffleSize; + if (requiredBytes < 0) { + throw new IllegalArgumentException("invalid required bytes:" + requiredBytes); + } else if (requiredBytes == 0) { + return maxShuffleSize >= usedMemory.get(); + } else { + return maxShuffleSize >= usedMemory.addAndGet(requiredBytes); } + } - public long releaseMemory(long releasedBytes) { - return usedMemory.addAndGet(releasedBytes * -1); - } + public boolean checkMemoryEnough() { + return usedMemory.get() < maxShuffleSize; + } - public long getUsedMemory() { - return usedMemory.get(); - } + public long releaseMemory(long releasedBytes) { + return usedMemory.addAndGet(releasedBytes * -1); + } - public double getUsedRatio() { - return usedMemory.get() * 1.0 / maxShuffleSize; - } + public long getUsedMemory() { + return usedMemory.get(); + } - public void release() { - if (memoryPoolManager != null) { - memoryPoolManager.dispose(); - } - } + public double getUsedRatio() { + return usedMemory.get() * 1.0 / maxShuffleSize; + } + public void release() { + if (memoryPoolManager != null) { + memoryPoolManager.dispose(); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/AbstractInputChannel.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/AbstractInputChannel.java index f7945ab96..c943515de 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/AbstractInputChannel.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/AbstractInputChannel.java @@ -19,140 +19,142 @@ package org.apache.geaflow.shuffle.pipeline.channel; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; import org.apache.geaflow.shuffle.pipeline.fetcher.OneShardFetcher; +import com.google.common.base.Preconditions; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.partition.consumer.InputChannel. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.partition.consumer.InputChannel. */ public abstract class AbstractInputChannel implements InputChannel { - private static final int BACKOFF_DISABLED = -1; - - // The info of the input channel to identify it globally within a task. - protected final int channelIndex; - // Initial batch to fetch. - protected final long initialBatchId; - // Slice id to consume. - protected final SliceId inputSliceId; - // Parent fetcher this channel belongs to. - protected final OneShardFetcher inputFetcher; - - // The initial backoff (ms). - protected final int initialBackoff; - // The maximum backoff (ms). - protected final int maxBackoff; - // The current backoff (ms). - protected int currentBackoff; - - // Asynchronous error notification. - private final AtomicReference cause = new AtomicReference(); - - protected AbstractInputChannel(int channelIndex, OneShardFetcher inputFetcher, SliceId sliceId, - int initialBackoff, int maxBackoff, long startBatchId) { - - Preconditions.checkArgument(channelIndex >= 0); - Preconditions.checkArgument(initialBackoff >= 0 && initialBackoff <= maxBackoff); - - this.inputSliceId = sliceId; - this.inputFetcher = Preconditions.checkNotNull(inputFetcher); - this.channelIndex = channelIndex; - this.initialBackoff = initialBackoff; - this.maxBackoff = maxBackoff; - this.currentBackoff = initialBackoff == 0 ? BACKOFF_DISABLED : 0; - this.initialBatchId = startBatchId; + private static final int BACKOFF_DISABLED = -1; + + // The info of the input channel to identify it globally within a task. + protected final int channelIndex; + // Initial batch to fetch. + protected final long initialBatchId; + // Slice id to consume. + protected final SliceId inputSliceId; + // Parent fetcher this channel belongs to. + protected final OneShardFetcher inputFetcher; + + // The initial backoff (ms). + protected final int initialBackoff; + // The maximum backoff (ms). + protected final int maxBackoff; + // The current backoff (ms). + protected int currentBackoff; + + // Asynchronous error notification. + private final AtomicReference cause = new AtomicReference(); + + protected AbstractInputChannel( + int channelIndex, + OneShardFetcher inputFetcher, + SliceId sliceId, + int initialBackoff, + int maxBackoff, + long startBatchId) { + + Preconditions.checkArgument(channelIndex >= 0); + Preconditions.checkArgument(initialBackoff >= 0 && initialBackoff <= maxBackoff); + + this.inputSliceId = sliceId; + this.inputFetcher = Preconditions.checkNotNull(inputFetcher); + this.channelIndex = channelIndex; + this.initialBackoff = initialBackoff; + this.maxBackoff = maxBackoff; + this.currentBackoff = initialBackoff == 0 ? BACKOFF_DISABLED : 0; + this.initialBatchId = startBatchId; + } + + /** + * Notifies the owning {@link OneShardFetcher} that this channel became non-empty. + * + *

This is guaranteed to be called only when a Buffer was added to a previously empty input + * channel. The notion of empty is atomically consistent with the flag {@link + * PipeChannelBuffer#moreAvailable()} when polling the next buffer from this channel. + * + *

Note: When the input channel observes an exception, this method is called regardless + * of whether the channel was empty before. That ensures that the parent InputGate will always be + * notified about the exception. + */ + protected void notifyChannelNonEmpty() { + inputFetcher.notifyChannelNonEmpty(this); + } + + public int getChannelIndex() { + return channelIndex; + } + + public SliceId getInputSliceId() { + return inputSliceId; + } + + /** Checks for an error and rethrows it if one was reported. */ + protected void checkError() throws IOException { + final Throwable t = cause.get(); + if (t != null) { + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new IOException("input channel error", t); + } } - - /** - * Notifies the owning {@link OneShardFetcher} that this channel became non-empty. - * - *

This is guaranteed to be called only when a Buffer was added to a previously - * empty input channel. The notion of empty is atomically consistent with the flag - * {@link PipeChannelBuffer#moreAvailable()} when polling the next buffer - * from this channel. - * - *

Note: When the input channel observes an exception, this - * method is called regardless of whether the channel was empty before. That ensures - * that the parent InputGate will always be notified about the exception. - */ - protected void notifyChannelNonEmpty() { - inputFetcher.notifyChannelNonEmpty(this); - } - - public int getChannelIndex() { - return channelIndex; + } + + /** + * Atomically sets an error for this channel and notifies the input fetcher about available data + * to trigger querying this channel by the task thread. + */ + public void setError(Throwable cause) { + if (this.cause.compareAndSet(null, Preconditions.checkNotNull(cause))) { + // Notify the input fetcher. + notifyChannelNonEmpty(); } - - public SliceId getInputSliceId() { - return inputSliceId; - } - - /** - * Checks for an error and rethrows it if one was reported. - */ - protected void checkError() throws IOException { - final Throwable t = cause.get(); - if (t != null) { - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new IOException("input channel error", t); - } - } - } - - /** - * Atomically sets an error for this channel and notifies the input fetcher about available - * data to trigger querying this channel by the task thread. - */ - public void setError(Throwable cause) { - if (this.cause.compareAndSet(null, Preconditions.checkNotNull(cause))) { - // Notify the input fetcher. - notifyChannelNonEmpty(); - } - } - - // ------------------------------------------------------------------------ - // request exponential backoff - // ------------------------------------------------------------------------ - - /** - * Returns the current backoff in ms. - */ - protected int getCurrentBackoff() { - return Math.max(currentBackoff, 0); + } + + // ------------------------------------------------------------------------ + // request exponential backoff + // ------------------------------------------------------------------------ + + /** Returns the current backoff in ms. */ + protected int getCurrentBackoff() { + return Math.max(currentBackoff, 0); + } + + /** + * Increases the current backoff and returns whether the operation was successful. + * + * @return true, iff the operation was successful. Otherwise, false. + */ + protected boolean increaseBackoff() { + // Backoff is disabled. + if (currentBackoff < 0) { + return false; } - /** - * Increases the current backoff and returns whether the operation was successful. - * - * @return true, iff the operation was successful. Otherwise, false. - */ - protected boolean increaseBackoff() { - // Backoff is disabled. - if (currentBackoff < 0) { - return false; - } - - // This is the first time backing off - if (currentBackoff == 0) { - currentBackoff = initialBackoff; - return true; - } else if (currentBackoff < maxBackoff) { - // Continue backing off. - currentBackoff = Math.min(currentBackoff * 2, maxBackoff); - return true; - } - - return false; + // This is the first time backing off + if (currentBackoff == 0) { + currentBackoff = initialBackoff; + return true; + } else if (currentBackoff < maxBackoff) { + // Continue backing off. + currentBackoff = Math.min(currentBackoff * 2, maxBackoff); + return true; } + return false; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/ChannelId.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/ChannelId.java index 48a46acad..84f73cab0 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/ChannelId.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/ChannelId.java @@ -19,62 +19,61 @@ package org.apache.geaflow.shuffle.pipeline.channel; -import io.netty.buffer.ByteBuf; import java.io.Serializable; import java.util.Objects; import java.util.UUID; +import io.netty.buffer.ByteBuf; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Flink's org.apache.flink.util.AbstractID. - */ +/** This class is an adaptation of Flink's org.apache.flink.util.AbstractID. */ public class ChannelId implements Serializable { - public static final int CHANNEL_ID_BYTES = 16; - private static final long serialVersionUID = 2L; - // The upper part of the actual ID. - private final long upperPart; - // The lower part of the actual ID. - private final long lowerPart; + public static final int CHANNEL_ID_BYTES = 16; + private static final long serialVersionUID = 2L; + // The upper part of the actual ID. + private final long upperPart; + // The lower part of the actual ID. + private final long lowerPart; - public ChannelId() { - UUID uuid = UUID.randomUUID(); - this.upperPart = uuid.getMostSignificantBits(); - this.lowerPart = uuid.getLeastSignificantBits(); - } + public ChannelId() { + UUID uuid = UUID.randomUUID(); + this.upperPart = uuid.getMostSignificantBits(); + this.lowerPart = uuid.getLeastSignificantBits(); + } - public ChannelId(long lowerPart, long upperPart) { - this.upperPart = upperPart; - this.lowerPart = lowerPart; - } + public ChannelId(long lowerPart, long upperPart) { + this.upperPart = upperPart; + this.lowerPart = lowerPart; + } - public void writeTo(ByteBuf buf) { - buf.writeLong(this.lowerPart); - buf.writeLong(this.upperPart); - } + public void writeTo(ByteBuf buf) { + buf.writeLong(this.lowerPart); + buf.writeLong(this.upperPart); + } - public static ChannelId readFrom(ByteBuf buf) { - long lower = buf.readLong(); - long upper = buf.readLong(); - return new ChannelId(lower, upper); - } + public static ChannelId readFrom(ByteBuf buf) { + long lower = buf.readLong(); + long upper = buf.readLong(); + return new ChannelId(lower, upper); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ChannelId that = (ChannelId) o; - return upperPart == that.upperPart && lowerPart == that.lowerPart; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(upperPart, lowerPart); + if (o == null || getClass() != o.getClass()) { + return false; } + ChannelId that = (ChannelId) o; + return upperPart == that.upperPart && lowerPart == that.lowerPart; + } + + @Override + public int hashCode() { + return Objects.hash(upperPart, lowerPart); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/InputChannel.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/InputChannel.java index ff0806b2c..d4d0096c9 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/InputChannel.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/InputChannel.java @@ -21,35 +21,33 @@ import java.io.IOException; import java.util.Optional; + import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; public interface InputChannel { - /** - * request batches from upstream slices. - * - * @param batchId the maximum batchId to fetch. - * @throws IOException IO exception - * @throws InterruptedException Interrupt exception - */ - void requestSlice(long batchId) throws IOException, InterruptedException; - - /** - * Returns the next buffer from the consumed slice or {@code Optional.empty()} if there - * is no data to return. - */ - Optional getNext() throws IOException, InterruptedException; - - /** - * check if channel is released. - * - * @return true if released. - */ - boolean isReleased(); - - /** - * Releases all resources of the channel. - */ - void release() throws IOException; - + /** + * request batches from upstream slices. + * + * @param batchId the maximum batchId to fetch. + * @throws IOException IO exception + * @throws InterruptedException Interrupt exception + */ + void requestSlice(long batchId) throws IOException, InterruptedException; + + /** + * Returns the next buffer from the consumed slice or {@code Optional.empty()} if there is no data + * to return. + */ + Optional getNext() throws IOException, InterruptedException; + + /** + * check if channel is released. + * + * @return true if released. + */ + boolean isReleased(); + + /** Releases all resources of the channel. */ + void release() throws IOException; } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/LocalInputChannel.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/LocalInputChannel.java index f98853d52..ad3c125c1 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/LocalInputChannel.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/LocalInputChannel.java @@ -19,11 +19,11 @@ package org.apache.geaflow.shuffle.pipeline.channel; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.Optional; import java.util.Timer; import java.util.TimerTask; + import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; import org.apache.geaflow.shuffle.pipeline.fetcher.OneShardFetcher; @@ -35,139 +35,144 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel. */ public class LocalInputChannel extends AbstractInputChannel implements PipelineSliceListener { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalInputChannel.class); - - private final Object requestLock = new Object(); - private PipelineSliceReader sliceReader; - private volatile boolean isReleased; - - public LocalInputChannel( - OneShardFetcher fetcher, - SliceId inputSlice, - int channelIndex, - int initialBackoff, - int maxBackoff, - long startBatchId) { - super(channelIndex, fetcher, inputSlice, initialBackoff, maxBackoff, startBatchId); - } - - @Override - public void requestSlice(long batchId) throws IOException { - boolean retriggerRequest = false; - - // The lock is required to request only once in the presence of retriggered requests. - synchronized (requestLock) { - Preconditions.checkState(!isReleased, "LocalInputChannel has been released already"); - if (this.sliceReader == null) { - LOGGER.info("Requesting Local slice {}", this.inputSliceId); - try { - SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - this.sliceReader = sliceManager - .createSliceReader(this.inputSliceId, this.initialBatchId, this); - } catch (SliceNotFoundException notFound) { - if (increaseBackoff()) { - retriggerRequest = true; - } else { - LOGGER.warn("not found slice:{}", this.inputSliceId); - throw notFound; - } - } - } else { - this.sliceReader.updateRequestedBatchId(batchId); - } - } - - if (this.sliceReader != null && this.sliceReader.hasNext()) { - notifyDataAvailable(); - } - // Do this outside of the lock scope as this might lead to a - // deadlock with a concurrent release of the channel via the - // input fetcher. - if (retriggerRequest) { - inputFetcher.retriggerFetchRequest(inputSliceId); + private static final Logger LOGGER = LoggerFactory.getLogger(LocalInputChannel.class); + + private final Object requestLock = new Object(); + private PipelineSliceReader sliceReader; + private volatile boolean isReleased; + + public LocalInputChannel( + OneShardFetcher fetcher, + SliceId inputSlice, + int channelIndex, + int initialBackoff, + int maxBackoff, + long startBatchId) { + super(channelIndex, fetcher, inputSlice, initialBackoff, maxBackoff, startBatchId); + } + + @Override + public void requestSlice(long batchId) throws IOException { + boolean retriggerRequest = false; + + // The lock is required to request only once in the presence of retriggered requests. + synchronized (requestLock) { + Preconditions.checkState(!isReleased, "LocalInputChannel has been released already"); + if (this.sliceReader == null) { + LOGGER.info("Requesting Local slice {}", this.inputSliceId); + try { + SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + this.sliceReader = + sliceManager.createSliceReader(this.inputSliceId, this.initialBatchId, this); + } catch (SliceNotFoundException notFound) { + if (increaseBackoff()) { + retriggerRequest = true; + } else { + LOGGER.warn("not found slice:{}", this.inputSliceId); + throw notFound; + } } + } else { + this.sliceReader.updateRequestedBatchId(batchId); + } } - public void reTriggerSliceRequest(Timer timer) { - synchronized (requestLock) { - Preconditions.checkState(sliceReader == null, "already requested slice"); - timer.schedule(new TimerTask() { - @Override - public void run() { - try { - requestSlice(initialBatchId); - } catch (Throwable t) { - setError(t); - } - } - }, getCurrentBackoff()); - } + if (this.sliceReader != null && this.sliceReader.hasNext()) { + notifyDataAvailable(); } - - @Override - public Optional getNext() throws IOException { - checkError(); - PipelineSliceReader reader = this.sliceReader; - if (reader == null) { - if (isReleased) { - return Optional.empty(); + // Do this outside of the lock scope as this might lead to a + // deadlock with a concurrent release of the channel via the + // input fetcher. + if (retriggerRequest) { + inputFetcher.retriggerFetchRequest(inputSliceId); + } + } + + public void reTriggerSliceRequest(Timer timer) { + synchronized (requestLock) { + Preconditions.checkState(sliceReader == null, "already requested slice"); + timer.schedule( + new TimerTask() { + @Override + public void run() { + try { + requestSlice(initialBatchId); + } catch (Throwable t) { + setError(t); + } } - reader = checkAndGetSliceReader(); - } - - PipeChannelBuffer next = reader.next(); - - if (next == null) { - return Optional.empty(); - } - - return Optional.of(next); + }, + getCurrentBackoff()); } - - @Override - public void notifyDataAvailable() { - notifyChannelNonEmpty(); + } + + @Override + public Optional getNext() throws IOException { + checkError(); + PipelineSliceReader reader = this.sliceReader; + if (reader == null) { + if (isReleased) { + return Optional.empty(); + } + reader = checkAndGetSliceReader(); } - private PipelineSliceReader checkAndGetSliceReader() { - // Synchronizing on the request lock means this blocks until the asynchronous request - // for the slice has been completed by then the slice reader is visible or the channel is released. - synchronized (requestLock) { - Preconditions.checkState(!isReleased, "released"); - Preconditions.checkState(sliceReader != null, "reader is not ready."); - return sliceReader; - } - } + PipeChannelBuffer next = reader.next(); - @Override - public boolean isReleased() { - return isReleased; + if (next == null) { + return Optional.empty(); } - @Override - public void release() { - if (!isReleased) { - isReleased = true; - PipelineSliceReader reader = sliceReader; - if (reader != null) { - reader.release(); - sliceReader = null; - } - } + return Optional.of(next); + } + + @Override + public void notifyDataAvailable() { + notifyChannelNonEmpty(); + } + + private PipelineSliceReader checkAndGetSliceReader() { + // Synchronizing on the request lock means this blocks until the asynchronous request + // for the slice has been completed by then the slice reader is visible or the channel is + // released. + synchronized (requestLock) { + Preconditions.checkState(!isReleased, "released"); + Preconditions.checkState(sliceReader != null, "reader is not ready."); + return sliceReader; } - - @Override - public String toString() { - return "LocalInputChannel [" + inputSliceId + "]"; + } + + @Override + public boolean isReleased() { + return isReleased; + } + + @Override + public void release() { + if (!isReleased) { + isReleased = true; + PipelineSliceReader reader = sliceReader; + if (reader != null) { + reader.release(); + sliceReader = null; + } } + } + @Override + public String toString() { + return "LocalInputChannel [" + inputSliceId + "]"; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/RemoteInputChannel.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/RemoteInputChannel.java index 042ed0ac6..574197856 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/RemoteInputChannel.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/channel/RemoteInputChannel.java @@ -19,13 +19,12 @@ package org.apache.geaflow.shuffle.pipeline.channel; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.ArrayDeque; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.shuffle.config.ShuffleConfig; import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.network.ConnectionId; @@ -37,253 +36,254 @@ import org.apache.geaflow.shuffle.pipeline.fetcher.OneShardFetcher; import org.apache.geaflow.shuffle.util.SliceNotFoundException; -/** - * This class references the implementation of Flink's RemoteInputChannel. - */ -public class RemoteInputChannel extends AbstractInputChannel { - - // ID to distinguish this channel from other channels sharing the same TCP connection. - private final ChannelId id = new ChannelId(); - - // The connection to use to request the remote slice. - private final ConnectionId connectionId; - - // The connection manager to use connect to the remote slice provider. - private final ConnectionManager connectionManager; - - // The received buffers. Received buffers are enqueued by the network I/O thread and the queue. - // is consumed by the receiving task thread. - private final ArrayDeque receivedBuffers = new ArrayDeque<>(); - - // Flag indicating whether this channel has been released. - private final AtomicBoolean isReleased = new AtomicBoolean(); - - // Client to establish a (possibly shared) TCP connection and request the slice. - private volatile SliceRequestClient sliceRequestClient; - - // The next expected sequence number for the next buffer. This is modified by the network - // I/O thread only. - private int expectedSequenceNumber = 0; - - private final boolean enableBackPressure; - - // The initial credit for this channel. - private final int initialCredit; - - // threshold to notify available credit for this channel. - private final int creditNotifyThreshold; - - // available credit for this channel. - private final AtomicInteger availableCredit; - - - public RemoteInputChannel(OneShardFetcher fetcher, SliceId inputSlice, int channelIndex, - ConnectionId connectionId, int initialBackoff, int maxBackoff, - long startBatchId, IConnectionManager connectionManager) { - super(channelIndex, fetcher, inputSlice, initialBackoff, maxBackoff, startBatchId); - this.connectionId = Preconditions.checkNotNull(connectionId); - this.connectionManager = (ConnectionManager) connectionManager; - ShuffleConfig config = connectionManager.getShuffleConfig(); - this.enableBackPressure = config.isBackpressureEnabled(); - // initial credit -1 means no limit. - this.initialCredit = enableBackPressure ? config.getChannelQueueSize() : -1; - this.availableCredit = new AtomicInteger(0); - this.creditNotifyThreshold = initialCredit / 2; - } - - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ - - @Override - public void requestSlice(long batchId) throws IOException, InterruptedException { - if (sliceRequestClient == null) { - // Create a client and request the slice - sliceRequestClient = connectionManager.createSliceRequestClient(connectionId); - sliceRequestClient.requestSlice(inputSliceId, this, 0, initialBatchId); - } else { - sliceRequestClient.requestNextBatch(batchId, this); - } - } - - /** - * Retriggers a remote slice request. - */ - public void retriggerSliceRequest(SliceId sliceId) throws IOException { - checkClientInitialized(); - checkError(); - - if (increaseBackoff()) { - sliceRequestClient.requestSlice(sliceId, this, getCurrentBackoff(), initialBatchId); - } else { - setError(new SliceNotFoundException(sliceId)); - } - } - - @Override - public Optional getNext() throws IOException { - checkClientInitialized(); - checkError(); - - final PipeBuffer next; - final boolean moreAvailable; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; - synchronized (receivedBuffers) { - next = receivedBuffers.poll(); - moreAvailable = !receivedBuffers.isEmpty(); - } +/** This class references the implementation of Flink's RemoteInputChannel. */ +public class RemoteInputChannel extends AbstractInputChannel { - if (next == null) { - if (isReleased.get()) { - throw new IOException("Queried for a buffer after channel has been released."); - } else { - throw new IllegalStateException( - "There should always have queued buffers for unreleased channel."); - } - } - if (enableBackPressure && next.isData() - && availableCredit.incrementAndGet() >= creditNotifyThreshold) { - notifyCreditAvailable(); - } - return Optional.of(new PipeChannelBuffer(next, moreAvailable)); + // ID to distinguish this channel from other channels sharing the same TCP connection. + private final ChannelId id = new ChannelId(); + + // The connection to use to request the remote slice. + private final ConnectionId connectionId; + + // The connection manager to use connect to the remote slice provider. + private final ConnectionManager connectionManager; + + // The received buffers. Received buffers are enqueued by the network I/O thread and the queue. + // is consumed by the receiving task thread. + private final ArrayDeque receivedBuffers = new ArrayDeque<>(); + + // Flag indicating whether this channel has been released. + private final AtomicBoolean isReleased = new AtomicBoolean(); + + // Client to establish a (possibly shared) TCP connection and request the slice. + private volatile SliceRequestClient sliceRequestClient; + + // The next expected sequence number for the next buffer. This is modified by the network + // I/O thread only. + private int expectedSequenceNumber = 0; + + private final boolean enableBackPressure; + + // The initial credit for this channel. + private final int initialCredit; + + // threshold to notify available credit for this channel. + private final int creditNotifyThreshold; + + // available credit for this channel. + private final AtomicInteger availableCredit; + + public RemoteInputChannel( + OneShardFetcher fetcher, + SliceId inputSlice, + int channelIndex, + ConnectionId connectionId, + int initialBackoff, + int maxBackoff, + long startBatchId, + IConnectionManager connectionManager) { + super(channelIndex, fetcher, inputSlice, initialBackoff, maxBackoff, startBatchId); + this.connectionId = Preconditions.checkNotNull(connectionId); + this.connectionManager = (ConnectionManager) connectionManager; + ShuffleConfig config = connectionManager.getShuffleConfig(); + this.enableBackPressure = config.isBackpressureEnabled(); + // initial credit -1 means no limit. + this.initialCredit = enableBackPressure ? config.getChannelQueueSize() : -1; + this.availableCredit = new AtomicInteger(0); + this.creditNotifyThreshold = initialCredit / 2; + } + + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ + + @Override + public void requestSlice(long batchId) throws IOException, InterruptedException { + if (sliceRequestClient == null) { + // Create a client and request the slice + sliceRequestClient = connectionManager.createSliceRequestClient(connectionId); + sliceRequestClient.requestSlice(inputSliceId, this, 0, initialBatchId); + } else { + sliceRequestClient.requestNextBatch(batchId, this); } + } - @Override - public boolean isReleased() { - return isReleased.get(); - } + /** Retriggers a remote slice request. */ + public void retriggerSliceRequest(SliceId sliceId) throws IOException { + checkClientInitialized(); + checkError(); - /** - * Releases all exclusive and floating buffers, closes the request client. - */ - @Override - public void release() throws IOException { - if (isReleased.compareAndSet(false, true)) { - synchronized (receivedBuffers) { - receivedBuffers.clear(); - } - - // The released flag has to be set before closing the connection to ensure that - // buffers received concurrently with closing are properly recycled. - if (sliceRequestClient != null) { - sliceRequestClient.close(this); - } else { - connectionManager.closeOpenChannelConnections(connectionId); - } - } + if (increaseBackoff()) { + sliceRequestClient.requestSlice(sliceId, this, getCurrentBackoff(), initialBatchId); + } else { + setError(new SliceNotFoundException(sliceId)); } + } - @Override - public String toString() { - return "RemoteInputChannel [" + inputSliceId + " at " + connectionId + "]"; - } + @Override + public Optional getNext() throws IOException { + checkClientInitialized(); + checkError(); - // ------------------------------------------------------------------------ - // Network I/O notifications (called by network I/O thread) - // ------------------------------------------------------------------------ - - /** - * Gets the current number of received buffers which have not been processed yet. - * - * @return Buffers queued for processing. - */ - public int getNumberOfQueuedBuffers() { - synchronized (receivedBuffers) { - return receivedBuffers.size(); - } - } + final PipeBuffer next; + final boolean moreAvailable; - public ChannelId getInputChannelId() { - return id; + synchronized (receivedBuffers) { + next = receivedBuffers.poll(); + moreAvailable = !receivedBuffers.isEmpty(); } - public int getInitialCredit() { - return initialCredit; + if (next == null) { + if (isReleased.get()) { + throw new IOException("Queried for a buffer after channel has been released."); + } else { + throw new IllegalStateException( + "There should always have queued buffers for unreleased channel."); + } } - - public int getAndResetAvailableCredit() { - return availableCredit.getAndSet(0); + if (enableBackPressure + && next.isData() + && availableCredit.incrementAndGet() >= creditNotifyThreshold) { + notifyCreditAvailable(); } - - /** - * Enqueue this input channel in the pipeline for notifying the producer of unannounced credit. - */ - private void notifyCreditAvailable() throws IOException { - checkClientInitialized(); - checkError(); - - sliceRequestClient.notifyCreditAvailable(this); + return Optional.of(new PipeChannelBuffer(next, moreAvailable)); + } + + @Override + public boolean isReleased() { + return isReleased.get(); + } + + /** Releases all exclusive and floating buffers, closes the request client. */ + @Override + public void release() throws IOException { + if (isReleased.compareAndSet(false, true)) { + synchronized (receivedBuffers) { + receivedBuffers.clear(); + } + + // The released flag has to be set before closing the connection to ensure that + // buffers received concurrently with closing are properly recycled. + if (sliceRequestClient != null) { + sliceRequestClient.close(this); + } else { + connectionManager.closeOpenChannelConnections(connectionId); + } } - - @VisibleForTesting - public SliceRequestClient getSliceRequestClient() { - return sliceRequestClient; + } + + @Override + public String toString() { + return "RemoteInputChannel [" + inputSliceId + " at " + connectionId + "]"; + } + + // ------------------------------------------------------------------------ + // Network I/O notifications (called by network I/O thread) + // ------------------------------------------------------------------------ + + /** + * Gets the current number of received buffers which have not been processed yet. + * + * @return Buffers queued for processing. + */ + public int getNumberOfQueuedBuffers() { + synchronized (receivedBuffers) { + return receivedBuffers.size(); } - - public void onBuffer(PipeBuffer buffer, int sequenceNumber) throws IOException { - synchronized (receivedBuffers) { - if (isReleased.get()) { - return; - } - - if (expectedSequenceNumber != sequenceNumber) { - onError(new ReorderingException(expectedSequenceNumber, sequenceNumber)); - return; - } - - boolean wasEmpty = receivedBuffers.isEmpty(); - receivedBuffers.add(buffer); - ++expectedSequenceNumber; - - if (wasEmpty) { - notifyChannelNonEmpty(); - } - } - + } + + public ChannelId getInputChannelId() { + return id; + } + + public int getInitialCredit() { + return initialCredit; + } + + public int getAndResetAvailableCredit() { + return availableCredit.getAndSet(0); + } + + /** + * Enqueue this input channel in the pipeline for notifying the producer of unannounced credit. + */ + private void notifyCreditAvailable() throws IOException { + checkClientInitialized(); + checkError(); + + sliceRequestClient.notifyCreditAvailable(this); + } + + @VisibleForTesting + public SliceRequestClient getSliceRequestClient() { + return sliceRequestClient; + } + + public void onBuffer(PipeBuffer buffer, int sequenceNumber) throws IOException { + synchronized (receivedBuffers) { + if (isReleased.get()) { + return; + } + + if (expectedSequenceNumber != sequenceNumber) { + onError(new ReorderingException(expectedSequenceNumber, sequenceNumber)); + return; + } + + boolean wasEmpty = receivedBuffers.isEmpty(); + receivedBuffers.add(buffer); + ++expectedSequenceNumber; + + if (wasEmpty) { + notifyChannelNonEmpty(); + } } + } - public void onEmptyBuffer(int sequenceNumber) throws IOException { - synchronized (receivedBuffers) { - if (!isReleased.get()) { - if (expectedSequenceNumber == sequenceNumber) { - expectedSequenceNumber++; - } else { - onError(new ReorderingException(expectedSequenceNumber, sequenceNumber)); - } - } + public void onEmptyBuffer(int sequenceNumber) throws IOException { + synchronized (receivedBuffers) { + if (!isReleased.get()) { + if (expectedSequenceNumber == sequenceNumber) { + expectedSequenceNumber++; + } else { + onError(new ReorderingException(expectedSequenceNumber, sequenceNumber)); } + } } + } - public void onFailedFetchRequest() { - inputFetcher.retriggerFetchRequest(this); - } - - public void onError(Throwable cause) { - setError(cause); - } + public void onFailedFetchRequest() { + inputFetcher.retriggerFetchRequest(this); + } - private void checkClientInitialized() { - Preconditions.checkState(sliceRequestClient != null, - "Bug: client is not initialized before request data."); - } + public void onError(Throwable cause) { + setError(cause); + } - private static class ReorderingException extends IOException { + private void checkClientInitialized() { + Preconditions.checkState( + sliceRequestClient != null, "Bug: client is not initialized before request data."); + } - private static final long serialVersionUID = -888282210356266816L; - private final int expectedSequenceNumber; - private final int actualSequenceNumber; + private static class ReorderingException extends IOException { - ReorderingException(int expectedSequenceNumber, int actualSequenceNumber) { - this.expectedSequenceNumber = expectedSequenceNumber; - this.actualSequenceNumber = actualSequenceNumber; - } + private static final long serialVersionUID = -888282210356266816L; + private final int expectedSequenceNumber; + private final int actualSequenceNumber; - @Override - public String getMessage() { - return String.format( - "Buffer re-ordering: expected buffer with sequence number %d, but received %d.", - expectedSequenceNumber, actualSequenceNumber); - } + ReorderingException(int expectedSequenceNumber, int actualSequenceNumber) { + this.expectedSequenceNumber = expectedSequenceNumber; + this.actualSequenceNumber = actualSequenceNumber; } + @Override + public String getMessage() { + return String.format( + "Buffer re-ordering: expected buffer with sequence number %d, but received %d.", + expectedSequenceNumber, actualSequenceNumber); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcher.java index 098a4a962..99461e184 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcher.java @@ -19,10 +19,6 @@ package org.apache.geaflow.shuffle.pipeline.fetcher; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -33,237 +29,242 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.shuffle.pipeline.buffer.PipeFetcherBuffer; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * Interface to fetch data from multiple {@link OneShardFetcher}. - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate. + * Interface to fetch data from multiple {@link OneShardFetcher}. This class is an adaptation of + * Flink's org.apache.flink.runtime.io.network.partition.consumer.UnionInputGate. */ public class MultiShardFetcher implements ShardFetcher, ShardFetcherListener { - // The input fetchers to union. - private final ShardFetcher[] shardFetchers; - - private final Set inputFetchersWithRemainingData; + // The input fetchers to union. + private final ShardFetcher[] shardFetchers; - private final ArrayDeque inputFetchersWithData = new ArrayDeque<>(); + private final Set inputFetchersWithRemainingData; - private final Set enqueuedInputFetchersWithData = new HashSet<>(); + private final ArrayDeque inputFetchersWithData = new ArrayDeque<>(); - // Listeners to forward buffer notifications to. - private final List fetcherListeners = new ArrayList<>(); + private final Set enqueuedInputFetchersWithData = new HashSet<>(); - // A mapping from input fetcher to (logical) channel index offset. - // Valid channel indexes go from 0 (inclusive) to the total number of input channels (exclusive). - private final Map inputFetcherToIndexOffsetMap; + // Listeners to forward buffer notifications to. + private final List fetcherListeners = new ArrayList<>(); - private final int totalNumberOfInputChannels; - private volatile boolean isReleased; + // A mapping from input fetcher to (logical) channel index offset. + // Valid channel indexes go from 0 (inclusive) to the total number of input channels (exclusive). + private final Map inputFetcherToIndexOffsetMap; - public MultiShardFetcher(OneShardFetcher... inputFetchers) { - this.shardFetchers = inputFetchers; - Preconditions.checkArgument(inputFetchers.length > 1, - "Union input fetcher should union at least two input fetchers."); + private final int totalNumberOfInputChannels; + private volatile boolean isReleased; - if (Arrays.stream(inputFetchers).map(OneShardFetcher::getFetcherIndex).distinct().count() - != inputFetchers.length) { - throw new IllegalArgumentException( - "Union of two input fetchers with the same index. Given indices: " + Arrays - .stream(inputFetchers).map(OneShardFetcher::getFetcherIndex) - .collect(Collectors.toList())); - } + public MultiShardFetcher(OneShardFetcher... inputFetchers) { + this.shardFetchers = inputFetchers; + Preconditions.checkArgument( + inputFetchers.length > 1, "Union input fetcher should union at least two input fetchers."); - this.inputFetcherToIndexOffsetMap = Maps.newHashMapWithExpectedSize(inputFetchers.length); - this.inputFetchersWithRemainingData = Sets.newHashSetWithExpectedSize(inputFetchers.length); + if (Arrays.stream(inputFetchers).map(OneShardFetcher::getFetcherIndex).distinct().count() + != inputFetchers.length) { + throw new IllegalArgumentException( + "Union of two input fetchers with the same index. Given indices: " + + Arrays.stream(inputFetchers) + .map(OneShardFetcher::getFetcherIndex) + .collect(Collectors.toList())); + } - int currentNumberOfInputChannels = 0; + this.inputFetcherToIndexOffsetMap = Maps.newHashMapWithExpectedSize(inputFetchers.length); + this.inputFetchersWithRemainingData = Sets.newHashSetWithExpectedSize(inputFetchers.length); - for (OneShardFetcher fetcher : inputFetchers) { - // The offset to use for buffer or event instances received from this input fetcher. - inputFetcherToIndexOffsetMap.put(Preconditions.checkNotNull(fetcher), currentNumberOfInputChannels); - inputFetchersWithRemainingData.add(fetcher); + int currentNumberOfInputChannels = 0; - currentNumberOfInputChannels += fetcher.getNumberOfInputChannels(); + for (OneShardFetcher fetcher : inputFetchers) { + // The offset to use for buffer or event instances received from this input fetcher. + inputFetcherToIndexOffsetMap.put( + Preconditions.checkNotNull(fetcher), currentNumberOfInputChannels); + inputFetchersWithRemainingData.add(fetcher); - // Register the union fetcher as a listener for all single input fetchers - fetcher.registerListener(this); - } + currentNumberOfInputChannels += fetcher.getNumberOfInputChannels(); - this.totalNumberOfInputChannels = currentNumberOfInputChannels; - this.isReleased = false; + // Register the union fetcher as a listener for all single input fetchers + fetcher.registerListener(this); } - @Override - public Optional getNext() throws IOException, InterruptedException { - return getNext(true); - } + this.totalNumberOfInputChannels = currentNumberOfInputChannels; + this.isReleased = false; + } - @Override - public Optional pollNext() throws IOException, InterruptedException { - return getNext(false); - } + @Override + public Optional getNext() throws IOException, InterruptedException { + return getNext(true); + } - private Optional getNext(boolean blocking) - throws IOException, InterruptedException { - if (inputFetchersWithRemainingData.isEmpty()) { - return Optional.empty(); - } + @Override + public Optional pollNext() throws IOException, InterruptedException { + return getNext(false); + } - Optional> next = getNextInputData(blocking); - if (!next.isPresent()) { - return Optional.empty(); - } + private Optional getNext(boolean blocking) + throws IOException, InterruptedException { + if (inputFetchersWithRemainingData.isEmpty()) { + return Optional.empty(); + } - InputWithData inputWithData = next.get(); - ShardFetcher shardFetcher = inputWithData.input; - PipeFetcherBuffer resultBuffer = inputWithData.data; + Optional> next = getNextInputData(blocking); + if (!next.isPresent()) { + return Optional.empty(); + } - if (resultBuffer.moreAvailable()) { - // This buffer or event was now removed from the non-empty fetchers queue - // we re-add it in case it has more data, because in that case no "non-empty" notification - // will come for that fetcher. - queueFetcher(shardFetcher); - } + InputWithData inputWithData = next.get(); + ShardFetcher shardFetcher = inputWithData.input; + PipeFetcherBuffer resultBuffer = inputWithData.data; - // Set the channel index to identify the input channel (across all unioned input fetchers). - final int channelIndexOffset = inputFetcherToIndexOffsetMap.get(shardFetcher); + if (resultBuffer.moreAvailable()) { + // This buffer or event was now removed from the non-empty fetchers queue + // we re-add it in case it has more data, because in that case no "non-empty" notification + // will come for that fetcher. + queueFetcher(shardFetcher); + } - resultBuffer.setChannelIndex(channelIndexOffset + resultBuffer.getChannelIndex()); - resultBuffer.setMoreAvailable(resultBuffer.moreAvailable() || inputWithData.moreAvailable); + // Set the channel index to identify the input channel (across all unioned input fetchers). + final int channelIndexOffset = inputFetcherToIndexOffsetMap.get(shardFetcher); - return Optional.of(inputWithData.data); - } + resultBuffer.setChannelIndex(channelIndexOffset + resultBuffer.getChannelIndex()); + resultBuffer.setMoreAvailable(resultBuffer.moreAvailable() || inputWithData.moreAvailable); - private Optional> getNextInputData( - boolean blocking) throws IOException, InterruptedException { + return Optional.of(inputWithData.data); + } - ShardFetcher shardFetcher; - boolean moreInputFetchersAvailable; + private Optional> getNextInputData( + boolean blocking) throws IOException, InterruptedException { - while (true) { - Optional> fetcherOptional = getInputFetcher(blocking); - if (!fetcherOptional.isPresent()) { - return Optional.empty(); - } + ShardFetcher shardFetcher; + boolean moreInputFetchersAvailable; - shardFetcher = fetcherOptional.get().f0; - moreInputFetchersAvailable = fetcherOptional.get().f1; + while (true) { + Optional> fetcherOptional = getInputFetcher(blocking); + if (!fetcherOptional.isPresent()) { + return Optional.empty(); + } - Optional result = shardFetcher.pollNext(); + shardFetcher = fetcherOptional.get().f0; + moreInputFetchersAvailable = fetcherOptional.get().f1; - if (result.isPresent()) { - return Optional.of(new InputWithData<>(shardFetcher, result.get(), - moreInputFetchersAvailable)); - } - } - } + Optional result = shardFetcher.pollNext(); - @Override - public void requestSlices(long batchId) throws IOException { - for (ShardFetcher fetcher : shardFetchers) { - fetcher.requestSlices(batchId); - } + if (result.isPresent()) { + return Optional.of( + new InputWithData<>(shardFetcher, result.get(), moreInputFetchersAvailable)); + } } + } - @Override - public void registerListener(ShardFetcherListener listener) { - synchronized (fetcherListeners) { - fetcherListeners.add(listener); - } + @Override + public void requestSlices(long batchId) throws IOException { + for (ShardFetcher fetcher : shardFetchers) { + fetcher.requestSlices(batchId); } + } - @Override - public void notifyAvailable(ShardFetcher shardFetcher) { - queueFetcher(shardFetcher); + @Override + public void registerListener(ShardFetcherListener listener) { + synchronized (fetcherListeners) { + fetcherListeners.add(listener); } + } - private void queueFetcher(ShardFetcher shardFetcher) { - Preconditions.checkNotNull(shardFetcher); + @Override + public void notifyAvailable(ShardFetcher shardFetcher) { + queueFetcher(shardFetcher); + } - int availableInputFetchers; + private void queueFetcher(ShardFetcher shardFetcher) { + Preconditions.checkNotNull(shardFetcher); - synchronized (inputFetchersWithData) { - if (enqueuedInputFetchersWithData.contains(shardFetcher)) { - return; - } + int availableInputFetchers; - availableInputFetchers = inputFetchersWithData.size(); + synchronized (inputFetchersWithData) { + if (enqueuedInputFetchersWithData.contains(shardFetcher)) { + return; + } - inputFetchersWithData.add(shardFetcher); - enqueuedInputFetchersWithData.add(shardFetcher); + availableInputFetchers = inputFetchersWithData.size(); - if (availableInputFetchers == 0) { - inputFetchersWithData.notifyAll(); - } - } + inputFetchersWithData.add(shardFetcher); + enqueuedInputFetchersWithData.add(shardFetcher); - if (availableInputFetchers == 0) { - synchronized (fetcherListeners) { - for (ShardFetcherListener listener : fetcherListeners) { - listener.notifyAvailable(this); - } - } - } + if (availableInputFetchers == 0) { + inputFetchersWithData.notifyAll(); + } } - private Optional> getInputFetcher(boolean blocking) - throws InterruptedException { - synchronized (inputFetchersWithData) { - while (inputFetchersWithData.size() == 0) { - if (blocking) { - inputFetchersWithData.wait(); - } else { - return Optional.empty(); - } - } - - ShardFetcher shardFetcher = inputFetchersWithData.remove(); - enqueuedInputFetchersWithData.remove(shardFetcher); - boolean moreAvailable = !enqueuedInputFetchersWithData.isEmpty(); - - return Optional.of(Tuple.of(shardFetcher, moreAvailable)); + if (availableInputFetchers == 0) { + synchronized (fetcherListeners) { + for (ShardFetcherListener listener : fetcherListeners) { + listener.notifyAvailable(this); } + } } + } + + private Optional> getInputFetcher(boolean blocking) + throws InterruptedException { + synchronized (inputFetchersWithData) { + while (inputFetchersWithData.size() == 0) { + if (blocking) { + inputFetchersWithData.wait(); + } else { + return Optional.empty(); + } + } - /** - * Returns the total number of input channels across all unioned input fetchers. - */ - @Override - public int getNumberOfInputChannels() { - return totalNumberOfInputChannels; - } + ShardFetcher shardFetcher = inputFetchersWithData.remove(); + enqueuedInputFetchersWithData.remove(shardFetcher); + boolean moreAvailable = !enqueuedInputFetchersWithData.isEmpty(); - @VisibleForTesting - public ShardFetcher[] getShardFetchers() { - return shardFetchers; + return Optional.of(Tuple.of(shardFetcher, moreAvailable)); } - - @Override - public boolean isFinished() { - for (ShardFetcher fetcher : shardFetchers) { - if (!fetcher.isFinished()) { - return false; - } - } - return true; + } + + /** Returns the total number of input channels across all unioned input fetchers. */ + @Override + public int getNumberOfInputChannels() { + return totalNumberOfInputChannels; + } + + @VisibleForTesting + public ShardFetcher[] getShardFetchers() { + return shardFetchers; + } + + @Override + public boolean isFinished() { + for (ShardFetcher fetcher : shardFetchers) { + if (!fetcher.isFinished()) { + return false; + } } - - @Override - public void close() { - if (!isReleased) { - for (ShardFetcher fetcher : shardFetchers) { - fetcher.close(); - } - synchronized (inputFetchersWithData) { - inputFetchersWithData.notifyAll(); - } - isReleased = true; - } + return true; + } + + @Override + public void close() { + if (!isReleased) { + for (ShardFetcher fetcher : shardFetchers) { + fetcher.close(); + } + synchronized (inputFetchersWithData) { + inputFetchersWithData.notifyAll(); + } + isReleased = true; } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcher.java index 1cfdadfc6..6dc38d660 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcher.java @@ -19,8 +19,6 @@ package org.apache.geaflow.shuffle.pipeline.fetcher; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; @@ -32,6 +30,7 @@ import java.util.Timer; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.common.tuple.Tuple; @@ -49,413 +48,449 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate. */ public class OneShardFetcher implements ShardFetcher { - private static final Logger LOGGER = LoggerFactory.getLogger(OneShardFetcher.class); - private static final int DEFAULT_CONNECTION_ID = -1; - private static final String DEFAULT_STREAM_NAME = ""; - - private final int stageId; - // The name of the owning task, for logging purposes. - private final String taskName; - // Fetcher index starting from 0, 1, ... - private final int fetcherIndex; - // Lock object to guard partition requests and runtime channel updates. - private final Object requestLock = new Object(); - - // Registered listeners to forward message notifications to. - private final List fetcherListeners = new ArrayList<>(); - // Field guaranteeing uniqueness for inputChannelsWithData queue. - private final BitSet enqueuedInputChannelsWithData; - // Channels, which notified this input fetcher about available data. - private final ArrayDeque inputChannelsWithData = new ArrayDeque<>(); - // Next input channel to fetch. - private AbstractInputChannel nextInputChannel; - // Flag indicating whether all resources have been released. - private volatile boolean isReleased; - // Flag indicating whether the fetcher is running. - private volatile boolean isRunning; - - // A timer to re-trigger local partition requests. Only initialized if actually needed. - private Timer retriggerLocalRequestTimer; - private final ExecutorService retriggerRemoteExecutor; - private final AtomicReference cause = new AtomicReference<>(); - - protected final int numberOfInputChannels; - protected final Map inputChannels; - protected final IConnectionManager connectionManager; - protected final String inputStream; - - public OneShardFetcher(int stageId, - String taskName, - int fetcherIndex, - int connectionId, - String inputStream, - List inputSlices, - long startBatchId, - IConnectionManager connectionManager) { - - this.stageId = stageId; - this.inputStream = inputStream; - this.taskName = Preconditions.checkNotNull(taskName); - this.fetcherIndex = fetcherIndex; - this.numberOfInputChannels = inputSlices.size(); - this.inputChannels = new HashMap<>(numberOfInputChannels); - this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels); - this.connectionManager = connectionManager; - this.retriggerRemoteExecutor = connectionManager.getExecutor(); - this.isReleased = false; - this.isRunning = true; - - ShuffleConfig nettyConfig = connectionManager.getShuffleConfig(); - int initialBackoff = nettyConfig.getConnectInitialBackoffMs(); - int maxBackoff = nettyConfig.getConnectMaxBackoffMs(); - buildInputChannels(connectionId, inputSlices, initialBackoff, maxBackoff, startBatchId); + private static final Logger LOGGER = LoggerFactory.getLogger(OneShardFetcher.class); + private static final int DEFAULT_CONNECTION_ID = -1; + private static final String DEFAULT_STREAM_NAME = ""; + + private final int stageId; + // The name of the owning task, for logging purposes. + private final String taskName; + // Fetcher index starting from 0, 1, ... + private final int fetcherIndex; + // Lock object to guard partition requests and runtime channel updates. + private final Object requestLock = new Object(); + + // Registered listeners to forward message notifications to. + private final List fetcherListeners = new ArrayList<>(); + // Field guaranteeing uniqueness for inputChannelsWithData queue. + private final BitSet enqueuedInputChannelsWithData; + // Channels, which notified this input fetcher about available data. + private final ArrayDeque inputChannelsWithData = new ArrayDeque<>(); + // Next input channel to fetch. + private AbstractInputChannel nextInputChannel; + // Flag indicating whether all resources have been released. + private volatile boolean isReleased; + // Flag indicating whether the fetcher is running. + private volatile boolean isRunning; + + // A timer to re-trigger local partition requests. Only initialized if actually needed. + private Timer retriggerLocalRequestTimer; + private final ExecutorService retriggerRemoteExecutor; + private final AtomicReference cause = new AtomicReference<>(); + + protected final int numberOfInputChannels; + protected final Map inputChannels; + protected final IConnectionManager connectionManager; + protected final String inputStream; + + public OneShardFetcher( + int stageId, + String taskName, + int fetcherIndex, + int connectionId, + String inputStream, + List inputSlices, + long startBatchId, + IConnectionManager connectionManager) { + + this.stageId = stageId; + this.inputStream = inputStream; + this.taskName = Preconditions.checkNotNull(taskName); + this.fetcherIndex = fetcherIndex; + this.numberOfInputChannels = inputSlices.size(); + this.inputChannels = new HashMap<>(numberOfInputChannels); + this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels); + this.connectionManager = connectionManager; + this.retriggerRemoteExecutor = connectionManager.getExecutor(); + this.isReleased = false; + this.isRunning = true; + + ShuffleConfig nettyConfig = connectionManager.getShuffleConfig(); + int initialBackoff = nettyConfig.getConnectInitialBackoffMs(); + int maxBackoff = nettyConfig.getConnectMaxBackoffMs(); + buildInputChannels(connectionId, inputSlices, initialBackoff, maxBackoff, startBatchId); + } + + @VisibleForTesting + public OneShardFetcher( + int stageId, + String taskName, + int fetcherIndex, + List inputSlices, + long startBatchId, + IConnectionManager connectionManager) { + + this( + stageId, + taskName, + fetcherIndex, + DEFAULT_CONNECTION_ID, + DEFAULT_STREAM_NAME, + inputSlices, + startBatchId, + connectionManager); + } + + protected void buildInputChannels( + int connectionId, + List inputSlices, + int initialBackoff, + int maxBackoff, + long initialBatchId) { + + int localChannels = 0; + ShuffleAddress localAddr = connectionManager.getShuffleAddress(); + for (int inputChannelIdx = 0; inputChannelIdx < numberOfInputChannels; inputChannelIdx++) { + PipelineSliceMeta task = (PipelineSliceMeta) inputSlices.get(inputChannelIdx); + ShuffleAddress address = task.getShuffleAddress(); + SliceId inputSlice = task.getSliceId(); + + AbstractInputChannel inputChannel; + if (address.equals(localAddr)) { + inputChannel = + new LocalInputChannel( + this, inputSlice, inputChannelIdx, initialBackoff, maxBackoff, initialBatchId); + inputChannels.put(inputSlice, inputChannel); + localChannels++; + } else { + inputChannel = + new RemoteInputChannel( + this, + inputSlice, + inputChannelIdx, + new ConnectionId(address, connectionId), + initialBackoff, + maxBackoff, + initialBatchId, + connectionManager); + inputChannels.put(inputSlice, inputChannel); + } } - - @VisibleForTesting - public OneShardFetcher(int stageId, - String taskName, - int fetcherIndex, - List inputSlices, - long startBatchId, - IConnectionManager connectionManager) { - - this(stageId, taskName, fetcherIndex, DEFAULT_CONNECTION_ID, DEFAULT_STREAM_NAME, - inputSlices, startBatchId, connectionManager); + LOGGER.info( + "{} create {} local channels in {} channels", + taskName, + localChannels, + numberOfInputChannels); + } + + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ + + @Override + public void requestSlices(long batchId) { + synchronized (requestLock) { + if (isReleased) { + throw new IllegalStateException("Already released."); + } + + // Sanity checks + if (numberOfInputChannels != inputChannels.size()) { + throw new IllegalStateException( + String.format( + "Bug in input fetcher setup logic: mismatch between " + + "number of total input channels [%s] and the currently set number " + + "of input " + + "channels [%s].", + inputChannels.size(), numberOfInputChannels)); + } + + internalRequestSlices(batchId); + } + } + + private void internalRequestSlices(long batchId) { + for (AbstractInputChannel inputChannel : inputChannels.values()) { + try { + inputChannel.requestSlice(batchId); + } catch (Throwable t) { + inputChannel.setError(t); + return; + } } + LOGGER.info("{} request next batch:{}", taskName, batchId); + } + + public void retriggerFetchRequest(RemoteInputChannel inputChannel) { + checkError(); + retriggerRemoteExecutor.execute( + () -> { + try { + retriggerFetchRequest(inputChannel.getInputSliceId()); + } catch (Throwable e) { + cause.set(e); + } + }); + } - protected void buildInputChannels(int connectionId, List inputSlices, - int initialBackoff, int maxBackoff, long initialBatchId) { - - int localChannels = 0; - ShuffleAddress localAddr = connectionManager.getShuffleAddress(); - for (int inputChannelIdx = 0; inputChannelIdx < numberOfInputChannels; inputChannelIdx++) { - PipelineSliceMeta task = (PipelineSliceMeta) inputSlices.get(inputChannelIdx); - ShuffleAddress address = task.getShuffleAddress(); - SliceId inputSlice = task.getSliceId(); - - AbstractInputChannel inputChannel; - if (address.equals(localAddr)) { - inputChannel = new LocalInputChannel(this, inputSlice, inputChannelIdx, - initialBackoff, maxBackoff, initialBatchId); - inputChannels.put(inputSlice, inputChannel); - localChannels++; - } else { - inputChannel = new RemoteInputChannel(this, inputSlice, inputChannelIdx, - new ConnectionId(address, connectionId), initialBackoff, maxBackoff, - initialBatchId, connectionManager); - inputChannels.put(inputSlice, inputChannel); - } + public void retriggerFetchRequest(SliceId sliceId) throws IOException { + synchronized (requestLock) { + if (!isReleased) { + final AbstractInputChannel ch = inputChannels.get(sliceId); + + if (ch.getClass() == RemoteInputChannel.class) { + final RemoteInputChannel rch = (RemoteInputChannel) ch; + rch.retriggerSliceRequest(sliceId); + } else { + final LocalInputChannel ich = (LocalInputChannel) ch; + if (retriggerLocalRequestTimer == null) { + retriggerLocalRequestTimer = new Timer(true); + } + ich.reTriggerSliceRequest(retriggerLocalRequestTimer); } - LOGGER.info("{} create {} local channels in {} channels", taskName, localChannels, - numberOfInputChannels); + } } + } - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ + private void checkError() { + final Throwable t = cause.get(); + if (t != null) { + throw new GeaflowRuntimeException(t.getMessage(), t); + } + } - @Override - public void requestSlices(long batchId) { - synchronized (requestLock) { - if (isReleased) { - throw new IllegalStateException("Already released."); - } + @Override + public Optional getNext() throws IOException, InterruptedException { + return getNext(true); + } - // Sanity checks - if (numberOfInputChannels != inputChannels.size()) { - throw new IllegalStateException(String.format( - "Bug in input fetcher setup logic: mismatch between " - + "number of total input channels [%s] and the currently set number " - + "of input " + "channels [%s].", inputChannels.size(), - numberOfInputChannels)); - } + @Override + public Optional pollNext() throws IOException, InterruptedException { + return getNext(false); + } - internalRequestSlices(batchId); - } - } + private Optional getNext(boolean blocking) + throws IOException, InterruptedException { - private void internalRequestSlices(long batchId) { - for (AbstractInputChannel inputChannel : inputChannels.values()) { - try { - inputChannel.requestSlice(batchId); - } catch (Throwable t) { - inputChannel.setError(t); - return; - } - } - LOGGER.info("{} request next batch:{}", taskName, batchId); + if (!isRunning) { + return Optional.empty(); } - - public void retriggerFetchRequest(RemoteInputChannel inputChannel) { - checkError(); - retriggerRemoteExecutor.execute(() -> { - try { - retriggerFetchRequest(inputChannel.getInputSliceId()); - } catch (Throwable e) { - cause.set(e); - } - }); + if (isReleased) { + throw new IOException("Input fetcher is already closed."); } + checkError(); - public void retriggerFetchRequest(SliceId sliceId) throws IOException { - synchronized (requestLock) { - if (!isReleased) { - final AbstractInputChannel ch = inputChannels.get(sliceId); - - if (ch.getClass() == RemoteInputChannel.class) { - final RemoteInputChannel rch = (RemoteInputChannel) ch; - rch.retriggerSliceRequest(sliceId); - } else { - final LocalInputChannel ich = (LocalInputChannel) ch; - if (retriggerLocalRequestTimer == null) { - retriggerLocalRequestTimer = new Timer(true); - } - ich.reTriggerSliceRequest(retriggerLocalRequestTimer); - } - } - } + Optional> next = + getNextInputData(blocking); + if (!next.isPresent()) { + return Optional.empty(); } - private void checkError() { - final Throwable t = cause.get(); - if (t != null) { - throw new GeaflowRuntimeException(t.getMessage(), t); - } - } + InputWithData inputWithData = next.get(); - @Override - public Optional getNext() throws IOException, InterruptedException { - return getNext(true); - } + PipeFetcherBuffer fetcherBuffer = + new PipeFetcherBuffer( + inputWithData.data.getBuffer(), + inputWithData.input.getChannelIndex(), + inputWithData.moreAvailable, + inputWithData.input.getInputSliceId(), + inputStream); + return Optional.of(fetcherBuffer); + } - @Override - public Optional pollNext() throws IOException, InterruptedException { - return getNext(false); - } + private Optional> getNextInputData( + boolean blocking) throws IOException, InterruptedException { - private Optional getNext(boolean blocking) - throws IOException, InterruptedException { + boolean moreAvailable = false; + AbstractInputChannel currentChannel; + Optional result = Optional.empty(); - if (!isRunning) { - return Optional.empty(); + do { + if (nextInputChannel != null) { + currentChannel = nextInputChannel; + synchronized (inputChannelsWithData) { + moreAvailable = inputChannelsWithData.size() > 0; } - if (isReleased) { - throw new IOException("Input fetcher is already closed."); + } else { + Optional> inputChannel = getChannel(blocking); + if (!inputChannel.isPresent()) { + return Optional.empty(); } - checkError(); - Optional> next = getNextInputData( - blocking); - if (!next.isPresent()) { - return Optional.empty(); + currentChannel = inputChannel.get().f0; + if (currentChannel.isReleased()) { + continue; } + moreAvailable = inputChannel.get().f1; + } + + result = currentChannel.getNext(); + + } while (!result.isPresent()); + + // this channel was now removed from the non-empty channels queue + // we re-add it in case it has more data, because in that case no "non-empty" notification + // will come for that channel + if (result.get().moreAvailable()) { + moreAvailable = true; + if (result.get().getBuffer().isData()) { + nextInputChannel = currentChannel; + } else { + queueChannel(currentChannel); + nextInputChannel = null; + } + } else { + nextInputChannel = null; + } + + return Optional.of(new InputWithData<>(currentChannel, result.get(), moreAvailable)); + } - InputWithData inputWithData = next.get(); + // ------------------------------------------------------------------------ + // Channel notifications + // ------------------------------------------------------------------------ - PipeFetcherBuffer fetcherBuffer = new PipeFetcherBuffer(inputWithData.data.getBuffer(), - inputWithData.input.getChannelIndex(), inputWithData.moreAvailable, - inputWithData.input.getInputSliceId(), inputStream); - return Optional.of(fetcherBuffer); + @Override + public void registerListener(ShardFetcherListener inputFetcherListener) { + synchronized (fetcherListeners) { + fetcherListeners.add(inputFetcherListener); } + } - private Optional> getNextInputData( - boolean blocking) throws IOException, InterruptedException { - - boolean moreAvailable = false; - AbstractInputChannel currentChannel; - Optional result = Optional.empty(); - - do { - if (nextInputChannel != null) { - currentChannel = nextInputChannel; - synchronized (inputChannelsWithData) { - moreAvailable = inputChannelsWithData.size() > 0; - } - } else { - Optional> inputChannel = getChannel(blocking); - if (!inputChannel.isPresent()) { - return Optional.empty(); - } - - currentChannel = inputChannel.get().f0; - if (currentChannel.isReleased()) { - continue; - } - moreAvailable = inputChannel.get().f1; - } + public void notifyChannelNonEmpty(AbstractInputChannel channel) { + queueChannel(Preconditions.checkNotNull(channel)); + } - result = currentChannel.getNext(); + private void queueChannel(AbstractInputChannel channel) { + int availableChannels; - } while (!result.isPresent()); + synchronized (inputChannelsWithData) { + if (enqueuedInputChannelsWithData.get(channel.getChannelIndex())) { + return; + } + availableChannels = inputChannelsWithData.size(); - // this channel was now removed from the non-empty channels queue - // we re-add it in case it has more data, because in that case no "non-empty" notification - // will come for that channel - if (result.get().moreAvailable()) { - moreAvailable = true; - if (result.get().getBuffer().isData()) { - nextInputChannel = currentChannel; - } else { - queueChannel(currentChannel); - nextInputChannel = null; - } - } else { - nextInputChannel = null; - } + inputChannelsWithData.add(channel); + enqueuedInputChannelsWithData.set(channel.getChannelIndex()); - return Optional.of(new InputWithData<>(currentChannel, result.get(), moreAvailable)); + if (availableChannels == 0) { + inputChannelsWithData.notifyAll(); + } } - // ------------------------------------------------------------------------ - // Channel notifications - // ------------------------------------------------------------------------ - - @Override - public void registerListener(ShardFetcherListener inputFetcherListener) { - synchronized (fetcherListeners) { - fetcherListeners.add(inputFetcherListener); + if (availableChannels == 0) { + synchronized (fetcherListeners) { + for (ShardFetcherListener listener : fetcherListeners) { + listener.notifyAvailable(this); } + } } + } - public void notifyChannelNonEmpty(AbstractInputChannel channel) { - queueChannel(Preconditions.checkNotNull(channel)); + private Optional> getChannel(boolean blocking) + throws InterruptedException { + if (nextInputChannel != null) { + return Optional.of(Tuple.of(nextInputChannel, true)); } - private void queueChannel(AbstractInputChannel channel) { - int availableChannels; - - synchronized (inputChannelsWithData) { - if (enqueuedInputChannelsWithData.get(channel.getChannelIndex())) { - return; - } - availableChannels = inputChannelsWithData.size(); - - inputChannelsWithData.add(channel); - enqueuedInputChannelsWithData.set(channel.getChannelIndex()); - - if (availableChannels == 0) { - inputChannelsWithData.notifyAll(); - } + synchronized (inputChannelsWithData) { + while (inputChannelsWithData.size() == 0) { + if (!isRunning) { + return Optional.empty(); } - - if (availableChannels == 0) { - synchronized (fetcherListeners) { - for (ShardFetcherListener listener : fetcherListeners) { - listener.notifyAvailable(this); - } - } + if (isReleased) { + throw new IllegalStateException("Channel released"); } - } - - private Optional> getChannel(boolean blocking) - throws InterruptedException { - if (nextInputChannel != null) { - return Optional.of(Tuple.of(nextInputChannel, true)); + if (blocking) { + inputChannelsWithData.wait(); + } else { + return Optional.empty(); } + } - synchronized (inputChannelsWithData) { - while (inputChannelsWithData.size() == 0) { - if (!isRunning) { - return Optional.empty(); - } - if (isReleased) { - throw new IllegalStateException("Channel released"); - } - - if (blocking) { - inputChannelsWithData.wait(); - } else { - return Optional.empty(); - } - } - - AbstractInputChannel inputChannel = inputChannelsWithData.remove(); - enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex()); - int availableChannels = inputChannelsWithData.size(); + AbstractInputChannel inputChannel = inputChannelsWithData.remove(); + enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex()); + int availableChannels = inputChannelsWithData.size(); - return Optional.of(Tuple.of(inputChannel, availableChannels > 0)); - } + return Optional.of(Tuple.of(inputChannel, availableChannels > 0)); } - - // ------------------------------------------------------------------------ - // Properties - // ------------------------------------------------------------------------ - - @Override - public int getNumberOfInputChannels() { - return numberOfInputChannels; + } + + // ------------------------------------------------------------------------ + // Properties + // ------------------------------------------------------------------------ + + @Override + public int getNumberOfInputChannels() { + return numberOfInputChannels; + } + + public Map getInputChannels() { + return inputChannels; + } + + public int getFetcherIndex() { + return fetcherIndex; + } + + public int getStageId() { + return stageId; + } + + @Override + public boolean isFinished() { + synchronized (requestLock) { + for (AbstractInputChannel inputChannel : inputChannels.values()) { + if (!inputChannel.isReleased()) { + return false; + } + } } - public Map getInputChannels() { - return inputChannels; - } + return true; + } - public int getFetcherIndex() { - return fetcherIndex; - } + @Override + public void close() { + boolean released = false; + isRunning = false; + synchronized (requestLock) { + if (!isReleased) { + try { + LOGGER.debug("{}: Releasing {}.", taskName, this); - public int getStageId() { - return stageId; - } + if (this.retriggerLocalRequestTimer != null) { + this.retriggerLocalRequestTimer.cancel(); + this.retriggerLocalRequestTimer = null; + } - @Override - public boolean isFinished() { - synchronized (requestLock) { - for (AbstractInputChannel inputChannel : inputChannels.values()) { - if (!inputChannel.isReleased()) { - return false; - } + for (AbstractInputChannel inputChannel : inputChannels.values()) { + try { + inputChannel.release(); + } catch (IOException e) { + LOGGER.error( + "{}: Error during release of channel resources: {}.", + taskName, + e.getMessage(), + e); + throw new GeaflowRuntimeException(e); } + } + } finally { + released = true; + isReleased = true; } - - return true; + } } - @Override - public void close() { - boolean released = false; - isRunning = false; - synchronized (requestLock) { - if (!isReleased) { - try { - LOGGER.debug("{}: Releasing {}.", taskName, this); - - if (this.retriggerLocalRequestTimer != null) { - this.retriggerLocalRequestTimer.cancel(); - this.retriggerLocalRequestTimer = null; - } - - for (AbstractInputChannel inputChannel : inputChannels.values()) { - try { - inputChannel.release(); - } catch (IOException e) { - LOGGER.error("{}: Error during release of channel resources: {}.", - taskName, e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } - } - } finally { - released = true; - isReleased = true; - } - } - } - - if (released) { - synchronized (inputChannelsWithData) { - inputChannelsWithData.notifyAll(); - } - } + if (released) { + synchronized (inputChannelsWithData) { + inputChannelsWithData.notifyAll(); + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcher.java index 0ab58f18b..6dfc7f389 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcher.java @@ -19,75 +19,71 @@ package org.apache.geaflow.shuffle.pipeline.fetcher; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.Optional; + import org.apache.geaflow.shuffle.pipeline.buffer.PipeFetcherBuffer; +import com.google.common.base.Preconditions; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.partition.consumer.InputGate. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.partition.consumer.InputGate. */ public interface ShardFetcher { - /** - * Request the upstream slices and with specific batch id. - * - * @throws IOException io exception. - */ - void requestSlices(long batchId) throws IOException; - - /** - * Blocking call waiting for next {@link PipeFetcherBuffer}. - * - * @return {@code Optional.empty()} if {@link #isFinished()} returns true. - */ - Optional getNext() throws IOException, InterruptedException; - - /** - * Poll the {@link PipeFetcherBuffer}. - * - * @return {@code Optional.empty()} if there is no data to return or if {@link #isFinished()} - * returns true. - */ - Optional pollNext() throws IOException, InterruptedException; - - /** - * Check if data transfer is finished. - */ - boolean isFinished(); - - /** - * Get the number of input channel. - * - * @return channel number. - */ - int getNumberOfInputChannels(); - - /** - * Register fetcher listeners. Notify when fetcher has data. - */ - void registerListener(ShardFetcherListener listener); - - /** - * Close. - */ - void close(); - - class InputWithData { - - protected final INPUT input; - protected final DATA data; - protected final boolean moreAvailable; - - InputWithData(INPUT input, DATA data, boolean moreAvailable) { - this.input = Preconditions.checkNotNull(input); - this.data = Preconditions.checkNotNull(data); - this.moreAvailable = moreAvailable; - } - } + /** + * Request the upstream slices and with specific batch id. + * + * @throws IOException io exception. + */ + void requestSlices(long batchId) throws IOException; + + /** + * Blocking call waiting for next {@link PipeFetcherBuffer}. + * + * @return {@code Optional.empty()} if {@link #isFinished()} returns true. + */ + Optional getNext() throws IOException, InterruptedException; + + /** + * Poll the {@link PipeFetcherBuffer}. + * + * @return {@code Optional.empty()} if there is no data to return or if {@link #isFinished()} + * returns true. + */ + Optional pollNext() throws IOException, InterruptedException; + /** Check if data transfer is finished. */ + boolean isFinished(); + + /** + * Get the number of input channel. + * + * @return channel number. + */ + int getNumberOfInputChannels(); + + /** Register fetcher listeners. Notify when fetcher has data. */ + void registerListener(ShardFetcherListener listener); + + /** Close. */ + void close(); + + class InputWithData { + + protected final INPUT input; + protected final DATA data; + protected final boolean moreAvailable; + + InputWithData(INPUT input, DATA data, boolean moreAvailable) { + this.input = Preconditions.checkNotNull(input); + this.data = Preconditions.checkNotNull(data); + this.moreAvailable = moreAvailable; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcherListener.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcherListener.java index f70441576..c1f9d4c09 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcherListener.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/fetcher/ShardFetcherListener.java @@ -21,12 +21,10 @@ public interface ShardFetcherListener { - /** - * notify if the input fetcher moves from zero to non-zero - * available input channels with data. - * - * @param shardFetcher Input fetcher that became available. - */ - void notifyAvailable(ShardFetcher shardFetcher); - + /** + * notify if the input fetcher moves from zero to non-zero available input channels with data. + * + * @param shardFetcher Input fetcher that became available. + */ + void notifyAvailable(ShardFetcher shardFetcher); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/AbstractSlice.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/AbstractSlice.java index 9c1e2836a..2dd8b7fc3 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/AbstractSlice.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/AbstractSlice.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.pipeline.slice; import java.util.ArrayDeque; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; @@ -28,78 +29,77 @@ public abstract class AbstractSlice implements IPipelineSlice { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineSlice.class); - - protected final SliceId sliceId; - protected final String taskLogTag; - protected int totalBufferCount; - protected ArrayDeque buffers; - protected PipelineSliceReader sliceReader; - protected volatile boolean isReleased; - - public AbstractSlice(String taskLogTag, SliceId sliceId) { - this.sliceId = sliceId; - this.taskLogTag = taskLogTag; - this.totalBufferCount = 0; - this.buffers = new ArrayDeque<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineSlice.class); + + protected final SliceId sliceId; + protected final String taskLogTag; + protected int totalBufferCount; + protected ArrayDeque buffers; + protected PipelineSliceReader sliceReader; + protected volatile boolean isReleased; + + public AbstractSlice(String taskLogTag, SliceId sliceId) { + this.sliceId = sliceId; + this.taskLogTag = taskLogTag; + this.totalBufferCount = 0; + this.buffers = new ArrayDeque<>(); + } + + @Override + public SliceId getSliceId() { + return sliceId; + } + + @Override + public PipelineSliceReader createSliceReader(long startBatchId, PipelineSliceListener listener) { + synchronized (buffers) { + if (isReleased) { + throw new GeaflowRuntimeException("slice is released:" + sliceId); + } + if (sliceReader != null && sliceReader.hasNext()) { + throw new GeaflowRuntimeException("slice is already created:" + sliceId); + } + + LOGGER.debug( + "creating reader for {} {} with startBatch:{}", taskLogTag, sliceId, startBatchId); + + sliceReader = new DisposableSliceReader(this, startBatchId, listener); + return sliceReader; } - - @Override - public SliceId getSliceId() { - return sliceId; - } - - @Override - public PipelineSliceReader createSliceReader(long startBatchId, PipelineSliceListener listener) { - synchronized (buffers) { - if (isReleased) { - throw new GeaflowRuntimeException("slice is released:" + sliceId); - } - if (sliceReader != null && sliceReader.hasNext()) { - throw new GeaflowRuntimeException("slice is already created:" + sliceId); - } - - LOGGER.debug("creating reader for {} {} with startBatch:{}", - taskLogTag, sliceId, startBatchId); - - sliceReader = new DisposableSliceReader(this, startBatchId, listener); - return sliceReader; - } + } + + @Override + public boolean canRelease() { + return !hasNext(); + } + + @Override + public boolean isReleased() { + return isReleased; + } + + @Override + public void release() { + int bufferSize; + final PipelineSliceReader reader; + + synchronized (buffers) { + if (isReleased) { + return; + } + + // Release all available buffers + bufferSize = buffers.size(); + buffers.clear(); + + reader = sliceReader; + sliceReader = null; + isReleased = true; } - @Override - public boolean canRelease() { - return !hasNext(); + LOGGER.info("{}: released {} with bufferSize:{}", taskLogTag, sliceId, bufferSize); + if (reader != null) { + reader.release(); } - - @Override - public boolean isReleased() { - return isReleased; - } - - @Override - public void release() { - int bufferSize; - final PipelineSliceReader reader; - - synchronized (buffers) { - if (isReleased) { - return; - } - - // Release all available buffers - bufferSize = buffers.size(); - buffers.clear(); - - reader = sliceReader; - sliceReader = null; - isReleased = true; - } - - LOGGER.info("{}: released {} with bufferSize:{}", taskLogTag, sliceId, bufferSize); - if (reader != null) { - reader.release(); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/BlockingSlice.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/BlockingSlice.java index e53cf3fc7..fa7e38338 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/BlockingSlice.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/BlockingSlice.java @@ -19,114 +19,113 @@ package org.apache.geaflow.shuffle.pipeline.slice; -import com.google.common.base.Preconditions; import org.apache.geaflow.shuffle.api.writer.PipelineShardWriter; import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; -public class BlockingSlice extends AbstractSlice { - - private final PipelineShardWriter parentWriter; - private boolean flushRequested; - - public BlockingSlice(String taskLogTag, SliceId sliceId, - PipelineShardWriter parentWriter) { - super(taskLogTag, sliceId); - this.parentWriter = parentWriter; - } - - // ------------------------------------------------------------------------ - // Produce - // ------------------------------------------------------------------------ - - @Override - public boolean add(PipeBuffer recordBuffer) { - final boolean notifyDataAvailable; - synchronized (buffers) { - if (isReleased) { - return false; - } - buffers.add(recordBuffer); - notifyDataAvailable = shouldNotifyDataAvailable(); - } +import com.google.common.base.Preconditions; - if (notifyDataAvailable) { - notifyDataAvailable(recordBuffer.getBatchId()); - } - return true; - } +public class BlockingSlice extends AbstractSlice { - private boolean shouldNotifyDataAvailable() { - return sliceReader != null && !this.flushRequested && getCurrentNumberOfBuffers() == 1; + private final PipelineShardWriter parentWriter; + private boolean flushRequested; + + public BlockingSlice(String taskLogTag, SliceId sliceId, PipelineShardWriter parentWriter) { + super(taskLogTag, sliceId); + this.parentWriter = parentWriter; + } + + // ------------------------------------------------------------------------ + // Produce + // ------------------------------------------------------------------------ + + @Override + public boolean add(PipeBuffer recordBuffer) { + final boolean notifyDataAvailable; + synchronized (buffers) { + if (isReleased) { + return false; + } + buffers.add(recordBuffer); + notifyDataAvailable = shouldNotifyDataAvailable(); } - @Override - public void flush() { - long batchId; - boolean needNotify; - synchronized (buffers) { - if (buffers.isEmpty()) { - return; - } - - batchId = buffers.peekLast().getBatchId(); - needNotify = !flushRequested && buffers.size() == 1; - updateFlushRequested(flushRequested || buffers.size() > 1 || needNotify); - } - - if (needNotify) { - notifyDataAvailable(batchId); - } + if (notifyDataAvailable) { + notifyDataAvailable(recordBuffer.getBatchId()); } - - private void notifyDataAvailable(long batchId) { - final PipelineSliceReader reader = sliceReader; - if (reader != null) { - reader.notifyAvailable(batchId); - } + return true; + } + + private boolean shouldNotifyDataAvailable() { + return sliceReader != null && !this.flushRequested && getCurrentNumberOfBuffers() == 1; + } + + @Override + public void flush() { + long batchId; + boolean needNotify; + synchronized (buffers) { + if (buffers.isEmpty()) { + return; + } + + batchId = buffers.peekLast().getBatchId(); + needNotify = !flushRequested && buffers.size() == 1; + updateFlushRequested(flushRequested || buffers.size() > 1 || needNotify); } - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ - - @Override - public boolean isReady2read() { - return true; + if (needNotify) { + notifyDataAvailable(batchId); } + } - @Override - public PipeBuffer next() { - synchronized (buffers) { - PipeBuffer buffer = null; - if (!buffers.isEmpty()) { - buffer = buffers.pop(); - if (buffers.isEmpty()) { - updateFlushRequested(false); - } - } - if (buffer != null && buffer.isData()) { - parentWriter.notifyBufferConsumed(buffer.getBufferSize()); - } - return buffer; - } + private void notifyDataAvailable(long batchId) { + final PipelineSliceReader reader = sliceReader; + if (reader != null) { + reader.notifyAvailable(batchId); } - - @Override - public boolean hasNext() { - synchronized (buffers) { - return this.flushRequested || getCurrentNumberOfBuffers() > 0; + } + + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ + + @Override + public boolean isReady2read() { + return true; + } + + @Override + public PipeBuffer next() { + synchronized (buffers) { + PipeBuffer buffer = null; + if (!buffers.isEmpty()) { + buffer = buffers.pop(); + if (buffers.isEmpty()) { + updateFlushRequested(false); } + } + if (buffer != null && buffer.isData()) { + parentWriter.notifyBufferConsumed(buffer.getBufferSize()); + } + return buffer; } + } - private int getCurrentNumberOfBuffers() { - Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); - return buffers.size(); + @Override + public boolean hasNext() { + synchronized (buffers) { + return this.flushRequested || getCurrentNumberOfBuffers() > 0; } + } - private void updateFlushRequested(boolean flushRequested) { - Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); - this.flushRequested = flushRequested; - } + private int getCurrentNumberOfBuffers() { + Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); + return buffers.size(); + } + private void updateFlushRequested(boolean flushRequested) { + Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); + this.flushRequested = flushRequested; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/DisposableSliceReader.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/DisposableSliceReader.java index c72700afc..929fe353c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/DisposableSliceReader.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/DisposableSliceReader.java @@ -22,33 +22,30 @@ import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; -/** - * DisposableReader poll the data from slice queue and release the slice on completion. - */ +/** DisposableReader poll the data from slice queue and release the slice on completion. */ public class DisposableSliceReader extends PipelineSliceReader { - public DisposableSliceReader(IPipelineSlice slice, long startBatchId, - PipelineSliceListener listener) { - super(slice, startBatchId, listener); - } + public DisposableSliceReader( + IPipelineSlice slice, long startBatchId, PipelineSliceListener listener) { + super(slice, startBatchId, listener); + } - @Override - public boolean hasNext() { - if (isReleased()) { - throw new IllegalStateException("slice has been released already: " + slice.getSliceId()); - } - return hasBatch() && slice.hasNext(); + @Override + public boolean hasNext() { + if (isReleased()) { + throw new IllegalStateException("slice has been released already: " + slice.getSliceId()); } + return hasBatch() && slice.hasNext(); + } - public PipeChannelBuffer next() { - PipeBuffer record = slice.next(); - if (record == null) { - return null; - } - if (!record.isData()) { - consumedBatchId = record.getBatchId(); - } - return new PipeChannelBuffer(record, hasNext()); + public PipeChannelBuffer next() { + PipeBuffer record = slice.next(); + if (record == null) { + return null; } - + if (!record.isData()) { + consumedBatchId = record.getBatchId(); + } + return new PipeChannelBuffer(record, hasNext()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/IPipelineSlice.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/IPipelineSlice.java index bd227f760..eec7de221 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/IPipelineSlice.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/IPipelineSlice.java @@ -24,43 +24,38 @@ public interface IPipelineSlice { - SliceId getSliceId(); + SliceId getSliceId(); - boolean isReleased(); + boolean isReleased(); - boolean canRelease(); + boolean canRelease(); - void release(); + void release(); - // ------------------------------------------------------------------------ - // Produce - // ------------------------------------------------------------------------ + // ------------------------------------------------------------------------ + // Produce + // ------------------------------------------------------------------------ - boolean add(PipeBuffer recordBuffer); + boolean add(PipeBuffer recordBuffer); - void flush(); + void flush(); - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ - PipelineSliceReader createSliceReader(long startBatchId, PipelineSliceListener listener); + PipelineSliceReader createSliceReader(long startBatchId, PipelineSliceListener listener); - /** - * Check whether the slice is ready to read. - */ - boolean isReady2read(); + /** Check whether the slice is ready to read. */ + boolean isReady2read(); - /** - * Check whether the slice has next record. - */ - boolean hasNext(); - - /** - * Poll next record from slice. - * - * @return - */ - PipeBuffer next(); + /** Check whether the slice has next record. */ + boolean hasNext(); + /** + * Poll next record from slice. + * + * @return + */ + PipeBuffer next(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSlice.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSlice.java index ddfc2a923..bf24b833a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSlice.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSlice.java @@ -19,108 +19,108 @@ package org.apache.geaflow.shuffle.pipeline.slice; -import com.google.common.base.Preconditions; import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.pipeline.buffer.PipeBuffer; -public class PipelineSlice extends AbstractSlice { - - private boolean flushRequested; - - public PipelineSlice(String taskLogTag, SliceId sliceId) { - super(taskLogTag, sliceId); - } - - // ------------------------------------------------------------------------ - // Produce - // ------------------------------------------------------------------------ - - @Override - public boolean add(PipeBuffer recordBuffer) { - final boolean notifyDataAvailable; - synchronized (buffers) { - if (isReleased) { - return false; - } - totalBufferCount++; - buffers.add(recordBuffer); - notifyDataAvailable = shouldNotifyDataAvailable(); - } +import com.google.common.base.Preconditions; - if (notifyDataAvailable) { - notifyDataAvailable(recordBuffer.getBatchId()); - } - return true; - } +public class PipelineSlice extends AbstractSlice { - private boolean shouldNotifyDataAvailable() { - return sliceReader != null && !this.flushRequested && getCurrentNumberOfBuffers() == 1; + private boolean flushRequested; + + public PipelineSlice(String taskLogTag, SliceId sliceId) { + super(taskLogTag, sliceId); + } + + // ------------------------------------------------------------------------ + // Produce + // ------------------------------------------------------------------------ + + @Override + public boolean add(PipeBuffer recordBuffer) { + final boolean notifyDataAvailable; + synchronized (buffers) { + if (isReleased) { + return false; + } + totalBufferCount++; + buffers.add(recordBuffer); + notifyDataAvailable = shouldNotifyDataAvailable(); } - @Override - public void flush() { - long batchId; - boolean needNotify; - synchronized (buffers) { - if (buffers.isEmpty()) { - return; - } - - batchId = buffers.peekLast().getBatchId(); - needNotify = !flushRequested && buffers.size() == 1; - updateFlushRequested(flushRequested || buffers.size() > 1 || needNotify); - } - - if (needNotify) { - notifyDataAvailable(batchId); - } + if (notifyDataAvailable) { + notifyDataAvailable(recordBuffer.getBatchId()); } - - private void notifyDataAvailable(long batchId) { - final PipelineSliceReader reader = sliceReader; - if (reader != null) { - reader.notifyAvailable(batchId); - } + return true; + } + + private boolean shouldNotifyDataAvailable() { + return sliceReader != null && !this.flushRequested && getCurrentNumberOfBuffers() == 1; + } + + @Override + public void flush() { + long batchId; + boolean needNotify; + synchronized (buffers) { + if (buffers.isEmpty()) { + return; + } + + batchId = buffers.peekLast().getBatchId(); + needNotify = !flushRequested && buffers.size() == 1; + updateFlushRequested(flushRequested || buffers.size() > 1 || needNotify); } - // ------------------------------------------------------------------------ - // Consume - // ------------------------------------------------------------------------ - - @Override - public boolean isReady2read() { - return true; + if (needNotify) { + notifyDataAvailable(batchId); } + } - @Override - public PipeBuffer next() { - synchronized (buffers) { - PipeBuffer buffer = null; - if (!buffers.isEmpty()) { - buffer = buffers.pop(); - if (buffers.size() == 0) { - updateFlushRequested(false); - } - } - return buffer; - } + private void notifyDataAvailable(long batchId) { + final PipelineSliceReader reader = sliceReader; + if (reader != null) { + reader.notifyAvailable(batchId); } - - @Override - public boolean hasNext() { - synchronized (buffers) { - return this.flushRequested || getCurrentNumberOfBuffers() > 0; + } + + // ------------------------------------------------------------------------ + // Consume + // ------------------------------------------------------------------------ + + @Override + public boolean isReady2read() { + return true; + } + + @Override + public PipeBuffer next() { + synchronized (buffers) { + PipeBuffer buffer = null; + if (!buffers.isEmpty()) { + buffer = buffers.pop(); + if (buffers.size() == 0) { + updateFlushRequested(false); } + } + return buffer; } + } - private int getCurrentNumberOfBuffers() { - Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); - return buffers.size(); + @Override + public boolean hasNext() { + synchronized (buffers) { + return this.flushRequested || getCurrentNumberOfBuffers() > 0; } + } - private void updateFlushRequested(boolean flushRequested) { - Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); - this.flushRequested = flushRequested; - } + private int getCurrentNumberOfBuffers() { + Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); + return buffers.size(); + } + private void updateFlushRequested(boolean flushRequested) { + Preconditions.checkArgument(Thread.holdsLock(buffers), "fail to get lock of buffers"); + this.flushRequested = flushRequested; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceListener.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceListener.java index 2948d7e92..381f38c2d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceListener.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceListener.java @@ -21,9 +21,6 @@ public interface PipelineSliceListener { - /** - * Called whenever there might be new data available in slice. - */ - void notifyDataAvailable(); - + /** Called whenever there might be new data available in slice. */ + void notifyDataAvailable(); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceReader.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceReader.java index 2ef6394f7..ba3d8ab90 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceReader.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/PipelineSliceReader.java @@ -20,59 +20,58 @@ package org.apache.geaflow.shuffle.pipeline.slice; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; import org.apache.geaflow.shuffle.pipeline.channel.LocalInputChannel; /** - * Called by {@link SequenceSliceReader} for remote consumption - * and {@link LocalInputChannel} - * for local consumption. + * Called by {@link SequenceSliceReader} for remote consumption and {@link LocalInputChannel} for + * local consumption. */ public abstract class PipelineSliceReader { - // Client request batch id. - private volatile long requestBatchId; - private final PipelineSliceListener listener; - - protected final IPipelineSlice slice; - protected volatile long consumedBatchId; - protected final AtomicBoolean released; + // Client request batch id. + private volatile long requestBatchId; + private final PipelineSliceListener listener; - public PipelineSliceReader(IPipelineSlice slice, long startBatchId, - PipelineSliceListener listener) { - this.slice = slice; - this.listener = listener; - this.consumedBatchId = -1; - this.requestBatchId = startBatchId; - this.released = new AtomicBoolean(); - } + protected final IPipelineSlice slice; + protected volatile long consumedBatchId; + protected final AtomicBoolean released; - public void updateRequestedBatchId(long batchId) { - this.requestBatchId = batchId; - } + public PipelineSliceReader( + IPipelineSlice slice, long startBatchId, PipelineSliceListener listener) { + this.slice = slice; + this.listener = listener; + this.consumedBatchId = -1; + this.requestBatchId = startBatchId; + this.released = new AtomicBoolean(); + } - public void notifyAvailable(long batchId) { - if (requestBatchId == -1 || batchId <= requestBatchId) { - listener.notifyDataAvailable(); - } - } + public void updateRequestedBatchId(long batchId) { + this.requestBatchId = batchId; + } - protected boolean hasBatch() { - return requestBatchId == -1 || consumedBatchId < requestBatchId; + public void notifyAvailable(long batchId) { + if (requestBatchId == -1 || batchId <= requestBatchId) { + listener.notifyDataAvailable(); } + } - public abstract boolean hasNext(); + protected boolean hasBatch() { + return requestBatchId == -1 || consumedBatchId < requestBatchId; + } - public abstract PipeChannelBuffer next(); + public abstract boolean hasNext(); - public void release() { - if (released.compareAndSet(false, true) && slice.canRelease()) { - slice.release(); - } - } + public abstract PipeChannelBuffer next(); - public boolean isReleased() { - return released.get() || slice.isReleased(); + public void release() { + if (released.compareAndSet(false, true) && slice.canRelease()) { + slice.release(); } + } + public boolean isReleased() { + return released.get() || slice.isReleased(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SequenceSliceReader.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SequenceSliceReader.java index b9f79f1cb..73396a936 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SequenceSliceReader.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SequenceSliceReader.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.network.netty.SliceOutputChannelHandler; import org.apache.geaflow.shuffle.pipeline.buffer.PipeChannelBuffer; @@ -31,105 +32,105 @@ public class SequenceSliceReader implements PipelineSliceListener { - private static final Logger LOGGER = LoggerFactory.getLogger(SequenceSliceReader.class); - private final ChannelId inputChannelId; - private final SliceOutputChannelHandler requestHandler; - - private SliceId sliceId; - private PipelineSliceReader sliceReader; - private int sequenceNumber = -1; - private int initialCredit; - private AtomicInteger availableCredit; - - private volatile boolean isRegistered = false; - private volatile boolean isReleased = false; - - public SequenceSliceReader(ChannelId inputChannelId, SliceOutputChannelHandler requestHandler) { - this.inputChannelId = inputChannelId; - this.requestHandler = requestHandler; - } - - public void createSliceReader(SliceId sliceId, long startBatchId, int initCredit) - throws IOException { - this.sliceId = sliceId; - SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - this.sliceReader = sliceManager.createSliceReader(sliceId, startBatchId, this); - this.initialCredit = initCredit; - this.availableCredit = new AtomicInteger(initCredit); - notifyDataAvailable(); - } - - @Override - public void notifyDataAvailable() { - requestHandler.notifyNonEmpty(this); - } - - public void requestBatch(long batchId) { - sliceReader.updateRequestedBatchId(batchId); + private static final Logger LOGGER = LoggerFactory.getLogger(SequenceSliceReader.class); + private final ChannelId inputChannelId; + private final SliceOutputChannelHandler requestHandler; + + private SliceId sliceId; + private PipelineSliceReader sliceReader; + private int sequenceNumber = -1; + private int initialCredit; + private AtomicInteger availableCredit; + + private volatile boolean isRegistered = false; + private volatile boolean isReleased = false; + + public SequenceSliceReader(ChannelId inputChannelId, SliceOutputChannelHandler requestHandler) { + this.inputChannelId = inputChannelId; + this.requestHandler = requestHandler; + } + + public void createSliceReader(SliceId sliceId, long startBatchId, int initCredit) + throws IOException { + this.sliceId = sliceId; + SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + this.sliceReader = sliceManager.createSliceReader(sliceId, startBatchId, this); + this.initialCredit = initCredit; + this.availableCredit = new AtomicInteger(initCredit); + notifyDataAvailable(); + } + + @Override + public void notifyDataAvailable() { + requestHandler.notifyNonEmpty(this); + } + + public void requestBatch(long batchId) { + sliceReader.updateRequestedBatchId(batchId); + } + + public void addCredit(int credit) { + int avail = this.availableCredit.addAndGet(credit); + if (avail > initialCredit) { + LOGGER.warn("available credit {} > initial credit {}", avail, initialCredit); } - - public void addCredit(int credit) { - int avail = this.availableCredit.addAndGet(credit); - if (avail > initialCredit) { - LOGGER.warn("available credit {} > initial credit {}", avail, initialCredit); - } + } + + public boolean hasNext() { + return sliceReader != null && sliceReader.hasNext(); + } + + public boolean isAvailable() { + // initial credit less than 0, means credit is unlimited. + return sliceReader != null + && sliceReader.hasNext() + && (initialCredit <= 0 || availableCredit.get() > 0); + } + + public PipeChannelBuffer next() { + if (isReleased) { + throw new IllegalArgumentException("slice has been released already: " + sliceId); } - - public boolean hasNext() { - return sliceReader != null && sliceReader.hasNext(); - } - - public boolean isAvailable() { - // initial credit less than 0, means credit is unlimited. - return sliceReader != null && sliceReader.hasNext() && (initialCredit <= 0 - || availableCredit.get() > 0); + PipeChannelBuffer next = sliceReader.next(); + if (next != null) { + sequenceNumber++; + if (next.getBuffer().isData()) { + availableCredit.decrementAndGet(); + } + return next; } - - public PipeChannelBuffer next() { - if (isReleased) { - throw new IllegalArgumentException("slice has been released already: " + sliceId); - } - PipeChannelBuffer next = sliceReader.next(); - if (next != null) { - sequenceNumber++; - if (next.getBuffer().isData()) { - availableCredit.decrementAndGet(); - } - return next; - } - return null; - } - - public int getSequenceNumber() { - return sequenceNumber; - } - - public ChannelId getReceiverId() { - return inputChannelId; - } - - public void setRegistered(boolean registered) { - this.isRegistered = registered; - } - - public boolean isRegistered() { - return isRegistered; - } - - public void releaseAllResources() throws IOException { - if (!isReleased) { - isReleased = true; - PipelineSliceReader reader = sliceReader; - if (reader != null) { - reader.release(); - sliceReader = null; - } - } - } - - @Override - public String toString() { - return "SequenceSliceReader{" + "sliceId=" + sliceId + '}'; + return null; + } + + public int getSequenceNumber() { + return sequenceNumber; + } + + public ChannelId getReceiverId() { + return inputChannelId; + } + + public void setRegistered(boolean registered) { + this.isRegistered = registered; + } + + public boolean isRegistered() { + return isRegistered; + } + + public void releaseAllResources() throws IOException { + if (!isReleased) { + isReleased = true; + PipelineSliceReader reader = sliceReader; + if (reader != null) { + reader.release(); + sliceReader = null; + } } + } + @Override + public String toString() { + return "SequenceSliceReader{" + "sliceId=" + sliceId + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SliceManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SliceManager.java index a2970b2d0..64691db11 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SliceManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SliceManager.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.message.SliceId; import org.apache.geaflow.shuffle.util.SliceNotFoundException; @@ -33,57 +34,56 @@ public class SliceManager { - private static final Logger LOGGER = LoggerFactory.getLogger(SliceManager.class); - private final Map> pipeline2slices = new HashMap<>(); - private final Map slices = new ConcurrentHashMap<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(SliceManager.class); + private final Map> pipeline2slices = new HashMap<>(); + private final Map slices = new ConcurrentHashMap<>(); - public void register(SliceId sliceId, IPipelineSlice slice) { - if (this.slices.containsKey(sliceId)) { - throw new GeaflowRuntimeException("slice already registered: " + sliceId); - } - LOGGER.debug("register slice {} {}", sliceId, slice.getClass().getSimpleName()); - this.slices.put(sliceId, slice); - synchronized (this.pipeline2slices) { - long pipelineId = sliceId.getWriterId().getPipelineId(); - Set sliceIds = this.pipeline2slices.computeIfAbsent(pipelineId, k -> new HashSet<>()); - sliceIds.add(sliceId); - } + public void register(SliceId sliceId, IPipelineSlice slice) { + if (this.slices.containsKey(sliceId)) { + throw new GeaflowRuntimeException("slice already registered: " + sliceId); } - - public IPipelineSlice getSlice(SliceId sliceId) { - return this.slices.get(sliceId); + LOGGER.debug("register slice {} {}", sliceId, slice.getClass().getSimpleName()); + this.slices.put(sliceId, slice); + synchronized (this.pipeline2slices) { + long pipelineId = sliceId.getWriterId().getPipelineId(); + Set sliceIds = + this.pipeline2slices.computeIfAbsent(pipelineId, k -> new HashSet<>()); + sliceIds.add(sliceId); } + } - public PipelineSliceReader createSliceReader(SliceId sliceId, - long startBatchId, - PipelineSliceListener listener) throws IOException { - IPipelineSlice slice = this.getSlice(sliceId); - if (slice == null) { - throw new SliceNotFoundException(sliceId); - } - return slice.createSliceReader(startBatchId, listener); + public IPipelineSlice getSlice(SliceId sliceId) { + return this.slices.get(sliceId); + } + + public PipelineSliceReader createSliceReader( + SliceId sliceId, long startBatchId, PipelineSliceListener listener) throws IOException { + IPipelineSlice slice = this.getSlice(sliceId); + if (slice == null) { + throw new SliceNotFoundException(sliceId); } + return slice.createSliceReader(startBatchId, listener); + } - public void release(SliceId sliceId) { - IPipelineSlice slice = this.slices.remove(sliceId); - if (slice != null && !slice.isReleased()) { - slice.release(); - LOGGER.info("release slice {}", sliceId); - } + public void release(SliceId sliceId) { + IPipelineSlice slice = this.slices.remove(sliceId); + if (slice != null && !slice.isReleased()) { + slice.release(); + LOGGER.info("release slice {}", sliceId); } + } - public void release(long pipelineId) { - if (!this.pipeline2slices.containsKey(pipelineId)) { - return; - } - synchronized (this.pipeline2slices) { - Set sliceIds = this.pipeline2slices.remove(pipelineId); - if (sliceIds != null) { - for (SliceId sliceId : sliceIds) { - this.release(sliceId); - } - } + public void release(long pipelineId) { + if (!this.pipeline2slices.containsKey(pipelineId)) { + return; + } + synchronized (this.pipeline2slices) { + Set sliceIds = this.pipeline2slices.remove(pipelineId); + if (sliceIds != null) { + for (SliceId sliceId : sliceIds) { + this.release(sliceId); } + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSlice.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSlice.java index 1aa086ef2..e54ace593 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSlice.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSlice.java @@ -25,6 +25,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -41,253 +42,259 @@ public class SpillablePipelineSlice extends AbstractSlice { - private static final Logger LOGGER = LoggerFactory.getLogger(SpillablePipelineSlice.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SpillablePipelineSlice.class); + + private final String fileName; + private final StorageLevel storageLevel; + private final ShuffleStore store; + private final ShuffleMemoryTracker memoryTracker; + + private OutputStream outputStream; + private CloseableIterator streamBufferIterator; + private PipeBuffer value; + private volatile boolean ready2read = false; + + // Bytes count in memory. + private long memoryBytes = 0; + // Bytes count on disk. + private long diskBytes = 0; + + public SpillablePipelineSlice(String taskLogTag, SliceId sliceId) { + this( + taskLogTag, + sliceId, + ShuffleManager.getInstance().getShuffleConfig(), + ShuffleManager.getInstance().getShuffleMemoryTracker()); + } + + public SpillablePipelineSlice( + String taskLogTag, + SliceId sliceId, + ShuffleConfig shuffleConfig, + ShuffleMemoryTracker memoryTracker) { + super(taskLogTag, sliceId); + this.storageLevel = shuffleConfig.getStorageLevel(); + this.store = ShuffleStore.getShuffleStore(shuffleConfig); + String fileName = + String.format( + "shuffle-%d-%d-%d", + sliceId.getPipelineId(), sliceId.getEdgeId(), sliceId.getSliceIndex()); + this.fileName = store.getFilePath(fileName); + this.memoryTracker = memoryTracker; + } - private final String fileName; - private final StorageLevel storageLevel; - private final ShuffleStore store; - private final ShuffleMemoryTracker memoryTracker; + public String getFileName() { + return this.fileName; + } - private OutputStream outputStream; - private CloseableIterator streamBufferIterator; - private PipeBuffer value; - private volatile boolean ready2read = false; + ////////////////////////////// + // Produce data. - // Bytes count in memory. - private long memoryBytes = 0; - // Bytes count on disk. - private long diskBytes = 0; + /// /////////////////////////// - public SpillablePipelineSlice(String taskLogTag, SliceId sliceId) { - this(taskLogTag, sliceId, ShuffleManager.getInstance().getShuffleConfig(), - ShuffleManager.getInstance() - .getShuffleMemoryTracker()); + @Override + public boolean add(PipeBuffer buffer) { + if (this.isReleased || this.ready2read) { + throw new GeaflowRuntimeException( + "slice already released or mark finish: " + this.getSliceId()); } + totalBufferCount++; - public SpillablePipelineSlice(String taskLogTag, SliceId sliceId, - ShuffleConfig shuffleConfig, ShuffleMemoryTracker memoryTracker) { - super(taskLogTag, sliceId); - this.storageLevel = shuffleConfig.getStorageLevel(); - this.store = ShuffleStore.getShuffleStore(shuffleConfig); - String fileName = String.format("shuffle-%d-%d-%d", - sliceId.getPipelineId(), sliceId.getEdgeId(), sliceId.getSliceIndex()); - this.fileName = store.getFilePath(fileName); - this.memoryTracker = memoryTracker; + if (this.storageLevel == StorageLevel.MEMORY) { + this.writeMemory(buffer); + return true; } - public String getFileName() { - return this.fileName; + if (this.storageLevel == StorageLevel.DISK) { + this.writeStore(buffer); + return true; } - ////////////////////////////// - // Produce data. - - /// /////////////////////////// - - @Override - public boolean add(PipeBuffer buffer) { - if (this.isReleased || this.ready2read) { - throw new GeaflowRuntimeException("slice already released or mark finish: " + this.getSliceId()); - } - totalBufferCount++; + this.writeMemory(buffer); + if (!memoryTracker.checkMemoryEnough()) { + this.spillWrite(); + } + return true; + } + + private void writeMemory(PipeBuffer buffer) { + this.buffers.add(buffer); + this.memoryBytes += buffer.getBufferSize(); + } + + private void writeStore(PipeBuffer buffer) { + try { + if (this.outputStream == null) { + this.outputStream = store.getOutputStream(fileName); + } + this.write2Stream(buffer); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); + } + } + + private void spillWrite() { + try { + if (this.outputStream == null) { + this.outputStream = store.getOutputStream(fileName); + } + while (!buffers.isEmpty()) { + PipeBuffer buffer = buffers.poll(); + write2Stream(buffer); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(e); + } + } + + private void write2Stream(PipeBuffer buffer) throws IOException { + OutBuffer outBuffer = buffer.getBuffer(); + if (outBuffer == null) { + Encoders.INTEGER.encode(0, this.outputStream); + Encoders.LONG.encode(buffer.getBatchId(), this.outputStream); + Encoders.INTEGER.encode(buffer.getCount(), this.outputStream); + } else { + this.diskBytes += buffer.getBufferSize(); + Encoders.INTEGER.encode(outBuffer.getBufferSize(), this.outputStream); + Encoders.LONG.encode(buffer.getBatchId(), this.outputStream); + outBuffer.write(this.outputStream); + outBuffer.release(); + } + } + + @Override + public void flush() { + if (this.outputStream != null) { + try { + spillWrite(); + this.outputStream.flush(); + this.outputStream.close(); + this.outputStream = null; + } catch (IOException e) { + throw new GeaflowRuntimeException(e); + } + this.streamBufferIterator = new FileStreamIterator(); + } + this.ready2read = true; + LOGGER.info("write file {} {} {}", this.fileName, this.memoryBytes, this.diskBytes); + } - if (this.storageLevel == StorageLevel.MEMORY) { - this.writeMemory(buffer); - return true; - } + ////////////////////////////// + // Consume data. - if (this.storageLevel == StorageLevel.DISK) { - this.writeStore(buffer); - return true; - } + /// /////////////////////////// - this.writeMemory(buffer); - if (!memoryTracker.checkMemoryEnough()) { - this.spillWrite(); - } - return true; + @Override + public boolean hasNext() { + if (this.isReleased) { + return false; } - - private void writeMemory(PipeBuffer buffer) { - this.buffers.add(buffer); - this.memoryBytes += buffer.getBufferSize(); + if (this.value != null) { + return true; } - private void writeStore(PipeBuffer buffer) { - try { - if (this.outputStream == null) { - this.outputStream = store.getOutputStream(fileName); - } - this.write2Stream(buffer); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } + if (!buffers.isEmpty()) { + this.value = buffers.poll(); + return true; } - private void spillWrite() { - try { - if (this.outputStream == null) { - this.outputStream = store.getOutputStream(fileName); - } - while (!buffers.isEmpty()) { - PipeBuffer buffer = buffers.poll(); - write2Stream(buffer); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } + if (streamBufferIterator != null && streamBufferIterator.hasNext()) { + this.value = streamBufferIterator.next(); + return true; } - private void write2Stream(PipeBuffer buffer) throws IOException { - OutBuffer outBuffer = buffer.getBuffer(); - if (outBuffer == null) { - Encoders.INTEGER.encode(0, this.outputStream); - Encoders.LONG.encode(buffer.getBatchId(), this.outputStream); - Encoders.INTEGER.encode(buffer.getCount(), this.outputStream); - } else { - this.diskBytes += buffer.getBufferSize(); - Encoders.INTEGER.encode(outBuffer.getBufferSize(), this.outputStream); - Encoders.LONG.encode(buffer.getBatchId(), this.outputStream); - outBuffer.write(this.outputStream); - outBuffer.release(); - } + return false; + } + + @Override + public PipeBuffer next() { + PipeBuffer next = this.value; + this.value = null; + return next; + } + + @Override + public boolean isReady2read() { + return this.ready2read; + } + + @Override + public synchronized void release() { + if (this.isReleased) { + return; } - - @Override - public void flush() { - if (this.outputStream != null) { - try { - spillWrite(); - this.outputStream.flush(); - this.outputStream.close(); - this.outputStream = null; - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - this.streamBufferIterator = new FileStreamIterator(); - } - this.ready2read = true; - LOGGER.info("write file {} {} {}", this.fileName, this.memoryBytes, this.diskBytes); + this.buffers.clear(); + try { + if (streamBufferIterator != null) { + streamBufferIterator.close(); + streamBufferIterator = null; + } + Path path = Paths.get(this.fileName); + Files.deleteIfExists(path); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + this.isReleased = true; + } - ////////////////////////////// - // Consume data. + class FileStreamIterator implements CloseableIterator { + private InputStream inputStream; + private PipeBuffer next; - /// /////////////////////////// + FileStreamIterator() { + this.inputStream = store.getInputStream(fileName); + } @Override public boolean hasNext() { - if (this.isReleased) { - return false; - } - if (this.value != null) { - return true; - } - - if (!buffers.isEmpty()) { - this.value = buffers.poll(); + if (this.next != null) { + return true; + } + InputStream input = this.inputStream; + try { + if (input != null && input.available() > 0) { + int size = Encoders.INTEGER.decode(input); + long batchId = Encoders.LONG.decode(input); + if (size == 0) { + int count = Encoders.INTEGER.decode(input); + this.next = new PipeBuffer(batchId, count, true); return true; - } - - if (streamBufferIterator != null && streamBufferIterator.hasNext()) { - this.value = streamBufferIterator.next(); + } else { + byte[] bytes = new byte[size]; + int read = input.read(bytes); + if (read != bytes.length) { + String msg = + String.format("illegal read size, expect %d, actual %d", bytes.length, read); + throw new GeaflowRuntimeException(msg); + } + this.next = new PipeBuffer(bytes, batchId); return true; + } } + } catch (IOException e) { + throw new GeaflowRuntimeException(e.getMessage(), e); + } - return false; + return false; } @Override public PipeBuffer next() { - PipeBuffer next = this.value; - this.value = null; - return next; - } - - @Override - public boolean isReady2read() { - return this.ready2read; + PipeBuffer buffer = this.next; + this.next = null; + return buffer; } @Override - public synchronized void release() { - if (this.isReleased) { - return; - } - this.buffers.clear(); + public void close() { + if (inputStream != null) { try { - if (streamBufferIterator != null) { - streamBufferIterator.close(); - streamBufferIterator = null; - } - Path path = Paths.get(this.fileName); - Files.deleteIfExists(path); + inputStream.close(); + inputStream = null; } catch (IOException e) { - throw new GeaflowRuntimeException(e); + throw new GeaflowRuntimeException(e); } - this.isReleased = true; + } } - - - class FileStreamIterator implements CloseableIterator { - private InputStream inputStream; - private PipeBuffer next; - - FileStreamIterator() { - this.inputStream = store.getInputStream(fileName); - } - - @Override - public boolean hasNext() { - if (this.next != null) { - return true; - } - InputStream input = this.inputStream; - try { - if (input != null && input.available() > 0) { - int size = Encoders.INTEGER.decode(input); - long batchId = Encoders.LONG.decode(input); - if (size == 0) { - int count = Encoders.INTEGER.decode(input); - this.next = new PipeBuffer(batchId, count, true); - return true; - } else { - byte[] bytes = new byte[size]; - int read = input.read(bytes); - if (read != bytes.length) { - String msg = String.format("illegal read size, expect %d, actual %d", - bytes.length, read); - throw new GeaflowRuntimeException(msg); - } - this.next = new PipeBuffer(bytes, batchId); - return true; - } - } - } catch (IOException e) { - throw new GeaflowRuntimeException(e.getMessage(), e); - } - - return false; - } - - @Override - public PipeBuffer next() { - PipeBuffer buffer = this.next; - this.next = null; - return buffer; - } - - @Override - public void close() { - if (inputStream != null) { - try { - inputStream.close(); - inputStream = null; - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - } - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractMessageIterator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractMessageIterator.java index e10e40b24..41fab5a4b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractMessageIterator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractMessageIterator.java @@ -20,58 +20,58 @@ package org.apache.geaflow.shuffle.serialize; import java.io.InputStream; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.shuffle.pipeline.buffer.OutBuffer; public abstract class AbstractMessageIterator implements IMessageIterator { - private long recordNum; - protected T currentValue; + private long recordNum; + protected T currentValue; - protected OutBuffer outBuffer; - protected InputStream inputStream; + protected OutBuffer outBuffer; + protected InputStream inputStream; - public AbstractMessageIterator(OutBuffer outBuffer) { - this.outBuffer = outBuffer; - this.inputStream = outBuffer.getInputStream(); - } + public AbstractMessageIterator(OutBuffer outBuffer) { + this.outBuffer = outBuffer; + this.inputStream = outBuffer.getInputStream(); + } - public AbstractMessageIterator(InputStream inputStream) { - this.inputStream = inputStream; - } + public AbstractMessageIterator(InputStream inputStream) { + this.inputStream = inputStream; + } - public OutBuffer getOutBuffer() { - return this.outBuffer; - } + public OutBuffer getOutBuffer() { + return this.outBuffer; + } - /** - * Returns the next element in the iteration. - * - * @return the next element. - */ - @Override - public T next() { - this.recordNum++; - T result = this.currentValue; - this.currentValue = null; - return result; - } + /** + * Returns the next element in the iteration. + * + * @return the next element. + */ + @Override + public T next() { + this.recordNum++; + T result = this.currentValue; + this.currentValue = null; + return result; + } - @Override - public long getSize() { - return this.recordNum; - } + @Override + public long getSize() { + return this.recordNum; + } - @Override - public void close() { - if (inputStream != null) { - IOUtils.closeQuietly(inputStream); - inputStream = null; - } - if (outBuffer != null) { - outBuffer.release(); - outBuffer = null; - } + @Override + public void close() { + if (inputStream != null) { + IOUtils.closeQuietly(inputStream); + inputStream = null; } - + if (outBuffer != null) { + outBuffer.release(); + outBuffer = null; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractRecordSerializer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractRecordSerializer.java index 80c81a447..972e6a161 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractRecordSerializer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/AbstractRecordSerializer.java @@ -23,12 +23,11 @@ public abstract class AbstractRecordSerializer implements IRecordSerializer { - @Override - public void serialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder) { - this.doSerialize(record, isRetract, builder); - builder.increaseRecordCount(); - } - - public abstract void doSerialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder); + @Override + public void serialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder) { + this.doSerialize(record, isRetract, builder); + builder.increaseRecordCount(); + } + public abstract void doSerialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderMessageIterator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderMessageIterator.java index eae1cdd2f..4c39b7707 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderMessageIterator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderMessageIterator.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.InputStream; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -30,36 +31,36 @@ public class EncoderMessageIterator extends AbstractMessageIterator { - private static final Logger LOGGER = LoggerFactory.getLogger(EncoderMessageIterator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(EncoderMessageIterator.class); - private final IEncoder encoder; + private final IEncoder encoder; - public EncoderMessageIterator(OutBuffer outBuffer, IEncoder encoder) { - super(outBuffer); - this.encoder = encoder; - } + public EncoderMessageIterator(OutBuffer outBuffer, IEncoder encoder) { + super(outBuffer); + this.encoder = encoder; + } - public EncoderMessageIterator(InputStream inputStream, IEncoder encoder) { - super(inputStream); - this.encoder = encoder; - } + public EncoderMessageIterator(InputStream inputStream, IEncoder encoder) { + super(inputStream); + this.encoder = encoder; + } - @Override - public boolean hasNext() { - if (currentValue != null) { - return true; - } - try { - if (this.inputStream.available() > 0) { - this.currentValue = this.encoder.decode(this.inputStream); - return true; - } else { - return false; - } - } catch (IOException e) { - LOGGER.error("encoder deserialize err", e); - throw new GeaflowRuntimeException(RuntimeErrors.INST.shuffleDeserializeError(e.getMessage()), e); - } + @Override + public boolean hasNext() { + if (currentValue != null) { + return true; } - + try { + if (this.inputStream.available() > 0) { + this.currentValue = this.encoder.decode(this.inputStream); + return true; + } else { + return false; + } + } catch (IOException e) { + LOGGER.error("encoder deserialize err", e); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.shuffleDeserializeError(e.getMessage()), e); + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderRecordSerializer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderRecordSerializer.java index 17f0e9ed2..7178cb3ba 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderRecordSerializer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/EncoderRecordSerializer.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.serialize; import java.io.IOException; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -27,19 +28,19 @@ public class EncoderRecordSerializer extends AbstractRecordSerializer { - private final IEncoder encoder; + private final IEncoder encoder; - public EncoderRecordSerializer(IEncoder encoder) { - this.encoder = encoder; - } + public EncoderRecordSerializer(IEncoder encoder) { + this.encoder = encoder; + } - @Override - public void doSerialize(T value, boolean isRetract, OutBuffer.BufferBuilder outBuffer) { - try { - this.encoder.encode(value, outBuffer.getOutputStream()); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.shuffleSerializeError(e.getMessage()), e); - } + @Override + public void doSerialize(T value, boolean isRetract, OutBuffer.BufferBuilder outBuffer) { + try { + this.encoder.encode(value, outBuffer.getOutputStream()); + } catch (IOException e) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.shuffleSerializeError(e.getMessage()), e); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IMessageIterator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IMessageIterator.java index 307a76e63..f86b6d1bf 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IMessageIterator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IMessageIterator.java @@ -23,16 +23,13 @@ public interface IMessageIterator extends Iterator { - /** - * Get total record accessed. - * - * @return total record number. - */ - long getSize(); + /** + * Get total record accessed. + * + * @return total record number. + */ + long getSize(); - /** - * Close this iterator. - */ - void close(); + /** Close this iterator. */ + void close(); } - diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IRecordSerializer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IRecordSerializer.java index 6617d2fa1..cdc78067e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IRecordSerializer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/IRecordSerializer.java @@ -23,13 +23,12 @@ public interface IRecordSerializer { - /** - * Serialize data to out buffer. - * - * @param record data - * @param isRetract if data is retract - * @param builder buffer - */ - void serialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder); - + /** + * Serialize data to out buffer. + * + * @param record data + * @param isRetract if data is retract + * @param builder buffer + */ + void serialize(T record, boolean isRetract, OutBuffer.BufferBuilder builder); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/MessageIterator.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/MessageIterator.java index 8494666fe..817f65d5a 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/MessageIterator.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/MessageIterator.java @@ -19,49 +19,50 @@ package org.apache.geaflow.shuffle.serialize; -import com.esotericsoftware.kryo.KryoException; -import com.esotericsoftware.kryo.io.Input; import java.io.InputStream; import java.util.Locale; + import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.serialize.impl.KryoSerializer; import org.apache.geaflow.shuffle.pipeline.buffer.OutBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.esotericsoftware.kryo.KryoException; +import com.esotericsoftware.kryo.io.Input; + public class MessageIterator extends AbstractMessageIterator { - private static final Logger LOGGER = LoggerFactory.getLogger(MessageIterator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MessageIterator.class); - private final KryoSerializer kryoSerializer; - private final Input input; + private final KryoSerializer kryoSerializer; + private final Input input; - public MessageIterator(OutBuffer outBuffer) { - super(outBuffer); - this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); - this.input = new Input(this.inputStream); - } + public MessageIterator(OutBuffer outBuffer) { + super(outBuffer); + this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); + this.input = new Input(this.inputStream); + } - public MessageIterator(InputStream inputStream) { - super(inputStream); - this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); - this.input = new Input(inputStream); - } + public MessageIterator(InputStream inputStream) { + super(inputStream); + this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); + this.input = new Input(inputStream); + } - public boolean hasNext() { - if (currentValue != null) { - return true; - } - try { - currentValue = (T) kryoSerializer.getThreadKryo().readClassAndObject(input); - return true; - } catch (KryoException e) { - if (e.getMessage().toLowerCase(Locale.ROOT).contains("buffer underflow")) { - currentValue = null; - return false; - } - LOGGER.error("deserialize failed", e); - throw e; - } + public boolean hasNext() { + if (currentValue != null) { + return true; } - + try { + currentValue = (T) kryoSerializer.getThreadKryo().readClassAndObject(input); + return true; + } catch (KryoException e) { + if (e.getMessage().toLowerCase(Locale.ROOT).contains("buffer underflow")) { + currentValue = null; + return false; + } + LOGGER.error("deserialize failed", e); + throw e; + } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/RecordSerializer.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/RecordSerializer.java index 1d20b5488..3e5b819e8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/RecordSerializer.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/serialize/RecordSerializer.java @@ -19,33 +19,33 @@ package org.apache.geaflow.shuffle.serialize; -import com.esotericsoftware.kryo.io.Output; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.serialize.impl.KryoSerializer; import org.apache.geaflow.shuffle.pipeline.buffer.OutBuffer.BufferBuilder; -public class RecordSerializer extends AbstractRecordSerializer { - - private static final int DEFAULT_BUFFER_SIZE = 4096; - - private final Output output; - private final KryoSerializer kryoSerializer; +import com.esotericsoftware.kryo.io.Output; - public RecordSerializer() { - this.output = new Output(DEFAULT_BUFFER_SIZE); - this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); - } +public class RecordSerializer extends AbstractRecordSerializer { - @Override - public void doSerialize(T value, boolean isRetract, BufferBuilder outBuffer) { - Output output = this.output; - output.setOutputStream(outBuffer.getOutputStream()); - try { - this.kryoSerializer.getThreadKryo().writeClassAndObject(output, value); - output.flush(); - } finally { - output.clear(); - } + private static final int DEFAULT_BUFFER_SIZE = 4096; + + private final Output output; + private final KryoSerializer kryoSerializer; + + public RecordSerializer() { + this.output = new Output(DEFAULT_BUFFER_SIZE); + this.kryoSerializer = ((KryoSerializer) SerializerFactory.getKryoSerializer()); + } + + @Override + public void doSerialize(T value, boolean isRetract, BufferBuilder outBuffer) { + Output output = this.output; + output.setOutputStream(outBuffer.getOutputStream()); + try { + this.kryoSerializer.getThreadKryo().writeClassAndObject(output, value); + output.flush(); + } finally { + output.clear(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/IShuffleService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/IShuffleService.java index 5bff66702..b28e9f74b 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/IShuffleService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/IShuffleService.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.service; import java.io.Serializable; + import org.apache.geaflow.shuffle.api.reader.IShuffleReader; import org.apache.geaflow.shuffle.api.writer.IShuffleWriter; import org.apache.geaflow.shuffle.message.PipelineInfo; @@ -28,29 +29,27 @@ public interface IShuffleService extends Serializable { - /** - * Init the shuffle service with job config. - * - * @param connectionManager connection manager. - */ - void init(IConnectionManager connectionManager); - - /** - * get shuffle writer per reducer task. - * - * @return shuffle reader. - */ - IShuffleReader getReader(); - - /** - * get shuffle reader per mapper task. - * - * @return shuffle writer. - */ - IShuffleWriter getWriter(); - - /** - * Release the local resources of this job id. - */ - void clean(PipelineInfo jobInfo); + /** + * Init the shuffle service with job config. + * + * @param connectionManager connection manager. + */ + void init(IConnectionManager connectionManager); + + /** + * get shuffle writer per reducer task. + * + * @return shuffle reader. + */ + IShuffleReader getReader(); + + /** + * get shuffle reader per mapper task. + * + * @return shuffle writer. + */ + IShuffleWriter getWriter(); + + /** Release the local resources of this job id. */ + void clean(PipelineInfo jobInfo); } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/NettyShuffleService.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/NettyShuffleService.java index 6f5763d52..e2bc11e2c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/NettyShuffleService.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/NettyShuffleService.java @@ -31,28 +31,28 @@ public class NettyShuffleService implements IShuffleService { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyShuffleService.class); - - private IConnectionManager connectionManager; - - @Override - public void init(IConnectionManager connectionManager) { - this.connectionManager = connectionManager; - } - - @Override - public IShuffleReader getReader() { - return new PipelineReader(this.connectionManager); - } - - @Override - public IShuffleWriter getWriter() { - return new PipelineWriter<>(this.connectionManager); - } - - @Override - public void clean(PipelineInfo jobInfo) { - LOGGER.info("release shuffle data of job {}", jobInfo); - ShuffleManager.getInstance().release(jobInfo.getPipelineId()); - } + private static final Logger LOGGER = LoggerFactory.getLogger(NettyShuffleService.class); + + private IConnectionManager connectionManager; + + @Override + public void init(IConnectionManager connectionManager) { + this.connectionManager = connectionManager; + } + + @Override + public IShuffleReader getReader() { + return new PipelineReader(this.connectionManager); + } + + @Override + public IShuffleWriter getWriter() { + return new PipelineWriter<>(this.connectionManager); + } + + @Override + public void clean(PipelineInfo jobInfo) { + LOGGER.info("release shuffle data of job {}", jobInfo); + ShuffleManager.getInstance().release(jobInfo.getPipelineId()); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/ShuffleManager.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/ShuffleManager.java index 606755744..3b46e1345 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/ShuffleManager.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/service/ShuffleManager.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.service; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.shuffle.api.reader.IShuffleReader; import org.apache.geaflow.shuffle.api.writer.IShuffleWriter; @@ -34,75 +35,75 @@ public class ShuffleManager { - private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleManager.class); - - private static ShuffleManager INSTANCE; - private final IShuffleService shuffleService; - private final ConnectionManager connectionManager; - private final SliceManager sliceManager; - private final ShuffleConfig shuffleConfig; - private final ShuffleMemoryTracker shuffleMemoryTracker; - - public ShuffleManager(Configuration config) { - this.shuffleConfig = new ShuffleConfig(config); - this.connectionManager = new ConnectionManager(shuffleConfig); - this.shuffleService = new NettyShuffleService(); - this.shuffleService.init(connectionManager); - this.sliceManager = new SliceManager(); - this.shuffleMemoryTracker = new ShuffleMemoryTracker(config); - } - - public static synchronized ShuffleManager init(Configuration config) { - if (INSTANCE == null) { - INSTANCE = new ShuffleManager(config); - } - return INSTANCE; - } - - public static ShuffleManager getInstance() { - return INSTANCE; - } - - public IConnectionManager getConnectionManager() { - return connectionManager; - } - - public SliceManager getSliceManager() { - return sliceManager; - } - - public ShuffleConfig getShuffleConfig() { - return shuffleConfig; + private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleManager.class); + + private static ShuffleManager INSTANCE; + private final IShuffleService shuffleService; + private final ConnectionManager connectionManager; + private final SliceManager sliceManager; + private final ShuffleConfig shuffleConfig; + private final ShuffleMemoryTracker shuffleMemoryTracker; + + public ShuffleManager(Configuration config) { + this.shuffleConfig = new ShuffleConfig(config); + this.connectionManager = new ConnectionManager(shuffleConfig); + this.shuffleService = new NettyShuffleService(); + this.shuffleService.init(connectionManager); + this.sliceManager = new SliceManager(); + this.shuffleMemoryTracker = new ShuffleMemoryTracker(config); + } + + public static synchronized ShuffleManager init(Configuration config) { + if (INSTANCE == null) { + INSTANCE = new ShuffleManager(config); } - - public ShuffleMemoryTracker getShuffleMemoryTracker() { - return shuffleMemoryTracker; - } - - public int getShufflePort() { - return connectionManager.getShuffleAddress().port(); - } - - public IShuffleReader loadShuffleReader() { - return shuffleService.getReader(); - } - - public IShuffleWriter loadShuffleWriter() { - return shuffleService.getWriter(); - } - - public void release(long pipelineId) { - sliceManager.release(pipelineId); - } - - public synchronized void close() { - LOGGER.info("closing shuffle manager"); - try { - connectionManager.close(); - shuffleMemoryTracker.release(); - INSTANCE = null; - } catch (IOException e) { - LOGGER.warn("close connectManager failed:{}", e.getCause(), e); - } + return INSTANCE; + } + + public static ShuffleManager getInstance() { + return INSTANCE; + } + + public IConnectionManager getConnectionManager() { + return connectionManager; + } + + public SliceManager getSliceManager() { + return sliceManager; + } + + public ShuffleConfig getShuffleConfig() { + return shuffleConfig; + } + + public ShuffleMemoryTracker getShuffleMemoryTracker() { + return shuffleMemoryTracker; + } + + public int getShufflePort() { + return connectionManager.getShuffleAddress().port(); + } + + public IShuffleReader loadShuffleReader() { + return shuffleService.getReader(); + } + + public IShuffleWriter loadShuffleWriter() { + return shuffleService.getWriter(); + } + + public void release(long pipelineId) { + sliceManager.release(pipelineId); + } + + public synchronized void close() { + LOGGER.info("closing shuffle manager"); + try { + connectionManager.close(); + shuffleMemoryTracker.release(); + INSTANCE = null; + } catch (IOException e) { + LOGGER.warn("close connectManager failed:{}", e.getCause(), e); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/LocalShuffleStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/LocalShuffleStore.java index 58faf80cf..84a56e2e8 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/LocalShuffleStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/LocalShuffleStore.java @@ -30,54 +30,57 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.config.ShuffleConfig; public class LocalShuffleStore implements ShuffleStore { - private static final int BUFFER_SIZE = 64 * 1024; - private static final String DEFAULT_LOCAL_ROOT = "/shuffle"; + private static final int BUFFER_SIZE = 64 * 1024; + private static final String DEFAULT_LOCAL_ROOT = "/shuffle"; - private final String shufflePath; + private final String shufflePath; - public LocalShuffleStore(ShuffleConfig shuffleConfig) { - Configuration configuration = shuffleConfig.getConfig(); - String workPath = configuration.getString(JOB_WORK_PATH); - Path path = Paths.get(workPath, DEFAULT_LOCAL_ROOT); - if (!Files.exists(path)) { - try { - Files.createDirectories(path); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - } - this.shufflePath = path.toString(); + public LocalShuffleStore(ShuffleConfig shuffleConfig) { + Configuration configuration = shuffleConfig.getConfig(); + String workPath = configuration.getString(JOB_WORK_PATH); + Path path = Paths.get(workPath, DEFAULT_LOCAL_ROOT); + if (!Files.exists(path)) { + try { + Files.createDirectories(path); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); + } } + this.shufflePath = path.toString(); + } - @Override - public String getFilePath(String fileName) { - return Paths.get(shufflePath, fileName).toString(); - } + @Override + public String getFilePath(String fileName) { + return Paths.get(shufflePath, fileName).toString(); + } - @Override - public InputStream getInputStream(String filePath) { - try { - Path path = Paths.get(filePath); - return new BufferedInputStream(Files.newInputStream(path, StandardOpenOption.READ), BUFFER_SIZE); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } + @Override + public InputStream getInputStream(String filePath) { + try { + Path path = Paths.get(filePath); + return new BufferedInputStream( + Files.newInputStream(path, StandardOpenOption.READ), BUFFER_SIZE); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + } - @Override - public OutputStream getOutputStream(String filePath) { - try { - Path path = Paths.get(filePath); - Files.deleteIfExists(path); - Files.createFile(path); - return new BufferedOutputStream(Files.newOutputStream(path, StandardOpenOption.WRITE), BUFFER_SIZE); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } + @Override + public OutputStream getOutputStream(String filePath) { + try { + Path path = Paths.get(filePath); + Files.deleteIfExists(path); + Files.createFile(path); + return new BufferedOutputStream( + Files.newOutputStream(path, StandardOpenOption.WRITE), BUFFER_SIZE); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/ShuffleStore.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/ShuffleStore.java index ec06bc78d..6a1a1045e 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/ShuffleStore.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/storage/ShuffleStore.java @@ -21,43 +21,43 @@ import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.shuffle.StorageLevel; import org.apache.geaflow.shuffle.config.ShuffleConfig; public interface ShuffleStore { - /** - * Get file path. - * - * @param fileName file name - * @return file path. - */ - String getFilePath(String fileName); - - /** - * Get input stream by filePath. - * - * @param path file path. - * @return file input stream. - */ - InputStream getInputStream(String path); - - /** - * Get output stream by filePath. - * - * @param path file path - * @return file output stream. - */ - OutputStream getOutputStream(String path); - - static ShuffleStore getShuffleStore(ShuffleConfig shuffleConfig) { - StorageLevel storageLevel = shuffleConfig.getStorageLevel(); - if (storageLevel == StorageLevel.DISK || storageLevel == StorageLevel.MEMORY_AND_DISK) { - return new LocalShuffleStore(shuffleConfig); - } else { - throw new GeaflowRuntimeException("unsupported shuffle level: " + storageLevel); - } + /** + * Get file path. + * + * @param fileName file name + * @return file path. + */ + String getFilePath(String fileName); + + /** + * Get input stream by filePath. + * + * @param path file path. + * @return file input stream. + */ + InputStream getInputStream(String path); + + /** + * Get output stream by filePath. + * + * @param path file path + * @return file output stream. + */ + OutputStream getOutputStream(String path); + + static ShuffleStore getShuffleStore(ShuffleConfig shuffleConfig) { + StorageLevel storageLevel = shuffleConfig.getStorageLevel(); + if (storageLevel == StorageLevel.DISK || storageLevel == StorageLevel.MEMORY_AND_DISK) { + return new LocalShuffleStore(shuffleConfig); + } else { + throw new GeaflowRuntimeException("unsupported shuffle level: " + storageLevel); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/AtomicReferenceCounter.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/AtomicReferenceCounter.java index fdb7a2915..e955aaac4 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/AtomicReferenceCounter.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/AtomicReferenceCounter.java @@ -19,87 +19,82 @@ package org.apache.geaflow.shuffle.util; -/** - * refer to the implementation of flink. - */ +/** refer to the implementation of flink. */ public class AtomicReferenceCounter { - private final Object lock = new Object(); - private int referenceCount; - private boolean isDisposed; - // Enter the disposed state when the reference count reaches this number. - private final int disposeOnReferenceCount; + private final Object lock = new Object(); + private int referenceCount; + private boolean isDisposed; + // Enter the disposed state when the reference count reaches this number. + private final int disposeOnReferenceCount; - public AtomicReferenceCounter() { - this.disposeOnReferenceCount = 0; - } + public AtomicReferenceCounter() { + this.disposeOnReferenceCount = 0; + } - public AtomicReferenceCounter(int disposeOnReferenceCount) { - this.disposeOnReferenceCount = disposeOnReferenceCount; - } + public AtomicReferenceCounter(int disposeOnReferenceCount) { + this.disposeOnReferenceCount = disposeOnReferenceCount; + } - /** - * Increments the reference count and returns whether it was successful. - * - *

If the method returns false, the counter has already been disposed. - * Otherwise, it returns true. - */ - public boolean increment() { - synchronized (lock) { - if (isDisposed) { - return false; - } + /** + * Increments the reference count and returns whether it was successful. + * + *

If the method returns false, the counter has already been disposed. Otherwise, + * it returns true. + */ + public boolean increment() { + synchronized (lock) { + if (isDisposed) { + return false; + } - referenceCount++; - return true; - } + referenceCount++; + return true; } + } - /** - * Decrements the reference count and returns whether the reference counter entered the disposed - * state. - * - *

If the method returns true, the decrement operation disposed the counter. - * Otherwise, it returns false. - */ - public boolean decrement() { - synchronized (lock) { - if (isDisposed) { - return false; - } + /** + * Decrements the reference count and returns whether the reference counter entered the disposed + * state. + * + *

If the method returns true, the decrement operation disposed the counter. + * Otherwise, it returns false. + */ + public boolean decrement() { + synchronized (lock) { + if (isDisposed) { + return false; + } - referenceCount--; - if (referenceCount <= disposeOnReferenceCount) { - isDisposed = true; - } + referenceCount--; + if (referenceCount <= disposeOnReferenceCount) { + isDisposed = true; + } - return isDisposed; - } + return isDisposed; } + } - public int get() { - synchronized (lock) { - return referenceCount; - } + public int get() { + synchronized (lock) { + return referenceCount; } + } - /** - * Returns whether the reference count has reached the disposed state. - */ - public boolean isDisposed() { - synchronized (lock) { - return isDisposed; - } + /** Returns whether the reference count has reached the disposed state. */ + public boolean isDisposed() { + synchronized (lock) { + return isDisposed; } + } - public boolean disposeIfNotUsed() { - synchronized (lock) { - if (referenceCount <= disposeOnReferenceCount) { - isDisposed = true; - } + public boolean disposeIfNotUsed() { + synchronized (lock) { + if (referenceCount <= disposeOnReferenceCount) { + isDisposed = true; + } - return isDisposed; - } + return isDisposed; } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/SliceNotFoundException.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/SliceNotFoundException.java index f2dbd3185..f705575ac 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/SliceNotFoundException.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/SliceNotFoundException.java @@ -20,11 +20,12 @@ package org.apache.geaflow.shuffle.util; import java.io.IOException; + import org.apache.geaflow.shuffle.message.SliceId; public class SliceNotFoundException extends IOException { - public SliceNotFoundException(SliceId sliceId) { - super("Slice not found:" + sliceId); - } + public SliceNotFoundException(SliceId sliceId) { + super("Slice not found:" + sliceId); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/TransportException.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/TransportException.java index 3a9301042..770fb83af 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/TransportException.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/org/apache/geaflow/shuffle/util/TransportException.java @@ -27,24 +27,24 @@ * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.runtime.io.network.netty.exception.TransportException. + * This class is an adaptation of Flink's + * org.apache.flink.runtime.io.network.netty.exception.TransportException. */ public class TransportException extends IOException { - private static final long serialVersionUID = 3637820720589866570L; - private final SocketAddress address; + private static final long serialVersionUID = 3637820720589866570L; + private final SocketAddress address; - public TransportException(String message, SocketAddress address) { - this(message, address, null); - } + public TransportException(String message, SocketAddress address) { + this(message, address, null); + } - public TransportException(String message, SocketAddress address, Throwable cause) { - super(message, cause); - this.address = address; - } - - public SocketAddress getAddress() { - return address; - } + public TransportException(String message, SocketAddress address, Throwable cause) { + super(message, cause); + this.address = address; + } + public SocketAddress getAddress() { + return address; + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriterTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriterTest.java index e0e2449cb..4e43fcd28 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriterTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/PipelineShardWriterTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.api.writer; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -39,87 +40,87 @@ public class PipelineShardWriterTest { - private boolean enableBackPressure; - - public PipelineShardWriterTest(boolean enableBackPressure) { - this.enableBackPressure = enableBackPressure; - } - - @Test - public void testWrite() throws IOException, InterruptedException { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); - configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); - configuration.put(ExecutionConfigKeys.SHUFFLE_BACKPRESSURE_ENABLE, - String.valueOf(enableBackPressure)); - configuration.put(ExecutionConfigKeys.SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE, "1"); - configuration.put(ExecutionConfigKeys.SHUFFLE_WRITER_BUFFER_SIZE, "20"); - configuration.put(ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_SIZE_BYTES, "10"); - - int pipelineId = 300; - SliceManager sliceManager = new SliceManager(); - ShuffleConfig shuffleConfig = new ShuffleConfig(configuration); - ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(configuration); - - try (MockedStatic ms = Mockito.mockStatic(ShuffleManager.class)) { - ShuffleManager shuffleManager = Mockito.mock(ShuffleManager.class); - ms.when(() -> ShuffleManager.getInstance()).then(invocation -> shuffleManager); - - Mockito.doReturn(sliceManager).when(shuffleManager).getSliceManager(); - Mockito.doReturn(shuffleConfig).when(shuffleManager).getShuffleConfig(); - Mockito.doReturn(memoryTracker).when(shuffleManager).getShuffleMemoryTracker(); - - WriterContext writerContext = new WriterContext(pipelineId, "write-test"); - writerContext.setConfig(shuffleConfig); - writerContext.setChannelNum(1); - writerContext.setEdgeId(0); - - PipelineShardWriter shardWriter = new PipelineShardWriter(); - shardWriter.init(writerContext); - - new Thread(() -> { + private boolean enableBackPressure; + + public PipelineShardWriterTest(boolean enableBackPressure) { + this.enableBackPressure = enableBackPressure; + } + + @Test + public void testWrite() throws IOException, InterruptedException { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); + configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); + configuration.put( + ExecutionConfigKeys.SHUFFLE_BACKPRESSURE_ENABLE, String.valueOf(enableBackPressure)); + configuration.put(ExecutionConfigKeys.SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE, "1"); + configuration.put(ExecutionConfigKeys.SHUFFLE_WRITER_BUFFER_SIZE, "20"); + configuration.put(ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_SIZE_BYTES, "10"); + + int pipelineId = 300; + SliceManager sliceManager = new SliceManager(); + ShuffleConfig shuffleConfig = new ShuffleConfig(configuration); + ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(configuration); + + try (MockedStatic ms = Mockito.mockStatic(ShuffleManager.class)) { + ShuffleManager shuffleManager = Mockito.mock(ShuffleManager.class); + ms.when(() -> ShuffleManager.getInstance()).then(invocation -> shuffleManager); + + Mockito.doReturn(sliceManager).when(shuffleManager).getSliceManager(); + Mockito.doReturn(shuffleConfig).when(shuffleManager).getShuffleConfig(); + Mockito.doReturn(memoryTracker).when(shuffleManager).getShuffleMemoryTracker(); + + WriterContext writerContext = new WriterContext(pipelineId, "write-test"); + writerContext.setConfig(shuffleConfig); + writerContext.setChannelNum(1); + writerContext.setEdgeId(0); + + PipelineShardWriter shardWriter = new PipelineShardWriter(); + shardWriter.init(writerContext); + + new Thread( + () -> { IPipelineSlice slice = sliceManager.getSlice(new SliceId(pipelineId, 0, 0, 0)); Assert.assertNotNull(slice); if (enableBackPressure) { - Assert.assertTrue(slice instanceof BlockingSlice); + Assert.assertTrue(slice instanceof BlockingSlice); } else { - Assert.assertTrue(slice instanceof PipelineSlice); + Assert.assertTrue(slice instanceof PipelineSlice); } int count = 0; while (true) { - PipeBuffer buffer = slice.next(); - if (buffer != null) { - if (!buffer.isData()) { - break; - } - count++; + PipeBuffer buffer = slice.next(); + if (buffer != null) { + if (!buffer.isData()) { + break; } + count++; + } } Assert.assertEquals(count, 10000); - }).start(); + }) + .start(); - int[] channels = new int[]{0}; - for (int i = 0; i < 10000; i++) { - shardWriter.emit(0, "helloWorld", false, channels); - } - shardWriter.finish(0); + int[] channels = new int[] {0}; + for (int i = 0; i < 10000; i++) { + shardWriter.emit(0, "helloWorld", false, channels); + } + shardWriter.finish(0); - ShuffleManager.getInstance().release(pipelineId); - Mockito.reset(shuffleManager); - } + ShuffleManager.getInstance().release(pipelineId); + Mockito.reset(shuffleManager); } + } - public static class SimpleTestFactory { + public static class SimpleTestFactory { - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new PipelineShardWriterTest(false), - new PipelineShardWriterTest(true), - }; - } + @Factory + public Object[] factoryMethod() { + return new Object[] { + new PipelineShardWriterTest(false), new PipelineShardWriterTest(true), + }; } - -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriterTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriterTest.java index d158c11d7..00d81481c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriterTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/api/writer/SpillableShardWriterTest.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.Optional; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -39,50 +40,48 @@ public class SpillableShardWriterTest { - @Test - public void testEmit() throws IOException { - IConnectionManager connectionManager = Mockito.mock(IConnectionManager.class); - Mockito.when(connectionManager.getShuffleAddress()).thenReturn(new ShuffleAddress( - "localhost", 1)); + @Test + public void testEmit() throws IOException { + IConnectionManager connectionManager = Mockito.mock(IConnectionManager.class); + Mockito.when(connectionManager.getShuffleAddress()) + .thenReturn(new ShuffleAddress("localhost", 1)); - Configuration config = new Configuration(); - config.put(SHUFFLE_SPILL_RECORDS, "50"); - config.put(CONTAINER_HEAP_SIZE_MB, "1"); + Configuration config = new Configuration(); + config.put(SHUFFLE_SPILL_RECORDS, "50"); + config.put(CONTAINER_HEAP_SIZE_MB, "1"); - int pipelineId = 2; - SliceManager sliceManager = new SliceManager(); - ShuffleConfig shuffleConfig = new ShuffleConfig(config); - ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(config); + int pipelineId = 2; + SliceManager sliceManager = new SliceManager(); + ShuffleConfig shuffleConfig = new ShuffleConfig(config); + ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(config); - try (MockedStatic ms = - Mockito.mockStatic(ShuffleManager.class)) { - ShuffleManager shuffleManager = Mockito.mock(ShuffleManager.class); - ms.when(() -> ShuffleManager.getInstance()).then(invocation -> shuffleManager); + try (MockedStatic ms = Mockito.mockStatic(ShuffleManager.class)) { + ShuffleManager shuffleManager = Mockito.mock(ShuffleManager.class); + ms.when(() -> ShuffleManager.getInstance()).then(invocation -> shuffleManager); - Mockito.doReturn(sliceManager).when(shuffleManager).getSliceManager(); - Mockito.doReturn(shuffleConfig).when(shuffleManager).getShuffleConfig(); - Mockito.doReturn(memoryTracker).when(shuffleManager).getShuffleMemoryTracker(); + Mockito.doReturn(sliceManager).when(shuffleManager).getSliceManager(); + Mockito.doReturn(shuffleConfig).when(shuffleManager).getShuffleConfig(); + Mockito.doReturn(memoryTracker).when(shuffleManager).getShuffleMemoryTracker(); - SpillableShardWriter shardWriter = new SpillableShardWriter( - connectionManager.getShuffleAddress()); - WriterContext writerContext = new WriterContext(pipelineId, "name"); + SpillableShardWriter shardWriter = + new SpillableShardWriter(connectionManager.getShuffleAddress()); + WriterContext writerContext = new WriterContext(pipelineId, "name"); - writerContext.setConfig(shuffleManager.getShuffleConfig()); - writerContext.setChannelNum(1); - shardWriter.init(writerContext); - int[] channels = new int[]{0}; + writerContext.setConfig(shuffleManager.getShuffleConfig()); + writerContext.setChannelNum(1); + shardWriter.init(writerContext); + int[] channels = new int[] {0}; - for (int i = 0; i < 10000; i++) { - shardWriter.emit(0, "hello, testing spillable writer", false, channels); - } - Optional optional = shardWriter.finish(0); - Shard shard = optional.get(); - Assert.assertNotNull(shard); - Assert.assertEquals(shard.getSlices().size(), 1); + for (int i = 0; i < 10000; i++) { + shardWriter.emit(0, "hello, testing spillable writer", false, channels); + } + Optional optional = shardWriter.finish(0); + Shard shard = optional.get(); + Assert.assertNotNull(shard); + Assert.assertEquals(shard.getSlices().size(), 1); - ShuffleManager.getInstance().release(pipelineId); - Mockito.reset(shuffleManager); - } + ShuffleManager.getInstance().release(pipelineId); + Mockito.reset(shuffleManager); } - -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/message/PipelineMessageTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/message/PipelineMessageTest.java index ed37df674..084e87552 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/message/PipelineMessageTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/message/PipelineMessageTest.java @@ -25,12 +25,12 @@ public class PipelineMessageTest { - @Test - public void test() { - PipelineMessage message = new PipelineMessage(1, 3, "stream", null); - byte[] bytes = SerializerFactory.getKryoSerializer().serialize(message); - PipelineMessage result = (PipelineMessage) SerializerFactory.getKryoSerializer().deserialize(bytes); - Assert.assertEquals(3, result.getWindowId(), "windowId should be ignored"); - } - + @Test + public void test() { + PipelineMessage message = new PipelineMessage(1, 3, "stream", null); + byte[] bytes = SerializerFactory.getKryoSerializer().serialize(message); + PipelineMessage result = + (PipelineMessage) SerializerFactory.getKryoSerializer().deserialize(bytes); + Assert.assertEquals(3, result.getWindowId(), "windowId should be ignored"); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoderTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoderTest.java index 51451ec08..ff3120a51 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoderTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/NettyFrameDecoderTest.java @@ -26,165 +26,169 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; + import org.testng.Assert; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; -public class NettyFrameDecoderTest { - - private static Random RND = new Random(); - private static final int LENGTH_SIZE = 4; - - @AfterClass - public static void cleanup() { - RND = null; - } - - @Test - public void testFrameDecoding() throws Exception { - NettyFrameDecoder decoder = new NettyFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(); - ByteBuf data = createAndFeedFrames(100, decoder, ctx); - verifyAndCloseDecoder(decoder, ctx, data); - } - - @Test - public void testRetainedFrames() throws Exception { - NettyFrameDecoder decoder = new NettyFrameDecoder(); +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; - AtomicInteger count = new AtomicInteger(); - List retained = new ArrayList<>(); +public class NettyFrameDecoderTest { - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(in -> { - // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - if (count.incrementAndGet() % 2 == 0) { + private static Random RND = new Random(); + private static final int LENGTH_SIZE = 4; + + @AfterClass + public static void cleanup() { + RND = null; + } + + @Test + public void testFrameDecoding() throws Exception { + NettyFrameDecoder decoder = new NettyFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); + } + + @Test + public void testRetainedFrames() throws Exception { + NettyFrameDecoder decoder = new NettyFrameDecoder(); + + AtomicInteger count = new AtomicInteger(); + List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())) + .thenAnswer( + in -> { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { retained.add(buf); - } else { + } else { buf.release(); - } - return null; - }); - - ByteBuf data = createAndFeedFrames(100, decoder, ctx); - try { - // Verify all retained buffers are readable. - for (ByteBuf b : retained) { - byte[] tmp = new byte[b.readableBytes()]; - b.readBytes(tmp); - b.release(); - } - verifyAndCloseDecoder(decoder, ctx, data); - } finally { - for (ByteBuf b : retained) { - release(b); - } - } + } + return null; + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } } - - @Test - public void testSplitLengthField() throws Exception { - byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; - ByteBuf buf = Unpooled.buffer(frame.length + LENGTH_SIZE); - buf.writeInt(frame.length + LENGTH_SIZE); - buf.writeBytes(frame); - - NettyFrameDecoder decoder = new NettyFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(); - try { - decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); - verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); - decoder.channelRead(ctx, buf); - verify(ctx).fireChannelRead(any(ByteBuf.class)); - Assert.assertEquals(0, buf.refCnt()); - } finally { - decoder.channelInactive(ctx); - release(buf); - } - } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testNegativeFrameSize() throws Exception { - testInvalidFrame(-1); + } + + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + LENGTH_SIZE); + buf.writeInt(frame.length + LENGTH_SIZE); + buf.writeBytes(frame); + + NettyFrameDecoder decoder = new NettyFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + Assert.assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testEmptyFrame() throws Exception { - // 8 because frame size includes the frame length. - testInvalidFrame(8); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testNegativeFrameSize() throws Exception { + testInvalidFrame(-1); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testEmptyFrame() throws Exception { + // 8 because frame size includes the frame length. + testInvalidFrame(8); + } + + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying that + * the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, NettyFrameDecoder decoder, ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeInt(frame.length + LENGTH_SIZE); + data.writeBytes(frame); } - /** - * Creates a number of randomly sized frames and feed them to the given decoder, verifying - * that the frames were read. - */ - private ByteBuf createAndFeedFrames(int frameCount, NettyFrameDecoder decoder, - ChannelHandlerContext ctx) throws Exception { - ByteBuf data = Unpooled.buffer(); - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; - data.writeInt(frame.length + LENGTH_SIZE); - data.writeBytes(frame); - } - - try { - while (data.isReadable()) { - int size = RND.nextInt(4 * 1024) + 256; - decoder.channelRead(ctx, - data.readSlice(Math.min(data.readableBytes(), size)).retain()); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } catch (Exception e) { - release(data); - throw e; - } - return data; - } + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } - private void verifyAndCloseDecoder(NettyFrameDecoder decoder, ChannelHandlerContext ctx, - ByteBuf data) throws Exception { - try { - decoder.channelInactive(ctx); - Assert.assertTrue(data.release(), "There shouldn't be dangling references to the data."); - } finally { - release(data); - } + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; } - - private void testInvalidFrame(long size) throws Exception { - NettyFrameDecoder decoder = new NettyFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ByteBuf frame = Unpooled.copyLong(size); - try { - decoder.channelRead(ctx, frame); - } finally { - release(frame); - } + return data; + } + + private void verifyAndCloseDecoder( + NettyFrameDecoder decoder, ChannelHandlerContext ctx, ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + Assert.assertTrue(data.release(), "There shouldn't be dangling references to the data."); + } finally { + release(data); } - - private ChannelHandlerContext mockChannelHandlerContext() { - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(in -> { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); - return null; - }); - return ctx; + } + + private void testInvalidFrame(long size) throws Exception { + NettyFrameDecoder decoder = new NettyFrameDecoder(); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ByteBuf frame = Unpooled.copyLong(size); + try { + decoder.channelRead(ctx, frame); + } finally { + release(frame); } - - private void release(ByteBuf buf) { - if (buf.refCnt() > 0) { - buf.release(buf.refCnt()); - } + } + + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())) + .thenAnswer( + in -> { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); } - + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientTest.java index 6978d22a5..90c7c8e62 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/network/netty/SliceRequestClientTest.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.shuffle.ShuffleAddress; @@ -50,109 +51,109 @@ @Test(singleThreaded = true) public class SliceRequestClientTest { - private boolean enableBackPressure; - - public SliceRequestClientTest(boolean enableBackPressure) { - this.enableBackPressure = enableBackPressure; - } - - @Test - public void testFetchWithoutMemoryPool() throws IOException, InterruptedException { - testCreditBasedFetch(false); - } - - @Test - public void testFetchWithMemoryPool() throws IOException, InterruptedException { - testCreditBasedFetch(true); - } - - private void testCreditBasedFetch(boolean enableMemoryPool) throws IOException, - InterruptedException { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); - configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); - configuration.put(ExecutionConfigKeys.SHUFFLE_BACKPRESSURE_ENABLE, String.valueOf(enableBackPressure)); - configuration.put(ExecutionConfigKeys.SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE, "1"); - configuration.put(ExecutionConfigKeys.SHUFFLE_WRITER_BUFFER_SIZE, "10"); - configuration.put(ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_SIZE_BYTES, "5"); - configuration.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE, String.valueOf(enableMemoryPool)); - - ShuffleManager shuffleManager = ShuffleManager.init(configuration); - IConnectionManager connectionManager = shuffleManager.getConnectionManager(); - - List inputSlices = new ArrayList<>(); - ShuffleAddress address = connectionManager.getShuffleAddress(); - SliceId sliceId = new SliceId(100, 0, 0, 0); - PipelineSliceMeta slice1 = new PipelineSliceMeta(sliceId, 1, address); - inputSlices.add(slice1); - - PipeBuffer pipeBuffer = buildPipeBuffer(enableMemoryPool, "hello".getBytes(), 1); - PipeBuffer pipeBuffer2 = new PipeBuffer(1, 1, false); - PipeBuffer pipeBuffer3 = buildPipeBuffer(enableMemoryPool, "hello".getBytes(), 2); - PipeBuffer pipeBuffer4 = new PipeBuffer(2, 1, true); - PipelineSlice slice = new PipelineSlice("task", sliceId); - slice.add(pipeBuffer); - slice.add(pipeBuffer2); - slice.add(pipeBuffer3); - slice.add(pipeBuffer4); - - SliceManager shuffleDataManager = ShuffleManager.getInstance().getSliceManager(); - shuffleDataManager.register(sliceId, slice); - - OneShardFetcher fetcher = new MockedShardFetcher(1, "taskName", 0, inputSlices, 0, - connectionManager); - List batchList = Arrays.asList(1L, 2L); - List result = new ArrayList<>(); - for (long batchId : batchList) { - fetcher.requestSlices(batchId); - while (!fetcher.isFinished()) { - Optional bufferOptional = fetcher.getNext(); - if (bufferOptional.isPresent()) { - PipeFetcherBuffer buffer = bufferOptional.get(); - String value = String.valueOf(buffer.getBuffer()); - result.add(value); - if (buffer.isBarrier()) { - break; - } - } - } + private boolean enableBackPressure; + + public SliceRequestClientTest(boolean enableBackPressure) { + this.enableBackPressure = enableBackPressure; + } + + @Test + public void testFetchWithoutMemoryPool() throws IOException, InterruptedException { + testCreditBasedFetch(false); + } + + @Test + public void testFetchWithMemoryPool() throws IOException, InterruptedException { + testCreditBasedFetch(true); + } + + private void testCreditBasedFetch(boolean enableMemoryPool) + throws IOException, InterruptedException { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); + configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); + configuration.put( + ExecutionConfigKeys.SHUFFLE_BACKPRESSURE_ENABLE, String.valueOf(enableBackPressure)); + configuration.put(ExecutionConfigKeys.SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE, "1"); + configuration.put(ExecutionConfigKeys.SHUFFLE_WRITER_BUFFER_SIZE, "10"); + configuration.put(ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_SIZE_BYTES, "5"); + configuration.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE, String.valueOf(enableMemoryPool)); + + ShuffleManager shuffleManager = ShuffleManager.init(configuration); + IConnectionManager connectionManager = shuffleManager.getConnectionManager(); + + List inputSlices = new ArrayList<>(); + ShuffleAddress address = connectionManager.getShuffleAddress(); + SliceId sliceId = new SliceId(100, 0, 0, 0); + PipelineSliceMeta slice1 = new PipelineSliceMeta(sliceId, 1, address); + inputSlices.add(slice1); + + PipeBuffer pipeBuffer = buildPipeBuffer(enableMemoryPool, "hello".getBytes(), 1); + PipeBuffer pipeBuffer2 = new PipeBuffer(1, 1, false); + PipeBuffer pipeBuffer3 = buildPipeBuffer(enableMemoryPool, "hello".getBytes(), 2); + PipeBuffer pipeBuffer4 = new PipeBuffer(2, 1, true); + PipelineSlice slice = new PipelineSlice("task", sliceId); + slice.add(pipeBuffer); + slice.add(pipeBuffer2); + slice.add(pipeBuffer3); + slice.add(pipeBuffer4); + + SliceManager shuffleDataManager = ShuffleManager.getInstance().getSliceManager(); + shuffleDataManager.register(sliceId, slice); + + OneShardFetcher fetcher = + new MockedShardFetcher(1, "taskName", 0, inputSlices, 0, connectionManager); + List batchList = Arrays.asList(1L, 2L); + List result = new ArrayList<>(); + for (long batchId : batchList) { + fetcher.requestSlices(batchId); + while (!fetcher.isFinished()) { + Optional bufferOptional = fetcher.getNext(); + if (bufferOptional.isPresent()) { + PipeFetcherBuffer buffer = bufferOptional.get(); + String value = String.valueOf(buffer.getBuffer()); + result.add(value); + if (buffer.isBarrier()) { + break; + } } + } + } - Assert.assertEquals(result.size(), 4); - Assert.assertTrue(slice.canRelease()); - - fetcher.close(); - for (AbstractInputChannel channel : fetcher.getInputChannels().values()) { - if (channel instanceof RemoteInputChannel) { - RemoteInputChannel remoteChannel = (RemoteInputChannel) channel; - Assert.assertTrue(remoteChannel.getSliceRequestClient().disposeIfNotUsed()); - } - } + Assert.assertEquals(result.size(), 4); + Assert.assertTrue(slice.canRelease()); - ShuffleManager.getInstance().release(sliceId.getPipelineId()); - ShuffleManager.getInstance().close(); + fetcher.close(); + for (AbstractInputChannel channel : fetcher.getInputChannels().values()) { + if (channel instanceof RemoteInputChannel) { + RemoteInputChannel remoteChannel = (RemoteInputChannel) channel; + Assert.assertTrue(remoteChannel.getSliceRequestClient().disposeIfNotUsed()); + } } - private PipeBuffer buildPipeBuffer(boolean enableMemoryPool, byte[] content, long batchId) { - if (enableMemoryPool) { - MemoryView view = MemoryManager.getInstance() - .requireMemory(content.length, MemoryGroupManger.SHUFFLE); - view.getWriter().write(content); - return new PipeBuffer(new MemoryViewBuffer(view, false), batchId); - } else { - return new PipeBuffer(content, batchId); - } + ShuffleManager.getInstance().release(sliceId.getPipelineId()); + ShuffleManager.getInstance().close(); + } + + private PipeBuffer buildPipeBuffer(boolean enableMemoryPool, byte[] content, long batchId) { + if (enableMemoryPool) { + MemoryView view = + MemoryManager.getInstance().requireMemory(content.length, MemoryGroupManger.SHUFFLE); + view.getWriter().write(content); + return new PipeBuffer(new MemoryViewBuffer(view, false), batchId); + } else { + return new PipeBuffer(content, batchId); } + } - public static class SimpleTestFactory { + public static class SimpleTestFactory { - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new SliceRequestClientTest(false), - new SliceRequestClientTest(true), - }; - } + @Factory + public Object[] factoryMethod() { + return new Object[] { + new SliceRequestClientTest(false), new SliceRequestClientTest(true), + }; } - -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBufferTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBufferTest.java index 7d78eb7dd..a573d1f0c 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBufferTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/MemoryViewBufferTest.java @@ -29,29 +29,28 @@ public class MemoryViewBufferTest { - @Test - public void test() { - Configuration configuration = new Configuration(); - MemoryManager manager = MemoryManager.build(configuration); - - int recordCount = 1024 * 16; - MemoryViewBufferBuilder buffer = new MemoryViewBufferBuilder(); - RecordSerializer serializer = new RecordSerializer(); - for (int i = 0; i < recordCount; i++) { - serializer.serialize(i, false, buffer); - } - - Assert.assertTrue(buffer.getBufferSize() > 0); - - int result = 0; - MessageIterator iterator = new MessageIterator(buffer.build().getInputStream()); - while (iterator.hasNext()) { - iterator.next(); - result++; - } - Assert.assertTrue(result == recordCount); - - manager.dispose(); + @Test + public void test() { + Configuration configuration = new Configuration(); + MemoryManager manager = MemoryManager.build(configuration); + + int recordCount = 1024 * 16; + MemoryViewBufferBuilder buffer = new MemoryViewBufferBuilder(); + RecordSerializer serializer = new RecordSerializer(); + for (int i = 0; i < recordCount; i++) { + serializer.serialize(i, false, buffer); } + Assert.assertTrue(buffer.getBufferSize() > 0); + + int result = 0; + MessageIterator iterator = new MessageIterator(buffer.build().getInputStream()); + while (iterator.hasNext()) { + iterator.next(); + result++; + } + Assert.assertTrue(result == recordCount); + + manager.dispose(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTrackerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTrackerTest.java index 5d37902ea..bf5849c83 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTrackerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/buffer/ShuffleMemoryTrackerTest.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.testng.Assert; @@ -33,73 +34,74 @@ public class ShuffleMemoryTrackerTest { - @Test - public void testRequireAndRelease() { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); - ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(configuration); + @Test + public void testRequireAndRelease() { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); + ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(configuration); - byte[] bytes1 = new byte[100]; - HeapBuffer buffer1 = new HeapBuffer(bytes1, false); - Assert.assertEquals(tracker.getUsedMemory(), 0); + byte[] bytes1 = new byte[100]; + HeapBuffer buffer1 = new HeapBuffer(bytes1, false); + Assert.assertEquals(tracker.getUsedMemory(), 0); - byte[] bytes2 = new byte[200]; - HeapBuffer buffer2 = new HeapBuffer(bytes2, tracker); - Assert.assertEquals(tracker.getUsedMemory(), 200); + byte[] bytes2 = new byte[200]; + HeapBuffer buffer2 = new HeapBuffer(bytes2, tracker); + Assert.assertEquals(tracker.getUsedMemory(), 200); - buffer1.release(); - Assert.assertEquals(tracker.getUsedMemory(), 200); - buffer2.release(); - Assert.assertEquals(tracker.getUsedMemory(), 0); - } + buffer1.release(); + Assert.assertEquals(tracker.getUsedMemory(), 200); + buffer2.release(); + Assert.assertEquals(tracker.getUsedMemory(), 0); + } - @Test - public void testOffheapRequireAndRelease() { - Configuration configuration = new Configuration(); - configuration.put(SHUFFLE_MEMORY_POOL_ENABLE, "true"); - ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(configuration); + @Test + public void testOffheapRequireAndRelease() { + Configuration configuration = new Configuration(); + configuration.put(SHUFFLE_MEMORY_POOL_ENABLE, "true"); + ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(configuration); - byte[] bytes1 = new byte[100]; - HeapBuffer buffer1 = new HeapBuffer(bytes1, false); - Assert.assertEquals(tracker.getUsedMemory(), 0); + byte[] bytes1 = new byte[100]; + HeapBuffer buffer1 = new HeapBuffer(bytes1, false); + Assert.assertEquals(tracker.getUsedMemory(), 0); - byte[] bytes2 = new byte[200]; - HeapBuffer buffer2 = new HeapBuffer(bytes2, tracker); - Assert.assertEquals(tracker.getUsedMemory(), 200); + byte[] bytes2 = new byte[200]; + HeapBuffer buffer2 = new HeapBuffer(bytes2, tracker); + Assert.assertEquals(tracker.getUsedMemory(), 200); - buffer1.release(); - Assert.assertEquals(tracker.getUsedMemory(), 200); - buffer2.release(); - Assert.assertEquals(tracker.getUsedMemory(), 0); - } + buffer1.release(); + Assert.assertEquals(tracker.getUsedMemory(), 200); + buffer2.release(); + Assert.assertEquals(tracker.getUsedMemory(), 0); + } - @Test - public void testConcurrency() throws InterruptedException { - Map config = new HashMap<>(); - config.put(CONTAINER_HEAP_SIZE_MB.getKey(), "1000"); - ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(new Configuration(config)); + @Test + public void testConcurrency() throws InterruptedException { + Map config = new HashMap<>(); + config.put(CONTAINER_HEAP_SIZE_MB.getKey(), "1000"); + ShuffleMemoryTracker tracker = new ShuffleMemoryTracker(new Configuration(config)); - Random random = new Random(); - CountDownLatch latch = new CountDownLatch(3); - for (int i = 0; i < 3; i++) { - new Thread(new Runnable() { + Random random = new Random(); + CountDownLatch latch = new CountDownLatch(3); + for (int i = 0; i < 3; i++) { + new Thread( + new Runnable() { @Override public void run() { - for (int time = 0; time < 10; time++) { - int required = random.nextInt(1000); - boolean suc = tracker.requireMemory(required); - if (suc) { - tracker.releaseMemory(required); - } else { - System.out.println("not enough"); - } + for (int time = 0; time < 10; time++) { + int required = random.nextInt(1000); + boolean suc = tracker.requireMemory(required); + if (suc) { + tracker.releaseMemory(required); + } else { + System.out.println("not enough"); } - latch.countDown(); + } + latch.countDown(); } - }).start(); - } - latch.await(); - Assert.assertEquals(tracker.getUsedMemory(), 0); + }) + .start(); } - + latch.await(); + Assert.assertEquals(tracker.getUsedMemory(), 0); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MockedShardFetcher.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MockedShardFetcher.java index 6414bdb71..19e0d7fb9 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MockedShardFetcher.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MockedShardFetcher.java @@ -20,6 +20,7 @@ package org.apache.geaflow.shuffle.pipeline.fetcher; import java.util.List; + import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.shuffle.message.ISliceMeta; import org.apache.geaflow.shuffle.message.PipelineSliceMeta; @@ -32,31 +33,44 @@ import org.slf4j.LoggerFactory; public class MockedShardFetcher extends OneShardFetcher { - private static final Logger LOGGER = LoggerFactory.getLogger(MockedShardFetcher.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MockedShardFetcher.class); - public MockedShardFetcher(int stageId, String taskName, int fetcherIndex, - List inputSlices, long startBatchId, - IConnectionManager connectionManager) { - super(stageId, taskName, fetcherIndex, inputSlices, startBatchId, - connectionManager); - } + public MockedShardFetcher( + int stageId, + String taskName, + int fetcherIndex, + List inputSlices, + long startBatchId, + IConnectionManager connectionManager) { + super(stageId, taskName, fetcherIndex, inputSlices, startBatchId, connectionManager); + } - @Override - protected void buildInputChannels(int connectionId, List inputSlices, - int initialBackoff, int maxBackoff, long startBatchId) { - - int localChannels = 0; - for (int inputChannelIdx = 0; inputChannelIdx < numberOfInputChannels; inputChannelIdx++) { - PipelineSliceMeta task = (PipelineSliceMeta) inputSlices.get(inputChannelIdx); - ShuffleAddress address = task.getShuffleAddress(); - SliceId inputSlice = task.getSliceId(); - - AbstractInputChannel inputChannel = new RemoteInputChannel(this, inputSlice, - inputChannelIdx, new ConnectionId(address, connectionId), initialBackoff, - maxBackoff, startBatchId, connectionManager); - inputChannels.put(inputSlice, inputChannel); - } - LOGGER.info("create {} local channels in {} channels", localChannels, numberOfInputChannels); - } + @Override + protected void buildInputChannels( + int connectionId, + List inputSlices, + int initialBackoff, + int maxBackoff, + long startBatchId) { + int localChannels = 0; + for (int inputChannelIdx = 0; inputChannelIdx < numberOfInputChannels; inputChannelIdx++) { + PipelineSliceMeta task = (PipelineSliceMeta) inputSlices.get(inputChannelIdx); + ShuffleAddress address = task.getShuffleAddress(); + SliceId inputSlice = task.getSliceId(); + + AbstractInputChannel inputChannel = + new RemoteInputChannel( + this, + inputSlice, + inputChannelIdx, + new ConnectionId(address, connectionId), + initialBackoff, + maxBackoff, + startBatchId, + connectionManager); + inputChannels.put(inputSlice, inputChannel); + } + LOGGER.info("create {} local channels in {} channels", localChannels, numberOfInputChannels); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcherTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcherTest.java index 1c15086f1..e0d9c0002 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcherTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/MultiShardFetcherTest.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.shuffle.ShuffleAddress; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -36,36 +37,37 @@ public class MultiShardFetcherTest { - @Test - public void testCreate() throws IOException { - List inputSlices1 = new ArrayList<>(); - ShuffleConfig config = new ShuffleConfig(new Configuration()); - IConnectionManager connectionManager = new ConnectionManager(config); - ShuffleAddress address = connectionManager.getShuffleAddress(); - PipelineSliceMeta slice1 = new PipelineSliceMeta(0, 0, -1, 0, address); - inputSlices1.add(slice1); - - OneShardFetcher fetcher1 = new OneShardFetcher(1, "taskName", 0, inputSlices1, 0, connectionManager); - Map channelMap = fetcher1.getInputChannels(); - Assert.assertEquals(channelMap.size(), 1); + @Test + public void testCreate() throws IOException { + List inputSlices1 = new ArrayList<>(); + ShuffleConfig config = new ShuffleConfig(new Configuration()); + IConnectionManager connectionManager = new ConnectionManager(config); + ShuffleAddress address = connectionManager.getShuffleAddress(); + PipelineSliceMeta slice1 = new PipelineSliceMeta(0, 0, -1, 0, address); + inputSlices1.add(slice1); - List inputSlices2 = new ArrayList<>(); - PipelineSliceMeta slice2 = new PipelineSliceMeta(0, 2, -1, 0, address); - inputSlices2.add(slice2); + OneShardFetcher fetcher1 = + new OneShardFetcher(1, "taskName", 0, inputSlices1, 0, connectionManager); + Map channelMap = fetcher1.getInputChannels(); + Assert.assertEquals(channelMap.size(), 1); - OneShardFetcher fetcher2 = new OneShardFetcher(1, "taskName", 1, inputSlices2, 0, connectionManager); - channelMap = fetcher2.getInputChannels(); - Assert.assertEquals(channelMap.size(), 1); + List inputSlices2 = new ArrayList<>(); + PipelineSliceMeta slice2 = new PipelineSliceMeta(0, 2, -1, 0, address); + inputSlices2.add(slice2); - MultiShardFetcher multiShardFetcher = new MultiShardFetcher(fetcher1, fetcher2); - Assert.assertEquals(multiShardFetcher.getNumberOfInputChannels(), 2); + OneShardFetcher fetcher2 = + new OneShardFetcher(1, "taskName", 1, inputSlices2, 0, connectionManager); + channelMap = fetcher2.getInputChannels(); + Assert.assertEquals(channelMap.size(), 1); - ShardFetcher[] fetchers = multiShardFetcher.getShardFetchers(); - Assert.assertEquals(((OneShardFetcher) fetchers[0]).getFetcherIndex(), 0); - Assert.assertEquals(((OneShardFetcher) fetchers[1]).getFetcherIndex(), 1); + MultiShardFetcher multiShardFetcher = new MultiShardFetcher(fetcher1, fetcher2); + Assert.assertEquals(multiShardFetcher.getNumberOfInputChannels(), 2); - multiShardFetcher.close(); - connectionManager.close(); - } + ShardFetcher[] fetchers = multiShardFetcher.getShardFetchers(); + Assert.assertEquals(((OneShardFetcher) fetchers[0]).getFetcherIndex(), 0); + Assert.assertEquals(((OneShardFetcher) fetchers[1]).getFetcherIndex(), 1); + multiShardFetcher.close(); + connectionManager.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcherTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcherTest.java index 5d77477fa..0c02630db 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcherTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/fetcher/OneShardFetcherTest.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.shuffle.ShuffleAddress; @@ -48,105 +49,105 @@ public class OneShardFetcherTest { - private Configuration configuration; - - @BeforeTest - public void setup() { - configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); - configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); + private Configuration configuration; + + @BeforeTest + public void setup() { + configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "default"); + configuration.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); + } + + @Test + public void testCreate() { + List inputSlices = new ArrayList<>(); + IConnectionManager connectionManager = + ShuffleManager.init(configuration).getConnectionManager(); + ShuffleAddress address = connectionManager.getShuffleAddress(); + PipelineSliceMeta slice1 = new PipelineSliceMeta(0, 0, -1, 0, address); + PipelineSliceMeta slice2 = new PipelineSliceMeta(0, 2, -1, 0, address); + inputSlices.add(slice1); + inputSlices.add(slice2); + + OneShardFetcher fetcher = + new OneShardFetcher(1, "taskName", 0, inputSlices, 0, connectionManager); + Map channelMap = fetcher.getInputChannels(); + Assert.assertEquals(channelMap.size(), 2); + + Set expectedSlices = new HashSet<>(); + expectedSlices.add(new SliceId(-1, 0, 0, 0)); + expectedSlices.add(new SliceId(-1, 0, 0, 2)); + + Set expectedChannelIndices = new HashSet<>(); + expectedChannelIndices.add(0); + expectedChannelIndices.add(1); + + Set channelIds = new HashSet<>(); + for (Map.Entry entry : channelMap.entrySet()) { + Assert.assertTrue(expectedSlices.remove(entry.getKey())); + AbstractInputChannel channel = entry.getValue(); + channelIds.add(channel.getChannelIndex()); + Assert.assertTrue(channel instanceof LocalInputChannel); } - - @Test - public void testCreate() { - List inputSlices = new ArrayList<>(); - IConnectionManager connectionManager = ShuffleManager.init(configuration).getConnectionManager(); - ShuffleAddress address = connectionManager.getShuffleAddress(); - PipelineSliceMeta slice1 = new PipelineSliceMeta(0, 0, -1, 0, address); - PipelineSliceMeta slice2 = new PipelineSliceMeta(0, 2, -1, 0, address); - inputSlices.add(slice1); - inputSlices.add(slice2); - - OneShardFetcher fetcher = new OneShardFetcher(1, "taskName", 0, inputSlices, 0, - connectionManager); - Map channelMap = fetcher.getInputChannels(); - Assert.assertEquals(channelMap.size(), 2); - - Set expectedSlices = new HashSet<>(); - expectedSlices.add(new SliceId(-1, 0, 0, 0)); - expectedSlices.add(new SliceId(-1, 0, 0, 2)); - - Set expectedChannelIndices = new HashSet<>(); - expectedChannelIndices.add(0); - expectedChannelIndices.add(1); - - Set channelIds = new HashSet<>(); - for (Map.Entry entry : channelMap.entrySet()) { - Assert.assertTrue(expectedSlices.remove(entry.getKey())); - AbstractInputChannel channel = entry.getValue(); - channelIds.add(channel.getChannelIndex()); - Assert.assertTrue(channel instanceof LocalInputChannel); + Assert.assertEquals(expectedChannelIndices, channelIds); + } + + @Test + public void testRemoteFetch() throws IOException, InterruptedException { + IConnectionManager connectionManager = + ShuffleManager.init(configuration).getConnectionManager(); + ShuffleAddress address = connectionManager.getShuffleAddress(); + final SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); + + long pipelineId = 1; + SliceId sliceId = new SliceId(pipelineId, 0, 0, 0); + PipelineSliceMeta slice1 = new PipelineSliceMeta(sliceId, 1, address); + List inputSlices = new ArrayList<>(); + inputSlices.add(slice1); + + OneShardFetcher fetcher = + new MockedShardFetcher(1, "taskName", 0, inputSlices, 0, connectionManager); + + PipeBuffer pipeBuffer = new PipeBuffer(new HeapBuffer("hello".getBytes(), false), 1); + PipeBuffer pipeBuffer2 = new PipeBuffer(1, 1, false); + PipeBuffer pipeBuffer3 = new PipeBuffer(new HeapBuffer("hello".getBytes(), false), 2); + PipeBuffer pipeBuffer4 = new PipeBuffer(2, 1, true); + PipelineSlice slice = new PipelineSlice("task", sliceId); + slice.add(pipeBuffer); + slice.add(pipeBuffer2); + slice.add(pipeBuffer3); + slice.add(pipeBuffer4); + + sliceManager.register(sliceId, slice); + + List batchList = Arrays.asList(1L, 2L); + List result = new ArrayList<>(); + for (long batchId : batchList) { + fetcher.requestSlices(batchId); + while (!fetcher.isFinished()) { + Optional bufferOptional = fetcher.getNext(); + if (bufferOptional.isPresent()) { + PipeFetcherBuffer buffer = bufferOptional.get(); + String value = String.valueOf(buffer.getBuffer()); + result.add(value); + if (buffer.isBarrier()) { + break; + } } - Assert.assertEquals(expectedChannelIndices, channelIds); + } } - @Test - public void testRemoteFetch() throws IOException, InterruptedException { - IConnectionManager connectionManager = - ShuffleManager.init(configuration).getConnectionManager(); - ShuffleAddress address = connectionManager.getShuffleAddress(); - final SliceManager sliceManager = ShuffleManager.getInstance().getSliceManager(); - - long pipelineId = 1; - SliceId sliceId = new SliceId(pipelineId, 0, 0, 0); - PipelineSliceMeta slice1 = new PipelineSliceMeta(sliceId, 1, address); - List inputSlices = new ArrayList<>(); - inputSlices.add(slice1); - - OneShardFetcher fetcher = new MockedShardFetcher(1, "taskName", 0, inputSlices, 0, - connectionManager); - - PipeBuffer pipeBuffer = new PipeBuffer(new HeapBuffer("hello".getBytes(), false), 1); - PipeBuffer pipeBuffer2 = new PipeBuffer(1, 1, false); - PipeBuffer pipeBuffer3 = new PipeBuffer(new HeapBuffer("hello".getBytes(), false), 2); - PipeBuffer pipeBuffer4 = new PipeBuffer(2, 1, true); - PipelineSlice slice = new PipelineSlice("task", sliceId); - slice.add(pipeBuffer); - slice.add(pipeBuffer2); - slice.add(pipeBuffer3); - slice.add(pipeBuffer4); - - sliceManager.register(sliceId, slice); - - List batchList = Arrays.asList(1L, 2L); - List result = new ArrayList<>(); - for (long batchId : batchList) { - fetcher.requestSlices(batchId); - while (!fetcher.isFinished()) { - Optional bufferOptional = fetcher.getNext(); - if (bufferOptional.isPresent()) { - PipeFetcherBuffer buffer = bufferOptional.get(); - String value = String.valueOf(buffer.getBuffer()); - result.add(value); - if (buffer.isBarrier()) { - break; - } - } - } - } + Assert.assertEquals(result.size(), 4); + Assert.assertTrue(slice.canRelease()); - Assert.assertEquals(result.size(), 4); - Assert.assertTrue(slice.canRelease()); - - fetcher.close(); - for (AbstractInputChannel channel : fetcher.getInputChannels().values()) { - if (channel instanceof RemoteInputChannel) { - RemoteInputChannel remoteChannel = (RemoteInputChannel) channel; - Assert.assertTrue(remoteChannel.getSliceRequestClient().disposeIfNotUsed()); - } - } - ShuffleManager.getInstance().release(pipelineId); - ShuffleManager.getInstance().close(); + fetcher.close(); + for (AbstractInputChannel channel : fetcher.getInputChannels().values()) { + if (channel instanceof RemoteInputChannel) { + RemoteInputChannel remoteChannel = (RemoteInputChannel) channel; + Assert.assertTrue(remoteChannel.getSliceRequestClient().disposeIfNotUsed()); + } } - + ShuffleManager.getInstance().release(pipelineId); + ShuffleManager.getInstance().close(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSliceTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSliceTest.java index bd63e3ebc..a17a0e432 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSliceTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/pipeline/slice/SpillablePipelineSliceTest.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.UUID; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.shuffle.config.ShuffleConfig; @@ -36,44 +37,42 @@ public class SpillablePipelineSliceTest { - @Test - public void testAdd() throws IOException { - Configuration config = new Configuration(); - config.put(SHUFFLE_SPILL_RECORDS, "50"); - config.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); - - long id = UUID.randomUUID().getLeastSignificantBits(); - SliceId sliceId = new SliceId(id, 0, 0, 0); + @Test + public void testAdd() throws IOException { + Configuration config = new Configuration(); + config.put(SHUFFLE_SPILL_RECORDS, "50"); + config.put(ExecutionConfigKeys.CONTAINER_HEAP_SIZE_MB, String.valueOf(1024)); - ShuffleConfig shuffleConfig = new ShuffleConfig(config); - ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(config); - SpillablePipelineSlice slice = new SpillablePipelineSlice("test", sliceId, shuffleConfig, - memoryTracker); + long id = UUID.randomUUID().getLeastSignificantBits(); + SliceId sliceId = new SliceId(id, 0, 0, 0); - byte[] bytes1 = new byte[100]; + ShuffleConfig shuffleConfig = new ShuffleConfig(config); + ShuffleMemoryTracker memoryTracker = new ShuffleMemoryTracker(config); + SpillablePipelineSlice slice = + new SpillablePipelineSlice("test", sliceId, shuffleConfig, memoryTracker); - int bufferCount = 10000; - for (int i = 0; i < bufferCount; i++) { - HeapBuffer outBuffer = new HeapBuffer(bytes1, memoryTracker); - PipeBuffer buffer = new PipeBuffer(outBuffer, 1); - slice.add(buffer); - } - slice.flush(); + byte[] bytes1 = new byte[100]; - // Check repeatable reader. - int consumedBufferCount = 0; - PipelineSliceReader reader = slice.createSliceReader(1, () -> { - }); - while (reader.hasNext()) { - PipeChannelBuffer buffer = reader.next(); - if (buffer != null) { - consumedBufferCount++; - } - } + int bufferCount = 10000; + for (int i = 0; i < bufferCount; i++) { + HeapBuffer outBuffer = new HeapBuffer(bytes1, memoryTracker); + PipeBuffer buffer = new PipeBuffer(outBuffer, 1); + slice.add(buffer); + } + slice.flush(); - Assert.assertEquals(bufferCount, consumedBufferCount); - Assert.assertTrue(slice.canRelease()); - slice.release(); + // Check repeatable reader. + int consumedBufferCount = 0; + PipelineSliceReader reader = slice.createSliceReader(1, () -> {}); + while (reader.hasNext()) { + PipeChannelBuffer buffer = reader.next(); + if (buffer != null) { + consumedBufferCount++; + } } + Assert.assertEquals(bufferCount, consumedBufferCount); + Assert.assertTrue(slice.canRelease()); + slice.release(); + } } diff --git a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/serialize/RecordSerializerTest.java b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/serialize/RecordSerializerTest.java index 3bd37c9a1..9128c8e2d 100644 --- a/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/serialize/RecordSerializerTest.java +++ b/geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/test/java/org/apache/geaflow/shuffle/serialize/RecordSerializerTest.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.model.record.impl.Record; import org.apache.geaflow.shuffle.pipeline.buffer.HeapBuffer.HeapBufferBuilder; import org.apache.geaflow.shuffle.pipeline.buffer.OutBuffer; @@ -30,23 +31,22 @@ public class RecordSerializerTest { - @Test - public void test() { - RecordSerializer serializer = new RecordSerializer(); - BufferBuilder outBuffer = new HeapBufferBuilder(); - List recordList = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - Record record = new Record(i); - serializer.serialize(record, false, outBuffer); - recordList.add(record); - } - OutBuffer buffer = outBuffer.build(); - MessageIterator deserializer = new MessageIterator(buffer); - List list = new ArrayList<>(); - while (deserializer.hasNext()) { - list.add((Record) deserializer.next()); - } - Assert.assertEquals(recordList, list); + @Test + public void test() { + RecordSerializer serializer = new RecordSerializer(); + BufferBuilder outBuffer = new HeapBufferBuilder(); + List recordList = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + Record record = new Record(i); + serializer.serialize(record, false, outBuffer); + recordList.add(record); } - + OutBuffer buffer = outBuffer.build(); + MessageIterator deserializer = new MessageIterator(buffer); + List list = new ArrayList<>(); + while (deserializer.hasNext()) { + list.add((Record) deserializer.next()); + } + Assert.assertEquals(recordList, list); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/Constants.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/Constants.java index bb253fef2..59a78d1b3 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/Constants.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/Constants.java @@ -21,10 +21,9 @@ public class Constants { - public static final String META_SERVER = "meta_server"; + public static final String META_SERVER = "meta_server"; - public static final String META_VERSION = "version"; - - public static final String APP_NAME = "app_name"; + public static final String META_VERSION = "version"; + public static final String APP_NAME = "app_name"; } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/MetaServerConfigKeys.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/MetaServerConfigKeys.java index 79bd4c231..8ef3d6b95 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/MetaServerConfigKeys.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/MetaServerConfigKeys.java @@ -24,9 +24,8 @@ public class MetaServerConfigKeys { - public static final ConfigKey META_SERVER_SERVICE_TYPE = ConfigKeys - .key("geaflow.meta.server.service.type") - .defaultValue("local") - .description("meta server service type eg[LOCAL,STATE_SERVICE]"); - + public static final ConfigKey META_SERVER_SERVICE_TYPE = + ConfigKeys.key("geaflow.meta.server.service.type") + .defaultValue("local") + .description("meta server service type eg[LOCAL,STATE_SERVICE]"); } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/BaseClient.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/BaseClient.java index fa7f4bc5d..d09c323b2 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/BaseClient.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/BaseClient.java @@ -21,9 +21,6 @@ import static org.apache.geaflow.metaserver.Constants.META_SERVER; -import com.baidu.brpc.client.BrpcProxy; -import com.baidu.brpc.client.RpcClient; -import com.baidu.brpc.client.channel.Endpoint; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.rpc.ConfigurableClientOption; @@ -42,98 +39,108 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class BaseClient implements ServiceListener { +import com.baidu.brpc.client.BrpcProxy; +import com.baidu.brpc.client.RpcClient; +import com.baidu.brpc.client.channel.Endpoint; - private static final Logger LOGGER = LoggerFactory.getLogger(BaseClient.class); - private static final int MAX_RETRY = 5; - private static final int RETRY_MS = 10000; - protected ServiceConsumer serviceConsumer; - protected RpcClient rpcClient; - protected MetaServerService metaServerService; - private HostAndPort currentServiceInfo; - protected Configuration configuration; +public abstract class BaseClient implements ServiceListener { - public BaseClient() { + private static final Logger LOGGER = LoggerFactory.getLogger(BaseClient.class); + private static final int MAX_RETRY = 5; + private static final int RETRY_MS = 10000; + protected ServiceConsumer serviceConsumer; + protected RpcClient rpcClient; + protected MetaServerService metaServerService; + private HostAndPort currentServiceInfo; + protected Configuration configuration; + + public BaseClient() {} + + public BaseClient(Configuration configuration) { + this.configuration = configuration; + ServiceBuilder serviceBuilder = + ServiceBuilderFactory.build( + configuration.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE)); + serviceConsumer = serviceBuilder.buildConsumer(configuration); + serviceConsumer.register(this); + buildServerConnect(true); + } + + protected synchronized void buildServerConnect(boolean force) { + boolean exits = serviceConsumer.exists(META_SERVER); + if (!exits) { + if (!force) { + return; + } + throw new IllegalStateException("not find meta server info"); + } + byte[] bytes = serviceConsumer.getDataAndWatch(META_SERVER); + HostAndPort serviceInfo = + (HostAndPort) SerializerFactory.getKryoSerializer().deserialize(bytes); + if (currentServiceInfo != null && currentServiceInfo.equals(serviceInfo)) { + LOGGER.info("service info {} is same, skip update", currentServiceInfo); + return; } - public BaseClient(Configuration configuration) { - this.configuration = configuration; - ServiceBuilder serviceBuilder = ServiceBuilderFactory.build( - configuration.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE)); - serviceConsumer = serviceBuilder.buildConsumer(configuration); - serviceConsumer.register(this); - buildServerConnect(true); + if (rpcClient != null) { + rpcClient.stop(); } - protected synchronized void buildServerConnect(boolean force) { - boolean exits = serviceConsumer.exists(META_SERVER); - if (!exits) { - if (!force) { - return; - } - throw new IllegalStateException("not find meta server info"); - } - byte[] bytes = serviceConsumer.getDataAndWatch(META_SERVER); - HostAndPort serviceInfo = (HostAndPort) SerializerFactory.getKryoSerializer() - .deserialize(bytes); - if (currentServiceInfo != null && currentServiceInfo.equals(serviceInfo)) { - LOGGER.info("service info {} is same, skip update", currentServiceInfo); - return; - } - - if (rpcClient != null) { - rpcClient.stop(); - } - - LOGGER.info("connect to meta server {}", serviceInfo); - - rpcClient = new RpcClient(new Endpoint(serviceInfo.getHost(), serviceInfo.getPort()), + LOGGER.info("connect to meta server {}", serviceInfo); + + rpcClient = + new RpcClient( + new Endpoint(serviceInfo.getHost(), serviceInfo.getPort()), ConfigurableClientOption.build(configuration)); - metaServerService = BrpcProxy.getProxy(rpcClient, MetaServerService.class); - currentServiceInfo = serviceInfo; - } + metaServerService = BrpcProxy.getProxy(rpcClient, MetaServerService.class); + currentServiceInfo = serviceInfo; + } - protected T process(MetaRequest request) { - return (T) RetryCommand.run( + protected T process(MetaRequest request) { + return (T) + RetryCommand.run( () -> { - ServiceResultPb result = metaServerService.process(RequestPBConverter.convert(request)); - return ResponsePBConverter.convert(result); + ServiceResultPb result = + metaServerService.process(RequestPBConverter.convert(request)); + return ResponsePBConverter.convert(result); }, () -> { - buildServerConnect(false); - return true; + buildServerConnect(false); + return true; }, - MAX_RETRY, RETRY_MS, true); - } - - @Override - public void nodeCreated(String path) { - checkAndUpdateConnector(path); - } - - @Override - public void nodeDeleted(String path) { - checkAndUpdateConnector(path); - } - - @Override - public void nodeDataChanged(String path) { - checkAndUpdateConnector(path); - } - - @Override - public void nodeChildrenChanged(String path) { - checkAndUpdateConnector(path); + MAX_RETRY, + RETRY_MS, + true); + } + + @Override + public void nodeCreated(String path) { + checkAndUpdateConnector(path); + } + + @Override + public void nodeDeleted(String path) { + checkAndUpdateConnector(path); + } + + @Override + public void nodeDataChanged(String path) { + checkAndUpdateConnector(path); + } + + @Override + public void nodeChildrenChanged(String path) { + checkAndUpdateConnector(path); + } + + private void checkAndUpdateConnector(String path) { + if (path.contains(META_SERVER)) { + buildServerConnect(false); } + } - private void checkAndUpdateConnector(String path) { - if (path.contains(META_SERVER)) { - buildServerConnect(false); - } - } - - public void close() { - serviceConsumer.close(); - rpcClient.stop(); - } + public void close() { + serviceConsumer.close(); + rpcClient.stop(); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/MetaServerQueryClient.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/MetaServerQueryClient.java index 3ca56aa75..e54c29376 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/MetaServerQueryClient.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/client/MetaServerQueryClient.java @@ -20,6 +20,7 @@ package org.apache.geaflow.metaserver.client; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.rpc.HostAndPort; @@ -29,31 +30,30 @@ public class MetaServerQueryClient extends BaseClient { - private static MetaServerQueryClient client; - - public MetaServerQueryClient(Configuration configuration) { - super(configuration); - } - - public static synchronized MetaServerQueryClient getClient(Configuration configuration) { - if (client == null) { - client = new MetaServerQueryClient(configuration); - } - return client; - } + private static MetaServerQueryClient client; + public MetaServerQueryClient(Configuration configuration) { + super(configuration); + } - public List queryAllServices(NamespaceType namespaceType) { - ServiceResponse response = process(new QueryAllServiceRequest(namespaceType)); - if (!response.isSuccess()) { - throw new GeaflowRuntimeException(response.getMessage()); - } - return response.getServiceInfos(); + public static synchronized MetaServerQueryClient getClient(Configuration configuration) { + if (client == null) { + client = new MetaServerQueryClient(configuration); } + return client; + } - @Override - public void close() { - super.close(); - client = null; + public List queryAllServices(NamespaceType namespaceType) { + ServiceResponse response = process(new QueryAllServiceRequest(namespaceType)); + if (!response.isSuccess()) { + throw new GeaflowRuntimeException(response.getMessage()); } + return response.getServiceInfos(); + } + + @Override + public void close() { + super.close(); + client = null; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/internal/MetaServerClient.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/internal/MetaServerClient.java index 555150f9d..1c5ce6e72 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/internal/MetaServerClient.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/internal/MetaServerClient.java @@ -29,35 +29,33 @@ public class MetaServerClient extends BaseClient { - private static MetaServerClient client; + private static MetaServerClient client; - public MetaServerClient() { + public MetaServerClient() {} - } + public MetaServerClient(Configuration configuration) { + super(configuration); + } - public MetaServerClient(Configuration configuration) { - super(configuration); + public static synchronized MetaServerClient getClient(Configuration configuration) { + if (client == null) { + client = new MetaServerClient(configuration); } - - public static synchronized MetaServerClient getClient(Configuration configuration) { - if (client == null) { - client = new MetaServerClient(configuration); - } - return client; + return client; + } + + public void registerService( + NamespaceType namespaceType, String containerId, HostAndPort serviceInfo) { + DefaultResponse response = + process(new RegisterServiceRequest(namespaceType, containerId, serviceInfo)); + if (!response.isSuccess()) { + throw new GeaflowRuntimeException(response.getMessage()); } + } - public void registerService(NamespaceType namespaceType, String containerId, - HostAndPort serviceInfo) { - DefaultResponse response = process( - new RegisterServiceRequest(namespaceType, containerId, serviceInfo)); - if (!response.isSuccess()) { - throw new GeaflowRuntimeException(response.getMessage()); - } - } - - @Override - public void close() { - super.close(); - client = null; - } + @Override + public void close() { + super.close(); + client = null; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequest.java index 3878930b0..e9344ba10 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequest.java @@ -23,7 +23,7 @@ public interface MetaRequest { - NamespaceType namespaceType(); + NamespaceType namespaceType(); - MetaRequestType requestType(); + MetaRequestType requestType(); } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestType.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestType.java index e8fa9c26c..5c71d2de6 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestType.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestType.java @@ -21,13 +21,9 @@ public enum MetaRequestType { - /** - * Register service request. - */ - REGISTER_SERVICE, + /** Register service request. */ + REGISTER_SERVICE, - /** - * Query all service request. - */ - QUERY_ALL_SERVICE; + /** Query all service request. */ + QUERY_ALL_SERVICE; } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaResponse.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaResponse.java index 62041e333..d3c20b85e 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaResponse.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/MetaResponse.java @@ -19,6 +19,4 @@ package org.apache.geaflow.metaserver.model.protocal; -public interface MetaResponse { - -} +public interface MetaResponse {} diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/AbstractMetaRequest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/AbstractMetaRequest.java index fb1e2fecb..5ebf55448 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/AbstractMetaRequest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/AbstractMetaRequest.java @@ -24,14 +24,14 @@ public abstract class AbstractMetaRequest implements MetaRequest { - private final NamespaceType nameSpaceType; + private final NamespaceType nameSpaceType; - public AbstractMetaRequest(NamespaceType nameSpaceType) { - this.nameSpaceType = nameSpaceType; - } + public AbstractMetaRequest(NamespaceType nameSpaceType) { + this.nameSpaceType = nameSpaceType; + } - @Override - public NamespaceType namespaceType() { - return nameSpaceType; - } + @Override + public NamespaceType namespaceType() { + return nameSpaceType; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/QueryAllServiceRequest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/QueryAllServiceRequest.java index 1c447b176..85327c0ab 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/QueryAllServiceRequest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/QueryAllServiceRequest.java @@ -24,16 +24,16 @@ public class QueryAllServiceRequest extends AbstractMetaRequest { - public QueryAllServiceRequest(NamespaceType nameSpaceType) { - super(nameSpaceType); - } + public QueryAllServiceRequest(NamespaceType nameSpaceType) { + super(nameSpaceType); + } - public QueryAllServiceRequest() { - super(NamespaceType.DEFAULT); - } + public QueryAllServiceRequest() { + super(NamespaceType.DEFAULT); + } - @Override - public MetaRequestType requestType() { - return MetaRequestType.QUERY_ALL_SERVICE; - } + @Override + public MetaRequestType requestType() { + return MetaRequestType.QUERY_ALL_SERVICE; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RegisterServiceRequest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RegisterServiceRequest.java index f5f2c05be..74a8f004e 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RegisterServiceRequest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RegisterServiceRequest.java @@ -25,30 +25,29 @@ public class RegisterServiceRequest extends AbstractMetaRequest { - private final String containerId; - private final HostAndPort info; - - public RegisterServiceRequest(String containerId, HostAndPort info) { - this(NamespaceType.DEFAULT, containerId, info); - } - - public RegisterServiceRequest(NamespaceType namespaceType, String containerId, - HostAndPort info) { - super(namespaceType); - this.containerId = containerId; - this.info = info; - } - - public String getContainerId() { - return containerId; - } - - public HostAndPort getInfo() { - return info; - } - - @Override - public MetaRequestType requestType() { - return MetaRequestType.REGISTER_SERVICE; - } + private final String containerId; + private final HostAndPort info; + + public RegisterServiceRequest(String containerId, HostAndPort info) { + this(NamespaceType.DEFAULT, containerId, info); + } + + public RegisterServiceRequest(NamespaceType namespaceType, String containerId, HostAndPort info) { + super(namespaceType); + this.containerId = containerId; + this.info = info; + } + + public String getContainerId() { + return containerId; + } + + public HostAndPort getInfo() { + return info; + } + + @Override + public MetaRequestType requestType() { + return MetaRequestType.REGISTER_SERVICE; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RequestPBConverter.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RequestPBConverter.java index 387259d9f..627db16a2 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RequestPBConverter.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/request/RequestPBConverter.java @@ -19,21 +19,20 @@ package org.apache.geaflow.metaserver.model.protocal.request; -import com.google.protobuf.ByteString; import org.apache.geaflow.common.encoder.RpcMessageEncoder; import org.apache.geaflow.metaserver.model.protocal.MetaRequest; import org.apache.geaflow.rpc.proto.MetaServer.ServiceRequestPb; -public class RequestPBConverter { - - public static ServiceRequestPb convert(MetaRequest request) { - ByteString bytes = RpcMessageEncoder.encode(request); - return ServiceRequestPb.newBuilder().setRequest(bytes).build(); - } +import com.google.protobuf.ByteString; - public static MetaRequest convert(ServiceRequestPb request) { - return RpcMessageEncoder.decode(request.getRequest()); - } +public class RequestPBConverter { + public static ServiceRequestPb convert(MetaRequest request) { + ByteString bytes = RpcMessageEncoder.encode(request); + return ServiceRequestPb.newBuilder().setRequest(bytes).build(); + } + public static MetaRequest convert(ServiceRequestPb request) { + return RpcMessageEncoder.decode(request.getRequest()); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/DefaultResponse.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/DefaultResponse.java index 737c29332..286220717 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/DefaultResponse.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/DefaultResponse.java @@ -23,31 +23,31 @@ public class DefaultResponse implements MetaResponse { - protected boolean success; - protected String message; + protected boolean success; + protected String message; - public DefaultResponse(boolean success) { - this.success = success; - } + public DefaultResponse(boolean success) { + this.success = success; + } - public DefaultResponse(boolean success, String message) { - this.success = success; - this.message = message; - } + public DefaultResponse(boolean success, String message) { + this.success = success; + this.message = message; + } - public boolean isSuccess() { - return success; - } + public boolean isSuccess() { + return success; + } - public void setSuccess(boolean success) { - this.success = success; - } + public void setSuccess(boolean success) { + this.success = success; + } - public String getMessage() { - return message; - } + public String getMessage() { + return message; + } - public void setMessage(String message) { - this.message = message; - } + public void setMessage(String message) { + this.message = message; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ResponsePBConverter.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ResponsePBConverter.java index 007b4855f..a4bf82c92 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ResponsePBConverter.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ResponsePBConverter.java @@ -19,19 +19,20 @@ package org.apache.geaflow.metaserver.model.protocal.response; -import com.google.protobuf.ByteString; import org.apache.geaflow.common.encoder.RpcMessageEncoder; import org.apache.geaflow.metaserver.model.protocal.MetaResponse; import org.apache.geaflow.rpc.proto.MetaServer.ServiceResultPb; +import com.google.protobuf.ByteString; + public class ResponsePBConverter { - public static ServiceResultPb convert(MetaResponse response) { - ByteString bytes = RpcMessageEncoder.encode(response); - return ServiceResultPb.newBuilder().setResult(bytes).build(); - } + public static ServiceResultPb convert(MetaResponse response) { + ByteString bytes = RpcMessageEncoder.encode(response); + return ServiceResultPb.newBuilder().setResult(bytes).build(); + } - public static MetaResponse convert(ServiceResultPb result) { - return RpcMessageEncoder.decode(result.getResult()); - } + public static MetaResponse convert(ServiceResultPb result) { + return RpcMessageEncoder.decode(result.getResult()); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ServiceResponse.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ServiceResponse.java index dde34d1ce..9b106ca8a 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ServiceResponse.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/model/protocal/response/ServiceResponse.java @@ -20,22 +20,23 @@ package org.apache.geaflow.metaserver.model.protocal.response; import java.util.List; + import org.apache.geaflow.common.rpc.HostAndPort; public class ServiceResponse extends DefaultResponse { - private List serviceInfos; + private List serviceInfos; - public ServiceResponse(boolean success, String message) { - super(success, message); - } + public ServiceResponse(boolean success, String message) { + super(success, message); + } - public ServiceResponse(List serviceInfos) { - super(true); - this.serviceInfos = serviceInfos; - } + public ServiceResponse(List serviceInfos) { + super(true); + this.serviceInfos = serviceInfos; + } - public List getServiceInfos() { - return serviceInfos; - } + public List getServiceInfos() { + return serviceInfos; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/MetaServerService.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/MetaServerService.java index 16b69b76b..6f6c13ce4 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/MetaServerService.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/MetaServerService.java @@ -24,9 +24,6 @@ public interface MetaServerService { - /** - * Process all messages sent to meta server and return results. - */ - ServiceResultPb process(ServiceRequestPb request); - + /** Process all messages sent to meta server and return results. */ + ServiceResultPb process(ServiceRequestPb request); } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/NamespaceType.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/NamespaceType.java index 962aad913..6b71c4416 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/NamespaceType.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/main/java/org/apache/geaflow/metaserver/service/NamespaceType.java @@ -20,13 +20,9 @@ package org.apache.geaflow.metaserver.service; public enum NamespaceType { - /** - * Default namespace. - */ - DEFAULT, + /** Default namespace. */ + DEFAULT, - /** - * State service namespace. - */ - STATE_SERVICE + /** State service namespace. */ + STATE_SERVICE } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestTest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestTest.java index a7a47eff2..b26296216 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestTest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaRequestTest.java @@ -30,32 +30,30 @@ public class MetaRequestTest { - @Test - public void testQueryAllServiceRequest() { - - QueryAllServiceRequest queryAllServiceRequest = new QueryAllServiceRequest(); - ServiceRequestPb request = RequestPBConverter.convert(queryAllServiceRequest); - QueryAllServiceRequest serviceRequest = (QueryAllServiceRequest) RequestPBConverter.convert( - request); - - Assert.assertSame(serviceRequest.requestType(), MetaRequestType.QUERY_ALL_SERVICE); - } - - @Test - public void testRegisterServiceRequest() { - RegisterServiceRequest registerServiceRequest = - new RegisterServiceRequest(NamespaceType.STATE_SERVICE, "1", - new HostAndPort("127.0.0.1", 1024)); - ServiceRequestPb request = RequestPBConverter.convert(registerServiceRequest); - - RegisterServiceRequest serviceRequest = (RegisterServiceRequest) RequestPBConverter.convert( - request); - - Assert.assertEquals(registerServiceRequest.getContainerId(), - serviceRequest.getContainerId()); - Assert.assertEquals(registerServiceRequest.getInfo(), serviceRequest.getInfo()); - Assert.assertEquals(registerServiceRequest.requestType(), serviceRequest.requestType()); - Assert.assertEquals(registerServiceRequest.namespaceType(), NamespaceType.STATE_SERVICE); - } - + @Test + public void testQueryAllServiceRequest() { + + QueryAllServiceRequest queryAllServiceRequest = new QueryAllServiceRequest(); + ServiceRequestPb request = RequestPBConverter.convert(queryAllServiceRequest); + QueryAllServiceRequest serviceRequest = + (QueryAllServiceRequest) RequestPBConverter.convert(request); + + Assert.assertSame(serviceRequest.requestType(), MetaRequestType.QUERY_ALL_SERVICE); + } + + @Test + public void testRegisterServiceRequest() { + RegisterServiceRequest registerServiceRequest = + new RegisterServiceRequest( + NamespaceType.STATE_SERVICE, "1", new HostAndPort("127.0.0.1", 1024)); + ServiceRequestPb request = RequestPBConverter.convert(registerServiceRequest); + + RegisterServiceRequest serviceRequest = + (RegisterServiceRequest) RequestPBConverter.convert(request); + + Assert.assertEquals(registerServiceRequest.getContainerId(), serviceRequest.getContainerId()); + Assert.assertEquals(registerServiceRequest.getInfo(), serviceRequest.getInfo()); + Assert.assertEquals(registerServiceRequest.requestType(), serviceRequest.requestType()); + Assert.assertEquals(registerServiceRequest.namespaceType(), NamespaceType.STATE_SERVICE); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaResponseTest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaResponseTest.java index f7ec441d8..22e266c53 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaResponseTest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-common/src/test/java/org/apache/geaflow/metaserver/model/protocal/MetaResponseTest.java @@ -19,8 +19,8 @@ package org.apache.geaflow.metaserver.model.protocal; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.rpc.HostAndPort; import org.apache.geaflow.metaserver.model.protocal.response.DefaultResponse; import org.apache.geaflow.metaserver.model.protocal.response.ResponsePBConverter; @@ -29,43 +29,43 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class MetaResponseTest { - - @Test - public void testDefaultResponse() { - DefaultResponse defaultResponse = new DefaultResponse(true); - ServiceResultPb result = ResponsePBConverter.convert(defaultResponse); - DefaultResponse response = (DefaultResponse) ResponsePBConverter.convert(result); +import com.google.common.collect.Lists; - Assert.assertTrue(response.isSuccess()); +public class MetaResponseTest { - defaultResponse = new DefaultResponse(false, "error"); - result = ResponsePBConverter.convert(defaultResponse); - response = (DefaultResponse) ResponsePBConverter.convert(result); + @Test + public void testDefaultResponse() { + DefaultResponse defaultResponse = new DefaultResponse(true); + ServiceResultPb result = ResponsePBConverter.convert(defaultResponse); + DefaultResponse response = (DefaultResponse) ResponsePBConverter.convert(result); - Assert.assertFalse(response.isSuccess()); - Assert.assertEquals(response.getMessage(), "error"); - } + Assert.assertTrue(response.isSuccess()); - @Test - public void testServiceResponse() { - ServiceResponse serviceResponse = new ServiceResponse(false, "error"); - ServiceResultPb result = ResponsePBConverter.convert(serviceResponse); - ServiceResponse response = (ServiceResponse) ResponsePBConverter.convert(result); + defaultResponse = new DefaultResponse(false, "error"); + result = ResponsePBConverter.convert(defaultResponse); + response = (DefaultResponse) ResponsePBConverter.convert(result); - Assert.assertFalse(response.isSuccess()); - Assert.assertEquals(response.getMessage(), "error"); + Assert.assertFalse(response.isSuccess()); + Assert.assertEquals(response.getMessage(), "error"); + } - List HostAndPortList = Lists.newArrayList(new HostAndPort("127.0.0.1", 1024), - new HostAndPort("127.0.0.1", 1025)); + @Test + public void testServiceResponse() { + ServiceResponse serviceResponse = new ServiceResponse(false, "error"); + ServiceResultPb result = ResponsePBConverter.convert(serviceResponse); + ServiceResponse response = (ServiceResponse) ResponsePBConverter.convert(result); - serviceResponse = new ServiceResponse(HostAndPortList); - result = ResponsePBConverter.convert(serviceResponse); - response = (ServiceResponse) ResponsePBConverter.convert(result); + Assert.assertFalse(response.isSuccess()); + Assert.assertEquals(response.getMessage(), "error"); - Assert.assertEquals(serviceResponse.getServiceInfos(), response.getServiceInfos()); - Assert.assertEquals(serviceResponse.isSuccess(), response.isSuccess()); + List HostAndPortList = + Lists.newArrayList(new HostAndPort("127.0.0.1", 1024), new HostAndPort("127.0.0.1", 1025)); - } + serviceResponse = new ServiceResponse(HostAndPortList); + result = ResponsePBConverter.convert(serviceResponse); + response = (ServiceResponse) ResponsePBConverter.convert(result); + Assert.assertEquals(serviceResponse.getServiceInfos(), response.getServiceInfos()); + Assert.assertEquals(serviceResponse.isSuccess(), response.isSuccess()); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServer.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServer.java index adb6de9f1..4591a2626 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServer.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServer.java @@ -22,8 +22,8 @@ import static org.apache.geaflow.metaserver.Constants.APP_NAME; import static org.apache.geaflow.metaserver.Constants.META_SERVER; -import com.baidu.brpc.server.RpcServer; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.rpc.HostAndPort; @@ -46,54 +46,57 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class MetaServer implements MetaServerService { +import com.baidu.brpc.server.RpcServer; - private static final Logger LOGGER = LoggerFactory.getLogger(MetaServer.class); - private static final int MIN_PORT = 50000; - private static final int MAX_PORT = 60000; - private RpcServer rpcServer; - private Map namespaceServiceHandlerMap; - private ServiceProvider serviceProvider; +public class MetaServer implements MetaServerService { - public void init(MetaServerContext context) { - Configuration configuration = context.getConfiguration(); - ServiceBuilder serviceBuilder = ServiceBuilderFactory.build(configuration.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE)); - serviceProvider = serviceBuilder.buildProvider(configuration); - namespaceServiceHandlerMap = ServiceHandlerFactory.load(context); - startServer(configuration); - } + private static final Logger LOGGER = LoggerFactory.getLogger(MetaServer.class); + private static final int MIN_PORT = 50000; + private static final int MAX_PORT = 60000; + private RpcServer rpcServer; + private Map namespaceServiceHandlerMap; + private ServiceProvider serviceProvider; - private void startServer(Configuration configuration) { - int port = PortUtil.getPort(MIN_PORT, MAX_PORT); - rpcServer = new RpcServer(port); - rpcServer.registerService(new MetaServerServiceProxy(this)); - rpcServer.start(); + public void init(MetaServerContext context) { + Configuration configuration = context.getConfiguration(); + ServiceBuilder serviceBuilder = + ServiceBuilderFactory.build( + configuration.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE)); + serviceProvider = serviceBuilder.buildProvider(configuration); + namespaceServiceHandlerMap = ServiceHandlerFactory.load(context); + startServer(configuration); + } - HostAndPort info = new HostAndPort(ProcessUtil.getHostIp(), port); + private void startServer(Configuration configuration) { + int port = PortUtil.getPort(MIN_PORT, MAX_PORT); + rpcServer = new RpcServer(port); + rpcServer.registerService(new MetaServerServiceProxy(this)); + rpcServer.start(); - String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); - serviceProvider.update(APP_NAME, appName.getBytes()); + HostAndPort info = new HostAndPort(ProcessUtil.getHostIp(), port); - if (serviceProvider.exists(META_SERVER)) { - serviceProvider.delete(META_SERVER); - } + String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); + serviceProvider.update(APP_NAME, appName.getBytes()); - serviceProvider.createAndWatch(META_SERVER, - SerializerFactory.getKryoSerializer().serialize(info)); - LOGGER.info("{} meta server start at {}", appName, info); + if (serviceProvider.exists(META_SERVER)) { + serviceProvider.delete(META_SERVER); } - @Override - public ServiceResultPb process(ServiceRequestPb serviceRequest) { - MetaRequest request = RequestPBConverter.convert(serviceRequest); - NamespaceServiceHandler handler = namespaceServiceHandlerMap.get(request.namespaceType()); - MetaResponse response = handler.process(request); - return ResponsePBConverter.convert(response); - } + serviceProvider.createAndWatch( + META_SERVER, SerializerFactory.getKryoSerializer().serialize(info)); + LOGGER.info("{} meta server start at {}", appName, info); + } - public void close() { - rpcServer.shutdown(); - serviceProvider.close(); - } + @Override + public ServiceResultPb process(ServiceRequestPb serviceRequest) { + MetaRequest request = RequestPBConverter.convert(serviceRequest); + NamespaceServiceHandler handler = namespaceServiceHandlerMap.get(request.namespaceType()); + MetaResponse response = handler.process(request); + return ResponsePBConverter.convert(response); + } + public void close() { + rpcServer.shutdown(); + serviceProvider.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerContext.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerContext.java index 89ce00223..4f0953f80 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerContext.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerContext.java @@ -23,26 +23,26 @@ public class MetaServerContext { - private Configuration configuration; - private boolean isRecover; + private Configuration configuration; + private boolean isRecover; - public MetaServerContext(Configuration configuration) { - this.configuration = configuration; - } + public MetaServerContext(Configuration configuration) { + this.configuration = configuration; + } - public Configuration getConfiguration() { - return configuration; - } + public Configuration getConfiguration() { + return configuration; + } - public void setConfiguration(Configuration configuration) { - this.configuration = configuration; - } + public void setConfiguration(Configuration configuration) { + this.configuration = configuration; + } - public boolean isRecover() { - return isRecover; - } + public boolean isRecover() { + return isRecover; + } - public void setRecover(boolean recover) { - isRecover = recover; - } + public void setRecover(boolean recover) { + isRecover = recover; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerServiceProxy.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerServiceProxy.java index d59330327..be09b0f06 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerServiceProxy.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/MetaServerServiceProxy.java @@ -25,14 +25,14 @@ public class MetaServerServiceProxy implements MetaServerService { - private final MetaServer metaServer; + private final MetaServer metaServer; - public MetaServerServiceProxy(MetaServer metaServer) { - this.metaServer = metaServer; - } + public MetaServerServiceProxy(MetaServer metaServer) { + this.metaServer = metaServer; + } - @Override - public ServiceResultPb process(ServiceRequestPb request) { - return metaServer.process(request); - } + @Override + public ServiceResultPb process(ServiceRequestPb request) { + return metaServer.process(request); + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/NamespaceServiceHandler.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/NamespaceServiceHandler.java index 9156333a7..81da4b1be 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/NamespaceServiceHandler.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/NamespaceServiceHandler.java @@ -26,18 +26,12 @@ public interface NamespaceServiceHandler { - /** - * Init service handler. - */ - void init(MetaServerContext context); + /** Init service handler. */ + void init(MetaServerContext context); - /** - * Process meta request and return response. - */ - MetaResponse process(MetaRequest request); + /** Process meta request and return response. */ + MetaResponse process(MetaRequest request); - /** - * Namespace of handler. - */ - NamespaceType namespaceType(); + /** Namespace of handler. */ + NamespaceType namespaceType(); } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/ServiceHandlerFactory.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/ServiceHandlerFactory.java index 4b2665d0d..3cdb95067 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/ServiceHandlerFactory.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/api/ServiceHandlerFactory.java @@ -19,9 +19,9 @@ package org.apache.geaflow.metaserver.api; -import com.google.common.collect.Maps; import java.util.Map; import java.util.ServiceLoader; + import org.apache.geaflow.common.mode.JobMode; import org.apache.geaflow.metaserver.MetaServerContext; import org.apache.geaflow.metaserver.local.DefaultServiceHandler; @@ -29,27 +29,29 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.collect.Maps; + public class ServiceHandlerFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(ServiceHandlerFactory.class); - - public static synchronized Map load( - MetaServerContext context) { - JobMode jobMode = JobMode.getJobMode(context.getConfiguration()); - - Map map = Maps.newConcurrentMap(); - ServiceLoader serviceLoader = ServiceLoader.load( - NamespaceServiceHandler.class); - for (NamespaceServiceHandler handler : serviceLoader) { - if (jobMode.name().equals(handler.namespaceType().name())) { - LOGGER.info("{} register service handler", handler.namespaceType()); - handler.init(context); - map.put(handler.namespaceType(), handler); - } - } - DefaultServiceHandler defaultServiceHandler = new DefaultServiceHandler(); - defaultServiceHandler.init(context); - map.put(defaultServiceHandler.namespaceType(), defaultServiceHandler); - return map; + private static final Logger LOGGER = LoggerFactory.getLogger(ServiceHandlerFactory.class); + + public static synchronized Map load( + MetaServerContext context) { + JobMode jobMode = JobMode.getJobMode(context.getConfiguration()); + + Map map = Maps.newConcurrentMap(); + ServiceLoader serviceLoader = + ServiceLoader.load(NamespaceServiceHandler.class); + for (NamespaceServiceHandler handler : serviceLoader) { + if (jobMode.name().equals(handler.namespaceType().name())) { + LOGGER.info("{} register service handler", handler.namespaceType()); + handler.init(context); + map.put(handler.namespaceType(), handler); + } } + DefaultServiceHandler defaultServiceHandler = new DefaultServiceHandler(); + defaultServiceHandler.init(context); + map.put(defaultServiceHandler.namespaceType(), defaultServiceHandler); + return map; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/local/DefaultServiceHandler.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/local/DefaultServiceHandler.java index ca23093fa..4eb4947ac 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/local/DefaultServiceHandler.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/main/java/org/apache/geaflow/metaserver/local/DefaultServiceHandler.java @@ -19,9 +19,9 @@ package org.apache.geaflow.metaserver.local; -import com.google.common.collect.Lists; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.rpc.HostAndPort; import org.apache.geaflow.metaserver.MetaServerContext; import org.apache.geaflow.metaserver.api.NamespaceServiceHandler; @@ -32,34 +32,38 @@ import org.apache.geaflow.metaserver.model.protocal.response.ServiceResponse; import org.apache.geaflow.metaserver.service.NamespaceType; +import com.google.common.collect.Lists; + public class DefaultServiceHandler implements NamespaceServiceHandler { - private Map serviceInfoMap; + private Map serviceInfoMap; - @Override - public void init(MetaServerContext context) { - serviceInfoMap = new ConcurrentHashMap<>(); - } + @Override + public void init(MetaServerContext context) { + serviceInfoMap = new ConcurrentHashMap<>(); + } - @Override - public MetaResponse process(MetaRequest request) { - switch (request.requestType()) { - case REGISTER_SERVICE: { - RegisterServiceRequest registerServiceRequest = (RegisterServiceRequest) request; - serviceInfoMap.put(registerServiceRequest.getContainerId(), - registerServiceRequest.getInfo()); - return new DefaultResponse(true); - } - case QUERY_ALL_SERVICE: { - return new ServiceResponse(Lists.newArrayList(serviceInfoMap.values())); - } - default: - return new DefaultResponse(false, "not support request " + request.requestType()); + @Override + public MetaResponse process(MetaRequest request) { + switch (request.requestType()) { + case REGISTER_SERVICE: + { + RegisterServiceRequest registerServiceRequest = (RegisterServiceRequest) request; + serviceInfoMap.put( + registerServiceRequest.getContainerId(), registerServiceRequest.getInfo()); + return new DefaultResponse(true); + } + case QUERY_ALL_SERVICE: + { + return new ServiceResponse(Lists.newArrayList(serviceInfoMap.values())); } + default: + return new DefaultResponse(false, "not support request " + request.requestType()); } + } - @Override - public NamespaceType namespaceType() { - return NamespaceType.DEFAULT; - } + @Override + public NamespaceType namespaceType() { + return NamespaceType.DEFAULT; + } } diff --git a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/test/java/org/apache/geaflow/metaserver/ZookeeperMetaServerTest.java b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/test/java/org/apache/geaflow/metaserver/ZookeeperMetaServerTest.java index f505bf76c..e7e4e2981 100644 --- a/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/test/java/org/apache/geaflow/metaserver/ZookeeperMetaServerTest.java +++ b/geaflow/geaflow-core/geaflow-metaserver/geaflow-metaserver-engine/src/test/java/org/apache/geaflow/metaserver/ZookeeperMetaServerTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.io.IOException; import java.util.List; + import org.apache.curator.test.TestingServer; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -35,58 +36,57 @@ import org.testng.annotations.Test; public class ZookeeperMetaServerTest { - private TestingServer zkServer; - private File testDir; - private MetaServer metaServer; - private final Configuration configuration = new Configuration(); + private TestingServer zkServer; + private File testDir; + private MetaServer metaServer; + private final Configuration configuration = new Configuration(); - private void before(Configuration configuration) throws Exception { - String jobName = "test_zookeeper" + System.currentTimeMillis(); - testDir = new File(FileUtil.constitutePath("tmp", "zk", jobName)); - configuration.put("geaflow.zookeeper.znode.parent", File.separator + jobName); - configuration.put("geaflow.zookeeper.quorum.servers", "localhost:2181"); - if (testDir.exists()) { - testDir.delete(); - } - if (!testDir.exists()) { - testDir.mkdir(); - } - zkServer = new TestingServer(2181, testDir); - zkServer.start(); - metaServer = new MetaServer(); - metaServer.init(new MetaServerContext(configuration)); + private void before(Configuration configuration) throws Exception { + String jobName = "test_zookeeper" + System.currentTimeMillis(); + testDir = new File(FileUtil.constitutePath("tmp", "zk", jobName)); + configuration.put("geaflow.zookeeper.znode.parent", File.separator + jobName); + configuration.put("geaflow.zookeeper.quorum.servers", "localhost:2181"); + if (testDir.exists()) { + testDir.delete(); } - - @AfterClass - private void after() throws IOException { - if (metaServer != null) { - metaServer.close(); - } - if (zkServer != null) { - zkServer.stop(); - testDir.delete(); - } + if (!testDir.exists()) { + testDir.mkdir(); } + zkServer = new TestingServer(2181, testDir); + zkServer.start(); + metaServer = new MetaServer(); + metaServer.init(new MetaServerContext(configuration)); + } + @AfterClass + private void after() throws IOException { + if (metaServer != null) { + metaServer.close(); + } + if (zkServer != null) { + zkServer.stop(); + testDir.delete(); + } + } - @Test - public void testZookeeperRegister() throws Exception { - this.configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "zookeeper"); - before(this.configuration); - MetaServerClient client = MetaServerClient.getClient(configuration); + @Test + public void testZookeeperRegister() throws Exception { + this.configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "zookeeper"); + before(this.configuration); + MetaServerClient client = MetaServerClient.getClient(configuration); - client.registerService(NamespaceType.DEFAULT, "1", new HostAndPort("127.0.0.1", 1000)); - client.registerService(NamespaceType.DEFAULT, "2", new HostAndPort("127.0.0.1", 10242)); - } + client.registerService(NamespaceType.DEFAULT, "1", new HostAndPort("127.0.0.1", 1000)); + client.registerService(NamespaceType.DEFAULT, "2", new HostAndPort("127.0.0.1", 10242)); + } - @Test(dependsOnMethods = "testZookeeperRegister") - public void testZookeeperQueryService() { - this.configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "zookeeper"); - MetaServerQueryClient queryClient = MetaServerQueryClient.getClient(configuration); - List serviceInfos = queryClient.queryAllServices(NamespaceType.DEFAULT); + @Test(dependsOnMethods = "testZookeeperRegister") + public void testZookeeperQueryService() { + this.configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "zookeeper"); + MetaServerQueryClient queryClient = MetaServerQueryClient.getClient(configuration); + List serviceInfos = queryClient.queryAllServices(NamespaceType.DEFAULT); - Assert.assertEquals(serviceInfos.size(), 2); - Assert.assertTrue(serviceInfos.contains(new HostAndPort("127.0.0.1", 1000))); - Assert.assertTrue(serviceInfos.contains(new HostAndPort("127.0.0.1", 10242))); - } + Assert.assertEquals(serviceInfos.size(), 2); + Assert.assertTrue(serviceInfos.contains(new HostAndPort("127.0.0.1", 1000))); + Assert.assertTrue(serviceInfos.contains(new HostAndPort("127.0.0.1", 10242))); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/IChainCollector.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/IChainCollector.java index 3fb9f77aa..a24182030 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/IChainCollector.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/IChainCollector.java @@ -24,25 +24,21 @@ public interface IChainCollector extends ICollector { - /** - * Process record value. - */ - void process(T value); + /** Process record value. */ + void process(T value); - /** - * Partition value. - */ - default void partition(T value) { - process(value); - } + /** Partition value. */ + default void partition(T value) { + process(value); + } - @Override - default void broadcast(T value) { - throw new GeaflowRuntimeException("chain collector not support broadcast"); - } + @Override + default void broadcast(T value) { + throw new GeaflowRuntimeException("chain collector not support broadcast"); + } - @Override - default void partition(KEY key, T record) { - throw new GeaflowRuntimeException("chain collector not support key partition"); - } + @Override + default void partition(KEY key, T record) { + throw new GeaflowRuntimeException("chain collector not support key partition"); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/OpChainCollector.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/OpChainCollector.java index 4aabf59cd..e4325c3dc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/OpChainCollector.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/collector/chain/OpChainCollector.java @@ -28,31 +28,32 @@ public class OpChainCollector extends AbstractCollector implements IChainCollector { - protected OneInputOperator operator; - - public OpChainCollector(int id, Operator operator) { - super(id); - this.operator = (OneInputOperator) operator; - } - - @Override - public void process(T value) { - try { - this.operator.processElement(value); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - } - - @Override - public String getTag() { - return String.format("%s-%s", ((AbstractOperator) operator).getOpArgs().getOpName(), - ((AbstractOperator) operator).getOpArgs().getOpId()); - } - - @Override - public OutputType getType() { - return OutputType.FORWARD; + protected OneInputOperator operator; + + public OpChainCollector(int id, Operator operator) { + super(id); + this.operator = (OneInputOperator) operator; + } + + @Override + public void process(T value) { + try { + this.operator.processElement(value); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - + } + + @Override + public String getTag() { + return String.format( + "%s-%s", + ((AbstractOperator) operator).getOpArgs().getOpName(), + ((AbstractOperator) operator).getOpArgs().getOpId()); + } + + @Override + public OutputType getType() { + return OutputType.FORWARD; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/context/AbstractRuntimeContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/context/AbstractRuntimeContext.java index c7d88a1e1..a3da6b10d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/context/AbstractRuntimeContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/context/AbstractRuntimeContext.java @@ -27,40 +27,38 @@ public abstract class AbstractRuntimeContext implements RuntimeContext { - protected String workPath; - protected Configuration jobConfig; - protected MetricGroup metricGroup; - protected long windowId; + protected String workPath; + protected Configuration jobConfig; + protected MetricGroup metricGroup; + protected long windowId; - public AbstractRuntimeContext(Configuration jobConfig) { - this.jobConfig = jobConfig; - this.workPath = jobConfig.getString(JOB_WORK_PATH); - } + public AbstractRuntimeContext(Configuration jobConfig) { + this.jobConfig = jobConfig; + this.workPath = jobConfig.getString(JOB_WORK_PATH); + } - public AbstractRuntimeContext(Configuration jobConfig, MetricGroup metricGroup, - String workPath) { - this.jobConfig = jobConfig; - this.metricGroup = metricGroup; - this.workPath = workPath; - } + public AbstractRuntimeContext(Configuration jobConfig, MetricGroup metricGroup, String workPath) { + this.jobConfig = jobConfig; + this.metricGroup = metricGroup; + this.workPath = workPath; + } - @Override - public String getWorkPath() { - return this.workPath; - } + @Override + public String getWorkPath() { + return this.workPath; + } - @Override - public MetricGroup getMetric() { - return this.metricGroup; - } + @Override + public MetricGroup getMetric() { + return this.metricGroup; + } - public void updateWindowId(long windowId) { - this.windowId = windowId; + public void updateWindowId(long windowId) { + this.windowId = windowId; + } - } - - @Override - public long getWindowId() { - return windowId; - } + @Override + public long getWindowId() { + return windowId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Constants.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Constants.java index 27e603dba..97a5c264b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Constants.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Constants.java @@ -21,6 +21,5 @@ public class Constants { - public static final long GRAPH_VERSION = 0L; - + public static final long GRAPH_VERSION = 0L; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/OpArgs.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/OpArgs.java index 364f7a7fb..9fd02af11 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/OpArgs.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/OpArgs.java @@ -25,147 +25,121 @@ public class OpArgs implements Serializable { - private int opId; - private String opName; - private int parallelism; - private boolean windowOp; - private Map config; - private OpType opType; - private ChainStrategy chainStrategy; - // Specify that the operator whether can group together. - private boolean enGroup = true; - - public OpArgs() { - this.config = new HashMap(); - } - - public int getOpId() { - return opId; - } - - public void setOpId(int opId) { - this.opId = opId; - } - - public String getOpName() { - return opName; - } - - public void setOpName(String opName) { - this.opName = opName; - } - - public int getParallelism() { - return parallelism; - } - - public void setParallelism(int parallelism) { - this.parallelism = parallelism; - } - - public Map getConfig() { - return config; - } - - public void setConfig(Map config) { - this.config = config; - } - - public void addConfig(String key, String value) { - if (this.config == null) { - this.config = new HashMap(); - } - this.config.put(key, value); - } - - public void setOpType(OpType type) { - this.opType = type; - } - - public OpType getOpType() { - return opType; - } - - public void setChainStrategy(ChainStrategy chainStrategy) { - this.chainStrategy = chainStrategy; - } - - public ChainStrategy getChainStrategy() { - return chainStrategy; - } - - public boolean isEnGroup() { - return enGroup; - } - - public void setEnGroup(boolean enGroup) { - this.enGroup = enGroup; - } - - public enum ChainStrategy { - - /** - * This policy indicates that an operator cannot have a leading operator - * and can be used as a chain header for other operators. - */ - HEAD, - /** - * This policy does not allow an operator to be linked to a leading operator - * or to be used as a leading operator of another operator. - * This means that the chain can only have one operator. - */ - NEVER, - /** - * Always. - */ - ALWAYS, - } - - public enum OpType { - /** - * Single window source that indicates all window. - */ - SINGLE_WINDOW_SOURCE, - /** - * Multi window source. - */ - MULTI_WINDOW_SOURCE, - /** - * Graph source. - */ - GRAPH_SOURCE, - /** - * Re partition. - */ - RE_PARTITION, - /** - * One input. - */ - ONE_INPUT, - /** - * Two input. - */ - TWO_INPUT, - /** - * Vertex centric compute. - */ - VERTEX_CENTRIC_COMPUTE, - /** - * Vertex centric traversal. - */ - VERTEX_CENTRIC_TRAVERSAL, - /** - * Incremental vertex centric compute. - */ - INC_VERTEX_CENTRIC_COMPUTE, - /** - * Incremental vertex centric traversal. - */ - INC_VERTEX_CENTRIC_TRAVERSAL, - /** - * Vertex centric compute with agg. - */ - VERTEX_CENTRIC_COMPUTE_WITH_AGG, - } - + private int opId; + private String opName; + private int parallelism; + private boolean windowOp; + private Map config; + private OpType opType; + private ChainStrategy chainStrategy; + // Specify that the operator whether can group together. + private boolean enGroup = true; + + public OpArgs() { + this.config = new HashMap(); + } + + public int getOpId() { + return opId; + } + + public void setOpId(int opId) { + this.opId = opId; + } + + public String getOpName() { + return opName; + } + + public void setOpName(String opName) { + this.opName = opName; + } + + public int getParallelism() { + return parallelism; + } + + public void setParallelism(int parallelism) { + this.parallelism = parallelism; + } + + public Map getConfig() { + return config; + } + + public void setConfig(Map config) { + this.config = config; + } + + public void addConfig(String key, String value) { + if (this.config == null) { + this.config = new HashMap(); + } + this.config.put(key, value); + } + + public void setOpType(OpType type) { + this.opType = type; + } + + public OpType getOpType() { + return opType; + } + + public void setChainStrategy(ChainStrategy chainStrategy) { + this.chainStrategy = chainStrategy; + } + + public ChainStrategy getChainStrategy() { + return chainStrategy; + } + + public boolean isEnGroup() { + return enGroup; + } + + public void setEnGroup(boolean enGroup) { + this.enGroup = enGroup; + } + + public enum ChainStrategy { + + /** + * This policy indicates that an operator cannot have a leading operator and can be used as a + * chain header for other operators. + */ + HEAD, + /** + * This policy does not allow an operator to be linked to a leading operator or to be used as a + * leading operator of another operator. This means that the chain can only have one operator. + */ + NEVER, + /** Always. */ + ALWAYS, + } + + public enum OpType { + /** Single window source that indicates all window. */ + SINGLE_WINDOW_SOURCE, + /** Multi window source. */ + MULTI_WINDOW_SOURCE, + /** Graph source. */ + GRAPH_SOURCE, + /** Re partition. */ + RE_PARTITION, + /** One input. */ + ONE_INPUT, + /** Two input. */ + TWO_INPUT, + /** Vertex centric compute. */ + VERTEX_CENTRIC_COMPUTE, + /** Vertex centric traversal. */ + VERTEX_CENTRIC_TRAVERSAL, + /** Incremental vertex centric compute. */ + INC_VERTEX_CENTRIC_COMPUTE, + /** Incremental vertex centric traversal. */ + INC_VERTEX_CENTRIC_TRAVERSAL, + /** Vertex centric compute with agg. */ + VERTEX_CENTRIC_COMPUTE_WITH_AGG, + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Operator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Operator.java index c963c2c5d..b9dc4c445 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Operator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/Operator.java @@ -21,36 +21,27 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.collector.ICollector; public interface Operator extends Serializable { - /** - * Operator open. - */ - void open(OpContext opContext); - - /** - * Operator finish. - */ - void finish(); - - /** - * Operator close. - */ - void close(); - - interface OpContext extends Serializable { - - /** - * Returns the collectors. - */ - List getCollectors(); - - /** - * Returns the runtime context. - */ - RuntimeContext getRuntimeContext(); - } + /** Operator open. */ + void open(OpContext opContext); + + /** Operator finish. */ + void finish(); + + /** Operator close. */ + void close(); + + interface OpContext extends Serializable { + + /** Returns the collectors. */ + List getCollectors(); + + /** Returns the runtime context. */ + RuntimeContext getRuntimeContext(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/AbstractOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/AbstractOperator.java index e430a9055..356de5853 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/AbstractOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/AbstractOperator.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.Function; @@ -48,198 +49,208 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractOperator implements Operator, CancellableTrait { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractOperator.class); - - private static final String ANONYMOUS = "Anonymous"; - private static final String EMPTY = ""; - - protected OpArgs opArgs; - protected FUNC function; - - protected List collectors; - protected List subOperatorList; - protected Map outputTags; - protected boolean enableDebug; - - protected OpContext opContext; - protected RuntimeContext runtimeContext; - protected MetricGroup metricGroup; - protected TicToc ticToc; - protected Meter opInputMeter; - protected Meter opOutputMeter; - protected Histogram opRtHistogram; - - public AbstractOperator() { - this.subOperatorList = new ArrayList<>(); - this.outputTags = new HashMap<>(); - this.opArgs = new OpArgs(); - this.enableDebug = false; +public abstract class AbstractOperator + implements Operator, CancellableTrait { + + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractOperator.class); + + private static final String ANONYMOUS = "Anonymous"; + private static final String EMPTY = ""; + + protected OpArgs opArgs; + protected FUNC function; + + protected List collectors; + protected List subOperatorList; + protected Map outputTags; + protected boolean enableDebug; + + protected OpContext opContext; + protected RuntimeContext runtimeContext; + protected MetricGroup metricGroup; + protected TicToc ticToc; + protected Meter opInputMeter; + protected Meter opOutputMeter; + protected Histogram opRtHistogram; + + public AbstractOperator() { + this.subOperatorList = new ArrayList<>(); + this.outputTags = new HashMap<>(); + this.opArgs = new OpArgs(); + this.enableDebug = false; + } + + public AbstractOperator(FUNC function) { + this(); + this.function = function; + } + + @Override + public void open(OpContext opContext) { + this.opContext = opContext; + Map opConfig = opArgs.getConfig(); + this.runtimeContext = opContext.getRuntimeContext().clone(opConfig); + boolean enableDetailMetric = + this.runtimeContext.getConfiguration().getBoolean(ExecutionConfigKeys.ENABLE_DETAIL_METRIC); + this.metricGroup = + enableDetailMetric + ? MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_FRAMEWORK) + : BlackHoleMetricGroup.INSTANCE; + this.opInputMeter = + metricGroup.meter( + MetricNameFormatter.inputTpsMetricName(this.getClass(), this.opArgs.getOpId())); + this.opOutputMeter = + metricGroup.meter( + MetricNameFormatter.outputTpsMetricName(this.getClass(), this.opArgs.getOpId())); + this.opRtHistogram = + metricGroup.histogram( + MetricNameFormatter.rtMetricName(this.getClass(), this.opArgs.getOpId())); + this.ticToc = new TicToc(); + + LOGGER.info("{} open,enableDebug:{}", this.getClass().getSimpleName(), enableDebug); + + this.collectors = new ArrayList<>(); + if (this.function instanceof RichFunction) { + ((RichFunction) function).open(this.runtimeContext); } - public AbstractOperator(FUNC function) { - this(); - this.function = function; + for (Operator subOperator : subOperatorList) { + OpContext subOpContext = + new DefaultOpContext(opContext.getCollectors(), opContext.getRuntimeContext()); + subOperator.open(subOpContext); + IChainCollector chainCollector = new OpChainCollector<>(opArgs.getOpId(), subOperator); + this.collectors.add(chainCollector); } - @Override - public void open(OpContext opContext) { - this.opContext = opContext; - Map opConfig = opArgs.getConfig(); - this.runtimeContext = opContext.getRuntimeContext().clone(opConfig); - boolean enableDetailMetric = this.runtimeContext - .getConfiguration().getBoolean(ExecutionConfigKeys.ENABLE_DETAIL_METRIC); - this.metricGroup = enableDetailMetric - ? MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_FRAMEWORK) - : BlackHoleMetricGroup.INSTANCE; - this.opInputMeter = metricGroup.meter(MetricNameFormatter.inputTpsMetricName(this.getClass(), this.opArgs.getOpId())); - this.opOutputMeter = metricGroup.meter(MetricNameFormatter.outputTpsMetricName(this.getClass(), this.opArgs.getOpId())); - this.opRtHistogram = metricGroup.histogram(MetricNameFormatter.rtMetricName(this.getClass(), this.opArgs.getOpId())); - this.ticToc = new TicToc(); - - LOGGER.info("{} open,enableDebug:{}", this.getClass().getSimpleName(), enableDebug); - - this.collectors = new ArrayList<>(); - if (this.function instanceof RichFunction) { - ((RichFunction) function).open(this.runtimeContext); - } - - for (Operator subOperator : subOperatorList) { - OpContext subOpContext = new DefaultOpContext(opContext.getCollectors(), opContext.getRuntimeContext()); - subOperator.open(subOpContext); - IChainCollector chainCollector = new OpChainCollector<>(opArgs.getOpId(), subOperator); - this.collectors.add(chainCollector); - } - - this.collectors.addAll(opContext.getCollectors().stream().filter(collector -> collector.getId() == opArgs.getOpId()) + this.collectors.addAll( + opContext.getCollectors().stream() + .filter(collector -> collector.getId() == opArgs.getOpId()) .collect(Collectors.toList())); - for (int i = 0, size = this.collectors.size(); i < size; i++) { - ICollector collector = this.collectors.get(i); - collector.setUp(this.runtimeContext); - if (collector instanceof AbstractCollector) { - ((AbstractCollector) collector).setOutputMetric(this.opOutputMeter); - } - } - + for (int i = 0, size = this.collectors.size(); i < size; i++) { + ICollector collector = this.collectors.get(i); + collector.setUp(this.runtimeContext); + if (collector instanceof AbstractCollector) { + ((AbstractCollector) collector).setOutputMetric(this.opOutputMeter); + } } + } - @Override - public void close() { - if (this.function instanceof RichFunction) { - ((RichFunction) function).close(); - } - for (Operator subOperator : subOperatorList) { - subOperator.close(); - } + @Override + public void close() { + if (this.function instanceof RichFunction) { + ((RichFunction) function).close(); } - - @Override - public void finish() { - if (this.function instanceof RichWindowFunction) { - ((RichWindowFunction) function).finish(); - } - for (int i = 0, size = this.collectors.size(); i < size; i++) { - this.collectors.get(i).finish(); - } - for (Operator operator : this.subOperatorList) { - operator.finish(); - } + for (Operator subOperator : subOperatorList) { + subOperator.close(); } + } - @Override - public void cancel() { - if (this.function instanceof CancellableTrait) { - ((CancellableTrait) function).cancel(); - } + @Override + public void finish() { + if (this.function instanceof RichWindowFunction) { + ((RichWindowFunction) function).finish(); } - - public OpArgs getOpArgs() { - return opArgs; + for (int i = 0, size = this.collectors.size(); i < size; i++) { + this.collectors.get(i).finish(); } - - public List getNextOperators() { - return this.subOperatorList; + for (Operator operator : this.subOperatorList) { + operator.finish(); } + } - public void addNextOperator(Operator operator) { - this.subOperatorList.add(operator); + @Override + public void cancel() { + if (this.function instanceof CancellableTrait) { + ((CancellableTrait) function).cancel(); } - - public Map getOutputTags() { - return outputTags; + } + + public OpArgs getOpArgs() { + return opArgs; + } + + public List getNextOperators() { + return this.subOperatorList; + } + + public void addNextOperator(Operator operator) { + this.subOperatorList.add(operator); + } + + public Map getOutputTags() { + return outputTags; + } + + public FUNC getFunction() { + return function; + } + + public void setFunction(FUNC function) { + this.function = function; + } + + public OpContext getOpContext() { + return opContext; + } + + @Override + public String toString() { + return getOperatorString(0); + } + + /** Returns display name of operator. */ + public String getOperatorString(int level) { + StringBuilder str = new StringBuilder(); + for (int i = 0; i < level; i++) { + str.append("\t"); } - - public FUNC getFunction() { - return function; + str.append(getClass().getSimpleName()) + .append("-") + .append(getIdentify()) + .append("-") + .append(getFunctionString()); + for (Operator subOperator : subOperatorList) { + str.append(((AbstractOperator) subOperator).getOperatorString(level + 1)); } - - public void setFunction(FUNC function) { - this.function = function; + return str.toString(); + } + + public String getIdentify() { + if (StringUtils.isNotBlank(opArgs.getOpName())) { + return opArgs.getOpName(); + } else { + return String.valueOf(opArgs.getOpId()); } - - public OpContext getOpContext() { - return opContext; + } + + private String getFunctionString() { + if (function != null) { + if (function.getClass().getSimpleName().length() == 0) { + return ANONYMOUS; + } + return function.getClass().getSimpleName(); } + return EMPTY; + } - @Override - public String toString() { - return getOperatorString(0); - } + public static class DefaultOpContext implements OpContext { - /** - * Returns display name of operator. - */ - public String getOperatorString(int level) { - StringBuilder str = new StringBuilder(); - for (int i = 0; i < level; i++) { - str.append("\t"); - } - str.append(getClass().getSimpleName()).append("-").append(getIdentify()).append("-") - .append(getFunctionString()); - for (Operator subOperator : subOperatorList) { - str.append(((AbstractOperator) subOperator).getOperatorString(level + 1)); - } - return str.toString(); - } + private final RuntimeContext runtimeContext; + private final List collectors; - public String getIdentify() { - if (StringUtils.isNotBlank(opArgs.getOpName())) { - return opArgs.getOpName(); - } else { - return String.valueOf(opArgs.getOpId()); - } + public DefaultOpContext(List collectors, RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + this.collectors = collectors; } - private String getFunctionString() { - if (function != null) { - if (function.getClass().getSimpleName().length() == 0) { - return ANONYMOUS; - } - return function.getClass().getSimpleName(); - } - return EMPTY; + @Override + public List getCollectors() { + return this.collectors; } - public static class DefaultOpContext implements OpContext { - - private final RuntimeContext runtimeContext; - private final List collectors; - - public DefaultOpContext(List collectors, RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - this.collectors = collectors; - } - - - @Override - public List getCollectors() { - return this.collectors; - } - - @Override - public RuntimeContext getRuntimeContext() { - return this.runtimeContext; - } + @Override + public RuntimeContext getRuntimeContext() { + return this.runtimeContext; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/io/SourceOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/io/SourceOperator.java index e9bfec926..a5be37f12 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/io/SourceOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/io/SourceOperator.java @@ -23,9 +23,6 @@ public interface SourceOperator extends Operator { - /** - * Fetch data from source function with windowId. - */ - R emit(long windowId) throws Exception; - + /** Fetch data from source function with windowId. */ + R emit(long windowId) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractOneInputOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractOneInputOperator.java index edaa768e2..fddb12af0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractOneInputOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractOneInputOperator.java @@ -24,29 +24,28 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractOneInputOperator extends - AbstractStreamOperator implements OneInputOperator { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractOneInputOperator.class); - - public AbstractOneInputOperator() { - super(); - opArgs.setOpType(OpArgs.OpType.ONE_INPUT); - } - - public AbstractOneInputOperator(FUNC func) { - super(func); - opArgs.setOpType(OpArgs.OpType.ONE_INPUT); - } - - @Override - public void processElement(T value) throws Exception { - this.ticToc.ticNano(); - process(value); - this.opRtHistogram.update(this.ticToc.tocNano() / 1000); - this.opInputMeter.mark(); - } - - protected abstract void process(T value) throws Exception; - +public abstract class AbstractOneInputOperator + extends AbstractStreamOperator implements OneInputOperator { + + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractOneInputOperator.class); + + public AbstractOneInputOperator() { + super(); + opArgs.setOpType(OpArgs.OpType.ONE_INPUT); + } + + public AbstractOneInputOperator(FUNC func) { + super(func); + opArgs.setOpType(OpArgs.OpType.ONE_INPUT); + } + + @Override + public void processElement(T value) throws Exception { + this.ticToc.ticNano(); + process(value); + this.opRtHistogram.update(this.ticToc.tocNano() / 1000); + this.opInputMeter.mark(); + } + + protected abstract void process(T value) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractStreamOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractStreamOperator.java index c269259c2..5c500a019 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractStreamOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractStreamOperator.java @@ -27,36 +27,35 @@ public abstract class AbstractStreamOperator extends AbstractOperator { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamOperator.class); - public AbstractStreamOperator() { - super(); - } + public AbstractStreamOperator() { + super(); + } - public AbstractStreamOperator(FUNC func) { - super(func); - } + public AbstractStreamOperator(FUNC func) { + super(func); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + } - protected final void collectValue(T value) { - if (value == null) { - return; - } - for (int i = 0, size = collectors.size(); i < size; i++) { - ICollector collector = this.collectors.get(i); - collector.partition(value); - } + protected final void collectValue(T value) { + if (value == null) { + return; } - - protected final void collectKValue(KEY key, VALUE value) { - for (int i = 0, size = collectors.size(); i < size; i++) { - ICollector collector = this.collectors.get(i); - collector.partition(key, value); - } + for (int i = 0, size = collectors.size(); i < size; i++) { + ICollector collector = this.collectors.get(i); + collector.partition(value); } + } + protected final void collectKValue(KEY key, VALUE value) { + for (int i = 0, size = collectors.size(); i < size; i++) { + ICollector collector = this.collectors.get(i); + collector.partition(key, value); + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractTwoInputOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractTwoInputOperator.java index b8d43304f..dcc5a2f4d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractTwoInputOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/AbstractTwoInputOperator.java @@ -23,37 +23,36 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractTwoInputOperator extends - AbstractStreamOperator implements TwoInputOperator { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTwoInputOperator.class); - - public AbstractTwoInputOperator() { - super(); - } - - public AbstractTwoInputOperator(FUNC func) { - super(func); - } - - @Override - public void processElementOne(T value) throws Exception { - this.ticToc.ticNano(); - processRecordOne(value); - this.opRtHistogram.update(this.ticToc.tocNano() / 1000); - this.opInputMeter.mark(); - } - - @Override - public void processElementTwo(U value) throws Exception { - this.ticToc.ticNano(); - processRecordTwo(value); - this.opRtHistogram.update(this.ticToc.tocNano() / 1000); - this.opInputMeter.mark(); - } - - protected abstract void processRecordOne(T value) throws Exception; - - protected abstract void processRecordTwo(U value) throws Exception; - +public abstract class AbstractTwoInputOperator + extends AbstractStreamOperator implements TwoInputOperator { + + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTwoInputOperator.class); + + public AbstractTwoInputOperator() { + super(); + } + + public AbstractTwoInputOperator(FUNC func) { + super(func); + } + + @Override + public void processElementOne(T value) throws Exception { + this.ticToc.ticNano(); + processRecordOne(value); + this.opRtHistogram.update(this.ticToc.tocNano() / 1000); + this.opInputMeter.mark(); + } + + @Override + public void processElementTwo(U value) throws Exception { + this.ticToc.ticNano(); + processRecordTwo(value); + this.opRtHistogram.update(this.ticToc.tocNano() / 1000); + this.opInputMeter.mark(); + } + + protected abstract void processRecordOne(T value) throws Exception; + + protected abstract void processRecordTwo(U value) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/OneInputOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/OneInputOperator.java index e5c4bbc32..85fdb3340 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/OneInputOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/OneInputOperator.java @@ -23,9 +23,6 @@ public interface OneInputOperator extends Operator { - /** - * Process element value. - */ - void processElement(T value) throws Exception; - + /** Process element value. */ + void processElement(T value) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/TwoInputOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/TwoInputOperator.java index 5fa9096ae..f9080cd1e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/TwoInputOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/base/window/TwoInputOperator.java @@ -23,14 +23,9 @@ public interface TwoInputOperator extends Operator { - /** - * Process first element value. - */ - void processElementOne(T value) throws Exception; - - /** - * Process second element value. - */ - void processElementTwo(U value) throws Exception; + /** Process first element value. */ + void processElementOne(T value) throws Exception; + /** Process second element value. */ + void processElementTwo(U value) throws Exception; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/AbstractGraphVertexCentricOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/AbstractGraphVertexCentricOp.java index cbb8c663f..3b2dbe174 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/AbstractGraphVertexCentricOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/AbstractGraphVertexCentricOp.java @@ -19,10 +19,10 @@ package org.apache.geaflow.operator.impl.graph.algo.vc; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.VertexCentricAlgo; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -57,208 +57,241 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractGraphVertexCentricOp> extends - AbstractOperator implements IGraphVertexCentricOp, IteratorOperator { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractGraphVertexCentricOp.class); +public abstract class AbstractGraphVertexCentricOp< + K, VV, EV, M, FUNC extends VertexCentricAlgo> + extends AbstractOperator + implements IGraphVertexCentricOp, IteratorOperator { - protected int taskId; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractGraphVertexCentricOp.class); - protected VertexCentricCombineFunction msgCombineFunction; - protected IGraphVCPartition graphVCPartition; - protected final GraphViewDesc graphViewDesc; + protected int taskId; - protected long maxIterations; - protected long iterations; - protected long windowId; + protected VertexCentricCombineFunction msgCombineFunction; + protected IGraphVCPartition graphVCPartition; + protected final GraphViewDesc graphViewDesc; - protected KeyGroup keyGroup; - protected KeyGroup taskKeyGroup; - protected GraphState graphState; - protected IGraphMsgBox graphMsgBox; - protected boolean shareEnable; + protected long maxIterations; + protected long iterations; + protected long windowId; - protected Map collectorMap; - protected ICollector> messageCollector; - protected Meter msgMeter; + protected KeyGroup keyGroup; + protected KeyGroup taskKeyGroup; + protected GraphState graphState; + protected IGraphMsgBox graphMsgBox; + protected boolean shareEnable; - public AbstractGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { - super(func); - this.graphViewDesc = graphViewDesc; - this.maxIterations = func.getMaxIterationCount(); - opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); - } + protected Map collectorMap; + protected ICollector> messageCollector; + protected Meter msgMeter; + + public AbstractGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { + super(func); + this.graphViewDesc = graphViewDesc; + this.maxIterations = func.getMaxIterationCount(); + opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); + @Override + public void open(OpContext opContext) { + super.open(opContext); - this.msgCombineFunction = function.getCombineFunction(); - this.graphVCPartition = function.getGraphPartition(); - this.windowId = runtimeContext.getWindowId(); - this.msgMeter = this.metricGroup.meter( + this.msgCombineFunction = function.getCombineFunction(); + this.graphVCPartition = function.getGraphPartition(); + this.windowId = runtimeContext.getWindowId(); + this.msgMeter = + this.metricGroup.meter( MetricNameFormatter.iterationMsgMetricName(this.getClass(), this.opArgs.getOpId())); - shareEnable = runtimeContext.getConfiguration().getBoolean(FrameworkConfigKeys.SERVICE_SHARE_ENABLE); - - GraphStateDescriptor desc = buildGraphStateDesc(opArgs.getOpName()); - desc.withMetricGroup(runtimeContext.getMetric()); - this.graphState = StateFactory.buildGraphState(desc, runtimeContext.getConfiguration()); - LOGGER.info("ThreadId {}, open graphState", Thread.currentThread().getId()); - if (!shareEnable) { - this.taskKeyGroup = keyGroup; - LOGGER.info("recovery graph state {}", graphState); - recover(); - } else { - load(); - LOGGER.info("processIndex {} taskIndex {} load shard {}, load graph state {}", - runtimeContext.getTaskArgs().getProcessIndex(), runtimeContext.getTaskArgs().getTaskIndex(), keyGroup, graphState); - } - - collectorMap = new HashMap<>(); - for (ICollector collector : this.collectors) { - collectorMap.put(collector.getTag(), collector); - } - this.messageCollector = collectorMap.get(GraphRecordNames.Message.name()); - if (this.messageCollector instanceof AbstractCollector) { - ((AbstractCollector) this.messageCollector).setOutputMetric(this.msgMeter); - } - this.graphMsgBox = GraphMsgBoxFactory.buildMessageBox(this.messageCollector, this.msgCombineFunction); - } + shareEnable = + runtimeContext.getConfiguration().getBoolean(FrameworkConfigKeys.SERVICE_SHARE_ENABLE); - protected GraphStateDescriptor buildGraphStateDesc(String name) { - this.taskId = runtimeContext.getTaskArgs().getTaskId(); - - int containerNum = runtimeContext.getConfiguration().getInteger(ExecutionConfigKeys.CONTAINER_NUM); - int processIndex = runtimeContext.getTaskArgs().getProcessIndex(); - int taskIndex = shareEnable ? processIndex : runtimeContext.getTaskArgs().getTaskIndex(); - int taskPara = shareEnable ? containerNum : runtimeContext.getTaskArgs().getParallelism(); - BackendType backendType = graphViewDesc.getBackend(); - GraphStateDescriptor desc = GraphStateDescriptor.build(graphViewDesc.getName() - , backendType.name()); - - int maxPara = graphViewDesc.getShardNum(); - Preconditions.checkArgument(taskPara <= maxPara, - String.format("task parallelism '%s' must be <= shard num(max parallelism) '%s'", - taskPara, maxPara)); - - keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, taskPara, taskIndex); - IKeyGroupAssigner keyGroupAssigner = - KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxPara); - desc.withKeyGroup(keyGroup); - desc.withKeyGroupAssigner(keyGroupAssigner); - if (shareEnable) { - LOGGER.info("enable state singleton"); - desc.withSingleton(); - desc.withStateMode(StateMode.RDONLY); - } - LOGGER.info("opName:{} taskId:{} taskIndex:{} keyGroup:{} containerNum:{} processIndex: {} real taskIndex:{}", this.opArgs.getOpName(), - taskId, - taskIndex, - desc.getKeyGroup(), containerNum, processIndex, runtimeContext.getTaskArgs().getTaskIndex()); - return desc; + GraphStateDescriptor desc = buildGraphStateDesc(opArgs.getOpName()); + desc.withMetricGroup(runtimeContext.getMetric()); + this.graphState = StateFactory.buildGraphState(desc, runtimeContext.getConfiguration()); + LOGGER.info("ThreadId {}, open graphState", Thread.currentThread().getId()); + if (!shareEnable) { + this.taskKeyGroup = keyGroup; + LOGGER.info("recovery graph state {}", graphState); + recover(); + } else { + load(); + LOGGER.info( + "processIndex {} taskIndex {} load shard {}, load graph state {}", + runtimeContext.getTaskArgs().getProcessIndex(), + runtimeContext.getTaskArgs().getTaskIndex(), + keyGroup, + graphState); } - @Override - public void processMessage(IGraphMessage graphMessage) { - if (enableDebug) { - LOGGER.info("taskId:{} windowId:{} Iteration:{} add message:{}", taskId, windowId, - iterations, - graphMessage); - } - K vertexId = graphMessage.getTargetVId(); - while (graphMessage.hasNext()) { - this.graphMsgBox.addInMessages(vertexId, graphMessage.next()); - } - this.opInputMeter.mark(); + collectorMap = new HashMap<>(); + for (ICollector collector : this.collectors) { + collectorMap.put(collector.getTag(), collector); + } + this.messageCollector = collectorMap.get(GraphRecordNames.Message.name()); + if (this.messageCollector instanceof AbstractCollector) { + ((AbstractCollector) this.messageCollector).setOutputMetric(this.msgMeter); } + this.graphMsgBox = + GraphMsgBoxFactory.buildMessageBox(this.messageCollector, this.msgCombineFunction); + } + + protected GraphStateDescriptor buildGraphStateDesc(String name) { + this.taskId = runtimeContext.getTaskArgs().getTaskId(); + + int containerNum = + runtimeContext.getConfiguration().getInteger(ExecutionConfigKeys.CONTAINER_NUM); + int processIndex = runtimeContext.getTaskArgs().getProcessIndex(); + int taskIndex = shareEnable ? processIndex : runtimeContext.getTaskArgs().getTaskIndex(); + int taskPara = shareEnable ? containerNum : runtimeContext.getTaskArgs().getParallelism(); + BackendType backendType = graphViewDesc.getBackend(); + GraphStateDescriptor desc = + GraphStateDescriptor.build(graphViewDesc.getName(), backendType.name()); - @Override - public void initIteration(long iterations) { - this.iterations = iterations; - this.windowId = opContext.getRuntimeContext().getWindowId(); - ((AbstractRuntimeContext) this.runtimeContext).updateWindowId(windowId); - if (enableDebug) { - LOGGER.info("taskId:{} windowId:{} init Iteration:{}", taskId, windowId, iterations); - } - this.iterations = iterations; - if (function instanceof RichIteratorFunction) { - ((RichIteratorFunction) function).initIteration(iterations); - } + int maxPara = graphViewDesc.getShardNum(); + Preconditions.checkArgument( + taskPara <= maxPara, + String.format( + "task parallelism '%s' must be <= shard num(max parallelism) '%s'", taskPara, maxPara)); + + keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, taskPara, taskIndex); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxPara); + desc.withKeyGroup(keyGroup); + desc.withKeyGroupAssigner(keyGroupAssigner); + if (shareEnable) { + LOGGER.info("enable state singleton"); + desc.withSingleton(); + desc.withStateMode(StateMode.RDONLY); } + LOGGER.info( + "opName:{} taskId:{} taskIndex:{} keyGroup:{} containerNum:{} processIndex: {} real" + + " taskIndex:{}", + this.opArgs.getOpName(), + taskId, + taskIndex, + desc.getKeyGroup(), + containerNum, + processIndex, + runtimeContext.getTaskArgs().getTaskIndex()); + return desc; + } - @Override - public long getMaxIterationCount() { - return this.maxIterations; + @Override + public void processMessage(IGraphMessage graphMessage) { + if (enableDebug) { + LOGGER.info( + "taskId:{} windowId:{} Iteration:{} add message:{}", + taskId, + windowId, + iterations, + graphMessage); + } + K vertexId = graphMessage.getTargetVId(); + while (graphMessage.hasNext()) { + this.graphMsgBox.addInMessages(vertexId, graphMessage.next()); } + this.opInputMeter.mark(); + } - @Override - public void finishIteration(long iteration) { - this.ticToc.tic(); - this.doFinishIteration(iteration); - this.metricGroup.histogram( - MetricNameFormatter.iterationFinishMetricName(this.getClass(), this.opArgs.getOpId(), iteration) - ).update(this.ticToc.toc()); + @Override + public void initIteration(long iterations) { + this.iterations = iterations; + this.windowId = opContext.getRuntimeContext().getWindowId(); + ((AbstractRuntimeContext) this.runtimeContext).updateWindowId(windowId); + if (enableDebug) { + LOGGER.info("taskId:{} windowId:{} init Iteration:{}", taskId, windowId, iterations); } + this.iterations = iterations; + if (function instanceof RichIteratorFunction) { + ((RichIteratorFunction) function).initIteration(iterations); + } + } - public abstract void doFinishIteration(long iteration); + @Override + public long getMaxIterationCount() { + return this.maxIterations; + } - @Override - public void close() { - this.graphMsgBox.clearInBox(); - this.graphMsgBox.clearOutBox(); - if (!shareEnable) { - this.graphState.manage().operate().close(); - } - } + @Override + public void finishIteration(long iteration) { + this.ticToc.tic(); + this.doFinishIteration(iteration); + this.metricGroup + .histogram( + MetricNameFormatter.iterationFinishMetricName( + this.getClass(), this.opArgs.getOpId(), iteration)) + .update(this.ticToc.toc()); + } - protected void recover() { - LOGGER.info("opName: {} will do recover, windowId: {}", this.opArgs.getOpName(), this.windowId); - long lastCheckPointId = getLatestViewVersion(); - if (lastCheckPointId >= 0) { - LOGGER.info("opName: {} do recover to state VersionId: {}", this.opArgs.getOpName(), - lastCheckPointId); - graphState.manage().operate().setCheckpointId(lastCheckPointId); - graphState.manage().operate().recover(); - } + public abstract void doFinishIteration(long iteration); + + @Override + public void close() { + this.graphMsgBox.clearInBox(); + this.graphMsgBox.clearOutBox(); + if (!shareEnable) { + this.graphState.manage().operate().close(); } + } - public String getGraphViewName() { - return graphViewDesc.getName(); + protected void recover() { + LOGGER.info("opName: {} will do recover, windowId: {}", this.opArgs.getOpName(), this.windowId); + long lastCheckPointId = getLatestViewVersion(); + if (lastCheckPointId >= 0) { + LOGGER.info( + "opName: {} do recover to state VersionId: {}", + this.opArgs.getOpName(), + lastCheckPointId); + graphState.manage().operate().setCheckpointId(lastCheckPointId); + graphState.manage().operate().recover(); } + } + + public String getGraphViewName() { + return graphViewDesc.getName(); + } - protected void load() { - LOGGER.info("opName: {} will do load, windowId: {}", this.opArgs.getOpName(), this.windowId); - long lastCheckPointId = getLatestViewVersion(); - long checkPointId = lastCheckPointId < 0 ? 0 : lastCheckPointId; - LOGGER.info("opName: {} do load, ViewMetaBookKeeper version: {}, checkPointId {}", - this.opArgs.getOpName(), lastCheckPointId, checkPointId); + protected void load() { + LOGGER.info("opName: {} will do load, windowId: {}", this.opArgs.getOpName(), this.windowId); + long lastCheckPointId = getLatestViewVersion(); + long checkPointId = lastCheckPointId < 0 ? 0 : lastCheckPointId; + LOGGER.info( + "opName: {} do load, ViewMetaBookKeeper version: {}, checkPointId {}", + this.opArgs.getOpName(), + lastCheckPointId, + checkPointId); - LoadOption loadOption = LoadOption.of(); - this.taskKeyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + LoadOption loadOption = LoadOption.of(); + this.taskKeyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( graphViewDesc.getShardNum(), runtimeContext.getTaskArgs().getParallelism(), runtimeContext.getTaskArgs().getTaskIndex()); - loadOption.withKeyGroup(this.taskKeyGroup); - loadOption.withCheckpointId(checkPointId); - graphState.manage().operate().load(loadOption); - LOGGER.info("opName: {} task key group {} do load successfully", this.opArgs.getOpName(), this.taskKeyGroup); - } + loadOption.withKeyGroup(this.taskKeyGroup); + loadOption.withCheckpointId(checkPointId); + graphState.manage().operate().load(loadOption); + LOGGER.info( + "opName: {} task key group {} do load successfully", + this.opArgs.getOpName(), + this.taskKeyGroup); + } - private long getLatestViewVersion() { - long lastCheckPointId; - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), - this.runtimeContext.getConfiguration()); - lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); - LOGGER.info("opName: {} will do recover or load, ViewMetaBookKeeper version: {}", - this.opArgs.getOpName(), lastCheckPointId); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - return lastCheckPointId; + private long getLatestViewVersion() { + long lastCheckPointId; + try { + ViewMetaBookKeeper keeper = + new ViewMetaBookKeeper(graphViewDesc.getName(), this.runtimeContext.getConfiguration()); + lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); + LOGGER.info( + "opName: {} will do recover or load, ViewMetaBookKeeper version: {}", + this.opArgs.getOpName(), + lastCheckPointId); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } - + return lastCheckPointId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpAggregator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpAggregator.java index dfcd11784..5df17858c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpAggregator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpAggregator.java @@ -40,102 +40,120 @@ /** * Do graph aggregate following by the graph vertex centric operator. * - * @param The id type of vertex/edge. - * @param The value type of vertex. - * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. - * @param The type of partial aggregator. - * @param The type of partial aggregate result. - * @param The type of global aggregator. - * @param The type of global aggregate result. + * @param The id type of vertex/edge. + * @param The value type of vertex. + * @param The value type of edge. + * @param The message type during iterations. + * @param The type of aggregate input iterm. + * @param The type of partial aggregator. + * @param The type of partial aggregate result. + * @param The type of global aggregator. + * @param The type of global aggregate result. * @param The type of algo function in operator. */ -public class GraphVertexCentricOpAggregator & GraphAggregationAlgo> { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphVertexCentricOpAggregator.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(GraphVertexCentricOpAggregator.class); - private AbstractGraphVertexCentricOp operator; - private long iteration; + private AbstractGraphVertexCentricOp operator; + private long iteration; - protected VertexCentricAggregateFunction.IPartialGraphAggFunction partialGraphAggFunction; - protected PartialAggContextImpl partialAggContextImpl; + protected VertexCentricAggregateFunction.IPartialGraphAggFunction + partialGraphAggFunction; + protected PartialAggContextImpl partialAggContextImpl; - protected PA partialAgg; - protected PR partialResult; - protected GR globalResult; + protected PA partialAgg; + protected PR partialResult; + protected GR globalResult; - private ICollector aggregateCollector; + private ICollector aggregateCollector; - public GraphVertexCentricOpAggregator(AbstractGraphVertexCentricOp operator) { - this.operator = operator; - } + public GraphVertexCentricOpAggregator(AbstractGraphVertexCentricOp operator) { + this.operator = operator; + } - public void open(VertexCentricAggContextFunction aggFunction) { - // Partial agg function. - this.partialGraphAggFunction = operator.getFunction().getAggregateFunction().getPartialAggregation(); - this.partialAggContextImpl = new PartialAggContextImpl(); - this.partialAgg = this.partialGraphAggFunction.create(this.partialAggContextImpl); + public void open(VertexCentricAggContextFunction aggFunction) { + // Partial agg function. + this.partialGraphAggFunction = + operator.getFunction().getAggregateFunction().getPartialAggregation(); + this.partialAggContextImpl = new PartialAggContextImpl(); + this.partialAgg = this.partialGraphAggFunction.create(this.partialAggContextImpl); - VertexCentricAggContextImpl aggContext = new VertexCentricAggContextImpl(); - aggFunction.initContext(aggContext); + VertexCentricAggContextImpl aggContext = new VertexCentricAggContextImpl(); + aggFunction.initContext(aggContext); - boolean enableDetailMetric = Configuration.getBoolean(ExecutionConfigKeys.ENABLE_DETAIL_METRIC, - this.operator.getOpArgs().getConfig()); - MetricGroup metricGroup = enableDetailMetric + boolean enableDetailMetric = + Configuration.getBoolean( + ExecutionConfigKeys.ENABLE_DETAIL_METRIC, this.operator.getOpArgs().getConfig()); + MetricGroup metricGroup = + enableDetailMetric ? MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_FRAMEWORK) : BlackHoleMetricGroup.INSTANCE; - Meter aggMeter = metricGroup.meter( - MetricNameFormatter.iterationAggMetricName(this.getClass(), operator.getOpArgs().getOpId())); - - this.aggregateCollector = operator.collectorMap.get(RecordArgs.GraphRecordNames.Aggregate.name()); - if (this.aggregateCollector instanceof AbstractCollector) { - ((AbstractCollector) this.aggregateCollector).setOutputMetric(aggMeter); - } - } - - public void initIteration(long iteration) { - this.iteration = iteration; + Meter aggMeter = + metricGroup.meter( + MetricNameFormatter.iterationAggMetricName( + this.getClass(), operator.getOpArgs().getOpId())); + + this.aggregateCollector = + operator.collectorMap.get(RecordArgs.GraphRecordNames.Aggregate.name()); + if (this.aggregateCollector instanceof AbstractCollector) { + ((AbstractCollector) this.aggregateCollector).setOutputMetric(aggMeter); } - - public void finishIteration(long iteration) { - if (partialResult != null) { - this.partialGraphAggFunction.finish(this.partialResult); - LOGGER.info("iterationId:{} partial result :{}", iteration, partialResult); - aggregateCollector.finish(); - this.partialResult = null; - } + } + + public void initIteration(long iteration) { + this.iteration = iteration; + } + + public void finishIteration(long iteration) { + if (partialResult != null) { + this.partialGraphAggFunction.finish(this.partialResult); + LOGGER.info("iterationId:{} partial result :{}", iteration, partialResult); + aggregateCollector.finish(); + this.partialResult = null; } + } - public void processAggregateResult(GR result) { - this.globalResult = result; - this.partialAgg = this.partialGraphAggFunction.create(this.partialAggContextImpl); - } + public void processAggregateResult(GR result) { + this.globalResult = result; + this.partialAgg = this.partialGraphAggFunction.create(this.partialAggContextImpl); + } - class VertexCentricAggContextImpl implements VertexCentricAggContextFunction.VertexCentricAggContext { + class VertexCentricAggContextImpl + implements VertexCentricAggContextFunction.VertexCentricAggContext { - @Override - public GR getAggregateResult() { - return globalResult; - } + @Override + public GR getAggregateResult() { + return globalResult; + } - @Override - public void aggregate(I i) { - partialResult = partialGraphAggFunction.aggregate(i, partialAgg); - } + @Override + public void aggregate(I i) { + partialResult = partialGraphAggFunction.aggregate(i, partialAgg); } + } - class PartialAggContextImpl implements VertexCentricAggregateFunction.IPartialAggContext { + class PartialAggContextImpl implements VertexCentricAggregateFunction.IPartialAggContext { - @Override - public long getIteration() { - return iteration; - } + @Override + public long getIteration() { + return iteration; + } - @Override - public void collect(PR result) { - aggregateCollector.partition(result); - } + @Override + public void collect(PR result) { + aggregateCollector.partition(result); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpFactory.java index dd549e7f3..d5cc20860 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/GraphVertexCentricOpFactory.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.algo.vc; import java.util.List; + import org.apache.geaflow.api.graph.compute.IncVertexCentricAggCompute; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; import org.apache.geaflow.api.graph.compute.VertexCentricAggCompute; @@ -49,142 +50,169 @@ public class GraphVertexCentricOpFactory { - public static IGraphVertexCentricOp buildStaticGraphVertexCentricComputeOp( - GraphViewDesc graphViewDesc, - VertexCentricCompute vertexCentricCompute) { - return new StaticGraphVertexCentricComputeOp<>(graphViewDesc, vertexCentricCompute); - } - - - public static IGraphVertexCentricAggOp buildStaticGraphVertexCentricAggComputeOp( - GraphViewDesc graphViewDesc, - VertexCentricAggCompute vertexCentricAggCompute) { - return new StaticGraphVertexCentricComputeWithAggOp<>(graphViewDesc, vertexCentricAggCompute); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, VertexCentricTraversal vertexCentricTraversal) { - return new StaticGraphVertexCentricTraversalStartByStreamOp<>(graphViewDesc, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( - GraphViewDesc graphViewDesc, - VertexCentricAggTraversal vertexCentricTraversal) { - return new StaticGraphVertexCentricTraversalStartByStreamWithAggOp<>(graphViewDesc, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalAllOp( - GraphViewDesc graphViewDesc, VertexCentricTraversal vertexCentricTraversal) { - return new StaticGraphVertexCentricTraversalAllOp<>(graphViewDesc, vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalAllOp( - GraphViewDesc graphViewDesc, - VertexCentricAggTraversal vertexCentricTraversal) { - return new StaticGraphVertexCentricTraversalAllWithAggOp<>(graphViewDesc, vertexCentricTraversal); - } - - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, VertexCentricTraversal vertexCentricTraversal, - VertexBeginTraversalRequest traversalRequest) { - return new StaticGraphVertexCentricTraversalStartByIdsOp<>(graphViewDesc, traversalRequest, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( - GraphViewDesc graphViewDesc, - VertexCentricAggTraversal vertexCentricTraversal, - VertexBeginTraversalRequest traversalRequest) { - return new StaticGraphVertexCentricTraversalStartByIdsWithAggOp<>(graphViewDesc, traversalRequest, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - VertexCentricTraversal vertexCentricTraversal, - List> traversalRequests) { - return new StaticGraphVertexCentricTraversalStartByIdsOp<>(graphViewDesc, traversalRequests, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( - GraphViewDesc graphViewDesc, - VertexCentricAggTraversal vertexCentricTraversal, - List> traversalRequests) { - return new StaticGraphVertexCentricTraversalStartByIdsWithAggOp<>(graphViewDesc, traversalRequests, - vertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricComputeOp( - GraphViewDesc graphViewDesc, IncVertexCentricCompute incVertexCentricCompute) { - return new DynamicGraphVertexCentricComputeOp(graphViewDesc, incVertexCentricCompute); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricAggComputeOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggCompute incVertexCentricCompute) { - return new DynamicGraphVertexCentricComputeWithAggOp(graphViewDesc, incVertexCentricCompute); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricTraversal incVertexCentricTraversal, - VertexBeginTraversalRequest traversalRequest) { - return new DynamicGraphVertexCentricTraversalStartByIdsOp<>(graphViewDesc, traversalRequest, - incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggTraversal incVertexCentricTraversal, - VertexBeginTraversalRequest traversalRequest) { - return new DynamicGraphVertexCentricTraversalStartByIdsWithAggOp<>(graphViewDesc, traversalRequest, - incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalAllOp( - GraphViewDesc graphViewDesc, - IncVertexCentricTraversal incVertexCentricTraversal) { - return new DynamicGraphVertexCentricTraversalAllOp<>(graphViewDesc, incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalAllOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggTraversal incVertexCentricTraversal) { - return new DynamicGraphVertexCentricTraversalAllWithAggOp<>(graphViewDesc, incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricTraversal incVertexCentricTraversal) { - return new DynamicGraphVertexCentricTraversalStartByStreamOp<>(graphViewDesc, incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggTraversal incVertexCentricTraversal) { - return new DynamicGraphVertexCentricTraversalStartByStreamWithAggOp<>(graphViewDesc, incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricTraversal incVertexCentricTraversal, - List> traversalRequests) { - return new DynamicGraphVertexCentricTraversalStartByIdsOp<>(graphViewDesc, - traversalRequests, - incVertexCentricTraversal); - } - - public static IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggTraversal incVertexCentricTraversal, - List> traversalRequests) { - return new DynamicGraphVertexCentricTraversalStartByIdsWithAggOp<>(graphViewDesc, - traversalRequests, - incVertexCentricTraversal); - } - + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricComputeOp( + GraphViewDesc graphViewDesc, VertexCentricCompute vertexCentricCompute) { + return new StaticGraphVertexCentricComputeOp<>(graphViewDesc, vertexCentricCompute); + } + + public static + IGraphVertexCentricAggOp + buildStaticGraphVertexCentricAggComputeOp( + GraphViewDesc graphViewDesc, + VertexCentricAggCompute vertexCentricAggCompute) { + return new StaticGraphVertexCentricComputeWithAggOp<>(graphViewDesc, vertexCentricAggCompute); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricTraversal vertexCentricTraversal) { + return new StaticGraphVertexCentricTraversalStartByStreamOp<>( + graphViewDesc, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricAggTraversal vertexCentricTraversal) { + return new StaticGraphVertexCentricTraversalStartByStreamWithAggOp<>( + graphViewDesc, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalAllOp( + GraphViewDesc graphViewDesc, + VertexCentricTraversal vertexCentricTraversal) { + return new StaticGraphVertexCentricTraversalAllOp<>(graphViewDesc, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalAllOp( + GraphViewDesc graphViewDesc, + VertexCentricAggTraversal vertexCentricTraversal) { + return new StaticGraphVertexCentricTraversalAllWithAggOp<>( + graphViewDesc, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricTraversal vertexCentricTraversal, + VertexBeginTraversalRequest traversalRequest) { + return new StaticGraphVertexCentricTraversalStartByIdsOp<>( + graphViewDesc, traversalRequest, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricAggTraversal vertexCentricTraversal, + VertexBeginTraversalRequest traversalRequest) { + return new StaticGraphVertexCentricTraversalStartByIdsWithAggOp<>( + graphViewDesc, traversalRequest, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricTraversal vertexCentricTraversal, + List> traversalRequests) { + return new StaticGraphVertexCentricTraversalStartByIdsOp<>( + graphViewDesc, traversalRequests, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildStaticGraphVertexCentricAggTraversalOp( + GraphViewDesc graphViewDesc, + VertexCentricAggTraversal vertexCentricTraversal, + List> traversalRequests) { + return new StaticGraphVertexCentricTraversalStartByIdsWithAggOp<>( + graphViewDesc, traversalRequests, vertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricComputeOp( + GraphViewDesc graphViewDesc, + IncVertexCentricCompute incVertexCentricCompute) { + return new DynamicGraphVertexCentricComputeOp(graphViewDesc, incVertexCentricCompute); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricAggComputeOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggCompute incVertexCentricCompute) { + return new DynamicGraphVertexCentricComputeWithAggOp(graphViewDesc, incVertexCentricCompute); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricTraversal incVertexCentricTraversal, + VertexBeginTraversalRequest traversalRequest) { + return new DynamicGraphVertexCentricTraversalStartByIdsOp<>( + graphViewDesc, traversalRequest, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggTraversal + incVertexCentricTraversal, + VertexBeginTraversalRequest traversalRequest) { + return new DynamicGraphVertexCentricTraversalStartByIdsWithAggOp<>( + graphViewDesc, traversalRequest, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalAllOp( + GraphViewDesc graphViewDesc, + IncVertexCentricTraversal incVertexCentricTraversal) { + return new DynamicGraphVertexCentricTraversalAllOp<>(graphViewDesc, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalAllOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggTraversal + incVertexCentricTraversal) { + return new DynamicGraphVertexCentricTraversalAllWithAggOp<>( + graphViewDesc, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricTraversal incVertexCentricTraversal) { + return new DynamicGraphVertexCentricTraversalStartByStreamOp<>( + graphViewDesc, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggTraversal + incVertexCentricTraversal) { + return new DynamicGraphVertexCentricTraversalStartByStreamWithAggOp<>( + graphViewDesc, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricTraversal incVertexCentricTraversal, + List> traversalRequests) { + return new DynamicGraphVertexCentricTraversalStartByIdsOp<>( + graphViewDesc, traversalRequests, incVertexCentricTraversal); + } + + public static + IGraphVertexCentricOp buildDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggTraversal + incVertexCentricTraversal, + List> traversalRequests) { + return new DynamicGraphVertexCentricTraversalStartByIdsWithAggOp<>( + graphViewDesc, traversalRequests, incVertexCentricTraversal); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphAggregateOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphAggregateOp.java index 6b6811bc6..f1457283c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphAggregateOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphAggregateOp.java @@ -21,6 +21,5 @@ public interface IGraphAggregateOp { - void processAggregateResult(R result); - + void processAggregateResult(R result); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphTraversalOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphTraversalOp.java index 9403b7ff2..c587589b5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphTraversalOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphTraversalOp.java @@ -20,17 +20,14 @@ package org.apache.geaflow.operator.impl.graph.algo.vc; import java.util.Iterator; + import org.apache.geaflow.model.traversal.ITraversalRequest; public interface IGraphTraversalOp extends IGraphVertexCentricOp { - /** - * Add traversal request. - */ - void addRequest(ITraversalRequest request); + /** Add traversal request. */ + void addRequest(ITraversalRequest request); - /** - * Returns traversal request iterator. - */ - Iterator> getTraversalRequests(); + /** Returns traversal request iterator. */ + Iterator> getTraversalRequests(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricAggOp.java index 373043d66..f991bfd35 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricAggOp.java @@ -22,16 +22,14 @@ /** * Interface for graph vertex centric operator with aggregation. * - * @param The id type of vertex/edge. + * @param The id type of vertex/edge. * @param The value type of vertex. * @param The value type of edge. - * @param The message type during iterations. - * @param The type of aggregate input iterm. + * @param The message type during iterations. + * @param The type of aggregate input iterm. * @param The type of partial aggregate iterm. * @param The type of partial aggregate result. - * @param The type of global aggregate result. + * @param The type of global aggregate result. */ public interface IGraphVertexCentricAggOp - extends IGraphVertexCentricOp, IGraphAggregateOp { - -} + extends IGraphVertexCentricOp, IGraphAggregateOp {} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricOp.java index 805863a5f..03171d151 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/IGraphVertexCentricOp.java @@ -25,19 +25,12 @@ public interface IGraphVertexCentricOp { - /** - * Add vertex into temporary graph. - */ - void addVertex(IVertex vertex); + /** Add vertex into temporary graph. */ + void addVertex(IVertex vertex); - /** - * Add edge into temporary graph. - */ - void addEdge(IEdge edge); - - /** - * Process iterator message. - */ - void processMessage(IGraphMessage graphMessage); + /** Add edge into temporary graph. */ + void addEdge(IEdge edge); + /** Process iterator message. */ + void processMessage(IGraphMessage graphMessage); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImpl.java index 31c26394d..c34392b75 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImpl.java @@ -19,8 +19,8 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import java.util.List; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; @@ -33,48 +33,46 @@ public class DynamicEdgeQueryImpl implements EdgeQuery { - protected K vId; - private long versionId; - private GraphState graphState; - protected KeyGroup keyGroup; - - public DynamicEdgeQueryImpl(K vId, long versionId, GraphState graphState) { - this.vId = vId; - this.versionId = versionId; - this.graphState = graphState; - } - - public DynamicEdgeQueryImpl(K vId, long versionId, GraphState graphState, - KeyGroup keyGroup) { - this.vId = vId; - this.versionId = versionId; - this.graphState = graphState; - this.keyGroup = keyGroup; - } + protected K vId; + private long versionId; + private GraphState graphState; + protected KeyGroup keyGroup; - @Override - public List> getEdges() { - DynamicEdgeState edgeState = graphState.dynamicGraph().E(); - return edgeState.query(versionId, vId).asList(); - } + public DynamicEdgeQueryImpl(K vId, long versionId, GraphState graphState) { + this.vId = vId; + this.versionId = versionId; + this.graphState = graphState; + } + public DynamicEdgeQueryImpl( + K vId, long versionId, GraphState graphState, KeyGroup keyGroup) { + this.vId = vId; + this.versionId = versionId; + this.graphState = graphState; + this.keyGroup = keyGroup; + } - @Override - public List> getOutEdges() { - DynamicEdgeState edgeState = graphState.dynamicGraph().E(); - return edgeState.query(versionId, vId).by(OutEdgeFilter.getInstance()).asList(); - } + @Override + public List> getEdges() { + DynamicEdgeState edgeState = graphState.dynamicGraph().E(); + return edgeState.query(versionId, vId).asList(); + } + @Override + public List> getOutEdges() { + DynamicEdgeState edgeState = graphState.dynamicGraph().E(); + return edgeState.query(versionId, vId).by(OutEdgeFilter.getInstance()).asList(); + } - @Override - public List> getInEdges() { - DynamicEdgeState edgeState = graphState.dynamicGraph().E(); - return edgeState.query(versionId, vId).by(InEdgeFilter.getInstance()).asList(); - } + @Override + public List> getInEdges() { + DynamicEdgeState edgeState = graphState.dynamicGraph().E(); + return edgeState.query(versionId, vId).by(InEdgeFilter.getInstance()).asList(); + } - @Override - public CloseableIterator> getEdges(IFilter edgeFilter) { - DynamicEdgeState edgeState = graphState.dynamicGraph().E(); - return edgeState.query(versionId, vId).by(edgeFilter).iterator(); - } + @Override + public CloseableIterator> getEdges(IFilter edgeFilter) { + DynamicEdgeState edgeState = graphState.dynamicGraph().E(); + return edgeState.query(versionId, vId).by(edgeFilter).iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalEdgeQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalEdgeQueryImpl.java index d655bb275..5b5cd847f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalEdgeQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalEdgeQueryImpl.java @@ -26,20 +26,18 @@ public class DynamicTraversalEdgeQueryImpl extends DynamicEdgeQueryImpl implements TraversalEdgeQuery { - public DynamicTraversalEdgeQueryImpl(K vId, long versionId, - GraphState graphState) { - super(vId, versionId, graphState); - } + public DynamicTraversalEdgeQueryImpl(K vId, long versionId, GraphState graphState) { + super(vId, versionId, graphState); + } + public DynamicTraversalEdgeQueryImpl( + K vId, long versionId, GraphState graphState, KeyGroup keyGroup) { + super(vId, versionId, graphState, keyGroup); + } - public DynamicTraversalEdgeQueryImpl(K vId, long versionId, - GraphState graphState, KeyGroup keyGroup) { - super(vId, versionId, graphState, keyGroup); - } - - @Override - public TraversalEdgeQuery withId(K vertexId) { - this.vId = vertexId; - return this; - } + @Override + public TraversalEdgeQuery withId(K vertexId) { + this.vId = vertexId; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalVertexQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalVertexQueryImpl.java index f5dc69418..0d470ff95 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalVertexQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicTraversalVertexQueryImpl.java @@ -27,22 +27,21 @@ public class DynamicTraversalVertexQueryImpl extends DynamicVertexQueryImpl implements TraversalVertexQuery { - public DynamicTraversalVertexQueryImpl(K vertexId, long versionId, - GraphState graphState) { - super(vertexId, versionId, graphState); - } + public DynamicTraversalVertexQueryImpl( + K vertexId, long versionId, GraphState graphState) { + super(vertexId, versionId, graphState); + } - public DynamicTraversalVertexQueryImpl(K vertexId, long versionId, - GraphState graphState, - KeyGroup keyGroup) { - super(vertexId, versionId, graphState, keyGroup); - } + public DynamicTraversalVertexQueryImpl( + K vertexId, long versionId, GraphState graphState, KeyGroup keyGroup) { + super(vertexId, versionId, graphState, keyGroup); + } - @Override - public CloseableIterator loadIdIterator() { - if (keyGroup == null) { - return graphState.dynamicGraph().V().idIterator(); - } - return graphState.dynamicGraph().V().query(versionId, keyGroup).idIterator(); + @Override + public CloseableIterator loadIdIterator() { + if (keyGroup == null) { + return graphState.dynamicGraph().V().idIterator(); } + return graphState.dynamicGraph().V().query(versionId, keyGroup).idIterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImpl.java index 5a85bc0fa..ab6834fb5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.VertexQuery; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.GraphState; @@ -28,40 +27,38 @@ public class DynamicVertexQueryImpl implements VertexQuery { - private K vertexId; - protected long versionId; - protected GraphState graphState; - protected KeyGroup keyGroup; - - public DynamicVertexQueryImpl(K vertexId, long versionId, GraphState graphState) { - this.vertexId = vertexId; - this.versionId = versionId; - this.graphState = graphState; - } - - public DynamicVertexQueryImpl(K vertexId, long versionId, GraphState graphState, - KeyGroup keyGroup) { - this.vertexId = vertexId; - this.versionId = versionId; - this.graphState = graphState; - this.keyGroup = keyGroup; - } + private K vertexId; + protected long versionId; + protected GraphState graphState; + protected KeyGroup keyGroup; - @Override - public VertexQuery withId(K vertexId) { - this.vertexId = vertexId; - return this; - } + public DynamicVertexQueryImpl(K vertexId, long versionId, GraphState graphState) { + this.vertexId = vertexId; + this.versionId = versionId; + this.graphState = graphState; + } - @Override - public IVertex get() { - return graphState.dynamicGraph().V().query(versionId, vertexId).get(); - } + public DynamicVertexQueryImpl( + K vertexId, long versionId, GraphState graphState, KeyGroup keyGroup) { + this.vertexId = vertexId; + this.versionId = versionId; + this.graphState = graphState; + this.keyGroup = keyGroup; + } - @Override - public IVertex get(IFilter vertexFilter) { - return graphState.dynamicGraph().V().query(versionId, vertexId).by(vertexFilter).get(); - } + @Override + public VertexQuery withId(K vertexId) { + this.vertexId = vertexId; + return this; + } + @Override + public IVertex get() { + return graphState.dynamicGraph().V().query(versionId, vertexId).get(); + } + @Override + public IVertex get(IFilter vertexFilter) { + return graphState.dynamicGraph().V().query(versionId, vertexId).by(vertexFilter).get(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphContextImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphContextImpl.java index 3c5401f33..ab4812174 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphContextImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphContextImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.HistoricalGraph; import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.IncGraphContext; @@ -37,100 +38,99 @@ public class IncGraphContextImpl implements IncGraphContext { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphContextImpl.class); - - private long iterationId; - private K vertexId; - private final OpContext opContext; - private final RuntimeContext runtimeContext; - - private final IncHistoricalGraph historicalGraph; - private final IncTemporaryGraph temporaryGraph; - private final IncMutableGraph mutableGraph; - private final GraphState graphState; - private final IGraphMsgBox graphMsgBox; - private final long maxIteration; - - public IncGraphContextImpl(OpContext opContext, - RuntimeContext runtimeContext, - GraphState graphState, - TemporaryGraphCache temporaryGraphCache, - IGraphMsgBox graphMsgBox, - long maxIteration) { - this.opContext = opContext; - this.runtimeContext = runtimeContext; - this.historicalGraph = new IncHistoricalGraph<>(graphState); - this.temporaryGraph = new IncTemporaryGraph<>(temporaryGraphCache); - this.mutableGraph = new IncMutableGraph<>(graphState); - this.graphMsgBox = graphMsgBox; - this.graphState = graphState; - this.maxIteration = maxIteration; - } - - public void init(long iterationId, K vertexId) { - this.iterationId = iterationId; - this.vertexId = vertexId; - - this.historicalGraph.init(vertexId); - this.temporaryGraph.init(vertexId); - - } - - @Override - public long getJobId() { - return opContext.getRuntimeContext().getPipelineId(); - } - - @Override - public long getIterationId() { - return iterationId; - } - - @Override - public RuntimeContext getRuntimeContext() { - return this.runtimeContext; + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphContextImpl.class); + + private long iterationId; + private K vertexId; + private final OpContext opContext; + private final RuntimeContext runtimeContext; + + private final IncHistoricalGraph historicalGraph; + private final IncTemporaryGraph temporaryGraph; + private final IncMutableGraph mutableGraph; + private final GraphState graphState; + private final IGraphMsgBox graphMsgBox; + private final long maxIteration; + + public IncGraphContextImpl( + OpContext opContext, + RuntimeContext runtimeContext, + GraphState graphState, + TemporaryGraphCache temporaryGraphCache, + IGraphMsgBox graphMsgBox, + long maxIteration) { + this.opContext = opContext; + this.runtimeContext = runtimeContext; + this.historicalGraph = new IncHistoricalGraph<>(graphState); + this.temporaryGraph = new IncTemporaryGraph<>(temporaryGraphCache); + this.mutableGraph = new IncMutableGraph<>(graphState); + this.graphMsgBox = graphMsgBox; + this.graphState = graphState; + this.maxIteration = maxIteration; + } + + public void init(long iterationId, K vertexId) { + this.iterationId = iterationId; + this.vertexId = vertexId; + + this.historicalGraph.init(vertexId); + this.temporaryGraph.init(vertexId); + } + + @Override + public long getJobId() { + return opContext.getRuntimeContext().getPipelineId(); + } + + @Override + public long getIterationId() { + return iterationId; + } + + @Override + public RuntimeContext getRuntimeContext() { + return this.runtimeContext; + } + + @Override + public MutableGraph getMutableGraph() { + return this.mutableGraph; + } + + @Override + public TemporaryGraph getTemporaryGraph() { + return this.temporaryGraph; + } + + @Override + public HistoricalGraph getHistoricalGraph() { + return this.historicalGraph; + } + + @Override + public void sendMessage(K vertexId, M m) { + if (this.iterationId >= this.maxIteration) { + return; } + graphMsgBox.addOutMessage(vertexId, m); + } - @Override - public MutableGraph getMutableGraph() { - return this.mutableGraph; + @Override + public void sendMessageToNeighbors(M m) { + if (this.iterationId >= this.maxIteration) { + return; } - - @Override - public TemporaryGraph getTemporaryGraph() { - return this.temporaryGraph; - } - - @Override - public HistoricalGraph getHistoricalGraph() { - return this.historicalGraph; - } - - @Override - public void sendMessage(K vertexId, M m) { - if (this.iterationId >= this.maxIteration) { - return; + List allVersions = graphState.dynamicGraph().V().getAllVersions(vertexId); + for (long version : allVersions) { + try (CloseableIterator> edgeIterator = + graphState.dynamicGraph().E().query(version, vertexId).iterator()) { + while (edgeIterator.hasNext()) { + IEdge edge = edgeIterator.next(); + graphMsgBox.addOutMessage(edge.getTargetId(), m); } - graphMsgBox.addOutMessage(vertexId, m); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } } - - @Override - public void sendMessageToNeighbors(M m) { - if (this.iterationId >= this.maxIteration) { - return; - } - List allVersions = graphState.dynamicGraph().V().getAllVersions(vertexId); - for (long version : allVersions) { - try (CloseableIterator> edgeIterator - = graphState.dynamicGraph().E().query(version, vertexId).iterator()) { - while (edgeIterator.hasNext()) { - IEdge edge = edgeIterator.next(); - graphMsgBox.addOutMessage(edge.getTargetId(), m); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphSnapShot.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphSnapShot.java index 252d3f536..e0b270710 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphSnapShot.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncGraphSnapShot.java @@ -26,29 +26,28 @@ public class IncGraphSnapShot implements GraphSnapShot { - protected final K vertexId; - protected final long versionId; - protected final GraphState graphState; - - public IncGraphSnapShot(K vertexId, long versionId, GraphState graphState) { - this.vertexId = vertexId; - this.versionId = versionId; - this.graphState = graphState; - } - - @Override - public long getVersion() { - return versionId; - } - - @Override - public VertexQuery vertex() { - return new DynamicVertexQueryImpl<>(vertexId, versionId, graphState); - } - - @Override - public EdgeQuery edges() { - return new DynamicEdgeQueryImpl<>(vertexId, versionId, graphState); - } - + protected final K vertexId; + protected final long versionId; + protected final GraphState graphState; + + public IncGraphSnapShot(K vertexId, long versionId, GraphState graphState) { + this.vertexId = vertexId; + this.versionId = versionId; + this.graphState = graphState; + } + + @Override + public long getVersion() { + return versionId; + } + + @Override + public VertexQuery vertex() { + return new DynamicVertexQueryImpl<>(vertexId, versionId, graphState); + } + + @Override + public EdgeQuery edges() { + return new DynamicEdgeQueryImpl<>(vertexId, versionId, graphState); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncHistoricalGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncHistoricalGraph.java index edddf95c6..0ee89117f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncHistoricalGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncHistoricalGraph.java @@ -19,9 +19,9 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import java.util.List; import java.util.Map; + import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.GraphSnapShot; import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.HistoricalGraph; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -30,45 +30,45 @@ public class IncHistoricalGraph implements HistoricalGraph { - protected K vertexId; - protected final GraphState graphState; + protected K vertexId; + protected final GraphState graphState; - public IncHistoricalGraph(GraphState graphState) { - this.graphState = graphState; - } + public IncHistoricalGraph(GraphState graphState) { + this.graphState = graphState; + } - public void init(K vertexId) { - this.vertexId = vertexId; - } + public void init(K vertexId) { + this.vertexId = vertexId; + } - @Override - public Long getLatestVersionId() { - return graphState.dynamicGraph().V().getLatestVersion(this.vertexId); - } + @Override + public Long getLatestVersionId() { + return graphState.dynamicGraph().V().getLatestVersion(this.vertexId); + } - @Override - public List getAllVersionIds() { - return graphState.dynamicGraph().V().getAllVersions(this.vertexId); - } + @Override + public List getAllVersionIds() { + return graphState.dynamicGraph().V().getAllVersions(this.vertexId); + } - @Override - public Map> getAllVertex() { - return graphState.dynamicGraph().V().query(vertexId).asMap(); - } + @Override + public Map> getAllVertex() { + return graphState.dynamicGraph().V().query(vertexId).asMap(); + } - @Override - public Map> getAllVertex(List versions) { - return graphState.dynamicGraph().V().query(vertexId, versions).asMap(); - } + @Override + public Map> getAllVertex(List versions) { + return graphState.dynamicGraph().V().query(vertexId, versions).asMap(); + } - @Override - public Map> getAllVertex(List versions, - IVertexFilter vertexFilter) { - return graphState.dynamicGraph().V().query(vertexId, versions).by(vertexFilter).asMap(); - } + @Override + public Map> getAllVertex( + List versions, IVertexFilter vertexFilter) { + return graphState.dynamicGraph().V().query(vertexId, versions).by(vertexFilter).asMap(); + } - @Override - public GraphSnapShot getSnapShot(long version) { - return new IncGraphSnapShot(vertexId, version, graphState); - } + @Override + public GraphSnapShot getSnapShot(long version) { + return new IncGraphSnapShot(vertexId, version, graphState); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncMutableGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncMutableGraph.java index 07203505d..0db2145bc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncMutableGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncMutableGraph.java @@ -19,7 +19,6 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.MutableGraph; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -27,21 +26,19 @@ public class IncMutableGraph implements MutableGraph { - private GraphState graphState; - - public IncMutableGraph(GraphState graphState) { - this.graphState = graphState; - } - - @Override - public void addVertex(long version, IVertex vertex) { - graphState.dynamicGraph().V().add(version, vertex); - } + private GraphState graphState; - @Override - public void addEdge(long version, IEdge edge) { - graphState.dynamicGraph().E().add(version, edge); - } + public IncMutableGraph(GraphState graphState) { + this.graphState = graphState; + } + @Override + public void addVertex(long version, IVertex vertex) { + graphState.dynamicGraph().V().add(version, vertex); + } + @Override + public void addEdge(long version, IEdge edge) { + graphState.dynamicGraph().E().add(version, edge); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncTemporaryGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncTemporaryGraph.java index 10a249575..355f4e05c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncTemporaryGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/IncTemporaryGraph.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; import java.util.List; + import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.TemporaryGraph; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -28,36 +29,35 @@ public class IncTemporaryGraph implements TemporaryGraph { - private K vertexId; - private TemporaryGraphCache temporaryGraphCache; - - public IncTemporaryGraph(TemporaryGraphCache temporaryGraphCache) { - this.temporaryGraphCache = temporaryGraphCache; - } - - public void init(K vertexId) { - this.vertexId = vertexId; - } - - @Override - public IVertex getVertex() { - return temporaryGraphCache.getVertex(vertexId); + private K vertexId; + private TemporaryGraphCache temporaryGraphCache; + + public IncTemporaryGraph(TemporaryGraphCache temporaryGraphCache) { + this.temporaryGraphCache = temporaryGraphCache; + } + + public void init(K vertexId) { + this.vertexId = vertexId; + } + + @Override + public IVertex getVertex() { + return temporaryGraphCache.getVertex(vertexId); + } + + @Override + public List> getEdges() { + return temporaryGraphCache.getEdges(vertexId); + } + + @Override + public void updateVertexValue(VV value) { + IVertex valueVertex = temporaryGraphCache.getVertex(vertexId); + if (valueVertex == null) { + valueVertex = new ValueVertex<>(vertexId, value); + } else { + valueVertex = valueVertex.withValue(value); } - - @Override - public List> getEdges() { - return temporaryGraphCache.getEdges(vertexId); - } - - @Override - public void updateVertexValue(VV value) { - IVertex valueVertex = temporaryGraphCache.getVertex(vertexId); - if (valueVertex == null) { - valueVertex = new ValueVertex<>(vertexId, value); - } else { - valueVertex = valueVertex.withValue(value); - } - temporaryGraphCache.addVertex(valueVertex); - } - + temporaryGraphCache.addVertex(valueVertex); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncGraphSnapShot.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncGraphSnapShot.java index 014515726..4ea888f29 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncGraphSnapShot.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncGraphSnapShot.java @@ -27,18 +27,17 @@ public class TraversalIncGraphSnapShot extends IncGraphSnapShot implements TraversalGraphSnapShot { - public TraversalIncGraphSnapShot(K vertexId, long versionId, - GraphState graphState) { - super(vertexId, versionId, graphState); - } + public TraversalIncGraphSnapShot(K vertexId, long versionId, GraphState graphState) { + super(vertexId, versionId, graphState); + } - @Override - public TraversalVertexQuery vertex() { - return new DynamicTraversalVertexQueryImpl<>(vertexId, versionId, graphState); - } + @Override + public TraversalVertexQuery vertex() { + return new DynamicTraversalVertexQueryImpl<>(vertexId, versionId, graphState); + } - @Override - public TraversalEdgeQuery edges() { - return new DynamicTraversalEdgeQueryImpl<>(vertexId, versionId, graphState); - } + @Override + public TraversalEdgeQuery edges() { + return new DynamicTraversalEdgeQueryImpl<>(vertexId, versionId, graphState); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncHistoricalGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncHistoricalGraph.java index 8476ea0e4..35c7462ef 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncHistoricalGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/TraversalIncHistoricalGraph.java @@ -19,51 +19,53 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import java.util.List; import java.util.Map; + import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.TraversalGraphSnapShot; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.TraversalHistoricalGraph; import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.HistoricalGraph; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.pushdown.filter.IVertexFilter; -public class TraversalIncHistoricalGraph implements HistoricalGraph - , TraversalHistoricalGraph { +public class TraversalIncHistoricalGraph + implements HistoricalGraph, TraversalHistoricalGraph { - private final IncHistoricalGraph historicalGraph; + private final IncHistoricalGraph historicalGraph; - public TraversalIncHistoricalGraph(IncHistoricalGraph historicalGraph) { - this.historicalGraph = historicalGraph; - } + public TraversalIncHistoricalGraph(IncHistoricalGraph historicalGraph) { + this.historicalGraph = historicalGraph; + } - @Override - public Long getLatestVersionId() { - return historicalGraph.getLatestVersionId(); - } + @Override + public Long getLatestVersionId() { + return historicalGraph.getLatestVersionId(); + } - @Override - public List getAllVersionIds() { - return historicalGraph.getAllVersionIds(); - } + @Override + public List getAllVersionIds() { + return historicalGraph.getAllVersionIds(); + } - @Override - public Map> getAllVertex() { - return historicalGraph.getAllVertex(); - } + @Override + public Map> getAllVertex() { + return historicalGraph.getAllVertex(); + } - @Override - public Map> getAllVertex(List versions) { - return historicalGraph.getAllVertex(versions); - } + @Override + public Map> getAllVertex(List versions) { + return historicalGraph.getAllVertex(versions); + } - @Override - public Map> getAllVertex(List versions, IVertexFilter vertexFilter) { - return historicalGraph.getAllVertex(versions, vertexFilter); - } + @Override + public Map> getAllVertex( + List versions, IVertexFilter vertexFilter) { + return historicalGraph.getAllVertex(versions, vertexFilter); + } - @Override - public TraversalGraphSnapShot getSnapShot(long version) { - return new TraversalIncGraphSnapShot<>(historicalGraph.vertexId, version, historicalGraph.graphState); - } + @Override + public TraversalGraphSnapShot getSnapShot(long version) { + return new TraversalIncGraphSnapShot<>( + historicalGraph.vertexId, version, historicalGraph.graphState); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImpl.java index 782652538..b11a8b259 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.statical; import java.util.List; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; @@ -31,40 +32,38 @@ public class StaticEdgeQueryImpl implements EdgeQuery { - protected K vId; - private final GraphState graphState; - protected KeyGroup keyGroup; - - public StaticEdgeQueryImpl(K vId, GraphState graphState) { - this.vId = vId; - this.graphState = graphState; - } - - public StaticEdgeQueryImpl(K vId, GraphState graphState, KeyGroup keyGroup) { - this.vId = vId; - this.graphState = graphState; - this.keyGroup = keyGroup; - } + protected K vId; + private final GraphState graphState; + protected KeyGroup keyGroup; - @Override - public List> getEdges() { - return graphState.staticGraph().E().query(vId).asList(); - } + public StaticEdgeQueryImpl(K vId, GraphState graphState) { + this.vId = vId; + this.graphState = graphState; + } - @Override - public List> getOutEdges() { - return graphState.staticGraph().E().query(vId).by(OutEdgeFilter.getInstance()).asList(); - } + public StaticEdgeQueryImpl(K vId, GraphState graphState, KeyGroup keyGroup) { + this.vId = vId; + this.graphState = graphState; + this.keyGroup = keyGroup; + } + @Override + public List> getEdges() { + return graphState.staticGraph().E().query(vId).asList(); + } - @Override - public List> getInEdges() { - return graphState.staticGraph().E().query(vId).by(InEdgeFilter.getInstance()).asList(); - } + @Override + public List> getOutEdges() { + return graphState.staticGraph().E().query(vId).by(OutEdgeFilter.getInstance()).asList(); + } + @Override + public List> getInEdges() { + return graphState.staticGraph().E().query(vId).by(InEdgeFilter.getInstance()).asList(); + } - @Override - public CloseableIterator> getEdges(IFilter edgeFilter) { - return (CloseableIterator) graphState.staticGraph().E().query(vId).by(edgeFilter).iterator(); - } + @Override + public CloseableIterator> getEdges(IFilter edgeFilter) { + return (CloseableIterator) graphState.staticGraph().E().query(vId).by(edgeFilter).iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticGraphContextImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticGraphContextImpl.java index 3c2e9b4ea..d149be63a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticGraphContextImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticGraphContextImpl.java @@ -31,78 +31,78 @@ public class StaticGraphContextImpl implements VertexCentricFunction.VertexCentricFuncContext { - private final Operator.OpContext opContext; - private final RuntimeContext runtimeContext; - private final GraphState graphState; - private final IGraphMsgBox graphMsgBox; - private final long maxIteration; - protected long iterationId; - protected K vertexId; + private final Operator.OpContext opContext; + private final RuntimeContext runtimeContext; + private final GraphState graphState; + private final IGraphMsgBox graphMsgBox; + private final long maxIteration; + protected long iterationId; + protected K vertexId; - public StaticGraphContextImpl(Operator.OpContext opContext, - RuntimeContext runtimeContext, - GraphState graphState, - IGraphMsgBox graphMsgBox, - long maxIteration) { - this.opContext = opContext; - this.runtimeContext = runtimeContext; - this.graphState = graphState; - this.graphMsgBox = graphMsgBox; - this.maxIteration = maxIteration; - } + public StaticGraphContextImpl( + Operator.OpContext opContext, + RuntimeContext runtimeContext, + GraphState graphState, + IGraphMsgBox graphMsgBox, + long maxIteration) { + this.opContext = opContext; + this.runtimeContext = runtimeContext; + this.graphState = graphState; + this.graphMsgBox = graphMsgBox; + this.maxIteration = maxIteration; + } - public void init(long iterationId, K vertexId) { - this.iterationId = iterationId; - this.vertexId = vertexId; - } + public void init(long iterationId, K vertexId) { + this.iterationId = iterationId; + this.vertexId = vertexId; + } - @Override - public long getJobId() { - return this.opContext.getRuntimeContext().getPipelineId(); - } + @Override + public long getJobId() { + return this.opContext.getRuntimeContext().getPipelineId(); + } - @Override - public long getIterationId() { - return this.iterationId; - } + @Override + public long getIterationId() { + return this.iterationId; + } - @Override - public RuntimeContext getRuntimeContext() { - return this.runtimeContext; - } + @Override + public RuntimeContext getRuntimeContext() { + return this.runtimeContext; + } - @Override - public VertexCentricFunction.VertexQuery vertex() { - return new StaticVertexQueryImpl<>(this.vertexId, this.graphState); - } + @Override + public VertexCentricFunction.VertexQuery vertex() { + return new StaticVertexQueryImpl<>(this.vertexId, this.graphState); + } - @Override - public VertexCentricFunction.EdgeQuery edges() { - return new StaticEdgeQueryImpl<>(this.vertexId, this.graphState); - } + @Override + public VertexCentricFunction.EdgeQuery edges() { + return new StaticEdgeQueryImpl<>(this.vertexId, this.graphState); + } - @Override - public void sendMessage(K vertexId, M message) { - if (this.iterationId >= this.maxIteration) { - return; - } - this.graphMsgBox.addOutMessage(vertexId, message); + @Override + public void sendMessage(K vertexId, M message) { + if (this.iterationId >= this.maxIteration) { + return; } + this.graphMsgBox.addOutMessage(vertexId, message); + } - @Override - public void sendMessageToNeighbors(M message) { - if (this.iterationId >= this.maxIteration) { - return; - } - try (CloseableIterator> edgeIterator - = this.graphState.staticGraph().E().query(this.vertexId).iterator()) { - while (edgeIterator.hasNext()) { - IEdge edge = edgeIterator.next(); - this.graphMsgBox.addOutMessage(edge.getTargetId(), message); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + @Override + public void sendMessageToNeighbors(M message) { + if (this.iterationId >= this.maxIteration) { + return; } - + try (CloseableIterator> edgeIterator = + this.graphState.staticGraph().E().query(this.vertexId).iterator()) { + while (edgeIterator.hasNext()) { + IEdge edge = edgeIterator.next(); + this.graphMsgBox.addOutMessage(edge.getTargetId(), message); + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalEdgeQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalEdgeQueryImpl.java index 4182ecb5b..0c2d8431f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalEdgeQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalEdgeQueryImpl.java @@ -26,17 +26,17 @@ public class StaticTraversalEdgeQueryImpl extends StaticEdgeQueryImpl implements TraversalEdgeQuery { - public StaticTraversalEdgeQueryImpl(K vId, GraphState graphState) { - super(vId, graphState); - } + public StaticTraversalEdgeQueryImpl(K vId, GraphState graphState) { + super(vId, graphState); + } - public StaticTraversalEdgeQueryImpl(K vId, GraphState graphState, KeyGroup keyGroup) { - super(vId, graphState, keyGroup); - } + public StaticTraversalEdgeQueryImpl(K vId, GraphState graphState, KeyGroup keyGroup) { + super(vId, graphState, keyGroup); + } - @Override - public TraversalEdgeQuery withId(K vertexId) { - this.vId = vertexId; - return this; - } + @Override + public TraversalEdgeQuery withId(K vertexId) { + this.vId = vertexId; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalVertexQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalVertexQueryImpl.java index b16888c33..9c7fa6e3f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalVertexQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticTraversalVertexQueryImpl.java @@ -27,15 +27,16 @@ public class StaticTraversalVertexQueryImpl extends StaticVertexQueryImpl implements TraversalVertexQuery { - public StaticTraversalVertexQueryImpl(K vertexId, GraphState graphState, KeyGroup keyGroup) { - super(vertexId, graphState, keyGroup); - } + public StaticTraversalVertexQueryImpl( + K vertexId, GraphState graphState, KeyGroup keyGroup) { + super(vertexId, graphState, keyGroup); + } - @Override - public CloseableIterator loadIdIterator() { - if (keyGroup == null) { - return graphState.staticGraph().V().idIterator(); - } - return graphState.staticGraph().V().query(keyGroup).idIterator(); + @Override + public CloseableIterator loadIdIterator() { + if (keyGroup == null) { + return graphState.staticGraph().V().idIterator(); } + return graphState.staticGraph().V().query(keyGroup).idIterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImpl.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImpl.java index 109605dd0..b464fe31d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImpl.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImpl.java @@ -27,34 +27,34 @@ public class StaticVertexQueryImpl implements VertexQuery { - private K vertexId; - protected GraphState graphState; - protected KeyGroup keyGroup; - - public StaticVertexQueryImpl(K vertexId, GraphState graphState) { - this.vertexId = vertexId; - this.graphState = graphState; - } - - public StaticVertexQueryImpl(K vertexId, GraphState graphState, KeyGroup keyGroup) { - this.vertexId = vertexId; - this.graphState = graphState; - this.keyGroup = keyGroup; - } - - @Override - public VertexQuery withId(K vertexId) { - this.vertexId = vertexId; - return this; - } - - @Override - public IVertex get() { - return graphState.staticGraph().V().query(vertexId).get(); - } - - @Override - public IVertex get(IFilter vertexFilter) { - return graphState.staticGraph().V().query(vertexId).by(vertexFilter).get(); - } + private K vertexId; + protected GraphState graphState; + protected KeyGroup keyGroup; + + public StaticVertexQueryImpl(K vertexId, GraphState graphState) { + this.vertexId = vertexId; + this.graphState = graphState; + } + + public StaticVertexQueryImpl(K vertexId, GraphState graphState, KeyGroup keyGroup) { + this.vertexId = vertexId; + this.graphState = graphState; + this.keyGroup = keyGroup; + } + + @Override + public VertexQuery withId(K vertexId) { + this.vertexId = vertexId; + return this; + } + + @Override + public IVertex get() { + return graphState.staticGraph().V().query(vertexId).get(); + } + + @Override + public IVertex get(IFilter vertexFilter) { + return graphState.staticGraph().V().query(vertexId).by(vertexFilter).get(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/CombinedMsgBox.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/CombinedMsgBox.java index f411fcc0f..526e0aa76 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/CombinedMsgBox.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/CombinedMsgBox.java @@ -19,86 +19,87 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.msgbox; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; -import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; - -public class CombinedMsgBox implements IGraphMsgBox { - - private final Map inMessageBox; - private final Map outMessageBox; - - private final VertexCentricCombineFunction combineFunction; - - public CombinedMsgBox(VertexCentricCombineFunction combineFunction) { - this.combineFunction = combineFunction; - this.inMessageBox = new HashMap<>(); - this.outMessageBox = new HashMap<>(); - } - - @Override - public void addInMessages(K vertexId, MESSAGE message) { - MESSAGE oldMessage = inMessageBox.get(vertexId); - if (oldMessage != null) { - MESSAGE newMessage = this.combineFunction.combine(oldMessage, message); - inMessageBox.put(vertexId, newMessage); - } else { - this.inMessageBox.put(vertexId, message); - } - } - - @Override - public void processInMessage(MsgProcessFunc processFunc) { - processMessage(inMessageBox, processFunc); - } - @Override - public void clearInBox() { - this.inMessageBox.clear(); - } - - @Override - public void addOutMessage(K vertexId, MESSAGE message) { - addMessage(outMessageBox, vertexId, message); - } +import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; - @Override - public void processOutMessage(MsgProcessFunc processFunc) { - processMessage(outMessageBox, processFunc); - } +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; - @Override - public void clearOutBox() { - this.outMessageBox.clear(); - } +public class CombinedMsgBox implements IGraphMsgBox { - private void processMessage(Map messageBox, MsgProcessFunc processFunc) { - for (Entry entry : messageBox.entrySet()) { - processFunc.process(entry.getKey(), Lists.newArrayList(entry.getValue())); - } + private final Map inMessageBox; + private final Map outMessageBox; + + private final VertexCentricCombineFunction combineFunction; + + public CombinedMsgBox(VertexCentricCombineFunction combineFunction) { + this.combineFunction = combineFunction; + this.inMessageBox = new HashMap<>(); + this.outMessageBox = new HashMap<>(); + } + + @Override + public void addInMessages(K vertexId, MESSAGE message) { + MESSAGE oldMessage = inMessageBox.get(vertexId); + if (oldMessage != null) { + MESSAGE newMessage = this.combineFunction.combine(oldMessage, message); + inMessageBox.put(vertexId, newMessage); + } else { + this.inMessageBox.put(vertexId, message); } - - private void addMessage(Map messageBox, K vertexId, MESSAGE message) { - MESSAGE oldMessage = messageBox.get(vertexId); - if (oldMessage != null) { - MESSAGE newMessage = this.combineFunction.combine(oldMessage, message); - messageBox.put(vertexId, newMessage); - } else { - messageBox.put(vertexId, message); - } + } + + @Override + public void processInMessage(MsgProcessFunc processFunc) { + processMessage(inMessageBox, processFunc); + } + + @Override + public void clearInBox() { + this.inMessageBox.clear(); + } + + @Override + public void addOutMessage(K vertexId, MESSAGE message) { + addMessage(outMessageBox, vertexId, message); + } + + @Override + public void processOutMessage(MsgProcessFunc processFunc) { + processMessage(outMessageBox, processFunc); + } + + @Override + public void clearOutBox() { + this.outMessageBox.clear(); + } + + private void processMessage(Map messageBox, MsgProcessFunc processFunc) { + for (Entry entry : messageBox.entrySet()) { + processFunc.process(entry.getKey(), Lists.newArrayList(entry.getValue())); } - - @VisibleForTesting - protected Map getInMessageBox() { - return this.inMessageBox; + } + + private void addMessage(Map messageBox, K vertexId, MESSAGE message) { + MESSAGE oldMessage = messageBox.get(vertexId); + if (oldMessage != null) { + MESSAGE newMessage = this.combineFunction.combine(oldMessage, message); + messageBox.put(vertexId, newMessage); + } else { + messageBox.put(vertexId, message); } + } - @VisibleForTesting - protected Map getOutMessageBox() { - return this.outMessageBox; - } + @VisibleForTesting + protected Map getInMessageBox() { + return this.inMessageBox; + } + @VisibleForTesting + protected Map getOutMessageBox() { + return this.outMessageBox; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/DirectEmitMsgBox.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/DirectEmitMsgBox.java index a12f1da99..5011a7f96 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/DirectEmitMsgBox.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/DirectEmitMsgBox.java @@ -24,62 +24,60 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; + import org.apache.geaflow.collector.ICollector; import org.apache.geaflow.model.graph.message.DefaultGraphMessage; import org.apache.geaflow.model.graph.message.IGraphMessage; public class DirectEmitMsgBox implements IGraphMsgBox { - private final Map> inMessageBox; - private final ICollector> msgCollector; + private final Map> inMessageBox; + private final ICollector> msgCollector; - public DirectEmitMsgBox(ICollector> msgCollector) { - this.inMessageBox = new HashMap<>(); - this.msgCollector = msgCollector; - } + public DirectEmitMsgBox(ICollector> msgCollector) { + this.inMessageBox = new HashMap<>(); + this.msgCollector = msgCollector; + } - @Override - public void addInMessages(K vertexId, MESSAGE message) { - List messages = inMessageBox.computeIfAbsent(vertexId, k -> new ArrayList<>()); - messages.add(message); - } - - @Override - public void processInMessage(MsgProcessFunc processFunc) { - processMessage(inMessageBox, processFunc); - } + @Override + public void addInMessages(K vertexId, MESSAGE message) { + List messages = inMessageBox.computeIfAbsent(vertexId, k -> new ArrayList<>()); + messages.add(message); + } - @Override - public void clearInBox() { - this.inMessageBox.clear(); - } + @Override + public void processInMessage(MsgProcessFunc processFunc) { + processMessage(inMessageBox, processFunc); + } - @Override - public void addOutMessage(K vertexId, MESSAGE message) { - this.msgCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, message)); - } + @Override + public void clearInBox() { + this.inMessageBox.clear(); + } - @Override - public void processOutMessage(MsgProcessFunc processFunc) { - } + @Override + public void addOutMessage(K vertexId, MESSAGE message) { + this.msgCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, message)); + } - @Override - public void clearOutBox() { - } + @Override + public void processOutMessage(MsgProcessFunc processFunc) {} - private void processMessage(Map> messageBox, - MsgProcessFunc processFunc) { - for (Entry> entry : messageBox.entrySet()) { - K vertexId = entry.getKey(); - List messageList = entry.getValue(); - processFunc.process(vertexId, messageList); - } - } + @Override + public void clearOutBox() {} - private void addMessage(Map> messageBox, K vertexId, MESSAGE message) { - List oldMessages = messageBox.getOrDefault(vertexId, new ArrayList<>()); - oldMessages.add(message); - messageBox.put(vertexId, oldMessages); + private void processMessage( + Map> messageBox, MsgProcessFunc processFunc) { + for (Entry> entry : messageBox.entrySet()) { + K vertexId = entry.getKey(); + List messageList = entry.getValue(); + processFunc.process(vertexId, messageList); } + } + private void addMessage(Map> messageBox, K vertexId, MESSAGE message) { + List oldMessages = messageBox.getOrDefault(vertexId, new ArrayList<>()); + oldMessages.add(message); + messageBox.put(vertexId, oldMessages); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/GraphMsgBoxFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/GraphMsgBoxFactory.java index 2754b26b4..dbe5b7560 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/GraphMsgBoxFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/GraphMsgBoxFactory.java @@ -25,14 +25,13 @@ public class GraphMsgBoxFactory { - public static IGraphMsgBox buildMessageBox( - ICollector> msgCollector, - VertexCentricCombineFunction combineFunction) { - if (combineFunction == null) { - return new DirectEmitMsgBox<>(msgCollector); - } else { - return new CombinedMsgBox<>(combineFunction); - } + public static IGraphMsgBox buildMessageBox( + ICollector> msgCollector, + VertexCentricCombineFunction combineFunction) { + if (combineFunction == null) { + return new DirectEmitMsgBox<>(msgCollector); + } else { + return new CombinedMsgBox<>(combineFunction); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/IGraphMsgBox.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/IGraphMsgBox.java index c4322abba..085070c07 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/IGraphMsgBox.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/IGraphMsgBox.java @@ -23,44 +23,28 @@ public interface IGraphMsgBox { - /** - * Add in-message into box. - */ - void addInMessages(K vertexId, MESSAGE message); + /** Add in-message into box. */ + void addInMessages(K vertexId, MESSAGE message); - /** - * Process in-message by process function. - */ - void processInMessage(MsgProcessFunc processFunc); + /** Process in-message by process function. */ + void processInMessage(MsgProcessFunc processFunc); - /** - * Clear in-message box. - */ - void clearInBox(); + /** Clear in-message box. */ + void clearInBox(); - /** - * Add out-message into box. - */ - void addOutMessage(K vertexId, MESSAGE message); + /** Add out-message into box. */ + void addOutMessage(K vertexId, MESSAGE message); - /** - * Process out-message by process function. - */ - void processOutMessage(MsgProcessFunc processFunc); + /** Process out-message by process function. */ + void processOutMessage(MsgProcessFunc processFunc); - /** - * Clear out-message box. - */ - void clearOutBox(); + /** Clear out-message box. */ + void clearOutBox(); - @FunctionalInterface - interface MsgProcessFunc { - - /** - * Process messages. - */ - void process(K vertexId, List messageList); - - } + @FunctionalInterface + interface MsgProcessFunc { + /** Process messages. */ + void process(K vertexId, List messageList); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/AbstractDynamicGraphVertexCentricOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/AbstractDynamicGraphVertexCentricOp.java index bdd89f3dc..82bc0517b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/AbstractDynamicGraphVertexCentricOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/AbstractDynamicGraphVertexCentricOp.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT; import java.io.IOException; + import org.apache.geaflow.api.graph.base.algo.VertexCentricAlgo; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.CheckpointUtil; @@ -39,119 +40,130 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractDynamicGraphVertexCentricOp> +public abstract class AbstractDynamicGraphVertexCentricOp< + K, VV, EV, M, FUNC extends VertexCentricAlgo> extends AbstractGraphVertexCentricOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractDynamicGraphVertexCentricOp.class); - - protected TemporaryGraphCache temporaryGraphCache; - private long checkpointDuration; - - public AbstractDynamicGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { - super(graphViewDesc, func); - assert !graphViewDesc.isStatic(); - opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_COMPUTE); - opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); - this.maxIterations = this.function.getMaxIterationCount(); - } - - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.temporaryGraphCache = new TemporaryGraphCache<>(); - this.checkpointDuration = runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); - } - - @Override - public void addVertex(IVertex vertex) { - if (enableDebug) { - LOGGER.info("taskId:{} windowId:{} iterations:{} add vertex:{}", - runtimeContext.getTaskArgs().getTaskId(), - windowId, - iterations, - vertex); - } - this.temporaryGraphCache.addVertex(vertex); - this.opInputMeter.mark(); - } - - @Override - public void addEdge(IEdge edge) { - if (enableDebug) { - LOGGER.info("taskId:{} windowId:{} iterations:{} add edge:{}", - runtimeContext.getTaskArgs().getTaskId(), - windowId, - iterations, - edge); - } - this.temporaryGraphCache.addEdge(edge); - this.opInputMeter.mark(); - } - - @Override - public void close() { - this.temporaryGraphCache.clear(); - this.graphMsgBox.clearInBox(); - this.graphMsgBox.clearOutBox(); - this.graphState.manage().operate().close(); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractDynamicGraphVertexCentricOp.class); + + protected TemporaryGraphCache temporaryGraphCache; + private long checkpointDuration; + + public AbstractDynamicGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { + super(graphViewDesc, func); + assert !graphViewDesc.isStatic(); + opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_COMPUTE); + opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); + this.maxIterations = this.function.getMaxIterationCount(); + } + + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.temporaryGraphCache = new TemporaryGraphCache<>(); + this.checkpointDuration = + runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); + } + + @Override + public void addVertex(IVertex vertex) { + if (enableDebug) { + LOGGER.info( + "taskId:{} windowId:{} iterations:{} add vertex:{}", + runtimeContext.getTaskArgs().getTaskId(), + windowId, + iterations, + vertex); } - - @Override - protected GraphStateDescriptor buildGraphStateDesc(String name) { - GraphStateDescriptor desc = super.buildGraphStateDesc(name); - desc.withDataModel(DataModel.DYNAMIC_GRAPH); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - return desc; + this.temporaryGraphCache.addVertex(vertex); + this.opInputMeter.mark(); + } + + @Override + public void addEdge(IEdge edge) { + if (enableDebug) { + LOGGER.info( + "taskId:{} windowId:{} iterations:{} add edge:{}", + runtimeContext.getTaskArgs().getTaskId(), + windowId, + iterations, + edge); } - - public GraphViewDesc getGraphViewDesc() { - return graphViewDesc; + this.temporaryGraphCache.addEdge(edge); + this.opInputMeter.mark(); + } + + @Override + public void close() { + this.temporaryGraphCache.clear(); + this.graphMsgBox.clearInBox(); + this.graphMsgBox.clearOutBox(); + this.graphState.manage().operate().close(); + } + + @Override + protected GraphStateDescriptor buildGraphStateDesc(String name) { + GraphStateDescriptor desc = super.buildGraphStateDesc(name); + desc.withDataModel(DataModel.DYNAMIC_GRAPH); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + return desc; + } + + public GraphViewDesc getGraphViewDesc() { + return graphViewDesc; + } + + protected void checkpoint() { + if (CheckpointUtil.needDoCheckpoint(windowId, checkpointDuration)) { + LOGGER.info( + "opName:{} do checkpoint for windowId:{}, checkpoint duration:{}", + this.opArgs.getOpName(), + windowId, + checkpointDuration); + long checkpoint = graphViewDesc.getCheckpoint(windowId); + LOGGER.info("do checkpoint, checkpointId: {}", checkpoint); + graphState.manage().operate().setCheckpointId(checkpoint); + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); } - - - protected void checkpoint() { - if (CheckpointUtil.needDoCheckpoint(windowId, checkpointDuration)) { - LOGGER.info("opName:{} do checkpoint for windowId:{}, checkpoint duration:{}", - this.opArgs.getOpName(), windowId, checkpointDuration); - long checkpoint = graphViewDesc.getCheckpoint(windowId); - LOGGER.info("do checkpoint, checkpointId: {}", checkpoint); - graphState.manage().operate().setCheckpointId(checkpoint); - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - } + } + + @Override + protected void recover() { + LOGGER.info("opName: {} will do recover, windowId: {}", this.opArgs.getOpName(), this.windowId); + long lastCheckPointId; + try { + ViewMetaBookKeeper keeper = + new ViewMetaBookKeeper(graphViewDesc.getName(), this.runtimeContext.getConfiguration()); + lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); + LOGGER.info( + "opName: {} will do recover, ViewMetaBookKeeper version: {}", + this.opArgs.getOpName(), + lastCheckPointId); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } - - @Override - protected void recover() { - LOGGER.info("opName: {} will do recover, windowId: {}", this.opArgs.getOpName(), - this.windowId); - long lastCheckPointId; - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), - this.runtimeContext.getConfiguration()); - lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); - LOGGER.info("opName: {} will do recover, ViewMetaBookKeeper version: {}", - this.opArgs.getOpName(), lastCheckPointId); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - if (lastCheckPointId >= 0) { - LOGGER.info("opName: {} do recover to state VersionId: {}", this.opArgs.getOpName(), - lastCheckPointId); - graphState.manage().operate().setCheckpointId(lastCheckPointId); - graphState.manage().operate().recover(); - } else { - LOGGER.info("lastCheckPointId < 0"); - //If the graph has a checkpoint, should we recover it - if (windowId > 1) { - //Recover checkpoint id for dynamic graph is init graph version add windowId - long recoverVersionId = graphViewDesc.getCheckpoint(windowId - 1); - LOGGER.info("opName: {} do recover to latestVersionId: {}", this.opArgs.getOpName(), - recoverVersionId); - graphState.manage().operate().setCheckpointId(recoverVersionId); - graphState.manage().operate().recover(); - } - } + if (lastCheckPointId >= 0) { + LOGGER.info( + "opName: {} do recover to state VersionId: {}", + this.opArgs.getOpName(), + lastCheckPointId); + graphState.manage().operate().setCheckpointId(lastCheckPointId); + graphState.manage().operate().recover(); + } else { + LOGGER.info("lastCheckPointId < 0"); + // If the graph has a checkpoint, should we recover it + if (windowId > 1) { + // Recover checkpoint id for dynamic graph is init graph version add windowId + long recoverVersionId = graphViewDesc.getCheckpoint(windowId - 1); + LOGGER.info( + "opName: {} do recover to latestVersionId: {}", + this.opArgs.getOpName(), + recoverVersionId); + graphState.manage().operate().setCheckpointId(recoverVersionId); + graphState.manage().operate().recover(); + } } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java index 7de8eca8d..4bff40bfa 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricComputeAlgo; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; @@ -46,149 +47,158 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricComputeOp> - extends AbstractDynamicGraphVertexCentricOp> +public class DynamicGraphVertexCentricComputeOp< + K, VV, EV, M, FUNC extends IncVertexCentricComputeFunction> + extends AbstractDynamicGraphVertexCentricOp< + K, VV, EV, M, AbstractIncVertexCentricComputeAlgo> implements IGraphVertexCentricOp, IteratorOperator { - private static final Logger LOGGER = LoggerFactory.getLogger(DynamicGraphVertexCentricComputeOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricComputeOp.class); - protected IncGraphComputeContextImpl graphIncVCComputeCtx; - protected IncVertexCentricComputeFunction incVCComputeFunction; + protected IncGraphComputeContextImpl graphIncVCComputeCtx; + protected IncVertexCentricComputeFunction incVCComputeFunction; - private Set invokeVIds; + private Set invokeVIds; - private ICollector> vertexCollector; + private ICollector> vertexCollector; - protected Configuration configuration; + protected Configuration configuration; - public DynamicGraphVertexCentricComputeOp(GraphViewDesc graphViewDesc, - AbstractIncVertexCentricComputeAlgo incVCAlgorithm) { - super(graphViewDesc, incVCAlgorithm); - opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_COMPUTE); - opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); - } + public DynamicGraphVertexCentricComputeOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricComputeAlgo incVCAlgorithm) { + super(graphViewDesc, incVCAlgorithm); + opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_COMPUTE); + opArgs.setChainStrategy(OpArgs.ChainStrategy.NEVER); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.incVCComputeFunction = this.function.getIncComputeFunction(); - this.configuration = runtimeContext.getConfiguration(); - this.graphIncVCComputeCtx = configuration.getBoolean(FrameworkConfigKeys.INFER_ENV_ENABLE) - ? new IncGraphInferComputeContextImpl() : new IncGraphComputeContextImpl(); - this.incVCComputeFunction.init(this.graphIncVCComputeCtx); - - this.invokeVIds = new HashSet<>(); - - for (ICollector collector : this.collectors) { - if (!collector.getTag().equals(GraphRecordNames.Message.name()) - && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { - vertexCollector = collector; - } - } - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.incVCComputeFunction = this.function.getIncComputeFunction(); + this.configuration = runtimeContext.getConfiguration(); + this.graphIncVCComputeCtx = + configuration.getBoolean(FrameworkConfigKeys.INFER_ENV_ENABLE) + ? new IncGraphInferComputeContextImpl() + : new IncGraphComputeContextImpl(); + this.incVCComputeFunction.init(this.graphIncVCComputeCtx); - @Override - public void doFinishIteration(long iterations) { - LOGGER.info("finish iteration:{}", iterations); - //compute - if (this.iterations == 1L) { - Set vIds = temporaryGraphCache.getAllEvolveVId(); - this.invokeVIds.addAll(vIds); - for (K vId : vIds) { - this.graphIncVCComputeCtx.init(iterations, vId); - this.incVCComputeFunction.evolve(vId, - this.graphIncVCComputeCtx.getTemporaryGraph()); - } - } else { - this.graphMsgBox.processInMessage(new MsgProcessFunc() { - @Override - public void process(K vertexId, List ms) { - graphIncVCComputeCtx.init(iterations, vertexId); - invokeVIds.add(vertexId); - incVCComputeFunction.compute(vertexId, ms.iterator()); - } - }); - this.graphMsgBox.clearInBox(); - } - if (incVCComputeFunction instanceof RichIteratorFunction) { - ((RichIteratorFunction) incVCComputeFunction).finishIteration(iterations); - } - // Emit message. - this.graphMsgBox.processOutMessage(new MsgProcessFunc() { + this.invokeVIds = new HashSet<>(); + + for (ICollector collector : this.collectors) { + if (!collector.getTag().equals(GraphRecordNames.Message.name()) + && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { + vertexCollector = collector; + } + } + } + + @Override + public void doFinishIteration(long iterations) { + LOGGER.info("finish iteration:{}", iterations); + // compute + if (this.iterations == 1L) { + Set vIds = temporaryGraphCache.getAllEvolveVId(); + this.invokeVIds.addAll(vIds); + for (K vId : vIds) { + this.graphIncVCComputeCtx.init(iterations, vId); + this.incVCComputeFunction.evolve(vId, this.graphIncVCComputeCtx.getTemporaryGraph()); + } + } else { + this.graphMsgBox.processInMessage( + new MsgProcessFunc() { @Override - public void process(K vertexId, List messages) { - // Collect message. - int size = messages.size(); - for (int i = 0; i < size; i++) { - messageCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); - } + public void process(K vertexId, List ms) { + graphIncVCComputeCtx.init(iterations, vertexId); + invokeVIds.add(vertexId); + incVCComputeFunction.compute(vertexId, ms.iterator()); } + }); + this.graphMsgBox.clearInBox(); + } + if (incVCComputeFunction instanceof RichIteratorFunction) { + ((RichIteratorFunction) incVCComputeFunction).finishIteration(iterations); + } + // Emit message. + this.graphMsgBox.processOutMessage( + new MsgProcessFunc() { + @Override + public void process(K vertexId, List messages) { + // Collect message. + int size = messages.size(); + for (int i = 0; i < size; i++) { + messageCollector.partition( + vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); + } + } }); - this.messageCollector.finish(); - this.graphMsgBox.clearOutBox(); - } + this.messageCollector.finish(); + this.graphMsgBox.clearOutBox(); + } - @Override - public void finish() { - LOGGER.info("current batch invokeIds:{}", this.invokeVIds); - for (K vertexId : this.invokeVIds) { - this.graphIncVCComputeCtx.init(iterations, vertexId); - this.incVCComputeFunction.finish(vertexId, this.graphIncVCComputeCtx.getMutableGraph()); - } - this.invokeVIds.clear(); - this.temporaryGraphCache.clear(); - vertexCollector.finish(); - checkpoint(); + @Override + public void finish() { + LOGGER.info("current batch invokeIds:{}", this.invokeVIds); + for (K vertexId : this.invokeVIds) { + this.graphIncVCComputeCtx.init(iterations, vertexId); + this.incVCComputeFunction.finish(vertexId, this.graphIncVCComputeCtx.getMutableGraph()); } + this.invokeVIds.clear(); + this.temporaryGraphCache.clear(); + vertexCollector.finish(); + checkpoint(); + } - class IncGraphComputeContextImpl extends IncGraphContextImpl implements IncGraphComputeContext { + class IncGraphComputeContextImpl extends IncGraphContextImpl + implements IncGraphComputeContext { - public IncGraphComputeContextImpl() { - super(opContext, runtimeContext, graphState, temporaryGraphCache, graphMsgBox, maxIterations); - } + public IncGraphComputeContextImpl() { + super(opContext, runtimeContext, graphState, temporaryGraphCache, graphMsgBox, maxIterations); + } - @Override - public void collect(IVertex vertex) { - vertexCollector.partition(vertex.getId(), vertex); - } + @Override + public void collect(IVertex vertex) { + vertexCollector.partition(vertex.getId(), vertex); } + } - class IncGraphInferComputeContextImpl extends IncGraphComputeContextImpl implements - IncGraphInferContext { + class IncGraphInferComputeContextImpl extends IncGraphComputeContextImpl + implements IncGraphInferContext { - private final ThreadLocal clientLocal = new ThreadLocal<>(); + private final ThreadLocal clientLocal = new ThreadLocal<>(); - private final InferContext inferContext; + private final InferContext inferContext; - public IncGraphInferComputeContextImpl() { - if (clientLocal.get() == null) { - try { - inferContext = new InferContext<>(runtimeContext.getConfiguration()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - clientLocal.set(inferContext); - } else { - inferContext = clientLocal.get(); - } + public IncGraphInferComputeContextImpl() { + if (clientLocal.get() == null) { + try { + inferContext = new InferContext<>(runtimeContext.getConfiguration()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + clientLocal.set(inferContext); + } else { + inferContext = clientLocal.get(); + } + } - @Override - public OUT infer(Object... modelInputs) { - try { - return inferContext.infer(modelInputs); - } catch (Exception e) { - throw new GeaflowRuntimeException("model infer failed", e); - } - } + @Override + public OUT infer(Object... modelInputs) { + try { + return inferContext.infer(modelInputs); + } catch (Exception e) { + throw new GeaflowRuntimeException("model infer failed", e); + } + } - @Override - public void close() throws IOException { - if (clientLocal.get() != null) { - clientLocal.get().close(); - clientLocal.remove(); - } - } + @Override + public void close() throws IOException { + if (clientLocal.get() != null) { + clientLocal.get().close(); + clientLocal.remove(); + } } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeWithAggOp.java index 81665c4a1..9bcf3aba9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeWithAggOp.java @@ -30,41 +30,43 @@ import org.slf4j.LoggerFactory; public class DynamicGraphVertexCentricComputeWithAggOp - extends DynamicGraphVertexCentricComputeOp> + extends DynamicGraphVertexCentricComputeOp< + K, VV, EV, M, IncVertexCentricAggComputeFunction> implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger(DynamicGraphVertexCentricComputeWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricComputeWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public DynamicGraphVertexCentricComputeWithAggOp( - GraphViewDesc graphViewDesc, - IncVertexCentricAggCompute vcAlgorithm) { - super(graphViewDesc, vcAlgorithm); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public DynamicGraphVertexCentricComputeWithAggOp( + GraphViewDesc graphViewDesc, + IncVertexCentricAggCompute vcAlgorithm) { + super(graphViewDesc, vcAlgorithm); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) incVCComputeFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) incVCComputeFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } - - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/cache/TemporaryGraphCache.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/cache/TemporaryGraphCache.java index c6030489c..2b5a2920a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/cache/TemporaryGraphCache.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/cache/TemporaryGraphCache.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.slf4j.Logger; @@ -32,46 +33,45 @@ public class TemporaryGraphCache { - private static final Logger LOGGER = LoggerFactory.getLogger(TemporaryGraphCache.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TemporaryGraphCache.class); - private final Set vertexIds; - private final Map> vertices; - private final Map>> vertexEdges; + private final Set vertexIds; + private final Map> vertices; + private final Map>> vertexEdges; - public TemporaryGraphCache() { - this.vertexIds = new HashSet<>(); - this.vertices = new HashMap<>(); - this.vertexEdges = new HashMap<>(); - } + public TemporaryGraphCache() { + this.vertexIds = new HashSet<>(); + this.vertices = new HashMap<>(); + this.vertexEdges = new HashMap<>(); + } - public void addVertex(IVertex vertex) { - this.vertexIds.add(vertex.getId()); - this.vertices.put(vertex.getId(), vertex); - } + public void addVertex(IVertex vertex) { + this.vertexIds.add(vertex.getId()); + this.vertices.put(vertex.getId(), vertex); + } - public IVertex getVertex(K vId) { - return this.vertices.get(vId); - } + public IVertex getVertex(K vId) { + return this.vertices.get(vId); + } - public void addEdge(IEdge edge) { - this.vertexIds.add(edge.getSrcId()); - List> edges = this.vertexEdges.getOrDefault(edge.getSrcId(), - new ArrayList<>()); - edges.add(edge); - this.vertexEdges.put(edge.getSrcId(), edges); - } + public void addEdge(IEdge edge) { + this.vertexIds.add(edge.getSrcId()); + List> edges = this.vertexEdges.getOrDefault(edge.getSrcId(), new ArrayList<>()); + edges.add(edge); + this.vertexEdges.put(edge.getSrcId(), edges); + } - public List> getEdges(K vId) { - return this.vertexEdges.get(vId); - } + public List> getEdges(K vId) { + return this.vertexEdges.get(vId); + } - public Set getAllEvolveVId() { - return this.vertexIds; - } + public Set getAllEvolveVId() { + return this.vertexIds; + } - public void clear() { - this.vertexIds.clear(); - this.vertices.clear(); - this.vertexEdges.clear(); - } + public void clear() { + this.vertexIds.clear(); + this.vertices.clear(); + this.vertexEdges.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/AbstractStaticGraphVertexCentricOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/AbstractStaticGraphVertexCentricOp.java index 769d47cb4..e9d06b4fa 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/AbstractStaticGraphVertexCentricOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/AbstractStaticGraphVertexCentricOp.java @@ -30,45 +30,46 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractStaticGraphVertexCentricOp> +public abstract class AbstractStaticGraphVertexCentricOp< + K, VV, EV, M, FUNC extends VertexCentricAlgo> extends AbstractGraphVertexCentricOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractStaticGraphVertexCentricOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractStaticGraphVertexCentricOp.class); - public AbstractStaticGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { - super(graphViewDesc, func); - } + public AbstractStaticGraphVertexCentricOp(GraphViewDesc graphViewDesc, FUNC func) { + super(graphViewDesc, func); + } - @Override - public void addVertex(IVertex vertex) { - if (enableDebug) { - LOGGER.info("taskId:{} add vertex:{}", taskId, vertex); - } - this.graphState.staticGraph().V().add(vertex); - this.opInputMeter.mark(); + @Override + public void addVertex(IVertex vertex) { + if (enableDebug) { + LOGGER.info("taskId:{} add vertex:{}", taskId, vertex); } + this.graphState.staticGraph().V().add(vertex); + this.opInputMeter.mark(); + } - @Override - public void addEdge(IEdge edge) { - if (enableDebug) { - LOGGER.info("taskId:{} add edge:{}", taskId, edge); - } - this.graphState.staticGraph().E().add(edge); - this.opInputMeter.mark(); + @Override + public void addEdge(IEdge edge) { + if (enableDebug) { + LOGGER.info("taskId:{} add edge:{}", taskId, edge); } + this.graphState.staticGraph().E().add(edge); + this.opInputMeter.mark(); + } - @Override - protected GraphStateDescriptor buildGraphStateDesc(String name) { - GraphStateDescriptor desc = super.buildGraphStateDesc(name); - desc.withDataModel(graphViewDesc.isStatic() ? DataModel.STATIC_GRAPH : DataModel.DYNAMIC_GRAPH); - if (graphViewDesc.getGraphMetaType() != null) { - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - } - return desc; + @Override + protected GraphStateDescriptor buildGraphStateDesc(String name) { + GraphStateDescriptor desc = super.buildGraphStateDesc(name); + desc.withDataModel(graphViewDesc.isStatic() ? DataModel.STATIC_GRAPH : DataModel.DYNAMIC_GRAPH); + if (graphViewDesc.getGraphMetaType() != null) { + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); } + return desc; + } - public GraphViewDesc getGraphViewDesc() { - return graphViewDesc; - } + public GraphViewDesc getGraphViewDesc() { + return graphViewDesc; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeOp.java index e67fbc848..c72fec3a8 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeOp.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricComputeAlgo; @@ -42,113 +43,120 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphVertexCentricComputeOp> - extends AbstractStaticGraphVertexCentricOp> { +public class StaticGraphVertexCentricComputeOp< + K, VV, EV, M, FUNC extends VertexCentricComputeFunction> + extends AbstractStaticGraphVertexCentricOp< + K, VV, EV, M, AbstractVertexCentricComputeAlgo> { - private static final Logger LOGGER = LoggerFactory.getLogger(StaticGraphVertexCentricComputeOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricComputeOp.class); - protected GraphVCComputeCtxImpl graphVCComputeCtx; - protected VertexCentricComputeFunction vcComputeFunction; + protected GraphVCComputeCtxImpl graphVCComputeCtx; + protected VertexCentricComputeFunction vcComputeFunction; - private ICollector> vertexCollector; + private ICollector> vertexCollector; - public StaticGraphVertexCentricComputeOp(GraphViewDesc graphViewDesc, - AbstractVertexCentricComputeAlgo vcAlgorithm) { - super(graphViewDesc, vcAlgorithm); - opArgs.setOpType(OpType.VERTEX_CENTRIC_COMPUTE); - } + public StaticGraphVertexCentricComputeOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricComputeAlgo vcAlgorithm) { + super(graphViewDesc, vcAlgorithm); + opArgs.setOpType(OpType.VERTEX_CENTRIC_COMPUTE); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); + @Override + public void open(OpContext opContext) { + super.open(opContext); - this.vcComputeFunction = this.function.getComputeFunction(); + this.vcComputeFunction = this.function.getComputeFunction(); - this.graphVCComputeCtx = new GraphVCComputeCtxImpl( + this.graphVCComputeCtx = + new GraphVCComputeCtxImpl( opContext, this.runtimeContext, this.graphState, this.graphMsgBox, this.maxIterations); - this.vcComputeFunction.init(this.graphVCComputeCtx); - - for (ICollector collector : this.collectors) { - if (!collector.getTag().equals(GraphRecordNames.Message.name()) - && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { - vertexCollector = collector; - } - } + this.vcComputeFunction.init(this.graphVCComputeCtx); + for (ICollector collector : this.collectors) { + if (!collector.getTag().equals(GraphRecordNames.Message.name()) + && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { + vertexCollector = collector; + } } - - @Override - public void doFinishIteration(long iterations) { - - // Compute. - if (iterations == 1L) { - Iterator> vertexIterator = this.graphState.staticGraph().V().iterator(); - while (vertexIterator.hasNext()) { - IVertex vertex = vertexIterator.next(); - K vertexId = vertex.getId(); - graphVCComputeCtx.init(iterations, vertexId); - vcComputeFunction.compute(vertexId, Collections.emptyIterator()); - } - } else { - this.graphMsgBox.processInMessage(new MsgProcessFunc() { - @Override - public void process(K vertexId, List ms) { - graphVCComputeCtx.init(iterations, vertexId); - vcComputeFunction.compute(vertexId, ms.iterator()); - } - }); - this.graphMsgBox.clearInBox(); - } - if (vcComputeFunction instanceof RichIteratorFunction) { - ((RichIteratorFunction) vcComputeFunction).finishIteration(iterations); - } - // Emit message. - this.graphMsgBox.processOutMessage(new MsgProcessFunc() { + } + + @Override + public void doFinishIteration(long iterations) { + + // Compute. + if (iterations == 1L) { + Iterator> vertexIterator = this.graphState.staticGraph().V().iterator(); + while (vertexIterator.hasNext()) { + IVertex vertex = vertexIterator.next(); + K vertexId = vertex.getId(); + graphVCComputeCtx.init(iterations, vertexId); + vcComputeFunction.compute(vertexId, Collections.emptyIterator()); + } + } else { + this.graphMsgBox.processInMessage( + new MsgProcessFunc() { @Override - public void process(K vertexId, List messages) { - // Collect message. - int size = messages.size(); - for (int i = 0; i < size; i++) { - messageCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); - } + public void process(K vertexId, List ms) { + graphVCComputeCtx.init(iterations, vertexId); + vcComputeFunction.compute(vertexId, ms.iterator()); } - }); - messageCollector.finish(); - this.graphMsgBox.clearOutBox(); + }); + this.graphMsgBox.clearInBox(); } - - - @Override - public void finish() { - try (CloseableIterator> vertexIterator = graphState.staticGraph().V().query().iterator()) { - while (vertexIterator.hasNext()) { - IVertex vertex = vertexIterator.next(); - vertexCollector.partition(vertex.getId(), vertex); + if (vcComputeFunction instanceof RichIteratorFunction) { + ((RichIteratorFunction) vcComputeFunction).finishIteration(iterations); + } + // Emit message. + this.graphMsgBox.processOutMessage( + new MsgProcessFunc() { + @Override + public void process(K vertexId, List messages) { + // Collect message. + int size = messages.size(); + for (int i = 0; i < size; i++) { + messageCollector.partition( + vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - this.vcComputeFunction.finish(); - vertexCollector.finish(); + } + }); + messageCollector.finish(); + this.graphMsgBox.clearOutBox(); + } + + @Override + public void finish() { + try (CloseableIterator> vertexIterator = + graphState.staticGraph().V().query().iterator()) { + while (vertexIterator.hasNext()) { + IVertex vertex = vertexIterator.next(); + vertexCollector.partition(vertex.getId(), vertex); + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - class GraphVCComputeCtxImpl extends StaticGraphContextImpl implements VertexCentricComputeFuncContext { - - public GraphVCComputeCtxImpl(OpContext opContext, - RuntimeContext runtimeContext, - GraphState graphState, - IGraphMsgBox graphMsgBox, - long maxIteration) { - super(opContext, runtimeContext, graphState, graphMsgBox, maxIteration); - } - - @Override - public void setNewVertexValue(VV value) { - IVertex valueVertex = graphState.staticGraph().V().query(vertexId).get(); - valueVertex = valueVertex.withValue(value); - graphState.staticGraph().V().add(valueVertex); - } - + this.vcComputeFunction.finish(); + vertexCollector.finish(); + } + + class GraphVCComputeCtxImpl extends StaticGraphContextImpl + implements VertexCentricComputeFuncContext { + + public GraphVCComputeCtxImpl( + OpContext opContext, + RuntimeContext runtimeContext, + GraphState graphState, + IGraphMsgBox graphMsgBox, + long maxIteration) { + super(opContext, runtimeContext, graphState, graphMsgBox, maxIteration); } + @Override + public void setNewVertexValue(VV value) { + IVertex valueVertex = graphState.staticGraph().V().query(vertexId).get(); + valueVertex = valueVertex.withValue(value); + graphState.staticGraph().V().add(valueVertex); + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeWithAggOp.java index 46be92e9c..ae5903ce7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/statical/StaticGraphVertexCentricComputeWithAggOp.java @@ -30,40 +30,43 @@ import org.slf4j.LoggerFactory; public class StaticGraphVertexCentricComputeWithAggOp - extends StaticGraphVertexCentricComputeOp> + extends StaticGraphVertexCentricComputeOp< + K, VV, EV, M, VertexCentricAggComputeFunction> implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger(StaticGraphVertexCentricComputeWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricComputeWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public StaticGraphVertexCentricComputeWithAggOp(GraphViewDesc graphViewDesc, - VertexCentricAggCompute vcAlgorithm) { - super(graphViewDesc, vcAlgorithm); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public StaticGraphVertexCentricComputeWithAggOp( + GraphViewDesc graphViewDesc, + VertexCentricAggCompute vcAlgorithm) { + super(graphViewDesc, vcAlgorithm); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) vcComputeFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) vcComputeFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } - - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/GraphViewMaterializeOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/GraphViewMaterializeOp.java index 46b440482..1a00ff065 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/GraphViewMaterializeOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/GraphViewMaterializeOp.java @@ -21,10 +21,9 @@ import static org.apache.geaflow.operator.Constants.GRAPH_VERSION; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.List; + import org.apache.geaflow.api.graph.materialize.GraphMaterializeFunction; import org.apache.geaflow.api.trait.CheckpointTrait; import org.apache.geaflow.api.trait.TransactionTrait; @@ -51,161 +50,171 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class GraphViewMaterializeOp extends AbstractOneInputOperator> implements TransactionTrait, CheckpointTrait { - - private static final Logger LOGGER = LoggerFactory.getLogger(GraphViewMaterializeOp.class); - - private final GraphViewDesc graphViewDesc; - - protected transient GraphState graphState; - - public GraphViewMaterializeOp(GraphViewDesc graphViewDesc) { - this.graphViewDesc = graphViewDesc; - } - - @Override - public void open(OpContext opContext) { - super.open(opContext); - String name = graphViewDesc.getName(); - String storeType = graphViewDesc.getBackend().name(); - GraphStateDescriptor descriptor = GraphStateDescriptor.build( - graphViewDesc.getName(), storeType); - descriptor.withDataModel(DataModel.DYNAMIC_GRAPH); - descriptor.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - descriptor.withMetricGroup(runtimeContext.getMetric()); - - int maxPara = graphViewDesc.getShardNum(); - int taskPara = runtimeContext.getTaskArgs().getParallelism(); - Preconditions.checkArgument(taskPara <= maxPara, - String.format("task parallelism '%s' must be <= shard num(max parallelism) '%s'", - taskPara, maxPara)); - - int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, - taskPara, taskIndex); - descriptor.withKeyGroup(keyGroup); - IKeyGroupAssigner keyGroupAssigner = KeyGroupAssignerFactory.createKeyGroupAssigner( - keyGroup, taskIndex, maxPara); - descriptor.withKeyGroupAssigner(keyGroupAssigner); - - int taskId = runtimeContext.getTaskArgs().getTaskId(); - LOGGER.info("opName:{} taskId:{} taskIndex:{} keyGroup:{}", name, taskId, - taskIndex, keyGroup); - this.graphState = StateFactory.buildGraphState(descriptor, - runtimeContext.getConfiguration()); - recover(); - this.function = new DynamicGraphMaterializeFunction<>(graphState); - } +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; - @Override - protected void process(Object record) throws Exception { - if (record instanceof IVertex) { - this.function.materializeVertex((IVertex) record); - } else { - this.function.materializeEdge((IEdge) record); - } +public class GraphViewMaterializeOp + extends AbstractOneInputOperator> + implements TransactionTrait, CheckpointTrait { + + private static final Logger LOGGER = LoggerFactory.getLogger(GraphViewMaterializeOp.class); + + private final GraphViewDesc graphViewDesc; + + protected transient GraphState graphState; + + public GraphViewMaterializeOp(GraphViewDesc graphViewDesc) { + this.graphViewDesc = graphViewDesc; + } + + @Override + public void open(OpContext opContext) { + super.open(opContext); + String name = graphViewDesc.getName(); + String storeType = graphViewDesc.getBackend().name(); + GraphStateDescriptor descriptor = + GraphStateDescriptor.build(graphViewDesc.getName(), storeType); + descriptor.withDataModel(DataModel.DYNAMIC_GRAPH); + descriptor.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + descriptor.withMetricGroup(runtimeContext.getMetric()); + + int maxPara = graphViewDesc.getShardNum(); + int taskPara = runtimeContext.getTaskArgs().getParallelism(); + Preconditions.checkArgument( + taskPara <= maxPara, + String.format( + "task parallelism '%s' must be <= shard num(max parallelism) '%s'", taskPara, maxPara)); + + int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, taskPara, taskIndex); + descriptor.withKeyGroup(keyGroup); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxPara); + descriptor.withKeyGroupAssigner(keyGroupAssigner); + + int taskId = runtimeContext.getTaskArgs().getTaskId(); + LOGGER.info("opName:{} taskId:{} taskIndex:{} keyGroup:{}", name, taskId, taskIndex, keyGroup); + this.graphState = StateFactory.buildGraphState(descriptor, runtimeContext.getConfiguration()); + recover(); + this.function = new DynamicGraphMaterializeFunction<>(graphState); + } + + @Override + protected void process(Object record) throws Exception { + if (record instanceof IVertex) { + this.function.materializeVertex((IVertex) record); + } else { + this.function.materializeEdge((IEdge) record); } - - @Override - public void checkpoint(long windowId) { - long checkpointId = graphViewDesc.getCheckpoint(windowId); - this.graphState.manage().operate().setCheckpointId(checkpointId); - this.graphState.manage().operate().finish(); - this.graphState.manage().operate().archive(); - - if (graphViewDesc.getBackend() == BackendType.Paimon) { - int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - - PaimonCommitRegistry registry = PaimonCommitRegistry.getInstance(); - List messages = registry.pollMessages(taskIndex); - if (messages != null && !messages.isEmpty()) { - for (TaskCommitMessage message : messages) { - List msg = message.getMessages(); - LOGGER.info("task {} emits windowId:{} chkId:{} table:{} messages:{}", - taskIndex, windowId, checkpointId, message.getTableName(), - msg.size()); - collectValue(new PaimonMessage(checkpointId, message.getTableName(), - msg)); - } - } + } + + @Override + public void checkpoint(long windowId) { + long checkpointId = graphViewDesc.getCheckpoint(windowId); + this.graphState.manage().operate().setCheckpointId(checkpointId); + this.graphState.manage().operate().finish(); + this.graphState.manage().operate().archive(); + + if (graphViewDesc.getBackend() == BackendType.Paimon) { + int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + + PaimonCommitRegistry registry = PaimonCommitRegistry.getInstance(); + List messages = registry.pollMessages(taskIndex); + if (messages != null && !messages.isEmpty()) { + for (TaskCommitMessage message : messages) { + List msg = message.getMessages(); + LOGGER.info( + "task {} emits windowId:{} chkId:{} table:{} messages:{}", + taskIndex, + windowId, + checkpointId, + message.getTableName(), + msg.size()); + collectValue(new PaimonMessage(checkpointId, message.getTableName(), msg)); } - LOGGER.info("do checkpoint over, checkpointId: {}", checkpointId); + } } + LOGGER.info("do checkpoint over, checkpointId: {}", checkpointId); + } - @Override - public void finish(long windowId) { - } + @Override + public void finish(long windowId) {} - @Override - public void rollback(long batchId) { - recover(batchId); - } + @Override + public void rollback(long batchId) { + recover(batchId); + } - @Override - public void close() { - if (this.graphState != null) { - this.graphState.manage().operate().close(); - } + @Override + public void close() { + if (this.graphState != null) { + this.graphState.manage().operate().close(); } - - protected void recover() { - recover(this.runtimeContext.getWindowId()); + } + + protected void recover() { + recover(this.runtimeContext.getWindowId()); + } + + private void recover(long windowId) { + long lastCheckPointId; + try { + ViewMetaBookKeeper keeper = + new ViewMetaBookKeeper(graphViewDesc.getName(), this.runtimeContext.getConfiguration()); + lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); + LOGGER.info( + "opName: {} will do recover, ViewMetaBookKeeper version: {}", + this.opArgs.getOpName(), + lastCheckPointId); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } - - private void recover(long windowId) { - long lastCheckPointId; - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), - this.runtimeContext.getConfiguration()); - lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); - LOGGER.info("opName: {} will do recover, ViewMetaBookKeeper version: {}", - this.opArgs.getOpName(), lastCheckPointId); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - if (lastCheckPointId >= 0) { - LOGGER.info("opName: {} do recover to graph VersionId: {}", this.opArgs.getOpName(), - lastCheckPointId); - graphState.manage().operate().setCheckpointId(lastCheckPointId); - graphState.manage().operate().recover(); - } else { - //If the graph has a checkpoint, should we recover it - LOGGER.info("lastCheckPointId < 0, windowId: {}", windowId); - if (windowId > 1) { - //Recover checkpoint id for dynamic graph is init graph version add windowId - long recoverVersionId = graphViewDesc.getCheckpoint(windowId - 1); - LOGGER.info("opName: {} do recover to latestVersionId: {}", this.opArgs.getOpName(), - recoverVersionId); - graphState.manage().operate().setCheckpointId(recoverVersionId); - graphState.manage().operate().recover(); - } - } + if (lastCheckPointId >= 0) { + LOGGER.info( + "opName: {} do recover to graph VersionId: {}", + this.opArgs.getOpName(), + lastCheckPointId); + graphState.manage().operate().setCheckpointId(lastCheckPointId); + graphState.manage().operate().recover(); + } else { + // If the graph has a checkpoint, should we recover it + LOGGER.info("lastCheckPointId < 0, windowId: {}", windowId); + if (windowId > 1) { + // Recover checkpoint id for dynamic graph is init graph version add windowId + long recoverVersionId = graphViewDesc.getCheckpoint(windowId - 1); + LOGGER.info( + "opName: {} do recover to latestVersionId: {}", + this.opArgs.getOpName(), + recoverVersionId); + graphState.manage().operate().setCheckpointId(recoverVersionId); + graphState.manage().operate().recover(); + } } + } - public static class DynamicGraphMaterializeFunction implements - GraphMaterializeFunction { + public static class DynamicGraphMaterializeFunction + implements GraphMaterializeFunction { - private final GraphState graphState; + private final GraphState graphState; - public DynamicGraphMaterializeFunction(GraphState graphState) { - this.graphState = graphState; - } - - @Override - public void materializeVertex(IVertex vertex) { - graphState.dynamicGraph().V().add(GRAPH_VERSION, vertex); - } - - @Override - public void materializeEdge(IEdge edge) { - graphState.dynamicGraph().E().add(GRAPH_VERSION, edge); - } + public DynamicGraphMaterializeFunction(GraphState graphState) { + this.graphState = graphState; + } + @Override + public void materializeVertex(IVertex vertex) { + graphState.dynamicGraph().V().add(GRAPH_VERSION, vertex); } - @VisibleForTesting - public GraphState getGraphState() { - return graphState; + @Override + public void materializeEdge(IEdge edge) { + graphState.dynamicGraph().E().add(GRAPH_VERSION, edge); } + } + + @VisibleForTesting + public GraphState getGraphState() { + return graphState; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/PaimonGlobalSink.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/PaimonGlobalSink.java index 44ba94ae8..d2cbcb1b6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/PaimonGlobalSink.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/materialize/PaimonGlobalSink.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichWindowFunction; import org.apache.geaflow.api.function.io.SinkFunction; @@ -43,97 +44,101 @@ public class PaimonGlobalSink extends RichWindowFunction implements SinkFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonGlobalSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonGlobalSink.class); - private String dbName; - private String jobName; - private long windowId; + private String dbName; + private String jobName; + private long windowId; - private List paimonMessages; - private Map writeBuilders; - private PaimonCatalogClient client; + private List paimonMessages; + private Map writeBuilders; + private PaimonCatalogClient client; - @Override - public void open(RuntimeContext runtimeContext) { - this.windowId = runtimeContext.getWindowId(); - Configuration jobConfig = runtimeContext.getConfiguration(); - this.client = PaimonCatalogManager.getCatalogClient(jobConfig); - this.jobName = jobConfig.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.dbName = jobConfig.getString(PaimonConfigKeys.PAIMON_STORE_DATABASE); + @Override + public void open(RuntimeContext runtimeContext) { + this.windowId = runtimeContext.getWindowId(); + Configuration jobConfig = runtimeContext.getConfiguration(); + this.client = PaimonCatalogManager.getCatalogClient(jobConfig); + this.jobName = jobConfig.getString(ExecutionConfigKeys.JOB_APP_NAME); + this.dbName = jobConfig.getString(PaimonConfigKeys.PAIMON_STORE_DATABASE); - this.writeBuilders = new HashMap<>(); - this.paimonMessages = new ArrayList<>(); - LOGGER.info("init paimon sink with db {}", dbName); - } + this.writeBuilders = new HashMap<>(); + this.paimonMessages = new ArrayList<>(); + LOGGER.info("init paimon sink with db {}", dbName); + } - @Override - public void write(PaimonMessage message) throws Exception { - paimonMessages.add(message); - } + @Override + public void write(PaimonMessage message) throws Exception { + paimonMessages.add(message); + } - @Override - public void finish() { - if (paimonMessages.isEmpty()) { - LOGGER.info("commit windowId {} empty messages", windowId); - return; - } + @Override + public void finish() { + if (paimonMessages.isEmpty()) { + LOGGER.info("commit windowId {} empty messages", windowId); + return; + } - final long startTime = System.currentTimeMillis(); - final long checkpointId = paimonMessages.get(0).getCheckpointId(); - Map> tableMessages = new HashMap<>(); - for (PaimonMessage message : paimonMessages) { - List messages = message.getMessages(); - if (messages != null && !messages.isEmpty()) { - String tableName = message.getTableName(); - List commitMessages = tableMessages.computeIfAbsent(tableName, - k -> new ArrayList<>()); - commitMessages.addAll(messages); - } - } - long deserializeTime = System.currentTimeMillis() - startTime; - if (!tableMessages.isEmpty()) { - for (Map.Entry> entry : tableMessages.entrySet()) { - LOGGER.info("commit table:{} messages:{}", entry.getKey(), entry.getValue().size()); - StreamWriteBuilder writeBuilder = getWriteBuilder(entry.getKey()); - try (StreamTableCommit commit = writeBuilder.newCommit()) { - try { - commit.commit(checkpointId, entry.getValue()); - } catch (Throwable e) { - LOGGER.warn("commit failed: {}", e.getMessage(), e); - Map> commitIdAndMessages = new HashMap<>(); - commitIdAndMessages.put(checkpointId, entry.getValue()); - commit.filterAndCommit(commitIdAndMessages); - } - } catch (Throwable e) { - LOGGER.error("Failed to commit data into Paimon: {}", e.getMessage(), e); - throw new GeaflowRuntimeException("Failed to commit data into Paimon.", e); - } - } + final long startTime = System.currentTimeMillis(); + final long checkpointId = paimonMessages.get(0).getCheckpointId(); + Map> tableMessages = new HashMap<>(); + for (PaimonMessage message : paimonMessages) { + List messages = message.getMessages(); + if (messages != null && !messages.isEmpty()) { + String tableName = message.getTableName(); + List commitMessages = + tableMessages.computeIfAbsent(tableName, k -> new ArrayList<>()); + commitMessages.addAll(messages); + } + } + long deserializeTime = System.currentTimeMillis() - startTime; + if (!tableMessages.isEmpty()) { + for (Map.Entry> entry : tableMessages.entrySet()) { + LOGGER.info("commit table:{} messages:{}", entry.getKey(), entry.getValue().size()); + StreamWriteBuilder writeBuilder = getWriteBuilder(entry.getKey()); + try (StreamTableCommit commit = writeBuilder.newCommit()) { + try { + commit.commit(checkpointId, entry.getValue()); + } catch (Throwable e) { + LOGGER.warn("commit failed: {}", e.getMessage(), e); + Map> commitIdAndMessages = new HashMap<>(); + commitIdAndMessages.put(checkpointId, entry.getValue()); + commit.filterAndCommit(commitIdAndMessages); + } + } catch (Throwable e) { + LOGGER.error("Failed to commit data into Paimon: {}", e.getMessage(), e); + throw new GeaflowRuntimeException("Failed to commit data into Paimon.", e); } - LOGGER.info("committed chkId:{} messages:{} deserializeCost:{}ms", - checkpointId, paimonMessages.size(), deserializeTime); - paimonMessages.clear(); + } } + LOGGER.info( + "committed chkId:{} messages:{} deserializeCost:{}ms", + checkpointId, + paimonMessages.size(), + deserializeTime); + paimonMessages.clear(); + } - private StreamWriteBuilder getWriteBuilder(String tableName) { - return writeBuilders.computeIfAbsent(tableName, k -> { - try { - FileStoreTable table = (FileStoreTable) client.getTable(Identifier.create(dbName, - tableName)); - return table.newStreamWriteBuilder().withCommitUser(jobName); - } catch (Throwable e) { - String msg = String.format("%s.%s not exist.", dbName, tableName); - throw new GeaflowRuntimeException(msg, e); - } + private StreamWriteBuilder getWriteBuilder(String tableName) { + return writeBuilders.computeIfAbsent( + tableName, + k -> { + try { + FileStoreTable table = + (FileStoreTable) client.getTable(Identifier.create(dbName, tableName)); + return table.newStreamWriteBuilder().withCommitUser(jobName); + } catch (Throwable e) { + String msg = String.format("%s.%s not exist.", dbName, tableName); + throw new GeaflowRuntimeException(msg, e); + } }); - } + } - @Override - public void close() { - LOGGER.info("close sink"); - if (client != null) { - client.close(); - } + @Override + public void close() { + LOGGER.info("close sink"); + if (client != null) { + client.close(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/source/GraphSourceOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/source/GraphSourceOperator.java index ef87a7278..c404fb65e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/source/GraphSourceOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/source/GraphSourceOperator.java @@ -19,11 +19,11 @@ package org.apache.geaflow.operator.impl.graph.source; - import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; + import org.apache.geaflow.api.function.io.GraphSourceFunction; import org.apache.geaflow.api.function.io.GraphSourceFunction.GraphSourceContext; import org.apache.geaflow.collector.ICollector; @@ -38,123 +38,122 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class GraphSourceOperator extends WindowSourceOperator, - IEdge>> { - - public static final String EDGE_TAG = "edge"; - public static final String VERTEX_TAG = "vertex"; - private static final Logger LOGGER = LoggerFactory.getLogger(GraphSourceOperator.class); - private GraphSourceContext sourceCxt; - private GraphSourceFunction sourceFunction; - // Metrics. - private long edgeCnt; - private long vertexCnt; - private long filteredVertexCnt; - private boolean isDedupEnabled; - private Set vertexIdSet; - protected Meter vertexTps; - protected Meter edgeTps; - - - public GraphSourceOperator() { +public class GraphSourceOperator + extends WindowSourceOperator, IEdge>> { + + public static final String EDGE_TAG = "edge"; + public static final String VERTEX_TAG = "vertex"; + private static final Logger LOGGER = LoggerFactory.getLogger(GraphSourceOperator.class); + private GraphSourceContext sourceCxt; + private GraphSourceFunction sourceFunction; + // Metrics. + private long edgeCnt; + private long vertexCnt; + private long filteredVertexCnt; + private boolean isDedupEnabled; + private Set vertexIdSet; + protected Meter vertexTps; + protected Meter edgeTps; + + public GraphSourceOperator() {} + + public GraphSourceOperator(GraphSourceFunction sourceFunction) { + super(); + this.sourceFunction = sourceFunction; + this.vertexIdSet = new HashSet<>(); + } + + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.vertexTps = + this.metricGroup.meter( + MetricNameFormatter.vertexTpsMetricName(this.getClass(), this.opArgs.getOpId())); + this.edgeTps = + this.metricGroup.meter( + MetricNameFormatter.edgeTpsMetricName(this.getClass(), this.opArgs.getOpId())); + this.sourceCxt = new DefaultGraphSourceContext(); + } + + public void emitRecord(long batchId) { + try { + + this.sourceFunction.fetch(batchId, sourceCxt); + this.vertexIdSet.clear(); + LOGGER.info("totalVertex: {}, filteredVertex: {}", vertexCnt, filteredVertexCnt); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } - - public GraphSourceOperator(GraphSourceFunction sourceFunction) { - super(); - this.sourceFunction = sourceFunction; - this.vertexIdSet = new HashSet<>(); + } + + private boolean filterVertex(K vertexId) { + if (vertexIdSet.contains(vertexId)) { + filteredVertexCnt++; + return true; + } else { + vertexIdSet.add(vertexId); + return false; } + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.vertexTps = this.metricGroup.meter( - MetricNameFormatter.vertexTpsMetricName(this.getClass(), this.opArgs.getOpId())); - this.edgeTps = this.metricGroup.meter( - MetricNameFormatter.edgeTpsMetricName(this.getClass(), this.opArgs.getOpId())); - this.sourceCxt = new DefaultGraphSourceContext(); + class DefaultGraphSourceContext implements GraphSourceContext { + + private final List vertexCollectors; + private final List edgeCollectors; + public DefaultGraphSourceContext() { + this.vertexCollectors = new ArrayList<>(); + this.edgeCollectors = new ArrayList<>(); + filterCollectors(collectors, vertexCollectors, edgeCollectors); } - public void emitRecord(long batchId) { - try { + @Override + public void collectVertex(IVertex vertex) throws Exception { + if (isDedupEnabled && filterVertex(vertex.getId())) { + return; + } + collect(new GraphRecord<>(vertex)); + } - this.sourceFunction.fetch(batchId, sourceCxt); - this.vertexIdSet.clear(); - LOGGER.info("totalVertex: {}, filteredVertex: {}", vertexCnt, filteredVertexCnt); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + @Override + public void collectEdge(IEdge edge) throws Exception { + collect(new GraphRecord<>(edge)); } - private boolean filterVertex(K vertexId) { - if (vertexIdSet.contains(vertexId)) { - filteredVertexCnt++; - return true; + private void filterCollectors( + List collectors, + List vertexCollectors, + List edgeCollectors) { + for (ICollector collector : collectors) { + int collectorId = collector.getId(); + String outputTag = outputTags.get(collectorId); + if (VERTEX_TAG.equals(outputTag)) { + vertexCollectors.add(collector); + } else if (EDGE_TAG.equals(outputTag)) { + edgeCollectors.add(collector); } else { - vertexIdSet.add(vertexId); - return false; + throw new GeaflowRuntimeException("unrecognized tag: " + outputTag); } + } } - class DefaultGraphSourceContext implements GraphSourceContext { - - private final List vertexCollectors; - private final List edgeCollectors; - - public DefaultGraphSourceContext() { - this.vertexCollectors = new ArrayList<>(); - this.edgeCollectors = new ArrayList<>(); - filterCollectors(collectors, vertexCollectors, edgeCollectors); - } - - @Override - public void collectVertex(IVertex vertex) throws Exception { - if (isDedupEnabled && filterVertex(vertex.getId())) { - return; - } - collect(new GraphRecord<>(vertex)); - } - - @Override - public void collectEdge(IEdge edge) throws Exception { - collect(new GraphRecord<>(edge)); - } - - private void filterCollectors(List collectors, - List vertexCollectors, - List edgeCollectors) { - for (ICollector collector : collectors) { - int collectorId = collector.getId(); - String outputTag = outputTags.get(collectorId); - if (VERTEX_TAG.equals(outputTag)) { - vertexCollectors.add(collector); - } else if (EDGE_TAG.equals(outputTag)) { - edgeCollectors.add(collector); - } else { - throw new GeaflowRuntimeException("unrecognized tag: " + outputTag); - } - } + @Override + public boolean collect(GraphRecord, IEdge> element) throws Exception { + if (element.getViewType() == ViewType.vertex) { + for (ICollector collector : vertexCollectors) { + collector.partition(element.getVertex().getId(), element.getVertex()); + vertexCnt++; + vertexTps.mark(); } - - @Override - public boolean collect(GraphRecord, IEdge> element) throws Exception { - if (element.getViewType() == ViewType.vertex) { - for (ICollector collector : vertexCollectors) { - collector.partition(element.getVertex().getId(), element.getVertex()); - vertexCnt++; - vertexTps.mark(); - } - } else { - for (ICollector collector : edgeCollectors) { - collector.partition(element.getEdge().getSrcId(), element.getEdge()); - edgeCnt++; - edgeTps.mark(); - } - } - return true; + } else { + for (ICollector collector : edgeCollectors) { + collector.partition(element.getEdge().getSrcId(), element.getEdge()); + edgeCnt++; + edgeTps.mark(); } + } + return true; } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/AbstractDynamicGraphVertexCentricTraversalOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/AbstractDynamicGraphVertexCentricTraversalOp.java index 39431faab..123e0d0c4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/AbstractDynamicGraphVertexCentricTraversalOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/AbstractDynamicGraphVertexCentricTraversalOp.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import java.util.Set; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; @@ -48,197 +49,204 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractDynamicGraphVertexCentricTraversalOp> - extends AbstractDynamicGraphVertexCentricOp> +public abstract class AbstractDynamicGraphVertexCentricTraversalOp< + K, VV, EV, M, R, FUNC extends IncVertexCentricTraversalFunction> + extends AbstractDynamicGraphVertexCentricOp< + K, VV, EV, M, AbstractIncVertexCentricTraversalAlgo> implements IGraphTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractDynamicGraphVertexCentricTraversalOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractDynamicGraphVertexCentricTraversalOp.class); - protected IncGraphVCTraversalCtxImpl graphVCTraversalCtx; - protected IncVertexCentricTraversalFunction incVcTraversalFunction; + protected IncGraphVCTraversalCtxImpl graphVCTraversalCtx; + protected IncVertexCentricTraversalFunction incVcTraversalFunction; - protected boolean addInvokeVIdsEachIteration = false; - protected Set invokeVIds; - protected List> responses; + protected boolean addInvokeVIdsEachIteration = false; + protected Set invokeVIds; + protected List> responses; - protected ICollector> responseCollector; + protected ICollector> responseCollector; - protected final List> traversalRequests; + protected final List> traversalRequests; - public AbstractDynamicGraphVertexCentricTraversalOp( - GraphViewDesc graphViewDesc, - AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal) { - super(graphViewDesc, incVertexCentricTraversal); - opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_TRAVERSAL); - this.traversalRequests = new ArrayList<>(); - } + public AbstractDynamicGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal) { + super(graphViewDesc, incVertexCentricTraversal); + opArgs.setOpType(OpType.INC_VERTEX_CENTRIC_TRAVERSAL); + this.traversalRequests = new ArrayList<>(); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - - this.incVcTraversalFunction = this.function.getIncTraversalFunction(); - this.graphVCTraversalCtx = new IncGraphVCTraversalCtxImpl(getIdentify(), messageCollector); - this.incVcTraversalFunction.open(this.graphVCTraversalCtx); - - this.addInvokeVIdsEachIteration = Configuration.getBoolean(FrameworkConfigKeys.ADD_INVOKE_VIDS_EACH_ITERATION, - opContext.getRuntimeContext().getConfiguration().getConfigMap()); - this.invokeVIds = new HashSet<>(); - this.responses = new ArrayList<>(); - - for (ICollector collector : this.collectors) { - if (!collector.getTag().equals(GraphRecordNames.Message.name()) - && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { - responseCollector = collector; - } - } - } + @Override + public void open(OpContext opContext) { + super.open(opContext); - @Override - public void doFinishIteration(long iterations) { - // Compute. - if (iterations == 1L) { - // Evolve. - Set vIds = temporaryGraphCache.getAllEvolveVId(); - this.invokeVIds.addAll(vIds); - for (K vId : vIds) { - this.graphVCTraversalCtx.init(iterations, vId); - this.incVcTraversalFunction.evolve(vId, this.graphVCTraversalCtx.getTemporaryGraph()); - } - traversalByRequest(); - } else { - this.graphMsgBox.processInMessage(new MsgProcessFunc() { - @Override - public void process(K vertexId, List messages) { - if (addInvokeVIdsEachIteration) { - invokeVIds.add(vertexId); - } - graphVCTraversalCtx.init(iterations, vertexId); - incVcTraversalFunction.compute(vertexId, messages.iterator()); - } - }); - this.graphMsgBox.clearInBox(); - } - if (this.incVcTraversalFunction instanceof RichIteratorFunction) { - ((RichIteratorFunction) this.incVcTraversalFunction).finishIteration(iterations); - } - // Emit message. - this.graphMsgBox.processOutMessage(new MsgProcessFunc() { + this.incVcTraversalFunction = this.function.getIncTraversalFunction(); + this.graphVCTraversalCtx = new IncGraphVCTraversalCtxImpl(getIdentify(), messageCollector); + this.incVcTraversalFunction.open(this.graphVCTraversalCtx); + + this.addInvokeVIdsEachIteration = + Configuration.getBoolean( + FrameworkConfigKeys.ADD_INVOKE_VIDS_EACH_ITERATION, + opContext.getRuntimeContext().getConfiguration().getConfigMap()); + this.invokeVIds = new HashSet<>(); + this.responses = new ArrayList<>(); + + for (ICollector collector : this.collectors) { + if (!collector.getTag().equals(GraphRecordNames.Message.name()) + && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { + responseCollector = collector; + } + } + } + + @Override + public void doFinishIteration(long iterations) { + // Compute. + if (iterations == 1L) { + // Evolve. + Set vIds = temporaryGraphCache.getAllEvolveVId(); + this.invokeVIds.addAll(vIds); + for (K vId : vIds) { + this.graphVCTraversalCtx.init(iterations, vId); + this.incVcTraversalFunction.evolve(vId, this.graphVCTraversalCtx.getTemporaryGraph()); + } + traversalByRequest(); + } else { + this.graphMsgBox.processInMessage( + new MsgProcessFunc() { @Override public void process(K vertexId, List messages) { - // Collect message. - int size = messages.size(); - for (int i = 0; i < size; i++) { - messageCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); - } + if (addInvokeVIdsEachIteration) { + invokeVIds.add(vertexId); + } + graphVCTraversalCtx.init(iterations, vertexId); + incVcTraversalFunction.compute(vertexId, messages.iterator()); + } + }); + this.graphMsgBox.clearInBox(); + } + if (this.incVcTraversalFunction instanceof RichIteratorFunction) { + ((RichIteratorFunction) this.incVcTraversalFunction).finishIteration(iterations); + } + // Emit message. + this.graphMsgBox.processOutMessage( + new MsgProcessFunc() { + @Override + public void process(K vertexId, List messages) { + // Collect message. + int size = messages.size(); + for (int i = 0; i < size; i++) { + messageCollector.partition( + vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); } + } }); - messageCollector.finish(); - this.graphMsgBox.clearOutBox(); + messageCollector.finish(); + this.graphMsgBox.clearOutBox(); + } + + protected void traversalByRequest() { + Iterator> iterator = getTraversalRequests(); + while (iterator.hasNext()) { + ITraversalRequest traversalRequest = iterator.next(); + K vertexId = traversalRequest.getVId(); + this.graphVCTraversalCtx.init(iterations, vertexId); + this.incVcTraversalFunction.init(traversalRequest); + } + } + + @Override + public void finish() { + LOGGER.info("current batch invokeIds size:{}", this.invokeVIds.size()); + for (K vertexId : this.invokeVIds) { + this.graphVCTraversalCtx.init(iterations, vertexId); + this.incVcTraversalFunction.finish(vertexId, this.graphVCTraversalCtx.getMutableGraph()); + } + this.temporaryGraphCache.clear(); + this.invokeVIds.clear(); + this.traversalRequests.clear(); + + LOGGER.info( + "incVcTraversalFunction finish, windowId:{}, invokeIds size:{}", + this.windowId, + this.invokeVIds.size()); + incVcTraversalFunction.finish(); + LOGGER.info("incVcTraversalFunction has finish windowId:{}", this.windowId); + + for (ITraversalResponse response : this.responses) { + responseCollector.partition(response.getResponseId(), response); + } + responseCollector.finish(); + responses.clear(); + checkpoint(); + LOGGER.info("TraversalOp has finish windowId:{}", this.windowId); + } + + @Override + public void close() { + super.close(); + incVcTraversalFunction.close(); + this.responses.clear(); + } + + public class IncGraphVCTraversalCtxImpl extends IncGraphContextImpl + implements IncVertexCentricTraversalFuncContext { + + private final ICollector> messageCollector; + private final String opName; + private final TraversalHistoricalGraph traversalHistoricalGraph; + private boolean enableIncrMatch; + + protected IncGraphVCTraversalCtxImpl( + String opName, ICollector> messageCollector) { + super(opContext, runtimeContext, graphState, temporaryGraphCache, graphMsgBox, maxIterations); + this.opName = opName; + this.messageCollector = messageCollector; + this.traversalHistoricalGraph = + new TraversalIncHistoricalGraph<>( + (IncHistoricalGraph) super.getHistoricalGraph()); } - protected void traversalByRequest() { - Iterator> iterator = getTraversalRequests(); - while (iterator.hasNext()) { - ITraversalRequest traversalRequest = iterator.next(); - K vertexId = traversalRequest.getVId(); - this.graphVCTraversalCtx.init(iterations, vertexId); - this.incVcTraversalFunction.init(traversalRequest); - } + public boolean isEnableIncrMatch() { + return enableIncrMatch; } - @Override - public void finish() { - LOGGER.info("current batch invokeIds size:{}", this.invokeVIds.size()); - for (K vertexId : this.invokeVIds) { - this.graphVCTraversalCtx.init(iterations, vertexId); - this.incVcTraversalFunction.finish(vertexId, this.graphVCTraversalCtx.getMutableGraph()); - } - this.temporaryGraphCache.clear(); - this.invokeVIds.clear(); - this.traversalRequests.clear(); - - LOGGER.info("incVcTraversalFunction finish, windowId:{}, invokeIds size:{}", this.windowId, this.invokeVIds.size()); - incVcTraversalFunction.finish(); - LOGGER.info("incVcTraversalFunction has finish windowId:{}", this.windowId); - - for (ITraversalResponse response : this.responses) { - responseCollector.partition(response.getResponseId(), response); - } - responseCollector.finish(); - responses.clear(); - checkpoint(); - LOGGER.info("TraversalOp has finish windowId:{}", this.windowId); + public IncGraphVCTraversalCtxImpl setEnableIncrMatch(boolean enableIncrMatch) { + this.enableIncrMatch = enableIncrMatch; + return this; } @Override - public void close() { - super.close(); - incVcTraversalFunction.close(); - this.responses.clear(); + public void activeRequest(ITraversalRequest request) {} + + @Override + public void takeResponse(ITraversalResponse response) { + responses.add(response); } - public class IncGraphVCTraversalCtxImpl extends IncGraphContextImpl - implements IncVertexCentricTraversalFuncContext { - - private final ICollector> messageCollector; - private final String opName; - private final TraversalHistoricalGraph traversalHistoricalGraph; - private boolean enableIncrMatch; - - protected IncGraphVCTraversalCtxImpl(String opName, - ICollector> messageCollector) { - super(opContext, runtimeContext, graphState, temporaryGraphCache, graphMsgBox, maxIterations); - this.opName = opName; - this.messageCollector = messageCollector; - this.traversalHistoricalGraph = new TraversalIncHistoricalGraph<>( - (IncHistoricalGraph) super.getHistoricalGraph()); - } - - public boolean isEnableIncrMatch() { - return enableIncrMatch; - } - - public IncGraphVCTraversalCtxImpl setEnableIncrMatch(boolean enableIncrMatch) { - this.enableIncrMatch = enableIncrMatch; - return this; - } - - @Override - public void activeRequest(ITraversalRequest request) { - - } - - @Override - public void takeResponse(ITraversalResponse response) { - responses.add(response); - } - - @Override - public void broadcast(IGraphMessage message) { - messageCollector.broadcast(message); - } - - @Override - public TraversalHistoricalGraph getHistoricalGraph() { - return traversalHistoricalGraph; - } - - @Override - public String getTraversalOpName() { - return opName; - } + @Override + public void broadcast(IGraphMessage message) { + messageCollector.broadcast(message); } - public void addRequest(ITraversalRequest request) { - LOGGER.info("add request:{}", request); - traversalRequests.add(request); + @Override + public TraversalHistoricalGraph getHistoricalGraph() { + return traversalHistoricalGraph; } - public Iterator> getTraversalRequests() { - return traversalRequests.iterator(); + @Override + public String getTraversalOpName() { + return opName; } + } + + public void addRequest(ITraversalRequest request) { + LOGGER.info("add request:{}", request); + traversalRequests.add(request); + } + + public Iterator> getTraversalRequests() { + return traversalRequests.iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphHelper.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphHelper.java index 174323104..35630d9f7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphHelper.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphHelper.java @@ -19,38 +19,40 @@ package org.apache.geaflow.operator.impl.graph.traversal.dynamic; - import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; public class DynamicGraphHelper { - public static boolean enableIncrTraversal(int maxIterationCount, int startIdSize, Configuration configuration) { - if (configuration != null) { - boolean res = configuration.getBoolean(DSLConfigKeys.ENABLE_INCR_TRAVERSAL); - if (!res) { - return false; - } - } - - int traversalThreshold = configuration.getInteger(DSLConfigKeys.INCR_TRAVERSAL_ITERATION_THRESHOLD); - // when maxIterationCount <=2 no need to include subGraph, since 1 hop is already included in the incr edges. - return maxIterationCount > 2 && maxIterationCount <= traversalThreshold && startIdSize == 0; + public static boolean enableIncrTraversal( + int maxIterationCount, int startIdSize, Configuration configuration) { + if (configuration != null) { + boolean res = configuration.getBoolean(DSLConfigKeys.ENABLE_INCR_TRAVERSAL); + if (!res) { + return false; + } } - public static boolean enableIncrTraversalRuntime(RuntimeContext runtimeContext) { - long windowId = runtimeContext.getWindowId(); - if (windowId == 1) { - // the first window not need evolve - return false; - } - long window = runtimeContext.getConfiguration().getLong(DSLConfigKeys.INCR_TRAVERSAL_WINDOW); - if (window == -1) { - // default do incr - return true; - } else { - return windowId > window; - } + int traversalThreshold = + configuration.getInteger(DSLConfigKeys.INCR_TRAVERSAL_ITERATION_THRESHOLD); + // when maxIterationCount <=2 no need to include subGraph, since 1 hop is already included in + // the incr edges. + return maxIterationCount > 2 && maxIterationCount <= traversalThreshold && startIdSize == 0; + } + + public static boolean enableIncrTraversalRuntime(RuntimeContext runtimeContext) { + long windowId = runtimeContext.getWindowId(); + if (windowId == 1) { + // the first window not need evolve + return false; + } + long window = runtimeContext.getConfiguration().getLong(DSLConfigKeys.INCR_TRAVERSAL_WINDOW); + if (window == -1) { + // default do incr + return true; + } else { + return windowId > window; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllOp.java index 08fa92c57..e12c1b479 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllOp.java @@ -28,48 +28,50 @@ import org.apache.geaflow.model.traversal.impl.VertexBeginTraversalRequest; import org.apache.geaflow.view.graph.GraphViewDesc; -public class DynamicGraphVertexCentricTraversalAllOp> +public class DynamicGraphVertexCentricTraversalAllOp< + K, VV, EV, M, R, FUNC extends IncVertexCentricTraversalFunction> extends AbstractDynamicGraphVertexCentricTraversalOp { - public DynamicGraphVertexCentricTraversalAllOp( - GraphViewDesc graphViewDesc, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - } + public DynamicGraphVertexCentricTraversalAllOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + } - private void traversalEvolveVIds() { - for (K vertexId : temporaryGraphCache.getAllEvolveVId()) { - ITraversalRequest traversalRequest = new VertexBeginTraversalRequest<>(vertexId); - this.graphVCTraversalCtx.init(iterations, vertexId); - this.incVcTraversalFunction.init(traversalRequest); - } + private void traversalEvolveVIds() { + for (K vertexId : temporaryGraphCache.getAllEvolveVId()) { + ITraversalRequest traversalRequest = new VertexBeginTraversalRequest<>(vertexId); + this.graphVCTraversalCtx.init(iterations, vertexId); + this.incVcTraversalFunction.init(traversalRequest); } + } - @Override - protected void traversalByRequest() { - if (graphVCTraversalCtx.isEnableIncrMatch() && DynamicGraphHelper.enableIncrTraversalRuntime(runtimeContext)) { - traversalEvolveVIds(); + @Override + protected void traversalByRequest() { + if (graphVCTraversalCtx.isEnableIncrMatch() + && DynamicGraphHelper.enableIncrTraversalRuntime(runtimeContext)) { + traversalEvolveVIds(); - } else { - if (function.getMaxIterationCount() <= 2) { - // The evolved vertices/edges can cover the match pattern when iteration <= 2 (e.g match(a)->(b)). - traversalEvolveVIds(); - return; - } + } else { + if (function.getMaxIterationCount() <= 2) { + // The evolved vertices/edges can cover the match pattern when iteration <= 2 (e.g + // match(a)->(b)). + traversalEvolveVIds(); + return; + } - // Traversal all vertices. - if (!temporaryGraphCache.getAllEvolveVId().isEmpty()) { - try (CloseableIterator idIterator = - graphState.dynamicGraph().V().query(GRAPH_VERSION, keyGroup).idIterator()) { - while (idIterator.hasNext()) { - K vertexId = idIterator.next(); - ITraversalRequest traversalRequest = new VertexBeginTraversalRequest<>(vertexId); - this.graphVCTraversalCtx.init(iterations, vertexId); - this.incVcTraversalFunction.init(traversalRequest); - } - } - } + // Traversal all vertices. + if (!temporaryGraphCache.getAllEvolveVId().isEmpty()) { + try (CloseableIterator idIterator = + graphState.dynamicGraph().V().query(GRAPH_VERSION, keyGroup).idIterator()) { + while (idIterator.hasNext()) { + K vertexId = idIterator.next(); + ITraversalRequest traversalRequest = new VertexBeginTraversalRequest<>(vertexId); + this.graphVCTraversalCtx.init(iterations, vertexId); + this.incVcTraversalFunction.init(traversalRequest); + } } + } } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllWithAggOp.java index f0b1be488..671ef1c02 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalAllWithAggOp.java @@ -29,43 +29,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricTraversalAllWithAggOp & VertexCentricAggContextFunction> +public class DynamicGraphVertexCentricTraversalAllWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + IncVertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends DynamicGraphVertexCentricTraversalAllOp implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - DynamicGraphVertexCentricTraversalAllWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricTraversalAllWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public DynamicGraphVertexCentricTraversalAllWithAggOp( - GraphViewDesc graphViewDesc, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public DynamicGraphVertexCentricTraversalAllWithAggOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsOp.java index ba84ab42e..cb7e58223 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsOp.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; import org.apache.geaflow.model.traversal.ITraversalRequest; @@ -31,59 +32,65 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricTraversalStartByIdsOp> +public class DynamicGraphVertexCentricTraversalStartByIdsOp< + K, VV, EV, M, R, FUNC extends IncVertexCentricTraversalFunction> extends AbstractDynamicGraphVertexCentricTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - DynamicGraphVertexCentricTraversalStartByIdsOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricTraversalStartByIdsOp.class); - private final List> traversalRequests; + private final List> traversalRequests; - public DynamicGraphVertexCentricTraversalStartByIdsOp( - GraphViewDesc graphViewDesc, - VertexBeginTraversalRequest vertexBeginTraversalRequest, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - traversalRequests = new ArrayList<>(); - traversalRequests.add(vertexBeginTraversalRequest); - } + public DynamicGraphVertexCentricTraversalStartByIdsOp( + GraphViewDesc graphViewDesc, + VertexBeginTraversalRequest vertexBeginTraversalRequest, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + traversalRequests = new ArrayList<>(); + traversalRequests.add(vertexBeginTraversalRequest); + } - public DynamicGraphVertexCentricTraversalStartByIdsOp( - GraphViewDesc graphViewDesc, - List> vertexBeginTraversalRequests, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - traversalRequests = new ArrayList<>(); - traversalRequests.addAll(vertexBeginTraversalRequests); - } + public DynamicGraphVertexCentricTraversalStartByIdsOp( + GraphViewDesc graphViewDesc, + List> vertexBeginTraversalRequests, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + traversalRequests = new ArrayList<>(); + traversalRequests.addAll(vertexBeginTraversalRequests); + } - @Override - protected void traversalByRequest() { - if (!temporaryGraphCache.getAllEvolveVId().isEmpty()) { - if (enableDebug) { - LOGGER.info("taskId:{} windowId:{} iterations:{} is not empty", - runtimeContext.getTaskArgs().getTaskId(), windowId, iterations); - } - super.traversalByRequest(); - } else { - LOGGER.info("taskId:{} windowId:{} iterations:{} is empty", - runtimeContext.getTaskArgs().getTaskId(), windowId, - iterations); - } + @Override + protected void traversalByRequest() { + if (!temporaryGraphCache.getAllEvolveVId().isEmpty()) { + if (enableDebug) { + LOGGER.info( + "taskId:{} windowId:{} iterations:{} is not empty", + runtimeContext.getTaskArgs().getTaskId(), + windowId, + iterations); + } + super.traversalByRequest(); + } else { + LOGGER.info( + "taskId:{} windowId:{} iterations:{} is empty", + runtimeContext.getTaskArgs().getTaskId(), + windowId, + iterations); } + } - @Override - public Iterator> getTraversalRequests() { - List> currentTaskRequest = new ArrayList<>(); - for (ITraversalRequest traversalRequest : traversalRequests) { - int maxParallelism = graphViewDesc.getShardNum(); - int currentKeyGroup = KeyGroupAssignment.assignToKeyGroup(traversalRequest.getVId(), - maxParallelism); - if (currentKeyGroup >= taskKeyGroup.getStartKeyGroup() && currentKeyGroup <= taskKeyGroup.getEndKeyGroup()) { - currentTaskRequest.add(traversalRequest); - } - } - return currentTaskRequest.iterator(); + @Override + public Iterator> getTraversalRequests() { + List> currentTaskRequest = new ArrayList<>(); + for (ITraversalRequest traversalRequest : traversalRequests) { + int maxParallelism = graphViewDesc.getShardNum(); + int currentKeyGroup = + KeyGroupAssignment.assignToKeyGroup(traversalRequest.getVId(), maxParallelism); + if (currentKeyGroup >= taskKeyGroup.getStartKeyGroup() + && currentKeyGroup <= taskKeyGroup.getEndKeyGroup()) { + currentTaskRequest.add(traversalRequest); + } } + return currentTaskRequest.iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.java index 0f4593b91..e2021192e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.traversal.dynamic; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; @@ -31,53 +32,64 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricTraversalStartByIdsWithAggOp & VertexCentricAggContextFunction> +public class DynamicGraphVertexCentricTraversalStartByIdsWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + IncVertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends DynamicGraphVertexCentricTraversalStartByIdsOp implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.class); - + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricTraversalStartByIdsWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public DynamicGraphVertexCentricTraversalStartByIdsWithAggOp( - GraphViewDesc graphViewDesc, - VertexBeginTraversalRequest vertexBeginTraversalRequest, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vertexBeginTraversalRequest, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public DynamicGraphVertexCentricTraversalStartByIdsWithAggOp( + GraphViewDesc graphViewDesc, + VertexBeginTraversalRequest vertexBeginTraversalRequest, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vertexBeginTraversalRequest, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - public DynamicGraphVertexCentricTraversalStartByIdsWithAggOp( - GraphViewDesc graphViewDesc, - List> vertexBeginTraversalRequests, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vertexBeginTraversalRequests, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public DynamicGraphVertexCentricTraversalStartByIdsWithAggOp( + GraphViewDesc graphViewDesc, + List> vertexBeginTraversalRequests, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vertexBeginTraversalRequests, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamOp.java index 1259eca2b..87388c3c9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamOp.java @@ -25,16 +25,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricTraversalStartByStreamOp> +public class DynamicGraphVertexCentricTraversalStartByStreamOp< + K, VV, EV, M, R, FUNC extends IncVertexCentricTraversalFunction> extends AbstractDynamicGraphVertexCentricTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - DynamicGraphVertexCentricTraversalStartByStreamOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricTraversalStartByStreamOp.class); - public DynamicGraphVertexCentricTraversalStartByStreamOp( - GraphViewDesc graphViewDesc, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - } + public DynamicGraphVertexCentricTraversalStartByStreamOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.java index 633a02876..f46d65f59 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/dynamic/DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.java @@ -29,43 +29,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphVertexCentricTraversalStartByStreamWithAggOp & VertexCentricAggContextFunction> +public class DynamicGraphVertexCentricTraversalStartByStreamWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + IncVertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends DynamicGraphVertexCentricTraversalStartByStreamOp implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(DynamicGraphVertexCentricTraversalStartByStreamWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public DynamicGraphVertexCentricTraversalStartByStreamWithAggOp( - GraphViewDesc graphViewDesc, - AbstractIncVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public DynamicGraphVertexCentricTraversalStartByStreamWithAggOp( + GraphViewDesc graphViewDesc, + AbstractIncVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) incVcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/AbstractStaticGraphVertexCentricTraversalOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/AbstractStaticGraphVertexCentricTraversalOp.java index 6a42cac9d..d6baee8f9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/AbstractStaticGraphVertexCentricTraversalOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/AbstractStaticGraphVertexCentricTraversalOp.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricTraversalAlgo; @@ -51,170 +52,182 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractStaticGraphVertexCentricTraversalOp> - extends AbstractStaticGraphVertexCentricOp> +public abstract class AbstractStaticGraphVertexCentricTraversalOp< + K, VV, EV, M, R, FUNC extends VertexCentricTraversalFunction> + extends AbstractStaticGraphVertexCentricOp< + K, VV, EV, M, AbstractVertexCentricTraversalAlgo> implements IGraphTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - AbstractStaticGraphVertexCentricTraversalOp.class); - - protected GraphVCTraversalCtxImpl graphVCTraversalCtx; - protected VertexCentricTraversalFunction vcTraversalFunction; - - protected List> responses; - - protected ICollector> responseCollector; - - protected final List> traversalRequests; - - public AbstractStaticGraphVertexCentricTraversalOp(GraphViewDesc graphViewDesc, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - opArgs.setOpType(OpType.VERTEX_CENTRIC_TRAVERSAL); - this.traversalRequests = new ArrayList<>(); + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractStaticGraphVertexCentricTraversalOp.class); + + protected GraphVCTraversalCtxImpl graphVCTraversalCtx; + protected VertexCentricTraversalFunction vcTraversalFunction; + + protected List> responses; + + protected ICollector> responseCollector; + + protected final List> traversalRequests; + + public AbstractStaticGraphVertexCentricTraversalOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + opArgs.setOpType(OpType.VERTEX_CENTRIC_TRAVERSAL); + this.traversalRequests = new ArrayList<>(); + } + + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.vcTraversalFunction = this.function.getTraversalFunction(); + this.graphVCTraversalCtx = + new GraphVCTraversalCtxImpl( + opContext, + this.runtimeContext, + this.graphState, + this.graphMsgBox, + this.maxIterations, + getIdentify(), + this.messageCollector); + this.vcTraversalFunction.open(this.graphVCTraversalCtx); + + this.responses = new ArrayList<>(); + + for (ICollector collector : this.collectors) { + if (!collector.getTag().equals(GraphRecordNames.Message.name()) + && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { + responseCollector = collector; + } } + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.vcTraversalFunction = this.function.getTraversalFunction(); - this.graphVCTraversalCtx = new GraphVCTraversalCtxImpl( - opContext, this.runtimeContext, this.graphState, - this.graphMsgBox, this.maxIterations, getIdentify(), this.messageCollector); - this.vcTraversalFunction.open(this.graphVCTraversalCtx); - - this.responses = new ArrayList<>(); - - for (ICollector collector : this.collectors) { - if (!collector.getTag().equals(GraphRecordNames.Message.name()) - && !collector.getTag().equals(GraphRecordNames.Aggregate.name())) { - responseCollector = collector; - } - } - } + @Override + public void doFinishIteration(long iterations) { - @Override - public void doFinishIteration(long iterations) { - - // Compute. - if (iterations == 1L) { - traversalByRequest(iterations); - } else { - this.graphMsgBox.processInMessage(new MsgProcessFunc() { - @Override - public void process(K vertexId, List messages) { - graphVCTraversalCtx.init(iterations, vertexId); - vcTraversalFunction.compute(vertexId, messages.iterator()); - } - }); - this.graphMsgBox.clearInBox(); - } - if (vcTraversalFunction instanceof RichIteratorFunction) { - ((RichIteratorFunction) vcTraversalFunction).finishIteration(iterations); - } - // Emit message. - this.graphMsgBox.processOutMessage(new MsgProcessFunc() { + // Compute. + if (iterations == 1L) { + traversalByRequest(iterations); + } else { + this.graphMsgBox.processInMessage( + new MsgProcessFunc() { @Override public void process(K vertexId, List messages) { - // Collect message. - int size = messages.size(); - for (int i = 0; i < size; i++) { - messageCollector.partition(vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); - } + graphVCTraversalCtx.init(iterations, vertexId); + vcTraversalFunction.compute(vertexId, messages.iterator()); } + }); + this.graphMsgBox.clearInBox(); + } + if (vcTraversalFunction instanceof RichIteratorFunction) { + ((RichIteratorFunction) vcTraversalFunction).finishIteration(iterations); + } + // Emit message. + this.graphMsgBox.processOutMessage( + new MsgProcessFunc() { + @Override + public void process(K vertexId, List messages) { + // Collect message. + int size = messages.size(); + for (int i = 0; i < size; i++) { + messageCollector.partition( + vertexId, new DefaultGraphMessage<>(vertexId, messages.get(i))); + } + } }); - messageCollector.finish(); - this.graphMsgBox.clearOutBox(); + messageCollector.finish(); + this.graphMsgBox.clearOutBox(); + } + + protected void traversalByRequest(long iterations) { + Iterator> iterator = getTraversalRequests(); + while (iterator.hasNext()) { + ITraversalRequest traversalRequest = iterator.next(); + K vertexId = traversalRequest.getVId(); + this.graphVCTraversalCtx.init(iterations, vertexId); + this.vcTraversalFunction.init(traversalRequest); } - - protected void traversalByRequest(long iterations) { - Iterator> iterator = getTraversalRequests(); - while (iterator.hasNext()) { - ITraversalRequest traversalRequest = iterator.next(); - K vertexId = traversalRequest.getVId(); - this.graphVCTraversalCtx.init(iterations, vertexId); - this.vcTraversalFunction.init(traversalRequest); - } + } + + @Override + public void finish() { + LOGGER.info("vcTraversalFunction finish windowId:{}", this.windowId); + vcTraversalFunction.finish(); + LOGGER.info("vcTraversalFunction has finish windowId:{}", this.windowId); + for (ITraversalResponse response : this.responses) { + responseCollector.partition(response.getResponseId(), response); + } + responseCollector.finish(); + traversalRequests.clear(); + responses.clear(); + LOGGER.info("TraversalOp has finish windowId:{}", this.windowId); + } + + @Override + public void close() { + this.vcTraversalFunction.close(); + super.close(); + this.responses.clear(); + } + + class GraphVCTraversalCtxImpl extends StaticGraphContextImpl + implements VertexCentricTraversalFuncContext { + + private final String opName; + private final ICollector> messageCollector; + + public GraphVCTraversalCtxImpl( + OpContext opContext, + RuntimeContext runtimeContext, + GraphState graphState, + IGraphMsgBox graphMsgBox, + long maxIteration, + String opName, + ICollector> messageCollector) { + super(opContext, runtimeContext, graphState, graphMsgBox, maxIteration); + this.opName = opName; + this.messageCollector = messageCollector; } @Override - public void finish() { - LOGGER.info("vcTraversalFunction finish windowId:{}", this.windowId); - vcTraversalFunction.finish(); - LOGGER.info("vcTraversalFunction has finish windowId:{}", this.windowId); - for (ITraversalResponse response : this.responses) { - responseCollector.partition(response.getResponseId(), response); - } - responseCollector.finish(); - traversalRequests.clear(); - responses.clear(); - LOGGER.info("TraversalOp has finish windowId:{}", this.windowId); + public void takeResponse(ITraversalResponse response) { + responses.add(response); } @Override - public void close() { - this.vcTraversalFunction.close(); - super.close(); - this.responses.clear(); + public TraversalVertexQuery vertex() { + if (graphViewDesc instanceof GraphSnapshotDesc) { + return new DynamicTraversalVertexQueryImpl<>(vertexId, 0L, graphState, taskKeyGroup); + } + return new StaticTraversalVertexQueryImpl<>(vertexId, graphState, taskKeyGroup); } - class GraphVCTraversalCtxImpl extends StaticGraphContextImpl - implements VertexCentricTraversalFuncContext { - - private final String opName; - private final ICollector> messageCollector; - - public GraphVCTraversalCtxImpl(OpContext opContext, - RuntimeContext runtimeContext, - GraphState graphState, - IGraphMsgBox graphMsgBox, - long maxIteration, - String opName, - ICollector> messageCollector) { - super(opContext, runtimeContext, graphState, graphMsgBox, maxIteration); - this.opName = opName; - this.messageCollector = messageCollector; - } - - @Override - public void takeResponse(ITraversalResponse response) { - responses.add(response); - } - - @Override - public TraversalVertexQuery vertex() { - if (graphViewDesc instanceof GraphSnapshotDesc) { - return new DynamicTraversalVertexQueryImpl<>(vertexId, 0L, graphState, taskKeyGroup); - } - return new StaticTraversalVertexQueryImpl<>(vertexId, graphState, taskKeyGroup); - } - - @Override - public TraversalEdgeQuery edges() { - if (graphViewDesc instanceof GraphSnapshotDesc) { - return new DynamicTraversalEdgeQueryImpl<>(vertexId, 0L, graphState, taskKeyGroup); - } - return new StaticTraversalEdgeQueryImpl<>(vertexId, graphState, taskKeyGroup); - } - - @Override - public void broadcast(IGraphMessage message) { - messageCollector.broadcast(message); - } - - @Override - public String getTraversalOpName() { - return opName; - } + @Override + public TraversalEdgeQuery edges() { + if (graphViewDesc instanceof GraphSnapshotDesc) { + return new DynamicTraversalEdgeQueryImpl<>(vertexId, 0L, graphState, taskKeyGroup); + } + return new StaticTraversalEdgeQueryImpl<>(vertexId, graphState, taskKeyGroup); } - public void addRequest(ITraversalRequest request) { - traversalRequests.add(request); + @Override + public void broadcast(IGraphMessage message) { + messageCollector.broadcast(message); } - public Iterator> getTraversalRequests() { - return traversalRequests.iterator(); + @Override + public String getTraversalOpName() { + return opName; } + } + + public void addRequest(ITraversalRequest request) { + traversalRequests.add(request); + } + + public Iterator> getTraversalRequests() { + return traversalRequests.iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllOp.java index 057b1afec..3bdfc32cc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllOp.java @@ -27,27 +27,27 @@ import org.apache.geaflow.model.traversal.impl.VertexBeginTraversalRequest; import org.apache.geaflow.view.graph.GraphViewDesc; -public class StaticGraphVertexCentricTraversalAllOp> +public class StaticGraphVertexCentricTraversalAllOp< + K, VV, EV, M, R, FUNC extends VertexCentricTraversalFunction> extends AbstractStaticGraphVertexCentricTraversalOp { - public StaticGraphVertexCentricTraversalAllOp(GraphViewDesc graphViewDesc, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - } + public StaticGraphVertexCentricTraversalAllOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + } - @Override - protected void traversalByRequest(long iterations) { - try (CloseableIterator idIterator = graphVCTraversalCtx.vertex().loadIdIterator()) { - while (idIterator.hasNext()) { - K vertexId = idIterator.next(); - ITraversalRequest traversalRequest = new VertexBeginTraversalRequest(vertexId); - this.graphVCTraversalCtx.init(iterations, vertexId); - this.vcTraversalFunction.init(traversalRequest); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + @Override + protected void traversalByRequest(long iterations) { + try (CloseableIterator idIterator = graphVCTraversalCtx.vertex().loadIdIterator()) { + while (idIterator.hasNext()) { + K vertexId = idIterator.next(); + ITraversalRequest traversalRequest = new VertexBeginTraversalRequest(vertexId); + this.graphVCTraversalCtx.init(iterations, vertexId); + this.vcTraversalFunction.init(traversalRequest); + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllWithAggOp.java index 12bff58c6..b363e37b5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalAllWithAggOp.java @@ -27,39 +27,52 @@ import org.apache.geaflow.operator.impl.graph.algo.vc.IGraphVertexCentricAggOp; import org.apache.geaflow.view.graph.GraphViewDesc; -public class StaticGraphVertexCentricTraversalAllWithAggOp & VertexCentricAggContextFunction> +public class StaticGraphVertexCentricTraversalAllWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + VertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends StaticGraphVertexCentricTraversalAllOp implements IGraphVertexCentricAggOp { - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public StaticGraphVertexCentricTraversalAllWithAggOp(GraphViewDesc graphViewDesc, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public StaticGraphVertexCentricTraversalAllWithAggOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsOp.java index 3ab0b5a00..ce9da5fb4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsOp.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; import org.apache.geaflow.model.traversal.ITraversalRequest; @@ -31,43 +32,47 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphVertexCentricTraversalStartByIdsOp> +public class StaticGraphVertexCentricTraversalStartByIdsOp< + K, VV, EV, M, R, FUNC extends VertexCentricTraversalFunction> extends AbstractStaticGraphVertexCentricTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - StaticGraphVertexCentricTraversalStartByIdsOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricTraversalStartByIdsOp.class); - private final List> traversalRequests; + private final List> traversalRequests; - public StaticGraphVertexCentricTraversalStartByIdsOp( - GraphViewDesc graphViewDesc, VertexBeginTraversalRequest vertexBeginTraversalRequest, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - traversalRequests = new ArrayList<>(); - traversalRequests.add(vertexBeginTraversalRequest); - } + public StaticGraphVertexCentricTraversalStartByIdsOp( + GraphViewDesc graphViewDesc, + VertexBeginTraversalRequest vertexBeginTraversalRequest, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + traversalRequests = new ArrayList<>(); + traversalRequests.add(vertexBeginTraversalRequest); + } - public StaticGraphVertexCentricTraversalStartByIdsOp( - GraphViewDesc graphViewDesc, List> vertexBeginTraversalRequests, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - traversalRequests = new ArrayList<>(); - traversalRequests.addAll(vertexBeginTraversalRequests); - } + public StaticGraphVertexCentricTraversalStartByIdsOp( + GraphViewDesc graphViewDesc, + List> vertexBeginTraversalRequests, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + traversalRequests = new ArrayList<>(); + traversalRequests.addAll(vertexBeginTraversalRequests); + } - @Override - public Iterator> getTraversalRequests() { - List> currentTaskRequest = new ArrayList<>(); - for (ITraversalRequest traversalRequest : traversalRequests) { - int maxParallelism = graphViewDesc.getShardNum(); - int currentKeyGroup = KeyGroupAssignment.assignToKeyGroup(traversalRequest.getVId(), maxParallelism); - LOGGER.info("maxParallelism {}", maxParallelism); + @Override + public Iterator> getTraversalRequests() { + List> currentTaskRequest = new ArrayList<>(); + for (ITraversalRequest traversalRequest : traversalRequests) { + int maxParallelism = graphViewDesc.getShardNum(); + int currentKeyGroup = + KeyGroupAssignment.assignToKeyGroup(traversalRequest.getVId(), maxParallelism); + LOGGER.info("maxParallelism {}", maxParallelism); - if (currentKeyGroup >= taskKeyGroup.getStartKeyGroup() && currentKeyGroup <= taskKeyGroup.getEndKeyGroup()) { - currentTaskRequest.add(traversalRequest); - } - } - return currentTaskRequest.iterator(); + if (currentKeyGroup >= taskKeyGroup.getStartKeyGroup() + && currentKeyGroup <= taskKeyGroup.getEndKeyGroup()) { + currentTaskRequest.add(traversalRequest); + } } + return currentTaskRequest.iterator(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsWithAggOp.java index cfc54b0a1..d325f1ac8 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByIdsWithAggOp.java @@ -20,6 +20,7 @@ package org.apache.geaflow.operator.impl.graph.traversal.statical; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; @@ -31,50 +32,64 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphVertexCentricTraversalStartByIdsWithAggOp & VertexCentricAggContextFunction> +public class StaticGraphVertexCentricTraversalStartByIdsWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + VertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends StaticGraphVertexCentricTraversalStartByIdsOp implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - StaticGraphVertexCentricTraversalStartByIdsWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricTraversalStartByIdsWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public StaticGraphVertexCentricTraversalStartByIdsWithAggOp( - GraphViewDesc graphViewDesc, VertexBeginTraversalRequest vertexBeginTraversalRequest, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vertexBeginTraversalRequest, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public StaticGraphVertexCentricTraversalStartByIdsWithAggOp( + GraphViewDesc graphViewDesc, + VertexBeginTraversalRequest vertexBeginTraversalRequest, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vertexBeginTraversalRequest, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - public StaticGraphVertexCentricTraversalStartByIdsWithAggOp( - GraphViewDesc graphViewDesc, List> vertexBeginTraversalRequests, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vertexBeginTraversalRequests, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public StaticGraphVertexCentricTraversalStartByIdsWithAggOp( + GraphViewDesc graphViewDesc, + List> vertexBeginTraversalRequests, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vertexBeginTraversalRequests, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamOp.java index b2fa19802..d92c5823a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamOp.java @@ -25,15 +25,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphVertexCentricTraversalStartByStreamOp> +public class StaticGraphVertexCentricTraversalStartByStreamOp< + K, VV, EV, M, R, FUNC extends VertexCentricTraversalFunction> extends AbstractStaticGraphVertexCentricTraversalOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - StaticGraphVertexCentricTraversalStartByStreamOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricTraversalStartByStreamOp.class); - public StaticGraphVertexCentricTraversalStartByStreamOp(GraphViewDesc graphViewDesc, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - } + public StaticGraphVertexCentricTraversalStartByStreamOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamWithAggOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamWithAggOp.java index 56d2587ac..b5dc80589 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamWithAggOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/traversal/statical/StaticGraphVertexCentricTraversalStartByStreamWithAggOp.java @@ -29,43 +29,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphVertexCentricTraversalStartByStreamWithAggOp & VertexCentricAggContextFunction> +public class StaticGraphVertexCentricTraversalStartByStreamWithAggOp< + K, + VV, + EV, + M, + R, + I, + PA, + PR, + GR, + FUNC extends + VertexCentricTraversalFunction + & VertexCentricAggContextFunction> extends StaticGraphVertexCentricTraversalStartByStreamOp implements IGraphVertexCentricAggOp { - private static final Logger LOGGER = LoggerFactory.getLogger( - StaticGraphVertexCentricTraversalStartByStreamWithAggOp.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphVertexCentricTraversalStartByStreamWithAggOp.class); - private GraphVertexCentricOpAggregator> aggregator; + private GraphVertexCentricOpAggregator< + K, VV, EV, M, I, ?, ?, ?, GR, VertexCentricAggTraversal> + aggregator; - public StaticGraphVertexCentricTraversalStartByStreamWithAggOp( - GraphViewDesc graphViewDesc, - AbstractVertexCentricTraversalAlgo vcTraversal) { - super(graphViewDesc, vcTraversal); - aggregator = new GraphVertexCentricOpAggregator(this); - } + public StaticGraphVertexCentricTraversalStartByStreamWithAggOp( + GraphViewDesc graphViewDesc, + AbstractVertexCentricTraversalAlgo vcTraversal) { + super(graphViewDesc, vcTraversal); + aggregator = new GraphVertexCentricOpAggregator(this); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + aggregator.open((VertexCentricAggContextFunction) vcTraversalFunction); + } - @Override - public void initIteration(long iteration) { - super.initIteration(iteration); - aggregator.initIteration(iteration); - } + @Override + public void initIteration(long iteration) { + super.initIteration(iteration); + aggregator.initIteration(iteration); + } - public void finishIteration(long iteration) { - super.finishIteration(iteration); - aggregator.finishIteration(iteration); - } + public void finishIteration(long iteration) { + super.finishIteration(iteration); + aggregator.finishIteration(iteration); + } - @Override - public void processAggregateResult(GR result) { - aggregator.processAggregateResult(result); - } + @Override + public void processAggregateResult(GR result) { + aggregator.processAggregateResult(result); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/io/WindowSourceOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/io/WindowSourceOperator.java index 9f770bc68..4090e8f40 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/io/WindowSourceOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/io/WindowSourceOperator.java @@ -31,69 +31,67 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class WindowSourceOperator extends AbstractStreamOperator> implements - SourceOperator { +public class WindowSourceOperator extends AbstractStreamOperator> + implements SourceOperator { - private static final Logger LOGGER = LoggerFactory.getLogger(WindowSourceOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(WindowSourceOperator.class); - protected transient SourceContext sourceCxt; - protected IWindow windowFunction; + protected transient SourceContext sourceCxt; + protected IWindow windowFunction; - public WindowSourceOperator() { - super(); - opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); - } + public WindowSourceOperator() { + super(); + opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); + } - public WindowSourceOperator(SourceFunction sourceFunction) { - super(sourceFunction); - opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); - } + public WindowSourceOperator(SourceFunction sourceFunction) { + super(sourceFunction); + opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); + } - public WindowSourceOperator(SourceFunction sourceFunction, IWindow windowFunction) { - this(sourceFunction); - this.windowFunction = windowFunction; - if (windowFunction instanceof AllWindow) { - opArgs.setOpType(OpType.SINGLE_WINDOW_SOURCE); - } else { - opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); - } + public WindowSourceOperator(SourceFunction sourceFunction, IWindow windowFunction) { + this(sourceFunction); + this.windowFunction = windowFunction; + if (windowFunction instanceof AllWindow) { + opArgs.setOpType(OpType.SINGLE_WINDOW_SOURCE); + } else { + opArgs.setOpType(OpType.MULTI_WINDOW_SOURCE); } + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.sourceCxt = new StreamSourceContext(); - TaskArgs taskArgs = opContext.getRuntimeContext().getTaskArgs(); - this.function.init(taskArgs.getParallelism(), taskArgs.getTaskIndex()); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.sourceCxt = new StreamSourceContext(); + TaskArgs taskArgs = opContext.getRuntimeContext().getTaskArgs(); + this.function.init(taskArgs.getParallelism(), taskArgs.getTaskIndex()); + } - @Override - public Boolean emit(long windowId) throws Exception { - try { - this.windowFunction.initWindow(windowId); - return this.function.fetch(this.windowFunction, sourceCxt); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + @Override + public Boolean emit(long windowId) throws Exception { + try { + this.windowFunction.initWindow(windowId); + return this.function.fetch(this.windowFunction, sourceCxt); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } + } - @Override - public void close() { - super.close(); - this.function.close(); - } + @Override + public void close() { + super.close(); + this.function.close(); + } - class StreamSourceContext implements SourceContext { + class StreamSourceContext implements SourceContext { - public StreamSourceContext() { - } - - @Override - public boolean collect(OUT element) throws Exception { - collectValue(element); - return true; - } + public StreamSourceContext() {} + @Override + public boolean collect(OUT element) throws Exception { + collectValue(element); + return true; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/iterator/IteratorOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/iterator/IteratorOperator.java index 081554f9c..00ab8b043 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/iterator/IteratorOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/iterator/IteratorOperator.java @@ -23,19 +23,12 @@ public interface IteratorOperator extends Operator { - /** - * Returns max iteration count. - */ - long getMaxIterationCount(); + /** Returns max iteration count. */ + long getMaxIterationCount(); - /** - * Initialize the windowId iteration. - */ - void initIteration(long iteration); - - /** - * Finish the windowId iteration. - */ - void finishIteration(long iteration); + /** Initialize the windowId iteration. */ + void initIteration(long iteration); + /** Finish the windowId iteration. */ + void finishIteration(long iteration); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/AbstractTransactionOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/AbstractTransactionOperator.java index c1595ba56..fa55611b5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/AbstractTransactionOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/AbstractTransactionOperator.java @@ -25,27 +25,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractTransactionOperator extends AbstractOneInputOperator> implements - TransactionTrait { +public abstract class AbstractTransactionOperator + extends AbstractOneInputOperator> implements TransactionTrait { - private static final Logger LOGGER = - LoggerFactory.getLogger(AbstractTransactionOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTransactionOperator.class); - public AbstractTransactionOperator() { - super(); - } + public AbstractTransactionOperator() { + super(); + } - public AbstractTransactionOperator(SinkFunction func) { - super(func); - } + public AbstractTransactionOperator(SinkFunction func) { + super(func); + } - @Override - public void finish(long windowId) { - LOGGER.info("transaction operator finish batch:{}", windowId); - } + @Override + public void finish(long windowId) { + LOGGER.info("transaction operator finish batch:{}", windowId); + } - @Override - public void rollback(long windowId) { - LOGGER.info("transaction operator rollback batch:{}", windowId); - } + @Override + public void rollback(long windowId) { + LOGGER.info("transaction operator rollback batch:{}", windowId); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/BroadcastOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/BroadcastOperator.java index e97f4e738..8f3b2e9c9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/BroadcastOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/BroadcastOperator.java @@ -25,11 +25,11 @@ public class BroadcastOperator extends AbstractOneInputOperator { - @SuppressWarnings("unchecked") - @Override - protected void process(IN value) throws Exception { - for (ICollector collector : collectors) { - collector.broadcast(value); - } + @SuppressWarnings("unchecked") + @Override + protected void process(IN value) throws Exception { + for (ICollector collector : collectors) { + collector.broadcast(value); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/CollectOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/CollectOperator.java index e619097c7..31f7a4158 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/CollectOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/CollectOperator.java @@ -24,13 +24,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class CollectOperator extends - AbstractOneInputOperator> { +public class CollectOperator extends AbstractOneInputOperator> { - private static final Logger LOGGER = LoggerFactory.getLogger(CollectOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CollectOperator.class); - @Override - protected void process(IN value) throws Exception { - collectValue(value); - } + @Override + protected void process(IN value) throws Exception { + collectValue(value); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FilterOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FilterOperator.java index 06d4c3a35..7e35839b9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FilterOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FilterOperator.java @@ -24,16 +24,14 @@ public class FilterOperator extends AbstractOneInputOperator> { + public FilterOperator(FilterFunction filterFunction) { + super(filterFunction); + } - public FilterOperator(FilterFunction filterFunction) { - super(filterFunction); + @Override + protected void process(T value) throws Exception { + if (function.filter(value)) { + collectValue(value); } - - @Override - protected void process(T value) throws Exception { - if (function.filter(value)) { - collectValue(value); - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FlatMapOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FlatMapOperator.java index 2d21ac883..69e97acff 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FlatMapOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/FlatMapOperator.java @@ -23,24 +23,23 @@ import org.apache.geaflow.collector.CollectionCollector; import org.apache.geaflow.operator.base.window.AbstractOneInputOperator; -public class FlatMapOperator extends - AbstractOneInputOperator> { +public class FlatMapOperator + extends AbstractOneInputOperator> { - private CollectionCollector collectionCollector; + private CollectionCollector collectionCollector; - public FlatMapOperator(FlatMapFunction function) { - super(function); - } + public FlatMapOperator(FlatMapFunction function) { + super(function); + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.collectionCollector = new CollectionCollector(opArgs.getOpId(), this.collectors); - } - - @Override - protected void process(IN value) throws Exception { - function.flatMap(value, collectionCollector); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.collectionCollector = new CollectionCollector(opArgs.getOpId(), this.collectors); + } + @Override + protected void process(IN value) throws Exception { + function.flatMap(value, collectionCollector); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/KeySelectorOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/KeySelectorOperator.java index 5e74dfab2..5181d2373 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/KeySelectorOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/KeySelectorOperator.java @@ -25,21 +25,21 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class KeySelectorOperator extends - AbstractOneInputOperator> { +public class KeySelectorOperator + extends AbstractOneInputOperator> { - private static final Logger LOGGER = LoggerFactory.getLogger(KeySelectorOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KeySelectorOperator.class); - public KeySelectorOperator(KeySelector keySelector) { - super(keySelector); - } + public KeySelectorOperator(KeySelector keySelector) { + super(keySelector); + } - @Override - protected void process(IN value) throws Exception { - KEY key = function.getKey(value); - if (key == null) { - key = (KEY) new Null(); - } - collectKValue(key, value); + @Override + protected void process(IN value) throws Exception { + KEY key = function.getKey(value); + if (key == null) { + key = (KEY) new Null(); } + collectKValue(key, value); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/MapOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/MapOperator.java index b51e631fe..286817f9a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/MapOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/MapOperator.java @@ -24,16 +24,15 @@ public class MapOperator extends AbstractOneInputOperator> { - public MapOperator(MapFunction o) { - super(o); - } + public MapOperator(MapFunction o) { + super(o); + } - @Override - protected void process(I value) throws Exception { - O out = function.map(value); - if (out != null) { - collectValue(out); - } + @Override + protected void process(I value) throws Exception { + O out = function.map(value); + if (out != null) { + collectValue(out); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/SinkOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/SinkOperator.java index 6f8a07758..303c1a2dc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/SinkOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/SinkOperator.java @@ -26,40 +26,40 @@ public class SinkOperator extends AbstractTransactionOperator { - private static final Logger LOGGER = LoggerFactory.getLogger(SinkOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SinkOperator.class); - private boolean isTransactionFunc; + private boolean isTransactionFunc; - public SinkOperator() { - super(); - } + public SinkOperator() { + super(); + } - public SinkOperator(SinkFunction sinkFunction) { - super(sinkFunction); - this.isTransactionFunc = sinkFunction instanceof TransactionTrait; - } + public SinkOperator(SinkFunction sinkFunction) { + super(sinkFunction); + this.isTransactionFunc = sinkFunction instanceof TransactionTrait; + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - } + @Override + public void open(OpContext opContext) { + super.open(opContext); + } - @Override - protected void process(T value) throws Exception { - this.function.write(value); - } + @Override + protected void process(T value) throws Exception { + this.function.write(value); + } - @Override - public void finish(long windowId) { - if (this.isTransactionFunc) { - ((TransactionTrait) this.function).finish(windowId); - } + @Override + public void finish(long windowId) { + if (this.isTransactionFunc) { + ((TransactionTrait) this.function).finish(windowId); } + } - @Override - public void rollback(long windowId) { - if (this.isTransactionFunc) { - ((TransactionTrait) this.function).rollback(windowId); - } + @Override + public void rollback(long windowId) { + if (this.isTransactionFunc) { + ((TransactionTrait) this.function).rollback(windowId); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/UnionOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/UnionOperator.java index 2e96c3b0e..d0ede1185 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/UnionOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/UnionOperator.java @@ -24,13 +24,12 @@ public class UnionOperator extends AbstractOneInputOperator { - public UnionOperator() { - super(); - } - - @Override - protected void process(I value) throws Exception { - collectValue(value); - } + public UnionOperator() { + super(); + } + @Override + protected void process(I value) throws Exception { + collectValue(value); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowAggregateOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowAggregateOperator.java index c108b7b26..26d40b099 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowAggregateOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowAggregateOperator.java @@ -19,58 +19,59 @@ package org.apache.geaflow.operator.impl.window; -import com.google.common.annotations.VisibleForTesting; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.operator.base.window.AbstractOneInputOperator; -public class WindowAggregateOperator extends - AbstractOneInputOperator> { +import com.google.common.annotations.VisibleForTesting; - private transient Map aggregatingState; - private KeySelector keySelector; +public class WindowAggregateOperator + extends AbstractOneInputOperator> { - public WindowAggregateOperator(AggregateFunction aggregateFunction, - KeySelector keySelector) { - super(aggregateFunction); - this.keySelector = keySelector; - } + private transient Map aggregatingState; + private KeySelector keySelector; - @Override - public void open(OpContext opContext) { - super.open(opContext); - this.aggregatingState = new HashMap<>(); - } + public WindowAggregateOperator( + AggregateFunction aggregateFunction, KeySelector keySelector) { + super(aggregateFunction); + this.keySelector = keySelector; + } - @Override - protected void process(IN value) throws Exception { - KEY key = keySelector.getKey(value); - ACC acc = aggregatingState.get(key); + @Override + public void open(OpContext opContext) { + super.open(opContext); + this.aggregatingState = new HashMap<>(); + } - if (acc == null) { - acc = this.function.createAccumulator(); - } - this.function.add(value, acc); - aggregatingState.put(key, acc); - } + @Override + protected void process(IN value) throws Exception { + KEY key = keySelector.getKey(value); + ACC acc = aggregatingState.get(key); - @Override - public void finish() { - for (ACC acc : aggregatingState.values()) { - OUT result = this.function.getResult(acc); - if (result != null) { - collectValue(result); - } - } - aggregatingState.clear(); - super.finish(); + if (acc == null) { + acc = this.function.createAccumulator(); } + this.function.add(value, acc); + aggregatingState.put(key, acc); + } - @VisibleForTesting - public void processValue(IN value) throws Exception { - process(value); + @Override + public void finish() { + for (ACC acc : aggregatingState.values()) { + OUT result = this.function.getResult(acc); + if (result != null) { + collectValue(result); + } } + aggregatingState.clear(); + super.finish(); + } + @VisibleForTesting + public void processValue(IN value) throws Exception { + process(value); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowReduceOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowReduceOperator.java index 8119a4368..b67df246c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowReduceOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/WindowReduceOperator.java @@ -19,51 +19,53 @@ package org.apache.geaflow.operator.impl.window; -import com.google.common.annotations.VisibleForTesting; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.ReduceFunction; import org.apache.geaflow.operator.base.window.AbstractOneInputOperator; +import com.google.common.annotations.VisibleForTesting; + public class WindowReduceOperator extends AbstractOneInputOperator> { - private final Map valueState; - private final KeySelector keySelector; + private final Map valueState; + private final KeySelector keySelector; - public WindowReduceOperator(ReduceFunction function, KeySelector keySelector) { - super(function); - this.keySelector = keySelector; - this.valueState = new HashMap<>(); - } + public WindowReduceOperator(ReduceFunction function, KeySelector keySelector) { + super(function); + this.keySelector = keySelector; + this.valueState = new HashMap<>(); + } - @Override - protected void process(T value) throws Exception { - KEY key = keySelector.getKey(value); - T oldValue = valueState.get(key); + @Override + protected void process(T value) throws Exception { + KEY key = keySelector.getKey(value); + T oldValue = valueState.get(key); - T newValue; - if (oldValue == null) { - newValue = value; - } else { - newValue = function.reduce(oldValue, value); - } - valueState.put(key, newValue); + T newValue; + if (oldValue == null) { + newValue = value; + } else { + newValue = function.reduce(oldValue, value); } + valueState.put(key, newValue); + } - @Override - public void finish() { - for (T value : valueState.values()) { - if (value != null) { - collectValue(value); - } - } - super.finish(); - valueState.clear(); + @Override + public void finish() { + for (T value : valueState.values()) { + if (value != null) { + collectValue(value); + } } + super.finish(); + valueState.clear(); + } - @VisibleForTesting - public void processValue(T value) throws Exception { - process(value); - } + @VisibleForTesting + public void processValue(T value) throws Exception { + process(value); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrAggregateOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrAggregateOperator.java index 28059fdac..fb290d1bd 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrAggregateOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrAggregateOperator.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Set; + import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.trait.CheckpointTrait; @@ -36,74 +37,77 @@ import org.apache.geaflow.utils.keygroup.KeyGroupAssignerFactory; import org.apache.geaflow.utils.keygroup.KeyGroupAssignment; -public class IncrAggregateOperator extends - AbstractOneInputOperator> +public class IncrAggregateOperator + extends AbstractOneInputOperator> implements TransactionTrait, CheckpointTrait { - private transient KeyValueState aggregatingState; - private KeySelector keySelector; - private Set keySet; + private transient KeyValueState aggregatingState; + private KeySelector keySelector; + private Set keySet; - public IncrAggregateOperator(AggregateFunction aggregateFunction, - KeySelector keySelector) { - super(aggregateFunction); - this.keySelector = keySelector; - } + public IncrAggregateOperator( + AggregateFunction aggregateFunction, KeySelector keySelector) { + super(aggregateFunction); + this.keySelector = keySelector; + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - KeyValueStateDescriptor descriptor = KeyValueStateDescriptor.build( + @Override + public void open(OpContext opContext) { + super.open(opContext); + KeyValueStateDescriptor descriptor = + KeyValueStateDescriptor.build( getIdentify(), this.runtimeContext.getConfiguration().getString(SYSTEM_STATE_BACKEND_TYPE)); - int taskIndex = this.runtimeContext.getTaskArgs().getTaskIndex(); - int parallelism = this.runtimeContext.getTaskArgs().getParallelism(); - int maxParallelism = this.runtimeContext.getTaskArgs().getMaxParallelism(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + int taskIndex = this.runtimeContext.getTaskArgs().getTaskIndex(); + int parallelism = this.runtimeContext.getTaskArgs().getParallelism(); + int maxParallelism = this.runtimeContext.getTaskArgs().getMaxParallelism(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( maxParallelism, parallelism, taskIndex); - descriptor.withKeyGroup(keyGroup); - IKeyGroupAssigner keyGroupAssigner = KeyGroupAssignerFactory.createKeyGroupAssigner( - keyGroup, taskIndex, maxParallelism); - descriptor.withKeyGroupAssigner(keyGroupAssigner); - this.aggregatingState = StateFactory.buildKeyValueState(descriptor, this.runtimeContext.getConfiguration()); - this.keySet = new HashSet<>(); - } + descriptor.withKeyGroup(keyGroup); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxParallelism); + descriptor.withKeyGroupAssigner(keyGroupAssigner); + this.aggregatingState = + StateFactory.buildKeyValueState(descriptor, this.runtimeContext.getConfiguration()); + this.keySet = new HashSet<>(); + } - @Override - protected void process(IN value) throws Exception { - KEY key = this.keySelector.getKey(value); - ACC acc = this.aggregatingState.get(key); + @Override + protected void process(IN value) throws Exception { + KEY key = this.keySelector.getKey(value); + ACC acc = this.aggregatingState.get(key); - if (acc == null) { - acc = this.function.createAccumulator(); - } - keySet.add(key); - this.function.add(value, acc); - aggregatingState.put(key, acc); + if (acc == null) { + acc = this.function.createAccumulator(); } + keySet.add(key); + this.function.add(value, acc); + aggregatingState.put(key, acc); + } - @Override - public void finish(long windowId) { - for (KEY key : keySet) { - ACC acc = aggregatingState.get(key); - OUT result = this.function.getResult(acc); - if (result != null) { - collectValue(result); - } - } - keySet.clear(); + @Override + public void finish(long windowId) { + for (KEY key : keySet) { + ACC acc = aggregatingState.get(key); + OUT result = this.function.getResult(acc); + if (result != null) { + collectValue(result); + } } + keySet.clear(); + } - @Override - public void rollback(long windowId) { - this.aggregatingState.manage().operate().setCheckpointId(windowId); - this.aggregatingState.manage().operate().recover(); - } + @Override + public void rollback(long windowId) { + this.aggregatingState.manage().operate().setCheckpointId(windowId); + this.aggregatingState.manage().operate().recover(); + } - @Override - public void checkpoint(long windowId) { - this.aggregatingState.manage().operate().setCheckpointId(windowId); - this.aggregatingState.manage().operate().finish(); - this.aggregatingState.manage().operate().archive(); - } + @Override + public void checkpoint(long windowId) { + this.aggregatingState.manage().operate().setCheckpointId(windowId); + this.aggregatingState.manage().operate().finish(); + this.aggregatingState.manage().operate().archive(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrReduceOperator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrReduceOperator.java index cdb9888ff..1638b29b3 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrReduceOperator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/window/incremental/IncrReduceOperator.java @@ -37,65 +37,66 @@ public class IncrReduceOperator extends AbstractOneInputOperator> implements TransactionTrait, CheckpointTrait { - private transient KeyValueState aggregatingState; - private final KeySelector keySelector; + private transient KeyValueState aggregatingState; + private final KeySelector keySelector; - public IncrReduceOperator(ReduceFunction function, KeySelector keySelector) { - super(function); - this.keySelector = keySelector; - } + public IncrReduceOperator(ReduceFunction function, KeySelector keySelector) { + super(function); + this.keySelector = keySelector; + } - @Override - public void open(OpContext opContext) { - super.open(opContext); - KeyValueStateDescriptor descriptor = KeyValueStateDescriptor.build( + @Override + public void open(OpContext opContext) { + super.open(opContext); + KeyValueStateDescriptor descriptor = + KeyValueStateDescriptor.build( getIdentify(), this.runtimeContext.getConfiguration().getString(SYSTEM_STATE_BACKEND_TYPE)); - int taskIndex = this.runtimeContext.getTaskArgs().getTaskIndex(); - int parallelism = this.runtimeContext.getTaskArgs().getParallelism(); - int maxParallelism = this.runtimeContext.getTaskArgs().getMaxParallelism(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + int taskIndex = this.runtimeContext.getTaskArgs().getTaskIndex(); + int parallelism = this.runtimeContext.getTaskArgs().getParallelism(); + int maxParallelism = this.runtimeContext.getTaskArgs().getMaxParallelism(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( maxParallelism, parallelism, taskIndex); - descriptor.withKeyGroup(keyGroup); - IKeyGroupAssigner keyGroupAssigner = KeyGroupAssignerFactory.createKeyGroupAssigner( - keyGroup, taskIndex, maxParallelism); - descriptor.withKeyGroupAssigner(keyGroupAssigner); - this.aggregatingState = StateFactory.buildKeyValueState(descriptor, this.runtimeContext.getConfiguration()); - } + descriptor.withKeyGroup(keyGroup); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxParallelism); + descriptor.withKeyGroupAssigner(keyGroupAssigner); + this.aggregatingState = + StateFactory.buildKeyValueState(descriptor, this.runtimeContext.getConfiguration()); + } - @Override - protected void process(T value) throws Exception { - KEY key = keySelector.getKey(value); - T oldValue = aggregatingState.get(key); + @Override + protected void process(T value) throws Exception { + KEY key = keySelector.getKey(value); + T oldValue = aggregatingState.get(key); - T newValue; - if (oldValue == null) { - newValue = value; - } else { - newValue = function.reduce(oldValue, value); - } - - aggregatingState.put(key, newValue); - if (newValue != null) { - collectValue(newValue); - } + T newValue; + if (oldValue == null) { + newValue = value; + } else { + newValue = function.reduce(oldValue, value); } - @Override - public void finish(long windowId) { - + aggregatingState.put(key, newValue); + if (newValue != null) { + collectValue(newValue); } + } - @Override - public void checkpoint(long windowId) { - this.aggregatingState.manage().operate().setCheckpointId(windowId); - this.aggregatingState.manage().operate().finish(); - this.aggregatingState.manage().operate().archive(); - } + @Override + public void finish(long windowId) {} - @Override - public void rollback(long windowId) { - this.aggregatingState.manage().operate().setCheckpointId(windowId); - this.aggregatingState.manage().operate().recover(); - } + @Override + public void checkpoint(long windowId) { + this.aggregatingState.manage().operate().setCheckpointId(windowId); + this.aggregatingState.manage().operate().finish(); + this.aggregatingState.manage().operate().archive(); + } + + @Override + public void rollback(long windowId) { + this.aggregatingState.manage().operate().setCheckpointId(windowId); + this.aggregatingState.manage().operate().recover(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/BroadCastPartitioner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/BroadCastPartitioner.java index 218c2b8f6..58c789b85 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/BroadCastPartitioner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/BroadCastPartitioner.java @@ -24,16 +24,16 @@ public class BroadCastPartitioner extends AbstractPartitioner { - public BroadCastPartitioner(int opId) { - super(opId); - } + public BroadCastPartitioner(int opId) { + super(opId); + } - public IPartition getPartition() { - return new BroadCastPartition(); - } + public IPartition getPartition() { + return new BroadCastPartition(); + } - @Override - public PartitionType getPartitionType() { - return PartitionType.broadcast; - } + @Override + public PartitionType getPartitionType() { + return PartitionType.broadcast; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/CustomPartitioner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/CustomPartitioner.java index 6ed8f60f3..a020e44b4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/CustomPartitioner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/CustomPartitioner.java @@ -23,20 +23,20 @@ public class CustomPartitioner extends AbstractPartitioner { - private final IPartition partition; + private final IPartition partition; - public CustomPartitioner(int opId, IPartition partition) { - super(opId); - this.partition = partition; - } + public CustomPartitioner(int opId, IPartition partition) { + super(opId); + this.partition = partition; + } - @Override - public IPartition getPartition() { - return this.partition; - } + @Override + public IPartition getPartition() { + return this.partition; + } - @Override - public PartitionType getPartitionType() { - return PartitionType.custom; - } + @Override + public PartitionType getPartitionType() { + return PartitionType.custom; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/ForwardPartitioner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/ForwardPartitioner.java index 3d137ef79..d1c53f60d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/ForwardPartitioner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/partitioner/impl/ForwardPartitioner.java @@ -24,22 +24,21 @@ public class ForwardPartitioner extends AbstractPartitioner { - public ForwardPartitioner() { - super(-1); - } - - public ForwardPartitioner(int opId) { - super(opId); - } - - @Override - public IPartition getPartition() { - return new RandomPartition(); - } - - @Override - public PartitionType getPartitionType() { - return PartitionType.forward; - } - + public ForwardPartitioner() { + super(-1); + } + + public ForwardPartitioner(int opId) { + super(opId); + } + + @Override + public IPartition getPartition() { + return new RandomPartition(); + } + + @Override + public PartitionType getPartitionType() { + return PartitionType.forward; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/base/AbstractOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/base/AbstractOperatorTest.java index ba8410875..a55d2ccac 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/base/AbstractOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/base/AbstractOperatorTest.java @@ -22,6 +22,7 @@ import static org.mockito.Matchers.any; import java.util.ArrayList; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.common.config.Configuration; @@ -36,67 +37,67 @@ public class AbstractOperatorTest { - @Test - public void testChainedOperator() { - - TestFunction function = new TestFunction(); - AbstractOperator operator = new TestOperator(function); - TestFunction subFunction = new TestFunction(); - AbstractOperator subOperator = new TestOperator(subFunction); - operator.addNextOperator(subOperator); + @Test + public void testChainedOperator() { - RuntimeContext runtimeContext = Mockito.mock(RuntimeContext.class); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - Mockito.doReturn(config).when(runtimeContext).getConfiguration(); + TestFunction function = new TestFunction(); + AbstractOperator operator = new TestOperator(function); + TestFunction subFunction = new TestFunction(); + AbstractOperator subOperator = new TestOperator(subFunction); + operator.addNextOperator(subOperator); - Operator.OpContext opContext = new AbstractOperator.DefaultOpContext(new ArrayList<>(), runtimeContext); - operator.open(opContext); + RuntimeContext runtimeContext = Mockito.mock(RuntimeContext.class); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + Mockito.doReturn(config).when(runtimeContext).getConfiguration(); - Assert.assertTrue(function.isOpened()); - Assert.assertTrue(subFunction.isOpened()); + Operator.OpContext opContext = + new AbstractOperator.DefaultOpContext(new ArrayList<>(), runtimeContext); + operator.open(opContext); - operator.close(); - Assert.assertTrue(function.isClosed()); - Assert.assertTrue(subFunction.isClosed()); - } + Assert.assertTrue(function.isOpened()); + Assert.assertTrue(subFunction.isOpened()); - private class TestOperator extends AbstractOperator implements OneInputOperator { + operator.close(); + Assert.assertTrue(function.isClosed()); + Assert.assertTrue(subFunction.isClosed()); + } - public TestOperator(TestFunction function) { - super(function); - } + private class TestOperator extends AbstractOperator + implements OneInputOperator { - @Override - public void processElement(TestFunction value) { - } + public TestOperator(TestFunction function) { + super(function); } - private class TestFunction extends RichFunction { + @Override + public void processElement(TestFunction value) {} + } - private boolean opened; - private boolean closed; + private class TestFunction extends RichFunction { - @Override - public void open(RuntimeContext runtimeContext) { - this.opened = true; - } + private boolean opened; + private boolean closed; - @Override - public void close() { - this.closed = true; - } + @Override + public void open(RuntimeContext runtimeContext) { + this.opened = true; + } - public boolean isOpened() { - return opened; - } + @Override + public void close() { + this.closed = true; + } - public boolean isClosed() { - return closed; - } + public boolean isOpened() { + return opened; } + public boolean isClosed() { + return closed; + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImplTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImplTest.java index daa9134e8..4a6ff145e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImplTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicEdgeQueryImplTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -50,82 +51,93 @@ public class DynamicEdgeQueryImplTest { - private EdgeQuery edgeQuery; - private List> edges; - - @BeforeClass - public void setup() { - GraphStateDescriptor desc = - GraphStateDescriptor.build("test", StoreType.MEMORY.name()); - desc.withDataModel(DataModel.DYNAMIC_GRAPH); - - GraphMetaType graphMetaType = new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, Integer.class); - - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView("test").withBackend( - BackendType.RocksDB).withSchema(graphMetaType).withShardNum(1).build(); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration()); - - graphState.dynamicGraph().V().add(0, new ValueVertex<>(0, 0)); - edges = new ArrayList<>(); - edges.add(new ValueEdge<>(0, 1, 1)); - edges.add(new ValueEdge<>(0, 2, 2)); - edges.add(new ValueEdge<>(0, 3, 3)); - edges.add(new ValueEdge<>(0, 4, 4)); - edges.add(new ValueEdge<>(0, 5, 5)); - edges.add(new ValueEdge<>(0, 6, 6, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 7, 7, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 8, 8, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 9, 9, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 10, 10, EdgeDirection.IN)); - - for (IEdge edge : edges) { - graphState.dynamicGraph().E().add(0, edge); - } - - edgeQuery = new DynamicEdgeQueryImpl(0, 0, graphState); - } - - @Test - public void testGetEdges() { - List> result = edgeQuery.getEdges(); - Assert.assertEquals(result, edges); - } - - @Test - public void testGetOutEdges() { - List> result = edgeQuery.getOutEdges(); - Assert.assertEquals(result, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.OUT).collect( - Collectors.toList())); - } - - @Test - public void testGetInEdges() { - List> result = edgeQuery.getInEdges(); - Assert.assertEquals(result, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect( - Collectors.toList())); + private EdgeQuery edgeQuery; + private List> edges; + + @BeforeClass + public void setup() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test", StoreType.MEMORY.name()); + desc.withDataModel(DataModel.DYNAMIC_GRAPH); + + GraphMetaType graphMetaType = + new GraphMetaType( + IntegerType.INSTANCE, ValueVertex.class, Integer.class, ValueEdge.class, Integer.class); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView("test") + .withBackend(BackendType.RocksDB) + .withSchema(graphMetaType) + .withShardNum(1) + .build(); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration()); + + graphState.dynamicGraph().V().add(0, new ValueVertex<>(0, 0)); + edges = new ArrayList<>(); + edges.add(new ValueEdge<>(0, 1, 1)); + edges.add(new ValueEdge<>(0, 2, 2)); + edges.add(new ValueEdge<>(0, 3, 3)); + edges.add(new ValueEdge<>(0, 4, 4)); + edges.add(new ValueEdge<>(0, 5, 5)); + edges.add(new ValueEdge<>(0, 6, 6, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 7, 7, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 8, 8, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 9, 9, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 10, 10, EdgeDirection.IN)); + + for (IEdge edge : edges) { + graphState.dynamicGraph().E().add(0, edge); } - @Test - public void testTestGetEdges() { - CloseableIterator> outEdges = edgeQuery.getEdges(OutEdgeFilter.getInstance()); - List> outEdgesList = new ArrayList<>(); - outEdges.forEachRemaining(outEdgesList::add); - Assert.assertEquals(outEdgesList, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.OUT).collect( - Collectors.toList())); - - CloseableIterator> inEdges = edgeQuery.getEdges(InEdgeFilter.getInstance()); - List> inEdgesList = new ArrayList<>(); - inEdges.forEachRemaining(inEdgesList::add); - Assert.assertEquals(inEdgesList, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect( - Collectors.toList())); - } + edgeQuery = new DynamicEdgeQueryImpl(0, 0, graphState); + } + + @Test + public void testGetEdges() { + List> result = edgeQuery.getEdges(); + Assert.assertEquals(result, edges); + } + + @Test + public void testGetOutEdges() { + List> result = edgeQuery.getOutEdges(); + Assert.assertEquals( + result, + edges.stream() + .filter(x -> x.getDirect() == EdgeDirection.OUT) + .collect(Collectors.toList())); + } + + @Test + public void testGetInEdges() { + List> result = edgeQuery.getInEdges(); + Assert.assertEquals( + result, + edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect(Collectors.toList())); + } + + @Test + public void testTestGetEdges() { + CloseableIterator> outEdges = + edgeQuery.getEdges(OutEdgeFilter.getInstance()); + List> outEdgesList = new ArrayList<>(); + outEdges.forEachRemaining(outEdgesList::add); + Assert.assertEquals( + outEdgesList, + edges.stream() + .filter(x -> x.getDirect() == EdgeDirection.OUT) + .collect(Collectors.toList())); + + CloseableIterator> inEdges = + edgeQuery.getEdges(InEdgeFilter.getInstance()); + List> inEdgesList = new ArrayList<>(); + inEdges.forEachRemaining(inEdgesList::add); + Assert.assertEquals( + inEdgesList, + edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect(Collectors.toList())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImplTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImplTest.java index 4307b33c8..a7c1b09ba 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImplTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/dynamic/DynamicVertexQueryImplTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.context.dynamic; - import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.VertexQuery; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.IntegerType; @@ -45,69 +44,74 @@ public class DynamicVertexQueryImplTest { - private VertexQuery vertexQuery; - private GraphState graphState; - - @BeforeClass - public void setup() { - GraphStateDescriptor desc = - GraphStateDescriptor.build("test", StoreType.MEMORY.name()); - desc.withDataModel(DataModel.DYNAMIC_GRAPH); - - GraphMetaType graphMetaType = new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, Integer.class); - - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView("test").withBackend( - BackendType.RocksDB).withSchema(graphMetaType).withShardNum(1).build(); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - desc.withKeyGroup(new KeyGroup(0, 1023)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1024)); - - graphState = StateFactory.buildGraphState(desc, - new Configuration()); - - graphState.dynamicGraph().V().add(0, new ValueVertex<>(1, 1)); - graphState.dynamicGraph().V().add(0, new ValueVertex<>(2, 2)); - graphState.dynamicGraph().V().add(0, new ValueVertex<>(3, 3)); - graphState.dynamicGraph().V().add(0, new ValueVertex<>(4, 4)); - graphState.dynamicGraph().V().add(0, new ValueVertex<>(5, 5)); - - vertexQuery = new DynamicVertexQueryImpl<>(1, 0, graphState); - - } - - @Test - public void testWithId() { - vertexQuery.withId(2); - IVertex vertex = vertexQuery.get(); - int k = vertex.getId(); - int v = vertex.getValue(); - Assert.assertEquals(k, 2); - Assert.assertEquals(v, 2); - } - - @Test - public void testGet() { - vertexQuery = new DynamicVertexQueryImpl<>(1, 0, graphState); - IVertex vertex = vertexQuery.get(); - int k = vertex.getId(); - int v = vertex.getValue(); - Assert.assertEquals(k, 1); - Assert.assertEquals(v, 1); - - } - - @Test - public void testTestGet() { - vertexQuery.withId(3); - IVertex vertex = vertexQuery.get(new IVertexFilter() { - @Override - public boolean filter(IVertex value) { + private VertexQuery vertexQuery; + private GraphState graphState; + + @BeforeClass + public void setup() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test", StoreType.MEMORY.name()); + desc.withDataModel(DataModel.DYNAMIC_GRAPH); + + GraphMetaType graphMetaType = + new GraphMetaType( + IntegerType.INSTANCE, ValueVertex.class, Integer.class, ValueEdge.class, Integer.class); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView("test") + .withBackend(BackendType.RocksDB) + .withSchema(graphMetaType) + .withShardNum(1) + .build(); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + desc.withKeyGroup(new KeyGroup(0, 1023)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1024)); + + graphState = StateFactory.buildGraphState(desc, new Configuration()); + + graphState.dynamicGraph().V().add(0, new ValueVertex<>(1, 1)); + graphState.dynamicGraph().V().add(0, new ValueVertex<>(2, 2)); + graphState.dynamicGraph().V().add(0, new ValueVertex<>(3, 3)); + graphState.dynamicGraph().V().add(0, new ValueVertex<>(4, 4)); + graphState.dynamicGraph().V().add(0, new ValueVertex<>(5, 5)); + + vertexQuery = new DynamicVertexQueryImpl<>(1, 0, graphState); + } + + @Test + public void testWithId() { + vertexQuery.withId(2); + IVertex vertex = vertexQuery.get(); + int k = vertex.getId(); + int v = vertex.getValue(); + Assert.assertEquals(k, 2); + Assert.assertEquals(v, 2); + } + + @Test + public void testGet() { + vertexQuery = new DynamicVertexQueryImpl<>(1, 0, graphState); + IVertex vertex = vertexQuery.get(); + int k = vertex.getId(); + int v = vertex.getValue(); + Assert.assertEquals(k, 1); + Assert.assertEquals(v, 1); + } + + @Test + public void testTestGet() { + vertexQuery.withId(3); + IVertex vertex = + vertexQuery.get( + new IVertexFilter() { + @Override + public boolean filter(IVertex value) { return value.getValue() == 3; - } - }); - int k = vertex.getId(); - int v = vertex.getValue(); - Assert.assertEquals(k, 3); - Assert.assertEquals(v, 3); - } + } + }); + int k = vertex.getId(); + int v = vertex.getValue(); + Assert.assertEquals(k, 3); + Assert.assertEquals(v, 3); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImplTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImplTest.java index 2271c32d0..cdd4c1ded 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImplTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticEdgeQueryImplTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -49,82 +50,92 @@ public class StaticEdgeQueryImplTest { - private EdgeQuery edgeQuery; - private List> edges; - - @BeforeClass - public void setup() { - GraphStateDescriptor desc = - GraphStateDescriptor.build("test", StoreType.MEMORY.name()); - - GraphMetaType graphMetaType = new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, Integer.class); - - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView("test").withBackend( - BackendType.RocksDB).withSchema(graphMetaType).withShardNum(1).build(); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration()); - - graphState.staticGraph().V().add(new ValueVertex<>(0, 0)); - edges = new ArrayList<>(); - edges.add(new ValueEdge<>(0, 1, 1)); - edges.add(new ValueEdge<>(0, 2, 2)); - edges.add(new ValueEdge<>(0, 3, 3)); - edges.add(new ValueEdge<>(0, 4, 4)); - edges.add(new ValueEdge<>(0, 5, 5)); - edges.add(new ValueEdge<>(0, 6, 6, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 7, 7, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 8, 8, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 9, 9, EdgeDirection.IN)); - edges.add(new ValueEdge<>(0, 10, 10, EdgeDirection.IN)); - - for (IEdge edge : edges) { - graphState.staticGraph().E().add(edge); - } - - edgeQuery = new StaticEdgeQueryImpl<>(0, graphState); + private EdgeQuery edgeQuery; + private List> edges; + + @BeforeClass + public void setup() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test", StoreType.MEMORY.name()); + + GraphMetaType graphMetaType = + new GraphMetaType( + IntegerType.INSTANCE, ValueVertex.class, Integer.class, ValueEdge.class, Integer.class); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView("test") + .withBackend(BackendType.RocksDB) + .withSchema(graphMetaType) + .withShardNum(1) + .build(); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration()); + + graphState.staticGraph().V().add(new ValueVertex<>(0, 0)); + edges = new ArrayList<>(); + edges.add(new ValueEdge<>(0, 1, 1)); + edges.add(new ValueEdge<>(0, 2, 2)); + edges.add(new ValueEdge<>(0, 3, 3)); + edges.add(new ValueEdge<>(0, 4, 4)); + edges.add(new ValueEdge<>(0, 5, 5)); + edges.add(new ValueEdge<>(0, 6, 6, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 7, 7, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 8, 8, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 9, 9, EdgeDirection.IN)); + edges.add(new ValueEdge<>(0, 10, 10, EdgeDirection.IN)); + + for (IEdge edge : edges) { + graphState.staticGraph().E().add(edge); } - - @Test - public void testGetEdges() { - List> result = edgeQuery.getEdges(); - Assert.assertEquals(result, edges); - } - - @Test - public void testGetOutEdges() { - List> result = edgeQuery.getOutEdges(); - Assert.assertEquals(result, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.OUT).collect( - Collectors.toList())); - } - - @Test - public void testGetInEdges() { - List> result = edgeQuery.getInEdges(); - Assert.assertEquals(result, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect( - Collectors.toList())); - } - - @Test - public void testTestGetEdges() { - CloseableIterator> outEdges = edgeQuery.getEdges(OutEdgeFilter.getInstance()); - List> outEdgesList = new ArrayList<>(); - outEdges.forEachRemaining(outEdgesList::add); - Assert.assertEquals(outEdgesList, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.OUT).collect( - Collectors.toList())); - - CloseableIterator> inEdges = edgeQuery.getEdges(InEdgeFilter.getInstance()); - List> inEdgesList = new ArrayList<>(); - inEdges.forEachRemaining(inEdgesList::add); - Assert.assertEquals(inEdgesList, - edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect( - Collectors.toList())); - } + edgeQuery = new StaticEdgeQueryImpl<>(0, graphState); + } + + @Test + public void testGetEdges() { + List> result = edgeQuery.getEdges(); + Assert.assertEquals(result, edges); + } + + @Test + public void testGetOutEdges() { + List> result = edgeQuery.getOutEdges(); + Assert.assertEquals( + result, + edges.stream() + .filter(x -> x.getDirect() == EdgeDirection.OUT) + .collect(Collectors.toList())); + } + + @Test + public void testGetInEdges() { + List> result = edgeQuery.getInEdges(); + Assert.assertEquals( + result, + edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect(Collectors.toList())); + } + + @Test + public void testTestGetEdges() { + CloseableIterator> outEdges = + edgeQuery.getEdges(OutEdgeFilter.getInstance()); + List> outEdgesList = new ArrayList<>(); + outEdges.forEachRemaining(outEdgesList::add); + Assert.assertEquals( + outEdgesList, + edges.stream() + .filter(x -> x.getDirect() == EdgeDirection.OUT) + .collect(Collectors.toList())); + + CloseableIterator> inEdges = + edgeQuery.getEdges(InEdgeFilter.getInstance()); + List> inEdgesList = new ArrayList<>(); + inEdges.forEachRemaining(inEdgesList::add); + Assert.assertEquals( + inEdgesList, + edges.stream().filter(x -> x.getDirect() == EdgeDirection.IN).collect(Collectors.toList())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImplTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImplTest.java index c9c6c43ec..ef39d66de 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImplTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/context/statical/StaticVertexQueryImplTest.java @@ -43,69 +43,73 @@ public class StaticVertexQueryImplTest { - private VertexQuery vertexQuery; - private GraphState graphState; + private VertexQuery vertexQuery; + private GraphState graphState; - @BeforeClass - public void setup() { - GraphStateDescriptor desc = - GraphStateDescriptor.build("test", StoreType.MEMORY.name()); + @BeforeClass + public void setup() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test", StoreType.MEMORY.name()); - GraphMetaType graphMetaType = new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, Integer.class); + GraphMetaType graphMetaType = + new GraphMetaType( + IntegerType.INSTANCE, ValueVertex.class, Integer.class, ValueEdge.class, Integer.class); - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView("test").withBackend( - BackendType.RocksDB).withSchema(graphMetaType).withShardNum(1).build(); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - desc.withKeyGroup(new KeyGroup(0, 1023)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1024)); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView("test") + .withBackend(BackendType.RocksDB) + .withSchema(graphMetaType) + .withShardNum(1) + .build(); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + desc.withKeyGroup(new KeyGroup(0, 1023)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1024)); - graphState = StateFactory.buildGraphState(desc, - new Configuration()); + graphState = StateFactory.buildGraphState(desc, new Configuration()); - graphState.staticGraph().V().add(new ValueVertex<>(1, 1)); - graphState.staticGraph().V().add(new ValueVertex<>(2, 2)); - graphState.staticGraph().V().add(new ValueVertex<>(3, 3)); - graphState.staticGraph().V().add(new ValueVertex<>(4, 4)); - graphState.staticGraph().V().add(new ValueVertex<>(5, 5)); + graphState.staticGraph().V().add(new ValueVertex<>(1, 1)); + graphState.staticGraph().V().add(new ValueVertex<>(2, 2)); + graphState.staticGraph().V().add(new ValueVertex<>(3, 3)); + graphState.staticGraph().V().add(new ValueVertex<>(4, 4)); + graphState.staticGraph().V().add(new ValueVertex<>(5, 5)); - vertexQuery = new StaticVertexQueryImpl<>(1, graphState); + vertexQuery = new StaticVertexQueryImpl<>(1, graphState); + } - } + @Test + public void testWithId() { + vertexQuery.withId(3); + IVertex valueVertex = vertexQuery.get(); + int k = valueVertex.getId(); + int v = valueVertex.getValue(); + Assert.assertEquals(k, 3); + Assert.assertEquals(v, 3); + } + @Test + public void testGet() { + vertexQuery = new StaticVertexQueryImpl<>(1, graphState); + IVertex valueVertex = vertexQuery.get(); + int k = valueVertex.getId(); + int v = valueVertex.getValue(); + Assert.assertEquals(k, 1); + Assert.assertEquals(v, 1); + } - @Test - public void testWithId() { - vertexQuery.withId(3); - IVertex valueVertex = vertexQuery.get(); - int k = valueVertex.getId(); - int v = valueVertex.getValue(); - Assert.assertEquals(k, 3); - Assert.assertEquals(v, 3); - } - - @Test - public void testGet() { - vertexQuery = new StaticVertexQueryImpl<>(1, graphState); - IVertex valueVertex = vertexQuery.get(); - int k = valueVertex.getId(); - int v = valueVertex.getValue(); - Assert.assertEquals(k, 1); - Assert.assertEquals(v, 1); - } - - @Test - public void testTestGet() { - vertexQuery.withId(4); - IVertex valueVertex = vertexQuery.get(new IVertexFilter() { - @Override - public boolean filter(IVertex value) { + @Test + public void testTestGet() { + vertexQuery.withId(4); + IVertex valueVertex = + vertexQuery.get( + new IVertexFilter() { + @Override + public boolean filter(IVertex value) { return value.getValue() == 4; - } - }); - int k = valueVertex.getId(); - int v = valueVertex.getValue(); - Assert.assertEquals(k, 4); - Assert.assertEquals(v, 4); - } + } + }); + int k = valueVertex.getId(); + int v = valueVertex.getValue(); + Assert.assertEquals(k, 4); + Assert.assertEquals(v, 4); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/MsgBoxTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/MsgBoxTest.java index ed7f76dd6..a5cb861c1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/MsgBoxTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/algo/vc/msgbox/MsgBoxTest.java @@ -20,36 +20,36 @@ package org.apache.geaflow.operator.impl.graph.algo.vc.msgbox; import java.util.Map; + import org.testng.Assert; import org.testng.annotations.Test; public class MsgBoxTest { - @Test - public void testCombinedMsgBox() { - CombinedMsgBox box = new CombinedMsgBox<>(Integer::sum); - box.addInMessages(1, 1); - box.addInMessages(1, 2); - box.addInMessages(2, 2); - Map inBox = box.getInMessageBox(); - Assert.assertEquals(inBox.size(), 2); - Assert.assertEquals((int) inBox.get(1), 3); - Assert.assertEquals((int) inBox.get(2), 2); - box.clearInBox(); - Assert.assertEquals(box.getInMessageBox().size(), 0); - - box.addOutMessage(0, 5); - box.addOutMessage(1, 9); - box.addOutMessage(2, 1); - box.addOutMessage(2, 2); - box.addOutMessage(2, 3); - Map outBox = box.getOutMessageBox(); - Assert.assertEquals(outBox.size(), 3); - Assert.assertEquals((int) outBox.get(0), 5); - Assert.assertEquals((int) outBox.get(1), 9); - Assert.assertEquals((int) outBox.get(2), 6); - box.clearOutBox(); - Assert.assertEquals(box.getOutMessageBox().size(), 0); - } + @Test + public void testCombinedMsgBox() { + CombinedMsgBox box = new CombinedMsgBox<>(Integer::sum); + box.addInMessages(1, 1); + box.addInMessages(1, 2); + box.addInMessages(2, 2); + Map inBox = box.getInMessageBox(); + Assert.assertEquals(inBox.size(), 2); + Assert.assertEquals((int) inBox.get(1), 3); + Assert.assertEquals((int) inBox.get(2), 2); + box.clearInBox(); + Assert.assertEquals(box.getInMessageBox().size(), 0); + box.addOutMessage(0, 5); + box.addOutMessage(1, 9); + box.addOutMessage(2, 1); + box.addOutMessage(2, 2); + box.addOutMessage(2, 3); + Map outBox = box.getOutMessageBox(); + Assert.assertEquals(outBox.size(), 3); + Assert.assertEquals((int) outBox.get(0), 5); + Assert.assertEquals((int) outBox.get(1), 9); + Assert.assertEquals((int) outBox.get(2), 6); + box.clearOutBox(); + Assert.assertEquals(box.getOutMessageBox().size(), 0); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java index 3e913f4f6..123523e1c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; @@ -53,104 +54,118 @@ public class DynamicGraphVertexCentricComputeOpTest { - @Test - public void testUpdateWindowId() { - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView("test") + @Test + public void testUpdateWindowId() { + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView("test") .withShardNum(1) .withBackend(IViewDesc.BackendType.RocksDB) - .withSchema(new GraphMetaType<>(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType<>( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - DynamicGraphVertexCentricComputeOp operator = new DynamicGraphVertexCentricComputeOp( - graphViewDesc, new IncVertexCentricCompute(5) { - @Override - public IncVertexCentricComputeFunction getIncComputeFunction() { + DynamicGraphVertexCentricComputeOp operator = + new DynamicGraphVertexCentricComputeOp( + graphViewDesc, + new IncVertexCentricCompute(5) { + @Override + public IncVertexCentricComputeFunction getIncComputeFunction() { return mock(IncVertexCentricComputeFunction.class); - } + } - @Override - public VertexCentricCombineFunction getCombineFunction() { + @Override + public VertexCentricCombineFunction getCombineFunction() { return null; - } - }); - ((AbstractOperator) operator).getOpArgs().setOpName("test"); - ((AbstractOperator) operator).getOpArgs() - .setConfig(new HashMap() {{ + } + }); + ((AbstractOperator) operator).getOpArgs().setOpName("test"); + ((AbstractOperator) operator) + .getOpArgs() + .setConfig( + new HashMap() { + { put(ENABLE_DETAIL_METRIC.getKey(), "false"); - }}); + } + }); + + List collectors = new ArrayList<>(); + ICollector collector = mock(ICollector.class); + when(collector.getId()).thenReturn(0); + when(collector.getTag()).thenReturn("tag"); + + collectors.add(collector); + collectors.add(collector); + long startWindowId = 0; + Operator.OpContext context = + new AbstractOperator.DefaultOpContext(collectors, new TestRuntimeContext()); + operator.open(context); + + Assert.assertEquals(startWindowId, ReflectionUtil.getField(operator, "windowId")); + + operator.initIteration(2); + Assert.assertEquals(startWindowId, ReflectionUtil.getField(operator, "windowId")); + Assert.assertEquals( + startWindowId, + ((RuntimeContext) ReflectionUtil.getField(operator, "runtimeContext")).getWindowId()); + + ((AbstractRuntimeContext) context.getRuntimeContext()).updateWindowId(3L); + operator.initIteration(1); + + Assert.assertEquals(3L, ReflectionUtil.getField(operator, "windowId")); + Assert.assertEquals( + 3L, ((RuntimeContext) ReflectionUtil.getField(operator, "runtimeContext")).getWindowId()); + } + + public class TestRuntimeContext extends AbstractRuntimeContext { - List collectors = new ArrayList<>(); - ICollector collector = mock(ICollector.class); - when(collector.getId()).thenReturn(0); - when(collector.getTag()).thenReturn("tag"); + public TestRuntimeContext() { + super(new Configuration()); + } + + public TestRuntimeContext(Map opConfig) { + super(new Configuration(opConfig)); + } - collectors.add(collector); - collectors.add(collector); - long startWindowId = 0; - Operator.OpContext context = new AbstractOperator.DefaultOpContext(collectors, - new TestRuntimeContext()); - operator.open(context); + @Override + public long getPipelineId() { + return 0; + } - Assert.assertEquals(startWindowId, ReflectionUtil.getField(operator, "windowId")); + @Override + public String getPipelineName() { + return null; + } - operator.initIteration(2); - Assert.assertEquals(startWindowId, ReflectionUtil.getField(operator, "windowId")); - Assert.assertEquals(startWindowId, ((RuntimeContext) ReflectionUtil.getField(operator, "runtimeContext")).getWindowId()); + @Override + public TaskArgs getTaskArgs() { + return new TaskArgs(0, 0, "test", 1, 1, 0); + } - ((AbstractRuntimeContext) context.getRuntimeContext()).updateWindowId(3L); - operator.initIteration(1); + @Override + public Configuration getConfiguration() { + return jobConfig; + } - Assert.assertEquals(3L, ReflectionUtil.getField(operator, "windowId")); - Assert.assertEquals(3L, ((RuntimeContext) ReflectionUtil.getField(operator, "runtimeContext")).getWindowId()); + @Override + public RuntimeContext clone(Map opConfig) { + return new TestRuntimeContext(opConfig); } - public class TestRuntimeContext extends AbstractRuntimeContext { - - public TestRuntimeContext() { - super(new Configuration()); - } - - public TestRuntimeContext(Map opConfig) { - super(new Configuration(opConfig)); - } - - @Override - public long getPipelineId() { - return 0; - } - - @Override - public String getPipelineName() { - return null; - } - - @Override - public TaskArgs getTaskArgs() { - return new TaskArgs(0, 0, "test", 1, 1, 0); - } - - @Override - public Configuration getConfiguration() { - return jobConfig; - } - - @Override - public RuntimeContext clone(Map opConfig) { - return new TestRuntimeContext(opConfig); - } - - @Override - public long getWindowId() { - return windowId; - } - - @Override - public MetricGroup getMetric() { - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - return MetricGroupRegistry.getInstance(config).getMetricGroup(); - } + @Override + public long getWindowId() { + return windowId; + } + @Override + public MetricGroup getMetric() { + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + return MetricGroupRegistry.getInstance(config).getMetricGroup(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/io/WindowSourceOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/io/WindowSourceOperatorTest.java index dd5a997b9..1af90aa0c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/io/WindowSourceOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/io/WindowSourceOperatorTest.java @@ -27,35 +27,34 @@ public class WindowSourceOperatorTest { - @Test - public void testSourceOperator() { + @Test + public void testSourceOperator() { - TestSourceFunction function = new TestSourceFunction(); - AbstractOperator operator = new WindowSourceOperator(function); - operator.close(); - Assert.assertTrue(function.isClosed()); - } + TestSourceFunction function = new TestSourceFunction(); + AbstractOperator operator = new WindowSourceOperator(function); + operator.close(); + Assert.assertTrue(function.isClosed()); + } - private class TestSourceFunction implements SourceFunction { + private class TestSourceFunction implements SourceFunction { - private boolean closed; + private boolean closed; - @Override - public void init(int parallel, int index) { - } + @Override + public void init(int parallel, int index) {} - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - return false; - } + @Override + public boolean fetch(IWindow window, SourceContext ctx) throws Exception { + return false; + } - @Override - public void close() { - this.closed = true; - } + @Override + public void close() { + this.closed = true; + } - public boolean isClosed() { - return closed; - } + public boolean isClosed() { + return closed; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrAggregateOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrAggregateOperatorTest.java index 041552a7c..317923168 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrAggregateOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrAggregateOperatorTest.java @@ -24,9 +24,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -45,106 +45,107 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class IncrAggregateOperatorTest { - private static int batchSize = 10; - - private IncrAggregateOperator operator; - private MyAgg agg; - private Operator.OpContext opContext; - private Map batchId2Value; - - @BeforeMethod - public void setup() { - ICollector collector = mock(ICollector.class); - Configuration configuration = new Configuration(); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(configuration); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - this.agg = new MyAgg(); - this.operator = new IncrAggregateOperator(this.agg, new KeySelectorFunc()); - this.operator.open(opContext); - - this.batchId2Value = new HashMap<>(); + private static int batchSize = 10; + + private IncrAggregateOperator operator; + private MyAgg agg; + private Operator.OpContext opContext; + private Map batchId2Value; + + @BeforeMethod + public void setup() { + ICollector collector = mock(ICollector.class); + Configuration configuration = new Configuration(); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(configuration); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + this.agg = new MyAgg(); + this.operator = new IncrAggregateOperator(this.agg, new KeySelectorFunc()); + this.operator.open(opContext); + + this.batchId2Value = new HashMap<>(); + } + + @Test + public void testAgg() throws Exception { + long batchId = 1; + long value = 0; + for (int i = 0; i < 1000; i++) { + value += i; + if ((i + 1) % batchSize == 0) { + this.batchId2Value.put(batchId++, value); + } } - @Test - public void testAgg() throws Exception { - long batchId = 1; - long value = 0; - for (int i = 0; i < 1000; i++) { - value += i; - if ((i + 1) % batchSize == 0) { - this.batchId2Value.put(batchId++, value); - } - } - - batchId = 1; - for (int i = 0; i < 1000; i++) { - this.operator.processElement((long) i); - if ((i + 1) % batchSize == 0) { - this.operator.finish(batchId); - this.operator.checkpoint(batchId); - batchId++; - Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); - } - } + batchId = 1; + for (int i = 0; i < 1000; i++) { + this.operator.processElement((long) i); + if ((i + 1) % batchSize == 0) { + this.operator.finish(batchId); + this.operator.checkpoint(batchId); + batchId++; + Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); + } } + } - public static class KeySelectorFunc implements KeySelector { + public static class KeySelectorFunc implements KeySelector { - @Override - public Long getKey(Long value) { - return value; - } + @Override + public Long getKey(Long value) { + return value; } + } - class MutableLong { - - long value; - } + class MutableLong { - class MyAgg implements AggregateFunction { + long value; + } - private long value; + class MyAgg implements AggregateFunction { - public MyAgg() { - this.value = 0; - } + private long value; - @Override - public MutableLong createAccumulator() { - return new MutableLong(); - } + public MyAgg() { + this.value = 0; + } - @Override - public void add(Long value, MutableLong accumulator) { - accumulator.value += value; - this.value += value; - } + @Override + public MutableLong createAccumulator() { + return new MutableLong(); + } - @Override - public Long getResult(MutableLong accumulator) { - return accumulator.value; - } + @Override + public void add(Long value, MutableLong accumulator) { + accumulator.value += value; + this.value += value; + } - @Override - public MutableLong merge(MutableLong a, MutableLong b) { - return null; - } + @Override + public Long getResult(MutableLong accumulator) { + return accumulator.value; + } - public long getValue() { - return this.value; - } + @Override + public MutableLong merge(MutableLong a, MutableLong b) { + return null; } + public long getValue() { + return this.value; + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrReduceOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrReduceOperatorTest.java index 4b81defeb..11fb97d2d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrReduceOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/IncrReduceOperatorTest.java @@ -24,9 +24,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -45,85 +45,87 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class IncrReduceOperatorTest { - private static final int BATCH_SIZE = 10; - - private IncrReduceOperator operator; - private MyReduce reduce; - private Operator.OpContext opContext; - private Map batchId2Value; - - @BeforeMethod - public void setup() { - ICollector collector = mock(ICollector.class); - Configuration configuration = new Configuration(); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(configuration); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - this.reduce = new MyReduce(); - this.operator = new IncrReduceOperator(this.reduce, new KeySelectorFunc()); - this.operator.open(opContext); - - this.batchId2Value = new HashMap<>(); + private static final int BATCH_SIZE = 10; + + private IncrReduceOperator operator; + private MyReduce reduce; + private Operator.OpContext opContext; + private Map batchId2Value; + + @BeforeMethod + public void setup() { + ICollector collector = mock(ICollector.class); + Configuration configuration = new Configuration(); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(configuration); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + this.reduce = new MyReduce(); + this.operator = new IncrReduceOperator(this.reduce, new KeySelectorFunc()); + this.operator.open(opContext); + + this.batchId2Value = new HashMap<>(); + } + + @Test + public void testReduce() throws Exception { + long batchId = 1; + long value = 0; + for (int i = 0; i < 1000; i++) { + value += i; + if ((i + 1) % BATCH_SIZE == 0) { + this.batchId2Value.put(batchId++, value); + } } - @Test - public void testReduce() throws Exception { - long batchId = 1; - long value = 0; - for (int i = 0; i < 1000; i++) { - value += i; - if ((i + 1) % BATCH_SIZE == 0) { - this.batchId2Value.put(batchId++, value); - } - } - - batchId = 1; - for (int i = 0; i < 1000; i++) { - this.operator.processElement(i); - if ((i + 1) % BATCH_SIZE == 0) { - this.operator.finish(); - this.operator.checkpoint(batchId); - batchId++; - Assert.assertTrue(this.reduce.getValue() == this.batchId2Value.get(batchId - 1)); - } - } + batchId = 1; + for (int i = 0; i < 1000; i++) { + this.operator.processElement(i); + if ((i + 1) % BATCH_SIZE == 0) { + this.operator.finish(); + this.operator.checkpoint(batchId); + batchId++; + Assert.assertTrue(this.reduce.getValue() == this.batchId2Value.get(batchId - 1)); + } } + } - public static class KeySelectorFunc implements KeySelector { + public static class KeySelectorFunc implements KeySelector { - @Override - public Integer getKey(Integer value) { - return 0; - } + @Override + public Integer getKey(Integer value) { + return 0; } + } - class MyReduce implements ReduceFunction { + class MyReduce implements ReduceFunction { - private int value; + private int value; - public MyReduce() { - this.value = 0; - } + public MyReduce() { + this.value = 0; + } - @Override - public Integer reduce(Integer oldValue, Integer newValue) { - value += newValue; - return oldValue + newValue; - } + @Override + public Integer reduce(Integer oldValue, Integer newValue) { + value += newValue; + return oldValue + newValue; + } - public int getValue() { - return value; - } + public int getValue() { + return value; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/SinkOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/SinkOperatorTest.java index 643f78fb3..427ccdf6d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/SinkOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/SinkOperatorTest.java @@ -22,10 +22,10 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import com.google.common.collect.Lists; import java.io.Closeable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.api.function.io.SinkFunction; @@ -42,122 +42,119 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class SinkOperatorTest { +import com.google.common.collect.Lists; - private SinkOperator operator; - private TransactionSinkFunction sinkFunction; - private CommonSinkFunction commonSinkFunction; - private Operator.OpContext opContext; - - @BeforeClass - public void setup() { - ICollector collector = mock(ICollector.class); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - Mockito.doReturn(config).when(runtimeContext).getConfiguration(); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - } +public class SinkOperatorTest { - @Test - public void testWriteAndFinishWithTransactionSink() throws Exception { - this.sinkFunction = new TransactionSinkFunction(); - this.operator = new SinkOperator(this.sinkFunction); - this.operator.open(opContext); - - for (int i = 0; i < 103; i++) { - this.operator.process(i); - } - Assert.assertEquals(this.sinkFunction.getList().size(), 3); - this.operator.finish(1L); - Assert.assertEquals(this.sinkFunction.getList().size(), 0); + private SinkOperator operator; + private TransactionSinkFunction sinkFunction; + private CommonSinkFunction commonSinkFunction; + private Operator.OpContext opContext; + + @BeforeClass + public void setup() { + ICollector collector = mock(ICollector.class); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + Mockito.doReturn(config).when(runtimeContext).getConfiguration(); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + } + + @Test + public void testWriteAndFinishWithTransactionSink() throws Exception { + this.sinkFunction = new TransactionSinkFunction(); + this.operator = new SinkOperator(this.sinkFunction); + this.operator.open(opContext); + + for (int i = 0; i < 103; i++) { + this.operator.process(i); } - - @Test - public void testWriteAndFinishWithCommonSink() throws Exception { - this.commonSinkFunction = new CommonSinkFunction(); - this.operator = new SinkOperator(this.commonSinkFunction); - this.operator.open(opContext); - - for (int i = 0; i < 103; i++) { - this.operator.process(i); - } - Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); - this.operator.finish(1L); - Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + Assert.assertEquals(this.sinkFunction.getList().size(), 3); + this.operator.finish(1L); + Assert.assertEquals(this.sinkFunction.getList().size(), 0); + } + + @Test + public void testWriteAndFinishWithCommonSink() throws Exception { + this.commonSinkFunction = new CommonSinkFunction(); + this.operator = new SinkOperator(this.commonSinkFunction); + this.operator.open(opContext); + + for (int i = 0; i < 103; i++) { + this.operator.process(i); } + Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + this.operator.finish(1L); + Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + } - static class TransactionSinkFunction extends RichFunction implements SinkFunction, Closeable, TransactionTrait { - - private List list; - private int num; + static class TransactionSinkFunction extends RichFunction + implements SinkFunction, Closeable, TransactionTrait { - @Override - public void open(RuntimeContext runtimeContext) { - list = new ArrayList<>(); - num = 1; - } + private List list; + private int num; - @Override - public void close() { - - } + @Override + public void open(RuntimeContext runtimeContext) { + list = new ArrayList<>(); + num = 1; + } - @Override - public void write(Integer value) throws Exception { - list.add(value); - if (num++ % 10 == 0) { - list.clear(); - num = 1; - } - } + @Override + public void close() {} - @Override - public void finish(long windowId) { - list.clear(); - } + @Override + public void write(Integer value) throws Exception { + list.add(value); + if (num++ % 10 == 0) { + list.clear(); + num = 1; + } + } - @Override - public void rollback(long windowId) { + @Override + public void finish(long windowId) { + list.clear(); + } - } + @Override + public void rollback(long windowId) {} - public List getList() { - return list; - } + public List getList() { + return list; } + } - static class CommonSinkFunction extends RichFunction implements SinkFunction, Closeable { + static class CommonSinkFunction extends RichFunction implements SinkFunction, Closeable { - private List list; - private int num; + private List list; + private int num; - @Override - public void open(RuntimeContext runtimeContext) { - list = new ArrayList<>(); - num = 1; - } - - @Override - public void close() { + @Override + public void open(RuntimeContext runtimeContext) { + list = new ArrayList<>(); + num = 1; + } - } + @Override + public void close() {} - @Override - public void write(Integer value) throws Exception { - list.add(value); - if (num++ % 10 == 0) { - list.clear(); - num = 1; - } - } + @Override + public void write(Integer value) throws Exception { + list.add(value); + if (num++ % 10 == 0) { + list.clear(); + num = 1; + } + } - public List getList() { - return list; - } + public List getList() { + return list; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowAggOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowAggOperatorTest.java index e4ec3f1b8..f3b167ed7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowAggOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowAggOperatorTest.java @@ -23,9 +23,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -42,103 +42,104 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class WindowAggOperatorTest { - private static int batchSize = 10; - - private WindowAggregateOperator operator; - private MyAgg agg; - private Operator.OpContext opContext; - private Map batchId2Value; - - @BeforeMethod - public void setup() { - ICollector collector = mock(ICollector.class); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - this.agg = new MyAgg(); - this.operator = new WindowAggregateOperator<>(this.agg, new KeySelectorFunc()); - this.operator.open(opContext); - - this.batchId2Value = new HashMap<>(); + private static int batchSize = 10; + + private WindowAggregateOperator operator; + private MyAgg agg; + private Operator.OpContext opContext; + private Map batchId2Value; + + @BeforeMethod + public void setup() { + ICollector collector = mock(ICollector.class); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + this.agg = new MyAgg(); + this.operator = new WindowAggregateOperator<>(this.agg, new KeySelectorFunc()); + this.operator.open(opContext); + + this.batchId2Value = new HashMap<>(); + } + + @Test + public void testAgg() throws Exception { + long batchId = 1; + long value = 0; + for (int i = 0; i < 1000; i++) { + value += i; + if ((i + 1) % batchSize == 0) { + this.batchId2Value.put(batchId++, value); + } } - @Test - public void testAgg() throws Exception { - long batchId = 1; - long value = 0; - for (int i = 0; i < 1000; i++) { - value += i; - if ((i + 1) % batchSize == 0) { - this.batchId2Value.put(batchId++, value); - } - } - - batchId = 1; - for (int i = 0; i < 1000; i++) { - this.operator.processValue((long) i); - if ((i + 1) % batchSize == 0) { - this.operator.finish(); - batchId++; - Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); - } - } + batchId = 1; + for (int i = 0; i < 1000; i++) { + this.operator.processValue((long) i); + if ((i + 1) % batchSize == 0) { + this.operator.finish(); + batchId++; + Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); + } } + } - public static class KeySelectorFunc implements KeySelector { + public static class KeySelectorFunc implements KeySelector { - @Override - public Long getKey(Long value) { - return value; - } + @Override + public Long getKey(Long value) { + return value; } + } - class MutableLong { - - long value; - } + class MutableLong { - class MyAgg implements AggregateFunction { + long value; + } - private long value; + class MyAgg implements AggregateFunction { - public MyAgg() { - this.value = 0; - } + private long value; - @Override - public MutableLong createAccumulator() { - return new MutableLong(); - } + public MyAgg() { + this.value = 0; + } - @Override - public void add(Long value, MutableLong accumulator) { - accumulator.value += value; - this.value += value; - } + @Override + public MutableLong createAccumulator() { + return new MutableLong(); + } - @Override - public Long getResult(MutableLong accumulator) { - return accumulator.value; - } + @Override + public void add(Long value, MutableLong accumulator) { + accumulator.value += value; + this.value += value; + } - @Override - public MutableLong merge(MutableLong a, MutableLong b) { - return null; - } + @Override + public Long getResult(MutableLong accumulator) { + return accumulator.value; + } - public long getValue() { - return this.value; - } + @Override + public MutableLong merge(MutableLong a, MutableLong b) { + return null; } + public long getValue() { + return this.value; + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowReduceOperatorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowReduceOperatorTest.java index 665f10627..dc7d9374c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowReduceOperatorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/window/WindowReduceOperatorTest.java @@ -23,9 +23,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -42,89 +42,90 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class WindowReduceOperatorTest { - private static int batchSize = 10; - - private WindowReduceOperator operator; - private MyReduce reduce; - private Operator.OpContext opContext; - private Map batchId2Value; - - @BeforeMethod - public void setup() { - ICollector collector = mock(ICollector.class); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - this.reduce = new MyReduce(); - this.operator = new WindowReduceOperator(this.reduce, new KeySelectorFunc()); - this.operator.open(opContext); - - this.batchId2Value = new HashMap<>(); + private static int batchSize = 10; + + private WindowReduceOperator operator; + private MyReduce reduce; + private Operator.OpContext opContext; + private Map batchId2Value; + + @BeforeMethod + public void setup() { + ICollector collector = mock(ICollector.class); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + this.reduce = new MyReduce(); + this.operator = new WindowReduceOperator(this.reduce, new KeySelectorFunc()); + this.operator.open(opContext); + + this.batchId2Value = new HashMap<>(); + } + + @Test + public void testReduce() throws Exception { + long batchId = 1; + long value = 0; + for (int i = 0; i < 1000; i++) { + value += i; + if ((i + 1) % batchSize == 0) { + this.batchId2Value.put(batchId++, value); + value = 0; + } } - @Test - public void testReduce() throws Exception { - long batchId = 1; - long value = 0; - for (int i = 0; i < 1000; i++) { - value += i; - if ((i + 1) % batchSize == 0) { - this.batchId2Value.put(batchId++, value); - value = 0; - } - } - - batchId = 1; - for (int i = 0; i < 1000; i++) { - this.operator.processValue(i); - if ((i + 1) % batchSize == 0) { - this.operator.finish(); - batchId++; - Assert.assertTrue(this.reduce.getValue() == this.batchId2Value.get(batchId - 1)); - this.reduce.setValue(i + 1); - } - } + batchId = 1; + for (int i = 0; i < 1000; i++) { + this.operator.processValue(i); + if ((i + 1) % batchSize == 0) { + this.operator.finish(); + batchId++; + Assert.assertTrue(this.reduce.getValue() == this.batchId2Value.get(batchId - 1)); + this.reduce.setValue(i + 1); + } } + } - public static class KeySelectorFunc implements KeySelector { + public static class KeySelectorFunc implements KeySelector { - @Override - public Integer getKey(Integer value) { - return 0; - } + @Override + public Integer getKey(Integer value) { + return 0; } + } - class MyReduce implements ReduceFunction { + class MyReduce implements ReduceFunction { - private int value; + private int value; - public MyReduce() { - this.value = 0; - } - - @Override - public Integer reduce(Integer oldValue, Integer newValue) { - value += newValue; - return oldValue + newValue; - } + public MyReduce() { + this.value = 0; + } - public int getValue() { - return value; - } + @Override + public Integer reduce(Integer oldValue, Integer newValue) { + value += newValue; + return oldValue + newValue; + } - public void setValue(int value) { - this.value = value; - } + public int getValue() { + return value; } + public void setValue(int value) { + this.value = value; + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/context/AbstractPipelineContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/context/AbstractPipelineContext.java index 559a8a383..c6304cffd 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/context/AbstractPipelineContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/context/AbstractPipelineContext.java @@ -19,7 +19,6 @@ package org.apache.geaflow.context; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -27,6 +26,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; + import org.apache.geaflow.api.pdata.base.PAction; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.pipeline.context.IPipelineContext; @@ -34,52 +34,53 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractPipelineContext implements IPipelineContext, Serializable { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractPipelineContext.class); +public abstract class AbstractPipelineContext implements IPipelineContext, Serializable { - protected final AtomicInteger idGenerator = new AtomicInteger(0); - protected Configuration pipelineConfig; - protected transient List actions; - protected transient Map viewDescMap; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractPipelineContext.class); - public AbstractPipelineContext(Configuration pipelineConfig) { - this.pipelineConfig = pipelineConfig; - this.actions = new ArrayList<>(); - this.viewDescMap = new HashMap<>(); - } + protected final AtomicInteger idGenerator = new AtomicInteger(0); + protected Configuration pipelineConfig; + protected transient List actions; + protected transient Map viewDescMap; - public int generateId() { - return idGenerator.incrementAndGet(); - } + public AbstractPipelineContext(Configuration pipelineConfig) { + this.pipelineConfig = pipelineConfig; + this.actions = new ArrayList<>(); + this.viewDescMap = new HashMap<>(); + } - @Override - public void addPAction(PAction action) { - LOGGER.info("Add Action, Id:{}", action.getId()); - this.actions.add(action); - } + public int generateId() { + return idGenerator.incrementAndGet(); + } - public void addView(IViewDesc viewDesc) { - LOGGER.info("User ViewName:{} ViewDesc:{}", viewDesc.getName(), viewDesc); - this.viewDescMap.put(viewDesc.getName(), viewDesc); - } + @Override + public void addPAction(PAction action) { + LOGGER.info("Add Action, Id:{}", action.getId()); + this.actions.add(action); + } - public IViewDesc getViewDesc(String name) { - IViewDesc viewDesc = this.viewDescMap.get(name); - Preconditions.checkArgument(viewDesc != null); - return viewDesc; - } + public void addView(IViewDesc viewDesc) { + LOGGER.info("User ViewName:{} ViewDesc:{}", viewDesc.getName(), viewDesc); + this.viewDescMap.put(viewDesc.getName(), viewDesc); + } - public Configuration getConfig() { - return pipelineConfig; - } + public IViewDesc getViewDesc(String name) { + IViewDesc viewDesc = this.viewDescMap.get(name); + Preconditions.checkArgument(viewDesc != null); + return viewDesc; + } - public List getActions() { - return actions.stream().collect(Collectors.toList()); - } + public Configuration getConfig() { + return pipelineConfig; + } - public Map getViewDescMap() { - return viewDescMap; - } + public List getActions() { + return actions.stream().collect(Collectors.toList()); + } + public Map getViewDescMap() { + return viewDescMap; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/AbstractGraphView.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/AbstractGraphView.java index 1764dfd90..2ef6f011d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/AbstractGraphView.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/AbstractGraphView.java @@ -19,7 +19,6 @@ package org.apache.geaflow.pdata.graph.view; -import com.google.common.base.Preconditions; import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; import org.apache.geaflow.api.graph.base.algo.VertexCentricAlgo; import org.apache.geaflow.api.partition.graph.edge.CustomEdgeVCPartition; @@ -40,64 +39,70 @@ import org.apache.geaflow.view.IViewDesc; import org.apache.geaflow.view.graph.GraphViewDesc; +import com.google.common.base.Preconditions; + public abstract class AbstractGraphView extends WindowDataStream { - protected long maxIterations; - protected GraphViewDesc graphViewDesc; - protected PWindowStream> vertexStream; - protected PWindowStream> edgeStream; - protected GraphExecAlgo graphExecAlgo; - protected IEncoder> msgEncoder; + protected long maxIterations; + protected GraphViewDesc graphViewDesc; + protected PWindowStream> vertexStream; + protected PWindowStream> edgeStream; + protected GraphExecAlgo graphExecAlgo; + protected IEncoder> msgEncoder; - public AbstractGraphView(IPipelineContext pipelineContext, - IViewDesc graphViewDesc, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext); - this.graphViewDesc = (GraphViewDesc) graphViewDesc; - this.vertexStream = vertexWindowStream; - this.edgeStream = edgeWindowStream; - super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); - } + public AbstractGraphView( + IPipelineContext pipelineContext, + IViewDesc graphViewDesc, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext); + this.graphViewDesc = (GraphViewDesc) graphViewDesc; + this.vertexStream = vertexWindowStream; + this.edgeStream = edgeWindowStream; + super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); + } - protected void processOnVertexCentric(VertexCentricAlgo vertexCentricAlgo) { - this.graphExecAlgo = GraphExecAlgo.VertexCentric; - this.maxIterations = vertexCentricAlgo.getMaxIterationCount(); - IGraphVCPartition graphPartition = vertexCentricAlgo.getGraphPartition(); - if (graphPartition == null) { - this.input = (Stream) this.vertexStream.keyBy(new DefaultVertexPartition<>()); - this.edgeStream = this.edgeStream.keyBy(new DefaultEdgePartition<>()); - if (parallelism > graphViewDesc.getShardNum()) { - this.input.withParallelism(this.graphViewDesc.getShardNum()); - this.edgeStream.withParallelism(this.graphViewDesc.getShardNum()); - } - } else { - Preconditions.checkArgument(parallelism <= graphViewDesc.getShardNum(), - "op parallelism must be <= shard num"); - this.input = (Stream) this.vertexStream.keyBy(new CustomVertexVCPartition<>(graphPartition)); - this.edgeStream = this.edgeStream.keyBy(new CustomEdgeVCPartition<>(graphPartition)); - } - IEncoder keyEncoder = vertexCentricAlgo.getKeyEncoder(); - if (keyEncoder == null) { - keyEncoder = (IEncoder) EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 0); - } - IEncoder msgEncoder = vertexCentricAlgo.getMessageEncoder(); - if (msgEncoder == null) { - msgEncoder = (IEncoder) EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 3); - } - this.msgEncoder = GraphMessageEncoders.build(keyEncoder, msgEncoder); + protected void processOnVertexCentric(VertexCentricAlgo vertexCentricAlgo) { + this.graphExecAlgo = GraphExecAlgo.VertexCentric; + this.maxIterations = vertexCentricAlgo.getMaxIterationCount(); + IGraphVCPartition graphPartition = vertexCentricAlgo.getGraphPartition(); + if (graphPartition == null) { + this.input = (Stream) this.vertexStream.keyBy(new DefaultVertexPartition<>()); + this.edgeStream = this.edgeStream.keyBy(new DefaultEdgePartition<>()); + if (parallelism > graphViewDesc.getShardNum()) { + this.input.withParallelism(this.graphViewDesc.getShardNum()); + this.edgeStream.withParallelism(this.graphViewDesc.getShardNum()); + } + } else { + Preconditions.checkArgument( + parallelism <= graphViewDesc.getShardNum(), "op parallelism must be <= shard num"); + this.input = (Stream) this.vertexStream.keyBy(new CustomVertexVCPartition<>(graphPartition)); + this.edgeStream = this.edgeStream.keyBy(new CustomEdgeVCPartition<>(graphPartition)); } - - public long getMaxIterations() { - return maxIterations; + IEncoder keyEncoder = vertexCentricAlgo.getKeyEncoder(); + if (keyEncoder == null) { + keyEncoder = + (IEncoder) + EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 0); } - - public PWindowStream> getEdges() { - return this.edgeStream; + IEncoder msgEncoder = vertexCentricAlgo.getMessageEncoder(); + if (msgEncoder == null) { + msgEncoder = + (IEncoder) + EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 3); } + this.msgEncoder = GraphMessageEncoders.build(keyEncoder, msgEncoder); + } - public IEncoder> getMsgEncoder() { - return this.msgEncoder; - } + public long getMaxIterations() { + return maxIterations; + } + + public PWindowStream> getEdges() { + return this.edgeStream; + } + public IEncoder> getMsgEncoder() { + return this.msgEncoder; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/IncGraphView.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/IncGraphView.java index a0c32239f..3e007de14 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/IncGraphView.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/IncGraphView.java @@ -19,7 +19,6 @@ package org.apache.geaflow.pdata.graph.view; -import com.google.common.annotations.VisibleForTesting; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.compute.IncVertexCentricAggCompute; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; @@ -42,107 +41,106 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; + public class IncGraphView implements PIncGraphView { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphView.class); - - private IPipelineContext pipelineContext; - private PWindowStream> vertexWindowSteam; - private PWindowStream> edgeWindowStream; - private IViewDesc graphViewDesc; - - @VisibleForTesting - private MaterializedIncGraph materializedIncGraph; - - public IncGraphView(IPipelineContext pipelineContext, IViewDesc viewDesc) { - this.pipelineContext = pipelineContext; - this.graphViewDesc = viewDesc; - } - - @Override - public PGraphView init(GraphViewDesc graphViewDesc) { - this.graphViewDesc = graphViewDesc; - return this; - } - - @SuppressWarnings("unchecked") - @Override - public PGraphWindow snapshot(long version) { - return new WindowStreamGraph<>(((GraphViewDesc) graphViewDesc).snapshot(version), pipelineContext); - } - - @Override - public PIncGraphView appendGraph(PWindowStream> vertexStream, - PWindowStream> edgeStream) { - this.vertexWindowSteam = vertexStream; - this.edgeWindowStream = edgeStream; - return this; - } - - @Override - public PIncGraphView appendEdge(PWindowStream> edgeStream) { - this.edgeWindowStream = edgeStream; - return this; - } - - @Override - public PIncGraphView appendVertex(PWindowStream> vertexStream) { - this.vertexWindowSteam = vertexStream; - return this; - } - - @Override - public PGraphCompute incrementalCompute( - IncVertexCentricCompute incVertexCentricCompute) { - ComputeIncGraph computeIncGraph = new ComputeIncGraph<>(pipelineContext, - graphViewDesc, vertexWindowSteam, edgeWindowStream); - computeIncGraph.computeOnIncVertexCentric(incVertexCentricCompute); - return computeIncGraph; - } - - @Override - public PGraphCompute incrementalCompute( - IncVertexCentricAggCompute incVertexCentricCompute) { - ComputeIncGraph computeIncGraph = - new ComputeIncGraph<>(pipelineContext, graphViewDesc, - vertexWindowSteam, - edgeWindowStream); - computeIncGraph.computeOnIncVertexCentric(incVertexCentricCompute); - return null; - } - - @Override - public PGraphTraversal incrementalTraversal( - IncVertexCentricTraversal incVertexCentricTraversal) { - - TraversalIncGraph traversalIncGraph = - new TraversalIncGraph<>(pipelineContext, graphViewDesc, - this.vertexWindowSteam, - this.edgeWindowStream); - traversalIncGraph.traversalOnVertexCentric(incVertexCentricTraversal); - return traversalIncGraph; - } - - @Override - public PGraphTraversal incrementalTraversal( - IncVertexCentricAggTraversal incVertexCentricTraversal) { - TraversalIncGraph traversalIncGraph = - new TraversalIncGraph<>(pipelineContext, graphViewDesc, - this.vertexWindowSteam, - this.edgeWindowStream); - traversalIncGraph.traversalOnVertexCentric(incVertexCentricTraversal); - return traversalIncGraph; - } - - @Override - public void materialize() { - materializedIncGraph = - new MaterializedIncGraph(pipelineContext, graphViewDesc, vertexWindowSteam, edgeWindowStream); - materializedIncGraph.materialize(); - } - - @VisibleForTesting - public MaterializedIncGraph getMaterializedIncGraph() { - return materializedIncGraph; - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphView.class); + + private IPipelineContext pipelineContext; + private PWindowStream> vertexWindowSteam; + private PWindowStream> edgeWindowStream; + private IViewDesc graphViewDesc; + + @VisibleForTesting private MaterializedIncGraph materializedIncGraph; + + public IncGraphView(IPipelineContext pipelineContext, IViewDesc viewDesc) { + this.pipelineContext = pipelineContext; + this.graphViewDesc = viewDesc; + } + + @Override + public PGraphView init(GraphViewDesc graphViewDesc) { + this.graphViewDesc = graphViewDesc; + return this; + } + + @SuppressWarnings("unchecked") + @Override + public PGraphWindow snapshot(long version) { + return new WindowStreamGraph<>( + ((GraphViewDesc) graphViewDesc).snapshot(version), pipelineContext); + } + + @Override + public PIncGraphView appendGraph( + PWindowStream> vertexStream, PWindowStream> edgeStream) { + this.vertexWindowSteam = vertexStream; + this.edgeWindowStream = edgeStream; + return this; + } + + @Override + public PIncGraphView appendEdge(PWindowStream> edgeStream) { + this.edgeWindowStream = edgeStream; + return this; + } + + @Override + public PIncGraphView appendVertex(PWindowStream> vertexStream) { + this.vertexWindowSteam = vertexStream; + return this; + } + + @Override + public PGraphCompute incrementalCompute( + IncVertexCentricCompute incVertexCentricCompute) { + ComputeIncGraph computeIncGraph = + new ComputeIncGraph<>(pipelineContext, graphViewDesc, vertexWindowSteam, edgeWindowStream); + computeIncGraph.computeOnIncVertexCentric(incVertexCentricCompute); + return computeIncGraph; + } + + @Override + public PGraphCompute incrementalCompute( + IncVertexCentricAggCompute incVertexCentricCompute) { + ComputeIncGraph computeIncGraph = + new ComputeIncGraph<>(pipelineContext, graphViewDesc, vertexWindowSteam, edgeWindowStream); + computeIncGraph.computeOnIncVertexCentric(incVertexCentricCompute); + return null; + } + + @Override + public PGraphTraversal incrementalTraversal( + IncVertexCentricTraversal incVertexCentricTraversal) { + + TraversalIncGraph traversalIncGraph = + new TraversalIncGraph<>( + pipelineContext, graphViewDesc, this.vertexWindowSteam, this.edgeWindowStream); + traversalIncGraph.traversalOnVertexCentric(incVertexCentricTraversal); + return traversalIncGraph; + } + + @Override + public PGraphTraversal incrementalTraversal( + IncVertexCentricAggTraversal incVertexCentricTraversal) { + TraversalIncGraph traversalIncGraph = + new TraversalIncGraph<>( + pipelineContext, graphViewDesc, this.vertexWindowSteam, this.edgeWindowStream); + traversalIncGraph.traversalOnVertexCentric(incVertexCentricTraversal); + return traversalIncGraph; + } + + @Override + public void materialize() { + materializedIncGraph = + new MaterializedIncGraph( + pipelineContext, graphViewDesc, vertexWindowSteam, edgeWindowStream); + materializedIncGraph.materialize(); + } + + @VisibleForTesting + public MaterializedIncGraph getMaterializedIncGraph() { + return materializedIncGraph; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/compute/ComputeIncGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/compute/ComputeIncGraph.java index a51d28214..c2a755616 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/compute/ComputeIncGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/compute/ComputeIncGraph.java @@ -19,7 +19,6 @@ package org.apache.geaflow.pdata.graph.view.compute; -import com.google.common.base.Preconditions; import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; import org.apache.geaflow.api.graph.compute.IncVertexCentricAggCompute; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; @@ -37,70 +36,70 @@ import org.apache.geaflow.pipeline.context.IPipelineContext; import org.apache.geaflow.view.IViewDesc; -public class ComputeIncGraph extends AbstractGraphView> implements PGraphCompute { - - public ComputeIncGraph(IPipelineContext pipelineContext, - IViewDesc graphViewDesc, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); - } - - public PWindowStream> computeOnIncVertexCentric( - IncVertexCentricCompute incVertexCentricCompute) { - processOnVertexCentric(incVertexCentricCompute); - IGraphVertexCentricOp graphVertexCentricComputeOp = - new DynamicGraphVertexCentricComputeOp(graphViewDesc, incVertexCentricCompute); - super.operator = (Operator) graphVertexCentricComputeOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(incVertexCentricCompute.getName()); - - return this; - } - - public PWindowStream> computeOnIncVertexCentric( - IncVertexCentricAggCompute incVertexCentricCompute) { - processOnVertexCentric(incVertexCentricCompute); - IGraphVertexCentricOp graphVertexCentricComputeOp = - new DynamicGraphVertexCentricComputeWithAggOp(graphViewDesc, incVertexCentricCompute); - super.operator = (Operator) graphVertexCentricComputeOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(incVertexCentricCompute.getName()); - - return this; - } - - @Override - public PGraphCompute compute() { - return this; - } - - @Override - public PGraphCompute compute(int parallelism) { - Preconditions.checkArgument(parallelism <= graphViewDesc.getShardNum(), - "op parallelism must be <= shard num"); - super.parallelism = parallelism; - return this; - } - - @Override - public PWindowStream> getVertices() { - return this; - } - - - @Override - public GraphExecAlgo getGraphComputeType() { - return graphExecAlgo; - } - - @Override - public TransformType getTransformType() { - return TransformType.ContinueGraphCompute; - } - +import com.google.common.base.Preconditions; +public class ComputeIncGraph extends AbstractGraphView> + implements PGraphCompute { + + public ComputeIncGraph( + IPipelineContext pipelineContext, + IViewDesc graphViewDesc, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); + } + + public PWindowStream> computeOnIncVertexCentric( + IncVertexCentricCompute incVertexCentricCompute) { + processOnVertexCentric(incVertexCentricCompute); + IGraphVertexCentricOp graphVertexCentricComputeOp = + new DynamicGraphVertexCentricComputeOp(graphViewDesc, incVertexCentricCompute); + super.operator = (Operator) graphVertexCentricComputeOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(incVertexCentricCompute.getName()); + + return this; + } + + public PWindowStream> computeOnIncVertexCentric( + IncVertexCentricAggCompute incVertexCentricCompute) { + processOnVertexCentric(incVertexCentricCompute); + IGraphVertexCentricOp graphVertexCentricComputeOp = + new DynamicGraphVertexCentricComputeWithAggOp(graphViewDesc, incVertexCentricCompute); + super.operator = (Operator) graphVertexCentricComputeOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(incVertexCentricCompute.getName()); + + return this; + } + + @Override + public PGraphCompute compute() { + return this; + } + + @Override + public PGraphCompute compute(int parallelism) { + Preconditions.checkArgument( + parallelism <= graphViewDesc.getShardNum(), "op parallelism must be <= shard num"); + super.parallelism = parallelism; + return this; + } + + @Override + public PWindowStream> getVertices() { + return this; + } + + @Override + public GraphExecAlgo getGraphComputeType() { + return graphExecAlgo; + } + + @Override + public TransformType getTransformType() { + return TransformType.ContinueGraphCompute; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/materialize/MaterializedIncGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/materialize/MaterializedIncGraph.java index 45bfd98a4..8e0482809 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/materialize/MaterializedIncGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/materialize/MaterializedIncGraph.java @@ -20,6 +20,7 @@ package org.apache.geaflow.pdata.graph.view.materialize; import java.util.Map; + import org.apache.geaflow.api.graph.materialize.PGraphMaterialize; import org.apache.geaflow.api.pdata.stream.window.PWindowStream; import org.apache.geaflow.model.common.Null; @@ -43,81 +44,80 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class MaterializedIncGraph extends AbstractGraphView> implements PGraphMaterialize { - - private static final Logger LOGGER = LoggerFactory.getLogger(MaterializedIncGraph.class); +public class MaterializedIncGraph + extends AbstractGraphView> + implements PGraphMaterialize { - public MaterializedIncGraph(IPipelineContext pipelineContext, - IViewDesc graphViewDesc, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); - } + private static final Logger LOGGER = LoggerFactory.getLogger(MaterializedIncGraph.class); - @Override - public void materialize() { - LOGGER.info("call materialize"); - GraphViewMaterializeOp materializeOp = new GraphViewMaterializeOp<>(graphViewDesc); - super.operator = materializeOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.withParallelism(graphViewDesc.getShardNum()); - this.opArgs.setOpName(MaterializedIncGraph.class.getSimpleName()); - assert this.vertexStream.getParallelism() <= graphViewDesc.getShardNum() : "Materialize " - + "vertexStream parallelism must <= number of graph shard num"; - this.input = (Stream) this.vertexStream - .keyBy(new WindowStreamGraph.DefaultVertexPartition<>()); - assert this.edgeStream.getParallelism() <= graphViewDesc.getShardNum() : "Materialize " - + "edgeStream parallelism must <= number of graph shard num"; - this.edgeStream = this.edgeStream - .keyBy(new WindowStreamGraph.DefaultEdgePartition<>()); + public MaterializedIncGraph( + IPipelineContext pipelineContext, + IViewDesc graphViewDesc, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); + } - if (graphViewDesc.getBackend() == BackendType.Paimon) { - SinkOperator operator = - new SinkOperator<>(new PaimonGlobalSink()); - WindowStreamSink globalSink = new WindowStreamSink<>(this, operator); - globalSink.withParallelism(1); - super.context.addPAction(globalSink); - } else { - super.context.addPAction(this); - } + @Override + public void materialize() { + LOGGER.info("call materialize"); + GraphViewMaterializeOp materializeOp = new GraphViewMaterializeOp<>(graphViewDesc); + super.operator = materializeOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.withParallelism(graphViewDesc.getShardNum()); + this.opArgs.setOpName(MaterializedIncGraph.class.getSimpleName()); + assert this.vertexStream.getParallelism() <= graphViewDesc.getShardNum() + : "Materialize " + "vertexStream parallelism must <= number of graph shard num"; + this.input = (Stream) this.vertexStream.keyBy(new WindowStreamGraph.DefaultVertexPartition<>()); + assert this.edgeStream.getParallelism() <= graphViewDesc.getShardNum() + : "Materialize " + "edgeStream parallelism must <= number of graph shard num"; + this.edgeStream = this.edgeStream.keyBy(new WindowStreamGraph.DefaultEdgePartition<>()); - assert this.getParallelism() == graphViewDesc.getShardNum() : "Materialize parallelism " - + "must be equal to the graph shard num."; + if (graphViewDesc.getBackend() == BackendType.Paimon) { + SinkOperator operator = new SinkOperator<>(new PaimonGlobalSink()); + WindowStreamSink globalSink = new WindowStreamSink<>(this, operator); + globalSink.withParallelism(1); + super.context.addPAction(globalSink); + } else { + super.context.addPAction(this); } - @Override - public TransformType getTransformType() { - return TransformType.ContinueGraphMaterialize; - } + assert this.getParallelism() == graphViewDesc.getShardNum() + : "Materialize parallelism " + "must be equal to the graph shard num."; + } - @Override - public IPartitioner getPartition() { - return new KeyPartitioner(this.getId()); - } + @Override + public TransformType getTransformType() { + return TransformType.ContinueGraphMaterialize; + } - @Override - public MaterializedIncGraph withConfig(Map config) { - this.setConfig(config); - return this; - } + @Override + public IPartitioner getPartition() { + return new KeyPartitioner(this.getId()); + } - @Override - public MaterializedIncGraph withConfig(String key, String value) { - this.setConfig(key, value); - return this; - } + @Override + public MaterializedIncGraph withConfig(Map config) { + this.setConfig(config); + return this; + } - @Override - public MaterializedIncGraph withName(String name) { - this.setName(name); - return this; - } + @Override + public MaterializedIncGraph withConfig(String key, String value) { + this.setConfig(key, value); + return this; + } - @Override - public MaterializedIncGraph withParallelism(int parallelism) { - this.setParallelism(parallelism); - return this; - } + @Override + public MaterializedIncGraph withName(String name) { + this.setName(name); + return this; + } + + @Override + public MaterializedIncGraph withParallelism(int parallelism) { + this.setParallelism(parallelism); + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/traversal/TraversalIncGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/traversal/TraversalIncGraph.java index dc85516e2..0d1e97ceb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/traversal/TraversalIncGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/view/traversal/TraversalIncGraph.java @@ -19,9 +19,9 @@ package org.apache.geaflow.pdata.graph.view.traversal; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.base.algo.GraphAggregationAlgo; import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; @@ -46,123 +46,139 @@ import org.apache.geaflow.view.IViewDesc; import org.apache.geaflow.view.graph.GraphViewDesc; -public class TraversalIncGraph extends AbstractGraphView> implements PGraphTraversal { - - protected PWindowStream> requestStream; - protected AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal; - - public TraversalIncGraph(IPipelineContext pipelineContext, - IViewDesc graphViewDesc, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); - this.vertexStream = vertexWindowStream; - this.edgeStream = edgeWindowStream; - this.graphViewDesc = (GraphViewDesc) graphViewDesc; - super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); - } +import com.google.common.collect.Lists; - public TraversalIncGraph traversalOnVertexCentric( - AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal) { - processOnVertexCentric(incVertexCentricTraversal); - this.incVertexCentricTraversal = incVertexCentricTraversal; - return this; - } +public class TraversalIncGraph + extends AbstractGraphView> + implements PGraphTraversal { - @Override - public PWindowStream> start() { - IGraphVertexCentricOp traversalOp; - if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalAllOp(graphViewDesc, - (IncVertexCentricAggTraversal) incVertexCentricTraversal); - } else { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalAllOp(graphViewDesc, - (IncVertexCentricTraversal) incVertexCentricTraversal); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(incVertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; - } + protected PWindowStream> requestStream; + protected AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal; - @Override - public PWindowStream> start(K vId) { - return start(Lists.newArrayList(vId)); - } + public TraversalIncGraph( + IPipelineContext pipelineContext, + IViewDesc graphViewDesc, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext, graphViewDesc, vertexWindowStream, edgeWindowStream); + this.vertexStream = vertexWindowStream; + this.edgeStream = edgeWindowStream; + this.graphViewDesc = (GraphViewDesc) graphViewDesc; + super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); + } - @Override - public PWindowStream> start(List vIds) { - List> vertexBeginTraversalRequests = new ArrayList<>(); - for (K vId : vIds) { - VertexBeginTraversalRequest vertexBeginTraversalRequest = - new VertexBeginTraversalRequest<>( - vId); - vertexBeginTraversalRequests.add(vertexBeginTraversalRequest); - } - - IGraphVertexCentricOp traversalOp; - if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( - graphViewDesc, - (IncVertexCentricAggTraversal) incVertexCentricTraversal, - vertexBeginTraversalRequests); - } else { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( - graphViewDesc, - (IncVertexCentricTraversal) incVertexCentricTraversal, - vertexBeginTraversalRequests); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(incVertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; - } + public TraversalIncGraph traversalOnVertexCentric( + AbstractIncVertexCentricTraversalAlgo incVertexCentricTraversal) { + processOnVertexCentric(incVertexCentricTraversal); + this.incVertexCentricTraversal = incVertexCentricTraversal; + return this; + } - @Override - public PWindowStream> start( - PWindowStream> requests) { - this.requestStream = requests instanceof PWindowBroadcastStream - ? requests : requests.keyBy(new DefaultTraversalRequestPartition()); - IGraphVertexCentricOp traversalOp; - if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( - graphViewDesc, - (IncVertexCentricAggTraversal) incVertexCentricTraversal); - } else { - traversalOp = GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( - graphViewDesc, - (IncVertexCentricTraversal) incVertexCentricTraversal); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(incVertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; + @Override + public PWindowStream> start() { + IGraphVertexCentricOp traversalOp; + if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalAllOp( + graphViewDesc, + (IncVertexCentricAggTraversal) + incVertexCentricTraversal); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalAllOp( + graphViewDesc, + (IncVertexCentricTraversal) incVertexCentricTraversal); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(incVertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } - @Override - public GraphExecAlgo getGraphTraversalType() { - return graphExecAlgo; - } + @Override + public PWindowStream> start(K vId) { + return start(Lists.newArrayList(vId)); + } - @Override - public TraversalIncGraph withParallelism(int parallelism) { - setParallelism(parallelism); - return this; + @Override + public PWindowStream> start(List vIds) { + List> vertexBeginTraversalRequests = new ArrayList<>(); + for (K vId : vIds) { + VertexBeginTraversalRequest vertexBeginTraversalRequest = + new VertexBeginTraversalRequest<>(vId); + vertexBeginTraversalRequests.add(vertexBeginTraversalRequest); } - public PWindowStream> getRequestStream() { - return requestStream; + IGraphVertexCentricOp traversalOp; + if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( + graphViewDesc, + (IncVertexCentricAggTraversal) + incVertexCentricTraversal, + vertexBeginTraversalRequests); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( + graphViewDesc, + (IncVertexCentricTraversal) incVertexCentricTraversal, + vertexBeginTraversalRequests); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(incVertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } - @Override - public TransformType getTransformType() { - return TransformType.ContinueGraphTraversal; + @Override + public PWindowStream> start( + PWindowStream> requests) { + this.requestStream = + requests instanceof PWindowBroadcastStream + ? requests + : requests.keyBy(new DefaultTraversalRequestPartition()); + IGraphVertexCentricOp traversalOp; + if (incVertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( + graphViewDesc, + (IncVertexCentricAggTraversal) + incVertexCentricTraversal); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildDynamicGraphVertexCentricTraversalOp( + graphViewDesc, + (IncVertexCentricTraversal) incVertexCentricTraversal); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(incVertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } + + @Override + public GraphExecAlgo getGraphTraversalType() { + return graphExecAlgo; + } + + @Override + public TraversalIncGraph withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } + + public PWindowStream> getRequestStream() { + return requestStream; + } + + @Override + public TransformType getTransformType() { + return TransformType.ContinueGraphTraversal; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/AbstractGraphWindow.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/AbstractGraphWindow.java index 6115568b7..d6cd32f77 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/AbstractGraphWindow.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/AbstractGraphWindow.java @@ -39,53 +39,57 @@ public abstract class AbstractGraphWindow extends WindowDataStream { - protected long maxIterations; - protected PWindowStream> vertexStream; - protected PWindowStream> edgeStream; - protected GraphExecAlgo graphExecAlgo; - protected IEncoder> msgEncoder; + protected long maxIterations; + protected PWindowStream> vertexStream; + protected PWindowStream> edgeStream; + protected GraphExecAlgo graphExecAlgo; + protected IEncoder> msgEncoder; - public AbstractGraphWindow(IPipelineContext pipelineContext, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext); - this.vertexStream = vertexWindowStream; - this.edgeStream = edgeWindowStream; - super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); - } + public AbstractGraphWindow( + IPipelineContext pipelineContext, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext); + this.vertexStream = vertexWindowStream; + this.edgeStream = edgeWindowStream; + super.parallelism = Math.max(vertexStream.getParallelism(), edgeStream.getParallelism()); + } - protected void processOnVertexCentric(VertexCentricAlgo vertexCentricAlgo) { - this.graphExecAlgo = GraphExecAlgo.VertexCentric; - this.maxIterations = vertexCentricAlgo.getMaxIterationCount(); - IGraphVCPartition graphPartition = vertexCentricAlgo.getGraphPartition(); - if (graphPartition == null) { - this.input = (Stream) this.vertexStream.keyBy(new DefaultVertexPartition<>()); - this.edgeStream = this.edgeStream.keyBy(new DefaultEdgePartition<>()); - } else { - this.input = (Stream) this.vertexStream.keyBy(new CustomVertexVCPartition<>(graphPartition)); - this.edgeStream = this.edgeStream.keyBy(new CustomEdgeVCPartition<>(graphPartition)); - } - IEncoder keyEncoder = vertexCentricAlgo.getKeyEncoder(); - if (keyEncoder == null) { - keyEncoder = (IEncoder) EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 0); - } - IEncoder msgEncoder = vertexCentricAlgo.getMessageEncoder(); - if (msgEncoder == null) { - msgEncoder = (IEncoder) EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 3); - } - this.msgEncoder = GraphMessageEncoders.build(keyEncoder, msgEncoder); + protected void processOnVertexCentric(VertexCentricAlgo vertexCentricAlgo) { + this.graphExecAlgo = GraphExecAlgo.VertexCentric; + this.maxIterations = vertexCentricAlgo.getMaxIterationCount(); + IGraphVCPartition graphPartition = vertexCentricAlgo.getGraphPartition(); + if (graphPartition == null) { + this.input = (Stream) this.vertexStream.keyBy(new DefaultVertexPartition<>()); + this.edgeStream = this.edgeStream.keyBy(new DefaultEdgePartition<>()); + } else { + this.input = (Stream) this.vertexStream.keyBy(new CustomVertexVCPartition<>(graphPartition)); + this.edgeStream = this.edgeStream.keyBy(new CustomEdgeVCPartition<>(graphPartition)); } - - public long getMaxIterations() { - return maxIterations; + IEncoder keyEncoder = vertexCentricAlgo.getKeyEncoder(); + if (keyEncoder == null) { + keyEncoder = + (IEncoder) + EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 0); } - - public PWindowStream> getEdges() { - return this.edgeStream; + IEncoder msgEncoder = vertexCentricAlgo.getMessageEncoder(); + if (msgEncoder == null) { + msgEncoder = + (IEncoder) + EncoderResolver.resolveFunction(VertexCentricAlgo.class, vertexCentricAlgo, 3); } + this.msgEncoder = GraphMessageEncoders.build(keyEncoder, msgEncoder); + } - public IEncoder> getMsgEncoder() { - return this.msgEncoder; - } + public long getMaxIterations() { + return maxIterations; + } + + public PWindowStream> getEdges() { + return this.edgeStream; + } + public IEncoder> getMsgEncoder() { + return this.msgEncoder; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/WindowStreamGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/WindowStreamGraph.java index 0ad83ba04..ed7993b46 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/WindowStreamGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/WindowStreamGraph.java @@ -19,8 +19,8 @@ package org.apache.geaflow.pdata.graph.window; -import com.google.common.base.Preconditions; import java.io.Serializable; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.graph.PGraphWindow; @@ -42,103 +42,100 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class WindowStreamGraph implements PGraphWindow, Serializable { - - private static final Logger LOGGER = LoggerFactory.getLogger(WindowStreamGraph.class); - - private final GraphViewDesc graphViewDesc; - private final IPipelineContext pipelineContext; - private final PWindowStream> vertexWindowSteam; - private final PWindowStream> edgeWindowStream; - - /** - * Create a static window graph. - */ - public WindowStreamGraph(GraphViewDesc graphViewDesc, IPipelineContext pipelineContext, - PWindowStream> vertexWindowSteam, - PWindowStream> edgeWindowStream) { - this.graphViewDesc = graphViewDesc.asStatic(); - this.pipelineContext = pipelineContext; - this.vertexWindowSteam = vertexWindowSteam; - this.edgeWindowStream = edgeWindowStream; - } - - /** - * Create a snapshot window graph. - */ - public WindowStreamGraph(GraphViewDesc graphViewDesc, IPipelineContext pipelineContext) { - this.graphViewDesc = graphViewDesc; - this.pipelineContext = pipelineContext; - this.vertexWindowSteam = new WindowStreamSource(pipelineContext, new CollectionSource(), - AllWindow.getInstance()); - this.edgeWindowStream = new WindowStreamSource(pipelineContext, new CollectionSource(), - AllWindow.getInstance()); - } - - - @Override - public PGraphCompute compute(VertexCentricCompute vertexCentricCompute) { - Preconditions.checkArgument(vertexCentricCompute.getMaxIterationCount() > 0); - ComputeWindowGraph graphCompute = new ComputeWindowGraph<>(pipelineContext, - vertexWindowSteam, edgeWindowStream); - graphCompute.computeOnVertexCentric(graphViewDesc, vertexCentricCompute); - return graphCompute; - } - - @Override - public PGraphCompute compute( - VertexCentricAggCompute vertexCentricAggCompute) { - Preconditions.checkArgument(vertexCentricAggCompute.getMaxIterationCount() > 0); - ComputeWindowGraph graphCompute = new ComputeWindowGraph<>(pipelineContext, - vertexWindowSteam, edgeWindowStream); - graphCompute.computeOnVertexCentric(graphViewDesc, vertexCentricAggCompute); - return graphCompute; - } - - - public static class DefaultVertexPartition implements KeySelector, K> { - @Override - public K getKey(IVertex value) { - return value.getId(); - } - } - - public static class DefaultEdgePartition implements KeySelector, K> { - @Override - public K getKey(IEdge value) { - return value.getSrcId(); - } - } - - @Override - public PGraphTraversal traversal( - VertexCentricTraversal vertexCentricTraversal) { - TraversalWindowGraph traversalWindowGraph = - new TraversalWindowGraph<>(graphViewDesc, pipelineContext, - vertexWindowSteam, - edgeWindowStream); - traversalWindowGraph.traversalOnVertexCentric(vertexCentricTraversal); - return traversalWindowGraph; - } +import com.google.common.base.Preconditions; - @Override - public PGraphTraversal traversal( - VertexCentricAggTraversal vertexCentricAggTraversal) { - TraversalWindowGraph traversalWindowGraph = - new TraversalWindowGraph<>(graphViewDesc, pipelineContext, - vertexWindowSteam, - edgeWindowStream); - traversalWindowGraph.traversalOnVertexCentric(vertexCentricAggTraversal); - return traversalWindowGraph; - } +public class WindowStreamGraph implements PGraphWindow, Serializable { + private static final Logger LOGGER = LoggerFactory.getLogger(WindowStreamGraph.class); + + private final GraphViewDesc graphViewDesc; + private final IPipelineContext pipelineContext; + private final PWindowStream> vertexWindowSteam; + private final PWindowStream> edgeWindowStream; + + /** Create a static window graph. */ + public WindowStreamGraph( + GraphViewDesc graphViewDesc, + IPipelineContext pipelineContext, + PWindowStream> vertexWindowSteam, + PWindowStream> edgeWindowStream) { + this.graphViewDesc = graphViewDesc.asStatic(); + this.pipelineContext = pipelineContext; + this.vertexWindowSteam = vertexWindowSteam; + this.edgeWindowStream = edgeWindowStream; + } + + /** Create a snapshot window graph. */ + public WindowStreamGraph(GraphViewDesc graphViewDesc, IPipelineContext pipelineContext) { + this.graphViewDesc = graphViewDesc; + this.pipelineContext = pipelineContext; + this.vertexWindowSteam = + new WindowStreamSource(pipelineContext, new CollectionSource(), AllWindow.getInstance()); + this.edgeWindowStream = + new WindowStreamSource(pipelineContext, new CollectionSource(), AllWindow.getInstance()); + } + + @Override + public PGraphCompute compute( + VertexCentricCompute vertexCentricCompute) { + Preconditions.checkArgument(vertexCentricCompute.getMaxIterationCount() > 0); + ComputeWindowGraph graphCompute = + new ComputeWindowGraph<>(pipelineContext, vertexWindowSteam, edgeWindowStream); + graphCompute.computeOnVertexCentric(graphViewDesc, vertexCentricCompute); + return graphCompute; + } + + @Override + public PGraphCompute compute( + VertexCentricAggCompute vertexCentricAggCompute) { + Preconditions.checkArgument(vertexCentricAggCompute.getMaxIterationCount() > 0); + ComputeWindowGraph graphCompute = + new ComputeWindowGraph<>(pipelineContext, vertexWindowSteam, edgeWindowStream); + graphCompute.computeOnVertexCentric(graphViewDesc, vertexCentricAggCompute); + return graphCompute; + } + + public static class DefaultVertexPartition implements KeySelector, K> { @Override - public PWindowStream> getEdges() { - return this.edgeWindowStream; + public K getKey(IVertex value) { + return value.getId(); } + } + public static class DefaultEdgePartition implements KeySelector, K> { @Override - public PWindowStream> getVertices() { - return this.vertexWindowSteam; + public K getKey(IEdge value) { + return value.getSrcId(); } + } + + @Override + public PGraphTraversal traversal( + VertexCentricTraversal vertexCentricTraversal) { + TraversalWindowGraph traversalWindowGraph = + new TraversalWindowGraph<>( + graphViewDesc, pipelineContext, vertexWindowSteam, edgeWindowStream); + traversalWindowGraph.traversalOnVertexCentric(vertexCentricTraversal); + return traversalWindowGraph; + } + + @Override + public PGraphTraversal traversal( + VertexCentricAggTraversal vertexCentricAggTraversal) { + TraversalWindowGraph traversalWindowGraph = + new TraversalWindowGraph<>( + graphViewDesc, pipelineContext, vertexWindowSteam, edgeWindowStream); + traversalWindowGraph.traversalOnVertexCentric(vertexCentricAggTraversal); + return traversalWindowGraph; + } + + @Override + public PWindowStream> getEdges() { + return this.edgeWindowStream; + } + + @Override + public PWindowStream> getVertices() { + return this.vertexWindowSteam; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/compute/ComputeWindowGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/compute/ComputeWindowGraph.java index 322f5a9d3..d85e14f12 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/compute/ComputeWindowGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/compute/ComputeWindowGraph.java @@ -36,70 +36,69 @@ import org.apache.geaflow.pipeline.context.IPipelineContext; import org.apache.geaflow.view.graph.GraphViewDesc; -public class ComputeWindowGraph extends AbstractGraphWindow> implements PGraphCompute { - - public ComputeWindowGraph(IPipelineContext pipelineContext, - PWindowStream> vertexStream, - PWindowStream> edgeStream) { - super(pipelineContext, vertexStream, edgeStream); - } - - public PWindowStream> computeOnVertexCentric(GraphViewDesc graphViewDesc, - VertexCentricCompute vertexCentricCompute) { - processOnVertexCentric(vertexCentricCompute); - - IGraphVertexCentricOp graphVertexCentricComputeOp = - GraphVertexCentricOpFactory.buildStaticGraphVertexCentricComputeOp(graphViewDesc, - vertexCentricCompute); - - super.operator = (Operator) graphVertexCentricComputeOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(vertexCentricCompute.getName()); - return this; - } - - public PWindowStream> computeOnVertexCentric( - GraphViewDesc graphViewDesc, - VertexCentricAggCompute vertexCentricAggCompute) { - processOnVertexCentric(vertexCentricAggCompute); - - IGraphVertexCentricAggOp graphVertexCentricComputeOp = - GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggComputeOp(graphViewDesc, - vertexCentricAggCompute); - super.operator = (Operator) graphVertexCentricComputeOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(vertexCentricAggCompute.getName()); - return this; - } - - - @Override - public TransformType getTransformType() { - return TransformType.WindowGraphCompute; - } - - @Override - public PWindowStream> getVertices() { - return this; - } - - @Override - public PGraphCompute compute() { - return this; - } - - @Override - public PGraphCompute compute(int parallelism) { - super.parallelism = parallelism; - return this; - } - - @Override - public GraphExecAlgo getGraphComputeType() { - return GraphExecAlgo.VertexCentric; - } - +public class ComputeWindowGraph + extends AbstractGraphWindow> implements PGraphCompute { + + public ComputeWindowGraph( + IPipelineContext pipelineContext, + PWindowStream> vertexStream, + PWindowStream> edgeStream) { + super(pipelineContext, vertexStream, edgeStream); + } + + public PWindowStream> computeOnVertexCentric( + GraphViewDesc graphViewDesc, VertexCentricCompute vertexCentricCompute) { + processOnVertexCentric(vertexCentricCompute); + + IGraphVertexCentricOp graphVertexCentricComputeOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricComputeOp( + graphViewDesc, vertexCentricCompute); + + super.operator = (Operator) graphVertexCentricComputeOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(vertexCentricCompute.getName()); + return this; + } + + public PWindowStream> computeOnVertexCentric( + GraphViewDesc graphViewDesc, + VertexCentricAggCompute vertexCentricAggCompute) { + processOnVertexCentric(vertexCentricAggCompute); + + IGraphVertexCentricAggOp graphVertexCentricComputeOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggComputeOp( + graphViewDesc, vertexCentricAggCompute); + super.operator = (Operator) graphVertexCentricComputeOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(vertexCentricAggCompute.getName()); + return this; + } + + @Override + public TransformType getTransformType() { + return TransformType.WindowGraphCompute; + } + + @Override + public PWindowStream> getVertices() { + return this; + } + + @Override + public PGraphCompute compute() { + return this; + } + + @Override + public PGraphCompute compute(int parallelism) { + super.parallelism = parallelism; + return this; + } + + @Override + public GraphExecAlgo getGraphComputeType() { + return GraphExecAlgo.VertexCentric; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/traversal/TraversalWindowGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/traversal/TraversalWindowGraph.java index 5b10bbc35..ca2ca1c4c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/traversal/TraversalWindowGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/graph/window/traversal/TraversalWindowGraph.java @@ -19,9 +19,9 @@ package org.apache.geaflow.pdata.graph.window.traversal; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.AbstractVertexCentricTraversalAlgo; import org.apache.geaflow.api.graph.base.algo.GraphAggregationAlgo; import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; @@ -46,123 +46,140 @@ import org.apache.geaflow.pipeline.context.IPipelineContext; import org.apache.geaflow.view.graph.GraphViewDesc; -public class TraversalWindowGraph extends - AbstractGraphWindow> implements PGraphTraversal { - - protected PWindowStream> requestStream; - protected AbstractVertexCentricTraversalAlgo vertexCentricTraversal; - private final GraphViewDesc graphViewDesc; - - public TraversalWindowGraph(GraphViewDesc graphViewDesc, - IPipelineContext pipelineContext, - PWindowStream> vertexWindowStream, - PWindowStream> edgeWindowStream) { - super(pipelineContext, vertexWindowStream, edgeWindowStream); - super.input = (Stream) vertexWindowStream; - this.edgeStream = edgeWindowStream; - this.graphViewDesc = graphViewDesc; - } +import com.google.common.collect.Lists; - public void traversalOnVertexCentric(VertexCentricTraversal vertexCentricTraversal) { - this.vertexCentricTraversal = vertexCentricTraversal; - processOnVertexCentric(vertexCentricTraversal); - this.graphExecAlgo = GraphExecAlgo.VertexCentric; - this.maxIterations = vertexCentricTraversal.getMaxIterationCount(); - } +public class TraversalWindowGraph + extends AbstractGraphWindow> + implements PGraphTraversal { - public void traversalOnVertexCentric( - VertexCentricAggTraversal vertexCentricTraversal) { - this.vertexCentricTraversal = vertexCentricTraversal; - processOnVertexCentric(vertexCentricTraversal); - this.graphExecAlgo = GraphExecAlgo.VertexCentric; - this.maxIterations = vertexCentricTraversal.getMaxIterationCount(); - } + protected PWindowStream> requestStream; + protected AbstractVertexCentricTraversalAlgo vertexCentricTraversal; + private final GraphViewDesc graphViewDesc; - @Override - public PWindowStream> start() { - IGraphVertexCentricOp traversalOp; - if (vertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalAllOp(graphViewDesc, - (VertexCentricAggTraversal) vertexCentricTraversal); - } else { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalAllOp(graphViewDesc, - (VertexCentricTraversal) vertexCentricTraversal); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(vertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; - } + public TraversalWindowGraph( + GraphViewDesc graphViewDesc, + IPipelineContext pipelineContext, + PWindowStream> vertexWindowStream, + PWindowStream> edgeWindowStream) { + super(pipelineContext, vertexWindowStream, edgeWindowStream); + super.input = (Stream) vertexWindowStream; + this.edgeStream = edgeWindowStream; + this.graphViewDesc = graphViewDesc; + } - @Override - public PWindowStream> start(K vId) { - return start(Lists.newArrayList(vId)); - } + public void traversalOnVertexCentric( + VertexCentricTraversal vertexCentricTraversal) { + this.vertexCentricTraversal = vertexCentricTraversal; + processOnVertexCentric(vertexCentricTraversal); + this.graphExecAlgo = GraphExecAlgo.VertexCentric; + this.maxIterations = vertexCentricTraversal.getMaxIterationCount(); + } - @Override - public PWindowStream> start(List vIds) { - List> vertexBeginTraversalRequests = new ArrayList<>(); - for (K vId : vIds) { - VertexBeginTraversalRequest vertexBeginTraversalRequest = new VertexBeginTraversalRequest( - vId); - vertexBeginTraversalRequests.add(vertexBeginTraversalRequest); - } - IGraphVertexCentricOp traversalOp; - if (vertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalOp(graphViewDesc, - (VertexCentricAggTraversal) vertexCentricTraversal, vertexBeginTraversalRequests); - } else { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalOp(graphViewDesc, - (VertexCentricTraversal) vertexCentricTraversal, vertexBeginTraversalRequests); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(vertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; - } + public void traversalOnVertexCentric( + VertexCentricAggTraversal vertexCentricTraversal) { + this.vertexCentricTraversal = vertexCentricTraversal; + processOnVertexCentric(vertexCentricTraversal); + this.graphExecAlgo = GraphExecAlgo.VertexCentric; + this.maxIterations = vertexCentricTraversal.getMaxIterationCount(); + } - @Override - public PWindowStream> start( - PWindowStream> requests) { - this.requestStream = requests instanceof PWindowBroadcastStream - ? requests : requests.keyBy(new DefaultTraversalRequestPartition()); - IGraphVertexCentricOp traversalOp; - if (vertexCentricTraversal instanceof GraphAggregationAlgo) { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalOp(graphViewDesc, - (VertexCentricAggTraversal) vertexCentricTraversal); - } else { - traversalOp = GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalOp(graphViewDesc, - (VertexCentricTraversal) vertexCentricTraversal); - } - super.operator = (Operator) traversalOp; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(getId()); - this.opArgs.setOpName(vertexCentricTraversal.getName()); - this.opArgs.setParallelism(this.parallelism); - return this; + @Override + public PWindowStream> start() { + IGraphVertexCentricOp traversalOp; + if (vertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalAllOp( + graphViewDesc, (VertexCentricAggTraversal) vertexCentricTraversal); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalAllOp( + graphViewDesc, (VertexCentricTraversal) vertexCentricTraversal); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(vertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } - @Override - public TraversalWindowGraph withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } + @Override + public PWindowStream> start(K vId) { + return start(Lists.newArrayList(vId)); + } - @Override - public GraphExecAlgo getGraphTraversalType() { - return graphExecAlgo; + @Override + public PWindowStream> start(List vIds) { + List> vertexBeginTraversalRequests = new ArrayList<>(); + for (K vId : vIds) { + VertexBeginTraversalRequest vertexBeginTraversalRequest = + new VertexBeginTraversalRequest(vId); + vertexBeginTraversalRequests.add(vertexBeginTraversalRequest); } - - @Override - public TransformType getTransformType() { - return TransformType.WindowGraphTraversal; + IGraphVertexCentricOp traversalOp; + if (vertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalOp( + graphViewDesc, + (VertexCentricAggTraversal) vertexCentricTraversal, + vertexBeginTraversalRequests); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalOp( + graphViewDesc, + (VertexCentricTraversal) vertexCentricTraversal, + vertexBeginTraversalRequests); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(vertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } - public PWindowStream> getRequestStream() { - return requestStream; + @Override + public PWindowStream> start( + PWindowStream> requests) { + this.requestStream = + requests instanceof PWindowBroadcastStream + ? requests + : requests.keyBy(new DefaultTraversalRequestPartition()); + IGraphVertexCentricOp traversalOp; + if (vertexCentricTraversal instanceof GraphAggregationAlgo) { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricAggTraversalOp( + graphViewDesc, (VertexCentricAggTraversal) vertexCentricTraversal); + } else { + traversalOp = + GraphVertexCentricOpFactory.buildStaticGraphVertexCentricTraversalOp( + graphViewDesc, (VertexCentricTraversal) vertexCentricTraversal); } + super.operator = (Operator) traversalOp; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(getId()); + this.opArgs.setOpName(vertexCentricTraversal.getName()); + this.opArgs.setParallelism(this.parallelism); + return this; + } + + @Override + public TraversalWindowGraph withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } + + @Override + public GraphExecAlgo getGraphTraversalType() { + return graphExecAlgo; + } + + @Override + public TransformType getTransformType() { + return TransformType.WindowGraphTraversal; + } + + public PWindowStream> getRequestStream() { + return requestStream; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/Stream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/Stream.java index 82c2218f1..398f854f0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/Stream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/Stream.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Map; + import org.apache.geaflow.api.pdata.base.PData; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.operator.OpArgs; @@ -32,108 +33,105 @@ public abstract class Stream implements PData, Serializable { - private int id; - protected int parallelism = 1; - - protected Stream input; - - protected OpArgs opArgs; - protected Operator operator; - protected IPipelineContext context; - protected IEncoder encoder; - - protected Stream() { - - } - - public Stream(IPipelineContext context) { - this.id = context.generateId(); - this.context = context; - } - - public Stream(IPipelineContext context, Operator operator) { - this(context); - this.operator = operator; - this.opArgs = ((AbstractOperator) operator).getOpArgs(); - this.opArgs.setOpId(this.id); - } - - public Stream(Stream dataStream, Operator operator) { - this(dataStream.getContext(), operator); - this.input = dataStream; - this.parallelism = input.getParallelism(); - this.opArgs.setParallelism(parallelism); - } - - @Override - public int getId() { - return id; - } - - public void setId(int id) { - this.id = id; - } - - protected void updateId() { - this.id = context.generateId(); - } - - public Operator getOperator() { - this.opArgs.setOpId(this.id); - return operator; - } - - public void setOperator(Operator operator) { - this.operator = operator; - if (input != null) { - this.opArgs.setParallelism(input.getParallelism()); - } - } - - public IPipelineContext getContext() { - return context; - } - - public > S getInput() { - return (S) this.input; - } - - public IPartitioner getPartition() { - return new ForwardPartitioner(this.getId()); - } - - public int getParallelism() { - return this.parallelism; - } - - protected void setParallelism(int parallelism) { - this.parallelism = parallelism; - this.opArgs.setParallelism(parallelism); - } - - protected void setName(String name) { - this.opArgs.setOpName(name); - } - - public void setConfig(Map config) { - this.opArgs.setConfig(config); - } - - public void setConfig(String key, String value) { - this.opArgs.getConfig().put(key, value); - } - - public TransformType getTransformType() { - return TransformType.StreamTransform; - } - - public Stream withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } - - public IEncoder getEncoder() { - return this.encoder; - } - + private int id; + protected int parallelism = 1; + + protected Stream input; + + protected OpArgs opArgs; + protected Operator operator; + protected IPipelineContext context; + protected IEncoder encoder; + + protected Stream() {} + + public Stream(IPipelineContext context) { + this.id = context.generateId(); + this.context = context; + } + + public Stream(IPipelineContext context, Operator operator) { + this(context); + this.operator = operator; + this.opArgs = ((AbstractOperator) operator).getOpArgs(); + this.opArgs.setOpId(this.id); + } + + public Stream(Stream dataStream, Operator operator) { + this(dataStream.getContext(), operator); + this.input = dataStream; + this.parallelism = input.getParallelism(); + this.opArgs.setParallelism(parallelism); + } + + @Override + public int getId() { + return id; + } + + public void setId(int id) { + this.id = id; + } + + protected void updateId() { + this.id = context.generateId(); + } + + public Operator getOperator() { + this.opArgs.setOpId(this.id); + return operator; + } + + public void setOperator(Operator operator) { + this.operator = operator; + if (input != null) { + this.opArgs.setParallelism(input.getParallelism()); + } + } + + public IPipelineContext getContext() { + return context; + } + + public > S getInput() { + return (S) this.input; + } + + public IPartitioner getPartition() { + return new ForwardPartitioner(this.getId()); + } + + public int getParallelism() { + return this.parallelism; + } + + protected void setParallelism(int parallelism) { + this.parallelism = parallelism; + this.opArgs.setParallelism(parallelism); + } + + protected void setName(String name) { + this.opArgs.setOpName(name); + } + + public void setConfig(Map config) { + this.opArgs.setConfig(config); + } + + public void setConfig(String key, String value) { + this.opArgs.getConfig().put(key, value); + } + + public TransformType getTransformType() { + return TransformType.StreamTransform; + } + + public Stream withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } + + public IEncoder getEncoder() { + return this.encoder; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/StreamType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/StreamType.java index c47948943..49a56a56b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/StreamType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/StreamType.java @@ -20,12 +20,8 @@ package org.apache.geaflow.pdata.stream; public enum StreamType { - /** - * Update type. - */ - update, - /** - * Append type. - */ - append + /** Update type. */ + update, + /** Append type. */ + append } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/TransformType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/TransformType.java index daec6a7ca..ce2994c1d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/TransformType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/TransformType.java @@ -20,22 +20,21 @@ package org.apache.geaflow.pdata.stream; public enum TransformType { - StreamSource, - WindowSource, - StreamJoin, - WindowJoin, - StreamCombine, - WindowCombine, - StreamUnion, - WindowGraphCompute, - WindowGraphTraversal, - ContinueGraphCompute, - ContinueGraphTraversal, - ContinueGraphMaterialize, - StreamTransform, - ContinueStreamCompute, - WindowSink, - StreamSink, - WindowCollect, - + StreamSource, + WindowSource, + StreamJoin, + WindowJoin, + StreamCombine, + WindowCombine, + StreamUnion, + WindowGraphCompute, + WindowGraphTraversal, + ContinueGraphCompute, + ContinueGraphTraversal, + ContinueGraphMaterialize, + StreamTransform, + ContinueStreamCompute, + WindowSink, + StreamSink, + WindowCollect, } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/AbstractStreamView.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/AbstractStreamView.java index 4a1049d8b..76b938694 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/AbstractStreamView.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/AbstractStreamView.java @@ -28,30 +28,32 @@ import org.apache.geaflow.view.IViewDesc; import org.apache.geaflow.view.stream.StreamViewDesc; -public abstract class AbstractStreamView extends WindowDataStream implements PStreamView { - - protected IPipelineContext pipelineContext; - protected StreamViewDesc streamViewDesc; - protected PWindowStream incrWindowStream; - - public AbstractStreamView(IPipelineContext pipelineContext) { - this.pipelineContext = pipelineContext; - } - - public AbstractStreamView(IPipelineContext pipelineContext, PWindowStream input, Operator operator) { - super(pipelineContext, input, operator); - this.pipelineContext = pipelineContext; - } - - @Override - public PStreamView init(IViewDesc viewDesc) { - this.streamViewDesc = (StreamViewDesc) viewDesc; - return this; - } - - @Override - public PIncStreamView append(PWindowStream windowStream) { - this.incrWindowStream = windowStream; - return (PIncStreamView) this; - } +public abstract class AbstractStreamView extends WindowDataStream + implements PStreamView { + + protected IPipelineContext pipelineContext; + protected StreamViewDesc streamViewDesc; + protected PWindowStream incrWindowStream; + + public AbstractStreamView(IPipelineContext pipelineContext) { + this.pipelineContext = pipelineContext; + } + + public AbstractStreamView( + IPipelineContext pipelineContext, PWindowStream input, Operator operator) { + super(pipelineContext, input, operator); + this.pipelineContext = pipelineContext; + } + + @Override + public PStreamView init(IViewDesc viewDesc) { + this.streamViewDesc = (StreamViewDesc) viewDesc; + return this; + } + + @Override + public PIncStreamView append(PWindowStream windowStream) { + this.incrWindowStream = windowStream; + return (PIncStreamView) this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/IncStreamView.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/IncStreamView.java index ab00ad1f2..2c9c9d0e6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/IncStreamView.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/IncStreamView.java @@ -32,36 +32,39 @@ import org.apache.geaflow.pdata.stream.view.compute.ComputeIncStream; import org.apache.geaflow.pipeline.context.IPipelineContext; -public class IncStreamView extends AbstractStreamView implements PIncStreamView { +public class IncStreamView extends AbstractStreamView + implements PIncStreamView { - protected KeySelector keySelector; + protected KeySelector keySelector; - public IncStreamView(IPipelineContext pipelineContext, KeySelector keySelector) { - super(pipelineContext); - this.keySelector = keySelector; - } + public IncStreamView(IPipelineContext pipelineContext, KeySelector keySelector) { + super(pipelineContext); + this.keySelector = keySelector; + } - @Override - public PWindowStream reduce(ReduceFunction reduceFunction) { - IncrReduceOperator incrReduceOperator = new IncrReduceOperator(reduceFunction, keySelector); - return new ComputeIncStream(pipelineContext, incrWindowStream, incrReduceOperator); - } + @Override + public PWindowStream reduce(ReduceFunction reduceFunction) { + IncrReduceOperator incrReduceOperator = new IncrReduceOperator(reduceFunction, keySelector); + return new ComputeIncStream(pipelineContext, incrWindowStream, incrReduceOperator); + } - @Override - public PWindowStream aggregate( - AggregateFunction aggregateFunction) { - IncrAggregateOperator incrAggregateOperator = new IncrAggregateOperator(aggregateFunction, keySelector); - IEncoder resultEncoder = EncoderResolver.resolveFunction(AggregateFunction.class, aggregateFunction, 2); - return new ComputeIncStream(pipelineContext, incrWindowStream, incrAggregateOperator).withEncoder(resultEncoder); - } + @Override + public PWindowStream aggregate(AggregateFunction aggregateFunction) { + IncrAggregateOperator incrAggregateOperator = + new IncrAggregateOperator(aggregateFunction, keySelector); + IEncoder resultEncoder = + EncoderResolver.resolveFunction(AggregateFunction.class, aggregateFunction, 2); + return new ComputeIncStream(pipelineContext, incrWindowStream, incrAggregateOperator) + .withEncoder(resultEncoder); + } - public PIncStreamView withKeySelector(KeySelector keySelector) { - this.keySelector = keySelector; - return this; - } + public PIncStreamView withKeySelector(KeySelector keySelector) { + this.keySelector = keySelector; + return this; + } - @Override - public TransformType getTransformType() { - return TransformType.ContinueStreamCompute; - } + @Override + public TransformType getTransformType() { + return TransformType.ContinueStreamCompute; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/compute/ComputeIncStream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/compute/ComputeIncStream.java index b804f04bb..c4a89e334 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/compute/ComputeIncStream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/view/compute/ComputeIncStream.java @@ -27,16 +27,17 @@ public class ComputeIncStream extends AbstractStreamView { - public ComputeIncStream(IPipelineContext pipelineContext, PWindowStream input, Operator operator) { - super(pipelineContext, input, operator); - } + public ComputeIncStream( + IPipelineContext pipelineContext, PWindowStream input, Operator operator) { + super(pipelineContext, input, operator); + } - public ComputeIncStream(IPipelineContext pipelineContext) { - super(pipelineContext); - } + public ComputeIncStream(IPipelineContext pipelineContext) { + super(pipelineContext); + } - @Override - public TransformType getTransformType() { - return TransformType.ContinueStreamCompute; - } + @Override + public TransformType getTransformType() { + return TransformType.ContinueStreamCompute; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowBroadcastDataStream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowBroadcastDataStream.java index d69b76c85..7b85b8ae4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowBroadcastDataStream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowBroadcastDataStream.java @@ -20,46 +20,48 @@ package org.apache.geaflow.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.pdata.stream.window.PWindowBroadcastStream; import org.apache.geaflow.api.pdata.stream.window.PWindowStream; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.operator.Operator; import org.apache.geaflow.pipeline.context.IPipelineContext; -public class WindowBroadcastDataStream extends WindowDataStream implements PWindowBroadcastStream { +public class WindowBroadcastDataStream extends WindowDataStream + implements PWindowBroadcastStream { - public WindowBroadcastDataStream(IPipelineContext pipelineContext, PWindowStream input, - Operator operator) { - super(pipelineContext, input, operator); - } + public WindowBroadcastDataStream( + IPipelineContext pipelineContext, PWindowStream input, Operator operator) { + super(pipelineContext, input, operator); + } - @Override - public PWindowBroadcastStream withConfig(Map config) { - setConfig(config); - return this; - } + @Override + public PWindowBroadcastStream withConfig(Map config) { + setConfig(config); + return this; + } - @Override - public PWindowBroadcastStream withConfig(String key, String value) { - setConfig(key, value); - return this; - } + @Override + public PWindowBroadcastStream withConfig(String key, String value) { + setConfig(key, value); + return this; + } - @Override - public PWindowBroadcastStream withName(String name) { - this.opArgs.setOpName(name); - return this; - } + @Override + public PWindowBroadcastStream withName(String name) { + this.opArgs.setOpName(name); + return this; + } - @Override - public PWindowBroadcastStream withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } + @Override + public PWindowBroadcastStream withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } - @Override - public WindowBroadcastDataStream withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } + @Override + public WindowBroadcastDataStream withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowDataStream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowDataStream.java index 4333de3aa..fad945a47 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowDataStream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowDataStream.java @@ -19,8 +19,8 @@ package org.apache.geaflow.pdata.stream.window; -import com.google.common.base.Preconditions; import java.util.Map; + import org.apache.geaflow.api.function.base.FilterFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -45,119 +45,123 @@ import org.apache.geaflow.pdata.stream.Stream; import org.apache.geaflow.pipeline.context.IPipelineContext; -public class WindowDataStream extends Stream implements PWindowStream { - - - public WindowDataStream() { - } - - public WindowDataStream(Stream input, Operator operator) { - super(input, operator); - } - - - public WindowDataStream(IPipelineContext pipelineContext) { - super(pipelineContext); - } - - public WindowDataStream(IPipelineContext pipelineContext, Operator operator) { - super(pipelineContext, operator); - } - - public WindowDataStream(IPipelineContext pipelineContext, PWindowStream input, - Operator operator) { - this(pipelineContext, operator); - this.input = (Stream) input; - this.parallelism = input.getParallelism(); - this.opArgs.setParallelism(input.getParallelism()); - } - - @Override - public WindowDataStream map(MapFunction mapFunction) { - Preconditions.checkArgument(mapFunction != null, " Map Function must not be null"); - IEncoder resultEncoder = EncoderResolver.resolveFunction(MapFunction.class, mapFunction, 1); - return new WindowDataStream(this.context, this, new MapOperator<>(mapFunction)).withEncoder(resultEncoder); - } - - @Override - public PWindowStream filter(FilterFunction filterFunction) { - Preconditions.checkArgument(filterFunction != null, " Filter Function must not be null"); - return new WindowDataStream(this.context, this, new FilterOperator<>(filterFunction)).withEncoder(this.encoder); - } - - @Override - public PWindowStream flatMap(FlatMapFunction flatMapFunction) { - Preconditions.checkArgument(flatMapFunction != null, " FlatMap Function must not be null"); - IEncoder resultEncoder = EncoderResolver.resolveFunction(FlatMapFunction.class, flatMapFunction, 1); - return new WindowDataStream(this.context, this, new FlatMapOperator(flatMapFunction)).withEncoder(resultEncoder); - } - - @Override - public PWindowStream union(PStream uStream) { - if (this instanceof WindowUnionStream) { - ((WindowUnionStream) this).addUnionDataStream((WindowDataStream) uStream); - return this; - } else { - return new WindowUnionStream(this, (WindowDataStream) uStream, - new UnionOperator()).withEncoder(this.encoder); - } - } - - @Override - public PWindowBroadcastStream broadcast() { - return new WindowBroadcastDataStream(this.context, this, new BroadcastOperator()).withEncoder(encoder); - } - - @Override - public PWindowKeyStream keyBy(KeySelector selectorFunction) { - Preconditions.checkArgument(selectorFunction != null, " KeySelector Function must not be null"); - return new WindowKeyDataStream(context, this, - new KeySelectorOperator(selectorFunction), selectorFunction).withEncoder(this.encoder); - } - - @Override - public WindowStreamSink sink(SinkFunction sinkFunction) { - Preconditions.checkArgument(sinkFunction != null, " Sink Function must not be null"); - WindowStreamSink sink = new WindowStreamSink(this, new SinkOperator<>(sinkFunction)); - context.addPAction(sink); - return sink; - } - - @Override - public PWindowCollect collect() { - WindowStreamCollect collect = new WindowStreamCollect<>(this, new CollectOperator()); - context.addPAction(collect); - return collect; - } - - @Override - public PWindowStream withConfig(Map config) { - setConfig(config); - return this; - } - - @Override - public PWindowStream withConfig(String key, String value) { - setConfig(key, value); - return this; - } - - @Override - public PWindowStream withName(String name) { - this.opArgs.setOpName(name); - return this; - } - - @Override - public PWindowStream withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } +import com.google.common.base.Preconditions; - @Override - public WindowDataStream withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } +public class WindowDataStream extends Stream implements PWindowStream { + public WindowDataStream() {} + + public WindowDataStream(Stream input, Operator operator) { + super(input, operator); + } + + public WindowDataStream(IPipelineContext pipelineContext) { + super(pipelineContext); + } + + public WindowDataStream(IPipelineContext pipelineContext, Operator operator) { + super(pipelineContext, operator); + } + + public WindowDataStream( + IPipelineContext pipelineContext, PWindowStream input, Operator operator) { + this(pipelineContext, operator); + this.input = (Stream) input; + this.parallelism = input.getParallelism(); + this.opArgs.setParallelism(input.getParallelism()); + } + + @Override + public WindowDataStream map(MapFunction mapFunction) { + Preconditions.checkArgument(mapFunction != null, " Map Function must not be null"); + IEncoder resultEncoder = EncoderResolver.resolveFunction(MapFunction.class, mapFunction, 1); + return new WindowDataStream(this.context, this, new MapOperator<>(mapFunction)) + .withEncoder(resultEncoder); + } + + @Override + public PWindowStream filter(FilterFunction filterFunction) { + Preconditions.checkArgument(filterFunction != null, " Filter Function must not be null"); + return new WindowDataStream(this.context, this, new FilterOperator<>(filterFunction)) + .withEncoder(this.encoder); + } + + @Override + public PWindowStream flatMap(FlatMapFunction flatMapFunction) { + Preconditions.checkArgument(flatMapFunction != null, " FlatMap Function must not be null"); + IEncoder resultEncoder = + EncoderResolver.resolveFunction(FlatMapFunction.class, flatMapFunction, 1); + return new WindowDataStream(this.context, this, new FlatMapOperator(flatMapFunction)) + .withEncoder(resultEncoder); + } + + @Override + public PWindowStream union(PStream uStream) { + if (this instanceof WindowUnionStream) { + ((WindowUnionStream) this).addUnionDataStream((WindowDataStream) uStream); + return this; + } else { + return new WindowUnionStream(this, (WindowDataStream) uStream, new UnionOperator()) + .withEncoder(this.encoder); + } + } + + @Override + public PWindowBroadcastStream broadcast() { + return new WindowBroadcastDataStream(this.context, this, new BroadcastOperator()) + .withEncoder(encoder); + } + + @Override + public PWindowKeyStream keyBy(KeySelector selectorFunction) { + Preconditions.checkArgument(selectorFunction != null, " KeySelector Function must not be null"); + return new WindowKeyDataStream( + context, this, new KeySelectorOperator(selectorFunction), selectorFunction) + .withEncoder(this.encoder); + } + + @Override + public WindowStreamSink sink(SinkFunction sinkFunction) { + Preconditions.checkArgument(sinkFunction != null, " Sink Function must not be null"); + WindowStreamSink sink = new WindowStreamSink(this, new SinkOperator<>(sinkFunction)); + context.addPAction(sink); + return sink; + } + + @Override + public PWindowCollect collect() { + WindowStreamCollect collect = new WindowStreamCollect<>(this, new CollectOperator()); + context.addPAction(collect); + return collect; + } + + @Override + public PWindowStream withConfig(Map config) { + setConfig(config); + return this; + } + + @Override + public PWindowStream withConfig(String key, String value) { + setConfig(key, value); + return this; + } + + @Override + public PWindowStream withName(String name) { + this.opArgs.setOpName(name); + return this; + } + + @Override + public PWindowStream withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } + + @Override + public WindowDataStream withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowKeyDataStream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowKeyDataStream.java index 6147196ac..b4d87c571 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowKeyDataStream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowKeyDataStream.java @@ -19,8 +19,8 @@ package org.apache.geaflow.pdata.stream.window; -import com.google.common.base.Preconditions; import java.util.Map; + import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -39,79 +39,89 @@ import org.apache.geaflow.pdata.stream.view.IncStreamView; import org.apache.geaflow.pipeline.context.IPipelineContext; -public class WindowKeyDataStream extends WindowDataStream implements - PWindowKeyStream { +import com.google.common.base.Preconditions; + +public class WindowKeyDataStream extends WindowDataStream + implements PWindowKeyStream { - private KeySelector keySelector; - private boolean materializeDisable; + private KeySelector keySelector; + private boolean materializeDisable; - public WindowKeyDataStream(IPipelineContext context, WindowDataStream dataStream, - AbstractOperator operator, - KeySelector keySelector) { - super(context, dataStream, operator); - this.keySelector = keySelector; - this.materializeDisable = ((AbstractPipelineContext) context).getConfig() + public WindowKeyDataStream( + IPipelineContext context, + WindowDataStream dataStream, + AbstractOperator operator, + KeySelector keySelector) { + super(context, dataStream, operator); + this.keySelector = keySelector; + this.materializeDisable = + ((AbstractPipelineContext) context) + .getConfig() .getBoolean(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE); - } + } - @Override - public PWindowStream aggregate(AggregateFunction aggregateFunction) { - if (!materializeDisable) { - return materialize().aggregate(aggregateFunction); - } - Preconditions.checkArgument(aggregateFunction != null, " aggregate Function must not be null"); - IEncoder resultEncoder = EncoderResolver.resolveFunction(AggregateFunction.class, aggregateFunction, 2); - return new WindowDataStream(this.context, this, new WindowAggregateOperator<>(aggregateFunction, keySelector)).withEncoder(resultEncoder); + @Override + public PWindowStream aggregate(AggregateFunction aggregateFunction) { + if (!materializeDisable) { + return materialize().aggregate(aggregateFunction); } + Preconditions.checkArgument(aggregateFunction != null, " aggregate Function must not be null"); + IEncoder resultEncoder = + EncoderResolver.resolveFunction(AggregateFunction.class, aggregateFunction, 2); + return new WindowDataStream( + this.context, this, new WindowAggregateOperator<>(aggregateFunction, keySelector)) + .withEncoder(resultEncoder); + } - @Override - public PWindowStream reduce(ReduceFunction reduceFunction) { - if (!materializeDisable) { - return materialize().reduce(reduceFunction); - } - Preconditions.checkArgument(reduceFunction != null, " Reduce Function must not be null"); - return new WindowDataStream(this.context, this, new WindowReduceOperator<>(reduceFunction, keySelector)).withEncoder(this.encoder); + @Override + public PWindowStream reduce(ReduceFunction reduceFunction) { + if (!materializeDisable) { + return materialize().reduce(reduceFunction); } + Preconditions.checkArgument(reduceFunction != null, " Reduce Function must not be null"); + return new WindowDataStream( + this.context, this, new WindowReduceOperator<>(reduceFunction, keySelector)) + .withEncoder(this.encoder); + } - @Override - public PIncStreamView materialize() { - IncStreamView incStreamView = new IncStreamView<>(context, keySelector); - return incStreamView.append(this); - } + @Override + public PIncStreamView materialize() { + IncStreamView incStreamView = new IncStreamView<>(context, keySelector); + return incStreamView.append(this); + } - @Override - public PWindowKeyStream withConfig(Map config) { - this.opArgs.setConfig(config); - return this; - } + @Override + public PWindowKeyStream withConfig(Map config) { + this.opArgs.setConfig(config); + return this; + } - @Override - public PWindowKeyStream withConfig(String key, String value) { - this.opArgs.getConfig().put(key, value); - return this; - } - - @Override - public PWindowKeyStream withName(String name) { - setName(name); - return this; - } + @Override + public PWindowKeyStream withConfig(String key, String value) { + this.opArgs.getConfig().put(key, value); + return this; + } - @Override - public PWindowKeyStream withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } + @Override + public PWindowKeyStream withName(String name) { + setName(name); + return this; + } - @Override - public IPartitioner getPartition() { - return new KeyPartitioner(this.getId()); - } + @Override + public PWindowKeyStream withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } - @Override - public WindowKeyDataStream withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } + @Override + public IPartitioner getPartition() { + return new KeyPartitioner(this.getId()); + } + @Override + public WindowKeyDataStream withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamCollect.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamCollect.java index 21d3f2cf4..a490340f3 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamCollect.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamCollect.java @@ -20,6 +20,7 @@ package org.apache.geaflow.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.pdata.PWindowCollect; import org.apache.geaflow.operator.base.AbstractOperator; import org.apache.geaflow.pdata.stream.Stream; @@ -27,36 +28,36 @@ public class WindowStreamCollect extends Stream implements PWindowCollect { - public WindowStreamCollect(Stream stream, AbstractOperator operator) { - super(stream, operator); - } - - @Override - public WindowStreamCollect withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } - - @Override - public WindowStreamCollect withName(String name) { - setName(name); - return this; - } - - @Override - public WindowStreamCollect withConfig(Map map) { - setConfig(map); - return this; - } - - @Override - public WindowStreamCollect withConfig(String key, String value) { - setConfig(key, value); - return this; - } - - @Override - public TransformType getTransformType() { - return TransformType.StreamTransform; - } + public WindowStreamCollect(Stream stream, AbstractOperator operator) { + super(stream, operator); + } + + @Override + public WindowStreamCollect withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } + + @Override + public WindowStreamCollect withName(String name) { + setName(name); + return this; + } + + @Override + public WindowStreamCollect withConfig(Map map) { + setConfig(map); + return this; + } + + @Override + public WindowStreamCollect withConfig(String key, String value) { + setConfig(key, value); + return this; + } + + @Override + public TransformType getTransformType() { + return TransformType.StreamTransform; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSink.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSink.java index 9c346013f..cc93dd80a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSink.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSink.java @@ -20,6 +20,7 @@ package org.apache.geaflow.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.pdata.PStreamSink; import org.apache.geaflow.operator.base.AbstractOperator; import org.apache.geaflow.pdata.stream.Stream; @@ -27,40 +28,38 @@ public class WindowStreamSink extends Stream implements PStreamSink { - public WindowStreamSink(IPipelineContext pipelineContext) { - super(pipelineContext); - } - - public WindowStreamSink(IPipelineContext pipelineContext, AbstractOperator operator) { - super(pipelineContext, operator); - } - - public WindowStreamSink(Stream stream, AbstractOperator operator) { - super(stream, operator); - } + public WindowStreamSink(IPipelineContext pipelineContext) { + super(pipelineContext); + } + public WindowStreamSink(IPipelineContext pipelineContext, AbstractOperator operator) { + super(pipelineContext, operator); + } - @Override - public WindowStreamSink withConfig(Map map) { - setConfig(map); - return this; - } + public WindowStreamSink(Stream stream, AbstractOperator operator) { + super(stream, operator); + } - @Override - public WindowStreamSink withConfig(String key, String value) { - setConfig(key, value); - return this; - } + @Override + public WindowStreamSink withConfig(Map map) { + setConfig(map); + return this; + } - @Override - public WindowStreamSink withName(String name) { - setName(name); - return this; - } + @Override + public WindowStreamSink withConfig(String key, String value) { + setConfig(key, value); + return this; + } - public WindowStreamSink withParallelism(int parallelism) { - super.setParallelism(parallelism); - return this; - } + @Override + public WindowStreamSink withName(String name) { + setName(name); + return this; + } + public WindowStreamSink withParallelism(int parallelism) { + super.setParallelism(parallelism); + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSource.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSource.java index f97329f71..8430590ab 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSource.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowStreamSource.java @@ -20,6 +20,7 @@ package org.apache.geaflow.pdata.stream.window; import java.util.Map; + import org.apache.geaflow.api.function.io.SourceFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.api.window.IWindow; @@ -31,63 +32,62 @@ public class WindowStreamSource extends WindowDataStream implements PWindowSource { - protected IWindow windowFunction; - - public WindowStreamSource(IPipelineContext pipelineContext, - SourceFunction sourceFunction, - IWindow windowFunction) { - super(pipelineContext, new WindowSourceOperator<>(sourceFunction, windowFunction)); - this.windowFunction = windowFunction; - this.encoder = (IEncoder) EncoderResolver.resolveFunction(SourceFunction.class, sourceFunction); - } - - @Override - public PWindowSource build(IPipelineContext pipelineContext, - SourceFunction sourceFunction, - IWindow window) { - return new WindowStreamSource<>(pipelineContext, sourceFunction, window); - } + protected IWindow windowFunction; - @Override - public WindowStreamSource window(IWindow window) { - this.windowFunction = window; - return this; - } + public WindowStreamSource( + IPipelineContext pipelineContext, + SourceFunction sourceFunction, + IWindow windowFunction) { + super(pipelineContext, new WindowSourceOperator<>(sourceFunction, windowFunction)); + this.windowFunction = windowFunction; + this.encoder = + (IEncoder) EncoderResolver.resolveFunction(SourceFunction.class, sourceFunction); + } + @Override + public PWindowSource build( + IPipelineContext pipelineContext, SourceFunction sourceFunction, IWindow window) { + return new WindowStreamSource<>(pipelineContext, sourceFunction, window); + } - @Override - public WindowStreamSource withConfig(Map map) { - super.withConfig(map); - return this; - } + @Override + public WindowStreamSource window(IWindow window) { + this.windowFunction = window; + return this; + } - @Override - public WindowStreamSource withConfig(String key, String value) { - super.withConfig(key, value); - return this; - } + @Override + public WindowStreamSource withConfig(Map map) { + super.withConfig(map); + return this; + } - @Override - public WindowStreamSource withName(String name) { - super.withName(name); - return this; - } + @Override + public WindowStreamSource withConfig(String key, String value) { + super.withConfig(key, value); + return this; + } - @Override - public WindowStreamSource withParallelism(int parallelism) { - super.withParallelism(parallelism); - return this; - } + @Override + public WindowStreamSource withName(String name) { + super.withName(name); + return this; + } - @Override - public TransformType getTransformType() { - return TransformType.StreamSource; - } + @Override + public WindowStreamSource withParallelism(int parallelism) { + super.withParallelism(parallelism); + return this; + } - @Override - public WindowStreamSource withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } + @Override + public TransformType getTransformType() { + return TransformType.StreamSource; + } + @Override + public WindowStreamSource withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowUnionStream.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowUnionStream.java index e80b1d694..bce69d048 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowUnionStream.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pdata/src/main/java/org/apache/geaflow/pdata/stream/window/WindowUnionStream.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.pdata.stream.PUnionStream; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.operator.impl.window.UnionOperator; @@ -29,56 +30,55 @@ public class WindowUnionStream extends WindowDataStream implements PUnionStream { - private List> unionWindowDataStreamList; - - public WindowUnionStream(WindowDataStream stream, WindowDataStream unionStream, - UnionOperator unionOperator) { - super(stream, unionOperator); - this.unionWindowDataStreamList = new ArrayList<>(); - this.addUnionDataStream(unionStream); - } + private List> unionWindowDataStreamList; - public void addUnionDataStream(WindowDataStream unionStream) { - this.unionWindowDataStreamList.add(unionStream); - } + public WindowUnionStream( + WindowDataStream stream, WindowDataStream unionStream, UnionOperator unionOperator) { + super(stream, unionOperator); + this.unionWindowDataStreamList = new ArrayList<>(); + this.addUnionDataStream(unionStream); + } - public List> getUnionWindowDataStreamList() { - return unionWindowDataStreamList; - } + public void addUnionDataStream(WindowDataStream unionStream) { + this.unionWindowDataStreamList.add(unionStream); + } - @Override - public WindowUnionStream withConfig(Map map) { - setConfig(map); - return this; - } + public List> getUnionWindowDataStreamList() { + return unionWindowDataStreamList; + } - @Override - public WindowUnionStream withConfig(String key, String value) { - setConfig(key, value); - return this; - } + @Override + public WindowUnionStream withConfig(Map map) { + setConfig(map); + return this; + } - @Override - public WindowUnionStream withName(String name) { - this.opArgs.setOpName(name); - return this; - } + @Override + public WindowUnionStream withConfig(String key, String value) { + setConfig(key, value); + return this; + } - @Override - public WindowUnionStream withParallelism(int parallelism) { - setParallelism(parallelism); - return this; - } + @Override + public WindowUnionStream withName(String name) { + this.opArgs.setOpName(name); + return this; + } - @Override - public TransformType getTransformType() { - return TransformType.StreamUnion; - } + @Override + public WindowUnionStream withParallelism(int parallelism) { + setParallelism(parallelism); + return this; + } - @Override - public WindowUnionStream withEncoder(IEncoder encoder) { - this.encoder = encoder; - return this; - } + @Override + public TransformType getTransformType() { + return TransformType.StreamUnion; + } + @Override + public WindowUnionStream withEncoder(IEncoder encoder) { + this.encoder = encoder; + return this; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineContext.java index 09f84827d..706ec8f6b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineContext.java @@ -21,27 +21,27 @@ import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.context.AbstractPipelineContext; public class PipelineContext extends AbstractPipelineContext { - private transient AtomicInteger idGenerator; - private String name; - - public PipelineContext(String name, Configuration pipelineConfig) { - super(pipelineConfig); - this.name = name; - this.actions = new ArrayList<>(); - this.idGenerator = new AtomicInteger(0); - } + private transient AtomicInteger idGenerator; + private String name; - public int generateId() { - return idGenerator.addAndGet(1); - } + public PipelineContext(String name, Configuration pipelineConfig) { + super(pipelineConfig); + this.name = name; + this.actions = new ArrayList<>(); + this.idGenerator = new AtomicInteger(0); + } - public String getName() { - return name; - } + public int generateId() { + return idGenerator.addAndGet(1); + } + public String getName() { + return name; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineTaskType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineTaskType.java index 5e2b045d4..51b164c9d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineTaskType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/PipelineTaskType.java @@ -21,14 +21,9 @@ public enum PipelineTaskType { - /** - * Default pipeline task. - */ - PipelineTask, - - /** - * Compile pipeline task. - */ - CompileTask; + /** Default pipeline task. */ + PipelineTask, + /** Compile pipeline task. */ + CompileTask; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/executor/PipelineExecutor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/executor/PipelineExecutor.java index a22cb9322..b4f67a729 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/executor/PipelineExecutor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/executor/PipelineExecutor.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.executor.IPipelineExecutor; import org.apache.geaflow.cluster.executor.PipelineExecutorContext; import org.apache.geaflow.common.config.Configuration; @@ -41,66 +42,80 @@ public class PipelineExecutor implements IPipelineExecutor { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); - private PipelineRunner pipelineRunner; - private PipelineExecutorContext executorContext; - private List viewDescList; - private Map serviceExecutorMap; + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); + private PipelineRunner pipelineRunner; + private PipelineExecutorContext executorContext; + private List viewDescList; + private Map serviceExecutorMap; - public void init(PipelineExecutorContext executorContext) { - this.executorContext = executorContext; - this.pipelineRunner = new PipelineRunner(executorContext.getEventDispatcher()); - this.serviceExecutorMap = new HashMap<>(); - } + public void init(PipelineExecutorContext executorContext) { + this.executorContext = executorContext; + this.pipelineRunner = new PipelineRunner(executorContext.getEventDispatcher()); + this.serviceExecutorMap = new HashMap<>(); + } - @Override - public void register(List viewDescList) { - this.viewDescList = viewDescList; - } + @Override + public void register(List viewDescList) { + this.viewDescList = viewDescList; + } - @Override - public void runPipelineTask(PipelineTask pipelineTask, TaskCallBack taskCallBack) { - int pipelineTaskId = executorContext.getIdGenerator().getAndIncrement(); - String pipelineTaskName = String.format("%s#%s", PipelineTaskType.PipelineTask.name(), + @Override + public void runPipelineTask(PipelineTask pipelineTask, TaskCallBack taskCallBack) { + int pipelineTaskId = executorContext.getIdGenerator().getAndIncrement(); + String pipelineTaskName = + String.format( + "%s#%s", + PipelineTaskType.PipelineTask.name(), executorContext.getIdGenerator().getAndIncrement()); - LOGGER.info("run pipeline task {}", pipelineTaskName); + LOGGER.info("run pipeline task {}", pipelineTaskName); - PipelineContext pipelineContext = new PipelineContext(PipelineTaskType.PipelineTask.name(), - executorContext.getEnvConfig()); - this.viewDescList.stream().forEach(viewDesc -> pipelineContext.addView(viewDesc)); + PipelineContext pipelineContext = + new PipelineContext(PipelineTaskType.PipelineTask.name(), executorContext.getEnvConfig()); + this.viewDescList.stream().forEach(viewDesc -> pipelineContext.addView(viewDesc)); - PipelineTaskExecutorContext taskExecutorContext = - new PipelineTaskExecutorContext(executorContext.getDriverId(), - pipelineTaskId, pipelineTaskName, pipelineContext, pipelineRunner); - PipelineTaskExecutor taskExecutor = new PipelineTaskExecutor(taskExecutorContext); - taskExecutor.execute(pipelineTask, taskCallBack); - } + PipelineTaskExecutorContext taskExecutorContext = + new PipelineTaskExecutorContext( + executorContext.getDriverId(), + pipelineTaskId, + pipelineTaskName, + pipelineContext, + pipelineRunner); + PipelineTaskExecutor taskExecutor = new PipelineTaskExecutor(taskExecutorContext); + taskExecutor.execute(pipelineTask, taskCallBack); + } - @Override - public void startPipelineService(PipelineService pipelineService) { - int pipelineTaskId = executorContext.getIdGenerator().getAndIncrement(); - String pipelineTaskName = String.format("%s#%s", PipelineTaskType.PipelineTask.name(), pipelineTaskId); - LOGGER.info("run pipeline task {}", pipelineTaskName); + @Override + public void startPipelineService(PipelineService pipelineService) { + int pipelineTaskId = executorContext.getIdGenerator().getAndIncrement(); + String pipelineTaskName = + String.format("%s#%s", PipelineTaskType.PipelineTask.name(), pipelineTaskId); + LOGGER.info("run pipeline task {}", pipelineTaskName); - Configuration configuration = new Configuration(); - configuration.putAll(executorContext.getEnvConfig().getConfigMap()); - configuration.setMasterId(executorContext.getEnvConfig().getMasterId()); - PipelineContext pipelineContext = new PipelineContext(PipelineTaskType.PipelineTask.name(), - configuration); - this.viewDescList.stream().forEach(viewDesc -> pipelineContext.addView(viewDesc)); + Configuration configuration = new Configuration(); + configuration.putAll(executorContext.getEnvConfig().getConfigMap()); + configuration.setMasterId(executorContext.getEnvConfig().getMasterId()); + PipelineContext pipelineContext = + new PipelineContext(PipelineTaskType.PipelineTask.name(), configuration); + this.viewDescList.stream().forEach(viewDesc -> pipelineContext.addView(viewDesc)); - PipelineServiceExecutorContext pipelineServiceExecutorContext = - new PipelineServiceExecutorContext(executorContext.getDriverId(), executorContext.getDriverIndex(), - pipelineTaskId, pipelineTaskName, pipelineContext, pipelineRunner, pipelineService); - PipelineServiceExecutor serviceExecutor = - new PipelineServiceExecutor(pipelineServiceExecutorContext); - serviceExecutorMap.put(pipelineService, serviceExecutor); - serviceExecutor.start(); - } + PipelineServiceExecutorContext pipelineServiceExecutorContext = + new PipelineServiceExecutorContext( + executorContext.getDriverId(), + executorContext.getDriverIndex(), + pipelineTaskId, + pipelineTaskName, + pipelineContext, + pipelineRunner, + pipelineService); + PipelineServiceExecutor serviceExecutor = + new PipelineServiceExecutor(pipelineServiceExecutorContext); + serviceExecutorMap.put(pipelineService, serviceExecutor); + serviceExecutor.start(); + } - @Override - public void stopPipelineService(PipelineService pipelineService) { - serviceExecutorMap.get(pipelineService).stop(); - LOGGER.info("stopped pipeline service {}", pipelineService); - } + @Override + public void stopPipelineService(PipelineService pipelineService) { + serviceExecutorMap.get(pipelineService).stop(); + LOGGER.info("stopped pipeline service {}", pipelineService); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/runner/PipelineRunner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/runner/PipelineRunner.java index e76cc2ab9..237f4ab9e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/runner/PipelineRunner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/runner/PipelineRunner.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.common.ExecutionIdGenerator; import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.driver.DriverEventDispatcher; @@ -49,72 +50,91 @@ public class PipelineRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineRunner.class); - private DriverEventDispatcher eventDispatcher; - private ICycleSchedulerContext context; + private DriverEventDispatcher eventDispatcher; + private ICycleSchedulerContext context; - public PipelineRunner(DriverEventDispatcher eventDispatcher) { - this.eventDispatcher = eventDispatcher; - } + public PipelineRunner(DriverEventDispatcher eventDispatcher) { + this.eventDispatcher = eventDispatcher; + } - public IExecutionResult executePipelineGraph(IPipelineExecutorContext pipelineExecutorContext, - PipelineGraph pipelineGraph, - TaskCallBack taskCallBack) { - ICycleSchedulerContext context = loadOrCreateContext(pipelineExecutorContext, pipelineGraph); - - if (taskCallBack != null) { - ((AbstractCycleSchedulerContext) context).setCallbackFunction(taskCallBack.getCallbackFunction()); - } - - ICycleScheduler scheduler = CycleSchedulerFactory.create(context.getCycle()); - if (scheduler instanceof IEventListener) { - eventDispatcher.registerListener(context.getCycle().getSchedulerId(), (IEventListener) scheduler); - } - - scheduler.init(context); - IExecutionResult result = scheduler.execute(); - LOGGER.info("final result of pipeline is {}", result.getResult()); - scheduler.close(); - if (scheduler instanceof IEventListener) { - eventDispatcher.removeListener(((ExecutionGraphCycleScheduler) scheduler).getSchedulerId()); - } - return result; - } + public IExecutionResult executePipelineGraph( + IPipelineExecutorContext pipelineExecutorContext, + PipelineGraph pipelineGraph, + TaskCallBack taskCallBack) { + ICycleSchedulerContext context = loadOrCreateContext(pipelineExecutorContext, pipelineGraph); - public void runPipelineGraph(PipelineGraph pipelineGraph, TaskCallBack taskCallBack, - PipelineTaskExecutorContext taskExecutorContext) { - IExecutionResult result = executePipelineGraph(taskExecutorContext, pipelineGraph, taskCallBack); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("run pipeline task failed, cause: " + result.getError()); - } + if (taskCallBack != null) { + ((AbstractCycleSchedulerContext) context) + .setCallbackFunction(taskCallBack.getCallbackFunction()); } - public IExecutionResult runPipelineGraph(PipelineGraph pipelineGraph, - PipelineServiceExecutorContext serviceExecutorContext) { - //TODO Service task callback. - return executePipelineGraph(serviceExecutorContext, pipelineGraph, null); + ICycleScheduler scheduler = CycleSchedulerFactory.create(context.getCycle()); + if (scheduler instanceof IEventListener) { + eventDispatcher.registerListener( + context.getCycle().getSchedulerId(), (IEventListener) scheduler); } - private ICycleSchedulerContext loadOrCreateContext(IPipelineExecutorContext pipelineExecutorContext, - PipelineGraph pipelineGraph) { - - ICycleSchedulerContext context = CycleSchedulerContextFactory.loadOrCreate(pipelineExecutorContext.getPipelineTaskId(), () -> { - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - PipelineContext pipelineContext = (PipelineContext) pipelineExecutorContext.getPipelineContext(); - ExecutionGraph graph = builder.buildExecutionGraph(pipelineContext.getConfig()); - - Map> vertex2Tasks = ExecutionCycleTaskAssigner.assign(graph); - - // Skip checkpoint if it's a PipelineServiceExecutorContext - boolean skipCheckpoint = pipelineExecutorContext instanceof PipelineServiceExecutorContext; - IExecutionCycle cycle = ExecutionCycleBuilder.buildExecutionCycle(graph, vertex2Tasks, - pipelineContext.getConfig(), ExecutionIdGenerator.getInstance().generateId(), - pipelineExecutorContext.getPipelineTaskId(), pipelineExecutorContext.getPipelineTaskName(), - ExecutionIdGenerator.getInstance().generateId(), pipelineExecutorContext.getDriverId(), - pipelineExecutorContext.getDriverIndex(), skipCheckpoint); - return CycleSchedulerContextFactory.create(cycle, null); - }); - return context; + scheduler.init(context); + IExecutionResult result = scheduler.execute(); + LOGGER.info("final result of pipeline is {}", result.getResult()); + scheduler.close(); + if (scheduler instanceof IEventListener) { + eventDispatcher.removeListener(((ExecutionGraphCycleScheduler) scheduler).getSchedulerId()); + } + return result; + } + + public void runPipelineGraph( + PipelineGraph pipelineGraph, + TaskCallBack taskCallBack, + PipelineTaskExecutorContext taskExecutorContext) { + IExecutionResult result = + executePipelineGraph(taskExecutorContext, pipelineGraph, taskCallBack); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("run pipeline task failed, cause: " + result.getError()); } + } + + public IExecutionResult runPipelineGraph( + PipelineGraph pipelineGraph, PipelineServiceExecutorContext serviceExecutorContext) { + // TODO Service task callback. + return executePipelineGraph(serviceExecutorContext, pipelineGraph, null); + } + + private ICycleSchedulerContext loadOrCreateContext( + IPipelineExecutorContext pipelineExecutorContext, PipelineGraph pipelineGraph) { + + ICycleSchedulerContext context = + CycleSchedulerContextFactory.loadOrCreate( + pipelineExecutorContext.getPipelineTaskId(), + () -> { + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + PipelineContext pipelineContext = + (PipelineContext) pipelineExecutorContext.getPipelineContext(); + ExecutionGraph graph = builder.buildExecutionGraph(pipelineContext.getConfig()); + + Map> vertex2Tasks = + ExecutionCycleTaskAssigner.assign(graph); + + // Skip checkpoint if it's a PipelineServiceExecutorContext + boolean skipCheckpoint = + pipelineExecutorContext instanceof PipelineServiceExecutorContext; + IExecutionCycle cycle = + ExecutionCycleBuilder.buildExecutionCycle( + graph, + vertex2Tasks, + pipelineContext.getConfig(), + ExecutionIdGenerator.getInstance().generateId(), + pipelineExecutorContext.getPipelineTaskId(), + pipelineExecutorContext.getPipelineTaskName(), + ExecutionIdGenerator.getInstance().generateId(), + pipelineExecutorContext.getDriverId(), + pipelineExecutorContext.getDriverIndex(), + skipCheckpoint); + return CycleSchedulerContextFactory.create(cycle, null); + }); + return context; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceContext.java index 3bbd2056b..b5745a973 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceContext.java @@ -38,67 +38,60 @@ public class PipelineServiceContext implements IPipelineServiceContext { - private long sessionId; - private PipelineContext pipelineContext; - private Object request; + private long sessionId; + private PipelineContext pipelineContext; + private Object request; - public PipelineServiceContext(long sessionId, - PipelineContext pipelineContext) { - this.sessionId = sessionId; - this.pipelineContext = pipelineContext; - } + public PipelineServiceContext(long sessionId, PipelineContext pipelineContext) { + this.sessionId = sessionId; + this.pipelineContext = pipelineContext; + } - public PipelineServiceContext(long sessionId, - PipelineContext pipelineContext, - Object request) { - this(sessionId, pipelineContext); - this.request = request; - } + public PipelineServiceContext(long sessionId, PipelineContext pipelineContext, Object request) { + this(sessionId, pipelineContext); + this.request = request; + } - @Override - public long getId() { - return sessionId; - } + @Override + public long getId() { + return sessionId; + } - @Override - public Object getRequest() { - return request; - } + @Override + public Object getRequest() { + return request; + } - @Override - public void response(Object response) { + @Override + public void response(Object response) {} - } + @Override + public Configuration getConfig() { + return pipelineContext.getConfig(); + } - @Override - public Configuration getConfig() { - return pipelineContext.getConfig(); - } + @Override + public PWindowSource buildSource(SourceFunction sourceFunction, IWindow window) { + return new WindowStreamSource<>(pipelineContext, sourceFunction, window); + } - @Override - public PWindowSource buildSource(SourceFunction sourceFunction, IWindow window) { - return new WindowStreamSource<>(pipelineContext, sourceFunction, window); - } + @Override + public PGraphView getGraphView(String viewName) { + IViewDesc viewDesc = pipelineContext.getViewDesc(viewName); + return new IncGraphView<>(pipelineContext, viewDesc); + } - @Override - public PGraphView getGraphView(String viewName) { - IViewDesc viewDesc = pipelineContext.getViewDesc(viewName); - return new IncGraphView<>(pipelineContext, viewDesc); - } + @Override + public PGraphView createGraphView(IViewDesc viewDesc) { + return new IncGraphView<>(pipelineContext, viewDesc); + } - @Override - public PGraphView createGraphView(IViewDesc viewDesc) { - return new IncGraphView<>(pipelineContext, viewDesc); - } - - @Override - public PGraphWindow buildWindowStreamGraph(PWindowStream> vertexWindowSteam, - PWindowStream> edgeWindowStream, - GraphViewDesc graphViewDesc) { - return new WindowStreamGraph( - graphViewDesc, - pipelineContext, - vertexWindowSteam, - edgeWindowStream); - } + @Override + public PGraphWindow buildWindowStreamGraph( + PWindowStream> vertexWindowSteam, + PWindowStream> edgeWindowStream, + GraphViewDesc graphViewDesc) { + return new WindowStreamGraph( + graphViewDesc, pipelineContext, vertexWindowSteam, edgeWindowStream); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutor.java index 713204251..b75f14443 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.pipeline.service; import java.io.Serializable; + import org.apache.geaflow.pipeline.service.IServiceServer; import org.apache.geaflow.runtime.pipeline.executor.PipelineExecutor; import org.apache.geaflow.runtime.pipeline.service.util.ServerFactory; @@ -28,22 +29,22 @@ public class PipelineServiceExecutor implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); - private PipelineServiceExecutorContext serviceExecutorContext; - private IServiceServer serviceServer; + private PipelineServiceExecutorContext serviceExecutorContext; + private IServiceServer serviceServer; - public PipelineServiceExecutor(PipelineServiceExecutorContext serviceExecutorContext) { - this.serviceExecutorContext = serviceExecutorContext; - } + public PipelineServiceExecutor(PipelineServiceExecutorContext serviceExecutorContext) { + this.serviceExecutorContext = serviceExecutorContext; + } - public void start() { - LOGGER.info("start pipeline service {}", serviceExecutorContext.getPipelineService()); - this.serviceServer = ServerFactory.loadServer(this.serviceExecutorContext); - this.serviceServer.startServer(); - } + public void start() { + LOGGER.info("start pipeline service {}", serviceExecutorContext.getPipelineService()); + this.serviceServer = ServerFactory.loadServer(this.serviceExecutorContext); + this.serviceServer.startServer(); + } - public void stop() { - this.serviceServer.stopServer(); - } + public void stop() { + this.serviceServer.stopServer(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutorContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutorContext.java index 7ace2896f..12778f6c2 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutorContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/PipelineServiceExecutorContext.java @@ -27,61 +27,62 @@ public class PipelineServiceExecutorContext implements IPipelineServiceExecutorContext { - private String driverId; - private int driverIndex; - private long pipelineTaskId; - private String pipelineTaskName; - private PipelineContext pipelineContext; - private PipelineRunner pipelineRunner; - private PipelineService pipelineService; + private String driverId; + private int driverIndex; + private long pipelineTaskId; + private String pipelineTaskName; + private PipelineContext pipelineContext; + private PipelineRunner pipelineRunner; + private PipelineService pipelineService; - public PipelineServiceExecutorContext(String driverId, - int driverIndex, - long pipelineTaskId, - String pipelineTaskName, - PipelineContext pipelineContext, - PipelineRunner pipelineRunner, - PipelineService pipelineService) { - this.driverId = driverId; - this.driverIndex = driverIndex; - this.pipelineTaskId = pipelineTaskId; - this.pipelineTaskName = pipelineTaskName; - this.pipelineContext = pipelineContext; - this.pipelineRunner = pipelineRunner; - this.pipelineService = pipelineService; - } + public PipelineServiceExecutorContext( + String driverId, + int driverIndex, + long pipelineTaskId, + String pipelineTaskName, + PipelineContext pipelineContext, + PipelineRunner pipelineRunner, + PipelineService pipelineService) { + this.driverId = driverId; + this.driverIndex = driverIndex; + this.pipelineTaskId = pipelineTaskId; + this.pipelineTaskName = pipelineTaskName; + this.pipelineContext = pipelineContext; + this.pipelineRunner = pipelineRunner; + this.pipelineService = pipelineService; + } - public String getDriverId() { - return driverId; - } + public String getDriverId() { + return driverId; + } - public int getDriverIndex() { - return driverIndex; - } + public int getDriverIndex() { + return driverIndex; + } - public long getPipelineTaskId() { - return pipelineTaskId; - } + public long getPipelineTaskId() { + return pipelineTaskId; + } - public String getPipelineTaskName() { - return pipelineTaskName; - } + public String getPipelineTaskName() { + return pipelineTaskName; + } - public PipelineContext getPipelineContext() { - return pipelineContext; - } + public PipelineContext getPipelineContext() { + return pipelineContext; + } - public PipelineRunner getPipelineRunner() { - return pipelineRunner; - } + public PipelineRunner getPipelineRunner() { + return pipelineRunner; + } - @Override - public Configuration getConfiguration() { - return pipelineContext.getConfig(); - } + @Override + public Configuration getConfiguration() { + return pipelineContext.getConfig(); + } - @Override - public PipelineService getPipelineService() { - return pipelineService; - } + @Override + public PipelineService getPipelineService() { + return pipelineService; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/util/ServerFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/util/ServerFactory.java index 1135ab521..3f79cc361 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/util/ServerFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/service/util/ServerFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -33,23 +34,23 @@ public class ServerFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(ServerFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ServerFactory.class); - public static IServiceServer loadServer(IPipelineServiceExecutorContext context) { - Configuration configuration = context.getConfiguration(); - String type = configuration.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); - ServiceLoader contextLoader = ServiceLoader.load(IServiceServer.class); - Iterator contextIterable = contextLoader.iterator(); - while (contextIterable.hasNext()) { - IServiceServer serviceServer = contextIterable.next(); - if (serviceServer.getServiceType() == ServiceType.getEnum(type)) { - LOGGER.info("loaded IServiceServer implementation {}", serviceServer); - serviceServer.init(context); - return serviceServer; - } - } - LOGGER.error("NOT found IServiceServer implementation with type:{}", type); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.spiNotFoundError(IServiceServer.class.getSimpleName())); + public static IServiceServer loadServer(IPipelineServiceExecutorContext context) { + Configuration configuration = context.getConfiguration(); + String type = configuration.getString(FrameworkConfigKeys.SERVICE_SERVER_TYPE); + ServiceLoader contextLoader = ServiceLoader.load(IServiceServer.class); + Iterator contextIterable = contextLoader.iterator(); + while (contextIterable.hasNext()) { + IServiceServer serviceServer = contextIterable.next(); + if (serviceServer.getServiceType() == ServiceType.getEnum(type)) { + LOGGER.info("loaded IServiceServer implementation {}", serviceServer); + serviceServer.init(context); + return serviceServer; + } } + LOGGER.error("NOT found IServiceServer implementation with type:{}", type); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IServiceServer.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskContext.java index 466234ae3..54550fd50 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskContext.java @@ -38,46 +38,46 @@ public class PipelineTaskContext implements IPipelineTaskContext { - private long pipelineTaskId; - private PipelineContext pipelineContext; + private long pipelineTaskId; + private PipelineContext pipelineContext; - public PipelineTaskContext(long pipelineTaskId, - PipelineContext pipelineContext) { - this.pipelineTaskId = pipelineTaskId; - this.pipelineContext = pipelineContext; - } + public PipelineTaskContext(long pipelineTaskId, PipelineContext pipelineContext) { + this.pipelineTaskId = pipelineTaskId; + this.pipelineContext = pipelineContext; + } - @Override - public long getId() { - return this.pipelineTaskId; - } + @Override + public long getId() { + return this.pipelineTaskId; + } - @Override - public Configuration getConfig() { - return pipelineContext.getConfig(); - } + @Override + public Configuration getConfig() { + return pipelineContext.getConfig(); + } - @Override - public PWindowSource buildSource(SourceFunction sourceFunction, IWindow window) { - return new WindowStreamSource<>(pipelineContext, sourceFunction, window); - } + @Override + public PWindowSource buildSource(SourceFunction sourceFunction, IWindow window) { + return new WindowStreamSource<>(pipelineContext, sourceFunction, window); + } - @Override - public PGraphView getGraphView(String viewName) { - IViewDesc viewDesc = pipelineContext.getViewDesc(viewName); - return new IncGraphView<>(pipelineContext, viewDesc); - } + @Override + public PGraphView getGraphView(String viewName) { + IViewDesc viewDesc = pipelineContext.getViewDesc(viewName); + return new IncGraphView<>(pipelineContext, viewDesc); + } - @Override - public PGraphView createGraphView(IViewDesc viewDesc) { - return new IncGraphView<>(pipelineContext, viewDesc); - } - - @Override - public PGraphWindow buildWindowStreamGraph(PWindowStream> vertexWindowSteam, - PWindowStream> edgeWindowStream, - GraphViewDesc graphViewDesc) { - return new WindowStreamGraph(graphViewDesc, pipelineContext, vertexWindowSteam, edgeWindowStream); - } + @Override + public PGraphView createGraphView(IViewDesc viewDesc) { + return new IncGraphView<>(pipelineContext, viewDesc); + } + @Override + public PGraphWindow buildWindowStreamGraph( + PWindowStream> vertexWindowSteam, + PWindowStream> edgeWindowStream, + GraphViewDesc graphViewDesc) { + return new WindowStreamGraph( + graphViewDesc, pipelineContext, vertexWindowSteam, edgeWindowStream); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutor.java index 35c45ba81..9df2bfdcb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.pipeline.task; import java.io.Serializable; + import org.apache.geaflow.pipeline.callback.TaskCallBack; import org.apache.geaflow.pipeline.task.PipelineTask; import org.apache.geaflow.plan.PipelinePlanBuilder; @@ -30,28 +31,31 @@ public class PipelineTaskExecutor implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineExecutor.class); - private PipelineTaskExecutorContext taskExecutorContext; + private PipelineTaskExecutorContext taskExecutorContext; - public PipelineTaskExecutor(PipelineTaskExecutorContext taskExecutorContext) { - this.taskExecutorContext = taskExecutorContext; - } + public PipelineTaskExecutor(PipelineTaskExecutorContext taskExecutorContext) { + this.taskExecutorContext = taskExecutorContext; + } - public void execute(PipelineTask pipelineTask, TaskCallBack taskCallBack) { - // User pipeline Task. - PipelineTaskContext taskContext = new PipelineTaskContext(taskExecutorContext.getPipelineTaskId(), - taskExecutorContext.getPipelineContext()); - pipelineTask.execute(taskContext); + public void execute(PipelineTask pipelineTask, TaskCallBack taskCallBack) { + // User pipeline Task. + PipelineTaskContext taskContext = + new PipelineTaskContext( + taskExecutorContext.getPipelineTaskId(), taskExecutorContext.getPipelineContext()); + pipelineTask.execute(taskContext); - PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); - // 1. Build pipeline graph plan. - PipelineGraph pipelineGraph = pipelinePlanBuilder.buildPlan(taskExecutorContext.getPipelineContext()); + PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); + // 1. Build pipeline graph plan. + PipelineGraph pipelineGraph = + pipelinePlanBuilder.buildPlan(taskExecutorContext.getPipelineContext()); - // 2. Optimize pipeline graph plan. - pipelinePlanBuilder.optimizePlan(taskExecutorContext.getPipelineContext().getConfig()); + // 2. Optimize pipeline graph plan. + pipelinePlanBuilder.optimizePlan(taskExecutorContext.getPipelineContext().getConfig()); - this.taskExecutorContext.getPipelineRunner().runPipelineGraph(pipelineGraph, taskCallBack - , taskExecutorContext); - } + this.taskExecutorContext + .getPipelineRunner() + .runPipelineGraph(pipelineGraph, taskCallBack, taskExecutorContext); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutorContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutorContext.java index 31ee6fd70..dec3810c7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutorContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-pipeline/src/main/java/org/apache/geaflow/runtime/pipeline/task/PipelineTaskExecutorContext.java @@ -25,46 +25,47 @@ public class PipelineTaskExecutorContext implements IPipelineExecutorContext { - private String driverId; - private long pipelineTaskId; - private String pipelineTaskName; - private PipelineContext pipelineContext; - private PipelineRunner pipelineRunner; + private String driverId; + private long pipelineTaskId; + private String pipelineTaskName; + private PipelineContext pipelineContext; + private PipelineRunner pipelineRunner; - public PipelineTaskExecutorContext(String driverId, - long pipelineTaskId, - String pipelineTaskName, - PipelineContext pipelineContext, - PipelineRunner pipelineRunner) { - this.driverId = driverId; - this.pipelineTaskId = pipelineTaskId; - this.pipelineTaskName = pipelineTaskName; - this.pipelineContext = pipelineContext; - this.pipelineRunner = pipelineRunner; - } + public PipelineTaskExecutorContext( + String driverId, + long pipelineTaskId, + String pipelineTaskName, + PipelineContext pipelineContext, + PipelineRunner pipelineRunner) { + this.driverId = driverId; + this.pipelineTaskId = pipelineTaskId; + this.pipelineTaskName = pipelineTaskName; + this.pipelineContext = pipelineContext; + this.pipelineRunner = pipelineRunner; + } - public String getDriverId() { - return driverId; - } + public String getDriverId() { + return driverId; + } - public long getPipelineTaskId() { - return pipelineTaskId; - } + public long getPipelineTaskId() { + return pipelineTaskId; + } - public String getPipelineTaskName() { - return pipelineTaskName; - } + public String getPipelineTaskName() { + return pipelineTaskName; + } - public PipelineContext getPipelineContext() { - return pipelineContext; - } + public PipelineContext getPipelineContext() { + return pipelineContext; + } - public PipelineRunner getPipelineRunner() { - return pipelineRunner; - } + public PipelineRunner getPipelineRunner() { + return pipelineRunner; + } - @Override - public int getDriverIndex() { - return 0; - } + @Override + public int getDriverIndex() { + return 0; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CollectExecutionVertex.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CollectExecutionVertex.java index 331e65a66..14cd040f5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CollectExecutionVertex.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CollectExecutionVertex.java @@ -21,7 +21,7 @@ public class CollectExecutionVertex extends ExecutionVertex { - public CollectExecutionVertex(int vertexId, String name) { - super(vertexId, name); - } + public CollectExecutionVertex(int vertexId, String name) { + super(vertexId, name); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupMeta.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupMeta.java index 42abf654e..a6236acbf 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupMeta.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupMeta.java @@ -22,64 +22,68 @@ import org.apache.geaflow.plan.graph.AffinityLevel; /** - * The CycleGroupMeta which is used to describe group meta(e.g loop count, whether is iterative) of cycle. + * The CycleGroupMeta which is used to describe group meta(e.g loop count, whether is iterative) of + * cycle. */ public class CycleGroupMeta { - private long iterationCount; - private int flyingCount; - private CycleGroupType groupType; - private AffinityLevel affinityLevel; + private long iterationCount; + private int flyingCount; + private CycleGroupType groupType; + private AffinityLevel affinityLevel; - public CycleGroupMeta() { - this.iterationCount = 1; - this.flyingCount = 1; - this.groupType = CycleGroupType.pipelined; - this.affinityLevel = AffinityLevel.worker; - } + public CycleGroupMeta() { + this.iterationCount = 1; + this.flyingCount = 1; + this.groupType = CycleGroupType.pipelined; + this.affinityLevel = AffinityLevel.worker; + } - public long getIterationCount() { - return iterationCount; - } + public long getIterationCount() { + return iterationCount; + } - public void setIterationCount(long iterationCount) { - this.iterationCount = iterationCount; - } + public void setIterationCount(long iterationCount) { + this.iterationCount = iterationCount; + } - public int getFlyingCount() { - return flyingCount; - } + public int getFlyingCount() { + return flyingCount; + } - public void setFlyingCount(int flyingCount) { - this.flyingCount = flyingCount; - } + public void setFlyingCount(int flyingCount) { + this.flyingCount = flyingCount; + } - public boolean isIterative() { - return groupType == CycleGroupType.incremental || groupType == CycleGroupType.statical; - } + public boolean isIterative() { + return groupType == CycleGroupType.incremental || groupType == CycleGroupType.statical; + } - public CycleGroupType getGroupType() { - return groupType; - } + public CycleGroupType getGroupType() { + return groupType; + } - public void setGroupType(CycleGroupType groupType) { - this.groupType = groupType; - } + public void setGroupType(CycleGroupType groupType) { + this.groupType = groupType; + } - public AffinityLevel getAffinityLevel() { - return affinityLevel; - } + public AffinityLevel getAffinityLevel() { + return affinityLevel; + } - public void setAffinityLevel(AffinityLevel affinityLevel) { - this.affinityLevel = affinityLevel; - } + public void setAffinityLevel(AffinityLevel affinityLevel) { + this.affinityLevel = affinityLevel; + } - @Override - public String toString() { - return "CycleGroupMeta{" - + "iterationCount=" + iterationCount - + ", flyingCount=" + flyingCount - + ", groupType=" + groupType - + '}'; - } + @Override + public String toString() { + return "CycleGroupMeta{" + + "iterationCount=" + + iterationCount + + ", flyingCount=" + + flyingCount + + ", groupType=" + + groupType + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupType.java index 68718d5dd..eee28d4ee 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/CycleGroupType.java @@ -21,23 +21,15 @@ public enum CycleGroupType { - /** - * A type which denotes pipelined cycle group. - */ - pipelined, + /** A type which denotes pipelined cycle group. */ + pipelined, - /** - * A type which denotes incremental cycle group. - */ - incremental, + /** A type which denotes incremental cycle group. */ + incremental, - /** - * A type which denotes statical cycle group. - */ - statical, + /** A type which denotes statical cycle group. */ + statical, - /** - * A type which denotes window cycle group. - */ - windowed + /** A type which denotes window cycle group. */ + windowed } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionEdge.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionEdge.java index a659b8ddf..7ed972016 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionEdge.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionEdge.java @@ -20,99 +20,101 @@ package org.apache.geaflow.core.graph; import java.io.Serializable; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.partitioner.IPartitioner; import org.apache.geaflow.shuffle.desc.OutputType; public class ExecutionEdge implements Serializable { - private IPartitioner partitioner; - private String edgeName; - private int edgeId; - private int srcId; - private int targetId; - private IEncoder encoder; - private OutputType type; - - public ExecutionEdge(IPartitioner partitioner, - int edgeId, - String edgeName, - int srcId, - int targetId, - IEncoder encoder) { - this(partitioner, edgeId, edgeName, srcId, targetId, OutputType.FORWARD, encoder); - } - - public ExecutionEdge(IPartitioner partitioner, - int edgeId, - String edgeName, - int srcId, - int targetId, - OutputType type, - IEncoder encoder) { - this.partitioner = partitioner; - this.edgeId = edgeId; - this.edgeName = edgeName; - this.srcId = srcId; - this.targetId = targetId; - this.encoder = encoder; - this.type = type; - } - - public IPartitioner getPartitioner() { - return partitioner; - } - - public int getEdgeId() { - return edgeId; - } - - public void setEdgeId(int edgeId) { - this.edgeId = edgeId; - } - - public void setPartitioner(IPartitioner partitioner) { - this.partitioner = partitioner; - } - - public String getEdgeName() { - return edgeName; - } - - public void setEdgeName(String edgeName) { - this.edgeName = edgeName; - } - - public int getSrcId() { - return srcId; - } - - public void setSrcId(int srcId) { - this.srcId = srcId; - } - - public int getTargetId() { - return targetId; - } - - public void setTargetId(int targetId) { - this.targetId = targetId; - } - - public IEncoder getEncoder() { - return this.encoder; - } - - public void setEncoder(IEncoder encoder) { - this.encoder = encoder; - } - - public OutputType getType() { - return type; - } - - public void setType(OutputType type) { - this.type = type; - } - + private IPartitioner partitioner; + private String edgeName; + private int edgeId; + private int srcId; + private int targetId; + private IEncoder encoder; + private OutputType type; + + public ExecutionEdge( + IPartitioner partitioner, + int edgeId, + String edgeName, + int srcId, + int targetId, + IEncoder encoder) { + this(partitioner, edgeId, edgeName, srcId, targetId, OutputType.FORWARD, encoder); + } + + public ExecutionEdge( + IPartitioner partitioner, + int edgeId, + String edgeName, + int srcId, + int targetId, + OutputType type, + IEncoder encoder) { + this.partitioner = partitioner; + this.edgeId = edgeId; + this.edgeName = edgeName; + this.srcId = srcId; + this.targetId = targetId; + this.encoder = encoder; + this.type = type; + } + + public IPartitioner getPartitioner() { + return partitioner; + } + + public int getEdgeId() { + return edgeId; + } + + public void setEdgeId(int edgeId) { + this.edgeId = edgeId; + } + + public void setPartitioner(IPartitioner partitioner) { + this.partitioner = partitioner; + } + + public String getEdgeName() { + return edgeName; + } + + public void setEdgeName(String edgeName) { + this.edgeName = edgeName; + } + + public int getSrcId() { + return srcId; + } + + public void setSrcId(int srcId) { + this.srcId = srcId; + } + + public int getTargetId() { + return targetId; + } + + public void setTargetId(int targetId) { + this.targetId = targetId; + } + + public IEncoder getEncoder() { + return this.encoder; + } + + public void setEncoder(IEncoder encoder) { + this.encoder = encoder; + } + + public OutputType getType() { + return type; + } + + public void setType(OutputType type) { + this.type = type; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionGraph.java index ba165cdea..bccfa86bf 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionGraph.java @@ -26,77 +26,85 @@ public class ExecutionGraph implements Serializable { - // A execution graph contains one or many vertexGroup. - private Map vertexGroupMap; - - // It specifies the edge map of vertex groups. - private Map groupEdgeMap; - - // >, it's used to be describe the relation with out vertex group. - private Map> vertexGroupOutEdgeIds; - - // >, it's used to be describe the relation with in vertex group. - private Map> vertexGroupInEdgeIds; - - private CycleGroupMeta cycleGroupMeta; - - public ExecutionGraph() { - this.vertexGroupMap = new HashMap<>(); - this.groupEdgeMap = new HashMap<>(); - this.vertexGroupOutEdgeIds = new HashMap<>(); - this.vertexGroupInEdgeIds = new HashMap<>(); - this.cycleGroupMeta = new CycleGroupMeta(); - } - - public Map getVertexGroupMap() { - return vertexGroupMap; - } - - public void setVertexGroupMap(Map vertexGroupMap) { - this.vertexGroupMap = vertexGroupMap; - } - - public Map getGroupEdgeMap() { - return groupEdgeMap; - } - - public void setGroupEdgeMap(Map groupEdgeMap) { - this.groupEdgeMap = groupEdgeMap; - } - - public Map> getVertexGroupOutEdgeIds() { - return vertexGroupOutEdgeIds; - } - - public void setVertexGroupOutEdgeIds(Map> vertexGroupOutEdgeIds) { - this.vertexGroupOutEdgeIds = vertexGroupOutEdgeIds; - } - - public Map> getVertexGroupInEdgeIds() { - return vertexGroupInEdgeIds; - } - - public void setVertexGroupInEdgeIds(Map> vertexGroupInEdgeIds) { - this.vertexGroupInEdgeIds = vertexGroupInEdgeIds; - } - - public void putVertexGroupOutEdgeIds(int groupId, List outEdgeIds) { - this.vertexGroupOutEdgeIds.put(groupId, outEdgeIds); - } - - public void putVertexGroupInEdgeIds(int groupId, List inEdgeIds) { - this.vertexGroupInEdgeIds.put(groupId, inEdgeIds); - } - - public CycleGroupMeta getCycleGroupMeta() { - return cycleGroupMeta; - } - - @Override - public String toString() { - return "ExecutionGraph{" + "vertexGroupMap=" + vertexGroupMap - + ",\n vertexGroupOutEdgeIds=" + vertexGroupOutEdgeIds - + ",\n vertexGroupInEdgeIds=" + vertexGroupInEdgeIds - + ",\n cycleGroupMeta=" + cycleGroupMeta + "\n}"; - } + // A execution graph contains one or many vertexGroup. + private Map vertexGroupMap; + + // It specifies the edge map of vertex groups. + private Map groupEdgeMap; + + // >, it's used to be describe the relation with out vertex + // group. + private Map> vertexGroupOutEdgeIds; + + // >, it's used to be describe the relation with in vertex + // group. + private Map> vertexGroupInEdgeIds; + + private CycleGroupMeta cycleGroupMeta; + + public ExecutionGraph() { + this.vertexGroupMap = new HashMap<>(); + this.groupEdgeMap = new HashMap<>(); + this.vertexGroupOutEdgeIds = new HashMap<>(); + this.vertexGroupInEdgeIds = new HashMap<>(); + this.cycleGroupMeta = new CycleGroupMeta(); + } + + public Map getVertexGroupMap() { + return vertexGroupMap; + } + + public void setVertexGroupMap(Map vertexGroupMap) { + this.vertexGroupMap = vertexGroupMap; + } + + public Map getGroupEdgeMap() { + return groupEdgeMap; + } + + public void setGroupEdgeMap(Map groupEdgeMap) { + this.groupEdgeMap = groupEdgeMap; + } + + public Map> getVertexGroupOutEdgeIds() { + return vertexGroupOutEdgeIds; + } + + public void setVertexGroupOutEdgeIds(Map> vertexGroupOutEdgeIds) { + this.vertexGroupOutEdgeIds = vertexGroupOutEdgeIds; + } + + public Map> getVertexGroupInEdgeIds() { + return vertexGroupInEdgeIds; + } + + public void setVertexGroupInEdgeIds(Map> vertexGroupInEdgeIds) { + this.vertexGroupInEdgeIds = vertexGroupInEdgeIds; + } + + public void putVertexGroupOutEdgeIds(int groupId, List outEdgeIds) { + this.vertexGroupOutEdgeIds.put(groupId, outEdgeIds); + } + + public void putVertexGroupInEdgeIds(int groupId, List inEdgeIds) { + this.vertexGroupInEdgeIds.put(groupId, inEdgeIds); + } + + public CycleGroupMeta getCycleGroupMeta() { + return cycleGroupMeta; + } + + @Override + public String toString() { + return "ExecutionGraph{" + + "vertexGroupMap=" + + vertexGroupMap + + ",\n vertexGroupOutEdgeIds=" + + vertexGroupOutEdgeIds + + ",\n vertexGroupInEdgeIds=" + + vertexGroupInEdgeIds + + ",\n cycleGroupMeta=" + + cycleGroupMeta + + "\n}"; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTask.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTask.java index 4273d2149..00dacdf93 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTask.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTask.java @@ -20,6 +20,7 @@ package org.apache.geaflow.core.graph; import java.io.Serializable; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.task.TaskArgs; import org.apache.geaflow.common.utils.LoggerFormatter; @@ -27,145 +28,161 @@ public class ExecutionTask implements Serializable { - private int taskId; - private int index; - private int vertexId; - private int parallelism; - private int maxParallelism; - private int numPartitions; - private WorkerInfo workerInfo; - private Processor processor; - private ExecutionTaskType executionTaskType; - private boolean iterative; - - private long startTime; - private long duration; - private transient String taskName; - - public ExecutionTask(int taskId, int index, int parallelism, int maxParallelism, int numPartitions, int vertexId) { - this.taskId = taskId; - this.index = index; - this.parallelism = parallelism; - this.maxParallelism = maxParallelism; - this.numPartitions = numPartitions; - this.vertexId = vertexId; - this.workerInfo = null; - this.startTime = -1; - } - - public int getTaskId() { - return taskId; - } - - public void setTaskId(int taskId) { - this.taskId = taskId; - } - - public int getIndex() { - return index; - } - - public void setIndex(int index) { - this.index = index; - } - - public int getParallelism() { - return parallelism; - } - - public void setParallelism(int parallelism) { - this.parallelism = parallelism; - } - - public int getMaxParallelism() { - return maxParallelism; - } - - public int getNumPartitions() { - return numPartitions; - } - - public int getVertexId() { - return vertexId; - } - - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } - - public WorkerInfo getWorkerInfo() { - return workerInfo; - } - - public void setWorkerInfo(WorkerInfo workerInfo) { - this.workerInfo = workerInfo; - } - - public long getStartTime() { - return startTime; - } - - public void setStartTime(long startTime) { - this.startTime = startTime; - } - - public long getDuration() { - return duration; - } - - public void setDuration(long duration) { - this.duration = duration; - } - - public Processor getProcessor() { - return processor; - } - - public void setProcessor(Processor processor) { - this.processor = processor; - } - - public ExecutionTaskType getExecutionTaskType() { - return executionTaskType; - } - - public void setExecutionTaskType(ExecutionTaskType executionTaskType) { - this.executionTaskType = executionTaskType; - } - - public void setIterative(boolean iterative) { - this.iterative = iterative; - } - - public boolean isIterative() { - return this.iterative; - } - - public String getTaskName() { - return taskName; - } - - public void buildTaskName(String pipelineName, int cycleId, long windowId) { - this.taskName = this.iterative + private int taskId; + private int index; + private int vertexId; + private int parallelism; + private int maxParallelism; + private int numPartitions; + private WorkerInfo workerInfo; + private Processor processor; + private ExecutionTaskType executionTaskType; + private boolean iterative; + + private long startTime; + private long duration; + private transient String taskName; + + public ExecutionTask( + int taskId, int index, int parallelism, int maxParallelism, int numPartitions, int vertexId) { + this.taskId = taskId; + this.index = index; + this.parallelism = parallelism; + this.maxParallelism = maxParallelism; + this.numPartitions = numPartitions; + this.vertexId = vertexId; + this.workerInfo = null; + this.startTime = -1; + } + + public int getTaskId() { + return taskId; + } + + public void setTaskId(int taskId) { + this.taskId = taskId; + } + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public int getParallelism() { + return parallelism; + } + + public void setParallelism(int parallelism) { + this.parallelism = parallelism; + } + + public int getMaxParallelism() { + return maxParallelism; + } + + public int getNumPartitions() { + return numPartitions; + } + + public int getVertexId() { + return vertexId; + } + + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } + + public WorkerInfo getWorkerInfo() { + return workerInfo; + } + + public void setWorkerInfo(WorkerInfo workerInfo) { + this.workerInfo = workerInfo; + } + + public long getStartTime() { + return startTime; + } + + public void setStartTime(long startTime) { + this.startTime = startTime; + } + + public long getDuration() { + return duration; + } + + public void setDuration(long duration) { + this.duration = duration; + } + + public Processor getProcessor() { + return processor; + } + + public void setProcessor(Processor processor) { + this.processor = processor; + } + + public ExecutionTaskType getExecutionTaskType() { + return executionTaskType; + } + + public void setExecutionTaskType(ExecutionTaskType executionTaskType) { + this.executionTaskType = executionTaskType; + } + + public void setIterative(boolean iterative) { + this.iterative = iterative; + } + + public boolean isIterative() { + return this.iterative; + } + + public String getTaskName() { + return taskName; + } + + public void buildTaskName(String pipelineName, int cycleId, long windowId) { + this.taskName = + this.iterative ? LoggerFormatter.getTaskTag( - pipelineName, cycleId, windowId, this.taskId, this.vertexId, this.index, this.parallelism) + pipelineName, + cycleId, + windowId, + this.taskId, + this.vertexId, + this.index, + this.parallelism) : LoggerFormatter.getTaskTag( - pipelineName, cycleId, this.taskId, this.vertexId, this.index, this.parallelism); - } - - public TaskArgs buildTaskArgs() { - return new TaskArgs( - this.taskId, - this.index, - this.taskName, - this.parallelism, - this.maxParallelism, - this.workerInfo.getProcessIndex()); - } - - @Override - public String toString() { - return "ExecutionTask{" + "taskId=" + taskId + ", index=" + index + ", vertexId=" - + vertexId + ", worker=" + (workerInfo == null ? "NULL" : workerInfo) + '}'; - } + pipelineName, cycleId, this.taskId, this.vertexId, this.index, this.parallelism); + } + + public TaskArgs buildTaskArgs() { + return new TaskArgs( + this.taskId, + this.index, + this.taskName, + this.parallelism, + this.maxParallelism, + this.workerInfo.getProcessIndex()); + } + + @Override + public String toString() { + return "ExecutionTask{" + + "taskId=" + + taskId + + ", index=" + + index + + ", vertexId=" + + vertexId + + ", worker=" + + (workerInfo == null ? "NULL" : workerInfo) + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTaskType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTaskType.java index 4c08255f1..feaf15741 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTaskType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionTaskType.java @@ -20,27 +20,26 @@ package org.apache.geaflow.core.graph; public enum ExecutionTaskType { - /** - * A head execution task is the start of cycle pipeline. - * that receive event from scheduler to trigger a certain round of iteration. - */ - head, + /** + * A head execution task is the start of cycle pipeline. that receive event from scheduler to + * trigger a certain round of iteration. + */ + head, - /** - * A middle execution task is the intermediate of cycle pipeline. - */ - middle, + /** A middle execution task is the intermediate of cycle pipeline. */ + middle, - /** - * A tail execution task is the end of cycle pipeline. - * that send event to scheduler to finish a certain round of iteration. - */ - tail, + /** + * A tail execution task is the end of cycle pipeline. that send event to scheduler to finish a + * certain round of iteration. + */ + tail, - /** - * A singularity(start&end both) execution task is the start of cycle pipeline, also is the end of cycle pipeline at the same time. - * Thus that receive event from scheduler to trigger a certain round of iteration, - * and send event to scheduler to finish a certain round of iteration meanwhile. - */ - singularity + /** + * A singularity(start&end both) execution task is the start of cycle pipeline, also is the end of + * cycle pipeline at the same time. Thus that receive event from scheduler to trigger a certain + * round of iteration, and send event to scheduler to finish a certain round of iteration + * meanwhile. + */ + singularity } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertex.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertex.java index 1a68879dc..82668a308 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertex.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertex.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.partitioner.IPartitioner; import org.apache.geaflow.plan.graph.AffinityLevel; import org.apache.geaflow.plan.graph.VertexType; @@ -31,153 +32,162 @@ public class ExecutionVertex implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionVertex.class); - - private int vertexId; - private Processor processor; - private String name; - private int parallelism; - private int maxParallelism; - private List parentVertexIds; - private List childrenVertexIds; - private int numPartitions; - private VertexType vertexType; - private AffinityLevel affinityLevel; - private VertexType chainTailType; - - private List inputEdges; - private List outputEdges; - - public ExecutionVertex(int vertexId, String name) { - this.vertexId = vertexId; - this.name = name; - this.parentVertexIds = new ArrayList<>(); - this.childrenVertexIds = new ArrayList<>(); - this.inputEdges = null; - this.outputEdges = null; - this.affinityLevel = AffinityLevel.worker; - } - - public int getVertexId() { - return vertexId; - } - - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } - - public Processor getProcessor() { - return processor; - } - - public void setProcessor(Processor processor) { - this.processor = processor; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public int getParallelism() { - return parallelism; - } - - public void setParallelism(int parallelism) { - this.parallelism = parallelism; - } - - public int getMaxParallelism() { - return maxParallelism; - } - - public void setMaxParallelism(int maxParallelism) { - this.maxParallelism = maxParallelism; - } - - public List getParentVertexIds() { - return parentVertexIds; - } - - public void setParentVertexIds(List parentVertexIds) { - this.parentVertexIds = parentVertexIds; - } - - public List getChildrenVertexIds() { - return childrenVertexIds; - } - - public void setChildrenVertexIds(List childrenVertexIds) { - this.childrenVertexIds = childrenVertexIds; - } - - public int getNumPartitions() { - return numPartitions; - } - - public void setNumPartitions(int numPartitions) { - this.numPartitions = numPartitions; - } - - public List getInputEdges() { - return inputEdges; - } - - public void setInputEdges(List inputEdges) { - this.inputEdges = inputEdges; - } - - public List getOutputEdges() { - return outputEdges; - } - - public void setOutputEdges(List outputEdges) { - this.outputEdges = outputEdges; - } - - public VertexType getVertexType() { - return vertexType; - } - - public void setVertexType(VertexType vertexType) { - this.vertexType = vertexType; - } - - public AffinityLevel getAffinityLevel() { - return affinityLevel; - } - - public void setAffinityLevel(AffinityLevel affinityLevel) { - this.affinityLevel = affinityLevel; - } - - public VertexType getChainTailType() { - return chainTailType; - } - - public void setChainTailType(VertexType chainTailType) { - this.chainTailType = chainTailType; - } - - public boolean isRepartition() { - if (outputEdges != null) { - boolean isAllForward = this.outputEdges.stream().allMatch( - x -> x.getPartitioner().getPartitionType() == IPartitioner.PartitionType.forward); - if (isAllForward) { - return false; - } else { - return true; - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionVertex.class); + + private int vertexId; + private Processor processor; + private String name; + private int parallelism; + private int maxParallelism; + private List parentVertexIds; + private List childrenVertexIds; + private int numPartitions; + private VertexType vertexType; + private AffinityLevel affinityLevel; + private VertexType chainTailType; + + private List inputEdges; + private List outputEdges; + + public ExecutionVertex(int vertexId, String name) { + this.vertexId = vertexId; + this.name = name; + this.parentVertexIds = new ArrayList<>(); + this.childrenVertexIds = new ArrayList<>(); + this.inputEdges = null; + this.outputEdges = null; + this.affinityLevel = AffinityLevel.worker; + } + + public int getVertexId() { + return vertexId; + } + + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } + + public Processor getProcessor() { + return processor; + } + + public void setProcessor(Processor processor) { + this.processor = processor; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getParallelism() { + return parallelism; + } + + public void setParallelism(int parallelism) { + this.parallelism = parallelism; + } + + public int getMaxParallelism() { + return maxParallelism; + } + + public void setMaxParallelism(int maxParallelism) { + this.maxParallelism = maxParallelism; + } + + public List getParentVertexIds() { + return parentVertexIds; + } + + public void setParentVertexIds(List parentVertexIds) { + this.parentVertexIds = parentVertexIds; + } + + public List getChildrenVertexIds() { + return childrenVertexIds; + } + + public void setChildrenVertexIds(List childrenVertexIds) { + this.childrenVertexIds = childrenVertexIds; + } + + public int getNumPartitions() { + return numPartitions; + } + + public void setNumPartitions(int numPartitions) { + this.numPartitions = numPartitions; + } + + public List getInputEdges() { + return inputEdges; + } + + public void setInputEdges(List inputEdges) { + this.inputEdges = inputEdges; + } + + public List getOutputEdges() { + return outputEdges; + } + + public void setOutputEdges(List outputEdges) { + this.outputEdges = outputEdges; + } + + public VertexType getVertexType() { + return vertexType; + } + + public void setVertexType(VertexType vertexType) { + this.vertexType = vertexType; + } + + public AffinityLevel getAffinityLevel() { + return affinityLevel; + } + + public void setAffinityLevel(AffinityLevel affinityLevel) { + this.affinityLevel = affinityLevel; + } + + public VertexType getChainTailType() { + return chainTailType; + } + + public void setChainTailType(VertexType chainTailType) { + this.chainTailType = chainTailType; + } + + public boolean isRepartition() { + if (outputEdges != null) { + boolean isAllForward = + this.outputEdges.stream() + .allMatch( + x -> x.getPartitioner().getPartitionType() == IPartitioner.PartitionType.forward); + if (isAllForward) { + return false; + } else { return true; - } - - @Override - public String toString() { - return "ExecutionVertex{" + "vertexId=" + vertexId + ", processor=" + processor + ", name='" - + name + '\'' + '}'; - } + } + } + return true; + } + + @Override + public String toString() { + return "ExecutionVertex{" + + "vertexId=" + + vertexId + + ", processor=" + + processor + + ", name='" + + name + + '\'' + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroup.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroup.java index f1831de2c..84aebf4c9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroup.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroup.java @@ -27,143 +27,148 @@ public class ExecutionVertexGroup implements Serializable { - private int groupId; - - // A execution vertex group contains one or many vertex. - private Map vertexMap; - - // Keep the mapping that edge id with edge. - private Map edgeMap; - - // Keep the mapping that edge id with iteration edge. - private Map iterationEdgeMap; - - // Keep the parent vertex group id list of current vertex group. - private List parentVertexGroupIds; - - // Keep the children vertex group id list of current vertex group. - private List childrenVertexGroupIds; - - // >, it's used to be describe the relation with out vertex group. - private Map> vertexId2OutEdgeIds; - - // >, it's used to be describe the relation with in vertex group. - private Map> vertexId2InEdgeIds; - - private CycleGroupMeta cycleGroupMeta; - - public ExecutionVertexGroup(int groupId) { - this.groupId = groupId; - this.vertexMap = new HashMap<>(); - this.edgeMap = new HashMap<>(); - this.iterationEdgeMap = new HashMap<>(); - this.parentVertexGroupIds = new ArrayList<>(); - this.childrenVertexGroupIds = new ArrayList<>(); - this.vertexId2OutEdgeIds = new HashMap<>(); - this.vertexId2InEdgeIds = new HashMap<>(); - this.cycleGroupMeta = new CycleGroupMeta(); - } - - public int getGroupId() { - return groupId; - } - - public Map getVertexMap() { - return vertexMap; - } - - public Map getEdgeMap() { - return edgeMap; - } - - public Map getIterationEdgeMap() { - return this.iterationEdgeMap; - } - - public Map> getVertexId2OutEdgeIds() { - return vertexId2OutEdgeIds; - } - - public Map> getVertexId2InEdgeIds() { - return vertexId2InEdgeIds; - } - - public void putVertexId2OutEdgeIds(int vertexId, List outEdgeIds) { - this.vertexId2OutEdgeIds.put(vertexId, outEdgeIds); - } - - public void putVertexId2InEdgeIds(int vertexId, List inEdgeIds) { - this.vertexId2InEdgeIds.put(vertexId, inEdgeIds); - } - - public List getParentVertexGroupIds() { - return parentVertexGroupIds; - } - - public List getChildrenVertexGroupIds() { - return childrenVertexGroupIds; - } - - public CycleGroupMeta getCycleGroupMeta() { - return cycleGroupMeta; - } - - /** - * Get the tail vertex ids in the current vertex group. - */ - public List getTailVertexIds() { - List tailVertexIds = new ArrayList<>(); - - vertexMap.keySet().stream().forEach(vertexId -> { - List outputVertexIds = new ArrayList<>(); - vertexId2OutEdgeIds.get(vertexId).stream().forEach(edgeId -> { - if (edgeMap.get(edgeId) != null) { - outputVertexIds.add(edgeMap.get(edgeId).getTargetId()); - } - }); - if (outputVertexIds.size() == 0) { + private int groupId; + + // A execution vertex group contains one or many vertex. + private Map vertexMap; + + // Keep the mapping that edge id with edge. + private Map edgeMap; + + // Keep the mapping that edge id with iteration edge. + private Map iterationEdgeMap; + + // Keep the parent vertex group id list of current vertex group. + private List parentVertexGroupIds; + + // Keep the children vertex group id list of current vertex group. + private List childrenVertexGroupIds; + + // >, it's used to be describe the relation with out + // vertex group. + private Map> vertexId2OutEdgeIds; + + // >, it's used to be describe the relation with in vertex + // group. + private Map> vertexId2InEdgeIds; + + private CycleGroupMeta cycleGroupMeta; + + public ExecutionVertexGroup(int groupId) { + this.groupId = groupId; + this.vertexMap = new HashMap<>(); + this.edgeMap = new HashMap<>(); + this.iterationEdgeMap = new HashMap<>(); + this.parentVertexGroupIds = new ArrayList<>(); + this.childrenVertexGroupIds = new ArrayList<>(); + this.vertexId2OutEdgeIds = new HashMap<>(); + this.vertexId2InEdgeIds = new HashMap<>(); + this.cycleGroupMeta = new CycleGroupMeta(); + } + + public int getGroupId() { + return groupId; + } + + public Map getVertexMap() { + return vertexMap; + } + + public Map getEdgeMap() { + return edgeMap; + } + + public Map getIterationEdgeMap() { + return this.iterationEdgeMap; + } + + public Map> getVertexId2OutEdgeIds() { + return vertexId2OutEdgeIds; + } + + public Map> getVertexId2InEdgeIds() { + return vertexId2InEdgeIds; + } + + public void putVertexId2OutEdgeIds(int vertexId, List outEdgeIds) { + this.vertexId2OutEdgeIds.put(vertexId, outEdgeIds); + } + + public void putVertexId2InEdgeIds(int vertexId, List inEdgeIds) { + this.vertexId2InEdgeIds.put(vertexId, inEdgeIds); + } + + public List getParentVertexGroupIds() { + return parentVertexGroupIds; + } + + public List getChildrenVertexGroupIds() { + return childrenVertexGroupIds; + } + + public CycleGroupMeta getCycleGroupMeta() { + return cycleGroupMeta; + } + + /** Get the tail vertex ids in the current vertex group. */ + public List getTailVertexIds() { + List tailVertexIds = new ArrayList<>(); + + vertexMap.keySet().stream() + .forEach( + vertexId -> { + List outputVertexIds = new ArrayList<>(); + vertexId2OutEdgeIds.get(vertexId).stream() + .forEach( + edgeId -> { + if (edgeMap.get(edgeId) != null) { + outputVertexIds.add(edgeMap.get(edgeId).getTargetId()); + } + }); + if (outputVertexIds.size() == 0) { tailVertexIds.add(vertexId); - } else { + } else { boolean flag = false; for (int outputVertexId : outputVertexIds) { - if (!vertexMap.containsKey(outputVertexId)) { - flag = true; - break; - } + if (!vertexMap.containsKey(outputVertexId)) { + flag = true; + break; + } } if (flag) { - tailVertexIds.add(vertexId); + tailVertexIds.add(vertexId); } - } - }); - - return tailVertexIds; - } - - /** - * Get the head vertex ids in the current vertex group. - */ - public List getHeadVertexIds() { - List headVertexIds = new ArrayList<>(); - - vertexMap.keySet().stream().forEach(vertexId -> { - // 1. Vertex input edge is empty OR - // 2. Vertex input edge's source vertex not belong to current group - boolean hasNoInputEdge = vertexId2InEdgeIds.get(vertexId) == null || vertexId2InEdgeIds.get(vertexId).isEmpty(); - boolean isVertexSrcNotInGroup = vertexId2InEdgeIds.get(vertexId).stream() - .anyMatch(edgeId -> !vertexMap.containsKey(edgeMap.get(edgeId).getSrcId())); - if (hasNoInputEdge || isVertexSrcNotInGroup) { + } + }); + + return tailVertexIds; + } + + /** Get the head vertex ids in the current vertex group. */ + public List getHeadVertexIds() { + List headVertexIds = new ArrayList<>(); + + vertexMap.keySet().stream() + .forEach( + vertexId -> { + // 1. Vertex input edge is empty OR + // 2. Vertex input edge's source vertex not belong to current group + boolean hasNoInputEdge = + vertexId2InEdgeIds.get(vertexId) == null + || vertexId2InEdgeIds.get(vertexId).isEmpty(); + boolean isVertexSrcNotInGroup = + vertexId2InEdgeIds.get(vertexId).stream() + .anyMatch(edgeId -> !vertexMap.containsKey(edgeMap.get(edgeId).getSrcId())); + if (hasNoInputEdge || isVertexSrcNotInGroup) { headVertexIds.add(vertexId); - } - }); - - return headVertexIds; - } - - @Override - public String toString() { - return "ExecutionVertexGroup{" + vertexMap - + ",\n cycleGroupMeta=" + cycleGroupMeta - + "\n}"; - } + } + }); + + return headVertexIds; + } + + @Override + public String toString() { + return "ExecutionVertexGroup{" + vertexMap + ",\n cycleGroupMeta=" + cycleGroupMeta + "\n}"; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroupEdge.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroupEdge.java index 80c0c18b5..f4c1b7b47 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroupEdge.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/ExecutionVertexGroupEdge.java @@ -20,62 +20,67 @@ package org.apache.geaflow.core.graph; import java.io.Serializable; + import org.apache.geaflow.partitioner.IPartitioner; public class ExecutionVertexGroupEdge implements Serializable { - private int groupEdgeId; - private int groupSrcId; - private int groupTargetId; - private IPartitioner partitioner; - private String groupEdgeName; - - public ExecutionVertexGroupEdge(IPartitioner partitioner, int groupEdgeId, - String groupEdgeName, int groupSrcId, int groupTargetId) { - this.partitioner = partitioner; - this.groupEdgeId = groupEdgeId; - this.groupEdgeName = groupEdgeName; - this.groupSrcId = groupSrcId; - this.groupTargetId = groupTargetId; - } - - public int getGroupEdgeId() { - return groupEdgeId; - } - - public void setGroupEdgeId(int groupEdgeId) { - this.groupEdgeId = groupEdgeId; - } - - public int getGroupSrcId() { - return groupSrcId; - } - - public void setGroupSrcId(int groupSrcId) { - this.groupSrcId = groupSrcId; - } - - public int getGroupTargetId() { - return groupTargetId; - } - - public void setGroupTargetId(int groupTargetId) { - this.groupTargetId = groupTargetId; - } - - public IPartitioner getPartitioner() { - return partitioner; - } - - public void setPartitioner(IPartitioner partitioner) { - this.partitioner = partitioner; - } - - public String getGroupEdgeName() { - return groupEdgeName; - } - - public void setGroupEdgeName(String groupEdgeName) { - this.groupEdgeName = groupEdgeName; - } + private int groupEdgeId; + private int groupSrcId; + private int groupTargetId; + private IPartitioner partitioner; + private String groupEdgeName; + + public ExecutionVertexGroupEdge( + IPartitioner partitioner, + int groupEdgeId, + String groupEdgeName, + int groupSrcId, + int groupTargetId) { + this.partitioner = partitioner; + this.groupEdgeId = groupEdgeId; + this.groupEdgeName = groupEdgeName; + this.groupSrcId = groupSrcId; + this.groupTargetId = groupTargetId; + } + + public int getGroupEdgeId() { + return groupEdgeId; + } + + public void setGroupEdgeId(int groupEdgeId) { + this.groupEdgeId = groupEdgeId; + } + + public int getGroupSrcId() { + return groupSrcId; + } + + public void setGroupSrcId(int groupSrcId) { + this.groupSrcId = groupSrcId; + } + + public int getGroupTargetId() { + return groupTargetId; + } + + public void setGroupTargetId(int groupTargetId) { + this.groupTargetId = groupTargetId; + } + + public IPartitioner getPartitioner() { + return partitioner; + } + + public void setPartitioner(IPartitioner partitioner) { + this.partitioner = partitioner; + } + + public String getGroupEdgeName() { + return groupEdgeName; + } + + public void setGroupEdgeName(String groupEdgeName) { + this.groupEdgeName = groupEdgeName; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/IteratorExecutionVertex.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/IteratorExecutionVertex.java index 570be77ae..96faf5b8b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/IteratorExecutionVertex.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/IteratorExecutionVertex.java @@ -21,14 +21,14 @@ public class IteratorExecutionVertex extends ExecutionVertex { - private long iteratorCount; + private long iteratorCount; - public IteratorExecutionVertex(int vertexId, String name, long iteratorCount) { - super(vertexId, name); - this.iteratorCount = iteratorCount; - } + public IteratorExecutionVertex(int vertexId, String name, long iteratorCount) { + super(vertexId, name); + this.iteratorCount = iteratorCount; + } - public long getIteratorCount() { - return iteratorCount; - } + public long getIteratorCount() { + return iteratorCount; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilder.java index 964b38366..b9e03590e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilder.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.plan.PipelinePlanBuilder.ITERATION_AGG_VERTEX_ID; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.ArrayList; import java.util.Comparator; @@ -33,6 +32,7 @@ import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -67,681 +67,767 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is for building execution graph. - */ -public class ExecutionGraphBuilder implements Serializable { - - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionGraphBuilder.class); - - private static final int START_GROUP_ID = 1; - - private PipelineGraph plan; - private int flyingCount; - - public ExecutionGraphBuilder(PipelineGraph plan) { - this.plan = plan; - } +import com.google.common.base.Preconditions; - public ExecutionGraph buildExecutionGraph(Configuration jobConf) { - ExecutionGraph executionGraph = new ExecutionGraph(); - - Map id2Vertexes = new HashMap<>(); - - // Build execution vertex group. - List pipelineVertexList = plan.getSourceVertices(); - Queue pipelineVertexQueue = new LinkedList<>(pipelineVertexList); - Map vertexId2GroupIdMap = new HashMap<>(); - - Map vertexGroupMap = buildExecutionVertexGroup( - vertexId2GroupIdMap, pipelineVertexQueue); - executionGraph.setVertexGroupMap(vertexGroupMap); - - vertexGroupMap.values().stream().forEach(vertexGroup -> vertexGroup.getVertexMap().values().stream() - .forEach(vertex -> id2Vertexes.put(vertex.getVertexId(), vertex))); - - // Rebuild the input and output vertex group for the group, and execution vertexes for the vertex at the same time. - Map groupEdgeMap = new HashMap<>(); - for (Map.Entry vertexGroupEntry : vertexGroupMap.entrySet()) { - ExecutionVertexGroup vertexGroup = vertexGroupEntry.getValue(); - // 1. Build the input and output execution vertexes for the vertex. - for (Map.Entry entry : vertexGroupEntry.getValue().getVertexMap().entrySet()) { - int vertexId = entry.getKey(); - Set vertexOutEdges = plan.getVertexOutEdges(vertexId); - Map outEdgeMap = new HashMap<>(); - Map inEdgeMap = new HashMap<>(); - - for (PipelineEdge pipelineEdge : vertexOutEdges) { - ExecutionEdge executionEdge = buildEdge(pipelineEdge); - outEdgeMap.put(pipelineEdge.getEdgeId(), executionEdge); - } - vertexGroup.getEdgeMap().putAll(outEdgeMap); - vertexGroup.putVertexId2OutEdgeIds(vertexId, plan.getVertexOutEdges(vertexId) - .stream().map(PipelineEdge::getEdgeId).collect(Collectors.toList())); - - Set vertexInEdges = plan.getVertexInputEdges(vertexId); - for (PipelineEdge pipelineEdge : vertexInEdges) { - ExecutionEdge executionEdge = buildEdge(pipelineEdge); - inEdgeMap.put(pipelineEdge.getEdgeId(), executionEdge); - } - vertexGroup.getEdgeMap().putAll(inEdgeMap); - vertexGroup.putVertexId2InEdgeIds(vertexId, plan.getVertexInputEdges(vertexId) - .stream().map(PipelineEdge::getEdgeId).collect(Collectors.toList())); - - entry.getValue().setInputEdges(inEdgeMap.values().stream().collect(Collectors.toList())); - entry.getValue().setOutputEdges(outEdgeMap.values().stream().collect(Collectors.toList())); - } +/** This class is for building execution graph. */ +public class ExecutionGraphBuilder implements Serializable { - // 2. Build the input and output vertex group for the group. - int groupId = vertexGroupEntry.getKey(); - List outGroupEdgeIds = new ArrayList<>(); - List inGroupEdgeIds = new ArrayList<>(); - - List tailVertexIds = vertexGroup.getTailVertexIds(); - List headVertexIds = vertexGroup.getHeadVertexIds(); - - headVertexIds.stream().forEach(headVertexId -> { - inGroupEdgeIds.addAll(plan.getVertexInputEdges(headVertexId).stream() - // Exclude self-loop edge. - .filter(e -> !vertexGroup.getVertexMap().containsKey(e.getSrcId())) - .map(PipelineEdge::getEdgeId).collect(Collectors.toList())); - vertexGroup.getParentVertexGroupIds().addAll(plan.getVertexInputVertexIds(headVertexId).stream() - .filter(vertexId -> !vertexGroup.getVertexMap().containsKey(vertexId)) - .map(vertexId -> vertexId2GroupIdMap.get(vertexId)).collect(Collectors.toList())); - }); - - tailVertexIds.stream().forEach(tailVertexId -> { - outGroupEdgeIds.addAll(plan.getVertexOutEdges(tailVertexId).stream() - .filter(e -> vertexGroup.getVertexMap().containsKey(e.getTargetId())) - .map(PipelineEdge::getEdgeId).collect(Collectors.toList())); - vertexGroup.getChildrenVertexGroupIds().addAll(plan.getVertexOutputVertexIds(tailVertexId).stream() - .filter(vertexId -> !vertexGroup.getVertexMap().containsKey(vertexId)) - .map(vertexId -> vertexId2GroupIdMap.get(vertexId)).collect(Collectors.toList())); - }); - - executionGraph.putVertexGroupInEdgeIds(groupId, inGroupEdgeIds); - executionGraph.putVertexGroupOutEdgeIds(groupId, outGroupEdgeIds); - - for (int tailVertexId : tailVertexIds) { - Set pipelineEdgeSet = plan.getVertexOutEdges(tailVertexId); - pipelineEdgeSet.stream().forEach(pipelineEdge -> groupEdgeMap.put(pipelineEdge.getEdgeId(), new ExecutionVertexGroupEdge( - pipelineEdge.getPartition(), pipelineEdge.getEdgeId(), pipelineEdge.getEdgeName(), - vertexId2GroupIdMap.get(tailVertexId), vertexId2GroupIdMap.get(pipelineEdge.getTargetId())))); - } + private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionGraphBuilder.class); + + private static final int START_GROUP_ID = 1; + + private PipelineGraph plan; + private int flyingCount; + + public ExecutionGraphBuilder(PipelineGraph plan) { + this.plan = plan; + } + + public ExecutionGraph buildExecutionGraph(Configuration jobConf) { + ExecutionGraph executionGraph = new ExecutionGraph(); + + Map id2Vertexes = new HashMap<>(); + + // Build execution vertex group. + List pipelineVertexList = plan.getSourceVertices(); + Queue pipelineVertexQueue = new LinkedList<>(pipelineVertexList); + Map vertexId2GroupIdMap = new HashMap<>(); + + Map vertexGroupMap = + buildExecutionVertexGroup(vertexId2GroupIdMap, pipelineVertexQueue); + executionGraph.setVertexGroupMap(vertexGroupMap); + + vertexGroupMap.values().stream() + .forEach( + vertexGroup -> + vertexGroup.getVertexMap().values().stream() + .forEach(vertex -> id2Vertexes.put(vertex.getVertexId(), vertex))); + + // Rebuild the input and output vertex group for the group, and execution vertexes for the + // vertex at the same time. + Map groupEdgeMap = new HashMap<>(); + for (Map.Entry vertexGroupEntry : vertexGroupMap.entrySet()) { + ExecutionVertexGroup vertexGroup = vertexGroupEntry.getValue(); + // 1. Build the input and output execution vertexes for the vertex. + for (Map.Entry entry : + vertexGroupEntry.getValue().getVertexMap().entrySet()) { + int vertexId = entry.getKey(); + Set vertexOutEdges = plan.getVertexOutEdges(vertexId); + Map outEdgeMap = new HashMap<>(); + Map inEdgeMap = new HashMap<>(); + + for (PipelineEdge pipelineEdge : vertexOutEdges) { + ExecutionEdge executionEdge = buildEdge(pipelineEdge); + outEdgeMap.put(pipelineEdge.getEdgeId(), executionEdge); } - executionGraph.setGroupEdgeMap(groupEdgeMap); - - flyingCount = jobConf.getInteger(FrameworkConfigKeys.STREAMING_FLYING_BATCH_NUM); - buildCycleGroupMeta(executionGraph); - - ExecutionGraphVisualization graphVisualization = new ExecutionGraphVisualization(executionGraph); - LOGGER.info("execution graph: {}, \nvertex group size: {}, group edge size: {}", - graphVisualization.getExecutionGraphViz(), executionGraph.getVertexGroupMap().size(), executionGraph.getGroupEdgeMap().size()); - - return executionGraph; - } - - /** - * Build execution vertex group. - */ - private Map buildExecutionVertexGroup(Map vertexId2GroupIdMap, - Queue pipelineVertexQueue) { - Map vertexGroupMap = new HashMap<>(); - int groupId = START_GROUP_ID; - - Set groupedVertices = new HashSet<>(); - - while (!pipelineVertexQueue.isEmpty()) { - PipelineVertex pipelineVertex = pipelineVertexQueue.poll(); - // Ignore already grouped vertex. - if (groupedVertices.contains(pipelineVertex.getVertexId())) { - continue; - } - Map currentVertexGroupMap = - group(pipelineVertex, pipelineVertexQueue, groupedVertices); - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(groupId); - vertexGroupMap.put(groupId, vertexGroup); - vertexGroup.getVertexMap().putAll(currentVertexGroupMap); - for (int id : currentVertexGroupMap.keySet()) { - vertexId2GroupIdMap.put(id, groupId); - } - groupedVertices.addAll(currentVertexGroupMap.keySet()); - groupId++; + vertexGroup.getEdgeMap().putAll(outEdgeMap); + vertexGroup.putVertexId2OutEdgeIds( + vertexId, + plan.getVertexOutEdges(vertexId).stream() + .map(PipelineEdge::getEdgeId) + .collect(Collectors.toList())); + + Set vertexInEdges = plan.getVertexInputEdges(vertexId); + for (PipelineEdge pipelineEdge : vertexInEdges) { + ExecutionEdge executionEdge = buildEdge(pipelineEdge); + inEdgeMap.put(pipelineEdge.getEdgeId(), executionEdge); } - return vertexGroupMap; + vertexGroup.getEdgeMap().putAll(inEdgeMap); + vertexGroup.putVertexId2InEdgeIds( + vertexId, + plan.getVertexInputEdges(vertexId).stream() + .map(PipelineEdge::getEdgeId) + .collect(Collectors.toList())); + + entry.getValue().setInputEdges(inEdgeMap.values().stream().collect(Collectors.toList())); + entry.getValue().setOutputEdges(outEdgeMap.values().stream().collect(Collectors.toList())); + } + + // 2. Build the input and output vertex group for the group. + int groupId = vertexGroupEntry.getKey(); + List outGroupEdgeIds = new ArrayList<>(); + List inGroupEdgeIds = new ArrayList<>(); + + List tailVertexIds = vertexGroup.getTailVertexIds(); + List headVertexIds = vertexGroup.getHeadVertexIds(); + + headVertexIds.stream() + .forEach( + headVertexId -> { + inGroupEdgeIds.addAll( + plan.getVertexInputEdges(headVertexId).stream() + // Exclude self-loop edge. + .filter(e -> !vertexGroup.getVertexMap().containsKey(e.getSrcId())) + .map(PipelineEdge::getEdgeId) + .collect(Collectors.toList())); + vertexGroup + .getParentVertexGroupIds() + .addAll( + plan.getVertexInputVertexIds(headVertexId).stream() + .filter(vertexId -> !vertexGroup.getVertexMap().containsKey(vertexId)) + .map(vertexId -> vertexId2GroupIdMap.get(vertexId)) + .collect(Collectors.toList())); + }); + + tailVertexIds.stream() + .forEach( + tailVertexId -> { + outGroupEdgeIds.addAll( + plan.getVertexOutEdges(tailVertexId).stream() + .filter(e -> vertexGroup.getVertexMap().containsKey(e.getTargetId())) + .map(PipelineEdge::getEdgeId) + .collect(Collectors.toList())); + vertexGroup + .getChildrenVertexGroupIds() + .addAll( + plan.getVertexOutputVertexIds(tailVertexId).stream() + .filter(vertexId -> !vertexGroup.getVertexMap().containsKey(vertexId)) + .map(vertexId -> vertexId2GroupIdMap.get(vertexId)) + .collect(Collectors.toList())); + }); + + executionGraph.putVertexGroupInEdgeIds(groupId, inGroupEdgeIds); + executionGraph.putVertexGroupOutEdgeIds(groupId, outGroupEdgeIds); + + for (int tailVertexId : tailVertexIds) { + Set pipelineEdgeSet = plan.getVertexOutEdges(tailVertexId); + pipelineEdgeSet.stream() + .forEach( + pipelineEdge -> + groupEdgeMap.put( + pipelineEdge.getEdgeId(), + new ExecutionVertexGroupEdge( + pipelineEdge.getPartition(), + pipelineEdge.getEdgeId(), + pipelineEdge.getEdgeName(), + vertexId2GroupIdMap.get(tailVertexId), + vertexId2GroupIdMap.get(pipelineEdge.getTargetId())))); + } } - - /** - * Build execution vertices map for current vertex. - * - * @param vertex Build group vertex map from the vertex. - * @param triggerVertices Vertices that can trigger to build or join an execution group. - * @param globalGroupedVertices Already grouped vertices. - * @return Vertex map that can join together to build an execution group. - */ - private Map group(PipelineVertex vertex, - Queue triggerVertices, - final Set globalGroupedVertices) { - Map currentVertexGroupMap = new HashMap<>(); - Queue currentOutput = new LinkedList<>(); - Set currentVisited = new HashSet<>(); - // If current vertex cannot group, build a vertex group that only include current vertex. - if (!group(vertex, currentVertexGroupMap, currentVisited, currentOutput, globalGroupedVertices)) { - - currentVertexGroupMap.put(vertex.getVertexId(), - buildExecutionVertex(plan.getVertexMap().get(vertex.getVertexId()))); - // Add output vertex to trigger next group. - List outputVertexIds = plan.getVertexOutputVertexIds(vertex.getVertexId()); - for (int id : outputVertexIds) { - if (isReadyToGroup(id, globalGroupedVertices, currentVertexGroupMap.keySet())) { - PipelineVertex outputVertex = plan.getVertexMap().get(id); - triggerVertices.add(outputVertex); - } - } - } else { - // Current group is standalone pipeline which has no output vertices, can join into next group. - while (currentOutput.isEmpty() && !triggerVertices.isEmpty()) { - PipelineVertex nextVertex = triggerVertices.poll(); - if (!group(nextVertex, currentVertexGroupMap, currentVisited, currentOutput, globalGroupedVertices)) { - break; - } - } - triggerVertices.addAll(currentOutput); - } - return currentVertexGroupMap; + executionGraph.setGroupEdgeMap(groupEdgeMap); + + flyingCount = jobConf.getInteger(FrameworkConfigKeys.STREAMING_FLYING_BATCH_NUM); + buildCycleGroupMeta(executionGraph); + + ExecutionGraphVisualization graphVisualization = + new ExecutionGraphVisualization(executionGraph); + LOGGER.info( + "execution graph: {}, \nvertex group size: {}, group edge size: {}", + graphVisualization.getExecutionGraphViz(), + executionGraph.getVertexGroupMap().size(), + executionGraph.getGroupEdgeMap().size()); + + return executionGraph; + } + + /** Build execution vertex group. */ + private Map buildExecutionVertexGroup( + Map vertexId2GroupIdMap, Queue pipelineVertexQueue) { + Map vertexGroupMap = new HashMap<>(); + int groupId = START_GROUP_ID; + + Set groupedVertices = new HashSet<>(); + + while (!pipelineVertexQueue.isEmpty()) { + PipelineVertex pipelineVertex = pipelineVertexQueue.poll(); + // Ignore already grouped vertex. + if (groupedVertices.contains(pipelineVertex.getVertexId())) { + continue; + } + Map currentVertexGroupMap = + group(pipelineVertex, pipelineVertexQueue, groupedVertices); + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(groupId); + vertexGroupMap.put(groupId, vertexGroup); + vertexGroup.getVertexMap().putAll(currentVertexGroupMap); + for (int id : currentVertexGroupMap.keySet()) { + vertexId2GroupIdMap.put(id, groupId); + } + groupedVertices.addAll(currentVertexGroupMap.keySet()); + groupId++; } - - /** - * Check and build group for input vertex. - */ - private boolean group(PipelineVertex vertex, - Map currentVertexGroupMap, - Set currentVisited, - Queue groupOutputVertices, - final Set globalGroupedVertices) { - - currentVisited.add(vertex.getVertexId()); - if (!canGroup(vertex)) { - return false; - } - // 1. All input must support group into current vertex. - boolean canGroup = pushUpGroup(vertex, currentVertexGroupMap, currentVisited, - groupOutputVertices, globalGroupedVertices); - if (canGroup) { - currentVertexGroupMap.put(vertex.getVertexId(), - buildExecutionVertex(plan.getVertexMap().get(vertex.getVertexId()))); + return vertexGroupMap; + } + + /** + * Build execution vertices map for current vertex. + * + * @param vertex Build group vertex map from the vertex. + * @param triggerVertices Vertices that can trigger to build or join an execution group. + * @param globalGroupedVertices Already grouped vertices. + * @return Vertex map that can join together to build an execution group. + */ + private Map group( + PipelineVertex vertex, + Queue triggerVertices, + final Set globalGroupedVertices) { + Map currentVertexGroupMap = new HashMap<>(); + Queue currentOutput = new LinkedList<>(); + Set currentVisited = new HashSet<>(); + // If current vertex cannot group, build a vertex group that only include current vertex. + if (!group( + vertex, currentVertexGroupMap, currentVisited, currentOutput, globalGroupedVertices)) { + + currentVertexGroupMap.put( + vertex.getVertexId(), + buildExecutionVertex(plan.getVertexMap().get(vertex.getVertexId()))); + // Add output vertex to trigger next group. + List outputVertexIds = plan.getVertexOutputVertexIds(vertex.getVertexId()); + for (int id : outputVertexIds) { + if (isReadyToGroup(id, globalGroupedVertices, currentVertexGroupMap.keySet())) { + PipelineVertex outputVertex = plan.getVertexMap().get(id); + triggerVertices.add(outputVertex); } - - // 2. Try check and join output vertex into current group. - if (canGroup) { - pushDownGroup(vertex, currentVertexGroupMap, currentVisited, - groupOutputVertices, globalGroupedVertices); + } + } else { + // Current group is standalone pipeline which has no output vertices, can join into next + // group. + while (currentOutput.isEmpty() && !triggerVertices.isEmpty()) { + PipelineVertex nextVertex = triggerVertices.poll(); + if (!group( + nextVertex, + currentVertexGroupMap, + currentVisited, + currentOutput, + globalGroupedVertices)) { + break; } + } + triggerVertices.addAll(currentOutput); + } + return currentVertexGroupMap; + } + + /** Check and build group for input vertex. */ + private boolean group( + PipelineVertex vertex, + Map currentVertexGroupMap, + Set currentVisited, + Queue groupOutputVertices, + final Set globalGroupedVertices) { + + currentVisited.add(vertex.getVertexId()); + if (!canGroup(vertex)) { + return false; + } + // 1. All input must support group into current vertex. + boolean canGroup = + pushUpGroup( + vertex, + currentVertexGroupMap, + currentVisited, + groupOutputVertices, + globalGroupedVertices); + if (canGroup) { + currentVertexGroupMap.put( + vertex.getVertexId(), + buildExecutionVertex(plan.getVertexMap().get(vertex.getVertexId()))); + } - return canGroup; + // 2. Try check and join output vertex into current group. + if (canGroup) { + pushDownGroup( + vertex, + currentVertexGroupMap, + currentVisited, + groupOutputVertices, + globalGroupedVertices); } + return canGroup; + } + + /** Check upstream vertex whether can group together, if could then group together. */ + private boolean pushUpGroup( + PipelineVertex vertex, + Map currentVertexGroupMap, + Set currentVisited, + Queue groupOutputVertices, + final Set globalGroupedVertices) { + + // The current vertex can group only if all input can group. + // 1. All input must support group into current vertex. + List inputVertexIds = plan.getVertexInputVertexIds(vertex.getVertexId()); + List inputVertexIdCandidates = new ArrayList<>(); + for (int id : inputVertexIds) { + PipelineVertex inputVertex = plan.getVertexMap().get(id); + // Input already in current vertex group, ignore. + if (currentVertexGroupMap.containsKey(id) || globalGroupedVertices.contains(id)) { + continue; + } else if (id == vertex.getVertexId()) { + // Ignore self loop edge. + continue; + } else if (currentVisited.contains(id)) { + // Visited but not in grouped vertices, it means group failed in previous steps. + // return false; + continue; + } + if (!canGroup(inputVertex, vertex)) { + return false; + } + inputVertexIdCandidates.add(id); + } - /** - * Check upstream vertex whether can group together, if could then group together. - */ - private boolean pushUpGroup(PipelineVertex vertex, - Map currentVertexGroupMap, - Set currentVisited, - Queue groupOutputVertices, - final Set globalGroupedVertices) { - - // The current vertex can group only if all input can group. - // 1. All input must support group into current vertex. - List inputVertexIds = plan.getVertexInputVertexIds(vertex.getVertexId()); - List inputVertexIdCandidates = new ArrayList<>(); - for (int id : inputVertexIds) { - PipelineVertex inputVertex = plan.getVertexMap().get(id); - // Input already in current vertex group, ignore. - if (currentVertexGroupMap.containsKey(id) || globalGroupedVertices.contains(id)) { - continue; - } else if (id == vertex.getVertexId()) { - // Ignore self loop edge. - continue; - } else if (currentVisited.contains(id)) { - // Visited but not in grouped vertices, it means group failed in previous steps. - // return false; - continue; - } - if (!canGroup(inputVertex, vertex)) { - return false; - } - inputVertexIdCandidates.add(id); + // 2. Try group input. + Map inputVertices = new HashMap<>(); + for (int id : inputVertexIdCandidates) { + PipelineVertex inputVertex = plan.getVertexMap().get(id); + // Input already in current vertex group, ignore. + if (currentVertexGroupMap.containsKey(id) + || globalGroupedVertices.contains(id) + || inputVertices.containsKey(id)) { + continue; + } else if (currentVisited.contains(id)) { + // Visited but not in grouped vertices, it means group failed in previous steps. + return false; + } + if (!canGroup(inputVertex, vertex)) { + return false; + } + + Map inputVertexGroupMap = new HashMap<>(); + inputVertexGroupMap.putAll(currentVertexGroupMap); + boolean canGroup = + group( + inputVertex, + inputVertexGroupMap, + currentVisited, + groupOutputVertices, + globalGroupedVertices); + if (!canGroup) { + return false; + } else { + inputVertices.putAll(inputVertexGroupMap); + } + } + // 3. All inputs can group into current group. + currentVertexGroupMap.putAll(inputVertices); + return true; + } + + /** + * Try check and join output vertices into current vertex group. If one output can join into + * current vertex group, recursively group output vertex, otherwise, add the output into queue to + * trigger next group. + */ + private void pushDownGroup( + PipelineVertex vertex, + Map currentVertexGroupMap, + Set currentVisited, + Queue groupOutputVertices, + final Set globalGroupedVertices) { + List outputVertexIds = plan.getVertexOutputVertexIds(vertex.getVertexId()); + for (int id : outputVertexIds) { + PipelineVertex outputVertex = plan.getVertexMap().get(id); + + // Ignore visited vertex. + if (currentVertexGroupMap.containsKey(id) || globalGroupedVertices.contains(id)) { + continue; + } + + // If it cannot group, add to trigger queue. + if (!canGroup(vertex, outputVertex) + || !group( + outputVertex, + currentVertexGroupMap, + currentVisited, + groupOutputVertices, + globalGroupedVertices)) { + groupOutputVertices.add(outputVertex); + } + } + } + + /** Check the current pipeline vertex whether can group. */ + private boolean canGroup(PipelineVertex currentVertex) { + boolean enGroup = true; + + VertexType type = currentVertex.getType(); + switch (type) { + case vertex_centric: + case inc_vertex_centric: + case iterator: + case inc_iterator: + case iteration_aggregation: + if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp) { + return true; } - - // 2. Try group input. - Map inputVertices = new HashMap<>(); - for (int id : inputVertexIdCandidates) { - PipelineVertex inputVertex = plan.getVertexMap().get(id); - // Input already in current vertex group, ignore. - if (currentVertexGroupMap.containsKey(id) || globalGroupedVertices.contains(id) || inputVertices.containsKey(id)) { - continue; - } else if (currentVisited.contains(id)) { - // Visited but not in grouped vertices, it means group failed in previous steps. - return false; - } - if (!canGroup(inputVertex, vertex)) { - return false; - } - - Map inputVertexGroupMap = new HashMap<>(); - inputVertexGroupMap.putAll(currentVertexGroupMap); - boolean canGroup = group(inputVertex, - inputVertexGroupMap, currentVisited, - groupOutputVertices, globalGroupedVertices); - if (!canGroup) { - return false; - } else { - inputVertices.putAll(inputVertexGroupMap); + return false; + default: + AbstractOperator operator = (AbstractOperator) currentVertex.getOperator(); + if (!operator.getOpArgs().isEnGroup()) { + enGroup = false; + } else { + List operatorList = operator.getNextOperators(); + for (Operator op : operatorList) { + if (!((AbstractOperator) op).getOpArgs().isEnGroup()) { + enGroup = false; + break; } + } } - // 3. All inputs can group into current group. - currentVertexGroupMap.putAll(inputVertices); - return true; } - - /** - * Try check and join output vertices into current vertex group. - * If one output can join into current vertex group, recursively group output vertex, - * otherwise, add the output into queue to trigger next group. - */ - private void pushDownGroup(PipelineVertex vertex, - Map currentVertexGroupMap, - Set currentVisited, - Queue groupOutputVertices, - final Set globalGroupedVertices) { - List outputVertexIds = plan.getVertexOutputVertexIds(vertex.getVertexId()); - for (int id : outputVertexIds) { - PipelineVertex outputVertex = plan.getVertexMap().get(id); - - // Ignore visited vertex. - if (currentVertexGroupMap.containsKey(id) - || globalGroupedVertices.contains(id)) { - continue; - } - - // If it cannot group, add to trigger queue. - if (!canGroup(vertex, outputVertex) - || !group(outputVertex, currentVertexGroupMap, - currentVisited, groupOutputVertices, globalGroupedVertices)) { - groupOutputVertices.add(outputVertex); - } - } + return enGroup; + } + + /** Check the current pipeline vertex can group with output vertex. */ + private boolean canGroup(PipelineVertex currentVertex, PipelineVertex outputVertex) { + if (canGroupWithInput(currentVertex, outputVertex) + && canGroupWithOutput(outputVertex, currentVertex)) { + return true; } - - /** - * Check the current pipeline vertex whether can group. - */ - private boolean canGroup(PipelineVertex currentVertex) { - boolean enGroup = true; - - VertexType type = currentVertex.getType(); - switch (type) { - case vertex_centric: - case inc_vertex_centric: - case iterator: - case inc_iterator: - case iteration_aggregation: - if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp) { - return true; - } - return false; - default: - AbstractOperator operator = (AbstractOperator) currentVertex.getOperator(); - if (!operator.getOpArgs().isEnGroup()) { - enGroup = false; - } else { - List operatorList = operator.getNextOperators(); - for (Operator op : operatorList) { - if (!((AbstractOperator) op).getOpArgs().isEnGroup()) { - enGroup = false; - break; - } - } - } + return false; + } + + /** Check the current pipeline vertex can group with output vertex. */ + private boolean canGroupWithOutput(PipelineVertex currentVertex, PipelineVertex outputVertex) { + VertexType type = currentVertex.getType(); + switch (type) { + case vertex_centric: + case inc_vertex_centric: + case iterator: + case inc_iterator: + case iteration_aggregation: + if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp + && outputVertex.getOperator() instanceof IGraphVertexCentricAggOp) { + return true; } - return enGroup; + return false; + default: + return true; } - - /** - * Check the current pipeline vertex can group with output vertex. - */ - private boolean canGroup(PipelineVertex currentVertex, PipelineVertex outputVertex) { - if (canGroupWithInput(currentVertex, outputVertex) && canGroupWithOutput(outputVertex, currentVertex)) { - return true; + } + + /** Check the current pipeline vertex can group with input vertex. */ + private boolean canGroupWithInput(PipelineVertex currentVertex, PipelineVertex inputVertex) { + VertexType type = currentVertex.getType(); + switch (type) { + case vertex_centric: + case inc_vertex_centric: + case iterator: + case inc_iterator: + case iteration_aggregation: + if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp + && inputVertex.getOperator() instanceof IGraphVertexCentricAggOp) { + return true; } return false; + default: + return true; } - - /** - * Check the current pipeline vertex can group with output vertex. - */ - private boolean canGroupWithOutput(PipelineVertex currentVertex, PipelineVertex outputVertex) { - VertexType type = currentVertex.getType(); - switch (type) { - case vertex_centric: - case inc_vertex_centric: - case iterator: - case inc_iterator: - case iteration_aggregation: - if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp - && outputVertex.getOperator() instanceof IGraphVertexCentricAggOp) { - return true; - } - return false; - default: - return true; - } + } + + /** Check whether current vertex can start to build a new group. */ + private boolean isReadyToGroup( + int vertexId, Set globalGroupedVertices, Set currentGroupedVertices) { + + List inputs = plan.getVertexInputVertexIds(vertexId); + for (int inputVid : inputs) { + // A vertex that matches flowing cases will not allow to start a new group: + // 1. input has not grouped. + // 2. and input vertex is itself. + // 3. and input is agg vertex that will never allow to group standalone. + if (!globalGroupedVertices.contains(inputVid) + && !currentGroupedVertices.contains(inputVid) + && inputVid != vertexId + && inputVid != ITERATION_AGG_VERTEX_ID) { + return false; + } } - - /** - * Check the current pipeline vertex can group with input vertex. - */ - private boolean canGroupWithInput(PipelineVertex currentVertex, PipelineVertex inputVertex) { - VertexType type = currentVertex.getType(); - switch (type) { - case vertex_centric: - case inc_vertex_centric: - case iterator: - case inc_iterator: - case iteration_aggregation: - if (currentVertex.getOperator() instanceof IGraphVertexCentricAggOp - && inputVertex.getOperator() instanceof IGraphVertexCentricAggOp) { - return true; - } - return false; - default: - return true; - } + return true; + } + + /** Build execution vertex. */ + private ExecutionVertex buildExecutionVertex(PipelineVertex pipelineVertex) { + ExecutionVertex executionVertex; + + int vertexId = pipelineVertex.getVertexId(); + VertexType type = pipelineVertex.getType(); + String name = pipelineVertex.getName(); + LOGGER.info("vertexId:{} vertexName:{} type:{}", vertexId, name, type); + + switch (type) { + case vertex_centric: + case inc_vertex_centric: + case iterator: + case inc_iterator: + executionVertex = + new IteratorExecutionVertex(vertexId, name, pipelineVertex.getIterations()); + break; + case collect: + executionVertex = new CollectExecutionVertex(vertexId, name); + break; + case inc_process: + default: + executionVertex = new ExecutionVertex(vertexId, pipelineVertex.getName()); + break; } - /** - * Check whether current vertex can start to build a new group. - */ - private boolean isReadyToGroup(int vertexId, - Set globalGroupedVertices, - Set currentGroupedVertices) { - - List inputs = plan.getVertexInputVertexIds(vertexId); - for (int inputVid : inputs) { - // A vertex that matches flowing cases will not allow to start a new group: - // 1. input has not grouped. - // 2. and input vertex is itself. - // 3. and input is agg vertex that will never allow to group standalone. - if (!globalGroupedVertices.contains(inputVid) && !currentGroupedVertices.contains(inputVid) - && inputVid != vertexId && inputVid != ITERATION_AGG_VERTEX_ID) { - return false; - } + // Construct the parent and child stage ids for the current stage. + List parentVertexIds = plan.getVertexInputVertexIds(executionVertex.getVertexId()); + if (parentVertexIds != null) { + executionVertex.setParentVertexIds(parentVertexIds); + } + List childrenStageIds = plan.getVertexOutputVertexIds(executionVertex.getVertexId()); + boolean hasChildren = childrenStageIds != null && childrenStageIds.size() > 0; + // Build the downstream partition num for the current vertex. + if (hasChildren) { + int bucketNum = 1; + Map vertexMap = plan.getVertexMap(); + for (Integer childVertexId : childrenStageIds) { + int childParallelism = getMaxParallelism(vertexMap.get(childVertexId)); + if (childParallelism > bucketNum) { + bucketNum = childParallelism; } - return true; + } + executionVertex.setNumPartitions(bucketNum); + } else { + executionVertex.setNumPartitions(pipelineVertex.getParallelism()); } - /** - * Build execution vertex. - */ - private ExecutionVertex buildExecutionVertex(PipelineVertex pipelineVertex) { - ExecutionVertex executionVertex; + // Build execution processor. + Processor processor = null; + if (pipelineVertex.getOperator() != null) { + IProcessorBuilder processorBuilder = new ProcessorBuilder(); + processor = processorBuilder.buildProcessor(pipelineVertex.getOperator()); + executionVertex.setProcessor(processor); + } - int vertexId = pipelineVertex.getVertexId(); - VertexType type = pipelineVertex.getType(); - String name = pipelineVertex.getName(); - LOGGER.info("vertexId:{} vertexName:{} type:{}", vertexId, name, type); + // Set join and combine left/right input processor stream name. + if (pipelineVertex.getType() == VertexType.join + || pipelineVertex.getType() == VertexType.combine) { + List edges = + plan.getVertexInputEdges(pipelineVertex.getVertexId()).stream() + .sorted(Comparator.comparingInt(PipelineEdge::getStreamOrdinal)) + .collect(Collectors.toList()); + TwoInputProcessor twoInputProcessor = (TwoInputProcessor) processor; + twoInputProcessor.setLeftStream(edges.get(0).getEdgeName()); + twoInputProcessor.setRightStream(edges.get(1).getEdgeName()); + } + // Set other member param. + executionVertex.setParallelism(pipelineVertex.getParallelism()); + executionVertex.setMaxParallelism(getMaxParallelism(pipelineVertex)); + executionVertex.setVertexType(pipelineVertex.getType()); + executionVertex.setAffinityLevel(pipelineVertex.getAffinity()); + executionVertex.setChainTailType(pipelineVertex.getChainTailType()); + + LOGGER.info( + "execution vertex {}, parallelism {}, max parallelism {}, num partitions {}", + executionVertex, + executionVertex.getParallelism(), + executionVertex.getMaxParallelism(), + executionVertex.getNumPartitions()); + return executionVertex; + } + + /** Build the cycle group meta for graph. */ + private void buildCycleGroupMeta(ExecutionGraph graph) { + Map vertexGroupMap = graph.getVertexGroupMap(); + + Set pipelineSet = new HashSet<>(); + Set batchSet = new HashSet<>(); + Set iteratorSet = new HashSet<>(); + + for (ExecutionVertexGroup vertexGroup : vertexGroupMap.values()) { + for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { + LOGGER.info("vertexInfo:{}", vertex); + CycleGroupType type = + getCycleGroupType(vertex, ((AbstractProcessor) vertex.getProcessor()).getOperator()); switch (type) { - case vertex_centric: - case inc_vertex_centric: - case iterator: - case inc_iterator: - executionVertex = new IteratorExecutionVertex(vertexId, name, pipelineVertex.getIterations()); - break; - case collect: - executionVertex = new CollectExecutionVertex(vertexId, name); - break; - case inc_process: - default: - executionVertex = new ExecutionVertex(vertexId, pipelineVertex.getName()); - break; - } - - // Construct the parent and child stage ids for the current stage. - List parentVertexIds = plan.getVertexInputVertexIds(executionVertex.getVertexId()); - if (parentVertexIds != null) { - executionVertex.setParentVertexIds(parentVertexIds); - } - List childrenStageIds = plan.getVertexOutputVertexIds(executionVertex.getVertexId()); - boolean hasChildren = childrenStageIds != null && childrenStageIds.size() > 0; - // Build the downstream partition num for the current vertex. - if (hasChildren) { - int bucketNum = 1; - Map vertexMap = plan.getVertexMap(); - for (Integer childVertexId : childrenStageIds) { - int childParallelism = getMaxParallelism(vertexMap.get(childVertexId)); - if (childParallelism > bucketNum) { - bucketNum = childParallelism; - } + case pipelined: + pipelineSet.add(vertexGroup); + vertexGroup.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); + vertexGroup.getCycleGroupMeta().setFlyingCount(flyingCount); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + break; + case incremental: + if (vertex.getVertexType() == VertexType.iteration_aggregation) { + continue; } - executionVertex.setNumPartitions(bucketNum); - } else { - executionVertex.setNumPartitions(pipelineVertex.getParallelism()); - } - - // Build execution processor. - Processor processor = null; - if (pipelineVertex.getOperator() != null) { - IProcessorBuilder processorBuilder = new ProcessorBuilder(); - processor = processorBuilder.buildProcessor(pipelineVertex.getOperator()); - executionVertex.setProcessor(processor); - } - - // Set join and combine left/right input processor stream name. - if (pipelineVertex.getType() == VertexType.join || pipelineVertex.getType() == VertexType.combine) { - List edges = plan.getVertexInputEdges(pipelineVertex.getVertexId()) - .stream().sorted(Comparator.comparingInt(PipelineEdge::getStreamOrdinal)) - .collect(Collectors.toList()); - TwoInputProcessor twoInputProcessor = (TwoInputProcessor) processor; - twoInputProcessor.setLeftStream(edges.get(0).getEdgeName()); - twoInputProcessor.setRightStream(edges.get(1).getEdgeName()); - } - - // Set other member param. - executionVertex.setParallelism(pipelineVertex.getParallelism()); - executionVertex.setMaxParallelism(getMaxParallelism(pipelineVertex)); - executionVertex.setVertexType(pipelineVertex.getType()); - executionVertex.setAffinityLevel(pipelineVertex.getAffinity()); - executionVertex.setChainTailType(pipelineVertex.getChainTailType()); - - LOGGER.info("execution vertex {}, parallelism {}, max parallelism {}, num partitions {}", - executionVertex, executionVertex.getParallelism(), executionVertex.getMaxParallelism(), executionVertex.getNumPartitions()); - return executionVertex; - } - - /** - * Build the cycle group meta for graph. - */ - private void buildCycleGroupMeta(ExecutionGraph graph) { - Map vertexGroupMap = graph.getVertexGroupMap(); - - Set pipelineSet = new HashSet<>(); - Set batchSet = new HashSet<>(); - Set iteratorSet = new HashSet<>(); - - for (ExecutionVertexGroup vertexGroup : vertexGroupMap.values()) { - for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { - LOGGER.info("vertexInfo:{}", vertex); - CycleGroupType type = getCycleGroupType(vertex, ((AbstractProcessor) vertex.getProcessor()).getOperator()); - switch (type) { - case pipelined: - pipelineSet.add(vertexGroup); - vertexGroup.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); - vertexGroup.getCycleGroupMeta().setFlyingCount(flyingCount); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - break; - case incremental: - if (vertex.getVertexType() == VertexType.iteration_aggregation) { - continue; - } - graph.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); - iteratorSet.add(vertexGroup); - vertexGroup.getCycleGroupMeta().setIterationCount(((IteratorExecutionVertex) vertex).getIteratorCount()); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - vertexGroup.getCycleGroupMeta().setAffinityLevel(AffinityLevel.worker); - break; - case statical: - if (vertex.getVertexType() == VertexType.iteration_aggregation) { - continue; - } - List sourceVertexList = plan.getSourceVertices(); - boolean isSingleWindow = sourceVertexList.stream().allMatch(v -> - ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() == OpArgs.OpType.SINGLE_WINDOW_SOURCE); - if (!isSingleWindow) { - graph.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); - } - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.statical); - iteratorSet.add(vertexGroup); - vertexGroup.getCycleGroupMeta().setIterationCount(((IteratorExecutionVertex) vertex).getIteratorCount()); - vertexGroup.getCycleGroupMeta().setAffinityLevel(AffinityLevel.worker); - break; - case windowed: - batchSet.add(vertexGroup); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.windowed); - break; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.operatorTypeNotSupportError(String.valueOf(type))); - } + graph.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); + iteratorSet.add(vertexGroup); + vertexGroup + .getCycleGroupMeta() + .setIterationCount(((IteratorExecutionVertex) vertex).getIteratorCount()); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + vertexGroup.getCycleGroupMeta().setAffinityLevel(AffinityLevel.worker); + break; + case statical: + if (vertex.getVertexType() == VertexType.iteration_aggregation) { + continue; + } + List sourceVertexList = plan.getSourceVertices(); + boolean isSingleWindow = + sourceVertexList.stream() + .allMatch( + v -> + ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() + == OpArgs.OpType.SINGLE_WINDOW_SOURCE); + if (!isSingleWindow) { + graph.getCycleGroupMeta().setIterationCount(Long.MAX_VALUE); } + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.statical); + iteratorSet.add(vertexGroup); + vertexGroup + .getCycleGroupMeta() + .setIterationCount(((IteratorExecutionVertex) vertex).getIteratorCount()); + vertexGroup.getCycleGroupMeta().setAffinityLevel(AffinityLevel.worker); + break; + case windowed: + batchSet.add(vertexGroup); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.windowed); + break; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.operatorTypeNotSupportError(String.valueOf(type))); } + } + } - // Currently not support pipeline mode in hybrid node cycle level, we will support in later. - if (batchSet.size() > 0 || iteratorSet.size() > 0) { - pipelineSet.stream().forEach(executionVertexGroup -> { + // Currently not support pipeline mode in hybrid node cycle level, we will support in later. + if (batchSet.size() > 0 || iteratorSet.size() > 0) { + pipelineSet.stream() + .forEach( + executionVertexGroup -> { executionVertexGroup.getCycleGroupMeta().setIterationCount(1); executionVertexGroup.getCycleGroupMeta().setFlyingCount(1); - }); - // Build graph cycle group type. - if (batchSet.size() > 0) { - graph.getCycleGroupMeta().setGroupType(CycleGroupType.windowed); - } - if (iteratorSet.size() > 0) { - if (iteratorSet.stream().anyMatch(executionVertexGroup - -> executionVertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.incremental)) { - graph.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - graph.getCycleGroupMeta().setGroupType(CycleGroupType.statical); - } - } + }); + // Build graph cycle group type. + if (batchSet.size() > 0) { + graph.getCycleGroupMeta().setGroupType(CycleGroupType.windowed); + } + if (iteratorSet.size() > 0) { + if (iteratorSet.stream() + .anyMatch( + executionVertexGroup -> + executionVertexGroup.getCycleGroupMeta().getGroupType() + == CycleGroupType.incremental)) { + graph.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + graph.getCycleGroupMeta().setGroupType(CycleGroupType.statical); } - - // Set the affinity level for pipeline and batch vertex group. - pipelineSet.stream().forEach(executionVertexGroup -> - executionVertexGroup.getCycleGroupMeta().setAffinityLevel(buildGroupAffinityLevel(executionVertexGroup))); - batchSet.stream().forEach(executionVertexGroup -> - executionVertexGroup.getCycleGroupMeta().setAffinityLevel(buildGroupAffinityLevel(executionVertexGroup))); - - pipelineSet.clear(); - batchSet.clear(); - iteratorSet.clear(); + } } - /** - * Get the cycle group type for current op. - */ - private CycleGroupType getCycleGroupType(ExecutionVertex vertex, Operator operator) { - CycleGroupType groupType; - OpArgs.OpType type = ((AbstractOperator) operator).getOpArgs().getOpType(); - switch (type) { - case ONE_INPUT: - case TWO_INPUT: - case MULTI_WINDOW_SOURCE: - case GRAPH_SOURCE: - groupType = CycleGroupType.pipelined; - // Get type of sub operator. - if (vertex != null) { - for (Object subOperator : ((AbstractOperator) ((AbstractProcessor) vertex.getProcessor()) - .getOperator()).getNextOperators()) { - groupType = getCycleGroupType(null, (Operator) subOperator); - if (groupType != CycleGroupType.pipelined) { - break; - } - } - } - break; - case SINGLE_WINDOW_SOURCE: - groupType = CycleGroupType.windowed; - break; - case INC_VERTEX_CENTRIC_COMPUTE: - case INC_VERTEX_CENTRIC_TRAVERSAL: - groupType = CycleGroupType.incremental; - break; - case VERTEX_CENTRIC_COMPUTE: - case VERTEX_CENTRIC_TRAVERSAL: - groupType = CycleGroupType.statical; - break; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.operatorTypeNotSupportError(type.name())); + // Set the affinity level for pipeline and batch vertex group. + pipelineSet.stream() + .forEach( + executionVertexGroup -> + executionVertexGroup + .getCycleGroupMeta() + .setAffinityLevel(buildGroupAffinityLevel(executionVertexGroup))); + batchSet.stream() + .forEach( + executionVertexGroup -> + executionVertexGroup + .getCycleGroupMeta() + .setAffinityLevel(buildGroupAffinityLevel(executionVertexGroup))); + + pipelineSet.clear(); + batchSet.clear(); + iteratorSet.clear(); + } + + /** Get the cycle group type for current op. */ + private CycleGroupType getCycleGroupType(ExecutionVertex vertex, Operator operator) { + CycleGroupType groupType; + OpArgs.OpType type = ((AbstractOperator) operator).getOpArgs().getOpType(); + switch (type) { + case ONE_INPUT: + case TWO_INPUT: + case MULTI_WINDOW_SOURCE: + case GRAPH_SOURCE: + groupType = CycleGroupType.pipelined; + // Get type of sub operator. + if (vertex != null) { + for (Object subOperator : + ((AbstractOperator) ((AbstractProcessor) vertex.getProcessor()).getOperator()) + .getNextOperators()) { + groupType = getCycleGroupType(null, (Operator) subOperator); + if (groupType != CycleGroupType.pipelined) { + break; + } + } } - - return groupType; + break; + case SINGLE_WINDOW_SOURCE: + groupType = CycleGroupType.windowed; + break; + case INC_VERTEX_CENTRIC_COMPUTE: + case INC_VERTEX_CENTRIC_TRAVERSAL: + groupType = CycleGroupType.incremental; + break; + case VERTEX_CENTRIC_COMPUTE: + case VERTEX_CENTRIC_TRAVERSAL: + groupType = CycleGroupType.statical; + break; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.operatorTypeNotSupportError(type.name())); } - /** - * Get max parallelism of vertex. - */ - private int getMaxParallelism(PipelineVertex vertex) { - int maxParallelism = vertex.getParallelism(); - switch (vertex.getType()) { - case inc_process: - case vertex_centric: - case iterator: - if (vertex.getOperator() instanceof AbstractStaticGraphVertexCentricOp) { - int shardNum = ((AbstractStaticGraphVertexCentricOp) vertex.getOperator()).getGraphViewDesc().getShardNum(); - Preconditions.checkArgument(shardNum >= maxParallelism, - String.format("shardNum %d should not be less than maxParallelism %d", shardNum, maxParallelism)); - return shardNum; - } - return MathUtil.minPowerOf2(maxParallelism); - case inc_vertex_centric: - case inc_iterator: - return ((AbstractDynamicGraphVertexCentricOp) vertex.getOperator()) - .getGraphViewDesc().getShardNum(); - default: - return maxParallelism; + return groupType; + } + + /** Get max parallelism of vertex. */ + private int getMaxParallelism(PipelineVertex vertex) { + int maxParallelism = vertex.getParallelism(); + switch (vertex.getType()) { + case inc_process: + case vertex_centric: + case iterator: + if (vertex.getOperator() instanceof AbstractStaticGraphVertexCentricOp) { + int shardNum = + ((AbstractStaticGraphVertexCentricOp) vertex.getOperator()) + .getGraphViewDesc() + .getShardNum(); + Preconditions.checkArgument( + shardNum >= maxParallelism, + String.format( + "shardNum %d should not be less than maxParallelism %d", + shardNum, maxParallelism)); + return shardNum; } + return MathUtil.minPowerOf2(maxParallelism); + case inc_vertex_centric: + case inc_iterator: + return ((AbstractDynamicGraphVertexCentricOp) vertex.getOperator()) + .getGraphViewDesc() + .getShardNum(); + default: + return maxParallelism; } - - private AffinityLevel buildGroupAffinityLevel(ExecutionVertexGroup vertexGroup) { - AffinityLevel affinityLevel = null; - for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { - if (affinityLevel == null) { - affinityLevel = vertex.getAffinityLevel(); - } else { - if (affinityLevel == vertex.getAffinityLevel()) { - continue; - } - // Set default value to worker. - affinityLevel = AffinityLevel.worker; - break; - } + } + + private AffinityLevel buildGroupAffinityLevel(ExecutionVertexGroup vertexGroup) { + AffinityLevel affinityLevel = null; + for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { + if (affinityLevel == null) { + affinityLevel = vertex.getAffinityLevel(); + } else { + if (affinityLevel == vertex.getAffinityLevel()) { + continue; } - return affinityLevel; + // Set default value to worker. + affinityLevel = AffinityLevel.worker; + break; + } } + return affinityLevel; + } - private ExecutionEdge buildEdge(PipelineEdge pipelineEdge) { - OutputType dataTransferType = pipelineEdge.getType(); - if (dataTransferType == null) { - dataTransferType = OutputType.FORWARD; - } - return new ExecutionEdge( - pipelineEdge.getPartition(), - pipelineEdge.getEdgeId(), - pipelineEdge.getEdgeName(), - pipelineEdge.getSrcId(), - pipelineEdge.getTargetId(), - dataTransferType, - pipelineEdge.getEncoder()); + private ExecutionEdge buildEdge(PipelineEdge pipelineEdge) { + OutputType dataTransferType = pipelineEdge.getType(); + if (dataTransferType == null) { + dataTransferType = OutputType.FORWARD; } + return new ExecutionEdge( + pipelineEdge.getPartition(), + pipelineEdge.getEdgeId(), + pipelineEdge.getEdgeName(), + pipelineEdge.getSrcId(), + pipelineEdge.getTargetId(), + dataTransferType, + pipelineEdge.getEncoder()); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/plan/visualization/ExecutionGraphVisualization.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/plan/visualization/ExecutionGraphVisualization.java index 834ac115e..d19d4a462 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/plan/visualization/ExecutionGraphVisualization.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/plan/visualization/ExecutionGraphVisualization.java @@ -23,63 +23,70 @@ import java.util.Collections; import java.util.List; import java.util.Map; + import org.apache.geaflow.core.graph.ExecutionGraph; import org.apache.geaflow.core.graph.ExecutionVertexGroup; import org.apache.geaflow.core.graph.ExecutionVertexGroupEdge; public class ExecutionGraphVisualization { - public static final String SCHEDULER = "scheduler"; - private static final String NODE_FORMAT = "%s [label=\"%s\"]\n"; - private static final String EMPTY_GRAPH = new StringBuilder("digraph G {\n") - .append("0 [label=\"node-0\"]\n") - .append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)) - .append("}") - .toString(); + public static final String SCHEDULER = "scheduler"; + private static final String NODE_FORMAT = "%s [label=\"%s\"]\n"; + private static final String EMPTY_GRAPH = + new StringBuilder("digraph G {\n") + .append("0 [label=\"node-0\"]\n") + .append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)) + .append("}") + .toString(); - private ExecutionGraph executionGraph; - private List vertexGroupEdgeList; - private Map vertexGroupMap; + private ExecutionGraph executionGraph; + private List vertexGroupEdgeList; + private Map vertexGroupMap; - public ExecutionGraphVisualization(ExecutionGraph executionGraph) { - this.executionGraph = executionGraph; - this.vertexGroupEdgeList = new ArrayList<>(executionGraph.getGroupEdgeMap().values()); - this.vertexGroupMap = executionGraph.getVertexGroupMap(); - } + public ExecutionGraphVisualization(ExecutionGraph executionGraph) { + this.executionGraph = executionGraph; + this.vertexGroupEdgeList = new ArrayList<>(executionGraph.getGroupEdgeMap().values()); + this.vertexGroupMap = executionGraph.getVertexGroupMap(); + } - public String getExecutionGraphViz() { - if (vertexGroupEdgeList.size() == 0 && vertexGroupMap.size() == 0) { - return EMPTY_GRAPH; - } + public String getExecutionGraphViz() { + if (vertexGroupEdgeList.size() == 0 && vertexGroupMap.size() == 0) { + return EMPTY_GRAPH; + } - Collections.sort(vertexGroupEdgeList, (o1, o2) -> { - int i = Integer.compare(o1.getGroupSrcId(), o2.getGroupSrcId()); - if (i == 0) { - return Integer.compare(o1.getGroupTargetId(), o2.getGroupTargetId()); - } else { - return i; - } + Collections.sort( + vertexGroupEdgeList, + (o1, o2) -> { + int i = Integer.compare(o1.getGroupSrcId(), o2.getGroupSrcId()); + if (i == 0) { + return Integer.compare(o1.getGroupTargetId(), o2.getGroupTargetId()); + } else { + return i; + } }); - StringBuilder builder = new StringBuilder("digraph G {\n"); - for (ExecutionVertexGroupEdge vertexGroupEdge : vertexGroupEdgeList) { - builder.append(String.format("%d -> %d [label = \"%s\"]\n", vertexGroupEdge.getGroupSrcId(), - vertexGroupEdge.getGroupTargetId(), vertexGroupEdge.getPartitioner().getPartitionType().toString())); - } - for (ExecutionVertexGroup vertexGroup : vertexGroupMap.values()) { - builder.append(String - .format(NODE_FORMAT, vertexGroup.getGroupId(), vertexGroup)); - } + StringBuilder builder = new StringBuilder("digraph G {\n"); + for (ExecutionVertexGroupEdge vertexGroupEdge : vertexGroupEdgeList) { + builder.append( + String.format( + "%d -> %d [label = \"%s\"]\n", + vertexGroupEdge.getGroupSrcId(), + vertexGroupEdge.getGroupTargetId(), + vertexGroupEdge.getPartitioner().getPartitionType().toString())); + } + for (ExecutionVertexGroup vertexGroup : vertexGroupMap.values()) { + builder.append(String.format(NODE_FORMAT, vertexGroup.getGroupId(), vertexGroup)); + } - builder.append("0 [label=\"node-0\"]\n"); - builder.append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)); + builder.append("0 [label=\"node-0\"]\n"); + builder.append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)); - builder.append("}"); - return builder.toString(); - } + builder.append("}"); + return builder.toString(); + } - @Override - public String toString() { - return executionGraph.toString(); - } + @Override + public String toString() { + return executionGraph.toString(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/util/ExecutionTaskUtils.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/util/ExecutionTaskUtils.java index 90a3550b7..6a988fcc8 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/util/ExecutionTaskUtils.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/core/graph/util/ExecutionTaskUtils.java @@ -24,19 +24,15 @@ public class ExecutionTaskUtils { - /** - * Check whether an execution task is the cycle head. - */ - public static boolean isCycleHead(ExecutionTask task) { - return task.getExecutionTaskType() == ExecutionTaskType.head - || task.getExecutionTaskType() == ExecutionTaskType.singularity; - } + /** Check whether an execution task is the cycle head. */ + public static boolean isCycleHead(ExecutionTask task) { + return task.getExecutionTaskType() == ExecutionTaskType.head + || task.getExecutionTaskType() == ExecutionTaskType.singularity; + } - /** - * Check whether an execution task is the cycle tail. - */ - public static boolean isCycleTail(ExecutionTask task) { - return task.getExecutionTaskType() == ExecutionTaskType.tail - || task.getExecutionTaskType() == ExecutionTaskType.singularity; - } + /** Check whether an execution task is the cycle tail. */ + public static boolean isCycleTail(ExecutionTask task) { + return task.getExecutionTaskType() == ExecutionTaskType.tail + || task.getExecutionTaskType() == ExecutionTaskType.singularity; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/PipelinePlanBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/PipelinePlanBuilder.java index 8ab196b4d..29a8d9513 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/PipelinePlanBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/PipelinePlanBuilder.java @@ -23,11 +23,11 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.ENABLE_EXTRA_OPTIMIZE; import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.ENABLE_EXTRA_OPTIMIZE_SINK; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.ArrayList; import java.util.HashSet; import java.util.List; + import org.apache.geaflow.api.graph.base.algo.GraphExecAlgo; import org.apache.geaflow.api.graph.materialize.PGraphMaterialize; import org.apache.geaflow.api.pdata.PStreamSink; @@ -63,487 +63,562 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is for building and optimizing logical plan. - */ +import com.google.common.base.Preconditions; + +/** This class is for building and optimizing logical plan. */ public class PipelinePlanBuilder implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelinePlanBuilder.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PipelinePlanBuilder.class); - private static final int SINGLE_WINDOW_CHECKPOINT_DURATION = 1; - public static final int ITERATION_AGG_VERTEX_ID = 0; + private static final int SINGLE_WINDOW_CHECKPOINT_DURATION = 1; + public static final int ITERATION_AGG_VERTEX_ID = 0; - private PipelineGraph pipelineGraph; - private HashSet visitedVIds; - private int edgeIdGenerator; + private PipelineGraph pipelineGraph; + private HashSet visitedVIds; + private int edgeIdGenerator; - public PipelinePlanBuilder() { - this.pipelineGraph = new PipelineGraph(); - this.visitedVIds = new HashSet<>(); - this.edgeIdGenerator = 1; + public PipelinePlanBuilder() { + this.pipelineGraph = new PipelineGraph(); + this.visitedVIds = new HashSet<>(); + this.edgeIdGenerator = 1; + } + + /** Build the whole plan graph. */ + public PipelineGraph buildPlan(AbstractPipelineContext pipelineContext) { + List actions = pipelineContext.getActions(); + if (actions.size() < 1) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.actionIsEmptyError()); + } + for (PAction action : actions) { + visitAction(action); } - /** - * Build the whole plan graph. - */ - public PipelineGraph buildPlan(AbstractPipelineContext pipelineContext) { - List actions = pipelineContext.getActions(); - if (actions.size() < 1) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.actionIsEmptyError()); - } - for (PAction action : actions) { - visitAction(action); - } + // Check the validity of upstream vertex for the current vertex. + this.pipelineGraph + .getVertexMap() + .values() + .forEach( + pipelineVertex -> { + LOGGER.info(pipelineVertex.getVertexString()); + DAGValidator.checkVertexValidity(this.pipelineGraph, pipelineVertex, true); + }); + + boolean isSingleWindow = + this.pipelineGraph.getSourceVertices().stream() + .allMatch( + v -> + ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() + == OpArgs.OpType.SINGLE_WINDOW_SOURCE); + if (isSingleWindow) { + pipelineContext + .getConfig() + .put( + BATCH_NUMBER_PER_CHECKPOINT.getKey(), + String.valueOf(SINGLE_WINDOW_CHECKPOINT_DURATION)); + LOGGER.info("reset checkpoint duration for all single window source pipeline graph"); + } - // Check the validity of upstream vertex for the current vertex. - this.pipelineGraph.getVertexMap().values().forEach(pipelineVertex -> { - LOGGER.info(pipelineVertex.getVertexString()); - DAGValidator.checkVertexValidity(this.pipelineGraph, pipelineVertex, true); - }); - - boolean isSingleWindow = this.pipelineGraph.getSourceVertices().stream().allMatch(v -> - ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() - == OpArgs.OpType.SINGLE_WINDOW_SOURCE); - if (isSingleWindow) { - pipelineContext.getConfig().put(BATCH_NUMBER_PER_CHECKPOINT.getKey(), - String.valueOf(SINGLE_WINDOW_CHECKPOINT_DURATION)); - LOGGER.info("reset checkpoint duration for all single window source pipeline graph"); + return this.pipelineGraph; + } + + /** Plan optimize. */ + public void optimizePlan(Configuration pipelineConfig) { + List dags = new ArrayList<>(); + List nodeInfos = new ArrayList<>(); + + PlanGraphVisualization visualization = new PlanGraphVisualization(this.pipelineGraph); + String logicalPlan = visualization.getGraphviz(); + dags.add(logicalPlan); + nodeInfos.add(visualization.getNodeInfo()); + LOGGER.info("logical plan: {}", logicalPlan); + + optimizePipelinePlan(pipelineConfig); + visualization = new PlanGraphVisualization(this.pipelineGraph); + String physicalPlan = visualization.getGraphviz(); + dags.add(physicalPlan); + nodeInfos.add(visualization.getNodeInfo()); + LOGGER.info("physical plan: {}", physicalPlan); + } + + /** Build plan graph by visiting action node at first. */ + private void visitAction(PAction action) { + int vId = action.getId(); + if (visitedVIds.add(vId)) { + Stream stream = (Stream) action; + if (action instanceof PGraphMaterialize) { + visitMaterializeAction((PGraphMaterialize) action); + } else { + PipelineVertex pipelineVertex = + new PipelineVertex(vId, stream.getOperator(), stream.getParallelism()); + if (action instanceof PStreamSink) { + pipelineVertex.setType(VertexType.sink); + pipelineVertex.setVertexMode(VertexMode.append); + } else { + pipelineVertex.setType(VertexType.collect); + pipelineVertex.setParallelism(1); } - return this.pipelineGraph; + this.pipelineGraph.addVertex(pipelineVertex); + Stream input = stream.getInput(); + PipelineEdge pipelineEdge = + new PipelineEdge( + this.edgeIdGenerator++, + input.getId(), + vId, + input.getPartition(), + input.getEncoder()); + this.pipelineGraph.addEdge(pipelineEdge); + visitNode(stream.getInput()); + } } + } + + /** Visit materialize action. */ + private void visitMaterializeAction(PGraphMaterialize materialize) { + Stream stream = (Stream) materialize; + PipelineVertex pipelineVertex = + new PipelineVertex(materialize.getId(), stream.getOperator(), stream.getParallelism()); + pipelineVertex.setType(VertexType.sink); + pipelineVertex.setVertexMode(VertexMode.append); + + MaterializedIncGraph materializedIncGraph = (MaterializedIncGraph) stream; + Stream vertexStreamInput = materializedIncGraph.getInput(); + Stream edgeStreamInput = (Stream) materializedIncGraph.getEdges(); + Preconditions.checkArgument( + vertexStreamInput != null && edgeStreamInput != null, + "input vertex and edge stream must be not null"); + + PipelineEdge vertexInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexStreamInput.getId(), + stream.getId(), + vertexStreamInput.getPartition(), + vertexStreamInput.getEncoder()); + vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexInputEdge); + PipelineEdge edgeInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeStreamInput.getId(), + stream.getId(), + edgeStreamInput.getPartition(), + edgeStreamInput.getEncoder()); + edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeInputEdge); + this.pipelineGraph.addVertex(pipelineVertex); + + visitNode(vertexStreamInput); + visitNode(edgeStreamInput); + } + + /** Visit all plan node and build pipeline graph. */ + private void visitNode(Stream stream) { + int vId = stream.getId(); + if (visitedVIds.add(vId)) { + PipelineVertex pipelineVertex = + new PipelineVertex(vId, stream.getOperator(), stream.getParallelism()); + pipelineVertex.setAffinity(AffinityLevel.worker); + switch (stream.getTransformType()) { + case StreamSource: + { + pipelineVertex.setType(VertexType.source); + pipelineVertex.setAffinity(AffinityLevel.worker); + break; + } + case ContinueGraphMaterialize: + pipelineVertex.setType(VertexType.inc_process); + pipelineVertex.setAffinity(AffinityLevel.worker); + MaterializedIncGraph pGraphMaterialize = (MaterializedIncGraph) stream; + + Stream vertexInput = pGraphMaterialize.getInput(); + Stream edgeInput = (Stream) pGraphMaterialize.getEdges(); + Preconditions.checkArgument( + vertexInput != null && edgeInput != null, + "input vertex and edge stream must be not null"); + + PipelineEdge vertexEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexInput.getId(), + stream.getId(), + vertexInput.getPartition(), + vertexInput.getEncoder()); + vertexEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexEdge); + PipelineEdge edgeEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeInput.getId(), + stream.getId(), + edgeInput.getPartition(), + edgeInput.getEncoder()); + edgeEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeEdge); + + visitNode(vertexInput); + visitNode(edgeInput); + break; + case ContinueGraphCompute: + { + pipelineVertex.setType(VertexType.inc_iterator); + pipelineVertex.setAffinity(AffinityLevel.worker); + ComputeIncGraph pGraphCompute = (ComputeIncGraph) stream; + GraphExecAlgo computeType = pGraphCompute.getGraphComputeType(); + switch (computeType) { + case VertexCentric: + pipelineVertex.setType(VertexType.inc_vertex_centric); + break; + default: + throw new GeaflowRuntimeException("not support graph compute type, " + computeType); + } - /** - * Plan optimize. - */ - public void optimizePlan(Configuration pipelineConfig) { - List dags = new ArrayList<>(); - List nodeInfos = new ArrayList<>(); - - PlanGraphVisualization visualization = new PlanGraphVisualization(this.pipelineGraph); - String logicalPlan = visualization.getGraphviz(); - dags.add(logicalPlan); - nodeInfos.add(visualization.getNodeInfo()); - LOGGER.info("logical plan: {}", logicalPlan); - - optimizePipelinePlan(pipelineConfig); - visualization = new PlanGraphVisualization(this.pipelineGraph); - String physicalPlan = visualization.getGraphviz(); - dags.add(physicalPlan); - nodeInfos.add(visualization.getNodeInfo()); - LOGGER.info("physical plan: {}", physicalPlan); - } + pipelineVertex.setIterations(pGraphCompute.getMaxIterations()); + Stream vertexStreamInput = pGraphCompute.getInput(); + Stream edgeStreamInput = (Stream) pGraphCompute.getEdges(); + Preconditions.checkArgument( + vertexStreamInput != null && edgeStreamInput != null, + "input vertex and edge stream must be not null"); + + PipelineEdge vertexInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexStreamInput.getId(), + stream.getId(), + vertexStreamInput.getPartition(), + vertexStreamInput.getEncoder()); + vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexInputEdge); + PipelineEdge edgeInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeStreamInput.getId(), + stream.getId(), + edgeStreamInput.getPartition(), + edgeStreamInput.getEncoder()); + edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeInputEdge); + + // iteration loop edge + PipelineEdge iterationEdge = buildIterationEdge(vId, pGraphCompute.getMsgEncoder()); + this.pipelineGraph.addEdge(iterationEdge); + + buildIterationAggVertexAndEdge(pipelineVertex); + + visitNode(vertexStreamInput); + visitNode(edgeStreamInput); + break; + } + case WindowGraphCompute: + { + pipelineVertex.setType(VertexType.iterator); + pipelineVertex.setAffinity(AffinityLevel.worker); + ComputeWindowGraph pGraphCompute = (ComputeWindowGraph) stream; + GraphExecAlgo computeType = pGraphCompute.getGraphComputeType(); + switch (computeType) { + case VertexCentric: + pipelineVertex.setType(VertexType.vertex_centric); + break; + default: + throw new GeaflowRuntimeException("not support graph compute type, " + computeType); + } - /** - * Build plan graph by visiting action node at first. - */ - private void visitAction(PAction action) { - int vId = action.getId(); - if (visitedVIds.add(vId)) { - Stream stream = (Stream) action; - if (action instanceof PGraphMaterialize) { - visitMaterializeAction((PGraphMaterialize) action); - } else { - PipelineVertex pipelineVertex = new PipelineVertex(vId, stream.getOperator(), - stream.getParallelism()); - if (action instanceof PStreamSink) { - pipelineVertex.setType(VertexType.sink); - pipelineVertex.setVertexMode(VertexMode.append); - } else { - pipelineVertex.setType(VertexType.collect); - pipelineVertex.setParallelism(1); - } - - this.pipelineGraph.addVertex(pipelineVertex); - Stream input = stream.getInput(); - PipelineEdge pipelineEdge = new PipelineEdge(this.edgeIdGenerator++, - input.getId(), vId, input.getPartition(), input.getEncoder()); - this.pipelineGraph.addEdge(pipelineEdge); - visitNode(stream.getInput()); + pipelineVertex.setIterations(pGraphCompute.getMaxIterations()); + Stream vertexStreamInput = pGraphCompute.getInput(); + Stream edgeStreamInput = (Stream) pGraphCompute.getEdges(); + Preconditions.checkArgument( + vertexStreamInput != null && edgeStreamInput != null, + "input vertex and edge stream must be not null"); + + PipelineEdge vertexInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexStreamInput.getId(), + stream.getId(), + vertexStreamInput.getPartition(), + vertexStreamInput.getEncoder()); + vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexInputEdge); + PipelineEdge edgeInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeStreamInput.getId(), + stream.getId(), + edgeStreamInput.getPartition(), + edgeStreamInput.getEncoder()); + edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeInputEdge); + + // iteration loop edge + PipelineEdge iterationEdge = buildIterationEdge(vId, pGraphCompute.getMsgEncoder()); + this.pipelineGraph.addEdge(iterationEdge); + + buildIterationAggVertexAndEdge(pipelineVertex); + + visitNode(vertexStreamInput); + visitNode(edgeStreamInput); + break; + } + case WindowGraphTraversal: + { + pipelineVertex.setType(VertexType.iterator); + pipelineVertex.setAffinity(AffinityLevel.worker); + TraversalWindowGraph windowGraph = (TraversalWindowGraph) stream; + GraphExecAlgo traversalType = windowGraph.getGraphTraversalType(); + switch (traversalType) { + case VertexCentric: + pipelineVertex.setType(VertexType.vertex_centric); + break; + default: + throw new GeaflowRuntimeException( + "not support graph traversal type, " + traversalType); } - } - } - /** - * Visit materialize action. - */ - private void visitMaterializeAction(PGraphMaterialize materialize) { - Stream stream = (Stream) materialize; - PipelineVertex pipelineVertex = new PipelineVertex( - materialize.getId(), stream.getOperator(), stream.getParallelism()); - pipelineVertex.setType(VertexType.sink); - pipelineVertex.setVertexMode(VertexMode.append); - - MaterializedIncGraph materializedIncGraph = (MaterializedIncGraph) stream; - Stream vertexStreamInput = materializedIncGraph.getInput(); - Stream edgeStreamInput = (Stream) materializedIncGraph.getEdges(); - Preconditions.checkArgument(vertexStreamInput != null && edgeStreamInput != null, - "input vertex and edge stream must be not null"); - - PipelineEdge vertexInputEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexStreamInput.getId(), - stream.getId(), vertexStreamInput.getPartition(), vertexStreamInput.getEncoder()); - vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexInputEdge); - PipelineEdge edgeInputEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeStreamInput.getId(), - stream.getId(), edgeStreamInput.getPartition(), edgeStreamInput.getEncoder()); - edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeInputEdge); - this.pipelineGraph.addVertex(pipelineVertex); + pipelineVertex.setIterations(windowGraph.getMaxIterations()); + Stream vertexStreamInput = windowGraph.getInput(); + Stream edgeStreamInput = (Stream) windowGraph.getEdges(); + Preconditions.checkArgument( + vertexStreamInput != null && edgeStreamInput != null, + "input vertex and edge stream must be not null"); + + PipelineEdge vertexInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexStreamInput.getId(), + stream.getId(), + vertexStreamInput.getPartition(), + vertexStreamInput.getEncoder()); + vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexInputEdge); + PipelineEdge edgeInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeStreamInput.getId(), + stream.getId(), + edgeStreamInput.getPartition(), + edgeStreamInput.getEncoder()); + edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeInputEdge); + + // Add request input. + if (windowGraph.getRequestStream() != null) { + Stream requestStreamInput = (Stream) windowGraph.getRequestStream(); + PipelineEdge requestInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + requestStreamInput.getId(), + stream.getId(), + requestStreamInput.getPartition(), + requestStreamInput.getEncoder()); + requestInputEdge.setEdgeName(GraphRecordNames.Request.name()); + this.pipelineGraph.addEdge(requestInputEdge); + visitNode(requestStreamInput); + } - visitNode(vertexStreamInput); - visitNode(edgeStreamInput); - } + // iteration loop edge + PipelineEdge iterationEdge = buildIterationEdge(vId, windowGraph.getMsgEncoder()); + this.pipelineGraph.addEdge(iterationEdge); + + buildIterationAggVertexAndEdge(pipelineVertex); - /** - * Visit all plan node and build pipeline graph. - */ - private void visitNode(Stream stream) { - int vId = stream.getId(); - if (visitedVIds.add(vId)) { - PipelineVertex pipelineVertex = new PipelineVertex(vId, - stream.getOperator(), stream.getParallelism()); + visitNode(vertexStreamInput); + visitNode(edgeStreamInput); + break; + } + case ContinueGraphTraversal: + { + pipelineVertex.setType(VertexType.inc_iterator); pipelineVertex.setAffinity(AffinityLevel.worker); - switch (stream.getTransformType()) { - case StreamSource: { - pipelineVertex.setType(VertexType.source); - pipelineVertex.setAffinity(AffinityLevel.worker); - break; - } - case ContinueGraphMaterialize: - pipelineVertex.setType(VertexType.inc_process); - pipelineVertex.setAffinity(AffinityLevel.worker); - MaterializedIncGraph pGraphMaterialize = (MaterializedIncGraph) stream; - - Stream vertexInput = pGraphMaterialize.getInput(); - Stream edgeInput = (Stream) pGraphMaterialize.getEdges(); - Preconditions.checkArgument(vertexInput != null && edgeInput != null, - "input vertex and edge stream must be not null"); - - PipelineEdge vertexEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexInput.getId(), - stream.getId(), vertexInput.getPartition(), vertexInput.getEncoder()); - vertexEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexEdge); - PipelineEdge edgeEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeInput.getId(), - stream.getId(), edgeInput.getPartition(), edgeInput.getEncoder()); - edgeEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeEdge); - - visitNode(vertexInput); - visitNode(edgeInput); - break; - case ContinueGraphCompute: { - pipelineVertex.setType(VertexType.inc_iterator); - pipelineVertex.setAffinity(AffinityLevel.worker); - ComputeIncGraph pGraphCompute = (ComputeIncGraph) stream; - GraphExecAlgo computeType = pGraphCompute.getGraphComputeType(); - switch (computeType) { - case VertexCentric: - pipelineVertex.setType(VertexType.inc_vertex_centric); - break; - default: - throw new GeaflowRuntimeException( - "not support graph compute type, " + computeType); - } - - pipelineVertex.setIterations(pGraphCompute.getMaxIterations()); - Stream vertexStreamInput = pGraphCompute.getInput(); - Stream edgeStreamInput = (Stream) pGraphCompute.getEdges(); - Preconditions.checkArgument( - vertexStreamInput != null && edgeStreamInput != null, - "input vertex and edge stream must be not null"); - - PipelineEdge vertexInputEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexStreamInput.getId(), - stream.getId(), vertexStreamInput.getPartition(), - vertexStreamInput.getEncoder()); - vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexInputEdge); - PipelineEdge edgeInputEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeStreamInput.getId(), - stream.getId(), edgeStreamInput.getPartition(), - edgeStreamInput.getEncoder()); - edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeInputEdge); - - // iteration loop edge - PipelineEdge iterationEdge = buildIterationEdge(vId, - pGraphCompute.getMsgEncoder()); - this.pipelineGraph.addEdge(iterationEdge); - - buildIterationAggVertexAndEdge(pipelineVertex); - - visitNode(vertexStreamInput); - visitNode(edgeStreamInput); - break; - } - case WindowGraphCompute: { - pipelineVertex.setType(VertexType.iterator); - pipelineVertex.setAffinity(AffinityLevel.worker); - ComputeWindowGraph pGraphCompute = (ComputeWindowGraph) stream; - GraphExecAlgo computeType = pGraphCompute.getGraphComputeType(); - switch (computeType) { - case VertexCentric: - pipelineVertex.setType(VertexType.vertex_centric); - break; - default: - throw new GeaflowRuntimeException( - "not support graph compute type, " + computeType); - } - - pipelineVertex.setIterations(pGraphCompute.getMaxIterations()); - Stream vertexStreamInput = pGraphCompute.getInput(); - Stream edgeStreamInput = (Stream) pGraphCompute.getEdges(); - Preconditions.checkArgument( - vertexStreamInput != null && edgeStreamInput != null, - "input vertex and edge stream must be not null"); - - PipelineEdge vertexInputEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexStreamInput.getId(), - stream.getId(), vertexStreamInput.getPartition(), - vertexStreamInput.getEncoder()); - vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexInputEdge); - PipelineEdge edgeInputEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeStreamInput.getId(), - stream.getId(), edgeStreamInput.getPartition(), - edgeStreamInput.getEncoder()); - edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeInputEdge); - - // iteration loop edge - PipelineEdge iterationEdge = buildIterationEdge(vId, - pGraphCompute.getMsgEncoder()); - this.pipelineGraph.addEdge(iterationEdge); - - buildIterationAggVertexAndEdge(pipelineVertex); - - visitNode(vertexStreamInput); - visitNode(edgeStreamInput); - break; - } - case WindowGraphTraversal: { - pipelineVertex.setType(VertexType.iterator); - pipelineVertex.setAffinity(AffinityLevel.worker); - TraversalWindowGraph windowGraph = (TraversalWindowGraph) stream; - GraphExecAlgo traversalType = windowGraph.getGraphTraversalType(); - switch (traversalType) { - case VertexCentric: - pipelineVertex.setType(VertexType.vertex_centric); - break; - default: - throw new GeaflowRuntimeException( - "not support graph traversal type, " + traversalType); - } - - pipelineVertex.setIterations(windowGraph.getMaxIterations()); - Stream vertexStreamInput = windowGraph.getInput(); - Stream edgeStreamInput = (Stream) windowGraph.getEdges(); - Preconditions.checkArgument( - vertexStreamInput != null && edgeStreamInput != null, - "input vertex and edge stream must be not null"); - - PipelineEdge vertexInputEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexStreamInput.getId(), - stream.getId(), vertexStreamInput.getPartition(), - vertexStreamInput.getEncoder()); - vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexInputEdge); - PipelineEdge edgeInputEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeStreamInput.getId(), - stream.getId(), edgeStreamInput.getPartition(), - edgeStreamInput.getEncoder()); - edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeInputEdge); - - // Add request input. - if (windowGraph.getRequestStream() != null) { - Stream requestStreamInput = (Stream) windowGraph.getRequestStream(); - PipelineEdge requestInputEdge = new PipelineEdge(this.edgeIdGenerator++, - requestStreamInput.getId(), - stream.getId(), requestStreamInput.getPartition(), - requestStreamInput.getEncoder()); - requestInputEdge.setEdgeName(GraphRecordNames.Request.name()); - this.pipelineGraph.addEdge(requestInputEdge); - visitNode(requestStreamInput); - } - - // iteration loop edge - PipelineEdge iterationEdge = buildIterationEdge(vId, - windowGraph.getMsgEncoder()); - this.pipelineGraph.addEdge(iterationEdge); - - buildIterationAggVertexAndEdge(pipelineVertex); - - visitNode(vertexStreamInput); - visitNode(edgeStreamInput); - break; - } - case ContinueGraphTraversal: { - pipelineVertex.setType(VertexType.inc_iterator); - pipelineVertex.setAffinity(AffinityLevel.worker); - TraversalIncGraph windowGraph = (TraversalIncGraph) stream; - GraphExecAlgo traversalType = windowGraph.getGraphTraversalType(); - switch (traversalType) { - case VertexCentric: - pipelineVertex.setType(VertexType.inc_vertex_centric); - break; - default: - throw new GeaflowRuntimeException( - "not support graph traversal type, " + traversalType); - } - - pipelineVertex.setIterations(windowGraph.getMaxIterations()); - Stream vertexStreamInput = windowGraph.getInput(); - Stream edgeStreamInput = (Stream) windowGraph.getEdges(); - Preconditions.checkArgument( - vertexStreamInput != null && edgeStreamInput != null, - "input vertex and edge stream must be not null"); - - // Add vertex input. - PipelineEdge vertexInputEdge = new PipelineEdge(this.edgeIdGenerator++, - vertexStreamInput.getId(), - stream.getId(), vertexStreamInput.getPartition(), - vertexStreamInput.getEncoder()); - vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); - this.pipelineGraph.addEdge(vertexInputEdge); - - // Add edge input. - PipelineEdge edgeInputEdge = new PipelineEdge(this.edgeIdGenerator++, - edgeStreamInput.getId(), - stream.getId(), edgeStreamInput.getPartition(), - edgeStreamInput.getEncoder()); - edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); - this.pipelineGraph.addEdge(edgeInputEdge); - - // Add request input. - if (windowGraph.getRequestStream() != null) { - Stream requestStreamInput = (Stream) windowGraph.getRequestStream(); - PipelineEdge requestInputEdge = new PipelineEdge(this.edgeIdGenerator++, - requestStreamInput.getId(), - stream.getId(), requestStreamInput.getPartition(), - requestStreamInput.getEncoder()); - requestInputEdge.setEdgeName(GraphRecordNames.Request.name()); - this.pipelineGraph.addEdge(requestInputEdge); - visitNode(requestStreamInput); - } - - // Add iteration loop edge - PipelineEdge iterationEdge = buildIterationEdge(vId, - windowGraph.getMsgEncoder()); - this.pipelineGraph.addEdge(iterationEdge); - - buildIterationAggVertexAndEdge(pipelineVertex); - - visitNode(vertexStreamInput); - visitNode(edgeStreamInput); - break; - } - case StreamTransform: { - pipelineVertex.setType(VertexType.process); - Stream inputStream = stream.getInput(); - Preconditions.checkArgument(inputStream != null, - "input stream must be not null"); - - PipelineEdge pipelineEdge = new PipelineEdge(this.edgeIdGenerator++, - inputStream.getId(), stream.getId(), - inputStream.getPartition(), inputStream.getEncoder()); - this.pipelineGraph.addEdge(pipelineEdge); - - visitNode(inputStream); - break; - } - case ContinueStreamCompute: { - pipelineVertex.setType(VertexType.inc_process); - pipelineVertex.setAffinity(AffinityLevel.worker); - Stream inputStream = stream.getInput(); - Preconditions.checkArgument(inputStream != null, - "input stream must be not null"); - - PipelineEdge pipelineEdge = new PipelineEdge(this.edgeIdGenerator++, - inputStream.getId(), stream.getId(), - inputStream.getPartition(), inputStream.getEncoder()); - this.pipelineGraph.addEdge(pipelineEdge); - - visitNode(inputStream); - break; - } - case StreamUnion: { - pipelineVertex.setType(VertexType.union); - WindowUnionStream unionStream = (WindowUnionStream) stream; - - Stream mainInput = stream.getInput(); - PipelineEdge mainEdge = new PipelineEdge(this.edgeIdGenerator++, - mainInput.getId(), unionStream.getId(), - mainInput.getPartition(), unionStream.getEncoder()); - mainEdge.setStreamOrdinal(0); - this.pipelineGraph.addEdge(mainEdge); - visitNode(mainInput); - - List otherInputs = unionStream.getUnionWindowDataStreamList(); - for (int index = 0; index < otherInputs.size(); index++) { - Stream otherInput = otherInputs.get(index); - PipelineEdge rightEdge = new PipelineEdge(this.edgeIdGenerator++, - otherInput.getId(), - unionStream.getId(), otherInput.getPartition(), - otherInput.getEncoder()); - rightEdge.setStreamOrdinal(index + 1); - this.pipelineGraph.addEdge(rightEdge); - visitNode(otherInput); - } - break; - } - default: - throw new GeaflowRuntimeException( - "Not supported transform type: " + stream.getTransformType()); + TraversalIncGraph windowGraph = (TraversalIncGraph) stream; + GraphExecAlgo traversalType = windowGraph.getGraphTraversalType(); + switch (traversalType) { + case VertexCentric: + pipelineVertex.setType(VertexType.inc_vertex_centric); + break; + default: + throw new GeaflowRuntimeException( + "not support graph traversal type, " + traversalType); } - this.pipelineGraph.addVertex(pipelineVertex); - } - } - private PipelineEdge buildIterationEdge(int vid, IEncoder encoder) { - PipelineEdge iterationEdge = new PipelineEdge(this.edgeIdGenerator++, vid, vid, - new KeyPartitioner<>(vid), encoder, OutputType.LOOP); - iterationEdge.setEdgeName(GraphRecordNames.Message.name()); - return iterationEdge; - } + pipelineVertex.setIterations(windowGraph.getMaxIterations()); + Stream vertexStreamInput = windowGraph.getInput(); + Stream edgeStreamInput = (Stream) windowGraph.getEdges(); + Preconditions.checkArgument( + vertexStreamInput != null && edgeStreamInput != null, + "input vertex and edge stream must be not null"); + + // Add vertex input. + PipelineEdge vertexInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + vertexStreamInput.getId(), + stream.getId(), + vertexStreamInput.getPartition(), + vertexStreamInput.getEncoder()); + vertexInputEdge.setEdgeName(GraphRecordNames.Vertex.name()); + this.pipelineGraph.addEdge(vertexInputEdge); + + // Add edge input. + PipelineEdge edgeInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + edgeStreamInput.getId(), + stream.getId(), + edgeStreamInput.getPartition(), + edgeStreamInput.getEncoder()); + edgeInputEdge.setEdgeName(GraphRecordNames.Edge.name()); + this.pipelineGraph.addEdge(edgeInputEdge); + + // Add request input. + if (windowGraph.getRequestStream() != null) { + Stream requestStreamInput = (Stream) windowGraph.getRequestStream(); + PipelineEdge requestInputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + requestStreamInput.getId(), + stream.getId(), + requestStreamInput.getPartition(), + requestStreamInput.getEncoder()); + requestInputEdge.setEdgeName(GraphRecordNames.Request.name()); + this.pipelineGraph.addEdge(requestInputEdge); + visitNode(requestStreamInput); + } - private void buildIterationAggVertexAndEdge(PipelineVertex iterationVertex) { - if (iterationVertex.getOperator() instanceof IGraphVertexCentricAggOp) { - PipelineVertex aggVertex = new PipelineVertex(ITERATION_AGG_VERTEX_ID, - iterationVertex.getOperator(), 0); - aggVertex.setType(VertexType.iteration_aggregation); - this.pipelineGraph.addVertex(aggVertex); - - PipelineEdge inputEdge = new PipelineEdge(this.edgeIdGenerator++, - iterationVertex.getVertexId(), ITERATION_AGG_VERTEX_ID, - new KeyPartitioner<>(iterationVertex.getVertexId()), null, OutputType.RESPONSE); - inputEdge.setEdgeName(GraphRecordNames.Aggregate.name()); - this.pipelineGraph.addEdge(inputEdge); - - PipelineEdge outputEdge = new PipelineEdge(this.edgeIdGenerator++, - ITERATION_AGG_VERTEX_ID, iterationVertex.getVertexId(), - new KeyPartitioner<>(ITERATION_AGG_VERTEX_ID), null, OutputType.RESPONSE); - outputEdge.setEdgeName(GraphRecordNames.Aggregate.name()); - this.pipelineGraph.addEdge(outputEdge); - } + // Add iteration loop edge + PipelineEdge iterationEdge = buildIterationEdge(vId, windowGraph.getMsgEncoder()); + this.pipelineGraph.addEdge(iterationEdge); + + buildIterationAggVertexAndEdge(pipelineVertex); + + visitNode(vertexStreamInput); + visitNode(edgeStreamInput); + break; + } + case StreamTransform: + { + pipelineVertex.setType(VertexType.process); + Stream inputStream = stream.getInput(); + Preconditions.checkArgument(inputStream != null, "input stream must be not null"); + + PipelineEdge pipelineEdge = + new PipelineEdge( + this.edgeIdGenerator++, + inputStream.getId(), + stream.getId(), + inputStream.getPartition(), + inputStream.getEncoder()); + this.pipelineGraph.addEdge(pipelineEdge); + + visitNode(inputStream); + break; + } + case ContinueStreamCompute: + { + pipelineVertex.setType(VertexType.inc_process); + pipelineVertex.setAffinity(AffinityLevel.worker); + Stream inputStream = stream.getInput(); + Preconditions.checkArgument(inputStream != null, "input stream must be not null"); + + PipelineEdge pipelineEdge = + new PipelineEdge( + this.edgeIdGenerator++, + inputStream.getId(), + stream.getId(), + inputStream.getPartition(), + inputStream.getEncoder()); + this.pipelineGraph.addEdge(pipelineEdge); + + visitNode(inputStream); + break; + } + case StreamUnion: + { + pipelineVertex.setType(VertexType.union); + WindowUnionStream unionStream = (WindowUnionStream) stream; + + Stream mainInput = stream.getInput(); + PipelineEdge mainEdge = + new PipelineEdge( + this.edgeIdGenerator++, + mainInput.getId(), + unionStream.getId(), + mainInput.getPartition(), + unionStream.getEncoder()); + mainEdge.setStreamOrdinal(0); + this.pipelineGraph.addEdge(mainEdge); + visitNode(mainInput); + + List otherInputs = unionStream.getUnionWindowDataStreamList(); + for (int index = 0; index < otherInputs.size(); index++) { + Stream otherInput = otherInputs.get(index); + PipelineEdge rightEdge = + new PipelineEdge( + this.edgeIdGenerator++, + otherInput.getId(), + unionStream.getId(), + otherInput.getPartition(), + otherInput.getEncoder()); + rightEdge.setStreamOrdinal(index + 1); + this.pipelineGraph.addEdge(rightEdge); + visitNode(otherInput); + } + break; + } + default: + throw new GeaflowRuntimeException( + "Not supported transform type: " + stream.getTransformType()); + } + this.pipelineGraph.addVertex(pipelineVertex); } - - /** - * Enforce union and chain optimize. - */ - private void optimizePipelinePlan(Configuration pipelineConfig) { - if (pipelineConfig.getBoolean(ENABLE_EXTRA_OPTIMIZE)) { - // Union Optimization. - boolean isExtraOptimizeSink = pipelineConfig.getBoolean(ENABLE_EXTRA_OPTIMIZE_SINK); - new UnionOptimizer(isExtraOptimizeSink).optimizePlan(pipelineGraph); - LOGGER.info("union optimize: {}", - new PlanGraphVisualization(pipelineGraph).getGraphviz()); - } - new PipelineGraphOptimizer().optimizePipelineGraph(pipelineGraph); + } + + private PipelineEdge buildIterationEdge(int vid, IEncoder encoder) { + PipelineEdge iterationEdge = + new PipelineEdge( + this.edgeIdGenerator++, vid, vid, new KeyPartitioner<>(vid), encoder, OutputType.LOOP); + iterationEdge.setEdgeName(GraphRecordNames.Message.name()); + return iterationEdge; + } + + private void buildIterationAggVertexAndEdge(PipelineVertex iterationVertex) { + if (iterationVertex.getOperator() instanceof IGraphVertexCentricAggOp) { + PipelineVertex aggVertex = + new PipelineVertex(ITERATION_AGG_VERTEX_ID, iterationVertex.getOperator(), 0); + aggVertex.setType(VertexType.iteration_aggregation); + this.pipelineGraph.addVertex(aggVertex); + + PipelineEdge inputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + iterationVertex.getVertexId(), + ITERATION_AGG_VERTEX_ID, + new KeyPartitioner<>(iterationVertex.getVertexId()), + null, + OutputType.RESPONSE); + inputEdge.setEdgeName(GraphRecordNames.Aggregate.name()); + this.pipelineGraph.addEdge(inputEdge); + + PipelineEdge outputEdge = + new PipelineEdge( + this.edgeIdGenerator++, + ITERATION_AGG_VERTEX_ID, + iterationVertex.getVertexId(), + new KeyPartitioner<>(ITERATION_AGG_VERTEX_ID), + null, + OutputType.RESPONSE); + outputEdge.setEdgeName(GraphRecordNames.Aggregate.name()); + this.pipelineGraph.addEdge(outputEdge); } - + } + + /** Enforce union and chain optimize. */ + private void optimizePipelinePlan(Configuration pipelineConfig) { + if (pipelineConfig.getBoolean(ENABLE_EXTRA_OPTIMIZE)) { + // Union Optimization. + boolean isExtraOptimizeSink = pipelineConfig.getBoolean(ENABLE_EXTRA_OPTIMIZE_SINK); + new UnionOptimizer(isExtraOptimizeSink).optimizePlan(pipelineGraph); + LOGGER.info("union optimize: {}", new PlanGraphVisualization(pipelineGraph).getGraphviz()); + } + new PipelineGraphOptimizer().optimizePipelineGraph(pipelineGraph); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/AffinityLevel.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/AffinityLevel.java index aa851f8eb..c23f559dc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/AffinityLevel.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/AffinityLevel.java @@ -21,8 +21,6 @@ public enum AffinityLevel { - /** - * An affinity level expects which to be scheduled to the same worker. - */ - worker, + /** An affinity level expects which to be scheduled to the same worker. */ + worker, } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineEdge.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineEdge.java index 8fa1f8e58..92f3ed13e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineEdge.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineEdge.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.api.partition.kv.RandomPartition; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.partitioner.IPartitioner; @@ -29,149 +30,154 @@ public class PipelineEdge implements Serializable { - private int edgeId; - private int srcId; - private int targetId; - private int streamOrdinal; - private IPartitioner partition; - private PartitionType partitionType; - private String edgeName; - private OutputType type; - private IEncoder encoder; - - public PipelineEdge(int edgeId, int srcId, int targetId, IPartitioner partition, IEncoder encoder) { - this(edgeId, srcId, targetId, partition, 0, encoder); - } - - public PipelineEdge(int edgeId, int srcId, int targetId, IPartitioner partition, - IEncoder encoder, OutputType type) { - this(edgeId, srcId, targetId, partition, 0, encoder); - this.type = type; - } - - public PipelineEdge(int edgeId, int srcId, int targetId, IPartitioner partition, - int streamOrdinal, IEncoder encoder) { - this.edgeId = edgeId; - this.srcId = srcId; - this.targetId = targetId; - this.partition = partition; - this.streamOrdinal = streamOrdinal; - this.partitionType = partition.getPartitionType(); - this.encoder = encoder; - } - - public PartitionType getPartitionType() { - return partitionType; - } - - public void setPartitionType(PartitionType partitionType) { - this.partitionType = partitionType; - } - - public int getEdgeId() { - return edgeId; - } - - public void setEdgeId(int edgeId) { - this.edgeId = edgeId; - } - - public int getSrcId() { - return srcId; - } - - public void setSrcId(int srcId) { - this.srcId = srcId; - } - - public int getTargetId() { - return targetId; - } - - public void setTargetId(int targetId) { - this.targetId = targetId; - } - - public IPartitioner getPartition() { - return partition; - } - - public void setPartition(IPartitioner partition) { - this.partition = partition; - } - - public int getStreamOrdinal() { - return streamOrdinal; - } - - public void setStreamOrdinal(int streamOrdinal) { - this.streamOrdinal = streamOrdinal; - } - - public IEncoder getEncoder() { - return this.encoder; - } - - public void setEncoder(IEncoder encoder) { - this.encoder = encoder; - } - - public String getEdgeName() { - if (edgeName != null) { - return edgeName; - } - - String partitionName = ""; - if (this.partition.getPartition() != null && !(this.partition - .getPartition() instanceof RandomPartition)) { - partitionName = - "-partitionFunc-" + this.partition.getPartition().getClass().getSimpleName(); - } - - if (srcId != targetId) { - return this.edgeId + "-stream-from" + srcId + "-to" + targetId + partitionName; - } else { - return this.edgeId + "-stream-from" + srcId + "-IteratorStream"; - } - } - - public OutputType getType() { - return type; - } - - public void setEdgeName(String edgeName) { - this.edgeName = edgeName; - } - - @Override - public final int hashCode() { - return Objects.hash(this.edgeId, this.srcId, this.targetId, edgeName); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof PipelineEdge) { - PipelineEdge other = (PipelineEdge) obj; - return other.getSrcId() == srcId && other.getTargetId() == targetId - && other.getEdgeId() == edgeId && other.getPartitionType() == partitionType - && Objects.equals(getEdgeName(), other.getEdgeName()); - } - return false; - } - - public enum JoinStream { - /** - * left. - */ - left, - /** - * right. - */ - right, - /** - * none. - */ - none; - } + private int edgeId; + private int srcId; + private int targetId; + private int streamOrdinal; + private IPartitioner partition; + private PartitionType partitionType; + private String edgeName; + private OutputType type; + private IEncoder encoder; + + public PipelineEdge( + int edgeId, int srcId, int targetId, IPartitioner partition, IEncoder encoder) { + this(edgeId, srcId, targetId, partition, 0, encoder); + } + + public PipelineEdge( + int edgeId, + int srcId, + int targetId, + IPartitioner partition, + IEncoder encoder, + OutputType type) { + this(edgeId, srcId, targetId, partition, 0, encoder); + this.type = type; + } + + public PipelineEdge( + int edgeId, + int srcId, + int targetId, + IPartitioner partition, + int streamOrdinal, + IEncoder encoder) { + this.edgeId = edgeId; + this.srcId = srcId; + this.targetId = targetId; + this.partition = partition; + this.streamOrdinal = streamOrdinal; + this.partitionType = partition.getPartitionType(); + this.encoder = encoder; + } + + public PartitionType getPartitionType() { + return partitionType; + } + + public void setPartitionType(PartitionType partitionType) { + this.partitionType = partitionType; + } + + public int getEdgeId() { + return edgeId; + } + + public void setEdgeId(int edgeId) { + this.edgeId = edgeId; + } + + public int getSrcId() { + return srcId; + } + + public void setSrcId(int srcId) { + this.srcId = srcId; + } + + public int getTargetId() { + return targetId; + } + + public void setTargetId(int targetId) { + this.targetId = targetId; + } + + public IPartitioner getPartition() { + return partition; + } + + public void setPartition(IPartitioner partition) { + this.partition = partition; + } + + public int getStreamOrdinal() { + return streamOrdinal; + } + + public void setStreamOrdinal(int streamOrdinal) { + this.streamOrdinal = streamOrdinal; + } + + public IEncoder getEncoder() { + return this.encoder; + } + + public void setEncoder(IEncoder encoder) { + this.encoder = encoder; + } + + public String getEdgeName() { + if (edgeName != null) { + return edgeName; + } + + String partitionName = ""; + if (this.partition.getPartition() != null + && !(this.partition.getPartition() instanceof RandomPartition)) { + partitionName = "-partitionFunc-" + this.partition.getPartition().getClass().getSimpleName(); + } + + if (srcId != targetId) { + return this.edgeId + "-stream-from" + srcId + "-to" + targetId + partitionName; + } else { + return this.edgeId + "-stream-from" + srcId + "-IteratorStream"; + } + } + + public OutputType getType() { + return type; + } + + public void setEdgeName(String edgeName) { + this.edgeName = edgeName; + } + + @Override + public final int hashCode() { + return Objects.hash(this.edgeId, this.srcId, this.targetId, edgeName); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof PipelineEdge) { + PipelineEdge other = (PipelineEdge) obj; + return other.getSrcId() == srcId + && other.getTargetId() == targetId + && other.getEdgeId() == edgeId + && other.getPartitionType() == partitionType + && Objects.equals(getEdgeName(), other.getEdgeName()); + } + return false; + } + + public enum JoinStream { + /** left. */ + left, + /** right. */ + right, + /** none. */ + none; + } } - diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineGraph.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineGraph.java index adde0e5df..7af3f69d3 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineGraph.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineGraph.java @@ -19,7 +19,6 @@ package org.apache.geaflow.plan.graph; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -31,142 +30,149 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PipelineGraph implements Serializable { - - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineGraph.class); - - private final Map vertexMap; - private final Map edgeMap; - private final Map> vertexOutputEdgeIds; - private final Map> vertexInputEdgeIds; - - public PipelineGraph() { - this.vertexMap = new HashMap<>(); - this.edgeMap = new LinkedHashMap<>(); - this.vertexOutputEdgeIds = new HashMap<>(); - this.vertexInputEdgeIds = new HashMap<>(); - } - - public void addVertex(PipelineVertex pipelineVertex) { - LOGGER.info("add vertex:{} {}", pipelineVertex.getVertexId(), pipelineVertex.getName()); - this.vertexMap.put(pipelineVertex.getVertexId(), pipelineVertex); - } - - public void setPipelineVertices(Set pipelineVertices) { - this.vertexMap.clear(); - pipelineVertices.stream().forEach(pipelineVertex -> addVertex(pipelineVertex)); - } - - public void addEdge(PipelineEdge pipelineEdge) { - LOGGER.info("add edgeId:{}, edgeName:{}, srcId:{}, targetId:{}", pipelineEdge.getEdgeId(), - pipelineEdge.getEdgeName(), pipelineEdge.getSrcId(), pipelineEdge.getTargetId()); - this.edgeMap.put(pipelineEdge.getEdgeId(), pipelineEdge); - - int edgeId = pipelineEdge.getEdgeId(); - int srcId = pipelineEdge.getSrcId(); - int tarId = pipelineEdge.getTargetId(); - if (vertexOutputEdgeIds.containsKey(srcId)) { - vertexOutputEdgeIds.get(srcId).add(edgeId); - } else { - vertexOutputEdgeIds.put(srcId, Lists.newArrayList(edgeId)); - } - - if (vertexInputEdgeIds.containsKey(tarId)) { - vertexInputEdgeIds.get(tarId).add(edgeId); - } else { - vertexInputEdgeIds.put(tarId, Lists.newArrayList(edgeId)); - } - } - - public void setPipelineEdges(Set pipelineEdges) { - this.edgeMap.clear(); - this.vertexInputEdgeIds.clear(); - this.vertexOutputEdgeIds.clear(); - - pipelineEdges.stream().forEach(pipelineEdge -> addEdge(pipelineEdge)); - } - - public Collection getPipelineEdgeList() { - return this.edgeMap.values(); - } - - public Map getVertexMap() { - return vertexMap; - } - - public Collection getPipelineVertices() { - return vertexMap.values(); - } - - public Set getVertexOutEdges(int vertexId) { - return getVertexEdgesByIds(this.vertexOutputEdgeIds.get(vertexId)); - } - - public Set getVertexInputEdges(int vertexId) { - return getVertexEdgesByIds(this.vertexInputEdgeIds.get(vertexId)); - } - - private Set getVertexEdgesByIds(List edgeIds) { - if (edgeIds == null) { - return new HashSet<>(); - } - LinkedHashSet edges = new LinkedHashSet<>(); - edgeIds.stream().map(id -> edgeMap.get(id)).forEach(edges::add); - return edges; - } - - public Map> getVertexInputEdges() { - Map> inputEdges = new HashMap<>(vertexMap.size()); - for (int key : vertexMap.keySet()) { - inputEdges.put(key, new HashSet<>()); - } - for (PipelineEdge executeEdge : getPipelineEdgeList()) { - inputEdges.get(executeEdge.getTargetId()).add(executeEdge); - - } - return inputEdges; - } - - public Map> getVertexOutputEdges() { - Map> outputEdges = new HashMap<>(vertexMap.size()); - for (int key : vertexMap.keySet()) { - outputEdges.put(key, new HashSet<>()); - } - for (PipelineEdge executeEdge : getPipelineEdgeList()) { - outputEdges.get(executeEdge.getSrcId()).add(executeEdge); - } - return outputEdges; - } - - public List getSourceVertices() { - List sourceVertices = new ArrayList<>(); - for (Map.Entry entry : vertexMap.entrySet()) { - if (entry.getValue().getType() == VertexType.source) { - sourceVertices.add(entry.getValue()); - } - } - return sourceVertices; - } +import com.google.common.collect.Lists; - public List getVertexInputVertexIds(int vertexId) { - if (vertexInputEdgeIds.get(vertexId) != null) { - return vertexInputEdgeIds.get(vertexId).stream() - .map(edgeId -> edgeMap.get(edgeId).getSrcId()).collect(Collectors.toList()); - } else { - return new ArrayList<>(); - } - } +public class PipelineGraph implements Serializable { - public List getVertexOutputVertexIds(int vertexId) { - if (vertexOutputEdgeIds.get(vertexId) != null) { - return vertexOutputEdgeIds.get(vertexId).stream() - .map(edgeId -> edgeMap.get(edgeId).getTargetId()).collect(Collectors.toList()); - } else { - return new ArrayList<>(); - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineGraph.class); + + private final Map vertexMap; + private final Map edgeMap; + private final Map> vertexOutputEdgeIds; + private final Map> vertexInputEdgeIds; + + public PipelineGraph() { + this.vertexMap = new HashMap<>(); + this.edgeMap = new LinkedHashMap<>(); + this.vertexOutputEdgeIds = new HashMap<>(); + this.vertexInputEdgeIds = new HashMap<>(); + } + + public void addVertex(PipelineVertex pipelineVertex) { + LOGGER.info("add vertex:{} {}", pipelineVertex.getVertexId(), pipelineVertex.getName()); + this.vertexMap.put(pipelineVertex.getVertexId(), pipelineVertex); + } + + public void setPipelineVertices(Set pipelineVertices) { + this.vertexMap.clear(); + pipelineVertices.stream().forEach(pipelineVertex -> addVertex(pipelineVertex)); + } + + public void addEdge(PipelineEdge pipelineEdge) { + LOGGER.info( + "add edgeId:{}, edgeName:{}, srcId:{}, targetId:{}", + pipelineEdge.getEdgeId(), + pipelineEdge.getEdgeName(), + pipelineEdge.getSrcId(), + pipelineEdge.getTargetId()); + this.edgeMap.put(pipelineEdge.getEdgeId(), pipelineEdge); + + int edgeId = pipelineEdge.getEdgeId(); + int srcId = pipelineEdge.getSrcId(); + int tarId = pipelineEdge.getTargetId(); + if (vertexOutputEdgeIds.containsKey(srcId)) { + vertexOutputEdgeIds.get(srcId).add(edgeId); + } else { + vertexOutputEdgeIds.put(srcId, Lists.newArrayList(edgeId)); + } + + if (vertexInputEdgeIds.containsKey(tarId)) { + vertexInputEdgeIds.get(tarId).add(edgeId); + } else { + vertexInputEdgeIds.put(tarId, Lists.newArrayList(edgeId)); + } + } + + public void setPipelineEdges(Set pipelineEdges) { + this.edgeMap.clear(); + this.vertexInputEdgeIds.clear(); + this.vertexOutputEdgeIds.clear(); + + pipelineEdges.stream().forEach(pipelineEdge -> addEdge(pipelineEdge)); + } + + public Collection getPipelineEdgeList() { + return this.edgeMap.values(); + } + + public Map getVertexMap() { + return vertexMap; + } + + public Collection getPipelineVertices() { + return vertexMap.values(); + } + + public Set getVertexOutEdges(int vertexId) { + return getVertexEdgesByIds(this.vertexOutputEdgeIds.get(vertexId)); + } + + public Set getVertexInputEdges(int vertexId) { + return getVertexEdgesByIds(this.vertexInputEdgeIds.get(vertexId)); + } + + private Set getVertexEdgesByIds(List edgeIds) { + if (edgeIds == null) { + return new HashSet<>(); + } + LinkedHashSet edges = new LinkedHashSet<>(); + edgeIds.stream().map(id -> edgeMap.get(id)).forEach(edges::add); + return edges; + } + + public Map> getVertexInputEdges() { + Map> inputEdges = new HashMap<>(vertexMap.size()); + for (int key : vertexMap.keySet()) { + inputEdges.put(key, new HashSet<>()); + } + for (PipelineEdge executeEdge : getPipelineEdgeList()) { + inputEdges.get(executeEdge.getTargetId()).add(executeEdge); + } + return inputEdges; + } + + public Map> getVertexOutputEdges() { + Map> outputEdges = new HashMap<>(vertexMap.size()); + for (int key : vertexMap.keySet()) { + outputEdges.put(key, new HashSet<>()); + } + for (PipelineEdge executeEdge : getPipelineEdgeList()) { + outputEdges.get(executeEdge.getSrcId()).add(executeEdge); + } + return outputEdges; + } + + public List getSourceVertices() { + List sourceVertices = new ArrayList<>(); + for (Map.Entry entry : vertexMap.entrySet()) { + if (entry.getValue().getType() == VertexType.source) { + sourceVertices.add(entry.getValue()); + } + } + return sourceVertices; + } + + public List getVertexInputVertexIds(int vertexId) { + if (vertexInputEdgeIds.get(vertexId) != null) { + return vertexInputEdgeIds.get(vertexId).stream() + .map(edgeId -> edgeMap.get(edgeId).getSrcId()) + .collect(Collectors.toList()); + } else { + return new ArrayList<>(); + } + } + + public List getVertexOutputVertexIds(int vertexId) { + if (vertexOutputEdgeIds.get(vertexId) != null) { + return vertexOutputEdgeIds.get(vertexId).stream() + .map(edgeId -> edgeMap.get(edgeId).getTargetId()) + .collect(Collectors.toList()); + } else { + return new ArrayList<>(); + } + } } - diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineVertex.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineVertex.java index 8c31ba979..5a947b8c1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineVertex.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/PipelineVertex.java @@ -21,142 +21,144 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.operator.Operator; import org.apache.geaflow.operator.base.AbstractOperator; public class PipelineVertex implements Serializable { - private int vertexId; - private OP operator; - private int parallelism; - private long iterations; - private VertexType type; - private VertexMode vertexMode; - private AffinityLevel affinity; - private boolean duplication; - private VertexType chainTailType; - - public PipelineVertex(int vertexId, OP operator, int parallelism) { - this.vertexId = vertexId; - this.operator = operator; - this.parallelism = parallelism; - this.iterations = 1; - this.affinity = AffinityLevel.worker; - } - - public PipelineVertex(int vertexId, OP operator, VertexType type, int parallelism) { - this(vertexId, operator, parallelism); - this.type = type; - this.chainTailType = type; - } - - public PipelineVertex(int vertexId, VertexType vertexType, OP operator, VertexMode vertexMode) { - this(vertexId, operator, vertexType, ((AbstractOperator) operator).getOpArgs().getParallelism()); - this.vertexMode = vertexMode; - } - - public boolean isDuplication() { - return duplication; - } - - public void setDuplication() { - this.duplication = true; - } - - public int getParallelism() { - return parallelism; - } - - public void setParallelism(int parallelism) { - this.parallelism = parallelism; - } - - public int getVertexId() { - return vertexId; - } - - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } - - public OP getOperator() { - return operator; - } - - public void setOperator(OP operator) { - this.operator = operator; - } - - public VertexType getType() { - return type; - } - - public void setType(VertexType type) { - this.type = type; - this.chainTailType = type; - } - - public String getName() { - if (this.operator != null) { - return this.operator.getClass().getSimpleName(); - } - return null; - } - - public VertexMode getVertexMode() { - return vertexMode; - } - - public void setVertexMode(VertexMode vertexMode) { - this.vertexMode = vertexMode; - } - - public String getVertexName() { - return "node-" + this.vertexId; - } - - public long getIterations() { - return iterations; - } - - public void setIterations(long iterations) { - this.iterations = iterations; - } - - public AffinityLevel getAffinity() { - return affinity; - } - - public void setAffinity(AffinityLevel affinity) { - this.affinity = affinity; - } - - public VertexType getChainTailType() { - return chainTailType; - } - - public void setChainTailType(VertexType chainTailType) { - this.chainTailType = chainTailType; - } - - public String getVertexString() { - String operatorStr = operator.toString(); - return String.format("%s, p:%d, %s", getVertexName(), parallelism, operatorStr); - } - - @Override - public int hashCode() { - return Objects.hash(this.getVertexId()); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof PipelineVertex) { - PipelineVertex other = (PipelineVertex) obj; - if (other.getVertexId() == this.vertexId) { - return true; - } - } - return false; - } + private int vertexId; + private OP operator; + private int parallelism; + private long iterations; + private VertexType type; + private VertexMode vertexMode; + private AffinityLevel affinity; + private boolean duplication; + private VertexType chainTailType; + + public PipelineVertex(int vertexId, OP operator, int parallelism) { + this.vertexId = vertexId; + this.operator = operator; + this.parallelism = parallelism; + this.iterations = 1; + this.affinity = AffinityLevel.worker; + } + + public PipelineVertex(int vertexId, OP operator, VertexType type, int parallelism) { + this(vertexId, operator, parallelism); + this.type = type; + this.chainTailType = type; + } + + public PipelineVertex(int vertexId, VertexType vertexType, OP operator, VertexMode vertexMode) { + this( + vertexId, operator, vertexType, ((AbstractOperator) operator).getOpArgs().getParallelism()); + this.vertexMode = vertexMode; + } + + public boolean isDuplication() { + return duplication; + } + + public void setDuplication() { + this.duplication = true; + } + + public int getParallelism() { + return parallelism; + } + + public void setParallelism(int parallelism) { + this.parallelism = parallelism; + } + + public int getVertexId() { + return vertexId; + } + + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } + + public OP getOperator() { + return operator; + } + + public void setOperator(OP operator) { + this.operator = operator; + } + + public VertexType getType() { + return type; + } + + public void setType(VertexType type) { + this.type = type; + this.chainTailType = type; + } + + public String getName() { + if (this.operator != null) { + return this.operator.getClass().getSimpleName(); + } + return null; + } + + public VertexMode getVertexMode() { + return vertexMode; + } + + public void setVertexMode(VertexMode vertexMode) { + this.vertexMode = vertexMode; + } + + public String getVertexName() { + return "node-" + this.vertexId; + } + + public long getIterations() { + return iterations; + } + + public void setIterations(long iterations) { + this.iterations = iterations; + } + + public AffinityLevel getAffinity() { + return affinity; + } + + public void setAffinity(AffinityLevel affinity) { + this.affinity = affinity; + } + + public VertexType getChainTailType() { + return chainTailType; + } + + public void setChainTailType(VertexType chainTailType) { + this.chainTailType = chainTailType; + } + + public String getVertexString() { + String operatorStr = operator.toString(); + return String.format("%s, p:%d, %s", getVertexName(), parallelism, operatorStr); + } + + @Override + public int hashCode() { + return Objects.hash(this.getVertexId()); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof PipelineVertex) { + PipelineVertex other = (PipelineVertex) obj; + if (other.getVertexId() == this.vertexId) { + return true; + } + } + return false; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexMode.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexMode.java index fbfe67338..5a8060ff4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexMode.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexMode.java @@ -20,12 +20,8 @@ package org.apache.geaflow.plan.graph; public enum VertexMode { - /** - * Append mode. - */ - append, - /** - * Update mode. - */ - update; + /** Append mode. */ + append, + /** Update mode. */ + update; } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexType.java index dff8698d6..466a083cb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/graph/VertexType.java @@ -20,60 +20,32 @@ package org.apache.geaflow.plan.graph; public enum VertexType { - /** - * Source vertex. - */ - source, - /** - * Process vertex. - */ - process, - /** - * Incremental process vertex. - */ - inc_process, - /** - * Combine vertex. - */ - combine, - /** - * Join vertex. - */ - join, - /** - * Union vertex. - */ - union, - /** - * Partition vertex. - */ - partition, - /** - * Vertex centric vertex. - */ - vertex_centric, - /** - * Incremental Vertex centric vertex. - */ - inc_vertex_centric, - /** - * Iterator vertex. - */ - iterator, - /** - * Incremental Iterator vertex. - */ - inc_iterator, - /** - * Sink vertex. - */ - sink, - /** - * Collect vertex. - */ - collect, - /** - * Iteration aggregation vertex. - */ - iteration_aggregation + /** Source vertex. */ + source, + /** Process vertex. */ + process, + /** Incremental process vertex. */ + inc_process, + /** Combine vertex. */ + combine, + /** Join vertex. */ + join, + /** Union vertex. */ + union, + /** Partition vertex. */ + partition, + /** Vertex centric vertex. */ + vertex_centric, + /** Incremental Vertex centric vertex. */ + inc_vertex_centric, + /** Iterator vertex. */ + iterator, + /** Incremental Iterator vertex. */ + inc_iterator, + /** Sink vertex. */ + sink, + /** Collect vertex. */ + collect, + /** Iteration aggregation vertex. */ + iteration_aggregation } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/PipelineGraphOptimizer.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/PipelineGraphOptimizer.java index ab979332c..ecc80da5b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/PipelineGraphOptimizer.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/PipelineGraphOptimizer.java @@ -20,19 +20,20 @@ package org.apache.geaflow.plan.optimizer; import java.io.Serializable; + import org.apache.geaflow.plan.graph.PipelineGraph; import org.apache.geaflow.plan.optimizer.strategy.ChainCombiner; import org.apache.geaflow.plan.optimizer.strategy.SingleWindowGroupRule; public class PipelineGraphOptimizer implements Serializable { - public void optimizePipelineGraph(PipelineGraph pipelineGraph) { - // Enforce chain combiner opt. - ChainCombiner chainCombiner = new ChainCombiner(); - chainCombiner.combineVertex(pipelineGraph); + public void optimizePipelineGraph(PipelineGraph pipelineGraph) { + // Enforce chain combiner opt. + ChainCombiner chainCombiner = new ChainCombiner(); + chainCombiner.combineVertex(pipelineGraph); - // Enforce single window rule. - SingleWindowGroupRule groupRule = new SingleWindowGroupRule(); - groupRule.apply(pipelineGraph); - } + // Enforce single window rule. + SingleWindowGroupRule groupRule = new SingleWindowGroupRule(); + groupRule.apply(pipelineGraph); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/UnionOptimizer.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/UnionOptimizer.java index fdd8647e6..c88785cfd 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/UnionOptimizer.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/UnionOptimizer.java @@ -19,8 +19,6 @@ package org.apache.geaflow.plan.optimizer; -import com.google.common.base.Preconditions; -import com.google.common.collect.ArrayListMultimap; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -39,6 +37,7 @@ import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.io.input.ClassLoaderObjectInputStream; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -57,362 +56,388 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.common.collect.ArrayListMultimap; + public class UnionOptimizer implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(UnionOptimizer.class); + private static final Logger LOGGER = LoggerFactory.getLogger(UnionOptimizer.class); - private Map vertexMap; - private Map> outputEdges; - private Map> inputEdges; + private Map vertexMap; + private Map> outputEdges; + private Map> inputEdges; - private Set visited = new HashSet<>(); + private Set visited = new HashSet<>(); - private Set newVertices = new HashSet<>(); - private Set newEdges = new LinkedHashSet<>(); - private Map newVertexMap = new HashMap<>(); + private Set newVertices = new HashSet<>(); + private Set newEdges = new LinkedHashSet<>(); + private Map newVertexMap = new HashMap<>(); - private boolean needOptimize; + private boolean needOptimize; - private List statelessOperator = new ArrayList<>( - Arrays.asList(FilterOperator.class, MapOperator.class, UnionOperator.class, - KeySelectorOperator.class)); + private List statelessOperator = + new ArrayList<>( + Arrays.asList( + FilterOperator.class, + MapOperator.class, + UnionOperator.class, + KeySelectorOperator.class)); - private PipelineVertex unionVertex; - private int removeUnionNum = 0; + private PipelineVertex unionVertex; + private int removeUnionNum = 0; - public UnionOptimizer(boolean extraOptimizeSink) { - LOGGER.info("extraOptimizeSink {}", extraOptimizeSink); - if (extraOptimizeSink) { - statelessOperator.add(SinkOperator.class); - } + public UnionOptimizer(boolean extraOptimizeSink) { + LOGGER.info("extraOptimizeSink {}", extraOptimizeSink); + if (extraOptimizeSink) { + statelessOperator.add(SinkOperator.class); } - - private void init(PipelineGraph plan) { - this.vertexMap = plan.getVertexMap(); - this.outputEdges = plan.getVertexOutputEdges(); - this.inputEdges = plan.getVertexInputEdges(); - this.needOptimize = false; + } + + private void init(PipelineGraph plan) { + this.vertexMap = plan.getVertexMap(); + this.outputEdges = plan.getVertexOutputEdges(); + this.inputEdges = plan.getVertexInputEdges(); + this.needOptimize = false; + } + + /** + * Union push-up algorithm: 1. DFS push up vertex, 2. Kahn split node, 3. Rearrange. All operators + * are rearranged to ensure algorithm reentrant. + */ + public boolean optimizePlan(PipelineGraph plan) { + init(plan); + + if (!pushUpPartitionFunction(plan)) { + return false; + } + try { + plan.getSourceVertices().forEach(this::dfs); + if (!needOptimize) { + return false; + } + kahn(plan.getSourceVertices(), 0); + } catch (Exception ex) { + LOGGER.warn("Unexpected exception happened while optimizing, thus give up", ex); + return false; } /** - * Union push-up algorithm: 1. DFS push up vertex, 2. Kahn split node, 3. Rearrange. - * All operators are rearranged to ensure algorithm reentrant. + * Check Validation. We use dfs to find union vertex lacking global view, which can lead to + * wrong results. */ - public boolean optimizePlan(PipelineGraph plan) { - init(plan); - - if (!pushUpPartitionFunction(plan)) { - return false; - } - try { - plan.getSourceVertices().forEach(this::dfs); - if (!needOptimize) { - return false; - } - kahn(plan.getSourceVertices(), 0); - } catch (Exception ex) { - LOGGER.warn("Unexpected exception happened while optimizing, thus give up", ex); - return false; - } - - /** - * Check Validation. - * We use dfs to find union vertex lacking global view, which can lead to wrong results. - */ - if (newVertices.size() + removeUnionNum < plan.getVertexMap().size()) { - LOGGER.warn(String.format( - "vertices number %s plus remove union vertices num %s is smaller after optimization %s, " - + "this is not right, thus give up"), - newVertices.size(), removeUnionNum, - plan.getVertexMap().size()); - return false; - } - - plan.setPipelineEdges(newEdges); - plan.setPipelineVertices(newVertices); - return true; + if (newVertices.size() + removeUnionNum < plan.getVertexMap().size()) { + LOGGER.warn( + String.format( + "vertices number %s plus remove union vertices num %s is smaller after optimization" + + " %s, this is not right, thus give up"), + newVertices.size(), + removeUnionNum, + plan.getVertexMap().size()); + return false; } - private boolean pushUpPartitionFunction(PipelineGraph plan) { - for (PipelineVertex vertex : plan.getPipelineVertices()) { - if (vertex.getOperator() instanceof UnionOperator) { - PipelineEdge edge = getOutEdgeFromVertex(vertex); - if (edge != null && !edge.getPartition().getPartitionType().isEnablePushUp()) { - IPartitioner partitionFunction = edge.getPartition(); - edge.setPartitionType(IPartitioner.PartitionType.forward); - for (PipelineEdge inEdge : plan.getVertexInputEdges().get(vertex.getVertexId())) { - if (inEdge.getPartitionType() == IPartitioner.PartitionType.key - || inEdge.getPartition() != null) { - return false; - } - inEdge.setPartition(partitionFunction); - edge.setPartitionType(IPartitioner.PartitionType.key); - } - } + plan.setPipelineEdges(newEdges); + plan.setPipelineVertices(newVertices); + return true; + } + + private boolean pushUpPartitionFunction(PipelineGraph plan) { + for (PipelineVertex vertex : plan.getPipelineVertices()) { + if (vertex.getOperator() instanceof UnionOperator) { + PipelineEdge edge = getOutEdgeFromVertex(vertex); + if (edge != null && !edge.getPartition().getPartitionType().isEnablePushUp()) { + IPartitioner partitionFunction = edge.getPartition(); + edge.setPartitionType(IPartitioner.PartitionType.forward); + for (PipelineEdge inEdge : plan.getVertexInputEdges().get(vertex.getVertexId())) { + if (inEdge.getPartitionType() == IPartitioner.PartitionType.key + || inEdge.getPartition() != null) { + return false; } + inEdge.setPartition(partitionFunction); + edge.setPartitionType(IPartitioner.PartitionType.key); + } } - return true; + } } + return true; + } - /** - * Push up vertex, and try to remove union vertex. - */ - private void dfs(PipelineVertex vertex) { - if (!visited.add(vertex)) { - return; - } - LOGGER.debug("visit vertex {}", vertex); - if (needPushUp(vertex)) { - pushUp(vertex); - vertex = unionVertex; - this.needOptimize = true; - } else if (unionVertex != null) { - tryRemoveVertex(unionVertex); - unionVertex = null; - } - - for (PipelineEdge executeEdge : outputEdges.get(vertex.getVertexId())) { - PipelineVertex nextVertex = vertexMap.get(executeEdge.getTargetId()); - dfs(nextVertex); - } + /** Push up vertex, and try to remove union vertex. */ + private void dfs(PipelineVertex vertex) { + if (!visited.add(vertex)) { + return; } - - private PipelineEdge getOutEdgeFromVertex(PipelineVertex vertex) { - Iterator it = outputEdges.get(vertex.getVertexId()).iterator(); - if (it.hasNext()) { - return it.next(); - } - return null; + LOGGER.debug("visit vertex {}", vertex); + if (needPushUp(vertex)) { + pushUp(vertex); + vertex = unionVertex; + this.needOptimize = true; + } else if (unionVertex != null) { + tryRemoveVertex(unionVertex); + unionVertex = null; } - /** - * Operator push up. - * The outgoing edge has been modified at setTargetId, so only need to focus on the incoming edge. - */ - private void pushUp(PipelineVertex sVertex) { - sVertex.setDuplication(); - if (sVertex.equals(unionVertex)) { - return; - } - LOGGER.info("pushUp vertex {}", sVertex); - - PipelineEdge unionOutEdge = getOutEdgeFromVertex(unionVertex); - Preconditions.checkNotNull(unionOutEdge); - // SOutEdge can be null. - PipelineEdge sOutEdge = getOutEdgeFromVertex(sVertex); - - // Incoming edge of sVertex modify. - inputEdges.get(sVertex.getVertexId()).remove(unionOutEdge); - for (PipelineEdge unionInEdge : inputEdges.get(unionVertex.getVertexId())) { - unionInEdge.setTargetId(sVertex.getVertexId()); - inputEdges.get(sVertex.getVertexId()).add(unionInEdge); - } + for (PipelineEdge executeEdge : outputEdges.get(vertex.getVertexId())) { + PipelineVertex nextVertex = vertexMap.get(executeEdge.getTargetId()); + dfs(nextVertex); + } + } - // Original input edge of sOutEdge modify. - if (sOutEdge != null) { - inputEdges.get(sOutEdge.getTargetId()).remove(sOutEdge); - unionOutEdge.setTargetId(sOutEdge.getTargetId()); - inputEdges.get(sOutEdge.getTargetId()).add(unionOutEdge); - } + private PipelineEdge getOutEdgeFromVertex(PipelineVertex vertex) { + Iterator it = outputEdges.get(vertex.getVertexId()).iterator(); + if (it.hasNext()) { + return it.next(); + } + return null; + } + + /** + * Operator push up. The outgoing edge has been modified at setTargetId, so only need to focus on + * the incoming edge. + */ + private void pushUp(PipelineVertex sVertex) { + sVertex.setDuplication(); + if (sVertex.equals(unionVertex)) { + return; + } + LOGGER.info("pushUp vertex {}", sVertex); + + PipelineEdge unionOutEdge = getOutEdgeFromVertex(unionVertex); + Preconditions.checkNotNull(unionOutEdge); + // SOutEdge can be null. + PipelineEdge sOutEdge = getOutEdgeFromVertex(sVertex); + + // Incoming edge of sVertex modify. + inputEdges.get(sVertex.getVertexId()).remove(unionOutEdge); + for (PipelineEdge unionInEdge : inputEdges.get(unionVertex.getVertexId())) { + unionInEdge.setTargetId(sVertex.getVertexId()); + inputEdges.get(sVertex.getVertexId()).add(unionInEdge); + } - // The incoming edge of unionVertex modify. - inputEdges.get(unionVertex.getVertexId()).clear(); - if (sOutEdge != null) { - sOutEdge.setTargetId(unionVertex.getVertexId()); - inputEdges.get(unionVertex.getVertexId()).add(sOutEdge); - } + // Original input edge of sOutEdge modify. + if (sOutEdge != null) { + inputEdges.get(sOutEdge.getTargetId()).remove(sOutEdge); + unionOutEdge.setTargetId(sOutEdge.getTargetId()); + inputEdges.get(sOutEdge.getTargetId()).add(unionOutEdge); + } - // Key partition pass. - if (sOutEdge != null && sOutEdge.getPartitionType().equals(IPartitioner.PartitionType.key)) { - unionOutEdge.setStreamOrdinal(sOutEdge.getStreamOrdinal()); - unionOutEdge.setPartitionType(IPartitioner.PartitionType.key); - sOutEdge.setPartitionType(IPartitioner.PartitionType.forward); - } + // The incoming edge of unionVertex modify. + inputEdges.get(unionVertex.getVertexId()).clear(); + if (sOutEdge != null) { + sOutEdge.setTargetId(unionVertex.getVertexId()); + inputEdges.get(unionVertex.getVertexId()).add(sOutEdge); + } - tryRemoveVertex(sVertex); + // Key partition pass. + if (sOutEdge != null && sOutEdge.getPartitionType().equals(IPartitioner.PartitionType.key)) { + unionOutEdge.setStreamOrdinal(sOutEdge.getStreamOrdinal()); + unionOutEdge.setPartitionType(IPartitioner.PartitionType.key); + sOutEdge.setPartitionType(IPartitioner.PartitionType.forward); } - private boolean needPushUp(PipelineVertex vertex) { + tryRemoveVertex(sVertex); + } - // 1. Multi output edges not process. - if (outputEdges.get(vertex.getVertexId()).size() > 1) { - return false; - } + private boolean needPushUp(PipelineVertex vertex) { - // If partition type is custom then not push up. - for (PipelineEdge edge : outputEdges.get(vertex.getVertexId())) { - if (edge.getPartition() != null && !edge.getPartition().getPartitionType().isEnablePushUp()) { - return false; - } - } + // 1. Multi output edges not process. + if (outputEdges.get(vertex.getVertexId()).size() > 1) { + return false; + } - // 2. Upstream has union vertex and downstream vertex exists. - if (unionVertex != null && getOutVertexIdSet(unionVertex).contains(vertex.getVertexId()) - && statelessOperator.contains(vertex.getOperator().getClass())) { - return true; - } - if (vertex.getOperator() instanceof UnionOperator) { - unionVertex = vertex; - return true; - } + // If partition type is custom then not push up. + for (PipelineEdge edge : outputEdges.get(vertex.getVertexId())) { + if (edge.getPartition() != null && !edge.getPartition().getPartitionType().isEnablePushUp()) { return false; + } } - /** - * Try remove union vertex, modify incoming edge relation of downstream, upstream modify targetId of edge directly. - * If outgoing edge of union vertex carries key property, we need pass to upstream. - */ - private void tryRemoveVertex(PipelineVertex vertex) { - LOGGER.debug("try remove vertex {}", vertex); - if (vertex.isDuplication() && vertex.getOperator() instanceof UnionOperator - && outputEdges.get(vertex.getVertexId()).size() == 1) { - LOGGER.info("remove vertex {}", vertex); - removeUnionNum++; - PipelineEdge outputEdge = outputEdges.get(vertex.getVertexId()).iterator().next(); - int targetId = outputEdge.getTargetId(); - inputEdges.get(targetId).remove(outputEdge); - - for (PipelineEdge inEdge : inputEdges.get(vertex.getVertexId())) { - inEdge.setTargetId(targetId); - inputEdges.get(targetId).add(inEdge); - if (outputEdge.getPartitionType() == IPartitioner.PartitionType.key) { - inEdge.setPartitionType(IPartitioner.PartitionType.key); - } - } + // 2. Upstream has union vertex and downstream vertex exists. + if (unionVertex != null + && getOutVertexIdSet(unionVertex).contains(vertex.getVertexId()) + && statelessOperator.contains(vertex.getOperator().getClass())) { + return true; + } + if (vertex.getOperator() instanceof UnionOperator) { + unionVertex = vertex; + return true; + } + return false; + } + + /** + * Try remove union vertex, modify incoming edge relation of downstream, upstream modify targetId + * of edge directly. If outgoing edge of union vertex carries key property, we need pass to + * upstream. + */ + private void tryRemoveVertex(PipelineVertex vertex) { + LOGGER.debug("try remove vertex {}", vertex); + if (vertex.isDuplication() + && vertex.getOperator() instanceof UnionOperator + && outputEdges.get(vertex.getVertexId()).size() == 1) { + LOGGER.info("remove vertex {}", vertex); + removeUnionNum++; + PipelineEdge outputEdge = outputEdges.get(vertex.getVertexId()).iterator().next(); + int targetId = outputEdge.getTargetId(); + inputEdges.get(targetId).remove(outputEdge); + + for (PipelineEdge inEdge : inputEdges.get(vertex.getVertexId())) { + inEdge.setTargetId(targetId); + inputEdges.get(targetId).add(inEdge); + if (outputEdge.getPartitionType() == IPartitioner.PartitionType.key) { + inEdge.setPartitionType(IPartitioner.PartitionType.key); } + } } - - /** - * Topological sorting is used to ensure the sequence of ids. - */ - private void kahn(List vertices, int id) throws IOException, ClassNotFoundException { - Queue toVisitQueue = new ArrayDeque<>(vertices); - Set visitedEdge = new HashSet<>(); - ArrayListMultimap oldIdToNewIdMap = ArrayListMultimap.create(); - - while (!toVisitQueue.isEmpty()) { - PipelineVertex vertex = toVisitQueue.poll(); - if (!vertex.isDuplication()) { - id++; - PipelineVertex newVertex = cloneVertex(vertex, id, 0); - newVertices.add(newVertex); - newVertexMap.put(id, newVertex); - oldIdToNewIdMap.put(vertex.getVertexId(), id); - } - - // Consider the condition that in edge and the previous node fission. - int index = 0; - for (PipelineEdge inEdge : inputEdges.get(vertex.getVertexId())) { - PipelineVertex oriSrcVertex = vertexMap.get(inEdge.getSrcId()); - int oriSrcId = oriSrcVertex.getVertexId(); - // The previous node fission. - for (Integer srcId : oldIdToNewIdMap.get(oriSrcId)) { - if (vertex.isDuplication()) { - id++; - // The parallelism of the new upstream node is used to ensure vertex merging. - PipelineVertex newVertex = cloneVertex(vertex, id, - newVertexMap.get(srcId).getParallelism()); - // After the vertex is split, we need to change its name, - // otherwise the concurrency cannot be set. Sink has side effects, do not do treatment. - if (!(newVertex.getOperator() instanceof SinkOperator)) { - changeOperatorName(newVertex, index++); - } - newVertices.add(newVertex); - newVertexMap.put(id, newVertex); - oldIdToNewIdMap.put(vertex.getVertexId(), id); - } - PipelineEdge newEdge = new PipelineEdge(srcId, srcId, id, inEdge.getPartition(), - inEdge.getStreamOrdinal(), inEdge.getEncoder()); - if (vertexMap.get(oriSrcId).isDuplication()) { - IPartitioner partitionFunction = newEdge.getPartition(); - newEdge.setEdgeName(String - .format("union-%d-%s-%s-%s", id, newEdge.getPartitionType(), - PipelineEdge.JoinStream.values()[newEdge.getStreamOrdinal()], - partitionFunction != null ? partitionFunction.getClass() - .getSimpleName() : "none")); - } - newEdges.add(newEdge); - } + } + + /** Topological sorting is used to ensure the sequence of ids. */ + private void kahn(List vertices, int id) + throws IOException, ClassNotFoundException { + Queue toVisitQueue = new ArrayDeque<>(vertices); + Set visitedEdge = new HashSet<>(); + ArrayListMultimap oldIdToNewIdMap = ArrayListMultimap.create(); + + while (!toVisitQueue.isEmpty()) { + PipelineVertex vertex = toVisitQueue.poll(); + if (!vertex.isDuplication()) { + id++; + PipelineVertex newVertex = cloneVertex(vertex, id, 0); + newVertices.add(newVertex); + newVertexMap.put(id, newVertex); + oldIdToNewIdMap.put(vertex.getVertexId(), id); + } + + // Consider the condition that in edge and the previous node fission. + int index = 0; + for (PipelineEdge inEdge : inputEdges.get(vertex.getVertexId())) { + PipelineVertex oriSrcVertex = vertexMap.get(inEdge.getSrcId()); + int oriSrcId = oriSrcVertex.getVertexId(); + // The previous node fission. + for (Integer srcId : oldIdToNewIdMap.get(oriSrcId)) { + if (vertex.isDuplication()) { + id++; + // The parallelism of the new upstream node is used to ensure vertex merging. + PipelineVertex newVertex = + cloneVertex(vertex, id, newVertexMap.get(srcId).getParallelism()); + // After the vertex is split, we need to change its name, + // otherwise the concurrency cannot be set. Sink has side effects, do not do treatment. + if (!(newVertex.getOperator() instanceof SinkOperator)) { + changeOperatorName(newVertex, index++); } + newVertices.add(newVertex); + newVertexMap.put(id, newVertex); + oldIdToNewIdMap.put(vertex.getVertexId(), id); + } + PipelineEdge newEdge = + new PipelineEdge( + srcId, + srcId, + id, + inEdge.getPartition(), + inEdge.getStreamOrdinal(), + inEdge.getEncoder()); + if (vertexMap.get(oriSrcId).isDuplication()) { + IPartitioner partitionFunction = newEdge.getPartition(); + newEdge.setEdgeName( + String.format( + "union-%d-%s-%s-%s", + id, + newEdge.getPartitionType(), + PipelineEdge.JoinStream.values()[newEdge.getStreamOrdinal()], + partitionFunction != null + ? partitionFunction.getClass().getSimpleName() + : "none")); + } + newEdges.add(newEdge); + } + } - for (PipelineEdge executeEdge : outputEdges.get(vertex.getVertexId())) { - PipelineVertex nextVertex = vertexMap.get(executeEdge.getTargetId()); - visitedEdge.add(executeEdge); - if (visitedEdge.containsAll(inputEdges.get(nextVertex.getVertexId()))) { - toVisitQueue.add(nextVertex); - } - } + for (PipelineEdge executeEdge : outputEdges.get(vertex.getVertexId())) { + PipelineVertex nextVertex = vertexMap.get(executeEdge.getTargetId()); + visitedEdge.add(executeEdge); + if (visitedEdge.containsAll(inputEdges.get(nextVertex.getVertexId()))) { + toVisitQueue.add(nextVertex); } + } } - - private byte[] toByteArray(Object obj) throws IOException { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(bos); - oos.writeObject(obj); - oos.flush(); - byte[] bytes = bos.toByteArray(); - oos.close(); - bos.close(); - return bytes; + } + + private byte[] toByteArray(Object obj) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(obj); + oos.flush(); + byte[] bytes = bos.toByteArray(); + oos.close(); + bos.close(); + return bytes; + } + + private Object toObject(byte[] bytes) throws IOException, ClassNotFoundException { + ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + bis.mark(0); + ObjectInputStream ois = new ObjectInputStream(bis); + Object obj = null; + try { + obj = ois.readObject(); + } catch (Exception ex) { + bis.reset(); + ois = new ClassLoaderObjectInputStream(Thread.currentThread().getContextClassLoader(), bis); + obj = ois.readObject(); + } finally { + ois.close(); } - - private Object toObject(byte[] bytes) throws IOException, ClassNotFoundException { - ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - bis.mark(0); - ObjectInputStream ois = new ObjectInputStream(bis); - Object obj = null; - try { - obj = ois.readObject(); - } catch (Exception ex) { - bis.reset(); - ois = new ClassLoaderObjectInputStream(Thread.currentThread().getContextClassLoader(), - bis); - obj = ois.readObject(); - } finally { - ois.close(); - } - bis.close(); - return obj; + bis.close(); + return obj; + } + + private PipelineVertex cloneVertex(PipelineVertex vertex, int id, int parallelism) + throws IOException, ClassNotFoundException { + LOGGER.debug("clone Vertex {}", vertex); + PipelineVertex cloned; + try { + byte[] out = SerializerFactory.getKryoSerializer().serialize(vertex); + cloned = (PipelineVertex) SerializerFactory.getKryoSerializer().deserialize(out); + } catch (Exception ex) { + LOGGER.warn( + "vertex {} kryo fail, try java serde, ex: {}", + vertex, + Arrays.toString(ex.getStackTrace())); + cloned = (PipelineVertex) toObject(toByteArray(vertex)); + if (cloned == null) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.undefinedError(vertex.getVertexString() + " is not Serializable"), + ex); + } } - - private PipelineVertex cloneVertex(PipelineVertex vertex, int id, int parallelism) - throws IOException, ClassNotFoundException { - LOGGER.debug("clone Vertex {}", vertex); - PipelineVertex cloned; - try { - byte[] out = SerializerFactory.getKryoSerializer().serialize(vertex); - cloned = (PipelineVertex) SerializerFactory.getKryoSerializer().deserialize(out); - } catch (Exception ex) { - LOGGER.warn("vertex {} kryo fail, try java serde, ex: {}", vertex, - Arrays.toString(ex.getStackTrace())); - cloned = (PipelineVertex) toObject(toByteArray(vertex)); - if (cloned == null) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.undefinedError(vertex.getVertexString() + " is not Serializable"), ex); - } - } - cloned.setVertexId(id); - AbstractOperator abstractOperator = ((AbstractOperator) cloned.getOperator()); - abstractOperator.getOpArgs().setOpId(id); - abstractOperator.setFunction(((AbstractOperator) vertex.getOperator()).getFunction()); - - if (parallelism != 0) { - cloned.setParallelism(parallelism); - abstractOperator.getOpArgs().setParallelism(parallelism); - } - - return cloned; + cloned.setVertexId(id); + AbstractOperator abstractOperator = ((AbstractOperator) cloned.getOperator()); + abstractOperator.getOpArgs().setOpId(id); + abstractOperator.setFunction(((AbstractOperator) vertex.getOperator()).getFunction()); + + if (parallelism != 0) { + cloned.setParallelism(parallelism); + abstractOperator.getOpArgs().setParallelism(parallelism); } - private void changeOperatorName(PipelineVertex vertex, int index) { - if (StringUtils.isNotEmpty(((AbstractOperator) vertex.getOperator()).getOpArgs().getOpName())) { - ((AbstractOperator) vertex.getOperator()).getOpArgs() - .setOpName(String.format("%s-%d", ((AbstractOperator) vertex.getOperator()).getOpArgs().getOpName(), index)); - } + return cloned; + } + + private void changeOperatorName(PipelineVertex vertex, int index) { + if (StringUtils.isNotEmpty(((AbstractOperator) vertex.getOperator()).getOpArgs().getOpName())) { + ((AbstractOperator) vertex.getOperator()) + .getOpArgs() + .setOpName( + String.format( + "%s-%d", + ((AbstractOperator) vertex.getOperator()).getOpArgs().getOpName(), index)); } + } - private Set getOutVertexIdSet(PipelineVertex vertex) { - return outputEdges.get(vertex.getVertexId()).stream().map(PipelineEdge::getTargetId) - .collect(Collectors.toSet()); - } + private Set getOutVertexIdSet(PipelineVertex vertex) { + return outputEdges.get(vertex.getVertexId()).stream() + .map(PipelineEdge::getTargetId) + .collect(Collectors.toSet()); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/ChainCombiner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/ChainCombiner.java index ba5375922..618829ce4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/ChainCombiner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/ChainCombiner.java @@ -30,6 +30,7 @@ import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; + import org.apache.geaflow.operator.OpArgs.ChainStrategy; import org.apache.geaflow.operator.base.AbstractOperator; import org.apache.geaflow.partitioner.IPartitioner; @@ -43,151 +44,190 @@ public class ChainCombiner implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(ChainCombiner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ChainCombiner.class); - private Set visited = new HashSet<>(); + private Set visited = new HashSet<>(); - public void combineVertex(PipelineGraph pipelineGraph) { - visited.clear(); + public void combineVertex(PipelineGraph pipelineGraph) { + visited.clear(); - // 1. Generate vertex id map according to vertex map. - Map vertexMap = pipelineGraph.getVertexMap(); + // 1. Generate vertex id map according to vertex map. + Map vertexMap = pipelineGraph.getVertexMap(); - // 2. Generate all output and input edges of all vertices according to edges. - Map> outputEdges = new HashMap<>(); - Map> inputEdges = new HashMap<>(); - vertexMap.keySet().forEach(vertexId -> { - outputEdges.put(vertexId, pipelineGraph.getVertexOutEdges(vertexId)); - inputEdges.put(vertexId, pipelineGraph.getVertexInputEdges(vertexId)); - }); + // 2. Generate all output and input edges of all vertices according to edges. + Map> outputEdges = new HashMap<>(); + Map> inputEdges = new HashMap<>(); + vertexMap + .keySet() + .forEach( + vertexId -> { + outputEdges.put(vertexId, pipelineGraph.getVertexOutEdges(vertexId)); + inputEdges.put(vertexId, pipelineGraph.getVertexInputEdges(vertexId)); + }); - // 3. Find all the Source nodes and merge recursively from the Source node. - List sourceVertices = pipelineGraph.getPipelineVertices().stream() + // 3. Find all the Source nodes and merge recursively from the Source node. + List sourceVertices = + pipelineGraph.getPipelineVertices().stream() .filter(pipelineVertex -> (pipelineVertex.getType() == VertexType.source)) .collect(Collectors.toList()); - Collection jobEdges = pipelineGraph.getPipelineEdgeList(); - Collection jobVertices = pipelineGraph.getPipelineVertices(); - - List verticesIds = jobVertices.stream().map(x -> x.getVertexId()) - .collect(Collectors.toList()); - List needAddJobEdges = new ArrayList<>(); - for (PipelineEdge jobEdge : jobEdges) { - if (!verticesIds.contains(jobEdge.getSrcId())) { - sourceVertices.add(vertexMap.get(jobEdge.getTargetId())); - needAddJobEdges.add(jobEdge); - } - } + Collection jobEdges = pipelineGraph.getPipelineEdgeList(); + Collection jobVertices = pipelineGraph.getPipelineVertices(); + + List verticesIds = + jobVertices.stream().map(x -> x.getVertexId()).collect(Collectors.toList()); + List needAddJobEdges = new ArrayList<>(); + for (PipelineEdge jobEdge : jobEdges) { + if (!verticesIds.contains(jobEdge.getSrcId())) { + sourceVertices.add(vertexMap.get(jobEdge.getTargetId())); + needAddJobEdges.add(jobEdge); + } + } - if (sourceVertices.size() != 0) { - // 4. Recursively generates new nodes and edges. - Set newVertices = new HashSet<>(); - Set newEdges = new TreeSet<>(new Comparator() { + if (sourceVertices.size() != 0) { + // 4. Recursively generates new nodes and edges. + Set newVertices = new HashSet<>(); + Set newEdges = + new TreeSet<>( + new Comparator() { @Override public int compare(PipelineEdge o1, PipelineEdge o2) { - return o1.getEdgeId() - o2.getEdgeId(); + return o1.getEdgeId() - o2.getEdgeId(); } - }); - for (PipelineVertex sourceVertex : sourceVertices) { - newVertices.add(sourceVertex); - createOperatorChain(sourceVertex.getVertexId(), sourceVertex, vertexMap, inputEdges, - outputEdges, newVertices, newEdges, null); - } - - pipelineGraph.setPipelineEdges(newEdges); - newVertices.forEach(jobVertex -> { - LOGGER.info(jobVertex.getVertexString()); - DAGValidator.checkVertexValidity(pipelineGraph, jobVertex, false); - }); - pipelineGraph.setPipelineVertices(newVertices); - needAddJobEdges.stream().forEach(jobEdge -> pipelineGraph.addEdge(jobEdge)); - } + }); + for (PipelineVertex sourceVertex : sourceVertices) { + newVertices.add(sourceVertex); + createOperatorChain( + sourceVertex.getVertexId(), + sourceVertex, + vertexMap, + inputEdges, + outputEdges, + newVertices, + newEdges, + null); + } + + pipelineGraph.setPipelineEdges(newEdges); + newVertices.forEach( + jobVertex -> { + LOGGER.info(jobVertex.getVertexString()); + DAGValidator.checkVertexValidity(pipelineGraph, jobVertex, false); + }); + pipelineGraph.setPipelineVertices(newVertices); + needAddJobEdges.stream().forEach(jobEdge -> pipelineGraph.addEdge(jobEdge)); } - - private void createOperatorChain(int id, PipelineVertex srcVertex, Map vertexMap, - Map> inputEdges, - Map> outputEdges, - Set newVertices, Set newEdges, - String outputTag) { - int srcId = srcVertex.getVertexId(); - - if (visited.add(srcId)) { - LOGGER.debug("Exploring vertex[{}]", srcId); - if (outputEdges.containsKey(srcId)) { - LOGGER.debug("srcId:{}", srcId); - Set srcVertexOutputEdges = outputEdges.get(srcId); - for (PipelineEdge executeEdge : srcVertexOutputEdges) { - LOGGER.debug("edge:{}", executeEdge); - int targetId = executeEdge.getTargetId(); - PipelineVertex targetVertex = vertexMap.get(targetId); - if (executeEdge.getEdgeName() != null) { - outputTag = executeEdge.getEdgeName(); - } - - if (isVertexCanMerge(srcVertex, targetVertex, executeEdge, inputEdges)) { - LOGGER.debug("Vertex[{}] can merge Vertex[{}]", srcVertex.getVertexId(), - targetVertex.getVertexId()); - - //Add a dependency for an Operator - AbstractOperator abstractOperator = (AbstractOperator) srcVertex.getOperator(); - abstractOperator.addNextOperator(targetVertex.getOperator()); - createOperatorChain(id, targetVertex, vertexMap, inputEdges, outputEdges, - newVertices, newEdges, outputTag); - srcVertex.setChainTailType(targetVertex.getChainTailType()); - if (executeEdge.getEdgeName() != null) { - abstractOperator.getOutputTags().put(executeEdge.getEdgeId(), executeEdge.getEdgeName()); - } - } else { - LOGGER.debug("Vertex[{}] can't merge Vertex[{}]", srcVertex.getVertexId(), - targetVertex.getVertexId()); - executeEdge.setSrcId(id); - executeEdge.setEdgeName(outputTag); - newEdges.add(executeEdge); - newVertices.add(targetVertex); - if (executeEdge.getSrcId() != executeEdge.getTargetId()) { - createOperatorChain(targetVertex.getVertexId(), targetVertex, vertexMap, - inputEdges, outputEdges, newVertices, newEdges, null); - } - } - } + } + + private void createOperatorChain( + int id, + PipelineVertex srcVertex, + Map vertexMap, + Map> inputEdges, + Map> outputEdges, + Set newVertices, + Set newEdges, + String outputTag) { + int srcId = srcVertex.getVertexId(); + + if (visited.add(srcId)) { + LOGGER.debug("Exploring vertex[{}]", srcId); + if (outputEdges.containsKey(srcId)) { + LOGGER.debug("srcId:{}", srcId); + Set srcVertexOutputEdges = outputEdges.get(srcId); + for (PipelineEdge executeEdge : srcVertexOutputEdges) { + LOGGER.debug("edge:{}", executeEdge); + int targetId = executeEdge.getTargetId(); + PipelineVertex targetVertex = vertexMap.get(targetId); + if (executeEdge.getEdgeName() != null) { + outputTag = executeEdge.getEdgeName(); + } + + if (isVertexCanMerge(srcVertex, targetVertex, executeEdge, inputEdges)) { + LOGGER.debug( + "Vertex[{}] can merge Vertex[{}]", + srcVertex.getVertexId(), + targetVertex.getVertexId()); + + // Add a dependency for an Operator + AbstractOperator abstractOperator = (AbstractOperator) srcVertex.getOperator(); + abstractOperator.addNextOperator(targetVertex.getOperator()); + createOperatorChain( + id, + targetVertex, + vertexMap, + inputEdges, + outputEdges, + newVertices, + newEdges, + outputTag); + srcVertex.setChainTailType(targetVertex.getChainTailType()); + if (executeEdge.getEdgeName() != null) { + abstractOperator + .getOutputTags() + .put(executeEdge.getEdgeId(), executeEdge.getEdgeName()); } + } else { + LOGGER.debug( + "Vertex[{}] can't merge Vertex[{}]", + srcVertex.getVertexId(), + targetVertex.getVertexId()); + executeEdge.setSrcId(id); + executeEdge.setEdgeName(outputTag); + newEdges.add(executeEdge); + newVertices.add(targetVertex); + if (executeEdge.getSrcId() != executeEdge.getTargetId()) { + createOperatorChain( + targetVertex.getVertexId(), + targetVertex, + vertexMap, + inputEdges, + outputEdges, + newVertices, + newEdges, + null); + } + } } + } + } + } + + /** + * Determine whether single nodes can be merged. Conditions for node merging: 1. The entry degree + * of the target node is 1. 2. The edge must be of type forward, not key. 3. Concurrency must be + * consistent. 4. The chainStrategy of the source node cannot be NEVER. 5. The chainStrategy of + * the target node must be ALWAYS. + */ + private boolean isVertexCanMerge( + PipelineVertex srcVertex, + PipelineVertex targetVertex, + PipelineEdge executeEdge, + Map> inputEdges) { + + if (inputEdges.get(targetVertex.getVertexId()).size() != 1) { + return false; } - /** - * Determine whether single nodes can be merged. - * Conditions for node merging: - * 1. The entry degree of the target node is 1. - * 2. The edge must be of type forward, not key. - * 3. Concurrency must be consistent. - * 4. The chainStrategy of the source node cannot be NEVER. - * 5. The chainStrategy of the target node must be ALWAYS. - */ - private boolean isVertexCanMerge(PipelineVertex srcVertex, PipelineVertex targetVertex, - PipelineEdge executeEdge, Map> inputEdges) { - - if (inputEdges.get(targetVertex.getVertexId()).size() != 1) { - return false; - } - - if (executeEdge.getPartition().getPartitionType() != IPartitioner.PartitionType.forward) { - return false; - } - - if (srcVertex.getParallelism() != targetVertex.getParallelism()) { - return false; - } + if (executeEdge.getPartition().getPartitionType() != IPartitioner.PartitionType.forward) { + return false; + } - if (((AbstractOperator) srcVertex.getOperator()).getOpArgs().getChainStrategy() - == ChainStrategy.NEVER) { - return false; - } + if (srcVertex.getParallelism() != targetVertex.getParallelism()) { + return false; + } - ChainStrategy strategy = ((AbstractOperator) targetVertex.getOperator()).getOpArgs().getChainStrategy(); - if (strategy != null && strategy != ChainStrategy.ALWAYS) { - return false; - } + if (((AbstractOperator) srcVertex.getOperator()).getOpArgs().getChainStrategy() + == ChainStrategy.NEVER) { + return false; + } - return true; + ChainStrategy strategy = + ((AbstractOperator) targetVertex.getOperator()).getOpArgs().getChainStrategy(); + if (strategy != null && strategy != ChainStrategy.ALWAYS) { + return false; } + + return true; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/SingleWindowGroupRule.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/SingleWindowGroupRule.java index 8694fbf93..266ad3564 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/SingleWindowGroupRule.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/optimizer/strategy/SingleWindowGroupRule.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.operator.OpArgs; import org.apache.geaflow.operator.base.AbstractOperator; import org.apache.geaflow.plan.graph.PipelineGraph; @@ -30,23 +31,24 @@ public class SingleWindowGroupRule implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(SingleWindowGroupRule.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SingleWindowGroupRule.class); - /** - * Apply group rule in plan. - */ - public void apply(PipelineGraph pipelineGraph) { - List sourceVertexList = pipelineGraph.getSourceVertices(); - // 1. Check whether is single window mode. - boolean isSingleWindow = sourceVertexList.stream().allMatch(v -> - ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() == OpArgs.OpType.SINGLE_WINDOW_SOURCE); + /** Apply group rule in plan. */ + public void apply(PipelineGraph pipelineGraph) { + List sourceVertexList = pipelineGraph.getSourceVertices(); + // 1. Check whether is single window mode. + boolean isSingleWindow = + sourceVertexList.stream() + .allMatch( + v -> + ((AbstractOperator) v.getOperator()).getOpArgs().getOpType() + == OpArgs.OpType.SINGLE_WINDOW_SOURCE); - // 2. Apply no group rule. - if (isSingleWindow) { - pipelineGraph.getPipelineVertices().stream().forEach( - v -> ((AbstractOperator) v.getOperator()).getOpArgs().setEnGroup(false)); - LOGGER.info("apply no group rule success"); - } + // 2. Apply no group rule. + if (isSingleWindow) { + pipelineGraph.getPipelineVertices().stream() + .forEach(v -> ((AbstractOperator) v.getOperator()).getOpArgs().setEnGroup(false)); + LOGGER.info("apply no group rule success"); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/util/DAGValidator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/util/DAGValidator.java index 709d33502..cf6e12ae6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/util/DAGValidator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/util/DAGValidator.java @@ -27,34 +27,36 @@ public class DAGValidator { - /** - * Gets and verifies whether the upstream vertex of the current vertex in the dag exists. - * - * @param pipelineGraph The pipeline plan. - * @param pipelineVertex The current vertex. - */ - public static void checkVertexValidity(PipelineGraph pipelineGraph, PipelineVertex pipelineVertex, boolean fetchPrevious) { - for (PipelineEdge pipelineEdge : pipelineGraph.getPipelineEdgeList()) { - int vertexId; - if (fetchPrevious) { - vertexId = pipelineEdge.getTargetId(); - } else { - vertexId = pipelineEdge.getSrcId(); - } + /** + * Gets and verifies whether the upstream vertex of the current vertex in the dag exists. + * + * @param pipelineGraph The pipeline plan. + * @param pipelineVertex The current vertex. + */ + public static void checkVertexValidity( + PipelineGraph pipelineGraph, PipelineVertex pipelineVertex, boolean fetchPrevious) { + for (PipelineEdge pipelineEdge : pipelineGraph.getPipelineEdgeList()) { + int vertexId; + if (fetchPrevious) { + vertexId = pipelineEdge.getTargetId(); + } else { + vertexId = pipelineEdge.getSrcId(); + } - // Input vertex check, for chain and non-chain mode. - if (pipelineVertex.getVertexId() == vertexId) { - int previousChainTailVertexId = pipelineEdge.getPartition().getOpId(); - PipelineVertex previousVertex = null; - if (pipelineGraph.getVertexMap().containsKey(previousChainTailVertexId)) { - previousVertex = pipelineGraph.getVertexMap().get(previousChainTailVertexId); - } - // Maybe encounter the situation that previous vertex is null. - if (previousVertex == null) { - throw new GeaflowRuntimeException(RuntimeErrors.INST - .previousVertexIsNullError(String.valueOf(pipelineVertex.getVertexId()))); - } - } + // Input vertex check, for chain and non-chain mode. + if (pipelineVertex.getVertexId() == vertexId) { + int previousChainTailVertexId = pipelineEdge.getPartition().getOpId(); + PipelineVertex previousVertex = null; + if (pipelineGraph.getVertexMap().containsKey(previousChainTailVertexId)) { + previousVertex = pipelineGraph.getVertexMap().get(previousChainTailVertexId); } + // Maybe encounter the situation that previous vertex is null. + if (previousVertex == null) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.previousVertexIsNullError( + String.valueOf(pipelineVertex.getVertexId()))); + } + } } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/GeaFlowNodeInfo.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/GeaFlowNodeInfo.java index d78eb950e..e4115fb7f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/GeaFlowNodeInfo.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/GeaFlowNodeInfo.java @@ -24,44 +24,44 @@ public class GeaFlowNodeInfo { - private int vertexId; - private String type; - private OpDesc operator; - private int parallelism; + private int vertexId; + private String type; + private OpDesc operator; + private int parallelism; - public GeaFlowNodeInfo(int vertexId, String type, Operator operator) { - this.vertexId = vertexId; - this.type = type; - this.operator = new OpDesc(operator); - this.parallelism = ((AbstractOperator) operator).getOpArgs().getParallelism(); - } + public GeaFlowNodeInfo(int vertexId, String type, Operator operator) { + this.vertexId = vertexId; + this.type = type; + this.operator = new OpDesc(operator); + this.parallelism = ((AbstractOperator) operator).getOpArgs().getParallelism(); + } - public int getParallelism() { - return parallelism; - } + public int getParallelism() { + return parallelism; + } - public String getType() { - return type; - } + public String getType() { + return type; + } - public void setType(String type) { - this.type = type; - } + public void setType(String type) { + this.type = type; + } - public int getVertexId() { - return vertexId; - } + public int getVertexId() { + return vertexId; + } - public void setVertexId(int vertexId) { - this.vertexId = vertexId; - } + public void setVertexId(int vertexId) { + this.vertexId = vertexId; + } - public String toGraphvizNodeString() { - StringBuilder builder = new StringBuilder(); - builder.append(vertexId).append(" [label=\""); - builder.append("p:").append(parallelism); - builder.append(", ").append(operator.getName()); - builder.append("\"]\n"); - return builder.toString(); - } + public String toGraphvizNodeString() { + StringBuilder builder = new StringBuilder(); + builder.append(vertexId).append(" [label=\""); + builder.append("p:").append(parallelism); + builder.append(", ").append(operator.getName()); + builder.append("\"]\n"); + return builder.toString(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/JsonPlanGraphVisualization.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/JsonPlanGraphVisualization.java index 7c3bca268..cb4863556 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/JsonPlanGraphVisualization.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/JsonPlanGraphVisualization.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.common.visualization.console.ConsoleVisualizeVertex; import org.apache.geaflow.common.visualization.console.JsonPlan; @@ -35,79 +36,82 @@ public class JsonPlanGraphVisualization { - private final PipelineGraph plan; + private final PipelineGraph plan; - private final JsonPlan jsonPlan; + private final JsonPlan jsonPlan; - public JsonPlanGraphVisualization(PipelineGraph plan) { - this.plan = plan; - this.jsonPlan = new JsonPlan(); - createJsonPlan(plan); - } + public JsonPlanGraphVisualization(PipelineGraph plan) { + this.plan = plan; + this.jsonPlan = new JsonPlan(); + createJsonPlan(plan); + } - private void createJsonPlan(PipelineGraph plan) { - List pipelineVertices = new ArrayList<>(plan.getVertexMap().values()); - pipelineVertices.sort(Comparator.comparingInt(PipelineVertex::getVertexId)); + private void createJsonPlan(PipelineGraph plan) { + List pipelineVertices = new ArrayList<>(plan.getVertexMap().values()); + pipelineVertices.sort(Comparator.comparingInt(PipelineVertex::getVertexId)); - for (PipelineVertex vertex : pipelineVertices) { - ConsoleVisualizeVertex v = new ConsoleVisualizeVertex(); - decorateNode(vertex, v); - jsonPlan.vertices.put(v.id, v); - } + for (PipelineVertex vertex : pipelineVertices) { + ConsoleVisualizeVertex v = new ConsoleVisualizeVertex(); + decorateNode(vertex, v); + jsonPlan.vertices.put(v.id, v); } + } - public JsonPlan getJsonPlan() { - return jsonPlan; - } + public JsonPlan getJsonPlan() { + return jsonPlan; + } - private void decorateNode(PipelineVertex jobVertex, ConsoleVisualizeVertex vertex) { - vertex.setId(Integer.toString(jobVertex.getVertexId())); - vertex.setVertexType(jobVertex.getType().name()); - vertex.setParallelism(jobVertex.getParallelism()); - vertex.setVertexMode(jobVertex.getVertexMode() == null ? null : - jobVertex.getVertexMode().name()); + private void decorateNode(PipelineVertex jobVertex, ConsoleVisualizeVertex vertex) { + vertex.setId(Integer.toString(jobVertex.getVertexId())); + vertex.setVertexType(jobVertex.getType().name()); + vertex.setParallelism(jobVertex.getParallelism()); + vertex.setVertexMode( + jobVertex.getVertexMode() == null ? null : jobVertex.getVertexMode().name()); - if (!vertex.getVertexType().equals(VertexType.source.name())) { - for (PipelineEdge edge : plan.getVertexInputEdges().get(jobVertex.getVertexId())) { - Predecessor predecessor = new Predecessor(); - predecessor.setId(Integer.toString(edge.getSrcId())); - predecessor.setPartitionType(edge.getPartitionType().name()); - vertex.getParents().add(predecessor); - } - } - AbstractOperator operator = (AbstractOperator) jobVertex.getOperator(); - if (operator.getNextOperators().size() > 0) { - vertex.setInnerPlan(new JsonPlan()); - decorateInnerOperator(vertex.getInnerPlan(), operator, vertex, null); - } else { - vertex.setOperator(operator.getClass().getSimpleName()); - vertex.setOperatorName(getOperatorName(operator)); - } + if (!vertex.getVertexType().equals(VertexType.source.name())) { + for (PipelineEdge edge : plan.getVertexInputEdges().get(jobVertex.getVertexId())) { + Predecessor predecessor = new Predecessor(); + predecessor.setId(Integer.toString(edge.getSrcId())); + predecessor.setPartitionType(edge.getPartitionType().name()); + vertex.getParents().add(predecessor); + } } + AbstractOperator operator = (AbstractOperator) jobVertex.getOperator(); + if (operator.getNextOperators().size() > 0) { + vertex.setInnerPlan(new JsonPlan()); + decorateInnerOperator(vertex.getInnerPlan(), operator, vertex, null); + } else { + vertex.setOperator(operator.getClass().getSimpleName()); + vertex.setOperatorName(getOperatorName(operator)); + } + } - private void decorateInnerOperator(JsonPlan innerPlan, AbstractOperator operator, - ConsoleVisualizeVertex outerVertex, String parentId) { - ConsoleVisualizeVertex vertex = new ConsoleVisualizeVertex(); - vertex.setOperator(operator.getClass().getSimpleName()); - vertex.setOperatorName(getOperatorName(operator)); + private void decorateInnerOperator( + JsonPlan innerPlan, + AbstractOperator operator, + ConsoleVisualizeVertex outerVertex, + String parentId) { + ConsoleVisualizeVertex vertex = new ConsoleVisualizeVertex(); + vertex.setOperator(operator.getClass().getSimpleName()); + vertex.setOperatorName(getOperatorName(operator)); - vertex.setId(outerVertex.getId() + "-" + operator.getOpArgs().getOpId()); - vertex.setParallelism(outerVertex.getParallelism()); - innerPlan.vertices.put(vertex.getId(), vertex); - if (parentId != null) { - Predecessor predecessor = new Predecessor(); - predecessor.setId(parentId); - vertex.getParents().add(predecessor); - } - for (Operator op : operator.getNextOperators()) { - decorateInnerOperator(innerPlan, (AbstractOperator) op, outerVertex, vertex.getId()); - } + vertex.setId(outerVertex.getId() + "-" + operator.getOpArgs().getOpId()); + vertex.setParallelism(outerVertex.getParallelism()); + innerPlan.vertices.put(vertex.getId(), vertex); + if (parentId != null) { + Predecessor predecessor = new Predecessor(); + predecessor.setId(parentId); + vertex.getParents().add(predecessor); + } + for (Operator op : operator.getNextOperators()) { + decorateInnerOperator(innerPlan, (AbstractOperator) op, outerVertex, vertex.getId()); } + } - private String getOperatorName(AbstractOperator operator) { - if (StringUtils.isNotBlank(operator.getOpArgs().getOpName())) { - return operator.getOpArgs().getOpName(); - } - return operator.getIdentify(); + private String getOperatorName(AbstractOperator operator) { + if (StringUtils.isNotBlank(operator.getOpArgs().getOpName())) { + return operator.getOpArgs().getOpName(); } + return operator.getIdentify(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/OpDesc.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/OpDesc.java index 5dae6b056..fd8ce84d4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/OpDesc.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/OpDesc.java @@ -21,41 +21,42 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.operator.Operator; import org.apache.geaflow.operator.base.AbstractOperator; public class OpDesc { - private String name; - private int id; - private List children; - - public OpDesc(Operator operator) { - // Because of operator chaining, operator name should contains multiple operators. - name = operator.toString(); - AbstractOperator abstractOperator = (AbstractOperator) operator; - id = abstractOperator.getOpArgs().getOpId(); - children = new ArrayList<>(); - if (abstractOperator.getNextOperators() != null) { - for (Object subOperator : abstractOperator.getNextOperators()) { - children.add(new OpDesc((Operator) subOperator)); - } - } + private String name; + private int id; + private List children; + + public OpDesc(Operator operator) { + // Because of operator chaining, operator name should contains multiple operators. + name = operator.toString(); + AbstractOperator abstractOperator = (AbstractOperator) operator; + id = abstractOperator.getOpArgs().getOpId(); + children = new ArrayList<>(); + if (abstractOperator.getNextOperators() != null) { + for (Object subOperator : abstractOperator.getNextOperators()) { + children.add(new OpDesc((Operator) subOperator)); + } } + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public void setName(String name) { - this.name = name; - } + public void setName(String name) { + this.name = name; + } - public int getId() { - return id; - } + public int getId() { + return id; + } - public void setId(int id) { - this.id = id; - } + public void setId(int id) { + this.id = id; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/PlanGraphVisualization.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/PlanGraphVisualization.java index 80bcce17e..062b7e3c6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/PlanGraphVisualization.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/main/java/org/apache/geaflow/plan/visualization/PlanGraphVisualization.java @@ -19,82 +19,85 @@ package org.apache.geaflow.plan.visualization; -import com.alibaba.fastjson.JSON; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.plan.graph.PipelineEdge; import org.apache.geaflow.plan.graph.PipelineGraph; import org.apache.geaflow.plan.graph.PipelineVertex; +import com.alibaba.fastjson.JSON; public class PlanGraphVisualization { - public static final String SCHEDULER = "scheduler"; - private static final String NODE_FORMAT = "%s [label=\"%s\"]\n"; - private static final String EMPTY_GRAPH = new StringBuilder("digraph G {\n") - .append("0 [label=\"node-0\"]\n") - .append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)) - .append("}") - .toString(); + public static final String SCHEDULER = "scheduler"; + private static final String NODE_FORMAT = "%s [label=\"%s\"]\n"; + private static final String EMPTY_GRAPH = + new StringBuilder("digraph G {\n") + .append("0 [label=\"node-0\"]\n") + .append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)) + .append("}") + .toString(); - private PipelineGraph plan; - private List pipelineEdgeList; - private Map pipelineVertexMap; + private PipelineGraph plan; + private List pipelineEdgeList; + private Map pipelineVertexMap; - public PlanGraphVisualization(PipelineGraph plan) { - this.plan = plan; - this.pipelineEdgeList = new ArrayList<>(plan.getPipelineEdgeList()); - this.pipelineVertexMap = plan.getVertexMap(); - } + public PlanGraphVisualization(PipelineGraph plan) { + this.plan = plan; + this.pipelineEdgeList = new ArrayList<>(plan.getPipelineEdgeList()); + this.pipelineVertexMap = plan.getVertexMap(); + } - public String getGraphviz() { - // Determine whether an empty PipelineGraph that has already executed a batch job is passed in. - if (pipelineEdgeList.size() == 0 && pipelineVertexMap.size() == 0) { - return EMPTY_GRAPH; - } - Collections.sort(pipelineEdgeList, (o1, o2) -> { - int i = Integer.compare(o1.getSrcId(), o2.getSrcId()); - if (i == 0) { - return Integer.compare(o1.getTargetId(), o2.getTargetId()); - } else { - return i; - } + public String getGraphviz() { + // Determine whether an empty PipelineGraph that has already executed a batch job is passed in. + if (pipelineEdgeList.size() == 0 && pipelineVertexMap.size() == 0) { + return EMPTY_GRAPH; + } + Collections.sort( + pipelineEdgeList, + (o1, o2) -> { + int i = Integer.compare(o1.getSrcId(), o2.getSrcId()); + if (i == 0) { + return Integer.compare(o1.getTargetId(), o2.getTargetId()); + } else { + return i; + } }); - StringBuilder builder = new StringBuilder("digraph G {\n"); - for (PipelineEdge edge : pipelineEdgeList) { - builder.append(String - .format("%d -> %d [label = \"%s\"]\n", edge.getSrcId(), edge.getTargetId(), - edge.getPartitionType().toString())); - } - for (PipelineVertex vertex : pipelineVertexMap.values()) { - builder.append(String - .format(NODE_FORMAT, vertex.getVertexId(), vertex.getVertexString())); - } - - builder.append("0 [label=\"node-0\"]\n"); - builder.append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)); - - builder.append("}"); - return builder.toString(); + StringBuilder builder = new StringBuilder("digraph G {\n"); + for (PipelineEdge edge : pipelineEdgeList) { + builder.append( + String.format( + "%d -> %d [label = \"%s\"]\n", + edge.getSrcId(), edge.getTargetId(), edge.getPartitionType().toString())); + } + for (PipelineVertex vertex : pipelineVertexMap.values()) { + builder.append(String.format(NODE_FORMAT, vertex.getVertexId(), vertex.getVertexString())); } - public String getNodeInfo() { - Map id2Info = new HashMap<>(); - for (PipelineVertex vertex : pipelineVertexMap.values()) { - GeaFlowNodeInfo node = new GeaFlowNodeInfo(vertex.getVertexId(), vertex.getType().name(), - vertex.getOperator()); - id2Info.put(vertex.getVertexName(), node); + builder.append("0 [label=\"node-0\"]\n"); + builder.append(String.format(NODE_FORMAT, SCHEDULER, SCHEDULER)); - } - return JSON.toJSONString(id2Info); - } + builder.append("}"); + return builder.toString(); + } - @Override - public String toString() { - return plan.toString(); + public String getNodeInfo() { + Map id2Info = new HashMap<>(); + for (PipelineVertex vertex : pipelineVertexMap.values()) { + GeaFlowNodeInfo node = + new GeaFlowNodeInfo(vertex.getVertexId(), vertex.getType().name(), vertex.getOperator()); + id2Info.put(vertex.getVertexName(), node); } + return JSON.toJSONString(id2Info); + } + + @Override + public String toString() { + return plan.toString(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilderTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilderTest.java index f4be83d11..d0d64a1dc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilderTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/core/graph/builder/ExecutionGraphBuilderTest.java @@ -22,11 +22,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Iterator; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -83,941 +82,1000 @@ import org.junit.Assert; import org.junit.Test; -public class ExecutionGraphBuilderTest { - - @Test - public void testSingleOutput() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - PStreamSink sink = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)) - .map(p -> p).withParallelism(2).keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)) - .withParallelism(3).sink(p -> { - }).withParallelism(3); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(1, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(5, vertexGroup.getVertexMap().size()); - Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getVertexMap().values().stream().anyMatch( - vertex -> vertex.getAffinityLevel() == AffinityLevel.worker)); - } +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; - @Test - public void testMultiOutput() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowStreamSource ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - WindowStreamSource ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - PStreamSink sink1 = ds1.map(p -> p).withParallelism(2).keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)) - .withParallelism(3).sink(p -> { - }).withParallelism(3); - PStreamSink sink2 = ds2.sink(v -> { - }); - - when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - ChainCombiner combiner = new ChainCombiner(); - combiner.combineVertex(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(1, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(4, vertexGroup.getVertexMap().size()); - Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - } +public class ExecutionGraphBuilderTest { - @Test - public void testAllWindowWithReduceTwoAndSinkFourConcurrency() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowStreamSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); - PStreamSink sink = source + @Test + public void testSingleOutput() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + PStreamSink sink = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)) + .map(p -> p) + .withParallelism(2) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .withParallelism(3) + .sink(p -> {}) + .withParallelism(3); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(1, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(5, vertexGroup.getVertexMap().size()); + Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue( + vertexGroup.getVertexMap().values().stream() + .anyMatch(vertex -> vertex.getAffinityLevel() == AffinityLevel.worker)); + } + + @Test + public void testMultiOutput() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowStreamSource ds1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + WindowStreamSource ds2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PStreamSink sink1 = + ds1.map(p -> p) + .withParallelism(2) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .withParallelism(3) + .sink(p -> {}) + .withParallelism(3); + PStreamSink sink2 = ds2.sink(v -> {}); + + when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + ChainCombiner combiner = new ChainCombiner(); + combiner.combineVertex(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(1, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(4, vertexGroup.getVertexMap().size()); + Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + } + + @Test + public void testAllWindowWithReduceTwoAndSinkFourConcurrency() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowStreamSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); + PStreamSink sink = + source .map(e -> Tuple.of(e, 1)) .keyBy(v -> ((Tuple) v).f0) .reduce(new CountFunc()) .withParallelism(2) - .sink(v -> { - }) + .sink(v -> {}) .withParallelism(4); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(3, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(1, vertexGroup.getVertexMap().size()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); - } - - @Test - public void testAllWindowWithSingleConcurrency() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowStreamSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); - PStreamSink sink = source + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(3, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(1, vertexGroup.getVertexMap().size()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); + } + + @Test + public void testAllWindowWithSingleConcurrency() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowStreamSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); + PStreamSink sink = + source .map(e -> Tuple.of(e, 1)) .keyBy(v -> ((Tuple) v).f0) .reduce(new CountFunc()) .withParallelism(1) - .sink(v -> { - }) + .sink(v -> {}) .withParallelism(1); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(2, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(1, vertexGroup.getVertexMap().size()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); - } - - @Test - public void testOperatorChain() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - - WindowStreamSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); - WindowDataStream filter1 = new WindowDataStream(source, new MapOperator(x -> x)); - WindowDataStream filter2 = new WindowDataStream(filter1, new MapOperator(x -> x)); - WindowDataStream mapper1 = new WindowDataStream(filter2, new MapOperator(x -> x)); - WindowDataStream mapper2 = new WindowDataStream(mapper1, new MapOperator(x -> x)); - PStreamSink sink1 = new WindowStreamSink(mapper2, new SinkOperator(x -> { - })); - - when(context.getActions()).thenReturn(ImmutableList.of(sink1)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(1, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(1, vertexGroup.getVertexMap().size()); - Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); - } - - @Test - public void testKeyAgg() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - PWindowSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); - PStreamSink sink = source.flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(2, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(1, vertexGroup.getVertexMap().size()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.windowed); + } + + @Test + public void testOperatorChain() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + + WindowStreamSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); + WindowDataStream filter1 = new WindowDataStream(source, new MapOperator(x -> x)); + WindowDataStream filter2 = new WindowDataStream(filter1, new MapOperator(x -> x)); + WindowDataStream mapper1 = new WindowDataStream(filter2, new MapOperator(x -> x)); + WindowDataStream mapper2 = new WindowDataStream(mapper1, new MapOperator(x -> x)); + PStreamSink sink1 = new WindowStreamSink(mapper2, new SinkOperator(x -> {})); + + when(context.getActions()).thenReturn(ImmutableList.of(sink1)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(1, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(1, vertexGroup.getVertexMap().size()); + Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); + } + + @Test + public void testKeyAgg() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + PWindowSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); + PStreamSink sink = + source + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { String[] records = value.split(","); for (String record : records) { - collector.partition(Long.valueOf(record)); + collector.partition(Long.valueOf(record)); } - } - }) + } + }) .map(p -> Tuple.of(p, p)) .keyBy(p -> ((long) ((Tuple) p).f0) % 7) .aggregate(new AggFunc()) .withParallelism(3) .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) - .sink(v -> { - }).withParallelism(2); - - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(1, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().values().stream().findFirst().get(); - Assert.assertEquals(3, vertexGroup.getVertexMap().size()); - Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getVertexMap().get(1).getNumPartitions() == 4); - Assert.assertTrue(vertexGroup.getVertexMap().get(5).getNumPartitions() == 2); - Assert.assertTrue(vertexGroup.getVertexMap().get(7).getNumPartitions() == 2); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); - } + .sink(v -> {}) + .withParallelism(2); - @Test - public void testIncGraphCompute() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PStreamSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PStreamSource> edges = - new WindowStreamSource(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(1, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + ExecutionVertexGroup vertexGroup = + graph.getVertexGroupMap().values().stream().findFirst().get(); + Assert.assertEquals(3, vertexGroup.getVertexMap().size()); + Assert.assertEquals(Long.MAX_VALUE, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(5, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertFalse(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getVertexMap().get(1).getNumPartitions() == 4); + Assert.assertTrue(vertexGroup.getVertexMap().get(5).getNumPartitions() == 2); + Assert.assertTrue(vertexGroup.getVertexMap().get(7).getNumPartitions() == 2); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.pipelined); + } + + @Test + public void testIncGraphCompute() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PStreamSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PStreamSource> edges = + new WindowStreamSource( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(4) .withBackend(IViewDesc.BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - PGraphView fundGraphView = - new IncGraphView<>(context, graphViewDesc); + PGraphView fundGraphView = + new IncGraphView<>(context, graphViewDesc); - PIncGraphView incGraphView = - fundGraphView.appendGraph( - vertices.window(WindowFactory.createSizeTumblingWindow(1)), - edges.window(WindowFactory.createSizeTumblingWindow(1))); + PIncGraphView incGraphView = + fundGraphView.appendGraph( + vertices.window(WindowFactory.createSizeTumblingWindow(1)), + edges.window(WindowFactory.createSizeTumblingWindow(1))); - PStreamSink sink = incGraphView.incrementalCompute(new IncGraphAlgorithms(3)) + PStreamSink sink = + incGraphView + .incrementalCompute(new IncGraphAlgorithms(3)) .getVertices() .map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(v -> { - }); - - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(4, graph.getVertexGroupMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.incremental); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.incremental); - } - - @Test - public void testStaticGraphCompute() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(4), context, - vertices, edges); - PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)) + .sink(v -> {}); + + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(4, graph.getVertexGroupMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.incremental); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.incremental); + } + + @Test + public void testStaticGraphCompute() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); + PStreamSink sink = + graphWindow + .compute(new PRAlgorithms(3)) .compute(3) .getVertices() - .sink(v -> { - }) + .sink(v -> {}) .withParallelism(2); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(4, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - } - - @Test - public void testAllWindowStaticGraphCompute() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(4), context, - vertices, edges); - PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)) + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(4, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + @Test + public void testAllWindowStaticGraphCompute() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); + PStreamSink sink = + graphWindow + .compute(new PRAlgorithms(3)) .compute(3) .getVertices() - .sink(v -> { - }) + .sink(v -> {}) .withParallelism(2); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - - Assert.assertEquals(4, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - } - - @Test - public void testWindowGraphTraversal() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - PStreamSource> triggerSource = - new WindowStreamSource<>(context, - new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), - AllWindow.getInstance()); - PWindowStream> windowTrigger = - triggerSource.window(WindowFactory.createSizeTumblingWindow(3)); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); - PStreamSink sink = graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(windowTrigger).sink(v -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(5, graph.getVertexGroupMap().size()); - Assert.assertEquals(5, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(4); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + + Assert.assertEquals(4, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + @Test + public void testWindowGraphTraversal() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PStreamSource> triggerSource = + new WindowStreamSource<>( + context, + new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), + AllWindow.getInstance()); + PWindowStream> windowTrigger = + triggerSource.window(WindowFactory.createSizeTumblingWindow(3)); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); + PStreamSink sink = + graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(windowTrigger).sink(v -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(5, graph.getVertexGroupMap().size()); + Assert.assertEquals(5, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(4); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + @Test + public void testMultiSourceWindowGraphTraversal() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PWindowSource> vertices2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PStreamSource> triggerSource = + new WindowStreamSource<>( + context, + new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), + AllWindow.getInstance()); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices1.union(vertices2), edges); + PStreamSink sink = + graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource).sink(v -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(4, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + @Test + public void testAllWindowGraphTraversal() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); + + PStreamSource> triggerSource = + new WindowStreamSource<>( + context, + new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), + AllWindow.getInstance()); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); + PStreamSink sink = + graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource).sink(v -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(4, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + @Test + public void testTwoSourceWithGraphUnion() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource source1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + + PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v = v1.union(v2); + PWindowStream> e = e1.union(e2); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(1), context, v, e); + PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)).getVertices().sink(r -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(3, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + } + + @Test + public void testThreeSourceWithGraphUnion() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource source1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source3 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + + PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v3 = source3.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e3 = source3.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v = v1.union(v2).union(v3); + PWindowStream> e = e1.union(e2).union(e3); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(1), context, v, e); + PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)).getVertices().sink(r -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(3, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + } + + @Test + public void testTenSourceWithGraphUnion() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource source1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source3 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source4 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source5 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source6 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source7 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source8 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source9 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + + PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v3 = source3.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e3 = source3.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v4 = source4.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e4 = source4.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v5 = source5.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e5 = source5.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v6 = source6.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e6 = source6.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v7 = source7.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e7 = source7.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v8 = source8.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e8 = source8.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v9 = source9.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e9 = source9.map(i -> new ValueEdge<>(i, i, i)); + + PWindowStream> v = + v1.union(v2).union(v3).union(v4).union(v5).union(v6).union(v7).union(v8).union(v9); + PWindowStream> e = + e1.union(e2).union(e3).union(e4).union(e5).union(e6).union(e7).union(e8).union(e9); + + PGraphWindow graphWindow = + new WindowStreamGraph(createGraphViewDesc(1), context, v, e); + PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)).getVertices().sink(r -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(3, graph.getVertexGroupMap().size()); + Assert.assertEquals(4, graph.getGroupEdgeMap().size()); + Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); + + ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); + Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); + Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); + Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + } + + @Test + public void testGroupVertexDiamondDependency() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + + PWindowSource source1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowSource source2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowStream map1 = source2.map(i -> i + 1); + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + + PWindowStream union1 = source1.union(source2).map(i -> i * 2); + PWindowStream union2 = union1.union(map1); + PStreamSink sink = union2.sink(r -> {}); + + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(1, graph.getVertexGroupMap().size()); + Assert.assertEquals(0, graph.getGroupEdgeMap().size()); + Assert.assertEquals(5, graph.getVertexGroupMap().get(1).getCycleGroupMeta().getFlyingCount()); + Assert.assertEquals( + Long.MAX_VALUE, graph.getVertexGroupMap().get(1).getCycleGroupMeta().getIterationCount()); + } + + @Test + public void testMultiGraphTraversal() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + PWindowSource> vertices1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PWindowSource> vertices2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PWindowSource> edges = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PStreamSource> triggerSource = + new WindowStreamSource<>( + context, + new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), + AllWindow.getInstance()); + + PGraphWindow graphWindow1 = + new WindowStreamGraph(createGraphViewDesc(4), context, vertices1, edges); + PWindowStream union = + graphWindow1 + .traversal(new GraphTraversalAlgorithms(3)) + .start(triggerSource) + .union(vertices2); + + PGraphWindow graphWindow2 = + new WindowStreamGraph(createGraphViewDesc(4), context, union, edges); + + PStreamSink sink = + graphWindow2.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource).sink(v -> {}); + + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); + ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); + Assert.assertEquals(7, graph.getVertexGroupMap().size()); + Assert.assertEquals(9, graph.getGroupEdgeMap().size()); + Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + } + + public static class IncGraphAlgorithms + extends IncVertexCentricCompute { + + public IncGraphAlgorithms(long iterations) { + super(iterations); } - @Test - public void testMultiSourceWindowGraphTraversal() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource> vertices1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PWindowSource> vertices2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - PStreamSource> triggerSource = - new WindowStreamSource<>(context, - new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), - AllWindow.getInstance()); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(4), context, - vertices1.union(vertices2), edges); - PStreamSink sink = graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource).sink(v -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(4, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + @Override + public IncVertexCentricComputeFunction + getIncComputeFunction() { + return new PRVertexCentricComputeFunction(); } - @Test - public void testAllWindowGraphTraversal() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), AllWindow.getInstance()); - - PStreamSource> triggerSource = - new WindowStreamSource<>(context, - new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), - AllWindow.getInstance()); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(4), context, vertices, edges); - PStreamSink sink = graphWindow.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource).sink(v -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(4, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(3); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - @Test - public void testTwoSourceWithGraphUnion() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); + public static class PRVertexCentricComputeFunction + implements IncVertexCentricComputeFunction { - PWindowSource source1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + @Override + public void init(IncGraphComputeContext graphContext) {} - PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); + @Override + public void evolve( + Integer vertexId, TemporaryGraph temporaryGraph) {} - PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); + @Override + public void compute(Integer vertexId, Iterator messageIterator) {} - PWindowStream> v = v1.union(v2); - PWindowStream> e = e1.union(e2); + @Override + public void finish(Integer vertexId, MutableGraph mutableGraph) {} + } - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(1), context, v, e); - PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)) - .getVertices() - .sink(r -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(3, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); - } - - @Test - public void testThreeSourceWithGraphUnion() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource source1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source3 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - - PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); + public static class PRAlgorithms extends VertexCentricCompute { - PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v3 = source3.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e3 = source3.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v = v1.union(v2).union(v3); - PWindowStream> e = e1.union(e2).union(e3); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(1), - context, v, e); - PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)) - .getVertices() - .sink(r -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(3, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + public PRAlgorithms(long iterations) { + super(iterations); } - @Test - public void testTenSourceWithGraphUnion() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - PWindowSource source1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source3 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source4 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source5 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source6 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source7 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source8 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source9 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - - PWindowStream> v1 = source1.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e1 = source1.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v2 = source2.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e2 = source2.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v3 = source3.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e3 = source3.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v4 = source4.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e4 = source4.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v5 = source5.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e5 = source5.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v6 = source6.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e6 = source6.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v7 = source7.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e7 = source7.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v8 = source8.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e8 = source8.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v9 = source9.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e9 = source9.map(i -> new ValueEdge<>(i, i, i)); - - PWindowStream> v = - v1.union(v2).union(v3).union(v4).union(v5).union(v6).union(v7).union(v8).union(v9); - PWindowStream> e = - e1.union(e2).union(e3).union(e4).union(e5).union(e6).union(e7).union(e8).union(e9); - - PGraphWindow graphWindow = new WindowStreamGraph(createGraphViewDesc(1), - context, v, e); - PStreamSink sink = graphWindow.compute(new PRAlgorithms(3)) - .getVertices() - .sink(r -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(3, graph.getVertexGroupMap().size()); - Assert.assertEquals(4, graph.getGroupEdgeMap().size()); - Assert.assertEquals(1, graph.getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getCycleGroupMeta().getIterationCount()); - - ExecutionVertexGroup vertexGroup = graph.getVertexGroupMap().get(2); - Assert.assertEquals(3, vertexGroup.getCycleGroupMeta().getIterationCount()); - Assert.assertEquals(1, vertexGroup.getCycleGroupMeta().getFlyingCount()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().isIterative()); - Assert.assertTrue(vertexGroup.getCycleGroupMeta().getAffinityLevel() == AffinityLevel.worker); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new PRVertexCentricComputeFunction2(); } - @Test - public void testGroupVertexDiamondDependency() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - - PWindowSource source1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowSource source2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowStream map1 = source2.map(i -> i + 1); - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - - PWindowStream union1 = source1.union(source2).map(i -> i * 2); - PWindowStream union2 = union1.union(map1); - PStreamSink sink = union2.sink(r -> { - }); - - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(1, graph.getVertexGroupMap().size()); - Assert.assertEquals(0, graph.getGroupEdgeMap().size()); - Assert.assertEquals(5, graph.getVertexGroupMap().get(1).getCycleGroupMeta().getFlyingCount()); - Assert.assertEquals(Long.MAX_VALUE, graph.getVertexGroupMap().get(1).getCycleGroupMeta().getIterationCount()); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - @Test - public void testMultiGraphTraversal() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); + public static class PRVertexCentricComputeFunction2 + implements VertexCentricComputeFunction { - PWindowSource> vertices1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PWindowSource> vertices2 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + @Override + public void init( + VertexCentricComputeFuncContext + vertexCentricFuncContext) {} - PWindowSource> edges = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + @Override + public void compute(Integer vertexId, Iterator messageIterator) {} - PStreamSource> triggerSource = - new WindowStreamSource<>(context, - new CollectionSource<>(Lists.newArrayList(new VertexBeginTraversalRequest(3))), - AllWindow.getInstance()); + @Override + public void finish() {} + } + public static class GraphTraversalAlgorithms + extends VertexCentricTraversal { - PGraphWindow graphWindow1 = new WindowStreamGraph(createGraphViewDesc(4), context, - vertices1, edges); - PWindowStream union = graphWindow1.traversal(new GraphTraversalAlgorithms(3)) - .start(triggerSource).union(vertices2); - - PGraphWindow graphWindow2 = new WindowStreamGraph(createGraphViewDesc(4), context, - union, edges); - - PStreamSink sink = graphWindow2.traversal(new GraphTraversalAlgorithms(3)).start(triggerSource) - .sink(v -> { - }); - - - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - ExecutionGraphBuilder builder = new ExecutionGraphBuilder(pipelineGraph); - ExecutionGraph graph = builder.buildExecutionGraph(new Configuration()); - Assert.assertEquals(7, graph.getVertexGroupMap().size()); - Assert.assertEquals(9, graph.getGroupEdgeMap().size()); - Assert.assertTrue(graph.getCycleGroupMeta().getGroupType() == CycleGroupType.statical); + public GraphTraversalAlgorithms(long iterations) { + super(iterations); } - public static class IncGraphAlgorithms extends IncVertexCentricCompute { - - public IncGraphAlgorithms(long iterations) { - super(iterations); - } - - @Override - public IncVertexCentricComputeFunction getIncComputeFunction() { - return new PRVertexCentricComputeFunction(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - public static class PRVertexCentricComputeFunction implements IncVertexCentricComputeFunction { - - @Override - public void init(IncGraphComputeContext graphContext) { - } - - @Override - public void evolve(Integer vertexId, TemporaryGraph temporaryGraph) { - } - - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - } - - - @Override - public void finish(Integer vertexId, MutableGraph mutableGraph) { - } + @Override + public VertexCentricTraversalFunction + getTraversalFunction() { + return null; } + } - public static class PRAlgorithms extends VertexCentricCompute { - - public PRAlgorithms(long iterations) { - super(iterations); - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new PRVertexCentricComputeFunction2(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + public static class CountFunc implements ReduceFunction> { + @Override + public Tuple reduce( + Tuple oldValue, Tuple newValue) { + return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); } + } - public static class PRVertexCentricComputeFunction2 implements VertexCentricComputeFunction { - - @Override - public void init(VertexCentricComputeFuncContext vertexCentricFuncContext) { - } + public static class AggFunc + implements AggregateFunction, Tuple, Tuple> { - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - } - - @Override - public void finish() { - - } + @Override + public Tuple createAccumulator() { + return Tuple.of(0L, 0L); } - public static class GraphTraversalAlgorithms extends VertexCentricTraversal { - - public GraphTraversalAlgorithms(long iterations) { - super(iterations); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - - @Override - public VertexCentricTraversalFunction getTraversalFunction() { - return null; - } + @Override + public void add(Tuple value, Tuple accumulator) { + accumulator.setF0(value.f0); + accumulator.setF1(value.f1 + accumulator.f1); } - public static class CountFunc implements ReduceFunction> { - - @Override - public Tuple reduce(Tuple oldValue, Tuple newValue) { - return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); - } - } - - public static class AggFunc implements - AggregateFunction, Tuple, Tuple> { - - @Override - public Tuple createAccumulator() { - return Tuple.of(0L, 0L); - } - - @Override - public void add(Tuple value, Tuple accumulator) { - accumulator.setF0(value.f0); - accumulator.setF1(value.f1 + accumulator.f1); - } - - @Override - public Tuple getResult(Tuple accumulator) { - return Tuple.of(accumulator.f0, accumulator.f1); - } - - @Override - public Tuple merge(Tuple a, Tuple b) { - return null; - } + @Override + public Tuple getResult(Tuple accumulator) { + return Tuple.of(accumulator.f0, accumulator.f1); } - private GraphViewDesc createGraphViewDesc(int shardNum) { - return GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(shardNum) - .withBackend(BackendType.Memory) - .build(); + @Override + public Tuple merge(Tuple a, Tuple b) { + return null; } + } + + private GraphViewDesc createGraphViewDesc(int shardNum) { + return GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(shardNum) + .withBackend(BackendType.Memory) + .build(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/BasePlanTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/BasePlanTest.java index 2b26f0766..1fbcb14dc 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/BasePlanTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/BasePlanTest.java @@ -31,35 +31,35 @@ public class BasePlanTest { - protected PipelineGraph plan; + protected PipelineGraph plan; - public void setUp() { - plan = new PipelineGraph(); - Class[] typeArguments = new Class[4]; - typeArguments[0] = Integer.class; - typeArguments[1] = Integer.class; - typeArguments[2] = Integer.class; - typeArguments[3] = Integer.class; + public void setUp() { + plan = new PipelineGraph(); + Class[] typeArguments = new Class[4]; + typeArguments[0] = Integer.class; + typeArguments[1] = Integer.class; + typeArguments[2] = Integer.class; + typeArguments[3] = Integer.class; - plan.addVertex(new PipelineVertex(1, new WindowSourceOperator<>(null, null), 1)); - plan.addVertex(new PipelineVertex(2, new MapOperator<>(null), 1)); + plan.addVertex(new PipelineVertex(1, new WindowSourceOperator<>(null, null), 1)); + plan.addVertex(new PipelineVertex(2, new MapOperator<>(null), 1)); - MapOperator mapOperator = new MapOperator(null); - mapOperator.getOpArgs().setParallelism(7); - plan.addVertex(new PipelineVertex(3, mapOperator, 1)); - plan.addVertex(new PipelineVertex(4, new SinkOperator<>(null), 1)); - plan.addVertex(new PipelineVertex(5, new KeySelectorOperator<>(null), 1)); - KeySelectorOperator keySelectorOperator = new KeySelectorOperator<>(null); - keySelectorOperator.getOpArgs().setParallelism(7); - plan.addVertex(new PipelineVertex(6, keySelectorOperator, 1)); - plan.addVertex(new PipelineVertex(7, new SinkOperator<>(null), 1)); + MapOperator mapOperator = new MapOperator(null); + mapOperator.getOpArgs().setParallelism(7); + plan.addVertex(new PipelineVertex(3, mapOperator, 1)); + plan.addVertex(new PipelineVertex(4, new SinkOperator<>(null), 1)); + plan.addVertex(new PipelineVertex(5, new KeySelectorOperator<>(null), 1)); + KeySelectorOperator keySelectorOperator = new KeySelectorOperator<>(null); + keySelectorOperator.getOpArgs().setParallelism(7); + plan.addVertex(new PipelineVertex(6, keySelectorOperator, 1)); + plan.addVertex(new PipelineVertex(7, new SinkOperator<>(null), 1)); - plan.addEdge(new PipelineEdge(1, 2, 4, new ForwardPartitioner(), null)); - plan.addEdge(new PipelineEdge(2, 1, 2, new ForwardPartitioner(), null)); - plan.addEdge(new PipelineEdge(3, 5, 7, new KeyPartitioner(1), null)); - plan.addEdge(new PipelineEdge(4, 6, 7, new KeyPartitioner(2), null)); - plan.addEdge(new PipelineEdge(5, 2, 5, new ForwardPartitioner(), null)); - plan.addEdge(new PipelineEdge(6, 3, 6, new ForwardPartitioner(), null)); - plan.addEdge(new PipelineEdge(7, 1, 3, new ForwardPartitioner(), null)); - } + plan.addEdge(new PipelineEdge(1, 2, 4, new ForwardPartitioner(), null)); + plan.addEdge(new PipelineEdge(2, 1, 2, new ForwardPartitioner(), null)); + plan.addEdge(new PipelineEdge(3, 5, 7, new KeyPartitioner(1), null)); + plan.addEdge(new PipelineEdge(4, 6, 7, new KeyPartitioner(2), null)); + plan.addEdge(new PipelineEdge(5, 2, 5, new ForwardPartitioner(), null)); + plan.addEdge(new PipelineEdge(6, 3, 6, new ForwardPartitioner(), null)); + plan.addEdge(new PipelineEdge(7, 1, 3, new ForwardPartitioner(), null)); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/PipelinePlanTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/PipelinePlanTest.java index 26ff4f355..661c34a21 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/PipelinePlanTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/PipelinePlanTest.java @@ -23,13 +23,13 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.pdata.PStreamSink; import org.apache.geaflow.api.pdata.PStreamSource; @@ -68,206 +68,236 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class PipelinePlanTest extends BasePlanTest { - - @BeforeClass - public void setUp() { - super.setUp(); - } - - @Test - public void getSrcVertexIdTest() { - List ret = plan.getVertexInputVertexIds(7); - Set target = new HashSet<>(); - target.add(6); - target.add(5); - Assert.assertEquals(ret, target); - - ret = plan.getVertexInputVertexIds(4); - target.clear(); - target.add(2); - Assert.assertEquals(ret, target); - } - - @Test - public void testGetVertexInputEdges() { - Map> map = plan.getVertexInputEdges(); - Set target = new HashSet<>(); - target.add(new PipelineEdge(6, 3, 6, new ForwardPartitioner(), null)); - - Assert.assertEquals(map.get(6), target); - } - - @Test - public void testGetVertexOutputEdges() { - Map> map = plan.getVertexOutputEdges(); - Set target = new HashSet<>(); - target.add(new PipelineEdge(2, 1, 2, new ForwardPartitioner(), null)); - target.add(new PipelineEdge(7, 1, 3, new ForwardPartitioner(), null)); - - Assert.assertEquals(map.get(1), target); - } - - @Test - public void testTwoLayerReduce() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - PStreamSink sink = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)) - .keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)).keyBy(u -> u) - .reduce((oldValue, newValue) -> (oldValue)).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - // source->keyBy->reduce->keyBy->reduce->sink - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 6); - } - - @Test - public void testOneLayerReduce() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - PStreamSink sink = new WindowStreamSource<>(context, - new CollectionSource(new ArrayList<>()), SizeTumblingWindow.of(2)) - .keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - // source->keyBy->reduce->sink - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 4); - } +import com.google.common.collect.ImmutableList; +public class PipelinePlanTest extends BasePlanTest { - @Test - public void testMultiOutput() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowDataStream> ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), + @BeforeClass + public void setUp() { + super.setUp(); + } + + @Test + public void getSrcVertexIdTest() { + List ret = plan.getVertexInputVertexIds(7); + Set target = new HashSet<>(); + target.add(6); + target.add(5); + Assert.assertEquals(ret, target); + + ret = plan.getVertexInputVertexIds(4); + target.clear(); + target.add(2); + Assert.assertEquals(ret, target); + } + + @Test + public void testGetVertexInputEdges() { + Map> map = plan.getVertexInputEdges(); + Set target = new HashSet<>(); + target.add(new PipelineEdge(6, 3, 6, new ForwardPartitioner(), null)); + + Assert.assertEquals(map.get(6), target); + } + + @Test + public void testGetVertexOutputEdges() { + Map> map = plan.getVertexOutputEdges(); + Set target = new HashSet<>(); + target.add(new PipelineEdge(2, 1, 2, new ForwardPartitioner(), null)); + target.add(new PipelineEdge(7, 1, 3, new ForwardPartitioner(), null)); + + Assert.assertEquals(map.get(1), target); + } + + @Test + public void testTwoLayerReduce() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + PStreamSink sink = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .keyBy(u -> u) + .reduce((oldValue, newValue) -> (oldValue)) + .sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + // source->keyBy->reduce->keyBy->reduce->sink + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 6); + } + + @Test + public void testOneLayerReduce() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + PStreamSink sink = + new WindowStreamSource<>( + context, new CollectionSource(new ArrayList<>()), SizeTumblingWindow.of(2)) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + // source->keyBy->reduce->sink + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 4); + } + + @Test + public void testMultiOutput() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowDataStream> ds1 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - WindowDataStream> ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - PStream> ds = ds2.map(v -> v).withParallelism(3).filter(v -> v.f0 > 1L); - PStreamSink sink1 = ds1.sink(p -> { - }); - PStreamSink sink2 = ds.keyBy(p -> p).reduce((v1, v2) -> v1).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - // ds1_source->print; - // ds2_source->map->filter keyby->reduce->print - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 8); - } + WindowDataStream> ds2 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); - public void testMaterialize(IViewDesc.BackendType backendType) { - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - AbstractPipelineContext context = new AbstractPipelineContext(configuration) { - @Override - public int generateId() { - return idGenerator.incrementAndGet(); - } + PStream> ds = ds2.map(v -> v).withParallelism(3).filter(v -> v.f0 > 1L); + PStreamSink sink1 = ds1.sink(p -> {}); + PStreamSink sink2 = ds.keyBy(p -> p).reduce((v1, v2) -> v1).sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + // ds1_source->print; + // ds2_source->map->filter keyby->reduce->print + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 8); + } + + public void testMaterialize(IViewDesc.BackendType backendType) { + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + AbstractPipelineContext context = + new AbstractPipelineContext(configuration) { + @Override + public int generateId() { + return idGenerator.incrementAndGet(); + } }; - PWindowSource source1 = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); - PWindowStream> v = source1.map(i -> new ValueVertex<>(i, (double) i)); - PWindowStream> e = source1.map(i -> new ValueEdge<>(i, i, i)); - - PStreamSource> vertices = - new WindowStreamSource<>(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - PStreamSource> edges = - new WindowStreamSource(context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + PWindowSource source1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(10)); + PWindowStream> v = source1.map(i -> new ValueVertex<>(i, (double) i)); + PWindowStream> e = source1.map(i -> new ValueEdge<>(i, i, i)); + + PStreamSource> vertices = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + PStreamSource> edges = + new WindowStreamSource( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(4) .withBackend(backendType) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - PGraphView fundGraphView = - new IncGraphView<>(context, graphViewDesc); - PIncGraphView incGraphView = - fundGraphView.appendGraph( - vertices.window(WindowFactory.createSizeTumblingWindow(1)), - edges.window(WindowFactory.createSizeTumblingWindow(1))); - incGraphView.materialize(); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - - Map vertexMap = pipelineGraph.getVertexMap(); - if (backendType == BackendType.Paimon) { - Assert.assertEquals(vertexMap.size(), 4); - } else { - Assert.assertEquals(vertexMap.size(), 3); - } - } - - @Test - public void testRocksDBMaterialize() { - testMaterialize(BackendType.RocksDB); + PGraphView fundGraphView = + new IncGraphView<>(context, graphViewDesc); + PIncGraphView incGraphView = + fundGraphView.appendGraph( + vertices.window(WindowFactory.createSizeTumblingWindow(1)), + edges.window(WindowFactory.createSizeTumblingWindow(1))); + incGraphView.materialize(); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + + Map vertexMap = pipelineGraph.getVertexMap(); + if (backendType == BackendType.Paimon) { + Assert.assertEquals(vertexMap.size(), 4); + } else { + Assert.assertEquals(vertexMap.size(), 3); } - - @Test - public void testPaimonMaterialize() { - testMaterialize(BackendType.Paimon); - } - - @Test - public void testAllWindowCheckpointDuration() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowStreamSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); - PStreamSink sink = source + } + + @Test + public void testRocksDBMaterialize() { + testMaterialize(BackendType.RocksDB); + } + + @Test + public void testPaimonMaterialize() { + testMaterialize(BackendType.Paimon); + } + + @Test + public void testAllWindowCheckpointDuration() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowStreamSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), AllWindow.getInstance()); + PStreamSink sink = + source .map(e -> Tuple.of(e, 1)) .keyBy(v -> ((Tuple) v).f0) .reduce(new ExecutionGraphBuilderTest.CountFunc()) .withParallelism(1) - .sink(v -> { - }) + .sink(v -> {}) .withParallelism(1); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Assert.assertEquals(context.getConfig().getLong(BATCH_NUMBER_PER_CHECKPOINT), 1); - } + Assert.assertEquals(context.getConfig().getLong(BATCH_NUMBER_PER_CHECKPOINT), 1); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/UnionTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/UnionTest.java index a3ebd984f..898ce21b9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/UnionTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/UnionTest.java @@ -22,9 +22,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.function.base.FilterFunction; import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.pdata.PStreamSink; @@ -40,138 +40,215 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class UnionTest { +import com.google.common.collect.ImmutableList; - @Test - public void testUnionPlan() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - - WindowStreamSource> ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - WindowStreamSource> ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - PStreamSink sink = ds1.union(ds2).keyBy(p -> p).filter((FilterFunction>) record -> true).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - //ds1_source->union->ds2_source->filter->print; - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 6); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - Assert.assertEquals(vertexMap.size(), 4); - } - - @Test - public void testMultiUnionPlan() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - - WindowStreamSource> ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - WindowStreamSource> ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - WindowStreamSource> ds3 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - PStreamSink sink = ds1.union(ds2).union(ds3).keyBy(p -> p) - .filter((FilterFunction>) record -> true).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - //ds1_source->union->ds2_source->ds3_source->filter->print; - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 7); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - Assert.assertEquals(vertexMap.size(), 5); - } - - @Test - public void testUnionWithKeyByPlan() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - when(context.getConfig()).thenReturn(configuration); - - WindowStreamSource> ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - WindowStreamSource> ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - PStreamSink sink = ds1.union(ds2).keyBy(p -> p).reduce((v1, v2) -> v2).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - //ds1_source->union->ds2_source->keyBy->reduce->print; - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 6); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - Assert.assertEquals(vertexMap.size(), 4); - } - - @Test - public void testWindowUnionWithKeyByPlan() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - - WindowStreamSource> ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - WindowStreamSource> ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(Tuple.of(1L, 3L), Tuple.of(2L, 5L), Tuple.of(3L, 7L), - Tuple.of(1L, 3L), Tuple.of(1L, 7L), Tuple.of(3L, 7L)), SizeTumblingWindow.of(2)); - - PStreamSink sink = ds1.union(ds2).keyBy(p -> p).reduce((v1, v2) -> v2).sink(p -> { - }); - when(context.getActions()).thenReturn(ImmutableList.of(sink)); - - //ds1_source->union->ds2_source->keyBy->reduce->print; - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Map vertexMap = pipelineGraph.getVertexMap(); - Assert.assertEquals(vertexMap.size(), 6); - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - Assert.assertEquals(vertexMap.size(), 4); - } +public class UnionTest { + @Test + public void testUnionPlan() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + + WindowStreamSource> ds1 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + WindowStreamSource> ds2 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + PStreamSink sink = + ds1.union(ds2) + .keyBy(p -> p) + .filter((FilterFunction>) record -> true) + .sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + // ds1_source->union->ds2_source->filter->print; + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 6); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + Assert.assertEquals(vertexMap.size(), 4); + } + + @Test + public void testMultiUnionPlan() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + + WindowStreamSource> ds1 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + WindowStreamSource> ds2 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + WindowStreamSource> ds3 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + PStreamSink sink = + ds1.union(ds2) + .union(ds3) + .keyBy(p -> p) + .filter((FilterFunction>) record -> true) + .sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + // ds1_source->union->ds2_source->ds3_source->filter->print; + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 7); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + Assert.assertEquals(vertexMap.size(), 5); + } + + @Test + public void testUnionWithKeyByPlan() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + when(context.getConfig()).thenReturn(configuration); + + WindowStreamSource> ds1 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + WindowStreamSource> ds2 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + PStreamSink sink = ds1.union(ds2).keyBy(p -> p).reduce((v1, v2) -> v2).sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + // ds1_source->union->ds2_source->keyBy->reduce->print; + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 6); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + Assert.assertEquals(vertexMap.size(), 4); + } + + @Test + public void testWindowUnionWithKeyByPlan() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + + WindowStreamSource> ds1 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + WindowStreamSource> ds2 = + new WindowStreamSource<>( + context, + new CollectionSource<>( + Tuple.of(1L, 3L), + Tuple.of(2L, 5L), + Tuple.of(3L, 7L), + Tuple.of(1L, 3L), + Tuple.of(1L, 7L), + Tuple.of(3L, 7L)), + SizeTumblingWindow.of(2)); + + PStreamSink sink = ds1.union(ds2).keyBy(p -> p).reduce((v1, v2) -> v2).sink(p -> {}); + when(context.getActions()).thenReturn(ImmutableList.of(sink)); + + // ds1_source->union->ds2_source->keyBy->reduce->print; + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + Map vertexMap = pipelineGraph.getVertexMap(); + Assert.assertEquals(vertexMap.size(), 6); + + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + Assert.assertEquals(vertexMap.size(), 4); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/graph/PipelineVertexTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/graph/PipelineVertexTest.java index 4dd19c979..fbef8a440 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/graph/PipelineVertexTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/graph/PipelineVertexTest.java @@ -27,32 +27,35 @@ public class PipelineVertexTest { - @Test - public void testVertex() { - PipelineVertex vertex1 = new PipelineVertex(1, new SinkOperator<>(), VertexType.sink, 2); - Assert.assertTrue(vertex1.getType() == VertexType.sink); - Assert.assertFalse(vertex1.isDuplication()); - Assert.assertNull(vertex1.getVertexMode()); - Assert.assertTrue(vertex1.getOperator() instanceof SinkOperator); - vertex1.setOperator(new FilterOperator<>(new FilterFunction() { - @Override - public boolean filter(Object record) { + @Test + public void testVertex() { + PipelineVertex vertex1 = new PipelineVertex(1, new SinkOperator<>(), VertexType.sink, 2); + Assert.assertTrue(vertex1.getType() == VertexType.sink); + Assert.assertFalse(vertex1.isDuplication()); + Assert.assertNull(vertex1.getVertexMode()); + Assert.assertTrue(vertex1.getOperator() instanceof SinkOperator); + vertex1.setOperator( + new FilterOperator<>( + new FilterFunction() { + @Override + public boolean filter(Object record) { return true; - } - })); - Assert.assertTrue(vertex1.getOperator() instanceof FilterOperator); + } + })); + Assert.assertTrue(vertex1.getOperator() instanceof FilterOperator); - PipelineVertex vertex2 = new PipelineVertex(1, VertexType.sink, new SinkOperator<>(), VertexMode.append); - Assert.assertTrue(vertex2.getType() == VertexType.sink); - Assert.assertFalse(vertex2.isDuplication()); - Assert.assertTrue(vertex2.getVertexMode() == VertexMode.append); - Assert.assertTrue(vertex2.getOperator() instanceof SinkOperator); - vertex2.setDuplication(); - Assert.assertTrue(vertex2.isDuplication()); + PipelineVertex vertex2 = + new PipelineVertex(1, VertexType.sink, new SinkOperator<>(), VertexMode.append); + Assert.assertTrue(vertex2.getType() == VertexType.sink); + Assert.assertFalse(vertex2.isDuplication()); + Assert.assertTrue(vertex2.getVertexMode() == VertexMode.append); + Assert.assertTrue(vertex2.getOperator() instanceof SinkOperator); + vertex2.setDuplication(); + Assert.assertTrue(vertex2.isDuplication()); - Assert.assertTrue(vertex1.equals(vertex2)); + Assert.assertTrue(vertex1.equals(vertex2)); - vertex1.setVertexId(2); - Assert.assertFalse(vertex1.equals(vertex2)); - } + vertex1.setVertexId(2); + Assert.assertFalse(vertex1.equals(vertex2)); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/optimizer/PlanOptimizerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/optimizer/PlanOptimizerTest.java index 17723f7f9..787e0863e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/optimizer/PlanOptimizerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-plan/src/test/java/org/apache/geaflow/plan/optimizer/PlanOptimizerTest.java @@ -22,9 +22,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.pdata.PStreamSink; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; @@ -44,115 +44,129 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class PlanOptimizerTest { - - @Test - public void testSingleOutput() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - PStreamSink stream = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(3)) - .map(p -> p).withParallelism(2).keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)) - .withParallelism(3).sink(p -> { - }).withParallelism(3); - - when(context.getActions()).thenReturn(ImmutableList.of(stream)); - - Assert.assertEquals(1, context.getActions().size()); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - Assert.assertEquals(4, pipelineGraph.getPipelineEdgeList().size()); - Assert.assertEquals(5, pipelineGraph.getPipelineVertices().size()); - - ChainCombiner combiner = new ChainCombiner(); - combiner.combineVertex(pipelineGraph); - - Assert.assertEquals(2, pipelineGraph.getPipelineEdgeList().size()); - Assert.assertEquals(3, pipelineGraph.getPipelineVertices().size()); - - Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(1).size()); - Assert.assertEquals(2, - pipelineGraph.getVertexOutEdges(1).stream().findFirst().get().getTargetId()); - - Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(2).size()); - Assert.assertEquals(4, - pipelineGraph.getVertexOutEdges(2).stream().findFirst().get().getTargetId()); - } - - @Test - public void testMultiOutput() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - Configuration configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); - when(context.getConfig()).thenReturn(configuration); - WindowStreamSource ds1 = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - WindowStreamSource ds2 = new WindowStreamSource<>(context, - new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); - - PStreamSink sink1 = ds1.map(p -> p).withParallelism(2).keyBy(q -> q).reduce((oldValue, newValue) -> (oldValue)) - .withParallelism(3).sink(p -> { - }).withParallelism(3); - PStreamSink sink2 = ds2.sink(v -> { - }); - - when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); - - Assert.assertEquals(2, context.getActions().size()); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - - Assert.assertEquals(5, pipelineGraph.getPipelineEdgeList().size()); - Assert.assertEquals(7, pipelineGraph.getPipelineVertices().size()); - - ChainCombiner combiner = new ChainCombiner(); - combiner.combineVertex(pipelineGraph); +import com.google.common.collect.ImmutableList; - Assert.assertEquals(2, pipelineGraph.getPipelineEdgeList().size()); - Assert.assertEquals(4, pipelineGraph.getPipelineVertices().size()); +public class PlanOptimizerTest { - Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(1).size()); - Assert.assertEquals(3, - pipelineGraph.getVertexOutEdges(1).stream().findFirst().get().getTargetId()); + @Test + public void testSingleOutput() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + PStreamSink stream = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(3)) + .map(p -> p) + .withParallelism(2) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .withParallelism(3) + .sink(p -> {}) + .withParallelism(3); + + when(context.getActions()).thenReturn(ImmutableList.of(stream)); + + Assert.assertEquals(1, context.getActions().size()); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + Assert.assertEquals(4, pipelineGraph.getPipelineEdgeList().size()); + Assert.assertEquals(5, pipelineGraph.getPipelineVertices().size()); + + ChainCombiner combiner = new ChainCombiner(); + combiner.combineVertex(pipelineGraph); + + Assert.assertEquals(2, pipelineGraph.getPipelineEdgeList().size()); + Assert.assertEquals(3, pipelineGraph.getPipelineVertices().size()); + + Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(1).size()); + Assert.assertEquals( + 2, pipelineGraph.getVertexOutEdges(1).stream().findFirst().get().getTargetId()); + + Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(2).size()); + Assert.assertEquals( + 4, pipelineGraph.getVertexOutEdges(2).stream().findFirst().get().getTargetId()); + } + + @Test + public void testMultiOutput() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + Configuration configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE, Boolean.TRUE.toString()); + when(context.getConfig()).thenReturn(configuration); + WindowStreamSource ds1 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + WindowStreamSource ds2 = + new WindowStreamSource<>( + context, new CollectionSource<>(new ArrayList<>()), SizeTumblingWindow.of(2)); + + PStreamSink sink1 = + ds1.map(p -> p) + .withParallelism(2) + .keyBy(q -> q) + .reduce((oldValue, newValue) -> (oldValue)) + .withParallelism(3) + .sink(p -> {}) + .withParallelism(3); + PStreamSink sink2 = ds2.sink(v -> {}); + + when(context.getActions()).thenReturn(ImmutableList.of(sink1, sink2)); + + Assert.assertEquals(2, context.getActions().size()); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + + Assert.assertEquals(5, pipelineGraph.getPipelineEdgeList().size()); + Assert.assertEquals(7, pipelineGraph.getPipelineVertices().size()); + + ChainCombiner combiner = new ChainCombiner(); + combiner.combineVertex(pipelineGraph); + + Assert.assertEquals(2, pipelineGraph.getPipelineEdgeList().size()); + Assert.assertEquals(4, pipelineGraph.getPipelineVertices().size()); + + Assert.assertEquals(1, pipelineGraph.getVertexOutEdges(1).size()); + Assert.assertEquals( + 3, pipelineGraph.getVertexOutEdges(1).stream().findFirst().get().getTargetId()); + } + + @Test + public void testOperatorChain() { + AtomicInteger idGenerator = new AtomicInteger(0); + AbstractPipelineContext context = mock(AbstractPipelineContext.class); + when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); + + WindowStreamSource source = + new WindowStreamSource( + context, new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); + WindowDataStream filter1 = new WindowDataStream(source, new MapOperator(x -> x)); + WindowDataStream filter2 = new WindowDataStream(filter1, new MapOperator(x -> x)); + WindowDataStream mapper1 = new WindowDataStream(filter2, new MapOperator(x -> x)); + WindowDataStream mapper2 = new WindowDataStream(mapper1, new MapOperator(x -> x)); + PStreamSink sink1 = new WindowStreamSink(mapper2, new SinkOperator(x -> {})); + + when(context.getActions()).thenReturn(ImmutableList.of(sink1)); + + PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = planBuilder.buildPlan(context); + Assert.assertNotNull(pipelineGraph); + for (PipelineVertex vertex : pipelineGraph.getPipelineVertices()) { + Assert.assertTrue(((AbstractOperator) vertex.getOperator()).getNextOperators().isEmpty()); } - @Test - public void testOperatorChain() { - AtomicInteger idGenerator = new AtomicInteger(0); - AbstractPipelineContext context = mock(AbstractPipelineContext.class); - when(context.generateId()).then(invocation -> idGenerator.incrementAndGet()); - - WindowStreamSource source = new WindowStreamSource(context, - new CollectionSource<>(ImmutableList.of(1, 2, 3)), SizeTumblingWindow.of(2)); - WindowDataStream filter1 = new WindowDataStream(source, new MapOperator(x -> x)); - WindowDataStream filter2 = new WindowDataStream(filter1, new MapOperator(x -> x)); - WindowDataStream mapper1 = new WindowDataStream(filter2, new MapOperator(x -> x)); - WindowDataStream mapper2 = new WindowDataStream(mapper1, new MapOperator(x -> x)); - PStreamSink sink1 = new WindowStreamSink(mapper2, new SinkOperator(x -> { - })); - - when(context.getActions()).thenReturn(ImmutableList.of(sink1)); - - PipelinePlanBuilder planBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = planBuilder.buildPlan(context); - Assert.assertNotNull(pipelineGraph); - for (PipelineVertex vertex : pipelineGraph.getPipelineVertices()) { - Assert.assertTrue(((AbstractOperator) vertex.getOperator()).getNextOperators().isEmpty()); - } - - PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); - optimizer.optimizePipelineGraph(pipelineGraph); - Assert.assertEquals(pipelineGraph.getVertexMap().size(), 1); - PipelineVertex sourceVertex = pipelineGraph.getVertexMap().get(1); - Assert.assertEquals(((AbstractOperator) sourceVertex.getOperator()).getNextOperators().size(), 1); - } + PipelineGraphOptimizer optimizer = new PipelineGraphOptimizer(); + optimizer.optimizePipelineGraph(pipelineGraph); + Assert.assertEquals(pipelineGraph.getVertexMap().size(), 1); + PipelineVertex sourceVertex = pipelineGraph.getVertexMap().get(1); + Assert.assertEquals( + ((AbstractOperator) sourceVertex.getOperator()).getNextOperators().size(), 1); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/Processor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/Processor.java index 3f9c9faa9..b09e03426 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/Processor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/Processor.java @@ -21,38 +21,27 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.collector.ICollector; public interface Processor extends Serializable { - /** - * Returns the op id. - */ - int getId(); - - /** - * Operator open. - */ - void open(List collector, RuntimeContext runtimeContext); - - /** - * Initialize operator by windowId. - */ - void init(long windowId); - - /** - * Operator process value t. - */ - R process(T t); - - /** - * Finish processing of windowId. - */ - void finish(long windowId); - - /** - * Operator close. - */ - void close(); + /** Returns the op id. */ + int getId(); + + /** Operator open. */ + void open(List collector, RuntimeContext runtimeContext); + + /** Initialize operator by windowId. */ + void init(long windowId); + + /** Operator process value t. */ + R process(T t); + + /** Finish processing of windowId. */ + void finish(long windowId); + + /** Operator close. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/IProcessorBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/IProcessorBuilder.java index 479c31ffb..7b8620228 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/IProcessorBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/IProcessorBuilder.java @@ -24,8 +24,6 @@ public interface IProcessorBuilder { - /** - * Build processor based on corresponding operator. - */ - Processor buildProcessor(Operator operator); + /** Build processor based on corresponding operator. */ + Processor buildProcessor(Operator operator); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/ProcessorBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/ProcessorBuilder.java index 88dfe39e5..187371678 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/ProcessorBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/builder/ProcessorBuilder.java @@ -19,7 +19,6 @@ package org.apache.geaflow.processor.builder; -import com.google.common.base.Preconditions; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.operator.OpArgs.OpType; @@ -35,39 +34,41 @@ import org.apache.geaflow.processor.impl.window.OneInputProcessor; import org.apache.geaflow.processor.impl.window.TwoInputProcessor; -public class ProcessorBuilder implements IProcessorBuilder { +import com.google.common.base.Preconditions; - @Override - public Processor buildProcessor(Operator operator) { - Processor processor = null; - OpType type = ((AbstractOperator) operator).getOpArgs().getOpType(); - String msg = String.format("operator %s type is null", operator); - Preconditions.checkArgument(type != null, msg); - switch (type) { - case SINGLE_WINDOW_SOURCE: - case MULTI_WINDOW_SOURCE: - processor = new SourceProcessor((WindowSourceOperator) operator); - break; - case ONE_INPUT: - processor = new OneInputProcessor((OneInputOperator) operator); - break; - case TWO_INPUT: - processor = new TwoInputProcessor((TwoInputOperator) operator); - break; - case GRAPH_SOURCE: - break; - case VERTEX_CENTRIC_COMPUTE: - case VERTEX_CENTRIC_COMPUTE_WITH_AGG: - case VERTEX_CENTRIC_TRAVERSAL: - case INC_VERTEX_CENTRIC_COMPUTE: - case INC_VERTEX_CENTRIC_TRAVERSAL: - processor = new GraphVertexCentricProcessor((AbstractGraphVertexCentricOp) operator); - break; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.operatorTypeNotSupportError(type.name())); - } +public class ProcessorBuilder implements IProcessorBuilder { - return processor; + @Override + public Processor buildProcessor(Operator operator) { + Processor processor = null; + OpType type = ((AbstractOperator) operator).getOpArgs().getOpType(); + String msg = String.format("operator %s type is null", operator); + Preconditions.checkArgument(type != null, msg); + switch (type) { + case SINGLE_WINDOW_SOURCE: + case MULTI_WINDOW_SOURCE: + processor = new SourceProcessor((WindowSourceOperator) operator); + break; + case ONE_INPUT: + processor = new OneInputProcessor((OneInputOperator) operator); + break; + case TWO_INPUT: + processor = new TwoInputProcessor((TwoInputOperator) operator); + break; + case GRAPH_SOURCE: + break; + case VERTEX_CENTRIC_COMPUTE: + case VERTEX_CENTRIC_COMPUTE_WITH_AGG: + case VERTEX_CENTRIC_TRAVERSAL: + case INC_VERTEX_CENTRIC_COMPUTE: + case INC_VERTEX_CENTRIC_TRAVERSAL: + processor = new GraphVertexCentricProcessor((AbstractGraphVertexCentricOp) operator); + break; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.operatorTypeNotSupportError(type.name())); } -} + return processor; + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractProcessor.java index 9e1205b1e..481915445 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractProcessor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.processor.impl; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.trait.CancellableTrait; import org.apache.geaflow.collector.ICollector; @@ -29,57 +30,57 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractProcessor implements Processor, CancellableTrait { +public abstract class AbstractProcessor + implements Processor, CancellableTrait { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractProcessor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractProcessor.class); - protected OP operator; - protected List collectors; - protected RuntimeContext runtimeContext; + protected OP operator; + protected List collectors; + protected RuntimeContext runtimeContext; - public AbstractProcessor(OP operator) { - this.operator = operator; - } + public AbstractProcessor(OP operator) { + this.operator = operator; + } - public OP getOperator() { - return operator; - } + public OP getOperator() { + return operator; + } - @Override - public int getId() { - return ((AbstractOperator) operator).getOpArgs().getOpId(); - } + @Override + public int getId() { + return ((AbstractOperator) operator).getOpArgs().getOpId(); + } - @Override - public void open(List collectors, RuntimeContext runtimeContext) { - this.collectors = collectors; - this.runtimeContext = runtimeContext; - this.operator.open(new AbstractOperator.DefaultOpContext(collectors, runtimeContext)); - } + @Override + public void open(List collectors, RuntimeContext runtimeContext) { + this.collectors = collectors; + this.runtimeContext = runtimeContext; + this.operator.open(new AbstractOperator.DefaultOpContext(collectors, runtimeContext)); + } - @Override - public void init(long windowId) { - } + @Override + public void init(long windowId) {} - @Override - public void finish(long windowId) { - operator.finish(); - } + @Override + public void finish(long windowId) { + operator.finish(); + } - @Override - public void close() { - this.operator.close(); - } + @Override + public void close() { + this.operator.close(); + } - @Override - public void cancel() { - if (this.operator instanceof CancellableTrait) { - ((CancellableTrait) this.operator).cancel(); - } + @Override + public void cancel() { + if (this.operator instanceof CancellableTrait) { + ((CancellableTrait) this.operator).cancel(); } + } - @Override - public String toString() { - return operator.getClass().getSimpleName(); - } + @Override + public String toString() { + return operator.getClass().getSimpleName(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractStreamProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractStreamProcessor.java index e11459af7..c8705ab40 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractStreamProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractStreamProcessor.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.trait.CheckpointTrait; import org.apache.geaflow.api.trait.TransactionTrait; @@ -34,69 +35,71 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractStreamProcessor extends AbstractProcessor implements TransactionTrait { +public abstract class AbstractStreamProcessor + extends AbstractProcessor implements TransactionTrait { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamProcessor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStreamProcessor.class); - protected Object lock = new Object(); - protected List transactionOpList; - protected long checkpointDuration; + protected Object lock = new Object(); + protected List transactionOpList; + protected long checkpointDuration; - public AbstractStreamProcessor(OP operator) { - super(operator); - this.transactionOpList = new ArrayList<>(); - addIfTransactionTrait(operator); - } + public AbstractStreamProcessor(OP operator) { + super(operator); + this.transactionOpList = new ArrayList<>(); + addIfTransactionTrait(operator); + } - @Override - public void open(List collectors, RuntimeContext runtimeContext) { - super.open(collectors, runtimeContext); - this.checkpointDuration = this.runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); - } + @Override + public void open(List collectors, RuntimeContext runtimeContext) { + super.open(collectors, runtimeContext); + this.checkpointDuration = + this.runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); + } - @Override - public void finish(long windowId) { - synchronized (lock) { - LOGGER.info("{} do finish {}", runtimeContext.getTaskArgs().getTaskId(), windowId); - for (TransactionTrait transactionTrait : this.transactionOpList) { - transactionTrait.finish(windowId); - if (CheckpointUtil.needDoCheckpoint(windowId, this.checkpointDuration) - && transactionTrait instanceof CheckpointTrait) { - ((CheckpointTrait) transactionTrait).checkpoint(windowId); - } - } - super.finish(windowId); + @Override + public void finish(long windowId) { + synchronized (lock) { + LOGGER.info("{} do finish {}", runtimeContext.getTaskArgs().getTaskId(), windowId); + for (TransactionTrait transactionTrait : this.transactionOpList) { + transactionTrait.finish(windowId); + if (CheckpointUtil.needDoCheckpoint(windowId, this.checkpointDuration) + && transactionTrait instanceof CheckpointTrait) { + ((CheckpointTrait) transactionTrait).checkpoint(windowId); } + } + super.finish(windowId); } + } - @Override - public void rollback(long windowId) { - synchronized (lock) { - LOGGER.info("do rollback {}", windowId); - for (TransactionTrait transactionTrait : this.transactionOpList) { - transactionTrait.rollback(windowId); - } - } + @Override + public void rollback(long windowId) { + synchronized (lock) { + LOGGER.info("do rollback {}", windowId); + for (TransactionTrait transactionTrait : this.transactionOpList) { + transactionTrait.rollback(windowId); + } } + } - @Override - public R process(T value) { - synchronized (lock) { - return processElement((BatchRecord) value); - } + @Override + public R process(T value) { + synchronized (lock) { + return processElement((BatchRecord) value); } + } - protected void addIfTransactionTrait(Operator operator) { - if (operator == null) { - return; - } - if (operator instanceof TransactionTrait) { - this.transactionOpList.add((TransactionTrait) operator); - } - for (Object subOperator : ((AbstractOperator) operator).getNextOperators()) { - addIfTransactionTrait((Operator) subOperator); - } + protected void addIfTransactionTrait(Operator operator) { + if (operator == null) { + return; + } + if (operator instanceof TransactionTrait) { + this.transactionOpList.add((TransactionTrait) operator); + } + for (Object subOperator : ((AbstractOperator) operator).getNextOperators()) { + addIfTransactionTrait((Operator) subOperator); } + } - protected abstract R processElement(BatchRecord batchRecord); + protected abstract R processElement(BatchRecord batchRecord); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractWindowProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractWindowProcessor.java index 47bc8e398..ebd1a08b9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractWindowProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/AbstractWindowProcessor.java @@ -21,17 +21,18 @@ import org.apache.geaflow.operator.Operator; -public abstract class AbstractWindowProcessor extends AbstractProcessor { +public abstract class AbstractWindowProcessor + extends AbstractProcessor { - protected long windowId; + protected long windowId; - public AbstractWindowProcessor(OP operator) { - super(operator); - } + public AbstractWindowProcessor(OP operator) { + super(operator); + } - @Override - public void init(long windowId) { - super.init(windowId); - this.windowId = windowId; - } + @Override + public void init(long windowId) { + super.init(windowId); + this.windowId = windowId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/graph/GraphVertexCentricProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/graph/GraphVertexCentricProcessor.java index ad92d620a..563c3fecb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/graph/GraphVertexCentricProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/graph/GraphVertexCentricProcessor.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.collector.ICollector; import org.apache.geaflow.model.graph.edge.IEdge; @@ -41,73 +42,69 @@ public class GraphVertexCentricProcessor extends AbstractWindowProcessor, Void, OP> { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphVertexCentricProcessor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(GraphVertexCentricProcessor.class); - private long iterations; + private long iterations; - public GraphVertexCentricProcessor(OP graphVertexCentricOp) { - super(graphVertexCentricOp); - } + public GraphVertexCentricProcessor(OP graphVertexCentricOp) { + super(graphVertexCentricOp); + } - @Override - public void open(List collectors, RuntimeContext runtimeContext) { - super.open(collectors, runtimeContext); - } + @Override + public void open(List collectors, RuntimeContext runtimeContext) { + super.open(collectors, runtimeContext); + } - @Override - public void init(long batchId) { - this.iterations = batchId; - } + @Override + public void init(long batchId) { + this.iterations = batchId; + } - @Override - public Void process(BatchRecord batchRecord) { - if (batchRecord != null) { - this.operator.initIteration(iterations); - if (this.operator.getMaxIterationCount() >= iterations) { - RecordArgs recordArgs = batchRecord.getRecordArgs(); - GraphRecordNames graphRecordName = GraphRecordNames.valueOf(recordArgs.getName()); - if (graphRecordName == GraphRecordNames.Vertex) { - final Iterator vertexIterator = batchRecord.getMessageIterator(); - while (vertexIterator.hasNext()) { - IVertex vertex = vertexIterator.next(); - this.operator.addVertex(vertex); - } - } else if (graphRecordName == GraphRecordNames.Edge) { - final Iterator edgeIterator = batchRecord.getMessageIterator(); - while (edgeIterator.hasNext()) { - this.operator.addEdge(edgeIterator.next()); - } - } else if (graphRecordName == GraphRecordNames.Message) { - final Iterator graphMessageIterator = - batchRecord.getMessageIterator(); - while (graphMessageIterator.hasNext()) { - this.operator.processMessage(graphMessageIterator.next()); - } - } else if (graphRecordName == GraphRecordNames.Request) { - final Iterator requestIterator = - batchRecord.getMessageIterator(); - while (requestIterator.hasNext()) { - ((IGraphTraversalOp) this.operator).addRequest(requestIterator.next()); - } - } else if (graphRecordName == GraphRecordNames.Aggregate) { - final Iterator requestIterator = - batchRecord.getMessageIterator(); - while (requestIterator.hasNext()) { - ((IGraphAggregateOp) this.operator).processAggregateResult(requestIterator.next()); - } - } - } + @Override + public Void process(BatchRecord batchRecord) { + if (batchRecord != null) { + this.operator.initIteration(iterations); + if (this.operator.getMaxIterationCount() >= iterations) { + RecordArgs recordArgs = batchRecord.getRecordArgs(); + GraphRecordNames graphRecordName = GraphRecordNames.valueOf(recordArgs.getName()); + if (graphRecordName == GraphRecordNames.Vertex) { + final Iterator vertexIterator = batchRecord.getMessageIterator(); + while (vertexIterator.hasNext()) { + IVertex vertex = vertexIterator.next(); + this.operator.addVertex(vertex); + } + } else if (graphRecordName == GraphRecordNames.Edge) { + final Iterator edgeIterator = batchRecord.getMessageIterator(); + while (edgeIterator.hasNext()) { + this.operator.addEdge(edgeIterator.next()); + } + } else if (graphRecordName == GraphRecordNames.Message) { + final Iterator graphMessageIterator = batchRecord.getMessageIterator(); + while (graphMessageIterator.hasNext()) { + this.operator.processMessage(graphMessageIterator.next()); + } + } else if (graphRecordName == GraphRecordNames.Request) { + final Iterator requestIterator = batchRecord.getMessageIterator(); + while (requestIterator.hasNext()) { + ((IGraphTraversalOp) this.operator).addRequest(requestIterator.next()); + } + } else if (graphRecordName == GraphRecordNames.Aggregate) { + final Iterator requestIterator = batchRecord.getMessageIterator(); + while (requestIterator.hasNext()) { + ((IGraphAggregateOp) this.operator).processAggregateResult(requestIterator.next()); + } } - return null; + } } + return null; + } - @Override - public void finish(long batchId) { - if (batchId > 0) { - this.operator.finishIteration(iterations); - } else { - this.operator.finish(); - } + @Override + public void finish(long batchId) { + if (batchId > 0) { + this.operator.finishIteration(iterations); + } else { + this.operator.finish(); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/io/SourceProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/io/SourceProcessor.java index 55c108400..9f5fc18ae 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/io/SourceProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/io/SourceProcessor.java @@ -26,28 +26,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class SourceProcessor extends AbstractWindowProcessor> { +public class SourceProcessor + extends AbstractWindowProcessor> { - private static final Logger LOGGER = LoggerFactory.getLogger(SourceProcessor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SourceProcessor.class); - public SourceProcessor(WindowSourceOperator operator) { - super(operator); - } - - - @Override - public Boolean process(Void v) { - try { - return this.operator.emit(this.windowId); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - } + public SourceProcessor(WindowSourceOperator operator) { + super(operator); + } - @Override - public void finish(long batchId) { - super.finish(batchId); + @Override + public Boolean process(Void v) { + try { + return this.operator.emit(this.windowId); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } + @Override + public void finish(long batchId) { + super.finish(batchId); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/OneInputProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/OneInputProcessor.java index aee02bc61..ef4d754c1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/OneInputProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/OneInputProcessor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.processor.impl.window; import java.util.Iterator; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.record.BatchRecord; import org.apache.geaflow.operator.base.window.OneInputOperator; @@ -27,27 +28,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class OneInputProcessor extends AbstractStreamProcessor { +public class OneInputProcessor extends AbstractStreamProcessor { - private static final Logger LOGGER = LoggerFactory.getLogger(OneInputProcessor.class); + private static final Logger LOGGER = LoggerFactory.getLogger(OneInputProcessor.class); - public OneInputProcessor(OneInputOperator operator) { - super(operator); - } + public OneInputProcessor(OneInputOperator operator) { + super(operator); + } - @Override - public Void processElement(BatchRecord batchRecord) { - try { - final Iterator messageIterator = batchRecord.getMessageIterator(); - while (messageIterator.hasNext()) { - T record = messageIterator.next(); - operator.processElement(record); - } - return null; - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + @Override + public Void processElement(BatchRecord batchRecord) { + try { + final Iterator messageIterator = batchRecord.getMessageIterator(); + while (messageIterator.hasNext()) { + T record = messageIterator.next(); + operator.processElement(record); + } + return null; + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/TwoInputProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/TwoInputProcessor.java index 7b988f0cb..df5935c94 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/TwoInputProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/main/java/org/apache/geaflow/processor/impl/window/TwoInputProcessor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.processor.impl.window; import java.util.Iterator; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.record.BatchRecord; import org.apache.geaflow.operator.base.window.TwoInputOperator; @@ -27,48 +28,47 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class TwoInputProcessor extends AbstractStreamProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(TwoInputProcessor.class); +public class TwoInputProcessor extends AbstractStreamProcessor { - private String leftStream; - private String rightStream; + private static final Logger LOGGER = LoggerFactory.getLogger(TwoInputProcessor.class); - public TwoInputProcessor(TwoInputOperator operator) { - super(operator); - } + private String leftStream; + private String rightStream; - @Override - public Void processElement(BatchRecord batchRecord) { - try { - final Iterator messageIterator = batchRecord.getMessageIterator(); - final String streamName = batchRecord.getStreamName(); + public TwoInputProcessor(TwoInputOperator operator) { + super(operator); + } - if (leftStream.equals(streamName)) { - while (messageIterator.hasNext()) { - T record = messageIterator.next(); - operator.processElementOne(record); - } - } else if (rightStream.equals(streamName)) { - while (messageIterator.hasNext()) { - T record = messageIterator.next(); - operator.processElementTwo(record); - } - } + @Override + public Void processElement(BatchRecord batchRecord) { + try { + final Iterator messageIterator = batchRecord.getMessageIterator(); + final String streamName = batchRecord.getStreamName(); - return null; - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException(e); + if (leftStream.equals(streamName)) { + while (messageIterator.hasNext()) { + T record = messageIterator.next(); + operator.processElementOne(record); } - } + } else if (rightStream.equals(streamName)) { + while (messageIterator.hasNext()) { + T record = messageIterator.next(); + operator.processElementTwo(record); + } + } - public void setLeftStream(String leftStream) { - this.leftStream = leftStream; + return null; + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException(e); } + } - public void setRightStream(String rightStream) { - this.rightStream = rightStream; - } + public void setLeftStream(String leftStream) { + this.leftStream = leftStream; + } + + public void setRightStream(String rightStream) { + this.rightStream = rightStream; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/IncrAggregateProcessorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/IncrAggregateProcessorTest.java index 1b9af62aa..b2e82464c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/IncrAggregateProcessorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/IncrAggregateProcessorTest.java @@ -24,10 +24,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.util.Arrays; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.KeySelector; @@ -46,103 +46,105 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class IncrAggregateProcessorTest { - private static int batchSize = 10; - - private OneInputProcessor oneInputProcessor; - private IncrAggregateOperator operator; - private MyAgg agg; - private Map batchId2Value; - - @BeforeMethod - public void setup() { - ICollector collector = mock(ICollector.class); - Configuration configuration = new Configuration(); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(configuration); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.agg = new MyAgg(); - this.operator = new IncrAggregateOperator(this.agg, new KeySelectorFunc()); - this.oneInputProcessor = new OneInputProcessor(this.operator); - this.oneInputProcessor.open(Lists.newArrayList(collector), runtimeContext); - this.batchId2Value = new HashMap<>(); + private static int batchSize = 10; + + private OneInputProcessor oneInputProcessor; + private IncrAggregateOperator operator; + private MyAgg agg; + private Map batchId2Value; + + @BeforeMethod + public void setup() { + ICollector collector = mock(ICollector.class); + Configuration configuration = new Configuration(); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(configuration); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "agg", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.agg = new MyAgg(); + this.operator = new IncrAggregateOperator(this.agg, new KeySelectorFunc()); + this.oneInputProcessor = new OneInputProcessor(this.operator); + this.oneInputProcessor.open(Lists.newArrayList(collector), runtimeContext); + this.batchId2Value = new HashMap<>(); + } + + @Test + public void testAggProcessor() throws Exception { + long batchId = 1; + long value = 0; + for (int i = 0; i < 1000; i++) { + value += i; + if ((i + 1) % batchSize == 0) { + this.batchId2Value.put(batchId++, value); + } } - @Test - public void testAggProcessor() throws Exception { - long batchId = 1; - long value = 0; - for (int i = 0; i < 1000; i++) { - value += i; - if ((i + 1) % batchSize == 0) { - this.batchId2Value.put(batchId++, value); - } - } - - batchId = 1; - for (int i = 0; i < 1000; i++) { - RecordArgs recordArgs = new RecordArgs(batchId); - this.oneInputProcessor.process(new BatchRecord<>(recordArgs, Arrays.asList(((long) i)).iterator())); - if ((i + 1) % batchSize == 0) { - this.oneInputProcessor.finish(batchId++); - Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); - } - } + batchId = 1; + for (int i = 0; i < 1000; i++) { + RecordArgs recordArgs = new RecordArgs(batchId); + this.oneInputProcessor.process( + new BatchRecord<>(recordArgs, Arrays.asList(((long) i)).iterator())); + if ((i + 1) % batchSize == 0) { + this.oneInputProcessor.finish(batchId++); + Assert.assertTrue(this.agg.getValue() == this.batchId2Value.get(batchId - 1)); + } } + } - public static class KeySelectorFunc implements KeySelector { + public static class KeySelectorFunc implements KeySelector { - @Override - public Long getKey(Long value) { - return value; - } + @Override + public Long getKey(Long value) { + return value; } + } - class MutableLong { - - long value; - } + class MutableLong { - class MyAgg implements AggregateFunction { + long value; + } - private long value; + class MyAgg implements AggregateFunction { - public MyAgg() { - this.value = 0; - } + private long value; - @Override - public MutableLong createAccumulator() { - return new MutableLong(); - } + public MyAgg() { + this.value = 0; + } - @Override - public void add(Long value, MutableLong accumulator) { - accumulator.value += value; - this.value += value; - } + @Override + public MutableLong createAccumulator() { + return new MutableLong(); + } - @Override - public Long getResult(MutableLong accumulator) { - return accumulator.value; - } + @Override + public void add(Long value, MutableLong accumulator) { + accumulator.value += value; + this.value += value; + } - @Override - public MutableLong merge(MutableLong a, MutableLong b) { - return null; - } + @Override + public Long getResult(MutableLong accumulator) { + return accumulator.value; + } - public long getValue() { - return this.value; - } + @Override + public MutableLong merge(MutableLong a, MutableLong b) { + return null; } + public long getValue() { + return this.value; + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/SinkProcessorTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/SinkProcessorTest.java index fd7e0fc1a..19fec12af 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/SinkProcessorTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-processor/src/test/java/org/apache/geaflow/processor/impl/window/SinkProcessorTest.java @@ -23,12 +23,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.google.common.collect.Lists; import java.io.Closeable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.api.function.io.SinkFunction; @@ -49,129 +49,126 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class SinkProcessorTest { +import com.google.common.collect.Lists; - private OneInputProcessor oneInputProcessor; - private SinkOperator operator; - private TransactionSinkFunction sinkFunction; - private CommonSinkFunction commonSinkFunction; - private Operator.OpContext opContext; - - @BeforeClass - public void setup() { - ICollector collector = mock(ICollector.class); - RuntimeContext runtimeContext = mock(RuntimeContext.class); - when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); - when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "sink", 1, 1024, 0)); - when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); - MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); - Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); - Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); - this.opContext = new AbstractOperator.DefaultOpContext( - Lists.newArrayList(collector), runtimeContext); - } +public class SinkProcessorTest { - @Test - public void testWriteAndFinishWithTransactionSink() throws Exception { - this.sinkFunction = new TransactionSinkFunction(); - this.operator = new SinkOperator(this.sinkFunction); - this.oneInputProcessor = new OneInputProcessor(this.operator); - this.oneInputProcessor.open(this.opContext.getCollectors(), this.opContext.getRuntimeContext()); - - for (int i = 0; i < 103; i++) { - RecordArgs recordArgs = new RecordArgs(1); - this.oneInputProcessor.process(new BatchRecord<>(recordArgs, Arrays.asList(i).iterator())); - } - Assert.assertEquals(this.sinkFunction.getList().size(), 3); - this.oneInputProcessor.finish(1L); - Assert.assertEquals(this.sinkFunction.getList().size(), 0); + private OneInputProcessor oneInputProcessor; + private SinkOperator operator; + private TransactionSinkFunction sinkFunction; + private CommonSinkFunction commonSinkFunction; + private Operator.OpContext opContext; + + @BeforeClass + public void setup() { + ICollector collector = mock(ICollector.class); + RuntimeContext runtimeContext = mock(RuntimeContext.class); + when(runtimeContext.getConfiguration()).thenReturn(new Configuration()); + when(runtimeContext.getTaskArgs()).thenReturn(new TaskArgs(1, 0, "sink", 1, 1024, 0)); + when(runtimeContext.clone(any(Map.class))).thenReturn(runtimeContext); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.REPORTER_LIST.getKey(), ""); + MetricGroup metricGroup = MetricGroupRegistry.getInstance(config).getMetricGroup(); + Mockito.doReturn(metricGroup).when(runtimeContext).getMetric(); + Mockito.doReturn(runtimeContext).when(runtimeContext).clone(any()); + this.opContext = + new AbstractOperator.DefaultOpContext(Lists.newArrayList(collector), runtimeContext); + } + + @Test + public void testWriteAndFinishWithTransactionSink() throws Exception { + this.sinkFunction = new TransactionSinkFunction(); + this.operator = new SinkOperator(this.sinkFunction); + this.oneInputProcessor = new OneInputProcessor(this.operator); + this.oneInputProcessor.open(this.opContext.getCollectors(), this.opContext.getRuntimeContext()); + + for (int i = 0; i < 103; i++) { + RecordArgs recordArgs = new RecordArgs(1); + this.oneInputProcessor.process(new BatchRecord<>(recordArgs, Arrays.asList(i).iterator())); } - - @Test - public void testWriteAndFinishWithCommonSink() throws Exception { - this.commonSinkFunction = new CommonSinkFunction(); - this.operator = new SinkOperator(this.commonSinkFunction); - this.oneInputProcessor = new OneInputProcessor(this.operator); - this.oneInputProcessor.open(this.opContext.getCollectors(), this.opContext.getRuntimeContext()); - - RecordArgs recordArgs = new RecordArgs(1); - for (int i = 0; i < 103; i++) { - this.oneInputProcessor.process(new BatchRecord<>(recordArgs, Arrays.asList(i).iterator())); - } - Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); - this.oneInputProcessor.finish(1L); - Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + Assert.assertEquals(this.sinkFunction.getList().size(), 3); + this.oneInputProcessor.finish(1L); + Assert.assertEquals(this.sinkFunction.getList().size(), 0); + } + + @Test + public void testWriteAndFinishWithCommonSink() throws Exception { + this.commonSinkFunction = new CommonSinkFunction(); + this.operator = new SinkOperator(this.commonSinkFunction); + this.oneInputProcessor = new OneInputProcessor(this.operator); + this.oneInputProcessor.open(this.opContext.getCollectors(), this.opContext.getRuntimeContext()); + + RecordArgs recordArgs = new RecordArgs(1); + for (int i = 0; i < 103; i++) { + this.oneInputProcessor.process(new BatchRecord<>(recordArgs, Arrays.asList(i).iterator())); } + Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + this.oneInputProcessor.finish(1L); + Assert.assertEquals(this.commonSinkFunction.getList().size(), 3); + } - static class TransactionSinkFunction extends RichFunction implements SinkFunction, Closeable, TransactionTrait { - - private List list; - private int num; + static class TransactionSinkFunction extends RichFunction + implements SinkFunction, Closeable, TransactionTrait { - @Override - public void open(RuntimeContext runtimeContext) { - list = new ArrayList<>(); - num = 1; - } + private List list; + private int num; - @Override - public void close() { - - } + @Override + public void open(RuntimeContext runtimeContext) { + list = new ArrayList<>(); + num = 1; + } - @Override - public void write(Integer value) throws Exception { - list.add(value); - if (num++ % 10 == 0) { - list.clear(); - num = 1; - } - } + @Override + public void close() {} - @Override - public void finish(long windowId) { - list.clear(); - } + @Override + public void write(Integer value) throws Exception { + list.add(value); + if (num++ % 10 == 0) { + list.clear(); + num = 1; + } + } - @Override - public void rollback(long windowId) { + @Override + public void finish(long windowId) { + list.clear(); + } - } + @Override + public void rollback(long windowId) {} - public List getList() { - return list; - } + public List getList() { + return list; } + } - static class CommonSinkFunction extends RichFunction implements SinkFunction, Closeable { + static class CommonSinkFunction extends RichFunction implements SinkFunction, Closeable { - private List list; - private int num; + private List list; + private int num; - @Override - public void open(RuntimeContext runtimeContext) { - list = new ArrayList<>(); - num = 1; - } - - @Override - public void close() { + @Override + public void open(RuntimeContext runtimeContext) { + list = new ArrayList<>(); + num = 1; + } - } + @Override + public void close() {} - @Override - public void write(Integer value) throws Exception { - list.add(value); - if (num++ % 10 == 0) { - list.clear(); - num = 1; - } - } + @Override + public void write(Integer value) throws Exception { + list.add(value); + if (num++ % 10 == 0) { + list.clear(); + num = 1; + } + } - public List getList() { - return list; - } + public List getList() { + return list; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/DefaultRuntimeContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/DefaultRuntimeContext.java index 80fc676b7..95f41e972 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/DefaultRuntimeContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/DefaultRuntimeContext.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.task.TaskArgs; @@ -30,88 +31,88 @@ public class DefaultRuntimeContext extends AbstractRuntimeContext { - private long pipelineId; - private String pipelineName; - private TaskArgs taskArgs; - protected IoDescriptor ioDescriptor; - - public DefaultRuntimeContext(Configuration jobConfig) { - super(jobConfig); - } - - @Override - public long getPipelineId() { - return pipelineId; - } - - public DefaultRuntimeContext setPipelineId(long pipelineId) { - this.pipelineId = pipelineId; - return this; - } - - @Override - public String getPipelineName() { - return pipelineName; - } - - public DefaultRuntimeContext setPipelineName(String jobName) { - this.pipelineName = jobName; - return this; - } - - @Override - public TaskArgs getTaskArgs() { - return taskArgs; - } - - public DefaultRuntimeContext setTaskArgs(TaskArgs taskArgs) { - this.taskArgs = taskArgs; - return this; - } - - @Override - public Configuration getConfiguration() { - return this.jobConfig; - } - - public DefaultRuntimeContext setIoDescriptor(IoDescriptor ioDescriptor) { - this.ioDescriptor = ioDescriptor; - return this; - } - - public DefaultRuntimeContext setWorkPath(String workPath) { - this.workPath = workPath; - return this; - } - - public DefaultRuntimeContext setMetricGroup(MetricGroup metricGroup) { - this.metricGroup = metricGroup; - return this; - } - - public DefaultRuntimeContext setWindowId(long windowId) { - updateWindowId(windowId); - return this; - } - - @Override - public RuntimeContext clone(Map opConfig) { - Map newConfig = new HashMap<>(); - newConfig.putAll(jobConfig.getConfigMap()); - newConfig.putAll(opConfig); - Configuration configuration = new Configuration(newConfig); - return DefaultRuntimeContext.build(configuration) - .setTaskArgs(taskArgs) - .setPipelineId(pipelineId) - .setPipelineName(pipelineName) - .setMetricGroup(metricGroup) - .setIoDescriptor(ioDescriptor) - .setWorkPath(getWorkPath()) - .setWindowId(windowId); - } - - public static DefaultRuntimeContext build(Configuration configuration) { - DefaultRuntimeContext runtimeContext = new DefaultRuntimeContext(configuration); - return runtimeContext; - } + private long pipelineId; + private String pipelineName; + private TaskArgs taskArgs; + protected IoDescriptor ioDescriptor; + + public DefaultRuntimeContext(Configuration jobConfig) { + super(jobConfig); + } + + @Override + public long getPipelineId() { + return pipelineId; + } + + public DefaultRuntimeContext setPipelineId(long pipelineId) { + this.pipelineId = pipelineId; + return this; + } + + @Override + public String getPipelineName() { + return pipelineName; + } + + public DefaultRuntimeContext setPipelineName(String jobName) { + this.pipelineName = jobName; + return this; + } + + @Override + public TaskArgs getTaskArgs() { + return taskArgs; + } + + public DefaultRuntimeContext setTaskArgs(TaskArgs taskArgs) { + this.taskArgs = taskArgs; + return this; + } + + @Override + public Configuration getConfiguration() { + return this.jobConfig; + } + + public DefaultRuntimeContext setIoDescriptor(IoDescriptor ioDescriptor) { + this.ioDescriptor = ioDescriptor; + return this; + } + + public DefaultRuntimeContext setWorkPath(String workPath) { + this.workPath = workPath; + return this; + } + + public DefaultRuntimeContext setMetricGroup(MetricGroup metricGroup) { + this.metricGroup = metricGroup; + return this; + } + + public DefaultRuntimeContext setWindowId(long windowId) { + updateWindowId(windowId); + return this; + } + + @Override + public RuntimeContext clone(Map opConfig) { + Map newConfig = new HashMap<>(); + newConfig.putAll(jobConfig.getConfigMap()); + newConfig.putAll(opConfig); + Configuration configuration = new Configuration(newConfig); + return DefaultRuntimeContext.build(configuration) + .setTaskArgs(taskArgs) + .setPipelineId(pipelineId) + .setPipelineName(pipelineName) + .setMetricGroup(metricGroup) + .setIoDescriptor(ioDescriptor) + .setWorkPath(getWorkPath()) + .setWindowId(windowId); + } + + public static DefaultRuntimeContext build(Configuration configuration) { + DefaultRuntimeContext runtimeContext = new DefaultRuntimeContext(configuration); + return runtimeContext; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/EventContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/EventContext.java index 3bb165543..1008d498e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/EventContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/context/EventContext.java @@ -26,6 +26,92 @@ public class EventContext implements IEventContext { + private long currentWindowId; + private int taskId; + private int cycleId; + private long pipelineId; + private String pipelineName; + private Processor processor; + private ExecutionTask executionTask; + private IoDescriptor ioDescriptor; + private String driverId; + private long windowId; + private long schedulerId; + + private EventContext( + long schedulerId, + long currentWindowId, + int taskId, + int cycleId, + long pipelineId, + String pipelineName, + Processor processor, + ExecutionTask executionTask, + IoDescriptor ioDescriptor, + String driverId, + long windowId) { + this.schedulerId = schedulerId; + this.currentWindowId = currentWindowId; + this.taskId = taskId; + this.cycleId = cycleId; + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + this.processor = processor; + this.executionTask = executionTask; + this.ioDescriptor = ioDescriptor; + this.driverId = driverId; + this.windowId = windowId; + } + + public long getCurrentWindowId() { + return currentWindowId; + } + + public int getTaskId() { + return taskId; + } + + public int getCycleId() { + return cycleId; + } + + public long getPipelineId() { + return pipelineId; + } + + public String getPipelineName() { + return pipelineName; + } + + public Processor getProcessor() { + return processor; + } + + public ExecutionTask getExecutionTask() { + return executionTask; + } + + public IoDescriptor getIoDescriptor() { + return ioDescriptor; + } + + public String getDriverId() { + return driverId; + } + + public long getWindowId() { + return windowId; + } + + public long getSchedulerId() { + return schedulerId; + } + + public static EventContextBuilder builder() { + return new EventContextBuilder(); + } + + public static class EventContextBuilder { private long currentWindowId; private int taskId; private int cycleId; @@ -38,133 +124,66 @@ public class EventContext implements IEventContext { private long windowId; private long schedulerId; - private EventContext(long schedulerId, long currentWindowId, int taskId, int cycleId, long pipelineId, - String pipelineName, Processor processor, ExecutionTask executionTask, - IoDescriptor ioDescriptor, String driverId, long windowId) { - this.schedulerId = schedulerId; - this.currentWindowId = currentWindowId; - this.taskId = taskId; - this.cycleId = cycleId; - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - this.processor = processor; - this.executionTask = executionTask; - this.ioDescriptor = ioDescriptor; - this.driverId = driverId; - this.windowId = windowId; + public EventContextBuilder withExecutionTask(ExecutionTask executionTask) { + this.executionTask = executionTask; + this.processor = executionTask.getProcessor(); + this.taskId = executionTask.getTaskId(); + return this; } - public long getCurrentWindowId() { - return currentWindowId; + public EventContextBuilder withCurrentWindowId(long currentWindowId) { + this.currentWindowId = currentWindowId; + return this; } - public int getTaskId() { - return taskId; + public EventContextBuilder withIoDescriptor(IoDescriptor ioDescriptor) { + this.ioDescriptor = ioDescriptor; + return this; } - public int getCycleId() { - return cycleId; + public EventContextBuilder withCycleId(int cycleId) { + this.cycleId = cycleId; + return this; } - public long getPipelineId() { - return pipelineId; + public EventContextBuilder withPipelineId(long pipelineId) { + this.pipelineId = pipelineId; + return this; } - public String getPipelineName() { - return pipelineName; + public EventContextBuilder withPipelineName(String pipelineName) { + this.pipelineName = pipelineName; + return this; } - public Processor getProcessor() { - return processor; + public EventContextBuilder withDriverId(String driverId) { + this.driverId = driverId; + return this; } - public ExecutionTask getExecutionTask() { - return executionTask; + public EventContextBuilder withWindowId(long windowId) { + this.windowId = windowId; + return this; } - public IoDescriptor getIoDescriptor() { - return ioDescriptor; + public EventContextBuilder withSchedulerId(long schedulerId) { + this.schedulerId = schedulerId; + return this; } - public String getDriverId() { - return driverId; - } - - public long getWindowId() { - return windowId; - } - - public long getSchedulerId() { - return schedulerId; - } - - public static EventContextBuilder builder() { - return new EventContextBuilder(); - } - - public static class EventContextBuilder { - private long currentWindowId; - private int taskId; - private int cycleId; - private long pipelineId; - private String pipelineName; - private Processor processor; - private ExecutionTask executionTask; - private IoDescriptor ioDescriptor; - private String driverId; - private long windowId; - private long schedulerId; - - public EventContextBuilder withExecutionTask(ExecutionTask executionTask) { - this.executionTask = executionTask; - this.processor = executionTask.getProcessor(); - this.taskId = executionTask.getTaskId(); - return this; - } - - public EventContextBuilder withCurrentWindowId(long currentWindowId) { - this.currentWindowId = currentWindowId; - return this; - } - - public EventContextBuilder withIoDescriptor(IoDescriptor ioDescriptor) { - this.ioDescriptor = ioDescriptor; - return this; - } - - public EventContextBuilder withCycleId(int cycleId) { - this.cycleId = cycleId; - return this; - } - - public EventContextBuilder withPipelineId(long pipelineId) { - this.pipelineId = pipelineId; - return this; - } - - public EventContextBuilder withPipelineName(String pipelineName) { - this.pipelineName = pipelineName; - return this; - } - - public EventContextBuilder withDriverId(String driverId) { - this.driverId = driverId; - return this; - } - - public EventContextBuilder withWindowId(long windowId) { - this.windowId = windowId; - return this; - } - - public EventContextBuilder withSchedulerId(long schedulerId) { - this.schedulerId = schedulerId; - return this; - } - - public EventContext build() { - return new EventContext(schedulerId, currentWindowId, taskId, cycleId, pipelineId, - pipelineName, processor, executionTask, ioDescriptor, driverId, windowId); - } + public EventContext build() { + return new EventContext( + schedulerId, + currentWindowId, + taskId, + cycleId, + pipelineId, + pipelineName, + processor, + executionTask, + ioDescriptor, + driverId, + windowId); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractCleanCommand.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractCleanCommand.java index 513c321ab..e18d01c57 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractCleanCommand.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractCleanCommand.java @@ -21,7 +21,7 @@ public abstract class AbstractCleanCommand extends AbstractExecutableCommand { - public AbstractCleanCommand(long schedulerId, int workerId, int cycleId, long windowId) { - super(schedulerId, workerId, cycleId, windowId); - } + public AbstractCleanCommand(long schedulerId, int workerId, int cycleId, long windowId) { + super(schedulerId, workerId, cycleId, windowId); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractExecutableCommand.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractExecutableCommand.java index c175246a6..b7eaeccf6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractExecutableCommand.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractExecutableCommand.java @@ -34,80 +34,82 @@ public abstract class AbstractExecutableCommand implements IExecutableCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractExecutableCommand.class); - - /** - * Scheduler id of the current event. - */ - protected long schedulerId; - - protected int workerId; - - /** - * Cycle id of the current event. - */ - protected int cycleId; - - /** - * window id of cycle. - */ - protected long windowId; - - protected transient IWorker worker; - protected transient IWorkerContext context; - protected transient FetcherRunner fetcherRunner; - protected transient EmitterRunner emitterRunner; - - public AbstractExecutableCommand(long schedulerId, int workerId, int cycleId, long windowId) { - this.schedulerId = schedulerId; - this.workerId = workerId; - this.cycleId = cycleId; - this.windowId = windowId; - } - - - @Override - public void execute(ITaskContext taskContext) { - worker = taskContext.getWorker(); - context = worker.getWorkerContext(); - int workerIndex = taskContext.getWorkerIndex(); - fetcherRunner = taskContext.getFetcherService().getRunner(workerIndex); - emitterRunner = taskContext.getEmitterService().getRunner(workerIndex); - LOGGER.info("task {} process {} batchId {}", - context == null ? null : ((AbstractWorkerContext) context).getTaskId(), this, windowId); - } - - public long getSchedulerId() { - return this.schedulerId; - } - - @Override - public int getWorkerId() { - return this.workerId; - } - - public int getCycleId() { - return cycleId; - } - - public long getIterationWindowId() { - return windowId; - } - - @Override - public void interrupt() { - worker.interrupt(); - } - - /** - * Finish compute and tell scheduler finish. - */ - protected void sendDoneEvent(String driverId, EventType sourceEventType, T result, boolean sendMetrics) { - AbstractWorkerContext workerContext = (AbstractWorkerContext) this.context; - int taskId = workerContext.getTaskId(); - EventMetrics eventMetrics = sendMetrics ? workerContext.getEventMetrics() : null; - DoneEvent doneEvent = new DoneEvent<>(this.schedulerId, this.cycleId, this.windowId, taskId, sourceEventType, result, eventMetrics); - RpcClient.getInstance().processPipeline(driverId, doneEvent); - } - + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractExecutableCommand.class); + + /** Scheduler id of the current event. */ + protected long schedulerId; + + protected int workerId; + + /** Cycle id of the current event. */ + protected int cycleId; + + /** window id of cycle. */ + protected long windowId; + + protected transient IWorker worker; + protected transient IWorkerContext context; + protected transient FetcherRunner fetcherRunner; + protected transient EmitterRunner emitterRunner; + + public AbstractExecutableCommand(long schedulerId, int workerId, int cycleId, long windowId) { + this.schedulerId = schedulerId; + this.workerId = workerId; + this.cycleId = cycleId; + this.windowId = windowId; + } + + @Override + public void execute(ITaskContext taskContext) { + worker = taskContext.getWorker(); + context = worker.getWorkerContext(); + int workerIndex = taskContext.getWorkerIndex(); + fetcherRunner = taskContext.getFetcherService().getRunner(workerIndex); + emitterRunner = taskContext.getEmitterService().getRunner(workerIndex); + LOGGER.info( + "task {} process {} batchId {}", + context == null ? null : ((AbstractWorkerContext) context).getTaskId(), + this, + windowId); + } + + public long getSchedulerId() { + return this.schedulerId; + } + + @Override + public int getWorkerId() { + return this.workerId; + } + + public int getCycleId() { + return cycleId; + } + + public long getIterationWindowId() { + return windowId; + } + + @Override + public void interrupt() { + worker.interrupt(); + } + + /** Finish compute and tell scheduler finish. */ + protected void sendDoneEvent( + String driverId, EventType sourceEventType, T result, boolean sendMetrics) { + AbstractWorkerContext workerContext = (AbstractWorkerContext) this.context; + int taskId = workerContext.getTaskId(); + EventMetrics eventMetrics = sendMetrics ? workerContext.getEventMetrics() : null; + DoneEvent doneEvent = + new DoneEvent<>( + this.schedulerId, + this.cycleId, + this.windowId, + taskId, + sourceEventType, + result, + eventMetrics); + RpcClient.getInstance().processPipeline(driverId, doneEvent); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractInitCommand.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractInitCommand.java index 149bf8cc3..6736d1ace 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractInitCommand.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractInitCommand.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.cluster.collector.AbstractPipelineCollector; import org.apache.geaflow.cluster.collector.CollectorFactory; @@ -52,46 +53,48 @@ public abstract class AbstractInitCommand extends AbstractExecutableCommand { - protected final long pipelineId; - protected final String pipelineName; - protected final IoDescriptor ioDescriptor; - - public AbstractInitCommand(long schedulerId, - int workerId, - int cycleId, - long windowId, - long pipelineId, - String pipelineName, - IoDescriptor ioDescriptor) { - super(schedulerId, workerId, cycleId, windowId); - this.pipelineId = pipelineId; - this.pipelineName = pipelineName; - this.ioDescriptor = ioDescriptor; + protected final long pipelineId; + protected final String pipelineName; + protected final IoDescriptor ioDescriptor; + + public AbstractInitCommand( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long pipelineId, + String pipelineName, + IoDescriptor ioDescriptor) { + super(schedulerId, workerId, cycleId, windowId); + this.pipelineId = pipelineId; + this.pipelineName = pipelineName; + this.ioDescriptor = ioDescriptor; + } + + protected void initFetcher() { + InputDescriptor inputDescriptor = this.ioDescriptor.getInputDescriptor(); + if (inputDescriptor == null || inputDescriptor.getInputDescMap().isEmpty()) { + return; } - - protected void initFetcher() { - InputDescriptor inputDescriptor = this.ioDescriptor.getInputDescriptor(); - if (inputDescriptor == null || inputDescriptor.getInputDescMap().isEmpty()) { - return; - } - WorkerContext workerContext = (WorkerContext) this.context; - InitFetchRequest request = this.buildInitFetchRequest( + WorkerContext workerContext = (WorkerContext) this.context; + InitFetchRequest request = + this.buildInitFetchRequest( inputDescriptor, workerContext.getExecutionTask(), workerContext.getEventMetrics()); - this.fetcherRunner.add(request); + this.fetcherRunner.add(request); + } + + protected InitFetchRequest buildInitFetchRequest( + InputDescriptor inputDescriptor, ExecutionTask task, EventMetrics eventMetrics) { + Map inputDescMap = new HashMap<>(); + for (Map.Entry> entry : inputDescriptor.getInputDescMap().entrySet()) { + IInputDesc inputDesc = entry.getValue(); + if (inputDesc.getInputType() == InputType.META) { + inputDescMap.put(entry.getKey(), (ShardInputDesc) entry.getValue()); + } } - protected InitFetchRequest buildInitFetchRequest(InputDescriptor inputDescriptor, - ExecutionTask task, - EventMetrics eventMetrics) { - Map inputDescMap = new HashMap<>(); - for (Map.Entry> entry : inputDescriptor.getInputDescMap().entrySet()) { - IInputDesc inputDesc = entry.getValue(); - if (inputDesc.getInputType() == InputType.META) { - inputDescMap.put(entry.getKey(), (ShardInputDesc) entry.getValue()); - } - } - - InitFetchRequest initFetchRequest = new InitFetchRequest( + InitFetchRequest initFetchRequest = + new InitFetchRequest( this.pipelineId, this.pipelineName, task.getVertexId(), @@ -100,98 +103,107 @@ protected InitFetchRequest buildInitFetchRequest(InputDescriptor inputDescriptor task.getParallelism(), task.getTaskName(), inputDescMap); - InputReader inputReader = ((AbstractWorker) this.worker).getInputReader(); - inputReader.setEventMetrics(eventMetrics); - initFetchRequest.addListener(inputReader); - return initFetchRequest; + InputReader inputReader = ((AbstractWorker) this.worker).getInputReader(); + inputReader.setEventMetrics(eventMetrics); + initFetchRequest.addListener(inputReader); + return initFetchRequest; + } + + protected void initEmitter() { + OutputDescriptor outputDescriptor = this.ioDescriptor.getOutputDescriptor(); + if (outputDescriptor == null || outputDescriptor.getOutputDescList().isEmpty()) { + ((WorkerContext) this.context).setCollectors(Collections.emptyList()); + return; } - - protected void initEmitter() { - OutputDescriptor outputDescriptor = this.ioDescriptor.getOutputDescriptor(); - if (outputDescriptor == null || outputDescriptor.getOutputDescList().isEmpty()) { - ((WorkerContext) this.context).setCollectors(Collections.emptyList()); - return; - } - InitEmitterRequest request = this.buildInitEmitterRequest(outputDescriptor); - this.emitterRunner.add(request); - List> collectors = this.buildCollectors(outputDescriptor, request); - ((WorkerContext) this.context).setCollectors(collectors); + InitEmitterRequest request = this.buildInitEmitterRequest(outputDescriptor); + this.emitterRunner.add(request); + List> collectors = this.buildCollectors(outputDescriptor, request); + ((WorkerContext) this.context).setCollectors(collectors); + } + + private InitEmitterRequest buildInitEmitterRequest(OutputDescriptor outputDescriptor) { + List> outputBuffers = + this.getOutputBuffers(outputDescriptor.getOutputDescList()); + RuntimeContext runtimeContext = ((WorkerContext) this.context).getRuntimeContext(); + return new InitEmitterRequest( + runtimeContext.getConfiguration(), + this.windowId, + runtimeContext.getPipelineId(), + runtimeContext.getPipelineName(), + runtimeContext.getTaskArgs(), + outputDescriptor, + outputBuffers); + } + + protected List> buildCollectors( + OutputDescriptor outputDescriptor, InitEmitterRequest request) { + List outputDescList = outputDescriptor.getOutputDescList(); + int outputNum = outputDescList.size(); + List> collectors = new ArrayList<>(outputNum); + List> outputBuffers = request.getOutputBuffers(); + for (int i = 0; i < outputNum; i++) { + IOutputDesc outputDesc = outputDescList.get(i); + IOutputMessageBuffer outputBuffer = outputBuffers.get(i); + ICollector collector = CollectorFactory.create(outputDesc); + if (outputDesc.getType() != OutputType.RESPONSE) { + ((AbstractPipelineCollector) collector).setOutputBuffer(outputBuffer); + } + collectors.add(collector); } + return collectors; + } - private InitEmitterRequest buildInitEmitterRequest(OutputDescriptor outputDescriptor) { - List> outputBuffers = this.getOutputBuffers(outputDescriptor.getOutputDescList()); - RuntimeContext runtimeContext = ((WorkerContext) this.context).getRuntimeContext(); - return new InitEmitterRequest( - runtimeContext.getConfiguration(), - this.windowId, - runtimeContext.getPipelineId(), - runtimeContext.getPipelineName(), - runtimeContext.getTaskArgs(), - outputDescriptor, - outputBuffers); + protected void popEmitter() { + OutputDescriptor outputDescriptor = ioDescriptor.getOutputDescriptor(); + if (outputDescriptor == null) { + return; } - protected List> buildCollectors(OutputDescriptor outputDescriptor, InitEmitterRequest request) { - List outputDescList = outputDescriptor.getOutputDescList(); - int outputNum = outputDescList.size(); - List> collectors = new ArrayList<>(outputNum); - List> outputBuffers = request.getOutputBuffers(); - for (int i = 0; i < outputNum; i++) { - IOutputDesc outputDesc = outputDescList.get(i); - IOutputMessageBuffer outputBuffer = outputBuffers.get(i); - ICollector collector = CollectorFactory.create(outputDesc); - if (outputDesc.getType() != OutputType.RESPONSE) { - ((AbstractPipelineCollector) collector).setOutputBuffer(outputBuffer); - } - collectors.add(collector); - } - return collectors; + WorkerContext workerContext = (WorkerContext) this.context; + List outputDescList = outputDescriptor.getOutputDescList(); + int outputNum = outputDescList.size(); + List> collectors = workerContext.getCollectors(); + if (collectors.size() != outputNum) { + throw new GeaflowRuntimeException( + String.format( + "collector num %d not match output desc num %d", collectors.size(), outputNum)); } - protected void popEmitter() { - OutputDescriptor outputDescriptor = ioDescriptor.getOutputDescriptor(); - if (outputDescriptor == null) { - return; - } - - WorkerContext workerContext = (WorkerContext) this.context; - List outputDescList = outputDescriptor.getOutputDescList(); - int outputNum = outputDescList.size(); - List> collectors = workerContext.getCollectors(); - if (collectors.size() != outputNum) { - throw new GeaflowRuntimeException(String.format("collector num %d not match output desc num %d", collectors.size(), outputNum)); - } - - List> outputBuffers = this.getOutputBuffers(outputDescList); - for (int i = 0; i < outputNum; i++) { - if (collectors.get(i) instanceof AbstractPipelineCollector) { - AbstractPipelineCollector collector = (AbstractPipelineCollector) collectors.get(i); - IOutputMessageBuffer outputBuffer = outputBuffers.get(i); - collector.setOutputBuffer(outputBuffer); - } - } - - UpdateEmitterRequest updateEmitterRequest = - new UpdateEmitterRequest(workerContext.getTaskId(), this.windowId, this.pipelineId, this.pipelineName, outputBuffers); - this.emitterRunner.add(updateEmitterRequest); + List> outputBuffers = this.getOutputBuffers(outputDescList); + for (int i = 0; i < outputNum; i++) { + if (collectors.get(i) instanceof AbstractPipelineCollector) { + AbstractPipelineCollector collector = (AbstractPipelineCollector) collectors.get(i); + IOutputMessageBuffer outputBuffer = outputBuffers.get(i); + collector.setOutputBuffer(outputBuffer); + } } - private List> getOutputBuffers(List outputDescList) { - int outputNum = outputDescList.size(); - List> outputBuffers = new ArrayList<>(outputNum); - for (IOutputDesc outputDesc : outputDescList) { - OutputWriter outputBuffer = null; - if (outputDesc.getType() != OutputType.RESPONSE) { - int bucketNum = ((ForwardOutputDesc) outputDesc).getTargetTaskIndices().size(); - outputBuffer = new OutputWriter<>(outputDesc.getEdgeId(), bucketNum); - outputBuffer.setEventMetrics(((WorkerContext) this.context).getEventMetrics()); - } - outputBuffers.add(outputBuffer); - } - return outputBuffers; + UpdateEmitterRequest updateEmitterRequest = + new UpdateEmitterRequest( + workerContext.getTaskId(), + this.windowId, + this.pipelineId, + this.pipelineName, + outputBuffers); + this.emitterRunner.add(updateEmitterRequest); + } + + private List> getOutputBuffers(List outputDescList) { + int outputNum = outputDescList.size(); + List> outputBuffers = new ArrayList<>(outputNum); + for (IOutputDesc outputDesc : outputDescList) { + OutputWriter outputBuffer = null; + if (outputDesc.getType() != OutputType.RESPONSE) { + int bucketNum = ((ForwardOutputDesc) outputDesc).getTargetTaskIndices().size(); + outputBuffer = new OutputWriter<>(outputDesc.getEdgeId(), bucketNum); + outputBuffer.setEventMetrics(((WorkerContext) this.context).getEventMetrics()); + } + outputBuffers.add(outputBuffer); } + return outputBuffers; + } - public long getPipelineId() { - return this.pipelineId; - } + public long getPipelineId() { + return this.pipelineId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractIterationComputeCommand.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractIterationComputeCommand.java index 1cd8b782d..1504f62d3 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractIterationComputeCommand.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/AbstractIterationComputeCommand.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.fetcher.FetchRequest; import org.apache.geaflow.cluster.task.ITaskContext; import org.apache.geaflow.runtime.core.worker.AbstractAlignedWorker; @@ -34,31 +35,41 @@ public abstract class AbstractIterationComputeCommand extends AbstractExecutableCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractIterationComputeCommand.class); - - protected long fetchWindowId; - protected long fetchCount; - protected Map windowCount; - protected Map> batchMessageCache; + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractIterationComputeCommand.class); - public AbstractIterationComputeCommand(long schedulerId, int workerId, int cycleId, long windowId, long fetchWindowId, long fetchCount) { - super(schedulerId, workerId, cycleId, windowId); - this.fetchWindowId = fetchWindowId; - this.fetchCount = fetchCount; - this.windowCount = new HashMap<>(); - this.batchMessageCache = new HashMap(); - } + protected long fetchWindowId; + protected long fetchCount; + protected Map windowCount; + protected Map> batchMessageCache; - @Override - public void execute(ITaskContext taskContext) { - final long start = System.currentTimeMillis(); - super.execute(taskContext); - AbstractWorker abstractWorker = (AbstractWorker) worker; - abstractWorker.init(windowId); - fetcherRunner.add(new FetchRequest(((WorkerContext) this.context).getTaskId(), fetchWindowId, fetchCount)); - abstractWorker.process(fetchCount, - this instanceof LoadGraphProcessEvent || worker instanceof AbstractAlignedWorker); - ((AbstractWorkerContext) this.context).getEventMetrics().addProcessCostMs(System.currentTimeMillis() - start); - } + public AbstractIterationComputeCommand( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long fetchWindowId, + long fetchCount) { + super(schedulerId, workerId, cycleId, windowId); + this.fetchWindowId = fetchWindowId; + this.fetchCount = fetchCount; + this.windowCount = new HashMap<>(); + this.batchMessageCache = new HashMap(); + } + @Override + public void execute(ITaskContext taskContext) { + final long start = System.currentTimeMillis(); + super.execute(taskContext); + AbstractWorker abstractWorker = (AbstractWorker) worker; + abstractWorker.init(windowId); + fetcherRunner.add( + new FetchRequest(((WorkerContext) this.context).getTaskId(), fetchWindowId, fetchCount)); + abstractWorker.process( + fetchCount, + this instanceof LoadGraphProcessEvent || worker instanceof AbstractAlignedWorker); + ((AbstractWorkerContext) this.context) + .getEventMetrics() + .addProcessCostMs(System.currentTimeMillis() - start); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanCycleEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanCycleEvent.java index a1a9bb521..4d892fa9f 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanCycleEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanCycleEvent.java @@ -32,49 +32,54 @@ import org.slf4j.LoggerFactory; /** - * Clean worker runtime execution env, e.g. close shuffle reader/writer and close processor. - * Reverse event of {@link InitCycleEvent}: {@link InitCycleEvent} for initialize env while {@link CleanCycleEvent} for clean env. + * Clean worker runtime execution env, e.g. close shuffle reader/writer and close processor. Reverse + * event of {@link InitCycleEvent}: {@link InitCycleEvent} for initialize env while {@link + * CleanCycleEvent} for clean env. */ public class CleanCycleEvent extends AbstractCleanCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(CleanCycleEvent.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CleanCycleEvent.class); - public CleanCycleEvent(long schedulerId, int workerId, int cycleId, long windowId) { - super(schedulerId, workerId, cycleId, windowId); - } + public CleanCycleEvent(long schedulerId, int workerId, int cycleId, long windowId) { + super(schedulerId, workerId, cycleId, windowId); + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - WorkerContext workerContext = (WorkerContext) this.context; - ExecutionTask executionTask = workerContext.getExecutionTask(); - this.fetcherRunner.add(new CloseFetchRequest(executionTask.getTaskId())); - this.emitterRunner.add(new CloseEmitterRequest(executionTask.getTaskId(), this.windowId)); - this.worker.close(); - EventMetrics eventMetrics = ((AbstractWorkerContext) this.context).getEventMetrics(); - eventMetrics.setFinishTime(System.currentTimeMillis()); - eventMetrics.setFinishGcTs(GcUtil.computeCurrentTotalGcTime()); - LOGGER.info("clean task {} {}/{} of {} {} : {}", - executionTask.getTaskId(), - executionTask.getIndex(), - executionTask.getParallelism(), - executionTask.getVertexId(), - executionTask.getProcessor().toString(), - eventMetrics); - this.sendDoneEvent(workerContext.getDriverId(), EventType.CLEAN_CYCLE, null, true); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + WorkerContext workerContext = (WorkerContext) this.context; + ExecutionTask executionTask = workerContext.getExecutionTask(); + this.fetcherRunner.add(new CloseFetchRequest(executionTask.getTaskId())); + this.emitterRunner.add(new CloseEmitterRequest(executionTask.getTaskId(), this.windowId)); + this.worker.close(); + EventMetrics eventMetrics = ((AbstractWorkerContext) this.context).getEventMetrics(); + eventMetrics.setFinishTime(System.currentTimeMillis()); + eventMetrics.setFinishGcTs(GcUtil.computeCurrentTotalGcTime()); + LOGGER.info( + "clean task {} {}/{} of {} {} : {}", + executionTask.getTaskId(), + executionTask.getIndex(), + executionTask.getParallelism(), + executionTask.getVertexId(), + executionTask.getProcessor().toString(), + eventMetrics); + this.sendDoneEvent(workerContext.getDriverId(), EventType.CLEAN_CYCLE, null, true); + } - @Override - public EventType getEventType() { - return EventType.CLEAN_CYCLE; - } + @Override + public EventType getEventType() { + return EventType.CLEAN_CYCLE; + } - @Override - public String toString() { - return "CleanCycleEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", windowId=" + windowId - + '}'; - } + @Override + public String toString() { + return "CleanCycleEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanEnvEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanEnvEvent.java index f1f10e166..286af675b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanEnvEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanEnvEvent.java @@ -28,47 +28,60 @@ public class CleanEnvEvent extends AbstractCleanCommand { - private final long pipelineId; - private final String driverId; + private final long pipelineId; + private final String driverId; - public CleanEnvEvent(long schedulerId, int workerId, int cycleId, long windowId, long pipelineId, String driverId) { - super(schedulerId, workerId, cycleId, windowId); - this.pipelineId = pipelineId; - this.driverId = driverId; - } + public CleanEnvEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long pipelineId, + String driverId) { + super(schedulerId, workerId, cycleId, windowId); + this.pipelineId = pipelineId; + this.driverId = driverId; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - ShuffleManager.getInstance().release(pipelineId); - WorkerContextManager.clear(); - this.emitterRunner.add(ClearEmitterRequest.INSTANCE); - this.sendDoneEvent(this.driverId, EventType.CLEAN_ENV, null, false); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + ShuffleManager.getInstance().release(pipelineId); + WorkerContextManager.clear(); + this.emitterRunner.add(ClearEmitterRequest.INSTANCE); + this.sendDoneEvent(this.driverId, EventType.CLEAN_ENV, null, false); + } - @Override - public EventType getEventType() { - return EventType.CLEAN_ENV; - } + @Override + public EventType getEventType() { + return EventType.CLEAN_ENV; + } - public void setIterationId(int iterationId) { - this.windowId = iterationId; - } + public void setIterationId(int iterationId) { + this.windowId = iterationId; + } - @Override - protected void sendDoneEvent(String driverId, EventType sourceEventType, T result, boolean sendMetrics) { - DoneEvent doneEvent = new DoneEvent<>(this.schedulerId, this.cycleId, this.windowId, 0, sourceEventType, result); - RpcClient.getInstance().processPipeline(driverId, doneEvent); - } + @Override + protected void sendDoneEvent( + String driverId, EventType sourceEventType, T result, boolean sendMetrics) { + DoneEvent doneEvent = + new DoneEvent<>(this.schedulerId, this.cycleId, this.windowId, 0, sourceEventType, result); + RpcClient.getInstance().processPipeline(driverId, doneEvent); + } - @Override - public String toString() { - return "CleanEnvEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", pipelineId=" + pipelineId - + '}'; - } + @Override + public String toString() { + return "CleanEnvEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", pipelineId=" + + pipelineId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanStashEnvEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanStashEnvEvent.java index 9e6cd0a41..6b4bf7cc8 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanStashEnvEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CleanStashEnvEvent.java @@ -26,46 +26,58 @@ public class CleanStashEnvEvent extends AbstractCleanCommand { - protected final long pipelineId; - protected final String driverId; + protected final long pipelineId; + protected final String driverId; - public CleanStashEnvEvent(long schedulerId, int workerId, int cycleId, long iterationId, - long pipelineId, String driverId) { - super(schedulerId, workerId, cycleId, iterationId); - this.pipelineId = pipelineId; - this.driverId = driverId; - } + public CleanStashEnvEvent( + long schedulerId, + int workerId, + int cycleId, + long iterationId, + long pipelineId, + String driverId) { + super(schedulerId, workerId, cycleId, iterationId); + this.pipelineId = pipelineId; + this.driverId = driverId; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - ShuffleManager.getInstance().release(pipelineId); - this.sendDoneEvent(this.driverId, EventType.CLEAN_ENV, null, false); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + ShuffleManager.getInstance().release(pipelineId); + this.sendDoneEvent(this.driverId, EventType.CLEAN_ENV, null, false); + } - @Override - public EventType getEventType() { - return EventType.CLEAN_ENV; - } + @Override + public EventType getEventType() { + return EventType.CLEAN_ENV; + } - public void setIterationId(int iterationId) { - this.windowId = iterationId; - } + public void setIterationId(int iterationId) { + this.windowId = iterationId; + } - @Override - protected void sendDoneEvent(String driverId, EventType sourceEventType, T result, boolean sendMetrics) { - DoneEvent doneEvent = new DoneEvent<>(this.schedulerId, this.cycleId, this.windowId, 0, sourceEventType, result); - RpcClient.getInstance().processPipeline(driverId, doneEvent); - } + @Override + protected void sendDoneEvent( + String driverId, EventType sourceEventType, T result, boolean sendMetrics) { + DoneEvent doneEvent = + new DoneEvent<>(this.schedulerId, this.cycleId, this.windowId, 0, sourceEventType, result); + RpcClient.getInstance().processPipeline(driverId, doneEvent); + } - @Override - public String toString() { - return "CleanStashEnvEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", pipelineId=" + pipelineId - + '}'; - } + @Override + public String toString() { + return "CleanStashEnvEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", pipelineId=" + + pipelineId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ComposeEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ComposeEvent.java index 2c8b24b89..bf3344d67 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ComposeEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ComposeEvent.java @@ -20,52 +20,48 @@ package org.apache.geaflow.runtime.core.protocol; import java.util.List; + import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.ICommand; import org.apache.geaflow.cluster.protocol.IComposeEvent; import org.apache.geaflow.cluster.protocol.IEvent; /** - * An event that contains a list of basic event. - * Suppose a cycle with one worker run the following steps one by one: - * firstly, worker need init runtime execution evn and - * then execution a round iteration and - * finally clean worker env. - * Scheduler can build a {@link ComposeEvent} of {@link InitCycleEvent}, {@link ExecuteComputeEvent} and {@link CleanCycleEvent} and - * send to worker, instead of sending three events to worker one by one. + * An event that contains a list of basic event. Suppose a cycle with one worker run the following + * steps one by one: firstly, worker need init runtime execution evn and then execution a round + * iteration and finally clean worker env. Scheduler can build a {@link ComposeEvent} of {@link + * InitCycleEvent}, {@link ExecuteComputeEvent} and {@link CleanCycleEvent} and send to worker, + * instead of sending three events to worker one by one. */ public class ComposeEvent implements IComposeEvent, ICommand { - private int workerId; + private int workerId; - // A list of event that will be executed by worker sequentially. - private List events; + // A list of event that will be executed by worker sequentially. + private List events; - public ComposeEvent(int workerId, List events) { - this.workerId = workerId; - this.events = events; - } + public ComposeEvent(int workerId, List events) { + this.workerId = workerId; + this.events = events; + } - @Override - public int getWorkerId() { - return workerId; - } + @Override + public int getWorkerId() { + return workerId; + } - @Override - public List getEventList() { - return events; - } + @Override + public List getEventList() { + return events; + } - @Override - public EventType getEventType() { - return EventType.COMPOSE; - } + @Override + public EventType getEventType() { + return EventType.COMPOSE; + } - @Override - public String toString() { - return "ComposeEvent{" - + "workerId=" + workerId - + ", events=" + events - + '}'; - } + @Override + public String toString() { + return "ComposeEvent{" + "workerId=" + workerId + ", events=" + events + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateTaskEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateTaskEvent.java index 9a3176b82..fdce4e5f0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateTaskEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateTaskEvent.java @@ -24,38 +24,34 @@ import org.apache.geaflow.cluster.protocol.IHighAvailableEvent; import org.apache.geaflow.ha.runtime.HighAvailableLevel; -/** - * Defined creating of the task. - */ +/** Defined creating of the task. */ public class CreateTaskEvent implements ICommand, IHighAvailableEvent { - private int workerId; - private HighAvailableLevel haLevel; - - public CreateTaskEvent(int workerId, HighAvailableLevel haLevel) { - this.workerId = workerId; - this.haLevel = haLevel; - } - - @Override - public int getWorkerId() { - return workerId; - } - - @Override - public EventType getEventType() { - return EventType.CREATE_TASK; - } - - @Override - public String toString() { - return "CreateTaskEvent{" - + "workerId=" + workerId - + '}'; - } - - @Override - public HighAvailableLevel getHaLevel() { - return haLevel; - } + private int workerId; + private HighAvailableLevel haLevel; + + public CreateTaskEvent(int workerId, HighAvailableLevel haLevel) { + this.workerId = workerId; + this.haLevel = haLevel; + } + + @Override + public int getWorkerId() { + return workerId; + } + + @Override + public EventType getEventType() { + return EventType.CREATE_TASK; + } + + @Override + public String toString() { + return "CreateTaskEvent{" + "workerId=" + workerId + '}'; + } + + @Override + public HighAvailableLevel getHaLevel() { + return haLevel; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateWorkerEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateWorkerEvent.java index 1133afc25..b90ee4a9b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateWorkerEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/CreateWorkerEvent.java @@ -28,51 +28,45 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Defined creating of the pipeline worker. - */ +/** Defined creating of the pipeline worker. */ public class CreateWorkerEvent implements IExecutableCommand, IHighAvailableEvent { - private static final Logger LOGGER = LoggerFactory.getLogger(CreateWorkerEvent.class); - - private int workerId; - private HighAvailableLevel haLevel; + private static final Logger LOGGER = LoggerFactory.getLogger(CreateWorkerEvent.class); - public CreateWorkerEvent(int workerId, HighAvailableLevel haLevel) { - this.workerId = workerId; - this.haLevel = haLevel; - } + private int workerId; + private HighAvailableLevel haLevel; - @Override - public int getWorkerId() { - return workerId; - } + public CreateWorkerEvent(int workerId, HighAvailableLevel haLevel) { + this.workerId = workerId; + this.haLevel = haLevel; + } - @Override - public void execute(ITaskContext context) { - context.registerWorker(WorkerFactory.createWorker(context.getConfig())); - LOGGER.info("create worker {} worker Id {}", context.getWorker(), workerId); - } + @Override + public int getWorkerId() { + return workerId; + } - @Override - public void interrupt() { + @Override + public void execute(ITaskContext context) { + context.registerWorker(WorkerFactory.createWorker(context.getConfig())); + LOGGER.info("create worker {} worker Id {}", context.getWorker(), workerId); + } - } + @Override + public void interrupt() {} - @Override - public EventType getEventType() { - return EventType.CREATE_WORKER; - } + @Override + public EventType getEventType() { + return EventType.CREATE_WORKER; + } - @Override - public String toString() { - return "CreateWorkerEvent{" - + "workerId=" + workerId - + '}'; - } + @Override + public String toString() { + return "CreateWorkerEvent{" + "workerId=" + workerId + '}'; + } - @Override - public HighAvailableLevel getHaLevel() { - return haLevel; - } + @Override + public HighAvailableLevel getHaLevel() { + return haLevel; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DestroyTaskEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DestroyTaskEvent.java index aabfedae8..3533db405 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DestroyTaskEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DestroyTaskEvent.java @@ -22,31 +22,27 @@ import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.ICommand; -/** - * Defined destroying of the task. - */ +/** Defined destroying of the task. */ public class DestroyTaskEvent implements ICommand { - private int workerId; + private int workerId; - public DestroyTaskEvent(int workerId) { - this.workerId = workerId; - } + public DestroyTaskEvent(int workerId) { + this.workerId = workerId; + } - @Override - public int getWorkerId() { - return workerId; - } + @Override + public int getWorkerId() { + return workerId; + } - @Override - public EventType getEventType() { - return EventType.DESTROY_TASK; - } + @Override + public EventType getEventType() { + return EventType.DESTROY_TASK; + } - @Override - public String toString() { - return "DestroyTaskEvent{" - + "workerId=" + workerId - + '}'; - } + @Override + public String toString() { + return "DestroyTaskEvent{" + "workerId=" + workerId + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DoneEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DoneEvent.java index f06bea1cd..17f94d7ee 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DoneEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/DoneEvent.java @@ -23,106 +23,116 @@ import org.apache.geaflow.cluster.protocol.ICycleResponseEvent; import org.apache.geaflow.common.metric.EventMetrics; -/** - * Defined the end of one iteration. - * It sent from cycle tail tasks to scheduler. - */ +/** Defined the end of one iteration. It sent from cycle tail tasks to scheduler. */ public class DoneEvent implements ICycleResponseEvent { - // Scheduler id of the current event. - private long schedulerId; - - // Cycle id of the current event. - private int cycleId; - - // Window id of cycle. - private long windowId; - - // The task id of cycle tail that send back event to scheduler. - private int taskId; - - // Event that trigger the execution. - private EventType sourceEvent; - - // Result of execution. null if no need result. - private T result; - - private EventMetrics eventMetrics; - - public DoneEvent(long schedulerId, int cycleId, long windowId, int tailTaskId, EventType sourceEvent) { - this(schedulerId, cycleId, windowId, tailTaskId, sourceEvent, null, null); - } - - public DoneEvent(long schedulerId, int cycleId, long windowId, int tailTaskId, EventType sourceEvent, T result) { - this(schedulerId, cycleId, windowId, tailTaskId, sourceEvent, result, null); - } - - public DoneEvent(long schedulerId, - int cycleId, - long windowId, - int tailTaskId, - EventType sourceEvent, - T result, - EventMetrics eventMetrics) { - this.schedulerId = schedulerId; - this.cycleId = cycleId; - this.windowId = windowId; - this.taskId = tailTaskId; - this.sourceEvent = sourceEvent; - this.result = result; - this.eventMetrics = eventMetrics; - } - - public long getSchedulerId() { - return schedulerId; - } - - @Override - public int getCycleId() { - return cycleId; - } - - public long getWindowId() { - return windowId; - } - - public int getTaskId() { - return taskId; - } - - public T getResult() { - return result; - } - - public void setResult(T result) { - this.result = result; - } - - public EventType getSourceEvent() { - return sourceEvent; - } - - @Override - public EventType getEventType() { - return EventType.DONE; - } - - public EventMetrics getEventMetrics() { - return eventMetrics; - } - - public void setEventMetrics(EventMetrics eventMetrics) { - this.eventMetrics = eventMetrics; - } - - @Override - public String toString() { - return "DoneEvent{" - + "schedulerId=" + schedulerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", taskId=" + taskId - + ", sourceEvent=" + sourceEvent - + '}'; - } + // Scheduler id of the current event. + private long schedulerId; + + // Cycle id of the current event. + private int cycleId; + + // Window id of cycle. + private long windowId; + + // The task id of cycle tail that send back event to scheduler. + private int taskId; + + // Event that trigger the execution. + private EventType sourceEvent; + + // Result of execution. null if no need result. + private T result; + + private EventMetrics eventMetrics; + + public DoneEvent( + long schedulerId, int cycleId, long windowId, int tailTaskId, EventType sourceEvent) { + this(schedulerId, cycleId, windowId, tailTaskId, sourceEvent, null, null); + } + + public DoneEvent( + long schedulerId, + int cycleId, + long windowId, + int tailTaskId, + EventType sourceEvent, + T result) { + this(schedulerId, cycleId, windowId, tailTaskId, sourceEvent, result, null); + } + + public DoneEvent( + long schedulerId, + int cycleId, + long windowId, + int tailTaskId, + EventType sourceEvent, + T result, + EventMetrics eventMetrics) { + this.schedulerId = schedulerId; + this.cycleId = cycleId; + this.windowId = windowId; + this.taskId = tailTaskId; + this.sourceEvent = sourceEvent; + this.result = result; + this.eventMetrics = eventMetrics; + } + + public long getSchedulerId() { + return schedulerId; + } + + @Override + public int getCycleId() { + return cycleId; + } + + public long getWindowId() { + return windowId; + } + + public int getTaskId() { + return taskId; + } + + public T getResult() { + return result; + } + + public void setResult(T result) { + this.result = result; + } + + public EventType getSourceEvent() { + return sourceEvent; + } + + @Override + public EventType getEventType() { + return EventType.DONE; + } + + public EventMetrics getEventMetrics() { + return eventMetrics; + } + + public void setEventMetrics(EventMetrics eventMetrics) { + this.eventMetrics = eventMetrics; + } + + @Override + public String toString() { + return "DoneEvent{" + + "schedulerId=" + + schedulerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", taskId=" + + taskId + + ", sourceEvent=" + + sourceEvent + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteComputeEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteComputeEvent.java index 9300a1453..a2cb5ea15 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteComputeEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteComputeEvent.java @@ -21,38 +21,53 @@ import org.apache.geaflow.cluster.protocol.EventType; -/** - * Send from scheduler to cycle head task to launch one iteration of the cycle. - */ +/** Send from scheduler to cycle head task to launch one iteration of the cycle. */ public class ExecuteComputeEvent extends AbstractIterationComputeCommand { - private boolean recoverable; - - public ExecuteComputeEvent(long schedulerId, int workerId, int cycleId, long windowId, long fetchWindowId, long fetchCount) { - super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); - } - - public ExecuteComputeEvent(long schedulerId, int workerId, int cycleId, long windowId, - long fetchWindowId, long fetchCount, - boolean recoverable) { - this(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); - this.recoverable = recoverable; - } - - @Override - public EventType getEventType() { - return EventType.EXECUTE_COMPUTE; - } - - @Override - public String toString() { - return "ExecuteComputeEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", fetchWindowId=" + fetchWindowId - + ", fetchCount=" + fetchCount - + '}'; - } + private boolean recoverable; + + public ExecuteComputeEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long fetchWindowId, + long fetchCount) { + super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); + } + + public ExecuteComputeEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long fetchWindowId, + long fetchCount, + boolean recoverable) { + this(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); + this.recoverable = recoverable; + } + + @Override + public EventType getEventType() { + return EventType.EXECUTE_COMPUTE; + } + + @Override + public String toString() { + return "ExecuteComputeEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", fetchWindowId=" + + fetchWindowId + + ", fetchCount=" + + fetchCount + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteFirstIterationEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteFirstIterationEvent.java index 5ad741a65..63e24447b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteFirstIterationEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/ExecuteFirstIterationEvent.java @@ -21,27 +21,29 @@ import org.apache.geaflow.cluster.protocol.EventType; -/** - * Send from scheduler to cycle head task to launch the first iteration of the cycle. - */ +/** Send from scheduler to cycle head task to launch the first iteration of the cycle. */ public class ExecuteFirstIterationEvent extends AbstractExecutableCommand { - public ExecuteFirstIterationEvent(long schedulerId, int workerId, int cycleId, long windowId) { - super(schedulerId, workerId, cycleId, windowId); - } + public ExecuteFirstIterationEvent(long schedulerId, int workerId, int cycleId, long windowId) { + super(schedulerId, workerId, cycleId, windowId); + } - @Override - public EventType getEventType() { - return EventType.EXECUTE_FIRST_ITERATION; - } + @Override + public EventType getEventType() { + return EventType.EXECUTE_FIRST_ITERATION; + } - @Override - public String toString() { - return "ExecuteFirstIterationEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + '}'; - } + @Override + public String toString() { + return "ExecuteFirstIterationEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishIterationEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishIterationEvent.java index 4f16d5208..eb4d2c32a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishIterationEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishIterationEvent.java @@ -25,33 +25,39 @@ public class FinishIterationEvent extends AbstractExecutableCommand { - public static final long END_OF_ITERATION_ID = 0; - - public FinishIterationEvent(long schedulerId, int workerId, long windowId, int cycleId) { - super(schedulerId, workerId, cycleId, windowId); - } - - @Override - public void execute(ITaskContext taskContext) { - final long start = System.currentTimeMillis(); - super.execute(taskContext); - worker.init(windowId); - worker.finish(END_OF_ITERATION_ID); - ((AbstractWorkerContext) this.context).getEventMetrics().addProcessCostMs(System.currentTimeMillis() - start); - } - - @Override - public EventType getEventType() { - return EventType.FINISH_ITERATION; - } - - @Override - public String toString() { - return "FinishIterationEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", windowId=" + windowId - + ", cycleId=" + cycleId - + '}'; - } + public static final long END_OF_ITERATION_ID = 0; + + public FinishIterationEvent(long schedulerId, int workerId, long windowId, int cycleId) { + super(schedulerId, workerId, cycleId, windowId); + } + + @Override + public void execute(ITaskContext taskContext) { + final long start = System.currentTimeMillis(); + super.execute(taskContext); + worker.init(windowId); + worker.finish(END_OF_ITERATION_ID); + ((AbstractWorkerContext) this.context) + .getEventMetrics() + .addProcessCostMs(System.currentTimeMillis() - start); + } + + @Override + public EventType getEventType() { + return EventType.FINISH_ITERATION; + } + + @Override + public String toString() { + return "FinishIterationEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", windowId=" + + windowId + + ", cycleId=" + + cycleId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishPrefetchEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishPrefetchEvent.java index 0c8d96d52..edfdf678b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishPrefetchEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/FinishPrefetchEvent.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.core.protocol; import java.util.List; + import org.apache.geaflow.cluster.fetcher.CloseFetchRequest; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.task.ITaskContext; @@ -28,52 +29,57 @@ public class FinishPrefetchEvent extends AbstractExecutableCommand { - private final int taskId; - private final int taskIndex; - private final long pipelineId; - private final List edgeIds; - - public FinishPrefetchEvent(long schedulerId, - int workerId, - int cycleId, - long windowId, - int taskId, - int taskIndex, - long pipelineId, - List edgeIds) { - super(schedulerId, workerId, cycleId, windowId); - this.taskId = taskId; - this.taskIndex = taskIndex; - this.pipelineId = pipelineId; - this.edgeIds = edgeIds; - } + private final int taskId; + private final int taskIndex; + private final long pipelineId; + private final List edgeIds; - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - PrefetchCallbackHandler callbackHandler = PrefetchCallbackHandler.getInstance(); - for (Integer edgeId : this.edgeIds) { - SliceId sliceId = new SliceId(this.pipelineId, edgeId, -1, this.taskIndex); - PrefetchCallbackHandler.PrefetchCallback callback = callbackHandler.removeTaskEventCallback(sliceId); - callback.execute(); - } + public FinishPrefetchEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + int taskId, + int taskIndex, + long pipelineId, + List edgeIds) { + super(schedulerId, workerId, cycleId, windowId); + this.taskId = taskId; + this.taskIndex = taskIndex; + this.pipelineId = pipelineId; + this.edgeIds = edgeIds; + } - this.fetcherRunner.add(new CloseFetchRequest(this.taskId)); + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + PrefetchCallbackHandler callbackHandler = PrefetchCallbackHandler.getInstance(); + for (Integer edgeId : this.edgeIds) { + SliceId sliceId = new SliceId(this.pipelineId, edgeId, -1, this.taskIndex); + PrefetchCallbackHandler.PrefetchCallback callback = + callbackHandler.removeTaskEventCallback(sliceId); + callback.execute(); } - @Override - public EventType getEventType() { - return EventType.PREFETCH; - } + this.fetcherRunner.add(new CloseFetchRequest(this.taskId)); + } - @Override - public String toString() { - return "FinishPrefetchEvent{" - + "taskId=" + taskId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + '}'; - } + @Override + public EventType getEventType() { + return EventType.PREFETCH; + } + @Override + public String toString() { + return "FinishPrefetchEvent{" + + "taskId=" + + taskId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCollectCycleEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCollectCycleEvent.java index b6efa9ab0..477eab573 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCollectCycleEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCollectCycleEvent.java @@ -19,9 +19,9 @@ package org.apache.geaflow.runtime.core.protocol; -import com.google.common.base.Preconditions; import java.util.Collections; import java.util.List; + import org.apache.geaflow.cluster.collector.CollectResponseCollector; import org.apache.geaflow.cluster.collector.InitEmitterRequest; import org.apache.geaflow.collector.ICollector; @@ -31,33 +31,48 @@ import org.apache.geaflow.shuffle.OutputDescriptor; import org.apache.geaflow.shuffle.ResponseOutputDesc; +import com.google.common.base.Preconditions; + /** - * An assign event provides some runtime execution information for worker to build the cycle pipeline. - * including: execution task descriptors, shuffle descriptors + * An assign event provides some runtime execution information for worker to build the cycle + * pipeline. including: execution task descriptors, shuffle descriptors */ public class InitCollectCycleEvent extends InitCycleEvent { - private static final int COLLECT_BUCKET_NUM = 1; - - public InitCollectCycleEvent(long schedulerId, - int workerId, - int cycleId, - long iterationId, - long pipelineId, - String pipelineName, - IoDescriptor ioDescriptor, - ExecutionTask task, - String driverId, - HighAvailableLevel haLevel) { - super(schedulerId, workerId, cycleId, iterationId, pipelineId, pipelineName, ioDescriptor, task, driverId, haLevel); - } - - @Override - protected List> buildCollectors(OutputDescriptor outputDescriptor, InitEmitterRequest request) { - Preconditions.checkArgument(outputDescriptor.getOutputDescList().size() == COLLECT_BUCKET_NUM, - "only support one collect output info yet"); - ResponseOutputDesc outputDesc = (ResponseOutputDesc) outputDescriptor.getOutputDescList().get(0); - return Collections.singletonList(new CollectResponseCollector<>(outputDesc)); - } + private static final int COLLECT_BUCKET_NUM = 1; + + public InitCollectCycleEvent( + long schedulerId, + int workerId, + int cycleId, + long iterationId, + long pipelineId, + String pipelineName, + IoDescriptor ioDescriptor, + ExecutionTask task, + String driverId, + HighAvailableLevel haLevel) { + super( + schedulerId, + workerId, + cycleId, + iterationId, + pipelineId, + pipelineName, + ioDescriptor, + task, + driverId, + haLevel); + } + @Override + protected List> buildCollectors( + OutputDescriptor outputDescriptor, InitEmitterRequest request) { + Preconditions.checkArgument( + outputDescriptor.getOutputDescList().size() == COLLECT_BUCKET_NUM, + "only support one collect output info yet"); + ResponseOutputDesc outputDesc = + (ResponseOutputDesc) outputDescriptor.getOutputDescList().get(0); + return Collections.singletonList(new CollectResponseCollector<>(outputDesc)); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCycleEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCycleEvent.java index add2d91c5..b6c204a4a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCycleEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitCycleEvent.java @@ -30,44 +30,46 @@ import org.apache.geaflow.shuffle.IoDescriptor; /** - * An assign event provides some runtime execution information for worker to build the cycle pipeline. - * including: execution task descriptors, shuffle descriptors + * An assign event provides some runtime execution information for worker to build the cycle + * pipeline. including: execution task descriptors, shuffle descriptors */ public class InitCycleEvent extends AbstractInitCommand implements IHighAvailableEvent { - private final ExecutionTask task; - private final String driverId; - private final HighAvailableLevel haLevel; + private final ExecutionTask task; + private final String driverId; + private final HighAvailableLevel haLevel; - public InitCycleEvent(long schedulerId, - int workerId, - int cycleId, - long iterationId, - long pipelineId, - String pipelineName, - IoDescriptor ioDescriptor, - ExecutionTask task, - String driverId, - HighAvailableLevel haLevel) { - super(schedulerId, workerId, cycleId, iterationId, pipelineId, pipelineName, ioDescriptor); - this.task = task; - this.driverId = driverId; - this.haLevel = haLevel; - } + public InitCycleEvent( + long schedulerId, + int workerId, + int cycleId, + long iterationId, + long pipelineId, + String pipelineName, + IoDescriptor ioDescriptor, + ExecutionTask task, + String driverId, + HighAvailableLevel haLevel) { + super(schedulerId, workerId, cycleId, iterationId, pipelineId, pipelineName, ioDescriptor); + this.task = task; + this.driverId = driverId; + this.haLevel = haLevel; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - this.task.buildTaskName(this.pipelineName, this.cycleId, this.windowId); - this.context = this.initContext(taskContext); - this.initFetcher(); - this.initEmitter(); - this.worker.open(this.context); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + this.task.buildTaskName(this.pipelineName, this.cycleId, this.windowId); + this.context = this.initContext(taskContext); + this.initFetcher(); + this.initEmitter(); + this.worker.open(this.context); + } - private WorkerContext initContext(ITaskContext taskContext) { - WorkerContext workerContext = new WorkerContext(taskContext); - IEventContext eventContext = EventContext.builder() + private WorkerContext initContext(ITaskContext taskContext) { + WorkerContext workerContext = new WorkerContext(taskContext); + IEventContext eventContext = + EventContext.builder() .withExecutionTask(this.task) .withDriverId(this.driverId) .withCycleId(this.cycleId) @@ -78,34 +80,39 @@ private WorkerContext initContext(ITaskContext taskContext) { .withWindowId(this.windowId) .withSchedulerId(this.schedulerId) .build(); - workerContext.init(eventContext); - return workerContext; - } + workerContext.init(eventContext); + return workerContext; + } - @Override - public EventType getEventType() { - return EventType.INIT_CYCLE; - } + @Override + public EventType getEventType() { + return EventType.INIT_CYCLE; + } - @Override - public HighAvailableLevel getHaLevel() { - return haLevel; - } + @Override + public HighAvailableLevel getHaLevel() { + return haLevel; + } - public ExecutionTask getTask() { - return task; - } - - @Override - public String toString() { - return "InitCycleEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", pipelineId=" + pipelineId - + ", pipelineName=" + pipelineName - + '}'; - } + public ExecutionTask getTask() { + return task; + } + @Override + public String toString() { + return "InitCycleEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", pipelineId=" + + pipelineId + + ", pipelineName=" + + pipelineName + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitIterationEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitIterationEvent.java index 2633f9f61..676f90e2e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitIterationEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InitIterationEvent.java @@ -24,39 +24,44 @@ import org.apache.geaflow.shuffle.IoDescriptor; /** - * An assign event provides some runtime execution information for worker to build the cycle pipeline. - * including: execution task descriptors, shuffle descriptors + * An assign event provides some runtime execution information for worker to build the cycle + * pipeline. including: execution task descriptors, shuffle descriptors */ public class InitIterationEvent extends AbstractInitCommand { - public InitIterationEvent(long schedulerId, - int workerId, - int cycleId, - long iterationId, - long pipelineId, - String pipelineName, - IoDescriptor ioDescriptor) { - super(schedulerId, workerId, cycleId, iterationId, pipelineId, pipelineName, ioDescriptor); - } + public InitIterationEvent( + long schedulerId, + int workerId, + int cycleId, + long iterationId, + long pipelineId, + String pipelineName, + IoDescriptor ioDescriptor) { + super(schedulerId, workerId, cycleId, iterationId, pipelineId, pipelineName, ioDescriptor); + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - this.initFetcher(); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + this.initFetcher(); + } - @Override - public EventType getEventType() { - return EventType.INIT_ITERATION; - } + @Override + public EventType getEventType() { + return EventType.INIT_ITERATION; + } - @Override - public String toString() { - return "InitIterationEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", iterationId=" + windowId - + '}'; - } + @Override + public String toString() { + return "InitIterationEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", iterationId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InterruptTaskEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InterruptTaskEvent.java index f342df72a..b72cb4d4c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InterruptTaskEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/InterruptTaskEvent.java @@ -24,37 +24,33 @@ public class InterruptTaskEvent implements ICommand { - // Worker id to execute event. - protected final int workerId; - - // Cycle id of the current event. - protected final int cycleId; - - - public InterruptTaskEvent(int workerId, int cycleId) { - this.workerId = workerId; - this.cycleId = cycleId; - } - - @Override - public int getWorkerId() { - return workerId; - } - - @Override - public EventType getEventType() { - return EventType.INTERRUPT_TASK; - } - - public int getCycleId() { - return cycleId; - } - - @Override - public String toString() { - return "InterruptTaskEvent{" - + "workerId=" + workerId - + ", cycleId=" + cycleId - + '}'; - } + // Worker id to execute event. + protected final int workerId; + + // Cycle id of the current event. + protected final int cycleId; + + public InterruptTaskEvent(int workerId, int cycleId) { + this.workerId = workerId; + this.cycleId = cycleId; + } + + @Override + public int getWorkerId() { + return workerId; + } + + @Override + public EventType getEventType() { + return EventType.INTERRUPT_TASK; + } + + public int getCycleId() { + return cycleId; + } + + @Override + public String toString() { + return "InterruptTaskEvent{" + "workerId=" + workerId + ", cycleId=" + cycleId + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/IterationExecutionComputeWithAggEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/IterationExecutionComputeWithAggEvent.java index 1c395e771..1a6c3f74b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/IterationExecutionComputeWithAggEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/IterationExecutionComputeWithAggEvent.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.task.ITaskContext; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,88 +33,96 @@ import org.apache.geaflow.shuffle.message.PipelineMessage; import org.apache.geaflow.shuffle.serialize.IMessageIterator; -/** - * Send from scheduler to cycle head task to launch one iteration with aggregation of the cycle. - */ +/** Send from scheduler to cycle head task to launch one iteration with aggregation of the cycle. */ public class IterationExecutionComputeWithAggEvent extends AbstractIterationComputeCommand { - private final IoDescriptor ioDescriptor; - - public IterationExecutionComputeWithAggEvent(long schedulerId, - int workerId, - int cycleId, - long windowId, - long fetchWindowId, - long fetchCount, - IoDescriptor ioDescriptor) { - super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); - this.ioDescriptor = ioDescriptor; + private final IoDescriptor ioDescriptor; + + public IterationExecutionComputeWithAggEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long fetchWindowId, + long fetchCount, + IoDescriptor ioDescriptor) { + super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); + this.ioDescriptor = ioDescriptor; + } + + @Override + public void execute(ITaskContext taskContext) { + ((AbstractAlignedWorker) taskContext.getWorker()).getInputReader().onMessage(fetchAggResult()); + super.execute(taskContext); + } + + @Override + public EventType getEventType() { + return EventType.ITERATIVE_COMPUTE_WITH_AGGREGATE; + } + + private PipelineMessage fetchAggResult() { + List aggRecords = new ArrayList<>(); + List> inputDesc = + new ArrayList<>(this.ioDescriptor.getInputDescriptor().getInputDescMap().values()); + if (inputDesc.size() != 1) { + throw new GeaflowRuntimeException( + "agg result should only have 1 input, but found " + inputDesc.size()); } - - @Override - public void execute(ITaskContext taskContext) { - ((AbstractAlignedWorker) taskContext.getWorker()).getInputReader().onMessage(fetchAggResult()); - super.execute(taskContext); + IInputDesc aggDesc = inputDesc.get(0); + int edgeId = aggDesc.getEdgeId(); + aggRecords.addAll(aggDesc.getInput()); + return new PipelineMessage<>( + edgeId, + this.fetchWindowId, + RecordArgs.GraphRecordNames.Aggregate.name(), + new DataMessageIterator<>(aggRecords)); + } + + private class DataMessageIterator implements IMessageIterator { + + private final Iterator iterator; + private long size = 0; + + public DataMessageIterator(List data) { + this.iterator = data.iterator(); + this.size = data.size(); } @Override - public EventType getEventType() { - return EventType.ITERATIVE_COMPUTE_WITH_AGGREGATE; + public long getSize() { + return size; } - private PipelineMessage fetchAggResult() { - List aggRecords = new ArrayList<>(); - List> inputDesc = new ArrayList<>(this.ioDescriptor.getInputDescriptor().getInputDescMap().values()); - if (inputDesc.size() != 1) { - throw new GeaflowRuntimeException("agg result should only have 1 input, but found " + inputDesc.size()); - } - IInputDesc aggDesc = inputDesc.get(0); - int edgeId = aggDesc.getEdgeId(); - aggRecords.addAll(aggDesc.getInput()); - return new PipelineMessage<>(edgeId, this.fetchWindowId, - RecordArgs.GraphRecordNames.Aggregate.name(), - new DataMessageIterator<>(aggRecords)); - } - - private class DataMessageIterator implements IMessageIterator { - - private final Iterator iterator; - private long size = 0; - - public DataMessageIterator(List data) { - this.iterator = data.iterator(); - this.size = data.size(); - } - - @Override - public long getSize() { - return size; - } - - @Override - public void close() { - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } + @Override + public void close() {} - @Override - public T next() { - return iterator.next(); - } + @Override + public boolean hasNext() { + return iterator.hasNext(); } @Override - public String toString() { - return "IterationExecutionComputeWithAggEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", fetchWindowId=" + fetchWindowId - + ", fetchCount=" + fetchCount - + '}'; + public T next() { + return iterator.next(); } + } + + @Override + public String toString() { + return "IterationExecutionComputeWithAggEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", fetchWindowId=" + + fetchWindowId + + ", fetchCount=" + + fetchCount + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LaunchSourceEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LaunchSourceEvent.java index cb4adecbc..d5fa1c01a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LaunchSourceEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LaunchSourceEvent.java @@ -25,61 +25,57 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Send from scheduler to cycle head task to launch one iteration of the cycle. - */ +/** Send from scheduler to cycle head task to launch one iteration of the cycle. */ public class LaunchSourceEvent extends AbstractExecutableCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(LaunchSourceEvent.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LaunchSourceEvent.class); - public LaunchSourceEvent(long schedulerId, int workerId, int cycleId, long windowId) { - super(schedulerId, workerId, cycleId, windowId); - } + public LaunchSourceEvent(long schedulerId, int workerId, int cycleId, long windowId) { + super(schedulerId, workerId, cycleId, windowId); + } - @Override - public void execute(ITaskContext taskContext) { - final long start = System.currentTimeMillis(); - super.execute(taskContext); - WorkerContext workerContext = (WorkerContext) this.context; - worker.init(windowId); - if (!workerContext.isFinished()) { - boolean hasData = (boolean) worker.process(null); - if (!hasData) { - workerContext.setFinished(true); - this.sendDoneEvent( - workerContext.getDriverId(), - EventType.LAUNCH_SOURCE, - false, - false); - LOGGER.info("source is finished at {}, workerId {}, taskId {}", - windowId, workerId, workerContext.getTaskId()); - } - } else { - this.sendDoneEvent( - workerContext.getDriverId(), - EventType.LAUNCH_SOURCE, - false, - false); - LOGGER.info("source already finished, workerId {}, taskId {}", - workerId, - workerContext.getTaskId()); - } - worker.finish(windowId); - workerContext.getEventMetrics().addProcessCostMs(System.currentTimeMillis() - start); + @Override + public void execute(ITaskContext taskContext) { + final long start = System.currentTimeMillis(); + super.execute(taskContext); + WorkerContext workerContext = (WorkerContext) this.context; + worker.init(windowId); + if (!workerContext.isFinished()) { + boolean hasData = (boolean) worker.process(null); + if (!hasData) { + workerContext.setFinished(true); + this.sendDoneEvent(workerContext.getDriverId(), EventType.LAUNCH_SOURCE, false, false); + LOGGER.info( + "source is finished at {}, workerId {}, taskId {}", + windowId, + workerId, + workerContext.getTaskId()); + } + } else { + this.sendDoneEvent(workerContext.getDriverId(), EventType.LAUNCH_SOURCE, false, false); + LOGGER.info( + "source already finished, workerId {}, taskId {}", workerId, workerContext.getTaskId()); } + worker.finish(windowId); + workerContext.getEventMetrics().addProcessCostMs(System.currentTimeMillis() - start); + } - @Override - public EventType getEventType() { - return EventType.LAUNCH_SOURCE; - } + @Override + public EventType getEventType() { + return EventType.LAUNCH_SOURCE; + } - @Override - public String toString() { - return "LaunchSourceEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + '}'; - } + @Override + public String toString() { + return "LaunchSourceEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LoadGraphProcessEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LoadGraphProcessEvent.java index 3d119041a..83bd87144 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LoadGraphProcessEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/LoadGraphProcessEvent.java @@ -26,29 +26,40 @@ public class LoadGraphProcessEvent extends AbstractIterationComputeCommand { - public LoadGraphProcessEvent(long schedulerId, int workerId, int cycleId, long windowId, long fetchWindowId, long fetchCount) { - super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); - } + public LoadGraphProcessEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long fetchWindowId, + long fetchCount) { + super(schedulerId, workerId, cycleId, windowId, fetchWindowId, fetchCount); + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - this.fetcherRunner.add(new CloseFetchRequest(((WorkerContext) this.context).getTaskId())); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + this.fetcherRunner.add(new CloseFetchRequest(((WorkerContext) this.context).getTaskId())); + } - @Override - public EventType getEventType() { - return EventType.PRE_GRAPH_PROCESS; - } + @Override + public EventType getEventType() { + return EventType.PRE_GRAPH_PROCESS; + } - @Override - public String toString() { - return "LoadGraphProcessEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", fetchWindowId=" + fetchWindowId - + '}'; - } + @Override + public String toString() { + return "LoadGraphProcessEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", fetchWindowId=" + + fetchWindowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PopWorkerEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PopWorkerEvent.java index b147c4e16..cc5800fd1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PopWorkerEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PopWorkerEvent.java @@ -28,67 +28,72 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Pop worker from cache and reuse context. - */ +/** Pop worker from cache and reuse context. */ public class PopWorkerEvent extends AbstractInitCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(PopWorkerEvent.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PopWorkerEvent.class); - private final int taskId; + private final int taskId; - public PopWorkerEvent(long schedulerId, - int workerId, - int cycleId, - long windowId, - long pipelineId, - String pipelineName, - IoDescriptor ioDescriptor, - int taskId) { - super(schedulerId, workerId, cycleId, windowId, pipelineId, pipelineName, ioDescriptor); - this.taskId = taskId; - } + public PopWorkerEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long pipelineId, + String pipelineName, + IoDescriptor ioDescriptor, + int taskId) { + super(schedulerId, workerId, cycleId, windowId, pipelineId, pipelineName, ioDescriptor); + this.taskId = taskId; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - LOGGER.info("reuse worker context, taskId {}", taskId); - AbstractWorkerContext popWorkerContext = new WorkerContext(taskContext); - popWorkerContext.setPipelineId(pipelineId); - popWorkerContext.setPipelineName(pipelineName); - popWorkerContext.setWindowId(windowId); - popWorkerContext.setTaskId(taskId); - popWorkerContext.setSchedulerId(schedulerId); + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + LOGGER.info("reuse worker context, taskId {}", taskId); + AbstractWorkerContext popWorkerContext = new WorkerContext(taskContext); + popWorkerContext.setPipelineId(pipelineId); + popWorkerContext.setPipelineName(pipelineName); + popWorkerContext.setWindowId(windowId); + popWorkerContext.setTaskId(taskId); + popWorkerContext.setSchedulerId(schedulerId); - ((IAffinityWorker) worker).pop(popWorkerContext); - context = worker.getWorkerContext(); + ((IAffinityWorker) worker).pop(popWorkerContext); + context = worker.getWorkerContext(); - this.initFetcher(); - this.popEmitter(); - } + this.initFetcher(); + this.popEmitter(); + } - @Override - public EventType getEventType() { - return EventType.POP_WORKER; - } + @Override + public EventType getEventType() { + return EventType.POP_WORKER; + } - public long getWindowId() { - return windowId; - } + public long getWindowId() { + return windowId; + } - public long getPipelineId() { - return pipelineId; - } + public long getPipelineId() { + return pipelineId; + } - @Override - public String toString() { - return "PopWorkerEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", pipelineId=" + pipelineId - + ", pipelineName=" + pipelineName - + '}'; - } + @Override + public String toString() { + return "PopWorkerEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", pipelineId=" + + pipelineId + + ", pipelineName=" + + pipelineName + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PrefetchEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PrefetchEvent.java index 1bb8e968c..a6a27b141 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PrefetchEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/PrefetchEvent.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.fetcher.FetchRequest; import org.apache.geaflow.cluster.fetcher.InitFetchRequest; import org.apache.geaflow.cluster.fetcher.PrefetchMessageBuffer; @@ -40,47 +41,50 @@ public class PrefetchEvent extends AbstractInitCommand { - private final ExecutionTask task; - private final List edgeIds; + private final ExecutionTask task; + private final List edgeIds; - public PrefetchEvent(long schedulerId, - int workerId, - int cycleId, - long windowId, - long pipelineId, - String pipelineName, - ExecutionTask task, - IoDescriptor ioDescriptor) { - super(schedulerId, workerId, cycleId, windowId, pipelineId, pipelineName, ioDescriptor); - this.task = task; - this.edgeIds = new ArrayList<>(ioDescriptor.getInputDescriptor().getInputDescMap().keySet()); - } + public PrefetchEvent( + long schedulerId, + int workerId, + int cycleId, + long windowId, + long pipelineId, + String pipelineName, + ExecutionTask task, + IoDescriptor ioDescriptor) { + super(schedulerId, workerId, cycleId, windowId, pipelineId, pipelineName, ioDescriptor); + this.task = task; + this.edgeIds = new ArrayList<>(ioDescriptor.getInputDescriptor().getInputDescMap().keySet()); + } - public List getEdgeIds() { - return this.edgeIds; - } + public List getEdgeIds() { + return this.edgeIds; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - this.task.buildTaskName(this.pipelineName, this.cycleId, this.windowId); - InitFetchRequest initFetchRequest = this.buildInitFetchRequest( - this.ioDescriptor.getInputDescriptor(), this.task, null); - this.fetcherRunner.add(initFetchRequest); - this.fetcherRunner.add(new FetchRequest(this.task.getTaskId(), this.windowId, 1)); - } + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + this.task.buildTaskName(this.pipelineName, this.cycleId, this.windowId); + InitFetchRequest initFetchRequest = + this.buildInitFetchRequest(this.ioDescriptor.getInputDescriptor(), this.task, null); + this.fetcherRunner.add(initFetchRequest); + this.fetcherRunner.add(new FetchRequest(this.task.getTaskId(), this.windowId, 1)); + } - @Override - protected InitFetchRequest buildInitFetchRequest(InputDescriptor inputDescriptor, ExecutionTask task, EventMetrics eventMetrics) { - Map inputDescMap = new HashMap<>(); - for (Map.Entry> entry : inputDescriptor.getInputDescMap().entrySet()) { - IInputDesc inputDesc = entry.getValue(); - if (inputDesc.getInputType() == InputType.META) { - inputDescMap.put(entry.getKey(), (ShardInputDesc) entry.getValue()); - } - } + @Override + protected InitFetchRequest buildInitFetchRequest( + InputDescriptor inputDescriptor, ExecutionTask task, EventMetrics eventMetrics) { + Map inputDescMap = new HashMap<>(); + for (Map.Entry> entry : inputDescriptor.getInputDescMap().entrySet()) { + IInputDesc inputDesc = entry.getValue(); + if (inputDesc.getInputType() == InputType.META) { + inputDescMap.put(entry.getKey(), (ShardInputDesc) entry.getValue()); + } + } - InitFetchRequest initFetchRequest = new InitFetchRequest( + InitFetchRequest initFetchRequest = + new InitFetchRequest( this.pipelineId, this.pipelineName, task.getVertexId(), @@ -90,33 +94,40 @@ protected InitFetchRequest buildInitFetchRequest(InputDescriptor inputDescriptor task.getTaskName(), inputDescMap); - for (Map.Entry entry : inputDescMap.entrySet()) { - Integer edgeId = entry.getKey(); - SliceId sliceId = new SliceId(this.pipelineId, edgeId, -1, task.getIndex()); - PrefetchMessageBuffer prefetchMessageBuffer = new PrefetchMessageBuffer<>(task.getTaskName(), sliceId); - initFetchRequest.addListener(prefetchMessageBuffer); - PrefetchCallbackHandler.getInstance() - .registerTaskEventCallback(sliceId, new PrefetchCallbackHandler.PrefetchCallback(prefetchMessageBuffer)); - } - - return initFetchRequest; + for (Map.Entry entry : inputDescMap.entrySet()) { + Integer edgeId = entry.getKey(); + SliceId sliceId = new SliceId(this.pipelineId, edgeId, -1, task.getIndex()); + PrefetchMessageBuffer prefetchMessageBuffer = + new PrefetchMessageBuffer<>(task.getTaskName(), sliceId); + initFetchRequest.addListener(prefetchMessageBuffer); + PrefetchCallbackHandler.getInstance() + .registerTaskEventCallback( + sliceId, new PrefetchCallbackHandler.PrefetchCallback(prefetchMessageBuffer)); } - @Override - public EventType getEventType() { - return EventType.PREFETCH; - } + return initFetchRequest; + } - @Override - public String toString() { - return "PrefetchEvent{" - + "taskId=" + this.task.getTaskId() - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + ", pipelineId=" + pipelineId - + ", pipelineName=" + pipelineName - + '}'; - } + @Override + public EventType getEventType() { + return EventType.PREFETCH; + } + @Override + public String toString() { + return "PrefetchEvent{" + + "taskId=" + + this.task.getTaskId() + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + ", pipelineId=" + + pipelineId + + ", pipelineName=" + + pipelineName + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/RollbackCycleEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/RollbackCycleEvent.java index a1f3eb0ea..e26b292ee 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/RollbackCycleEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/RollbackCycleEvent.java @@ -25,39 +25,40 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Rollback worker to specified batch. - */ +/** Rollback worker to specified batch. */ public class RollbackCycleEvent extends AbstractExecutableCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(RollbackCycleEvent.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RollbackCycleEvent.class); + public RollbackCycleEvent(long schedulerId, int workerId, int cycleId, long windowId) { + super(schedulerId, workerId, cycleId, windowId); + } - public RollbackCycleEvent(long schedulerId, int workerId, int cycleId, long windowId) { - super(schedulerId, workerId, cycleId, windowId); + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + if (worker instanceof TransactionTrait) { + LOGGER.info("worker do rollback {}", windowId); + ((TransactionTrait) worker).rollback(windowId); } + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - if (worker instanceof TransactionTrait) { - LOGGER.info("worker do rollback {}", windowId); - ((TransactionTrait) worker).rollback(windowId); - } - } + @Override + public EventType getEventType() { + return EventType.ROLLBACK; + } - @Override - public EventType getEventType() { - return EventType.ROLLBACK; - } - - @Override - public String toString() { - return "RollbackCycleEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + '}'; - } + @Override + public String toString() { + return "RollbackCycleEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/StashWorkerEvent.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/StashWorkerEvent.java index d5b18eecd..e927cfc79 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/StashWorkerEvent.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/protocol/StashWorkerEvent.java @@ -29,57 +29,60 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Stash worker. - */ +/** Stash worker. */ public class StashWorkerEvent extends AbstractCleanCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(StashWorkerEvent.class); + private static final Logger LOGGER = LoggerFactory.getLogger(StashWorkerEvent.class); - private final int taskId; + private final int taskId; - public StashWorkerEvent(long schedulerId, int workerId, int cycleId, long windowId, int taskId) { - super(schedulerId, workerId, cycleId, windowId); - this.taskId = taskId; - } + public StashWorkerEvent(long schedulerId, int workerId, int cycleId, long windowId, int taskId) { + super(schedulerId, workerId, cycleId, windowId); + this.taskId = taskId; + } - @Override - public void execute(ITaskContext taskContext) { - super.execute(taskContext); - WorkerContext workerContext = (WorkerContext) this.context; - ExecutionTask executionTask = workerContext.getExecutionTask(); - workerContext.getEventMetrics().setFinishTime(System.currentTimeMillis()); - LOGGER.info("stash task {} {}/{} of {} {} : {}", - executionTask.getTaskId(), - executionTask.getIndex(), - executionTask.getParallelism(), - executionTask.getVertexId(), - executionTask.getProcessor().toString(), - workerContext.getEventMetrics()); + @Override + public void execute(ITaskContext taskContext) { + super.execute(taskContext); + WorkerContext workerContext = (WorkerContext) this.context; + ExecutionTask executionTask = workerContext.getExecutionTask(); + workerContext.getEventMetrics().setFinishTime(System.currentTimeMillis()); + LOGGER.info( + "stash task {} {}/{} of {} {} : {}", + executionTask.getTaskId(), + executionTask.getIndex(), + executionTask.getParallelism(), + executionTask.getVertexId(), + executionTask.getProcessor().toString(), + workerContext.getEventMetrics()); - // Stash worker context. - ((IAffinityWorker) worker).stash(); + // Stash worker context. + ((IAffinityWorker) worker).stash(); - this.fetcherRunner.add(new CloseFetchRequest(this.taskId)); - this.emitterRunner.add(new StashEmitterRequest(this.taskId, this.windowId)); - worker.close(); - LOGGER.info("stash worker context, taskId {}", ((WorkerContext) context).getTaskId()); + this.fetcherRunner.add(new CloseFetchRequest(this.taskId)); + this.emitterRunner.add(new StashEmitterRequest(this.taskId, this.windowId)); + worker.close(); + LOGGER.info("stash worker context, taskId {}", ((WorkerContext) context).getTaskId()); - this.sendDoneEvent(workerContext.getDriverId(), EventType.CLEAN_CYCLE, null, true); - } + this.sendDoneEvent(workerContext.getDriverId(), EventType.CLEAN_CYCLE, null, true); + } - @Override - public EventType getEventType() { - return EventType.STASH_WORKER; - } + @Override + public EventType getEventType() { + return EventType.STASH_WORKER; + } - @Override - public String toString() { - return "StashWorkerEvent{" - + "schedulerId=" + schedulerId - + ", workerId=" + workerId - + ", cycleId=" + cycleId - + ", windowId=" + windowId - + '}'; - } + @Override + public String toString() { + return "StashWorkerEvent{" + + "schedulerId=" + + schedulerId + + ", workerId=" + + workerId + + ", cycleId=" + + cycleId + + ", windowId=" + + windowId + + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/AbstractCycleScheduler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/AbstractCycleScheduler.java index 54ef10cf0..712bfed08 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/AbstractCycleScheduler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/AbstractCycleScheduler.java @@ -32,72 +32,76 @@ import org.slf4j.LoggerFactory; public abstract class AbstractCycleScheduler< - C extends IExecutionCycle, - PC extends IExecutionCycle, - PCC extends ICycleSchedulerContext, - R, E> implements ICycleScheduler { + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext, + R, + E> + implements ICycleScheduler { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractCycleScheduler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractCycleScheduler.class); - protected C cycle; - protected ICycleSchedulerContext context; - protected SchedulerEventDispatcher dispatcher; - protected IStateMachine stateMachine; + protected C cycle; + protected ICycleSchedulerContext context; + protected SchedulerEventDispatcher dispatcher; + protected IStateMachine stateMachine; - @Override - public void init(ICycleSchedulerContext context) { - this.cycle = context.getCycle(); - this.context = context; - } + @Override + public void init(ICycleSchedulerContext context) { + this.cycle = context.getCycle(); + this.context = context; + } - public IExecutionResult execute() { + public IExecutionResult execute() { - String cycleLogTag = LoggerFormatter.getCycleTag(cycle.getPipelineName(), cycle.getCycleId()); - try { - while (!stateMachine.isTerminated()) { - while (true) { - IScheduleState oldState = stateMachine.getCurrentState(); - IScheduleState state = stateMachine.readyToTransition(); - if (state == null) { - finishFlyingEvent(); - break; - } - if (state.getScheduleStateType() == END) { - break; - } - LOGGER.info("{} state transition from {} to {}", - this.getClass(), oldState.getScheduleStateType(), state.getScheduleStateType()); - execute(state); - } - } - LOGGER.info("{} finished at {}", cycleLogTag, context.getFinishIterationId()); - R result = finish(); - context.finish(); - return ExecutionResult.buildSuccessResult(result); - } catch (Throwable e) { - LOGGER.error(String.format("%s occur exception", cycleLogTag), e); - return (ExecutionResult) ExecutionResult.buildFailedResult(e); + String cycleLogTag = LoggerFormatter.getCycleTag(cycle.getPipelineName(), cycle.getCycleId()); + try { + while (!stateMachine.isTerminated()) { + while (true) { + IScheduleState oldState = stateMachine.getCurrentState(); + IScheduleState state = stateMachine.readyToTransition(); + if (state == null) { + finishFlyingEvent(); + break; + } + if (state.getScheduleStateType() == END) { + break; + } + LOGGER.info( + "{} state transition from {} to {}", + this.getClass(), + oldState.getScheduleStateType(), + state.getScheduleStateType()); + execute(state); } + } + LOGGER.info("{} finished at {}", cycleLogTag, context.getFinishIterationId()); + R result = finish(); + context.finish(); + return ExecutionResult.buildSuccessResult(result); + } catch (Throwable e) { + LOGGER.error(String.format("%s occur exception", cycleLogTag), e); + return (ExecutionResult) ExecutionResult.buildFailedResult(e); } + } - @Override - public void close() { - } + @Override + public void close() {} - protected abstract void execute(IScheduleState state); + protected abstract void execute(IScheduleState state); - protected void finishFlyingEvent() { - // Handle response task until received all responses of certain iteration. - while (context.hasNextToFinish()) { - long finishedIterationId = context.getNextFinishIterationId(); - finish(finishedIterationId); - context.finish(finishedIterationId); - } + protected void finishFlyingEvent() { + // Handle response task until received all responses of certain iteration. + while (context.hasNextToFinish()) { + long finishedIterationId = context.getNextFinishIterationId(); + finish(finishedIterationId); + context.finish(finishedIterationId); } + } - protected abstract void finish(long iterationId); + protected abstract void finish(long iterationId); - protected abstract R finish(); + protected abstract R finish(); - protected abstract void registerEventListener(); + protected abstract void registerEventListener(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleResponseEventPool.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleResponseEventPool.java index c50694877..5363fe3ba 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleResponseEventPool.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleResponseEventPool.java @@ -21,41 +21,42 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class CycleResponseEventPool { - private LinkedBlockingQueue eventQueue; - private static final int WAITING_TIME_OUT = 10; + private LinkedBlockingQueue eventQueue; + private static final int WAITING_TIME_OUT = 10; - public CycleResponseEventPool() { - eventQueue = new LinkedBlockingQueue(); - } + public CycleResponseEventPool() { + eventQueue = new LinkedBlockingQueue(); + } - public void notifyEvent(T event) { - try { - eventQueue.put(event); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + public void notifyEvent(T event) { + try { + eventQueue.put(event); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } - - public T waitEvent() { - while (true) { - try { - // Wait until get available response. - T event = eventQueue.poll(WAITING_TIME_OUT, TimeUnit.MILLISECONDS); - if (event == null) { - continue; - } - return event; - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + } + + public T waitEvent() { + while (true) { + try { + // Wait until get available response. + T event = eventQueue.poll(WAITING_TIME_OUT, TimeUnit.MILLISECONDS); + if (event == null) { + continue; } + return event; + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); + } } + } - public void clear() { - eventQueue.clear(); - } + public void clear() { + eventQueue.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleSchedulerFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleSchedulerFactory.java index 333a30a2d..b797dcfb0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleSchedulerFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/CycleSchedulerFactory.java @@ -26,23 +26,26 @@ public class CycleSchedulerFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(CycleSchedulerFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CycleSchedulerFactory.class); - public static ICycleScheduler create(IExecutionCycle cycle) { - ICycleScheduler scheduler; - switch (cycle.getType()) { - case GRAPH: - scheduler = new ExecutionGraphCycleScheduler<>(cycle.getSchedulerId()); - break; - case ITERATION: - case ITERATION_WITH_AGG: - case PIPELINE: - scheduler = new PipelineCycleScheduler<>(cycle.getSchedulerId()); - break; - default: - throw new GeaflowRuntimeException(String.format("not support cycle %s yet", cycle)); - } - LOGGER.info("create scheduler {} for cycle {}", scheduler.getClass().getSimpleName(), cycle.getCycleId()); - return scheduler; + public static ICycleScheduler create(IExecutionCycle cycle) { + ICycleScheduler scheduler; + switch (cycle.getType()) { + case GRAPH: + scheduler = new ExecutionGraphCycleScheduler<>(cycle.getSchedulerId()); + break; + case ITERATION: + case ITERATION_WITH_AGG: + case PIPELINE: + scheduler = new PipelineCycleScheduler<>(cycle.getSchedulerId()); + break; + default: + throw new GeaflowRuntimeException(String.format("not support cycle %s yet", cycle)); } + LOGGER.info( + "create scheduler {} for cycle {}", + scheduler.getClass().getSimpleName(), + cycle.getCycleId()); + return scheduler; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutableEventIterator.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutableEventIterator.java index 73202b00f..8988e1222 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutableEventIterator.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutableEventIterator.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,102 +33,100 @@ public class ExecutableEventIterator { - private final Map> worker2events = new TreeMap<>(); - private int size = 0; - private Iterator>> workerIterator; - private boolean ready = false; - - public Map> getEvents() { - return this.worker2events; + private final Map> worker2events = new TreeMap<>(); + private int size = 0; + private Iterator>> workerIterator; + private boolean ready = false; + + public Map> getEvents() { + return this.worker2events; + } + + public void markReady() { + this.workerIterator = this.worker2events.entrySet().iterator(); + this.ready = true; + } + + public ExecutableEventIterator merge(ExecutableEventIterator other) { + for (List events : other.getEvents().values()) { + for (ExecutableEvent event : events) { + this.addEvent(event); + } } + return this; + } - public void markReady() { - this.workerIterator = this.worker2events.entrySet().iterator(); - this.ready = true; - } + public int size() { + return this.size; + } - public ExecutableEventIterator merge(ExecutableEventIterator other) { - for (List events : other.getEvents().values()) { - for (ExecutableEvent event : events) { - this.addEvent(event); - } - } - return this; - } + ////////////////////////////// + // Produce event. - public int size() { - return this.size; - } + /// /////////////////////////// - ////////////////////////////// - // Produce event. + public void addEvent(WorkerInfo worker, ExecutionTask task, IEvent event) { + this.addEvent(ExecutableEvent.build(worker, task, event)); + } - /// /////////////////////////// - - public void addEvent(WorkerInfo worker, ExecutionTask task, IEvent event) { - this.addEvent(ExecutableEvent.build(worker, task, event)); - } - - public void addEvent(ExecutableEvent event) { - if (this.ready) { - throw new GeaflowRuntimeException("event iterator already mark ready"); - } - List events = this.worker2events.computeIfAbsent(event.getWorker(), w -> new ArrayList<>()); - events.add(event); - this.size++; + public void addEvent(ExecutableEvent event) { + if (this.ready) { + throw new GeaflowRuntimeException("event iterator already mark ready"); } + List events = + this.worker2events.computeIfAbsent(event.getWorker(), w -> new ArrayList<>()); + events.add(event); + this.size++; + } + ////////////////////////////// + // Consume event. - ////////////////////////////// - // Consume event. - - /// /////////////////////////// + /// /////////////////////////// - public boolean hasNext() { - if (!this.ready) { - throw new GeaflowRuntimeException("event iterator not ready"); - } - return this.workerIterator.hasNext(); + public boolean hasNext() { + if (!this.ready) { + throw new GeaflowRuntimeException("event iterator not ready"); } + return this.workerIterator.hasNext(); + } - public Tuple> next() { - Map.Entry> next = this.workerIterator.next(); - return Tuple.of(next.getKey(), next.getValue()); - } - - public static class ExecutableEvent { + public Tuple> next() { + Map.Entry> next = this.workerIterator.next(); + return Tuple.of(next.getKey(), next.getValue()); + } - private final WorkerInfo worker; - private final ExecutionTask task; - private final IEvent event; + public static class ExecutableEvent { - private ExecutableEvent(WorkerInfo worker, ExecutionTask task, IEvent event) { - this.worker = worker; - this.task = task; - this.event = event; - } + private final WorkerInfo worker; + private final ExecutionTask task; + private final IEvent event; - public WorkerInfo getWorker() { - return this.worker; - } - - public ExecutionTask getTask() { - return this.task; - } + private ExecutableEvent(WorkerInfo worker, ExecutionTask task, IEvent event) { + this.worker = worker; + this.task = task; + this.event = event; + } - public IEvent getEvent() { - return this.event; - } + public WorkerInfo getWorker() { + return this.worker; + } - @Override - public String toString() { - return String.valueOf(this.event); - } + public ExecutionTask getTask() { + return this.task; + } - public static ExecutableEvent build(WorkerInfo worker, ExecutionTask task, IEvent event) { - return new ExecutableEvent(worker, task, event); - } + public IEvent getEvent() { + return this.event; + } + @Override + public String toString() { + return String.valueOf(this.event); } + public static ExecutableEvent build(WorkerInfo worker, ExecutionTask task, IEvent event) { + return new ExecutableEvent(worker, task, event); + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionCycleTaskAssigner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionCycleTaskAssigner.java index 471d6a144..610f7ebd9 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionCycleTaskAssigner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionCycleTaskAssigner.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.core.graph.ExecutionGraph; import org.apache.geaflow.core.graph.ExecutionTask; import org.apache.geaflow.core.graph.ExecutionVertex; @@ -33,30 +34,36 @@ public class ExecutionCycleTaskAssigner { - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionCycleTaskAssigner.class); - - private static AtomicInteger taskId = new AtomicInteger(0); + private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionCycleTaskAssigner.class); - public static Map> assign(ExecutionGraph executionGraph) { + private static AtomicInteger taskId = new AtomicInteger(0); - Map> vertex2Tasks = new HashMap<>(); + public static Map> assign(ExecutionGraph executionGraph) { - for (ExecutionVertexGroup vertexGroup : executionGraph.getVertexGroupMap().values()) { - for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { - List tasks = new ArrayList<>(); - List taskIds = new ArrayList<>(); - for (int i = 0; i < vertex.getParallelism(); i++) { - ExecutionTask task = new ExecutionTask(taskId.getAndIncrement(), - i, vertex.getParallelism(), vertex.getMaxParallelism(), vertex.getNumPartitions(), vertex.getVertexId()); - task.setIterative(vertexGroup.getCycleGroupMeta().isIterative()); - tasks.add(task); - taskIds.add(task.getTaskId()); - } - LOGGER.info("assign task vertexId:{}, taskIds:{}", vertex.getVertexId(), taskIds); + Map> vertex2Tasks = new HashMap<>(); - vertex2Tasks.put(vertex.getVertexId(), tasks); - } + for (ExecutionVertexGroup vertexGroup : executionGraph.getVertexGroupMap().values()) { + for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { + List tasks = new ArrayList<>(); + List taskIds = new ArrayList<>(); + for (int i = 0; i < vertex.getParallelism(); i++) { + ExecutionTask task = + new ExecutionTask( + taskId.getAndIncrement(), + i, + vertex.getParallelism(), + vertex.getMaxParallelism(), + vertex.getNumPartitions(), + vertex.getVertexId()); + task.setIterative(vertexGroup.getCycleGroupMeta().isIterative()); + tasks.add(task); + taskIds.add(task.getTaskId()); } - return vertex2Tasks; + LOGGER.info("assign task vertexId:{}, taskIds:{}", vertex.getVertexId(), taskIds); + + vertex2Tasks.put(vertex.getVertexId(), tasks); + } } + return vertex2Tasks; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleScheduler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleScheduler.java index 99bd7ec7e..545ffb873 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleScheduler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleScheduler.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; @@ -57,211 +58,259 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ExecutionGraphCycleScheduler, R, E> +public class ExecutionGraphCycleScheduler< + PC extends IExecutionCycle, PCC extends ICycleSchedulerContext, R, E> extends AbstractCycleScheduler implements IEventListener { - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionGraphCycleScheduler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionGraphCycleScheduler.class); - private CycleResultManager resultManager; - private long pipelineId; - private long schedulerId; - private String pipelineName; - private PipelineMetrics pipelineMetrics; - private List results; - private String cycleLogTag; + private CycleResultManager resultManager; + private long pipelineId; + private long schedulerId; + private String pipelineName; + private PipelineMetrics pipelineMetrics; + private List results; + private String cycleLogTag; - public ExecutionGraphCycleScheduler() { - } + public ExecutionGraphCycleScheduler() {} - public ExecutionGraphCycleScheduler(long schedulerId) { - this.schedulerId = schedulerId; - } + public ExecutionGraphCycleScheduler(long schedulerId) { + this.schedulerId = schedulerId; + } - @Override - public void init(ICycleSchedulerContext context) { - super.init(context); - this.resultManager = context.getResultManager(); - this.results = new ArrayList<>(); - this.context.getSchedulerWorkerManager().init(this.cycle); - this.context.getSchedulerWorkerManager().assign(this.cycle); - this.pipelineId = this.cycle.getPipelineId(); - this.cycleLogTag = LoggerFormatter.getCycleTag(context.getCycle().getPipelineName(), cycle.getCycleId()); - this.dispatcher = new SchedulerEventDispatcher(cycleLogTag); - registerEventListener(); - this.stateMachine = new GraphStateMachine(); - this.stateMachine.init(context); - } + @Override + public void init(ICycleSchedulerContext context) { + super.init(context); + this.resultManager = context.getResultManager(); + this.results = new ArrayList<>(); + this.context.getSchedulerWorkerManager().init(this.cycle); + this.context.getSchedulerWorkerManager().assign(this.cycle); + this.pipelineId = this.cycle.getPipelineId(); + this.cycleLogTag = + LoggerFormatter.getCycleTag(context.getCycle().getPipelineName(), cycle.getCycleId()); + this.dispatcher = new SchedulerEventDispatcher(cycleLogTag); + registerEventListener(); + this.stateMachine = new GraphStateMachine(); + this.stateMachine.init(context); + } - @Override - public void execute(IScheduleState event) { - ScheduleState state = (ScheduleState) event; - switch (state.getScheduleStateType()) { - case EXECUTE_COMPUTE: - execute(context.getNextIterationId()); - break; - default: - throw new GeaflowRuntimeException(String.format("not support event {}", event)); - } + @Override + public void execute(IScheduleState event) { + ScheduleState state = (ScheduleState) event; + switch (state.getScheduleStateType()) { + case EXECUTE_COMPUTE: + execute(context.getNextIterationId()); + break; + default: + throw new GeaflowRuntimeException(String.format("not support event {}", event)); } + } - public void execute(long iterationId) { - IScheduleStrategy scheduleStrategy = new TopologicalOrderScheduleStrategy(context.getConfig()); - scheduleStrategy.init(cycle); - this.pipelineName = getPipelineName(iterationId); - this.pipelineMetrics = new PipelineMetrics(pipelineName); - this.pipelineMetrics.setStartTime(System.currentTimeMillis()); - StatsCollectorFactory.getInstance().getPipelineStatsCollector().reportPipelineMetrics(pipelineMetrics); - LOGGER.info("{} execute iterationId {}, executionId {}", pipelineName, iterationId, pipelineId); + public void execute(long iterationId) { + IScheduleStrategy scheduleStrategy = + new TopologicalOrderScheduleStrategy(context.getConfig()); + scheduleStrategy.init(cycle); + this.pipelineName = getPipelineName(iterationId); + this.pipelineMetrics = new PipelineMetrics(pipelineName); + this.pipelineMetrics.setStartTime(System.currentTimeMillis()); + StatsCollectorFactory.getInstance() + .getPipelineStatsCollector() + .reportPipelineMetrics(pipelineMetrics); + LOGGER.info("{} execute iterationId {}, executionId {}", pipelineName, iterationId, pipelineId); - while (scheduleStrategy.hasNext()) { + while (scheduleStrategy.hasNext()) { - // Get cycle that ready to schedule. - IExecutionCycle nextCycle = scheduleStrategy.next(); - ExecutionNodeCycle cycle = (ExecutionNodeCycle) nextCycle; - cycle.setPipelineName(pipelineName); - LOGGER.info("{} start schedule {}, total task num {}, head task num {}, tail " - + "task num {} type:{}", - pipelineName, LoggerFormatter.getCycleName(cycle.getCycleId(), iterationId), cycle.getTasks().size(), - cycle.getCycleHeads().size(), cycle.getCycleTails().size(), cycle.getType()); + // Get cycle that ready to schedule. + IExecutionCycle nextCycle = scheduleStrategy.next(); + ExecutionNodeCycle cycle = (ExecutionNodeCycle) nextCycle; + cycle.setPipelineName(pipelineName); + LOGGER.info( + "{} start schedule {}, total task num {}, head task num {}, tail " + + "task num {} type:{}", + pipelineName, + LoggerFormatter.getCycleName(cycle.getCycleId(), iterationId), + cycle.getTasks().size(), + cycle.getCycleHeads().size(), + cycle.getCycleTails().size(), + cycle.getType()); - // Schedule cycle. - PipelineCycleScheduler cycleScheduler = - (PipelineCycleScheduler) CycleSchedulerFactory.create(cycle); - ICycleSchedulerContext cycleContext = CycleSchedulerContextFactory.create(cycle, context); - cycleScheduler.init(cycleContext); + // Schedule cycle. + PipelineCycleScheduler cycleScheduler = + (PipelineCycleScheduler) CycleSchedulerFactory.create(cycle); + ICycleSchedulerContext cycleContext = CycleSchedulerContextFactory.create(cycle, context); + cycleScheduler.init(cycleContext); - EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId()); - if (cycleScheduler instanceof IEventListener) { - dispatcher.registerListener(listenerKey, cycleScheduler); - } + EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId()); + if (cycleScheduler instanceof IEventListener) { + dispatcher.registerListener(listenerKey, cycleScheduler); + } - try { - final long start = System.currentTimeMillis(); - IExecutionResult result = cycleScheduler.execute(); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException(String.format("%s schedule execute %s failed ", - pipelineName, cycle.getCycleId())); - } - if (result.getResult() != null) { - results.add(result.getResult()); - } - scheduleStrategy.finish(cycle); - LOGGER.info("{} schedule {} finished, cost {}ms", - pipelineName, LoggerFormatter.getCycleName(cycle.getCycleId(), iterationId), - System.currentTimeMillis() - start); - if (cycleScheduler instanceof IEventListener) { - dispatcher.removeListener(listenerKey); - } - } catch (Throwable e) { - throw new GeaflowRuntimeException(String.format("%s schedule iterationId %s failed ", - pipelineName, iterationId), e); - } finally { - cycleScheduler.close(); - } + try { + final long start = System.currentTimeMillis(); + IExecutionResult result = cycleScheduler.execute(); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException( + String.format("%s schedule execute %s failed ", pipelineName, cycle.getCycleId())); } + if (result.getResult() != null) { + results.add(result.getResult()); + } + scheduleStrategy.finish(cycle); + LOGGER.info( + "{} schedule {} finished, cost {}ms", + pipelineName, + LoggerFormatter.getCycleName(cycle.getCycleId(), iterationId), + System.currentTimeMillis() - start); + if (cycleScheduler instanceof IEventListener) { + dispatcher.removeListener(listenerKey); + } + } catch (Throwable e) { + throw new GeaflowRuntimeException( + String.format("%s schedule iterationId %s failed ", pipelineName, iterationId), e); + } finally { + cycleScheduler.close(); + } } + } - @Override - public void finish(long iterationId) { - String cycleLogTag = LoggerFormatter.getCycleTag(pipelineName, cycle.getCycleId(), iterationId); - // Clean shuffle data for all used workers. - context.getSchedulerWorkerManager().clean(usedWorkers -> - cleanEnv(usedWorkers, cycleLogTag, iterationId, false), cycle); - // Clear last iteration shard meta. - resultManager.clear(); - DataExchanger.clear(); - this.pipelineMetrics.setDuration(System.currentTimeMillis() - pipelineMetrics.getStartTime()); - StatsCollectorFactory.getInstance().getPipelineStatsCollector().reportPipelineMetrics(pipelineMetrics); - LOGGER.info("{} finished {}", cycleLogTag, pipelineMetrics); - } + @Override + public void finish(long iterationId) { + String cycleLogTag = LoggerFormatter.getCycleTag(pipelineName, cycle.getCycleId(), iterationId); + // Clean shuffle data for all used workers. + context + .getSchedulerWorkerManager() + .clean(usedWorkers -> cleanEnv(usedWorkers, cycleLogTag, iterationId, false), cycle); + // Clear last iteration shard meta. + resultManager.clear(); + DataExchanger.clear(); + this.pipelineMetrics.setDuration(System.currentTimeMillis() - pipelineMetrics.getStartTime()); + StatsCollectorFactory.getInstance() + .getPipelineStatsCollector() + .reportPipelineMetrics(pipelineMetrics); + LOGGER.info("{} finished {}", cycleLogTag, pipelineMetrics); + } - @Override - protected R finish() { - // Clean shuffle data for all used workers. - String cycleLogTag = LoggerFormatter.getCycleTag(pipelineName, cycle.getCycleId()); - context.getSchedulerWorkerManager().clean(usedWorkers -> cleanEnv(usedWorkers, cycleLogTag, - cycle.getIterationCount(), true), cycle); - return (R) results; - } + @Override + protected R finish() { + // Clean shuffle data for all used workers. + String cycleLogTag = LoggerFormatter.getCycleTag(pipelineName, cycle.getCycleId()); + context + .getSchedulerWorkerManager() + .clean( + usedWorkers -> cleanEnv(usedWorkers, cycleLogTag, cycle.getIterationCount(), true), + cycle); + return (R) results; + } - private void cleanEnv(List usedWorkers, - String cycleLogTag, - long iterationId, - boolean needCleanWorkerContext) { - CountDownLatch latch = new CountDownLatch(1); - LOGGER.info("{} start wait {} clean env response, need clean worker context {}", - cycleLogTag, usedWorkers.size(), needCleanWorkerContext); - // Register listener to handle response. - EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId(), EventType.CLEAN_ENV); - ComputeFinishEventListener listener = - new ComputeFinishEventListener(usedWorkers.size(), events -> { - LOGGER.info("{} clean env response {} finished all {} events", cycleLogTag, listenerKey, events.size()); - latch.countDown(); + private void cleanEnv( + List usedWorkers, + String cycleLogTag, + long iterationId, + boolean needCleanWorkerContext) { + CountDownLatch latch = new CountDownLatch(1); + LOGGER.info( + "{} start wait {} clean env response, need clean worker context {}", + cycleLogTag, + usedWorkers.size(), + needCleanWorkerContext); + // Register listener to handle response. + EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId(), EventType.CLEAN_ENV); + ComputeFinishEventListener listener = + new ComputeFinishEventListener( + usedWorkers.size(), + events -> { + LOGGER.info( + "{} clean env response {} finished all {} events", + cycleLogTag, + listenerKey, + events.size()); + latch.countDown(); }); - dispatcher.registerListener(listenerKey, listener); + dispatcher.registerListener(listenerKey, listener); - List> submitFutures = new ArrayList<>(usedWorkers.size()); - for (WorkerInfo worker : usedWorkers) { - IEvent cleanEvent; - if (needCleanWorkerContext) { - cleanEvent = new CleanEnvEvent(schedulerId, worker.getWorkerIndex(), - cycle.getCycleId(), iterationId, pipelineId, cycle.getDriverId()); - } else { - cleanEvent = new CleanStashEnvEvent(schedulerId, worker.getWorkerIndex(), - cycle.getCycleId(), iterationId, pipelineId, cycle.getDriverId()); - } - Future future = RpcClient.getInstance() - .processContainer(worker.getContainerName(), cleanEvent); - submitFutures.add(future); - } - FutureUtil.wait(submitFutures); - - try { - latch.await(); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException("exception when wait all clean event finish", e); - } finally { - ShuffleManager.getInstance().release(pipelineId); - } + List> submitFutures = new ArrayList<>(usedWorkers.size()); + for (WorkerInfo worker : usedWorkers) { + IEvent cleanEvent; + if (needCleanWorkerContext) { + cleanEvent = + new CleanEnvEvent( + schedulerId, + worker.getWorkerIndex(), + cycle.getCycleId(), + iterationId, + pipelineId, + cycle.getDriverId()); + } else { + cleanEvent = + new CleanStashEnvEvent( + schedulerId, + worker.getWorkerIndex(), + cycle.getCycleId(), + iterationId, + pipelineId, + cycle.getDriverId()); + } + Future future = + RpcClient.getInstance().processContainer(worker.getContainerName(), cleanEvent); + submitFutures.add(future); } + FutureUtil.wait(submitFutures); - private String getPipelineName(long iterationId) { - return String.format("%s-%s", cycle.getPipelineName(), iterationId); + try { + latch.await(); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException("exception when wait all clean event finish", e); + } finally { + ShuffleManager.getInstance().release(pipelineId); } + } - @Override - public void close() { - context.getSchedulerWorkerManager().release(this.cycle); - context.close(this.cycle); - LOGGER.info("{} closed", cycle.getPipelineName()); - } + private String getPipelineName(long iterationId) { + return String.format("%s-%s", cycle.getPipelineName(), iterationId); + } - public long getSchedulerId() { - return schedulerId; - } + @Override + public void close() { + context.getSchedulerWorkerManager().release(this.cycle); + context.close(this.cycle); + LOGGER.info("{} closed", cycle.getPipelineName()); + } - @Override - public void handleEvent(IEvent event) { - dispatcher.dispatch(event); - } + public long getSchedulerId() { + return schedulerId; + } - protected void registerEventListener() { - EventListenerKey listenerKey = EventListenerKey.of(context.getCycle().getCycleId(), EventType.LAUNCH_SOURCE); - IEventListener listener = - new SourceFinishResponseEventListener(getSourceCycleNum(cycle), - events -> { - long finishWindowId = ((DoneEvent) events.iterator().next()).getWindowId(); - LOGGER.info("{} all source finished at {}", cycleLogTag, finishWindowId); - ((AbstractCycleSchedulerContext) context).setTerminateIterationId( - ((DoneEvent) events.iterator().next()).getWindowId()); - }); - // Register listener for end of source event. - this.dispatcher.registerListener(listenerKey, listener); - } + @Override + public void handleEvent(IEvent event) { + dispatcher.dispatch(event); + } - private int getSourceCycleNum(IExecutionCycle cycle) { - return (int) ((ExecutionGraphCycle) cycle).getCycleMap().values().stream() - .filter(e -> ((ExecutionNodeCycle) e).getVertexGroup().getParentVertexGroupIds().isEmpty()).count(); - } + protected void registerEventListener() { + EventListenerKey listenerKey = + EventListenerKey.of(context.getCycle().getCycleId(), EventType.LAUNCH_SOURCE); + IEventListener listener = + new SourceFinishResponseEventListener( + getSourceCycleNum(cycle), + events -> { + long finishWindowId = ((DoneEvent) events.iterator().next()).getWindowId(); + LOGGER.info("{} all source finished at {}", cycleLogTag, finishWindowId); + ((AbstractCycleSchedulerContext) context) + .setTerminateIterationId(((DoneEvent) events.iterator().next()).getWindowId()); + }); + // Register listener for end of source event. + this.dispatcher.registerListener(listenerKey, listener); + } + private int getSourceCycleNum(IExecutionCycle cycle) { + return (int) + ((ExecutionGraphCycle) cycle) + .getCycleMap().values().stream() + .filter( + e -> + ((ExecutionNodeCycle) e) + .getVertexGroup() + .getParentVertexGroupIds() + .isEmpty()) + .count(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ICycleScheduler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ICycleScheduler.java index 1e4672c47..3401333a2 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ICycleScheduler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/ICycleScheduler.java @@ -27,22 +27,18 @@ public interface ICycleScheduler< C extends IExecutionCycle, PC extends IExecutionCycle, PCC extends ICycleSchedulerContext, - R, E> { + R, + E> { - /** - * Initialize cycle scheduler by input context. - * May include assign resource and initialize worker, and set up cycle schedule env. - */ - void init(ICycleSchedulerContext context); + /** + * Initialize cycle scheduler by input context. May include assign resource and initialize worker, + * and set up cycle schedule env. + */ + void init(ICycleSchedulerContext context); - /** - * Execution all cycles. - */ - IExecutionResult execute(); - - /** - * Close the initialized resources and workers. - */ - void close(); + /** Execution all cycles. */ + IExecutionResult execute(); + /** Close the initialized resources and workers. */ + void close(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleScheduler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleScheduler.java index 902c5da2b..bdbb8bc3b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleScheduler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleScheduler.java @@ -28,6 +28,7 @@ import java.util.Objects; import java.util.concurrent.Future; import java.util.stream.Collectors; + import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; @@ -62,412 +63,444 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is pipeline cycle scheduler impl. - */ +/** This class is pipeline cycle scheduler impl. */ public class PipelineCycleScheduler

, E> - extends AbstractCycleScheduler, E> implements IEventListener { - - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineCycleScheduler.class); - - private ExecutionNodeCycle nodeCycle; - private CycleResultManager resultManager; - - private CycleResponseEventPool responseEventPool; - private HashMap> iterationIdToFinishedTasks; - private Map vertexIdToMetrics; - private Map cycleTasks; - private String pipelineName; - private long pipelineId; - private int cycleId; - private long schedulerId; - private long scheduleStartTime; - private boolean isIteration; - - private SchedulerEventBuilder eventBuilder; - private SchedulerGraphAggregateProcessor aggregator; - private String cycleLogTag; - - public PipelineCycleScheduler() { + extends AbstractCycleScheduler, E> + implements IEventListener { + + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineCycleScheduler.class); + + private ExecutionNodeCycle nodeCycle; + private CycleResultManager resultManager; + + private CycleResponseEventPool responseEventPool; + private HashMap> iterationIdToFinishedTasks; + private Map vertexIdToMetrics; + private Map cycleTasks; + private String pipelineName; + private long pipelineId; + private int cycleId; + private long schedulerId; + private long scheduleStartTime; + private boolean isIteration; + + private SchedulerEventBuilder eventBuilder; + private SchedulerGraphAggregateProcessor aggregator; + private String cycleLogTag; + + public PipelineCycleScheduler() {} + + public PipelineCycleScheduler(long schedulerId) { + this.schedulerId = schedulerId; + } + + @Override + public void init(ICycleSchedulerContext context) { + super.init(context); + this.responseEventPool = new CycleResponseEventPool<>(); + this.iterationIdToFinishedTasks = new HashMap<>(); + this.nodeCycle = context.getCycle(); + this.cycleTasks = + nodeCycle.getTasks().stream().collect(Collectors.toMap(ExecutionTask::getTaskId, t -> t)); + this.isIteration = nodeCycle.getVertexGroup().getCycleGroupMeta().isIterative(); + this.pipelineName = nodeCycle.getPipelineName(); + this.pipelineId = nodeCycle.getPipelineId(); + this.cycleId = nodeCycle.getCycleId(); + this.resultManager = context.getResultManager(); + this.cycleLogTag = LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId); + this.dispatcher = new SchedulerEventDispatcher(cycleLogTag); + this.initMetrics(); + + this.stateMachine = new PipelineStateMachine(); + this.stateMachine.init(context); + + this.eventBuilder = new SchedulerEventBuilder(context, this.resultManager, this.schedulerId); + if (nodeCycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG) { + this.aggregator = + new SchedulerGraphAggregateProcessor( + nodeCycle, (AbstractCycleSchedulerContext) context, resultManager); } - public PipelineCycleScheduler(long schedulerId) { - this.schedulerId = schedulerId; + registerEventListener(); + } + + private void initMetrics() { + this.scheduleStartTime = System.currentTimeMillis(); + this.vertexIdToMetrics = new HashMap<>(); + Map vertexMap = this.nodeCycle.getVertexGroup().getVertexMap(); + for (Map.Entry entry : vertexMap.entrySet()) { + Integer vertexId = entry.getKey(); + ExecutionVertex vertex = entry.getValue(); + this.vertexIdToMetrics.put(vertexId, new EventMetrics[vertex.getParallelism()]); } - - @Override - public void init(ICycleSchedulerContext context) { - super.init(context); - this.responseEventPool = new CycleResponseEventPool<>(); - this.iterationIdToFinishedTasks = new HashMap<>(); - this.nodeCycle = context.getCycle(); - this.cycleTasks = nodeCycle.getTasks().stream().collect(Collectors.toMap(ExecutionTask::getTaskId, t -> t)); - this.isIteration = nodeCycle.getVertexGroup().getCycleGroupMeta().isIterative(); - this.pipelineName = nodeCycle.getPipelineName(); - this.pipelineId = nodeCycle.getPipelineId(); - this.cycleId = nodeCycle.getCycleId(); - this.resultManager = context.getResultManager(); - this.cycleLogTag = LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId); - this.dispatcher = new SchedulerEventDispatcher(cycleLogTag); - this.initMetrics(); - - this.stateMachine = new PipelineStateMachine(); - this.stateMachine.init(context); - - this.eventBuilder = new SchedulerEventBuilder(context, this.resultManager, this.schedulerId); - if (nodeCycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG) { - this.aggregator = new SchedulerGraphAggregateProcessor(nodeCycle, - (AbstractCycleSchedulerContext) context, resultManager); - } - - registerEventListener(); + } + + @Override + public void close() { + super.close(); + iterationIdToFinishedTasks.clear(); + responseEventPool.clear(); + if (context.getParentContext() == null) { + ShuffleManager.getInstance().release(pipelineId); + ShuffleManager.getInstance().close(); } - - private void initMetrics() { - this.scheduleStartTime = System.currentTimeMillis(); - this.vertexIdToMetrics = new HashMap<>(); - Map vertexMap = this.nodeCycle.getVertexGroup().getVertexMap(); - for (Map.Entry entry : vertexMap.entrySet()) { - Integer vertexId = entry.getKey(); - ExecutionVertex vertex = entry.getValue(); - this.vertexIdToMetrics.put(vertexId, new EventMetrics[vertex.getParallelism()]); - } + if (context.getParentContext() == null) { + context.close(cycle); } - - @Override - public void close() { - super.close(); - iterationIdToFinishedTasks.clear(); - responseEventPool.clear(); - if (context.getParentContext() == null) { - ShuffleManager.getInstance().release(pipelineId); - ShuffleManager.getInstance().close(); - } - if (context.getParentContext() == null) { - context.close(cycle); - } - LOGGER.info("{} closed", cycleLogTag); + LOGGER.info("{} closed", cycleLogTag); + } + + public long getSchedulerId() { + return schedulerId; + } + + public void setSchedulerId(long schedulerId) { + this.schedulerId = schedulerId; + } + + @Override + protected void execute(IScheduleState state) { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + if (state.getScheduleStateType() == ScheduleStateType.COMPOSE) { + for (IScheduleState s : ((ComposeState) state).getStates()) { + getNextIterationId(context, s); + ExecutableEventIterator tmp = + this.eventBuilder.build(s.getScheduleStateType(), this.context.getCurrentIterationId()); + iterator.merge(tmp); + } + } else { + getNextIterationId(context, state); + ExecutableEventIterator tmp = + this.eventBuilder.build( + state.getScheduleStateType(), this.context.getCurrentIterationId()); + iterator.merge(tmp); } - - public long getSchedulerId() { - return schedulerId; + iterator.markReady(); + + String iterationLogTag = getCycleIterationTag(this.context.getCurrentIterationId()); + LOGGER.info("{} execute", iterationLogTag); + this.iterationIdToFinishedTasks.put(this.context.getCurrentIterationId(), new ArrayList<>()); + this.executeEvents(iterator, this.context.getCurrentIterationId()); + } + + private void executeEvents(ExecutableEventIterator eventIterator, long iterationId) { + List> submitFutures = new ArrayList<>(eventIterator.size()); + while (eventIterator.hasNext()) { + Tuple> tuple = eventIterator.next(); + WorkerInfo worker = tuple.f0; + List executableEvents = tuple.f1; + for (ExecutableEventIterator.ExecutableEvent executableEvent : executableEvents) { + ExecutionTask task = executableEvent.getTask(); + IEvent event = executableEvent.getEvent(); + String taskTag = + this.isIteration + ? LoggerFormatter.getTaskTag( + this.pipelineName, + this.cycleId, + iterationId, + task.getTaskId(), + task.getVertexId(), + task.getIndex(), + task.getParallelism()) + : LoggerFormatter.getTaskTag( + this.pipelineName, + this.cycleId, + task.getTaskId(), + task.getVertexId(), + task.getIndex(), + task.getParallelism()); + LOGGER.info( + "{} submit event {} on host {} {} process {}", + taskTag, + event, + worker.getHost(), + worker.getWorkerIndex(), + worker.getProcessId()); + } + IEvent finalEvent; + if (executableEvents.size() == 1) { + finalEvent = executableEvents.get(0).getEvent(); + } else { + List events = + executableEvents.stream() + .map(ExecutableEventIterator.ExecutableEvent::getEvent) + .collect(Collectors.toList()); + finalEvent = new ComposeEvent(worker.getWorkerIndex(), flatEvents(events)); + } + Future future = + (Future) + RpcClient.getInstance().processContainer(worker.getContainerName(), finalEvent); + submitFutures.add(future); } - - public void setSchedulerId(long schedulerId) { - this.schedulerId = schedulerId; + FutureUtil.wait(submitFutures); + } + + private static List flatEvents(List list) { + List events = new ArrayList<>(); + for (IEvent event : list) { + if (event.getEventType() == EventType.COMPOSE) { + events.addAll(flatEvents(((ComposeEvent) event).getEventList())); + } else { + events.add(event); + } } - - @Override - protected void execute(IScheduleState state) { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - if (state.getScheduleStateType() == ScheduleStateType.COMPOSE) { - for (IScheduleState s : ((ComposeState) state).getStates()) { - getNextIterationId(context, s); - ExecutableEventIterator tmp = this.eventBuilder.build(s.getScheduleStateType(), this.context.getCurrentIterationId()); - iterator.merge(tmp); - } - } else { - getNextIterationId(context, state); - ExecutableEventIterator tmp = this.eventBuilder.build(state.getScheduleStateType(), this.context.getCurrentIterationId()); - iterator.merge(tmp); - } - iterator.markReady(); - - String iterationLogTag = getCycleIterationTag(this.context.getCurrentIterationId()); - LOGGER.info("{} execute", iterationLogTag); - this.iterationIdToFinishedTasks.put(this.context.getCurrentIterationId(), new ArrayList<>()); - this.executeEvents(iterator, this.context.getCurrentIterationId()); + return events; + } + + protected void finish(long iterationId) { + String iterationLogTag = getCycleIterationTag(iterationId); + if (iterationIdToFinishedTasks.get(iterationId) == null) { + // Unexpected to reach here. + throw new GeaflowRuntimeException( + String.format("fatal: %s result is unregistered", iterationLogTag)); } - private void executeEvents(ExecutableEventIterator eventIterator, long iterationId) { - List> submitFutures = new ArrayList<>(eventIterator.size()); - while (eventIterator.hasNext()) { - Tuple> tuple = eventIterator.next(); - WorkerInfo worker = tuple.f0; - List executableEvents = tuple.f1; - for (ExecutableEventIterator.ExecutableEvent executableEvent : executableEvents) { - ExecutionTask task = executableEvent.getTask(); - IEvent event = executableEvent.getEvent(); - String taskTag = this.isIteration - ? LoggerFormatter.getTaskTag(this.pipelineName, this.cycleId, iterationId, - task.getTaskId(), task.getVertexId(), task.getIndex(), task.getParallelism()) - : LoggerFormatter.getTaskTag(this.pipelineName, this.cycleId, task.getTaskId(), - task.getVertexId(), task.getIndex(), task.getParallelism()); - LOGGER.info("{} submit event {} on host {} {} process {}", - taskTag, - event, - worker.getHost(), - worker.getWorkerIndex(), - worker.getProcessId()); - } - IEvent finalEvent; - if (executableEvents.size() == 1) { - finalEvent = executableEvents.get(0).getEvent(); - } else { - List events = executableEvents.stream() - .map(ExecutableEventIterator.ExecutableEvent::getEvent) - .collect(Collectors.toList()); - finalEvent = new ComposeEvent(worker.getWorkerIndex(), flatEvents(events)); - } - Future future = (Future) RpcClient.getInstance() - .processContainer(worker.getContainerName(), finalEvent); - submitFutures.add(future); - } - FutureUtil.wait(submitFutures); + int expectedResponseSize = nodeCycle.getCycleTails().size(); + while (iterationIdToFinishedTasks.get(iterationId).size() != expectedResponseSize) { + IEvent response = responseEventPool.waitEvent(); + DoneEvent event = (DoneEvent) response; + // Get iterationId from task. + long currentTaskIterationId = event.getWindowId(); + if (!iterationIdToFinishedTasks.containsKey(currentTaskIterationId)) { + throw new GeaflowRuntimeException( + String.format( + "%s finish error, current response iterationId %s, current waiting iterationIds %s", + cycleLogTag, currentTaskIterationId, iterationIdToFinishedTasks.keySet())); + } + iterationIdToFinishedTasks.get(currentTaskIterationId).add(response); } - private static List flatEvents(List list) { - List events = new ArrayList<>(); - for (IEvent event : list) { - if (event.getEventType() == EventType.COMPOSE) { - events.addAll(flatEvents(((ComposeEvent) event).getEventList())); - } else { - events.add(event); - } - } - return events; + // Get current iteration result. + List responses = iterationIdToFinishedTasks.remove(iterationId); + for (IEvent e : responses) { + registerResults((DoneEvent) e); } - protected void finish(long iterationId) { - String iterationLogTag = getCycleIterationTag(iterationId); - if (iterationIdToFinishedTasks.get(iterationId) == null) { - // Unexpected to reach here. - throw new GeaflowRuntimeException(String.format("fatal: %s result is unregistered", - iterationLogTag)); - } - - int expectedResponseSize = nodeCycle.getCycleTails().size(); - while (iterationIdToFinishedTasks.get(iterationId).size() != expectedResponseSize) { - IEvent response = responseEventPool.waitEvent(); - DoneEvent event = (DoneEvent) response; - // Get iterationId from task. - long currentTaskIterationId = event.getWindowId(); - if (!iterationIdToFinishedTasks.containsKey(currentTaskIterationId)) { - throw new GeaflowRuntimeException( - String.format("%s finish error, current response iterationId %s, current waiting iterationIds %s", - cycleLogTag, - currentTaskIterationId, - iterationIdToFinishedTasks.keySet())); - } - iterationIdToFinishedTasks.get(currentTaskIterationId).add(response); - } - - // Get current iteration result. - List responses = iterationIdToFinishedTasks.remove(iterationId); - for (IEvent e : responses) { - registerResults((DoneEvent) e); - } - - if (this.isIteration) { - this.collectEventMetrics(responses, iterationId); - } - LOGGER.info("{} finished iterationId {}", iterationLogTag, iterationId); + if (this.isIteration) { + this.collectEventMetrics(responses, iterationId); } - - protected List finish() { - long finishIterationId = this.isIteration - ? this.context.getFinishIterationId() + 1 : this.context.getFinishIterationId(); - String finishLogTag = this.getCycleIterationTag(finishIterationId); - - // Need receive all tail responses. - int responseCount = 0; - - List resultResponses = new ArrayList<>(this.cycleTasks.size()); - List metricResponses = new ArrayList<>(this.cycleTasks.size()); - while (true) { - IEvent e = responseEventPool.waitEvent(); - DoneEvent> event = (DoneEvent) e; - switch (event.getSourceEvent()) { - case EXECUTE_COMPUTE: - resultResponses.add(event); - break; - default: - metricResponses.add(event); - responseCount++; - break; - } - if (responseCount == cycleTasks.size()) { - LOGGER.info("{} all task result collected", finishLogTag); - break; - } - } - if (!resultResponses.isEmpty()) { - for (IEvent e : resultResponses) { - registerResults((DoneEvent) e); - } - } - if (!metricResponses.isEmpty()) { - this.collectEventMetrics(metricResponses, finishIterationId); - LOGGER.info("{} finished", finishLogTag); - } - - return context.getResultManager().getDataResponse(); + LOGGER.info("{} finished iterationId {}", iterationLogTag, iterationId); + } + + protected List finish() { + long finishIterationId = + this.isIteration + ? this.context.getFinishIterationId() + 1 + : this.context.getFinishIterationId(); + String finishLogTag = this.getCycleIterationTag(finishIterationId); + + // Need receive all tail responses. + int responseCount = 0; + + List resultResponses = new ArrayList<>(this.cycleTasks.size()); + List metricResponses = new ArrayList<>(this.cycleTasks.size()); + while (true) { + IEvent e = responseEventPool.waitEvent(); + DoneEvent> event = (DoneEvent) e; + switch (event.getSourceEvent()) { + case EXECUTE_COMPUTE: + resultResponses.add(event); + break; + default: + metricResponses.add(event); + responseCount++; + break; + } + if (responseCount == cycleTasks.size()) { + LOGGER.info("{} all task result collected", finishLogTag); + break; + } + } + if (!resultResponses.isEmpty()) { + for (IEvent e : resultResponses) { + registerResults((DoneEvent) e); + } + } + if (!metricResponses.isEmpty()) { + this.collectEventMetrics(metricResponses, finishIterationId); + LOGGER.info("{} finished", finishLogTag); } - @Override - public void handleEvent(IEvent event) { - LOGGER.info("{} handle event {}", cycleLogTag, event); - if (event.getEventType() == EventType.COMPOSE) { - for (IEvent e : ((ComposeEvent) event).getEventList()) { - handleEvent(e); - } - } else { - dispatcher.dispatch(event); - } + return context.getResultManager().getDataResponse(); + } + + @Override + public void handleEvent(IEvent event) { + LOGGER.info("{} handle event {}", cycleLogTag, event); + if (event.getEventType() == EventType.COMPOSE) { + for (IEvent e : ((ComposeEvent) event).getEventList()) { + handleEvent(e); + } + } else { + dispatcher.dispatch(event); } + } + + private void registerResults(DoneEvent> event) { - private void registerResults(DoneEvent> event) { - - if (event.getResult() != null) { - // Register result to resultManager. - for (IResult result : event.getResult().values()) { - LOGGER.info("{} register result for {}", event, result.getId()); - if (result.getType() == OutputType.RESPONSE && result.getId() != COLLECT_DATA_EDGE_ID) { - LOGGER.info("do aggregate, result {}", result.getResponse()); - aggregator.aggregate(result.getResponse()); - } else { - resultManager.register(result.getId(), result); - } - } + if (event.getResult() != null) { + // Register result to resultManager. + for (IResult result : event.getResult().values()) { + LOGGER.info("{} register result for {}", event, result.getId()); + if (result.getType() == OutputType.RESPONSE && result.getId() != COLLECT_DATA_EDGE_ID) { + LOGGER.info("do aggregate, result {}", result.getResponse()); + aggregator.aggregate(result.getResponse()); + } else { + resultManager.register(result.getId(), result); } + } } + } - private void collectEventMetrics(List responses, long windowId) { - Map> vertexId2metrics = responses.stream() + private void collectEventMetrics(List responses, long windowId) { + Map> vertexId2metrics = + responses.stream() .map(e -> ((DoneEvent) e).getEventMetrics()) .filter(Objects::nonNull) .collect(Collectors.groupingBy(EventMetrics::getVertexId)); - long duration = System.currentTimeMillis() - this.scheduleStartTime; - for (Map.Entry> entry : vertexId2metrics.entrySet()) { - Integer vertexId = entry.getKey(); - List metrics = entry.getValue(); - EventMetrics[] previousMetrics = this.vertexIdToMetrics.get(vertexId); - - int taskNum = previousMetrics.length; - int slowestTask = 0; - long executeCostMs = 0; - long totalExecuteTime = 0; - long totalGcTime = 0; - long slowestTaskExecuteTime = 0; - long totalInputRecords = 0; - long totalInputBytes = 0; - long totalOutputRecords = 0; - long totalOutputBytes = 0; - - for (EventMetrics eventMetrics : metrics) { - int index = eventMetrics.getIndex(); - EventMetrics previous = previousMetrics[index]; - if (previous == null) { - executeCostMs = eventMetrics.getProcessCostMs(); - totalExecuteTime += executeCostMs; - totalGcTime += eventMetrics.getGcCostMs(); - totalInputRecords += eventMetrics.getShuffleReadRecords(); - totalInputBytes += eventMetrics.getShuffleReadBytes(); - totalOutputRecords += eventMetrics.getShuffleWriteRecords(); - totalOutputBytes += eventMetrics.getShuffleWriteBytes(); - } else { - executeCostMs = eventMetrics.getProcessCostMs() - previous.getProcessCostMs(); - totalExecuteTime += executeCostMs; - totalGcTime += eventMetrics.getGcCostMs() - previous.getGcCostMs(); - totalInputRecords += eventMetrics.getShuffleReadRecords() - previous.getShuffleReadRecords(); - totalInputBytes += eventMetrics.getShuffleReadBytes() - previous.getShuffleReadBytes(); - totalOutputRecords += eventMetrics.getShuffleWriteRecords() - previous.getShuffleWriteRecords(); - totalOutputBytes += eventMetrics.getShuffleWriteBytes() - previous.getShuffleWriteBytes(); - } - if (executeCostMs > slowestTaskExecuteTime) { - slowestTaskExecuteTime = executeCostMs; - slowestTask = index; - } - if (this.isIteration) { - previousMetrics[index] = eventMetrics; - } - } - - String metricName = this.isIteration - ? LoggerFormatter.getCycleMetricName(this.cycleId, windowId, vertexId) - : LoggerFormatter.getCycleMetricName(this.cycleId, vertexId); - String opName = this.nodeCycle.getVertexGroup().getVertexMap().get(vertexId).getName(); - CycleMetrics cycleMetrics = CycleMetrics.build( - metricName, - this.pipelineName, - opName, - taskNum, - slowestTask, - this.scheduleStartTime, - duration, - totalExecuteTime, - totalGcTime, - slowestTaskExecuteTime, - totalInputRecords, - totalInputBytes, - totalOutputRecords, - totalOutputBytes - ); - LOGGER.info("collect metric {} {}", metricName, cycleMetrics); - StatsCollectorFactory.getInstance().getPipelineStatsCollector().reportCycleMetrics(cycleMetrics); + long duration = System.currentTimeMillis() - this.scheduleStartTime; + for (Map.Entry> entry : vertexId2metrics.entrySet()) { + Integer vertexId = entry.getKey(); + List metrics = entry.getValue(); + EventMetrics[] previousMetrics = this.vertexIdToMetrics.get(vertexId); + + int taskNum = previousMetrics.length; + int slowestTask = 0; + long executeCostMs = 0; + long totalExecuteTime = 0; + long totalGcTime = 0; + long slowestTaskExecuteTime = 0; + long totalInputRecords = 0; + long totalInputBytes = 0; + long totalOutputRecords = 0; + long totalOutputBytes = 0; + + for (EventMetrics eventMetrics : metrics) { + int index = eventMetrics.getIndex(); + EventMetrics previous = previousMetrics[index]; + if (previous == null) { + executeCostMs = eventMetrics.getProcessCostMs(); + totalExecuteTime += executeCostMs; + totalGcTime += eventMetrics.getGcCostMs(); + totalInputRecords += eventMetrics.getShuffleReadRecords(); + totalInputBytes += eventMetrics.getShuffleReadBytes(); + totalOutputRecords += eventMetrics.getShuffleWriteRecords(); + totalOutputBytes += eventMetrics.getShuffleWriteBytes(); + } else { + executeCostMs = eventMetrics.getProcessCostMs() - previous.getProcessCostMs(); + totalExecuteTime += executeCostMs; + totalGcTime += eventMetrics.getGcCostMs() - previous.getGcCostMs(); + totalInputRecords += + eventMetrics.getShuffleReadRecords() - previous.getShuffleReadRecords(); + totalInputBytes += eventMetrics.getShuffleReadBytes() - previous.getShuffleReadBytes(); + totalOutputRecords += + eventMetrics.getShuffleWriteRecords() - previous.getShuffleWriteRecords(); + totalOutputBytes += eventMetrics.getShuffleWriteBytes() - previous.getShuffleWriteBytes(); } - - this.scheduleStartTime = System.currentTimeMillis(); - } - - private void getNextIterationId(ICycleSchedulerContext context, IScheduleState state) { - if (state.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE - || state.getScheduleStateType() == ScheduleStateType.ITERATION_INIT) { - context.getNextIterationId(); + if (executeCostMs > slowestTaskExecuteTime) { + slowestTaskExecuteTime = executeCostMs; + slowestTask = index; } + if (this.isIteration) { + previousMetrics[index] = eventMetrics; + } + } + + String metricName = + this.isIteration + ? LoggerFormatter.getCycleMetricName(this.cycleId, windowId, vertexId) + : LoggerFormatter.getCycleMetricName(this.cycleId, vertexId); + String opName = this.nodeCycle.getVertexGroup().getVertexMap().get(vertexId).getName(); + CycleMetrics cycleMetrics = + CycleMetrics.build( + metricName, + this.pipelineName, + opName, + taskNum, + slowestTask, + this.scheduleStartTime, + duration, + totalExecuteTime, + totalGcTime, + slowestTaskExecuteTime, + totalInputRecords, + totalInputBytes, + totalOutputRecords, + totalOutputBytes); + LOGGER.info("collect metric {} {}", metricName, cycleMetrics); + StatsCollectorFactory.getInstance() + .getPipelineStatsCollector() + .reportCycleMetrics(cycleMetrics); } - private String getCycleIterationTag(long iterationId) { - return this.isIteration - ? LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId, iterationId) - : LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId); - } - - @Override - protected void registerEventListener() { - registerSourceFinishEventListener(); - registerResponseEventListener(); - } - - private void registerSourceFinishEventListener() { - EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId(), EventType.LAUNCH_SOURCE); - IEventListener listener = - new SourceFinishResponseEventListener(nodeCycle.getCycleHeads().size(), - events -> { - long sourceFinishWindowId = - events.stream().map(e -> ((DoneEvent) e).getWindowId()).max(Long::compareTo).get(); - ((AbstractCycleSchedulerContext) context) - .setTerminateIterationId(sourceFinishWindowId); - LOGGER.info("{} all source finished at {}", cycleLogTag, sourceFinishWindowId); - - - ICycleSchedulerContext parentContext = ((AbstractCycleSchedulerContext) context).getParentContext(); - if (parentContext != null) { - DoneEvent sourceFinishEvent = new DoneEvent(schedulerId, - parentContext.getCycle().getCycleId(), - sourceFinishWindowId, cycle.getCycleId(), - EventType.LAUNCH_SOURCE, cycle.getCycleId()); - RpcClient.getInstance().processPipeline(cycle.getDriverId(), sourceFinishEvent); - } - }); - - // Register listener for end of source event. - this.dispatcher.registerListener(listenerKey, listener); - } + this.scheduleStartTime = System.currentTimeMillis(); + } - private void registerResponseEventListener() { - EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId()); - IEventListener listener = new ResponseEventListener(); - this.dispatcher.registerListener(listenerKey, listener); + private void getNextIterationId(ICycleSchedulerContext context, IScheduleState state) { + if (state.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE + || state.getScheduleStateType() == ScheduleStateType.ITERATION_INIT) { + context.getNextIterationId(); } + } + + private String getCycleIterationTag(long iterationId) { + return this.isIteration + ? LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId, iterationId) + : LoggerFormatter.getCycleTag(this.pipelineName, this.cycleId); + } + + @Override + protected void registerEventListener() { + registerSourceFinishEventListener(); + registerResponseEventListener(); + } + + private void registerSourceFinishEventListener() { + EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId(), EventType.LAUNCH_SOURCE); + IEventListener listener = + new SourceFinishResponseEventListener( + nodeCycle.getCycleHeads().size(), + events -> { + long sourceFinishWindowId = + events.stream() + .map(e -> ((DoneEvent) e).getWindowId()) + .max(Long::compareTo) + .get(); + ((AbstractCycleSchedulerContext) context) + .setTerminateIterationId(sourceFinishWindowId); + LOGGER.info("{} all source finished at {}", cycleLogTag, sourceFinishWindowId); + + ICycleSchedulerContext parentContext = + ((AbstractCycleSchedulerContext) context).getParentContext(); + if (parentContext != null) { + DoneEvent sourceFinishEvent = + new DoneEvent( + schedulerId, + parentContext.getCycle().getCycleId(), + sourceFinishWindowId, + cycle.getCycleId(), + EventType.LAUNCH_SOURCE, + cycle.getCycleId()); + RpcClient.getInstance().processPipeline(cycle.getDriverId(), sourceFinishEvent); + } + }); + + // Register listener for end of source event. + this.dispatcher.registerListener(listenerKey, listener); + } + + private void registerResponseEventListener() { + EventListenerKey listenerKey = EventListenerKey.of(cycle.getCycleId()); + IEventListener listener = new ResponseEventListener(); + this.dispatcher.registerListener(listenerKey, listener); + } + + public class ResponseEventListener implements IEventListener { - public class ResponseEventListener implements IEventListener { - - @Override - public void handleEvent(IEvent event) { - responseEventPool.notifyEvent(event); - } + @Override + public void handleEvent(IEvent event) { + responseEventPool.notifyEvent(event); } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventBuilder.java index 049a4c83d..6c3086ca1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventBuilder.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.protocol.ScheduleStateType; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -65,133 +66,139 @@ public class SchedulerEventBuilder { - private static final int COMPUTE_FETCH_COUNT = 1; + private static final int COMPUTE_FETCH_COUNT = 1; - private final ICycleSchedulerContext context; - private final ExecutionNodeCycle cycle; - private final CycleResultManager resultManager; - private final boolean enableAffinity; - private final boolean isIteration; - private final long schedulerId; + private final ICycleSchedulerContext context; + private final ExecutionNodeCycle cycle; + private final CycleResultManager resultManager; + private final boolean enableAffinity; + private final boolean isIteration; + private final long schedulerId; - public SchedulerEventBuilder(ICycleSchedulerContext context, - CycleResultManager resultManager, - long schedulerId) { - this.context = context; - this.cycle = context.getCycle(); - this.resultManager = resultManager; - this.enableAffinity = context.getParentContext() != null + public SchedulerEventBuilder( + ICycleSchedulerContext context, + CycleResultManager resultManager, + long schedulerId) { + this.context = context; + this.cycle = context.getCycle(); + this.resultManager = resultManager; + this.enableAffinity = + context.getParentContext() != null && context.getParentContext().getCycle().getIterationCount() > 1; - this.isIteration = cycle.getVertexGroup().getCycleGroupMeta().isIterative(); - this.schedulerId = schedulerId; - } - - public ExecutableEventIterator build(ScheduleStateType state, long iterationId) { - switch (state) { - case PREFETCH: - return this.buildPrefetch(); - case INIT: - return this.buildInitPipeline(); - case ITERATION_INIT: - return buildInitIteration(iterationId); - case EXECUTE_COMPUTE: - return buildExecute(iterationId); - case ITERATION_FINISH: - return this.finishIteration(); - case FINISH_PREFETCH: - return this.buildFinishPrefetch(); - case CLEAN_CYCLE: - return this.finishPipeline(); - case ROLLBACK: - return this.handleRollback(); - default: - throw new GeaflowRuntimeException(String.format("not support event %s yet", state)); - } + this.isIteration = cycle.getVertexGroup().getCycleGroupMeta().isIterative(); + this.schedulerId = schedulerId; + } + public ExecutableEventIterator build(ScheduleStateType state, long iterationId) { + switch (state) { + case PREFETCH: + return this.buildPrefetch(); + case INIT: + return this.buildInitPipeline(); + case ITERATION_INIT: + return buildInitIteration(iterationId); + case EXECUTE_COMPUTE: + return buildExecute(iterationId); + case ITERATION_FINISH: + return this.finishIteration(); + case FINISH_PREFETCH: + return this.buildFinishPrefetch(); + case CLEAN_CYCLE: + return this.finishPipeline(); + case ROLLBACK: + return this.handleRollback(); + default: + throw new GeaflowRuntimeException(String.format("not support event %s yet", state)); } + } - private ExecutableEventIterator buildPrefetch() { - ExecutableEventIterator iterator = this.buildChildrenPrefetchEvent(); - return iterator; - } + private ExecutableEventIterator buildPrefetch() { + ExecutableEventIterator iterator = this.buildChildrenPrefetchEvent(); + return iterator; + } - private ExecutableEventIterator buildFinishPrefetch() { - ExecutableEventIterator events = new ExecutableEventIterator(); - Map needFinishedPrefetchEvents = - this.context.getPrefetchEvents(); - Iterator> iterator = needFinishedPrefetchEvents.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - ExecutableEvent executableEvent = entry.getValue(); - IEvent event = executableEvent.getEvent(); - PrefetchEvent prefetchEvent = (PrefetchEvent) event; - FinishPrefetchEvent finishPrefetchEvent = new FinishPrefetchEvent( - prefetchEvent.getSchedulerId(), - prefetchEvent.getWorkerId(), - prefetchEvent.getCycleId(), - prefetchEvent.getIterationWindowId(), - executableEvent.getTask().getTaskId(), - executableEvent.getTask().getIndex(), - prefetchEvent.getPipelineId(), - prefetchEvent.getEdgeIds()); - ExecutableEvent finishExecutableEvent = ExecutableEvent.build( - executableEvent.getWorker(), executableEvent.getTask(), finishPrefetchEvent); - events.addEvent(finishExecutableEvent); - iterator.remove(); - } - return events; + private ExecutableEventIterator buildFinishPrefetch() { + ExecutableEventIterator events = new ExecutableEventIterator(); + Map needFinishedPrefetchEvents = this.context.getPrefetchEvents(); + Iterator> iterator = + needFinishedPrefetchEvents.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + ExecutableEvent executableEvent = entry.getValue(); + IEvent event = executableEvent.getEvent(); + PrefetchEvent prefetchEvent = (PrefetchEvent) event; + FinishPrefetchEvent finishPrefetchEvent = + new FinishPrefetchEvent( + prefetchEvent.getSchedulerId(), + prefetchEvent.getWorkerId(), + prefetchEvent.getCycleId(), + prefetchEvent.getIterationWindowId(), + executableEvent.getTask().getTaskId(), + executableEvent.getTask().getIndex(), + prefetchEvent.getPipelineId(), + prefetchEvent.getEdgeIds()); + ExecutableEvent finishExecutableEvent = + ExecutableEvent.build( + executableEvent.getWorker(), executableEvent.getTask(), finishPrefetchEvent); + events.addEvent(finishExecutableEvent); + iterator.remove(); } + return events; + } - private ExecutableEventIterator buildInitPipeline() { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (ExecutionTask task : this.cycle.getTasks()) { - IoDescriptor ioDescriptor = - IoDescriptorBuilder.buildPipelineIoDescriptor(task, this.cycle, - this.resultManager, this.context.isPrefetch()); - iterator.addEvent(task.getWorkerInfo(), task, buildInitOrPopEvent(task, ioDescriptor)); - } - return iterator; + private ExecutableEventIterator buildInitPipeline() { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (ExecutionTask task : this.cycle.getTasks()) { + IoDescriptor ioDescriptor = + IoDescriptorBuilder.buildPipelineIoDescriptor( + task, this.cycle, this.resultManager, this.context.isPrefetch()); + iterator.addEvent(task.getWorkerInfo(), task, buildInitOrPopEvent(task, ioDescriptor)); } + return iterator; + } - private ExecutableEventIterator buildChildrenPrefetchEvent() { - ICycleSchedulerContext parentContext = this.context.getParentContext(); - Set childrenIds = new HashSet<>(parentContext.getCycle().getCycleChildren().get(this.cycle.getCycleId())); - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (Integer childId : childrenIds) { - IExecutionCycle childCycle = parentContext.getCycle().getCycleMap().get(childId); - if (childCycle instanceof ExecutionNodeCycle) { - ExecutionNodeCycle childNodeCycle = (ExecutionNodeCycle) childCycle; - List childHeadTasks = childNodeCycle.getCycleHeads(); - Map needFinishedPrefetchEvents = - this.context.getPrefetchEvents(); - for (ExecutionTask childHeadTask : childHeadTasks) { - PrefetchEvent prefetchEvent = this.buildPrefetchEvent(childNodeCycle, childHeadTask); - ExecutableEvent executableEvent = ExecutableEvent.build(childHeadTask.getWorkerInfo(), - childHeadTask, prefetchEvent); - iterator.addEvent(executableEvent); - needFinishedPrefetchEvents.put(childHeadTask.getTaskId(), executableEvent); - } - } + private ExecutableEventIterator buildChildrenPrefetchEvent() { + ICycleSchedulerContext parentContext = + this.context.getParentContext(); + Set childrenIds = + new HashSet<>(parentContext.getCycle().getCycleChildren().get(this.cycle.getCycleId())); + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (Integer childId : childrenIds) { + IExecutionCycle childCycle = parentContext.getCycle().getCycleMap().get(childId); + if (childCycle instanceof ExecutionNodeCycle) { + ExecutionNodeCycle childNodeCycle = (ExecutionNodeCycle) childCycle; + List childHeadTasks = childNodeCycle.getCycleHeads(); + Map needFinishedPrefetchEvents = this.context.getPrefetchEvents(); + for (ExecutionTask childHeadTask : childHeadTasks) { + PrefetchEvent prefetchEvent = this.buildPrefetchEvent(childNodeCycle, childHeadTask); + ExecutableEvent executableEvent = + ExecutableEvent.build(childHeadTask.getWorkerInfo(), childHeadTask, prefetchEvent); + iterator.addEvent(executableEvent); + needFinishedPrefetchEvents.put(childHeadTask.getTaskId(), executableEvent); } - return iterator; + } } + return iterator; + } - private PrefetchEvent buildPrefetchEvent(ExecutionNodeCycle childNodeCycle, ExecutionTask childTask) { - IoDescriptor ioDescriptor = IoDescriptorBuilder.buildPrefetchIoDescriptor(this.cycle, childNodeCycle, childTask); - return new PrefetchEvent( - childNodeCycle.getSchedulerId(), - childTask.getWorkerInfo().getWorkerIndex(), - childNodeCycle.getCycleId(), - this.context.getInitialIterationId(), - childNodeCycle.getPipelineId(), - childNodeCycle.getPipelineName(), - childTask, - ioDescriptor); - } + private PrefetchEvent buildPrefetchEvent( + ExecutionNodeCycle childNodeCycle, ExecutionTask childTask) { + IoDescriptor ioDescriptor = + IoDescriptorBuilder.buildPrefetchIoDescriptor(this.cycle, childNodeCycle, childTask); + return new PrefetchEvent( + childNodeCycle.getSchedulerId(), + childTask.getWorkerInfo().getWorkerIndex(), + childNodeCycle.getCycleId(), + this.context.getInitialIterationId(), + childNodeCycle.getPipelineId(), + childNodeCycle.getPipelineName(), + childTask, + ioDescriptor); + } - private IEvent buildInitOrPopEvent(ExecutionTask task, IoDescriptor ioDescriptor) { - return this.cycle.isWorkerAssigned() - ? new PopWorkerEvent( + private IEvent buildInitOrPopEvent(ExecutionTask task, IoDescriptor ioDescriptor) { + return this.cycle.isWorkerAssigned() + ? new PopWorkerEvent( this.schedulerId, task.getWorkerInfo().getWorkerIndex(), this.cycle.getCycleId(), @@ -200,185 +207,219 @@ private IEvent buildInitOrPopEvent(ExecutionTask task, IoDescriptor ioDescriptor this.cycle.getPipelineName(), ioDescriptor, task.getTaskId()) - : this.buildInitCycleEvent(task, ioDescriptor); + : this.buildInitCycleEvent(task, ioDescriptor); + } + + private InitCycleEvent buildInitCycleEvent(ExecutionTask task, IoDescriptor ioDescriptor) { + InitCycleEvent init; + int workerId = task.getWorkerInfo().getWorkerIndex(); + HighAvailableLevel highAvailableLevel = this.cycle.getHighAvailableLevel(); + if (this.cycle.getType() == ExecutionCycleType.ITERATION) { + highAvailableLevel = HighAvailableLevel.REDO; } - private InitCycleEvent buildInitCycleEvent(ExecutionTask task, IoDescriptor ioDescriptor) { - InitCycleEvent init; - int workerId = task.getWorkerInfo().getWorkerIndex(); - HighAvailableLevel highAvailableLevel = this.cycle.getHighAvailableLevel(); - if (this.cycle.getType() == ExecutionCycleType.ITERATION) { - highAvailableLevel = HighAvailableLevel.REDO; - } + if (this.cycle instanceof CollectExecutionNodeCycle) { + init = + new InitCollectCycleEvent( + this.schedulerId, + workerId, + this.cycle.getCycleId(), + this.context.getInitialIterationId(), + this.cycle.getPipelineId(), + this.cycle.getPipelineName(), + ioDescriptor, + task, + this.cycle.getDriverId(), + highAvailableLevel); + } else { + init = + new InitCycleEvent( + this.schedulerId, + workerId, + this.cycle.getCycleId(), + this.context.getInitialIterationId(), + this.cycle.getPipelineId(), + this.cycle.getPipelineName(), + ioDescriptor, + task, + this.cycle.getDriverId(), + highAvailableLevel); + } + return init; + } - if (this.cycle instanceof CollectExecutionNodeCycle) { - init = new InitCollectCycleEvent( + private ExecutableEventIterator buildInitIteration(long iterationId) { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (ExecutionTask task : this.cycle.getTasks()) { + if (ExecutionTaskUtils.isCycleHead(task)) { + int workerId = task.getWorkerInfo().getWorkerIndex(); + // Load graph. + IEvent loadGraph = + new LoadGraphProcessEvent( this.schedulerId, workerId, - this.cycle.getCycleId(), - this.context.getInitialIterationId(), - this.cycle.getPipelineId(), - this.cycle.getPipelineName(), - ioDescriptor, - task, - this.cycle.getDriverId(), - highAvailableLevel); - } else { - init = new InitCycleEvent( + cycle.getCycleId(), + iterationId, + context.getInitialIterationId(), + COMPUTE_FETCH_COUNT); + // Init iteration. + IoDescriptor ioDescriptor = + IoDescriptorBuilder.buildIterationIoDescriptor( + task, this.cycle, this.resultManager, OutputType.LOOP); + InitIterationEvent iterationInit = + new InitIterationEvent( this.schedulerId, workerId, this.cycle.getCycleId(), - this.context.getInitialIterationId(), + iterationId, this.cycle.getPipelineId(), this.cycle.getPipelineName(), - ioDescriptor, - task, - this.cycle.getDriverId(), - highAvailableLevel); - } - return init; + ioDescriptor); + IEvent execute = + new ExecuteFirstIterationEvent( + this.schedulerId, workerId, this.cycle.getCycleId(), iterationId); + ComposeEvent composeEvent = + new ComposeEvent(workerId, Arrays.asList(loadGraph, iterationInit, execute)); + iterator.addEvent(task.getWorkerInfo(), task, composeEvent); + } } + return iterator; + } - private ExecutableEventIterator buildInitIteration(long iterationId) { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (ExecutionTask task : this.cycle.getTasks()) { - if (ExecutionTaskUtils.isCycleHead(task)) { - int workerId = task.getWorkerInfo().getWorkerIndex(); - // Load graph. - IEvent loadGraph = new LoadGraphProcessEvent(this.schedulerId, workerId, cycle.getCycleId(), - iterationId, context.getInitialIterationId(), COMPUTE_FETCH_COUNT); - // Init iteration. - IoDescriptor ioDescriptor = IoDescriptorBuilder.buildIterationIoDescriptor( - task, this.cycle, this.resultManager, OutputType.LOOP); - InitIterationEvent iterationInit = new InitIterationEvent( - this.schedulerId, - workerId, - this.cycle.getCycleId(), - iterationId, - this.cycle.getPipelineId(), - this.cycle.getPipelineName(), - ioDescriptor); - IEvent execute = new ExecuteFirstIterationEvent(this.schedulerId, workerId, this.cycle.getCycleId(), iterationId); - ComposeEvent composeEvent = new ComposeEvent(workerId, - Arrays.asList(loadGraph, iterationInit, execute)); - iterator.addEvent(task.getWorkerInfo(), task, composeEvent); - } + /** Build launch for all cycle heads. */ + private ExecutableEventIterator buildExecute(long iterationId) { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (ExecutionTask task : cycle.getTasks()) { + if (ExecutionTaskUtils.isCycleHead(task)) { + // Only submit launch to cycle head. + long fetchId = iterationId; + // Fetch previous iteration input. + if (isIteration) { + if (iterationId > DEFAULT_INITIAL_ITERATION_ID) { + fetchId = iterationId - 1; + } else { + fetchId = context.getInitialIterationId(); + } } - return iterator; + IEvent event = + buildExecute( + task, + task.getWorkerInfo().getWorkerIndex(), + cycle.getCycleId(), + iterationId, + fetchId); + iterator.addEvent(task.getWorkerInfo(), task, event); + } else if (iterationId == context.getInitialIterationId()) { + // Build execute compute for non-tail event during first window. + int workerId = task.getWorkerInfo().getWorkerIndex(); + ExecuteComputeEvent execute = + new ExecuteComputeEvent( + this.schedulerId, + workerId, + cycle.getCycleId(), + context.getInitialIterationId(), + context.getInitialIterationId(), + context.getFinishIterationId() - context.getInitialIterationId() + 1, + cycle.getIterationCount() > 1); + iterator.addEvent(task.getWorkerInfo(), task, execute); + } } + return iterator; + } - /** - * Build launch for all cycle heads. - */ - private ExecutableEventIterator buildExecute(long iterationId) { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (ExecutionTask task : cycle.getTasks()) { - if (ExecutionTaskUtils.isCycleHead(task)) { - // Only submit launch to cycle head. - long fetchId = iterationId; - // Fetch previous iteration input. - if (isIteration) { - if (iterationId > DEFAULT_INITIAL_ITERATION_ID) { - fetchId = iterationId - 1; - } else { - fetchId = context.getInitialIterationId(); - } - } - IEvent event = buildExecute(task, task.getWorkerInfo().getWorkerIndex(), - cycle.getCycleId(), iterationId, fetchId); - iterator.addEvent(task.getWorkerInfo(), task, event); - } else if (iterationId == context.getInitialIterationId()) { - // Build execute compute for non-tail event during first window. - int workerId = task.getWorkerInfo().getWorkerIndex(); - ExecuteComputeEvent execute = new ExecuteComputeEvent( - this.schedulerId, - workerId, - cycle.getCycleId(), context.getInitialIterationId(), - context.getInitialIterationId(), - context.getFinishIterationId() - context.getInitialIterationId() + 1, - cycle.getIterationCount() > 1); - iterator.addEvent(task.getWorkerInfo(), task, execute); - - } - } - return iterator; + private ExecutableEventIterator finishPipeline() { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + boolean needInterrupt = context.getCurrentIterationId() < context.getFinishIterationId(); + for (ExecutionTask task : cycle.getTasks()) { + int workerId = task.getWorkerInfo().getWorkerIndex(); + IEvent cleanEvent; + if (enableAffinity) { + cleanEvent = + new StashWorkerEvent( + this.schedulerId, + workerId, + cycle.getCycleId(), + cycle.getIterationCount(), + task.getTaskId()); + } else { + cleanEvent = + new CleanCycleEvent( + this.schedulerId, workerId, cycle.getCycleId(), cycle.getIterationCount()); + } + if (needInterrupt + && context.getCycle().getType() != ExecutionCycleType.ITERATION + && context.getCycle().getType() != ExecutionCycleType.ITERATION_WITH_AGG) { + InterruptTaskEvent interruptTaskEvent = + new InterruptTaskEvent(workerId, cycle.getCycleId()); + ComposeEvent composeEvent = + new ComposeEvent( + task.getWorkerInfo().getWorkerIndex(), + Arrays.asList(interruptTaskEvent, cleanEvent)); + iterator.addEvent(task.getWorkerInfo(), task, composeEvent); + } else { + iterator.addEvent(task.getWorkerInfo(), task, cleanEvent); + } } + return iterator; + } - private ExecutableEventIterator finishPipeline() { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - boolean needInterrupt = context.getCurrentIterationId() < context.getFinishIterationId(); - for (ExecutionTask task : cycle.getTasks()) { - int workerId = task.getWorkerInfo().getWorkerIndex(); - IEvent cleanEvent; - if (enableAffinity) { - cleanEvent = new StashWorkerEvent(this.schedulerId, workerId, cycle.getCycleId(), cycle.getIterationCount(), task.getTaskId()); - } else { - cleanEvent = new CleanCycleEvent(this.schedulerId, workerId, cycle.getCycleId(), cycle.getIterationCount()); - } - if (needInterrupt && context.getCycle().getType() != ExecutionCycleType.ITERATION - && context.getCycle().getType() != ExecutionCycleType.ITERATION_WITH_AGG) { - InterruptTaskEvent interruptTaskEvent = new InterruptTaskEvent(workerId, cycle.getCycleId()); - ComposeEvent composeEvent = - new ComposeEvent(task.getWorkerInfo().getWorkerIndex(), - Arrays.asList(interruptTaskEvent, cleanEvent)); - iterator.addEvent(task.getWorkerInfo(), task, composeEvent); - } else { - iterator.addEvent(task.getWorkerInfo(), task, cleanEvent); - } - } - return iterator; + private ExecutableEventIterator finishIteration() { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (ExecutionTask task : this.cycle.getTasks()) { + int workerId = task.getWorkerInfo().getWorkerIndex(); + // Finish iteration + FinishIterationEvent iterationFinishEvent = + new FinishIterationEvent( + this.schedulerId, + workerId, + this.context.getInitialIterationId(), + this.cycle.getCycleId()); + + iterator.addEvent(task.getWorkerInfo(), task, iterationFinishEvent); } + return iterator; + } - private ExecutableEventIterator finishIteration() { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (ExecutionTask task : this.cycle.getTasks()) { - int workerId = task.getWorkerInfo().getWorkerIndex(); - // Finish iteration - FinishIterationEvent iterationFinishEvent = new FinishIterationEvent( + private ExecutableEventIterator handleRollback() { + ExecutableEventIterator iterator = new ExecutableEventIterator(); + for (ExecutionTask task : this.cycle.getTasks()) { + int workerId = task.getWorkerInfo().getWorkerIndex(); + // Do not do rollback if recover from initial iteration id. + if (context.getCurrentIterationId() != DEFAULT_INITIAL_ITERATION_ID) { + RollbackCycleEvent rollbackCycleEvent = + new RollbackCycleEvent( this.schedulerId, workerId, - this.context.getInitialIterationId(), - this.cycle.getCycleId()); - - iterator.addEvent(task.getWorkerInfo(), task, iterationFinishEvent); - } - return iterator; - } - - private ExecutableEventIterator handleRollback() { - ExecutableEventIterator iterator = new ExecutableEventIterator(); - for (ExecutionTask task : this.cycle.getTasks()) { - int workerId = task.getWorkerInfo().getWorkerIndex(); - // Do not do rollback if recover from initial iteration id. - if (context.getCurrentIterationId() != DEFAULT_INITIAL_ITERATION_ID) { - RollbackCycleEvent rollbackCycleEvent = new RollbackCycleEvent(this.schedulerId, workerId, - this.cycle.getCycleId(), - context.getCurrentIterationId() - 1); - iterator.addEvent(task.getWorkerInfo(), task, rollbackCycleEvent); - } - } - return iterator; + this.cycle.getCycleId(), + context.getCurrentIterationId() - 1); + iterator.addEvent(task.getWorkerInfo(), task, rollbackCycleEvent); + } } + return iterator; + } - private IEvent buildExecute(ExecutionTask task, int workerId, int cycleId, long iterationId, long fetchId) { - if (cycle.getVertexGroup().getParentVertexGroupIds().isEmpty() && ExecutionTaskUtils.isCycleHead(task)) { - return new LaunchSourceEvent(this.schedulerId, workerId, cycleId, iterationId); - } else { - IoDescriptor ioDescriptor = IoDescriptorBuilder.buildIterationIoDescriptor( - task, this.cycle, this.resultManager, OutputType.RESPONSE); - if (ioDescriptor.getInputDescriptor().getInputDescMap().isEmpty()) { - return new ExecuteComputeEvent(this.schedulerId, workerId, cycleId, iterationId, fetchId, COMPUTE_FETCH_COUNT); - } else { - return new IterationExecutionComputeWithAggEvent( - this.schedulerId, - workerId, - cycleId, - iterationId, - fetchId, - COMPUTE_FETCH_COUNT, - ioDescriptor); - } - } + private IEvent buildExecute( + ExecutionTask task, int workerId, int cycleId, long iterationId, long fetchId) { + if (cycle.getVertexGroup().getParentVertexGroupIds().isEmpty() + && ExecutionTaskUtils.isCycleHead(task)) { + return new LaunchSourceEvent(this.schedulerId, workerId, cycleId, iterationId); + } else { + IoDescriptor ioDescriptor = + IoDescriptorBuilder.buildIterationIoDescriptor( + task, this.cycle, this.resultManager, OutputType.RESPONSE); + if (ioDescriptor.getInputDescriptor().getInputDescMap().isEmpty()) { + return new ExecuteComputeEvent( + this.schedulerId, workerId, cycleId, iterationId, fetchId, COMPUTE_FETCH_COUNT); + } else { + return new IterationExecutionComputeWithAggEvent( + this.schedulerId, + workerId, + cycleId, + iterationId, + fetchId, + COMPUTE_FETCH_COUNT, + ioDescriptor); + } } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventDispatcher.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventDispatcher.java index a31e4751d..8789d78fa 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventDispatcher.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerEventDispatcher.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.common.IDispatcher; import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.EventType; @@ -34,48 +35,54 @@ public class SchedulerEventDispatcher implements IDispatcher { - private static final Logger LOGGER = LoggerFactory.getLogger(SchedulerEventDispatcher.class); + private static final Logger LOGGER = LoggerFactory.getLogger(SchedulerEventDispatcher.class); - private Map listeners; - private String cycleLogTag; + private Map listeners; + private String cycleLogTag; - public SchedulerEventDispatcher(String cycleLogTag) { - this.cycleLogTag = cycleLogTag; - this.listeners = new ConcurrentHashMap<>(); - } + public SchedulerEventDispatcher(String cycleLogTag) { + this.cycleLogTag = cycleLogTag; + this.listeners = new ConcurrentHashMap<>(); + } - @Override - public void dispatch(IEvent event) throws GeaflowDispatchException { - if (event.getEventType() != EventType.DONE) { - throw new GeaflowRuntimeException(String.format("%s not support handle event %s", - cycleLogTag, event)); - } - DoneEvent doneEvent = (DoneEvent) event; - getListener(doneEvent).handleEvent(doneEvent); + @Override + public void dispatch(IEvent event) throws GeaflowDispatchException { + if (event.getEventType() != EventType.DONE) { + throw new GeaflowRuntimeException( + String.format("%s not support handle event %s", cycleLogTag, event)); } + DoneEvent doneEvent = (DoneEvent) event; + getListener(doneEvent).handleEvent(doneEvent); + } - public void registerListener(EventListenerKey key, IEventListener eventListener) { - LOGGER.info("{} register event listener {}", cycleLogTag, key); - listeners.put(key, eventListener); - } + public void registerListener(EventListenerKey key, IEventListener eventListener) { + LOGGER.info("{} register event listener {}", cycleLogTag, key); + listeners.put(key, eventListener); + } - public void removeListener(EventListenerKey key) { - LOGGER.info("{} remove event listener {}", cycleLogTag, key); - listeners.remove(key); - } + public void removeListener(EventListenerKey key) { + LOGGER.info("{} remove event listener {}", cycleLogTag, key); + listeners.remove(key); + } - private IEventListener getListener(DoneEvent event) { - IEventListener listener; - if ((listener = listeners.get( - EventListenerKey.of(event.getCycleId(), event.getSourceEvent(), event.getWindowId()))) != null) { - return listener; - } else if ((listener = listeners.get( - EventListenerKey.of(event.getCycleId(), event.getSourceEvent()))) != null) { - return listener; - } else if ((listener = listeners.get(EventListenerKey.of(event.getCycleId()))) != null) { - return listener; - } - throw new GeaflowRuntimeException(String.format("%s not found any listener for event %s. current listeners %s", - cycleLogTag, event, listeners.keySet())); + private IEventListener getListener(DoneEvent event) { + IEventListener listener; + if ((listener = + listeners.get( + EventListenerKey.of( + event.getCycleId(), event.getSourceEvent(), event.getWindowId()))) + != null) { + return listener; + } else if ((listener = + listeners.get(EventListenerKey.of(event.getCycleId(), event.getSourceEvent()))) + != null) { + return listener; + } else if ((listener = listeners.get(EventListenerKey.of(event.getCycleId()))) != null) { + return listener; } + throw new GeaflowRuntimeException( + String.format( + "%s not found any listener for event %s. current listeners %s", + cycleLogTag, event, listeners.keySet())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerGraphAggregateProcessor.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerGraphAggregateProcessor.java index 3436fc3b9..7139ba0c4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerGraphAggregateProcessor.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/SchedulerGraphAggregateProcessor.java @@ -21,10 +21,10 @@ import static org.apache.geaflow.plan.PipelinePlanBuilder.ITERATION_AGG_VERTEX_ID; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; import java.util.Optional; + import org.apache.geaflow.api.graph.base.algo.GraphAggregationAlgo; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.cluster.response.ResponseResult; @@ -36,80 +36,95 @@ import org.apache.geaflow.runtime.core.scheduler.io.CycleResultManager; import org.apache.geaflow.shuffle.desc.OutputType; -public class SchedulerGraphAggregateProcessor { +import com.google.common.base.Preconditions; - private VertexCentricAggregateFunction.IGraphAggregateFunction function; - private AGG aggregator; - private RESULT result; - private int expectedCount; - private int processedCount; - private ExecutionNodeCycle cycle; - private AbstractCycleSchedulerContext schedulerContext; - - public SchedulerGraphAggregateProcessor(ExecutionNodeCycle cycle, - AbstractCycleSchedulerContext context, - CycleResultManager resultManager) { - this.cycle = cycle; - this.schedulerContext = context; - this.expectedCount = cycle.getCycleTails().size(); - this.processedCount = 0; - - Preconditions.checkArgument(cycle.getVertexGroup().getVertexMap().size() == 2, - String.format("Vertex group should only contains an iteration vertex " - + "and an aggregation vertex, current vertex size is %s", - cycle.getVertexGroup().getVertexMap().size())); - - ExecutionVertex aggVertex = cycle.getVertexGroup().getVertexMap().get(ITERATION_AGG_VERTEX_ID); - Preconditions.checkArgument(aggVertex != null, "aggregation vertex id should be 0"); - - AbstractGraphVertexCentricOp operator = - (AbstractGraphVertexCentricOp) ((GraphVertexCentricProcessor) aggVertex.getProcessor()).getOperator(); - ((GraphAggregationAlgo) (operator.getFunction())).getAggregateFunction().getPartialAggregation(); - this.function = - ((GraphAggregationAlgo) (operator.getFunction())).getAggregateFunction().getGlobalAggregation(); - - Optional edgeId = cycle.getVertexGroup().getEdgeMap().values().stream() - .filter(e -> e.getType() == OutputType.RESPONSE && e.getSrcId() == ITERATION_AGG_VERTEX_ID) - .map(e -> e.getEdgeId()).findFirst(); - Preconditions.checkArgument(edgeId.isPresent(), - "An edge from aggregation vertex to iteration vertex should build"); - this.aggregator = (AGG) function.create(new GlobalAggregateContext(edgeId.get(), resultManager)); - } +public class SchedulerGraphAggregateProcessor { - public void aggregate(List input) { - for (ITERM iterm : input) { - result = function.aggregate(iterm, aggregator); - if (++processedCount == expectedCount) { - function.finish(result); - processedCount = 0; - } - } + private VertexCentricAggregateFunction.IGraphAggregateFunction function; + private AGG aggregator; + private RESULT result; + private int expectedCount; + private int processedCount; + private ExecutionNodeCycle cycle; + private AbstractCycleSchedulerContext schedulerContext; + + public SchedulerGraphAggregateProcessor( + ExecutionNodeCycle cycle, + AbstractCycleSchedulerContext context, + CycleResultManager resultManager) { + this.cycle = cycle; + this.schedulerContext = context; + this.expectedCount = cycle.getCycleTails().size(); + this.processedCount = 0; + + Preconditions.checkArgument( + cycle.getVertexGroup().getVertexMap().size() == 2, + String.format( + "Vertex group should only contains an iteration vertex " + + "and an aggregation vertex, current vertex size is %s", + cycle.getVertexGroup().getVertexMap().size())); + + ExecutionVertex aggVertex = cycle.getVertexGroup().getVertexMap().get(ITERATION_AGG_VERTEX_ID); + Preconditions.checkArgument(aggVertex != null, "aggregation vertex id should be 0"); + + AbstractGraphVertexCentricOp operator = + (AbstractGraphVertexCentricOp) + ((GraphVertexCentricProcessor) aggVertex.getProcessor()).getOperator(); + ((GraphAggregationAlgo) (operator.getFunction())) + .getAggregateFunction() + .getPartialAggregation(); + this.function = + ((GraphAggregationAlgo) (operator.getFunction())) + .getAggregateFunction() + .getGlobalAggregation(); + + Optional edgeId = + cycle.getVertexGroup().getEdgeMap().values().stream() + .filter( + e -> e.getType() == OutputType.RESPONSE && e.getSrcId() == ITERATION_AGG_VERTEX_ID) + .map(e -> e.getEdgeId()) + .findFirst(); + Preconditions.checkArgument( + edgeId.isPresent(), "An edge from aggregation vertex to iteration vertex should build"); + this.aggregator = + (AGG) function.create(new GlobalAggregateContext(edgeId.get(), resultManager)); + } + + public void aggregate(List input) { + for (ITERM iterm : input) { + result = function.aggregate(iterm, aggregator); + if (++processedCount == expectedCount) { + function.finish(result); + processedCount = 0; + } } + } - private class GlobalAggregateContext - implements VertexCentricAggregateFunction.IGlobalGraphAggContext { + private class GlobalAggregateContext + implements VertexCentricAggregateFunction.IGlobalGraphAggContext { - private int edgeId; - private CycleResultManager resultManager; + private int edgeId; + private CycleResultManager resultManager; - public GlobalAggregateContext(int edgeId, CycleResultManager resultManager) { - this.resultManager = resultManager; - this.edgeId = edgeId; - } + public GlobalAggregateContext(int edgeId, CycleResultManager resultManager) { + this.resultManager = resultManager; + this.edgeId = edgeId; + } - @Override - public long getIteration() { - return schedulerContext.getCurrentIterationId(); - } + @Override + public long getIteration() { + return schedulerContext.getCurrentIterationId(); + } - @Override - public void broadcast(RESULT result) { - resultManager.register(edgeId, new ResponseResult(edgeId, OutputType.RESPONSE, Arrays.asList(result))); - } + @Override + public void broadcast(RESULT result) { + resultManager.register( + edgeId, new ResponseResult(edgeId, OutputType.RESPONSE, Arrays.asList(result))); + } - @Override - public void terminate() { - schedulerContext.setTerminateIterationId(getIteration()); - } + @Override + public void terminate() { + schedulerContext.setTerminateIterationId(getIteration()); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContext.java index 84b58e7dc..f19993e0b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContext.java @@ -25,6 +25,7 @@ import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicLong; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -39,223 +40,233 @@ import org.slf4j.LoggerFactory; public abstract class AbstractCycleSchedulerContext< - C extends IExecutionCycle, PC extends IExecutionCycle, PCC extends ICycleSchedulerContext> + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> implements ICycleSchedulerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractCycleSchedulerContext.class); - - public static final long DEFAULT_INITIAL_ITERATION_ID = 1; - - protected final C cycle; - protected transient AtomicLong iterationIdGenerator; - protected transient Queue flyingIterations; - protected transient long currentIterationId; - protected long finishIterationId; - protected transient long initialIterationId; - protected transient long lastCheckpointId; - protected transient long terminateIterationId; - - protected transient PCC parentContext; - protected IScheduledWorkerManager workerManager; - protected transient CycleResultManager cycleResultManager; - protected ICallbackFunction callbackFunction; - protected static ThreadLocal rollback = ThreadLocal.withInitial(() -> false); - protected transient Map prefetchEvents; - protected transient boolean prefetch; - - public AbstractCycleSchedulerContext(C cycle, PCC parentContext) { - this.cycle = cycle; - this.parentContext = parentContext; - - if (parentContext != null) { - // Get worker manager from parent context, no need to init. - this.workerManager = (IScheduledWorkerManager) parentContext.getSchedulerWorkerManager(); - this.finishIterationId = cycle.getIterationCount() == Long.MAX_VALUE - ? cycle.getIterationCount() : cycle.getIterationCount() + parentContext.getCurrentIterationId() - 1; - } else { - this.workerManager = (IScheduledWorkerManager) - ScheduledWorkerManagerFactory.createScheduledWorkerManager( - cycle.getConfig(), - ScheduledWorkerManagerFactory.getWorkerManagerHALevel(cycle) - ); - this.finishIterationId = cycle.getIterationCount() == Long.MAX_VALUE - ? cycle.getIterationCount() : cycle.getIterationCount() + DEFAULT_INITIAL_ITERATION_ID - 1; - } - } - - public void init() { - long startIterationId; - if (parentContext != null) { - startIterationId = parentContext.getCurrentIterationId(); - } else { - startIterationId = DEFAULT_INITIAL_ITERATION_ID; - } - init(startIterationId); - } - - public void init(long startIterationId) { - this.flyingIterations = new LinkedBlockingQueue<>(cycle.getFlyingCount()); - this.currentIterationId = startIterationId; - this.initialIterationId = startIterationId; - this.iterationIdGenerator = new AtomicLong(currentIterationId); - this.lastCheckpointId = 0; - this.terminateIterationId = Long.MAX_VALUE; - if (parentContext != null) { - this.cycleResultManager = parentContext.getResultManager(); - } else { - this.cycleResultManager = new CycleResultManager(); - } - prefetch = cycle.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH); - prefetchEvents = new HashMap<>(); - - LOGGER.info("{} init cycle context onTheFlyThreshold {}, currentIterationId {}, " - + "iterationCount {}, finishIterationId {}, initialIterationId {}", - cycle.getPipelineName(), cycle.getFlyingCount(), this.currentIterationId, - cycle.getIterationCount(), this.finishIterationId, this.initialIterationId); - - } - - @Override - public C getCycle() { - return this.cycle; - } - - @Override - public Configuration getConfig() { - return cycle.getConfig(); - } - - @Override - public boolean isCycleFinished() { - return (iterationIdGenerator.get() > finishIterationId || lastCheckpointId >= terminateIterationId) - && flyingIterations.isEmpty(); - } - - @Override - public long getCurrentIterationId() { - return currentIterationId; - } - - public void setCurrentIterationId(long iterationId) { - this.currentIterationId = iterationId; - } - - @Override - public boolean isRecovered() { - return false; - } - - @Override - public boolean isRollback() { - return rollback.get(); - } - - @Override - public boolean isPrefetch() { - return prefetch; - } - - public void setRollback(boolean bool) { - rollback.set(bool); - } - - @Override - public long getFinishIterationId() { - return finishIterationId; + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractCycleSchedulerContext.class); + + public static final long DEFAULT_INITIAL_ITERATION_ID = 1; + + protected final C cycle; + protected transient AtomicLong iterationIdGenerator; + protected transient Queue flyingIterations; + protected transient long currentIterationId; + protected long finishIterationId; + protected transient long initialIterationId; + protected transient long lastCheckpointId; + protected transient long terminateIterationId; + + protected transient PCC parentContext; + protected IScheduledWorkerManager workerManager; + protected transient CycleResultManager cycleResultManager; + protected ICallbackFunction callbackFunction; + protected static ThreadLocal rollback = ThreadLocal.withInitial(() -> false); + protected transient Map prefetchEvents; + protected transient boolean prefetch; + + public AbstractCycleSchedulerContext(C cycle, PCC parentContext) { + this.cycle = cycle; + this.parentContext = parentContext; + + if (parentContext != null) { + // Get worker manager from parent context, no need to init. + this.workerManager = (IScheduledWorkerManager) parentContext.getSchedulerWorkerManager(); + this.finishIterationId = + cycle.getIterationCount() == Long.MAX_VALUE + ? cycle.getIterationCount() + : cycle.getIterationCount() + parentContext.getCurrentIterationId() - 1; + } else { + this.workerManager = + (IScheduledWorkerManager) + ScheduledWorkerManagerFactory.createScheduledWorkerManager( + cycle.getConfig(), ScheduledWorkerManagerFactory.getWorkerManagerHALevel(cycle)); + this.finishIterationId = + cycle.getIterationCount() == Long.MAX_VALUE + ? cycle.getIterationCount() + : cycle.getIterationCount() + DEFAULT_INITIAL_ITERATION_ID - 1; } - - public void setFinishIterationId(long finishIterationId) { - this.finishIterationId = finishIterationId; - } - - @Override - public boolean hasNextIteration() { - return iterationIdGenerator.get() <= finishIterationId && lastCheckpointId < terminateIterationId - && flyingIterations.size() < cycle.getFlyingCount(); - } - - @Override - public long getNextIterationId() { - long iterationId = iterationIdGenerator.getAndIncrement(); - flyingIterations.add(iterationId); - this.currentIterationId = iterationId; - return iterationId; - } - - @Override - public boolean hasNextToFinish() { - return !flyingIterations.isEmpty() && !hasNextIteration(); - } - - @Override - public long getNextFinishIterationId() { - return flyingIterations.remove(); - } - - @Override - public long getInitialIterationId() { - return initialIterationId; - } - - @Override - public IScheduledWorkerManager getSchedulerWorkerManager() { - return workerManager; + } + + public void init() { + long startIterationId; + if (parentContext != null) { + startIterationId = parentContext.getCurrentIterationId(); + } else { + startIterationId = DEFAULT_INITIAL_ITERATION_ID; } - - @Override - public CycleResultManager getResultManager() { - return cycleResultManager; + init(startIterationId); + } + + public void init(long startIterationId) { + this.flyingIterations = new LinkedBlockingQueue<>(cycle.getFlyingCount()); + this.currentIterationId = startIterationId; + this.initialIterationId = startIterationId; + this.iterationIdGenerator = new AtomicLong(currentIterationId); + this.lastCheckpointId = 0; + this.terminateIterationId = Long.MAX_VALUE; + if (parentContext != null) { + this.cycleResultManager = parentContext.getResultManager(); + } else { + this.cycleResultManager = new CycleResultManager(); } - - @Override - public PCC getParentContext() { - return (PCC) this.parentContext; + prefetch = cycle.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH); + prefetchEvents = new HashMap<>(); + + LOGGER.info( + "{} init cycle context onTheFlyThreshold {}, currentIterationId {}, " + + "iterationCount {}, finishIterationId {}, initialIterationId {}", + cycle.getPipelineName(), + cycle.getFlyingCount(), + this.currentIterationId, + cycle.getIterationCount(), + this.finishIterationId, + this.initialIterationId); + } + + @Override + public C getCycle() { + return this.cycle; + } + + @Override + public Configuration getConfig() { + return cycle.getConfig(); + } + + @Override + public boolean isCycleFinished() { + return (iterationIdGenerator.get() > finishIterationId + || lastCheckpointId >= terminateIterationId) + && flyingIterations.isEmpty(); + } + + @Override + public long getCurrentIterationId() { + return currentIterationId; + } + + public void setCurrentIterationId(long iterationId) { + this.currentIterationId = iterationId; + } + + @Override + public boolean isRecovered() { + return false; + } + + @Override + public boolean isRollback() { + return rollback.get(); + } + + @Override + public boolean isPrefetch() { + return prefetch; + } + + public void setRollback(boolean bool) { + rollback.set(bool); + } + + @Override + public long getFinishIterationId() { + return finishIterationId; + } + + public void setFinishIterationId(long finishIterationId) { + this.finishIterationId = finishIterationId; + } + + @Override + public boolean hasNextIteration() { + return iterationIdGenerator.get() <= finishIterationId + && lastCheckpointId < terminateIterationId + && flyingIterations.size() < cycle.getFlyingCount(); + } + + @Override + public long getNextIterationId() { + long iterationId = iterationIdGenerator.getAndIncrement(); + flyingIterations.add(iterationId); + this.currentIterationId = iterationId; + return iterationId; + } + + @Override + public boolean hasNextToFinish() { + return !flyingIterations.isEmpty() && !hasNextIteration(); + } + + @Override + public long getNextFinishIterationId() { + return flyingIterations.remove(); + } + + @Override + public long getInitialIterationId() { + return initialIterationId; + } + + @Override + public IScheduledWorkerManager getSchedulerWorkerManager() { + return workerManager; + } + + @Override + public CycleResultManager getResultManager() { + return cycleResultManager; + } + + @Override + public PCC getParentContext() { + return (PCC) this.parentContext; + } + + public void setTerminateIterationId(long iterationId) { + terminateIterationId = iterationId; + } + + public void setCallbackFunction(ICallbackFunction callbackFunction) { + this.callbackFunction = callbackFunction; + } + + public Map getPrefetchEvents() { + return this.prefetchEvents; + } + + @Override + public void finish(long windowId) { + if (callbackFunction != null) { + callbackFunction.window(windowId); } + checkpoint(windowId); + } - public void setTerminateIterationId(long iterationId) { - terminateIterationId = iterationId; + @Override + public void finish() { + if (callbackFunction != null) { + callbackFunction.terminal(); } + } - public void setCallbackFunction(ICallbackFunction callbackFunction) { - this.callbackFunction = callbackFunction; - } + @Override + public List assign(C cycle) { + return workerManager.assign(cycle); + } - public Map getPrefetchEvents() { - return this.prefetchEvents; - } - - @Override - public void finish(long windowId) { - if (callbackFunction != null) { - callbackFunction.window(windowId); - } - checkpoint(windowId); - } - - @Override - public void finish() { - if (callbackFunction != null) { - callbackFunction.terminal(); - } - } - - @Override - public List assign(C cycle) { - return workerManager.assign(cycle); - } - - @Override - public void release(C cycle) { - workerManager.release(cycle); - } - - @Override - public void close(IExecutionCycle cycle) { - workerManager.close(cycle); - } + @Override + public void release(C cycle) { + workerManager.release(cycle); + } - abstract void checkpoint(long windowId); + @Override + public void close(IExecutionCycle cycle) { + workerManager.close(cycle); + } - abstract HighAvailableLevel getHaLevel(); + abstract void checkpoint(long windowId); + abstract HighAvailableLevel getHaLevel(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContext.java index f01a0f4c2..bac83cd52 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContext.java @@ -22,8 +22,8 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT; import static org.apache.geaflow.ha.runtime.HighAvailableLevel.CHECKPOINT; -import com.google.common.base.Preconditions; import java.util.function.Supplier; + import org.apache.geaflow.cluster.common.IReliableContext; import org.apache.geaflow.cluster.system.ClusterMetaStore; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -33,142 +33,155 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class CheckpointSchedulerContext< - C extends IExecutionCycle, - PC extends IExecutionCycle, - PCC extends ICycleSchedulerContext> extends AbstractCycleSchedulerContext implements IReliableContext { - - private static final Logger LOGGER = LoggerFactory.getLogger(CheckpointSchedulerContext.class); - - private final long checkpointDuration; - private boolean isNeedFullCheckpoint = true; - private transient long currentCheckpointId; - private transient boolean isRecovered = false; - - public CheckpointSchedulerContext(C cycle, PCC parentContext) { - super(cycle, parentContext); - this.checkpointDuration = getConfig().getLong(BATCH_NUMBER_PER_CHECKPOINT); - if (parentContext != null) { - this.callbackFunction = ((AbstractCycleSchedulerContext) parentContext).callbackFunction; - ((AbstractCycleSchedulerContext) parentContext).setCallbackFunction(null); - } + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> + extends AbstractCycleSchedulerContext implements IReliableContext { + + private static final Logger LOGGER = LoggerFactory.getLogger(CheckpointSchedulerContext.class); + + private final long checkpointDuration; + private boolean isNeedFullCheckpoint = true; + private transient long currentCheckpointId; + private transient boolean isRecovered = false; + + public CheckpointSchedulerContext(C cycle, PCC parentContext) { + super(cycle, parentContext); + this.checkpointDuration = getConfig().getLong(BATCH_NUMBER_PER_CHECKPOINT); + if (parentContext != null) { + this.callbackFunction = + ((AbstractCycleSchedulerContext) parentContext).callbackFunction; + ((AbstractCycleSchedulerContext) parentContext).setCallbackFunction(null); } + } - @Override - public void init() { - load(); - if (!isRecovered) { - checkpoint(new CycleCheckpointFunction()); - } + @Override + public void init() { + load(); + if (!isRecovered) { + checkpoint(new CycleCheckpointFunction()); } - - @Override - public void checkpoint(long iterationId) { - this.currentCheckpointId = iterationId; - if (isNeedFullCheckpoint) { - checkpoint(new CycleCheckpointFunction()); - isNeedFullCheckpoint = false; - } - if (CheckpointUtil.needDoCheckpoint(iterationId, checkpointDuration)) { - checkpoint(new IterationIdCheckpointFunction()); - lastCheckpointId = iterationId; - } + } + + @Override + public void checkpoint(long iterationId) { + this.currentCheckpointId = iterationId; + if (isNeedFullCheckpoint) { + checkpoint(new CycleCheckpointFunction()); + isNeedFullCheckpoint = false; } - - @Override - public void load() { - long windowId = loadWindowId(cycle.getPipelineTaskId()); - init(windowId); - if (!isRecovered && windowId != CheckpointSchedulerContext.DEFAULT_INITIAL_ITERATION_ID) { - setRollback(true); - } + if (CheckpointUtil.needDoCheckpoint(iterationId, checkpointDuration)) { + checkpoint(new IterationIdCheckpointFunction()); + lastCheckpointId = iterationId; } - - /** - * Load cycle if exists, otherwise rebuild one by input func. - */ - public static AbstractCycleSchedulerContext build(long pipelineTaskId, Supplier builder) { - AbstractCycleSchedulerContext context = loadCycle(pipelineTaskId); - if (context == null) { - Preconditions.checkArgument(builder != null, "should provide function to build new context"); - context = (AbstractCycleSchedulerContext) builder.get(); - if (context == null) { - throw new GeaflowRuntimeException("build new context failed"); - } - } - return context; + } + + @Override + public void load() { + long windowId = loadWindowId(cycle.getPipelineTaskId()); + init(windowId); + if (!isRecovered && windowId != CheckpointSchedulerContext.DEFAULT_INITIAL_ITERATION_ID) { + setRollback(true); } - - private static long loadWindowId(long pipelineTaskId) { - Long lastWindowId = ClusterMetaStore.getInstance().getWindowId(pipelineTaskId); - long windowId; - if (lastWindowId == null) { - windowId = CheckpointSchedulerContext.DEFAULT_INITIAL_ITERATION_ID; - LOGGER.info("not found last success batchId, set startIterationId to {}", windowId); - } else { - // driver fo recover windowId - windowId = lastWindowId + 1; - LOGGER.info("load scheduler context, lastWindowId {}, current start windowId {}", - lastWindowId, windowId); - } - return windowId; + } + + /** Load cycle if exists, otherwise rebuild one by input func. */ + public static AbstractCycleSchedulerContext build( + long pipelineTaskId, Supplier builder) { + AbstractCycleSchedulerContext context = loadCycle(pipelineTaskId); + if (context == null) { + Preconditions.checkArgument(builder != null, "should provide function to build new context"); + context = (AbstractCycleSchedulerContext) builder.get(); + if (context == null) { + throw new GeaflowRuntimeException("build new context failed"); + } } - - private static CheckpointSchedulerContext loadCycle(long pipelineTaskId) { - CheckpointSchedulerContext context = (CheckpointSchedulerContext) ClusterMetaStore.getInstance().getCycle(pipelineTaskId); - if (context == null) { - LOGGER.info("not found recoverable cycle"); - return null; - } - context.isRecovered = true; - context.init(); - return context; + return context; + } + + private static long loadWindowId(long pipelineTaskId) { + Long lastWindowId = ClusterMetaStore.getInstance().getWindowId(pipelineTaskId); + long windowId; + if (lastWindowId == null) { + windowId = CheckpointSchedulerContext.DEFAULT_INITIAL_ITERATION_ID; + LOGGER.info("not found last success batchId, set startIterationId to {}", windowId); + } else { + // driver fo recover windowId + windowId = lastWindowId + 1; + LOGGER.info( + "load scheduler context, lastWindowId {}, current start windowId {}", + lastWindowId, + windowId); } - - @Override - public void checkpoint(IReliableContextCheckpointFunction function) { - function.doCheckpoint(this); + return windowId; + } + + private static CheckpointSchedulerContext loadCycle(long pipelineTaskId) { + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) ClusterMetaStore.getInstance().getCycle(pipelineTaskId); + if (context == null) { + LOGGER.info("not found recoverable cycle"); + return null; } + context.isRecovered = true; + context.init(); + return context; + } - @Override - public void close(IExecutionCycle cycle) { - super.close(cycle); - } + @Override + public void checkpoint(IReliableContextCheckpointFunction function) { + function.doCheckpoint(this); + } - @Override - protected HighAvailableLevel getHaLevel() { - return CHECKPOINT; - } + @Override + public void close(IExecutionCycle cycle) { + super.close(cycle); + } - public boolean isRecovered() { - return isRecovered; - } + @Override + protected HighAvailableLevel getHaLevel() { + return CHECKPOINT; + } - public void setRecovered(boolean isRecovered) { - this.isRecovered = isRecovered; - } + public boolean isRecovered() { + return isRecovered; + } + + public void setRecovered(boolean isRecovered) { + this.isRecovered = isRecovered; + } - public class IterationIdCheckpointFunction implements IReliableContextCheckpointFunction { + public class IterationIdCheckpointFunction implements IReliableContextCheckpointFunction { - @Override - public void doCheckpoint(IReliableContext context) { - long checkpointId = ((CheckpointSchedulerContext) context).currentCheckpointId; - ClusterMetaStore.getInstance().saveWindowId(checkpointId, - (((CheckpointSchedulerContext) context).getCycle().getPipelineTaskId())); - LOGGER.info("cycle {} do checkpoint {}", - ((CheckpointSchedulerContext) context).getCycle().getCycleId(), checkpointId); - } + @Override + public void doCheckpoint(IReliableContext context) { + long checkpointId = ((CheckpointSchedulerContext) context).currentCheckpointId; + ClusterMetaStore.getInstance() + .saveWindowId( + checkpointId, + (((CheckpointSchedulerContext) context).getCycle().getPipelineTaskId())); + LOGGER.info( + "cycle {} do checkpoint {}", + ((CheckpointSchedulerContext) context).getCycle().getCycleId(), + checkpointId); } + } - public class CycleCheckpointFunction implements IReliableContextCheckpointFunction { + public class CycleCheckpointFunction implements IReliableContextCheckpointFunction { - @Override - public void doCheckpoint(IReliableContext context) { - long checkpointId = ((CheckpointSchedulerContext) context).currentCheckpointId; - ClusterMetaStore.getInstance().saveCycle(context, - ((CheckpointSchedulerContext) context).getCycle().getPipelineTaskId()).flush(); - LOGGER.info("cycle {} do checkpoint {} for full context", - ((CheckpointSchedulerContext) context).getCycle().getCycleId(), checkpointId); - } + @Override + public void doCheckpoint(IReliableContext context) { + long checkpointId = ((CheckpointSchedulerContext) context).currentCheckpointId; + ClusterMetaStore.getInstance() + .saveCycle(context, ((CheckpointSchedulerContext) context).getCycle().getPipelineTaskId()) + .flush(); + LOGGER.info( + "cycle {} do checkpoint {} for full context", + ((CheckpointSchedulerContext) context).getCycle().getCycleId(), + checkpointId); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CycleSchedulerContextFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CycleSchedulerContextFactory.java index 8be4e97e8..de4d5ea60 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CycleSchedulerContextFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/CycleSchedulerContextFactory.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.core.scheduler.context; import java.util.function.Supplier; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.runtime.core.scheduler.cycle.ExecutionCycleType; import org.apache.geaflow.runtime.core.scheduler.cycle.IExecutionCycle; @@ -28,35 +29,42 @@ public class CycleSchedulerContextFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(CycleSchedulerContextFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CycleSchedulerContextFactory.class); - public static > - ICycleSchedulerContext loadOrCreate(long pipelineTaskId, Supplier> buildFunc) { - return CheckpointSchedulerContext.build(pipelineTaskId, buildFunc); - } + public static < + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> + ICycleSchedulerContext loadOrCreate( + long pipelineTaskId, Supplier> buildFunc) { + return CheckpointSchedulerContext.build(pipelineTaskId, buildFunc); + } - public static ICycleSchedulerContext create(IExecutionCycle cycle, ICycleSchedulerContext parent) { - AbstractCycleSchedulerContext context; - switch (cycle.getHighAvailableLevel()) { - case CHECKPOINT: - LOGGER.info("create checkpoint scheduler context"); - context = new CheckpointSchedulerContext(cycle, parent); - break; - case REDO: - if (cycle.getType() == ExecutionCycleType.ITERATION - || cycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG) { - LOGGER.info("create iteration redo scheduler context"); - context = new IterationRedoSchedulerContext(cycle, parent); - } else { - LOGGER.info("create redo scheduler context"); - context = new RedoSchedulerContext(cycle, parent); - } - break; - default: - throw new GeaflowRuntimeException(String.format("not support ha level %s for cycle %s", - cycle.getHighAvailableLevel(), cycle.getCycleId())); + public static ICycleSchedulerContext create( + IExecutionCycle cycle, ICycleSchedulerContext parent) { + AbstractCycleSchedulerContext context; + switch (cycle.getHighAvailableLevel()) { + case CHECKPOINT: + LOGGER.info("create checkpoint scheduler context"); + context = new CheckpointSchedulerContext(cycle, parent); + break; + case REDO: + if (cycle.getType() == ExecutionCycleType.ITERATION + || cycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG) { + LOGGER.info("create iteration redo scheduler context"); + context = new IterationRedoSchedulerContext(cycle, parent); + } else { + LOGGER.info("create redo scheduler context"); + context = new RedoSchedulerContext(cycle, parent); } - context.init(); - return context; + break; + default: + throw new GeaflowRuntimeException( + String.format( + "not support ha level %s for cycle %s", + cycle.getHighAvailableLevel(), cycle.getCycleId())); } + context.init(); + return context; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/ICycleSchedulerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/ICycleSchedulerContext.java index 7b2891e09..407200e09 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/ICycleSchedulerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/ICycleSchedulerContext.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.runtime.core.scheduler.ExecutableEventIterator.ExecutableEvent; @@ -30,134 +31,84 @@ import org.apache.geaflow.runtime.core.scheduler.resource.IScheduledWorkerManager; public interface ICycleSchedulerContext< - C extends IExecutionCycle, - PC extends IExecutionCycle, - PCC extends ICycleSchedulerContext> extends Serializable { - - /** - * Returns execution cycle. - */ - C getCycle(); - - PCC getParentContext(); - - /** - * Returns execution config. - */ - Configuration getConfig(); - - /** - * Returns whether cycle is finished. - */ - boolean isCycleFinished(); - - /** - * Returns when cycle is recovered. - */ - boolean isRecovered(); - - /** - * Returns whether cycle need rollback. - */ - boolean isRollback(); - - /** - * Returns whether enable prefetch. - */ - boolean isPrefetch(); - - /** - * Returns current iteration id. - */ - long getCurrentIterationId(); - - /** - * Returns finish iteration id. - */ - long getFinishIterationId(); - - /** - * Check whether has next iteration. - */ - boolean hasNextIteration(); - - /** - * Returns next iteration id. - */ - long getNextIterationId(); - - /** - * Check whether has next cycle to finish. - */ - boolean hasNextToFinish(); - - /** - * Returns next finish iteration id. - */ - long getNextFinishIterationId(); - - /** - * Returns initial iteration id. - */ - long getInitialIterationId(); - - /** - * Returns cycle result manager. - */ - CycleResultManager getResultManager(); - - /** - * Assign workers for cycle. - */ - List assign(C cycle); - - /** - * Release worker for cycle. - */ - void release(C cycle); - - /** - * Finish the windowId iteration. - */ - void finish(long windowId); - - /** - * Finish cycle. - */ - void finish(); - - /** - * Close workerManager. - */ - void close(IExecutionCycle cycle); - - /** - * Returns scheduler worker manager. - */ - IScheduledWorkerManager getSchedulerWorkerManager(); - - /** - * Returns prefetch events needed to be finished. - */ - Map getPrefetchEvents(); - - enum SchedulerState { - /** - * Init state. - */ - INIT, - /** - * Execute state. - */ - EXECUTE, - /** - * Finish state. - */ - FINISH, - /** - * Rollback state. - */ - ROLLBACK, - } + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> + extends Serializable { + + /** Returns execution cycle. */ + C getCycle(); + + PCC getParentContext(); + + /** Returns execution config. */ + Configuration getConfig(); + + /** Returns whether cycle is finished. */ + boolean isCycleFinished(); + + /** Returns when cycle is recovered. */ + boolean isRecovered(); + + /** Returns whether cycle need rollback. */ + boolean isRollback(); + + /** Returns whether enable prefetch. */ + boolean isPrefetch(); + + /** Returns current iteration id. */ + long getCurrentIterationId(); + + /** Returns finish iteration id. */ + long getFinishIterationId(); + + /** Check whether has next iteration. */ + boolean hasNextIteration(); + + /** Returns next iteration id. */ + long getNextIterationId(); + + /** Check whether has next cycle to finish. */ + boolean hasNextToFinish(); + + /** Returns next finish iteration id. */ + long getNextFinishIterationId(); + + /** Returns initial iteration id. */ + long getInitialIterationId(); + + /** Returns cycle result manager. */ + CycleResultManager getResultManager(); + + /** Assign workers for cycle. */ + List assign(C cycle); + + /** Release worker for cycle. */ + void release(C cycle); + + /** Finish the windowId iteration. */ + void finish(long windowId); + + /** Finish cycle. */ + void finish(); + + /** Close workerManager. */ + void close(IExecutionCycle cycle); + + /** Returns scheduler worker manager. */ + IScheduledWorkerManager getSchedulerWorkerManager(); + + /** Returns prefetch events needed to be finished. */ + Map getPrefetchEvents(); + enum SchedulerState { + /** Init state. */ + INIT, + /** Execute state. */ + EXECUTE, + /** Finish state. */ + FINISH, + /** Rollback state. */ + ROLLBACK, + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContext.java index b271b9691..e293614cf 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContext.java @@ -24,21 +24,24 @@ import org.slf4j.LoggerFactory; public class IterationRedoSchedulerContext< - C extends IExecutionCycle, - PC extends IExecutionCycle, - PCC extends ICycleSchedulerContext> extends RedoSchedulerContext { + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> + extends RedoSchedulerContext { - private static final Logger LOGGER = LoggerFactory.getLogger(IterationRedoSchedulerContext.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IterationRedoSchedulerContext.class); - public IterationRedoSchedulerContext(C cycle, PCC parentContext) { - super(cycle, parentContext); - } + public IterationRedoSchedulerContext(C cycle, PCC parentContext) { + super(cycle, parentContext); + } - public void init(long startIterationId) { - super.init(DEFAULT_INITIAL_ITERATION_ID); - this.initialIterationId = parentContext.getCurrentIterationId(); - this.finishIterationId = cycle.getIterationCount(); - LOGGER.info("init cycle for iteration, initialIterationId {}, finishIterationId {}", - this.initialIterationId, this.finishIterationId); - } + public void init(long startIterationId) { + super.init(DEFAULT_INITIAL_ITERATION_ID); + this.initialIterationId = parentContext.getCurrentIterationId(); + this.finishIterationId = cycle.getIterationCount(); + LOGGER.info( + "init cycle for iteration, initialIterationId {}, finishIterationId {}", + this.initialIterationId, + this.finishIterationId); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/RedoSchedulerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/RedoSchedulerContext.java index d866d518d..210634065 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/RedoSchedulerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/context/RedoSchedulerContext.java @@ -23,26 +23,27 @@ import org.apache.geaflow.runtime.core.scheduler.cycle.IExecutionCycle; public class RedoSchedulerContext< - C extends IExecutionCycle, - PC extends IExecutionCycle, - PCC extends ICycleSchedulerContext> extends AbstractCycleSchedulerContext { + C extends IExecutionCycle, + PC extends IExecutionCycle, + PCC extends ICycleSchedulerContext> + extends AbstractCycleSchedulerContext { - public RedoSchedulerContext(C cycle, PCC parentContext) { - super(cycle, parentContext); - } + public RedoSchedulerContext(C cycle, PCC parentContext) { + super(cycle, parentContext); + } - @Override - public void init() { - super.init(); - } + @Override + public void init() { + super.init(); + } - @Override - protected HighAvailableLevel getHaLevel() { - return HighAvailableLevel.REDO; - } + @Override + protected HighAvailableLevel getHaLevel() { + return HighAvailableLevel.REDO; + } - @Override - public void checkpoint(long iterationId) { - lastCheckpointId = iterationId; - } + @Override + public void checkpoint(long iterationId) { + lastCheckpointId = iterationId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/AbstractExecutionCycle.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/AbstractExecutionCycle.java index 83e039b82..a462e405e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/AbstractExecutionCycle.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/AbstractExecutionCycle.java @@ -19,90 +19,98 @@ package org.apache.geaflow.runtime.core.scheduler.cycle; -import com.google.common.base.Preconditions; import org.apache.geaflow.common.config.Configuration; +import com.google.common.base.Preconditions; + public abstract class AbstractExecutionCycle implements IExecutionCycle { - protected long pipelineId; - protected long pipelineTaskId; - protected long schedulerId; - protected String pipelineName; - protected int cycleId; - protected int flyingCount; - protected long iterationCount; - private Configuration config; - - public AbstractExecutionCycle(long schedulerId, long pipelineId, long pipelineTaskId, String pipelineName, - int cycleId, int flyingCount, long iterationCount, - Configuration config) { - this.pipelineName = pipelineName; - this.cycleId = cycleId; - this.flyingCount = flyingCount; - this.iterationCount = iterationCount; - this.pipelineId = pipelineId; - this.pipelineTaskId = pipelineTaskId; - this.schedulerId = schedulerId; - this.config = config; - - Preconditions.checkArgument(flyingCount > 0, - "cycle flyingCount should be positive, current value %s", flyingCount); - Preconditions.checkArgument(iterationCount > 0, - "cycle iterationCount should be positive, current value %s", iterationCount); - } - - - public void setPipelineId(long pipelineId) { - this.pipelineId = pipelineId; - } - - public long getPipelineId() { - return pipelineId; - } - - public void setPipelineTaskId(long pipelineTaskId) { - this.pipelineTaskId = pipelineTaskId; - } - - @Override - public long getPipelineTaskId() { - return pipelineTaskId; - } - - public void setSchedulerId(long schedulerId) { - this.schedulerId = schedulerId; - } - - @Override - public long getSchedulerId() { - return schedulerId; - } - - public void setPipelineName(String pipelineName) { - this.pipelineName = pipelineName; - } - - @Override - public String getPipelineName() { - return pipelineName; - } - - @Override - public int getCycleId() { - return cycleId; - } - - @Override - public int getFlyingCount() { - return flyingCount; - } - - @Override - public long getIterationCount() { - return iterationCount; - } - - public Configuration getConfig() { - return config; - } + protected long pipelineId; + protected long pipelineTaskId; + protected long schedulerId; + protected String pipelineName; + protected int cycleId; + protected int flyingCount; + protected long iterationCount; + private Configuration config; + + public AbstractExecutionCycle( + long schedulerId, + long pipelineId, + long pipelineTaskId, + String pipelineName, + int cycleId, + int flyingCount, + long iterationCount, + Configuration config) { + this.pipelineName = pipelineName; + this.cycleId = cycleId; + this.flyingCount = flyingCount; + this.iterationCount = iterationCount; + this.pipelineId = pipelineId; + this.pipelineTaskId = pipelineTaskId; + this.schedulerId = schedulerId; + this.config = config; + + Preconditions.checkArgument( + flyingCount > 0, "cycle flyingCount should be positive, current value %s", flyingCount); + Preconditions.checkArgument( + iterationCount > 0, + "cycle iterationCount should be positive, current value %s", + iterationCount); + } + + public void setPipelineId(long pipelineId) { + this.pipelineId = pipelineId; + } + + public long getPipelineId() { + return pipelineId; + } + + public void setPipelineTaskId(long pipelineTaskId) { + this.pipelineTaskId = pipelineTaskId; + } + + @Override + public long getPipelineTaskId() { + return pipelineTaskId; + } + + public void setSchedulerId(long schedulerId) { + this.schedulerId = schedulerId; + } + + @Override + public long getSchedulerId() { + return schedulerId; + } + + public void setPipelineName(String pipelineName) { + this.pipelineName = pipelineName; + } + + @Override + public String getPipelineName() { + return pipelineName; + } + + @Override + public int getCycleId() { + return cycleId; + } + + @Override + public int getFlyingCount() { + return flyingCount; + } + + @Override + public long getIterationCount() { + return iterationCount; + } + + public Configuration getConfig() { + return config; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/CollectExecutionNodeCycle.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/CollectExecutionNodeCycle.java index ab8cf4920..f4c10faee 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/CollectExecutionNodeCycle.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/CollectExecutionNodeCycle.java @@ -24,9 +24,23 @@ public class CollectExecutionNodeCycle extends ExecutionNodeCycle { - public CollectExecutionNodeCycle(long schedulerId, long pipelineId, long pipelineTaskId, - String pipelineName, ExecutionVertexGroup vertexGroup, - Configuration config, String driverId, int driverIndex) { - super(schedulerId, pipelineId, pipelineTaskId, pipelineName, vertexGroup, config, driverId, driverIndex); - } + public CollectExecutionNodeCycle( + long schedulerId, + long pipelineId, + long pipelineTaskId, + String pipelineName, + ExecutionVertexGroup vertexGroup, + Configuration config, + String driverId, + int driverIndex) { + super( + schedulerId, + pipelineId, + pipelineTaskId, + pipelineName, + vertexGroup, + config, + driverId, + driverIndex); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilder.java index a6d3c50c3..ee4046a73 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilder.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.core.graph.ExecutionGraph; import org.apache.geaflow.core.graph.ExecutionTask; @@ -32,97 +33,132 @@ public class ExecutionCycleBuilder { - private static final int GRAPH_CYCLE_ID = 0; + private static final int GRAPH_CYCLE_ID = 0; - /** - * Build cycle by execution graph. - */ - public static IExecutionCycle buildExecutionCycle(ExecutionGraph executionGraph, - Map> vertex2Tasks, - Configuration config, - long pipelineId, - long pipelineTaskId, - String name, - long schedulerId, - String driverId, - int driverIndex, - boolean skipCheckpoint) { + /** Build cycle by execution graph. */ + public static IExecutionCycle buildExecutionCycle( + ExecutionGraph executionGraph, + Map> vertex2Tasks, + Configuration config, + long pipelineId, + long pipelineTaskId, + String name, + long schedulerId, + String driverId, + int driverIndex, + boolean skipCheckpoint) { - int flyingCount = executionGraph.getCycleGroupMeta().getFlyingCount(); - long iterationCount = executionGraph.getCycleGroupMeta().getIterationCount(); - ExecutionGraphCycle graphCycle = new ExecutionGraphCycle(schedulerId, pipelineId, - pipelineTaskId, name, GRAPH_CYCLE_ID, flyingCount, iterationCount, config, driverId, driverIndex); - for (ExecutionVertexGroup vertexGroup : executionGraph.getVertexGroupMap().values()) { - ExecutionNodeCycle nodeCycle = buildExecutionCycle(vertexGroup, - vertex2Tasks, config, pipelineId, pipelineTaskId, name, schedulerId, driverId, driverIndex); - graphCycle.addCycle(nodeCycle, skipCheckpoint); - } - return graphCycle; + int flyingCount = executionGraph.getCycleGroupMeta().getFlyingCount(); + long iterationCount = executionGraph.getCycleGroupMeta().getIterationCount(); + ExecutionGraphCycle graphCycle = + new ExecutionGraphCycle( + schedulerId, + pipelineId, + pipelineTaskId, + name, + GRAPH_CYCLE_ID, + flyingCount, + iterationCount, + config, + driverId, + driverIndex); + for (ExecutionVertexGroup vertexGroup : executionGraph.getVertexGroupMap().values()) { + ExecutionNodeCycle nodeCycle = + buildExecutionCycle( + vertexGroup, + vertex2Tasks, + config, + pipelineId, + pipelineTaskId, + name, + schedulerId, + driverId, + driverIndex); + graphCycle.addCycle(nodeCycle, skipCheckpoint); } + return graphCycle; + } - private static ExecutionNodeCycle buildExecutionCycle(ExecutionVertexGroup vertexGroup, - Map> vertex2Tasks, - Configuration config, - long pipelineId, - long pipelineTaskId, - String name, - long schedulerId, - String driverId, - int driverIndex) { - ExecutionNodeCycle cycle; - if (vertexGroup.getVertexMap().size() == 1 - && vertexGroup.getVertexMap().values().iterator().next().getChainTailType() == VertexType.collect) { - cycle = new CollectExecutionNodeCycle(schedulerId, pipelineId, pipelineTaskId, name, - vertexGroup, config, driverId, driverIndex); - } else { - cycle = new ExecutionNodeCycle(schedulerId, pipelineId, pipelineTaskId, name, - vertexGroup, config, driverId, driverIndex); - } - List allTasks = new ArrayList<>(); - List headTasks = new ArrayList<>(); - List tailTasks = new ArrayList<>(); - List opNames = new ArrayList<>(); - - for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { - boolean isHead = false; - // is head - if (vertexGroup.getHeadVertexIds().contains(vertex.getVertexId())) { - isHead = true; - } - boolean isTail = false; - // is tail - if (vertexGroup.getTailVertexIds().contains(vertex.getVertexId())) { - isTail = true; - opNames.add(vertex.getName()); - } + private static ExecutionNodeCycle buildExecutionCycle( + ExecutionVertexGroup vertexGroup, + Map> vertex2Tasks, + Configuration config, + long pipelineId, + long pipelineTaskId, + String name, + long schedulerId, + String driverId, + int driverIndex) { + ExecutionNodeCycle cycle; + if (vertexGroup.getVertexMap().size() == 1 + && vertexGroup.getVertexMap().values().iterator().next().getChainTailType() + == VertexType.collect) { + cycle = + new CollectExecutionNodeCycle( + schedulerId, + pipelineId, + pipelineTaskId, + name, + vertexGroup, + config, + driverId, + driverIndex); + } else { + cycle = + new ExecutionNodeCycle( + schedulerId, + pipelineId, + pipelineTaskId, + name, + vertexGroup, + config, + driverId, + driverIndex); + } + List allTasks = new ArrayList<>(); + List headTasks = new ArrayList<>(); + List tailTasks = new ArrayList<>(); + List opNames = new ArrayList<>(); - List tasks = vertex2Tasks.get(vertex.getVertexId()); - allTasks.addAll(tasks); + for (ExecutionVertex vertex : vertexGroup.getVertexMap().values()) { + boolean isHead = false; + // is head + if (vertexGroup.getHeadVertexIds().contains(vertex.getVertexId())) { + isHead = true; + } + boolean isTail = false; + // is tail + if (vertexGroup.getTailVertexIds().contains(vertex.getVertexId())) { + isTail = true; + opNames.add(vertex.getName()); + } - for (ExecutionTask task : tasks) { - if (isHead && isTail) { - task.setExecutionTaskType(ExecutionTaskType.singularity); - headTasks.add(task); - tailTasks.add(task); - } else if (isHead) { - task.setExecutionTaskType(ExecutionTaskType.head); - headTasks.add(task); - } else if (isTail) { - task.setExecutionTaskType(ExecutionTaskType.tail); - tailTasks.add(task); - } else { - task.setExecutionTaskType(ExecutionTaskType.middle); - } - task.setProcessor(vertex.getProcessor()); + List tasks = vertex2Tasks.get(vertex.getVertexId()); + allTasks.addAll(tasks); - } + for (ExecutionTask task : tasks) { + if (isHead && isTail) { + task.setExecutionTaskType(ExecutionTaskType.singularity); + headTasks.add(task); + tailTasks.add(task); + } else if (isHead) { + task.setExecutionTaskType(ExecutionTaskType.head); + headTasks.add(task); + } else if (isTail) { + task.setExecutionTaskType(ExecutionTaskType.tail); + tailTasks.add(task); + } else { + task.setExecutionTaskType(ExecutionTaskType.middle); } - - cycle.setName(String.join("|", opNames)); - cycle.setTasks(allTasks); - cycle.setCycleHeads(headTasks); - cycle.setCycleTails(tailTasks); - cycle.setVertexIdToTasks(vertex2Tasks); - return cycle; + task.setProcessor(vertex.getProcessor()); + } } + + cycle.setName(String.join("|", opNames)); + cycle.setTasks(allTasks); + cycle.setCycleHeads(headTasks); + cycle.setCycleTails(tailTasks); + cycle.setVertexIdToTasks(vertex2Tasks); + return cycle; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleType.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleType.java index ecf33f9d4..0fd3d0f59 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleType.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleType.java @@ -21,23 +21,17 @@ public enum ExecutionCycleType { - /** - * A cycle that contains whole execution graph. - */ - GRAPH, + /** A cycle that contains whole execution graph. */ + GRAPH, - /** - * A pipeline cycle fetch data and finally transfer output to downstream or sink op. - */ - PIPELINE, + /** A pipeline cycle fetch data and finally transfer output to downstream or sink op. */ + PIPELINE, - /** - * Iteration cycle that container data flow loop from cycle tail to cycle head. - */ - ITERATION, + /** Iteration cycle that container data flow loop from cycle tail to cycle head. */ + ITERATION, - /** - * Iteration cycle that container data flow loop from cycle tail to cycle head with aggregation. - */ - ITERATION_WITH_AGG, + /** + * Iteration cycle that container data flow loop from cycle tail to cycle head with aggregation. + */ + ITERATION_WITH_AGG, } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionGraphCycle.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionGraphCycle.java index b2d60dd58..ceb3e5302 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionGraphCycle.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionGraphCycle.java @@ -23,78 +23,103 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.ha.runtime.HighAvailableLevel; public class ExecutionGraphCycle extends AbstractExecutionCycle { - private final String driverId; - private final int driverIndex; - private final Map cycleMap; - private final Map> cycleChildren; - private final Map> cycleParents; - private HighAvailableLevel haLevel; - - public ExecutionGraphCycle(long schedulerId, long pipelineId, long pipelineTaskId, - String pipelineName, int cycleId, int flyingCount, long iterationCount, - Configuration config, String driverId, int driverIndex) { - super(schedulerId, pipelineId, pipelineTaskId, pipelineName, cycleId, flyingCount, iterationCount, config); - this.driverId = driverId; - this.driverIndex = driverIndex; - this.cycleMap = new HashMap<>(); - this.cycleChildren = new HashMap<>(); - this.cycleParents = new HashMap<>(); - this.haLevel = HighAvailableLevel.REDO; - } + private final String driverId; + private final int driverIndex; + private final Map cycleMap; + private final Map> cycleChildren; + private final Map> cycleParents; + private HighAvailableLevel haLevel; - @Override - public ExecutionCycleType getType() { - return ExecutionCycleType.GRAPH; - } + public ExecutionGraphCycle( + long schedulerId, + long pipelineId, + long pipelineTaskId, + String pipelineName, + int cycleId, + int flyingCount, + long iterationCount, + Configuration config, + String driverId, + int driverIndex) { + super( + schedulerId, + pipelineId, + pipelineTaskId, + pipelineName, + cycleId, + flyingCount, + iterationCount, + config); + this.driverId = driverId; + this.driverIndex = driverIndex; + this.cycleMap = new HashMap<>(); + this.cycleChildren = new HashMap<>(); + this.cycleParents = new HashMap<>(); + this.haLevel = HighAvailableLevel.REDO; + } - @Override - public String getDriverId() { - return driverId; - } + @Override + public ExecutionCycleType getType() { + return ExecutionCycleType.GRAPH; + } - @Override - public int getDriverIndex() { - return driverIndex; - } + @Override + public String getDriverId() { + return driverId; + } - @Override - public HighAvailableLevel getHighAvailableLevel() { - return haLevel; - } + @Override + public int getDriverIndex() { + return driverIndex; + } - public void addCycle(IExecutionCycle cycle, boolean skipCheckpoint) { - if (cycleMap.containsKey(cycle.getCycleId())) { - throw new GeaflowRuntimeException(String.format("cycle %d already added", cycle.getCycleId())); - } - cycleMap.put(cycle.getCycleId(), cycle); - cycleParents.put(cycle.getCycleId(), new ArrayList<>()); - cycleChildren.put(cycle.getCycleId(), new ArrayList<>()); - - ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; - cycleParents.get(cycle.getCycleId()).addAll(nodeCycle.getVertexGroup().getParentVertexGroupIds()); - cycleChildren.get(cycle.getCycleId()).addAll(nodeCycle.getVertexGroup().getChildrenVertexGroupIds()); - - if (!skipCheckpoint && (iterationCount > 1 && haLevel != HighAvailableLevel.CHECKPOINT - && (cycle.getType() == ExecutionCycleType.ITERATION || cycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG))) { - haLevel = HighAvailableLevel.CHECKPOINT; - } - } + @Override + public HighAvailableLevel getHighAvailableLevel() { + return haLevel; + } - public Map getCycleMap() { - return cycleMap; + public void addCycle(IExecutionCycle cycle, boolean skipCheckpoint) { + if (cycleMap.containsKey(cycle.getCycleId())) { + throw new GeaflowRuntimeException( + String.format("cycle %d already added", cycle.getCycleId())); } + cycleMap.put(cycle.getCycleId(), cycle); + cycleParents.put(cycle.getCycleId(), new ArrayList<>()); + cycleChildren.put(cycle.getCycleId(), new ArrayList<>()); - public Map> getCycleChildren() { - return cycleChildren; - } + ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; + cycleParents + .get(cycle.getCycleId()) + .addAll(nodeCycle.getVertexGroup().getParentVertexGroupIds()); + cycleChildren + .get(cycle.getCycleId()) + .addAll(nodeCycle.getVertexGroup().getChildrenVertexGroupIds()); - public Map> getCycleParents() { - return cycleParents; + if (!skipCheckpoint + && (iterationCount > 1 + && haLevel != HighAvailableLevel.CHECKPOINT + && (cycle.getType() == ExecutionCycleType.ITERATION + || cycle.getType() == ExecutionCycleType.ITERATION_WITH_AGG))) { + haLevel = HighAvailableLevel.CHECKPOINT; } + } + + public Map getCycleMap() { + return cycleMap; + } + + public Map> getCycleChildren() { + return cycleChildren; + } + + public Map> getCycleParents() { + return cycleParents; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionNodeCycle.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionNodeCycle.java index 6a8fbfefc..3b6edf255 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionNodeCycle.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionNodeCycle.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.core.graph.ExecutionTask; import org.apache.geaflow.core.graph.ExecutionVertexGroup; @@ -30,141 +31,154 @@ public class ExecutionNodeCycle extends AbstractExecutionCycle { - private String name; - private ExecutionCycleType type; - private String driverId; - private int driverIndex; - private HighAvailableLevel highAvailableLevel; - - private ExecutionVertexGroup vertexGroup; - private List tasks; - private List cycleHeads; - private List cycleTails; - private Set cycleHeadTaskIds; - private Map> vertexIdToTasks; - private boolean iterative; - private boolean isPipelineDataLoop; - private boolean isCollectResult; - private transient boolean workerAssigned; - - public ExecutionNodeCycle(long schedulerId, long pipelineId, long pipelineTaskId, String pipelineName, - ExecutionVertexGroup vertexGroup, - Configuration config, String driverId, int driverIndex) { - super(schedulerId, pipelineId, pipelineTaskId, pipelineName, vertexGroup.getGroupId(), - vertexGroup.getCycleGroupMeta().getFlyingCount(), vertexGroup.getCycleGroupMeta().getIterationCount(), - config); - this.vertexGroup = vertexGroup; - if (vertexGroup.getCycleGroupMeta().isIterative()) { - if (vertexGroup.getVertexMap().size() > 1) { - this.type = ExecutionCycleType.ITERATION_WITH_AGG; - } else { - this.type = ExecutionCycleType.ITERATION; - } - this.iterative = true; - } else { - this.type = ExecutionCycleType.PIPELINE; - } - this.isPipelineDataLoop = vertexGroup.getCycleGroupMeta().isIterative(); - this.driverId = driverId; - this.driverIndex = driverIndex; - if (!vertexGroup.getCycleGroupMeta().isIterative() && vertexGroup.getCycleGroupMeta().getIterationCount() > 1) { - this.highAvailableLevel = HighAvailableLevel.CHECKPOINT; - } else { - this.highAvailableLevel = HighAvailableLevel.REDO; - } - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public ExecutionVertexGroup getVertexGroup() { - return vertexGroup; - } - - public List getTasks() { - return tasks; - } - - public void setTasks(List tasks) { - this.tasks = tasks; - } - - public List getCycleHeads() { - return cycleHeads; - } - - public void setCycleHeads(List cycleHeads) { - this.cycleHeads = cycleHeads; - this.cycleHeadTaskIds = cycleHeads.stream().map(ExecutionTask::getTaskId).collect(Collectors.toSet()); - } - - public List getCycleTails() { - return cycleTails; - } - - public void setCycleTails(List cycleTails) { - this.cycleTails = cycleTails; - } - - public Map> getVertexIdToTasks() { - return vertexIdToTasks; - } - - public void setVertexIdToTasks(Map> vertexIdToTasks) { - this.vertexIdToTasks = vertexIdToTasks; - } - - @Override - public ExecutionCycleType getType() { - return type; - } - - @Override - public String getDriverId() { - return driverId; - } - - @Override - public int getDriverIndex() { - return driverIndex; - } - - @Override - public HighAvailableLevel getHighAvailableLevel() { - return highAvailableLevel; - } - - public boolean isIterative() { - return this.iterative; - } - - public boolean isPipelineDataLoop() { - return isPipelineDataLoop; - } - - public boolean isCollectResult() { - return isCollectResult; - } - - public void setCollectResult(boolean collectResult) { - isCollectResult = collectResult; - } - - public boolean isWorkerAssigned() { - return workerAssigned; - } - - public void setWorkerAssigned(boolean workerAssigned) { - this.workerAssigned = workerAssigned; - } - - public boolean isHeadTask(int taskId) { - return this.cycleHeadTaskIds.contains(taskId); - } - + private String name; + private ExecutionCycleType type; + private String driverId; + private int driverIndex; + private HighAvailableLevel highAvailableLevel; + + private ExecutionVertexGroup vertexGroup; + private List tasks; + private List cycleHeads; + private List cycleTails; + private Set cycleHeadTaskIds; + private Map> vertexIdToTasks; + private boolean iterative; + private boolean isPipelineDataLoop; + private boolean isCollectResult; + private transient boolean workerAssigned; + + public ExecutionNodeCycle( + long schedulerId, + long pipelineId, + long pipelineTaskId, + String pipelineName, + ExecutionVertexGroup vertexGroup, + Configuration config, + String driverId, + int driverIndex) { + super( + schedulerId, + pipelineId, + pipelineTaskId, + pipelineName, + vertexGroup.getGroupId(), + vertexGroup.getCycleGroupMeta().getFlyingCount(), + vertexGroup.getCycleGroupMeta().getIterationCount(), + config); + this.vertexGroup = vertexGroup; + if (vertexGroup.getCycleGroupMeta().isIterative()) { + if (vertexGroup.getVertexMap().size() > 1) { + this.type = ExecutionCycleType.ITERATION_WITH_AGG; + } else { + this.type = ExecutionCycleType.ITERATION; + } + this.iterative = true; + } else { + this.type = ExecutionCycleType.PIPELINE; + } + this.isPipelineDataLoop = vertexGroup.getCycleGroupMeta().isIterative(); + this.driverId = driverId; + this.driverIndex = driverIndex; + if (!vertexGroup.getCycleGroupMeta().isIterative() + && vertexGroup.getCycleGroupMeta().getIterationCount() > 1) { + this.highAvailableLevel = HighAvailableLevel.CHECKPOINT; + } else { + this.highAvailableLevel = HighAvailableLevel.REDO; + } + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public ExecutionVertexGroup getVertexGroup() { + return vertexGroup; + } + + public List getTasks() { + return tasks; + } + + public void setTasks(List tasks) { + this.tasks = tasks; + } + + public List getCycleHeads() { + return cycleHeads; + } + + public void setCycleHeads(List cycleHeads) { + this.cycleHeads = cycleHeads; + this.cycleHeadTaskIds = + cycleHeads.stream().map(ExecutionTask::getTaskId).collect(Collectors.toSet()); + } + + public List getCycleTails() { + return cycleTails; + } + + public void setCycleTails(List cycleTails) { + this.cycleTails = cycleTails; + } + + public Map> getVertexIdToTasks() { + return vertexIdToTasks; + } + + public void setVertexIdToTasks(Map> vertexIdToTasks) { + this.vertexIdToTasks = vertexIdToTasks; + } + + @Override + public ExecutionCycleType getType() { + return type; + } + + @Override + public String getDriverId() { + return driverId; + } + + @Override + public int getDriverIndex() { + return driverIndex; + } + + @Override + public HighAvailableLevel getHighAvailableLevel() { + return highAvailableLevel; + } + + public boolean isIterative() { + return this.iterative; + } + + public boolean isPipelineDataLoop() { + return isPipelineDataLoop; + } + + public boolean isCollectResult() { + return isCollectResult; + } + + public void setCollectResult(boolean collectResult) { + isCollectResult = collectResult; + } + + public boolean isWorkerAssigned() { + return workerAssigned; + } + + public void setWorkerAssigned(boolean workerAssigned) { + this.workerAssigned = workerAssigned; + } + + public boolean isHeadTask(int taskId) { + return this.cycleHeadTaskIds.contains(taskId); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/IExecutionCycle.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/IExecutionCycle.java index e1e2a12d8..596a1442b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/IExecutionCycle.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/cycle/IExecutionCycle.java @@ -24,59 +24,36 @@ public interface IExecutionCycle { - /** - * Returns pipeline name. - */ - String getPipelineName(); + /** Returns pipeline name. */ + String getPipelineName(); - /** - * Returns pipelineTask id. - */ - long getPipelineTaskId(); + /** Returns pipelineTask id. */ + long getPipelineTaskId(); - /** - * Returns scheduler id. - */ - long getSchedulerId(); + /** Returns scheduler id. */ + long getSchedulerId(); - /** - * Returns cycle id. - */ - int getCycleId(); + /** Returns cycle id. */ + int getCycleId(); - /** - * Returns flying count. - */ - int getFlyingCount(); + /** Returns flying count. */ + int getFlyingCount(); - /** - * Returns iteration count. - */ - long getIterationCount(); + /** Returns iteration count. */ + long getIterationCount(); - /** - * Returns config. - */ - Configuration getConfig(); + /** Returns config. */ + Configuration getConfig(); - /** - * Returns execution cycle type. - */ - ExecutionCycleType getType(); + /** Returns execution cycle type. */ + ExecutionCycleType getType(); - /** - * Returns driver id. - */ - String getDriverId(); + /** Returns driver id. */ + String getDriverId(); - /** - * Returns driver index. - */ - int getDriverIndex(); - - /** - * Returns HA level. - */ - HighAvailableLevel getHighAvailableLevel(); + /** Returns driver index. */ + int getDriverIndex(); + /** Returns HA level. */ + HighAvailableLevel getHighAvailableLevel(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/CycleResultManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/CycleResultManager.java index b654132d1..94c6c39e1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/CycleResultManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/CycleResultManager.java @@ -25,45 +25,42 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.response.IResult; public class CycleResultManager { - /** - * edge id to result shard. - */ - private Map> shards; + /** edge id to result shard. */ + private Map> shards; - /** - * all data response. - */ - private List rawDatas; + /** all data response. */ + private List rawDatas; - public CycleResultManager() { - this.shards = new ConcurrentHashMap<>(); - this.rawDatas = new ArrayList<>(); - } + public CycleResultManager() { + this.shards = new ConcurrentHashMap<>(); + this.rawDatas = new ArrayList<>(); + } - public void register(int id, IResult response) { - if (!shards.containsKey(id)) { - shards.put(id, new ArrayList<>()); - } - shards.get(id).add(response); + public void register(int id, IResult response) { + if (!shards.containsKey(id)) { + shards.put(id, new ArrayList<>()); } + shards.get(id).add(response); + } - public List get(int id) { - return shards.get(id); - } + public List get(int id) { + return shards.get(id); + } - public List getDataResponse() { - return shards.get(COLLECT_DATA_EDGE_ID); - } + public List getDataResponse() { + return shards.get(COLLECT_DATA_EDGE_ID); + } - public void release(int id) { - shards.remove(id); - } + public void release(int id) { + shards.remove(id); + } - public void clear() { - shards.clear(); - } -} \ No newline at end of file + public void clear() { + shards.clear(); + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchanger.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchanger.java index 79da4f322..7253a8072 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchanger.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchanger.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.response.IResult; import org.apache.geaflow.cluster.response.ShardResult; import org.apache.geaflow.core.graph.ExecutionEdge; @@ -31,47 +32,46 @@ public class DataExchanger { - private static final ThreadLocal>>> taskInputEdgeShards = ThreadLocal.withInitial(HashMap::new); + private static final ThreadLocal>>> taskInputEdgeShards = + ThreadLocal.withInitial(HashMap::new); - /** - * Build task input for execution vertex. - * - * @return key: taskIndex - * value: list of input shards - */ - public static Map> buildInput(ExecutionVertex vertex, - ExecutionEdge inputEdge, - CycleResultManager resultManager) { + /** + * Build task input for execution vertex. + * + * @return key: taskIndex value: list of input shards + */ + public static Map> buildInput( + ExecutionVertex vertex, ExecutionEdge inputEdge, CycleResultManager resultManager) { - if (taskInputEdgeShards.get().containsKey(inputEdge.getEdgeId())) { - return taskInputEdgeShards.get().get(inputEdge.getEdgeId()); - } - Map> result = new HashMap<>(); + if (taskInputEdgeShards.get().containsKey(inputEdge.getEdgeId())) { + return taskInputEdgeShards.get().get(inputEdge.getEdgeId()); + } + Map> result = new HashMap<>(); - int edgeId = inputEdge.getEdgeId(); - List eventResults = resultManager.get(edgeId); - Map taskIdToInputShard = new HashMap<>(); - for (IResult eventResult : eventResults) { - ShardResult shard = (ShardResult) eventResult; - for (int i = 0; i < shard.getResponse().size(); i++) { - int index = i % vertex.getParallelism(); - if (!taskIdToInputShard.containsKey(index)) { - taskIdToInputShard.put(index, new Shard(shard.getId(), new ArrayList<>())); - } - taskIdToInputShard.get(index).getSlices().add(shard.getResponse().get(i)); - } + int edgeId = inputEdge.getEdgeId(); + List eventResults = resultManager.get(edgeId); + Map taskIdToInputShard = new HashMap<>(); + for (IResult eventResult : eventResults) { + ShardResult shard = (ShardResult) eventResult; + for (int i = 0; i < shard.getResponse().size(); i++) { + int index = i % vertex.getParallelism(); + if (!taskIdToInputShard.containsKey(index)) { + taskIdToInputShard.put(index, new Shard(shard.getId(), new ArrayList<>())); } - for (Map.Entry entry : taskIdToInputShard.entrySet()) { - if (!result.containsKey(entry.getKey())) { - result.put(entry.getKey(), new ArrayList<>()); - } - result.get(entry.getKey()).add(entry.getValue()); - } - taskInputEdgeShards.get().put(inputEdge.getEdgeId(), result); - return result; + taskIdToInputShard.get(index).getSlices().add(shard.getResponse().get(i)); + } } - - public static void clear() { - taskInputEdgeShards.get().clear(); + for (Map.Entry entry : taskIdToInputShard.entrySet()) { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new ArrayList<>()); + } + result.get(entry.getKey()).add(entry.getValue()); } + taskInputEdgeShards.get().put(inputEdge.getEdgeId(), result); + return result; + } + + public static void clear() { + taskInputEdgeShards.get().clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilder.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilder.java index 3e3bb4c46..e1d17847a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilder.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilder.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.cluster.response.IResult; import org.apache.geaflow.cluster.response.ResponseResult; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -54,271 +55,280 @@ public class IoDescriptorBuilder { - public static int COLLECT_DATA_EDGE_ID = 0; + public static int COLLECT_DATA_EDGE_ID = 0; - public static IoDescriptor buildPrefetchIoDescriptor(ExecutionNodeCycle cycle, ExecutionNodeCycle childCycle, ExecutionTask childTask) { - InputDescriptor inputDescriptor = buildPrefetchInputDescriptor(cycle, childCycle, childTask); - return new IoDescriptor(inputDescriptor, null); - } + public static IoDescriptor buildPrefetchIoDescriptor( + ExecutionNodeCycle cycle, ExecutionNodeCycle childCycle, ExecutionTask childTask) { + InputDescriptor inputDescriptor = buildPrefetchInputDescriptor(cycle, childCycle, childTask); + return new IoDescriptor(inputDescriptor, null); + } - private static InputDescriptor buildPrefetchInputDescriptor(ExecutionNodeCycle cycle, - ExecutionNodeCycle childCycle, - ExecutionTask childTask) { - ExecutionVertexGroup childVertexGroup = childCycle.getVertexGroup(); - List inputEdgeIds = childVertexGroup.getVertexId2InEdgeIds().get(childTask.getVertexId()); - ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); - Map> inputInfoList = new HashMap<>(); - for (Integer edgeId : inputEdgeIds) { - ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); - if (edge == null) { - continue; - } - if (edge.getType() != OutputType.FORWARD) { - continue; - } - if (edge.getSrcId() == edge.getTargetId()) { - continue; - } - IInputDesc inputInfo = buildInputDesc(childTask, edge, childCycle, null, DataExchangeMode.BATCH, BatchPhase.PREFETCH_WRITE); - if (inputInfo != null) { - inputInfoList.put(edgeId, inputInfo); - } - } - return new InputDescriptor(inputInfoList); + private static InputDescriptor buildPrefetchInputDescriptor( + ExecutionNodeCycle cycle, ExecutionNodeCycle childCycle, ExecutionTask childTask) { + ExecutionVertexGroup childVertexGroup = childCycle.getVertexGroup(); + List inputEdgeIds = + childVertexGroup.getVertexId2InEdgeIds().get(childTask.getVertexId()); + ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); + Map> inputInfoList = new HashMap<>(); + for (Integer edgeId : inputEdgeIds) { + ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); + if (edge == null) { + continue; + } + if (edge.getType() != OutputType.FORWARD) { + continue; + } + if (edge.getSrcId() == edge.getTargetId()) { + continue; + } + IInputDesc inputInfo = + buildInputDesc( + childTask, edge, childCycle, null, DataExchangeMode.BATCH, BatchPhase.PREFETCH_WRITE); + if (inputInfo != null) { + inputInfoList.put(edgeId, inputInfo); + } } + return new InputDescriptor(inputInfoList); + } - /** - * Build pipeline io descriptor for input task. - */ - public static IoDescriptor buildPipelineIoDescriptor(ExecutionTask task, - ExecutionNodeCycle cycle, - CycleResultManager resultManager, - boolean prefetch) { - InputDescriptor inputDescriptor = buildPipelineInputDescriptor(task, cycle, resultManager, prefetch); - OutputDescriptor outputDescriptor = buildPipelineOutputDescriptor(task, cycle, prefetch); - return new IoDescriptor(inputDescriptor, outputDescriptor); - } + /** Build pipeline io descriptor for input task. */ + public static IoDescriptor buildPipelineIoDescriptor( + ExecutionTask task, + ExecutionNodeCycle cycle, + CycleResultManager resultManager, + boolean prefetch) { + InputDescriptor inputDescriptor = + buildPipelineInputDescriptor(task, cycle, resultManager, prefetch); + OutputDescriptor outputDescriptor = buildPipelineOutputDescriptor(task, cycle, prefetch); + return new IoDescriptor(inputDescriptor, outputDescriptor); + } - /** - * Build pipeline input descriptor. - */ - private static InputDescriptor buildPipelineInputDescriptor(ExecutionTask task, - ExecutionNodeCycle cycle, - CycleResultManager resultManager, - boolean prefetch) { - ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); - List inputEdgeIds = vertexGroup.getVertexId2InEdgeIds().get(task.getVertexId()); - Map> inputInfoList = new HashMap<>(); - for (Integer edgeId : inputEdgeIds) { - ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); - // Only build forward for pipeline. - if (edge.getType() != OutputType.FORWARD) { - continue; - } - if (edge.getSrcId() == edge.getTargetId()) { - continue; - } + /** Build pipeline input descriptor. */ + private static InputDescriptor buildPipelineInputDescriptor( + ExecutionTask task, + ExecutionNodeCycle cycle, + CycleResultManager resultManager, + boolean prefetch) { + ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); + List inputEdgeIds = vertexGroup.getVertexId2InEdgeIds().get(task.getVertexId()); + Map> inputInfoList = new HashMap<>(); + for (Integer edgeId : inputEdgeIds) { + ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); + // Only build forward for pipeline. + if (edge.getType() != OutputType.FORWARD) { + continue; + } + if (edge.getSrcId() == edge.getTargetId()) { + continue; + } - Tuple tuple = getInputDataExchangeMode(cycle, task, edge.getSrcId(), prefetch); - IInputDesc inputInfo = buildInputDesc(task, edge, cycle, resultManager, tuple.f0, tuple.f1); - if (inputInfo != null) { - inputInfoList.put(edgeId, inputInfo); - } - } - if (inputInfoList.isEmpty()) { - return null; - } - return new InputDescriptor(inputInfoList); + Tuple tuple = + getInputDataExchangeMode(cycle, task, edge.getSrcId(), prefetch); + IInputDesc inputInfo = + buildInputDesc(task, edge, cycle, resultManager, tuple.f0, tuple.f1); + if (inputInfo != null) { + inputInfoList.put(edgeId, inputInfo); + } } - - /** - * Build pipeline output descriptor. - */ - private static OutputDescriptor buildPipelineOutputDescriptor(ExecutionTask task, - ExecutionNodeCycle cycle, - boolean prefetch) { - int vertexId = task.getVertexId(); - if (cycle instanceof CollectExecutionNodeCycle) { - return new OutputDescriptor(Collections.singletonList(buildCollectOutputDesc(task))); - } - ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); - List outEdgeIds = vertexGroup.getVertexId2OutEdgeIds().get(vertexId); - List outputDescList = new ArrayList<>(); - for (Integer edgeId : outEdgeIds) { - ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); - IOutputDesc outputDesc = buildOutputDesc(task, edge, cycle, prefetch); - outputDescList.add(outputDesc); - } - if (outputDescList.isEmpty()) { - return null; - } - return new OutputDescriptor(outputDescList); + if (inputInfoList.isEmpty()) { + return null; } + return new InputDescriptor(inputInfoList); + } - public static IoDescriptor buildIterationIoDescriptor(ExecutionTask task, - ExecutionNodeCycle cycle, - CycleResultManager resultManager, - OutputType outputType) { - InputDescriptor inputDescriptor = buildIterationInputDescriptor(task, cycle, resultManager, outputType); - return new IoDescriptor(inputDescriptor, null); + /** Build pipeline output descriptor. */ + private static OutputDescriptor buildPipelineOutputDescriptor( + ExecutionTask task, ExecutionNodeCycle cycle, boolean prefetch) { + int vertexId = task.getVertexId(); + if (cycle instanceof CollectExecutionNodeCycle) { + return new OutputDescriptor(Collections.singletonList(buildCollectOutputDesc(task))); } - - private static InputDescriptor buildIterationInputDescriptor(ExecutionTask task, - ExecutionNodeCycle cycle, - CycleResultManager resultManager, - OutputType outputType) { - ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); - List inputEdgeIds = vertexGroup.getVertexId2InEdgeIds().get(task.getVertexId()); - Map> inputInfos = new HashMap<>(); - for (Integer edgeId : inputEdgeIds) { - ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); - if (edge.getType() == outputType) { - IInputDesc inputInfo = buildInputDesc(task, edge, cycle, resultManager, DataExchangeMode.PIPELINE, BatchPhase.CLASSIC); - if (inputInfo != null) { - inputInfos.put(edgeId, inputInfo); - } - } - } - return new InputDescriptor(inputInfos); + ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); + List outEdgeIds = vertexGroup.getVertexId2OutEdgeIds().get(vertexId); + List outputDescList = new ArrayList<>(); + for (Integer edgeId : outEdgeIds) { + ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); + IOutputDesc outputDesc = buildOutputDesc(task, edge, cycle, prefetch); + outputDescList.add(outputDesc); + } + if (outputDescList.isEmpty()) { + return null; } + return new OutputDescriptor(outputDescList); + } - /** - * Build input info for input task and edge. - */ - protected static IInputDesc buildInputDesc(ExecutionTask task, - ExecutionEdge inputEdge, - ExecutionNodeCycle cycle, - CycleResultManager resultManager, - DataExchangeMode dataExchangeMode, - BatchPhase batchPhase) { - List inputTasks = cycle.getVertexIdToTasks().get(inputEdge.getSrcId()); - int edgeId = inputEdge.getEdgeId(); - OutputType outputType = inputEdge.getType(); + public static IoDescriptor buildIterationIoDescriptor( + ExecutionTask task, + ExecutionNodeCycle cycle, + CycleResultManager resultManager, + OutputType outputType) { + InputDescriptor inputDescriptor = + buildIterationInputDescriptor(task, cycle, resultManager, outputType); + return new IoDescriptor(inputDescriptor, null); + } - switch (outputType) { - case LOOP: - case FORWARD: - int vertexId = task.getVertexId(); - ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); - List inputs = new ArrayList<>(inputTasks.size()); - if (dataExchangeMode == DataExchangeMode.BATCH && batchPhase == BatchPhase.CLASSIC) { - Map> taskInputs = - DataExchanger.buildInput(vertexGroup.getVertexMap().get(vertexId), inputEdge, resultManager); - inputs = taskInputs.get(task.getIndex()); - } else { - for (ExecutionTask inputTask : inputTasks) { - LogicalPipelineSliceMeta logicalSlice = new LogicalPipelineSliceMeta( - inputTask.getIndex(), - task.getIndex(), - cycle.getPipelineId(), - edgeId, - inputTask.getWorkerInfo().getContainerName()); - Shard shard = new Shard(edgeId, Collections.singletonList(logicalSlice)); - inputs.add(shard); - } - } - return new ShardInputDesc( - edgeId, - inputEdge.getEdgeName(), - inputs, - inputEdge.getEncoder(), - dataExchangeMode, - batchPhase); - case RESPONSE: - List results = resultManager.get(inputEdge.getEdgeId()); - if (results == null) { - return null; - } - List dataInput = new ArrayList<>(); - for (IResult result : results) { - if (result.getType() != OutputType.RESPONSE) { - throw new GeaflowRuntimeException(String.format("edge %s type %s not support handle result %s", - inputEdge.getEdgeId(), inputEdge.getType(), result.getType())); - } - dataInput.addAll(((ResponseResult) result).getResponse()); - } - return new RawDataInputDesc(edgeId, inputEdge.getEdgeName(), dataInput); - default: - throw new GeaflowRuntimeException(String.format("not support build input for edge %s type %s", - inputEdge.getEdgeId(), inputEdge.getType())); + private static InputDescriptor buildIterationInputDescriptor( + ExecutionTask task, + ExecutionNodeCycle cycle, + CycleResultManager resultManager, + OutputType outputType) { + ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); + List inputEdgeIds = vertexGroup.getVertexId2InEdgeIds().get(task.getVertexId()); + Map> inputInfos = new HashMap<>(); + for (Integer edgeId : inputEdgeIds) { + ExecutionEdge edge = vertexGroup.getEdgeMap().get(edgeId); + if (edge.getType() == outputType) { + IInputDesc inputInfo = + buildInputDesc( + task, edge, cycle, resultManager, DataExchangeMode.PIPELINE, BatchPhase.CLASSIC); + if (inputInfo != null) { + inputInfos.put(edgeId, inputInfo); } + } } + return new InputDescriptor(inputInfos); + } - private static IOutputDesc buildCollectOutputDesc(ExecutionTask task) { - int opId = getCollectOpId((AbstractOperator) ((AbstractProcessor) task.getProcessor()).getOperator()); - ResponseOutputDesc outputDesc = new ResponseOutputDesc(opId, - COLLECT_DATA_EDGE_ID, OutputType.RESPONSE.name()); - return outputDesc; - - } + /** Build input info for input task and edge. */ + protected static IInputDesc buildInputDesc( + ExecutionTask task, + ExecutionEdge inputEdge, + ExecutionNodeCycle cycle, + CycleResultManager resultManager, + DataExchangeMode dataExchangeMode, + BatchPhase batchPhase) { + List inputTasks = cycle.getVertexIdToTasks().get(inputEdge.getSrcId()); + int edgeId = inputEdge.getEdgeId(); + OutputType outputType = inputEdge.getType(); - private static Integer getCollectOpId(AbstractOperator operator) { - if (operator.getNextOperators().isEmpty()) { - return operator.getOpArgs().getOpId(); - } else if (operator.getNextOperators().size() == 1) { - return getCollectOpId((AbstractOperator) operator.getNextOperators().get(0)); + switch (outputType) { + case LOOP: + case FORWARD: + int vertexId = task.getVertexId(); + ExecutionVertexGroup vertexGroup = cycle.getVertexGroup(); + List inputs = new ArrayList<>(inputTasks.size()); + if (dataExchangeMode == DataExchangeMode.BATCH && batchPhase == BatchPhase.CLASSIC) { + Map> taskInputs = + DataExchanger.buildInput( + vertexGroup.getVertexMap().get(vertexId), inputEdge, resultManager); + inputs = taskInputs.get(task.getIndex()); } else { - throw new GeaflowRuntimeException("not support collect multi-output"); + for (ExecutionTask inputTask : inputTasks) { + LogicalPipelineSliceMeta logicalSlice = + new LogicalPipelineSliceMeta( + inputTask.getIndex(), + task.getIndex(), + cycle.getPipelineId(), + edgeId, + inputTask.getWorkerInfo().getContainerName()); + Shard shard = new Shard(edgeId, Collections.singletonList(logicalSlice)); + inputs.add(shard); + } } + return new ShardInputDesc( + edgeId, + inputEdge.getEdgeName(), + inputs, + inputEdge.getEncoder(), + dataExchangeMode, + batchPhase); + case RESPONSE: + List results = resultManager.get(inputEdge.getEdgeId()); + if (results == null) { + return null; + } + List dataInput = new ArrayList<>(); + for (IResult result : results) { + if (result.getType() != OutputType.RESPONSE) { + throw new GeaflowRuntimeException( + String.format( + "edge %s type %s not support handle result %s", + inputEdge.getEdgeId(), inputEdge.getType(), result.getType())); + } + dataInput.addAll(((ResponseResult) result).getResponse()); + } + return new RawDataInputDesc(edgeId, inputEdge.getEdgeName(), dataInput); + default: + throw new GeaflowRuntimeException( + String.format( + "not support build input for edge %s type %s", + inputEdge.getEdgeId(), inputEdge.getType())); } + } - protected static IOutputDesc buildOutputDesc(ExecutionTask task, - ExecutionEdge outEdge, - ExecutionNodeCycle cycle, - boolean prefetch) { - - int vertexId = task.getVertexId(); - switch (outEdge.getType()) { - case LOOP: - case FORWARD: - List tasks = cycle.getVertexIdToTasks().get(outEdge.getTargetId()); - List taskIds = tasks.stream().map(ExecutionTask::getTaskId).collect(Collectors.toList()); - // TODO forward partitioner not write to all output tasks. - int numPartitions = outEdge.getSrcId() == outEdge.getTargetId() - ? task.getMaxParallelism() - : task.getNumPartitions(); - DataExchangeMode dataExchangeMode = getOutputDataExchangeMode(cycle, outEdge.getTargetId(), prefetch); - return new ForwardOutputDesc<>( - vertexId, - outEdge.getEdgeId(), - numPartitions, - outEdge.getEdgeName(), - dataExchangeMode, - taskIds, - outEdge.getPartitioner(), - outEdge.getEncoder()); - case RESPONSE: - return new ResponseOutputDesc( - outEdge.getPartitioner().getOpId(), - outEdge.getEdgeId(), - outEdge.getEdgeName()); - default: - throw new GeaflowRuntimeException(String.format("not support build output for edge %s type %s", - outEdge.getEdgeId(), outEdge.getType())); + private static IOutputDesc buildCollectOutputDesc(ExecutionTask task) { + int opId = + getCollectOpId((AbstractOperator) ((AbstractProcessor) task.getProcessor()).getOperator()); + ResponseOutputDesc outputDesc = + new ResponseOutputDesc(opId, COLLECT_DATA_EDGE_ID, OutputType.RESPONSE.name()); + return outputDesc; + } - } + private static Integer getCollectOpId(AbstractOperator operator) { + if (operator.getNextOperators().isEmpty()) { + return operator.getOpArgs().getOpId(); + } else if (operator.getNextOperators().size() == 1) { + return getCollectOpId((AbstractOperator) operator.getNextOperators().get(0)); + } else { + throw new GeaflowRuntimeException("not support collect multi-output"); } + } - private static Tuple getInputDataExchangeMode(ExecutionNodeCycle cycle, - ExecutionTask task, - int vertexId, - boolean prefetch) { - Map vertexMap = cycle.getVertexGroup().getVertexMap(); - DataExchangeMode dataExchangeMode = DataExchangeMode.BATCH; - BatchPhase batchPhase = BatchPhase.CLASSIC; - if (vertexMap.containsKey(vertexId)) { - dataExchangeMode = DataExchangeMode.PIPELINE; - } else if (prefetch && cycle.isHeadTask(task.getTaskId())) { - batchPhase = BatchPhase.PREFETCH_READ; - } - return Tuple.of(dataExchangeMode, batchPhase); + protected static IOutputDesc buildOutputDesc( + ExecutionTask task, ExecutionEdge outEdge, ExecutionNodeCycle cycle, boolean prefetch) { + + int vertexId = task.getVertexId(); + switch (outEdge.getType()) { + case LOOP: + case FORWARD: + List tasks = cycle.getVertexIdToTasks().get(outEdge.getTargetId()); + List taskIds = + tasks.stream().map(ExecutionTask::getTaskId).collect(Collectors.toList()); + // TODO forward partitioner not write to all output tasks. + int numPartitions = + outEdge.getSrcId() == outEdge.getTargetId() + ? task.getMaxParallelism() + : task.getNumPartitions(); + DataExchangeMode dataExchangeMode = + getOutputDataExchangeMode(cycle, outEdge.getTargetId(), prefetch); + return new ForwardOutputDesc<>( + vertexId, + outEdge.getEdgeId(), + numPartitions, + outEdge.getEdgeName(), + dataExchangeMode, + taskIds, + outEdge.getPartitioner(), + outEdge.getEncoder()); + case RESPONSE: + return new ResponseOutputDesc( + outEdge.getPartitioner().getOpId(), outEdge.getEdgeId(), outEdge.getEdgeName()); + default: + throw new GeaflowRuntimeException( + String.format( + "not support build output for edge %s type %s", + outEdge.getEdgeId(), outEdge.getType())); } + } - private static DataExchangeMode getOutputDataExchangeMode(ExecutionNodeCycle cycle, - int vertexId, - boolean prefetch) { - Map vertexMap = cycle.getVertexGroup().getVertexMap(); - return vertexMap.containsKey(vertexId) || prefetch ? DataExchangeMode.PIPELINE : DataExchangeMode.BATCH; + private static Tuple getInputDataExchangeMode( + ExecutionNodeCycle cycle, ExecutionTask task, int vertexId, boolean prefetch) { + Map vertexMap = cycle.getVertexGroup().getVertexMap(); + DataExchangeMode dataExchangeMode = DataExchangeMode.BATCH; + BatchPhase batchPhase = BatchPhase.CLASSIC; + if (vertexMap.containsKey(vertexId)) { + dataExchangeMode = DataExchangeMode.PIPELINE; + } else if (prefetch && cycle.isHeadTask(task.getTaskId())) { + batchPhase = BatchPhase.PREFETCH_READ; } + return Tuple.of(dataExchangeMode, batchPhase); + } + private static DataExchangeMode getOutputDataExchangeMode( + ExecutionNodeCycle cycle, int vertexId, boolean prefetch) { + Map vertexMap = cycle.getVertexGroup().getVertexMap(); + return vertexMap.containsKey(vertexId) || prefetch + ? DataExchangeMode.PIPELINE + : DataExchangeMode.BATCH; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManager.java index fa1bb99a6..5bc3f98fa 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManager.java @@ -28,6 +28,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.cluster.client.utils.PipelineUtil; import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.RequireResourceRequest; @@ -57,221 +58,240 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractScheduledWorkerManager implements IScheduledWorkerManager { +public abstract class AbstractScheduledWorkerManager + implements IScheduledWorkerManager { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractScheduledWorkerManager.class); - private static final String SEPARATOR = "#"; - protected static final String DEFAULT_RESOURCE_ID = "default_"; - protected static final String DEFAULT_GRAPH_VIEW_NAME = "default_graph_view_name"; + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractScheduledWorkerManager.class); + private static final String SEPARATOR = "#"; + protected static final String DEFAULT_RESOURCE_ID = "default_"; + protected static final String DEFAULT_GRAPH_VIEW_NAME = "default_graph_view_name"; - protected static final int RETRY_REQUEST_RESOURCE_INTERVAL = 5; - protected static final int REPORT_RETRY_TIMES = 50; + protected static final int RETRY_REQUEST_RESOURCE_INTERVAL = 5; + protected static final int REPORT_RETRY_TIMES = 50; - protected final String masterId; - protected final boolean isAsync; - protected transient Map workers; - protected transient Map> nodeCycles; - protected transient Map isAssigned; - protected TaskAssigner taskAssigner; + protected final String masterId; + protected final boolean isAsync; + protected transient Map workers; + protected transient Map> nodeCycles; + protected transient Map isAssigned; + protected TaskAssigner taskAssigner; - public AbstractScheduledWorkerManager(Configuration config) { - this.masterId = config.getMasterId(); - this.isAsync = PipelineUtil.isAsync(config); - if (this.taskAssigner == null) { - this.taskAssigner = new TaskAssigner(); - } + public AbstractScheduledWorkerManager(Configuration config) { + this.masterId = config.getMasterId(); + this.isAsync = PipelineUtil.isAsync(config); + if (this.taskAssigner == null) { + this.taskAssigner = new TaskAssigner(); } + } - @Override - public void init(ExecutionGraphCycle graph) { - if (this.workers == null) { - this.workers = new ConcurrentHashMap<>(); - this.nodeCycles = new ConcurrentHashMap<>(); - this.isAssigned = new ConcurrentHashMap<>(); - } - Long schedulerId = graph.getSchedulerId(); - String resourceId = this.genResourceId(graph.getDriverIndex(), schedulerId); + @Override + public void init(ExecutionGraphCycle graph) { + if (this.workers == null) { + this.workers = new ConcurrentHashMap<>(); + this.nodeCycles = new ConcurrentHashMap<>(); + this.isAssigned = new ConcurrentHashMap<>(); + } + Long schedulerId = graph.getSchedulerId(); + String resourceId = this.genResourceId(graph.getDriverIndex(), schedulerId); - int requestResourceNum = graph.getCycleMap().values().stream() + int requestResourceNum = + graph.getCycleMap().values().stream() .map(e -> ((ExecutionNodeCycle) e).getVertexGroup()) .map(AbstractScheduledWorkerManager::getExecutionGroupParallelism) .max(Integer::compareTo) .orElse(0); - List allocated = this.requestWorker(requestResourceNum, resourceId); - this.initWorkers(schedulerId, allocated, ScheduledWorkerManagerFactory.getWorkerManagerHALevel(graph)); - this.workers.put(schedulerId, new ResourceInfo(resourceId, allocated)); - LOGGER.info("scheduler {} request {} workers from resource manager", schedulerId, allocated.size()); + List allocated = this.requestWorker(requestResourceNum, resourceId); + this.initWorkers( + schedulerId, allocated, ScheduledWorkerManagerFactory.getWorkerManagerHALevel(graph)); + this.workers.put(schedulerId, new ResourceInfo(resourceId, allocated)); + LOGGER.info( + "scheduler {} request {} workers from resource manager", schedulerId, allocated.size()); - this.nodeCycles.put(schedulerId, extractNodeCycles(graph)); - } + this.nodeCycles.put(schedulerId, extractNodeCycles(graph)); + } - private static List extractNodeCycles(ExecutionGraphCycle graph) { - List list = new ArrayList<>(graph.getCycleMap().size()); - for (IExecutionCycle cycle : graph.getCycleMap().values()) { - if (cycle instanceof ExecutionNodeCycle) { - list.add((ExecutionNodeCycle) cycle); - } - } - return list; + private static List extractNodeCycles(ExecutionGraphCycle graph) { + List list = new ArrayList<>(graph.getCycleMap().size()); + for (IExecutionCycle cycle : graph.getCycleMap().values()) { + if (cycle instanceof ExecutionNodeCycle) { + list.add((ExecutionNodeCycle) cycle); + } } + return list; + } - @Override - public List assign(ExecutionGraphCycle graph) { - Long schedulerId = graph.getSchedulerId(); - boolean assigned = Optional.ofNullable(this.isAssigned.get(schedulerId)).orElse(false); - List allocated = this.workers.get(schedulerId).getWorkers(); - List taskIndexes = new ArrayList<>(); - for (int i = 0; i < allocated.size(); i++) { - taskIndexes.add(i); - } - String graphName = getGraphViewName(graph); - Map taskIndex2Worker = taskAssigner.assignTasks2Workers(graphName, - taskIndexes, allocated); - for (IExecutionCycle cycle : graph.getCycleMap().values()) { - if (cycle instanceof ExecutionNodeCycle) { - ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; - int parallelism = getExecutionGroupParallelism(nodeCycle.getVertexGroup()); - for (int i = 0; i < parallelism; i++) { - ExecutionTask task = nodeCycle.getTasks().get(i); - task.setWorkerInfo(taskIndex2Worker.get(i)); - } - nodeCycle.setWorkerAssigned(assigned); - } + @Override + public List assign(ExecutionGraphCycle graph) { + Long schedulerId = graph.getSchedulerId(); + boolean assigned = Optional.ofNullable(this.isAssigned.get(schedulerId)).orElse(false); + List allocated = this.workers.get(schedulerId).getWorkers(); + List taskIndexes = new ArrayList<>(); + for (int i = 0; i < allocated.size(); i++) { + taskIndexes.add(i); + } + String graphName = getGraphViewName(graph); + Map taskIndex2Worker = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, allocated); + for (IExecutionCycle cycle : graph.getCycleMap().values()) { + if (cycle instanceof ExecutionNodeCycle) { + ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; + int parallelism = getExecutionGroupParallelism(nodeCycle.getVertexGroup()); + for (int i = 0; i < parallelism; i++) { + ExecutionTask task = nodeCycle.getTasks().get(i); + task.setWorkerInfo(taskIndex2Worker.get(i)); } - return allocated; + nodeCycle.setWorkerAssigned(assigned); + } } + return allocated; + } - protected WorkerInfo assignTaskWorker(WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { - return worker; - } + protected WorkerInfo assignTaskWorker( + WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { + return worker; + } - @Override - public void release(ExecutionGraphCycle graph) { - Long schedulerId = graph.getSchedulerId(); - ResourceInfo remove = this.workers.remove(schedulerId); - if (remove != null) { - RpcClient.getInstance().releaseResource(masterId, - ReleaseResourceRequest.build(remove.getResourceId(), remove.getWorkers())); - } + @Override + public void release(ExecutionGraphCycle graph) { + Long schedulerId = graph.getSchedulerId(); + ResourceInfo remove = this.workers.remove(schedulerId); + if (remove != null) { + RpcClient.getInstance() + .releaseResource( + masterId, ReleaseResourceRequest.build(remove.getResourceId(), remove.getWorkers())); } + } - @Override - public void clean(CleanWorkerFunction function, IExecutionCycle cycle) { - Long schedulerId = cycle.getSchedulerId(); - function.clean(this.workers.get(schedulerId).getWorkers()); - this.isAssigned.put(schedulerId, true); - for (ExecutionNodeCycle nodeCycle : this.nodeCycles.get(schedulerId)) { - nodeCycle.setWorkerAssigned(true); - } + @Override + public void clean(CleanWorkerFunction function, IExecutionCycle cycle) { + Long schedulerId = cycle.getSchedulerId(); + function.clean(this.workers.get(schedulerId).getWorkers()); + this.isAssigned.put(schedulerId, true); + for (ExecutionNodeCycle nodeCycle : this.nodeCycles.get(schedulerId)) { + nodeCycle.setWorkerAssigned(true); } + } - @Override - public synchronized void close(IExecutionCycle cycle) { - if (this.isAssigned != null) { - this.isAssigned.remove(cycle.getSchedulerId()); - } + @Override + public synchronized void close(IExecutionCycle cycle) { + if (this.isAssigned != null) { + this.isAssigned.remove(cycle.getSchedulerId()); } + } - protected static int getExecutionGroupParallelism(ExecutionVertexGroup vertexGroup) { - return vertexGroup.getVertexMap().values().stream() - .map(ExecutionVertex::getParallelism) - .reduce(Integer::sum) - .orElse(0); - } + protected static int getExecutionGroupParallelism(ExecutionVertexGroup vertexGroup) { + return vertexGroup.getVertexMap().values().stream() + .map(ExecutionVertex::getParallelism) + .reduce(Integer::sum) + .orElse(0); + } - protected List requestWorker(int requestResourceNum, String resourceId) { - IAllocator.AllocateStrategy allocateStrategy = this.isAsync - ? IAllocator.AllocateStrategy.PROCESS_FAIR : IAllocator.AllocateStrategy.ROUND_ROBIN; - RequireResponse response = RpcClient.getInstance().requireResource(this.masterId, - RequireResourceRequest.build(resourceId, requestResourceNum, allocateStrategy)); - int retryTimes = 1; - while (!response.isSuccess() || response.getWorkers().isEmpty()) { - try { - response = RpcClient.getInstance().requireResource(masterId, - RequireResourceRequest.build(resourceId, - requestResourceNum, - allocateStrategy)); - if (retryTimes % REPORT_RETRY_TIMES == 0) { - String msg = String.format("request %s worker with allocateStrategy %s failed after %s times: %s", - requestResourceNum, allocateStrategy, retryTimes, response.getMsg()); - LOGGER.warn(msg); - } - Thread.sleep((long) RETRY_REQUEST_RESOURCE_INTERVAL * retryTimes); - // TODO Report to ExceptionCollector. - retryTimes++; - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + protected List requestWorker(int requestResourceNum, String resourceId) { + IAllocator.AllocateStrategy allocateStrategy = + this.isAsync + ? IAllocator.AllocateStrategy.PROCESS_FAIR + : IAllocator.AllocateStrategy.ROUND_ROBIN; + RequireResponse response = + RpcClient.getInstance() + .requireResource( + this.masterId, + RequireResourceRequest.build(resourceId, requestResourceNum, allocateStrategy)); + int retryTimes = 1; + while (!response.isSuccess() || response.getWorkers().isEmpty()) { + try { + response = + RpcClient.getInstance() + .requireResource( + masterId, + RequireResourceRequest.build(resourceId, requestResourceNum, allocateStrategy)); + if (retryTimes % REPORT_RETRY_TIMES == 0) { + String msg = + String.format( + "request %s worker with allocateStrategy %s failed after %s times: %s", + requestResourceNum, allocateStrategy, retryTimes, response.getMsg()); + LOGGER.warn(msg); } - return response.getWorkers(); + Thread.sleep((long) RETRY_REQUEST_RESOURCE_INTERVAL * retryTimes); + // TODO Report to ExceptionCollector. + retryTimes++; + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); + } } + return response.getWorkers(); + } - protected void initWorkers(Long schedulerId, List workers, HighAvailableLevel highAvailableLevel) { - if (this.workers.get(schedulerId) != null) { - LOGGER.info("recovered workers {} already init, ignore init again", workers.size()); - return; - } - LOGGER.info("do init workers {}", workers.size()); - CountDownLatch processCountDownLatch = new CountDownLatch(workers.size()); - AtomicInteger failureCount = new AtomicInteger(0); - AtomicReference exception = new AtomicReference<>(); - for (WorkerInfo workerInfo : workers) { - int workerId = workerInfo.getWorkerIndex(); - CreateTaskEvent createTaskEvent = new CreateTaskEvent(workerId, highAvailableLevel); - CreateWorkerEvent createWorkerEvent = new CreateWorkerEvent(workerId, highAvailableLevel); - ComposeEvent composeEvent = new ComposeEvent(workerId, Arrays.asList(createTaskEvent, createWorkerEvent)); + protected void initWorkers( + Long schedulerId, List workers, HighAvailableLevel highAvailableLevel) { + if (this.workers.get(schedulerId) != null) { + LOGGER.info("recovered workers {} already init, ignore init again", workers.size()); + return; + } + LOGGER.info("do init workers {}", workers.size()); + CountDownLatch processCountDownLatch = new CountDownLatch(workers.size()); + AtomicInteger failureCount = new AtomicInteger(0); + AtomicReference exception = new AtomicReference<>(); + for (WorkerInfo workerInfo : workers) { + int workerId = workerInfo.getWorkerIndex(); + CreateTaskEvent createTaskEvent = new CreateTaskEvent(workerId, highAvailableLevel); + CreateWorkerEvent createWorkerEvent = new CreateWorkerEvent(workerId, highAvailableLevel); + ComposeEvent composeEvent = + new ComposeEvent(workerId, Arrays.asList(createTaskEvent, createWorkerEvent)); - RpcClient.getInstance().processContainer(workerInfo.getContainerName(), composeEvent, - new RpcEndpointRef.RpcCallback() { - @Override - public void onSuccess(Container.Response value) { - processCountDownLatch.countDown(); - } + RpcClient.getInstance() + .processContainer( + workerInfo.getContainerName(), + composeEvent, + new RpcEndpointRef.RpcCallback() { + @Override + public void onSuccess(Container.Response value) { + processCountDownLatch.countDown(); + } - @Override - public void onFailure(Throwable t) { - processCountDownLatch.countDown(); - failureCount.incrementAndGet(); - exception.compareAndSet(null, t); - } - }); - } - try { - processCountDownLatch.await(); - LOGGER.info("do init workers finished"); - if (failureCount.get() > 0) { - throw new GeaflowRuntimeException(String.format("init worker failed. failed count %s", - failureCount.get()), exception.get()); - } - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + @Override + public void onFailure(Throwable t) { + processCountDownLatch.countDown(); + failureCount.incrementAndGet(); + exception.compareAndSet(null, t); + } + }); } - - public String genResourceId(int driverIndex, Long schedulerId) { - return DEFAULT_RESOURCE_ID + driverIndex + SEPARATOR + schedulerId; + try { + processCountDownLatch.await(); + LOGGER.info("do init workers finished"); + if (failureCount.get() > 0) { + throw new GeaflowRuntimeException( + String.format("init worker failed. failed count %s", failureCount.get()), + exception.get()); + } + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } + } - private String getGraphViewName(ExecutionGraphCycle graph) { - for (IExecutionCycle cycle : graph.getCycleMap().values()) { - if (cycle instanceof ExecutionNodeCycle) { - ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; - List tasks = nodeCycle.getTasks(); - for (ExecutionTask task : tasks) { - AbstractProcessor processor = (AbstractProcessor) (task.getProcessor()); - if (processor != null) { - Operator operator = processor.getOperator(); - if (operator instanceof AbstractGraphVertexCentricOp) { - AbstractGraphVertexCentricOp graphOp = - (AbstractGraphVertexCentricOp) operator; - return graphOp.getGraphViewName(); - } - } - } + public String genResourceId(int driverIndex, Long schedulerId) { + return DEFAULT_RESOURCE_ID + driverIndex + SEPARATOR + schedulerId; + } + private String getGraphViewName(ExecutionGraphCycle graph) { + for (IExecutionCycle cycle : graph.getCycleMap().values()) { + if (cycle instanceof ExecutionNodeCycle) { + ExecutionNodeCycle nodeCycle = (ExecutionNodeCycle) cycle; + List tasks = nodeCycle.getTasks(); + for (ExecutionTask task : tasks) { + AbstractProcessor processor = (AbstractProcessor) (task.getProcessor()); + if (processor != null) { + Operator operator = processor.getOperator(); + if (operator instanceof AbstractGraphVertexCentricOp) { + AbstractGraphVertexCentricOp graphOp = (AbstractGraphVertexCentricOp) operator; + return graphOp.getGraphViewName(); } + } } - return DEFAULT_GRAPH_VIEW_NAME; + } } -} \ No newline at end of file + return DEFAULT_GRAPH_VIEW_NAME; + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointCycleScheduledWorkerManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointCycleScheduledWorkerManager.java index a60eb8a55..c314b72f0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointCycleScheduledWorkerManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointCycleScheduledWorkerManager.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.core.scheduler.resource; import java.util.Optional; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -28,16 +29,16 @@ public class CheckpointCycleScheduledWorkerManager extends AbstractScheduledWorkerManager { - public CheckpointCycleScheduledWorkerManager(Configuration config) { - super(config); - } + public CheckpointCycleScheduledWorkerManager(Configuration config) { + super(config); + } - @Override - protected WorkerInfo assignTaskWorker(WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { - if (affinityLevel == AffinityLevel.worker) { - return Optional.ofNullable(task.getWorkerInfo()).orElse(worker); - } - throw new GeaflowRuntimeException("not support affinity level yet " + affinityLevel); + @Override + protected WorkerInfo assignTaskWorker( + WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { + if (affinityLevel == AffinityLevel.worker) { + return Optional.ofNullable(task.getWorkerInfo()).orElse(worker); } - -} \ No newline at end of file + throw new GeaflowRuntimeException("not support affinity level yet " + affinityLevel); + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/IScheduledWorkerManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/IScheduledWorkerManager.java index 088a373e9..29e3c57ea 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/IScheduledWorkerManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/IScheduledWorkerManager.java @@ -20,45 +20,37 @@ package org.apache.geaflow.runtime.core.scheduler.resource; import java.util.List; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.runtime.core.scheduler.cycle.IExecutionCycle; public interface IScheduledWorkerManager { - /** - * Init the worker manager by input graph. - * The graph info will help to decide the total worker resources required. - */ - void init(G graph); - - /** - * Assign workers for execution task of input graph. - * - * @return Workers if assign worker succeed, otherwise empty. - */ - List assign(G graph); + /** + * Init the worker manager by input graph. The graph info will help to decide the total worker + * resources required. + */ + void init(G graph); - /** - * Release all worker resource for the input graph. - */ - void release(G graph); + /** + * Assign workers for execution task of input graph. + * + * @return Workers if assign worker succeed, otherwise empty. + */ + List assign(G graph); - /** - * Clean worker runtime context for used workers by specified clean function. - */ - void clean(CleanWorkerFunction cleaFunc, IExecutionCycle cycle); + /** Release all worker resource for the input graph. */ + void release(G graph); - /** - * Release all worker to master resource manager. - */ - void close(IExecutionCycle cycle); + /** Clean worker runtime context for used workers by specified clean function. */ + void clean(CleanWorkerFunction cleaFunc, IExecutionCycle cycle); - /** - * Function interface to clean runtime context for already assigned workers. - */ - interface CleanWorkerFunction { + /** Release all worker to master resource manager. */ + void close(IExecutionCycle cycle); - void clean(List assignedWorkers); + /** Function interface to clean runtime context for already assigned workers. */ + interface CleanWorkerFunction { - } + void clean(List assignedWorkers); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoCycleScheduledWorkerManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoCycleScheduledWorkerManager.java index 7d64ccc6b..13327e7d6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoCycleScheduledWorkerManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoCycleScheduledWorkerManager.java @@ -23,8 +23,7 @@ public class RedoCycleScheduledWorkerManager extends AbstractScheduledWorkerManager { - protected RedoCycleScheduledWorkerManager(Configuration config) { - super(config); - } - + protected RedoCycleScheduledWorkerManager(Configuration config) { + super(config); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/ScheduledWorkerManagerFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/ScheduledWorkerManagerFactory.java index 530a3c552..56f44ddeb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/ScheduledWorkerManagerFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/ScheduledWorkerManagerFactory.java @@ -19,8 +19,8 @@ package org.apache.geaflow.runtime.core.scheduler.resource; -import com.google.common.annotations.VisibleForTesting; import java.util.Map; + import org.apache.geaflow.cluster.resourcemanager.ReleaseResourceRequest; import org.apache.geaflow.cluster.resourcemanager.ResourceInfo; import org.apache.geaflow.cluster.rpc.RpcClient; @@ -31,78 +31,82 @@ import org.apache.geaflow.runtime.core.scheduler.cycle.ExecutionGraphCycle; import org.apache.geaflow.runtime.core.scheduler.cycle.IExecutionCycle; -public class ScheduledWorkerManagerFactory { - - private static volatile RedoCycleScheduledWorkerManager redoWorkerManager; - private static volatile CheckpointCycleScheduledWorkerManager checkpointWorkerManager; +import com.google.common.annotations.VisibleForTesting; - public static > WM createScheduledWorkerManager( - Configuration config, HighAvailableLevel level) { - switch (level) { - case REDO: - if (redoWorkerManager == null) { - synchronized (ScheduledWorkerManagerFactory.class) { - if (redoWorkerManager == null) { - redoWorkerManager = new RedoCycleScheduledWorkerManager(config); - } - } - } - return (WM) redoWorkerManager; - case CHECKPOINT: - if (checkpointWorkerManager == null) { - synchronized (ScheduledWorkerManagerFactory.class) { - if (checkpointWorkerManager == null) { - checkpointWorkerManager = new CheckpointCycleScheduledWorkerManager(config); - } - } - } - return (WM) checkpointWorkerManager; - default: - throw new GeaflowRuntimeException("not support worker manager type " + level); - } +public class ScheduledWorkerManagerFactory { - } + private static volatile RedoCycleScheduledWorkerManager redoWorkerManager; + private static volatile CheckpointCycleScheduledWorkerManager checkpointWorkerManager; - @VisibleForTesting - public static synchronized void clear() { - if (redoWorkerManager != null) { - clear(redoWorkerManager); - redoWorkerManager = null; + public static > + WM createScheduledWorkerManager(Configuration config, HighAvailableLevel level) { + switch (level) { + case REDO: + if (redoWorkerManager == null) { + synchronized (ScheduledWorkerManagerFactory.class) { + if (redoWorkerManager == null) { + redoWorkerManager = new RedoCycleScheduledWorkerManager(config); + } + } } - if (checkpointWorkerManager != null) { - clear(checkpointWorkerManager); - checkpointWorkerManager = null; + return (WM) redoWorkerManager; + case CHECKPOINT: + if (checkpointWorkerManager == null) { + synchronized (ScheduledWorkerManagerFactory.class) { + if (checkpointWorkerManager == null) { + checkpointWorkerManager = new CheckpointCycleScheduledWorkerManager(config); + } + } } + return (WM) checkpointWorkerManager; + default: + throw new GeaflowRuntimeException("not support worker manager type " + level); } + } - private static void clear(AbstractScheduledWorkerManager workerManager) { - if (workerManager.workers != null) { - for (Map.Entry workerEntry : workerManager.workers.entrySet()) { - RpcClient.getInstance().releaseResource(workerManager.masterId, - ReleaseResourceRequest.build(workerEntry.getValue().getResourceId(), workerEntry.getValue().getWorkers())); - } - } + @VisibleForTesting + public static synchronized void clear() { + if (redoWorkerManager != null) { + clear(redoWorkerManager); + redoWorkerManager = null; + } + if (checkpointWorkerManager != null) { + clear(checkpointWorkerManager); + checkpointWorkerManager = null; } + } - public static HighAvailableLevel getWorkerManagerHALevel(IExecutionCycle cycle) { - if (cycle.getType() == ExecutionCycleType.GRAPH) { - ExecutionGraphCycle graph = (ExecutionGraphCycle) cycle; - if (graph.getHighAvailableLevel() == HighAvailableLevel.CHECKPOINT) { - return HighAvailableLevel.CHECKPOINT; - } - // As for stream case, the whole graph is REDO ha level while child cycle is CHECKPOINT. - // We need set worker manager ha level to CHECKPOINT - // to make sure all request worker initialized with CHECKPOINT level. - if (graph.getCycleMap().size() == 1) { - IExecutionCycle child = graph.getCycleMap().values().iterator().next(); - if (child.getHighAvailableLevel() == HighAvailableLevel.CHECKPOINT) { - return HighAvailableLevel.CHECKPOINT; - } - } - return HighAvailableLevel.REDO; + private static void clear(AbstractScheduledWorkerManager workerManager) { + if (workerManager.workers != null) { + for (Map.Entry workerEntry : workerManager.workers.entrySet()) { + RpcClient.getInstance() + .releaseResource( + workerManager.masterId, + ReleaseResourceRequest.build( + workerEntry.getValue().getResourceId(), workerEntry.getValue().getWorkers())); + } + } + } - } else { - return cycle.getHighAvailableLevel(); + public static HighAvailableLevel getWorkerManagerHALevel(IExecutionCycle cycle) { + if (cycle.getType() == ExecutionCycleType.GRAPH) { + ExecutionGraphCycle graph = (ExecutionGraphCycle) cycle; + if (graph.getHighAvailableLevel() == HighAvailableLevel.CHECKPOINT) { + return HighAvailableLevel.CHECKPOINT; + } + // As for stream case, the whole graph is REDO ha level while child cycle is CHECKPOINT. + // We need set worker manager ha level to CHECKPOINT + // to make sure all request worker initialized with CHECKPOINT level. + if (graph.getCycleMap().size() == 1) { + IExecutionCycle child = graph.getCycleMap().values().iterator().next(); + if (child.getHighAvailableLevel() == HighAvailableLevel.CHECKPOINT) { + return HighAvailableLevel.CHECKPOINT; } + } + return HighAvailableLevel.REDO; + + } else { + return cycle.getHighAvailableLevel(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/TaskAssigner.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/TaskAssigner.java index a8d53dec0..e2d236b82 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/TaskAssigner.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/resource/TaskAssigner.java @@ -31,170 +31,178 @@ import java.util.Map; import java.util.Queue; import java.util.Set; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; public class TaskAssigner implements Serializable { - private Map>> historyBindings; + private Map>> historyBindings; - public TaskAssigner() { - if (this.historyBindings == null) { - this.historyBindings = new HashMap<>(); - } + public TaskAssigner() { + if (this.historyBindings == null) { + this.historyBindings = new HashMap<>(); } - - /** - * The matching of tasks and workers is a bipartite graph matching problem. The Task - * Assigner use the Hopcroft–Karp algorithm for bipartite graph matching, with a time - * complexity of O(e√v), where e is the number of edges and v is the number of vertices. - * - * @return the matching of tasks and workers. - */ - public Map assignTasks2Workers(String graphName, List tasks, - List workers) { - if (tasks.size() != workers.size()) { - throw new IllegalArgumentException( - "Tasks and workers queues must have the same length."); - } - if (DEFAULT_GRAPH_VIEW_NAME.equals(graphName)) { - Map matches = new HashMap<>(); - Iterator workerIterator = workers.iterator(); - for (Integer task : tasks) { - if (workerIterator.hasNext()) { - WorkerInfo worker = workerIterator.next(); - matches.put(task, worker); - } - } - return matches; - } - - Map> historyBinding = historyBindings.computeIfAbsent(graphName, - k -> new HashMap<>()); - // Step 1: Build bipartite graph based on history. - Map> graph = new HashMap<>(); - for ( - Integer task : tasks) { - graph.put(task, new ArrayList<>()); - if (historyBinding.containsKey(task)) { - for (WorkerInfo worker : historyBinding.get(task)) { - if (workers.contains(worker)) { - graph.get(task).add(worker); - } - } - } + } + + /** + * The matching of tasks and workers is a bipartite graph matching problem. The Task Assigner use + * the Hopcroft–Karp algorithm for bipartite graph matching, with a time complexity of O(e√v), + * where e is the number of edges and v is the number of vertices. + * + * @return the matching of tasks and workers. + */ + public Map assignTasks2Workers( + String graphName, List tasks, List workers) { + if (tasks.size() != workers.size()) { + throw new IllegalArgumentException("Tasks and workers queues must have the same length."); + } + if (DEFAULT_GRAPH_VIEW_NAME.equals(graphName)) { + Map matches = new HashMap<>(); + Iterator workerIterator = workers.iterator(); + for (Integer task : tasks) { + if (workerIterator.hasNext()) { + WorkerInfo worker = workerIterator.next(); + matches.put(task, worker); } + } + return matches; + } - // Step 2: Use Hopcroft-Karp to find maximum matching. - Map matches = hopcroftKarp(tasks, workers, graph); - - // Step 3: Assign remaining workers to tasks without a match. - Set unmatchedTasks = new HashSet<>(tasks); - unmatchedTasks.removeAll(matches.keySet()); - Set unmatchedWorkers = new HashSet<>(workers); - unmatchedWorkers.removeAll(matches.values()); - - Iterator workerIterator = unmatchedWorkers.iterator(); - for (Integer task : unmatchedTasks) { - if (workerIterator.hasNext()) { - WorkerInfo worker = workerIterator.next(); - matches.put(task, worker); - } + Map> historyBinding = + historyBindings.computeIfAbsent(graphName, k -> new HashMap<>()); + // Step 1: Build bipartite graph based on history. + Map> graph = new HashMap<>(); + for (Integer task : tasks) { + graph.put(task, new ArrayList<>()); + if (historyBinding.containsKey(task)) { + for (WorkerInfo worker : historyBinding.get(task)) { + if (workers.contains(worker)) { + graph.get(task).add(worker); + } } + } + } - // Update history bindings. - for (Map.Entry entry : matches.entrySet()) { - Integer task = entry.getKey(); - WorkerInfo worker = entry.getValue(); - historyBinding.putIfAbsent(task, new HashSet<>()); - historyBinding.get(task).add(worker); - } - return matches; + // Step 2: Use Hopcroft-Karp to find maximum matching. + Map matches = hopcroftKarp(tasks, workers, graph); + + // Step 3: Assign remaining workers to tasks without a match. + Set unmatchedTasks = new HashSet<>(tasks); + unmatchedTasks.removeAll(matches.keySet()); + Set unmatchedWorkers = new HashSet<>(workers); + unmatchedWorkers.removeAll(matches.values()); + + Iterator workerIterator = unmatchedWorkers.iterator(); + for (Integer task : unmatchedTasks) { + if (workerIterator.hasNext()) { + WorkerInfo worker = workerIterator.next(); + matches.put(task, worker); + } } - private Map hopcroftKarp(List tasks, List workers, - Map> graph) { - Map pairU = new HashMap<>(); - Map pairV = new HashMap<>(); - Map dist = new HashMap<>(); + // Update history bindings. + for (Map.Entry entry : matches.entrySet()) { + Integer task = entry.getKey(); + WorkerInfo worker = entry.getValue(); + historyBinding.putIfAbsent(task, new HashSet<>()); + historyBinding.get(task).add(worker); + } + return matches; + } - for (Integer task : tasks) { - pairU.put(task, null); - } - for (WorkerInfo worker : workers) { - pairV.put(worker, null); - } + private Map hopcroftKarp( + List tasks, List workers, Map> graph) { + Map pairU = new HashMap<>(); + Map pairV = new HashMap<>(); + Map dist = new HashMap<>(); - int inf = Integer.MAX_VALUE; + for (Integer task : tasks) { + pairU.put(task, null); + } + for (WorkerInfo worker : workers) { + pairV.put(worker, null); + } - while (bfs(tasks, pairU, pairV, dist, graph, inf)) { - for (Integer task : tasks) { - if (pairU.get(task) == null) { - // DFS will update pairU and pairV. - dfs(task, pairU, pairV, dist, graph, inf); - } - } - } + int inf = Integer.MAX_VALUE; - Map matches = new HashMap<>(); - for (Integer task : pairU.keySet()) { - if (pairU.get(task) != null) { - matches.put(task, pairU.get(task)); - } + while (bfs(tasks, pairU, pairV, dist, graph, inf)) { + for (Integer task : tasks) { + if (pairU.get(task) == null) { + // DFS will update pairU and pairV. + dfs(task, pairU, pairV, dist, graph, inf); } + } + } - return matches; + Map matches = new HashMap<>(); + for (Integer task : pairU.keySet()) { + if (pairU.get(task) != null) { + matches.put(task, pairU.get(task)); + } } - private boolean bfs(List tasks, Map pairU, - Map pairV, Map dist, - Map> graph, int inf) { - Queue queue = new LinkedList<>(); - - for (Integer task : tasks) { - if (pairU.get(task) == null) { - dist.put(task, 0); - queue.add(task); - } else { - dist.put(task, inf); - } - } + return matches; + } + + private boolean bfs( + List tasks, + Map pairU, + Map pairV, + Map dist, + Map> graph, + int inf) { + Queue queue = new LinkedList<>(); + + for (Integer task : tasks) { + if (pairU.get(task) == null) { + dist.put(task, 0); + queue.add(task); + } else { + dist.put(task, inf); + } + } - boolean hasAugmentingPath = false; - - while (!queue.isEmpty()) { - Integer task = queue.poll(); - if (dist.get(task) < inf) { - for (WorkerInfo worker : graph.get(task)) { - Integer nextTask = pairV.get(worker); - if (nextTask == null) { - hasAugmentingPath = true; - } else if (dist.get(nextTask) == inf) { - dist.put(nextTask, dist.get(task) + 1); - queue.add(nextTask); - } - } - } + boolean hasAugmentingPath = false; + + while (!queue.isEmpty()) { + Integer task = queue.poll(); + if (dist.get(task) < inf) { + for (WorkerInfo worker : graph.get(task)) { + Integer nextTask = pairV.get(worker); + if (nextTask == null) { + hasAugmentingPath = true; + } else if (dist.get(nextTask) == inf) { + dist.put(nextTask, dist.get(task) + 1); + queue.add(nextTask); + } } - - return hasAugmentingPath; + } } - private boolean dfs(Integer task, Map pairU, - Map pairV, Map dist, - Map> graph, int inf) { - if (task != null) { - for (WorkerInfo worker : graph.get(task)) { - Integer nextTask = pairV.get(worker); - if (nextTask == null || (dist.get(nextTask) == dist.get(task) + 1 && dfs(nextTask, - pairU, pairV, dist, graph, inf))) { - pairV.put(worker, task); - pairU.put(task, worker); - return true; - } - } - dist.put(task, inf); - return false; + return hasAugmentingPath; + } + + private boolean dfs( + Integer task, + Map pairU, + Map pairV, + Map dist, + Map> graph, + int inf) { + if (task != null) { + for (WorkerInfo worker : graph.get(task)) { + Integer nextTask = pairV.get(worker); + if (nextTask == null + || (dist.get(nextTask) == dist.get(task) + 1 + && dfs(nextTask, pairU, pairV, dist, graph, inf))) { + pairV.put(worker, task); + pairU.put(task, worker); + return true; } - return true; + } + dist.put(task, inf); + return false; } + return true; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/AbstractFixedSizeEventHandler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/AbstractFixedSizeEventHandler.java index 72811e180..4973f6111 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/AbstractFixedSizeEventHandler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/AbstractFixedSizeEventHandler.java @@ -20,64 +20,59 @@ package org.apache.geaflow.runtime.core.scheduler.response; import java.util.Collection; + import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.IEvent; public abstract class AbstractFixedSizeEventHandler implements IEventListener { - protected int expectedSize; - private IEventCompletedHandler handler; - private ResponseEventCache eventCache; + protected int expectedSize; + private IEventCompletedHandler handler; + private ResponseEventCache eventCache; - public AbstractFixedSizeEventHandler(int expectedSize, IEventCompletedHandler handler) { - this.expectedSize = expectedSize; - this.handler = handler; - this.eventCache = buildEventCache(); - } + public AbstractFixedSizeEventHandler(int expectedSize, IEventCompletedHandler handler) { + this.expectedSize = expectedSize; + this.handler = handler; + this.eventCache = buildEventCache(); + } - @Override - public void handleEvent(IEvent event) { - eventCache.add(event); - if (eventCache.size() == expectedSize) { - if (handler != null) { - handler.onCompleted(eventCache.values()); - } - } + @Override + public void handleEvent(IEvent event) { + eventCache.add(event); + if (eventCache.size() == expectedSize) { + if (handler != null) { + handler.onCompleted(eventCache.values()); + } } + } + + abstract ResponseEventCache buildEventCache(); - abstract ResponseEventCache buildEventCache(); + /** All finished event cache. */ + public interface ResponseEventCache { /** - * All finished event cache. + * Add event to cache. + * + * @param event need add to cache. */ - public interface ResponseEventCache { + void add(IEvent event); - /** - * Add event to cache. - * @param event need add to cache. - */ - void add(IEvent event); + /** Return the cached size of current events. */ + int size(); - /** - * Return the cached size of current events. - */ - int size(); + /** Return all cached values. */ + Collection values(); + } - /** - * Return all cached values. - */ - Collection values(); - } + /** Callback function when all events completed as expected. */ + public interface IEventCompletedHandler { /** - * Callback function when all events completed as expected. + * Do callback when received all events. + * + * @param events */ - public interface IEventCompletedHandler { - - /** - * Do callback when received all events. - * @param events - */ - void onCompleted(Collection events); - } + void onCompleted(Collection events); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/ComputeFinishEventListener.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/ComputeFinishEventListener.java index c6d353b1d..fbe0542d6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/ComputeFinishEventListener.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/ComputeFinishEventListener.java @@ -22,41 +22,41 @@ import java.util.Collection; import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; + import org.apache.geaflow.cluster.protocol.IEvent; public class ComputeFinishEventListener extends AbstractFixedSizeEventHandler { - public ComputeFinishEventListener(int eventCount, IEventCompletedHandler handler) { - super(eventCount, handler); - } + public ComputeFinishEventListener(int eventCount, IEventCompletedHandler handler) { + super(eventCount, handler); + } - @Override - ResponseEventCache buildEventCache() { - return new EventCache(expectedSize); - } + @Override + ResponseEventCache buildEventCache() { + return new EventCache(expectedSize); + } - public class EventCache implements ResponseEventCache { + public class EventCache implements ResponseEventCache { - private Queue events; + private Queue events; - public EventCache(int capacity) { - this.events = new LinkedBlockingQueue<>(capacity); - } + public EventCache(int capacity) { + this.events = new LinkedBlockingQueue<>(capacity); + } - @Override - public void add(IEvent event) { - events.add(event); - } + @Override + public void add(IEvent event) { + events.add(event); + } - @Override - public int size() { - return events.size(); - } + @Override + public int size() { + return events.size(); + } - @Override - public Collection values() { - return events; - } + @Override + public Collection values() { + return events; } + } } - diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/EventListenerKey.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/EventListenerKey.java index 3c23e3505..5eda1013b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/EventListenerKey.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/EventListenerKey.java @@ -20,57 +20,61 @@ package org.apache.geaflow.runtime.core.scheduler.response; import java.util.Objects; + import org.apache.geaflow.cluster.protocol.EventType; public class EventListenerKey { - private static final int DUMMY_WINDOW_ID = 0; + private static final int DUMMY_WINDOW_ID = 0; - private int cycleId; - private EventType eventType; - private long windowId; + private int cycleId; + private EventType eventType; + private long windowId; - private EventListenerKey(int cycleId, long windowId, EventType eventType) { - this.cycleId = cycleId; - this.windowId = windowId; - this.eventType = eventType; - } + private EventListenerKey(int cycleId, long windowId, EventType eventType) { + this.cycleId = cycleId; + this.windowId = windowId; + this.eventType = eventType; + } - public static EventListenerKey of(int cycleId) { - return new EventListenerKey(cycleId, DUMMY_WINDOW_ID, null); - } + public static EventListenerKey of(int cycleId) { + return new EventListenerKey(cycleId, DUMMY_WINDOW_ID, null); + } - public static EventListenerKey of(int cycleId, EventType eventType) { - return new EventListenerKey(cycleId, DUMMY_WINDOW_ID, eventType); - } + public static EventListenerKey of(int cycleId, EventType eventType) { + return new EventListenerKey(cycleId, DUMMY_WINDOW_ID, eventType); + } - public static EventListenerKey of(int cycleId, EventType eventType, long windowId) { - return new EventListenerKey(cycleId, windowId, eventType); - } + public static EventListenerKey of(int cycleId, EventType eventType, long windowId) { + return new EventListenerKey(cycleId, windowId, eventType); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - EventListenerKey that = (EventListenerKey) o; - return cycleId == that.cycleId && windowId == that.windowId && eventType == that.eventType; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(cycleId, eventType, windowId); + if (o == null || getClass() != o.getClass()) { + return false; } + EventListenerKey that = (EventListenerKey) o; + return cycleId == that.cycleId && windowId == that.windowId && eventType == that.eventType; + } - @Override - public String toString() { - return "EventListenerKey{" - + "cycleId=" + cycleId - + ", eventType=" + eventType - + ", windowId=" + windowId - + '}'; - } -} \ No newline at end of file + @Override + public int hashCode() { + return Objects.hash(cycleId, eventType, windowId); + } + + @Override + public String toString() { + return "EventListenerKey{" + + "cycleId=" + + cycleId + + ", eventType=" + + eventType + + ", windowId=" + + windowId + + '}'; + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/SourceFinishResponseEventListener.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/SourceFinishResponseEventListener.java index f3a2e41cb..c6b9d6730 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/SourceFinishResponseEventListener.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/response/SourceFinishResponseEventListener.java @@ -22,47 +22,48 @@ import java.util.Collection; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.runtime.core.protocol.DoneEvent; public class SourceFinishResponseEventListener extends AbstractFixedSizeEventHandler { - private int eventCount; + private int eventCount; - public SourceFinishResponseEventListener(int eventCount, IEventCompletedHandler handler) { - super(eventCount, handler); - this.eventCount = eventCount; - } + public SourceFinishResponseEventListener(int eventCount, IEventCompletedHandler handler) { + super(eventCount, handler); + this.eventCount = eventCount; + } - @Override - ResponseEventCache buildEventCache() { - return new EventCache(eventCount); - } + @Override + ResponseEventCache buildEventCache() { + return new EventCache(eventCount); + } - public class EventCache implements ResponseEventCache { + public class EventCache implements ResponseEventCache { - private Map events; + private Map events; - public EventCache(int capacity) { - this.events = new ConcurrentHashMap<>(capacity); - } + public EventCache(int capacity) { + this.events = new ConcurrentHashMap<>(capacity); + } - @Override - public void add(IEvent event) { - int taskId = ((DoneEvent) event).getTaskId(); - if (!events.containsKey(taskId)) { - events.put(taskId, event); - } - } + @Override + public void add(IEvent event) { + int taskId = ((DoneEvent) event).getTaskId(); + if (!events.containsKey(taskId)) { + events.put(taskId, event); + } + } - @Override - public int size() { - return events.size(); - } + @Override + public int size() { + return events.size(); + } - @Override - public Collection values() { - return events.values(); - } + @Override + public Collection values() { + return events.values(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/ExecutionResult.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/ExecutionResult.java index bccc5cfeb..f1381e493 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/ExecutionResult.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/ExecutionResult.java @@ -21,40 +21,40 @@ public class ExecutionResult implements IExecutionResult { - // Result of the execution. - private R result; - // Error object of the execution. - private E error; - // Check whether the execution is success. - // If true, then the result should not empty, else the error should not empty. - private boolean isSuccess; - - public ExecutionResult(R result, E error, boolean isSuccess) { - this.result = result; - this.error = error; - this.isSuccess = isSuccess; - } - - @Override - public R getResult() { - return result; - } - - @Override - public E getError() { - return error; - } - - @Override - public boolean isSuccess() { - return isSuccess; - } - - public static ExecutionResult buildSuccessResult(R result) { - return new ExecutionResult<>(result, null, true); - } - - public static ExecutionResult buildFailedResult(E error) { - return new ExecutionResult<>(null, error, false); - } + // Result of the execution. + private R result; + // Error object of the execution. + private E error; + // Check whether the execution is success. + // If true, then the result should not empty, else the error should not empty. + private boolean isSuccess; + + public ExecutionResult(R result, E error, boolean isSuccess) { + this.result = result; + this.error = error; + this.isSuccess = isSuccess; + } + + @Override + public R getResult() { + return result; + } + + @Override + public E getError() { + return error; + } + + @Override + public boolean isSuccess() { + return isSuccess; + } + + public static ExecutionResult buildSuccessResult(R result) { + return new ExecutionResult<>(result, null, true); + } + + public static ExecutionResult buildFailedResult(E error) { + return new ExecutionResult<>(null, error, false); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/IExecutionResult.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/IExecutionResult.java index 853356a7b..b4f21f7d5 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/IExecutionResult.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/result/IExecutionResult.java @@ -21,23 +21,19 @@ public interface IExecutionResult { - /** - * Get execution result. - * - * @return - */ - R getResult(); + /** + * Get execution result. + * + * @return + */ + R getResult(); - /** - * Returns execution error. - */ - E getError(); - - /** - * Check whether the execution is successful. - * if true, then can get result from {@link #getResult()} - * otherwise, then can get result from {@link #getError()} - */ - boolean isSuccess(); + /** Returns execution error. */ + E getError(); + /** + * Check whether the execution is successful. if true, then can get result from {@link + * #getResult()} otherwise, then can get result from {@link #getError()} + */ + boolean isSuccess(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/AbstractStateMachine.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/AbstractStateMachine.java index c3d2c672d..117b70fad 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/AbstractStateMachine.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/AbstractStateMachine.java @@ -23,73 +23,71 @@ import org.apache.geaflow.runtime.core.scheduler.context.AbstractCycleSchedulerContext; import org.apache.geaflow.runtime.core.scheduler.context.ICycleSchedulerContext; -public abstract class AbstractStateMachine implements IStateMachine { +public abstract class AbstractStateMachine + implements IStateMachine { - protected static final ScheduleState START = ScheduleState.of(ScheduleStateType.START); - protected static final ScheduleState END = ScheduleState.of(ScheduleStateType.END); - protected StateMachineManager stateMachineManager; - protected ScheduleState currentState; - protected ICycleSchedulerContext context; + protected static final ScheduleState START = ScheduleState.of(ScheduleStateType.START); + protected static final ScheduleState END = ScheduleState.of(ScheduleStateType.END); + protected StateMachineManager stateMachineManager; + protected ScheduleState currentState; + protected ICycleSchedulerContext context; + @Override + public void init(ICycleSchedulerContext context) { + this.context = context; + this.currentState = START; + ((AbstractCycleSchedulerContext) this.context) + .setCurrentIterationId(context.getInitialIterationId()); + this.stateMachineManager = new StateMachineManager(); + } - @Override - public void init(ICycleSchedulerContext context) { - this.context = context; - this.currentState = START; - ((AbstractCycleSchedulerContext) this.context).setCurrentIterationId(context.getInitialIterationId()); - this.stateMachineManager = new StateMachineManager(); + @Override + public IScheduleState readyToTransition() { + if (currentState != END) { + return transition(); } + return END; + } - @Override - public IScheduleState readyToTransition() { - if (currentState != END) { - return transition(); - } - return END; - } + @Override + public ScheduleState getCurrentState() { + return currentState; + } - @Override - public ScheduleState getCurrentState() { - return currentState; - } + /** Transition base on the current context. */ + @Override + public IScheduleState transition() { + return transition(currentState); + } - /** - * Transition base on the current context. - */ - @Override - public IScheduleState transition() { - return transition(currentState); - } - - @Override - public boolean isTerminated() { - return currentState == END; - } + @Override + public boolean isTerminated() { + return currentState == END; + } - /** - * Get a list of state transition path after apply a sequence of transition from the input - * source. - */ - private ScheduleState transition(ScheduleState source) { - ScheduleState target = stateMachineManager.transition(source, context); - if (target != null) { - if (END == target) { - currentState = END; - return END; - } + /** + * Get a list of state transition path after apply a sequence of transition from the input source. + */ + private ScheduleState transition(ScheduleState source) { + ScheduleState target = stateMachineManager.transition(source, context); + if (target != null) { + if (END == target) { + currentState = END; + return END; + } - currentState = ScheduleState.of(target.getScheduleStateType()); - return currentState; - } - return null; + currentState = ScheduleState.of(target.getScheduleStateType()); + return currentState; } + return null; + } - public static class ComputeTransitionCondition - implements ITransitionCondition { + public static class ComputeTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return context.hasNextIteration(); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return context.hasNextIteration(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ComposeState.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ComposeState.java index 1df69a1ce..5c7448761 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ComposeState.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ComposeState.java @@ -20,34 +20,32 @@ package org.apache.geaflow.runtime.core.scheduler.statemachine; import java.util.List; + import org.apache.geaflow.cluster.protocol.ScheduleStateType; public class ComposeState implements IScheduleState { - private List states; - - public ComposeState(List states) { - this.states = states; - } + private List states; - public static ComposeState of(List states) { - return new ComposeState(states); - } + public ComposeState(List states) { + this.states = states; + } - public List getStates() { - return states; - } + public static ComposeState of(List states) { + return new ComposeState(states); + } - @Override - public ScheduleStateType getScheduleStateType() { - return ScheduleStateType.COMPOSE; - } + public List getStates() { + return states; + } - @Override - public String toString() { - return "compose{" - + states - + '}'; - } + @Override + public ScheduleStateType getScheduleStateType() { + return ScheduleStateType.COMPOSE; + } + @Override + public String toString() { + return "compose{" + states + '}'; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IScheduleState.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IScheduleState.java index b028bb26d..4b56b4122 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IScheduleState.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IScheduleState.java @@ -23,6 +23,5 @@ public interface IScheduleState { - ScheduleStateType getScheduleStateType(); - + ScheduleStateType getScheduleStateType(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IStateMachine.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IStateMachine.java index 749da7750..5e081a656 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IStateMachine.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/IStateMachine.java @@ -21,29 +21,18 @@ public interface IStateMachine { - /** - * Init state machine by context. - */ - void init(C context); + /** Init state machine by context. */ + void init(C context); - /** - * Check whether state machine can do transition or waiting result. - */ - IScheduleState readyToTransition(); + /** Check whether state machine can do transition or waiting result. */ + IScheduleState readyToTransition(); - /** - * Trigger a transition to certain target state. - */ - S transition(); + /** Trigger a transition to certain target state. */ + S transition(); - /** - * Check whether the state machine is reach to END state. - */ - boolean isTerminated(); - - /** - * Returns current state. - */ - ScheduleState getCurrentState(); + /** Check whether the state machine is reach to END state. */ + boolean isTerminated(); + /** Returns current state. */ + ScheduleState getCurrentState(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransition.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransition.java index bc7b26cfc..c7b49e799 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransition.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransition.java @@ -22,9 +22,6 @@ @FunctionalInterface public interface ITransition { - /** - * Transfer source state to a target state by context. - */ - S transition(S source, C context); - -} \ No newline at end of file + /** Transfer source state to a target state by context. */ + S transition(S source, C context); +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransitionCondition.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransitionCondition.java index 9b03910b9..d15399942 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransitionCondition.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ITransitionCondition.java @@ -21,6 +21,5 @@ public interface ITransitionCondition { - boolean predicate(S state, C context); - -} \ No newline at end of file + boolean predicate(S state, C context); +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ScheduleState.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ScheduleState.java index 043a1740c..c8073fa55 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ScheduleState.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/ScheduleState.java @@ -20,44 +20,45 @@ package org.apache.geaflow.runtime.core.scheduler.statemachine; import java.util.Objects; + import org.apache.geaflow.cluster.protocol.ScheduleStateType; public class ScheduleState implements IScheduleState { - private ScheduleStateType stateType; + private ScheduleStateType stateType; - public ScheduleState(ScheduleStateType stateType) { - this.stateType = stateType; - } + public ScheduleState(ScheduleStateType stateType) { + this.stateType = stateType; + } - public static ScheduleState of(ScheduleStateType stateType) { - return new ScheduleState(stateType); - } + public static ScheduleState of(ScheduleStateType stateType) { + return new ScheduleState(stateType); + } - @Override - public ScheduleStateType getScheduleStateType() { - return stateType; - } + @Override + public ScheduleStateType getScheduleStateType() { + return stateType; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ScheduleState state = (ScheduleState) o; - return stateType == state.stateType; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(stateType); + if (o == null || getClass() != o.getClass()) { + return false; } + ScheduleState state = (ScheduleState) o; + return stateType == state.stateType; + } - @Override - public String toString() { - return stateType.name(); - } + @Override + public int hashCode() { + return Objects.hash(stateType); + } + + @Override + public String toString() { + return stateType.name(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineManager.java index 9752130d8..90e4e9574 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineManager.java @@ -26,49 +26,47 @@ public class StateMachineManager { - private Map>> stateToTransitionMap = new HashMap<>(); + private Map>> stateToTransitionMap = new HashMap<>(); - private static final ITransitionCondition ALWAYS = (s, c) -> true; + private static final ITransitionCondition ALWAYS = (s, c) -> true; - public void addTransition(S source, S target) { - this.addTransition(source, target, ALWAYS); - } + public void addTransition(S source, S target) { + this.addTransition(source, target, ALWAYS); + } - /** - * Add transition from source state to target state. - * The transition condition evaluation by added order. - * For example, denote as source -> target : condition. - * If s1 -> t1 : c1 is already added, then add s1 -> t2 : c2. - * If (c1 matches) return t1 - * else if (c2 matches) return t2 - * else return null; - */ - public void addTransition(S source, S target, ITransitionCondition condition) { - if (!stateToTransitionMap.containsKey(source)) { - stateToTransitionMap.put(source, new ArrayList<>()); - } - stateToTransitionMap.get(source).add((s, c) -> { - if (condition.predicate(s, c)) { + /** + * Add transition from source state to target state. The transition condition evaluation by added + * order. For example, denote as source -> target : condition. If s1 -> t1 : c1 is already added, + * then add s1 -> t2 : c2. If (c1 matches) return t1 else if (c2 matches) return t2 else return + * null; + */ + public void addTransition(S source, S target, ITransitionCondition condition) { + if (!stateToTransitionMap.containsKey(source)) { + stateToTransitionMap.put(source, new ArrayList<>()); + } + stateToTransitionMap + .get(source) + .add( + (s, c) -> { + if (condition.predicate(s, c)) { return target; - } else { + } else { return null; - } - }); - } + } + }); + } - /** - * Transition from source to certain target state by input context. - */ - public S transition(S source, C context) { - if (stateToTransitionMap.containsKey(source)) { - List> transitions = stateToTransitionMap.get(source); - for (ITransition t : transitions) { - S target = t.transition(source, context); - if (target != null) { - return target; - } - } + /** Transition from source to certain target state by input context. */ + public S transition(S source, C context) { + if (stateToTransitionMap.containsKey(source)) { + List> transitions = stateToTransitionMap.get(source); + for (ITransition t : transitions) { + S target = t.transition(source, context); + if (target != null) { + return target; } - return null; + } } + return null; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/graph/GraphStateMachine.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/graph/GraphStateMachine.java index 3df306682..f5b1c802c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/graph/GraphStateMachine.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/graph/GraphStateMachine.java @@ -25,31 +25,32 @@ import org.apache.geaflow.runtime.core.scheduler.statemachine.ITransitionCondition; import org.apache.geaflow.runtime.core.scheduler.statemachine.ScheduleState; -/** - * Holds all state and transitions of the schedule state machine. - */ +/** Holds all state and transitions of the schedule state machine. */ public class GraphStateMachine extends AbstractStateMachine { - private static final ScheduleState EXECUTE_COMPUTE = ScheduleState.of(ScheduleStateType.EXECUTE_COMPUTE); + private static final ScheduleState EXECUTE_COMPUTE = + ScheduleState.of(ScheduleStateType.EXECUTE_COMPUTE); - @Override - public void init(ICycleSchedulerContext context) { - super.init(context); - // Build state machine. - // START -> EXECUTE_COMPUTE. - this.stateMachineManager.addTransition(START, EXECUTE_COMPUTE, new ComputeTransitionCondition()); - - // EXECUTE_COMPUTE -> CLEAN_PIPELINE | CLEAN_PIPELINE. - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, EXECUTE_COMPUTE, new ComputeTransitionCondition()); - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, END, new FinishTransitionCondition()); - } + @Override + public void init(ICycleSchedulerContext context) { + super.init(context); + // Build state machine. + // START -> EXECUTE_COMPUTE. + this.stateMachineManager.addTransition( + START, EXECUTE_COMPUTE, new ComputeTransitionCondition()); - public static class FinishTransitionCondition - implements ITransitionCondition { + // EXECUTE_COMPUTE -> CLEAN_PIPELINE | CLEAN_PIPELINE. + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, EXECUTE_COMPUTE, new ComputeTransitionCondition()); + this.stateMachineManager.addTransition(EXECUTE_COMPUTE, END, new FinishTransitionCondition()); + } - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return context.isCycleFinished(); - } + public static class FinishTransitionCondition + implements ITransitionCondition { + + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return context.isCycleFinished(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/pipeline/PipelineStateMachine.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/pipeline/PipelineStateMachine.java index 26fa0975e..798b552e1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/pipeline/PipelineStateMachine.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/statemachine/pipeline/PipelineStateMachine.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.protocol.ScheduleStateType; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.runtime.core.scheduler.ExecutableEventIterator.ExecutableEvent; @@ -36,222 +37,247 @@ import org.apache.geaflow.runtime.core.scheduler.statemachine.ITransitionCondition; import org.apache.geaflow.runtime.core.scheduler.statemachine.ScheduleState; -/** - * Holds all state and transitions of the schedule state machine. - */ +/** Holds all state and transitions of the schedule state machine. */ public class PipelineStateMachine extends AbstractStateMachine { - private static final ScheduleState INIT = ScheduleState.of(ScheduleStateType.INIT); - private static final ScheduleState PREFETCH = ScheduleState.of(ScheduleStateType.PREFETCH); - private static final ScheduleState FINISH_PREFETCH = - ScheduleState.of(ScheduleStateType.FINISH_PREFETCH); - private static final ScheduleState ITERATION_INIT = ScheduleState.of(ScheduleStateType.ITERATION_INIT); - private static final ScheduleState EXECUTE_COMPUTE = ScheduleState.of(ScheduleStateType.EXECUTE_COMPUTE); - private static final ScheduleState ROLLBACK = ScheduleState.of(ScheduleStateType.ROLLBACK); - private static final ScheduleState ITERATION_FINISH = ScheduleState.of(ScheduleStateType.ITERATION_FINISH); - private static final ScheduleState CLEAN_CYCLE = ScheduleState.of(ScheduleStateType.CLEAN_CYCLE); - - @Override - public void init(ICycleSchedulerContext context) { - super.init(context); - - // Build state machine. - // START -> ROLLBACK | PREFETCH | INIT. - this.stateMachineManager.addTransition(START, ROLLBACK, new Start2RollbackTransitionCondition()); - this.stateMachineManager.addTransition(START, PREFETCH, new Start2PrefetchTransitionCondition()); - this.stateMachineManager.addTransition(START, INIT); - - // PREFETCH -> ITERATION_FINISH | INIT. - this.stateMachineManager.addTransition(PREFETCH, ITERATION_FINISH, new FinishTransitionCondition()); - this.stateMachineManager.addTransition(PREFETCH, INIT); - - // INIT -> ROLLBACK | ITERATION_INIT | EXECUTE_COMPUTE. - this.stateMachineManager.addTransition(INIT, ROLLBACK, new Init2RollbackTransitionCondition()); - this.stateMachineManager.addTransition(INIT, ITERATION_INIT, new InitIterationTransitionCondition()); - this.stateMachineManager.addTransition(INIT, EXECUTE_COMPUTE); - - // ROLLBACK -> ITERATION_INIT | EXECUTE_COMPUTE. - this.stateMachineManager.addTransition(ROLLBACK, ITERATION_INIT, new InitIterationTransitionCondition()); - this.stateMachineManager.addTransition(ROLLBACK, EXECUTE_COMPUTE); - - // ITERATION_INIT -> EXECUTE_COMPUTE | PREFETCH | ITERATION_FINISH. - this.stateMachineManager.addTransition(ITERATION_INIT, EXECUTE_COMPUTE, new ComputeTransitionCondition()); - this.stateMachineManager.addTransition(ITERATION_INIT, PREFETCH, new Compute2PrefetchTransitionCondition()); - this.stateMachineManager.addTransition(ITERATION_INIT, ITERATION_FINISH, new FinishTransitionCondition()); - - // EXECUTE_COMPUTE -> EXECUTE_COMPUTE | ITERATION_FINISH | FINISH_PREFETCH | CLEAN_CYCLE. - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, EXECUTE_COMPUTE, new ComputeTransitionCondition()); - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, PREFETCH, new Compute2PrefetchTransitionCondition()); - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, ITERATION_FINISH, new FinishTransitionCondition()); - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, FINISH_PREFETCH, - new Compute2FinishPrefetchTransitionCondition()); - this.stateMachineManager.addTransition(EXECUTE_COMPUTE, CLEAN_CYCLE, new CleanTransitionCondition()); - - // ITERATION_FINISH -> FINISH_PREFETCH | CLEAN_CYCLE. - this.stateMachineManager.addTransition(ITERATION_FINISH, FINISH_PREFETCH, - new FinishPrefetchTransitionCondition()); - this.stateMachineManager.addTransition(ITERATION_FINISH, CLEAN_CYCLE); - - // FINISH_PREFETCH -> CLEAN_CYCLE. - this.stateMachineManager.addTransition(FINISH_PREFETCH, CLEAN_CYCLE); - - // CLEAN_CYCLE -> END. - this.stateMachineManager.addTransition(CLEAN_CYCLE, END); + private static final ScheduleState INIT = ScheduleState.of(ScheduleStateType.INIT); + private static final ScheduleState PREFETCH = ScheduleState.of(ScheduleStateType.PREFETCH); + private static final ScheduleState FINISH_PREFETCH = + ScheduleState.of(ScheduleStateType.FINISH_PREFETCH); + private static final ScheduleState ITERATION_INIT = + ScheduleState.of(ScheduleStateType.ITERATION_INIT); + private static final ScheduleState EXECUTE_COMPUTE = + ScheduleState.of(ScheduleStateType.EXECUTE_COMPUTE); + private static final ScheduleState ROLLBACK = ScheduleState.of(ScheduleStateType.ROLLBACK); + private static final ScheduleState ITERATION_FINISH = + ScheduleState.of(ScheduleStateType.ITERATION_FINISH); + private static final ScheduleState CLEAN_CYCLE = ScheduleState.of(ScheduleStateType.CLEAN_CYCLE); + + @Override + public void init(ICycleSchedulerContext context) { + super.init(context); + + // Build state machine. + // START -> ROLLBACK | PREFETCH | INIT. + this.stateMachineManager.addTransition( + START, ROLLBACK, new Start2RollbackTransitionCondition()); + this.stateMachineManager.addTransition( + START, PREFETCH, new Start2PrefetchTransitionCondition()); + this.stateMachineManager.addTransition(START, INIT); + + // PREFETCH -> ITERATION_FINISH | INIT. + this.stateMachineManager.addTransition( + PREFETCH, ITERATION_FINISH, new FinishTransitionCondition()); + this.stateMachineManager.addTransition(PREFETCH, INIT); + + // INIT -> ROLLBACK | ITERATION_INIT | EXECUTE_COMPUTE. + this.stateMachineManager.addTransition(INIT, ROLLBACK, new Init2RollbackTransitionCondition()); + this.stateMachineManager.addTransition( + INIT, ITERATION_INIT, new InitIterationTransitionCondition()); + this.stateMachineManager.addTransition(INIT, EXECUTE_COMPUTE); + + // ROLLBACK -> ITERATION_INIT | EXECUTE_COMPUTE. + this.stateMachineManager.addTransition( + ROLLBACK, ITERATION_INIT, new InitIterationTransitionCondition()); + this.stateMachineManager.addTransition(ROLLBACK, EXECUTE_COMPUTE); + + // ITERATION_INIT -> EXECUTE_COMPUTE | PREFETCH | ITERATION_FINISH. + this.stateMachineManager.addTransition( + ITERATION_INIT, EXECUTE_COMPUTE, new ComputeTransitionCondition()); + this.stateMachineManager.addTransition( + ITERATION_INIT, PREFETCH, new Compute2PrefetchTransitionCondition()); + this.stateMachineManager.addTransition( + ITERATION_INIT, ITERATION_FINISH, new FinishTransitionCondition()); + + // EXECUTE_COMPUTE -> EXECUTE_COMPUTE | ITERATION_FINISH | FINISH_PREFETCH | CLEAN_CYCLE. + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, EXECUTE_COMPUTE, new ComputeTransitionCondition()); + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, PREFETCH, new Compute2PrefetchTransitionCondition()); + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, ITERATION_FINISH, new FinishTransitionCondition()); + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, FINISH_PREFETCH, new Compute2FinishPrefetchTransitionCondition()); + this.stateMachineManager.addTransition( + EXECUTE_COMPUTE, CLEAN_CYCLE, new CleanTransitionCondition()); + + // ITERATION_FINISH -> FINISH_PREFETCH | CLEAN_CYCLE. + this.stateMachineManager.addTransition( + ITERATION_FINISH, FINISH_PREFETCH, new FinishPrefetchTransitionCondition()); + this.stateMachineManager.addTransition(ITERATION_FINISH, CLEAN_CYCLE); + + // FINISH_PREFETCH -> CLEAN_CYCLE. + this.stateMachineManager.addTransition(FINISH_PREFETCH, CLEAN_CYCLE); + + // CLEAN_CYCLE -> END. + this.stateMachineManager.addTransition(CLEAN_CYCLE, END); + } + + @Override + public IScheduleState transition() { + List states = new ArrayList<>(); + transition(currentState, states); + if (states.isEmpty()) { + return null; + } else { + if (states.size() == 1) { + return states.get(0); + } else { + return ComposeState.of(states); + } } - - @Override - public IScheduleState transition() { - List states = new ArrayList<>(); - transition(currentState, states); - if (states.isEmpty()) { - return null; - } else { - if (states.size() == 1) { - return states.get(0); - } else { - return ComposeState.of(states); - } - } + } + + private void transition(ScheduleState source, List results) { + ScheduleState target = stateMachineManager.transition(source, context); + if (target != null) { + if (END == target) { + currentState = END; + return; + } + + // Not allow two execution state compose. + if (!composable( + results.isEmpty() ? null : (ScheduleState) results.get(results.size() - 1), target)) { + return; + } + + currentState = ScheduleState.of(target.getScheduleStateType()); + results.add(currentState); + if (target.getScheduleStateType() == ScheduleStateType.ITERATION_FINISH + && source.getScheduleStateType() == ScheduleStateType.PREFETCH) { + return; + } + transition(currentState, results); } + } - private void transition(ScheduleState source, List results) { - ScheduleState target = stateMachineManager.transition(source, context); - if (target != null) { - if (END == target) { - currentState = END; - return; - } - - // Not allow two execution state compose. - if (!composable(results.isEmpty() ? null : (ScheduleState) results.get(results.size() - 1), target)) { - return; - } - - currentState = ScheduleState.of(target.getScheduleStateType()); - results.add(currentState); - if (target.getScheduleStateType() == ScheduleStateType.ITERATION_FINISH && source.getScheduleStateType() == ScheduleStateType.PREFETCH) { - return; - } - transition(currentState, results); - } + private boolean composable(ScheduleState previous, ScheduleState current) { + if (previous == null || current == null) { + return true; } - - private boolean composable(ScheduleState previous, ScheduleState current) { - if (previous == null || current == null) { - return true; - } - /*if (context.getCycle() instanceof ExecutionNodeCycle) { - if (((ExecutionNodeCycle) context.getCycle()).getVertexGroup().getVertexMap().size() > 1) { - return false; - } - }*/ - // Not allow two execution state compose. - if ((previous.getScheduleStateType() == ScheduleStateType.ITERATION_INIT || previous.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE) - && current.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE) { + /*if (context.getCycle() instanceof ExecutionNodeCycle) { + if (((ExecutionNodeCycle) context.getCycle()).getVertexGroup().getVertexMap().size() > 1) { return false; } - return true; + }*/ + // Not allow two execution state compose. + if ((previous.getScheduleStateType() == ScheduleStateType.ITERATION_INIT + || previous.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE) + && current.getScheduleStateType() == ScheduleStateType.EXECUTE_COMPUTE) { + return false; } + return true; + } - public static class Start2RollbackTransitionCondition - implements ITransitionCondition { + public static class Start2RollbackTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - if (context instanceof CheckpointSchedulerContext) { - return context.isRecovered(); - } - return false; - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + if (context instanceof CheckpointSchedulerContext) { + return context.isRecovered(); + } + return false; } + } - public static class Start2PrefetchTransitionCondition - implements ITransitionCondition { + public static class Start2PrefetchTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return context.isPrefetch() && !((ExecutionNodeCycle) context.getCycle()).isIterative(); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return context.isPrefetch() && !((ExecutionNodeCycle) context.getCycle()).isIterative(); } + } - public static class Compute2PrefetchTransitionCondition - implements ITransitionCondition { + public static class Compute2PrefetchTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - // When the iteration is finished and prefetch is enable, we need to prefetch for next iteration. - return context.isCycleFinished() && (context.getCycle().getType() == ExecutionCycleType.ITERATION - || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG) && context.isPrefetch(); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + // When the iteration is finished and prefetch is enable, we need to prefetch for next + // iteration. + return context.isCycleFinished() + && (context.getCycle().getType() == ExecutionCycleType.ITERATION + || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG) + && context.isPrefetch(); } + } - public static class Init2RollbackTransitionCondition - implements ITransitionCondition { + public static class Init2RollbackTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - if (context.isRollback()) { - ((AbstractCycleSchedulerContext) context).setRollback(false); - return true; - } - return false; - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + if (context.isRollback()) { + ((AbstractCycleSchedulerContext) context).setRollback(false); + return true; + } + return false; } + } - public static class InitIterationTransitionCondition - implements ITransitionCondition { + public static class InitIterationTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return ((ExecutionNodeCycle) context.getCycle()).getVertexGroup().getCycleGroupMeta().isIterative(); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return ((ExecutionNodeCycle) context.getCycle()) + .getVertexGroup() + .getCycleGroupMeta() + .isIterative(); } + } - public static class FinishTransitionCondition - implements ITransitionCondition { + public static class FinishTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return context.isCycleFinished() && (context.getCycle().getType() == ExecutionCycleType.ITERATION - || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return context.isCycleFinished() + && (context.getCycle().getType() == ExecutionCycleType.ITERATION + || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG); } + } - public static class FinishPrefetchTransitionCondition - implements ITransitionCondition { + public static class FinishPrefetchTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - if (context.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH)) { - Map needFinishedPrefetchEvents = context.getPrefetchEvents(); - return needFinishedPrefetchEvents != null && needFinishedPrefetchEvents.size() > 0; - } - return false; - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + if (context.getConfig().getBoolean(ExecutionConfigKeys.SHUFFLE_PREFETCH)) { + Map needFinishedPrefetchEvents = context.getPrefetchEvents(); + return needFinishedPrefetchEvents != null && needFinishedPrefetchEvents.size() > 0; + } + return false; } + } - public static class CleanTransitionCondition - implements ITransitionCondition { + public static class CleanTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - return context.isCycleFinished() && !(context.getCycle().getType() == ExecutionCycleType.ITERATION - || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG); - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + return context.isCycleFinished() + && !(context.getCycle().getType() == ExecutionCycleType.ITERATION + || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG); } + } - public static class Compute2FinishPrefetchTransitionCondition - implements ITransitionCondition { + public static class Compute2FinishPrefetchTransitionCondition + implements ITransitionCondition { - @Override - public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { - if (context.isCycleFinished() && !(context.getCycle().getType() == ExecutionCycleType.ITERATION - || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG) && context.isPrefetch()) { - Map needFinishedPrefetchEvents = context.getPrefetchEvents(); - return needFinishedPrefetchEvents != null && needFinishedPrefetchEvents.size() > 0; - } - return false; - } + @Override + public boolean predicate(ScheduleState state, ICycleSchedulerContext context) { + if (context.isCycleFinished() + && !(context.getCycle().getType() == ExecutionCycleType.ITERATION + || context.getCycle().getType() == ExecutionCycleType.ITERATION_WITH_AGG) + && context.isPrefetch()) { + Map needFinishedPrefetchEvents = context.getPrefetchEvents(); + return needFinishedPrefetchEvents != null && needFinishedPrefetchEvents.size() > 0; + } + return false; } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/IScheduleStrategy.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/IScheduleStrategy.java index 2f1027f4f..1fefa69f2 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/IScheduleStrategy.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/IScheduleStrategy.java @@ -21,23 +21,15 @@ public interface IScheduleStrategy { - /** - * Initialize a schedule strategy. - */ - void init(G graph); + /** Initialize a schedule strategy. */ + void init(G graph); - /** - * Check whether the graph has cycle to schedule. - */ - boolean hasNext(); + /** Check whether the graph has cycle to schedule. */ + boolean hasNext(); - /** - * Offer the next cycle that can be scheduled. - */ - V next(); + /** Offer the next cycle that can be scheduled. */ + V next(); - /** - * Do something if cycle is finished. - */ - void finish(V cycle); + /** Do something if cycle is finished. */ + void finish(V cycle); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/TopologicalOrderScheduleStrategy.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/TopologicalOrderScheduleStrategy.java index 30369a21d..76cc678c4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/TopologicalOrderScheduleStrategy.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/scheduler/strategy/TopologicalOrderScheduleStrategy.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.LinkedBlockingDeque; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.runtime.core.scheduler.cycle.ExecutionGraphCycle; @@ -33,111 +34,108 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class TopologicalOrderScheduleStrategy implements IScheduleStrategy { +public class TopologicalOrderScheduleStrategy + implements IScheduleStrategy { - private static final Logger LOGGER = LoggerFactory.getLogger(TopologicalOrderScheduleStrategy.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(TopologicalOrderScheduleStrategy.class); - private ExecutionGraphCycle graph; - private LinkedBlockingDeque waiting; - private LinkedBlockingDeque running; + private ExecutionGraphCycle graph; + private LinkedBlockingDeque waiting; + private LinkedBlockingDeque running; - private Configuration config; + private Configuration config; - // All stage ids that already finished. - private Set finishedIds; + // All stage ids that already finished. + private Set finishedIds; - public TopologicalOrderScheduleStrategy(Configuration config) { - this.config = config; - } + public TopologicalOrderScheduleStrategy(Configuration config) { + this.config = config; + } - @Override - public void init(ExecutionGraphCycle graph) { - this.graph = graph; - this.waiting = new LinkedBlockingDeque<>(); - this.running = new LinkedBlockingDeque<>(); - this.finishedIds = new HashSet<>(); + @Override + public void init(ExecutionGraphCycle graph) { + this.graph = graph; + this.waiting = new LinkedBlockingDeque<>(); + this.running = new LinkedBlockingDeque<>(); + this.finishedIds = new HashSet<>(); - // Find head vertex. - List heads = graph.getCycleParents().entrySet().stream() + // Find head vertex. + List heads = + graph.getCycleParents().entrySet().stream() .filter(e -> e.getValue().isEmpty()) .map(e -> graph.getCycleMap().get(e.getKey())) .sorted(Comparator.comparingInt(IExecutionCycle::getCycleId)) .collect(Collectors.toList()); - // Add head to waiting list. - waiting.addAll(heads); + // Add head to waiting list. + waiting.addAll(heads); + } + + @Override + public boolean hasNext() { + return !waiting.isEmpty(); + } + + @Override + public IExecutionCycle next() { + IExecutionCycle cycle = null; + try { + cycle = waiting.takeFirst(); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException("interrupted when waiting the cycle ready to schedule", e); } - - @Override - public boolean hasNext() { - return !waiting.isEmpty(); + running.addLast(cycle); + return cycle; + } + + @Override + public synchronized void finish(IExecutionCycle cycle) { + finishedIds.add(cycle.getCycleId()); + // Recursively check and pop the head element of running stage queue if it finished. + // To make sure that a stage the earlier added into running queue, the earlier removed. + while (!running.isEmpty() && finishedIds.contains(running.peek().getCycleId())) { + + IExecutionCycle triggerCycle = running.remove(); + triggerChildren(triggerCycle); } + } - @Override - public IExecutionCycle next() { - IExecutionCycle cycle = null; - try { - cycle = waiting.takeFirst(); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException("interrupted when waiting the cycle ready to schedule", e); - } - running.addLast(cycle); - return cycle; - } + /** Add the children to waiting list if necessary. */ + private void triggerChildren(IExecutionCycle cycle) { - @Override - public synchronized void finish(IExecutionCycle cycle) { - finishedIds.add(cycle.getCycleId()); - // Recursively check and pop the head element of running stage queue if it finished. - // To make sure that a stage the earlier added into running queue, the earlier removed. - while (!running.isEmpty() - && finishedIds.contains(running.peek().getCycleId())) { + List readyToStartGroups = new ArrayList<>(); - IExecutionCycle triggerCycle = running.remove(); - triggerChildren(triggerCycle); + for (int childId : graph.getCycleChildren().get(cycle.getCycleId())) { + IExecutionCycle child = graph.getCycleMap().get(childId); + boolean childParentAllDone = true; + for (Integer childParentGroupId : graph.getCycleParents().get(childId)) { + if (!finishedIds.contains(childParentGroupId)) { + childParentAllDone = false; + break; } + } + if (childParentAllDone) { + readyToStartGroups.add(child); + } } - - /** - * Add the children to waiting list if necessary. - */ - private void triggerChildren(IExecutionCycle cycle) { - - List readyToStartGroups = new ArrayList<>(); - - for (int childId : graph.getCycleChildren().get(cycle.getCycleId())) { - IExecutionCycle child = graph.getCycleMap().get(childId); - boolean childParentAllDone = true; - for (Integer childParentGroupId : graph.getCycleParents().get(childId)) { - if (!finishedIds.contains(childParentGroupId)) { - childParentAllDone = false; - break; - } - } - if (childParentAllDone) { - readyToStartGroups.add(child); - } - } - if (!readyToStartGroups.isEmpty()) { - LOGGER.info("current waiting stages {}, new add stages {}", - waiting.stream().map(e -> e.getCycleId()).collect(Collectors.toList()), - readyToStartGroups.stream().map(e -> e.getCycleId()).collect(Collectors.toList())); - for (IExecutionCycle group : readyToStartGroups) { - addToWaiting(group); - } - } + if (!readyToStartGroups.isEmpty()) { + LOGGER.info( + "current waiting stages {}, new add stages {}", + waiting.stream().map(e -> e.getCycleId()).collect(Collectors.toList()), + readyToStartGroups.stream().map(e -> e.getCycleId()).collect(Collectors.toList())); + for (IExecutionCycle group : readyToStartGroups) { + addToWaiting(group); + } } - - /** - * Add stage into waiting list. - */ - private synchronized void addToWaiting(IExecutionCycle cycle) { - // Avoid add a certain stage into waiting list multi-times. - if (!waiting.stream().anyMatch(e -> e.getCycleId() == cycle.getCycleId())) { - waiting.add(cycle); - } else { - LOGGER.info("cycle {} already added to waiting queue", cycle.getCycleId()); - } + } + + /** Add stage into waiting list. */ + private synchronized void addToWaiting(IExecutionCycle cycle) { + // Avoid add a certain stage into waiting list multi-times. + if (!waiting.stream().anyMatch(e -> e.getCycleId() == cycle.getCycleId())) { + waiting.add(cycle); + } else { + LOGGER.info("cycle {} already added to waiting queue", cycle.getCycleId()); } - + } } - diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractAlignedWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractAlignedWorker.java index a30208120..19c0006a0 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractAlignedWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractAlignedWorker.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.trait.CancellableTrait; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.shuffle.message.PipelineMessage; @@ -31,83 +32,83 @@ public abstract class AbstractAlignedWorker extends AbstractComputeWorker { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractAlignedWorker.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractAlignedWorker.class); - protected Map>> windowMessageCache; + protected Map>> windowMessageCache; - public AbstractAlignedWorker() { - super(); - this.windowMessageCache = new HashMap<>(); - } + public AbstractAlignedWorker() { + super(); + this.windowMessageCache = new HashMap<>(); + } - /** - * Trigger worker to process message, need cache message in aligned worker. - */ - @Override - protected void processMessage(long windowId, PipelineMessage message) { - if (windowId > context.getCurrentWindowId()) { - if (windowMessageCache.containsKey(windowId)) { - windowMessageCache.get(windowId).add(message); - } else { - List> cache = new ArrayList<>(); - cache.add(message); - windowMessageCache.put(windowId, cache); - } - } else { - processMessageEvent(windowId, message); - } + /** Trigger worker to process message, need cache message in aligned worker. */ + @Override + protected void processMessage(long windowId, PipelineMessage message) { + if (windowId > context.getCurrentWindowId()) { + if (windowMessageCache.containsKey(windowId)) { + windowMessageCache.get(windowId).add(message); + } else { + List> cache = new ArrayList<>(); + cache.add(message); + windowMessageCache.put(windowId, cache); + } + } else { + processMessageEvent(windowId, message); } + } - /** - * Trigger worker to process buffered message. - */ - @Override - protected void processBarrier(long windowId, long totalCount) { - processBufferedMessages(windowId); - - long processCount = 0; - if (windowCount.containsKey(windowId)) { - processCount = windowCount.remove(windowId); - } + /** Trigger worker to process buffered message. */ + @Override + protected void processBarrier(long windowId, long totalCount) { + processBufferedMessages(windowId); - if (totalCount != processCount) { - LOGGER.error("taskId {} {} mismatch, TotalCount:{} != ProcessCount:{}", - context.getTaskId(), totalCount, totalCount, processCount); - throw new GeaflowRuntimeException(String.format("taskId %s mismatch, TotalCount:%s != ProcessCount:%s", - context.getTaskId(), totalCount, processCount)); - } - context.getEventMetrics().addShuffleReadRecords(totalCount); - - long currentWindowId = context.getCurrentWindowId(); - finish(currentWindowId); - updateWindowId(currentWindowId + 1); + long processCount = 0; + if (windowCount.containsKey(windowId)) { + processCount = windowCount.remove(windowId); } - /** - * Process buffered messages. - */ - private void processBufferedMessages(long windowId) { - if (windowMessageCache.containsKey(windowId)) { - List> cacheMessages = windowMessageCache.get(windowId); - for (PipelineMessage message : cacheMessages) { - processMessageEvent(windowId, message); - } - windowMessageCache.remove(windowId); - } + if (totalCount != processCount) { + LOGGER.error( + "taskId {} {} mismatch, TotalCount:{} != ProcessCount:{}", + context.getTaskId(), + totalCount, + totalCount, + processCount); + throw new GeaflowRuntimeException( + String.format( + "taskId %s mismatch, TotalCount:%s != ProcessCount:%s", + context.getTaskId(), totalCount, processCount)); } + context.getEventMetrics().addShuffleReadRecords(totalCount); + + long currentWindowId = context.getCurrentWindowId(); + finish(currentWindowId); + updateWindowId(currentWindowId + 1); + } - @Override - public void interrupt() { - this.running = false; - if (context.getProcessor() instanceof CancellableTrait) { - ((CancellableTrait) context.getProcessor()).cancel(); - } + /** Process buffered messages. */ + private void processBufferedMessages(long windowId) { + if (windowMessageCache.containsKey(windowId)) { + List> cacheMessages = windowMessageCache.get(windowId); + for (PipelineMessage message : cacheMessages) { + processMessageEvent(windowId, message); + } + windowMessageCache.remove(windowId); } + } - @Override - public void close() { - super.close(); - windowCount.clear(); - windowMessageCache.clear(); + @Override + public void interrupt() { + this.running = false; + if (context.getProcessor() instanceof CancellableTrait) { + ((CancellableTrait) context.getProcessor()).cancel(); } + } + + @Override + public void close() { + super.close(); + windowCount.clear(); + windowMessageCache.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractComputeWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractComputeWorker.java index de0141c97..dbf5208cb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractComputeWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractComputeWorker.java @@ -19,7 +19,6 @@ package org.apache.geaflow.runtime.core.worker; -import com.google.common.base.Preconditions; import org.apache.geaflow.api.trait.TransactionTrait; import org.apache.geaflow.cluster.worker.IAffinityWorker; import org.apache.geaflow.cluster.worker.IWorkerContext; @@ -32,93 +31,96 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractComputeWorker extends AbstractWorker implements TransactionTrait, IAffinityWorker { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractComputeWorker.class); - - private boolean isTransactionProcessor; - - public AbstractComputeWorker() { - super(); - } - - @Override - public void open(IWorkerContext workerContext) { - super.open(workerContext); - LOGGER.info("open processor"); - context.getProcessor().open( - context.getCollectors(), - context.getRuntimeContext() - ); - this.isTransactionProcessor = context.getProcessor() instanceof TransactionTrait; - } - - @Override - public void init(long windowId) { - LOGGER.info("taskId {} init windowId {}", context.getTaskId(), windowId); - updateWindowId(windowId); - context.getProcessor().init(windowId); - } - - @Override - public R process(BatchRecord batchRecord) { - return (R) context.getProcessor().process(batchRecord); - } - - @Override - public void finish(long windowId) { - LOGGER.info("taskId {} finishes windowId {}, currentBatchId {}", - context.getTaskId(), windowId, context.getCurrentWindowId()); - context.getProcessor().finish(windowId); - finishWindow(context.getCurrentWindowId()); - } - - @Override - public void rollback(long windowId) { - LOGGER.info("taskId {} rollback windowId {}", context.getTaskId(), windowId); - if (isTransactionProcessor) { - ((TransactionTrait) context.getProcessor()).rollback(windowId); - } - updateWindowId(windowId + 1); - } +import com.google.common.base.Preconditions; - @Override - public void stash() { - // Stash current worker context. - WorkerContextManager.register(context.getTaskId(), (WorkerContext) context); - context = null; +public abstract class AbstractComputeWorker extends AbstractWorker + implements TransactionTrait, IAffinityWorker { + + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractComputeWorker.class); + + private boolean isTransactionProcessor; + + public AbstractComputeWorker() { + super(); + } + + @Override + public void open(IWorkerContext workerContext) { + super.open(workerContext); + LOGGER.info("open processor"); + context.getProcessor().open(context.getCollectors(), context.getRuntimeContext()); + this.isTransactionProcessor = context.getProcessor() instanceof TransactionTrait; + } + + @Override + public void init(long windowId) { + LOGGER.info("taskId {} init windowId {}", context.getTaskId(), windowId); + updateWindowId(windowId); + context.getProcessor().init(windowId); + } + + @Override + public R process(BatchRecord batchRecord) { + return (R) context.getProcessor().process(batchRecord); + } + + @Override + public void finish(long windowId) { + LOGGER.info( + "taskId {} finishes windowId {}, currentBatchId {}", + context.getTaskId(), + windowId, + context.getCurrentWindowId()); + context.getProcessor().finish(windowId); + finishWindow(context.getCurrentWindowId()); + } + + @Override + public void rollback(long windowId) { + LOGGER.info("taskId {} rollback windowId {}", context.getTaskId(), windowId); + if (isTransactionProcessor) { + ((TransactionTrait) context.getProcessor()).rollback(windowId); } - - @Override - public void pop(IWorkerContext workerContext) { - AbstractWorkerContext popWorkerContext = (AbstractWorkerContext) workerContext; - context = (AbstractWorkerContext) WorkerContextManager.get(popWorkerContext.getTaskId()); - Preconditions.checkArgument(context != null, "not found any context"); - - final long pipelineId = popWorkerContext.getPipelineId(); - final String pipelineName = popWorkerContext.getPipelineName(); - final int cycleId = popWorkerContext.getCycleId(); - final long windowId = popWorkerContext.getWindowId(); - final long schedulerId = popWorkerContext.getSchedulerId(); - - context.setPipelineId(pipelineId); - context.setPipelineName(pipelineName); - context.setWindowId(windowId); - context.setSchedulerId(schedulerId); - context.getExecutionTask().buildTaskName(pipelineName, cycleId, windowId); - context.initEventMetrics(); - - // Update runtime context. - DefaultRuntimeContext runtimeContext = (DefaultRuntimeContext) context.getRuntimeContext(); - runtimeContext.setPipelineId(pipelineId); - runtimeContext.setPipelineName(pipelineName); - runtimeContext.setWindowId(windowId); - runtimeContext.setTaskArgs(context.getExecutionTask().buildTaskArgs()); - - // Update collectors. - for (ICollector collector : context.getCollectors()) { - LOGGER.info("setup collector {}", runtimeContext.getTaskArgs()); - collector.setUp(runtimeContext); - } + updateWindowId(windowId + 1); + } + + @Override + public void stash() { + // Stash current worker context. + WorkerContextManager.register(context.getTaskId(), (WorkerContext) context); + context = null; + } + + @Override + public void pop(IWorkerContext workerContext) { + AbstractWorkerContext popWorkerContext = (AbstractWorkerContext) workerContext; + context = (AbstractWorkerContext) WorkerContextManager.get(popWorkerContext.getTaskId()); + Preconditions.checkArgument(context != null, "not found any context"); + + final long pipelineId = popWorkerContext.getPipelineId(); + final String pipelineName = popWorkerContext.getPipelineName(); + final int cycleId = popWorkerContext.getCycleId(); + final long windowId = popWorkerContext.getWindowId(); + final long schedulerId = popWorkerContext.getSchedulerId(); + + context.setPipelineId(pipelineId); + context.setPipelineName(pipelineName); + context.setWindowId(windowId); + context.setSchedulerId(schedulerId); + context.getExecutionTask().buildTaskName(pipelineName, cycleId, windowId); + context.initEventMetrics(); + + // Update runtime context. + DefaultRuntimeContext runtimeContext = (DefaultRuntimeContext) context.getRuntimeContext(); + runtimeContext.setPipelineId(pipelineId); + runtimeContext.setPipelineName(pipelineName); + runtimeContext.setWindowId(windowId); + runtimeContext.setTaskArgs(context.getExecutionTask().buildTaskArgs()); + + // Update collectors. + for (ICollector collector : context.getCollectors()) { + LOGGER.info("setup collector {}", runtimeContext.getTaskArgs()); + collector.setUp(runtimeContext); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractUnAlignedWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractUnAlignedWorker.java index af1ffdd88..755c2eee7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractUnAlignedWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractUnAlignedWorker.java @@ -23,6 +23,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.cluster.exception.ComponentUncaughtExceptionHandler; import org.apache.geaflow.cluster.protocol.InputMessage; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -34,155 +35,167 @@ public abstract class AbstractUnAlignedWorker extends AbstractComputeWorker { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractUnAlignedWorker.class); - - private static final String WORKER_FORMAT = "geaflow-asp-worker-"; - private static final int DEFAULT_TIMEOUT_MS = 100; - - protected ExecutorService executorService; - protected BlockingQueue processingWindowIdQueue; - - public AbstractUnAlignedWorker() { - super(); - this.processingWindowIdQueue = new LinkedBlockingDeque<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractUnAlignedWorker.class); + + private static final String WORKER_FORMAT = "geaflow-asp-worker-"; + private static final int DEFAULT_TIMEOUT_MS = 100; + + protected ExecutorService executorService; + protected BlockingQueue processingWindowIdQueue; + + public AbstractUnAlignedWorker() { + super(); + this.processingWindowIdQueue = new LinkedBlockingDeque<>(); + } + + @Override + public void process(long fetchCount, boolean isAligned) { + // TODO Currently LoadGraphProcessEvent need align processing. + if (isAligned) { + alignedProcess(fetchCount); + } else { + if (executorService == null) { + LOGGER.info("taskId {} unaligned worker has been shutdown, start...", context.getTaskId()); + startTask(); + } } - - @Override - public void process(long fetchCount, boolean isAligned) { - // TODO Currently LoadGraphProcessEvent need align processing. - if (isAligned) { - alignedProcess(fetchCount); + } + + private void startTask() { + long start = System.currentTimeMillis(); + this.executorService = + Executors.getExecutorService( + 1, + WORKER_FORMAT + context.getTaskId() + "-%d", + ComponentUncaughtExceptionHandler.INSTANCE); + executorService.execute(new WorkerTask()); + LOGGER.info( + "taskId {} start task cost {}ms", + context != null ? context.getTaskId() : "null", + System.currentTimeMillis() - start); + } + + public void alignedProcess(long fetchCount) { + super.process(fetchCount, true); + } + + public void unalignedProcess() { + try { + InputMessage input = inputReader.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); + if (input != null) { + long windowId = input.getWindowId(); + if (input.getMessage() != null) { + PipelineMessage message = input.getMessage(); + processMessage(windowId, message); } else { - if (executorService == null) { - LOGGER.info("taskId {} unaligned worker has been shutdown, start...", context.getTaskId()); - startTask(); - } + long totalCount = input.getWindowCount(); + processBarrier(windowId, totalCount); } + } + } catch (Throwable t) { + if (running) { + LOGGER.error(t.getMessage(), t); + throw new GeaflowRuntimeException(t); + } else { + LOGGER.warn("service closed {}", t.getMessage()); + } } - private void startTask() { - long start = System.currentTimeMillis(); - this.executorService = Executors.getExecutorService(1, WORKER_FORMAT + context.getTaskId() + "-%d", - ComponentUncaughtExceptionHandler.INSTANCE); - executorService.execute(new WorkerTask()); - LOGGER.info("taskId {} start task cost {}ms", context != null ? context.getTaskId() : "null", System.currentTimeMillis() - start); + if (!running) { + LOGGER.info("{} worker terminated", context == null ? "null" : context.getTaskId()); } - - public void alignedProcess(long fetchCount) { - super.process(fetchCount, true); + } + + /** Process message event and trigger worker to process. */ + @Override + protected void processMessage(long windowId, PipelineMessage message) { + processMessageEvent(windowId, message); + } + + /** Trigger worker to call processor finish. */ + @Override + protected void processBarrier(long windowId, long totalCount) { + long processedCount = windowCount.containsKey(windowId) ? windowCount.get(windowId) : 0; + finishBarrier(totalCount, processedCount); + LOGGER.info( + "taskId {} windowId {} process total {} messages", + context.getTaskId(), + windowId, + processedCount); + windowCount.remove(windowId); + } + + /** + * Verify the processed count and total count, and whether the window id currently processed is + * consistent with the window id in the context. + */ + protected void finishBarrier(long totalCount, long processedCount) { + if (totalCount != processedCount) { + LOGGER.error( + "taskId {} {} mismatch, TotalCount:{} != ProcessCount:{}", + context.getTaskId(), + totalCount, + totalCount, + processedCount); } + context.getEventMetrics().addShuffleReadRecords(totalCount); - public void unalignedProcess() { - try { - InputMessage input = inputReader.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); - if (input != null) { - long windowId = input.getWindowId(); - if (input.getMessage() != null) { - PipelineMessage message = input.getMessage(); - processMessage(windowId, message); - } else { - long totalCount = input.getWindowCount(); - processBarrier(windowId, totalCount); - } - } - } catch (Throwable t) { - if (running) { - LOGGER.error(t.getMessage(), t); - throw new GeaflowRuntimeException(t); - } else { - LOGGER.warn("service closed {}", t.getMessage()); - } - } - - if (!running) { - LOGGER.info("{} worker terminated", - context == null ? "null" : context.getTaskId()); - } + long currentWindowId; + try { + currentWindowId = processingWindowIdQueue.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } - - /** - * Process message event and trigger worker to process. - */ - @Override - protected void processMessage(long windowId, PipelineMessage message) { - processMessageEvent(windowId, message); + finish(currentWindowId); + + // Current window id must be in [context.getCurrentWindowId() - 1, + // context.getCurrentWindowId()]. + if (currentWindowId != context.getCurrentWindowId() + && currentWindowId != context.getCurrentWindowId() - 1) { + String errorMessage = + String.format( + "currentWindowId is %d from queue, id is %d from context", + currentWindowId, context.getCurrentWindowId()); + LOGGER.error(errorMessage); + throw new GeaflowRuntimeException(errorMessage); } - - /** - * Trigger worker to call processor finish. - */ - @Override - protected void processBarrier(long windowId, long totalCount) { - long processedCount = windowCount.containsKey(windowId) ? windowCount.get(windowId) : 0; - finishBarrier(totalCount, processedCount); - LOGGER.info("taskId {} windowId {} process total {} messages", context.getTaskId(), windowId, processedCount); - windowCount.remove(windowId); + super.init(currentWindowId + 1); + } + + @Override + public void interrupt() { + super.interrupt(); + if (executorService != null) { + ExecutorUtil.shutdown(executorService); + executorService = null; } - - /** - * Verify the processed count and total count, and whether the window id - * currently processed is consistent with the window id in the context. - */ - protected void finishBarrier(long totalCount, long processedCount) { - if (totalCount != processedCount) { - LOGGER.error("taskId {} {} mismatch, TotalCount:{} != ProcessCount:{}", - context.getTaskId(), totalCount, totalCount, processedCount); - } - context.getEventMetrics().addShuffleReadRecords(totalCount); - - long currentWindowId; - try { - currentWindowId = processingWindowIdQueue.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } - finish(currentWindowId); - - // Current window id must be in [context.getCurrentWindowId() - 1, context.getCurrentWindowId()]. - if (currentWindowId != context.getCurrentWindowId() && currentWindowId != context.getCurrentWindowId() - 1) { - String errorMessage = String.format("currentWindowId is %d from queue, id is %d from context", - currentWindowId, context.getCurrentWindowId()); - LOGGER.error(errorMessage); - throw new GeaflowRuntimeException(errorMessage); - } - super.init(currentWindowId + 1); + } + + @Override + public void close() { + super.close(); + this.running = false; + this.processingWindowIdQueue.clear(); + this.windowCount.clear(); + if (executorService != null) { + LOGGER.info("shutdown unaligned worker"); + ExecutorUtil.shutdown(executorService); + executorService = null; } + } - @Override - public void interrupt() { - super.interrupt(); - if (executorService != null) { - ExecutorUtil.shutdown(executorService); - executorService = null; - } - } + public class WorkerTask implements Runnable { @Override - public void close() { - super.close(); - this.running = false; - this.processingWindowIdQueue.clear(); - this.windowCount.clear(); - if (executorService != null) { - LOGGER.info("shutdown unaligned worker"); - ExecutorUtil.shutdown(executorService); - executorService = null; + public void run() { + try { + while (running) { + unalignedProcess(); } + } catch (Exception e) { + LOGGER.error("Unaligned process encounter exception ", e); + throw new GeaflowRuntimeException(e); + } } - - public class WorkerTask implements Runnable { - - @Override - public void run() { - try { - while (running) { - unalignedProcess(); - } - } catch (Exception e) { - LOGGER.error("Unaligned process encounter exception ", e); - throw new GeaflowRuntimeException(e); - } - } - } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractWorker.java index 018c1d002..e4d6e6a73 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/AbstractWorker.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.api.trait.CancellableTrait; import org.apache.geaflow.cluster.collector.AbstractPipelineOutputCollector; import org.apache.geaflow.cluster.protocol.EventType; @@ -47,141 +48,144 @@ public abstract class AbstractWorker implements IWorker, O> { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractWorker.class); - - private static final int DEFAULT_TIMEOUT_MS = 100; - - protected AbstractWorkerContext context; - protected InputReader inputReader; - protected Map windowCount; - protected volatile boolean running; - - public AbstractWorker() { - this.inputReader = new InputReader<>(); - this.windowCount = new HashMap<>(); - this.running = false; - } - - @Override - public void open(IWorkerContext workerContext) { - this.context = (AbstractWorkerContext) workerContext; - this.running = true; - } - - @Override - public IWorkerContext getWorkerContext() { - return context; - } - - public InputReader getInputReader() { - return inputReader; - } - - /** - * Fetch message from input queue and trigger aligned compute all the time, - * and finish until total batch count has fetched. - */ - public void process(long totalWindowCount, boolean isAligned) { - long processedWindowCount = 0; - long fetchCost = 0; - while (processedWindowCount < totalWindowCount && running) { - try { - long fetchStart = System.currentTimeMillis(); - InputMessage input = this.inputReader.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); - fetchCost += System.currentTimeMillis() - fetchStart; - if (input != null) { - long windowId = input.getWindowId(); - if (input.getMessage() != null) { - PipelineMessage message = input.getMessage(); - processMessage(windowId, message); - } else { - this.context.getEventMetrics().addShuffleReadCostMs(fetchCost); - fetchCost = 0; - long totalCount = input.getWindowCount(); - processBarrier(windowId, totalCount); - processedWindowCount++; - } - } - } catch (Throwable t) { - LOGGER.error(t.getMessage(), t); - throw new GeaflowRuntimeException(t); - } - } - if (!running) { - LOGGER.info("{} worker terminated", context.getTaskId()); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractWorker.class); + + private static final int DEFAULT_TIMEOUT_MS = 100; + + protected AbstractWorkerContext context; + protected InputReader inputReader; + protected Map windowCount; + protected volatile boolean running; + + public AbstractWorker() { + this.inputReader = new InputReader<>(); + this.windowCount = new HashMap<>(); + this.running = false; + } + + @Override + public void open(IWorkerContext workerContext) { + this.context = (AbstractWorkerContext) workerContext; + this.running = true; + } + + @Override + public IWorkerContext getWorkerContext() { + return context; + } + + public InputReader getInputReader() { + return inputReader; + } + + /** + * Fetch message from input queue and trigger aligned compute all the time, and finish until total + * batch count has fetched. + */ + public void process(long totalWindowCount, boolean isAligned) { + long processedWindowCount = 0; + long fetchCost = 0; + while (processedWindowCount < totalWindowCount && running) { + try { + long fetchStart = System.currentTimeMillis(); + InputMessage input = this.inputReader.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS); + fetchCost += System.currentTimeMillis() - fetchStart; + if (input != null) { + long windowId = input.getWindowId(); + if (input.getMessage() != null) { + PipelineMessage message = input.getMessage(); + processMessage(windowId, message); + } else { + this.context.getEventMetrics().addShuffleReadCostMs(fetchCost); + fetchCost = 0; + long totalCount = input.getWindowCount(); + processBarrier(windowId, totalCount); + processedWindowCount++; + } } + } catch (Throwable t) { + LOGGER.error(t.getMessage(), t); + throw new GeaflowRuntimeException(t); + } } - - /** - * Process message event and trigger worker to process. - */ - protected void processMessageEvent(long windowId, PipelineMessage message) { - IMessageIterator messageIterator = message.getMessageIterator(); - process(new BatchRecord<>(message.getRecordArgs(), messageIterator)); - - long count = messageIterator.getSize(); - messageIterator.close(); - - // Aggregate message not take into account when check message count. - if (!message.getRecordArgs().getName().equals(RecordArgs.GraphRecordNames.Aggregate.name())) { - if (!windowCount.containsKey(windowId)) { - windowCount.put(windowId, count); - } else { - long oldCounter = windowCount.get(windowId); - windowCount.put(windowId, oldCounter + count); - } - } + if (!running) { + LOGGER.info("{} worker terminated", context.getTaskId()); } - - /** - * Tell scheduler finish and send back response to scheduler. - */ - protected void finishWindow(long windowId) { - Map> results = new HashMap<>(); - List> collectors = this.context.getCollectors(); - if (ExecutionTaskUtils.isCycleTail(context.getExecutionTask())) { - for (int i = 0; i < collectors.size(); i++) { - IResultCollector responseCollector = (IResultCollector) collectors.get(i); - IResult result = (IResult) responseCollector.collectResult(); - if (result != null) { - results.put(result.getId(), result); - } - } - - // Tell scheduler finish or response. - EventMetrics eventMetrics = this.context.isIterativeTask() ? this.context.getEventMetrics() : null; - DoneEvent done = new DoneEvent<>(context.getSchedulerId(), context.getCycleId(), windowId, - context.getTaskId(), EventType.EXECUTE_COMPUTE, results, eventMetrics); - RpcClient.getInstance().processPipeline(context.getDriverId(), done); - } + } + + /** Process message event and trigger worker to process. */ + protected void processMessageEvent(long windowId, PipelineMessage message) { + IMessageIterator messageIterator = message.getMessageIterator(); + process(new BatchRecord<>(message.getRecordArgs(), messageIterator)); + + long count = messageIterator.getSize(); + messageIterator.close(); + + // Aggregate message not take into account when check message count. + if (!message.getRecordArgs().getName().equals(RecordArgs.GraphRecordNames.Aggregate.name())) { + if (!windowCount.containsKey(windowId)) { + windowCount.put(windowId, count); + } else { + long oldCounter = windowCount.get(windowId); + windowCount.put(windowId, oldCounter + count); + } } - - - protected void updateWindowId(long windowId) { - context.setCurrentWindowId(windowId); - for (ICollector collector : this.context.getCollectors()) { - if (collector instanceof AbstractPipelineOutputCollector) { - ((AbstractPipelineOutputCollector) collector).setWindowId(windowId); - } + } + + /** Tell scheduler finish and send back response to scheduler. */ + protected void finishWindow(long windowId) { + Map> results = new HashMap<>(); + List> collectors = this.context.getCollectors(); + if (ExecutionTaskUtils.isCycleTail(context.getExecutionTask())) { + for (int i = 0; i < collectors.size(); i++) { + IResultCollector responseCollector = (IResultCollector) collectors.get(i); + IResult result = (IResult) responseCollector.collectResult(); + if (result != null) { + results.put(result.getId(), result); } + } + + // Tell scheduler finish or response. + EventMetrics eventMetrics = + this.context.isIterativeTask() ? this.context.getEventMetrics() : null; + DoneEvent done = + new DoneEvent<>( + context.getSchedulerId(), + context.getCycleId(), + windowId, + context.getTaskId(), + EventType.EXECUTE_COMPUTE, + results, + eventMetrics); + RpcClient.getInstance().processPipeline(context.getDriverId(), done); + } + } + + protected void updateWindowId(long windowId) { + context.setCurrentWindowId(windowId); + for (ICollector collector : this.context.getCollectors()) { + if (collector instanceof AbstractPipelineOutputCollector) { + ((AbstractPipelineOutputCollector) collector).setWindowId(windowId); + } } + } - protected abstract void processMessage(long windowId, PipelineMessage message); + protected abstract void processMessage(long windowId, PipelineMessage message); - protected abstract void processBarrier(long windowId, long totalCount); + protected abstract void processBarrier(long windowId, long totalCount); - @Override - public void interrupt() { - this.running = false; - if (context.getProcessor() instanceof CancellableTrait) { - ((CancellableTrait) context.getProcessor()).cancel(); - } + @Override + public void interrupt() { + this.running = false; + if (context.getProcessor() instanceof CancellableTrait) { + ((CancellableTrait) context.getProcessor()).cancel(); } + } - @Override - public void close() { - if (context != null) { - context.close(); - } + @Override + public void close() { + if (context != null) { + context.close(); } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/InputReader.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/InputReader.java index c01209c30..028e3e361 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/InputReader.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/InputReader.java @@ -26,20 +26,20 @@ import org.apache.geaflow.shuffle.message.PipelineMessage; import org.apache.geaflow.shuffle.service.ShuffleManager; -public class InputReader extends AbstractMessageBuffer> implements IInputMessageBuffer { +public class InputReader extends AbstractMessageBuffer> + implements IInputMessageBuffer { - public InputReader() { - super(ShuffleManager.getInstance().getShuffleConfig().getFetchQueueSize()); - } + public InputReader() { + super(ShuffleManager.getInstance().getShuffleConfig().getFetchQueueSize()); + } - @Override - public void onMessage(PipelineMessage message) { - this.offer(new InputMessage<>(message)); - } - - @Override - public void onBarrier(PipelineBarrier barrier) { - this.offer(new InputMessage<>(barrier.getWindowId(), barrier.getCount())); - } + @Override + public void onMessage(PipelineMessage message) { + this.offer(new InputMessage<>(message)); + } + @Override + public void onBarrier(PipelineBarrier barrier) { + this.offer(new InputMessage<>(barrier.getWindowId(), barrier.getCount())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/OutputWriter.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/OutputWriter.java index db1a43be1..2826b0d7c 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/OutputWriter.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/OutputWriter.java @@ -23,6 +23,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.cluster.collector.IOutputMessageBuffer; import org.apache.geaflow.cluster.protocol.OutputMessage; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -33,82 +34,81 @@ public class OutputWriter extends AbstractMessageBuffer> implements IOutputMessageBuffer { - private CompletableFuture resultFuture = new CompletableFuture<>(); - - private final int edgeId; - private final int bufferSize; - private final ArrayList[] buffers; - private final AtomicReference err; + private CompletableFuture resultFuture = new CompletableFuture<>(); - public OutputWriter(int edgeId, int bucketNum) { - super(ShuffleManager.getInstance().getShuffleConfig().getEmitQueueSize()); - this.edgeId = edgeId; - this.bufferSize = ShuffleManager.getInstance().getShuffleConfig().getEmitBufferSize(); - this.buffers = new ArrayList[bucketNum]; - this.err = new AtomicReference<>(); - for (int i = 0; i < bucketNum; i++) { - this.buffers[i] = new ArrayList<>(this.bufferSize); - } - } + private final int edgeId; + private final int bufferSize; + private final ArrayList[] buffers; + private final AtomicReference err; - @Override - public void emit(long windowId, T data, boolean isRetract, int[] targetChannels) { - for (int channel : targetChannels) { - ArrayList buffer = this.buffers[channel]; - buffer.add(data); - if (buffer.size() == this.bufferSize) { - this.checkErr(); - long start = System.currentTimeMillis(); - this.offer(OutputMessage.data(windowId, channel, buffer)); - this.eventMetrics.addShuffleWriteCostMs(System.currentTimeMillis() - start); - this.buffers[channel] = new ArrayList<>(this.bufferSize); - } - } + public OutputWriter(int edgeId, int bucketNum) { + super(ShuffleManager.getInstance().getShuffleConfig().getEmitQueueSize()); + this.edgeId = edgeId; + this.bufferSize = ShuffleManager.getInstance().getShuffleConfig().getEmitBufferSize(); + this.buffers = new ArrayList[bucketNum]; + this.err = new AtomicReference<>(); + for (int i = 0; i < bucketNum; i++) { + this.buffers[i] = new ArrayList<>(this.bufferSize); } + } - @Override - public void setResult(long windowId, Shard result) { - this.resultFuture.complete(result); - } - - @Override - public Shard finish(long windowId) { + @Override + public void emit(long windowId, T data, boolean isRetract, int[] targetChannels) { + for (int channel : targetChannels) { + ArrayList buffer = this.buffers[channel]; + buffer.add(data); + if (buffer.size() == this.bufferSize) { this.checkErr(); long start = System.currentTimeMillis(); - for (int i = 0; i < this.buffers.length; i++) { - ArrayList buffer = this.buffers[i]; - if (!buffer.isEmpty()) { - this.offer(OutputMessage.data(windowId, i, buffer)); - this.buffers[i] = new ArrayList<>(this.bufferSize); - } - } + this.offer(OutputMessage.data(windowId, channel, buffer)); + this.eventMetrics.addShuffleWriteCostMs(System.currentTimeMillis() - start); + this.buffers[channel] = new ArrayList<>(this.bufferSize); + } + } + } + + @Override + public void setResult(long windowId, Shard result) { + this.resultFuture.complete(result); + } - this.offer(OutputMessage.barrier(windowId)); - try { - Shard shard = this.resultFuture.get(); - this.resultFuture = new CompletableFuture<>(); - this.eventMetrics.addShuffleWriteCostMs(System.currentTimeMillis() - start); - return shard; - } catch (InterruptedException | ExecutionException e) { - throw new GeaflowRuntimeException(e); - } + @Override + public Shard finish(long windowId) { + this.checkErr(); + long start = System.currentTimeMillis(); + for (int i = 0; i < this.buffers.length; i++) { + ArrayList buffer = this.buffers[i]; + if (!buffer.isEmpty()) { + this.offer(OutputMessage.data(windowId, i, buffer)); + this.buffers[i] = new ArrayList<>(this.bufferSize); + } } - @Override - public void error(Throwable t) { - if (this.err.get() == null) { - this.err.set(t); - } + this.offer(OutputMessage.barrier(windowId)); + try { + Shard shard = this.resultFuture.get(); + this.resultFuture = new CompletableFuture<>(); + this.eventMetrics.addShuffleWriteCostMs(System.currentTimeMillis() - start); + return shard; + } catch (InterruptedException | ExecutionException e) { + throw new GeaflowRuntimeException(e); } + } - public void checkErr() { - if (this.err.get() != null) { - throw new GeaflowRuntimeException(this.err.get()); - } + @Override + public void error(Throwable t) { + if (this.err.get() == null) { + this.err.set(t); } + } - public int getEdgeId() { - return this.edgeId; + public void checkErr() { + if (this.err.get() != null) { + throw new GeaflowRuntimeException(this.err.get()); } + } + public int getEdgeId() { + return this.edgeId; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/PrefetchCallbackHandler.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/PrefetchCallbackHandler.java index 1b94fa02e..32563f2e8 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/PrefetchCallbackHandler.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/PrefetchCallbackHandler.java @@ -21,47 +21,45 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.fetcher.PrefetchMessageBuffer; import org.apache.geaflow.shuffle.message.SliceId; public class PrefetchCallbackHandler { - private static volatile PrefetchCallbackHandler INSTANCE; + private static volatile PrefetchCallbackHandler INSTANCE; - private final Map taskId2callback = new ConcurrentHashMap<>(); + private final Map taskId2callback = new ConcurrentHashMap<>(); - public static PrefetchCallbackHandler getInstance() { + public static PrefetchCallbackHandler getInstance() { + if (INSTANCE == null) { + synchronized (PrefetchCallbackHandler.class) { if (INSTANCE == null) { - synchronized (PrefetchCallbackHandler.class) { - if (INSTANCE == null) { - INSTANCE = new PrefetchCallbackHandler(); - } - } + INSTANCE = new PrefetchCallbackHandler(); } - return INSTANCE; + } } + return INSTANCE; + } + public void registerTaskEventCallback(SliceId sliceId, PrefetchCallback prefetchCallback) { + this.taskId2callback.put(sliceId, prefetchCallback); + } - public void registerTaskEventCallback(SliceId sliceId, PrefetchCallback prefetchCallback) { - this.taskId2callback.put(sliceId, prefetchCallback); - } + public PrefetchCallback removeTaskEventCallback(SliceId sliceId) { + return this.taskId2callback.remove(sliceId); + } - public PrefetchCallback removeTaskEventCallback(SliceId sliceId) { - return this.taskId2callback.remove(sliceId); - } + public static class PrefetchCallback { - public static class PrefetchCallback { - - private final PrefetchMessageBuffer buffer; - - public PrefetchCallback(PrefetchMessageBuffer buffer) { - this.buffer = buffer; - } - - public void execute() { - this.buffer.waitUtilFinish(); - } + private final PrefetchMessageBuffer buffer; + public PrefetchCallback(PrefetchMessageBuffer buffer) { + this.buffer = buffer; } + public void execute() { + this.buffer.waitUtilFinish(); + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/WorkerFactory.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/WorkerFactory.java index 3974ebd61..a8610e20e 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/WorkerFactory.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/WorkerFactory.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.ServiceLoader; + import org.apache.geaflow.cluster.worker.IWorker; import org.apache.geaflow.cluster.worker.WorkerType; import org.apache.geaflow.common.config.Configuration; @@ -32,21 +33,24 @@ public class WorkerFactory { - private static final Logger LOGGER = LoggerFactory.getLogger(WorkerFactory.class); + private static final Logger LOGGER = LoggerFactory.getLogger(WorkerFactory.class); - public static IWorker createWorker(Configuration configuration) { - WorkerType workerType = configuration.getBoolean(FrameworkConfigKeys.ASP_ENABLE) - ? WorkerType.unaligned_compute : WorkerType.aligned_compute; + public static IWorker createWorker(Configuration configuration) { + WorkerType workerType = + configuration.getBoolean(FrameworkConfigKeys.ASP_ENABLE) + ? WorkerType.unaligned_compute + : WorkerType.aligned_compute; - ServiceLoader executorLoader = ServiceLoader.load(IWorker.class); - Iterator executorIterable = executorLoader.iterator(); - while (executorIterable.hasNext()) { - IWorker worker = executorIterable.next(); - if (worker.getWorkerType() == workerType) { - return worker; - } - } - LOGGER.error("NOT found IWorker implementation"); - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(IWorker.class.getSimpleName())); + ServiceLoader executorLoader = ServiceLoader.load(IWorker.class); + Iterator executorIterable = executorLoader.iterator(); + while (executorIterable.hasNext()) { + IWorker worker = executorIterable.next(); + if (worker.getWorkerType() == workerType) { + return worker; + } } + LOGGER.error("NOT found IWorker implementation"); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.spiNotFoundError(IWorker.class.getSimpleName())); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/AbstractWorkerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/AbstractWorkerContext.java index db6e58801..18aad7512 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/AbstractWorkerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/AbstractWorkerContext.java @@ -20,6 +20,7 @@ package org.apache.geaflow.runtime.core.worker.context; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.cluster.protocol.IEventContext; import org.apache.geaflow.cluster.task.ITaskContext; @@ -36,166 +37,164 @@ public abstract class AbstractWorkerContext implements IWorkerContext { - protected EventContext eventContext; - protected long currentWindowId; - protected int taskId; - protected Processor processor; - protected Configuration config; - protected boolean enableDebug; - - protected ExecutionTask executionTask; - protected IoDescriptor ioDescriptor; - protected int cycleId; - protected long pipelineId; - protected long schedulerId; - protected String pipelineName; - protected String driverId; - protected MetricGroup metricGroup; - protected EventMetrics eventMetrics; - protected List> collectors; - protected long windowId; - protected RuntimeContext runtimeContext; - protected boolean isFinished; - - public AbstractWorkerContext(ITaskContext taskContext) { - this.config = taskContext.getConfig(); - this.metricGroup = taskContext.getMetricGroup(); - this.enableDebug = false; - } - - @Override - public void init(IEventContext eventContext) { - this.eventContext = (EventContext) eventContext; - this.currentWindowId = this.eventContext.getCurrentWindowId(); - this.cycleId = this.eventContext.getCycleId(); - this.pipelineId = this.eventContext.getPipelineId(); - this.schedulerId = this.eventContext.getSchedulerId(); - this.pipelineName = this.eventContext.getPipelineName(); - this.driverId = this.eventContext.getDriverId(); - this.ioDescriptor = this.eventContext.getIoDescriptor(); - this.executionTask = this.eventContext.getExecutionTask(); - this.processor = this.executionTask.getProcessor(); - this.taskId = this.executionTask.getTaskId(); - this.windowId = this.eventContext.getWindowId(); - this.runtimeContext = createRuntimeContext(); - this.isFinished = false; - this.initEventMetrics(); - } - - public EventContext getEventContext() { - return this.eventContext; - } - - /** - * Create runtime context and set io descriptor. - */ - private RuntimeContext createRuntimeContext() { - return DefaultRuntimeContext.build(config) - .setTaskArgs(this.executionTask.buildTaskArgs()) - .setPipelineId(pipelineId) - .setPipelineName(pipelineName) - .setMetricGroup(metricGroup) - .setIoDescriptor(ioDescriptor) - .setWindowId(windowId); - } - - public long getCurrentWindowId() { - return currentWindowId; - } - - public void setCurrentWindowId(long currentWindowId) { - this.currentWindowId = currentWindowId; - } - - public String getDriverId() { - return driverId; - } - - public void setTaskId(int taskId) { - this.taskId = taskId; - } - - public int getTaskId() { - return taskId; - } - - public int getCycleId() { - return cycleId; - } - - public Processor getProcessor() { - return processor; - } - - public ExecutionTask getExecutionTask() { - return executionTask; - } - - public EventMetrics getEventMetrics() { - return eventMetrics; - } - - public List> getCollectors() { - return collectors; - } - - public void setCollectors(List> collectors) { - this.collectors = collectors; - } - - public long getPipelineId() { - return pipelineId; - } - - public String getPipelineName() { - return pipelineName; - } - - public void setPipelineId(long pipelineId) { - this.pipelineId = pipelineId; - } - - public void setSchedulerId(long schedulerId) { - this.schedulerId = schedulerId; - } - - public long getSchedulerId() { - return schedulerId; - } - - public void setPipelineName(String pipelineName) { - this.pipelineName = pipelineName; - } - - public void setWindowId(long windowId) { - this.windowId = windowId; - } - - public long getWindowId() { - return windowId; - } - - public RuntimeContext getRuntimeContext() { - return runtimeContext; - } - - public boolean isIterativeTask() { - return this.executionTask.isIterative(); - } - - public boolean isFinished() { - return isFinished; - } - - public void setFinished(boolean finished) { - isFinished = finished; - } - - public void initEventMetrics() { - this.eventMetrics = new EventMetrics( + protected EventContext eventContext; + protected long currentWindowId; + protected int taskId; + protected Processor processor; + protected Configuration config; + protected boolean enableDebug; + + protected ExecutionTask executionTask; + protected IoDescriptor ioDescriptor; + protected int cycleId; + protected long pipelineId; + protected long schedulerId; + protected String pipelineName; + protected String driverId; + protected MetricGroup metricGroup; + protected EventMetrics eventMetrics; + protected List> collectors; + protected long windowId; + protected RuntimeContext runtimeContext; + protected boolean isFinished; + + public AbstractWorkerContext(ITaskContext taskContext) { + this.config = taskContext.getConfig(); + this.metricGroup = taskContext.getMetricGroup(); + this.enableDebug = false; + } + + @Override + public void init(IEventContext eventContext) { + this.eventContext = (EventContext) eventContext; + this.currentWindowId = this.eventContext.getCurrentWindowId(); + this.cycleId = this.eventContext.getCycleId(); + this.pipelineId = this.eventContext.getPipelineId(); + this.schedulerId = this.eventContext.getSchedulerId(); + this.pipelineName = this.eventContext.getPipelineName(); + this.driverId = this.eventContext.getDriverId(); + this.ioDescriptor = this.eventContext.getIoDescriptor(); + this.executionTask = this.eventContext.getExecutionTask(); + this.processor = this.executionTask.getProcessor(); + this.taskId = this.executionTask.getTaskId(); + this.windowId = this.eventContext.getWindowId(); + this.runtimeContext = createRuntimeContext(); + this.isFinished = false; + this.initEventMetrics(); + } + + public EventContext getEventContext() { + return this.eventContext; + } + + /** Create runtime context and set io descriptor. */ + private RuntimeContext createRuntimeContext() { + return DefaultRuntimeContext.build(config) + .setTaskArgs(this.executionTask.buildTaskArgs()) + .setPipelineId(pipelineId) + .setPipelineName(pipelineName) + .setMetricGroup(metricGroup) + .setIoDescriptor(ioDescriptor) + .setWindowId(windowId); + } + + public long getCurrentWindowId() { + return currentWindowId; + } + + public void setCurrentWindowId(long currentWindowId) { + this.currentWindowId = currentWindowId; + } + + public String getDriverId() { + return driverId; + } + + public void setTaskId(int taskId) { + this.taskId = taskId; + } + + public int getTaskId() { + return taskId; + } + + public int getCycleId() { + return cycleId; + } + + public Processor getProcessor() { + return processor; + } + + public ExecutionTask getExecutionTask() { + return executionTask; + } + + public EventMetrics getEventMetrics() { + return eventMetrics; + } + + public List> getCollectors() { + return collectors; + } + + public void setCollectors(List> collectors) { + this.collectors = collectors; + } + + public long getPipelineId() { + return pipelineId; + } + + public String getPipelineName() { + return pipelineName; + } + + public void setPipelineId(long pipelineId) { + this.pipelineId = pipelineId; + } + + public void setSchedulerId(long schedulerId) { + this.schedulerId = schedulerId; + } + + public long getSchedulerId() { + return schedulerId; + } + + public void setPipelineName(String pipelineName) { + this.pipelineName = pipelineName; + } + + public void setWindowId(long windowId) { + this.windowId = windowId; + } + + public long getWindowId() { + return windowId; + } + + public RuntimeContext getRuntimeContext() { + return runtimeContext; + } + + public boolean isIterativeTask() { + return this.executionTask.isIterative(); + } + + public boolean isFinished() { + return isFinished; + } + + public void setFinished(boolean finished) { + isFinished = finished; + } + + public void initEventMetrics() { + this.eventMetrics = + new EventMetrics( this.executionTask.getVertexId(), this.executionTask.getParallelism(), this.executionTask.getIndex()); - } - + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContext.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContext.java index 975e9f2c8..bc315f66d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContext.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContext.java @@ -24,18 +24,16 @@ public class WorkerContext extends AbstractWorkerContext { - public WorkerContext(ITaskContext taskContext) { - super(taskContext); - } + public WorkerContext(ITaskContext taskContext) { + super(taskContext); + } - /** - * Release worker resource. - */ - @Override - public void close() { - for (ICollector collector : collectors) { - collector.close(); - } - processor.close(); + /** Release worker resource. */ + @Override + public void close() { + for (ICollector collector : collectors) { + collector.close(); } + processor.close(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContextManager.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContextManager.java index d8235fedd..342ce3391 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContextManager.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/context/WorkerContextManager.java @@ -20,30 +20,32 @@ package org.apache.geaflow.runtime.core.worker.context; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.cluster.worker.IWorkerContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class WorkerContextManager { - private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContextManager.class); + private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContextManager.class); - private static ConcurrentHashMap workerContexts = new ConcurrentHashMap<>(); + private static ConcurrentHashMap workerContexts = + new ConcurrentHashMap<>(); - public static IWorkerContext get(int taskId) { - return workerContexts.get(taskId); - } + public static IWorkerContext get(int taskId) { + return workerContexts.get(taskId); + } - public static void register(int taskId, WorkerContext workerContext) { - LOGGER.info("taskId {} register worker context", taskId); - workerContexts.put(taskId, workerContext); - } + public static void register(int taskId, WorkerContext workerContext) { + LOGGER.info("taskId {} register worker context", taskId); + workerContexts.put(taskId, workerContext); + } - public static synchronized void clear() { - for (IWorkerContext workerContext : workerContexts.values()) { - workerContext.close(); - } - LOGGER.info("clear all worker context"); - workerContexts.clear(); + public static synchronized void clear() { + for (IWorkerContext workerContext : workerContexts.values()) { + workerContext.close(); } + LOGGER.info("clear all worker context"); + workerContexts.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/AlignedComputeWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/AlignedComputeWorker.java index 45e73f255..2d4d2bb70 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/AlignedComputeWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/AlignedComputeWorker.java @@ -24,12 +24,12 @@ public class AlignedComputeWorker extends AbstractAlignedWorker { - public AlignedComputeWorker() { - super(); - } + public AlignedComputeWorker() { + super(); + } - @Override - public WorkerType getWorkerType() { - return WorkerType.aligned_compute; - } + @Override + public WorkerType getWorkerType() { + return WorkerType.aligned_compute; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/UnAlignedComputeWorker.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/UnAlignedComputeWorker.java index fe25a3632..18bcf938a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/UnAlignedComputeWorker.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/main/java/org/apache/geaflow/runtime/core/worker/impl/UnAlignedComputeWorker.java @@ -26,31 +26,35 @@ public class UnAlignedComputeWorker extends AbstractUnAlignedWorker { - private static final Logger LOGGER = LoggerFactory.getLogger(UnAlignedComputeWorker.class); + private static final Logger LOGGER = LoggerFactory.getLogger(UnAlignedComputeWorker.class); - public UnAlignedComputeWorker() { - super(); - } - - @Override - public void init(long windowId) { - // TODO Processing in dynamic / stream scene. - if (processingWindowIdQueue.isEmpty() && windowId <= context.getWindowId()) { - super.init(windowId); - } - processingWindowIdQueue.add(windowId); - } - - @Override - public void finish(long windowId) { - LOGGER.info("taskId {} finishes windowId {}, currentBatchId {}, real currentBatchId {}", - context.getTaskId(), windowId, windowId, context.getCurrentWindowId()); - context.getProcessor().finish(windowId); - finishWindow(windowId); - } + public UnAlignedComputeWorker() { + super(); + } - @Override - public WorkerType getWorkerType() { - return WorkerType.unaligned_compute; + @Override + public void init(long windowId) { + // TODO Processing in dynamic / stream scene. + if (processingWindowIdQueue.isEmpty() && windowId <= context.getWindowId()) { + super.init(windowId); } + processingWindowIdQueue.add(windowId); + } + + @Override + public void finish(long windowId) { + LOGGER.info( + "taskId {} finishes windowId {}, currentBatchId {}, real currentBatchId {}", + context.getTaskId(), + windowId, + windowId, + context.getCurrentWindowId()); + context.getProcessor().finish(windowId); + finishWindow(windowId); + } + + @Override + public WorkerType getWorkerType() { + return WorkerType.unaligned_compute; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/BaseCycleSchedulerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/BaseCycleSchedulerTest.java index 40592faee..c781987c6 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/BaseCycleSchedulerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/BaseCycleSchedulerTest.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; + import org.apache.geaflow.cluster.common.IEventListener; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; @@ -46,108 +47,127 @@ public class BaseCycleSchedulerTest { - private static final Logger LOGGER = LoggerFactory.getLogger(BaseCycleSchedulerTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(BaseCycleSchedulerTest.class); - private MockedStatic rpcClientMs; - protected MockContainerEventProcessor processor; + private MockedStatic rpcClientMs; + protected MockContainerEventProcessor processor; - @BeforeMethod - public void beforeMethod() { - processor = new MockContainerEventProcessor(); - // mock resource manager rpc - RpcClient rpcClient = Mockito.mock(RpcClient.class); - rpcClientMs = Mockito.mockStatic(RpcClient.class); - rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); + @BeforeMethod + public void beforeMethod() { + processor = new MockContainerEventProcessor(); + // mock resource manager rpc + RpcClient rpcClient = Mockito.mock(RpcClient.class); + rpcClientMs = Mockito.mockStatic(RpcClient.class); + rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); - Mockito.doAnswer(in -> { - int workerNum = ((RequireResourceRequest) in.getArgument(1)).getRequiredNum(); - List workers = new ArrayList<>(); - for (int i = 0; i < workerNum; i++) { + Mockito.doAnswer( + in -> { + int workerNum = ((RequireResourceRequest) in.getArgument(1)).getRequiredNum(); + List workers = new ArrayList<>(); + for (int i = 0; i < workerNum; i++) { WorkerInfo workerInfo = new WorkerInfo(); workerInfo.setHost("host0"); workerInfo.setContainerName("container0"); workerInfo.setWorkerIndex(i); workers.add(workerInfo); - } - return RequireResponse.success("test", workers); - }).when(rpcClient).requireResource(any(), any()); - - // mock container rpc - Mockito.doAnswer(in -> { - processor.process((IEvent) in.getArgument(1)); - CompletableFuture future = new CompletableFuture<>(); - future.complete(null); - return future; - }).when(rpcClient).processContainer(any(), any()); - - Mockito.doAnswer(in -> { - processor.process((IEvent) in.getArgument(1)); - RpcEndpointRef.RpcCallback callback = ((RpcEndpointRef.RpcCallback) in.getArgument(2)); - callback.onSuccess(null); - return null; - }).when(rpcClient).processContainer(any(), any(), any()); + } + return RequireResponse.success("test", workers); + }) + .when(rpcClient) + .requireResource(any(), any()); + + // mock container rpc + Mockito.doAnswer( + in -> { + processor.process((IEvent) in.getArgument(1)); + CompletableFuture future = new CompletableFuture<>(); + future.complete(null); + return future; + }) + .when(rpcClient) + .processContainer(any(), any()); + + Mockito.doAnswer( + in -> { + processor.process((IEvent) in.getArgument(1)); + RpcEndpointRef.RpcCallback callback = + ((RpcEndpointRef.RpcCallback) in.getArgument(2)); + callback.onSuccess(null); + return null; + }) + .when(rpcClient) + .processContainer(any(), any(), any()); + } + + @AfterMethod + public void afterMethod() { + rpcClientMs.close(); + processor.clean(); + } + + public class MockContainerEventProcessor { + + private List processed; + private ICycleScheduler scheduler; + + public MockContainerEventProcessor() { + this.processed = new ArrayList<>(); } - @AfterMethod - public void afterMethod() { - rpcClientMs.close(); - processor.clean(); + public void register(ICycleScheduler scheduler) { + this.scheduler = scheduler; } - public class MockContainerEventProcessor { - - private List processed; - private ICycleScheduler scheduler; - - public MockContainerEventProcessor() { - this.processed = new ArrayList<>(); - } - - public void register(ICycleScheduler scheduler) { - this.scheduler = scheduler; - } + public void clean() { + processed.clear(); + this.scheduler = null; + } - public void clean() { - processed.clear(); - this.scheduler = null; - } + public void process(IEvent event) { + LOGGER.info("process event {}", event); + processed.add(event); + processInternal(event); + } - public void process(IEvent event) { - LOGGER.info("process event {}", event); - processed.add(event); - processInternal(event); + public void processInternal(IEvent event) { + if (event.getEventType() == EventType.COMPOSE) { + for (IEvent e : ((ComposeEvent) event).getEventList()) { + processInternal(e); } - - public void processInternal(IEvent event) { - if (event.getEventType() == EventType.COMPOSE) { - for (IEvent e : ((ComposeEvent) event).getEventList()) { - processInternal(e); - } - } else { - IEvent response; - switch (event.getEventType()) { - case LAUNCH_SOURCE: - LaunchSourceEvent sourceEvent = (LaunchSourceEvent) event; - response = new DoneEvent<>(sourceEvent.getSchedulerId(), sourceEvent.getCycleId(), sourceEvent.getIterationWindowId(), - sourceEvent.getWorkerId(), EventType.EXECUTE_COMPUTE); - ((IEventListener) scheduler).handleEvent(response); - break; - case CLEAN_CYCLE: - case CLEAN_ENV: - case STASH_WORKER: - AbstractExecutableCommand executableCommand = (AbstractExecutableCommand) event; - response = new DoneEvent<>(executableCommand.getSchedulerId(), executableCommand.getCycleId(), executableCommand.getIterationWindowId(), - executableCommand.getWorkerId(), executableCommand.getEventType()); - ((IEventListener) scheduler).handleEvent(response); - break; - default: - - } - } + } else { + IEvent response; + switch (event.getEventType()) { + case LAUNCH_SOURCE: + LaunchSourceEvent sourceEvent = (LaunchSourceEvent) event; + response = + new DoneEvent<>( + sourceEvent.getSchedulerId(), + sourceEvent.getCycleId(), + sourceEvent.getIterationWindowId(), + sourceEvent.getWorkerId(), + EventType.EXECUTE_COMPUTE); + ((IEventListener) scheduler).handleEvent(response); + break; + case CLEAN_CYCLE: + case CLEAN_ENV: + case STASH_WORKER: + AbstractExecutableCommand executableCommand = (AbstractExecutableCommand) event; + response = + new DoneEvent<>( + executableCommand.getSchedulerId(), + executableCommand.getCycleId(), + executableCommand.getIterationWindowId(), + executableCommand.getWorkerId(), + executableCommand.getEventType()); + ((IEventListener) scheduler).handleEvent(response); + break; + default: } + } + } - public List getProcessed() { - return processed; - } + public List getProcessed() { + return processed; } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleSchedulerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleSchedulerTest.java index 59bd35e59..6b868eca7 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleSchedulerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/ExecutionGraphCycleSchedulerTest.java @@ -30,6 +30,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.common.ExecutionIdGenerator; import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; @@ -61,189 +62,236 @@ public class ExecutionGraphCycleSchedulerTest extends BaseCycleSchedulerTest { - private static final Logger LOGGER = LoggerFactory.getLogger(ExecutionGraphCycleSchedulerTest.class); - - private Configuration configuration; - - @BeforeMethod - public void setUp() { - Map config = new HashMap<>(); - config.put(JOB_UNIQUE_ID.getKey(), "scheduler-fo-test" + System.currentTimeMillis()); - config.put(RUN_LOCAL_MODE.getKey(), "true"); - config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); - configuration = new Configuration(config); - ExecutionIdGenerator.init(0); - ShuffleManager.init(configuration); - StatsCollectorFactory.init(configuration); + private static final Logger LOGGER = + LoggerFactory.getLogger(ExecutionGraphCycleSchedulerTest.class); + + private Configuration configuration; + + @BeforeMethod + public void setUp() { + Map config = new HashMap<>(); + config.put(JOB_UNIQUE_ID.getKey(), "scheduler-fo-test" + System.currentTimeMillis()); + config.put(RUN_LOCAL_MODE.getKey(), "true"); + config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); + configuration = new Configuration(config); + ExecutionIdGenerator.init(0); + ShuffleManager.init(configuration); + StatsCollectorFactory.init(configuration); + } + + @AfterMethod + public void cleanUp() { + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); + } + + @Test + public void testSimplePipeline() { + ClusterMetaStore.init(0, "driver-0", configuration); + ExecutionGraphCycleScheduler scheduler = new ExecutionGraphCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockGraphCycle(), null); + scheduler.init(context); + scheduler.execute(); + scheduler.close(); + + List events = processor.getProcessed(); + LOGGER.info("processed events {}", events.size()); + for (IEvent event : events) { + LOGGER.info("{}", event); } - @AfterMethod - public void cleanUp() { - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); + int eventWindowId = 1; + int eventIndex = 0; + int eventSize = 32; + Assert.assertEquals(eventSize, events.size()); + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.CREATE_TASK, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.CREATE_WORKER, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + eventIndex++; + + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.INIT_CYCLE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + Assert.assertEquals( + eventWindowId++, + ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)) + .getIterationWindowId()); + eventIndex++; + + Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); + + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); + + while (eventIndex < eventSize - 1) { + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.POP_WORKER, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + Assert.assertEquals( + eventWindowId++, + ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)) + .getIterationWindowId()); + eventIndex++; + + Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); + + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); } - @Test - public void testSimplePipeline() { - ClusterMetaStore.init(0, "driver-0", configuration); - ExecutionGraphCycleScheduler scheduler = new ExecutionGraphCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockGraphCycle(), null); - scheduler.init(context); - scheduler.execute(); - scheduler.close(); - - - List events = processor.getProcessed(); - LOGGER.info("processed events {}", events.size()); - for (IEvent event : events) { - LOGGER.info("{}", event); - } - - int eventWindowId = 1; - int eventIndex = 0; - int eventSize = 32; - Assert.assertEquals(eventSize, events.size()); - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.CREATE_TASK, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.CREATE_WORKER, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - eventIndex++; - - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.INIT_CYCLE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - Assert.assertEquals(eventWindowId++, ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)).getIterationWindowId()); - eventIndex++; - - Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); - - while (eventIndex < eventSize - 1) { - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.POP_WORKER, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - Assert.assertEquals(eventWindowId++, ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)).getIterationWindowId()); - eventIndex++; - - Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); - } - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex).getEventType()); + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex).getEventType()); + } + + @Test + private void testPipelineAfterRecover() { + + configuration.put(CLUSTER_ID, "restart"); + ClusterMetaStore.init(0, "driver-0", configuration); + ClusterMetaStore.getInstance().saveWindowId(5L, 0); + + ExecutionGraphCycleScheduler scheduler = new ExecutionGraphCycleScheduler(); + processor.register(scheduler); + + // mock recover context from previous case. + CheckpointSchedulerContext newContext = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockGraphCycle(), null); + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build( + newContext.getCycle().getPipelineTaskId(), () -> newContext); + context.init(6); + + scheduler.init(context); + scheduler.execute(); + scheduler.close(); + + List events = processor.getProcessed(); + LOGGER.info("processed events {}", events.size()); + for (IEvent event : events) { + LOGGER.info("{}", event); } - @Test - private void testPipelineAfterRecover() { - - configuration.put(CLUSTER_ID, "restart"); - ClusterMetaStore.init(0, "driver-0", configuration); - ClusterMetaStore.getInstance().saveWindowId(5L, 0); - - ExecutionGraphCycleScheduler scheduler = new ExecutionGraphCycleScheduler(); - processor.register(scheduler); - - // mock recover context from previous case. - CheckpointSchedulerContext newContext = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockGraphCycle(), null); - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(newContext.getCycle().getPipelineTaskId(), - () -> newContext); - context.init(6); - - - scheduler.init(context); - scheduler.execute(); - scheduler.close(); - - List events = processor.getProcessed(); - LOGGER.info("processed events {}", events.size()); - for (IEvent event : events) { - LOGGER.info("{}", event); - } - - int eventWindowId = 6; - int eventIndex = 0; - int eventSize = 17; - Assert.assertEquals(eventSize, events.size()); - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.CREATE_TASK, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.CREATE_WORKER, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - eventIndex++; - - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.INIT_CYCLE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.ROLLBACK, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(2).getEventType()); - Assert.assertEquals(eventWindowId++, ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(2)).getIterationWindowId()); - eventIndex++; - - Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); - - while (eventIndex < eventSize - 1) { - Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); - Assert.assertEquals(EventType.POP_WORKER, ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); - Assert.assertEquals(eventWindowId++, ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)).getIterationWindowId()); - eventIndex++; - - Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); - } - - Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex).getEventType()); + int eventWindowId = 6; + int eventIndex = 0; + int eventSize = 17; + Assert.assertEquals(eventSize, events.size()); + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.CREATE_TASK, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.CREATE_WORKER, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + eventIndex++; + + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.INIT_CYCLE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.ROLLBACK, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(2).getEventType()); + Assert.assertEquals( + eventWindowId++, + ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(2)) + .getIterationWindowId()); + eventIndex++; + + Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); + + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); + + while (eventIndex < eventSize - 1) { + Assert.assertEquals(EventType.COMPOSE, events.get(eventIndex).getEventType()); + Assert.assertEquals( + EventType.POP_WORKER, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(eventIndex)).getEventList().get(1).getEventType()); + Assert.assertEquals( + eventWindowId++, + ((LaunchSourceEvent) ((ComposeEvent) events.get(eventIndex)).getEventList().get(1)) + .getIterationWindowId()); + eventIndex++; + + Assert.assertEquals(EventType.STASH_WORKER, events.get(eventIndex++).getEventType()); + + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex++).getEventType()); } + Assert.assertEquals(EventType.CLEAN_ENV, events.get(eventIndex).getEventType()); + } - private ExecutionGraphCycle buildMockGraphCycle() { - - ExecutionGraphCycle graphCycle = new ExecutionGraphCycle(0, 0, 0, "graph_cycle", 0, - 1, 10, configuration, "driver_id", 0); - ExecutionNodeCycle nodeCycle = buildMockIterationNodeCycle(configuration); + private ExecutionGraphCycle buildMockGraphCycle() { - graphCycle.addCycle(nodeCycle, false); - try { - ReflectionUtil.setField(graphCycle, "haLevel", CHECKPOINT); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + ExecutionGraphCycle graphCycle = + new ExecutionGraphCycle(0, 0, 0, "graph_cycle", 0, 1, 10, configuration, "driver_id", 0); + ExecutionNodeCycle nodeCycle = buildMockIterationNodeCycle(configuration); - return graphCycle; + graphCycle.addCycle(nodeCycle, false); + try { + ReflectionUtil.setField(graphCycle, "haLevel", CHECKPOINT); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - protected ExecutionNodeCycle buildMockIterationNodeCycle(Configuration configuration) { - - long finishIterationId = 1; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(1); - vertexGroup.getVertexMap().put(0, vertex); - vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); - vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); - - List headTasks = new ArrayList<>(); - List tailTasks = new ArrayList<>(); - for (int i = 0; i < vertex.getParallelism(); i++) { - ExecutionTask task = new ExecutionTask(i, i, vertex.getParallelism(), vertex.getParallelism(), vertex.getParallelism(), vertex.getVertexId()); - task.setExecutionTaskType(ExecutionTaskType.head); - tailTasks.add(task); - headTasks.add(task); - } - - ExecutionNodeCycle cycle = new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, - configuration, "driver_id", 0); - cycle.setCycleHeads(headTasks); - cycle.setCycleTails(tailTasks); - cycle.setTasks(headTasks); - return cycle; + return graphCycle; + } + + protected ExecutionNodeCycle buildMockIterationNodeCycle(Configuration configuration) { + + long finishIterationId = 1; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(1); + vertexGroup.getVertexMap().put(0, vertex); + vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); + vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); + + List headTasks = new ArrayList<>(); + List tailTasks = new ArrayList<>(); + for (int i = 0; i < vertex.getParallelism(); i++) { + ExecutionTask task = + new ExecutionTask( + i, + i, + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getVertexId()); + task.setExecutionTaskType(ExecutionTaskType.head); + tailTasks.add(task); + headTasks.add(task); } -} \ No newline at end of file + ExecutionNodeCycle cycle = + new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + cycle.setCycleHeads(headTasks); + cycle.setCycleTails(tailTasks); + cycle.setTasks(headTasks); + return cycle; + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleSchedulerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleSchedulerTest.java index 492a14265..cbf4756f2 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleSchedulerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/PipelineCycleSchedulerTest.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.protocol.EventType; import org.apache.geaflow.cluster.protocol.IEvent; import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; @@ -58,166 +59,197 @@ public class PipelineCycleSchedulerTest extends BaseCycleSchedulerTest { - private static final Logger LOGGER = LoggerFactory.getLogger(PipelineCycleSchedulerTest.class); - - private static CheckpointSchedulerContext mockPersistContext; - private Configuration configuration; - - @BeforeMethod - public void setUp() { - Map config = new HashMap<>(); - config.put(JOB_UNIQUE_ID.getKey(), "scheduler-fo-test" + System.currentTimeMillis()); - config.put(RUN_LOCAL_MODE.getKey(), "true"); - config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); - config.put(ExecutionConfigKeys.SHUFFLE_PREFETCH.getKey(), String.valueOf(false)); - configuration = new Configuration(config); - ClusterMetaStore.init(0, "driver-0", configuration); - ShuffleManager.init(configuration); - } - - @AfterMethod - public void cleanUp() { - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); - } - - @Test(priority = 0) - public void testSimplePipeline() { - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - StatsCollectorFactory.init(configuration); - CheckpointSchedulerContext context = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(configuration), null); - mockPersistContext = context; - scheduler.init(context); - scheduler.execute(); - scheduler.close(); - - List events = processor.getProcessed(); - LOGGER.info("processed events {}", events.size()); - for (IEvent event : events) { - LOGGER.info("{}", event); - } - Assert.assertEquals(6, events.size()); - - Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); - Assert.assertEquals(EventType.INIT_CYCLE, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); - Assert.assertEquals(1, ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)).getIterationWindowId()); - - Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); - Assert.assertEquals(2, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); - - Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(4).getEventType()); - Assert.assertEquals(5, ((LaunchSourceEvent) events.get(4)).getIterationWindowId()); - - Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(5).getEventType()); - + private static final Logger LOGGER = LoggerFactory.getLogger(PipelineCycleSchedulerTest.class); + + private static CheckpointSchedulerContext mockPersistContext; + private Configuration configuration; + + @BeforeMethod + public void setUp() { + Map config = new HashMap<>(); + config.put(JOB_UNIQUE_ID.getKey(), "scheduler-fo-test" + System.currentTimeMillis()); + config.put(RUN_LOCAL_MODE.getKey(), "true"); + config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); + config.put(ExecutionConfigKeys.SHUFFLE_PREFETCH.getKey(), String.valueOf(false)); + configuration = new Configuration(config); + ClusterMetaStore.init(0, "driver-0", configuration); + ShuffleManager.init(configuration); + } + + @AfterMethod + public void cleanUp() { + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); + } + + @Test(priority = 0) + public void testSimplePipeline() { + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + StatsCollectorFactory.init(configuration); + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(configuration), null); + mockPersistContext = context; + scheduler.init(context); + scheduler.execute(); + scheduler.close(); + + List events = processor.getProcessed(); + LOGGER.info("processed events {}", events.size()); + for (IEvent event : events) { + LOGGER.info("{}", event); } - - @Test(priority = 1) - private void testPipelineAfterRestart() { - - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - // mock recover context from previous case. - CheckpointSchedulerContext context = mockPersistContext; - context.init(3); - context.setRecovered(false); - context.setRollback(true); - - scheduler.init(context); - scheduler.execute(); - scheduler.close(); - - context.setRollback(false); - - List events = processor.getProcessed(); - LOGGER.info("processed events {}", events.size()); - for (IEvent event : events) { - LOGGER.info("{}", event); - } - Assert.assertEquals(4, events.size()); - Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); - Assert.assertEquals(EventType.INIT_CYCLE, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); - Assert.assertEquals(EventType.ROLLBACK, ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); - Assert.assertEquals(2, ((RollbackCycleEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)).getIterationWindowId()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(0)).getEventList().get(2).getEventType()); - Assert.assertEquals(3, ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(2)).getIterationWindowId()); - - Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); - Assert.assertEquals(4, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); - - Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(2).getEventType()); - Assert.assertEquals(5, ((LaunchSourceEvent) events.get(2)).getIterationWindowId()); - - Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(3).getEventType()); + Assert.assertEquals(6, events.size()); + + Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); + Assert.assertEquals( + EventType.INIT_CYCLE, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); + Assert.assertEquals( + 1, + ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)) + .getIterationWindowId()); + + Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); + Assert.assertEquals(2, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); + + Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(4).getEventType()); + Assert.assertEquals(5, ((LaunchSourceEvent) events.get(4)).getIterationWindowId()); + + Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(5).getEventType()); + } + + @Test(priority = 1) + private void testPipelineAfterRestart() { + + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + // mock recover context from previous case. + CheckpointSchedulerContext context = mockPersistContext; + context.init(3); + context.setRecovered(false); + context.setRollback(true); + + scheduler.init(context); + scheduler.execute(); + scheduler.close(); + + context.setRollback(false); + + List events = processor.getProcessed(); + LOGGER.info("processed events {}", events.size()); + for (IEvent event : events) { + LOGGER.info("{}", event); } - - @Test(priority = 1) - private void testPipelineAfterRecover() { - - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - // mock recover context from previous case. - CheckpointSchedulerContext context = mockPersistContext; - context.init(4); - context.setRecovered(true); - - scheduler.init(context); - scheduler.execute(); - scheduler.close(); - - - List events = processor.getProcessed(); - LOGGER.info("processed events {}", events.size()); - for (IEvent event : events) { - LOGGER.info("{}", event); - } - Assert.assertEquals(3, events.size()); - Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); - Assert.assertEquals(EventType.ROLLBACK, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); - Assert.assertEquals(3, ((RollbackCycleEvent) ((ComposeEvent) events.get(0)).getEventList().get(0)).getIterationWindowId()); - Assert.assertEquals(EventType.LAUNCH_SOURCE, ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); - Assert.assertEquals(4, ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)).getIterationWindowId()); - - Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); - Assert.assertEquals(5, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); - - Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(2).getEventType()); + Assert.assertEquals(4, events.size()); + Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); + Assert.assertEquals( + EventType.INIT_CYCLE, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); + Assert.assertEquals( + EventType.ROLLBACK, ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); + Assert.assertEquals( + 2, + ((RollbackCycleEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)) + .getIterationWindowId()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(0)).getEventList().get(2).getEventType()); + Assert.assertEquals( + 3, + ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(2)) + .getIterationWindowId()); + + Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); + Assert.assertEquals(4, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); + + Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(2).getEventType()); + Assert.assertEquals(5, ((LaunchSourceEvent) events.get(2)).getIterationWindowId()); + + Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(3).getEventType()); + } + + @Test(priority = 1) + private void testPipelineAfterRecover() { + + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + // mock recover context from previous case. + CheckpointSchedulerContext context = mockPersistContext; + context.init(4); + context.setRecovered(true); + + scheduler.init(context); + scheduler.execute(); + scheduler.close(); + + List events = processor.getProcessed(); + LOGGER.info("processed events {}", events.size()); + for (IEvent event : events) { + LOGGER.info("{}", event); } - - private ExecutionNodeCycle buildMockCycle(Configuration configuration) { - - long finishIterationId = 5; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(1); - vertexGroup.getVertexMap().put(0, vertex); - vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); - vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); - - List headTasks = new ArrayList<>(); - List tailTasks = new ArrayList<>(); - for (int i = 0; i < vertex.getParallelism(); i++) { - ExecutionTask task = new ExecutionTask(i, i, vertex.getParallelism(), vertex.getParallelism(), vertex.getParallelism(), vertex.getVertexId()); - task.setWorkerInfo(new WorkerInfo("host0", 0, 0, 1, -1, 1, "container0")); - task.setExecutionTaskType(ExecutionTaskType.head); - tailTasks.add(task); - headTasks.add(task); - } - - ExecutionNodeCycle cycle = new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, - configuration, "driver_id", 0); - cycle.setCycleHeads(headTasks); - cycle.setCycleTails(tailTasks); - cycle.setTasks(headTasks); - return cycle; + Assert.assertEquals(3, events.size()); + Assert.assertEquals(EventType.COMPOSE, events.get(0).getEventType()); + Assert.assertEquals( + EventType.ROLLBACK, ((ComposeEvent) events.get(0)).getEventList().get(0).getEventType()); + Assert.assertEquals( + 3, + ((RollbackCycleEvent) ((ComposeEvent) events.get(0)).getEventList().get(0)) + .getIterationWindowId()); + Assert.assertEquals( + EventType.LAUNCH_SOURCE, + ((ComposeEvent) events.get(0)).getEventList().get(1).getEventType()); + Assert.assertEquals( + 4, + ((LaunchSourceEvent) ((ComposeEvent) events.get(0)).getEventList().get(1)) + .getIterationWindowId()); + + Assert.assertEquals(EventType.LAUNCH_SOURCE, events.get(1).getEventType()); + Assert.assertEquals(5, ((LaunchSourceEvent) events.get(1)).getIterationWindowId()); + + Assert.assertEquals(EventType.CLEAN_CYCLE, events.get(2).getEventType()); + } + + private ExecutionNodeCycle buildMockCycle(Configuration configuration) { + + long finishIterationId = 5; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(1); + vertexGroup.getVertexMap().put(0, vertex); + vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); + vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); + + List headTasks = new ArrayList<>(); + List tailTasks = new ArrayList<>(); + for (int i = 0; i < vertex.getParallelism(); i++) { + ExecutionTask task = + new ExecutionTask( + i, + i, + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getVertexId()); + task.setWorkerInfo(new WorkerInfo("host0", 0, 0, 1, -1, 1, "container0")); + task.setExecutionTaskType(ExecutionTaskType.head); + tailTasks.add(task); + headTasks.add(task); } + ExecutionNodeCycle cycle = + new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + cycle.setCycleHeads(headTasks); + cycle.setCycleTails(tailTasks); + cycle.setTasks(headTasks); + return cycle; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContextTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContextTest.java index 542c2c566..e823b331d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContextTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/AbstractCycleSchedulerContextTest.java @@ -34,127 +34,133 @@ public class AbstractCycleSchedulerContextTest extends BaseCycleSchedulerContextTest { - @Test - public void testFinishIterationIdFromRecover() { - long finishIterationId = 100; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); - context.init(); - - long checkpointId = 20L; - context.checkpoint(checkpointId); - CheckpointSchedulerContext newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); - - Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); - Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); - } - - @Test - public void testInitContextAfterRecover() { - long finishIterationId = 100; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); - context.init(); - - long checkpointId = 20L; - context.checkpoint(checkpointId); - CheckpointSchedulerContext newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); - - long currentIterationId = checkpointId + 1; - context.init(currentIterationId); - - Assert.assertEquals(currentIterationId, newContext.getCurrentIterationId()); - Assert.assertEquals(currentIterationId, newContext.getInitialIterationId()); - Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); - - Assert.assertNotNull(context.getResultManager()); - Assert.assertNotNull(context.getSchedulerWorkerManager()); - } - - - @Test - public void testInitContextAfterRestart() { - long finishIterationId = 100; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); - context.init(); - - // do checkpoint - long checkpointId = 20L; - context.checkpoint(checkpointId); - - // clean checkpoint cycle. - ClusterMetaStore.getInstance(0, "driver-0", new Configuration()).clean(); - CheckpointSchedulerContext newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), + @Test + public void testFinishIterationIdFromRecover() { + long finishIterationId = 100; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); + context.init(); + + long checkpointId = 20L; + context.checkpoint(checkpointId); + CheckpointSchedulerContext newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); + + Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); + Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); + } + + @Test + public void testInitContextAfterRecover() { + long finishIterationId = 100; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); + context.init(); + + long checkpointId = 20L; + context.checkpoint(checkpointId); + CheckpointSchedulerContext newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); + + long currentIterationId = checkpointId + 1; + context.init(currentIterationId); + + Assert.assertEquals(currentIterationId, newContext.getCurrentIterationId()); + Assert.assertEquals(currentIterationId, newContext.getInitialIterationId()); + Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); + + Assert.assertNotNull(context.getResultManager()); + Assert.assertNotNull(context.getSchedulerWorkerManager()); + } + + @Test + public void testInitContextAfterRestart() { + long finishIterationId = 100; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); + context.init(); + + // do checkpoint + long checkpointId = 20L; + context.checkpoint(checkpointId); + + // clean checkpoint cycle. + ClusterMetaStore.getInstance(0, "driver-0", new Configuration()).clean(); + CheckpointSchedulerContext newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build( + context.getCycle().getPipelineTaskId(), () -> CycleSchedulerContextFactory.create(cycle, null)); - long currentIterationId = checkpointId + 1; - context.init(currentIterationId); - - Assert.assertEquals(currentIterationId, newContext.getCurrentIterationId()); - Assert.assertEquals(currentIterationId, newContext.getInitialIterationId()); - Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); - - Assert.assertNotNull(context.getResultManager()); - Assert.assertNotNull(context.getSchedulerWorkerManager()); - } - - @Test - public void testCheckpointDuration() { - - ExecutionNodeCycle cycle = buildMockCycle(false); - CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); - context.init(); - - // not do checkpoint at 17 - long checkpointId = 17L; - context.checkpoint(checkpointId); - Assert.assertNull(ClusterMetaStore.getInstance().getWindowId(context.getCycle().getPipelineTaskId())); - CheckpointSchedulerContext newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); - Assert.assertNotNull(newContext); - Assert.assertEquals(1, newContext.getCurrentIterationId()); - - checkpointId = 20L; - context.checkpoint(checkpointId); - newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); - Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); - - - // loaded is still previous checkpointId - long newCheckpointId = 23L; - context.checkpoint(newCheckpointId); - newContext = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); - Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); + long currentIterationId = checkpointId + 1; + context.init(currentIterationId); + + Assert.assertEquals(currentIterationId, newContext.getCurrentIterationId()); + Assert.assertEquals(currentIterationId, newContext.getInitialIterationId()); + Assert.assertEquals(finishIterationId, newContext.getFinishIterationId()); + + Assert.assertNotNull(context.getResultManager()); + Assert.assertNotNull(context.getSchedulerWorkerManager()); + } + + @Test + public void testCheckpointDuration() { + + ExecutionNodeCycle cycle = buildMockCycle(false); + CheckpointSchedulerContext context = new CheckpointSchedulerContext(cycle, null); + context.init(); + + // not do checkpoint at 17 + long checkpointId = 17L; + context.checkpoint(checkpointId); + Assert.assertNull( + ClusterMetaStore.getInstance().getWindowId(context.getCycle().getPipelineTaskId())); + CheckpointSchedulerContext newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); + Assert.assertNotNull(newContext); + Assert.assertEquals(1, newContext.getCurrentIterationId()); + + checkpointId = 20L; + context.checkpoint(checkpointId); + newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); + Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); + + // loaded is still previous checkpointId + long newCheckpointId = 23L; + context.checkpoint(newCheckpointId); + newContext = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), null); + Assert.assertEquals(checkpointId + 1, newContext.getCurrentIterationId()); + } + + private ExecutionNodeCycle buildMockCycle(boolean isIterative) { + Configuration configuration = new Configuration(); + configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + ClusterMetaStore.init(0, "driver-0", configuration); + + long finishIterationId = 100; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + if (isIterative) { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); } + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); - private ExecutionNodeCycle buildMockCycle(boolean isIterative) { - Configuration configuration = new Configuration(); - configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - ClusterMetaStore.init(0, "driver-0", configuration); - - long finishIterationId = 100; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - if (isIterative) { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - } - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); - - return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); - } + return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/BaseCycleSchedulerContextTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/BaseCycleSchedulerContextTest.java index 2b64fc857..0e233556a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/BaseCycleSchedulerContextTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/BaseCycleSchedulerContextTest.java @@ -23,6 +23,7 @@ import java.io.File; import java.util.List; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.cluster.system.ClusterMetaStore; @@ -40,61 +41,57 @@ public class BaseCycleSchedulerContextTest { - protected static MockedStatic mockWorkerManager; - protected Configuration configuration = new Configuration(); - - @BeforeMethod - public void setUp() { - ClusterMetaStore.close(); - mockWorkerManager = mockScheduledWorkerManager(); - - String path = "/tmp/" + this.getClass().getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - configuration.getConfigMap().clear(); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), this.getClass().getSimpleName()); - configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - configuration.put(FileConfigKeys.ROOT.getKey(), path); - configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "test-scheduler-context"); - } - - @AfterMethod - public void cleanUp() { - mockWorkerManager.close(); - String path = "/tmp/" + this.getClass().getSimpleName(); - FileUtils.deleteQuietly(new File(path)); - ClusterMetaStore.close(); - } + protected static MockedStatic mockWorkerManager; + protected Configuration configuration = new Configuration(); - public static MockedStatic mockScheduledWorkerManager() { - MockedStatic ms = Mockito.mockStatic(ScheduledWorkerManagerFactory.class); - ms.when(() -> - ScheduledWorkerManagerFactory.createScheduledWorkerManager(any(), any())) - .then(invocation -> new MockScheduledWorkerManager()); - return ms; - } + @BeforeMethod + public void setUp() { + ClusterMetaStore.close(); + mockWorkerManager = mockScheduledWorkerManager(); - protected static class MockScheduledWorkerManager - implements IScheduledWorkerManager { + String path = "/tmp/" + this.getClass().getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + configuration.getConfigMap().clear(); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), this.getClass().getSimpleName()); + configuration.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + configuration.put(FileConfigKeys.ROOT.getKey(), path); + configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "test-scheduler-context"); + } - @Override - public void init(ExecutionGraphCycle graph) { - } + @AfterMethod + public void cleanUp() { + mockWorkerManager.close(); + String path = "/tmp/" + this.getClass().getSimpleName(); + FileUtils.deleteQuietly(new File(path)); + ClusterMetaStore.close(); + } - @Override - public List assign(ExecutionGraphCycle vertex) { - return null; - } + public static MockedStatic mockScheduledWorkerManager() { + MockedStatic ms = + Mockito.mockStatic(ScheduledWorkerManagerFactory.class); + ms.when(() -> ScheduledWorkerManagerFactory.createScheduledWorkerManager(any(), any())) + .then(invocation -> new MockScheduledWorkerManager()); + return ms; + } - @Override - public void release(ExecutionGraphCycle vertex) { - } + protected static class MockScheduledWorkerManager + implements IScheduledWorkerManager { - @Override - public void clean(CleanWorkerFunction cleaFunc, IExecutionCycle cycle) { - } + @Override + public void init(ExecutionGraphCycle graph) {} - @Override - public void close(IExecutionCycle cycle) { - } + @Override + public List assign(ExecutionGraphCycle vertex) { + return null; } + + @Override + public void release(ExecutionGraphCycle vertex) {} + + @Override + public void clean(CleanWorkerFunction cleaFunc, IExecutionCycle cycle) {} + + @Override + public void close(IExecutionCycle cycle) {} + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContextTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContextTest.java index 04547e2d3..e9c787fee 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContextTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/CheckpointSchedulerContextTest.java @@ -30,93 +30,99 @@ public class CheckpointSchedulerContextTest extends BaseCycleSchedulerContextTest { - @Test - public void testNewContext() { - long finishIterationId = 100; - String testName = "testName"; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.setName(testName); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - - CheckpointSchedulerContext context = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); - - Assert.assertEquals(1, context.getCurrentIterationId()); - Assert.assertEquals(finishIterationId, context.getFinishIterationId()); - - } - - @Test - public void testRestartContext() { - long finishIterationId = 100; - String testName = "testName"; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.setName(testName); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - CheckpointSchedulerContext context = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); - - ClusterMetaStore.close(); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); - ClusterMetaStore.init(0, "driver-0", configuration); - - CheckpointSchedulerContext loaded = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), () -> context); - Assert.assertEquals(1, loaded.getCurrentIterationId()); - - long checkpointId = 10; - loaded.checkpoint(checkpointId); - Assert.assertNotNull(ClusterMetaStore.getInstance().getCycle(context.getCycle().getPipelineTaskId())); - - // Mock restart job. - ClusterMetaStore.close(); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); - ClusterMetaStore.init(0, "driver-0", configuration); - - CheckpointSchedulerContext loaded2 = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), () -> - CycleSchedulerContextFactory.create(cycle, null)); - Assert.assertEquals(checkpointId + 1, loaded2.getCurrentIterationId()); - } - - @Test - public void testFailoverRecover() { - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); - long finishIterationId = 100; - String testName = "testName"; - ExecutionNodeCycle cycle = buildMockCycle(false); - cycle.setName(testName); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - - CheckpointSchedulerContext context = (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); - long checkpointId = 10; - context.checkpoint(checkpointId); - - ClusterMetaStore.close(); - ClusterMetaStore.init(0, "driver-0", configuration); - - CheckpointSchedulerContext loaded = - (CheckpointSchedulerContext) CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), () -> context); - - Assert.assertEquals(checkpointId + 1, loaded.getCurrentIterationId()); - Assert.assertEquals(finishIterationId, context.getFinishIterationId()); - } - - private ExecutionNodeCycle buildMockCycle(boolean isIterative) { - ClusterMetaStore.init(0, "driver-0", configuration); - - long finishIterationId = 100; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - if (isIterative) { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - } - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); - - return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + @Test + public void testNewContext() { + long finishIterationId = 100; + String testName = "testName"; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.setName(testName); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); + + Assert.assertEquals(1, context.getCurrentIterationId()); + Assert.assertEquals(finishIterationId, context.getFinishIterationId()); + } + + @Test + public void testRestartContext() { + long finishIterationId = 100; + String testName = "testName"; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.setName(testName); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); + + ClusterMetaStore.close(); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + ClusterMetaStore.init(0, "driver-0", configuration); + + CheckpointSchedulerContext loaded = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), () -> context); + Assert.assertEquals(1, loaded.getCurrentIterationId()); + + long checkpointId = 10; + loaded.checkpoint(checkpointId); + Assert.assertNotNull( + ClusterMetaStore.getInstance().getCycle(context.getCycle().getPipelineTaskId())); + + // Mock restart job. + ClusterMetaStore.close(); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test2"); + ClusterMetaStore.init(0, "driver-0", configuration); + + CheckpointSchedulerContext loaded2 = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build( + context.getCycle().getPipelineTaskId(), + () -> CycleSchedulerContextFactory.create(cycle, null)); + Assert.assertEquals(checkpointId + 1, loaded2.getCurrentIterationId()); + } + + @Test + public void testFailoverRecover() { + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + long finishIterationId = 100; + String testName = "testName"; + ExecutionNodeCycle cycle = buildMockCycle(false); + cycle.setName(testName); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(cycle, null); + long checkpointId = 10; + context.checkpoint(checkpointId); + + ClusterMetaStore.close(); + ClusterMetaStore.init(0, "driver-0", configuration); + + CheckpointSchedulerContext loaded = + (CheckpointSchedulerContext) + CheckpointSchedulerContext.build(context.getCycle().getPipelineTaskId(), () -> context); + + Assert.assertEquals(checkpointId + 1, loaded.getCurrentIterationId()); + Assert.assertEquals(finishIterationId, context.getFinishIterationId()); + } + + private ExecutionNodeCycle buildMockCycle(boolean isIterative) { + ClusterMetaStore.init(0, "driver-0", configuration); + + long finishIterationId = 100; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + if (isIterative) { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); } + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); + return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/ExecutionGraphCycleSchedulerContextTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/ExecutionGraphCycleSchedulerContextTest.java index 4e4072678..7a45d285b 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/ExecutionGraphCycleSchedulerContextTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/ExecutionGraphCycleSchedulerContextTest.java @@ -31,69 +31,71 @@ public class ExecutionGraphCycleSchedulerContextTest extends BaseCycleSchedulerContextTest { - @Test - public void testExecutionGraphWithCheckpointChileCycle() { - long finishIterationId = 100; - - ExecutionGraphCycle graph = buildMockExecutionGraphCycle(); - RedoSchedulerContext parentContext = new RedoSchedulerContext(graph, null); - parentContext.init(1); - - ExecutionNodeCycle iterationCycle = buildPipelineCycle(false, finishIterationId); - graph.addCycle(iterationCycle, false); - CheckpointSchedulerContext iterationContext = new CheckpointSchedulerContext(iterationCycle, parentContext); - iterationContext.init(); - - Assert.assertEquals(iterationContext.getCurrentIterationId(), 1); - Assert.assertEquals(iterationContext.getFinishIterationId(), 100); - - configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); - - // Recover case. - CheckpointSchedulerContext loadedContext1 = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(iterationCycle, parentContext); - Assert.assertEquals(1, loadedContext1.getCurrentIterationId()); - - loadedContext1.checkpoint(20); - - // Recover from checkpoint. - CheckpointSchedulerContext loadedContext2 = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(iterationCycle, parentContext); - Assert.assertEquals(21, loadedContext2.getCurrentIterationId()); + @Test + public void testExecutionGraphWithCheckpointChileCycle() { + long finishIterationId = 100; + + ExecutionGraphCycle graph = buildMockExecutionGraphCycle(); + RedoSchedulerContext parentContext = new RedoSchedulerContext(graph, null); + parentContext.init(1); + + ExecutionNodeCycle iterationCycle = buildPipelineCycle(false, finishIterationId); + graph.addCycle(iterationCycle, false); + CheckpointSchedulerContext iterationContext = + new CheckpointSchedulerContext(iterationCycle, parentContext); + iterationContext.init(); + + Assert.assertEquals(iterationContext.getCurrentIterationId(), 1); + Assert.assertEquals(iterationContext.getFinishIterationId(), 100); + + configuration.put(ExecutionConfigKeys.CLUSTER_ID, "test1"); + + // Recover case. + CheckpointSchedulerContext loadedContext1 = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(iterationCycle, parentContext); + Assert.assertEquals(1, loadedContext1.getCurrentIterationId()); + + loadedContext1.checkpoint(20); + + // Recover from checkpoint. + CheckpointSchedulerContext loadedContext2 = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(iterationCycle, parentContext); + Assert.assertEquals(21, loadedContext2.getCurrentIterationId()); + } + + private ExecutionGraphCycle buildMockExecutionGraphCycle() { + ClusterMetaStore.init(0, "driver-0", configuration); + + long finishIterationId = 100; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); + + return new ExecutionGraphCycle( + 0, 0, 0, "test", 0, 1, finishIterationId, configuration, "driver_id", 0); + } + + protected ExecutionNodeCycle buildPipelineCycle(boolean isIterative, long iterationCount) { + ClusterMetaStore.init(0, "driver-0", configuration); + + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(iterationCount); + if (isIterative) { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); } + ExecutionVertex vertex = new ExecutionVertex(1, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); - private ExecutionGraphCycle buildMockExecutionGraphCycle() { - ClusterMetaStore.init(0, "driver-0", configuration); - - long finishIterationId = 100; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); - - return new ExecutionGraphCycle(0, 0, 0, "test", 0, - 1, finishIterationId, - configuration, "driver_id", 0); - } - - protected ExecutionNodeCycle buildPipelineCycle(boolean isIterative, long iterationCount) { - ClusterMetaStore.init(0, "driver-0", configuration); - - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(iterationCount); - if (isIterative) { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - } - ExecutionVertex vertex = new ExecutionVertex(1, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); - - return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); - } + return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContextTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContextTest.java index 6dc495b4f..cd80c853a 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContextTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/context/IterationRedoSchedulerContextTest.java @@ -29,38 +29,38 @@ public class IterationRedoSchedulerContextTest extends BaseCycleSchedulerContextTest { - @Test - public void testIterationFinishId() { - long finishIterationId = 100; - ExecutionNodeCycle cycle = buildMockCycle(false, finishIterationId); - cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); - CheckpointSchedulerContext parentContext = new CheckpointSchedulerContext(cycle, null); - parentContext.init(50); + @Test + public void testIterationFinishId() { + long finishIterationId = 100; + ExecutionNodeCycle cycle = buildMockCycle(false, finishIterationId); + cycle.getVertexGroup().getCycleGroupMeta().setIterationCount(finishIterationId); + CheckpointSchedulerContext parentContext = new CheckpointSchedulerContext(cycle, null); + parentContext.init(50); - ExecutionNodeCycle iterationCycle = buildMockCycle(false, 5); - IterationRedoSchedulerContext iterationContext = new IterationRedoSchedulerContext(iterationCycle, parentContext); - iterationContext.init(20); + ExecutionNodeCycle iterationCycle = buildMockCycle(false, 5); + IterationRedoSchedulerContext iterationContext = + new IterationRedoSchedulerContext(iterationCycle, parentContext); + iterationContext.init(20); - Assert.assertEquals(iterationContext.getCurrentIterationId(), 1); - Assert.assertEquals(iterationContext.getFinishIterationId(), 5); - } - - protected ExecutionNodeCycle buildMockCycle(boolean isIterative, long iterationCount) { - ClusterMetaStore.init(0, "driver-0", configuration); + Assert.assertEquals(iterationContext.getCurrentIterationId(), 1); + Assert.assertEquals(iterationContext.getFinishIterationId(), 5); + } - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(iterationCount); - if (isIterative) { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - } - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); + protected ExecutionNodeCycle buildMockCycle(boolean isIterative, long iterationCount) { + ClusterMetaStore.init(0, "driver-0", configuration); - return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(iterationCount); + if (isIterative) { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); } + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); + return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilderTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilderTest.java index d3b739388..75593aaf4 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilderTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/cycle/ExecutionCycleBuilderTest.java @@ -27,58 +27,64 @@ public class ExecutionCycleBuilderTest { - @Test(expectedExceptions = IllegalArgumentException.class, - expectedExceptionsMessageRegExp = "cycle flyingCount should be positive, current value.*") - public void validateGraphCycleFlyingCount() { - ExecutionGraph executionGraph = new ExecutionGraph(); - CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); - meta.setFlyingCount(0); - meta.setGroupType(CycleGroupType.windowed); + @Test( + expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "cycle flyingCount should be positive, current value.*") + public void validateGraphCycleFlyingCount() { + ExecutionGraph executionGraph = new ExecutionGraph(); + CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); + meta.setFlyingCount(0); + meta.setGroupType(CycleGroupType.windowed); - ExecutionCycleBuilder.buildExecutionCycle(executionGraph, null, null, 0, 0, null, 0, null, 0, false); - } + ExecutionCycleBuilder.buildExecutionCycle( + executionGraph, null, null, 0, 0, null, 0, null, 0, false); + } - @Test(expectedExceptions = IllegalArgumentException.class, - expectedExceptionsMessageRegExp = "cycle iterationCount should be positive, current value.*") - public void validateGraphCycleIterationCount() { - ExecutionGraph executionGraph = new ExecutionGraph(); - CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); - meta.setIterationCount(0); - meta.setGroupType(CycleGroupType.windowed); - ExecutionCycleBuilder.buildExecutionCycle(executionGraph, null, null, 0, 0, null, 0, null, 0, false); - } + @Test( + expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "cycle iterationCount should be positive, current value.*") + public void validateGraphCycleIterationCount() { + ExecutionGraph executionGraph = new ExecutionGraph(); + CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); + meta.setIterationCount(0); + meta.setGroupType(CycleGroupType.windowed); + ExecutionCycleBuilder.buildExecutionCycle( + executionGraph, null, null, 0, 0, null, 0, null, 0, false); + } - @Test(expectedExceptions = IllegalArgumentException.class, - expectedExceptionsMessageRegExp = "cycle flyingCount should be positive, current value.*") - public void validateNodeCycleFlyingCount() { - ExecutionGraph executionGraph = new ExecutionGraph(); - CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); - meta.setIterationCount(1); - meta.setFlyingCount(1); + @Test( + expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "cycle flyingCount should be positive, current value.*") + public void validateNodeCycleFlyingCount() { + ExecutionGraph executionGraph = new ExecutionGraph(); + CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); + meta.setIterationCount(1); + meta.setFlyingCount(1); - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(0); - vertexGroup.getCycleGroupMeta().setIterationCount(1); - executionGraph.getVertexGroupMap().put(1, vertexGroup); + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(0); + vertexGroup.getCycleGroupMeta().setIterationCount(1); + executionGraph.getVertexGroupMap().put(1, vertexGroup); - ExecutionCycleBuilder.buildExecutionCycle(executionGraph, null, null, 0, 0, null, 0, null, - 0, false); - } + ExecutionCycleBuilder.buildExecutionCycle( + executionGraph, null, null, 0, 0, null, 0, null, 0, false); + } - @Test(expectedExceptions = IllegalArgumentException.class, - expectedExceptionsMessageRegExp = "cycle iterationCount should be positive, current value.*") - public void validateNodeCycleIterationCount() { - ExecutionGraph executionGraph = new ExecutionGraph(); - CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); - meta.setIterationCount(1); - meta.setFlyingCount(1); + @Test( + expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "cycle iterationCount should be positive, current value.*") + public void validateNodeCycleIterationCount() { + ExecutionGraph executionGraph = new ExecutionGraph(); + CycleGroupMeta meta = executionGraph.getCycleGroupMeta(); + meta.setIterationCount(1); + meta.setFlyingCount(1); - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(0); - executionGraph.getVertexGroupMap().put(1, vertexGroup); + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(0); + executionGraph.getVertexGroupMap().put(1, vertexGroup); - ExecutionCycleBuilder.buildExecutionCycle(executionGraph, null, null, 0, 0, null, 0, null, - 0, false); - } + ExecutionCycleBuilder.buildExecutionCycle( + executionGraph, null, null, 0, 0, null, 0, null, 0, false); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchangerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchangerTest.java index 938920752..6dca66987 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchangerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/DataExchangerTest.java @@ -23,7 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; -import junit.framework.TestCase; + import org.apache.geaflow.cluster.response.ResponseResult; import org.apache.geaflow.cluster.response.ShardResult; import org.apache.geaflow.core.graph.ExecutionEdge; @@ -34,82 +34,82 @@ import org.apache.geaflow.shuffle.message.Shard; import org.testng.Assert; -public class DataExchangerTest extends TestCase { +import junit.framework.TestCase; - public void testBuildInput() { - - ExecutionVertex vertex = new ExecutionVertex(1, "test"); - int parallelism = 3; - vertex.setParallelism(parallelism); - int edgeId = 0; - ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); - vertex.setInputEdges(Arrays.asList(edge)); - - CycleResultManager resultManager = new CycleResultManager(); - ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); - resultManager.register(0, shards1); - ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); - resultManager.register(0, shards2); - - Map> result = DataExchanger.buildInput(vertex, edge, resultManager); - Assert.assertEquals(parallelism, result.size()); - for (int i = 0; i < parallelism; i++) { - Assert.assertEquals(result.get(i).size(), 1); - Shard shard = result.get(i).get(0); - Assert.assertEquals(shard.getSlices().size(), 2); - for (ISliceMeta sliceMeta : shard.getSlices()) { - Assert.assertEquals(i, sliceMeta.getSourceIndex()); - Assert.assertEquals(i, sliceMeta.getTargetIndex()); - } - } - } +public class DataExchangerTest extends TestCase { - public void testBuildInputWithResponse() { - - ExecutionVertex vertex = new ExecutionVertex(1, "test"); - int parallelism = 3; - vertex.setParallelism(parallelism); - int edgeId = 0; - ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); - - CycleResultManager resultManager = new CycleResultManager(); - ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); - resultManager.register(0, shards1); - - ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); - resultManager.register(0, shards2); - - ResponseResult response1 = new ResponseResult(edgeId, OutputType.RESPONSE, buildResponse()); - resultManager.register(1, response1); - - ResponseResult response2 = new ResponseResult(edgeId, OutputType.RESPONSE, buildResponse()); - resultManager.register(1, response2); - - // build input not include response. - Map> result = DataExchanger.buildInput(vertex, edge, resultManager); - Assert.assertEquals(parallelism, result.size()); - for (int i = 0; i < parallelism; i++) { - Assert.assertEquals(result.get(i).size(), 1); - Shard shard = result.get(i).get(0); - Assert.assertEquals(shard.getSlices().size(), 2); - for (ISliceMeta sliceMeta : shard.getSlices()) { - Assert.assertEquals(i, sliceMeta.getSourceIndex()); - Assert.assertEquals(i, sliceMeta.getTargetIndex()); - } - } + public void testBuildInput() { + + ExecutionVertex vertex = new ExecutionVertex(1, "test"); + int parallelism = 3; + vertex.setParallelism(parallelism); + int edgeId = 0; + ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); + vertex.setInputEdges(Arrays.asList(edge)); + + CycleResultManager resultManager = new CycleResultManager(); + ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); + resultManager.register(0, shards1); + ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); + resultManager.register(0, shards2); + + Map> result = DataExchanger.buildInput(vertex, edge, resultManager); + Assert.assertEquals(parallelism, result.size()); + for (int i = 0; i < parallelism; i++) { + Assert.assertEquals(result.get(i).size(), 1); + Shard shard = result.get(i).get(0); + Assert.assertEquals(shard.getSlices().size(), 2); + for (ISliceMeta sliceMeta : shard.getSlices()) { + Assert.assertEquals(i, sliceMeta.getSourceIndex()); + Assert.assertEquals(i, sliceMeta.getTargetIndex()); + } } - - - private List buildSlices(int parallelism) { - List slices = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - slices.add(new PipelineSliceMeta(i, i, 0, 0, null)); - } - return slices; + } + + public void testBuildInputWithResponse() { + + ExecutionVertex vertex = new ExecutionVertex(1, "test"); + int parallelism = 3; + vertex.setParallelism(parallelism); + int edgeId = 0; + ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); + + CycleResultManager resultManager = new CycleResultManager(); + ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); + resultManager.register(0, shards1); + + ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, buildSlices(parallelism)); + resultManager.register(0, shards2); + + ResponseResult response1 = new ResponseResult(edgeId, OutputType.RESPONSE, buildResponse()); + resultManager.register(1, response1); + + ResponseResult response2 = new ResponseResult(edgeId, OutputType.RESPONSE, buildResponse()); + resultManager.register(1, response2); + + // build input not include response. + Map> result = DataExchanger.buildInput(vertex, edge, resultManager); + Assert.assertEquals(parallelism, result.size()); + for (int i = 0; i < parallelism; i++) { + Assert.assertEquals(result.get(i).size(), 1); + Shard shard = result.get(i).get(0); + Assert.assertEquals(shard.getSlices().size(), 2); + for (ISliceMeta sliceMeta : shard.getSlices()) { + Assert.assertEquals(i, sliceMeta.getSourceIndex()); + Assert.assertEquals(i, sliceMeta.getTargetIndex()); + } } + } - private List buildResponse() { - return Arrays.asList(0); + private List buildSlices(int parallelism) { + List slices = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + slices.add(new PipelineSliceMeta(i, i, 0, 0, null)); } + return slices; + } -} \ No newline at end of file + private List buildResponse() { + return Arrays.asList(0); + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilderTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilderTest.java index 61b5b9603..3d3b9588d 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilderTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/io/IoDescriptorBuilderTest.java @@ -23,7 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import junit.framework.TestCase; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.cluster.response.ShardResult; import org.apache.geaflow.common.config.Configuration; @@ -43,64 +43,74 @@ import org.apache.geaflow.shuffle.message.PipelineSliceMeta; import org.testng.Assert; -public class IoDescriptorBuilderTest extends TestCase { +import junit.framework.TestCase; - public void testBuildInputInfo() { +public class IoDescriptorBuilderTest extends TestCase { - int edgeId = 0; - int parallelism = 3; - ExecutionTask task = new ExecutionTask(1, 0, parallelism, 0, 0, 0); - ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); - ExecutionNodeCycle cycle = buildMockCycle(new Configuration(), parallelism); + public void testBuildInputInfo() { - CycleResultManager resultManager = new CycleResultManager(); - List slices = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - slices.add(new PipelineSliceMeta(i, i, 0, 0, null)); - } - ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, slices); - resultManager.register(0, shards1); - ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, slices); - resultManager.register(0, shards2); + int edgeId = 0; + int parallelism = 3; + ExecutionTask task = new ExecutionTask(1, 0, parallelism, 0, 0, 0); + ExecutionEdge edge = new ExecutionEdge(null, edgeId, null, 0, 1, null); + ExecutionNodeCycle cycle = buildMockCycle(new Configuration(), parallelism); - IInputDesc input = IoDescriptorBuilder.buildInputDesc(task, edge, cycle, resultManager, DataExchangeMode.PIPELINE, BatchPhase.CLASSIC); - ShardInputDesc shard = (ShardInputDesc) input; - Assert.assertEquals(parallelism, shard.getInput().size()); - Assert.assertEquals(1, shard.getInput().get(0).getSlices().size()); + CycleResultManager resultManager = new CycleResultManager(); + List slices = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + slices.add(new PipelineSliceMeta(i, i, 0, 0, null)); } + ShardResult shards1 = new ShardResult(edgeId, OutputType.FORWARD, slices); + resultManager.register(0, shards1); + ShardResult shards2 = new ShardResult(edgeId, OutputType.FORWARD, slices); + resultManager.register(0, shards2); - private ExecutionNodeCycle buildMockCycle(Configuration configuration, int parallelism) { + IInputDesc input = + IoDescriptorBuilder.buildInputDesc( + task, edge, cycle, resultManager, DataExchangeMode.PIPELINE, BatchPhase.CLASSIC); + ShardInputDesc shard = (ShardInputDesc) input; + Assert.assertEquals(parallelism, shard.getInput().size()); + Assert.assertEquals(1, shard.getInput().get(0).getSlices().size()); + } - long finishIterationId = 5; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(parallelism); - vertexGroup.getVertexMap().put(0, vertex); - vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); - vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); + private ExecutionNodeCycle buildMockCycle(Configuration configuration, int parallelism) { - List headTasks = new ArrayList<>(); - List tailTasks = new ArrayList<>(); - for (int i = 0; i < vertex.getParallelism(); i++) { - ExecutionTask task = new ExecutionTask(i, i, vertex.getParallelism(), vertex.getParallelism(), vertex.getParallelism(), vertex.getVertexId()); - task.setExecutionTaskType(ExecutionTaskType.head); - task.setWorkerInfo(new WorkerInfo()); - tailTasks.add(task); - headTasks.add(task); - } + long finishIterationId = 5; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(parallelism); + vertexGroup.getVertexMap().put(0, vertex); + vertexGroup.putVertexId2InEdgeIds(0, new ArrayList<>()); + vertexGroup.putVertexId2OutEdgeIds(0, new ArrayList<>()); - ExecutionNodeCycle cycle = new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, - configuration, - "driver_id", 0); - cycle.setCycleHeads(headTasks); - cycle.setCycleTails(tailTasks); - cycle.setTasks(headTasks); - Map> vertexIdToTasks = new HashMap<>(); - vertexIdToTasks.put(0, headTasks); - cycle.setVertexIdToTasks(vertexIdToTasks); - return cycle; + List headTasks = new ArrayList<>(); + List tailTasks = new ArrayList<>(); + for (int i = 0; i < vertex.getParallelism(); i++) { + ExecutionTask task = + new ExecutionTask( + i, + i, + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getParallelism(), + vertex.getVertexId()); + task.setExecutionTaskType(ExecutionTaskType.head); + task.setWorkerInfo(new WorkerInfo()); + tailTasks.add(task); + headTasks.add(task); } -} \ No newline at end of file + + ExecutionNodeCycle cycle = + new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + cycle.setCycleHeads(headTasks); + cycle.setCycleTails(tailTasks); + cycle.setTasks(headTasks); + Map> vertexIdToTasks = new HashMap<>(); + vertexIdToTasks.put(0, headTasks); + cycle.setVertexIdToTasks(vertexIdToTasks); + return cycle; + } +} diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManagerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManagerTest.java index 5421452e9..de84f72fa 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManagerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/AbstractScheduledWorkerManagerTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.cluster.resourcemanager.ResourceInfo; import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.cluster.rpc.RpcClient; @@ -42,75 +43,84 @@ public class AbstractScheduledWorkerManagerTest { - @AfterMethod - public void afterMethod() { - ScheduledWorkerManagerFactory.clear(); - } - - @Test - public void testInitWorkerSuccess() { - // mock resource manager rpc - RpcClient rpcClient = Mockito.mock(RpcClient.class); - MockedStatic rpcClientMs = Mockito.mockStatic(RpcClient.class); - rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); + @AfterMethod + public void afterMethod() { + ScheduledWorkerManagerFactory.clear(); + } - AtomicInteger count = new AtomicInteger(0); - Mockito.doAnswer(in -> { - RpcEndpointRef.RpcCallback callback = ((RpcEndpointRef.RpcCallback) in.getArgument(2)); - callback.onSuccess(null); - count.incrementAndGet(); - return null; - }).when(rpcClient).processContainer(any(), any(), any()); + @Test + public void testInitWorkerSuccess() { + // mock resource manager rpc + RpcClient rpcClient = Mockito.mock(RpcClient.class); + MockedStatic rpcClientMs = Mockito.mockStatic(RpcClient.class); + rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); - List workers = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); - } - AbstractScheduledWorkerManager workerManager = buildMockWorkerManager(); - workerManager.workers = new ConcurrentHashMap<>(); - workerManager.initWorkers(0L, workers, null); - Assert.assertEquals(count.get(), workers.size()); + AtomicInteger count = new AtomicInteger(0); + Mockito.doAnswer( + in -> { + RpcEndpointRef.RpcCallback callback = + ((RpcEndpointRef.RpcCallback) in.getArgument(2)); + callback.onSuccess(null); + count.incrementAndGet(); + return null; + }) + .when(rpcClient) + .processContainer(any(), any(), any()); - count.set(0); - workerManager.workers.put(0L, new ResourceInfo(workerManager.genResourceId(0, 0L), workers)); - workerManager.initWorkers(0L, workers, null); - Assert.assertEquals(count.get(), 0); - rpcClientMs.close(); + List workers = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); } + AbstractScheduledWorkerManager workerManager = buildMockWorkerManager(); + workerManager.workers = new ConcurrentHashMap<>(); + workerManager.initWorkers(0L, workers, null); + Assert.assertEquals(count.get(), workers.size()); - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testInitWorkerFailed() throws Exception { - // mock resource manager rpc - RpcClient rpcClient = Mockito.mock(RpcClient.class); - MockedStatic rpcClientMs; - rpcClientMs = Mockito.mockStatic(RpcClient.class); - rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); + count.set(0); + workerManager.workers.put(0L, new ResourceInfo(workerManager.genResourceId(0, 0L), workers)); + workerManager.initWorkers(0L, workers, null); + Assert.assertEquals(count.get(), 0); + rpcClientMs.close(); + } - Mockito.doAnswer(in -> { - RpcEndpointRef.RpcCallback callback = ((RpcEndpointRef.RpcCallback) in.getArgument(2)); - callback.onFailure(new GeaflowRuntimeException("rpc error")); - return null; - }).when(rpcClient).processContainer(any(), any(), any()); + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testInitWorkerFailed() throws Exception { + // mock resource manager rpc + RpcClient rpcClient = Mockito.mock(RpcClient.class); + MockedStatic rpcClientMs; + rpcClientMs = Mockito.mockStatic(RpcClient.class); + rpcClientMs.when(() -> RpcClient.getInstance()).then(invocation -> rpcClient); - List workers = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); - } - AbstractScheduledWorkerManager workerManager = buildMockWorkerManager(); - workerManager.workers = new ConcurrentHashMap<>(); - try { - workerManager.initWorkers(0L, workers, null); - } finally { - rpcClientMs.close(); - } - } + Mockito.doAnswer( + in -> { + RpcEndpointRef.RpcCallback callback = + ((RpcEndpointRef.RpcCallback) in.getArgument(2)); + callback.onFailure(new GeaflowRuntimeException("rpc error")); + return null; + }) + .when(rpcClient) + .processContainer(any(), any(), any()); - private AbstractScheduledWorkerManager buildMockWorkerManager() { - return new AbstractScheduledWorkerManager(new Configuration()) { - @Override - protected WorkerInfo assignTaskWorker(WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { - return null; - } - }; + List workers = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); + } + AbstractScheduledWorkerManager workerManager = buildMockWorkerManager(); + workerManager.workers = new ConcurrentHashMap<>(); + try { + workerManager.initWorkers(0L, workers, null); + } finally { + rpcClientMs.close(); } + } + + private AbstractScheduledWorkerManager buildMockWorkerManager() { + return new AbstractScheduledWorkerManager(new Configuration()) { + @Override + protected WorkerInfo assignTaskWorker( + WorkerInfo worker, ExecutionTask task, AffinityLevel affinityLevel) { + return null; + } + }; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/BaseScheduledWorkerManagerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/BaseScheduledWorkerManagerTest.java index a92092f25..e53c8adcd 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/BaseScheduledWorkerManagerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/BaseScheduledWorkerManagerTest.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.system.ClusterMetaStore; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.core.graph.CycleGroupType; @@ -37,36 +38,35 @@ public abstract class BaseScheduledWorkerManagerTest { - protected ExecutionGraphCycle buildMockCycle(int parallelism) { - Configuration configuration = new Configuration(); - configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - ClusterMetaStore.init(0, "driver-0", configuration); - - long finishIterationId = 100; - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - int taskNum = parallelism; - vertex.setParallelism(taskNum); - vertexGroup.getVertexMap().put(0, vertex); - List tasks = new ArrayList<>(); - for (int i = 0; i < taskNum; i++) { - ExecutionTask task = new ExecutionTask(i, i, taskNum, taskNum, taskNum, 0); - tasks.add(task); - } - ExecutionNodeCycle cycle = new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, - configuration, "driver_id", 0); - cycle.setTasks(tasks); + protected ExecutionGraphCycle buildMockCycle(int parallelism) { + Configuration configuration = new Configuration(); + configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + ClusterMetaStore.init(0, "driver-0", configuration); - return new ExecutionGraphCycle(0, 0, 0, "test", 0, 1, 1, configuration, "driverId", 0); + long finishIterationId = 100; + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + int taskNum = parallelism; + vertex.setParallelism(taskNum); + vertexGroup.getVertexMap().put(0, vertex); + List tasks = new ArrayList<>(); + for (int i = 0; i < taskNum; i++) { + ExecutionTask task = new ExecutionTask(i, i, taskNum, taskNum, taskNum, 0); + tasks.add(task); } + ExecutionNodeCycle cycle = + new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + cycle.setTasks(tasks); - @AfterMethod - public void afterMethod() { - ScheduledWorkerManagerFactory.clear(); - } + return new ExecutionGraphCycle(0, 0, 0, "test", 0, 1, 1, configuration, "driverId", 0); + } + @AfterMethod + public void afterMethod() { + ScheduledWorkerManagerFactory.clear(); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointScheduledWorkerManagerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointScheduledWorkerManagerTest.java index 8d27f50d1..fee23fa88 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointScheduledWorkerManagerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/CheckpointScheduledWorkerManagerTest.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.runtime.core.scheduler.cycle.ExecutionGraphCycle; @@ -36,99 +37,99 @@ public class CheckpointScheduledWorkerManagerTest extends BaseScheduledWorkerManagerTest { - @Test - public void testRequestMultiTimes() { - int parallelism = 3; - ExecutionGraphCycle cycle = buildMockCycle(parallelism); - CheckpointCycleScheduledWorkerManager wm = new CheckpointCycleScheduledWorkerManager(new Configuration()); - CheckpointCycleScheduledWorkerManager workerManager = Mockito.spy(wm); - - List workers = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); - } - String resourceId = workerManager.genResourceId(cycle.getDriverIndex(), - cycle.getSchedulerId()); - Mockito.doReturn(workers).when(workerManager).requestWorker(anyInt(), eq(resourceId)); - Mockito.doNothing().when(workerManager).initWorkers(anyLong(), any(), any()); - workerManager.init(cycle); - Assert.assertEquals(parallelism, - workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); + @Test + public void testRequestMultiTimes() { + int parallelism = 3; + ExecutionGraphCycle cycle = buildMockCycle(parallelism); + CheckpointCycleScheduledWorkerManager wm = + new CheckpointCycleScheduledWorkerManager(new Configuration()); + CheckpointCycleScheduledWorkerManager workerManager = Mockito.spy(wm); - workerManager.assign(cycle); - Assert.assertEquals(parallelism, workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); - workerManager.workers.remove(cycle.getSchedulerId()); + List workers = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); } + String resourceId = workerManager.genResourceId(cycle.getDriverIndex(), cycle.getSchedulerId()); + Mockito.doReturn(workers).when(workerManager).requestWorker(anyInt(), eq(resourceId)); + Mockito.doNothing().when(workerManager).initWorkers(anyLong(), any(), any()); + workerManager.init(cycle); + Assert.assertEquals( + parallelism, workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); - @Test - public void testTaskAssigner() throws Exception { - List workers = new ArrayList<>(); - int taskSize = 3; - for (int i = 0; i < taskSize; i++) { - workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); - } - List taskIndexes = new ArrayList<>(); - for (int i = 0; i < taskSize; i++) { - taskIndexes.add(i); - } - String graphName = "test1"; - TaskAssigner taskAssigner = new TaskAssigner(); + workerManager.assign(cycle); + Assert.assertEquals( + parallelism, workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); + workerManager.workers.remove(cycle.getSchedulerId()); + } - // Current worker list: [0, 1, 2]. - // Assign: t0:[w0], t1:[w1], t2:[w2]. - Map match0 = taskAssigner.assignTasks2Workers(graphName, taskIndexes, - workers); + @Test + public void testTaskAssigner() throws Exception { + List workers = new ArrayList<>(); + int taskSize = 3; + for (int i = 0; i < taskSize; i++) { + workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); + } + List taskIndexes = new ArrayList<>(); + for (int i = 0; i < taskSize; i++) { + taskIndexes.add(i); + } + String graphName = "test1"; + TaskAssigner taskAssigner = new TaskAssigner(); - // Hit all cache. - Map match1 = taskAssigner.assignTasks2Workers(graphName, taskIndexes, - workers); - for (Integer taskIndex : taskIndexes) { - Assert.assertEquals(match0.get(taskIndex).getWorkerIndex(), - match1.get(taskIndex).getWorkerIndex()); - } + // Current worker list: [0, 1, 2]. + // Assign: t0:[w0], t1:[w1], t2:[w2]. + Map match0 = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, workers); - // Current worker list: [0, 1, 3]. - // Assign: t0:[w0], t1:[w1], t2:[w2, w3]. - WorkerInfo worker2 = workers.remove(taskSize - 1); - workers.add(new WorkerInfo("", 0, 0, 0, taskSize, "worker-" + taskSize)); - Map match2 = taskAssigner.assignTasks2Workers(graphName, taskIndexes, - workers); - Integer task_index1 = 0; - for (Integer taskIndex : taskIndexes) { - if (match1.get(taskIndex).getWorkerIndex() == taskSize - 1) { - Assert.assertEquals(taskSize, match2.get(taskIndex).getWorkerIndex()); - } else { - if (match1.get(taskIndex).getWorkerIndex() == 1) { - task_index1 = taskIndex; - } - Assert.assertEquals(match1.get(taskIndex).getWorkerIndex(), - match2.get(taskIndex).getWorkerIndex()); - } - } + // Hit all cache. + Map match1 = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, workers); + for (Integer taskIndex : taskIndexes) { + Assert.assertEquals( + match0.get(taskIndex).getWorkerIndex(), match1.get(taskIndex).getWorkerIndex()); + } - // Current worker list: [1, 2, 3]. - // Assign: t0:[w0, w2 or w3], t1:[w1], t2:[w2, w3]. - workers.remove(0); - workers.add(worker2); - Map match3 = taskAssigner.assignTasks2Workers(graphName, taskIndexes, - workers); - for (Integer taskIndex : taskIndexes) { - if (match3.get(taskIndex).getWorkerIndex() == 1) { - Assert.assertEquals(taskIndex, task_index1); - } + // Current worker list: [0, 1, 3]. + // Assign: t0:[w0], t1:[w1], t2:[w2, w3]. + WorkerInfo worker2 = workers.remove(taskSize - 1); + workers.add(new WorkerInfo("", 0, 0, 0, taskSize, "worker-" + taskSize)); + Map match2 = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, workers); + Integer task_index1 = 0; + for (Integer taskIndex : taskIndexes) { + if (match1.get(taskIndex).getWorkerIndex() == taskSize - 1) { + Assert.assertEquals(taskSize, match2.get(taskIndex).getWorkerIndex()); + } else { + if (match1.get(taskIndex).getWorkerIndex() == 1) { + task_index1 = taskIndex; } + Assert.assertEquals( + match1.get(taskIndex).getWorkerIndex(), match2.get(taskIndex).getWorkerIndex()); + } + } - // Current worker list: [2, 3, 4] - // Assign: t0:[w0, w2 or w3], t1:[w1, w4], t2:[w2, w3] - workers.remove(0); - workers.add(new WorkerInfo("", 0, 0, 0, taskSize + 1, "worker-" + taskSize + 1)); - Map match4 = taskAssigner.assignTasks2Workers(graphName, taskIndexes, - workers); - for (Integer taskIndex : taskIndexes) { - if (match3.get(taskIndex).getWorkerIndex() == 1) { - Assert.assertEquals(taskIndex, task_index1); - } - } + // Current worker list: [1, 2, 3]. + // Assign: t0:[w0, w2 or w3], t1:[w1], t2:[w2, w3]. + workers.remove(0); + workers.add(worker2); + Map match3 = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, workers); + for (Integer taskIndex : taskIndexes) { + if (match3.get(taskIndex).getWorkerIndex() == 1) { + Assert.assertEquals(taskIndex, task_index1); + } } + // Current worker list: [2, 3, 4] + // Assign: t0:[w0, w2 or w3], t1:[w1, w4], t2:[w2, w3] + workers.remove(0); + workers.add(new WorkerInfo("", 0, 0, 0, taskSize + 1, "worker-" + taskSize + 1)); + Map match4 = + taskAssigner.assignTasks2Workers(graphName, taskIndexes, workers); + for (Integer taskIndex : taskIndexes) { + if (match3.get(taskIndex).getWorkerIndex() == 1) { + Assert.assertEquals(taskIndex, task_index1); + } + } + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoScheduledWorkerManagerTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoScheduledWorkerManagerTest.java index 9ca9d0c4b..f153dfb62 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoScheduledWorkerManagerTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/resource/RedoScheduledWorkerManagerTest.java @@ -26,6 +26,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.cluster.resourcemanager.WorkerInfo; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.runtime.core.scheduler.cycle.ExecutionGraphCycle; @@ -35,28 +36,27 @@ public class RedoScheduledWorkerManagerTest extends BaseScheduledWorkerManagerTest { - @Test - public void testRequestMultiTimes() { - int parallelism = 3; - ExecutionGraphCycle cycle = buildMockCycle(parallelism); - RedoCycleScheduledWorkerManager wm = new RedoCycleScheduledWorkerManager(new Configuration()); - RedoCycleScheduledWorkerManager workerManager = Mockito.spy(wm); - - List workers = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); - } - String resourceId = workerManager.genResourceId(cycle.getDriverIndex(), - cycle.getSchedulerId()); - Mockito.doReturn(workers).when(workerManager).requestWorker(anyInt(), eq(resourceId)); - Mockito.doNothing().when(workerManager).initWorkers(anyLong(), any(), any()); - workerManager.init(cycle); - Assert.assertEquals(parallelism, - workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); - - workerManager.assign(cycle); - Assert.assertEquals(parallelism, - workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); - workerManager.workers.remove(cycle.getSchedulerId()); + @Test + public void testRequestMultiTimes() { + int parallelism = 3; + ExecutionGraphCycle cycle = buildMockCycle(parallelism); + RedoCycleScheduledWorkerManager wm = new RedoCycleScheduledWorkerManager(new Configuration()); + RedoCycleScheduledWorkerManager workerManager = Mockito.spy(wm); + + List workers = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + workers.add(new WorkerInfo("", 0, 0, 0, i, "worker-" + i)); } + String resourceId = workerManager.genResourceId(cycle.getDriverIndex(), cycle.getSchedulerId()); + Mockito.doReturn(workers).when(workerManager).requestWorker(anyInt(), eq(resourceId)); + Mockito.doNothing().when(workerManager).initWorkers(anyLong(), any(), any()); + workerManager.init(cycle); + Assert.assertEquals( + parallelism, workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); + + workerManager.assign(cycle); + Assert.assertEquals( + parallelism, workerManager.workers.get(cycle.getSchedulerId()).getWorkers().size()); + workerManager.workers.remove(cycle.getSchedulerId()); + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineTest.java index 030503e5a..3956ebac1 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-runtime-core/src/test/java/org/apache/geaflow/runtime/core/scheduler/statemachine/StateMachineTest.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.protocol.ScheduleStateType; import org.apache.geaflow.cluster.system.ClusterMetaStore; import org.apache.geaflow.common.config.Configuration; @@ -54,517 +55,562 @@ public class StateMachineTest extends BaseCycleSchedulerTest { - private Configuration configuration; - - @BeforeMethod - public void setUp() { - Map config = new HashMap<>(); - config.put(JOB_UNIQUE_ID.getKey(), - "scheduler-state-machine-test" + System.currentTimeMillis()); - config.put(RUN_LOCAL_MODE.getKey(), "true"); - config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); - configuration = new Configuration(config); - ClusterMetaStore.init(0, "driver-0", configuration); - ShuffleManager.init(configuration); - StatsCollectorFactory.init(configuration); - } - - @AfterMethod - public void cleanUp() { - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); + private Configuration configuration; + + @BeforeMethod + public void setUp() { + Map config = new HashMap<>(); + config.put(JOB_UNIQUE_ID.getKey(), "scheduler-state-machine-test" + System.currentTimeMillis()); + config.put(RUN_LOCAL_MODE.getKey(), "true"); + config.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + config.put(CONTAINER_HEAP_SIZE_MB.getKey(), String.valueOf(1024)); + configuration = new Configuration(config); + ClusterMetaStore.init(0, "driver-0", configuration); + ShuffleManager.init(configuration); + StatsCollectorFactory.init(configuration); + } + + @AfterMethod + public void cleanUp() { + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); + } + + @Test + public void testBatch() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + RedoSchedulerContext context = + (RedoSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false, 1), null); + context.setRollback(false); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(1).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(2).getScheduleStateType()); + context.getNextIterationId(); + + Map prefetchEvents = context.getPrefetchEvents(); + ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); + prefetchEvents.put(0, executableEvent); + + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); } - @Test - public void testBatch() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - RedoSchedulerContext context = - (RedoSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false, 1), null); - context.setRollback(false); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(2).getScheduleStateType()); - context.getNextIterationId(); - - Map prefetchEvents = context.getPrefetchEvents(); - ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); - prefetchEvents.put(0, executableEvent); - - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - - // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.FINISH_PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.FINISH_PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.CLEAN_CYCLE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testBatchDisablePrefetch() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + RedoSchedulerContext context = + (RedoSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false, 1, false), null); + context.setRollback(false); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + context.getNextIterationId(); + + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); } - @Test - public void testBatchDisablePrefetch() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - RedoSchedulerContext context = - (RedoSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false, 1, - false), null); - context.setRollback(false); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - context.getNextIterationId(); - - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - - // EXECUTE_COMPUTE -> CLEAN_CYCLE. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> CLEAN_CYCLE. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testStream() { + configuration.put(CLUSTER_ID, "restart"); + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(1).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(2).getScheduleStateType()); + + Map prefetchEvents = context.getPrefetchEvents(); + ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); + prefetchEvents.put(0, executableEvent); + + // INIT -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 5; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testStream() { - configuration.put(CLUSTER_ID, "restart"); - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(2).getScheduleStateType()); - - Map prefetchEvents = context.getPrefetchEvents(); - ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); - prefetchEvents.put(0, executableEvent); - - // INIT -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 5; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.FINISH_PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.FINISH_PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.CLEAN_CYCLE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testStreamDisablePrefetch() { + configuration.put(CLUSTER_ID, "restart"); + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false, 5, false), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // INIT -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 5; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testStreamDisablePrefetch() { - configuration.put(CLUSTER_ID, "restart"); - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false - , 5, false), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // INIT -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 5; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> CLEAN_CYCLE. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> CLEAN_CYCLE. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testIteration() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + RedoSchedulerContext parentContext = new RedoSchedulerContext(buildMockCycle(false), null); + parentContext.init(1); + + IterationRedoSchedulerContext context = + (IterationRedoSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(true), parentContext); + context.setRollback(false); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ITERATION_INIT, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ITERATION_INIT -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 5; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testIteration() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - RedoSchedulerContext parentContext = new RedoSchedulerContext(buildMockCycle(false), null); - parentContext.init(1); - - IterationRedoSchedulerContext context = - (IterationRedoSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(true), parentContext); - context.setRollback(false); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ITERATION_INIT, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ITERATION_INIT -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 5; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> ITERATION_FINISH. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ITERATION_FINISH, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - Map prefetchEvents = context.getPrefetchEvents(); - ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); - prefetchEvents.put(0, executableEvent); - - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.FINISH_PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ITERATION_FINISH -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> ITERATION_FINISH. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ITERATION_FINISH, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + Map prefetchEvents = context.getPrefetchEvents(); + ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); + prefetchEvents.put(0, executableEvent); + + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.FINISH_PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.CLEAN_CYCLE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ITERATION_FINISH -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testIterationDisablePrefetch() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + RedoSchedulerContext parentContext = + new RedoSchedulerContext(buildMockCycle(false, 5, false), null); + parentContext.init(1); + + IterationRedoSchedulerContext context = + (IterationRedoSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(true, 5, false), parentContext); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ITERATION_INIT, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ITERATION_INIT -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 5; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testIterationDisablePrefetch() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - RedoSchedulerContext parentContext = new RedoSchedulerContext(buildMockCycle(false, 5, - false), null); - parentContext.init(1); - - IterationRedoSchedulerContext context = - (IterationRedoSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(true, 5, false), parentContext); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ITERATION_INIT, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ITERATION_INIT -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 5; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> ITERATION_FINISH. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ITERATION_FINISH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ITERATION_FINISH -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> ITERATION_FINISH. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ITERATION_FINISH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.CLEAN_CYCLE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ITERATION_FINISH -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testRollback001() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + + context.init(2); + context.setRecovered(true); + stateMachine.init(context); + + // START -> ROLLBACK. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ROLLBACK, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ROLLBACK -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 4; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testRollback001() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - - context.init(2); - context.setRecovered(true); - stateMachine.init(context); - - // START -> ROLLBACK. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ROLLBACK, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ROLLBACK -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 4; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> ITERATION_FINISH. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); - - // ITERATION_FINISH -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> ITERATION_FINISH. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); + + // ITERATION_FINISH -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testRollback002() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + + context.init(3); + context.setRollback(true); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(1).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ROLLBACK, + ((ComposeState) state).getStates().get(2).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(3).getScheduleStateType()); + + Map prefetchEvents = context.getPrefetchEvents(); + ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); + prefetchEvents.put(0, executableEvent); + + // ROLLBACK -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 3; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testRollback002() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - - context.init(3); - context.setRollback(true); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ROLLBACK, - ((ComposeState) state).getStates().get(2).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(3).getScheduleStateType()); - - Map prefetchEvents = context.getPrefetchEvents(); - ExecutableEvent executableEvent = ExecutableEvent.build(null, null, null); - prefetchEvents.put(0, executableEvent); - - // ROLLBACK -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 3; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.FINISH_PREFETCH, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); - context.setRollback(false); + // EXECUTE_COMPUTE -> FINISH_PREFETCH | CLEAN_CYCLE. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.FINISH_PREFETCH, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.CLEAN_CYCLE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + context.setRollback(false); + } + + @Test + public void testRollback003() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false, 5, true), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + + context.init(2); + context.setRecovered(true); + stateMachine.init(context); + + // START -> ROLLBACK. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ROLLBACK, + ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + + // ROLLBACK -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 4; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testRollback003() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false, 5, true), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - - context.init(2); - context.setRecovered(true); - stateMachine.init(context); - - // START -> ROLLBACK. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ROLLBACK, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - - // ROLLBACK -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 4; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> ITERATION_FINISH. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); - - // ITERATION_FINISH -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + // EXECUTE_COMPUTE -> ITERATION_FINISH. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); + + // ITERATION_FINISH -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + } + + @Test + public void testRollback004() { + ClusterMetaStore.init(0, "driver-0", configuration); + PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); + processor.register(scheduler); + + CheckpointSchedulerContext context = + (CheckpointSchedulerContext) + CycleSchedulerContextFactory.create(buildMockCycle(false, 5, false), null); + PipelineStateMachine stateMachine = new PipelineStateMachine(); + + context.init(3); + context.setRollback(true); + stateMachine.init(context); + + // START -> INIT. + IScheduleState state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.INIT, ((ComposeState) state).getStates().get(0).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.ROLLBACK, + ((ComposeState) state).getStates().get(1).getScheduleStateType()); + Assert.assertEquals( + ScheduleStateType.EXECUTE_COMPUTE, + ((ComposeState) state).getStates().get(2).getScheduleStateType()); + + // ROLLBACK -> loop (EXECUTE_COMPUTE). + for (int i = 1; i <= 3; i++) { + state = stateMachine.transition(); + context.getNextIterationId(); + Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); + state = stateMachine.transition(); + Assert.assertEquals(null, state); + while (context.hasNextToFinish()) { + context.getNextFinishIterationId(); + } } - @Test - public void testRollback004() { - ClusterMetaStore.init(0, "driver-0", configuration); - PipelineCycleScheduler scheduler = new PipelineCycleScheduler(); - processor.register(scheduler); - - CheckpointSchedulerContext context = - (CheckpointSchedulerContext) CycleSchedulerContextFactory.create(buildMockCycle(false, 5, false), null); - PipelineStateMachine stateMachine = new PipelineStateMachine(); - - context.init(3); - context.setRollback(true); - stateMachine.init(context); - - // START -> INIT. - IScheduleState state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.COMPOSE, state.getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.INIT, - ((ComposeState) state).getStates().get(0).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.ROLLBACK, - ((ComposeState) state).getStates().get(1).getScheduleStateType()); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, - ((ComposeState) state).getStates().get(2).getScheduleStateType()); - - // ROLLBACK -> loop (EXECUTE_COMPUTE). - for (int i = 1; i <= 3; i++) { - state = stateMachine.transition(); - context.getNextIterationId(); - Assert.assertEquals(ScheduleStateType.EXECUTE_COMPUTE, state.getScheduleStateType()); - state = stateMachine.transition(); - Assert.assertEquals(null, state); - while (context.hasNextToFinish()) { - context.getNextFinishIterationId(); - } - } - - // EXECUTE_COMPUTE -> ITERATION_FINISH. - state = stateMachine.transition(); - Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); - - - // CLEAN_CYCLE -> END. - Assert.assertEquals(ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); - context.setRollback(false); + // EXECUTE_COMPUTE -> ITERATION_FINISH. + state = stateMachine.transition(); + Assert.assertEquals(ScheduleStateType.CLEAN_CYCLE, state.getScheduleStateType()); + + // CLEAN_CYCLE -> END. + Assert.assertEquals( + ScheduleStateType.END, stateMachine.getCurrentState().getScheduleStateType()); + context.setRollback(false); + } + + private ExecutionNodeCycle buildMockCycle( + boolean isIterative, long finishIterationId, boolean prefetch) { + Configuration configuration = new Configuration(); + configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); + configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(prefetch)); + ClusterMetaStore.init(0, "driver-0", configuration); + + ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); + vertexGroup.getCycleGroupMeta().setFlyingCount(1); + vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); + if (isIterative) { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); + } else { + vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); } + ExecutionVertex vertex = new ExecutionVertex(0, "test"); + vertex.setParallelism(2); + vertexGroup.getVertexMap().put(0, vertex); - private ExecutionNodeCycle buildMockCycle(boolean isIterative, long finishIterationId, - boolean prefetch) { - Configuration configuration = new Configuration(); - configuration.put(JOB_UNIQUE_ID, "test-scheduler-context"); - configuration.put(SYSTEM_STATE_BACKEND_TYPE.getKey(), StoreType.MEMORY.name()); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(prefetch)); - ClusterMetaStore.init(0, "driver-0", configuration); - - ExecutionVertexGroup vertexGroup = new ExecutionVertexGroup(1); - vertexGroup.getCycleGroupMeta().setFlyingCount(1); - vertexGroup.getCycleGroupMeta().setIterationCount(finishIterationId); - if (isIterative) { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.incremental); - } else { - vertexGroup.getCycleGroupMeta().setGroupType(CycleGroupType.pipelined); - } - ExecutionVertex vertex = new ExecutionVertex(0, "test"); - vertex.setParallelism(2); - vertexGroup.getVertexMap().put(0, vertex); - - return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); - } + return new ExecutionNodeCycle(0, 0, 0, "test", vertexGroup, configuration, "driver_id", 0); + } - private ExecutionNodeCycle buildMockCycle(boolean isIterative, long finishIterationId) { - return buildMockCycle(isIterative, finishIterationId, true); - } - - private ExecutionNodeCycle buildMockCycle(boolean isIterative) { - return buildMockCycle(isIterative, 5); - } + private ExecutionNodeCycle buildMockCycle(boolean isIterative, long finishIterationId) { + return buildMockCycle(isIterative, finishIterationId, true); + } + private ExecutionNodeCycle buildMockCycle(boolean isIterative) { + return buildMockCycle(isIterative, 5); + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/ContainerRunner.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/ContainerRunner.java index d78c18682..5f71e7e7d 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/ContainerRunner.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/ContainerRunner.java @@ -39,63 +39,62 @@ import org.slf4j.LoggerFactory; public class ContainerRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(ContainerRunner.class); - private final ContainerContext containerContext; - private IContainer container; + private static final Logger LOGGER = LoggerFactory.getLogger(ContainerRunner.class); + private final ContainerContext containerContext; + private IContainer container; - public ContainerRunner(ContainerContext containerContext) { - this.containerContext = containerContext; - } + public ContainerRunner(ContainerContext containerContext) { + this.containerContext = containerContext; + } - public void run() { - container = new Container(); - containerContext.load(); - container.init(containerContext); - } + public void run() { + container = new Container(); + containerContext.load(); + container.init(containerContext); + } - private void waitForTermination() { - LOGGER.info("wait for service terminating"); - ((AbstractContainer) container).waitTermination(); - } + private void waitForTermination() { + LOGGER.info("wait for service terminating"); + ((AbstractContainer) container).waitTermination(); + } - public void close() { - if (container != null) { - container.close(); - } + public void close() { + if (container != null) { + container.close(); } + } - public static void main(String[] args) throws Exception { - ContainerRunner containerRunner = null; - try { - final long startTime = System.currentTimeMillis(); + public static void main(String[] args) throws Exception { + ContainerRunner containerRunner = null; + try { + final long startTime = System.currentTimeMillis(); - String id = ClusterUtils.getProperty(ClusterConstants.CONTAINER_ID); - String masterId = ClusterUtils.getProperty(MASTER_ID); - LOGGER.info("ResourceID assigned for this container:{} masterId:{}", id, masterId); + String id = ClusterUtils.getProperty(ClusterConstants.CONTAINER_ID); + String masterId = ClusterUtils.getProperty(MASTER_ID); + LOGGER.info("ResourceID assigned for this container:{} masterId:{}", id, masterId); - Configuration config = ClusterUtils.loadConfiguration(); - config.setMasterId(masterId); - String supervisorPort = ClusterUtils.getEnvValue(System.getenv(), ENV_SUPERVISOR_PORT); - config.put(SUPERVISOR_RPC_PORT, supervisorPort); - String agentPort = ClusterUtils.getEnvValue(System.getenv(), ENV_AGENT_PORT); - config.put(AGENT_HTTP_PORT, agentPort); - LOGGER.info("Supervisor rpc port: {} agentPort: {}", supervisorPort, agentPort); + Configuration config = ClusterUtils.loadConfiguration(); + config.setMasterId(masterId); + String supervisorPort = ClusterUtils.getEnvValue(System.getenv(), ENV_SUPERVISOR_PORT); + config.put(SUPERVISOR_RPC_PORT, supervisorPort); + String agentPort = ClusterUtils.getEnvValue(System.getenv(), ENV_AGENT_PORT); + config.put(AGENT_HTTP_PORT, agentPort); + LOGGER.info("Supervisor rpc port: {} agentPort: {}", supervisorPort, agentPort); - new RunnerRuntimeHook(ContainerRunner.class.getSimpleName(), - Integer.parseInt(supervisorPort)).start(); + new RunnerRuntimeHook(ContainerRunner.class.getSimpleName(), Integer.parseInt(supervisorPort)) + .start(); - ContainerContext context = new ContainerContext(Integer.parseInt(id), config); - containerRunner = new ContainerRunner(context); - containerRunner.run(); - LOGGER.info("Completed container init in {}ms", System.currentTimeMillis() - startTime); - containerRunner.waitForTermination(); - } catch (Throwable e) { - LOGGER.error("FATAL: container {} exits", ProcessUtil.getProcessId(), e); - if (containerRunner != null) { - containerRunner.close(); - } - System.exit(EXIT_CODE); - } + ContainerContext context = new ContainerContext(Integer.parseInt(id), config); + containerRunner = new ContainerRunner(context); + containerRunner.run(); + LOGGER.info("Completed container init in {}ms", System.currentTimeMillis() - startTime); + containerRunner.waitForTermination(); + } catch (Throwable e) { + LOGGER.error("FATAL: container {} exits", ProcessUtil.getProcessId(), e); + if (containerRunner != null) { + containerRunner.close(); + } + System.exit(EXIT_CODE); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/DriverRunner.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/DriverRunner.java index 6a7490584..944b4ab8e 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/DriverRunner.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/DriverRunner.java @@ -39,67 +39,67 @@ import org.slf4j.LoggerFactory; public class DriverRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(DriverRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DriverRunner.class); - private final DriverContext driverContext; - private Driver driver; + private final DriverContext driverContext; + private Driver driver; - public DriverRunner(DriverContext driverContext) { - this.driverContext = driverContext; - } + public DriverRunner(DriverContext driverContext) { + this.driverContext = driverContext; + } - public void run() { - int rpcPort = driverContext.getConfig().getInteger(DRIVER_RPC_PORT); - driver = new Driver(rpcPort); - driverContext.load(); - driver.init(driverContext); - } + public void run() { + int rpcPort = driverContext.getConfig().getInteger(DRIVER_RPC_PORT); + driver = new Driver(rpcPort); + driverContext.load(); + driver.init(driverContext); + } - private void waitForTermination() { - LOGGER.info("wait for service terminating"); - driver.waitTermination(); - } + private void waitForTermination() { + LOGGER.info("wait for service terminating"); + driver.waitTermination(); + } - public void close() { - if (driver != null) { - driver.close(); - } + public void close() { + if (driver != null) { + driver.close(); } + } - public static void main(String[] args) throws Exception { - DriverRunner driverRunner = null; - try { - final long startTime = System.currentTimeMillis(); + public static void main(String[] args) throws Exception { + DriverRunner driverRunner = null; + try { + final long startTime = System.currentTimeMillis(); - String id = ClusterUtils.getProperty(CONTAINER_ID); - String index = ClusterUtils.getProperty(CONTAINER_INDEX); - String masterId = ClusterUtils.getProperty(ClusterConstants.MASTER_ID); - LOGGER.info("ResourceID assigned for this driver id:{} index:{} masterId:{}", id, index, masterId); + String id = ClusterUtils.getProperty(CONTAINER_ID); + String index = ClusterUtils.getProperty(CONTAINER_INDEX); + String masterId = ClusterUtils.getProperty(ClusterConstants.MASTER_ID); + LOGGER.info( + "ResourceID assigned for this driver id:{} index:{} masterId:{}", id, index, masterId); - Configuration config = ClusterUtils.loadConfiguration(); - config.setMasterId(masterId); - String supervisorPort = ClusterUtils.getEnvValue(System.getenv(), ENV_SUPERVISOR_PORT); - config.put(SUPERVISOR_RPC_PORT, supervisorPort); - String agentPort = ClusterUtils.getEnvValue(System.getenv(), ENV_AGENT_PORT); - config.put(AGENT_HTTP_PORT, agentPort); - LOGGER.info("Supervisor rpc port: {} agentPort: {}", supervisorPort, agentPort); + Configuration config = ClusterUtils.loadConfiguration(); + config.setMasterId(masterId); + String supervisorPort = ClusterUtils.getEnvValue(System.getenv(), ENV_SUPERVISOR_PORT); + config.put(SUPERVISOR_RPC_PORT, supervisorPort); + String agentPort = ClusterUtils.getEnvValue(System.getenv(), ENV_AGENT_PORT); + config.put(AGENT_HTTP_PORT, agentPort); + LOGGER.info("Supervisor rpc port: {} agentPort: {}", supervisorPort, agentPort); - new RunnerRuntimeHook(DriverRunner.class.getSimpleName(), - Integer.parseInt(supervisorPort)).start(); + new RunnerRuntimeHook(DriverRunner.class.getSimpleName(), Integer.parseInt(supervisorPort)) + .start(); - DriverContext context = new DriverContext(Integer.parseInt(id), - Integer.parseInt(index), config); - driverRunner = new DriverRunner(context); - driverRunner.run(); - LOGGER.info("Completed driver init in {} ms", System.currentTimeMillis() - startTime); - driverRunner.waitForTermination(); - } catch (Throwable e) { - LOGGER.error("FATAL: driver {} exits", ProcessUtil.getProcessId(), e); - if (driverRunner != null) { - driverRunner.close(); - } - System.exit(EXIT_CODE); - } + DriverContext context = + new DriverContext(Integer.parseInt(id), Integer.parseInt(index), config); + driverRunner = new DriverRunner(context); + driverRunner.run(); + LOGGER.info("Completed driver init in {} ms", System.currentTimeMillis() - startTime); + driverRunner.waitForTermination(); + } catch (Throwable e) { + LOGGER.error("FATAL: driver {} exits", ProcessUtil.getProcessId(), e); + if (driverRunner != null) { + driverRunner.close(); + } + System.exit(EXIT_CODE); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/MasterRunner.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/MasterRunner.java index 34c8f3bcd..8dce98d12 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/MasterRunner.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/entrypoint/MasterRunner.java @@ -31,40 +31,38 @@ public class MasterRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(MasterRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MasterRunner.class); - protected final Configuration config; - protected final AbstractMaster master; + protected final Configuration config; + protected final AbstractMaster master; - public MasterRunner(Configuration config, IClusterManager clusterManager) { - this.config = config; - if (config.getBoolean(ExecutionConfigKeys.ENABLE_MASTER_LEADER_ELECTION)) { - initLeaderElectionService(); - } - - MasterContext context = new MasterContext(config); - context.setClusterManager(clusterManager); - context.load(); - - master = MasterFactory.create(config); - master.init(context); + public MasterRunner(Configuration config, IClusterManager clusterManager) { + this.config = config; + if (config.getBoolean(ExecutionConfigKeys.ENABLE_MASTER_LEADER_ELECTION)) { + initLeaderElectionService(); } - protected void initLeaderElectionService() { - } + MasterContext context = new MasterContext(config); + context.setClusterManager(clusterManager); + context.load(); - public ClusterInfo init() { - try { - return master.startCluster(); - } catch (Throwable e) { - LOGGER.error("init failed", e); - throw e; - } - } + master = MasterFactory.create(config); + master.init(context); + } + + protected void initLeaderElectionService() {} - protected void waitForTermination() { - LOGGER.info("waiting for finishing..."); - master.waitTermination(); + public ClusterInfo init() { + try { + return master.startCluster(); + } catch (Throwable e) { + LOGGER.error("init failed", e); + throw e; } + } + protected void waitForTermination() { + LOGGER.info("waiting for finishing..."); + master.waitTermination(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AbstractFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AbstractFailoverStrategy.java index b031a1b0c..b10dae89f 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AbstractFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AbstractFailoverStrategy.java @@ -32,33 +32,33 @@ public abstract class AbstractFailoverStrategy implements IFailoverStrategy { - protected EnvType envType; - protected AbstractClusterManager clusterManager; - protected boolean enableSupervisor; - protected ClusterContext context; + protected EnvType envType; + protected AbstractClusterManager clusterManager; + protected boolean enableSupervisor; + protected ClusterContext context; - public AbstractFailoverStrategy(EnvType envType) { - this.envType = envType; - } + public AbstractFailoverStrategy(EnvType envType) { + this.envType = envType; + } - @Override - public void init(ClusterContext context) { - this.context = context; - this.enableSupervisor = context.getConfig().getBoolean(SUPERVISOR_ENABLE); - } + @Override + public void init(ClusterContext context) { + this.context = context; + this.enableSupervisor = context.getConfig().getBoolean(SUPERVISOR_ENABLE); + } - protected void reportFailoverEvent(ExceptionLevel level, EventLabel label, String message) { - StatsCollectorFactory.init(context.getConfig()).getEventCollector() - .reportEvent(level, label, message); - } + protected void reportFailoverEvent(ExceptionLevel level, EventLabel label, String message) { + StatsCollectorFactory.init(context.getConfig()) + .getEventCollector() + .reportEvent(level, label, message); + } - public void setClusterManager(IClusterManager clusterManager) { - this.clusterManager = (AbstractClusterManager) clusterManager; - } - - @Override - public EnvType getEnv() { - return envType; - } + public void setClusterManager(IClusterManager clusterManager) { + this.clusterManager = (AbstractClusterManager) clusterManager; + } + @Override + public EnvType getEnv() { + return envType; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AutoRestartPolicy.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AutoRestartPolicy.java index db0dc3834..d1de8797a 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AutoRestartPolicy.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/AutoRestartPolicy.java @@ -21,29 +21,22 @@ public enum AutoRestartPolicy { - /** - * the process will not be autorestarted. - */ - FALSE("false"), + /** the process will not be autorestarted. */ + FALSE("false"), - /** - * the process will always be autorestarted. - */ - TRUE("true"), + /** the process will always be autorestarted. */ + TRUE("true"), - /** - * the process will be autorestarted when exits with unexpected codes. - */ - UNEXPECTED("unexpected"); + /** the process will be autorestarted when exits with unexpected codes. */ + UNEXPECTED("unexpected"); - private final String value; + private final String value; - AutoRestartPolicy(String value) { - this.value = value; - } - - public String getValue() { - return value; - } + AutoRestartPolicy(String value) { + this.value = value; + } + public String getValue() { + return value; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ClusterFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ClusterFailoverStrategy.java index 61390f89a..5766ca112 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ClusterFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ClusterFailoverStrategy.java @@ -24,6 +24,7 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.PROCESS_AUTO_RESTART; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.failover.FailoverStrategyType; import org.apache.geaflow.cluster.heartbeat.HeartbeatManager; @@ -33,58 +34,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This strategy is to restart the whole cluster by the master once an anomaly is detected. - */ +/** This strategy is to restart the whole cluster by the master once an anomaly is detected. */ public class ClusterFailoverStrategy extends AbstractFailoverStrategy { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterFailoverStrategy.class); - - protected AtomicBoolean doKilling; - protected HeartbeatManager heartbeatManager; + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterFailoverStrategy.class); - public ClusterFailoverStrategy(EnvType envType) { - super(envType); - } + protected AtomicBoolean doKilling; + protected HeartbeatManager heartbeatManager; - @Override - public void init(ClusterContext context) { - super.init(context); - this.heartbeatManager = context.getHeartbeatManager(); - // Set true if in recovering and reset to false after recovering finished. - this.doKilling = new AtomicBoolean(context.isRecover()); - // Disable worker process auto-restart because master will do that. - context.getConfig().put(PROCESS_AUTO_RESTART, Boolean.FALSE.toString()); - LOGGER.info("init with recovering: {}", context.isRecover()); - } + public ClusterFailoverStrategy(EnvType envType) { + super(envType); + } - @Override - public void doFailover(int componentId, Throwable cause) { - boolean isMasterRestarts = (componentId == DEFAULT_MASTER_ID); - if (isMasterRestarts) { - // Master restart itself when the process is started in recover mode. - final long startTime = System.currentTimeMillis(); - clusterManager.restartAllDrivers(); - clusterManager.restartAllContainers(); - doKilling.set(false); - String finishMessage = String.format("Completed failover in %s ms.", - System.currentTimeMillis() - startTime); - LOGGER.info(finishMessage); - reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_FINISH, finishMessage); - } else if (doKilling.compareAndSet(false, true)) { - String reason = cause == null ? null : cause.getMessage(); - String startMessage = String.format("Start failover due to %s", reason); - LOGGER.info(startMessage); - reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_START, startMessage); - // Close heartbeat check service. - heartbeatManager.close(); - // Trigger process restart. - System.exit(EXIT_CODE); - } - } + @Override + public void init(ClusterContext context) { + super.init(context); + this.heartbeatManager = context.getHeartbeatManager(); + // Set true if in recovering and reset to false after recovering finished. + this.doKilling = new AtomicBoolean(context.isRecover()); + // Disable worker process auto-restart because master will do that. + context.getConfig().put(PROCESS_AUTO_RESTART, Boolean.FALSE.toString()); + LOGGER.info("init with recovering: {}", context.isRecover()); + } - @Override - public FailoverStrategyType getType() { - return FailoverStrategyType.cluster_fo; + @Override + public void doFailover(int componentId, Throwable cause) { + boolean isMasterRestarts = (componentId == DEFAULT_MASTER_ID); + if (isMasterRestarts) { + // Master restart itself when the process is started in recover mode. + final long startTime = System.currentTimeMillis(); + clusterManager.restartAllDrivers(); + clusterManager.restartAllContainers(); + doKilling.set(false); + String finishMessage = + String.format("Completed failover in %s ms.", System.currentTimeMillis() - startTime); + LOGGER.info(finishMessage); + reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_FINISH, finishMessage); + } else if (doKilling.compareAndSet(false, true)) { + String reason = cause == null ? null : cause.getMessage(); + String startMessage = String.format("Start failover due to %s", reason); + LOGGER.info(startMessage); + reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_START, startMessage); + // Close heartbeat check service. + heartbeatManager.close(); + // Trigger process restart. + System.exit(EXIT_CODE); } + } + @Override + public FailoverStrategyType getType() { + return FailoverStrategyType.cluster_fo; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ComponentFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ComponentFailoverStrategy.java index e0e87378a..3eed67abc 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ComponentFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/ComponentFailoverStrategy.java @@ -27,30 +27,26 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This strategy is to restart the process by supervisor but not master. - */ +/** This strategy is to restart the process by supervisor but not master. */ public class ComponentFailoverStrategy extends AbstractFailoverStrategy { - private static final Logger LOGGER = LoggerFactory.getLogger(ComponentFailoverStrategy.class); - - public ComponentFailoverStrategy(EnvType envType) { - super(envType); - } - - @Override - public void init(ClusterContext context) { - super.init(context); - context.getConfig().put(PROCESS_AUTO_RESTART, AutoRestartPolicy.UNEXPECTED.getValue()); - LOGGER.info("init with foRestarts: {}", context.getClusterConfig().getMaxRestarts()); - } - - @Override - public void doFailover(int componentId, Throwable cause) { - } - - @Override - public FailoverStrategyType getType() { - return FailoverStrategyType.component_fo; - } - + private static final Logger LOGGER = LoggerFactory.getLogger(ComponentFailoverStrategy.class); + + public ComponentFailoverStrategy(EnvType envType) { + super(envType); + } + + @Override + public void init(ClusterContext context) { + super.init(context); + context.getConfig().put(PROCESS_AUTO_RESTART, AutoRestartPolicy.UNEXPECTED.getValue()); + LOGGER.info("init with foRestarts: {}", context.getClusterConfig().getMaxRestarts()); + } + + @Override + public void doFailover(int componentId, Throwable cause) {} + + @Override + public FailoverStrategyType getType() { + return FailoverStrategyType.component_fo; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/DisableFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/DisableFailoverStrategy.java index 7892762d4..fcc786a9c 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/DisableFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/failover/DisableFailoverStrategy.java @@ -29,27 +29,27 @@ public class DisableFailoverStrategy extends AbstractFailoverStrategy { - private static final Logger LOGGER = - LoggerFactory.getLogger(DisableFailoverStrategy.class); - - - public DisableFailoverStrategy(EnvType envType) { - super(envType); - } - - @Override - public void init(ClusterContext context) { - context.getConfig().put(PROCESS_AUTO_RESTART, AutoRestartPolicy.FALSE.getValue()); - } - - @Override - public void doFailover(int componentId, Throwable cause) { - LOGGER.info("Failover is disabled, do nothing. Triggered by component #{}: {}.", - componentId, cause == null ? null : cause.getMessage()); - } - - @Override - public FailoverStrategyType getType() { - return FailoverStrategyType.disable_fo; - } + private static final Logger LOGGER = LoggerFactory.getLogger(DisableFailoverStrategy.class); + + public DisableFailoverStrategy(EnvType envType) { + super(envType); + } + + @Override + public void init(ClusterContext context) { + context.getConfig().put(PROCESS_AUTO_RESTART, AutoRestartPolicy.FALSE.getValue()); + } + + @Override + public void doFailover(int componentId, Throwable cause) { + LOGGER.info( + "Failover is disabled, do nothing. Triggered by component #{}: {}.", + componentId, + cause == null ? null : cause.getMessage()); + } + + @Override + public FailoverStrategyType getType() { + return FailoverStrategyType.disable_fo; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/manager/GeaFlowClusterManager.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/manager/GeaFlowClusterManager.java index 9df23411f..66fe20be9 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/manager/GeaFlowClusterManager.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/manager/GeaFlowClusterManager.java @@ -38,6 +38,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.Future; + import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.failover.FailoverStrategyFactory; @@ -53,133 +54,144 @@ public abstract class GeaFlowClusterManager extends AbstractClusterManager { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowClusterManager.class); - protected final String classpath; - protected boolean enableSupervisor; - protected EnvType envType; - protected String configValue; - protected boolean failFast; - protected String logDir; - - public GeaFlowClusterManager(EnvType envType) { - this.envType = envType; - this.classpath = System.getProperty("java.class.path"); - } - - @Override - public void init(ClusterContext clusterContext) { - super.init(clusterContext); - this.config = clusterContext.getConfig(); - this.configValue = ClusterUtils.convertConfigToString(clusterContext.getConfig()); - this.enableSupervisor = clusterContext.getConfig().getBoolean(SUPERVISOR_ENABLE); - this.logDir = clusterContext.getConfig().getString(LOG_DIR); - } - - @Override - protected IFailoverStrategy buildFailoverStrategy() { - IFailoverStrategy foStrategy = FailoverStrategyFactory.loadFailoverStrategy(envType, - clusterConfig.getConfig().getString(FO_STRATEGY)); - foStrategy.init(clusterContext); - if (foStrategy instanceof AbstractFailoverStrategy) { - ((AbstractFailoverStrategy) foStrategy).setClusterManager(this); - } - return foStrategy; - } - - @Override - public void restartAllDrivers() { - Map driverIds = clusterContext.getDriverIds(); - LOGGER.info("Restart all drivers: {}", driverIds); - if (enableSupervisor) { - restartContainersBySupervisor(driverIds, true); - } else { - for (Map.Entry entry : driverIds.entrySet()) { - restartDriver(entry.getKey()); - } - } - } - - @Override - public void restartAllContainers() { - Map containerIds = clusterContext.getContainerIds(); - LOGGER.info("Restart all containers: {}", containerIds); - if (enableSupervisor) { - restartContainersBySupervisor(containerIds, false); - } else { - for (Map.Entry entry : containerIds.entrySet()) { - restartContainer(entry.getKey()); - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowClusterManager.class); + protected final String classpath; + protected boolean enableSupervisor; + protected EnvType envType; + protected String configValue; + protected boolean failFast; + protected String logDir; + + public GeaFlowClusterManager(EnvType envType) { + this.envType = envType; + this.classpath = System.getProperty("java.class.path"); + } + + @Override + public void init(ClusterContext clusterContext) { + super.init(clusterContext); + this.config = clusterContext.getConfig(); + this.configValue = ClusterUtils.convertConfigToString(clusterContext.getConfig()); + this.enableSupervisor = clusterContext.getConfig().getBoolean(SUPERVISOR_ENABLE); + this.logDir = clusterContext.getConfig().getString(LOG_DIR); + } + + @Override + protected IFailoverStrategy buildFailoverStrategy() { + IFailoverStrategy foStrategy = + FailoverStrategyFactory.loadFailoverStrategy( + envType, clusterConfig.getConfig().getString(FO_STRATEGY)); + foStrategy.init(clusterContext); + if (foStrategy instanceof AbstractFailoverStrategy) { + ((AbstractFailoverStrategy) foStrategy).setClusterManager(this); } - - protected void restartContainersBySupervisor(Map containerIds, - boolean isDriver) { - List futures = new ArrayList<>(); - for (Map.Entry entry : containerIds.entrySet()) { - futures.add( - RpcClient.getInstance().restartWorkerBySupervisor(entry.getValue(), failFast)); - } - Iterator> iterator = containerIds.entrySet().iterator(); - List lostWorkers = new ArrayList<>(); - for (Future future : futures) { - Entry entry = iterator.next(); - try { - future.get(); - } catch (Throwable e) { - LOGGER.warn("catch exception from {}: {} {}", entry.getValue(), - e.getClass().getCanonicalName(), e.getMessage()); - lostWorkers.add(entry.getKey()); - } - } - if (isDriver) { - LOGGER.info("Restart lost drivers: {}", lostWorkers); - for (Integer id : lostWorkers) { - restartDriver(id); - } - } else { - LOGGER.info("Restart lost containers: {}", lostWorkers); - for (Integer id : lostWorkers) { - restartContainer(id); - } - } + return foStrategy; + } + + @Override + public void restartAllDrivers() { + Map driverIds = clusterContext.getDriverIds(); + LOGGER.info("Restart all drivers: {}", driverIds); + if (enableSupervisor) { + restartContainersBySupervisor(driverIds, true); + } else { + for (Map.Entry entry : driverIds.entrySet()) { + restartDriver(entry.getKey()); + } } - - public String getDriverShellCommand(int driverId, int driverIndex, - String logFile) { - Map extraOptions = buildExtraOptions(driverId); - extraOptions.put(CONTAINER_INDEX, String.valueOf(driverIndex)); - - String logFilename = logDir + File.separator + logFile; - return ClusterUtils.getStartCommand(clusterConfig.getDriverJvmOptions(), DriverRunner.class, - logFilename, clusterConfig.getConfig(), extraOptions, classpath, false); + } + + @Override + public void restartAllContainers() { + Map containerIds = clusterContext.getContainerIds(); + LOGGER.info("Restart all containers: {}", containerIds); + if (enableSupervisor) { + restartContainersBySupervisor(containerIds, false); + } else { + for (Map.Entry entry : containerIds.entrySet()) { + restartContainer(entry.getKey()); + } } + } - public String getContainerShellCommand(int containerId, boolean isRecover, - String logFile) { - Map extraOptions = buildExtraOptions(containerId); - extraOptions.put(IS_RECOVER, String.valueOf(isRecover)); - - String logFilename = logDir + File.separator + logFile; - return ClusterUtils.getStartCommand(clusterConfig.getContainerJvmOptions(), - ContainerRunner.class, logFilename, clusterConfig.getConfig(), extraOptions, - classpath, false); + protected void restartContainersBySupervisor( + Map containerIds, boolean isDriver) { + List futures = new ArrayList<>(); + for (Map.Entry entry : containerIds.entrySet()) { + futures.add(RpcClient.getInstance().restartWorkerBySupervisor(entry.getValue(), failFast)); } - - protected Map buildExtraOptions(int containerId) { - Map env = new HashMap<>(); - env.put(MASTER_ID, masterId); - env.put(CONTAINER_ID, String.valueOf(containerId)); - env.put(JOB_CONFIG, configValue); - return env; + Iterator> iterator = containerIds.entrySet().iterator(); + List lostWorkers = new ArrayList<>(); + for (Future future : futures) { + Entry entry = iterator.next(); + try { + future.get(); + } catch (Throwable e) { + LOGGER.warn( + "catch exception from {}: {} {}", + entry.getValue(), + e.getClass().getCanonicalName(), + e.getMessage()); + lostWorkers.add(entry.getKey()); + } } - - protected Map buildSupervisorEnvs(int containerId, String startCommand, - String autoRestart) { - Map env = new HashMap<>(); - env.put(AUTO_RESTART, autoRestart); - env.put(CONTAINER_ID, String.valueOf(containerId)); - env.put(CONTAINER_START_COMMAND, startCommand); - return env; + if (isDriver) { + LOGGER.info("Restart lost drivers: {}", lostWorkers); + for (Integer id : lostWorkers) { + restartDriver(id); + } + } else { + LOGGER.info("Restart lost containers: {}", lostWorkers); + for (Integer id : lostWorkers) { + restartContainer(id); + } } - + } + + public String getDriverShellCommand(int driverId, int driverIndex, String logFile) { + Map extraOptions = buildExtraOptions(driverId); + extraOptions.put(CONTAINER_INDEX, String.valueOf(driverIndex)); + + String logFilename = logDir + File.separator + logFile; + return ClusterUtils.getStartCommand( + clusterConfig.getDriverJvmOptions(), + DriverRunner.class, + logFilename, + clusterConfig.getConfig(), + extraOptions, + classpath, + false); + } + + public String getContainerShellCommand(int containerId, boolean isRecover, String logFile) { + Map extraOptions = buildExtraOptions(containerId); + extraOptions.put(IS_RECOVER, String.valueOf(isRecover)); + + String logFilename = logDir + File.separator + logFile; + return ClusterUtils.getStartCommand( + clusterConfig.getContainerJvmOptions(), + ContainerRunner.class, + logFilename, + clusterConfig.getConfig(), + extraOptions, + classpath, + false); + } + + protected Map buildExtraOptions(int containerId) { + Map env = new HashMap<>(); + env.put(MASTER_ID, masterId); + env.put(CONTAINER_ID, String.valueOf(containerId)); + env.put(JOB_CONFIG, configValue); + return env; + } + + protected Map buildSupervisorEnvs( + int containerId, String startCommand, String autoRestart) { + Map env = new HashMap<>(); + env.put(AUTO_RESTART, autoRestart); + env.put(CONTAINER_ID, String.valueOf(containerId)); + env.put(CONTAINER_START_COMMAND, startCommand); + return env; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/ClusterUtils.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/ClusterUtils.java index 1dd4a68a8..631e4535e 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/ClusterUtils.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/ClusterUtils.java @@ -24,12 +24,12 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.JOB_CONFIG; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CONF_DIR; -import com.google.common.base.Preconditions; import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.lang.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.config.ClusterJvmOptions; @@ -38,95 +38,102 @@ import org.apache.geaflow.utils.JsonUtils; import org.eclipse.jetty.util.StringUtil; +import com.google.common.base.Preconditions; + public class ClusterUtils { - private static final String PROPERTY_FORMAT = "-D%s=\"%s\""; + private static final String PROPERTY_FORMAT = "-D%s=\"%s\""; - public static String getProperty(String key) { - String value = System.getProperty(key); - if (StringUtil.isEmpty(value)) { - throw new GeaflowRuntimeException(String.format("Jvm property %s not found.", key)); - } - return value; + public static String getProperty(String key) { + String value = System.getProperty(key); + if (StringUtil.isEmpty(value)) { + throw new GeaflowRuntimeException(String.format("Jvm property %s not found.", key)); } - - public static String getEnvValue(Map env, String envKey) { - String value = env.get(envKey); - Preconditions.checkArgument(value != null, "%s is not set", envKey); - return value; + return value; + } + + public static String getEnvValue(Map env, String envKey) { + String value = env.get(envKey); + Preconditions.checkArgument(value != null, "%s is not set", envKey); + return value; + } + + public static Configuration loadConfiguration() { + String content = getProperty(JOB_CONFIG); + return convertStringToConfig(content); + } + + public static String convertConfigToString(Configuration configuration) { + return StringEscapeUtils.escapeJava(JsonUtils.toJsonString(configuration.getConfigMap())); + } + + public static Configuration convertStringToConfig(String content) { + Map map = JsonUtils.parseJson2map(content); + return new Configuration(map); + } + + public static String getStartCommand( + ClusterJvmOptions jvmOpts, + Class mainClass, + String logFilename, + Configuration configuration, + String classpath) { + return getStartCommand(jvmOpts, mainClass, logFilename, configuration, null, classpath, true); + } + + /** + * This method is an adaptation of Flink's. + * org.apache.flink.runtime.clusterframework.BootstrapTools#getTaskManagerShellCommand. + */ + public static String getStartCommand( + ClusterJvmOptions jvmOpts, + Class mainClass, + String logFilename, + Configuration configuration, + Map extraOpts, + String classpath, + boolean needRedirect) { + final Map startCommandValues = new HashMap<>(); + startCommandValues.put("java", "java"); + startCommandValues.put("classpath", "-classpath " + classpath); + startCommandValues.put("class", mainClass.getName()); + + ArrayList params = new ArrayList<>(); + params.add(String.format("-Xms%dm", jvmOpts.getXmsMB())); + params.add(String.format("-Xmx%dm", jvmOpts.getMaxHeapMB())); + if (jvmOpts.getXmnMB() > 0) { + params.add(String.format("-Xmn%dm", jvmOpts.getXmnMB())); } - - public static Configuration loadConfiguration() { - String content = getProperty(JOB_CONFIG); - return convertStringToConfig(content); + if (jvmOpts.getMaxDirectMB() > 0) { + params.add(String.format("-XX:MaxDirectMemorySize=%dm", jvmOpts.getMaxDirectMB())); } - - public static String convertConfigToString(Configuration configuration) { - return StringEscapeUtils.escapeJava(JsonUtils.toJsonString(configuration.getConfigMap())); + startCommandValues.put("jvmmem", StringUtils.join(params, ' ')); + List opts = jvmOpts.getExtraOptions(); + if (extraOpts != null && !extraOpts.isEmpty()) { + opts = new ArrayList<>(jvmOpts.getExtraOptions()); + for (Map.Entry entry : extraOpts.entrySet()) { + opts.add(String.format(PROPERTY_FORMAT, entry.getKey(), entry.getValue())); + } } + startCommandValues.put("jvmopts", StringUtils.join(opts, ' ')); - public static Configuration convertStringToConfig(String content) { - Map map = JsonUtils.parseJson2map(content); - return new Configuration(map); - } + String confDir = configuration.getString(CONF_DIR); + String log4jPath = Paths.get(confDir, CONFIG_FILE_LOG4J_NAME).toString(); + StringBuilder logging = new StringBuilder(); + logging.append("-Dlog.file=").append(logFilename); + logging.append(" -Dlog4j.configuration=file:").append(log4jPath); + startCommandValues.put("logging", logging.toString()); - public static String getStartCommand(ClusterJvmOptions jvmOpts, Class mainClass, - String logFilename, Configuration configuration, - String classpath) { - return getStartCommand(jvmOpts, mainClass, logFilename, configuration, null, classpath, - true); - } + String redirects = needRedirect ? ">> " + logFilename + " 2>&1" : ""; + startCommandValues.put("redirects", redirects); - /** - * This method is an adaptation of Flink's. - * org.apache.flink.runtime.clusterframework.BootstrapTools#getTaskManagerShellCommand. - */ - public static String getStartCommand(ClusterJvmOptions jvmOpts, Class mainClass, - String logFilename, Configuration configuration, - Map extraOpts, String classpath, - boolean needRedirect) { - final Map startCommandValues = new HashMap<>(); - startCommandValues.put("java", "java"); - startCommandValues.put("classpath", "-classpath " + classpath); - startCommandValues.put("class", mainClass.getName()); - - ArrayList params = new ArrayList<>(); - params.add(String.format("-Xms%dm", jvmOpts.getXmsMB())); - params.add(String.format("-Xmx%dm", jvmOpts.getMaxHeapMB())); - if (jvmOpts.getXmnMB() > 0) { - params.add(String.format("-Xmn%dm", jvmOpts.getXmnMB())); - } - if (jvmOpts.getMaxDirectMB() > 0) { - params.add(String.format("-XX:MaxDirectMemorySize=%dm", jvmOpts.getMaxDirectMB())); - } - startCommandValues.put("jvmmem", StringUtils.join(params, ' ')); - List opts = jvmOpts.getExtraOptions(); - if (extraOpts != null && !extraOpts.isEmpty()) { - opts = new ArrayList<>(jvmOpts.getExtraOptions()); - for (Map.Entry entry : extraOpts.entrySet()) { - opts.add(String.format(PROPERTY_FORMAT, entry.getKey(), entry.getValue())); - } - } - startCommandValues.put("jvmopts", StringUtils.join(opts, ' ')); - - String confDir = configuration.getString(CONF_DIR); - String log4jPath = Paths.get(confDir, CONFIG_FILE_LOG4J_NAME).toString(); - StringBuilder logging = new StringBuilder(); - logging.append("-Dlog.file=").append(logFilename); - logging.append(" -Dlog4j.configuration=file:").append(log4jPath); - startCommandValues.put("logging", logging.toString()); - - String redirects = needRedirect ? ">> " + logFilename + " 2>&1" : ""; - startCommandValues.put("redirects", redirects); - - return getStartCommand(CONTAINER_START_COMMAND_TEMPLATE, startCommandValues); - } + return getStartCommand(CONTAINER_START_COMMAND_TEMPLATE, startCommandValues); + } - public static String getStartCommand(String template, Map startCommandValues) { - for (Map.Entry variable : startCommandValues.entrySet()) { - template = template.replace("%" + variable.getKey() + "%", variable.getValue()); - } - return template; + public static String getStartCommand(String template, Map startCommandValues) { + for (Map.Entry variable : startCommandValues.entrySet()) { + template = template.replace("%" + variable.getKey() + "%", variable.getValue()); } - + return template; + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/RunnerRuntimeHook.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/RunnerRuntimeHook.java index ab54323fd..246bed82f 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/RunnerRuntimeHook.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/main/java/org/apache/geaflow/cluster/runner/util/RunnerRuntimeHook.java @@ -20,40 +20,41 @@ package org.apache.geaflow.cluster.runner.util; import java.net.Socket; + import org.apache.geaflow.common.utils.ProcessUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class RunnerRuntimeHook extends Thread { - private static final Logger LOGGER = LoggerFactory.getLogger(RunnerRuntimeHook.class); - - private final String name; - private final int port; - - public RunnerRuntimeHook(String name, int port) { - this.name = name; - this.port = port; - } - - @Override - public void run() { - checkLiveness(); - } - - private void checkLiveness() { - String host = ProcessUtil.getHostIp(); - try (Socket socket = new Socket(host, port)) { - LOGGER.info("Created socket to address: {}/{}", host, port); - int c; - while ((c = socket.getInputStream().read()) != -1) { - LOGGER.info("Read message from remote: {}", c); - } - } catch (Throwable e) { - LOGGER.error("Read from supervisor failed", e); - } - int pid = ProcessUtil.getProcessId(); - LOGGER.error("Kill {}(pid:{}) because parent process died", name, pid); - ProcessUtil.killProcess(pid); + private static final Logger LOGGER = LoggerFactory.getLogger(RunnerRuntimeHook.class); + + private final String name; + private final int port; + + public RunnerRuntimeHook(String name, int port) { + this.name = name; + this.port = port; + } + + @Override + public void run() { + checkLiveness(); + } + + private void checkLiveness() { + String host = ProcessUtil.getHostIp(); + try (Socket socket = new Socket(host, port)) { + LOGGER.info("Created socket to address: {}/{}", host, port); + int c; + while ((c = socket.getInputStream().read()) != -1) { + LOGGER.info("Read message from remote: {}", c); + } + } catch (Throwable e) { + LOGGER.error("Read from supervisor failed", e); } + int pid = ProcessUtil.getProcessId(); + LOGGER.error("Kill {}(pid:{}) because parent process died", name, pid); + ProcessUtil.killProcess(pid); + } } diff --git a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/test/java/org/apache/geaflow/cluster/runner/util/ClusterUtilsTest.java b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/test/java/org/apache/geaflow/cluster/runner/util/ClusterUtilsTest.java index c4bed55d0..ea797ac0f 100644 --- a/geaflow/geaflow-deploy/geaflow-cluster-runner/src/test/java/org/apache/geaflow/cluster/runner/util/ClusterUtilsTest.java +++ b/geaflow/geaflow-deploy/geaflow-cluster-runner/src/test/java/org/apache/geaflow/cluster/runner/util/ClusterUtilsTest.java @@ -26,6 +26,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.config.ClusterJvmOptions; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.utils.JsonUtils; @@ -35,49 +36,49 @@ public class ClusterUtilsTest { - @Test - public void testConfiguration() { - Map map = new HashMap<>(); - map.put("k", "v"); - String s = JsonUtils.toJsonString(map); - Configuration configuration = ClusterUtils.convertStringToConfig(s); - Assert.assertNotNull(configuration); - } - - @Test - public void getStartCommand_WithValidInputs_ShouldReturnCorrectCommand() { - ClusterJvmOptions jvmOptions = Mockito.mock(ClusterJvmOptions.class); - Configuration configuration = new Configuration(); + @Test + public void testConfiguration() { + Map map = new HashMap<>(); + map.put("k", "v"); + String s = JsonUtils.toJsonString(map); + Configuration configuration = ClusterUtils.convertStringToConfig(s); + Assert.assertNotNull(configuration); + } - // 准备 - when(jvmOptions.getXmsMB()).thenReturn(1024); - when(jvmOptions.getMaxHeapMB()).thenReturn(2048); - when(jvmOptions.getXmnMB()).thenReturn(512); - when(jvmOptions.getMaxDirectMB()).thenReturn(1024); - when(jvmOptions.getExtraOptions()).thenReturn(new ArrayList<>()); + @Test + public void getStartCommand_WithValidInputs_ShouldReturnCorrectCommand() { + ClusterJvmOptions jvmOptions = Mockito.mock(ClusterJvmOptions.class); + Configuration configuration = new Configuration(); - Map extraOpts = new HashMap<>(); - extraOpts.put("key1", "value1"); - extraOpts.put("key2", "value2"); + // 准备 + when(jvmOptions.getXmsMB()).thenReturn(1024); + when(jvmOptions.getMaxHeapMB()).thenReturn(2048); + when(jvmOptions.getXmnMB()).thenReturn(512); + when(jvmOptions.getMaxDirectMB()).thenReturn(1024); + when(jvmOptions.getExtraOptions()).thenReturn(new ArrayList<>()); - String classpath = "/path/to/classes"; - String logFilename = "log.txt"; - Class mainClass = ClusterUtilsTest.class; + Map extraOpts = new HashMap<>(); + extraOpts.put("key1", "value1"); + extraOpts.put("key2", "value2"); - String command = ClusterUtils.getStartCommand(jvmOptions, mainClass, logFilename, - configuration, extraOpts, classpath, true); + String classpath = "/path/to/classes"; + String logFilename = "log.txt"; + Class mainClass = ClusterUtilsTest.class; - // 验证 - assertTrue(command.contains("-Xms1024m")); - assertTrue(command.contains("-Xmx2048m")); - assertTrue(command.contains("-Xmn512m")); - assertTrue(command.contains("-XX:MaxDirectMemorySize=1024m")); - assertTrue(command.contains("-Dkey1=\"value1\"")); - assertTrue(command.contains("-Dkey2=\"value2\"")); - assertTrue(command.contains("-Dlog.file=log.txt")); - assertTrue(command.contains( - "-Dlog4j.configuration=file:/etc/geaflow/conf/" + CONFIG_FILE_LOG4J_NAME)); - assertTrue(command.contains(">> log.txt 2>&1")); - } + String command = + ClusterUtils.getStartCommand( + jvmOptions, mainClass, logFilename, configuration, extraOpts, classpath, true); + // 验证 + assertTrue(command.contains("-Xms1024m")); + assertTrue(command.contains("-Xmx2048m")); + assertTrue(command.contains("-Xmn512m")); + assertTrue(command.contains("-XX:MaxDirectMemorySize=1024m")); + assertTrue(command.contains("-Dkey1=\"value1\"")); + assertTrue(command.contains("-Dkey2=\"value2\"")); + assertTrue(command.contains("-Dlog.file=log.txt")); + assertTrue( + command.contains("-Dlog4j.configuration=file:/etc/geaflow/conf/" + CONFIG_FILE_LOG4J_NAME)); + assertTrue(command.contains(">> log.txt 2>&1")); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesClusterClient.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesClusterClient.java index 450ac0d00..72d592af9 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesClusterClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesClusterClient.java @@ -28,17 +28,13 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_WORK_PATH; -import com.google.common.base.Preconditions; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.LoadBalancerIngress; -import io.fabric8.kubernetes.api.model.Service; -import io.fabric8.kubernetes.api.model.ServicePort; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeoutException; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.client.AbstractClusterClient; import org.apache.geaflow.cluster.client.IPipelineClient; @@ -64,228 +60,241 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.LoadBalancerIngress; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.api.model.ServicePort; + public class KubernetesClusterClient extends AbstractClusterClient { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClusterClient.class); - private static final int DEFAULT_SLEEP_MS = 1000; + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClusterClient.class); + private static final int DEFAULT_SLEEP_MS = 1000; - private GeaflowKubeClient kubernetesClient; - private KubernetesClusterManager clusterManager; - private String clusterId; - private int clientTimeoutMs; + private GeaflowKubeClient kubernetesClient; + private KubernetesClusterManager clusterManager; + private String clusterId; + private int clientTimeoutMs; - @Override - public void init(IEnvironmentContext environmentContext) { - super.init(environmentContext); - if (!config.contains(JOB_WORK_PATH)) { - config.put(JOB_WORK_PATH, config.getString(WORK_DIR)); - } - if (!config.contains(CLUSTER_ID)) { - config.put(CLUSTER_ID, K8SConstants.RANDOM_CLUSTER_ID_PREFIX + UUID.randomUUID()); - } - this.clusterId = config.getString(CLUSTER_ID); - String masterUrl = KubernetesConfig.getClientMasterUrl(config); - this.kubernetesClient = new GeaflowKubeClient(config, masterUrl); - this.clusterManager = new KubernetesClusterManager(); - this.clusterManager.init(new ClusterContext(config), kubernetesClient); - this.clientTimeoutMs = KubernetesConfig.getClientTimeoutMs(config); + @Override + public void init(IEnvironmentContext environmentContext) { + super.init(environmentContext); + if (!config.contains(JOB_WORK_PATH)) { + config.put(JOB_WORK_PATH, config.getString(WORK_DIR)); } + if (!config.contains(CLUSTER_ID)) { + config.put(CLUSTER_ID, K8SConstants.RANDOM_CLUSTER_ID_PREFIX + UUID.randomUUID()); + } + this.clusterId = config.getString(CLUSTER_ID); + String masterUrl = KubernetesConfig.getClientMasterUrl(config); + this.kubernetesClient = new GeaflowKubeClient(config, masterUrl); + this.clusterManager = new KubernetesClusterManager(); + this.clusterManager.init(new ClusterContext(config), kubernetesClient); + this.clientTimeoutMs = KubernetesConfig.getClientTimeoutMs(config); + } - @Override - public IPipelineClient startCluster() { - try { - this.clusterId = clusterManager.startMaster().getHandler(); - Map driverAddresses = waitForMasterStarted(clusterId); - ClusterMeta clusterMeta = new ClusterMeta(driverAddresses, - config.getString(MASTER_EXPOSED_ADDRESS)); - callback.onSuccess(clusterMeta); - LOGGER.info("Cluster info: {} config: {}", clusterMeta, config); - String successMsg = String.format("Start cluster success. Cluster info: %s", clusterMeta); - StatsCollectorFactory.init(config).getEventCollector() - .reportEvent(ExceptionLevel.INFO, EventLabel.START_CLUSTER_SUCCESS, successMsg); - return PipelineClientFactory.createPipelineClient(driverAddresses, config); - } catch (Throwable e) { - LOGGER.error("Deploy failed.", e); - callback.onFailure(e); - String failMsg = String.format("Start cluster failed: %s", e.getMessage()); - StatsCollectorFactory.init(config).getEventCollector() - .reportEvent(ExceptionLevel.FATAL, EventLabel.START_CLUSTER_FAILED, failMsg); - kubernetesClient.destroyCluster(clusterId); - throw new GeaflowRuntimeException(e); - } + @Override + public IPipelineClient startCluster() { + try { + this.clusterId = clusterManager.startMaster().getHandler(); + Map driverAddresses = waitForMasterStarted(clusterId); + ClusterMeta clusterMeta = + new ClusterMeta(driverAddresses, config.getString(MASTER_EXPOSED_ADDRESS)); + callback.onSuccess(clusterMeta); + LOGGER.info("Cluster info: {} config: {}", clusterMeta, config); + String successMsg = String.format("Start cluster success. Cluster info: %s", clusterMeta); + StatsCollectorFactory.init(config) + .getEventCollector() + .reportEvent(ExceptionLevel.INFO, EventLabel.START_CLUSTER_SUCCESS, successMsg); + return PipelineClientFactory.createPipelineClient(driverAddresses, config); + } catch (Throwable e) { + LOGGER.error("Deploy failed.", e); + callback.onFailure(e); + String failMsg = String.format("Start cluster failed: %s", e.getMessage()); + StatsCollectorFactory.init(config) + .getEventCollector() + .reportEvent(ExceptionLevel.FATAL, EventLabel.START_CLUSTER_FAILED, failMsg); + kubernetesClient.destroyCluster(clusterId); + throw new GeaflowRuntimeException(e); } + } - private Map waitForMasterStarted(String clusterId) throws TimeoutException { - ClusterInfo clusterInfo = waitForMasterConfigUpdated(clusterId); - DockerNetworkType networkType = KubernetesConfig.getDockerNetworkType(config); - if (networkType != DockerNetworkType.HOST) { - updateServiceAddress(clusterId, clusterInfo); - } - return clusterInfo.getDriverAddresses(); + private Map waitForMasterStarted(String clusterId) + throws TimeoutException { + ClusterInfo clusterInfo = waitForMasterConfigUpdated(clusterId); + DockerNetworkType networkType = KubernetesConfig.getDockerNetworkType(config); + if (networkType != DockerNetworkType.HOST) { + updateServiceAddress(clusterId, clusterInfo); } + return clusterInfo.getDriverAddresses(); + } - private ClusterInfo waitForMasterConfigUpdated(String clusterId) throws TimeoutException { - Map configuration; - final long startTime = System.currentTimeMillis(); - KubernetesMasterParam masterParam = new KubernetesMasterParam(config); - while (true) { - String configName = masterParam.getConfigMapName(clusterId); - ConfigMap configMap = kubernetesClient.getConfigMap(configName); - configuration = KubernetesUtils.loadConfigurationFromString( - configMap.getData().get(K8SConstants.ENV_CONFIG_FILE)); - if (configuration.containsKey(DRIVER_EXPOSED_ADDRESS)) { - break; - } - long elapsedTime = System.currentTimeMillis() - startTime; - if (elapsedTime > 60000) { - LOGGER.warn("Start cluster took more than 60 seconds, please check logs on the " - + "Kubernetes cluster."); - if (elapsedTime > clientTimeoutMs) { - throw new TimeoutException("Waiting cluster ready timeout."); - } - } - SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); + private ClusterInfo waitForMasterConfigUpdated(String clusterId) throws TimeoutException { + Map configuration; + final long startTime = System.currentTimeMillis(); + KubernetesMasterParam masterParam = new KubernetesMasterParam(config); + while (true) { + String configName = masterParam.getConfigMapName(clusterId); + ConfigMap configMap = kubernetesClient.getConfigMap(configName); + configuration = + KubernetesUtils.loadConfigurationFromString( + configMap.getData().get(K8SConstants.ENV_CONFIG_FILE)); + if (configuration.containsKey(DRIVER_EXPOSED_ADDRESS)) { + break; + } + long elapsedTime = System.currentTimeMillis() - startTime; + if (elapsedTime > 60000) { + LOGGER.warn( + "Start cluster took more than 60 seconds, please check logs on the " + + "Kubernetes cluster."); + if (elapsedTime > clientTimeoutMs) { + throw new TimeoutException("Waiting cluster ready timeout."); } - - ClusterInfo clusterInfo = new ClusterInfo(); - String driverAddress = configuration.get(DRIVER_EXPOSED_ADDRESS); - clusterInfo.setDriverAddresses(KubernetesUtils.decodeRpcAddressMap(driverAddress)); - return clusterInfo; + } + SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); } - @Override - public void shutdown() { - Preconditions.checkNotNull(clusterId, "ClusterId is null."); - kubernetesClient.destroyCluster(clusterId); - } + ClusterInfo clusterInfo = new ClusterInfo(); + String driverAddress = configuration.get(DRIVER_EXPOSED_ADDRESS); + clusterInfo.setDriverAddresses(KubernetesUtils.decodeRpcAddressMap(driverAddress)); + return clusterInfo; + } - private void updateServiceAddress(String clusterId, ClusterInfo clusterInfo) - throws TimeoutException { - ServiceExposedType serviceType = KubernetesConfig.getServiceExposedType(config); - String masterServiceName; - if (serviceType == CLUSTER_IP) { - masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); - } else { - masterServiceName = KubernetesUtils.getMasterClientServiceName(clusterId); - } - String serviceAddress = setupExposedServiceAddress(serviceType, masterServiceName, - K8SConstants.HTTP_PORT); - config.put(MASTER_EXPOSED_ADDRESS, serviceAddress); + @Override + public void shutdown() { + Preconditions.checkNotNull(clusterId, "ClusterId is null."); + kubernetesClient.destroyCluster(clusterId); + } - int driverIndex = 0; - Map driverAddresses = new HashMap<>(); - Map originalDriverAddresses = clusterInfo.getDriverAddresses(); - for (String driverId : originalDriverAddresses.keySet()) { - String driverServiceName = KubernetesUtils.getDriverServiceName(clusterId, driverIndex); - ConnectAddress rpcAddress = ConnectAddress.build(setupExposedServiceAddress(serviceType, - driverServiceName, K8SConstants.RPC_PORT)); - driverAddresses.put(driverId, rpcAddress); - driverIndex++; - } - config.put(DRIVER_EXPOSED_ADDRESS, KubernetesUtils.encodeRpcAddressMap(driverAddresses)); - clusterInfo.setDriverAddresses(driverAddresses); + private void updateServiceAddress(String clusterId, ClusterInfo clusterInfo) + throws TimeoutException { + ServiceExposedType serviceType = KubernetesConfig.getServiceExposedType(config); + String masterServiceName; + if (serviceType == CLUSTER_IP) { + masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); + } else { + masterServiceName = KubernetesUtils.getMasterClientServiceName(clusterId); } + String serviceAddress = + setupExposedServiceAddress(serviceType, masterServiceName, K8SConstants.HTTP_PORT); + config.put(MASTER_EXPOSED_ADDRESS, serviceAddress); - private String setupExposedServiceAddress(ServiceExposedType serviceType, String serviceName, - String portName) throws TimeoutException { - String serviceAddress = "localhost"; - switch (serviceType) { - case CLUSTER_IP: - serviceAddress = resolveServiceExposedByClusterIp(serviceName, portName); - break; - case NODE_PORT: - serviceAddress = resolveServiceExposedByNodePort(serviceName, portName); - break; - case LOAD_BALANCER: - serviceAddress = resolveServiceExposedByLoadBalancer(serviceName); - break; - default: - break; - } + int driverIndex = 0; + Map driverAddresses = new HashMap<>(); + Map originalDriverAddresses = clusterInfo.getDriverAddresses(); + for (String driverId : originalDriverAddresses.keySet()) { + String driverServiceName = KubernetesUtils.getDriverServiceName(clusterId, driverIndex); + ConnectAddress rpcAddress = + ConnectAddress.build( + setupExposedServiceAddress(serviceType, driverServiceName, K8SConstants.RPC_PORT)); + driverAddresses.put(driverId, rpcAddress); + driverIndex++; + } + config.put(DRIVER_EXPOSED_ADDRESS, KubernetesUtils.encodeRpcAddressMap(driverAddresses)); + clusterInfo.setDriverAddresses(driverAddresses); + } - LOGGER.info("Service {} exposed: {}.", serviceName, serviceAddress); - return serviceAddress; + private String setupExposedServiceAddress( + ServiceExposedType serviceType, String serviceName, String portName) throws TimeoutException { + String serviceAddress = "localhost"; + switch (serviceType) { + case CLUSTER_IP: + serviceAddress = resolveServiceExposedByClusterIp(serviceName, portName); + break; + case NODE_PORT: + serviceAddress = resolveServiceExposedByNodePort(serviceName, portName); + break; + case LOAD_BALANCER: + serviceAddress = resolveServiceExposedByLoadBalancer(serviceName); + break; + default: + break; } - private String resolveServiceExposedByClusterIp(String serviceName, String portName) - throws TimeoutException { - String namespace = config.getString(NAME_SPACE); - String serviceSuffix = config.getString(SERVICE_SUFFIX); - String serviceAddress = serviceName + K8SConstants.NAMESPACE_SEPARATOR + namespace; - if (!StringUtils.isBlank(serviceSuffix)) { - serviceAddress += K8SConstants.NAMESPACE_SEPARATOR + serviceSuffix; - } + LOGGER.info("Service {} exposed: {}.", serviceName, serviceAddress); + return serviceAddress; + } - LOGGER.info("Waiting for service {} to be exposed by cluster ip.", serviceName); - Service service = getService(serviceName); + private String resolveServiceExposedByClusterIp(String serviceName, String portName) + throws TimeoutException { + String namespace = config.getString(NAME_SPACE); + String serviceSuffix = config.getString(SERVICE_SUFFIX); + String serviceAddress = serviceName + K8SConstants.NAMESPACE_SEPARATOR + namespace; + if (!StringUtils.isBlank(serviceSuffix)) { + serviceAddress += K8SConstants.NAMESPACE_SEPARATOR + serviceSuffix; + } - int port = 0; - for (ServicePort servicePort : service.getSpec().getPorts()) { - if (servicePort.getName().equals(portName)) { - port = servicePort.getPort(); - break; - } - } - return new ConnectAddress(serviceAddress, port).toString(); + LOGGER.info("Waiting for service {} to be exposed by cluster ip.", serviceName); + Service service = getService(serviceName); + + int port = 0; + for (ServicePort servicePort : service.getSpec().getPorts()) { + if (servicePort.getName().equals(portName)) { + port = servicePort.getPort(); + break; + } } + return new ConnectAddress(serviceAddress, port).toString(); + } - private String resolveServiceExposedByNodePort(String serviceName, String portName) - throws TimeoutException { - LOGGER.info("Waiting for service {} to be exposed by node port.", serviceName); - Service service = getService(serviceName); + private String resolveServiceExposedByNodePort(String serviceName, String portName) + throws TimeoutException { + LOGGER.info("Waiting for service {} to be exposed by node port.", serviceName); + Service service = getService(serviceName); - int nodePort = 0; - for (ServicePort servicePort : service.getSpec().getPorts()) { - if (servicePort.getName().equals(portName)) { - nodePort = servicePort.getNodePort(); - break; - } - } - return new ConnectAddress(kubernetesClient.getKubernetesMasterHost(), nodePort).toString(); + int nodePort = 0; + for (ServicePort servicePort : service.getSpec().getPorts()) { + if (servicePort.getName().equals(portName)) { + nodePort = servicePort.getNodePort(); + break; + } } + return new ConnectAddress(kubernetesClient.getKubernetesMasterHost(), nodePort).toString(); + } - private String resolveServiceExposedByLoadBalancer(String serviceName) { - LOGGER.info("Waiting for service {} to be exposed by load balancer.", serviceName); - List ipList = new ArrayList<>(); - List ingressList; - final long startTime = System.currentTimeMillis(); - Service service; - while (true) { - service = kubernetesClient.getService(serviceName); - if (service != null) { - ingressList = service.getStatus().getLoadBalancer().getIngress(); - if (ingressList != null && ingressList.size() > 0) { - for (LoadBalancerIngress ingress : ingressList) { - ipList.add(ingress.getIp()); - } - break; - } - } - if (System.currentTimeMillis() - startTime > 60000) { - LOGGER.warn("Expose service took more than 60s, please check logs on the " - + "Kubernetes cluster."); - } - SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); + private String resolveServiceExposedByLoadBalancer(String serviceName) { + LOGGER.info("Waiting for service {} to be exposed by load balancer.", serviceName); + List ipList = new ArrayList<>(); + List ingressList; + final long startTime = System.currentTimeMillis(); + Service service; + while (true) { + service = kubernetesClient.getService(serviceName); + if (service != null) { + ingressList = service.getStatus().getLoadBalancer().getIngress(); + if (ingressList != null && ingressList.size() > 0) { + for (LoadBalancerIngress ingress : ingressList) { + ipList.add(ingress.getIp()); + } + break; } - return ipList.get(0); + } + if (System.currentTimeMillis() - startTime > 60000) { + LOGGER.warn( + "Expose service took more than 60s, please check logs on the " + "Kubernetes cluster."); + } + SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); } + return ipList.get(0); + } - private Service getService(String serviceName) throws TimeoutException { - Service service = kubernetesClient.getService(serviceName); - final long startTime = System.currentTimeMillis(); - while (service == null) { - long elapsedTime = System.currentTimeMillis() - startTime; - if (elapsedTime > 60000) { - LOGGER.warn("Get service {} took more than 60s, please check logs on Kubernetes" - + " cluster.", serviceName); - if (elapsedTime > clientTimeoutMs) { - throw new TimeoutException("Resolve service " + serviceName + " timeout."); - } - } - SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); - service = kubernetesClient.getService(serviceName); + private Service getService(String serviceName) throws TimeoutException { + Service service = kubernetesClient.getService(serviceName); + final long startTime = System.currentTimeMillis(); + while (service == null) { + long elapsedTime = System.currentTimeMillis() - startTime; + if (elapsedTime > 60000) { + LOGGER.warn( + "Get service {} took more than 60s, please check logs on Kubernetes" + " cluster.", + serviceName); + if (elapsedTime > clientTimeoutMs) { + throw new TimeoutException("Resolve service " + serviceName + " timeout."); } - return service; + } + SleepUtils.sleepMilliSecond(DEFAULT_SLEEP_MS); + service = kubernetesClient.getService(serviceName); } - + return service; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironment.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironment.java index 0dbb1325b..88373675a 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironment.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironment.java @@ -27,18 +27,18 @@ public class KubernetesEnvironment extends AbstractEnvironment { - public KubernetesEnvironment() { - this.context.getConfig().put(LOG_DIR, "/home/admin/logs/geaflow"); - this.context.getConfig().put(SUPERVISOR_ENABLE, Boolean.TRUE.toString()); - } + public KubernetesEnvironment() { + this.context.getConfig().put(LOG_DIR, "/home/admin/logs/geaflow"); + this.context.getConfig().put(SUPERVISOR_ENABLE, Boolean.TRUE.toString()); + } - @Override - protected IClusterClient getClusterClient() { - return new KubernetesClusterClient(); - } + @Override + protected IClusterClient getClusterClient() { + return new KubernetesClusterClient(); + } - @Override - public EnvType getEnvType() { - return EnvType.K8S; - } + @Override + public EnvType getEnvType() { + return EnvType.K8S; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClient.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClient.java index 8e8237270..8b6ef4bbd 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClient.java @@ -21,17 +21,10 @@ import static org.apache.geaflow.cluster.k8s.config.K8SConstants.CLIENT_NAME_SUFFIX; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.Container; -import io.fabric8.kubernetes.api.model.OwnerReference; -import io.fabric8.kubernetes.api.model.OwnerReferenceBuilder; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.api.model.Service; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.clustermanager.KubernetesResourceBuilder; @@ -46,100 +39,113 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.OwnerReferenceBuilder; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Service; + /** - * Utility to submit a job in cluster mode. - * The config contains: clusterId, mainClass, clientArgs, and k8s related config. + * Utility to submit a job in cluster mode. The config contains: clusterId, mainClass, clientArgs, + * and k8s related config. */ public class KubernetesJobClient { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesJobClient.class); - - private final Configuration configuration; - private final GeaflowKubeClient geaflowKubeClient; - private final KubernetesClientParam clientParam; - private final String clusterId; - - public KubernetesJobClient(Map config, String masterUrl) { - this(config, new GeaflowKubeClient(config, masterUrl)); - } - - @VisibleForTesting - public KubernetesJobClient(Map config, GeaflowKubeClient client) { - this.configuration = new Configuration(config); - this.clientParam = new KubernetesClientParam(configuration); - this.geaflowKubeClient = client; - this.clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); - Preconditions.checkArgument(StringUtils.isNotBlank(clusterId), - "ClusterId is not set: " + ExecutionConfigKeys.CLUSTER_ID); - } - - public void submitJob() { - try { - DockerNetworkType dockerNetworkType = KubernetesConfig - .getDockerNetworkType(configuration); - - // create configMap. - ConfigMap configMap = createConfigMap(clusterId); - - // create container - String podName = clusterId + CLIENT_NAME_SUFFIX; - Container container = KubernetesResourceBuilder - .createContainer(podName, podName, null, - clientParam, clientParam.getContainerShellCommand(), clientParam.getAdditionEnvs(), - dockerNetworkType); - - // create ownerReference - OwnerReference ownerReference = createOwnerReference(configMap); - - // create pod - Pod pod = KubernetesResourceBuilder - .createPod(clusterId, podName, podName, ownerReference, configMap, clientParam, - container); - geaflowKubeClient.createPod(pod); - } catch (Exception e) { - LOGGER.error("Failed to create client pod:{}", e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } - } - - /** - * Setup a Config Map that will generate a geaflow-conf.yaml and log4j file. - * - * @param clusterId the cluster id - * @return the created configMap - */ - private ConfigMap createConfigMap(String clusterId) { - ConfigMap configMap = KubernetesResourceBuilder.createConfigMap(clusterId, clientParam, - null); - return geaflowKubeClient.createConfigMap(configMap); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesJobClient.class); + + private final Configuration configuration; + private final GeaflowKubeClient geaflowKubeClient; + private final KubernetesClientParam clientParam; + private final String clusterId; + + public KubernetesJobClient(Map config, String masterUrl) { + this(config, new GeaflowKubeClient(config, masterUrl)); + } + + @VisibleForTesting + public KubernetesJobClient(Map config, GeaflowKubeClient client) { + this.configuration = new Configuration(config); + this.clientParam = new KubernetesClientParam(configuration); + this.geaflowKubeClient = client; + this.clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); + Preconditions.checkArgument( + StringUtils.isNotBlank(clusterId), + "ClusterId is not set: " + ExecutionConfigKeys.CLUSTER_ID); + } + + public void submitJob() { + try { + DockerNetworkType dockerNetworkType = KubernetesConfig.getDockerNetworkType(configuration); + + // create configMap. + ConfigMap configMap = createConfigMap(clusterId); + + // create container + String podName = clusterId + CLIENT_NAME_SUFFIX; + Container container = + KubernetesResourceBuilder.createContainer( + podName, + podName, + null, + clientParam, + clientParam.getContainerShellCommand(), + clientParam.getAdditionEnvs(), + dockerNetworkType); + + // create ownerReference + OwnerReference ownerReference = createOwnerReference(configMap); + + // create pod + Pod pod = + KubernetesResourceBuilder.createPod( + clusterId, podName, podName, ownerReference, configMap, clientParam, container); + geaflowKubeClient.createPod(pod); + } catch (Exception e) { + LOGGER.error("Failed to create client pod:{}", e.getMessage(), e); + throw new GeaflowRuntimeException(e); } - - private OwnerReference createOwnerReference(ConfigMap configMap) { - Preconditions.checkNotNull(configMap, "configMap could not be null"); - return new OwnerReferenceBuilder() - .withName(configMap.getMetadata().getName()) - .withApiVersion(configMap.getApiVersion()) - .withUid(configMap.getMetadata().getUid()) - .withKind(configMap.getKind()) - .withController(true) - .build(); - } - - public void stopJob() { - String clientConfigMap = clientParam.getConfigMapName(clusterId); - geaflowKubeClient.deleteConfigMap(clientConfigMap); - geaflowKubeClient.destroyCluster(clusterId); - } - - public Service getMasterService() { - String masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); - return geaflowKubeClient.getService(masterServiceName); - } - - public List getJobPods() { - Map podLabels = new HashMap<>(); - podLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); - return geaflowKubeClient.getPods(podLabels).getItems(); - } - + } + + /** + * Setup a Config Map that will generate a geaflow-conf.yaml and log4j file. + * + * @param clusterId the cluster id + * @return the created configMap + */ + private ConfigMap createConfigMap(String clusterId) { + ConfigMap configMap = KubernetesResourceBuilder.createConfigMap(clusterId, clientParam, null); + return geaflowKubeClient.createConfigMap(configMap); + } + + private OwnerReference createOwnerReference(ConfigMap configMap) { + Preconditions.checkNotNull(configMap, "configMap could not be null"); + return new OwnerReferenceBuilder() + .withName(configMap.getMetadata().getName()) + .withApiVersion(configMap.getApiVersion()) + .withUid(configMap.getMetadata().getUid()) + .withKind(configMap.getKind()) + .withController(true) + .build(); + } + + public void stopJob() { + String clientConfigMap = clientParam.getConfigMapName(clusterId); + geaflowKubeClient.deleteConfigMap(clientConfigMap); + geaflowKubeClient.destroyCluster(clusterId); + } + + public Service getMasterService() { + String masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); + return geaflowKubeClient.getService(masterServiceName); + } + + public List getJobPods() { + Map podLabels = new HashMap<>(); + podLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); + return geaflowKubeClient.getPods(podLabels).getItems(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitter.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitter.java index 6401dd2b3..d3b690aa2 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitter.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitter.java @@ -22,10 +22,9 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.CLUSTER_TYPE; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; -import com.alibaba.fastjson.JSON; -import com.google.common.base.Preconditions; import java.lang.reflect.InvocationTargetException; import java.util.Map; + import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; @@ -37,85 +36,84 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Utility to submit a job via shell scripts. - */ +import com.alibaba.fastjson.JSON; +import com.google.common.base.Preconditions; + +/** Utility to submit a job via shell scripts. */ public class KubernetesJobSubmitter { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesJobSubmitter.class); - private static final String START_ACTION = "start"; - private static final String STOP_ACTION = "stop"; + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesJobSubmitter.class); + private static final String START_ACTION = "start"; + private static final String STOP_ACTION = "stop"; - public static void main(String[] args) throws Throwable { - if (args.length < 1) { - throw new IllegalArgumentException("usage: start/stop [mainClassName] [args]"); - } - String action = args[0]; - KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); - if (action.equalsIgnoreCase(START_ACTION)) { - submitter.submitJob(args); - } else if (action.equalsIgnoreCase(STOP_ACTION)) { - submitter.stopJob(args); - } else { - throw new IllegalArgumentException("unknown action:" + action); - } + public static void main(String[] args) throws Throwable { + if (args.length < 1) { + throw new IllegalArgumentException("usage: start/stop [mainClassName] [args]"); } + String action = args[0]; + KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); + if (action.equalsIgnoreCase(START_ACTION)) { + submitter.submitJob(args); + } else if (action.equalsIgnoreCase(STOP_ACTION)) { + submitter.stopJob(args); + } else { + throw new IllegalArgumentException("unknown action:" + action); + } + } - public void submitJob(String[] args) throws Throwable { - if (args.length < 2) { - throw new IllegalArgumentException("usage: start mainClassName [args]"); - } - try { - String driverArgs; - String className = args[1]; - if (args.length > 2) { - driverArgs = args[2]; - } else { - Configuration config = KubernetesUtils.loadConfigurationFromFile(); - driverArgs = StringEscapeUtils.escapeJava(JSON.toJSONString(config.getConfigMap())); - } - LOGGER.info("{} driverArgs: {}", className, driverArgs); - - Class clazz = Class.forName(className); - System.setProperty(CLUSTER_TYPE, EnvType.K8S.name()); - clazz.getMethod("main", String[].class).invoke(null, (Object) new String[]{driverArgs}); - } catch (Throwable e) { - if (e instanceof InvocationTargetException && e.getCause() != null) { - e = e.getCause(); - } - LOGGER.error("launch main failed", e); - throw e; - } + public void submitJob(String[] args) throws Throwable { + if (args.length < 2) { + throw new IllegalArgumentException("usage: start mainClassName [args]"); } + try { + String driverArgs; + String className = args[1]; + if (args.length > 2) { + driverArgs = args[2]; + } else { + Configuration config = KubernetesUtils.loadConfigurationFromFile(); + driverArgs = StringEscapeUtils.escapeJava(JSON.toJSONString(config.getConfigMap())); + } + LOGGER.info("{} driverArgs: {}", className, driverArgs); - public void stopJob(String[] args) throws Throwable { - Configuration configuration; - GeaflowKubeClient client = null; - try { - if (args.length > 1) { - EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); - Map config = parser.parse(new String[]{args[1]}); - configuration = new Configuration(config); - } else { - configuration = KubernetesUtils.loadConfigurationFromFile(); - } - String masterUrl = KubernetesConfig.getClientMasterUrl(configuration); - client = new GeaflowKubeClient(configuration, masterUrl); - String clusterId = configuration.getString(CLUSTER_ID); - Preconditions.checkArgument(StringUtils.isNotEmpty(clusterId), "clusterId is not set"); - LOGGER.info("stop job with cluster id:{}", clusterId); - client.destroyCluster(clusterId); - } catch (Throwable e) { - if (e instanceof InvocationTargetException && e.getCause() != null) { - e = e.getCause(); - } - LOGGER.error("stop job failed", e); - throw e; - } finally { - if (client != null) { - client.close(); - } - } + Class clazz = Class.forName(className); + System.setProperty(CLUSTER_TYPE, EnvType.K8S.name()); + clazz.getMethod("main", String[].class).invoke(null, (Object) new String[] {driverArgs}); + } catch (Throwable e) { + if (e instanceof InvocationTargetException && e.getCause() != null) { + e = e.getCause(); + } + LOGGER.error("launch main failed", e); + throw e; } + } + public void stopJob(String[] args) throws Throwable { + Configuration configuration; + GeaflowKubeClient client = null; + try { + if (args.length > 1) { + EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); + Map config = parser.parse(new String[] {args[1]}); + configuration = new Configuration(config); + } else { + configuration = KubernetesUtils.loadConfigurationFromFile(); + } + String masterUrl = KubernetesConfig.getClientMasterUrl(configuration); + client = new GeaflowKubeClient(configuration, masterUrl); + String clusterId = configuration.getString(CLUSTER_ID); + Preconditions.checkArgument(StringUtils.isNotEmpty(clusterId), "clusterId is not set"); + LOGGER.info("stop job with cluster id:{}", clusterId); + client.destroyCluster(clusterId); + } catch (Throwable e) { + if (e instanceof InvocationTargetException && e.getCause() != null) { + e = e.getCause(); + } + LOGGER.error("stop job failed", e); + throw e; + } finally { + if (client != null) { + client.close(); + } + } + } } - diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/GeaflowKubeClient.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/GeaflowKubeClient.java index d5d987d49..99a2e3e36 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/GeaflowKubeClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/GeaflowKubeClient.java @@ -19,6 +19,22 @@ package org.apache.geaflow.cluster.k8s.clustermanager; +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; +import org.apache.geaflow.cluster.k8s.config.KubernetesMasterParam; +import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.utils.RetryCommand; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.fabric8.kubernetes.api.model.ConfigMap; import io.fabric8.kubernetes.api.model.HasMetadata; import io.fabric8.kubernetes.api.model.Pod; @@ -34,195 +50,197 @@ import io.fabric8.kubernetes.client.dsl.PodResource; import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElectionConfig; import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElector; -import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; -import org.apache.geaflow.cluster.k8s.config.KubernetesMasterParam; -import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; -import org.apache.geaflow.common.config.Configuration; -import org.apache.geaflow.common.utils.RetryCommand; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -/** - * Geaflow Kubernetes client to interact with kubernetes api server. - */ +/** Geaflow Kubernetes client to interact with kubernetes api server. */ public class GeaflowKubeClient implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaflowKubeClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(GeaflowKubeClient.class); - private final KubernetesMasterParam masterParam; - private final KubernetesClient kubernetesClient; + private final KubernetesMasterParam masterParam; + private final KubernetesClient kubernetesClient; - private final int retryCount; - private final long retryInterval; + private final int retryCount; + private final long retryInterval; - public GeaflowKubeClient(Configuration config) { - this(KubernetesClientFactory.create(config), config); - } + public GeaflowKubeClient(Configuration config) { + this(KubernetesClientFactory.create(config), config); + } - public GeaflowKubeClient(Map config, String masterUrl) { - this(new Configuration(config), masterUrl); - } + public GeaflowKubeClient(Map config, String masterUrl) { + this(new Configuration(config), masterUrl); + } - public GeaflowKubeClient(Configuration config, String masterUrl) { - this(KubernetesClientFactory.create(config, masterUrl), config); - } + public GeaflowKubeClient(Configuration config, String masterUrl) { + this(KubernetesClientFactory.create(config, masterUrl), config); + } - public GeaflowKubeClient(KubernetesClient client, Configuration config) { - this.kubernetesClient = client; - this.masterParam = new KubernetesMasterParam(config); - this.retryCount = KubernetesConfig.getConnectionRetryTimes(config); - this.retryInterval = KubernetesConfig.getConnectionRetryIntervalMs(config); - } + public GeaflowKubeClient(KubernetesClient client, Configuration config) { + this.kubernetesClient = client; + this.masterParam = new KubernetesMasterParam(config); + this.retryCount = KubernetesConfig.getConnectionRetryTimes(config); + this.retryInterval = KubernetesConfig.getConnectionRetryIntervalMs(config); + } - public String getKubernetesMasterHost() { - return kubernetesClient.getMasterUrl().getHost(); - } + public String getKubernetesMasterHost() { + return kubernetesClient.getMasterUrl().getHost(); + } - public ConfigMap createOrReplaceConfigMap(ConfigMap configMap) { - return runWithRetries(() -> kubernetesClient.configMaps().createOrReplace(configMap)); - } + public ConfigMap createOrReplaceConfigMap(ConfigMap configMap) { + return runWithRetries(() -> kubernetesClient.configMaps().createOrReplace(configMap)); + } - public ConfigMap updateConfigMap(ConfigMap configMap) { - return runWithRetries(() -> kubernetesClient.resource(configMap).lockResourceVersion().replace()); - } + public ConfigMap updateConfigMap(ConfigMap configMap) { + return runWithRetries( + () -> kubernetesClient.resource(configMap).lockResourceVersion().replace()); + } - public Service getService(String serviceName) { - return runWithRetries((() -> kubernetesClient.services().withName(serviceName).get())); - } + public Service getService(String serviceName) { + return runWithRetries((() -> kubernetesClient.services().withName(serviceName).get())); + } - public PodList getPods(Map labels) { - return runWithRetries((() -> kubernetesClient.pods().withLabels(labels).list())); - } + public PodList getPods(Map labels) { + return runWithRetries((() -> kubernetesClient.pods().withLabels(labels).list())); + } - public ConfigMap getConfigMap(String configmapName) { - return runWithRetries((() -> kubernetesClient.configMaps().withName(configmapName).get())); - } + public ConfigMap getConfigMap(String configmapName) { + return runWithRetries((() -> kubernetesClient.configMaps().withName(configmapName).get())); + } - public Deployment getDeployment(String name) { - return runWithRetries((() -> kubernetesClient.apps().deployments().withName(name).get())); - } + public Deployment getDeployment(String name) { + return runWithRetries((() -> kubernetesClient.apps().deployments().withName(name).get())); + } - public Service createService(Service service) { - Callable action = () -> { - LOGGER.info("create service: {}", service.getMetadata().getName()); - return kubernetesClient.services().create(service); + public Service createService(Service service) { + Callable action = + () -> { + LOGGER.info("create service: {}", service.getMetadata().getName()); + return kubernetesClient.services().create(service); }; - return runWithRetries(action); - } - - public void createPod(Pod pod) { - Callable action = () -> { - LOGGER.info("create pod: {}", pod.getMetadata().getName()); - kubernetesClient.pods().create(pod); - return null; + return runWithRetries(action); + } + + public void createPod(Pod pod) { + Callable action = + () -> { + LOGGER.info("create pod: {}", pod.getMetadata().getName()); + kubernetesClient.pods().create(pod); + return null; }; - runWithRetries(action); - } - - public ConfigMap createConfigMap(ConfigMap configMap) { - Callable action = () -> { - LOGGER.info("create configmap: {}", configMap.getMetadata().getName()); - return kubernetesClient.configMaps().create(configMap); + runWithRetries(action); + } + + public ConfigMap createConfigMap(ConfigMap configMap) { + Callable action = + () -> { + LOGGER.info("create configmap: {}", configMap.getMetadata().getName()); + return kubernetesClient.configMaps().create(configMap); }; - return runWithRetries(action); - } - - public void createDeployment(Deployment deployment) { - Callable action = () -> { - kubernetesClient.apps().deployments().create(deployment); - LOGGER.info("create deployment: {}", deployment.getMetadata().getName()); - return null; + return runWithRetries(action); + } + + public void createDeployment(Deployment deployment) { + Callable action = + () -> { + kubernetesClient.apps().deployments().create(deployment); + LOGGER.info("create deployment: {}", deployment.getMetadata().getName()); + return null; }; - runWithRetries(action); - } - - public Watch createPodsWatcher(Map labels, - BiConsumer eventHandler, - Consumer closeHandler) { - Callable action = () -> { - Watcher watcher = createWatcher(eventHandler, closeHandler); - PodList podList = kubernetesClient.pods().withLabels(labels).list(); - String resourceVersion = podList.getMetadata().getResourceVersion(); - LOGGER.info("create watcher for {} pods with resource version: {} labels: {}", podList.getItems().size(), resourceVersion, labels); - return kubernetesClient.pods().withLabels(labels).withResourceVersion(resourceVersion).watch(watcher); + runWithRetries(action); + } + + public Watch createPodsWatcher( + Map labels, + BiConsumer eventHandler, + Consumer closeHandler) { + Callable action = + () -> { + Watcher watcher = createWatcher(eventHandler, closeHandler); + PodList podList = kubernetesClient.pods().withLabels(labels).list(); + String resourceVersion = podList.getMetadata().getResourceVersion(); + LOGGER.info( + "create watcher for {} pods with resource version: {} labels: {}", + podList.getItems().size(), + resourceVersion, + labels); + return kubernetesClient + .pods() + .withLabels(labels) + .withResourceVersion(resourceVersion) + .watch(watcher); }; - return runWithRetries(action); - } - - public Watch createServiceWatcher(String serviceName, - BiConsumer eventHandler, - Consumer closeHandler) { - Callable action = () -> { - Watcher watcher = createWatcher(eventHandler, closeHandler); - LOGGER.info("create watcher for service with name: {}", serviceName); - return kubernetesClient.services().withName(serviceName).watch(watcher); + return runWithRetries(action); + } + + public Watch createServiceWatcher( + String serviceName, + BiConsumer eventHandler, + Consumer closeHandler) { + Callable action = + () -> { + Watcher watcher = createWatcher(eventHandler, closeHandler); + LOGGER.info("create watcher for service with name: {}", serviceName); + return kubernetesClient.services().withName(serviceName).watch(watcher); }; - return runWithRetries(action); - } - - private Watcher createWatcher(BiConsumer eventHandler, - Consumer closeHandler) { - Watcher watcher = new Watcher() { - @Override - public void eventReceived(Action action, R resource) { - eventHandler.accept(action, resource); - } - - @Override - public void onClose(WatcherException e) { - if (e != null) { - LOGGER.warn("Watcher onClose: {}", e.getMessage()); - } - closeHandler.accept(e); + return runWithRetries(action); + } + + private Watcher createWatcher( + BiConsumer eventHandler, Consumer closeHandler) { + Watcher watcher = + new Watcher() { + @Override + public void eventReceived(Action action, R resource) { + eventHandler.accept(action, resource); + } + + @Override + public void onClose(WatcherException e) { + if (e != null) { + LOGGER.warn("Watcher onClose: {}", e.getMessage()); } + closeHandler.accept(e); + } }; - return watcher; - } - - public LeaderElector createLeaderElector(LeaderElectionConfig config, ExecutorService service) { - return new LeaderElector(kubernetesClient, config, service); - } - - private T runWithRetries(Callable action) { - return RetryCommand.run(action, retryCount, retryInterval); - } - - public void destroyCluster(String clusterId) { - String serviceName = KubernetesUtils.getMasterServiceName(clusterId); - LOGGER.info("delete cluster with service:{}", serviceName); - kubernetesClient.services().withName(serviceName).delete(); - } - - public void deleteConfigMap(String configMapName) { - LOGGER.info("delete configMap:{}", configMapName); - kubernetesClient.configMaps().withName(configMapName).delete(); - } - - public void deletePod(Map labels) { - Callable action = () -> { - FilterWatchListDeletable running = kubernetesClient.pods().withLabels(labels); - List pods = running.list().getItems(); - if (!pods.isEmpty()) { - LOGGER.info("delete {} running pod with label:{}", pods.size(), - labels); - pods.forEach(pod -> kubernetesClient.resource(pod).delete()); - } - return null; + return watcher; + } + + public LeaderElector createLeaderElector(LeaderElectionConfig config, ExecutorService service) { + return new LeaderElector(kubernetesClient, config, service); + } + + private T runWithRetries(Callable action) { + return RetryCommand.run(action, retryCount, retryInterval); + } + + public void destroyCluster(String clusterId) { + String serviceName = KubernetesUtils.getMasterServiceName(clusterId); + LOGGER.info("delete cluster with service:{}", serviceName); + kubernetesClient.services().withName(serviceName).delete(); + } + + public void deleteConfigMap(String configMapName) { + LOGGER.info("delete configMap:{}", configMapName); + kubernetesClient.configMaps().withName(configMapName).delete(); + } + + public void deletePod(Map labels) { + Callable action = + () -> { + FilterWatchListDeletable running = + kubernetesClient.pods().withLabels(labels); + List pods = running.list().getItems(); + if (!pods.isEmpty()) { + LOGGER.info("delete {} running pod with label:{}", pods.size(), labels); + pods.forEach(pod -> kubernetesClient.resource(pod).delete()); + } + return null; }; - runWithRetries(action); - } + runWithRetries(action); + } - public void close() { - if (kubernetesClient != null) { - kubernetesClient.close(); - } + public void close() { + if (kubernetesClient != null) { + kubernetesClient.close(); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClientFactory.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClientFactory.java index 1797c7210..c418e8ef4 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClientFactory.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClientFactory.java @@ -27,57 +27,53 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.NAME_SPACE; import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.PING_INTERVAL_MS; +import org.apache.commons.lang3.StringUtils; +import org.apache.geaflow.common.config.Configuration; + import io.fabric8.kubernetes.client.ConfigBuilder; import io.fabric8.kubernetes.client.KubernetesClient; import io.fabric8.kubernetes.client.KubernetesClientBuilder; -import org.apache.commons.lang3.StringUtils; -import org.apache.geaflow.common.config.Configuration; -/** - * Builder for Kubernetes clients. - */ +/** Builder for Kubernetes clients. */ public class KubernetesClientFactory { - public static KubernetesClient create(Configuration config) { - String masterUrl = config.getString(MASTER_URL); - return create(config, masterUrl); - } + public static KubernetesClient create(Configuration config) { + String masterUrl = config.getString(MASTER_URL); + return create(config, masterUrl); + } - public static KubernetesClient create(Configuration config, String masterUrl) { - String namespace = config.getString(NAME_SPACE); - long pingInterval = config.getLong(PING_INTERVAL_MS); + public static KubernetesClient create(Configuration config, String masterUrl) { + String namespace = config.getString(NAME_SPACE); + long pingInterval = config.getLong(PING_INTERVAL_MS); - ConfigBuilder clientConfig = new ConfigBuilder() + ConfigBuilder clientConfig = + new ConfigBuilder() .withApiVersion("v1") .withMasterUrl(masterUrl) .withWebsocketPingInterval(pingInterval) .withNamespace(namespace); - clientConfig.withTrustCerts(true); - String certKey = config.getString(CERT_KEY); - if (!StringUtils.isBlank(certKey)) { - clientConfig.withClientKeyData(certKey); - } - - String certData = config.getString(CERT_DATA); - if (!StringUtils.isBlank(certData)) { - clientConfig.withClientCertData(certData); - } + clientConfig.withTrustCerts(true); + String certKey = config.getString(CERT_KEY); + if (!StringUtils.isBlank(certKey)) { + clientConfig.withClientKeyData(certKey); + } - String caData = config.getString(CA_DATA); - if (!StringUtils.isBlank(caData)) { - clientConfig.withCaCertData(certData); - } + String certData = config.getString(CERT_DATA); + if (!StringUtils.isBlank(certData)) { + clientConfig.withClientCertData(certData); + } - String certKeyAlgo = config.getString(CLIENT_KEY_ALGO); - if (!StringUtils.isBlank(certKeyAlgo)) { - clientConfig.withClientKeyAlgo(certKeyAlgo); - } + String caData = config.getString(CA_DATA); + if (!StringUtils.isBlank(caData)) { + clientConfig.withCaCertData(certData); + } - return new KubernetesClientBuilder() - .withConfig(clientConfig.build()) - .build(); + String certKeyAlgo = config.getString(CLIENT_KEY_ALGO); + if (!StringUtils.isBlank(certKeyAlgo)) { + clientConfig.withClientKeyAlgo(certKeyAlgo); } + return new KubernetesClientBuilder().withConfig(clientConfig.build()).build(); + } } - diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterId.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterId.java index fc7d4267a..6eb62dec2 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterId.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterId.java @@ -23,14 +23,13 @@ public class KubernetesClusterId implements ClusterId { - private final String clusterId; + private final String clusterId; - public KubernetesClusterId(String clusterId) { - this.clusterId = clusterId; - } - - public String getHandler() { - return clusterId; - } + public KubernetesClusterId(String clusterId) { + this.clusterId = clusterId; + } + public String getHandler() { + return clusterId; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManager.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManager.java index f33917c66..f2ee57285 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManager.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManager.java @@ -31,17 +31,11 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.SERVICE_SUFFIX; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; -import com.google.common.annotations.VisibleForTesting; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.Container; -import io.fabric8.kubernetes.api.model.OwnerReference; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.api.model.Service; -import io.fabric8.kubernetes.api.model.apps.Deployment; import java.io.File; import java.util.Collections; import java.util.HashMap; import java.util.Map; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.k8s.config.AbstractKubernetesParam; @@ -61,271 +55,331 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class KubernetesClusterManager extends GeaFlowClusterManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClusterManager.class); - - private String clusterId; - private OwnerReference ownerReference; - private ConfigMap containerConfigMap; - - private KubernetesContainerParam containerParam; - private KubernetesDriverParam driverParam; - private KubernetesMasterParam masterParam; - private String containerPodNamePrefix; - private String driverPodNamePrefix; - private DockerNetworkType dockerNetworkType; - private ServiceExposedType serviceExposedType; - private GeaflowKubeClient kubernetesClient; - +import com.google.common.annotations.VisibleForTesting; - public KubernetesClusterManager() { - super(EnvType.K8S); - } +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.api.model.apps.Deployment; - @Override - public void init(ClusterContext context) { - init(context, new GeaflowKubeClient(context.getConfig())); - } +public class KubernetesClusterManager extends GeaFlowClusterManager { - public void init(ClusterContext context, GeaflowKubeClient client) { - super.init(context); - - this.failFast = true; - this.kubernetesClient = client; - this.dockerNetworkType = KubernetesConfig.getDockerNetworkType(config); - this.serviceExposedType = KubernetesConfig.getServiceExposedType(config); - this.masterParam = new KubernetesMasterParam(clusterConfig); - this.clusterId = config.getString(CLUSTER_ID); - - if (config.contains(KubernetesConfig.CLUSTER_START_TIME)) { - this.containerParam = new KubernetesContainerParam(clusterConfig); - this.driverParam = new KubernetesDriverParam(clusterConfig); - this.containerPodNamePrefix = containerParam.getPodNamePrefix(clusterId); - this.driverPodNamePrefix = driverParam.getPodNamePrefix(clusterId); - setupOwnerReference(); - setupConfigMap(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClusterManager.class); + + private String clusterId; + private OwnerReference ownerReference; + private ConfigMap containerConfigMap; + + private KubernetesContainerParam containerParam; + private KubernetesDriverParam driverParam; + private KubernetesMasterParam masterParam; + private String containerPodNamePrefix; + private String driverPodNamePrefix; + private DockerNetworkType dockerNetworkType; + private ServiceExposedType serviceExposedType; + private GeaflowKubeClient kubernetesClient; + + public KubernetesClusterManager() { + super(EnvType.K8S); + } + + @Override + public void init(ClusterContext context) { + init(context, new GeaflowKubeClient(context.getConfig())); + } + + public void init(ClusterContext context, GeaflowKubeClient client) { + super.init(context); + + this.failFast = true; + this.kubernetesClient = client; + this.dockerNetworkType = KubernetesConfig.getDockerNetworkType(config); + this.serviceExposedType = KubernetesConfig.getServiceExposedType(config); + this.masterParam = new KubernetesMasterParam(clusterConfig); + this.clusterId = config.getString(CLUSTER_ID); + + if (config.contains(KubernetesConfig.CLUSTER_START_TIME)) { + this.containerParam = new KubernetesContainerParam(clusterConfig); + this.driverParam = new KubernetesDriverParam(clusterConfig); + this.containerPodNamePrefix = containerParam.getPodNamePrefix(clusterId); + this.driverPodNamePrefix = driverParam.getPodNamePrefix(clusterId); + setupOwnerReference(); + setupConfigMap(); } - - @Override - public KubernetesClusterId startMaster() { - Map labels = masterParam.getPodLabels(clusterId); - createMaster(clusterId, labels); - clusterInfo = new KubernetesClusterId(clusterId); - return (KubernetesClusterId) clusterInfo; + } + + @Override + public KubernetesClusterId startMaster() { + Map labels = masterParam.getPodLabels(clusterId); + createMaster(clusterId, labels); + clusterInfo = new KubernetesClusterId(clusterId); + return (KubernetesClusterId) clusterInfo; + } + + private void setupOwnerReference() { + try { + String serviceName = KubernetesUtils.getMasterServiceName(clusterId); + Service service = kubernetesClient.getService(serviceName); + if (service != null) { + ownerReference = KubernetesResourceBuilder.createOwnerReference(service); + } else { + throw new RuntimeException("Failed to get service: " + serviceName); + } + } catch (Exception e) { + throw new RuntimeException("Could not setup owner reference.", e); } - - private void setupOwnerReference() { - try { - String serviceName = KubernetesUtils.getMasterServiceName(clusterId); - Service service = kubernetesClient.getService(serviceName); - if (service != null) { - ownerReference = KubernetesResourceBuilder.createOwnerReference(service); - } else { - throw new RuntimeException("Failed to get service: " + serviceName); - } - } catch (Exception e) { - throw new RuntimeException("Could not setup owner reference.", e); - } + } + + private void setupConfigMap() { + try { + ConfigMap configMap = + KubernetesResourceBuilder.createConfigMap(clusterId, containerParam, ownerReference); + kubernetesClient.createOrReplaceConfigMap(configMap); + containerConfigMap = configMap; + } catch (Exception e) { + throw new RuntimeException("Could not upload container config map.", e); } - - private void setupConfigMap() { - try { - ConfigMap configMap = KubernetesResourceBuilder.createConfigMap(clusterId, - containerParam, ownerReference); - kubernetesClient.createOrReplaceConfigMap(configMap); - containerConfigMap = configMap; - } catch (Exception e) { - throw new RuntimeException("Could not upload container config map.", e); - } + } + + @VisibleForTesting + public void createMaster(String clusterId, Map labels) { + this.clusterId = clusterId; + this.config.put(CLUSTER_ID, clusterId); + this.config.put( + KubernetesConfig.CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); + + this.dockerNetworkType = KubernetesConfig.getDockerNetworkType(config); + // Host network only supports clusterIp + if (dockerNetworkType == DockerNetworkType.HOST) { + config.put(SERVICE_EXPOSED_TYPE, ServiceExposedType.CLUSTER_IP.name()); } - - @VisibleForTesting - public void createMaster(String clusterId, Map labels) { - this.clusterId = clusterId; - this.config.put(CLUSTER_ID, clusterId); - this.config.put(KubernetesConfig.CLUSTER_START_TIME, - String.valueOf(System.currentTimeMillis())); - - this.dockerNetworkType = KubernetesConfig.getDockerNetworkType(config); - // Host network only supports clusterIp - if (dockerNetworkType == DockerNetworkType.HOST) { - config.put(SERVICE_EXPOSED_TYPE, ServiceExposedType.CLUSTER_IP.name()); - } - this.serviceExposedType = KubernetesConfig.getServiceExposedType(config); - this.configValue = ClusterUtils.convertConfigToString(this.config); - - // 1. create configMap. - ConfigMap configMap = createMasterConfigMap(clusterId, dockerNetworkType); - - // 2. create the master container - Container container = createMasterContainer(dockerNetworkType); - - // 3. create replication controller. - String masterDeployName = clusterId + K8SConstants.MASTER_RS_NAME_SUFFIX; - Deployment deployment = KubernetesResourceBuilder.createDeployment(clusterId, - masterDeployName, String.valueOf(DEFAULT_MASTER_ID), container, configMap, masterParam, + this.serviceExposedType = KubernetesConfig.getServiceExposedType(config); + this.configValue = ClusterUtils.convertConfigToString(this.config); + + // 1. create configMap. + ConfigMap configMap = createMasterConfigMap(clusterId, dockerNetworkType); + + // 2. create the master container + Container container = createMasterContainer(dockerNetworkType); + + // 3. create replication controller. + String masterDeployName = clusterId + K8SConstants.MASTER_RS_NAME_SUFFIX; + Deployment deployment = + KubernetesResourceBuilder.createDeployment( + clusterId, + masterDeployName, + String.valueOf(DEFAULT_MASTER_ID), + container, + configMap, + masterParam, dockerNetworkType); - // 3. create the service. - String serviceName = KubernetesUtils.getMasterServiceName(clusterId); - Service service = createService(serviceName, - ServiceExposedType.CLUSTER_IP, labels, null, masterParam); - OwnerReference ownerReference = KubernetesResourceBuilder.createOwnerReference(service); - - if (!serviceExposedType.equals(ServiceExposedType.CLUSTER_IP)) { - serviceName = KubernetesUtils.getMasterClientServiceName(clusterId); - createService(serviceName, ServiceExposedType.NODE_PORT, labels, ownerReference, - masterParam); - } + // 3. create the service. + String serviceName = KubernetesUtils.getMasterServiceName(clusterId); + Service service = + createService(serviceName, ServiceExposedType.CLUSTER_IP, labels, null, masterParam); + OwnerReference ownerReference = KubernetesResourceBuilder.createOwnerReference(service); - // 4. set owner reference. - deployment.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); - configMap.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); - - kubernetesClient.createConfigMap(configMap); - kubernetesClient.createDeployment(deployment); + if (!serviceExposedType.equals(ServiceExposedType.CLUSTER_IP)) { + serviceName = KubernetesUtils.getMasterClientServiceName(clusterId); + createService(serviceName, ServiceExposedType.NODE_PORT, labels, ownerReference, masterParam); } - @VisibleForTesting - public Container createMasterContainer(DockerNetworkType networkType) { - String command = masterParam.getContainerShellCommand(); - LOGGER.info("master start command: {}", command); - Map additionalEnvs = masterParam.getAdditionEnvs(); - - String containerName = masterParam.getContainerName(); - return KubernetesResourceBuilder.createContainer(containerName, String.valueOf(DEFAULT_MASTER_ID), masterId, - masterParam, command, additionalEnvs, networkType); + // 4. set owner reference. + deployment.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); + configMap.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); + + kubernetesClient.createConfigMap(configMap); + kubernetesClient.createDeployment(deployment); + } + + @VisibleForTesting + public Container createMasterContainer(DockerNetworkType networkType) { + String command = masterParam.getContainerShellCommand(); + LOGGER.info("master start command: {}", command); + Map additionalEnvs = masterParam.getAdditionEnvs(); + + String containerName = masterParam.getContainerName(); + return KubernetesResourceBuilder.createContainer( + containerName, + String.valueOf(DEFAULT_MASTER_ID), + masterId, + masterParam, + command, + additionalEnvs, + networkType); + } + + /** + * Set up a Config Map that will generate a geaflow-conf.yaml and log4j file. + * + * @param clusterId the cluster id + * @return the created configMap + */ + public ConfigMap createMasterConfigMap(String clusterId, DockerNetworkType dockerNetworkType) { + if (dockerNetworkType != DockerNetworkType.HOST) { + // use serviceName to discover master + String namespace = config.getString(NAME_SPACE); + String serviceSuffix = config.getString(SERVICE_SUFFIX); + serviceSuffix = + StringUtils.isBlank(serviceSuffix) + ? "" + : K8SConstants.NAMESPACE_SEPARATOR + serviceSuffix; + config.put( + MASTER_ADDRESS, + clusterId + + K8SConstants.SERVICE_NAME_SUFFIX + + K8SConstants.NAMESPACE_SEPARATOR + + namespace + + serviceSuffix); } - /** - * Set up a Config Map that will generate a geaflow-conf.yaml and log4j file. - * - * @param clusterId the cluster id - * @return the created configMap - */ - public ConfigMap createMasterConfigMap(String clusterId, DockerNetworkType dockerNetworkType) { - if (dockerNetworkType != DockerNetworkType.HOST) { - // use serviceName to discover master - String namespace = config.getString(NAME_SPACE); - String serviceSuffix = config.getString(SERVICE_SUFFIX); - serviceSuffix = StringUtils.isBlank(serviceSuffix) ? "" - : K8SConstants.NAMESPACE_SEPARATOR - + serviceSuffix; - config.put(MASTER_ADDRESS, - clusterId + K8SConstants.SERVICE_NAME_SUFFIX + K8SConstants.NAMESPACE_SEPARATOR - + namespace + serviceSuffix); - } - - return KubernetesResourceBuilder.createConfigMap(clusterId, masterParam, null); + return KubernetesResourceBuilder.createConfigMap(clusterId, masterParam, null); + } + + private Service createService( + String serviceName, + ServiceExposedType exposedType, + Map labels, + OwnerReference ownerReference, + AbstractKubernetesParam param) { + Service service = + KubernetesResourceBuilder.createService( + serviceName, exposedType, labels, ownerReference, param); + return kubernetesClient.createService(service); + } + + @Override + public void createNewContainer(int containerId, boolean isRecover) { + try { + // Create container. + String containerStartCommand = + getContainerShellCommand(containerId, isRecover, CONTAINER_LOG_SUFFIX); + Map additionalEnvs = containerParam.getAdditionEnvs(); + additionalEnvs.put(ENV_IS_RECOVER, String.valueOf(isRecover)); + additionalEnvs.put(CONTAINER_START_COMMAND, containerStartCommand); + + String podName = containerPodNamePrefix + containerId; + String startCommand = buildSupervisorStartCommand(CONTAINER_LOG_SUFFIX); + Container container = + KubernetesResourceBuilder.createContainer( + podName, + String.valueOf(containerId), + masterId, + containerParam, + startCommand, + additionalEnvs, + dockerNetworkType); + + // Create pod. + Pod containerPod = + KubernetesResourceBuilder.createPod( + clusterId, + podName, + String.valueOf(containerId), + ownerReference, + containerConfigMap, + containerParam, + container); + kubernetesClient.createPod(containerPod); + } catch (Exception e) { + LOGGER.error("Failed to request new container pod:{}", e.getMessage(), e); + throw new GeaflowRuntimeException(e); } - - private Service createService(String serviceName, ServiceExposedType exposedType, - Map labels, OwnerReference ownerReference, - AbstractKubernetesParam param) { - Service service = KubernetesResourceBuilder.createService(serviceName, exposedType, labels, - ownerReference, param); - return kubernetesClient.createService(service); - } - - @Override - public void createNewContainer(int containerId, boolean isRecover) { - try { - // Create container. - String containerStartCommand = getContainerShellCommand(containerId, isRecover, - CONTAINER_LOG_SUFFIX); - Map additionalEnvs = containerParam.getAdditionEnvs(); - additionalEnvs.put(ENV_IS_RECOVER, String.valueOf(isRecover)); - additionalEnvs.put(CONTAINER_START_COMMAND, containerStartCommand); - - String podName = containerPodNamePrefix + containerId; - String startCommand = buildSupervisorStartCommand(CONTAINER_LOG_SUFFIX); - Container container = KubernetesResourceBuilder.createContainer(podName, - String.valueOf(containerId), masterId, containerParam, startCommand, additionalEnvs, - dockerNetworkType); - - // Create pod. - Pod containerPod = KubernetesResourceBuilder.createPod(clusterId, podName, - String.valueOf(containerId), ownerReference, containerConfigMap, containerParam, - container); - kubernetesClient.createPod(containerPod); - } catch (Exception e) { - LOGGER.error("Failed to request new container pod:{}", e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + } + + @Override + public void createNewDriver(int driverId, int driverIndex) { + String serviceName = KubernetesUtils.getDriverServiceName(clusterId, driverIndex); + Service service = kubernetesClient.getService(serviceName); + if (service != null) { + LOGGER.info("driver service {} already exists, skip starting driver", serviceName); + return; } - @Override - public void createNewDriver(int driverId, int driverIndex) { - String serviceName = KubernetesUtils.getDriverServiceName(clusterId, driverIndex); - Service service = kubernetesClient.getService(serviceName); - if (service != null) { - LOGGER.info("driver service {} already exists, skip starting driver", serviceName); - return; - } - - // 1. Create container. - String driverStartCommand = getDriverShellCommand(driverId, driverIndex, DRIVER_LOG_SUFFIX); - Map additionalEnvs = driverParam.getAdditionEnvs(); - additionalEnvs.put(K8SConstants.ENV_CONTAINER_INDEX, String.valueOf(driverIndex)); - additionalEnvs.put(CONTAINER_START_COMMAND, driverStartCommand); - - String podName = driverPodNamePrefix + driverId; - String startCommand = buildSupervisorStartCommand(DRIVER_LOG_SUFFIX); - Container container = KubernetesResourceBuilder.createContainer(podName, - String.valueOf(driverId), masterId, driverParam, startCommand, additionalEnvs, + // 1. Create container. + String driverStartCommand = getDriverShellCommand(driverId, driverIndex, DRIVER_LOG_SUFFIX); + Map additionalEnvs = driverParam.getAdditionEnvs(); + additionalEnvs.put(K8SConstants.ENV_CONTAINER_INDEX, String.valueOf(driverIndex)); + additionalEnvs.put(CONTAINER_START_COMMAND, driverStartCommand); + + String podName = driverPodNamePrefix + driverId; + String startCommand = buildSupervisorStartCommand(DRIVER_LOG_SUFFIX); + Container container = + KubernetesResourceBuilder.createContainer( + podName, + String.valueOf(driverId), + masterId, + driverParam, + startCommand, + additionalEnvs, dockerNetworkType); - // 2. Create deployment. - String rcName = clusterId + K8SConstants.DRIVER_RS_NAME_SUFFIX + driverIndex; - Deployment deployment = KubernetesResourceBuilder.createDeployment(clusterId, rcName, - String.valueOf(driverId), container, containerConfigMap, driverParam, + // 2. Create deployment. + String rcName = clusterId + K8SConstants.DRIVER_RS_NAME_SUFFIX + driverIndex; + Deployment deployment = + KubernetesResourceBuilder.createDeployment( + clusterId, + rcName, + String.valueOf(driverId), + container, + containerConfigMap, + driverParam, dockerNetworkType); - // 3. Create the service. - createService(serviceName, serviceExposedType, driverParam.getPodLabels(clusterId), - ownerReference, driverParam); - - // 4. Set owner reference. - deployment.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); - - kubernetesClient.createDeployment(deployment); - } - - @Override - public void restartDriver(int driverId) { - LOGGER.info("Kill driver pod: {}.", driverId); - Map labels = new HashMap<>(); - labels.put(K8SConstants.LABEL_APP_KEY, clusterId); - labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_DRIVER); - labels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, String.valueOf(driverId)); - kubernetesClient.deletePod(labels); - } - - @Override - public void restartContainer(int containerId) { - // Kill the pod before start a new one. - Map labels = new HashMap<>(); - labels.put(K8SConstants.LABEL_APP_KEY, clusterId); - labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_WORKER); - labels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, String.valueOf(containerId)); - kubernetesClient.deletePod(labels); - createNewContainer(containerId, true); - } - - private String buildSupervisorStartCommand(String fileName) { - String logFile = config.getString(LOG_DIR) + File.separator + fileName; - return ClusterUtils.getStartCommand(clusterConfig.getSupervisorJvmOptions(), - KubernetesSupervisorRunner.class, logFile, config, classpath); - } - - @Override - public void close() { - super.close(); - if (kubernetesClient != null) { - kubernetesClient.destroyCluster(clusterId); - kubernetesClient.close(); - } + // 3. Create the service. + createService( + serviceName, + serviceExposedType, + driverParam.getPodLabels(clusterId), + ownerReference, + driverParam); + + // 4. Set owner reference. + deployment.getMetadata().setOwnerReferences(Collections.singletonList(ownerReference)); + + kubernetesClient.createDeployment(deployment); + } + + @Override + public void restartDriver(int driverId) { + LOGGER.info("Kill driver pod: {}.", driverId); + Map labels = new HashMap<>(); + labels.put(K8SConstants.LABEL_APP_KEY, clusterId); + labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_DRIVER); + labels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, String.valueOf(driverId)); + kubernetesClient.deletePod(labels); + } + + @Override + public void restartContainer(int containerId) { + // Kill the pod before start a new one. + Map labels = new HashMap<>(); + labels.put(K8SConstants.LABEL_APP_KEY, clusterId); + labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_WORKER); + labels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, String.valueOf(containerId)); + kubernetesClient.deletePod(labels); + createNewContainer(containerId, true); + } + + private String buildSupervisorStartCommand(String fileName) { + String logFile = config.getString(LOG_DIR) + File.separator + fileName; + return ClusterUtils.getStartCommand( + clusterConfig.getSupervisorJvmOptions(), + KubernetesSupervisorRunner.class, + logFile, + config, + classpath); + } + + @Override + public void close() { + super.close(); + if (kubernetesClient != null) { + kubernetesClient.destroyCluster(clusterId); + kubernetesClient.close(); } + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesResourceBuilder.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesResourceBuilder.java index bfedca627..2ecd4bc5d 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesResourceBuilder.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesResourceBuilder.java @@ -32,7 +32,27 @@ import static org.apache.geaflow.common.config.keys.DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT; +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.apache.geaflow.cluster.k8s.config.K8SConstants; +import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; +import org.apache.geaflow.cluster.k8s.config.KubernetesConfig.DockerNetworkType; +import org.apache.geaflow.cluster.k8s.config.KubernetesConfig.ServiceExposedType; +import org.apache.geaflow.cluster.k8s.config.KubernetesParam; +import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.utils.FileUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.google.common.base.Preconditions; + import io.fabric8.kubernetes.api.model.ConfigMap; import io.fabric8.kubernetes.api.model.ConfigMapBuilder; import io.fabric8.kubernetes.api.model.Container; @@ -56,252 +76,337 @@ import io.fabric8.kubernetes.api.model.VolumeMountBuilder; import io.fabric8.kubernetes.api.model.apps.Deployment; import io.fabric8.kubernetes.api.model.apps.DeploymentBuilder; -import java.io.File; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; -import org.apache.geaflow.cluster.k8s.config.K8SConstants; -import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; -import org.apache.geaflow.cluster.k8s.config.KubernetesConfig.DockerNetworkType; -import org.apache.geaflow.cluster.k8s.config.KubernetesConfig.ServiceExposedType; -import org.apache.geaflow.cluster.k8s.config.KubernetesParam; -import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; -import org.apache.geaflow.common.config.Configuration; -import org.apache.geaflow.common.utils.FileUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class KubernetesResourceBuilder { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesResourceBuilder.class); - - public static Container createContainer(String containerName, String containerId, - String masterId, KubernetesParam param, String command, - Map additionalEnvs, - DockerNetworkType dockerNetworkType) { - - Quantity masterCpuQuantity = param.getCpuQuantity(); - Quantity masterMemoryQuantity = param.getMemoryQuantity(); - - String image = param.getContainerImage(); - Preconditions.checkNotNull(image, "container image should be specified"); - - Configuration config = param.getConfig(); - String clusterName = config.getString(CLUSTER_NAME); - String pullPolicy = param.getContainerImagePullPolicy(); - String confDir = param.getConfDir(); - String logDir = param.getLogDir(); - String jobWorkPath = config.getString(WORK_DIR); - String jarDownloadPath = KubernetesConfig.getJarDownloadPath(config); - String udfList = config.getString(USER_JAR_FILES); - String engineJar = config.getString(ENGINE_JAR_FILES); - String autoRestart = param.getAutoRestart(); - String gatewayEndpoint = config.getString(GEAFLOW_GW_ENDPOINT, ""); - String geaflowToken = config.getString(GEAFLOW_DSL_CATALOG_TOKEN_KEY, ""); - Boolean clusterFaultInjectionEnable = param.getClusterFaultInjectionEnable(); - boolean alwaysDownloadEngineJar = config.getBoolean(ALWAYS_PULL_ENGINE_JAR); - - ContainerBuilder containerBuilder = new ContainerBuilder() + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesResourceBuilder.class); + + public static Container createContainer( + String containerName, + String containerId, + String masterId, + KubernetesParam param, + String command, + Map additionalEnvs, + DockerNetworkType dockerNetworkType) { + + Quantity masterCpuQuantity = param.getCpuQuantity(); + Quantity masterMemoryQuantity = param.getMemoryQuantity(); + + String image = param.getContainerImage(); + Preconditions.checkNotNull(image, "container image should be specified"); + + Configuration config = param.getConfig(); + String clusterName = config.getString(CLUSTER_NAME); + String pullPolicy = param.getContainerImagePullPolicy(); + String confDir = param.getConfDir(); + String logDir = param.getLogDir(); + String jobWorkPath = config.getString(WORK_DIR); + String jarDownloadPath = KubernetesConfig.getJarDownloadPath(config); + String udfList = config.getString(USER_JAR_FILES); + String engineJar = config.getString(ENGINE_JAR_FILES); + String autoRestart = param.getAutoRestart(); + String gatewayEndpoint = config.getString(GEAFLOW_GW_ENDPOINT, ""); + String geaflowToken = config.getString(GEAFLOW_DSL_CATALOG_TOKEN_KEY, ""); + Boolean clusterFaultInjectionEnable = param.getClusterFaultInjectionEnable(); + boolean alwaysDownloadEngineJar = config.getBoolean(ALWAYS_PULL_ENGINE_JAR); + + ContainerBuilder containerBuilder = + new ContainerBuilder() .withName(containerName) .withImage(image) .withImagePullPolicy(pullPolicy) .addNewEnv() - .withName(K8SConstants.ENV_CONF_DIR).withValue(confDir).endEnv() + .withName(K8SConstants.ENV_CONF_DIR) + .withValue(confDir) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_LOG_DIR).withValue(logDir).endEnv() + .withName(K8SConstants.ENV_LOG_DIR) + .withValue(logDir) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_JOB_WORK_PATH).withValue(jobWorkPath).endEnv() + .withName(K8SConstants.ENV_JOB_WORK_PATH) + .withValue(jobWorkPath) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_JAR_DOWNLOAD_PATH).withValue(jarDownloadPath).endEnv() + .withName(K8SConstants.ENV_JAR_DOWNLOAD_PATH) + .withValue(jarDownloadPath) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_UDF_LIST).withValue(udfList).endEnv() + .withName(K8SConstants.ENV_UDF_LIST) + .withValue(udfList) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_ENGINE_JAR).withValue(engineJar).endEnv() + .withName(K8SConstants.ENV_ENGINE_JAR) + .withValue(engineJar) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_START_COMMAND).withValue(command).endEnv() + .withName(K8SConstants.ENV_START_COMMAND) + .withValue(command) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_CONTAINER_ID).withValue(containerId).endEnv() + .withName(K8SConstants.ENV_CONTAINER_ID) + .withValue(containerId) + .endEnv() + .addNewEnv() + .withName(K8SConstants.ENV_CLUSTER_ID) + .withValue(clusterName) + .endEnv() + .addNewEnv() + .withName(K8SConstants.ENV_MASTER_ID) + .withValue(masterId) + .endEnv() + .addNewEnv() + .withName(K8SConstants.ENV_AUTO_RESTART) + .withValue(autoRestart) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_CLUSTER_ID).withValue(clusterName) + .withName(K8SConstants.ENV_CLUSTER_FAULT_INJECTION_ENABLE) + .withValue(String.valueOf(clusterFaultInjectionEnable)) .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_MASTER_ID).withValue(masterId).endEnv() + .withName(K8SConstants.ENV_CATALOG_TOKEN) + .withValue(geaflowToken) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_AUTO_RESTART).withValue(autoRestart).endEnv() + .withName(K8SConstants.ENV_GW_ENDPOINT) + .withValue(gatewayEndpoint) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_CLUSTER_FAULT_INJECTION_ENABLE).withValue(String.valueOf(clusterFaultInjectionEnable)).endEnv() + .withName(K8SConstants.ENV_ALWAYS_DOWNLOAD_ENGINE) + .withValue(String.valueOf(alwaysDownloadEngineJar)) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_CATALOG_TOKEN).withValue(geaflowToken).endEnv() + .withName(K8SConstants.ENV_NODE_NAME) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.NODE_NAME_FIELD_PATH) + .build()) + .build()) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_GW_ENDPOINT).withValue(gatewayEndpoint).endEnv() + .withName(K8SConstants.ENV_POD_NAME) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.POD_NAME_FIELD_PATH) + .build()) + .build()) + .endEnv() + .addNewEnv() + .withName(K8SConstants.ENV_POD_IP) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.POD_IP_FIELD_PATH) + .build()) + .build()) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_ALWAYS_DOWNLOAD_ENGINE).withValue(String.valueOf(alwaysDownloadEngineJar)).endEnv() + .withName(K8SConstants.ENV_HOST_IP) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.HOST_IP_FIELD_PATH) + .build()) + .build()) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_NODE_NAME).withValueFrom(new EnvVarSourceBuilder() - .withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.NODE_NAME_FIELD_PATH) - .build()).build()).endEnv().addNewEnv().withName(K8SConstants.ENV_POD_NAME) - .withValueFrom(new EnvVarSourceBuilder().withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.POD_NAME_FIELD_PATH) - .build()).build()).endEnv().addNewEnv().withName(K8SConstants.ENV_POD_IP) - .withValueFrom(new EnvVarSourceBuilder().withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.POD_IP_FIELD_PATH).build()) - .build()).endEnv().addNewEnv().withName(K8SConstants.ENV_HOST_IP) - .withValueFrom(new EnvVarSourceBuilder().withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.HOST_IP_FIELD_PATH) - .build()).build()).endEnv() + .withName(K8SConstants.ENV_SERVICE_ACCOUNT) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.SERVICE_ACCOUNT_NAME_FIELD_PATH) + .build()) + .build()) + .endEnv() .addNewEnv() - .withName(K8SConstants.ENV_SERVICE_ACCOUNT).withValueFrom( - new EnvVarSourceBuilder().withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.SERVICE_ACCOUNT_NAME_FIELD_PATH) - .build()).build()) + .withName(K8SConstants.ENV_NAMESPACE) + .withValueFrom( + new EnvVarSourceBuilder() + .withFieldRef( + new ObjectFieldSelectorBuilder() + .withFieldPath(K8SConstants.NAMESPACE_FIELD_PATH) + .build()) + .build()) .endEnv() - .addNewEnv().withName(K8SConstants.ENV_NAMESPACE).withValueFrom( - new EnvVarSourceBuilder().withFieldRef( - new ObjectFieldSelectorBuilder().withFieldPath(K8SConstants.NAMESPACE_FIELD_PATH) - .build()).build()).endEnv() .editOrNewResources() .addToRequests(K8SConstants.RESOURCE_NAME_MEMORY, masterMemoryQuantity) - .addToRequests(K8SConstants.RESOURCE_NAME_CPU, masterCpuQuantity).endResources() - .withVolumeMounts(new VolumeMountBuilder() - .withName(K8SConstants.GEAFLOW_CONF_VOLUME) - .withMountPath(confDir) - .build()); - - if (dockerNetworkType == DockerNetworkType.BRIDGE) { - if (param.getRpcPort() > 0) { - containerBuilder.addNewPort() - .withName(K8SConstants.RPC_PORT) - .withContainerPort(param.getRpcPort()) - .withProtocol("TCP").endPort(); - } - if (param.getHttpPort() > 0) { - containerBuilder.addNewPort() - .withName(K8SConstants.HTTP_PORT) - .withContainerPort(param.getHttpPort()) - .withProtocol("TCP").endPort(); - } - } + .addToRequests(K8SConstants.RESOURCE_NAME_CPU, masterCpuQuantity) + .endResources() + .withVolumeMounts( + new VolumeMountBuilder() + .withName(K8SConstants.GEAFLOW_CONF_VOLUME) + .withMountPath(confDir) + .build()); + + if (dockerNetworkType == DockerNetworkType.BRIDGE) { + if (param.getRpcPort() > 0) { + containerBuilder + .addNewPort() + .withName(K8SConstants.RPC_PORT) + .withContainerPort(param.getRpcPort()) + .withProtocol("TCP") + .endPort(); + } + if (param.getHttpPort() > 0) { + containerBuilder + .addNewPort() + .withName(K8SConstants.HTTP_PORT) + .withContainerPort(param.getHttpPort()) + .withProtocol("TCP") + .endPort(); + } + } - Quantity containerDiskQuantity = param.getDiskQuantity(); - if (containerDiskQuantity != null) { - containerBuilder.editResources() - .addToRequests(K8SConstants.RESOURCE_NAME_EPHEMERAL_STORAGE, containerDiskQuantity).endResources(); - } + Quantity containerDiskQuantity = param.getDiskQuantity(); + if (containerDiskQuantity != null) { + containerBuilder + .editResources() + .addToRequests(K8SConstants.RESOURCE_NAME_EPHEMERAL_STORAGE, containerDiskQuantity) + .endResources(); + } - boolean enableMemoryLimit = KubernetesConfig.enableResourceMemoryLimit(config); - boolean enableCpuLimit = KubernetesConfig.enableResourceCpuLimit(config); - boolean enableDiskLimit = KubernetesConfig.enableResourceEphemeralStorageLimit(config); + boolean enableMemoryLimit = KubernetesConfig.enableResourceMemoryLimit(config); + boolean enableCpuLimit = KubernetesConfig.enableResourceCpuLimit(config); + boolean enableDiskLimit = KubernetesConfig.enableResourceEphemeralStorageLimit(config); - if (enableMemoryLimit) { - containerBuilder.editResources() - .addToLimits(K8SConstants.RESOURCE_NAME_MEMORY, masterMemoryQuantity).endResources(); - } - if (enableCpuLimit) { - containerBuilder.editResources() - .addToLimits(K8SConstants.RESOURCE_NAME_CPU, masterCpuQuantity).endResources(); - } - if (enableDiskLimit) { - String size = KubernetesConfig.getResourceEphemeralStorageSize(config); - Quantity ephemeralStorageQuantity = new Quantity(size); - containerBuilder.editResources() - .addToLimits(K8SConstants.RESOURCE_NAME_EPHEMERAL_STORAGE, ephemeralStorageQuantity) - .endResources(); - } - - if (additionalEnvs != null && !additionalEnvs.isEmpty()) { - additionalEnvs.entrySet().stream().forEach( - e -> containerBuilder.addNewEnv().withName(e.getKey()).withValue(e.getValue()) - .endEnv()); - } + if (enableMemoryLimit) { + containerBuilder + .editResources() + .addToLimits(K8SConstants.RESOURCE_NAME_MEMORY, masterMemoryQuantity) + .endResources(); + } + if (enableCpuLimit) { + containerBuilder + .editResources() + .addToLimits(K8SConstants.RESOURCE_NAME_CPU, masterCpuQuantity) + .endResources(); + } + if (enableDiskLimit) { + String size = KubernetesConfig.getResourceEphemeralStorageSize(config); + Quantity ephemeralStorageQuantity = new Quantity(size); + containerBuilder + .editResources() + .addToLimits(K8SConstants.RESOURCE_NAME_EPHEMERAL_STORAGE, ephemeralStorageQuantity) + .endResources(); + } - return containerBuilder.build(); + if (additionalEnvs != null && !additionalEnvs.isEmpty()) { + additionalEnvs.entrySet().stream() + .forEach( + e -> + containerBuilder + .addNewEnv() + .withName(e.getKey()) + .withValue(e.getValue()) + .endEnv()); } - /** - * Setup a Config Map that will generate a conf file. - * - * @param clusterId the cluster id - * @return the created configMap - */ - public static ConfigMap createConfigMap(String clusterId, KubernetesParam param, - OwnerReference ownerReference) { - Map config = param.getConfig().getConfigMap(); - StringBuilder confContent = new StringBuilder(); - config.forEach( - (k, v) -> confContent.append(k).append(CONFIG_KV_SEPARATOR).append(v).append(System.lineSeparator())); - - String configMapName = param.getConfigMapName(clusterId); - ObjectMetaBuilder metaBuilder = new ObjectMetaBuilder().withName(configMapName); - if (ownerReference != null) { - metaBuilder.withOwnerReferences(ownerReference); - } + return containerBuilder.build(); + } + + /** + * Setup a Config Map that will generate a conf file. + * + * @param clusterId the cluster id + * @return the created configMap + */ + public static ConfigMap createConfigMap( + String clusterId, KubernetesParam param, OwnerReference ownerReference) { + Map config = param.getConfig().getConfigMap(); + StringBuilder confContent = new StringBuilder(); + config.forEach( + (k, v) -> + confContent + .append(k) + .append(CONFIG_KV_SEPARATOR) + .append(v) + .append(System.lineSeparator())); + + String configMapName = param.getConfigMapName(clusterId); + ObjectMetaBuilder metaBuilder = new ObjectMetaBuilder().withName(configMapName); + if (ownerReference != null) { + metaBuilder.withOwnerReferences(ownerReference); + } - ConfigMapBuilder configMapBuilder = new ConfigMapBuilder() + ConfigMapBuilder configMapBuilder = + new ConfigMapBuilder() .withMetadata(metaBuilder.build()) .addToData(K8SConstants.ENV_CONFIG_FILE, confContent.toString()); - String files = param.getConfig().getString(CONTAINER_CONF_FILES); - if (StringUtils.isNotEmpty(files)) { - for (String filePath : files.split(CONFIG_LIST_SEPARATOR)) { - String fileName = filePath.substring(filePath.lastIndexOf(File.separator) + 1); - String fileContent = FileUtil.getContentFromFile(filePath); - if (fileContent != null) { - configMapBuilder.addToData(fileName, fileContent); - } else { - LOGGER.info("File {} not exist, will not add to configMap", filePath); - } - } + String files = param.getConfig().getString(CONTAINER_CONF_FILES); + if (StringUtils.isNotEmpty(files)) { + for (String filePath : files.split(CONFIG_LIST_SEPARATOR)) { + String fileName = filePath.substring(filePath.lastIndexOf(File.separator) + 1); + String fileContent = FileUtil.getContentFromFile(filePath); + if (fileContent != null) { + configMapBuilder.addToData(fileName, fileContent); + } else { + LOGGER.info("File {} not exist, will not add to configMap", filePath); } - - return configMapBuilder.build(); + } } - public static ConfigMap updateConfigMap(ConfigMap configMap, - Map newConfig) { - Map updatedConfig = KubernetesUtils.loadConfigurationFromString( + return configMapBuilder.build(); + } + + public static ConfigMap updateConfigMap(ConfigMap configMap, Map newConfig) { + Map updatedConfig = + KubernetesUtils.loadConfigurationFromString( configMap.getData().get(K8SConstants.ENV_CONFIG_FILE)); - updatedConfig.putAll(newConfig); + updatedConfig.putAll(newConfig); - StringBuilder confContent = new StringBuilder(); - updatedConfig.forEach((k, v) -> - confContent.append(k).append(": ").append(v).append(System.lineSeparator())); + StringBuilder confContent = new StringBuilder(); + updatedConfig.forEach( + (k, v) -> confContent.append(k).append(": ").append(v).append(System.lineSeparator())); - ConfigMapBuilder configMapBuilder = new ConfigMapBuilder() + ConfigMapBuilder configMapBuilder = + new ConfigMapBuilder() .withNewMetadata() .withName(configMap.getMetadata().getName()) .withOwnerReferences(configMap.getMetadata().getOwnerReferences()) .endMetadata() .addToData(configMap.getData()); - configMapBuilder.addToData(K8SConstants.ENV_CONFIG_FILE, confContent.toString()); - return configMapBuilder.build(); + configMapBuilder.addToData(K8SConstants.ENV_CONFIG_FILE, confContent.toString()); + return configMapBuilder.build(); + } + + private static Map getAnnotations(KubernetesParam param) { + return param.getAnnotations(); + } + + public static Pod createPod( + String clusterId, + String name, + String id, + OwnerReference ownerReference, + ConfigMap configMap, + KubernetesParam param, + Container... container) { + + List configMapItems = + configMap.getData().keySet().stream() + .map(e -> new KeyToPath(e, null, e)) + .collect(Collectors.toList()); + + Map labels = param.getPodLabels(clusterId); + labels.put(LABEL_COMPONENT_ID_KEY, id); + Map annotations = getAnnotations(param); + + ObjectMetaBuilder metaBuilder = + new ObjectMetaBuilder().withName(name).withLabels(labels).withAnnotations(annotations); + if (ownerReference != null) { + metaBuilder.withOwnerReferences(ownerReference); } - - private static Map getAnnotations(KubernetesParam param) { - return param.getAnnotations(); - } - - public static Pod createPod(String clusterId, String name, - String id, OwnerReference ownerReference, - ConfigMap configMap, KubernetesParam param, - Container... container) { - - List configMapItems = configMap.getData().keySet().stream() - .map(e -> new KeyToPath(e, null, e)).collect(Collectors.toList()); - - Map labels = param.getPodLabels(clusterId); - labels.put(LABEL_COMPONENT_ID_KEY, id); - Map annotations = getAnnotations(param); - - ObjectMetaBuilder metaBuilder = new ObjectMetaBuilder().withName(name) - .withLabels(labels).withAnnotations(annotations); - if (ownerReference != null) { - metaBuilder.withOwnerReferences(ownerReference); - } - PodBuilder podBuilder = new PodBuilder() + PodBuilder podBuilder = + new PodBuilder() .withMetadata(metaBuilder.build()) .editOrNewSpec() .withServiceAccountName(param.getServiceAccount()) @@ -313,68 +418,74 @@ public static Pod createPod(String clusterId, String name, .withName(K8SConstants.GEAFLOW_CONF_VOLUME) .withNewConfigMap() .withName(param.getConfigMapName(clusterId)) - .addAllToItems(configMapItems).endConfigMap().endVolume() + .addAllToItems(configMapItems) + .endConfigMap() + .endVolume() .addNewVolume() .withName(K8SConstants.GEAFLOW_LOG_VOLUME) - .withNewEmptyDir().endEmptyDir() + .withNewEmptyDir() + .endEmptyDir() .endVolume() .endSpec(); - List matchExpressionsList = - KubernetesUtils.getMatchExpressions(param.getConfig()); - if (matchExpressionsList.size() > 0) { - podBuilder.editSpec() - .withNewAffinity() - .withNewNodeAffinity() - .withNewRequiredDuringSchedulingIgnoredDuringExecution() - .addNewNodeSelectorTerm() - .withMatchExpressions(matchExpressionsList) - .endNodeSelectorTerm() - .endRequiredDuringSchedulingIgnoredDuringExecution() - .endNodeAffinity() - .endAffinity() - .endSpec(); - } - - List tolerationList = KubernetesUtils.getTolerations(param.getConfig()); - if (tolerationList.size() > 0) { - podBuilder.editSpec() - .withTolerations(tolerationList) - .endSpec(); - } - - DockerNetworkType dockerNetworkType = - KubernetesConfig.getDockerNetworkType(param.getConfig()); - if (dockerNetworkType == DockerNetworkType.HOST) { - podBuilder.editSpec() - .withHostNetwork(true) - .withDnsPolicy(K8SConstants.HOST_NETWORK_DNS_POLICY) - .endSpec(); - } - - return podBuilder.build(); + List matchExpressionsList = + KubernetesUtils.getMatchExpressions(param.getConfig()); + if (matchExpressionsList.size() > 0) { + podBuilder + .editSpec() + .withNewAffinity() + .withNewNodeAffinity() + .withNewRequiredDuringSchedulingIgnoredDuringExecution() + .addNewNodeSelectorTerm() + .withMatchExpressions(matchExpressionsList) + .endNodeSelectorTerm() + .endRequiredDuringSchedulingIgnoredDuringExecution() + .endNodeAffinity() + .endAffinity() + .endSpec(); } - public static Deployment createDeployment(String clusterId, - String rcName, - String id, - Container container, - ConfigMap configMap, - KubernetesParam param, - DockerNetworkType dockerNetworkType) { - String configMapName = param.getConfigMapName(clusterId); - String serviceAccount = param.getServiceAccount(); - - Map labels = param.getPodLabels(clusterId); - labels.put(LABEL_COMPONENT_ID_KEY, id); - Map annotations = getAnnotations(param); - - List configMapItems = configMap.getData().keySet().stream() - .map(e -> new KeyToPath(e, null, e)).collect(Collectors.toList()); + List tolerationList = KubernetesUtils.getTolerations(param.getConfig()); + if (tolerationList.size() > 0) { + podBuilder.editSpec().withTolerations(tolerationList).endSpec(); + } - int replicas = param.enableLeaderElection() ? 2 : 1; + DockerNetworkType dockerNetworkType = KubernetesConfig.getDockerNetworkType(param.getConfig()); + if (dockerNetworkType == DockerNetworkType.HOST) { + podBuilder + .editSpec() + .withHostNetwork(true) + .withDnsPolicy(K8SConstants.HOST_NETWORK_DNS_POLICY) + .endSpec(); + } - DeploymentBuilder deploymentBuilder = new DeploymentBuilder() + return podBuilder.build(); + } + + public static Deployment createDeployment( + String clusterId, + String rcName, + String id, + Container container, + ConfigMap configMap, + KubernetesParam param, + DockerNetworkType dockerNetworkType) { + String configMapName = param.getConfigMapName(clusterId); + String serviceAccount = param.getServiceAccount(); + + Map labels = param.getPodLabels(clusterId); + labels.put(LABEL_COMPONENT_ID_KEY, id); + Map annotations = getAnnotations(param); + + List configMapItems = + configMap.getData().keySet().stream() + .map(e -> new KeyToPath(e, null, e)) + .collect(Collectors.toList()); + + int replicas = param.enableLeaderElection() ? 2 : 1; + + DeploymentBuilder deploymentBuilder = + new DeploymentBuilder() .editOrNewMetadata() .withName(rcName) .withLabels(labels) @@ -399,118 +510,144 @@ public static Deployment createDeployment(String clusterId, .withNewConfigMap() .withName(configMapName) .addAllToItems(configMapItems) - .endConfigMap().endVolume() + .endConfigMap() + .endVolume() .addNewVolume() .withName(K8SConstants.GEAFLOW_LOG_VOLUME) .withNewEmptyDir() - .endEmptyDir().endVolume() + .endEmptyDir() + .endVolume() .endSpec() .endTemplate() .endSpec(); - String dnsDomains = param.getConfig().getString(DNS_SEARCH_DOMAINS); - if (!StringUtils.isEmpty(dnsDomains)) { - List domains = Arrays.stream(dnsDomains.split(",")).map(String::trim) - .collect(Collectors.toList()); - deploymentBuilder.editSpec().editTemplate().editSpec() - .withDnsConfig(new PodDNSConfigBuilder().withSearches(domains).build()).endSpec() - .endTemplate().endSpec(); - } - - if (dockerNetworkType == DockerNetworkType.HOST) { - deploymentBuilder.editSpec().editTemplate() - .editSpec() - .withHostNetwork(true) - .withDnsPolicy(K8SConstants.HOST_NETWORK_DNS_POLICY) - .endSpec().endTemplate().endSpec(); - } + String dnsDomains = param.getConfig().getString(DNS_SEARCH_DOMAINS); + if (!StringUtils.isEmpty(dnsDomains)) { + List domains = + Arrays.stream(dnsDomains.split(",")).map(String::trim).collect(Collectors.toList()); + deploymentBuilder + .editSpec() + .editTemplate() + .editSpec() + .withDnsConfig(new PodDNSConfigBuilder().withSearches(domains).build()) + .endSpec() + .endTemplate() + .endSpec(); + } - List matchExpressionsList = - KubernetesUtils.getMatchExpressions(param.getConfig()); - if (matchExpressionsList.size() > 0) { - deploymentBuilder.editSpec().editTemplate().editSpec() - .withNewAffinity() - .withNewNodeAffinity() - .withNewRequiredDuringSchedulingIgnoredDuringExecution() - .addNewNodeSelectorTerm() - .withMatchExpressions(matchExpressionsList) - .endNodeSelectorTerm() - .endRequiredDuringSchedulingIgnoredDuringExecution() - .endNodeAffinity() - .endAffinity() - .endSpec().endTemplate().endSpec(); - } + if (dockerNetworkType == DockerNetworkType.HOST) { + deploymentBuilder + .editSpec() + .editTemplate() + .editSpec() + .withHostNetwork(true) + .withDnsPolicy(K8SConstants.HOST_NETWORK_DNS_POLICY) + .endSpec() + .endTemplate() + .endSpec(); + } - List tolerationList = KubernetesUtils.getTolerations(param.getConfig()); - if (tolerationList.size() > 0) { - deploymentBuilder.editSpec().editTemplate().editSpec() - .withTolerations(tolerationList) - .endSpec().endTemplate().endSpec(); - } + List matchExpressionsList = + KubernetesUtils.getMatchExpressions(param.getConfig()); + if (matchExpressionsList.size() > 0) { + deploymentBuilder + .editSpec() + .editTemplate() + .editSpec() + .withNewAffinity() + .withNewNodeAffinity() + .withNewRequiredDuringSchedulingIgnoredDuringExecution() + .addNewNodeSelectorTerm() + .withMatchExpressions(matchExpressionsList) + .endNodeSelectorTerm() + .endRequiredDuringSchedulingIgnoredDuringExecution() + .endNodeAffinity() + .endAffinity() + .endSpec() + .endTemplate() + .endSpec(); + } - return deploymentBuilder.build(); + List tolerationList = KubernetesUtils.getTolerations(param.getConfig()); + if (tolerationList.size() > 0) { + deploymentBuilder + .editSpec() + .editTemplate() + .editSpec() + .withTolerations(tolerationList) + .endSpec() + .endTemplate() + .endSpec(); } - public static Service createService(String serviceName, ServiceExposedType exposedType, - Map labels, OwnerReference ownerReference, - KubernetesParam param) { - ObjectMetaBuilder metaBuilder = new ObjectMetaBuilder().withName(serviceName); - if (ownerReference != null) { - metaBuilder.withOwnerReferences(ownerReference); - } - Map serviceLabels = param.getServiceLabels(); - serviceLabels.putAll(labels); - metaBuilder.withLabels(serviceLabels); - Map serviceAnnotations = param.getServiceAnnotations(); - if (serviceAnnotations != null) { - metaBuilder.withAnnotations(serviceAnnotations); - } - ServiceBuilder svcBuilder = new ServiceBuilder() + return deploymentBuilder.build(); + } + + public static Service createService( + String serviceName, + ServiceExposedType exposedType, + Map labels, + OwnerReference ownerReference, + KubernetesParam param) { + ObjectMetaBuilder metaBuilder = new ObjectMetaBuilder().withName(serviceName); + if (ownerReference != null) { + metaBuilder.withOwnerReferences(ownerReference); + } + Map serviceLabels = param.getServiceLabels(); + serviceLabels.putAll(labels); + metaBuilder.withLabels(serviceLabels); + Map serviceAnnotations = param.getServiceAnnotations(); + if (serviceAnnotations != null) { + metaBuilder.withAnnotations(serviceAnnotations); + } + ServiceBuilder svcBuilder = + new ServiceBuilder() .withMetadata(metaBuilder.build()) .withNewSpec() .withType(exposedType.getServiceExposedType()) .withSelector(labels) .endSpec(); - List servicePorts = new ArrayList<>(); - if (param.getRpcPort() > 0) { - ServicePortBuilder portBuilder = new ServicePortBuilder(); - portBuilder.withName(K8SConstants.RPC_PORT) - .withPort(param.getRpcPort()) - .withProtocol(K8SConstants.TCP_PROTOCOL); - if (exposedType == ServiceExposedType.NODE_PORT && param.getNodePort() > 0) { - portBuilder.withNodePort(param.getNodePort()); - } - servicePorts.add(portBuilder.build()); - } - if (param.getHttpPort() > 0) { - ServicePortBuilder portBuilder = new ServicePortBuilder(); - portBuilder.withName(K8SConstants.HTTP_PORT) - .withPort(param.getHttpPort()) - .withProtocol(K8SConstants.TCP_PROTOCOL); - if (exposedType == ServiceExposedType.NODE_PORT && param.getNodePort() > 0) { - portBuilder.withNodePort(param.getNodePort()); - } - servicePorts.add(portBuilder.build()); - } - svcBuilder.editSpec().addAllToPorts(servicePorts).endSpec(); - - if (exposedType.equals(ServiceExposedType.CLUSTER_IP)) { - svcBuilder.editSpec().withClusterIP("None").endSpec(); - } - - return svcBuilder.build(); + List servicePorts = new ArrayList<>(); + if (param.getRpcPort() > 0) { + ServicePortBuilder portBuilder = new ServicePortBuilder(); + portBuilder + .withName(K8SConstants.RPC_PORT) + .withPort(param.getRpcPort()) + .withProtocol(K8SConstants.TCP_PROTOCOL); + if (exposedType == ServiceExposedType.NODE_PORT && param.getNodePort() > 0) { + portBuilder.withNodePort(param.getNodePort()); + } + servicePorts.add(portBuilder.build()); + } + if (param.getHttpPort() > 0) { + ServicePortBuilder portBuilder = new ServicePortBuilder(); + portBuilder + .withName(K8SConstants.HTTP_PORT) + .withPort(param.getHttpPort()) + .withProtocol(K8SConstants.TCP_PROTOCOL); + if (exposedType == ServiceExposedType.NODE_PORT && param.getNodePort() > 0) { + portBuilder.withNodePort(param.getNodePort()); + } + servicePorts.add(portBuilder.build()); } + svcBuilder.editSpec().addAllToPorts(servicePorts).endSpec(); - public static OwnerReference createOwnerReference(Service service) { - Preconditions.checkNotNull(service, "Service is required to create owner reference."); - return new OwnerReferenceBuilder() - .withName(service.getMetadata().getName()) - .withApiVersion(service.getApiVersion()) - .withUid(service.getMetadata().getUid()) - .withKind(service.getKind()) - .withController(true) - .build(); + if (exposedType.equals(ServiceExposedType.CLUSTER_IP)) { + svcBuilder.editSpec().withClusterIP("None").endSpec(); } + return svcBuilder.build(); + } + + public static OwnerReference createOwnerReference(Service service) { + Preconditions.checkNotNull(service, "Service is required to create owner reference."); + return new OwnerReferenceBuilder() + .withName(service.getMetadata().getName()) + .withApiVersion(service.getApiVersion()) + .withUid(service.getMetadata().getUid()) + .withKind(service.getKind()) + .withController(true) + .build(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/AbstractKubernetesParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/AbstractKubernetesParam.java index 9b34c041a..9e18315e6 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/AbstractKubernetesParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/AbstractKubernetesParam.java @@ -29,136 +29,136 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.SERVICE_USER_LABELS; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.PROCESS_AUTO_RESTART; -import io.fabric8.kubernetes.api.model.Quantity; -import io.fabric8.kubernetes.api.model.QuantityBuilder; import java.util.Map; + import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.common.config.Configuration; -public abstract class AbstractKubernetesParam implements KubernetesParam { - - protected Configuration config; - protected ClusterConfig clusterConfig; - - public AbstractKubernetesParam(Configuration config) { - this.config = config; - this.clusterConfig = ClusterConfig.build(config); - } - - public AbstractKubernetesParam(ClusterConfig clusterConfig) { - this.clusterConfig = clusterConfig; - this.config = clusterConfig.getConfig(); - } - - public String getContainerImage() { - return config.getString(CONTAINER_IMAGE); - } - - public String getContainerImagePullPolicy() { - return config.getString(CONTAINER_IMAGE_PULL_POLICY); - } - - public String getServiceAccount() { - return config.getString(SERVICE_ACCOUNT); - } - - public Map getServiceLabels() { - return KubernetesUtils.getPairsConf(config, SERVICE_USER_LABELS); - } - - @Override - public Map getServiceAnnotations() { - return KubernetesUtils.getPairsConf(config, SERVICE_USER_ANNOTATIONS.getKey()); - } - - protected Quantity getCpuQuantity(double cpu) { - return new QuantityBuilder(false).withAmount(String.valueOf(cpu)).build(); - } - - protected Quantity getMemoryQuantity(long memoryMB) { - return new QuantityBuilder(false).withAmount(String.valueOf((memoryMB) << 20)) - .build(); - } - - protected Quantity getDiskQuantity(long diskGB) { - return new QuantityBuilder(false).withAmount(String.valueOf((diskGB) << 30)) - .build(); - } - - @Override - public String getConfDir() { - return config.getString(CONF_DIR); - } - - @Override - public String getLogDir() { - return config.getString(LOG_DIR); - } - - @Override - public String getAutoRestart() { - return config.getString(PROCESS_AUTO_RESTART); - } - - @Override - public Boolean getClusterFaultInjectionEnable() { - return config.getBoolean(CLUSTER_FAULT_INJECTION_ENABLE); - } - - @Override - public int getRpcPort() { - return 0; - } - - @Override - public int getHttpPort() { - return 0; - } - - public int getNodePort() { - return 0; - } - - @Override - public Quantity getCpuQuantity() { - Double cpu = getContainerCpu(); - return getCpuQuantity(cpu); - } - - protected abstract Double getContainerCpu(); - - @Override - public Quantity getMemoryQuantity() { - long memoryMB = getContainerMemoryMB(); - return getMemoryQuantity(memoryMB); - } - - protected abstract long getContainerMemoryMB(); - - @Override - public Quantity getDiskQuantity() { - long diskGB = getContainerDiskGB(); - if (diskGB == 0) { - return null; - } - return getDiskQuantity(diskGB); - } - - protected abstract long getContainerDiskGB(); - - @Override - public Map getAdditionEnvs() { - return null; - } +import io.fabric8.kubernetes.api.model.Quantity; +import io.fabric8.kubernetes.api.model.QuantityBuilder; - @Override - public Configuration getConfig() { - return config; - } +public abstract class AbstractKubernetesParam implements KubernetesParam { - @Override - public boolean enableLeaderElection() { - return false; - } + protected Configuration config; + protected ClusterConfig clusterConfig; + + public AbstractKubernetesParam(Configuration config) { + this.config = config; + this.clusterConfig = ClusterConfig.build(config); + } + + public AbstractKubernetesParam(ClusterConfig clusterConfig) { + this.clusterConfig = clusterConfig; + this.config = clusterConfig.getConfig(); + } + + public String getContainerImage() { + return config.getString(CONTAINER_IMAGE); + } + + public String getContainerImagePullPolicy() { + return config.getString(CONTAINER_IMAGE_PULL_POLICY); + } + + public String getServiceAccount() { + return config.getString(SERVICE_ACCOUNT); + } + + public Map getServiceLabels() { + return KubernetesUtils.getPairsConf(config, SERVICE_USER_LABELS); + } + + @Override + public Map getServiceAnnotations() { + return KubernetesUtils.getPairsConf(config, SERVICE_USER_ANNOTATIONS.getKey()); + } + + protected Quantity getCpuQuantity(double cpu) { + return new QuantityBuilder(false).withAmount(String.valueOf(cpu)).build(); + } + + protected Quantity getMemoryQuantity(long memoryMB) { + return new QuantityBuilder(false).withAmount(String.valueOf((memoryMB) << 20)).build(); + } + + protected Quantity getDiskQuantity(long diskGB) { + return new QuantityBuilder(false).withAmount(String.valueOf((diskGB) << 30)).build(); + } + + @Override + public String getConfDir() { + return config.getString(CONF_DIR); + } + + @Override + public String getLogDir() { + return config.getString(LOG_DIR); + } + + @Override + public String getAutoRestart() { + return config.getString(PROCESS_AUTO_RESTART); + } + + @Override + public Boolean getClusterFaultInjectionEnable() { + return config.getBoolean(CLUSTER_FAULT_INJECTION_ENABLE); + } + + @Override + public int getRpcPort() { + return 0; + } + + @Override + public int getHttpPort() { + return 0; + } + + public int getNodePort() { + return 0; + } + + @Override + public Quantity getCpuQuantity() { + Double cpu = getContainerCpu(); + return getCpuQuantity(cpu); + } + + protected abstract Double getContainerCpu(); + + @Override + public Quantity getMemoryQuantity() { + long memoryMB = getContainerMemoryMB(); + return getMemoryQuantity(memoryMB); + } + + protected abstract long getContainerMemoryMB(); + + @Override + public Quantity getDiskQuantity() { + long diskGB = getContainerDiskGB(); + if (diskGB == 0) { + return null; + } + return getDiskQuantity(diskGB); + } + + protected abstract long getContainerDiskGB(); + + @Override + public Map getAdditionEnvs() { + return null; + } + + @Override + public Configuration getConfig() { + return config; + } + + @Override + public boolean enableLeaderElection() { + return false; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/K8SConstants.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/K8SConstants.java index 6b3aee4f0..9f3dec5a6 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/K8SConstants.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/K8SConstants.java @@ -19,147 +19,144 @@ package org.apache.geaflow.cluster.k8s.config; -/** - * Constants for kubernetes. - */ +/** Constants for kubernetes. */ public final class K8SConstants { - public static final String RANDOM_CLUSTER_ID_PREFIX = "geaflow"; - - public static final String RPC_PORT = "rpc"; + public static final String RANDOM_CLUSTER_ID_PREFIX = "geaflow"; - public static final String HTTP_PORT = "rest"; + public static final String RPC_PORT = "rpc"; - public static final String NAME_SEPARATOR = "-"; + public static final String HTTP_PORT = "rest"; - public static final String NAMESPACE_SEPARATOR = "."; + public static final String NAME_SEPARATOR = "-"; - public static final String CONFIG_KV_SEPARATOR = ":"; + public static final String NAMESPACE_SEPARATOR = "."; - public static final String ADDRESS_SEPARATOR = "="; + public static final String CONFIG_KV_SEPARATOR = ":"; - public static final String CONFIG_LIST_SEPARATOR = ","; + public static final String ADDRESS_SEPARATOR = "="; - public static final String MASTER_RS_NAME_SUFFIX = "-master-rs"; + public static final String CONFIG_LIST_SEPARATOR = ","; - public static final String DRIVER_RS_NAME_SUFFIX = "-driver-rs-"; + public static final String MASTER_RS_NAME_SUFFIX = "-master-rs"; - public static final String CLIENT_NAME_SUFFIX = "-client"; + public static final String DRIVER_RS_NAME_SUFFIX = "-driver-rs-"; - public static final String MASTER_NAME_SUFFIX = "-master"; + public static final String CLIENT_NAME_SUFFIX = "-client"; - public static final String DRIVER_NAME_SUFFIX = "-driver"; + public static final String MASTER_NAME_SUFFIX = "-master"; - public static final String WORKER_NAME_SUFFIX = "-container"; + public static final String DRIVER_NAME_SUFFIX = "-driver"; - public static final String SERVICE_NAME_SUFFIX = "-service"; + public static final String WORKER_NAME_SUFFIX = "-container"; - public static final String CLIENT_SERVICE_NAME_SUFFIX = "-client-service"; + public static final String SERVICE_NAME_SUFFIX = "-service"; - public static final String DRIVER_SERVICE_NAME_SUFFIX = "-driver-service-"; + public static final String CLIENT_SERVICE_NAME_SUFFIX = "-client-service"; - public static final String MASTER_CONFIG_MAP_SUFFIX = "-master-conf-map"; + public static final String DRIVER_SERVICE_NAME_SUFFIX = "-driver-service-"; - public static final String WORKER_CONFIG_MAP_SUFFIX = "-worker-conf-map"; + public static final String MASTER_CONFIG_MAP_SUFFIX = "-master-conf-map"; - public static final String CLIENT_CONFIG_MAP_SUFFIX = "-client-conf-map"; + public static final String WORKER_CONFIG_MAP_SUFFIX = "-worker-conf-map"; - public static final String LABEL_APP_KEY = "app"; + public static final String CLIENT_CONFIG_MAP_SUFFIX = "-client-conf-map"; - public static final String LABEL_CONFIG_MAP_LOCK = "config-map-lock"; + public static final String LABEL_APP_KEY = "app"; - public static final String LABEL_COMPONENT_KEY = "component"; + public static final String LABEL_CONFIG_MAP_LOCK = "config-map-lock"; - public static final String LABEL_COMPONENT_MASTER = "master"; + public static final String LABEL_COMPONENT_KEY = "component"; - public static final String LABEL_COMPONENT_WORKER = "worker"; + public static final String LABEL_COMPONENT_MASTER = "master"; - public static final String LABEL_COMPONENT_DRIVER = "driver"; + public static final String LABEL_COMPONENT_WORKER = "worker"; - public static final String LABEL_COMPONENT_CLIENT = "client"; + public static final String LABEL_COMPONENT_DRIVER = "driver"; - public static final String LABEL_COMPONENT_ID_KEY = "component-id"; + public static final String LABEL_COMPONENT_CLIENT = "client"; - public static final String RESOURCE_NAME_MEMORY = "memory"; + public static final String LABEL_COMPONENT_ID_KEY = "component-id"; - public static final String RESOURCE_NAME_CPU = "cpu"; + public static final String RESOURCE_NAME_MEMORY = "memory"; - public static final String RESOURCE_NAME_EPHEMERAL_STORAGE = "ephemeral-storage"; + public static final String RESOURCE_NAME_CPU = "cpu"; - public static final String POD_RESTART_POLICY = "Always"; + public static final String RESOURCE_NAME_EPHEMERAL_STORAGE = "ephemeral-storage"; - public static final String HOST_ALIASES_CONFIG_MAP_NAME = "host-aliases"; + public static final String POD_RESTART_POLICY = "Always"; - public static final String HOST_NETWORK_DNS_POLICY = "ClusterFirstWithHostNet"; + public static final String HOST_ALIASES_CONFIG_MAP_NAME = "host-aliases"; - public static final String TCP_PROTOCOL = "TCP"; + public static final String HOST_NETWORK_DNS_POLICY = "ClusterFirstWithHostNet"; - public static final String ENV_NODE_NAME = "NODE_NAME"; - public static final String NODE_NAME_FIELD_PATH = "spec.nodeName"; + public static final String TCP_PROTOCOL = "TCP"; - public static final String ENV_POD_NAME = "POD_NAME"; - public static final String POD_NAME_FIELD_PATH = "metadata.name"; + public static final String ENV_NODE_NAME = "NODE_NAME"; + public static final String NODE_NAME_FIELD_PATH = "spec.nodeName"; - public static final String ENV_POD_IP = "POD_IP"; - public static final String POD_IP_FIELD_PATH = "status.podIP"; + public static final String ENV_POD_NAME = "POD_NAME"; + public static final String POD_NAME_FIELD_PATH = "metadata.name"; - public static final String ENV_HOST_IP = "HOST_IP"; - public static final String HOST_IP_FIELD_PATH = "status.hostIP"; + public static final String ENV_POD_IP = "POD_IP"; + public static final String POD_IP_FIELD_PATH = "status.podIP"; - public static final String ENV_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"; - public static final String SERVICE_ACCOUNT_NAME_FIELD_PATH = "spec.serviceAccountName"; + public static final String ENV_HOST_IP = "HOST_IP"; + public static final String HOST_IP_FIELD_PATH = "status.hostIP"; - public static final String ENV_NAMESPACE = "NAMESPACE"; - public static final String NAMESPACE_FIELD_PATH = "metadata.namespace"; + public static final String ENV_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"; + public static final String SERVICE_ACCOUNT_NAME_FIELD_PATH = "spec.serviceAccountName"; - // ----------------------------- Environment Variables ---------------------------- + public static final String ENV_NAMESPACE = "NAMESPACE"; + public static final String NAMESPACE_FIELD_PATH = "metadata.namespace"; - public static final String ENV_CONF_DIR = "GEAFLOW_CONF_DIR"; + // ----------------------------- Environment Variables ---------------------------- - public static final String ENV_LOG_DIR = "GEAFLOW_LOG_DIR"; + public static final String ENV_CONF_DIR = "GEAFLOW_CONF_DIR"; - public static final String ENV_JOB_WORK_PATH = "GEAFLOW_JOB_WORK_PATH"; + public static final String ENV_LOG_DIR = "GEAFLOW_LOG_DIR"; - public static final String ENV_JAR_DOWNLOAD_PATH = "GEAFLOW_JAR_DOWNLOAD_PATH"; + public static final String ENV_JOB_WORK_PATH = "GEAFLOW_JOB_WORK_PATH"; - public static final String ENV_UDF_LIST = "GEAFLOW_UDF_LIST"; + public static final String ENV_JAR_DOWNLOAD_PATH = "GEAFLOW_JAR_DOWNLOAD_PATH"; - public static final String ENV_ENGINE_JAR = "GEAFLOW_ENGINE_JAR"; + public static final String ENV_UDF_LIST = "GEAFLOW_UDF_LIST"; - public static final String ENV_PERSISTENT_ROOT = "GEAFLOW_PERSISTENT_ROOT"; + public static final String ENV_ENGINE_JAR = "GEAFLOW_ENGINE_JAR"; - public static final String ENV_CATALOG_TOKEN = "GEAFLOW_CATALOG_TOKEN"; + public static final String ENV_PERSISTENT_ROOT = "GEAFLOW_PERSISTENT_ROOT"; - public static final String ENV_GW_ENDPOINT = "GEAFLOW_GW_ENDPOINT"; + public static final String ENV_CATALOG_TOKEN = "GEAFLOW_CATALOG_TOKEN"; - public static final String ENV_START_COMMAND = "GEAFLOW_START_COMMAND"; + public static final String ENV_GW_ENDPOINT = "GEAFLOW_GW_ENDPOINT"; - public static final String ENV_CONTAINER_ID = "GEAFLOW_CONTAINER_ID"; + public static final String ENV_START_COMMAND = "GEAFLOW_START_COMMAND"; - public static final String ENV_CONTAINER_INDEX = "GEAFLOW_CONTAINER_INDEX"; + public static final String ENV_CONTAINER_ID = "GEAFLOW_CONTAINER_ID"; - public static final String ENV_IS_RECOVER = "GEAFLOW_IS_RECOVER"; + public static final String ENV_CONTAINER_INDEX = "GEAFLOW_CONTAINER_INDEX"; - public static final String ENV_AUTO_RESTART = "GEAFLOW_AUTO_RESTART"; + public static final String ENV_IS_RECOVER = "GEAFLOW_IS_RECOVER"; - public static final String ENV_ALWAYS_DOWNLOAD_ENGINE = "GEAFLOW_ALWAYS_DOWNLOAD_ENGINE_JAR"; + public static final String ENV_AUTO_RESTART = "GEAFLOW_AUTO_RESTART"; - public static final String ENV_CLUSTER_ID = "GEAFLOW_CLUSTER_ID"; + public static final String ENV_ALWAYS_DOWNLOAD_ENGINE = "GEAFLOW_ALWAYS_DOWNLOAD_ENGINE_JAR"; - public static final String ENV_CLUSTER_FAULT_INJECTION_ENABLE = - "GEAFLOW_CLUSTER_FAULT_INJECTION_ENABLE"; + public static final String ENV_CLUSTER_ID = "GEAFLOW_CLUSTER_ID"; - public static final String ENV_MASTER_ID = "GEAFLOW_MASTER_ID"; + public static final String ENV_CLUSTER_FAULT_INJECTION_ENABLE = + "GEAFLOW_CLUSTER_FAULT_INJECTION_ENABLE"; - public static final String ENV_CONFIG_FILE = "geaflow-conf.yml"; - public static final String ENV_PROFILER_PATH = "ASYNC_PROFILER_SHELL_PATH"; + public static final String ENV_MASTER_ID = "GEAFLOW_MASTER_ID"; - public static final String GEAFLOW_CONF_VOLUME = "geaflow-conf-volume"; + public static final String ENV_CONFIG_FILE = "geaflow-conf.yml"; + public static final String ENV_PROFILER_PATH = "ASYNC_PROFILER_SHELL_PATH"; - public static final String GEAFLOW_LOG_VOLUME = "geaflow-log-volume"; + public static final String GEAFLOW_CONF_VOLUME = "geaflow-conf-volume"; - public static final String MASTER_ADDRESS = "geaflow.master.address"; + public static final String GEAFLOW_LOG_VOLUME = "geaflow-log-volume"; - public static final String JOB_CLASSPATH = "$GEAFLOW_CLASSPATH"; + public static final String MASTER_ADDRESS = "geaflow.master.address"; + public static final String JOB_CLASSPATH = "$GEAFLOW_CLASSPATH"; } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesClientParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesClientParam.java index d3270f485..e0b046751 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesClientParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesClientParam.java @@ -25,6 +25,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.k8s.entrypoint.KubernetesClientRunner; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.cluster.runner.util.ClusterUtils; @@ -32,79 +33,83 @@ public class KubernetesClientParam extends AbstractKubernetesParam { - public static final String CLIENT_USER_ANNOTATIONS = "kubernetes.client.user.annotations"; - - public static final String CLIENT_NODE_SELECTOR = "kubernetes.client.node-selector"; - - public static final String CONTAINERIZED_CLIENT_ENV_PREFIX = "containerized.client.env."; - - public static final String CLIENT_LOG_SUFFIX = "client.log"; - - private static final String CLIENT_AUTO_RESTART = "false"; - - public KubernetesClientParam(Configuration config) { - super(config); - } - - @Override - public Double getContainerCpu() { - return clusterConfig.getClientVcores(); - } - - @Override - public long getContainerMemoryMB() { - return clusterConfig.getClientMemoryMB(); - } - - @Override - protected long getContainerDiskGB() { - return clusterConfig.getClientDiskGB(); - } - - @Override - public String getContainerShellCommand() { - String logFilename = getLogDir() + File.separator + CLIENT_LOG_SUFFIX; - return ClusterUtils.getStartCommand(clusterConfig.getClientJvmOptions(), - KubernetesClientRunner.class, logFilename, config, JOB_CLASSPATH); - } - - @Override - public Map getAdditionEnvs() { - return KubernetesUtils.getVariablesWithPrefix(CONTAINERIZED_CLIENT_ENV_PREFIX, - config.getConfigMap()); - } - - @Override - public String getPodNamePrefix(String clusterId) { - return clusterId + K8SConstants.CLIENT_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; - } - - @Override - public String getConfigMapName(String clusterId) { - return clusterId + K8SConstants.CLIENT_CONFIG_MAP_SUFFIX; - } - - @Override - public Map getPodLabels(String clusterId) { - Map workerPodLabels = new HashMap<>(); - workerPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); - workerPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_CLIENT); - workerPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); - return workerPodLabels; - } - - @Override - public Map getAnnotations() { - return KubernetesUtils.getPairsConf(config, CLIENT_USER_ANNOTATIONS); - } - - @Override - public Map getNodeSelector() { - return KubernetesUtils.getPairsConf(config, CLIENT_NODE_SELECTOR); - } - - @Override - public String getAutoRestart() { - return CLIENT_AUTO_RESTART; - } + public static final String CLIENT_USER_ANNOTATIONS = "kubernetes.client.user.annotations"; + + public static final String CLIENT_NODE_SELECTOR = "kubernetes.client.node-selector"; + + public static final String CONTAINERIZED_CLIENT_ENV_PREFIX = "containerized.client.env."; + + public static final String CLIENT_LOG_SUFFIX = "client.log"; + + private static final String CLIENT_AUTO_RESTART = "false"; + + public KubernetesClientParam(Configuration config) { + super(config); + } + + @Override + public Double getContainerCpu() { + return clusterConfig.getClientVcores(); + } + + @Override + public long getContainerMemoryMB() { + return clusterConfig.getClientMemoryMB(); + } + + @Override + protected long getContainerDiskGB() { + return clusterConfig.getClientDiskGB(); + } + + @Override + public String getContainerShellCommand() { + String logFilename = getLogDir() + File.separator + CLIENT_LOG_SUFFIX; + return ClusterUtils.getStartCommand( + clusterConfig.getClientJvmOptions(), + KubernetesClientRunner.class, + logFilename, + config, + JOB_CLASSPATH); + } + + @Override + public Map getAdditionEnvs() { + return KubernetesUtils.getVariablesWithPrefix( + CONTAINERIZED_CLIENT_ENV_PREFIX, config.getConfigMap()); + } + + @Override + public String getPodNamePrefix(String clusterId) { + return clusterId + K8SConstants.CLIENT_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; + } + + @Override + public String getConfigMapName(String clusterId) { + return clusterId + K8SConstants.CLIENT_CONFIG_MAP_SUFFIX; + } + + @Override + public Map getPodLabels(String clusterId) { + Map workerPodLabels = new HashMap<>(); + workerPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); + workerPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_CLIENT); + workerPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); + return workerPodLabels; + } + + @Override + public Map getAnnotations() { + return KubernetesUtils.getPairsConf(config, CLIENT_USER_ANNOTATIONS); + } + + @Override + public Map getNodeSelector() { + return KubernetesUtils.getPairsConf(config, CLIENT_NODE_SELECTOR); + } + + @Override + public String getAutoRestart() { + return CLIENT_AUTO_RESTART; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfig.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfig.java index 3fe862793..f9d08988b 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfig.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfig.java @@ -42,132 +42,127 @@ public class KubernetesConfig { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesConfig.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesConfig.class); - public static final String CLUSTER_START_TIME = "kubernetes.cluster.start-time"; + public static final String CLUSTER_START_TIME = "kubernetes.cluster.start-time"; - public static final String CLIENT_MASTER_URL = "kubernetes.client.master.url"; + public static final String CLIENT_MASTER_URL = "kubernetes.client.master.url"; - public static final String MASTER_EXPOSED_ADDRESS = "kubernetes.master.exposed.address"; + public static final String MASTER_EXPOSED_ADDRESS = "kubernetes.master.exposed.address"; - public static final String DRIVER_EXPOSED_ADDRESS = "kubernetes.driver.exposed.address"; + public static final String DRIVER_EXPOSED_ADDRESS = "kubernetes.driver.exposed.address"; - /** - * Service exposed type on kubernetes cluster. - */ - public enum ServiceExposedType { - CLUSTER_IP("ClusterIP"), - NODE_PORT("NodePort"), - LOAD_BALANCER("LoadBalancer"), - EXTERNAL_NAME("ExternalName"); - - private final String serviceExposedType; + /** Service exposed type on kubernetes cluster. */ + public enum ServiceExposedType { + CLUSTER_IP("ClusterIP"), + NODE_PORT("NodePort"), + LOAD_BALANCER("LoadBalancer"), + EXTERNAL_NAME("ExternalName"); - ServiceExposedType(String type) { - serviceExposedType = type; - } + private final String serviceExposedType; - public String getServiceExposedType() { - return serviceExposedType; - } + ServiceExposedType(String type) { + serviceExposedType = type; + } - /** - * Convert exposed type string in kubernetes spec to ServiceExposedType. - * - * @param type exposed in kubernetes spec, e.g. LoadBanlancer - * @return ServiceExposedType - */ - public static ServiceExposedType fromString(String type) { - for (ServiceExposedType exposedType : ServiceExposedType.values()) { - if (exposedType.getServiceExposedType().equals(type)) { - return exposedType; - } - } - return CLUSTER_IP; - } + public String getServiceExposedType() { + return serviceExposedType; } /** - * The network type which be used by docker daemon. + * Convert exposed type string in kubernetes spec to ServiceExposedType. + * + * @param type exposed in kubernetes spec, e.g. LoadBanlancer + * @return ServiceExposedType */ - public enum DockerNetworkType { - BRIDGE, - HOST; - - /** - * Convert network type string to DockerNetworkType. - * - * @param type Docker network type, e.g. bridge - * @return DockerNetworkType - */ - public static DockerNetworkType fromString(String type) { - for (DockerNetworkType dockerNetworkType : DockerNetworkType.values()) { - if (dockerNetworkType.toString().equals(type)) { - return dockerNetworkType; - } - } - LOGGER.warn("Docker network type {} is not supported, BRIDGE network will be used", type); - return BRIDGE; + public static ServiceExposedType fromString(String type) { + for (ServiceExposedType exposedType : ServiceExposedType.values()) { + if (exposedType.getServiceExposedType().equals(type)) { + return exposedType; } + } + return CLUSTER_IP; } + } + + /** The network type which be used by docker daemon. */ + public enum DockerNetworkType { + BRIDGE, + HOST; - public static String getClientMasterUrl(Configuration config) { - String url = config.getString(CLIENT_MASTER_URL); - if (StringUtils.isNotBlank(url)) { - return url; - } else { - return config.getString(MASTER_URL); + /** + * Convert network type string to DockerNetworkType. + * + * @param type Docker network type, e.g. bridge + * @return DockerNetworkType + */ + public static DockerNetworkType fromString(String type) { + for (DockerNetworkType dockerNetworkType : DockerNetworkType.values()) { + if (dockerNetworkType.toString().equals(type)) { + return dockerNetworkType; } + } + LOGGER.warn("Docker network type {} is not supported, BRIDGE network will be used", type); + return BRIDGE; } - - public static int getClientTimeoutMs(Configuration config) { - return config.getInteger(CLUSTER_CLIENT_TIMEOUT_MS); + } + + public static String getClientMasterUrl(Configuration config) { + String url = config.getString(CLIENT_MASTER_URL); + if (StringUtils.isNotBlank(url)) { + return url; + } else { + return config.getString(MASTER_URL); } + } - public static int getConnectionRetryTimes(Configuration config) { - return config.getInteger(CONNECTION_RETRY_TIMES); - } + public static int getClientTimeoutMs(Configuration config) { + return config.getInteger(CLUSTER_CLIENT_TIMEOUT_MS); + } - public static long getConnectionRetryIntervalMs(Configuration config) { - return config.getLong(CONNECTION_RETRY_INTERVAL_MS); - } + public static int getConnectionRetryTimes(Configuration config) { + return config.getInteger(CONNECTION_RETRY_TIMES); + } - public static DockerNetworkType getDockerNetworkType(Configuration config) { - String value = config.getString(DOCKER_NETWORK_TYPE); - return DockerNetworkType.fromString(value); - } + public static long getConnectionRetryIntervalMs(Configuration config) { + return config.getLong(CONNECTION_RETRY_INTERVAL_MS); + } - public static ServiceExposedType getServiceExposedType(Configuration config) { - String value = config.getString(SERVICE_EXPOSED_TYPE); - return ServiceExposedType.valueOf(value); - } + public static DockerNetworkType getDockerNetworkType(Configuration config) { + String value = config.getString(DOCKER_NETWORK_TYPE); + return DockerNetworkType.fromString(value); + } - public static String getServiceName(Configuration config) { - return KubernetesUtils.getMasterServiceName(config.getString(CLUSTER_ID)); - } + public static ServiceExposedType getServiceExposedType(Configuration config) { + String value = config.getString(SERVICE_EXPOSED_TYPE); + return ServiceExposedType.valueOf(value); + } - public static String getServiceNameWithNamespace(Configuration config) { - return getServiceName(config) + K8SConstants.NAMESPACE_SEPARATOR + config.getString(NAME_SPACE); - } + public static String getServiceName(Configuration config) { + return KubernetesUtils.getMasterServiceName(config.getString(CLUSTER_ID)); + } - public static boolean enableResourceMemoryLimit(Configuration config) { - return config.getBoolean(ENABLE_RESOURCE_MEMORY_LIMIT); - } + public static String getServiceNameWithNamespace(Configuration config) { + return getServiceName(config) + K8SConstants.NAMESPACE_SEPARATOR + config.getString(NAME_SPACE); + } - public static boolean enableResourceCpuLimit(Configuration config) { - return config.getBoolean(ENABLE_RESOURCE_CPU_LIMIT); - } + public static boolean enableResourceMemoryLimit(Configuration config) { + return config.getBoolean(ENABLE_RESOURCE_MEMORY_LIMIT); + } - public static boolean enableResourceEphemeralStorageLimit(Configuration config) { - return config.getBoolean(ENABLE_RESOURCE_EPHEMERAL_STORAGE_LIMIT); - } + public static boolean enableResourceCpuLimit(Configuration config) { + return config.getBoolean(ENABLE_RESOURCE_CPU_LIMIT); + } - public static String getResourceEphemeralStorageSize(Configuration config) { - return config.getString(DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE); - } + public static boolean enableResourceEphemeralStorageLimit(Configuration config) { + return config.getBoolean(ENABLE_RESOURCE_EPHEMERAL_STORAGE_LIMIT); + } - public static String getJarDownloadPath(Configuration config) { - return FileUtils.getFile(config.getString(WORK_DIR), "jar").toString(); - } + public static String getResourceEphemeralStorageSize(Configuration config) { + return config.getString(DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE); + } + public static String getJarDownloadPath(Configuration config) { + return FileUtils.getFile(config.getString(WORK_DIR), "jar").toString(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigKeys.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigKeys.java index ff57a4aa6..93f1d4ce1 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigKeys.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigKeys.java @@ -28,226 +28,258 @@ * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.kubernetes.configuration.KubernetesConfigOptions. + * This class is an adaptation of Flink's + * org.apache.flink.kubernetes.configuration.KubernetesConfigOptions. */ public class KubernetesConfigKeys { - public static final ConfigKey CERT_DATA = ConfigKeys.key("kubernetes.cert.data") - .defaultValue("") - .description("kubernetes client cert data"); - - public static final ConfigKey CERT_KEY = ConfigKeys.key("kubernetes.cert.key") - .defaultValue("") - .description("kubernetes client cert key data"); - - public static final ConfigKey CA_DATA = ConfigKeys.key("kubernetes.ca.data") - .defaultValue("") - .description("kubernetes cluster ca data"); - - public static final ConfigKey NAME_SPACE = ConfigKeys.key("kubernetes.namespace") - .defaultValue("default") - .description("kubernetes namespace"); - - public static final ConfigKey CLUSTER_NAME = ConfigKeys.key("kubernetes.cluster.name") - .defaultValue("") - .description("kubernetes cluster name"); - - public static final ConfigKey CLUSTER_FAULT_INJECTION_ENABLE = ConfigKeys.key("kubernetes" - + ".cluster.fault-injection.enable") - .defaultValue(false) - .description("kubernetes cluster fo enable"); - - public static final ConfigKey MASTER_URL = ConfigKeys.key("kubernetes.master.url") - .defaultValue("https://kubernetes.default.svc") - .description("kubernetes cluster master url"); - - public static final ConfigKey SERVICE_SUFFIX = ConfigKeys.key("kubernetes.service.suffix") - .defaultValue("") - .description("suffix to append to the service name"); - - public static final ConfigKey SERVICE_ACCOUNT = ConfigKeys.key("kubernetes.service.account") - .defaultValue("geaflow") - .description("kubernetes service account to request resources from api server"); - - public static final ConfigKey SERVICE_EXPOSED_TYPE = ConfigKeys - .key("kubernetes.service.exposed.type") - .defaultValue(ServiceExposedType.NODE_PORT.name()) - .description("kubernetes service exposed service type"); - - public static final ConfigKey SERVICE_DNS_ENV = ConfigKeys - .key("kubernetes.service.dns.env") - .defaultValue(null) - .description("kubernetes service dns env"); - - public static final ConfigKey SERVICE_USER_LABELS = ConfigKeys.key("kubernetes.service.user.labels") - .defaultValue("") - .description("The labels to be set for services. Specified as key:value pairs separated by " - + "commas. such as version:alphav1,deploy:test."); - - public static final ConfigKey SERVICE_USER_ANNOTATIONS = ConfigKeys.key("kubernetes.service.user.annotations") - .defaultValue("") - .description("The annotations to be set for services. Specified as key:value pairs separated by " - + "commas. such as version:alphav1,deploy:test."); - - public static final ConfigKey DNS_SEARCH_DOMAINS = ConfigKeys - .key("kubernetes.pods.dns.search.domains") - .defaultValue("") - .description("dns search domain config"); - - public static final ConfigKey CONNECTION_RETRY_TIMES = ConfigKeys - .key("kubernetes.connection.retry.times") - .defaultValue(100) - .description("max retry to connect to api server"); - - public static final ConfigKey CONNECTION_RETRY_INTERVAL_MS = ConfigKeys - .key("kubernetes.connection.retry.interval.ms") - .defaultValue(1000L) - .description("max connect retry interval in ms"); - - public static final ConfigKey PING_INTERVAL_MS = ConfigKeys - .key("kubernetes.websocketPingInterval.ms") - .defaultValue(10000L) - .description("client ping interval in ms"); - - public static final ConfigKey POD_USER_LABELS = ConfigKeys.key("kubernetes.pod.user.labels") - .defaultValue("") - .description("The labels to be set for pods. Specified as key:value pairs separated by " - + "commas. such as version:alphav1,deploy:test."); - - public static final ConfigKey CONTAINER_IMAGE = ConfigKeys.key("kubernetes.container.image") - .defaultValue("geaflow-k8s:latest") - .description("container image name"); - - public static final ConfigKey CONTAINER_IMAGE_PULL_POLICY = ConfigKeys - .key("kubernetes.container.image.pullPolicy") - .defaultValue("IfNotPresent") - .description("container image pull policy"); - - public static final ConfigKey CONTAINER_CONF_FILES = ConfigKeys - .key("kubernetes.container.conf.files") - .defaultValue("/opt/geaflow/conf/log4j.properties") - .description("files to be used within containers"); - - public static final ConfigKey ENABLE_RESOURCE_MEMORY_LIMIT = ConfigKeys - .key("kubernetes.enable.resource.memory.limit") - .defaultValue(true) - .description("enable container memory limit"); - - public static final ConfigKey ENABLE_RESOURCE_CPU_LIMIT = ConfigKeys - .key("kubernetes.enable.resource.cpu.limit") - .defaultValue(true) - .description("enable container cpu limit"); - - public static final ConfigKey ENABLE_RESOURCE_EPHEMERAL_STORAGE_LIMIT = ConfigKeys - .key("kubernetes.enable.resource.storage.limit") - .defaultValue(true) - .description("enable container disk storage limit"); - - public static final ConfigKey DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE = ConfigKeys - .key("kubernetes.resource.storage.limit.size") - .defaultValue("15Gi") - .description("default container storage size"); - - public static final ConfigKey DOCKER_NETWORK_TYPE = ConfigKeys - .key("kubernetes.docker.network.type") - .defaultValue("BRIDGE") - .description("It could be BRIDGE/HOST."); - - public static final ConfigKey USE_IP_IN_HOST_NETWORK = ConfigKeys - .key("kubernetes.use-ip-in-host-network") - .defaultValue(true) - .description("whether to use ip in host network"); - - public static final ConfigKey ENABLE_LOG_DISK_LESS = ConfigKeys - .key("kubernetes.log.diskless.enable") - .defaultValue(true) - .description("whether to enable log diskless"); - - public static final ConfigKey TOLERATION_LIST = ConfigKeys - .key("kubernetes.toleration.list") - .noDefaultValue() - .description("Multiple tolerations will be separated by commas. Each toleration contains " - + "five parts, key:operator:value:effect:tolerationSeconds. Use - instead if the part " - + "is null. For example, key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-," - + "key3:Equal:value3:NoExecute:3600"); - - public static final ConfigKey MATCH_EXPRESSION_LIST = ConfigKeys - .key("kubernetes.match-expression.list") - .noDefaultValue() - .description("Multiple match-expressions will be separated by commas. Each " - + "match-expression contains " - + "five parts, key:operator:value:effect:tolerationSeconds. Use - instead if the part " - + "is null. For example, key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-," - + "key3:Equal:value3:NoExecute:3600"); - - public static final ConfigKey EVICTED_POD_LABELS = ConfigKeys - .key("kubernetes.pods.evict.labels") - .defaultValue("pod.sigma.ali/eviction:true") - .description("The labels of pod to be evicted"); - - public static final ConfigKey CONF_DIR = ConfigKeys.key("kubernetes.geaflow.conf.dir") - .defaultValue("/etc/geaflow/conf") - .description("geaflow conf directory"); - - public static final ConfigKey LOG_DIR = ConfigKeys.key("kubernetes.geaflow.log.dir") - .defaultValue("/home/admin/logs/geaflow") - .description("geaflow job log directory"); - - public static final ConfigKey WATCHER_CHECK_INTERVAL = ConfigKeys - .key("kubernetes.watcher.check.interval.seconds") - .defaultValue(60) - .description("time interval to check watcher liveness in seconds"); - - public static final ConfigKey DRIVER_NODE_PORT = ConfigKeys.key("kubernetes.driver.node.port") - .defaultValue(0) - .description("driver node port"); - - public static final ConfigKey WORK_DIR = ConfigKeys.key("kubernetes.geaflow.work.dir") - .defaultValue("/home/admin/geaflow/tmp") - .description("job work dir"); - - public static final ConfigKey ENGINE_JAR_FILES = ConfigKeys.key("kubernetes.engine.jar.files") - .defaultValue("") - .description("engine jar files, separated by comma"); - - public static final ConfigKey USER_JAR_FILES = ConfigKeys.key("kubernetes.user.jar.files") - .defaultValue("") - .description("user udf jar files, separated by comma"); - - public static final ConfigKey USER_MAIN_CLASS = ConfigKeys.key("kubernetes.user.main.class") - .noDefaultValue() - .description("the main class of user program"); - - public static final ConfigKey USER_CLASS_ARGS = ConfigKeys.key("kubernetes.user.class.args") - .noDefaultValue() - .description("the args of user mainClass"); - - public static final ConfigKey PROCESS_AUTO_RESTART = ConfigKeys.key("kubernetes.cluster.process.auto-restart") - .defaultValue("unexpected") - .description("whether to restart process automatically"); - - public static final ConfigKey CLIENT_KEY_ALGO = ConfigKeys.key("kubernetes.certs.client.key.algo") - .defaultValue("") - .description("client key algo"); - - public static final ConfigKey LEADER_ELECTION_LEASE_DURATION = ConfigKeys.key("kubernetes.leader-election.lease-duration") - .defaultValue(15) - .description("The duration seconds of once leader-election in kubernetes. Contenders can " - + "try to contend for a new leader after the previous leader invalid"); - - public static final ConfigKey LEADER_ELECTION_RENEW_DEADLINE = ConfigKeys.key("kubernetes.leader-election.renew-deadline") - .defaultValue(15) - .description("The deadline seconds of once leader-election in kubernetes. The current " - + "leader must renew the leadership within the deadline, or the leadership will be " - + "invalid after lease duration"); - - public static final ConfigKey LEADER_ELECTION_RETRY_PERIOD = ConfigKeys.key("kubernetes.leader-election.retry-period") - .defaultValue(5) - .description("The interval seconds of each contenders to try to contend for a new leader," - + " also is the interval seconds of current leader to renew for its leadership lease"); - - public static final ConfigKey ALWAYS_PULL_ENGINE_JAR = ConfigKeys - .key("kubernetes.engine.jar.pull.always") - .defaultValue(false) - .description("whether to always pull the remote engine jar to replace local ones"); + public static final ConfigKey CERT_DATA = + ConfigKeys.key("kubernetes.cert.data") + .defaultValue("") + .description("kubernetes client cert data"); + + public static final ConfigKey CERT_KEY = + ConfigKeys.key("kubernetes.cert.key") + .defaultValue("") + .description("kubernetes client cert key data"); + + public static final ConfigKey CA_DATA = + ConfigKeys.key("kubernetes.ca.data") + .defaultValue("") + .description("kubernetes cluster ca data"); + + public static final ConfigKey NAME_SPACE = + ConfigKeys.key("kubernetes.namespace") + .defaultValue("default") + .description("kubernetes namespace"); + + public static final ConfigKey CLUSTER_NAME = + ConfigKeys.key("kubernetes.cluster.name") + .defaultValue("") + .description("kubernetes cluster name"); + + public static final ConfigKey CLUSTER_FAULT_INJECTION_ENABLE = + ConfigKeys.key("kubernetes" + ".cluster.fault-injection.enable") + .defaultValue(false) + .description("kubernetes cluster fo enable"); + + public static final ConfigKey MASTER_URL = + ConfigKeys.key("kubernetes.master.url") + .defaultValue("https://kubernetes.default.svc") + .description("kubernetes cluster master url"); + + public static final ConfigKey SERVICE_SUFFIX = + ConfigKeys.key("kubernetes.service.suffix") + .defaultValue("") + .description("suffix to append to the service name"); + + public static final ConfigKey SERVICE_ACCOUNT = + ConfigKeys.key("kubernetes.service.account") + .defaultValue("geaflow") + .description("kubernetes service account to request resources from api server"); + + public static final ConfigKey SERVICE_EXPOSED_TYPE = + ConfigKeys.key("kubernetes.service.exposed.type") + .defaultValue(ServiceExposedType.NODE_PORT.name()) + .description("kubernetes service exposed service type"); + + public static final ConfigKey SERVICE_DNS_ENV = + ConfigKeys.key("kubernetes.service.dns.env") + .defaultValue(null) + .description("kubernetes service dns env"); + + public static final ConfigKey SERVICE_USER_LABELS = + ConfigKeys.key("kubernetes.service.user.labels") + .defaultValue("") + .description( + "The labels to be set for services. Specified as key:value pairs separated by " + + "commas. such as version:alphav1,deploy:test."); + + public static final ConfigKey SERVICE_USER_ANNOTATIONS = + ConfigKeys.key("kubernetes.service.user.annotations") + .defaultValue("") + .description( + "The annotations to be set for services. Specified as key:value pairs separated by " + + "commas. such as version:alphav1,deploy:test."); + + public static final ConfigKey DNS_SEARCH_DOMAINS = + ConfigKeys.key("kubernetes.pods.dns.search.domains") + .defaultValue("") + .description("dns search domain config"); + + public static final ConfigKey CONNECTION_RETRY_TIMES = + ConfigKeys.key("kubernetes.connection.retry.times") + .defaultValue(100) + .description("max retry to connect to api server"); + + public static final ConfigKey CONNECTION_RETRY_INTERVAL_MS = + ConfigKeys.key("kubernetes.connection.retry.interval.ms") + .defaultValue(1000L) + .description("max connect retry interval in ms"); + + public static final ConfigKey PING_INTERVAL_MS = + ConfigKeys.key("kubernetes.websocketPingInterval.ms") + .defaultValue(10000L) + .description("client ping interval in ms"); + + public static final ConfigKey POD_USER_LABELS = + ConfigKeys.key("kubernetes.pod.user.labels") + .defaultValue("") + .description( + "The labels to be set for pods. Specified as key:value pairs separated by " + + "commas. such as version:alphav1,deploy:test."); + + public static final ConfigKey CONTAINER_IMAGE = + ConfigKeys.key("kubernetes.container.image") + .defaultValue("geaflow-k8s:latest") + .description("container image name"); + + public static final ConfigKey CONTAINER_IMAGE_PULL_POLICY = + ConfigKeys.key("kubernetes.container.image.pullPolicy") + .defaultValue("IfNotPresent") + .description("container image pull policy"); + + public static final ConfigKey CONTAINER_CONF_FILES = + ConfigKeys.key("kubernetes.container.conf.files") + .defaultValue("/opt/geaflow/conf/log4j.properties") + .description("files to be used within containers"); + + public static final ConfigKey ENABLE_RESOURCE_MEMORY_LIMIT = + ConfigKeys.key("kubernetes.enable.resource.memory.limit") + .defaultValue(true) + .description("enable container memory limit"); + + public static final ConfigKey ENABLE_RESOURCE_CPU_LIMIT = + ConfigKeys.key("kubernetes.enable.resource.cpu.limit") + .defaultValue(true) + .description("enable container cpu limit"); + + public static final ConfigKey ENABLE_RESOURCE_EPHEMERAL_STORAGE_LIMIT = + ConfigKeys.key("kubernetes.enable.resource.storage.limit") + .defaultValue(true) + .description("enable container disk storage limit"); + + public static final ConfigKey DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE = + ConfigKeys.key("kubernetes.resource.storage.limit.size") + .defaultValue("15Gi") + .description("default container storage size"); + + public static final ConfigKey DOCKER_NETWORK_TYPE = + ConfigKeys.key("kubernetes.docker.network.type") + .defaultValue("BRIDGE") + .description("It could be BRIDGE/HOST."); + + public static final ConfigKey USE_IP_IN_HOST_NETWORK = + ConfigKeys.key("kubernetes.use-ip-in-host-network") + .defaultValue(true) + .description("whether to use ip in host network"); + + public static final ConfigKey ENABLE_LOG_DISK_LESS = + ConfigKeys.key("kubernetes.log.diskless.enable") + .defaultValue(true) + .description("whether to enable log diskless"); + + public static final ConfigKey TOLERATION_LIST = + ConfigKeys.key("kubernetes.toleration.list") + .noDefaultValue() + .description( + "Multiple tolerations will be separated by commas. Each toleration contains five" + + " parts, key:operator:value:effect:tolerationSeconds. Use - instead if the part" + + " is null. For example, key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-," + + "key3:Equal:value3:NoExecute:3600"); + + public static final ConfigKey MATCH_EXPRESSION_LIST = + ConfigKeys.key("kubernetes.match-expression.list") + .noDefaultValue() + .description( + "Multiple match-expressions will be separated by commas. Each match-expression" + + " contains five parts, key:operator:value:effect:tolerationSeconds. Use -" + + " instead if the part is null. For example," + + " key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-," + + "key3:Equal:value3:NoExecute:3600"); + + public static final ConfigKey EVICTED_POD_LABELS = + ConfigKeys.key("kubernetes.pods.evict.labels") + .defaultValue("pod.sigma.ali/eviction:true") + .description("The labels of pod to be evicted"); + + public static final ConfigKey CONF_DIR = + ConfigKeys.key("kubernetes.geaflow.conf.dir") + .defaultValue("/etc/geaflow/conf") + .description("geaflow conf directory"); + + public static final ConfigKey LOG_DIR = + ConfigKeys.key("kubernetes.geaflow.log.dir") + .defaultValue("/home/admin/logs/geaflow") + .description("geaflow job log directory"); + + public static final ConfigKey WATCHER_CHECK_INTERVAL = + ConfigKeys.key("kubernetes.watcher.check.interval.seconds") + .defaultValue(60) + .description("time interval to check watcher liveness in seconds"); + + public static final ConfigKey DRIVER_NODE_PORT = + ConfigKeys.key("kubernetes.driver.node.port").defaultValue(0).description("driver node port"); + + public static final ConfigKey WORK_DIR = + ConfigKeys.key("kubernetes.geaflow.work.dir") + .defaultValue("/home/admin/geaflow/tmp") + .description("job work dir"); + + public static final ConfigKey ENGINE_JAR_FILES = + ConfigKeys.key("kubernetes.engine.jar.files") + .defaultValue("") + .description("engine jar files, separated by comma"); + + public static final ConfigKey USER_JAR_FILES = + ConfigKeys.key("kubernetes.user.jar.files") + .defaultValue("") + .description("user udf jar files, separated by comma"); + + public static final ConfigKey USER_MAIN_CLASS = + ConfigKeys.key("kubernetes.user.main.class") + .noDefaultValue() + .description("the main class of user program"); + + public static final ConfigKey USER_CLASS_ARGS = + ConfigKeys.key("kubernetes.user.class.args") + .noDefaultValue() + .description("the args of user mainClass"); + + public static final ConfigKey PROCESS_AUTO_RESTART = + ConfigKeys.key("kubernetes.cluster.process.auto-restart") + .defaultValue("unexpected") + .description("whether to restart process automatically"); + + public static final ConfigKey CLIENT_KEY_ALGO = + ConfigKeys.key("kubernetes.certs.client.key.algo") + .defaultValue("") + .description("client key algo"); + + public static final ConfigKey LEADER_ELECTION_LEASE_DURATION = + ConfigKeys.key("kubernetes.leader-election.lease-duration") + .defaultValue(15) + .description( + "The duration seconds of once leader-election in kubernetes. Contenders can " + + "try to contend for a new leader after the previous leader invalid"); + + public static final ConfigKey LEADER_ELECTION_RENEW_DEADLINE = + ConfigKeys.key("kubernetes.leader-election.renew-deadline") + .defaultValue(15) + .description( + "The deadline seconds of once leader-election in kubernetes. The current leader must" + + " renew the leadership within the deadline, or the leadership will be invalid" + + " after lease duration"); + + public static final ConfigKey LEADER_ELECTION_RETRY_PERIOD = + ConfigKeys.key("kubernetes.leader-election.retry-period") + .defaultValue(5) + .description( + "The interval seconds of each contenders to try to contend for a new leader, also is" + + " the interval seconds of current leader to renew for its leadership lease"); + + public static final ConfigKey ALWAYS_PULL_ENGINE_JAR = + ConfigKeys.key("kubernetes.engine.jar.pull.always") + .defaultValue(false) + .description("whether to always pull the remote engine jar to replace local ones"); } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesContainerParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesContainerParam.java index b2aa60232..43f2a29c1 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesContainerParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesContainerParam.java @@ -26,6 +26,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.k8s.entrypoint.KubernetesContainerRunner; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; @@ -34,74 +35,77 @@ public class KubernetesContainerParam extends AbstractKubernetesParam { - public static final String CONTAINER_USER_ANNOTATIONS = "kubernetes.container.user.annotations"; - - public static final String CONTAINER_NODE_SELECTOR = "kubernetes.container.node-selector"; - - public static final String CONTAINER_ENV_PREFIX = "kubernetes.container.env."; - - public KubernetesContainerParam(Configuration config) { - super(config); - } - - public KubernetesContainerParam(ClusterConfig config) { - super(config); - } - - @Override - public Double getContainerCpu() { - return clusterConfig.getContainerVcores(); - } - - @Override - public long getContainerMemoryMB() { - return clusterConfig.getContainerMemoryMB(); - } - - @Override - protected long getContainerDiskGB() { - return clusterConfig.getContainerDiskGB(); - } - - @Override - public String getContainerShellCommand() { - String logFilename = getLogDir() + File.separator + CONTAINER_LOG_SUFFIX; - return ClusterUtils.getStartCommand(clusterConfig.getContainerJvmOptions(), - KubernetesContainerRunner.class, logFilename, config, JOB_CLASSPATH); - } - - @Override - public Map getAdditionEnvs() { - return KubernetesUtils - .getVariablesWithPrefix(CONTAINER_ENV_PREFIX, config.getConfigMap()); - } - - @Override - public String getPodNamePrefix(String clusterId) { - return clusterId + K8SConstants.WORKER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; - } - - @Override - public String getConfigMapName(String clusterId) { - return clusterId + K8SConstants.WORKER_CONFIG_MAP_SUFFIX; - } - - @Override - public Map getPodLabels(String clusterId) { - Map workerPodLabels = new HashMap<>(); - workerPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); - workerPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_WORKER); - workerPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); - return workerPodLabels; - } - - @Override - public Map getAnnotations() { - return KubernetesUtils.getPairsConf(config, CONTAINER_USER_ANNOTATIONS); - } - - @Override - public Map getNodeSelector() { - return KubernetesUtils.getPairsConf(config, CONTAINER_NODE_SELECTOR); - } + public static final String CONTAINER_USER_ANNOTATIONS = "kubernetes.container.user.annotations"; + + public static final String CONTAINER_NODE_SELECTOR = "kubernetes.container.node-selector"; + + public static final String CONTAINER_ENV_PREFIX = "kubernetes.container.env."; + + public KubernetesContainerParam(Configuration config) { + super(config); + } + + public KubernetesContainerParam(ClusterConfig config) { + super(config); + } + + @Override + public Double getContainerCpu() { + return clusterConfig.getContainerVcores(); + } + + @Override + public long getContainerMemoryMB() { + return clusterConfig.getContainerMemoryMB(); + } + + @Override + protected long getContainerDiskGB() { + return clusterConfig.getContainerDiskGB(); + } + + @Override + public String getContainerShellCommand() { + String logFilename = getLogDir() + File.separator + CONTAINER_LOG_SUFFIX; + return ClusterUtils.getStartCommand( + clusterConfig.getContainerJvmOptions(), + KubernetesContainerRunner.class, + logFilename, + config, + JOB_CLASSPATH); + } + + @Override + public Map getAdditionEnvs() { + return KubernetesUtils.getVariablesWithPrefix(CONTAINER_ENV_PREFIX, config.getConfigMap()); + } + + @Override + public String getPodNamePrefix(String clusterId) { + return clusterId + K8SConstants.WORKER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; + } + + @Override + public String getConfigMapName(String clusterId) { + return clusterId + K8SConstants.WORKER_CONFIG_MAP_SUFFIX; + } + + @Override + public Map getPodLabels(String clusterId) { + Map workerPodLabels = new HashMap<>(); + workerPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); + workerPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_WORKER); + workerPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); + return workerPodLabels; + } + + @Override + public Map getAnnotations() { + return KubernetesUtils.getPairsConf(config, CONTAINER_USER_ANNOTATIONS); + } + + @Override + public Map getNodeSelector() { + return KubernetesUtils.getPairsConf(config, CONTAINER_NODE_SELECTOR); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesDriverParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesDriverParam.java index 035e57bd8..a5128d0e5 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesDriverParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesDriverParam.java @@ -28,6 +28,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.k8s.entrypoint.KubernetesDriverRunner; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; @@ -36,84 +37,86 @@ public class KubernetesDriverParam extends AbstractKubernetesParam { - public static final String DRIVER_ENV_PREFIX = "kubernetes.driver.env."; - - public static final String DRIVER_USER_ANNOTATIONS = "kubernetes.driver.user.annotations"; - - public static final String DRIVER_NODE_SELECTOR = "kubernetes.driver.node-selector"; - - public KubernetesDriverParam(Configuration config) { - super(config); - } - - public KubernetesDriverParam(ClusterConfig config) { - super(config); - } - - @Override - public Double getContainerCpu() { - return clusterConfig.getDriverVcores(); - } - - @Override - public long getContainerMemoryMB() { - return clusterConfig.getDriverMemoryMB(); - } - - @Override - protected long getContainerDiskGB() { - return clusterConfig.getDriverDiskGB(); - } - - @Override - public String getContainerShellCommand() { - String logFileName = getLogDir() + File.separator + DRIVER_LOG_SUFFIX; - return ClusterUtils.getStartCommand(clusterConfig.getDriverJvmOptions(), - KubernetesDriverRunner.class, logFileName, config, JOB_CLASSPATH); - } - - @Override - public Map getAdditionEnvs() { - return KubernetesUtils - .getVariablesWithPrefix(DRIVER_ENV_PREFIX, config.getConfigMap()); - } - - @Override - public String getPodNamePrefix(String clusterId) { - return clusterId + K8SConstants.DRIVER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; - } - - @Override - public String getConfigMapName(String clusterId) { - return clusterId + K8SConstants.WORKER_CONFIG_MAP_SUFFIX; - } - - @Override - public int getRpcPort() { - return config.getInteger(DRIVER_RPC_PORT); - } - - public int getNodePort() { - return config.getInteger(DRIVER_NODE_PORT); - } - - @Override - public Map getPodLabels(String clusterId) { - Map driverPodLabels = new HashMap<>(); - driverPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); - driverPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_DRIVER); - driverPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); - return driverPodLabels; - } - - @Override - public Map getNodeSelector() { - return KubernetesUtils.getPairsConf(config, DRIVER_NODE_SELECTOR); - } - - @Override - public Map getAnnotations() { - return KubernetesUtils.getPairsConf(config, DRIVER_USER_ANNOTATIONS); - } - + public static final String DRIVER_ENV_PREFIX = "kubernetes.driver.env."; + + public static final String DRIVER_USER_ANNOTATIONS = "kubernetes.driver.user.annotations"; + + public static final String DRIVER_NODE_SELECTOR = "kubernetes.driver.node-selector"; + + public KubernetesDriverParam(Configuration config) { + super(config); + } + + public KubernetesDriverParam(ClusterConfig config) { + super(config); + } + + @Override + public Double getContainerCpu() { + return clusterConfig.getDriverVcores(); + } + + @Override + public long getContainerMemoryMB() { + return clusterConfig.getDriverMemoryMB(); + } + + @Override + protected long getContainerDiskGB() { + return clusterConfig.getDriverDiskGB(); + } + + @Override + public String getContainerShellCommand() { + String logFileName = getLogDir() + File.separator + DRIVER_LOG_SUFFIX; + return ClusterUtils.getStartCommand( + clusterConfig.getDriverJvmOptions(), + KubernetesDriverRunner.class, + logFileName, + config, + JOB_CLASSPATH); + } + + @Override + public Map getAdditionEnvs() { + return KubernetesUtils.getVariablesWithPrefix(DRIVER_ENV_PREFIX, config.getConfigMap()); + } + + @Override + public String getPodNamePrefix(String clusterId) { + return clusterId + K8SConstants.DRIVER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; + } + + @Override + public String getConfigMapName(String clusterId) { + return clusterId + K8SConstants.WORKER_CONFIG_MAP_SUFFIX; + } + + @Override + public int getRpcPort() { + return config.getInteger(DRIVER_RPC_PORT); + } + + public int getNodePort() { + return config.getInteger(DRIVER_NODE_PORT); + } + + @Override + public Map getPodLabels(String clusterId) { + Map driverPodLabels = new HashMap<>(); + driverPodLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); + driverPodLabels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_DRIVER); + driverPodLabels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); + return driverPodLabels; + } + + @Override + public Map getNodeSelector() { + return KubernetesUtils.getPairsConf(config, DRIVER_NODE_SELECTOR); + } + + @Override + public Map getAnnotations() { + return KubernetesUtils.getPairsConf(config, DRIVER_USER_ANNOTATIONS); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesMasterParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesMasterParam.java index 5a29d4317..9d7592892 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesMasterParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesMasterParam.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.k8s.entrypoint.KubernetesMasterRunner; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; @@ -37,95 +38,97 @@ public class KubernetesMasterParam extends AbstractKubernetesParam { - public static final String MASTER_CONTAINER_NAME = "kubernetes.master.container.name"; - - public static final String MASTER_USER_ANNOTATIONS = "kubernetes.master.user.annotations"; - - public static final String MASTER_NODE_SELECTOR = "kubernetes.master.node-selector"; - - public static final String MASTER_ENV_PREFIX = "kubernetes.master.env."; - - - public KubernetesMasterParam(Configuration config) { - super(config); - } - - public KubernetesMasterParam(ClusterConfig config) { - super(config); - } - - @Override - public String getAutoRestart() { - return AutoRestartPolicy.UNEXPECTED.getValue(); - } - - public String getContainerName() { - return config.getString(MASTER_CONTAINER_NAME, "geaflow-master"); - } - - public Double getContainerCpu() { - return clusterConfig.getMasterVcores(); - } - - @Override - public long getContainerMemoryMB() { - return clusterConfig.getMasterMemoryMB(); - } - - @Override - protected long getContainerDiskGB() { - return clusterConfig.getMasterDiskGB(); - } - - @Override - public String getContainerShellCommand() { - String logFilename = getLogDir() + File.separator + MASTER_LOG_SUFFIX; - return ClusterUtils.getStartCommand(clusterConfig.getMasterJvmOptions(), - KubernetesMasterRunner.class, logFilename, config, JOB_CLASSPATH); - } - - @Override - public String getPodNamePrefix(String clusterId) { - return clusterId + K8SConstants.MASTER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; - } - - @Override - public int getHttpPort() { - return config.getInteger(MASTER_HTTP_PORT); - } - - @Override - public Map getAdditionEnvs() { - return KubernetesUtils - .getVariablesWithPrefix(MASTER_ENV_PREFIX, config.getConfigMap()); - } - - @Override - public String getConfigMapName(String clusterId) { - return clusterId + K8SConstants.MASTER_CONFIG_MAP_SUFFIX; - } - - @Override - public Map getPodLabels(String clusterId) { - Map labels = new HashMap<>(); - labels.put(K8SConstants.LABEL_APP_KEY, clusterId); - labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_MASTER); - labels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); - return labels; - } - - @Override - public Map getAnnotations() { - return KubernetesUtils.getPairsConf(config, MASTER_USER_ANNOTATIONS); - } - - @Override - public Map getNodeSelector() { - return KubernetesUtils.getPairsConf(config, MASTER_NODE_SELECTOR); - } - - @Override - public boolean enableLeaderElection() { - return config.getBoolean(ExecutionConfigKeys.ENABLE_MASTER_LEADER_ELECTION); - } + public static final String MASTER_CONTAINER_NAME = "kubernetes.master.container.name"; + + public static final String MASTER_USER_ANNOTATIONS = "kubernetes.master.user.annotations"; + + public static final String MASTER_NODE_SELECTOR = "kubernetes.master.node-selector"; + + public static final String MASTER_ENV_PREFIX = "kubernetes.master.env."; + + public KubernetesMasterParam(Configuration config) { + super(config); + } + + public KubernetesMasterParam(ClusterConfig config) { + super(config); + } + + @Override + public String getAutoRestart() { + return AutoRestartPolicy.UNEXPECTED.getValue(); + } + + public String getContainerName() { + return config.getString(MASTER_CONTAINER_NAME, "geaflow-master"); + } + + public Double getContainerCpu() { + return clusterConfig.getMasterVcores(); + } + + @Override + public long getContainerMemoryMB() { + return clusterConfig.getMasterMemoryMB(); + } + + @Override + protected long getContainerDiskGB() { + return clusterConfig.getMasterDiskGB(); + } + + @Override + public String getContainerShellCommand() { + String logFilename = getLogDir() + File.separator + MASTER_LOG_SUFFIX; + return ClusterUtils.getStartCommand( + clusterConfig.getMasterJvmOptions(), + KubernetesMasterRunner.class, + logFilename, + config, + JOB_CLASSPATH); + } + + @Override + public String getPodNamePrefix(String clusterId) { + return clusterId + K8SConstants.MASTER_NAME_SUFFIX + K8SConstants.NAME_SEPARATOR; + } + + @Override + public int getHttpPort() { + return config.getInteger(MASTER_HTTP_PORT); + } + + @Override + public Map getAdditionEnvs() { + return KubernetesUtils.getVariablesWithPrefix(MASTER_ENV_PREFIX, config.getConfigMap()); + } + + @Override + public String getConfigMapName(String clusterId) { + return clusterId + K8SConstants.MASTER_CONFIG_MAP_SUFFIX; + } + + @Override + public Map getPodLabels(String clusterId) { + Map labels = new HashMap<>(); + labels.put(K8SConstants.LABEL_APP_KEY, clusterId); + labels.put(K8SConstants.LABEL_COMPONENT_KEY, K8SConstants.LABEL_COMPONENT_MASTER); + labels.putAll(KubernetesUtils.getPairsConf(config, POD_USER_LABELS)); + return labels; + } + + @Override + public Map getAnnotations() { + return KubernetesUtils.getPairsConf(config, MASTER_USER_ANNOTATIONS); + } + + @Override + public Map getNodeSelector() { + return KubernetesUtils.getPairsConf(config, MASTER_NODE_SELECTOR); + } + + @Override + public boolean enableLeaderElection() { + return config.getBoolean(ExecutionConfigKeys.ENABLE_MASTER_LEADER_ELECTION); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesParam.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesParam.java index 3c1e63c59..0d090bbd0 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesParam.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/config/KubernetesParam.java @@ -19,136 +19,91 @@ package org.apache.geaflow.cluster.k8s.config; -import io.fabric8.kubernetes.api.model.Quantity; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; +import io.fabric8.kubernetes.api.model.Quantity; + /** - * A collection of Kubernetes parameters for pod creating. - * This interface is an adaptation of Flink's org.apache.flink.kubernetes.kubeclient.parameters.KubernetesParameters. + * A collection of Kubernetes parameters for pod creating. This interface is an adaptation of + * Flink's org.apache.flink.kubernetes.kubeclient.parameters.KubernetesParameters. */ public interface KubernetesParam { - /** - * Get service account. - */ - String getServiceAccount(); - - /** - * Get all key-value pair labels for service. - */ - Map getServiceLabels(); - - /** - * Get all key-value pair annotations for service. - */ - Map getServiceAnnotations(); - - /** - * Get all key-value pair labels for pod. - * - * @param clusterId Current k8s cluster id. - */ - Map getPodLabels(String clusterId); - - /** - * Get all node selectors. - */ - Map getNodeSelector(); - - /** - * Get all annotations shared for all pods. - */ - Map getAnnotations(); - - /** - * Get container image name. - */ - String getContainerImage(); - - /** - * Get container image pull policy. - */ - String getContainerImagePullPolicy(); - - /** - * Get the shell command of start container process. - */ - String getContainerShellCommand(); - - /** - * Get pod name prefix of current cluster. - */ - String getPodNamePrefix(String clusterId); - - /** - * Get the name of configuration for current cluster. - */ - String getConfigMapName(String clusterId); - - /** - * Get the value of cpu request for pod. - */ - Quantity getCpuQuantity(); - - /** - * Get the value of memory request for pod. - */ - Quantity getMemoryQuantity(); - - /** - * Get the value of disk request for pod. - */ - Quantity getDiskQuantity(); - - /** - * Get current exposed rpc port. - */ - int getRpcPort(); - - /** - * Get current exposed http port. - */ - int getHttpPort(); - - /** - * Get current exposed node port. - */ - int getNodePort(); - - /** - * Get env config directory. - */ - String getConfDir(); - - /** - * Get log directory. - */ - String getLogDir(); - - /** - * Get the flag that process should auto start after crashed. - */ - String getAutoRestart(); - - /** - * Get the flag whether allow injecting error or exception. - */ - Boolean getClusterFaultInjectionEnable(); - - /** - * Get origin user configuration. - */ - Configuration getConfig(); - - /** - * Get addition env from client. - */ - Map getAdditionEnvs(); - - /** - * Get whether leader-election is enabled. - */ - boolean enableLeaderElection(); + /** Get service account. */ + String getServiceAccount(); + + /** Get all key-value pair labels for service. */ + Map getServiceLabels(); + + /** Get all key-value pair annotations for service. */ + Map getServiceAnnotations(); + + /** + * Get all key-value pair labels for pod. + * + * @param clusterId Current k8s cluster id. + */ + Map getPodLabels(String clusterId); + + /** Get all node selectors. */ + Map getNodeSelector(); + + /** Get all annotations shared for all pods. */ + Map getAnnotations(); + + /** Get container image name. */ + String getContainerImage(); + + /** Get container image pull policy. */ + String getContainerImagePullPolicy(); + + /** Get the shell command of start container process. */ + String getContainerShellCommand(); + + /** Get pod name prefix of current cluster. */ + String getPodNamePrefix(String clusterId); + + /** Get the name of configuration for current cluster. */ + String getConfigMapName(String clusterId); + + /** Get the value of cpu request for pod. */ + Quantity getCpuQuantity(); + + /** Get the value of memory request for pod. */ + Quantity getMemoryQuantity(); + + /** Get the value of disk request for pod. */ + Quantity getDiskQuantity(); + + /** Get current exposed rpc port. */ + int getRpcPort(); + + /** Get current exposed http port. */ + int getHttpPort(); + + /** Get current exposed node port. */ + int getNodePort(); + + /** Get env config directory. */ + String getConfDir(); + + /** Get log directory. */ + String getLogDir(); + + /** Get the flag that process should auto start after crashed. */ + String getAutoRestart(); + + /** Get the flag whether allow injecting error or exception. */ + Boolean getClusterFaultInjectionEnable(); + + /** Get origin user configuration. */ + Configuration getConfig(); + + /** Get addition env from client. */ + Map getAdditionEnvs(); + /** Get whether leader-election is enabled. */ + boolean enableLeaderElection(); } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunner.java index 713c5cd65..942bf1b80 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunner.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.lang.reflect.Method; import java.util.Map; + import org.apache.commons.lang3.StringEscapeUtils; import org.apache.geaflow.cluster.client.callback.ClusterCallbackFactory; import org.apache.geaflow.cluster.client.callback.ClusterStartedCallback; @@ -45,73 +46,72 @@ public class KubernetesClientRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClientRunner.class); - private final Configuration config; - - public KubernetesClientRunner(Configuration config) { - this.config = config; - } + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesClientRunner.class); + private final Configuration config; - public void run(String classArgs) { - String userClass = null; - ClusterStartedCallback callback = ClusterCallbackFactory.createClusterStartCallback(config); - try { - System.setProperty(CLUSTER_TYPE, EnvType.K8S.name()); + public KubernetesClientRunner(Configuration config) { + this.config = config; + } - userClass = config.getString(USER_MAIN_CLASS); - LOGGER.info("execute mainClass {} to k8s, args: {}", userClass, classArgs); + public void run(String classArgs) { + String userClass = null; + ClusterStartedCallback callback = ClusterCallbackFactory.createClusterStartCallback(config); + try { + System.setProperty(CLUSTER_TYPE, EnvType.K8S.name()); - EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); - Map newConfig = parser.parse(new String[]{classArgs}); - config.putAll(newConfig); + userClass = config.getString(USER_MAIN_CLASS); + LOGGER.info("execute mainClass {} to k8s, args: {}", userClass, classArgs); - callback = ClusterCallbackFactory.createClusterStartCallback(config); - LOGGER.info("client callback: {}", callback.getClass().getCanonicalName()); + EnvironmentArgumentParser parser = new EnvironmentArgumentParser(); + Map newConfig = parser.parse(new String[] {classArgs}); + config.putAll(newConfig); - Class mainClazz = Thread.currentThread().getContextClassLoader().loadClass(userClass); - Method mainMethod = mainClazz.getMethod("main", String[].class); - mainMethod.invoke(mainClazz, (Object) new String[]{classArgs}); - } catch (Throwable e) { - LOGGER.error("execute mainClass {} failed", userClass, e); - callback.onFailure(e); - throw new GeaflowRuntimeException(e); - } finally { - cleanAndExit(); - } - } + callback = ClusterCallbackFactory.createClusterStartCallback(config); + LOGGER.info("client callback: {}", callback.getClass().getCanonicalName()); - public static void main(String[] args) throws IOException { - try { - final long startTime = System.currentTimeMillis(); - Configuration config = KubernetesUtils.loadConfigurationFromFile(); - final String classArgs = StringEscapeUtils.escapeJava(config.getString(USER_CLASS_ARGS)); - KubernetesClientRunner clientRunner = new KubernetesClientRunner(config); - clientRunner.run(classArgs); - LOGGER.info("Completed client init in {} ms", System.currentTimeMillis() - startTime); - } catch (Throwable e) { - LOGGER.error("init client runner failed: {}", e.getMessage(), e); - throw e; - } + Class mainClazz = Thread.currentThread().getContextClassLoader().loadClass(userClass); + Method mainMethod = mainClazz.getMethod("main", String[].class); + mainMethod.invoke(mainClazz, (Object) new String[] {classArgs}); + } catch (Throwable e) { + LOGGER.error("execute mainClass {} failed", userClass, e); + callback.onFailure(e); + throw new GeaflowRuntimeException(e); + } finally { + cleanAndExit(); } + } - private void cleanAndExit() { - try { - int waitTime = config.getInteger(CLIENT_EXIT_WAIT_SECONDS); - LOGGER.info("Sleep {} seconds before client exits...", waitTime); - SleepUtils.sleepSecond(waitTime); - deleteClientConfigMap(); - } catch (Throwable e) { - LOGGER.error("delete client config map failed: {}", e.getMessage(), e); - } + public static void main(String[] args) throws IOException { + try { + final long startTime = System.currentTimeMillis(); + Configuration config = KubernetesUtils.loadConfigurationFromFile(); + final String classArgs = StringEscapeUtils.escapeJava(config.getString(USER_CLASS_ARGS)); + KubernetesClientRunner clientRunner = new KubernetesClientRunner(config); + clientRunner.run(classArgs); + LOGGER.info("Completed client init in {} ms", System.currentTimeMillis() - startTime); + } catch (Throwable e) { + LOGGER.error("init client runner failed: {}", e.getMessage(), e); + throw e; } + } - private void deleteClientConfigMap() { - String clusterId = config.getString(ExecutionConfigKeys.CLUSTER_ID); - String masterUrl = KubernetesConfig.getClientMasterUrl(config); - KubernetesClientParam clientParam = new KubernetesClientParam(config); - String clientConfigMap = clientParam.getConfigMapName(clusterId); - GeaflowKubeClient kubernetesClient = new GeaflowKubeClient(config, masterUrl); - kubernetesClient.deleteConfigMap(clientConfigMap); + private void cleanAndExit() { + try { + int waitTime = config.getInteger(CLIENT_EXIT_WAIT_SECONDS); + LOGGER.info("Sleep {} seconds before client exits...", waitTime); + SleepUtils.sleepSecond(waitTime); + deleteClientConfigMap(); + } catch (Throwable e) { + LOGGER.error("delete client config map failed: {}", e.getMessage(), e); } + } + private void deleteClientConfigMap() { + String clusterId = config.getString(ExecutionConfigKeys.CLUSTER_ID); + String masterUrl = KubernetesConfig.getClientMasterUrl(config); + KubernetesClientParam clientParam = new KubernetesClientParam(config); + String clientConfigMap = clientParam.getConfigMapName(clusterId); + GeaflowKubeClient kubernetesClient = new GeaflowKubeClient(config, masterUrl); + kubernetesClient.deleteConfigMap(clientConfigMap); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesContainerRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesContainerRunner.java index f7658658f..a3339dee5 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesContainerRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesContainerRunner.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.EXIT_CODE; import java.util.Map; + import org.apache.geaflow.cluster.container.Container; import org.apache.geaflow.cluster.container.ContainerContext; import org.apache.geaflow.cluster.k8s.config.K8SConstants; @@ -37,59 +38,59 @@ * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.kubernetes.taskmanager.KubernetesTaskExecutorRunner. + * This class is an adaptation of Flink's + * org.apache.flink.kubernetes.taskmanager.KubernetesTaskExecutorRunner. */ public class KubernetesContainerRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesContainerRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesContainerRunner.class); - /** - * The process environment variables. - */ - private static final Map ENV = System.getenv(); + /** The process environment variables. */ + private static final Map ENV = System.getenv(); - private final ContainerContext containerContext; - private Container container; + private final ContainerContext containerContext; + private Container container; - public KubernetesContainerRunner(ContainerContext containerContext) { - this.containerContext = containerContext; - } + public KubernetesContainerRunner(ContainerContext containerContext) { + this.containerContext = containerContext; + } - public void run() { - Configuration config = containerContext.getConfig(); - KubernetesContainerParam workerParam = new KubernetesContainerParam(config); - container = new Container(workerParam.getRpcPort()); - containerContext.load(); - container.init(containerContext); - } + public void run() { + Configuration config = containerContext.getConfig(); + KubernetesContainerParam workerParam = new KubernetesContainerParam(config); + container = new Container(workerParam.getRpcPort()); + containerContext.load(); + container.init(containerContext); + } - private void waitForTermination() { - LOGGER.info("wait for service terminating"); - container.waitTermination(); - } + private void waitForTermination() { + LOGGER.info("wait for service terminating"); + container.waitTermination(); + } - public static void main(String[] args) throws Exception { - try { - final long startTime = System.currentTimeMillis(); - String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); - String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); - boolean isRecover = Boolean.parseBoolean(ClusterUtils.getEnvValue(ENV, - K8SConstants.ENV_IS_RECOVER)); - LOGGER.info("ResourceID assigned for this container:{} masterId:{}, isRecover:{}", id, - masterId, isRecover); + public static void main(String[] args) throws Exception { + try { + final long startTime = System.currentTimeMillis(); + String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); + String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); + boolean isRecover = + Boolean.parseBoolean(ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_IS_RECOVER)); + LOGGER.info( + "ResourceID assigned for this container:{} masterId:{}, isRecover:{}", + id, + masterId, + isRecover); - Configuration config = KubernetesUtils.loadConfiguration(); - config.setMasterId(masterId); - ContainerContext context = new ContainerContext(Integer.parseInt(id), config, isRecover); - KubernetesContainerRunner kubernetesContainerRunner = new KubernetesContainerRunner( - context); - kubernetesContainerRunner.run(); - LOGGER.info("Completed container init in {}ms", System.currentTimeMillis() - startTime); - kubernetesContainerRunner.waitForTermination(); - } catch (Throwable e) { - LOGGER.error("FATAL: process exits", e); - System.exit(EXIT_CODE); - } + Configuration config = KubernetesUtils.loadConfiguration(); + config.setMasterId(masterId); + ContainerContext context = new ContainerContext(Integer.parseInt(id), config, isRecover); + KubernetesContainerRunner kubernetesContainerRunner = new KubernetesContainerRunner(context); + kubernetesContainerRunner.run(); + LOGGER.info("Completed container init in {}ms", System.currentTimeMillis() - startTime); + kubernetesContainerRunner.waitForTermination(); + } catch (Throwable e) { + LOGGER.error("FATAL: process exits", e); + System.exit(EXIT_CODE); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesDriverRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesDriverRunner.java index 40861d6e5..18922df8e 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesDriverRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesDriverRunner.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.EXIT_CODE; import java.util.Map; + import org.apache.geaflow.cluster.driver.Driver; import org.apache.geaflow.cluster.driver.DriverContext; import org.apache.geaflow.cluster.k8s.config.K8SConstants; @@ -37,61 +38,59 @@ * additional information regarding copyright ownership. */ /** - * This class is an adaptation of Flink's org.apache.flink.kubernetes.taskmanager.KubernetesTaskExecutorRunner. + * This class is an adaptation of Flink's + * org.apache.flink.kubernetes.taskmanager.KubernetesTaskExecutorRunner. */ public class KubernetesDriverRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesDriverRunner.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesDriverRunner.class); - /** - * The process environment variables. - */ - private static final Map ENV = System.getenv(); + /** The process environment variables. */ + private static final Map ENV = System.getenv(); - private final int driverId; - private final int driverIndex; - private final Configuration config; - private Driver driver; + private final int driverId; + private final int driverIndex; + private final Configuration config; + private Driver driver; - public KubernetesDriverRunner(int driverId, int driverIndex, Configuration config) { - this.driverId = driverId; - this.driverIndex = driverIndex; - this.config = config; - } + public KubernetesDriverRunner(int driverId, int driverIndex, Configuration config) { + this.driverId = driverId; + this.driverIndex = driverIndex; + this.config = config; + } - public void run() { - KubernetesDriverParam driverParam = new KubernetesDriverParam(config); - DriverContext driverContext = new DriverContext(driverId, driverIndex, config); - driver = new Driver(driverParam.getRpcPort()); - driverContext.load(); - driver.init(driverContext); - } + public void run() { + KubernetesDriverParam driverParam = new KubernetesDriverParam(config); + DriverContext driverContext = new DriverContext(driverId, driverIndex, config); + driver = new Driver(driverParam.getRpcPort()); + driverContext.load(); + driver.init(driverContext); + } - private void waitForTermination() { - LOGGER.info("wait for service terminating"); - driver.waitTermination(); - } + private void waitForTermination() { + LOGGER.info("wait for service terminating"); + driver.waitTermination(); + } - public static void main(String[] args) throws Exception { - try { - final long startTime = System.currentTimeMillis(); - String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); - String index = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_INDEX); - String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); - LOGGER.info("ResourceID assigned for this driver id:{} index:{} masterId:{}", id, - index, masterId); + public static void main(String[] args) throws Exception { + try { + final long startTime = System.currentTimeMillis(); + String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); + String index = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_INDEX); + String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); + LOGGER.info( + "ResourceID assigned for this driver id:{} index:{} masterId:{}", id, index, masterId); - Configuration config = KubernetesUtils.loadConfiguration(); - config.setMasterId(masterId); - KubernetesDriverRunner kubernetesDriverRunner = - new KubernetesDriverRunner(Integer.parseInt(id), Integer.parseInt(index), config); - kubernetesDriverRunner.run(); - LOGGER.info("Completed driver init in {} ms", System.currentTimeMillis() - startTime); - kubernetesDriverRunner.waitForTermination(); - } catch (Throwable e) { - LOGGER.error("FATAL: process exits", e); - System.exit(EXIT_CODE); - } + Configuration config = KubernetesUtils.loadConfiguration(); + config.setMasterId(masterId); + KubernetesDriverRunner kubernetesDriverRunner = + new KubernetesDriverRunner(Integer.parseInt(id), Integer.parseInt(index), config); + kubernetesDriverRunner.run(); + LOGGER.info("Completed driver init in {} ms", System.currentTimeMillis() - startTime); + kubernetesDriverRunner.waitForTermination(); + } catch (Throwable e) { + LOGGER.error("FATAL: process exits", e); + System.exit(EXIT_CODE); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunner.java index 8f78594dd..05dbbcc4f 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunner.java @@ -24,9 +24,9 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.EXIT_CODE; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; -import io.fabric8.kubernetes.api.model.ConfigMap; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.clustermanager.ClusterInfo; import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.clustermanager.KubernetesClusterManager; @@ -46,6 +46,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.fabric8.kubernetes.api.model.ConfigMap; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ @@ -56,95 +58,95 @@ */ public class KubernetesMasterRunner extends MasterRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesMasterRunner.class); - private static final Map ENV = System.getenv(); - - public KubernetesMasterRunner(Configuration config) { - super(config, new KubernetesClusterManager()); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesMasterRunner.class); + private static final Map ENV = System.getenv(); + + public KubernetesMasterRunner(Configuration config) { + super(config, new KubernetesClusterManager()); + } + + @Override + public ClusterInfo init() { + startClusterWatcher(); + + ClusterInfo clusterInfo = super.init(); + updateConfigMap(clusterInfo); + return clusterInfo; + } + + protected void startClusterWatcher() { + KubernetesPodWatcher watcher = new KubernetesPodWatcher(config); + watcher.start(); + } + + @Override + protected void initLeaderElectionService() { + master.initLeaderElectionService( + new KubernetesMasterLeaderContender(), config, DEFAULT_MASTER_ID); + try { + master.waitForLeaderElection(); + } catch (InterruptedException e) { + throw new GeaflowRuntimeException(e); } + } - @Override - public ClusterInfo init() { - startClusterWatcher(); - - ClusterInfo clusterInfo = super.init(); - updateConfigMap(clusterInfo); - return clusterInfo; - } - - protected void startClusterWatcher() { - KubernetesPodWatcher watcher = new KubernetesPodWatcher(config); - watcher.start(); - } + private class KubernetesMasterLeaderContender implements ILeaderContender { @Override - protected void initLeaderElectionService() { - master.initLeaderElectionService(new KubernetesMasterLeaderContender(), config, - DEFAULT_MASTER_ID); - try { - master.waitForLeaderElection(); - } catch (InterruptedException e) { - throw new GeaflowRuntimeException(e); - } + public void handleLeadershipGranted() { + LOGGER.info("Leadership granted, init master now."); + master.notifyLeaderElection(); } - private class KubernetesMasterLeaderContender implements ILeaderContender { - - @Override - public void handleLeadershipGranted() { - LOGGER.info("Leadership granted, init master now."); - master.notifyLeaderElection(); - } - - @Override - public void handleLeadershipLost() { - LOGGER.info("Leadership lost, exit the process now."); - System.exit(EXIT_CODE); - } - - @Override - public LeaderContenderType getType() { - return LeaderContenderType.master; - } + @Override + public void handleLeadershipLost() { + LOGGER.info("Leadership lost, exit the process now."); + System.exit(EXIT_CODE); } - private void updateConfigMap(ClusterInfo clusterInfo) { - Map updatedConfig = new HashMap<>(); - ConnectAddress masterAddress = clusterInfo.getMasterAddress(); - updatedConfig.put(KubernetesConfig.MASTER_EXPOSED_ADDRESS, masterAddress.toString()); - - Map driverAddresses = clusterInfo.getDriverAddresses(); - updatedConfig.put(KubernetesConfig.DRIVER_EXPOSED_ADDRESS, - KubernetesUtils.encodeRpcAddressMap(driverAddresses)); - - GeaflowKubeClient client = new GeaflowKubeClient(config); - String clusterId = config.getString(CLUSTER_ID); - - KubernetesMasterParam masterParam = new KubernetesMasterParam(config); - ConfigMap configMap = client.getConfigMap(masterParam.getConfigMapName(clusterId)); - ConfigMap updatedConfigMap = KubernetesResourceBuilder.updateConfigMap(configMap, - updatedConfig); - client.createOrReplaceConfigMap(updatedConfigMap); - LOGGER.info("updated master configmap: {}", config); + @Override + public LeaderContenderType getType() { + return LeaderContenderType.master; } - - public static void main(String[] args) throws Exception { - try { - final long startTime = System.currentTimeMillis(); - Configuration config = KubernetesUtils.loadConfiguration(); - String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); - config.setMasterId(masterId); - String profilerPath = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_PROFILER_PATH); - config.put(AGENT_PROFILER_PATH, profilerPath); - - KubernetesMasterRunner masterRunner = new KubernetesMasterRunner(config); - masterRunner.init(); - LOGGER.info("Completed master init in {} ms", System.currentTimeMillis() - startTime); - masterRunner.waitForTermination(); - } catch (Throwable e) { - LOGGER.error("FATAL: process exits", e); - System.exit(EXIT_CODE); - } + } + + private void updateConfigMap(ClusterInfo clusterInfo) { + Map updatedConfig = new HashMap<>(); + ConnectAddress masterAddress = clusterInfo.getMasterAddress(); + updatedConfig.put(KubernetesConfig.MASTER_EXPOSED_ADDRESS, masterAddress.toString()); + + Map driverAddresses = clusterInfo.getDriverAddresses(); + updatedConfig.put( + KubernetesConfig.DRIVER_EXPOSED_ADDRESS, + KubernetesUtils.encodeRpcAddressMap(driverAddresses)); + + GeaflowKubeClient client = new GeaflowKubeClient(config); + String clusterId = config.getString(CLUSTER_ID); + + KubernetesMasterParam masterParam = new KubernetesMasterParam(config); + ConfigMap configMap = client.getConfigMap(masterParam.getConfigMapName(clusterId)); + ConfigMap updatedConfigMap = + KubernetesResourceBuilder.updateConfigMap(configMap, updatedConfig); + client.createOrReplaceConfigMap(updatedConfigMap); + LOGGER.info("updated master configmap: {}", config); + } + + public static void main(String[] args) throws Exception { + try { + final long startTime = System.currentTimeMillis(); + Configuration config = KubernetesUtils.loadConfiguration(); + String masterId = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_MASTER_ID); + config.setMasterId(masterId); + String profilerPath = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_PROFILER_PATH); + config.put(AGENT_PROFILER_PATH, profilerPath); + + KubernetesMasterRunner masterRunner = new KubernetesMasterRunner(config); + masterRunner.init(); + LOGGER.info("Completed master init in {} ms", System.currentTimeMillis() - startTime); + masterRunner.waitForTermination(); + } catch (Throwable e) { + LOGGER.error("FATAL: process exits", e); + System.exit(EXIT_CODE); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesSupervisorRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesSupervisorRunner.java index 2932d00f7..2c995ce18 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesSupervisorRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesSupervisorRunner.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.CONTAINER_START_COMMAND; import java.util.Map; + import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.cluster.runner.Supervisor; @@ -33,43 +34,45 @@ public class KubernetesSupervisorRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesSupervisorRunner.class); - private static final Map ENV = System.getenv(); - private final Supervisor supervisor; + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesSupervisorRunner.class); + private static final Map ENV = System.getenv(); + private final Supervisor supervisor; - public KubernetesSupervisorRunner(Configuration configuration, String startCommand, - boolean autoRestart) { - this.supervisor = new Supervisor(startCommand, configuration, autoRestart); - } + public KubernetesSupervisorRunner( + Configuration configuration, String startCommand, boolean autoRestart) { + this.supervisor = new Supervisor(startCommand, configuration, autoRestart); + } - public void run() { - supervisor.start(); - } + public void run() { + supervisor.start(); + } - private void waitForTermination() { - LOGGER.info("Waiting for supervisor exit."); - supervisor.waitForTermination(); - } + private void waitForTermination() { + LOGGER.info("Waiting for supervisor exit."); + supervisor.waitForTermination(); + } - public static void main(String[] args) throws Exception { - try { - String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); - String autoRestartEnv = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_AUTO_RESTART); - LOGGER.info("Start supervisor with ID: {} pid: {} autoStart:{}", id, - ProcessUtil.getProcessId(), autoRestartEnv); + public static void main(String[] args) throws Exception { + try { + String id = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_CONTAINER_ID); + String autoRestartEnv = ClusterUtils.getEnvValue(ENV, K8SConstants.ENV_AUTO_RESTART); + LOGGER.info( + "Start supervisor with ID: {} pid: {} autoStart:{}", + id, + ProcessUtil.getProcessId(), + autoRestartEnv); - Configuration config = KubernetesUtils.loadConfiguration(); - String startCommand = ClusterUtils.getEnvValue(ENV, CONTAINER_START_COMMAND); - boolean autoRestart = !autoRestartEnv.equalsIgnoreCase(Boolean.FALSE.toString()); - KubernetesSupervisorRunner workerRunner = new KubernetesSupervisorRunner(config, - startCommand, autoRestart); - workerRunner.run(); - workerRunner.waitForTermination(); - LOGGER.info("Exit worker process"); - } catch (Throwable e) { - LOGGER.error("FATAL: process exits", e); - throw e; - } + Configuration config = KubernetesUtils.loadConfiguration(); + String startCommand = ClusterUtils.getEnvValue(ENV, CONTAINER_START_COMMAND); + boolean autoRestart = !autoRestartEnv.equalsIgnoreCase(Boolean.FALSE.toString()); + KubernetesSupervisorRunner workerRunner = + new KubernetesSupervisorRunner(config, startCommand, autoRestart); + workerRunner.run(); + workerRunner.waitForTermination(); + LOGGER.info("Exit worker process"); + } catch (Throwable e) { + LOGGER.error("FATAL: process exits", e); + throw e; } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesClusterFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesClusterFailoverStrategy.java index 193872222..a8f8b0baa 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesClusterFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesClusterFailoverStrategy.java @@ -24,8 +24,7 @@ public class KubernetesClusterFailoverStrategy extends ClusterFailoverStrategy { - public KubernetesClusterFailoverStrategy() { - super(EnvType.K8S); - } - + public KubernetesClusterFailoverStrategy() { + super(EnvType.K8S); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesComponentFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesComponentFailoverStrategy.java index 1208271bc..2ac52d90a 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesComponentFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesComponentFailoverStrategy.java @@ -31,45 +31,48 @@ import org.slf4j.LoggerFactory; public class KubernetesComponentFailoverStrategy extends ComponentFailoverStrategy { - private static final Logger LOGGER = - LoggerFactory.getLogger(KubernetesComponentFailoverStrategy.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(KubernetesComponentFailoverStrategy.class); - protected ClusterContext clusterContext; + protected ClusterContext clusterContext; - public KubernetesComponentFailoverStrategy() { - super(EnvType.K8S); - } - - public void init(ClusterContext clusterContext) { - super.init(clusterContext); - this.clusterContext = clusterContext; - } + public KubernetesComponentFailoverStrategy() { + super(EnvType.K8S); + } - @Override - public void doFailover(int componentId, Throwable cause) { - if (componentId != DEFAULT_MASTER_ID) { - if (cause instanceof GeaflowHeartbeatException) { - String startMessage = String.format("Start component failover for component #%s " - + "cause by %s.", componentId, cause.getMessage()); - LOGGER.info(startMessage); - reportFailoverEvent(ExceptionLevel.ERROR, EventLabel.FAILOVER_START, startMessage); + public void init(ClusterContext clusterContext) { + super.init(clusterContext); + this.clusterContext = clusterContext; + } - long startTime = System.currentTimeMillis(); - if (clusterContext.getDriverIds().containsKey(componentId)) { - clusterManager.restartDriver(componentId); - } else { - clusterManager.restartContainer(componentId); - } + @Override + public void doFailover(int componentId, Throwable cause) { + if (componentId != DEFAULT_MASTER_ID) { + if (cause instanceof GeaflowHeartbeatException) { + String startMessage = + String.format( + "Start component failover for component #%s " + "cause by %s.", + componentId, cause.getMessage()); + LOGGER.info(startMessage); + reportFailoverEvent(ExceptionLevel.ERROR, EventLabel.FAILOVER_START, startMessage); - String finishMessage = String.format("Completed component failover for component " - + "#%s in %s ms.", componentId, System.currentTimeMillis() - startTime); - LOGGER.info(finishMessage); - reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_FINISH, finishMessage); - } else { - String reason = cause == null ? null : cause.getMessage(); - LOGGER.warn("{} throws exception: {}", componentId, reason); - } + long startTime = System.currentTimeMillis(); + if (clusterContext.getDriverIds().containsKey(componentId)) { + clusterManager.restartDriver(componentId); + } else { + clusterManager.restartContainer(componentId); } - } + String finishMessage = + String.format( + "Completed component failover for component " + "#%s in %s ms.", + componentId, System.currentTimeMillis() - startTime); + LOGGER.info(finishMessage); + reportFailoverEvent(ExceptionLevel.INFO, EventLabel.FAILOVER_FINISH, finishMessage); + } else { + String reason = cause == null ? null : cause.getMessage(); + LOGGER.warn("{} throws exception: {}", componentId, reason); + } + } + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesDisableFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesDisableFailoverStrategy.java index 5bdf99144..482ed91b0 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesDisableFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/failover/KubernetesDisableFailoverStrategy.java @@ -24,8 +24,7 @@ public class KubernetesDisableFailoverStrategy extends DisableFailoverStrategy { - public KubernetesDisableFailoverStrategy() { - super(EnvType.K8S); - } - + public KubernetesDisableFailoverStrategy() { + super(EnvType.K8S); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/AbstractPodHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/AbstractPodHandler.java index a0be9e5dd..16eade0a1 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/AbstractPodHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/AbstractPodHandler.java @@ -21,35 +21,36 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.stats.collector.StatsCollectorFactory; import org.apache.geaflow.stats.model.ExceptionLevel; public abstract class AbstractPodHandler implements IPodEventHandler { - protected List listeners; - - public AbstractPodHandler() { - this.listeners = new ArrayList<>(); - } - - public void addListener(IEventListener listener) { - listeners.add(listener); - } + protected List listeners; - public void notifyListeners(PodEvent event) { - for (IEventListener listener : listeners) { - listener.onEvent(event); - } - } + public AbstractPodHandler() { + this.listeners = new ArrayList<>(); + } - protected void reportPodEvent(PodEvent event, ExceptionLevel level, String message) { - String eventMessage = buildEventMessage(event, message); - StatsCollectorFactory.getInstance().getEventCollector() - .reportEvent(level, event.getEventKind().name(), eventMessage); - } + public void addListener(IEventListener listener) { + listeners.add(listener); + } - private String buildEventMessage(PodEvent event, String message) { - return message + "\n" + event.toString(); + public void notifyListeners(PodEvent event) { + for (IEventListener listener : listeners) { + listener.onEvent(event); } - + } + + protected void reportPodEvent(PodEvent event, ExceptionLevel level, String message) { + String eventMessage = buildEventMessage(event, message); + StatsCollectorFactory.getInstance() + .getEventCollector() + .reportEvent(level, event.getEventKind().name(), eventMessage); + } + + private String buildEventMessage(PodEvent event, String message) { + return message + "\n" + event.toString(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IEventListener.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IEventListener.java index 82f8091a5..a996ddae6 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IEventListener.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IEventListener.java @@ -21,6 +21,5 @@ public interface IEventListener { - void onEvent(PodEvent event); - + void onEvent(PodEvent event); } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IPodEventHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IPodEventHandler.java index d56f5119e..93184a315 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IPodEventHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/IPodEventHandler.java @@ -23,6 +23,5 @@ public interface IPodEventHandler { - void handle(Pod pod); - + void handle(Pod pod); } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodAddedHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodAddedHandler.java index 81ad08880..bd2b04f25 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodAddedHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodAddedHandler.java @@ -19,21 +19,22 @@ package org.apache.geaflow.cluster.k8s.handler; -import io.fabric8.kubernetes.api.model.Pod; import org.apache.geaflow.cluster.k8s.handler.PodHandlerRegistry.EventKind; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.stats.model.ExceptionLevel; +import io.fabric8.kubernetes.api.model.Pod; + public class PodAddedHandler extends AbstractPodHandler { - @Override - public void handle(Pod pod) { - String componentId = KubernetesUtils.extractComponentId(pod); - if (componentId != null) { - String addMessage = String.format("Pod #%s %s is created.", - componentId, pod.getMetadata().getName()); - PodEvent event = new PodEvent(pod, EventKind.POD_ADDED); - reportPodEvent(event, ExceptionLevel.INFO, addMessage); - } + @Override + public void handle(Pod pod) { + String componentId = KubernetesUtils.extractComponentId(pod); + if (componentId != null) { + String addMessage = + String.format("Pod #%s %s is created.", componentId, pod.getMetadata().getName()); + PodEvent event = new PodEvent(pod, EventKind.POD_ADDED); + reportPodEvent(event, ExceptionLevel.INFO, addMessage); } + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodDeletedHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodDeletedHandler.java index 942d36a55..e14d163c3 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodDeletedHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodDeletedHandler.java @@ -19,21 +19,22 @@ package org.apache.geaflow.cluster.k8s.handler; -import io.fabric8.kubernetes.api.model.Pod; import org.apache.geaflow.cluster.k8s.handler.PodHandlerRegistry.EventKind; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.stats.model.ExceptionLevel; +import io.fabric8.kubernetes.api.model.Pod; + public class PodDeletedHandler extends AbstractPodHandler { - @Override - public void handle(Pod pod) { - String componentId = KubernetesUtils.extractComponentId(pod); - if (componentId != null) { - String deleteMessage = String.format("Pod #%s %s is deleted.", - componentId, pod.getMetadata().getName()); - PodEvent event = new PodEvent(pod, EventKind.POD_DELETED); - reportPodEvent(event, ExceptionLevel.ERROR, deleteMessage); - } + @Override + public void handle(Pod pod) { + String componentId = KubernetesUtils.extractComponentId(pod); + if (componentId != null) { + String deleteMessage = + String.format("Pod #%s %s is deleted.", componentId, pod.getMetadata().getName()); + PodEvent event = new PodEvent(pod, EventKind.POD_DELETED); + reportPodEvent(event, ExceptionLevel.ERROR, deleteMessage); } + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvent.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvent.java index 457aa8e59..2787220da 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvent.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvent.java @@ -19,74 +19,89 @@ package org.apache.geaflow.cluster.k8s.handler; -import io.fabric8.kubernetes.api.model.Pod; import java.io.Serializable; + import org.apache.geaflow.cluster.k8s.handler.PodHandlerRegistry.EventKind; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; +import io.fabric8.kubernetes.api.model.Pod; + public class PodEvent implements Serializable { - private EventKind eventKind; - private String hostIp; - private String podIp; - private long ts; - private String containerId; - - public PodEvent(Pod pod, EventKind kind) { - this(pod, kind, System.currentTimeMillis()); - } - - public PodEvent(Pod pod, EventKind kind, long ts) { - this.eventKind = kind; - this.containerId = KubernetesUtils.extractComponentId(pod); - this.podIp = pod.getStatus().getPodIP(); - this.hostIp = pod.getStatus().getHostIP(); - this.ts = ts; - } - - public EventKind getEventKind() { - return eventKind; - } - - public void setEventKind(EventKind eventKind) { - this.eventKind = eventKind; - } - - public String getHostIp() { - return hostIp; - } - - public void setHostIp(String hostIp) { - this.hostIp = hostIp; - } - - public String getPodIp() { - return podIp; - } - - public void setPodIp(String podIp) { - this.podIp = podIp; - } - - public long getTs() { - return ts; - } - - public void setTs(long ts) { - this.ts = ts; - } - - public String getContainerId() { - return containerId; - } - - public void setContainerId(String containerId) { - this.containerId = containerId; - } - - @Override - public String toString() { - return "PodEvent{" + "eventKind=" + eventKind + ", hostIp='" + hostIp + '\'' + ", podIp='" - + podIp + '\'' + ", ts=" + ts + ", containerId='" + containerId + '\'' + '}'; - } + private EventKind eventKind; + private String hostIp; + private String podIp; + private long ts; + private String containerId; + + public PodEvent(Pod pod, EventKind kind) { + this(pod, kind, System.currentTimeMillis()); + } + + public PodEvent(Pod pod, EventKind kind, long ts) { + this.eventKind = kind; + this.containerId = KubernetesUtils.extractComponentId(pod); + this.podIp = pod.getStatus().getPodIP(); + this.hostIp = pod.getStatus().getHostIP(); + this.ts = ts; + } + + public EventKind getEventKind() { + return eventKind; + } + + public void setEventKind(EventKind eventKind) { + this.eventKind = eventKind; + } + + public String getHostIp() { + return hostIp; + } + + public void setHostIp(String hostIp) { + this.hostIp = hostIp; + } + + public String getPodIp() { + return podIp; + } + + public void setPodIp(String podIp) { + this.podIp = podIp; + } + + public long getTs() { + return ts; + } + + public void setTs(long ts) { + this.ts = ts; + } + + public String getContainerId() { + return containerId; + } + + public void setContainerId(String containerId) { + this.containerId = containerId; + } + + @Override + public String toString() { + return "PodEvent{" + + "eventKind=" + + eventKind + + ", hostIp='" + + hostIp + + '\'' + + ", podIp='" + + podIp + + '\'' + + ", ts=" + + ts + + ", containerId='" + + containerId + + '\'' + + '}'; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvictHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvictHandler.java index 9eb3944be..e0fd03024 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvictHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodEvictHandler.java @@ -21,8 +21,8 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.EVICTED_POD_LABELS; -import io.fabric8.kubernetes.api.model.Pod; import java.util.Map; + import org.apache.geaflow.cluster.k8s.handler.PodHandlerRegistry.EventKind; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.common.config.Configuration; @@ -30,34 +30,40 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.fabric8.kubernetes.api.model.Pod; + public class PodEvictHandler extends AbstractPodHandler { - private static final Logger LOG = LoggerFactory.getLogger(PodEvictHandler.class); - private final Map evictLabels; - private int totalCount; + private static final Logger LOG = LoggerFactory.getLogger(PodEvictHandler.class); + private final Map evictLabels; + private int totalCount; - public PodEvictHandler(Configuration configuration) { - this.evictLabels = KubernetesUtils.getPairsConf(configuration, EVICTED_POD_LABELS); - } + public PodEvictHandler(Configuration configuration) { + this.evictLabels = KubernetesUtils.getPairsConf(configuration, EVICTED_POD_LABELS); + } + + @Override + public void handle(Pod pod) { + Map labels = pod.getMetadata().getLabels(); + for (Map.Entry entry : evictLabels.entrySet()) { + String key = entry.getKey(); + if (labels.get(key) != null && labels.get(key).equalsIgnoreCase(entry.getValue())) { + String componentId = KubernetesUtils.extractComponentId(pod); + String message = + String.format( + "Pod #%s %s will be removed, label: %s annotations: %s, total removed: %s", + componentId, + pod.getMetadata().getName(), + key, + pod.getMetadata().getAnnotations(), + ++totalCount); + LOG.info(message); - @Override - public void handle(Pod pod) { - Map labels = pod.getMetadata().getLabels(); - for (Map.Entry entry : evictLabels.entrySet()) { - String key = entry.getKey(); - if (labels.get(key) != null && labels.get(key).equalsIgnoreCase(entry.getValue())) { - String componentId = KubernetesUtils.extractComponentId(pod); - String message = String.format( - "Pod #%s %s will be removed, label: %s annotations: %s, total removed: %s", - componentId, pod.getMetadata().getName(), key, - pod.getMetadata().getAnnotations(), ++totalCount); - LOG.info(message); - - PodEvent event = new PodEvent(pod, EventKind.POD_EVICTION); - notifyListeners(event); - reportPodEvent(event, ExceptionLevel.WARN, message); - break; - } - } + PodEvent event = new PodEvent(pod, EventKind.POD_EVICTION); + notifyListeners(event); + reportPodEvent(event, ExceptionLevel.WARN, message); + break; + } } -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodHandlerRegistry.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodHandlerRegistry.java index 76ae3b580..65f6985c0 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodHandlerRegistry.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodHandlerRegistry.java @@ -19,53 +19,54 @@ package org.apache.geaflow.cluster.k8s.handler; -import io.fabric8.kubernetes.client.Watcher.Action; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; +import io.fabric8.kubernetes.client.Watcher.Action; + public class PodHandlerRegistry { - private final Map> eventHandlerMap; - private static PodHandlerRegistry INSTANCE; + private final Map> eventHandlerMap; + private static PodHandlerRegistry INSTANCE; - private PodHandlerRegistry(Configuration configuration) { - this.eventHandlerMap = new HashMap<>(); + private PodHandlerRegistry(Configuration configuration) { + this.eventHandlerMap = new HashMap<>(); - Map modifiedHandlerMap = new HashMap<>(); - modifiedHandlerMap.put(EventKind.POD_OOM, new PodOOMHandler()); - modifiedHandlerMap.put(EventKind.POD_EVICTION, new PodEvictHandler(configuration)); - this.eventHandlerMap.put(Action.MODIFIED, modifiedHandlerMap); + Map modifiedHandlerMap = new HashMap<>(); + modifiedHandlerMap.put(EventKind.POD_OOM, new PodOOMHandler()); + modifiedHandlerMap.put(EventKind.POD_EVICTION, new PodEvictHandler(configuration)); + this.eventHandlerMap.put(Action.MODIFIED, modifiedHandlerMap); - Map addedHandlerMap = new HashMap<>(); - addedHandlerMap.put(EventKind.POD_ADDED, new PodAddedHandler()); - this.eventHandlerMap.put(Action.ADDED, addedHandlerMap); + Map addedHandlerMap = new HashMap<>(); + addedHandlerMap.put(EventKind.POD_ADDED, new PodAddedHandler()); + this.eventHandlerMap.put(Action.ADDED, addedHandlerMap); - Map deletedHandlerMap = new HashMap<>(); - deletedHandlerMap.put(EventKind.POD_DELETED, new PodDeletedHandler()); - this.eventHandlerMap.put(Action.DELETED, deletedHandlerMap); - } + Map deletedHandlerMap = new HashMap<>(); + deletedHandlerMap.put(EventKind.POD_DELETED, new PodDeletedHandler()); + this.eventHandlerMap.put(Action.DELETED, deletedHandlerMap); + } - public static synchronized PodHandlerRegistry getInstance(Configuration configuration) { - if (INSTANCE == null) { - INSTANCE = new PodHandlerRegistry(configuration); - } - return INSTANCE; + public static synchronized PodHandlerRegistry getInstance(Configuration configuration) { + if (INSTANCE == null) { + INSTANCE = new PodHandlerRegistry(configuration); } + return INSTANCE; + } - public void registerListener(Action action, EventKind eventKind, IEventListener listener) { - ((AbstractPodHandler) eventHandlerMap.get(action).get(eventKind)).addListener(listener); - } - - public Map> getHandlerMap() { - return eventHandlerMap; - } + public void registerListener(Action action, EventKind eventKind, IEventListener listener) { + ((AbstractPodHandler) eventHandlerMap.get(action).get(eventKind)).addListener(listener); + } - public enum EventKind { - POD_ADDED, - POD_DELETED, - POD_OOM, - POD_EVICTION - } + public Map> getHandlerMap() { + return eventHandlerMap; + } + public enum EventKind { + POD_ADDED, + POD_DELETED, + POD_OOM, + POD_EVICTION + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodOOMHandler.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodOOMHandler.java index f7abee66e..bc9dbda11 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodOOMHandler.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/handler/PodOOMHandler.java @@ -37,13 +37,11 @@ package org.apache.geaflow.cluster.k8s.handler; -import io.fabric8.kubernetes.api.model.ContainerState; -import io.fabric8.kubernetes.api.model.ContainerStatus; -import io.fabric8.kubernetes.api.model.Pod; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.k8s.handler.PodHandlerRegistry.EventKind; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; import org.apache.geaflow.common.tuple.Tuple; @@ -54,83 +52,88 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.fabric8.kubernetes.api.model.ContainerState; +import io.fabric8.kubernetes.api.model.ContainerStatus; +import io.fabric8.kubernetes.api.model.Pod; + public class PodOOMHandler extends AbstractPodHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(PodOOMHandler.class); - private static final String OOM_KILLED_KEY = "OOMKilled"; - private static final Exception POD_OOM_MSG = new Exception("pod overused memory"); + private static final Logger LOGGER = LoggerFactory.getLogger(PodOOMHandler.class); + private static final String OOM_KILLED_KEY = "OOMKilled"; + private static final Exception POD_OOM_MSG = new Exception("pod overused memory"); - private final DateTimeFormatter parser = ISODateTimeFormat.dateTimeNoMillis(); - private final Map>> exceptions; - private int totalOOMCount; + private final DateTimeFormatter parser = ISODateTimeFormat.dateTimeNoMillis(); + private final Map>> exceptions; + private int totalOOMCount; - public PodOOMHandler() { - this.totalOOMCount = 0; - this.exceptions = new HashMap<>(); - } + public PodOOMHandler() { + this.totalOOMCount = 0; + this.exceptions = new HashMap<>(); + } - @Override - public void handle(Pod pod) { - if (pod.getStatus() != null && !pod.getStatus().getContainerStatuses().isEmpty()) { - for (ContainerStatus containerStatus : pod.getStatus().getContainerStatuses()) { - ContainerState state = containerStatus.getState(); - if (state != null && state.getTerminated() != null - && state.getTerminated().getReason() != null - && state.getTerminated().getFinishedAt() != null) { - if (state.getTerminated().getReason().contains(OOM_KILLED_KEY)) { - String finishTime = containerStatus.getState().getTerminated() - .getFinishedAt(); - DateTime parsed; - try { - parsed = parser.parseDateTime(finishTime); - } catch (Exception e) { - LOGGER.error("Failed to parse finish time: {}", finishTime, e); - return; - } - long exceptionTime = parsed.getMillis(); - List> oldList = exceptions.get(exceptionTime); + @Override + public void handle(Pod pod) { + if (pod.getStatus() != null && !pod.getStatus().getContainerStatuses().isEmpty()) { + for (ContainerStatus containerStatus : pod.getStatus().getContainerStatuses()) { + ContainerState state = containerStatus.getState(); + if (state != null + && state.getTerminated() != null + && state.getTerminated().getReason() != null + && state.getTerminated().getFinishedAt() != null) { + if (state.getTerminated().getReason().contains(OOM_KILLED_KEY)) { + String finishTime = containerStatus.getState().getTerminated().getFinishedAt(); + DateTime parsed; + try { + parsed = parser.parseDateTime(finishTime); + } catch (Exception e) { + LOGGER.error("Failed to parse finish time: {}", finishTime, e); + return; + } + long exceptionTime = parsed.getMillis(); + List> oldList = exceptions.get(exceptionTime); - boolean added = true; - String componentId = KubernetesUtils.extractComponentId(pod); - Tuple newException = new Tuple<>(componentId, - POD_OOM_MSG); - if (oldList == null) { - exceptions.computeIfAbsent(exceptionTime, k -> new ArrayList<>()) - .add(newException); - } else { - if (exists(newException, oldList)) { - added = false; - } else { - oldList.add(newException); - } - } + boolean added = true; + String componentId = KubernetesUtils.extractComponentId(pod); + Tuple newException = new Tuple<>(componentId, POD_OOM_MSG); + if (oldList == null) { + exceptions.computeIfAbsent(exceptionTime, k -> new ArrayList<>()).add(newException); + } else { + if (exists(newException, oldList)) { + added = false; + } else { + oldList.add(newException); + } + } - if (added) { - totalOOMCount++; - LOGGER.info("Pod #{} {} oom killed at {}, totally: {}", componentId, - pod.getMetadata().getName(), parsed, totalOOMCount); + if (added) { + totalOOMCount++; + LOGGER.info( + "Pod #{} {} oom killed at {}, totally: {}", + componentId, + pod.getMetadata().getName(), + parsed, + totalOOMCount); - PodEvent oomEvent = new PodEvent(pod, EventKind.POD_OOM, exceptionTime); - notifyListeners(oomEvent); + PodEvent oomEvent = new PodEvent(pod, EventKind.POD_OOM, exceptionTime); + notifyListeners(oomEvent); - String errMsg = String.format("pod %s oom killed at %s", - pod.getMetadata().getName(), parsed); - reportPodEvent(oomEvent, ExceptionLevel.ERROR, errMsg); - } - } - } + String errMsg = + String.format("pod %s oom killed at %s", pod.getMetadata().getName(), parsed); + reportPodEvent(oomEvent, ExceptionLevel.ERROR, errMsg); } + } } + } } + } - private boolean exists(Tuple target, - List> exceptions) { - for (Tuple e : exceptions) { - if (e.f0.equals(target.f0)) { - return true; - } - } - return false; + private boolean exists( + Tuple target, List> exceptions) { + for (Tuple e : exceptions) { + if (e.f0.equals(target.f0)) { + return true; + } } - -} \ No newline at end of file + return false; + } +} diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectionService.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectionService.java index c20a22e6a..97a814ca2 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectionService.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectionService.java @@ -22,11 +22,6 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.POD_USER_LABELS; import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.WATCHER_CHECK_INTERVAL; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.Service; -import io.fabric8.kubernetes.client.Watch; -import io.fabric8.kubernetes.client.Watcher.Action; -import io.fabric8.kubernetes.client.extended.leaderelection.LeaderCallbacks; import java.util.HashMap; import java.util.Map; import java.util.UUID; @@ -35,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; + import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.utils.KubernetesUtils; @@ -49,145 +45,161 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.client.Watch; +import io.fabric8.kubernetes.client.Watcher.Action; +import io.fabric8.kubernetes.client.extended.leaderelection.LeaderCallbacks; + public class KubernetesLeaderElectionService implements ILeaderElectionService { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesLeaderElectionService.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(KubernetesLeaderElectionService.class); - private static final String NAME_SEPERATOR = "-"; + private static final String NAME_SEPERATOR = "-"; - private GeaflowKubeClient client; + private GeaflowKubeClient client; - private KubernetesLeaderElectorConfig electorConfig; + private KubernetesLeaderElectorConfig electorConfig; - private KubernetesLeaderElector leaderElector; + private KubernetesLeaderElector leaderElector; - private Configuration configuration; + private Configuration configuration; - private String componentId; + private String componentId; - private String identity; + private String identity; - private ILeaderContender contender; + private ILeaderContender contender; - private Watch masterServiceWatcher; + private Watch masterServiceWatcher; - private volatile boolean watcherClosed; + private volatile boolean watcherClosed; - private final ScheduledExecutorService executorService = - Executors.newSingleThreadScheduledExecutor( - ThreadUtil.namedThreadFactory(true, "watcher-creator")); + private final ScheduledExecutorService executorService = + Executors.newSingleThreadScheduledExecutor( + ThreadUtil.namedThreadFactory(true, "watcher-creator")); - public void init(Configuration configuration, String componentId) { - this.configuration = configuration; - this.componentId = componentId; - this.client = new GeaflowKubeClient(configuration); - this.identity = UUID.randomUUID().toString(); - this.watcherClosed = true; - LOGGER.info("Init leader-election service with identity: {}", identity); - } + public void init(Configuration configuration, String componentId) { + this.configuration = configuration; + this.componentId = componentId; + this.client = new GeaflowKubeClient(configuration); + this.identity = UUID.randomUUID().toString(); + this.watcherClosed = true; + LOGGER.info("Init leader-election service with identity: {}", identity); + } - @Override - public void open(ILeaderContender contender) { - this.contender = contender; - KubernetesLeaderElectionEventListener eventListener = - new KubernetesLeaderElectionEventListener(); - initMasterServiceWatcher(); - String configMapName = getLeaderConfigMapName(contender); - electorConfig = KubernetesLeaderElectorConfig.build(configuration, configMapName, identity); - LeaderCallbacks callbacks = new LeaderCallbacks( + @Override + public void open(ILeaderContender contender) { + this.contender = contender; + KubernetesLeaderElectionEventListener eventListener = + new KubernetesLeaderElectionEventListener(); + initMasterServiceWatcher(); + String configMapName = getLeaderConfigMapName(contender); + electorConfig = KubernetesLeaderElectorConfig.build(configuration, configMapName, identity); + LeaderCallbacks callbacks = + new LeaderCallbacks( eventListener::handleLeadershipGranted, eventListener::handleLeadershipLost, - eventListener::handleNewLeadership - ); - leaderElector = new KubernetesLeaderElector(client, electorConfig, callbacks); - leaderElector.run(); - } - - private void initMasterServiceWatcher() { - String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); - String masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); - BiConsumer eventHandler = this::handleClusterDestroyed; - Consumer exceptionHandler = (exception) -> { - watcherClosed = true; - LOGGER.warn("watch exception: {}", exception.getMessage(), exception); + eventListener::handleNewLeadership); + leaderElector = new KubernetesLeaderElector(client, electorConfig, callbacks); + leaderElector.run(); + } + + private void initMasterServiceWatcher() { + String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); + String masterServiceName = KubernetesUtils.getMasterServiceName(clusterId); + BiConsumer eventHandler = this::handleClusterDestroyed; + Consumer exceptionHandler = + (exception) -> { + watcherClosed = true; + LOGGER.warn("watch exception: {}", exception.getMessage(), exception); }; - int checkInterval = configuration.getInteger(WATCHER_CHECK_INTERVAL); - executorService.scheduleAtFixedRate(() -> { - if (watcherClosed) { - if (masterServiceWatcher != null) { - masterServiceWatcher.close(); - } - masterServiceWatcher = client.createServiceWatcher(masterServiceName, eventHandler, exceptionHandler); - if (masterServiceWatcher != null) { - watcherClosed = false; - } + int checkInterval = configuration.getInteger(WATCHER_CHECK_INTERVAL); + executorService.scheduleAtFixedRate( + () -> { + if (watcherClosed) { + if (masterServiceWatcher != null) { + masterServiceWatcher.close(); + } + masterServiceWatcher = + client.createServiceWatcher(masterServiceName, eventHandler, exceptionHandler); + if (masterServiceWatcher != null) { + watcherClosed = false; } - }, 0, checkInterval, TimeUnit.SECONDS); + } + }, + 0, + checkInterval, + TimeUnit.SECONDS); + } + + private void handleClusterDestroyed(Action action, Service service) { + // Leader-election service should be stopped if the cluster is destroyed, + // to avoid creating the config-map-lock again after deleted. + if (action == Action.DELETED) { + LOGGER.warn( + "Master service {} is deleted, close the leader-election service now.", + service.getMetadata().getName()); + leaderElector.close(); } + } - private void handleClusterDestroyed(Action action, Service service) { - // Leader-election service should be stopped if the cluster is destroyed, - // to avoid creating the config-map-lock again after deleted. - if (action == Action.DELETED) { - LOGGER.warn("Master service {} is deleted, close the leader-election service now.", - service.getMetadata().getName()); - leaderElector.close(); - } + @Override + public void close() { + if (leaderElector != null) { + leaderElector.close(); } + executorService.shutdownNow(); + } - @Override - public void close() { - if (leaderElector != null) { - leaderElector.close(); - } - executorService.shutdownNow(); - } + @Override + public boolean isLeader() { + ConfigMap configMap = client.getConfigMap(electorConfig.getConfigMapName()); + return KubernetesLeaderElector.isLeader(configMap, electorConfig.getIdentity()); + } + + @Override + public LeaderElectionServiceType getType() { + return LeaderElectionServiceType.kubernetes; + } + + private class KubernetesLeaderElectionEventListener implements ILeaderElectionEventListener { @Override - public boolean isLeader() { - ConfigMap configMap = client.getConfigMap(electorConfig.getConfigMapName()); - return KubernetesLeaderElector.isLeader(configMap, electorConfig.getIdentity()); + public void handleLeadershipGranted() { + String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); + Map configMapLabels = new HashMap<>(); + configMapLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); + configMapLabels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, componentId); + configMapLabels.put(K8SConstants.LABEL_CONFIG_MAP_LOCK, Boolean.toString(true)); + configMapLabels.putAll(KubernetesUtils.getPairsConf(configuration, POD_USER_LABELS)); + ConfigMap configMap = client.getConfigMap(getLeaderConfigMapName(contender)); + configMap.getMetadata().setLabels(configMapLabels); + + Service masterService = client.getService(KubernetesUtils.getMasterServiceName(clusterId)); + configMap.addOwnerReference(masterService); + client.updateConfigMap(configMap); + contender.handleLeadershipGranted(); } @Override - public LeaderElectionServiceType getType() { - return LeaderElectionServiceType.kubernetes; + public void handleLeadershipLost() { + contender.handleLeadershipLost(); } - private class KubernetesLeaderElectionEventListener implements ILeaderElectionEventListener { - - @Override - public void handleLeadershipGranted() { - String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); - Map configMapLabels = new HashMap<>(); - configMapLabels.put(K8SConstants.LABEL_APP_KEY, clusterId); - configMapLabels.put(K8SConstants.LABEL_COMPONENT_ID_KEY, componentId); - configMapLabels.put(K8SConstants.LABEL_CONFIG_MAP_LOCK, Boolean.toString(true)); - configMapLabels.putAll(KubernetesUtils.getPairsConf(configuration, POD_USER_LABELS)); - ConfigMap configMap = client.getConfigMap(getLeaderConfigMapName(contender)); - configMap.getMetadata().setLabels(configMapLabels); - - Service masterService = client.getService(KubernetesUtils.getMasterServiceName(clusterId)); - configMap.addOwnerReference(masterService); - client.updateConfigMap(configMap); - contender.handleLeadershipGranted(); - } - - @Override - public void handleLeadershipLost() { - contender.handleLeadershipLost(); - } - - @Override - public void handleNewLeadership(String newLeader) { - LOGGER.info("New leader for contender {} is elected. The leader is {}.", - contender.getClass().getSimpleName(), newLeader.equals(identity) ? "me" : newLeader); - } + @Override + public void handleNewLeadership(String newLeader) { + LOGGER.info( + "New leader for contender {} is elected. The leader is {}.", + contender.getClass().getSimpleName(), + newLeader.equals(identity) ? "me" : newLeader); } + } - private String getLeaderConfigMapName(ILeaderContender contender) { - String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); - LeaderContenderType contenderType = contender.getType(); - return clusterId + NAME_SEPERATOR + contenderType + NAME_SEPERATOR + componentId; - } + private String getLeaderConfigMapName(ILeaderContender contender) { + String clusterId = configuration.getString(ExecutionConfigKeys.CLUSTER_ID); + LeaderContenderType contenderType = contender.getType(); + return clusterId + NAME_SEPERATOR + contenderType + NAME_SEPERATOR + componentId; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElector.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElector.java index bbfcdf62e..c16cab282 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElector.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElector.java @@ -19,34 +19,38 @@ package org.apache.geaflow.cluster.k8s.leaderelection; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.client.extended.leaderelection.LeaderCallbacks; -import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElectionConfig; -import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElectionConfigBuilder; -import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElector; -import io.fabric8.kubernetes.client.extended.leaderelection.resourcelock.ConfigMapLock; import java.time.Duration; import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.common.utils.ThreadUtil; +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.client.extended.leaderelection.LeaderCallbacks; +import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElectionConfig; +import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElectionConfigBuilder; +import io.fabric8.kubernetes.client.extended.leaderelection.LeaderElector; +import io.fabric8.kubernetes.client.extended.leaderelection.resourcelock.ConfigMapLock; + public class KubernetesLeaderElector { - public static final String LEADER_ANNOTATION_KEY = "control-plane.alpha.kubernetes.io/leader"; + public static final String LEADER_ANNOTATION_KEY = "control-plane.alpha.kubernetes.io/leader"; - private final ExecutorService executorService = Executors.newFixedThreadPool(4, - ThreadUtil.namedThreadFactory(true, "leader-elector")); + private final ExecutorService executorService = + Executors.newFixedThreadPool(4, ThreadUtil.namedThreadFactory(true, "leader-elector")); - private final LeaderElector innerLeaderElector; + private final LeaderElector innerLeaderElector; - public KubernetesLeaderElector(GeaflowKubeClient client, KubernetesLeaderElectorConfig config, - LeaderCallbacks callbacks) { - ConfigMapLock lock = new ConfigMapLock(config.getNamespace(), config.getConfigMapName(), config.getIdentity()); - LeaderElectionConfig electionConfig = new LeaderElectionConfigBuilder() + public KubernetesLeaderElector( + GeaflowKubeClient client, KubernetesLeaderElectorConfig config, LeaderCallbacks callbacks) { + ConfigMapLock lock = + new ConfigMapLock(config.getNamespace(), config.getConfigMapName(), config.getIdentity()); + LeaderElectionConfig electionConfig = + new LeaderElectionConfigBuilder() .withName(config.getConfigMapName()) .withLock(lock) .withLeaseDuration(Duration.ofSeconds(config.getLeaseDuration())) @@ -55,24 +59,23 @@ public KubernetesLeaderElector(GeaflowKubeClient client, KubernetesLeaderElector .withLeaderCallbacks(callbacks) .build(); - innerLeaderElector = client.createLeaderElector(electionConfig, executorService); - } - - public void run() { - if (!executorService.isShutdown()) { - executorService.execute(innerLeaderElector::run); - } - } + innerLeaderElector = client.createLeaderElector(electionConfig, executorService); + } - public void close() { - executorService.shutdownNow(); + public void run() { + if (!executorService.isShutdown()) { + executorService.execute(innerLeaderElector::run); } + } - public static boolean isLeader(ConfigMap configMap, String identity) { - Map annotations = - Optional.ofNullable(configMap.getMetadata().getAnnotations()).orElse(new HashMap<>()); - String leader = annotations.get(LEADER_ANNOTATION_KEY); - return leader != null && leader.contains(identity); - } + public void close() { + executorService.shutdownNow(); + } + public static boolean isLeader(ConfigMap configMap, String identity) { + Map annotations = + Optional.ofNullable(configMap.getMetadata().getAnnotations()).orElse(new HashMap<>()); + String leader = annotations.get(LEADER_ANNOTATION_KEY); + return leader != null && leader.contains(identity); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectorConfig.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectorConfig.java index 57347e38d..15f899578 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectorConfig.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/leaderelection/KubernetesLeaderElectorConfig.java @@ -24,93 +24,84 @@ public class KubernetesLeaderElectorConfig { - /** - * Cluster namespace. - */ - private String namespace; - - /** - * Identity of the contender. - */ - private String identity; - - /** - * The name of the config-map-lock. - */ - private String configMapName; - - /** - * Duration of the lease. - */ - private Integer leaseDuration; - - /** - * The deadline of the leader to renew the lease. - */ - private Integer renewDeadline; - - /** - * Retry interval of the contender. - */ - private Integer retryPeriod; - - public static KubernetesLeaderElectorConfig build(Configuration configuration, - String configMapName, String identity) { - KubernetesLeaderElectorConfig electorConfig = new KubernetesLeaderElectorConfig(); - electorConfig.setNamespace(configuration.getString(KubernetesConfigKeys.NAME_SPACE)); - electorConfig.setLeaseDuration(configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_LEASE_DURATION)); - electorConfig.setRenewDeadline(configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_RENEW_DEADLINE)); - electorConfig.setRetryPeriod(configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_RETRY_PERIOD)); - electorConfig.setIdentity(identity); - electorConfig.setConfigMapName(configMapName); - return electorConfig; - } - - public String getNamespace() { - return namespace; - } - - public void setNamespace(String namespace) { - this.namespace = namespace; - } - - public String getIdentity() { - return identity; - } - - public void setIdentity(String identity) { - this.identity = identity; - } - - public String getConfigMapName() { - return configMapName; - } - - public void setConfigMapName(String configMapName) { - this.configMapName = configMapName; - } - - public Integer getLeaseDuration() { - return leaseDuration; - } - - public void setLeaseDuration(Integer leaseDuration) { - this.leaseDuration = leaseDuration; - } - - public Integer getRenewDeadline() { - return renewDeadline; - } - - public void setRenewDeadline(Integer renewDeadline) { - this.renewDeadline = renewDeadline; - } - - public Integer getRetryPeriod() { - return retryPeriod; - } - - public void setRetryPeriod(Integer retryPeriod) { - this.retryPeriod = retryPeriod; - } + /** Cluster namespace. */ + private String namespace; + + /** Identity of the contender. */ + private String identity; + + /** The name of the config-map-lock. */ + private String configMapName; + + /** Duration of the lease. */ + private Integer leaseDuration; + + /** The deadline of the leader to renew the lease. */ + private Integer renewDeadline; + + /** Retry interval of the contender. */ + private Integer retryPeriod; + + public static KubernetesLeaderElectorConfig build( + Configuration configuration, String configMapName, String identity) { + KubernetesLeaderElectorConfig electorConfig = new KubernetesLeaderElectorConfig(); + electorConfig.setNamespace(configuration.getString(KubernetesConfigKeys.NAME_SPACE)); + electorConfig.setLeaseDuration( + configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_LEASE_DURATION)); + electorConfig.setRenewDeadline( + configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_RENEW_DEADLINE)); + electorConfig.setRetryPeriod( + configuration.getInteger(KubernetesConfigKeys.LEADER_ELECTION_RETRY_PERIOD)); + electorConfig.setIdentity(identity); + electorConfig.setConfigMapName(configMapName); + return electorConfig; + } + + public String getNamespace() { + return namespace; + } + + public void setNamespace(String namespace) { + this.namespace = namespace; + } + + public String getIdentity() { + return identity; + } + + public void setIdentity(String identity) { + this.identity = identity; + } + + public String getConfigMapName() { + return configMapName; + } + + public void setConfigMapName(String configMapName) { + this.configMapName = configMapName; + } + + public Integer getLeaseDuration() { + return leaseDuration; + } + + public void setLeaseDuration(Integer leaseDuration) { + this.leaseDuration = leaseDuration; + } + + public Integer getRenewDeadline() { + return renewDeadline; + } + + public void setRenewDeadline(Integer renewDeadline) { + this.renewDeadline = renewDeadline; + } + + public Integer getRetryPeriod() { + return retryPeriod; + } + + public void setRetryPeriod(Integer retryPeriod) { + this.retryPeriod = retryPeriod; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtils.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtils.java index b2248f754..3031952b0 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtils.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtils.java @@ -27,13 +27,6 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.SERVICE_SUFFIX; import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.USE_IP_IN_HOST_NETWORK; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Joiner; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.HostAlias; -import io.fabric8.kubernetes.api.model.NodeSelectorRequirement; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.api.model.Toleration; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; @@ -46,7 +39,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import javax.annotation.Nullable; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; @@ -58,314 +53,333 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Common utils for Kubernetes. - */ -public class KubernetesUtils { +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesUtils.class); +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.HostAlias; +import io.fabric8.kubernetes.api.model.NodeSelectorRequirement; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Toleration; - private static InetAddress resolveServiceAddress(String serviceName) { - try { - return InetAddress.getByName(serviceName); - } catch (UnknownHostException e) { - return null; - } - } +/** Common utils for Kubernetes. */ +public class KubernetesUtils { - public static Map getPairsConf(Configuration config, ConfigKey configKey) { - return getPairsConf(config, configKey.getKey()); - } + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesUtils.class); - public static Map getPairsConf(Configuration config, String configKey) { - Map pairs = new HashMap<>(); - String pairsStr = config.getString(configKey); - if (pairsStr != null) { - for (String label : pairsStr.split(",")) { - String[] splits = label.split(":"); - if (splits.length == 2) { - pairs.put(splits[0], splits[1]); - } - } - } - return pairs; + private static InetAddress resolveServiceAddress(String serviceName) { + try { + return InetAddress.getByName(serviceName); + } catch (UnknownHostException e) { + return null; } - - public static List getHostAliases(ConfigMap configMap) { - String hostAliases = configMap.getData().get(K8SConstants.HOST_ALIASES_CONFIG_MAP_NAME); - List hostAliasesList = new ArrayList<>(); - if (hostAliases != null) { - for (String item : hostAliases.split("\n")) { - if (item.startsWith("#")) { - continue; - } - String[] splits = item.split("\\s+"); - if (splits.length >= 2) { - List hostNames = new ArrayList<>(); - for (int i = 1; i < splits.length; i++) { - hostNames.add(splits[i].toLowerCase()); - } - hostAliasesList.add(new HostAlias(hostNames, splits[0])); - } - } + } + + public static Map getPairsConf(Configuration config, ConfigKey configKey) { + return getPairsConf(config, configKey.getKey()); + } + + public static Map getPairsConf(Configuration config, String configKey) { + Map pairs = new HashMap<>(); + String pairsStr = config.getString(configKey); + if (pairsStr != null) { + for (String label : pairsStr.split(",")) { + String[] splits = label.split(":"); + if (splits.length == 2) { + pairs.put(splits[0], splits[1]); } - return hostAliasesList; + } } - - public static Map loadConfigurationFromString(String content) { - Map config = new HashMap<>(); - for (String line : content.split(System.lineSeparator())) { - String[] splits = line.split(":"); - if (splits.length >= 2) { - config.put(splits[0].trim(), StringUtils.substringAfter(line, ":").trim()); - } + return pairs; + } + + public static List getHostAliases(ConfigMap configMap) { + String hostAliases = configMap.getData().get(K8SConstants.HOST_ALIASES_CONFIG_MAP_NAME); + List hostAliasesList = new ArrayList<>(); + if (hostAliases != null) { + for (String item : hostAliases.split("\n")) { + if (item.startsWith("#")) { + continue; } - return config; - } - - /** - * Method to extract variables from the config based on the given prefix String. - * - * @param prefix Prefix for the variables key - * @param config The config to get the environment variable defintion from - */ - public static Map getVariablesWithPrefix(String prefix, - Map config) { - Map result = new HashMap<>(); - for (Map.Entry entry : config.entrySet()) { - if (entry.getKey().startsWith(prefix) && entry.getKey().length() > prefix.length()) { - // remove prefix - String key = entry.getKey().substring(prefix.length()); - result.put(key, entry.getValue()); - } + String[] splits = item.split("\\s+"); + if (splits.length >= 2) { + List hostNames = new ArrayList<>(); + for (int i = 1; i < splits.length; i++) { + hostNames.add(splits[i].toLowerCase()); + } + hostAliasesList.add(new HostAlias(hostNames, splits[0])); } - return result; + } } - - public static Configuration loadConfiguration() throws Exception { - Configuration config = loadConfigurationFromFile(); - - KubernetesConfig.DockerNetworkType dockerNetworkType = - KubernetesConfig.getDockerNetworkType( - config); - - // Wait for service to be resolved. - String serviceIp = waitForServiceNameResolved(config, false).getHostAddress(); - config.put(MASTER_ADDRESS, serviceIp); - if (dockerNetworkType == KubernetesConfig.DockerNetworkType.HOST) { - try { - InetAddress addr = InetAddress.getLocalHost(); - if (config.getBoolean(USE_IP_IN_HOST_NETWORK)) { - config.put(MASTER_ADDRESS, serviceIp); - } else { - config.put(MASTER_ADDRESS, addr.getHostName()); - } - } catch (UnknownHostException e) { - LOGGER.warn("Get hostname for master error {}.", e.getMessage()); - } - } - return config; + return hostAliasesList; + } + + public static Map loadConfigurationFromString(String content) { + Map config = new HashMap<>(); + for (String line : content.split(System.lineSeparator())) { + String[] splits = line.split(":"); + if (splits.length >= 2) { + config.put(splits[0].trim(), StringUtils.substringAfter(line, ":").trim()); + } } - - public static Configuration loadConfigurationFromFile() throws IOException { - String configDir = System.getenv().get(K8SConstants.ENV_CONF_DIR); - if (configDir == null) { - throw new IllegalArgumentException( - "Given configuration directory is null, cannot " + "load configuration"); - } - - final File confDirFile = new File(configDir); - if (!(confDirFile.exists())) { - throw new RuntimeException( - "The given configuration directory name '" + configDir + "' (" - + confDirFile.getAbsolutePath() + ") does not describe an existing directory."); - } - - // get yaml configuration file - final File yamlConfigFile = new File(confDirFile, K8SConstants.ENV_CONFIG_FILE); - - if (!yamlConfigFile.exists()) { - throw new IOException( - "The config file '" + yamlConfigFile + "' (" + yamlConfigFile.getAbsolutePath() - + ") does not exist."); - } - - return loadYAMLResource(yamlConfigFile); + return config; + } + + /** + * Method to extract variables from the config based on the given prefix String. + * + * @param prefix Prefix for the variables key + * @param config The config to get the environment variable defintion from + */ + public static Map getVariablesWithPrefix( + String prefix, Map config) { + Map result = new HashMap<>(); + for (Map.Entry entry : config.entrySet()) { + if (entry.getKey().startsWith(prefix) && entry.getKey().length() > prefix.length()) { + // remove prefix + String key = entry.getKey().substring(prefix.length()); + result.put(key, entry.getValue()); + } } - - /** - * This method is an adaptation of Flink's - * org.apache.flink.configuration.GlobalConfiguration#loadYAMLResource - */ - @VisibleForTesting - public static Configuration loadYAMLResource(File file) { - final Configuration config = new Configuration(); - - try (BufferedReader reader = new BufferedReader( - new InputStreamReader(new FileInputStream(file)))) { - - String line; - int lineNo = 0; - while ((line = reader.readLine()) != null) { - lineNo++; - // 1. check for comments - String[] comments = line.split("#", 2); - String conf = comments[0].trim(); - - // 2. get key and value - if (conf.length() > 0) { - String key; - String value; - - String[] kv = conf.split(CONFIG_KV_SEPARATOR, 2); - if (kv.length < 1) { - LOGGER.warn( - "Error while trying to split key and value in configuration file " - + file + ":" + lineNo + ": \"" + line + "\""); - continue; - } - - key = kv[0].trim(); - value = kv.length == 1 ? "" : kv[1].trim(); - - // sanity check - if (key.length() == 0) { - LOGGER.warn( - "Error after splitting key in configuration file " + file + ":" + lineNo - + ": \"" + line + "\""); - continue; - } - - LOGGER.info("Loading property: {}, {}", key, value); - config.put(key, value); - } - } - } catch (IOException e) { - throw new RuntimeException("Error parsing YAML configuration.", e); + return result; + } + + public static Configuration loadConfiguration() throws Exception { + Configuration config = loadConfigurationFromFile(); + + KubernetesConfig.DockerNetworkType dockerNetworkType = + KubernetesConfig.getDockerNetworkType(config); + + // Wait for service to be resolved. + String serviceIp = waitForServiceNameResolved(config, false).getHostAddress(); + config.put(MASTER_ADDRESS, serviceIp); + if (dockerNetworkType == KubernetesConfig.DockerNetworkType.HOST) { + try { + InetAddress addr = InetAddress.getLocalHost(); + if (config.getBoolean(USE_IP_IN_HOST_NETWORK)) { + config.put(MASTER_ADDRESS, serviceIp); + } else { + config.put(MASTER_ADDRESS, addr.getHostName()); } - - return config; + } catch (UnknownHostException e) { + LOGGER.warn("Get hostname for master error {}.", e.getMessage()); + } + } + return config; + } + + public static Configuration loadConfigurationFromFile() throws IOException { + String configDir = System.getenv().get(K8SConstants.ENV_CONF_DIR); + if (configDir == null) { + throw new IllegalArgumentException( + "Given configuration directory is null, cannot " + "load configuration"); } - public static InetAddress waitForServiceNameResolved(Configuration config, - boolean appendSuffix) { - String serviceNameWithNamespace = KubernetesConfig.getServiceNameWithNamespace(config); - String suffix = config.getString(SERVICE_SUFFIX); - if (appendSuffix && !StringUtils.isBlank(suffix)) { - serviceNameWithNamespace += K8SConstants.NAMESPACE_SEPARATOR + suffix; - } - LOGGER.info("Waiting for service {} to be resolved.", serviceNameWithNamespace); - - InetAddress serviceAddress; - final long startTime = System.currentTimeMillis(); - do { - serviceAddress = resolveServiceAddress(serviceNameWithNamespace); - if (System.currentTimeMillis() - startTime > 60000) { - LOGGER.warn("Resolve service took more than 60 seconds, please check logs on the " - + "Kubernetes cluster."); - } - SleepUtils.sleepMilliSecond(250); - } while (serviceAddress == null); - - LOGGER.info("Service {} resolved to {}", serviceNameWithNamespace, serviceAddress); - return serviceAddress; + final File confDirFile = new File(configDir); + if (!(confDirFile.exists())) { + throw new RuntimeException( + "The given configuration directory name '" + + configDir + + "' (" + + confDirFile.getAbsolutePath() + + ") does not describe an existing directory."); } + // get yaml configuration file + final File yamlConfigFile = new File(confDirFile, K8SConstants.ENV_CONFIG_FILE); - public static List getTolerations(Configuration config) { - List tolerationList = new ArrayList<>(); - if (!config.contains(KubernetesConfigKeys.TOLERATION_LIST)) { - return tolerationList; - } - String tolerations = config.getString(KubernetesConfigKeys.TOLERATION_LIST); - for (String each : tolerations.trim().split(",")) { - String[] parts = each.split(":", -1); - if (parts.length != 5) { - LOGGER.error("parse toleration error, {}", each); - continue; - } - Toleration toleration = new Toleration(); - if (parts[0] != null && !parts[0].isEmpty() && !parts[0].equals("-")) { - toleration.setKey(parts[0]); - } - if (parts[1] != null && !parts[1].isEmpty() && !parts[1].equals("-")) { - toleration.setOperator(parts[1]); - } - if (parts[2] != null && !parts[2].isEmpty() && !parts[2].equals("-")) { - toleration.setValue(parts[2]); - } - if (parts[3] != null && !parts[3].isEmpty() && !parts[3].equals("-")) { - toleration.setEffect(parts[3]); - } - if (parts[4] != null && !parts[4].isEmpty() && !parts[4].equals("-")) { - toleration.setTolerationSeconds(Long.valueOf(parts[4])); - } - tolerationList.add(toleration); - } - return tolerationList; + if (!yamlConfigFile.exists()) { + throw new IOException( + "The config file '" + + yamlConfigFile + + "' (" + + yamlConfigFile.getAbsolutePath() + + ") does not exist."); } - public static List getMatchExpressions(Configuration config) { - List matchExpressionList = new ArrayList<>(); - if (!config.contains(KubernetesConfigKeys.MATCH_EXPRESSION_LIST)) { - return matchExpressionList; + return loadYAMLResource(yamlConfigFile); + } + + /** + * This method is an adaptation of Flink's + * org.apache.flink.configuration.GlobalConfiguration#loadYAMLResource + */ + @VisibleForTesting + public static Configuration loadYAMLResource(File file) { + final Configuration config = new Configuration(); + + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(new FileInputStream(file)))) { + + String line; + int lineNo = 0; + while ((line = reader.readLine()) != null) { + lineNo++; + // 1. check for comments + String[] comments = line.split("#", 2); + String conf = comments[0].trim(); + + // 2. get key and value + if (conf.length() > 0) { + String key; + String value; + + String[] kv = conf.split(CONFIG_KV_SEPARATOR, 2); + if (kv.length < 1) { + LOGGER.warn( + "Error while trying to split key and value in configuration file " + + file + + ":" + + lineNo + + ": \"" + + line + + "\""); + continue; + } + + key = kv[0].trim(); + value = kv.length == 1 ? "" : kv[1].trim(); + + // sanity check + if (key.length() == 0) { + LOGGER.warn( + "Error after splitting key in configuration file " + + file + + ":" + + lineNo + + ": \"" + + line + + "\""); + continue; + } + + LOGGER.info("Loading property: {}, {}", key, value); + config.put(key, value); } - String matchExpressions = config.getString(KubernetesConfigKeys.MATCH_EXPRESSION_LIST); - for (String each : matchExpressions.trim().split(",")) { - String[] parts = each.split(":", -1); - if (parts.length != 3) { - LOGGER.error("parse matchExpressions error, {}", each); - continue; - } - NodeSelectorRequirement matchExpression = new NodeSelectorRequirement(); - if (parts[0] != null && !parts[0].isEmpty() && !parts[0].equals("-")) { - matchExpression.setKey(parts[0]); - } - if (parts[1] != null && !parts[1].isEmpty() && !parts[1].equals("-")) { - matchExpression.setOperator(parts[1]); - } - if (parts[2] != null && !parts[2].isEmpty() && !parts[2].equals("-")) { - matchExpression.setValues(Arrays.asList(parts[2])); - } - matchExpressionList.add(matchExpression); - } - return matchExpressionList; + } + } catch (IOException e) { + throw new RuntimeException("Error parsing YAML configuration.", e); } - @Nullable - public static String extractComponentId(Pod pod) { - return pod.getMetadata().getLabels().get(K8SConstants.LABEL_COMPONENT_ID_KEY); - } + return config; + } - @Nullable - public static String extractComponent(Pod pod) { - return pod.getMetadata().getLabels().get(K8SConstants.LABEL_COMPONENT_KEY); + public static InetAddress waitForServiceNameResolved(Configuration config, boolean appendSuffix) { + String serviceNameWithNamespace = KubernetesConfig.getServiceNameWithNamespace(config); + String suffix = config.getString(SERVICE_SUFFIX); + if (appendSuffix && !StringUtils.isBlank(suffix)) { + serviceNameWithNamespace += K8SConstants.NAMESPACE_SEPARATOR + suffix; } - - public static String encodeRpcAddressMap(Map addressMap) { - return Joiner.on(CONFIG_LIST_SEPARATOR).withKeyValueSeparator(ADDRESS_SEPARATOR) - .join(addressMap); + LOGGER.info("Waiting for service {} to be resolved.", serviceNameWithNamespace); + + InetAddress serviceAddress; + final long startTime = System.currentTimeMillis(); + do { + serviceAddress = resolveServiceAddress(serviceNameWithNamespace); + if (System.currentTimeMillis() - startTime > 60000) { + LOGGER.warn( + "Resolve service took more than 60 seconds, please check logs on the " + + "Kubernetes cluster."); + } + SleepUtils.sleepMilliSecond(250); + } while (serviceAddress == null); + + LOGGER.info("Service {} resolved to {}", serviceNameWithNamespace, serviceAddress); + return serviceAddress; + } + + public static List getTolerations(Configuration config) { + List tolerationList = new ArrayList<>(); + if (!config.contains(KubernetesConfigKeys.TOLERATION_LIST)) { + return tolerationList; } - - public static Map decodeRpcAddressMap(String str) { - Map map = new HashMap<>(); - for (String entry : str.trim().split(CONFIG_LIST_SEPARATOR)) { - String[] pair = entry.split(ADDRESS_SEPARATOR); - map.put(pair[0], ConnectAddress.build(pair[1])); - } - return map; + String tolerations = config.getString(KubernetesConfigKeys.TOLERATION_LIST); + for (String each : tolerations.trim().split(",")) { + String[] parts = each.split(":", -1); + if (parts.length != 5) { + LOGGER.error("parse toleration error, {}", each); + continue; + } + Toleration toleration = new Toleration(); + if (parts[0] != null && !parts[0].isEmpty() && !parts[0].equals("-")) { + toleration.setKey(parts[0]); + } + if (parts[1] != null && !parts[1].isEmpty() && !parts[1].equals("-")) { + toleration.setOperator(parts[1]); + } + if (parts[2] != null && !parts[2].isEmpty() && !parts[2].equals("-")) { + toleration.setValue(parts[2]); + } + if (parts[3] != null && !parts[3].isEmpty() && !parts[3].equals("-")) { + toleration.setEffect(parts[3]); + } + if (parts[4] != null && !parts[4].isEmpty() && !parts[4].equals("-")) { + toleration.setTolerationSeconds(Long.valueOf(parts[4])); + } + tolerationList.add(toleration); } + return tolerationList; + } - public static String getMasterServiceName(String clusterId) { - return clusterId + K8SConstants.SERVICE_NAME_SUFFIX; + public static List getMatchExpressions(Configuration config) { + List matchExpressionList = new ArrayList<>(); + if (!config.contains(KubernetesConfigKeys.MATCH_EXPRESSION_LIST)) { + return matchExpressionList; } - - public static String getMasterClientServiceName(String clusterId) { - return clusterId + K8SConstants.CLIENT_SERVICE_NAME_SUFFIX; + String matchExpressions = config.getString(KubernetesConfigKeys.MATCH_EXPRESSION_LIST); + for (String each : matchExpressions.trim().split(",")) { + String[] parts = each.split(":", -1); + if (parts.length != 3) { + LOGGER.error("parse matchExpressions error, {}", each); + continue; + } + NodeSelectorRequirement matchExpression = new NodeSelectorRequirement(); + if (parts[0] != null && !parts[0].isEmpty() && !parts[0].equals("-")) { + matchExpression.setKey(parts[0]); + } + if (parts[1] != null && !parts[1].isEmpty() && !parts[1].equals("-")) { + matchExpression.setOperator(parts[1]); + } + if (parts[2] != null && !parts[2].isEmpty() && !parts[2].equals("-")) { + matchExpression.setValues(Arrays.asList(parts[2])); + } + matchExpressionList.add(matchExpression); } - - public static String getDriverServiceName(String clusterId, int driverIndex) { - return clusterId + DRIVER_SERVICE_NAME_SUFFIX + driverIndex; + return matchExpressionList; + } + + @Nullable public static String extractComponentId(Pod pod) { + return pod.getMetadata().getLabels().get(K8SConstants.LABEL_COMPONENT_ID_KEY); + } + + @Nullable public static String extractComponent(Pod pod) { + return pod.getMetadata().getLabels().get(K8SConstants.LABEL_COMPONENT_KEY); + } + + public static String encodeRpcAddressMap(Map addressMap) { + return Joiner.on(CONFIG_LIST_SEPARATOR) + .withKeyValueSeparator(ADDRESS_SEPARATOR) + .join(addressMap); + } + + public static Map decodeRpcAddressMap(String str) { + Map map = new HashMap<>(); + for (String entry : str.trim().split(CONFIG_LIST_SEPARATOR)) { + String[] pair = entry.split(ADDRESS_SEPARATOR); + map.put(pair[0], ConnectAddress.build(pair[1])); } + return map; + } + + public static String getMasterServiceName(String clusterId) { + return clusterId + K8SConstants.SERVICE_NAME_SUFFIX; + } + + public static String getMasterClientServiceName(String clusterId) { + return clusterId + K8SConstants.CLIENT_SERVICE_NAME_SUFFIX; + } + public static String getDriverServiceName(String clusterId, int driverIndex) { + return clusterId + DRIVER_SERVICE_NAME_SUFFIX + driverIndex; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/watcher/KubernetesPodWatcher.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/watcher/KubernetesPodWatcher.java index 24b98142d..69c9ede2b 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/watcher/KubernetesPodWatcher.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/main/java/org/apache/geaflow/cluster/k8s/watcher/KubernetesPodWatcher.java @@ -22,10 +22,6 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys.WATCHER_CHECK_INTERVAL; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.client.Watch; -import io.fabric8.kubernetes.client.Watcher; -import io.fabric8.kubernetes.client.Watcher.Action; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executors; @@ -33,6 +29,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; + import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.handler.IPodEventHandler; @@ -44,79 +41,92 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.client.Watch; +import io.fabric8.kubernetes.client.Watcher; +import io.fabric8.kubernetes.client.Watcher.Action; + public class KubernetesPodWatcher { - private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesPodWatcher.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KubernetesPodWatcher.class); - private Watch watcher; - private final int checkInterval; - private volatile boolean watcherClosed; - private final Map> eventHandlerMap; - private final GeaflowKubeClient kubernetesClient; - private final Map labels; - private final ScheduledExecutorService executorService; + private Watch watcher; + private final int checkInterval; + private volatile boolean watcherClosed; + private final Map> eventHandlerMap; + private final GeaflowKubeClient kubernetesClient; + private final Map labels; + private final ScheduledExecutorService executorService; - public KubernetesPodWatcher(Configuration config) { - this.watcherClosed = true; - this.checkInterval = config.getInteger(WATCHER_CHECK_INTERVAL); - this.kubernetesClient = new GeaflowKubeClient(config); + public KubernetesPodWatcher(Configuration config) { + this.watcherClosed = true; + this.checkInterval = config.getInteger(WATCHER_CHECK_INTERVAL); + this.kubernetesClient = new GeaflowKubeClient(config); - this.labels = new HashMap<>(); - this.labels.put(K8SConstants.LABEL_APP_KEY, config.getString(CLUSTER_ID)); + this.labels = new HashMap<>(); + this.labels.put(K8SConstants.LABEL_APP_KEY, config.getString(CLUSTER_ID)); - this.executorService = Executors.newSingleThreadScheduledExecutor( + this.executorService = + Executors.newSingleThreadScheduledExecutor( ThreadUtil.namedThreadFactory(true, "cluster-watcher")); - PodHandlerRegistry registry = PodHandlerRegistry.getInstance(config); - this.eventHandlerMap = registry.getHandlerMap(); - } + PodHandlerRegistry registry = PodHandlerRegistry.getInstance(config); + this.eventHandlerMap = registry.getHandlerMap(); + } - public void start() { - createAndStartPodsWatcher(); - } + public void start() { + createAndStartPodsWatcher(); + } - public void close() { - executorService.shutdown(); - } + public void close() { + executorService.shutdown(); + } - private void createAndStartPodsWatcher() { - BiConsumer eventHandler = this::handlePodMessage; - Consumer exceptionHandler = (exception) -> { - watcherClosed = true; - LOGGER.warn("watch exception: {}", exception.getMessage(), exception); + private void createAndStartPodsWatcher() { + BiConsumer eventHandler = this::handlePodMessage; + Consumer exceptionHandler = + (exception) -> { + watcherClosed = true; + LOGGER.warn("watch exception: {}", exception.getMessage(), exception); }; - ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor( + ScheduledExecutorService executorService = + Executors.newSingleThreadScheduledExecutor( ThreadUtil.namedThreadFactory(true, "watcher-creator")); - executorService.scheduleAtFixedRate(() -> { - if (watcherClosed) { - if (watcher != null) { - watcher.close(); - } - watcher = kubernetesClient.createPodsWatcher(labels, eventHandler, - exceptionHandler); - if (watcher != null) { - watcherClosed = false; - } + executorService.scheduleAtFixedRate( + () -> { + if (watcherClosed) { + if (watcher != null) { + watcher.close(); } - }, 0, checkInterval, TimeUnit.SECONDS); - } + watcher = kubernetesClient.createPodsWatcher(labels, eventHandler, exceptionHandler); + if (watcher != null) { + watcherClosed = false; + } + } + }, + 0, + checkInterval, + TimeUnit.SECONDS); + } - private void handlePodMessage(Watcher.Action action, Pod pod) { - String componentId = KubernetesUtils.extractComponentId(pod); - if (componentId == null) { - LOGGER.warn("Unknown pod {} with labels:{} event:{}", pod.getMetadata().getName(), - pod.getMetadata().getLabels(), action); - return; - } - String component = KubernetesUtils.extractComponent(pod); - if (K8SConstants.LABEL_COMPONENT_CLIENT.equals(component)) { - return; - } - if (eventHandlerMap.containsKey(action)) { - eventHandlerMap.get(action).forEach((kind, handler) -> handler.handle(pod)); - } else { - LOGGER.info("Skip {} event for pod {}", action, pod.getMetadata().getName()); - } + private void handlePodMessage(Watcher.Action action, Pod pod) { + String componentId = KubernetesUtils.extractComponentId(pod); + if (componentId == null) { + LOGGER.warn( + "Unknown pod {} with labels:{} event:{}", + pod.getMetadata().getName(), + pod.getMetadata().getLabels(), + action); + return; } - + String component = KubernetesUtils.extractComponent(pod); + if (K8SConstants.LABEL_COMPONENT_CLIENT.equals(component)) { + return; + } + if (eventHandlerMap.containsKey(action)) { + eventHandlerMap.get(action).forEach((kind, handler) -> handler.handle(pod)); + } else { + LOGGER.info("Skip {} event for pod {}", action, pod.getMetadata().getName()); + } + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironmentTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironmentTest.java index b57221698..5555d4bb7 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironmentTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesEnvironmentTest.java @@ -27,11 +27,10 @@ public class KubernetesEnvironmentTest { - @Test - public void testLoad() { - Environment env = EnvironmentFactory.onK8SEnvironment(); - Assert.assertTrue(env instanceof KubernetesEnvironment); - Assert.assertEquals(env.getEnvType(), EnvType.K8S); - } - + @Test + public void testLoad() { + Environment env = EnvironmentFactory.onK8SEnvironment(); + Assert.assertTrue(env instanceof KubernetesEnvironment); + Assert.assertEquals(env.getEnvType(), EnvType.K8S); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClientTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClientTest.java index eb4339274..ce6fb5ae2 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClientTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobClientTest.java @@ -26,11 +26,8 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLIENT_DISK_GB; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.MASTER_MEMORY_MB; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.client.KubernetesClient; -import io.fabric8.kubernetes.client.server.mock.KubernetesServer; import java.util.List; + import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys; @@ -42,63 +39,67 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.server.mock.KubernetesServer; + public class KubernetesJobClientTest { - private Configuration jobConf; - - private static final int MASTER_MEMORY = 128; - private static final String CLUSTER_ID = "geaflow-cluster-1"; - private static final String CONF_DIR_IN_IMAGE = "/geaflow/conf"; - private static final String MASTER_CONTAINER_NAME = "geaflow-master"; - - private final KubernetesServer server = new KubernetesServer(false, true); - private KubernetesJobClient jobClient; - private KubernetesClient kubernetesClient; - private GeaflowKubeClient geaflowKubeClient; - - @BeforeMethod - public void setUp() { - jobConf = new Configuration(); - jobConf.put(KubernetesConfigKeys.CONF_DIR.getKey(), CONF_DIR_IN_IMAGE); - jobConf.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), CLUSTER_ID); - jobConf.put(KubernetesMasterParam.MASTER_CONTAINER_NAME, MASTER_CONTAINER_NAME); - jobConf.put(MASTER_MEMORY_MB.getKey(), String.valueOf(MASTER_MEMORY)); - jobConf.setMasterId(CLUSTER_ID + "_MASTER"); - jobConf.put(CLIENT_DISK_GB.getKey(), "5"); - - server.before(); - kubernetesClient = server.getClient(); - geaflowKubeClient = new GeaflowKubeClient(kubernetesClient, jobConf); - jobClient = new KubernetesJobClient(jobConf.getConfigMap(), geaflowKubeClient); - } - - @AfterMethod - public void destroy() { - jobClient.stopJob(); - server.after(); - } - - @Test - public void testCreateJobClient() { - jobClient.submitJob(); - SleepUtils.sleepSecond(2); - - // check config map - String configMapName = CLUSTER_ID + K8SConstants.CLIENT_CONFIG_MAP_SUFFIX; - ConfigMap configMap = kubernetesClient.configMaps().withName(configMapName).get(); - assertNotNull(configMap); - assertEquals(1, configMap.getData().size()); - assertTrue(configMap.getData().containsKey(K8SConstants.ENV_CONFIG_FILE)); - - // check pod reference - Pod pod = - kubernetesClient.pods().withName(CLUSTER_ID + CLIENT_NAME_SUFFIX).get(); - assertNotNull(configMap.getMetadata().getOwnerReferences()); - assertEquals(configMapName, pod.getMetadata().getOwnerReferences().get(0).getName()); - assertEquals(pod.getSpec().getContainers().size(), 1); - - List jobPods = jobClient.getJobPods(); - Pod jobPod = jobPods.get(0); - assertEquals(jobPod.getSpec().getContainers().size(), 1); - } + private Configuration jobConf; + + private static final int MASTER_MEMORY = 128; + private static final String CLUSTER_ID = "geaflow-cluster-1"; + private static final String CONF_DIR_IN_IMAGE = "/geaflow/conf"; + private static final String MASTER_CONTAINER_NAME = "geaflow-master"; + + private final KubernetesServer server = new KubernetesServer(false, true); + private KubernetesJobClient jobClient; + private KubernetesClient kubernetesClient; + private GeaflowKubeClient geaflowKubeClient; + + @BeforeMethod + public void setUp() { + jobConf = new Configuration(); + jobConf.put(KubernetesConfigKeys.CONF_DIR.getKey(), CONF_DIR_IN_IMAGE); + jobConf.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), CLUSTER_ID); + jobConf.put(KubernetesMasterParam.MASTER_CONTAINER_NAME, MASTER_CONTAINER_NAME); + jobConf.put(MASTER_MEMORY_MB.getKey(), String.valueOf(MASTER_MEMORY)); + jobConf.setMasterId(CLUSTER_ID + "_MASTER"); + jobConf.put(CLIENT_DISK_GB.getKey(), "5"); + + server.before(); + kubernetesClient = server.getClient(); + geaflowKubeClient = new GeaflowKubeClient(kubernetesClient, jobConf); + jobClient = new KubernetesJobClient(jobConf.getConfigMap(), geaflowKubeClient); + } + + @AfterMethod + public void destroy() { + jobClient.stopJob(); + server.after(); + } + + @Test + public void testCreateJobClient() { + jobClient.submitJob(); + SleepUtils.sleepSecond(2); + + // check config map + String configMapName = CLUSTER_ID + K8SConstants.CLIENT_CONFIG_MAP_SUFFIX; + ConfigMap configMap = kubernetesClient.configMaps().withName(configMapName).get(); + assertNotNull(configMap); + assertEquals(1, configMap.getData().size()); + assertTrue(configMap.getData().containsKey(K8SConstants.ENV_CONFIG_FILE)); + + // check pod reference + Pod pod = kubernetesClient.pods().withName(CLUSTER_ID + CLIENT_NAME_SUFFIX).get(); + assertNotNull(configMap.getMetadata().getOwnerReferences()); + assertEquals(configMapName, pod.getMetadata().getOwnerReferences().get(0).getName()); + assertEquals(pod.getSpec().getContainers().size(), 1); + + List jobPods = jobClient.getJobPods(); + Pod jobPod = jobPods.get(0); + assertEquals(jobPod.getSpec().getContainers().size(), 1); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitterTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitterTest.java index da33a3ee0..06bf79612 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitterTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/client/KubernetesJobSubmitterTest.java @@ -21,30 +21,32 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_ID; -import com.alibaba.fastjson.JSON; -import io.fabric8.kubernetes.client.KubernetesClientException; import java.util.HashMap; import java.util.Map; + import org.testng.annotations.Test; +import com.alibaba.fastjson.JSON; + +import io.fabric8.kubernetes.client.KubernetesClientException; + public class KubernetesJobSubmitterTest { - @Test(expectedExceptions = NoSuchMethodException.class) - public void testSubmit() throws Throwable { - KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); - String[] args = new String[]{"start", this.getClass().getCanonicalName(), "{}"}; - submitter.submitJob(args); - } - - @Test(expectedExceptions = KubernetesClientException.class) - public void testStop() throws Throwable { - KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); - Map jobConfig = new HashMap<>(); - jobConfig.put(CLUSTER_ID.getKey(), "124"); - Map> config = new HashMap<>(); - config.put("job", jobConfig); - String[] args = new String[]{"stop", JSON.toJSONString(config)}; - submitter.stopJob(args); - } + @Test(expectedExceptions = NoSuchMethodException.class) + public void testSubmit() throws Throwable { + KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); + String[] args = new String[] {"start", this.getClass().getCanonicalName(), "{}"}; + submitter.submitJob(args); + } + @Test(expectedExceptions = KubernetesClientException.class) + public void testStop() throws Throwable { + KubernetesJobSubmitter submitter = new KubernetesJobSubmitter(); + Map jobConfig = new HashMap<>(); + jobConfig.put(CLUSTER_ID.getKey(), "124"); + Map> config = new HashMap<>(); + config.put("job", jobConfig); + String[] args = new String[] {"stop", JSON.toJSONString(config)}; + submitter.stopJob(args); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManagerTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManagerTest.java index bde4f2050..30a49078b 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManagerTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/clustermanager/KubernetesClusterManagerTest.java @@ -35,22 +35,11 @@ import static org.apache.geaflow.cluster.k8s.config.KubernetesMasterParam.MASTER_NODE_SELECTOR; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.MASTER_MEMORY_MB; -import io.fabric8.kubernetes.api.model.ConfigMap; -import io.fabric8.kubernetes.api.model.Container; -import io.fabric8.kubernetes.api.model.ContainerPort; -import io.fabric8.kubernetes.api.model.ContainerPortBuilder; -import io.fabric8.kubernetes.api.model.EnvVar; -import io.fabric8.kubernetes.api.model.EnvVarBuilder; -import io.fabric8.kubernetes.api.model.Pod; -import io.fabric8.kubernetes.api.model.Service; -import io.fabric8.kubernetes.api.model.apps.Deployment; -import io.fabric8.kubernetes.client.KubernetesClient; -import io.fabric8.kubernetes.client.server.mock.KubernetesServer; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import junit.framework.TestCase; + import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.k8s.config.K8SConstants; import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; @@ -67,301 +56,331 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -public class KubernetesClusterManagerTest { - - private Configuration jobConf; - - private static final int MASTER_MEMORY = 128; - private static final String MASTER_COMPONENT = "master"; - private static final String CLUSTER_ID = "geaflow-cluster-1"; - private static final String SERVICE_ACCOUNT = "geaflow"; - private static final String CONF_DIR_IN_IMAGE = "/geaflow/conf"; - private static final String MASTER_CONTAINER_NAME = "geaflow-master"; - - private KubernetesClient kubernetesClient; - private GeaflowKubeClient geaflowKubeClient; - private KubernetesServer server = new KubernetesServer(false, true); - private KubernetesClusterManager kubernetesClusterManager; - - @BeforeMethod - public void setUp() { - jobConf = new Configuration(); - jobConf.put(ExecutionConfigKeys.RUN_LOCAL_MODE, Boolean.TRUE.toString()); - jobConf.put(KubernetesConfigKeys.CONF_DIR.getKey(), CONF_DIR_IN_IMAGE); - jobConf.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), CLUSTER_ID); - jobConf.put(KubernetesMasterParam.MASTER_CONTAINER_NAME, MASTER_CONTAINER_NAME); - jobConf.put(MASTER_MEMORY_MB.getKey(), String.valueOf(MASTER_MEMORY)); - jobConf.put(MATCH_EXPRESSION_LIST, "key1:In:value1,key2:In:-"); - jobConf.setMasterId(CLUSTER_ID + "_MASTER"); - - server.before(); - kubernetesClient = server.getClient(); - geaflowKubeClient = new GeaflowKubeClient(kubernetesClient, jobConf); - - kubernetesClusterManager = new KubernetesClusterManager(); - ClusterContext context = new ClusterContext(jobConf); - kubernetesClusterManager.init(context, geaflowKubeClient); - } +import io.fabric8.kubernetes.api.model.ConfigMap; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.ContainerPort; +import io.fabric8.kubernetes.api.model.ContainerPortBuilder; +import io.fabric8.kubernetes.api.model.EnvVar; +import io.fabric8.kubernetes.api.model.EnvVarBuilder; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.api.model.apps.Deployment; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.server.mock.KubernetesServer; +import junit.framework.TestCase; - @AfterMethod - public void destroy() { - server.after(); - } +public class KubernetesClusterManagerTest { - @Test - public void testCreateMasterContainer() { - // Set environment for master - String envName = "env-a"; - String envValue = "value-a"; - jobConf.put(MASTER_ENV_PREFIX + envName, envValue); - - Container container = kubernetesClusterManager - .createMasterContainer(DockerNetworkType.BRIDGE); - - assertTrue(container.getCommand().isEmpty()); - Optional commandEnv = container.getEnv().stream() - .filter(e -> e.getName().equals(K8SConstants.ENV_START_COMMAND)).findAny(); - assertNotNull(commandEnv.get()); - assertNotNull(commandEnv.get().getValue()); - assertNotNull(container.getArgs()); - assertEquals(MASTER_CONTAINER_NAME, container.getName()); - ContainerPort httpPort = new ContainerPortBuilder().withContainerPort(8090) - .withName(K8SConstants.HTTP_PORT).withProtocol("TCP").build(); - assertTrue(container.getPorts().contains(httpPort)); - assertEquals(MASTER_MEMORY << 20, - Integer.parseInt(container.getResources().getLimits().get("memory").getAmount())); - - // Check environment - assertTrue(container.getEnv() + private Configuration jobConf; + + private static final int MASTER_MEMORY = 128; + private static final String MASTER_COMPONENT = "master"; + private static final String CLUSTER_ID = "geaflow-cluster-1"; + private static final String SERVICE_ACCOUNT = "geaflow"; + private static final String CONF_DIR_IN_IMAGE = "/geaflow/conf"; + private static final String MASTER_CONTAINER_NAME = "geaflow-master"; + + private KubernetesClient kubernetesClient; + private GeaflowKubeClient geaflowKubeClient; + private KubernetesServer server = new KubernetesServer(false, true); + private KubernetesClusterManager kubernetesClusterManager; + + @BeforeMethod + public void setUp() { + jobConf = new Configuration(); + jobConf.put(ExecutionConfigKeys.RUN_LOCAL_MODE, Boolean.TRUE.toString()); + jobConf.put(KubernetesConfigKeys.CONF_DIR.getKey(), CONF_DIR_IN_IMAGE); + jobConf.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), CLUSTER_ID); + jobConf.put(KubernetesMasterParam.MASTER_CONTAINER_NAME, MASTER_CONTAINER_NAME); + jobConf.put(MASTER_MEMORY_MB.getKey(), String.valueOf(MASTER_MEMORY)); + jobConf.put(MATCH_EXPRESSION_LIST, "key1:In:value1,key2:In:-"); + jobConf.setMasterId(CLUSTER_ID + "_MASTER"); + + server.before(); + kubernetesClient = server.getClient(); + geaflowKubeClient = new GeaflowKubeClient(kubernetesClient, jobConf); + + kubernetesClusterManager = new KubernetesClusterManager(); + ClusterContext context = new ClusterContext(jobConf); + kubernetesClusterManager.init(context, geaflowKubeClient); + } + + @AfterMethod + public void destroy() { + server.after(); + } + + @Test + public void testCreateMasterContainer() { + // Set environment for master + String envName = "env-a"; + String envValue = "value-a"; + jobConf.put(MASTER_ENV_PREFIX + envName, envValue); + + Container container = kubernetesClusterManager.createMasterContainer(DockerNetworkType.BRIDGE); + + assertTrue(container.getCommand().isEmpty()); + Optional commandEnv = + container.getEnv().stream() + .filter(e -> e.getName().equals(K8SConstants.ENV_START_COMMAND)) + .findAny(); + assertNotNull(commandEnv.get()); + assertNotNull(commandEnv.get().getValue()); + assertNotNull(container.getArgs()); + assertEquals(MASTER_CONTAINER_NAME, container.getName()); + ContainerPort httpPort = + new ContainerPortBuilder() + .withContainerPort(8090) + .withName(K8SConstants.HTTP_PORT) + .withProtocol("TCP") + .build(); + assertTrue(container.getPorts().contains(httpPort)); + assertEquals( + MASTER_MEMORY << 20, + Integer.parseInt(container.getResources().getLimits().get("memory").getAmount())); + + // Check environment + assertTrue( + container + .getEnv() .contains(new EnvVarBuilder().withName(envName).withValue(envValue).build())); - } - - @Test - public void testCreateMaster() { - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - // check config map - ConfigMap configMap = kubernetesClient.configMaps() - .withName(CLUSTER_ID + K8SConstants.MASTER_CONFIG_MAP_SUFFIX).get(); - assertNotNull(configMap); - assertEquals(1, configMap.getData().size()); - assertTrue(configMap.getData().containsKey(K8SConstants.ENV_CONFIG_FILE)); - - // check replication controller - Deployment deployment = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); - assertNotNull(deployment); - assertEquals(1, deployment.getSpec().getReplicas().intValue()); - assertEquals(3, deployment.getSpec().getSelector().getMatchLabels().size()); - assertEquals(MASTER_CONTAINER_NAME, - deployment.getSpec().getTemplate().getSpec().getContainers().get(0).getName()); - TestCase.assertEquals(K8SConstants.GEAFLOW_CONF_VOLUME, - deployment.getSpec().getTemplate().getSpec().getVolumes().get(0).getName()); - assertEquals(SERVICE_ACCOUNT, - deployment.getSpec().getTemplate().getSpec().getServiceAccountName()); - - // check service - Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); - assertNotNull(service); - assertEquals(labels, service.getSpec().getSelector()); - assertEquals(1, service.getSpec().getPorts().size()); - - // check owner reference - String serviceName = service.getMetadata().getName(); - assertNotNull(configMap.getMetadata().getOwnerReferences()); - assertEquals(serviceName, configMap.getMetadata().getOwnerReferences().get(0).getName()); - assertNotNull(deployment.getMetadata().getOwnerReferences().get(0)); - assertEquals(serviceName, deployment.getMetadata().getOwnerReferences().get(0).getName()); - } - - @Test(timeOut = 30000) - public void testKillCluster() { - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); - assertNotNull(service); - - ConfigMap configMap = geaflowKubeClient - .getConfigMap(CLUSTER_ID + K8SConstants.MASTER_CONFIG_MAP_SUFFIX); - assertNotNull(configMap); - - Deployment rc = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); - assertNotNull(rc); - - kubernetesClusterManager.close(); - } - - @Test(timeOut = 30000) - public void testMasterUserLabels() { - jobConf.put(POD_USER_LABELS.getKey(), "l1:test,l2:hello"); - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - labels.put(LABEL_COMPONENT_ID_KEY, String.valueOf(DEFAULT_MASTER_ID)); - labels.putAll(KubernetesUtils.getPairsConf(jobConf, POD_USER_LABELS)); - assertEquals(5, labels.size()); - assertEquals("test", labels.get("l1")); - assertEquals("hello", labels.get("l2")); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - // check replication controller - Deployment rc = kubernetesClient.apps().deployments() - .withName(CLUSTER_ID + MASTER_RS_NAME_SUFFIX).get(); - assertNotNull(rc); - assertEquals(labels, rc.getSpec().getSelector().getMatchLabels()); - - // check service - Service service = kubernetesClient.services().withName(CLUSTER_ID + SERVICE_NAME_SUFFIX) + } + + @Test + public void testCreateMaster() { + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + // check config map + ConfigMap configMap = + kubernetesClient + .configMaps() + .withName(CLUSTER_ID + K8SConstants.MASTER_CONFIG_MAP_SUFFIX) .get(); - assertNotNull(service); - assertEquals(labels, service.getSpec().getSelector()); - } - - @Test(timeOut = 30000) - public void testStartMasterWithNodePort() { - jobConf.put(SERVICE_EXPOSED_TYPE.getKey(), "NODE_PORT"); - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - // check service - Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); - assertNotNull(service); - assertEquals(1, service.getSpec().getPorts().size()); - // Check client service - Service clientService = geaflowKubeClient - .getService(KubernetesUtils.getMasterClientServiceName(CLUSTER_ID)); - assertNotNull(service); - assertEquals(1, clientService.getSpec().getPorts().size()); - } - - @Test(timeOut = 30000) - public void testStartMasterWithNodeSelector() { - jobConf.put(MASTER_NODE_SELECTOR, "env:production,tier:frontend"); - kubernetesClusterManager.createMaster(CLUSTER_ID, new HashMap<>()); - - // check node selector - Deployment rc = kubernetesClient.apps().deployments() - .withName(CLUSTER_ID + MASTER_RS_NAME_SUFFIX).get(); - assertNotNull(rc); - Map nodeSelector = rc.getSpec().getTemplate().getSpec().getNodeSelector(); - assertEquals(2, nodeSelector.size()); - assertEquals("production", nodeSelector.get("env")); - assertEquals("frontend", nodeSelector.get("tier")); - } - - @Test - public void testHostNetwork() { - jobConf - .put(DOCKER_NETWORK_TYPE.getKey(), KubernetesConfig.DockerNetworkType.HOST.toString()); - kubernetesClusterManager.createMaster(CLUSTER_ID, new HashMap<>()); - - // check replication controller - Deployment rc = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); - assertNotNull(rc); - assertEquals(1, rc.getSpec().getReplicas().intValue()); - assertTrue(rc.getSpec().getTemplate().getSpec().getHostNetwork()); - } - - @Test - public void testCreateWorkerPod() { - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - int containerId = 1; - KubernetesClusterManager kubernetesClusterManager2 = new KubernetesClusterManager(); - jobConf.put(CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); - ClusterContext context = new ClusterContext(jobConf); - kubernetesClusterManager2.init(context, geaflowKubeClient); - kubernetesClusterManager2.createNewContainer(containerId, false); - // check pod label - verifyWorkerPodSize(1); - // restart pod - kubernetesClusterManager2.restartContainer(containerId); - verifyWorkerPodSize(1); - } - - @Test - public void testClusterFailover() { - - int masterId = 0; - int containerId_1 = 1; - int containerId_2 = 2; - int driverId = 3; - Map containerIds = new HashMap() {{ + assertNotNull(configMap); + assertEquals(1, configMap.getData().size()); + assertTrue(configMap.getData().containsKey(K8SConstants.ENV_CONFIG_FILE)); + + // check replication controller + Deployment deployment = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); + assertNotNull(deployment); + assertEquals(1, deployment.getSpec().getReplicas().intValue()); + assertEquals(3, deployment.getSpec().getSelector().getMatchLabels().size()); + assertEquals( + MASTER_CONTAINER_NAME, + deployment.getSpec().getTemplate().getSpec().getContainers().get(0).getName()); + TestCase.assertEquals( + K8SConstants.GEAFLOW_CONF_VOLUME, + deployment.getSpec().getTemplate().getSpec().getVolumes().get(0).getName()); + assertEquals( + SERVICE_ACCOUNT, deployment.getSpec().getTemplate().getSpec().getServiceAccountName()); + + // check service + Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); + assertNotNull(service); + assertEquals(labels, service.getSpec().getSelector()); + assertEquals(1, service.getSpec().getPorts().size()); + + // check owner reference + String serviceName = service.getMetadata().getName(); + assertNotNull(configMap.getMetadata().getOwnerReferences()); + assertEquals(serviceName, configMap.getMetadata().getOwnerReferences().get(0).getName()); + assertNotNull(deployment.getMetadata().getOwnerReferences().get(0)); + assertEquals(serviceName, deployment.getMetadata().getOwnerReferences().get(0).getName()); + } + + @Test(timeOut = 30000) + public void testKillCluster() { + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); + assertNotNull(service); + + ConfigMap configMap = + geaflowKubeClient.getConfigMap(CLUSTER_ID + K8SConstants.MASTER_CONFIG_MAP_SUFFIX); + assertNotNull(configMap); + + Deployment rc = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); + assertNotNull(rc); + + kubernetesClusterManager.close(); + } + + @Test(timeOut = 30000) + public void testMasterUserLabels() { + jobConf.put(POD_USER_LABELS.getKey(), "l1:test,l2:hello"); + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + labels.put(LABEL_COMPONENT_ID_KEY, String.valueOf(DEFAULT_MASTER_ID)); + labels.putAll(KubernetesUtils.getPairsConf(jobConf, POD_USER_LABELS)); + assertEquals(5, labels.size()); + assertEquals("test", labels.get("l1")); + assertEquals("hello", labels.get("l2")); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + // check replication controller + Deployment rc = + kubernetesClient.apps().deployments().withName(CLUSTER_ID + MASTER_RS_NAME_SUFFIX).get(); + assertNotNull(rc); + assertEquals(labels, rc.getSpec().getSelector().getMatchLabels()); + + // check service + Service service = kubernetesClient.services().withName(CLUSTER_ID + SERVICE_NAME_SUFFIX).get(); + assertNotNull(service); + assertEquals(labels, service.getSpec().getSelector()); + } + + @Test(timeOut = 30000) + public void testStartMasterWithNodePort() { + jobConf.put(SERVICE_EXPOSED_TYPE.getKey(), "NODE_PORT"); + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + // check service + Service service = geaflowKubeClient.getService(CLUSTER_ID + SERVICE_NAME_SUFFIX); + assertNotNull(service); + assertEquals(1, service.getSpec().getPorts().size()); + // Check client service + Service clientService = + geaflowKubeClient.getService(KubernetesUtils.getMasterClientServiceName(CLUSTER_ID)); + assertNotNull(service); + assertEquals(1, clientService.getSpec().getPorts().size()); + } + + @Test(timeOut = 30000) + public void testStartMasterWithNodeSelector() { + jobConf.put(MASTER_NODE_SELECTOR, "env:production,tier:frontend"); + kubernetesClusterManager.createMaster(CLUSTER_ID, new HashMap<>()); + + // check node selector + Deployment rc = + kubernetesClient.apps().deployments().withName(CLUSTER_ID + MASTER_RS_NAME_SUFFIX).get(); + assertNotNull(rc); + Map nodeSelector = rc.getSpec().getTemplate().getSpec().getNodeSelector(); + assertEquals(2, nodeSelector.size()); + assertEquals("production", nodeSelector.get("env")); + assertEquals("frontend", nodeSelector.get("tier")); + } + + @Test + public void testHostNetwork() { + jobConf.put(DOCKER_NETWORK_TYPE.getKey(), KubernetesConfig.DockerNetworkType.HOST.toString()); + kubernetesClusterManager.createMaster(CLUSTER_ID, new HashMap<>()); + + // check replication controller + Deployment rc = geaflowKubeClient.getDeployment(CLUSTER_ID + MASTER_RS_NAME_SUFFIX); + assertNotNull(rc); + assertEquals(1, rc.getSpec().getReplicas().intValue()); + assertTrue(rc.getSpec().getTemplate().getSpec().getHostNetwork()); + } + + @Test + public void testCreateWorkerPod() { + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + int containerId = 1; + KubernetesClusterManager kubernetesClusterManager2 = new KubernetesClusterManager(); + jobConf.put(CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); + ClusterContext context = new ClusterContext(jobConf); + kubernetesClusterManager2.init(context, geaflowKubeClient); + kubernetesClusterManager2.createNewContainer(containerId, false); + // check pod label + verifyWorkerPodSize(1); + // restart pod + kubernetesClusterManager2.restartContainer(containerId); + verifyWorkerPodSize(1); + } + + @Test + public void testClusterFailover() { + + int masterId = 0; + int containerId_1 = 1; + int containerId_2 = 2; + int driverId = 3; + Map containerIds = + new HashMap() { + { put(containerId_1, "container1"); put(containerId_2, "container2"); - }}; - Map driverIds = new HashMap() {{ + } + }; + Map driverIds = + new HashMap() { + { put(0, "driver1"); - }}; - - // driver/container FO - Map labels = new HashMap<>(); - labels.put("app", CLUSTER_ID); - labels.put("component", MASTER_COMPONENT); - kubernetesClusterManager.createMaster(CLUSTER_ID, labels); - - KubernetesClusterManager kubernetesClusterManager2 = new KubernetesClusterManager(); - jobConf.put(CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); - ClusterContext context = new ClusterContext(jobConf); - context.setContainerIds(containerIds); - context.setDriverIds(driverIds); - kubernetesClusterManager2.init(context, geaflowKubeClient); - kubernetesClusterManager2.createNewDriver(driverId, 0); - kubernetesClusterManager2.createNewContainer(containerId_1, false); - kubernetesClusterManager2.createNewContainer(containerId_2, false); - - // check pod label - verifyWorkerPodSize(2); - - // restart all pod - kubernetesClusterManager2.restartContainer(containerId_1); - verifyWorkerPodSize(2); - - // restart driver & containers - KubernetesClusterManager mock = mockClusterManager(); - mock.doFailover(masterId, new GeaflowHeartbeatException()); - Mockito.verify(mock, Mockito.times(1)).restartAllDrivers(); - Mockito.verify(mock, Mockito.times(1)).restartAllContainers(); - } - - private KubernetesClusterManager mockClusterManager() { - KubernetesClusterManager clusterManager = new KubernetesClusterManager(); - KubernetesClusterManager mockClusterManager = Mockito.spy(clusterManager); - Mockito.doNothing().when(mockClusterManager).restartAllDrivers(); - Mockito.doNothing().when(mockClusterManager).restartAllContainers(); - - ClusterContext context = new ClusterContext(jobConf); - mockClusterManager.init(context, geaflowKubeClient); - return mockClusterManager; + } + }; + + // driver/container FO + Map labels = new HashMap<>(); + labels.put("app", CLUSTER_ID); + labels.put("component", MASTER_COMPONENT); + kubernetesClusterManager.createMaster(CLUSTER_ID, labels); + + KubernetesClusterManager kubernetesClusterManager2 = new KubernetesClusterManager(); + jobConf.put(CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); + ClusterContext context = new ClusterContext(jobConf); + context.setContainerIds(containerIds); + context.setDriverIds(driverIds); + kubernetesClusterManager2.init(context, geaflowKubeClient); + kubernetesClusterManager2.createNewDriver(driverId, 0); + kubernetesClusterManager2.createNewContainer(containerId_1, false); + kubernetesClusterManager2.createNewContainer(containerId_2, false); + + // check pod label + verifyWorkerPodSize(2); + + // restart all pod + kubernetesClusterManager2.restartContainer(containerId_1); + verifyWorkerPodSize(2); + + // restart driver & containers + KubernetesClusterManager mock = mockClusterManager(); + mock.doFailover(masterId, new GeaflowHeartbeatException()); + Mockito.verify(mock, Mockito.times(1)).restartAllDrivers(); + Mockito.verify(mock, Mockito.times(1)).restartAllContainers(); + } + + private KubernetesClusterManager mockClusterManager() { + KubernetesClusterManager clusterManager = new KubernetesClusterManager(); + KubernetesClusterManager mockClusterManager = Mockito.spy(clusterManager); + Mockito.doNothing().when(mockClusterManager).restartAllDrivers(); + Mockito.doNothing().when(mockClusterManager).restartAllContainers(); + + ClusterContext context = new ClusterContext(jobConf); + mockClusterManager.init(context, geaflowKubeClient); + return mockClusterManager; + } + + private void verifyWorkerPodSize(int size) { + verifyPodSize(size, K8SConstants.LABEL_COMPONENT_WORKER); + } + + private void verifyPodSize(int size, String componentKey) { + List pods = getPods(componentKey); + Assert.assertEquals(pods.size(), size); + if (size > 0) { + Assert.assertTrue( + pods.get(0).getMetadata().getLabels().containsKey(K8SConstants.LABEL_COMPONENT_ID_KEY)); } - - private void verifyWorkerPodSize(int size) { - verifyPodSize(size, K8SConstants.LABEL_COMPONENT_WORKER); - } - - private void verifyPodSize(int size, String componentKey) { - List pods = getPods(componentKey); - Assert.assertEquals(pods.size(), size); - if (size > 0) { - Assert.assertTrue(pods.get(0).getMetadata().getLabels().containsKey(K8SConstants.LABEL_COMPONENT_ID_KEY)); - } - } - - private List getPods(String componentKey) { - Map workerLabels = new HashMap<>(); - workerLabels.put(K8SConstants.LABEL_APP_KEY, CLUSTER_ID); - workerLabels.put(K8SConstants.LABEL_COMPONENT_KEY, componentKey); - return geaflowKubeClient.getPods(workerLabels).getItems(); - } - + } + + private List getPods(String componentKey) { + Map workerLabels = new HashMap<>(); + workerLabels.put(K8SConstants.LABEL_APP_KEY, CLUSTER_ID); + workerLabels.put(K8SConstants.LABEL_COMPONENT_KEY, componentKey); + return geaflowKubeClient.getPods(workerLabels).getItems(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigTest.java index 4e9d3392e..5a410559c 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/config/KubernetesConfigTest.java @@ -28,26 +28,24 @@ public class KubernetesConfigTest { - @Test - public void testMasterUrl() { - Configuration configuration = new Configuration(); - Assert.assertEquals(KubernetesConfig.getClientMasterUrl(configuration), - MASTER_URL.getDefaultValue()); - - configuration.put(CLIENT_MASTER_URL, ""); - Assert.assertEquals(KubernetesConfig.getClientMasterUrl(configuration), - MASTER_URL.getDefaultValue()); - - configuration.put(CLIENT_MASTER_URL, "client"); - Assert.assertEquals(KubernetesConfig.getClientMasterUrl(configuration), - "client"); - } - - @Test - public void testJarDownloadPath() { - Configuration configuration = new Configuration(); - String path = KubernetesConfig.getJarDownloadPath(configuration); - Assert.assertEquals(path, "/home/admin/geaflow/tmp/jar"); - } - + @Test + public void testMasterUrl() { + Configuration configuration = new Configuration(); + Assert.assertEquals( + KubernetesConfig.getClientMasterUrl(configuration), MASTER_URL.getDefaultValue()); + + configuration.put(CLIENT_MASTER_URL, ""); + Assert.assertEquals( + KubernetesConfig.getClientMasterUrl(configuration), MASTER_URL.getDefaultValue()); + + configuration.put(CLIENT_MASTER_URL, "client"); + Assert.assertEquals(KubernetesConfig.getClientMasterUrl(configuration), "client"); + } + + @Test + public void testJarDownloadPath() { + Configuration configuration = new Configuration(); + String path = KubernetesConfig.getJarDownloadPath(configuration); + Assert.assertEquals(path, "/home/admin/geaflow/tmp/jar"); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunnerTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunnerTest.java index 84f4e9e95..5ee2d5bf8 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunnerTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesClientRunnerTest.java @@ -19,21 +19,21 @@ package org.apache.geaflow.cluster.k8s.entrypoint; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.testng.annotations.Test; -public class KubernetesClientRunnerTest extends KubernetesTestBase { +import com.alibaba.fastjson.JSON; - @Test - public void testClientRunner() { - configuration.put(ExecutionConfigKeys.FO_TIMEOUT_MS, "2"); - configuration.put(KubernetesConfigKeys.USER_MAIN_CLASS, - KubernetesMockRunner.class.getCanonicalName()); - String clusterArgs = JSON.toJSONString(configuration.getConfigMap()); - KubernetesClientRunner clientRunner = new KubernetesClientRunner(configuration); - clientRunner.run(clusterArgs); - } +public class KubernetesClientRunnerTest extends KubernetesTestBase { + @Test + public void testClientRunner() { + configuration.put(ExecutionConfigKeys.FO_TIMEOUT_MS, "2"); + configuration.put( + KubernetesConfigKeys.USER_MAIN_CLASS, KubernetesMockRunner.class.getCanonicalName()); + String clusterArgs = JSON.toJSONString(configuration.getConfigMap()); + KubernetesClientRunner clientRunner = new KubernetesClientRunner(configuration); + clientRunner.run(clusterArgs); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunnerTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunnerTest.java index 0829c3ba6..d07711961 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunnerTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMasterRunnerTest.java @@ -20,20 +20,20 @@ package org.apache.geaflow.cluster.k8s.entrypoint; import java.util.concurrent.TimeoutException; + import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.testng.annotations.Test; public class KubernetesMasterRunnerTest extends KubernetesTestBase { - @Test(expectedExceptions = TimeoutException.class) - public void testMasterRunner() throws Throwable { - configuration.put(ExecutionConfigKeys.FO_TIMEOUT_MS, "1000"); - try { - KubernetesMasterRunner masterRunner = new KubernetesMasterRunner(configuration); - masterRunner.init(); - } catch (Exception e) { - throw e.getCause(); - } + @Test(expectedExceptions = TimeoutException.class) + public void testMasterRunner() throws Throwable { + configuration.put(ExecutionConfigKeys.FO_TIMEOUT_MS, "1000"); + try { + KubernetesMasterRunner masterRunner = new KubernetesMasterRunner(configuration); + masterRunner.init(); + } catch (Exception e) { + throw e.getCause(); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMockRunner.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMockRunner.java index 14adc7129..c6f4edc92 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMockRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesMockRunner.java @@ -21,7 +21,5 @@ public class KubernetesMockRunner { - public static void main(String[] args) { - } - + public static void main(String[] args) {} } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesTestBase.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesTestBase.java index ddd32edf2..212a5c754 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesTestBase.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/entrypoint/KubernetesTestBase.java @@ -19,10 +19,8 @@ package org.apache.geaflow.cluster.k8s.entrypoint; -import io.fabric8.kubernetes.api.model.Service; -import io.fabric8.kubernetes.client.KubernetesClient; -import io.fabric8.kubernetes.client.server.mock.KubernetesServer; import java.util.HashMap; + import org.apache.geaflow.cluster.k8s.clustermanager.GeaflowKubeClient; import org.apache.geaflow.cluster.k8s.clustermanager.KubernetesResourceBuilder; import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; @@ -36,48 +34,52 @@ import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.server.mock.KubernetesServer; + public class KubernetesTestBase { - private final KubernetesServer server = new KubernetesServer(false, true); - protected KubernetesClient kubernetesClient; - protected String masterUrl; - protected String clusterId; - protected Configuration configuration; - protected GeaflowKubeClient geaflowKubeClient; + private final KubernetesServer server = new KubernetesServer(false, true); + protected KubernetesClient kubernetesClient; + protected String masterUrl; + protected String clusterId; + protected Configuration configuration; + protected GeaflowKubeClient geaflowKubeClient; - @BeforeMethod - public void setUp() { - server.before(); - kubernetesClient = server.getClient(); - masterUrl = kubernetesClient.getMasterUrl().toString(); - clusterId = "123"; + @BeforeMethod + public void setUp() { + server.before(); + kubernetesClient = server.getClient(); + masterUrl = kubernetesClient.getMasterUrl().toString(); + clusterId = "123"; - configuration = new Configuration(); - configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "memory"); - configuration.put(ExecutionConfigKeys.REPORTER_LIST, "slf4j"); - configuration.put(ExecutionConfigKeys.HA_SERVICE_TYPE, "memory"); - configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "1"); - configuration.put(ExecutionConfigKeys.CLUSTER_ID, clusterId); - configuration.put(KubernetesConfigKeys.MASTER_URL, masterUrl); - configuration.put(KubernetesConfig.CLUSTER_START_TIME, - String.valueOf(System.currentTimeMillis())); - configuration.setMasterId("mockMaster"); + configuration = new Configuration(); + configuration.put(FrameworkConfigKeys.SYSTEM_STATE_BACKEND_TYPE, "memory"); + configuration.put(ExecutionConfigKeys.REPORTER_LIST, "slf4j"); + configuration.put(ExecutionConfigKeys.HA_SERVICE_TYPE, "memory"); + configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "1"); + configuration.put(ExecutionConfigKeys.CLUSTER_ID, clusterId); + configuration.put(KubernetesConfigKeys.MASTER_URL, masterUrl); + configuration.put( + KubernetesConfig.CLUSTER_START_TIME, String.valueOf(System.currentTimeMillis())); + configuration.setMasterId("mockMaster"); - geaflowKubeClient = new GeaflowKubeClient(configuration, masterUrl); - createService(configuration); - } + geaflowKubeClient = new GeaflowKubeClient(configuration, masterUrl); + createService(configuration); + } - @AfterMethod - public void destroy() { - server.after(); - } + @AfterMethod + public void destroy() { + server.after(); + } - private void createService(Configuration configuration) { - KubernetesMasterParam param = new KubernetesMasterParam(configuration); - String serviceName = KubernetesUtils.getMasterServiceName(clusterId); - Service service = KubernetesResourceBuilder - .createService(serviceName, ServiceExposedType.CLUSTER_IP, - new HashMap<>(), null, param); - geaflowKubeClient.createService(service); - } + private void createService(Configuration configuration) { + KubernetesMasterParam param = new KubernetesMasterParam(configuration); + String serviceName = KubernetesUtils.getMasterServiceName(clusterId); + Service service = + KubernetesResourceBuilder.createService( + serviceName, ServiceExposedType.CLUSTER_IP, new HashMap<>(), null, param); + geaflowKubeClient.createService(service); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtilsTest.java b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtilsTest.java index 83be5c354..8c270ec53 100644 --- a/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtilsTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-k8s/src/test/java/org/apache/geaflow/cluster/k8s/utils/KubernetesUtilsTest.java @@ -23,12 +23,11 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CLUSTER_CLIENT_TIMEOUT_MS; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.REPORTER_LIST; -import io.fabric8.kubernetes.api.model.NodeSelectorRequirement; -import io.fabric8.kubernetes.api.model.Toleration; import java.io.File; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.constants.ClusterConstants; import org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys; import org.apache.geaflow.cluster.rpc.ConnectAddress; @@ -38,75 +37,78 @@ import org.testng.Assert; import org.testng.annotations.Test; +import io.fabric8.kubernetes.api.model.NodeSelectorRequirement; +import io.fabric8.kubernetes.api.model.Toleration; + public class KubernetesUtilsTest { - @Test - public void testGetEnvTest() { - Map env = System.getenv(); - try { - ClusterUtils.getEnvValue(env, "envTestKey"); - } catch (IllegalArgumentException e) { - Assert.assertTrue(e.getMessage().contains("envTestKey is not set")); - } + @Test + public void testGetEnvTest() { + Map env = System.getenv(); + try { + ClusterUtils.getEnvValue(env, "envTestKey"); + } catch (IllegalArgumentException e) { + Assert.assertTrue(e.getMessage().contains("envTestKey is not set")); } + } - @Test - public void testLoadConfigFromFile() { - String path = - this.getClass().getClassLoader().getResource("geaflow-conf-test.yml").getPath(); - Configuration config = KubernetesUtils.loadYAMLResource(new File(path)); - Assert.assertNotNull(config); - Assert.assertEquals(config.getString(REPORTER_LIST), "slf4j"); - Assert.assertEquals(config.getString(SERVICE_EXPOSED_TYPE), "NODE_PORT"); - Assert.assertEquals(config.getInteger(CLUSTER_CLIENT_TIMEOUT_MS), 300000); - } + @Test + public void testLoadConfigFromFile() { + String path = this.getClass().getClassLoader().getResource("geaflow-conf-test.yml").getPath(); + Configuration config = KubernetesUtils.loadYAMLResource(new File(path)); + Assert.assertNotNull(config); + Assert.assertEquals(config.getString(REPORTER_LIST), "slf4j"); + Assert.assertEquals(config.getString(SERVICE_EXPOSED_TYPE), "NODE_PORT"); + Assert.assertEquals(config.getInteger(CLUSTER_CLIENT_TIMEOUT_MS), 300000); + } - @Test - public void testGetContentFromFile() { - String path = - this.getClass().getClassLoader().getResource("geaflow-conf-test.yml").getPath(); - String content = FileUtil.getContentFromFile(path); - Assert.assertNotNull(content); - } + @Test + public void testGetContentFromFile() { + String path = this.getClass().getClassLoader().getResource("geaflow-conf-test.yml").getPath(); + String content = FileUtil.getContentFromFile(path); + Assert.assertNotNull(content); + } - @Test - public void testTolerances() { - Configuration configuration = new Configuration(); - List tolerationList = KubernetesUtils.getTolerations(configuration); - Assert.assertEquals(tolerationList.size(), 0); + @Test + public void testTolerances() { + Configuration configuration = new Configuration(); + List tolerationList = KubernetesUtils.getTolerations(configuration); + Assert.assertEquals(tolerationList.size(), 0); - configuration.put(KubernetesConfigKeys.TOLERATION_LIST, ""); - tolerationList = KubernetesUtils.getTolerations(configuration); - Assert.assertEquals(tolerationList.size(), 0); + configuration.put(KubernetesConfigKeys.TOLERATION_LIST, ""); + tolerationList = KubernetesUtils.getTolerations(configuration); + Assert.assertEquals(tolerationList.size(), 0); - configuration.put(KubernetesConfigKeys.TOLERATION_LIST, "key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-"); - tolerationList = KubernetesUtils.getTolerations(configuration); - Assert.assertEquals(tolerationList.size(), 2); - } + configuration.put( + KubernetesConfigKeys.TOLERATION_LIST, "key1:Equal:value1:NoSchedule:-,key2:Exists:-:-:-"); + tolerationList = KubernetesUtils.getTolerations(configuration); + Assert.assertEquals(tolerationList.size(), 2); + } - @Test - public void testMatchExpressions() { - Configuration configuration = new Configuration(); - List matchExpressions = KubernetesUtils.getMatchExpressions(configuration); - Assert.assertEquals(matchExpressions.size(), 0); + @Test + public void testMatchExpressions() { + Configuration configuration = new Configuration(); + List matchExpressions = + KubernetesUtils.getMatchExpressions(configuration); + Assert.assertEquals(matchExpressions.size(), 0); - configuration.put(KubernetesConfigKeys.MATCH_EXPRESSION_LIST, ""); - matchExpressions = KubernetesUtils.getMatchExpressions(configuration); - Assert.assertEquals(matchExpressions.size(), 0); + configuration.put(KubernetesConfigKeys.MATCH_EXPRESSION_LIST, ""); + matchExpressions = KubernetesUtils.getMatchExpressions(configuration); + Assert.assertEquals(matchExpressions.size(), 0); - configuration.put(KubernetesConfigKeys.MATCH_EXPRESSION_LIST, "key1:In:value1,key2:In:-"); - matchExpressions = KubernetesUtils.getMatchExpressions(configuration); - Assert.assertEquals(matchExpressions.size(), 2); - } + configuration.put(KubernetesConfigKeys.MATCH_EXPRESSION_LIST, "key1:In:value1,key2:In:-"); + matchExpressions = KubernetesUtils.getMatchExpressions(configuration); + Assert.assertEquals(matchExpressions.size(), 2); + } - @Test - public void testAddressEncoding() { - Map map = new HashMap<>(); - for (int i = 0; i < 3; i++) { - map.put(ClusterConstants.getDriverName(i), new ConnectAddress("127.0.0.1", 80)); - } - String encodedStr = KubernetesUtils.encodeRpcAddressMap(map); - Map map2 = KubernetesUtils.decodeRpcAddressMap(encodedStr); - Assert.assertEquals(map, map2); + @Test + public void testAddressEncoding() { + Map map = new HashMap<>(); + for (int i = 0; i < 3; i++) { + map.put(ClusterConstants.getDriverName(i), new ConnectAddress("127.0.0.1", 80)); } + String encodedStr = KubernetesUtils.encodeRpcAddressMap(map); + Map map2 = KubernetesUtils.decodeRpcAddressMap(encodedStr); + Assert.assertEquals(map, map2); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalClusterClient.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalClusterClient.java index 26610c45f..7ab4ff198 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalClusterClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalClusterClient.java @@ -21,6 +21,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.geaflow.cluster.client.AbstractClusterClient; import org.apache.geaflow.cluster.client.IPipelineClient; import org.apache.geaflow.cluster.client.PipelineClientFactory; @@ -38,43 +39,42 @@ public class LocalClusterClient extends AbstractClusterClient { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalClusterClient.class); - private LocalClusterManager localClusterManager; - private ClusterContext clusterContext; - private final ExecutorService agentService = Executors.newSingleThreadExecutor( - ThreadUtil.namedThreadFactory(true, "local-agent")); + private static final Logger LOGGER = LoggerFactory.getLogger(LocalClusterClient.class); + private LocalClusterManager localClusterManager; + private ClusterContext clusterContext; + private final ExecutorService agentService = + Executors.newSingleThreadExecutor(ThreadUtil.namedThreadFactory(true, "local-agent")); - @Override - public void init(IEnvironmentContext environmentContext) { - super.init(environmentContext); - clusterContext = new ClusterContext(config); - localClusterManager = new LocalClusterManager(); - localClusterManager.init(clusterContext); - } + @Override + public void init(IEnvironmentContext environmentContext) { + super.init(environmentContext); + clusterContext = new ClusterContext(config); + localClusterManager = new LocalClusterManager(); + localClusterManager.init(clusterContext); + } - @Override - public IPipelineClient startCluster() { - try { - LocalClusterId clusterId = localClusterManager.startMaster(); - ClusterInfo clusterInfo = LocalClient.initMaster(clusterId.getMaster()); - ClusterMeta clusterMeta = new ClusterMeta(clusterInfo); - callback.onSuccess(clusterMeta); - LOGGER.info("cluster info: {}", clusterInfo); - return PipelineClientFactory.createPipelineClient( - clusterInfo.getDriverAddresses(), clusterContext.getConfig()); - } catch (Throwable e) { - LOGGER.error("deploy cluster failed", e); - callback.onFailure(e); - throw new GeaflowRuntimeException(e); - } + @Override + public IPipelineClient startCluster() { + try { + LocalClusterId clusterId = localClusterManager.startMaster(); + ClusterInfo clusterInfo = LocalClient.initMaster(clusterId.getMaster()); + ClusterMeta clusterMeta = new ClusterMeta(clusterInfo); + callback.onSuccess(clusterMeta); + LOGGER.info("cluster info: {}", clusterInfo); + return PipelineClientFactory.createPipelineClient( + clusterInfo.getDriverAddresses(), clusterContext.getConfig()); + } catch (Throwable e) { + LOGGER.error("deploy cluster failed", e); + callback.onFailure(e); + throw new GeaflowRuntimeException(e); } + } - @Override - public void shutdown() { - LOGGER.info("shutdown cluster"); - if (localClusterManager != null) { - localClusterManager.close(); - } + @Override + public void shutdown() { + LOGGER.info("shutdown cluster"); + if (localClusterManager != null) { + localClusterManager.close(); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalEnvironment.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalEnvironment.java index 377511493..5c8f69ff3 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalEnvironment.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/client/LocalEnvironment.java @@ -24,13 +24,13 @@ public class LocalEnvironment extends AbstractEnvironment { - @Override - protected IClusterClient getClusterClient() { - return new LocalClusterClient(); - } + @Override + protected IClusterClient getClusterClient() { + return new LocalClusterClient(); + } - @Override - public EnvType getEnvType() { - return EnvType.LOCAL; - } + @Override + public EnvType getEnvType() { + return EnvType.LOCAL; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClient.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClient.java index a4a18cb3e..88521f812 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClient.java @@ -20,6 +20,7 @@ package org.apache.geaflow.cluster.local.clustermanager; import java.io.Serializable; + import org.apache.geaflow.cluster.clustermanager.ClusterInfo; import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.container.ContainerContext; @@ -30,28 +31,25 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class LocalClient implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalClient.class); - - public static LocalMasterRunner createMaster(ClusterConfig clusterConfig) { - return new LocalMasterRunner(clusterConfig.getConfig()); - } + private static final Logger LOGGER = LoggerFactory.getLogger(LocalClient.class); - public static ClusterInfo initMaster(LocalMasterRunner master) { - LOGGER.info("init master"); - return master.init(); - } + public static LocalMasterRunner createMaster(ClusterConfig clusterConfig) { + return new LocalMasterRunner(clusterConfig.getConfig()); + } - public static LocalDriverRunner createDriver(ClusterConfig clusterConfig, - DriverContext context) { - return new LocalDriverRunner(context); - } + public static ClusterInfo initMaster(LocalMasterRunner master) { + LOGGER.info("init master"); + return master.init(); + } - public static LocalContainerRunner createContainer(ClusterConfig clusterConfig, - ContainerContext containerContext) { - return new LocalContainerRunner(containerContext); - } + public static LocalDriverRunner createDriver(ClusterConfig clusterConfig, DriverContext context) { + return new LocalDriverRunner(context); + } + public static LocalContainerRunner createContainer( + ClusterConfig clusterConfig, ContainerContext containerContext) { + return new LocalContainerRunner(containerContext); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterId.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterId.java index 45b323425..a3530d0c8 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterId.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterId.java @@ -24,14 +24,13 @@ public class LocalClusterId implements ClusterId { - private LocalMasterRunner master; + private LocalMasterRunner master; - public LocalClusterId(LocalMasterRunner master) { - this.master = master; - } - - public LocalMasterRunner getMaster() { - return master; - } + public LocalClusterId(LocalMasterRunner master) { + this.master = master; + } + public LocalMasterRunner getMaster() { + return master; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterManager.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterManager.java index 951a856bf..da7482f50 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterManager.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/clustermanager/LocalClusterManager.java @@ -19,9 +19,9 @@ package org.apache.geaflow.cluster.local.clustermanager; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.clustermanager.AbstractClusterManager; import org.apache.geaflow.cluster.clustermanager.ClusterContext; @@ -36,75 +36,73 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class LocalClusterManager extends AbstractClusterManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(LocalClusterManager.class); - private String appPath; - - @Override - public void init(ClusterContext clusterContext) { - super.init(clusterContext); - this.appPath = getWorkPath(); - } - - @Override - protected IFailoverStrategy buildFailoverStrategy() { - return FailoverStrategyFactory.loadFailoverStrategy(IEnvironment.EnvType.LOCAL, - FailoverStrategyType.disable_fo.name()); - } - - @Override - public LocalClusterId startMaster() { - Preconditions.checkArgument(clusterConfig != null, "clusterConfig is not initialized"); - clusterInfo = new LocalClusterId(LocalClient.createMaster(clusterConfig)); - return (LocalClusterId) clusterInfo; - } - - @Override - public void createNewContainer(int containerId, boolean isRecover) { - ContainerContext containerContext = new LocalContainerContext(containerId, - clusterConfig.getConfig()); - LocalClient.createContainer(clusterConfig, containerContext); - } - - @Override - public void createNewDriver(int driverId, int driverIndex) { - DriverContext driverContext = new LocalDriverContext(driverId, driverIndex, - clusterConfig.getConfig()); - LocalClient.createDriver(clusterConfig, driverContext); - LOGGER.info("call driver start id:{} index:{}", driverId, driverIndex); - } - - - @Override - public void restartDriver(int driverId) { - } - - @Override - public void restartContainer(int containerId) { - } +import com.google.common.base.Preconditions; - @Override - protected void validateContainerNum(int containerNum) { - Preconditions.checkArgument(containerNum == 1, "local mode containerNum must equal with 1"); - } +public class LocalClusterManager extends AbstractClusterManager { - @Override - public void close() { - super.close(); - if (appPath != null) { - FileUtils.deleteQuietly(new File(appPath)); - } + private static final Logger LOGGER = LoggerFactory.getLogger(LocalClusterManager.class); + private String appPath; + + @Override + public void init(ClusterContext clusterContext) { + super.init(clusterContext); + this.appPath = getWorkPath(); + } + + @Override + protected IFailoverStrategy buildFailoverStrategy() { + return FailoverStrategyFactory.loadFailoverStrategy( + IEnvironment.EnvType.LOCAL, FailoverStrategyType.disable_fo.name()); + } + + @Override + public LocalClusterId startMaster() { + Preconditions.checkArgument(clusterConfig != null, "clusterConfig is not initialized"); + clusterInfo = new LocalClusterId(LocalClient.createMaster(clusterConfig)); + return (LocalClusterId) clusterInfo; + } + + @Override + public void createNewContainer(int containerId, boolean isRecover) { + ContainerContext containerContext = + new LocalContainerContext(containerId, clusterConfig.getConfig()); + LocalClient.createContainer(clusterConfig, containerContext); + } + + @Override + public void createNewDriver(int driverId, int driverIndex) { + DriverContext driverContext = + new LocalDriverContext(driverId, driverIndex, clusterConfig.getConfig()); + LocalClient.createDriver(clusterConfig, driverContext); + LOGGER.info("call driver start id:{} index:{}", driverId, driverIndex); + } + + @Override + public void restartDriver(int driverId) {} + + @Override + public void restartContainer(int containerId) {} + + @Override + protected void validateContainerNum(int containerNum) { + Preconditions.checkArgument(containerNum == 1, "local mode containerNum must equal with 1"); + } + + @Override + public void close() { + super.close(); + if (appPath != null) { + FileUtils.deleteQuietly(new File(appPath)); } - - private String getWorkPath() { - String workPath = "/tmp/" + System.currentTimeMillis(); - try { - FileUtils.forceMkdir(new File(workPath)); - } catch (IOException e) { - LOGGER.error(e.getMessage(), e); - } - return workPath; + } + + private String getWorkPath() { + String workPath = "/tmp/" + System.currentTimeMillis(); + try { + FileUtils.forceMkdir(new File(workPath)); + } catch (IOException e) { + LOGGER.error(e.getMessage(), e); } - + return workPath; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalContainerContext.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalContainerContext.java index c7aaab2fd..4a931828d 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalContainerContext.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalContainerContext.java @@ -24,11 +24,11 @@ public class LocalContainerContext extends ContainerContext { - public LocalContainerContext(int index, Configuration config) { - super(index, config); - } + public LocalContainerContext(int index, Configuration config) { + super(index, config); + } - public boolean isRecover() { - return false; - } + public boolean isRecover() { + return false; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalDriverContext.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalDriverContext.java index cddad4da0..79928686c 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalDriverContext.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/context/LocalDriverContext.java @@ -24,12 +24,11 @@ public class LocalDriverContext extends DriverContext { - public LocalDriverContext(int id, int index, Configuration config) { - super(id, index, config); - } - - public boolean isRecover() { - return false; - } + public LocalDriverContext(int id, int index, Configuration config) { + super(id, index, config); + } + public boolean isRecover() { + return false; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalContainerRunner.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalContainerRunner.java index 4509d469b..3237e38cd 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalContainerRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalContainerRunner.java @@ -24,10 +24,9 @@ public class LocalContainerRunner { - public LocalContainerRunner(ContainerContext context) { - Container container = new Container(); - context.load(); - container.init(context); - } - + public LocalContainerRunner(ContainerContext context) { + Container container = new Container(); + context.load(); + container.init(context); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalDriverRunner.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalDriverRunner.java index 282e605f6..92aceb353 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalDriverRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalDriverRunner.java @@ -24,10 +24,9 @@ public class LocalDriverRunner { - public LocalDriverRunner(DriverContext context) { - Driver driver = new Driver(); - context.load(); - driver.init(context); - } - + public LocalDriverRunner(DriverContext context) { + Driver driver = new Driver(); + context.load(); + driver.init(context); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalMasterRunner.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalMasterRunner.java index dc59d4419..8b9fea4d7 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalMasterRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/entrypoint/LocalMasterRunner.java @@ -26,12 +26,11 @@ public class LocalMasterRunner extends MasterRunner { - public LocalMasterRunner(Configuration configuration) { - super(configuration, new LocalClusterManager()); - } - - public ClusterInfo init() { - return master.startCluster(); - } + public LocalMasterRunner(Configuration configuration) { + super(configuration, new LocalClusterManager()); + } + public ClusterInfo init() { + return master.startCluster(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/failover/LocalFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/failover/LocalFailoverStrategy.java index 670380971..4b4c003db 100644 --- a/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/failover/LocalFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-local/src/main/java/org/apache/geaflow/cluster/local/failover/LocalFailoverStrategy.java @@ -30,25 +30,23 @@ public class LocalFailoverStrategy implements IFailoverStrategy { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalFailoverStrategy.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LocalFailoverStrategy.class); - @Override - public void init(ClusterContext context) { + @Override + public void init(ClusterContext context) {} - } + @Override + public void doFailover(int componentId, Throwable cause) { + LOGGER.info("component {} do failover", componentId); + } - @Override - public void doFailover(int componentId, Throwable cause) { - LOGGER.info("component {} do failover", componentId); - } + @Override + public FailoverStrategyType getType() { + return disable_fo; + } - @Override - public FailoverStrategyType getType() { - return disable_fo; - } - - @Override - public EnvType getEnv() { - return EnvType.LOCAL; - } + @Override + public EnvType getEnv() { + return EnvType.LOCAL; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayClusterClient.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayClusterClient.java index ecce17b68..d9137bb58 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayClusterClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayClusterClient.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_WORK_PATH; -import io.ray.api.Ray; import org.apache.geaflow.cluster.client.AbstractClusterClient; import org.apache.geaflow.cluster.client.IPipelineClient; import org.apache.geaflow.cluster.client.PipelineClientFactory; @@ -37,47 +36,48 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RayClusterClient extends AbstractClusterClient { +import io.ray.api.Ray; - private static final Logger LOGGER = LoggerFactory.getLogger(RayClusterClient.class); - private RayClusterManager rayClusterManager; - private ClusterContext clusterContext; +public class RayClusterClient extends AbstractClusterClient { - @Override - public void init(IEnvironmentContext environmentContext) { - super.init(environmentContext); - clusterContext = new ClusterContext(config); - rayClusterManager = new RayClusterManager(); - rayClusterManager.init(clusterContext); + private static final Logger LOGGER = LoggerFactory.getLogger(RayClusterClient.class); + private RayClusterManager rayClusterManager; + private ClusterContext clusterContext; - RaySystemFunc.initRayEnv(clusterContext.getClusterConfig()); - if (!config.contains(JOB_WORK_PATH)) { - config.put(JOB_WORK_PATH, RaySystemFunc.getWorkPath()); - } - } + @Override + public void init(IEnvironmentContext environmentContext) { + super.init(environmentContext); + clusterContext = new ClusterContext(config); + rayClusterManager = new RayClusterManager(); + rayClusterManager.init(clusterContext); - @Override - public IPipelineClient startCluster() { - try { - RayClusterId clusterId = rayClusterManager.startMaster(); - ClusterInfo clusterInfo = RayClient.initMaster(clusterId.getHandler()); - ClusterMeta clusterMeta = new ClusterMeta(clusterInfo); - callback.onSuccess(clusterMeta); - LOGGER.info("cluster info: {}", clusterInfo); - return PipelineClientFactory.createPipelineClient( - clusterInfo.getDriverAddresses(), clusterContext.getConfig()); - } catch (Throwable e) { - LOGGER.error("deploy cluster failed", e); - callback.onFailure(e); - throw new GeaflowRuntimeException(e); - } + RaySystemFunc.initRayEnv(clusterContext.getClusterConfig()); + if (!config.contains(JOB_WORK_PATH)) { + config.put(JOB_WORK_PATH, RaySystemFunc.getWorkPath()); } + } - @Override - public void shutdown() { - LOGGER.info("shutdown cluster"); - rayClusterManager.close(); - Ray.shutdown(); + @Override + public IPipelineClient startCluster() { + try { + RayClusterId clusterId = rayClusterManager.startMaster(); + ClusterInfo clusterInfo = RayClient.initMaster(clusterId.getHandler()); + ClusterMeta clusterMeta = new ClusterMeta(clusterInfo); + callback.onSuccess(clusterMeta); + LOGGER.info("cluster info: {}", clusterInfo); + return PipelineClientFactory.createPipelineClient( + clusterInfo.getDriverAddresses(), clusterContext.getConfig()); + } catch (Throwable e) { + LOGGER.error("deploy cluster failed", e); + callback.onFailure(e); + throw new GeaflowRuntimeException(e); } + } + @Override + public void shutdown() { + LOGGER.info("shutdown cluster"); + rayClusterManager.close(); + Ray.shutdown(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayEnvironment.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayEnvironment.java index 296d64e39..a34f44161 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayEnvironment.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/client/RayEnvironment.java @@ -28,18 +28,18 @@ public class RayEnvironment extends AbstractEnvironment { - public RayEnvironment() { - context.getConfig().put(LOG_DIR, RAY_LOG_DIR); - context.getConfig().put(SUPERVISOR_ENABLE, Boolean.FALSE.toString()); - } + public RayEnvironment() { + context.getConfig().put(LOG_DIR, RAY_LOG_DIR); + context.getConfig().put(SUPERVISOR_ENABLE, Boolean.FALSE.toString()); + } - @Override - protected IClusterClient getClusterClient() { - return new RayClusterClient(); - } + @Override + protected IClusterClient getClusterClient() { + return new RayClusterClient(); + } - @Override - public EnvType getEnvType() { - return EnvType.RAY; - } + @Override + public EnvType getEnvType() { + return EnvType.RAY; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClient.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClient.java index 8106e1a63..e6e164a6f 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClient.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClient.java @@ -19,13 +19,10 @@ package org.apache.geaflow.cluster.ray.clustermanager; -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.api.options.ActorLifetime; import java.io.Serializable; import java.util.List; import java.util.Map; + import org.apache.geaflow.cluster.clustermanager.ClusterInfo; import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.container.ContainerContext; @@ -37,71 +34,83 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.options.ActorLifetime; public class RayClient implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(RayClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RayClient.class); - public static ActorHandle createMaster(ClusterConfig clusterConfig) { - int totalMemoryMb = clusterConfig.getMasterMemoryMB(); - List jvmOptions = clusterConfig.getMasterJvmOptions().getJvmOptions(); + public static ActorHandle createMaster(ClusterConfig clusterConfig) { + int totalMemoryMb = clusterConfig.getMasterMemoryMB(); + List jvmOptions = clusterConfig.getMasterJvmOptions().getJvmOptions(); - ActorHandle masterRayActor = Ray - .actor(RayMasterRunner::new, clusterConfig.getConfig()) + ActorHandle masterRayActor = + Ray.actor(RayMasterRunner::new, clusterConfig.getConfig()) .setMaxRestarts(clusterConfig.getMaxRestarts()) .setLifetime(ActorLifetime.DETACHED) .remote(); - LOGGER.info("master actor:{}, memoryMB:{}, jvmOptions:{}, foRestartTimes:{}", - masterRayActor.getId().toString(), totalMemoryMb, jvmOptions, - clusterConfig.getMaxRestarts()); - return masterRayActor; - } + LOGGER.info( + "master actor:{}, memoryMB:{}, jvmOptions:{}, foRestartTimes:{}", + masterRayActor.getId().toString(), + totalMemoryMb, + jvmOptions, + clusterConfig.getMaxRestarts()); + return masterRayActor; + } - public static ClusterInfo initMaster(ActorHandle masterActor) { - LOGGER.info("init master:{}", masterActor.getId().toString()); - ObjectRef masterMetaRayObject = masterActor.task(RayMasterRunner::init) - .remote(); - return masterMetaRayObject.get(); - } + public static ClusterInfo initMaster(ActorHandle masterActor) { + LOGGER.info("init master:{}", masterActor.getId().toString()); + ObjectRef masterMetaRayObject = masterActor.task(RayMasterRunner::init).remote(); + return masterMetaRayObject.get(); + } - public static ActorHandle createDriver(ClusterConfig clusterConfig, - DriverContext context) { - int totalMemoryMb = clusterConfig.getDriverMemoryMB(); - List jvmOptions = clusterConfig.getDriverJvmOptions().getJvmOptions(); + public static ActorHandle createDriver( + ClusterConfig clusterConfig, DriverContext context) { + int totalMemoryMb = clusterConfig.getDriverMemoryMB(); + List jvmOptions = clusterConfig.getDriverJvmOptions().getJvmOptions(); - ActorHandle driverRayActor = Ray - .actor(RayDriverRunner::new, context) + ActorHandle driverRayActor = + Ray.actor(RayDriverRunner::new, context) .setMaxRestarts(clusterConfig.getMaxRestarts()) .setLifetime(ActorLifetime.DETACHED) .remote(); - LOGGER.info("driver actor:{}, memoryMB:{}, jvmOptions:{}, foRestartTimes:{}", - driverRayActor.getId().toString(), totalMemoryMb, jvmOptions, - clusterConfig.getMaxRestarts()); - return driverRayActor; - } + LOGGER.info( + "driver actor:{}, memoryMB:{}, jvmOptions:{}, foRestartTimes:{}", + driverRayActor.getId().toString(), + totalMemoryMb, + jvmOptions, + clusterConfig.getMaxRestarts()); + return driverRayActor; + } - public static ActorHandle createContainer(ClusterConfig clusterConfig, - ContainerContext containerContext) { - ActorHandle rayContainer = Ray - .actor(RayContainerRunner::new, containerContext) + public static ActorHandle createContainer( + ClusterConfig clusterConfig, ContainerContext containerContext) { + ActorHandle rayContainer = + Ray.actor(RayContainerRunner::new, containerContext) .setMaxRestarts(clusterConfig.getMaxRestarts()) .setLifetime(ActorLifetime.DETACHED) .remote(); - LOGGER.info("worker actor {} maxRestarts {}", rayContainer.getId().toString(), - clusterConfig.getMaxRestarts()); - return rayContainer; - } + LOGGER.info( + "worker actor {} maxRestarts {}", + rayContainer.getId().toString(), + clusterConfig.getMaxRestarts()); + return rayContainer; + } - public static ActorHandle createSupervisor(ClusterConfig clusterConfig, - Map envs) { - ActorHandle rayContainer = Ray - .actor(RaySupervisorRunner::new, clusterConfig.getConfig(), envs) + public static ActorHandle createSupervisor( + ClusterConfig clusterConfig, Map envs) { + ActorHandle rayContainer = + Ray.actor(RaySupervisorRunner::new, clusterConfig.getConfig(), envs) .setMaxRestarts(clusterConfig.getMaxRestarts()) .setLifetime(ActorLifetime.DETACHED) .remote(); - LOGGER.info("supervisor actor {} maxRestarts {}", rayContainer.getId().toString(), - clusterConfig.getMaxRestarts()); - return rayContainer; - } - + LOGGER.info( + "supervisor actor {} maxRestarts {}", + rayContainer.getId().toString(), + clusterConfig.getMaxRestarts()); + return rayContainer; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterId.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterId.java index 1c6213107..32f776947 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterId.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterId.java @@ -19,20 +19,20 @@ package org.apache.geaflow.cluster.ray.clustermanager; -import io.ray.api.ActorHandle; import org.apache.geaflow.cluster.clustermanager.ClusterId; import org.apache.geaflow.cluster.ray.entrypoint.RayMasterRunner; -public class RayClusterId implements ClusterId { +import io.ray.api.ActorHandle; - private ActorHandle handler; +public class RayClusterId implements ClusterId { - public RayClusterId(ActorHandle handler) { - this.handler = handler; - } + private ActorHandle handler; - public ActorHandle getHandler() { - return handler; - } + public RayClusterId(ActorHandle handler) { + this.handler = handler; + } + public ActorHandle getHandler() { + return handler; + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterManager.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterManager.java index c9405642a..3f3fd4c26 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterManager.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/clustermanager/RayClusterManager.java @@ -23,12 +23,10 @@ import static org.apache.geaflow.cluster.ray.config.RayConfig.RAY_AGENT_PROFILER_PATH; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.PROCESS_AUTO_RESTART; -import com.google.common.base.Preconditions; -import io.ray.api.ActorHandle; -import io.ray.api.Ray; import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.clustermanager.ClusterContext; import org.apache.geaflow.cluster.container.ContainerContext; @@ -46,103 +44,106 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RayClusterManager extends GeaFlowClusterManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayClusterManager.class); - private static final int MASTER_ACTOR_ID = 0; - private static Map actors = new HashMap<>(); +import com.google.common.base.Preconditions; - private String autoRestart; - private String currentJobId; +import io.ray.api.ActorHandle; +import io.ray.api.Ray; - public RayClusterManager() { - super(EnvType.RAY); - } +public class RayClusterManager extends GeaFlowClusterManager { - @Override - public void init(ClusterContext clusterContext) { - super.init(clusterContext); - if (clusterContext.getHeartbeatManager() != null) { - this.failFast = clusterConfig.getMaxRestarts() == 0; - this.autoRestart = config.getString(PROCESS_AUTO_RESTART); - this.currentJobId = Ray.getRuntimeContext().getCurrentJobId().toString(); - } - String profilerPath = config.getString(RAY_AGENT_PROFILER_PATH); - this.config.put(AGENT_PROFILER_PATH, profilerPath); - } + private static final Logger LOGGER = LoggerFactory.getLogger(RayClusterManager.class); + private static final int MASTER_ACTOR_ID = 0; + private static Map actors = new HashMap<>(); - @Override - public RayClusterId startMaster() { - Preconditions.checkArgument(clusterConfig != null, "clusterConfig is not initialized"); - ActorHandle master = RayClient.createMaster(clusterConfig); - clusterInfo = new RayClusterId(master); - actors.put(MASTER_ACTOR_ID, master); - return (RayClusterId) clusterInfo; - } + private String autoRestart; + private String currentJobId; - @Override - public void createNewContainer(int containerId, boolean isRecover) { - if (enableSupervisor) { - String logFile = String.format("container-%s-%s.log", currentJobId, containerId); - String command = getContainerShellCommand(containerId, isRecover, logFile); - ActorHandle actor = createSupervisor(containerId, command, - autoRestart); - actors.put(containerId, actor); - } else { - ContainerContext containerContext = new RayContainerContext(containerId, - clusterConfig.getConfig()); - ActorHandle container = RayClient.createContainer(clusterConfig, containerContext); - actors.put(containerId, container); - } + public RayClusterManager() { + super(EnvType.RAY); + } + @Override + public void init(ClusterContext clusterContext) { + super.init(clusterContext); + if (clusterContext.getHeartbeatManager() != null) { + this.failFast = clusterConfig.getMaxRestarts() == 0; + this.autoRestart = config.getString(PROCESS_AUTO_RESTART); + this.currentJobId = Ray.getRuntimeContext().getCurrentJobId().toString(); } - - @Override - public void createNewDriver(int driverId, int driverIndex) { - LOGGER.info("create driver start, enable supervisor:{}", enableSupervisor); - if (enableSupervisor) { - String logFile = String.format("driver-%s-%s.log", currentJobId, driverId); - String command = getDriverShellCommand(driverId, driverIndex, logFile); - ActorHandle actor = createSupervisor(driverId, command, - autoRestart); - actors.put(driverId, actor); - } else { - DriverContext driverContext = new RayDriverContext(driverId, driverIndex, clusterConfig.getConfig()); - ActorHandle driver = RayClient.createDriver(clusterConfig, - driverContext); - actors.put(driverId, driver); - } - LOGGER.info("call driver start, id:{} index:{}", driverId, driverIndex); + String profilerPath = config.getString(RAY_AGENT_PROFILER_PATH); + this.config.put(AGENT_PROFILER_PATH, profilerPath); + } + + @Override + public RayClusterId startMaster() { + Preconditions.checkArgument(clusterConfig != null, "clusterConfig is not initialized"); + ActorHandle master = RayClient.createMaster(clusterConfig); + clusterInfo = new RayClusterId(master); + actors.put(MASTER_ACTOR_ID, master); + return (RayClusterId) clusterInfo; + } + + @Override + public void createNewContainer(int containerId, boolean isRecover) { + if (enableSupervisor) { + String logFile = String.format("container-%s-%s.log", currentJobId, containerId); + String command = getContainerShellCommand(containerId, isRecover, logFile); + ActorHandle actor = createSupervisor(containerId, command, autoRestart); + actors.put(containerId, actor); + } else { + ContainerContext containerContext = + new RayContainerContext(containerId, clusterConfig.getConfig()); + ActorHandle container = + RayClient.createContainer(clusterConfig, containerContext); + actors.put(containerId, container); } - - @Override - public void restartContainer(int containerId) { - if (!actors.containsKey(containerId)) { - throw new GeaflowRuntimeException(String.format("invalid container id %s", containerId)); - } - actors.get(containerId).kill(); + } + + @Override + public void createNewDriver(int driverId, int driverIndex) { + LOGGER.info("create driver start, enable supervisor:{}", enableSupervisor); + if (enableSupervisor) { + String logFile = String.format("driver-%s-%s.log", currentJobId, driverId); + String command = getDriverShellCommand(driverId, driverIndex, logFile); + ActorHandle actor = createSupervisor(driverId, command, autoRestart); + actors.put(driverId, actor); + } else { + DriverContext driverContext = + new RayDriverContext(driverId, driverIndex, clusterConfig.getConfig()); + ActorHandle driver = RayClient.createDriver(clusterConfig, driverContext); + actors.put(driverId, driver); } + LOGGER.info("call driver start, id:{} index:{}", driverId, driverIndex); + } - @Override - public void restartDriver(int driverId) { - if (!actors.containsKey(driverId)) { - throw new GeaflowRuntimeException(String.format("invalid driver id %s", driverId)); - } - actors.get(driverId).kill(); + @Override + public void restartContainer(int containerId) { + if (!actors.containsKey(containerId)) { + throw new GeaflowRuntimeException(String.format("invalid container id %s", containerId)); } + actors.get(containerId).kill(); + } - private ActorHandle createSupervisor(int containerId, String command, String autoStart) { - Map additionalEnvs = buildSupervisorEnvs(containerId, command, autoStart); - return RayClient.createSupervisor(clusterConfig, additionalEnvs); + @Override + public void restartDriver(int driverId) { + if (!actors.containsKey(driverId)) { + throw new GeaflowRuntimeException(String.format("invalid driver id %s", driverId)); } - - @Override - public void close() { - super.close(); - actors.clear(); - if (RaySystemFunc.isLocalMode()) { - FileUtils.deleteQuietly(new File(RaySystemFunc.getWorkPath())); - } + actors.get(driverId).kill(); + } + + private ActorHandle createSupervisor( + int containerId, String command, String autoStart) { + Map additionalEnvs = buildSupervisorEnvs(containerId, command, autoStart); + return RayClient.createSupervisor(clusterConfig, additionalEnvs); + } + + @Override + public void close() { + super.close(); + actors.clear(); + if (RaySystemFunc.isLocalMode()) { + FileUtils.deleteQuietly(new File(RaySystemFunc.getWorkPath())); } - + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/config/RayConfig.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/config/RayConfig.java index 00b332d31..61822fce1 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/config/RayConfig.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/config/RayConfig.java @@ -24,52 +24,50 @@ public class RayConfig { - public static final String RAY_RUN_MODE = "ray.run-mode"; + public static final String RAY_RUN_MODE = "ray.run-mode"; - public static final String RAY_JOB_JVM_OPTIONS_PREFIX = "ray.job.jvm-options"; - public static final String RAY_TASK_RETURN_TASK_EXCEPTION = "ray.task.return_task_exception"; - public static final String RAY_JOB_L1FO_ENABLE = "ray.job.enable-l1-fault-tolerance"; + public static final String RAY_JOB_JVM_OPTIONS_PREFIX = "ray.job.jvm-options"; + public static final String RAY_TASK_RETURN_TASK_EXCEPTION = "ray.task.return_task_exception"; + public static final String RAY_JOB_L1FO_ENABLE = "ray.job.enable-l1-fault-tolerance"; - public static final String RAY_JOB_RUNTIME_ENV = "ray.job.runtime-env"; - public static final String RAY_JOB_WORKING_DIR = "working_dir"; + public static final String RAY_JOB_RUNTIME_ENV = "ray.job.runtime-env"; + public static final String RAY_JOB_WORKING_DIR = "working_dir"; - public static final int CLUSTER_RESERVED_MEMORY_MB = 3 * 1024; - public static final int WORKER_RESERVED_MEMORY_MB = 3 * 1024; + public static final int CLUSTER_RESERVED_MEMORY_MB = 3 * 1024; + public static final int WORKER_RESERVED_MEMORY_MB = 3 * 1024; + // Sets the amount of memory occupied by a jvm process, both in and out of the heap + public static final String RAY_JOB_JAVA_WORKER_PROCESS_DEFAULT_MEMORY_MB = + "ray.job.java-worker-process-default-memory-mb"; - // Sets the amount of memory occupied by a jvm process, both in and out of the heap - public static final String RAY_JOB_JAVA_WORKER_PROCESS_DEFAULT_MEMORY_MB = - "ray.job.java-worker-process-default-memory-mb"; + // The proportion of Xmx to memory_mb + public static final String RAY_JOB_JAVA_HEAP_FRACTION = "ray.job.java-heap-fraction"; - // The proportion of Xmx to memory_mb - public static final String RAY_JOB_JAVA_HEAP_FRACTION = "ray.job.java-heap-fraction"; + // Set the initialization to start how many jvm processes there + public static final String RAY_JOB_NUM_INITIAL_JAVA_WORKER_PROCESS = + "ray.job.num-initial-java-worker-processes"; - // Set the initialization to start how many jvm processes there - public static final String RAY_JOB_NUM_INITIAL_JAVA_WORKER_PROCESS = - "ray.job.num-initial-java-worker-processes"; + // Sets the total memory for initializing startup jobs + public static final String RAY_JOB_TOTAL_MEMORY_MB = "ray.job.total-memory-mb"; - // Sets the total memory for initializing startup jobs - public static final String RAY_JOB_TOTAL_MEMORY_MB = "ray.job.total-memory-mb"; + // User - defined log file + public static final String RAY_CUSTOM_LOGGER0_NAME = "ray.logging.loggers.0.name"; - // User - defined log file - public static final String RAY_CUSTOM_LOGGER0_NAME = "ray.logging.loggers.0.name"; + public static final String RAY_CUSTOM_LOGGER0_FILE_NAME = "ray.logging.loggers.0.file-name"; - public static final String RAY_CUSTOM_LOGGER0_FILE_NAME = "ray.logging.loggers.0.file-name"; + public static final String RAY_CUSTOM_LOGGER0_PATTERN = "ray.logging.loggers.0.pattern"; - public static final String RAY_CUSTOM_LOGGER0_PATTERN = "ray.logging.loggers.0.pattern"; + public static final String CUSTOM_LOGGER_NAME = "userlogger"; - public static final String CUSTOM_LOGGER_NAME = "userlogger"; + public static final String CUSTOM_LOGGER_FILE_NAME = "geaflow-user-%p.log"; - public static final String CUSTOM_LOGGER_FILE_NAME = "geaflow-user-%p.log"; + public static final String CUSTOM_LOGGER_PATTERN = + "%d{yyyy-MM-dd HH:mm:ss,SSS} %p %c{1} [%t]: %m%n"; - public static final String CUSTOM_LOGGER_PATTERN = "%d{yyyy-MM-dd HH:mm:ss,SSS} %p %c{1} [%t]: %m%n"; - - public static final String RAY_LOG_DIR = "/home/admin/logs/ray-logs/logs"; - - public static final ConfigKey RAY_AGENT_PROFILER_PATH = ConfigKeys - .key("ray.agent.profiler.path") - .defaultValue("/home/admin/ray-pack/bin/profiler/profiler.sh") - .description("ray agent profiler path"); + public static final String RAY_LOG_DIR = "/home/admin/logs/ray-logs/logs"; + public static final ConfigKey RAY_AGENT_PROFILER_PATH = + ConfigKeys.key("ray.agent.profiler.path") + .defaultValue("/home/admin/ray-pack/bin/profiler/profiler.sh") + .description("ray agent profiler path"); } - diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayContainerContext.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayContainerContext.java index 26d6d4cd9..fd6b1a654 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayContainerContext.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayContainerContext.java @@ -25,14 +25,14 @@ public class RayContainerContext extends ContainerContext { - public RayContainerContext(int index, Configuration config) { - super(index, config); - } + public RayContainerContext(int index, Configuration config) { + super(index, config); + } - public boolean isRecover() { - if (RaySystemFunc.isLocalMode()) { - return false; - } - return RaySystemFunc.isRestarted(); + public boolean isRecover() { + if (RaySystemFunc.isLocalMode()) { + return false; } + return RaySystemFunc.isRestarted(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayDriverContext.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayDriverContext.java index 958ed55f3..d4bdaffee 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayDriverContext.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/context/RayDriverContext.java @@ -25,16 +25,15 @@ public class RayDriverContext extends DriverContext { - public RayDriverContext(int id, int index, Configuration config) { - super(id, index, config); - } + public RayDriverContext(int id, int index, Configuration config) { + super(id, index, config); + } - public boolean isRecover() { - // local mode not support restart. - if (RaySystemFunc.isLocalMode()) { - return false; - } - return RaySystemFunc.isRestarted(); + public boolean isRecover() { + // local mode not support restart. + if (RaySystemFunc.isLocalMode()) { + return false; } - + return RaySystemFunc.isRestarted(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayContainerRunner.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayContainerRunner.java index 6c5271daa..3bdb2b476 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayContainerRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayContainerRunner.java @@ -24,10 +24,9 @@ public class RayContainerRunner { - public RayContainerRunner(ContainerContext context) { - Container container = new Container(); - context.load(); - container.init(context); - } - + public RayContainerRunner(ContainerContext context) { + Container container = new Container(); + context.load(); + container.init(context); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayDriverRunner.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayDriverRunner.java index 9b0c798a6..9b667a076 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayDriverRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayDriverRunner.java @@ -24,10 +24,9 @@ public class RayDriverRunner { - public RayDriverRunner(DriverContext context) { - Driver driver = new Driver(); - context.load(); - driver.init(context); - } - + public RayDriverRunner(DriverContext context) { + Driver driver = new Driver(); + context.load(); + driver.init(context); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayMasterRunner.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayMasterRunner.java index 87b7e1109..065d5f2dd 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayMasterRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RayMasterRunner.java @@ -25,8 +25,7 @@ public class RayMasterRunner extends MasterRunner { - public RayMasterRunner(Configuration configuration) { - super(configuration, new RayClusterManager()); - } - + public RayMasterRunner(Configuration configuration) { + super(configuration, new RayClusterManager()); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RaySupervisorRunner.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RaySupervisorRunner.java index 3bd8fcc46..3ff181d59 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RaySupervisorRunner.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/entrypoint/RaySupervisorRunner.java @@ -24,6 +24,7 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.CONTAINER_START_COMMAND; import java.util.Map; + import org.apache.geaflow.cluster.runner.Supervisor; import org.apache.geaflow.cluster.runner.util.ClusterUtils; import org.apache.geaflow.common.config.Configuration; @@ -32,18 +33,20 @@ import org.slf4j.LoggerFactory; public class RaySupervisorRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(RaySupervisorRunner.class); - - public RaySupervisorRunner(Configuration configuration, Map env) { - String id = ClusterUtils.getEnvValue(env, CONTAINER_ID); - String autoRestartEnv = ClusterUtils.getEnvValue(env, AUTO_RESTART); - LOGGER.info("Start supervisor with ID: {} pid: {} autoStart: {}", id, - ProcessUtil.getProcessId(), autoRestartEnv); + private static final Logger LOGGER = LoggerFactory.getLogger(RaySupervisorRunner.class); - String startCommand = ClusterUtils.getEnvValue(env, CONTAINER_START_COMMAND); - boolean autoRestart = !autoRestartEnv.equalsIgnoreCase(Boolean.FALSE.toString()); - Supervisor supervisor = new Supervisor(startCommand, configuration, autoRestart); - supervisor.start(); - } + public RaySupervisorRunner(Configuration configuration, Map env) { + String id = ClusterUtils.getEnvValue(env, CONTAINER_ID); + String autoRestartEnv = ClusterUtils.getEnvValue(env, AUTO_RESTART); + LOGGER.info( + "Start supervisor with ID: {} pid: {} autoStart: {}", + id, + ProcessUtil.getProcessId(), + autoRestartEnv); + String startCommand = ClusterUtils.getEnvValue(env, CONTAINER_START_COMMAND); + boolean autoRestart = !autoRestartEnv.equalsIgnoreCase(Boolean.FALSE.toString()); + Supervisor supervisor = new Supervisor(startCommand, configuration, autoRestart); + supervisor.start(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayClusterFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayClusterFailoverStrategy.java index 26ce23eaa..63813a923 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayClusterFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayClusterFailoverStrategy.java @@ -26,14 +26,13 @@ public class RayClusterFailoverStrategy extends ClusterFailoverStrategy { - public RayClusterFailoverStrategy() { - super(EnvType.RAY); - } - - @Override - public void init(ClusterContext context) { - super.init(context); - System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); - } + public RayClusterFailoverStrategy() { + super(EnvType.RAY); + } + @Override + public void init(ClusterContext context) { + super.init(context); + System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayComponentFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayComponentFailoverStrategy.java index 90df123cc..0006fad72 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayComponentFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayComponentFailoverStrategy.java @@ -26,14 +26,13 @@ public class RayComponentFailoverStrategy extends ComponentFailoverStrategy { - public RayComponentFailoverStrategy() { - super(EnvType.RAY); - } - - @Override - public void init(ClusterContext context) { - super.init(context); - System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); - } + public RayComponentFailoverStrategy() { + super(EnvType.RAY); + } + @Override + public void init(ClusterContext context) { + super.init(context); + System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayDisableFailoverStrategy.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayDisableFailoverStrategy.java index 012a8ce46..7a3aa36eb 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayDisableFailoverStrategy.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/failover/RayDisableFailoverStrategy.java @@ -26,14 +26,13 @@ public class RayDisableFailoverStrategy extends DisableFailoverStrategy { - public RayDisableFailoverStrategy() { - super(EnvType.RAY); - } - - @Override - public void init(ClusterContext context) { - super.init(context); - System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); - } + public RayDisableFailoverStrategy() { + super(EnvType.RAY); + } + @Override + public void init(ClusterContext context) { + super.init(context); + System.setProperty(RayConfig.RAY_TASK_RETURN_TASK_EXCEPTION, Boolean.FALSE.toString()); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/utils/RaySystemFunc.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/utils/RaySystemFunc.java index 07a828869..179667034 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/utils/RaySystemFunc.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/main/java/org/apache/geaflow/cluster/ray/utils/RaySystemFunc.java @@ -21,10 +21,10 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_WORK_PATH; -import io.ray.api.Ray; import java.io.File; import java.io.IOException; import java.io.Serializable; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.cluster.config.ClusterConfig; import org.apache.geaflow.cluster.ray.config.RayConfig; @@ -32,51 +32,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import io.ray.api.Ray; + public class RaySystemFunc implements Serializable { - private static final long serialVersionUID = -3708025618479190982L; + private static final long serialVersionUID = -3708025618479190982L; - private static final Logger LOGGER = LoggerFactory.getLogger(RaySystemFunc.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RaySystemFunc.class); - private static final Object LOCK = new Object(); - private static String appPath; + private static final Object LOCK = new Object(); + private static String appPath; - public static boolean isRestarted() { - return Ray.getRuntimeContext().wasCurrentActorRestarted(); - } + public static boolean isRestarted() { + return Ray.getRuntimeContext().wasCurrentActorRestarted(); + } - public static boolean isLocalMode() { - return Ray.getRuntimeContext().isLocalMode(); - } + public static boolean isLocalMode() { + return Ray.getRuntimeContext().isLocalMode(); + } - public static String getWorkPath() { - if (appPath != null) { - return appPath; - } - synchronized (LOCK) { - if (Ray.getRuntimeContext().isLocalMode()) { - appPath = "/tmp/" + System.currentTimeMillis(); - try { - FileUtils.forceMkdir(new File(appPath)); - } catch (IOException e) { - LOGGER.error(e.getMessage(), e); - } - } else { - appPath = Ray.getRuntimeContext().getCurrentRuntimeEnv() - .get(RayConfig.RAY_JOB_WORKING_DIR, String.class); - } + public static String getWorkPath() { + if (appPath != null) { + return appPath; + } + synchronized (LOCK) { + if (Ray.getRuntimeContext().isLocalMode()) { + appPath = "/tmp/" + System.currentTimeMillis(); + try { + FileUtils.forceMkdir(new File(appPath)); + } catch (IOException e) { + LOGGER.error(e.getMessage(), e); } - return appPath; + } else { + appPath = + Ray.getRuntimeContext() + .getCurrentRuntimeEnv() + .get(RayConfig.RAY_JOB_WORKING_DIR, String.class); + } } + return appPath; + } - public static void initRayEnv(ClusterConfig clusterConfig) { - LOGGER.info("clusterConfig:{}", clusterConfig); - Configuration config = clusterConfig.getConfig(); + public static void initRayEnv(ClusterConfig clusterConfig) { + LOGGER.info("clusterConfig:{}", clusterConfig); + Configuration config = clusterConfig.getConfig(); - // Set working dir. - System.setProperty( - String.format("%s.%s", RayConfig.RAY_JOB_RUNTIME_ENV, RayConfig.RAY_JOB_WORKING_DIR), - config.getString(JOB_WORK_PATH)); - Ray.init(); - } + // Set working dir. + System.setProperty( + String.format("%s.%s", RayConfig.RAY_JOB_RUNTIME_ENV, RayConfig.RAY_JOB_WORKING_DIR), + config.getString(JOB_WORK_PATH)); + Ray.init(); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/RayJobTest.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/RayJobTest.java index 6297298b1..bb67facc9 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/RayJobTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/RayJobTest.java @@ -19,60 +19,70 @@ package org.apache.geaflow.cluster.ray; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.commons.lang3.StringEscapeUtils; import org.apache.geaflow.utils.HttpUtil; import org.testng.annotations.Test; public class RayJobTest { - @Test(enabled = false) - public void testRayJobSubmitTest() { - String mainClass = "org.apache.geaflow.GeaFlowJobDemo"; - List remoteJarUrls = Arrays.asList( - "https:://public-engine-url-geaflow-0.6.zip", - "https://public-udf-url-udf.zip"); - String args = "{\"job\":{\"geaflow.config.key\":\"true\"}"; - args = StringEscapeUtils.escapeJava(args); - args = StringEscapeUtils.escapeJava("\"" + args + "\""); + @Test(enabled = false) + public void testRayJobSubmitTest() { + String mainClass = "org.apache.geaflow.GeaFlowJobDemo"; + List remoteJarUrls = + Arrays.asList( + "https:://public-engine-url-geaflow-0.6.zip", "https://public-udf-url-udf.zip"); + String args = "{\"job\":{\"geaflow.config.key\":\"true\"}"; + args = StringEscapeUtils.escapeJava(args); + args = StringEscapeUtils.escapeJava("\"" + args + "\""); - submitJob(remoteJarUrls, mainClass, args); - } + submitJob(remoteJarUrls, mainClass, args); + } - public static void submitJob(List remoteJarUrls, String mainClass, String args) { - // Cluster args. - // ray dashboard url. - String rayDashboardAddress = "127.0.0.1:8265"; - // ray redis url. - String rayRedisAddress = "127.0.0.1:6379"; - String rayDistJarPath = "path-to-ray-cluster/ray_dist.jar"; - String raySessionResourceJarPath = "/path-to-ray-cluster-session-dir/session_latest/runtime_resources/java_jars_files/"; + public static void submitJob(List remoteJarUrls, String mainClass, String args) { + // Cluster args. + // ray dashboard url. + String rayDashboardAddress = "127.0.0.1:8265"; + // ray redis url. + String rayRedisAddress = "127.0.0.1:6379"; + String rayDistJarPath = "path-to-ray-cluster/ray_dist.jar"; + String raySessionResourceJarPath = + "/path-to-ray-cluster-session-dir/session_latest/runtime_resources/java_jars_files/"; - // Job args. - List downloadJarPaths = new ArrayList<>(remoteJarUrls.size()); - for (String remoteJarUrl : remoteJarUrls) { - String str = remoteJarUrl.replace(".zip", ""); - String result = str.replaceAll("[:/.]+", "_"); - downloadJarPaths.add(raySessionResourceJarPath + result + "/*"); - } + // Job args. + List downloadJarPaths = new ArrayList<>(remoteJarUrls.size()); + for (String remoteJarUrl : remoteJarUrls) { + String str = remoteJarUrl.replace(".zip", ""); + String result = str.replaceAll("[:/.]+", "_"); + downloadJarPaths.add(raySessionResourceJarPath + result + "/*"); + } - String downloadJarClassPath = String.join(":", downloadJarPaths); - List remoteJarUrlsStr = new ArrayList<>(remoteJarUrls.size()); - for (String remoteUrl : remoteJarUrls) { - remoteJarUrlsStr.add("\"" + remoteUrl + "\""); - } - String remoteJarJsonPath = String.join(",", remoteJarUrlsStr); + String downloadJarClassPath = String.join(":", downloadJarPaths); + List remoteJarUrlsStr = new ArrayList<>(remoteJarUrls.size()); + for (String remoteUrl : remoteJarUrls) { + remoteJarUrlsStr.add("\"" + remoteUrl + "\""); + } + String remoteJarJsonPath = String.join(",", remoteJarUrlsStr); - String request = String.format("{\n" - + "\"entrypoint\": \"java -classpath %s:%s -Dlog.file=/tmp/logfile.log -Dray.address=%s %s %s\",\n" - + "\"runtime_env\": {\"java_jars\": [%s]}\n" - + "}", rayDistJarPath, downloadJarClassPath, rayRedisAddress, mainClass, args, remoteJarJsonPath); + String request = + String.format( + "{\n" + + "\"entrypoint\": \"java -classpath %s:%s -Dlog.file=/tmp/logfile.log " + + " -Dray.address=%s %s %s\",\n" + + "\"runtime_env\": {\"java_jars\": [%s]}\n" + + "}", + rayDistJarPath, + downloadJarClassPath, + rayRedisAddress, + mainClass, + args, + remoteJarJsonPath); - String postUrl = String.format("http://%s/api/jobs/", rayDashboardAddress); - System.out.println("request: \n" + request + "\npost url: \n" + postUrl); - HttpUtil.post(postUrl, request); - } + String postUrl = String.format("http://%s/api/jobs/", rayDashboardAddress); + System.out.println("request: \n" + request + "\npost url: \n" + postUrl); + HttpUtil.post(postUrl, request); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayClusterFoStrategyTest.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayClusterFoStrategyTest.java index 4f1073820..4a4c23371 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayClusterFoStrategyTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayClusterFoStrategyTest.java @@ -30,20 +30,22 @@ public class RayClusterFoStrategyTest { - @Test - public void testLoad() { - Configuration configuration = new Configuration(); - IFailoverStrategy foStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, configuration.getString(FO_STRATEGY)); - Assert.assertNotNull(foStrategy); - Assert.assertEquals(foStrategy.getType().name(), configuration.getString(FO_STRATEGY)); - - IFailoverStrategy rayFoStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "cluster_fo"); - Assert.assertNotNull(rayFoStrategy); - Assert.assertEquals(rayFoStrategy.getClass(), RayClusterFailoverStrategy.class); - - rayFoStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "component_fo"); - Assert.assertNotNull(rayFoStrategy); - Assert.assertEquals(rayFoStrategy.getClass(), RayComponentFailoverStrategy.class); - } - + @Test + public void testLoad() { + Configuration configuration = new Configuration(); + IFailoverStrategy foStrategy = + FailoverStrategyFactory.loadFailoverStrategy( + EnvType.RAY, configuration.getString(FO_STRATEGY)); + Assert.assertNotNull(foStrategy); + Assert.assertEquals(foStrategy.getType().name(), configuration.getString(FO_STRATEGY)); + + IFailoverStrategy rayFoStrategy = + FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "cluster_fo"); + Assert.assertNotNull(rayFoStrategy); + Assert.assertEquals(rayFoStrategy.getClass(), RayClusterFailoverStrategy.class); + + rayFoStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "component_fo"); + Assert.assertNotNull(rayFoStrategy); + Assert.assertEquals(rayFoStrategy.getClass(), RayComponentFailoverStrategy.class); + } } diff --git a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayDisableFoStrategyTest.java b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayDisableFoStrategyTest.java index 21db2209f..bb6b816cb 100644 --- a/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayDisableFoStrategyTest.java +++ b/geaflow/geaflow-deploy/geaflow-on-ray/src/test/java/org/apache/geaflow/cluster/ray/failover/RayDisableFoStrategyTest.java @@ -30,16 +30,18 @@ public class RayDisableFoStrategyTest { - @Test - public void testLoad() { - Configuration configuration = new Configuration(); - IFailoverStrategy foStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, configuration.getString(FO_STRATEGY)); - Assert.assertNotNull(foStrategy); - Assert.assertEquals(foStrategy.getType().name(), configuration.getString(FO_STRATEGY)); - - IFailoverStrategy rayDisableFoStrategy = FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "disable_fo"); - Assert.assertNotNull(rayDisableFoStrategy); - Assert.assertEquals(rayDisableFoStrategy.getClass(), RayDisableFailoverStrategy.class); - } + @Test + public void testLoad() { + Configuration configuration = new Configuration(); + IFailoverStrategy foStrategy = + FailoverStrategyFactory.loadFailoverStrategy( + EnvType.RAY, configuration.getString(FO_STRATEGY)); + Assert.assertNotNull(foStrategy); + Assert.assertEquals(foStrategy.getType().name(), configuration.getString(FO_STRATEGY)); + IFailoverStrategy rayDisableFoStrategy = + FailoverStrategyFactory.loadFailoverStrategy(EnvType.RAY, "disable_fo"); + Assert.assertNotNull(rayDisableFoStrategy); + Assert.assertEquals(rayDisableFoStrategy.getClass(), RayDisableFailoverStrategy.class); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/Catalog.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/Catalog.java index 0e6556ddb..851c0b3bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/Catalog.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/Catalog.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.catalog; import java.util.Set; + import org.apache.calcite.schema.Table; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.exception.ObjectAlreadyExistException; @@ -32,144 +33,145 @@ public interface Catalog { - String DEFAULT_INSTANCE = "default"; - - /** - * Init catalog. - * - * @param config environment config. - */ - void init(Configuration config); - - /** - * Get the catalog type name. - * - * @return catalog type name. - */ - String getType(); - - /** - * List all the instance names. - * - * @return The instance name list. - */ - Set listInstances(); - - /** - * Test if the instance name exists. - * - * @param instanceName The instance name. - * @return true if exists, else return false. - */ - boolean isInstanceExists(String instanceName); - - /** - * Get graph of the instance and graph name. - * - * @param instanceName The instance name. - * @param graphName The graph name. - * @return A {@link GeaFlowGraph}. - */ - Table getGraph(String instanceName, String graphName); - - Table getTable(String instanceName, String tableName); - - VertexTable getVertex(String instanceName, String vertexName); - - EdgeTable getEdge(String instanceName, String edgeName); - - GeaFlowFunction getFunction(String instanceName, String functionName); - - Set listGraphAndTable(String instanceName); - - /** - * Create a graph under the specified instance. - * - * @param instanceName The instance name for the graph to create. - * @param graph The graph to create. - * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} - * when the graph has already exists. - */ - void createGraph(String instanceName, GeaFlowGraph graph) throws ObjectAlreadyExistException; - - /** - * Create a table under the specified instance. - * - * @param instanceName The instance name for the graph to create. - * @param table The table to create. - * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} - * when the table has already exists. - */ - void createTable(String instanceName, GeaFlowTable table) throws ObjectAlreadyExistException; - - /** - * Create a view under the specified instance. - * - * @param instanceName The instance name for the graph to create. - * @param view The view to create. - * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} - * when the view has already exists. - */ - void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException; - - /** - * Create a function under the specified instance. - * - * @param instanceName The instance name for the graph to create. - * @param function The function to create. - * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} - * when the function has already exists. - */ - void createFunction(String instanceName, GeaFlowFunction function) throws ObjectAlreadyExistException; - - /** - * Drop a graph under the specified instance. - * - * @param instanceName The instance name for the graph to drop. - * @param graphName The graph name to drop. - */ - void dropGraph(String instanceName, String graphName); - - /** - * Drop a table under the specified instance. - * - * @param instanceName The instance name for the table to drop. - * @param tableName The table name to drop. - */ - void dropTable(String instanceName, String tableName); - - /** - * Drop a function under the specified instance. - * - * @param instanceName The instance name for the function to drop. - * @param functionName The function name to drop. - */ - void dropFunction(String instanceName, String functionName); - - /** - * Describe the graph information. - * - * @param instanceName The instance name for the graph. - * @param graphName The graph name to describe. - * @return The information of the graph. - */ - String describeGraph(String instanceName, String graphName); - - /** - * Describe the table information. - * - * @param instanceName The instance name for the table. - * @param tableName The table name to describe. - * @return The information of the table. - */ - String describeTable(String instanceName, String tableName); - - /** - * Describe the function information. - * - * @param instanceName The instance name for the function. - * @param functionName The function name to describe. - * @return The information of the function. - */ - String describeFunction(String instanceName, String functionName); + String DEFAULT_INSTANCE = "default"; + + /** + * Init catalog. + * + * @param config environment config. + */ + void init(Configuration config); + + /** + * Get the catalog type name. + * + * @return catalog type name. + */ + String getType(); + + /** + * List all the instance names. + * + * @return The instance name list. + */ + Set listInstances(); + + /** + * Test if the instance name exists. + * + * @param instanceName The instance name. + * @return true if exists, else return false. + */ + boolean isInstanceExists(String instanceName); + + /** + * Get graph of the instance and graph name. + * + * @param instanceName The instance name. + * @param graphName The graph name. + * @return A {@link GeaFlowGraph}. + */ + Table getGraph(String instanceName, String graphName); + + Table getTable(String instanceName, String tableName); + + VertexTable getVertex(String instanceName, String vertexName); + + EdgeTable getEdge(String instanceName, String edgeName); + + GeaFlowFunction getFunction(String instanceName, String functionName); + + Set listGraphAndTable(String instanceName); + + /** + * Create a graph under the specified instance. + * + * @param instanceName The instance name for the graph to create. + * @param graph The graph to create. + * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} when the graph + * has already exists. + */ + void createGraph(String instanceName, GeaFlowGraph graph) throws ObjectAlreadyExistException; + + /** + * Create a table under the specified instance. + * + * @param instanceName The instance name for the graph to create. + * @param table The table to create. + * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} when the table + * has already exists. + */ + void createTable(String instanceName, GeaFlowTable table) throws ObjectAlreadyExistException; + + /** + * Create a view under the specified instance. + * + * @param instanceName The instance name for the graph to create. + * @param view The view to create. + * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} when the view + * has already exists. + */ + void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException; + + /** + * Create a function under the specified instance. + * + * @param instanceName The instance name for the graph to create. + * @param function The function to create. + * @throws ObjectAlreadyExistException Throwing {@link ObjectAlreadyExistException} when the + * function has already exists. + */ + void createFunction(String instanceName, GeaFlowFunction function) + throws ObjectAlreadyExistException; + + /** + * Drop a graph under the specified instance. + * + * @param instanceName The instance name for the graph to drop. + * @param graphName The graph name to drop. + */ + void dropGraph(String instanceName, String graphName); + + /** + * Drop a table under the specified instance. + * + * @param instanceName The instance name for the table to drop. + * @param tableName The table name to drop. + */ + void dropTable(String instanceName, String tableName); + + /** + * Drop a function under the specified instance. + * + * @param instanceName The instance name for the function to drop. + * @param functionName The function name to drop. + */ + void dropFunction(String instanceName, String functionName); + + /** + * Describe the graph information. + * + * @param instanceName The instance name for the graph. + * @param graphName The graph name to describe. + * @return The information of the graph. + */ + String describeGraph(String instanceName, String graphName); + + /** + * Describe the table information. + * + * @param instanceName The instance name for the table. + * @param tableName The table name to describe. + * @return The information of the table. + */ + String describeTable(String instanceName, String tableName); + + /** + * Describe the function information. + * + * @param instanceName The instance name for the function. + * @param functionName The function name to describe. + * @return The information of the function. + */ + String describeFunction(String instanceName, String functionName); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogFactory.java index fd75c7b59..bcd8faad5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogFactory.java @@ -22,30 +22,35 @@ import java.util.ArrayList; import java.util.List; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; public class CatalogFactory { - public static Catalog getCatalog(String catalogType) { - ServiceLoader graphCatalogs = ServiceLoader.load(Catalog.class); - List catalogTypes = new ArrayList<>(); - - for (Catalog catalog : graphCatalogs) { - if (catalog.getType().equals(catalogType)) { - return catalog; - } - catalogTypes.add(catalog.getType()); - } - throw new GeaFlowDSLException("Catalog type: '" + catalogType + "' is not exists, " - + "available types are: " + catalogTypes); - } + public static Catalog getCatalog(String catalogType) { + ServiceLoader graphCatalogs = ServiceLoader.load(Catalog.class); + List catalogTypes = new ArrayList<>(); - public static Catalog getCatalog(Configuration conf) { - String catalogType = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE); - Catalog catalog = new CatalogImpl(getCatalog(catalogType)); - catalog.init(conf); + for (Catalog catalog : graphCatalogs) { + if (catalog.getType().equals(catalogType)) { return catalog; + } + catalogTypes.add(catalog.getType()); } + throw new GeaFlowDSLException( + "Catalog type: '" + + catalogType + + "' is not exists, " + + "available types are: " + + catalogTypes); + } + + public static Catalog getCatalog(Configuration conf) { + String catalogType = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE); + Catalog catalog = new CatalogImpl(getCatalog(catalogType)); + catalog.init(conf); + return catalog; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogImpl.java index 78c52d9ba..60833d0bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CatalogImpl.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; + import org.apache.calcite.schema.Table; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.exception.ObjectAlreadyExistException; @@ -36,218 +37,217 @@ public class CatalogImpl implements Catalog { - private final Catalog baseCatalog; - - private final Catalog memoryCatalog; - - private final Map> tmpTables = new HashMap<>(); - - public CatalogImpl(Catalog catalog) { - this.baseCatalog = catalog; - this.memoryCatalog = new MemoryCatalog(); - } - - @Override - public void init(Configuration config) { - memoryCatalog.init(config); - baseCatalog.init(config); - } - - @Override - public String getType() { - return baseCatalog.getType(); - } - - @Override - public Set listInstances() { - Set allInstances = new HashSet<>(memoryCatalog.listInstances()); - allInstances.addAll(baseCatalog.listInstances()); - return allInstances; - } - - @Override - public boolean isInstanceExists(String instanceName) { - return memoryCatalog.isInstanceExists(instanceName) || baseCatalog.isInstanceExists(instanceName); - } - - @Override - public Table getGraph(String instanceName, String graphName) { - Table graph = memoryCatalog.getGraph(instanceName, graphName); - if (graph != null) { - return graph; - } - return baseCatalog.getGraph(instanceName, graphName); - } - - @Override - public Table getTable(String instanceName, String tableName) { - Table table = memoryCatalog.getTable(instanceName, tableName); - if (table != null) { - return table; - } - table = baseCatalog.getTable(instanceName, tableName); - if (table != null) { - return table; - } - table = getVertex(instanceName, tableName); - if (table != null) { - return table; - } - return getEdge(instanceName, tableName); - } - - @Override - public VertexTable getVertex(String instanceName, String vertexName) { - Table table = memoryCatalog.getTable(instanceName, vertexName); - if (table instanceof VertexTable) { - return (VertexTable) table; - } - return baseCatalog.getVertex(instanceName, vertexName); - } - - @Override - public EdgeTable getEdge(String instanceName, String edgeName) { - Table table = memoryCatalog.getTable(instanceName, edgeName); - if (table instanceof EdgeTable) { - return (EdgeTable) table; - } - return baseCatalog.getEdge(instanceName, edgeName); - } - - @Override - public GeaFlowFunction getFunction(String instanceName, String functionName) { - return baseCatalog.getFunction(instanceName, functionName); - } - - @Override - public Set listGraphAndTable(String instanceName) { - Set graphAndTables = new HashSet<>(memoryCatalog.listGraphAndTable(instanceName)); - graphAndTables.addAll(baseCatalog.listGraphAndTable(instanceName)); - return graphAndTables; - } - - @Override - public void createGraph(String instanceName, GeaFlowGraph graph) - throws ObjectAlreadyExistException, ObjectNotExistException { - if (getGraph(instanceName, graph.getName()) != null) { - if (!graph.isIfNotExists()) { - throw new ObjectAlreadyExistException(graph.getName()); - } - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - // create graph - if (graph.isTemporary()) { - memoryCatalog.createGraph(instanceName, graph); - } else { - baseCatalog.createGraph(instanceName, graph); - } - // create vertex table & edge table in catalog - for (VertexTable vertexTable : graph.getVertexTables()) { - createTable(instanceName, vertexTable); - } - for (EdgeTable edgeTable : graph.getEdgeTables()) { - createTable(instanceName, edgeTable); - } - } - - @Override - public void createTable(String instanceName, GeaFlowTable table) - throws ObjectAlreadyExistException, ObjectNotExistException { - if (getTable(instanceName, table.getName()) != null) { - if (!table.isIfNotExists()) { - throw new ObjectAlreadyExistException(table.getName()); - } - // ignore if table exists. - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - if (table.isTemporary()) { - memoryCatalog.createTable(instanceName, table); - } else { - baseCatalog.createTable(instanceName, table); - } - } - - @Override - public void createView(String instanceName, GeaFlowView view) - throws ObjectAlreadyExistException { - if (getTable(instanceName, view.getName()) != null) { - if (!view.isIfNotExists()) { - throw new ObjectAlreadyExistException(view.getName()); - } - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - baseCatalog.createView(instanceName, view); - } - - @Override - public void createFunction(String instanceName, GeaFlowFunction function) - throws ObjectAlreadyExistException { - baseCatalog.createFunction(instanceName, function); - } - - @Override - public void dropGraph(String instanceName, String graphName) { - Table tmpGraph = memoryCatalog.getGraph(instanceName, graphName); - if (tmpGraph != null) { - memoryCatalog.dropGraph(instanceName, graphName); - } else { - Table graphFromCatalog = baseCatalog.getGraph(instanceName, graphName); - if (graphFromCatalog == null) { - throw new ObjectNotExistException(graphName); - } - baseCatalog.dropGraph(instanceName, graphName); - } - } - - @Override - public void dropTable(String instanceName, String tableName) { - Table table = memoryCatalog.getGraph(instanceName, tableName); - if (table != null) { - memoryCatalog.dropGraph(instanceName, tableName); - } else { - Table tableFromCatalog = baseCatalog.getTable(instanceName, tableName); - if (tableFromCatalog == null) { - throw new ObjectNotExistException(tableName); - } - baseCatalog.dropTable(instanceName, tableName); - } - - } - - @Override - public void dropFunction(String instanceName, String functionName) { - baseCatalog.dropFunction(instanceName, functionName); - } - - @Override - public String describeGraph(String instanceName, String graphName) { - Table graph = getGraph(instanceName, graphName); - if (graph == null) { - throw new ObjectNotExistException(graphName); - } - return graph.toString(); - } - - @Override - public String describeTable(String instanceName, String tableName) { - Table table = getTable(instanceName, tableName); - if (table == null) { - throw new ObjectNotExistException(tableName); - } - return table.toString(); - } - - @Override - public String describeFunction(String instanceName, String functionName) { - return baseCatalog.describeFunction(instanceName, functionName); - } + private final Catalog baseCatalog; + + private final Catalog memoryCatalog; + + private final Map> tmpTables = new HashMap<>(); + + public CatalogImpl(Catalog catalog) { + this.baseCatalog = catalog; + this.memoryCatalog = new MemoryCatalog(); + } + + @Override + public void init(Configuration config) { + memoryCatalog.init(config); + baseCatalog.init(config); + } + + @Override + public String getType() { + return baseCatalog.getType(); + } + + @Override + public Set listInstances() { + Set allInstances = new HashSet<>(memoryCatalog.listInstances()); + allInstances.addAll(baseCatalog.listInstances()); + return allInstances; + } + + @Override + public boolean isInstanceExists(String instanceName) { + return memoryCatalog.isInstanceExists(instanceName) + || baseCatalog.isInstanceExists(instanceName); + } + + @Override + public Table getGraph(String instanceName, String graphName) { + Table graph = memoryCatalog.getGraph(instanceName, graphName); + if (graph != null) { + return graph; + } + return baseCatalog.getGraph(instanceName, graphName); + } + + @Override + public Table getTable(String instanceName, String tableName) { + Table table = memoryCatalog.getTable(instanceName, tableName); + if (table != null) { + return table; + } + table = baseCatalog.getTable(instanceName, tableName); + if (table != null) { + return table; + } + table = getVertex(instanceName, tableName); + if (table != null) { + return table; + } + return getEdge(instanceName, tableName); + } + + @Override + public VertexTable getVertex(String instanceName, String vertexName) { + Table table = memoryCatalog.getTable(instanceName, vertexName); + if (table instanceof VertexTable) { + return (VertexTable) table; + } + return baseCatalog.getVertex(instanceName, vertexName); + } + + @Override + public EdgeTable getEdge(String instanceName, String edgeName) { + Table table = memoryCatalog.getTable(instanceName, edgeName); + if (table instanceof EdgeTable) { + return (EdgeTable) table; + } + return baseCatalog.getEdge(instanceName, edgeName); + } + + @Override + public GeaFlowFunction getFunction(String instanceName, String functionName) { + return baseCatalog.getFunction(instanceName, functionName); + } + + @Override + public Set listGraphAndTable(String instanceName) { + Set graphAndTables = new HashSet<>(memoryCatalog.listGraphAndTable(instanceName)); + graphAndTables.addAll(baseCatalog.listGraphAndTable(instanceName)); + return graphAndTables; + } + + @Override + public void createGraph(String instanceName, GeaFlowGraph graph) + throws ObjectAlreadyExistException, ObjectNotExistException { + if (getGraph(instanceName, graph.getName()) != null) { + if (!graph.isIfNotExists()) { + throw new ObjectAlreadyExistException(graph.getName()); + } + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + // create graph + if (graph.isTemporary()) { + memoryCatalog.createGraph(instanceName, graph); + } else { + baseCatalog.createGraph(instanceName, graph); + } + // create vertex table & edge table in catalog + for (VertexTable vertexTable : graph.getVertexTables()) { + createTable(instanceName, vertexTable); + } + for (EdgeTable edgeTable : graph.getEdgeTables()) { + createTable(instanceName, edgeTable); + } + } + + @Override + public void createTable(String instanceName, GeaFlowTable table) + throws ObjectAlreadyExistException, ObjectNotExistException { + if (getTable(instanceName, table.getName()) != null) { + if (!table.isIfNotExists()) { + throw new ObjectAlreadyExistException(table.getName()); + } + // ignore if table exists. + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + if (table.isTemporary()) { + memoryCatalog.createTable(instanceName, table); + } else { + baseCatalog.createTable(instanceName, table); + } + } + + @Override + public void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException { + if (getTable(instanceName, view.getName()) != null) { + if (!view.isIfNotExists()) { + throw new ObjectAlreadyExistException(view.getName()); + } + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + baseCatalog.createView(instanceName, view); + } + + @Override + public void createFunction(String instanceName, GeaFlowFunction function) + throws ObjectAlreadyExistException { + baseCatalog.createFunction(instanceName, function); + } + + @Override + public void dropGraph(String instanceName, String graphName) { + Table tmpGraph = memoryCatalog.getGraph(instanceName, graphName); + if (tmpGraph != null) { + memoryCatalog.dropGraph(instanceName, graphName); + } else { + Table graphFromCatalog = baseCatalog.getGraph(instanceName, graphName); + if (graphFromCatalog == null) { + throw new ObjectNotExistException(graphName); + } + baseCatalog.dropGraph(instanceName, graphName); + } + } + + @Override + public void dropTable(String instanceName, String tableName) { + Table table = memoryCatalog.getGraph(instanceName, tableName); + if (table != null) { + memoryCatalog.dropGraph(instanceName, tableName); + } else { + Table tableFromCatalog = baseCatalog.getTable(instanceName, tableName); + if (tableFromCatalog == null) { + throw new ObjectNotExistException(tableName); + } + baseCatalog.dropTable(instanceName, tableName); + } + } + + @Override + public void dropFunction(String instanceName, String functionName) { + baseCatalog.dropFunction(instanceName, functionName); + } + + @Override + public String describeGraph(String instanceName, String graphName) { + Table graph = getGraph(instanceName, graphName); + if (graph == null) { + throw new ObjectNotExistException(graphName); + } + return graph.toString(); + } + + @Override + public String describeTable(String instanceName, String tableName) { + Table table = getTable(instanceName, tableName); + if (table == null) { + throw new ObjectNotExistException(tableName); + } + return table.toString(); + } + + @Override + public String describeFunction(String instanceName, String functionName) { + return baseCatalog.describeFunction(instanceName, functionName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CompileCatalog.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CompileCatalog.java index 14df4d73b..ee6582c6e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CompileCatalog.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/CompileCatalog.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; + import org.apache.calcite.schema.Table; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.exception.ObjectAlreadyExistException; @@ -36,191 +37,184 @@ public class CompileCatalog implements Catalog { - private final Catalog catalog; - - private final Map> allTables = new HashMap<>(); - - public CompileCatalog(Catalog catalog) { - this.catalog = catalog; - } - - @Override - public void init(Configuration config) { - - } - - @Override - public String getType() { - return catalog.getType(); - } - - @Override - public Set listInstances() { - Set allInstances = new HashSet<>(allTables.keySet()); - allInstances.addAll(catalog.listInstances()); - return allInstances; - } - - @Override - public boolean isInstanceExists(String instanceName) { - return allTables.containsKey(instanceName) || catalog.isInstanceExists(instanceName); - } - - @Override - public Table getGraph(String instanceName, String graphName) { - Table graph = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(graphName); - if (graph != null) { - return graph; - } - return catalog.getGraph(instanceName, graphName); - } - - @Override - public Table getTable(String instanceName, String tableName) { - Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(tableName); - if (table != null) { - return table; - } - return catalog.getTable(instanceName, tableName); - } - - @Override - public VertexTable getVertex(String instanceName, String vertexName) { - Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(vertexName); - if (table instanceof VertexTable) { - return (VertexTable) table; - } - return catalog.getVertex(instanceName, vertexName); - } - - @Override - public EdgeTable getEdge(String instanceName, String edgeName) { - Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(edgeName); - if (table instanceof EdgeTable) { - return (EdgeTable) table; - } - return catalog.getEdge(instanceName, edgeName); - } - - @Override - public GeaFlowFunction getFunction(String instanceName, String functionName) { - return catalog.getFunction(instanceName, functionName); - } - - @Override - public Set listGraphAndTable(String instanceName) { - Set graphAndTables = new HashSet<>(); - Map tables = allTables.get(instanceName); - if (tables != null) { - graphAndTables.addAll(tables.keySet()); - } - graphAndTables.addAll(catalog.listGraphAndTable(instanceName)); - return graphAndTables; - } - - @Override - public void createGraph(String instanceName, GeaFlowGraph graph) - throws ObjectAlreadyExistException { - if (getGraph(instanceName, graph.getName()) != null) { - if (!graph.isIfNotExists()) { - throw new ObjectAlreadyExistException(graph.getName()); - } - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(graph.getName(), graph); - } - - @Override - public void createTable(String instanceName, GeaFlowTable table) - throws ObjectAlreadyExistException { - if (getTable(instanceName, table.getName()) != null) { - if (!table.isIfNotExists()) { - throw new ObjectAlreadyExistException(table.getName()); - } - // ignore if table exists. - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(table.getName(), table); - } - - @Override - public void createView(String instanceName, GeaFlowView view) - throws ObjectAlreadyExistException { - if (getTable(instanceName, view.getName()) != null) { - if (!view.isIfNotExists()) { - throw new ObjectAlreadyExistException(view.getName()); - } - return; - } - if (!isInstanceExists(instanceName)) { - throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(view.getName(), view); - } - - @Override - public void createFunction(String instanceName, GeaFlowFunction function) - throws ObjectAlreadyExistException { - - } - - @Override - public void dropGraph(String instanceName, String graphName) { - Table graph = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(graphName); - if (graph != null) { - allTables.get(instanceName).remove(graphName); - return; - } - Table graphFromCatalog = catalog.getGraph(instanceName, graphName); - if (graphFromCatalog == null) { - throw new ObjectNotExistException(graphName); - } - } - - @Override - public void dropTable(String instanceName, String tableName) { - Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(tableName); - if (table != null) { - allTables.get(instanceName).remove(tableName); - return; - } - Table tableFromCatalog = catalog.getTable(instanceName, tableName); - if (tableFromCatalog == null) { - throw new ObjectNotExistException(tableName); - } - } - - @Override - public void dropFunction(String instanceName, String functionName) { - - } - - @Override - public String describeGraph(String instanceName, String graphName) { - Table graph = getGraph(instanceName, graphName); - if (graph == null) { - throw new ObjectNotExistException(graphName); - } - return graph.toString(); - } - - @Override - public String describeTable(String instanceName, String tableName) { - Table table = getTable(instanceName, tableName); - if (table == null) { - throw new ObjectNotExistException(tableName); - } - return table.toString(); - } - - @Override - public String describeFunction(String instanceName, String functionName) { - return null; - } + private final Catalog catalog; + + private final Map> allTables = new HashMap<>(); + + public CompileCatalog(Catalog catalog) { + this.catalog = catalog; + } + + @Override + public void init(Configuration config) {} + + @Override + public String getType() { + return catalog.getType(); + } + + @Override + public Set listInstances() { + Set allInstances = new HashSet<>(allTables.keySet()); + allInstances.addAll(catalog.listInstances()); + return allInstances; + } + + @Override + public boolean isInstanceExists(String instanceName) { + return allTables.containsKey(instanceName) || catalog.isInstanceExists(instanceName); + } + + @Override + public Table getGraph(String instanceName, String graphName) { + Table graph = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(graphName); + if (graph != null) { + return graph; + } + return catalog.getGraph(instanceName, graphName); + } + + @Override + public Table getTable(String instanceName, String tableName) { + Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(tableName); + if (table != null) { + return table; + } + return catalog.getTable(instanceName, tableName); + } + + @Override + public VertexTable getVertex(String instanceName, String vertexName) { + Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(vertexName); + if (table instanceof VertexTable) { + return (VertexTable) table; + } + return catalog.getVertex(instanceName, vertexName); + } + + @Override + public EdgeTable getEdge(String instanceName, String edgeName) { + Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(edgeName); + if (table instanceof EdgeTable) { + return (EdgeTable) table; + } + return catalog.getEdge(instanceName, edgeName); + } + + @Override + public GeaFlowFunction getFunction(String instanceName, String functionName) { + return catalog.getFunction(instanceName, functionName); + } + + @Override + public Set listGraphAndTable(String instanceName) { + Set graphAndTables = new HashSet<>(); + Map tables = allTables.get(instanceName); + if (tables != null) { + graphAndTables.addAll(tables.keySet()); + } + graphAndTables.addAll(catalog.listGraphAndTable(instanceName)); + return graphAndTables; + } + + @Override + public void createGraph(String instanceName, GeaFlowGraph graph) + throws ObjectAlreadyExistException { + if (getGraph(instanceName, graph.getName()) != null) { + if (!graph.isIfNotExists()) { + throw new ObjectAlreadyExistException(graph.getName()); + } + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(graph.getName(), graph); + } + + @Override + public void createTable(String instanceName, GeaFlowTable table) + throws ObjectAlreadyExistException { + if (getTable(instanceName, table.getName()) != null) { + if (!table.isIfNotExists()) { + throw new ObjectAlreadyExistException(table.getName()); + } + // ignore if table exists. + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(table.getName(), table); + } + + @Override + public void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException { + if (getTable(instanceName, view.getName()) != null) { + if (!view.isIfNotExists()) { + throw new ObjectAlreadyExistException(view.getName()); + } + return; + } + if (!isInstanceExists(instanceName)) { + throw new ObjectNotExistException("instance: '" + instanceName + "' is not exists."); + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(view.getName(), view); + } + + @Override + public void createFunction(String instanceName, GeaFlowFunction function) + throws ObjectAlreadyExistException {} + + @Override + public void dropGraph(String instanceName, String graphName) { + Table graph = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(graphName); + if (graph != null) { + allTables.get(instanceName).remove(graphName); + return; + } + Table graphFromCatalog = catalog.getGraph(instanceName, graphName); + if (graphFromCatalog == null) { + throw new ObjectNotExistException(graphName); + } + } + + @Override + public void dropTable(String instanceName, String tableName) { + Table table = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).get(tableName); + if (table != null) { + allTables.get(instanceName).remove(tableName); + return; + } + Table tableFromCatalog = catalog.getTable(instanceName, tableName); + if (tableFromCatalog == null) { + throw new ObjectNotExistException(tableName); + } + } + + @Override + public void dropFunction(String instanceName, String functionName) {} + + @Override + public String describeGraph(String instanceName, String graphName) { + Table graph = getGraph(instanceName, graphName); + if (graph == null) { + throw new ObjectNotExistException(graphName); + } + return graph.toString(); + } + + @Override + public String describeTable(String instanceName, String tableName) { + Table table = getTable(instanceName, tableName); + if (table == null) { + throw new ObjectNotExistException(tableName); + } + return table.toString(); + } + + @Override + public String describeFunction(String instanceName, String functionName) { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/GeaFlowRootCalciteSchema.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/GeaFlowRootCalciteSchema.java index e61bb004f..2f33fe812 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/GeaFlowRootCalciteSchema.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/GeaFlowRootCalciteSchema.java @@ -19,82 +19,84 @@ package org.apache.geaflow.dsl.catalog; -import com.google.common.collect.Sets; import java.util.Collection; import java.util.Set; + import org.apache.calcite.jdbc.SimpleCalciteSchema; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rel.type.RelProtoDataType; import org.apache.calcite.schema.*; import org.apache.geaflow.dsl.catalog.exception.ObjectNotExistException; -public class GeaFlowRootCalciteSchema implements Schema { - - private final Catalog catalog; - - public GeaFlowRootCalciteSchema(Catalog catalog) { - this.catalog = catalog; - } - - @Override - public Table getTable(String name) { - return null; - } - - @Override - public Set getTableNames() { - return Sets.newHashSet(); - } - - @Override - public RelProtoDataType getType(String name) { - return null; - } - - @Override - public Set getTypeNames() { - return Sets.newHashSet(); - } - - @Override - public Collection getFunctions(String name) { - return Sets.newHashSet(); - } - - @Override - public Set getFunctionNames() { - return Sets.newHashSet(); - } - - @Override - public Schema getSubSchema(String name) { - if (catalog.isInstanceExists(name)) { - return new InstanceCalciteSchema(name, catalog); - } - throw new ObjectNotExistException("Instance '" + name + "' is not exists."); - } - - @Override - public Set getSubSchemaNames() { - return catalog.listInstances(); - } - - @Override - public Expression getExpression(SchemaPlus parentSchema, String name) { - return Schemas.subSchemaExpression(parentSchema, name, getClass()); - } - - @Override - public boolean isMutable() { - return true; - } +import com.google.common.collect.Sets; - @Override - public Schema snapshot(SchemaVersion version) { - return this; - } +public class GeaFlowRootCalciteSchema implements Schema { - public SchemaPlus plus() { - return new SimpleCalciteSchema(null, this, "").plus(); + private final Catalog catalog; + + public GeaFlowRootCalciteSchema(Catalog catalog) { + this.catalog = catalog; + } + + @Override + public Table getTable(String name) { + return null; + } + + @Override + public Set getTableNames() { + return Sets.newHashSet(); + } + + @Override + public RelProtoDataType getType(String name) { + return null; + } + + @Override + public Set getTypeNames() { + return Sets.newHashSet(); + } + + @Override + public Collection getFunctions(String name) { + return Sets.newHashSet(); + } + + @Override + public Set getFunctionNames() { + return Sets.newHashSet(); + } + + @Override + public Schema getSubSchema(String name) { + if (catalog.isInstanceExists(name)) { + return new InstanceCalciteSchema(name, catalog); } + throw new ObjectNotExistException("Instance '" + name + "' is not exists."); + } + + @Override + public Set getSubSchemaNames() { + return catalog.listInstances(); + } + + @Override + public Expression getExpression(SchemaPlus parentSchema, String name) { + return Schemas.subSchemaExpression(parentSchema, name, getClass()); + } + + @Override + public boolean isMutable() { + return true; + } + + @Override + public Schema snapshot(SchemaVersion version) { + return this; + } + + public SchemaPlus plus() { + return new SimpleCalciteSchema(null, this, "").plus(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/InstanceCalciteSchema.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/InstanceCalciteSchema.java index 644982a6c..b70580c02 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/InstanceCalciteSchema.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/InstanceCalciteSchema.java @@ -19,94 +19,94 @@ package org.apache.geaflow.dsl.catalog; -import com.google.common.collect.Sets; import java.util.Collection; import java.util.Collections; import java.util.Objects; import java.util.Set; + import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rel.type.RelProtoDataType; import org.apache.calcite.schema.*; -/** - * A bridge between GeaFlow's instance catalog to calcite schema. - */ -public class InstanceCalciteSchema implements Schema { - - private final String instanceName; - - private final Catalog catalog; - - public InstanceCalciteSchema(String instanceName, Catalog catalog) { - this.instanceName = Objects.requireNonNull(instanceName); - this.catalog = Objects.requireNonNull(catalog); - } - - @Override - public Table getTable(String name) { - //At present, Calcite only has one Table data model. - // The Graph data model inherits from the Table data model. - // During validator inference, it is impossible to distinguish whether it is a graph or a - // table based on identifier. It is necessary to read the catalog separately. - Table table; - try { - table = catalog.getTable(instanceName, name); - } catch (Exception e) { - table = null; - } - if (table != null) { - return table; - } - return catalog.getGraph(instanceName, name); - } - - @Override - public Set getTableNames() { - return catalog.listGraphAndTable(instanceName); - } - - @Override - public RelProtoDataType getType(String name) { - return null; - } - - @Override - public Set getTypeNames() { - return Sets.newHashSet(); - } - - @Override - public Collection getFunctions(String name) { - return Collections.emptyList(); - } - - @Override - public Set getFunctionNames() { - return Sets.newHashSet(); - } - - @Override - public SchemaPlus getSubSchema(String name) { - return null; - } - - @Override - public Set getSubSchemaNames() { - return null; - } +import com.google.common.collect.Sets; - @Override - public Expression getExpression(SchemaPlus parentSchema, String name) { - return Schemas.subSchemaExpression(parentSchema, name, getClass()); - } +/** A bridge between GeaFlow's instance catalog to calcite schema. */ +public class InstanceCalciteSchema implements Schema { - @Override - public boolean isMutable() { - return true; + private final String instanceName; + + private final Catalog catalog; + + public InstanceCalciteSchema(String instanceName, Catalog catalog) { + this.instanceName = Objects.requireNonNull(instanceName); + this.catalog = Objects.requireNonNull(catalog); + } + + @Override + public Table getTable(String name) { + // At present, Calcite only has one Table data model. + // The Graph data model inherits from the Table data model. + // During validator inference, it is impossible to distinguish whether it is a graph or a + // table based on identifier. It is necessary to read the catalog separately. + Table table; + try { + table = catalog.getTable(instanceName, name); + } catch (Exception e) { + table = null; } - - @Override - public Schema snapshot(SchemaVersion version) { - return this; + if (table != null) { + return table; } + return catalog.getGraph(instanceName, name); + } + + @Override + public Set getTableNames() { + return catalog.listGraphAndTable(instanceName); + } + + @Override + public RelProtoDataType getType(String name) { + return null; + } + + @Override + public Set getTypeNames() { + return Sets.newHashSet(); + } + + @Override + public Collection getFunctions(String name) { + return Collections.emptyList(); + } + + @Override + public Set getFunctionNames() { + return Sets.newHashSet(); + } + + @Override + public SchemaPlus getSubSchema(String name) { + return null; + } + + @Override + public Set getSubSchemaNames() { + return null; + } + + @Override + public Expression getExpression(SchemaPlus parentSchema, String name) { + return Schemas.subSchemaExpression(parentSchema, name, getClass()); + } + + @Override + public boolean isMutable() { + return true; + } + + @Override + public Schema snapshot(SchemaVersion version) { + return this; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/MemoryCatalog.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/MemoryCatalog.java index cc1887d0e..831bd96ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/MemoryCatalog.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/MemoryCatalog.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; + import org.apache.calcite.schema.Table; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.exception.ObjectAlreadyExistException; @@ -37,180 +38,177 @@ public class MemoryCatalog implements Catalog { - public static final String CATALOG_TYPE = "memory"; - - private final Map> allTables = new HashMap<>(); - - public MemoryCatalog() { - allTables.put(Catalog.DEFAULT_INSTANCE, new HashMap<>()); - } - - @Override - public void init(Configuration config) { - - } - - @Override - public String getType() { - return CATALOG_TYPE; - } - - @Override - public Set listInstances() { - Set instances = new HashSet<>(allTables.keySet()); - return instances; - } - - @Override - public boolean isInstanceExists(String instanceName) { - return Objects.equals(Catalog.DEFAULT_INSTANCE, instanceName) - || allTables.containsKey(instanceName); - } - - @Override - public Table getGraph(String instanceName, String graphName) { - Map graphs = allTables.get(instanceName); - if (graphs != null) { - return graphs.get(graphName); - } - return null; - } - - @Override - public Table getTable(String instanceName, String tableName) { - Map tables = allTables.get(instanceName); - if (tables != null) { - return tables.get(tableName); - } - return null; - } - - @Override - public VertexTable getVertex(String instanceName, String vertexName) { - Map tables = allTables.get(instanceName); - if (tables != null && (tables.get(vertexName) instanceof VertexTable)) { - return (VertexTable) tables.get(vertexName); - } - return null; - } - - @Override - public EdgeTable getEdge(String instanceName, String edgeName) { - Map tables = allTables.get(instanceName); - if (tables != null && (tables.get(edgeName) instanceof EdgeTable)) { - return (EdgeTable) tables.get(edgeName); - } - return null; - } - - @Override - public GeaFlowFunction getFunction(String instanceName, String functionName) { - return null; - } - - @Override - public Set listGraphAndTable(String instanceName) { - return new HashSet<>(allTables.get(instanceName).keySet()); - } - - @Override - public void createGraph(String instanceName, GeaFlowGraph graph) throws ObjectAlreadyExistException { - if (getGraph(instanceName, graph.getName()) != null) { - if (!graph.isIfNotExists()) { - throw new ObjectAlreadyExistException(graph.getName()); - } - return; - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(graph.getName(), graph); - } - - @Override - public void createTable(String instanceName, GeaFlowTable table) throws ObjectAlreadyExistException { - if (getTable(instanceName, table.getName()) != null) { - if (!table.isIfNotExists()) { - throw new ObjectAlreadyExistException(table.getName()); - } - // ignore if table exists. - return; - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(table.getName(), table); - } - - @Override - public void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException { - if (getTable(instanceName, view.getName()) != null) { - if (!view.isIfNotExists()) { - throw new ObjectAlreadyExistException(view.getName()); - } - return; - } - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(view.getName(), view); - } - - @Override - public void createFunction(String instanceName, GeaFlowFunction function) throws ObjectAlreadyExistException { - - } - - @Override - public void dropGraph(String instanceName, String graphName) { - Map graphs = allTables.get(instanceName); - if (graphs == null) { - throw new ObjectNotExistException(instanceName); - } - GeaFlowGraph graph = (GeaFlowGraph) graphs.get(graphName); - if (graph == null) { - throw new ObjectNotExistException(graphName); - } - graphs.remove(graphName); - } - - @Override - public void dropTable(String instanceName, String tableName) { - Map tables = allTables.get(instanceName); - if (tables == null) { - throw new ObjectNotExistException(instanceName); - } - GeaFlowTable table = (GeaFlowTable) tables.get(tableName); - if (table == null) { - throw new ObjectNotExistException(tableName); - } - tables.remove(tableName); - } - - @Override - public void dropFunction(String instanceName, String functionName) { - - } - - @Override - public String describeGraph(String instanceName, String graphName) { - Map graphs = allTables.get(instanceName); - if (graphs == null) { - throw new ObjectNotExistException(instanceName); - } - GeaFlowGraph graph = (GeaFlowGraph) graphs.get(graphName); - if (graph == null) { - throw new ObjectNotExistException(graphName); - } - return graph.toString(); - } - - @Override - public String describeTable(String instanceName, String tableName) { - Map tables = allTables.get(instanceName); - if (tables == null) { - throw new ObjectNotExistException(instanceName); - } - GeaFlowTable table = (GeaFlowTable) tables.get(tableName); - if (table == null) { - throw new ObjectNotExistException(tableName); - } - return table.toString(); - } - - @Override - public String describeFunction(String instanceName, String functionName) { - return null; - } + public static final String CATALOG_TYPE = "memory"; + + private final Map> allTables = new HashMap<>(); + + public MemoryCatalog() { + allTables.put(Catalog.DEFAULT_INSTANCE, new HashMap<>()); + } + + @Override + public void init(Configuration config) {} + + @Override + public String getType() { + return CATALOG_TYPE; + } + + @Override + public Set listInstances() { + Set instances = new HashSet<>(allTables.keySet()); + return instances; + } + + @Override + public boolean isInstanceExists(String instanceName) { + return Objects.equals(Catalog.DEFAULT_INSTANCE, instanceName) + || allTables.containsKey(instanceName); + } + + @Override + public Table getGraph(String instanceName, String graphName) { + Map graphs = allTables.get(instanceName); + if (graphs != null) { + return graphs.get(graphName); + } + return null; + } + + @Override + public Table getTable(String instanceName, String tableName) { + Map tables = allTables.get(instanceName); + if (tables != null) { + return tables.get(tableName); + } + return null; + } + + @Override + public VertexTable getVertex(String instanceName, String vertexName) { + Map tables = allTables.get(instanceName); + if (tables != null && (tables.get(vertexName) instanceof VertexTable)) { + return (VertexTable) tables.get(vertexName); + } + return null; + } + + @Override + public EdgeTable getEdge(String instanceName, String edgeName) { + Map tables = allTables.get(instanceName); + if (tables != null && (tables.get(edgeName) instanceof EdgeTable)) { + return (EdgeTable) tables.get(edgeName); + } + return null; + } + + @Override + public GeaFlowFunction getFunction(String instanceName, String functionName) { + return null; + } + + @Override + public Set listGraphAndTable(String instanceName) { + return new HashSet<>(allTables.get(instanceName).keySet()); + } + + @Override + public void createGraph(String instanceName, GeaFlowGraph graph) + throws ObjectAlreadyExistException { + if (getGraph(instanceName, graph.getName()) != null) { + if (!graph.isIfNotExists()) { + throw new ObjectAlreadyExistException(graph.getName()); + } + return; + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(graph.getName(), graph); + } + + @Override + public void createTable(String instanceName, GeaFlowTable table) + throws ObjectAlreadyExistException { + if (getTable(instanceName, table.getName()) != null) { + if (!table.isIfNotExists()) { + throw new ObjectAlreadyExistException(table.getName()); + } + // ignore if table exists. + return; + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(table.getName(), table); + } + + @Override + public void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException { + if (getTable(instanceName, view.getName()) != null) { + if (!view.isIfNotExists()) { + throw new ObjectAlreadyExistException(view.getName()); + } + return; + } + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).put(view.getName(), view); + } + + @Override + public void createFunction(String instanceName, GeaFlowFunction function) + throws ObjectAlreadyExistException {} + + @Override + public void dropGraph(String instanceName, String graphName) { + Map graphs = allTables.get(instanceName); + if (graphs == null) { + throw new ObjectNotExistException(instanceName); + } + GeaFlowGraph graph = (GeaFlowGraph) graphs.get(graphName); + if (graph == null) { + throw new ObjectNotExistException(graphName); + } + graphs.remove(graphName); + } + + @Override + public void dropTable(String instanceName, String tableName) { + Map tables = allTables.get(instanceName); + if (tables == null) { + throw new ObjectNotExistException(instanceName); + } + GeaFlowTable table = (GeaFlowTable) tables.get(tableName); + if (table == null) { + throw new ObjectNotExistException(tableName); + } + tables.remove(tableName); + } + + @Override + public void dropFunction(String instanceName, String functionName) {} + + @Override + public String describeGraph(String instanceName, String graphName) { + Map graphs = allTables.get(instanceName); + if (graphs == null) { + throw new ObjectNotExistException(instanceName); + } + GeaFlowGraph graph = (GeaFlowGraph) graphs.get(graphName); + if (graph == null) { + throw new ObjectNotExistException(graphName); + } + return graph.toString(); + } + + @Override + public String describeTable(String instanceName, String tableName) { + Map tables = allTables.get(instanceName); + if (tables == null) { + throw new ObjectNotExistException(instanceName); + } + GeaFlowTable table = (GeaFlowTable) tables.get(tableName); + if (table == null) { + throw new ObjectNotExistException(tableName); + } + return table.toString(); + } + + @Override + public String describeFunction(String instanceName, String functionName) { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractDataModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractDataModel.java index fc88193be..14087ea28 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractDataModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractDataModel.java @@ -19,6 +19,4 @@ package org.apache.geaflow.dsl.catalog.console; -public abstract class AbstractDataModel extends AbstractNameModel { - -} +public abstract class AbstractDataModel extends AbstractNameModel {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractIdModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractIdModel.java index 44f79f52f..60f31a7f4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractIdModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractIdModel.java @@ -23,73 +23,73 @@ public abstract class AbstractIdModel implements Serializable { - protected String id; + protected String id; - protected String createTime; + protected String createTime; - protected String modifyTime; + protected String modifyTime; - protected String creatorId; + protected String creatorId; - protected String creatorName; + protected String creatorName; - protected String modifierId; + protected String modifierId; - protected String modifierName; + protected String modifierName; - public String getId() { - return id; - } + public String getId() { + return id; + } - public void setId(String id) { - this.id = id; - } + public void setId(String id) { + this.id = id; + } - public String getCreateTime() { - return createTime; - } + public String getCreateTime() { + return createTime; + } - public void setCreateTime(String createTime) { - this.createTime = createTime; - } + public void setCreateTime(String createTime) { + this.createTime = createTime; + } - public String getModifyTime() { - return modifyTime; - } + public String getModifyTime() { + return modifyTime; + } - public void setModifyTime(String modifyTime) { - this.modifyTime = modifyTime; - } + public void setModifyTime(String modifyTime) { + this.modifyTime = modifyTime; + } - public String getCreatorId() { - return creatorId; - } + public String getCreatorId() { + return creatorId; + } - public void setCreatorId(String creatorId) { - this.creatorId = creatorId; - } + public void setCreatorId(String creatorId) { + this.creatorId = creatorId; + } - public String getCreatorName() { - return creatorName; - } + public String getCreatorName() { + return creatorName; + } - public void setCreatorName(String creatorName) { - this.creatorName = creatorName; - } + public void setCreatorName(String creatorName) { + this.creatorName = creatorName; + } - public String getModifierId() { - return modifierId; - } + public String getModifierId() { + return modifierId; + } - public void setModifierId(String modifierId) { - this.modifierId = modifierId; - } + public void setModifierId(String modifierId) { + this.modifierId = modifierId; + } - public String getModifierName() { - return modifierName; - } + public String getModifierName() { + return modifierName; + } - public void setModifierName(String modifierName) { - this.modifierName = modifierName; - } + public void setModifierName(String modifierName) { + this.modifierName = modifierName; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractNameModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractNameModel.java index e74bcd733..3b2f7b546 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractNameModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractNameModel.java @@ -21,32 +21,30 @@ public abstract class AbstractNameModel extends AbstractIdModel { - protected String name; + protected String name; - protected String comment; + protected String comment; - public AbstractNameModel() { + public AbstractNameModel() {} - } + public AbstractNameModel(String name, String comment) { + this.name = name; + this.comment = comment; + } - public AbstractNameModel(String name, String comment) { - this.name = name; - this.comment = comment; - } + public String getName() { + return name; + } - public String getName() { - return name; - } + public void setName(String name) { + this.name = name; + } - public void setName(String name) { - this.name = name; - } + public String getComment() { + return comment; + } - public String getComment() { - return comment; - } - - public void setComment(String comment) { - this.comment = comment; - } + public void setComment(String comment) { + this.comment = comment; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractStructModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractStructModel.java index c4a62840c..ebc44b192 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractStructModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/AbstractStructModel.java @@ -23,23 +23,23 @@ public abstract class AbstractStructModel extends AbstractDataModel { - protected GeaFlowStructType type; + protected GeaFlowStructType type; - protected List fields; + protected List fields; - public GeaFlowStructType getType() { - return type; - } + public GeaFlowStructType getType() { + return type; + } - public void setType(GeaFlowStructType type) { - this.type = type; - } + public void setType(GeaFlowStructType type) { + this.type = type; + } - public List getFields() { - return fields; - } + public List getFields() { + return fields; + } - public void setFields(List fields) { - this.fields = fields; - } + public void setFields(List fields) { + this.fields = fields; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/CatalogUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/CatalogUtil.java index 6f200a179..4209fe191 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/CatalogUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/CatalogUtil.java @@ -21,12 +21,12 @@ import static org.apache.geaflow.dsl.util.SqlTypeUtil.convertTypeName; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; @@ -41,232 +41,265 @@ import org.apache.geaflow.dsl.schema.GeaFlowGraph.VertexTable; import org.apache.geaflow.dsl.schema.GeaFlowTable; +import com.google.common.collect.Lists; + public class CatalogUtil { - public static TableModel convertToTableModel(GeaFlowTable table) { - TableModel tableModel = new TableModel(); - PluginConfigModel pluginConfigModel = new PluginConfigModel(); - pluginConfigModel.setType(table.getTableType()); - pluginConfigModel.setConfig(convertToTableModelConfig(table.getConfig())); - tableModel.setPluginConfig(pluginConfigModel); - tableModel.setName(table.getName()); - List fields = table.getFields(); - List fieldModelList = convertToFieldModel(fields); - tableModel.setFields(fieldModelList); - return tableModel; - } + public static TableModel convertToTableModel(GeaFlowTable table) { + TableModel tableModel = new TableModel(); + PluginConfigModel pluginConfigModel = new PluginConfigModel(); + pluginConfigModel.setType(table.getTableType()); + pluginConfigModel.setConfig(convertToTableModelConfig(table.getConfig())); + tableModel.setPluginConfig(pluginConfigModel); + tableModel.setName(table.getName()); + List fields = table.getFields(); + List fieldModelList = convertToFieldModel(fields); + tableModel.setFields(fieldModelList); + return tableModel; + } - public static GeaFlowTable convertToGeaFlowTable(TableModel model, String instanceName) { - if (model == null) { - return null; - } - List fieldModels = model.getFields(); - List fields = convertToTableField(fieldModels); - return new GeaFlowTable(instanceName, model.getName(), fields, new ArrayList<>(), - new ArrayList<>(), convertToGeaFlowTableConfig(model.getPluginConfig()), true, false); + public static GeaFlowTable convertToGeaFlowTable(TableModel model, String instanceName) { + if (model == null) { + return null; } + List fieldModels = model.getFields(); + List fields = convertToTableField(fieldModels); + return new GeaFlowTable( + instanceName, + model.getName(), + fields, + new ArrayList<>(), + new ArrayList<>(), + convertToGeaFlowTableConfig(model.getPluginConfig()), + true, + false); + } - public static VertexModel convertToVertexModel(VertexTable table) { - VertexModel vertexModel = new VertexModel(); - vertexModel.setName(table.getTypeName()); - List fields = table.getFields(); - String idFieldName = table.getIdFieldName(); - List fieldModels = new ArrayList<>(fields.size()); - for (TableField field : fields) { - GeaFlowFieldCategory fieldCategory; - if (field.getName().equals(idFieldName)) { - fieldCategory = GeaFlowFieldCategory.VERTEX_ID; - } else { - fieldCategory = GeaFlowFieldCategory.PROPERTY; - } - fieldModels.add(new FieldModel(field.getName(), null, - GeaFlowFieldType.getFieldType(field.getType()), fieldCategory)); - } - vertexModel.setFields(fieldModels); - return vertexModel; + public static VertexModel convertToVertexModel(VertexTable table) { + VertexModel vertexModel = new VertexModel(); + vertexModel.setName(table.getTypeName()); + List fields = table.getFields(); + String idFieldName = table.getIdFieldName(); + List fieldModels = new ArrayList<>(fields.size()); + for (TableField field : fields) { + GeaFlowFieldCategory fieldCategory; + if (field.getName().equals(idFieldName)) { + fieldCategory = GeaFlowFieldCategory.VERTEX_ID; + } else { + fieldCategory = GeaFlowFieldCategory.PROPERTY; + } + fieldModels.add( + new FieldModel( + field.getName(), + null, + GeaFlowFieldType.getFieldType(field.getType()), + fieldCategory)); } + vertexModel.setFields(fieldModels); + return vertexModel; + } - public static VertexTable convertToVertexTable(String instanceName, VertexModel model) { - if (model == null) { - return null; - } - List fieldModels = model.getFields(); - List fields = new ArrayList<>(fieldModels.size()); - String idFieldName = null; - for (FieldModel fieldModel : fieldModels) { - if (fieldModel.getCategory() == GeaFlowFieldCategory.VERTEX_ID) { - idFieldName = fieldModel.getName(); - } - String typeName = convertTypeName(fieldModel.getType().name()); - IType fieldType = Types.of(typeName, -1); - TableField field = new TableField(fieldModel.getName(), fieldType, false); - fields.add(field); - } - return new VertexTable(instanceName, model.getName(), fields, idFieldName); + public static VertexTable convertToVertexTable(String instanceName, VertexModel model) { + if (model == null) { + return null; } - - public static EdgeModel convertToEdgeModel(EdgeTable table) { - EdgeModel edgeModel = new EdgeModel(); - edgeModel.setName(table.getTypeName()); - List fields = table.getFields(); - List fieldModels = new ArrayList<>(fields.size()); - String srcIdFieldName = table.getSrcIdFieldName(); - String targetIdFieldName = table.getTargetIdFieldName(); - String timestampFieldName = table.getTimestampFieldName(); - for (TableField field : fields) { - GeaFlowFieldCategory fieldCategory; - if (field.getName().equals(srcIdFieldName)) { - fieldCategory = GeaFlowFieldCategory.EDGE_SOURCE_ID; - } else if (field.getName().equals(targetIdFieldName)) { - fieldCategory = GeaFlowFieldCategory.EDGE_TARGET_ID; - } else if (field.getName().equals(timestampFieldName)) { - fieldCategory = GeaFlowFieldCategory.EDGE_TIMESTAMP; - } else { - fieldCategory = GeaFlowFieldCategory.PROPERTY; - } - fieldModels.add(new FieldModel(field.getName(), null, - GeaFlowFieldType.getFieldType(field.getType()), fieldCategory)); - } - edgeModel.setFields(fieldModels); - return edgeModel; + List fieldModels = model.getFields(); + List fields = new ArrayList<>(fieldModels.size()); + String idFieldName = null; + for (FieldModel fieldModel : fieldModels) { + if (fieldModel.getCategory() == GeaFlowFieldCategory.VERTEX_ID) { + idFieldName = fieldModel.getName(); + } + String typeName = convertTypeName(fieldModel.getType().name()); + IType fieldType = Types.of(typeName, -1); + TableField field = new TableField(fieldModel.getName(), fieldType, false); + fields.add(field); } + return new VertexTable(instanceName, model.getName(), fields, idFieldName); + } - public static EdgeTable convertToEdgeTable(String instanceName, EdgeModel model) { - if (model == null) { - return null; - } - List fieldModels = model.getFields(); - List fields = new ArrayList<>(fieldModels.size()); - String srcIdFieldName = null; - String targetIdFieldName = null; - String timestampFieldName = null; - for (FieldModel fieldModel : fieldModels) { - switch (fieldModel.getCategory()) { - case EDGE_SOURCE_ID: - srcIdFieldName = fieldModel.getName(); - break; - case EDGE_TARGET_ID: - targetIdFieldName = fieldModel.getName(); - break; - case EDGE_TIMESTAMP: - timestampFieldName = fieldModel.getName(); - break; - default: - } - String typeName = convertTypeName(fieldModel.getType().name()); - IType fieldType = Types.of(typeName, -1); - TableField field = new TableField(fieldModel.getName(), fieldType, false); - fields.add(field); - } - return new EdgeTable(instanceName, model.getName(), fields, srcIdFieldName, targetIdFieldName, - timestampFieldName); + public static EdgeModel convertToEdgeModel(EdgeTable table) { + EdgeModel edgeModel = new EdgeModel(); + edgeModel.setName(table.getTypeName()); + List fields = table.getFields(); + List fieldModels = new ArrayList<>(fields.size()); + String srcIdFieldName = table.getSrcIdFieldName(); + String targetIdFieldName = table.getTargetIdFieldName(); + String timestampFieldName = table.getTimestampFieldName(); + for (TableField field : fields) { + GeaFlowFieldCategory fieldCategory; + if (field.getName().equals(srcIdFieldName)) { + fieldCategory = GeaFlowFieldCategory.EDGE_SOURCE_ID; + } else if (field.getName().equals(targetIdFieldName)) { + fieldCategory = GeaFlowFieldCategory.EDGE_TARGET_ID; + } else if (field.getName().equals(timestampFieldName)) { + fieldCategory = GeaFlowFieldCategory.EDGE_TIMESTAMP; + } else { + fieldCategory = GeaFlowFieldCategory.PROPERTY; + } + fieldModels.add( + new FieldModel( + field.getName(), + null, + GeaFlowFieldType.getFieldType(field.getType()), + fieldCategory)); } + edgeModel.setFields(fieldModels); + return edgeModel; + } - public static GraphModel convertToGraphModel(GeaFlowGraph graph) { - GraphModel graphModel = new GraphModel(); - PluginConfigModel pluginConfigModel = new PluginConfigModel(); - pluginConfigModel.setType(graph.getStoreType()); - pluginConfigModel.setConfig(convertToGraphModelConfig(graph.getConfig().getConfigMap())); - - graphModel.setPluginConfig(pluginConfigModel); - List vertexTables = graph.getVertexTables(); - List vertexModels = new ArrayList<>(vertexTables.size()); - for (VertexTable vertexTable : vertexTables) { - vertexModels.add(convertToVertexModel(vertexTable)); - } - graphModel.setVertices(vertexModels); - List edgeTables = graph.getEdgeTables(); - List edgeModels = new ArrayList<>(edgeTables.size()); - for (EdgeTable edgeTable : edgeTables) { - edgeModels.add(convertToEdgeModel(edgeTable)); - } - graphModel.setEdges(edgeModels); - graphModel.setName(graph.getName()); - return graphModel; + public static EdgeTable convertToEdgeTable(String instanceName, EdgeModel model) { + if (model == null) { + return null; } + List fieldModels = model.getFields(); + List fields = new ArrayList<>(fieldModels.size()); + String srcIdFieldName = null; + String targetIdFieldName = null; + String timestampFieldName = null; + for (FieldModel fieldModel : fieldModels) { + switch (fieldModel.getCategory()) { + case EDGE_SOURCE_ID: + srcIdFieldName = fieldModel.getName(); + break; + case EDGE_TARGET_ID: + targetIdFieldName = fieldModel.getName(); + break; + case EDGE_TIMESTAMP: + timestampFieldName = fieldModel.getName(); + break; + default: + } + String typeName = convertTypeName(fieldModel.getType().name()); + IType fieldType = Types.of(typeName, -1); + TableField field = new TableField(fieldModel.getName(), fieldType, false); + fields.add(field); + } + return new EdgeTable( + instanceName, + model.getName(), + fields, + srcIdFieldName, + targetIdFieldName, + timestampFieldName); + } - public static GeaFlowGraph convertToGeaFlowGraph(GraphModel model, String instanceName) { - if (model == null) { - return null; - } - List vertices = model.getVertices(); - List vertexTables = new ArrayList<>(vertices.size()); - for (VertexModel vertexModel : vertices) { - vertexTables.add(convertToVertexTable(instanceName, vertexModel)); - } - GraphDescriptor desc = new GraphDescriptor(); - for (VertexTable vertex : vertexTables) { - desc.addNode(new NodeDescriptor(desc.getIdName(model.getName()), - vertex.getTypeName())); - } + public static GraphModel convertToGraphModel(GeaFlowGraph graph) { + GraphModel graphModel = new GraphModel(); + PluginConfigModel pluginConfigModel = new PluginConfigModel(); + pluginConfigModel.setType(graph.getStoreType()); + pluginConfigModel.setConfig(convertToGraphModelConfig(graph.getConfig().getConfigMap())); - List edges = model.getEdges(); - List edgeTables = new ArrayList<>(edges.size()); - for (EdgeModel edgeModel : edges) { - edgeTables.add(convertToEdgeTable(instanceName, edgeModel)); - } - if (model.getEndpoints() != null) { - for (Endpoint endpoint : model.getEndpoints()) { - desc.addEdge(new EdgeDescriptor(desc.getIdName(model.getName()), - endpoint.getEdgeName(), endpoint.getSourceName(), - endpoint.getTargetName())); - } - } - GeaFlowGraph geaFlowGraph = new GeaFlowGraph(instanceName, model.getName(), vertexTables, - edgeTables, convertToGeaFlowGraphConfig(model.getPluginConfig()), Collections.emptyMap(), - true, false); - geaFlowGraph.setDescriptor(geaFlowGraph.getValidDescriptorInGraph(desc)); - return geaFlowGraph; + graphModel.setPluginConfig(pluginConfigModel); + List vertexTables = graph.getVertexTables(); + List vertexModels = new ArrayList<>(vertexTables.size()); + for (VertexTable vertexTable : vertexTables) { + vertexModels.add(convertToVertexModel(vertexTable)); } - - public static GeaFlowFunction convertToGeaFlowFunction(FunctionModel model) { - return GeaFlowFunction.of(model.getName(), Lists.newArrayList(model.getEntryClass())); + graphModel.setVertices(vertexModels); + List edgeTables = graph.getEdgeTables(); + List edgeModels = new ArrayList<>(edgeTables.size()); + for (EdgeTable edgeTable : edgeTables) { + edgeModels.add(convertToEdgeModel(edgeTable)); } + graphModel.setEdges(edgeModels); + graphModel.setName(graph.getName()); + return graphModel; + } - private static List convertToFieldModel(List fields) { - List fieldModelList = new ArrayList<>(fields.size()); - for (TableField field : fields) { - FieldModel fieldModel = new FieldModel(field.getName(), null, - GeaFlowFieldType.getFieldType(field.getType()), GeaFlowFieldCategory.PROPERTY); - fieldModelList.add(fieldModel); - } - return fieldModelList; + public static GeaFlowGraph convertToGeaFlowGraph(GraphModel model, String instanceName) { + if (model == null) { + return null; } - - private static List convertToTableField(List fieldModels) { - List fields = new ArrayList<>(fieldModels.size()); - for (FieldModel fieldModel : fieldModels) { - String typeName = convertTypeName(fieldModel.getType().name()); - IType fieldType = Types.of(typeName, -1); - TableField field = new TableField(fieldModel.getName(), fieldType, false); - fields.add(field); - } - return fields; + List vertices = model.getVertices(); + List vertexTables = new ArrayList<>(vertices.size()); + for (VertexModel vertexModel : vertices) { + vertexTables.add(convertToVertexTable(instanceName, vertexModel)); } - - private static Map convertToTableModelConfig(Map tableConfig) { - Map modelConfig = new HashMap<>(tableConfig); - modelConfig.remove(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey()); - return modelConfig; + GraphDescriptor desc = new GraphDescriptor(); + for (VertexTable vertex : vertexTables) { + desc.addNode(new NodeDescriptor(desc.getIdName(model.getName()), vertex.getTypeName())); } - private static Map convertToGeaFlowTableConfig(PluginConfigModel configModel) { - Map tableConfig = new HashMap<>(configModel.getConfig()); - tableConfig.put(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey(), - configModel.getType()); - return tableConfig; + List edges = model.getEdges(); + List edgeTables = new ArrayList<>(edges.size()); + for (EdgeModel edgeModel : edges) { + edgeTables.add(convertToEdgeTable(instanceName, edgeModel)); } + if (model.getEndpoints() != null) { + for (Endpoint endpoint : model.getEndpoints()) { + desc.addEdge( + new EdgeDescriptor( + desc.getIdName(model.getName()), + endpoint.getEdgeName(), + endpoint.getSourceName(), + endpoint.getTargetName())); + } + } + GeaFlowGraph geaFlowGraph = + new GeaFlowGraph( + instanceName, + model.getName(), + vertexTables, + edgeTables, + convertToGeaFlowGraphConfig(model.getPluginConfig()), + Collections.emptyMap(), + true, + false); + geaFlowGraph.setDescriptor(geaFlowGraph.getValidDescriptorInGraph(desc)); + return geaFlowGraph; + } + + public static GeaFlowFunction convertToGeaFlowFunction(FunctionModel model) { + return GeaFlowFunction.of(model.getName(), Lists.newArrayList(model.getEntryClass())); + } - private static Map convertToGraphModelConfig(Map graphConfig) { - Map modelConfig = new HashMap<>(graphConfig); - modelConfig.remove(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey()); - return modelConfig; + private static List convertToFieldModel(List fields) { + List fieldModelList = new ArrayList<>(fields.size()); + for (TableField field : fields) { + FieldModel fieldModel = + new FieldModel( + field.getName(), + null, + GeaFlowFieldType.getFieldType(field.getType()), + GeaFlowFieldCategory.PROPERTY); + fieldModelList.add(fieldModel); } + return fieldModelList; + } - private static Map convertToGeaFlowGraphConfig(PluginConfigModel configModel) { - Map graphConfig = new HashMap<>(configModel.getConfig()); - graphConfig.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), - configModel.getType()); - return graphConfig; + private static List convertToTableField(List fieldModels) { + List fields = new ArrayList<>(fieldModels.size()); + for (FieldModel fieldModel : fieldModels) { + String typeName = convertTypeName(fieldModel.getType().name()); + IType fieldType = Types.of(typeName, -1); + TableField field = new TableField(fieldModel.getName(), fieldType, false); + fields.add(field); } + return fields; + } + + private static Map convertToTableModelConfig(Map tableConfig) { + Map modelConfig = new HashMap<>(tableConfig); + modelConfig.remove(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey()); + return modelConfig; + } + + private static Map convertToGeaFlowTableConfig(PluginConfigModel configModel) { + Map tableConfig = new HashMap<>(configModel.getConfig()); + tableConfig.put(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey(), configModel.getType()); + return tableConfig; + } + + private static Map convertToGraphModelConfig(Map graphConfig) { + Map modelConfig = new HashMap<>(graphConfig); + modelConfig.remove(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey()); + return modelConfig; + } + + private static Map convertToGeaFlowGraphConfig(PluginConfigModel configModel) { + Map graphConfig = new HashMap<>(configModel.getConfig()); + graphConfig.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), configModel.getType()); + return graphConfig; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalog.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalog.java index d72467753..b72486924 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalog.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalog.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.calcite.schema.Table; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.Catalog; @@ -39,192 +40,188 @@ public class ConsoleCatalog implements Catalog { - public static final String CATALOG_TYPE = "console"; + public static final String CATALOG_TYPE = "console"; - private final Map> allTables = new HashMap<>(); + private final Map> allTables = new HashMap<>(); - private final Map> allFunctions = new HashMap<>(); + private final Map> allFunctions = new HashMap<>(); - private Set allInstances; + private Set allInstances; - private Set allGraphsAndTables; + private Set allGraphsAndTables; - private ConsoleCatalogClient client; + private ConsoleCatalogClient client; - @Override - public void init(Configuration config) { - this.client = new ConsoleCatalogClient(config); - } + @Override + public void init(Configuration config) { + this.client = new ConsoleCatalogClient(config); + } - @Override - public String getType() { - return CATALOG_TYPE; - } + @Override + public String getType() { + return CATALOG_TYPE; + } - @Override - public Set listInstances() { - if (allInstances == null) { - allInstances = client.getInstances(); - } - return allInstances; + @Override + public Set listInstances() { + if (allInstances == null) { + allInstances = client.getInstances(); } + return allInstances; + } - @Override - public boolean isInstanceExists(String instanceName) { - Set allInstances = listInstances(); - return allInstances.contains(instanceName); - } + @Override + public boolean isInstanceExists(String instanceName) { + Set allInstances = listInstances(); + return allInstances.contains(instanceName); + } - @Override - public Table getGraph(String instanceName, String graphName) { - return allTables - .computeIfAbsent(instanceName, k -> new HashMap<>()) - .computeIfAbsent(graphName, k -> client.getGraph(instanceName, graphName)); - } + @Override + public Table getGraph(String instanceName, String graphName) { + return allTables + .computeIfAbsent(instanceName, k -> new HashMap<>()) + .computeIfAbsent(graphName, k -> client.getGraph(instanceName, graphName)); + } - @Override - public Table getTable(String instanceName, String tableName) { - return allTables - .computeIfAbsent(instanceName, k -> new HashMap<>()) - .computeIfAbsent(tableName, k -> client.getTable(instanceName, tableName)); - } + @Override + public Table getTable(String instanceName, String tableName) { + return allTables + .computeIfAbsent(instanceName, k -> new HashMap<>()) + .computeIfAbsent(tableName, k -> client.getTable(instanceName, tableName)); + } - @Override - public VertexTable getVertex(String instanceName, String vertexName) { - return (VertexTable) allTables + @Override + public VertexTable getVertex(String instanceName, String vertexName) { + return (VertexTable) + allTables .computeIfAbsent(instanceName, k -> new HashMap<>()) .computeIfAbsent(vertexName, k -> client.getVertex(instanceName, vertexName)); - } + } - @Override - public EdgeTable getEdge(String instanceName, String edgeName) { - return (EdgeTable) allTables + @Override + public EdgeTable getEdge(String instanceName, String edgeName) { + return (EdgeTable) + allTables .computeIfAbsent(instanceName, k -> new HashMap<>()) .computeIfAbsent(edgeName, k -> client.getEdge(instanceName, edgeName)); - } - - @Override - public GeaFlowFunction getFunction(String instanceName, String functionName) { - return allFunctions - .computeIfAbsent(instanceName, k -> new HashMap<>()) - .computeIfAbsent(functionName, k -> client.getFunction(instanceName, functionName)); - } - - @Override - public Set listGraphAndTable(String instanceName) { - if (allGraphsAndTables == null) { - allGraphsAndTables = new HashSet<>(); - allGraphsAndTables.addAll(client.getGraphs(instanceName)); - allGraphsAndTables.addAll(client.getTables(instanceName)); - } - return allGraphsAndTables; - } - - @Override - public void createGraph(String instanceName, GeaFlowGraph graph) - throws ObjectAlreadyExistException { - if (getGraph(instanceName, graph.getName()) != null) { - if (!graph.isIfNotExists()) { - throw new ObjectAlreadyExistException(graph.getName()); - } - return; - } - client.createGraph(instanceName, graph); - Map> edgeType2SourceTypes = new HashMap<>(); - Map> edgeType2TargetTypes = new HashMap<>(); - for (EdgeDescriptor edgeDescriptor : graph.getDescriptor().edges) { - edgeType2SourceTypes.computeIfAbsent(edgeDescriptor.type, t -> new ArrayList<>()); - edgeType2TargetTypes.computeIfAbsent(edgeDescriptor.type, t -> new ArrayList<>()); - edgeType2SourceTypes.get(edgeDescriptor.type).add(edgeDescriptor.sourceType); - edgeType2TargetTypes.get(edgeDescriptor.type).add(edgeDescriptor.targetType); - } - List edgeTypes = new ArrayList<>(); - List sourceVertexTypes = new ArrayList<>(); - List targetVertexTypes = new ArrayList<>(); - for (String edgeType : edgeType2SourceTypes.keySet()) { - for (int i = 0; i < edgeType2SourceTypes.get(edgeType).size(); i++) { - edgeTypes.add(edgeType); - sourceVertexTypes.add(edgeType2SourceTypes.get(edgeType).get(i)); - targetVertexTypes.add(edgeType2TargetTypes.get(edgeType).get(i)); - } - } - client.createEdgeEndpoints(instanceName, graph.getName(), edgeTypes, - sourceVertexTypes, targetVertexTypes); - allTables.get(instanceName).put(graph.getName(), graph); - } - - @Override - public void createTable(String instanceName, GeaFlowTable table) - throws ObjectAlreadyExistException { - if (getTable(instanceName, table.getName()) != null) { - if (!table.isIfNotExists()) { - throw new ObjectAlreadyExistException(table.getName()); - } - // ignore if table exists. - return; - } - client.createTable(instanceName, table); - allTables.get(instanceName).put(table.getName(), table); - } - - @Override - public void createView(String instanceName, GeaFlowView view) - throws ObjectAlreadyExistException { - Map tableMap = allTables - .computeIfAbsent(instanceName, k -> new HashMap<>()); - Table geaFlowView = tableMap.get(view.getName()); - if (geaFlowView != null) { - if (!view.isIfNotExists()) { - throw new ObjectAlreadyExistException(view.getName()); - } - return; - } - tableMap.put(view.getName(), view); - } - - @Override - public void createFunction(String instanceName, GeaFlowFunction function) - throws ObjectAlreadyExistException { - - } - - @Override - public void dropGraph(String instanceName, String graphName) { - client.deleteGraph(instanceName, graphName); - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).remove(graphName); - } - - @Override - public void dropTable(String instanceName, String tableName) { - client.deleteTable(instanceName, tableName); - allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).remove(tableName); - } - - @Override - public void dropFunction(String instanceName, String functionName) { - - } - - @Override - public String describeGraph(String instanceName, String graphName) { - Table graph = getGraph(instanceName, graphName); - if (graph != null) { - return graph.toString(); - } - return null; - } - - @Override - public String describeTable(String instanceName, String tableName) { - Table table = getTable(instanceName, tableName); - if (table != null) { - return table.toString(); - } - return null; - } - - @Override - public String describeFunction(String instanceName, String functionName) { - return null; - } + } + + @Override + public GeaFlowFunction getFunction(String instanceName, String functionName) { + return allFunctions + .computeIfAbsent(instanceName, k -> new HashMap<>()) + .computeIfAbsent(functionName, k -> client.getFunction(instanceName, functionName)); + } + + @Override + public Set listGraphAndTable(String instanceName) { + if (allGraphsAndTables == null) { + allGraphsAndTables = new HashSet<>(); + allGraphsAndTables.addAll(client.getGraphs(instanceName)); + allGraphsAndTables.addAll(client.getTables(instanceName)); + } + return allGraphsAndTables; + } + + @Override + public void createGraph(String instanceName, GeaFlowGraph graph) + throws ObjectAlreadyExistException { + if (getGraph(instanceName, graph.getName()) != null) { + if (!graph.isIfNotExists()) { + throw new ObjectAlreadyExistException(graph.getName()); + } + return; + } + client.createGraph(instanceName, graph); + Map> edgeType2SourceTypes = new HashMap<>(); + Map> edgeType2TargetTypes = new HashMap<>(); + for (EdgeDescriptor edgeDescriptor : graph.getDescriptor().edges) { + edgeType2SourceTypes.computeIfAbsent(edgeDescriptor.type, t -> new ArrayList<>()); + edgeType2TargetTypes.computeIfAbsent(edgeDescriptor.type, t -> new ArrayList<>()); + edgeType2SourceTypes.get(edgeDescriptor.type).add(edgeDescriptor.sourceType); + edgeType2TargetTypes.get(edgeDescriptor.type).add(edgeDescriptor.targetType); + } + List edgeTypes = new ArrayList<>(); + List sourceVertexTypes = new ArrayList<>(); + List targetVertexTypes = new ArrayList<>(); + for (String edgeType : edgeType2SourceTypes.keySet()) { + for (int i = 0; i < edgeType2SourceTypes.get(edgeType).size(); i++) { + edgeTypes.add(edgeType); + sourceVertexTypes.add(edgeType2SourceTypes.get(edgeType).get(i)); + targetVertexTypes.add(edgeType2TargetTypes.get(edgeType).get(i)); + } + } + client.createEdgeEndpoints( + instanceName, graph.getName(), edgeTypes, sourceVertexTypes, targetVertexTypes); + allTables.get(instanceName).put(graph.getName(), graph); + } + + @Override + public void createTable(String instanceName, GeaFlowTable table) + throws ObjectAlreadyExistException { + if (getTable(instanceName, table.getName()) != null) { + if (!table.isIfNotExists()) { + throw new ObjectAlreadyExistException(table.getName()); + } + // ignore if table exists. + return; + } + client.createTable(instanceName, table); + allTables.get(instanceName).put(table.getName(), table); + } + + @Override + public void createView(String instanceName, GeaFlowView view) throws ObjectAlreadyExistException { + Map tableMap = allTables.computeIfAbsent(instanceName, k -> new HashMap<>()); + Table geaFlowView = tableMap.get(view.getName()); + if (geaFlowView != null) { + if (!view.isIfNotExists()) { + throw new ObjectAlreadyExistException(view.getName()); + } + return; + } + tableMap.put(view.getName(), view); + } + + @Override + public void createFunction(String instanceName, GeaFlowFunction function) + throws ObjectAlreadyExistException {} + + @Override + public void dropGraph(String instanceName, String graphName) { + client.deleteGraph(instanceName, graphName); + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).remove(graphName); + } + + @Override + public void dropTable(String instanceName, String tableName) { + client.deleteTable(instanceName, tableName); + allTables.computeIfAbsent(instanceName, k -> new HashMap<>()).remove(tableName); + } + + @Override + public void dropFunction(String instanceName, String functionName) {} + + @Override + public String describeGraph(String instanceName, String graphName) { + Table graph = getGraph(instanceName, graphName); + if (graph != null) { + return graph.toString(); + } + return null; + } + + @Override + public String describeTable(String instanceName, String tableName) { + Table table = getTable(instanceName, tableName); + if (table != null) { + return table.toString(); + } + return null; + } + + @Override + public String describeFunction(String instanceName, String functionName) { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalogClient.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalogClient.java index b9f528c1b..8be196756 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalogClient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/ConsoleCatalogClient.java @@ -21,14 +21,13 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT; -import com.google.common.reflect.TypeToken; -import com.google.gson.Gson; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; @@ -39,139 +38,144 @@ import org.apache.geaflow.dsl.schema.GeaFlowTable; import org.apache.geaflow.utils.HttpUtil; -public class ConsoleCatalogClient { - - private final Map headers = new HashMap<>(); - - private final String endpoint; - - private final Gson gson; - - public ConsoleCatalogClient(Configuration config) { - String token = config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY); - this.headers.put("geaflow-token", token); - this.endpoint = config.getString(GEAFLOW_GW_ENDPOINT); - this.gson = new Gson(); - } - - public void createTable(String instanceName, GeaFlowTable table) { - String createUrl = endpoint + "/api/instances/" + instanceName + "/tables"; - TableModel tableModel = CatalogUtil.convertToTableModel(table); - HttpUtil.post(createUrl, gson.toJson(tableModel), headers, Object.class); - } - - public void createGraph(String instanceName, GeaFlowGraph graph) { - String createUrl = endpoint + "/api/instances/" + instanceName + "/graphs"; - GraphModel graphModel = CatalogUtil.convertToGraphModel(graph); - HttpUtil.post(createUrl, gson.toJson(graphModel), headers, Object.class); - } - - public void createEdgeEndpoints(String instanceName, String graphName, List edgeNames, - List srcVertexNames, List targetVertexNames) { - assert srcVertexNames != null && targetVertexNames != null - && srcVertexNames.size() == targetVertexNames.size(); - String createUrl = endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName - + "/endpoints"; - List endpointWrappers = new ArrayList<>(); - for (int i = 0; i < srcVertexNames.size(); i++) { - endpointWrappers.add(new EndpointWrapper(srcVertexNames.get(i), - targetVertexNames.get(i), edgeNames.get(i))); - } - HttpUtil.post(createUrl, gson.toJson(endpointWrappers), headers, Object.class); - } - - public static class EndpointWrapper { - - public final String sourceName; - public final String targetName; - public final String edgeName; - - public EndpointWrapper(String sourceName, String targetName, String edgeName) { - assert !StringUtils.isBlank(sourceName) && !StringUtils.isBlank(targetName) - && !StringUtils.isBlank(edgeName); - this.sourceName = sourceName; - this.targetName = targetName; - this.edgeName = edgeName; - } - - } - - public GeaFlowTable getTable(String instanceName, String tableName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/tables/" + tableName; - TableModel tableModel = HttpUtil.get(getUrl, headers, TableModel.class); - return CatalogUtil.convertToGeaFlowTable(tableModel, instanceName); - } - - public GeaFlowGraph getGraph(String instanceName, String graphName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName; - GraphModel graphModel = HttpUtil.get(getUrl, headers, GraphModel.class); - return CatalogUtil.convertToGeaFlowGraph(graphModel, instanceName); - } - - public VertexTable getVertex(String instanceName, String vertexName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/vertices/" + vertexName; - VertexModel vertexModel = HttpUtil.get(getUrl, headers, VertexModel.class); - return CatalogUtil.convertToVertexTable(instanceName, vertexModel); - } - - public EdgeTable getEdge(String instanceName, String edgeName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/edges/" + edgeName; - EdgeModel edgeModel = HttpUtil.get(getUrl, headers, EdgeModel.class); - return CatalogUtil.convertToEdgeTable(instanceName, edgeModel); - } +import com.google.common.reflect.TypeToken; +import com.google.gson.Gson; - public GeaFlowFunction getFunction(String instanceName, String functionName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/functions/" + functionName; - FunctionModel functionModel = HttpUtil.get(getUrl, headers, FunctionModel.class); - return CatalogUtil.convertToGeaFlowFunction(functionModel); - } +public class ConsoleCatalogClient { - public void deleteTable(String instanceName, String tableName) { - String deleteUrl = endpoint + "/api/instances/" + instanceName + "/tables/" + tableName; - HttpUtil.delete(deleteUrl, headers); + private final Map headers = new HashMap<>(); + + private final String endpoint; + + private final Gson gson; + + public ConsoleCatalogClient(Configuration config) { + String token = config.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY); + this.headers.put("geaflow-token", token); + this.endpoint = config.getString(GEAFLOW_GW_ENDPOINT); + this.gson = new Gson(); + } + + public void createTable(String instanceName, GeaFlowTable table) { + String createUrl = endpoint + "/api/instances/" + instanceName + "/tables"; + TableModel tableModel = CatalogUtil.convertToTableModel(table); + HttpUtil.post(createUrl, gson.toJson(tableModel), headers, Object.class); + } + + public void createGraph(String instanceName, GeaFlowGraph graph) { + String createUrl = endpoint + "/api/instances/" + instanceName + "/graphs"; + GraphModel graphModel = CatalogUtil.convertToGraphModel(graph); + HttpUtil.post(createUrl, gson.toJson(graphModel), headers, Object.class); + } + + public void createEdgeEndpoints( + String instanceName, + String graphName, + List edgeNames, + List srcVertexNames, + List targetVertexNames) { + assert srcVertexNames != null + && targetVertexNames != null + && srcVertexNames.size() == targetVertexNames.size(); + String createUrl = + endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName + "/endpoints"; + List endpointWrappers = new ArrayList<>(); + for (int i = 0; i < srcVertexNames.size(); i++) { + endpointWrappers.add( + new EndpointWrapper(srcVertexNames.get(i), targetVertexNames.get(i), edgeNames.get(i))); } - - public void deleteGraph(String instanceName, String graphName) { - String deleteUrl = endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName; - HttpUtil.delete(deleteUrl, headers); + HttpUtil.post(createUrl, gson.toJson(endpointWrappers), headers, Object.class); + } + + public static class EndpointWrapper { + + public final String sourceName; + public final String targetName; + public final String edgeName; + + public EndpointWrapper(String sourceName, String targetName, String edgeName) { + assert !StringUtils.isBlank(sourceName) + && !StringUtils.isBlank(targetName) + && !StringUtils.isBlank(edgeName); + this.sourceName = sourceName; + this.targetName = targetName; + this.edgeName = edgeName; } - - public Set getTables(String instanceName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/tables"; - Set tableNames = new HashSet<>(); - PageList tableModels = HttpUtil.get(getUrl, headers, - new TypeToken>() { - }.getType()); - List tableModelsList = tableModels.getList(); - for (TableModel tableModel : tableModelsList) { - tableNames.add(tableModel.getName()); - } - return tableNames; + } + + public GeaFlowTable getTable(String instanceName, String tableName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/tables/" + tableName; + TableModel tableModel = HttpUtil.get(getUrl, headers, TableModel.class); + return CatalogUtil.convertToGeaFlowTable(tableModel, instanceName); + } + + public GeaFlowGraph getGraph(String instanceName, String graphName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName; + GraphModel graphModel = HttpUtil.get(getUrl, headers, GraphModel.class); + return CatalogUtil.convertToGeaFlowGraph(graphModel, instanceName); + } + + public VertexTable getVertex(String instanceName, String vertexName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/vertices/" + vertexName; + VertexModel vertexModel = HttpUtil.get(getUrl, headers, VertexModel.class); + return CatalogUtil.convertToVertexTable(instanceName, vertexModel); + } + + public EdgeTable getEdge(String instanceName, String edgeName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/edges/" + edgeName; + EdgeModel edgeModel = HttpUtil.get(getUrl, headers, EdgeModel.class); + return CatalogUtil.convertToEdgeTable(instanceName, edgeModel); + } + + public GeaFlowFunction getFunction(String instanceName, String functionName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/functions/" + functionName; + FunctionModel functionModel = HttpUtil.get(getUrl, headers, FunctionModel.class); + return CatalogUtil.convertToGeaFlowFunction(functionModel); + } + + public void deleteTable(String instanceName, String tableName) { + String deleteUrl = endpoint + "/api/instances/" + instanceName + "/tables/" + tableName; + HttpUtil.delete(deleteUrl, headers); + } + + public void deleteGraph(String instanceName, String graphName) { + String deleteUrl = endpoint + "/api/instances/" + instanceName + "/graphs/" + graphName; + HttpUtil.delete(deleteUrl, headers); + } + + public Set getTables(String instanceName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/tables"; + Set tableNames = new HashSet<>(); + PageList tableModels = + HttpUtil.get(getUrl, headers, new TypeToken>() {}.getType()); + List tableModelsList = tableModels.getList(); + for (TableModel tableModel : tableModelsList) { + tableNames.add(tableModel.getName()); } - - public Set getGraphs(String instanceName) { - String getUrl = endpoint + "/api/instances/" + instanceName + "/graphs"; - Set graphNames = new HashSet<>(); - PageList graphModels = HttpUtil.get(getUrl, headers, - new TypeToken>() { - }.getType()); - List graphModelsList = graphModels.getList(); - for (GraphModel graphModel : graphModelsList) { - graphNames.add(graphModel.getName()); - } - return graphNames; + return tableNames; + } + + public Set getGraphs(String instanceName) { + String getUrl = endpoint + "/api/instances/" + instanceName + "/graphs"; + Set graphNames = new HashSet<>(); + PageList graphModels = + HttpUtil.get(getUrl, headers, new TypeToken>() {}.getType()); + List graphModelsList = graphModels.getList(); + for (GraphModel graphModel : graphModelsList) { + graphNames.add(graphModel.getName()); } - - public Set getInstances() { - String getUrl = endpoint + "/api/instances"; - Set instanceNames = new HashSet<>(); - PageList instanceModels = HttpUtil.get(getUrl, headers, - new TypeToken>() { - }.getType()); - List instanceModelsList = instanceModels.getList(); - for (InstanceModel instanceModel : instanceModelsList) { - instanceNames.add(instanceModel.getName()); - } - return instanceNames; + return graphNames; + } + + public Set getInstances() { + String getUrl = endpoint + "/api/instances"; + Set instanceNames = new HashSet<>(); + PageList instanceModels = + HttpUtil.get(getUrl, headers, new TypeToken>() {}.getType()); + List instanceModelsList = instanceModels.getList(); + for (InstanceModel instanceModel : instanceModelsList) { + instanceNames.add(instanceModel.getName()); } + return instanceNames; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/EdgeModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/EdgeModel.java index 7bcbd98b6..20407e6b7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/EdgeModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/EdgeModel.java @@ -21,7 +21,7 @@ public class EdgeModel extends AbstractStructModel { - public EdgeModel() { - type = GeaFlowStructType.EDGE; - } + public EdgeModel() { + type = GeaFlowStructType.EDGE; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FieldModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FieldModel.java index 590495f02..4cd39d9b6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FieldModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FieldModel.java @@ -21,29 +21,30 @@ public class FieldModel extends AbstractNameModel { - private GeaFlowFieldType type; + private GeaFlowFieldType type; - private GeaFlowFieldCategory category; + private GeaFlowFieldCategory category; - public FieldModel(String name, String comment, GeaFlowFieldType type, GeaFlowFieldCategory category) { - super(name, comment); - this.type = type; - this.category = category; - } + public FieldModel( + String name, String comment, GeaFlowFieldType type, GeaFlowFieldCategory category) { + super(name, comment); + this.type = type; + this.category = category; + } - public GeaFlowFieldType getType() { - return type; - } + public GeaFlowFieldType getType() { + return type; + } - public void setType(GeaFlowFieldType type) { - this.type = type; - } + public void setType(GeaFlowFieldType type) { + this.type = type; + } - public GeaFlowFieldCategory getCategory() { - return category; - } + public GeaFlowFieldCategory getCategory() { + return category; + } - public void setCategory(GeaFlowFieldCategory category) { - this.category = category; - } + public void setCategory(GeaFlowFieldCategory category) { + this.category = category; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FunctionModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FunctionModel.java index 0e98e01ca..424cb60f7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FunctionModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/FunctionModel.java @@ -21,13 +21,13 @@ public class FunctionModel extends AbstractDataModel { - private String entryClass; + private String entryClass; - public String getEntryClass() { - return entryClass; - } + public String getEntryClass() { + return entryClass; + } - public void setEntryClass(String entryClass) { - this.entryClass = entryClass; - } + public void setEntryClass(String entryClass) { + this.entryClass = entryClass; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowEdgeDirection.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowEdgeDirection.java index 80a4befba..782a2de6e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowEdgeDirection.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowEdgeDirection.java @@ -19,20 +19,12 @@ package org.apache.geaflow.dsl.catalog.console; -/** - * The edge direction defined on the console platform. - */ +/** The edge direction defined on the console platform. */ public enum GeaFlowEdgeDirection { - /** - * Out direction. - */ - OUT, - /** - * In direction. - */ - IN, - /** - * Both direction. - */ - BOTH + /** Out direction. */ + OUT, + /** In direction. */ + IN, + /** Both direction. */ + BOTH } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldCategory.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldCategory.java index 073e1d833..6fc3f3c76 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldCategory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldCategory.java @@ -19,61 +19,44 @@ package org.apache.geaflow.dsl.catalog.console; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Set; -/** - * Field category, indicating what struct for {@link GeaFlowStructType} it belongs to. - */ +import com.google.common.collect.Sets; + +/** Field category, indicating what struct for {@link GeaFlowStructType} it belongs to. */ public enum GeaFlowFieldCategory { - /** - * Property, all data structures in {@link GeaFlowStructType} have it. - */ - PROPERTY(GeaFlowStructType.values()), - /** - * Id, table struct and view struct have it. - */ - ID(GeaFlowStructType.TABLE, GeaFlowStructType.VIEW), - /** - * Vertex id, only vertex struct has it. - */ - VERTEX_ID(GeaFlowStructType.VERTEX), - /** - * Vertex label, only vertex struct has it. - */ - VERTEX_LABEL(GeaFlowStructType.VERTEX), - /** - * Edge source id, only edge struct has it. - */ - EDGE_SOURCE_ID(GeaFlowStructType.EDGE), - /** - * Edge target id, only edge struct has it. - */ - EDGE_TARGET_ID(GeaFlowStructType.EDGE), - /** - * Edge label, only edge struct has it. - */ - EDGE_LABEL(GeaFlowStructType.EDGE), - /** - * Edge timestamp, only edge struct has it. - */ - EDGE_TIMESTAMP(GeaFlowStructType.EDGE); + /** Property, all data structures in {@link GeaFlowStructType} have it. */ + PROPERTY(GeaFlowStructType.values()), + /** Id, table struct and view struct have it. */ + ID(GeaFlowStructType.TABLE, GeaFlowStructType.VIEW), + /** Vertex id, only vertex struct has it. */ + VERTEX_ID(GeaFlowStructType.VERTEX), + /** Vertex label, only vertex struct has it. */ + VERTEX_LABEL(GeaFlowStructType.VERTEX), + /** Edge source id, only edge struct has it. */ + EDGE_SOURCE_ID(GeaFlowStructType.EDGE), + /** Edge target id, only edge struct has it. */ + EDGE_TARGET_ID(GeaFlowStructType.EDGE), + /** Edge label, only edge struct has it. */ + EDGE_LABEL(GeaFlowStructType.EDGE), + /** Edge timestamp, only edge struct has it. */ + EDGE_TIMESTAMP(GeaFlowStructType.EDGE); - private final Set structTypes; + private final Set structTypes; - GeaFlowFieldCategory(GeaFlowStructType... structTypes) { - this.structTypes = Sets.newHashSet(structTypes); - } + GeaFlowFieldCategory(GeaFlowStructType... structTypes) { + this.structTypes = Sets.newHashSet(structTypes); + } - public static List of(GeaFlowStructType structType) { - List constraints = new ArrayList<>(); - for (GeaFlowFieldCategory value : values()) { - if (value.structTypes.contains(structType)) { - constraints.add(value); - } - } - return constraints; + public static List of(GeaFlowStructType structType) { + List constraints = new ArrayList<>(); + for (GeaFlowFieldCategory value : values()) { + if (value.structTypes.contains(structType)) { + constraints.add(value); + } } + return constraints; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldType.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldType.java index 703ee8dbb..c3f5ecc33 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowFieldType.java @@ -23,52 +23,38 @@ import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; -/** - * Basic data type supported on the console platform. - */ +/** Basic data type supported on the console platform. */ public enum GeaFlowFieldType { - /** - * Boolean type. - */ - BOOLEAN, - /** - * Int type. - */ - INT, - /** - * Bigint type. - */ - BIGINT, - /** - * Double type. - */ - DOUBLE, - /** - * Varchar type. - */ - VARCHAR, - /** - * Timestamp type. - */ - TIMESTAMP; + /** Boolean type. */ + BOOLEAN, + /** Int type. */ + INT, + /** Bigint type. */ + BIGINT, + /** Double type. */ + DOUBLE, + /** Varchar type. */ + VARCHAR, + /** Timestamp type. */ + TIMESTAMP; - public static GeaFlowFieldType getFieldType(IType type) { - switch (type.getName()) { - case Types.TYPE_NAME_STRING: - case Types.TYPE_NAME_BINARY_STRING: - return VARCHAR; - case Types.TYPE_NAME_LONG: - return BIGINT; - case Types.TYPE_NAME_INTEGER: - return INT; - default: - for (GeaFlowFieldType t : values()) { - if (t.name().equalsIgnoreCase(type.getName())) { - return t; - } - } - break; + public static GeaFlowFieldType getFieldType(IType type) { + switch (type.getName()) { + case Types.TYPE_NAME_STRING: + case Types.TYPE_NAME_BINARY_STRING: + return VARCHAR; + case Types.TYPE_NAME_LONG: + return BIGINT; + case Types.TYPE_NAME_INTEGER: + return INT; + default: + for (GeaFlowFieldType t : values()) { + if (t.name().equalsIgnoreCase(type.getName())) { + return t; + } } - throw new GeaflowRuntimeException("can not find relate field type: " + type.getName()); + break; } + throw new GeaflowRuntimeException("can not find relate field type: " + type.getName()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowStructType.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowStructType.java index a606f9504..a6b5cb547 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowStructType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GeaFlowStructType.java @@ -19,24 +19,14 @@ package org.apache.geaflow.dsl.catalog.console; -/** - * Main data structures stored in the console database. - */ +/** Main data structures stored in the console database. */ public enum GeaFlowStructType { - /** - * Table struct. - */ - TABLE, - /** - * View struct. - */ - VIEW, - /** - * Vertex struct. - */ - VERTEX, - /** - * Edge struct. - */ - EDGE + /** Table struct. */ + TABLE, + /** View struct. */ + VIEW, + /** Vertex struct. */ + VERTEX, + /** Edge struct. */ + EDGE } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GraphModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GraphModel.java index 9ff4f3773..541286134 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GraphModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/GraphModel.java @@ -23,71 +23,68 @@ public class GraphModel extends AbstractDataModel { - private PluginConfigModel pluginConfig; + private PluginConfigModel pluginConfig; - private List vertices; + private List vertices; - private List edges; + private List edges; - private List endpoints; + private List endpoints; - public PluginConfigModel getPluginConfig() { - return pluginConfig; - } - - public void setPluginConfig(PluginConfigModel pluginConfig) { - this.pluginConfig = pluginConfig; - } - - public List getVertices() { - return vertices; - } - - public void setVertices(List vertices) { - this.vertices = vertices; - } + public PluginConfigModel getPluginConfig() { + return pluginConfig; + } - public List getEdges() { - return edges; - } + public void setPluginConfig(PluginConfigModel pluginConfig) { + this.pluginConfig = pluginConfig; + } - public void setEdges(List edges) { - this.edges = edges; - } + public List getVertices() { + return vertices; + } + public void setVertices(List vertices) { + this.vertices = vertices; + } - public List getEndpoints() { - return endpoints; - } + public List getEdges() { + return edges; + } - public void setEndpoints(List endpoints) { - this.endpoints = endpoints; - } + public void setEdges(List edges) { + this.edges = edges; + } - public static class Endpoint { + public List getEndpoints() { + return endpoints; + } - private final String sourceName; - private final String targetName; - private final String edgeName; + public void setEndpoints(List endpoints) { + this.endpoints = endpoints; + } - public Endpoint(String srcVertex, String targetVertex, String edgeName) { - this.sourceName = srcVertex; - this.targetName = targetVertex; - this.edgeName = edgeName; - } + public static class Endpoint { - public String getSourceName() { - return sourceName; - } + private final String sourceName; + private final String targetName; + private final String edgeName; - public String getTargetName() { - return targetName; - } + public Endpoint(String srcVertex, String targetVertex, String edgeName) { + this.sourceName = srcVertex; + this.targetName = targetVertex; + this.edgeName = edgeName; + } - public String getEdgeName() { - return edgeName; - } + public String getSourceName() { + return sourceName; + } + public String getTargetName() { + return targetName; } + public String getEdgeName() { + return edgeName; + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/InstanceModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/InstanceModel.java index 528853588..a6c1e415b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/InstanceModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/InstanceModel.java @@ -19,6 +19,4 @@ package org.apache.geaflow.dsl.catalog.console; -public class InstanceModel extends AbstractNameModel { - -} +public class InstanceModel extends AbstractNameModel {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PageList.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PageList.java index bac22e4e4..d0458e3f6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PageList.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PageList.java @@ -23,23 +23,23 @@ public class PageList { - private List list; + private List list; - private long total; + private long total; - public List getList() { - return list; - } + public List getList() { + return list; + } - public void setList(List list) { - this.list = list; - } + public void setList(List list) { + this.list = list; + } - public long getTotal() { - return total; - } + public long getTotal() { + return total; + } - public void setTotal(long total) { - this.total = total; - } + public void setTotal(long total) { + this.total = total; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PluginConfigModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PluginConfigModel.java index e66c5eb48..1a3f3e157 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PluginConfigModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/PluginConfigModel.java @@ -23,23 +23,23 @@ public class PluginConfigModel extends AbstractNameModel { - private String type; + private String type; - private Map config; + private Map config; - public String getType() { - return type; - } + public String getType() { + return type; + } - public void setType(String type) { - this.type = type; - } + public void setType(String type) { + this.type = type; + } - public Map getConfig() { - return config; - } + public Map getConfig() { + return config; + } - public void setConfig(Map config) { - this.config = config; - } + public void setConfig(Map config) { + this.config = config; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/TableModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/TableModel.java index 029940f1d..0e631bdc5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/TableModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/TableModel.java @@ -21,17 +21,17 @@ public class TableModel extends AbstractStructModel { - private PluginConfigModel pluginConfig; + private PluginConfigModel pluginConfig; - public TableModel() { - type = GeaFlowStructType.TABLE; - } + public TableModel() { + type = GeaFlowStructType.TABLE; + } - public PluginConfigModel getPluginConfig() { - return pluginConfig; - } + public PluginConfigModel getPluginConfig() { + return pluginConfig; + } - public void setPluginConfig(PluginConfigModel pluginConfig) { - this.pluginConfig = pluginConfig; - } + public void setPluginConfig(PluginConfigModel pluginConfig) { + this.pluginConfig = pluginConfig; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/VertexModel.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/VertexModel.java index e252d02b5..2330c5f80 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/VertexModel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/console/VertexModel.java @@ -21,7 +21,7 @@ public class VertexModel extends AbstractStructModel { - public VertexModel() { - type = GeaFlowStructType.VERTEX; - } + public VertexModel() { + type = GeaFlowStructType.VERTEX; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectAlreadyExistException.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectAlreadyExistException.java index e61fdfbfc..d31e477a8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectAlreadyExistException.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectAlreadyExistException.java @@ -23,7 +23,7 @@ public class ObjectAlreadyExistException extends GeaFlowDSLException { - public ObjectAlreadyExistException(String objectName) { - super("'" + objectName + "' already exists"); - } + public ObjectAlreadyExistException(String objectName) { + super("'" + objectName + "' already exists"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectNotExistException.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectNotExistException.java index 8a4cc9139..ee4753abc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectNotExistException.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/catalog/exception/ObjectNotExistException.java @@ -23,7 +23,7 @@ public class ObjectNotExistException extends GeaFlowDSLException { - public ObjectNotExistException(String objectName) { - super("'" + objectName + "' is not exist in catalog"); - } + public ObjectNotExistException(String objectName) { + super("'" + objectName + "' is not exist in catalog"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GQLStatistic.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GQLStatistic.java index 712cd42a2..4095ba0ec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GQLStatistic.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GQLStatistic.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelReferentialConstraint; @@ -30,28 +31,28 @@ public class GQLStatistic implements Statistic { - @Override - public Double getRowCount() { - return null; - } - - @Override - public boolean isKey(ImmutableBitSet columns) { - return false; - } - - @Override - public List getReferentialConstraints() { - return new ArrayList<>(); - } - - @Override - public List getCollations() { - return Collections.emptyList(); - } - - @Override - public RelDistribution getDistribution() { - return null; - } + @Override + public Double getRowCount() { + return null; + } + + @Override + public boolean isKey(ImmutableBitSet columns) { + return false; + } + + @Override + public List getReferentialConstraints() { + return new ArrayList<>(); + } + + @Override + public List getCollations() { + return Collections.emptyList(); + } + + @Override + public RelDistribution getDistribution() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowFunction.java index 719f19226..fd5862796 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowFunction.java @@ -19,135 +19,134 @@ package org.apache.geaflow.dsl.schema; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.List; import java.util.Objects; + import org.apache.calcite.sql.SqlIdentifier; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.sqlnode.SqlCreateFunction; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + public class GeaFlowFunction implements Serializable { - public enum FunctionType { + public enum FunctionType { - /** - * User-defined scalar function. - */ - UDF, + /** User-defined scalar function. */ + UDF, - /** - * User-defined table function. - */ - UDTF, + /** User-defined table function. */ + UDTF, - /** - * User-defined aggregate function. - */ - UDAF, + /** User-defined aggregate function. */ + UDAF, - /** - * User-defined graph algorithm. - */ - UDGA - } + /** User-defined graph algorithm. */ + UDGA + } - /** - * Function name. - */ - private final String name; + /** Function name. */ + private final String name; - private final List clazz; + private final List clazz; - private final String url; + private final String url; - private final boolean ifNotExists; + private final boolean ifNotExists; - public static GeaFlowFunction toFunction(SqlCreateFunction function) { - // Extract function name - String functionName = ((SqlIdentifier) function.getFunctionName()).getSimple(); - String className = function.getClassName(); - String url = function.getUsingPath(); - return new GeaFlowFunction(functionName, Lists.newArrayList(className), url, function.ifNotExists()); - } + public static GeaFlowFunction toFunction(SqlCreateFunction function) { + // Extract function name + String functionName = ((SqlIdentifier) function.getFunctionName()).getSimple(); + String className = function.getClassName(); + String url = function.getUsingPath(); + return new GeaFlowFunction( + functionName, Lists.newArrayList(className), url, function.ifNotExists()); + } + public static GeaFlowFunction of(Class functionClazz) { + Description description = (Description) functionClazz.getAnnotation(Description.class); - public static GeaFlowFunction of(Class functionClazz) { - Description description = (Description) functionClazz.getAnnotation(Description.class); + Preconditions.checkState( + description != null, "missing Description annotation for udf " + functionClazz); - Preconditions.checkState(description != null, - "missing Description annotation for udf " + functionClazz); + Preconditions.checkArgument( + !description.name().contains(","), + "bad udf name " + description.name() + " in " + functionClazz); - Preconditions.checkArgument(!description.name().contains(","), - "bad udf name " + description.name() + " in " + functionClazz); + return new GeaFlowFunction(description.name(), functionClazz.getName(), false); + } - return new GeaFlowFunction(description.name(), functionClazz.getName(), false); - } + public static GeaFlowFunction of(String name, Class reflectClass) { + return new GeaFlowFunction(name, reflectClass.getName(), false); + } - public static GeaFlowFunction of(String name, Class reflectClass) { - return new GeaFlowFunction(name, reflectClass.getName(), false); - } + public static GeaFlowFunction of(String name, List classNames) { + return of(name, classNames, null); + } - public static GeaFlowFunction of(String name, List classNames) { - return of(name, classNames, null); - } + public static GeaFlowFunction of(String name, List classNames, String url) { + return new GeaFlowFunction(name, classNames, url, false); + } - public static GeaFlowFunction of(String name, List classNames, String url) { - return new GeaFlowFunction(name, classNames, url, false); - } + private GeaFlowFunction(String name, String clazz, boolean ifNotExists) { + this(name, Lists.newArrayList(clazz), null, ifNotExists); + } - private GeaFlowFunction(String name, String clazz, boolean ifNotExists) { - this(name, Lists.newArrayList(clazz), null, ifNotExists); - } - - private GeaFlowFunction(String name, List clazz, String url, boolean ifNotExists) { - this.name = name; - this.clazz = clazz; - this.url = url; - this.ifNotExists = ifNotExists; - } + private GeaFlowFunction(String name, List clazz, String url, boolean ifNotExists) { + this.name = name; + this.clazz = clazz; + this.url = url; + this.ifNotExists = ifNotExists; + } - public boolean isIfNotExists() { - return ifNotExists; - } + public boolean isIfNotExists() { + return ifNotExists; + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public List getClazz() { - return clazz; - } + public List getClazz() { + return clazz; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof GeaFlowFunction)) { - return false; - } - GeaFlowFunction that = (GeaFlowFunction) o; - return Objects.equals(name, that.name) && Objects.equals(clazz, that.clazz) - && Objects.equals(url, that.url); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(name, clazz, url); - } - - @Override - public String toString() { - return "GeaFlowFunction{" - + "name='" + name + '\'' - + ", clazz=" + clazz - + ", url='" + url + '\'' - + '}'; - } - - public String getUrl() { - return url; + if (!(o instanceof GeaFlowFunction)) { + return false; } + GeaFlowFunction that = (GeaFlowFunction) o; + return Objects.equals(name, that.name) + && Objects.equals(clazz, that.clazz) + && Objects.equals(url, that.url); + } + + @Override + public int hashCode() { + return Objects.hash(name, clazz, url); + } + + @Override + public String toString() { + return "GeaFlowFunction{" + + "name='" + + name + + '\'' + + ", clazz=" + + clazz + + ", url='" + + url + + '\'' + + '}'; + } + + public String getUrl() { + return url; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowGraph.java index e5e8f7f38..27dcdae56 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowGraph.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.schema; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; @@ -31,6 +29,7 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -50,429 +49,524 @@ import org.apache.geaflow.dsl.common.types.TableField; import org.apache.geaflow.dsl.util.SqlTypeUtil; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + public class GeaFlowGraph extends AbstractTable implements Serializable { - private final String instanceName; - private final String name; - private final List vertexTables; - private final List edgeTables; - private final Map usingTables; - private final Map config; - private final boolean ifNotExists; - private final boolean isTemporary; - private GraphDescriptor graphDescriptor; - - public GeaFlowGraph(String instanceName, String name, List vertexTables, - List edgeTables, Map config, - Map usingTables, boolean ifNotExists, boolean isTemporary) { - this.instanceName = instanceName; - this.name = name; - this.vertexTables = vertexTables; - this.edgeTables = edgeTables; - this.config = new HashMap<>(config); - this.usingTables = ImmutableMap.copyOf(usingTables); - this.ifNotExists = ifNotExists; - this.isTemporary = isTemporary; - - for (VertexTable vertexTable : this.vertexTables) { - vertexTable.setGraph(this); - } - for (EdgeTable edgeTable : this.edgeTables) { - edgeTable.setGraph(this); - } - this.validate(); + private final String instanceName; + private final String name; + private final List vertexTables; + private final List edgeTables; + private final Map usingTables; + private final Map config; + private final boolean ifNotExists; + private final boolean isTemporary; + private GraphDescriptor graphDescriptor; + + public GeaFlowGraph( + String instanceName, + String name, + List vertexTables, + List edgeTables, + Map config, + Map usingTables, + boolean ifNotExists, + boolean isTemporary) { + this.instanceName = instanceName; + this.name = name; + this.vertexTables = vertexTables; + this.edgeTables = edgeTables; + this.config = new HashMap<>(config); + this.usingTables = ImmutableMap.copyOf(usingTables); + this.ifNotExists = ifNotExists; + this.isTemporary = isTemporary; + + for (VertexTable vertexTable : this.vertexTables) { + vertexTable.setGraph(this); } - - public GeaFlowGraph(String instanceName, String name, List vertexTables, - List edgeTables, Map config, - Map usingTables, boolean ifNotExists, boolean isTemporary, - GraphDescriptor descriptor) { - this(instanceName, name, vertexTables, edgeTables, config, usingTables, ifNotExists, isTemporary); - this.graphDescriptor = Objects.requireNonNull(descriptor); + for (EdgeTable edgeTable : this.edgeTables) { + edgeTable.setGraph(this); } - - public void validate() { - if (this.vertexTables.size() > 0) { - TableField commonVertexIdField = this.vertexTables.get(0).getIdField(); - for (VertexTable vertexTable : this.vertexTables) { - if (!vertexTable.getIdField().getType().equals(commonVertexIdField.getType())) { - throw new GeaFlowDSLException("Id field type should be same between vertex " + "tables"); - } - } - } - if (this.edgeTables.size() > 0) { - TableField commonSrcIdField = this.edgeTables.get(0).getSrcIdField(); - TableField commonTargetIdField = this.edgeTables.get(0).getTargetIdField(); - Optional commonTsField = - Optional.ofNullable(this.edgeTables.get(0).getTimestampField()); - for (EdgeTable edgeTable : this.edgeTables) { - if (!edgeTable.getSrcIdField().getType().equals(commonSrcIdField.getType())) { - throw new GeaFlowDSLException("SOURCE ID field type should be same between edge " - + "tables"); - } else if (!edgeTable.getTargetIdField().getType().equals(commonTargetIdField.getType())) { - throw new GeaFlowDSLException("DESTINATION ID field type should be same " - + "between edge tables"); - } - - if (commonTsField.isPresent()) { - if (edgeTable.getTimestampField() == null) { - throw new GeaFlowDSLException("TIMESTAMP should defined or not defined in all edge tables"); - } else if (!edgeTable.getTimestampField().getType().equals(commonTsField.get().getType())) { - throw new GeaFlowDSLException("TIMESTAMP field type should be same between edge " - + "tables"); - } - } else { - if (edgeTable.getTimestampField() != null) { - throw new GeaFlowDSLException("TIMESTAMP should defined or not defined in all edge tables"); - } - } - } - } + this.validate(); + } + + public GeaFlowGraph( + String instanceName, + String name, + List vertexTables, + List edgeTables, + Map config, + Map usingTables, + boolean ifNotExists, + boolean isTemporary, + GraphDescriptor descriptor) { + this( + instanceName, + name, + vertexTables, + edgeTables, + config, + usingTables, + ifNotExists, + isTemporary); + this.graphDescriptor = Objects.requireNonNull(descriptor); + } + + public void validate() { + if (this.vertexTables.size() > 0) { + TableField commonVertexIdField = this.vertexTables.get(0).getIdField(); + for (VertexTable vertexTable : this.vertexTables) { + if (!vertexTable.getIdField().getType().equals(commonVertexIdField.getType())) { + throw new GeaFlowDSLException("Id field type should be same between vertex " + "tables"); + } + } } - - public boolean containTable(GeaFlowTable table) { - if (table == null) { - return false; - } - if (table instanceof VertexTable) { - return this.getVertexTables().stream().anyMatch(v -> v.getName().equals(table.getName())); - } - if (table instanceof EdgeTable) { - return this.getEdgeTables().stream().anyMatch(v -> v.getName().equals(table.getName())); - } - return false; + if (this.edgeTables.size() > 0) { + TableField commonSrcIdField = this.edgeTables.get(0).getSrcIdField(); + TableField commonTargetIdField = this.edgeTables.get(0).getTargetIdField(); + Optional commonTsField = + Optional.ofNullable(this.edgeTables.get(0).getTimestampField()); + for (EdgeTable edgeTable : this.edgeTables) { + if (!edgeTable.getSrcIdField().getType().equals(commonSrcIdField.getType())) { + throw new GeaFlowDSLException( + "SOURCE ID field type should be same between edge " + "tables"); + } else if (!edgeTable.getTargetIdField().getType().equals(commonTargetIdField.getType())) { + throw new GeaFlowDSLException( + "DESTINATION ID field type should be same " + "between edge tables"); + } + + if (commonTsField.isPresent()) { + if (edgeTable.getTimestampField() == null) { + throw new GeaFlowDSLException( + "TIMESTAMP should defined or not defined in all edge tables"); + } else if (!edgeTable + .getTimestampField() + .getType() + .equals(commonTsField.get().getType())) { + throw new GeaFlowDSLException( + "TIMESTAMP field type should be same between edge " + "tables"); + } + } else { + if (edgeTable.getTimestampField() != null) { + throw new GeaFlowDSLException( + "TIMESTAMP should defined or not defined in all edge tables"); + } + } + } } + } - @Override - public RelDataType getRowType(RelDataTypeFactory typeFactory) { - List fields = new ArrayList<>(); - for (VertexTable table : vertexTables) { - VertexRecordType type = table.getRowType(typeFactory); - fields.add(new RelDataTypeFieldImpl(table.getTypeName(), fields.size(), type)); - } - for (EdgeTable table : edgeTables) { - EdgeRecordType type = table.getRowType(typeFactory); - fields.add(new RelDataTypeFieldImpl(table.getTypeName(), fields.size(), type)); - } - return new GraphRecordType(name, fields); + public boolean containTable(GeaFlowTable table) { + if (table == null) { + return false; } - - public GraphSchema getGraphSchema(RelDataTypeFactory typeFactory) { - return (GraphSchema) SqlTypeUtil.convertType(getRowType(typeFactory)); - } - - public String getInstanceName() { - return instanceName; + if (table instanceof VertexTable) { + return this.getVertexTables().stream().anyMatch(v -> v.getName().equals(table.getName())); } - - public String getName() { - return name; + if (table instanceof EdgeTable) { + return this.getEdgeTables().stream().anyMatch(v -> v.getName().equals(table.getName())); } - - public String getUniqueName() { - return instanceName + "_" + name; + return false; + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + List fields = new ArrayList<>(); + for (VertexTable table : vertexTables) { + VertexRecordType type = table.getRowType(typeFactory); + fields.add(new RelDataTypeFieldImpl(table.getTypeName(), fields.size(), type)); } - - public List getVertexTables() { - return vertexTables; + for (EdgeTable table : edgeTables) { + EdgeRecordType type = table.getRowType(typeFactory); + fields.add(new RelDataTypeFieldImpl(table.getTypeName(), fields.size(), type)); } - - public List getEdgeTables() { - return edgeTables; - } - - public Configuration getConfig() { - return new Configuration(config); - } - - public GeaFlowGraph setDescriptor(GraphDescriptor desc) { - this.graphDescriptor = Objects.requireNonNull(desc); - return this; - } - - public GraphDescriptor getValidDescriptorInGraph(GraphDescriptor desc) { - GraphDescriptor newDesc = new GraphDescriptor(); - newDesc.addNode(desc.nodes.stream().filter( - node -> this.vertexTables.stream().anyMatch(v -> v.getTypeName().equals(node.type)) - ).collect(Collectors.toList())); - newDesc.addEdge(desc.edges.stream().filter( - edge -> { - EdgeTable edgeTable = null; - for (EdgeTable e : this.getEdgeTables()) { + return new GraphRecordType(name, fields); + } + + public GraphSchema getGraphSchema(RelDataTypeFactory typeFactory) { + return (GraphSchema) SqlTypeUtil.convertType(getRowType(typeFactory)); + } + + public String getInstanceName() { + return instanceName; + } + + public String getName() { + return name; + } + + public String getUniqueName() { + return instanceName + "_" + name; + } + + public List getVertexTables() { + return vertexTables; + } + + public List getEdgeTables() { + return edgeTables; + } + + public Configuration getConfig() { + return new Configuration(config); + } + + public GeaFlowGraph setDescriptor(GraphDescriptor desc) { + this.graphDescriptor = Objects.requireNonNull(desc); + return this; + } + + public GraphDescriptor getValidDescriptorInGraph(GraphDescriptor desc) { + GraphDescriptor newDesc = new GraphDescriptor(); + newDesc.addNode( + desc.nodes.stream() + .filter( + node -> this.vertexTables.stream().anyMatch(v -> v.getTypeName().equals(node.type))) + .collect(Collectors.toList())); + newDesc.addEdge( + desc.edges.stream() + .filter( + edge -> { + EdgeTable edgeTable = null; + for (EdgeTable e : this.getEdgeTables()) { if (e.getTypeName().equals(edge.type)) { - edgeTable = e; - break; + edgeTable = e; + break; } - } - VertexTable sourceVertexTable = null; - for (VertexTable v : this.getVertexTables()) { + } + VertexTable sourceVertexTable = null; + for (VertexTable v : this.getVertexTables()) { if (v.getTypeName().equals(edge.sourceType)) { - sourceVertexTable = v; - break; + sourceVertexTable = v; + break; } - } - VertexTable targetVertexTable = null; - for (VertexTable v : this.getVertexTables()) { + } + VertexTable targetVertexTable = null; + for (VertexTable v : this.getVertexTables()) { if (v.getTypeName().equals(edge.targetType)) { - targetVertexTable = v; - break; + targetVertexTable = v; + break; } - } - boolean exist = edgeTable != null - && sourceVertexTable != null && targetVertexTable != null; - return exist && edgeTable.getSrcIdField().getType().equals(sourceVertexTable.getIdField().getType()) - && edgeTable.getTargetIdField().getType().equals(targetVertexTable.getIdField().getType()); - } - ).collect(Collectors.toList())); - return newDesc; + } + boolean exist = + edgeTable != null && sourceVertexTable != null && targetVertexTable != null; + return exist + && edgeTable + .getSrcIdField() + .getType() + .equals(sourceVertexTable.getIdField().getType()) + && edgeTable + .getTargetIdField() + .getType() + .equals(targetVertexTable.getIdField().getType()); + }) + .collect(Collectors.toList())); + return newDesc; + } + + public GraphDescriptor getDescriptor() { + return graphDescriptor == null ? new GraphDescriptor() : graphDescriptor; + } + + public Configuration getConfigWithGlobal(Configuration globalConf) { + Map conf = new HashMap<>(globalConf.getConfigMap()); + conf.putAll(this.config); + return new Configuration(conf); + } + + public Configuration getConfigWithGlobal( + Map globalConf, Map setOptions) { + Map conf = new HashMap<>(globalConf); + conf.putAll(this.config); + conf.putAll(setOptions); + return new Configuration(conf); + } + + public String getStoreType() { + return Configuration.getString(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE, config); + } + + public int getShardCount() { + return Configuration.getInteger(DSLConfigKeys.GEAFLOW_DSL_STORE_SHARD_COUNT, config); + } + + public IType getIdType() { + VertexTable vertexTable = vertexTables.iterator().next(); + return vertexTable.getIdField().getType(); + } + + public IType getLabelType() { + return Types.STRING; + } + + public boolean isIfNotExists() { + return ifNotExists; + } + + public boolean isTemporary() { + return isTemporary; + } + + public GraphElementTable getTable(String tableName) { + for (VertexTable vertexTable : vertexTables) { + if (vertexTable.getTypeName().equalsIgnoreCase(tableName)) { + return vertexTable; + } } - - public GraphDescriptor getDescriptor() { - return graphDescriptor == null ? new GraphDescriptor() : graphDescriptor; + for (EdgeTable edgeTable : edgeTables) { + if (edgeTable.getTypeName().equalsIgnoreCase(tableName)) { + return edgeTable; + } } - - public Configuration getConfigWithGlobal(Configuration globalConf) { - Map conf = new HashMap<>(globalConf.getConfigMap()); - conf.putAll(this.config); - return new Configuration(conf); + return null; + } + + public Map getUsingTables() { + return usingTables; + } + + public static class VertexTable extends GeaFlowTable implements GraphElementTable, Serializable { + + private final String idField; + + private GeaFlowGraph graph; + + public VertexTable( + String instanceName, String typeName, List fields, String idField) { + super( + instanceName, + typeName, + fields, + Collections.singletonList(idField), + Collections.emptyList(), + new HashMap<>(), + true, + true); + this.idField = Objects.requireNonNull(idField); + checkFields(); } - public Configuration getConfigWithGlobal(Map globalConf, Map setOptions) { - Map conf = new HashMap<>(globalConf); - conf.putAll(this.config); - conf.putAll(setOptions); - return new Configuration(conf); + private void checkFields() { + for (TableField field : getFields()) { + GraphRecordType.validateFieldName(field.getName()); + } + Set fieldNames = + getFields().stream().map(TableField::getName).collect(Collectors.toSet()); + if (fieldNames.size() != super.getFields().size()) { + throw new GeaFlowDSLException("Duplicate field has found in vertex table: " + getName()); + } + if (!fieldNames.contains(idField)) { + throw new GeaFlowDSLException( + "id field:'" + idField + "' is not found in the fields: " + fieldNames); + } } - public String getStoreType() { - return Configuration.getString(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE, config); + public void setGraph(GeaFlowGraph graph) { + this.graph = graph; } - public int getShardCount() { - return Configuration.getInteger(DSLConfigKeys.GEAFLOW_DSL_STORE_SHARD_COUNT, config); + @Override + public GeaFlowGraph getGraph() { + return graph; } - public IType getIdType() { - VertexTable vertexTable = vertexTables.iterator().next(); - return vertexTable.getIdField().getType(); + @Override + public String getTypeName() { + return getName(); } - public IType getLabelType() { - return Types.STRING; + public TableField getIdField() { + return findField(getFields(), idField); } - public boolean isIfNotExists() { - return ifNotExists; + public String getIdFieldName() { + return idField; } - public boolean isTemporary() { - return isTemporary; + @Override + public VertexRecordType getRowType(RelDataTypeFactory typeFactory) { + List dataFields = new ArrayList<>(getFields().size()); + for (int i = 0; i < getFields().size(); i++) { + TableField field = getFields().get(i); + RelDataType type = + SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory); + RelDataTypeField dataField = new RelDataTypeFieldImpl(field.getName(), i, type); + dataFields.add(dataField); + } + return VertexRecordType.createVertexType(dataFields, idField, typeFactory); } - public GraphElementTable getTable(String tableName) { - for (VertexTable vertexTable : vertexTables) { - if (vertexTable.getTypeName().equalsIgnoreCase(tableName)) { - return vertexTable; - } - } - for (EdgeTable edgeTable : edgeTables) { - if (edgeTable.getTypeName().equalsIgnoreCase(tableName)) { - return edgeTable; - } - } - return null; + @Override + public String toString() { + return "VertexTable{" + + "typeName='" + + getTypeName() + + '\'' + + ", fields=" + + getFields() + + ", idField='" + + idField + + '\'' + + '}'; } - - public Map getUsingTables() { - return usingTables; + } + + public static class EdgeTable extends GeaFlowTable implements GraphElementTable, Serializable { + + private final String srcIdField; + private final String targetIdField; + private final String timestampField; + + private GeaFlowGraph graph; + + public EdgeTable( + String instanceName, + String typeName, + List fields, + String srcIdField, + String targetIdField, + String timestampField) { + super( + instanceName, + typeName, + fields, + Lists.newArrayList(srcIdField, targetIdField), + Collections.emptyList(), + new HashMap<>(), + true, + true); + this.srcIdField = Objects.requireNonNull(srcIdField); + this.targetIdField = Objects.requireNonNull(targetIdField); + this.timestampField = timestampField; + checkFields(); } - public static class VertexTable extends GeaFlowTable implements GraphElementTable, Serializable { - - private final String idField; - - private GeaFlowGraph graph; - - public VertexTable(String instanceName, String typeName, List fields, String idField) { - super(instanceName, typeName, fields, Collections.singletonList(idField), - Collections.emptyList(), new HashMap<>(), true, true); - this.idField = Objects.requireNonNull(idField); - checkFields(); - } - - private void checkFields() { - for (TableField field : getFields()) { - GraphRecordType.validateFieldName(field.getName()); - } - Set fieldNames = getFields().stream().map(TableField::getName) - .collect(Collectors.toSet()); - if (fieldNames.size() != super.getFields().size()) { - throw new GeaFlowDSLException("Duplicate field has found in vertex table: " + getName()); - } - if (!fieldNames.contains(idField)) { - throw new GeaFlowDSLException("id field:'" + idField + "' is not found in the fields: " + fieldNames); - } - } - - public void setGraph(GeaFlowGraph graph) { - this.graph = graph; - } - - @Override - public GeaFlowGraph getGraph() { - return graph; - } - - @Override - public String getTypeName() { - return getName(); - } - - - public TableField getIdField() { - return findField(getFields(), idField); - } - - public String getIdFieldName() { - return idField; - } - - @Override - public VertexRecordType getRowType(RelDataTypeFactory typeFactory) { - List dataFields = new ArrayList<>(getFields().size()); - for (int i = 0; i < getFields().size(); i++) { - TableField field = getFields().get(i); - RelDataType type = SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory); - RelDataTypeField dataField = new RelDataTypeFieldImpl(field.getName(), i, type); - dataFields.add(dataField); - } - return VertexRecordType.createVertexType(dataFields, idField, typeFactory); - } - - @Override - public String toString() { - return "VertexTable{" + "typeName='" + getTypeName() + '\'' + ", fields=" + getFields() - + ", idField='" + idField + '\'' + '}'; - } + private void checkFields() { + for (TableField field : getFields()) { + GraphRecordType.validateFieldName(field.getName()); + } + Set fieldNames = + getFields().stream().map(TableField::getName).collect(Collectors.toSet()); + if (fieldNames.size() != getFields().size()) { + throw new GeaFlowDSLException("Duplicate field has found in edge table: " + getName()); + } + if (!fieldNames.contains(srcIdField)) { + throw new GeaFlowDSLException( + "source id:" + srcIdField + " is not found in fields: " + fieldNames); + } + if (!fieldNames.contains(targetIdField)) { + throw new GeaFlowDSLException( + "target id:" + targetIdField + " is not found in fields: " + fieldNames); + } } - public static class EdgeTable extends GeaFlowTable implements GraphElementTable, Serializable { - - private final String srcIdField; - private final String targetIdField; - private final String timestampField; - - private GeaFlowGraph graph; - - public EdgeTable(String instanceName, String typeName, List fields, String srcIdField, - String targetIdField, String timestampField) { - super(instanceName, typeName, fields, Lists.newArrayList(srcIdField, targetIdField), - Collections.emptyList(), new HashMap<>(), true, true); - this.srcIdField = Objects.requireNonNull(srcIdField); - this.targetIdField = Objects.requireNonNull(targetIdField); - this.timestampField = timestampField; - checkFields(); - } - - private void checkFields() { - for (TableField field : getFields()) { - GraphRecordType.validateFieldName(field.getName()); - } - Set fieldNames = getFields().stream().map(TableField::getName) - .collect(Collectors.toSet()); - if (fieldNames.size() != getFields().size()) { - throw new GeaFlowDSLException("Duplicate field has found in edge table: " + getName()); - } - if (!fieldNames.contains(srcIdField)) { - throw new GeaFlowDSLException("source id:" + srcIdField + " is not found in fields: " + fieldNames); - } - if (!fieldNames.contains(targetIdField)) { - throw new GeaFlowDSLException( - "target id:" + targetIdField + " is not found in fields: " + fieldNames); - } - } - - public void setGraph(GeaFlowGraph graph) { - this.graph = graph; - } - - @Override - public GeaFlowGraph getGraph() { - return graph; - } - - @Override - public String getTypeName() { - return getName(); - } - - public TableField getSrcIdField() { - return findField(getFields(), srcIdField); - } - - public TableField getTargetIdField() { - return findField(getFields(), targetIdField); - } - - public TableField getTimestampField() { - if (timestampField == null) { - return null; - } - return findField(getFields(), timestampField); - } + public void setGraph(GeaFlowGraph graph) { + this.graph = graph; + } - public String getSrcIdFieldName() { - return srcIdField; - } + @Override + public GeaFlowGraph getGraph() { + return graph; + } - public String getTargetIdFieldName() { - return targetIdField; - } + @Override + public String getTypeName() { + return getName(); + } - public String getTimestampFieldName() { - return timestampField; - } + public TableField getSrcIdField() { + return findField(getFields(), srcIdField); + } - @Override - public EdgeRecordType getRowType(RelDataTypeFactory typeFactory) { - List dataFields = new ArrayList<>(getFields().size()); - for (int i = 0; i < getFields().size(); i++) { - TableField field = getFields().get(i); - RelDataType type = SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory); - RelDataTypeField dataField = new RelDataTypeFieldImpl(field.getName(), i, type); - dataFields.add(dataField); - } - return EdgeRecordType.createEdgeType(dataFields, srcIdField, targetIdField, timestampField, typeFactory); - } + public TableField getTargetIdField() { + return findField(getFields(), targetIdField); + } - @Override - public String toString() { - return "EdgeTable{" + "typeName='" + getName() + '\'' + ", fields=" + getFields() - + ", srcIdField='" + srcIdField + '\'' + ", targetIdField='" + targetIdField + '\'' - + ", timestampField='" + timestampField + '\'' + '}'; - } + public TableField getTimestampField() { + if (timestampField == null) { + return null; + } + return findField(getFields(), timestampField); } - public interface GraphElementTable extends Table { + public String getSrcIdFieldName() { + return srcIdField; + } - String getTypeName(); + public String getTargetIdFieldName() { + return targetIdField; + } - GeaFlowGraph getGraph(); + public String getTimestampFieldName() { + return timestampField; } - private static TableField findField(List fields, String name) { - for (TableField field : fields) { - if (Objects.equals(field.getName(), name)) { - return field; - } - } - throw new IllegalArgumentException("Field name: '" + name + "' is not found"); + @Override + public EdgeRecordType getRowType(RelDataTypeFactory typeFactory) { + List dataFields = new ArrayList<>(getFields().size()); + for (int i = 0; i < getFields().size(); i++) { + TableField field = getFields().get(i); + RelDataType type = + SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory); + RelDataTypeField dataField = new RelDataTypeFieldImpl(field.getName(), i, type); + dataFields.add(dataField); + } + return EdgeRecordType.createEdgeType( + dataFields, srcIdField, targetIdField, timestampField, typeFactory); } @Override public String toString() { - return "GeaFlowGraph{" + "name='" + name + '\'' + ", vertexTables=" + vertexTables - + ", edgeTables=" + edgeTables + ", config=" + config + ", ifNotExists=" + ifNotExists - + '}'; + return "EdgeTable{" + + "typeName='" + + getName() + + '\'' + + ", fields=" + + getFields() + + ", srcIdField='" + + srcIdField + + '\'' + + ", targetIdField='" + + targetIdField + + '\'' + + ", timestampField='" + + timestampField + + '\'' + + '}'; + } + } + + public interface GraphElementTable extends Table { + + String getTypeName(); + + GeaFlowGraph getGraph(); + } + + private static TableField findField(List fields, String name) { + for (TableField field : fields) { + if (Objects.equals(field.getName(), name)) { + return field; + } } + throw new IllegalArgumentException("Field name: '" + name + "' is not found"); + } + + @Override + public String toString() { + return "GeaFlowGraph{" + + "name='" + + name + + '\'' + + ", vertexTables=" + + vertexTables + + ", edgeTables=" + + edgeTables + + ", config=" + + config + + ", ifNotExists=" + + ifNotExists + + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowTable.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowTable.java index 5fe86c90e..d4477d20a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowTable.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.schema.Statistic; @@ -40,177 +41,184 @@ public class GeaFlowTable extends AbstractTable implements Serializable { - private final String instanceName; - - private final String name; - - private final List fields; - - private final List primaryFields; - - private final List partitionFields; - - private final Map config; - - private final boolean ifNotExists; - - private final boolean isTemporary; - - public GeaFlowTable(String instanceName, String name, List fields, List primaryFields, - List partitionFields, Map config, - boolean ifNotExists, boolean isTemporaryTable) { - this.instanceName = Objects.requireNonNull(instanceName); - this.name = Objects.requireNonNull(name, "name is null"); - this.fields = Objects.requireNonNull(fields, "fields is null"); - this.primaryFields = Objects.requireNonNull(primaryFields, "primaryFields is null"); - this.partitionFields = Objects.requireNonNull(partitionFields, "partitionFields is null"); - this.config = Objects.requireNonNull(config, "config is null"); - this.ifNotExists = ifNotExists; - this.isTemporary = isTemporaryTable; - } - - @Override - public Statistic getStatistic() { - return new GQLStatistic(); - } - - @Override - public RelDataType getRowType(RelDataTypeFactory typeFactory) { - List fieldNames = new ArrayList<>(); - List fieldTypes = new ArrayList<>(); - for (TableField field : fields) { - fieldNames.add(field.getName()); - fieldTypes.add(SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory)); - } - return typeFactory.createStructType(fieldTypes, fieldNames); - } - - public TableSchema getTableSchema() { - return new TableSchema(getDataSchema(), getPartitionSchema()); - } - - public StructType getDataSchema() { - return new StructType(fields).dropRight(partitionFields.size()); - } - - public StructType getPartitionSchema() { - List pFields = new ArrayList<>(partitionFields.size()); - for (int i = fields.size() - partitionFields.size(); i < fields.size(); i++) { - pFields.add(fields.get(i)); - } - return new StructType(pFields); - } - - public String getInstanceName() { - return instanceName; - } - - public String getName() { - return name; - } - - public List getFields() { - return fields; - } + private final String instanceName; - public List getPrimaryFields() { - return primaryFields; - } + private final String name; + + private final List fields; - public List getPartitionFields() { - return partitionFields; - } - - public List getPartitionIndices() { - StructType tableSchema = getTableSchema(); - return partitionFields.stream() - .map(tableSchema::indexOf) - .collect(Collectors.toList()); - } - - public Map getConfig() { - return config; - } - - public boolean isIfNotExists() { - return ifNotExists; - } - - public boolean isTemporary() { - return isTemporary; - } + private final List primaryFields; + + private final List partitionFields; - public String getTableType() { - return Configuration.getString(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE, config); - } - - public Configuration getConfigWithGlobal(Configuration globalConf) { - Map conf = new HashMap<>(globalConf.getConfigMap()); - conf.putAll(this.config); - return new Configuration(conf); - } - - public Configuration getConfigWithGlobal(Configuration globalConf, Map setOptions) { - Map conf = new HashMap<>(globalConf.getConfigMap()); - conf.putAll(setOptions); - conf.putAll(this.config); - return new Configuration(conf); - } - - public boolean isPartitionField(int index) { - return partitionFields.contains(fields.get(index).getName()); - } - - @Override - public String toString() { - StringBuilder sql = new StringBuilder(); - sql.append("CREATE TABLE "); - - sql.append(name).append(" (\n"); - - for (int i = 0; i < fields.size(); i++) { - TableField field = fields.get(i); - if (i > 0) { - sql.append("\n\t,"); - } else { - sql.append("\t"); - } - sql.append(field.getName()).append("\t"); - sql.append(field.getType()); - } - - sql.append("\n)"); - - if (config.size() > 0) { - boolean first = true; - sql.append(" WITH (\n"); - for (Map.Entry entry : config.entrySet()) { - if (!first) { - sql.append("\t,"); - } else { - sql.append("\t"); - } - first = false; - sql.append(entry.getKey()).append("=") - .append(StringLiteralUtil.escapeSQLString(entry.getValue())) - .append("\n"); - } - sql.append(")"); + private final Map config; + + private final boolean ifNotExists; + + private final boolean isTemporary; + + public GeaFlowTable( + String instanceName, + String name, + List fields, + List primaryFields, + List partitionFields, + Map config, + boolean ifNotExists, + boolean isTemporaryTable) { + this.instanceName = Objects.requireNonNull(instanceName); + this.name = Objects.requireNonNull(name, "name is null"); + this.fields = Objects.requireNonNull(fields, "fields is null"); + this.primaryFields = Objects.requireNonNull(primaryFields, "primaryFields is null"); + this.partitionFields = Objects.requireNonNull(partitionFields, "partitionFields is null"); + this.config = Objects.requireNonNull(config, "config is null"); + this.ifNotExists = ifNotExists; + this.isTemporary = isTemporaryTable; + } + + @Override + public Statistic getStatistic() { + return new GQLStatistic(); + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + List fieldNames = new ArrayList<>(); + List fieldTypes = new ArrayList<>(); + for (TableField field : fields) { + fieldNames.add(field.getName()); + fieldTypes.add( + SqlTypeUtil.convertToRelType(field.getType(), field.isNullable(), typeFactory)); + } + return typeFactory.createStructType(fieldTypes, fieldNames); + } + + public TableSchema getTableSchema() { + return new TableSchema(getDataSchema(), getPartitionSchema()); + } + + public StructType getDataSchema() { + return new StructType(fields).dropRight(partitionFields.size()); + } + + public StructType getPartitionSchema() { + List pFields = new ArrayList<>(partitionFields.size()); + for (int i = fields.size() - partitionFields.size(); i < fields.size(); i++) { + pFields.add(fields.get(i)); + } + return new StructType(pFields); + } + + public String getInstanceName() { + return instanceName; + } + + public String getName() { + return name; + } + + public List getFields() { + return fields; + } + + public List getPrimaryFields() { + return primaryFields; + } + + public List getPartitionFields() { + return partitionFields; + } + + public List getPartitionIndices() { + StructType tableSchema = getTableSchema(); + return partitionFields.stream().map(tableSchema::indexOf).collect(Collectors.toList()); + } + + public Map getConfig() { + return config; + } + + public boolean isIfNotExists() { + return ifNotExists; + } + + public boolean isTemporary() { + return isTemporary; + } + + public String getTableType() { + return Configuration.getString(DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE, config); + } + + public Configuration getConfigWithGlobal(Configuration globalConf) { + Map conf = new HashMap<>(globalConf.getConfigMap()); + conf.putAll(this.config); + return new Configuration(conf); + } + + public Configuration getConfigWithGlobal( + Configuration globalConf, Map setOptions) { + Map conf = new HashMap<>(globalConf.getConfigMap()); + conf.putAll(setOptions); + conf.putAll(this.config); + return new Configuration(conf); + } + + public boolean isPartitionField(int index) { + return partitionFields.contains(fields.get(index).getName()); + } + + @Override + public String toString() { + StringBuilder sql = new StringBuilder(); + sql.append("CREATE TABLE "); + + sql.append(name).append(" (\n"); + + for (int i = 0; i < fields.size(); i++) { + TableField field = fields.get(i); + if (i > 0) { + sql.append("\n\t,"); + } else { + sql.append("\t"); + } + sql.append(field.getName()).append("\t"); + sql.append(field.getType()); + } + + sql.append("\n)"); + + if (config.size() > 0) { + boolean first = true; + sql.append(" WITH (\n"); + for (Map.Entry entry : config.entrySet()) { + if (!first) { + sql.append("\t,"); + } else { + sql.append("\t"); } - return sql.toString(); - } - - public enum StoreType { - FILE, - CONSOLE; - - public static StoreType of(String value) { - for (StoreType storeType : values()) { - if (storeType.name().equalsIgnoreCase(value)) { - return storeType; - } - } - throw new IllegalArgumentException("Illegal storeType: " + value); + first = false; + sql.append(entry.getKey()) + .append("=") + .append(StringLiteralUtil.escapeSQLString(entry.getValue())) + .append("\n"); + } + sql.append(")"); + } + return sql.toString(); + } + + public enum StoreType { + FILE, + CONSOLE; + + public static StoreType of(String value) { + for (StoreType storeType : values()) { + if (storeType.name().equalsIgnoreCase(value)) { + return storeType; } + } + throw new IllegalArgumentException("Illegal storeType: " + value); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowView.java b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowView.java index 9438d75f2..df46c088b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-catalog/src/main/java/org/apache/geaflow/dsl/schema/GeaFlowView.java @@ -23,50 +23,53 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.schema.impl.ViewTable; import org.apache.commons.lang3.StringUtils; public class GeaFlowView extends ViewTable implements Serializable { - private final String instanceName; + private final String instanceName; - private final String name; + private final String name; - private final List fields; + private final List fields; - private final boolean ifNotExists; + private final boolean ifNotExists; - public GeaFlowView(String instanceName, String name, List fields, RelDataType rowType, - String viewSql, boolean ifNotExists) { - super(null, t -> rowType, viewSql, Collections.emptyList(), Collections.emptyList()); - this.instanceName = instanceName; - this.name = Objects.requireNonNull(name, "name is null"); - this.fields = Objects.requireNonNull(fields, "fields is null"); - this.ifNotExists = ifNotExists; - } + public GeaFlowView( + String instanceName, + String name, + List fields, + RelDataType rowType, + String viewSql, + boolean ifNotExists) { + super(null, t -> rowType, viewSql, Collections.emptyList(), Collections.emptyList()); + this.instanceName = instanceName; + this.name = Objects.requireNonNull(name, "name is null"); + this.fields = Objects.requireNonNull(fields, "fields is null"); + this.ifNotExists = ifNotExists; + } - public String getInstanceName() { - return instanceName; - } + public String getInstanceName() { + return instanceName; + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public List getFields() { - return fields; - } + public List getFields() { + return fields; + } - public boolean isIfNotExists() { - return ifNotExists; - } + public boolean isIfNotExists() { + return ifNotExists; + } - @Override - public String toString() { - return "Create View " + name + "(" - + StringUtils.join(fields, ",") - + ") AS" - + getViewSql(); - } + @Override + public String toString() { + return "Create View " + name + "(" + StringUtils.join(fields, ",") + ") AS" + getViewSql(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java index 5a73c8c1b..fac2e0a6d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.algo; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.dsl.common.data.Row; @@ -29,135 +30,136 @@ import org.apache.geaflow.state.pushdown.filter.IFilter; /** - * Interface defining methods for managing and interacting with the runtime context of a graph algorithm. + * Interface defining methods for managing and interacting with the runtime context of a graph + * algorithm. * * @param The type of vertex IDs. * @param The type of messages that can be sent between vertices. */ public interface AlgorithmRuntimeContext { - /** - * Loads all edges in the specified direction. - * - * @param direction The direction of the edges to be loaded. - * @return A list of RowEdge objects representing the edges. - */ - List loadEdges(EdgeDirection direction); - - /** - * Returns an iterator over all edges in the specified direction. - * - * @param direction The direction of the edges to iterate over. - * @return An iterator over RowEdge objects representing the edges. - */ - CloseableIterator loadEdgesIterator(EdgeDirection direction); - - /** - * Returns an iterator over edges filtered by the provided IFilter. - * - * @param filter The filter to apply when loading edges. - * @return An iterator over RowEdge objects representing the filtered edges. - */ - CloseableIterator loadEdgesIterator(IFilter filter); - - /** - * Loads static edges in the specified direction. - * - * @param direction The direction of the edges to be loaded. - * @return A list of RowEdge objects representing the static edges. - */ - List loadStaticEdges(EdgeDirection direction); - - /** - * Returns an iterator over static edges in the specified direction. - * - * @param direction The direction of the edges to iterate over. - * @return An iterator over RowEdge objects representing the static edges. - */ - CloseableIterator loadStaticEdgesIterator(EdgeDirection direction); - - /** - * Returns an iterator over static edges filtered by the provided IFilter. - * - * @param filter The filter to apply when loading edges. - * @return An iterator over RowEdge objects representing the filtered static edges. - */ - CloseableIterator loadStaticEdgesIterator(IFilter filter); - - /** - * Loads dynamic edges (changed during execution) in the specified direction. - * - * @param direction The direction of the edges to be loaded. - * @return A list of RowEdge objects representing the dynamic edges. - */ - List loadDynamicEdges(EdgeDirection direction); - - /** - * Returns an iterator over dynamic edges in the specified direction. - * - * @param direction The direction of the edges to iterate over. - * @return An iterator over RowEdge objects representing the dynamic edges. - */ - CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction); - - /** - * Returns an iterator over dynamic edges filtered by the provided IFilter. - * - * @param filter The filter to apply when loading edges. - * @return An iterator over RowEdge objects representing the filtered dynamic edges. - */ - CloseableIterator loadDynamicEdgesIterator(IFilter filter); - - /** - * Sends a message to the specified vertex. - * - * @param vertexId The ID of the vertex to which the message should be sent. - * @param message The message to send. - */ - void sendMessage(K vertexId, M message); - - /** - * Updates the current vertex's value with the new row data. - * - * @param value The new row data to set as the vertex value. - */ - void updateVertexValue(Row value); - - /** - * Takes a row of data, typically received from another source. - * - * @param value The row data to take. - */ - void take(Row value); - - /** - * Gets the unique identifier for the current iteration. - * - * @return The current iteration ID. - */ - long getCurrentIterationId(); - - /** - * Retrieves the schema information for the graph. - * - * @return The GraphSchema object containing the graph structure details. - */ - GraphSchema getGraphSchema(); - - /** - * Retrieves the configuration settings for the algorithm runtime context. - * - * @return The Configuration object containing the settings. - */ - Configuration getConfig(); - - /** - * Sends a termination vote to the coordinator to signal algorithm completion. - * This method allows vertices to vote for algorithm termination when they - * determine that no further computation is needed. - * - * @param terminationReason The reason for termination (e.g., "CONVERGED", "COMPLETED") - * @param voteValue The vote value (typically 1 for termination vote) - */ - void voteToTerminate(String terminationReason, Object voteValue); -} \ No newline at end of file + /** + * Loads all edges in the specified direction. + * + * @param direction The direction of the edges to be loaded. + * @return A list of RowEdge objects representing the edges. + */ + List loadEdges(EdgeDirection direction); + + /** + * Returns an iterator over all edges in the specified direction. + * + * @param direction The direction of the edges to iterate over. + * @return An iterator over RowEdge objects representing the edges. + */ + CloseableIterator loadEdgesIterator(EdgeDirection direction); + + /** + * Returns an iterator over edges filtered by the provided IFilter. + * + * @param filter The filter to apply when loading edges. + * @return An iterator over RowEdge objects representing the filtered edges. + */ + CloseableIterator loadEdgesIterator(IFilter filter); + + /** + * Loads static edges in the specified direction. + * + * @param direction The direction of the edges to be loaded. + * @return A list of RowEdge objects representing the static edges. + */ + List loadStaticEdges(EdgeDirection direction); + + /** + * Returns an iterator over static edges in the specified direction. + * + * @param direction The direction of the edges to iterate over. + * @return An iterator over RowEdge objects representing the static edges. + */ + CloseableIterator loadStaticEdgesIterator(EdgeDirection direction); + + /** + * Returns an iterator over static edges filtered by the provided IFilter. + * + * @param filter The filter to apply when loading edges. + * @return An iterator over RowEdge objects representing the filtered static edges. + */ + CloseableIterator loadStaticEdgesIterator(IFilter filter); + + /** + * Loads dynamic edges (changed during execution) in the specified direction. + * + * @param direction The direction of the edges to be loaded. + * @return A list of RowEdge objects representing the dynamic edges. + */ + List loadDynamicEdges(EdgeDirection direction); + + /** + * Returns an iterator over dynamic edges in the specified direction. + * + * @param direction The direction of the edges to iterate over. + * @return An iterator over RowEdge objects representing the dynamic edges. + */ + CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction); + + /** + * Returns an iterator over dynamic edges filtered by the provided IFilter. + * + * @param filter The filter to apply when loading edges. + * @return An iterator over RowEdge objects representing the filtered dynamic edges. + */ + CloseableIterator loadDynamicEdgesIterator(IFilter filter); + + /** + * Sends a message to the specified vertex. + * + * @param vertexId The ID of the vertex to which the message should be sent. + * @param message The message to send. + */ + void sendMessage(K vertexId, M message); + + /** + * Updates the current vertex's value with the new row data. + * + * @param value The new row data to set as the vertex value. + */ + void updateVertexValue(Row value); + + /** + * Takes a row of data, typically received from another source. + * + * @param value The row data to take. + */ + void take(Row value); + + /** + * Gets the unique identifier for the current iteration. + * + * @return The current iteration ID. + */ + long getCurrentIterationId(); + + /** + * Retrieves the schema information for the graph. + * + * @return The GraphSchema object containing the graph structure details. + */ + GraphSchema getGraphSchema(); + + /** + * Retrieves the configuration settings for the algorithm runtime context. + * + * @return The Configuration object containing the settings. + */ + Configuration getConfig(); + + /** + * Sends a termination vote to the coordinator to signal algorithm completion. This method allows + * vertices to vote for algorithm termination when they determine that no further computation is + * needed. + * + * @param terminationReason The reason for termination (e.g., "CONVERGED", "COMPLETED") + * @param voteValue The vote value (typically 1 for termination vote) + */ + void voteToTerminate(String terminationReason, Object voteValue); +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmUserFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmUserFunction.java index 4058ff6f6..bb85e7bc8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmUserFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmUserFunction.java @@ -19,10 +19,10 @@ package org.apache.geaflow.dsl.common.algo; - import java.io.Serializable; import java.util.Iterator; import java.util.Optional; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.types.GraphSchema; @@ -36,38 +36,26 @@ */ public interface AlgorithmUserFunction extends Serializable { - /** - * Init method for the function. - * - * @param context The runtime context. - * @param params The parameters for the function. - */ - void init(AlgorithmRuntimeContext context, Object[] params); + /** + * Init method for the function. + * + * @param context The runtime context. + * @param params The parameters for the function. + */ + void init(AlgorithmRuntimeContext context, Object[] params); - /** - * Processing method for each vertex and the messages it received. - */ - void process(RowVertex vertex, Optional updatedValues, Iterator messages); + /** Processing method for each vertex and the messages it received. */ + void process(RowVertex vertex, Optional updatedValues, Iterator messages); - /** - * Finish method called by each vertex upon algorithm convergence. - */ - void finish(RowVertex graphVertex, Optional updatedValues); + /** Finish method called by each vertex upon algorithm convergence. */ + void finish(RowVertex graphVertex, Optional updatedValues); - /** - * Finish method called after all vertices is processed. - */ - default void finish() { - } + /** Finish method called after all vertices is processed. */ + default void finish() {} - /** - * Finish Iteration method called after each iteration finished. - */ - default void finishIteration(long iterationId) { - } + /** Finish Iteration method called after each iteration finished. */ + default void finishIteration(long iterationId) {} - /** - * Returns the output type for the function. - */ - StructType getOutputType(GraphSchema graphSchema); + /** Returns the output type for the function. */ + StructType getOutputType(GraphSchema graphSchema); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/IncrementalAlgorithmUserFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/IncrementalAlgorithmUserFunction.java index b19373e2b..a9720da7d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/IncrementalAlgorithmUserFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/IncrementalAlgorithmUserFunction.java @@ -19,9 +19,5 @@ package org.apache.geaflow.dsl.common.algo; -/** - * Interface for the User Defined Graph Algorithm support incremental calculation. - */ -public interface IncrementalAlgorithmUserFunction { - -} +/** Interface for the User Defined Graph Algorithm support incremental calculation. */ +public interface IncrementalAlgorithmUserFunction {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/BinaryLayoutHelper.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/BinaryLayoutHelper.java index d3aa601cb..fe5702ab1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/BinaryLayoutHelper.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/BinaryLayoutHelper.java @@ -28,99 +28,98 @@ import org.apache.geaflow.common.binary.IBinaryObject; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; -/** - * This is based on Spark's BitSetMethods. - */ +/** This is based on Spark's BitSetMethods. */ public class BinaryLayoutHelper { - public static final int FIELDS_NUM_OFFSET = 0; - - public static final int NULL_BIT_OFFSET = 4; - - public static int getFieldsNum(IBinaryObject binaryObject) { - return BinaryOperations.getInt(binaryObject, FIELDS_NUM_OFFSET); - } - - /** - * Get the number bytes need to store null bit. - * - * @param numValues Number of values. - */ - public static int getBitSetBytes(int numValues) { - return (numValues + 7) / 8; - } - - public static int getExtendPoint(int numValues) { - return NULL_BIT_OFFSET + getBitSetBytes(numValues) + numValues * 8; - } - - /** - * Set the index-th bit to 1 from the baseOffset. - * - * @param baseObject The baseObject to set, it may be a byte[] for heap memory or null for off-heap. - * @param baseOffset The base offset. - * @param index The index-th bit to set to 1. - */ - public static void set(IBinaryObject baseObject, long baseOffset, int index) { - if (index < 0) { - throw new GeaFlowDSLException("index (" + index + ") should >= 0"); - } - final long mask = 1L << (index & 0x7); // mod 8 and shift - final long wordOffset = baseOffset + (index >> 3); // div 8 - final byte word = getByte(baseObject, wordOffset); - putByte(baseObject, wordOffset, (byte) (word | mask)); - } - - public static void unset(IBinaryObject baseObject, long baseOffset, int index) { - if (index < 0) { - throw new GeaFlowDSLException("index (" + index + ") should >= 0"); - } - final long mask = 1L << (index & 0x7); // mod 8 and shift - final long wordOffset = baseOffset + (index >> 3); - final byte word = getByte(baseObject, wordOffset); - putByte(baseObject, wordOffset, (byte) (word & ~mask)); - } - - public static boolean isSet(IBinaryObject baseObject, long baseOffset, int index) { - if (index < 0) { - throw new GeaFlowDSLException("index (" + index + ") should >= 0"); - } - final long mask = 1L << (index & 0x7); // mod 8 and shift - final long wordOffset = baseOffset + (index >> 3); - final byte word = getByte(baseObject, wordOffset); - return (word & mask) != 0; - } - - public static long getFieldOffset(int nullBitSetBytes, int index) { - return NULL_BIT_OFFSET + nullBitSetBytes + index * 8L; + public static final int FIELDS_NUM_OFFSET = 0; + + public static final int NULL_BIT_OFFSET = 4; + + public static int getFieldsNum(IBinaryObject binaryObject) { + return BinaryOperations.getInt(binaryObject, FIELDS_NUM_OFFSET); + } + + /** + * Get the number bytes need to store null bit. + * + * @param numValues Number of values. + */ + public static int getBitSetBytes(int numValues) { + return (numValues + 7) / 8; + } + + public static int getExtendPoint(int numValues) { + return NULL_BIT_OFFSET + getBitSetBytes(numValues) + numValues * 8; + } + + /** + * Set the index-th bit to 1 from the baseOffset. + * + * @param baseObject The baseObject to set, it may be a byte[] for heap memory or null for + * off-heap. + * @param baseOffset The base offset. + * @param index The index-th bit to set to 1. + */ + public static void set(IBinaryObject baseObject, long baseOffset, int index) { + if (index < 0) { + throw new GeaFlowDSLException("index (" + index + ") should >= 0"); } - - public static long getArrayFieldOffset(int nullBitSetBytes, int index) { - return nullBitSetBytes + index * 8L; + final long mask = 1L << (index & 0x7); // mod 8 and shift + final long wordOffset = baseOffset + (index >> 3); // div 8 + final byte word = getByte(baseObject, wordOffset); + putByte(baseObject, wordOffset, (byte) (word | mask)); + } + + public static void unset(IBinaryObject baseObject, long baseOffset, int index) { + if (index < 0) { + throw new GeaFlowDSLException("index (" + index + ") should >= 0"); } - - public static void zeroBytes(byte[] bytes) { - zeroBytes(bytes, 0, bytes.length); + final long mask = 1L << (index & 0x7); // mod 8 and shift + final long wordOffset = baseOffset + (index >> 3); + final byte word = getByte(baseObject, wordOffset); + putByte(baseObject, wordOffset, (byte) (word & ~mask)); + } + + public static boolean isSet(IBinaryObject baseObject, long baseOffset, int index) { + if (index < 0) { + throw new GeaFlowDSLException("index (" + index + ") should >= 0"); } - - public static void zeroBytes(byte[] bytes, long baseOffset, int size) { - zeroBytes(HeapBinaryObject.of(bytes), baseOffset, size); + final long mask = 1L << (index & 0x7); // mod 8 and shift + final long wordOffset = baseOffset + (index >> 3); + final byte word = getByte(baseObject, wordOffset); + return (word & mask) != 0; + } + + public static long getFieldOffset(int nullBitSetBytes, int index) { + return NULL_BIT_OFFSET + nullBitSetBytes + index * 8L; + } + + public static long getArrayFieldOffset(int nullBitSetBytes, int index) { + return nullBitSetBytes + index * 8L; + } + + public static void zeroBytes(byte[] bytes) { + zeroBytes(bytes, 0, bytes.length); + } + + public static void zeroBytes(byte[] bytes, long baseOffset, int size) { + zeroBytes(HeapBinaryObject.of(bytes), baseOffset, size); + } + + public static void zeroBytes(IBinaryObject baseObject, long baseOffset, int size) { + assert baseOffset >= 0 && baseOffset + size <= baseObject.size(); + + int workAlign = size / 8 * 8; + for (int i = 0; i < workAlign; i += 8) { + putLong(baseObject, baseOffset + i, 0L); } - - public static void zeroBytes(IBinaryObject baseObject, long baseOffset, int size) { - assert baseOffset >= 0 && baseOffset + size <= baseObject.size(); - - int workAlign = size / 8 * 8; - for (int i = 0; i < workAlign; i += 8) { - putLong(baseObject, baseOffset + i, 0L); - } - int retain = size - workAlign; - for (int i = 0; i < retain; i++) { - putByte(baseObject, baseOffset + workAlign + i, (byte) 0); - } + int retain = size - workAlign; + for (int i = 0; i < retain; i++) { + putByte(baseObject, baseOffset + workAlign + i, (byte) 0); } + } - public static int getInitBufferSize(int fieldsNum) { - return NULL_BIT_OFFSET + fieldsNum * 8; - } + public static int getInitBufferSize(int fieldsNum) { + return NULL_BIT_OFFSET + fieldsNum * 8; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/DecoderFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/DecoderFactory.java index 5cb434bfe..dbd0c0e7c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/DecoderFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/DecoderFactory.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary; import java.util.Locale; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.binary.decoder.DefaultEdgeDecoder; @@ -39,35 +40,35 @@ public class DecoderFactory { - public static IBinaryDecoder createDecoder(IType type) { - String typeName = type.getName().toUpperCase(Locale.ROOT); - switch (typeName) { - case Types.TYPE_NAME_VERTEX: - return new DefaultVertexDecoder((VertexType) type); - case Types.TYPE_NAME_EDGE: - return new DefaultEdgeDecoder((EdgeType) type); - case Types.TYPE_NAME_STRUCT: - return new DefaultRowDecoder((StructType) type); - case Types.TYPE_NAME_PATH: - return new DefaultPathDecoder((PathType) type); - default: - throw new GeaFlowDSLException("decoder type " + type.getName() + " is not support"); - } + public static IBinaryDecoder createDecoder(IType type) { + String typeName = type.getName().toUpperCase(Locale.ROOT); + switch (typeName) { + case Types.TYPE_NAME_VERTEX: + return new DefaultVertexDecoder((VertexType) type); + case Types.TYPE_NAME_EDGE: + return new DefaultEdgeDecoder((EdgeType) type); + case Types.TYPE_NAME_STRUCT: + return new DefaultRowDecoder((StructType) type); + case Types.TYPE_NAME_PATH: + return new DefaultPathDecoder((PathType) type); + default: + throw new GeaFlowDSLException("decoder type " + type.getName() + " is not support"); } + } - public static VertexDecoder createVertexDecoder(VertexType vertexType) { - return new DefaultVertexDecoder(vertexType); - } + public static VertexDecoder createVertexDecoder(VertexType vertexType) { + return new DefaultVertexDecoder(vertexType); + } - public static EdgeDecoder createEdgeDecoder(EdgeType edgeType) { - return new DefaultEdgeDecoder(edgeType); - } + public static EdgeDecoder createEdgeDecoder(EdgeType edgeType) { + return new DefaultEdgeDecoder(edgeType); + } - public static RowDecoder createRowDecoder(StructType rowType) { - return new DefaultRowDecoder(rowType); - } + public static RowDecoder createRowDecoder(StructType rowType) { + return new DefaultRowDecoder(rowType); + } - public static PathDecoder createPathDecoder(PathType pathType) { - return new DefaultPathDecoder(pathType); - } + public static PathDecoder createPathDecoder(PathType pathType) { + return new DefaultPathDecoder(pathType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/EncoderFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/EncoderFactory.java index 0f3e329a6..896e577a3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/EncoderFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/EncoderFactory.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary; import java.util.Locale; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.binary.encoder.DefaultEdgeEncoder; @@ -36,29 +37,29 @@ public class EncoderFactory { - public static IBinaryEncoder createEncoder(IType type) { - String typeName = type.getName().toUpperCase(Locale.ROOT); - switch (typeName) { - case Types.TYPE_NAME_VERTEX: - return new DefaultVertexEncoder((VertexType) type); - case Types.TYPE_NAME_EDGE: - return new DefaultEdgeEncoder((EdgeType) type); - case Types.TYPE_NAME_STRUCT: - return new DefaultRowEncoder((StructType) type); - default: - throw new GeaFlowDSLException("encoder type " + type.getName() + " is not support"); - } + public static IBinaryEncoder createEncoder(IType type) { + String typeName = type.getName().toUpperCase(Locale.ROOT); + switch (typeName) { + case Types.TYPE_NAME_VERTEX: + return new DefaultVertexEncoder((VertexType) type); + case Types.TYPE_NAME_EDGE: + return new DefaultEdgeEncoder((EdgeType) type); + case Types.TYPE_NAME_STRUCT: + return new DefaultRowEncoder((StructType) type); + default: + throw new GeaFlowDSLException("encoder type " + type.getName() + " is not support"); } + } - public static VertexEncoder createVertexEncoder(VertexType vertexType) { - return new DefaultVertexEncoder(vertexType); - } + public static VertexEncoder createVertexEncoder(VertexType vertexType) { + return new DefaultVertexEncoder(vertexType); + } - public static EdgeEncoder createEdgeEncoder(EdgeType edgeType) { - return new DefaultEdgeEncoder(edgeType); - } + public static EdgeEncoder createEdgeEncoder(EdgeType edgeType) { + return new DefaultEdgeEncoder(edgeType); + } - public static RowEncoder createRowEncoder(StructType rowType) { - return new DefaultRowEncoder(rowType); - } + public static RowEncoder createRowEncoder(StructType rowType) { + return new DefaultRowEncoder(rowType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldReaderFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldReaderFactory.java index 8df5ff0b8..3261781c9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldReaderFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldReaderFactory.java @@ -27,6 +27,7 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.Locale; + import org.apache.geaflow.common.binary.BinaryOperations; import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.binary.IBinaryObject; @@ -39,74 +40,73 @@ public class FieldReaderFactory { - public interface PropertyFieldReader { + public interface PropertyFieldReader { - V read(IBinaryObject baseObject, long offset); - } + V read(IBinaryObject baseObject, long offset); + } - public static PropertyFieldReader getPropertyFieldReader(IType type) { - String typeName = type.getName().toUpperCase(Locale.ROOT); - switch (typeName) { - case Types.TYPE_NAME_INTEGER: - return BinaryOperations::getInt; - case Types.TYPE_NAME_LONG: - return BinaryOperations::getLong; - case Types.TYPE_NAME_SHORT: - return BinaryOperations::getShort; - case Types.TYPE_NAME_DOUBLE: - return BinaryOperations::getDouble; - case Types.TYPE_NAME_BINARY_STRING: - return (baseObject, offset) -> { - int size = BinaryOperations.getInt(baseObject, offset); - long stringOffset = getContentOffset(baseObject, offset); - return new BinaryString(baseObject, stringOffset, size); - }; - case Types.TYPE_NAME_BOOLEAN: - return (baseObject, offset) -> BinaryOperations.getInt(baseObject, offset) == 1; - case Types.TYPE_NAME_TIMESTAMP: - return (baseObject, offset) -> new Timestamp(BinaryOperations.getLong(baseObject, - offset)); - case Types.TYPE_NAME_DATE: - return (baseObject, offset) -> new Date(BinaryOperations.getLong(baseObject, - offset)); - case Types.TYPE_NAME_OBJECT: - case Types.TYPE_NAME_VERTEX: - case Types.TYPE_NAME_EDGE: - return (baseObject, offset) -> { - int size = BinaryOperations.getInt(baseObject, offset); - long bytesOffset = getContentOffset(baseObject, offset); - byte[] objectBytes = new byte[size]; - BinaryOperations.copyMemory(baseObject, bytesOffset, objectBytes, 0, size); - return SerializerFactory.getKryoSerializer().deserialize(objectBytes); - }; - case Types.TYPE_NAME_ARRAY: - return (baseObject, offset) -> { - int arraySize = BinaryOperations.getInt(baseObject, offset); - long arrayOffset = getContentOffset(baseObject, offset); + public static PropertyFieldReader getPropertyFieldReader(IType type) { + String typeName = type.getName().toUpperCase(Locale.ROOT); + switch (typeName) { + case Types.TYPE_NAME_INTEGER: + return BinaryOperations::getInt; + case Types.TYPE_NAME_LONG: + return BinaryOperations::getLong; + case Types.TYPE_NAME_SHORT: + return BinaryOperations::getShort; + case Types.TYPE_NAME_DOUBLE: + return BinaryOperations::getDouble; + case Types.TYPE_NAME_BINARY_STRING: + return (baseObject, offset) -> { + int size = BinaryOperations.getInt(baseObject, offset); + long stringOffset = getContentOffset(baseObject, offset); + return new BinaryString(baseObject, stringOffset, size); + }; + case Types.TYPE_NAME_BOOLEAN: + return (baseObject, offset) -> BinaryOperations.getInt(baseObject, offset) == 1; + case Types.TYPE_NAME_TIMESTAMP: + return (baseObject, offset) -> new Timestamp(BinaryOperations.getLong(baseObject, offset)); + case Types.TYPE_NAME_DATE: + return (baseObject, offset) -> new Date(BinaryOperations.getLong(baseObject, offset)); + case Types.TYPE_NAME_OBJECT: + case Types.TYPE_NAME_VERTEX: + case Types.TYPE_NAME_EDGE: + return (baseObject, offset) -> { + int size = BinaryOperations.getInt(baseObject, offset); + long bytesOffset = getContentOffset(baseObject, offset); + byte[] objectBytes = new byte[size]; + BinaryOperations.copyMemory(baseObject, bytesOffset, objectBytes, 0, size); + return SerializerFactory.getKryoSerializer().deserialize(objectBytes); + }; + case Types.TYPE_NAME_ARRAY: + return (baseObject, offset) -> { + int arraySize = BinaryOperations.getInt(baseObject, offset); + long arrayOffset = getContentOffset(baseObject, offset); - IType componentType = ((ArrayType) type).getComponentType(); - Object[] array = - (Object[]) Array.newInstance( - FunctionCallUtils.typeClass(componentType.getTypeClass(), true), arraySize); + IType componentType = ((ArrayType) type).getComponentType(); + Object[] array = + (Object[]) + Array.newInstance( + FunctionCallUtils.typeClass(componentType.getTypeClass(), true), arraySize); - for (int i = 0; i < arraySize; i++) { - if (isSet(baseObject, arrayOffset, i)) { - array[i] = null; - } else { - long elementOffset = arrayOffset + getArrayFieldOffset(getBitSetBytes(arraySize), i); + for (int i = 0; i < arraySize; i++) { + if (isSet(baseObject, arrayOffset, i)) { + array[i] = null; + } else { + long elementOffset = arrayOffset + getArrayFieldOffset(getBitSetBytes(arraySize), i); - PropertyFieldReader elementReader = getPropertyFieldReader(componentType); - array[i] = elementReader.read(baseObject, elementOffset); - } - } - return array; - }; - default: - throw new GeaFlowDSLException("field type not supported: " + typeName); - } + PropertyFieldReader elementReader = getPropertyFieldReader(componentType); + array[i] = elementReader.read(baseObject, elementOffset); + } + } + return array; + }; + default: + throw new GeaFlowDSLException("field type not supported: " + typeName); } + } - private static long getContentOffset(IBinaryObject baseObject, long headOffset) { - return BinaryOperations.getInt(baseObject, headOffset + 4); - } + private static long getContentOffset(IBinaryObject baseObject, long headOffset) { + return BinaryOperations.getInt(baseObject, headOffset + 4); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldWriterFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldWriterFactory.java index 3edcdd086..0b1dc1a24 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldWriterFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/FieldWriterFactory.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.dsl.common.binary.BinaryLayoutHelper.zeroBytes; import java.util.Locale; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; @@ -31,165 +32,174 @@ public class FieldWriterFactory { - public interface PropertyFieldWriter { + public interface PropertyFieldWriter { - void write(WriterBuffer writerBuffer, long nullBitsOffset, int index, V value); - } + void write(WriterBuffer writerBuffer, long nullBitsOffset, int index, V value); + } - public static PropertyFieldWriter getPropertyFieldWriter(IType type) { - String typeName = type.getName().toUpperCase(Locale.ROOT); - switch (typeName) { - case Types.TYPE_NAME_INTEGER: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeIntAlign(value); - } - }; - case Types.TYPE_NAME_LONG: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeLong(value); - } - }; - case Types.TYPE_NAME_SHORT: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeShortAlign(value); - } - }; - case Types.TYPE_NAME_DOUBLE: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeDouble(value); - } - }; - case Types.TYPE_NAME_BINARY_STRING: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - writeString(writerBuffer, nullBitsOffset, index, value); - }; - case Types.TYPE_NAME_BOOLEAN: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, - value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeIntAlign(value ? 1 : 0); - } - }; - case Types.TYPE_NAME_TIMESTAMP: - case Types.TYPE_NAME_DATE: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - writerBuffer.writeLong(value.getTime()); - } - }; - case Types.TYPE_NAME_OBJECT: - case Types.TYPE_NAME_VERTEX: - case Types.TYPE_NAME_EDGE: - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - byte[] bytes = null; - if (value != null) { - bytes = SerializerFactory.getKryoSerializer().serialize(value); - } - writeBytes(writerBuffer, nullBitsOffset, index, bytes); - }; - case Types.TYPE_NAME_ARRAY: - ArrayType arrayType = (ArrayType) type; - return (PropertyFieldWriter) (writerBuffer, nullBitsOffset, index, value) -> { - if (value == null) { - writerBuffer.setNullAt(nullBitsOffset, index); - } else { - Object[] array = (Object[]) value; - int currentCursor = writerBuffer.getCursor(); - final int arrayStartOffset = writerBuffer.getExtendPoint(); - writerBuffer.moveToExtend(); + public static PropertyFieldWriter getPropertyFieldWriter(IType type) { + String typeName = type.getName().toUpperCase(Locale.ROOT); + switch (typeName) { + case Types.TYPE_NAME_INTEGER: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeIntAlign(value); + } + }; + case Types.TYPE_NAME_LONG: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeLong(value); + } + }; + case Types.TYPE_NAME_SHORT: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeShortAlign(value); + } + }; + case Types.TYPE_NAME_DOUBLE: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeDouble(value); + } + }; + case Types.TYPE_NAME_BINARY_STRING: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + writeString(writerBuffer, nullBitsOffset, index, value); + }; + case Types.TYPE_NAME_BOOLEAN: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeIntAlign(value ? 1 : 0); + } + }; + case Types.TYPE_NAME_TIMESTAMP: + case Types.TYPE_NAME_DATE: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + writerBuffer.writeLong(value.getTime()); + } + }; + case Types.TYPE_NAME_OBJECT: + case Types.TYPE_NAME_VERTEX: + case Types.TYPE_NAME_EDGE: + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + byte[] bytes = null; + if (value != null) { + bytes = SerializerFactory.getKryoSerializer().serialize(value); + } + writeBytes(writerBuffer, nullBitsOffset, index, bytes); + }; + case Types.TYPE_NAME_ARRAY: + ArrayType arrayType = (ArrayType) type; + return (PropertyFieldWriter) + (writerBuffer, nullBitsOffset, index, value) -> { + if (value == null) { + writerBuffer.setNullAt(nullBitsOffset, index); + } else { + Object[] array = (Object[]) value; + int currentCursor = writerBuffer.getCursor(); + final int arrayStartOffset = writerBuffer.getExtendPoint(); + writerBuffer.moveToExtend(); - writeArray(writerBuffer, array, arrayType.getComponentType()); - writerBuffer.setCursor(currentCursor); - // write array length - writerBuffer.writeInt(array.length); - // write array content offset - writerBuffer.writeInt(arrayStartOffset); - } - }; - default: - throw new GeaFlowDSLException("field type not supported: " + typeName); - } + writeArray(writerBuffer, array, arrayType.getComponentType()); + writerBuffer.setCursor(currentCursor); + // write array length + writerBuffer.writeInt(array.length); + // write array content offset + writerBuffer.writeInt(arrayStartOffset); + } + }; + default: + throw new GeaFlowDSLException("field type not supported: " + typeName); } + } - public static void writeArray(WriterBuffer writerBuffer, Object[] array, IType componentType) { - int startCursor = writerBuffer.getCursor(); - byte[] nullBitSet = new byte[BinaryLayoutHelper.getBitSetBytes(array.length)]; - // nullBits length + array-length * 8 - int baseSizeNeed = nullBitSet.length + array.length * 8; - writerBuffer.grow(baseSizeNeed); - writerBuffer.setExtendPoint(startCursor + baseSizeNeed); + public static void writeArray(WriterBuffer writerBuffer, Object[] array, IType componentType) { + int startCursor = writerBuffer.getCursor(); + byte[] nullBitSet = new byte[BinaryLayoutHelper.getBitSetBytes(array.length)]; + // nullBits length + array-length * 8 + int baseSizeNeed = nullBitSet.length + array.length * 8; + writerBuffer.grow(baseSizeNeed); + writerBuffer.setExtendPoint(startCursor + baseSizeNeed); - long nullBitOffset = writerBuffer.getCursor(); - // clear null bits - zeroBytes(nullBitSet, 0, nullBitSet.length); - // write null bits - writerBuffer.writeBytes(nullBitSet); + long nullBitOffset = writerBuffer.getCursor(); + // clear null bits + zeroBytes(nullBitSet, 0, nullBitSet.length); + // write null bits + writerBuffer.writeBytes(nullBitSet); - for (int i = 0; i < array.length; i++) { - PropertyFieldWriter writer = getPropertyFieldWriter(componentType); - writer.write(writerBuffer, nullBitOffset, i, array[i]); - } + for (int i = 0; i < array.length; i++) { + PropertyFieldWriter writer = getPropertyFieldWriter(componentType); + writer.write(writerBuffer, nullBitOffset, i, array[i]); } + } - public static void writeString(WriterBuffer writerBuffer, long baseOffset, int index, - BinaryString string) { - if (string == null) { - writerBuffer.setNullAt(baseOffset, index); - } else { - int bytesLength = string.getNumBytes(); - // write bytes size - writerBuffer.writeInt(bytesLength); - // write bytes offset - writerBuffer.writeInt(writerBuffer.getExtendPoint()); - // save current cursor - int currentCursor = writerBuffer.getCursor(); - // move cursor to extend region - writerBuffer.moveToExtend(); - // grow buffer - writerBuffer.growTo(writerBuffer.getExtendPoint() + bytesLength); - writerBuffer.writeBytes(string.getBinaryObject(), string.getOffset(), bytesLength); - // set new extend point - writerBuffer.setExtendPoint(writerBuffer.getCursor()); - // reset cursor - writerBuffer.setCursor(currentCursor); - } + public static void writeString( + WriterBuffer writerBuffer, long baseOffset, int index, BinaryString string) { + if (string == null) { + writerBuffer.setNullAt(baseOffset, index); + } else { + int bytesLength = string.getNumBytes(); + // write bytes size + writerBuffer.writeInt(bytesLength); + // write bytes offset + writerBuffer.writeInt(writerBuffer.getExtendPoint()); + // save current cursor + int currentCursor = writerBuffer.getCursor(); + // move cursor to extend region + writerBuffer.moveToExtend(); + // grow buffer + writerBuffer.growTo(writerBuffer.getExtendPoint() + bytesLength); + writerBuffer.writeBytes(string.getBinaryObject(), string.getOffset(), bytesLength); + // set new extend point + writerBuffer.setExtendPoint(writerBuffer.getCursor()); + // reset cursor + writerBuffer.setCursor(currentCursor); } + } - public static void writeBytes(WriterBuffer writerBuffer, long baseOffset, int index, byte[] bytes) { - if (bytes == null) { - writerBuffer.setNullAt(baseOffset, index); - } else { - // write bytes size - writerBuffer.writeInt(bytes.length); - // write bytes offset - writerBuffer.writeInt(writerBuffer.getExtendPoint()); - // save current cursor - int currentCursor = writerBuffer.getCursor(); - // move cursor to extend region - writerBuffer.moveToExtend(); - // grow buffer - writerBuffer.growTo(writerBuffer.getExtendPoint() + bytes.length); - writerBuffer.writeBytes(bytes); - // set new extend point - writerBuffer.setExtendPoint(writerBuffer.getCursor()); - // reset cursor - writerBuffer.setCursor(currentCursor); - } + public static void writeBytes( + WriterBuffer writerBuffer, long baseOffset, int index, byte[] bytes) { + if (bytes == null) { + writerBuffer.setNullAt(baseOffset, index); + } else { + // write bytes size + writerBuffer.writeInt(bytes.length); + // write bytes offset + writerBuffer.writeInt(writerBuffer.getExtendPoint()); + // save current cursor + int currentCursor = writerBuffer.getCursor(); + // move cursor to extend region + writerBuffer.moveToExtend(); + // grow buffer + writerBuffer.growTo(writerBuffer.getExtendPoint() + bytes.length); + writerBuffer.writeBytes(bytes); + // set new extend point + writerBuffer.setExtendPoint(writerBuffer.getCursor()); + // reset cursor + writerBuffer.setCursor(currentCursor); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/HeapWriterBuffer.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/HeapWriterBuffer.java index 49d248a19..736f49853 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/HeapWriterBuffer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/HeapWriterBuffer.java @@ -26,181 +26,181 @@ public class HeapWriterBuffer implements WriterBuffer { - private HeapBinaryObject buffer; - - private int cursor; - - private int extendPoint = 0; - - @Override - public void initialize(int initSize) { - if (initSize <= 0) { - throw new GeaFlowDSLException("Illegal init buffer size: " + initSize); - } - buffer = HeapBinaryObject.of(new byte[initSize]); - this.cursor = 0; - } - - @Override - public void grow(int size) { - byte[] newBuffer = new byte[buffer.size() + size]; - BinaryOperations.copyMemory(buffer, 0, - newBuffer, 0, buffer.size()); - this.buffer = HeapBinaryObject.of(newBuffer); - } - - @Override - public void growTo(int targetSize) { - int growSize = targetSize - getCapacity(); - if (growSize > 0) { - grow(growSize); - } - } - - @Override - public byte[] copyBuffer() { - assert buffer != null; - byte[] copy = new byte[extendPoint]; - BinaryOperations.copyMemory(buffer, 0, copy, 0, extendPoint); - return copy; - } - - @Override - public int getCapacity() { - return buffer.size(); - } - - @Override - public void writeByte(byte b) { - BinaryOperations.putByte(buffer, cursor, b); - cursor += 1; - checkCursorAfterWrite(); - } - - @Override - public void writeInt(int v) { - BinaryOperations.putInt(buffer, cursor, v); - cursor += 4; - checkCursorAfterWrite(); - } - - @Override - public void writeIntAlign(int v) { - BinaryOperations.putLong(buffer, cursor, 0L); - BinaryOperations.putInt(buffer, cursor, v); - cursor += 8; - checkCursorAfterWrite(); - } - - @Override - public void writeShort(short v) { - BinaryOperations.putShort(buffer, cursor, v); - cursor += 2; - checkCursorAfterWrite(); - } - - @Override - public void writeShortAlign(short v) { - BinaryOperations.putLong(buffer, cursor, 0L); - BinaryOperations.putShort(buffer, cursor, v); - cursor += 8; - checkCursorAfterWrite(); - } - - @Override - public void writeLong(long v) { - BinaryOperations.putLong(buffer, cursor, v); - cursor += 8; - checkCursorAfterWrite(); - } - - @Override - public void writeDouble(double v) { - BinaryOperations.putDouble(buffer, cursor, v); - cursor += 8; - checkCursorAfterWrite(); - } - - @Override - public void writeBytes(byte[] bytes) { - BinaryOperations.copyMemory(bytes, 0, buffer, cursor, bytes.length); - cursor += bytes.length; - checkCursorAfterWrite(); - } - - @Override - public void writeBytes(IBinaryObject src, long srcOffset, long length) { - BinaryOperations.copyMemory(src, srcOffset, buffer, cursor, length); - cursor += length; - checkCursorAfterWrite(); - } - - private void checkCursorAfterWrite() { - if (cursor <= 0 || cursor > buffer.size()) { - throw new GeaFlowDSLException("Illegal cursor: " + cursor + ", buffer length is:" + buffer.size()); - } - } - - @Override - public int getCursor() { - return cursor; - } - - @Override - public void setCursor(int cursor) { - if (cursor < 0) { - throw new GeaFlowDSLException("Illegal cursor" + cursor); - } - this.cursor = cursor; - } - - @Override - public void moveCursor(int cursor) { - setCursor(getCursor() + cursor); - } - - @Override - public void setExtendPoint(int tailPoint) { - if (tailPoint < this.extendPoint) { - throw new GeaFlowDSLException("Current tailPoint: " + tailPoint + " should >= " - + "the pre value:" + this.extendPoint); - } - this.extendPoint = tailPoint; - if (tailPoint > buffer.size()) { - grow(tailPoint - buffer.size()); - } - } - - @Override - public int getExtendPoint() { - return extendPoint; - } - - @Override - public void moveToExtend() { - setCursor(getExtendPoint()); - } - - @Override - public void reset() { - cursor = 0; - extendPoint = 0; - } - - @Override - public void setNullAt(long offset, int index) { - BinaryLayoutHelper.set(buffer, offset, index); - // align index - this.cursor += 8; - } - - @Override - public void release() { - buffer.release(); - } - - @Override - public boolean isReleased() { - return buffer.isReleased(); - } + private HeapBinaryObject buffer; + + private int cursor; + + private int extendPoint = 0; + + @Override + public void initialize(int initSize) { + if (initSize <= 0) { + throw new GeaFlowDSLException("Illegal init buffer size: " + initSize); + } + buffer = HeapBinaryObject.of(new byte[initSize]); + this.cursor = 0; + } + + @Override + public void grow(int size) { + byte[] newBuffer = new byte[buffer.size() + size]; + BinaryOperations.copyMemory(buffer, 0, newBuffer, 0, buffer.size()); + this.buffer = HeapBinaryObject.of(newBuffer); + } + + @Override + public void growTo(int targetSize) { + int growSize = targetSize - getCapacity(); + if (growSize > 0) { + grow(growSize); + } + } + + @Override + public byte[] copyBuffer() { + assert buffer != null; + byte[] copy = new byte[extendPoint]; + BinaryOperations.copyMemory(buffer, 0, copy, 0, extendPoint); + return copy; + } + + @Override + public int getCapacity() { + return buffer.size(); + } + + @Override + public void writeByte(byte b) { + BinaryOperations.putByte(buffer, cursor, b); + cursor += 1; + checkCursorAfterWrite(); + } + + @Override + public void writeInt(int v) { + BinaryOperations.putInt(buffer, cursor, v); + cursor += 4; + checkCursorAfterWrite(); + } + + @Override + public void writeIntAlign(int v) { + BinaryOperations.putLong(buffer, cursor, 0L); + BinaryOperations.putInt(buffer, cursor, v); + cursor += 8; + checkCursorAfterWrite(); + } + + @Override + public void writeShort(short v) { + BinaryOperations.putShort(buffer, cursor, v); + cursor += 2; + checkCursorAfterWrite(); + } + + @Override + public void writeShortAlign(short v) { + BinaryOperations.putLong(buffer, cursor, 0L); + BinaryOperations.putShort(buffer, cursor, v); + cursor += 8; + checkCursorAfterWrite(); + } + + @Override + public void writeLong(long v) { + BinaryOperations.putLong(buffer, cursor, v); + cursor += 8; + checkCursorAfterWrite(); + } + + @Override + public void writeDouble(double v) { + BinaryOperations.putDouble(buffer, cursor, v); + cursor += 8; + checkCursorAfterWrite(); + } + + @Override + public void writeBytes(byte[] bytes) { + BinaryOperations.copyMemory(bytes, 0, buffer, cursor, bytes.length); + cursor += bytes.length; + checkCursorAfterWrite(); + } + + @Override + public void writeBytes(IBinaryObject src, long srcOffset, long length) { + BinaryOperations.copyMemory(src, srcOffset, buffer, cursor, length); + cursor += length; + checkCursorAfterWrite(); + } + + private void checkCursorAfterWrite() { + if (cursor <= 0 || cursor > buffer.size()) { + throw new GeaFlowDSLException( + "Illegal cursor: " + cursor + ", buffer length is:" + buffer.size()); + } + } + + @Override + public int getCursor() { + return cursor; + } + + @Override + public void setCursor(int cursor) { + if (cursor < 0) { + throw new GeaFlowDSLException("Illegal cursor" + cursor); + } + this.cursor = cursor; + } + + @Override + public void moveCursor(int cursor) { + setCursor(getCursor() + cursor); + } + + @Override + public void setExtendPoint(int tailPoint) { + if (tailPoint < this.extendPoint) { + throw new GeaFlowDSLException( + "Current tailPoint: " + tailPoint + " should >= " + "the pre value:" + this.extendPoint); + } + this.extendPoint = tailPoint; + if (tailPoint > buffer.size()) { + grow(tailPoint - buffer.size()); + } + } + + @Override + public int getExtendPoint() { + return extendPoint; + } + + @Override + public void moveToExtend() { + setCursor(getExtendPoint()); + } + + @Override + public void reset() { + cursor = 0; + extendPoint = 0; + } + + @Override + public void setNullAt(long offset, int index) { + BinaryLayoutHelper.set(buffer, offset, index); + // align index + this.cursor += 8; + } + + @Override + public void release() { + buffer.release(); + } + + @Override + public boolean isReleased() { + return buffer.isReleased(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/WriterBuffer.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/WriterBuffer.java index 7014a0f29..ec728db7e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/WriterBuffer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/WriterBuffer.java @@ -20,55 +20,56 @@ package org.apache.geaflow.dsl.common.binary; import java.io.Serializable; + import org.apache.geaflow.common.binary.IBinaryObject; public interface WriterBuffer extends Serializable { - void initialize(int initSize); + void initialize(int initSize); - void grow(int size); + void grow(int size); - void growTo(int targetSize); + void growTo(int targetSize); - Object copyBuffer(); + Object copyBuffer(); - int getCapacity(); + int getCapacity(); - void writeByte(byte b); + void writeByte(byte b); - void writeInt(int v); + void writeInt(int v); - void writeIntAlign(int v); + void writeIntAlign(int v); - void writeShort(short v); + void writeShort(short v); - void writeShortAlign(short v); + void writeShortAlign(short v); - void writeLong(long v); + void writeLong(long v); - void writeDouble(double v); + void writeDouble(double v); - void writeBytes(byte[] bytes); + void writeBytes(byte[] bytes); - void writeBytes(IBinaryObject src, long srcOffset, long length); + void writeBytes(IBinaryObject src, long srcOffset, long length); - int getCursor(); + int getCursor(); - void setCursor(int cursor); + void setCursor(int cursor); - void moveCursor(int cursor); + void moveCursor(int cursor); - int getExtendPoint(); + int getExtendPoint(); - void moveToExtend(); + void moveToExtend(); - void setExtendPoint(int tailPoint); + void setExtendPoint(int tailPoint); - void reset(); + void reset(); - void setNullAt(long offset, int index); + void setNullAt(long offset, int index); - void release(); + void release(); - boolean isReleased(); + boolean isReleased(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultEdgeDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultEdgeDecoder.java index 2d4a9cd84..80c8d2584 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultEdgeDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultEdgeDecoder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary.decoder; import java.util.List; + import org.apache.geaflow.dsl.common.binary.DecoderFactory; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -31,32 +32,32 @@ public class DefaultEdgeDecoder implements EdgeDecoder { - private final RowDecoder rowDecoder; - private final EdgeType edgeType; + private final RowDecoder rowDecoder; + private final EdgeType edgeType; - public DefaultEdgeDecoder(EdgeType edgeType) { - StructType rowType = new StructType(edgeType.getValueFields()); - this.rowDecoder = DecoderFactory.createRowDecoder(rowType); - this.edgeType = edgeType; - } + public DefaultEdgeDecoder(EdgeType edgeType) { + StructType rowType = new StructType(edgeType.getValueFields()); + this.rowDecoder = DecoderFactory.createRowDecoder(rowType); + this.edgeType = edgeType; + } - @Override - public RowEdge decode(RowEdge rowEdge) { - RowEdge decodeEdge = VertexEdgeFactory.createEdge(edgeType); - decodeEdge.setSrcId(rowEdge.getSrcId()); - decodeEdge.setTargetId(rowEdge.getTargetId()); - decodeEdge.setBinaryLabel(rowEdge.getBinaryLabel()); - decodeEdge.setDirect(rowEdge.getDirect()); - if (edgeType.getTimestamp().isPresent()) { - ((IGraphElementWithTimeField) decodeEdge).setTime(((IGraphElementWithTimeField) rowEdge).getTime()); - } - Object[] values = new Object[edgeType.getValueSize()]; - List valueFields = edgeType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - values[i] = rowEdge.getField(edgeType.getValueOffset() + i, - valueFields.get(i).getType()); - } - decodeEdge.setValue(rowDecoder.decode(ObjectRow.create(values))); - return decodeEdge; + @Override + public RowEdge decode(RowEdge rowEdge) { + RowEdge decodeEdge = VertexEdgeFactory.createEdge(edgeType); + decodeEdge.setSrcId(rowEdge.getSrcId()); + decodeEdge.setTargetId(rowEdge.getTargetId()); + decodeEdge.setBinaryLabel(rowEdge.getBinaryLabel()); + decodeEdge.setDirect(rowEdge.getDirect()); + if (edgeType.getTimestamp().isPresent()) { + ((IGraphElementWithTimeField) decodeEdge) + .setTime(((IGraphElementWithTimeField) rowEdge).getTime()); + } + Object[] values = new Object[edgeType.getValueSize()]; + List valueFields = edgeType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + values[i] = rowEdge.getField(edgeType.getValueOffset() + i, valueFields.get(i).getType()); } + decodeEdge.setValue(rowDecoder.decode(ObjectRow.create(values))); + return decodeEdge; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultPathDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultPathDecoder.java index c4707a1c3..4a0d9b64b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultPathDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultPathDecoder.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.dsl.common.binary.DecoderFactory; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -30,24 +31,24 @@ public class DefaultPathDecoder implements PathDecoder { - private final IBinaryDecoder[] binaryDecoders; + private final IBinaryDecoder[] binaryDecoders; - public DefaultPathDecoder(PathType pathType) { - this.binaryDecoders = new IBinaryDecoder[pathType.size()]; - List pathFields = pathType.getFields(); - for (int i = 0; i < pathFields.size(); i++) { - TableField field = pathFields.get(i); - binaryDecoders[i] = DecoderFactory.createDecoder(field.getType()); - } + public DefaultPathDecoder(PathType pathType) { + this.binaryDecoders = new IBinaryDecoder[pathType.size()]; + List pathFields = pathType.getFields(); + for (int i = 0; i < pathFields.size(); i++) { + TableField field = pathFields.get(i); + binaryDecoders[i] = DecoderFactory.createDecoder(field.getType()); } + } - @Override - public Path decode(Path rowPath) { - List pathNodes = rowPath.getPathNodes(); - List decodePathNodes = new ArrayList<>(pathNodes.size()); - for (int i = 0; i < pathNodes.size(); i++) { - decodePathNodes.add(binaryDecoders[i].decode(pathNodes.get(i))); - } - return new DefaultPath(decodePathNodes); + @Override + public Path decode(Path rowPath) { + List pathNodes = rowPath.getPathNodes(); + List decodePathNodes = new ArrayList<>(pathNodes.size()); + for (int i = 0; i < pathNodes.size(); i++) { + decodePathNodes.add(binaryDecoders[i].decode(pathNodes.get(i))); } + return new DefaultPath(decodePathNodes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultRowDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultRowDecoder.java index 9834e179b..2d27db2b2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultRowDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultRowDecoder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary.decoder; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.binary.DecoderFactory; @@ -34,69 +35,69 @@ public class DefaultRowDecoder implements RowDecoder { - private final StructType rowType; - private final IBinaryDecoder[] valueDecoders; + private final StructType rowType; + private final IBinaryDecoder[] valueDecoders; - public DefaultRowDecoder(StructType rowType) { - this.rowType = rowType; - this.valueDecoders = new IBinaryDecoder[rowType.size()]; - List fields = rowType.getFields(); - for (int i = 0; i < fields.size(); i++) { - TableField field = fields.get(i); - valueDecoders[i] = generateDecoder(field.getType()); - } + public DefaultRowDecoder(StructType rowType) { + this.rowType = rowType; + this.valueDecoders = new IBinaryDecoder[rowType.size()]; + List fields = rowType.getFields(); + for (int i = 0; i < fields.size(); i++) { + TableField field = fields.get(i); + valueDecoders[i] = generateDecoder(field.getType()); } + } - private IBinaryDecoder generateDecoder(IType type) { - if (type instanceof VertexType) { - return DecoderFactory.createVertexDecoder((VertexType) type); - } else if (type instanceof EdgeType) { - return DecoderFactory.createEdgeDecoder((EdgeType) type); - } else if (type instanceof PathType) { - return DecoderFactory.createPathDecoder((PathType) type); - } else if (type instanceof StructType) { - return DecoderFactory.createRowDecoder((StructType) type); - } - return null; + private IBinaryDecoder generateDecoder(IType type) { + if (type instanceof VertexType) { + return DecoderFactory.createVertexDecoder((VertexType) type); + } else if (type instanceof EdgeType) { + return DecoderFactory.createEdgeDecoder((EdgeType) type); + } else if (type instanceof PathType) { + return DecoderFactory.createPathDecoder((PathType) type); + } else if (type instanceof StructType) { + return DecoderFactory.createRowDecoder((StructType) type); } + return null; + } - @Override - public Row decode(Row row) { - List fields = rowType.getFields(); - Object[] values = new Object[fields.size()]; - for (int i = 0; i < fields.size(); i++) { - TableField field = fields.get(i); - if (valueDecoders[i] != null) { - values[i] = valueDecoders[i].decode((Row) row.getField(i, field.getType())); - } else { - values[i] = decode(row.getField(i, field.getType()), field.getType()); - } - } - return ObjectRow.create(values); + @Override + public Row decode(Row row) { + List fields = rowType.getFields(); + Object[] values = new Object[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + TableField field = fields.get(i); + if (valueDecoders[i] != null) { + values[i] = valueDecoders[i].decode((Row) row.getField(i, field.getType())); + } else { + values[i] = decode(row.getField(i, field.getType()), field.getType()); + } } + return ObjectRow.create(values); + } - private Object decode(Object o, IType type) { - if (type instanceof ArrayType) { - if (o == null) { - return null; - } - if (o.getClass().isArray()) { - Object[] array = (Object[]) o; - Object[] decodeArray = new Object[array.length]; - IType componentType = ((ArrayType) type).getComponentType(); - IBinaryDecoder decoder = generateDecoder(componentType); - for (int i = 0; i < array.length; i++) { - if (decoder != null) { - decodeArray[i] = decoder.decode((Row) array[i]); - } else { - decodeArray[i] = decode(array[i], componentType); - } - } - return decodeArray; - } - } else if (o instanceof BinaryString) { - return o.toString(); + private Object decode(Object o, IType type) { + if (type instanceof ArrayType) { + if (o == null) { + return null; + } + if (o.getClass().isArray()) { + Object[] array = (Object[]) o; + Object[] decodeArray = new Object[array.length]; + IType componentType = ((ArrayType) type).getComponentType(); + IBinaryDecoder decoder = generateDecoder(componentType); + for (int i = 0; i < array.length; i++) { + if (decoder != null) { + decodeArray[i] = decoder.decode((Row) array[i]); + } else { + decodeArray[i] = decode(array[i], componentType); + } } - return o; + return decodeArray; + } + } else if (o instanceof BinaryString) { + return o.toString(); } + return o; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultVertexDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultVertexDecoder.java index 04a76b3c9..38a937de5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultVertexDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/DefaultVertexDecoder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary.decoder; import java.util.List; + import org.apache.geaflow.dsl.common.binary.DecoderFactory; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -30,27 +31,26 @@ public class DefaultVertexDecoder implements VertexDecoder { - private final RowDecoder rowDecoder; - private final VertexType vertexType; + private final RowDecoder rowDecoder; + private final VertexType vertexType; - public DefaultVertexDecoder(VertexType vertexType) { - StructType rowType = new StructType(vertexType.getValueFields()); - this.rowDecoder = DecoderFactory.createRowDecoder(rowType); - this.vertexType = vertexType; - } + public DefaultVertexDecoder(VertexType vertexType) { + StructType rowType = new StructType(vertexType.getValueFields()); + this.rowDecoder = DecoderFactory.createRowDecoder(rowType); + this.vertexType = vertexType; + } - @Override - public RowVertex decode(RowVertex rowVertex) { - RowVertex decodeVertex = VertexEdgeFactory.createVertex(vertexType); - decodeVertex.setId(rowVertex.getId()); - decodeVertex.setBinaryLabel(rowVertex.getBinaryLabel()); - Object[] values = new Object[vertexType.getValueSize()]; - List valueFields = vertexType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - values[i] = rowVertex.getField(vertexType.getValueOffset() + i, - valueFields.get(i).getType()); - } - decodeVertex.setValue(rowDecoder.decode(ObjectRow.create(values))); - return decodeVertex; + @Override + public RowVertex decode(RowVertex rowVertex) { + RowVertex decodeVertex = VertexEdgeFactory.createVertex(vertexType); + decodeVertex.setId(rowVertex.getId()); + decodeVertex.setBinaryLabel(rowVertex.getBinaryLabel()); + Object[] values = new Object[vertexType.getValueSize()]; + List valueFields = vertexType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + values[i] = rowVertex.getField(vertexType.getValueOffset() + i, valueFields.get(i).getType()); } + decodeVertex.setValue(rowDecoder.decode(ObjectRow.create(values))); + return decodeVertex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/EdgeDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/EdgeDecoder.java index f86334eaf..0f0adcfdf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/EdgeDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/EdgeDecoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.RowEdge; -public interface EdgeDecoder extends IBinaryDecoder { - -} +public interface EdgeDecoder extends IBinaryDecoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/IBinaryDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/IBinaryDecoder.java index 914fc3801..11eb568ab 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/IBinaryDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/IBinaryDecoder.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.common.binary.decoder; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface IBinaryDecoder extends Serializable { - OUT decode(IN row); + OUT decode(IN row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/PathDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/PathDecoder.java index 9b258169a..a1fc135bb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/PathDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/PathDecoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.Path; -public interface PathDecoder extends IBinaryDecoder { - -} +public interface PathDecoder extends IBinaryDecoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/RowDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/RowDecoder.java index 06ed155a5..c849e92b7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/RowDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/RowDecoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.Row; -public interface RowDecoder extends IBinaryDecoder { - -} +public interface RowDecoder extends IBinaryDecoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/VertexDecoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/VertexDecoder.java index d4cc03894..b4ffdaf84 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/VertexDecoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/decoder/VertexDecoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.RowVertex; -public interface VertexDecoder extends IBinaryDecoder { - -} +public interface VertexDecoder extends IBinaryDecoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultEdgeEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultEdgeEncoder.java index f5ef34516..861dcade1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultEdgeEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultEdgeEncoder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary.encoder; import java.util.List; + import org.apache.geaflow.dsl.common.binary.EncoderFactory; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -32,32 +33,32 @@ public class DefaultEdgeEncoder implements EdgeEncoder { - private final RowEncoder rowEncoder; - private final EdgeType edgeType; + private final RowEncoder rowEncoder; + private final EdgeType edgeType; - public DefaultEdgeEncoder(EdgeType edgeType) { - StructType rowType = new StructType(edgeType.getValueFields()); - this.rowEncoder = EncoderFactory.createRowEncoder(rowType); - this.edgeType = edgeType; - } + public DefaultEdgeEncoder(EdgeType edgeType) { + StructType rowType = new StructType(edgeType.getValueFields()); + this.rowEncoder = EncoderFactory.createRowEncoder(rowType); + this.edgeType = edgeType; + } - @Override - public RowEdge encode(RowEdge rowEdge) { - RowEdge binaryEdge = VertexEdgeFactory.createEdge(edgeType); - binaryEdge.setSrcId(BinaryUtil.toBinaryForString(rowEdge.getSrcId())); - binaryEdge.setTargetId(BinaryUtil.toBinaryForString(rowEdge.getTargetId())); - if (edgeType.getTimestamp().isPresent()) { - ((IGraphElementWithTimeField) binaryEdge).setTime(((IGraphElementWithTimeField) rowEdge).getTime()); - } - binaryEdge.setDirect(rowEdge.getDirect()); - binaryEdge.setBinaryLabel(rowEdge.getBinaryLabel()); - Object[] values = new Object[edgeType.getValueSize()]; - List valueFields = edgeType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - values[i] = rowEdge.getField(edgeType.getValueOffset() + i, - valueFields.get(i).getType()); - } - binaryEdge.setValue(rowEncoder.encode(ObjectRow.create(values))); - return binaryEdge; + @Override + public RowEdge encode(RowEdge rowEdge) { + RowEdge binaryEdge = VertexEdgeFactory.createEdge(edgeType); + binaryEdge.setSrcId(BinaryUtil.toBinaryForString(rowEdge.getSrcId())); + binaryEdge.setTargetId(BinaryUtil.toBinaryForString(rowEdge.getTargetId())); + if (edgeType.getTimestamp().isPresent()) { + ((IGraphElementWithTimeField) binaryEdge) + .setTime(((IGraphElementWithTimeField) rowEdge).getTime()); + } + binaryEdge.setDirect(rowEdge.getDirect()); + binaryEdge.setBinaryLabel(rowEdge.getBinaryLabel()); + Object[] values = new Object[edgeType.getValueSize()]; + List valueFields = edgeType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + values[i] = rowEdge.getField(edgeType.getValueOffset() + i, valueFields.get(i).getType()); } + binaryEdge.setValue(rowEncoder.encode(ObjectRow.create(values))); + return binaryEdge; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultRowEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultRowEncoder.java index 082ddf421..bf29e8a54 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultRowEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultRowEncoder.java @@ -25,6 +25,7 @@ import static org.apache.geaflow.dsl.common.binary.BinaryLayoutHelper.zeroBytes; import java.util.List; + import org.apache.geaflow.dsl.common.binary.BinaryLayoutHelper; import org.apache.geaflow.dsl.common.binary.FieldWriterFactory; import org.apache.geaflow.dsl.common.binary.FieldWriterFactory.PropertyFieldWriter; @@ -39,46 +40,44 @@ public class DefaultRowEncoder implements RowEncoder { - private final WriterBuffer writerBuffer; - private final StructType rowType; + private final WriterBuffer writerBuffer; + private final StructType rowType; - public DefaultRowEncoder(StructType rowType) { - this.writerBuffer = new HeapWriterBuffer(); - this.rowType = rowType; - writerBuffer.initialize(BinaryLayoutHelper.getInitBufferSize(rowType.size())); - } + public DefaultRowEncoder(StructType rowType) { + this.writerBuffer = new HeapWriterBuffer(); + this.rowType = rowType; + writerBuffer.initialize(BinaryLayoutHelper.getInitBufferSize(rowType.size())); + } - @Override - public BinaryRow encode(Row row) { - if (row instanceof BinaryRow) { - return (BinaryRow) row; - } - writerBuffer.reset(); - // write fields num - writerBuffer.writeInt(rowType.size()); - writerBuffer.setExtendPoint(getExtendPoint(rowType.size())); - // write null bit set - byte[] nullBitSet = new byte[getBitSetBytes(rowType.size())]; - if (nullBitSet.length > 0) { - zeroBytes(nullBitSet); - writerBuffer.writeBytes(nullBitSet); - } - // write all values - List fields = rowType.getFields(); - for (int i = 0; i < fields.size(); i++) { - TableField field = fields.get(i); - Object value = row.getField(i, field.getType()); - Object castValue = TypeCastUtil.cast(value, field.getType()); - PropertyFieldWriter writer = FieldWriterFactory - .getPropertyFieldWriter(field.getType()); - try { - writer.write(writerBuffer, NULL_BIT_OFFSET, i, castValue); - } catch (Exception e) { - throw new GeaFlowDSLException("Fail to write: " + field + ", value is: " + value, - e); - } - } - byte[] rowBytes = (byte[]) writerBuffer.copyBuffer(); - return BinaryRow.of(rowBytes); + @Override + public BinaryRow encode(Row row) { + if (row instanceof BinaryRow) { + return (BinaryRow) row; + } + writerBuffer.reset(); + // write fields num + writerBuffer.writeInt(rowType.size()); + writerBuffer.setExtendPoint(getExtendPoint(rowType.size())); + // write null bit set + byte[] nullBitSet = new byte[getBitSetBytes(rowType.size())]; + if (nullBitSet.length > 0) { + zeroBytes(nullBitSet); + writerBuffer.writeBytes(nullBitSet); + } + // write all values + List fields = rowType.getFields(); + for (int i = 0; i < fields.size(); i++) { + TableField field = fields.get(i); + Object value = row.getField(i, field.getType()); + Object castValue = TypeCastUtil.cast(value, field.getType()); + PropertyFieldWriter writer = FieldWriterFactory.getPropertyFieldWriter(field.getType()); + try { + writer.write(writerBuffer, NULL_BIT_OFFSET, i, castValue); + } catch (Exception e) { + throw new GeaFlowDSLException("Fail to write: " + field + ", value is: " + value, e); + } } + byte[] rowBytes = (byte[]) writerBuffer.copyBuffer(); + return BinaryRow.of(rowBytes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultVertexEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultVertexEncoder.java index f72ff1de4..502c62889 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultVertexEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/DefaultVertexEncoder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.binary.encoder; import java.util.List; + import org.apache.geaflow.dsl.common.binary.EncoderFactory; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -31,27 +32,26 @@ public class DefaultVertexEncoder implements VertexEncoder { - private final RowEncoder rowEncoder; - private final VertexType vertexType; + private final RowEncoder rowEncoder; + private final VertexType vertexType; - public DefaultVertexEncoder(VertexType vertexType) { - StructType rowType = new StructType(vertexType.getValueFields()); - this.rowEncoder = EncoderFactory.createRowEncoder(rowType); - this.vertexType = vertexType; - } + public DefaultVertexEncoder(VertexType vertexType) { + StructType rowType = new StructType(vertexType.getValueFields()); + this.rowEncoder = EncoderFactory.createRowEncoder(rowType); + this.vertexType = vertexType; + } - @Override - public RowVertex encode(RowVertex rowVertex) { - RowVertex binaryVertex = VertexEdgeFactory.createVertex(vertexType); - binaryVertex.setId(BinaryUtil.toBinaryForString(rowVertex.getId())); - binaryVertex.setBinaryLabel(rowVertex.getBinaryLabel()); - Object[] values = new Object[vertexType.getValueSize()]; - List valueFields = vertexType.getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - values[i] = rowVertex.getField(vertexType.getValueOffset() + i, - valueFields.get(i).getType()); - } - binaryVertex.setValue(rowEncoder.encode(ObjectRow.create(values))); - return binaryVertex; + @Override + public RowVertex encode(RowVertex rowVertex) { + RowVertex binaryVertex = VertexEdgeFactory.createVertex(vertexType); + binaryVertex.setId(BinaryUtil.toBinaryForString(rowVertex.getId())); + binaryVertex.setBinaryLabel(rowVertex.getBinaryLabel()); + Object[] values = new Object[vertexType.getValueSize()]; + List valueFields = vertexType.getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + values[i] = rowVertex.getField(vertexType.getValueOffset() + i, valueFields.get(i).getType()); } + binaryVertex.setValue(rowEncoder.encode(ObjectRow.create(values))); + return binaryVertex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/EdgeEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/EdgeEncoder.java index 77c60afe7..890258e21 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/EdgeEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/EdgeEncoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.RowEdge; -public interface EdgeEncoder extends IBinaryEncoder { - -} +public interface EdgeEncoder extends IBinaryEncoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/IBinaryEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/IBinaryEncoder.java index 3857c24e5..2e1403a74 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/IBinaryEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/IBinaryEncoder.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.common.binary.encoder; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface IBinaryEncoder extends Serializable { - OUT encode(IN row); + OUT encode(IN row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/RowEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/RowEncoder.java index e995cb43a..3842d8daf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/RowEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/RowEncoder.java @@ -22,6 +22,4 @@ import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.BinaryRow; -public interface RowEncoder extends IBinaryEncoder { - -} +public interface RowEncoder extends IBinaryEncoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/VertexEncoder.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/VertexEncoder.java index 6cfb6058b..8b8e8d130 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/VertexEncoder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/binary/encoder/VertexEncoder.java @@ -21,6 +21,4 @@ import org.apache.geaflow.dsl.common.data.RowVertex; -public interface VertexEncoder extends IBinaryEncoder { - -} +public interface VertexEncoder extends IBinaryEncoder {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileContext.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileContext.java index bb8857c7c..632d7ad03 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileContext.java @@ -24,32 +24,32 @@ public class CompileContext { - private Map config; + private Map config; - private Map parallelisms; + private Map parallelisms; - public CompileContext() { - this(new HashMap<>(), new HashMap<>()); - } + public CompileContext() { + this(new HashMap<>(), new HashMap<>()); + } - public CompileContext(Map config, Map parallelisms) { - this.config = config; - this.parallelisms = parallelisms; - } + public CompileContext(Map config, Map parallelisms) { + this.config = config; + this.parallelisms = parallelisms; + } - public Map getConfig() { - return config; - } + public Map getConfig() { + return config; + } - public Map getParallelisms() { - return parallelisms; - } + public Map getParallelisms() { + return parallelisms; + } - public void setConfig(Map config) { - this.config = config; - } + public void setConfig(Map config) { + this.config = config; + } - public void setParallelisms(Map parallelisms) { - this.parallelisms = parallelisms; - } + public void setParallelisms(Map parallelisms) { + this.parallelisms = parallelisms; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileResult.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileResult.java index afd8cd741..ad2c054aa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileResult.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/CompileResult.java @@ -21,72 +21,71 @@ import java.io.Serializable; import java.util.Set; + import org.apache.calcite.rel.type.RelDataType; import org.apache.geaflow.common.visualization.console.JsonPlan; public class CompileResult implements Serializable { - private JsonPlan physicPlan; - - private Set sourceTables; + private JsonPlan physicPlan; - private Set targetTables; + private Set sourceTables; - private Set sourceGraphs; + private Set targetTables; - private Set targetGraphs; + private Set sourceGraphs; - private RelDataType currentResultType; + private Set targetGraphs; - public CompileResult() { + private RelDataType currentResultType; - } + public CompileResult() {} - public RelDataType getCurrentResultType() { - return currentResultType; - } + public RelDataType getCurrentResultType() { + return currentResultType; + } - public void setCurrentResultType(RelDataType currentResultType) { - this.currentResultType = currentResultType; - } + public void setCurrentResultType(RelDataType currentResultType) { + this.currentResultType = currentResultType; + } - public JsonPlan getPhysicPlan() { - return physicPlan; - } + public JsonPlan getPhysicPlan() { + return physicPlan; + } - public void setPhysicPlan(JsonPlan physicPlan) { - this.physicPlan = physicPlan; - } + public void setPhysicPlan(JsonPlan physicPlan) { + this.physicPlan = physicPlan; + } - public Set getSourceTables() { - return sourceTables; - } + public Set getSourceTables() { + return sourceTables; + } - public void setSourceTables(Set sourceTables) { - this.sourceTables = sourceTables; - } + public void setSourceTables(Set sourceTables) { + this.sourceTables = sourceTables; + } - public Set getTargetTables() { - return targetTables; - } + public Set getTargetTables() { + return targetTables; + } - public void setTargetTables(Set targetTables) { - this.targetTables = targetTables; - } + public void setTargetTables(Set targetTables) { + this.targetTables = targetTables; + } - public Set getSourceGraphs() { - return sourceGraphs; - } + public Set getSourceGraphs() { + return sourceGraphs; + } - public void setSourceGraphs(Set sourceGraphs) { - this.sourceGraphs = sourceGraphs; - } + public void setSourceGraphs(Set sourceGraphs) { + this.sourceGraphs = sourceGraphs; + } - public Set getTargetGraphs() { - return targetGraphs; - } + public Set getTargetGraphs() { + return targetGraphs; + } - public void setTargetGraphs(Set targetGraphs) { - this.targetGraphs = targetGraphs; - } + public void setTargetGraphs(Set targetGraphs) { + this.targetGraphs = targetGraphs; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/FunctionInfo.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/FunctionInfo.java index 71e7af95a..62b74af1e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/FunctionInfo.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/FunctionInfo.java @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -25,43 +24,43 @@ public class FunctionInfo implements Serializable { - private final String instanceName; + private final String instanceName; - private final String functionName; + private final String functionName; - public FunctionInfo(String instanceName, String functionName) { - this.instanceName = instanceName; - this.functionName = functionName; - } + public FunctionInfo(String instanceName, String functionName) { + this.instanceName = instanceName; + this.functionName = functionName; + } - public String getInstanceName() { - return instanceName; - } + public String getInstanceName() { + return instanceName; + } - public String getFunctionName() { - return functionName; - } + public String getFunctionName() { + return functionName; + } - @Override - public String toString() { - return instanceName + "." + functionName; - } + @Override + public String toString() { + return instanceName + "." + functionName; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof FunctionInfo)) { - return false; - } - FunctionInfo that = (FunctionInfo) o; - return Objects.equals(instanceName, that.instanceName) && Objects.equals(functionName, - that.functionName); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(instanceName, functionName); + if (!(o instanceof FunctionInfo)) { + return false; } + FunctionInfo that = (FunctionInfo) o; + return Objects.equals(instanceName, that.instanceName) + && Objects.equals(functionName, that.functionName); + } + + @Override + public int hashCode() { + return Objects.hash(instanceName, functionName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/GraphInfo.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/GraphInfo.java index add9afb44..9dc85e7d1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/GraphInfo.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/GraphInfo.java @@ -23,20 +23,20 @@ public class GraphInfo implements Serializable { - private final String instanceName; + private final String instanceName; - private final String graphName; + private final String graphName; - public GraphInfo(String instanceName, String graphName) { - this.instanceName = instanceName; - this.graphName = graphName; - } + public GraphInfo(String instanceName, String graphName) { + this.instanceName = instanceName; + this.graphName = graphName; + } - public String getInstanceName() { - return instanceName; - } + public String getInstanceName() { + return instanceName; + } - public String getGraphName() { - return graphName; - } + public String getGraphName() { + return graphName; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/QueryCompiler.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/QueryCompiler.java index f1a205eef..554c27dc0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/QueryCompiler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/QueryCompiler.java @@ -23,22 +23,17 @@ public interface QueryCompiler { - /** - * Compile the dsl script to generate the {@link CompileResult}. - */ - CompileResult compile(String script, CompileContext context); + /** Compile the dsl script to generate the {@link CompileResult}. */ + CompileResult compile(String script, CompileContext context); - /** - * Get the UnResolved functions in the dsl script. - */ - Set getUnResolvedFunctions(String script, CompileContext context); + /** Get the UnResolved functions in the dsl script. */ + Set getUnResolvedFunctions(String script, CompileContext context); - Set getDeclaredTablePlugins(String script, CompileContext context); + Set getDeclaredTablePlugins(String script, CompileContext context); - Set getEnginePlugins(); + Set getEnginePlugins(); - Set getUnResolvedTables(String script, CompileContext context); - - String formatOlapResult(String script, Object queryResult, CompileContext context); + Set getUnResolvedTables(String script, CompileContext context); + String formatOlapResult(String script, Object queryResult, CompileContext context); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/TableInfo.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/TableInfo.java index b5c140ff3..72d061001 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/TableInfo.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/compile/TableInfo.java @@ -24,37 +24,38 @@ public class TableInfo implements Serializable { - private final String instanceName; + private final String instanceName; - private final String tableName; + private final String tableName; - public TableInfo(String instanceName, String tableName) { - this.instanceName = instanceName; - this.tableName = tableName; - } + public TableInfo(String instanceName, String tableName) { + this.instanceName = instanceName; + this.tableName = tableName; + } - public String getInstanceName() { - return instanceName; - } + public String getInstanceName() { + return instanceName; + } - public String getTableName() { - return tableName; - } + public String getTableName() { + return tableName; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof TableInfo)) { - return false; - } - TableInfo that = (TableInfo) o; - return Objects.equals(tableName, that.tableName) && Objects.equals(instanceName, that.instanceName); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(tableName, instanceName); + if (!(o instanceof TableInfo)) { + return false; } + TableInfo that = (TableInfo) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(instanceName, that.instanceName); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, instanceName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Accumulator.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Accumulator.java index 6ca8a9805..7339f1e80 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Accumulator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Accumulator.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public interface Accumulator extends Serializable { - -} +public interface Accumulator extends Serializable {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/IGraphElementWithBinaryLabel.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/IGraphElementWithBinaryLabel.java index e59368a7f..909ee4903 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/IGraphElementWithBinaryLabel.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/IGraphElementWithBinaryLabel.java @@ -23,18 +23,17 @@ public interface IGraphElementWithBinaryLabel { - /** - * Get binary label. - * - * @return label - */ - BinaryString getBinaryLabel(); - - /** - * Set binary label. - * - * @param label label - */ - void setBinaryLabel(BinaryString label); + /** + * Get binary label. + * + * @return label + */ + BinaryString getBinaryLabel(); + /** + * Set binary label. + * + * @param label label + */ + void setBinaryLabel(BinaryString label); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/ParameterizedRow.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/ParameterizedRow.java index 33713fcef..829706e6f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/ParameterizedRow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/ParameterizedRow.java @@ -21,9 +21,9 @@ public interface ParameterizedRow extends Row { - Row getParameter(); + Row getParameter(); - Row getSystemVariables(); + Row getSystemVariables(); - Object getRequestId(); + Object getRequestId(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Path.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Path.java index 32078ac17..68fb04928 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Path.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Path.java @@ -21,27 +21,28 @@ import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.type.IType; public interface Path extends Row { - Row getField(int i, IType type); + Row getField(int i, IType type); - List getPathNodes(); + List getPathNodes(); - void addNode(Row node); + void addNode(Row node); - void remove(int index); + void remove(int index); - Path copy(); + Path copy(); - int size(); + int size(); - Path subPath(Collection indices); + Path subPath(Collection indices); - Path subPath(int[] indices); + Path subPath(int[] indices); - long getId(); + long getId(); - void setId(long id); + void setId(long id); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Row.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Row.java index 162df9a7a..638ba413f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Row.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/Row.java @@ -23,21 +23,22 @@ public interface Row extends StepRecord { - Object getField(int i, IType type); + Object getField(int i, IType type); - Row EMPTY = (i, type) -> { + Row EMPTY = + (i, type) -> { throw new IllegalArgumentException("Cannot getField from empty row"); - }; + }; - default StepRecordType getType() { - return StepRecordType.ROW; - } + default StepRecordType getType() { + return StepRecordType.ROW; + } - default Object[] getFields(IType[] types) { - Object[] fields = new Object[types.length]; - for (int i = 0; i < fields.length; i++) { - fields[i] = getField(i, types[i]); - } - return fields; + default Object[] getFields(IType[] types) { + Object[] fields = new Object[types.length]; + for (int i = 0; i < fields.length; i++) { + fields[i] = getField(i, types[i]); } + return fields; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowEdge.java index 06c9deb18..35d89bb1d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowEdge.java @@ -23,12 +23,12 @@ import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; -public interface RowEdge extends IEdge, IGraphElementWithLabelField, - IGraphElementWithBinaryLabel, Row { +public interface RowEdge + extends IEdge, IGraphElementWithLabelField, IGraphElementWithBinaryLabel, Row { - void setValue(Row value); + void setValue(Row value); - RowEdge withDirection(EdgeDirection direction); + RowEdge withDirection(EdgeDirection direction); - RowEdge identityReverse(); + RowEdge identityReverse(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKey.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKey.java index 0bc309ad6..72f88c96d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKey.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKey.java @@ -21,5 +21,5 @@ public interface RowKey extends Row, VirtualId { - Object[] getKeys(); + Object[] getKeys(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKeyWithRequestId.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKeyWithRequestId.java index bcea3d30e..ced8d32f7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKeyWithRequestId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowKeyWithRequestId.java @@ -21,5 +21,5 @@ public interface RowKeyWithRequestId extends RowKey { - Object getRequestId(); + Object getRequestId(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowVertex.java index 24828f752..a34721a13 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/RowVertex.java @@ -22,8 +22,8 @@ import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.vertex.IVertex; -public interface RowVertex extends IVertex, IGraphElementWithLabelField, - IGraphElementWithBinaryLabel, Row { +public interface RowVertex + extends IVertex, IGraphElementWithLabelField, IGraphElementWithBinaryLabel, Row { - void setValue(Row value); + void setValue(Row value); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/StepRecord.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/StepRecord.java index 65c1fea87..58845e0bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/StepRecord.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/StepRecord.java @@ -23,38 +23,26 @@ public interface StepRecord extends Serializable { - StepRecordType getType(); - - enum StepRecordType { - - /** - * Represent a vertex. - */ - VERTEX, - - /** - * Represent an edge group of the vertex neighbor edges. - */ - EDGE_GROUP, - - /** - * Represent the end of data stream. - */ - EOD, - - /** - * Represent the result value for sub-query. - */ - SINGLE_VALUE, - - /** - * Represent a row. - */ - ROW, - - /** - * Represent a record with key. - */ - KEY_RECORD - } + StepRecordType getType(); + + enum StepRecordType { + + /** Represent a vertex. */ + VERTEX, + + /** Represent an edge group of the vertex neighbor edges. */ + EDGE_GROUP, + + /** Represent the end of data stream. */ + EOD, + + /** Represent the result value for sub-query. */ + SINGLE_VALUE, + + /** Represent a row. */ + ROW, + + /** Represent a record with key. */ + KEY_RECORD + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/VirtualId.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/VirtualId.java index 6ae3ca089..050703bdb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/VirtualId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/VirtualId.java @@ -21,5 +21,4 @@ import java.io.Serializable; -public interface VirtualId extends Serializable { -} +public interface VirtualId extends Serializable {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/BinaryRow.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/BinaryRow.java index 5b3fc43e7..7b0ec8081 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/BinaryRow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/BinaryRow.java @@ -25,11 +25,8 @@ import static org.apache.geaflow.dsl.common.binary.BinaryLayoutHelper.getFieldsNum; import static org.apache.geaflow.dsl.common.binary.BinaryLayoutHelper.isSet; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; + import org.apache.geaflow.common.binary.HeapBinaryObject; import org.apache.geaflow.common.binary.IBinaryObject; import org.apache.geaflow.common.type.IType; @@ -37,71 +34,73 @@ import org.apache.geaflow.dsl.common.binary.FieldReaderFactory.PropertyFieldReader; import org.apache.geaflow.dsl.common.data.Row; -public class BinaryRow implements Row, KryoSerializable { - - private IBinaryObject binaryObject; - - private BinaryRow() { - - } - - private BinaryRow(byte[] bytes) { - this.binaryObject = HeapBinaryObject.of(bytes); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - public static BinaryRow of(byte[] bytes) { - return new BinaryRow(bytes); - } +public class BinaryRow implements Row, KryoSerializable { - @Override - public Object getField(int i, IType type) { - if (isNullValue(i)) { - return null; - } - PropertyFieldReader reader = FieldReaderFactory.getPropertyFieldReader(type); - int fieldsNum = getFieldsNum(binaryObject); - long offset = getFieldOffset(getBitSetBytes(fieldsNum), i); - return reader.read(binaryObject, offset); - } + private IBinaryObject binaryObject; - @Override - public String toString() { - return "BinaryRow{" + "binaryObject=" + binaryObject + '}'; - } + private BinaryRow() {} - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BinaryRow binaryRow = (BinaryRow) o; - return Objects.equals(binaryObject, binaryRow.binaryObject); - } + private BinaryRow(byte[] bytes) { + this.binaryObject = HeapBinaryObject.of(bytes); + } - @Override - public int hashCode() { - return Objects.hash(binaryObject); - } + public static BinaryRow of(byte[] bytes) { + return new BinaryRow(bytes); + } - private boolean isNullValue(int index) { - return isSet(binaryObject, NULL_BIT_OFFSET, index); + @Override + public Object getField(int i, IType type) { + if (isNullValue(i)) { + return null; } - - @Override - public void write(Kryo kryo, Output output) { - byte[] bytes = this.binaryObject.toBytes(); - output.writeInt(bytes.length); - output.writeBytes(bytes); + PropertyFieldReader reader = FieldReaderFactory.getPropertyFieldReader(type); + int fieldsNum = getFieldsNum(binaryObject); + long offset = getFieldOffset(getBitSetBytes(fieldsNum), i); + return reader.read(binaryObject, offset); + } + + @Override + public String toString() { + return "BinaryRow{" + "binaryObject=" + binaryObject + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public void read(Kryo kryo, Input input) { - int length = input.readInt(); - byte[] bytes = input.readBytes(length); - this.binaryObject = HeapBinaryObject.of(bytes); + if (o == null || getClass() != o.getClass()) { + return false; } - + BinaryRow binaryRow = (BinaryRow) o; + return Objects.equals(binaryObject, binaryRow.binaryObject); + } + + @Override + public int hashCode() { + return Objects.hash(binaryObject); + } + + private boolean isNullValue(int index) { + return isSet(binaryObject, NULL_BIT_OFFSET, index); + } + + @Override + public void write(Kryo kryo, Output output) { + byte[] bytes = this.binaryObject.toBytes(); + output.writeInt(bytes.length); + output.writeBytes(bytes); + } + + @Override + public void read(Kryo kryo, Input input) { + int length = input.readInt(); + byte[] bytes = input.readBytes(length); + this.binaryObject = HeapBinaryObject.of(bytes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedPath.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedPath.java index 5ab0f9757..a75aef471 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedPath.java @@ -19,133 +19,136 @@ package org.apache.geaflow.dsl.common.data.impl; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; -public class DefaultParameterizedPath implements ParameterizedPath { - - private final Path basePath; - - private final Object requestId; - - private final Row parameterRow; - - private final Row systemVariableRow; - - public DefaultParameterizedPath(Path basePath, Object requestId, Row parameterRow, - Row systemVariableRow) { - this.basePath = basePath; - this.requestId = requestId; - this.parameterRow = parameterRow; - this.systemVariableRow = systemVariableRow; - } - - public DefaultParameterizedPath(Path basePath, Object requestId, Row parameterRow) { - this(basePath, requestId, parameterRow, null); - } - - @Override - public Row getParameter() { - return parameterRow; - } - - @Override - public Row getSystemVariables() { - return systemVariableRow; - } - - @Override - public Object getRequestId() { - return requestId; - } - - @Override - public Row getField(int i, IType type) { - return basePath.getField(i, type); - } - - @Override - public List getPathNodes() { - return basePath.getPathNodes(); - } - - @Override - public void addNode(Row node) { - basePath.addNode(node); - } - - @Override - public void remove(int index) { - basePath.remove(index); - } - - @Override - public Path copy() { - return new DefaultParameterizedPath(basePath.copy(), requestId, - parameterRow, systemVariableRow); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public int size() { - return basePath.size(); - } +public class DefaultParameterizedPath implements ParameterizedPath { - @Override - public Path subPath(Collection indices) { - return new DefaultParameterizedPath(basePath.subPath(indices), requestId, - parameterRow, systemVariableRow); - } + private final Path basePath; + + private final Object requestId; + + private final Row parameterRow; + + private final Row systemVariableRow; + + public DefaultParameterizedPath( + Path basePath, Object requestId, Row parameterRow, Row systemVariableRow) { + this.basePath = basePath; + this.requestId = requestId; + this.parameterRow = parameterRow; + this.systemVariableRow = systemVariableRow; + } + + public DefaultParameterizedPath(Path basePath, Object requestId, Row parameterRow) { + this(basePath, requestId, parameterRow, null); + } + + @Override + public Row getParameter() { + return parameterRow; + } + + @Override + public Row getSystemVariables() { + return systemVariableRow; + } + + @Override + public Object getRequestId() { + return requestId; + } + + @Override + public Row getField(int i, IType type) { + return basePath.getField(i, type); + } + + @Override + public List getPathNodes() { + return basePath.getPathNodes(); + } + + @Override + public void addNode(Row node) { + basePath.addNode(node); + } + + @Override + public void remove(int index) { + basePath.remove(index); + } + + @Override + public Path copy() { + return new DefaultParameterizedPath( + basePath.copy(), requestId, parameterRow, systemVariableRow); + } + + @Override + public int size() { + return basePath.size(); + } + + @Override + public Path subPath(Collection indices) { + return new DefaultParameterizedPath( + basePath.subPath(indices), requestId, parameterRow, systemVariableRow); + } + + @Override + public Path subPath(int[] indices) { + return new DefaultParameterizedPath( + basePath.subPath(indices), requestId, parameterRow, systemVariableRow); + } + + @Override + public long getId() { + return basePath.getId(); + } + + @Override + public void setId(long id) { + basePath.setId(id); + } + + public static class DefaultParameterizedPathSerializer + extends Serializer { @Override - public Path subPath(int[] indices) { - return new DefaultParameterizedPath(basePath.subPath(indices), requestId, - parameterRow, systemVariableRow); + public void write(Kryo kryo, Output output, DefaultParameterizedPath object) { + kryo.writeClassAndObject(output, object.basePath); + kryo.writeClassAndObject(output, object.getRequestId()); + kryo.writeClassAndObject(output, object.getParameter()); + kryo.writeClassAndObject(output, object.getSystemVariables()); } @Override - public long getId() { - return basePath.getId(); + public DefaultParameterizedPath read( + Kryo kryo, Input input, Class aClass) { + Path basePath = (Path) kryo.readClassAndObject(input); + Object requestId = kryo.readClassAndObject(input); + Row parameterRow = (Row) kryo.readClassAndObject(input); + Row systemVariableRow = (Row) kryo.readClassAndObject(input); + return new DefaultParameterizedPath(basePath, requestId, parameterRow, systemVariableRow); } @Override - public void setId(long id) { - basePath.setId(id); - } - - public static class DefaultParameterizedPathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, DefaultParameterizedPath object) { - kryo.writeClassAndObject(output, object.basePath); - kryo.writeClassAndObject(output, object.getRequestId()); - kryo.writeClassAndObject(output, object.getParameter()); - kryo.writeClassAndObject(output, object.getSystemVariables()); - } - - @Override - public DefaultParameterizedPath read(Kryo kryo, Input input, Class aClass) { - Path basePath = (Path) kryo.readClassAndObject(input); - Object requestId = kryo.readClassAndObject(input); - Row parameterRow = (Row) kryo.readClassAndObject(input); - Row systemVariableRow = (Row) kryo.readClassAndObject(input); - return new DefaultParameterizedPath(basePath, requestId, parameterRow, systemVariableRow); - } - - @Override - public DefaultParameterizedPath copy(Kryo kryo, DefaultParameterizedPath original) { - Path basePath = kryo.copy(original.basePath); - Object requestId = kryo.copy(original.getRequestId()); - Row parameterRow = kryo.copy(original.getParameter()); - Row systemVariableRow = kryo.copy(original.getSystemVariables()); - return new DefaultParameterizedPath(basePath, requestId, parameterRow, systemVariableRow); - } - } - + public DefaultParameterizedPath copy(Kryo kryo, DefaultParameterizedPath original) { + Path basePath = kryo.copy(original.basePath); + Object requestId = kryo.copy(original.getRequestId()); + Row parameterRow = kryo.copy(original.getParameter()); + Row systemVariableRow = kryo.copy(original.getSystemVariables()); + return new DefaultParameterizedPath(basePath, requestId, parameterRow, systemVariableRow); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedRow.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedRow.java index 64f12c13f..f42abd5ab 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedRow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultParameterizedRow.java @@ -19,82 +19,85 @@ package org.apache.geaflow.dsl.common.data.impl; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.dsl.common.data.ParameterizedRow; +import org.apache.geaflow.dsl.common.data.Row; + import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import org.apache.geaflow.common.type.IType; -import org.apache.geaflow.dsl.common.data.ParameterizedRow; -import org.apache.geaflow.dsl.common.data.Row; public class DefaultParameterizedRow implements ParameterizedRow { - private final Row baseRow; + private final Row baseRow; - private final Object requestId; + private final Object requestId; - private final Row parameterRow; + private final Row parameterRow; - private final Row systemVariableRow; + private final Row systemVariableRow; - public DefaultParameterizedRow(Row baseRow, Object requestId, Row parameterRow, Row systemVariableRow) { - this.baseRow = baseRow; - this.requestId = requestId; - this.parameterRow = parameterRow; - this.systemVariableRow = systemVariableRow; - } + public DefaultParameterizedRow( + Row baseRow, Object requestId, Row parameterRow, Row systemVariableRow) { + this.baseRow = baseRow; + this.requestId = requestId; + this.parameterRow = parameterRow; + this.systemVariableRow = systemVariableRow; + } - public DefaultParameterizedRow(Row baseRow, Object requestId, Row parameterRow) { - this(baseRow, requestId, parameterRow, null); - } + public DefaultParameterizedRow(Row baseRow, Object requestId, Row parameterRow) { + this(baseRow, requestId, parameterRow, null); + } - @Override - public Object getField(int i, IType type) { - return baseRow.getField(i, type); - } + @Override + public Object getField(int i, IType type) { + return baseRow.getField(i, type); + } - @Override - public Row getParameter() { - return parameterRow; - } + @Override + public Row getParameter() { + return parameterRow; + } + + @Override + public Row getSystemVariables() { + return systemVariableRow; + } + + @Override + public Object getRequestId() { + return requestId; + } + + public static class DefaultParameterizedRowSerializer + extends Serializer { @Override - public Row getSystemVariables() { - return systemVariableRow; + public void write(Kryo kryo, Output output, DefaultParameterizedRow object) { + kryo.writeClassAndObject(output, object.baseRow); + kryo.writeClassAndObject(output, object.getRequestId()); + kryo.writeClassAndObject(output, object.getParameter()); + kryo.writeClassAndObject(output, object.getSystemVariables()); } @Override - public Object getRequestId() { - return requestId; + public DefaultParameterizedRow read( + Kryo kryo, Input input, Class aClass) { + Row baseRow = (Row) kryo.readClassAndObject(input); + Object requestId = kryo.readClassAndObject(input); + Row parameterRow = (Row) kryo.readClassAndObject(input); + Row systemVariableRow = (Row) kryo.readClassAndObject(input); + return new DefaultParameterizedRow(baseRow, requestId, parameterRow, systemVariableRow); } - public static class DefaultParameterizedRowSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, DefaultParameterizedRow object) { - kryo.writeClassAndObject(output, object.baseRow); - kryo.writeClassAndObject(output, object.getRequestId()); - kryo.writeClassAndObject(output, object.getParameter()); - kryo.writeClassAndObject(output, object.getSystemVariables()); - } - - @Override - public DefaultParameterizedRow read(Kryo kryo, Input input, Class aClass) { - Row baseRow = (Row) kryo.readClassAndObject(input); - Object requestId = kryo.readClassAndObject(input); - Row parameterRow = (Row) kryo.readClassAndObject(input); - Row systemVariableRow = (Row) kryo.readClassAndObject(input); - return new DefaultParameterizedRow(baseRow, requestId, parameterRow, systemVariableRow); - } - - @Override - public DefaultParameterizedRow copy(Kryo kryo, DefaultParameterizedRow original) { - Row baseRow = kryo.copy(original.baseRow); - Object requestId = kryo.copy(original.getRequestId()); - Row parameterRow = kryo.copy(original.getParameter()); - Row systemVariableRow = kryo.copy(original.getSystemVariables()); - return new DefaultParameterizedRow(baseRow, requestId, parameterRow, systemVariableRow); - } + @Override + public DefaultParameterizedRow copy(Kryo kryo, DefaultParameterizedRow original) { + Row baseRow = kryo.copy(original.baseRow); + Object requestId = kryo.copy(original.getRequestId()); + Row parameterRow = kryo.copy(original.getParameter()); + Row systemVariableRow = kryo.copy(original.getSystemVariables()); + return new DefaultParameterizedRow(baseRow, requestId, parameterRow, systemVariableRow); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultPath.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultPath.java index 0a64c69f1..4367e952c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultPath.java @@ -19,138 +19,138 @@ package org.apache.geaflow.dsl.common.data.impl; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Lists; + public class DefaultPath implements Path { - private long id; + private long id; - private final List pathNodes; + private final List pathNodes; - public DefaultPath(List pathNodes, long id) { - this.pathNodes = Objects.requireNonNull(pathNodes); - this.id = id; - } + public DefaultPath(List pathNodes, long id) { + this.pathNodes = Objects.requireNonNull(pathNodes); + this.id = id; + } - public DefaultPath(List pathNodes) { - this(pathNodes, -1L); - } + public DefaultPath(List pathNodes) { + this(pathNodes, -1L); + } - public DefaultPath(Row[] pathNodes) { - this(Lists.newArrayList(pathNodes)); - } + public DefaultPath(Row[] pathNodes) { + this(Lists.newArrayList(pathNodes)); + } - public DefaultPath() { - this(new ArrayList<>()); - } + public DefaultPath() { + this(new ArrayList<>()); + } - @Override - public Row getField(int i, IType type) { - return pathNodes.get(i); - } + @Override + public Row getField(int i, IType type) { + return pathNodes.get(i); + } - @Override - public List getPathNodes() { - return pathNodes; - } + @Override + public List getPathNodes() { + return pathNodes; + } - @Override - public void addNode(Row node) { - pathNodes.add(node); - } + @Override + public void addNode(Row node) { + pathNodes.add(node); + } - @Override - public void remove(int index) { - pathNodes.remove(index); - } + @Override + public void remove(int index) { + pathNodes.remove(index); + } - @Override - public Path copy() { - return new DefaultPath(Lists.newArrayList(pathNodes), id); - } + @Override + public Path copy() { + return new DefaultPath(Lists.newArrayList(pathNodes), id); + } - @Override - public int size() { - return pathNodes.size(); - } + @Override + public int size() { + return pathNodes.size(); + } - @Override - public Path subPath(Collection indices) { - List indexList = new ArrayList<>(indices); - Collections.sort(indexList); - - Path subPath = new DefaultPath(); - for (Integer index : indexList) { - subPath.addNode(pathNodes.get(index)); - } - return subPath; - } + @Override + public Path subPath(Collection indices) { + List indexList = new ArrayList<>(indices); + Collections.sort(indexList); - @Override - public Path subPath(int[] indices) { - Path subPath = new DefaultPath(); - for (Integer index : indices) { - subPath.addNode(pathNodes.get(index)); - } - return subPath; + Path subPath = new DefaultPath(); + for (Integer index : indexList) { + subPath.addNode(pathNodes.get(index)); } + return subPath; + } - @Override - public String toString() { - return "DefaultPath{" - + "pathNodes=" + pathNodes - + '}'; + @Override + public Path subPath(int[] indices) { + Path subPath = new DefaultPath(); + for (Integer index : indices) { + subPath.addNode(pathNodes.get(index)); } + return subPath; + } + + @Override + public String toString() { + return "DefaultPath{" + "pathNodes=" + pathNodes + '}'; + } + + @Override + public long getId() { + return id; + } + + @Override + public void setId(long id) { + this.id = id; + } + + public static class DefaultPathSerializer extends Serializer { @Override - public long getId() { - return id; + public void write(Kryo kryo, Output output, DefaultPath defaultPath) { + output.writeInt(defaultPath.getPathNodes().size()); + for (Row pathNode : defaultPath.getPathNodes()) { + kryo.writeClassAndObject(output, pathNode); + } } @Override - public void setId(long id) { - this.id = id; + public DefaultPath read(Kryo kryo, Input input, Class aClass) { + int size = input.readInt(); + List pathNodes = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + pathNodes.add((Row) kryo.readClassAndObject(input)); + } + return new DefaultPath(pathNodes); } - public static class DefaultPathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, DefaultPath defaultPath) { - output.writeInt(defaultPath.getPathNodes().size()); - for (Row pathNode : defaultPath.getPathNodes()) { - kryo.writeClassAndObject(output, pathNode); - } - } - - @Override - public DefaultPath read(Kryo kryo, Input input, Class aClass) { - int size = input.readInt(); - List pathNodes = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - pathNodes.add((Row) kryo.readClassAndObject(input)); - } - return new DefaultPath(pathNodes); - } - - @Override - public DefaultPath copy(Kryo kryo, DefaultPath original) { - List pathNodes = new ArrayList<>(original.getPathNodes().size()); - for (Row pathNode : original.getPathNodes()) { - pathNodes.add(kryo.copy(pathNode)); - } - return new DefaultPath(pathNodes); - } - } + @Override + public DefaultPath copy(Kryo kryo, DefaultPath original) { + List pathNodes = new ArrayList<>(original.getPathNodes().size()); + for (Row pathNode : original.getPathNodes()) { + pathNodes.add(kryo.copy(pathNode)); + } + return new DefaultPath(pathNodes); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultRowKeyWithRequestId.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultRowKeyWithRequestId.java index ffbe44bf5..f999fa9a0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultRowKeyWithRequestId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/DefaultRowKeyWithRequestId.java @@ -19,77 +19,80 @@ package org.apache.geaflow.dsl.common.data.impl; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.RowKey; import org.apache.geaflow.dsl.common.data.RowKeyWithRequestId; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class DefaultRowKeyWithRequestId implements RowKeyWithRequestId { - private final Object requestId; + private final Object requestId; - private final RowKey rowKey; + private final RowKey rowKey; - public DefaultRowKeyWithRequestId(Object requestId, RowKey rowKey) { - this.requestId = requestId; - this.rowKey = rowKey; - } + public DefaultRowKeyWithRequestId(Object requestId, RowKey rowKey) { + this.requestId = requestId; + this.rowKey = rowKey; + } - @Override - public Object getRequestId() { - return requestId; - } + @Override + public Object getRequestId() { + return requestId; + } - @Override - public Object getField(int i, IType type) { - return rowKey.getField(i, type); - } + @Override + public Object getField(int i, IType type) { + return rowKey.getField(i, type); + } - @Override - public Object[] getKeys() { - return rowKey.getKeys(); + @Override + public Object[] getKeys() { + return rowKey.getKeys(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + if (!(o instanceof DefaultRowKeyWithRequestId)) { + return false; + } + DefaultRowKeyWithRequestId that = (DefaultRowKeyWithRequestId) o; + return Objects.equals(requestId, that.requestId) && Objects.equals(rowKey, that.rowKey); + } + + @Override + public int hashCode() { + return Objects.hash(requestId, rowKey); + } + + public static class DefaultRowKeyWithRequestIdSerializer + extends Serializer { @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof DefaultRowKeyWithRequestId)) { - return false; - } - DefaultRowKeyWithRequestId that = (DefaultRowKeyWithRequestId) o; - return Objects.equals(requestId, that.requestId) && Objects.equals(rowKey, that.rowKey); + public void write(Kryo kryo, Output output, DefaultRowKeyWithRequestId object) { + kryo.writeClassAndObject(output, object.getRequestId()); + kryo.writeClassAndObject(output, object.rowKey); } @Override - public int hashCode() { - return Objects.hash(requestId, rowKey); + public DefaultRowKeyWithRequestId read( + Kryo kryo, Input input, Class aClass) { + Object requestId = kryo.readClassAndObject(input); + RowKey rowKey = (RowKey) kryo.readClassAndObject(input); + return new DefaultRowKeyWithRequestId(requestId, rowKey); } - public static class DefaultRowKeyWithRequestIdSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, DefaultRowKeyWithRequestId object) { - kryo.writeClassAndObject(output, object.getRequestId()); - kryo.writeClassAndObject(output, object.rowKey); - } - - @Override - public DefaultRowKeyWithRequestId read(Kryo kryo, Input input, Class aClass) { - Object requestId = kryo.readClassAndObject(input); - RowKey rowKey = (RowKey) kryo.readClassAndObject(input); - return new DefaultRowKeyWithRequestId(requestId, rowKey); - } - - @Override - public DefaultRowKeyWithRequestId copy(Kryo kryo, DefaultRowKeyWithRequestId original) { - return new DefaultRowKeyWithRequestId(original.requestId, original.rowKey); - } + @Override + public DefaultRowKeyWithRequestId copy(Kryo kryo, DefaultRowKeyWithRequestId original) { + return new DefaultRowKeyWithRequestId(original.requestId, original.rowKey); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRow.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRow.java index 2a1472cf4..1bbe93da9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRow.java @@ -19,80 +19,82 @@ package org.apache.geaflow.dsl.common.data.impl; +import java.util.Arrays; + +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.dsl.common.data.Row; + import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import java.util.Arrays; -import org.apache.geaflow.common.type.IType; -import org.apache.geaflow.dsl.common.data.Row; public class ObjectRow implements Row { - private final Object[] fields; + private final Object[] fields; - private ObjectRow(Object[] fields) { - this.fields = fields; - } + private ObjectRow(Object[] fields) { + this.fields = fields; + } - public static ObjectRow create(Object... fields) { - return new ObjectRow(fields); - } + public static ObjectRow create(Object... fields) { + return new ObjectRow(fields); + } - @Override - public Object getField(int i, IType type) { - return fields[i]; - } + @Override + public Object getField(int i, IType type) { + return fields[i]; + } - public Object[] getFields() { - return fields; - } + public Object[] getFields() { + return fields; + } - @Override - public String toString() { - return Arrays.toString(fields); + @Override + public String toString() { + return Arrays.toString(fields); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + if (!(o instanceof ObjectRow)) { + return false; + } + ObjectRow objectRow = (ObjectRow) o; + return Arrays.equals(fields, objectRow.fields); + } + + @Override + public int hashCode() { + return Arrays.hashCode(fields); + } + + public static class ObjectRowSerializer extends Serializer { @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ObjectRow)) { - return false; - } - ObjectRow objectRow = (ObjectRow) o; - return Arrays.equals(fields, objectRow.fields); + public void write(Kryo kryo, Output output, ObjectRow objectRow) { + output.writeInt(objectRow.fields.length); + for (Object field : objectRow.fields) { + kryo.writeClassAndObject(output, field); + } } @Override - public int hashCode() { - return Arrays.hashCode(fields); + public ObjectRow read(Kryo kryo, Input input, Class aClass) { + int size = input.readInt(); + Object[] fields = new Object[size]; + for (int i = 0; i < size; i++) { + fields[i] = kryo.readClassAndObject(input); + } + return ObjectRow.create(fields); } - public static class ObjectRowSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, ObjectRow objectRow) { - output.writeInt(objectRow.fields.length); - for (Object field : objectRow.fields) { - kryo.writeClassAndObject(output, field); - } - } - - @Override - public ObjectRow read(Kryo kryo, Input input, Class aClass) { - int size = input.readInt(); - Object[] fields = new Object[size]; - for (int i = 0; i < size; i++) { - fields[i] = kryo.readClassAndObject(input); - } - return ObjectRow.create(fields); - } - - @Override - public ObjectRow copy(Kryo kryo, ObjectRow original) { - return ObjectRow.create(original.fields); - } + @Override + public ObjectRow copy(Kryo kryo, ObjectRow original) { + return ObjectRow.create(original.fields); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRowKey.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRowKey.java index f6c9d3531..cdf6e9c4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRowKey.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ObjectRowKey.java @@ -19,84 +19,84 @@ package org.apache.geaflow.dsl.common.data.impl; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.RowKey; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class ObjectRowKey implements RowKey { - private final Object[] keys; + private final Object[] keys; - private ObjectRowKey(Object... keys) { - this.keys = Objects.requireNonNull(keys); - } + private ObjectRowKey(Object... keys) { + this.keys = Objects.requireNonNull(keys); + } - public static RowKey of(Object... keys) { - return new ObjectRowKey(keys); - } + public static RowKey of(Object... keys) { + return new ObjectRowKey(keys); + } - @Override - public Object getField(int i, IType type) { - return keys[i]; - } + @Override + public Object getField(int i, IType type) { + return keys[i]; + } - @Override - public Object[] getKeys() { - return keys; - } + @Override + public Object[] getKeys() { + return keys; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ObjectRowKey)) { - return false; - } - ObjectRowKey that = (ObjectRowKey) o; - return Arrays.equals(keys, that.keys); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + if (!(o instanceof ObjectRowKey)) { + return false; + } + ObjectRowKey that = (ObjectRowKey) o; + return Arrays.equals(keys, that.keys); + } + + @Override + public int hashCode() { + return Arrays.hashCode(keys); + } + + @Override + public String toString() { + return "ObjectRowKey{" + "keys=" + Arrays.toString(keys) + '}'; + } + + public static class ObjectRowKeySerializer extends Serializer { @Override - public int hashCode() { - return Arrays.hashCode(keys); + public void write(Kryo kryo, Output output, ObjectRowKey objectRowKey) { + output.writeInt(objectRowKey.getKeys().length); + for (Object key : objectRowKey.getKeys()) { + kryo.writeClassAndObject(output, key); + } } @Override - public String toString() { - return "ObjectRowKey{" - + "keys=" + Arrays.toString(keys) - + '}'; + public ObjectRowKey read(Kryo kryo, Input input, Class aClass) { + int size = input.readInt(); + Object[] keys = new Object[size]; + for (int i = 0; i < size; i++) { + keys[i] = kryo.readClassAndObject(input); + } + return new ObjectRowKey(keys); } - public static class ObjectRowKeySerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, ObjectRowKey objectRowKey) { - output.writeInt(objectRowKey.getKeys().length); - for (Object key : objectRowKey.getKeys()) { - kryo.writeClassAndObject(output, key); - } - } - - @Override - public ObjectRowKey read(Kryo kryo, Input input, Class aClass) { - int size = input.readInt(); - Object[] keys = new Object[size]; - for (int i = 0; i < size; i++) { - keys[i] = kryo.readClassAndObject(input); - } - return new ObjectRowKey(keys); - } - - @Override - public ObjectRowKey copy(Kryo kryo, ObjectRowKey original) { - return new ObjectRowKey(original.getKeys()); - } + @Override + public ObjectRowKey copy(Kryo kryo, ObjectRowKey original) { + return new ObjectRowKey(original.getKeys()); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ParameterizedPath.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ParameterizedPath.java index 3ed8e0e05..a229c8c02 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ParameterizedPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/ParameterizedPath.java @@ -22,6 +22,4 @@ import org.apache.geaflow.dsl.common.data.ParameterizedRow; import org.apache.geaflow.dsl.common.data.Path; -public interface ParameterizedPath extends ParameterizedRow, Path { - -} +public interface ParameterizedPath extends ParameterizedRow, Path {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/VertexEdgeFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/VertexEdgeFactory.java index f9b888cd2..e42c84931 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/VertexEdgeFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/VertexEdgeFactory.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.data.impl; import java.util.Locale; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; @@ -43,53 +44,53 @@ public class VertexEdgeFactory { - public static RowVertex createVertex(VertexType vertexType) { - String idTypeName = vertexType.getId().getType().getName().toUpperCase(Locale.ROOT); - switch (idTypeName) { - case Types.TYPE_NAME_INTEGER: - return new IntVertex(); - case Types.TYPE_NAME_LONG: - return new LongVertex(); - case Types.TYPE_NAME_DOUBLE: - return new DoubleVertex(); - case Types.TYPE_NAME_BINARY_STRING: - return new BinaryStringVertex(); - default: - } - return new ObjectVertex(); + public static RowVertex createVertex(VertexType vertexType) { + String idTypeName = vertexType.getId().getType().getName().toUpperCase(Locale.ROOT); + switch (idTypeName) { + case Types.TYPE_NAME_INTEGER: + return new IntVertex(); + case Types.TYPE_NAME_LONG: + return new LongVertex(); + case Types.TYPE_NAME_DOUBLE: + return new DoubleVertex(); + case Types.TYPE_NAME_BINARY_STRING: + return new BinaryStringVertex(); + default: } + return new ObjectVertex(); + } - public static RowEdge createEdge(EdgeType edgeType) { - String idTypeName = edgeType.getSrcId().getType().getName().toUpperCase(Locale.ROOT); - if (edgeType.getTimestamp().isPresent()) { - return createTsEdge(idTypeName); - } - switch (idTypeName) { - case Types.TYPE_NAME_INTEGER: - return new IntEdge(); - case Types.TYPE_NAME_LONG: - return new LongEdge(); - case Types.TYPE_NAME_DOUBLE: - return new DoubleEdge(); - case Types.TYPE_NAME_BINARY_STRING: - return new BinaryStringEdge(); - default: - } - return new ObjectEdge(); + public static RowEdge createEdge(EdgeType edgeType) { + String idTypeName = edgeType.getSrcId().getType().getName().toUpperCase(Locale.ROOT); + if (edgeType.getTimestamp().isPresent()) { + return createTsEdge(idTypeName); + } + switch (idTypeName) { + case Types.TYPE_NAME_INTEGER: + return new IntEdge(); + case Types.TYPE_NAME_LONG: + return new LongEdge(); + case Types.TYPE_NAME_DOUBLE: + return new DoubleEdge(); + case Types.TYPE_NAME_BINARY_STRING: + return new BinaryStringEdge(); + default: } + return new ObjectEdge(); + } - private static RowEdge createTsEdge(String idTypeName) { - switch (idTypeName) { - case Types.TYPE_NAME_INTEGER: - return new IntTsEdge(); - case Types.TYPE_NAME_LONG: - return new LongTsEdge(); - case Types.TYPE_NAME_DOUBLE: - return new DoubleTsEdge(); - case Types.TYPE_NAME_BINARY_STRING: - return new BinaryStringTsEdge(); - default: - } - return new ObjectTsEdge(); + private static RowEdge createTsEdge(String idTypeName) { + switch (idTypeName) { + case Types.TYPE_NAME_INTEGER: + return new IntTsEdge(); + case Types.TYPE_NAME_LONG: + return new LongTsEdge(); + case Types.TYPE_NAME_DOUBLE: + return new DoubleTsEdge(); + case Types.TYPE_NAME_BINARY_STRING: + return new BinaryStringTsEdge(); + default: } + return new ObjectTsEdge(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringEdge.java index 5c3649b09..712c6c06e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,198 +30,203 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class BinaryStringEdge implements RowEdge, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private BinaryString srcId; - - private BinaryString targetId; - - private EdgeDirection direction = EdgeDirection.OUT; - - private BinaryString label; - - private Row value; - - public BinaryStringEdge() { - - } - - public BinaryStringEdge(BinaryString srcId, BinaryString targetId) { - this.srcId = srcId; - this.targetId = targetId; - } - - public BinaryStringEdge(BinaryString srcId, BinaryString targetId, Row value) { - this.srcId = srcId; - this.targetId = targetId; - this.value = value; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class BinaryStringEdge implements RowEdge, KryoSerializable { - @Override - public Object getSrcId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + private BinaryString srcId; + + private BinaryString targetId; + + private EdgeDirection direction = EdgeDirection.OUT; + + private BinaryString label; + + private Row value; + + public BinaryStringEdge() {} + + public BinaryStringEdge(BinaryString srcId, BinaryString targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + public BinaryStringEdge(BinaryString srcId, BinaryString targetId, Row value) { + this.srcId = srcId; + this.targetId = targetId; + this.value = value; + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = (BinaryString) srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = (BinaryString) targetId; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof BinaryStringEdge)) { + return false; + } + BinaryStringEdge that = (BinaryStringEdge) o; + return Objects.equals(srcId, that.srcId) + && Objects.equals(targetId, that.targetId) + && direction == that.direction + && Objects.equals(label, that.label); + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, direction, label); + } + + @Override + public BinaryStringEdge reverse() { + BinaryStringEdge edge = new BinaryStringEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public BinaryStringEdge withValue(Row value) { + BinaryStringEdge edge = new BinaryStringEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public String toString() { + return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: return srcId; - } - - @Override - public void setSrcId(Object srcId) { - this.srcId = (BinaryString) srcId; - } - - @Override - public Object getTargetId() { + case EdgeType.TARGET_ID_FIELD_POSITION: return targetId; - } - - @Override - public void setTargetId(Object targetId) { - this.targetId = (BinaryString) targetId; - } - - @Override - public String getLabel() { - return label.toString(); - } - - @Override - public EdgeDirection getDirect() { - return direction; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof BinaryStringEdge)) { - return false; - } - BinaryStringEdge that = (BinaryStringEdge) o; - return Objects.equals(srcId, that.srcId) && Objects.equals(targetId, - that.targetId) && direction == that.direction && Objects.equals(label, that.label); - } - - @Override - public int hashCode() { - return Objects.hash(srcId, targetId, direction, label); - } - - @Override - public BinaryStringEdge reverse() { - BinaryStringEdge edge = new BinaryStringEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public Row getValue() { - return value; - } - - @Override - public BinaryStringEdge withValue(Row value) { - BinaryStringEdge edge = new BinaryStringEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public String toString() { - return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return srcId; - case EdgeType.TARGET_ID_FIELD_POSITION: - return targetId; - case EdgeType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 3, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - BinaryStringEdge edge = new BinaryStringEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public RowEdge identityReverse() { - BinaryStringEdge edge = new BinaryStringEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction.reverse()); - return edge; - } - - @Override - public BinaryString getBinaryLabel() { + case EdgeType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public BinaryStringEdge get() { - return new BinaryStringEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - byte[] srcIdBytes = this.srcId.getBytes(); - output.writeInt(srcIdBytes.length); - output.writeBytes(srcIdBytes); - byte[] targetIdBytes = this.targetId.getBytes(); - output.writeInt(targetIdBytes.length); - output.writeBytes(targetIdBytes); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - BinaryString srcId = BinaryString.fromBytes(input.readBytes(input.readInt())); - this.srcId = srcId; - BinaryString targetId = BinaryString.fromBytes(input.readBytes(input.readInt())); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); - this.setBinaryLabel(label); - Row value = (Row) kryo.readClassAndObject(input); - this.value = value; - } + default: + return value.getField(i - 3, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + BinaryStringEdge edge = new BinaryStringEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public RowEdge identityReverse() { + BinaryStringEdge edge = new BinaryStringEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction.reverse()); + return edge; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public BinaryStringEdge get() { + return new BinaryStringEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + byte[] srcIdBytes = this.srcId.getBytes(); + output.writeInt(srcIdBytes.length); + output.writeBytes(srcIdBytes); + byte[] targetIdBytes = this.targetId.getBytes(); + output.writeInt(targetIdBytes.length); + output.writeBytes(targetIdBytes); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + BinaryString srcId = BinaryString.fromBytes(input.readBytes(input.readInt())); + this.srcId = srcId; + BinaryString targetId = BinaryString.fromBytes(input.readBytes(input.readInt())); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); + this.setBinaryLabel(label); + Row value = (Row) kryo.readClassAndObject(input); + this.value = value; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringTsEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringTsEdge.java index 988604fd5..6277f0b72 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringTsEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringTsEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,158 +30,169 @@ import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class BinaryStringTsEdge extends BinaryStringEdge implements IGraphElementWithTimeField, - KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private long time; - - public BinaryStringTsEdge() { - - } - - public BinaryStringTsEdge(BinaryString srcId, BinaryString targetId) { - super(srcId, targetId); - } - - public BinaryStringTsEdge(BinaryString srcId, BinaryString targetId, Row value) { - super(srcId, targetId, value); - } - - public void setTime(long time) { - this.time = time; - } - - public long getTime() { - return time; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return getSrcId(); - case EdgeType.TARGET_ID_FIELD_POSITION: - return getTargetId(); - case EdgeType.LABEL_FIELD_POSITION: - return getBinaryLabel(); - case EdgeType.TIME_FIELD_POSITION: - return time; - default: - return getValue().getField(i - 4, type); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof BinaryStringTsEdge)) { - return false; - } - if (!super.equals(o)) { - return false; - } - BinaryStringTsEdge that = (BinaryStringTsEdge) o; - return time == that.time; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), time); - } +public class BinaryStringTsEdge extends BinaryStringEdge + implements IGraphElementWithTimeField, KryoSerializable { - @Override - public BinaryStringTsEdge reverse() { - BinaryStringTsEdge edge = new BinaryStringTsEdge((BinaryString) getTargetId(), - (BinaryString) getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + public static final Supplier CONSTRUCTOR = new Constructor(); - @Override - public BinaryStringTsEdge withValue(Row value) { - BinaryStringTsEdge edge = new BinaryStringTsEdge((BinaryString) getSrcId(), - (BinaryString) getTargetId(), value); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + private long time; - @Override - public RowEdge withDirection(EdgeDirection direction) { - BinaryStringTsEdge edge = new BinaryStringTsEdge((BinaryString) getSrcId(), - (BinaryString) getTargetId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(direction); - edge.setTime(time); - return edge; - } + public BinaryStringTsEdge() {} - @Override - public RowEdge identityReverse() { - BinaryStringTsEdge edge = new BinaryStringTsEdge((BinaryString) getTargetId(), - (BinaryString) getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect().reverse()); - edge.setTime(time); - return edge; - } + public BinaryStringTsEdge(BinaryString srcId, BinaryString targetId) { + super(srcId, targetId); + } - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() - + "#" + time + "#" + getValue(); - } + public BinaryStringTsEdge(BinaryString srcId, BinaryString targetId, Row value) { + super(srcId, targetId, value); + } - private static class Constructor implements Supplier { + public void setTime(long time) { + this.time = time; + } - @Override - public BinaryStringTsEdge get() { - return new BinaryStringTsEdge(); - } - } + public long getTime() { + return time; + } - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - byte[] srcIdBytes = ((BinaryString) this.getSrcId()).getBytes(); - byte[] targetIdBytes = ((BinaryString) this.getTargetId()).getBytes(); - output.writeInt(srcIdBytes.length); - output.writeBytes(srcIdBytes); - output.writeInt(targetIdBytes.length); - output.writeBytes(targetIdBytes); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - output.writeLong(this.getTime()); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: + return getSrcId(); + case EdgeType.TARGET_ID_FIELD_POSITION: + return getTargetId(); + case EdgeType.LABEL_FIELD_POSITION: + return getBinaryLabel(); + case EdgeType.TIME_FIELD_POSITION: + return time; + default: + return getValue().getField(i - 4, type); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof BinaryStringTsEdge)) { + return false; + } + if (!super.equals(o)) { + return false; + } + BinaryStringTsEdge that = (BinaryStringTsEdge) o; + return time == that.time; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), time); + } + + @Override + public BinaryStringTsEdge reverse() { + BinaryStringTsEdge edge = + new BinaryStringTsEdge((BinaryString) getTargetId(), (BinaryString) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public BinaryStringTsEdge withValue(Row value) { + BinaryStringTsEdge edge = + new BinaryStringTsEdge((BinaryString) getSrcId(), (BinaryString) getTargetId(), value); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + BinaryStringTsEdge edge = + new BinaryStringTsEdge((BinaryString) getSrcId(), (BinaryString) getTargetId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(direction); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge identityReverse() { + BinaryStringTsEdge edge = + new BinaryStringTsEdge((BinaryString) getTargetId(), (BinaryString) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect().reverse()); + edge.setTime(time); + return edge; + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + time + + "#" + + getValue(); + } + + private static class Constructor implements Supplier { @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - BinaryString srcId = BinaryString.fromBytes(input.readBytes(input.readInt())); - BinaryString targetId = BinaryString.fromBytes(input.readBytes(input.readInt())); - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); - long time = input.readLong(); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create edge object - this.setSrcId(srcId); - this.setTargetId(targetId); - this.setValue(value); - this.setBinaryLabel(label); - this.setDirect(direction); - this.setTime(time); - } - + public BinaryStringTsEdge get() { + return new BinaryStringTsEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + byte[] srcIdBytes = ((BinaryString) this.getSrcId()).getBytes(); + byte[] targetIdBytes = ((BinaryString) this.getTargetId()).getBytes(); + output.writeInt(srcIdBytes.length); + output.writeBytes(srcIdBytes); + output.writeInt(targetIdBytes.length); + output.writeBytes(targetIdBytes); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + output.writeLong(this.getTime()); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + BinaryString srcId = BinaryString.fromBytes(input.readBytes(input.readInt())); + BinaryString targetId = BinaryString.fromBytes(input.readBytes(input.readInt())); + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); + long time = input.readLong(); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create edge object + this.setSrcId(srcId); + this.setTargetId(targetId); + this.setValue(value); + this.setBinaryLabel(label); + this.setDirect(direction); + this.setTime(time); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringVertex.java index ecb99972c..d02645c05 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/BinaryStringVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,158 +31,160 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.vertex.IVertex; -public class BinaryStringVertex implements RowVertex, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private BinaryString id; - - private BinaryString label; - - private Row value; - - public BinaryStringVertex() { - - } - - public BinaryStringVertex(BinaryString id) { - this.id = id; - } - - public BinaryStringVertex(BinaryString id, BinaryString label, Row value) { - this.id = id; - this.label = label; - this.value = value; - } - - @Override - public String getLabel() { - return label.toString(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = BinaryString.fromString(label); - } +public class BinaryStringVertex implements RowVertex, KryoSerializable { - @Override - public Object getId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + private BinaryString id; + + private BinaryString label; + + private Row value; + + public BinaryStringVertex() {} + + public BinaryStringVertex(BinaryString id) { + this.id = id; + } + + public BinaryStringVertex(BinaryString id, BinaryString label, Row value) { + this.id = id; + this.label = label; + this.value = value; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public void setLabel(String label) { + this.label = BinaryString.fromString(label); + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = (BinaryString) Objects.requireNonNull(id); + } + + @Override + public Row getValue() { + return value; + } + + @Override + public BinaryStringVertex withValue(Row value) { + return new BinaryStringVertex(id, label, value); + } + + @Override + public BinaryStringVertex withLabel(String label) { + return new BinaryStringVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); + } + + @Override + public IVertex withTime(long time) { + throw new GeaFlowDSLException("Vertex not support timestamp"); + } + + @SuppressWarnings("unchecked") + @Override + public int compareTo(Object o) { + RowVertex vertex = (RowVertex) o; + return ((Comparable) id).compareTo(vertex.getId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowVertex)) { + return false; + } + RowVertex that = (RowVertex) o; + return id.equals(that.getId()) && Objects.equals(label, that.getBinaryLabel()); + } + + @Override + public int hashCode() { + return Objects.hash(id, label); + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case VertexType.ID_FIELD_POSITION: return id; - } - - @Override - public void setId(Object id) { - this.id = (BinaryString) Objects.requireNonNull(id); - } - - @Override - public Row getValue() { - return value; - } - - @Override - public BinaryStringVertex withValue(Row value) { - return new BinaryStringVertex(id, label, value); - } - - @Override - public BinaryStringVertex withLabel(String label) { - return new BinaryStringVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); - } - - @Override - public IVertex withTime(long time) { - throw new GeaFlowDSLException("Vertex not support timestamp"); - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(Object o) { - RowVertex vertex = (RowVertex) o; - return ((Comparable) id).compareTo(vertex.getId()); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - RowVertex that = (RowVertex) o; - return id.equals(that.getId()) && Objects.equals(label, that.getBinaryLabel()); - } - - @Override - public int hashCode() { - return Objects.hash(id, label); - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case VertexType.ID_FIELD_POSITION: - return id; - case VertexType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 2, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public String toString() { - return id + "#" + label + "#" + value; - } - - @Override - public BinaryString getBinaryLabel() { + case VertexType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public BinaryStringVertex get() { - return new BinaryStringVertex(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - byte[] idBytes = this.id.getBytes(); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(idBytes.length); - output.writeBytes(idBytes); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - BinaryString id = BinaryString.fromBytes(input.readBytes(input.readInt())); - BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create vertex object - this.id = id; - this.setValue(value); - this.setBinaryLabel(label); - } - + default: + return value.getField(i - 2, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public String toString() { + return id + "#" + label + "#" + value; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public BinaryStringVertex get() { + return new BinaryStringVertex(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + byte[] idBytes = this.id.getBytes(); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(idBytes.length); + output.writeBytes(idBytes); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + BinaryString id = BinaryString.fromBytes(input.readBytes(input.readInt())); + BinaryString label = BinaryString.fromBytes(input.readBytes(input.readInt())); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create vertex object + this.id = id; + this.setValue(value); + this.setBinaryLabel(label); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleEdge.java index 0d45189da..7b13902c1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,207 +30,210 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class DoubleEdge implements RowEdge, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public double srcId; - - public double targetId; - - public EdgeDirection direction = EdgeDirection.OUT; - - private BinaryString label; - - private Row value; - - public DoubleEdge() { - - } - - public DoubleEdge(double srcId, double targetId) { - this.srcId = srcId; - this.targetId = targetId; - } - - public DoubleEdge(double srcId, double targetId, Row value) { - this.srcId = srcId; - this.targetId = targetId; - this.value = value; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class DoubleEdge implements RowEdge, KryoSerializable { - @Override - public Object getSrcId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public double srcId; + + public double targetId; + + public EdgeDirection direction = EdgeDirection.OUT; + + private BinaryString label; + + private Row value; + + public DoubleEdge() {} + + public DoubleEdge(double srcId, double targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + public DoubleEdge(double srcId, double targetId, Row value) { + this.srcId = srcId; + this.targetId = targetId; + this.value = value; + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = (double) srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = (double) targetId; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof DoubleEdge) { + DoubleEdge that = (DoubleEdge) o; + return Double.compare(srcId, that.srcId) == 0 + && Double.compare(targetId, that.targetId) == 0 + && direction == that.direction + && Objects.equals(label, that.label); + } else { + RowEdge that = (RowEdge) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, direction, label); + } + + @Override + public DoubleEdge reverse() { + DoubleEdge edge = new DoubleEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public DoubleEdge withValue(Row value) { + DoubleEdge edge = new DoubleEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public String toString() { + return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: return srcId; - } - - @Override - public void setSrcId(Object srcId) { - this.srcId = (double) srcId; - } - - @Override - public Object getTargetId() { + case EdgeType.TARGET_ID_FIELD_POSITION: return targetId; - } - - @Override - public void setTargetId(Object targetId) { - this.targetId = (double) targetId; - } - - @Override - public String getLabel() { - return label.toString(); - } - - @Override - public EdgeDirection getDirect() { - return direction; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof DoubleEdge) { - DoubleEdge that = (DoubleEdge) o; - return Double.compare(srcId, that.srcId) == 0 && Double.compare(targetId, - that.targetId) == 0 && direction == that.direction && Objects.equals(label, - that.label); - } else { - RowEdge that = (RowEdge) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(srcId, targetId, direction, label); - } - - @Override - public DoubleEdge reverse() { - DoubleEdge edge = new DoubleEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public Row getValue() { - return value; - } - - @Override - public DoubleEdge withValue(Row value) { - DoubleEdge edge = new DoubleEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public String toString() { - return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return srcId; - case EdgeType.TARGET_ID_FIELD_POSITION: - return targetId; - case EdgeType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 3, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - DoubleEdge edge = new DoubleEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public RowEdge identityReverse() { - DoubleEdge edge = new DoubleEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction.reverse()); - return edge; - } - - @Override - public BinaryString getBinaryLabel() { + case EdgeType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public DoubleEdge get() { - return new DoubleEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - output.writeDouble(this.srcId); - output.writeDouble(this.targetId); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - double srcId = input.readDouble(); - this.srcId = srcId; - double targetId = input.readDouble(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create edge object - this.setValue(value); - } - + default: + return value.getField(i - 3, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + DoubleEdge edge = new DoubleEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public RowEdge identityReverse() { + DoubleEdge edge = new DoubleEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction.reverse()); + return edge; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public DoubleEdge get() { + return new DoubleEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + output.writeDouble(this.srcId); + output.writeDouble(this.targetId); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + double srcId = input.readDouble(); + this.srcId = srcId; + double targetId = input.readDouble(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create edge object + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleTsEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleTsEdge.java index a05a5fa8e..f485403be 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleTsEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleTsEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,154 +30,162 @@ import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class DoubleTsEdge extends DoubleEdge implements IGraphElementWithTimeField, - KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private long time; - - public DoubleTsEdge() { - - } - - public DoubleTsEdge(double srcId, double targetId) { - super(srcId, targetId); - } - - public DoubleTsEdge(double srcId, double targetId, Row value) { - super(srcId, targetId, value); - } - - public void setTime(long time) { - this.time = time; - } - - public long getTime() { - return time; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return getSrcId(); - case EdgeType.TARGET_ID_FIELD_POSITION: - return getTargetId(); - case EdgeType.LABEL_FIELD_POSITION: - return getBinaryLabel(); - case EdgeType.TIME_FIELD_POSITION: - return time; - default: - return getValue().getField(i - 4, type); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof DoubleTsEdge)) { - return false; - } - if (!super.equals(o)) { - return false; - } - DoubleTsEdge that = (DoubleTsEdge) o; - return time == that.time; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), time); - } +public class DoubleTsEdge extends DoubleEdge + implements IGraphElementWithTimeField, KryoSerializable { - @Override - public DoubleTsEdge reverse() { - DoubleTsEdge edge = new DoubleTsEdge((double) getTargetId(), (double) getSrcId(), - getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + public static final Supplier CONSTRUCTOR = new Constructor(); - @Override - public DoubleTsEdge withValue(Row value) { - DoubleTsEdge edge = new DoubleTsEdge((double) getSrcId(), (double) getTargetId(), value); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + private long time; - @Override - public RowEdge withDirection(EdgeDirection direction) { - DoubleTsEdge edge = new DoubleTsEdge((double) getSrcId(), (double) getTargetId(), - getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(direction); - edge.setTime(time); - return edge; - } + public DoubleTsEdge() {} - @Override - public RowEdge identityReverse() { - DoubleTsEdge edge = new DoubleTsEdge((double) getTargetId(), (double) getSrcId(), - getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect().reverse()); - edge.setTime(time); - return edge; - } + public DoubleTsEdge(double srcId, double targetId) { + super(srcId, targetId); + } - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() - + "#" + time + "#" + getValue(); - } + public DoubleTsEdge(double srcId, double targetId, Row value) { + super(srcId, targetId, value); + } - private static class Constructor implements Supplier { + public void setTime(long time) { + this.time = time; + } - @Override - public DoubleTsEdge get() { - return new DoubleTsEdge(); - } - } + public long getTime() { + return time; + } - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - output.writeDouble((Double) this.getSrcId()); - output.writeDouble((Double) this.getTargetId()); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - output.writeLong(this.getTime()); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: + return getSrcId(); + case EdgeType.TARGET_ID_FIELD_POSITION: + return getTargetId(); + case EdgeType.LABEL_FIELD_POSITION: + return getBinaryLabel(); + case EdgeType.TIME_FIELD_POSITION: + return time; + default: + return getValue().getField(i - 4, type); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof DoubleTsEdge)) { + return false; + } + if (!super.equals(o)) { + return false; + } + DoubleTsEdge that = (DoubleTsEdge) o; + return time == that.time; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), time); + } + + @Override + public DoubleTsEdge reverse() { + DoubleTsEdge edge = new DoubleTsEdge((double) getTargetId(), (double) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public DoubleTsEdge withValue(Row value) { + DoubleTsEdge edge = new DoubleTsEdge((double) getSrcId(), (double) getTargetId(), value); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + DoubleTsEdge edge = new DoubleTsEdge((double) getSrcId(), (double) getTargetId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(direction); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge identityReverse() { + DoubleTsEdge edge = new DoubleTsEdge((double) getTargetId(), (double) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect().reverse()); + edge.setTime(time); + return edge; + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + time + + "#" + + getValue(); + } + + private static class Constructor implements Supplier { @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - double srcId = input.readDouble(); - this.srcId = srcId; - double targetId = input.readDouble(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - long time = input.readLong(); - this.setTime(time); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create edge object - this.setValue(value); - } - + public DoubleTsEdge get() { + return new DoubleTsEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + output.writeDouble((Double) this.getSrcId()); + output.writeDouble((Double) this.getTargetId()); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + output.writeLong(this.getTime()); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + double srcId = input.readDouble(); + this.srcId = srcId; + double targetId = input.readDouble(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + long time = input.readLong(); + this.setTime(time); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create edge object + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleVertex.java index 3027bcfa8..393eba542 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/DoubleVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,161 +31,164 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.vertex.IVertex; -public class DoubleVertex implements RowVertex, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public double id; - - private BinaryString label; - - private Row value; - - public DoubleVertex() { - - } - - public DoubleVertex(double id) { - this.id = id; - } - - public DoubleVertex(double id, BinaryString label, Row value) { - this.id = id; - this.label = label; - this.value = value; - } - - @Override - public String getLabel() { - return label.toString(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class DoubleVertex implements RowVertex, KryoSerializable { - @Override - public Object getId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public double id; + + private BinaryString label; + + private Row value; + + public DoubleVertex() {} + + public DoubleVertex(double id) { + this.id = id; + } + + public DoubleVertex(double id, BinaryString label, Row value) { + this.id = id; + this.label = label; + this.value = value; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = (double) Objects.requireNonNull(id); + } + + @Override + public Row getValue() { + return value; + } + + @Override + public DoubleVertex withValue(Row value) { + return new DoubleVertex(id, label, value); + } + + @Override + public DoubleVertex withLabel(String label) { + return new DoubleVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); + } + + @Override + public IVertex withTime(long time) { + throw new GeaFlowDSLException("Vertex not support timestamp"); + } + + @SuppressWarnings("unchecked") + @Override + public int compareTo(Object o) { + RowVertex vertex = (RowVertex) o; + return ((Comparable) id).compareTo(vertex.getId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof DoubleVertex) { + DoubleVertex that = (DoubleVertex) o; + return Double.compare(id, that.id) == 0 && Objects.equals(label, that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(id, label); + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case VertexType.ID_FIELD_POSITION: return id; - } - - @Override - public void setId(Object id) { - this.id = (double) Objects.requireNonNull(id); - } - - @Override - public Row getValue() { - return value; - } - - @Override - public DoubleVertex withValue(Row value) { - return new DoubleVertex(id, label, value); - } - - @Override - public DoubleVertex withLabel(String label) { - return new DoubleVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); - } - - @Override - public IVertex withTime(long time) { - throw new GeaFlowDSLException("Vertex not support timestamp"); - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(Object o) { - RowVertex vertex = (RowVertex) o; - return ((Comparable) id).compareTo(vertex.getId()); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof DoubleVertex) { - DoubleVertex that = (DoubleVertex) o; - return Double.compare(id, that.id) == 0 && Objects.equals(label, that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(id, label); - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case VertexType.ID_FIELD_POSITION: - return id; - case VertexType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 2, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public String toString() { - return id + "#" + label + "#" + value; - } - - @Override - public BinaryString getBinaryLabel() { + case VertexType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public DoubleVertex get() { - return new DoubleVertex(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - output.writeDouble(this.id); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - double id = input.readDouble(); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create vertex object - this.id = id; - this.setValue(value); - this.setBinaryLabel(label); - } + default: + return value.getField(i - 2, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public String toString() { + return id + "#" + label + "#" + value; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public DoubleVertex get() { + return new DoubleVertex(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + output.writeDouble(this.id); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + double id = input.readDouble(); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create vertex object + this.id = id; + this.setValue(value); + this.setBinaryLabel(label); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntEdge.java index 7183030e0..991892349 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,207 +30,210 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class IntEdge implements RowEdge, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public int srcId; - - public int targetId; - - public EdgeDirection direction = EdgeDirection.OUT; - - private BinaryString label; - - private Row value; - - public IntEdge() { - - } - - public IntEdge(int srcId, int targetId) { - this.srcId = srcId; - this.targetId = targetId; - } - - public IntEdge(int srcId, int targetId, Row value) { - this.srcId = srcId; - this.targetId = targetId; - this.value = value; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class IntEdge implements RowEdge, KryoSerializable { - @Override - public Object getSrcId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public int srcId; + + public int targetId; + + public EdgeDirection direction = EdgeDirection.OUT; + + private BinaryString label; + + private Row value; + + public IntEdge() {} + + public IntEdge(int srcId, int targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + public IntEdge(int srcId, int targetId, Row value) { + this.srcId = srcId; + this.targetId = targetId; + this.value = value; + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = (int) srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = (int) targetId; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof IntEdge) { + IntEdge that = (IntEdge) o; + return srcId == that.srcId + && targetId == that.targetId + && direction == that.direction + && Objects.equals(label, that.label); + } else { + RowEdge that = (RowEdge) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, direction, label); + } + + @Override + public IntEdge reverse() { + IntEdge edge = new IntEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public IntEdge withValue(Row value) { + IntEdge edge = new IntEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public String toString() { + return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: return srcId; - } - - @Override - public void setSrcId(Object srcId) { - this.srcId = (int) srcId; - } - - @Override - public Object getTargetId() { + case EdgeType.TARGET_ID_FIELD_POSITION: return targetId; - } - - @Override - public void setTargetId(Object targetId) { - this.targetId = (int) targetId; - } - - @Override - public String getLabel() { - return label.toString(); - } - - @Override - public EdgeDirection getDirect() { - return direction; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof IntEdge) { - IntEdge that = (IntEdge) o; - return srcId == that.srcId && targetId == that.targetId - && direction == that.direction && Objects.equals(label, that.label); - } else { - RowEdge that = (RowEdge) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(srcId, targetId, direction, label); - } - - @Override - public IntEdge reverse() { - IntEdge edge = new IntEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public Row getValue() { - return value; - } - - @Override - public IntEdge withValue(Row value) { - IntEdge edge = new IntEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public String toString() { - return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return srcId; - case EdgeType.TARGET_ID_FIELD_POSITION: - return targetId; - case EdgeType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 3, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - IntEdge edge = new IntEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public RowEdge identityReverse() { - IntEdge edge = new IntEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction.reverse()); - return edge; - } - - @Override - public BinaryString getBinaryLabel() { + case EdgeType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public IntEdge get() { - return new IntEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize fields - output.writeInt(this.srcId); - output.writeInt(this.targetId); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize fields - int srcId = input.readInt(); - this.srcId = srcId; - int targetId = input.readInt(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - // create edge object - this.setValue(value); - } - - + default: + return value.getField(i - 3, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + IntEdge edge = new IntEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public RowEdge identityReverse() { + IntEdge edge = new IntEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction.reverse()); + return edge; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public IntEdge get() { + return new IntEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize fields + output.writeInt(this.srcId); + output.writeInt(this.targetId); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize fields + int srcId = input.readInt(); + this.srcId = srcId; + int targetId = input.readInt(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + // create edge object + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntTsEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntTsEdge.java index 2741ddcf9..3ca16b475 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntTsEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntTsEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,151 +30,160 @@ import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class IntTsEdge extends IntEdge implements IGraphElementWithTimeField, KryoSerializable { - public static final Supplier CONSTRUCTOR = new Constructor(); + public static final Supplier CONSTRUCTOR = new Constructor(); - private long time; + private long time; - public IntTsEdge() { + public IntTsEdge() {} - } + public IntTsEdge(int srcId, int targetId) { + super(srcId, targetId); + } - public IntTsEdge(int srcId, int targetId) { - super(srcId, targetId); - } + public IntTsEdge(int srcId, int targetId, Row value) { + super(srcId, targetId, value); + } - public IntTsEdge(int srcId, int targetId, Row value) { - super(srcId, targetId, value); - } + public void setTime(long time) { + this.time = time; + } - public void setTime(long time) { - this.time = time; - } + public long getTime() { + return time; + } - public long getTime() { + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: + return getSrcId(); + case EdgeType.TARGET_ID_FIELD_POSITION: + return getTargetId(); + case EdgeType.LABEL_FIELD_POSITION: + return getBinaryLabel(); + case EdgeType.TIME_FIELD_POSITION: return time; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return getSrcId(); - case EdgeType.TARGET_ID_FIELD_POSITION: - return getTargetId(); - case EdgeType.LABEL_FIELD_POSITION: - return getBinaryLabel(); - case EdgeType.TIME_FIELD_POSITION: - return time; - default: - return getValue().getField(i - 4, type); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof IntTsEdge)) { - return false; - } - if (!super.equals(o)) { - return false; - } - IntTsEdge that = (IntTsEdge) o; - return time == that.time; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), time); - } + default: + return getValue().getField(i - 4, type); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof IntTsEdge)) { + return false; + } + if (!super.equals(o)) { + return false; + } + IntTsEdge that = (IntTsEdge) o; + return time == that.time; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), time); + } + + @Override + public IntTsEdge reverse() { + IntTsEdge edge = new IntTsEdge((int) getTargetId(), (int) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public IntTsEdge withValue(Row value) { + IntTsEdge edge = new IntTsEdge((int) getSrcId(), (int) getTargetId(), value); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + IntTsEdge edge = new IntTsEdge((int) getSrcId(), (int) getTargetId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(direction); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge identityReverse() { + IntTsEdge edge = new IntTsEdge((int) getTargetId(), (int) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect().reverse()); + edge.setTime(time); + return edge; + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + time + + "#" + + getValue(); + } + + private static class Constructor implements Supplier { @Override - public IntTsEdge reverse() { - IntTsEdge edge = new IntTsEdge((int) getTargetId(), (int) getSrcId(), - getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } - - @Override - public IntTsEdge withValue(Row value) { - IntTsEdge edge = new IntTsEdge((int) getSrcId(), (int) getTargetId(), value); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - IntTsEdge edge = new IntTsEdge((int) getSrcId(), (int) getTargetId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(direction); - edge.setTime(time); - return edge; - } - - @Override - public RowEdge identityReverse() { - IntTsEdge edge = new IntTsEdge((int) getTargetId(), (int) getSrcId(), - getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect().reverse()); - edge.setTime(time); - return edge; - } - - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() - + "#" + time + "#" + getValue(); - } - - private static class Constructor implements Supplier { - - @Override - public IntTsEdge get() { - return new IntTsEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize srcId, targetId, direction, label and time - output.writeInt((Integer) this.getSrcId()); - output.writeInt((Integer) this.getTargetId()); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - output.writeLong(this.getTime()); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize srcId, targetId, direction, label and time - int srcId = input.readInt(); - this.srcId = srcId; - int targetId = input.readInt(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - long time = input.readLong(); - this.setTime(time); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - this.setValue(value); - } - + public IntTsEdge get() { + return new IntTsEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize srcId, targetId, direction, label and time + output.writeInt((Integer) this.getSrcId()); + output.writeInt((Integer) this.getTargetId()); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + output.writeLong(this.getTime()); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize srcId, targetId, direction, label and time + int srcId = input.readInt(); + this.srcId = srcId; + int targetId = input.readInt(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + long time = input.readLong(); + this.setTime(time); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntVertex.java index dc3a38cdf..b5446caf7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/IntVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,159 +31,161 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.vertex.IVertex; -public class IntVertex implements RowVertex, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public int id; - - private BinaryString label; - - private Row value; - - public IntVertex() { - - } - - public IntVertex(int id) { - this.id = id; - } - - public IntVertex(int id, BinaryString label, Row value) { - this.id = id; - this.label = label; - this.value = value; - } - - @Override - public String getLabel() { - return label.toString(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class IntVertex implements RowVertex, KryoSerializable { - @Override - public Object getId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public int id; + + private BinaryString label; + + private Row value; + + public IntVertex() {} + + public IntVertex(int id) { + this.id = id; + } + + public IntVertex(int id, BinaryString label, Row value) { + this.id = id; + this.label = label; + this.value = value; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = (int) Objects.requireNonNull(id); + } + + @Override + public Row getValue() { + return value; + } + + @Override + public IntVertex withValue(Row value) { + return new IntVertex(id, label, value); + } + + @Override + public IntVertex withLabel(String label) { + return new IntVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); + } + + @Override + public IVertex withTime(long time) { + throw new GeaFlowDSLException("Vertex not support timestamp"); + } + + @SuppressWarnings("unchecked") + @Override + public int compareTo(Object o) { + RowVertex vertex = (RowVertex) o; + return ((Comparable) id).compareTo(vertex.getId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof IntVertex) { + IntVertex that = (IntVertex) o; + return id == that.id && Objects.equals(label, that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(id, label); + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case VertexType.ID_FIELD_POSITION: return id; - } - - @Override - public void setId(Object id) { - this.id = (int) Objects.requireNonNull(id); - } - - @Override - public Row getValue() { - return value; - } - - @Override - public IntVertex withValue(Row value) { - return new IntVertex(id, label, value); - } - - @Override - public IntVertex withLabel(String label) { - return new IntVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); - } - - @Override - public IVertex withTime(long time) { - throw new GeaFlowDSLException("Vertex not support timestamp"); - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(Object o) { - RowVertex vertex = (RowVertex) o; - return ((Comparable) id).compareTo(vertex.getId()); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof IntVertex) { - IntVertex that = (IntVertex) o; - return id == that.id && Objects.equals(label, that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(id, label); - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case VertexType.ID_FIELD_POSITION: - return id; - case VertexType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 2, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public String toString() { - return id + "#" + label + "#" + value; - } - - @Override - public BinaryString getBinaryLabel() { + case VertexType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public IntVertex get() { - return new IntVertex(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize id, label and value - output.writeInt(this.id); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize id, label and value - int id = input.readInt(); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - Row value = (Row) kryo.readClassAndObject(input); - this.id = id; - this.setValue(value); - this.setBinaryLabel(label); - } - + default: + return value.getField(i - 2, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public String toString() { + return id + "#" + label + "#" + value; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public IntVertex get() { + return new IntVertex(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize id, label and value + output.writeInt(this.id); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize id, label and value + int id = input.readInt(); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + Row value = (Row) kryo.readClassAndObject(input); + this.id = id; + this.setValue(value); + this.setBinaryLabel(label); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongEdge.java index f8c075d2a..06d952c16 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,203 +30,207 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class LongEdge implements RowEdge, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public long srcId; - - public long targetId; - - public EdgeDirection direction = EdgeDirection.OUT; - - private BinaryString label; - - private Row value; - - public LongEdge() { - - } - - public LongEdge(long srcId, long targetId) { - this.srcId = srcId; - this.targetId = targetId; - } - - public LongEdge(long srcId, long targetId, Row value) { - this.srcId = srcId; - this.targetId = targetId; - this.value = value; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class LongEdge implements RowEdge, KryoSerializable { - @Override - public Object getSrcId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public long srcId; + + public long targetId; + + public EdgeDirection direction = EdgeDirection.OUT; + + private BinaryString label; + + private Row value; + + public LongEdge() {} + + public LongEdge(long srcId, long targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + public LongEdge(long srcId, long targetId, Row value) { + this.srcId = srcId; + this.targetId = targetId; + this.value = value; + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = (long) srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = (long) targetId; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof LongEdge) { + LongEdge that = (LongEdge) o; + return srcId == that.srcId + && targetId == that.targetId + && direction == that.direction + && Objects.equals(label, that.label); + } else { + RowEdge that = (RowEdge) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, direction, label); + } + + @Override + public LongEdge reverse() { + LongEdge edge = new LongEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public LongEdge withValue(Row value) { + LongEdge edge = new LongEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public String toString() { + return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: return srcId; - } - - @Override - public void setSrcId(Object srcId) { - this.srcId = (long) srcId; - } - - @Override - public Object getTargetId() { + case EdgeType.TARGET_ID_FIELD_POSITION: return targetId; - } - - @Override - public void setTargetId(Object targetId) { - this.targetId = (long) targetId; - } - - @Override - public String getLabel() { - return label.toString(); - } - - @Override - public EdgeDirection getDirect() { - return direction; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof LongEdge) { - LongEdge that = (LongEdge) o; - return srcId == that.srcId && targetId == that.targetId - && direction == that.direction && Objects.equals(label, that.label); - } else { - RowEdge that = (RowEdge) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(srcId, targetId, direction, label); - } - - @Override - public LongEdge reverse() { - LongEdge edge = new LongEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public Row getValue() { - return value; - } - - @Override - public LongEdge withValue(Row value) { - LongEdge edge = new LongEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public String toString() { - return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return srcId; - case EdgeType.TARGET_ID_FIELD_POSITION: - return targetId; - case EdgeType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 3, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - LongEdge edge = new LongEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public RowEdge identityReverse() { - LongEdge edge = new LongEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction.reverse()); - return edge; - } - - @Override - public BinaryString getBinaryLabel() { + case EdgeType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public LongEdge get() { - return new LongEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize srcId, targetId, direction, label and value - output.writeLong(this.srcId); - output.writeLong(this.targetId); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize srcId, targetId, direction, label and value - long srcId = input.readLong(); - this.srcId = srcId; - long targetId = input.readLong(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - Row value = (Row) kryo.readClassAndObject(input); - this.setValue(value); - } - + default: + return value.getField(i - 3, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + LongEdge edge = new LongEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public RowEdge identityReverse() { + LongEdge edge = new LongEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction.reverse()); + return edge; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public LongEdge get() { + return new LongEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize srcId, targetId, direction, label and value + output.writeLong(this.srcId); + output.writeLong(this.targetId); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize srcId, targetId, direction, label and value + long srcId = input.readLong(); + this.srcId = srcId; + long targetId = input.readLong(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + Row value = (Row) kryo.readClassAndObject(input); + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongTsEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongTsEdge.java index 5712a126b..ab7edc416 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongTsEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongTsEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,148 +30,160 @@ import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class LongTsEdge extends LongEdge implements IGraphElementWithTimeField, KryoSerializable { - public static final Supplier CONSTRUCTOR = new Constructor(); + public static final Supplier CONSTRUCTOR = new Constructor(); - private long time; + private long time; - public LongTsEdge() { + public LongTsEdge() {} - } + public LongTsEdge(long srcId, long targetId) { + super(srcId, targetId); + } - public LongTsEdge(long srcId, long targetId) { - super(srcId, targetId); - } + public LongTsEdge(long srcId, long targetId, Row value) { + super(srcId, targetId, value); + } - public LongTsEdge(long srcId, long targetId, Row value) { - super(srcId, targetId, value); - } + public void setTime(long time) { + this.time = time; + } - public void setTime(long time) { - this.time = time; - } + public long getTime() { + return time; + } - public long getTime() { + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: + return getSrcId(); + case EdgeType.TARGET_ID_FIELD_POSITION: + return getTargetId(); + case EdgeType.LABEL_FIELD_POSITION: + return getBinaryLabel(); + case EdgeType.TIME_FIELD_POSITION: return time; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return getSrcId(); - case EdgeType.TARGET_ID_FIELD_POSITION: - return getTargetId(); - case EdgeType.LABEL_FIELD_POSITION: - return getBinaryLabel(); - case EdgeType.TIME_FIELD_POSITION: - return time; - default: - return getValue().getField(i - 4, type); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof LongTsEdge)) { - return false; - } - if (!super.equals(o)) { - return false; - } - LongTsEdge that = (LongTsEdge) o; - return time == that.time; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), time); - } - - @Override - public LongTsEdge reverse() { - LongTsEdge edge = new LongTsEdge((long) getTargetId(), (long) getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } - - @Override - public LongTsEdge withValue(Row value) { - LongTsEdge edge = new LongTsEdge((long) getSrcId(), (long) getTargetId(), value); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - LongTsEdge edge = new LongTsEdge((long) getSrcId(), (long) getTargetId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(direction); - edge.setTime(time); - return edge; - } + default: + return getValue().getField(i - 4, type); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof LongTsEdge)) { + return false; + } + if (!super.equals(o)) { + return false; + } + LongTsEdge that = (LongTsEdge) o; + return time == that.time; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), time); + } + + @Override + public LongTsEdge reverse() { + LongTsEdge edge = new LongTsEdge((long) getTargetId(), (long) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public LongTsEdge withValue(Row value) { + LongTsEdge edge = new LongTsEdge((long) getSrcId(), (long) getTargetId(), value); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + LongTsEdge edge = new LongTsEdge((long) getSrcId(), (long) getTargetId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(direction); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge identityReverse() { + LongTsEdge edge = new LongTsEdge((long) getTargetId(), (long) getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect().reverse()); + edge.setTime(time); + return edge; + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + time + + "#" + + getValue(); + } + + private static class Constructor implements Supplier { @Override - public RowEdge identityReverse() { - LongTsEdge edge = new LongTsEdge((long) getTargetId(), (long) getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect().reverse()); - edge.setTime(time); - return edge; - } - - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() - + "#" + time + "#" + getValue(); - } - - private static class Constructor implements Supplier { - - @Override - public LongTsEdge get() { - return new LongTsEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize srcId, targetId, direction, label and time - output.writeLong((Long) this.getSrcId()); - output.writeLong((Long) this.getTargetId()); - kryo.writeClassAndObject(output, this.getDirect()); - byte[] labelBytes = this.getBinaryLabel().getBytes(); - output.writeInt(labelBytes.length); - output.writeBytes(labelBytes); - output.writeLong(this.getTime()); - // serialize value - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize srcId, targetId, direction, label and time - long srcId = input.readLong(); - this.srcId = srcId; - long targetId = input.readLong(); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - int labelLength = input.readInt(); - BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); - this.setBinaryLabel(label); - long time = input.readLong(); - this.setTime(time); - // deserialize value - Row value = (Row) kryo.readClassAndObject(input); - this.setValue(value); - } + public LongTsEdge get() { + return new LongTsEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize srcId, targetId, direction, label and time + output.writeLong((Long) this.getSrcId()); + output.writeLong((Long) this.getTargetId()); + kryo.writeClassAndObject(output, this.getDirect()); + byte[] labelBytes = this.getBinaryLabel().getBytes(); + output.writeInt(labelBytes.length); + output.writeBytes(labelBytes); + output.writeLong(this.getTime()); + // serialize value + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize srcId, targetId, direction, label and time + long srcId = input.readLong(); + this.srcId = srcId; + long targetId = input.readLong(); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + int labelLength = input.readInt(); + BinaryString label = BinaryString.fromBytes(input.readBytes(labelLength)); + this.setBinaryLabel(label); + long time = input.readLong(); + this.setTime(time); + // deserialize value + Row value = (Row) kryo.readClassAndObject(input); + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongVertex.java index 31a8c56e2..c8f608bc3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/LongVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,156 +31,158 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.vertex.IVertex; -public class LongVertex implements RowVertex, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - public long id; - - private BinaryString label; - - private Row value; - - public LongVertex() { - - } - - public LongVertex(long id) { - this.id = id; - } - - public LongVertex(long id, BinaryString label, Row value) { - this.id = id; - this.label = label; - this.value = value; - } - - @Override - public String getLabel() { - return label.toString(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class LongVertex implements RowVertex, KryoSerializable { - @Override - public Object getId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + public long id; + + private BinaryString label; + + private Row value; + + public LongVertex() {} + + public LongVertex(long id) { + this.id = id; + } + + public LongVertex(long id, BinaryString label, Row value) { + this.id = id; + this.label = label; + this.value = value; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = (long) Objects.requireNonNull(id); + } + + @Override + public Row getValue() { + return value; + } + + @Override + public LongVertex withValue(Row value) { + return new LongVertex(id, label, value); + } + + @Override + public LongVertex withLabel(String label) { + return new LongVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); + } + + @Override + public IVertex withTime(long time) { + throw new GeaFlowDSLException("Vertex not support timestamp"); + } + + @SuppressWarnings("unchecked") + @Override + public int compareTo(Object o) { + RowVertex vertex = (RowVertex) o; + return ((Comparable) id).compareTo(vertex.getId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof LongVertex) { + LongVertex that = (LongVertex) o; + return id == that.id && Objects.equals(label, that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return that.equals(this); + } + } + + @Override + public int hashCode() { + return Objects.hash(id, label); + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case VertexType.ID_FIELD_POSITION: return id; - } - - @Override - public void setId(Object id) { - this.id = (long) Objects.requireNonNull(id); - } - - @Override - public Row getValue() { - return value; - } - - @Override - public LongVertex withValue(Row value) { - return new LongVertex(id, label, value); - } - - @Override - public LongVertex withLabel(String label) { - return new LongVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); - } - - @Override - public IVertex withTime(long time) { - throw new GeaFlowDSLException("Vertex not support timestamp"); - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(Object o) { - RowVertex vertex = (RowVertex) o; - return ((Comparable) id).compareTo(vertex.getId()); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof LongVertex) { - LongVertex that = (LongVertex) o; - return id == that.id && Objects.equals(label, that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return that.equals(this); - } - } - - @Override - public int hashCode() { - return Objects.hash(id, label); - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case VertexType.ID_FIELD_POSITION: - return id; - case VertexType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 2, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public String toString() { - return id + "#" + label + "#" + value; - } - - @Override - public BinaryString getBinaryLabel() { + case VertexType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public LongVertex get() { - return new LongVertex(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize id, label and value - output.writeLong(this.id); - kryo.writeClassAndObject(output, this.getBinaryLabel()); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize id, label and value - long id = input.readLong(); - BinaryString label = (BinaryString) kryo.readClassAndObject(input); - Row value = (Row) kryo.readClassAndObject(input); - this.id = id; - this.setValue(value); - this.setBinaryLabel(label); - } - + default: + return value.getField(i - 2, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public String toString() { + return id + "#" + label + "#" + value; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public LongVertex get() { + return new LongVertex(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize id, label and value + output.writeLong(this.id); + kryo.writeClassAndObject(output, this.getBinaryLabel()); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize id, label and value + long id = input.readLong(); + BinaryString label = (BinaryString) kryo.readClassAndObject(input); + Row value = (Row) kryo.readClassAndObject(input); + this.id = id; + this.setValue(value); + this.setBinaryLabel(label); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectEdge.java index c00a41dc9..46ecec306 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,195 +30,199 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class ObjectEdge implements RowEdge, KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private Object srcId; - - private Object targetId; - - private EdgeDirection direction = EdgeDirection.OUT; - - private BinaryString label; - - private Row value; - - public ObjectEdge() { - - } - - public ObjectEdge(Object srcId, Object targetId) { - this.srcId = srcId; - this.targetId = targetId; - } - - public ObjectEdge(Object srcId, Object targetId, Row value) { - this.srcId = srcId; - this.targetId = targetId; - this.value = value; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } +public class ObjectEdge implements RowEdge, KryoSerializable { - @Override - public Object getSrcId() { + public static final Supplier CONSTRUCTOR = new Constructor(); + + private Object srcId; + + private Object targetId; + + private EdgeDirection direction = EdgeDirection.OUT; + + private BinaryString label; + + private Row value; + + public ObjectEdge() {} + + public ObjectEdge(Object srcId, Object targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + public ObjectEdge(Object srcId, Object targetId, Row value) { + this.srcId = srcId; + this.targetId = targetId; + this.value = value; + } + + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = targetId; + } + + @Override + public String getLabel() { + return label.toString(); + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ObjectEdge)) { + return false; + } + ObjectEdge that = (ObjectEdge) o; + return Objects.equals(srcId, that.srcId) + && Objects.equals(targetId, that.targetId) + && direction == that.direction + && Objects.equals(label, that.label); + } + + @Override + public int hashCode() { + return Objects.hash(srcId, targetId, direction, label); + } + + @Override + public ObjectEdge reverse() { + ObjectEdge edge = new ObjectEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public ObjectEdge withValue(Row value) { + ObjectEdge edge = new ObjectEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public String toString() { + return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; + } + + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: return srcId; - } - - @Override - public void setSrcId(Object srcId) { - this.srcId = srcId; - } - - @Override - public Object getTargetId() { + case EdgeType.TARGET_ID_FIELD_POSITION: return targetId; - } - - @Override - public void setTargetId(Object targetId) { - this.targetId = targetId; - } - - @Override - public String getLabel() { - return label.toString(); - } - - @Override - public EdgeDirection getDirect() { - return direction; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ObjectEdge)) { - return false; - } - ObjectEdge that = (ObjectEdge) o; - return Objects.equals(srcId, that.srcId) && Objects.equals(targetId, - that.targetId) && direction == that.direction && Objects.equals(label, that.label); - } - - @Override - public int hashCode() { - return Objects.hash(srcId, targetId, direction, label); - } - - @Override - public ObjectEdge reverse() { - ObjectEdge edge = new ObjectEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public Row getValue() { - return value; - } - - @Override - public ObjectEdge withValue(Row value) { - ObjectEdge edge = new ObjectEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public String toString() { - return srcId + "#" + targetId + "#" + label + "#" + direction + "#" + value; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return srcId; - case EdgeType.TARGET_ID_FIELD_POSITION: - return targetId; - case EdgeType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 3, type); - } - } - - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - ObjectEdge edge = new ObjectEdge(srcId, targetId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction); - return edge; - } - - @Override - public RowEdge identityReverse() { - ObjectEdge edge = new ObjectEdge(targetId, srcId, value); - edge.setBinaryLabel(label); - edge.setDirect(direction.reverse()); - return edge; - } - - @Override - public BinaryString getBinaryLabel() { + case EdgeType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public ObjectEdge get() { - return new ObjectEdge(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize srcId, targetId, direction, label and value - kryo.writeClassAndObject(output, this.getSrcId()); - kryo.writeClassAndObject(output, this.getTargetId()); - kryo.writeClassAndObject(output, this.getDirect()); - kryo.writeClassAndObject(output, this.getBinaryLabel()); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize srcId, targetId, direction, label and value - Object srcId = kryo.readClassAndObject(input); - this.srcId = srcId; - Object targetId = kryo.readClassAndObject(input); - this.targetId = targetId; - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - this.setDirect(direction); - BinaryString label = (BinaryString) kryo.readClassAndObject(input); - this.setBinaryLabel(label); - Row value = (Row) kryo.readClassAndObject(input); - this.setValue(value); - } - + default: + return value.getField(i - 3, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + ObjectEdge edge = new ObjectEdge(srcId, targetId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction); + return edge; + } + + @Override + public RowEdge identityReverse() { + ObjectEdge edge = new ObjectEdge(targetId, srcId, value); + edge.setBinaryLabel(label); + edge.setDirect(direction.reverse()); + return edge; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public ObjectEdge get() { + return new ObjectEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize srcId, targetId, direction, label and value + kryo.writeClassAndObject(output, this.getSrcId()); + kryo.writeClassAndObject(output, this.getTargetId()); + kryo.writeClassAndObject(output, this.getDirect()); + kryo.writeClassAndObject(output, this.getBinaryLabel()); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize srcId, targetId, direction, label and value + Object srcId = kryo.readClassAndObject(input); + this.srcId = srcId; + Object targetId = kryo.readClassAndObject(input); + this.targetId = targetId; + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + this.setDirect(direction); + BinaryString label = (BinaryString) kryo.readClassAndObject(input); + this.setBinaryLabel(label); + Row value = (Row) kryo.readClassAndObject(input); + this.setValue(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectTsEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectTsEdge.java index 105f9cb31..5d90cd5fa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectTsEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectTsEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,145 +30,156 @@ import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; -public class ObjectTsEdge extends ObjectEdge implements IGraphElementWithTimeField, - KryoSerializable { - - public static final Supplier CONSTRUCTOR = new Constructor(); - - private long time; - - public ObjectTsEdge() { - - } - - public ObjectTsEdge(Object srcId, Object targetId) { - super(srcId, targetId); - } - - public ObjectTsEdge(Object srcId, Object targetId, Row value) { - super(srcId, targetId, value); - } - - public void setTime(long time) { - this.time = time; - } - - public long getTime() { - return time; - } - - @Override - public Object getField(int i, IType type) { - switch (i) { - case EdgeType.SRC_ID_FIELD_POSITION: - return getSrcId(); - case EdgeType.TARGET_ID_FIELD_POSITION: - return getTargetId(); - case EdgeType.LABEL_FIELD_POSITION: - return getBinaryLabel(); - case EdgeType.TIME_FIELD_POSITION: - return time; - default: - return getValue().getField(i - 4, type); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ObjectTsEdge)) { - return false; - } - if (!super.equals(o)) { - return false; - } - ObjectTsEdge that = (ObjectTsEdge) o; - return time == that.time; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), time); - } +public class ObjectTsEdge extends ObjectEdge + implements IGraphElementWithTimeField, KryoSerializable { - @Override - public ObjectTsEdge reverse() { - ObjectTsEdge edge = new ObjectTsEdge(getTargetId(), getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + public static final Supplier CONSTRUCTOR = new Constructor(); - @Override - public RowEdge withDirection(EdgeDirection direction) { - ObjectTsEdge edge = new ObjectTsEdge(getSrcId(), getTargetId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(direction); - edge.setTime(getTime()); - return edge; - } + private long time; - @Override - public ObjectEdge withValue(Row value) { - ObjectTsEdge edge = new ObjectTsEdge(getSrcId(), getTargetId(), value); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect()); - edge.setTime(time); - return edge; - } + public ObjectTsEdge() {} - @Override - public RowEdge identityReverse() { - ObjectTsEdge edge = new ObjectTsEdge(getTargetId(), getSrcId(), getValue()); - edge.setBinaryLabel(getBinaryLabel()); - edge.setDirect(getDirect().reverse()); - edge.setTime(time); - return edge; - } + public ObjectTsEdge(Object srcId, Object targetId) { + super(srcId, targetId); + } - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() - + "#" + time + "#" + getValue(); - } + public ObjectTsEdge(Object srcId, Object targetId, Row value) { + super(srcId, targetId, value); + } - private static class Constructor implements Supplier { + public void setTime(long time) { + this.time = time; + } - @Override - public ObjectTsEdge get() { - return new ObjectTsEdge(); - } - } + public long getTime() { + return time; + } - @Override - public void write(Kryo kryo, Output output) { - // serialize srcId, targetId, direction, label, value and time - kryo.writeClassAndObject(output, this.getSrcId()); - kryo.writeClassAndObject(output, this.getTargetId()); - kryo.writeClassAndObject(output, this.getDirect()); - kryo.writeClassAndObject(output, this.getBinaryLabel()); - kryo.writeClassAndObject(output, this.getValue()); - output.writeLong(this.getTime()); - } + @Override + public Object getField(int i, IType type) { + switch (i) { + case EdgeType.SRC_ID_FIELD_POSITION: + return getSrcId(); + case EdgeType.TARGET_ID_FIELD_POSITION: + return getTargetId(); + case EdgeType.LABEL_FIELD_POSITION: + return getBinaryLabel(); + case EdgeType.TIME_FIELD_POSITION: + return time; + default: + return getValue().getField(i - 4, type); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ObjectTsEdge)) { + return false; + } + if (!super.equals(o)) { + return false; + } + ObjectTsEdge that = (ObjectTsEdge) o; + return time == that.time; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), time); + } + + @Override + public ObjectTsEdge reverse() { + ObjectTsEdge edge = new ObjectTsEdge(getTargetId(), getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + ObjectTsEdge edge = new ObjectTsEdge(getSrcId(), getTargetId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(direction); + edge.setTime(getTime()); + return edge; + } + + @Override + public ObjectEdge withValue(Row value) { + ObjectTsEdge edge = new ObjectTsEdge(getSrcId(), getTargetId(), value); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect()); + edge.setTime(time); + return edge; + } + + @Override + public RowEdge identityReverse() { + ObjectTsEdge edge = new ObjectTsEdge(getTargetId(), getSrcId(), getValue()); + edge.setBinaryLabel(getBinaryLabel()); + edge.setDirect(getDirect().reverse()); + edge.setTime(time); + return edge; + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + time + + "#" + + getValue(); + } + + private static class Constructor implements Supplier { @Override - public void read(Kryo kryo, Input input) { - // deserialize srcId, targetId, direction, label, value and time - Object srcId = kryo.readClassAndObject(input); - Object targetId = kryo.readClassAndObject(input); - EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); - BinaryString label = (BinaryString) kryo.readClassAndObject(input); - Row value = (Row) kryo.readClassAndObject(input); - long time = input.readLong(); - this.setSrcId(srcId); - this.setTargetId(targetId); - this.setValue(value); - this.setBinaryLabel(label); - this.setDirect(direction); - this.setTime(time); - } - + public ObjectTsEdge get() { + return new ObjectTsEdge(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize srcId, targetId, direction, label, value and time + kryo.writeClassAndObject(output, this.getSrcId()); + kryo.writeClassAndObject(output, this.getTargetId()); + kryo.writeClassAndObject(output, this.getDirect()); + kryo.writeClassAndObject(output, this.getBinaryLabel()); + kryo.writeClassAndObject(output, this.getValue()); + output.writeLong(this.getTime()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize srcId, targetId, direction, label, value and time + Object srcId = kryo.readClassAndObject(input); + Object targetId = kryo.readClassAndObject(input); + EdgeDirection direction = (EdgeDirection) kryo.readClassAndObject(input); + BinaryString label = (BinaryString) kryo.readClassAndObject(input); + Row value = (Row) kryo.readClassAndObject(input); + long time = input.readLong(); + this.setSrcId(srcId); + this.setTargetId(targetId); + this.setValue(value); + this.setBinaryLabel(label); + this.setDirect(direction); + this.setTime(time); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectVertex.java index 91a04b313..055d37a73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/data/impl/types/ObjectVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.common.data.impl.types; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,156 +31,157 @@ import org.apache.geaflow.dsl.common.util.BinaryUtil; import org.apache.geaflow.model.graph.vertex.IVertex; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class ObjectVertex implements RowVertex, KryoSerializable { - public static final Supplier CONSTRUCTOR = new Constructor(); + public static final Supplier CONSTRUCTOR = new Constructor(); - private Object id; + private Object id; - private BinaryString label; + private BinaryString label; - private Row value; + private Row value; - public ObjectVertex() { + public ObjectVertex() {} - } + public ObjectVertex(Object id) { + this.id = id; + } - public ObjectVertex(Object id) { - this.id = id; - } + public ObjectVertex(Object id, BinaryString label, Row value) { + this.id = id; + this.label = label; + this.value = value; + } - public ObjectVertex(Object id, BinaryString label, Row value) { - this.id = id; - this.label = label; - this.value = value; - } + @Override + public String getLabel() { + return label.toString(); + } - @Override - public String getLabel() { - return label.toString(); - } + @Override + public void setLabel(String label) { + this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); + } - @Override - public void setLabel(String label) { - this.label = (BinaryString) BinaryUtil.toBinaryLabel(label); - } + @Override + public Object getId() { + return id; + } - @Override - public Object getId() { - return id; - } + @Override + public void setId(Object id) { + this.id = Objects.requireNonNull(id); + } - @Override - public void setId(Object id) { - this.id = Objects.requireNonNull(id); - } + @Override + public Row getValue() { + return value; + } - @Override - public Row getValue() { - return value; - } + @Override + public ObjectVertex withValue(Row value) { + return new ObjectVertex(id, label, value); + } - @Override - public ObjectVertex withValue(Row value) { - return new ObjectVertex(id, label, value); - } + @Override + public ObjectVertex withLabel(String label) { + return new ObjectVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); + } - @Override - public ObjectVertex withLabel(String label) { - return new ObjectVertex(id, (BinaryString) BinaryUtil.toBinaryLabel(label), value); - } + @Override + public IVertex withTime(long time) { + throw new GeaFlowDSLException("Vertex not support timestamp"); + } - @Override - public IVertex withTime(long time) { - throw new GeaFlowDSLException("Vertex not support timestamp"); + @SuppressWarnings("unchecked") + @Override + public int compareTo(Object o) { + RowVertex vertex = (RowVertex) o; + if (id instanceof Comparable) { + return ((Comparable) id).compareTo(vertex.getId()); } + return Integer.compare(getId().hashCode(), vertex.getId().hashCode()); + } - @SuppressWarnings("unchecked") - @Override - public int compareTo(Object o) { - RowVertex vertex = (RowVertex) o; - if (id instanceof Comparable) { - return ((Comparable) id).compareTo(vertex.getId()); - } - return Integer.compare(getId().hashCode(), vertex.getId().hashCode()); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - RowVertex that = (RowVertex) o; - - return id.equals(that.getId()) && Objects.equals(label, that.getBinaryLabel()); + if (!(o instanceof RowVertex)) { + return false; } + RowVertex that = (RowVertex) o; - @Override - public int hashCode() { - return Objects.hash(id, label); - } + return id.equals(that.getId()) && Objects.equals(label, that.getBinaryLabel()); + } - @Override - public Object getField(int i, IType type) { - switch (i) { - case VertexType.ID_FIELD_POSITION: - return id; - case VertexType.LABEL_FIELD_POSITION: - return label; - default: - return value.getField(i - 2, type); - } - } + @Override + public int hashCode() { + return Objects.hash(id, label); + } - @Override - public void setValue(Row value) { - this.value = value; - } - - @Override - public String toString() { - return id + "#" + label + "#" + value; - } - - @Override - public BinaryString getBinaryLabel() { + @Override + public Object getField(int i, IType type) { + switch (i) { + case VertexType.ID_FIELD_POSITION: + return id; + case VertexType.LABEL_FIELD_POSITION: return label; - } - - @Override - public void setBinaryLabel(BinaryString label) { - this.label = label; - } - - private static class Constructor implements Supplier { - - @Override - public ObjectVertex get() { - return new ObjectVertex(); - } - } - - @Override - public void write(Kryo kryo, Output output) { - // serialize id, label, and value - kryo.writeClassAndObject(output, this.getId()); - kryo.writeClassAndObject(output, this.getBinaryLabel()); - kryo.writeClassAndObject(output, this.getValue()); - } - - @Override - public void read(Kryo kryo, Input input) { - // deserialize id, label, and value - Object id = kryo.readClassAndObject(input); - BinaryString label = (BinaryString) kryo.readClassAndObject(input); - Row value = (Row) kryo.readClassAndObject(input); - this.id = id; - this.setValue(value); - this.setBinaryLabel(label); - } - - + default: + return value.getField(i - 2, type); + } + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public String toString() { + return id + "#" + label + "#" + value; + } + + @Override + public BinaryString getBinaryLabel() { + return label; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.label = label; + } + + private static class Constructor implements Supplier { + + @Override + public ObjectVertex get() { + return new ObjectVertex(); + } + } + + @Override + public void write(Kryo kryo, Output output) { + // serialize id, label, and value + kryo.writeClassAndObject(output, this.getId()); + kryo.writeClassAndObject(output, this.getBinaryLabel()); + kryo.writeClassAndObject(output, this.getValue()); + } + + @Override + public void read(Kryo kryo, Input input) { + // deserialize id, label, and value + Object id = kryo.readClassAndObject(input); + BinaryString label = (BinaryString) kryo.readClassAndObject(input); + Row value = (Row) kryo.readClassAndObject(input); + this.id = id; + this.setValue(value); + this.setBinaryLabel(label); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/EdgeDescriptor.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/EdgeDescriptor.java index 1b0dd1837..cdd275f1a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/EdgeDescriptor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/EdgeDescriptor.java @@ -23,15 +23,15 @@ public class EdgeDescriptor { - public String id; - public String type; - public String sourceType; - public String targetType; + public String id; + public String type; + public String sourceType; + public String targetType; - public EdgeDescriptor(String id, String type, String sourceType, String targetType) { - this.id = Objects.requireNonNull(id); - this.type = Objects.requireNonNull(type); - this.sourceType = Objects.requireNonNull(sourceType); - this.targetType = Objects.requireNonNull(targetType); - } + public EdgeDescriptor(String id, String type, String sourceType, String targetType) { + this.id = Objects.requireNonNull(id); + this.type = Objects.requireNonNull(type); + this.sourceType = Objects.requireNonNull(sourceType); + this.targetType = Objects.requireNonNull(targetType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/GraphDescriptor.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/GraphDescriptor.java index 8b693c5df..c645fa017 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/GraphDescriptor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/GraphDescriptor.java @@ -19,55 +19,56 @@ package org.apache.geaflow.dsl.common.descriptor; -import com.google.gson.Gson; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; +import com.google.gson.Gson; + public class GraphDescriptor { - private final AtomicLong id = new AtomicLong(0L); - public List nodes = new ArrayList<>(); - public List edges = new ArrayList<>(); - public List relations = new ArrayList<>(); + private final AtomicLong id = new AtomicLong(0L); + public List nodes = new ArrayList<>(); + public List edges = new ArrayList<>(); + public List relations = new ArrayList<>(); - public GraphDescriptor addNode(NodeDescriptor nodeDescriptor) { - nodes.add(Objects.requireNonNull(nodeDescriptor)); - return this; - } + public GraphDescriptor addNode(NodeDescriptor nodeDescriptor) { + nodes.add(Objects.requireNonNull(nodeDescriptor)); + return this; + } - public GraphDescriptor addNode(List nodeDescriptors) { - for (NodeDescriptor nodeDescriptor : nodeDescriptors) { - addNode(nodeDescriptor); - } - return this; + public GraphDescriptor addNode(List nodeDescriptors) { + for (NodeDescriptor nodeDescriptor : nodeDescriptors) { + addNode(nodeDescriptor); } + return this; + } - public GraphDescriptor addEdge(EdgeDescriptor edgeDescriptor) { - edges.add(Objects.requireNonNull(edgeDescriptor)); - return this; - } + public GraphDescriptor addEdge(EdgeDescriptor edgeDescriptor) { + edges.add(Objects.requireNonNull(edgeDescriptor)); + return this; + } - public GraphDescriptor addEdge(List edgeStats) { - for (EdgeDescriptor edgeDescriptor : edgeStats) { - addEdge(edgeDescriptor); - } - return this; + public GraphDescriptor addEdge(List edgeStats) { + for (EdgeDescriptor edgeDescriptor : edgeStats) { + addEdge(edgeDescriptor); } + return this; + } - public GraphDescriptor addRelation(RelationDescriptor relationDescriptor) { - relations.add(Objects.requireNonNull(relationDescriptor)); - return this; - } + public GraphDescriptor addRelation(RelationDescriptor relationDescriptor) { + relations.add(Objects.requireNonNull(relationDescriptor)); + return this; + } - public String getIdName(String value) { - return value + "-" + id.getAndIncrement(); - } + public String getIdName(String value) { + return value + "-" + id.getAndIncrement(); + } - @Override - public String toString() { - Gson gson = new Gson(); - return gson.toJson(this); - } + @Override + public String toString() { + Gson gson = new Gson(); + return gson.toJson(this); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/NodeDescriptor.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/NodeDescriptor.java index 8444c68b2..73f26e240 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/NodeDescriptor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/NodeDescriptor.java @@ -21,11 +21,11 @@ public class NodeDescriptor { - public String id; - public String type; + public String id; + public String type; - public NodeDescriptor(String id, String type) { - this.id = id; - this.type = type; - } + public NodeDescriptor(String id, String type) { + this.id = id; + this.type = type; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/RelationDescriptor.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/RelationDescriptor.java index fc38a6cac..5ccf25c21 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/RelationDescriptor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/descriptor/RelationDescriptor.java @@ -21,13 +21,13 @@ public class RelationDescriptor { - public String source; - public String target; - public String type; + public String source; + public String target; + public String type; - public RelationDescriptor(String source, String target, String type) { - this.source = source; - this.target = target; - this.type = type; - } + public RelationDescriptor(String source, String target, String type) { + this.source = source; + this.target = target; + this.type = type; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/exception/GeaFlowDSLException.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/exception/GeaFlowDSLException.java index 0d69c59ec..efe8ff2a7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/exception/GeaFlowDSLException.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/exception/GeaFlowDSLException.java @@ -24,27 +24,27 @@ public class GeaFlowDSLException extends RuntimeException { - public GeaFlowDSLException(String message, Throwable cause) { - super(message, cause); - } + public GeaFlowDSLException(String message, Throwable cause) { + super(message, cause); + } - public GeaFlowDSLException(String message) { - super(message); - } + public GeaFlowDSLException(String message) { + super(message); + } - public GeaFlowDSLException(Throwable e, String message, Object... parameters) { - super(MessageFormatter.arrayFormat(message, parameters).getMessage(), e); - } + public GeaFlowDSLException(Throwable e, String message, Object... parameters) { + super(MessageFormatter.arrayFormat(message, parameters).getMessage(), e); + } - public GeaFlowDSLException(SqlParserPos position, String message, Object... parameters) { - super("At " + position + ": " + MessageFormatter.arrayFormat(message, parameters).getMessage()); - } + public GeaFlowDSLException(SqlParserPos position, String message, Object... parameters) { + super("At " + position + ": " + MessageFormatter.arrayFormat(message, parameters).getMessage()); + } - public GeaFlowDSLException(String message, Object... parameters) { - super(MessageFormatter.arrayFormat(message, parameters).getMessage()); - } + public GeaFlowDSLException(String message, Object... parameters) { + super(MessageFormatter.arrayFormat(message, parameters).getMessage()); + } - public GeaFlowDSLException(Throwable cause) { - super(cause); - } + public GeaFlowDSLException(Throwable cause) { + super(cause); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/Description.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/Description.java index 22fdf7ed3..281b3f47d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/Description.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/Description.java @@ -25,7 +25,7 @@ @Retention(RetentionPolicy.RUNTIME) public @interface Description { - String name(); + String name(); - String description(); + String description(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/FunctionContext.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/FunctionContext.java index 6ea13788e..03e432e96 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/FunctionContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/FunctionContext.java @@ -21,21 +21,22 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.common.config.Configuration; public class FunctionContext implements Serializable { - private final Configuration config; + private final Configuration config; - private FunctionContext(Configuration config) { - this.config = Objects.requireNonNull(config); - } + private FunctionContext(Configuration config) { + this.config = Objects.requireNonNull(config); + } - public static FunctionContext of(Configuration config) { - return new FunctionContext(config); - } + public static FunctionContext of(Configuration config) { + return new FunctionContext(config); + } - public Configuration getConfig() { - return config; - } + public Configuration getConfig() { + return config; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/PropertyExistsFunctions.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/PropertyExistsFunctions.java index eb6ca542c..86f975d4d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/PropertyExistsFunctions.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/PropertyExistsFunctions.java @@ -28,123 +28,124 @@ * *

Implements ISO-GQL Section 19.13: <property_exists predicate> * - *

These static methods are called via reflection by the corresponding runtime - * Expression classes for better distributed execution safety. + *

These static methods are called via reflection by the corresponding runtime Expression classes + * for better distributed execution safety. * *

ISO-GQL General Rules: + * *

    - *
  • If element is null, result is Unknown (null)
  • - *
  • If element has the specified property, result is True
  • - *
  • Otherwise, result is False
  • + *
  • If element is null, result is Unknown (null) + *
  • If element has the specified property, result is True + *
  • Otherwise, result is False *
* - *

Implementation Note: - * This implementation follows GeaFlow's runtime validation strategy. Property existence - * checking relies on compile-time validation through the SQL optimizer and type system. - * At runtime, we validate types and provide meaningful error messages, but assume that + *

Implementation Note: This implementation follows GeaFlow's runtime validation strategy. + * Property existence checking relies on compile-time validation through the SQL optimizer and type + * system. At runtime, we validate types and provide meaningful error messages, but assume that * property names have been validated during query compilation. * *

This design matches the approach used by other ISO-GQL predicates (IS_SOURCE_OF, - * IS_DESTINATION_OF) and aligns with GeaFlow's Row interface, which provides indexed - * property access rather than name-based access at runtime. + * IS_DESTINATION_OF) and aligns with GeaFlow's Row interface, which provides indexed property + * access rather than name-based access at runtime. */ public class PropertyExistsFunctions { - /** - * Evaluates PROPERTY_EXISTS predicate for any graph element. - * - *

This is the primary implementation method that provides comprehensive - * validation following ISO-GQL three-valued logic. - * - * @param element graph element (vertex, edge, or row) - * @param propertyName property name to check - * @return Boolean: true if property exists, false if not, null if element is null - * @throws IllegalArgumentException if element is not a valid graph element type - * @throws IllegalArgumentException if propertyName is null or empty - */ - public static Boolean propertyExists(Object element, String propertyName) { - // ISO-GQL Rule 1: If element is null, result is Unknown (null) - if (element == null) { - return null; // Three-valued logic: Unknown - } - - // ISO-GQL Rule 2: Type validation - // Element must be a graph element type (Row, RowVertex, or RowEdge) - if (!(element instanceof Row || element instanceof RowVertex || element instanceof RowEdge)) { - throw new IllegalArgumentException( - "First operand of PROPERTY_EXISTS must be a graph element (Row, RowVertex, or RowEdge), got: " - + element.getClass().getName()); - } - - // ISO-GQL Rule 3: Property name validation - // Property name must be non-null and non-empty - if (propertyName == null || propertyName.trim().isEmpty()) { - throw new IllegalArgumentException( - "Second operand of PROPERTY_EXISTS must be a non-empty property name"); - } - - // ISO-GQL Rule 4: Property existence check - // - // IMPLEMENTATION NOTE: - // In GeaFlow's architecture, property existence validation happens at compile-time - // through the SQL optimizer and type system (StructType.contain()). The Row interface - // only provides indexed access (getField(int i)), not name-based access. - // - // At runtime, if this code is reached with a valid property name, it means: - // 1. The SQL parser accepted the property name - // 2. The type system validated it against the schema - // 3. The query optimizer generated code using valid field indices - // - // Therefore, we return true for any non-null element with a non-empty property name, - // trusting the compile-time validation. This matches GeaFlow's design philosophy - // and is consistent with the Row interface's indexed access pattern. - // - // For a full runtime property checking implementation, GeaFlow would need to: - // - Extend Row interface to include schema metadata (getType() method) - // - Or pass StructType context through the execution pipeline - // - Or add hasField(String name) method to Row interface - // - // These architectural changes would enable runtime validation but at the cost - // of memory overhead and execution complexity. - return true; + /** + * Evaluates PROPERTY_EXISTS predicate for any graph element. + * + *

This is the primary implementation method that provides comprehensive validation following + * ISO-GQL three-valued logic. + * + * @param element graph element (vertex, edge, or row) + * @param propertyName property name to check + * @return Boolean: true if property exists, false if not, null if element is null + * @throws IllegalArgumentException if element is not a valid graph element type + * @throws IllegalArgumentException if propertyName is null or empty + */ + public static Boolean propertyExists(Object element, String propertyName) { + // ISO-GQL Rule 1: If element is null, result is Unknown (null) + if (element == null) { + return null; // Three-valued logic: Unknown } - /** - * Type-specific overload for RowVertex elements. - * - *

Provides better type checking and clearer error messages for vertex-specific calls. - * - * @param vertex vertex to check - * @param propertyName property name to check - * @return Boolean: true if property exists, false if not, null if vertex is null - */ - public static Boolean propertyExists(RowVertex vertex, String propertyName) { - return propertyExists((Object) vertex, propertyName); + // ISO-GQL Rule 2: Type validation + // Element must be a graph element type (Row, RowVertex, or RowEdge) + if (!(element instanceof Row || element instanceof RowVertex || element instanceof RowEdge)) { + throw new IllegalArgumentException( + "First operand of PROPERTY_EXISTS must be a graph element (Row, RowVertex, or RowEdge)," + + " got: " + + element.getClass().getName()); } - /** - * Type-specific overload for RowEdge elements. - * - *

Provides better type checking and clearer error messages for edge-specific calls. - * - * @param edge edge to check - * @param propertyName property name to check - * @return Boolean: true if property exists, false if not, null if edge is null - */ - public static Boolean propertyExists(RowEdge edge, String propertyName) { - return propertyExists((Object) edge, propertyName); + // ISO-GQL Rule 3: Property name validation + // Property name must be non-null and non-empty + if (propertyName == null || propertyName.trim().isEmpty()) { + throw new IllegalArgumentException( + "Second operand of PROPERTY_EXISTS must be a non-empty property name"); } - /** - * Type-specific overload for Row elements. - * - *

Provides better type checking and clearer error messages for row-specific calls. - * - * @param row row to check - * @param propertyName property name to check - * @return Boolean: true if property exists, false if not, null if row is null - */ - public static Boolean propertyExists(Row row, String propertyName) { - return propertyExists((Object) row, propertyName); - } + // ISO-GQL Rule 4: Property existence check + // + // IMPLEMENTATION NOTE: + // In GeaFlow's architecture, property existence validation happens at compile-time + // through the SQL optimizer and type system (StructType.contain()). The Row interface + // only provides indexed access (getField(int i)), not name-based access. + // + // At runtime, if this code is reached with a valid property name, it means: + // 1. The SQL parser accepted the property name + // 2. The type system validated it against the schema + // 3. The query optimizer generated code using valid field indices + // + // Therefore, we return true for any non-null element with a non-empty property name, + // trusting the compile-time validation. This matches GeaFlow's design philosophy + // and is consistent with the Row interface's indexed access pattern. + // + // For a full runtime property checking implementation, GeaFlow would need to: + // - Extend Row interface to include schema metadata (getType() method) + // - Or pass StructType context through the execution pipeline + // - Or add hasField(String name) method to Row interface + // + // These architectural changes would enable runtime validation but at the cost + // of memory overhead and execution complexity. + return true; + } + + /** + * Type-specific overload for RowVertex elements. + * + *

Provides better type checking and clearer error messages for vertex-specific calls. + * + * @param vertex vertex to check + * @param propertyName property name to check + * @return Boolean: true if property exists, false if not, null if vertex is null + */ + public static Boolean propertyExists(RowVertex vertex, String propertyName) { + return propertyExists((Object) vertex, propertyName); + } + + /** + * Type-specific overload for RowEdge elements. + * + *

Provides better type checking and clearer error messages for edge-specific calls. + * + * @param edge edge to check + * @param propertyName property name to check + * @return Boolean: true if property exists, false if not, null if edge is null + */ + public static Boolean propertyExists(RowEdge edge, String propertyName) { + return propertyExists((Object) edge, propertyName); + } + + /** + * Type-specific overload for Row elements. + * + *

Provides better type checking and clearer error messages for row-specific calls. + * + * @param row row to check + * @param propertyName property name to check + * @return Boolean: true if property exists, false if not, null if row is null + */ + public static Boolean propertyExists(Row row, String propertyName) { + return propertyExists((Object) row, propertyName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/SourceDestinationFunctions.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/SourceDestinationFunctions.java index 631792ced..9523c5b51 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/SourceDestinationFunctions.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/SourceDestinationFunctions.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.common.function; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; @@ -28,127 +29,127 @@ * *

Implements ISO-GQL Section 19.10: <source/destination predicate> * - *

These static methods are called via reflection by the corresponding runtime - * Expression classes (IsSourceOfExpression, IsDestinationOfExpression) for better - * distributed execution safety. + *

These static methods are called via reflection by the corresponding runtime Expression classes + * (IsSourceOfExpression, IsDestinationOfExpression) for better distributed execution safety. * *

ISO-GQL General Rules: + * *

    - *
  • If node or edge is null, result is Unknown (null)
  • - *
  • If edge is undirected, result is False
  • - *
  • If node matches edge endpoint (source/destination), result is True
  • - *
  • Otherwise, result is False
  • + *
  • If node or edge is null, result is Unknown (null) + *
  • If edge is undirected, result is False + *
  • If node matches edge endpoint (source/destination), result is True + *
  • Otherwise, result is False *
*/ public class SourceDestinationFunctions { - /** - * Implements IS_SOURCE_OF predicate. - * - * @param nodeValue vertex/node object to check - * @param edgeValue edge object to check - * @return Boolean: true if node is source of edge, false if not, null if either is null - */ - public static Boolean isSourceOf(Object nodeValue, Object edgeValue) { - // ISO-GQL Rule 1: If node or edge is null, result is Unknown (null) - if (nodeValue == null || edgeValue == null) { - return null; // Three-valued logic: Unknown - } - - // Validate types - if (!(nodeValue instanceof RowVertex)) { - throw new IllegalArgumentException( - "First operand of IS_SOURCE_OF must be a vertex/node, got: " - + nodeValue.getClass().getName()); - } - if (!(edgeValue instanceof RowEdge)) { - throw new IllegalArgumentException( - "Second operand of IS_SOURCE_OF must be an edge, got: " - + edgeValue.getClass().getName()); - } - - RowVertex node = (RowVertex) nodeValue; - RowEdge edge = (RowEdge) edgeValue; - - // ISO-GQL Rule 2: If edge is undirected, result is False - // Note: In GeaFlow, BOTH direction means undirected - if (edge.getDirect() == org.apache.geaflow.model.graph.edge.EdgeDirection.BOTH) { - return false; - } - - // ISO-GQL Rule 3: Check if node is source of edge - // Compare node ID with edge source ID - Object nodeId = node.getId(); - Object edgeSrcId = edge.getSrcId(); - - return Objects.equals(nodeId, edgeSrcId); + /** + * Implements IS_SOURCE_OF predicate. + * + * @param nodeValue vertex/node object to check + * @param edgeValue edge object to check + * @return Boolean: true if node is source of edge, false if not, null if either is null + */ + public static Boolean isSourceOf(Object nodeValue, Object edgeValue) { + // ISO-GQL Rule 1: If node or edge is null, result is Unknown (null) + if (nodeValue == null || edgeValue == null) { + return null; // Three-valued logic: Unknown + } + + // Validate types + if (!(nodeValue instanceof RowVertex)) { + throw new IllegalArgumentException( + "First operand of IS_SOURCE_OF must be a vertex/node, got: " + + nodeValue.getClass().getName()); + } + if (!(edgeValue instanceof RowEdge)) { + throw new IllegalArgumentException( + "Second operand of IS_SOURCE_OF must be an edge, got: " + edgeValue.getClass().getName()); } - /** - * Implements IS_NOT_SOURCE_OF predicate. - * - * @param nodeValue vertex/node object to check - * @param edgeValue edge object to check - * @return Boolean: true if node is NOT source of edge, false if it is, null if either is null - */ - public static Boolean isNotSourceOf(Object nodeValue, Object edgeValue) { - Boolean result = isSourceOf(nodeValue, edgeValue); - // Three-valued logic: NOT Unknown = Unknown (null remains null) - return result == null ? null : !result; + RowVertex node = (RowVertex) nodeValue; + RowEdge edge = (RowEdge) edgeValue; + + // ISO-GQL Rule 2: If edge is undirected, result is False + // Note: In GeaFlow, BOTH direction means undirected + if (edge.getDirect() == org.apache.geaflow.model.graph.edge.EdgeDirection.BOTH) { + return false; } - /** - * Implements IS_DESTINATION_OF predicate. - * - * @param nodeValue vertex/node object to check - * @param edgeValue edge object to check - * @return Boolean: true if node is destination of edge, false if not, null if either is null - */ - public static Boolean isDestinationOf(Object nodeValue, Object edgeValue) { - // ISO-GQL Rule 1: If node or edge is null, result is Unknown (null) - if (nodeValue == null || edgeValue == null) { - return null; // Three-valued logic: Unknown - } - - // Validate types - if (!(nodeValue instanceof RowVertex)) { - throw new IllegalArgumentException( - "First operand of IS_DESTINATION_OF must be a vertex/node, got: " - + nodeValue.getClass().getName()); - } - if (!(edgeValue instanceof RowEdge)) { - throw new IllegalArgumentException( - "Second operand of IS_DESTINATION_OF must be an edge, got: " - + edgeValue.getClass().getName()); - } - - RowVertex node = (RowVertex) nodeValue; - RowEdge edge = (RowEdge) edgeValue; - - // ISO-GQL Rule 2: If edge is undirected, result is False - // Note: In GeaFlow, BOTH direction means undirected - if (edge.getDirect() == org.apache.geaflow.model.graph.edge.EdgeDirection.BOTH) { - return false; - } - - // ISO-GQL Rule 3: Check if node is destination of edge - // Compare node ID with edge target ID - Object nodeId = node.getId(); - Object edgeTargetId = edge.getTargetId(); - - return Objects.equals(nodeId, edgeTargetId); + // ISO-GQL Rule 3: Check if node is source of edge + // Compare node ID with edge source ID + Object nodeId = node.getId(); + Object edgeSrcId = edge.getSrcId(); + + return Objects.equals(nodeId, edgeSrcId); + } + + /** + * Implements IS_NOT_SOURCE_OF predicate. + * + * @param nodeValue vertex/node object to check + * @param edgeValue edge object to check + * @return Boolean: true if node is NOT source of edge, false if it is, null if either is null + */ + public static Boolean isNotSourceOf(Object nodeValue, Object edgeValue) { + Boolean result = isSourceOf(nodeValue, edgeValue); + // Three-valued logic: NOT Unknown = Unknown (null remains null) + return result == null ? null : !result; + } + + /** + * Implements IS_DESTINATION_OF predicate. + * + * @param nodeValue vertex/node object to check + * @param edgeValue edge object to check + * @return Boolean: true if node is destination of edge, false if not, null if either is null + */ + public static Boolean isDestinationOf(Object nodeValue, Object edgeValue) { + // ISO-GQL Rule 1: If node or edge is null, result is Unknown (null) + if (nodeValue == null || edgeValue == null) { + return null; // Three-valued logic: Unknown } - /** - * Implements IS_NOT_DESTINATION_OF predicate. - * - * @param nodeValue vertex/node object to check - * @param edgeValue edge object to check - * @return Boolean: true if node is NOT destination of edge, false if it is, null if either is null - */ - public static Boolean isNotDestinationOf(Object nodeValue, Object edgeValue) { - Boolean result = isDestinationOf(nodeValue, edgeValue); - // Three-valued logic: NOT Unknown = Unknown (null remains null) - return result == null ? null : !result; + // Validate types + if (!(nodeValue instanceof RowVertex)) { + throw new IllegalArgumentException( + "First operand of IS_DESTINATION_OF must be a vertex/node, got: " + + nodeValue.getClass().getName()); } + if (!(edgeValue instanceof RowEdge)) { + throw new IllegalArgumentException( + "Second operand of IS_DESTINATION_OF must be an edge, got: " + + edgeValue.getClass().getName()); + } + + RowVertex node = (RowVertex) nodeValue; + RowEdge edge = (RowEdge) edgeValue; + + // ISO-GQL Rule 2: If edge is undirected, result is False + // Note: In GeaFlow, BOTH direction means undirected + if (edge.getDirect() == org.apache.geaflow.model.graph.edge.EdgeDirection.BOTH) { + return false; + } + + // ISO-GQL Rule 3: Check if node is destination of edge + // Compare node ID with edge target ID + Object nodeId = node.getId(); + Object edgeTargetId = edge.getTargetId(); + + return Objects.equals(nodeId, edgeTargetId); + } + + /** + * Implements IS_NOT_DESTINATION_OF predicate. + * + * @param nodeValue vertex/node object to check + * @param edgeValue edge object to check + * @return Boolean: true if node is NOT destination of edge, false if it is, null if either is + * null + */ + public static Boolean isNotDestinationOf(Object nodeValue, Object edgeValue) { + Boolean result = isDestinationOf(nodeValue, edgeValue); + // Three-valued logic: NOT Unknown = Unknown (null remains null) + return result == null ? null : !result; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAF.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAF.java index 7fb80d2e1..7093d6f78 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAF.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAF.java @@ -19,37 +19,26 @@ package org.apache.geaflow.dsl.common.function; -/** - * Interface for the User Defined Aggregate Function. - */ +/** Interface for the User Defined Aggregate Function. */ public abstract class UDAF extends UserDefinedFunction { - /** - * Create aggregate accumulator for aggregate function to store the aggregate value. - */ - public abstract AccumT createAccumulator(); - - /** - * Accumulate the input to the accumulator. - */ - public abstract void accumulate(AccumT accumulator, InputT input); + /** Create aggregate accumulator for aggregate function to store the aggregate value. */ + public abstract AccumT createAccumulator(); - /** - * Merge the accumulator iterator to the accumulator. - * - * @param accumulator The accumulator to merged to. - * @param its The accumulator iterators to merge from. - */ - public abstract void merge(AccumT accumulator, Iterable its); + /** Accumulate the input to the accumulator. */ + public abstract void accumulate(AccumT accumulator, InputT input); - /** - * Reset the accumulator to init value. - */ - public abstract void resetAccumulator(AccumT accumulator); + /** + * Merge the accumulator iterator to the accumulator. + * + * @param accumulator The accumulator to merged to. + * @param its The accumulator iterators to merge from. + */ + public abstract void merge(AccumT accumulator, Iterable its); - /** - * Get aggregate function result from the accumulator. - */ - public abstract OutputT getValue(AccumT accumulator); + /** Reset the accumulator to init value. */ + public abstract void resetAccumulator(AccumT accumulator); + /** Get aggregate function result from the accumulator. */ + public abstract OutputT getValue(AccumT accumulator); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAFArguments.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAFArguments.java index 8b771a73b..2399b79bc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAFArguments.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDAFArguments.java @@ -23,19 +23,19 @@ public abstract class UDAFArguments { - private Object[] params; + private Object[] params; - public abstract List> getParamTypes(); + public abstract List> getParamTypes(); - public void setParams(Object[] params) { - this.params = params; - } + public void setParams(Object[] params) { + this.params = params; + } - public Object getParam(int i) { - return params[i]; - } + public Object getParam(int i) { + return params[i]; + } - public int getParamSize() { - return getParamTypes().size(); - } + public int getParamSize() { + return getParamTypes().size(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDF.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDF.java index 1159f7147..212f2aa8f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDF.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDF.java @@ -19,6 +19,4 @@ package org.apache.geaflow.dsl.common.function; -public abstract class UDF extends UserDefinedFunction { - -} +public abstract class UDF extends UserDefinedFunction {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDTF.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDTF.java index 1bf28e7e2..0fac9a6f8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDTF.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UDTF.java @@ -19,41 +19,42 @@ package org.apache.geaflow.dsl.common.function; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; -public abstract class UDTF extends UserDefinedFunction { +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; - protected List collector; +public abstract class UDTF extends UserDefinedFunction { - public UDTF() { - this.collector = Lists.newArrayList(); - } + protected List collector; - /** - * Collect the result. - */ - protected void collect(Object[] output) { - if (output == null) { - throw new GeaFlowDSLException("UDTF's output must not null, " - + "Please check your UDTF's logic"); - } - this.collector.add(output); - } + public UDTF() { + this.collector = Lists.newArrayList(); + } - public List getCollectData() { - ImmutableList values = ImmutableList.copyOf(collector); - collector.clear(); - return values; + /** Collect the result. */ + protected void collect(Object[] output) { + if (output == null) { + throw new GeaFlowDSLException( + "UDTF's output must not null, " + "Please check your UDTF's logic"); } - - /** - * Returns type output types for the function. - * - * @param paramTypes The parameter types of the function. - * @param outFieldNames The output fields of the function in the sql. - */ - public abstract List> getReturnType(List> paramTypes, List outFieldNames); + this.collector.add(output); + } + + public List getCollectData() { + ImmutableList values = ImmutableList.copyOf(collector); + collector.clear(); + return values; + } + + /** + * Returns type output types for the function. + * + * @param paramTypes The parameter types of the function. + * @param outFieldNames The output fields of the function in the sql. + */ + public abstract List> getReturnType( + List> paramTypes, List outFieldNames); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UserDefinedFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UserDefinedFunction.java index 9d513b8b2..e9a1e814d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UserDefinedFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/function/UserDefinedFunction.java @@ -23,15 +23,9 @@ public abstract class UserDefinedFunction implements Serializable { - /** - * Init method for the user defined function. - */ - public void open(FunctionContext context) { - } + /** Init method for the user defined function. */ + public void open(FunctionContext context) {} - /** - * Close method for the user defined function. - */ - public void close() { - } + /** Close method for the user defined function. */ + public void close() {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/EnablePartitionPushDown.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/EnablePartitionPushDown.java index a10053de7..ed8b50b73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/EnablePartitionPushDown.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/EnablePartitionPushDown.java @@ -21,5 +21,5 @@ public interface EnablePartitionPushDown { - void setPartitionFilter(PartitionFilter partitionFilter); + void setPartitionFilter(PartitionFilter partitionFilter); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/PartitionFilter.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/PartitionFilter.java index fc13ef7ec..f77ba570c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/PartitionFilter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/pushdown/PartitionFilter.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.common.pushdown; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface PartitionFilter extends Serializable { - boolean apply(Row partition); + boolean apply(Row partition); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ArrayType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ArrayType.java index 46e4d5485..d830531e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ArrayType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ArrayType.java @@ -21,102 +21,103 @@ import java.lang.reflect.Array; import java.util.Objects; + import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; public class ArrayType implements IType { - private final IType componentType; + private final IType componentType; - public ArrayType(IType componentType) { - this.componentType = Objects.requireNonNull(componentType); - } + public ArrayType(IType componentType) { + this.componentType = Objects.requireNonNull(componentType); + } - public IType getComponentType() { - return componentType; - } + public IType getComponentType() { + return componentType; + } - @Override - public String toString() { - return "Array<" + componentType + ">"; - } + @Override + public String toString() { + return "Array<" + componentType + ">"; + } - @Override - public String getName() { - return Types.TYPE_NAME_ARRAY; - } + @Override + public String getName() { + return Types.TYPE_NAME_ARRAY; + } - @SuppressWarnings("unchecked") - @Override - public Class getTypeClass() { - return (Class) Array.newInstance(componentType.getTypeClass(), 0).getClass(); - } + @SuppressWarnings("unchecked") + @Override + public Class getTypeClass() { + return (Class) Array.newInstance(componentType.getTypeClass(), 0).getClass(); + } - @Override - public byte[] serialize(Object[] obj) { - return SerializerFactory.getKryoSerializer().serialize(obj); - } + @Override + public byte[] serialize(Object[] obj) { + return SerializerFactory.getKryoSerializer().serialize(obj); + } - @Override - public Object[] deserialize(byte[] bytes) { - return (Object[]) SerializerFactory.getKryoSerializer().deserialize(bytes); - } + @Override + public Object[] deserialize(byte[] bytes) { + return (Object[]) SerializerFactory.getKryoSerializer().deserialize(bytes); + } - @SuppressWarnings("unchecked") - @Override - public int compare(Object[] x, Object[] y) { - if (null == x) { - return y == null ? 0 : -1; - } else if (y == null) { - return 1; - } - int i; - for (i = 0; i < x.length && i < y.length; i++) { - Comparable cx = (Comparable) x[i]; - Comparable cy = (Comparable) y[i]; - if (cx == null && cy != null) { - return -1; - } - if (cx != null && cy == null) { - return 1; - } - if (cx == cy) { - return 0; - } - int c = cx.compareTo(cy); - if (c != 0) { - return c; - } - } - if (x.length > i) { - return 1; - } - if (y.length > i) { - return -1; - } + @SuppressWarnings("unchecked") + @Override + public int compare(Object[] x, Object[] y) { + if (null == x) { + return y == null ? 0 : -1; + } else if (y == null) { + return 1; + } + int i; + for (i = 0; i < x.length && i < y.length; i++) { + Comparable cx = (Comparable) x[i]; + Comparable cy = (Comparable) y[i]; + if (cx == null && cy != null) { + return -1; + } + if (cx != null && cy == null) { + return 1; + } + if (cx == cy) { return 0; + } + int c = cx.compareTo(cy); + if (c != 0) { + return c; + } } - - @Override - public boolean isPrimitive() { - return false; + if (x.length > i) { + return 1; } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ArrayType)) { - return false; - } - ArrayType arrayType = (ArrayType) o; - return Objects.equals(componentType, arrayType.componentType); + if (y.length > i) { + return -1; } + return 0; + } - @Override - public int hashCode() { - return Objects.hash(componentType); + @Override + public boolean isPrimitive() { + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ArrayType)) { + return false; } + ArrayType arrayType = (ArrayType) o; + return Objects.equals(componentType, arrayType.componentType); + } + + @Override + public int hashCode() { + return Objects.hash(componentType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ClassType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ClassType.java index 691d2d2b7..16575d2a1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ClassType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ClassType.java @@ -25,44 +25,44 @@ public class ClassType implements IType { - public static final ClassType INSTANCE = new ClassType(); + public static final ClassType INSTANCE = new ClassType(); - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("{"); - sb.append("type:Class"); - sb.append("}"); - return sb.toString(); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + sb.append("type:Class"); + sb.append("}"); + return sb.toString(); + } - @Override - public String getName() { - return Types.TYPE_NAME_CLASS; - } + @Override + public String getName() { + return Types.TYPE_NAME_CLASS; + } - @Override - public Class getTypeClass() { - return Class.class; - } + @Override + public Class getTypeClass() { + return Class.class; + } - @Override - public byte[] serialize(Class obj) { - return SerializerFactory.getKryoSerializer().serialize(obj); - } + @Override + public byte[] serialize(Class obj) { + return SerializerFactory.getKryoSerializer().serialize(obj); + } - @Override - public Class deserialize(byte[] bytes) { - return (Class) SerializerFactory.getKryoSerializer().deserialize(bytes); - } + @Override + public Class deserialize(byte[] bytes) { + return (Class) SerializerFactory.getKryoSerializer().deserialize(bytes); + } - @Override - public int compare(Class x, Class y) { - return x.toString().compareTo(y.toString()); - } + @Override + public int compare(Class x, Class y) { + return x.toString().compareTo(y.toString()); + } - @Override - public boolean isPrimitive() { - return false; - } + @Override + public boolean isPrimitive() { + return false; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/EdgeType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/EdgeType.java index 10c474583..35e319754 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/EdgeType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/EdgeType.java @@ -19,140 +19,142 @@ package org.apache.geaflow.dsl.common.types; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; -public class EdgeType extends StructType { +import com.google.common.collect.Lists; - public static final int SRC_ID_FIELD_POSITION = 0; +public class EdgeType extends StructType { - public static final int TARGET_ID_FIELD_POSITION = 1; + public static final int SRC_ID_FIELD_POSITION = 0; - public static final int LABEL_FIELD_POSITION = 2; + public static final int TARGET_ID_FIELD_POSITION = 1; - public static final int TIME_FIELD_POSITION = 3; + public static final int LABEL_FIELD_POSITION = 2; - private final boolean hasTimestamp; + public static final int TIME_FIELD_POSITION = 3; - private static final int NUM_META_FIELDS_WITHOUT_TS = 3; + private final boolean hasTimestamp; - private static final int NUM_META_FIELDS_WITH_TS = 4; + private static final int NUM_META_FIELDS_WITHOUT_TS = 3; - public static final String DEFAULT_SRC_ID_NAME = "~srcId"; + private static final int NUM_META_FIELDS_WITH_TS = 4; - public static final String DEFAULT_TARGET_ID_NAME = "~targetId"; + public static final String DEFAULT_SRC_ID_NAME = "~srcId"; - public static final String DEFAULT_LABEL_NAME = "~label"; + public static final String DEFAULT_TARGET_ID_NAME = "~targetId"; - public static final String DEFAULT_TS_NAME = "~ts"; + public static final String DEFAULT_LABEL_NAME = "~label"; - public EdgeType(List fields, boolean hasTimestamp) { - super(fields); - this.hasTimestamp = hasTimestamp; - } + public static final String DEFAULT_TS_NAME = "~ts"; - @Override - public String getName() { - return Types.TYPE_NAME_EDGE; - } + public EdgeType(List fields, boolean hasTimestamp) { + super(fields); + this.hasTimestamp = hasTimestamp; + } - @SuppressWarnings("unchecked") - @Override - public Class getTypeClass() { - return (Class) RowEdge.class; - } + @Override + public String getName() { + return Types.TYPE_NAME_EDGE; + } - public TableField getSrcId() { - return fields.get(SRC_ID_FIELD_POSITION); - } + @SuppressWarnings("unchecked") + @Override + public Class getTypeClass() { + return (Class) RowEdge.class; + } - public TableField getTargetId() { - return fields.get(TARGET_ID_FIELD_POSITION); - } + public TableField getSrcId() { + return fields.get(SRC_ID_FIELD_POSITION); + } - public TableField getLabel() { - return fields.get(LABEL_FIELD_POSITION); - } + public TableField getTargetId() { + return fields.get(TARGET_ID_FIELD_POSITION); + } - public Optional getTimestamp() { - if (hasTimestamp) { - return Optional.of(fields.get(TIME_FIELD_POSITION)); - } - return Optional.empty(); - } + public TableField getLabel() { + return fields.get(LABEL_FIELD_POSITION); + } - public int getValueSize() { - return size() - getValueOffset(); + public Optional getTimestamp() { + if (hasTimestamp) { + return Optional.of(fields.get(TIME_FIELD_POSITION)); } + return Optional.empty(); + } - public int getValueOffset() { - if (hasTimestamp) { - return NUM_META_FIELDS_WITH_TS; - } - return NUM_META_FIELDS_WITHOUT_TS; - } + public int getValueSize() { + return size() - getValueOffset(); + } - public List getValueFields() { - return getFields().subList(getValueOffset(), size()); + public int getValueOffset() { + if (hasTimestamp) { + return NUM_META_FIELDS_WITH_TS; } - - public IType[] getValueTypes() { - IType[] valueTypes = new IType[size() - getValueOffset()]; - List valueFields = getValueFields(); - for (int i = 0; i < valueFields.size(); i++) { - valueTypes[i] = valueFields.get(i).getType(); - } - return valueTypes; + return NUM_META_FIELDS_WITHOUT_TS; + } + + public List getValueFields() { + return getFields().subList(getValueOffset(), size()); + } + + public IType[] getValueTypes() { + IType[] valueTypes = new IType[size() - getValueOffset()]; + List valueFields = getValueFields(); + for (int i = 0; i < valueFields.size(); i++) { + valueTypes[i] = valueFields.get(i).getType(); } - - public static EdgeType emptyEdge(IType idType) { - TableField srcField = new TableField(DEFAULT_SRC_ID_NAME, idType, false); - TableField targetField = new TableField(DEFAULT_TARGET_ID_NAME, idType, false); - TableField labelField = new TableField(GraphSchema.LABEL_FIELD_NAME, Types.STRING, false); - return new EdgeType(Lists.newArrayList(srcField, targetField, labelField), false); + return valueTypes; + } + + public static EdgeType emptyEdge(IType idType) { + TableField srcField = new TableField(DEFAULT_SRC_ID_NAME, idType, false); + TableField targetField = new TableField(DEFAULT_TARGET_ID_NAME, idType, false); + TableField labelField = new TableField(GraphSchema.LABEL_FIELD_NAME, Types.STRING, false); + return new EdgeType(Lists.newArrayList(srcField, targetField, labelField), false); + } + + @Override + public EdgeType merge(StructType other) { + assert other instanceof EdgeType : "EdgeType should merge with edge type"; + assert hasTimestamp && ((EdgeType) other).hasTimestamp + || !hasTimestamp && !((EdgeType) other).hasTimestamp + : "Cannot merge different edge type"; + Map> name2Types = new HashMap<>(); + for (TableField field : this.fields) { + name2Types.put(field.getName(), field.getType()); } - - @Override - public EdgeType merge(StructType other) { - assert other instanceof EdgeType : "EdgeType should merge with edge type"; - assert hasTimestamp && ((EdgeType) other).hasTimestamp - || !hasTimestamp && !((EdgeType) other).hasTimestamp : "Cannot merge different edge type"; - Map> name2Types = new HashMap<>(); - for (TableField field : this.fields) { - name2Types.put(field.getName(), field.getType()); - } - List mergedFields = new ArrayList<>(this.fields); - for (TableField field : other.fields) { - if (name2Types.containsKey(field.getName())) { - if (!name2Types.get(field.getName()).equals(field.getType())) { - throw new IllegalArgumentException("Fail to merge edge schema"); - } - } else { - mergedFields.add(field); - } + List mergedFields = new ArrayList<>(this.fields); + for (TableField field : other.fields) { + if (name2Types.containsKey(field.getName())) { + if (!name2Types.get(field.getName()).equals(field.getType())) { + throw new IllegalArgumentException("Fail to merge edge schema"); } - return new EdgeType(mergedFields, hasTimestamp); + } else { + mergedFields.add(field); + } } + return new EdgeType(mergedFields, hasTimestamp); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof EdgeType)) { - return false; - } - EdgeType that = (EdgeType) o; - return Objects.equals(fields, that.fields); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - + if (!(o instanceof EdgeType)) { + return false; + } + EdgeType that = (EdgeType) o; + return Objects.equals(fields, that.fields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/GraphSchema.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/GraphSchema.java index f9ac329ab..0baac5cc4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/GraphSchema.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/GraphSchema.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; @@ -31,117 +32,117 @@ public class GraphSchema extends StructType { - public static final String LABEL_FIELD_NAME = "~label"; + public static final String LABEL_FIELD_NAME = "~label"; - private final String graphName; + private final String graphName; - public GraphSchema(String graphName, List fields) { - super(fields); - this.graphName = Objects.requireNonNull(graphName); - } + public GraphSchema(String graphName, List fields) { + super(fields); + this.graphName = Objects.requireNonNull(graphName); + } - @Override - public String getName() { - return Types.TYPE_NAME_GRAPH; - } + @Override + public String getName() { + return Types.TYPE_NAME_GRAPH; + } - public String getGraphName() { - return graphName; - } + public String getGraphName() { + return graphName; + } - public VertexType getVertex(String name) { - return (VertexType) getField(name).getType(); - } + public VertexType getVertex(String name) { + return (VertexType) getField(name).getType(); + } - public EdgeType getEdge(String name) { - return (EdgeType) getField(name).getType(); - } + public EdgeType getEdge(String name) { + return (EdgeType) getField(name).getType(); + } - public List getVertices() { - List vertexTypes = new ArrayList<>(); - for (TableField field : fields) { - if (field.getType() instanceof VertexType) { - vertexTypes.add((VertexType) field.getType()); - } - } - return vertexTypes; + public List getVertices() { + List vertexTypes = new ArrayList<>(); + for (TableField field : fields) { + if (field.getType() instanceof VertexType) { + vertexTypes.add((VertexType) field.getType()); + } } - - public List getEdges() { - List edgeTypes = new ArrayList<>(); - for (TableField field : fields) { - if (field.getType() instanceof EdgeType) { - edgeTypes.add((EdgeType) field.getType()); - } - } - return edgeTypes; + return vertexTypes; + } + + public List getEdges() { + List edgeTypes = new ArrayList<>(); + for (TableField field : fields) { + if (field.getType() instanceof EdgeType) { + edgeTypes.add((EdgeType) field.getType()); + } } - - public IType getIdType() { - List vertexTypes = getVertices(); - assert vertexTypes.size() > 0 : "Empty graph"; - return vertexTypes.get(0).getId().getType(); + return edgeTypes; + } + + public IType getIdType() { + List vertexTypes = getVertices(); + assert vertexTypes.size() > 0 : "Empty graph"; + return vertexTypes.get(0).getId().getType(); + } + + public List getAddingFields(GraphSchema baseSchema) { + if (baseSchema.equals(this)) { + return Collections.emptyList(); } - - public List getAddingFields(GraphSchema baseSchema) { - if (baseSchema.equals(this)) { - return Collections.emptyList(); - } - // the global vertex variable is applied to all the vertex tables. - // so any of the vertex table has the same adding fields. - VertexType vertexType = getVertices().get(0); - VertexType baseVertex = baseSchema.getVertices().get(0); - assert vertexType != null && baseVertex != null; - return vertexType.getAddingFields(baseVertex); + // the global vertex variable is applied to all the vertex tables. + // so any of the vertex table has the same adding fields. + VertexType vertexType = getVertices().get(0); + VertexType baseVertex = baseSchema.getVertices().get(0); + assert vertexType != null && baseVertex != null; + return vertexType.getAddingFields(baseVertex); + } + + @Override + public int compare(Row x, Row y) { + throw new GeaFlowDSLException("Illegal call."); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - - @Override - public int compare(Row x, Row y) { - throw new GeaFlowDSLException("Illegal call."); + if (!(o instanceof GraphSchema)) { + return false; } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof GraphSchema)) { - return false; - } - if (!super.equals(o)) { - return false; - } - GraphSchema that = (GraphSchema) o; - return Objects.equals(graphName, that.graphName); + if (!super.equals(o)) { + return false; } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), graphName); - } - - @Override - public GraphSchema merge(StructType other) { - assert other instanceof GraphSchema : "GraphSchema should merge with graph schema"; - assert graphName.equals(((GraphSchema) other).graphName) : "Cannot merge with different graph schema"; - - List mergedFields = new ArrayList<>(fields); - List mergedFieldNames = fields.stream().map(TableField::getName) - .collect(Collectors.toList()); - for (TableField field : other.fields) { - int index = mergedFieldNames.indexOf(field.getName()); - if (index >= 0) { - StructType thisType = (StructType) mergedFields.get(index).getType(); - StructType thatType = (StructType) field.getType(); - StructType mergedType = thisType.merge(thatType); - - mergedFields.set(index, field.copy(mergedType)); - } else { - mergedFields.add(field); - mergedFieldNames.add(field.getName()); - } - } - return new GraphSchema(graphName, mergedFields); + GraphSchema that = (GraphSchema) o; + return Objects.equals(graphName, that.graphName); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), graphName); + } + + @Override + public GraphSchema merge(StructType other) { + assert other instanceof GraphSchema : "GraphSchema should merge with graph schema"; + assert graphName.equals(((GraphSchema) other).graphName) + : "Cannot merge with different graph schema"; + + List mergedFields = new ArrayList<>(fields); + List mergedFieldNames = + fields.stream().map(TableField::getName).collect(Collectors.toList()); + for (TableField field : other.fields) { + int index = mergedFieldNames.indexOf(field.getName()); + if (index >= 0) { + StructType thisType = (StructType) mergedFields.get(index).getType(); + StructType thatType = (StructType) field.getType(); + StructType mergedType = thisType.merge(thatType); + + mergedFields.set(index, field.copy(mergedType)); + } else { + mergedFields.add(field); + mergedFieldNames.add(field.getName()); + } } + return new GraphSchema(graphName, mergedFields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ObjectType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ObjectType.java index e9ac1be88..e8db8e805 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ObjectType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/ObjectType.java @@ -25,42 +25,40 @@ public class ObjectType implements IType { - public static ObjectType INSTANCE = new ObjectType(); + public static ObjectType INSTANCE = new ObjectType(); - private ObjectType() { + private ObjectType() {} - } + @Override + public String getName() { + return Types.TYPE_NAME_OBJECT; + } - @Override - public String getName() { - return Types.TYPE_NAME_OBJECT; - } + @Override + public Class getTypeClass() { + return Object.class; + } - @Override - public Class getTypeClass() { - return Object.class; - } + @Override + public byte[] serialize(Object obj) { + return SerializerFactory.getKryoSerializer().serialize(obj); + } - @Override - public byte[] serialize(Object obj) { - return SerializerFactory.getKryoSerializer().serialize(obj); - } + @Override + public Object deserialize(byte[] bytes) { + return SerializerFactory.getKryoSerializer().deserialize(bytes); + } - @Override - public Object deserialize(byte[] bytes) { - return SerializerFactory.getKryoSerializer().deserialize(bytes); + @Override + public int compare(Object x, Object y) { + if (x instanceof Comparable && y instanceof Comparable) { + return ((Comparable) x).compareTo(y); } + return 0; + } - @Override - public int compare(Object x, Object y) { - if (x instanceof Comparable && y instanceof Comparable) { - return ((Comparable) x).compareTo(y); - } - return 0; - } - - @Override - public boolean isPrimitive() { - return true; - } + @Override + public boolean isPrimitive() { + return true; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/PathType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/PathType.java index abbf757b1..4cc5461bf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/PathType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/PathType.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.common.types; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -28,122 +27,123 @@ import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; -public class PathType extends StructType { - - public static final PathType EMPTY = new PathType(); - - public PathType(List fields) { - super(fields); - } +import com.google.common.collect.Sets; - public PathType(TableField... fields) { - super(fields); - } +public class PathType extends StructType { - @Override - public PathType addField(TableField field) { - List newFields = new ArrayList<>(fields); - newFields.add(field); - return new PathType(newFields); - } + public static final PathType EMPTY = new PathType(); - @Override - public PathType replace(String name, TableField newField) { - int index = indexOf(name); - if (index == -1) { - throw new IllegalArgumentException("Field: '" + name + "' is not exist on the path"); - } - List newFields = new ArrayList<>(fields); - newFields.set(index, newField); - return new PathType(newFields); - } + public PathType(List fields) { + super(fields); + } - public PathType filter(Predicate predicate) { - List filterFields = fields.stream() - .filter(predicate) - .collect(Collectors.toList()); - return new PathType(filterFields); - } + public PathType(TableField... fields) { + super(fields); + } - @Override - public String getName() { - return Types.TYPE_NAME_PATH; - } + @Override + public PathType addField(TableField field) { + List newFields = new ArrayList<>(fields); + newFields.add(field); + return new PathType(newFields); + } - @SuppressWarnings("unchecked") - @Override - public Class getTypeClass() { - return (Class) Path.class; + @Override + public PathType replace(String name, TableField newField) { + int index = indexOf(name); + if (index == -1) { + throw new IllegalArgumentException("Field: '" + name + "' is not exist on the path"); } - - public Set getCommonFieldNames(PathType other) { - Set fieldNames = new HashSet<>(getFieldNames()); - Set otherFieldNames = new HashSet<>(other.getFieldNames()); - return Sets.intersection(fieldNames, otherFieldNames); + List newFields = new ArrayList<>(fields); + newFields.set(index, newField); + return new PathType(newFields); + } + + public PathType filter(Predicate predicate) { + List filterFields = fields.stream().filter(predicate).collect(Collectors.toList()); + return new PathType(filterFields); + } + + @Override + public String getName() { + return Types.TYPE_NAME_PATH; + } + + @SuppressWarnings("unchecked") + @Override + public Class getTypeClass() { + return (Class) Path.class; + } + + public Set getCommonFieldNames(PathType other) { + Set fieldNames = new HashSet<>(getFieldNames()); + Set otherFieldNames = new HashSet<>(other.getFieldNames()); + return Sets.intersection(fieldNames, otherFieldNames); + } + + public PathType join(PathType right) { + List joinFields = new ArrayList<>(fields); + + Map nameCount = new HashMap<>(); + for (TableField rightField : right.fields) { + if (joinFields.contains(rightField)) { + int cnt = nameCount.getOrDefault(rightField.getName(), 0); + String newName = rightField.getName() + cnt; + joinFields.add(rightField.copy(newName)); + nameCount.put(rightField.getName(), cnt + 1); + } else { + joinFields.add(rightField); + } } - - public PathType join(PathType right) { - List joinFields = new ArrayList<>(fields); - - Map nameCount = new HashMap<>(); - for (TableField rightField : right.fields) { - if (joinFields.contains(rightField)) { - int cnt = nameCount.getOrDefault(rightField.getName(), 0); - String newName = rightField.getName() + cnt; - joinFields.add(rightField.copy(newName)); - nameCount.put(rightField.getName(), cnt + 1); - } else { - joinFields.add(rightField); - } + return new PathType(joinFields); + } + + public PathType subPath(int from, int size) { + List subFields = fields.subList(from, from + size); + return new PathType(subFields); + } + + @Override + public int compare(Row a, Row b) { + if (null == a) { + return b == null ? 0 : -1; + } else if (b == null) { + return 1; + } else { + for (int i = 0; i < fields.size(); i++) { + IType type = fields.get(i).getType(); + int comparator = ((IType) type).compare(a.getField(i, type), b.getField(i, type)); + if (comparator != 0) { + return comparator; } - return new PathType(joinFields); + } + return 0; } - - public PathType subPath(int from, int size) { - List subFields = fields.subList(from, from + size); - return new PathType(subFields); + } + + @Override + public PathType merge(StructType other) { + assert other instanceof PathType : "PathType should merge with path type"; + Map> name2Types = new HashMap<>(); + for (TableField field : this.fields) { + name2Types.put(field.getName(), field.getType()); } - - @Override - public int compare(Row a, Row b) { - if (null == a) { - return b == null ? 0 : -1; - } else if (b == null) { - return 1; - } else { - for (int i = 0; i < fields.size(); i++) { - IType type = fields.get(i).getType(); - int comparator = ((IType) type).compare(a.getField(i, type), b.getField(i, type)); - if (comparator != 0) { - return comparator; - } - } - return 0; - } - } - - @Override - public PathType merge(StructType other) { - assert other instanceof PathType : "PathType should merge with path type"; - Map> name2Types = new HashMap<>(); - for (TableField field : this.fields) { - name2Types.put(field.getName(), field.getType()); - } - List mergedFields = new ArrayList<>(this.fields); - for (TableField field : other.fields) { - if (name2Types.containsKey(field.getName())) { - if (!name2Types.get(field.getName()).equals(field.getType())) { - throw new IllegalArgumentException("Fail to merge path schema"); - } - } else { - mergedFields.add(field); - } + List mergedFields = new ArrayList<>(this.fields); + for (TableField field : other.fields) { + if (name2Types.containsKey(field.getName())) { + if (!name2Types.get(field.getName()).equals(field.getType())) { + throw new IllegalArgumentException("Fail to merge path schema"); } - return new PathType(mergedFields); + } else { + mergedFields.add(field); + } } + return new PathType(mergedFields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/StructType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/StructType.java index cf31efea0..5573ceccd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/StructType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/StructType.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.common.types; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -28,208 +26,217 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; -public class StructType implements IType { - - protected final List fields; - - public StructType(List fields) { - this.fields = ImmutableList.copyOf(fields); - Set names = fields.stream().map(TableField::getName).collect(Collectors.toSet()); - Preconditions.checkArgument(names.size() == fields.size(), - "Duplicate fields found for struct type."); - } - - public StructType(TableField... fields) { - this(ImmutableList.copyOf(fields)); - } - - public static StructType singleValue(IType valueType, boolean nullable) { - return new StructType(new TableField("col_0", valueType, nullable)); - } - - public int indexOf(String name) { - for (int i = 0; i < fields.size(); i++) { - if (fields.get(i).getName().equalsIgnoreCase(name)) { - return i; - } - } - return -1; - } - - public boolean contain(String name) { - return indexOf(name) != -1; - } - - public TableField getField(int i) { - if (i >= 0 && i <= fields.size() - 1) { - return fields.get(i); - } - throw new IndexOutOfBoundsException("Index: " + i + ", size: " + fields.size()); - } - - public TableField getField(String name) { - int index = indexOf(name); - if (index == -1) { - throw new IllegalArgumentException("Field: '" + name + "' is not exist"); - } - return fields.get(index); - } - - public List getFields() { - return fields; - } - - public IType getType(int i) { - return getField(i).getType(); - } - - public IType[] getTypes() { - IType[] types = new IType[fields.size()]; - for (int i = 0; i < fields.size(); i++) { - types[i] = fields.get(i).getType(); - } - return types; - } - - public StructType addField(TableField field) { - List newFields = new ArrayList<>(fields); - newFields.add(field); - return new StructType(newFields); - } - - public StructType replace(String name, TableField newField) { - int index = indexOf(name); - if (index == -1) { - throw new IllegalArgumentException("Field: '" + name + "' is not exist"); - } - List newFields = new ArrayList<>(fields); - newFields.set(index, newField); - return new StructType(newFields); - } - - public StructType dropRight(int size) { - return new StructType(fields.subList(0, fields.size() - size)); - } - - public List getFieldNames() { - return fields.stream().map(TableField::getName).collect(Collectors.toList()); - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("{"); - boolean first = true; - for (TableField field : fields) { - if (!first) { - sb.append(","); - } - sb.append(field.getName()).append(":{type:").append(field.getType()).append(",") - .append("nullable:").append(field.isNullable()).append("}"); - first = false; - } - sb.append("}"); - return sb.toString(); - } - - public int size() { - return fields.size(); - } - - @Override - public String getName() { - return Types.TYPE_NAME_STRUCT; - } - - @Override - public Class getTypeClass() { - return Row.class; - } - - @Override - public byte[] serialize(Row obj) { - return SerializerFactory.getKryoSerializer().serialize(obj); - } - - @Override - public Row deserialize(byte[] bytes) { - return (Row) SerializerFactory.getKryoSerializer().deserialize(bytes); - } - - @Override - public int compare(Row a, Row b) { - if (null == a) { - return b == null ? 0 : -1; - } else if (b == null) { - return 1; - } else { - for (int i = 0; i < fields.size(); i++) { - IType type = fields.get(i).getType(); - int comparator = ((IType) type).compare(a.getField(i, type), b.getField(i, type)); - if (comparator != 0) { - return comparator; - } - } - return 0; - } - } - - @Override - public boolean isPrimitive() { - return false; - } - - public StructType merge(StructType other) { - Map> name2Types = new HashMap<>(); - for (TableField field : this.fields) { - name2Types.put(field.getName(), field.getType()); - } - List mergedFields = new ArrayList<>(this.fields); - for (TableField field : other.fields) { - if (name2Types.containsKey(field.getName())) { - if (!name2Types.get(field.getName()).equals(field.getType())) { - throw new IllegalArgumentException("Fail to merge vertex schema"); - } - } else { - mergedFields.add(field); - } - } - return new StructType(mergedFields); - } +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; - public List getAddingFields(StructType baseType) { - Map baseFields = new HashMap<>(); - for (TableField field : baseType.fields) { - baseFields.put(field.getName(), field); - } - List addingFields = new ArrayList<>(); - for (TableField field : fields) { - if (!baseFields.containsKey(field.getName())) { - addingFields.add(field); - } - } - return addingFields; - } +public class StructType implements IType { - @Override - public boolean equals(Object o) { - if (this == o) { - return true; + protected final List fields; + + public StructType(List fields) { + this.fields = ImmutableList.copyOf(fields); + Set names = fields.stream().map(TableField::getName).collect(Collectors.toSet()); + Preconditions.checkArgument( + names.size() == fields.size(), "Duplicate fields found for struct type."); + } + + public StructType(TableField... fields) { + this(ImmutableList.copyOf(fields)); + } + + public static StructType singleValue(IType valueType, boolean nullable) { + return new StructType(new TableField("col_0", valueType, nullable)); + } + + public int indexOf(String name) { + for (int i = 0; i < fields.size(); i++) { + if (fields.get(i).getName().equalsIgnoreCase(name)) { + return i; + } + } + return -1; + } + + public boolean contain(String name) { + return indexOf(name) != -1; + } + + public TableField getField(int i) { + if (i >= 0 && i <= fields.size() - 1) { + return fields.get(i); + } + throw new IndexOutOfBoundsException("Index: " + i + ", size: " + fields.size()); + } + + public TableField getField(String name) { + int index = indexOf(name); + if (index == -1) { + throw new IllegalArgumentException("Field: '" + name + "' is not exist"); + } + return fields.get(index); + } + + public List getFields() { + return fields; + } + + public IType getType(int i) { + return getField(i).getType(); + } + + public IType[] getTypes() { + IType[] types = new IType[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + types[i] = fields.get(i).getType(); + } + return types; + } + + public StructType addField(TableField field) { + List newFields = new ArrayList<>(fields); + newFields.add(field); + return new StructType(newFields); + } + + public StructType replace(String name, TableField newField) { + int index = indexOf(name); + if (index == -1) { + throw new IllegalArgumentException("Field: '" + name + "' is not exist"); + } + List newFields = new ArrayList<>(fields); + newFields.set(index, newField); + return new StructType(newFields); + } + + public StructType dropRight(int size) { + return new StructType(fields.subList(0, fields.size() - size)); + } + + public List getFieldNames() { + return fields.stream().map(TableField::getName).collect(Collectors.toList()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + boolean first = true; + for (TableField field : fields) { + if (!first) { + sb.append(","); + } + sb.append(field.getName()) + .append(":{type:") + .append(field.getType()) + .append(",") + .append("nullable:") + .append(field.isNullable()) + .append("}"); + first = false; + } + sb.append("}"); + return sb.toString(); + } + + public int size() { + return fields.size(); + } + + @Override + public String getName() { + return Types.TYPE_NAME_STRUCT; + } + + @Override + public Class getTypeClass() { + return Row.class; + } + + @Override + public byte[] serialize(Row obj) { + return SerializerFactory.getKryoSerializer().serialize(obj); + } + + @Override + public Row deserialize(byte[] bytes) { + return (Row) SerializerFactory.getKryoSerializer().deserialize(bytes); + } + + @Override + public int compare(Row a, Row b) { + if (null == a) { + return b == null ? 0 : -1; + } else if (b == null) { + return 1; + } else { + for (int i = 0; i < fields.size(); i++) { + IType type = fields.get(i).getType(); + int comparator = ((IType) type).compare(a.getField(i, type), b.getField(i, type)); + if (comparator != 0) { + return comparator; } - if (!(o instanceof StructType)) { - return false; + } + return 0; + } + } + + @Override + public boolean isPrimitive() { + return false; + } + + public StructType merge(StructType other) { + Map> name2Types = new HashMap<>(); + for (TableField field : this.fields) { + name2Types.put(field.getName(), field.getType()); + } + List mergedFields = new ArrayList<>(this.fields); + for (TableField field : other.fields) { + if (name2Types.containsKey(field.getName())) { + if (!name2Types.get(field.getName()).equals(field.getType())) { + throw new IllegalArgumentException("Fail to merge vertex schema"); } - StructType that = (StructType) o; - return Objects.equals(fields, that.fields); - } - - @Override - public int hashCode() { - return Objects.hash(fields); - } + } else { + mergedFields.add(field); + } + } + return new StructType(mergedFields); + } + + public List getAddingFields(StructType baseType) { + Map baseFields = new HashMap<>(); + for (TableField field : baseType.fields) { + baseFields.put(field.getName(), field); + } + List addingFields = new ArrayList<>(); + for (TableField field : fields) { + if (!baseFields.containsKey(field.getName())) { + addingFields.add(field); + } + } + return addingFields; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof StructType)) { + return false; + } + StructType that = (StructType) o; + return Objects.equals(fields, that.fields); + } + + @Override + public int hashCode() { + return Objects.hash(fields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableField.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableField.java index c919f18bb..53544e5c7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableField.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableField.java @@ -21,70 +21,76 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.common.type.IType; public class TableField implements Serializable { - private final String name; + private final String name; - private final IType type; + private final IType type; - private final boolean nullable; + private final boolean nullable; - public TableField(String name, IType type, boolean nullable) { - this.name = Objects.requireNonNull(name); - this.type = Objects.requireNonNull(type); - this.nullable = nullable; - } + public TableField(String name, IType type, boolean nullable) { + this.name = Objects.requireNonNull(name); + this.type = Objects.requireNonNull(type); + this.nullable = nullable; + } - public TableField(String name, IType type) { - this(name, type, true); - } + public TableField(String name, IType type) { + this(name, type, true); + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public IType getType() { - return type; - } + public IType getType() { + return type; + } - public boolean isNullable() { - return nullable; - } + public boolean isNullable() { + return nullable; + } - public TableField copy(String name) { - return new TableField(name, type, nullable); - } + public TableField copy(String name) { + return new TableField(name, type, nullable); + } - public TableField copy(IType type) { - return new TableField(name, type, nullable); - } + public TableField copy(IType type) { + return new TableField(name, type, nullable); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof TableField)) { - return false; - } - TableField field = (TableField) o; - return nullable == field.nullable && Objects.equals(name, field.name) && Objects.equals(type, - field.type); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(name, type, nullable); - } - - @Override - public String toString() { - return "TableField{" - + "name='" + name + '\'' - + ", type=" + type - + ", nullable=" + nullable - + '}'; + if (!(o instanceof TableField)) { + return false; } + TableField field = (TableField) o; + return nullable == field.nullable + && Objects.equals(name, field.name) + && Objects.equals(type, field.type); + } + + @Override + public int hashCode() { + return Objects.hash(name, type, nullable); + } + + @Override + public String toString() { + return "TableField{" + + "name='" + + name + + '\'' + + ", type=" + + type + + ", nullable=" + + nullable + + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableSchema.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableSchema.java index c7f0bc653..db41fabed 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableSchema.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/TableSchema.java @@ -24,53 +24,53 @@ public class TableSchema extends StructType { - private final StructType dataSchema; - private final StructType partitionSchema; + private final StructType dataSchema; + private final StructType partitionSchema; - public TableSchema(StructType dataSchema, StructType partitionSchema) { - super(combine(dataSchema, partitionSchema)); - this.dataSchema = dataSchema; - this.partitionSchema = partitionSchema; - } + public TableSchema(StructType dataSchema, StructType partitionSchema) { + super(combine(dataSchema, partitionSchema)); + this.dataSchema = dataSchema; + this.partitionSchema = partitionSchema; + } - public TableSchema(StructType dataSchema) { - this(dataSchema, new StructType()); - } + public TableSchema(StructType dataSchema) { + this(dataSchema, new StructType()); + } - public TableSchema(List fields) { - super(fields); - this.dataSchema = new StructType(fields); - this.partitionSchema = new StructType(); - } + public TableSchema(List fields) { + super(fields); + this.dataSchema = new StructType(fields); + this.partitionSchema = new StructType(); + } - public TableSchema(TableField... fields) { - super(fields); - this.dataSchema = new StructType(fields); - this.partitionSchema = new StructType(); - } + public TableSchema(TableField... fields) { + super(fields); + this.dataSchema = new StructType(fields); + this.partitionSchema = new StructType(); + } - private static List combine(StructType dataSchema, StructType partitionSchema) { - List fields = new ArrayList<>(); - fields.addAll(dataSchema.getFields()); - fields.addAll(partitionSchema.getFields()); - return fields; - } + private static List combine(StructType dataSchema, StructType partitionSchema) { + List fields = new ArrayList<>(); + fields.addAll(dataSchema.getFields()); + fields.addAll(partitionSchema.getFields()); + return fields; + } - public StructType getDataSchema() { - return dataSchema; - } + public StructType getDataSchema() { + return dataSchema; + } - public StructType getPartitionSchema() { - return partitionSchema; - } + public StructType getPartitionSchema() { + return partitionSchema; + } - @Override - public StructType addField(TableField field) { - throw new IllegalArgumentException("addField not support"); - } + @Override + public StructType addField(TableField field) { + throw new IllegalArgumentException("addField not support"); + } - @Override - public StructType replace(String name, TableField newField) { - throw new IllegalArgumentException("replace not support"); - } + @Override + public StructType replace(String name, TableField newField) { + throw new IllegalArgumentException("replace not support"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VertexType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VertexType.java index 5477cc67e..eb0e4c412 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VertexType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VertexType.java @@ -19,105 +19,107 @@ package org.apache.geaflow.dsl.common.types; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; -public class VertexType extends StructType { +import com.google.common.collect.Lists; - public static final int ID_FIELD_POSITION = 0; +public class VertexType extends StructType { - public static final int LABEL_FIELD_POSITION = 1; + public static final int ID_FIELD_POSITION = 0; - private static final int NUM_META_FIELDS = 2; + public static final int LABEL_FIELD_POSITION = 1; - public static final String DEFAULT_ID_FIELD_NAME = "~id"; + private static final int NUM_META_FIELDS = 2; - public VertexType(List fields) { - super(fields); - } + public static final String DEFAULT_ID_FIELD_NAME = "~id"; - @Override - public String getName() { - return Types.TYPE_NAME_VERTEX; - } + public VertexType(List fields) { + super(fields); + } - @SuppressWarnings("unchecked") - @Override - public Class getTypeClass() { - return (Class) RowVertex.class; - } + @Override + public String getName() { + return Types.TYPE_NAME_VERTEX; + } - public TableField getId() { - return getField(ID_FIELD_POSITION); - } + @SuppressWarnings("unchecked") + @Override + public Class getTypeClass() { + return (Class) RowVertex.class; + } - public TableField getLabel() { - return getField(LABEL_FIELD_POSITION); - } + public TableField getId() { + return getField(ID_FIELD_POSITION); + } - public int getValueSize() { - return size() - getValueOffset(); - } + public TableField getLabel() { + return getField(LABEL_FIELD_POSITION); + } - public int getValueOffset() { - return NUM_META_FIELDS; - } + public int getValueSize() { + return size() - getValueOffset(); + } - public IType[] getValueTypes() { - IType[] valueTypes = new IType[size() - getValueOffset()]; - for (int i = getValueOffset(); i < size(); i++) { - valueTypes[i - getValueOffset()] = fields.get(i).getType(); - } - return valueTypes; - } + public int getValueOffset() { + return NUM_META_FIELDS; + } - public List getValueFields() { - return getFields().subList(getValueOffset(), size()); + public IType[] getValueTypes() { + IType[] valueTypes = new IType[size() - getValueOffset()]; + for (int i = getValueOffset(); i < size(); i++) { + valueTypes[i - getValueOffset()] = fields.get(i).getType(); } - - public static VertexType emptyVertex(IType idType) { - TableField idField = new TableField("~id", idType, false); - TableField labelField = new TableField("~label", Types.STRING, false); - return new VertexType(Lists.newArrayList(idField, labelField)); + return valueTypes; + } + + public List getValueFields() { + return getFields().subList(getValueOffset(), size()); + } + + public static VertexType emptyVertex(IType idType) { + TableField idField = new TableField("~id", idType, false); + TableField labelField = new TableField("~label", Types.STRING, false); + return new VertexType(Lists.newArrayList(idField, labelField)); + } + + @Override + public VertexType merge(StructType other) { + assert other instanceof VertexType : "VertexType should merge with vertex type"; + Map> name2Types = new HashMap<>(); + for (TableField field : this.fields) { + name2Types.put(field.getName(), field.getType()); } - - @Override - public VertexType merge(StructType other) { - assert other instanceof VertexType : "VertexType should merge with vertex type"; - Map> name2Types = new HashMap<>(); - for (TableField field : this.fields) { - name2Types.put(field.getName(), field.getType()); - } - List mergedFields = new ArrayList<>(this.fields); - for (TableField field : other.fields) { - if (name2Types.containsKey(field.getName())) { - if (!name2Types.get(field.getName()).equals(field.getType())) { - throw new IllegalArgumentException("Fail to merge vertex schema"); - } - } else { - mergedFields.add(field); - } + List mergedFields = new ArrayList<>(this.fields); + for (TableField field : other.fields) { + if (name2Types.containsKey(field.getName())) { + if (!name2Types.get(field.getName()).equals(field.getType())) { + throw new IllegalArgumentException("Fail to merge vertex schema"); } - return new VertexType(mergedFields); + } else { + mergedFields.add(field); + } } + return new VertexType(mergedFields); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof VertexType)) { - return false; - } - VertexType that = (VertexType) o; - return Objects.equals(fields, that.fields); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof VertexType)) { + return false; } + VertexType that = (VertexType) o; + return Objects.equals(fields, that.fields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VoidType.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VoidType.java index 34b1b7756..55de7d52a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VoidType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/types/VoidType.java @@ -23,39 +23,37 @@ public class VoidType implements IType { - public static VoidType INSTANCE = new VoidType(); - - private VoidType() { - - } - - @Override - public String getName() { - return "VOID"; - } - - @Override - public Class getTypeClass() { - return Void.class; - } - - @Override - public byte[] serialize(Void obj) { - return new byte[0]; - } - - @Override - public Void deserialize(byte[] bytes) { - return null; - } - - @Override - public int compare(Void x, Void y) { - return 0; - } - - @Override - public boolean isPrimitive() { - return true; - } + public static VoidType INSTANCE = new VoidType(); + + private VoidType() {} + + @Override + public String getName() { + return "VOID"; + } + + @Override + public Class getTypeClass() { + return Void.class; + } + + @Override + public byte[] serialize(Void obj) { + return new byte[0]; + } + + @Override + public Void deserialize(byte[] bytes) { + return null; + } + + @Override + public int compare(Void x, Void y) { + return 0; + } + + @Override + public boolean isPrimitive() { + return true; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/BinaryUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/BinaryUtil.java index ab44a6b43..5c9b6428d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/BinaryUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/BinaryUtil.java @@ -21,25 +21,26 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.binary.BinaryString; public class BinaryUtil { - private static final ThreadLocal> BINARY_LABEL_CACHE = - ThreadLocal.withInitial(HashMap::new); + private static final ThreadLocal> BINARY_LABEL_CACHE = + ThreadLocal.withInitial(HashMap::new); - public static Object toBinaryForString(Object o) { - if (o instanceof BinaryString) { - return o; - } - if (o instanceof String) { - return BinaryString.fromString((String) o); - } - return o; + public static Object toBinaryForString(Object o) { + if (o instanceof BinaryString) { + return o; } - - public static Object toBinaryLabel(String label) { - Map labelCache = BINARY_LABEL_CACHE.get(); - return labelCache.computeIfAbsent(label, l -> toBinaryForString(label)); + if (o instanceof String) { + return BinaryString.fromString((String) o); } + return o; + } + + public static Object toBinaryLabel(String label) { + Map labelCache = BINARY_LABEL_CACHE.get(); + return labelCache.computeIfAbsent(label, l -> toBinaryForString(label)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/FunctionCallUtils.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/FunctionCallUtils.java index 68b0d66ec..616c64975 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/FunctionCallUtils.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/FunctionCallUtils.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.common.util; -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; import java.lang.reflect.Array; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -32,367 +30,373 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.common.function.UDAFArguments; -public class FunctionCallUtils { - - public static String UDF_EVAL_METHOD_NAME = "eval"; +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; - private static final Map, Class[]> TYPE_DEGREE_MAP = new HashMap<>(); +public class FunctionCallUtils { - static { - TYPE_DEGREE_MAP - .put(Integer.class, new Class[]{Long.class, Double.class, BigDecimal.class}); - TYPE_DEGREE_MAP.put(Long.class, new Class[]{Double.class, BigDecimal.class}); - TYPE_DEGREE_MAP.put(Byte.class, - new Class[]{Integer.class, Long.class, Double.class, BigDecimal.class}); - TYPE_DEGREE_MAP.put(Short.class, - new Class[]{Integer.class, Long.class, Double.class, BigDecimal.class}); - TYPE_DEGREE_MAP.put(BigDecimal.class, new Class[]{Double.class}); + public static String UDF_EVAL_METHOD_NAME = "eval"; + + private static final Map, Class[]> TYPE_DEGREE_MAP = new HashMap<>(); + + static { + TYPE_DEGREE_MAP.put(Integer.class, new Class[] {Long.class, Double.class, BigDecimal.class}); + TYPE_DEGREE_MAP.put(Long.class, new Class[] {Double.class, BigDecimal.class}); + TYPE_DEGREE_MAP.put( + Byte.class, new Class[] {Integer.class, Long.class, Double.class, BigDecimal.class}); + TYPE_DEGREE_MAP.put( + Short.class, new Class[] {Integer.class, Long.class, Double.class, BigDecimal.class}); + TYPE_DEGREE_MAP.put(BigDecimal.class, new Class[] {Double.class}); + } + + private static final Map, Class> BOX_TYPE_MAPS = new HashMap<>(); + private static final Map, Class> UNBOX_TYPE_MAPS = new HashMap<>(); + + static { + BOX_TYPE_MAPS.put(int.class, Integer.class); + BOX_TYPE_MAPS.put(long.class, Long.class); + BOX_TYPE_MAPS.put(short.class, Short.class); + BOX_TYPE_MAPS.put(byte.class, Byte.class); + BOX_TYPE_MAPS.put(boolean.class, Boolean.class); + BOX_TYPE_MAPS.put(double.class, Double.class); + + UNBOX_TYPE_MAPS.put(Integer.class, int.class); + UNBOX_TYPE_MAPS.put(Long.class, long.class); + UNBOX_TYPE_MAPS.put(Short.class, short.class); + UNBOX_TYPE_MAPS.put(Byte.class, byte.class); + UNBOX_TYPE_MAPS.put(Boolean.class, boolean.class); + UNBOX_TYPE_MAPS.put(Double.class, double.class); + } + + public static Method findMatchMethod(Class udfClass, List> paramTypes) { + return findMatchMethod(udfClass, UDF_EVAL_METHOD_NAME, paramTypes); + } + + public static Method findMatchMethod( + Class clazz, String methodName, List> paramTypes) { + List methods = getAllMethod(clazz); + double maxScore = 0d; + Method bestMatch = null; + for (Method method : methods) { + if (!method.getName().equals(methodName)) { + continue; + } + Class[] defineTypes = method.getParameterTypes(); + double score = getMatchScore(defineTypes, paramTypes.toArray(new Class[] {})); + if (score > maxScore) { + maxScore = score; + bestMatch = method; + } } - - private static final Map, Class> BOX_TYPE_MAPS = new HashMap<>(); - private static final Map, Class> UNBOX_TYPE_MAPS = new HashMap<>(); - - static { - BOX_TYPE_MAPS.put(int.class, Integer.class); - BOX_TYPE_MAPS.put(long.class, Long.class); - BOX_TYPE_MAPS.put(short.class, Short.class); - BOX_TYPE_MAPS.put(byte.class, Byte.class); - BOX_TYPE_MAPS.put(boolean.class, Boolean.class); - BOX_TYPE_MAPS.put(double.class, Double.class); - - UNBOX_TYPE_MAPS.put(Integer.class, int.class); - UNBOX_TYPE_MAPS.put(Long.class, long.class); - UNBOX_TYPE_MAPS.put(Short.class, short.class); - UNBOX_TYPE_MAPS.put(Byte.class, byte.class); - UNBOX_TYPE_MAPS.put(Boolean.class, boolean.class); - UNBOX_TYPE_MAPS.put(Double.class, double.class); + if (bestMatch != null) { + return bestMatch; } + throw new IllegalArgumentException( + "Cannot find method " + methodName + " in " + clazz + ",input paramType is " + paramTypes); + } + + private static double getMatchScore(Class[] defineTypes, Class[] callTypes) { - public static Method findMatchMethod(Class udfClass, List> paramTypes) { - return findMatchMethod(udfClass, UDF_EVAL_METHOD_NAME, paramTypes); + if (defineTypes.length == 0 && callTypes.length == 0) { + return 1; } - public static Method findMatchMethod(Class clazz, String methodName, List> paramTypes) { - List methods = getAllMethod(clazz); - double maxScore = 0d; - Method bestMatch = null; - for (Method method : methods) { - if (!method.getName().equals(methodName)) { - continue; - } - Class[] defineTypes = method.getParameterTypes(); - double score = getMatchScore(defineTypes, paramTypes.toArray(new Class[]{})); - if (score > maxScore) { - maxScore = score; - bestMatch = method; - } - } - if (bestMatch != null) { - return bestMatch; - } - throw new IllegalArgumentException("Cannot find method " + methodName + " in " + clazz - + ",input paramType is " + paramTypes); + if (defineTypes.length == 0 && callTypes.length > 0) { + return 0; } - private static double getMatchScore(Class[] defineTypes, Class[] callTypes) { + if (defineTypes.length > callTypes.length) { + return 0; + } - if (defineTypes.length == 0 && callTypes.length == 0) { - return 1; - } + // + double score = 1.0d; - if (defineTypes.length == 0 && callTypes.length > 0) { - return 0; - } - - if (defineTypes.length > callTypes.length) { + int i; + for (i = 0; i < defineTypes.length - 1; i++) { + double s = getScore(defineTypes[i], callTypes[i]); + if (s == 0) { + return 0; + } + score *= s; + if (score == 0) { + return 0; + } + } + Class lastDefineType = defineTypes[i]; + + // test whether the last is a variable parameter. + if (lastDefineType.isArray()) { + Class componentType = lastDefineType.getComponentType(); + if (callTypes[i].isArray() && i == callTypes.length - 1) { + score *= getScore(componentType, callTypes[i].getComponentType()); + } else { + for (; i < callTypes.length; i++) { + double s = getScore(componentType, callTypes[i]); + + if (s == 0) { return 0; + } + score *= s; } - - // - double score = 1.0d; - - int i; - for (i = 0; i < defineTypes.length - 1; i++) { - double s = getScore(defineTypes[i], callTypes[i]); - if (s == 0) { - return 0; - } - score *= s; - if (score == 0) { - return 0; - } - } - Class lastDefineType = defineTypes[i]; - - // test whether the last is a variable parameter. - if (lastDefineType.isArray()) { - Class componentType = lastDefineType.getComponentType(); - if (callTypes[i].isArray() - && i == callTypes.length - 1) { - score *= getScore(componentType, callTypes[i].getComponentType()); - } else { - for (; i < callTypes.length; i++) { - double s = getScore(componentType, callTypes[i]); - - if (s == 0) { - return 0; - } - score *= s; - } - } - return score; - } else { - double s = getScore(lastDefineType, callTypes[i]); - if (s == 0) { - return 0; - } - score *= s; - if (score > 0 && defineTypes.length == callTypes.length) { - return score; - } - } + } + return score; + } else { + double s = getScore(lastDefineType, callTypes[i]); + if (s == 0) { return 0; + } + score *= s; + if (score > 0 && defineTypes.length == callTypes.length) { + return score; + } } - - private static double getScore(Class defineType, Class callType) { - defineType = getBoxType(defineType); - callType = getBoxType(callType); - - if (defineType == callType) { - return 1d; - } else { - if (callType == null) { // the input parameter is null - return 1d; - } - int typeDegreeIndex = findTypeDegreeIndex(defineType, callType); - - if (typeDegreeIndex != -1) { - // (0, 0.9] - return (float) (0.9 * (1 - 0.1 * typeDegreeIndex)); - } else if (defineType.isAssignableFrom(callType)) { - return 0.5d; - } else if (callType == BinaryString.class && defineType == String.class) { - return 0.6d; - } else { - return 0; - } - } + return 0; + } + + private static double getScore(Class defineType, Class callType) { + defineType = getBoxType(defineType); + callType = getBoxType(callType); + + if (defineType == callType) { + return 1d; + } else { + if (callType == null) { // the input parameter is null + return 1d; + } + int typeDegreeIndex = findTypeDegreeIndex(defineType, callType); + + if (typeDegreeIndex != -1) { + // (0, 0.9] + return (float) (0.9 * (1 - 0.1 * typeDegreeIndex)); + } else if (defineType.isAssignableFrom(callType)) { + return 0.5d; + } else if (callType == BinaryString.class && defineType == String.class) { + return 0.6d; + } else { + return 0; + } } + } - private static int findTypeDegreeIndex(Class defineType, Class callType) { + private static int findTypeDegreeIndex(Class defineType, Class callType) { - Class[] degreeTypes = TYPE_DEGREE_MAP.get(callType); - if (degreeTypes == null) { - return -1; - } - for (int i = 0; i < degreeTypes.length; i++) { - if (degreeTypes[i] == defineType) { - return i; - } - } - return -1; + Class[] degreeTypes = TYPE_DEGREE_MAP.get(callType); + if (degreeTypes == null) { + return -1; } - - public static List getAllEvalParamTypes(Class udfClass) { - List evalMethods = getAllEvalMethods(udfClass); - List types = new ArrayList<>(); - for (Method evalMethod : evalMethods) { - types.add(evalMethod.getParameterTypes()); - } - return types; + for (int i = 0; i < degreeTypes.length; i++) { + if (degreeTypes[i] == defineType) { + return i; + } } - - public static List getAllEvalMethods(Class udfClass) { - List evalMethods = new ArrayList<>(); - Class clazz = udfClass; - while (clazz != Object.class) { - Method[] methods = clazz.getDeclaredMethods(); - for (Method method : methods) { - if (method.getName().equals(UDF_EVAL_METHOD_NAME)) { - evalMethods.add(method); - } - } - clazz = clazz.getSuperclass(); - } - return evalMethods; + return -1; + } + + public static List getAllEvalParamTypes(Class udfClass) { + List evalMethods = getAllEvalMethods(udfClass); + List types = new ArrayList<>(); + for (Method evalMethod : evalMethods) { + types.add(evalMethod.getParameterTypes()); } - - private static List getAllMethod(Class udfClass) { - List evalMethods = new ArrayList<>(); - Class clazz = udfClass; - while (clazz != Object.class) { - Method[] methods = clazz.getDeclaredMethods(); - for (Method method : methods) { - evalMethods.add(method); - } - clazz = clazz.getSuperclass(); + return types; + } + + public static List getAllEvalMethods(Class udfClass) { + List evalMethods = new ArrayList<>(); + Class clazz = udfClass; + while (clazz != Object.class) { + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods) { + if (method.getName().equals(UDF_EVAL_METHOD_NAME)) { + evalMethods.add(method); } - return evalMethods; - } - - public static Class getBoxType(Class type) { - return BOX_TYPE_MAPS.getOrDefault(type, type); + } + clazz = clazz.getSuperclass(); } - - public static Class getUnboxType(Class type) { - return UNBOX_TYPE_MAPS.getOrDefault(type, type); + return evalMethods; + } + + private static List getAllMethod(Class udfClass) { + List evalMethods = new ArrayList<>(); + Class clazz = udfClass; + while (clazz != Object.class) { + Method[] methods = clazz.getDeclaredMethods(); + for (Method method : methods) { + evalMethods.add(method); + } + clazz = clazz.getSuperclass(); } + return evalMethods; + } + public static Class getBoxType(Class type) { + return BOX_TYPE_MAPS.getOrDefault(type, type); + } - public static Type[] getUDAFGenericTypes(Class udafClass) { - return FunctionCallUtils.getGenericTypes(udafClass, UDAF.class); - } - - public static Type[] getGenericTypes(Class clazz, Class baseClass) { - if (!baseClass.isAssignableFrom(clazz)) { - throw new IllegalArgumentException( - "input clazz must be a sub class of the base class: " + baseClass); - } + public static Class getUnboxType(Class type) { + return UNBOX_TYPE_MAPS.getOrDefault(type, type); + } - Map>, Type> defineTypeMap = new HashMap<>(); - - while (clazz != baseClass) { - Class superClass = clazz.getSuperclass(); - TypeVariable>[] typeVariables = superClass.getTypeParameters(); - Type[] types = ((ParameterizedType) clazz.getGenericSuperclass()) - .getActualTypeArguments(); - - for (int i = 0; i < typeVariables.length; i++) { - TypeVariable> typeVariable = typeVariables[i]; - Type type = - types[i] instanceof ParameterizedType ? ((ParameterizedType) types[i]).getRawType() : types[i]; - defineTypeMap.put(typeVariable, type); - } - clazz = superClass; - } + public static Type[] getUDAFGenericTypes(Class udafClass) { + return FunctionCallUtils.getGenericTypes(udafClass, UDAF.class); + } - TypeVariable>[] typeVariables = baseClass.getTypeParameters(); - Type[] types = new Type[typeVariables.length]; + public static Type[] getGenericTypes(Class clazz, Class baseClass) { + if (!baseClass.isAssignableFrom(clazz)) { + throw new IllegalArgumentException( + "input clazz must be a sub class of the base class: " + baseClass); + } - for (int i = 0; i < typeVariables.length; i++) { - Type type = typeVariables[i]; - do { - type = defineTypeMap.get(type); - } while (type != null && !(type instanceof Class)); - types[i] = type; - } - return types; + Map>, Type> defineTypeMap = new HashMap<>(); + + while (clazz != baseClass) { + Class superClass = clazz.getSuperclass(); + TypeVariable>[] typeVariables = superClass.getTypeParameters(); + Type[] types = ((ParameterizedType) clazz.getGenericSuperclass()).getActualTypeArguments(); + + for (int i = 0; i < typeVariables.length; i++) { + TypeVariable> typeVariable = typeVariables[i]; + Type type = + types[i] instanceof ParameterizedType + ? ((ParameterizedType) types[i]).getRawType() + : types[i]; + defineTypeMap.put(typeVariable, type); + } + clazz = superClass; } + TypeVariable>[] typeVariables = baseClass.getTypeParameters(); + Type[] types = new Type[typeVariables.length]; - public static Class> findMatchUDAF(String name, - List>> udafClassList, - List> callTypes) { - double maxScore = 0; - Class> bestClass = null; - List>> allDefinedTypes = new ArrayList<>(); - - for (Class> udafClass : udafClassList) { - List> defineTypes = getUDAFInputTypes(udafClass); - allDefinedTypes.add(defineTypes); - if (callTypes.size() != defineTypes.size()) { - continue; - } - - double score = 1.0; - for (int i = 0; i < callTypes.size(); i++) { - score *= getScore(defineTypes.get(i), callTypes.get(i)); - } - if (score > maxScore) { - maxScore = score; - bestClass = udafClass; - } - } - if (bestClass != null) { - return bestClass; - } - - throw new GeaFlowDSLException( - "Mismatch input types for " + name + ",the input type is " + callTypes - + ",while the udaf defined type is " + Joiner.on(" OR ").join(allDefinedTypes)); + for (int i = 0; i < typeVariables.length; i++) { + Type type = typeVariables[i]; + do { + type = defineTypeMap.get(type); + } while (type != null && !(type instanceof Class)); + types[i] = type; } - - public static List> getUDAFInputTypes(Class udafClass) { - List> inputTypes = Lists.newArrayList(); - Type[] genericTypes = getGenericTypes(udafClass, UDAF.class); - Class inputType = (Class) genericTypes[0]; - // case for UDAF has multi-parameters. - if (UDAFArguments.class.isAssignableFrom(inputType)) { - try { - UDAFArguments input = (UDAFArguments) inputType.newInstance(); - inputTypes.addAll(input.getParamTypes()); - } catch (Exception e) { - throw new RuntimeException(e); - } - } else { - inputTypes.add(inputType); - } - return inputTypes; + return types; + } + + public static Class> findMatchUDAF( + String name, List>> udafClassList, List> callTypes) { + double maxScore = 0; + Class> bestClass = null; + List>> allDefinedTypes = new ArrayList<>(); + + for (Class> udafClass : udafClassList) { + List> defineTypes = getUDAFInputTypes(udafClass); + allDefinedTypes.add(defineTypes); + if (callTypes.size() != defineTypes.size()) { + continue; + } + + double score = 1.0; + for (int i = 0; i < callTypes.size(); i++) { + score *= getScore(defineTypes.get(i), callTypes.get(i)); + } + if (score > maxScore) { + maxScore = score; + bestClass = udafClass; + } + } + if (bestClass != null) { + return bestClass; } - public static Object callMethod(Method method, Object target, Object[] params) - throws InvocationTargetException, IllegalAccessException { - - Class[] defineTypes = method.getParameterTypes(); - int variableParamIndex = -1; - - if (defineTypes.length > 0) { - Class lastDefineType = defineTypes[defineTypes.length - 1]; - if (lastDefineType.isArray()) { - if (params[defineTypes.length - 1] != null - && params[defineTypes.length - 1].getClass().isArray() - && params.length == defineTypes.length) { - variableParamIndex = -1; - } else { - variableParamIndex = defineTypes.length - 1; - } - } - } - - int paramSize = variableParamIndex >= 0 ? variableParamIndex + 1 : params.length; - Object[] castParams = new Object[paramSize]; - - if (variableParamIndex >= 0) { - int i = 0; - for (; i < variableParamIndex; i++) { - castParams[i] = TypeCastUtil.cast(params[i], getBoxType(defineTypes[i])); - } - Class componentType = defineTypes[variableParamIndex].getComponentType(); - Object[] varParaArray = - (Object[]) Array.newInstance(componentType, params.length - variableParamIndex); - - for (; i < params.length; i++) { - varParaArray[i - variableParamIndex] = TypeCastUtil.cast(params[i], - getBoxType(componentType)); - } - - castParams[variableParamIndex] = varParaArray; + throw new GeaFlowDSLException( + "Mismatch input types for " + + name + + ",the input type is " + + callTypes + + ",while the udaf defined type is " + + Joiner.on(" OR ").join(allDefinedTypes)); + } + + public static List> getUDAFInputTypes(Class udafClass) { + List> inputTypes = Lists.newArrayList(); + Type[] genericTypes = getGenericTypes(udafClass, UDAF.class); + Class inputType = (Class) genericTypes[0]; + // case for UDAF has multi-parameters. + if (UDAFArguments.class.isAssignableFrom(inputType)) { + try { + UDAFArguments input = (UDAFArguments) inputType.newInstance(); + inputTypes.addAll(input.getParamTypes()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + inputTypes.add(inputType); + } + return inputTypes; + } + + public static Object callMethod(Method method, Object target, Object[] params) + throws InvocationTargetException, IllegalAccessException { + + Class[] defineTypes = method.getParameterTypes(); + int variableParamIndex = -1; + + if (defineTypes.length > 0) { + Class lastDefineType = defineTypes[defineTypes.length - 1]; + if (lastDefineType.isArray()) { + if (params[defineTypes.length - 1] != null + && params[defineTypes.length - 1].getClass().isArray() + && params.length == defineTypes.length) { + variableParamIndex = -1; } else { - for (int i = 0; i < params.length; i++) { - castParams[i] = TypeCastUtil.cast(params[i], getBoxType(defineTypes[i])); - } - } - Object result = method.invoke(target, castParams); - if (result instanceof String) { // convert string to binary string if the udf return string type. - result = BinaryString.fromString((String) result); + variableParamIndex = defineTypes.length - 1; } - return result; + } } - public static Class typeClass(Class type, boolean useBinary) { - if (useBinary) { - if (type == String.class) { - return BinaryString.class; - } - if (type.isArray() && type.getComponentType() == String.class) { - return Array.newInstance(BinaryString.class, 0).getClass(); - } - } - return type; + int paramSize = variableParamIndex >= 0 ? variableParamIndex + 1 : params.length; + Object[] castParams = new Object[paramSize]; + + if (variableParamIndex >= 0) { + int i = 0; + for (; i < variableParamIndex; i++) { + castParams[i] = TypeCastUtil.cast(params[i], getBoxType(defineTypes[i])); + } + Class componentType = defineTypes[variableParamIndex].getComponentType(); + Object[] varParaArray = + (Object[]) Array.newInstance(componentType, params.length - variableParamIndex); + + for (; i < params.length; i++) { + varParaArray[i - variableParamIndex] = + TypeCastUtil.cast(params[i], getBoxType(componentType)); + } + + castParams[variableParamIndex] = varParaArray; + } else { + for (int i = 0; i < params.length; i++) { + castParams[i] = TypeCastUtil.cast(params[i], getBoxType(defineTypes[i])); + } + } + Object result = method.invoke(target, castParams); + if (result + instanceof String) { // convert string to binary string if the udf return string type. + result = BinaryString.fromString((String) result); + } + return result; + } + + public static Class typeClass(Class type, boolean useBinary) { + if (useBinary) { + if (type == String.class) { + return BinaryString.class; + } + if (type.isArray() && type.getComponentType() == String.class) { + return Array.newInstance(BinaryString.class, 0).getClass(); + } } + return type; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java index f5631f335..034d4c713 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/TypeCastUtil.java @@ -19,619 +19,622 @@ package org.apache.geaflow.dsl.common.util; -import com.google.common.collect.ImmutableMap; import java.lang.reflect.Array; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Map; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.common.type.IType; +import com.google.common.collect.ImmutableMap; + public class TypeCastUtil { - @SuppressWarnings("unchecked") - private static final Map, ITypeCast> typeCasts = - (Map) ImmutableMap.builder() - .put(Tuple.of(Integer.class, Long.class), new Int2Long()) - .put(Tuple.of(Integer.class, Double.class), new Int2Double()) - .put(Tuple.of(Integer.class, BigDecimal.class), new Int2Decimal()) - .put(Tuple.of(Integer.class, String.class), new Int2String()) - .put(Tuple.of(Integer.class, BinaryString.class), new Int2BinaryString()) - .put(Tuple.of(Integer.class, Timestamp.class), new Int2Timestamp()) - .put(Tuple.of(Long.class, Double.class), new Long2Double()) - .put(Tuple.of(Long.class, Integer.class), new Long2Int()) - .put(Tuple.of(Long.class, BigDecimal.class), new Long2Decimal()) - .put(Tuple.of(Long.class, String.class), new Long2String()) - .put(Tuple.of(Long.class, BinaryString.class), new Long2BinaryString()) - .put(Tuple.of(Long.class, Timestamp.class), new Long2Timestamp()) - .put(Tuple.of(String.class, Long.class), new String2Long()) - .put(Tuple.of(String.class, Integer.class), new String2Int()) - .put(Tuple.of(String.class, Double.class), new String2Double()) - .put(Tuple.of(String.class, BigDecimal.class), new String2Decimal()) - .put(Tuple.of(String.class, Boolean.class), new String2Boolean()) - .put(Tuple.of(String.class, BinaryString.class), new String2Binary()) - .put(Tuple.of(String.class, Timestamp.class), new String2Timestamp()) - .put(Tuple.of(String.class, Date.class), new String2Date()) - .put(Tuple.of(BinaryString.class, Long.class), new BinaryString2Long()) - .put(Tuple.of(BinaryString.class, Integer.class), new BinaryString2Int()) - .put(Tuple.of(BinaryString.class, Double.class), new BinaryString2Double()) - .put(Tuple.of(BinaryString.class, BigDecimal.class), new BinaryString2Decimal()) - .put(Tuple.of(BinaryString.class, Boolean.class), new BinaryString2Boolean()) - .put(Tuple.of(BinaryString.class, String.class), new BinaryString2String()) - .put(Tuple.of(BinaryString.class, Timestamp.class), new BinaryString2Timestamp()) - .put(Tuple.of(BinaryString.class, Date.class), new BinaryString2Date()) - .put(Tuple.of(Double.class, Long.class), new Double2Long()) - .put(Tuple.of(Double.class, Integer.class), new Double2Int()) - .put(Tuple.of(Double.class, BigDecimal.class), new Double2Decimal()) - .put(Tuple.of(Double.class, String.class), new Double2String()) - .put(Tuple.of(Double.class, BinaryString.class), new Double2BinaryString()) - .put(Tuple.of(Boolean.class, String.class), new Boolean2String()) - .put(Tuple.of(Boolean.class, BinaryString.class), new Boolean2BinaryString()) - .put(Tuple.of(BigDecimal.class, Double.class), new Decimal2Double()) - .put(Tuple.of(BigDecimal.class, Integer.class), new Decimal2Int()) - .put(Tuple.of(BigDecimal.class, Long.class), new Decimal2Long()) - .put(Tuple.of(BigDecimal.class, String.class), new Decimal2String()) - .put(Tuple.of(BigDecimal.class, BinaryString.class), new Decimal2BinaryString()) - .put(Tuple.of(byte[].class, BinaryString.class), new Bytes2BinaryString()) - .build(); - - private static final ITypeCast identityCast = new IdentityCast(); - - public static ITypeCast getTypeCast(IType sourceType, IType targetType) { - return getTypeCast(sourceType.getTypeClass(), targetType.getTypeClass()); - } - - @SuppressWarnings("unchecked") - public static ITypeCast getTypeCast(Class sourceType, Class targetType) { - sourceType = FunctionCallUtils.getBoxType(sourceType); - targetType = FunctionCallUtils.getBoxType(targetType); - if (sourceType == targetType) { - return identityCast; - } - if (sourceType.isArray() && targetType.isArray()) { - ITypeCast componentTypeCast = getTypeCast(sourceType.getComponentType(), targetType.getComponentType()); - return (ITypeCast) new ArrayCast(componentTypeCast, targetType.getComponentType()); - } - ITypeCast typeCast = typeCasts.get(Tuple.of(sourceType, targetType)); - if (typeCast == null) { - throw new IllegalArgumentException("Cannot cast from: " + sourceType + " to " + targetType); - } - return typeCast; - } - - public static Object cast(Object o, IType type) { - return cast(o, type.getTypeClass()); - } - - public static Object cast(Object o, Class targetType) { - if (o == null) { - return null; - } - if (targetType.isAssignableFrom(o.getClass())) { - return o; - } - if (targetType == Object.class) { - return o; - } - return getTypeCast(o.getClass(), targetType).castTo(o); - } - - public static boolean isInteger(String s) { - if (s == null) { - return false; - } - for (int i = 0; i < s.length(); i++) { - if (!Character.isDigit(s.charAt(i))) { - return false; - } - } - return true; - } - - public static boolean isInteger(BinaryString s) { - if (s == null) { - return false; - } - for (int i = 0; i < s.getLength(); i++) { - if (!Character.isDigit(s.getByte(i))) { - return false; - } - } - return true; - } - - public interface ITypeCast { - T castTo(S s); - } - - private static class IdentityCast implements ITypeCast { - - @Override - public S castTo(S s) { - return s; - } - } - - private static class ArrayCast implements ITypeCast { - - private final ITypeCast componentTypeCast; - - private final Class targetType; - - public ArrayCast( - ITypeCast componentTypeCast, Class targetType) { - this.componentTypeCast = componentTypeCast; - this.targetType = targetType; - } - - @Override - public Object castTo(Object objects) { - int length = Array.getLength(objects); - Object castArray = Array.newInstance(targetType, length); - for (int i = 0; i < length; i++) { - Object castObject = componentTypeCast.castTo(Array.get(objects, i)); - Array.set(castArray, i, castObject); - } - return castArray; - } + @SuppressWarnings("unchecked") + private static final Map, ITypeCast> typeCasts = + (Map) + ImmutableMap.builder() + .put(Tuple.of(Integer.class, Long.class), new Int2Long()) + .put(Tuple.of(Integer.class, Double.class), new Int2Double()) + .put(Tuple.of(Integer.class, BigDecimal.class), new Int2Decimal()) + .put(Tuple.of(Integer.class, String.class), new Int2String()) + .put(Tuple.of(Integer.class, BinaryString.class), new Int2BinaryString()) + .put(Tuple.of(Integer.class, Timestamp.class), new Int2Timestamp()) + .put(Tuple.of(Long.class, Double.class), new Long2Double()) + .put(Tuple.of(Long.class, Integer.class), new Long2Int()) + .put(Tuple.of(Long.class, BigDecimal.class), new Long2Decimal()) + .put(Tuple.of(Long.class, String.class), new Long2String()) + .put(Tuple.of(Long.class, BinaryString.class), new Long2BinaryString()) + .put(Tuple.of(Long.class, Timestamp.class), new Long2Timestamp()) + .put(Tuple.of(String.class, Long.class), new String2Long()) + .put(Tuple.of(String.class, Integer.class), new String2Int()) + .put(Tuple.of(String.class, Double.class), new String2Double()) + .put(Tuple.of(String.class, BigDecimal.class), new String2Decimal()) + .put(Tuple.of(String.class, Boolean.class), new String2Boolean()) + .put(Tuple.of(String.class, BinaryString.class), new String2Binary()) + .put(Tuple.of(String.class, Timestamp.class), new String2Timestamp()) + .put(Tuple.of(String.class, Date.class), new String2Date()) + .put(Tuple.of(BinaryString.class, Long.class), new BinaryString2Long()) + .put(Tuple.of(BinaryString.class, Integer.class), new BinaryString2Int()) + .put(Tuple.of(BinaryString.class, Double.class), new BinaryString2Double()) + .put(Tuple.of(BinaryString.class, BigDecimal.class), new BinaryString2Decimal()) + .put(Tuple.of(BinaryString.class, Boolean.class), new BinaryString2Boolean()) + .put(Tuple.of(BinaryString.class, String.class), new BinaryString2String()) + .put(Tuple.of(BinaryString.class, Timestamp.class), new BinaryString2Timestamp()) + .put(Tuple.of(BinaryString.class, Date.class), new BinaryString2Date()) + .put(Tuple.of(Double.class, Long.class), new Double2Long()) + .put(Tuple.of(Double.class, Integer.class), new Double2Int()) + .put(Tuple.of(Double.class, BigDecimal.class), new Double2Decimal()) + .put(Tuple.of(Double.class, String.class), new Double2String()) + .put(Tuple.of(Double.class, BinaryString.class), new Double2BinaryString()) + .put(Tuple.of(Boolean.class, String.class), new Boolean2String()) + .put(Tuple.of(Boolean.class, BinaryString.class), new Boolean2BinaryString()) + .put(Tuple.of(BigDecimal.class, Double.class), new Decimal2Double()) + .put(Tuple.of(BigDecimal.class, Integer.class), new Decimal2Int()) + .put(Tuple.of(BigDecimal.class, Long.class), new Decimal2Long()) + .put(Tuple.of(BigDecimal.class, String.class), new Decimal2String()) + .put(Tuple.of(BigDecimal.class, BinaryString.class), new Decimal2BinaryString()) + .put(Tuple.of(byte[].class, BinaryString.class), new Bytes2BinaryString()) + .build(); + + private static final ITypeCast identityCast = new IdentityCast(); + + public static ITypeCast getTypeCast(IType sourceType, IType targetType) { + return getTypeCast(sourceType.getTypeClass(), targetType.getTypeClass()); + } + + @SuppressWarnings("unchecked") + public static ITypeCast getTypeCast(Class sourceType, Class targetType) { + sourceType = FunctionCallUtils.getBoxType(sourceType); + targetType = FunctionCallUtils.getBoxType(targetType); + if (sourceType == targetType) { + return identityCast; + } + if (sourceType.isArray() && targetType.isArray()) { + ITypeCast componentTypeCast = + getTypeCast(sourceType.getComponentType(), targetType.getComponentType()); + return (ITypeCast) new ArrayCast(componentTypeCast, targetType.getComponentType()); + } + ITypeCast typeCast = typeCasts.get(Tuple.of(sourceType, targetType)); + if (typeCast == null) { + throw new IllegalArgumentException("Cannot cast from: " + sourceType + " to " + targetType); + } + return typeCast; + } + + public static Object cast(Object o, IType type) { + return cast(o, type.getTypeClass()); + } + + public static Object cast(Object o, Class targetType) { + if (o == null) { + return null; + } + if (targetType.isAssignableFrom(o.getClass())) { + return o; + } + if (targetType == Object.class) { + return o; + } + return getTypeCast(o.getClass(), targetType).castTo(o); + } + + public static boolean isInteger(String s) { + if (s == null) { + return false; + } + for (int i = 0; i < s.length(); i++) { + if (!Character.isDigit(s.charAt(i))) { + return false; + } + } + return true; + } + + public static boolean isInteger(BinaryString s) { + if (s == null) { + return false; + } + for (int i = 0; i < s.getLength(); i++) { + if (!Character.isDigit(s.getByte(i))) { + return false; + } + } + return true; + } + + public interface ITypeCast { + T castTo(S s); + } + + private static class IdentityCast implements ITypeCast { + + @Override + public S castTo(S s) { + return s; + } + } + + private static class ArrayCast implements ITypeCast { + + private final ITypeCast componentTypeCast; + + private final Class targetType; + + public ArrayCast(ITypeCast componentTypeCast, Class targetType) { + this.componentTypeCast = componentTypeCast; + this.targetType = targetType; + } + + @Override + public Object castTo(Object objects) { + int length = Array.getLength(objects); + Object castArray = Array.newInstance(targetType, length); + for (int i = 0; i < length; i++) { + Object castObject = componentTypeCast.castTo(Array.get(objects, i)); + Array.set(castArray, i, castObject); + } + return castArray; } + } - private static class Int2Long implements ITypeCast { + private static class Int2Long implements ITypeCast { - @Override - public Long castTo(Integer o) { - if (o == null) { - return null; - } - return o.longValue(); - } + @Override + public Long castTo(Integer o) { + if (o == null) { + return null; + } + return o.longValue(); } + } - private static class Int2Double implements ITypeCast { + private static class Int2Double implements ITypeCast { - @Override - public Double castTo(Integer o) { - if (o == null) { - return null; - } - return o.doubleValue(); - } + @Override + public Double castTo(Integer o) { + if (o == null) { + return null; + } + return o.doubleValue(); } + } - private static class Int2Decimal implements ITypeCast { + private static class Int2Decimal implements ITypeCast { - @Override - public BigDecimal castTo(Integer o) { - if (o == null) { - return null; - } - return new BigDecimal(o); - } + @Override + public BigDecimal castTo(Integer o) { + if (o == null) { + return null; + } + return new BigDecimal(o); } + } - private static class Int2String implements ITypeCast { + private static class Int2String implements ITypeCast { - @Override - public String castTo(Integer o) { - if (o == null) { - return null; - } - return String.valueOf(o); - } + @Override + public String castTo(Integer o) { + if (o == null) { + return null; + } + return String.valueOf(o); } + } - private static class Int2BinaryString implements ITypeCast { + private static class Int2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(Integer o) { - if (o == null) { - return null; - } - return BinaryString.fromString(String.valueOf(o)); - } + @Override + public BinaryString castTo(Integer o) { + if (o == null) { + return null; + } + return BinaryString.fromString(String.valueOf(o)); } + } - private static class Int2Timestamp implements ITypeCast { + private static class Int2Timestamp implements ITypeCast { - @Override - public Timestamp castTo(Integer o) { - if (o == null) { - return null; - } - return new Timestamp(o); - } + @Override + public Timestamp castTo(Integer o) { + if (o == null) { + return null; + } + return new Timestamp(o); } + } - private static class Long2Double implements ITypeCast { + private static class Long2Double implements ITypeCast { - @Override - public Double castTo(Long o) { - if (o == null) { - return null; - } - return o.doubleValue(); - } + @Override + public Double castTo(Long o) { + if (o == null) { + return null; + } + return o.doubleValue(); } + } - private static class Long2Int implements ITypeCast { + private static class Long2Int implements ITypeCast { - @Override - public Integer castTo(Long o) { - if (o == null) { - return null; - } - return o.intValue(); - } + @Override + public Integer castTo(Long o) { + if (o == null) { + return null; + } + return o.intValue(); } + } - private static class Long2Decimal implements ITypeCast { + private static class Long2Decimal implements ITypeCast { - @Override - public BigDecimal castTo(Long o) { - if (o == null) { - return null; - } - return new BigDecimal(o); - } + @Override + public BigDecimal castTo(Long o) { + if (o == null) { + return null; + } + return new BigDecimal(o); } + } - private static class Long2String implements ITypeCast { + private static class Long2String implements ITypeCast { - @Override - public String castTo(Long o) { - if (o == null) { - return null; - } - return String.valueOf(o); - } + @Override + public String castTo(Long o) { + if (o == null) { + return null; + } + return String.valueOf(o); } + } - private static class Long2BinaryString implements ITypeCast { + private static class Long2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(Long o) { - if (o == null) { - return null; - } - return BinaryString.fromString(String.valueOf(o)); - } + @Override + public BinaryString castTo(Long o) { + if (o == null) { + return null; + } + return BinaryString.fromString(String.valueOf(o)); } + } - private static class Long2Timestamp implements ITypeCast { + private static class Long2Timestamp implements ITypeCast { - @Override - public Timestamp castTo(Long o) { - if (o == null) { - return null; - } - return new Timestamp(o); - } + @Override + public Timestamp castTo(Long o) { + if (o == null) { + return null; + } + return new Timestamp(o); } + } - private static class String2Long implements ITypeCast { + private static class String2Long implements ITypeCast { - @Override - public Long castTo(String o) { - if (o == null) { - return null; - } - return Long.parseLong(o); - } + @Override + public Long castTo(String o) { + if (o == null) { + return null; + } + return Long.parseLong(o); } + } - private static class String2Int implements ITypeCast { + private static class String2Int implements ITypeCast { - @Override - public Integer castTo(String o) { - if (o == null) { - return null; - } - return Integer.parseInt(o); - } + @Override + public Integer castTo(String o) { + if (o == null) { + return null; + } + return Integer.parseInt(o); } + } - private static class String2Double implements ITypeCast { + private static class String2Double implements ITypeCast { - @Override - public Double castTo(String o) { - if (o == null) { - return null; - } - return Double.valueOf(o); - } + @Override + public Double castTo(String o) { + if (o == null) { + return null; + } + return Double.valueOf(o); } + } - private static class String2Decimal implements ITypeCast { + private static class String2Decimal implements ITypeCast { - @Override - public BigDecimal castTo(String o) { - if (o == null) { - return null; - } - return new BigDecimal(o); - } + @Override + public BigDecimal castTo(String o) { + if (o == null) { + return null; + } + return new BigDecimal(o); } + } - private static class String2Boolean implements ITypeCast { + private static class String2Boolean implements ITypeCast { - @Override - public Boolean castTo(String o) { - if (o == null) { - return null; - } - return Boolean.valueOf(o); - } + @Override + public Boolean castTo(String o) { + if (o == null) { + return null; + } + return Boolean.valueOf(o); } + } - private static class String2Binary implements ITypeCast { + private static class String2Binary implements ITypeCast { - @Override - public BinaryString castTo(String o) { - if (o == null) { - return null; - } - return BinaryString.fromString(o); - } + @Override + public BinaryString castTo(String o) { + if (o == null) { + return null; + } + return BinaryString.fromString(o); } + } - private static class String2Timestamp implements ITypeCast { + private static class String2Timestamp implements ITypeCast { - @Override - public Timestamp castTo(String o) { - if (o == null) { - return null; - } - if (isInteger(o)) { - return new Timestamp(Long.parseLong(o)); - } - return Timestamp.valueOf(o); - } + @Override + public Timestamp castTo(String o) { + if (o == null) { + return null; + } + if (isInteger(o)) { + return new Timestamp(Long.parseLong(o)); + } + return Timestamp.valueOf(o); } + } - private static class String2Date implements ITypeCast { + private static class String2Date implements ITypeCast { - @Override - public Date castTo(String o) { - if (o == null) { - return null; - } - return Date.valueOf(o); - } + @Override + public Date castTo(String o) { + if (o == null) { + return null; + } + return Date.valueOf(o); } + } - private static class BinaryString2Long implements ITypeCast { + private static class BinaryString2Long implements ITypeCast { - @Override - public Long castTo(BinaryString o) { - if (o == null) { - return null; - } - return Long.parseLong(o.toString()); - } + @Override + public Long castTo(BinaryString o) { + if (o == null) { + return null; + } + return Long.parseLong(o.toString()); } + } - private static class BinaryString2Int implements ITypeCast { + private static class BinaryString2Int implements ITypeCast { - @Override - public Integer castTo(BinaryString o) { - if (o == null) { - return null; - } - return Integer.parseInt(o.toString()); - } + @Override + public Integer castTo(BinaryString o) { + if (o == null) { + return null; + } + return Integer.parseInt(o.toString()); } + } - private static class BinaryString2Double implements ITypeCast { + private static class BinaryString2Double implements ITypeCast { - @Override - public Double castTo(BinaryString o) { - if (o == null) { - return null; - } - return Double.valueOf(o.toString()); - } + @Override + public Double castTo(BinaryString o) { + if (o == null) { + return null; + } + return Double.valueOf(o.toString()); } + } - private static class BinaryString2Decimal implements ITypeCast { + private static class BinaryString2Decimal implements ITypeCast { - @Override - public BigDecimal castTo(BinaryString o) { - if (o == null) { - return null; - } - return new BigDecimal(o.toString()); - } + @Override + public BigDecimal castTo(BinaryString o) { + if (o == null) { + return null; + } + return new BigDecimal(o.toString()); } + } - private static class BinaryString2Boolean implements ITypeCast { + private static class BinaryString2Boolean implements ITypeCast { - @Override - public Boolean castTo(BinaryString o) { - if (o == null) { - return null; - } - return Boolean.valueOf(o.toString()); - } + @Override + public Boolean castTo(BinaryString o) { + if (o == null) { + return null; + } + return Boolean.valueOf(o.toString()); } + } - private static class BinaryString2String implements ITypeCast { + private static class BinaryString2String implements ITypeCast { - @Override - public String castTo(BinaryString o) { - if (o == null) { - return null; - } - return o.toString(); - } + @Override + public String castTo(BinaryString o) { + if (o == null) { + return null; + } + return o.toString(); } + } - private static class BinaryString2Timestamp implements ITypeCast { + private static class BinaryString2Timestamp implements ITypeCast { - @Override - public Timestamp castTo(BinaryString o) { - if (o == null) { - return null; - } - if (isInteger(o)) { - return new Timestamp(Long.parseLong(o.toString())); - } - return Timestamp.valueOf(o.toString()); - } + @Override + public Timestamp castTo(BinaryString o) { + if (o == null) { + return null; + } + if (isInteger(o)) { + return new Timestamp(Long.parseLong(o.toString())); + } + return Timestamp.valueOf(o.toString()); } + } - private static class BinaryString2Date implements ITypeCast { + private static class BinaryString2Date implements ITypeCast { - @Override - public Date castTo(BinaryString o) { - if (o == null) { - return null; - } - return Date.valueOf(o.toString()); - } + @Override + public Date castTo(BinaryString o) { + if (o == null) { + return null; + } + return Date.valueOf(o.toString()); } + } - private static class Double2Long implements ITypeCast { + private static class Double2Long implements ITypeCast { - @Override - public Long castTo(Double o) { - if (o == null) { - return null; - } - return o.longValue(); - } + @Override + public Long castTo(Double o) { + if (o == null) { + return null; + } + return o.longValue(); } + } - private static class Double2Int implements ITypeCast { + private static class Double2Int implements ITypeCast { - @Override - public Integer castTo(Double o) { - if (o == null) { - return null; - } - return o.intValue(); - } + @Override + public Integer castTo(Double o) { + if (o == null) { + return null; + } + return o.intValue(); } + } - private static class Double2Decimal implements ITypeCast { + private static class Double2Decimal implements ITypeCast { - @Override - public BigDecimal castTo(Double o) { - if (o == null) { - return null; - } - return new BigDecimal(o); - } + @Override + public BigDecimal castTo(Double o) { + if (o == null) { + return null; + } + return new BigDecimal(o); } + } - private static class Double2String implements ITypeCast { + private static class Double2String implements ITypeCast { - @Override - public String castTo(Double o) { - if (o == null) { - return null; - } - return o.toString(); - } + @Override + public String castTo(Double o) { + if (o == null) { + return null; + } + return o.toString(); } + } - private static class Double2BinaryString implements ITypeCast { + private static class Double2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(Double o) { - if (o == null) { - return null; - } - return BinaryString.fromString(o.toString()); - } + @Override + public BinaryString castTo(Double o) { + if (o == null) { + return null; + } + return BinaryString.fromString(o.toString()); } + } - private static class Boolean2String implements ITypeCast { + private static class Boolean2String implements ITypeCast { - @Override - public String castTo(Boolean o) { - if (o == null) { - return null; - } - return o.toString(); - } + @Override + public String castTo(Boolean o) { + if (o == null) { + return null; + } + return o.toString(); } + } - private static class Boolean2BinaryString implements ITypeCast { + private static class Boolean2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(Boolean o) { - if (o == null) { - return null; - } - return BinaryString.fromString(o.toString()); - } + @Override + public BinaryString castTo(Boolean o) { + if (o == null) { + return null; + } + return BinaryString.fromString(o.toString()); } + } - private static class Decimal2Double implements ITypeCast { + private static class Decimal2Double implements ITypeCast { - @Override - public Double castTo(BigDecimal o) { - if (o == null) { - return null; - } - return o.doubleValue(); - } + @Override + public Double castTo(BigDecimal o) { + if (o == null) { + return null; + } + return o.doubleValue(); } + } - private static class Decimal2Long implements ITypeCast { + private static class Decimal2Long implements ITypeCast { - @Override - public Long castTo(BigDecimal o) { - if (o == null) { - return null; - } - return o.longValue(); - } + @Override + public Long castTo(BigDecimal o) { + if (o == null) { + return null; + } + return o.longValue(); } + } - private static class Decimal2Int implements ITypeCast { + private static class Decimal2Int implements ITypeCast { - @Override - public Integer castTo(BigDecimal o) { - if (o == null) { - return null; - } - return o.intValue(); - } + @Override + public Integer castTo(BigDecimal o) { + if (o == null) { + return null; + } + return o.intValue(); } + } - private static class Decimal2String implements ITypeCast { + private static class Decimal2String implements ITypeCast { - @Override - public String castTo(BigDecimal o) { - if (o == null) { - return null; - } - return o.toString(); - } + @Override + public String castTo(BigDecimal o) { + if (o == null) { + return null; + } + return o.toString(); } + } - private static class Decimal2BinaryString implements ITypeCast { + private static class Decimal2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(BigDecimal o) { - if (o == null) { - return null; - } - return BinaryString.fromString(o.toString()); - } + @Override + public BinaryString castTo(BigDecimal o) { + if (o == null) { + return null; + } + return BinaryString.fromString(o.toString()); } + } - private static class Bytes2BinaryString implements ITypeCast { - @Override - public BinaryString castTo(byte[] o) { - if (o == null) { - return null; - } - return BinaryString.fromString(new String(o, StandardCharsets.UTF_8)); - } + private static class Bytes2BinaryString implements ITypeCast { + @Override + public BinaryString castTo(byte[] o) { + if (o == null) { + return null; + } + return BinaryString.fromString(new String(o, StandardCharsets.UTF_8)); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/Windows.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/Windows.java index 17f986d9f..9910fb487 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/Windows.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/util/Windows.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.common.util; -import com.google.common.base.Preconditions; import org.apache.geaflow.api.window.IWindow; import org.apache.geaflow.api.window.impl.AllWindow; import org.apache.geaflow.api.window.impl.FixedTimeTumblingWindow; @@ -27,32 +26,38 @@ import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; +import com.google.common.base.Preconditions; + public class Windows { - public static final long SIZE_OF_ALL_WINDOW = -1L; + public static final long SIZE_OF_ALL_WINDOW = -1L; - public static IWindow createWindow(Configuration configuration) { - long batchWindowSize = Integer.MIN_VALUE; - if (configuration.contains(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE)) { - batchWindowSize = configuration.getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); - Preconditions.checkState(batchWindowSize != 0, "Window size should not be zero!"); - } - long timeWindowDuration = -1; - if (configuration.contains(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE)) { - timeWindowDuration = configuration.getLong(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE); - Preconditions.checkState(timeWindowDuration > 0, "Time Window size should not be positive!"); - } - Preconditions.checkState(!(batchWindowSize >= SIZE_OF_ALL_WINDOW && timeWindowDuration > 0), - "Only one of window can exist! size window:%s, time window:%s", batchWindowSize, timeWindowDuration); - if (batchWindowSize == SIZE_OF_ALL_WINDOW) { - return AllWindow.getInstance(); - } else if (batchWindowSize > 0) { - return new SizeTumblingWindow<>(batchWindowSize); - } else if (timeWindowDuration > 0) { - return new FixedTimeTumblingWindow<>(timeWindowDuration); - } else { - // use default - return new SizeTumblingWindow<>((Long) DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getDefaultValue()); - } + public static IWindow createWindow(Configuration configuration) { + long batchWindowSize = Integer.MIN_VALUE; + if (configuration.contains(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE)) { + batchWindowSize = configuration.getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); + Preconditions.checkState(batchWindowSize != 0, "Window size should not be zero!"); + } + long timeWindowDuration = -1; + if (configuration.contains(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE)) { + timeWindowDuration = configuration.getLong(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE); + Preconditions.checkState(timeWindowDuration > 0, "Time Window size should not be positive!"); + } + Preconditions.checkState( + !(batchWindowSize >= SIZE_OF_ALL_WINDOW && timeWindowDuration > 0), + "Only one of window can exist! size window:%s, time window:%s", + batchWindowSize, + timeWindowDuration); + if (batchWindowSize == SIZE_OF_ALL_WINDOW) { + return AllWindow.getInstance(); + } else if (batchWindowSize > 0) { + return new SizeTumblingWindow<>(batchWindowSize); + } else if (timeWindowDuration > 0) { + return new FixedTimeTumblingWindow<>(timeWindowDuration); + } else { + // use default + return new SizeTumblingWindow<>( + (Long) DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getDefaultValue()); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/binary/BinaryEncodeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/binary/BinaryEncodeTest.java index e3a3cf4f4..a9127b739 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/binary/BinaryEncodeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/binary/BinaryEncodeTest.java @@ -41,52 +41,52 @@ public class BinaryEncodeTest { - @Test - public void testEncodeVertex() { - IType[] idTypes = getIdTypes(); - for (IType idType : idTypes) { - VertexType vertexType = getVertexType(idType); - RowVertex originVertex = getVertex(idType); - IBinaryEncoder vertexEncoder = EncoderFactory.createEncoder(vertexType); - Row encodeVertex = vertexEncoder.encode(originVertex); - IBinaryDecoder vertexDecoder = DecoderFactory.createDecoder(vertexType); - Row decodeVertex = vertexDecoder.decode(encodeVertex); - checkResult(originVertex, decodeVertex, vertexType); - } + @Test + public void testEncodeVertex() { + IType[] idTypes = getIdTypes(); + for (IType idType : idTypes) { + VertexType vertexType = getVertexType(idType); + RowVertex originVertex = getVertex(idType); + IBinaryEncoder vertexEncoder = EncoderFactory.createEncoder(vertexType); + Row encodeVertex = vertexEncoder.encode(originVertex); + IBinaryDecoder vertexDecoder = DecoderFactory.createDecoder(vertexType); + Row decodeVertex = vertexDecoder.decode(encodeVertex); + checkResult(originVertex, decodeVertex, vertexType); } + } - @Test - public void testEncodeEdge() { - IType[] idTypes = getIdTypes(); - for (IType idType : idTypes) { - EdgeType edgeType = getEdgeType(idType, true); - RowEdge originEdge = getEdge(idType, true); - IBinaryEncoder edgeEncoder = EncoderFactory.createEncoder(edgeType); - Row encodeEdge = edgeEncoder.encode(originEdge); - IBinaryDecoder edgeDecoder = DecoderFactory.createDecoder(edgeType); - Row decodeEdge = edgeDecoder.decode(encodeEdge); - checkResult(originEdge, decodeEdge, edgeType); - } - - for (IType idType : idTypes) { - EdgeType edgeType = getEdgeType(idType, false); - RowEdge originEdge = getEdge(idType, false); - IBinaryEncoder edgeEncoder = EncoderFactory.createEncoder(edgeType); - Row encodeEdge = edgeEncoder.encode(originEdge); - IBinaryDecoder edgeDecoder = DecoderFactory.createDecoder(edgeType); - Row decodeEdge = edgeDecoder.decode(encodeEdge); - checkResult(originEdge, decodeEdge, edgeType); - } + @Test + public void testEncodeEdge() { + IType[] idTypes = getIdTypes(); + for (IType idType : idTypes) { + EdgeType edgeType = getEdgeType(idType, true); + RowEdge originEdge = getEdge(idType, true); + IBinaryEncoder edgeEncoder = EncoderFactory.createEncoder(edgeType); + Row encodeEdge = edgeEncoder.encode(originEdge); + IBinaryDecoder edgeDecoder = DecoderFactory.createDecoder(edgeType); + Row decodeEdge = edgeDecoder.decode(encodeEdge); + checkResult(originEdge, decodeEdge, edgeType); } - @Test - public void testEncodeRow() { - StructType rowType = getRowType(); - Row originRow = getRow(); - IBinaryEncoder rowEncoder = EncoderFactory.createEncoder(rowType); - Row encodeRow = rowEncoder.encode(originRow); - IBinaryDecoder rowDecoder = DecoderFactory.createDecoder(rowType); - Row decodeRow = rowDecoder.decode(encodeRow); - checkResult(originRow, decodeRow, rowType); + for (IType idType : idTypes) { + EdgeType edgeType = getEdgeType(idType, false); + RowEdge originEdge = getEdge(idType, false); + IBinaryEncoder edgeEncoder = EncoderFactory.createEncoder(edgeType); + Row encodeEdge = edgeEncoder.encode(originEdge); + IBinaryDecoder edgeDecoder = DecoderFactory.createDecoder(edgeType); + Row decodeEdge = edgeDecoder.decode(encodeEdge); + checkResult(originEdge, decodeEdge, edgeType); } + } + + @Test + public void testEncodeRow() { + StructType rowType = getRowType(); + Row originRow = getRow(); + IBinaryEncoder rowEncoder = EncoderFactory.createEncoder(rowType); + Row encodeRow = rowEncoder.encode(originRow); + IBinaryDecoder rowDecoder = DecoderFactory.createDecoder(rowType); + Row decodeRow = rowDecoder.decode(encodeRow); + checkResult(originRow, decodeRow, rowType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/ArrayUtilTest.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/ArrayUtilTest.java index aa9565c25..94f1db605 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/ArrayUtilTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/ArrayUtilTest.java @@ -27,17 +27,17 @@ public class ArrayUtilTest { - @Test - public void testConcatArray() { - Integer[] array1 = null; - Integer[] array2 = null; - Object[] concatArray1 = ArrayUtil.concatArray(array1, array2); - assertNull(concatArray1); - array1 = new Integer[0]; - Object[] concatArray2 = ArrayUtil.concatArray(array1, array2); - assertEquals(concatArray2, array1); - array2 = new Integer[0]; - Object[] concatArray3 = ArrayUtil.concatArray(array1, array2); - assertEquals(concatArray3.length, 0); - } + @Test + public void testConcatArray() { + Integer[] array1 = null; + Integer[] array2 = null; + Object[] concatArray1 = ArrayUtil.concatArray(array1, array2); + assertNull(concatArray1); + array1 = new Integer[0]; + Object[] concatArray2 = ArrayUtil.concatArray(array1, array2); + assertEquals(concatArray2, array1); + array2 = new Integer[0]; + Object[] concatArray3 = ArrayUtil.concatArray(array1, array2); + assertEquals(concatArray3.length, 0); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/BasicDataTest.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/BasicDataTest.java index 0f3a6995c..f7cf932ed 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/BasicDataTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/BasicDataTest.java @@ -30,7 +30,6 @@ import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; -import com.google.common.collect.Lists; import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.utils.ClassUtil; @@ -55,222 +54,223 @@ import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class BasicDataTest { - @Test - public void testDefaultPath() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - DefaultPath emptyPath = new DefaultPath(); - assertEquals(emptyPath.getPathNodes().size(), 0); - emptyPath.addNode(ObjectRow.create(fields)); - assertEquals(emptyPath.size(), 1); - assertEquals(emptyPath.subPath(Lists.newArrayList(0)).size(), 1); - } + @Test + public void testDefaultPath() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + DefaultPath emptyPath = new DefaultPath(); + assertEquals(emptyPath.getPathNodes().size(), 0); + emptyPath.addNode(ObjectRow.create(fields)); + assertEquals(emptyPath.size(), 1); + assertEquals(emptyPath.subPath(Lists.newArrayList(0)).size(), 1); + } - @Test - public void testObjectEdge() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - ObjectEdge emptyEdge = - new ObjectEdge(1, 2); - ObjectEdge rowEdge = - new ObjectEdge(1, 2, ObjectRow.create(fields)); - assertNull(rowEdge.getBinaryLabel()); - assertEquals(rowEdge.getSrcId(), 1); - rowEdge.setSrcId(2); - assertEquals(rowEdge.getSrcId(), 2); - assertEquals(rowEdge.getTargetId(), 2); - rowEdge.setTargetId(1); - assertEquals(rowEdge.getTargetId(), 1); - assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); - assertEquals(rowEdge.getDirect(), EdgeDirection.OUT); + @Test + public void testObjectEdge() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + ObjectEdge emptyEdge = new ObjectEdge(1, 2); + ObjectEdge rowEdge = new ObjectEdge(1, 2, ObjectRow.create(fields)); + assertNull(rowEdge.getBinaryLabel()); + assertEquals(rowEdge.getSrcId(), 1); + rowEdge.setSrcId(2); + assertEquals(rowEdge.getSrcId(), 2); + assertEquals(rowEdge.getTargetId(), 2); + rowEdge.setTargetId(1); + assertEquals(rowEdge.getTargetId(), 1); + assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); + assertEquals(rowEdge.getDirect(), EdgeDirection.OUT); - ObjectEdge reverse = rowEdge.reverse(); - assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); - assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); + ObjectEdge reverse = rowEdge.reverse(); + assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); + assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); - ObjectEdge rowEdge2 = rowEdge.withValue(ObjectRow.create(fields)); - assertEquals(rowEdge2.getField(0, null), 2); - assertEquals(rowEdge2.getField(1, null), 1); - assertNull(rowEdge2.getField(2, null)); - assertEquals(rowEdge2.getField(3, null), "test"); - assertEquals(rowEdge2, rowEdge); - assertNotEquals(emptyEdge, rowEdge); - } + ObjectEdge rowEdge2 = rowEdge.withValue(ObjectRow.create(fields)); + assertEquals(rowEdge2.getField(0, null), 2); + assertEquals(rowEdge2.getField(1, null), 1); + assertNull(rowEdge2.getField(2, null)); + assertEquals(rowEdge2.getField(3, null), "test"); + assertEquals(rowEdge2, rowEdge); + assertNotEquals(emptyEdge, rowEdge); + } - @Test - public void testEdge() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - IType[] idTypes = getIdTypes(); - for (IType idType : idTypes) { - RowEdge rowEdge = getEdge(idType, ObjectRow.create(fields), false); - assertEquals(rowEdge.getBinaryLabel(), BinaryString.fromString("edgeLabel")); - assertEquals(rowEdge.getSrcId(), getSrcId(idType)); - assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); - assertEquals(rowEdge.getDirect(), EdgeDirection.IN); - RowEdge reverse = (RowEdge) rowEdge.reverse(); - assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); - assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); - RowEdge rowEdge2 = (RowEdge) rowEdge.withValue(ObjectRow.create(fields)); - rowEdge2.setLabel("relation"); - assertEquals(rowEdge2.getLabel(), "relation"); - assertEquals(rowEdge2.getField(0, null), getSrcId(idType)); - assertEquals(rowEdge2.getField(1, null), getTargetId(idType)); - assertEquals(rowEdge2.getField(2, null), BinaryString.fromString("relation")); - assertEquals(rowEdge2.getField(3, null), "test"); - assertNotEquals(rowEdge2, rowEdge); - assertEquals(reverse.reverse(), rowEdge); - } + @Test + public void testEdge() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + IType[] idTypes = getIdTypes(); + for (IType idType : idTypes) { + RowEdge rowEdge = getEdge(idType, ObjectRow.create(fields), false); + assertEquals(rowEdge.getBinaryLabel(), BinaryString.fromString("edgeLabel")); + assertEquals(rowEdge.getSrcId(), getSrcId(idType)); + assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); + assertEquals(rowEdge.getDirect(), EdgeDirection.IN); + RowEdge reverse = (RowEdge) rowEdge.reverse(); + assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); + assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); + RowEdge rowEdge2 = (RowEdge) rowEdge.withValue(ObjectRow.create(fields)); + rowEdge2.setLabel("relation"); + assertEquals(rowEdge2.getLabel(), "relation"); + assertEquals(rowEdge2.getField(0, null), getSrcId(idType)); + assertEquals(rowEdge2.getField(1, null), getTargetId(idType)); + assertEquals(rowEdge2.getField(2, null), BinaryString.fromString("relation")); + assertEquals(rowEdge2.getField(3, null), "test"); + assertNotEquals(rowEdge2, rowEdge); + assertEquals(reverse.reverse(), rowEdge); } + } - @Test - public void testObjectTsEdge() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - ObjectTsEdge rowEdge = - new ObjectTsEdge(1, 2, ObjectRow.create(fields)); - assertEquals(rowEdge.getTime(), 0); - rowEdge.setTime(1); - assertEquals(rowEdge.getTime(), 1); - ObjectTsEdge reverse = rowEdge.reverse(); - ObjectTsEdge rowEdge2 = (ObjectTsEdge) rowEdge.withValue(ObjectRow.create(fields)); - assertEquals(rowEdge2.getField(0, null), 1); - assertEquals(rowEdge2.getField(1, null), 2); - assertNull(rowEdge2.getField(2, null)); - assertEquals(rowEdge2.getField(3, null), 1L); - assertEquals(rowEdge2.getField(4, null), "test"); - assertEquals(rowEdge2, rowEdge); - assertNotEquals(reverse, rowEdge); - } + @Test + public void testObjectTsEdge() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + ObjectTsEdge rowEdge = new ObjectTsEdge(1, 2, ObjectRow.create(fields)); + assertEquals(rowEdge.getTime(), 0); + rowEdge.setTime(1); + assertEquals(rowEdge.getTime(), 1); + ObjectTsEdge reverse = rowEdge.reverse(); + ObjectTsEdge rowEdge2 = (ObjectTsEdge) rowEdge.withValue(ObjectRow.create(fields)); + assertEquals(rowEdge2.getField(0, null), 1); + assertEquals(rowEdge2.getField(1, null), 2); + assertNull(rowEdge2.getField(2, null)); + assertEquals(rowEdge2.getField(3, null), 1L); + assertEquals(rowEdge2.getField(4, null), "test"); + assertEquals(rowEdge2, rowEdge); + assertNotEquals(reverse, rowEdge); + } - @Test - public void testEdgeWithTs() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - IType[] idTypes = getIdTypes(); - for (IType idType : idTypes) { - RowEdge rowEdge = getEdge(idType, ObjectRow.create(fields), true); - ((IGraphElementWithTimeField) rowEdge).setTime(1L); - assertEquals(((IGraphElementWithTimeField) rowEdge).getTime(), 1L); - assertEquals(rowEdge.getBinaryLabel(), BinaryString.fromString("edgeLabel")); - assertEquals(rowEdge.getSrcId(), getSrcId(idType)); - assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); - assertEquals(rowEdge.getDirect(), EdgeDirection.IN); - RowEdge reverse = (RowEdge) rowEdge.reverse(); - assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); - assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); - RowEdge rowEdge2 = (RowEdge) rowEdge.withValue(ObjectRow.create(fields)); - rowEdge2.setLabel("relation"); - assertEquals(rowEdge2.getLabel(), "relation"); - assertEquals(rowEdge2.getField(0, null), getSrcId(idType)); - assertEquals(rowEdge2.getField(1, null), getTargetId(idType)); - assertEquals(rowEdge2.getField(2, null), BinaryString.fromString("relation")); - assertEquals(rowEdge2.getField(3, null), 1L); - assertEquals(rowEdge2.getField(4, null), "test"); - assertNotEquals(rowEdge2, rowEdge); - assertEquals(reverse.reverse(), rowEdge); - } + @Test + public void testEdgeWithTs() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + IType[] idTypes = getIdTypes(); + for (IType idType : idTypes) { + RowEdge rowEdge = getEdge(idType, ObjectRow.create(fields), true); + ((IGraphElementWithTimeField) rowEdge).setTime(1L); + assertEquals(((IGraphElementWithTimeField) rowEdge).getTime(), 1L); + assertEquals(rowEdge.getBinaryLabel(), BinaryString.fromString("edgeLabel")); + assertEquals(rowEdge.getSrcId(), getSrcId(idType)); + assertEquals(rowEdge.getValue().toString(), "[test, 1, 1.0]"); + assertEquals(rowEdge.getDirect(), EdgeDirection.IN); + RowEdge reverse = (RowEdge) rowEdge.reverse(); + assertEquals(reverse.getSrcId(), rowEdge.getTargetId()); + assertEquals(reverse.getTargetId(), rowEdge.getSrcId()); + RowEdge rowEdge2 = (RowEdge) rowEdge.withValue(ObjectRow.create(fields)); + rowEdge2.setLabel("relation"); + assertEquals(rowEdge2.getLabel(), "relation"); + assertEquals(rowEdge2.getField(0, null), getSrcId(idType)); + assertEquals(rowEdge2.getField(1, null), getTargetId(idType)); + assertEquals(rowEdge2.getField(2, null), BinaryString.fromString("relation")); + assertEquals(rowEdge2.getField(3, null), 1L); + assertEquals(rowEdge2.getField(4, null), "test"); + assertNotEquals(rowEdge2, rowEdge); + assertEquals(reverse.reverse(), rowEdge); } + } - @Test - public void testObjectVertex() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - ObjectVertex emptyVertex = new ObjectVertex(2); - ObjectVertex rowVertex = - new ObjectVertex(1, BinaryString.fromString("person"), ObjectRow.create(fields)); - assertEquals(rowVertex.getLabel(), "person"); - assertEquals(rowVertex.getId(), 1); - assertEquals(rowVertex.getValue().toString(), "[test, 1, 1.0]"); - ObjectVertex rowVertex2 = rowVertex.withValue(ObjectRow.create(fields)); - rowVertex2 = rowVertex2.withLabel("user"); - assertEquals(rowVertex2.getLabel(), "user"); - assertEquals(rowVertex.compareTo(rowVertex2), 0); - assertEquals(rowVertex2.getField(0, null), 1); - assertEquals(rowVertex2.getField(1, null), BinaryString.fromString("user")); - assertEquals(rowVertex2.getField(2, null), "test"); - assertNotEquals(rowVertex2, rowVertex); - assertNotEquals(emptyVertex, rowVertex); - } + @Test + public void testObjectVertex() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + ObjectVertex emptyVertex = new ObjectVertex(2); + ObjectVertex rowVertex = + new ObjectVertex(1, BinaryString.fromString("person"), ObjectRow.create(fields)); + assertEquals(rowVertex.getLabel(), "person"); + assertEquals(rowVertex.getId(), 1); + assertEquals(rowVertex.getValue().toString(), "[test, 1, 1.0]"); + ObjectVertex rowVertex2 = rowVertex.withValue(ObjectRow.create(fields)); + rowVertex2 = rowVertex2.withLabel("user"); + assertEquals(rowVertex2.getLabel(), "user"); + assertEquals(rowVertex.compareTo(rowVertex2), 0); + assertEquals(rowVertex2.getField(0, null), 1); + assertEquals(rowVertex2.getField(1, null), BinaryString.fromString("user")); + assertEquals(rowVertex2.getField(2, null), "test"); + assertNotEquals(rowVertex2, rowVertex); + assertNotEquals(emptyVertex, rowVertex); + } - @Test - public void testVertex() { - Object[] fields = new Object[]{"test", 1L, 1.0}; - IType[] idTypes = getIdTypes(); - for (IType idType : idTypes) { - RowVertex rowVertex = getVertex(idType, ObjectRow.create(fields)); - assertEquals(rowVertex.getBinaryLabel(), BinaryString.fromString("vertexLabel")); - assertEquals(rowVertex.getId(), getSrcId(idType)); - assertEquals(rowVertex.getValue().toString(), "[test, 1, 1.0]"); - RowVertex rowVertex2 = (RowVertex) rowVertex.withValue(ObjectRow.create(fields)); - rowVertex2 = (RowVertex) rowVertex2.withLabel("user"); - assertEquals(rowVertex.compareTo(rowVertex2), 0); - assertEquals(rowVertex2.getField(0, null), getSrcId(idType)); - assertEquals(rowVertex2.getField(1, null), BinaryString.fromString("user")); - assertEquals(rowVertex2.getField(2, null), "test"); - assertNotEquals(rowVertex2, rowVertex); - } + @Test + public void testVertex() { + Object[] fields = new Object[] {"test", 1L, 1.0}; + IType[] idTypes = getIdTypes(); + for (IType idType : idTypes) { + RowVertex rowVertex = getVertex(idType, ObjectRow.create(fields)); + assertEquals(rowVertex.getBinaryLabel(), BinaryString.fromString("vertexLabel")); + assertEquals(rowVertex.getId(), getSrcId(idType)); + assertEquals(rowVertex.getValue().toString(), "[test, 1, 1.0]"); + RowVertex rowVertex2 = (RowVertex) rowVertex.withValue(ObjectRow.create(fields)); + rowVertex2 = (RowVertex) rowVertex2.withLabel("user"); + assertEquals(rowVertex.compareTo(rowVertex2), 0); + assertEquals(rowVertex2.getField(0, null), getSrcId(idType)); + assertEquals(rowVertex2.getField(1, null), BinaryString.fromString("user")); + assertEquals(rowVertex2.getField(2, null), "test"); + assertNotEquals(rowVertex2, rowVertex); } + } - @Test - public void testVertexEqual() { - IntVertex intVertex = new IntVertex(123); - assertEquals(intVertex.id, 123); - ObjectVertex objectVertex = new ObjectVertex(new Integer(123)); - assertTrue(intVertex.equals(objectVertex)); - DoubleVertex doubleVertex = new DoubleVertex(1.23); - assertTrue(doubleVertex.id > 1.22); - } + @Test + public void testVertexEqual() { + IntVertex intVertex = new IntVertex(123); + assertEquals(intVertex.id, 123); + ObjectVertex objectVertex = new ObjectVertex(new Integer(123)); + assertTrue(intVertex.equals(objectVertex)); + DoubleVertex doubleVertex = new DoubleVertex(1.23); + assertTrue(doubleVertex.id > 1.22); + } - @Test - public void testEdgeEqual() { - LongEdge longEdge = new LongEdge(12L, 23L); - ObjectEdge objectEdge = new ObjectEdge(new Long(12L), new Long(23L)); - assertFalse(longEdge.equals(objectEdge)); - IntEdge intEdge = new IntEdge(12, 23); - ObjectEdge objectEdge2 = new ObjectEdge(new Integer(12), new Integer(23)); - assertFalse(intEdge.equals(objectEdge2)); - } + @Test + public void testEdgeEqual() { + LongEdge longEdge = new LongEdge(12L, 23L); + ObjectEdge objectEdge = new ObjectEdge(new Long(12L), new Long(23L)); + assertFalse(longEdge.equals(objectEdge)); + IntEdge intEdge = new IntEdge(12, 23); + ObjectEdge objectEdge2 = new ObjectEdge(new Integer(12), new Integer(23)); + assertFalse(intEdge.equals(objectEdge2)); + } - @Test - public void testEdgeIdentityReverse() { - testEdgeIdentityReverse(BinaryStringEdge.class, BinaryString.fromString("1"), BinaryString.fromString("2")); - testEdgeIdentityReverse(BinaryStringTsEdge.class, BinaryString.fromString("1"), BinaryString.fromString("2")); - testEdgeIdentityReverse(DoubleEdge.class, 1.0, 2.0); - testEdgeIdentityReverse(DoubleTsEdge.class, 1.0, 2.0); - testEdgeIdentityReverse(IntEdge.class, 1, 2); - testEdgeIdentityReverse(IntTsEdge.class, 1, 2); - testEdgeIdentityReverse(LongEdge.class, 1L, 2L); - testEdgeIdentityReverse(LongTsEdge.class, 1L, 2L); - testEdgeIdentityReverse(ObjectEdge.class, 1, 2); - } + @Test + public void testEdgeIdentityReverse() { + testEdgeIdentityReverse( + BinaryStringEdge.class, BinaryString.fromString("1"), BinaryString.fromString("2")); + testEdgeIdentityReverse( + BinaryStringTsEdge.class, BinaryString.fromString("1"), BinaryString.fromString("2")); + testEdgeIdentityReverse(DoubleEdge.class, 1.0, 2.0); + testEdgeIdentityReverse(DoubleTsEdge.class, 1.0, 2.0); + testEdgeIdentityReverse(IntEdge.class, 1, 2); + testEdgeIdentityReverse(IntTsEdge.class, 1, 2); + testEdgeIdentityReverse(LongEdge.class, 1L, 2L); + testEdgeIdentityReverse(LongTsEdge.class, 1L, 2L); + testEdgeIdentityReverse(ObjectEdge.class, 1, 2); + } - private void testEdgeIdentityReverse(Class edgeClass, Object srcId, Object targetId) { - RowEdge edge = ClassUtil.newInstance(edgeClass); - edge.setSrcId(srcId); - edge.setTargetId(targetId); - edge.setDirect(EdgeDirection.OUT); - RowEdge identityReverse = edge.identityReverse(); - assertEquals(identityReverse.getClass(), edgeClass); - assertEquals(identityReverse.getSrcId(), edge.getTargetId()); - assertEquals(identityReverse.getTargetId(), edge.getSrcId()); - assertEquals(identityReverse.getDirect(), EdgeDirection.IN); - } + private void testEdgeIdentityReverse( + Class edgeClass, Object srcId, Object targetId) { + RowEdge edge = ClassUtil.newInstance(edgeClass); + edge.setSrcId(srcId); + edge.setTargetId(targetId); + edge.setDirect(EdgeDirection.OUT); + RowEdge identityReverse = edge.identityReverse(); + assertEquals(identityReverse.getClass(), edgeClass); + assertEquals(identityReverse.getSrcId(), edge.getTargetId()); + assertEquals(identityReverse.getTargetId(), edge.getSrcId()); + assertEquals(identityReverse.getDirect(), EdgeDirection.IN); + } - @Test - public void testParameterizedPath() { - Path basePath = new DefaultPath(); - basePath.addNode(new ObjectVertex()); - basePath.addNode(new ObjectEdge()); + @Test + public void testParameterizedPath() { + Path basePath = new DefaultPath(); + basePath.addNode(new ObjectVertex()); + basePath.addNode(new ObjectEdge()); - DefaultParameterizedPath path = new DefaultParameterizedPath(basePath, - 1L, null, null); - assertEquals(path.getPathNodes(), basePath.getPathNodes()); + DefaultParameterizedPath path = new DefaultParameterizedPath(basePath, 1L, null, null); + assertEquals(path.getPathNodes(), basePath.getPathNodes()); - path.addNode(new ObjectVertex()); - assertEquals(path.getPathNodes().size(), 3); + path.addNode(new ObjectVertex()); + assertEquals(path.getPathNodes().size(), 3); - path.remove(2); - assertEquals(path.getPathNodes().size(), 2); + path.remove(2); + assertEquals(path.getPathNodes().size(), 2); - ParameterizedPath subPath = (ParameterizedPath) path.subPath(new int[]{0}); - assertEquals(subPath.getSystemVariables(), path.getSystemVariables()); - assertEquals(subPath.getPathNodes().size(), 1); - } + ParameterizedPath subPath = (ParameterizedPath) path.subPath(new int[] {0}); + assertEquals(subPath.getSystemVariables(), path.getSystemVariables()); + assertEquals(subPath.getPathNodes().size(), 1); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/TypeCastUtilTest.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/TypeCastUtilTest.java index 84d074f52..5d0e6e677 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/TypeCastUtilTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/data/TypeCastUtilTest.java @@ -26,21 +26,21 @@ public class TypeCastUtilTest { - @Test - public void testIntCast() { - Integer i = 1; - Assert.assertEquals(TypeCastUtil.cast(i, Integer.class), 1); - Assert.assertEquals(TypeCastUtil.cast(i, Long.class), 1L); - Assert.assertEquals(TypeCastUtil.cast(i, Double.class), 1.0); - Assert.assertEquals(TypeCastUtil.cast(i, BinaryString.class), BinaryString.fromString("1")); - } + @Test + public void testIntCast() { + Integer i = 1; + Assert.assertEquals(TypeCastUtil.cast(i, Integer.class), 1); + Assert.assertEquals(TypeCastUtil.cast(i, Long.class), 1L); + Assert.assertEquals(TypeCastUtil.cast(i, Double.class), 1.0); + Assert.assertEquals(TypeCastUtil.cast(i, BinaryString.class), BinaryString.fromString("1")); + } - @Test - public void testCastDouble() { - Double d = 1.0; - Assert.assertEquals(TypeCastUtil.cast(d, Integer.class), 1); - Assert.assertEquals(TypeCastUtil.cast(d, Long.class), 1L); - Assert.assertEquals(TypeCastUtil.cast(d, Double.class), 1.0); - Assert.assertEquals(TypeCastUtil.cast(d, BinaryString.class), BinaryString.fromString("1.0")); - } + @Test + public void testCastDouble() { + Double d = 1.0; + Assert.assertEquals(TypeCastUtil.cast(d, Integer.class), 1); + Assert.assertEquals(TypeCastUtil.cast(d, Long.class), 1L); + Assert.assertEquals(TypeCastUtil.cast(d, Double.class), 1.0); + Assert.assertEquals(TypeCastUtil.cast(d, BinaryString.class), BinaryString.fromString("1.0")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/util/TestSchemaUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/util/TestSchemaUtil.java index 3fa59018b..ced3d1634 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/util/TestSchemaUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/test/java/org/apache/geaflow/dsl/common/util/TestSchemaUtil.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; @@ -41,164 +42,178 @@ public class TestSchemaUtil { - public static IType[] getIdTypes() { - return new IType[]{Types.INTEGER, Types.LONG, Types.DOUBLE, Types.BINARY_STRING, - Types.SHORT}; - } - - public static VertexType getVertexType(IType idType) { - List fields = new ArrayList<>(); - fields.add(new TableField("id", idType, false)); - fields.add(new TableField("label", Types.BINARY_STRING, false)); - fields.addAll(getRowTypeFields()); - return new VertexType(fields); - } - - public static EdgeType getEdgeType(IType idType, boolean hasTimestamp) { - List fields = new ArrayList<>(); - fields.add(new TableField("srcId", idType, false)); - fields.add(new TableField("dstId", idType, false)); - fields.add(new TableField("label", Types.BINARY_STRING, false)); - if (hasTimestamp) { - fields.add(new TableField("ts", Types.LONG, false)); - } - fields.addAll(getRowTypeFields()); - return new EdgeType(fields, hasTimestamp); - } - - public static StructType getRowType() { - return new StructType(getRowTypeFields()); - } - - private static List getRowTypeFields() { - List fields = new ArrayList<>(); - fields.add(new TableField("v0", Types.INTEGER, false)); - fields.add(new TableField("v1", Types.INTEGER, false)); - fields.add(new TableField("v2", Types.LONG, false)); - fields.add(new TableField("v3", Types.LONG, false)); - fields.add(new TableField("v4", Types.SHORT, false)); - fields.add(new TableField("v5", Types.SHORT, false)); - fields.add(new TableField("v6", Types.DOUBLE, false)); - fields.add(new TableField("v7", Types.DOUBLE, false)); - fields.add(new TableField("v8", Types.BINARY_STRING, false)); - fields.add(new TableField("v9", Types.BINARY_STRING, false)); - fields.add(new TableField("v10", new ArrayType(Types.INTEGER), false)); - fields.add(new TableField("v11", new ArrayType(Types.BINARY_STRING), false)); - fields.add(new TableField("v12", new ArrayType(Types.BINARY_STRING), false)); - return fields; - } - - public static RowVertex getVertex(IType idType) { - return getVertex(idType, getRow()); - } - - public static RowVertex getVertex(IType idType, ObjectRow row) { - RowVertex vertex = VertexEdgeFactory.createVertex(getVertexType(idType)); - vertex.setId(getSrcId(idType)); - vertex.setBinaryLabel(BinaryString.fromString("vertexLabel")); - vertex.setValue(row); - return vertex; - } - - public static RowEdge getEdge(IType idType, boolean hasTimestamp) { - return getEdge(idType, getRow(), hasTimestamp); - } - - public static RowEdge getEdge(IType idType, ObjectRow row, boolean hasTimestamp) { - RowEdge edge = VertexEdgeFactory.createEdge(getEdgeType(idType, hasTimestamp)); - edge.setSrcId(getSrcId(idType)); - edge.setTargetId(getTargetId(idType)); - edge.setBinaryLabel(BinaryString.fromString("edgeLabel")); - if (hasTimestamp) { - ((IGraphElementWithTimeField) edge).setTime(284L); - } - edge.setValue(row); - return edge.withDirection(EdgeDirection.IN); - } - - public static Object getSrcId(IType type) { - String idTypeName = type.getName().toUpperCase(Locale.ROOT); - switch (idTypeName) { - case Types.TYPE_NAME_INTEGER: - return 1; - case Types.TYPE_NAME_LONG: - return 1L; - case Types.TYPE_NAME_DOUBLE: - return 1.0d; - case Types.TYPE_NAME_BINARY_STRING: - return BinaryString.fromString("1"); - default: - return (short) 1; - } - } - - public static Object getTargetId(IType type) { - String idTypeName = type.getName().toUpperCase(Locale.ROOT); - switch (idTypeName) { - case Types.TYPE_NAME_INTEGER: - return 2; - case Types.TYPE_NAME_LONG: - return 2L; - case Types.TYPE_NAME_DOUBLE: - return 2.0d; - case Types.TYPE_NAME_BINARY_STRING: - return BinaryString.fromString("2"); - default: - return (short) 2; - } - } - - public static ObjectRow getRow() { - Object[] fields = new Object[]{111, null, 1234L, null, (short) 256, null, 1.0d, null, - "testValue__#123", null, new int[]{1, 2, 5}, new String[]{"23", "wyuety12", null, "w237", null}, null}; - return ObjectRow.create(fields); - } - - public static void checkResult(Object actual, Object expect, IType type) { - if (type instanceof VertexType) { - checkVertex((RowVertex) actual, (RowVertex) expect, (VertexType) type); - } else if (type instanceof EdgeType) { - checkEdge((RowEdge) actual, (RowEdge) expect, (EdgeType) type); - } else if (type instanceof StructType) { - checkRow((Row) actual, (Row) expect, (StructType) type); - } else if (type instanceof ArrayType) { - Object[] actualArray = (Object[]) actual; - Object[] expectArray = (Object[]) expect; - Assert.assertEquals(actualArray.length, expectArray.length); - - IType componentType = ((ArrayType) type).getComponentType(); - for (int i = 0; i < actualArray.length; i++) { - checkResult(actualArray[i], expectArray[i], componentType); - } - } else { - Assert.assertEquals(actual, expect); - } - } - - private static void checkVertex(RowVertex actual, RowVertex expect, VertexType vertexType) { - Assert.assertEquals(actual.getId(), expect.getId()); - Assert.assertEquals(actual.getBinaryLabel(), expect.getBinaryLabel()); - - List valueFields = vertexType.getValueFields(); - checkResult(actual.getValue(), expect.getValue(), new StructType(valueFields)); - } - - private static void checkEdge(RowEdge actual, RowEdge expect, EdgeType edgeType) { - Assert.assertEquals(actual.getSrcId(), expect.getSrcId()); - Assert.assertEquals(actual.getTargetId(), expect.getTargetId()); - Assert.assertEquals(actual.getBinaryLabel(), expect.getBinaryLabel()); - if (edgeType.getTimestamp().isPresent()) { - Assert.assertEquals(((IGraphElementWithTimeField) actual).getTime(), - ((IGraphElementWithTimeField) expect).getTime()); - } - Assert.assertEquals(actual.getDirect(), expect.getDirect()); - checkResult(actual.getValue(), expect.getValue(), new StructType(edgeType.getValueFields())); - } - - private static void checkRow(Row actual, Row expect, StructType rowType) { - IType[] types = rowType.getTypes(); - for (int i = 0; i < types.length; i++) { - Assert.assertEquals(actual.getField(i, types[i]), expect.getField(i, types[i])); - } - } + public static IType[] getIdTypes() { + return new IType[] {Types.INTEGER, Types.LONG, Types.DOUBLE, Types.BINARY_STRING, Types.SHORT}; + } + + public static VertexType getVertexType(IType idType) { + List fields = new ArrayList<>(); + fields.add(new TableField("id", idType, false)); + fields.add(new TableField("label", Types.BINARY_STRING, false)); + fields.addAll(getRowTypeFields()); + return new VertexType(fields); + } + + public static EdgeType getEdgeType(IType idType, boolean hasTimestamp) { + List fields = new ArrayList<>(); + fields.add(new TableField("srcId", idType, false)); + fields.add(new TableField("dstId", idType, false)); + fields.add(new TableField("label", Types.BINARY_STRING, false)); + if (hasTimestamp) { + fields.add(new TableField("ts", Types.LONG, false)); + } + fields.addAll(getRowTypeFields()); + return new EdgeType(fields, hasTimestamp); + } + + public static StructType getRowType() { + return new StructType(getRowTypeFields()); + } + + private static List getRowTypeFields() { + List fields = new ArrayList<>(); + fields.add(new TableField("v0", Types.INTEGER, false)); + fields.add(new TableField("v1", Types.INTEGER, false)); + fields.add(new TableField("v2", Types.LONG, false)); + fields.add(new TableField("v3", Types.LONG, false)); + fields.add(new TableField("v4", Types.SHORT, false)); + fields.add(new TableField("v5", Types.SHORT, false)); + fields.add(new TableField("v6", Types.DOUBLE, false)); + fields.add(new TableField("v7", Types.DOUBLE, false)); + fields.add(new TableField("v8", Types.BINARY_STRING, false)); + fields.add(new TableField("v9", Types.BINARY_STRING, false)); + fields.add(new TableField("v10", new ArrayType(Types.INTEGER), false)); + fields.add(new TableField("v11", new ArrayType(Types.BINARY_STRING), false)); + fields.add(new TableField("v12", new ArrayType(Types.BINARY_STRING), false)); + return fields; + } + + public static RowVertex getVertex(IType idType) { + return getVertex(idType, getRow()); + } + + public static RowVertex getVertex(IType idType, ObjectRow row) { + RowVertex vertex = VertexEdgeFactory.createVertex(getVertexType(idType)); + vertex.setId(getSrcId(idType)); + vertex.setBinaryLabel(BinaryString.fromString("vertexLabel")); + vertex.setValue(row); + return vertex; + } + + public static RowEdge getEdge(IType idType, boolean hasTimestamp) { + return getEdge(idType, getRow(), hasTimestamp); + } + + public static RowEdge getEdge(IType idType, ObjectRow row, boolean hasTimestamp) { + RowEdge edge = VertexEdgeFactory.createEdge(getEdgeType(idType, hasTimestamp)); + edge.setSrcId(getSrcId(idType)); + edge.setTargetId(getTargetId(idType)); + edge.setBinaryLabel(BinaryString.fromString("edgeLabel")); + if (hasTimestamp) { + ((IGraphElementWithTimeField) edge).setTime(284L); + } + edge.setValue(row); + return edge.withDirection(EdgeDirection.IN); + } + + public static Object getSrcId(IType type) { + String idTypeName = type.getName().toUpperCase(Locale.ROOT); + switch (idTypeName) { + case Types.TYPE_NAME_INTEGER: + return 1; + case Types.TYPE_NAME_LONG: + return 1L; + case Types.TYPE_NAME_DOUBLE: + return 1.0d; + case Types.TYPE_NAME_BINARY_STRING: + return BinaryString.fromString("1"); + default: + return (short) 1; + } + } + + public static Object getTargetId(IType type) { + String idTypeName = type.getName().toUpperCase(Locale.ROOT); + switch (idTypeName) { + case Types.TYPE_NAME_INTEGER: + return 2; + case Types.TYPE_NAME_LONG: + return 2L; + case Types.TYPE_NAME_DOUBLE: + return 2.0d; + case Types.TYPE_NAME_BINARY_STRING: + return BinaryString.fromString("2"); + default: + return (short) 2; + } + } + + public static ObjectRow getRow() { + Object[] fields = + new Object[] { + 111, + null, + 1234L, + null, + (short) 256, + null, + 1.0d, + null, + "testValue__#123", + null, + new int[] {1, 2, 5}, + new String[] {"23", "wyuety12", null, "w237", null}, + null + }; + return ObjectRow.create(fields); + } + + public static void checkResult(Object actual, Object expect, IType type) { + if (type instanceof VertexType) { + checkVertex((RowVertex) actual, (RowVertex) expect, (VertexType) type); + } else if (type instanceof EdgeType) { + checkEdge((RowEdge) actual, (RowEdge) expect, (EdgeType) type); + } else if (type instanceof StructType) { + checkRow((Row) actual, (Row) expect, (StructType) type); + } else if (type instanceof ArrayType) { + Object[] actualArray = (Object[]) actual; + Object[] expectArray = (Object[]) expect; + Assert.assertEquals(actualArray.length, expectArray.length); + + IType componentType = ((ArrayType) type).getComponentType(); + for (int i = 0; i < actualArray.length; i++) { + checkResult(actualArray[i], expectArray[i], componentType); + } + } else { + Assert.assertEquals(actual, expect); + } + } + + private static void checkVertex(RowVertex actual, RowVertex expect, VertexType vertexType) { + Assert.assertEquals(actual.getId(), expect.getId()); + Assert.assertEquals(actual.getBinaryLabel(), expect.getBinaryLabel()); + + List valueFields = vertexType.getValueFields(); + checkResult(actual.getValue(), expect.getValue(), new StructType(valueFields)); + } + + private static void checkEdge(RowEdge actual, RowEdge expect, EdgeType edgeType) { + Assert.assertEquals(actual.getSrcId(), expect.getSrcId()); + Assert.assertEquals(actual.getTargetId(), expect.getTargetId()); + Assert.assertEquals(actual.getBinaryLabel(), expect.getBinaryLabel()); + if (edgeType.getTimestamp().isPresent()) { + Assert.assertEquals( + ((IGraphElementWithTimeField) actual).getTime(), + ((IGraphElementWithTimeField) expect).getTime()); + } + Assert.assertEquals(actual.getDirect(), expect.getDirect()); + checkResult(actual.getValue(), expect.getValue(), new StructType(edgeType.getValueFields())); + } + + private static void checkRow(Row actual, Row expect, StructType rowType) { + IType[] types = rowType.getTypes(); + for (int i = 0; i < types.length; i++) { + Assert.assertEquals(actual.getField(i, types[i]), expect.getField(i, types[i])); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorCompilerTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorCompilerTest.java index 6ea510405..fd5a74758 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorCompilerTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorCompilerTest.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.dsl.common.compile.CompileContext; import org.apache.geaflow.dsl.common.compile.QueryCompiler; import org.apache.geaflow.dsl.runtime.QueryClient; @@ -30,12 +31,13 @@ public class ConnectorCompilerTest { - @Test - public void testFindUnResolvedPlugins() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindUnResolvedPlugins() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "CREATE GRAPH IF NOT EXISTS dy_modern (\n" + String script = + "CREATE GRAPH IF NOT EXISTS dy_modern (\n" + " Vertex person (\n" + " id bigint ID,\n" + " name varchar\n" @@ -84,13 +86,15 @@ public void testFindUnResolvedPlugins() { + "INSERT INTO kafka_sink\n" + "SELECT * FROM kafka_source;"; - Set plugins = compiler.getDeclaredTablePlugins(script, context); - Set enginePlugins = compiler.getEnginePlugins(); - Assert.assertEquals(plugins.size(), 3); - List filteredSet = plugins.stream().filter(e -> !enginePlugins.contains(e.toUpperCase())) + Set plugins = compiler.getDeclaredTablePlugins(script, context); + Set enginePlugins = compiler.getEnginePlugins(); + Assert.assertEquals(plugins.size(), 3); + List filteredSet = + plugins.stream() + .filter(e -> !enginePlugins.contains(e.toUpperCase())) .collect(Collectors.toList()); - Assert.assertEquals(filteredSet.size(), 1); + Assert.assertEquals(filteredSet.size(), 1); - Assert.assertEquals(filteredSet.get(0), "kafka123"); - } + Assert.assertEquals(filteredSet.get(0), "kafka123"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorTester.java b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorTester.java index 03beeb0d6..bd72e40f2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorTester.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/ConnectorTester.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.connector; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.io.Serializable; @@ -32,6 +31,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; @@ -51,220 +51,225 @@ import org.apache.geaflow.runtime.core.scheduler.resource.ScheduledWorkerManagerFactory; import org.testng.Assert; +import com.google.common.base.Preconditions; + public class ConnectorTester implements Serializable { - private int testTimeWaitSeconds = 0; + private int testTimeWaitSeconds = 0; - public static final String INIT_DDL = "/query/modern_graph.sql"; - public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; + public static final String INIT_DDL = "/query/modern_graph.sql"; + public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; - private String queryPath; + private String queryPath; - private boolean compareWithOrder = false; + private boolean compareWithOrder = false; - private String graphDefinePath; + private String graphDefinePath; - private boolean hasCustomWindowConfig = false; + private boolean hasCustomWindowConfig = false; - protected boolean dedupe = false; + protected boolean dedupe = false; - private int workerNum = (int) ExecutionConfigKeys.CONTAINER_WORKER_NUM.getDefaultValue(); + private int workerNum = (int) ExecutionConfigKeys.CONTAINER_WORKER_NUM.getDefaultValue(); - private final Map config = new HashMap<>(); + private final Map config = new HashMap<>(); - private ConnectorTester() { - try { - initRemotePath(); - } catch (IOException e) { - throw new RuntimeException(e); - } + private ConnectorTester() { + try { + initRemotePath(); + } catch (IOException e) { + throw new RuntimeException(e); } + } - public static ConnectorTester build() { - return new ConnectorTester(); - } + public static ConnectorTester build() { + return new ConnectorTester(); + } + public ConnectorTester withQueryPath(String queryPath) { + this.queryPath = queryPath; + return this; + } - public ConnectorTester withQueryPath(String queryPath) { - this.queryPath = queryPath; - return this; - } + public ConnectorTester withConfig(String key, Object value) { + this.config.put(key, String.valueOf(value)); + return this; + } - public ConnectorTester withConfig(String key, Object value) { - this.config.put(key, String.valueOf(value)); - return this; + public ConnectorTester execute() throws Exception { + if (queryPath == null) { + throw new IllegalArgumentException("You should call withQueryPath() before execute()."); } - - public ConnectorTester execute() throws Exception { - if (queryPath == null) { - throw new IllegalArgumentException("You should call withQueryPath() before execute()."); - } - Map config = new HashMap<>(); - if (!hasCustomWindowConfig) { - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); - } - config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), FileConstants.PREFIX_JAVA_RESOURCE + queryPath); - config.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM.getKey(), String.valueOf(workerNum)); - config.putAll(this.config); - initResultDirectory(); - - Environment environment = EnvironmentFactory.onLocalEnvironment(); - environment.getEnvironmentContext().withConfig(config); - - GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); - - String graphDefinePath = null; - if (this.graphDefinePath != null) { - graphDefinePath = this.graphDefinePath; - } - gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); - try { - gqlPipeLine.execute(); - } finally { - environment.shutdown(); - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); - } - return this; + Map config = new HashMap<>(); + if (!hasCustomWindowConfig) { + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); } - - private void initResultDirectory() throws Exception { - // delete target file path - String targetPath = getTargetPath(queryPath); - File targetFile = new File(targetPath); - if (targetFile.exists()) { - FileUtils.forceDelete(targetFile); - } + config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); + config.put( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), + FileConstants.PREFIX_JAVA_RESOURCE + queryPath); + config.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM.getKey(), String.valueOf(workerNum)); + config.putAll(this.config); + initResultDirectory(); + + Environment environment = EnvironmentFactory.onLocalEnvironment(); + environment.getEnvironmentContext().withConfig(config); + + GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); + + String graphDefinePath = null; + if (this.graphDefinePath != null) { + graphDefinePath = this.graphDefinePath; } - - private void initRemotePath() throws IOException { - // delete state remote path - File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); - if (stateRemoteFile.exists()) { - FileUtils.forceDelete(stateRemoteFile); - } + gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); + try { + gqlPipeLine.execute(); + } finally { + environment.shutdown(); + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); } - - public void checkSinkResult() throws Exception { - checkSinkResult(null); + return this; + } + + private void initResultDirectory() throws Exception { + // delete target file path + String targetPath = getTargetPath(queryPath); + File targetFile = new File(targetPath); + if (targetFile.exists()) { + FileUtils.forceDelete(targetFile); } + } - public void checkSinkResult(String dict) throws Exception { - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String exceptPath = dict != null ? "/expect/" + dict + "/" + lastPath.split("\\.")[0] + ".txt" + private void initRemotePath() throws IOException { + // delete state remote path + File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); + if (stateRemoteFile.exists()) { + FileUtils.forceDelete(stateRemoteFile); + } + } + + public void checkSinkResult() throws Exception { + checkSinkResult(null); + } + + public void checkSinkResult(String dict) throws Exception { + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String exceptPath = + dict != null + ? "/expect/" + dict + "/" + lastPath.split("\\.")[0] + ".txt" : "/expect/" + lastPath.split("\\.")[0] + ".txt"; - String targetPath = getTargetPath(queryPath); - String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); - String actualResult = readFile(targetPath); - compareResult(actualResult, expectResult); + String targetPath = getTargetPath(queryPath); + String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); + String actualResult = readFile(targetPath); + compareResult(actualResult, expectResult); + } + + private void compareResult(String actualResult, String expectResult) { + if (compareWithOrder) { + Assert.assertEquals(expectResult, actualResult); + } else { + String[] actualLines = actualResult.split("\n"); + String[] expectLines = expectResult.split("\n"); + if (dedupe) { + List actualLinesDedupe = + Arrays.asList(actualLines).stream().distinct().collect(Collectors.toList()); + actualLines = actualLinesDedupe.toArray(new String[0]); + List expectLinesDedupe = + Arrays.asList(expectLines).stream().distinct().collect(Collectors.toList()); + expectLines = expectLinesDedupe.toArray(new String[0]); + } + Arrays.sort(actualLines); + Arrays.sort(expectLines); + + String actualSort = StringUtils.join(actualLines, "\n"); + String expectSort = StringUtils.join(expectLines, "\n"); + if (!Objects.equals(actualSort, expectSort)) { + Assert.assertEquals(expectResult, actualResult); + } } + } - private void compareResult(String actualResult, String expectResult) { - if (compareWithOrder) { - Assert.assertEquals(expectResult, actualResult); - } else { - String[] actualLines = actualResult.split("\n"); - String[] expectLines = expectResult.split("\n"); - if (dedupe) { - List actualLinesDedupe = Arrays.asList(actualLines).stream().distinct().collect(Collectors.toList()); - actualLines = actualLinesDedupe.toArray(new String[0]); - List expectLinesDedupe = Arrays.asList(expectLines).stream().distinct().collect(Collectors.toList()); - expectLines = expectLinesDedupe.toArray(new String[0]); - } - Arrays.sort(actualLines); - Arrays.sort(expectLines); - - String actualSort = StringUtils.join(actualLines, "\n"); - String expectSort = StringUtils.join(expectLines, "\n"); - if (!Objects.equals(actualSort, expectSort)) { - Assert.assertEquals(expectResult, actualResult); - } - } + private String readFile(String path) throws IOException { + File file = new File(path); + if (file.isHidden()) { + return ""; } - - private String readFile(String path) throws IOException { - File file = new File(path); - if (file.isHidden()) { - return ""; - } - if (file.isFile()) { - return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); + if (file.isFile()) { + return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); + } + File[] files = file.listFiles(); + StringBuilder content = new StringBuilder(); + if (files != null) { + for (File subFile : files) { + String readText = readFile(subFile.getAbsolutePath()); + if (StringUtils.isBlank(readText)) { + continue; } - File[] files = file.listFiles(); - StringBuilder content = new StringBuilder(); - if (files != null) { - for (File subFile : files) { - String readText = readFile(subFile.getAbsolutePath()); - if (StringUtils.isBlank(readText)) { - continue; - } - if (content.length() > 0) { - content.append("\n"); - } - content.append(readText); - } + if (content.length() > 0) { + content.append("\n"); } - return content.toString().trim(); - } - - private static String getTargetPath(String queryPath) { - assert queryPath != null; - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String targetPath = "target/" + lastPath.split("\\.")[0]; - String currentPath = new File(".").getAbsolutePath(); - targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; - return targetPath; + content.append(readText); + } } + return content.toString().trim(); + } - private static class TestGQLPipelineHook implements GQLPipelineHook { + private static String getTargetPath(String queryPath) { + assert queryPath != null; + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String targetPath = "target/" + lastPath.split("\\.")[0]; + String currentPath = new File(".").getAbsolutePath(); + targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; + return targetPath; + } - private final String graphDefinePath; + private static class TestGQLPipelineHook implements GQLPipelineHook { - private final String queryPath; + private final String graphDefinePath; - public TestGQLPipelineHook(String graphDefinePath, String queryPath) { - this.graphDefinePath = graphDefinePath; - this.queryPath = queryPath; - } + private final String queryPath; - @Override - public String rewriteScript(String script, Configuration configuration) { - String result = script; - String regex = "\\$\\{[^}]+}"; - Pattern pattern = Pattern.compile(regex); - Matcher matcher = pattern.matcher(result); - while (matcher.find()) { - String matchedField = matcher.group(); - String replaceKey = matchedField.substring(2, matchedField.length() - 1); - if (replaceKey.equals("target")) { - result = result.replace(matchedField, getTargetPath(queryPath)); - } else { - String replaceData = configuration.getString(replaceKey); - Preconditions.checkState(replaceData != null, "Not found replace key:{}", replaceKey); - result = result.replace(matchedField, replaceData); - } - } - return result; - } + public TestGQLPipelineHook(String graphDefinePath, String queryPath) { + this.graphDefinePath = graphDefinePath; + this.queryPath = queryPath; + } - @Override - public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { - if (graphDefinePath != null) { - try { - String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); - queryClient.executeQuery(ddl, queryContext); - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } - } + @Override + public String rewriteScript(String script, Configuration configuration) { + String result = script; + String regex = "\\$\\{[^}]+}"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(result); + while (matcher.find()) { + String matchedField = matcher.group(); + String replaceKey = matchedField.substring(2, matchedField.length() - 1); + if (replaceKey.equals("target")) { + result = result.replace(matchedField, getTargetPath(queryPath)); + } else { + String replaceData = configuration.getString(replaceKey); + Preconditions.checkState(replaceData != null, "Not found replace key:{}", replaceKey); + result = result.replace(matchedField, replaceData); } + } + return result; + } - @Override - public void afterExecute(QueryClient queryClient, QueryContext queryContext) { - + @Override + public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { + if (graphDefinePath != null) { + try { + String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); + queryClient.executeQuery(ddl, queryContext); + } catch (IOException e) { + throw new GeaFlowDSLException(e); } + } } + + @Override + public void afterExecute(QueryClient queryClient, QueryContext queryContext) {} + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/HiveSourceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/HiveSourceTest.java index 3b4ed093e..a702c507c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/HiveSourceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector-tests/src/test/java/org/apache/geaflow/dsl/runtime/connector/HiveSourceTest.java @@ -24,13 +24,12 @@ public class HiveSourceTest { - @Test(enabled = false) - public void testHiveSource_001() throws Exception { - ConnectorTester - .build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) - .withQueryPath("/query/hive_source_001.sql") - .execute() - .checkSinkResult(); - } + @Test(enabled = false) + public void testHiveSource_001() throws Exception { + ConnectorTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) + .withQueryPath("/query/hive_source_001.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/AbstractTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/AbstractTableSource.java index 8ea0a91f1..e692599ae 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/AbstractTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/AbstractTableSource.java @@ -19,9 +19,9 @@ package org.apache.geaflow.dsl.connector.api; - import java.io.IOException; import java.util.Optional; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.connector.api.window.AllFetchWindow; import org.apache.geaflow.dsl.connector.api.window.FetchWindow; @@ -30,30 +30,36 @@ public abstract class AbstractTableSource implements TableSource { - public FetchData fetch(Partition partition, Optional startOffset, FetchWindow windowInfo) throws IOException { - switch (windowInfo.getType()) { - case ALL_WINDOW: - return fetch(partition, startOffset, (AllFetchWindow) windowInfo); - case SIZE_TUMBLING_WINDOW: - return fetch(partition, startOffset, (SizeFetchWindow) windowInfo); - case FIXED_TIME_TUMBLING_WINDOW: - return fetch(partition, startOffset, (TimeFetchWindow) windowInfo); - default: - throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); - } + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + switch (windowInfo.getType()) { + case ALL_WINDOW: + return fetch(partition, startOffset, (AllFetchWindow) windowInfo); + case SIZE_TUMBLING_WINDOW: + return fetch(partition, startOffset, (SizeFetchWindow) windowInfo); + case FIXED_TIME_TUMBLING_WINDOW: + return fetch(partition, startOffset, (TimeFetchWindow) windowInfo); + default: + throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); } + } + public FetchData fetch( + Partition partition, Optional startOffset, AllFetchWindow windowInfo) + throws IOException { + throw new GeaFlowDSLException("Not support"); + } - public FetchData fetch(Partition partition, Optional startOffset, AllFetchWindow windowInfo) throws IOException { - throw new GeaFlowDSLException("Not support"); - } - - public FetchData fetch(Partition partition, Optional startOffset, SizeFetchWindow windowInfo) throws IOException { - throw new GeaFlowDSLException("Not support"); - } - - public FetchData fetch(Partition partition, Optional startOffset, TimeFetchWindow windowInfo) throws IOException { - throw new GeaFlowDSLException("Not support"); - } + public FetchData fetch( + Partition partition, Optional startOffset, SizeFetchWindow windowInfo) + throws IOException { + throw new GeaFlowDSLException("Not support"); + } + public FetchData fetch( + Partition partition, Optional startOffset, TimeFetchWindow windowInfo) + throws IOException { + throw new GeaFlowDSLException("Not support"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/FetchData.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/FetchData.java index cc042dcf7..ebab1b9cb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/FetchData.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/FetchData.java @@ -24,70 +24,62 @@ import java.util.List; import java.util.Objects; -/** - * The data fetched from the partition of {@link TableSource}. - */ +/** The data fetched from the partition of {@link TableSource}. */ public class FetchData implements Serializable { - private final Iterator dataIterator; + private final Iterator dataIterator; - private final int size; + private final int size; - private final Offset nextOffset; + private final Offset nextOffset; - private final boolean isFinish; + private final boolean isFinish; - private FetchData(Iterator dataIterator, int size, Offset nextOffset, boolean isFinish) { - this.dataIterator = Objects.requireNonNull(dataIterator); - this.size = size; - this.nextOffset = Objects.requireNonNull(nextOffset); - this.isFinish = isFinish; - } + private FetchData(Iterator dataIterator, int size, Offset nextOffset, boolean isFinish) { + this.dataIterator = Objects.requireNonNull(dataIterator); + this.size = size; + this.nextOffset = Objects.requireNonNull(nextOffset); + this.isFinish = isFinish; + } - public static FetchData createStreamFetch(List dataList, Offset nextOffset, boolean isFinish) { - return new FetchData<>(dataList.listIterator(), dataList.size(), nextOffset, isFinish); - } + public static FetchData createStreamFetch( + List dataList, Offset nextOffset, boolean isFinish) { + return new FetchData<>(dataList.listIterator(), dataList.size(), nextOffset, isFinish); + } - public static FetchData createBatchFetch(Iterator dataIterator, Offset nextOffset) { - return new FetchData<>(dataIterator, -1, nextOffset, true); - } + public static FetchData createBatchFetch(Iterator dataIterator, Offset nextOffset) { + return new FetchData<>(dataIterator, -1, nextOffset, true); + } - /** - * Returns data list. - */ - public Iterator getDataIterator() { - return dataIterator; - } + /** Returns data list. */ + public Iterator getDataIterator() { + return dataIterator; + } - /** - * Returns data size. - */ - public int getDataSize() { - return size; - } + /** Returns data size. */ + public int getDataSize() { + return size; + } - /** - * Returns the offset for next window. - */ - public Offset getNextOffset() { - return nextOffset; - } + /** Returns the offset for next window. */ + public Offset getNextOffset() { + return nextOffset; + } - public void seek(long seekPos) { - long toSkip = seekPos; - while (toSkip > 0) { - if (!dataIterator.hasNext()) { - throw new RuntimeException("seek pos:" + seekPos + " exceed the split size: " + (seekPos - toSkip)); - } - dataIterator.next(); - toSkip --; - } + public void seek(long seekPos) { + long toSkip = seekPos; + while (toSkip > 0) { + if (!dataIterator.hasNext()) { + throw new RuntimeException( + "seek pos:" + seekPos + " exceed the split size: " + (seekPos - toSkip)); + } + dataIterator.next(); + toSkip--; } + } - /** - * Returns true if the fetch has finished for the partition. - */ - public boolean isFinish() { - return isFinish; - } + /** Returns true if the fetch has finished for the partition. */ + public boolean isFinish() { + return isFinish; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/ISkipOpenAndClose.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/ISkipOpenAndClose.java index 43748d271..23b245dfc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/ISkipOpenAndClose.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/ISkipOpenAndClose.java @@ -20,9 +20,7 @@ package org.apache.geaflow.dsl.connector.api; /** - * If the {@link TableSource} or {@link TableSink} inherit this interface, the - * open() and close() method will be skipped for compile time. + * If the {@link TableSource} or {@link TableSink} inherit this interface, the open() and close() + * method will be skipped for compile time. */ -public interface ISkipOpenAndClose { - -} +public interface ISkipOpenAndClose {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Offset.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Offset.java index 7e4831d1c..ab20a9a50 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Offset.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Offset.java @@ -21,17 +21,13 @@ import java.io.Serializable; -/** - * Offset for each partition of the {@link TableSource}. - */ +/** Offset for each partition of the {@link TableSource}. */ public interface Offset extends Serializable { - /** - * Returns the human read-able offset string. - */ - String humanReadable(); + /** Returns the human read-able offset string. */ + String humanReadable(); - long getOffset(); + long getOffset(); - boolean isTimestamp(); + boolean isTimestamp(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Partition.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Partition.java index 6d42d8464..3e572481b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Partition.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/Partition.java @@ -21,16 +21,11 @@ import java.io.Serializable; -/** - * The partition of the {@link TableSource}. - */ +/** The partition of the {@link TableSource}. */ public interface Partition extends Serializable { - /** - * Returns the name of the partition. - */ - String getName(); + /** Returns the name of the partition. */ + String getName(); - default void setIndex(int index, int parallel) { - } + default void setIndex(int index, int parallel) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableConnector.java index b472e189e..17e8cafb6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableConnector.java @@ -19,13 +19,9 @@ package org.apache.geaflow.dsl.connector.api; -/** - * The interface for table connector. - */ +/** The interface for table connector. */ public interface TableConnector { - /** - * Return table connector type. - */ - String getType(); + /** Return table connector type. */ + String getType(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableReadableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableReadableConnector.java index 7dd6185f1..f9465195a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableReadableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableReadableConnector.java @@ -21,10 +21,8 @@ import org.apache.geaflow.common.config.Configuration; -/** - * A readable table connector. - */ +/** A readable table connector. */ public interface TableReadableConnector extends TableConnector { - TableSource createSource(Configuration conf); + TableSource createSource(Configuration conf); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSink.java index ab0733f78..9e6d3a46c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSink.java @@ -21,38 +21,27 @@ import java.io.IOException; import java.io.Serializable; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.types.StructType; -/** - * Interface for table sink. - */ +/** Interface for table sink. */ public interface TableSink extends Serializable { - /** - * The init method for compile time. - */ - void init(Configuration tableConf, StructType schema); - - /** - * The init method for runtime. - */ - void open(RuntimeContext context); - - /** - * The write method for writing row to the table. - */ - void write(Row row) throws IOException; - - /** - * The finish callback for each window finished. - */ - void finish() throws IOException; - - /** - * The close callback for the job finish the execution. - */ - void close(); + /** The init method for compile time. */ + void init(Configuration tableConf, StructType schema); + + /** The init method for runtime. */ + void open(RuntimeContext context); + + /** The write method for writing row to the table. */ + void write(Row row) throws IOException; + + /** The finish callback for each window finished. */ + void finish() throws IOException; + + /** The close callback for the job finish the execution. */ + void close(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSource.java index 638b8d20b..cbb6afa5d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableSource.java @@ -23,6 +23,7 @@ import java.io.Serializable; import java.util.List; import java.util.Optional; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; @@ -30,46 +31,33 @@ import org.apache.geaflow.dsl.connector.api.serde.TableDeserializer; import org.apache.geaflow.dsl.connector.api.window.FetchWindow; -/** - * Interface for table source. - */ +/** Interface for table source. */ public interface TableSource extends Serializable { - /** - * The init method for compile time. - */ - void init(Configuration tableConf, TableSchema tableSchema); + /** The init method for compile time. */ + void init(Configuration tableConf, TableSchema tableSchema); - /** - * The init method for runtime. - */ - void open(RuntimeContext context); + /** The init method for runtime. */ + void open(RuntimeContext context); - /** - * List all the partitions for the source. - */ - List listPartitions(); + /** List all the partitions for the source. */ + List listPartitions(); - /** - * List all the partitions for the source. - */ - default List listPartitions(int parallelism) { - return listPartitions(); - } + /** List all the partitions for the source. */ + default List listPartitions(int parallelism) { + return listPartitions(); + } - /** - * Returns the {@link TableDeserializer} for the source to convert data read from - * the source to {@link Row}. - */ - TableDeserializer getDeserializer(Configuration conf); + /** + * Returns the {@link TableDeserializer} for the source to convert data read from the source to + * {@link Row}. + */ + TableDeserializer getDeserializer(Configuration conf); - /** - * Fetch data for the partition from start offset. - */ - FetchData fetch(Partition partition, Optional startOffset, FetchWindow windowInfo) throws IOException; + /** Fetch data for the partition from start offset. */ + FetchData fetch(Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException; - /** - * The close callback for the job finish the execution. - */ - void close(); + /** The close callback for the job finish the execution. */ + void close(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableWritableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableWritableConnector.java index 97f90e7b7..bba5bfa3e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableWritableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/TableWritableConnector.java @@ -19,16 +19,11 @@ package org.apache.geaflow.dsl.connector.api; - import org.apache.geaflow.common.config.Configuration; -/** - * A writable table connector. - */ +/** A writable table connector. */ public interface TableWritableConnector extends TableConnector { - /** - * Create the {@link TableSink} for the table connector. - */ - TableSink createSink(Configuration conf); + /** Create the {@link TableSink} for the table connector. */ + TableSink createSink(Configuration conf); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSinkFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSinkFunction.java index 2a1585c71..f9dafd9bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSinkFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSinkFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.connector.api.function; import java.io.IOException; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichWindowFunction; import org.apache.geaflow.api.function.io.SinkFunction; @@ -37,64 +38,68 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * The implementation of {@link SinkFunction} for DSL table sink. - */ +/** The implementation of {@link SinkFunction} for DSL table sink. */ public class GeaFlowTableSinkFunction extends RichWindowFunction implements SinkFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowTableSinkFunction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowTableSinkFunction.class); - protected final GeaFlowTable table; + protected final GeaFlowTable table; - protected TableSink tableSink; + protected TableSink tableSink; - private boolean skipWrite; - private Histogram writeRt; - private Histogram flushRt; - private Meter writeTps; + private boolean skipWrite; + private Histogram writeRt; + private Histogram flushRt; + private Meter writeTps; - public GeaFlowTableSinkFunction(GeaFlowTable table, TableSink tableSink) { - this.table = table; - this.tableSink = tableSink; - } + public GeaFlowTableSinkFunction(GeaFlowTable table, TableSink tableSink) { + this.table = table; + this.tableSink = tableSink; + } - @Override - public void open(RuntimeContext runtimeContext) { - tableSink.open(runtimeContext); - LOGGER.info("open sink table: {}", table.getName()); - writeRt = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + @Override + public void open(RuntimeContext runtimeContext) { + tableSink.open(runtimeContext); + LOGGER.info("open sink table: {}", table.getName()); + writeRt = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .histogram(MetricNameFormatter.tableWriteTimeRtName(table.getName())); - flushRt = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + flushRt = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .histogram(MetricNameFormatter.tableFlushTimeRtName(table.getName())); - writeTps = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + writeTps = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .meter(MetricNameFormatter.tableOutputRowTpsName(table.getName())); - Configuration conf = table.getConfigWithGlobal(runtimeContext.getConfiguration()); - skipWrite = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SINK_ENABLE_SKIP); - } + Configuration conf = table.getConfigWithGlobal(runtimeContext.getConfiguration()); + skipWrite = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SINK_ENABLE_SKIP); + } - @Override - public void write(Row row) throws Exception { - if (!skipWrite) { - long startTime = System.currentTimeMillis(); - tableSink.write(row); - writeRt.update(System.currentTimeMillis() - startTime); - writeTps.mark(); - } + @Override + public void write(Row row) throws Exception { + if (!skipWrite) { + long startTime = System.currentTimeMillis(); + tableSink.write(row); + writeRt.update(System.currentTimeMillis() - startTime); + writeTps.mark(); } + } - @Override - public void finish() { - try { - long startTime = System.currentTimeMillis(); - tableSink.finish(); - flushRt.update(System.currentTimeMillis() - startTime); - } catch (IOException e) { - throw new GeaFlowDSLException("Error in sink flush", e); - } + @Override + public void finish() { + try { + long startTime = System.currentTimeMillis(); + tableSink.finish(); + flushRt.update(System.currentTimeMillis() - startTime); + } catch (IOException e) { + throw new GeaFlowDSLException("Error in sink flush", e); } + } - @Override - public void close() { - tableSink.close(); - } + @Override + public void close() { + tableSink.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSourceFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSourceFunction.java index 296b4f08d..862035780 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSourceFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/GeaFlowTableSourceFunction.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.connector.api.function; -import com.google.common.base.Preconditions; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; @@ -31,6 +29,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.api.function.io.SourceFunction; @@ -61,202 +60,230 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * The implementation of {@link SourceFunction} for DSL table source. - */ +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +/** The implementation of {@link SourceFunction} for DSL table source. */ public class GeaFlowTableSourceFunction extends RichFunction implements SourceFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowTableSourceFunction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowTableSourceFunction.class); - private static final int PARTITION_COMPARE_PERIOD_SECOND = 120; + private static final int PARTITION_COMPARE_PERIOD_SECOND = 120; - private final GeaFlowTable table; - private final TableSource tableSource; - private RuntimeContext runtimeContext; - private List partitions; - private int parallelism; + private final GeaFlowTable table; + private final TableSource tableSource; + private RuntimeContext runtimeContext; + private List partitions; + private int parallelism; - private OffsetStore offsetStore; + private OffsetStore offsetStore; - private TableDeserializer deserializer; + private TableDeserializer deserializer; - private transient volatile List oldPartitions = null; + private transient volatile List oldPartitions = null; - private transient volatile boolean isPartitionModified = false; + private transient volatile boolean isPartitionModified = false; - private transient volatile boolean isStopPartitionCheck = false; + private transient volatile boolean isStopPartitionCheck = false; - private transient ExecutorService singleThreadPool; + private transient ExecutorService singleThreadPool; - private boolean enableUploadMetrics; - private Counter rowCounter; - private Meter rowTps; - private Meter blockTps; - private Histogram parserRt; + private boolean enableUploadMetrics; + private Counter rowCounter; + private Meter rowTps; + private Meter blockTps; + private Histogram parserRt; - public GeaFlowTableSourceFunction(GeaFlowTable table, TableSource tableSource) { - this.table = table; - this.tableSource = tableSource; - } + public GeaFlowTableSourceFunction(GeaFlowTable table, TableSource tableSource) { + this.table = table; + this.tableSource = tableSource; + } - @Override - public void open(RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - this.offsetStore = new OffsetStore(runtimeContext, table.getName()); - rowCounter = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + @Override + public void open(RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + this.offsetStore = new OffsetStore(runtimeContext, table.getName()); + rowCounter = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .counter(MetricNameFormatter.tableInputRowName(table.getName())); - rowTps = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + rowTps = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .meter(MetricNameFormatter.tableInputRowTpsName(table.getName())); - blockTps = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + blockTps = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .meter(MetricNameFormatter.tableInputBlockTpsName(table.getName())); - parserRt = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL) + parserRt = + MetricGroupRegistry.getInstance() + .getMetricGroup(MetricConstants.MODULE_DSL) .histogram(MetricNameFormatter.tableParserTimeRtName(table.getName())); - } + } - @Override - public void close() { - if (tableSource != null) { - tableSource.close(); - } - if (singleThreadPool != null) { - singleThreadPool.shutdownNow(); - } + @Override + public void close() { + if (tableSource != null) { + tableSource.close(); + } + if (singleThreadPool != null) { + singleThreadPool.shutdownNow(); } + } - @Override - public void init(int parallel, int index) { - final long startTime = System.currentTimeMillis(); - this.parallelism = parallel; - tableSource.open(runtimeContext); - List allPartitions = tableSource.listPartitions(this.parallelism); - oldPartitions = new ArrayList<>(allPartitions); - singleThreadPool = startPartitionCompareThread(); - boolean isSingleFileModeRead = table.getConfigWithGlobal(runtimeContext.getConfiguration()) + @Override + public void init(int parallel, int index) { + final long startTime = System.currentTimeMillis(); + this.parallelism = parallel; + tableSource.open(runtimeContext); + List allPartitions = tableSource.listPartitions(this.parallelism); + oldPartitions = new ArrayList<>(allPartitions); + singleThreadPool = startPartitionCompareThread(); + boolean isSingleFileModeRead = + table + .getConfigWithGlobal(runtimeContext.getConfiguration()) .getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD); - if (isSingleFileModeRead) { - Preconditions.checkState(allPartitions.size() == 1, - "geaflow.dsl.file.single.mod.read is ture only support single file"); - partitions = allPartitions; - } else { - partitions = assignPartition(allPartitions, - runtimeContext.getTaskArgs().getMaxParallelism(), parallel, index); - } - for (Partition partition : partitions) { - partition.setIndex(index, parallel); - } + if (isSingleFileModeRead) { + Preconditions.checkState( + allPartitions.size() == 1, + "geaflow.dsl.file.single.mod.read is ture only support single file"); + partitions = allPartitions; + } else { + partitions = + assignPartition( + allPartitions, runtimeContext.getTaskArgs().getMaxParallelism(), parallel, index); + } + for (Partition partition : partitions) { + partition.setIndex(index, parallel); + } - Configuration conf = table.getConfigWithGlobal(runtimeContext.getConfiguration()); - deserializer = tableSource.getDeserializer(conf); - if (deserializer != null) { - StructType schema = (StructType) SqlTypeUtil.convertType( - table.getRowType(GQLJavaTypeFactory.create())); - deserializer.init(conf, schema); - } - enableUploadMetrics = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_ENABLE_UPLOAD_METRICS); - LOGGER.info("open source table: {}, taskIndex:{}, parallel: {}, assigned " - + "partitions:{}, cost {}", table.getName(), index, parallel, partitions, System.currentTimeMillis() - startTime); + Configuration conf = table.getConfigWithGlobal(runtimeContext.getConfiguration()); + deserializer = tableSource.getDeserializer(conf); + if (deserializer != null) { + StructType schema = + (StructType) SqlTypeUtil.convertType(table.getRowType(GQLJavaTypeFactory.create())); + deserializer.init(conf, schema); } + enableUploadMetrics = + conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_ENABLE_UPLOAD_METRICS); + LOGGER.info( + "open source table: {}, taskIndex:{}, parallel: {}, assigned " + "partitions:{}, cost {}", + table.getName(), + index, + parallel, + partitions, + System.currentTimeMillis() - startTime); + } - @SuppressWarnings("unchecked") - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - if (isPartitionModified) { - throw new GeaFlowDSLException("The partitions of the source table has modified!"); + @SuppressWarnings("unchecked") + @Override + public boolean fetch(IWindow window, SourceContext ctx) throws Exception { + if (isPartitionModified) { + throw new GeaFlowDSLException("The partitions of the source table has modified!"); + } + if (partitions.isEmpty()) { + return false; + } + FetchWindow fetchWindow = FetchWindowFactory.createFetchWindow(window); + long batchId = window.windowId(); + boolean isFinish = true; + for (Partition partition : partitions) { + long partitionStartTime = System.currentTimeMillis(); + Offset offset = offsetStore.readOffset(partition.getName(), batchId); + FetchData fetchData = + tableSource.fetch(partition, Optional.ofNullable(offset), fetchWindow); + Iterator dataIterator = fetchData.getDataIterator(); + while (dataIterator.hasNext()) { + Object record = dataIterator.next(); + long startTime = System.nanoTime(); + List rows; + if (deserializer != null) { + rows = ((TableDeserializer) deserializer).deserialize(record); + } else { + rows = Collections.singletonList((Row) record); } - if (partitions.isEmpty()) { - return false; + if (rows != null && rows.size() > 0) { + for (Row row : rows) { + ctx.collect(row); + } + if (enableUploadMetrics) { + parserRt.update((System.nanoTime() - startTime) / 1000L); + rowCounter.inc(rows.size()); + rowTps.mark(rows.size()); + blockTps.mark(); + } } - FetchWindow fetchWindow = FetchWindowFactory.createFetchWindow(window); - long batchId = window.windowId(); - boolean isFinish = true; - for (Partition partition : partitions) { - long partitionStartTime = System.currentTimeMillis(); - Offset offset = offsetStore.readOffset(partition.getName(), batchId); - FetchData fetchData = tableSource.fetch(partition, Optional.ofNullable(offset), fetchWindow); - Iterator dataIterator = fetchData.getDataIterator(); - while (dataIterator.hasNext()) { - Object record = dataIterator.next(); - long startTime = System.nanoTime(); - List rows; - if (deserializer != null) { - rows = ((TableDeserializer) deserializer).deserialize(record); - } else { - rows = Collections.singletonList((Row) record); - } - if (rows != null && rows.size() > 0) { - for (Row row : rows) { - ctx.collect(row); - } - if (enableUploadMetrics) { - parserRt.update((System.nanoTime() - startTime) / 1000L); - rowCounter.inc(rows.size()); - rowTps.mark(rows.size()); - blockTps.mark(); - } - } - } - // store the next offset. - offsetStore.writeOffset(partition.getName(), batchId + 1, fetchData.getNextOffset()); + } + // store the next offset. + offsetStore.writeOffset(partition.getName(), batchId + 1, fetchData.getNextOffset()); - LOGGER.info("fetch data size: {}, isFinish: {}, table: {}, partition: {}, batchId: {}," - + "nextOffset: {}, cost {}", - fetchData.getDataSize(), fetchData.isFinish(), table.getName(), partition.getName(), - batchId, fetchData.getNextOffset().humanReadable(), System.currentTimeMillis() - partitionStartTime); + LOGGER.info( + "fetch data size: {}, isFinish: {}, table: {}, partition: {}, batchId: {}," + + "nextOffset: {}, cost {}", + fetchData.getDataSize(), + fetchData.isFinish(), + table.getName(), + partition.getName(), + batchId, + fetchData.getNextOffset().humanReadable(), + System.currentTimeMillis() - partitionStartTime); - if (!fetchData.isFinish()) { - isFinish = false; - } - } - return !isFinish; + if (!fetchData.isFinish()) { + isFinish = false; + } } + return !isFinish; + } - protected ExecutorService startPartitionCompareThread() { - ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat( - "partitionComparedThread" + "-%d").build(); + protected ExecutorService startPartitionCompareThread() { + ThreadFactory namedThreadFactory = + new ThreadFactoryBuilder().setNameFormat("partitionComparedThread" + "-%d").build(); - ExecutorService singleThreadPool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue<>(1), namedThreadFactory); - singleThreadPool.execute(() -> { - while (!isStopPartitionCheck) { - if (isPartitionModified) { - throw new GeaFlowDSLException( - "The partitions of the source table has modified!"); - } - LOGGER.info("partitionCompareThread is running"); - List newPartitions = tableSource.listPartitions(parallelism); - if (oldPartitions == null || newPartitions == null - || oldPartitions.size() != newPartitions.size() || !oldPartitions.equals( - newPartitions)) { - LOGGER.warn("partition modify. old partition list is: {}, new partition list " - + "is: {}", oldPartitions, newPartitions); - isPartitionModified = true; - } - oldPartitions = newPartitions; - try { - Thread.sleep(PARTITION_COMPARE_PERIOD_SECOND * 1000); - } catch (InterruptedException e) { - isStopPartitionCheck = true; - } + ExecutorService singleThreadPool = + new ThreadPoolExecutor( + 1, 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(1), namedThreadFactory); + singleThreadPool.execute( + () -> { + while (!isStopPartitionCheck) { + if (isPartitionModified) { + throw new GeaFlowDSLException("The partitions of the source table has modified!"); + } + LOGGER.info("partitionCompareThread is running"); + List newPartitions = tableSource.listPartitions(parallelism); + if (oldPartitions == null + || newPartitions == null + || oldPartitions.size() != newPartitions.size() + || !oldPartitions.equals(newPartitions)) { + LOGGER.warn( + "partition modify. old partition list is: {}, new partition list " + "is: {}", + oldPartitions, + newPartitions); + isPartitionModified = true; + } + oldPartitions = newPartitions; + try { + Thread.sleep(PARTITION_COMPARE_PERIOD_SECOND * 1000); + } catch (InterruptedException e) { + isStopPartitionCheck = true; } + } }); - return singleThreadPool; - } + return singleThreadPool; + } - private List assignPartition(List allPartitions, int maxParallelism, - int parallel, int index) { - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxParallelism, - parallel, index); + private List assignPartition( + List allPartitions, int maxParallelism, int parallel, int index) { + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxParallelism, parallel, index); - List partitions = new ArrayList<>(); - for (Partition partition : allPartitions) { - int keyGroupId = KeyGroupAssignment.assignToKeyGroup(partition.getName(), - maxParallelism); - if (keyGroupId >= keyGroup.getStartKeyGroup() - && keyGroupId <= keyGroup.getEndKeyGroup()) { - partitions.add(partition); - } - } - return partitions; + List partitions = new ArrayList<>(); + for (Partition partition : allPartitions) { + int keyGroupId = KeyGroupAssignment.assignToKeyGroup(partition.getName(), maxParallelism); + if (keyGroupId >= keyGroup.getStartKeyGroup() && keyGroupId <= keyGroup.getEndKeyGroup()) { + partitions.add(partition); + } } + return partitions; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/OffsetStore.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/OffsetStore.java index 31c20c141..a6899e74e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/OffsetStore.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/function/OffsetStore.java @@ -19,12 +19,12 @@ package org.apache.geaflow.dsl.connector.api.function; -import com.alibaba.fastjson.JSON; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -45,140 +45,153 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * The offset store for {@link TableSource} to read and write offset for each partition. - */ -public class OffsetStore { +import com.alibaba.fastjson.JSON; - private static final Logger LOGGER = LoggerFactory.getLogger(OffsetStore.class); - - private static final String KEY_SEPARATOR = "_"; - private static final String CONSOLE_OFFSET = "offset"; - private static final String CHECKPOINT_OFFSET = "checkpoint" + KEY_SEPARATOR + "offset"; - - private final long bucketNum; - private final String jobId; - private final String tableName; - - private final transient IKVStore kvStore; - private final transient Map kvStoreCache; - private final transient IKVStore jsonOffsetStore; - - public OffsetStore(RuntimeContext runtimeContext, String tableName) { - Configuration configuration = runtimeContext.getConfiguration(); - jobId = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.tableName = Objects.requireNonNull(tableName); - - String backendType = runtimeContext.getConfiguration().getString(FrameworkConfigKeys.SYSTEM_OFFSET_BACKEND_TYPE); - IStoreBuilder builder = StoreBuilderFactory.build(backendType.toUpperCase(Locale.ROOT)); - if (builder instanceof RocksdbStoreBuilder) { - throw new GeaflowRuntimeException("GeaFlow offset not support ROCKSDB storage and should " - + "be configured as JDBC or MEMORY"); - } - kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); - jsonOffsetStore = (IKVStore) builder.getStore(DataModel.KV, configuration); - String stateName = configuration.getString(ExecutionConfigKeys.SYSTEM_META_TABLE, - generateKey(jobId, tableName)); - StoreContext storeContext = new StoreContext(stateName).withConfig(configuration); - storeContext.withKeySerializer(new OffsetKvSerializer()); - kvStore.init(storeContext); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); - jsonOffsetStore.init(storeContext); - - long bucketNum = 2 * runtimeContext.getConfiguration().getLong(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT); - long streamFlyingNum = runtimeContext.getConfiguration().getInteger(FrameworkConfigKeys.STREAMING_FLYING_BATCH_NUM) + 1; - if (bucketNum < streamFlyingNum) { - bucketNum = streamFlyingNum; - } - this.bucketNum = bucketNum; - this.kvStoreCache = new HashMap<>(); - LOGGER.info("init offset store, store type is: {}, bucket num is: {}", backendType, this.bucketNum); - } +/** The offset store for {@link TableSource} to read and write offset for each partition. */ +public class OffsetStore { - public Offset readOffset(String partitionName, long batchId) { - long bucketId = batchId % bucketNum; - String key = generateKey(jobId, CHECKPOINT_OFFSET, tableName, partitionName, - String.valueOf(bucketId)); - if (kvStoreCache.containsKey(key)) { - return kvStoreCache.get(key); - } else { - Offset offset = RetryCommand.run(() -> kvStore.get(key), 3); - kvStoreCache.put(key, offset); - return offset; - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(OffsetStore.class); - public void writeOffset(String partitionName, long batchId, Offset offset) { - long bucketId = batchId % bucketNum; - String key = generateKey(jobId, CHECKPOINT_OFFSET, tableName, partitionName, String.valueOf(bucketId)); - String keyForConsole = generateKey(jobId, CONSOLE_OFFSET, tableName, partitionName); - kvStoreCache.put(key, offset); - RetryCommand.run(() -> { - kvStore.put(key, offset); - jsonOffsetStore.put(keyForConsole, new ConsoleOffset(offset).toJson()); - return null; - }, 3); + private static final String KEY_SEPARATOR = "_"; + private static final String CONSOLE_OFFSET = "offset"; + private static final String CHECKPOINT_OFFSET = "checkpoint" + KEY_SEPARATOR + "offset"; - } + private final long bucketNum; + private final String jobId; + private final String tableName; - private static class OffsetKvSerializer implements IKVSerializer { + private final transient IKVStore kvStore; + private final transient Map kvStoreCache; + private final transient IKVStore jsonOffsetStore; - @Override - public byte[] serializeValue(Offset value) { - return SerializerFactory.getKryoSerializer().serialize(value); - } + public OffsetStore(RuntimeContext runtimeContext, String tableName) { + Configuration configuration = runtimeContext.getConfiguration(); + jobId = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); + this.tableName = Objects.requireNonNull(tableName); - @Override - public Offset deserializeValue(byte[] valueArray) { - return (Offset) SerializerFactory.getKryoSerializer().deserialize(valueArray); - } + String backendType = + runtimeContext.getConfiguration().getString(FrameworkConfigKeys.SYSTEM_OFFSET_BACKEND_TYPE); + IStoreBuilder builder = StoreBuilderFactory.build(backendType.toUpperCase(Locale.ROOT)); + if (builder instanceof RocksdbStoreBuilder) { + throw new GeaflowRuntimeException( + "GeaFlow offset not support ROCKSDB storage and should " + + "be configured as JDBC or MEMORY"); + } + kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); + jsonOffsetStore = (IKVStore) builder.getStore(DataModel.KV, configuration); + String stateName = + configuration.getString( + ExecutionConfigKeys.SYSTEM_META_TABLE, generateKey(jobId, tableName)); + StoreContext storeContext = new StoreContext(stateName).withConfig(configuration); + storeContext.withKeySerializer(new OffsetKvSerializer()); + kvStore.init(storeContext); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); + jsonOffsetStore.init(storeContext); + + long bucketNum = + 2 + * runtimeContext + .getConfiguration() + .getLong(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT); + long streamFlyingNum = + runtimeContext.getConfiguration().getInteger(FrameworkConfigKeys.STREAMING_FLYING_BATCH_NUM) + + 1; + if (bucketNum < streamFlyingNum) { + bucketNum = streamFlyingNum; + } + this.bucketNum = bucketNum; + this.kvStoreCache = new HashMap<>(); + LOGGER.info( + "init offset store, store type is: {}, bucket num is: {}", backendType, this.bucketNum); + } + + public Offset readOffset(String partitionName, long batchId) { + long bucketId = batchId % bucketNum; + String key = + generateKey(jobId, CHECKPOINT_OFFSET, tableName, partitionName, String.valueOf(bucketId)); + if (kvStoreCache.containsKey(key)) { + return kvStoreCache.get(key); + } else { + Offset offset = RetryCommand.run(() -> kvStore.get(key), 3); + kvStoreCache.put(key, offset); + return offset; + } + } + + public void writeOffset(String partitionName, long batchId, Offset offset) { + long bucketId = batchId % bucketNum; + String key = + generateKey(jobId, CHECKPOINT_OFFSET, tableName, partitionName, String.valueOf(bucketId)); + String keyForConsole = generateKey(jobId, CONSOLE_OFFSET, tableName, partitionName); + kvStoreCache.put(key, offset); + RetryCommand.run( + () -> { + kvStore.put(key, offset); + jsonOffsetStore.put(keyForConsole, new ConsoleOffset(offset).toJson()); + return null; + }, + 3); + } + + private static class OffsetKvSerializer implements IKVSerializer { + + @Override + public byte[] serializeValue(Offset value) { + return SerializerFactory.getKryoSerializer().serialize(value); + } - @Override - public byte[] serializeKey(String key) { - return key.getBytes(StandardCharsets.UTF_8); - } + @Override + public Offset deserializeValue(byte[] valueArray) { + return (Offset) SerializerFactory.getKryoSerializer().deserialize(valueArray); + } - @Override - public String deserializeKey(byte[] array) { - return new String(array, StandardCharsets.UTF_8); - } + @Override + public byte[] serializeKey(String key) { + return key.getBytes(StandardCharsets.UTF_8); } - public static String generateKey(String... strings) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < strings.length; i++) { - if (i > 0) { - sb.append(KEY_SEPARATOR); - } - sb.append(strings[i]); - } - return sb.toString().replaceAll("'", ""); + @Override + public String deserializeKey(byte[] array) { + return new String(array, StandardCharsets.UTF_8); + } + } + + public static String generateKey(String... strings) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < strings.length; i++) { + if (i > 0) { + sb.append(KEY_SEPARATOR); + } + sb.append(strings[i]); } + return sb.toString().replaceAll("'", ""); + } - public static class ConsoleOffset { + public static class ConsoleOffset { - private final long offset; + private final long offset; - private final long writeTime; + private final long writeTime; - enum TYPE { - TIMESTAMP, NON_TIMESTAMP - } + enum TYPE { + TIMESTAMP, + NON_TIMESTAMP + } - final TYPE type; + final TYPE type; - public ConsoleOffset(Offset offset) { - this.offset = offset.getOffset(); - this.writeTime = System.currentTimeMillis(); - this.type = offset.isTimestamp() ? TYPE.TIMESTAMP : TYPE.NON_TIMESTAMP; - } + public ConsoleOffset(Offset offset) { + this.offset = offset.getOffset(); + this.writeTime = System.currentTimeMillis(); + this.type = offset.isTimestamp() ? TYPE.TIMESTAMP : TYPE.NON_TIMESTAMP; + } - public String toJson() { - Map kvMap = new HashMap<>(3); - kvMap.put("offset", String.valueOf(offset)); - kvMap.put("writeTime", String.valueOf(writeTime)); - kvMap.put("type", String.valueOf(type)); - return JSON.toJSONString(kvMap); - } + public String toJson() { + Map kvMap = new HashMap<>(3); + kvMap.put("offset", String.valueOf(offset)); + kvMap.put("writeTime", String.valueOf(writeTime)); + kvMap.put("type", String.valueOf(type)); + return JSON.toJSONString(kvMap); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/DeserializerFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/DeserializerFactory.java index 59f831454..ed9d8be21 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/DeserializerFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/DeserializerFactory.java @@ -28,22 +28,23 @@ public class DeserializerFactory { - public static TableDeserializer loadDeserializer(Configuration conf) { - String connectorFormat = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT, + public static TableDeserializer loadDeserializer(Configuration conf) { + String connectorFormat = + conf.getString( + ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT, (String) ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT.getDefaultValue()); - if (connectorFormat.equals(ConnectorConstants.CONNECTOR_FORMAT_JSON)) { - return (TableDeserializer) new JsonDeserializer(); - } else { - return (TableDeserializer) new TextDeserializer(); - } + if (connectorFormat.equals(ConnectorConstants.CONNECTOR_FORMAT_JSON)) { + return (TableDeserializer) new JsonDeserializer(); + } else { + return (TableDeserializer) new TextDeserializer(); } + } - public static TableDeserializer loadRowTableDeserializer() { - return (TableDeserializer) new RowTableDeserializer(); - } - - public static TableDeserializer loadTextDeserializer() { - return (TableDeserializer) new TextDeserializer(); - } + public static TableDeserializer loadRowTableDeserializer() { + return (TableDeserializer) new RowTableDeserializer(); + } + public static TableDeserializer loadTextDeserializer() { + return (TableDeserializer) new TextDeserializer(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/TableDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/TableDeserializer.java index d176fbd42..237e635db 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/TableDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/TableDeserializer.java @@ -21,27 +21,26 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.types.StructType; import org.apache.geaflow.dsl.connector.api.TableSource; /** - * The deserializer interface for the {@link TableSource} to - * convert the records fetched from the source to list of {@link Row}. + * The deserializer interface for the {@link TableSource} to convert the records fetched from the + * source to list of {@link Row}. */ public interface TableDeserializer extends Serializable { - /** - * Init method for deserializer. - * - * @param conf The configuration of the table source. - * @param schema The schema of the table source. - */ - void init(Configuration conf, StructType schema); + /** + * Init method for deserializer. + * + * @param conf The configuration of the table source. + * @param schema The schema of the table source. + */ + void init(Configuration conf, StructType schema); - /** - * Returns the deserialized rows for the input record. - */ - List deserialize(IN record); + /** Returns the deserialized rows for the input record. */ + List deserialize(IN record); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/JsonDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/JsonDeserializer.java index 4cd6bd80c..7782955aa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/JsonDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/JsonDeserializer.java @@ -19,12 +19,10 @@ package org.apache.geaflow.dsl.connector.api.serde.impl; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -35,63 +33,67 @@ import org.apache.geaflow.dsl.common.util.TypeCastUtil; import org.apache.geaflow.dsl.connector.api.serde.TableDeserializer; -public class JsonDeserializer implements TableDeserializer { +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; - private StructType schema; +public class JsonDeserializer implements TableDeserializer { - private ObjectMapper mapper; + private StructType schema; - private boolean ignoreParseError; + private ObjectMapper mapper; - private boolean failOnMissingField; + private boolean ignoreParseError; + private boolean failOnMissingField; - @Override - public void init(Configuration conf, StructType schema) { - this.schema = Objects.requireNonNull(schema); - this.mapper = new ObjectMapper(); - this.ignoreParseError = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_IGNORE_PARSE_ERROR); - this.failOnMissingField = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_FAIL_ON_MISSING_FIELD); + @Override + public void init(Configuration conf, StructType schema) { + this.schema = Objects.requireNonNull(schema); + this.mapper = new ObjectMapper(); + this.ignoreParseError = + conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_IGNORE_PARSE_ERROR); + this.failOnMissingField = + conf.getBoolean( + ConnectorConfigKeys.GEAFLOW_DSL_CONNECTOR_FORMAT_JSON_FAIL_ON_MISSING_FIELD); + } + @Override + public List deserialize(String record) { + if (record == null || record.isEmpty()) { + return Collections.emptyList(); } - - @Override - public List deserialize(String record) { - if (record == null || record.isEmpty()) { - return Collections.emptyList(); - } - Object[] values = new Object[schema.size()]; - JsonNode jsonNode = null; - try { - jsonNode = mapper.readTree(record); - } catch (JsonProcessingException e) { - // handle exception according to configuration - if (ignoreParseError) { - // return empty list - return Collections.emptyList(); - } else { - throw new GeaflowRuntimeException("fail to deserialize record " + record, e); - } - } - // if json node is null - for (int i = 0; i < schema.size(); i++) { - String fieldName = schema.getFieldNames().get(i); - if (failOnMissingField) { - if (!jsonNode.has(fieldName)) { - throw new GeaflowRuntimeException("fail to deserialize record " + record + " due to missing field " + fieldName); - } - } - JsonNode value = jsonNode.get(fieldName); - IType type = schema.getType(i); - // cast the value to the type defined in the schema. - if (value != null) { - values[i] = TypeCastUtil.cast(value.asText(), type); - } else { - values[i] = null; - } - + Object[] values = new Object[schema.size()]; + JsonNode jsonNode = null; + try { + jsonNode = mapper.readTree(record); + } catch (JsonProcessingException e) { + // handle exception according to configuration + if (ignoreParseError) { + // return empty list + return Collections.emptyList(); + } else { + throw new GeaflowRuntimeException("fail to deserialize record " + record, e); + } + } + // if json node is null + for (int i = 0; i < schema.size(); i++) { + String fieldName = schema.getFieldNames().get(i); + if (failOnMissingField) { + if (!jsonNode.has(fieldName)) { + throw new GeaflowRuntimeException( + "fail to deserialize record " + record + " due to missing field " + fieldName); } - return Collections.singletonList(ObjectRow.create(values)); + } + JsonNode value = jsonNode.get(fieldName); + IType type = schema.getType(i); + // cast the value to the type defined in the schema. + if (value != null) { + values[i] = TypeCastUtil.cast(value.asText(), type); + } else { + values[i] = null; + } } - + return Collections.singletonList(ObjectRow.create(values)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/RowTableDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/RowTableDeserializer.java index b418e8480..92545d472 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/RowTableDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/RowTableDeserializer.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -32,22 +33,22 @@ public class RowTableDeserializer implements TableDeserializer { - private StructType schema; + private StructType schema; - @Override - public void init(Configuration conf, StructType schema) { - this.schema = schema; - } + @Override + public void init(Configuration conf, StructType schema) { + this.schema = schema; + } - @Override - public List deserialize(Row record) { - Object[] values = new Object[schema.size()]; + @Override + public List deserialize(Row record) { + Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - IType type = schema.getType(i); - // cast the value to the type defined in the schema. - values[i] = TypeCastUtil.cast(record.getField(i, ObjectType.INSTANCE), type); - } - return Collections.singletonList(ObjectRow.create(values)); + for (int i = 0; i < schema.size(); i++) { + IType type = schema.getType(i); + // cast the value to the type defined in the schema. + values[i] = TypeCastUtil.cast(record.getField(i, ObjectType.INSTANCE), type); } + return Collections.singletonList(ObjectRow.create(values)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/TextDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/TextDeserializer.java index c2e5fe6a2..30e2dc355 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/TextDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/serde/impl/TextDeserializer.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -35,45 +36,47 @@ public class TextDeserializer implements TableDeserializer { - private String lineSeparator; + private String lineSeparator; - private String columnSeparator; + private String columnSeparator; - private boolean isColumnTrim; + private boolean isColumnTrim; - private StructType schema; + private StructType schema; - @Override - public void init(Configuration conf, StructType schema) { - this.lineSeparator = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_LINE_SEPARATOR); - this.columnSeparator = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); - this.isColumnTrim = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_TRIM); - this.schema = Objects.requireNonNull(schema); - } + @Override + public void init(Configuration conf, StructType schema) { + this.lineSeparator = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_LINE_SEPARATOR); + this.columnSeparator = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + this.isColumnTrim = conf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_TRIM); + this.schema = Objects.requireNonNull(schema); + } - @Override - public List deserialize(String text) { - if (text == null || text.isEmpty()) { - return Collections.emptyList(); - } - List rows = new ArrayList<>(); - String[] lines = StringUtils.splitByWholeSeparator(text, lineSeparator); - for (String line : lines) { - if (line.isEmpty() && schema.size() >= 1) { - continue; - } - String[] fields = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, columnSeparator); - if (schema.size() != fields.length) { - throw new GeaFlowDSLException("Data fields size:{}, is not equal to the schema size:{}", - fields.length, schema.size()); - } - Object[] values = new Object[schema.size()]; - for (int i = 0; i < values.length; i++) { - String trimField = isColumnTrim ? StringUtils.trim(fields[i]) : fields[i]; - values[i] = TypeCastUtil.cast(trimField, schema.getType(i)); - } - rows.add(ObjectRow.create(values)); - } - return rows; + @Override + public List deserialize(String text) { + if (text == null || text.isEmpty()) { + return Collections.emptyList(); + } + List rows = new ArrayList<>(); + String[] lines = StringUtils.splitByWholeSeparator(text, lineSeparator); + for (String line : lines) { + if (line.isEmpty() && schema.size() >= 1) { + continue; + } + String[] fields = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, columnSeparator); + if (schema.size() != fields.length) { + throw new GeaFlowDSLException( + "Data fields size:{}, is not equal to the schema size:{}", + fields.length, + schema.size()); + } + Object[] values = new Object[schema.size()]; + for (int i = 0; i < values.length; i++) { + String trimField = isColumnTrim ? StringUtils.trim(fields[i]) : fields[i]; + values[i] = TypeCastUtil.cast(trimField, schema.getType(i)); + } + rows.add(ObjectRow.create(values)); } + return rows; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorConstants.java index 59d0c71ef..247d18a29 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorConstants.java @@ -21,8 +21,7 @@ public class ConnectorConstants { - public static final String START_TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"; - public static final String CONNECTOR_FORMAT_JSON = "json"; - public static final String CONNECTOR_FORMAT_TEXT = "text"; - + public static final String START_TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"; + public static final String CONNECTOR_FORMAT_JSON = "json"; + public static final String CONNECTOR_FORMAT_TEXT = "text"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorFactory.java index 0c27fc694..f0dfbcb4d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/util/ConnectorFactory.java @@ -20,23 +20,24 @@ package org.apache.geaflow.dsl.connector.api.util; import java.util.ServiceLoader; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.connector.api.TableConnector; public class ConnectorFactory { - public static TableConnector loadConnector(String tableType) { - ServiceLoader connectors = ServiceLoader.load(TableConnector.class); - TableConnector currentConnector = null; - for (TableConnector connector : connectors) { - if (connector.getType().equalsIgnoreCase(tableType)) { - currentConnector = connector; - break; - } - } - if (currentConnector == null) { - throw new GeaFlowDSLException("Table type: '{}' has not implement", tableType); - } - return currentConnector; + public static TableConnector loadConnector(String tableType) { + ServiceLoader connectors = ServiceLoader.load(TableConnector.class); + TableConnector currentConnector = null; + for (TableConnector connector : connectors) { + if (connector.getType().equalsIgnoreCase(tableType)) { + currentConnector = connector; + break; + } + } + if (currentConnector == null) { + throw new GeaFlowDSLException("Table type: '{}' has not implement", tableType); } + return currentConnector; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AbstractFetchWindow.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AbstractFetchWindow.java index 478f0bff7..1051c0eec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AbstractFetchWindow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AbstractFetchWindow.java @@ -21,14 +21,14 @@ public abstract class AbstractFetchWindow implements FetchWindow { - protected final long windowId; + protected final long windowId; - public AbstractFetchWindow(long windowId) { - this.windowId = windowId; - } + public AbstractFetchWindow(long windowId) { + this.windowId = windowId; + } - @Override - public long windowId() { - return windowId; - } + @Override + public long windowId() { + return windowId; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AllFetchWindow.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AllFetchWindow.java index e7dad2422..0d8d60295 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AllFetchWindow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/AllFetchWindow.java @@ -21,22 +21,20 @@ import org.apache.geaflow.api.window.WindowType; -/** - * Fetch all. - */ +/** Fetch all. */ public class AllFetchWindow extends AbstractFetchWindow { - public AllFetchWindow(long windowId) { - super(windowId); - } + public AllFetchWindow(long windowId) { + super(windowId); + } - @Override - public long windowSize() { - return -1; - } + @Override + public long windowSize() { + return -1; + } - @Override - public WindowType getType() { - return WindowType.ALL_WINDOW; - } + @Override + public WindowType getType() { + return WindowType.ALL_WINDOW; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindow.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindow.java index 8bb84bd8f..35e6b040a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindow.java @@ -21,24 +21,15 @@ import org.apache.geaflow.api.window.WindowType; -/** - * Interface for the table source fetch records. - */ +/** Interface for the table source fetch records. */ public interface FetchWindow { - /** - * Return the window id. - */ - long windowId(); - - /** - * Return the window size. - */ - long windowSize(); + /** Return the window id. */ + long windowId(); - /** - * Return the window type. - */ - WindowType getType(); + /** Return the window size. */ + long windowSize(); + /** Return the window type. */ + WindowType getType(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindowFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindowFactory.java index b740e1c37..682a95586 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindowFactory.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/FetchWindowFactory.java @@ -24,22 +24,20 @@ import org.apache.geaflow.api.window.impl.SizeTumblingWindow; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; -/** - * Convert to the fetch window from the common window. - */ +/** Convert to the fetch window from the common window. */ public class FetchWindowFactory { - public static FetchWindow createFetchWindow(IWindow window) { - switch (window.getType()) { - case ALL_WINDOW: - return new AllFetchWindow(window.windowId()); - case SIZE_TUMBLING_WINDOW: - return new SizeFetchWindow(window.windowId(), ((SizeTumblingWindow) window).getSize()); - case FIXED_TIME_TUMBLING_WINDOW: - return new TimeFetchWindow(window.windowId(), ((FixedTimeTumblingWindow) window).getTimeWindowSize()); - default: - throw new GeaFlowDSLException("Not support window type:{}", window.getType()); - } + public static FetchWindow createFetchWindow(IWindow window) { + switch (window.getType()) { + case ALL_WINDOW: + return new AllFetchWindow(window.windowId()); + case SIZE_TUMBLING_WINDOW: + return new SizeFetchWindow(window.windowId(), ((SizeTumblingWindow) window).getSize()); + case FIXED_TIME_TUMBLING_WINDOW: + return new TimeFetchWindow( + window.windowId(), ((FixedTimeTumblingWindow) window).getTimeWindowSize()); + default: + throw new GeaFlowDSLException("Not support window type:{}", window.getType()); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/SizeFetchWindow.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/SizeFetchWindow.java index 0a2785748..52b227b16 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/SizeFetchWindow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/SizeFetchWindow.java @@ -21,25 +21,23 @@ import org.apache.geaflow.api.window.WindowType; -/** - * size window. - */ +/** size window. */ public class SizeFetchWindow extends AbstractFetchWindow { - private final long size; + private final long size; - public SizeFetchWindow(long windowId, long size) { - super(windowId); - this.size = size; - } + public SizeFetchWindow(long windowId, long size) { + super(windowId); + this.size = size; + } - @Override - public long windowSize() { - return size; - } + @Override + public long windowSize() { + return size; + } - @Override - public WindowType getType() { - return WindowType.SIZE_TUMBLING_WINDOW; - } + @Override + public WindowType getType() { + return WindowType.SIZE_TUMBLING_WINDOW; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/TimeFetchWindow.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/TimeFetchWindow.java index 19aefef6c..3b3ce8249 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/TimeFetchWindow.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/main/java/org/apache/geaflow/dsl/connector/api/window/TimeFetchWindow.java @@ -21,35 +21,33 @@ import org.apache.geaflow.api.window.WindowType; -/** - * Time window. - */ +/** Time window. */ public class TimeFetchWindow extends AbstractFetchWindow { - private final long windowSizeInSecond; + private final long windowSizeInSecond; - public TimeFetchWindow(long windowId, long windowSizeInSecond) { - super(windowId); - this.windowSizeInSecond = windowSizeInSecond; - } + public TimeFetchWindow(long windowId, long windowSizeInSecond) { + super(windowId); + this.windowSizeInSecond = windowSizeInSecond; + } - // include - public long getStartWindowTime(long startTime) { - return startTime + windowId * windowSizeInSecond * 1000; - } + // include + public long getStartWindowTime(long startTime) { + return startTime + windowId * windowSizeInSecond * 1000; + } - // exclude - public long getEndWindowTime(long startTime) { - return startTime + (windowId + 1) * windowSizeInSecond * 1000; - } + // exclude + public long getEndWindowTime(long startTime) { + return startTime + (windowId + 1) * windowSizeInSecond * 1000; + } - @Override - public long windowSize() { - return windowSizeInSecond; - } + @Override + public long windowSize() { + return windowSizeInSecond; + } - @Override - public WindowType getType() { - return WindowType.FIXED_TIME_TUMBLING_WINDOW; - } + @Override + public WindowType getType() { + return WindowType.FIXED_TIME_TUMBLING_WINDOW; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/test/java/org/apache/geaflow/dsl/connector/api/JsonDeserializerTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/test/java/org/apache/geaflow/dsl/connector/api/JsonDeserializerTest.java index b49782691..3e4ca4cf3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/test/java/org/apache/geaflow/dsl/connector/api/JsonDeserializerTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-api/src/test/java/org/apache/geaflow/dsl/connector/api/JsonDeserializerTest.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.type.primitive.BinaryStringType; @@ -34,84 +35,79 @@ public class JsonDeserializerTest { - @Test - public void testDeserialize() { - JsonDeserializer deserializer = new JsonDeserializer(); - StructType dataSchema = new StructType( + @Test + public void testDeserialize() { + JsonDeserializer deserializer = new JsonDeserializer(); + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - deserializer.init(new Configuration(), dataSchema); - List row = deserializer.deserialize("{\"id\":1, \"name\":\"amy\", \"age\":10}"); - List rowWithNull = deserializer.deserialize("{\"id\":1, \"name\":\"amy\"}"); - Assert.assertEquals(row.get(0).getField(0, IntegerType.INSTANCE), 1); - Assert.assertEquals(row.get(0).getField(1, BinaryStringType.INSTANCE).toString(), "amy"); - Assert.assertEquals(row.get(0).getField(2, IntegerType.INSTANCE), 10); - Assert.assertEquals(rowWithNull.get(0).getField(0, IntegerType.INSTANCE), 1); - Assert.assertEquals(rowWithNull.get(0).getField(1, BinaryStringType.INSTANCE).toString(), "amy"); - Assert.assertEquals(rowWithNull.get(0).getField(2, IntegerType.INSTANCE), null); - - } + new TableField("age", IntegerType.INSTANCE, false)); + deserializer.init(new Configuration(), dataSchema); + List row = deserializer.deserialize("{\"id\":1, \"name\":\"amy\", \"age\":10}"); + List rowWithNull = deserializer.deserialize("{\"id\":1, \"name\":\"amy\"}"); + Assert.assertEquals(row.get(0).getField(0, IntegerType.INSTANCE), 1); + Assert.assertEquals(row.get(0).getField(1, BinaryStringType.INSTANCE).toString(), "amy"); + Assert.assertEquals(row.get(0).getField(2, IntegerType.INSTANCE), 10); + Assert.assertEquals(rowWithNull.get(0).getField(0, IntegerType.INSTANCE), 1); + Assert.assertEquals( + rowWithNull.get(0).getField(1, BinaryStringType.INSTANCE).toString(), "amy"); + Assert.assertEquals(rowWithNull.get(0).getField(2, IntegerType.INSTANCE), null); + } - - @Test - public void testDeserializeEmptyString() { - JsonDeserializer deserializer = new JsonDeserializer(); - StructType dataSchema = new StructType( + @Test + public void testDeserializeEmptyString() { + JsonDeserializer deserializer = new JsonDeserializer(); + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - deserializer.init(new Configuration(), dataSchema); - List rows = deserializer.deserialize(""); - List testNullRows = deserializer.deserialize(null); - Assert.assertEquals(rows, Collections.emptyList()); - Assert.assertEquals(testNullRows, Collections.emptyList()); - - } + new TableField("age", IntegerType.INSTANCE, false)); + deserializer.init(new Configuration(), dataSchema); + List rows = deserializer.deserialize(""); + List testNullRows = deserializer.deserialize(null); + Assert.assertEquals(rows, Collections.emptyList()); + Assert.assertEquals(testNullRows, Collections.emptyList()); + } - @Test(expected = GeaflowRuntimeException.class) - public void testDeserializeParseError() { - JsonDeserializer deserializer = new JsonDeserializer(); - StructType dataSchema = new StructType( + @Test(expected = GeaflowRuntimeException.class) + public void testDeserializeParseError() { + JsonDeserializer deserializer = new JsonDeserializer(); + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - deserializer.init(new Configuration(), dataSchema); - List rows = deserializer.deserialize("test"); - } + new TableField("age", IntegerType.INSTANCE, false)); + deserializer.init(new Configuration(), dataSchema); + List rows = deserializer.deserialize("test"); + } - @Test - public void testDeserializeIgnoreParseError() { - JsonDeserializer deserializer = new JsonDeserializer(); - StructType dataSchema = new StructType( + @Test + public void testDeserializeIgnoreParseError() { + JsonDeserializer deserializer = new JsonDeserializer(); + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - Configuration conf = new Configuration(); - conf.put("geaflow.dsl.connector.format.json.ignore-parse-error", "true"); - deserializer.init(conf, dataSchema); - List rows = deserializer.deserialize("test"); - Assert.assertEquals(rows, Collections.emptyList()); - } + new TableField("age", IntegerType.INSTANCE, false)); + Configuration conf = new Configuration(); + conf.put("geaflow.dsl.connector.format.json.ignore-parse-error", "true"); + deserializer.init(conf, dataSchema); + List rows = deserializer.deserialize("test"); + Assert.assertEquals(rows, Collections.emptyList()); + } - @Test(expected = GeaflowRuntimeException.class) - public void testDeserializeFailOnMissingField() { - JsonDeserializer deserializer = new JsonDeserializer(); - StructType dataSchema = new StructType( + @Test(expected = GeaflowRuntimeException.class) + public void testDeserializeFailOnMissingField() { + JsonDeserializer deserializer = new JsonDeserializer(); + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - Configuration conf = new Configuration(); - conf.put("geaflow.dsl.connector.format.json.fail-on-missing-field", "true"); - deserializer.init(conf, dataSchema); - List rowWithMissingField = deserializer.deserialize("{\"id\":1, \"name\":\"amy\"}"); - - } - - + new TableField("age", IntegerType.INSTANCE, false)); + Configuration conf = new Configuration(); + conf.put("geaflow.dsl.connector.format.json.fail-on-missing-field", "true"); + deserializer.init(conf, dataSchema); + List rowWithMissingField = deserializer.deserialize("{\"id\":1, \"name\":\"amy\"}"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleConfigKeys.java index d80a5ed80..d482afe4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleConfigKeys.java @@ -24,8 +24,8 @@ public class ConsoleConfigKeys { - public static final ConfigKey GEAFLOW_DSL_CONSOLE_SKIP = ConfigKeys - .key("geaflow.dsl.console.skip") - .defaultValue(false) - .description("Whether skip write to console."); + public static final ConfigKey GEAFLOW_DSL_CONSOLE_SKIP = + ConfigKeys.key("geaflow.dsl.console.skip") + .defaultValue(false) + .description("Whether skip write to console."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnector.java index 46cc0b479..bc678227a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnector.java @@ -25,13 +25,13 @@ public class ConsoleTableConnector implements TableWritableConnector { - @Override - public String getType() { - return "CONSOLE"; - } + @Override + public String getType() { + return "CONSOLE"; + } - @Override - public TableSink createSink(Configuration conf) { - return new ConsoleTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new ConsoleTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableSink.java index 97a4cd047..c7331a6e2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/main/java/org/apache/geaflow/dsl/connector/console/ConsoleTableSink.java @@ -29,34 +29,28 @@ public class ConsoleTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(ConsoleTableSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ConsoleTableSink.class); - private boolean skip; + private boolean skip; - @Override - public void init(Configuration tableConf, StructType schema) { - skip = tableConf.getBoolean(ConsoleConfigKeys.GEAFLOW_DSL_CONSOLE_SKIP); - } - - @Override - public void open(RuntimeContext context) { + @Override + public void init(Configuration tableConf, StructType schema) { + skip = tableConf.getBoolean(ConsoleConfigKeys.GEAFLOW_DSL_CONSOLE_SKIP); + } - } + @Override + public void open(RuntimeContext context) {} - @Override - public void write(Row row) { - if (!skip) { - LOGGER.info(row.toString()); - } + @Override + public void write(Row row) { + if (!skip) { + LOGGER.info(row.toString()); } + } - @Override - public void finish() { + @Override + public void finish() {} - } - - @Override - public void close() { - - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/test/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/test/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnectorTest.java index 0ec59dca6..25bbf2984 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/test/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-console/src/test/java/org/apache/geaflow/dsl/connector/console/ConsoleTableConnectorTest.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.task.TaskArgs; @@ -33,59 +34,60 @@ public class ConsoleTableConnectorTest { - @Test - public void testConsole() throws IOException { - ConsoleTableConnector connector = new ConsoleTableConnector(); - TableSink sink = connector.createSink(new Configuration()); - Assert.assertEquals(sink.getClass(), ConsoleTableSink.class); - Configuration tableConf = new Configuration(); + @Test + public void testConsole() throws IOException { + ConsoleTableConnector connector = new ConsoleTableConnector(); + TableSink sink = connector.createSink(new Configuration()); + Assert.assertEquals(sink.getClass(), ConsoleTableSink.class); + Configuration tableConf = new Configuration(); - sink.init(tableConf, new StructType()); + sink.init(tableConf, new StructType()); - sink.open(new RuntimeContext() { - @Override - public long getPipelineId() { - return 0; - } + sink.open( + new RuntimeContext() { + @Override + public long getPipelineId() { + return 0; + } - @Override - public String getPipelineName() { - return null; - } + @Override + public String getPipelineName() { + return null; + } - @Override - public TaskArgs getTaskArgs() { - return null; - } + @Override + public TaskArgs getTaskArgs() { + return null; + } - @Override - public Configuration getConfiguration() { - return null; - } + @Override + public Configuration getConfiguration() { + return null; + } - @Override - public String getWorkPath() { - return null; - } + @Override + public String getWorkPath() { + return null; + } - @Override - public MetricGroup getMetric() { - return null; - } + @Override + public MetricGroup getMetric() { + return null; + } - @Override - public RuntimeContext clone(Map opConfig) { - return null; - } + @Override + public RuntimeContext clone(Map opConfig) { + return null; + } - @Override - public long getWindowId() { - return 0; - } + @Override + public long getWindowId() { + return 0; + } }); - sink.write(ObjectRow.create(1, 2, 3)); - sink.finish(); - sink.close(); - } + sink.write(ObjectRow.create(1, 2, 3)); + sink.finish(); + sink.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeys.java index 65df4b4cd..67d276ae4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeys.java @@ -24,48 +24,48 @@ public class ElasticsearchConfigKeys { - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_HOSTS = ConfigKeys - .key("geaflow.dsl.elasticsearch.hosts") - .noDefaultValue() - .description("Elasticsearch cluster hosts list."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_HOSTS = + ConfigKeys.key("geaflow.dsl.elasticsearch.hosts") + .noDefaultValue() + .description("Elasticsearch cluster hosts list."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_INDEX = ConfigKeys - .key("geaflow.dsl.elasticsearch.index") - .noDefaultValue() - .description("Elasticsearch index name."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_INDEX = + ConfigKeys.key("geaflow.dsl.elasticsearch.index") + .noDefaultValue() + .description("Elasticsearch index name."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD = ConfigKeys - .key("geaflow.dsl.elasticsearch.document.id.field") - .noDefaultValue() - .description("Elasticsearch document id field."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD = + ConfigKeys.key("geaflow.dsl.elasticsearch.document.id.field") + .noDefaultValue() + .description("Elasticsearch document id field."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_USERNAME = ConfigKeys - .key("geaflow.dsl.elasticsearch.username") - .noDefaultValue() - .description("Elasticsearch username for authentication."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_USERNAME = + ConfigKeys.key("geaflow.dsl.elasticsearch.username") + .noDefaultValue() + .description("Elasticsearch username for authentication."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_PASSWORD = ConfigKeys - .key("geaflow.dsl.elasticsearch.password") - .noDefaultValue() - .description("Elasticsearch password for authentication."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_PASSWORD = + ConfigKeys.key("geaflow.dsl.elasticsearch.password") + .noDefaultValue() + .description("Elasticsearch password for authentication."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE = ConfigKeys - .key("geaflow.dsl.elasticsearch.batch.size") - .defaultValue("1000") - .description("Elasticsearch batch write size."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE = + ConfigKeys.key("geaflow.dsl.elasticsearch.batch.size") + .defaultValue("1000") + .description("Elasticsearch batch write size."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT = ConfigKeys - .key("geaflow.dsl.elasticsearch.scroll.timeout") - .defaultValue("60s") - .description("Elasticsearch scroll query timeout."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT = + ConfigKeys.key("geaflow.dsl.elasticsearch.scroll.timeout") + .defaultValue("60s") + .description("Elasticsearch scroll query timeout."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT = ConfigKeys - .key("geaflow.dsl.elasticsearch.connection.timeout") - .defaultValue("1000") - .description("Elasticsearch connection timeout in milliseconds."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT = + ConfigKeys.key("geaflow.dsl.elasticsearch.connection.timeout") + .defaultValue("1000") + .description("Elasticsearch connection timeout in milliseconds."); - public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT = ConfigKeys - .key("geaflow.dsl.elasticsearch.socket.timeout") - .defaultValue("30000") - .description("Elasticsearch socket timeout in milliseconds."); + public static final ConfigKey GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT = + ConfigKeys.key("geaflow.dsl.elasticsearch.socket.timeout") + .defaultValue("30000") + .description("Elasticsearch socket timeout in milliseconds."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConstants.java index ed8c7adac..945c6ded4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConstants.java @@ -21,24 +21,23 @@ public class ElasticsearchConstants { - public static final int DEFAULT_BATCH_SIZE = 1000; + public static final int DEFAULT_BATCH_SIZE = 1000; - public static final String DEFAULT_SCROLL_TIMEOUT = "60s"; + public static final String DEFAULT_SCROLL_TIMEOUT = "60s"; - public static final int DEFAULT_CONNECTION_TIMEOUT = 1000; + public static final int DEFAULT_CONNECTION_TIMEOUT = 1000; - public static final int DEFAULT_SOCKET_TIMEOUT = 30000; + public static final int DEFAULT_SOCKET_TIMEOUT = 30000; - public static final int DEFAULT_SEARCH_SIZE = 1000; + public static final int DEFAULT_SEARCH_SIZE = 1000; - public static final String ES_SCHEMA_SUFFIX = "://"; + public static final String ES_SCHEMA_SUFFIX = "://"; - public static final String ES_HTTP_SCHEME = "http"; + public static final String ES_HTTP_SCHEME = "http"; - public static final String ES_HTTPS_SCHEME = "https"; + public static final String ES_HTTPS_SCHEME = "https"; - public static final String ES_SPLIT_COMMA = ","; - - public static final String ES_SPLIT_COLON = ";"; + public static final String ES_SPLIT_COMMA = ","; + public static final String ES_SPLIT_COLON = ";"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnector.java index f8950a8d0..9e5930e11 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnector.java @@ -27,20 +27,20 @@ public class ElasticsearchTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "ELASTICSEARCH"; + public static final String TYPE = "ELASTICSEARCH"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new ElasticsearchTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new ElasticsearchTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new ElasticsearchTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new ElasticsearchTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSink.java index 4831a6c21..55f0cabbe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSink.java @@ -19,12 +19,12 @@ package org.apache.geaflow.dsl.connector.elasticsearch; -import com.google.gson.Gson; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; @@ -47,168 +47,179 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.gson.Gson; + public class ElasticsearchTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchTableSink.class); - private static final Gson GSON = new Gson(); - - private StructType schema; - private String hosts; - private String indexName; - private String documentIdField; - private String username; - private String password; - private int batchSize; - private int connectionTimeout; - private int socketTimeout; - - private RestHighLevelClient client; - private BulkRequest bulkRequest; - private int batchCounter = 0; - - @Override - public void init(Configuration conf, StructType schema) { - LOGGER.info("Prepare with config: {}, \n schema: {}", conf, schema); - this.schema = schema; - - this.hosts = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS); - this.indexName = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX); - this.documentIdField = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, ""); - this.username = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, ""); - this.password = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_PASSWORD, ""); - this.batchSize = conf.getInteger(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE, - ElasticsearchConstants.DEFAULT_BATCH_SIZE); - this.connectionTimeout = conf.getInteger(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT, - ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT); - this.socketTimeout = conf.getInteger(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT, - ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT); + private static final Logger LOGGER = LoggerFactory.getLogger(ElasticsearchTableSink.class); + private static final Gson GSON = new Gson(); + + private StructType schema; + private String hosts; + private String indexName; + private String documentIdField; + private String username; + private String password; + private int batchSize; + private int connectionTimeout; + private int socketTimeout; + + private RestHighLevelClient client; + private BulkRequest bulkRequest; + private int batchCounter = 0; + + @Override + public void init(Configuration conf, StructType schema) { + LOGGER.info("Prepare with config: {}, \n schema: {}", conf, schema); + this.schema = schema; + + this.hosts = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS); + this.indexName = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX); + this.documentIdField = + conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, ""); + this.username = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, ""); + this.password = conf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_PASSWORD, ""); + this.batchSize = + conf.getInteger( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE, + ElasticsearchConstants.DEFAULT_BATCH_SIZE); + this.connectionTimeout = + conf.getInteger( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT, + ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT); + this.socketTimeout = + conf.getInteger( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT, + ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT); + } + + @Override + public void open(RuntimeContext context) { + try { + this.client = createElasticsearchClient(); + this.bulkRequest = new BulkRequest(); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); } - - @Override - public void open(RuntimeContext context) { - try { - this.client = createElasticsearchClient(); - this.bulkRequest = new BulkRequest(); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); + } + + @Override + public void write(Row row) throws IOException { + // Convert row to JSON document + String jsonDocument = rowToJson(row); + + // Create index request + IndexRequest request = new IndexRequest(indexName); + request.source(jsonDocument, XContentType.JSON); + + // Set document ID if specified + if (documentIdField != null && !documentIdField.isEmpty()) { + int idFieldIndex = schema.indexOf(documentIdField); + if (idFieldIndex >= 0) { + Object idValue = row.getField(idFieldIndex, schema.getType(idFieldIndex)); + if (idValue != null) { + request.id(idValue.toString()); } + } } - @Override - public void write(Row row) throws IOException { - // Convert row to JSON document - String jsonDocument = rowToJson(row); - - // Create index request - IndexRequest request = new IndexRequest(indexName); - request.source(jsonDocument, XContentType.JSON); - - // Set document ID if specified - if (documentIdField != null && !documentIdField.isEmpty()) { - int idFieldIndex = schema.indexOf(documentIdField); - if (idFieldIndex >= 0) { - Object idValue = row.getField(idFieldIndex, schema.getType(idFieldIndex)); - if (idValue != null) { - request.id(idValue.toString()); - } - } - } - - // Add to bulk request - bulkRequest.add(request); - batchCounter++; + // Add to bulk request + bulkRequest.add(request); + batchCounter++; - // Flush if batch size reached - if (batchCounter >= batchSize) { - flush(); - } + // Flush if batch size reached + if (batchCounter >= batchSize) { + flush(); } - - @Override - public void finish() throws IOException { - flush(); + } + + @Override + public void finish() throws IOException { + flush(); + } + + @Override + public void close() { + try { + if (Objects.nonNull(this.client)) { + client.close(); + } + } catch (IOException e) { + throw new GeaFlowDSLException("Failed to close Elasticsearch client", e); } - - @Override - public void close() { - try { - if (Objects.nonNull(this.client)) { - client.close(); - } - } catch (IOException e) { - throw new GeaFlowDSLException("Failed to close Elasticsearch client", e); - } + } + + private void flush() throws IOException { + if (batchCounter > 0 && client != null) { + BulkResponse bulkResponse = client.bulk(bulkRequest, RequestOptions.DEFAULT); + if (bulkResponse.hasFailures()) { + LOGGER.error("Bulk request failed: {}", bulkResponse.buildFailureMessage()); + throw new IOException("Bulk request failed: " + bulkResponse.buildFailureMessage()); + } + bulkRequest = new BulkRequest(); + batchCounter = 0; } + } - private void flush() throws IOException { - if (batchCounter > 0 && client != null) { - BulkResponse bulkResponse = client.bulk(bulkRequest, RequestOptions.DEFAULT); - if (bulkResponse.hasFailures()) { - LOGGER.error("Bulk request failed: {}", bulkResponse.buildFailureMessage()); - throw new IOException("Bulk request failed: " + bulkResponse.buildFailureMessage()); - } - bulkRequest = new BulkRequest(); - batchCounter = 0; - } + private String rowToJson(Row row) { + // Convert Row to JSON string + Map map = new HashMap<>(); + List fieldNames = schema.getFieldNames(); + + for (int i = 0; i < fieldNames.size(); i++) { + String fieldName = fieldNames.get(i); + Object fieldValue = row.getField(i, schema.getType(i)); + map.put(fieldName, fieldValue); } - private String rowToJson(Row row) { - // Convert Row to JSON string - Map map = new HashMap<>(); - List fieldNames = schema.getFieldNames(); + return GSON.toJson(map); + } - for (int i = 0; i < fieldNames.size(); i++) { - String fieldName = fieldNames.get(i); - Object fieldValue = row.getField(i, schema.getType(i)); - map.put(fieldName, fieldValue); - } + private RestHighLevelClient createElasticsearchClient() { + try { + String[] hostArray = hosts.split(","); + HttpHost[] httpHosts = new HttpHost[hostArray.length]; - return GSON.toJson(map); - } + for (int i = 0; i < hostArray.length; i++) { + String host = hostArray[i].trim(); + if (host.startsWith("http://")) { + host = host.substring(7); + } else if (host.startsWith("https://")) { + host = host.substring(8); + } - private RestHighLevelClient createElasticsearchClient() { - try { - String[] hostArray = hosts.split(","); - HttpHost[] httpHosts = new HttpHost[hostArray.length]; - - for (int i = 0; i < hostArray.length; i++) { - String host = hostArray[i].trim(); - if (host.startsWith("http://")) { - host = host.substring(7); - } else if (host.startsWith("https://")) { - host = host.substring(8); - } - - String[] parts = host.split(":"); - String hostname = parts[0]; - int port = parts.length > 1 ? Integer.parseInt(parts[1]) : 9200; - httpHosts[i] = new HttpHost(hostname, port, "http"); - } - - RestClientBuilder builder = RestClient.builder(httpHosts); - - // Configure timeouts - builder.setRequestConfigCallback(requestConfigBuilder -> { - requestConfigBuilder.setConnectTimeout(connectionTimeout); - requestConfigBuilder.setSocketTimeout(socketTimeout); - return requestConfigBuilder; + String[] parts = host.split(":"); + String hostname = parts[0]; + int port = parts.length > 1 ? Integer.parseInt(parts[1]) : 9200; + httpHosts[i] = new HttpHost(hostname, port, "http"); + } + + RestClientBuilder builder = RestClient.builder(httpHosts); + + // Configure timeouts + builder.setRequestConfigCallback( + requestConfigBuilder -> { + requestConfigBuilder.setConnectTimeout(connectionTimeout); + requestConfigBuilder.setSocketTimeout(socketTimeout); + return requestConfigBuilder; + }); + + // Configure authentication if provided + if (username != null && !username.isEmpty() && password != null) { + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials( + AuthScope.ANY, new UsernamePasswordCredentials(username, password)); + + builder.setHttpClientConfigCallback( + httpClientBuilder -> { + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + return httpClientBuilder; }); + } - // Configure authentication if provided - if (username != null && !username.isEmpty() && password != null) { - final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(AuthScope.ANY, - new UsernamePasswordCredentials(username, password)); - - builder.setHttpClientConfigCallback(httpClientBuilder -> { - httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); - return httpClientBuilder; - }); - } - - return new RestHighLevelClient(builder); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); - } + return new RestHighLevelClient(builder); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSource.java index 6fed8c37c..09871fc73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/main/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSource.java @@ -26,8 +26,6 @@ import static org.apache.geaflow.dsl.connector.elasticsearch.ElasticsearchConstants.ES_SPLIT_COLON; import static org.apache.geaflow.dsl.connector.elasticsearch.ElasticsearchConstants.ES_SPLIT_COMMA; -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; import java.io.IOException; import java.lang.reflect.Type; import java.util.ArrayList; @@ -35,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; @@ -63,208 +62,221 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ElasticsearchTableSource implements TableSource { - - private static final Gson GSON = new Gson(); - private static final Type MAP_TYPE = new TypeToken>(){}.getType(); - - private Logger logger = LoggerFactory.getLogger(ElasticsearchTableSource.class); - - private StructType schema; - private String hosts; - private String indexName; - private String username; - private String password; - private String scrollTimeout; - private int connectionTimeout; - private int socketTimeout; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; - private RestHighLevelClient client; +public class ElasticsearchTableSource implements TableSource { - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.schema = tableSchema; - this.hosts = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS); - this.indexName = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX); - this.username = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, ""); - this.password = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_PASSWORD, ""); - this.scrollTimeout = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT, - ElasticsearchConstants.DEFAULT_SCROLL_TIMEOUT); - this.connectionTimeout = tableConf.getInteger(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT, - ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT); - this.socketTimeout = tableConf.getInteger(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT, - ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT); + private static final Gson GSON = new Gson(); + private static final Type MAP_TYPE = new TypeToken>() {}.getType(); + + private Logger logger = LoggerFactory.getLogger(ElasticsearchTableSource.class); + + private StructType schema; + private String hosts; + private String indexName; + private String username; + private String password; + private String scrollTimeout; + private int connectionTimeout; + private int socketTimeout; + + private RestHighLevelClient client; + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.schema = tableSchema; + this.hosts = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS); + this.indexName = tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX); + this.username = + tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, ""); + this.password = + tableConf.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_PASSWORD, ""); + this.scrollTimeout = + tableConf.getString( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT, + ElasticsearchConstants.DEFAULT_SCROLL_TIMEOUT); + this.connectionTimeout = + tableConf.getInteger( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT, + ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT); + this.socketTimeout = + tableConf.getInteger( + ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT, + ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT); + } + + @Override + public void open(RuntimeContext context) { + try { + this.client = createElasticsearchClient(); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to initialize Elasticsearch client", e); } - - @Override - public void open(RuntimeContext context) { - try { - this.client = createElasticsearchClient(); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to initialize Elasticsearch client", e); + } + + @Override + public List listPartitions() { + return Collections.singletonList(new ElasticsearchPartition(indexName)); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return new TableDeserializer() { + @Override + public void init(Configuration configuration, StructType structType) { + // Initialization if needed + } + + @Override + public List deserialize(IN record) { + if (record instanceof SearchHit) { + SearchHit hit = (SearchHit) record; + Map source = hit.getSourceAsMap(); + if (source == null) { + source = GSON.fromJson(hit.getSourceAsString(), MAP_TYPE); + } + + // Convert map to Row based on schema + Object[] values = new Object[schema.size()]; + for (int i = 0; i < schema.size(); i++) { + String fieldName = schema.getFields().get(i).getName(); + values[i] = source.get(fieldName); + } + Row row = ObjectRow.create(values); + return Collections.singletonList(row); } + return Collections.emptyList(); + } + }; + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + try { + SearchRequest searchRequest = new SearchRequest(indexName); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(DEFAULT_SEARCH_SIZE); // Batch size + + searchRequest.source(searchSourceBuilder); + + // Use scroll for large dataset reading + Scroll scroll = new Scroll(TimeValue.parseTimeValue(scrollTimeout, "scroll_timeout")); + searchRequest.scroll(scroll); + + SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + String scrollId = searchResponse.getScrollId(); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + + List dataList = new ArrayList<>(); + for (SearchHit hit : searchHits) { + dataList.add((T) hit); + } + + // Clear scroll + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.addScrollId(scrollId); + client.clearScroll(clearScrollRequest, RequestOptions.DEFAULT); + + ElasticsearchOffset nextOffset = new ElasticsearchOffset(scrollId); + return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, false); + } catch (Exception e) { + throw new IOException("Failed to fetch data from Elasticsearch", e); } - - @Override - public List listPartitions() { - return Collections.singletonList(new ElasticsearchPartition(indexName)); + } + + @Override + public void close() { + try { + if (client != null) { + client.close(); + } + } catch (IOException e) { + // Log error but don't throw exception in close method + logger.warn("Failed to close Elasticsearch client", e); } + } + + private RestHighLevelClient createElasticsearchClient() { + try { + String[] hostArray = hosts.split(ES_SPLIT_COMMA); + HttpHost[] httpHosts = new HttpHost[hostArray.length]; + + for (int i = 0; i < hostArray.length; i++) { + String host = hostArray[i].trim(); + if (host.startsWith(ES_HTTP_SCHEME + ES_SCHEMA_SUFFIX)) { + host = host.substring(7); + } else if (host.startsWith(ES_HTTPS_SCHEME + ES_SCHEMA_SUFFIX)) { + host = host.substring(8); + } - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return new TableDeserializer() { - @Override - public void init(Configuration configuration, StructType structType) { - // Initialization if needed - } - - @Override - public List deserialize(IN record) { - if (record instanceof SearchHit) { - SearchHit hit = (SearchHit) record; - Map source = hit.getSourceAsMap(); - if (source == null) { - source = GSON.fromJson(hit.getSourceAsString(), MAP_TYPE); - } - - // Convert map to Row based on schema - Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - String fieldName = schema.getFields().get(i).getName(); - values[i] = source.get(fieldName); - } - Row row = ObjectRow.create(values); - return Collections.singletonList(row); - } - return Collections.emptyList(); - } - }; + String[] parts = host.split(ES_SPLIT_COLON); + String hostname = parts[0]; + int port = parts.length > 1 ? Integer.parseInt(parts[1]) : 9200; + httpHosts[i] = new HttpHost(hostname, port, ES_HTTP_SCHEME); + } + + RestClientBuilder builder = RestClient.builder(httpHosts); + + // Configure timeouts + builder.setRequestConfigCallback( + requestConfigBuilder -> { + requestConfigBuilder.setConnectTimeout(connectionTimeout); + requestConfigBuilder.setSocketTimeout(socketTimeout); + return requestConfigBuilder; + }); + + return new RestHighLevelClient(builder); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); } + } - @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - try { - SearchRequest searchRequest = new SearchRequest(indexName); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.size(DEFAULT_SEARCH_SIZE); // Batch size - - searchRequest.source(searchSourceBuilder); - - // Use scroll for large dataset reading - Scroll scroll = new Scroll(TimeValue.parseTimeValue(scrollTimeout, "scroll_timeout")); - searchRequest.scroll(scroll); - - SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); - String scrollId = searchResponse.getScrollId(); - SearchHit[] searchHits = searchResponse.getHits().getHits(); - - List dataList = new ArrayList<>(); - for (SearchHit hit : searchHits) { - dataList.add((T) hit); - } - - // Clear scroll - ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); - clearScrollRequest.addScrollId(scrollId); - client.clearScroll(clearScrollRequest, RequestOptions.DEFAULT); - - ElasticsearchOffset nextOffset = new ElasticsearchOffset(scrollId); - return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, false); - } catch (Exception e) { - throw new IOException("Failed to fetch data from Elasticsearch", e); - } - } + public static class ElasticsearchPartition implements Partition { + private final String indexName; - @Override - public void close() { - try { - if (client != null) { - client.close(); - } - } catch (IOException e) { - // Log error but don't throw exception in close method - logger.warn("Failed to close Elasticsearch client", e); - } + public ElasticsearchPartition(String indexName) { + this.indexName = indexName; } - private RestHighLevelClient createElasticsearchClient() { - try { - String[] hostArray = hosts.split(ES_SPLIT_COMMA); - HttpHost[] httpHosts = new HttpHost[hostArray.length]; - - for (int i = 0; i < hostArray.length; i++) { - String host = hostArray[i].trim(); - if (host.startsWith(ES_HTTP_SCHEME + ES_SCHEMA_SUFFIX)) { - host = host.substring(7); - } else if (host.startsWith(ES_HTTPS_SCHEME + ES_SCHEMA_SUFFIX)) { - host = host.substring(8); - } - - String[] parts = host.split(ES_SPLIT_COLON); - String hostname = parts[0]; - int port = parts.length > 1 ? Integer.parseInt(parts[1]) : 9200; - httpHosts[i] = new HttpHost(hostname, port, ES_HTTP_SCHEME); - } - - RestClientBuilder builder = RestClient.builder(httpHosts); - - // Configure timeouts - builder.setRequestConfigCallback(requestConfigBuilder -> { - requestConfigBuilder.setConnectTimeout(connectionTimeout); - requestConfigBuilder.setSocketTimeout(socketTimeout); - return requestConfigBuilder; - }); - - return new RestHighLevelClient(builder); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to create Elasticsearch client", e); - } + @Override + public String getName() { + return indexName; } + } - public static class ElasticsearchPartition implements Partition { - private final String indexName; + public static class ElasticsearchOffset implements Offset { + private final String scrollId; + private final long timestamp; - public ElasticsearchPartition(String indexName) { - this.indexName = indexName; - } - - @Override - public String getName() { - return indexName; - } + public ElasticsearchOffset(String scrollId) { + this(scrollId, System.currentTimeMillis()); } - public static class ElasticsearchOffset implements Offset { - private final String scrollId; - private final long timestamp; - - public ElasticsearchOffset(String scrollId) { - this(scrollId, System.currentTimeMillis()); - } - - public ElasticsearchOffset(String scrollId, long timestamp) { - this.scrollId = scrollId; - this.timestamp = timestamp; - } + public ElasticsearchOffset(String scrollId, long timestamp) { + this.scrollId = scrollId; + this.timestamp = timestamp; + } - public String getScrollId() { - return scrollId; - } + public String getScrollId() { + return scrollId; + } - @Override - public String humanReadable() { - return "ElasticsearchOffset{scrollId='" + scrollId + "', timestamp=" + timestamp + "}"; - } + @Override + public String humanReadable() { + return "ElasticsearchOffset{scrollId='" + scrollId + "', timestamp=" + timestamp + "}"; + } - @Override - public long getOffset() { - return timestamp; - } + @Override + public long getOffset() { + return timestamp; + } - @Override - public boolean isTimestamp() { - return true; - } + @Override + public boolean isTimestamp() { + return true; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeysTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeysTest.java index 0eaf9cf91..dfee780b9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeysTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchConfigKeysTest.java @@ -25,43 +25,52 @@ public class ElasticsearchConfigKeysTest { - @Test - public void testConfigKeys() { - Configuration config = new Configuration(); + @Test + public void testConfigKeys() { + Configuration config = new Configuration(); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, "elastic"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME, "elastic"); - Assert.assertEquals(config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS), - (Object) "localhost:9200"); - Assert.assertEquals(config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX), - (Object) "test_index"); - Assert.assertEquals(config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME), - (Object) "elastic"); - } + Assert.assertEquals( + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS), + (Object) "localhost:9200"); + Assert.assertEquals( + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX), + (Object) "test_index"); + Assert.assertEquals( + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_USERNAME), + (Object) "elastic"); + } - @Test - public void testDefaultValues() { - Configuration config = new Configuration(); + @Test + public void testDefaultValues() { + Configuration config = new Configuration(); - String batchSize = config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE); - String scrollTimeout = config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT); + String batchSize = + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE); + String scrollTimeout = + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SCROLL_TIMEOUT); - Assert.assertEquals(batchSize, (Object) String.valueOf(ElasticsearchConstants.DEFAULT_BATCH_SIZE)); - Assert.assertEquals(scrollTimeout, (Object) ElasticsearchConstants.DEFAULT_SCROLL_TIMEOUT); - } + Assert.assertEquals( + batchSize, (Object) String.valueOf(ElasticsearchConstants.DEFAULT_BATCH_SIZE)); + Assert.assertEquals(scrollTimeout, (Object) ElasticsearchConstants.DEFAULT_SCROLL_TIMEOUT); + } - @Test - public void testTimeoutValues() { - Configuration config = new Configuration(); + @Test + public void testTimeoutValues() { + Configuration config = new Configuration(); - String connectionTimeout = config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT); - String socketTimeout = config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT); + String connectionTimeout = + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_CONNECTION_TIMEOUT); + String socketTimeout = + config.getString(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_SOCKET_TIMEOUT); - Assert.assertEquals(connectionTimeout, - (Object) String.valueOf(ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT)); - Assert.assertEquals(socketTimeout, - (Object) String.valueOf(ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT)); - } + Assert.assertEquals( + connectionTimeout, + (Object) String.valueOf(ElasticsearchConstants.DEFAULT_CONNECTION_TIMEOUT)); + Assert.assertEquals( + socketTimeout, (Object) String.valueOf(ElasticsearchConstants.DEFAULT_SOCKET_TIMEOUT)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnectorTest.java index a5f594b13..eea0982b1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableConnectorTest.java @@ -20,7 +20,7 @@ package org.apache.geaflow.dsl.connector.elasticsearch; import java.util.Arrays; -import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.types.StructType; @@ -34,59 +34,59 @@ public class ElasticsearchTableConnectorTest { - private ElasticsearchTableConnector connector; - private Configuration config; - private TableSchema schema; + private ElasticsearchTableConnector connector; + private Configuration config; + private TableSchema schema; - @BeforeMethod - public void setUp() { - connector = new ElasticsearchTableConnector(); - config = new Configuration(); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, "id"); + @BeforeMethod + public void setUp() { + connector = new ElasticsearchTableConnector(); + config = new Configuration(); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, "id"); - TableField idField = new TableField("id", Types.INTEGER, false); - TableField nameField = new TableField("name", Types.STRING, false); - schema = new TableSchema(new StructType(Arrays.asList(idField, nameField))); - } + TableField idField = new TableField("id", Types.INTEGER, false); + TableField nameField = new TableField("name", Types.STRING, false); + schema = new TableSchema(new StructType(Arrays.asList(idField, nameField))); + } - @Test - public void testGetName() { - Assert.assertEquals(connector.getType(), "ELASTICSEARCH"); - } + @Test + public void testGetName() { + Assert.assertEquals(connector.getType(), "ELASTICSEARCH"); + } - @Test - public void testGetSource() { - TableSource source = connector.createSource(config); - Assert.assertNotNull(source); - Assert.assertTrue(source instanceof ElasticsearchTableSource); - } + @Test + public void testGetSource() { + TableSource source = connector.createSource(config); + Assert.assertNotNull(source); + Assert.assertTrue(source instanceof ElasticsearchTableSource); + } - @Test - public void testGetSink() { - TableSink sink = connector.createSink(config); - Assert.assertNotNull(sink); - Assert.assertTrue(sink instanceof ElasticsearchTableSink); - } + @Test + public void testGetSink() { + TableSink sink = connector.createSink(config); + Assert.assertNotNull(sink); + Assert.assertTrue(sink instanceof ElasticsearchTableSink); + } - @Test - public void testMultipleSourceInstances() { - TableSource source1 = connector.createSource(config); - TableSource source2 = connector.createSource(config); + @Test + public void testMultipleSourceInstances() { + TableSource source1 = connector.createSource(config); + TableSource source2 = connector.createSource(config); - Assert.assertNotNull(source1); - Assert.assertNotNull(source2); - Assert.assertNotSame(source1, source2); - } + Assert.assertNotNull(source1); + Assert.assertNotNull(source2); + Assert.assertNotSame(source1, source2); + } - @Test - public void testMultipleSinkInstances() { - TableSink sink1 = connector.createSink(config); - TableSink sink2 = connector.createSink(config); + @Test + public void testMultipleSinkInstances() { + TableSink sink1 = connector.createSink(config); + TableSink sink2 = connector.createSink(config); - Assert.assertNotNull(sink1); - Assert.assertNotNull(sink2); - Assert.assertNotSame(sink1, sink2); - } + Assert.assertNotNull(sink1); + Assert.assertNotNull(sink2); + Assert.assertNotSame(sink1, sink2); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSinkTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSinkTest.java index 942839522..0721ecd12 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSinkTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSinkTest.java @@ -20,7 +20,7 @@ package org.apache.geaflow.dsl.connector.elasticsearch; import java.util.Arrays; -import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; @@ -33,68 +33,68 @@ public class ElasticsearchTableSinkTest { - private ElasticsearchTableSink sink; - private Configuration config; - private StructType schema; + private ElasticsearchTableSink sink; + private Configuration config; + private StructType schema; - @BeforeMethod - public void setUp() { - sink = new ElasticsearchTableSink(); - config = new Configuration(); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, "id"); + @BeforeMethod + public void setUp() { + sink = new ElasticsearchTableSink(); + config = new Configuration(); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_DOCUMENT_ID_FIELD, "id"); - TableField idField = new TableField("id", Types.INTEGER, false); - TableField nameField = new TableField("name", Types.STRING, false); - TableField ageField = new TableField("age", Types.INTEGER, false); - schema = new StructType(Arrays.asList(idField, nameField, ageField)); - } + TableField idField = new TableField("id", Types.INTEGER, false); + TableField nameField = new TableField("name", Types.STRING, false); + TableField ageField = new TableField("age", Types.INTEGER, false); + schema = new StructType(Arrays.asList(idField, nameField, ageField)); + } - @Test - public void testInit() { - sink.init(config, schema); - Assert.assertNotNull(sink); - } + @Test + public void testInit() { + sink.init(config, schema); + Assert.assertNotNull(sink); + } - @Test(expectedExceptions = RuntimeException.class) - public void testInitWithoutIndex() { - Configuration invalidConfig = new Configuration(); - invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - sink.init(invalidConfig, schema); - } + @Test(expectedExceptions = RuntimeException.class) + public void testInitWithoutIndex() { + Configuration invalidConfig = new Configuration(); + invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + sink.init(invalidConfig, schema); + } - @Test - public void testInitWithoutIdField() { - Configuration invalidConfig = new Configuration(); - invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); - sink.init(invalidConfig, schema); - Assert.assertNotNull(sink); - } + @Test + public void testInitWithoutIdField() { + Configuration invalidConfig = new Configuration(); + invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); + sink.init(invalidConfig, schema); + Assert.assertNotNull(sink); + } - @Test - public void testBatchSizeConfiguration() { - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE, "500"); - sink.init(config, schema); - Assert.assertNotNull(sink); - } + @Test + public void testBatchSizeConfiguration() { + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_BATCH_SIZE, "500"); + sink.init(config, schema); + Assert.assertNotNull(sink); + } - @Test - public void testWriteRow() { - sink.init(config, schema); + @Test + public void testWriteRow() { + sink.init(config, schema); - Row row = ObjectRow.create(1, "Alice", 25); - Assert.assertNotNull(row); - } + Row row = ObjectRow.create(1, "Alice", 25); + Assert.assertNotNull(row); + } - @Test - public void testMultipleWrites() { - sink.init(config, schema); + @Test + public void testMultipleWrites() { + sink.init(config, schema); - for (int i = 0; i < 10; i++) { - Row row = ObjectRow.create(i, "User" + i, 20 + i); - Assert.assertNotNull(row); - } + for (int i = 0; i < 10; i++) { + Row row = ObjectRow.create(i, "User" + i, 20 + i); + Assert.assertNotNull(row); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSourceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSourceTest.java index 1a0a00443..a35b9298e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSourceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-elasticsearch/src/test/java/org/apache/geaflow/dsl/connector/elasticsearch/ElasticsearchTableSourceTest.java @@ -21,78 +21,75 @@ import java.util.Arrays; import java.util.List; -import java.util.Optional; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.Types; -import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.types.StructType; import org.apache.geaflow.dsl.common.types.TableField; import org.apache.geaflow.dsl.common.types.TableSchema; -import org.apache.geaflow.dsl.connector.api.FetchData; import org.apache.geaflow.dsl.connector.api.Partition; -import org.mockito.Mockito; import org.testng.Assert; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; public class ElasticsearchTableSourceTest { - private ElasticsearchTableSource source; - private Configuration config; - private TableSchema schema; + private ElasticsearchTableSource source; + private Configuration config; + private TableSchema schema; - @BeforeMethod - public void setUp() { - source = new ElasticsearchTableSource(); - config = new Configuration(); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); + @BeforeMethod + public void setUp() { + source = new ElasticsearchTableSource(); + config = new Configuration(); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + config.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_INDEX, "test_index"); - TableField idField = new TableField("id", Types.INTEGER, false); - TableField nameField = new TableField("name", Types.STRING, false); - schema = new TableSchema(new StructType(Arrays.asList(idField, nameField))); - } + TableField idField = new TableField("id", Types.INTEGER, false); + TableField nameField = new TableField("name", Types.STRING, false); + schema = new TableSchema(new StructType(Arrays.asList(idField, nameField))); + } - @Test - public void testInit() { - source.init(config, schema); - Assert.assertNotNull(source); - } + @Test + public void testInit() { + source.init(config, schema); + Assert.assertNotNull(source); + } - @Test(expectedExceptions = RuntimeException.class) - public void testInitWithoutIndex() { - Configuration invalidConfig = new Configuration(); - invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); - source.init(invalidConfig, schema); - } + @Test(expectedExceptions = RuntimeException.class) + public void testInitWithoutIndex() { + Configuration invalidConfig = new Configuration(); + invalidConfig.put(ElasticsearchConfigKeys.GEAFLOW_DSL_ELASTICSEARCH_HOSTS, "localhost:9200"); + source.init(invalidConfig, schema); + } - @Test - public void testListPartitions() { - source.init(config, schema); - List partitions = source.listPartitions(); + @Test + public void testListPartitions() { + source.init(config, schema); + List partitions = source.listPartitions(); - Assert.assertNotNull(partitions); - Assert.assertEquals(partitions.size(), 1); - Assert.assertEquals(partitions.get(0).getName(), "test_index"); - } + Assert.assertNotNull(partitions); + Assert.assertEquals(partitions.size(), 1); + Assert.assertEquals(partitions.get(0).getName(), "test_index"); + } - @Test - public void testGetDeserializer() { - source.init(config, schema); - Assert.assertNotNull(source.getDeserializer(config)); - } + @Test + public void testGetDeserializer() { + source.init(config, schema); + Assert.assertNotNull(source.getDeserializer(config)); + } - @Test - public void testPartitionName() { - ElasticsearchTableSource.ElasticsearchPartition partition = - new ElasticsearchTableSource.ElasticsearchPartition("my_index"); - Assert.assertEquals(partition.getName(), "my_index"); - } + @Test + public void testPartitionName() { + ElasticsearchTableSource.ElasticsearchPartition partition = + new ElasticsearchTableSource.ElasticsearchPartition("my_index"); + Assert.assertEquals(partition.getName(), "my_index"); + } - @Test - public void testOffsetHumanReadable() { - ElasticsearchTableSource.ElasticsearchOffset offset = - new ElasticsearchTableSource.ElasticsearchOffset("scroll_123"); - Assert.assertTrue(offset.humanReadable().contains("scroll_123")); - } + @Test + public void testOffsetHumanReadable() { + ElasticsearchTableSource.ElasticsearchOffset offset = + new ElasticsearchTableSource.ElasticsearchOffset("scroll_123"); + Assert.assertTrue(offset.humanReadable().contains("scroll_123")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConnectorUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConnectorUtil.java index dc50f7301..288e1c9fd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConnectorUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConnectorUtil.java @@ -21,9 +21,8 @@ import static org.apache.geaflow.dsl.connector.file.FileConstants.PREFIX_S3_RESOURCE; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.BasicAWSCredentials; import java.util.Map; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -34,76 +33,78 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.BasicAWSCredentials; + public class FileConnectorUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(FileConnectorUtil.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FileConnectorUtil.class); - private static final String HADOOP_HOME = "HADOOP_HOME"; + private static final String HADOOP_HOME = "HADOOP_HOME"; - public static String getPartitionFileName(int taskIndex) { - return "partition_" + taskIndex; - } + public static String getPartitionFileName(int taskIndex) { + return "partition_" + taskIndex; + } - public static FileSystem getHdfsFileSystem(Configuration conf) { - org.apache.hadoop.conf.Configuration hadoopConf = toHadoopConf(conf); - FileSystem fileSystem; - try { - fileSystem = FileSystem.newInstance(hadoopConf); - } catch (Exception e) { - throw new GeaflowRuntimeException("Cannot init hdfs file system.", e); - } - return fileSystem; + public static FileSystem getHdfsFileSystem(Configuration conf) { + org.apache.hadoop.conf.Configuration hadoopConf = toHadoopConf(conf); + FileSystem fileSystem; + try { + fileSystem = FileSystem.newInstance(hadoopConf); + } catch (Exception e) { + throw new GeaflowRuntimeException("Cannot init hdfs file system.", e); } + return fileSystem; + } - public static org.apache.hadoop.conf.Configuration toHadoopConf(Configuration conf) { - org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration(); - String hadoopConfPath = System.getenv(HADOOP_HOME); - if (!StringUtils.isEmpty(hadoopConfPath)) { - LOGGER.info("find hadoop home at: {}", hadoopConfPath); - hadoopConf.addResource(new Path(hadoopConfPath + "/etc/hadoop/core-site.xml")); - hadoopConf.addResource(new Path(hadoopConfPath + "/etc/hadoop/hdfs-site.xml")); - } - if (conf.contains(FileConfigKeys.JSON_CONFIG)) { - String userConfigStr = conf.getString(FileConfigKeys.JSON_CONFIG); - Map userConfig = GsonUtil.parse(userConfigStr); - if (userConfig != null) { - for (Map.Entry entry : userConfig.entrySet()) { - hadoopConf.set(entry.getKey(), entry.getValue()); - } - } - } - if (conf.contains("fs.defaultFS")) { - hadoopConf.set("fs.defaultFS", conf.getString("fs.defaultFS")); + public static org.apache.hadoop.conf.Configuration toHadoopConf(Configuration conf) { + org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration(); + String hadoopConfPath = System.getenv(HADOOP_HOME); + if (!StringUtils.isEmpty(hadoopConfPath)) { + LOGGER.info("find hadoop home at: {}", hadoopConfPath); + hadoopConf.addResource(new Path(hadoopConfPath + "/etc/hadoop/core-site.xml")); + hadoopConf.addResource(new Path(hadoopConfPath + "/etc/hadoop/hdfs-site.xml")); + } + if (conf.contains(FileConfigKeys.JSON_CONFIG)) { + String userConfigStr = conf.getString(FileConfigKeys.JSON_CONFIG); + Map userConfig = GsonUtil.parse(userConfigStr); + if (userConfig != null) { + for (Map.Entry entry : userConfig.entrySet()) { + hadoopConf.set(entry.getKey(), entry.getValue()); } - return hadoopConf; + } } - - public static AWSCredentials getS3Credentials(Configuration conf) { - String accessKey = conf.getString(FileConstants.S3_ACCESS_KEY); - String secretKey = conf.getString(FileConstants.S3_SECRET_KEY); - AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey); - return credentials; + if (conf.contains("fs.defaultFS")) { + hadoopConf.set("fs.defaultFS", conf.getString("fs.defaultFS")); } + return hadoopConf; + } - public static String getS3ServiceEndpoint(Configuration conf) { - return conf.getString(FileConstants.S3_SERVICE_ENDPOINT); - } + public static AWSCredentials getS3Credentials(Configuration conf) { + String accessKey = conf.getString(FileConstants.S3_ACCESS_KEY); + String secretKey = conf.getString(FileConstants.S3_SECRET_KEY); + AWSCredentials credentials = new BasicAWSCredentials(accessKey, secretKey); + return credentials; + } - public static String[] getFileUri(String path) { - String uri = path.substring(PREFIX_S3_RESOURCE.length()); - return uri.split("/"); - } + public static String getS3ServiceEndpoint(Configuration conf) { + return conf.getString(FileConstants.S3_SERVICE_ENDPOINT); + } - public static String getBucket(String path) { - String[] paths = getFileUri(path); - return paths[0]; - } + public static String[] getFileUri(String path) { + String uri = path.substring(PREFIX_S3_RESOURCE.length()); + return uri.split("/"); + } - public static String getKey(String path) { - String[] paths = getFileUri(path); - String[] keys = new String[paths.length - 1]; - System.arraycopy(paths, 1, keys, 0, keys.length); - return String.join("/", keys); - } + public static String getBucket(String path) { + String[] paths = getFileUri(path); + return paths[0]; + } + public static String getKey(String path) { + String[] paths = getFileUri(path); + String[] keys = new String[paths.length - 1]; + System.arraycopy(paths, 1, keys, 0, keys.length); + return String.join("/", keys); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConstants.java index 1ed86bd82..77c2d3328 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileConstants.java @@ -21,11 +21,10 @@ public class FileConstants { - public static final String PREFIX_JAVA_RESOURCE = "resource://"; - - public static final String PREFIX_S3_RESOURCE = "s3://"; - public static final String S3_ACCESS_KEY = "geaflow.store.s3.access.key"; - public static final String S3_SECRET_KEY = "geaflow.store.s3.secret.key"; - public static final String S3_SERVICE_ENDPOINT = "geaflow.store.s3.service.endpoint"; + public static final String PREFIX_JAVA_RESOURCE = "resource://"; + public static final String PREFIX_S3_RESOURCE = "s3://"; + public static final String S3_ACCESS_KEY = "geaflow.store.s3.access.key"; + public static final String S3_SECRET_KEY = "geaflow.store.s3.secret.key"; + public static final String S3_SERVICE_ENDPOINT = "geaflow.store.s3.service.endpoint"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileTableConnector.java index 4ec87ca72..9fc3a8a77 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/FileTableConnector.java @@ -29,20 +29,20 @@ public class FileTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "FILE"; + public static final String TYPE = "FILE"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new FileTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new FileTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new FileTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new FileTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileTableSink.java index 9033b3407..c786da5da 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileTableSink.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Objects; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -33,67 +34,67 @@ public class FileTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(FileTableSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FileTableSink.class); - private String path; + private String path; - private String separator; + private String separator; - private StructType schema; + private StructType schema; - private Configuration tableConf; + private Configuration tableConf; - protected transient FileWriteHandler writer; + protected transient FileWriteHandler writer; - @Override - public void init(Configuration tableConf, StructType schema) { - this.path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); - this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); - this.schema = Objects.requireNonNull(schema); - this.tableConf = tableConf; - } + @Override + public void init(Configuration tableConf, StructType schema) { + this.path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); + this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + this.schema = Objects.requireNonNull(schema); + this.tableConf = tableConf; + } - @Override - public void open(RuntimeContext context) { - this.writer = FileWriteHandlers.from(path, tableConf); - this.writer.init(tableConf, schema, context.getTaskArgs().getTaskIndex()); - } + @Override + public void open(RuntimeContext context) { + this.writer = FileWriteHandlers.from(path, tableConf); + this.writer.init(tableConf, schema, context.getTaskArgs().getTaskIndex()); + } - @Override - public void write(Row row) throws IOException { - Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - values[i] = row.getField(i, schema.getType(i)); - } - - StringBuilder line = new StringBuilder(); - for (Object value : values) { - if (line.length() > 0) { - line.append(separator); - } - line.append(value); - } - writer.write(line + "\n"); + @Override + public void write(Row row) throws IOException { + Object[] values = new Object[schema.size()]; + for (int i = 0; i < schema.size(); i++) { + values[i] = row.getField(i, schema.getType(i)); } - @Override - public void finish() throws IOException { - String split = tableConf.getString(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), null); - if (split != null) { - writer.write(split + "\n"); - } - - writer.flush(); + StringBuilder line = new StringBuilder(); + for (Object value : values) { + if (line.length() > 0) { + line.append(separator); + } + line.append(value); + } + writer.write(line + "\n"); + } + + @Override + public void finish() throws IOException { + String split = tableConf.getString(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), null); + if (split != null) { + writer.write(split + "\n"); } - @Override - public void close() { - if (writer != null) { - try { - writer.close(); - } catch (IOException e) { - LOGGER.warn("Error in close writer", e); - } - } + writer.flush(); + } + + @Override + public void close() { + if (writer != null) { + try { + writer.close(); + } catch (IOException e) { + LOGGER.warn("Error in close writer", e); + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandler.java index 2c051ecc9..9faf8405f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandler.java @@ -20,16 +20,17 @@ package org.apache.geaflow.dsl.connector.file.sink; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.types.StructType; public interface FileWriteHandler { - void init(Configuration tableConf, StructType schema, int taskIndex); + void init(Configuration tableConf, StructType schema, int taskIndex); - void write(String text) throws IOException; + void write(String text) throws IOException; - void flush() throws IOException; + void flush() throws IOException; - void close() throws IOException; + void close() throws IOException; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandlers.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandlers.java index de37501b1..b91dcc41e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandlers.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/FileWriteHandlers.java @@ -28,14 +28,14 @@ public class FileWriteHandlers { - public static FileWriteHandler from(String path, Configuration conf) { - if (path.startsWith(PREFIX_S3_RESOURCE)) { - return new S3FileWriteHandler(path); - } - FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); - if (fs instanceof LocalFileSystem) { - return new LocalFileWriteHandler(path); - } - return new HdfsFileWriteHandler(path); + public static FileWriteHandler from(String path, Configuration conf) { + if (path.startsWith(PREFIX_S3_RESOURCE)) { + return new S3FileWriteHandler(path); } + FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); + if (fs instanceof LocalFileSystem) { + return new LocalFileWriteHandler(path); + } + return new HdfsFileWriteHandler(path); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/HdfsFileWriteHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/HdfsFileWriteHandler.java index 99fc41684..2f6bcdc4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/HdfsFileWriteHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/HdfsFileWriteHandler.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.connector.file.sink; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.common.types.StructType; @@ -32,64 +33,64 @@ public class HdfsFileWriteHandler implements FileWriteHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(HdfsFileWriteHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HdfsFileWriteHandler.class); - protected Configuration conf; - protected StructType schema; - protected int taskIndex; - protected final String baseDir; + protected Configuration conf; + protected StructType schema; + protected int taskIndex; + protected final String baseDir; - protected transient FileSystem fileSystem; - protected transient FSDataOutputStream writer; + protected transient FileSystem fileSystem; + protected transient FSDataOutputStream writer; - public HdfsFileWriteHandler(String baseDir) { - this.baseDir = baseDir; - } + public HdfsFileWriteHandler(String baseDir) { + this.baseDir = baseDir; + } - @Override - public void init(Configuration tableConf, StructType schema, int taskIndex) { - this.conf = tableConf; - this.schema = schema; - this.taskIndex = taskIndex; - this.fileSystem = FileConnectorUtil.getHdfsFileSystem(conf); + @Override + public void init(Configuration tableConf, StructType schema, int taskIndex) { + this.conf = tableConf; + this.schema = schema; + this.taskIndex = taskIndex; + this.fileSystem = FileConnectorUtil.getHdfsFileSystem(conf); - Path dirPath = new Path(baseDir); - Path filePath = new Path(dirPath, FileConnectorUtil.getPartitionFileName(taskIndex)); - filePath = fileSystem.makeQualified(filePath); - try { - if (!fileSystem.exists(new Path(baseDir))) { - fileSystem.mkdirs(new Path(baseDir)); - LOGGER.info("mkdirs {}", baseDir); - } - if (fileSystem.exists(filePath)) { - String newPath = filePath + "_" + System.currentTimeMillis(); - this.writer = fileSystem.create(new Path(newPath)); - LOGGER.info("path {} exists, create new file path {}", filePath, newPath); - } else { - this.writer = fileSystem.create(filePath); - LOGGER.info("create file path {}", filePath); - } - } catch (IOException e) { - throw new GeaFlowDSLException("Error in create file: " + filePath, e); - } + Path dirPath = new Path(baseDir); + Path filePath = new Path(dirPath, FileConnectorUtil.getPartitionFileName(taskIndex)); + filePath = fileSystem.makeQualified(filePath); + try { + if (!fileSystem.exists(new Path(baseDir))) { + fileSystem.mkdirs(new Path(baseDir)); + LOGGER.info("mkdirs {}", baseDir); + } + if (fileSystem.exists(filePath)) { + String newPath = filePath + "_" + System.currentTimeMillis(); + this.writer = fileSystem.create(new Path(newPath)); + LOGGER.info("path {} exists, create new file path {}", filePath, newPath); + } else { + this.writer = fileSystem.create(filePath); + LOGGER.info("create file path {}", filePath); + } + } catch (IOException e) { + throw new GeaFlowDSLException("Error in create file: " + filePath, e); } + } - @Override - public void write(String text) throws IOException { - this.writer.write(text.getBytes()); - } + @Override + public void write(String text) throws IOException { + this.writer.write(text.getBytes()); + } - @Override - public void flush() throws IOException { - this.writer.flush(); - this.writer.hflush(); - } + @Override + public void flush() throws IOException { + this.writer.flush(); + this.writer.hflush(); + } - @Override - public void close() throws IOException { - if (this.writer != null) { - flush(); - this.writer.close(); - } + @Override + public void close() throws IOException { + if (this.writer != null) { + flush(); + this.writer.close(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/LocalFileWriteHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/LocalFileWriteHandler.java index 82248712c..46e07a547 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/LocalFileWriteHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/LocalFileWriteHandler.java @@ -24,6 +24,7 @@ import java.io.FileWriter; import java.io.IOException; import java.io.Writer; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -34,64 +35,66 @@ public class LocalFileWriteHandler implements FileWriteHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalFileWriteHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LocalFileWriteHandler.class); - private final String baseDir; + private final String baseDir; - protected String targetFile; + protected String targetFile; - private Writer writer; + private Writer writer; - public LocalFileWriteHandler(String baseDir) { - this.baseDir = baseDir; - } + public LocalFileWriteHandler(String baseDir) { + this.baseDir = baseDir; + } - @Override - public void init(Configuration tableConf, StructType schema, int taskIndex) { - File dirPath = new File(baseDir); - File filePath = new File(dirPath, FileConnectorUtil.getPartitionFileName(taskIndex)); - try { - if (!dirPath.exists()) { - dirPath.mkdirs(); - } - if (filePath.exists()) { - if (Configuration.getString(ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION, - (String) ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION.getDefaultValue(), - tableConf.getConfigMap()).equals(ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION.getDefaultValue())) { - String newPath = filePath + "_" + System.currentTimeMillis(); - targetFile = newPath; - this.writer = new BufferedWriter(new FileWriter(newPath)); - LOGGER.info("path {} exists, create new file path {}", filePath, newPath); - } else { - filePath.delete(); - targetFile = filePath.getAbsolutePath(); - this.writer = new BufferedWriter(new FileWriter(filePath)); - LOGGER.info("path {} exists, replace it {}", filePath); - } - } else { - targetFile = filePath.getAbsolutePath(); - this.writer = new BufferedWriter(new FileWriter(filePath)); - LOGGER.info("create file path {}", filePath); - } - } catch (IOException e) { - throw new GeaFlowDSLException("Error in create file: " + filePath, e); + @Override + public void init(Configuration tableConf, StructType schema, int taskIndex) { + File dirPath = new File(baseDir); + File filePath = new File(dirPath, FileConnectorUtil.getPartitionFileName(taskIndex)); + try { + if (!dirPath.exists()) { + dirPath.mkdirs(); + } + if (filePath.exists()) { + if (Configuration.getString( + ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION, + (String) ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION.getDefaultValue(), + tableConf.getConfigMap()) + .equals(ConnectorConfigKeys.GEAFLOW_DSL_SINK_FILE_COLLISION.getDefaultValue())) { + String newPath = filePath + "_" + System.currentTimeMillis(); + targetFile = newPath; + this.writer = new BufferedWriter(new FileWriter(newPath)); + LOGGER.info("path {} exists, create new file path {}", filePath, newPath); + } else { + filePath.delete(); + targetFile = filePath.getAbsolutePath(); + this.writer = new BufferedWriter(new FileWriter(filePath)); + LOGGER.info("path {} exists, replace it {}", filePath); } + } else { + targetFile = filePath.getAbsolutePath(); + this.writer = new BufferedWriter(new FileWriter(filePath)); + LOGGER.info("create file path {}", filePath); + } + } catch (IOException e) { + throw new GeaFlowDSLException("Error in create file: " + filePath, e); } + } - @Override - public void write(String text) throws IOException { - writer.write(text); - } + @Override + public void write(String text) throws IOException { + writer.write(text); + } - @Override - public void flush() throws IOException { - writer.flush(); - } + @Override + public void flush() throws IOException { + writer.flush(); + } - @Override - public void close() throws IOException { - if (writer != null) { - writer.close(); - } + @Override + public void close() throws IOException { + if (writer != null) { + writer.close(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/S3FileWriteHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/S3FileWriteHandler.java index 0de1c5c6b..15efa5caf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/S3FileWriteHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/sink/S3FileWriteHandler.java @@ -19,69 +19,74 @@ package org.apache.geaflow.dsl.connector.file.sink; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import java.io.File; import java.io.IOException; import java.util.UUID; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.types.StructType; import org.apache.geaflow.dsl.connector.file.FileConnectorUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; public class S3FileWriteHandler extends LocalFileWriteHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(S3FileWriteHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(S3FileWriteHandler.class); - private static final String TMP_PATH = "/tmp/"; + private static final String TMP_PATH = "/tmp/"; - protected String path; - protected AWSCredentials credentials; - protected String serviceEndpoint; + protected String path; + protected AWSCredentials credentials; + protected String serviceEndpoint; - protected AmazonS3 s3; + protected AmazonS3 s3; - public S3FileWriteHandler(String baseDir) { - super(TMP_PATH + UUID.randomUUID()); - path = baseDir; - } + public S3FileWriteHandler(String baseDir) { + super(TMP_PATH + UUID.randomUUID()); + path = baseDir; + } - @Override - public void init(Configuration tableConf, StructType schema, int taskIndex) { - super.init(tableConf, schema, taskIndex); - this.credentials = FileConnectorUtil.getS3Credentials(tableConf); - this.serviceEndpoint = FileConnectorUtil.getS3ServiceEndpoint(tableConf); - s3 = AmazonS3ClientBuilder.standard() - .withCredentials(new AWSCredentialsProvider() { - @Override - public AWSCredentials getCredentials() { + @Override + public void init(Configuration tableConf, StructType schema, int taskIndex) { + super.init(tableConf, schema, taskIndex); + this.credentials = FileConnectorUtil.getS3Credentials(tableConf); + this.serviceEndpoint = FileConnectorUtil.getS3ServiceEndpoint(tableConf); + s3 = + AmazonS3ClientBuilder.standard() + .withCredentials( + new AWSCredentialsProvider() { + @Override + public AWSCredentials getCredentials() { return credentials; - } + } - @Override - public void refresh() { - } - }) - .withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, null)) + @Override + public void refresh() {} + }) + .withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, null)) .build(); + } - } - - @Override - public void flush() throws IOException { - super.flush(); - File file = new File(targetFile); - s3.putObject(FileConnectorUtil.getBucket(path), FileConnectorUtil.getKey(path) + "/" + file.getName(), file); - } + @Override + public void flush() throws IOException { + super.flush(); + File file = new File(targetFile); + s3.putObject( + FileConnectorUtil.getBucket(path), + FileConnectorUtil.getKey(path) + "/" + file.getName(), + file); + } - @Override - public void close() throws IOException { - super.close(); - s3.shutdown(); - } + @Override + public void close() throws IOException { + super.close(); + s3.shutdown(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/AbstractFileReadHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/AbstractFileReadHandler.java index ae7b5caff..701743719 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/AbstractFileReadHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/AbstractFileReadHandler.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -41,112 +42,115 @@ public abstract class AbstractFileReadHandler implements FileReadHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractFileReadHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractFileReadHandler.class); - protected String formatName; - protected Configuration tableConf; - private TableSchema tableSchema; - protected Path path; + protected String formatName; + protected Configuration tableConf; + private TableSchema tableSchema; + protected Path path; - protected Map fileFormats = new HashMap<>(); + protected Map fileFormats = new HashMap<>(); - @Override - public void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException { - this.formatName = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_FORMAT); - this.tableConf = tableConf; - this.tableSchema = tableSchema; - this.path = new Path(path); - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema, String path) + throws IOException { + this.formatName = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_FORMAT); + this.tableConf = tableConf; + this.tableSchema = tableSchema; + this.path = new Path(path); + } - @Override - public FetchData readPartition(FileSplit split, FileOffset offset, int windowSize) throws IOException { - FileFormat format = getFileFormat(split, offset); - if (windowSize == Integer.MAX_VALUE) { // read all data from file - Iterator iterator = format.batchRead(); - return FetchData.createBatchFetch(iterator, new FileOffset(-1)); - } else { - if (!(format instanceof StreamFormat)) { - throw new GeaFlowDSLException("Format '{}' is not a stream format, the window size should be -1", - formatName); - } - StreamFormat streamFormat = (StreamFormat) format; - return streamFormat.streamRead(offset, windowSize); - } + @Override + public FetchData readPartition(FileSplit split, FileOffset offset, int windowSize) + throws IOException { + FileFormat format = getFileFormat(split, offset); + if (windowSize == Integer.MAX_VALUE) { // read all data from file + Iterator iterator = format.batchRead(); + return FetchData.createBatchFetch(iterator, new FileOffset(-1)); + } else { + if (!(format instanceof StreamFormat)) { + throw new GeaFlowDSLException( + "Format '{}' is not a stream format, the window size should be -1", formatName); + } + StreamFormat streamFormat = (StreamFormat) format; + return streamFormat.streamRead(offset, windowSize); } + } - private FileFormat getFileFormat(FileSplit split, FileOffset offset) throws IOException { - if (!fileFormats.containsKey(split)) { - FileFormat fileFormat = FileFormats.loadFileFormat(formatName); - fileFormat.init(tableConf, tableSchema, split); - // skip pre offset for stream read. - if (fileFormat instanceof StreamFormat) { - ((StreamFormat) fileFormat).skip(offset.getOffset()); - } - fileFormats.put(split, fileFormat); - } - return fileFormats.get(split); + private FileFormat getFileFormat(FileSplit split, FileOffset offset) throws IOException { + if (!fileFormats.containsKey(split)) { + FileFormat fileFormat = FileFormats.loadFileFormat(formatName); + fileFormat.init(tableConf, tableSchema, split); + // skip pre offset for stream read. + if (fileFormat instanceof StreamFormat) { + ((StreamFormat) fileFormat).skip(offset.getOffset()); + } + fileFormats.put(split, fileFormat); } + return fileFormats.get(split); + } - @Override - public void close() throws IOException { - for (FileFormat format : fileFormats.values()) { - format.close(); - } + @Override + public void close() throws IOException { + for (FileFormat format : fileFormats.values()) { + format.close(); } + } - @SuppressWarnings("unchecked") - @Override - public TableDeserializer getDeserializer() { - return (TableDeserializer) FileFormats.loadFileFormat(formatName).getDeserializer(); - } + @SuppressWarnings("unchecked") + @Override + public TableDeserializer getDeserializer() { + return (TableDeserializer) FileFormats.loadFileFormat(formatName).getDeserializer(); + } - public int findLineSplitSize(InputStream inputStream) { - try { - if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { - return 1; - } else { - int lineSplitSize = 1; - int c; - while (true) { - c = inputStream.read(); - if (c == -1) { - break; - } else if (c == '\n') { - break; - } else if (c == '\r') { - int c2 = inputStream.read(); - if (c2 == '\n') { - lineSplitSize = 2; - break; - } else if (c2 == -1) { - break; - } - } - } - return lineSplitSize; + public int findLineSplitSize(InputStream inputStream) { + try { + if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { + return 1; + } else { + int lineSplitSize = 1; + int c; + while (true) { + c = inputStream.read(); + if (c == -1) { + break; + } else if (c == '\n') { + break; + } else if (c == '\r') { + int c2 = inputStream.read(); + if (c2 == '\n') { + lineSplitSize = 2; + break; + } else if (c2 == -1) { + break; } - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new RuntimeException(e); + } } + return lineSplitSize; + } + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new RuntimeException(e); } + } - protected long findNextStartPos(long expectPos, long totalLength, InputStream inputStream) throws Exception { - if (expectPos >= totalLength) { - return totalLength; - } - inputStream.skip(expectPos); - byte[] buffer = new byte[1]; - int readSize = 0; - do { - readSize = inputStream.read(buffer); - if (readSize != -1) { - expectPos++; - } - } while (readSize != -1 && !"\n".equalsIgnoreCase(new String(buffer))); - if (expectPos >= totalLength) { - expectPos = totalLength; - } - return expectPos; + protected long findNextStartPos(long expectPos, long totalLength, InputStream inputStream) + throws Exception { + if (expectPos >= totalLength) { + return totalLength; + } + inputStream.skip(expectPos); + byte[] buffer = new byte[1]; + int readSize = 0; + do { + readSize = inputStream.read(buffer); + if (readSize != -1) { + expectPos++; + } + } while (readSize != -1 && !"\n".equalsIgnoreCase(new String(buffer))); + if (expectPos >= totalLength) { + expectPos = totalLength; } + return expectPos; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/DfsFileReadHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/DfsFileReadHandler.java index c27dbe3c4..9dda9ae13 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/DfsFileReadHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/DfsFileReadHandler.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; + import org.apache.avro.generic.GenericData; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -47,107 +48,121 @@ public class DfsFileReadHandler extends AbstractFileReadHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(DfsFileReadHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DfsFileReadHandler.class); - protected Configuration tableConf; + protected Configuration tableConf; - @Override - public void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException { - super.init(tableConf, tableSchema, path); - org.apache.hadoop.conf.Configuration hadoopConf = FileConnectorUtil.toHadoopConf(tableConf); - this.path = this.path.getFileSystem(hadoopConf).makeQualified(new Path(path)); - this.fileFormats = new HashMap<>(); - this.tableConf = tableConf; - LOGGER.info("init hdfs file system. path: {}", path); - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema, String path) + throws IOException { + super.init(tableConf, tableSchema, path); + org.apache.hadoop.conf.Configuration hadoopConf = FileConnectorUtil.toHadoopConf(tableConf); + this.path = this.path.getFileSystem(hadoopConf).makeQualified(new Path(path)); + this.fileFormats = new HashMap<>(); + this.tableConf = tableConf; + LOGGER.info("init hdfs file system. path: {}", path); + } - @Override - public List listPartitions(int parallelism) { - try { - int lineSplitSize = findLineSplitSize(); - if (!tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { - FileSystem fileSystem = FileConnectorUtil.getHdfsFileSystem(tableConf); - if (fileSystem.isDirectory(path)) { - RemoteIterator files = fileSystem.listFiles(path, true); - List partitions = new ArrayList<>(); - while (files.hasNext()) { - LocatedFileStatus f = files.next(); - String relativePath = f.getPath().getName(); - FileSplit partition = new FileSplit(path.toString(), relativePath, lineSplitSize); - partitions.add(partition); - } - return partitions; - } else { - return splitSingleFile(parallelism, lineSplitSize); - } - } else { - return splitSingleFile(parallelism, lineSplitSize); - } - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new GeaflowRuntimeException("Cannot get partitions with path: " + path, e); + @Override + public List listPartitions(int parallelism) { + try { + int lineSplitSize = findLineSplitSize(); + if (!tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { + FileSystem fileSystem = FileConnectorUtil.getHdfsFileSystem(tableConf); + if (fileSystem.isDirectory(path)) { + RemoteIterator files = fileSystem.listFiles(path, true); + List partitions = new ArrayList<>(); + while (files.hasNext()) { + LocatedFileStatus f = files.next(); + String relativePath = f.getPath().getName(); + FileSplit partition = new FileSplit(path.toString(), relativePath, lineSplitSize); + partitions.add(partition); + } + return partitions; + } else { + return splitSingleFile(parallelism, lineSplitSize); } + } else { + return splitSingleFile(parallelism, lineSplitSize); + } + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new GeaflowRuntimeException("Cannot get partitions with path: " + path, e); } + } - public int findLineSplitSize() throws Exception { - if (tableConf.getInteger(ConnectorConfigKeys.GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE) > 0) { - return tableConf.getInteger(ConnectorConfigKeys.GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE); - } else if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { - return 1; - } else { - FileSystem fs = FileConnectorUtil.getHdfsFileSystem(tableConf); - Path currentPath = path; - if (fs.isDirectory(path)) { - RemoteIterator files = fs.listFiles(path, true); - if (!files.hasNext()) { - return 1; - } - currentPath = files.next().getPath(); - } - FSDataInputStream inputStream = fs.open(currentPath); - int lineSplitSize = findLineSplitSize(inputStream); - inputStream.close(); - return lineSplitSize; + public int findLineSplitSize() throws Exception { + if (tableConf.getInteger(ConnectorConfigKeys.GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE) > 0) { + return tableConf.getInteger(ConnectorConfigKeys.GEAFLOW_DSL_FILE_LINE_SPLIT_SIZE); + } else if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { + return 1; + } else { + FileSystem fs = FileConnectorUtil.getHdfsFileSystem(tableConf); + Path currentPath = path; + if (fs.isDirectory(path)) { + RemoteIterator files = fs.listFiles(path, true); + if (!files.hasNext()) { + return 1; } + currentPath = files.next().getPath(); + } + FSDataInputStream inputStream = fs.open(currentPath); + int lineSplitSize = findLineSplitSize(inputStream); + inputStream.close(); + return lineSplitSize; } + } - public List splitSingleFile(int parallelism, int lineSplitSize) throws Exception { - String relativePath = path.getName(); - String directory = path.getParent().toUri().getPath(); - if (parallelism == 1 || tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { - FileSplit partition = new FileSplit(directory, relativePath); - return Collections.singletonList(partition); - } else { - if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { - ParquetInputFormat parquetInputFormat = new AvroParquetInputFormat<>(); - List inputSplits = parquetInputFormat.getSplits(new JobContextImpl(FileConnectorUtil.toHadoopConf(tableConf), new JobID())); - List fileSplits = new ArrayList<>(); - inputSplits.stream().forEach(split -> fileSplits.add(new FileSplit(directory, relativePath, lineSplitSize, - ((org.apache.hadoop.mapreduce.lib.input.FileSplit) split).getStart(), - ((org.apache.hadoop.mapreduce.lib.input.FileSplit) split).getLength()))); - return fileSplits; - } else { - FileSystem fs = FileConnectorUtil.getHdfsFileSystem(tableConf); - long fileLength = fs.getFileStatus(path).getLen(); - long splitSize = fileLength / parallelism; - if (splitSize == 0) { - splitSize = 1; - } - long startPos = 0; - long endPos = 0; - List fileSplits = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - startPos = endPos; - FSDataInputStream inputStream = fs.open(path); - endPos = findNextStartPos(startPos + splitSize, fileLength, inputStream); - inputStream.close(); - fileSplits.add(new FileSplit(directory, relativePath, i, lineSplitSize, startPos, endPos - startPos)); - if (endPos >= fileLength) { - break; - } - } - return fileSplits; - } + public List splitSingleFile(int parallelism, int lineSplitSize) throws Exception { + String relativePath = path.getName(); + String directory = path.getParent().toUri().getPath(); + if (parallelism == 1 + || tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { + FileSplit partition = new FileSplit(directory, relativePath); + return Collections.singletonList(partition); + } else { + if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { + ParquetInputFormat parquetInputFormat = new AvroParquetInputFormat<>(); + List inputSplits = + parquetInputFormat.getSplits( + new JobContextImpl(FileConnectorUtil.toHadoopConf(tableConf), new JobID())); + List fileSplits = new ArrayList<>(); + inputSplits.stream() + .forEach( + split -> + fileSplits.add( + new FileSplit( + directory, + relativePath, + lineSplitSize, + ((org.apache.hadoop.mapreduce.lib.input.FileSplit) split).getStart(), + ((org.apache.hadoop.mapreduce.lib.input.FileSplit) split) + .getLength()))); + return fileSplits; + } else { + FileSystem fs = FileConnectorUtil.getHdfsFileSystem(tableConf); + long fileLength = fs.getFileStatus(path).getLen(); + long splitSize = fileLength / parallelism; + if (splitSize == 0) { + splitSize = 1; + } + long startPos = 0; + long endPos = 0; + List fileSplits = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + startPos = endPos; + FSDataInputStream inputStream = fs.open(path); + endPos = findNextStartPos(startPos + splitSize, fileLength, inputStream); + inputStream.close(); + fileSplits.add( + new FileSplit( + directory, relativePath, i, lineSplitSize, startPos, endPos - startPos)); + if (endPos >= fileLength) { + break; + } } + return fileSplits; + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandler.java index 6502908fd..2f5718dce 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandler.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.types.TableSchema; import org.apache.geaflow.dsl.connector.api.FetchData; @@ -32,13 +33,14 @@ public interface FileReadHandler extends Serializable { - void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException; + void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException; - List listPartitions(int parallelism); + List listPartitions(int parallelism); - FetchData readPartition(FileSplit split, FileOffset offset, int windowSize) throws IOException; + FetchData readPartition(FileSplit split, FileOffset offset, int windowSize) + throws IOException; - void close() throws IOException; + void close() throws IOException; - TableDeserializer getDeserializer(); + TableDeserializer getDeserializer(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandlers.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandlers.java index 77c1bc411..7478eb9e1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandlers.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileReadHandlers.java @@ -24,13 +24,13 @@ public class FileReadHandlers { - public static FileReadHandler from(String path) { - if (path.startsWith(PREFIX_JAVA_RESOURCE)) { - return new JarFileReadHandler(); - } - if (path.startsWith(PREFIX_S3_RESOURCE)) { - return new S3FileReadHandler(); - } - return new DfsFileReadHandler(); + public static FileReadHandler from(String path) { + if (path.startsWith(PREFIX_JAVA_RESOURCE)) { + return new JarFileReadHandler(); } + if (path.startsWith(PREFIX_S3_RESOURCE)) { + return new S3FileReadHandler(); + } + return new DfsFileReadHandler(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileTableSource.java index f6c9a5a34..0a51ae102 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/FileTableSource.java @@ -26,6 +26,7 @@ import java.util.Objects; import java.util.Optional; import java.util.regex.Pattern; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; @@ -48,265 +49,277 @@ public class FileTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(FileTableSource.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FileTableSource.class); - private String path; + private String path; - private Configuration tableConf; + private Configuration tableConf; - private TableSchema tableSchema; + private TableSchema tableSchema; - private String nameFilterRegex; + private String nameFilterRegex; - private transient FileReadHandler fileReadHandler; + private transient FileReadHandler fileReadHandler; - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); - this.tableConf = tableConf; - this.tableSchema = tableSchema; - this.nameFilterRegex = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_NAME_REGEX); - LOGGER.info("init table source with tableConf: {}", tableConf); - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); + this.tableConf = tableConf; + this.tableSchema = tableSchema; + this.nameFilterRegex = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_NAME_REGEX); + LOGGER.info("init table source with tableConf: {}", tableConf); + } - @Override - public void open(RuntimeContext context) { - this.fileReadHandler = FileReadHandlers.from(path); - try { - this.fileReadHandler.init(tableConf, tableSchema, path); - } catch (IOException e) { - throw new GeaFlowDSLException("Error in open file source", e); - } - LOGGER.info("open table source on path: {}", path); + @Override + public void open(RuntimeContext context) { + this.fileReadHandler = FileReadHandlers.from(path); + try { + this.fileReadHandler.init(tableConf, tableSchema, path); + } catch (IOException e) { + throw new GeaFlowDSLException("Error in open file source", e); } - - @Override - public List listPartitions() { - return listPartitions(1); - } - - @Override - public List listPartitions(int parallelism) { - List allPartitions = fileReadHandler.listPartitions(parallelism); - if (StringUtils.isNotEmpty(this.nameFilterRegex)) { - List filterPartitions = new ArrayList<>(); - for (Partition partition : allPartitions) { - if (!partition.getName().startsWith(".") - && Pattern.matches(this.nameFilterRegex, partition.getName())) { - filterPartitions.add(partition); - } - } - return filterPartitions; - } else { - return allPartitions; + LOGGER.info("open table source on path: {}", path); + } + + @Override + public List listPartitions() { + return listPartitions(1); + } + + @Override + public List listPartitions(int parallelism) { + List allPartitions = fileReadHandler.listPartitions(parallelism); + if (StringUtils.isNotEmpty(this.nameFilterRegex)) { + List filterPartitions = new ArrayList<>(); + for (Partition partition : allPartitions) { + if (!partition.getName().startsWith(".") + && Pattern.matches(this.nameFilterRegex, partition.getName())) { + filterPartitions.add(partition); } + } + return filterPartitions; + } else { + return allPartitions; } - - @SuppressWarnings("unchecked") - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return fileReadHandler.getDeserializer(); + } + + @SuppressWarnings("unchecked") + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return fileReadHandler.getDeserializer(); + } + + @SuppressWarnings("unchecked") + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + int windowSize; + if (windowInfo.getType() == WindowType.ALL_WINDOW) { + windowSize = Integer.MAX_VALUE; + } else if (windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW) { + if (windowInfo.windowSize() > Integer.MAX_VALUE) { + throw new GeaFlowDSLException( + "File table source window size is overflow:{}", windowInfo.windowSize()); + } + windowSize = (int) windowInfo.windowSize(); + } else { + throw new GeaFlowDSLException( + "File table source not support window:{}", windowInfo.getType()); } - - @SuppressWarnings("unchecked") - @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - int windowSize; - if (windowInfo.getType() == WindowType.ALL_WINDOW) { - windowSize = Integer.MAX_VALUE; - } else if (windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW) { - if (windowInfo.windowSize() > Integer.MAX_VALUE) { - throw new GeaFlowDSLException("File table source window size is overflow:{}", windowInfo.windowSize()); - } - windowSize = (int) windowInfo.windowSize(); - } else { - throw new GeaFlowDSLException("File table source not support window:{}", windowInfo.getType()); - } - FileOffset offset = startOffset.map(value -> (FileOffset) value).orElseGet(() -> new FileOffset(0L)); - return fileReadHandler.readPartition((FileSplit) partition, offset, windowSize); + FileOffset offset = + startOffset.map(value -> (FileOffset) value).orElseGet(() -> new FileOffset(0L)); + return fileReadHandler.readPartition((FileSplit) partition, offset, windowSize); + } + + @Override + public void close() { + try { + fileReadHandler.close(); + } catch (IOException e) { + throw new GeaFlowDSLException("Error in close file read handler", e); } + } - @Override - public void close() { - try { - fileReadHandler.close(); - } catch (IOException e) { - throw new GeaFlowDSLException("Error in close file read handler", e); - } - } + public static class FileSplit implements Partition { - public static class FileSplit implements Partition { + private String name; - private String name; + private final String baseDir; - private final String baseDir; + private final String relativePath; - private final String relativePath; + private long splitStart; - private long splitStart; + private long splitLength; - private long splitLength; + private int lineSplitSize; - private int lineSplitSize; + private int index; - private int index; + private int parallel; - private int parallel; - - public FileSplit(String baseDir, String relativePath) { - this.baseDir = baseDir; - this.relativePath = relativePath; - this.lineSplitSize = 1; - this.splitStart = -1L; - this.splitLength = Long.MAX_VALUE; - this.name = relativePath; - } - - public FileSplit(String baseDir, String relativePath, int lineSplitSize) { - this.baseDir = baseDir; - this.relativePath = relativePath; - this.lineSplitSize = lineSplitSize; - this.splitStart = -1L; - this.splitLength = Long.MAX_VALUE; - this.name = relativePath; - } + public FileSplit(String baseDir, String relativePath) { + this.baseDir = baseDir; + this.relativePath = relativePath; + this.lineSplitSize = 1; + this.splitStart = -1L; + this.splitLength = Long.MAX_VALUE; + this.name = relativePath; + } - public FileSplit(String baseDir, String relativePath, int lineSplitSize, long splitStart, long splitLength) { - this.baseDir = baseDir; - this.relativePath = relativePath; - this.lineSplitSize = lineSplitSize; - this.splitStart = splitStart; - this.splitLength = splitLength; - this.name = relativePath; - } + public FileSplit(String baseDir, String relativePath, int lineSplitSize) { + this.baseDir = baseDir; + this.relativePath = relativePath; + this.lineSplitSize = lineSplitSize; + this.splitStart = -1L; + this.splitLength = Long.MAX_VALUE; + this.name = relativePath; + } - public FileSplit(String baseDir, String relativePath, int index, int lineSplitSize, long splitStart, long splitLength) { - this.baseDir = baseDir; - this.relativePath = relativePath; - this.lineSplitSize = lineSplitSize; - this.splitStart = splitStart; - this.splitLength = splitLength; - this.name = relativePath + "_" + index; - } + public FileSplit( + String baseDir, String relativePath, int lineSplitSize, long splitStart, long splitLength) { + this.baseDir = baseDir; + this.relativePath = relativePath; + this.lineSplitSize = lineSplitSize; + this.splitStart = splitStart; + this.splitLength = splitLength; + this.name = relativePath; + } - public FileSplit(String file) { - int index = file.lastIndexOf('/'); - if (index == -1) { - throw new GeaFlowDSLException("Illegal file path: '{}', should be a full path.", file); - } - this.baseDir = file.substring(0, index); - this.relativePath = file.substring(index + 1); - this.splitStart = -1L; - this.splitLength = Long.MAX_VALUE; - this.name = relativePath; - } + public FileSplit( + String baseDir, + String relativePath, + int index, + int lineSplitSize, + long splitStart, + long splitLength) { + this.baseDir = baseDir; + this.relativePath = relativePath; + this.lineSplitSize = lineSplitSize; + this.splitStart = splitStart; + this.splitLength = splitLength; + this.name = relativePath + "_" + index; + } - @Override - public void setIndex(int index, int parallel) { - this.index = index; - this.parallel = parallel; - } + public FileSplit(String file) { + int index = file.lastIndexOf('/'); + if (index == -1) { + throw new GeaFlowDSLException("Illegal file path: '{}', should be a full path.", file); + } + this.baseDir = file.substring(0, index); + this.relativePath = file.substring(index + 1); + this.splitStart = -1L; + this.splitLength = Long.MAX_VALUE; + this.name = relativePath; + } - @Override - public String getName() { - return name; - } + @Override + public void setIndex(int index, int parallel) { + this.index = index; + this.parallel = parallel; + } - public String getPath() { - if (baseDir.endsWith("/")) { - return baseDir + relativePath; - } - return baseDir + "/" + relativePath; - } + @Override + public String getName() { + return name; + } - public long getSplitStart() { - return splitStart; - } + public String getPath() { + if (baseDir.endsWith("/")) { + return baseDir + relativePath; + } + return baseDir + "/" + relativePath; + } - public long getSplitLength() { - return splitLength; - } + public long getSplitStart() { + return splitStart; + } - public int getLineSplitSize() { - return lineSplitSize; - } + public long getSplitLength() { + return splitLength; + } - public int getIndex() { - return index; - } + public int getLineSplitSize() { + return lineSplitSize; + } - public int getParallel() { - return parallel; - } + public int getIndex() { + return index; + } - @Override - public int hashCode() { - return Objects.hash(baseDir, relativePath); - } + public int getParallel() { + return parallel; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof FileSplit)) { - return false; - } - FileSplit that = (FileSplit) o; - return Objects.equals(baseDir, that.baseDir) && Objects.equals(relativePath, that.relativePath) - && Objects.equals(name, that.name) && Objects.equals(splitStart, that.splitStart) - && Objects.equals(splitLength, that.splitLength); - } + @Override + public int hashCode() { + return Objects.hash(baseDir, relativePath); + } - @Override - public String toString() { - return "FileSplit(path=" + getPath() + ")"; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof FileSplit)) { + return false; + } + FileSplit that = (FileSplit) o; + return Objects.equals(baseDir, that.baseDir) + && Objects.equals(relativePath, that.relativePath) + && Objects.equals(name, that.name) + && Objects.equals(splitStart, that.splitStart) + && Objects.equals(splitLength, that.splitLength); + } - public InputStream openStream(Configuration conf) throws IOException { - FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); - Path path = new Path(baseDir, relativePath); - FSDataInputStream inputStream = fs.open(path); - if (this.splitStart != -1L) { - inputStream.seek(this.splitStart); - } - return inputStream; - } + @Override + public String toString() { + return "FileSplit(path=" + getPath() + ")"; + } - public InputStream openStream(Configuration conf, long inputOffset) throws IOException { - FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); - Path path = new Path(baseDir, relativePath); - FSDataInputStream inputStream = fs.open(path); - if (this.splitStart != -1L) { - inputStream.seek(this.splitStart + inputOffset); - } - return inputStream; - } + public InputStream openStream(Configuration conf) throws IOException { + FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); + Path path = new Path(baseDir, relativePath); + FSDataInputStream inputStream = fs.open(path); + if (this.splitStart != -1L) { + inputStream.seek(this.splitStart); + } + return inputStream; + } + public InputStream openStream(Configuration conf, long inputOffset) throws IOException { + FileSystem fs = FileConnectorUtil.getHdfsFileSystem(conf); + Path path = new Path(baseDir, relativePath); + FSDataInputStream inputStream = fs.open(path); + if (this.splitStart != -1L) { + inputStream.seek(this.splitStart + inputOffset); + } + return inputStream; } + } - public static class FileOffset implements Offset { + public static class FileOffset implements Offset { - private final long offset; + private final long offset; - public FileOffset(long offset) { - this.offset = offset; - } + public FileOffset(long offset) { + this.offset = offset; + } - @Override - public String humanReadable() { - return String.valueOf(offset); - } + @Override + public String humanReadable() { + return String.valueOf(offset); + } - @Override - public long getOffset() { - return offset; - } + @Override + public long getOffset() { + return offset; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/JarFileReadHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/JarFileReadHandler.java index 7dd9d834c..8392a2299 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/JarFileReadHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/JarFileReadHandler.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -40,117 +41,136 @@ public class JarFileReadHandler extends AbstractFileReadHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(JarFileReadHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(JarFileReadHandler.class); - @Override - public void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException { - super.init(tableConf, tableSchema, path); - this.path = new Path(path.substring(PREFIX_JAVA_RESOURCE.length())); - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema, String path) + throws IOException { + super.init(tableConf, tableSchema, path); + this.path = new Path(path.substring(PREFIX_JAVA_RESOURCE.length())); + } - @Override - public List listPartitions(int parallelism) { - int index = path.toString().lastIndexOf('/'); - String baseDir = path.toString().substring(0, index); - String fileName = path.toString().substring(index + 1); - try { - URL url = getClass().getResource(path.toString()); - InputStream inputStream = url.openStream(); - int lineSplitSize = findLineSplitSize(inputStream); - inputStream.close(); - if (parallelism == 1 || tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { - return Collections.singletonList(new ResourceFileSplit(baseDir, fileName, lineSplitSize)); - } else { - List partitions = splitSingleFile(parallelism, baseDir, fileName, lineSplitSize); - List newPartitions = new ArrayList<>(); - for (int i = 0; i < partitions.size(); i++) { - Partition partition = partitions.get(i); - FileSplit fileSplit = (FileSplit) partition; - newPartitions.add(new ResourceFileSplit(baseDir, fileName, i, lineSplitSize, fileSplit.getSplitStart(), fileSplit.getSplitLength())); - } - return newPartitions; - } - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - throw new RuntimeException(e); + @Override + public List listPartitions(int parallelism) { + int index = path.toString().lastIndexOf('/'); + String baseDir = path.toString().substring(0, index); + String fileName = path.toString().substring(index + 1); + try { + URL url = getClass().getResource(path.toString()); + InputStream inputStream = url.openStream(); + int lineSplitSize = findLineSplitSize(inputStream); + inputStream.close(); + if (parallelism == 1 + || tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD)) { + return Collections.singletonList(new ResourceFileSplit(baseDir, fileName, lineSplitSize)); + } else { + List partitions = splitSingleFile(parallelism, baseDir, fileName, lineSplitSize); + List newPartitions = new ArrayList<>(); + for (int i = 0; i < partitions.size(); i++) { + Partition partition = partitions.get(i); + FileSplit fileSplit = (FileSplit) partition; + newPartitions.add( + new ResourceFileSplit( + baseDir, + fileName, + i, + lineSplitSize, + fileSplit.getSplitStart(), + fileSplit.getSplitLength())); } + return newPartitions; + } + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + throw new RuntimeException(e); } + } - public List splitSingleFile(int parallelism, String baseDir, String fileName, int lineSplitSize) throws Exception { - FileSplit fileSplit = new FileSplit(baseDir, fileName); + public List splitSingleFile( + int parallelism, String baseDir, String fileName, int lineSplitSize) throws Exception { + FileSplit fileSplit = new FileSplit(baseDir, fileName); - if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { - throw new RuntimeException("not support parallel read parquet file in resources"); - } else { - URL url = getClass().getResource(fileSplit.getPath()); - if (url == null) { - throw new GeaFlowDSLException("Resource: {} not found", fileSplit.getPath()); - } - File file = new File(url.getFile()); - long fileLength = file.length(); - long splitSize = fileLength / parallelism; - if (splitSize == 0) { - splitSize = 1; - } - long startPos = 0; - long endPos = 0; - List fileSplits = new ArrayList<>(); - for (int i = 0; i < parallelism; i++) { - startPos = endPos; - InputStream inputStream = url.openStream(); - endPos = findNextStartPos(startPos + splitSize, fileLength, inputStream); - inputStream.close(); - fileSplits.add(new ResourceFileSplit(baseDir, fileName, i, lineSplitSize, startPos, endPos - startPos)); - if (endPos >= fileLength) { - break; - } - } - return fileSplits; + if (this.formatName.equalsIgnoreCase(SourceConstants.PARQUET)) { + throw new RuntimeException("not support parallel read parquet file in resources"); + } else { + URL url = getClass().getResource(fileSplit.getPath()); + if (url == null) { + throw new GeaFlowDSLException("Resource: {} not found", fileSplit.getPath()); + } + File file = new File(url.getFile()); + long fileLength = file.length(); + long splitSize = fileLength / parallelism; + if (splitSize == 0) { + splitSize = 1; + } + long startPos = 0; + long endPos = 0; + List fileSplits = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + startPos = endPos; + InputStream inputStream = url.openStream(); + endPos = findNextStartPos(startPos + splitSize, fileLength, inputStream); + inputStream.close(); + fileSplits.add( + new ResourceFileSplit( + baseDir, fileName, i, lineSplitSize, startPos, endPos - startPos)); + if (endPos >= fileLength) { + break; } + } + return fileSplits; } + } - public static class ResourceFileSplit extends FileSplit { + public static class ResourceFileSplit extends FileSplit { - public ResourceFileSplit(String baseDir, String relativePath, int lineSplitSize) { - super(baseDir, relativePath, lineSplitSize); - } + public ResourceFileSplit(String baseDir, String relativePath, int lineSplitSize) { + super(baseDir, relativePath, lineSplitSize); + } - public ResourceFileSplit(String baseDir, String relativePath, int lineSplitSize, long splitStart, long splitLength) { - super(baseDir, relativePath, lineSplitSize, splitStart, splitLength); - } + public ResourceFileSplit( + String baseDir, String relativePath, int lineSplitSize, long splitStart, long splitLength) { + super(baseDir, relativePath, lineSplitSize, splitStart, splitLength); + } - public ResourceFileSplit(String baseDir, String relativePath, int index, int lineSplitSize, long splitStart, long splitLength) { - super(baseDir, relativePath, index, lineSplitSize, splitStart, splitLength); - } + public ResourceFileSplit( + String baseDir, + String relativePath, + int index, + int lineSplitSize, + long splitStart, + long splitLength) { + super(baseDir, relativePath, index, lineSplitSize, splitStart, splitLength); + } - @Override - public InputStream openStream(Configuration conf) throws IOException { - URL url = getClass().getResource(getPath()); - if (url == null) { - throw new GeaFlowDSLException("Resource: {} not found", getPath()); - } - InputStream inputStream = url.openStream(); - if (getSplitStart() != -1L) { - inputStream.skip(getSplitStart()); - } - return inputStream; - } + @Override + public InputStream openStream(Configuration conf) throws IOException { + URL url = getClass().getResource(getPath()); + if (url == null) { + throw new GeaFlowDSLException("Resource: {} not found", getPath()); + } + InputStream inputStream = url.openStream(); + if (getSplitStart() != -1L) { + inputStream.skip(getSplitStart()); + } + return inputStream; + } - @Override - public InputStream openStream(Configuration conf, long inputOffset) throws IOException { - URL url = getClass().getResource(getPath()); - if (url == null) { - throw new GeaFlowDSLException("Resource: {} not found", getPath()); - } - InputStream inputStream = url.openStream(); - if (inputOffset != -1L) { - if (getSplitStart() != -1L) { - inputStream.skip(getSplitStart() + inputOffset); - } else { - inputStream.skip(inputOffset); - } - } - return inputStream; + @Override + public InputStream openStream(Configuration conf, long inputOffset) throws IOException { + URL url = getClass().getResource(getPath()); + if (url == null) { + throw new GeaFlowDSLException("Resource: {} not found", getPath()); + } + InputStream inputStream = url.openStream(); + if (inputOffset != -1L) { + if (getSplitStart() != -1L) { + inputStream.skip(getSplitStart() + inputOffset); + } else { + inputStream.skip(inputOffset); } + } + return inputStream; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/S3FileReadHandler.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/S3FileReadHandler.java index 9c46a7a91..bd866cf76 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/S3FileReadHandler.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/S3FileReadHandler.java @@ -19,18 +19,11 @@ package org.apache.geaflow.dsl.connector.file.source; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.types.TableSchema; @@ -39,84 +32,93 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import com.amazonaws.services.s3.model.ListObjectsV2Result; +import com.amazonaws.services.s3.model.S3Object; +import com.amazonaws.services.s3.model.S3ObjectSummary; public class S3FileReadHandler extends AbstractFileReadHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(S3FileReadHandler.class); + private static final Logger LOGGER = LoggerFactory.getLogger(S3FileReadHandler.class); - protected String path; + protected String path; - protected AWSCredentials credentials; - protected String serviceEndpoint; + protected AWSCredentials credentials; + protected String serviceEndpoint; - protected AmazonS3 s3; + protected AmazonS3 s3; - @Override - public void init(Configuration tableConf, TableSchema tableSchema, String path) throws IOException { - super.init(tableConf, tableSchema, path); - this.path = path; - this.credentials = FileConnectorUtil.getS3Credentials(tableConf); - this.serviceEndpoint = FileConnectorUtil.getS3ServiceEndpoint(tableConf); - s3 = AmazonS3ClientBuilder.standard() - .withCredentials(new AWSCredentialsProvider() { - @Override - public AWSCredentials getCredentials() { + @Override + public void init(Configuration tableConf, TableSchema tableSchema, String path) + throws IOException { + super.init(tableConf, tableSchema, path); + this.path = path; + this.credentials = FileConnectorUtil.getS3Credentials(tableConf); + this.serviceEndpoint = FileConnectorUtil.getS3ServiceEndpoint(tableConf); + s3 = + AmazonS3ClientBuilder.standard() + .withCredentials( + new AWSCredentialsProvider() { + @Override + public AWSCredentials getCredentials() { return credentials; - } + } - @Override - public void refresh() { - } - }) - .withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, null)) + @Override + public void refresh() {} + }) + .withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration(serviceEndpoint, null)) .build(); + } + + @Override + public List listPartitions(int parallelism) { + List partitions = new ArrayList<>(); + try { + ListObjectsV2Result result = + s3.listObjectsV2(FileConnectorUtil.getBucket(path), FileConnectorUtil.getKey(path)); + + result + .getObjectSummaries() + .forEach( + (S3ObjectSummary obj) -> { + ResourceFileSplit split = new ResourceFileSplit(obj.getBucketName(), obj.getKey()); + split.setS3(s3); + partitions.add(split); + }); + } catch (Exception e) { + throw new GeaflowRuntimeException("Cannot get partitions with path: " + path, e); } + return partitions; + } - @Override - public List listPartitions(int parallelism) { - List partitions = new ArrayList<>(); - try { - ListObjectsV2Result result = s3.listObjectsV2( - FileConnectorUtil.getBucket(path), - FileConnectorUtil.getKey(path) - ); - - - result.getObjectSummaries() - .forEach((S3ObjectSummary obj) -> { - ResourceFileSplit split = new ResourceFileSplit(obj.getBucketName(), obj.getKey()); - split.setS3(s3); - partitions.add(split); - }); - } catch (Exception e) { - throw new GeaflowRuntimeException("Cannot get partitions with path: " + path, e); - } - return partitions; - } - - public static class ResourceFileSplit extends FileTableSource.FileSplit { - - private AmazonS3 s3; + public static class ResourceFileSplit extends FileTableSource.FileSplit { - private String object; + private AmazonS3 s3; - private String key; + private String object; - public ResourceFileSplit(String baseDir, String relativePath) { - super(baseDir, relativePath); - this.object = baseDir; - this.key = relativePath; - } + private String key; - void setS3(AmazonS3 s3) { - this.s3 = s3; - } + public ResourceFileSplit(String baseDir, String relativePath) { + super(baseDir, relativePath); + this.object = baseDir; + this.key = relativePath; + } - @Override - public InputStream openStream(Configuration conf) throws IOException { - S3Object obj = s3.getObject(object, key); - return obj.getObjectContent(); - } + void setS3(AmazonS3 s3) { + this.s3 = s3; } + @Override + public InputStream openStream(Configuration conf) throws IOException { + S3Object obj = s3.getObject(object, key); + return obj.getObjectContent(); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/SourceConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/SourceConstants.java index 23f012d5e..f02f162d9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/SourceConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/SourceConstants.java @@ -21,8 +21,7 @@ public class SourceConstants { - public static final String PARQUET = "parquet"; - public static final String CSV = "csv"; - public static final String TXT = "txt"; - + public static final String PARQUET = "parquet"; + public static final String CSV = "csv"; + public static final String TXT = "txt"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/CsvFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/CsvFormat.java index a5d4a5212..e550a42de 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/CsvFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/CsvFormat.java @@ -19,11 +19,11 @@ package org.apache.geaflow.dsl.connector.file.source.format; -import com.google.common.collect.Lists; import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -38,100 +38,105 @@ import org.apache.geaflow.dsl.connector.file.source.FileTableSource.FileSplit; import org.apache.geaflow.dsl.connector.file.source.SourceConstants; +import com.google.common.collect.Lists; + public class CsvFormat implements FileFormat { - private TextFormat txtFormat; + private TextFormat txtFormat; - private StructType dataSchema; + private StructType dataSchema; - private Configuration tableConf; + private Configuration tableConf; - @Override - public String getFormat() { - return SourceConstants.CSV; - } + @Override + public String getFormat() { + return SourceConstants.CSV; + } - @Override - public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) throws IOException { - this.txtFormat = new TextFormat(); - this.txtFormat.init(tableConf, tableSchema, split); - this.dataSchema = tableSchema.getDataSchema(); - this.tableConf = tableConf; - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) + throws IOException { + this.txtFormat = new TextFormat(); + this.txtFormat.init(tableConf, tableSchema, split); + this.dataSchema = tableSchema.getDataSchema(); + this.tableConf = tableConf; + } - @Override - public Iterator batchRead() throws IOException { - Iterator textIterator = txtFormat.batchRead(); - return new CsvIterator(textIterator, dataSchema, tableConf); - } + @Override + public Iterator batchRead() throws IOException { + Iterator textIterator = txtFormat.batchRead(); + return new CsvIterator(textIterator, dataSchema, tableConf); + } - @Override - public void close() throws IOException { - txtFormat.close(); - } + @Override + public void close() throws IOException { + txtFormat.close(); + } - @Override - public TableDeserializer getDeserializer() { - return null; - } + @Override + public TableDeserializer getDeserializer() { + return null; + } - private static class CsvIterator implements Iterator { - - private final Iterator textIterator; - - private final StructType schema; - - private final int[] fieldIndices; - - private final String separator; - - public CsvIterator(Iterator textIterator, StructType schema, Configuration tableConf) { - this.textIterator = textIterator; - this.schema = schema; - this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); - - boolean skipHeader = tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER); - List headerFields = new ArrayList<>(); - if (skipHeader) { - if (textIterator.hasNext()) { // skip header - String header = textIterator.next(); - headerFields = Lists.newArrayList( - StringUtils.splitByWholeSeparatorPreserveAllTokens(header, separator)); - } - } - fieldIndices = new int[schema.size()]; - if (headerFields.size() > 0) { - int i = 0; - for (TableField field : schema.getFields()) { - int index = headerFields.indexOf(field.getName()); - if (index == -1) { - throw new GeaFlowDSLException("Field: '{}' is not exists in the csv " - + "header. header field is: {}", field.getName(), - StringUtils.join(headerFields, ",")); - } - fieldIndices[i++] = index; - } - } else { - for (int i = 0; i < schema.size(); i++) { - fieldIndices[i] = i; - } - } - } + private static class CsvIterator implements Iterator { - @Override - public boolean hasNext() { - return textIterator.hasNext(); - } + private final Iterator textIterator; + + private final StructType schema; + + private final int[] fieldIndices; - @Override - public Row next() { - String line = textIterator.next(); - String[] fields = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, separator); - Object[] selectFields = new Object[fieldIndices.length]; - for (int i = 0; i < selectFields.length; i++) { - selectFields[i] = TypeCastUtil.cast(fields[fieldIndices[i]], schema.getType(i)); - } - return ObjectRow.create(selectFields); + private final String separator; + + public CsvIterator(Iterator textIterator, StructType schema, Configuration tableConf) { + this.textIterator = textIterator; + this.schema = schema; + this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + + boolean skipHeader = tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER); + List headerFields = new ArrayList<>(); + if (skipHeader) { + if (textIterator.hasNext()) { // skip header + String header = textIterator.next(); + headerFields = + Lists.newArrayList( + StringUtils.splitByWholeSeparatorPreserveAllTokens(header, separator)); + } + } + fieldIndices = new int[schema.size()]; + if (headerFields.size() > 0) { + int i = 0; + for (TableField field : schema.getFields()) { + int index = headerFields.indexOf(field.getName()); + if (index == -1) { + throw new GeaFlowDSLException( + "Field: '{}' is not exists in the csv " + "header. header field is: {}", + field.getName(), + StringUtils.join(headerFields, ",")); + } + fieldIndices[i++] = index; } + } else { + for (int i = 0; i < schema.size(); i++) { + fieldIndices[i] = i; + } + } + } + + @Override + public boolean hasNext() { + return textIterator.hasNext(); + } + + @Override + public Row next() { + String line = textIterator.next(); + String[] fields = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, separator); + Object[] selectFields = new Object[fieldIndices.length]; + for (int i = 0; i < selectFields.length; i++) { + selectFields[i] = TypeCastUtil.cast(fields[fieldIndices[i]], schema.getType(i)); + } + return ObjectRow.create(selectFields); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormat.java index 0388501ce..8f015b342 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormat.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.types.TableSchema; import org.apache.geaflow.dsl.connector.api.serde.TableDeserializer; @@ -28,13 +29,13 @@ public interface FileFormat { - String getFormat(); + String getFormat(); - void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) throws IOException; + void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) throws IOException; - Iterator batchRead() throws IOException; + Iterator batchRead() throws IOException; - void close() throws IOException; + void close() throws IOException; - TableDeserializer getDeserializer(); + TableDeserializer getDeserializer(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormats.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormats.java index f3aefce58..7725653d8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormats.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/FileFormats.java @@ -20,23 +20,24 @@ package org.apache.geaflow.dsl.connector.file.source.format; import java.util.ServiceLoader; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; public class FileFormats { - @SuppressWarnings("unchecked") - public static FileFormat loadFileFormat(String formatName) { - ServiceLoader formats = ServiceLoader.load(FileFormat.class); - FileFormat currentFormat = null; - for (FileFormat format : formats) { - if (format.getFormat().equalsIgnoreCase(formatName)) { - currentFormat = format; - break; - } - } - if (currentFormat == null) { - throw new GeaFlowDSLException("File format '{}' is not found", formatName); - } - return currentFormat; + @SuppressWarnings("unchecked") + public static FileFormat loadFileFormat(String formatName) { + ServiceLoader formats = ServiceLoader.load(FileFormat.class); + FileFormat currentFormat = null; + for (FileFormat format : formats) { + if (format.getFormat().equalsIgnoreCase(formatName)) { + currentFormat = format; + break; + } + } + if (currentFormat == null) { + throw new GeaFlowDSLException("File format '{}' is not found", formatName); } + return currentFormat; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/ParquetFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/ParquetFormat.java index cfd0a6eb4..e131031c0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/ParquetFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/ParquetFormat.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; + import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; @@ -63,179 +64,187 @@ public class ParquetFormat implements FileFormat { - private static final Logger LOGGER = LoggerFactory.getLogger(ParquetFormat.class); - - private StructType dataSchema; - - private List inputSplits; - - private ParquetInputFormat inputFormat; - - private TaskAttemptContext taskAttemptContext; - - private RecordReader currentReader; - - @Override - public String getFormat() { - return SourceConstants.PARQUET; - } - - @Override - public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) throws IOException { - this.dataSchema = tableSchema.getDataSchema(); - this.inputFormat = new AvroParquetInputFormat<>(); - Job job = Job.getInstance(FileConnectorUtil.toHadoopConf(tableConf)); - Path path = new Path(split.getPath()); - path = path.getFileSystem(job.getConfiguration()).makeQualified(path); - LOGGER.info("Read parquet from: {}", path); - AvroParquetInputFormat.setInputPaths(job, path); - - Schema avroSchema = convertToAvroSchema(dataSchema, false); - job.getConfiguration().set(AvroReadSupport.AVRO_COMPATIBILITY, "false"); - AvroParquetInputFormat.setAvroReadSchema(job, avroSchema); - AvroParquetInputFormat.setRequestedProjection(job, avroSchema); - - JobContext jobContext = new JobContextImpl(job.getConfiguration(), new JobID()); - this.inputSplits = inputFormat.getSplits(jobContext); - if (split.getSplitStart() != -1L) { - this.inputSplits = this.inputSplits.stream().filter( - inputSplit -> ((org.apache.hadoop.mapred.FileSplit) inputSplit).getStart() == split.getSplitStart() - && ((org.apache.hadoop.mapred.FileSplit) inputSplit).getLength() == split.getSplitLength() - ).collect(Collectors.toList()); - } - this.taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration(), new TaskAttemptID()); + private static final Logger LOGGER = LoggerFactory.getLogger(ParquetFormat.class); + + private StructType dataSchema; + + private List inputSplits; + + private ParquetInputFormat inputFormat; + + private TaskAttemptContext taskAttemptContext; + + private RecordReader currentReader; + + @Override + public String getFormat() { + return SourceConstants.PARQUET; + } + + @Override + public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) + throws IOException { + this.dataSchema = tableSchema.getDataSchema(); + this.inputFormat = new AvroParquetInputFormat<>(); + Job job = Job.getInstance(FileConnectorUtil.toHadoopConf(tableConf)); + Path path = new Path(split.getPath()); + path = path.getFileSystem(job.getConfiguration()).makeQualified(path); + LOGGER.info("Read parquet from: {}", path); + AvroParquetInputFormat.setInputPaths(job, path); + + Schema avroSchema = convertToAvroSchema(dataSchema, false); + job.getConfiguration().set(AvroReadSupport.AVRO_COMPATIBILITY, "false"); + AvroParquetInputFormat.setAvroReadSchema(job, avroSchema); + AvroParquetInputFormat.setRequestedProjection(job, avroSchema); + + JobContext jobContext = new JobContextImpl(job.getConfiguration(), new JobID()); + this.inputSplits = inputFormat.getSplits(jobContext); + if (split.getSplitStart() != -1L) { + this.inputSplits = + this.inputSplits.stream() + .filter( + inputSplit -> + ((org.apache.hadoop.mapred.FileSplit) inputSplit).getStart() + == split.getSplitStart() + && ((org.apache.hadoop.mapred.FileSplit) inputSplit).getLength() + == split.getSplitLength()) + .collect(Collectors.toList()); } - - @Override - public Iterator batchRead() throws IOException { - return new Iterator() { - - private int index = 0; - - @Override - public boolean hasNext() { - try { - boolean hasNext = currentReader != null && currentReader.nextKeyValue(); - if (currentReader == null || !hasNext) { - if (index < inputSplits.size()) { - InputSplit split = inputSplits.get(index); - // close previous reader - if (currentReader != null) { - currentReader.close(); - } - // create new reader - currentReader = inputFormat.createRecordReader(split, taskAttemptContext); - currentReader.initialize(split, taskAttemptContext); - hasNext = currentReader.nextKeyValue(); - } else { - return false; - } - index++; - } - return hasNext; - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } - } - - @Override - public Row next() { - try { - GenericData.Record record = currentReader.getCurrentValue(); - return convertAvroRecordToRow(record); - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } + this.taskAttemptContext = + new TaskAttemptContextImpl(job.getConfiguration(), new TaskAttemptID()); + } + + @Override + public Iterator batchRead() throws IOException { + return new Iterator() { + + private int index = 0; + + @Override + public boolean hasNext() { + try { + boolean hasNext = currentReader != null && currentReader.nextKeyValue(); + if (currentReader == null || !hasNext) { + if (index < inputSplits.size()) { + InputSplit split = inputSplits.get(index); + // close previous reader + if (currentReader != null) { + currentReader.close(); + } + // create new reader + currentReader = inputFormat.createRecordReader(split, taskAttemptContext); + currentReader.initialize(split, taskAttemptContext); + hasNext = currentReader.nextKeyValue(); + } else { + return false; } - }; - } - - @Override - public void close() throws IOException { - if (currentReader != null) { - currentReader.close(); - } - } - - @Override - public TableDeserializer getDeserializer() { - return null; - } - - public static Schema convertToAvroSchema(IType sqlType, boolean nullable) { - TypeBuilder builder = SchemaBuilder.builder(); - Schema avroType; - switch (sqlType.getName()) { - case Types.TYPE_NAME_BINARY_STRING: - case Types.TYPE_NAME_STRING: - avroType = builder.stringType(); - break; - case Types.TYPE_NAME_INTEGER: - avroType = builder.intType(); - break; - case Types.TYPE_NAME_LONG: - avroType = builder.longType(); - break; - case Types.TYPE_NAME_BOOLEAN: - avroType = builder.booleanType(); - break; - case Types.TYPE_NAME_DOUBLE: - avroType = builder.doubleType(); - break; - case Types.TYPE_NAME_TIMESTAMP: - avroType = LogicalTypes.timestampMicros().addToSchema(builder.longType()); - break; - case Types.TYPE_NAME_STRUCT: - StructType structType = (StructType) sqlType; - FieldAssembler fieldAssembler = builder.record("struct").namespace("").fields(); - for (TableField field : structType.getFields()) { - Schema fieldAvroType = convertToAvroSchema(field.getType(), field.isNullable()); - fieldAssembler.name(field.getName()).type(fieldAvroType).noDefault(); - } - avroType = fieldAssembler.endRecord(); - break; - case Types.TYPE_NAME_ARRAY: - ArrayType arrayType = (ArrayType) sqlType; - avroType = builder.array().items(convertToAvroSchema(arrayType.getComponentType(), nullable)); - break; - default: - throw new GeaFlowDSLException("Not support type: {}", sqlType.getName()); + index++; + } + return hasNext; + } catch (Exception e) { + throw new GeaFlowDSLException(e); } - if (nullable) { - Schema nullSchema = builder.nullType(); - return Schema.createUnion(avroType, nullSchema); + } + + @Override + public Row next() { + try { + GenericData.Record record = currentReader.getCurrentValue(); + return convertAvroRecordToRow(record); + } catch (Exception e) { + throw new GeaFlowDSLException(e); } - return avroType; + } + }; + } + + @Override + public void close() throws IOException { + if (currentReader != null) { + currentReader.close(); } - - private Row convertAvroRecordToRow(GenericData.Record record) { - Object[] fields = new Object[dataSchema.size()]; - for (int i = 0; i < fields.length; i++) { - IType type = dataSchema.getType(i); - switch (type.getName()) { - case Types.TYPE_NAME_BINARY_STRING: - Utf8 utf8 = (Utf8) record.get(i); - fields[i] = utf8 == null ? null : BinaryString.fromBytes(utf8.getBytes()); - break; - case Types.TYPE_NAME_INTEGER: - fields[i] = (Integer) record.get(i); - break; - case Types.TYPE_NAME_LONG: - fields[i] = (Long) record.get(i); - break; - case Types.TYPE_NAME_DOUBLE: - fields[i] = (Double) record.get(i); - break; - case Types.TYPE_NAME_TIMESTAMP: - fields[i] = record.get(i); - break; - case Types.TYPE_NAME_BOOLEAN: - fields[i] = (Boolean) record.get(i); - break; - default: - throw new GeaFlowDSLException("Not support type: {}", type.getName()); - } + } + + @Override + public TableDeserializer getDeserializer() { + return null; + } + + public static Schema convertToAvroSchema(IType sqlType, boolean nullable) { + TypeBuilder builder = SchemaBuilder.builder(); + Schema avroType; + switch (sqlType.getName()) { + case Types.TYPE_NAME_BINARY_STRING: + case Types.TYPE_NAME_STRING: + avroType = builder.stringType(); + break; + case Types.TYPE_NAME_INTEGER: + avroType = builder.intType(); + break; + case Types.TYPE_NAME_LONG: + avroType = builder.longType(); + break; + case Types.TYPE_NAME_BOOLEAN: + avroType = builder.booleanType(); + break; + case Types.TYPE_NAME_DOUBLE: + avroType = builder.doubleType(); + break; + case Types.TYPE_NAME_TIMESTAMP: + avroType = LogicalTypes.timestampMicros().addToSchema(builder.longType()); + break; + case Types.TYPE_NAME_STRUCT: + StructType structType = (StructType) sqlType; + FieldAssembler fieldAssembler = builder.record("struct").namespace("").fields(); + for (TableField field : structType.getFields()) { + Schema fieldAvroType = convertToAvroSchema(field.getType(), field.isNullable()); + fieldAssembler.name(field.getName()).type(fieldAvroType).noDefault(); } - return ObjectRow.create(fields); + avroType = fieldAssembler.endRecord(); + break; + case Types.TYPE_NAME_ARRAY: + ArrayType arrayType = (ArrayType) sqlType; + avroType = + builder.array().items(convertToAvroSchema(arrayType.getComponentType(), nullable)); + break; + default: + throw new GeaFlowDSLException("Not support type: {}", sqlType.getName()); + } + if (nullable) { + Schema nullSchema = builder.nullType(); + return Schema.createUnion(avroType, nullSchema); + } + return avroType; + } + + private Row convertAvroRecordToRow(GenericData.Record record) { + Object[] fields = new Object[dataSchema.size()]; + for (int i = 0; i < fields.length; i++) { + IType type = dataSchema.getType(i); + switch (type.getName()) { + case Types.TYPE_NAME_BINARY_STRING: + Utf8 utf8 = (Utf8) record.get(i); + fields[i] = utf8 == null ? null : BinaryString.fromBytes(utf8.getBytes()); + break; + case Types.TYPE_NAME_INTEGER: + fields[i] = (Integer) record.get(i); + break; + case Types.TYPE_NAME_LONG: + fields[i] = (Long) record.get(i); + break; + case Types.TYPE_NAME_DOUBLE: + fields[i] = (Double) record.get(i); + break; + case Types.TYPE_NAME_TIMESTAMP: + fields[i] = record.get(i); + break; + case Types.TYPE_NAME_BOOLEAN: + fields[i] = (Boolean) record.get(i); + break; + default: + throw new GeaFlowDSLException("Not support type: {}", type.getName()); + } } + return ObjectRow.create(fields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/StreamFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/StreamFormat.java index 891e0428e..c851ad43a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/StreamFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/StreamFormat.java @@ -20,12 +20,13 @@ package org.apache.geaflow.dsl.connector.file.source.format; import java.io.IOException; + import org.apache.geaflow.dsl.connector.api.FetchData; import org.apache.geaflow.dsl.connector.file.source.FileTableSource.FileOffset; public interface StreamFormat { - FetchData streamRead(FileOffset offset, int windowSize) throws IOException; + FetchData streamRead(FileOffset offset, int windowSize) throws IOException; - void skip(long n) throws IOException; + void skip(long n) throws IOException; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/TextFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/TextFormat.java index 71bd5703d..41e2a54b5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/TextFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/main/java/org/apache/geaflow/dsl/connector/file/source/format/TextFormat.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -40,114 +41,119 @@ public class TextFormat implements FileFormat, StreamFormat { - private static final Logger LOGGER = LoggerFactory.getLogger(TextFormat.class); - - private BufferedReader reader; - - private Configuration tableConf; - protected boolean singleFileModeRead = false; - private FileSplit fileSplit; - private long readCnt = 0L; - private long readSize = 0L; - private long expectReadSize = -1L; - private int lineSplitSize = 1; - private boolean firstRead = true; - - @Override - public String getFormat() { - return SourceConstants.TXT; - } - - @Override - public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) throws IOException { - this.tableConf = tableConf; - this.fileSplit = split; - this.expectReadSize = split.getSplitLength(); - this.lineSplitSize = split.getLineSplitSize(); - this.reader = new BufferedReader(new InputStreamReader(split.openStream(tableConf))); - this.expectReadSize = split.getSplitLength(); - this.lineSplitSize = split.getLineSplitSize(); - this.singleFileModeRead = tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD); - } - - @Override - public Iterator batchRead() { - return new Iterator() { - - private String current = null; - - @Override - public boolean hasNext() { - if (singleFileModeRead) { - throw new GeaFlowDSLException("Single file mode read is not supported in batch read"); - } - try { - if (expectReadSize != -1L && readSize >= expectReadSize) { - return false; - } - if (current == null) { - current = reader.readLine(); - } - return current != null; - } catch (IOException e) { - throw new GeaFlowDSLException("Error in read", e); - } - } - - @Override - public String next() { - String next = current; - current = null; - readSize += next.length() + lineSplitSize; - return next; - } - }; - } - - @Override - public FetchData streamRead(FileOffset offset, int windowSize) throws IOException { - if (firstRead) { - firstRead = false; - close(); - this.reader = new BufferedReader(new InputStreamReader(fileSplit.openStream(tableConf, offset.getOffset()))); + private static final Logger LOGGER = LoggerFactory.getLogger(TextFormat.class); + + private BufferedReader reader; + + private Configuration tableConf; + protected boolean singleFileModeRead = false; + private FileSplit fileSplit; + private long readCnt = 0L; + private long readSize = 0L; + private long expectReadSize = -1L; + private int lineSplitSize = 1; + private boolean firstRead = true; + + @Override + public String getFormat() { + return SourceConstants.TXT; + } + + @Override + public void init(Configuration tableConf, TableSchema tableSchema, FileSplit split) + throws IOException { + this.tableConf = tableConf; + this.fileSplit = split; + this.expectReadSize = split.getSplitLength(); + this.lineSplitSize = split.getLineSplitSize(); + this.reader = new BufferedReader(new InputStreamReader(split.openStream(tableConf))); + this.expectReadSize = split.getSplitLength(); + this.lineSplitSize = split.getLineSplitSize(); + this.singleFileModeRead = + tableConf.getBoolean(ConnectorConfigKeys.GEAFLOW_DSL_SOURCE_FILE_PARALLEL_MOD); + } + + @Override + public Iterator batchRead() { + return new Iterator() { + + private String current = null; + + @Override + public boolean hasNext() { + if (singleFileModeRead) { + throw new GeaFlowDSLException("Single file mode read is not supported in batch read"); } - - List readContents = new ArrayList<>(windowSize); - long nextOffset = offset.getOffset(); - int i = 0; - for (i = 0; i < windowSize; i++) { - String line = reader.readLine(); - if (line == null) { - break; - } - if (!singleFileModeRead || readCnt % this.fileSplit.getParallel() == this.fileSplit.getIndex()) { - readContents.add(line); - } - nextOffset += line.length() + lineSplitSize; - readSize += line.length() + lineSplitSize; - readCnt++; - if (fileSplit.getSplitStart() != -1L && nextOffset >= fileSplit.getSplitLength()) { - break; - } + try { + if (expectReadSize != -1L && readSize >= expectReadSize) { + return false; + } + if (current == null) { + current = reader.readLine(); + } + return current != null; + } catch (IOException e) { + throw new GeaFlowDSLException("Error in read", e); } - boolean isFinished = i < windowSize; - return FetchData.createStreamFetch(readContents, new FileOffset(nextOffset), isFinished); + } + + @Override + public String next() { + String next = current; + current = null; + readSize += next.length() + lineSplitSize; + return next; + } + }; + } + + @Override + public FetchData streamRead(FileOffset offset, int windowSize) throws IOException { + if (firstRead) { + firstRead = false; + close(); + this.reader = + new BufferedReader( + new InputStreamReader(fileSplit.openStream(tableConf, offset.getOffset()))); } - @Override - public void close() throws IOException { - if (reader != null) { - reader.close(); - } + List readContents = new ArrayList<>(windowSize); + long nextOffset = offset.getOffset(); + int i = 0; + for (i = 0; i < windowSize; i++) { + String line = reader.readLine(); + if (line == null) { + break; + } + if (!singleFileModeRead + || readCnt % this.fileSplit.getParallel() == this.fileSplit.getIndex()) { + readContents.add(line); + } + nextOffset += line.length() + lineSplitSize; + readSize += line.length() + lineSplitSize; + readCnt++; + if (fileSplit.getSplitStart() != -1L && nextOffset >= fileSplit.getSplitLength()) { + break; + } } - - @Override - public TableDeserializer getDeserializer() { - return new TextDeserializer(); + boolean isFinished = i < windowSize; + return FetchData.createStreamFetch(readContents, new FileOffset(nextOffset), isFinished); + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); } + } - @Override - public void skip(long n) throws IOException { - reader.skip(n); - } + @Override + public TableDeserializer getDeserializer() { + return new TextDeserializer(); + } + + @Override + public void skip(long n) throws IOException { + reader.skip(n); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ConnectorFactoryTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ConnectorFactoryTest.java index 0f362fa79..49427759e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ConnectorFactoryTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ConnectorFactoryTest.java @@ -26,10 +26,9 @@ public class ConnectorFactoryTest { - @Test - public void testLoadConnector() { - TableConnector connector = ConnectorFactory.loadConnector("file"); - Assert.assertEquals(connector.getType(), "FILE"); - } - + @Test + public void testLoadConnector() { + TableConnector connector = ConnectorFactory.loadConnector("file"); + Assert.assertEquals(connector.getType(), "FILE"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/CsvFormatTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/CsvFormatTest.java index cb22d990e..77d657528 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/CsvFormatTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/CsvFormatTest.java @@ -19,11 +19,11 @@ package org.apache.geaflow.dsl.connector.file; -import com.google.common.collect.Lists; import java.io.File; import java.util.HashMap; import java.util.Iterator; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -37,74 +37,69 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class CsvFormatTest { - @Test - public void testReadSkipHeader() throws Exception { - String output = "target/test/csv/output"; - writeData(output, "id,name,price", "1,a1,10", "2,a2,12", "3,a3,15"); + @Test + public void testReadSkipHeader() throws Exception { + String output = "target/test/csv/output"; + writeData(output, "id,name,price", "1,a1,10", "2,a2,12", "3,a3,15"); - CsvFormat format = new CsvFormat(); - Map config = new HashMap<>(); - File outputFile = new File(output); - config.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), outputFile.getAbsolutePath()); - config.put(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER.getKey(), "true"); + CsvFormat format = new CsvFormat(); + Map config = new HashMap<>(); + File outputFile = new File(output); + config.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), outputFile.getAbsolutePath()); + config.put(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER.getKey(), "true"); - FileSplit fileSplit = new FileSplit(outputFile.getAbsolutePath()); - StructType dataSchema = new StructType( - new TableField("price", Types.DOUBLE), - new TableField("name", Types.BINARY_STRING) - ); - format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); - Iterator iterator = format.batchRead(); - StringBuilder result = new StringBuilder(); - while (iterator.hasNext()) { - Row row = iterator.next(); - if (result.length() > 0) { - result.append("\n"); - } - result.append(row.toString()); - } - Assert.assertEquals(result.toString(), - "[10.0, a1]\n" - + "[12.0, a2]\n" - + "[15.0, a3]"); + FileSplit fileSplit = new FileSplit(outputFile.getAbsolutePath()); + StructType dataSchema = + new StructType( + new TableField("price", Types.DOUBLE), new TableField("name", Types.BINARY_STRING)); + format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); + Iterator iterator = format.batchRead(); + StringBuilder result = new StringBuilder(); + while (iterator.hasNext()) { + Row row = iterator.next(); + if (result.length() > 0) { + result.append("\n"); + } + result.append(row.toString()); } + Assert.assertEquals(result.toString(), "[10.0, a1]\n" + "[12.0, a2]\n" + "[15.0, a3]"); + } - @Test - public void testReadNoHeader() throws Exception { - String output = "target/test/csv/output"; - writeData(output, "1,a1,10", "2,a2,12", "3,a3,15"); + @Test + public void testReadNoHeader() throws Exception { + String output = "target/test/csv/output"; + writeData(output, "1,a1,10", "2,a2,12", "3,a3,15"); - CsvFormat format = new CsvFormat(); - Map config = new HashMap<>(); - File outputFile = new File(output); - config.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), outputFile.getAbsolutePath()); - config.put(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER.getKey(), "false"); + CsvFormat format = new CsvFormat(); + Map config = new HashMap<>(); + File outputFile = new File(output); + config.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), outputFile.getAbsolutePath()); + config.put(ConnectorConfigKeys.GEAFLOW_DSL_SKIP_HEADER.getKey(), "false"); - FileSplit fileSplit = new FileSplit(outputFile.getAbsolutePath()); - StructType dataSchema = new StructType( + FileSplit fileSplit = new FileSplit(outputFile.getAbsolutePath()); + StructType dataSchema = + new StructType( new TableField("id", Types.INTEGER, false), new TableField("name", Types.BINARY_STRING), - new TableField("price", Types.DOUBLE) - ); - format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); - Iterator iterator = format.batchRead(); - StringBuilder result = new StringBuilder(); - while (iterator.hasNext()) { - Row row = iterator.next(); - if (result.length() > 0) { - result.append("\n"); - } - result.append(row.toString()); - } - Assert.assertEquals(result.toString(), - "[1, a1, 10.0]\n" - + "[2, a2, 12.0]\n" - + "[3, a3, 15.0]"); + new TableField("price", Types.DOUBLE)); + format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); + Iterator iterator = format.batchRead(); + StringBuilder result = new StringBuilder(); + while (iterator.hasNext()) { + Row row = iterator.next(); + if (result.length() > 0) { + result.append("\n"); + } + result.append(row.toString()); } + Assert.assertEquals(result.toString(), "[1, a1, 10.0]\n" + "[2, a2, 12.0]\n" + "[3, a3, 15.0]"); + } - private void writeData(String outputFile, String... lines) throws Exception { - FileUtils.writeLines(new File(outputFile), Lists.newArrayList(lines)); - } + private void writeData(String outputFile, String... lines) throws Exception { + FileUtils.writeLines(new File(outputFile), Lists.newArrayList(lines)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/FileTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/FileTableConnectorTest.java index 02a51e65e..163590236 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/FileTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/FileTableConnectorTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.connector.file; -import com.alibaba.fastjson.JSON; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; @@ -27,6 +26,7 @@ import java.io.InputStreamReader; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -46,78 +46,78 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class FileTableConnectorTest { +import com.alibaba.fastjson.JSON; - @Test - public void testLocalFileResource() throws IOException { - DfsFileReadHandler resource = new DfsFileReadHandler(); - resource.init(new Configuration(), new TableSchema(), "file:///data"); +public class FileTableConnectorTest { - try { - resource.readPartition(new FileSplit("/data", "test"), - new FileOffset(0L), 0); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "File /data/test does not exist"); - } - resource.close(); - FileSplit fileSplit = new FileSplit("test_baseDir", "test_relativePath"); - Assert.assertEquals(fileSplit.getName(), "test_relativePath"); - Assert.assertEquals(fileSplit.getPath(), "test_baseDir/test_relativePath"); - FileSplit fileSplit2 = new FileSplit("test_baseDir/test_relativePath"); - Assert.assertEquals(fileSplit.hashCode(), fileSplit2.hashCode()); - Assert.assertEquals(fileSplit, fileSplit2); - Assert.assertEquals(fileSplit, fileSplit); - Assert.assertNotEquals(fileSplit, null); - Assert.assertEquals(fileSplit.toString(), fileSplit2.toString()); + @Test + public void testLocalFileResource() throws IOException { + DfsFileReadHandler resource = new DfsFileReadHandler(); + resource.init(new Configuration(), new TableSchema(), "file:///data"); - FileOffset fileOffset = new FileOffset(1000L); - Assert.assertEquals(fileOffset.getOffset(), 1000L); - Assert.assertEquals(fileOffset.humanReadable(), "1000"); + try { + resource.readPartition(new FileSplit("/data", "test"), new FileOffset(0L), 0); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "File /data/test does not exist"); } + resource.close(); + FileSplit fileSplit = new FileSplit("test_baseDir", "test_relativePath"); + Assert.assertEquals(fileSplit.getName(), "test_relativePath"); + Assert.assertEquals(fileSplit.getPath(), "test_baseDir/test_relativePath"); + FileSplit fileSplit2 = new FileSplit("test_baseDir/test_relativePath"); + Assert.assertEquals(fileSplit.hashCode(), fileSplit2.hashCode()); + Assert.assertEquals(fileSplit, fileSplit2); + Assert.assertEquals(fileSplit, fileSplit); + Assert.assertNotEquals(fileSplit, null); + Assert.assertEquals(fileSplit.toString(), fileSplit2.toString()); - @Test - public void testFileTableSink() { - FileTableSink tableSink = new FileTableSink(); - Configuration conf = new Configuration(); - conf.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH, "test"); - tableSink.init(conf, new StructType(new TableField("test", StringType.INSTANCE, true))); - tableSink.close(); - } + FileOffset fileOffset = new FileOffset(1000L); + Assert.assertEquals(fileOffset.getOffset(), 1000L); + Assert.assertEquals(fileOffset.humanReadable(), "1000"); + } - @Test - public void testFileHandlerReandAndWrite() throws IOException { - String testDir = "/tmp/testDirForHdfsFileWriteHandlerTest"; - FileUtils.deleteDirectory(new File(testDir)); - HdfsFileWriteHandler handler = new HdfsFileWriteHandler(testDir); - Configuration testConf = new Configuration(); - testConf.put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); - handler.init(testConf, StructType.singleValue(StringType.INSTANCE, true), 0); - handler.write("test"); - handler.flush(); - handler.close(); - FileInputStream inputStream = new FileInputStream(testDir + "/partition_0"); - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - String line = reader.readLine(); - Assert.assertEquals(line, "test"); + @Test + public void testFileTableSink() { + FileTableSink tableSink = new FileTableSink(); + Configuration conf = new Configuration(); + conf.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH, "test"); + tableSink.init(conf, new StructType(new TableField("test", StringType.INSTANCE, true))); + tableSink.close(); + } + @Test + public void testFileHandlerReandAndWrite() throws IOException { + String testDir = "/tmp/testDirForHdfsFileWriteHandlerTest"; + FileUtils.deleteDirectory(new File(testDir)); + HdfsFileWriteHandler handler = new HdfsFileWriteHandler(testDir); + Configuration testConf = new Configuration(); + testConf.put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); + handler.init(testConf, StructType.singleValue(StringType.INSTANCE, true), 0); + handler.write("test"); + handler.flush(); + handler.close(); + FileInputStream inputStream = new FileInputStream(testDir + "/partition_0"); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + String line = reader.readLine(); + Assert.assertEquals(line, "test"); - DfsFileReadHandler readHandler = new DfsFileReadHandler(); - readHandler.init(testConf, new TableSchema(), testDir); - List partitions = readHandler.listPartitions(1); - Assert.assertEquals(partitions.size(), 1); - FetchData fetchData = - readHandler.readPartition((FileSplit) partitions.get(0), new FileOffset(0L), 10); - Assert.assertEquals(fetchData.getDataSize(), 1); - Assert.assertEquals(fetchData.getDataIterator().next(), "test"); - handler.close(); - } + DfsFileReadHandler readHandler = new DfsFileReadHandler(); + readHandler.init(testConf, new TableSchema(), testDir); + List partitions = readHandler.listPartitions(1); + Assert.assertEquals(partitions.size(), 1); + FetchData fetchData = + readHandler.readPartition((FileSplit) partitions.get(0), new FileOffset(0L), 10); + Assert.assertEquals(fetchData.getDataSize(), 1); + Assert.assertEquals(fetchData.getDataIterator().next(), "test"); + handler.close(); + } - @Test - public void testConsoleOffset() { - FileOffset test = new FileOffset(111L); - Map kvMap = JSON.parseObject(new ConsoleOffset(test).toJson(), Map.class); - Assert.assertEquals(kvMap.get("offset"), "111"); - Assert.assertEquals(kvMap.get("type"), "NON_TIMESTAMP"); - Assert.assertTrue(Long.parseLong(kvMap.get("writeTime")) > 0L); - } + @Test + public void testConsoleOffset() { + FileOffset test = new FileOffset(111L); + Map kvMap = JSON.parseObject(new ConsoleOffset(test).toJson(), Map.class); + Assert.assertEquals(kvMap.get("offset"), "111"); + Assert.assertEquals(kvMap.get("type"), "NON_TIMESTAMP"); + Assert.assertTrue(Long.parseLong(kvMap.get("writeTime")) > 0L); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ParquetFormatTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ParquetFormatTest.java index da38546d7..59e5ed911 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ParquetFormatTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-file/src/test/java/org/apache/geaflow/dsl/connector/file/ParquetFormatTest.java @@ -21,7 +21,6 @@ import static java.lang.Thread.sleep; -import com.google.common.collect.Lists; import java.io.File; import java.io.IOException; import java.util.Arrays; @@ -30,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.avro.Schema; import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.GenericRecordBuilder; @@ -54,91 +54,93 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class ParquetFormatTest { - private static final StructType dataSchema = new StructType( - new TableField("id", Types.INTEGER, false), - new TableField("name", Types.BINARY_STRING), - new TableField("price", Types.DOUBLE) - ); + private static final StructType dataSchema = + new StructType( + new TableField("id", Types.INTEGER, false), + new TableField("name", Types.BINARY_STRING), + new TableField("price", Types.DOUBLE)); - private static final Schema avroSchema = ParquetFormat.convertToAvroSchema(dataSchema, false); + private static final Schema avroSchema = ParquetFormat.convertToAvroSchema(dataSchema, false); - @Test - public void testReadParquet() throws Exception { - String output = "target/test/parquet/output"; + @Test + public void testReadParquet() throws Exception { + String output = "target/test/parquet/output"; - writeData(output, "1,a1,10", "2,a2,12", "3,a3,15"); + writeData(output, "1,a1,10", "2,a2,12", "3,a3,15"); - ParquetFormat format = new ParquetFormat(); - Map config = new HashMap<>(); - File file = new File(output); - List parquetFiles = Arrays.stream(file.listFiles()) + ParquetFormat format = new ParquetFormat(); + Map config = new HashMap<>(); + File file = new File(output); + List parquetFiles = + Arrays.stream(file.listFiles()) .filter(f -> f.getName().endsWith(".parquet")) .collect(Collectors.toList()); - Assert.assertEquals(parquetFiles.size(), 1); - - config.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), parquetFiles.get(0).getAbsolutePath()); - FileSplit fileSplit = new FileSplit(parquetFiles.get(0).getAbsolutePath()); - format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); - Iterator iterator = format.batchRead(); - StringBuilder result = new StringBuilder(); - while (iterator.hasNext()) { - Row row = iterator.next(); - if (result.length() > 0) { - result.append("\n"); - } - result.append(row.toString()); - } - Assert.assertEquals(result.toString(), - "[1, a1, 10.0]\n" - + "[2, a2, 12.0]\n" - + "[3, a3, 15.0]"); + Assert.assertEquals(parquetFiles.size(), 1); + + config.put( + ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), parquetFiles.get(0).getAbsolutePath()); + FileSplit fileSplit = new FileSplit(parquetFiles.get(0).getAbsolutePath()); + format.init(new Configuration(config), new TableSchema(dataSchema), fileSplit); + Iterator iterator = format.batchRead(); + StringBuilder result = new StringBuilder(); + while (iterator.hasNext()) { + Row row = iterator.next(); + if (result.length() > 0) { + result.append("\n"); + } + result.append(row.toString()); } + Assert.assertEquals(result.toString(), "[1, a1, 10.0]\n" + "[2, a2, 12.0]\n" + "[3, a3, 15.0]"); + } - private void writeData(String output, String... lines) throws Exception { - final org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + private void writeData(String output, String... lines) throws Exception { + final org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); - String inputFile = "target/test/parquet/input.txt"; - FileUtils.writeLines(new File(inputFile), Lists.newArrayList(lines)); + String inputFile = "target/test/parquet/input.txt"; + FileUtils.writeLines(new File(inputFile), Lists.newArrayList(lines)); - Path inputPath = new Path(inputFile); - Path outputPath = new Path(output); - final FileSystem fileSystem = inputPath.getFileSystem(conf); - fileSystem.delete(outputPath, true); - final Job job = new Job(conf, "write data to parquet"); + Path inputPath = new Path(inputFile); + Path outputPath = new Path(output); + final FileSystem fileSystem = inputPath.getFileSystem(conf); + fileSystem.delete(outputPath, true); + final Job job = new Job(conf, "write data to parquet"); - TextInputFormat.addInputPath(job, inputPath); - job.setInputFormatClass(TextInputFormat.class); + TextInputFormat.addInputPath(job, inputPath); + job.setInputFormatClass(TextInputFormat.class); - job.setMapperClass(TestMapper.class); - job.setNumReduceTasks(0); + job.setMapperClass(TestMapper.class); + job.setNumReduceTasks(0); - job.setOutputFormatClass(AvroParquetOutputFormat.class); - AvroParquetOutputFormat.setOutputPath(job, outputPath); - AvroParquetOutputFormat.setSchema(job, avroSchema); + job.setOutputFormatClass(AvroParquetOutputFormat.class); + AvroParquetOutputFormat.setOutputPath(job, outputPath); + AvroParquetOutputFormat.setSchema(job, avroSchema); - job.submit(); - while (!job.isComplete()) { - sleep(10); - } - if (!job.isSuccessful()) { - throw new RuntimeException("job failed " + job.getJobName()); - } + job.submit(); + while (!job.isComplete()) { + sleep(10); } - - public static class TestMapper extends Mapper { - - @Override - public void map(LongWritable key, Text value, - Context context) throws IOException, InterruptedException { - String[] fields = value.toString().split(","); - GenericRecord record = new GenericRecordBuilder(avroSchema) - .set("id", Integer.parseInt(fields[0])) - .set("name", fields[1]) - .set("price", Double.parseDouble(fields[2])) - .build(); - context.write(null, record); - } + if (!job.isSuccessful()) { + throw new RuntimeException("job failed " + job.getJobName()); + } + } + + public static class TestMapper extends Mapper { + + @Override + public void map(LongWritable key, Text value, Context context) + throws IOException, InterruptedException { + String[] fields = value.toString().split(","); + GenericRecord record = + new GenericRecordBuilder(avroSchema) + .set("id", Integer.parseInt(fields[0])) + .set("name", fields[1]) + .set("price", Double.parseDouble(fields[2])) + .build(); + context.write(null, record); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConfigKeys.java index 4b3d7313e..c6f2f4654 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConfigKeys.java @@ -1,66 +1,66 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.geaflow.dsl.connector.hbase; - -import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_BUFFER_SIZE; -import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_FAMILY_MAPPING; -import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_NAMESPACE; -import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_SEPARATOR; - -import org.apache.geaflow.common.config.ConfigKey; -import org.apache.geaflow.common.config.ConfigKeys; - -public class HBaseConfigKeys { - - public static final ConfigKey GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM = ConfigKeys - .key("geaflow.dsl.hbase.zookeeper.quorum") - .noDefaultValue() - .description("HBase zookeeper quorum servers list."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_NAME_SPACE = ConfigKeys - .key("geaflow.dsl.hbase.namespace") - .defaultValue(DEFAULT_NAMESPACE) - .description("HBase namespace."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_TABLE_NAME = ConfigKeys - .key("geaflow.dsl.hbase.tablename") - .noDefaultValue() - .description("HBase table name."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS = ConfigKeys - .key("geaflow.dsl.hbase.rowkey.column") - .noDefaultValue() - .description("HBase rowkey columns."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR = ConfigKeys - .key("geaflow.dsl.hbase.rowkey.separator") - .defaultValue(DEFAULT_SEPARATOR) - .description("HBase rowkey join serapator."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_FAMILY_NAME = ConfigKeys - .key("geaflow.dsl.hbase.familyname.mapping") - .defaultValue(DEFAULT_FAMILY_MAPPING) - .description("HBase column family name mapping."); - - public static final ConfigKey GEAFLOW_DSL_HBASE_BUFFER_SIZE = ConfigKeys - .key("geaflow.dsl.hbase.buffersize") - .defaultValue(DEFAULT_BUFFER_SIZE) - .description("HBase writer buffer size."); -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.connector.hbase; + +import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_BUFFER_SIZE; +import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_FAMILY_MAPPING; +import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_NAMESPACE; +import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_SEPARATOR; + +import org.apache.geaflow.common.config.ConfigKey; +import org.apache.geaflow.common.config.ConfigKeys; + +public class HBaseConfigKeys { + + public static final ConfigKey GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM = + ConfigKeys.key("geaflow.dsl.hbase.zookeeper.quorum") + .noDefaultValue() + .description("HBase zookeeper quorum servers list."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_NAME_SPACE = + ConfigKeys.key("geaflow.dsl.hbase.namespace") + .defaultValue(DEFAULT_NAMESPACE) + .description("HBase namespace."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_TABLE_NAME = + ConfigKeys.key("geaflow.dsl.hbase.tablename") + .noDefaultValue() + .description("HBase table name."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS = + ConfigKeys.key("geaflow.dsl.hbase.rowkey.column") + .noDefaultValue() + .description("HBase rowkey columns."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR = + ConfigKeys.key("geaflow.dsl.hbase.rowkey.separator") + .defaultValue(DEFAULT_SEPARATOR) + .description("HBase rowkey join serapator."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_FAMILY_NAME = + ConfigKeys.key("geaflow.dsl.hbase.familyname.mapping") + .defaultValue(DEFAULT_FAMILY_MAPPING) + .description("HBase column family name mapping."); + + public static final ConfigKey GEAFLOW_DSL_HBASE_BUFFER_SIZE = + ConfigKeys.key("geaflow.dsl.hbase.buffersize") + .defaultValue(DEFAULT_BUFFER_SIZE) + .description("HBase writer buffer size."); +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConstants.java index 63952a26d..e44f0ceac 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseConstants.java @@ -21,13 +21,13 @@ public class HBaseConstants { - public static final int DEFAULT_BUFFER_SIZE = 1024 * 1024; + public static final int DEFAULT_BUFFER_SIZE = 1024 * 1024; - public static final String DEFAULT_NAMESPACE = "default"; + public static final String DEFAULT_NAMESPACE = "default"; - public static final String DEFAULT_COLUMN_FAMILY = "GeaFlow"; + public static final String DEFAULT_COLUMN_FAMILY = "GeaFlow"; - public static final String DEFAULT_SEPARATOR = ","; + public static final String DEFAULT_SEPARATOR = ","; - public static final String DEFAULT_FAMILY_MAPPING = "{}"; + public static final String DEFAULT_FAMILY_MAPPING = "{}"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableConnector.java index 2cc241632..233b1f200 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableConnector.java @@ -25,15 +25,15 @@ public class HBaseTableConnector implements TableWritableConnector { - public static final String TYPE = "HBase"; + public static final String TYPE = "HBase"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSink createSink(Configuration conf) { - return new HBaseTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new HBaseTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableSink.java index 8623e843c..c59d46261 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/main/java/org/apache/geaflow/dsl/connector/hbase/HBaseTableSink.java @@ -1,193 +1,192 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.geaflow.dsl.connector.hbase; - -import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_COLUMN_FAMILY; - -import com.google.common.collect.Lists; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import org.apache.geaflow.api.context.RuntimeContext; -import org.apache.geaflow.common.config.Configuration; -import org.apache.geaflow.common.type.IType; -import org.apache.geaflow.common.type.Types; -import org.apache.geaflow.dsl.common.data.Row; -import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; -import org.apache.geaflow.dsl.common.types.StructType; -import org.apache.geaflow.dsl.connector.api.TableSink; -import org.apache.geaflow.utils.JsonUtils; -import org.apache.hadoop.hbase.TableName; -import org.apache.hadoop.hbase.client.BufferedMutator; -import org.apache.hadoop.hbase.client.BufferedMutatorParams; -import org.apache.hadoop.hbase.client.Connection; -import org.apache.hadoop.hbase.client.ConnectionFactory; -import org.apache.hadoop.hbase.client.Put; -import org.apache.hadoop.hbase.util.Bytes; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class HBaseTableSink implements TableSink { - - private static final Logger LOGGER = LoggerFactory.getLogger(HBaseTableSink.class); - - private StructType schema; - - private String zookeeperQuorum; - - private String namespace; - - private String tableName; - - private Set rowKeyColumns; - - private String separator; - - private Map familyNamesMap; - - private int bufferSize; - - private Connection connection; - - private BufferedMutator mutator; - - @Override - public void init(Configuration tableConf, StructType schema) { - LOGGER.info("Prepare with config: {}, \n schema: {}", tableConf, schema); - this.schema = schema; - - this.zookeeperQuorum = - tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM); - this.tableName = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_TABLE_NAME); - this.namespace = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE); - String rowKeys = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS); - this.rowKeyColumns = new HashSet<>(Arrays.asList(rowKeys.split("\\s*,\\s*"))); - this.separator = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR); - String familyNameMapping = tableConf.getString( - HBaseConfigKeys.GEAFLOW_DSL_HBASE_FAMILY_NAME); - this.familyNamesMap = JsonUtils.parseJson2map(familyNameMapping); - this.bufferSize = tableConf.getInteger(HBaseConfigKeys.GEAFLOW_DSL_HBASE_BUFFER_SIZE); - } - - @Override - public void open(RuntimeContext context) { - org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); - conf.set("hbase.zookeeper.quorum", zookeeperQuorum); - try { - connection = ConnectionFactory.createConnection(conf); - BufferedMutatorParams bufferedMutatorParams = new BufferedMutatorParams( - TableName.valueOf(namespace, tableName)) - .writeBufferSize(bufferSize); - mutator = connection.getBufferedMutator(bufferedMutatorParams); - } catch (IOException e) { - throw new GeaFlowDSLException("Can not get connection from hbase"); - } - } - - @Override - public void write(Row row) throws IOException { - Put put = row2Put(row); - mutator.mutate(put); - } - - @Override - public void finish() throws IOException { - mutator.flush(); - } - - @Override - public void close() { - try { - if (Objects.nonNull(this.mutator)) { - mutator.close(); - } - if (Objects.nonNull(this.connection)) { - connection.close(); - } - } catch (IOException e) { - throw new GeaFlowDSLException("Fail to close resources."); - } - } - - private Put row2Put(Row row) { - byte[] rowKey = buildRowKey(row); - Put put = new Put(rowKey); - List fieldNames = this.schema.getFieldNames(); - IType[] types = this.schema.getTypes(); - for (int i = 0; i < fieldNames.size(); i++) { - if (rowKeyColumns.contains(fieldNames.get(i))) { - continue; - } - String fieldName = fieldNames.get(i); - String familyName = familyNamesMap.getOrDefault(fieldName, DEFAULT_COLUMN_FAMILY); - byte[] values = convertColumnToBytes(row, types, i); - put.addColumn(Bytes.toBytes(familyName), Bytes.toBytes(fieldName), values); - } - return put; - } - - private byte[] buildRowKey(Row row) { - List fieldNames = this.schema.getFieldNames(); - IType[] types = this.schema.getTypes(); - List rowKeyValues = Lists.newArrayList(); - for (int i = 0; i < fieldNames.size(); i++) { - if (rowKeyColumns.contains(fieldNames.get(i))) { - rowKeyValues.add(row.getField(i, types[i]).toString()); - } - } - return Bytes.toBytes(String.join(separator, rowKeyValues)); - } - - private byte[] convertColumnToBytes(Row row, IType[] types, int idx) { - Object field = row.getField(idx, types[idx]); - if (Objects.isNull(field)) { - return null; - } - String typeName = types[idx].getName(); - switch (typeName) { - case Types.TYPE_NAME_BYTE: - return Bytes.toBytes((Byte) field); - case Types.TYPE_NAME_SHORT: - return Bytes.toBytes((Short) field); - case Types.TYPE_NAME_INTEGER: - return Bytes.toBytes((Integer) field); - case Types.TYPE_NAME_LONG: - return Bytes.toBytes((Long) field); - case Types.TYPE_NAME_BOOLEAN: - return Bytes.toBytes((Boolean) field); - case Types.TYPE_NAME_FLOAT: - return Bytes.toBytes((Float) field); - case Types.TYPE_NAME_DOUBLE: - return Bytes.toBytes((Double) field); - case Types.TYPE_NAME_STRING: - return field.toString().getBytes(StandardCharsets.UTF_8); - case Types.TYPE_NAME_BINARY_STRING: - return field.toString().getBytes(); - default: - throw new GeaFlowDSLException(String.format("Type: %s is not supported.", - typeName)); - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.connector.hbase; + +import static org.apache.geaflow.dsl.connector.hbase.HBaseConstants.DEFAULT_COLUMN_FAMILY; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import org.apache.geaflow.api.context.RuntimeContext; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.common.type.Types; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.connector.api.TableSink; +import org.apache.geaflow.utils.JsonUtils; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.BufferedMutator; +import org.apache.hadoop.hbase.client.BufferedMutatorParams; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.ConnectionFactory; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.util.Bytes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Lists; + +public class HBaseTableSink implements TableSink { + + private static final Logger LOGGER = LoggerFactory.getLogger(HBaseTableSink.class); + + private StructType schema; + + private String zookeeperQuorum; + + private String namespace; + + private String tableName; + + private Set rowKeyColumns; + + private String separator; + + private Map familyNamesMap; + + private int bufferSize; + + private Connection connection; + + private BufferedMutator mutator; + + @Override + public void init(Configuration tableConf, StructType schema) { + LOGGER.info("Prepare with config: {}, \n schema: {}", tableConf, schema); + this.schema = schema; + + this.zookeeperQuorum = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM); + this.tableName = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_TABLE_NAME); + this.namespace = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE); + String rowKeys = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS); + this.rowKeyColumns = new HashSet<>(Arrays.asList(rowKeys.split("\\s*,\\s*"))); + this.separator = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR); + String familyNameMapping = tableConf.getString(HBaseConfigKeys.GEAFLOW_DSL_HBASE_FAMILY_NAME); + this.familyNamesMap = JsonUtils.parseJson2map(familyNameMapping); + this.bufferSize = tableConf.getInteger(HBaseConfigKeys.GEAFLOW_DSL_HBASE_BUFFER_SIZE); + } + + @Override + public void open(RuntimeContext context) { + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + conf.set("hbase.zookeeper.quorum", zookeeperQuorum); + try { + connection = ConnectionFactory.createConnection(conf); + BufferedMutatorParams bufferedMutatorParams = + new BufferedMutatorParams(TableName.valueOf(namespace, tableName)) + .writeBufferSize(bufferSize); + mutator = connection.getBufferedMutator(bufferedMutatorParams); + } catch (IOException e) { + throw new GeaFlowDSLException("Can not get connection from hbase"); + } + } + + @Override + public void write(Row row) throws IOException { + Put put = row2Put(row); + mutator.mutate(put); + } + + @Override + public void finish() throws IOException { + mutator.flush(); + } + + @Override + public void close() { + try { + if (Objects.nonNull(this.mutator)) { + mutator.close(); + } + if (Objects.nonNull(this.connection)) { + connection.close(); + } + } catch (IOException e) { + throw new GeaFlowDSLException("Fail to close resources."); + } + } + + private Put row2Put(Row row) { + byte[] rowKey = buildRowKey(row); + Put put = new Put(rowKey); + List fieldNames = this.schema.getFieldNames(); + IType[] types = this.schema.getTypes(); + for (int i = 0; i < fieldNames.size(); i++) { + if (rowKeyColumns.contains(fieldNames.get(i))) { + continue; + } + String fieldName = fieldNames.get(i); + String familyName = familyNamesMap.getOrDefault(fieldName, DEFAULT_COLUMN_FAMILY); + byte[] values = convertColumnToBytes(row, types, i); + put.addColumn(Bytes.toBytes(familyName), Bytes.toBytes(fieldName), values); + } + return put; + } + + private byte[] buildRowKey(Row row) { + List fieldNames = this.schema.getFieldNames(); + IType[] types = this.schema.getTypes(); + List rowKeyValues = Lists.newArrayList(); + for (int i = 0; i < fieldNames.size(); i++) { + if (rowKeyColumns.contains(fieldNames.get(i))) { + rowKeyValues.add(row.getField(i, types[i]).toString()); + } + } + return Bytes.toBytes(String.join(separator, rowKeyValues)); + } + + private byte[] convertColumnToBytes(Row row, IType[] types, int idx) { + Object field = row.getField(idx, types[idx]); + if (Objects.isNull(field)) { + return null; + } + String typeName = types[idx].getName(); + switch (typeName) { + case Types.TYPE_NAME_BYTE: + return Bytes.toBytes((Byte) field); + case Types.TYPE_NAME_SHORT: + return Bytes.toBytes((Short) field); + case Types.TYPE_NAME_INTEGER: + return Bytes.toBytes((Integer) field); + case Types.TYPE_NAME_LONG: + return Bytes.toBytes((Long) field); + case Types.TYPE_NAME_BOOLEAN: + return Bytes.toBytes((Boolean) field); + case Types.TYPE_NAME_FLOAT: + return Bytes.toBytes((Float) field); + case Types.TYPE_NAME_DOUBLE: + return Bytes.toBytes((Double) field); + case Types.TYPE_NAME_STRING: + return field.toString().getBytes(StandardCharsets.UTF_8); + case Types.TYPE_NAME_BINARY_STRING: + return field.toString().getBytes(); + default: + throw new GeaFlowDSLException(String.format("Type: %s is not supported.", typeName)); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseConnectorTest.java index 528865dcb..c8b230ff4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseConnectorTest.java @@ -1,216 +1,276 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.geaflow.dsl.connector.hbase; - -import com.google.common.collect.Lists; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import org.apache.geaflow.common.config.Configuration; -import org.apache.geaflow.common.type.Types; -import org.apache.geaflow.dsl.common.data.Row; -import org.apache.geaflow.dsl.common.data.impl.ObjectRow; -import org.apache.geaflow.dsl.common.types.StructType; -import org.apache.geaflow.dsl.common.types.TableField; -import org.apache.geaflow.dsl.common.types.TableSchema; -import org.apache.geaflow.dsl.connector.api.TableConnector; -import org.apache.geaflow.dsl.connector.api.TableSink; -import org.apache.geaflow.dsl.connector.api.TableWritableConnector; -import org.apache.geaflow.dsl.connector.api.util.ConnectorFactory; -import org.apache.geaflow.runtime.core.context.DefaultRuntimeContext; -import org.apache.hadoop.hbase.client.Result; -import org.apache.hadoop.hbase.client.ResultScanner; -import org.apache.hadoop.hbase.util.Bytes; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testng.Assert; -import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -public class HBaseConnectorTest { - - private static final Logger LOG = LoggerFactory.getLogger(HBaseConnectorTest.class); - - private static String namespace = "TuGraph"; - - public static final String zookeeperQuorum = "127.0.0.1"; - - public static final String tableType = "HBase"; - - public static final String tmpDataDir = "/tmp/GeaFlow-HBase-Sink-Connector"; - - private final StructType dataSchema = new StructType( - new TableField("id", Types.INTEGER, false), - new TableField("name", Types.BINARY_STRING), - new TableField("price", Types.DOUBLE), - new TableField("weight", Types.LONG) - ); - - Object[][] items = { - {1, "a1", 10.11, 12L}, - {2, "a2", 12.22, 10000000L}, - {3, "a3", 13.33, 1237879479832L}, - {4, "a4", 14.44, 34978947328979L}, - {5, "a5", 25.67, 98302183091830190L} - }; - - private final TableSchema tableSchema = new TableSchema(dataSchema); - - @BeforeClass - public static void setup() throws IOException { - System.setProperty("test.build.data.basedirectory", tmpDataDir); - HBaseLocalTestUtils.createNamespace(namespace); - } - - @AfterClass - public static void tearDown() { - HBaseLocalTestUtils.closeConnection(); - } - - @Test - public void testLoadConnector() { - TableConnector tableConnector = ConnectorFactory.loadConnector(tableType); - Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "hbase"); - } - - private void prepare(String tableName, Map tableConfMap, - String... columnFamilies) throws IOException { - if (!tableConfMap.containsKey(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE.getKey())) { - namespace = "default"; - } - HBaseLocalTestUtils.createTable(namespace, tableName, columnFamilies); - - - TableConnector tableConnector = ConnectorFactory.loadConnector(tableType); - TableWritableConnector readableConnector = (TableWritableConnector) tableConnector; - Configuration tableConf = new Configuration(tableConfMap); - TableSink tableSink = readableConnector.createSink(tableConf); - tableSink.init(tableConf, tableSchema); - tableSink.open(new DefaultRuntimeContext(tableConf)); - - for (Object[] item : items) { - Row row = ObjectRow.create(item); - tableSink.write(row); - } - tableSink.finish(); - - List results = Lists.newArrayList(); - ResultScanner scanner = HBaseLocalTestUtils.getScanner(namespace, tableName); - for (Result result : scanner) { - results.add(result); - } - Assert.assertEquals(results.size(), 5); - } - - @Test - public void testWriteBase() throws IOException { - String tableName = "GeaFlowBase"; - Map tableConfMap = buildConfiguration(zookeeperQuorum, namespace, tableName, - "id", null, "{\"name\": " - + "\"A\", \"price\": \"A\", \"weight\": \"B\"}"); - prepare(tableName, tableConfMap, "A", "B"); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(Bytes.toString(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "A", "name")), items[i][1]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "A", "price")).getDouble(), items[i][2]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "B", "weight")).getLong(), items[i][3]); - } - } - - @Test - public void testWriteRowKey() throws IOException { - String tableName = "GeaFlowRowKey"; - Map tableConfMap = buildConfiguration(zookeeperQuorum, namespace, tableName, - "id,name", "-", "{\"price\": \"A\", \"weight\": \"A\"}"); - prepare("GeaFlowRowKey", tableConfMap, "A"); - - for (int i = 0; i < 5; i++) { - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - (i + 1) + "-" + items[i][1], "A", "price")).getDouble(), items[i][2]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - (i + 1) + "-" + items[i][1], "A", "weight")).getLong(), items[i][3]); - } - } - - @Test - public void testWriteDefaultColumnFamily() throws IOException { - String tableName = "GeaFlowDefaultColumnFamily"; - Map tableConfMap = buildConfiguration(zookeeperQuorum, namespace, tableName, - "id", null, null); - // default column family is "GeaFlow" - prepare(tableName, tableConfMap, "A", "GeaFlow"); - - for (int i = 0; i < 5; i++) { - Assert.assertEquals(Bytes.toString(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "GeaFlow", "name")), items[i][1]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "GeaFlow", "price")).getDouble(), items[i][2]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell(namespace, tableName, - String.valueOf(i + 1), "GeaFlow", "weight")).getLong(), items[i][3]); - } - } - - @Test - public void testWriteDefaultNamespace() throws IOException { - String tableName = "GeaFlowDefaultNamespace"; - // default namespace is "default" - Map tableConfMap = buildConfiguration(zookeeperQuorum, null, tableName, - "id", null, "{\"name\": " - + "\"A\", \"price\": \"A\", \"weight\": \"B\"}"); - prepare(tableName, tableConfMap, "A", "B"); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(Bytes.toString(HBaseLocalTestUtils.getCell("default", tableName, - String.valueOf(i + 1), "A", "name")), items[i][1]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell("default", tableName, - String.valueOf(i + 1), "A", "price")).getDouble(), items[i][2]); - Assert.assertEquals(ByteBuffer.wrap(HBaseLocalTestUtils.getCell("default", tableName, - String.valueOf(i + 1), "B", "weight")).getLong(), items[i][3]); - } - } - - private Map buildConfiguration(String zkQuorum, String namespace, - String tableName, String rowKeyColumn, - String rowKeySeparator, - String familyNameMapping) { - Map tableConfMap = new HashMap<>(); - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM.getKey(), zkQuorum); - if (Objects.nonNull(namespace)) { - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE.getKey(), namespace); - } - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_TABLE_NAME.getKey(), tableName); - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS.getKey(), rowKeyColumn); - if (Objects.nonNull(rowKeySeparator)) { - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR.getKey(), - rowKeySeparator); - } - if (Objects.nonNull(familyNameMapping)) { - tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_FAMILY_NAME.getKey(), - familyNameMapping); - } - return tableConfMap; - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.connector.hbase; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.type.Types; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.types.TableSchema; +import org.apache.geaflow.dsl.connector.api.TableConnector; +import org.apache.geaflow.dsl.connector.api.TableSink; +import org.apache.geaflow.dsl.connector.api.TableWritableConnector; +import org.apache.geaflow.dsl.connector.api.util.ConnectorFactory; +import org.apache.geaflow.runtime.core.context.DefaultRuntimeContext; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.ResultScanner; +import org.apache.hadoop.hbase.util.Bytes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import com.google.common.collect.Lists; + +public class HBaseConnectorTest { + + private static final Logger LOG = LoggerFactory.getLogger(HBaseConnectorTest.class); + + private static String namespace = "TuGraph"; + + public static final String zookeeperQuorum = "127.0.0.1"; + + public static final String tableType = "HBase"; + + public static final String tmpDataDir = "/tmp/GeaFlow-HBase-Sink-Connector"; + + private final StructType dataSchema = + new StructType( + new TableField("id", Types.INTEGER, false), + new TableField("name", Types.BINARY_STRING), + new TableField("price", Types.DOUBLE), + new TableField("weight", Types.LONG)); + + Object[][] items = { + {1, "a1", 10.11, 12L}, + {2, "a2", 12.22, 10000000L}, + {3, "a3", 13.33, 1237879479832L}, + {4, "a4", 14.44, 34978947328979L}, + {5, "a5", 25.67, 98302183091830190L} + }; + + private final TableSchema tableSchema = new TableSchema(dataSchema); + + @BeforeClass + public static void setup() throws IOException { + System.setProperty("test.build.data.basedirectory", tmpDataDir); + HBaseLocalTestUtils.createNamespace(namespace); + } + + @AfterClass + public static void tearDown() { + HBaseLocalTestUtils.closeConnection(); + } + + @Test + public void testLoadConnector() { + TableConnector tableConnector = ConnectorFactory.loadConnector(tableType); + Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "hbase"); + } + + private void prepare(String tableName, Map tableConfMap, String... columnFamilies) + throws IOException { + if (!tableConfMap.containsKey(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE.getKey())) { + namespace = "default"; + } + HBaseLocalTestUtils.createTable(namespace, tableName, columnFamilies); + + TableConnector tableConnector = ConnectorFactory.loadConnector(tableType); + TableWritableConnector readableConnector = (TableWritableConnector) tableConnector; + Configuration tableConf = new Configuration(tableConfMap); + TableSink tableSink = readableConnector.createSink(tableConf); + tableSink.init(tableConf, tableSchema); + tableSink.open(new DefaultRuntimeContext(tableConf)); + + for (Object[] item : items) { + Row row = ObjectRow.create(item); + tableSink.write(row); + } + tableSink.finish(); + + List results = Lists.newArrayList(); + ResultScanner scanner = HBaseLocalTestUtils.getScanner(namespace, tableName); + for (Result result : scanner) { + results.add(result); + } + Assert.assertEquals(results.size(), 5); + } + + @Test + public void testWriteBase() throws IOException { + String tableName = "GeaFlowBase"; + Map tableConfMap = + buildConfiguration( + zookeeperQuorum, + namespace, + tableName, + "id", + null, + "{\"name\": " + "\"A\", \"price\": \"A\", \"weight\": \"B\"}"); + prepare(tableName, tableConfMap, "A", "B"); + for (int i = 0; i < 5; i++) { + Assert.assertEquals( + Bytes.toString( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "A", "name")), + items[i][1]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "A", "price")) + .getDouble(), + items[i][2]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "B", "weight")) + .getLong(), + items[i][3]); + } + } + + @Test + public void testWriteRowKey() throws IOException { + String tableName = "GeaFlowRowKey"; + Map tableConfMap = + buildConfiguration( + zookeeperQuorum, + namespace, + tableName, + "id,name", + "-", + "{\"price\": \"A\", \"weight\": \"A\"}"); + prepare("GeaFlowRowKey", tableConfMap, "A"); + + for (int i = 0; i < 5; i++) { + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, (i + 1) + "-" + items[i][1], "A", "price")) + .getDouble(), + items[i][2]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, (i + 1) + "-" + items[i][1], "A", "weight")) + .getLong(), + items[i][3]); + } + } + + @Test + public void testWriteDefaultColumnFamily() throws IOException { + String tableName = "GeaFlowDefaultColumnFamily"; + Map tableConfMap = + buildConfiguration(zookeeperQuorum, namespace, tableName, "id", null, null); + // default column family is "GeaFlow" + prepare(tableName, tableConfMap, "A", "GeaFlow"); + + for (int i = 0; i < 5; i++) { + Assert.assertEquals( + Bytes.toString( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "GeaFlow", "name")), + items[i][1]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "GeaFlow", "price")) + .getDouble(), + items[i][2]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + namespace, tableName, String.valueOf(i + 1), "GeaFlow", "weight")) + .getLong(), + items[i][3]); + } + } + + @Test + public void testWriteDefaultNamespace() throws IOException { + String tableName = "GeaFlowDefaultNamespace"; + // default namespace is "default" + Map tableConfMap = + buildConfiguration( + zookeeperQuorum, + null, + tableName, + "id", + null, + "{\"name\": " + "\"A\", \"price\": \"A\", \"weight\": \"B\"}"); + prepare(tableName, tableConfMap, "A", "B"); + for (int i = 0; i < 5; i++) { + Assert.assertEquals( + Bytes.toString( + HBaseLocalTestUtils.getCell( + "default", tableName, String.valueOf(i + 1), "A", "name")), + items[i][1]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + "default", tableName, String.valueOf(i + 1), "A", "price")) + .getDouble(), + items[i][2]); + Assert.assertEquals( + ByteBuffer.wrap( + HBaseLocalTestUtils.getCell( + "default", tableName, String.valueOf(i + 1), "B", "weight")) + .getLong(), + items[i][3]); + } + } + + private Map buildConfiguration( + String zkQuorum, + String namespace, + String tableName, + String rowKeyColumn, + String rowKeySeparator, + String familyNameMapping) { + Map tableConfMap = new HashMap<>(); + tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ZOOKEEPER_QUORUM.getKey(), zkQuorum); + if (Objects.nonNull(namespace)) { + tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_NAME_SPACE.getKey(), namespace); + } + tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_TABLE_NAME.getKey(), tableName); + tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_COLUMNS.getKey(), rowKeyColumn); + if (Objects.nonNull(rowKeySeparator)) { + tableConfMap.put( + HBaseConfigKeys.GEAFLOW_DSL_HBASE_ROWKEY_SEPARATOR.getKey(), rowKeySeparator); + } + if (Objects.nonNull(familyNameMapping)) { + tableConfMap.put(HBaseConfigKeys.GEAFLOW_DSL_HBASE_FAMILY_NAME.getKey(), familyNameMapping); + } + return tableConfMap; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseLocalTestUtils.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseLocalTestUtils.java index aac655e30..c2f187f86 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseLocalTestUtils.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hbase/src/test/java/org/apache/geaflow/dsl/connector/hbase/HBaseLocalTestUtils.java @@ -1,174 +1,184 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.geaflow.dsl.connector.hbase; - -import java.io.IOException; -import java.util.Objects; -import org.apache.hadoop.hbase.HBaseTestingUtility; -import org.apache.hadoop.hbase.NamespaceDescriptor; -import org.apache.hadoop.hbase.TableName; -import org.apache.hadoop.hbase.client.Admin; -import org.apache.hadoop.hbase.client.ColumnFamilyDescriptorBuilder; -import org.apache.hadoop.hbase.client.Connection; -import org.apache.hadoop.hbase.client.Get; -import org.apache.hadoop.hbase.client.Put; -import org.apache.hadoop.hbase.client.Result; -import org.apache.hadoop.hbase.client.ResultScanner; -import org.apache.hadoop.hbase.client.Scan; -import org.apache.hadoop.hbase.client.Table; -import org.apache.hadoop.hbase.client.TableDescriptorBuilder; -import org.apache.hadoop.hbase.util.Bytes; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class HBaseLocalTestUtils { - - private static final Logger LOGGER = LoggerFactory.getLogger(HBaseLocalTestUtils.class); - - private static Connection connection = null; - - private static final HBaseTestingUtility hBaseTesting; - - static { - hBaseTesting = new HBaseTestingUtility(); - hBaseTesting.getConfiguration().set("test.hbase.zookeeper.property.clientPort", "2181"); - try { - hBaseTesting.startMiniCluster(); - if (Objects.isNull(connection)) { - connection = hBaseTesting.getConnection(); - } - LOGGER.info("Get connection from HBase local mini cluster"); - } catch (Exception e) { - throw new RuntimeException("Can not get connection from HBase local mini cluster."); - } - } - - public static void closeConnection() { - try { - if (Objects.nonNull(connection)) { - connection.close(); - } - if (Objects.nonNull(hBaseTesting)) { - hBaseTesting.shutdownMiniCluster(); - } - } catch (IOException e) { - throw new RuntimeException("Can not close resources."); - } - } - - private static boolean tableExists(String namespace, String tableName) throws IOException { - Admin admin = connection.getAdmin(); - boolean exists; - try { - exists = admin.tableExists(TableName.valueOf(namespace, tableName)); - } catch (IOException e) { - throw new RuntimeException("Fail to judge table exists."); - } - admin.close(); - return exists; - } - - public static void createNamespace(String namespace) throws IOException { - Admin admin = connection.getAdmin(); - NamespaceDescriptor.Builder builder = NamespaceDescriptor.create(namespace); - try { - admin.createNamespace(builder.build()); - } catch (IOException e) { - throw new RuntimeException("Fail to create HBase namespace."); - } - LOGGER.info("Create HBase namespace {} success.", namespace); - admin.close(); - } - - public static void createTable(String namespace, String tableName, String... columnFamilies) - throws IOException { - if (columnFamilies.length == 0) { - throw new RuntimeException("Create a table with at least one column family."); - } - if (tableExists(namespace, tableName)) { - throw new RuntimeException( - String.format("Namespace %s, table %s already exists.", namespace, tableName)); - } - Admin admin = connection.getAdmin(); - TableDescriptorBuilder tableDescriptorBuilder = TableDescriptorBuilder.newBuilder( - TableName.valueOf(namespace, tableName)); - - for (String columnFamily : columnFamilies) { - ColumnFamilyDescriptorBuilder columnFamilyDescriptorBuilder = - ColumnFamilyDescriptorBuilder.newBuilder( - Bytes.toBytes(columnFamily)); - tableDescriptorBuilder.setColumnFamily(columnFamilyDescriptorBuilder.build()); - } - - try { - admin.createTable(tableDescriptorBuilder.build()); - } catch (IOException e) { - throw new RuntimeException("Fail to create HBase table."); - } - LOGGER.info("Create HBase table `{}` under namespace `{}` success.", tableName, namespace); - admin.close(); - } - - public static void putCell(String namespace, String tableName, String rowKey, - String columnFamilyName, String qualifier, String value) - throws IOException { - Table table = connection.getTable(TableName.valueOf(namespace, tableName)); - Put put = new Put(Bytes.toBytes(rowKey)); - put.addColumn(Bytes.toBytes(columnFamilyName), Bytes.toBytes(qualifier), - Bytes.toBytes(value)); - - try { - table.put(put); - } catch (IOException e) { - throw new RuntimeException("Fail to put one row to HBase."); - } - LOGGER.info("Put a record {} to {}-{}-{}-{}-{}.", value, namespace, tableName, rowKey, - columnFamilyName, qualifier); - table.close(); - } - - public static byte[] getCell(String namespace, String tableName, String rowKey, - String columnFamilyName, String qualifier) throws IOException { - Table table = connection.getTable(TableName.valueOf(namespace, tableName)); - Get get = new Get(Bytes.toBytes(rowKey)); - get.addColumn(Bytes.toBytes(columnFamilyName), Bytes.toBytes(qualifier)); - - byte[] resultValue; - try { - Result result = table.get(get); - resultValue = result.getValue(Bytes.toBytes(columnFamilyName), - Bytes.toBytes(qualifier)); - } catch (IOException e) { - throw new RuntimeException("Fail to get cell from HBase"); - } - table.close(); - return resultValue; - } - - public static ResultScanner getScanner(String namespace, String tableName) throws IOException { - Table table = connection.getTable(TableName.valueOf(namespace, tableName)); - try { - Scan scan = new Scan(); - return table.getScanner(scan); - } catch (IOException e) { - throw new RuntimeException("Fail to scan all records."); - } - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.connector.hbase; + +import java.io.IOException; +import java.util.Objects; + +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.NamespaceDescriptor; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Admin; +import org.apache.hadoop.hbase.client.ColumnFamilyDescriptorBuilder; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.ResultScanner; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.client.TableDescriptorBuilder; +import org.apache.hadoop.hbase.util.Bytes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class HBaseLocalTestUtils { + + private static final Logger LOGGER = LoggerFactory.getLogger(HBaseLocalTestUtils.class); + + private static Connection connection = null; + + private static final HBaseTestingUtility hBaseTesting; + + static { + hBaseTesting = new HBaseTestingUtility(); + hBaseTesting.getConfiguration().set("test.hbase.zookeeper.property.clientPort", "2181"); + try { + hBaseTesting.startMiniCluster(); + if (Objects.isNull(connection)) { + connection = hBaseTesting.getConnection(); + } + LOGGER.info("Get connection from HBase local mini cluster"); + } catch (Exception e) { + throw new RuntimeException("Can not get connection from HBase local mini cluster."); + } + } + + public static void closeConnection() { + try { + if (Objects.nonNull(connection)) { + connection.close(); + } + if (Objects.nonNull(hBaseTesting)) { + hBaseTesting.shutdownMiniCluster(); + } + } catch (IOException e) { + throw new RuntimeException("Can not close resources."); + } + } + + private static boolean tableExists(String namespace, String tableName) throws IOException { + Admin admin = connection.getAdmin(); + boolean exists; + try { + exists = admin.tableExists(TableName.valueOf(namespace, tableName)); + } catch (IOException e) { + throw new RuntimeException("Fail to judge table exists."); + } + admin.close(); + return exists; + } + + public static void createNamespace(String namespace) throws IOException { + Admin admin = connection.getAdmin(); + NamespaceDescriptor.Builder builder = NamespaceDescriptor.create(namespace); + try { + admin.createNamespace(builder.build()); + } catch (IOException e) { + throw new RuntimeException("Fail to create HBase namespace."); + } + LOGGER.info("Create HBase namespace {} success.", namespace); + admin.close(); + } + + public static void createTable(String namespace, String tableName, String... columnFamilies) + throws IOException { + if (columnFamilies.length == 0) { + throw new RuntimeException("Create a table with at least one column family."); + } + if (tableExists(namespace, tableName)) { + throw new RuntimeException( + String.format("Namespace %s, table %s already exists.", namespace, tableName)); + } + Admin admin = connection.getAdmin(); + TableDescriptorBuilder tableDescriptorBuilder = + TableDescriptorBuilder.newBuilder(TableName.valueOf(namespace, tableName)); + + for (String columnFamily : columnFamilies) { + ColumnFamilyDescriptorBuilder columnFamilyDescriptorBuilder = + ColumnFamilyDescriptorBuilder.newBuilder(Bytes.toBytes(columnFamily)); + tableDescriptorBuilder.setColumnFamily(columnFamilyDescriptorBuilder.build()); + } + + try { + admin.createTable(tableDescriptorBuilder.build()); + } catch (IOException e) { + throw new RuntimeException("Fail to create HBase table."); + } + LOGGER.info("Create HBase table `{}` under namespace `{}` success.", tableName, namespace); + admin.close(); + } + + public static void putCell( + String namespace, + String tableName, + String rowKey, + String columnFamilyName, + String qualifier, + String value) + throws IOException { + Table table = connection.getTable(TableName.valueOf(namespace, tableName)); + Put put = new Put(Bytes.toBytes(rowKey)); + put.addColumn(Bytes.toBytes(columnFamilyName), Bytes.toBytes(qualifier), Bytes.toBytes(value)); + + try { + table.put(put); + } catch (IOException e) { + throw new RuntimeException("Fail to put one row to HBase."); + } + LOGGER.info( + "Put a record {} to {}-{}-{}-{}-{}.", + value, + namespace, + tableName, + rowKey, + columnFamilyName, + qualifier); + table.close(); + } + + public static byte[] getCell( + String namespace, String tableName, String rowKey, String columnFamilyName, String qualifier) + throws IOException { + Table table = connection.getTable(TableName.valueOf(namespace, tableName)); + Get get = new Get(Bytes.toBytes(rowKey)); + get.addColumn(Bytes.toBytes(columnFamilyName), Bytes.toBytes(qualifier)); + + byte[] resultValue; + try { + Result result = table.get(get); + resultValue = result.getValue(Bytes.toBytes(columnFamilyName), Bytes.toBytes(qualifier)); + } catch (IOException e) { + throw new RuntimeException("Fail to get cell from HBase"); + } + table.close(); + return resultValue; + } + + public static ResultScanner getScanner(String namespace, String tableName) throws IOException { + Table table = connection.getTable(TableName.valueOf(namespace, tableName)); + try { + Scan scan = new Scan(); + return table.getScanner(scan); + } catch (IOException e) { + throw new RuntimeException("Fail to scan all records."); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveConfigKeys.java index 18c2ca7d5..8687a425b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveConfigKeys.java @@ -23,23 +23,23 @@ import org.apache.geaflow.common.config.ConfigKeys; public class HiveConfigKeys { - public static final ConfigKey GEAFLOW_DSL_HIVE_DATABASE_NAME = ConfigKeys - .key("geaflow.dsl.hive.database.name") - .noDefaultValue() - .description("The database name for hive table."); + public static final ConfigKey GEAFLOW_DSL_HIVE_DATABASE_NAME = + ConfigKeys.key("geaflow.dsl.hive.database.name") + .noDefaultValue() + .description("The database name for hive table."); - public static final ConfigKey GEAFLOW_DSL_HIVE_TABLE_NAME = ConfigKeys - .key("geaflow.dsl.hive.table.name") - .noDefaultValue() - .description("The hive table name to read."); + public static final ConfigKey GEAFLOW_DSL_HIVE_TABLE_NAME = + ConfigKeys.key("geaflow.dsl.hive.table.name") + .noDefaultValue() + .description("The hive table name to read."); - public static final ConfigKey GEAFLOW_DSL_HIVE_METASTORE_URIS = ConfigKeys - .key("geaflow.dsl.hive.metastore.uris") - .noDefaultValue() - .description("The hive meta store uri."); + public static final ConfigKey GEAFLOW_DSL_HIVE_METASTORE_URIS = + ConfigKeys.key("geaflow.dsl.hive.metastore.uris") + .noDefaultValue() + .description("The hive meta store uri."); - public static final ConfigKey GEAFLOW_DSL_HIVE_PARTITION_MIN_SPLITS = ConfigKeys - .key("geaflow.dsl.hive.splits.per.partition") - .defaultValue(1) - .description("The split number of each hive partition."); + public static final ConfigKey GEAFLOW_DSL_HIVE_PARTITION_MIN_SPLITS = + ConfigKeys.key("geaflow.dsl.hive.splits.per.partition") + .defaultValue(1) + .description("The split number of each hive partition."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveReader.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveReader.java index 62f648b6d..74c51d7bc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveReader.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveReader.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.Properties; import java.util.stream.Collectors; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.ClassUtil; import org.apache.geaflow.dsl.common.data.Row; @@ -50,147 +51,162 @@ public class HiveReader { - private final RecordReader recordReader; - private final StructType readSchema; - private final Deserializer deserializer; - private long fetchOffset; - - public HiveReader(RecordReader recordReader, StructType readSchema, - StorageDescriptor sd, Properties tableProps) { - this.recordReader = recordReader; - this.readSchema = new StructType(readSchema.getFields().stream().map( - f -> new TableField(f.getName().toLowerCase(Locale.ROOT), f.getType(), f.isNullable())) - .collect(Collectors.toList())); - this.deserializer = ClassUtil.newInstance(sd.getSerdeInfo().getSerializationLib()); - this.fetchOffset = 0L; - try { - org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); - SerDeUtils.initializeSerDe(deserializer, conf, tableProps, null); - - } catch (SerDeException e) { - throw new GeaFlowDSLException(e); - } + private final RecordReader recordReader; + private final StructType readSchema; + private final Deserializer deserializer; + private long fetchOffset; + + public HiveReader( + RecordReader recordReader, + StructType readSchema, + StorageDescriptor sd, + Properties tableProps) { + this.recordReader = recordReader; + this.readSchema = + new StructType( + readSchema.getFields().stream() + .map( + f -> + new TableField( + f.getName().toLowerCase(Locale.ROOT), f.getType(), f.isNullable())) + .collect(Collectors.toList())); + this.deserializer = ClassUtil.newInstance(sd.getSerdeInfo().getSerializationLib()); + this.fetchOffset = 0L; + try { + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + SerDeUtils.initializeSerDe(deserializer, conf, tableProps, null); + + } catch (SerDeException e) { + throw new GeaFlowDSLException(e); } - - public FetchData read(long windowSize, String[] partitionValues) { - Iterator hiveIterator = new HiveIterator(recordReader, deserializer, partitionValues, readSchema); - if (windowSize == Long.MAX_VALUE) { - return FetchData.createBatchFetch(hiveIterator, new HiveOffset(-1L)); + } + + public FetchData read(long windowSize, String[] partitionValues) { + Iterator hiveIterator = + new HiveIterator(recordReader, deserializer, partitionValues, readSchema); + if (windowSize == Long.MAX_VALUE) { + return FetchData.createBatchFetch(hiveIterator, new HiveOffset(-1L)); + } else { + long fetchCnt = 0L; + List rows = new ArrayList<>(); + while (fetchCnt < windowSize) { + if (hiveIterator.hasNext()) { + fetchCnt++; + rows.add(hiveIterator.next()); } else { - long fetchCnt = 0L; - List rows = new ArrayList<>(); - while (fetchCnt < windowSize) { - if (hiveIterator.hasNext()) { - fetchCnt ++; - rows.add(hiveIterator.next()); - } else { - break; - } - } - fetchOffset += fetchCnt; - return FetchData.createStreamFetch(rows, new HiveOffset(fetchOffset), fetchCnt < windowSize); + break; } + } + fetchOffset += fetchCnt; + return FetchData.createStreamFetch(rows, new HiveOffset(fetchOffset), fetchCnt < windowSize); } - - public void seek(long seekPos) { - try { - Writable key = recordReader.createKey(); - Writable value = recordReader.createValue(); - fetchOffset = seekPos; - while (seekPos-- > 0) { - if (!recordReader.next(key, value)) { - throw new GeaflowRuntimeException("fetch offset is out of range: " + fetchOffset); - } - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); + } + + public void seek(long seekPos) { + try { + Writable key = recordReader.createKey(); + Writable value = recordReader.createValue(); + fetchOffset = seekPos; + while (seekPos-- > 0) { + if (!recordReader.next(key, value)) { + throw new GeaflowRuntimeException("fetch offset is out of range: " + fetchOffset); } + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - private static class HiveIterator implements Iterator { - - private final RecordReader recordReader; - private final Deserializer deserializer; - private final String[] partitionValues; - private final StructType readSchema; - - private final Map name2Fields = new HashMap<>(); - - private final Writable key; - private final Writable value; - - - public HiveIterator(RecordReader recordReader, - Deserializer deserializer, - String[] partitionValues, - StructType readSchema) { - this.recordReader = recordReader; - this.deserializer = deserializer; - this.partitionValues = partitionValues; - this.readSchema = readSchema; - key = recordReader.createKey(); - value = recordReader.createValue(); - - try { - StructObjectInspector structObjectInspector = (StructObjectInspector) deserializer.getObjectInspector(); - for (StructField field : structObjectInspector.getAllStructFieldRefs()) { - name2Fields.put(field.getFieldName(), field); - } - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } - } + private static class HiveIterator implements Iterator { - @Override - public boolean hasNext() { - try { - return recordReader.next(key, value); - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } - } + private final RecordReader recordReader; + private final Deserializer deserializer; + private final String[] partitionValues; + private final StructType readSchema; - @Override - public Row next() { - try { - Object hiveRowStruct = deserializer.deserialize(value); - StructObjectInspector structObjectInspector = (StructObjectInspector) deserializer.getObjectInspector(); - Object[] values = convertHiveStructToRow(hiveRowStruct, structObjectInspector); - if (partitionValues.length > 0) { // append partition values. - Object[] valueWithPartitions = new Object[values.length + partitionValues.length]; - System.arraycopy(values, 0, valueWithPartitions, 0, values.length); - System.arraycopy(partitionValues, 0, valueWithPartitions, - values.length, partitionValues.length); - values = valueWithPartitions; - } - return ObjectRow.create(values); - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } + private final Map name2Fields = new HashMap<>(); + + private final Writable key; + private final Writable value; + + public HiveIterator( + RecordReader recordReader, + Deserializer deserializer, + String[] partitionValues, + StructType readSchema) { + this.recordReader = recordReader; + this.deserializer = deserializer; + this.partitionValues = partitionValues; + this.readSchema = readSchema; + key = recordReader.createKey(); + value = recordReader.createValue(); + + try { + StructObjectInspector structObjectInspector = + (StructObjectInspector) deserializer.getObjectInspector(); + for (StructField field : structObjectInspector.getAllStructFieldRefs()) { + name2Fields.put(field.getFieldName(), field); } + } catch (Exception e) { + throw new GeaFlowDSLException(e); + } + } + + @Override + public boolean hasNext() { + try { + return recordReader.next(key, value); + } catch (IOException e) { + throw new GeaFlowDSLException(e); + } + } - private Object[] convertHiveStructToRow(Object hiveRowStruct, StructObjectInspector structObjectInspector) { - Object[] values = new Object[readSchema.size()]; - for (int i = 0; i < values.length; i++) { - String fieldName = readSchema.getField(i).getName(); - StructField field = name2Fields.get(fieldName); - if (field != null) { - values[i] = toSqlValue( - structObjectInspector.getStructFieldData(hiveRowStruct, field), - field.getFieldObjectInspector()); - } else { - values[i] = null; - } - } - return values; + @Override + public Row next() { + try { + Object hiveRowStruct = deserializer.deserialize(value); + StructObjectInspector structObjectInspector = + (StructObjectInspector) deserializer.getObjectInspector(); + Object[] values = convertHiveStructToRow(hiveRowStruct, structObjectInspector); + if (partitionValues.length > 0) { // append partition values. + Object[] valueWithPartitions = new Object[values.length + partitionValues.length]; + System.arraycopy(values, 0, valueWithPartitions, 0, values.length); + System.arraycopy( + partitionValues, 0, valueWithPartitions, values.length, partitionValues.length); + values = valueWithPartitions; } + return ObjectRow.create(values); + } catch (Exception e) { + throw new GeaFlowDSLException(e); + } + } - private Object toSqlValue(Object hiveValue, ObjectInspector fieldInspector) { - if (fieldInspector instanceof PrimitiveObjectInspector) { - PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector) fieldInspector; - return primitiveObjectInspector.getPrimitiveJavaObject(hiveValue); - } - throw new GeaFlowDSLException("Complex type:{} have not support", fieldInspector.getTypeName()); + private Object[] convertHiveStructToRow( + Object hiveRowStruct, StructObjectInspector structObjectInspector) { + Object[] values = new Object[readSchema.size()]; + for (int i = 0; i < values.length; i++) { + String fieldName = readSchema.getField(i).getName(); + StructField field = name2Fields.get(fieldName); + if (field != null) { + values[i] = + toSqlValue( + structObjectInspector.getStructFieldData(hiveRowStruct, field), + field.getFieldObjectInspector()); + } else { + values[i] = null; } + } + return values; + } + + private Object toSqlValue(Object hiveValue, ObjectInspector fieldInspector) { + if (fieldInspector instanceof PrimitiveObjectInspector) { + PrimitiveObjectInspector primitiveObjectInspector = + (PrimitiveObjectInspector) fieldInspector; + return primitiveObjectInspector.getPrimitiveJavaObject(hiveValue); + } + throw new GeaFlowDSLException( + "Complex type:{} have not support", fieldInspector.getTypeName()); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableConnector.java index 8107e7a9a..c673411cd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableConnector.java @@ -25,15 +25,15 @@ public class HiveTableConnector implements TableReadableConnector { - public static final String TYPE = "HIVE"; + public static final String TYPE = "HIVE"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new HiveTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new HiveTableSource(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableSource.java index 72f4ad871..a52628c50 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/HiveTableSource.java @@ -28,6 +28,7 @@ import java.util.Optional; import java.util.Properties; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; @@ -64,271 +65,277 @@ public class HiveTableSource implements TableSource, EnablePartitionPushDown { - private static final Logger LOGGER = LoggerFactory.getLogger(HiveTableSource.class); - - private StructType dataSchema; - private TableSchema tableSchema; - private String dbName; - private String tableName; - private String metastoreURIs; - private int splitNumPerPartition; - private transient IMetaStoreClient metaClient; - private transient Map partitionReaders; - - private Table hiveTable; - private Properties tableProps; - - private PartitionFilter partitionFilter; - - @Override - public void init(Configuration conf, TableSchema tableSchema) { - this.dataSchema = tableSchema.getDataSchema(); - this.tableSchema = tableSchema; - this.dbName = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_DATABASE_NAME); - this.tableName = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_TABLE_NAME); - this.metastoreURIs = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_METASTORE_URIS); - this.splitNumPerPartition = conf.getInteger(HiveConfigKeys.GEAFLOW_DSL_HIVE_PARTITION_MIN_SPLITS); + private static final Logger LOGGER = LoggerFactory.getLogger(HiveTableSource.class); + + private StructType dataSchema; + private TableSchema tableSchema; + private String dbName; + private String tableName; + private String metastoreURIs; + private int splitNumPerPartition; + private transient IMetaStoreClient metaClient; + private transient Map partitionReaders; + + private Table hiveTable; + private Properties tableProps; + + private PartitionFilter partitionFilter; + + @Override + public void init(Configuration conf, TableSchema tableSchema) { + this.dataSchema = tableSchema.getDataSchema(); + this.tableSchema = tableSchema; + this.dbName = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_DATABASE_NAME); + this.tableName = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_TABLE_NAME); + this.metastoreURIs = conf.getString(HiveConfigKeys.GEAFLOW_DSL_HIVE_METASTORE_URIS); + this.splitNumPerPartition = + conf.getInteger(HiveConfigKeys.GEAFLOW_DSL_HIVE_PARTITION_MIN_SPLITS); + } + + @Override + public void open(RuntimeContext context) { + this.partitionReaders = new HashMap<>(); + Map hiveConf = new HashMap<>(); + hiveConf.put(HiveUtils.THRIFT_URIS, metastoreURIs); + HiveVersionAdapter hiveVersionAdapter = HiveVersionAdapters.get(); + this.metaClient = hiveVersionAdapter.createMetaSoreClient(HiveUtils.getHiveConfig(hiveConf)); + try { + this.hiveTable = metaClient.getTable(dbName, tableName); + } catch (Exception e) { + throw new GeaFlowDSLException("Fail to get hive table for: {}.{}", dbName, tableName); } - - @Override - public void open(RuntimeContext context) { - this.partitionReaders = new HashMap<>(); - Map hiveConf = new HashMap<>(); - hiveConf.put(HiveUtils.THRIFT_URIS, metastoreURIs); - HiveVersionAdapter hiveVersionAdapter = HiveVersionAdapters.get(); - this.metaClient = hiveVersionAdapter.createMetaSoreClient(HiveUtils.getHiveConfig(hiveConf)); - try { - this.hiveTable = metaClient.getTable(dbName, tableName); - } catch (Exception e) { - throw new GeaFlowDSLException("Fail to get hive table for: {}.{}", dbName, tableName); + this.tableProps = hiveVersionAdapter.getTableMetadata(hiveTable); + LOGGER.info("open hive table source, hive version is: {}", hiveVersionAdapter.version()); + } + + @Override + public List listPartitions() { + List allPartitions = new ArrayList<>(); + try { + List hivePartitions = + metaClient.listPartitions(dbName, tableName, (short) -1); + + List storageDescriptors = new ArrayList<>(); + List sdPartitionValues = new ArrayList<>(); + if (hivePartitions != null && !hivePartitions.isEmpty()) { + for (org.apache.hadoop.hive.metastore.api.Partition hivePartition : hivePartitions) { + String[] partitionValues = alignPartitionValues(hivePartition.getValues()); + if (accept(partitionValues)) { + storageDescriptors.add(hivePartition.getSd()); + sdPartitionValues.add(partitionValues); + } } - this.tableProps = hiveVersionAdapter.getTableMetadata(hiveTable); - LOGGER.info("open hive table source, hive version is: {}", hiveVersionAdapter.version()); - } - - @Override - public List listPartitions() { - List allPartitions = new ArrayList<>(); - try { - List hivePartitions = - metaClient.listPartitions(dbName, tableName, (short) -1); - - List storageDescriptors = new ArrayList<>(); - List sdPartitionValues = new ArrayList<>(); - if (hivePartitions != null && !hivePartitions.isEmpty()) { - for (org.apache.hadoop.hive.metastore.api.Partition hivePartition : hivePartitions) { - String[] partitionValues = alignPartitionValues(hivePartition.getValues()); - if (accept(partitionValues)) { - storageDescriptors.add(hivePartition.getSd()); - sdPartitionValues.add(partitionValues); - } - } - } else { - storageDescriptors.add(hiveTable.getSd()); - sdPartitionValues.add(new String[0]); - } - for (int i = 0; i < storageDescriptors.size(); i++) { - StorageDescriptor sd = storageDescriptors.get(i); - String[] partitionValues = sdPartitionValues.get(i); - InputFormat inputFormat = HiveUtils.createInputFormat(sd); - InputSplit[] hadoopInputSplits = HiveUtils.createInputSplits(sd, inputFormat, splitNumPerPartition); - if (hadoopInputSplits != null) { - for (InputSplit split : hadoopInputSplits) { - allPartitions.add( - new HivePartition(dbName, tableName, split, inputFormat, sd, partitionValues)); - } - } - } - } catch (Exception e) { - throw new GeaFlowDSLException("fail to list partitions for " + dbName + "." + tableName, e); + } else { + storageDescriptors.add(hiveTable.getSd()); + sdPartitionValues.add(new String[0]); + } + for (int i = 0; i < storageDescriptors.size(); i++) { + StorageDescriptor sd = storageDescriptors.get(i); + String[] partitionValues = sdPartitionValues.get(i); + InputFormat inputFormat = HiveUtils.createInputFormat(sd); + InputSplit[] hadoopInputSplits = + HiveUtils.createInputSplits(sd, inputFormat, splitNumPerPartition); + if (hadoopInputSplits != null) { + for (InputSplit split : hadoopInputSplits) { + allPartitions.add( + new HivePartition(dbName, tableName, split, inputFormat, sd, partitionValues)); + } } - return allPartitions; - } - - @Override - public List listPartitions(int parallelism) { - return listPartitions(); + } + } catch (Exception e) { + throw new GeaFlowDSLException("fail to list partitions for " + dbName + "." + tableName, e); } - - /** - * Align the hive partition values to the partition fields order defined in DSL ddl. - */ - private String[] alignPartitionValues(List partitionValues) { - List hivePartitionKeys = hiveTable.getPartitionKeys().stream() + return allPartitions; + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + /** Align the hive partition values to the partition fields order defined in DSL ddl. */ + private String[] alignPartitionValues(List partitionValues) { + List hivePartitionKeys = + hiveTable.getPartitionKeys().stream() .map(FieldSchema::getName) .collect(Collectors.toList()); - StructType partitionSchema = tableSchema.getPartitionSchema(); - - String[] alignedValues = new String[partitionSchema.size()]; - for (int i = 0; i < partitionSchema.size(); i++) { - String partitionName = partitionSchema.getField(i).getName(); - int valueIndex = hivePartitionKeys.indexOf(partitionName); - if (valueIndex >= 0) { - alignedValues[i] = partitionValues.get(valueIndex); - } else { - alignedValues[i] = null; - } - } - return alignedValues; + StructType partitionSchema = tableSchema.getPartitionSchema(); + + String[] alignedValues = new String[partitionSchema.size()]; + for (int i = 0; i < partitionSchema.size(); i++) { + String partitionName = partitionSchema.getField(i).getName(); + int valueIndex = hivePartitionKeys.indexOf(partitionName); + if (valueIndex >= 0) { + alignedValues[i] = partitionValues.get(valueIndex); + } else { + alignedValues[i] = null; + } } + return alignedValues; + } - private boolean accept(Object[] partitionValues) { - if (partitionFilter != null) { - Row row = ObjectRow.create(partitionValues); - return partitionFilter.apply(row); + private boolean accept(Object[] partitionValues) { + if (partitionFilter != null) { + Row row = ObjectRow.create(partitionValues); + return partitionFilter.apply(row); + } + return true; + } + + @SuppressWarnings("unchecked") + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadRowTableDeserializer(); + } + + @SuppressWarnings("unchecked") + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + long desireWindowSize = -1; + switch (windowInfo.getType()) { + case ALL_WINDOW: + desireWindowSize = Long.MAX_VALUE; + break; + case SIZE_TUMBLING_WINDOW: + desireWindowSize = windowInfo.windowSize(); + break; + default: + throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); + } + HivePartition hivePartition = (HivePartition) partition; + HiveReader reader = partitionReaders.get(partition.getName()); + if (reader == null) { + JobConf jobConf = HiveUtils.getJobConf(hivePartition.getSd()); + Reporter reporter = HiveUtils.createDummyReporter(); + RecordReader recordReader = + hivePartition + .getInputFormat() + .getRecordReader(hivePartition.getSplit(), jobConf, reporter); + if (recordReader instanceof Configurable) { + ((Configurable) recordReader).setConf(jobConf); + } + reader = new HiveReader(recordReader, dataSchema, hivePartition.getSd(), tableProps); + partitionReaders.put(partition.getName(), reader); + if (startOffset.isPresent()) { + long seekOffset = startOffset.get().getOffset(); + if (seekOffset > 0) { + reader.seek(seekOffset); } - return true; + } + } + try { + return (FetchData) reader.read(desireWindowSize, hivePartition.getPartitionValues()); + } catch (Exception e) { + throw new GeaFlowDSLException(e); } + } - @SuppressWarnings("unchecked") - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadRowTableDeserializer(); + @Override + public void close() { + if (metaClient != null) { + metaClient.close(); + } + LOGGER.info("close hive table source: {}", tableName); + } + + @Override + public void setPartitionFilter(PartitionFilter partitionFilter) { + this.partitionFilter = partitionFilter; + } + + public static class HivePartition implements Partition { + + private final String dbName; + private final String tableName; + private final InputSplit split; + private final InputFormat inputFormat; + private final StorageDescriptor sd; + private final String[] partitionValues; + + public HivePartition( + String dbName, + String tableName, + InputSplit split, + InputFormat inputFormat, + StorageDescriptor sd, + String[] partitionValues) { + this.dbName = dbName; + this.tableName = tableName; + this.split = split; + this.inputFormat = inputFormat; + this.sd = sd; + this.partitionValues = partitionValues; } - @SuppressWarnings("unchecked") @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - long desireWindowSize = -1; - switch (windowInfo.getType()) { - case ALL_WINDOW: - desireWindowSize = Long.MAX_VALUE; - break; - case SIZE_TUMBLING_WINDOW: - desireWindowSize = windowInfo.windowSize(); - break; - default: - throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); - } - HivePartition hivePartition = (HivePartition) partition; - HiveReader reader = partitionReaders.get(partition.getName()); - if (reader == null) { - JobConf jobConf = HiveUtils.getJobConf(hivePartition.getSd()); - Reporter reporter = HiveUtils.createDummyReporter(); - RecordReader recordReader = hivePartition.getInputFormat() - .getRecordReader(hivePartition.getSplit(), jobConf, reporter); - if (recordReader instanceof Configurable) { - ((Configurable) recordReader).setConf(jobConf); - } - reader = new HiveReader(recordReader, dataSchema, hivePartition.getSd(), tableProps); - partitionReaders.put(partition.getName(), reader); - if (startOffset.isPresent()) { - long seekOffset = startOffset.get().getOffset(); - if (seekOffset > 0) { - reader.seek(seekOffset); - } - } - } - try { - return (FetchData) reader.read(desireWindowSize, hivePartition.getPartitionValues()); - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } + public String getName() { + return StringUtils.join(new Object[] {dbName, tableName, split}, "-"); } @Override - public void close() { - if (metaClient != null) { - metaClient.close(); - } - LOGGER.info("close hive table source: {}", tableName); - } + public void setIndex(int index, int parallel) {} @Override - public void setPartitionFilter(PartitionFilter partitionFilter) { - this.partitionFilter = partitionFilter; + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof HivePartition)) { + return false; + } + HivePartition that = (HivePartition) o; + return Objects.equals(dbName, that.dbName) + && Objects.equals(tableName, that.tableName) + && Objects.equals( + split != null ? split.toString() : "null", + that.split != null ? that.split.toString() : "null"); } - public static class HivePartition implements Partition { - - private final String dbName; - private final String tableName; - private final InputSplit split; - private final InputFormat inputFormat; - private final StorageDescriptor sd; - private final String[] partitionValues; - - public HivePartition(String dbName, - String tableName, - InputSplit split, - InputFormat inputFormat, - StorageDescriptor sd, - String[] partitionValues) { - this.dbName = dbName; - this.tableName = tableName; - this.split = split; - this.inputFormat = inputFormat; - this.sd = sd; - this.partitionValues = partitionValues; - } - - @Override - public String getName() { - return StringUtils.join(new Object[]{dbName, tableName, split}, "-"); - } - - @Override - public void setIndex(int index, int parallel) { - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof HivePartition)) { - return false; - } - HivePartition that = (HivePartition) o; - return Objects.equals(dbName, that.dbName) && Objects.equals(tableName, that.tableName) - && Objects.equals(split != null ? split.toString() : "null", - that.split != null ? that.split.toString() : "null"); - } - - @Override - public int hashCode() { - return Objects.hash(dbName, tableName, split != null ? split.toString() : "null"); - } + @Override + public int hashCode() { + return Objects.hash(dbName, tableName, split != null ? split.toString() : "null"); + } - public InputSplit getSplit() { - return split; - } + public InputSplit getSplit() { + return split; + } - public InputFormat getInputFormat() { - return inputFormat; - } + public InputFormat getInputFormat() { + return inputFormat; + } - public StorageDescriptor getSd() { - return sd; - } + public StorageDescriptor getSd() { + return sd; + } - public String[] getPartitionValues() { - return partitionValues; - } + public String[] getPartitionValues() { + return partitionValues; } + } - public static class HiveOffset implements Offset { + public static class HiveOffset implements Offset { - private final long offset; + private final long offset; - public HiveOffset(long offset) { - this.offset = offset; - } + public HiveOffset(long offset) { + this.offset = offset; + } - @Override - public String humanReadable() { - return String.valueOf(offset); - } + @Override + public String humanReadable() { + return String.valueOf(offset); + } - @Override - public long getOffset() { - return offset; - } + @Override + public long getOffset() { + return offset; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive23Adapter.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive23Adapter.java index 89c1043d1..4ed6b935c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive23Adapter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive23Adapter.java @@ -21,6 +21,7 @@ import java.lang.reflect.Method; import java.util.Properties; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.IMetaStoreClient; @@ -29,39 +30,39 @@ public class Hive23Adapter implements HiveVersionAdapter { - private final String version; - private static final String METHOD_NAME_GET_PROXY = "getProxy"; - private static final String METHOD_NAME_GET_TABLE_META_DATA = "getTableMetadata"; + private final String version; + private static final String METHOD_NAME_GET_PROXY = "getProxy"; + private static final String METHOD_NAME_GET_TABLE_META_DATA = "getTableMetadata"; - public Hive23Adapter(String version) { - this.version = version; - } + public Hive23Adapter(String version) { + this.version = version; + } - @Override - public String version() { - return version; - } + @Override + public String version() { + return version; + } - @Override - public IMetaStoreClient createMetaSoreClient(HiveConf hiveConf) { - try { - Method method = RetryingMetaStoreClient.class - .getMethod(METHOD_NAME_GET_PROXY, HiveConf.class, Boolean.TYPE); - return (IMetaStoreClient) method.invoke(null, hiveConf, true); - } catch (Exception ex) { - throw new RuntimeException("Failed to create Hive Metastore client", ex); - } + @Override + public IMetaStoreClient createMetaSoreClient(HiveConf hiveConf) { + try { + Method method = + RetryingMetaStoreClient.class.getMethod( + METHOD_NAME_GET_PROXY, HiveConf.class, Boolean.TYPE); + return (IMetaStoreClient) method.invoke(null, hiveConf, true); + } catch (Exception ex) { + throw new RuntimeException("Failed to create Hive Metastore client", ex); } + } - @Override - public Properties getTableMetadata(Table table) { - try { - Class metaStoreUtilsClass = Class.forName("org.apache.hadoop.hive.metastore.MetaStoreUtils"); - Method method = - metaStoreUtilsClass.getMethod(METHOD_NAME_GET_TABLE_META_DATA, Table.class); - return (Properties) method.invoke(null, table); - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } + @Override + public Properties getTableMetadata(Table table) { + try { + Class metaStoreUtilsClass = Class.forName("org.apache.hadoop.hive.metastore.MetaStoreUtils"); + Method method = metaStoreUtilsClass.getMethod(METHOD_NAME_GET_TABLE_META_DATA, Table.class); + return (Properties) method.invoke(null, table); + } catch (Exception e) { + throw new GeaFlowDSLException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive3Adapter.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive3Adapter.java index 822bdec6a..0b673a942 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive3Adapter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/Hive3Adapter.java @@ -21,6 +21,7 @@ import java.lang.reflect.Method; import java.util.Properties; + import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; @@ -30,39 +31,40 @@ public class Hive3Adapter implements HiveVersionAdapter { - private final String version; - private static final String METHOD_NAME_GET_PROXY = "getProxy"; - private static final String METHOD_NAME_GET_TABLE_META_DATA = "getTableMetadata"; + private final String version; + private static final String METHOD_NAME_GET_PROXY = "getProxy"; + private static final String METHOD_NAME_GET_TABLE_META_DATA = "getTableMetadata"; - public Hive3Adapter(String version) { - this.version = version; - } + public Hive3Adapter(String version) { + this.version = version; + } - @Override - public String version() { - return version; - } + @Override + public String version() { + return version; + } - @Override - public IMetaStoreClient createMetaSoreClient(HiveConf hiveConf) { - try { - Method method = RetryingMetaStoreClient.class - .getMethod(METHOD_NAME_GET_PROXY, Configuration.class, Boolean.TYPE); - return (IMetaStoreClient) method.invoke(null, hiveConf, true); - } catch (Exception ex) { - throw new RuntimeException("Failed to create Hive Metastore client", ex); - } + @Override + public IMetaStoreClient createMetaSoreClient(HiveConf hiveConf) { + try { + Method method = + RetryingMetaStoreClient.class.getMethod( + METHOD_NAME_GET_PROXY, Configuration.class, Boolean.TYPE); + return (IMetaStoreClient) method.invoke(null, hiveConf, true); + } catch (Exception ex) { + throw new RuntimeException("Failed to create Hive Metastore client", ex); } + } - @Override - public Properties getTableMetadata(Table table) { - try { - Class metaStoreUtilsClass = Class.forName("org.apache.hadoop.hive.metastore.utils.MetaStoreUtils"); - Method method = - metaStoreUtilsClass.getMethod(METHOD_NAME_GET_TABLE_META_DATA, Table.class); - return (Properties) method.invoke(null, table); - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } + @Override + public Properties getTableMetadata(Table table) { + try { + Class metaStoreUtilsClass = + Class.forName("org.apache.hadoop.hive.metastore.utils.MetaStoreUtils"); + Method method = metaStoreUtilsClass.getMethod(METHOD_NAME_GET_TABLE_META_DATA, Table.class); + return (Properties) method.invoke(null, table); + } catch (Exception e) { + throw new GeaFlowDSLException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapter.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapter.java index c592c6ecd..1172b5c5f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapter.java @@ -20,15 +20,16 @@ package org.apache.geaflow.dsl.connector.hive.adapter; import java.util.Properties; + import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.IMetaStoreClient; import org.apache.hadoop.hive.metastore.api.Table; public interface HiveVersionAdapter { - String version(); + String version(); - IMetaStoreClient createMetaSoreClient(HiveConf hiveConf); + IMetaStoreClient createMetaSoreClient(HiveConf hiveConf); - Properties getTableMetadata(Table table); + Properties getTableMetadata(Table table); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapters.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapters.java index 012bab8bf..d9969a868 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapters.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/adapter/HiveVersionAdapters.java @@ -24,53 +24,53 @@ public class HiveVersionAdapters { - public static final String HIVE_230 = "2.3.0"; + public static final String HIVE_230 = "2.3.0"; - public static final String HIVE_231 = "2.3.1"; + public static final String HIVE_231 = "2.3.1"; - public static final String HIVE_232 = "2.3.2"; + public static final String HIVE_232 = "2.3.2"; - public static final String HIVE_233 = "2.3.3"; + public static final String HIVE_233 = "2.3.3"; - public static final String HIVE_234 = "2.3.4"; + public static final String HIVE_234 = "2.3.4"; - public static final String HIVE_235 = "2.3.5"; + public static final String HIVE_235 = "2.3.5"; - public static final String HIVE_236 = "2.3.6"; + public static final String HIVE_236 = "2.3.6"; - public static final String HIVE_237 = "2.3.7"; + public static final String HIVE_237 = "2.3.7"; - public static final String HIVE_239 = "2.3.9"; + public static final String HIVE_239 = "2.3.9"; - public static final String HIVE_300 = "3.0.0"; + public static final String HIVE_300 = "3.0.0"; - public static final String HIVE_310 = "3.1.0"; + public static final String HIVE_310 = "3.1.0"; - public static final String HIVE_311 = "3.1.1"; + public static final String HIVE_311 = "3.1.1"; - public static final String HIVE_312 = "3.1.2"; + public static final String HIVE_312 = "3.1.2"; - public static HiveVersionAdapter get() { - Package myPackage = HiveVersionAnnotation.class.getPackage(); - HiveVersionAnnotation version = myPackage.getAnnotation(HiveVersionAnnotation.class); - switch (version.version()) { - case HIVE_230: - case HIVE_231: - case HIVE_232: - case HIVE_233: - case HIVE_234: - case HIVE_235: - case HIVE_236: - case HIVE_237: - case HIVE_239: - return new Hive23Adapter(version.version()); - case HIVE_300: - case HIVE_310: - case HIVE_311: - case HIVE_312: - return new Hive3Adapter(version.version()); - default: - throw new GeaFlowDSLException("Hive version: {} is not supported.", version.version()); - } + public static HiveVersionAdapter get() { + Package myPackage = HiveVersionAnnotation.class.getPackage(); + HiveVersionAnnotation version = myPackage.getAnnotation(HiveVersionAnnotation.class); + switch (version.version()) { + case HIVE_230: + case HIVE_231: + case HIVE_232: + case HIVE_233: + case HIVE_234: + case HIVE_235: + case HIVE_236: + case HIVE_237: + case HIVE_239: + return new Hive23Adapter(version.version()); + case HIVE_300: + case HIVE_310: + case HIVE_311: + case HIVE_312: + return new Hive3Adapter(version.version()); + default: + throw new GeaFlowDSLException("Hive version: {} is not supported.", version.version()); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/util/HiveUtils.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/util/HiveUtils.java index 77a517a66..d1edd7eeb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/util/HiveUtils.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/main/java/org/apache/geaflow/dsl/connector/hive/util/HiveUtils.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; + import org.apache.geaflow.common.utils.ClassUtil; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.hive.conf.HiveConf; @@ -38,91 +39,83 @@ public class HiveUtils { - public static String INPUT_DIR = "mapreduce.input.fileinputformat.inputdir"; + public static String INPUT_DIR = "mapreduce.input.fileinputformat.inputdir"; - public static String THRIFT_URIS = "hive.metastore.uris"; + public static String THRIFT_URIS = "hive.metastore.uris"; - private static final Logger LOGGER = LoggerFactory.getLogger(HiveUtils.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HiveUtils.class); - public static HiveConf getHiveConfig(Map config) { - HiveConf hiveConf = new HiveConf(); - for (String key : config.keySet()) { - hiveConf.set(key, config.get(key)); - } - return hiveConf; + public static HiveConf getHiveConfig(Map config) { + HiveConf hiveConf = new HiveConf(); + for (String key : config.keySet()) { + hiveConf.set(key, config.get(key)); } - - public static JobConf getJobConf(StorageDescriptor storageDescriptor) { - JobConf jobConf = new JobConf(); - jobConf.set(INPUT_DIR, storageDescriptor.getLocation()); - return jobConf; + return hiveConf; + } + + public static JobConf getJobConf(StorageDescriptor storageDescriptor) { + JobConf jobConf = new JobConf(); + jobConf.set(INPUT_DIR, storageDescriptor.getLocation()); + return jobConf; + } + + public static InputFormat createInputFormat( + StorageDescriptor storageDescriptor) { + JobConf jobConf = getJobConf(storageDescriptor); + jobConf.set(INPUT_DIR, storageDescriptor.getLocation()); + try { + return ClassUtil.newInstance(storageDescriptor.getInputFormat()); + } catch (Exception e) { + throw new RuntimeException("Unable to instantiate the hadoop input format", e); } - - public static InputFormat createInputFormat(StorageDescriptor storageDescriptor) { - JobConf jobConf = getJobConf(storageDescriptor); - jobConf.set(INPUT_DIR, storageDescriptor.getLocation()); - try { - return ClassUtil.newInstance(storageDescriptor.getInputFormat()); - } catch (Exception e) { - throw new RuntimeException("Unable to instantiate the hadoop input format", e); - } - } - - public static InputSplit[] createInputSplits(StorageDescriptor storageDescriptor, - InputFormat inputFormat, - int splitNumPerPartition) throws IOException { - JobConf jobConf = getJobConf(storageDescriptor); - ReflectionUtils.setConf(inputFormat, jobConf); - if (inputFormat instanceof Configurable) { - ((Configurable) inputFormat).setConf(jobConf); - } else if (inputFormat instanceof JobConfigurable) { - ((JobConfigurable) inputFormat).configure(jobConf); - } - return inputFormat.getSplits(jobConf, splitNumPerPartition); + } + + public static InputSplit[] createInputSplits( + StorageDescriptor storageDescriptor, InputFormat inputFormat, int splitNumPerPartition) + throws IOException { + JobConf jobConf = getJobConf(storageDescriptor); + ReflectionUtils.setConf(inputFormat, jobConf); + if (inputFormat instanceof Configurable) { + ((Configurable) inputFormat).setConf(jobConf); + } else if (inputFormat instanceof JobConfigurable) { + ((JobConfigurable) inputFormat).configure(jobConf); } - - public static Reporter createDummyReporter() { - return new Reporter() { - @Override - public void setStatus(String status) { - - } - - @Override - public Counter getCounter(Enum name) { - return new Counter(); - } - - @Override - public Counter getCounter(String group, String name) { - return new Counter(); - } - - @Override - public void incrCounter(Enum key, long amount) { - - } - - @Override - public void incrCounter(String group, String counter, long amount) { - - } - - @Override - public InputSplit getInputSplit() throws UnsupportedOperationException { - return null; - } - - @Override - public float getProgress() { - return 0; - } - - @Override - public void progress() { - - } - }; - } - + return inputFormat.getSplits(jobConf, splitNumPerPartition); + } + + public static Reporter createDummyReporter() { + return new Reporter() { + @Override + public void setStatus(String status) {} + + @Override + public Counter getCounter(Enum name) { + return new Counter(); + } + + @Override + public Counter getCounter(String group, String name) { + return new Counter(); + } + + @Override + public void incrCounter(Enum key, long amount) {} + + @Override + public void incrCounter(String group, String counter, long amount) {} + + @Override + public InputSplit getInputSplit() throws UnsupportedOperationException { + return null; + } + + @Override + public float getProgress() { + return 0; + } + + @Override + public void progress() {} + }; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/BaseHiveTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/BaseHiveTest.java index 52c3bc9d8..16db1a582 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/BaseHiveTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/BaseHiveTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.utils.FileUtil; import org.apache.hadoop.conf.Configuration; @@ -32,58 +33,59 @@ public class BaseHiveTest { - public static final int metastorePort = 9083; - private HiveTestMetaStore hiveTestMetastore; + public static final int metastorePort = 9083; + private HiveTestMetaStore hiveTestMetastore; - private Driver hiveDriver; + private Driver hiveDriver; - @BeforeClass - public void setup() throws IOException { - String hiveLocation = FileUtil.concatPath(FileUtils.getTempDirectoryPath(), - "hive_" + System.currentTimeMillis()); - File hiveDir = new File(hiveLocation); - hiveDir.mkdirs(); - HiveConf hiveConf = createHiveConf(new Configuration(), hiveLocation); - hiveTestMetastore = new HiveTestMetaStore(hiveConf, hiveLocation); - hiveTestMetastore.start(); + @BeforeClass + public void setup() throws IOException { + String hiveLocation = + FileUtil.concatPath(FileUtils.getTempDirectoryPath(), "hive_" + System.currentTimeMillis()); + File hiveDir = new File(hiveLocation); + hiveDir.mkdirs(); + HiveConf hiveConf = createHiveConf(new Configuration(), hiveLocation); + hiveTestMetastore = new HiveTestMetaStore(hiveConf, hiveLocation); + hiveTestMetastore.start(); - SessionState.start(hiveConf); - hiveDriver = new Driver(hiveConf); - } + SessionState.start(hiveConf); + hiveDriver = new Driver(hiveConf); + } - public void shutdown() { - hiveTestMetastore.stop(); - } + public void shutdown() { + hiveTestMetastore.stop(); + } - protected void executeHiveSql(String hiveSql) { - try { - hiveDriver.run(hiveSql); - } catch (Exception e) { - throw new RuntimeException(e); - } + protected void executeHiveSql(String hiveSql) { + try { + hiveDriver.run(hiveSql); + } catch (Exception e) { + throw new RuntimeException(e); } + } - private HiveConf createHiveConf(Configuration conf, String hiveLocation) { - conf.set("hive.metastore.local", "false"); - conf.setInt(ConfVars.METASTORE_SERVER_PORT.varname, metastorePort); - conf.set(ConfVars.METASTORE_EXECUTE_SET_UGI.varname, "false"); - conf.set(HiveConf.ConfVars.METASTOREURIS.varname, "thrift://localhost:" + metastorePort); - // conf.set(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, bindIP); - File metastoreDir = new File(hiveLocation, "metastore_db"); - conf.set(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, - "jdbc:derby:" + metastoreDir.getPath() + ";create=true"); - File wareHouseDir = new File(hiveLocation, "ware_house"); - wareHouseDir.mkdirs(); - conf.set(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, wareHouseDir.getAbsolutePath()); - conf.set("datanucleus.schema.autoCreateTables", "true"); - conf.set("hive.metastore.schema.verification", "false"); - conf.set("datanucleus.autoCreateSchema", "true"); - conf.set("datanucleus.fixedDatastore", "false"); - conf.set("datanucleus.schema.autoCreateAll", "true"); - conf.set("hive.stats.autogather", "false"); + private HiveConf createHiveConf(Configuration conf, String hiveLocation) { + conf.set("hive.metastore.local", "false"); + conf.setInt(ConfVars.METASTORE_SERVER_PORT.varname, metastorePort); + conf.set(ConfVars.METASTORE_EXECUTE_SET_UGI.varname, "false"); + conf.set(HiveConf.ConfVars.METASTOREURIS.varname, "thrift://localhost:" + metastorePort); + // conf.set(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, bindIP); + File metastoreDir = new File(hiveLocation, "metastore_db"); + conf.set( + HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + "jdbc:derby:" + metastoreDir.getPath() + ";create=true"); + File wareHouseDir = new File(hiveLocation, "ware_house"); + wareHouseDir.mkdirs(); + conf.set(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, wareHouseDir.getAbsolutePath()); + conf.set("datanucleus.schema.autoCreateTables", "true"); + conf.set("hive.metastore.schema.verification", "false"); + conf.set("datanucleus.autoCreateSchema", "true"); + conf.set("datanucleus.fixedDatastore", "false"); + conf.set("datanucleus.schema.autoCreateAll", "true"); + conf.set("hive.stats.autogather", "false"); - String scratchDir = FileUtil.concatPath(hiveLocation, "scratch"); - conf.set(ConfVars.SCRATCHDIR.varname, scratchDir); - return new HiveConf(conf, this.getClass()); - } + String scratchDir = FileUtil.concatPath(hiveLocation, "scratch"); + conf.set(ConfVars.SCRATCHDIR.varname, scratchDir); + return new HiveConf(conf, this.getClass()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HivePartitionTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HivePartitionTest.java index 17fe72dd8..40dfc174d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HivePartitionTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HivePartitionTest.java @@ -27,25 +27,22 @@ public class HivePartitionTest { - @Test - public void testHivePartition() { - Partition p1 = new HivePartition( - "default", "testTable1", null, null, null, new String[0]); - Partition p2 = new HivePartition( - "default", "testTable2", null, null, null, new String[0]); - Partition _p1 = new HivePartition( - "default", "testTable1", null, null, null, new String[0]); - Assert.assertEquals(p1.hashCode(), _p1.hashCode()); - Assert.assertEquals(p1, _p1); - Assert.assertNotEquals(p1.hashCode(), p2.hashCode()); - Assert.assertNotEquals(p1, p2); - } + @Test + public void testHivePartition() { + Partition p1 = new HivePartition("default", "testTable1", null, null, null, new String[0]); + Partition p2 = new HivePartition("default", "testTable2", null, null, null, new String[0]); + Partition _p1 = new HivePartition("default", "testTable1", null, null, null, new String[0]); + Assert.assertEquals(p1.hashCode(), _p1.hashCode()); + Assert.assertEquals(p1, _p1); + Assert.assertNotEquals(p1.hashCode(), p2.hashCode()); + Assert.assertNotEquals(p1, p2); + } - @Test - public void testHiveOffset() { - HiveOffset test = new HiveOffset(0L); - Assert.assertEquals(test.humanReadable(), "0"); - Assert.assertEquals(test.getOffset(), 0L); - Assert.assertEquals(test.isTimestamp(), false); - } + @Test + public void testHiveOffset() { + HiveOffset test = new HiveOffset(0L); + Assert.assertEquals(test.humanReadable(), "0"); + Assert.assertEquals(test.getOffset(), 0L); + Assert.assertEquals(test.isTimestamp(), false); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTableSourceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTableSourceTest.java index b6a10e300..61fd7cc86 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTableSourceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTableSourceTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.BinaryStringType; @@ -48,200 +49,232 @@ public class HiveTableSourceTest extends BaseHiveTest { - private static final Logger LOGGER = LoggerFactory.getLogger(HiveTableSourceTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HiveTableSourceTest.class); - @AfterTest - public void shutdown() { - super.shutdown(); - } + @AfterTest + public void shutdown() { + super.shutdown(); + } - @Test - public void testReadHiveText() throws IOException { - String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as textfile"; - String inserts = "INSERT into hive_user SELECT 1, 'jim', 20;" + @Test + public void testReadHiveText() throws IOException { + String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as textfile"; + String inserts = + "INSERT into hive_user SELECT 1, 'jim', 20;" + "INSERT into hive_user SELECT 2, 'kate', 18;" + "INSERT into hive_user SELECT 3, 'lily', 22;" + "INSERT into hive_user SELECT 4, 'lucy', 25;" + "INSERT into hive_user SELECT 5, 'jack', 26"; - StructType dataSchema = new StructType( + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - String expected = "[1, jim, 20]\n" + new TableField("age", IntegerType.INSTANCE, false)); + String expected = + "[1, jim, 20]\n" + "[2, kate, 18]\n" + "[3, lily, 22]\n" + "[4, lucy, 25]\n" + "[5, jack, 26]"; - checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); - checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); - } + checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); + checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); + } - @Test - public void testReadHiveParquet() throws IOException { - String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as parquet"; - String inserts = "INSERT into hive_user SELECT 1, 'jim', 20;" + @Test + public void testReadHiveParquet() throws IOException { + String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as parquet"; + String inserts = + "INSERT into hive_user SELECT 1, 'jim', 20;" + "INSERT into hive_user SELECT 2, 'kate', 18;" + "INSERT into hive_user SELECT 3, 'lily', 22;" + "INSERT into hive_user SELECT 4, 'lucy', 25;" + "INSERT into hive_user SELECT 5, 'jack', 26"; - StructType dataSchema = new StructType( + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - String expected = "[1, jim, 20]\n" + new TableField("age", IntegerType.INSTANCE, false)); + String expected = + "[1, jim, 20]\n" + "[2, kate, 18]\n" + "[3, lily, 22]\n" + "[4, lucy, 25]\n" + "[5, jack, 26]"; - checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); - checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); - } + checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); + checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); + } - @Test - public void testReadHiveOrc() throws IOException { - String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as orc"; - String inserts = "INSERT into hive_user SELECT 1, 'jim', 20;" + @Test + public void testReadHiveOrc() throws IOException { + String ddl = "CREATE TABLE hive_user (id int, name string, age int) stored as orc"; + String inserts = + "INSERT into hive_user SELECT 1, 'jim', 20;" + "INSERT into hive_user SELECT 2, 'kate', 18;" + "INSERT into hive_user SELECT 3, 'lily', 22;" + "INSERT into hive_user SELECT 4, 'lucy', 25;" + "INSERT into hive_user SELECT 5, 'jack', 26"; - StructType dataSchema = new StructType( + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - String expected = "[1, jim, 20]\n" + new TableField("age", IntegerType.INSTANCE, false)); + String expected = + "[1, jim, 20]\n" + "[2, kate, 18]\n" + "[3, lily, 22]\n" + "[4, lucy, 25]\n" + "[5, jack, 26]"; - checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); - checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); - } + checkReadHive(ddl, inserts, dataSchema, new StructType(), false, expected); + checkReadHive(ddl, inserts, dataSchema, new StructType(), true, expected); + } - @Test - public void testReadHiveTextPartitionTable() throws IOException { - String ddl = "CREATE TABLE hive_user (id int, name string, age int) " + @Test + public void testReadHiveTextPartitionTable() throws IOException { + String ddl = + "CREATE TABLE hive_user (id int, name string, age int) " + "partitioned by(dt string)" + "stored as textfile"; - String inserts = - "INSERT into hive_user partition(dt = '2023-04-23') SELECT 1, 'jim', 20;" - + "INSERT into hive_user partition(dt = '2023-04-24') SELECT 2, 'kate', 18;" - + "INSERT into hive_user partition(dt = '2023-04-24') SELECT 3, 'lily', 22;" - + "INSERT into hive_user partition(dt = '2023-04-25') SELECT 4, 'lucy', 25;" - + "INSERT into hive_user partition(dt = '2023-04-26') SELECT 5, 'jack', 26"; - StructType dataSchema = new StructType( + String inserts = + "INSERT into hive_user partition(dt = '2023-04-23') SELECT 1, 'jim', 20;" + + "INSERT into hive_user partition(dt = '2023-04-24') SELECT 2, 'kate', 18;" + + "INSERT into hive_user partition(dt = '2023-04-24') SELECT 3, 'lily', 22;" + + "INSERT into hive_user partition(dt = '2023-04-25') SELECT 4, 'lucy', 25;" + + "INSERT into hive_user partition(dt = '2023-04-26') SELECT 5, 'jack', 26"; + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - StructType partitionSchema = new StructType( - new TableField("dt", BinaryStringType.INSTANCE, false) - ); - String expected = "[1, jim, 20, 2023-04-23]\n" + new TableField("age", IntegerType.INSTANCE, false)); + StructType partitionSchema = + new StructType(new TableField("dt", BinaryStringType.INSTANCE, false)); + String expected = + "[1, jim, 20, 2023-04-23]\n" + "[2, kate, 18, 2023-04-24]\n" + "[3, lily, 22, 2023-04-24]\n" + "[4, lucy, 25, 2023-04-25]\n" + "[5, jack, 26, 2023-04-26]"; - checkReadHive(ddl, inserts, dataSchema, partitionSchema, false, - "[1, jim, 20, 2023-04-23]\n" - + "[2, kate, 18, 2023-04-24]\n" - + "[3, lily, 22, 2023-04-24]\n" - + "[4, lucy, 25, 2023-04-25]\n" - + "[5, jack, 26, 2023-04-26]"); - checkReadHive(ddl, inserts, dataSchema, partitionSchema, true, - "[1, jim, 20, 2023-04-23]\n" - + "[2, kate, 18, 2023-04-24]\n" - + "[3, lily, 22, 2023-04-24]\n" - + "[4, lucy, 25, 2023-04-25]\n" - + "[5, jack, 26, 2023-04-26]"); - } + checkReadHive( + ddl, + inserts, + dataSchema, + partitionSchema, + false, + "[1, jim, 20, 2023-04-23]\n" + + "[2, kate, 18, 2023-04-24]\n" + + "[3, lily, 22, 2023-04-24]\n" + + "[4, lucy, 25, 2023-04-25]\n" + + "[5, jack, 26, 2023-04-26]"); + checkReadHive( + ddl, + inserts, + dataSchema, + partitionSchema, + true, + "[1, jim, 20, 2023-04-23]\n" + + "[2, kate, 18, 2023-04-24]\n" + + "[3, lily, 22, 2023-04-24]\n" + + "[4, lucy, 25, 2023-04-25]\n" + + "[5, jack, 26, 2023-04-26]"); + } - @Test - public void testReadHiveText2PartitionTable() throws IOException { - String ddl = "CREATE TABLE hive_user (id int, name string, age int) " + @Test + public void testReadHiveText2PartitionTable() throws IOException { + String ddl = + "CREATE TABLE hive_user (id int, name string, age int) " + "partitioned by(dt string, hh string)" + "stored as textfile"; - String inserts = - "INSERT into hive_user partition(dt = '2023-04-23', hh ='10') SELECT 1, 'jim', 20;" - + "INSERT into hive_user partition(dt = '2023-04-24',hh = '10') SELECT 2, 'kate', 18;" - + "INSERT into hive_user partition(dt = '2023-04-24',hh = '11') SELECT 3, 'lily', 22;" - + "INSERT into hive_user partition(dt = '2023-04-25',hh = '12') SELECT 4, 'lucy', 25;" - + "INSERT into hive_user partition(dt = '2023-04-26',hh = '13') SELECT 5, 'jack', 26"; - StructType dataSchema = new StructType( + String inserts = + "INSERT into hive_user partition(dt = '2023-04-23', hh ='10') SELECT 1, 'jim', 20;" + + "INSERT into hive_user partition(dt = '2023-04-24',hh = '10') SELECT 2, 'kate', 18;" + + "INSERT into hive_user partition(dt = '2023-04-24',hh = '11') SELECT 3, 'lily', 22;" + + "INSERT into hive_user partition(dt = '2023-04-25',hh = '12') SELECT 4, 'lucy', 25;" + + "INSERT into hive_user partition(dt = '2023-04-26',hh = '13') SELECT 5, 'jack', 26"; + StructType dataSchema = + new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("name", BinaryStringType.INSTANCE, true), - new TableField("age", IntegerType.INSTANCE, false) - ); - StructType partitionSchema = new StructType( + new TableField("age", IntegerType.INSTANCE, false)); + StructType partitionSchema = + new StructType( new TableField("hh", BinaryStringType.INSTANCE, false), - new TableField("dt", BinaryStringType.INSTANCE, false) - ); - checkReadHive(ddl, inserts, dataSchema, partitionSchema, false, - "[1, jim, 20, 10, 2023-04-23]\n" - + "[2, kate, 18, 10, 2023-04-24]\n" - + "[3, lily, 22, 11, 2023-04-24]\n" - + "[4, lucy, 25, 12, 2023-04-25]\n" - + "[5, jack, 26, 13, 2023-04-26]"); - checkReadHive(ddl, inserts, dataSchema, partitionSchema, true, - "[1, jim, 20, 10, 2023-04-23]\n" - + "[2, kate, 18, 10, 2023-04-24]\n" - + "[3, lily, 22, 11, 2023-04-24]\n" - + "[4, lucy, 25, 12, 2023-04-25]\n" - + "[5, jack, 26, 13, 2023-04-26]"); + new TableField("dt", BinaryStringType.INSTANCE, false)); + checkReadHive( + ddl, + inserts, + dataSchema, + partitionSchema, + false, + "[1, jim, 20, 10, 2023-04-23]\n" + + "[2, kate, 18, 10, 2023-04-24]\n" + + "[3, lily, 22, 11, 2023-04-24]\n" + + "[4, lucy, 25, 12, 2023-04-25]\n" + + "[5, jack, 26, 13, 2023-04-26]"); + checkReadHive( + ddl, + inserts, + dataSchema, + partitionSchema, + true, + "[1, jim, 20, 10, 2023-04-23]\n" + + "[2, kate, 18, 10, 2023-04-24]\n" + + "[3, lily, 22, 11, 2023-04-24]\n" + + "[4, lucy, 25, 12, 2023-04-25]\n" + + "[5, jack, 26, 13, 2023-04-26]"); + } + + private void checkReadHive( + String ddl, + String inserts, + StructType dataSchema, + StructType partitionSchema, + boolean isStream, + String expectResult) + throws IOException { + executeHiveSql("Drop table if exists hive_user"); + executeHiveSql(ddl); + String[] insertArray = inserts.split(";"); + for (String insert : insertArray) { + if (StringUtils.isNotEmpty(insert)) { + executeHiveSql(insert); + } } + HiveTableSource hiveTableSource = new HiveTableSource(); + Configuration tableConf = new Configuration(); + tableConf.put(HiveConfigKeys.GEAFLOW_DSL_HIVE_DATABASE_NAME, "default"); + tableConf.put(HiveConfigKeys.GEAFLOW_DSL_HIVE_TABLE_NAME, "hive_user"); + tableConf.put( + HiveConfigKeys.GEAFLOW_DSL_HIVE_METASTORE_URIS, "thrift://localhost:" + metastorePort); - private void checkReadHive(String ddl, String inserts, StructType dataSchema, - StructType partitionSchema, - boolean isStream, - String expectResult) throws IOException { - executeHiveSql("Drop table if exists hive_user"); - executeHiveSql(ddl); - String[] insertArray = inserts.split(";"); - for (String insert : insertArray) { - if (StringUtils.isNotEmpty(insert)) { - executeHiveSql(insert); - } - } - HiveTableSource hiveTableSource = new HiveTableSource(); - Configuration tableConf = new Configuration(); - tableConf.put(HiveConfigKeys.GEAFLOW_DSL_HIVE_DATABASE_NAME, "default"); - tableConf.put(HiveConfigKeys.GEAFLOW_DSL_HIVE_TABLE_NAME, "hive_user"); - tableConf.put(HiveConfigKeys.GEAFLOW_DSL_HIVE_METASTORE_URIS, - "thrift://localhost:" + metastorePort); - - TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); - hiveTableSource.init(tableConf, tableSchema); - - hiveTableSource.open(new DefaultRuntimeContext(tableConf)); - List partitions = hiveTableSource.listPartitions(); - - TableDeserializer deserializer = hiveTableSource.getDeserializer(tableConf); - deserializer.init(tableConf, tableSchema); - - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - LOGGER.info("partition: {}", partition.getName()); - AbstractFetchWindow window; - if (isStream) { - window = new SizeFetchWindow(1, Long.MAX_VALUE); - } else { - window = new AllFetchWindow(1); - } - FetchData fetchData = hiveTableSource.fetch(partition, Optional.empty(), - new AllFetchWindow(1)); - Iterator rowIterator = fetchData.getDataIterator(); - while (rowIterator.hasNext()) { - Row row = rowIterator.next(); - readRows.addAll(deserializer.deserialize(row)); - } - } - List lines = readRows.stream().map(Object::toString) - .sorted().collect(Collectors.toList()); - Assert.assertEquals(StringUtils.join(lines, "\n"), expectResult); - - hiveTableSource.close(); + TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); + hiveTableSource.init(tableConf, tableSchema); + + hiveTableSource.open(new DefaultRuntimeContext(tableConf)); + List partitions = hiveTableSource.listPartitions(); + + TableDeserializer deserializer = hiveTableSource.getDeserializer(tableConf); + deserializer.init(tableConf, tableSchema); + + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + LOGGER.info("partition: {}", partition.getName()); + AbstractFetchWindow window; + if (isStream) { + window = new SizeFetchWindow(1, Long.MAX_VALUE); + } else { + window = new AllFetchWindow(1); + } + FetchData fetchData = + hiveTableSource.fetch(partition, Optional.empty(), new AllFetchWindow(1)); + Iterator rowIterator = fetchData.getDataIterator(); + while (rowIterator.hasNext()) { + Row row = rowIterator.next(); + readRows.addAll(deserializer.deserialize(row)); + } } + List lines = + readRows.stream().map(Object::toString).sorted().collect(Collectors.toList()); + Assert.assertEquals(StringUtils.join(lines, "\n"), expectResult); + + hiveTableSource.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTestMetaStore.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTestMetaStore.java index 0b3b4d90a..8b9233c5f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTestMetaStore.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hive/src/test/java/org/apache/geaflow/dsl/connector/hive/HiveTestMetaStore.java @@ -27,6 +27,7 @@ import java.nio.channels.SocketChannel; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.commons.io.FileUtils; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; @@ -49,117 +50,122 @@ public class HiveTestMetaStore { - private static final Logger LOGGER = LoggerFactory.getLogger(HiveTestMetaStore.class); + private static final Logger LOGGER = LoggerFactory.getLogger(HiveTestMetaStore.class); - private final String hiveLocation; + private final String hiveLocation; - private final HiveConf hiveConf; + private final HiveConf hiveConf; - private final ExecutorService executorService; - private TServer tServer; + private final ExecutorService executorService; + private TServer tServer; - public HiveTestMetaStore(HiveConf hiveConf, String hiveLocation) { - this.hiveLocation = hiveLocation; - this.hiveConf = hiveConf; - this.executorService = Executors.newSingleThreadExecutor(); - } + public HiveTestMetaStore(HiveConf hiveConf, String hiveLocation) { + this.hiveLocation = hiveLocation; + this.hiveConf = hiveConf; + this.executorService = Executors.newSingleThreadExecutor(); + } - public void start() throws IOException { - int metastorePort = hiveConf.getIntVar(ConfVars.METASTORE_SERVER_PORT); - tServer = startMetaStore(hiveConf); - waitForServer(metastorePort); - LOGGER.info("Hive metastore server has started at port:{}", metastorePort); - } + public void start() throws IOException { + int metastorePort = hiveConf.getIntVar(ConfVars.METASTORE_SERVER_PORT); + tServer = startMetaStore(hiveConf); + waitForServer(metastorePort); + LOGGER.info("Hive metastore server has started at port:{}", metastorePort); + } - private static void waitForServer(int serverPort) { - final long endTime = System.currentTimeMillis() + (long) 20000; + private static void waitForServer(int serverPort) { + final long endTime = System.currentTimeMillis() + (long) 20000; + try { + while (System.currentTimeMillis() < endTime) { + SocketChannel channel = null; try { - while (System.currentTimeMillis() < endTime) { - SocketChannel channel = null; - try { - channel = SocketChannel.open(new InetSocketAddress("localhost", serverPort)); - LOGGER.info("Server started at port {}", serverPort); - return; - } catch (ConnectException e) { - LOGGER.info("Waiting for server to start..."); - Thread.sleep(1000); - } finally { - if (channel != null) { - channel.close(); - } - } - } - } catch (Exception e) { - throw new RuntimeException("Fail to start server at port: " + serverPort); + channel = SocketChannel.open(new InetSocketAddress("localhost", serverPort)); + LOGGER.info("Server started at port {}", serverPort); + return; + } catch (ConnectException e) { + LOGGER.info("Waiting for server to start..."); + Thread.sleep(1000); + } finally { + if (channel != null) { + channel.close(); + } } + } + } catch (Exception e) { + throw new RuntimeException("Fail to start server at port: " + serverPort); } - - public void stop() { - if (tServer != null) { - try { - tServer.stop(); - LOGGER.info("MetaStore sever has stop"); - } catch (Exception e) { - LOGGER.error("stop meta store failed", e); - } - } - if (executorService != null) { - executorService.shutdownNow(); - } - try { - FileUtils.deleteDirectory(new File(hiveLocation)); - } catch (IOException e) { - LOGGER.warn("fail to clear hive location: " + hiveLocation); - } + } + + public void stop() { + if (tServer != null) { + try { + tServer.stop(); + LOGGER.info("MetaStore sever has stop"); + } catch (Exception e) { + LOGGER.error("stop meta store failed", e); + } } - - private TServer startMetaStore(HiveConf conf) throws IOException { - try { - int port = conf.getIntVar(HiveConf.ConfVars.METASTORE_SERVER_PORT); - int minWorkerThreads = conf.getIntVar(HiveConf.ConfVars.METASTORESERVERMINTHREADS); - int maxWorkerThreads = conf.getIntVar(HiveConf.ConfVars.METASTORESERVERMAXTHREADS); - boolean tcpKeepAlive = conf.getBoolVar(HiveConf.ConfVars.METASTORE_TCP_KEEP_ALIVE); - boolean useFramedTransport = conf.getBoolVar(HiveConf.ConfVars.METASTORE_USE_THRIFT_FRAMED_TRANSPORT); - - InetSocketAddress address = new InetSocketAddress("localhost", port); - TServerTransport serverTransport = tcpKeepAlive ? new TServerSocketKeepAlive(address) : - new TServerSocket(address); - - TProcessor processor; - TTransportFactory transFactory; - HiveMetaStore.HMSHandler baseHandler = new HiveMetaStore.HMSHandler( - "Test metastore handler", conf, false); - IHMSHandler handler = RetryingHMSHandler.getProxy(conf, baseHandler, true); - - transFactory = useFramedTransport ? new TFramedTransport.Factory() : new TTransportFactory(); - processor = new TSetIpAddressProcessor<>(handler); - - TThreadPoolServer.Args args = new TThreadPoolServer.Args(serverTransport).processor(processor) - .transportFactory(transFactory).protocolFactory(new TBinaryProtocol.Factory()) - .minWorkerThreads(minWorkerThreads).maxWorkerThreads(maxWorkerThreads); - - final TServer tServer = new TThreadPoolServer(args); - executorService.submit(tServer::serve); - return tServer; - } catch (Throwable x) { - throw new IOException(x); - } + if (executorService != null) { + executorService.shutdownNow(); + } + try { + FileUtils.deleteDirectory(new File(hiveLocation)); + } catch (IOException e) { + LOGGER.warn("fail to clear hive location: " + hiveLocation); } + } + + private TServer startMetaStore(HiveConf conf) throws IOException { + try { + int port = conf.getIntVar(HiveConf.ConfVars.METASTORE_SERVER_PORT); + int minWorkerThreads = conf.getIntVar(HiveConf.ConfVars.METASTORESERVERMINTHREADS); + int maxWorkerThreads = conf.getIntVar(HiveConf.ConfVars.METASTORESERVERMAXTHREADS); + boolean tcpKeepAlive = conf.getBoolVar(HiveConf.ConfVars.METASTORE_TCP_KEEP_ALIVE); + boolean useFramedTransport = + conf.getBoolVar(HiveConf.ConfVars.METASTORE_USE_THRIFT_FRAMED_TRANSPORT); + + InetSocketAddress address = new InetSocketAddress("localhost", port); + TServerTransport serverTransport = + tcpKeepAlive ? new TServerSocketKeepAlive(address) : new TServerSocket(address); + + TProcessor processor; + TTransportFactory transFactory; + HiveMetaStore.HMSHandler baseHandler = + new HiveMetaStore.HMSHandler("Test metastore handler", conf, false); + IHMSHandler handler = RetryingHMSHandler.getProxy(conf, baseHandler, true); + + transFactory = useFramedTransport ? new TFramedTransport.Factory() : new TTransportFactory(); + processor = new TSetIpAddressProcessor<>(handler); + + TThreadPoolServer.Args args = + new TThreadPoolServer.Args(serverTransport) + .processor(processor) + .transportFactory(transFactory) + .protocolFactory(new TBinaryProtocol.Factory()) + .minWorkerThreads(minWorkerThreads) + .maxWorkerThreads(maxWorkerThreads); + + final TServer tServer = new TThreadPoolServer(args); + executorService.submit(tServer::serve); + return tServer; + } catch (Throwable x) { + throw new IOException(x); + } + } - private static final class TServerSocketKeepAlive extends TServerSocket { - public TServerSocketKeepAlive(InetSocketAddress address) throws TTransportException { - super(address, 0); - } + private static final class TServerSocketKeepAlive extends TServerSocket { + public TServerSocketKeepAlive(InetSocketAddress address) throws TTransportException { + super(address, 0); + } - @Override - protected TSocket acceptImpl() throws TTransportException { - TSocket ts = super.acceptImpl(); - try { - ts.getSocket().setKeepAlive(true); - } catch (SocketException e) { - throw new TTransportException(e); - } - return ts; - } + @Override + protected TSocket acceptImpl() throws TTransportException { + TSocket ts = super.acceptImpl(); + try { + ts.getSocket().setKeepAlive(true); + } catch (SocketException e) { + throw new TTransportException(e); + } + return ts; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowEngineContext.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowEngineContext.java index 31e53c009..7ef46418f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowEngineContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowEngineContext.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Stream; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -51,87 +52,86 @@ public class GeaFlowEngineContext extends HoodieEngineContext { - private GeaFlowEngineContext(SerializableConfiguration conf, - TaskContextSupplier taskContextSupplier) { - super(conf, taskContextSupplier); - } - - public static GeaFlowEngineContext create(RuntimeContext context, Configuration tableConf) { - SerializableConfiguration hadoopConf = - new SerializableConfiguration(FileConnectorUtil.toHadoopConf(tableConf)); - TaskContextSupplier taskContextSupplier = new GeaFlowTaskContextSupplier(context); - return new GeaFlowEngineContext(hadoopConf, taskContextSupplier); - } - - @Override - public HoodieAccumulator newAccumulator() { - return HoodieAtomicLongAccumulator.create(); - } - - @Override - public HoodieData emptyHoodieData() { - return HoodieListData.eager(Collections.emptyList()); - } - - @Override - public HoodieData parallelize(List list, int i) { - return HoodieListData.eager(list); - } - - @Override - public List map(List list, SerializableFunction func, int i) { - return list.stream().parallel().map(throwingMapWrapper(func)).collect(toList()); - } - - @Override - public List mapToPairAndReduceByKey(List list, - SerializablePairFunction serializablePairFunction, - SerializableBiFunction serializableBiFunction, int i) { - throw new GeaFlowDSLException("Write hudi is not support"); - } - - @Override - public Stream> mapPartitionsToPairAndReduceByKey(Stream stream, - SerializablePairFlatMapFunction, K, V> serializablePairFlatMapFunction, - SerializableBiFunction serializableBiFunction, - int i) { - throw new GeaFlowDSLException("Write hudi is not support"); - } - - @Override - public List reduceByKey(List> list, SerializableBiFunction serializableBiFunction, - int i) { - throw new GeaFlowDSLException("Write hudi is not support"); - } - - @Override - public List flatMap(List list, SerializableFunction> func, int i) { - return list.stream().parallel().flatMap(throwingFlatMapWrapper(func)).collect(toList()); - } - - @Override - public void foreach(List list, SerializableConsumer serializableConsumer, int i) { - throw new GeaFlowDSLException("Write hudi is not support"); - } - - @Override - public Map mapToPair(List list, SerializablePairFunction serializablePairFunction, - Integer integer) { - throw new GeaFlowDSLException("Write hudi is not support"); - } - - @Override - public void setProperty(EngineProperty engineProperty, String s) { - - } - - @Override - public Option getProperty(EngineProperty engineProperty) { - return Option.empty(); - } - - @Override - public void setJobStatus(String s, String s1) { - - } + private GeaFlowEngineContext( + SerializableConfiguration conf, TaskContextSupplier taskContextSupplier) { + super(conf, taskContextSupplier); + } + + public static GeaFlowEngineContext create(RuntimeContext context, Configuration tableConf) { + SerializableConfiguration hadoopConf = + new SerializableConfiguration(FileConnectorUtil.toHadoopConf(tableConf)); + TaskContextSupplier taskContextSupplier = new GeaFlowTaskContextSupplier(context); + return new GeaFlowEngineContext(hadoopConf, taskContextSupplier); + } + + @Override + public HoodieAccumulator newAccumulator() { + return HoodieAtomicLongAccumulator.create(); + } + + @Override + public HoodieData emptyHoodieData() { + return HoodieListData.eager(Collections.emptyList()); + } + + @Override + public HoodieData parallelize(List list, int i) { + return HoodieListData.eager(list); + } + + @Override + public List map(List list, SerializableFunction func, int i) { + return list.stream().parallel().map(throwingMapWrapper(func)).collect(toList()); + } + + @Override + public List mapToPairAndReduceByKey( + List list, + SerializablePairFunction serializablePairFunction, + SerializableBiFunction serializableBiFunction, + int i) { + throw new GeaFlowDSLException("Write hudi is not support"); + } + + @Override + public Stream> mapPartitionsToPairAndReduceByKey( + Stream stream, + SerializablePairFlatMapFunction, K, V> serializablePairFlatMapFunction, + SerializableBiFunction serializableBiFunction, + int i) { + throw new GeaFlowDSLException("Write hudi is not support"); + } + + @Override + public List reduceByKey( + List> list, SerializableBiFunction serializableBiFunction, int i) { + throw new GeaFlowDSLException("Write hudi is not support"); + } + + @Override + public List flatMap(List list, SerializableFunction> func, int i) { + return list.stream().parallel().flatMap(throwingFlatMapWrapper(func)).collect(toList()); + } + + @Override + public void foreach(List list, SerializableConsumer serializableConsumer, int i) { + throw new GeaFlowDSLException("Write hudi is not support"); + } + + @Override + public Map mapToPair( + List list, SerializablePairFunction serializablePairFunction, Integer integer) { + throw new GeaFlowDSLException("Write hudi is not support"); + } + + @Override + public void setProperty(EngineProperty engineProperty, String s) {} + + @Override + public Option getProperty(EngineProperty engineProperty) { + return Option.empty(); + } + + @Override + public void setJobStatus(String s, String s1) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowHoodieTableFileIndex.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowHoodieTableFileIndex.java index 3f2cc0cd3..3182012b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowHoodieTableFileIndex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowHoodieTableFileIndex.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -36,62 +37,77 @@ public class GeaFlowHoodieTableFileIndex extends BaseHoodieTableFileIndex { + private GeaFlowHoodieTableFileIndex( + HoodieEngineContext engineContext, + HoodieTableMetaClient metaClient, + TypedProperties configProperties, + HoodieTableQueryType queryType, + List queryPaths, + Option specifiedQueryInstant, + boolean shouldIncludePendingCommits, + boolean shouldValidateInstant, + FileStatusCache fileStatusCache, + boolean shouldListLazily) { + super( + engineContext, + metaClient, + configProperties, + queryType, + queryPaths, + specifiedQueryInstant, + shouldIncludePendingCommits, + shouldValidateInstant, + fileStatusCache, + shouldListLazily); + } - private GeaFlowHoodieTableFileIndex(HoodieEngineContext engineContext, - HoodieTableMetaClient metaClient, - TypedProperties configProperties, - HoodieTableQueryType queryType, - List queryPaths, - Option specifiedQueryInstant, - boolean shouldIncludePendingCommits, boolean shouldValidateInstant, - FileStatusCache fileStatusCache, boolean shouldListLazily) { - super(engineContext, metaClient, configProperties, queryType, queryPaths, specifiedQueryInstant, - shouldIncludePendingCommits, shouldValidateInstant, fileStatusCache, shouldListLazily); - } - - public static GeaFlowHoodieTableFileIndex create(RuntimeContext context, Configuration tableConf) { - HoodieEngineContext engineContext = GeaFlowEngineContext.create(context, tableConf); - String path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); + public static GeaFlowHoodieTableFileIndex create( + RuntimeContext context, Configuration tableConf) { + HoodieEngineContext engineContext = GeaFlowEngineContext.create(context, tableConf); + String path = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH); - HoodieTableMetaClient metaClient = HoodieTableMetaClient.builder() + HoodieTableMetaClient metaClient = + HoodieTableMetaClient.builder() .setMetaserverConfig(tableConf.getConfigMap()) .setBasePath(path) .setConf(FileConnectorUtil.toHadoopConf(tableConf)) .build(); - TypedProperties configProperties = HoodieUtil.toTypeProperties(tableConf.getConfigMap()); - - FileStatusCache noCache = new FileStatusCache() { - @Override - public Option get(Path path) { - return Option.empty(); - } - - @Override - public void put(Path path, FileStatus[] fileStatuses) { + TypedProperties configProperties = HoodieUtil.toTypeProperties(tableConf.getConfigMap()); - } + FileStatusCache noCache = + new FileStatusCache() { + @Override + public Option get(Path path) { + return Option.empty(); + } - @Override - public void invalidate() { + @Override + public void put(Path path, FileStatus[] fileStatuses) {} - } + @Override + public void invalidate() {} }; - return new GeaFlowHoodieTableFileIndex( - engineContext, metaClient, - configProperties, HoodieTableQueryType.SNAPSHOT, - Collections.singletonList(new Path(path)), - Option.empty(), false, false, - noCache, false); - } + return new GeaFlowHoodieTableFileIndex( + engineContext, + metaClient, + configProperties, + HoodieTableQueryType.SNAPSHOT, + Collections.singletonList(new Path(path)), + Option.empty(), + false, + false, + noCache, + false); + } - @Override - protected Object[] doParsePartitionColumnValues(String[] strings, String s) { - return new Object[0]; - } + @Override + protected Object[] doParsePartitionColumnValues(String[] strings, String s) { + return new Object[0]; + } - @Override - public List getAllQueryPartitionPaths() { - return super.getAllQueryPartitionPaths(); - } + @Override + public List getAllQueryPartitionPaths() { + return super.getAllQueryPartitionPaths(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowTaskContextSupplier.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowTaskContextSupplier.java index 2a94eb9f9..45f30a2fe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowTaskContextSupplier.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/GeaFlowTaskContextSupplier.java @@ -21,39 +21,38 @@ import java.util.Objects; import java.util.function.Supplier; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.hudi.common.engine.EngineProperty; import org.apache.hudi.common.engine.TaskContextSupplier; import org.apache.hudi.common.util.Option; -/** - * The {@link TaskContextSupplier} implementation for geaflow. - */ +/** The {@link TaskContextSupplier} implementation for geaflow. */ public class GeaFlowTaskContextSupplier extends TaskContextSupplier { - private final RuntimeContext context; + private final RuntimeContext context; - public GeaFlowTaskContextSupplier(RuntimeContext context) { - this.context = Objects.requireNonNull(context); - } + public GeaFlowTaskContextSupplier(RuntimeContext context) { + this.context = Objects.requireNonNull(context); + } - @Override - public Supplier getPartitionIdSupplier() { - return () -> context.getTaskArgs().getTaskIndex(); - } + @Override + public Supplier getPartitionIdSupplier() { + return () -> context.getTaskArgs().getTaskIndex(); + } - @Override - public Supplier getStageIdSupplier() { - return () -> context.getTaskArgs().getTaskIndex(); - } + @Override + public Supplier getStageIdSupplier() { + return () -> context.getTaskArgs().getTaskIndex(); + } - @Override - public Supplier getAttemptIdSupplier() { - return context::getPipelineId; - } + @Override + public Supplier getAttemptIdSupplier() { + return context::getPipelineId; + } - @Override - public Option getProperty(EngineProperty engineProperty) { - return Option.empty(); - } + @Override + public Option getProperty(EngineProperty engineProperty) { + return Option.empty(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnector.java index 389e995f1..cf7b22294 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnector.java @@ -25,15 +25,15 @@ public class HoodieTableConnector implements TableReadableConnector { - private static final String HUDI = "HUDI"; + private static final String HUDI = "HUDI"; - @Override - public String getType() { - return HUDI; - } + @Override + public String getType() { + return HUDI; + } - @Override - public TableSource createSource(Configuration conf) { - return new HoodieTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new HoodieTableSource(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableSource.java index 26bb2640e..de9a63aa5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableSource.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -49,70 +50,70 @@ public class HoodieTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(HoodieTableSource.class); - - private Configuration tableConf; + private static final Logger LOGGER = LoggerFactory.getLogger(HoodieTableSource.class); - private TableSchema tableSchema; + private Configuration tableConf; - private GeaFlowHoodieTableFileIndex fileIndex; + private TableSchema tableSchema; - private Path basePath; + private GeaFlowHoodieTableFileIndex fileIndex; - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.tableConf = tableConf; - this.tableSchema = tableSchema; - } - - @Override - public void open(RuntimeContext context) { - this.fileIndex = GeaFlowHoodieTableFileIndex.create(context, tableConf); - this.basePath = new Path(tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH)); - FileSystem fileSystem = FileConnectorUtil.getHdfsFileSystem(tableConf); - this.basePath = fileSystem.makeQualified(basePath); - LOGGER.info("open hudi table source: {}", basePath); - } + private Path basePath; - @Override - public List listPartitions() { - List paths = fileIndex.getAllQueryPartitionPaths(); - List fileSplits = new ArrayList<>(); - for (PartitionPath path : paths) { - fileSplits.add(new FileSplit(basePath.toString(), path.getPath())); - } - return fileSplits; - } + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.tableConf = tableConf; + this.tableSchema = tableSchema; + } - @Override - public List listPartitions(int parallelism) { - return listPartitions(); - } + @Override + public void open(RuntimeContext context) { + this.fileIndex = GeaFlowHoodieTableFileIndex.create(context, tableConf); + this.basePath = new Path(tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH)); + FileSystem fileSystem = FileConnectorUtil.getHdfsFileSystem(tableConf); + this.basePath = fileSystem.makeQualified(basePath); + LOGGER.info("open hudi table source: {}", basePath); + } - @Override - public TableDeserializer getDeserializer(Configuration conf) { - // no op here as hoodie table source return row already, deserializer is no need. - return null; + @Override + public List listPartitions() { + List paths = fileIndex.getAllQueryPartitionPaths(); + List fileSplits = new ArrayList<>(); + for (PartitionPath path : paths) { + fileSplits.add(new FileSplit(basePath.toString(), path.getPath())); } - - @Override - public FetchData fetch(Partition partition, Optional startOffset, FetchWindow windowInfo) - throws IOException { - if (windowInfo.getType() == WindowType.ALL_WINDOW) { - FileSplit split = (FileSplit) partition; - ParquetFormat format = new ParquetFormat(); - format.init(tableConf, tableSchema, split); - Iterator iterator = format.batchRead(); - Offset nextOffset = new FileOffset(-1L); - format.close(); - return (FetchData) FetchData.createBatchFetch(iterator, nextOffset); - } else { - throw new GeaFlowDSLException("Hudi table source not support window:{}", windowInfo.getType()); - } + return fileSplits; + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + // no op here as hoodie table source return row already, deserializer is no need. + return null; + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + if (windowInfo.getType() == WindowType.ALL_WINDOW) { + FileSplit split = (FileSplit) partition; + ParquetFormat format = new ParquetFormat(); + format.init(tableConf, tableSchema, split); + Iterator iterator = format.batchRead(); + Offset nextOffset = new FileOffset(-1L); + format.close(); + return (FetchData) FetchData.createBatchFetch(iterator, nextOffset); + } else { + throw new GeaFlowDSLException( + "Hudi table source not support window:{}", windowInfo.getType()); } + } - @Override - public void close() { - - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieUtil.java index 8ad9f7ed2..0fe7db4aa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/main/java/org/apache/geaflow/dsl/connector/hudi/HoodieUtil.java @@ -20,15 +20,16 @@ package org.apache.geaflow.dsl.connector.hudi; import java.util.Map; + import org.apache.hudi.common.config.TypedProperties; public class HoodieUtil { - public static TypedProperties toTypeProperties(Map config) { - TypedProperties typedProperties = new TypedProperties(); - for (Map.Entry entry : config.entrySet()) { - typedProperties.setProperty(entry.getKey(), entry.getValue()); - } - return typedProperties; + public static TypedProperties toTypeProperties(Map config) { + TypedProperties typedProperties = new TypedProperties(); + for (Map.Entry entry : config.entrySet()) { + typedProperties.setProperty(entry.getKey(), entry.getValue()); } + return typedProperties; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/test/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/test/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnectorTest.java index 95d126263..35b39f82b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/test/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-hudi/src/test/java/org/apache/geaflow/dsl/connector/hudi/HoodieTableConnectorTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.connector.hudi; -import com.google.common.collect.Lists; import java.io.File; import java.io.IOException; import java.util.ArrayList; @@ -30,6 +29,7 @@ import java.util.Optional; import java.util.Properties; import java.util.function.Supplier; + import org.apache.avro.Schema; import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.GenericRecordBuilder; @@ -69,127 +69,126 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class HoodieTableConnectorTest { - private final StructType dataSchema = new StructType( - new TableField("id", Types.INTEGER, false), - new TableField("name", Types.BINARY_STRING), - new TableField("price", Types.DOUBLE) - ); - - private final StructType partitionSchema = new StructType( - new TableField("dt", Types.BINARY_STRING, false) - ); - - private final TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); - - private int commitNo = 1; - - @Test - public void testReadHudi() throws IOException { - String tmpDir = "/tmp/hudi/test/" + System.nanoTime(); - FileUtils.deleteQuietly(new File(tmpDir)); - writeData(tmpDir, - "1,a1,10", - "2,a2,12", - "3,a3,12", - "4,a4,15", - "5,a5,10"); - - TableConnector tableConnector = ConnectorFactory.loadConnector("hudi"); - Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "hudi"); - TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; - - Map tableConfMap = new HashMap<>(); - tableConfMap.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), tmpDir); - Configuration tableConf = new Configuration(tableConfMap); - TableSource tableSource = readableConnector.createSource(tableConf); - tableSource.init(tableConf, tableSchema); - - tableSource.open(new DefaultRuntimeContext(tableConf)); - - List partitions = tableSource.listPartitions(); - - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(-1)); - readRows.addAll(Lists.newArrayList(rows.getDataIterator())); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[1, a1, 10.0]\n" - + "[2, a2, 12.0]\n" - + "[3, a3, 12.0]\n" - + "[4, a4, 15.0]\n" - + "[5, a5, 10.0]"); - } + private final StructType dataSchema = + new StructType( + new TableField("id", Types.INTEGER, false), + new TableField("name", Types.BINARY_STRING), + new TableField("price", Types.DOUBLE)); + + private final StructType partitionSchema = + new StructType(new TableField("dt", Types.BINARY_STRING, false)); + + private final TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); + + private int commitNo = 1; - private void writeData(String path, String... lines) throws IOException { - org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration(); - HoodieJavaEngineContext context = new HoodieJavaEngineContext(hadoopConf, new TestTaskContextSupplier()); + @Test + public void testReadHudi() throws IOException { + String tmpDir = "/tmp/hudi/test/" + System.nanoTime(); + FileUtils.deleteQuietly(new File(tmpDir)); + writeData(tmpDir, "1,a1,10", "2,a2,12", "3,a3,12", "4,a4,15", "5,a5,10"); - HoodieTableMetaClient.PropertyBuilder builder = - HoodieTableMetaClient.withPropertyBuilder() - .setDatabaseName("default") - .setTableName("h0") - .setTableType(HoodieTableType.COPY_ON_WRITE) - .setPayloadClass(HoodieAvroPayload.class); - Properties processedProperties = builder.fromProperties(new Properties()).build(); + TableConnector tableConnector = ConnectorFactory.loadConnector("hudi"); + Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "hudi"); + TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; - HoodieTableMetaClient.initTableAndGetMetaClient(hadoopConf, path, - processedProperties); + Map tableConfMap = new HashMap<>(); + tableConfMap.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), tmpDir); + Configuration tableConf = new Configuration(tableConfMap); + TableSource tableSource = readableConnector.createSource(tableConf); + tableSource.init(tableConf, tableSchema); - Schema schema = ParquetFormat.convertToAvroSchema(dataSchema, false); - HoodieWriteConfig writeConfig = HoodieWriteConfig.newBuilder() + tableSource.open(new DefaultRuntimeContext(tableConf)); + + List partitions = tableSource.listPartitions(); + + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + FetchData rows = tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(-1)); + readRows.addAll(Lists.newArrayList(rows.getDataIterator())); + } + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[1, a1, 10.0]\n" + + "[2, a2, 12.0]\n" + + "[3, a3, 12.0]\n" + + "[4, a4, 15.0]\n" + + "[5, a5, 10.0]"); + } + + private void writeData(String path, String... lines) throws IOException { + org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration(); + HoodieJavaEngineContext context = + new HoodieJavaEngineContext(hadoopConf, new TestTaskContextSupplier()); + + HoodieTableMetaClient.PropertyBuilder builder = + HoodieTableMetaClient.withPropertyBuilder() + .setDatabaseName("default") + .setTableName("h0") + .setTableType(HoodieTableType.COPY_ON_WRITE) + .setPayloadClass(HoodieAvroPayload.class); + Properties processedProperties = builder.fromProperties(new Properties()).build(); + + HoodieTableMetaClient.initTableAndGetMetaClient(hadoopConf, path, processedProperties); + + Schema schema = ParquetFormat.convertToAvroSchema(dataSchema, false); + HoodieWriteConfig writeConfig = + HoodieWriteConfig.newBuilder() .withEngineType(EngineType.JAVA) .withEmbeddedTimelineServerEnabled(false) .withPath(path) .withSchema(schema.toString()) .build(); - HoodieJavaWriteClient writeClient = new HoodieJavaWriteClient(context, writeConfig); - - List records = new ArrayList<>(); - for (String line : lines) { - String[] fields = line.split(","); - GenericRecord record = new GenericRecordBuilder(schema) - .set("id", Integer.parseInt(fields[0])) - .set("name", fields[1]) - .set("price", Double.parseDouble(fields[2])) - .build(); - HoodieRecordPayload hoodiePayload = new DefaultHoodieRecordPayload(Option.of(record)); - HoodieKey key = new HoodieKey(fields[0], "2023/07/04"); - HoodieAvroRecord hoodieRecord = new HoodieAvroRecord(key, hoodiePayload); - records.add(hoodieRecord); - } - - String commitTime = String.format("%09d", commitNo++); - writeClient.startCommitWithTime(commitTime); - writeClient.insert(records, commitTime); + HoodieJavaWriteClient writeClient = new HoodieJavaWriteClient(context, writeConfig); + + List records = new ArrayList<>(); + for (String line : lines) { + String[] fields = line.split(","); + GenericRecord record = + new GenericRecordBuilder(schema) + .set("id", Integer.parseInt(fields[0])) + .set("name", fields[1]) + .set("price", Double.parseDouble(fields[2])) + .build(); + HoodieRecordPayload hoodiePayload = new DefaultHoodieRecordPayload(Option.of(record)); + HoodieKey key = new HoodieKey(fields[0], "2023/07/04"); + HoodieAvroRecord hoodieRecord = new HoodieAvroRecord(key, hoodiePayload); + records.add(hoodieRecord); } - private static class TestTaskContextSupplier extends TaskContextSupplier { + String commitTime = String.format("%09d", commitNo++); + writeClient.startCommitWithTime(commitTime); + writeClient.insert(records, commitTime); + } - private final int partitionId = 0; - private final int stageId = 0; - private final long attemptId = 0; + private static class TestTaskContextSupplier extends TaskContextSupplier { - @Override - public Supplier getPartitionIdSupplier() { - return () -> partitionId; - } + private final int partitionId = 0; + private final int stageId = 0; + private final long attemptId = 0; - @Override - public Supplier getStageIdSupplier() { - return () -> stageId; - } + @Override + public Supplier getPartitionIdSupplier() { + return () -> partitionId; + } - @Override - public Supplier getAttemptIdSupplier() { - return () -> attemptId; - } + @Override + public Supplier getStageIdSupplier() { + return () -> stageId; + } + + @Override + public Supplier getAttemptIdSupplier() { + return () -> attemptId; + } - @Override - public Option getProperty(EngineProperty prop) { - return Option.empty(); - } + @Override + public Option getProperty(EngineProperty prop) { + return Option.empty(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCConfigKeys.java index 9cdb213c7..b90df7ec8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCConfigKeys.java @@ -23,50 +23,46 @@ import org.apache.geaflow.common.config.ConfigKeys; public class JDBCConfigKeys { - public static final ConfigKey GEAFLOW_DSL_JDBC_DRIVER = ConfigKeys - .key("geaflow.dsl.jdbc.driver") - .noDefaultValue() - .description("The JDBC driver."); + public static final ConfigKey GEAFLOW_DSL_JDBC_DRIVER = + ConfigKeys.key("geaflow.dsl.jdbc.driver").noDefaultValue().description("The JDBC driver."); - public static final ConfigKey GEAFLOW_DSL_JDBC_URL = ConfigKeys - .key("geaflow.dsl.jdbc.url") - .noDefaultValue() - .description("The database URL."); + public static final ConfigKey GEAFLOW_DSL_JDBC_URL = + ConfigKeys.key("geaflow.dsl.jdbc.url").noDefaultValue().description("The database URL."); - public static final ConfigKey GEAFLOW_DSL_JDBC_USERNAME = ConfigKeys - .key("geaflow.dsl.jdbc.username") - .noDefaultValue() - .description("The database username."); + public static final ConfigKey GEAFLOW_DSL_JDBC_USERNAME = + ConfigKeys.key("geaflow.dsl.jdbc.username") + .noDefaultValue() + .description("The database username."); - public static final ConfigKey GEAFLOW_DSL_JDBC_PASSWORD = ConfigKeys - .key("geaflow.dsl.jdbc.password") - .noDefaultValue() - .description("The database password."); + public static final ConfigKey GEAFLOW_DSL_JDBC_PASSWORD = + ConfigKeys.key("geaflow.dsl.jdbc.password") + .noDefaultValue() + .description("The database password."); - public static final ConfigKey GEAFLOW_DSL_JDBC_TABLE_NAME = ConfigKeys - .key("geaflow.dsl.jdbc.table.name") - .noDefaultValue() - .description("The table name."); + public static final ConfigKey GEAFLOW_DSL_JDBC_TABLE_NAME = + ConfigKeys.key("geaflow.dsl.jdbc.table.name").noDefaultValue().description("The table name."); - public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_NUM = ConfigKeys - .key("geaflow.dsl.jdbc.partition.num") - .defaultValue(1L) - .description("The JDBC partition number, default 1."); + public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_NUM = + ConfigKeys.key("geaflow.dsl.jdbc.partition.num") + .defaultValue(1L) + .description("The JDBC partition number, default 1."); - public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_COLUMN = ConfigKeys - .key("geaflow.dsl.jdbc.partition.column") - .defaultValue("id") - .description("The JDBC partition column."); + public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_COLUMN = + ConfigKeys.key("geaflow.dsl.jdbc.partition.column") + .defaultValue("id") + .description("The JDBC partition column."); - public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_LOWERBOUND = ConfigKeys - .key("geaflow.dsl.jdbc.partition.lowerbound") - .defaultValue(0L) - .description("The lowerbound of JDBC partition, just used to decide the partition stride, " - + "not for filtering the rows in table."); + public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_LOWERBOUND = + ConfigKeys.key("geaflow.dsl.jdbc.partition.lowerbound") + .defaultValue(0L) + .description( + "The lowerbound of JDBC partition, just used to decide the partition stride, " + + "not for filtering the rows in table."); - public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_UPPERBOUND = ConfigKeys - .key("geaflow.dsl.jdbc.partition.upperbound") - .defaultValue(0L) - .description("The upperbound of JDBC partition, just used to decide the partition stride, " - + "not for filtering the rows in table."); + public static final ConfigKey GEAFLOW_DSL_JDBC_PARTITION_UPPERBOUND = + ConfigKeys.key("geaflow.dsl.jdbc.partition.upperbound") + .defaultValue(0L) + .description( + "The upperbound of JDBC partition, just used to decide the partition stride, " + + "not for filtering the rows in table."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnector.java index ca892f82f..d21de5199 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnector.java @@ -27,20 +27,20 @@ public class JDBCTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "JDBC"; + public static final String TYPE = "JDBC"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new JDBCTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new JDBCTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new JDBCTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new JDBCTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSink.java index 7a9114673..2b13d4484 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSink.java @@ -24,6 +24,7 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; @@ -36,80 +37,80 @@ public class JDBCTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableSink.class); - private Configuration tableConf; - private StructType schema; - private String driver; - private String url; - private String username; - private String password; - private String tableName; - private Connection connection; - private Statement statement; + private Configuration tableConf; + private StructType schema; + private String driver; + private String url; + private String username; + private String password; + private String tableName; + private Connection connection; + private Statement statement; - @Override - public void init(Configuration tableConf, StructType tableSchema) { - LOGGER.info("init jdbc sink with config: {}, \n schema: {}", tableConf, tableSchema); - this.tableConf = tableConf; - this.schema = tableSchema; + @Override + public void init(Configuration tableConf, StructType tableSchema) { + LOGGER.info("init jdbc sink with config: {}, \n schema: {}", tableConf, tableSchema); + this.tableConf = tableConf; + this.schema = tableSchema; - this.driver = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_DRIVER); - this.url = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_URL); - this.username = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_USERNAME); - this.password = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PASSWORD); - this.tableName = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_TABLE_NAME); - } + this.driver = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_DRIVER); + this.url = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_URL); + this.username = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_USERNAME); + this.password = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PASSWORD); + this.tableName = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_TABLE_NAME); + } - @Override - public void open(RuntimeContext context) { - try { - Class.forName(this.driver); - this.connection = DriverManager.getConnection(url, username, password); - this.connection.setAutoCommit(false); - this.statement = connection.createStatement(); - } catch (Exception e) { - throw new GeaFlowDSLException("failed to connect to database", e); - } + @Override + public void open(RuntimeContext context) { + try { + Class.forName(this.driver); + this.connection = DriverManager.getConnection(url, username, password); + this.connection.setAutoCommit(false); + this.statement = connection.createStatement(); + } catch (Exception e) { + throw new GeaFlowDSLException("failed to connect to database", e); } + } - @Override - public void write(Row row) throws IOException { - try { - JDBCUtils.insertIntoTable(this.statement, this.tableName, this.schema.getFields(), row); - } catch (SQLException e) { - throw new GeaFlowDSLException("failed to write to table: " + tableName, e); - } + @Override + public void write(Row row) throws IOException { + try { + JDBCUtils.insertIntoTable(this.statement, this.tableName, this.schema.getFields(), row); + } catch (SQLException e) { + throw new GeaFlowDSLException("failed to write to table: " + tableName, e); } + } - @Override - public void finish() throws IOException { - try { - connection.commit(); - } catch (SQLException e) { - LOGGER.error("failed to commit", e); - try { - connection.rollback(); - } catch (SQLException ex) { - throw new GeaFlowDSLException("failed to rollback", e); - } - } + @Override + public void finish() throws IOException { + try { + connection.commit(); + } catch (SQLException e) { + LOGGER.error("failed to commit", e); + try { + connection.rollback(); + } catch (SQLException ex) { + throw new GeaFlowDSLException("failed to rollback", e); + } } + } - @Override - public void close() { - try { - if (this.statement != null) { - this.statement.close(); - this.statement = null; - } - if (this.connection != null) { - this.connection.close(); - this.connection = null; - } - LOGGER.info("close"); - } catch (SQLException e) { - throw new GeaFlowDSLException("failed to close"); - } + @Override + public void close() { + try { + if (this.statement != null) { + this.statement.close(); + this.statement = null; + } + if (this.connection != null) { + this.connection.close(); + this.connection = null; + } + LOGGER.info("close"); + } catch (SQLException e) { + throw new GeaFlowDSLException("failed to close"); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSource.java index 85e96eaab..27eaf5dbd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableSource.java @@ -32,6 +32,7 @@ import java.util.Objects; import java.util.Optional; import java.util.stream.Stream; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -52,232 +53,239 @@ public class JDBCTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableSource.class); - private Configuration tableConf; - private StructType schema; - private String driver; - private String url; - private String username; - private String password; - private String tableName; - private long partitionNum; - private String partitionColumn; - private long lowerBound; - private long upperBound; - private Map partitionConnectionMap = new HashMap<>(); - private Map partitionStatementMap = new HashMap<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableSource.class); + private Configuration tableConf; + private StructType schema; + private String driver; + private String url; + private String username; + private String password; + private String tableName; + private long partitionNum; + private String partitionColumn; + private long lowerBound; + private long upperBound; + private Map partitionConnectionMap = new HashMap<>(); + private Map partitionStatementMap = new HashMap<>(); - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - LOGGER.info("prepare with config: {}, \n schema: {}", tableConf, tableSchema); - this.tableConf = tableConf; - this.schema = tableSchema; + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + LOGGER.info("prepare with config: {}, \n schema: {}", tableConf, tableSchema); + this.tableConf = tableConf; + this.schema = tableSchema; - this.driver = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_DRIVER); - this.url = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_URL); - this.username = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_USERNAME); - this.password = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PASSWORD); - this.tableName = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_TABLE_NAME); - this.partitionNum = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_NUM); - if (this.partitionNum <= 0) { - throw new GeaFlowDSLException("Invalid partition number: {}", partitionNum); - } - this.partitionColumn = - tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_COLUMN); - this.lowerBound = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_LOWERBOUND); - this.upperBound = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_UPPERBOUND); - // PartitionColumn, lowerbound and upperbound must all be specified. Otherwise, ignore. - if (Stream.of(partitionColumn, lowerBound, upperBound).allMatch(Objects::nonNull)) { - if (partitionNum == 1) { - // create one connection and ignore partitionColumn, lowerbound, upperbound. - } else if (partitionNum > 1 && lowerBound >= upperBound) { - throw new GeaFlowDSLException("Upperbound must greater than lowerbound" - + "(lowerbound:%d upperbound:%d).", lowerBound, upperBound); - } else { - partitionNum = Math.min(upperBound - lowerBound, partitionNum); - } - } + this.driver = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_DRIVER); + this.url = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_URL); + this.username = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_USERNAME); + this.password = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PASSWORD); + this.tableName = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_TABLE_NAME); + this.partitionNum = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_NUM); + if (this.partitionNum <= 0) { + throw new GeaFlowDSLException("Invalid partition number: {}", partitionNum); } - - @Override - public void open(RuntimeContext context) { - try { - Class.forName(this.driver); - } catch (ClassNotFoundException e) { - throw new GeaFlowDSLException("failed to load driver: {}.", this.driver); - } + this.partitionColumn = tableConf.getString(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_COLUMN); + this.lowerBound = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_LOWERBOUND); + this.upperBound = tableConf.getLong(JDBCConfigKeys.GEAFLOW_DSL_JDBC_PARTITION_UPPERBOUND); + // PartitionColumn, lowerbound and upperbound must all be specified. Otherwise, ignore. + if (Stream.of(partitionColumn, lowerBound, upperBound).allMatch(Objects::nonNull)) { + if (partitionNum == 1) { + // create one connection and ignore partitionColumn, lowerbound, upperbound. + } else if (partitionNum > 1 && lowerBound >= upperBound) { + throw new GeaFlowDSLException( + "Upperbound must greater than lowerbound" + "(lowerbound:%d upperbound:%d).", + lowerBound, upperBound); + } else { + partitionNum = Math.min(upperBound - lowerBound, partitionNum); + } } + } - @Override - public List listPartitions() { - if (partitionNum <= 0) { - throw new GeaFlowDSLException("Invalid partition number: {}", partitionNum); - } - - if (partitionNum == 1) { - return Collections.singletonList(new JDBCPartition(tableName, "")); - } + @Override + public void open(RuntimeContext context) { + try { + Class.forName(this.driver); + } catch (ClassNotFoundException e) { + throw new GeaFlowDSLException("failed to load driver: {}.", this.driver); + } + } - long stride = upperBound / partitionNum - lowerBound / partitionNum; - long currentValue = lowerBound; - List partitions = new ArrayList<>(); - for (long i = 0; i < partitionNum; ++i) { - String lBound = i != 0 ? String.format("%s >= %d", partitionColumn, currentValue) : - null; - currentValue += stride; - String uBound = i != partitionNum - 1 ? String.format("%s < %d", partitionColumn, - currentValue) : null; - String whereClause; - if (uBound == null) { - whereClause = lBound; - } else if (lBound == null) { - whereClause = String.format("%s OR %s IS NULL", uBound, partitionColumn); - } else { - whereClause = String.format("%s AND %s", lBound, uBound); - } - whereClause = "WHERE " + whereClause; - partitions.add(new JDBCPartition(tableName, whereClause)); - } - return partitions; + @Override + public List listPartitions() { + if (partitionNum <= 0) { + throw new GeaFlowDSLException("Invalid partition number: {}", partitionNum); } - @Override - public List listPartitions(int parallelism) { - return listPartitions(); + if (partitionNum == 1) { + return Collections.singletonList(new JDBCPartition(tableName, "")); } - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadRowTableDeserializer(); + long stride = upperBound / partitionNum - lowerBound / partitionNum; + long currentValue = lowerBound; + List partitions = new ArrayList<>(); + for (long i = 0; i < partitionNum; ++i) { + String lBound = i != 0 ? String.format("%s >= %d", partitionColumn, currentValue) : null; + currentValue += stride; + String uBound = + i != partitionNum - 1 ? String.format("%s < %d", partitionColumn, currentValue) : null; + String whereClause; + if (uBound == null) { + whereClause = lBound; + } else if (lBound == null) { + whereClause = String.format("%s OR %s IS NULL", uBound, partitionColumn); + } else { + whereClause = String.format("%s AND %s", lBound, uBound); + } + whereClause = "WHERE " + whereClause; + partitions.add(new JDBCPartition(tableName, whereClause)); } + return partitions; + } - @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - if (!(windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW - || windowInfo.getType() == WindowType.ALL_WINDOW)) { - throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); - } - JDBCPartition jdbcPartition = (JDBCPartition) partition; - if (!jdbcPartition.getTableName().equals(this.tableName)) { - throw new GeaFlowDSLException("wrong partition"); - } - Statement statement = partitionStatementMap.get(partition); - if (statement == null) { - try { - Connection connection = DriverManager.getConnection(url, username, password); - statement = connection.createStatement(); - partitionConnectionMap.put(partition, connection); - partitionStatementMap.put(partition, statement); - } catch (SQLException e) { - throw new GeaFlowDSLException("failed to connect."); - } - } + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } - long offset = 0; - if (startOffset.isPresent()) { - offset = startOffset.get().getOffset(); - } - List dataList; - try { - dataList = JDBCUtils.selectRowsFromTable(statement, this.tableName, - jdbcPartition.getWhereClause(), this.schema.size(), offset, windowInfo.windowSize(), this.schema.getField(0).getName()); - } catch (SQLException e) { - throw new GeaFlowDSLException("select rows form table failed.", e); - } - JDBCOffset nextOffset = new JDBCOffset(offset + dataList.size()); - boolean isFinish = windowInfo.getType() == WindowType.ALL_WINDOW || dataList.size() < windowInfo.windowSize(); - return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadRowTableDeserializer(); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + if (!(windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW + || windowInfo.getType() == WindowType.ALL_WINDOW)) { + throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); + } + JDBCPartition jdbcPartition = (JDBCPartition) partition; + if (!jdbcPartition.getTableName().equals(this.tableName)) { + throw new GeaFlowDSLException("wrong partition"); + } + Statement statement = partitionStatementMap.get(partition); + if (statement == null) { + try { + Connection connection = DriverManager.getConnection(url, username, password); + statement = connection.createStatement(); + partitionConnectionMap.put(partition, connection); + partitionStatementMap.put(partition, statement); + } catch (SQLException e) { + throw new GeaFlowDSLException("failed to connect."); + } } - @Override - public void close() { - try { - for (Statement statement : this.partitionStatementMap.values()) { - if (statement != null) { - statement.close(); - } - } - this.partitionStatementMap.clear(); - for (Connection connection : this.partitionConnectionMap.values()) { - if (connection != null) { - connection.close(); - } - } - this.partitionConnectionMap.clear(); - } catch (SQLException e) { - throw new GeaFlowDSLException("failed to close connection."); - } + long offset = 0; + if (startOffset.isPresent()) { + offset = startOffset.get().getOffset(); + } + List dataList; + try { + dataList = + JDBCUtils.selectRowsFromTable( + statement, + this.tableName, + jdbcPartition.getWhereClause(), + this.schema.size(), + offset, + windowInfo.windowSize(), + this.schema.getField(0).getName()); + } catch (SQLException e) { + throw new GeaFlowDSLException("select rows form table failed.", e); } + JDBCOffset nextOffset = new JDBCOffset(offset + dataList.size()); + boolean isFinish = + windowInfo.getType() == WindowType.ALL_WINDOW || dataList.size() < windowInfo.windowSize(); + return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); + } - public static class JDBCPartition implements Partition { + @Override + public void close() { + try { + for (Statement statement : this.partitionStatementMap.values()) { + if (statement != null) { + statement.close(); + } + } + this.partitionStatementMap.clear(); + for (Connection connection : this.partitionConnectionMap.values()) { + if (connection != null) { + connection.close(); + } + } + this.partitionConnectionMap.clear(); + } catch (SQLException e) { + throw new GeaFlowDSLException("failed to close connection."); + } + } - String tableName; - String whereClause; + public static class JDBCPartition implements Partition { - public JDBCPartition(String tableName, String whereClause) { - this.tableName = tableName; - this.whereClause = whereClause; - } + String tableName; + String whereClause; - public String getTableName() { - return tableName; - } + public JDBCPartition(String tableName, String whereClause) { + this.tableName = tableName; + this.whereClause = whereClause; + } - public String getWhereClause() { - return whereClause; - } + public String getTableName() { + return tableName; + } - @Override - public String getName() { - if (whereClause == null || whereClause.isEmpty()) { - return tableName; - } else { - return tableName + "-" + whereClause; - } - } + public String getWhereClause() { + return whereClause; + } - @Override - public int hashCode() { - return Objects.hash(tableName, whereClause); - } + @Override + public String getName() { + if (whereClause == null || whereClause.isEmpty()) { + return tableName; + } else { + return tableName + "-" + whereClause; + } + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof JDBCPartition)) { - return false; - } - JDBCPartition that = (JDBCPartition) o; - return Objects.equals(tableName, that.tableName) && Objects.equals(whereClause, - that.whereClause); - } + @Override + public int hashCode() { + return Objects.hash(tableName, whereClause); + } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof JDBCPartition)) { + return false; + } + JDBCPartition that = (JDBCPartition) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(whereClause, that.whereClause); } + } - public static class JDBCOffset implements Offset { + public static class JDBCOffset implements Offset { - private final long offset; + private final long offset; - public JDBCOffset(long offset) { - this.offset = offset; - } + public JDBCOffset(long offset) { + this.offset = offset; + } - @Override - public String humanReadable() { - return String.valueOf(offset); - } + @Override + public String humanReadable() { + return String.valueOf(offset); + } - @Override - public long getOffset() { - return offset; - } + @Override + public long getOffset() { + return offset; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/util/JDBCUtils.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/util/JDBCUtils.java index b23caa410..305b9ae05 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/util/JDBCUtils.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/main/java/org/apache/geaflow/dsl/connector/jdbc/util/JDBCUtils.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; + import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.type.IType; @@ -39,93 +40,106 @@ public class JDBCUtils { - public static final int VARCHAR_MAXLENGTH = 255; + public static final int VARCHAR_MAXLENGTH = 255; - private static String tableFieldToSQL(TableField tableField) { - String typeName = tableField.getType().getName(); - switch (typeName) { - case Types.TYPE_NAME_STRING: - case Types.TYPE_NAME_BINARY_STRING: - typeName = SqlTypeName.VARCHAR.getName(); - break; - case Types.TYPE_NAME_LONG: - typeName = SqlTypeName.BIGINT.getName(); - break; - default: - typeName = typeName.toUpperCase(); - } - return String.format("%s %s%s %s", tableField.getName(), typeName, - typeName.equals(SqlTypeName.VARCHAR.getName()) ? "(" + VARCHAR_MAXLENGTH + ")" : "", - tableField.isNullable() ? "NULL" : "NOT " + "NULL"); + private static String tableFieldToSQL(TableField tableField) { + String typeName = tableField.getType().getName(); + switch (typeName) { + case Types.TYPE_NAME_STRING: + case Types.TYPE_NAME_BINARY_STRING: + typeName = SqlTypeName.VARCHAR.getName(); + break; + case Types.TYPE_NAME_LONG: + typeName = SqlTypeName.BIGINT.getName(); + break; + default: + typeName = typeName.toUpperCase(); } + return String.format( + "%s %s%s %s", + tableField.getName(), + typeName, + typeName.equals(SqlTypeName.VARCHAR.getName()) ? "(" + VARCHAR_MAXLENGTH + ")" : "", + tableField.isNullable() ? "NULL" : "NOT " + "NULL"); + } - public static void createTemporaryTable(Statement statement, String tableName, - List fields) throws SQLException { - StringBuilder tableFields = new StringBuilder(); - for (TableField field : fields) { - tableFields.append(tableFieldToSQL(field)).append(",\n"); - } - tableFields.deleteCharAt(tableFields.lastIndexOf(",")); - String createTableQuery = String.format("CREATE TEMPORARY TABLE %s (\n%s);", tableName, - tableFields); - statement.execute(createTableQuery); + public static void createTemporaryTable( + Statement statement, String tableName, List fields) throws SQLException { + StringBuilder tableFields = new StringBuilder(); + for (TableField field : fields) { + tableFields.append(tableFieldToSQL(field)).append(",\n"); } + tableFields.deleteCharAt(tableFields.lastIndexOf(",")); + String createTableQuery = + String.format("CREATE TEMPORARY TABLE %s (\n%s);", tableName, tableFields); + statement.execute(createTableQuery); + } - public static void insertIntoTable(Statement statement, String tableName, - List fields, Row row) throws SQLException { - Object[] values = new Object[fields.size()]; - boolean isFirst = true; - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < fields.size(); i++) { - if (isFirst) { - isFirst = false; - } else { - builder.append(","); - } - IType type = fields.get(i).getType(); - Object value = row.getField(i, type); - if (value == null) { - if (fields.get(i).isNullable()) { - builder.append("null"); - } else { - throw new RuntimeException("filed " + fields.get(i).getName() + " can not be null"); - } - } else if (type.getClass() == BinaryStringType.class || type.getClass() == StringType.class) { - builder.append("'").append(value).append("'"); - } else { - builder.append(value); - } - values[i] = row.getField(i, fields.get(i).getType()); + public static void insertIntoTable( + Statement statement, String tableName, List fields, Row row) throws SQLException { + Object[] values = new Object[fields.size()]; + boolean isFirst = true; + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < fields.size(); i++) { + if (isFirst) { + isFirst = false; + } else { + builder.append(","); + } + IType type = fields.get(i).getType(); + Object value = row.getField(i, type); + if (value == null) { + if (fields.get(i).isNullable()) { + builder.append("null"); + } else { + throw new RuntimeException("filed " + fields.get(i).getName() + " can not be null"); } - - String insertIntoValues = builder.toString(); - String insertColumns = StringUtils.join(fields.stream().map( - field -> field.getName()).collect(Collectors.toList()), ","); - String insertIntoTableQuery = String.format("INSERT INTO %s (%s) VALUES (%s);", tableName, insertColumns, - insertIntoValues); - statement.execute(insertIntoTableQuery); + } else if (type.getClass() == BinaryStringType.class || type.getClass() == StringType.class) { + builder.append("'").append(value).append("'"); + } else { + builder.append(value); + } + values[i] = row.getField(i, fields.get(i).getType()); } - public static List selectRowsFromTable(Statement statement, String tableName, - String whereClause, int columnNum, long startOffset, - long windowSize, String orderByColumnName) throws SQLException { - if (windowSize == Windows.SIZE_OF_ALL_WINDOW) { - windowSize = Integer.MAX_VALUE; - } else if (windowSize <= 0) { - throw new GeaFlowDSLException("wrong windowSize"); - } - String selectRowsFromTableQuery = String.format("SELECT * FROM %s %s ORDER BY %s LIMIT %s OFFSET %s;", + String insertIntoValues = builder.toString(); + String insertColumns = + StringUtils.join( + fields.stream().map(field -> field.getName()).collect(Collectors.toList()), ","); + String insertIntoTableQuery = + String.format( + "INSERT INTO %s (%s) VALUES (%s);", tableName, insertColumns, insertIntoValues); + statement.execute(insertIntoTableQuery); + } + + public static List selectRowsFromTable( + Statement statement, + String tableName, + String whereClause, + int columnNum, + long startOffset, + long windowSize, + String orderByColumnName) + throws SQLException { + if (windowSize == Windows.SIZE_OF_ALL_WINDOW) { + windowSize = Integer.MAX_VALUE; + } else if (windowSize <= 0) { + throw new GeaFlowDSLException("wrong windowSize"); + } + String selectRowsFromTableQuery = + String.format( + "SELECT * FROM %s %s ORDER BY %s LIMIT %s OFFSET %s;", tableName, whereClause, orderByColumnName, windowSize, startOffset); - ResultSet resultSet = statement.executeQuery(selectRowsFromTableQuery); - List rowList = new ArrayList<>(); - while (resultSet.next()) { - Object[] values = new Object[columnNum]; - for (int i = 1; i <= columnNum; i++) { - values[i - 1] = resultSet.getObject(i); - } - rowList.add(ObjectRow.create(values)); - } - resultSet.close(); - return rowList; + ResultSet resultSet = statement.executeQuery(selectRowsFromTableQuery); + List rowList = new ArrayList<>(); + while (resultSet.next()) { + Object[] values = new Object[columnNum]; + for (int i = 1; i <= columnNum; i++) { + values[i - 1] = resultSet.getObject(i); + } + rowList.add(ObjectRow.create(values)); } + resultSet.close(); + return rowList; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/test/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/test/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnectorTest.java index 68e0e69b3..60005cb94 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/test/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-jdbc/src/test/java/org/apache/geaflow/dsl/connector/jdbc/JDBCTableConnectorTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -37,84 +38,83 @@ import org.slf4j.LoggerFactory; public class JDBCTableConnectorTest { - private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableConnectorTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTableConnectorTest.class); - private static final String driver = "org.h2.Driver"; - private static final String URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; - private static final String username = "h2_user"; - private static final String password = "h2_pwd"; - private static Connection connection; - private static Statement statement; + private static final String driver = "org.h2.Driver"; + private static final String URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private static final String username = "h2_user"; + private static final String password = "h2_pwd"; + private static Connection connection; + private static Statement statement; - @BeforeClass - public static void setup() throws SQLException { - LOGGER.info("start h2 database."); - JdbcDataSource dataSource = new JdbcDataSource(); - dataSource.setURL(URL); - dataSource.setUser(username); - dataSource.setPassword(password); + @BeforeClass + public static void setup() throws SQLException { + LOGGER.info("start h2 database."); + JdbcDataSource dataSource = new JdbcDataSource(); + dataSource.setURL(URL); + dataSource.setUser(username); + dataSource.setPassword(password); - connection = java.sql.DriverManager.getConnection(URL); - statement = connection.createStatement(); - statement.execute("CREATE TABLE test_table (id INT PRIMARY KEY, name VARCHAR(255))"); - statement.execute("INSERT INTO test_table (id, name) VALUES (1, 'Test1')"); - statement.execute("INSERT INTO test_table (id, name) VALUES (2, 'Test2')"); - statement.execute("INSERT INTO test_table (id, name) VALUES (3, 'Test3')"); - statement.execute("INSERT INTO test_table (id, name) VALUES (4, 'Test4')"); - } + connection = java.sql.DriverManager.getConnection(URL); + statement = connection.createStatement(); + statement.execute("CREATE TABLE test_table (id INT PRIMARY KEY, name VARCHAR(255))"); + statement.execute("INSERT INTO test_table (id, name) VALUES (1, 'Test1')"); + statement.execute("INSERT INTO test_table (id, name) VALUES (2, 'Test2')"); + statement.execute("INSERT INTO test_table (id, name) VALUES (3, 'Test3')"); + statement.execute("INSERT INTO test_table (id, name) VALUES (4, 'Test4')"); + } - @AfterClass - public static void cleanup() throws SQLException { - statement.close(); - connection.close(); - } + @AfterClass + public static void cleanup() throws SQLException { + statement.close(); + connection.close(); + } - @Test - public void testCreateTable() throws SQLException { - List tableFieldList = new ArrayList<>(); - tableFieldList.add(new TableField("id", Types.INTEGER, false)); - tableFieldList.add(new TableField("v1", Types.DOUBLE, true)); - tableFieldList.add(new TableField("v2", Types.DOUBLE, true)); - JDBCUtils.createTemporaryTable(statement, "another_table", tableFieldList); - } + @Test + public void testCreateTable() throws SQLException { + List tableFieldList = new ArrayList<>(); + tableFieldList.add(new TableField("id", Types.INTEGER, false)); + tableFieldList.add(new TableField("v1", Types.DOUBLE, true)); + tableFieldList.add(new TableField("v2", Types.DOUBLE, true)); + JDBCUtils.createTemporaryTable(statement, "another_table", tableFieldList); + } - @Test - public void testInsertIntoTable() throws SQLException { - List tableFieldList = new ArrayList<>(); - tableFieldList.add(new TableField("id", Types.INTEGER, false)); - tableFieldList.add(new TableField("name", Types.BINARY_STRING, true)); - Row row = ObjectRow.create(new Object[]{5, null}); - JDBCUtils.insertIntoTable(statement, "test_table", tableFieldList, row); - List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", - "", 2, 0, 20, "id"); - Row resultRow = null; - for (Row queryRow : rowList) { - if ((Integer) queryRow.getField(0, Types.INTEGER) == 5) { - resultRow = queryRow; - break; - } - } - assert resultRow != null; - assert resultRow.getField(1, Types.BINARY_STRING) == null; + @Test + public void testInsertIntoTable() throws SQLException { + List tableFieldList = new ArrayList<>(); + tableFieldList.add(new TableField("id", Types.INTEGER, false)); + tableFieldList.add(new TableField("name", Types.BINARY_STRING, true)); + Row row = ObjectRow.create(new Object[] {5, null}); + JDBCUtils.insertIntoTable(statement, "test_table", tableFieldList, row); + List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", "", 2, 0, 20, "id"); + Row resultRow = null; + for (Row queryRow : rowList) { + if ((Integer) queryRow.getField(0, Types.INTEGER) == 5) { + resultRow = queryRow; + break; + } } + assert resultRow != null; + assert resultRow.getField(1, Types.BINARY_STRING) == null; + } - @Test - public void testSelectRowsFromTable1() throws SQLException { - List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", "", 2, 0, 2, "id"); - assert rowList.size() == 2; - } + @Test + public void testSelectRowsFromTable1() throws SQLException { + List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", "", 2, 0, 2, "id"); + assert rowList.size() == 2; + } - @Test - public void testSelectRowsFromTable2() throws SQLException { - List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", - "WHERE id < 2", 2, 0, 3, "id"); - assert rowList.size() == 1; - } + @Test + public void testSelectRowsFromTable2() throws SQLException { + List rowList = + JDBCUtils.selectRowsFromTable(statement, "test_table", "WHERE id < 2", 2, 0, 3, "id"); + assert rowList.size() == 1; + } - @Test - public void testSelectRowsFromTable3() throws SQLException { - List rowList = JDBCUtils.selectRowsFromTable(statement, "test_table", - "WHERE id < 4", 2, 0, 1, "id"); - assert rowList.size() == 1; - } + @Test + public void testSelectRowsFromTable3() throws SQLException { + List rowList = + JDBCUtils.selectRowsFromTable(statement, "test_table", "WHERE id < 4", 2, 0, 1, "id"); + assert rowList.size() == 1; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaConfigKeys.java index b53634144..d788b297c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaConfigKeys.java @@ -24,33 +24,31 @@ public class KafkaConfigKeys { - public static final ConfigKey GEAFLOW_DSL_KAFKA_SERVERS = ConfigKeys - .key("geaflow.dsl.kafka.servers") - .noDefaultValue() - .description("The kafka bootstrap servers list."); - - public static final ConfigKey GEAFLOW_DSL_KAFKA_TOPIC = ConfigKeys - .key("geaflow.dsl.kafka.topic") - .noDefaultValue() - .description("The kafka topic."); - - public static final ConfigKey GEAFLOW_DSL_KAFKA_GROUP_ID = ConfigKeys - .key("geaflow.dsl.kafka.group.id") - .defaultValue("default-group-id") - .description("The kafka group id, default is 'default-group-id'."); - - public static final ConfigKey GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE = ConfigKeys - .key("geaflow.dsl.kafka.pull.batch.size") - .defaultValue(100) - .description("The kafka pull batch size"); - - public static final ConfigKey GEAFLOW_DSL_KAFKA_DATA_OPERATION_TIMEOUT = ConfigKeys - .key("geaflow.dsl.kafka.data.operation.timeout.seconds") - .defaultValue(30) - .description("The kafka pool/write data timeout"); - - public static final ConfigKey GEAFLOW_DSL_KAFKA_CLIENT_ID = ConfigKeys - .key("geaflow.dsl.kafka.client.id") - .defaultValue(null) - .description("The kafka client id"); + public static final ConfigKey GEAFLOW_DSL_KAFKA_SERVERS = + ConfigKeys.key("geaflow.dsl.kafka.servers") + .noDefaultValue() + .description("The kafka bootstrap servers list."); + + public static final ConfigKey GEAFLOW_DSL_KAFKA_TOPIC = + ConfigKeys.key("geaflow.dsl.kafka.topic").noDefaultValue().description("The kafka topic."); + + public static final ConfigKey GEAFLOW_DSL_KAFKA_GROUP_ID = + ConfigKeys.key("geaflow.dsl.kafka.group.id") + .defaultValue("default-group-id") + .description("The kafka group id, default is 'default-group-id'."); + + public static final ConfigKey GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE = + ConfigKeys.key("geaflow.dsl.kafka.pull.batch.size") + .defaultValue(100) + .description("The kafka pull batch size"); + + public static final ConfigKey GEAFLOW_DSL_KAFKA_DATA_OPERATION_TIMEOUT = + ConfigKeys.key("geaflow.dsl.kafka.data.operation.timeout.seconds") + .defaultValue(30) + .description("The kafka pool/write data timeout"); + + public static final ConfigKey GEAFLOW_DSL_KAFKA_CLIENT_ID = + ConfigKeys.key("geaflow.dsl.kafka.client.id") + .defaultValue(null) + .description("The kafka client id"); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnector.java index 218bf8a23..22f0dac6c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnector.java @@ -27,20 +27,20 @@ public class KafkaTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "KAFKA"; + public static final String TYPE = "KAFKA"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new KafkaTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new KafkaTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new KafkaTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new KafkaTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSink.java index acddde56d..66933ec2a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSink.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Properties; import java.util.concurrent.TimeUnit; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; @@ -42,104 +43,104 @@ public class KafkaTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTableSink.class); - - private Configuration tableConf; - private StructType schema; - private String separator; - private String servers; - private String valueSerializerClass; - private String topic; - private Properties props; - private int writeTimeout; - - private transient KafkaProducer producer; - - @Override - public void init(Configuration conf, StructType schema) { - LOGGER.info("prepare with config: {}, \n schema: {}", conf, schema); - this.tableConf = conf; - this.schema = schema; - - this.servers = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_SERVERS); - topic = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_TOPIC); - this.valueSerializerClass = KafkaConstants.KAFKA_VALUE_SERIALIZER_CLASS; - this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTableSink.class); + + private Configuration tableConf; + private StructType schema; + private String separator; + private String servers; + private String valueSerializerClass; + private String topic; + private Properties props; + private int writeTimeout; + + private transient KafkaProducer producer; + + @Override + public void init(Configuration conf, StructType schema) { + LOGGER.info("prepare with config: {}, \n schema: {}", conf, schema); + this.tableConf = conf; + this.schema = schema; + + this.servers = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_SERVERS); + topic = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_TOPIC); + this.valueSerializerClass = KafkaConstants.KAFKA_VALUE_SERIALIZER_CLASS; + this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + } + + @Override + public void open(RuntimeContext context) { + props = new Properties(); + props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, servers); + props.setProperty(KafkaConstants.KAFKA_KEY_SERIALIZER, valueSerializerClass); + props.setProperty(KafkaConstants.KAFKA_VALUE_SERIALIZER, valueSerializerClass); + if (context.getConfiguration().contains(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID)) { + String useClientId = + context.getConfiguration().getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID); + props.put(KafkaConstants.KAFKA_CLIENT_ID, useClientId); } - @Override - public void open(RuntimeContext context) { - props = new Properties(); - props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, servers); - props.setProperty(KafkaConstants.KAFKA_KEY_SERIALIZER, valueSerializerClass); - props.setProperty(KafkaConstants.KAFKA_VALUE_SERIALIZER, valueSerializerClass); - if (context.getConfiguration().contains(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID)) { - String useClientId = context.getConfiguration() - .getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID); - props.put(KafkaConstants.KAFKA_CLIENT_ID, useClientId); - } - - writeTimeout = context.getConfiguration() + writeTimeout = + context + .getConfiguration() .getInteger(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_DATA_OPERATION_TIMEOUT); - producer = new KafkaProducer<>(props); - - Properties consumerProps = new Properties(); - consumerProps.putAll(props); - String valueDeserializerClassString = - KafkaConstants.KAFKA_VALUE_DESERIALIZER_CLASS; - consumerProps.setProperty(KafkaConstants.KAFKA_KEY_DESERIALIZER, - valueDeserializerClassString); - consumerProps.setProperty(KafkaConstants.KAFKA_VALUE_DESERIALIZER, - valueDeserializerClassString); - KafkaConsumer tmpConsumer = new KafkaConsumer<>(consumerProps); - Map> topic2PartitionInfo = tmpConsumer.listTopics(); - tmpConsumer.close(); - if (!topic2PartitionInfo.containsKey(topic)) { - producer.close(); - producer = null; - throw new GeaFlowDSLException("Topic: [{}] has not been created.", topic); - } - } - - @Override - public void write(Row row) throws IOException { - Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - values[i] = row.getField(i, schema.getType(i)); - } - ProducerRecord record = createRecord(values); - try { - producer.send(record).get(writeTimeout, TimeUnit.SECONDS); - } catch (Exception e) { - throw new IOException("When write kafka.", e); - } + producer = new KafkaProducer<>(props); + + Properties consumerProps = new Properties(); + consumerProps.putAll(props); + String valueDeserializerClassString = KafkaConstants.KAFKA_VALUE_DESERIALIZER_CLASS; + consumerProps.setProperty(KafkaConstants.KAFKA_KEY_DESERIALIZER, valueDeserializerClassString); + consumerProps.setProperty( + KafkaConstants.KAFKA_VALUE_DESERIALIZER, valueDeserializerClassString); + KafkaConsumer tmpConsumer = new KafkaConsumer<>(consumerProps); + Map> topic2PartitionInfo = tmpConsumer.listTopics(); + tmpConsumer.close(); + if (!topic2PartitionInfo.containsKey(topic)) { + producer.close(); + producer = null; + throw new GeaFlowDSLException("Topic: [{}] has not been created.", topic); } + } - private void flush() { - LOGGER.info("flush"); - if (producer != null) { - producer.flush(); - } else { - LOGGER.warn("Producer is null."); - } + @Override + public void write(Row row) throws IOException { + Object[] values = new Object[schema.size()]; + for (int i = 0; i < schema.size(); i++) { + values[i] = row.getField(i, schema.getType(i)); } - - protected ProducerRecord createRecord(Object[] flushValues) { - return new ProducerRecord<>(topic, StringUtils.join(flushValues, separator)); + ProducerRecord record = createRecord(values); + try { + producer.send(record).get(writeTimeout, TimeUnit.SECONDS); + } catch (Exception e) { + throw new IOException("When write kafka.", e); } - - @Override - public void finish() throws IOException { - flush(); + } + + private void flush() { + LOGGER.info("flush"); + if (producer != null) { + producer.flush(); + } else { + LOGGER.warn("Producer is null."); } - - @Override - public void close() { - LOGGER.info("close"); - flush(); - if (producer != null) { - producer.close(); - } + } + + protected ProducerRecord createRecord(Object[] flushValues) { + return new ProducerRecord<>(topic, StringUtils.join(flushValues, separator)); + } + + @Override + public void finish() throws IOException { + flush(); + } + + @Override + public void close() { + LOGGER.info("close"); + flush(); + if (producer != null) { + producer.close(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSource.java index 0beac47e2..bae4b2c95 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableSource.java @@ -18,7 +18,6 @@ */ package org.apache.geaflow.dsl.connector.kafka; -import com.google.common.base.Preconditions; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; @@ -30,6 +29,7 @@ import java.util.Properties; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -57,304 +57,322 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; public class KafkaTableSource extends AbstractTableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTableSource.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTableSource.class); - private static final Duration OPERATION_TIMEOUT = - Duration.ofSeconds(KafkaConstants.KAFKA_OPERATION_TIMEOUT_SECONDS); + private static final Duration OPERATION_TIMEOUT = + Duration.ofSeconds(KafkaConstants.KAFKA_OPERATION_TIMEOUT_SECONDS); - private String topic; - private long startTimeMs; - private Properties props; - private Duration pollTimeout; + private String topic; + private long startTimeMs; + private Properties props; + private Duration pollTimeout; - private transient KafkaConsumer consumer; + private transient KafkaConsumer consumer; - @Override - public void init(Configuration conf, TableSchema tableSchema) { - - final String servers = - conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_SERVERS); - topic = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_TOPIC); - pollTimeout = Duration.ofSeconds(conf.getInteger(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_DATA_OPERATION_TIMEOUT)); - final String groupId = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_GROUP_ID); - final String valueDeserializerClassString = KafkaConstants.KAFKA_VALUE_DESERIALIZER_CLASS; - final String startTimeStr = conf.getString(ConnectorConfigKeys.GEAFLOW_DSL_START_TIME, + @Override + public void init(Configuration conf, TableSchema tableSchema) { + + final String servers = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_SERVERS); + topic = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_TOPIC); + pollTimeout = + Duration.ofSeconds( + conf.getInteger(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_DATA_OPERATION_TIMEOUT)); + final String groupId = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_GROUP_ID); + final String valueDeserializerClassString = KafkaConstants.KAFKA_VALUE_DESERIALIZER_CLASS; + final String startTimeStr = + conf.getString( + ConnectorConfigKeys.GEAFLOW_DSL_START_TIME, (String) ConnectorConfigKeys.GEAFLOW_DSL_START_TIME.getDefaultValue()); - if (startTimeStr.equalsIgnoreCase(KafkaConstants.KAFKA_BEGIN)) { - startTimeMs = 0; - } else if (startTimeStr.equalsIgnoreCase(KafkaConstants.KAFKA_LATEST)) { - startTimeMs = Long.MAX_VALUE; - } else { - startTimeMs = DateTimeUtil.toUnixTime(startTimeStr, ConnectorConstants.START_TIME_FORMAT); - } - if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE)) { - Preconditions.checkState(startTimeMs > 0, "Time window need unified start time! Please set config:%s", ConnectorConfigKeys.GEAFLOW_DSL_START_TIME.getKey()); - } - int pullSize = conf.getInteger(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE); - if (pullSize <= 0) { - throw new GeaFlowDSLException("Config {} is illegal:{}", KafkaConfigKeys.GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE, pullSize); - } - this.props = new Properties(); - props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, servers); - props.setProperty(KafkaConstants.KAFKA_KEY_DESERIALIZER, - valueDeserializerClassString); - props.setProperty(KafkaConstants.KAFKA_VALUE_DESERIALIZER, - valueDeserializerClassString); - props.setProperty(KafkaConstants.KAFKA_MAX_POLL_RECORDS, - String.valueOf(pullSize)); - props.setProperty(KafkaConstants.KAFKA_GROUP_ID, groupId); - if (conf.contains(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID)) { - String useClientId = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID); - props.put(KafkaConstants.KAFKA_CLIENT_ID, useClientId); - } - LOGGER.info("open kafka, servers is: {}, topic is:{}, config is:{}, schema is: {}", - servers, topic, conf, tableSchema); + if (startTimeStr.equalsIgnoreCase(KafkaConstants.KAFKA_BEGIN)) { + startTimeMs = 0; + } else if (startTimeStr.equalsIgnoreCase(KafkaConstants.KAFKA_LATEST)) { + startTimeMs = Long.MAX_VALUE; + } else { + startTimeMs = DateTimeUtil.toUnixTime(startTimeStr, ConnectorConstants.START_TIME_FORMAT); } - - @Override - public void open(RuntimeContext context) { - consumer = new KafkaConsumer<>(props); - LOGGER.info("consumer opened, topic: {}", topic); + if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_TIME_WINDOW_SIZE)) { + Preconditions.checkState( + startTimeMs > 0, + "Time window need unified start time! Please set config:%s", + ConnectorConfigKeys.GEAFLOW_DSL_START_TIME.getKey()); } - - @Override - public List listPartitions() { - KafkaConsumer tmpConsumer = new KafkaConsumer<>(props); - List partitions = tmpConsumer.partitionsFor(topic, OPERATION_TIMEOUT); - tmpConsumer.close(); - return partitions.stream().map( - partition -> new KafkaPartition(topic, partition.partition()) - ).collect(Collectors.toList()); + int pullSize = conf.getInteger(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE); + if (pullSize <= 0) { + throw new GeaFlowDSLException( + "Config {} is illegal:{}", KafkaConfigKeys.GEAFLOW_DSL_KAFKA_PULL_BATCH_SIZE, pullSize); } - - @Override - public List listPartitions(int parallelism) { - return listPartitions(); + this.props = new Properties(); + props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, servers); + props.setProperty(KafkaConstants.KAFKA_KEY_DESERIALIZER, valueDeserializerClassString); + props.setProperty(KafkaConstants.KAFKA_VALUE_DESERIALIZER, valueDeserializerClassString); + props.setProperty(KafkaConstants.KAFKA_MAX_POLL_RECORDS, String.valueOf(pullSize)); + props.setProperty(KafkaConstants.KAFKA_GROUP_ID, groupId); + if (conf.contains(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID)) { + String useClientId = conf.getString(KafkaConfigKeys.GEAFLOW_DSL_KAFKA_CLIENT_ID); + props.put(KafkaConstants.KAFKA_CLIENT_ID, useClientId); } - - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadDeserializer(conf); + LOGGER.info( + "open kafka, servers is: {}, topic is:{}, config is:{}, schema is: {}", + servers, + topic, + conf, + tableSchema); + } + + @Override + public void open(RuntimeContext context) { + consumer = new KafkaConsumer<>(props); + LOGGER.info("consumer opened, topic: {}", topic); + } + + @Override + public List listPartitions() { + KafkaConsumer tmpConsumer = new KafkaConsumer<>(props); + List partitions = tmpConsumer.partitionsFor(topic, OPERATION_TIMEOUT); + tmpConsumer.close(); + return partitions.stream() + .map(partition -> new KafkaPartition(topic, partition.partition())) + .collect(Collectors.toList()); + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadDeserializer(conf); + } + + private Tuple fetchPrepare( + Partition partition, Optional startOffset) { + KafkaPartition kafkaPartition = (KafkaPartition) partition; + TopicPartition topicPartition = + new TopicPartition(kafkaPartition.getTopic(), kafkaPartition.getPartition()); + Set singletonPartition = Collections.singleton(topicPartition); + + Set currentAssignment = consumer.assignment(); + if (currentAssignment.size() != 1 || !currentAssignment.contains(topicPartition)) { + consumer.assign(singletonPartition); } - private Tuple fetchPrepare(Partition partition, Optional startOffset) { - KafkaPartition kafkaPartition = (KafkaPartition) partition; - TopicPartition topicPartition = new TopicPartition(kafkaPartition.getTopic(), - kafkaPartition.getPartition()); - Set singletonPartition = Collections.singleton(topicPartition); - - Set currentAssignment = consumer.assignment(); - if (currentAssignment.size() != 1 || !currentAssignment.contains(topicPartition)) { - consumer.assign(singletonPartition); - } - - KafkaOffset reqKafkaOffset; - if (startOffset.isPresent()) { - reqKafkaOffset = (KafkaOffset) startOffset.get(); - consumer.seek(topicPartition, reqKafkaOffset.getKafkaOffset()); - return Tuple.of(topicPartition, reqKafkaOffset.getKafkaOffset()); + KafkaOffset reqKafkaOffset; + if (startOffset.isPresent()) { + reqKafkaOffset = (KafkaOffset) startOffset.get(); + consumer.seek(topicPartition, reqKafkaOffset.getKafkaOffset()); + return Tuple.of(topicPartition, reqKafkaOffset.getKafkaOffset()); + } else { + if (startTimeMs == 0) { + Map partition2Offset = + consumer.beginningOffsets(singletonPartition, OPERATION_TIMEOUT); + Long beginningOffset = partition2Offset.get(topicPartition); + if (beginningOffset == null) { + throw new GeaFlowDSLException( + "Cannot get beginning offset for partition: {}, " + "startTime: {}.", + topicPartition, + startTimeMs); } else { - if (startTimeMs == 0) { - Map partition2Offset = - consumer.beginningOffsets(singletonPartition, OPERATION_TIMEOUT); - Long beginningOffset = partition2Offset.get(topicPartition); - if (beginningOffset == null) { - throw new GeaFlowDSLException("Cannot get beginning offset for partition: {}, " - + "startTime: {}.", topicPartition, startTimeMs); - } else { - consumer.seek(topicPartition, beginningOffset); - } - return Tuple.of(topicPartition, beginningOffset); - } else if (startTimeMs == Long.MAX_VALUE) { - Map endOffsets = - consumer.endOffsets(Collections.singletonList(topicPartition), OPERATION_TIMEOUT); - Long beginningOffset = endOffsets.get(topicPartition); - if (beginningOffset == null) { - throw new GeaFlowDSLException("Cannot get beginning offset for partition: {}, " - + "startTime: {}.", topicPartition, startTimeMs); - } else { - consumer.seek(topicPartition, beginningOffset); - } - return Tuple.of(topicPartition, beginningOffset); - } else { - Map partitionOffset = - consumer.offsetsForTimes(Collections.singletonMap(topicPartition, - startTimeMs / 1000), OPERATION_TIMEOUT); - OffsetAndTimestamp offset = partitionOffset.get(topicPartition); - if (offset == null) { - throw new GeaFlowDSLException("Cannot get offset for partition: {}, " - + "startTime: {}.", topicPartition, startTimeMs); - } else { - consumer.seek(topicPartition, offset.offset()); - } - return Tuple.of(topicPartition, offset.offset()); - } + consumer.seek(topicPartition, beginningOffset); } - } - - @Override - public FetchData fetch(Partition partition, Optional startOffset, - SizeFetchWindow windowInfo) throws IOException { - TopicPartition topicPartition = fetchPrepare(partition, startOffset).f0; - - List dataList = new ArrayList<>(); - long responseMaxTimestamp = -1; - while (dataList.size() < windowInfo.windowSize()) { - ConsumerRecords records = consumer.poll(pollTimeout); - if (records.isEmpty()) { - break; - } - for (ConsumerRecord record : records) { - assert record.topic().equals(this.topic) : "Illegal topic"; - dataList.add(record.value()); - if (record.timestamp() > responseMaxTimestamp) { - responseMaxTimestamp = record.timestamp(); - } - } + return Tuple.of(topicPartition, beginningOffset); + } else if (startTimeMs == Long.MAX_VALUE) { + Map endOffsets = + consumer.endOffsets(Collections.singletonList(topicPartition), OPERATION_TIMEOUT); + Long beginningOffset = endOffsets.get(topicPartition); + if (beginningOffset == null) { + throw new GeaFlowDSLException( + "Cannot get beginning offset for partition: {}, " + "startTime: {}.", + topicPartition, + startTimeMs); + } else { + consumer.seek(topicPartition, beginningOffset); } - //reload cursor - long nextOffset = consumer.position(topicPartition, OPERATION_TIMEOUT); - KafkaOffset nextKafkaOffset; - if (responseMaxTimestamp >= 0) { - nextKafkaOffset = new KafkaOffset(nextOffset, responseMaxTimestamp); + return Tuple.of(topicPartition, beginningOffset); + } else { + Map partitionOffset = + consumer.offsetsForTimes( + Collections.singletonMap(topicPartition, startTimeMs / 1000), OPERATION_TIMEOUT); + OffsetAndTimestamp offset = partitionOffset.get(topicPartition); + if (offset == null) { + throw new GeaFlowDSLException( + "Cannot get offset for partition: {}, " + "startTime: {}.", + topicPartition, + startTimeMs); } else { - nextKafkaOffset = new KafkaOffset(nextOffset, System.currentTimeMillis()); + consumer.seek(topicPartition, offset.offset()); } - return (FetchData) FetchData.createStreamFetch(dataList, nextKafkaOffset, false); + return Tuple.of(topicPartition, offset.offset()); + } } - - @Override - public FetchData fetch(Partition partition, Optional startOffset, TimeFetchWindow windowInfo) throws IOException { - Tuple partitionWithOffset = fetchPrepare(partition, startOffset); - long windowStartTimeMs = windowInfo.getStartWindowTime(startTimeMs); - long windowEndTimeMs = windowInfo.getEndWindowTime(startTimeMs); - OffsetAndTimestamp windowStartOffset = queryTimesOffset(partitionWithOffset.f0, windowStartTimeMs); - if (windowStartOffset == null || windowStartOffset.timestamp() >= windowEndTimeMs) { - // no data in current window, skip! - KafkaOffset offset = new KafkaOffset(partitionWithOffset.f1, windowEndTimeMs); - return (FetchData) FetchData.createStreamFetch(Collections.EMPTY_LIST, offset, false); - } - List dataList = new ArrayList<>(); - long responseMaxTimestamp = -1; - while (responseMaxTimestamp < windowEndTimeMs && responseMaxTimestamp < System.currentTimeMillis()) { - ConsumerRecords records = consumer.poll(pollTimeout); - if (!records.isEmpty()) { - for (ConsumerRecord record : records) { - dataList.add(record.value()); - if (record.timestamp() > responseMaxTimestamp) { - responseMaxTimestamp = record.timestamp(); - } - } - } else if (windowEndTimeMs > System.currentTimeMillis()) { - // no new msg, break; - break; - } - } - //reload cursor - long nextOffset = consumer.position(partitionWithOffset.f0, OPERATION_TIMEOUT); - KafkaOffset nextKafkaOffset; - if (responseMaxTimestamp >= 0) { - nextKafkaOffset = new KafkaOffset(nextOffset, responseMaxTimestamp); - } else { - nextKafkaOffset = new KafkaOffset(nextOffset, System.currentTimeMillis()); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, SizeFetchWindow windowInfo) + throws IOException { + TopicPartition topicPartition = fetchPrepare(partition, startOffset).f0; + + List dataList = new ArrayList<>(); + long responseMaxTimestamp = -1; + while (dataList.size() < windowInfo.windowSize()) { + ConsumerRecords records = consumer.poll(pollTimeout); + if (records.isEmpty()) { + break; + } + for (ConsumerRecord record : records) { + assert record.topic().equals(this.topic) : "Illegal topic"; + dataList.add(record.value()); + if (record.timestamp() > responseMaxTimestamp) { + responseMaxTimestamp = record.timestamp(); } - return (FetchData) FetchData.createStreamFetch(dataList, nextKafkaOffset, false); + } } - - private OffsetAndTimestamp queryTimesOffset(TopicPartition topicPartition, long timestampInMs) { - Map partitionOffset = consumer.offsetsForTimes( - Collections.singletonMap(topicPartition, timestampInMs / 1000), OPERATION_TIMEOUT); - return partitionOffset.get(topicPartition); + // reload cursor + long nextOffset = consumer.position(topicPartition, OPERATION_TIMEOUT); + KafkaOffset nextKafkaOffset; + if (responseMaxTimestamp >= 0) { + nextKafkaOffset = new KafkaOffset(nextOffset, responseMaxTimestamp); + } else { + nextKafkaOffset = new KafkaOffset(nextOffset, System.currentTimeMillis()); } - - @Override - public void close() { - if (consumer != null) { - consumer.close(); - consumer = null; + return (FetchData) FetchData.createStreamFetch(dataList, nextKafkaOffset, false); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, TimeFetchWindow windowInfo) + throws IOException { + Tuple partitionWithOffset = fetchPrepare(partition, startOffset); + long windowStartTimeMs = windowInfo.getStartWindowTime(startTimeMs); + long windowEndTimeMs = windowInfo.getEndWindowTime(startTimeMs); + OffsetAndTimestamp windowStartOffset = + queryTimesOffset(partitionWithOffset.f0, windowStartTimeMs); + if (windowStartOffset == null || windowStartOffset.timestamp() >= windowEndTimeMs) { + // no data in current window, skip! + KafkaOffset offset = new KafkaOffset(partitionWithOffset.f1, windowEndTimeMs); + return (FetchData) FetchData.createStreamFetch(Collections.EMPTY_LIST, offset, false); + } + List dataList = new ArrayList<>(); + long responseMaxTimestamp = -1; + while (responseMaxTimestamp < windowEndTimeMs + && responseMaxTimestamp < System.currentTimeMillis()) { + ConsumerRecords records = consumer.poll(pollTimeout); + if (!records.isEmpty()) { + for (ConsumerRecord record : records) { + dataList.add(record.value()); + if (record.timestamp() > responseMaxTimestamp) { + responseMaxTimestamp = record.timestamp(); + } } - LOGGER.info("close"); + } else if (windowEndTimeMs > System.currentTimeMillis()) { + // no new msg, break; + break; + } + } + // reload cursor + long nextOffset = consumer.position(partitionWithOffset.f0, OPERATION_TIMEOUT); + KafkaOffset nextKafkaOffset; + if (responseMaxTimestamp >= 0) { + nextKafkaOffset = new KafkaOffset(nextOffset, responseMaxTimestamp); + } else { + nextKafkaOffset = new KafkaOffset(nextOffset, System.currentTimeMillis()); } + return (FetchData) FetchData.createStreamFetch(dataList, nextKafkaOffset, false); + } + private OffsetAndTimestamp queryTimesOffset(TopicPartition topicPartition, long timestampInMs) { + Map partitionOffset = + consumer.offsetsForTimes( + Collections.singletonMap(topicPartition, timestampInMs / 1000), OPERATION_TIMEOUT); + return partitionOffset.get(topicPartition); + } + + @Override + public void close() { + if (consumer != null) { + consumer.close(); + consumer = null; + } + LOGGER.info("close"); + } - public static class KafkaPartition implements Partition { + public static class KafkaPartition implements Partition { - private final String topic; - private final int partitionId; + private final String topic; + private final int partitionId; - public KafkaPartition(String topic, int partitionId) { - this.topic = topic; - this.partitionId = partitionId; - } - - @Override - public String getName() { - return topic + "-" + partitionId; - } + public KafkaPartition(String topic, int partitionId) { + this.topic = topic; + this.partitionId = partitionId; + } - @Override - public void setIndex(int index, int parallel) { - } + @Override + public String getName() { + return topic + "-" + partitionId; + } - @Override - public int hashCode() { - return Objects.hash(topic, partitionId); - } + @Override + public void setIndex(int index, int parallel) {} - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof KafkaPartition)) { - return false; - } - KafkaPartition that = (KafkaPartition) o; - return Objects.equals(topic, that.topic) && Objects.equals( - partitionId, that.partitionId); - } + @Override + public int hashCode() { + return Objects.hash(topic, partitionId); + } - public String getTopic() { - return topic; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof KafkaPartition)) { + return false; + } + KafkaPartition that = (KafkaPartition) o; + return Objects.equals(topic, that.topic) && Objects.equals(partitionId, that.partitionId); + } - public int getPartition() { - return partitionId; - } + public String getTopic() { + return topic; } + public int getPartition() { + return partitionId; + } + } - public static class KafkaOffset implements Offset { + public static class KafkaOffset implements Offset { - private final long offset; + private final long offset; - private final long humanReadableTime; + private final long humanReadableTime; - public KafkaOffset(long offset, long humanReadableTime) { - this.offset = offset; - this.humanReadableTime = humanReadableTime; - } + public KafkaOffset(long offset, long humanReadableTime) { + this.offset = offset; + this.humanReadableTime = humanReadableTime; + } - @Override - public String humanReadable() { - return DateTimeUtil.fromUnixTime(humanReadableTime, ConnectorConstants.START_TIME_FORMAT); - } + @Override + public String humanReadable() { + return DateTimeUtil.fromUnixTime(humanReadableTime, ConnectorConstants.START_TIME_FORMAT); + } - @Override - public long getOffset() { - return humanReadableTime; - } + @Override + public long getOffset() { + return humanReadableTime; + } - public long getKafkaOffset() { - return offset; - } + public long getKafkaOffset() { + return offset; + } - @Override - public boolean isTimestamp() { - return true; - } + @Override + public boolean isTimestamp() { + return true; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/utils/KafkaConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/utils/KafkaConstants.java index f4286709b..88f9d914a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/utils/KafkaConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/main/java/org/apache/geaflow/dsl/connector/kafka/utils/KafkaConstants.java @@ -21,20 +21,20 @@ public class KafkaConstants { - public static final String KAFKA_BOOTSTRAP_SERVERS = "bootstrap.servers"; - public static final String KAFKA_KEY_SERIALIZER = "key.serializer"; - public static final String KAFKA_VALUE_SERIALIZER = "value.serializer"; - public static final String KAFKA_KEY_DESERIALIZER = "key.deserializer"; - public static final String KAFKA_VALUE_DESERIALIZER = "value.deserializer"; - public static final String KAFKA_MAX_POLL_RECORDS = "max.poll.records"; - public static final String KAFKA_GROUP_ID = "group.id"; - public static final String KAFKA_CLIENT_ID = "client.id"; - public static final String KAFKA_VALUE_SERIALIZER_CLASS = - "org.apache.kafka.common.serialization.StringSerializer"; - public static final String KAFKA_VALUE_DESERIALIZER_CLASS = - "org.apache.kafka.common.serialization.StringDeserializer"; + public static final String KAFKA_BOOTSTRAP_SERVERS = "bootstrap.servers"; + public static final String KAFKA_KEY_SERIALIZER = "key.serializer"; + public static final String KAFKA_VALUE_SERIALIZER = "value.serializer"; + public static final String KAFKA_KEY_DESERIALIZER = "key.deserializer"; + public static final String KAFKA_VALUE_DESERIALIZER = "value.deserializer"; + public static final String KAFKA_MAX_POLL_RECORDS = "max.poll.records"; + public static final String KAFKA_GROUP_ID = "group.id"; + public static final String KAFKA_CLIENT_ID = "client.id"; + public static final String KAFKA_VALUE_SERIALIZER_CLASS = + "org.apache.kafka.common.serialization.StringSerializer"; + public static final String KAFKA_VALUE_DESERIALIZER_CLASS = + "org.apache.kafka.common.serialization.StringDeserializer"; - public static final String KAFKA_BEGIN = "begin"; - public static final String KAFKA_LATEST = "latest"; - public static final int KAFKA_OPERATION_TIMEOUT_SECONDS = 10; + public static final String KAFKA_BEGIN = "begin"; + public static final String KAFKA_LATEST = "latest"; + public static final int KAFKA_OPERATION_TIMEOUT_SECONDS = 10; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/test/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/test/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnectorTest.java index 2b996f1fc..7817c2f62 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/test/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-kafka/src/test/java/org/apache/geaflow/dsl/connector/kafka/KafkaTableConnectorTest.java @@ -19,9 +19,9 @@ package org.apache.geaflow.dsl.connector.kafka; -import com.alibaba.fastjson.JSON; import java.io.IOException; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.connector.api.function.OffsetStore.ConsoleOffset; import org.apache.geaflow.dsl.connector.api.serde.TableDeserializer; @@ -32,51 +32,53 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.alibaba.fastjson.JSON; + public class KafkaTableConnectorTest { - @Test - public void testKafkaPartition() { - KafkaPartition partition = new KafkaPartition("topic", 0); - Assert.assertEquals(partition.getTopic(), "topic"); - Assert.assertEquals(partition.getPartition(), 0); - Assert.assertEquals(partition.getName(), "topic-0"); + @Test + public void testKafkaPartition() { + KafkaPartition partition = new KafkaPartition("topic", 0); + Assert.assertEquals(partition.getTopic(), "topic"); + Assert.assertEquals(partition.getPartition(), 0); + Assert.assertEquals(partition.getName(), "topic-0"); - KafkaPartition partition2 = new KafkaPartition("topic", 0); - Assert.assertEquals(partition.hashCode(), partition2.hashCode()); - Assert.assertEquals(partition, partition); - Assert.assertEquals(partition, partition2); - Assert.assertNotEquals(partition, null); - } + KafkaPartition partition2 = new KafkaPartition("topic", 0); + Assert.assertEquals(partition.hashCode(), partition2.hashCode()); + Assert.assertEquals(partition, partition); + Assert.assertEquals(partition, partition2); + Assert.assertNotEquals(partition, null); + } - @Test - public void testKafkaOffset() { - KafkaOffset offset = new KafkaOffset(100, 11111111); - Assert.assertEquals(offset.getKafkaOffset(), 100L); - Assert.assertEquals(offset.humanReadable(), "1970-01-01 11:05:11"); - } + @Test + public void testKafkaOffset() { + KafkaOffset offset = new KafkaOffset(100, 11111111); + Assert.assertEquals(offset.getKafkaOffset(), 100L); + Assert.assertEquals(offset.humanReadable(), "1970-01-01 11:05:11"); + } - @Test - public void testConsoleOffset() throws IOException { - KafkaOffset test = new KafkaOffset(111L, 11111111L); - Map kvMap = JSON.parseObject(new ConsoleOffset(test).toJson(), Map.class); - Assert.assertEquals(kvMap.get("offset"), "11111111"); - Assert.assertEquals(kvMap.get("type"), "TIMESTAMP"); - Assert.assertTrue(Long.parseLong(kvMap.get("writeTime")) > 0L); - } + @Test + public void testConsoleOffset() throws IOException { + KafkaOffset test = new KafkaOffset(111L, 11111111L); + Map kvMap = JSON.parseObject(new ConsoleOffset(test).toJson(), Map.class); + Assert.assertEquals(kvMap.get("offset"), "11111111"); + Assert.assertEquals(kvMap.get("type"), "TIMESTAMP"); + Assert.assertTrue(Long.parseLong(kvMap.get("writeTime")) > 0L); + } - @Test - public void testJsonDeserializer() { - KafkaTableSource kafkaTableSource = new KafkaTableSource(); - Configuration conf = new Configuration(); - conf.put("geaflow.dsl.connector.format", "json"); - TableDeserializer deserializer = kafkaTableSource.getDeserializer(conf); - Assert.assertEquals(deserializer.getClass(), JsonDeserializer.class); - } + @Test + public void testJsonDeserializer() { + KafkaTableSource kafkaTableSource = new KafkaTableSource(); + Configuration conf = new Configuration(); + conf.put("geaflow.dsl.connector.format", "json"); + TableDeserializer deserializer = kafkaTableSource.getDeserializer(conf); + Assert.assertEquals(deserializer.getClass(), JsonDeserializer.class); + } - @Test - public void testTextDeserializer() { - KafkaTableSource kafkaTableSource = new KafkaTableSource(); - TableDeserializer deserializer = kafkaTableSource.getDeserializer(new Configuration()); - Assert.assertEquals(deserializer.getClass(), TextDeserializer.class); - } + @Test + public void testTextDeserializer() { + KafkaTableSource kafkaTableSource = new KafkaTableSource(); + TableDeserializer deserializer = kafkaTableSource.getDeserializer(new Configuration()); + Assert.assertEquals(deserializer.getClass(), TextDeserializer.class); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeys.java index 6c755547d..14ca92576 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeys.java @@ -32,78 +32,79 @@ public class Neo4jConfigKeys { - public static final ConfigKey GEAFLOW_DSL_NEO4J_URI = ConfigKeys - .key("geaflow.dsl.neo4j.uri") - .noDefaultValue() - .description("Neo4j database URI (e.g., bolt://localhost:7687)."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_USERNAME = ConfigKeys - .key("geaflow.dsl.neo4j.username") - .noDefaultValue() - .description("Neo4j database username."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_PASSWORD = ConfigKeys - .key("geaflow.dsl.neo4j.password") - .noDefaultValue() - .description("Neo4j database password."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_DATABASE = ConfigKeys - .key("geaflow.dsl.neo4j.database") - .defaultValue(DEFAULT_DATABASE) - .description("Neo4j database name."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_BATCH_SIZE = ConfigKeys - .key("geaflow.dsl.neo4j.batch.size") - .defaultValue(DEFAULT_BATCH_SIZE) - .description("Batch size for writing to Neo4j."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME = ConfigKeys - .key("geaflow.dsl.neo4j.max.connection.lifetime.millis") - .defaultValue(DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS) - .description("Maximum lifetime of a connection in milliseconds."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE = ConfigKeys - .key("geaflow.dsl.neo4j.max.connection.pool.size") - .defaultValue(DEFAULT_MAX_CONNECTION_POOL_SIZE) - .description("Maximum size of the connection pool."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT = ConfigKeys - .key("geaflow.dsl.neo4j.connection.acquisition.timeout.millis") - .defaultValue(DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS) - .description("Timeout for acquiring a connection from the pool in milliseconds."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_QUERY = ConfigKeys - .key("geaflow.dsl.neo4j.query") - .noDefaultValue() - .description("Cypher query for reading data from Neo4j."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_NODE_LABEL = ConfigKeys - .key("geaflow.dsl.neo4j.node.label") - .defaultValue(DEFAULT_NODE_LABEL) - .description("Node label for writing nodes to Neo4j."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE = ConfigKeys - .key("geaflow.dsl.neo4j.relationship.type") - .defaultValue(DEFAULT_RELATIONSHIP_TYPE) - .description("Relationship type for writing relationships to Neo4j."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_WRITE_MODE = ConfigKeys - .key("geaflow.dsl.neo4j.write.mode") - .defaultValue("node") - .description("Write mode: 'node' for writing nodes, 'relationship' for writing relationships."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_NODE_ID_FIELD = ConfigKeys - .key("geaflow.dsl.neo4j.node.id.field") - .noDefaultValue() - .description("Field name to use as node ID."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD = ConfigKeys - .key("geaflow.dsl.neo4j.relationship.source.field") - .noDefaultValue() - .description("Field name for relationship source node ID."); - - public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD = ConfigKeys - .key("geaflow.dsl.neo4j.relationship.target.field") - .noDefaultValue() - .description("Field name for relationship target node ID."); + public static final ConfigKey GEAFLOW_DSL_NEO4J_URI = + ConfigKeys.key("geaflow.dsl.neo4j.uri") + .noDefaultValue() + .description("Neo4j database URI (e.g., bolt://localhost:7687)."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_USERNAME = + ConfigKeys.key("geaflow.dsl.neo4j.username") + .noDefaultValue() + .description("Neo4j database username."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_PASSWORD = + ConfigKeys.key("geaflow.dsl.neo4j.password") + .noDefaultValue() + .description("Neo4j database password."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_DATABASE = + ConfigKeys.key("geaflow.dsl.neo4j.database") + .defaultValue(DEFAULT_DATABASE) + .description("Neo4j database name."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_BATCH_SIZE = + ConfigKeys.key("geaflow.dsl.neo4j.batch.size") + .defaultValue(DEFAULT_BATCH_SIZE) + .description("Batch size for writing to Neo4j."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME = + ConfigKeys.key("geaflow.dsl.neo4j.max.connection.lifetime.millis") + .defaultValue(DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS) + .description("Maximum lifetime of a connection in milliseconds."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE = + ConfigKeys.key("geaflow.dsl.neo4j.max.connection.pool.size") + .defaultValue(DEFAULT_MAX_CONNECTION_POOL_SIZE) + .description("Maximum size of the connection pool."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT = + ConfigKeys.key("geaflow.dsl.neo4j.connection.acquisition.timeout.millis") + .defaultValue(DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS) + .description("Timeout for acquiring a connection from the pool in milliseconds."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_QUERY = + ConfigKeys.key("geaflow.dsl.neo4j.query") + .noDefaultValue() + .description("Cypher query for reading data from Neo4j."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_NODE_LABEL = + ConfigKeys.key("geaflow.dsl.neo4j.node.label") + .defaultValue(DEFAULT_NODE_LABEL) + .description("Node label for writing nodes to Neo4j."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE = + ConfigKeys.key("geaflow.dsl.neo4j.relationship.type") + .defaultValue(DEFAULT_RELATIONSHIP_TYPE) + .description("Relationship type for writing relationships to Neo4j."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_WRITE_MODE = + ConfigKeys.key("geaflow.dsl.neo4j.write.mode") + .defaultValue("node") + .description( + "Write mode: 'node' for writing nodes, 'relationship' for writing relationships."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_NODE_ID_FIELD = + ConfigKeys.key("geaflow.dsl.neo4j.node.id.field") + .noDefaultValue() + .description("Field name to use as node ID."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD = + ConfigKeys.key("geaflow.dsl.neo4j.relationship.source.field") + .noDefaultValue() + .description("Field name for relationship source node ID."); + + public static final ConfigKey GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD = + ConfigKeys.key("geaflow.dsl.neo4j.relationship.target.field") + .noDefaultValue() + .description("Field name for relationship target node ID."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConstants.java index 153b1f5c0..105605311 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConstants.java @@ -21,19 +21,19 @@ public class Neo4jConstants { - public static final String DEFAULT_DATABASE = "neo4j"; + public static final String DEFAULT_DATABASE = "neo4j"; - public static final int DEFAULT_BATCH_SIZE = 1000; + public static final int DEFAULT_BATCH_SIZE = 1000; - public static final long DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS = 3600000L; // 1 hour + public static final long DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS = 3600000L; // 1 hour - public static final int DEFAULT_MAX_CONNECTION_POOL_SIZE = 100; + public static final int DEFAULT_MAX_CONNECTION_POOL_SIZE = 100; - public static final long DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS = 60000L; // 1 minute + public static final long DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS = 60000L; // 1 minute - public static final String DEFAULT_NODE_LABEL = "Node"; + public static final String DEFAULT_NODE_LABEL = "Node"; - public static final String DEFAULT_RELATIONSHIP_LABEL = "relationship"; + public static final String DEFAULT_RELATIONSHIP_LABEL = "relationship"; - public static final String DEFAULT_RELATIONSHIP_TYPE = "RELATES_TO"; + public static final String DEFAULT_RELATIONSHIP_TYPE = "RELATES_TO"; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnector.java index 179298d23..6825ef2a9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnector.java @@ -27,20 +27,20 @@ public class Neo4jTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "Neo4j"; + public static final String TYPE = "Neo4j"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new Neo4jTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new Neo4jTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new Neo4jTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new Neo4jTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSink.java index 239bc0015..ed86166da 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSink.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.IType; @@ -47,255 +48,263 @@ public class Neo4jTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(Neo4jTableSink.class); - - private StructType schema; - private String uri; - private String username; - private String password; - private String database; - private int batchSize; - private String writeMode; - private String nodeLabel; - private String relationshipType; - private String nodeIdField; - private String relationshipSourceField; - private String relationshipTargetField; - private long maxConnectionLifetime; - private int maxConnectionPoolSize; - private long connectionAcquisitionTimeout; - - private Driver driver; - private Session session; - private Transaction transaction; - private List batch; - - @Override - public void init(Configuration tableConf, StructType schema) { - LOGGER.info("Init Neo4j sink with config: {}, \n schema: {}", tableConf, schema); - this.schema = schema; - - this.uri = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI); - this.username = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME); - this.password = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD); - this.database = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE); - this.batchSize = tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE); - this.writeMode = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE); - this.nodeLabel = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL); - this.relationshipType = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE); - this.nodeIdField = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_ID_FIELD); - this.relationshipSourceField = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD); - this.relationshipTargetField = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD); - this.maxConnectionLifetime = tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME); - this.maxConnectionPoolSize = tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE); - this.connectionAcquisitionTimeout = tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT); - - validateConfig(); - this.batch = new ArrayList<>(batchSize); + private static final Logger LOGGER = LoggerFactory.getLogger(Neo4jTableSink.class); + + private StructType schema; + private String uri; + private String username; + private String password; + private String database; + private int batchSize; + private String writeMode; + private String nodeLabel; + private String relationshipType; + private String nodeIdField; + private String relationshipSourceField; + private String relationshipTargetField; + private long maxConnectionLifetime; + private int maxConnectionPoolSize; + private long connectionAcquisitionTimeout; + + private Driver driver; + private Session session; + private Transaction transaction; + private List batch; + + @Override + public void init(Configuration tableConf, StructType schema) { + LOGGER.info("Init Neo4j sink with config: {}, \n schema: {}", tableConf, schema); + this.schema = schema; + + this.uri = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI); + this.username = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME); + this.password = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD); + this.database = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE); + this.batchSize = tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE); + this.writeMode = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE); + this.nodeLabel = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL); + this.relationshipType = + tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE); + this.nodeIdField = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_ID_FIELD); + this.relationshipSourceField = + tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD); + this.relationshipTargetField = + tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD); + this.maxConnectionLifetime = + tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME); + this.maxConnectionPoolSize = + tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE); + this.connectionAcquisitionTimeout = + tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT); + + validateConfig(); + this.batch = new ArrayList<>(batchSize); + } + + private void validateConfig() { + if (uri == null || uri.isEmpty()) { + throw new GeaFlowDSLException("Neo4j URI must be specified"); + } + if (username == null || username.isEmpty()) { + throw new GeaFlowDSLException("Neo4j username must be specified"); + } + if (password == null || password.isEmpty()) { + throw new GeaFlowDSLException("Neo4j password must be specified"); } + if (DEFAULT_NODE_LABEL.toLowerCase().equals(writeMode)) { + if (nodeIdField == null || nodeIdField.isEmpty()) { + throw new GeaFlowDSLException("Node ID field must be specified for node write mode"); + } + } else if (DEFAULT_RELATIONSHIP_LABEL.equals(writeMode)) { + if (relationshipSourceField == null + || relationshipSourceField.isEmpty() + || relationshipTargetField == null + || relationshipTargetField.isEmpty()) { + throw new GeaFlowDSLException( + "Relationship source and target fields must be specified for relationship write mode"); + } + } else { + throw new GeaFlowDSLException( + "Invalid write mode: " + writeMode + ". Must be 'node' or 'relationship'"); + } + } - private void validateConfig() { - if (uri == null || uri.isEmpty()) { - throw new GeaFlowDSLException("Neo4j URI must be specified"); - } - if (username == null || username.isEmpty()) { - throw new GeaFlowDSLException("Neo4j username must be specified"); - } - if (password == null || password.isEmpty()) { - throw new GeaFlowDSLException("Neo4j password must be specified"); - } - if (DEFAULT_NODE_LABEL.toLowerCase().equals(writeMode)) { - if (nodeIdField == null || nodeIdField.isEmpty()) { - throw new GeaFlowDSLException("Node ID field must be specified for node write mode"); - } - } else if (DEFAULT_RELATIONSHIP_LABEL.equals(writeMode)) { - if (relationshipSourceField == null || relationshipSourceField.isEmpty() - || relationshipTargetField == null || relationshipTargetField.isEmpty()) { - throw new GeaFlowDSLException("Relationship source and target fields must be specified for relationship write mode"); - } - } else { - throw new GeaFlowDSLException("Invalid write mode: " + writeMode + ". Must be 'node' or 'relationship'"); - } + @Override + public void open(RuntimeContext context) { + try { + Config config = + Config.builder() + .withMaxConnectionLifetime(maxConnectionLifetime, TimeUnit.MILLISECONDS) + .withMaxConnectionPoolSize(maxConnectionPoolSize) + .withConnectionAcquisitionTimeout(connectionAcquisitionTimeout, TimeUnit.MILLISECONDS) + .build(); + + this.driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password), config); + + SessionConfig sessionConfig = SessionConfig.builder().withDatabase(database).build(); + + this.session = driver.session(sessionConfig); + this.transaction = session.beginTransaction(); + + LOGGER.info("Neo4j connection established successfully"); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to connect to Neo4j: " + e.getMessage(), e); } + } - @Override - public void open(RuntimeContext context) { - try { - Config config = Config.builder() - .withMaxConnectionLifetime(maxConnectionLifetime, TimeUnit.MILLISECONDS) - .withMaxConnectionPoolSize(maxConnectionPoolSize) - .withConnectionAcquisitionTimeout(connectionAcquisitionTimeout, TimeUnit.MILLISECONDS) - .build(); - - this.driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password), config); - - SessionConfig sessionConfig = SessionConfig.builder() - .withDatabase(database) - .build(); - - this.session = driver.session(sessionConfig); - this.transaction = session.beginTransaction(); - - LOGGER.info("Neo4j connection established successfully"); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to connect to Neo4j: " + e.getMessage(), e); - } + @Override + public void write(Row row) throws IOException { + batch.add(row); + if (batch.size() >= batchSize) { + flush(); } + } - @Override - public void write(Row row) throws IOException { - batch.add(row); - if (batch.size() >= batchSize) { - flush(); + @Override + public void finish() throws IOException { + if (!batch.isEmpty()) { + flush(); + } + try { + if (transaction != null) { + transaction.commit(); + transaction.close(); + transaction = null; + } + } catch (Exception e) { + LOGGER.error("Failed to commit transaction", e); + try { + if (transaction != null) { + transaction.rollback(); } + } catch (Exception ex) { + throw new GeaFlowDSLException("Failed to rollback transaction", ex); + } + throw new GeaFlowDSLException("Failed to finish writing to Neo4j", e); } + } - @Override - public void finish() throws IOException { - if (!batch.isEmpty()) { - flush(); - } - try { - if (transaction != null) { - transaction.commit(); - transaction.close(); - transaction = null; - } - } catch (Exception e) { - LOGGER.error("Failed to commit transaction", e); - try { - if (transaction != null) { - transaction.rollback(); - } - } catch (Exception ex) { - throw new GeaFlowDSLException("Failed to rollback transaction", ex); - } - throw new GeaFlowDSLException("Failed to finish writing to Neo4j", e); - } + @Override + public void close() { + try { + if (transaction != null) { + transaction.close(); + transaction = null; + } + if (session != null) { + session.close(); + session = null; + } + if (driver != null) { + driver.close(); + driver = null; + } + LOGGER.info("Neo4j connection closed successfully"); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to close Neo4j connection", e); } + } - @Override - public void close() { - try { - if (transaction != null) { - transaction.close(); - transaction = null; - } - if (session != null) { - session.close(); - session = null; - } - if (driver != null) { - driver.close(); - driver = null; - } - LOGGER.info("Neo4j connection closed successfully"); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to close Neo4j connection", e); - } + private void flush() { + if (batch.isEmpty()) { + return; } - private void flush() { - if (batch.isEmpty()) { - return; - } + try { + if (DEFAULT_NODE_LABEL.toLowerCase().equals(writeMode)) { + writeNodes(); + } else { + writeRelationships(); + } + batch.clear(); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to flush batch to Neo4j", e); + } + } - try { - if (DEFAULT_NODE_LABEL.toLowerCase().equals(writeMode)) { - writeNodes(); - } else { - writeRelationships(); - } - batch.clear(); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to flush batch to Neo4j", e); - } + private void writeNodes() { + List fieldNames = schema.getFieldNames(); + IType[] types = schema.getTypes(); + + int nodeIdIndex = fieldNames.indexOf(nodeIdField); + if (nodeIdIndex == -1) { + throw new GeaFlowDSLException("Node ID field not found in schema: " + nodeIdField); } - private void writeNodes() { - List fieldNames = schema.getFieldNames(); - IType[] types = schema.getTypes(); - - int nodeIdIndex = fieldNames.indexOf(nodeIdField); - if (nodeIdIndex == -1) { - throw new GeaFlowDSLException("Node ID field not found in schema: " + nodeIdField); + for (Row row : batch) { + Map properties = new HashMap<>(); + for (int i = 0; i < fieldNames.size(); i++) { + if (i == nodeIdIndex) { + continue; // Skip ID field, it will be used as node ID } - - for (Row row : batch) { - Map properties = new HashMap<>(); - for (int i = 0; i < fieldNames.size(); i++) { - if (i == nodeIdIndex) { - continue; // Skip ID field, it will be used as node ID - } - Object value = row.getField(i, types[i]); - if (value != null) { - properties.put(fieldNames.get(i), value); - } - } - - Object nodeId = row.getField(nodeIdIndex, types[nodeIdIndex]); - if (nodeId == null) { - throw new GeaFlowDSLException("Node ID cannot be null"); - } - - String cypher = String.format( - "MERGE (n:%s {id: $id}) SET n += $properties", - nodeLabel - ); - - Map parameters = new HashMap<>(); - parameters.put("id", nodeId); - parameters.put("properties", properties); - - transaction.run(cypher, parameters); + Object value = row.getField(i, types[i]); + if (value != null) { + properties.put(fieldNames.get(i), value); } + } + + Object nodeId = row.getField(nodeIdIndex, types[nodeIdIndex]); + if (nodeId == null) { + throw new GeaFlowDSLException("Node ID cannot be null"); + } + + String cypher = String.format("MERGE (n:%s {id: $id}) SET n += $properties", nodeLabel); + + Map parameters = new HashMap<>(); + parameters.put("id", nodeId); + parameters.put("properties", properties); + + transaction.run(cypher, parameters); } + } - private void writeRelationships() { - List fieldNames = schema.getFieldNames(); - IType[] types = schema.getTypes(); - - int sourceIndex = fieldNames.indexOf(relationshipSourceField); - int targetIndex = fieldNames.indexOf(relationshipTargetField); - - if (sourceIndex == -1) { - throw new GeaFlowDSLException("Relationship source field not found in schema: " + relationshipSourceField); + private void writeRelationships() { + List fieldNames = schema.getFieldNames(); + IType[] types = schema.getTypes(); + + int sourceIndex = fieldNames.indexOf(relationshipSourceField); + int targetIndex = fieldNames.indexOf(relationshipTargetField); + + if (sourceIndex == -1) { + throw new GeaFlowDSLException( + "Relationship source field not found in schema: " + relationshipSourceField); + } + if (targetIndex == -1) { + throw new GeaFlowDSLException( + "Relationship target field not found in schema: " + relationshipTargetField); + } + + for (Row row : batch) { + Object sourceId = row.getField(sourceIndex, types[sourceIndex]); + Object targetId = row.getField(targetIndex, types[targetIndex]); + + if (sourceId == null || targetId == null) { + throw new GeaFlowDSLException("Relationship source and target IDs cannot be null"); + } + + Map properties = new HashMap<>(); + for (int i = 0; i < fieldNames.size(); i++) { + if (i == sourceIndex || i == targetIndex) { + continue; // Skip source and target fields } - if (targetIndex == -1) { - throw new GeaFlowDSLException("Relationship target field not found in schema: " + relationshipTargetField); + Object value = row.getField(i, types[i]); + if (value != null) { + properties.put(fieldNames.get(i), value); } + } - for (Row row : batch) { - Object sourceId = row.getField(sourceIndex, types[sourceIndex]); - Object targetId = row.getField(targetIndex, types[targetIndex]); - - if (sourceId == null || targetId == null) { - throw new GeaFlowDSLException("Relationship source and target IDs cannot be null"); - } - - Map properties = new HashMap<>(); - for (int i = 0; i < fieldNames.size(); i++) { - if (i == sourceIndex || i == targetIndex) { - continue; // Skip source and target fields - } - Object value = row.getField(i, types[i]); - if (value != null) { - properties.put(fieldNames.get(i), value); - } - } - - final String cypher = String.format( - "MATCH (a {id: $sourceId}), (b {id: $targetId}) " - + "MERGE (a)-[r:%s]->(b) SET r += $properties", - relationshipType - ); - - Map parameters = new HashMap<>(); - parameters.put("sourceId", sourceId); - parameters.put("targetId", targetId); - parameters.put("properties", properties); - - transaction.run(cypher, parameters); - } + final String cypher = + String.format( + "MATCH (a {id: $sourceId}), (b {id: $targetId}) " + + "MERGE (a)-[r:%s]->(b) SET r += $properties", + relationshipType); + + Map parameters = new HashMap<>(); + parameters.put("sourceId", sourceId); + parameters.put("targetId", targetId); + parameters.put("properties", properties); + + transaction.run(cypher, parameters); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSource.java index 019e3667d..dade2f7ab 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/main/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableSource.java @@ -29,6 +29,7 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -58,243 +59,246 @@ public class Neo4jTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(Neo4jTableSource.class); - - private Configuration tableConf; - private StructType schema; - private String uri; - private String username; - private String password; - private String database; - private String cypherQuery; - private long maxConnectionLifetime; - private int maxConnectionPoolSize; - private long connectionAcquisitionTimeout; + private static final Logger LOGGER = LoggerFactory.getLogger(Neo4jTableSource.class); + + private Configuration tableConf; + private StructType schema; + private String uri; + private String username; + private String password; + private String database; + private String cypherQuery; + private long maxConnectionLifetime; + private int maxConnectionPoolSize; + private long connectionAcquisitionTimeout; + + private Driver driver; + private Map partitionSessionMap = new ConcurrentHashMap<>(); + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + LOGGER.info("Init Neo4j source with config: {}, \n schema: {}", tableConf, tableSchema); + this.tableConf = tableConf; + this.schema = tableSchema; + + this.uri = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI); + this.username = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME); + this.password = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD); + this.database = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE); + this.cypherQuery = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_QUERY); + this.maxConnectionLifetime = + tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME); + this.maxConnectionPoolSize = + tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE); + this.connectionAcquisitionTimeout = + tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT); + + if (cypherQuery == null || cypherQuery.isEmpty()) { + throw new GeaFlowDSLException("Neo4j query must be specified"); + } + } + + @Override + public void open(RuntimeContext context) { + try { + Config config = + Config.builder() + .withMaxConnectionLifetime(maxConnectionLifetime, TimeUnit.MILLISECONDS) + .withMaxConnectionPoolSize(maxConnectionPoolSize) + .withConnectionAcquisitionTimeout(connectionAcquisitionTimeout, TimeUnit.MILLISECONDS) + .build(); + + this.driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password), config); + LOGGER.info("Neo4j driver created successfully"); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to create Neo4j driver: " + e.getMessage(), e); + } + } + + @Override + public List listPartitions() { + // Neo4j doesn't have native partitioning like JDBC + // For simplicity, we return a single partition + return Collections.singletonList(new Neo4jPartition(cypherQuery)); + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadRowTableDeserializer(); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + if (!(windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW + || windowInfo.getType() == WindowType.ALL_WINDOW)) { + throw new GeaFlowDSLException("Not support window type: {}", windowInfo.getType()); + } - private Driver driver; - private Map partitionSessionMap = new ConcurrentHashMap<>(); + Neo4jPartition neo4jPartition = (Neo4jPartition) partition; + Session session = partitionSessionMap.get(partition); - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - LOGGER.info("Init Neo4j source with config: {}, \n schema: {}", tableConf, tableSchema); - this.tableConf = tableConf; - this.schema = tableSchema; - - this.uri = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI); - this.username = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME); - this.password = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD); - this.database = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE); - this.cypherQuery = tableConf.getString(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_QUERY); - this.maxConnectionLifetime = tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_LIFETIME); - this.maxConnectionPoolSize = tableConf.getInteger(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_MAX_CONNECTION_POOL_SIZE); - this.connectionAcquisitionTimeout = tableConf.getLong(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_CONNECTION_ACQUISITION_TIMEOUT); - - if (cypherQuery == null || cypherQuery.isEmpty()) { - throw new GeaFlowDSLException("Neo4j query must be specified"); - } + if (session == null) { + SessionConfig sessionConfig = SessionConfig.builder().withDatabase(database).build(); + session = driver.session(sessionConfig); + partitionSessionMap.put(partition, session); } - @Override - public void open(RuntimeContext context) { - try { - Config config = Config.builder() - .withMaxConnectionLifetime(maxConnectionLifetime, TimeUnit.MILLISECONDS) - .withMaxConnectionPoolSize(maxConnectionPoolSize) - .withConnectionAcquisitionTimeout(connectionAcquisitionTimeout, TimeUnit.MILLISECONDS) - .build(); - - this.driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password), config); - LOGGER.info("Neo4j driver created successfully"); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to create Neo4j driver: " + e.getMessage(), e); - } - } + long offset = startOffset.isPresent() ? startOffset.get().getOffset() : 0; - @Override - public List listPartitions() { - // Neo4j doesn't have native partitioning like JDBC - // For simplicity, we return a single partition - return Collections.singletonList(new Neo4jPartition(cypherQuery)); - } + List dataList = new ArrayList<>(); + try { + String query = neo4jPartition.getQuery(); + // Add SKIP and LIMIT to the query for pagination + String paginatedQuery = query + " SKIP $skip LIMIT $limit"; - @Override - public List listPartitions(int parallelism) { - return listPartitions(); - } + Map parameters = new HashMap<>(); + parameters.put("skip", offset); + parameters.put("limit", windowInfo.windowSize()); - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadRowTableDeserializer(); - } + Result result = session.run(paginatedQuery, parameters); - @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - if (!(windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW - || windowInfo.getType() == WindowType.ALL_WINDOW)) { - throw new GeaFlowDSLException("Not support window type: {}", windowInfo.getType()); - } + List fieldNames = schema.getFieldNames(); - Neo4jPartition neo4jPartition = (Neo4jPartition) partition; - Session session = partitionSessionMap.get(partition); - - if (session == null) { - SessionConfig sessionConfig = SessionConfig.builder() - .withDatabase(database) - .build(); - session = driver.session(sessionConfig); - partitionSessionMap.put(partition, session); - } + while (result.hasNext()) { + Record record = result.next(); + Object[] values = new Object[fieldNames.size()]; - long offset = startOffset.isPresent() ? startOffset.get().getOffset() : 0; - - List dataList = new ArrayList<>(); - try { - String query = neo4jPartition.getQuery(); - // Add SKIP and LIMIT to the query for pagination - String paginatedQuery = query + " SKIP $skip LIMIT $limit"; - - Map parameters = new HashMap<>(); - parameters.put("skip", offset); - parameters.put("limit", windowInfo.windowSize()); - - Result result = session.run(paginatedQuery, parameters); - - List fieldNames = schema.getFieldNames(); - - while (result.hasNext()) { - Record record = result.next(); - Object[] values = new Object[fieldNames.size()]; - - for (int i = 0; i < fieldNames.size(); i++) { - String fieldName = fieldNames.get(i); - if (record.containsKey(fieldName)) { - Value value = record.get(fieldName); - values[i] = convertNeo4jValue(value); - } else { - values[i] = null; - } - } - - dataList.add(ObjectRow.create(values)); - } - - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to fetch data from Neo4j", e); + for (int i = 0; i < fieldNames.size(); i++) { + String fieldName = fieldNames.get(i); + if (record.containsKey(fieldName)) { + Value value = record.get(fieldName); + values[i] = convertNeo4jValue(value); + } else { + values[i] = null; + } } - Neo4jOffset nextOffset = new Neo4jOffset(offset + dataList.size()); - boolean isFinish = windowInfo.getType() == WindowType.ALL_WINDOW - || dataList.size() < windowInfo.windowSize(); - - return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); + dataList.add(ObjectRow.create(values)); + } + + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to fetch data from Neo4j", e); } - @Override - public void close() { - try { - for (Session session : partitionSessionMap.values()) { - if (session != null) { - session.close(); - } - } - partitionSessionMap.clear(); - - if (driver != null) { - driver.close(); - driver = null; - } - LOGGER.info("Neo4j connections closed successfully"); - } catch (Exception e) { - throw new GeaFlowDSLException("Failed to close Neo4j connections", e); + Neo4jOffset nextOffset = new Neo4jOffset(offset + dataList.size()); + boolean isFinish = + windowInfo.getType() == WindowType.ALL_WINDOW || dataList.size() < windowInfo.windowSize(); + + return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); + } + + @Override + public void close() { + try { + for (Session session : partitionSessionMap.values()) { + if (session != null) { + session.close(); } + } + partitionSessionMap.clear(); + + if (driver != null) { + driver.close(); + driver = null; + } + LOGGER.info("Neo4j connections closed successfully"); + } catch (Exception e) { + throw new GeaFlowDSLException("Failed to close Neo4j connections", e); } + } - private Object convertNeo4jValue(Value value) { - if (value.isNull()) { - return null; - } - - switch (value.type().name()) { - case "INTEGER": - return value.asLong(); - case "FLOAT": - return value.asDouble(); - case "STRING": - return value.asString(); - case "BOOLEAN": - return value.asBoolean(); - case "LIST": - return value.asList(); - case "MAP": - return value.asMap(); - case "NODE": - return value.asNode().asMap(); - case "RELATIONSHIP": - return value.asRelationship().asMap(); - case "PATH": - return value.asPath().toString(); - default: - return value.asObject(); - } + private Object convertNeo4jValue(Value value) { + if (value.isNull()) { + return null; } - public static class Neo4jPartition implements Partition { + switch (value.type().name()) { + case "INTEGER": + return value.asLong(); + case "FLOAT": + return value.asDouble(); + case "STRING": + return value.asString(); + case "BOOLEAN": + return value.asBoolean(); + case "LIST": + return value.asList(); + case "MAP": + return value.asMap(); + case "NODE": + return value.asNode().asMap(); + case "RELATIONSHIP": + return value.asRelationship().asMap(); + case "PATH": + return value.asPath().toString(); + default: + return value.asObject(); + } + } - private final String query; + public static class Neo4jPartition implements Partition { - public Neo4jPartition(String query) { - this.query = query; - } + private final String query; - public String getQuery() { - return query; - } + public Neo4jPartition(String query) { + this.query = query; + } - @Override - public String getName() { - return "neo4j-partition-" + query.hashCode(); - } + public String getQuery() { + return query; + } - @Override - public int hashCode() { - return Objects.hash(query); - } + @Override + public String getName() { + return "neo4j-partition-" + query.hashCode(); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof Neo4jPartition)) { - return false; - } - Neo4jPartition that = (Neo4jPartition) o; - return Objects.equals(query, that.query); - } + @Override + public int hashCode() { + return Objects.hash(query); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Neo4jPartition)) { + return false; + } + Neo4jPartition that = (Neo4jPartition) o; + return Objects.equals(query, that.query); } + } - public static class Neo4jOffset implements Offset { + public static class Neo4jOffset implements Offset { - private final long offset; + private final long offset; - public Neo4jOffset(long offset) { - this.offset = offset; - } + public Neo4jOffset(long offset) { + this.offset = offset; + } - @Override - public String humanReadable() { - return String.valueOf(offset); - } + @Override + public String humanReadable() { + return String.valueOf(offset); + } - @Override - public long getOffset() { - return offset; - } + @Override + public long getOffset() { + return offset; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeysTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeysTest.java index c93dd546d..78330bd33 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeysTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jConfigKeysTest.java @@ -24,56 +24,62 @@ public class Neo4jConfigKeysTest { - @Test - public void testDefaultValues() { - Assert.assertEquals(Neo4jConstants.DEFAULT_DATABASE, "neo4j"); - Assert.assertEquals(Neo4jConstants.DEFAULT_BATCH_SIZE, 1000); - Assert.assertEquals(Neo4jConstants.DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS, 3600000L); - Assert.assertEquals(Neo4jConstants.DEFAULT_MAX_CONNECTION_POOL_SIZE, 100); - Assert.assertEquals(Neo4jConstants.DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS, 60000L); - Assert.assertEquals(Neo4jConstants.DEFAULT_NODE_LABEL, "Node"); - Assert.assertEquals(Neo4jConstants.DEFAULT_RELATIONSHIP_TYPE, "RELATES_TO"); - } + @Test + public void testDefaultValues() { + Assert.assertEquals(Neo4jConstants.DEFAULT_DATABASE, "neo4j"); + Assert.assertEquals(Neo4jConstants.DEFAULT_BATCH_SIZE, 1000); + Assert.assertEquals(Neo4jConstants.DEFAULT_MAX_CONNECTION_LIFETIME_MILLIS, 3600000L); + Assert.assertEquals(Neo4jConstants.DEFAULT_MAX_CONNECTION_POOL_SIZE, 100); + Assert.assertEquals(Neo4jConstants.DEFAULT_CONNECTION_ACQUISITION_TIMEOUT_MILLIS, 60000L); + Assert.assertEquals(Neo4jConstants.DEFAULT_NODE_LABEL, "Node"); + Assert.assertEquals(Neo4jConstants.DEFAULT_RELATIONSHIP_TYPE, "RELATES_TO"); + } - @Test - public void testConfigKeyNames() { - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI.getKey(), - "geaflow.dsl.neo4j.uri"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME.getKey(), - "geaflow.dsl.neo4j.username"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD.getKey(), - "geaflow.dsl.neo4j.password"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE.getKey(), - "geaflow.dsl.neo4j.database"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE.getKey(), - "geaflow.dsl.neo4j.batch.size"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_QUERY.getKey(), - "geaflow.dsl.neo4j.query"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL.getKey(), - "geaflow.dsl.neo4j.node.label"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE.getKey(), - "geaflow.dsl.neo4j.relationship.type"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE.getKey(), - "geaflow.dsl.neo4j.write.mode"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_ID_FIELD.getKey(), - "geaflow.dsl.neo4j.node.id.field"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD.getKey(), - "geaflow.dsl.neo4j.relationship.source.field"); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD.getKey(), - "geaflow.dsl.neo4j.relationship.target.field"); - } + @Test + public void testConfigKeyNames() { + Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI.getKey(), "geaflow.dsl.neo4j.uri"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME.getKey(), "geaflow.dsl.neo4j.username"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD.getKey(), "geaflow.dsl.neo4j.password"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE.getKey(), "geaflow.dsl.neo4j.database"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE.getKey(), "geaflow.dsl.neo4j.batch.size"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_QUERY.getKey(), "geaflow.dsl.neo4j.query"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL.getKey(), "geaflow.dsl.neo4j.node.label"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE.getKey(), + "geaflow.dsl.neo4j.relationship.type"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE.getKey(), "geaflow.dsl.neo4j.write.mode"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_ID_FIELD.getKey(), + "geaflow.dsl.neo4j.node.id.field"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_SOURCE_FIELD.getKey(), + "geaflow.dsl.neo4j.relationship.source.field"); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TARGET_FIELD.getKey(), + "geaflow.dsl.neo4j.relationship.target.field"); + } - @Test - public void testConfigKeyDefaults() { - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE.getDefaultValue(), - Neo4jConstants.DEFAULT_DATABASE); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE.getDefaultValue(), - Neo4jConstants.DEFAULT_BATCH_SIZE); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL.getDefaultValue(), - Neo4jConstants.DEFAULT_NODE_LABEL); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE.getDefaultValue(), - Neo4jConstants.DEFAULT_RELATIONSHIP_TYPE); - Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE.getDefaultValue(), - "node"); - } + @Test + public void testConfigKeyDefaults() { + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_DATABASE.getDefaultValue(), + Neo4jConstants.DEFAULT_DATABASE); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_BATCH_SIZE.getDefaultValue(), + Neo4jConstants.DEFAULT_BATCH_SIZE); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_NODE_LABEL.getDefaultValue(), + Neo4jConstants.DEFAULT_NODE_LABEL); + Assert.assertEquals( + Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_RELATIONSHIP_TYPE.getDefaultValue(), + Neo4jConstants.DEFAULT_RELATIONSHIP_TYPE); + Assert.assertEquals(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_WRITE_MODE.getDefaultValue(), "node"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnectorTest.java index ed2f3dae3..a255e9cc7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-neo4j/src/test/java/org/apache/geaflow/dsl/connector/neo4j/Neo4jTableConnectorTest.java @@ -28,54 +28,54 @@ public class Neo4jTableConnectorTest { - private Neo4jTableConnector connector; - private Configuration config; + private Neo4jTableConnector connector; + private Configuration config; - @BeforeMethod - public void setUp() { - connector = new Neo4jTableConnector(); - config = new Configuration(); - config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI, "bolt://localhost:7687"); - config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME, "neo4j"); - config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD, "password"); - } + @BeforeMethod + public void setUp() { + connector = new Neo4jTableConnector(); + config = new Configuration(); + config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_URI, "bolt://localhost:7687"); + config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_USERNAME, "neo4j"); + config.put(Neo4jConfigKeys.GEAFLOW_DSL_NEO4J_PASSWORD, "password"); + } - @Test - public void testGetType() { - Assert.assertEquals(connector.getType(), "Neo4j"); - } + @Test + public void testGetType() { + Assert.assertEquals(connector.getType(), "Neo4j"); + } - @Test - public void testCreateSource() { - TableSource source = connector.createSource(config); - Assert.assertNotNull(source); - Assert.assertTrue(source instanceof Neo4jTableSource); - } + @Test + public void testCreateSource() { + TableSource source = connector.createSource(config); + Assert.assertNotNull(source); + Assert.assertTrue(source instanceof Neo4jTableSource); + } - @Test - public void testCreateSink() { - TableSink sink = connector.createSink(config); - Assert.assertNotNull(sink); - Assert.assertTrue(sink instanceof Neo4jTableSink); - } + @Test + public void testCreateSink() { + TableSink sink = connector.createSink(config); + Assert.assertNotNull(sink); + Assert.assertTrue(sink instanceof Neo4jTableSink); + } - @Test - public void testMultipleSourceInstances() { - TableSource source1 = connector.createSource(config); - TableSource source2 = connector.createSource(config); + @Test + public void testMultipleSourceInstances() { + TableSource source1 = connector.createSource(config); + TableSource source2 = connector.createSource(config); - Assert.assertNotNull(source1); - Assert.assertNotNull(source2); - Assert.assertNotSame(source1, source2); - } + Assert.assertNotNull(source1); + Assert.assertNotNull(source2); + Assert.assertNotSame(source1, source2); + } - @Test - public void testMultipleSinkInstances() { - TableSink sink1 = connector.createSink(config); - TableSink sink2 = connector.createSink(config); + @Test + public void testMultipleSinkInstances() { + TableSink sink1 = connector.createSink(config); + TableSink sink2 = connector.createSink(config); - Assert.assertNotNull(sink1); - Assert.assertNotNull(sink2); - Assert.assertNotSame(sink1, sink2); - } + Assert.assertNotNull(sink1); + Assert.assertNotNull(sink2); + Assert.assertNotSame(sink1, sink2); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractor.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractor.java index d09a37e46..abce58032 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractor.java @@ -19,146 +19,156 @@ package org.apache.geaflow.dsl.connector.odps; -import com.aliyun.odps.Column; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.types.StructType; +import com.aliyun.odps.Column; + public class DefaultPartitionExtractor implements PartitionExtractor { - private static final String DEFAULT_SEPARATOR_PATTERN = "[,/]"; - private static final String QUOTE_SEPARATOR_PATTERN = "[`'\"]"; - private static final String EQUAL_SEPARATOR = "="; - private static final String COMMA_SEPARATOR = ","; - private static final String SLASH_SEPARATOR = "/"; - private static final String DYNAMIC_KEY_PREFIX = "$"; + private static final String DEFAULT_SEPARATOR_PATTERN = "[,/]"; + private static final String QUOTE_SEPARATOR_PATTERN = "[`'\"]"; + private static final String EQUAL_SEPARATOR = "="; + private static final String COMMA_SEPARATOR = ","; + private static final String SLASH_SEPARATOR = "/"; + private static final String DYNAMIC_KEY_PREFIX = "$"; - // partition spec separator - private final String separator; - // all partition keys - private final String[] keys; - // dynamic fields index - private final int[] columns; - // dynamic field types - private final IType[] types; - // constant fields, values.length should be equal to keys.length - // if values[i] is null, it means the i-th key is a dynamic field - private final String[] values; + // partition spec separator + private final String separator; + // all partition keys + private final String[] keys; + // dynamic fields index + private final int[] columns; + // dynamic field types + private final IType[] types; + // constant fields, values.length should be equal to keys.length + // if values[i] is null, it means the i-th key is a dynamic field + private final String[] values; - /** - * Create a partition extractor. - * @param partitionColumns partition columns - * @param schema the input schema - * @return the partition extractor - */ - public static PartitionExtractor create(List partitionColumns, StructType schema) { - if (partitionColumns == null || partitionColumns.isEmpty()) { - return row -> ""; - } - int[] columns = new int[partitionColumns.size()]; - IType[] types = new IType[partitionColumns.size()]; - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < partitionColumns.size(); i++) { - String partitionColumn = partitionColumns.get(i).getName(); - int index = schema.indexOf(partitionColumn); - if (index < 0) { - throw new IllegalArgumentException("Partition column " + partitionColumn + " not found in schema"); - } - columns[i] = index; - types[i] = schema.getType(index); - sb.append(partitionColumn).append(EQUAL_SEPARATOR).append(DYNAMIC_KEY_PREFIX) - .append(partitionColumn).append(COMMA_SEPARATOR); - } - return new DefaultPartitionExtractor(sb.substring(0, sb.length() - 1), columns, types); + /** + * Create a partition extractor. + * + * @param partitionColumns partition columns + * @param schema the input schema + * @return the partition extractor + */ + public static PartitionExtractor create(List partitionColumns, StructType schema) { + if (partitionColumns == null || partitionColumns.isEmpty()) { + return row -> ""; } - - /** - * Create a partition extractor. - * @param spec partition spec, like "dt=$dt,hh=$hh" - * @param schema the input schema - * @return the partition extractor - */ - public static PartitionExtractor create(String spec, StructType schema) { - if (spec == null || spec.isEmpty()) { - return row -> ""; - } - String[] groups = spec.split(DEFAULT_SEPARATOR_PATTERN); - List index = new ArrayList<>(); - List> types = new ArrayList<>(); - for (String group : groups) { - String[] kv = group.split(EQUAL_SEPARATOR); - if (kv.length != 2) { - throw new IllegalArgumentException("Invalid partition spec."); - } - String k = kv[0].trim(); - String v = unquoted(kv[1].trim()); - if (k.isEmpty() || v.isEmpty()) { - throw new IllegalArgumentException("Invalid partition spec."); - } - if (v.startsWith(DYNAMIC_KEY_PREFIX)) { - int val = schema.indexOf(v.substring(1)); - if (val != -1) { - index.add(val); - types.add(schema.getType(val)); - } - } - } - return new DefaultPartitionExtractor(spec, index.stream().mapToInt(i -> i).toArray(), types.toArray(new IType[0])); + int[] columns = new int[partitionColumns.size()]; + IType[] types = new IType[partitionColumns.size()]; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < partitionColumns.size(); i++) { + String partitionColumn = partitionColumns.get(i).getName(); + int index = schema.indexOf(partitionColumn); + if (index < 0) { + throw new IllegalArgumentException( + "Partition column " + partitionColumn + " not found in schema"); + } + columns[i] = index; + types[i] = schema.getType(index); + sb.append(partitionColumn) + .append(EQUAL_SEPARATOR) + .append(DYNAMIC_KEY_PREFIX) + .append(partitionColumn) + .append(COMMA_SEPARATOR); } + return new DefaultPartitionExtractor(sb.substring(0, sb.length() - 1), columns, types); + } - public DefaultPartitionExtractor(String spec, int[] columns, IType[] types) { - this.columns = columns; - this.types = types; - if (spec == null) { - throw new IllegalArgumentException("Argument 'spec' cannot be null"); - } - String[] groups = spec.split(DEFAULT_SEPARATOR_PATTERN); - this.separator = spec.contains(COMMA_SEPARATOR) ? COMMA_SEPARATOR : SLASH_SEPARATOR; - this.keys = new String[groups.length]; - this.values = new String[groups.length]; - for (int i = 0; i < groups.length; i++) { - String[] kv = groups[i].split(EQUAL_SEPARATOR); - if (kv.length != 2) { - throw new IllegalArgumentException("Invalid partition spec."); - } - String k = kv[0].trim(); - String v = unquoted(kv[1].trim()); - if (k.isEmpty() || v.isEmpty()) { - throw new IllegalArgumentException("Invalid partition spec."); - } - this.keys[i] = k; - this.values[i] = v.startsWith(DYNAMIC_KEY_PREFIX) ? null : v; + /** + * Create a partition extractor. + * + * @param spec partition spec, like "dt=$dt,hh=$hh" + * @param schema the input schema + * @return the partition extractor + */ + public static PartitionExtractor create(String spec, StructType schema) { + if (spec == null || spec.isEmpty()) { + return row -> ""; + } + String[] groups = spec.split(DEFAULT_SEPARATOR_PATTERN); + List index = new ArrayList<>(); + List> types = new ArrayList<>(); + for (String group : groups) { + String[] kv = group.split(EQUAL_SEPARATOR); + if (kv.length != 2) { + throw new IllegalArgumentException("Invalid partition spec."); + } + String k = kv[0].trim(); + String v = unquoted(kv[1].trim()); + if (k.isEmpty() || v.isEmpty()) { + throw new IllegalArgumentException("Invalid partition spec."); + } + if (v.startsWith(DYNAMIC_KEY_PREFIX)) { + int val = schema.indexOf(v.substring(1)); + if (val != -1) { + index.add(val); + types.add(schema.getType(val)); } + } } + return new DefaultPartitionExtractor( + spec, index.stream().mapToInt(i -> i).toArray(), types.toArray(new IType[0])); + } - /** - * Unquote the string. - * @param s the string - * @return the unquoted string - */ - public static String unquoted(String s) { - return s.replaceAll(QUOTE_SEPARATOR_PATTERN, ""); + public DefaultPartitionExtractor(String spec, int[] columns, IType[] types) { + this.columns = columns; + this.types = types; + if (spec == null) { + throw new IllegalArgumentException("Argument 'spec' cannot be null"); } + String[] groups = spec.split(DEFAULT_SEPARATOR_PATTERN); + this.separator = spec.contains(COMMA_SEPARATOR) ? COMMA_SEPARATOR : SLASH_SEPARATOR; + this.keys = new String[groups.length]; + this.values = new String[groups.length]; + for (int i = 0; i < groups.length; i++) { + String[] kv = groups[i].split(EQUAL_SEPARATOR); + if (kv.length != 2) { + throw new IllegalArgumentException("Invalid partition spec."); + } + String k = kv[0].trim(); + String v = unquoted(kv[1].trim()); + if (k.isEmpty() || v.isEmpty()) { + throw new IllegalArgumentException("Invalid partition spec."); + } + this.keys[i] = k; + this.values[i] = v.startsWith(DYNAMIC_KEY_PREFIX) ? null : v; + } + } - @Override - public String extractPartition(Row row) { - StringBuilder sb = new StringBuilder(); - // dynamic field - int col = 0; - for (int i = 0; i < keys.length; i++) { - sb.append(keys[i]).append(EQUAL_SEPARATOR); - if (values[i] == null) { - sb.append(row.getField(columns[col], types[col])); - col++; - } else { - sb.append(values[i]); - } - if (i < keys.length - 1) { - sb.append(separator); - } - } - return sb.toString(); + /** + * Unquote the string. + * + * @param s the string + * @return the unquoted string + */ + public static String unquoted(String s) { + return s.replaceAll(QUOTE_SEPARATOR_PATTERN, ""); + } + + @Override + public String extractPartition(Row row) { + StringBuilder sb = new StringBuilder(); + // dynamic field + int col = 0; + for (int i = 0; i < keys.length; i++) { + sb.append(keys[i]).append(EQUAL_SEPARATOR); + if (values[i] == null) { + sb.append(row.getField(columns[col], types[col])); + col++; + } else { + sb.append(values[i]); + } + if (i < keys.length - 1) { + sb.append(separator); + } } + return sb.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsConfigKeys.java index 753edc09e..0c7182e82 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsConfigKeys.java @@ -23,49 +23,47 @@ import org.apache.geaflow.common.config.ConfigKeys; public class OdpsConfigKeys { - public static final ConfigKey GEAFLOW_DSL_ODPS_PROJECT = ConfigKeys - .key("geaflow.dsl.odps.project") - .noDefaultValue() - .description("The odps project name."); + public static final ConfigKey GEAFLOW_DSL_ODPS_PROJECT = + ConfigKeys.key("geaflow.dsl.odps.project") + .noDefaultValue() + .description("The odps project name."); - public static final ConfigKey GEAFLOW_DSL_ODPS_TABLE = ConfigKeys - .key("geaflow.dsl.odps.table") - .noDefaultValue() - .description("The odps table name."); + public static final ConfigKey GEAFLOW_DSL_ODPS_TABLE = + ConfigKeys.key("geaflow.dsl.odps.table").noDefaultValue().description("The odps table name."); - public static final ConfigKey GEAFLOW_DSL_ODPS_ACCESS_ID = ConfigKeys - .key("geaflow.dsl.odps.accessid") - .noDefaultValue() - .description("The odps accessid."); + public static final ConfigKey GEAFLOW_DSL_ODPS_ACCESS_ID = + ConfigKeys.key("geaflow.dsl.odps.accessid") + .noDefaultValue() + .description("The odps accessid."); - public static final ConfigKey GEAFLOW_DSL_ODPS_ACCESS_KEY = ConfigKeys - .key("geaflow.dsl.odps.accesskey") - .noDefaultValue() - .description("The odps accesskey."); + public static final ConfigKey GEAFLOW_DSL_ODPS_ACCESS_KEY = + ConfigKeys.key("geaflow.dsl.odps.accesskey") + .noDefaultValue() + .description("The odps accesskey."); - public static final ConfigKey GEAFLOW_DSL_ODPS_ENDPOINT = ConfigKeys - .key("geaflow.dsl.odps.endpoint") - .noDefaultValue() - .description("The odps endpoint."); + public static final ConfigKey GEAFLOW_DSL_ODPS_ENDPOINT = + ConfigKeys.key("geaflow.dsl.odps.endpoint") + .noDefaultValue() + .description("The odps endpoint."); - @Deprecated - public static final ConfigKey GEAFLOW_DSL_ODPS_PARTITION_SPEC = ConfigKeys - .key("geaflow.dsl.odps.partition.spec") - .defaultValue("") - .description("The odps partition spec."); + @Deprecated + public static final ConfigKey GEAFLOW_DSL_ODPS_PARTITION_SPEC = + ConfigKeys.key("geaflow.dsl.odps.partition.spec") + .defaultValue("") + .description("The odps partition spec."); - public static final ConfigKey GEAFLOW_DSL_ODPS_SINK_BUFFER_SIZE = ConfigKeys - .key("geaflow.dsl.odps.sink.buffer.size") - .defaultValue(1000) - .description("The buffer size of odps sink buffer."); + public static final ConfigKey GEAFLOW_DSL_ODPS_SINK_BUFFER_SIZE = + ConfigKeys.key("geaflow.dsl.odps.sink.buffer.size") + .defaultValue(1000) + .description("The buffer size of odps sink buffer."); - public static final ConfigKey GEAFLOW_DSL_ODPS_SINK_FLUSH_INTERVAL_MS = ConfigKeys - .key("geaflow.dsl.odps.sink.flush.interval.ms") - .defaultValue(10000) - .description("The flush interval of odps sink buffer."); + public static final ConfigKey GEAFLOW_DSL_ODPS_SINK_FLUSH_INTERVAL_MS = + ConfigKeys.key("geaflow.dsl.odps.sink.flush.interval.ms") + .defaultValue(10000) + .description("The flush interval of odps sink buffer."); - public static final ConfigKey GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS = ConfigKeys - .key("geaflow.dsl.odps.timeout.seconds") - .defaultValue(60) - .description("The timeout for odps connection, in seconds."); + public static final ConfigKey GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS = + ConfigKeys.key("geaflow.dsl.odps.timeout.seconds") + .defaultValue(60) + .description("The timeout for odps connection, in seconds."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableConnector.java index 603ab4f8d..f5c71896a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableConnector.java @@ -27,20 +27,20 @@ public class OdpsTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "ODPS"; + public static final String TYPE = "ODPS"; - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } - @Override - public TableSource createSource(Configuration conf) { - return new OdpsTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new OdpsTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new OdpsTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new OdpsTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSink.java index e6c22bc70..f6f5470f4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSink.java @@ -19,18 +19,6 @@ package org.apache.geaflow.dsl.connector.odps; -import com.aliyun.odps.Column; -import com.aliyun.odps.Odps; -import com.aliyun.odps.PartitionSpec; -import com.aliyun.odps.Table; -import com.aliyun.odps.TableSchema; -import com.aliyun.odps.account.Account; -import com.aliyun.odps.account.AliyunAccount; -import com.aliyun.odps.data.ArrayRecord; -import com.aliyun.odps.tunnel.TableTunnel; -import com.aliyun.odps.tunnel.TunnelException; -import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import java.io.IOException; import java.util.HashMap; import java.util.List; @@ -40,7 +28,9 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; + import javax.annotation.Nullable; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; @@ -50,165 +40,188 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.aliyun.odps.Column; +import com.aliyun.odps.Odps; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.Table; +import com.aliyun.odps.TableSchema; +import com.aliyun.odps.account.Account; +import com.aliyun.odps.account.AliyunAccount; +import com.aliyun.odps.data.ArrayRecord; +import com.aliyun.odps.tunnel.TableTunnel; +import com.aliyun.odps.tunnel.TunnelException; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; + public class OdpsTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(OdpsTableSink.class); - - private int bufferSize = 1000; - private int flushIntervalMs = Integer.MAX_VALUE; - - private int timeoutSeconds = 60; - private String endPoint; - private String project; - private String tableName; - private String accessKey; - private String accessId; - private StructType schema; - private String partitionSpec; - - private transient TableTunnel tunnel; - private transient Column[] recordColumns; - private transient int[] columnIndex; - private transient PartitionExtractor partitionExtractor; - private transient Map partitionWriters; - - private final ExecutorService executor = Executors.newSingleThreadExecutor(); - - @Override - public void init(Configuration tableConf, StructType schema) { - LOGGER.info("open with config: {}, \n schema: {}", tableConf, schema); - this.schema = Objects.requireNonNull(schema); - this.columnIndex = new int[schema.size()]; - for (int i = 0; i < this.schema.size(); i++) { - String columnName = this.schema.getField(i).getName(); - columnIndex[i] = this.schema.indexOf(columnName); - } - this.endPoint = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT); - this.project = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT); - this.tableName = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE); - this.accessKey = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY); - this.accessId = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID); - this.partitionSpec = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PARTITION_SPEC); - int bufferSize = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_SINK_BUFFER_SIZE); - if (bufferSize > 0) { - this.bufferSize = bufferSize; - } - int flushIntervalMs = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_SINK_FLUSH_INTERVAL_MS); - if (flushIntervalMs > 0) { - this.flushIntervalMs = flushIntervalMs; - } - int timeoutSeconds = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS); - if (timeoutSeconds > 0) { - this.timeoutSeconds = timeoutSeconds; - } - checkArguments(); - - LOGGER.info("init odps table sink, endPoint : {}, project : {}, tableName : {}", - endPoint, project, tableName); + private static final Logger LOGGER = LoggerFactory.getLogger(OdpsTableSink.class); + + private int bufferSize = 1000; + private int flushIntervalMs = Integer.MAX_VALUE; + + private int timeoutSeconds = 60; + private String endPoint; + private String project; + private String tableName; + private String accessKey; + private String accessId; + private StructType schema; + private String partitionSpec; + + private transient TableTunnel tunnel; + private transient Column[] recordColumns; + private transient int[] columnIndex; + private transient PartitionExtractor partitionExtractor; + private transient Map partitionWriters; + + private final ExecutorService executor = Executors.newSingleThreadExecutor(); + + @Override + public void init(Configuration tableConf, StructType schema) { + LOGGER.info("open with config: {}, \n schema: {}", tableConf, schema); + this.schema = Objects.requireNonNull(schema); + this.columnIndex = new int[schema.size()]; + for (int i = 0; i < this.schema.size(); i++) { + String columnName = this.schema.getField(i).getName(); + columnIndex[i] = this.schema.indexOf(columnName); } - - @Override - public void open(RuntimeContext context) { - Account account = new AliyunAccount(accessId, accessKey); - Odps odps = new Odps(account); - odps.setEndpoint(endPoint); - odps.setDefaultProject(project); - this.tunnel = new TableTunnel(odps); - Table table = odps.tables().get(tableName); - TableSchema tableSchema = table.getSchema(); - this.recordColumns = tableSchema.getColumns().toArray(new Column[0]); - this.columnIndex = new int[recordColumns.length]; - for (int i = 0; i < this.recordColumns.length; i++) { - String columnName = this.recordColumns[i].getName(); - columnIndex[i] = this.schema.indexOf(columnName); - } - if (this.partitionSpec != null && !this.partitionSpec.isEmpty()) { - this.partitionExtractor = DefaultPartitionExtractor.create(this.partitionSpec, schema); - } else { - List partitionColumns = tableSchema.getPartitionColumns(); - this.partitionExtractor = DefaultPartitionExtractor.create(partitionColumns, schema); - } - this.partitionWriters = new HashMap<>(); + this.endPoint = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT); + this.project = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT); + this.tableName = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE); + this.accessKey = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY); + this.accessId = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID); + this.partitionSpec = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PARTITION_SPEC); + int bufferSize = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_SINK_BUFFER_SIZE); + if (bufferSize > 0) { + this.bufferSize = bufferSize; } - - @Override - public void write(Row row) throws IOException { - Object[] values = new Object[columnIndex.length]; - for (int i = 0; i < columnIndex.length; i++) { - if (columnIndex[i] >= 0) { - values[i] = row.getField(columnIndex[i], schema.getType(columnIndex[i])); - } else { - values[i] = null; - } - } - PartitionWriter writer = createOrGetWriter(partitionExtractor.extractPartition(row)); - writer.write(new ArrayRecord(recordColumns, values)); + int flushIntervalMs = + tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_SINK_FLUSH_INTERVAL_MS); + if (flushIntervalMs > 0) { + this.flushIntervalMs = flushIntervalMs; } - - @Override - public void finish() throws IOException { - flush(); + int timeoutSeconds = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS); + if (timeoutSeconds > 0) { + this.timeoutSeconds = timeoutSeconds; } - - @Override - public void close() { - LOGGER.info("close."); - flush(); + checkArguments(); + + LOGGER.info( + "init odps table sink, endPoint : {}, project : {}, tableName : {}", + endPoint, + project, + tableName); + } + + @Override + public void open(RuntimeContext context) { + Account account = new AliyunAccount(accessId, accessKey); + Odps odps = new Odps(account); + odps.setEndpoint(endPoint); + odps.setDefaultProject(project); + this.tunnel = new TableTunnel(odps); + Table table = odps.tables().get(tableName); + TableSchema tableSchema = table.getSchema(); + this.recordColumns = tableSchema.getColumns().toArray(new Column[0]); + this.columnIndex = new int[recordColumns.length]; + for (int i = 0; i < this.recordColumns.length; i++) { + String columnName = this.recordColumns[i].getName(); + columnIndex[i] = this.schema.indexOf(columnName); } - - private void flush() { - try { - for (PartitionWriter writer : partitionWriters.values()) { - writer.flush(); - } - } catch (IOException e) { - throw new RuntimeException("Flush data error.", e); - } + if (this.partitionSpec != null && !this.partitionSpec.isEmpty()) { + this.partitionExtractor = DefaultPartitionExtractor.create(this.partitionSpec, schema); + } else { + List partitionColumns = tableSchema.getPartitionColumns(); + this.partitionExtractor = DefaultPartitionExtractor.create(partitionColumns, schema); } - - /** - * Create or get writer. - * @param partition the partition - * @return a writer - */ - private PartitionWriter createOrGetWriter(String partition) { - PartitionWriter partitionWriter = partitionWriters.get(partition); - if (partitionWriter == null) { - TableTunnel.StreamUploadSession session = createUploadSession(partition); - partitionWriter = new PartitionWriter(session, bufferSize, flushIntervalMs); - partitionWriters.put(partition, partitionWriter); - } - return partitionWriter; + this.partitionWriters = new HashMap<>(); + } + + @Override + public void write(Row row) throws IOException { + Object[] values = new Object[columnIndex.length]; + for (int i = 0; i < columnIndex.length; i++) { + if (columnIndex[i] >= 0) { + values[i] = row.getField(columnIndex[i], schema.getType(columnIndex[i])); + } else { + values[i] = null; + } } - - /** - * Create an upload session. - * @param partition the partition - * @return an upload session - */ - private TableTunnel.StreamUploadSession createUploadSession(@Nullable String partition) { - Future future = executor.submit(() -> { - try { + PartitionWriter writer = createOrGetWriter(partitionExtractor.extractPartition(row)); + writer.write(new ArrayRecord(recordColumns, values)); + } + + @Override + public void finish() throws IOException { + flush(); + } + + @Override + public void close() { + LOGGER.info("close."); + flush(); + } + + private void flush() { + try { + for (PartitionWriter writer : partitionWriters.values()) { + writer.flush(); + } + } catch (IOException e) { + throw new RuntimeException("Flush data error.", e); + } + } + + /** + * Create or get writer. + * + * @param partition the partition + * @return a writer + */ + private PartitionWriter createOrGetWriter(String partition) { + PartitionWriter partitionWriter = partitionWriters.get(partition); + if (partitionWriter == null) { + TableTunnel.StreamUploadSession session = createUploadSession(partition); + partitionWriter = new PartitionWriter(session, bufferSize, flushIntervalMs); + partitionWriters.put(partition, partitionWriter); + } + return partitionWriter; + } + + /** + * Create an upload session. + * + * @param partition the partition + * @return an upload session + */ + private TableTunnel.StreamUploadSession createUploadSession(@Nullable String partition) { + Future future = + executor.submit( + () -> { + try { if (partition == null || partition.isEmpty()) { - return tunnel.createStreamUploadSession(project, tableName); + return tunnel.createStreamUploadSession(project, tableName); } - return tunnel.createStreamUploadSession(project, tableName, new PartitionSpec(partition)); - } catch (TunnelException e) { + return tunnel.createStreamUploadSession( + project, tableName, new PartitionSpec(partition)); + } catch (TunnelException e) { throw new GeaFlowDSLException("Cannot get odps session.", e); - } - }); - try { - return future.get(this.timeoutSeconds, TimeUnit.SECONDS); - } catch (Exception e) { - throw new GeaFlowDSLException("Create stream upload session with endpoint " + this.endPoint + " failed", e); - } - } - - private void checkArguments() { - Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "endPoint is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "project is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(tableName), "tableName is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(accessId), "accessId is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(accessKey), "accessKey is null"); + } + }); + try { + return future.get(this.timeoutSeconds, TimeUnit.SECONDS); + } catch (Exception e) { + throw new GeaFlowDSLException( + "Create stream upload session with endpoint " + this.endPoint + " failed", e); } + } + + private void checkArguments() { + Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "endPoint is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "project is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(tableName), "tableName is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(accessId), "accessId is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(accessKey), "accessKey is null"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSource.java index a3f9cecce..79c7f064c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSource.java @@ -19,18 +19,6 @@ package org.apache.geaflow.dsl.connector.odps; -import com.aliyun.odps.Odps; -import com.aliyun.odps.PartitionSpec; -import com.aliyun.odps.Table; -import com.aliyun.odps.account.Account; -import com.aliyun.odps.account.AliyunAccount; -import com.aliyun.odps.data.Record; -import com.aliyun.odps.data.RecordReader; -import com.aliyun.odps.tunnel.TableTunnel; -import com.aliyun.odps.tunnel.TableTunnel.DownloadSession; -import com.aliyun.odps.tunnel.TunnelException; -import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -43,6 +31,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -68,281 +57,312 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class OdpsTableSource implements TableSource, EnablePartitionPushDown { - - private static final Logger LOGGER = LoggerFactory.getLogger(OdpsTableSource.class); - - private int timeoutSeconds = 60; - private String endPoint; - private String project; - private String tableName; - private String accessKey; - private String accessId; - - private String shardNamePrefix; - private PartitionFilter partitionFilter; - private StructType partitionSchema; - private StructType schema; - private Map columnName2Index; +import com.aliyun.odps.Odps; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.Table; +import com.aliyun.odps.account.Account; +import com.aliyun.odps.account.AliyunAccount; +import com.aliyun.odps.data.Record; +import com.aliyun.odps.data.RecordReader; +import com.aliyun.odps.tunnel.TableTunnel; +import com.aliyun.odps.tunnel.TableTunnel.DownloadSession; +import com.aliyun.odps.tunnel.TunnelException; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; - private transient Odps odps; - private transient Table table; - private transient TableTunnel tunnel; - private transient Map partition2DownloadSession; +public class OdpsTableSource implements TableSource, EnablePartitionPushDown { - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.endPoint = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT); - this.project = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT); - this.tableName = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE); - this.accessKey = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY); - this.accessId = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID); - this.shardNamePrefix = project + "-" + tableName + "-"; - this.timeoutSeconds = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS); - - this.partitionSchema = tableSchema.getPartitionSchema(); - this.schema = Objects.requireNonNull(tableSchema); - columnName2Index = new HashMap<>(); - for (int i = 0; i < schema.size(); i++) { - columnName2Index.put(schema.getField(i).getName(), i); - } - - checkArguments(); - - LOGGER.info("init with config: {}, \n schema: {}\n" - + "endPoint : {}, project : {}, tableName : {}", - tableConf, tableSchema, endPoint, project, tableName); + private static final Logger LOGGER = LoggerFactory.getLogger(OdpsTableSource.class); + + private int timeoutSeconds = 60; + private String endPoint; + private String project; + private String tableName; + private String accessKey; + private String accessId; + + private String shardNamePrefix; + private PartitionFilter partitionFilter; + private StructType partitionSchema; + private StructType schema; + private Map columnName2Index; + + private transient Odps odps; + private transient Table table; + private transient TableTunnel tunnel; + private transient Map partition2DownloadSession; + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.endPoint = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT); + this.project = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT); + this.tableName = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE); + this.accessKey = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY); + this.accessId = tableConf.getString(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID); + this.shardNamePrefix = project + "-" + tableName + "-"; + this.timeoutSeconds = tableConf.getInteger(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TIMEOUT_SECONDS); + + this.partitionSchema = tableSchema.getPartitionSchema(); + this.schema = Objects.requireNonNull(tableSchema); + columnName2Index = new HashMap<>(); + for (int i = 0; i < schema.size(); i++) { + columnName2Index.put(schema.getField(i).getName(), i); } - @Override - public void open(RuntimeContext context) { - Account account = new AliyunAccount(accessId, accessKey); - this.odps = new Odps(account); - odps.setEndpoint(endPoint); - odps.setDefaultProject(project); - this.tunnel = new TableTunnel(odps); - this.partition2DownloadSession = new HashMap<>(); - this.table = odps.tables().get(project, tableName); - for (TableField field : partitionSchema.getFields()) { - String partitionKey = field.getName(); - if (this.table.getSchema().getPartitionColumn(partitionKey) == null) { - throw new GeaFlowDSLException("Partition key: {} not exists in odps table: {}.", - partitionKey, tableName); - } - if (!OdpsConnectorUtils.typeEquals( - this.table.getSchema().getPartitionColumn(partitionKey).getTypeInfo().getOdpsType(), - field.getType())) { - throw new GeaFlowDSLException("Partition key: {} is {} in Odps but not {}.", - partitionKey, - this.table.getSchema().getPartitionColumn(partitionKey).getTypeInfo().getTypeName(), - field.getType().getName()); - } - } - for (TableField field : ((TableSchema) schema).getDataSchema().getFields()) { - String fieldName = field.getName(); - if (this.table.getSchema().getColumn(fieldName) != null && !OdpsConnectorUtils.typeEquals( - this.table.getSchema().getColumn(fieldName).getTypeInfo().getOdpsType(), - field.getType())) { - throw new GeaFlowDSLException("Column: {} is {} in Odps but not {}.", - fieldName, - this.table.getSchema().getPartitionColumn(fieldName).getTypeInfo().getTypeName(), - field.getType().getName()); - } - } + checkArguments(); + + LOGGER.info( + "init with config: {}, \n schema: {}\n" + "endPoint : {}, project : {}, tableName : {}", + tableConf, + tableSchema, + endPoint, + project, + tableName); + } + + @Override + public void open(RuntimeContext context) { + Account account = new AliyunAccount(accessId, accessKey); + this.odps = new Odps(account); + odps.setEndpoint(endPoint); + odps.setDefaultProject(project); + this.tunnel = new TableTunnel(odps); + this.partition2DownloadSession = new HashMap<>(); + this.table = odps.tables().get(project, tableName); + for (TableField field : partitionSchema.getFields()) { + String partitionKey = field.getName(); + if (this.table.getSchema().getPartitionColumn(partitionKey) == null) { + throw new GeaFlowDSLException( + "Partition key: {} not exists in odps table: {}.", partitionKey, tableName); + } + if (!OdpsConnectorUtils.typeEquals( + this.table.getSchema().getPartitionColumn(partitionKey).getTypeInfo().getOdpsType(), + field.getType())) { + throw new GeaFlowDSLException( + "Partition key: {} is {} in Odps but not {}.", + partitionKey, + this.table.getSchema().getPartitionColumn(partitionKey).getTypeInfo().getTypeName(), + field.getType().getName()); + } } - - - private Row getPartitionRow(PartitionSpec spec) { - Object[] values = new Object[partitionSchema.getFields().size()]; - int i = 0; - for (TableField field : partitionSchema.getFields()) { - String fieldName = field.getName(); - if (spec.keys().contains(fieldName)) { - values[i] = TypeCastUtil.cast(spec.get(fieldName), field.getType()); - } else { - values[i] = null; - } - i++; - } - return ObjectRow.create(values); + for (TableField field : ((TableSchema) schema).getDataSchema().getFields()) { + String fieldName = field.getName(); + if (this.table.getSchema().getColumn(fieldName) != null + && !OdpsConnectorUtils.typeEquals( + this.table.getSchema().getColumn(fieldName).getTypeInfo().getOdpsType(), + field.getType())) { + throw new GeaFlowDSLException( + "Column: {} is {} in Odps but not {}.", + fieldName, + this.table.getSchema().getPartitionColumn(fieldName).getTypeInfo().getTypeName(), + field.getType().getName()); + } } - - @Override - public List listPartitions() { - ExecutorService executor = Executors.newSingleThreadExecutor(); - Future> future = executor.submit(() -> { - List odpsPartitions = new ArrayList<>(); - List partitionSpecs; - List allPartitions = table.getPartitionSpecs(); - if (partitionFilter == null || partitionSchema == null) { + } + + private Row getPartitionRow(PartitionSpec spec) { + Object[] values = new Object[partitionSchema.getFields().size()]; + int i = 0; + for (TableField field : partitionSchema.getFields()) { + String fieldName = field.getName(); + if (spec.keys().contains(fieldName)) { + values[i] = TypeCastUtil.cast(spec.get(fieldName), field.getType()); + } else { + values[i] = null; + } + i++; + } + return ObjectRow.create(values); + } + + @Override + public List listPartitions() { + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future> future = + executor.submit( + () -> { + List odpsPartitions = new ArrayList<>(); + List partitionSpecs; + List allPartitions = table.getPartitionSpecs(); + if (partitionFilter == null || partitionSchema == null) { partitionSpecs = allPartitions; - } else { - partitionSpecs = allPartitions.stream().filter(p -> - partitionFilter.apply(getPartitionRow(p))).collect(Collectors.toList()); - } - odpsPartitions.addAll(partitionSpecs.stream().map(spec -> table.getPartition(spec)).collect( - Collectors.toList())); - return odpsPartitions.stream().map(partition -> new OdpsShardPartition(shardNamePrefix, - partition.getPartitionSpec())).collect(Collectors.toList()); - }); - List odpsPartitions; - try { - odpsPartitions = future.get(this.timeoutSeconds, TimeUnit.SECONDS); - } catch (Exception e) { - throw new GeaFlowDSLException("Cannot list partitions from ODPS, endPoint: " + this.endPoint, e); - } - return odpsPartitions; + } else { + partitionSpecs = + allPartitions.stream() + .filter(p -> partitionFilter.apply(getPartitionRow(p))) + .collect(Collectors.toList()); + } + odpsPartitions.addAll( + partitionSpecs.stream() + .map(spec -> table.getPartition(spec)) + .collect(Collectors.toList())); + return odpsPartitions.stream() + .map( + partition -> + new OdpsShardPartition(shardNamePrefix, partition.getPartitionSpec())) + .collect(Collectors.toList()); + }); + List odpsPartitions; + try { + odpsPartitions = future.get(this.timeoutSeconds, TimeUnit.SECONDS); + } catch (Exception e) { + throw new GeaFlowDSLException( + "Cannot list partitions from ODPS, endPoint: " + this.endPoint, e); } - - @Override - public List listPartitions(int parallelism) { - return listPartitions(); + return odpsPartitions; + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return (TableDeserializer) new OdpsRecordDeserializer(); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + long desireWindowSize = -1; + switch (windowInfo.getType()) { + case ALL_WINDOW: + desireWindowSize = Long.MAX_VALUE; + break; + case SIZE_TUMBLING_WINDOW: + desireWindowSize = windowInfo.windowSize(); + break; + default: + throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); + } + assert partition instanceof OdpsShardPartition; + OdpsShardPartition odpsPartition = (OdpsShardPartition) partition; + + // Get the ODPS download session, and if it does not exist in the map, open one. + long start = startOffset.isPresent() ? (startOffset.get()).getOffset() : 0L; + DownloadSession downloadSession = partition2DownloadSession.get(odpsPartition); + if (downloadSession == null) { + try { + downloadSession = + tunnel.createDownloadSession( + project, tableName, odpsPartition.getSinglePartitionSpec()); + partition2DownloadSession.put(odpsPartition, downloadSession); + } catch (TunnelException e) { + throw new GeaFlowDSLException("Cannot get Odps session.", e); + } + } + long sessionRecordCount = downloadSession.getRecordCount(); + long remainingCount = sessionRecordCount - start < 0 ? 0 : sessionRecordCount - start; + long count = Math.min(desireWindowSize, remainingCount); + RecordReader reader; + try { + reader = downloadSession.openRecordReader(start, count); + } catch (TunnelException e) { + throw new GeaFlowDSLException("Cannot get Odps session.", e); } - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return (TableDeserializer) new OdpsRecordDeserializer(); + OdpsOffset nextOffset = new OdpsOffset(start + count); + if (windowInfo.getType() == WindowType.ALL_WINDOW) { + return (FetchData) + FetchData.createBatchFetch( + new OdpsBatchIterator(reader, count, odpsPartition.getSinglePartitionSpec()), + nextOffset); + } else { + List dataList = new ArrayList<>(); + for (int i = 0; i < count; i++) { + Record record = reader.read(); + OdpsRecordWithPartitionSpec recordWithPartitionSpec = + new OdpsRecordWithPartitionSpec(record, odpsPartition.getSinglePartitionSpec()); + dataList.add(recordWithPartitionSpec); + } + reader.close(); + boolean isFinish = desireWindowSize >= remainingCount; + return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); + } + } + + @Override + public void close() { + LOGGER.info("close."); + } + + private void checkArguments() { + Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "endPoint is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "project is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(tableName), "tableName is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "accessKey is null"); + Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "accessId is null"); + } + + @Override + public void setPartitionFilter(PartitionFilter partitionFilter) { + this.partitionFilter = partitionFilter; + } + + public static class OdpsShardPartition implements Partition { + + private final String prefix; + private final PartitionSpec singlePartitionSpec; + + public OdpsShardPartition(String prefix, PartitionSpec singlePartitionSpec) { + this.prefix = prefix; + this.singlePartitionSpec = singlePartitionSpec; } @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - long desireWindowSize = -1; - switch (windowInfo.getType()) { - case ALL_WINDOW: - desireWindowSize = Long.MAX_VALUE; - break; - case SIZE_TUMBLING_WINDOW: - desireWindowSize = windowInfo.windowSize(); - break; - default: - throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); - } - assert partition instanceof OdpsShardPartition; - OdpsShardPartition odpsPartition = (OdpsShardPartition) partition; - - //Get the ODPS download session, and if it does not exist in the map, open one. - long start = startOffset.isPresent() ? (startOffset.get()).getOffset() : 0L; - DownloadSession downloadSession = partition2DownloadSession.get(odpsPartition); - if (downloadSession == null) { - try { - downloadSession = tunnel.createDownloadSession(project, tableName, - odpsPartition.getSinglePartitionSpec()); - partition2DownloadSession.put(odpsPartition, downloadSession); - } catch (TunnelException e) { - throw new GeaFlowDSLException("Cannot get Odps session.", e); - } - } - long sessionRecordCount = downloadSession.getRecordCount(); - long remainingCount = sessionRecordCount - start < 0 ? 0 : sessionRecordCount - start; - long count = Math.min(desireWindowSize, remainingCount); - RecordReader reader; - try { - reader = downloadSession.openRecordReader(start, count); - } catch (TunnelException e) { - throw new GeaFlowDSLException("Cannot get Odps session.", e); - } - - OdpsOffset nextOffset = new OdpsOffset(start + count); - if (windowInfo.getType() == WindowType.ALL_WINDOW) { - return (FetchData) FetchData.createBatchFetch( - new OdpsBatchIterator(reader, count, odpsPartition.getSinglePartitionSpec()), nextOffset); - } else { - List dataList = new ArrayList<>(); - for (int i = 0; i < count; i++) { - Record record = reader.read(); - OdpsRecordWithPartitionSpec recordWithPartitionSpec = new OdpsRecordWithPartitionSpec(record, odpsPartition.getSinglePartitionSpec()); - dataList.add(recordWithPartitionSpec); - } - reader.close(); - boolean isFinish = desireWindowSize >= remainingCount; - return (FetchData) FetchData.createStreamFetch(dataList, nextOffset, isFinish); - } + public String getName() { + return prefix + singlePartitionSpec; } - @Override - public void close() { - LOGGER.info("close."); - } + public void setIndex(int index, int parallel) {} - private void checkArguments() { - Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "endPoint is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "project is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(tableName), "tableName is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(endPoint), "accessKey is null"); - Preconditions.checkArgument(!Strings.isNullOrEmpty(project), "accessId is null"); + @Override + public int hashCode() { + return Objects.hash(prefix, singlePartitionSpec); } @Override - public void setPartitionFilter(PartitionFilter partitionFilter) { - this.partitionFilter = partitionFilter; + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof OdpsShardPartition)) { + return false; + } + OdpsShardPartition that = (OdpsShardPartition) o; + return Objects.equals(prefix, that.prefix) + && Objects.equals(singlePartitionSpec.toString(), that.singlePartitionSpec.toString()); } - public static class OdpsShardPartition implements Partition { - - private final String prefix; - private final PartitionSpec singlePartitionSpec; - - public OdpsShardPartition(String prefix, PartitionSpec singlePartitionSpec) { - this.prefix = prefix; - this.singlePartitionSpec = singlePartitionSpec; - } - - @Override - public String getName() { - return prefix + singlePartitionSpec; - } - - @Override - public void setIndex(int index, int parallel) { - } - - @Override - public int hashCode() { - return Objects.hash(prefix, singlePartitionSpec); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof OdpsShardPartition)) { - return false; - } - OdpsShardPartition that = (OdpsShardPartition) o; - return Objects.equals(prefix, that.prefix) && Objects.equals( - singlePartitionSpec.toString(), that.singlePartitionSpec.toString()); - } - - public PartitionSpec getSinglePartitionSpec() { - return singlePartitionSpec; - } + public PartitionSpec getSinglePartitionSpec() { + return singlePartitionSpec; } + } - public static class OdpsOffset implements Offset { - - private final long offset; + public static class OdpsOffset implements Offset { - public OdpsOffset(long offset) { - this.offset = offset; - } + private final long offset; - @Override - public String humanReadable() { - return String.valueOf(offset); - } + public OdpsOffset(long offset) { + this.offset = offset; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public String humanReadable() { + return String.valueOf(offset); + } - @Override - public long getOffset() { - return offset; - } + @Override + public boolean isTimestamp() { + return false; } + @Override + public long getOffset() { + return offset; + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionExtractor.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionExtractor.java index 9552e448b..f1c0ba907 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionExtractor.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionExtractor.java @@ -21,15 +21,14 @@ import org.apache.geaflow.dsl.common.data.Row; -/** - * ODPS partition extractor. - */ +/** ODPS partition extractor. */ public interface PartitionExtractor { - /** - * extract partition from row. - * @param row the input row - * @return the partition - */ - String extractPartition(Row row); + /** + * extract partition from row. + * + * @param row the input row + * @return the partition + */ + String extractPartition(Row row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionWriter.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionWriter.java index 2d78cf742..2c6233247 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionWriter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/PartitionWriter.java @@ -19,48 +19,52 @@ package org.apache.geaflow.dsl.connector.odps; +import java.io.IOException; + import com.aliyun.odps.data.Record; import com.aliyun.odps.tunnel.TableTunnel; -import java.io.IOException; public class PartitionWriter { - private final TableTunnel.StreamRecordPack recordPack; - private final int batchSize; - private final int flushIntervalMs; - private long lastFlushTime; + private final TableTunnel.StreamRecordPack recordPack; + private final int batchSize; + private final int flushIntervalMs; + private long lastFlushTime; - public PartitionWriter(TableTunnel.StreamUploadSession uploadSession, int batchSize, int flushIntervalMs) { - try { - this.recordPack = uploadSession.newRecordPack(); - } catch (IOException e) { - throw new RuntimeException(e); - } - this.batchSize = batchSize; - this.flushIntervalMs = flushIntervalMs; - this.lastFlushTime = System.currentTimeMillis(); + public PartitionWriter( + TableTunnel.StreamUploadSession uploadSession, int batchSize, int flushIntervalMs) { + try { + this.recordPack = uploadSession.newRecordPack(); + } catch (IOException e) { + throw new RuntimeException(e); } + this.batchSize = batchSize; + this.flushIntervalMs = flushIntervalMs; + this.lastFlushTime = System.currentTimeMillis(); + } - /** - * Write a record to the stream, if the batch size is reached - * or the flush interval is reached, flush the stream. - * @param record The record to write. - * @throws IOException If an I/O error occurs. - */ - public void write(Record record) throws IOException { - recordPack.append(record); - if (recordPack.getRecordCount() >= batchSize - || System.currentTimeMillis() - lastFlushTime > flushIntervalMs) { - recordPack.flush(); - lastFlushTime = System.currentTimeMillis(); - } + /** + * Write a record to the stream, if the batch size is reached or the flush interval is reached, + * flush the stream. + * + * @param record The record to write. + * @throws IOException If an I/O error occurs. + */ + public void write(Record record) throws IOException { + recordPack.append(record); + if (recordPack.getRecordCount() >= batchSize + || System.currentTimeMillis() - lastFlushTime > flushIntervalMs) { + recordPack.flush(); + lastFlushTime = System.currentTimeMillis(); } + } - /** - * Flush the stream. - * @throws IOException If an I/O error occurs. - */ - public void flush() throws IOException { - recordPack.flush(); - } + /** + * Flush the stream. + * + * @throws IOException If an I/O error occurs. + */ + public void flush() throws IOException { + recordPack.flush(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsBatchIterator.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsBatchIterator.java index 65e81f5b9..8c65870bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsBatchIterator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsBatchIterator.java @@ -19,58 +19,60 @@ package org.apache.geaflow.dsl.connector.odps.utils; -import com.aliyun.odps.PartitionSpec; -import com.aliyun.odps.data.Record; -import com.aliyun.odps.data.RecordReader; import java.io.IOException; import java.util.Iterator; import java.util.Objects; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.data.Record; +import com.aliyun.odps.data.RecordReader; + public class OdpsBatchIterator implements Iterator { - private RecordReader reader; - private long count; - boolean finished = false; - private final PartitionSpec spec; + private RecordReader reader; + private long count; + boolean finished = false; + private final PartitionSpec spec; - public OdpsBatchIterator(RecordReader reader, long count, PartitionSpec spec) { - this.reader = Objects.requireNonNull(reader); - this.spec = Objects.requireNonNull(spec); - this.count = count; - assert count >= 0; - } + public OdpsBatchIterator(RecordReader reader, long count, PartitionSpec spec) { + this.reader = Objects.requireNonNull(reader); + this.spec = Objects.requireNonNull(spec); + this.count = count; + assert count >= 0; + } - @Override - public boolean hasNext() { - if (!finished && count > 0) { - return true; - } else { - try { - if (reader != null) { - reader.close(); - reader = null; - } - } catch (IOException e) { - throw new GeaflowRuntimeException("Error when close odps reader."); - } - return false; + @Override + public boolean hasNext() { + if (!finished && count > 0) { + return true; + } else { + try { + if (reader != null) { + reader.close(); + reader = null; } + } catch (IOException e) { + throw new GeaflowRuntimeException("Error when close odps reader."); + } + return false; } + } - @Override - public OdpsRecordWithPartitionSpec next() { - Record record; - try { - record = reader.read(); - } catch (IOException e) { - throw new GeaflowRuntimeException("Error when read odps."); - } - if (record == null) { - finished = true; - } else { - count--; - } - return new OdpsRecordWithPartitionSpec(record, spec); + @Override + public OdpsRecordWithPartitionSpec next() { + Record record; + try { + record = reader.read(); + } catch (IOException e) { + throw new GeaflowRuntimeException("Error when read odps."); + } + if (record == null) { + finished = true; + } else { + count--; } + return new OdpsRecordWithPartitionSpec(record, spec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsConnectorUtils.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsConnectorUtils.java index 3a62a7620..3a35f773c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsConnectorUtils.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsConnectorUtils.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.connector.odps.utils; -import com.aliyun.odps.OdpsType; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.primitive.BinaryStringType; import org.apache.geaflow.common.type.primitive.BooleanType; @@ -35,47 +34,50 @@ import org.apache.geaflow.dsl.common.types.ArrayType; import org.apache.geaflow.dsl.common.types.VoidType; +import com.aliyun.odps.OdpsType; + public class OdpsConnectorUtils { - public static boolean typeEquals(OdpsType odpsType, IType itype) { - switch (odpsType) { - case TINYINT: - case SMALLINT: - return itype instanceof LongType || itype instanceof IntegerType - || itype instanceof ShortType; - case INT: - return itype instanceof LongType || itype instanceof IntegerType; - case BIGINT: - return itype instanceof IntegerType; - case FLOAT: - return itype instanceof FloatType || itype instanceof DoubleType; - case DOUBLE: - return itype instanceof DoubleType; - case BOOLEAN: - return itype instanceof BooleanType; - case CHAR: - case STRING: - case VARCHAR: - return itype instanceof BinaryStringType || itype instanceof StringType; - case BINARY: - return itype instanceof ByteType || itype instanceof BinaryStringType; - case DECIMAL: - return itype instanceof DecimalType; - case ARRAY: - return itype instanceof ArrayType; - case VOID: - return itype instanceof VoidType; - case DATETIME: - case DATE: - return itype instanceof TimestampType; - case TIMESTAMP: - return itype instanceof TimestampType || itype instanceof LongType; - case MAP: - case INTERVAL_DAY_TIME: - case INTERVAL_YEAR_MONTH: - case STRUCT: - default: - return false; - } + public static boolean typeEquals(OdpsType odpsType, IType itype) { + switch (odpsType) { + case TINYINT: + case SMALLINT: + return itype instanceof LongType + || itype instanceof IntegerType + || itype instanceof ShortType; + case INT: + return itype instanceof LongType || itype instanceof IntegerType; + case BIGINT: + return itype instanceof IntegerType; + case FLOAT: + return itype instanceof FloatType || itype instanceof DoubleType; + case DOUBLE: + return itype instanceof DoubleType; + case BOOLEAN: + return itype instanceof BooleanType; + case CHAR: + case STRING: + case VARCHAR: + return itype instanceof BinaryStringType || itype instanceof StringType; + case BINARY: + return itype instanceof ByteType || itype instanceof BinaryStringType; + case DECIMAL: + return itype instanceof DecimalType; + case ARRAY: + return itype instanceof ArrayType; + case VOID: + return itype instanceof VoidType; + case DATETIME: + case DATE: + return itype instanceof TimestampType; + case TIMESTAMP: + return itype instanceof TimestampType || itype instanceof LongType; + case MAP: + case INTERVAL_DAY_TIME: + case INTERVAL_YEAR_MONTH: + case STRUCT: + default: + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordDeserializer.java index 64c008d23..f3d102116 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordDeserializer.java @@ -19,14 +19,12 @@ package org.apache.geaflow.dsl.connector.odps.utils; -import com.aliyun.odps.Column; -import com.aliyun.odps.PartitionSpec; -import com.aliyun.odps.data.Record; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -34,48 +32,52 @@ import org.apache.geaflow.dsl.common.util.TypeCastUtil; import org.apache.geaflow.dsl.connector.api.serde.TableDeserializer; +import com.aliyun.odps.Column; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.data.Record; + public class OdpsRecordDeserializer implements TableDeserializer { - private StructType schema; + private StructType schema; - private Map columnName2Index; + private Map columnName2Index; - @Override - public void init(Configuration conf, StructType schema) { - this.schema = Objects.requireNonNull(schema); - columnName2Index = new HashMap<>(); - for (int i = 0; i < schema.size(); i++) { - columnName2Index.put(schema.getField(i).getName(), i); - } + @Override + public void init(Configuration conf, StructType schema) { + this.schema = Objects.requireNonNull(schema); + columnName2Index = new HashMap<>(); + for (int i = 0; i < schema.size(); i++) { + columnName2Index.put(schema.getField(i).getName(), i); } + } - @Override - public List deserialize(OdpsRecordWithPartitionSpec recordWithPartitionSpec) { - if (recordWithPartitionSpec == null || recordWithPartitionSpec.record == null) { - return Collections.emptyList(); - } - Record item = recordWithPartitionSpec.record; - Object[] objects = new Object[this.schema.size()]; - Column[] columns = item.getColumns(); - int colIndex = 0; - for (Column col : columns) { - String colName = col.getName(); - Integer index = columnName2Index.get(colName); - if (index != null) { - objects[index] = TypeCastUtil.cast(item.get(colIndex), this.schema.getType(colIndex)); - } - colIndex++; - } - PartitionSpec spec = recordWithPartitionSpec.spec; - if (spec != null) { - for (String colName : spec.keys()) { - Integer index = columnName2Index.get(colName); - if (index != null) { - objects[index] = TypeCastUtil.cast(spec.get(colName), this.schema.getType(index)); - } - colIndex++; - } + @Override + public List deserialize(OdpsRecordWithPartitionSpec recordWithPartitionSpec) { + if (recordWithPartitionSpec == null || recordWithPartitionSpec.record == null) { + return Collections.emptyList(); + } + Record item = recordWithPartitionSpec.record; + Object[] objects = new Object[this.schema.size()]; + Column[] columns = item.getColumns(); + int colIndex = 0; + for (Column col : columns) { + String colName = col.getName(); + Integer index = columnName2Index.get(colName); + if (index != null) { + objects[index] = TypeCastUtil.cast(item.get(colIndex), this.schema.getType(colIndex)); + } + colIndex++; + } + PartitionSpec spec = recordWithPartitionSpec.spec; + if (spec != null) { + for (String colName : spec.keys()) { + Integer index = columnName2Index.get(colName); + if (index != null) { + objects[index] = TypeCastUtil.cast(spec.get(colName), this.schema.getType(index)); } - return Collections.singletonList(ObjectRow.create(objects)); + colIndex++; + } } -} \ No newline at end of file + return Collections.singletonList(ObjectRow.create(objects)); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordWithPartitionSpec.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordWithPartitionSpec.java index f56ab13e8..e2285ee24 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordWithPartitionSpec.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/main/java/org/apache/geaflow/dsl/connector/odps/utils/OdpsRecordWithPartitionSpec.java @@ -24,12 +24,12 @@ public class OdpsRecordWithPartitionSpec { - public final Record record; + public final Record record; - public final PartitionSpec spec; + public final PartitionSpec spec; - public OdpsRecordWithPartitionSpec(Record record, PartitionSpec spec) { - this.record = record; - this.spec = spec; - } + public OdpsRecordWithPartitionSpec(Record record, PartitionSpec spec) { + this.record = record; + this.spec = spec; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractorTest.java index 0532cc9c8..46552e511 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/DefaultPartitionExtractorTest.java @@ -24,93 +24,90 @@ import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; -import org.apache.geaflow.dsl.common.types.StructType; import org.junit.Assert; import org.junit.Test; public class DefaultPartitionExtractorTest { - @Test - public void testExtract() { - Row row1 = ObjectRow.create(1, 2025111, 10, 3.14, "a11"); - Row row2 = ObjectRow.create(1, 2025111, 11, 3.14, "b11"); - Row row3 = ObjectRow.create(1, 2025111, 12, 3.14, "c11"); - Row row4 = ObjectRow.create(1, 2025111, 13, 3.14, "d11"); - Row row5 = ObjectRow.create(1, 2025111, 14, 3.14, "e11"); - Row row6 = ObjectRow.create(1, 2025110, 11, 3.14, "null"); + @Test + public void testExtract() { + Row row1 = ObjectRow.create(1, 2025111, 10, 3.14, "a11"); + Row row2 = ObjectRow.create(1, 2025111, 11, 3.14, "b11"); + Row row3 = ObjectRow.create(1, 2025111, 12, 3.14, "c11"); + Row row4 = ObjectRow.create(1, 2025111, 13, 3.14, "d11"); + Row row5 = ObjectRow.create(1, 2025111, 14, 3.14, "e11"); + Row row6 = ObjectRow.create(1, 2025110, 11, 3.14, "null"); - String spec1 = "dt=$dt,hh=$hh,biz=$biz"; - DefaultPartitionExtractor extractor1 = new DefaultPartitionExtractor( - spec1, new int[]{1, 2, 4}, - new IType[]{IntegerType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE}); - Assert.assertEquals("dt=2025111,hh=10,biz=a11", extractor1.extractPartition(row1)); - Assert.assertEquals("dt=2025111,hh=11,biz=b11", extractor1.extractPartition(row2)); - Assert.assertEquals("dt=2025111,hh=12,biz=c11", extractor1.extractPartition(row3)); - Assert.assertEquals("dt=2025111,hh=13,biz=d11", extractor1.extractPartition(row4)); - Assert.assertEquals("dt=2025111,hh=14,biz=e11", extractor1.extractPartition(row5)); - Assert.assertEquals("dt=2025110,hh=11,biz=null", extractor1.extractPartition(row6)); + String spec1 = "dt=$dt,hh=$hh,biz=$biz"; + DefaultPartitionExtractor extractor1 = + new DefaultPartitionExtractor( + spec1, + new int[] {1, 2, 4}, + new IType[] {IntegerType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE}); + Assert.assertEquals("dt=2025111,hh=10,biz=a11", extractor1.extractPartition(row1)); + Assert.assertEquals("dt=2025111,hh=11,biz=b11", extractor1.extractPartition(row2)); + Assert.assertEquals("dt=2025111,hh=12,biz=c11", extractor1.extractPartition(row3)); + Assert.assertEquals("dt=2025111,hh=13,biz=d11", extractor1.extractPartition(row4)); + Assert.assertEquals("dt=2025111,hh=14,biz=e11", extractor1.extractPartition(row5)); + Assert.assertEquals("dt=2025110,hh=11,biz=null", extractor1.extractPartition(row6)); - String spec2 = "dt=$dt/hh=$hh/biz=$biz"; - DefaultPartitionExtractor extractor2 = new DefaultPartitionExtractor( - spec2, new int[]{1, 2, 4}, - new IType[]{IntegerType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE}); - Assert.assertEquals("dt=2025111/hh=10/biz=a11", extractor2.extractPartition(row1)); - Assert.assertEquals("dt=2025111/hh=11/biz=b11", extractor2.extractPartition(row2)); - Assert.assertEquals("dt=2025111/hh=12/biz=c11", extractor2.extractPartition(row3)); - Assert.assertEquals("dt=2025111/hh=13/biz=d11", extractor2.extractPartition(row4)); - Assert.assertEquals("dt=2025111/hh=14/biz=e11", extractor2.extractPartition(row5)); - Assert.assertEquals("dt=2025110/hh=11/biz=null", extractor2.extractPartition(row6)); + String spec2 = "dt=$dt/hh=$hh/biz=$biz"; + DefaultPartitionExtractor extractor2 = + new DefaultPartitionExtractor( + spec2, + new int[] {1, 2, 4}, + new IType[] {IntegerType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE}); + Assert.assertEquals("dt=2025111/hh=10/biz=a11", extractor2.extractPartition(row1)); + Assert.assertEquals("dt=2025111/hh=11/biz=b11", extractor2.extractPartition(row2)); + Assert.assertEquals("dt=2025111/hh=12/biz=c11", extractor2.extractPartition(row3)); + Assert.assertEquals("dt=2025111/hh=13/biz=d11", extractor2.extractPartition(row4)); + Assert.assertEquals("dt=2025111/hh=14/biz=e11", extractor2.extractPartition(row5)); + Assert.assertEquals("dt=2025110/hh=11/biz=null", extractor2.extractPartition(row6)); + String spec3 = "dt=$dt"; + DefaultPartitionExtractor extractor3 = + new DefaultPartitionExtractor(spec3, new int[] {1}, new IType[] {IntegerType.INSTANCE}); + Assert.assertEquals("dt=2025111", extractor3.extractPartition(row1)); + Assert.assertEquals("dt=2025111", extractor3.extractPartition(row2)); + Assert.assertEquals("dt=2025111", extractor3.extractPartition(row3)); + Assert.assertEquals("dt=2025111", extractor3.extractPartition(row4)); + Assert.assertEquals("dt=2025111", extractor3.extractPartition(row5)); + Assert.assertEquals("dt=2025110", extractor3.extractPartition(row6)); - String spec3 = "dt=$dt"; - DefaultPartitionExtractor extractor3 = new DefaultPartitionExtractor( - spec3, new int[]{1}, - new IType[]{IntegerType.INSTANCE}); - Assert.assertEquals("dt=2025111", extractor3.extractPartition(row1)); - Assert.assertEquals("dt=2025111", extractor3.extractPartition(row2)); - Assert.assertEquals("dt=2025111", extractor3.extractPartition(row3)); - Assert.assertEquals("dt=2025111", extractor3.extractPartition(row4)); - Assert.assertEquals("dt=2025111", extractor3.extractPartition(row5)); - Assert.assertEquals("dt=2025110", extractor3.extractPartition(row6)); + String spec4 = "dt=20251120"; + DefaultPartitionExtractor extractor4 = + new DefaultPartitionExtractor(spec4, new int[] {}, new IType[] {}); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row1)); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row2)); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row3)); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row4)); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row5)); + Assert.assertEquals("dt=20251120", extractor4.extractPartition(row6)); + String spec5 = "dt=20251120,hh=$hh"; + DefaultPartitionExtractor extractor5 = + new DefaultPartitionExtractor(spec5, new int[] {2}, new IType[] {IntegerType.INSTANCE}); + Assert.assertEquals("dt=20251120,hh=10", extractor5.extractPartition(row1)); + Assert.assertEquals("dt=20251120,hh=11", extractor5.extractPartition(row2)); + Assert.assertEquals("dt=20251120,hh=12", extractor5.extractPartition(row3)); + Assert.assertEquals("dt=20251120,hh=13", extractor5.extractPartition(row4)); + Assert.assertEquals("dt=20251120,hh=14", extractor5.extractPartition(row5)); + Assert.assertEquals("dt=20251120,hh=11", extractor5.extractPartition(row6)); - String spec4 = "dt=20251120"; - DefaultPartitionExtractor extractor4 = new DefaultPartitionExtractor( - spec4, new int[]{}, - new IType[]{}); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row1)); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row2)); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row3)); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row4)); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row5)); - Assert.assertEquals("dt=20251120", extractor4.extractPartition(row6)); + PartitionExtractor extractor6 = DefaultPartitionExtractor.create("", null); + Assert.assertEquals("", extractor6.extractPartition(row1)); + Assert.assertEquals("", extractor6.extractPartition(row2)); + Assert.assertEquals("", extractor6.extractPartition(row3)); + Assert.assertEquals("", extractor6.extractPartition(row4)); + Assert.assertEquals("", extractor6.extractPartition(row5)); + Assert.assertEquals("", extractor6.extractPartition(row6)); + } - String spec5 = "dt=20251120,hh=$hh"; - DefaultPartitionExtractor extractor5 = new DefaultPartitionExtractor( - spec5, new int[]{2}, - new IType[]{IntegerType.INSTANCE}); - Assert.assertEquals("dt=20251120,hh=10", extractor5.extractPartition(row1)); - Assert.assertEquals("dt=20251120,hh=11", extractor5.extractPartition(row2)); - Assert.assertEquals("dt=20251120,hh=12", extractor5.extractPartition(row3)); - Assert.assertEquals("dt=20251120,hh=13", extractor5.extractPartition(row4)); - Assert.assertEquals("dt=20251120,hh=14", extractor5.extractPartition(row5)); - Assert.assertEquals("dt=20251120,hh=11", extractor5.extractPartition(row6)); - - PartitionExtractor extractor6 = DefaultPartitionExtractor.create("", null); - Assert.assertEquals("", extractor6.extractPartition(row1)); - Assert.assertEquals("", extractor6.extractPartition(row2)); - Assert.assertEquals("", extractor6.extractPartition(row3)); - Assert.assertEquals("", extractor6.extractPartition(row4)); - Assert.assertEquals("", extractor6.extractPartition(row5)); - Assert.assertEquals("", extractor6.extractPartition(row6)); - - } - - @Test - public void testUnquoted() { - Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("dt")); - Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("`dt`")); - Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("'dt'")); - Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("\"dt\"")); - } + @Test + public void testUnquoted() { + Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("dt")); + Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("`dt`")); + Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("'dt'")); + Assert.assertEquals("dt", DefaultPartitionExtractor.unquoted("\"dt\"")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSinkTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSinkTest.java index 063bc5360..62696e3e0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSinkTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSinkTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.connector.odps; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -31,34 +32,34 @@ public class OdpsTableSinkTest { - @Test(enabled = false) - public void testOdpsTableSink() throws IOException { - OdpsTableSink sink = new OdpsTableSink(); - Configuration config = new Configuration(); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PARTITION_SPEC, "dt='20000000'"); - TableSchema schema = new TableSchema( + @Test(enabled = false) + public void testOdpsTableSink() throws IOException { + OdpsTableSink sink = new OdpsTableSink(); + Configuration config = new Configuration(); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PARTITION_SPEC, "dt='20000000'"); + TableSchema schema = + new TableSchema( new TableField("src_id", StringType.INSTANCE, false), new TableField("target_id", StringType.INSTANCE, false), - new TableField("relation", StringType.INSTANCE, false) - ); - sink.init(config, schema); - try { - sink.open(new DefaultRuntimeContext(config)); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); - } - sink.write(ObjectRow.create(new Object[]{"1", "2", "3"})); - try { - sink.finish(); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "java.lang.IllegalArgumentException"); - } - sink.close(); + new TableField("relation", StringType.INSTANCE, false)); + sink.init(config, schema); + try { + sink.open(new DefaultRuntimeContext(config)); + } catch (Exception e) { + Assert.assertEquals( + e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); } - + sink.write(ObjectRow.create(new Object[] {"1", "2", "3"})); + try { + sink.finish(); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "java.lang.IllegalArgumentException"); + } + sink.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSourceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSourceTest.java index a31756ef7..1636017df 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSourceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-odps/src/test/java/org/apache/geaflow/dsl/connector/odps/OdpsTableSourceTest.java @@ -19,13 +19,10 @@ package org.apache.geaflow.dsl.connector.odps; -import com.aliyun.odps.Column; -import com.aliyun.odps.OdpsType; -import com.aliyun.odps.PartitionSpec; -import com.aliyun.odps.data.ArrayRecord; import java.io.IOException; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.BinaryStringType; import org.apache.geaflow.common.type.primitive.BooleanType; @@ -51,97 +48,116 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.aliyun.odps.Column; +import com.aliyun.odps.OdpsType; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.data.ArrayRecord; + public class OdpsTableSourceTest { - @Test(enabled = false) - public void testOdpsTableSource() throws IOException { - OdpsTableSource source = new OdpsTableSource(); - Configuration config = new Configuration(); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); - TableSchema schema = new TableSchema( + @Test(enabled = false) + public void testOdpsTableSource() throws IOException { + OdpsTableSource source = new OdpsTableSource(); + Configuration config = new Configuration(); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); + TableSchema schema = + new TableSchema( new TableField("src_id", IntegerType.INSTANCE, false), new TableField("target_id", IntegerType.INSTANCE, false), - new TableField("relation", IntegerType.INSTANCE, false) - ); - source.init(config, schema); - try { - source.open(new DefaultRuntimeContext(config)); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Can't bind xml to com.aliyun.odps.Table$TableModel"); - } - try { - List odpsPartitions = source.listPartitions(); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); - } - Partition firstPartition = new OdpsShardPartition("prefix-", new PartitionSpec("dt='20000000'")); - FetchWindow window = new SizeFetchWindow(1, 100L); - try { - FetchData data = source.fetch(firstPartition, Optional.empty(), window); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Cannot get Odps session."); - } - TableDeserializer deserializer = source.getDeserializer(config); - deserializer.init(config, schema); - deserializer.deserialize(new OdpsRecordWithPartitionSpec(new ArrayRecord( - new Column[]{new Column("src_id", OdpsType.STRING), new Column("test", OdpsType.STRING)}, - new Object[]{"16", "32"}), null)); + new TableField("relation", IntegerType.INSTANCE, false)); + source.init(config, schema); + try { + source.open(new DefaultRuntimeContext(config)); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "Can't bind xml to com.aliyun.odps.Table$TableModel"); + } + try { + List odpsPartitions = source.listPartitions(); + } catch (Exception e) { + Assert.assertEquals( + e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); + } + Partition firstPartition = + new OdpsShardPartition("prefix-", new PartitionSpec("dt='20000000'")); + FetchWindow window = new SizeFetchWindow(1, 100L); + try { + FetchData data = source.fetch(firstPartition, Optional.empty(), window); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "Cannot get Odps session."); } + TableDeserializer deserializer = source.getDeserializer(config); + deserializer.init(config, schema); + deserializer.deserialize( + new OdpsRecordWithPartitionSpec( + new ArrayRecord( + new Column[] { + new Column("src_id", OdpsType.STRING), new Column("test", OdpsType.STRING) + }, + new Object[] {"16", "32"}), + null)); + } - @Test(enabled = false) - public void testPartitionSpec() throws IOException { - OdpsTableSource source = new OdpsTableSource(); - Configuration config = new Configuration(); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); - config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); - TableSchema schema = new TableSchema( + @Test(enabled = false) + public void testPartitionSpec() throws IOException { + OdpsTableSource source = new OdpsTableSource(); + Configuration config = new Configuration(); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ENDPOINT, "http://test.odps.com/api"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_PROJECT, "test_project"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_TABLE, "test_table"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_KEY, "test_access_key"); + config.put(OdpsConfigKeys.GEAFLOW_DSL_ODPS_ACCESS_ID, "test_access_id"); + TableSchema schema = + new TableSchema( new TableField("src_id", IntegerType.INSTANCE, false), new TableField("target_id", IntegerType.INSTANCE, false), - new TableField("relation", IntegerType.INSTANCE, false) - ); - source.init(config, schema); - try { - source.open(new DefaultRuntimeContext(config)); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Can't bind xml to com.aliyun.odps.Table$TableModel"); - } - try { - List odpsPartitions = source.listPartitions(); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); - } - Partition firstPartition = new OdpsShardPartition("prefix-", new PartitionSpec("dt='20000000'")); - FetchWindow window = new SizeFetchWindow(1, 100L); - try { - FetchData data = source.fetch(firstPartition, Optional.empty(), window); - } catch (Exception e) { - Assert.assertEquals(e.getMessage(), "Cannot get Odps session."); - } - TableDeserializer deserializer = source.getDeserializer(config); - deserializer.init(config, schema); - deserializer.deserialize(new OdpsRecordWithPartitionSpec(new ArrayRecord( - new Column[]{new Column("src_id", OdpsType.STRING), new Column("test", OdpsType.STRING)}, - new Object[]{"16", "32"}), null)); + new TableField("relation", IntegerType.INSTANCE, false)); + source.init(config, schema); + try { + source.open(new DefaultRuntimeContext(config)); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "Can't bind xml to com.aliyun.odps.Table$TableModel"); } - - @Test - public void testOdpsConnectorUtils() { - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.SMALLINT, LongType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.INT, LongType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.FLOAT, FloatType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DOUBLE, DoubleType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.BOOLEAN, BooleanType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.STRING, BinaryStringType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.BINARY, ByteType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DECIMAL, DecimalType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.VOID, VoidType.INSTANCE)); - Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DATE, TimestampType.INSTANCE)); + try { + List odpsPartitions = source.listPartitions(); + } catch (Exception e) { + Assert.assertEquals( + e.getMessage(), "Cannot list partitions from ODPS, endPoint: http://test.odps.com/api"); + } + Partition firstPartition = + new OdpsShardPartition("prefix-", new PartitionSpec("dt='20000000'")); + FetchWindow window = new SizeFetchWindow(1, 100L); + try { + FetchData data = source.fetch(firstPartition, Optional.empty(), window); + } catch (Exception e) { + Assert.assertEquals(e.getMessage(), "Cannot get Odps session."); } + TableDeserializer deserializer = source.getDeserializer(config); + deserializer.init(config, schema); + deserializer.deserialize( + new OdpsRecordWithPartitionSpec( + new ArrayRecord( + new Column[] { + new Column("src_id", OdpsType.STRING), new Column("test", OdpsType.STRING) + }, + new Object[] {"16", "32"}), + null)); + } + + @Test + public void testOdpsConnectorUtils() { + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.SMALLINT, LongType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.INT, LongType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.FLOAT, FloatType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DOUBLE, DoubleType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.BOOLEAN, BooleanType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.STRING, BinaryStringType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.BINARY, ByteType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DECIMAL, DecimalType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.VOID, VoidType.INSTANCE)); + Assert.assertTrue(OdpsConnectorUtils.typeEquals(OdpsType.DATE, TimestampType.INSTANCE)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/IteratorWrapper.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/IteratorWrapper.java index 3ce921708..6ce68a589 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/IteratorWrapper.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/IteratorWrapper.java @@ -25,27 +25,27 @@ public class IteratorWrapper implements CloseableIterator { - private final CloseableIterator iterator; - private final PaimonRecordDeserializer deserializer; - - public IteratorWrapper(CloseableIterator iterator, - PaimonRecordDeserializer deserializer) { - this.iterator = iterator; - this.deserializer = deserializer; - } - - @Override - public void close() throws Exception { - iterator.close(); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - - @Override - public Row next() { - return deserializer.deserialize(iterator.next()); - } + private final CloseableIterator iterator; + private final PaimonRecordDeserializer deserializer; + + public IteratorWrapper( + CloseableIterator iterator, PaimonRecordDeserializer deserializer) { + this.iterator = iterator; + this.deserializer = deserializer; + } + + @Override + public void close() throws Exception { + iterator.close(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public Row next() { + return deserializer.deserialize(iterator.next()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonConfigKeys.java index 7218ed73b..9e6903d73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonConfigKeys.java @@ -24,43 +24,44 @@ public class PaimonConfigKeys { - public static final ConfigKey GEAFLOW_DSL_PAIMON_WAREHOUSE = ConfigKeys - .key("geaflow.dsl.paimon.warehouse") - .noDefaultValue() - .description("The warehouse path for paimon catalog creation."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_WAREHOUSE = + ConfigKeys.key("geaflow.dsl.paimon.warehouse") + .noDefaultValue() + .description("The warehouse path for paimon catalog creation."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_OPTIONS_JSON = ConfigKeys - .key("geaflow.dsl.paimon.options.json") - .noDefaultValue() - .description("The options json for paimon catalog creation."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_OPTIONS_JSON = + ConfigKeys.key("geaflow.dsl.paimon.options.json") + .noDefaultValue() + .description("The options json for paimon catalog creation."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_CONFIGURATION_JSON = ConfigKeys - .key("geaflow.dsl.paimon.configuration.json") - .noDefaultValue() - .description("The configuration json for paimon catalog creation."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_CONFIGURATION_JSON = + ConfigKeys.key("geaflow.dsl.paimon.configuration.json") + .noDefaultValue() + .description("The configuration json for paimon catalog creation."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_DATABASE_NAME = ConfigKeys - .key("geaflow.dsl.paimon.database.name") - .noDefaultValue() - .description("The database name for paimon table."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_DATABASE_NAME = + ConfigKeys.key("geaflow.dsl.paimon.database.name") + .noDefaultValue() + .description("The database name for paimon table."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_TABLE_NAME = ConfigKeys - .key("geaflow.dsl.paimon.table.name") - .noDefaultValue() - .description("The paimon table name to read."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_TABLE_NAME = + ConfigKeys.key("geaflow.dsl.paimon.table.name") + .noDefaultValue() + .description("The paimon table name to read."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_SOURCE_MODE = ConfigKeys - .key("geaflow.dsl.paimon.source.mode") - .defaultValue(SourceMode.BATCH.name()) - .description("The paimon source mode, if stream, will continue to read data from paimon."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_SOURCE_MODE = + ConfigKeys.key("geaflow.dsl.paimon.source.mode") + .defaultValue(SourceMode.BATCH.name()) + .description( + "The paimon source mode, if stream, will continue to read data from paimon."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID = ConfigKeys - .key("geaflow.dsl.paimon.scan.snapshot.id") - .defaultValue(null) - .description("If scan mode is from-snapshot, this parameter is required."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID = + ConfigKeys.key("geaflow.dsl.paimon.scan.snapshot.id") + .defaultValue(null) + .description("If scan mode is from-snapshot, this parameter is required."); - public static final ConfigKey GEAFLOW_DSL_PAIMON_SCAN_MODE = ConfigKeys - .key("geaflow.dsl.paimon.scan.mode") - .defaultValue(StartupMode.LATEST.getValue()) - .description("Determines the scan mode for paimon source, 'latest' or 'from-snapshot'."); + public static final ConfigKey GEAFLOW_DSL_PAIMON_SCAN_MODE = + ConfigKeys.key("geaflow.dsl.paimon.scan.mode") + .defaultValue(StartupMode.LATEST.getValue()) + .description("Determines the scan mode for paimon source, 'latest' or 'from-snapshot'."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonRecordDeserializer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonRecordDeserializer.java index 1cbd523c6..fa15a7ef4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonRecordDeserializer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonRecordDeserializer.java @@ -30,50 +30,49 @@ public class PaimonRecordDeserializer { - private StructType schema; + private StructType schema; - public void init(StructType schema) { - TableSchema tableSchema = (TableSchema) schema; - this.schema = tableSchema.getDataSchema(); - } + public void init(StructType schema) { + TableSchema tableSchema = (TableSchema) schema; + this.schema = tableSchema.getDataSchema(); + } - public Row deserialize(Object record) { - if (record == null) { - return null; - } - InternalRow internalRow = (InternalRow) record; - assert internalRow.getFieldCount() == schema.size(); - Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - TableField field = this.schema.getField(i); - switch (field.getType().getName()) { - case Types.TYPE_NAME_BOOLEAN: - values[i] = internalRow.getBoolean(i); - break; - case Types.TYPE_NAME_BYTE: - values[i] = internalRow.getByte(i); - break; - case Types.TYPE_NAME_DOUBLE: - values[i] = internalRow.getDouble(i); - break; - case Types.TYPE_NAME_FLOAT: - values[i] = internalRow.getFloat(i); - break; - case Types.TYPE_NAME_INTEGER: - values[i] = internalRow.getInt(i); - break; - case Types.TYPE_NAME_LONG: - values[i] = internalRow.getLong(i); - break; - case Types.TYPE_NAME_STRING: - case Types.TYPE_NAME_BINARY_STRING: - values[i] = internalRow.getString(i); - break; - default: - throw new GeaFlowDSLException("Type: {} not support", - field.getType().getName()); - } - } - return ObjectRow.create(values); + public Row deserialize(Object record) { + if (record == null) { + return null; + } + InternalRow internalRow = (InternalRow) record; + assert internalRow.getFieldCount() == schema.size(); + Object[] values = new Object[schema.size()]; + for (int i = 0; i < schema.size(); i++) { + TableField field = this.schema.getField(i); + switch (field.getType().getName()) { + case Types.TYPE_NAME_BOOLEAN: + values[i] = internalRow.getBoolean(i); + break; + case Types.TYPE_NAME_BYTE: + values[i] = internalRow.getByte(i); + break; + case Types.TYPE_NAME_DOUBLE: + values[i] = internalRow.getDouble(i); + break; + case Types.TYPE_NAME_FLOAT: + values[i] = internalRow.getFloat(i); + break; + case Types.TYPE_NAME_INTEGER: + values[i] = internalRow.getInt(i); + break; + case Types.TYPE_NAME_LONG: + values[i] = internalRow.getLong(i); + break; + case Types.TYPE_NAME_STRING: + case Types.TYPE_NAME_BINARY_STRING: + values[i] = internalRow.getString(i); + break; + default: + throw new GeaFlowDSLException("Type: {} not support", field.getType().getName()); + } } + return ObjectRow.create(values); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnector.java index deda93832..d11badef9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnector.java @@ -25,15 +25,15 @@ public class PaimonTableConnector implements TableReadableConnector { - private static final String PAIMON = "PAIMON"; + private static final String PAIMON = "PAIMON"; - @Override - public String getType() { - return PAIMON; - } + @Override + public String getType() { + return PAIMON; + } - @Override - public TableSource createSource(Configuration conf) { - return new PaimonTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new PaimonTableSource(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableSource.java index f5ba0eb59..dca60459d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableSource.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; @@ -59,456 +60,483 @@ /** Paimon table source. */ public class PaimonTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonTableSource.class); - - private Configuration tableConf; - private TableSchema tableSchema; - - private String path; - private Map options; - private String configJson; - private Map configs; - private String database; - private String table; - private SourceMode sourceMode; - private final PaimonRecordDeserializer deserializer = new PaimonRecordDeserializer(); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonTableSource.class); + + private Configuration tableConf; + private TableSchema tableSchema; + + private String path; + private Map options; + private String configJson; + private Map configs; + private String database; + private String table; + private SourceMode sourceMode; + private final PaimonRecordDeserializer deserializer = new PaimonRecordDeserializer(); + + private transient long fromSnapshot; + private transient ReadBuilder readBuilder; + private transient Map> iterators; + private transient Map> readers; + private transient Map offsets; + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.tableConf = tableConf; + this.tableSchema = tableSchema; + this.path = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_WAREHOUSE, ""); + this.options = new HashMap<>(); + this.configs = new HashMap<>(); + if (StringUtils.isBlank(this.path)) { + String optionJson = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_OPTIONS_JSON); + Map userOptions = GsonUtil.parse(optionJson); + if (userOptions != null) { + options.putAll(userOptions); + } + this.configJson = + tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_CONFIGURATION_JSON, ""); + if (!StringUtils.isBlank(configJson)) { + Map userConfig = GsonUtil.parse(configJson); + if (userConfig != null) { + configs.putAll(userConfig); + } + } + } + this.database = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_DATABASE_NAME); + this.table = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_TABLE_NAME); + this.sourceMode = + SourceMode.valueOf( + tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE).toUpperCase()); + } + + @Override + public List listPartitions(int parallelism) { + List partitions = new ArrayList<>(parallelism); + for (int i = 0; i < parallelism; i++) { + partitions.add(new PaimonPartition(database, table)); + } + return partitions; + } + + @Override + public List listPartitions() { + throw new UnsupportedOperationException("Please use listPartitions(int parallelism) instead"); + } + + @Override + public void open(RuntimeContext context) { + Catalog catalog = getPaimonCatalog(); + Identifier identifier = Identifier.create(database, table); + try { + this.readBuilder = Objects.requireNonNull(catalog.getTable(identifier).newReadBuilder()); + if (tableConf + .getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_MODE) + .equalsIgnoreCase(StartupMode.FROM_SNAPSHOT.getValue()) + && tableConf.contains(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID)) { + this.fromSnapshot = tableConf.getLong(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID); + } else { + this.fromSnapshot = catalog.getTable(identifier).latestSnapshotId().orElse(1L); + } + LOGGER.info("New partition will start from snapshot: {}", this.fromSnapshot); + } catch (TableNotExistException e) { + throw new GeaFlowDSLException("Table: {} in db: {} not exists.", table, database); + } + this.iterators = new HashMap<>(); + this.readers = new HashMap<>(); + this.offsets = new HashMap<>(); + LOGGER.info( + "Open paimon source, tableConf: {}, tableSchema: {}, path: {}, options: " + + "{}, configs: {}, database: {}, tableName: {}", + tableConf, + tableSchema, + path, + options, + configs, + database, + table); + this.deserializer.init(tableSchema); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return null; + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + PaimonPartition paimonPartition = (PaimonPartition) partition; + assert paimonPartition.getDatabase().equals(this.database) + && paimonPartition.getTable().equals(this.table); + + long skip = 0; + PaimonOffset innerOffset = offsets.get(partition); + if (innerOffset == null) { + // first fetch, use custom specified snapshot id + innerOffset = new PaimonOffset(fromSnapshot); + } - private transient long fromSnapshot; - private transient ReadBuilder readBuilder; - private transient Map> iterators; - private transient Map> readers; - private transient Map offsets; + // if startOffset is specified, use it and try reset innerOffset + if ((startOffset.isPresent() && !startOffset.get().equals(innerOffset))) { + skip = Math.abs(startOffset.get().getOffset() - innerOffset.getOffset()); + innerOffset = PaimonOffset.from((PaimonOffset) startOffset.get()); + } + if (paimonPartition.getCurrentSnapshot() != innerOffset.getSnapshotId()) { + paimonPartition.reset(loadSplitsFrom(innerOffset.getSnapshotId())); + } - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.tableConf = tableConf; - this.tableSchema = tableSchema; - this.path = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_WAREHOUSE, ""); - this.options = new HashMap<>(); - this.configs = new HashMap<>(); - if (StringUtils.isBlank(this.path)) { - String optionJson = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_OPTIONS_JSON); - Map userOptions = GsonUtil.parse(optionJson); - if (userOptions != null) { - options.putAll(userOptions); + Split split = paimonPartition.seek(innerOffset.getSplitIndex()); + if (split == null) { + if (sourceMode == SourceMode.BATCH) { + LOGGER.info("No more split to fetch"); + return FetchData.createBatchFetch(Collections.emptyIterator(), innerOffset); + } else { + LOGGER.debug("Snapshot {} not ready now", innerOffset.getSnapshotId()); + return FetchData.createStreamFetch(Collections.emptyList(), innerOffset, false); + } + } + offsets.put(paimonPartition, innerOffset); + + CloseableIterator iterator = createRecordIterator(split); + if (skip > 0) { + while (iterator.hasNext() && skip > 0) { + iterator.next(); + skip--; + } + } + switch (windowInfo.getType()) { + case ALL_WINDOW: + return FetchData.createBatchFetch( + new IteratorWrapper(iterator, deserializer), new PaimonOffset()); + case SIZE_TUMBLING_WINDOW: + List readContents = new ArrayList<>(); + long advance = 0; + for (long i = 0; i < windowInfo.windowSize(); i++) { + if (iterator.hasNext()) { + advance++; + readContents.add(deserializer.deserialize(iterator.next())); + } else { + if (sourceMode == SourceMode.BATCH) { + break; } - this.configJson = - tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_CONFIGURATION_JSON, ""); - if (!StringUtils.isBlank(configJson)) { - Map userConfig = GsonUtil.parse(configJson); - if (userConfig != null) { - configs.putAll(userConfig); - } + try { + removeRecordIterator(split); + } catch (Exception e) { + throw new RuntimeException(e); } - } - this.database = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_DATABASE_NAME); - this.table = tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_TABLE_NAME); - this.sourceMode = SourceMode.valueOf(tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE).toUpperCase()); + innerOffset = innerOffset.nextSplit(); + Split seek = paimonPartition.seek(innerOffset.getSplitIndex()); + if (seek == null) { + // all split finished, try read next snapshot + innerOffset = innerOffset.nextSnapshot(); + paimonPartition.reset(loadSplitsFrom(innerOffset.getSnapshotId())); + } + offsets.put(paimonPartition, innerOffset); + advance = 0L; + seek = paimonPartition.seek(innerOffset.getSplitIndex()); + if (seek == null) { + // no new snapshot discovered, retry next turn + break; + } + iterator = createRecordIterator(seek); + i--; + } + } + PaimonOffset next = innerOffset.advance(advance); + offsets.put(paimonPartition, next); + boolean isFinished = sourceMode != SourceMode.STREAM && !iterator.hasNext(); + return FetchData.createStreamFetch(readContents, next, isFinished); + default: + throw new GeaFlowDSLException("Paimon not support window:{}", windowInfo.getType()); } - - @Override - public List listPartitions(int parallelism) { - List partitions = new ArrayList<>(parallelism); - for (int i = 0; i < parallelism; i++) { - partitions.add(new PaimonPartition(database, table)); - } - return partitions; + } + + /** + * Create a paimon record iterator. + * + * @param split the paimon data split + * @return a paimon record iterator + * @throws IOException if error occurs + */ + private CloseableIterator createRecordIterator(Split split) throws IOException { + CloseableIterator iterator = iterators.get(split); + if (iterator != null) { + return iterator; } - - @Override - public List listPartitions() { - throw new UnsupportedOperationException("Please use listPartitions(int parallelism) instead"); + RecordReader reader = readBuilder.newRead().createReader(split); + CloseableIterator closeableIterator = reader.toCloseableIterator(); + readers.put(split, reader); + iterators.put(split, closeableIterator); + return closeableIterator; + } + + /** + * Remove the paimon record iterator and reader. + * + * @param split the paimon data split + * @throws Exception if error occurs + */ + private void removeRecordIterator(Split split) throws Exception { + CloseableIterator removed = iterators.remove(split); + if (removed != null) { + removed.close(); + } + RecordReader reader = readers.remove(split); + if (reader != null) { + reader.close(); } + } + + /** + * Load splits from snapshot. + * + * @param snapshotId the snapshot id + * @return all splits + */ + private List loadSplitsFrom(long snapshotId) { + StreamTableScan streamTableScan = readBuilder.newStreamScan(); + streamTableScan.restore(snapshotId); + long start = System.currentTimeMillis(); + List splits = streamTableScan.plan().splits(); + LOGGER.debug( + "Load splits from snapshot: {}, cost: {}ms", + snapshotId, + System.currentTimeMillis() - start); + return splits; + } + + /** + * Get paimon catalog. + * + * @return the paimon catalog. + */ + private Catalog getPaimonCatalog() { + CatalogContext catalogContext; + if (StringUtils.isBlank(this.path)) { + if (StringUtils.isBlank(this.configJson)) { + catalogContext = Objects.requireNonNull(CatalogContext.create(new Options(options))); + } else { + org.apache.hadoop.conf.Configuration hadoopConf = + new org.apache.hadoop.conf.Configuration(); + for (Map.Entry entry : configs.entrySet()) { + hadoopConf.set(entry.getKey(), entry.getValue()); + } + catalogContext = + Objects.requireNonNull(CatalogContext.create(new Options(options), hadoopConf)); + } + } else { + catalogContext = Objects.requireNonNull(CatalogContext.create(new Path(path))); + } + return Objects.requireNonNull(CatalogFactory.createCatalog(catalogContext)); + } - @Override - public void open(RuntimeContext context) { - Catalog catalog = getPaimonCatalog(); - Identifier identifier = Identifier.create(database, table); + @Override + public void close() { + for (CloseableIterator reader : iterators.values()) { + if (reader != null) { try { - this.readBuilder = Objects.requireNonNull(catalog.getTable(identifier).newReadBuilder()); - if (tableConf.getString(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_MODE) - .equalsIgnoreCase(StartupMode.FROM_SNAPSHOT.getValue()) - && tableConf.contains(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID)) { - this.fromSnapshot = tableConf.getLong(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID); - } else { - this.fromSnapshot = catalog.getTable(identifier).latestSnapshotId().orElse(1L); - } - LOGGER.info("New partition will start from snapshot: {}", this.fromSnapshot); - } catch (TableNotExistException e) { - throw new GeaFlowDSLException("Table: {} in db: {} not exists.", table, database); + reader.close(); + } catch (Exception e) { + throw new GeaFlowDSLException("Error occurs when close paimon iterator.", e); + } + } + } + for (RecordReader reader : readers.values()) { + if (reader != null) { + try { + reader.close(); + } catch (Exception e) { + throw new GeaFlowDSLException("Error occurs when close paimon reader.", e); } - this.iterators = new HashMap<>(); - this.readers = new HashMap<>(); - this.offsets = new HashMap<>(); - LOGGER.info("Open paimon source, tableConf: {}, tableSchema: {}, path: {}, options: " - + "{}, configs: {}, database: {}, tableName: {}", tableConf, tableSchema, path, - options, configs, database, table); - this.deserializer.init(tableSchema); + } + } + iterators.clear(); + offsets.clear(); + } + + public static class PaimonPartition implements Partition { + + private final String database; + private final String table; + + private int index; + private int parallel; + + // current snapshot id + private transient long currentSnapshot; + // assigned splits for this partition + private final transient List splits; + + public PaimonPartition(String database, String table) { + this.database = Objects.requireNonNull(database); + this.table = Objects.requireNonNull(table); + this.splits = new ArrayList<>(); + this.currentSnapshot = -1L; + this.index = 0; + this.parallel = 1; } - @Override - public TableDeserializer getDeserializer(Configuration conf) { - return null; + public void reset(List splits) { + this.splits.clear(); + for (Split split : splits) { + DataSplit dataSplit = (DataSplit) split; + int hash = Objects.hash(dataSplit.bucket(), dataSplit.partition()); + if (hash % parallel == index) { + this.splits.add(split); + this.currentSnapshot = dataSplit.snapshotId(); + } + } + if (!this.splits.isEmpty()) { + LOGGER.info( + "Assign paimon split(s) for table {}.{}, snapshot: {}, split size: {}", + database, + table, + this.currentSnapshot, + this.splits.size()); + } } @Override - public FetchData fetch(Partition partition, Optional startOffset, FetchWindow windowInfo) - throws IOException { - PaimonPartition paimonPartition = (PaimonPartition) partition; - assert paimonPartition.getDatabase().equals(this.database) - && paimonPartition.getTable().equals(this.table); - - long skip = 0; - PaimonOffset innerOffset = offsets.get(partition); - if (innerOffset == null) { - // first fetch, use custom specified snapshot id - innerOffset = new PaimonOffset(fromSnapshot); - } - - // if startOffset is specified, use it and try reset innerOffset - if ((startOffset.isPresent() && !startOffset.get().equals(innerOffset))) { - skip = Math.abs(startOffset.get().getOffset() - innerOffset.getOffset()); - innerOffset = PaimonOffset.from((PaimonOffset) startOffset.get()); - } - if (paimonPartition.getCurrentSnapshot() != innerOffset.getSnapshotId()) { - paimonPartition.reset(loadSplitsFrom(innerOffset.getSnapshotId())); - } - - Split split = paimonPartition.seek(innerOffset.getSplitIndex()); - if (split == null) { - if (sourceMode == SourceMode.BATCH) { - LOGGER.info("No more split to fetch"); - return FetchData.createBatchFetch(Collections.emptyIterator(), innerOffset); - } else { - LOGGER.debug("Snapshot {} not ready now", innerOffset.getSnapshotId()); - return FetchData.createStreamFetch(Collections.emptyList(), innerOffset, false); - } - } - offsets.put(paimonPartition, innerOffset); - - CloseableIterator iterator = createRecordIterator(split); - if (skip > 0) { - while (iterator.hasNext() && skip > 0) { - iterator.next(); - skip--; - } - } - switch (windowInfo.getType()) { - case ALL_WINDOW: - return FetchData.createBatchFetch(new IteratorWrapper(iterator, deserializer), new PaimonOffset()); - case SIZE_TUMBLING_WINDOW: - List readContents = new ArrayList<>(); - long advance = 0; - for (long i = 0; i < windowInfo.windowSize(); i++) { - if (iterator.hasNext()) { - advance++; - readContents.add(deserializer.deserialize(iterator.next())); - } else { - if (sourceMode == SourceMode.BATCH) { - break; - } - try { - removeRecordIterator(split); - } catch (Exception e) { - throw new RuntimeException(e); - } - innerOffset = innerOffset.nextSplit(); - Split seek = paimonPartition.seek(innerOffset.getSplitIndex()); - if (seek == null) { - // all split finished, try read next snapshot - innerOffset = innerOffset.nextSnapshot(); - paimonPartition.reset(loadSplitsFrom(innerOffset.getSnapshotId())); - } - offsets.put(paimonPartition, innerOffset); - advance = 0L; - seek = paimonPartition.seek(innerOffset.getSplitIndex()); - if (seek == null) { - // no new snapshot discovered, retry next turn - break; - } - iterator = createRecordIterator(seek); - i--; - } - } - PaimonOffset next = innerOffset.advance(advance); - offsets.put(paimonPartition, next); - boolean isFinished = sourceMode != SourceMode.STREAM && !iterator.hasNext(); - return FetchData.createStreamFetch(readContents, next, isFinished); - default: - throw new GeaFlowDSLException("Paimon not support window:{}", windowInfo.getType()); - } + public String getName() { + return database + "-" + table; } - /** - * Create a paimon record iterator. - * @param split the paimon data split - * @return a paimon record iterator - * @throws IOException if error occurs - */ - private CloseableIterator createRecordIterator(Split split) throws IOException { - CloseableIterator iterator = iterators.get(split); - if (iterator != null) { - return iterator; - } - RecordReader reader = readBuilder.newRead().createReader(split); - CloseableIterator closeableIterator = reader.toCloseableIterator(); - readers.put(split, reader); - iterators.put(split, closeableIterator); - return closeableIterator; + public String getDatabase() { + return database; } - /** - * Remove the paimon record iterator and reader. - * @param split the paimon data split - * @throws Exception if error occurs - */ - private void removeRecordIterator(Split split) throws Exception { - CloseableIterator removed = iterators.remove(split); - if (removed != null) { - removed.close(); - } - RecordReader reader = readers.remove(split); - if (reader != null) { - reader.close(); - } + public String getTable() { + return table; } - /** - * Load splits from snapshot. - * @param snapshotId the snapshot id - * @return all splits - */ - private List loadSplitsFrom(long snapshotId) { - StreamTableScan streamTableScan = readBuilder.newStreamScan(); - streamTableScan.restore(snapshotId); - long start = System.currentTimeMillis(); - List splits = streamTableScan.plan().splits(); - LOGGER.debug("Load splits from snapshot: {}, cost: {}ms", snapshotId, System.currentTimeMillis() - start); - return splits; + public long getCurrentSnapshot() { + return currentSnapshot; } /** - * Get paimon catalog. - * @return the paimon catalog. + * Seek the split by index. + * + * @param splitIndex the split index + * @return the split to read */ - private Catalog getPaimonCatalog() { - CatalogContext catalogContext; - if (StringUtils.isBlank(this.path)) { - if (StringUtils.isBlank(this.configJson)) { - catalogContext = - Objects.requireNonNull(CatalogContext.create(new Options(options))); - } else { - org.apache.hadoop.conf.Configuration hadoopConf = new org.apache.hadoop.conf.Configuration(); - for (Map.Entry entry : configs.entrySet()) { - hadoopConf.set(entry.getKey(), entry.getValue()); - } - catalogContext = - Objects.requireNonNull(CatalogContext.create(new Options(options), hadoopConf)); - } - } else { - catalogContext = Objects.requireNonNull(CatalogContext.create(new Path(path))); - } - return Objects.requireNonNull(CatalogFactory.createCatalog(catalogContext)); + public Split seek(int splitIndex) { + if (splitIndex >= splits.size()) { + return null; + } + return splits.get(splitIndex); } @Override - public void close() { - for (CloseableIterator reader : iterators.values()) { - if (reader != null) { - try { - reader.close(); - } catch (Exception e) { - throw new GeaFlowDSLException("Error occurs when close paimon iterator.", e); - } - } - } - for ( RecordReader reader : readers.values()) { - if (reader != null) { - try { - reader.close(); - } catch (Exception e) { - throw new GeaFlowDSLException("Error occurs when close paimon reader.", e); - } - } - } - iterators.clear(); - offsets.clear(); + public void setIndex(int index, int parallel) { + this.parallel = parallel; + this.index = index; } - public static class PaimonPartition implements Partition { - - private final String database; - private final String table; - - private int index; - private int parallel; - - // current snapshot id - private transient long currentSnapshot; - // assigned splits for this partition - private final transient List splits; - - public PaimonPartition(String database, String table) { - this.database = Objects.requireNonNull(database); - this.table = Objects.requireNonNull(table); - this.splits = new ArrayList<>(); - this.currentSnapshot = -1L; - this.index = 0; - this.parallel = 1; - } - - public void reset(List splits) { - this.splits.clear(); - for (Split split : splits) { - DataSplit dataSplit = (DataSplit) split; - int hash = Objects.hash(dataSplit.bucket(), dataSplit.partition()); - if (hash % parallel == index) { - this.splits.add(split); - this.currentSnapshot = dataSplit.snapshotId(); - } - } - if (!this.splits.isEmpty()) { - LOGGER.info("Assign paimon split(s) for table {}.{}, snapshot: {}, split size: {}", - database, table, this.currentSnapshot, this.splits.size()); - } - } - - @Override - public String getName() { - return database + "-" + table; - } - - public String getDatabase() { - return database; - } - - public String getTable() { - return table; - } - - public long getCurrentSnapshot() { - return currentSnapshot; - } - - /** - * Seek the split by index. - * @param splitIndex the split index - * @return the split to read - */ - public Split seek(int splitIndex) { - if (splitIndex >= splits.size()) { - return null; - } - return splits.get(splitIndex); - } - - @Override - public void setIndex(int index, int parallel) { - this.parallel = parallel; - this.index = index; - } - - @Override - public int hashCode() { - return Objects.hash(database, table, index); - } + @Override + public int hashCode() { + return Objects.hash(database, table, index); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof PaimonPartition)) { - return false; - } - PaimonPartition that = (PaimonPartition) o; - return Objects.equals(database, that.database) && Objects.equals( - table, that.table) && Objects.equals( - index, that.index); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PaimonPartition)) { + return false; + } + PaimonPartition that = (PaimonPartition) o; + return Objects.equals(database, that.database) + && Objects.equals(table, that.table) + && Objects.equals(index, that.index); } + } - public static class PaimonOffset implements Offset { + public static class PaimonOffset implements Offset { - // if partition snapshot id is not equal to current snapshot id, - // need to load splits - private final long snapshotId; - // the index of current partition assigned splits - private final int splitIndex; - // the offset of current split - private final long offset; + // if partition snapshot id is not equal to current snapshot id, + // need to load splits + private final long snapshotId; + // the index of current partition assigned splits + private final int splitIndex; + // the offset of current split + private final long offset; - public static PaimonOffset from(PaimonOffset offset) { - return new PaimonOffset(offset.snapshotId, offset.splitIndex, offset.offset); - } + public static PaimonOffset from(PaimonOffset offset) { + return new PaimonOffset(offset.snapshotId, offset.splitIndex, offset.offset); + } - public PaimonOffset() { - this(1L); - } + public PaimonOffset() { + this(1L); + } - public PaimonOffset(long snapshotId) { - this(snapshotId, 0, 0); - } + public PaimonOffset(long snapshotId) { + this(snapshotId, 0, 0); + } - public PaimonOffset(long snapshotId, int splitIndex, long offset) { - this.snapshotId = snapshotId; - this.splitIndex = splitIndex; - this.offset = offset; - } + public PaimonOffset(long snapshotId, int splitIndex, long offset) { + this.snapshotId = snapshotId; + this.splitIndex = splitIndex; + this.offset = offset; + } - @Override - public String humanReadable() { - return String.format("snapshot %d, split index %d, offset %d", snapshotId, splitIndex, offset); - } + @Override + public String humanReadable() { + return String.format( + "snapshot %d, split index %d, offset %d", snapshotId, splitIndex, offset); + } - public long getSnapshotId() { - return snapshotId; - } + public long getSnapshotId() { + return snapshotId; + } - public int getSplitIndex() { - return splitIndex; - } + public int getSplitIndex() { + return splitIndex; + } - @Override - public long getOffset() { - return offset; - } + @Override + public long getOffset() { + return offset; + } - public PaimonOffset advance(long rows) { - if (rows == 0) { - return this; - } - return new PaimonOffset(snapshotId, splitIndex, offset + rows); - } + public PaimonOffset advance(long rows) { + if (rows == 0) { + return this; + } + return new PaimonOffset(snapshotId, splitIndex, offset + rows); + } - public PaimonOffset nextSplit() { - return new PaimonOffset(snapshotId, splitIndex + 1, 0); - } + public PaimonOffset nextSplit() { + return new PaimonOffset(snapshotId, splitIndex + 1, 0); + } - public PaimonOffset nextSnapshot() { - return new PaimonOffset(snapshotId + 1, 0, 0); - } + public PaimonOffset nextSnapshot() { + return new PaimonOffset(snapshotId + 1, 0, 0); + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - PaimonOffset that = (PaimonOffset) o; - return snapshotId == that.snapshotId && splitIndex == that.splitIndex && offset == that.offset; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PaimonOffset that = (PaimonOffset) o; + return snapshotId == that.snapshotId + && splitIndex == that.splitIndex + && offset == that.offset; + } - @Override - public int hashCode() { - return Objects.hash(snapshotId, splitIndex, offset); - } + @Override + public int hashCode() { + return Objects.hash(snapshotId, splitIndex, offset); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/SourceMode.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/SourceMode.java index 637de0b0e..f9a175c88 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/SourceMode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/SourceMode.java @@ -21,6 +21,6 @@ /** Paimon source mode. */ public enum SourceMode { - BATCH, - STREAM + BATCH, + STREAM } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/StartupMode.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/StartupMode.java index 56182aab7..b16804d14 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/StartupMode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/main/java/org/apache/geaflow/dsl/connector/paimon/StartupMode.java @@ -20,37 +20,33 @@ package org.apache.geaflow.dsl.connector.paimon; public enum StartupMode { - LATEST( - "latest", - "Read changes starting from the latest snapshot."), - - FROM_SNAPSHOT( - "from-snapshot", - "For streaming sources, continuously reads changes starting from snapshot " - + "specified by \"scan.snapshot-id\", without producing a snapshot at the beginning. " - + "For batch sources, produces a snapshot specified by \"scan.snapshot-id\" " - + "but does not read new changes."); - - private final String value; - private final String description; - - StartupMode(String value, String description) { - this.value = value; - this.description = description; - } - - public String getValue() { - return value; - } - - public String getDescription() { - return description; - } - - @Override - public String toString() { - return value; - } - - + LATEST("latest", "Read changes starting from the latest snapshot."), + + FROM_SNAPSHOT( + "from-snapshot", + "For streaming sources, continuously reads changes starting from snapshot " + + "specified by \"scan.snapshot-id\", without producing a snapshot at the beginning. " + + "For batch sources, produces a snapshot specified by \"scan.snapshot-id\" " + + "but does not read new changes."); + + private final String value; + private final String description; + + StartupMode(String value, String description) { + this.value = value; + this.description = description; + } + + public String getValue() { + return value; + } + + public String getDescription() { + return description; + } + + @Override + public String toString() { + return value; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/test/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/test/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnectorTest.java index edd36ab51..93ba7cc8b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/test/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-paimon/src/test/java/org/apache/geaflow/dsl/connector/paimon/PaimonTableConnectorTest.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; + import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; @@ -72,316 +73,317 @@ public class PaimonTableConnectorTest { - String tmpDir = "/tmp/geaflow/dsl/paimon/test/"; - String db = "paimon_db"; - - GenericRow record1 = GenericRow.of(1, BinaryString.fromString("a1"), 10.0); - GenericRow record2 = GenericRow.of(2, BinaryString.fromString("ab"), 12.0); - GenericRow record3 = GenericRow.of(3, BinaryString.fromString("a3"), 12.0); - GenericRow record4 = GenericRow.of(4, BinaryString.fromString("bcd"), 15.0); - GenericRow record5 = GenericRow.of(5, BinaryString.fromString("a5"), 10.0); - GenericRow record6 = GenericRow.of(6, BinaryString.fromString("s1"), 9.0); - GenericRow record7 = GenericRow.of(7, BinaryString.fromString("sb"), 20.0); - GenericRow record8 = GenericRow.of(8, BinaryString.fromString("s3"), 16.0); - GenericRow record9 = GenericRow.of(9, BinaryString.fromString("bad"), 12.0); - GenericRow record10 = GenericRow.of(10, BinaryString.fromString("aa5"), 11.0); - GenericRow record11 = GenericRow.of(11, BinaryString.fromString("x11"), 11.2); - - private final StructType dataSchema = new StructType( - new TableField("id", Types.INTEGER, false), - new TableField("name", Types.BINARY_STRING), - new TableField("price", Types.DOUBLE) - ); - - private final StructType partitionSchema = new StructType( - new TableField("dt", Types.BINARY_STRING, false) - ); - - private final TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); - - - @BeforeTest - public void prepare() { - FileUtils.deleteQuietly(new File(tmpDir)); - } - - @AfterTest - public void clean() { - FileUtils.deleteQuietly(new File(tmpDir)); - } - - public void createSnapshot(String tableName, List rows) { - CatalogContext catalogContext = - Objects.requireNonNull(CatalogContext.create(new Path(tmpDir))); - Catalog catalog = Objects.requireNonNull(CatalogFactory.createCatalog(catalogContext)); - try { - catalog.createDatabase(db, true); - List dbs = catalog.listDatabases(); - assert dbs.get(0).equals(db); - Identifier identifier = new Identifier(db, tableName); - catalog.createTable(identifier, - Schema.newBuilder() - .column("id", new IntType()) - .column("name", new VarCharType(256)) - .column("price", new DoubleType()) - .build(), true); - List tables = catalog.listTables(dbs.get(0)); - assert tables.contains(tableName); - Table table = catalog.getTable(identifier); - BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); - BatchTableWrite write = writeBuilder.newWrite(); - for (GenericRow row : rows) { - write.write(row); - } - List messages = write.prepareCommit(); - BatchTableCommit commit = writeBuilder.newCommit(); - commit.commit(messages); - } catch (Exception e) { - throw new GeaFlowDSLException("Test error.", e); - } + String tmpDir = "/tmp/geaflow/dsl/paimon/test/"; + String db = "paimon_db"; + + GenericRow record1 = GenericRow.of(1, BinaryString.fromString("a1"), 10.0); + GenericRow record2 = GenericRow.of(2, BinaryString.fromString("ab"), 12.0); + GenericRow record3 = GenericRow.of(3, BinaryString.fromString("a3"), 12.0); + GenericRow record4 = GenericRow.of(4, BinaryString.fromString("bcd"), 15.0); + GenericRow record5 = GenericRow.of(5, BinaryString.fromString("a5"), 10.0); + GenericRow record6 = GenericRow.of(6, BinaryString.fromString("s1"), 9.0); + GenericRow record7 = GenericRow.of(7, BinaryString.fromString("sb"), 20.0); + GenericRow record8 = GenericRow.of(8, BinaryString.fromString("s3"), 16.0); + GenericRow record9 = GenericRow.of(9, BinaryString.fromString("bad"), 12.0); + GenericRow record10 = GenericRow.of(10, BinaryString.fromString("aa5"), 11.0); + GenericRow record11 = GenericRow.of(11, BinaryString.fromString("x11"), 11.2); + + private final StructType dataSchema = + new StructType( + new TableField("id", Types.INTEGER, false), + new TableField("name", Types.BINARY_STRING), + new TableField("price", Types.DOUBLE)); + + private final StructType partitionSchema = + new StructType(new TableField("dt", Types.BINARY_STRING, false)); + + private final TableSchema tableSchema = new TableSchema(dataSchema, partitionSchema); + + @BeforeTest + public void prepare() { + FileUtils.deleteQuietly(new File(tmpDir)); + } + + @AfterTest + public void clean() { + FileUtils.deleteQuietly(new File(tmpDir)); + } + + public void createSnapshot(String tableName, List rows) { + CatalogContext catalogContext = Objects.requireNonNull(CatalogContext.create(new Path(tmpDir))); + Catalog catalog = Objects.requireNonNull(CatalogFactory.createCatalog(catalogContext)); + try { + catalog.createDatabase(db, true); + List dbs = catalog.listDatabases(); + assert dbs.get(0).equals(db); + Identifier identifier = new Identifier(db, tableName); + catalog.createTable( + identifier, + Schema.newBuilder() + .column("id", new IntType()) + .column("name", new VarCharType(256)) + .column("price", new DoubleType()) + .build(), + true); + List tables = catalog.listTables(dbs.get(0)); + assert tables.contains(tableName); + Table table = catalog.getTable(identifier); + BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); + BatchTableWrite write = writeBuilder.newWrite(); + for (GenericRow row : rows) { + write.write(row); + } + List messages = write.prepareCommit(); + BatchTableCommit commit = writeBuilder.newCommit(); + commit.commit(messages); + } catch (Exception e) { + throw new GeaFlowDSLException("Test error.", e); } + } - @Test - public void testReadPaimonStreamMode() throws IOException { - String table = "paimon_stream_table"; + @Test + public void testReadPaimonStreamMode() throws IOException { + String table = "paimon_stream_table"; - Tuple tuple = createTableSource(table, true); - TableSource tableSource = tuple.f0; - Configuration tableConf = tuple.f1; + Tuple tuple = createTableSource(table, true); + TableSource tableSource = tuple.f0; + Configuration tableConf = tuple.f1; - tableSource.init(tableConf, tableSchema); + tableSource.init(tableConf, tableSchema); - // create snapshot 1 - createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); + // create snapshot 1 + createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); - tableSource.open(new DefaultRuntimeContext(tableConf)); + tableSource.open(new DefaultRuntimeContext(tableConf)); - List partitions = tableSource.listPartitions(1); + List partitions = tableSource.listPartitions(1); - Offset nextOffset = null; + Offset nextOffset = null; - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(0L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - nextOffset = rows.getNextOffset(); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[1, a1, 10.0]\n" - + "[2, ab, 12.0]\n" - + "[3, a3, 12.0]\n" - + "[4, bcd, 15.0]"); - readRows.clear(); - - // create snapshot 2 - createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); - - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(1L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[5, a5, 10.0]\n" - + "[6, s1, 9.0]\n" - + "[7, sb, 20.0]\n" - + "[8, s3, 16.0]"); - readRows.clear(); - - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(2L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[9, bad, 12.0]\n" - + "[10, aa5, 11.0]"); - readRows.clear(); - - // no new snapshot - for (int i = 0; i < 2; i++) { - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(3L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - Assert.assertTrue(readRows.isEmpty()); - } - } - - // create snapshot 3 - createSnapshot(table, Lists.newArrayList(record11)); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(3L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - Assert.assertFalse(readRows.isEmpty()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[11, x11, 11.2]"); - readRows.clear(); - - // test restore from offset - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.of(nextOffset), new SizeFetchWindow(4L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[5, a5, 10.0]\n" - + "[6, s1, 9.0]\n" - + "[7, sb, 20.0]\n" - + "[8, s3, 16.0]"); + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(0L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); + nextOffset = rows.getNextOffset(); } - - @Test - public void testReadPaimonBatchMode() throws IOException { - String table = "paimon_batch_table"; - - Tuple tuple = createTableSource(table, false); - TableSource tableSource = tuple.f0; - Configuration tableConf = tuple.f1; - tableSource.init(tableConf, tableSchema); - - createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); - - tableSource.open(new DefaultRuntimeContext(tableConf)); - - List partitions = tableSource.listPartitions(1); - - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(0L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertFalse(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[1, a1, 10.0]\n" - + "[2, ab, 12.0]\n" - + "[3, a3, 12.0]\n" - + "[4, bcd, 15.0]"); - readRows.clear(); - - createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); - - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(1L, 4)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertTrue(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[5, a5, 10.0]"); - readRows.clear(); + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[1, a1, 10.0]\n" + "[2, ab, 12.0]\n" + "[3, a3, 12.0]\n" + "[4, bcd, 15.0]"); + readRows.clear(); + + // create snapshot 2 + createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); + + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(1L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); } - - @Test - public void testReadPaimonFromSnapshot() throws IOException { - String table = "paimon_batch_table_2"; - - TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); - Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); - TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; - Map tableConfMap = getTableConf(table).getConfigMap(); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.BATCH.name()); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_MODE.getKey(), "From-Snapshot"); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID.getKey(), "1"); - Configuration tableConf = new Configuration(tableConfMap); - TableSource tableSource = readableConnector.createSource(tableConf); - tableSource.init(tableConf, tableSchema); - - createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); - createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); - - tableSource.open(new DefaultRuntimeContext(tableConf)); - - List partitions = tableSource.listPartitions(1); - - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(0L)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertTrue(rows.isFinish()); - } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[1, a1, 10.0]\n" - + "[2, ab, 12.0]\n" - + "[3, a3, 12.0]\n" - + "[4, bcd, 15.0]\n" - + "[5, a5, 10.0]"); + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[5, a5, 10.0]\n" + "[6, s1, 9.0]\n" + "[7, sb, 20.0]\n" + "[8, s3, 16.0]"); + readRows.clear(); + + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(2L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); } - - @Test - public void testReadPaimonFromLastesSnapshot() throws IOException { - String table = "paimon_batch_table_3"; - - TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); - Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); - TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; - Configuration tableConf = getTableConf(table); - tableConf.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.BATCH.name()); - TableSource tableSource = readableConnector.createSource(tableConf); - tableSource.init(tableConf, tableSchema); - - createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); - createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); - - tableSource.open(new DefaultRuntimeContext(tableConf)); - - List partitions = tableSource.listPartitions(1); - - List readRows = new ArrayList<>(); - for (Partition partition : partitions) { - FetchData rows = tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(0L)); - while (rows.getDataIterator().hasNext()) { - readRows.add(rows.getDataIterator().next()); - } - Assert.assertTrue(rows.isFinish()); + Assert.assertEquals(StringUtils.join(readRows, "\n"), "[9, bad, 12.0]\n" + "[10, aa5, 11.0]"); + readRows.clear(); + + // no new snapshot + for (int i = 0; i < 2; i++) { + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(3L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); } - Assert.assertEquals(StringUtils.join(readRows, "\n"), - "[6, s1, 9.0]\n" - + "[7, sb, 20.0]\n" - + "[8, s3, 16.0]\n" - + "[9, bad, 12.0]\n" - + "[10, aa5, 11.0]"); + Assert.assertFalse(rows.isFinish()); + Assert.assertTrue(readRows.isEmpty()); + } } - private Configuration getTableConf(String table) { - Map tableConfMap = new HashMap<>(); - tableConfMap.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), tmpDir); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_WAREHOUSE.getKey(), tmpDir); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_DATABASE_NAME.getKey(), db); - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_TABLE_NAME.getKey(), table); - return new Configuration(tableConfMap); + // create snapshot 3 + createSnapshot(table, Lists.newArrayList(record11)); + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(3L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); + Assert.assertFalse(readRows.isEmpty()); } - - private Tuple createTableSource(String table, boolean streamMode) { - TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); - Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); - TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; - Map tableConfMap = getTableConf(table).getConfigMap(); - if (streamMode) { - tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.STREAM.name()); - } - Configuration tableConf = new Configuration(tableConfMap); - return Tuple.of(readableConnector.createSource(tableConf), tableConf); + Assert.assertEquals(StringUtils.join(readRows, "\n"), "[11, x11, 11.2]"); + readRows.clear(); + + // test restore from offset + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.of(nextOffset), new SizeFetchWindow(4L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); } - + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[5, a5, 10.0]\n" + "[6, s1, 9.0]\n" + "[7, sb, 20.0]\n" + "[8, s3, 16.0]"); + } + + @Test + public void testReadPaimonBatchMode() throws IOException { + String table = "paimon_batch_table"; + + Tuple tuple = createTableSource(table, false); + TableSource tableSource = tuple.f0; + Configuration tableConf = tuple.f1; + tableSource.init(tableConf, tableSchema); + + createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); + + tableSource.open(new DefaultRuntimeContext(tableConf)); + + List partitions = tableSource.listPartitions(1); + + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(0L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertFalse(rows.isFinish()); + } + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[1, a1, 10.0]\n" + "[2, ab, 12.0]\n" + "[3, a3, 12.0]\n" + "[4, bcd, 15.0]"); + readRows.clear(); + + createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); + + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new SizeFetchWindow(1L, 4)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertTrue(rows.isFinish()); + } + Assert.assertEquals(StringUtils.join(readRows, "\n"), "[5, a5, 10.0]"); + readRows.clear(); + } + + @Test + public void testReadPaimonFromSnapshot() throws IOException { + String table = "paimon_batch_table_2"; + + TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); + Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); + TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; + Map tableConfMap = getTableConf(table).getConfigMap(); + tableConfMap.put( + PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.BATCH.name()); + tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_MODE.getKey(), "From-Snapshot"); + tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SCAN_SNAPSHOT_ID.getKey(), "1"); + Configuration tableConf = new Configuration(tableConfMap); + TableSource tableSource = readableConnector.createSource(tableConf); + tableSource.init(tableConf, tableSchema); + + createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); + createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); + + tableSource.open(new DefaultRuntimeContext(tableConf)); + + List partitions = tableSource.listPartitions(1); + + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(0L)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertTrue(rows.isFinish()); + } + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[1, a1, 10.0]\n" + + "[2, ab, 12.0]\n" + + "[3, a3, 12.0]\n" + + "[4, bcd, 15.0]\n" + + "[5, a5, 10.0]"); + } + + @Test + public void testReadPaimonFromLastesSnapshot() throws IOException { + String table = "paimon_batch_table_3"; + + TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); + Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); + TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; + Configuration tableConf = getTableConf(table); + tableConf.put( + PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.BATCH.name()); + TableSource tableSource = readableConnector.createSource(tableConf); + tableSource.init(tableConf, tableSchema); + + createSnapshot(table, Lists.newArrayList(record1, record2, record3, record4, record5)); + createSnapshot(table, Lists.newArrayList(record6, record7, record8, record9, record10)); + + tableSource.open(new DefaultRuntimeContext(tableConf)); + + List partitions = tableSource.listPartitions(1); + + List readRows = new ArrayList<>(); + for (Partition partition : partitions) { + FetchData rows = + tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(0L)); + while (rows.getDataIterator().hasNext()) { + readRows.add(rows.getDataIterator().next()); + } + Assert.assertTrue(rows.isFinish()); + } + Assert.assertEquals( + StringUtils.join(readRows, "\n"), + "[6, s1, 9.0]\n" + + "[7, sb, 20.0]\n" + + "[8, s3, 16.0]\n" + + "[9, bad, 12.0]\n" + + "[10, aa5, 11.0]"); + } + + private Configuration getTableConf(String table) { + Map tableConfMap = new HashMap<>(); + tableConfMap.put(ConnectorConfigKeys.GEAFLOW_DSL_FILE_PATH.getKey(), tmpDir); + tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_WAREHOUSE.getKey(), tmpDir); + tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_DATABASE_NAME.getKey(), db); + tableConfMap.put(PaimonConfigKeys.GEAFLOW_DSL_PAIMON_TABLE_NAME.getKey(), table); + return new Configuration(tableConfMap); + } + + private Tuple createTableSource(String table, boolean streamMode) { + TableConnector tableConnector = ConnectorFactory.loadConnector("PAIMON"); + Assert.assertEquals(tableConnector.getType().toLowerCase(Locale.ROOT), "paimon"); + TableReadableConnector readableConnector = (TableReadableConnector) tableConnector; + Map tableConfMap = getTableConf(table).getConfigMap(); + if (streamMode) { + tableConfMap.put( + PaimonConfigKeys.GEAFLOW_DSL_PAIMON_SOURCE_MODE.getKey(), SourceMode.STREAM.name()); + } + Configuration tableConf = new Configuration(tableConfMap); + return Tuple.of(readableConnector.createSource(tableConf), tableConf); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarConfigKeys.java index b07ed5376..201383b2a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarConfigKeys.java @@ -24,30 +24,28 @@ public class PulsarConfigKeys { - public static final ConfigKey GEAFLOW_DSL_PULSAR_SERVERS = ConfigKeys - .key("geaflow.dsl.pulsar.servers") - .noDefaultValue() - .description("The pulsar bootstrap servers list."); - - public static final ConfigKey GEAFLOW_DSL_PULSAR_PORT = ConfigKeys - .key("geaflow.dsl.pulsar.port") - .noDefaultValue() - .description("The pulsar bootstrap servers list."); - - public static final ConfigKey GEAFLOW_DSL_PULSAR_TOPIC = ConfigKeys - .key("geaflow.dsl.pulsar.topic") - .noDefaultValue() - .description("The pulsar topic."); - - public static final ConfigKey GEAFLOW_DSL_PULSAR_SUBSCRIBE_NAME = ConfigKeys - .key("geaflow.dsl.pulsar.subscribeName") - .defaultValue("default-subscribeName") - .description("The pulsar subscribeName, default is 'default-subscribeName'."); - - public static final ConfigKey GEAFLOW_DSL_PULSAR_SUBSCRIBE_INITIAL_POSITION = ConfigKeys - .key("geaflow.dsl.pulsar.subscriptionInitialPosition") - .defaultValue("latest") - .description("The pulsar subscriptionInitialPosition, default is 'default-subscriptionInitialPosition'."); - - + public static final ConfigKey GEAFLOW_DSL_PULSAR_SERVERS = + ConfigKeys.key("geaflow.dsl.pulsar.servers") + .noDefaultValue() + .description("The pulsar bootstrap servers list."); + + public static final ConfigKey GEAFLOW_DSL_PULSAR_PORT = + ConfigKeys.key("geaflow.dsl.pulsar.port") + .noDefaultValue() + .description("The pulsar bootstrap servers list."); + + public static final ConfigKey GEAFLOW_DSL_PULSAR_TOPIC = + ConfigKeys.key("geaflow.dsl.pulsar.topic").noDefaultValue().description("The pulsar topic."); + + public static final ConfigKey GEAFLOW_DSL_PULSAR_SUBSCRIBE_NAME = + ConfigKeys.key("geaflow.dsl.pulsar.subscribeName") + .defaultValue("default-subscribeName") + .description("The pulsar subscribeName, default is 'default-subscribeName'."); + + public static final ConfigKey GEAFLOW_DSL_PULSAR_SUBSCRIBE_INITIAL_POSITION = + ConfigKeys.key("geaflow.dsl.pulsar.subscriptionInitialPosition") + .defaultValue("latest") + .description( + "The pulsar subscriptionInitialPosition, default is" + + " 'default-subscriptionInitialPosition'."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnector.java index 1780e26e6..2658b109c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnector.java @@ -27,20 +27,20 @@ public class PulsarTableConnector implements TableReadableConnector, TableWritableConnector { - public static final String TYPE = "PULSAR"; + public static final String TYPE = "PULSAR"; - @Override - public TableSource createSource(Configuration conf) { - return new PulsarTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new PulsarTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new PulsarTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new PulsarTableSink(); + } - @Override - public String getType() { - return TYPE; - } + @Override + public String getType() { + return TYPE; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSink.java index c05ca9697..f426f2a46 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSink.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -37,104 +38,103 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class PulsarTableSink implements TableSink { - private static final Logger LOGGER = LoggerFactory.getLogger(TableSink.class); - - private Configuration tableConf; - private StructType schema; - private String servers; - private String topic; - private int maxPendingMessage; - private int maxMessages; - private int maxPublishDelay; - - private transient PulsarClient pulsarClient; - private transient Producer producer; - - private MessageRoutingMode messageRoutingMode; - private String separator; - - private void createPulsarProducer() { - - if (messageRoutingMode == null) { - messageRoutingMode = MessageRoutingMode.SinglePartition; - } - try { - pulsarClient = PulsarClient.builder().serviceUrl(servers).build(); - producer = pulsarClient.newProducer(Schema.STRING) - .topic(topic) - .maxPendingMessages(maxPendingMessage) - .messageRoutingMode(messageRoutingMode) - .batchingMaxMessages(maxMessages) - .batchingMaxPublishDelay(maxPublishDelay, TimeUnit.MILLISECONDS) - .create(); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException("create pulsar producer error, exception is {}", e); - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(TableSink.class); - @Override - public void init(Configuration conf, StructType schema) { - tableConf = conf; - String port = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_PORT); - String[] serversAddress = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SERVERS).split(","); - servers = "pulsar://" + String.join(":" + port + ",", serversAddress); - - topic = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_TOPIC); - this.schema = schema; - separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); - maxPendingMessage = PulsarConstants.PULSAR_MAX_PENDING_MESSAGES; - maxMessages = PulsarConstants.PULSAR_BATCHING_MAX_MESSAGES; - maxPublishDelay = PulsarConstants.PULSAR_BATCHING_MAX_PUBLISH_DELAY; - } + private Configuration tableConf; + private StructType schema; + private String servers; + private String topic; + private int maxPendingMessage; + private int maxMessages; + private int maxPublishDelay; - @Override - public void open(RuntimeContext context) { - createPulsarProducer(); - } + private transient PulsarClient pulsarClient; + private transient Producer producer; - @Override - public void write(Row row) { - Object[] values = row.getFields(schema.getTypes()); - StringBuilder line = new StringBuilder(); - for (Object value : values) { - if (line.length() > 0) { - line.append(separator); - } - line.append(value); - } - try { - producer.send(line.toString()); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException("pulsar producer send message error, exception is {}", e); - } + private MessageRoutingMode messageRoutingMode; + private String separator; - } + private void createPulsarProducer() { - @Override - public void finish() throws IOException { - if (producer != null) { - producer.flush(); - } else { - assert producer != null; - LOGGER.warn("Producer is null."); - } + if (messageRoutingMode == null) { + messageRoutingMode = MessageRoutingMode.SinglePartition; } - - @Override - public void close() { - if (producer != null) { - try { - producer.close(); - pulsarClient.close(); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException("pulsar client close error, exception is {}", e); - } - } - LOGGER.info("close pulsar client"); + try { + pulsarClient = PulsarClient.builder().serviceUrl(servers).build(); + producer = + pulsarClient + .newProducer(Schema.STRING) + .topic(topic) + .maxPendingMessages(maxPendingMessage) + .messageRoutingMode(messageRoutingMode) + .batchingMaxMessages(maxMessages) + .batchingMaxPublishDelay(maxPublishDelay, TimeUnit.MILLISECONDS) + .create(); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException("create pulsar producer error, exception is {}", e); } - - + } + + @Override + public void init(Configuration conf, StructType schema) { + tableConf = conf; + String port = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_PORT); + String[] serversAddress = + conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SERVERS).split(","); + servers = "pulsar://" + String.join(":" + port + ",", serversAddress); + + topic = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_TOPIC); + this.schema = schema; + separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + maxPendingMessage = PulsarConstants.PULSAR_MAX_PENDING_MESSAGES; + maxMessages = PulsarConstants.PULSAR_BATCHING_MAX_MESSAGES; + maxPublishDelay = PulsarConstants.PULSAR_BATCHING_MAX_PUBLISH_DELAY; + } + + @Override + public void open(RuntimeContext context) { + createPulsarProducer(); + } + + @Override + public void write(Row row) { + Object[] values = row.getFields(schema.getTypes()); + StringBuilder line = new StringBuilder(); + for (Object value : values) { + if (line.length() > 0) { + line.append(separator); + } + line.append(value); + } + try { + producer.send(line.toString()); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException("pulsar producer send message error, exception is {}", e); + } + } + + @Override + public void finish() throws IOException { + if (producer != null) { + producer.flush(); + } else { + assert producer != null; + LOGGER.warn("Producer is null."); + } + } + + @Override + public void close() { + if (producer != null) { + try { + producer.close(); + pulsarClient.close(); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException("pulsar client close error, exception is {}", e); + } + } + LOGGER.info("close pulsar client"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSource.java index 5615a575d..b05e0a991 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableSource.java @@ -30,6 +30,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -61,240 +62,242 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class PulsarTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(PulsarTableSource.class); - private Configuration tableConf; - private String servers; - private String topic; - private SubscriptionType subscribeType; - private int receiverQueueSize; - private int negativeAckRedeliveryDelay; - private TimeUnit timeUnit; - private String subscriptionName; - private SubscriptionInitialPosition subscriptionInitialPosition; - private long windowSize; + private static final Logger LOGGER = LoggerFactory.getLogger(PulsarTableSource.class); + private Configuration tableConf; + private String servers; + private String topic; + private SubscriptionType subscribeType; + private int receiverQueueSize; + private int negativeAckRedeliveryDelay; + private TimeUnit timeUnit; + private String subscriptionName; + private SubscriptionInitialPosition subscriptionInitialPosition; + private long windowSize; + + private transient PulsarClient pulsarClient; + + private transient Map> consumers; + + @Override + public void init(Configuration conf, TableSchema tableSchema) { + this.tableConf = conf; + String port = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_PORT); + String[] serversAddress = + conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SERVERS).split(","); + servers = "pulsar://" + String.join(":" + port + ",", serversAddress); + topic = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_TOPIC); + subscriptionName = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SUBSCRIBE_NAME); + + String position = + conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SUBSCRIBE_INITIAL_POSITION); + if (position.equals("earliest")) { + this.subscriptionInitialPosition = SubscriptionInitialPosition.Earliest; + } else if (position.equals("latest")) { + this.subscriptionInitialPosition = SubscriptionInitialPosition.Latest; + } else { + throw new GeaFlowDSLException("Invalid subscription initial position: {}", position); + } + subscribeType = PulsarConstants.PULSAR_SUBSCRIBE_TYPE; + negativeAckRedeliveryDelay = PulsarConstants.PULSAR_NEGATIVE_ACK_REDELIVERY; + timeUnit = PulsarConstants.PULSAR_NEGATIVE_ACK_REDELIVERY_UNIT; + receiverQueueSize = PulsarConstants.PULSAR_RECEIVER_QUEUE_SIZE; + + this.windowSize = conf.getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); + if (this.windowSize == Windows.SIZE_OF_ALL_WINDOW) { + throw new GeaFlowDSLException("Pulsar cannot support all window"); + } else if (windowSize <= 0) { + throw new GeaFlowDSLException("Invalid window size: {}", windowSize); + } + } + + private void createPulsarClient() { + try { + this.pulsarClient = PulsarClient.builder().serviceUrl(this.servers).build(); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException(" fail to create pulsar client, exception is {}", e); + } + } + + private Consumer createPulsarConsumer(String partitionName) { + Consumer consumer; + try { + consumer = + pulsarClient + .newConsumer(Schema.STRING) + .topic(partitionName) + .subscriptionName(subscriptionName) + .subscriptionType(subscribeType) + .subscriptionInitialPosition(subscriptionInitialPosition) + .negativeAckRedeliveryDelay(negativeAckRedeliveryDelay, timeUnit) + .batchReceivePolicy( + new BatchReceivePolicy.Builder().maxNumMessages((int) windowSize).build()) + .receiverQueueSize(receiverQueueSize) + .subscribe(); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException( + "fail to create pulsar consumer, topic name is {}", partitionName); + } + return consumer; + } + + @Override + public void open(RuntimeContext context) { + createPulsarClient(); + consumers = new HashMap<>(); + LOGGER.info("pulsar client created successfully"); + } + + @Override + public List listPartitions() { + List partitionNameList; + try { + partitionNameList = pulsarClient.getPartitionsForTopic(topic).get(); + } catch (InterruptedException | ExecutionException e) { + throw new GeaFlowDSLException("get partitions for topic fail, the topic is {}", topic); + } + if (partitionNameList == null) { + throw new GeaFlowDSLException( + "Obtain an empty partition list through pulsarClient, the topic name is", topic); + } + return partitionNameList.stream().map(PulsarPartition::new).collect(Collectors.toList()); + } - private transient PulsarClient pulsarClient; + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadTextDeserializer(); + } - private transient Map> consumers; + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { - @Override - public void init(Configuration conf, TableSchema tableSchema) { - this.tableConf = conf; - String port = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_PORT); - String[] serversAddress = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SERVERS).split(","); - servers = "pulsar://" + String.join(":" + port + ",", serversAddress); - topic = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_TOPIC); - subscriptionName = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SUBSCRIBE_NAME); - - String position = conf.getString(PulsarConfigKeys.GEAFLOW_DSL_PULSAR_SUBSCRIBE_INITIAL_POSITION); - if (position.equals("earliest")) { - this.subscriptionInitialPosition = SubscriptionInitialPosition.Earliest; - } else if (position.equals("latest")) { - this.subscriptionInitialPosition = SubscriptionInitialPosition.Latest; - } else { - throw new GeaFlowDSLException("Invalid subscription initial position: {}", position); - } - subscribeType = PulsarConstants.PULSAR_SUBSCRIBE_TYPE; - negativeAckRedeliveryDelay = PulsarConstants.PULSAR_NEGATIVE_ACK_REDELIVERY; - timeUnit = PulsarConstants.PULSAR_NEGATIVE_ACK_REDELIVERY_UNIT; - receiverQueueSize = PulsarConstants.PULSAR_RECEIVER_QUEUE_SIZE; - - this.windowSize = conf.getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); - if (this.windowSize == Windows.SIZE_OF_ALL_WINDOW) { - throw new GeaFlowDSLException("Pulsar cannot support all window"); - } else if (windowSize <= 0) { - throw new GeaFlowDSLException("Invalid window size: {}", windowSize); - } + if (windowInfo.getType() != WindowType.SIZE_TUMBLING_WINDOW) { + throw new GeaFlowDSLException("Pulsar cannot support window type:{}", windowInfo.getType()); + } + if (windowInfo.windowSize() <= 0) { + throw new GeaFlowDSLException("Invalid window size: {}", windowInfo.windowSize()); + } + windowSize = windowInfo.windowSize(); + String partitionName = partition.getName(); + Consumer consumer = consumers.get(partitionName); + if (consumer == null) { + consumer = createPulsarConsumer(partitionName); + consumers.put(partitionName, consumer); + } + assert consumer != null; + PulsarOffset pulsarOffset; + + boolean isTimeStamp = false; + if (startOffset.isPresent()) { + pulsarOffset = (PulsarOffset) startOffset.get(); + if (pulsarOffset.isTimestamp()) { + consumer.seek(pulsarOffset.getOffset()); + isTimeStamp = true; + } else { + consumer.seek(pulsarOffset.getMessageId()); + } } - private void createPulsarClient() { - try { - this.pulsarClient = PulsarClient.builder().serviceUrl(this.servers).build(); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException(" fail to create pulsar client, exception is {}", e); - } + List dataList = new ArrayList<>(); + Messages messages = consumer.batchReceive(); + Iterator> iterator = messages.iterator(); + long timeOffset = 0L; + MessageId lastMessageId = MessageId.earliest; + while (iterator.hasNext()) { + Message message = iterator.next(); + dataList.add(message.getValue()); + timeOffset = message.getPublishTime(); + lastMessageId = message.getMessageId(); + LOGGER.info("receive message: " + message.getValue()); + LOGGER.info("object address is: " + this.hashCode()); } - private Consumer createPulsarConsumer(String partitionName) { - Consumer consumer; - try { - consumer = pulsarClient.newConsumer(Schema.STRING) - .topic(partitionName) - .subscriptionName(subscriptionName) - .subscriptionType(subscribeType) - .subscriptionInitialPosition(subscriptionInitialPosition) - .negativeAckRedeliveryDelay(negativeAckRedeliveryDelay, timeUnit) - .batchReceivePolicy(new BatchReceivePolicy.Builder().maxNumMessages((int) windowSize).build()) - .receiverQueueSize(receiverQueueSize) - .subscribe(); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException("fail to create pulsar consumer, topic name is {}", partitionName); - } - return consumer; + PulsarOffset newOffset; + if (isTimeStamp) { + newOffset = new PulsarOffset(timeOffset); + } else { + newOffset = new PulsarOffset(lastMessageId); } + TopicName.get(topic).getPartitionedTopicName(); - @Override - public void open(RuntimeContext context) { - createPulsarClient(); - consumers = new HashMap<>(); - LOGGER.info("pulsar client created successfully"); + return (FetchData) FetchData.createStreamFetch(dataList, newOffset, false); + } + + @Override + public void close() { + try { + pulsarClient.close(); + } catch (PulsarClientException e) { + throw new GeaFlowDSLException("fail to close pulsar client, the exception is {}", e); } + LOGGER.info("close pulsar client"); + } - @Override - public List listPartitions() { - List partitionNameList; - try { - partitionNameList = pulsarClient.getPartitionsForTopic(topic).get(); - } catch (InterruptedException | ExecutionException e) { - throw new GeaFlowDSLException("get partitions for topic fail, the topic is {}", topic); - } - if (partitionNameList == null) { - throw new GeaFlowDSLException("Obtain an empty partition list through pulsarClient, the topic name is", - topic); - } - return partitionNameList.stream().map(PulsarPartition::new) - .collect(Collectors.toList()); + public static class PulsarPartition implements Partition { + + private final String topicWithPartition; + + public PulsarPartition(String topicWithPartition) { + this.topicWithPartition = topicWithPartition; } @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadTextDeserializer(); + public String getName() { + return topicWithPartition; } @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - - if (windowInfo.getType() != WindowType.SIZE_TUMBLING_WINDOW) { - throw new GeaFlowDSLException("Pulsar cannot support window type:{}", windowInfo.getType()); - } - if (windowInfo.windowSize() <= 0) { - throw new GeaFlowDSLException("Invalid window size: {}", windowInfo.windowSize()); - } - windowSize = windowInfo.windowSize(); - String partitionName = partition.getName(); - Consumer consumer = consumers.get(partitionName); - if (consumer == null) { - consumer = createPulsarConsumer(partitionName); - consumers.put(partitionName, consumer); - } - assert consumer != null; - PulsarOffset pulsarOffset; - - boolean isTimeStamp = false; - if (startOffset.isPresent()) { - pulsarOffset = (PulsarOffset) startOffset.get(); - if (pulsarOffset.isTimestamp()) { - consumer.seek(pulsarOffset.getOffset()); - isTimeStamp = true; - } else { - consumer.seek(pulsarOffset.getMessageId()); - } - } - - List dataList = new ArrayList<>(); - Messages messages = consumer.batchReceive(); - Iterator> iterator = messages.iterator(); - long timeOffset = 0L; - MessageId lastMessageId = MessageId.earliest; - while (iterator.hasNext()) { - Message message = iterator.next(); - dataList.add(message.getValue()); - timeOffset = message.getPublishTime(); - lastMessageId = message.getMessageId(); - LOGGER.info("receive message: " + message.getValue()); - LOGGER.info("object address is: " + this.hashCode()); - - } - - PulsarOffset newOffset; - if (isTimeStamp) { - newOffset = new PulsarOffset(timeOffset); - } else { - newOffset = new PulsarOffset(lastMessageId); - } - TopicName.get(topic).getPartitionedTopicName(); - - return (FetchData) FetchData.createStreamFetch(dataList, newOffset, false); + public int hashCode() { + return Objects.hash(topicWithPartition); } @Override - public void close() { - try { - pulsarClient.close(); - } catch (PulsarClientException e) { - throw new GeaFlowDSLException("fail to close pulsar client, the exception is {}", e); - } - LOGGER.info("close pulsar client"); + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PulsarPartition)) { + return false; + } + PulsarPartition that = (PulsarPartition) o; + return Objects.equals(topicWithPartition, that.topicWithPartition); + } + } + + public static class PulsarOffset implements Offset { + private final MessageId messageId; + private final long timeStamp; + + public PulsarOffset(MessageId messageId) { + this.messageId = messageId; + timeStamp = 0L; + } + public PulsarOffset(long timeStamp) { + this.messageId = null; + this.timeStamp = timeStamp; } - public static class PulsarPartition implements Partition { - - private final String topicWithPartition; - - public PulsarPartition(String topicWithPartition) { - this.topicWithPartition = topicWithPartition; - } - - @Override - public String getName() { - return topicWithPartition; - } - - @Override - public int hashCode() { - return Objects.hash(topicWithPartition); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof PulsarPartition)) { - return false; - } - PulsarPartition that = (PulsarPartition) o; - return Objects.equals(topicWithPartition, that.topicWithPartition); - } + @Override + public String humanReadable() { + return DateTimeUtil.fromUnixTime(timeStamp, ConnectorConstants.START_TIME_FORMAT); } - public static class PulsarOffset implements Offset { - private final MessageId messageId; - private final long timeStamp; - - public PulsarOffset(MessageId messageId) { - this.messageId = messageId; - timeStamp = 0L; - } - - public PulsarOffset(long timeStamp) { - this.messageId = null; - this.timeStamp = timeStamp; - } - - @Override - public String humanReadable() { - return DateTimeUtil.fromUnixTime(timeStamp, ConnectorConstants.START_TIME_FORMAT); - } - - @Override - public long getOffset() { - return timeStamp; - } - - @Override - public boolean isTimestamp() { - return messageId == null; - } - - public MessageId getMessageId() { - return messageId; - } + @Override + public long getOffset() { + return timeStamp; } + @Override + public boolean isTimestamp() { + return messageId == null; + } + + public MessageId getMessageId() { + return messageId; + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/utils/PulsarConstants.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/utils/PulsarConstants.java index 54a9621fc..826163c84 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/utils/PulsarConstants.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/main/java/org/apache/geaflow/dsl/connector/pulsar/utils/PulsarConstants.java @@ -20,18 +20,17 @@ package org.apache.geaflow.dsl.connector.pulsar.utils; import java.util.concurrent.TimeUnit; + import org.apache.pulsar.client.api.SubscriptionType; public class PulsarConstants { - public static final SubscriptionType PULSAR_SUBSCRIBE_TYPE = SubscriptionType.Exclusive; - public static final int PULSAR_NEGATIVE_ACK_REDELIVERY = 10; - public static final TimeUnit PULSAR_NEGATIVE_ACK_REDELIVERY_UNIT = TimeUnit.SECONDS; - public static final int PULSAR_MAX_REDELIVER_COUNT = 10; - public static final int PULSAR_RECEIVER_QUEUE_SIZE = 1000; - public static final int PULSAR_MAX_PENDING_MESSAGES = 1000; - public static final int PULSAR_BATCHING_MAX_MESSAGES = 1000; - public static final int PULSAR_BATCHING_MAX_PUBLISH_DELAY = 10; - - + public static final SubscriptionType PULSAR_SUBSCRIBE_TYPE = SubscriptionType.Exclusive; + public static final int PULSAR_NEGATIVE_ACK_REDELIVERY = 10; + public static final TimeUnit PULSAR_NEGATIVE_ACK_REDELIVERY_UNIT = TimeUnit.SECONDS; + public static final int PULSAR_MAX_REDELIVER_COUNT = 10; + public static final int PULSAR_RECEIVER_QUEUE_SIZE = 1000; + public static final int PULSAR_MAX_PENDING_MESSAGES = 1000; + public static final int PULSAR_BATCHING_MAX_MESSAGES = 1000; + public static final int PULSAR_BATCHING_MAX_PUBLISH_DELAY = 10; } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/test/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/test/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnectorTest.java index 53b838a95..3f9c45260 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/test/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-pulsar/src/test/java/org/apache/geaflow/dsl/connector/pulsar/PulsarTableConnectorTest.java @@ -27,34 +27,32 @@ import org.junit.Test; public class PulsarTableConnectorTest { - public static final String server = "pulsar://localhost:6650"; - public static final String topic = "persistent://test/test_pulsar_connector/non_partition_topic"; - public static final String partitionTopic = "persistent://test/test_pulsar_connector/partition_topic"; - - - @Test - public void testPulsarPartition() { - PulsarPartition partition = new PulsarPartition("topic"); - Assert.assertEquals(partition.getName(), "topic"); - - PulsarPartition partition2 = new PulsarPartition("topic"); - Assert.assertEquals(partition.hashCode(), partition2.hashCode()); - Assert.assertEquals(partition, partition); - Assert.assertEquals(partition, partition2); - Assert.assertNotEquals(partition, null); - } - - @Test - public void testPulsarOffset() { - PulsarOffset offsetByMessageId = new PulsarOffset(MessageId.earliest); - Assert.assertEquals(offsetByMessageId.getMessageId(), - DefaultImplementation.newMessageId(-1L, -1L, -1)); - Assert.assertEquals(offsetByMessageId.getOffset(), 0L); - - PulsarOffset offsetByTimeStamp = new PulsarOffset(11111111L); - Assert.assertEquals(offsetByTimeStamp.getOffset(), 11111111L); - Assert.assertNull(offsetByTimeStamp.getMessageId()); - - } - + public static final String server = "pulsar://localhost:6650"; + public static final String topic = "persistent://test/test_pulsar_connector/non_partition_topic"; + public static final String partitionTopic = + "persistent://test/test_pulsar_connector/partition_topic"; + + @Test + public void testPulsarPartition() { + PulsarPartition partition = new PulsarPartition("topic"); + Assert.assertEquals(partition.getName(), "topic"); + + PulsarPartition partition2 = new PulsarPartition("topic"); + Assert.assertEquals(partition.hashCode(), partition2.hashCode()); + Assert.assertEquals(partition, partition); + Assert.assertEquals(partition, partition2); + Assert.assertNotEquals(partition, null); + } + + @Test + public void testPulsarOffset() { + PulsarOffset offsetByMessageId = new PulsarOffset(MessageId.earliest); + Assert.assertEquals( + offsetByMessageId.getMessageId(), DefaultImplementation.newMessageId(-1L, -1L, -1)); + Assert.assertEquals(offsetByMessageId.getOffset(), 0L); + + PulsarOffset offsetByTimeStamp = new PulsarOffset(11111111L); + Assert.assertEquals(offsetByTimeStamp.getOffset(), 11111111L); + Assert.assertNull(offsetByTimeStamp.getMessageId()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomSourceConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomSourceConfigKeys.java index a7fdd13a0..822eb2dec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomSourceConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomSourceConfigKeys.java @@ -24,13 +24,13 @@ public class RandomSourceConfigKeys { - public static final ConfigKey GEAFLOW_DSL_RANDOM_SOURCE_RATE = ConfigKeys - .key("geaflow.dsl.random.source.tuples.per.second") - .defaultValue(1.0) - .description("Random source create tuple number per second"); + public static final ConfigKey GEAFLOW_DSL_RANDOM_SOURCE_RATE = + ConfigKeys.key("geaflow.dsl.random.source.tuples.per.second") + .defaultValue(1.0) + .description("Random source create tuple number per second"); - public static final ConfigKey GEAFLOW_DSL_RANDOM_SOURCE_MAX_BATCH = ConfigKeys - .key("geaflow.dsl.random.source.max.batch") - .defaultValue(3L) - .description("Random source create batch max num"); + public static final ConfigKey GEAFLOW_DSL_RANDOM_SOURCE_MAX_BATCH = + ConfigKeys.key("geaflow.dsl.random.source.max.batch") + .defaultValue(3L) + .description("Random source create batch max num"); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableConnector.java index afc15abcd..4bb11b199 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableConnector.java @@ -25,13 +25,13 @@ public class RandomTableConnector implements TableReadableConnector { - @Override - public String getType() { - return "RANDOM"; - } + @Override + public String getType() { + return "RANDOM"; + } - @Override - public TableSource createSource(Configuration conf) { - return new RandomTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new RandomTableSource(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableSource.java index 360c983fb..2c1f8bbe8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/main/java/org/apache/geaflow/dsl/connector/random/RandomTableSource.java @@ -25,6 +25,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Random; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.binary.BinaryString; @@ -49,177 +50,177 @@ public class RandomTableSource implements TableSource { - private static final Logger LOGGER = LoggerFactory.getLogger(RandomTableSource.class); - - private TableSchema schema = null; - private double rate = 1.0; - private long maxBatch = 100L; - - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.schema = tableSchema; - this.rate = tableConf.getDouble(RandomSourceConfigKeys.GEAFLOW_DSL_RANDOM_SOURCE_RATE); - this.maxBatch = tableConf.getLong(RandomSourceConfigKeys.GEAFLOW_DSL_RANDOM_SOURCE_MAX_BATCH); + private static final Logger LOGGER = LoggerFactory.getLogger(RandomTableSource.class); + + private TableSchema schema = null; + private double rate = 1.0; + private long maxBatch = 100L; + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.schema = tableSchema; + this.rate = tableConf.getDouble(RandomSourceConfigKeys.GEAFLOW_DSL_RANDOM_SOURCE_RATE); + this.maxBatch = tableConf.getLong(RandomSourceConfigKeys.GEAFLOW_DSL_RANDOM_SOURCE_MAX_BATCH); + } + + @Override + public void open(RuntimeContext context) { + this.hashCode(); + } + + private static List singletonPartition() { + List singletonPartition = new ArrayList<>(); + singletonPartition.add(new RandomPartition()); + return singletonPartition; + } + + @Override + public List listPartitions() { + return singletonPartition(); + } + + @Override + public List listPartitions(int parallelism) { + return singletonPartition(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return (TableDeserializer) new RowTableDeserializer(); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + long windowSize; + if (windowInfo.getType() == WindowType.ALL_WINDOW) { + windowSize = (long) rate; + } else if (windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW) { + // Control wait time based on the rate. + windowSize = windowInfo.windowSize(); + SleepUtils.sleepMilliSecond((long) (1000 * (windowSize / rate))); + } else if (windowInfo.getType() == WindowType.FIXED_TIME_TUMBLING_WINDOW) { + // Control the number of tuples based on the rate. + if (windowInfo.windowSize() > Integer.MAX_VALUE) { + throw new GeaFlowDSLException( + "Random table source window size is overflow:{}", windowInfo.windowSize()); + } + long seconds = (int) windowInfo.windowSize(); + windowSize = (long) (seconds * rate); + } else { + throw new GeaFlowDSLException( + "File table source not support window:{}", windowInfo.getType()); } - - @Override - public void open(RuntimeContext context) { - this.hashCode(); + List randomContents = new ArrayList<>(); + Random random = new Random(); + for (long i = 0; i < windowSize; i++) { + Row row = createRandomRow(schema, random); + randomContents.add(row); } - - private static List singletonPartition() { - List singletonPartition = new ArrayList<>(); - singletonPartition.add(new RandomPartition()); - return singletonPartition; + long thisBatchId = 1L; + if (startOffset.isPresent()) { + thisBatchId = startOffset.get().getOffset(); + } + if (windowInfo.getType() == WindowType.ALL_WINDOW) { + return (FetchData) + FetchData.createBatchFetch( + randomContents.iterator(), new RandomSourceOffset(thisBatchId + 1)); + } else { + boolean isFinish = thisBatchId >= maxBatch; + return (FetchData) + FetchData.createStreamFetch( + randomContents, new RandomSourceOffset(thisBatchId + 1), isFinish); } + } + + private static Row createRandomRow(TableSchema schema, Random random) { + List fields = schema.getFields(); + Object[] objects = new Object[schema.size()]; + for (int i = 0; i < fields.size(); i++) { + TableField field = fields.get(i); + objects[i] = createRandomCol(field, random); + } + return ObjectRow.create(objects); + } + + private static Object createRandomCol(TableField field, Random random) { + String fieldName = field.getName(); + IType type = field.getType(); + switch (type.getName()) { + case Types.TYPE_NAME_BOOLEAN: + return random.nextBoolean(); + case Types.TYPE_NAME_BYTE: + return (byte) random.nextInt(Byte.MAX_VALUE + 1); + case Types.TYPE_NAME_SHORT: + return (short) random.nextInt(Short.MAX_VALUE + 1); + case Types.TYPE_NAME_INTEGER: + return random.nextInt(100); + case Types.TYPE_NAME_LONG: + return (long) random.nextInt(1000000); + case Types.TYPE_NAME_FLOAT: + return random.nextFloat() * 100.0; + case Types.TYPE_NAME_DOUBLE: + return random.nextDouble() * 1000000.0; + case Types.TYPE_NAME_STRING: + return fieldName + "_" + random.nextInt(1000000); + case Types.TYPE_NAME_BINARY_STRING: + return BinaryString.fromString(fieldName + "_" + random.nextInt(1000000)); + default: + throw new RuntimeException("Cannot create random value for type: " + type); + } + } + + @Override + public void close() {} + + public static class RandomPartition implements Partition { + + public RandomPartition() {} @Override - public List listPartitions() { - return singletonPartition(); + public String getName() { + return this.getClass().getName(); } @Override - public List listPartitions(int parallelism) { - return singletonPartition(); - } + public void setIndex(int index, int parallel) {} @Override - public TableDeserializer getDeserializer(Configuration conf) { - return (TableDeserializer) new RowTableDeserializer(); + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); } @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - long windowSize; - if (windowInfo.getType() == WindowType.ALL_WINDOW) { - windowSize = (long) rate; - } else if (windowInfo.getType() == WindowType.SIZE_TUMBLING_WINDOW) { - //Control wait time based on the rate. - windowSize = windowInfo.windowSize(); - SleepUtils.sleepMilliSecond((long) (1000 * (windowSize / rate))); - } else if (windowInfo.getType() == WindowType.FIXED_TIME_TUMBLING_WINDOW) { - //Control the number of tuples based on the rate. - if (windowInfo.windowSize() > Integer.MAX_VALUE) { - throw new GeaFlowDSLException("Random table source window size is overflow:{}", - windowInfo.windowSize()); - } - long seconds = (int) windowInfo.windowSize(); - windowSize = (long) (seconds * rate); - } else { - throw new GeaFlowDSLException("File table source not support window:{}", windowInfo.getType()); - } - List randomContents = new ArrayList<>(); - Random random = new Random(); - for (long i = 0; i < windowSize; i++) { - Row row = createRandomRow(schema, random); - randomContents.add(row); - } - long thisBatchId = 1L; - if (startOffset.isPresent()) { - thisBatchId = startOffset.get().getOffset(); - } - if (windowInfo.getType() == WindowType.ALL_WINDOW) { - return (FetchData) FetchData.createBatchFetch(randomContents.iterator(), - new RandomSourceOffset(thisBatchId + 1)); - } else { - boolean isFinish = thisBatchId >= maxBatch; - return (FetchData) FetchData.createStreamFetch(randomContents, - new RandomSourceOffset(thisBatchId + 1), isFinish); - } + public int hashCode() { + return Objects.hash(this.getClass().getName()); } + } - private static Row createRandomRow(TableSchema schema, Random random) { - List fields = schema.getFields(); - Object[] objects = new Object[schema.size()]; - for (int i = 0; i < fields.size(); i++) { - TableField field = fields.get(i); - objects[i] = createRandomCol(field, random); - } - return ObjectRow.create(objects); - } + public static class RandomSourceOffset implements Offset { - private static Object createRandomCol(TableField field, Random random) { - String fieldName = field.getName(); - IType type = field.getType(); - switch (type.getName()) { - case Types.TYPE_NAME_BOOLEAN: - return random.nextBoolean(); - case Types.TYPE_NAME_BYTE: - return (byte) random.nextInt(Byte.MAX_VALUE + 1); - case Types.TYPE_NAME_SHORT: - return (short) random.nextInt(Short.MAX_VALUE + 1); - case Types.TYPE_NAME_INTEGER: - return random.nextInt(100); - case Types.TYPE_NAME_LONG: - return (long) random.nextInt(1000000); - case Types.TYPE_NAME_FLOAT: - return random.nextFloat() * 100.0; - case Types.TYPE_NAME_DOUBLE: - return random.nextDouble() * 1000000.0; - case Types.TYPE_NAME_STRING: - return fieldName + "_" + random.nextInt(1000000); - case Types.TYPE_NAME_BINARY_STRING: - return BinaryString.fromString(fieldName + "_" + random.nextInt(1000000)); - default: - throw new RuntimeException("Cannot create random value for type: " + type); - } + private final long batchId; + + public RandomSourceOffset(long batchId) { + this.batchId = batchId; } @Override - public void close() { - + public String humanReadable() { + return "batch:" + batchId; } - public static class RandomPartition implements Partition { - - public RandomPartition() { - } - - @Override - public String getName() { - return this.getClass().getName(); - } - - @Override - public void setIndex(int index, int parallel) { - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - return o != null && getClass() == o.getClass(); - } - - @Override - public int hashCode() { - return Objects.hash(this.getClass().getName()); - } + @Override + public long getOffset() { + return batchId; } - public static class RandomSourceOffset implements Offset { - - private final long batchId; - - public RandomSourceOffset(long batchId) { - this.batchId = batchId; - } - - @Override - public String humanReadable() { - return "batch:" + batchId; - } - - @Override - public long getOffset() { - return batchId; - } - - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/test/java/org/apache/geaflow/dsl/connector/random/RandomTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/test/java/org/apache/geaflow/dsl/connector/random/RandomTableConnectorTest.java index 526365bb9..f1d241b77 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/test/java/org/apache/geaflow/dsl/connector/random/RandomTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-random/src/test/java/org/apache/geaflow/dsl/connector/random/RandomTableConnectorTest.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.task.TaskArgs; @@ -32,58 +33,59 @@ public class RandomTableConnectorTest { - @Test - public void testRandom() throws IOException { - RandomTableConnector connector = new RandomTableConnector(); - TableSource source = connector.createSource(new Configuration()); - Assert.assertEquals(source.getClass(), RandomTableSource.class); - Configuration tableConf = new Configuration(); + @Test + public void testRandom() throws IOException { + RandomTableConnector connector = new RandomTableConnector(); + TableSource source = connector.createSource(new Configuration()); + Assert.assertEquals(source.getClass(), RandomTableSource.class); + Configuration tableConf = new Configuration(); - source.init(tableConf, new TableSchema()); + source.init(tableConf, new TableSchema()); - source.open(new RuntimeContext() { - @Override - public long getPipelineId() { - return 0; - } + source.open( + new RuntimeContext() { + @Override + public long getPipelineId() { + return 0; + } - @Override - public String getPipelineName() { - return null; - } + @Override + public String getPipelineName() { + return null; + } - @Override - public TaskArgs getTaskArgs() { - return null; - } + @Override + public TaskArgs getTaskArgs() { + return null; + } - @Override - public Configuration getConfiguration() { - return null; - } + @Override + public Configuration getConfiguration() { + return null; + } - @Override - public String getWorkPath() { - return null; - } + @Override + public String getWorkPath() { + return null; + } - @Override - public MetricGroup getMetric() { - return null; - } + @Override + public MetricGroup getMetric() { + return null; + } - @Override - public RuntimeContext clone(Map opConfig) { - return null; - } + @Override + public RuntimeContext clone(Map opConfig) { + return null; + } - @Override - public long getWindowId() { - return 0; - } + @Override + public long getWindowId() { + return 0; + } }); - source.listPartitions(); - source.close(); - } + source.listPartitions(); + source.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketConfigKeys.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketConfigKeys.java index 80244ad26..c039c230b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketConfigKeys.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketConfigKeys.java @@ -24,13 +24,13 @@ public class SocketConfigKeys { - public static final ConfigKey GEAFLOW_DSL_SOCKET_HOST = ConfigKeys - .key("geaflow.dsl.socket.host") - .noDefaultValue() - .description("The host which socket links to."); + public static final ConfigKey GEAFLOW_DSL_SOCKET_HOST = + ConfigKeys.key("geaflow.dsl.socket.host") + .noDefaultValue() + .description("The host which socket links to."); - public static final ConfigKey GEAFLOW_DSL_SOCKET_PORT = ConfigKeys - .key("geaflow.dsl.socket.port") - .noDefaultValue() - .description("The port which socket links to."); + public static final ConfigKey GEAFLOW_DSL_SOCKET_PORT = + ConfigKeys.key("geaflow.dsl.socket.port") + .noDefaultValue() + .description("The port which socket links to."); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableConnector.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableConnector.java index feb415ff7..47d7aebba 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableConnector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableConnector.java @@ -27,18 +27,18 @@ public class SocketTableConnector implements TableReadableConnector, TableWritableConnector { - @Override - public String getType() { - return "SOCKET"; - } + @Override + public String getType() { + return "SOCKET"; + } - @Override - public TableSource createSource(Configuration conf) { - return new SocketTableSource(); - } + @Override + public TableSource createSource(Configuration conf) { + return new SocketTableSource(); + } - @Override - public TableSink createSink(Configuration conf) { - return new SocketTableSink(); - } + @Override + public TableSink createSink(Configuration conf) { + return new SocketTableSink(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSink.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSink.java index c6519ead1..e56505691 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSink.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSink.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Objects; import java.util.concurrent.LinkedBlockingQueue; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ConnectorConfigKeys; @@ -36,68 +37,63 @@ public class SocketTableSink implements TableSink, ISkipOpenAndClose { - private static final Logger LOGGER = LoggerFactory.getLogger(SocketTableSink.class.getName()); - - private Configuration tableConf; - - private StructType schema; - - private String separator; - - private LinkedBlockingQueue dataQueue; - - @Override - public void init(Configuration tableConf, StructType schema) { - this.tableConf = tableConf; - this.schema = Objects.requireNonNull(schema); - this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + private static final Logger LOGGER = LoggerFactory.getLogger(SocketTableSink.class.getName()); + + private Configuration tableConf; + + private StructType schema; + + private String separator; + + private LinkedBlockingQueue dataQueue; + + @Override + public void init(Configuration tableConf, StructType schema) { + this.tableConf = tableConf; + this.schema = Objects.requireNonNull(schema); + this.separator = tableConf.getString(ConnectorConfigKeys.GEAFLOW_DSL_COLUMN_SEPARATOR); + } + + @Override + public void open(RuntimeContext context) { + String host = tableConf.getString(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST); + int port = tableConf.getInteger(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT); + this.dataQueue = new LinkedBlockingQueue<>(); + while (true) { + try { + NettySinkClient client = new NettySinkClient(host, port, dataQueue); + client.run(); + break; + } catch (Exception e) { + LOGGER.info("Attempt to connect sink netty server."); + SleepUtils.sleepSecond(5); + } } + } - @Override - public void open(RuntimeContext context) { - String host = tableConf.getString(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST); - int port = tableConf.getInteger(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT); - this.dataQueue = new LinkedBlockingQueue<>(); - while (true) { - try { - NettySinkClient client = new NettySinkClient(host, port, dataQueue); - client.run(); - break; - } catch (Exception e) { - LOGGER.info("Attempt to connect sink netty server."); - SleepUtils.sleepSecond(5); - } - } + @Override + public void write(Row row) throws IOException { + Object[] values = new Object[schema.size()]; + for (int i = 0; i < schema.size(); i++) { + values[i] = row.getField(i, schema.getType(i)); } - - @Override - public void write(Row row) throws IOException { - Object[] values = new Object[schema.size()]; - for (int i = 0; i < schema.size(); i++) { - values[i] = row.getField(i, schema.getType(i)); - } - StringBuilder line = new StringBuilder(); - for (Object value : values) { - if (line.length() > 0) { - line.append(separator); - } - line.append(value); - } - try { - dataQueue.put(line.toString()); - } catch (InterruptedException e) { - LOGGER.info(null, e); - } - + StringBuilder line = new StringBuilder(); + for (Object value : values) { + if (line.length() > 0) { + line.append(separator); + } + line.append(value); } - - @Override - public void finish() throws IOException { - + try { + dataQueue.put(line.toString()); + } catch (InterruptedException e) { + LOGGER.info(null, e); } + } - @Override - public void close() { + @Override + public void finish() throws IOException {} - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSource.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSource.java index 0c46ff970..000847a27 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSource.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/SocketTableSource.java @@ -26,6 +26,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.LinkedBlockingQueue; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.WindowType; import org.apache.geaflow.common.config.Configuration; @@ -46,129 +47,125 @@ public class SocketTableSource implements TableSource, ISkipOpenAndClose { - private static final Logger LOGGER = LoggerFactory.getLogger(SocketTableSource.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(SocketTableSource.class.getName()); + + private Configuration tableConf; + + private LinkedBlockingQueue dataQueue; + + @Override + public void init(Configuration tableConf, TableSchema tableSchema) { + this.tableConf = tableConf; + } + + @Override + public void open(RuntimeContext context) { + String host = tableConf.getString(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST); + int port = tableConf.getInteger(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT); + this.dataQueue = new LinkedBlockingQueue<>(); + while (true) { + try { + NettySourceClient client = new NettySourceClient(host, port, dataQueue); + client.run(); + break; + } catch (Exception e) { + LOGGER.info("Attempt to connect source netty server."); + SleepUtils.sleepSecond(5); + } + } + } + + @Override + public List listPartitions() { + return Collections.singletonList(new SocketPartition(new ArrayList<>())); + } + + @Override + public List listPartitions(int parallelism) { + return listPartitions(); + } + + @Override + public TableDeserializer getDeserializer(Configuration conf) { + return DeserializerFactory.loadTextDeserializer(); + } + + @Override + public FetchData fetch( + Partition partition, Optional startOffset, FetchWindow windowInfo) + throws IOException { + if (windowInfo.getType() != WindowType.SIZE_TUMBLING_WINDOW) { + throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); + } + try { + List fetchData = new ArrayList<>(); + for (int i = 0; i < windowInfo.windowSize(); i++) { + fetchData.add(dataQueue.take()); + } + return (FetchData) FetchData.createStreamFetch(fetchData, new SocketOffset(), false); + } catch (InterruptedException e) { + throw new IOException(e); + } + } - private Configuration tableConf; + @Override + public void close() {} - private LinkedBlockingQueue dataQueue; + public static class SocketPartition implements Partition { - @Override - public void init(Configuration tableConf, TableSchema tableSchema) { - this.tableConf = tableConf; - } + private List data; - @Override - public void open(RuntimeContext context) { - String host = tableConf.getString(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST); - int port = tableConf.getInteger(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT); - this.dataQueue = new LinkedBlockingQueue<>(); - while (true) { - try { - NettySourceClient client = new NettySourceClient(host, port, dataQueue); - client.run(); - break; - } catch (Exception e) { - LOGGER.info("Attempt to connect source netty server."); - SleepUtils.sleepSecond(5); - } - } + public SocketPartition(List data) { + this.data = data; } - @Override - public List listPartitions() { - return Collections.singletonList(new SocketPartition(new ArrayList<>())); + public List getData() { + return data; } @Override - public List listPartitions(int parallelism) { - return listPartitions(); + public String getName() { + return String.valueOf(data.hashCode()); } @Override - public TableDeserializer getDeserializer(Configuration conf) { - return DeserializerFactory.loadTextDeserializer(); - } + public void setIndex(int index, int parallel) {} @Override - public FetchData fetch(Partition partition, Optional startOffset, - FetchWindow windowInfo) throws IOException { - if (windowInfo.getType() != WindowType.SIZE_TUMBLING_WINDOW) { - throw new GeaFlowDSLException("Not support window type:{}", windowInfo.getType()); - } - try { - List fetchData = new ArrayList<>(); - for (int i = 0; i < windowInfo.windowSize(); i++) { - fetchData.add(dataQueue.take()); - } - return (FetchData) FetchData.createStreamFetch(fetchData, new SocketOffset(), false); - } catch (InterruptedException e) { - throw new IOException(e); - } + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SocketPartition that = (SocketPartition) o; + return Objects.equals(data, that.data); } @Override - public void close() { - - } - - public static class SocketPartition implements Partition { - - private List data; - - public SocketPartition(List data) { - this.data = data; - } - - public List getData() { - return data; - } - - @Override - public String getName() { - return String.valueOf(data.hashCode()); - } - - @Override - public void setIndex(int index, int parallel) { - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - SocketPartition that = (SocketPartition) o; - return Objects.equals(data, that.data); - } - - @Override - public int hashCode() { - return Objects.hash(data); - } + public int hashCode() { + return Objects.hash(data); } + } - public static class SocketOffset implements Offset { + public static class SocketOffset implements Offset { - public SocketOffset() { + public SocketOffset() {} - } - - @Override - public String humanReadable() { - return "None"; - } + @Override + public String humanReadable() { + return "None"; + } - @Override - public long getOffset() { - return -1; - } + @Override + public long getOffset() { + return -1; + } - @Override - public boolean isTimestamp() { - return false; - } + @Override + public boolean isTimestamp() { + return false; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySinkClient.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySinkClient.java index 457d5f406..14e0ffb0d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySinkClient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySinkClient.java @@ -19,6 +19,11 @@ package org.apache.geaflow.dsl.connector.socket.server; +import java.util.concurrent.LinkedBlockingQueue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; @@ -32,76 +37,76 @@ import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; import io.netty.util.CharsetUtil; -import java.util.concurrent.LinkedBlockingQueue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class NettySinkClient { - private static final Logger LOGGER = LoggerFactory.getLogger(NettySinkClient.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(NettySinkClient.class.getName()); - private final String host; + private final String host; - private final int port; + private final int port; - private final LinkedBlockingQueue dataQueue; + private final LinkedBlockingQueue dataQueue; - public NettySinkClient(String host, int port, LinkedBlockingQueue dataQueue) { - this.host = host; - this.port = port; - this.dataQueue = dataQueue; - } + public NettySinkClient(String host, int port, LinkedBlockingQueue dataQueue) { + this.host = host; + this.port = port; + this.dataQueue = dataQueue; + } - public void run() throws Exception { - EventLoopGroup bossGroup = new NioEventLoopGroup(); - Bootstrap bootstrap = new Bootstrap().group(bossGroup).channel(NioSocketChannel.class) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel channel) throws Exception { + public void run() throws Exception { + EventLoopGroup bossGroup = new NioEventLoopGroup(); + Bootstrap bootstrap = + new Bootstrap() + .group(bossGroup) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { channel.pipeline().addLast(new LineBasedFrameDecoder(Integer.MAX_VALUE)); channel.pipeline().addLast("decoder", new StringDecoder(CharsetUtil.UTF_8)); channel.pipeline().addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)); channel.pipeline().addLast(new ClientHandler()); - } - }); - - ChannelFuture future = bootstrap.connect(host, port).sync(); - future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); + } + }); - } + ChannelFuture future = bootstrap.connect(host, port).sync(); + future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); + } - private class ClientHandler extends ChannelInboundHandlerAdapter { + private class ClientHandler extends ChannelInboundHandlerAdapter { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - ctx.channel().eventLoop().submit(new SendDataTask(ctx)); - } + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.channel().eventLoop().submit(new SendDataTask(ctx)); + } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - String data = String.valueOf(msg); - LOGGER.info("sink receive data: {}", data); - } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + String data = String.valueOf(msg); + LOGGER.info("sink receive data: {}", data); } + } - private class SendDataTask implements Runnable { + private class SendDataTask implements Runnable { - private final ChannelHandlerContext context; + private final ChannelHandlerContext context; - public SendDataTask(ChannelHandlerContext context) { - this.context = context; - } + public SendDataTask(ChannelHandlerContext context) { + this.context = context; + } - @Override - public void run() { - while (true) { - try { - String data = dataQueue.take(); - context.writeAndFlush(data + "\n"); - } catch (InterruptedException e) { - LOGGER.info("Send data error. ", e); - } - } + @Override + public void run() { + while (true) { + try { + String data = dataQueue.take(); + context.writeAndFlush(data + "\n"); + } catch (InterruptedException e) { + LOGGER.info("Send data error. ", e); } + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySourceClient.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySourceClient.java index 54b62d09f..901401aa6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySourceClient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettySourceClient.java @@ -19,6 +19,11 @@ package org.apache.geaflow.dsl.connector.socket.server; +import java.util.concurrent.LinkedBlockingQueue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; @@ -32,56 +37,57 @@ import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; import io.netty.util.CharsetUtil; -import java.util.concurrent.LinkedBlockingQueue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class NettySourceClient { - private static final Logger LOGGER = LoggerFactory.getLogger(NettySourceClient.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(NettySourceClient.class.getName()); - private final String host; + private final String host; - private final int port; + private final int port; - private final LinkedBlockingQueue dataQueue; + private final LinkedBlockingQueue dataQueue; - public NettySourceClient(String host, int port, LinkedBlockingQueue dataQueue) { - this.host = host; - this.port = port; - this.dataQueue = dataQueue; - } + public NettySourceClient(String host, int port, LinkedBlockingQueue dataQueue) { + this.host = host; + this.port = port; + this.dataQueue = dataQueue; + } - public void run() throws Exception { - EventLoopGroup bossGroup = new NioEventLoopGroup(); - Bootstrap bootstrap = new Bootstrap().group(bossGroup).channel(NioSocketChannel.class) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel channel) throws Exception { + public void run() throws Exception { + EventLoopGroup bossGroup = new NioEventLoopGroup(); + Bootstrap bootstrap = + new Bootstrap() + .group(bossGroup) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { channel.pipeline().addLast(new LineBasedFrameDecoder(Integer.MAX_VALUE)); channel.pipeline().addLast("decoder", new StringDecoder(CharsetUtil.UTF_8)); channel.pipeline().addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)); channel.pipeline().addLast(new ClientHandler()); - } - }); + } + }); - ChannelFuture future = bootstrap.connect(host, port).sync(); - future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); - } + ChannelFuture future = bootstrap.connect(host, port).sync(); + future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); + } - private class ClientHandler extends ChannelInboundHandlerAdapter { + private class ClientHandler extends ChannelInboundHandlerAdapter { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - String getDataMsg = NettyWebServer.CMD_GET_DATA + "\n"; - ctx.writeAndFlush(getDataMsg); - } + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + String getDataMsg = NettyWebServer.CMD_GET_DATA + "\n"; + ctx.writeAndFlush(getDataMsg); + } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - String data = String.valueOf(msg); - LOGGER.info("source get data: {}", msg); - dataQueue.put(data); - } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + String data = String.valueOf(msg); + LOGGER.info("source get data: {}", msg); + dataQueue.put(data); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyTerminalServer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyTerminalServer.java index ec47bf1da..c997ccdcb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyTerminalServer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyTerminalServer.java @@ -19,6 +19,15 @@ package org.apache.geaflow.dsl.connector.socket.server; +import java.io.PrintStream; +import java.util.Scanner; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; @@ -33,110 +42,105 @@ import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; import io.netty.util.CharsetUtil; -import java.io.PrintStream; -import java.util.Scanner; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class NettyTerminalServer { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyTerminalServer.class.getName()); - - private PrintStream printer; - - private LinkedBlockingQueue dataQueue = new LinkedBlockingQueue<>(); - - public void bind(int port) { - EventLoopGroup bossGroup = new NioEventLoopGroup(); - EventLoopGroup workerGroup = new NioEventLoopGroup(); - printer = System.out; - - try { - ServerBootstrap bootstrap = new ServerBootstrap() - .group(bossGroup, workerGroup) - .channel(NioServerSocketChannel.class) - .childHandler(new NettyServerChannelInitializer()) - .option(ChannelOption.SO_BACKLOG, 500) - .childOption(ChannelOption.SO_KEEPALIVE, true); - - printer.println("Start netty terminal server, waiting connect."); - // bind port - ChannelFuture future = bootstrap.bind(port).sync(); - // shutdown channel - future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); - Scanner scanner = new Scanner(System.in); - while (true) { - while (scanner.hasNextLine()) { - String line = scanner.nextLine(); - dataQueue.put(line); - } - } - } catch (Exception e) { - bossGroup.shutdownGracefully(); - workerGroup.shutdownGracefully(); + private static final Logger LOGGER = LoggerFactory.getLogger(NettyTerminalServer.class.getName()); + + private PrintStream printer; + + private LinkedBlockingQueue dataQueue = new LinkedBlockingQueue<>(); + + public void bind(int port) { + EventLoopGroup bossGroup = new NioEventLoopGroup(); + EventLoopGroup workerGroup = new NioEventLoopGroup(); + printer = System.out; + + try { + ServerBootstrap bootstrap = + new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new NettyServerChannelInitializer()) + .option(ChannelOption.SO_BACKLOG, 500) + .childOption(ChannelOption.SO_KEEPALIVE, true); + + printer.println("Start netty terminal server, waiting connect."); + // bind port + ChannelFuture future = bootstrap.bind(port).sync(); + // shutdown channel + future.channel().closeFuture().addListener((channelFuture) -> bossGroup.shutdownGracefully()); + Scanner scanner = new Scanner(System.in); + while (true) { + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + dataQueue.put(line); } + } + } catch (Exception e) { + bossGroup.shutdownGracefully(); + workerGroup.shutdownGracefully(); } + } - private class NettyServerChannelInitializer extends ChannelInitializer { + private class NettyServerChannelInitializer extends ChannelInitializer { - @Override - protected void initChannel(SocketChannel channel) throws Exception { - channel.pipeline().addLast(new LineBasedFrameDecoder(Integer.MAX_VALUE)); - channel.pipeline().addLast(new StringDecoder(CharsetUtil.UTF_8)); - channel.pipeline().addLast(new StringEncoder(CharsetUtil.UTF_8)); - channel.pipeline().addLast("commonHandler", new SocketHandler()); - } + @Override + protected void initChannel(SocketChannel channel) throws Exception { + channel.pipeline().addLast(new LineBasedFrameDecoder(Integer.MAX_VALUE)); + channel.pipeline().addLast(new StringDecoder(CharsetUtil.UTF_8)); + channel.pipeline().addLast(new StringEncoder(CharsetUtil.UTF_8)); + channel.pipeline().addLast("commonHandler", new SocketHandler()); + } + } + + private class SocketHandler extends ChannelInboundHandlerAdapter { + + private AtomicBoolean activeFlag = new AtomicBoolean(true); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + String message = String.valueOf(msg); + if (message.startsWith(NettyWebServer.CMD_GET_DATA)) { + printer.println("please enter the data to the console."); + ctx.channel() + .eventLoop() + .scheduleWithFixedDelay(new GetDataTask(ctx, activeFlag), 0, 1, TimeUnit.SECONDS); + } else { + printer.println(">> " + message); + } } - private class SocketHandler extends ChannelInboundHandlerAdapter { - - private AtomicBoolean activeFlag = new AtomicBoolean(true); - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - String message = String.valueOf(msg); - if (message.startsWith(NettyWebServer.CMD_GET_DATA)) { - printer.println("please enter the data to the console."); - ctx.channel().eventLoop().scheduleWithFixedDelay(new GetDataTask(ctx, activeFlag) - , 0, 1, TimeUnit.SECONDS); - } else { - printer.println(">> " + message); - } - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - activeFlag.set(false); - } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + activeFlag.set(false); } + } - private class GetDataTask implements Runnable { + private class GetDataTask implements Runnable { - private final ChannelHandlerContext ctx; - private final AtomicBoolean activeFlag; + private final ChannelHandlerContext ctx; + private final AtomicBoolean activeFlag; - public GetDataTask(ChannelHandlerContext ctx, AtomicBoolean activeFlag) { - this.ctx = ctx; - this.activeFlag = activeFlag; - } + public GetDataTask(ChannelHandlerContext ctx, AtomicBoolean activeFlag) { + this.ctx = ctx; + this.activeFlag = activeFlag; + } - @Override - public void run() { - if (activeFlag.get()) { - while (!dataQueue.isEmpty()) { - String line = null; - try { - line = dataQueue.take(); - String res = line + "\n"; - this.ctx.writeAndFlush(res); - } catch (InterruptedException e) { - LOGGER.info(e.getMessage()); - } - } - } + @Override + public void run() { + if (activeFlag.get()) { + while (!dataQueue.isEmpty()) { + String line = null; + try { + line = dataQueue.take(); + String res = line + "\n"; + this.ctx.writeAndFlush(res); + } catch (InterruptedException e) { + LOGGER.info(e.getMessage()); + } } + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyWebServer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyWebServer.java index 08f251c5e..6463a8a00 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyWebServer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/NettyWebServer.java @@ -19,6 +19,13 @@ package org.apache.geaflow.dsl.connector.socket.server; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; @@ -40,168 +47,166 @@ import io.netty.handler.codec.string.StringDecoder; import io.netty.handler.codec.string.StringEncoder; import io.netty.util.CharsetUtil; -import java.util.List; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class NettyWebServer { - private static final Logger LOGGER = LoggerFactory.getLogger(NettyWebServer.class.getName()); - - public static final String CMD_GET_DATA = "GET_DATA"; + private static final Logger LOGGER = LoggerFactory.getLogger(NettyWebServer.class.getName()); - public static final String CMD_WEB_DISPLAY_DATA = "WEB_DISPLAY_DATA"; + public static final String CMD_GET_DATA = "GET_DATA"; - private final LinkedBlockingQueue inputDataQueue = new LinkedBlockingQueue<>(); + public static final String CMD_WEB_DISPLAY_DATA = "WEB_DISPLAY_DATA"; - private final LinkedBlockingQueue outputDataQueue = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue inputDataQueue = new LinkedBlockingQueue<>(); - public void bind(int port) { - EventLoopGroup bossGroup = new NioEventLoopGroup(); - EventLoopGroup workerGroup = new NioEventLoopGroup(); + private final LinkedBlockingQueue outputDataQueue = new LinkedBlockingQueue<>(); - try { - ServerBootstrap bootstrap = new ServerBootstrap() - .group(bossGroup, workerGroup) - .channel(NioServerSocketChannel.class) - .childHandler(new NettyServerChannelInitializer()) - .option(ChannelOption.SO_BACKLOG, 500) - .childOption(ChannelOption.SO_KEEPALIVE, true); - - LOGGER.info("Start netty web server, waiting connect."); - // bind port - ChannelFuture future = bootstrap.bind(port).sync(); - // shutdown channel sync - future.channel().closeFuture().sync(); - } catch (Exception e) { - bossGroup.shutdownGracefully(); - workerGroup.shutdownGracefully(); - } - } + public void bind(int port) { + EventLoopGroup bossGroup = new NioEventLoopGroup(); + EventLoopGroup workerGroup = new NioEventLoopGroup(); - private class NettyServerChannelInitializer extends ChannelInitializer { + try { + ServerBootstrap bootstrap = + new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new NettyServerChannelInitializer()) + .option(ChannelOption.SO_BACKLOG, 500) + .childOption(ChannelOption.SO_KEEPALIVE, true); - @Override - protected void initChannel(SocketChannel channel) throws Exception { - channel.pipeline().addLast("protocolChoose", new ProtocolChooseHandler()); - } + LOGGER.info("Start netty web server, waiting connect."); + // bind port + ChannelFuture future = bootstrap.bind(port).sync(); + // shutdown channel sync + future.channel().closeFuture().sync(); + } catch (Exception e) { + bossGroup.shutdownGracefully(); + workerGroup.shutdownGracefully(); } + } - private class ProtocolChooseHandler extends ByteToMessageDecoder { - - private static final int MAX_LENGTH = 100; - - private static final String WEB_SOCKET_PREFIX = "GET /"; - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - String protocol = getProtocol(in); - if (protocol.startsWith(WEB_SOCKET_PREFIX)) { - // add web socket protocol handler - ctx.pipeline().addLast("httpCodec", new HttpServerCodec()); - ctx.pipeline().addLast("aggregator", new HttpObjectAggregator(65535)); - ctx.pipeline().addLast("webSocketAggregator", new WebSocketFrameAggregator(65535)); - ctx.pipeline().addLast("protocolHandler", new WebSocketServerProtocolHandler("/")); - ctx.pipeline().addLast("webSocketHandler", new WebSocketHandler()); - } else { - // add socket protocol handler - ByteBuf buf = in.copy(); - buf.resetReaderIndex(); - out.add(buf); - ctx.pipeline().addLast("lineSplit", new LineBasedFrameDecoder(Integer.MAX_VALUE)); - ctx.pipeline().addLast("decoder", new StringDecoder(CharsetUtil.UTF_8)); - ctx.pipeline().addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)); - ctx.pipeline().addLast("socketHandler", new SocketHandler()); - } - - in.resetReaderIndex(); - ctx.pipeline().remove(this.getClass()); - } + private class NettyServerChannelInitializer extends ChannelInitializer { - private String getProtocol(ByteBuf in) { - int length = in.readableBytes(); - if (length > MAX_LENGTH) { - length = MAX_LENGTH; - } - in.markReaderIndex(); - byte[] content = new byte[length]; - in.readBytes(content); - return new String(content); - } + @Override + protected void initChannel(SocketChannel channel) throws Exception { + channel.pipeline().addLast("protocolChoose", new ProtocolChooseHandler()); } - - private class WebSocketHandler extends ChannelInboundHandlerAdapter { - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - String message = ((TextWebSocketFrame) msg).text(); - if (message.startsWith(NettyWebServer.CMD_WEB_DISPLAY_DATA)) { - LOGGER.info("start display data to web"); - ctx.channel().eventLoop().scheduleWithFixedDelay(new WebDisplayDataTask(ctx), 0, 1, - TimeUnit.SECONDS); - } else { - LOGGER.info("receive data from web: {}", message); - inputDataQueue.put(message); - } - } + } + + private class ProtocolChooseHandler extends ByteToMessageDecoder { + + private static final int MAX_LENGTH = 100; + + private static final String WEB_SOCKET_PREFIX = "GET /"; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) + throws Exception { + String protocol = getProtocol(in); + if (protocol.startsWith(WEB_SOCKET_PREFIX)) { + // add web socket protocol handler + ctx.pipeline().addLast("httpCodec", new HttpServerCodec()); + ctx.pipeline().addLast("aggregator", new HttpObjectAggregator(65535)); + ctx.pipeline().addLast("webSocketAggregator", new WebSocketFrameAggregator(65535)); + ctx.pipeline().addLast("protocolHandler", new WebSocketServerProtocolHandler("/")); + ctx.pipeline().addLast("webSocketHandler", new WebSocketHandler()); + } else { + // add socket protocol handler + ByteBuf buf = in.copy(); + buf.resetReaderIndex(); + out.add(buf); + ctx.pipeline().addLast("lineSplit", new LineBasedFrameDecoder(Integer.MAX_VALUE)); + ctx.pipeline().addLast("decoder", new StringDecoder(CharsetUtil.UTF_8)); + ctx.pipeline().addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)); + ctx.pipeline().addLast("socketHandler", new SocketHandler()); + } + + in.resetReaderIndex(); + ctx.pipeline().remove(this.getClass()); } - private class SocketHandler extends ChannelInboundHandlerAdapter { - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - String message = String.valueOf(msg); - if (message.startsWith(NettyWebServer.CMD_GET_DATA)) { - LOGGER.info("start get data from web"); - ctx.channel().eventLoop().submit(new GetDataTask(ctx)); - } else { - LOGGER.info("receive result from engine: {}", message); - outputDataQueue.put(message); - } - } + private String getProtocol(ByteBuf in) { + int length = in.readableBytes(); + if (length > MAX_LENGTH) { + length = MAX_LENGTH; + } + in.markReaderIndex(); + byte[] content = new byte[length]; + in.readBytes(content); + return new String(content); + } + } + + private class WebSocketHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + String message = ((TextWebSocketFrame) msg).text(); + if (message.startsWith(NettyWebServer.CMD_WEB_DISPLAY_DATA)) { + LOGGER.info("start display data to web"); + ctx.channel() + .eventLoop() + .scheduleWithFixedDelay(new WebDisplayDataTask(ctx), 0, 1, TimeUnit.SECONDS); + } else { + LOGGER.info("receive data from web: {}", message); + inputDataQueue.put(message); + } + } + } + + private class SocketHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + String message = String.valueOf(msg); + if (message.startsWith(NettyWebServer.CMD_GET_DATA)) { + LOGGER.info("start get data from web"); + ctx.channel().eventLoop().submit(new GetDataTask(ctx)); + } else { + LOGGER.info("receive result from engine: {}", message); + outputDataQueue.put(message); + } } + } - private class GetDataTask implements Runnable { + private class GetDataTask implements Runnable { - private final ChannelHandlerContext ctx; + private final ChannelHandlerContext ctx; - public GetDataTask(ChannelHandlerContext ctx) { - this.ctx = ctx; - } + public GetDataTask(ChannelHandlerContext ctx) { + this.ctx = ctx; + } - @Override - public void run() { - while (true) { - try { - String data = inputDataQueue.take(); - ctx.channel().writeAndFlush(data + "\n"); - } catch (Exception e) { - LOGGER.error(null, e); - } - } + @Override + public void run() { + while (true) { + try { + String data = inputDataQueue.take(); + ctx.channel().writeAndFlush(data + "\n"); + } catch (Exception e) { + LOGGER.error(null, e); } + } } + } - private class WebDisplayDataTask implements Runnable { + private class WebDisplayDataTask implements Runnable { - private final ChannelHandlerContext ctx; + private final ChannelHandlerContext ctx; - public WebDisplayDataTask(ChannelHandlerContext ctx) { - this.ctx = ctx; - } + public WebDisplayDataTask(ChannelHandlerContext ctx) { + this.ctx = ctx; + } - @Override - public void run() { - while (!outputDataQueue.isEmpty()) { - try { - String data = outputDataQueue.take(); - ctx.channel().writeAndFlush(new TextWebSocketFrame(data)); - } catch (Exception e) { - LOGGER.error(null, e); - } - } + @Override + public void run() { + while (!outputDataQueue.isEmpty()) { + try { + String data = outputDataQueue.take(); + ctx.channel().writeAndFlush(new TextWebSocketFrame(data)); + } catch (Exception e) { + LOGGER.error(null, e); } + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/SocketServer.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/SocketServer.java index 590b762e6..211e7275c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/SocketServer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/main/java/org/apache/geaflow/dsl/connector/socket/server/SocketServer.java @@ -21,23 +21,23 @@ public class SocketServer { - private static final String TERMINAL_SERVER_TYPE = "TERMINAL"; + private static final String TERMINAL_SERVER_TYPE = "TERMINAL"; - private static final String GRAPH_INSIGHT_SERVER_TYPE = "GI"; + private static final String GRAPH_INSIGHT_SERVER_TYPE = "GI"; - public static void main(String[] args) { - int port = 9003; - String serverType = TERMINAL_SERVER_TYPE; - if (args.length > 0) { - port = Integer.parseInt(args[0]); - } - if (args.length > 1) { - serverType = String.valueOf(args[1]); - } - if (serverType.equalsIgnoreCase(GRAPH_INSIGHT_SERVER_TYPE)) { - new NettyWebServer().bind(port); - } else { - new NettyTerminalServer().bind(port); - } + public static void main(String[] args) { + int port = 9003; + String serverType = TERMINAL_SERVER_TYPE; + if (args.length > 0) { + port = Integer.parseInt(args[0]); } + if (args.length > 1) { + serverType = String.valueOf(args[1]); + } + if (serverType.equalsIgnoreCase(GRAPH_INSIGHT_SERVER_TYPE)) { + new NettyWebServer().bind(port); + } else { + new NettyTerminalServer().bind(port); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/test/java/SocketTableConnectorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/test/java/SocketTableConnectorTest.java index f1f804b10..7a7d60639 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/test/java/SocketTableConnectorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-connector/geaflow-dsl-connector-socket/src/test/java/SocketTableConnectorTest.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Optional; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.type.primitive.StringType; @@ -42,63 +43,67 @@ public class SocketTableConnectorTest { - public NettyTerminalServer setup(int port) throws Exception { - NettyTerminalServer nettyTerminalServer = new NettyTerminalServer(); - new Thread(() -> { - nettyTerminalServer.bind(port); - }).start(); - return nettyTerminalServer; - } + public NettyTerminalServer setup(int port) throws Exception { + NettyTerminalServer nettyTerminalServer = new NettyTerminalServer(); + new Thread( + () -> { + nettyTerminalServer.bind(port); + }) + .start(); + return nettyTerminalServer; + } - @Test - public void testSocketReadWrite() throws Exception { - setup(9003); - SocketTableConnector connector = new SocketTableConnector(); - Assert.assertEquals(connector.getType(), "SOCKET"); - Configuration tableConf = new Configuration(); - tableConf.put(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST.getKey(), "localhost"); - tableConf.put(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT.getKey(), "9003"); - TableSource tableSource = connector.createSource(tableConf); - Assert.assertEquals(tableSource.getDeserializer(tableConf).getClass(), TextDeserializer.class); - TableSchema sourceSchema = new TableSchema(new TableField("text", StringType.INSTANCE, true)); - tableSource.init(tableConf, sourceSchema); - tableSource.open(new DefaultRuntimeContext(tableConf)); - Assert.assertEquals(tableSource.listPartitions().size(), 1); - Partition partition = tableSource.listPartitions().get(0); + @Test + public void testSocketReadWrite() throws Exception { + setup(9003); + SocketTableConnector connector = new SocketTableConnector(); + Assert.assertEquals(connector.getType(), "SOCKET"); + Configuration tableConf = new Configuration(); + tableConf.put(SocketConfigKeys.GEAFLOW_DSL_SOCKET_HOST.getKey(), "localhost"); + tableConf.put(SocketConfigKeys.GEAFLOW_DSL_SOCKET_PORT.getKey(), "9003"); + TableSource tableSource = connector.createSource(tableConf); + Assert.assertEquals(tableSource.getDeserializer(tableConf).getClass(), TextDeserializer.class); + TableSchema sourceSchema = new TableSchema(new TableField("text", StringType.INSTANCE, true)); + tableSource.init(tableConf, sourceSchema); + tableSource.open(new DefaultRuntimeContext(tableConf)); + Assert.assertEquals(tableSource.listPartitions().size(), 1); + Partition partition = tableSource.listPartitions().get(0); - TableSink tableSink = connector.createSink(tableConf); + TableSink tableSink = connector.createSink(tableConf); - TableSchema sinkSchema = new TableSchema(new TableField("id", IntegerType.INSTANCE, true), + TableSchema sinkSchema = + new TableSchema( + new TableField("id", IntegerType.INSTANCE, true), new TableField("name", StringType.INSTANCE, true)); - tableSink.init(tableConf, sinkSchema); - tableSink.open(new DefaultRuntimeContext(tableConf)); + tableSink.init(tableConf, sinkSchema); + tableSink.open(new DefaultRuntimeContext(tableConf)); - tableSink.write(ObjectRow.create(1, "jim")); - tableSink.finish(); - tableSink.close(); + tableSink.write(ObjectRow.create(1, "jim")); + tableSink.finish(); + tableSink.close(); - try { - tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(1)); - } catch (Exception e) { - Assert.assertEquals(e.getClass(), GeaFlowDSLException.class); - } - tableSource.close(); + try { + tableSource.fetch(partition, Optional.empty(), new AllFetchWindow(1)); + } catch (Exception e) { + Assert.assertEquals(e.getClass(), GeaFlowDSLException.class); } + tableSource.close(); + } - @Test - public void testSocketPartitionAndOffset() { - SocketPartition socketPartition1 = new SocketPartition(new ArrayList<>()); - SocketPartition socketPartition2 = new SocketPartition(new ArrayList<>()); - Assert.assertFalse(socketPartition1.equals(null)); - Assert.assertEquals(socketPartition1.hashCode(), socketPartition2.hashCode()); - Assert.assertEquals(socketPartition1, socketPartition1); - Assert.assertEquals(socketPartition1, socketPartition2); - Assert.assertEquals(socketPartition1.getData().size(), 0); - Assert.assertEquals(socketPartition1.getName(), socketPartition2.getName()); + @Test + public void testSocketPartitionAndOffset() { + SocketPartition socketPartition1 = new SocketPartition(new ArrayList<>()); + SocketPartition socketPartition2 = new SocketPartition(new ArrayList<>()); + Assert.assertFalse(socketPartition1.equals(null)); + Assert.assertEquals(socketPartition1.hashCode(), socketPartition2.hashCode()); + Assert.assertEquals(socketPartition1, socketPartition1); + Assert.assertEquals(socketPartition1, socketPartition2); + Assert.assertEquals(socketPartition1.getData().size(), 0); + Assert.assertEquals(socketPartition1.getName(), socketPartition2.getName()); - SocketOffset socketOffset = new SocketOffset(); - Assert.assertEquals(socketOffset.getOffset(), -1); - Assert.assertEquals(socketOffset.humanReadable(), "None"); - Assert.assertFalse(socketOffset.isTimestamp()); - } + SocketOffset socketOffset = new SocketOffset(); + Assert.assertEquals(socketOffset.getOffset(), -1); + Assert.assertEquals(socketOffset.humanReadable(), "None"); + Assert.assertFalse(socketOffset.isTimestamp()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/EdgeRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/EdgeRecordType.java index acebd9e08..7ee6ea1da 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/EdgeRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/EdgeRecordType.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -41,190 +42,237 @@ public class EdgeRecordType extends RelRecordType { - private final boolean hasTimeField; + private final boolean hasTimeField; - private EdgeRecordType(List fields, boolean hasTimeField) { - super(StructKind.PEEK_FIELDS, fields); - this.hasTimeField = hasTimeField; - } + private EdgeRecordType(List fields, boolean hasTimeField) { + super(StructKind.PEEK_FIELDS, fields); + this.hasTimeField = hasTimeField; + } - @Override - public boolean isNullable() { - return true; - } + @Override + public boolean isNullable() { + return true; + } - public static EdgeRecordType createEdgeType(List fields, String srcIdField, - String targetIdField, String timestampField, - RelDataTypeFactory typeFactory) { - - boolean hasMultiSrcIdFields = fields.stream() - .filter(f -> f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_SRC_ID)) - .count() > 1; - if (hasMultiSrcIdFields) { - srcIdField = EdgeType.DEFAULT_SRC_ID_NAME; - fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_SRC_ID, srcIdField); - } - boolean hasMultiTargetIdFields = fields.stream() - .filter(f -> f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_TARGET_ID)) - .count() > 1; - if (hasMultiTargetIdFields) { - targetIdField = EdgeType.DEFAULT_TARGET_ID_NAME; - fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_TARGET_ID, targetIdField); - } - boolean hasMultiTsFields = fields.stream() - .filter(f -> f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_TS)) - .count() > 1; - if (hasMultiTsFields) { - timestampField = EdgeType.DEFAULT_TS_NAME; - fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_TS, timestampField); - } - List reorderFields = reorderFields(fields, srcIdField, targetIdField, timestampField, - typeFactory); - return new EdgeRecordType(reorderFields, timestampField != null); + public static EdgeRecordType createEdgeType( + List fields, + String srcIdField, + String targetIdField, + String timestampField, + RelDataTypeFactory typeFactory) { + + boolean hasMultiSrcIdFields = + fields.stream() + .filter( + f -> + f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()) + .getMetaField() + .equals(MetaField.EDGE_SRC_ID)) + .count() + > 1; + if (hasMultiSrcIdFields) { + srcIdField = EdgeType.DEFAULT_SRC_ID_NAME; + fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_SRC_ID, srcIdField); + } + boolean hasMultiTargetIdFields = + fields.stream() + .filter( + f -> + f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()) + .getMetaField() + .equals(MetaField.EDGE_TARGET_ID)) + .count() + > 1; + if (hasMultiTargetIdFields) { + targetIdField = EdgeType.DEFAULT_TARGET_ID_NAME; + fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_TARGET_ID, targetIdField); } + boolean hasMultiTsFields = + fields.stream() + .filter( + f -> + f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()) + .getMetaField() + .equals(MetaField.EDGE_TS)) + .count() + > 1; + if (hasMultiTsFields) { + timestampField = EdgeType.DEFAULT_TS_NAME; + fields = GraphRecordType.renameMetaField(fields, MetaField.EDGE_TS, timestampField); + } + List reorderFields = + reorderFields(fields, srcIdField, targetIdField, timestampField, typeFactory); + return new EdgeRecordType(reorderFields, timestampField != null); + } - public static EdgeRecordType createEdgeType(List fields, RelDataTypeFactory typeFactory) { - String srcIdField = null; - String targetIdField = null; - String timestampField = null; - for (RelDataTypeField field : fields) { - if (field.getType() instanceof MetaFieldType) { - MetaFieldType metaFieldType = (MetaFieldType) field.getType(); - if (metaFieldType.getMetaField() == MetaField.EDGE_SRC_ID) { - srcIdField = field.getName(); - } else if (metaFieldType.getMetaField() == MetaField.EDGE_TARGET_ID) { - targetIdField = field.getName(); - } else if (metaFieldType.getMetaField() == MetaField.EDGE_TS) { - timestampField = field.getName(); - } - } - } - if (srcIdField == null) { - throw new GeaFlowDSLException("Missing source id field"); + public static EdgeRecordType createEdgeType( + List fields, RelDataTypeFactory typeFactory) { + String srcIdField = null; + String targetIdField = null; + String timestampField = null; + for (RelDataTypeField field : fields) { + if (field.getType() instanceof MetaFieldType) { + MetaFieldType metaFieldType = (MetaFieldType) field.getType(); + if (metaFieldType.getMetaField() == MetaField.EDGE_SRC_ID) { + srcIdField = field.getName(); + } else if (metaFieldType.getMetaField() == MetaField.EDGE_TARGET_ID) { + targetIdField = field.getName(); + } else if (metaFieldType.getMetaField() == MetaField.EDGE_TS) { + timestampField = field.getName(); } - if (targetIdField == null) { - throw new GeaFlowDSLException("Missing target id field"); - } - return createEdgeType(fields, srcIdField, targetIdField, timestampField, typeFactory); + } } - - @Override - public SqlTypeName getSqlTypeName() { - return SqlTypeName.EDGE; + if (srcIdField == null) { + throw new GeaFlowDSLException("Missing source id field"); } - - @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { - super.generateTypeString(sb.append("Edge: "), withDetail); + if (targetIdField == null) { + throw new GeaFlowDSLException("Missing target id field"); } + return createEdgeType(fields, srcIdField, targetIdField, timestampField, typeFactory); + } - public RelDataTypeField getSrcIdField() { - return fieldList.get(EdgeType.SRC_ID_FIELD_POSITION); - } + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.EDGE; + } - public RelDataTypeField getTargetIdField() { - return fieldList.get(EdgeType.TARGET_ID_FIELD_POSITION); - } + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + super.generateTypeString(sb.append("Edge: "), withDetail); + } - public RelDataTypeField getLabelField() { - return fieldList.get(EdgeType.LABEL_FIELD_POSITION); - } + public RelDataTypeField getSrcIdField() { + return fieldList.get(EdgeType.SRC_ID_FIELD_POSITION); + } - public Optional getTimestampField() { - if (hasTimeField) { - return Optional.of(fieldList.get(EdgeType.TIME_FIELD_POSITION)); - } - return Optional.empty(); + public RelDataTypeField getTargetIdField() { + return fieldList.get(EdgeType.TARGET_ID_FIELD_POSITION); + } + + public RelDataTypeField getLabelField() { + return fieldList.get(EdgeType.LABEL_FIELD_POSITION); + } + + public Optional getTimestampField() { + if (hasTimeField) { + return Optional.of(fieldList.get(EdgeType.TIME_FIELD_POSITION)); } + return Optional.empty(); + } - public int getTimestampIndex() { - if (hasTimeField) { - return EdgeType.TIME_FIELD_POSITION; - } - return -1; + public int getTimestampIndex() { + if (hasTimeField) { + return EdgeType.TIME_FIELD_POSITION; } + return -1; + } - public EdgeRecordType add(String fieldName, RelDataType type, boolean caseSensitive) { - if (type instanceof MetaFieldType) { - type = ((MetaFieldType) type).getType(); - } - List fields = new ArrayList<>(getFieldList()); + public EdgeRecordType add(String fieldName, RelDataType type, boolean caseSensitive) { + if (type instanceof MetaFieldType) { + type = ((MetaFieldType) type).getType(); + } + List fields = new ArrayList<>(getFieldList()); - RelDataTypeField field = getField(fieldName, caseSensitive, false); - if (field != null) { - fields.set(field.getIndex(), new RelDataTypeFieldImpl(fieldName, field.getIndex(), type)); - } else { - fields.add(new RelDataTypeFieldImpl(fieldName, fields.size(), type)); - } - return new EdgeRecordType(fields, hasTimeField); + RelDataTypeField field = getField(fieldName, caseSensitive, false); + if (field != null) { + fields.set(field.getIndex(), new RelDataTypeFieldImpl(fieldName, field.getIndex(), type)); + } else { + fields.add(new RelDataTypeFieldImpl(fieldName, fields.size(), type)); } + return new EdgeRecordType(fields, hasTimeField); + } - private static List reorderFields(List fields, String srcIdField, - String targetIdField, String timestampField, - RelDataTypeFactory typeFactory) { - if (fields == null) { - throw new NullPointerException("fields is null"); - } - List reorderFields = new ArrayList<>(fields.size()); + private static List reorderFields( + List fields, + String srcIdField, + String targetIdField, + String timestampField, + RelDataTypeFactory typeFactory) { + if (fields == null) { + throw new NullPointerException("fields is null"); + } + List reorderFields = new ArrayList<>(fields.size()); - int srcIdIndex = indexOf(fields, srcIdField); - int targetIdIndex = indexOf(fields, targetIdField); - assert srcIdIndex != -1 : "srcIdField:" + srcIdField + " is not found"; - assert targetIdIndex != -1 : "targetIdField:" + targetIdField + " is not found"; + int srcIdIndex = indexOf(fields, srcIdField); + int targetIdIndex = indexOf(fields, targetIdField); + assert srcIdIndex != -1 : "srcIdField:" + srcIdField + " is not found"; + assert targetIdIndex != -1 : "targetIdField:" + targetIdField + " is not found"; - RelDataTypeField srcIdTypeField = fields.get(srcIdIndex); - RelDataTypeField targetIdTypeField = fields.get(targetIdIndex); - // put srcId field. - reorderFields.add(new RelDataTypeFieldImpl(srcIdTypeField.getName(), EdgeType.SRC_ID_FIELD_POSITION, + RelDataTypeField srcIdTypeField = fields.get(srcIdIndex); + RelDataTypeField targetIdTypeField = fields.get(targetIdIndex); + // put srcId field. + reorderFields.add( + new RelDataTypeFieldImpl( + srcIdTypeField.getName(), + EdgeType.SRC_ID_FIELD_POSITION, edgeSrcId(srcIdTypeField.getType(), typeFactory))); - // put targetId field. - reorderFields.add(new RelDataTypeFieldImpl(targetIdTypeField.getName(), EdgeType.TARGET_ID_FIELD_POSITION, + // put targetId field. + reorderFields.add( + new RelDataTypeFieldImpl( + targetIdTypeField.getName(), + EdgeType.TARGET_ID_FIELD_POSITION, edgeTargetId(targetIdTypeField.getType(), typeFactory))); - // put label field. - reorderFields.add(new RelDataTypeFieldImpl(GraphSchema.LABEL_FIELD_NAME, EdgeType.LABEL_FIELD_POSITION, + // put label field. + reorderFields.add( + new RelDataTypeFieldImpl( + GraphSchema.LABEL_FIELD_NAME, + EdgeType.LABEL_FIELD_POSITION, edgeType(typeFactory.createSqlType(SqlTypeName.VARCHAR), typeFactory))); - // put ts field if it has defined. - int tsIndex = indexOf(fields, timestampField); - if (tsIndex != -1) { - RelDataTypeField tsTypeField = fields.get(tsIndex); - reorderFields.add(new RelDataTypeFieldImpl(tsTypeField.getName(), EdgeType.TIME_FIELD_POSITION, - edgeTs(tsTypeField.getType(), typeFactory))); - } - int labelIndex = indexOf(fields, GraphSchema.LABEL_FIELD_NAME); - // put other fields by order exclude ~label. - for (int k = 0; k < fields.size(); k++) { - RelDataTypeField field = fields.get(k); - if (k != srcIdIndex && k != targetIdIndex && k != labelIndex && k != tsIndex) { - reorderFields.add(new RelDataTypeFieldImpl(field.getName(), reorderFields.size(), field.getType())); - } - } - return reorderFields; + // put ts field if it has defined. + int tsIndex = indexOf(fields, timestampField); + if (tsIndex != -1) { + RelDataTypeField tsTypeField = fields.get(tsIndex); + reorderFields.add( + new RelDataTypeFieldImpl( + tsTypeField.getName(), + EdgeType.TIME_FIELD_POSITION, + edgeTs(tsTypeField.getType(), typeFactory))); + } + int labelIndex = indexOf(fields, GraphSchema.LABEL_FIELD_NAME); + // put other fields by order exclude ~label. + for (int k = 0; k < fields.size(); k++) { + RelDataTypeField field = fields.get(k); + if (k != srcIdIndex && k != targetIdIndex && k != labelIndex && k != tsIndex) { + reorderFields.add( + new RelDataTypeFieldImpl(field.getName(), reorderFields.size(), field.getType())); + } } + return reorderFields; + } - static int indexOf(List fields, String name) { - if (name == null) { - return -1; - } - for (int i = 0; i < fields.size(); i++) { - if (fields.get(i).getName().equalsIgnoreCase(name)) { - return i; - } - } - return -1; + static int indexOf(List fields, String name) { + if (name == null) { + return -1; } + for (int i = 0; i < fields.size(); i++) { + if (fields.get(i).getName().equalsIgnoreCase(name)) { + return i; + } + } + return -1; + } - public static EdgeRecordType emptyEdgeType(RelDataType idType, RelDataTypeFactory typeFactory) { - List fields = new ArrayList<>(); - fields.add(new RelDataTypeFieldImpl(EdgeType.DEFAULT_SRC_ID_NAME, - EdgeType.SRC_ID_FIELD_POSITION, MetaFieldType.edgeSrcId(idType, typeFactory))); - fields.add(new RelDataTypeFieldImpl(EdgeType.DEFAULT_TARGET_ID_NAME, - EdgeType.TARGET_ID_FIELD_POSITION, MetaFieldType.edgeTargetId(idType, typeFactory))); - fields.add(new RelDataTypeFieldImpl(GraphSchema.LABEL_FIELD_NAME, EdgeType.LABEL_FIELD_POSITION, + public static EdgeRecordType emptyEdgeType(RelDataType idType, RelDataTypeFactory typeFactory) { + List fields = new ArrayList<>(); + fields.add( + new RelDataTypeFieldImpl( + EdgeType.DEFAULT_SRC_ID_NAME, + EdgeType.SRC_ID_FIELD_POSITION, + MetaFieldType.edgeSrcId(idType, typeFactory))); + fields.add( + new RelDataTypeFieldImpl( + EdgeType.DEFAULT_TARGET_ID_NAME, + EdgeType.TARGET_ID_FIELD_POSITION, + MetaFieldType.edgeTargetId(idType, typeFactory))); + fields.add( + new RelDataTypeFieldImpl( + GraphSchema.LABEL_FIELD_NAME, + EdgeType.LABEL_FIELD_POSITION, typeFactory.createSqlType(SqlTypeName.VARCHAR))); - return EdgeRecordType.createEdgeType(fields, EdgeType.DEFAULT_SRC_ID_NAME, - EdgeType.DEFAULT_TARGET_ID_NAME, null, typeFactory); - } + return EdgeRecordType.createEdgeType( + fields, EdgeType.DEFAULT_SRC_ID_NAME, EdgeType.DEFAULT_TARGET_ID_NAME, null, typeFactory); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/GraphRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/GraphRecordType.java index b47a9165c..36978fdc8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/GraphRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/GraphRecordType.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -43,274 +44,277 @@ public class GraphRecordType extends RelRecordType { - private static final Set META_FIELD_NAMES = new HashSet() { + private static final Set META_FIELD_NAMES = + new HashSet() { { - add(VertexType.DEFAULT_ID_FIELD_NAME.toUpperCase(Locale.ROOT)); - add(EdgeType.DEFAULT_SRC_ID_NAME.toUpperCase(Locale.ROOT)); - add(EdgeType.DEFAULT_TARGET_ID_NAME.toUpperCase(Locale.ROOT)); - add(EdgeType.DEFAULT_LABEL_NAME.toUpperCase(Locale.ROOT)); - add(EdgeType.DEFAULT_TS_NAME.toUpperCase(Locale.ROOT)); + add(VertexType.DEFAULT_ID_FIELD_NAME.toUpperCase(Locale.ROOT)); + add(EdgeType.DEFAULT_SRC_ID_NAME.toUpperCase(Locale.ROOT)); + add(EdgeType.DEFAULT_TARGET_ID_NAME.toUpperCase(Locale.ROOT)); + add(EdgeType.DEFAULT_LABEL_NAME.toUpperCase(Locale.ROOT)); + add(EdgeType.DEFAULT_TS_NAME.toUpperCase(Locale.ROOT)); } - }; + }; - private final String graphName; + private final String graphName; - public static void validateFieldName(String name) { - if (META_FIELD_NAMES.contains(name.toUpperCase(Locale.ROOT))) { - throw new GeaFlowDSLException("Field {} cannot use in graph as field name.", name); - } + public static void validateFieldName(String name) { + if (META_FIELD_NAMES.contains(name.toUpperCase(Locale.ROOT))) { + throw new GeaFlowDSLException("Field {} cannot use in graph as field name.", name); } + } - public GraphRecordType(String graphName, List fields) { - super(StructKind.PEEK_FIELDS, fields); - this.graphName = Objects.requireNonNull(graphName); - } + public GraphRecordType(String graphName, List fields) { + super(StructKind.PEEK_FIELDS, fields); + this.graphName = Objects.requireNonNull(graphName); + } - public RelDataTypeField getField(List fields, boolean caseSensitive) { - RelDataTypeField field = super.getField(fields.get(0), caseSensitive, false); - if (field != null) { - RelDataType fieldType = field.getType(); - for (int i = 1; i < fields.size(); i++) { - field = fieldType.getField(fields.get(i), caseSensitive, false); - if (field == null) { - return null; - } - fieldType = field.getType(); - } + public RelDataTypeField getField(List fields, boolean caseSensitive) { + RelDataTypeField field = super.getField(fields.get(0), caseSensitive, false); + if (field != null) { + RelDataType fieldType = field.getType(); + for (int i = 1; i < fields.size(); i++) { + field = fieldType.getField(fields.get(i), caseSensitive, false); + if (field == null) { + return null; } - return field; + fieldType = field.getType(); + } } + return field; + } - @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { - sb.append("Graph:"); - super.generateTypeString(sb, withDetail); - } + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + sb.append("Graph:"); + super.generateTypeString(sb, withDetail); + } - @Override - public SqlTypeName getSqlTypeName() { - return SqlTypeName.GRAPH; - } + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.GRAPH; + } - public String getGraphName() { - return graphName; - } + public String getGraphName() { + return graphName; + } - public VertexRecordType getVertexType(Collection vertexTypes, RelDataTypeFactory typeFactory) { - for (String vertexType : vertexTypes) { - boolean exist = false; - for (RelDataTypeField field : getFieldList()) { - if (field.getType() instanceof VertexRecordType) { - if (field.getName().equals(vertexType)) { - exist = true; - break; - } - } - } - if (!exist) { - throw new GeaFlowDSLException("Cannot find vertex type: '" + vertexType - + "'."); - } + public VertexRecordType getVertexType( + Collection vertexTypes, RelDataTypeFactory typeFactory) { + for (String vertexType : vertexTypes) { + boolean exist = false; + for (RelDataTypeField field : getFieldList()) { + if (field.getType() instanceof VertexRecordType) { + if (field.getName().equals(vertexType)) { + exist = true; + break; + } } + } + if (!exist) { + throw new GeaFlowDSLException("Cannot find vertex type: '" + vertexType + "'."); + } + } - List vertexTables = new ArrayList<>(); - for (RelDataTypeField field : getFieldList()) { - if (field.getType() instanceof VertexRecordType) { - if (vertexTypes.isEmpty() || vertexTypes.contains(field.getName())) { - vertexTables.add(field); - } - } + List vertexTables = new ArrayList<>(); + for (RelDataTypeField field : getFieldList()) { + if (field.getType() instanceof VertexRecordType) { + if (vertexTypes.isEmpty() || vertexTypes.contains(field.getName())) { + vertexTables.add(field); } + } + } - //Check all vertex tables to be merged have the same ID field name and type - String chkVertexId = null; - RelDataType chkVertexIdType = null; - for (RelDataTypeField field : vertexTables) { - VertexRecordType vertexType = (VertexRecordType) field.getType(); - if (chkVertexId == null) { - chkVertexId = vertexType.getIdField().getName(); - chkVertexIdType = vertexType.getIdField().getType(); - } else { - if (!vertexType.getIdField().getType().equals(chkVertexIdType)) { - throw new GeaFlowDSLException("Id field type should be same between vertex " - + "tables"); - } - } + // Check all vertex tables to be merged have the same ID field name and type + String chkVertexId = null; + RelDataType chkVertexIdType = null; + for (RelDataTypeField field : vertexTables) { + VertexRecordType vertexType = (VertexRecordType) field.getType(); + if (chkVertexId == null) { + chkVertexId = vertexType.getIdField().getName(); + chkVertexIdType = vertexType.getIdField().getType(); + } else { + if (!vertexType.getIdField().getType().equals(chkVertexIdType)) { + throw new GeaFlowDSLException("Id field type should be same between vertex " + "tables"); } + } + } - Map existFields = new HashMap<>(); - List combineFields = new ArrayList<>(); - String idField = null; - for (RelDataTypeField field : vertexTables) { - VertexRecordType vertexType = (VertexRecordType) field.getType(); - idField = vertexType.getIdField().getName(); + Map existFields = new HashMap<>(); + List combineFields = new ArrayList<>(); + String idField = null; + for (RelDataTypeField field : vertexTables) { + VertexRecordType vertexType = (VertexRecordType) field.getType(); + idField = vertexType.getIdField().getName(); - List vertexFields = vertexType.getFieldList(); - for (RelDataTypeField vertexField : vertexFields) { - if (existFields.containsKey(vertexField.getName())) { - // The field with the same name and type between vertex tables - // will be merged as one field. It's illegal for different types - // of same name fields between vertex tables. - if (!existFields.get(vertexField.getName()).equals(vertexField.getType())) { - throw new GeaFlowDSLException("Same name field between vertex tables " - + "shouldn't have different type."); - } - } else { - existFields.put(vertexField.getName(), vertexField.getType()); - combineFields.add(vertexField); - } - } + List vertexFields = vertexType.getFieldList(); + for (RelDataTypeField vertexField : vertexFields) { + if (existFields.containsKey(vertexField.getName())) { + // The field with the same name and type between vertex tables + // will be merged as one field. It's illegal for different types + // of same name fields between vertex tables. + if (!existFields.get(vertexField.getName()).equals(vertexField.getType())) { + throw new GeaFlowDSLException( + "Same name field between vertex tables " + "shouldn't have different type."); + } + } else { + existFields.put(vertexField.getName(), vertexField.getType()); + combineFields.add(vertexField); } - - return VertexRecordType.createVertexType(combineFields, idField, typeFactory); + } } - public EdgeRecordType getEdgeType(Collection edgeTypes, RelDataTypeFactory typeFactory) { - for (String edgeType : edgeTypes) { - boolean exist = false; - for (RelDataTypeField field : getFieldList()) { - if (field.getType() instanceof EdgeRecordType) { - if (field.getName().equals(edgeType)) { - exist = true; - break; - } - } - } - if (!exist) { - throw new GeaFlowDSLException("Cannot find edge type: '" + edgeType - + "'."); - } + return VertexRecordType.createVertexType(combineFields, idField, typeFactory); + } + + public EdgeRecordType getEdgeType(Collection edgeTypes, RelDataTypeFactory typeFactory) { + for (String edgeType : edgeTypes) { + boolean exist = false; + for (RelDataTypeField field : getFieldList()) { + if (field.getType() instanceof EdgeRecordType) { + if (field.getName().equals(edgeType)) { + exist = true; + break; + } } - List edgeTables = new ArrayList<>(); - for (RelDataTypeField field : getFieldList()) { - if (field.getType() instanceof EdgeRecordType) { - if (edgeTypes.isEmpty() || edgeTypes.contains(field.getName())) { - edgeTables.add(field); - } - } + } + if (!exist) { + throw new GeaFlowDSLException("Cannot find edge type: '" + edgeType + "'."); + } + } + List edgeTables = new ArrayList<>(); + for (RelDataTypeField field : getFieldList()) { + if (field.getType() instanceof EdgeRecordType) { + if (edgeTypes.isEmpty() || edgeTypes.contains(field.getName())) { + edgeTables.add(field); } - // Check all edge tables to be merged have the same SOURCE ID / DESTINATION ID - // / TIMESTAMP field name and type - String chkSourceId = null; - String chkDestinationId = null; - String chkTimestamp = null; - Boolean definedTimestamp = null; - RelDataType chkSourceIdType = null; - RelDataType chkDestinationIdType = null; - RelDataType chkTimestampType = null; - for (RelDataTypeField field : edgeTables) { - EdgeRecordType edgeType = (EdgeRecordType) field.getType(); - if (chkSourceId == null) { - chkSourceId = edgeType.getSrcIdField().getName(); - chkSourceIdType = edgeType.getSrcIdField().getType(); - } else { - if (!edgeType.getSrcIdField().getType().equals(chkSourceIdType)) { - throw new GeaFlowDSLException("SOURCE ID field type should be same between edge " - + "tables"); - } - } - if (chkDestinationId == null) { - chkDestinationId = edgeType.getTargetIdField().getName(); - chkDestinationIdType = edgeType.getTargetIdField().getType(); - } else { - if (!edgeType.getTargetIdField().getType().equals(chkDestinationIdType)) { - throw new GeaFlowDSLException("DESTINATION ID field type should be same " - + "between edge tables"); - } - } - if (definedTimestamp == null) { - definedTimestamp = edgeType.getTimestampField().isPresent(); - } else if (definedTimestamp != edgeType.getTimestampField().isPresent()) { - throw new GeaFlowDSLException("TIMESTAMP should defined or not defined in all edge tables"); - } - if (definedTimestamp) { - if (chkTimestamp == null) { - chkTimestamp = edgeType.getTimestampField().get().getName(); - chkTimestampType = edgeType.getTimestampField().get().getType(); - } else { - if (!edgeType.getTimestampField().get().getType().equals(chkTimestampType)) { - throw new GeaFlowDSLException("TIMESTAMP field type should be same between edge " - + "tables"); - } - } - } + } + } + // Check all edge tables to be merged have the same SOURCE ID / DESTINATION ID + // / TIMESTAMP field name and type + String chkSourceId = null; + String chkDestinationId = null; + String chkTimestamp = null; + Boolean definedTimestamp = null; + RelDataType chkSourceIdType = null; + RelDataType chkDestinationIdType = null; + RelDataType chkTimestampType = null; + for (RelDataTypeField field : edgeTables) { + EdgeRecordType edgeType = (EdgeRecordType) field.getType(); + if (chkSourceId == null) { + chkSourceId = edgeType.getSrcIdField().getName(); + chkSourceIdType = edgeType.getSrcIdField().getType(); + } else { + if (!edgeType.getSrcIdField().getType().equals(chkSourceIdType)) { + throw new GeaFlowDSLException( + "SOURCE ID field type should be same between edge " + "tables"); + } + } + if (chkDestinationId == null) { + chkDestinationId = edgeType.getTargetIdField().getName(); + chkDestinationIdType = edgeType.getTargetIdField().getType(); + } else { + if (!edgeType.getTargetIdField().getType().equals(chkDestinationIdType)) { + throw new GeaFlowDSLException( + "DESTINATION ID field type should be same " + "between edge tables"); } + } + if (definedTimestamp == null) { + definedTimestamp = edgeType.getTimestampField().isPresent(); + } else if (definedTimestamp != edgeType.getTimestampField().isPresent()) { + throw new GeaFlowDSLException("TIMESTAMP should defined or not defined in all edge tables"); + } + if (definedTimestamp) { + if (chkTimestamp == null) { + chkTimestamp = edgeType.getTimestampField().get().getName(); + chkTimestampType = edgeType.getTimestampField().get().getType(); + } else { + if (!edgeType.getTimestampField().get().getType().equals(chkTimestampType)) { + throw new GeaFlowDSLException( + "TIMESTAMP field type should be same between edge " + "tables"); + } + } + } + } - Map existFields = new HashMap<>(); - List combineFields = new ArrayList<>(); - String srcIdField = null; - String targetField = null; - String tsField = null; + Map existFields = new HashMap<>(); + List combineFields = new ArrayList<>(); + String srcIdField = null; + String targetField = null; + String tsField = null; - for (RelDataTypeField field : edgeTables) { - EdgeRecordType edgeType = (EdgeRecordType) field.getType(); - srcIdField = edgeType.getSrcIdField().getName(); - targetField = edgeType.getTargetIdField().getName(); - tsField = edgeType.getTimestampField().map(RelDataTypeField::getName).orElse(null); + for (RelDataTypeField field : edgeTables) { + EdgeRecordType edgeType = (EdgeRecordType) field.getType(); + srcIdField = edgeType.getSrcIdField().getName(); + targetField = edgeType.getTargetIdField().getName(); + tsField = edgeType.getTimestampField().map(RelDataTypeField::getName).orElse(null); - List edgeFields = edgeType.getFieldList(); - for (RelDataTypeField edgeField : edgeFields) { - if (existFields.containsKey(edgeField.getName())) { - // The field with the same name and type between edge tables - // will be merged as one field. It's illegal for different types - // of same name fields between edge tables. - if (!existFields.get(edgeField.getName()).equals(edgeField.getType())) { - throw new GeaFlowDSLException("Same name field between edge tables " - + "shouldn't have different type."); - } - } else { - existFields.put(edgeField.getName(), edgeField.getType()); - combineFields.add(edgeField); - } - } + List edgeFields = edgeType.getFieldList(); + for (RelDataTypeField edgeField : edgeFields) { + if (existFields.containsKey(edgeField.getName())) { + // The field with the same name and type between edge tables + // will be merged as one field. It's illegal for different types + // of same name fields between edge tables. + if (!existFields.get(edgeField.getName()).equals(edgeField.getType())) { + throw new GeaFlowDSLException( + "Same name field between edge tables " + "shouldn't have different type."); + } + } else { + existFields.put(edgeField.getName(), edgeField.getType()); + combineFields.add(edgeField); } - - return EdgeRecordType.createEdgeType(combineFields, srcIdField, targetField, tsField, typeFactory); + } } - public static List renameMetaField(List fields, - MetaField metaType, - String newFieldName) { - List metaFields = new ArrayList<>(); - return fields.stream().filter(f -> { - if (f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(metaType)) { + return EdgeRecordType.createEdgeType( + combineFields, srcIdField, targetField, tsField, typeFactory); + } + + public static List renameMetaField( + List fields, MetaField metaType, String newFieldName) { + List metaFields = new ArrayList<>(); + return fields.stream() + .filter( + f -> { + if (f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()).getMetaField().equals(metaType)) { if (metaFields.isEmpty()) { - metaFields.add(f); - return true; + metaFields.add(f); + return true; } return false; - } - return true; - }).map(f -> { - if (f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(metaType)) { + } + return true; + }) + .map( + f -> { + if (f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()).getMetaField().equals(metaType)) { return new RelDataTypeFieldImpl(newFieldName, f.getIndex(), f.getType()); - } else { + } else { return f; - } - }).collect(Collectors.toList()); - } + } + }) + .collect(Collectors.toList()); + } - /** - * Copy the graph type and add a vertex field to all the vertex tables. - * - * @param fieldName The added field name. - * @param fieldType The added field type. - */ - public GraphRecordType addVertexField(String fieldName, RelDataType fieldType) { - List fields = getFieldList(); - List newFields = new ArrayList<>(); - for (RelDataTypeField field : fields) { - if (field.getType().getSqlTypeName() == SqlTypeName.VERTEX) { - VertexRecordType vertexRecordType = (VertexRecordType) field.getType(); - VertexRecordType newVertexType = vertexRecordType.add(fieldName, - fieldType, false); - newFields.add(new RelDataTypeFieldImpl(field.getName(), field.getIndex(), newVertexType)); - } else { - newFields.add(field); - } - } - return new GraphRecordType(graphName, newFields); + /** + * Copy the graph type and add a vertex field to all the vertex tables. + * + * @param fieldName The added field name. + * @param fieldType The added field type. + */ + public GraphRecordType addVertexField(String fieldName, RelDataType fieldType) { + List fields = getFieldList(); + List newFields = new ArrayList<>(); + for (RelDataTypeField field : fields) { + if (field.getType().getSqlTypeName() == SqlTypeName.VERTEX) { + VertexRecordType vertexRecordType = (VertexRecordType) field.getType(); + VertexRecordType newVertexType = vertexRecordType.add(fieldName, fieldType, false); + newFields.add(new RelDataTypeFieldImpl(field.getName(), field.getIndex(), newVertexType)); + } else { + newFields.add(field); + } } + return new GraphRecordType(graphName, newFields); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/JoinPathRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/JoinPathRecordType.java index 4164e0b10..862c45713 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/JoinPathRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/JoinPathRecordType.java @@ -20,23 +20,24 @@ package org.apache.geaflow.dsl.calcite; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; public class JoinPathRecordType extends PathRecordType { - public JoinPathRecordType(List fields) { - super(fields); - } + public JoinPathRecordType(List fields) { + super(fields); + } - @Override - public PathRecordType copy(int index, RelDataType newType) { - PathRecordType recordType = super.copy(index, newType); - return new JoinPathRecordType(recordType.getFieldList()); - } + @Override + public PathRecordType copy(int index, RelDataType newType) { + PathRecordType recordType = super.copy(index, newType); + return new JoinPathRecordType(recordType.getFieldList()); + } - @Override - public boolean isSinglePath() { - return false; - } -} \ No newline at end of file + @Override + public boolean isSinglePath() { + return false; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/MetaFieldType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/MetaFieldType.java index 884f6129c..8818be8a6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/MetaFieldType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/MetaFieldType.java @@ -22,6 +22,7 @@ import java.nio.charset.Charset; import java.util.List; import java.util.Objects; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeComparability; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -37,203 +38,208 @@ public class MetaFieldType extends RelDataTypeImpl { - private final MetaField metaField; - - private final RelDataType type; - - private MetaFieldType(MetaField metaField, RelDataType type) { - this.metaField = metaField; - this.type = type; - computeDigest(); - } - - public static MetaFieldType vertexId(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.VERTEX_ID, typeFactory.createTypeWithNullability(type, false)); - } - - public static MetaFieldType vertexType(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.VERTEX_TYPE, typeFactory.createTypeWithNullability(type, false)); - } - - public static MetaFieldType edgeSrcId(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.EDGE_SRC_ID, typeFactory.createTypeWithNullability(type, false)); - } - - public static MetaFieldType edgeTargetId(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.EDGE_TARGET_ID, typeFactory.createTypeWithNullability(type, false)); - } - - public static MetaFieldType edgeType(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.EDGE_TYPE, typeFactory.createTypeWithNullability(type, false)); - } - - public static MetaFieldType edgeTs(RelDataType type, RelDataTypeFactory typeFactory) { - return new MetaFieldType(MetaField.EDGE_TS, typeFactory.createTypeWithNullability(type, false)); - } - - @Override - public void computeDigest() { - if (type instanceof RelDataTypeImpl) { - ((RelDataTypeImpl) type).computeDigest(); - this.digest = ((RelDataTypeImpl) type).getDigest(); - } - } - - @Override - public boolean isStruct() { - return type.isStruct(); - } - - @Override - public List getFieldList() { - return type.getFieldList(); - } - - @Override - public List getFieldNames() { - return type.getFieldNames(); - } - - @Override - public int getFieldCount() { - return type.getFieldCount(); - } - - @Override - public StructKind getStructKind() { - return type.getStructKind(); - } - - @Override - public RelDataTypeField getField(String fieldName, boolean caseSensitive, boolean elideRecord) { - return type.getField(fieldName, caseSensitive, elideRecord); - } - - @Override - public boolean isNullable() { - return false; - } - - @Override - public RelDataType getComponentType() { - return type.getComponentType(); - } - - @Override - public RelDataType getKeyType() { - return type.getKeyType(); - } - - @Override - public RelDataType getValueType() { - return type.getValueType(); - } - - @Override - public Charset getCharset() { - return type.getCharset(); - } - - @Override - public SqlCollation getCollation() { - return type.getCollation(); - } - - @Override - public SqlIntervalQualifier getIntervalQualifier() { - return type.getIntervalQualifier(); - } - - @Override - public int getPrecision() { - return type.getPrecision(); - } - - @Override - public int getScale() { - return type.getScale(); - } - - @Override - public SqlTypeName getSqlTypeName() { - return type.getSqlTypeName(); - } - - @Override - public SqlIdentifier getSqlIdentifier() { - return type.getSqlIdentifier(); - } - - @Override - public String toString() { - return type.toString(); - } - - @Override - public String getFullTypeString() { - return type.toString(); - } - - @Override - public RelDataTypeFamily getFamily() { - return type.getFamily(); - } - - @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { - if (type instanceof RelDataTypeImpl) { - ((RelDataTypeImpl) type).generateTypeString2(sb, withDetail); - } else { - sb.append(type.toString()); - } - } - - @Override - public RelDataTypePrecedenceList getPrecedenceList() { - return type.getPrecedenceList(); - } - - @Override - public RelDataTypeComparability getComparability() { - return type.getComparability(); - } - - @Override - public boolean isDynamicStruct() { - return type.isDynamicStruct(); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RelDataType)) { - return false; - } - RelDataType that = (RelDataType) o; - return Objects.equals(type, that); - } - - @Override - public int hashCode() { - return Objects.hash(type); - } - - public MetaField getMetaField() { - return metaField; - } - - public RelDataType getType() { - return type; - } - - public enum MetaField { - VERTEX_ID, - VERTEX_TYPE, - EDGE_SRC_ID, - EDGE_TARGET_ID, - EDGE_TYPE, - EDGE_TS - } + private final MetaField metaField; + + private final RelDataType type; + + private MetaFieldType(MetaField metaField, RelDataType type) { + this.metaField = metaField; + this.type = type; + computeDigest(); + } + + public static MetaFieldType vertexId(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType( + MetaField.VERTEX_ID, typeFactory.createTypeWithNullability(type, false)); + } + + public static MetaFieldType vertexType(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType( + MetaField.VERTEX_TYPE, typeFactory.createTypeWithNullability(type, false)); + } + + public static MetaFieldType edgeSrcId(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType( + MetaField.EDGE_SRC_ID, typeFactory.createTypeWithNullability(type, false)); + } + + public static MetaFieldType edgeTargetId(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType( + MetaField.EDGE_TARGET_ID, typeFactory.createTypeWithNullability(type, false)); + } + + public static MetaFieldType edgeType(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType( + MetaField.EDGE_TYPE, typeFactory.createTypeWithNullability(type, false)); + } + + public static MetaFieldType edgeTs(RelDataType type, RelDataTypeFactory typeFactory) { + return new MetaFieldType(MetaField.EDGE_TS, typeFactory.createTypeWithNullability(type, false)); + } + + @Override + public void computeDigest() { + if (type instanceof RelDataTypeImpl) { + ((RelDataTypeImpl) type).computeDigest(); + this.digest = ((RelDataTypeImpl) type).getDigest(); + } + } + + @Override + public boolean isStruct() { + return type.isStruct(); + } + + @Override + public List getFieldList() { + return type.getFieldList(); + } + + @Override + public List getFieldNames() { + return type.getFieldNames(); + } + + @Override + public int getFieldCount() { + return type.getFieldCount(); + } + + @Override + public StructKind getStructKind() { + return type.getStructKind(); + } + + @Override + public RelDataTypeField getField(String fieldName, boolean caseSensitive, boolean elideRecord) { + return type.getField(fieldName, caseSensitive, elideRecord); + } + + @Override + public boolean isNullable() { + return false; + } + + @Override + public RelDataType getComponentType() { + return type.getComponentType(); + } + + @Override + public RelDataType getKeyType() { + return type.getKeyType(); + } + + @Override + public RelDataType getValueType() { + return type.getValueType(); + } + + @Override + public Charset getCharset() { + return type.getCharset(); + } + + @Override + public SqlCollation getCollation() { + return type.getCollation(); + } + + @Override + public SqlIntervalQualifier getIntervalQualifier() { + return type.getIntervalQualifier(); + } + + @Override + public int getPrecision() { + return type.getPrecision(); + } + + @Override + public int getScale() { + return type.getScale(); + } + + @Override + public SqlTypeName getSqlTypeName() { + return type.getSqlTypeName(); + } + + @Override + public SqlIdentifier getSqlIdentifier() { + return type.getSqlIdentifier(); + } + + @Override + public String toString() { + return type.toString(); + } + + @Override + public String getFullTypeString() { + return type.toString(); + } + + @Override + public RelDataTypeFamily getFamily() { + return type.getFamily(); + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + if (type instanceof RelDataTypeImpl) { + ((RelDataTypeImpl) type).generateTypeString2(sb, withDetail); + } else { + sb.append(type.toString()); + } + } + + @Override + public RelDataTypePrecedenceList getPrecedenceList() { + return type.getPrecedenceList(); + } + + @Override + public RelDataTypeComparability getComparability() { + return type.getComparability(); + } + + @Override + public boolean isDynamicStruct() { + return type.isDynamicStruct(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RelDataType)) { + return false; + } + RelDataType that = (RelDataType) o; + return Objects.equals(type, that); + } + + @Override + public int hashCode() { + return Objects.hash(type); + } + + public MetaField getMetaField() { + return metaField; + } + + public RelDataType getType() { + return type; + } + + public enum MetaField { + VERTEX_ID, + VERTEX_TYPE, + EDGE_SRC_ID, + EDGE_TARGET_ID, + EDGE_TYPE, + EDGE_TS + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/PathRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/PathRecordType.java index 496e0af7e..32fcd7131 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/PathRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/PathRecordType.java @@ -26,6 +26,7 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -38,108 +39,114 @@ public class PathRecordType extends RelRecordType { - public static final PathRecordType EMPTY = new PathRecordType(Collections.emptyList()); - - public PathRecordType(List fields) { - super(StructKind.PEEK_FIELDS, fields); - } - - @Override - public SqlTypeName getSqlTypeName() { - return SqlTypeName.PATH; - } - - @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { - super.generateTypeString(sb.append("Path:"), withDetail); - } + public static final PathRecordType EMPTY = new PathRecordType(Collections.emptyList()); - public PathRecordType copy(int index, RelDataType newType) { - List fields = getFieldList(); - RelDataTypeField field = fields.get(index); + public PathRecordType(List fields) { + super(StructKind.PEEK_FIELDS, fields); + } - if (field.getType().getSqlTypeName() != newType.getSqlTypeName()) { - throw new IllegalArgumentException("Cannot replace field: " + field.getName() - + " with different typename"); - } - RelDataTypeField newField = new RelDataTypeFieldImpl(field.getName(), index, newType); + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.PATH; + } - List newFields = new ArrayList<>(fields); - newFields.set(index, newField); - return new PathRecordType(newFields); - } + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + super.generateTypeString(sb.append("Path:"), withDetail); + } - public boolean canConcat(PathRecordType other) { - return (this.lastFieldName().equals(other.firstFieldName())) - && this.isSinglePath() && other.isSinglePath(); - } + public PathRecordType copy(int index, RelDataType newType) { + List fields = getFieldList(); + RelDataTypeField field = fields.get(index); - public PathRecordType concat(PathRecordType other, boolean caseSensitive) { - List newFields = new ArrayList<>(getFieldList()); - Set fieldNames = getFieldNames().stream().map(name -> { - if (caseSensitive) { - return name; - } else { - return name.toUpperCase(Locale.ROOT); - } - }).collect(Collectors.toSet()); - - for (RelDataTypeField field : other.getFieldList()) { - String name = caseSensitive ? field.getName() : field.getName().toUpperCase(Locale.ROOT); - if (!fieldNames.contains(name)) { - int index = newFields.size(); - newFields.add(field.copy(index)); - } - } - return new PathRecordType(newFields); + if (field.getType().getSqlTypeName() != newType.getSqlTypeName()) { + throw new IllegalArgumentException( + "Cannot replace field: " + field.getName() + " with different typename"); } - - public PathRecordType join(PathRecordType other, RelDataTypeFactory typeFactory) { - RelDataType joinType = SqlValidatorUtil.deriveJoinRowType(this, other, - JoinRelType.INNER, typeFactory, null, Collections.emptyList()); - return new JoinPathRecordType(joinType.getFieldList()); + RelDataTypeField newField = new RelDataTypeFieldImpl(field.getName(), index, newType); + + List newFields = new ArrayList<>(fields); + newFields.set(index, newField); + return new PathRecordType(newFields); + } + + public boolean canConcat(PathRecordType other) { + return (this.lastFieldName().equals(other.firstFieldName())) + && this.isSinglePath() + && other.isSinglePath(); + } + + public PathRecordType concat(PathRecordType other, boolean caseSensitive) { + List newFields = new ArrayList<>(getFieldList()); + Set fieldNames = + getFieldNames().stream() + .map( + name -> { + if (caseSensitive) { + return name; + } else { + return name.toUpperCase(Locale.ROOT); + } + }) + .collect(Collectors.toSet()); + + for (RelDataTypeField field : other.getFieldList()) { + String name = caseSensitive ? field.getName() : field.getName().toUpperCase(Locale.ROOT); + if (!fieldNames.contains(name)) { + int index = newFields.size(); + newFields.add(field.copy(index)); + } } - - public PathRecordType addField(String name, RelDataType type, boolean caseSensitive) { - List newFields = new ArrayList<>(getFieldList()); - RelDataTypeField field = getField(name, caseSensitive, false); - if (field != null) { - newFields.set(field.getIndex(), new RelDataTypeFieldImpl(name, field.getIndex(), type)); - } else { - newFields.add(new RelDataTypeFieldImpl(name, newFields.size(), type)); - } - return new PathRecordType(newFields); + return new PathRecordType(newFields); + } + + public PathRecordType join(PathRecordType other, RelDataTypeFactory typeFactory) { + RelDataType joinType = + SqlValidatorUtil.deriveJoinRowType( + this, other, JoinRelType.INNER, typeFactory, null, Collections.emptyList()); + return new JoinPathRecordType(joinType.getFieldList()); + } + + public PathRecordType addField(String name, RelDataType type, boolean caseSensitive) { + List newFields = new ArrayList<>(getFieldList()); + RelDataTypeField field = getField(name, caseSensitive, false); + if (field != null) { + newFields.set(field.getIndex(), new RelDataTypeFieldImpl(name, field.getIndex(), type)); + } else { + newFields.add(new RelDataTypeFieldImpl(name, newFields.size(), type)); } + return new PathRecordType(newFields); + } - public Optional firstFieldName() { - if (fieldList.size() == 0) { - return Optional.empty(); - } - return Optional.of(fieldList.get(0).getName()); + public Optional firstFieldName() { + if (fieldList.size() == 0) { + return Optional.empty(); } + return Optional.of(fieldList.get(0).getName()); + } - public Optional lastFieldName() { - if (fieldList.size() == 0) { - return Optional.empty(); - } - return Optional.of(fieldList.get(fieldList.size() - 1).getName()); + public Optional lastFieldName() { + if (fieldList.size() == 0) { + return Optional.empty(); } + return Optional.of(fieldList.get(fieldList.size() - 1).getName()); + } - public Optional firstField() { - if (fieldList.size() == 0) { - return Optional.empty(); - } - return Optional.of(fieldList.get(0)); + public Optional firstField() { + if (fieldList.size() == 0) { + return Optional.empty(); } + return Optional.of(fieldList.get(0)); + } - public Optional lastField() { - if (fieldList.size() == 0) { - return Optional.empty(); - } - return Optional.of(fieldList.get(fieldList.size() - 1)); + public Optional lastField() { + if (fieldList.size() == 0) { + return Optional.empty(); } + return Optional.of(fieldList.get(fieldList.size() - 1)); + } - public boolean isSinglePath() { - return true; - } + public boolean isSinglePath() { + return true; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/UnionPathRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/UnionPathRecordType.java index 92127a5a8..8d4a74a47 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/UnionPathRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/UnionPathRecordType.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -33,89 +34,97 @@ public class UnionPathRecordType extends PathRecordType { - private final List inputPathRecordTypes; + private final List inputPathRecordTypes; - public UnionPathRecordType(List unionPaths, RelDataTypeFactory typeFactory) { - super(createUnionTypeFields(unionPaths, typeFactory)); - inputPathRecordTypes = Objects.requireNonNull(unionPaths); - } + public UnionPathRecordType(List unionPaths, RelDataTypeFactory typeFactory) { + super(createUnionTypeFields(unionPaths, typeFactory)); + inputPathRecordTypes = Objects.requireNonNull(unionPaths); + } - public UnionPathRecordType(List fields, List unionPaths) { - super(fields); - inputPathRecordTypes = Objects.requireNonNull(unionPaths); - } - - @Override - public boolean isNullable() { - return true; - } + public UnionPathRecordType(List fields, List unionPaths) { + super(fields); + inputPathRecordTypes = Objects.requireNonNull(unionPaths); + } - private static List createUnionTypeFields(Iterable unionPaths, - RelDataTypeFactory typeFactory) { - List unionFields = new ArrayList<>(); - Map presented = new HashMap<>(); - for (PathRecordType pathType : unionPaths) { - for (int i = 0; i < pathType.getFieldCount(); i++) { - String fieldName = pathType.getFieldList().get(i).getName(); - if (!presented.containsKey(fieldName)) { - RelDataTypeField newUnionField = - pathType.getFieldList().get(i).copy(unionFields.size()); - unionFields.add(newUnionField); - presented.put(fieldName, unionFields.size() - 1); - } else { - //derive union type - RelDataTypeField sameNameField = unionFields.get(presented.get(fieldName)); - List unionFieldsList = sameNameField.getType().getFieldList(); - List newUnionFieldsList = new ArrayList<>(unionFieldsList); - List fieldsList = - pathType.getFieldList().get(i).getType().getFieldList(); - for (RelDataTypeField field : fieldsList) { - boolean found = false; - for (RelDataTypeField unionField : unionFieldsList) { - if (field.getName().equals(unionField.getName())) { - if (!field.getType().equals(unionField.getType())) { - throw new GeaFlowDSLException( - "Encountered ambiguous field with the same name " - + "but different type when generating the Union type. \n" - + "Name: " + field.getName() + "\n" - + "Type: " + field.getType()); - } + @Override + public boolean isNullable() { + return true; + } - found = true; - } - } - if (!found) { - newUnionFieldsList.add(field); - } - } - if (sameNameField.getType().getSqlTypeName() == SqlTypeName.VERTEX) { - unionFields.set(presented.get(fieldName), - new RelDataTypeFieldImpl(fieldName, sameNameField.getIndex(), - VertexRecordType.createVertexType(newUnionFieldsList, typeFactory))); - } else if (sameNameField.getType().getSqlTypeName() == SqlTypeName.EDGE) { - unionFields.set(presented.get(fieldName), - new RelDataTypeFieldImpl(fieldName, sameNameField.getIndex(), - EdgeRecordType.createEdgeType(newUnionFieldsList, typeFactory))); - } else { - throw new IllegalArgumentException("Illegal type: " + sameNameField.getType()); - } + private static List createUnionTypeFields( + Iterable unionPaths, RelDataTypeFactory typeFactory) { + List unionFields = new ArrayList<>(); + Map presented = new HashMap<>(); + for (PathRecordType pathType : unionPaths) { + for (int i = 0; i < pathType.getFieldCount(); i++) { + String fieldName = pathType.getFieldList().get(i).getName(); + if (!presented.containsKey(fieldName)) { + RelDataTypeField newUnionField = pathType.getFieldList().get(i).copy(unionFields.size()); + unionFields.add(newUnionField); + presented.put(fieldName, unionFields.size() - 1); + } else { + // derive union type + RelDataTypeField sameNameField = unionFields.get(presented.get(fieldName)); + List unionFieldsList = sameNameField.getType().getFieldList(); + List newUnionFieldsList = new ArrayList<>(unionFieldsList); + List fieldsList = + pathType.getFieldList().get(i).getType().getFieldList(); + for (RelDataTypeField field : fieldsList) { + boolean found = false; + for (RelDataTypeField unionField : unionFieldsList) { + if (field.getName().equals(unionField.getName())) { + if (!field.getType().equals(unionField.getType())) { + throw new GeaFlowDSLException( + "Encountered ambiguous field with the same name " + + "but different type when generating the Union type. \n" + + "Name: " + + field.getName() + + "\n" + + "Type: " + + field.getType()); } + + found = true; + } + } + if (!found) { + newUnionFieldsList.add(field); } + } + if (sameNameField.getType().getSqlTypeName() == SqlTypeName.VERTEX) { + unionFields.set( + presented.get(fieldName), + new RelDataTypeFieldImpl( + fieldName, + sameNameField.getIndex(), + VertexRecordType.createVertexType(newUnionFieldsList, typeFactory))); + } else if (sameNameField.getType().getSqlTypeName() == SqlTypeName.EDGE) { + unionFields.set( + presented.get(fieldName), + new RelDataTypeFieldImpl( + fieldName, + sameNameField.getIndex(), + EdgeRecordType.createEdgeType(newUnionFieldsList, typeFactory))); + } else { + throw new IllegalArgumentException("Illegal type: " + sameNameField.getType()); + } } - return unionFields; + } } + return unionFields; + } - @Override - public boolean isSinglePath() { - return false; - } + @Override + public boolean isSinglePath() { + return false; + } - @Override - public UnionPathRecordType addField(String name, RelDataType type, boolean caseSensitive) { - throw new GeaFlowDSLException("Illegal call."); - } + @Override + public UnionPathRecordType addField(String name, RelDataType type, boolean caseSensitive) { + throw new GeaFlowDSLException("Illegal call."); + } - public List getInputPathRecordTypes() { - return inputPathRecordTypes; - } + public List getInputPathRecordTypes() { + return inputPathRecordTypes; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/VertexRecordType.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/VertexRecordType.java index 80350cbf7..e25199197 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/VertexRecordType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/calcite/VertexRecordType.java @@ -22,9 +22,9 @@ import static org.apache.geaflow.dsl.calcite.MetaFieldType.vertexId; import static org.apache.geaflow.dsl.calcite.MetaFieldType.vertexType; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -36,120 +36,144 @@ import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.common.types.VertexType; -public class VertexRecordType extends RelRecordType { - - private VertexRecordType(List fields) { - super(StructKind.PEEK_FIELDS, fields); - } - - @Override - public boolean isNullable() { - return true; - } - - public static VertexRecordType createVertexType(List fields, String idField, - RelDataTypeFactory typeFactory) { - boolean hasMultiIdFields = fields.stream().filter(f -> f.getType() instanceof MetaFieldType - && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.VERTEX_ID)).count() > 1; - if (hasMultiIdFields) { - idField = VertexType.DEFAULT_ID_FIELD_NAME; - fields = GraphRecordType.renameMetaField(fields, MetaField.VERTEX_ID, idField); - } - List reorderFields = reorderFields(fields, idField, typeFactory); - return new VertexRecordType(reorderFields); - } - - public static VertexRecordType createVertexType(List fields, RelDataTypeFactory typeFactory) { - String idField = null; - for (RelDataTypeField field : fields) { - if (field.getType() instanceof MetaFieldType - && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_ID) { - idField = field.getName(); - } - } - Preconditions.checkArgument(idField != null, "Missing id field"); - return createVertexType(fields, idField, typeFactory); - } +import com.google.common.base.Preconditions; - @Override - public SqlTypeName getSqlTypeName() { - return SqlTypeName.VERTEX; - } +public class VertexRecordType extends RelRecordType { - @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { - super.generateTypeString(sb.append("Vertex:"), withDetail); + private VertexRecordType(List fields) { + super(StructKind.PEEK_FIELDS, fields); + } + + @Override + public boolean isNullable() { + return true; + } + + public static VertexRecordType createVertexType( + List fields, String idField, RelDataTypeFactory typeFactory) { + boolean hasMultiIdFields = + fields.stream() + .filter( + f -> + f.getType() instanceof MetaFieldType + && ((MetaFieldType) f.getType()) + .getMetaField() + .equals(MetaField.VERTEX_ID)) + .count() + > 1; + if (hasMultiIdFields) { + idField = VertexType.DEFAULT_ID_FIELD_NAME; + fields = GraphRecordType.renameMetaField(fields, MetaField.VERTEX_ID, idField); } - - public RelDataTypeField getIdField() { - return fieldList.get(VertexType.ID_FIELD_POSITION); + List reorderFields = reorderFields(fields, idField, typeFactory); + return new VertexRecordType(reorderFields); + } + + public static VertexRecordType createVertexType( + List fields, RelDataTypeFactory typeFactory) { + String idField = null; + for (RelDataTypeField field : fields) { + if (field.getType() instanceof MetaFieldType + && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_ID) { + idField = field.getName(); + } } - - public RelDataTypeField getLabelField() { - return fieldList.get(VertexType.LABEL_FIELD_POSITION); + Preconditions.checkArgument(idField != null, "Missing id field"); + return createVertexType(fields, idField, typeFactory); + } + + @Override + public SqlTypeName getSqlTypeName() { + return SqlTypeName.VERTEX; + } + + @Override + protected void generateTypeString(StringBuilder sb, boolean withDetail) { + super.generateTypeString(sb.append("Vertex:"), withDetail); + } + + public RelDataTypeField getIdField() { + return fieldList.get(VertexType.ID_FIELD_POSITION); + } + + public RelDataTypeField getLabelField() { + return fieldList.get(VertexType.LABEL_FIELD_POSITION); + } + + public boolean isId(int index) { + return index == VertexType.ID_FIELD_POSITION; + } + + public VertexRecordType add(String fieldName, RelDataType type, boolean caseSensitive) { + if (type instanceof MetaFieldType) { + type = ((MetaFieldType) type).getType(); } + List fields = new ArrayList<>(getFieldList()); - public boolean isId(int index) { - return index == VertexType.ID_FIELD_POSITION; + RelDataTypeField field = getField(fieldName, caseSensitive, false); + if (field != null) { + fields.set(field.getIndex(), new RelDataTypeFieldImpl(fieldName, field.getIndex(), type)); + } else { + fields.add(new RelDataTypeFieldImpl(fieldName, fields.size(), type)); } + return new VertexRecordType(fields); + } - public VertexRecordType add(String fieldName, RelDataType type, boolean caseSensitive) { - if (type instanceof MetaFieldType) { - type = ((MetaFieldType) type).getType(); - } - List fields = new ArrayList<>(getFieldList()); - - RelDataTypeField field = getField(fieldName, caseSensitive, false); - if (field != null) { - fields.set(field.getIndex(), new RelDataTypeFieldImpl(fieldName, field.getIndex(), type)); - } else { - fields.add(new RelDataTypeFieldImpl(fieldName, fields.size(), type)); - } - return new VertexRecordType(fields); + private static List reorderFields( + List fields, String idField, RelDataTypeFactory typeFactory) { + if (fields == null) { + throw new NullPointerException("fields is null"); } - - private static List reorderFields(List fields, String idField, - RelDataTypeFactory typeFactory) { - if (fields == null) { - throw new NullPointerException("fields is null"); - } - List reorderFields = new ArrayList<>(fields.size()); - int idIndex = EdgeRecordType.indexOf(fields, idField); - assert idIndex != -1 : "idField: " + idField + " is not exist"; - RelDataTypeField idTypeField = fields.get(idIndex); - - // put id field at position 0. - reorderFields.add(new RelDataTypeFieldImpl(idTypeField.getName(), VertexType.ID_FIELD_POSITION, + List reorderFields = new ArrayList<>(fields.size()); + int idIndex = EdgeRecordType.indexOf(fields, idField); + assert idIndex != -1 : "idField: " + idField + " is not exist"; + RelDataTypeField idTypeField = fields.get(idIndex); + + // put id field at position 0. + reorderFields.add( + new RelDataTypeFieldImpl( + idTypeField.getName(), + VertexType.ID_FIELD_POSITION, vertexId(idTypeField.getType(), typeFactory))); - // put label field at position 1. - reorderFields.add(new RelDataTypeFieldImpl(GraphSchema.LABEL_FIELD_NAME, VertexType.LABEL_FIELD_POSITION, + // put label field at position 1. + reorderFields.add( + new RelDataTypeFieldImpl( + GraphSchema.LABEL_FIELD_NAME, + VertexType.LABEL_FIELD_POSITION, vertexType(typeFactory.createSqlType(SqlTypeName.VARCHAR), typeFactory))); - // put other fields by order exclude ~label. - int labelIndex = EdgeRecordType.indexOf(fields, GraphSchema.LABEL_FIELD_NAME); - for (int k = 0; k < fields.size(); k++) { - RelDataTypeField field = fields.get(k); - if (k != labelIndex && k != idIndex) { - reorderFields.add(new RelDataTypeFieldImpl(field.getName(), reorderFields.size(), field.getType())); - } - } - return reorderFields; + // put other fields by order exclude ~label. + int labelIndex = EdgeRecordType.indexOf(fields, GraphSchema.LABEL_FIELD_NAME); + for (int k = 0; k < fields.size(); k++) { + RelDataTypeField field = fields.get(k); + if (k != labelIndex && k != idIndex) { + reorderFields.add( + new RelDataTypeFieldImpl(field.getName(), reorderFields.size(), field.getType())); + } } + return reorderFields; + } - public static class VirtualVertexRecordType extends VertexRecordType { + public static class VirtualVertexRecordType extends VertexRecordType { - public static final String VIRTUAL_ID_FIELD_NAME = "~virtual_id"; + public static final String VIRTUAL_ID_FIELD_NAME = "~virtual_id"; - private VirtualVertexRecordType(List fields) { - super(fields); - } + private VirtualVertexRecordType(List fields) { + super(fields); + } - public static VirtualVertexRecordType of(RelDataTypeFactory typeFactory) { - List fields = new ArrayList<>(); - fields.add(new RelDataTypeFieldImpl(VIRTUAL_ID_FIELD_NAME, - VertexType.ID_FIELD_POSITION, typeFactory.createSqlType(SqlTypeName.ANY))); - fields.add(new RelDataTypeFieldImpl("~label", VertexType.LABEL_FIELD_POSITION, - typeFactory.createSqlType(SqlTypeName.VARCHAR))); - return new VirtualVertexRecordType(fields); - } + public static VirtualVertexRecordType of(RelDataTypeFactory typeFactory) { + List fields = new ArrayList<>(); + fields.add( + new RelDataTypeFieldImpl( + VIRTUAL_ID_FIELD_NAME, + VertexType.ID_FIELD_POSITION, + typeFactory.createSqlType(SqlTypeName.ANY))); + fields.add( + new RelDataTypeFieldImpl( + "~label", + VertexType.LABEL_FIELD_POSITION, + typeFactory.createSqlType(SqlTypeName.VARCHAR))); + return new VirtualVertexRecordType(fields); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlBasicQueryOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlBasicQueryOperator.java index d2af7d01e..85602f8dc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlBasicQueryOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlBasicQueryOperator.java @@ -24,26 +24,21 @@ public class SqlBasicQueryOperator extends SqlOperator { - private SqlBasicQueryOperator(String name) { - super(name, SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, - null, null); - } + private SqlBasicQueryOperator(String name) { + super(name, SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, null, null); + } - public static SqlBasicQueryOperator of(String name) { - return new SqlBasicQueryOperator(name); - } + public static SqlBasicQueryOperator of(String name) { + return new SqlBasicQueryOperator(name); + } - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlEdgeConstructOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlEdgeConstructOperator.java index b3ce463d2..6f8807d25 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlEdgeConstructOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlEdgeConstructOperator.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeFieldImpl; @@ -38,62 +39,57 @@ public class SqlEdgeConstructOperator extends SqlMultisetValueConstructor { - private final SqlIdentifier[] keyNodes; + private final SqlIdentifier[] keyNodes; - public SqlEdgeConstructOperator(SqlIdentifier[] keyNodes) { - super("EDGE", SqlKind.EDGE_VALUE_CONSTRUCTOR); - this.keyNodes = keyNodes; - } + public SqlEdgeConstructOperator(SqlIdentifier[] keyNodes) { + super("EDGE", SqlKind.EDGE_VALUE_CONSTRUCTOR); + this.keyNodes = keyNodes; + } - @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - List valuesType = opBinding.collectOperandTypes(); - if (keyNodes.length != valuesType.size()) { - throw new GeaFlowDSLException(String.format("Key size: %s is not equal to the value size: %s at %s", - keyNodes.length, valuesType.size(), keyNodes[0].getParserPosition())); - } - List fields = new ArrayList<>(); - for (int i = 0; i < keyNodes.length; i++) { - String name = keyNodes[i].getSimple(); - RelDataTypeField field = new RelDataTypeFieldImpl(name, i, valuesType.get(i)); - fields.add(field); - } - return EdgeRecordType.createEdgeType(fields, opBinding.getTypeFactory()); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + List valuesType = opBinding.collectOperandTypes(); + if (keyNodes.length != valuesType.size()) { + throw new GeaFlowDSLException( + String.format( + "Key size: %s is not equal to the value size: %s at %s", + keyNodes.length, valuesType.size(), keyNodes[0].getParserPosition())); } - - @Override - public RelDataType deriveType( - SqlValidator validator, - SqlValidatorScope scope, - SqlCall call) { - for (SqlNode operand : call.getOperandList()) { - RelDataType nodeType = validator.deriveType(scope, operand); - assert nodeType != null; - } - RelDataType type = call.getOperator().validateOperands(validator, scope, call); - SqlValidatorUtil.checkCharsetAndCollateConsistentIfCharType(type); - if (type.getSqlTypeName() != SqlTypeName.EDGE) { - throw new GeaFlowDSLException("Edge construct must return edge type, current is: " - + type + " at " + call.getParserPosition()); - } - return type; + List fields = new ArrayList<>(); + for (int i = 0; i < keyNodes.length; i++) { + String name = keyNodes[i].getSimple(); + RelDataTypeField field = new RelDataTypeFieldImpl(name, i, valuesType.get(i)); + fields.add(field); } + return EdgeRecordType.createEdgeType(fields, opBinding.getTypeFactory()); + } - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - pos = pos.plusAll(Arrays.asList(operands)); - return new SqlEdgeConstruct(keyNodes, operands, pos); + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + for (SqlNode operand : call.getOperandList()) { + RelDataType nodeType = validator.deriveType(scope, operand); + assert nodeType != null; } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); + RelDataType type = call.getOperator().validateOperands(validator, scope, call); + SqlValidatorUtil.checkCharsetAndCollateConsistentIfCharType(type); + if (type.getSqlTypeName() != SqlTypeName.EDGE) { + throw new GeaFlowDSLException( + "Edge construct must return edge type, current is: " + + type + + " at " + + call.getParserPosition()); } + return type; + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + pos = pos.plusAll(Arrays.asList(operands)); + return new SqlEdgeConstruct(keyNodes, operands, pos); + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlFilterOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlFilterOperator.java index bacb6c68b..540164aab 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlFilterOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlFilterOperator.java @@ -26,32 +26,24 @@ public class SqlFilterOperator extends SqlOperator { - public static final SqlFilterOperator INSTANCE = new SqlFilterOperator(); - - private SqlFilterOperator() { - super("MatchEdge", SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, - null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlFilterStatement(pos, operands[0], operands[1]); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlFilterOperator INSTANCE = new SqlFilterOperator(); + + private SqlFilterOperator() { + super("MatchEdge", SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlFilterStatement(pos, operands[0], operands[1]); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlGraphAlgorithmOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlGraphAlgorithmOperator.java index 8dc81d908..65f9e4c0e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlGraphAlgorithmOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlGraphAlgorithmOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.operator; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlKind; @@ -30,33 +31,38 @@ public class SqlGraphAlgorithmOperator extends SqlOperator { - public static final SqlGraphAlgorithmOperator INSTANCE = new SqlGraphAlgorithmOperator(); + public static final SqlGraphAlgorithmOperator INSTANCE = new SqlGraphAlgorithmOperator(); - protected SqlGraphAlgorithmOperator() { - super("Graph Algorithm", SqlKind.GQL_ALGORITHM, 2, false, - ReturnTypes.ARG1, new AlgorithmOperandTypeInfer(), null); - } + protected SqlGraphAlgorithmOperator() { + super( + "Graph Algorithm", + SqlKind.GQL_ALGORITHM, + 2, + false, + ReturnTypes.ARG1, + new AlgorithmOperandTypeInfer(), + null); + } - @Override - public boolean checkOperandTypes( - SqlCallBinding callBinding, - boolean throwOnFailure) { - return true; - } + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + return true; + } - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } - private static class AlgorithmOperandTypeInfer implements SqlOperandTypeInference { + private static class AlgorithmOperandTypeInfer implements SqlOperandTypeInference { - @Override - public void inferOperandTypes(SqlCallBinding callBinding, RelDataType returnType, RelDataType[] operandTypes) { - List callOperandTypes = callBinding.collectOperandTypes(); - for (int i = 0; i < callOperandTypes.size(); i++) { - operandTypes[i] = callOperandTypes.get(i); - } - } + @Override + public void inferOperandTypes( + SqlCallBinding callBinding, RelDataType returnType, RelDataType[] operandTypes) { + List callOperandTypes = callBinding.collectOperandTypes(); + for (int i = 0; i < callOperandTypes.size(); i++) { + operandTypes[i] = callOperandTypes.get(i); + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLambdaOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLambdaOperator.java index a56c3e524..5b2933063 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLambdaOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLambdaOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.operator; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlKind; @@ -30,33 +31,31 @@ public class SqlLambdaOperator extends SqlOperator { - public static final SqlLambdaOperator INSTANCE = new SqlLambdaOperator(); + public static final SqlLambdaOperator INSTANCE = new SqlLambdaOperator(); - protected SqlLambdaOperator() { - super("Lambda", SqlKind.OTHER, 2, false, - ReturnTypes.ARG1, new LambdaOperandTypeInfer(), null); - } + protected SqlLambdaOperator() { + super("Lambda", SqlKind.OTHER, 2, false, ReturnTypes.ARG1, new LambdaOperandTypeInfer(), null); + } - @Override - public boolean checkOperandTypes( - SqlCallBinding callBinding, - boolean throwOnFailure) { - return true; - } + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + return true; + } - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } - private static class LambdaOperandTypeInfer implements SqlOperandTypeInference { + private static class LambdaOperandTypeInfer implements SqlOperandTypeInference { - @Override - public void inferOperandTypes(SqlCallBinding callBinding, RelDataType returnType, RelDataType[] operandTypes) { - List callOperandTypes = callBinding.collectOperandTypes(); - for (int i = 0; i < callOperandTypes.size(); i++) { - operandTypes[i] = callOperandTypes.get(i); - } - } + @Override + public void inferOperandTypes( + SqlCallBinding callBinding, RelDataType returnType, RelDataType[] operandTypes) { + List callOperandTypes = callBinding.collectOperandTypes(); + for (int i = 0; i < callOperandTypes.size(); i++) { + operandTypes[i] = callOperandTypes.get(i); + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLetOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLetOperator.java index 44bb62fc2..412524d94 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLetOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlLetOperator.java @@ -26,32 +26,24 @@ public class SqlLetOperator extends SqlOperator { - public static final SqlLetOperator INSTANCE = new SqlLetOperator(); - - private SqlLetOperator() { - super("Let", SqlKind.GQL_LET, 2, true, - ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlLetStatement(pos, operands[0], (SqlIdentifier) operands[1], operands[2], false); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlLetOperator INSTANCE = new SqlLetOperator(); + + private SqlLetOperator() { + super("Let", SqlKind.GQL_LET, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlLetStatement(pos, operands[0], (SqlIdentifier) operands[1], operands[2], false); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchEdgeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchEdgeOperator.java index a3b17731a..38da66919 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchEdgeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchEdgeOperator.java @@ -27,35 +27,34 @@ public class SqlMatchEdgeOperator extends SqlOperator { - public static final SqlMatchEdgeOperator INSTANCE = new SqlMatchEdgeOperator(); - - private SqlMatchEdgeOperator() { - super("MatchEdge", SqlKind.GQL_MATCH_EDGE, 2, true, ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - String directionName = operands[3].toString(); - - return new SqlMatchEdge(pos, (SqlIdentifier) operands[0], (SqlNodeList) operands[1], - (SqlNodeList) operands[2], operands[3], - EdgeDirection.of(directionName), 1, 1); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlMatchEdgeOperator INSTANCE = new SqlMatchEdgeOperator(); + + private SqlMatchEdgeOperator() { + super("MatchEdge", SqlKind.GQL_MATCH_EDGE, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + String directionName = operands[3].toString(); + + return new SqlMatchEdge( + pos, + (SqlIdentifier) operands[0], + (SqlNodeList) operands[1], + (SqlNodeList) operands[2], + operands[3], + EdgeDirection.of(directionName), + 1, + 1); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchNodeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchNodeOperator.java index 2d04eeebf..6e4425806 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchNodeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchNodeOperator.java @@ -26,32 +26,29 @@ public class SqlMatchNodeOperator extends SqlOperator { - public static final SqlMatchNodeOperator INSTANCE = new SqlMatchNodeOperator(); - - private SqlMatchNodeOperator() { - super("MatchNode", SqlKind.GQL_MATCH_NODE, 2, true, ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlMatchNode(pos, (SqlIdentifier) operands[0], (SqlNodeList) operands[1], - (SqlNodeList) operands[2], operands[3]); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlMatchNodeOperator INSTANCE = new SqlMatchNodeOperator(); + + private SqlMatchNodeOperator() { + super("MatchNode", SqlKind.GQL_MATCH_NODE, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlMatchNode( + pos, + (SqlIdentifier) operands[0], + (SqlNodeList) operands[1], + (SqlNodeList) operands[2], + operands[3]); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchPatternOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchPatternOperator.java index fd5b6ebd8..ad482b85e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchPatternOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlMatchPatternOperator.java @@ -26,33 +26,30 @@ public class SqlMatchPatternOperator extends SqlOperator { - public static final SqlMatchPatternOperator INSTANCE = new SqlMatchPatternOperator(); - - private SqlMatchPatternOperator() { - super("MatchPattern", SqlKind.OTHER, 2, true, - ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlMatchPattern(pos, operands[0], (SqlNodeList) operands[1], operands[2], - (SqlNodeList) operands[3], operands[4]); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlMatchPatternOperator INSTANCE = new SqlMatchPatternOperator(); + + private SqlMatchPatternOperator() { + super("MatchPattern", SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlMatchPattern( + pos, + operands[0], + (SqlNodeList) operands[1], + operands[2], + (SqlNodeList) operands[3], + operands[4]); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlPathPatternOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlPathPatternOperator.java index fc248a5c0..2899c8bfc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlPathPatternOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlPathPatternOperator.java @@ -26,32 +26,24 @@ public class SqlPathPatternOperator extends SqlOperator { - public static final SqlPathPatternOperator INSTANCE = new SqlPathPatternOperator(); - - private SqlPathPatternOperator() { - super("MatchNode", SqlKind.GQL_PATH_PATTERN, 2, true, - ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlPathPattern(pos, (SqlNodeList) operands[0], (SqlIdentifier) operands[1]); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlPathPatternOperator INSTANCE = new SqlPathPatternOperator(); + + private SqlPathPatternOperator() { + super("MatchNode", SqlKind.GQL_PATH_PATTERN, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlPathPattern(pos, (SqlNodeList) operands[0], (SqlIdentifier) operands[1]); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlReturnOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlReturnOperator.java index 555a283a7..6bf534659 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlReturnOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlReturnOperator.java @@ -26,38 +26,32 @@ public class SqlReturnOperator extends SqlOperator { - public static final SqlReturnOperator INSTANCE = new SqlReturnOperator(); - - private SqlReturnOperator() { - super("Return", SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, null, null); - } - - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlReturnStatement(pos - , (SqlNodeList) operands[0] - , operands[1] - , (SqlNodeList) operands[2] - , (SqlNodeList) operands[3] - , (SqlNodeList) operands[4] - , operands[5] - , operands[6]); - } - - @Override - public SqlSyntax getSyntax() { - return SqlSyntax.SPECIAL; - } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + public static final SqlReturnOperator INSTANCE = new SqlReturnOperator(); + + private SqlReturnOperator() { + super("Return", SqlKind.OTHER, 2, true, ReturnTypes.SCOPE, null, null); + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlReturnStatement( + pos, + (SqlNodeList) operands[0], + operands[1], + (SqlNodeList) operands[2], + (SqlNodeList) operands[3], + (SqlNodeList) operands[4], + operands[5], + operands[6]); + } + + @Override + public SqlSyntax getSyntax() { + return SqlSyntax.SPECIAL; + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlSameOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlSameOperator.java index 699f03e92..9d5995a09 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlSameOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlSameOperator.java @@ -38,41 +38,32 @@ * *

Syntax: SAME(element1, element2, ...) * - *

Returns: BOOLEAN - TRUE if all element references point to the same element, - * FALSE otherwise. + *

Returns: BOOLEAN - TRUE if all element references point to the same element, FALSE otherwise. * *

Implements ISO/IEC 39075:2024 Section 19.12. */ public class SqlSameOperator extends SqlFunction { - public static final SqlSameOperator INSTANCE = new SqlSameOperator(); + public static final SqlSameOperator INSTANCE = new SqlSameOperator(); - private SqlSameOperator() { - super( - "SAME", - SqlKind.OTHER_FUNCTION, - ReturnTypes.BOOLEAN, - null, - // At least 2 operands, all must be of comparable types - OperandTypes.VARIADIC, - SqlFunctionCategory.USER_DEFINED_FUNCTION - ); - } + private SqlSameOperator() { + super( + "SAME", + SqlKind.OTHER_FUNCTION, + ReturnTypes.BOOLEAN, + null, + // At least 2 operands, all must be of comparable types + OperandTypes.VARIADIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - return new SqlSameCall(pos, java.util.Arrays.asList(operands)); - } + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + return new SqlSameCall(pos, java.util.Arrays.asList(operands)); + } - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); - } + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlVertexConstructOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlVertexConstructOperator.java index fbbe80d53..dd234ac29 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlVertexConstructOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/operator/SqlVertexConstructOperator.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeFieldImpl; @@ -38,62 +39,57 @@ public class SqlVertexConstructOperator extends SqlMultisetValueConstructor { - private final SqlIdentifier[] fieldNameNodes; + private final SqlIdentifier[] fieldNameNodes; - public SqlVertexConstructOperator(SqlIdentifier[] fieldNameNodes) { - super("VERTEX", SqlKind.VERTEX_VALUE_CONSTRUCTOR); - this.fieldNameNodes = fieldNameNodes; - } + public SqlVertexConstructOperator(SqlIdentifier[] fieldNameNodes) { + super("VERTEX", SqlKind.VERTEX_VALUE_CONSTRUCTOR); + this.fieldNameNodes = fieldNameNodes; + } - @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - List valuesType = opBinding.collectOperandTypes(); - if (fieldNameNodes.length != valuesType.size()) { - throw new GeaFlowDSLException(String.format("Field name size: %s is not equal to the value size: %s at %s", - fieldNameNodes.length, valuesType.size(), fieldNameNodes[0].getParserPosition())); - } - List fields = new ArrayList<>(); - for (int i = 0; i < fieldNameNodes.length; i++) { - String name = fieldNameNodes[i].getSimple(); - RelDataTypeField field = new RelDataTypeFieldImpl(name, i, valuesType.get(i)); - fields.add(field); - } - return VertexRecordType.createVertexType(fields, opBinding.getTypeFactory()); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + List valuesType = opBinding.collectOperandTypes(); + if (fieldNameNodes.length != valuesType.size()) { + throw new GeaFlowDSLException( + String.format( + "Field name size: %s is not equal to the value size: %s at %s", + fieldNameNodes.length, valuesType.size(), fieldNameNodes[0].getParserPosition())); } - - @Override - public RelDataType deriveType( - SqlValidator validator, - SqlValidatorScope scope, - SqlCall call) { - for (SqlNode operand : call.getOperandList()) { - RelDataType nodeType = validator.deriveType(scope, operand); - assert nodeType != null; - } - RelDataType type = call.getOperator().validateOperands(validator, scope, call); - SqlValidatorUtil.checkCharsetAndCollateConsistentIfCharType(type); - if (type.getSqlTypeName() != SqlTypeName.VERTEX) { - throw new GeaFlowDSLException("Vertex construct must return vertex type, current is: " - + type + " at " + call.getParserPosition()); - } - return type; + List fields = new ArrayList<>(); + for (int i = 0; i < fieldNameNodes.length; i++) { + String name = fieldNameNodes[i].getSimple(); + RelDataTypeField field = new RelDataTypeFieldImpl(name, i, valuesType.get(i)); + fields.add(field); } + return VertexRecordType.createVertexType(fields, opBinding.getTypeFactory()); + } - @Override - public SqlCall createCall( - SqlLiteral functionQualifier, - SqlParserPos pos, - SqlNode... operands) { - pos = pos.plusAll(Arrays.asList(operands)); - return new SqlVertexConstruct(fieldNameNodes, operands, pos); + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + for (SqlNode operand : call.getOperandList()) { + RelDataType nodeType = validator.deriveType(scope, operand); + assert nodeType != null; } - - @Override - public void unparse( - SqlWriter writer, - SqlCall call, - int leftPrec, - int rightPrec) { - call.unparse(writer, leftPrec, rightPrec); + RelDataType type = call.getOperator().validateOperands(validator, scope, call); + SqlValidatorUtil.checkCharsetAndCollateConsistentIfCharType(type); + if (type.getSqlTypeName() != SqlTypeName.VERTEX) { + throw new GeaFlowDSLException( + "Vertex construct must return vertex type, current is: " + + type + + " at " + + call.getParserPosition()); } + return type; + } + + @Override + public SqlCall createCall(SqlLiteral functionQualifier, SqlParserPos pos, SqlNode... operands) { + pos = pos.plusAll(Arrays.asList(operands)); + return new SqlVertexConstruct(fieldNameNodes, operands, pos); + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.unparse(writer, leftPrec, rightPrec); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GQLConformance.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GQLConformance.java index 476e59510..c36cda41c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GQLConformance.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GQLConformance.java @@ -23,105 +23,105 @@ public class GQLConformance implements SqlConformance { - public static GQLConformance INSTANCE = new GQLConformance(); - - @Override - public boolean isLiberal() { - return false; - } - - @Override - public boolean isGroupByAlias() { - return true; - } - - @Override - public boolean isGroupByOrdinal() { - return false; - } - - @Override - public boolean isHavingAlias() { - return false; - } - - @Override - public boolean isPercentRemainderAllowed() { - return true; - } - - @Override - public boolean allowNiladicParentheses() { - return false; - } - - @Override - public boolean allowExplicitRowValueConstructor() { - return false; - } - - @Override - public boolean allowExtend() { - return false; - } - - @Override - public boolean isLimitStartCountAllowed() { - return false; - } - - @Override - public boolean allowGeometry() { - return false; - } - - @Override - public boolean shouldConvertRaggedUnionTypesToVarying() { - return true; - } - - @Override - public boolean allowExtendedTrim() { - return false; - } - - @Override - public boolean isSortByOrdinal() { - return false; - } - - @Override - public boolean isSortByAlias() { - return true; - } - - @Override - public boolean isSortByAliasObscures() { - return false; - } - - @Override - public boolean isFromRequired() { - return false; - } - - @Override - public boolean isBangEqualAllowed() { - return true; - } - - @Override - public boolean isMinusAllowed() { - return false; - } - - @Override - public boolean isApplyAllowed() { - return false; - } - - @Override - public boolean isInsertSubsetColumnsAllowed() { - return false; - } + public static GQLConformance INSTANCE = new GQLConformance(); + + @Override + public boolean isLiberal() { + return false; + } + + @Override + public boolean isGroupByAlias() { + return true; + } + + @Override + public boolean isGroupByOrdinal() { + return false; + } + + @Override + public boolean isHavingAlias() { + return false; + } + + @Override + public boolean isPercentRemainderAllowed() { + return true; + } + + @Override + public boolean allowNiladicParentheses() { + return false; + } + + @Override + public boolean allowExplicitRowValueConstructor() { + return false; + } + + @Override + public boolean allowExtend() { + return false; + } + + @Override + public boolean isLimitStartCountAllowed() { + return false; + } + + @Override + public boolean allowGeometry() { + return false; + } + + @Override + public boolean shouldConvertRaggedUnionTypesToVarying() { + return true; + } + + @Override + public boolean allowExtendedTrim() { + return false; + } + + @Override + public boolean isSortByOrdinal() { + return false; + } + + @Override + public boolean isSortByAlias() { + return true; + } + + @Override + public boolean isSortByAliasObscures() { + return false; + } + + @Override + public boolean isFromRequired() { + return false; + } + + @Override + public boolean isBangEqualAllowed() { + return true; + } + + @Override + public boolean isMinusAllowed() { + return false; + } + + @Override + public boolean isApplyAllowed() { + return false; + } + + @Override + public boolean isInsertSubsetColumnsAllowed() { + return false; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GeaFlowDSLParser.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GeaFlowDSLParser.java index 451fa2591..d2f4bbd8e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GeaFlowDSLParser.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/parser/GeaFlowDSLParser.java @@ -21,6 +21,7 @@ import java.io.StringReader; import java.util.List; + import org.apache.calcite.config.Lex; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.sql.SqlNode; @@ -30,58 +31,58 @@ public class GeaFlowDSLParser { - public static SqlParser.Config PARSER_CONFIG = - SqlParser.configBuilder() - .setLex(Lex.MYSQL) - .setParserFactory(GeaFlowParserImpl.FACTORY) - .setConformance(GQLConformance.INSTANCE) - .build(); + public static SqlParser.Config PARSER_CONFIG = + SqlParser.configBuilder() + .setLex(Lex.MYSQL) + .setParserFactory(GeaFlowParserImpl.FACTORY) + .setConformance(GQLConformance.INSTANCE) + .build(); - public List parseMultiStatement(String sql) throws SqlParseException { - GeaFlowParserImpl parser = createParser(sql); - try { - return parser.MultiStmtEof(); - } catch (Throwable ex) { - if (ex instanceof CalciteContextException) { - ((CalciteContextException) ex).setOriginalStatement(sql); - } - throw parser.normalizeException(ex); - } + public List parseMultiStatement(String sql) throws SqlParseException { + GeaFlowParserImpl parser = createParser(sql); + try { + return parser.MultiStmtEof(); + } catch (Throwable ex) { + if (ex instanceof CalciteContextException) { + ((CalciteContextException) ex).setOriginalStatement(sql); + } + throw parser.normalizeException(ex); } + } - public SqlNode parseStatement(String sql) throws SqlParseException { - GeaFlowParserImpl parser = createParser(sql); - try { - return parser.parseSqlStmtEof(); - } catch (Throwable ex) { - if (ex instanceof CalciteContextException) { - ((CalciteContextException) ex).setOriginalStatement(sql); - } - throw parser.normalizeException(ex); - } + public SqlNode parseStatement(String sql) throws SqlParseException { + GeaFlowParserImpl parser = createParser(sql); + try { + return parser.parseSqlStmtEof(); + } catch (Throwable ex) { + if (ex instanceof CalciteContextException) { + ((CalciteContextException) ex).setOriginalStatement(sql); + } + throw parser.normalizeException(ex); } + } - private GeaFlowParserImpl createParser(String sql) { - GeaFlowParserImpl parser = (GeaFlowParserImpl) PARSER_CONFIG.parserFactory() - .getParser(new StringReader(sql)); - - parser.setOriginalSql(sql); - parser.setTabSize(1); - parser.setQuotedCasing(PARSER_CONFIG.quotedCasing()); - parser.setUnquotedCasing(PARSER_CONFIG.unquotedCasing()); - parser.setIdentifierMaxLength(PARSER_CONFIG.identifierMaxLength()); - parser.setConformance(PARSER_CONFIG.conformance()); - switch (PARSER_CONFIG.quoting()) { - case DOUBLE_QUOTE: - parser.switchTo("DQID"); - break; - case BACK_TICK: - parser.switchTo("BTID"); - break; - default: - parser.switchTo("DEFAULT"); - } + private GeaFlowParserImpl createParser(String sql) { + GeaFlowParserImpl parser = + (GeaFlowParserImpl) PARSER_CONFIG.parserFactory().getParser(new StringReader(sql)); - return parser; + parser.setOriginalSql(sql); + parser.setTabSize(1); + parser.setQuotedCasing(PARSER_CONFIG.quotedCasing()); + parser.setUnquotedCasing(PARSER_CONFIG.unquotedCasing()); + parser.setIdentifierMaxLength(PARSER_CONFIG.identifierMaxLength()); + parser.setConformance(PARSER_CONFIG.conformance()); + switch (PARSER_CONFIG.quoting()) { + case DOUBLE_QUOTE: + parser.switchTo("DQID"); + break; + case BACK_TICK: + parser.switchTo("BTID"); + break; + default: + parser.switchTo("DEFAULT"); } + + return parser; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/AbstractSqlGraphElementConstruct.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/AbstractSqlGraphElementConstruct.java index 1e5c6f81a..f8b98adf7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/AbstractSqlGraphElementConstruct.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/AbstractSqlGraphElementConstruct.java @@ -19,59 +19,60 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.base.Preconditions; import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.parser.SqlParserPos; +import com.google.common.base.Preconditions; + public abstract class AbstractSqlGraphElementConstruct extends SqlBasicCall { - private final SqlIdentifier[] keyNodes; + private final SqlIdentifier[] keyNodes; - public AbstractSqlGraphElementConstruct(SqlOperator operator, SqlNode[] operands, - SqlParserPos pos) { - super(operator, getValueNodes(operands), pos); - this.keyNodes = getKeyNodes(operands); - } + public AbstractSqlGraphElementConstruct( + SqlOperator operator, SqlNode[] operands, SqlParserPos pos) { + super(operator, getValueNodes(operands), pos); + this.keyNodes = getKeyNodes(operands); + } - public SqlIdentifier[] getKeyNodes() { - return keyNodes; - } + public SqlIdentifier[] getKeyNodes() { + return keyNodes; + } - public SqlNode[] getValueNodes() { - return super.operands; - } + public SqlNode[] getValueNodes() { + return super.operands; + } - protected static SqlIdentifier[] getKeyNodes(SqlNode[] operands) { - SqlIdentifier[] keyNodes = new SqlIdentifier[operands.length / 2]; - for (int i = 0; i < operands.length; i += 2) { - keyNodes[i / 2] = (SqlIdentifier) operands[i]; - } - return keyNodes; + protected static SqlIdentifier[] getKeyNodes(SqlNode[] operands) { + SqlIdentifier[] keyNodes = new SqlIdentifier[operands.length / 2]; + for (int i = 0; i < operands.length; i += 2) { + keyNodes[i / 2] = (SqlIdentifier) operands[i]; } + return keyNodes; + } - protected static SqlNode[] getValueNodes(SqlNode[] operands) { - Preconditions.checkArgument(operands.length % 2 == 0, - "Illegal operand count: " + operands.length); - SqlNode[] valueNodes = new SqlNode[operands.length / 2]; - for (int i = 1; i < operands.length; i += 2) { - valueNodes[(i - 1) / 2] = operands[i]; - } - return valueNodes; + protected static SqlNode[] getValueNodes(SqlNode[] operands) { + Preconditions.checkArgument( + operands.length % 2 == 0, "Illegal operand count: " + operands.length); + SqlNode[] valueNodes = new SqlNode[operands.length / 2]; + for (int i = 1; i < operands.length; i += 2) { + valueNodes[(i - 1) / 2] = operands[i]; } + return valueNodes; + } - protected static SqlNode[] getOperands(SqlNode[] keyNodes, SqlNode[] valueNodes) { - assert keyNodes.length == valueNodes.length; - SqlNode[] nodes = new SqlNode[keyNodes.length + valueNodes.length]; - for (int i = 0; i < nodes.length; i++) { - if (i % 2 == 0) { - nodes[i] = keyNodes[i / 2]; - } else { - nodes[i] = valueNodes[i / 2]; - } - } - return nodes; + protected static SqlNode[] getOperands(SqlNode[] keyNodes, SqlNode[] valueNodes) { + assert keyNodes.length == valueNodes.length; + SqlNode[] nodes = new SqlNode[keyNodes.length + valueNodes.length]; + for (int i = 0; i < nodes.length; i++) { + if (i % 2 == 0) { + nodes[i] = keyNodes[i / 2]; + } else { + nodes[i] = valueNodes[i / 2]; + } } + return nodes; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlAlterGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlAlterGraph.java index 1bdef23a1..d2e760fc1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlAlterGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlAlterGraph.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.sqlnode; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.tools.ValidationException; @@ -27,78 +28,77 @@ public class SqlAlterGraph extends SqlCall { - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlAlterGraph", - SqlKind.ALTER_GRAPH); + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlAlterGraph", SqlKind.ALTER_GRAPH); - private SqlIdentifier alterName; - private SqlNodeList vertices; - private SqlNodeList edges; + private SqlIdentifier alterName; + private SqlNodeList vertices; + private SqlNodeList edges; - public SqlAlterGraph(SqlParserPos pos, SqlIdentifier alterName, - SqlNodeList vertices, SqlNodeList edges) { - super(pos); - this.alterName = alterName; - this.vertices = vertices; - this.edges = edges; - } + public SqlAlterGraph( + SqlParserPos pos, SqlIdentifier alterName, SqlNodeList vertices, SqlNodeList edges) { + super(pos); + this.alterName = alterName; + this.vertices = vertices; + this.edges = edges; + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.alterName = (SqlIdentifier) operand; - break; - case 1: - this.vertices = (SqlNodeList) operand; - break; - case 2: - this.edges = (SqlNodeList) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.alterName = (SqlIdentifier) operand; + break; + case 1: + this.vertices = (SqlNodeList) operand; + break; + case 2: + this.edges = (SqlNodeList) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public List getOperandList() { - return ImmutableNullableList.of(getName(), getVertices(), getEdges()); - } + @Override + public List getOperandList() { + return ImmutableNullableList.of(getName(), getVertices(), getEdges()); + } - @Override - public SqlOperator getOperator() { - return this.OPERATOR; - } + @Override + public SqlOperator getOperator() { + return this.OPERATOR; + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - for (SqlNode addV : vertices) { - writer.keyword("alter"); - writer.keyword("graph"); - alterName.unparse(writer, 0, 0); - writer.keyword("add"); - addV.unparse(writer, 0, 0); - } - for (SqlNode addE : edges) { - writer.keyword("alter"); - writer.keyword("graph"); - alterName.unparse(writer, 0, 0); - writer.keyword("add"); - addE.unparse(writer, 0, 0); - } + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + for (SqlNode addV : vertices) { + writer.keyword("alter"); + writer.keyword("graph"); + alterName.unparse(writer, 0, 0); + writer.keyword("add"); + addV.unparse(writer, 0, 0); } - - public SqlIdentifier getName() { - return alterName; + for (SqlNode addE : edges) { + writer.keyword("alter"); + writer.keyword("graph"); + alterName.unparse(writer, 0, 0); + writer.keyword("add"); + addE.unparse(writer, 0, 0); } + } - public SqlNodeList getVertices() { - return vertices; - } + public SqlIdentifier getName() { + return alterName; + } - public SqlNodeList getEdges() { - return edges; - } + public SqlNodeList getVertices() { + return vertices; + } - public void validate() throws ValidationException { - } + public SqlNodeList getEdges() { + return edges; + } + + public void validate() throws ValidationException {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateFunction.java index 311d4484d..d1e6534b1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateFunction.java @@ -19,127 +19,122 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.base.Objects; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.util.ImmutableNullableList; import org.apache.geaflow.dsl.util.StringLiteralUtil; -/** - * Parse tree node that represents a CREATE Function statement. - */ -public class SqlCreateFunction extends SqlCreate { - - private static final SqlOperator OPERATOR = - new SqlSpecialOperator("CREATE FUNCTION", SqlKind.CREATE_FUNCTION); - - private SqlNode functionName; - private SqlNode className; - private SqlNode usingPath; - - public SqlCreateFunction(SqlParserPos pos, - boolean ifNotExists, - SqlNode functionName, - SqlNode className, - SqlNode usingPath) { - super(OPERATOR, pos, false, ifNotExists); - this.functionName = functionName; - this.className = className; - this.usingPath = usingPath; - } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(getFunctionName(), getClassNameNode(), usingPath); - } +import com.google.common.base.Objects; - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.functionName = operand; - break; - case 1: - this.className = operand; - break; - case 2: - this.usingPath = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } +/** Parse tree node that represents a CREATE Function statement. */ +public class SqlCreateFunction extends SqlCreate { - @Override - public void unparse(SqlWriter writer, - int leftPrec, - int rightPrec) { - writer.keyword("CREATE"); - writer.keyword("FUNCTION"); - if (super.ifNotExists) { - writer.keyword("IF"); - writer.keyword("NOT"); - writer.keyword("EXISTS"); - } - functionName.unparse(writer, leftPrec, rightPrec); - writer.keyword("AS"); - className.unparse(writer, 0, 0); - if (usingPath != null) { - writer.keyword("USING"); - usingPath.unparse(writer, 0, 0); - } + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("CREATE FUNCTION", SqlKind.CREATE_FUNCTION); + + private SqlNode functionName; + private SqlNode className; + private SqlNode usingPath; + + public SqlCreateFunction( + SqlParserPos pos, + boolean ifNotExists, + SqlNode functionName, + SqlNode className, + SqlNode usingPath) { + super(OPERATOR, pos, false, ifNotExists); + this.functionName = functionName; + this.className = className; + this.usingPath = usingPath; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(getFunctionName(), getClassNameNode(), usingPath); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.functionName = operand; + break; + case 1: + this.className = operand; + break; + case 2: + this.usingPath = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } - - public void validate() { - + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CREATE"); + writer.keyword("FUNCTION"); + if (super.ifNotExists) { + writer.keyword("IF"); + writer.keyword("NOT"); + writer.keyword("EXISTS"); } - - public SqlNode getFunctionName() { - return functionName; + functionName.unparse(writer, leftPrec, rightPrec); + writer.keyword("AS"); + className.unparse(writer, 0, 0); + if (usingPath != null) { + writer.keyword("USING"); + usingPath.unparse(writer, 0, 0); } + } + public void validate() {} - public String getClassName() { - return StringLiteralUtil.unescapeSQLString(className.toString()); - } + public SqlNode getFunctionName() { + return functionName; + } - public String getUsingPath() { - if (usingPath == null) { - return null; - } - return StringLiteralUtil.unescapeSQLString(usingPath.toString()); - } + public String getClassName() { + return StringLiteralUtil.unescapeSQLString(className.toString()); + } - public SqlNode getClassNameNode() { - return className; + public String getUsingPath() { + if (usingPath == null) { + return null; } + return StringLiteralUtil.unescapeSQLString(usingPath.toString()); + } - public void setClassName(SqlNode className) { - this.className = className; - } + public SqlNode getClassNameNode() { + return className; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - SqlCreateFunction that = (SqlCreateFunction) o; - return Objects.equal(functionName, that.functionName) - && Objects.equal(className, that.className) - && Objects.equal(usingPath, that.usingPath); - } + public void setClassName(SqlNode className) { + this.className = className; + } - @Override - public int hashCode() { - return Objects - .hashCode(functionName, className, usingPath); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean ifNotExists() { - return ifNotExists; + if (o == null || getClass() != o.getClass()) { + return false; } + SqlCreateFunction that = (SqlCreateFunction) o; + return Objects.equal(functionName, that.functionName) + && Objects.equal(className, that.className) + && Objects.equal(usingPath, that.usingPath); + } + + @Override + public int hashCode() { + return Objects.hashCode(functionName, className, usingPath); + } + + public boolean ifNotExists() { + return ifNotExists; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateGraph.java index a9ee90366..e930b705b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateGraph.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.sqlnode; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.util.ImmutableNullableList; @@ -27,103 +28,108 @@ public class SqlCreateGraph extends SqlCreate { - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlCreateGraph", - SqlKind.CREATE_GRAPH); - - private SqlIdentifier name; - private SqlNodeList vertices; - private SqlNodeList edges; - private SqlNodeList properties; - private final boolean isTemporary; - - public SqlCreateGraph(SqlParserPos pos, boolean isTemporary, boolean ifNotExists, - SqlIdentifier name, SqlNodeList vertices, - SqlNodeList edges, SqlNodeList properties) { - super(OPERATOR, pos, false, ifNotExists); - this.name = name; - this.vertices = vertices; - this.edges = edges; - this.properties = properties; - this.isTemporary = isTemporary; - } + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlCreateGraph", SqlKind.CREATE_GRAPH); - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.vertices = (SqlNodeList) operand; - break; - case 2: - this.edges = (SqlNodeList) operand; - break; - case 3: - this.properties = (SqlNodeList) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } + private SqlIdentifier name; + private SqlNodeList vertices; + private SqlNodeList edges; + private SqlNodeList properties; + private final boolean isTemporary; - @Override - public List getOperandList() { - return ImmutableNullableList.of(getName(), getVertices(), getEdges(), getProperties()); - } + public SqlCreateGraph( + SqlParserPos pos, + boolean isTemporary, + boolean ifNotExists, + SqlIdentifier name, + SqlNodeList vertices, + SqlNodeList edges, + SqlNodeList properties) { + super(OPERATOR, pos, false, ifNotExists); + this.name = name; + this.vertices = vertices; + this.edges = edges; + this.properties = properties; + this.isTemporary = isTemporary; + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("CREATE"); - if (isTemporary) { - writer.keyword("TEMPORARY"); - } - writer.keyword("GRAPH"); - if (super.ifNotExists) { - writer.keyword("IF"); - writer.keyword("NOT"); - writer.keyword("EXISTS"); - } - name.unparse(writer, 0, 0); - writer.print("("); - writer.newlineAndIndent(); - SqlNodeUtil.unparseNodeList(writer, vertices, ","); - writer.print(","); - writer.newlineAndIndent(); - SqlNodeUtil.unparseNodeList(writer, edges, ","); - writer.newlineAndIndent(); - writer.print(")"); - if (properties != null && properties.size() > 0) { - writer.keyword("WITH"); - writer.print("("); - writer.newlineAndIndent(); - SqlNodeUtil.unparseNodeList(writer, properties, ","); - writer.newlineAndIndent(); - writer.print(")"); - } + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.vertices = (SqlNodeList) operand; + break; + case 2: + this.edges = (SqlNodeList) operand; + break; + case 3: + this.properties = (SqlNodeList) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - public SqlIdentifier getName() { - return name; - } + @Override + public List getOperandList() { + return ImmutableNullableList.of(getName(), getVertices(), getEdges(), getProperties()); + } - public SqlNodeList getVertices() { - return vertices; + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CREATE"); + if (isTemporary) { + writer.keyword("TEMPORARY"); } - - public SqlNodeList getEdges() { - return edges; + writer.keyword("GRAPH"); + if (super.ifNotExists) { + writer.keyword("IF"); + writer.keyword("NOT"); + writer.keyword("EXISTS"); } - - public SqlNodeList getProperties() { - return properties; + name.unparse(writer, 0, 0); + writer.print("("); + writer.newlineAndIndent(); + SqlNodeUtil.unparseNodeList(writer, vertices, ","); + writer.print(","); + writer.newlineAndIndent(); + SqlNodeUtil.unparseNodeList(writer, edges, ","); + writer.newlineAndIndent(); + writer.print(")"); + if (properties != null && properties.size() > 0) { + writer.keyword("WITH"); + writer.print("("); + writer.newlineAndIndent(); + SqlNodeUtil.unparseNodeList(writer, properties, ","); + writer.newlineAndIndent(); + writer.print(")"); } + } - public boolean isTemporary() { - return isTemporary; - } + public SqlIdentifier getName() { + return name; + } - public boolean ifNotExists() { - return ifNotExists; - } + public SqlNodeList getVertices() { + return vertices; + } + + public SqlNodeList getEdges() { + return edges; + } + + public SqlNodeList getProperties() { + return properties; + } + + public boolean isTemporary() { + return isTemporary; + } + + public boolean ifNotExists() { + return ifNotExists; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateTable.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateTable.java index 354b1ec2b..0da8601e3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateTable.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.tools.ValidationException; @@ -29,201 +30,192 @@ import org.apache.commons.lang.StringUtils; import org.apache.geaflow.dsl.util.SqlTypeUtil; -/** - * Parse tree node that represents a CREATE TABLE statement. - */ +/** Parse tree node that represents a CREATE TABLE statement. */ public class SqlCreateTable extends SqlCreate { - private static final SqlOperator OPERATOR = - new SqlSpecialOperator("CREATE TABLE", SqlKind.CREATE_TABLE); - - private SqlIdentifier name; - private SqlNodeList columns; - private SqlNodeList properties; - private SqlNodeList primaryKeys; - private SqlNodeList partitionFields; - private final boolean isTemporary; - - /** - * Creates a SqlCreateTable. - */ - public SqlCreateTable(SqlParserPos pos, - boolean isTemporary, - boolean ifNotExists, - SqlIdentifier name, - SqlNodeList columns, - SqlNodeList properties, - SqlNodeList primaryKeys, - SqlNodeList partitionFields) { - super(OPERATOR, pos, false, ifNotExists); - this.name = name; - this.columns = columns; - this.properties = properties; - this.primaryKeys = primaryKeys; - this.partitionFields = partitionFields; - this.isTemporary = isTemporary; + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("CREATE TABLE", SqlKind.CREATE_TABLE); + + private SqlIdentifier name; + private SqlNodeList columns; + private SqlNodeList properties; + private SqlNodeList primaryKeys; + private SqlNodeList partitionFields; + private final boolean isTemporary; + + /** Creates a SqlCreateTable. */ + public SqlCreateTable( + SqlParserPos pos, + boolean isTemporary, + boolean ifNotExists, + SqlIdentifier name, + SqlNodeList columns, + SqlNodeList properties, + SqlNodeList primaryKeys, + SqlNodeList partitionFields) { + super(OPERATOR, pos, false, ifNotExists); + this.name = name; + this.columns = columns; + this.properties = properties; + this.primaryKeys = primaryKeys; + this.partitionFields = partitionFields; + this.isTemporary = isTemporary; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of( + getName(), getColumns(), getProperties(), getPrimaryKeys(), getPartitionFields()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.columns = (SqlNodeList) operand; + break; + case 2: + this.properties = (SqlNodeList) operand; + break; + case 3: + this.primaryKeys = (SqlNodeList) operand; + break; + case 4: + this.partitionFields = (SqlNodeList) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public List getOperandList() { - return ImmutableNullableList.of(getName(), getColumns(), getProperties(), - getPrimaryKeys(), getPartitionFields()); - } + public void setProperties(SqlNodeList properties) { + this.properties = properties; + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.columns = (SqlNodeList) operand; - break; - case 2: - this.properties = (SqlNodeList) operand; - break; - case 3: - this.primaryKeys = (SqlNodeList) operand; - break; - case 4: - this.partitionFields = (SqlNodeList) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CREATE"); + if (isTemporary) { + writer.keyword("TEMPORARY"); } - - - public void setProperties(SqlNodeList properties) { - this.properties = properties; + writer.keyword("TABLE"); + if (super.ifNotExists) { + writer.keyword("IF"); + writer.keyword("NOT"); + writer.keyword("EXISTS"); } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("CREATE"); - if (isTemporary) { - writer.keyword("TEMPORARY"); - } - writer.keyword("TABLE"); - if (super.ifNotExists) { - writer.keyword("IF"); - writer.keyword("NOT"); - writer.keyword("EXISTS"); - } - name.unparse(writer, leftPrec, rightPrec); - if (columns.size() >= 0) { - final SqlWriter.Frame frame = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + name.unparse(writer, leftPrec, rightPrec); + if (columns.size() >= 0) { + final SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.newlineAndIndent(); + writer.print(" "); + if (columns.size() > 0) { + for (int i = 0; i < columns.size(); i++) { + if (i > 0) { + writer.print(","); writer.newlineAndIndent(); writer.print(" "); - if (columns.size() > 0) { - for (int i = 0; i < columns.size(); i++) { - if (i > 0) { - writer.print(","); - writer.newlineAndIndent(); - writer.print(" "); - } - columns.get(i).unparse(writer, leftPrec, rightPrec); - } - } - writer.newlineAndIndent(); - writer.endList(frame); + } + columns.get(i).unparse(writer, leftPrec, rightPrec); } - if (partitionFields != null && partitionFields.size() > 0) { - writer.keyword("PARTITIONED"); - writer.keyword("BY"); - writer.print("("); - boolean first = true; - for (SqlNode partitionField : partitionFields) { - if (!first) { - writer.print(","); - } - first = false; - partitionField.unparse(writer, 0, 0); - } - writer.print(")"); - } - if (properties != null) { - writer.keyword("WITH"); - final SqlWriter.Frame with = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); - writer.newlineAndIndent(); - for (int i = 0; i < properties.size(); i++) { - if (i > 0) { - writer.print(","); - writer.newlineAndIndent(); - } - properties.get(i).unparse(writer, leftPrec, rightPrec); - } - - writer.newlineAndIndent(); - writer.endList(with); + } + writer.newlineAndIndent(); + writer.endList(frame); + } + if (partitionFields != null && partitionFields.size() > 0) { + writer.keyword("PARTITIONED"); + writer.keyword("BY"); + writer.print("("); + boolean first = true; + for (SqlNode partitionField : partitionFields) { + if (!first) { + writer.print(","); } + first = false; + partitionField.unparse(writer, 0, 0); + } + writer.print(")"); } - - /** - * Sql syntax validation. - */ - public void validate() throws ValidationException { - Map columnNameMap = new HashMap<>(); - if (columns != null) { - for (SqlNode column : columns) { - SqlTableColumn sqlTableColumn = (SqlTableColumn) column; - String columnName = sqlTableColumn.getName().getSimple(); - if (columnNameMap.get(columnName) == null) { - columnNameMap.put(columnName, true); - } else { - throw new ValidationException( - "duplicate column name " + "[" + columnName + "], at " + column.getParserPosition()); - } - - SqlDataTypeSpec typeSpec = sqlTableColumn.getType(); - try { - SqlTypeUtil.convertType(typeSpec); - } catch (UnsupportedOperationException e) { - throw new ValidationException( - "not support type " + "[" + typeSpec + "], at " + column.getParserPosition()); - } - } + if (properties != null) { + writer.keyword("WITH"); + final SqlWriter.Frame with = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.newlineAndIndent(); + for (int i = 0; i < properties.size(); i++) { + if (i > 0) { + writer.print(","); + writer.newlineAndIndent(); } + properties.get(i).unparse(writer, leftPrec, rightPrec); + } - if (properties != null) { - for (int i = 0; i < properties.size(); i++) { - SqlTableProperty property = (SqlTableProperty) properties.get(i); - if (property.getKey() == null || StringUtils.isEmpty(property.getKey().toString())) { - throw new ValidationException( - "property key is null or empty string at " + property.getParserPosition()); - } + writer.newlineAndIndent(); + writer.endList(with); + } + } + + /** Sql syntax validation. */ + public void validate() throws ValidationException { + Map columnNameMap = new HashMap<>(); + if (columns != null) { + for (SqlNode column : columns) { + SqlTableColumn sqlTableColumn = (SqlTableColumn) column; + String columnName = sqlTableColumn.getName().getSimple(); + if (columnNameMap.get(columnName) == null) { + columnNameMap.put(columnName, true); + } else { + throw new ValidationException( + "duplicate column name " + "[" + columnName + "], at " + column.getParserPosition()); + } - } + SqlDataTypeSpec typeSpec = sqlTableColumn.getType(); + try { + SqlTypeUtil.convertType(typeSpec); + } catch (UnsupportedOperationException e) { + throw new ValidationException( + "not support type " + "[" + typeSpec + "], at " + column.getParserPosition()); } + } } - public SqlIdentifier getName() { - return name; + if (properties != null) { + for (int i = 0; i < properties.size(); i++) { + SqlTableProperty property = (SqlTableProperty) properties.get(i); + if (property.getKey() == null || StringUtils.isEmpty(property.getKey().toString())) { + throw new ValidationException( + "property key is null or empty string at " + property.getParserPosition()); + } + } } + } - public SqlNodeList getColumns() { - return columns; - } + public SqlIdentifier getName() { + return name; + } - public SqlNodeList getProperties() { - return properties; - } + public SqlNodeList getColumns() { + return columns; + } - public SqlNodeList getPrimaryKeys() { - return primaryKeys; - } + public SqlNodeList getProperties() { + return properties; + } - public SqlNodeList getPartitionFields() { - return partitionFields; - } + public SqlNodeList getPrimaryKeys() { + return primaryKeys; + } - public boolean isTemporary() { - return isTemporary; - } + public SqlNodeList getPartitionFields() { + return partitionFields; + } - public boolean ifNotExists() { - return ifNotExists; - } + public boolean isTemporary() { + return isTemporary; + } + + public boolean ifNotExists() { + return ifNotExists; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateView.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateView.java index 5b9fd4ea9..f1d1fd08a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlCreateView.java @@ -19,118 +19,118 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.util.ImmutableNullableList; -/** - * Parse tree node that represents a CREATE VIEW statement. - */ -public class SqlCreateView extends SqlCreate { +import com.google.common.collect.Lists; - public static final SqlSpecialOperator OPERATOR = - new SqlSpecialOperator("CREATE VIEW", SqlKind.CREATE_VIEW); - - private SqlIdentifier name; - private SqlNodeList fields; - private SqlNode subQuery; - - public SqlCreateView(SqlParserPos pos, - boolean ifNotExists, - SqlIdentifier name, - SqlNodeList fields, - SqlNode subQuery) { - super(OPERATOR, pos, false, ifNotExists); - this.name = name; - this.subQuery = subQuery; - this.fields = fields; - } +/** Parse tree node that represents a CREATE VIEW statement. */ +public class SqlCreateView extends SqlCreate { - @Override - public SqlOperator getOperator() { - return OPERATOR; + public static final SqlSpecialOperator OPERATOR = + new SqlSpecialOperator("CREATE VIEW", SqlKind.CREATE_VIEW); + + private SqlIdentifier name; + private SqlNodeList fields; + private SqlNode subQuery; + + public SqlCreateView( + SqlParserPos pos, + boolean ifNotExists, + SqlIdentifier name, + SqlNodeList fields, + SqlNode subQuery) { + super(OPERATOR, pos, false, ifNotExists); + this.name = name; + this.subQuery = subQuery; + this.fields = fields; + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(getName(), getFields(), getSubQuery()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.fields = (SqlNodeList) operand; + break; + case 2: + this.subQuery = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(getName(), getFields(), getSubQuery()); + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CREATE"); + writer.keyword("VIEW"); + if (super.ifNotExists) { + writer.keyword("IF"); + writer.keyword("NOT"); + writer.keyword("EXISTS"); } - - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.fields = (SqlNodeList) operand; - break; - case 2: - this.subQuery = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); + name.unparse(writer, leftPrec, rightPrec); + if (fields.size() > 0) { + final SqlWriter.Frame field = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.newlineAndIndent(); + writer.print(" "); + for (int i = 0; i < fields.size(); i++) { + if (i > 0) { + writer.print(","); + writer.newlineAndIndent(); + writer.print(" "); } + fields.get(i).unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); + writer.endList(field); } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("CREATE"); - writer.keyword("VIEW"); - if (super.ifNotExists) { - writer.keyword("IF"); - writer.keyword("NOT"); - writer.keyword("EXISTS"); - } - name.unparse(writer, leftPrec, rightPrec); - if (fields.size() > 0) { - final SqlWriter.Frame field = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); - writer.newlineAndIndent(); - writer.print(" "); - for (int i = 0; i < fields.size(); i++) { - if (i > 0) { - writer.print(","); - writer.newlineAndIndent(); - writer.print(" "); - } - fields.get(i).unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - writer.endList(field); - } - writer.keyword("AS"); - writer.newlineAndIndent(); - subQuery.unparse(writer, leftPrec, rightPrec); - } - - public SqlIdentifier getName() { - return name; - } - - public List getFieldNames() { - List fieldNames = Lists.newArrayList(); - for (SqlNode node : fields.getList()) { - fieldNames.add(node.toString()); - } - return fieldNames; + writer.keyword("AS"); + writer.newlineAndIndent(); + subQuery.unparse(writer, leftPrec, rightPrec); + } + + public SqlIdentifier getName() { + return name; + } + + public List getFieldNames() { + List fieldNames = Lists.newArrayList(); + for (SqlNode node : fields.getList()) { + fieldNames.add(node.toString()); } + return fieldNames; + } - public SqlNodeList getFields() { - return fields; - } + public SqlNodeList getFields() { + return fields; + } - public SqlNode getSubQuery() { - return subQuery; - } + public SqlNode getSubQuery() { + return subQuery; + } - public String getSubQuerySql() { - return subQuery.toString(); - } + public String getSubQuerySql() { + return subQuery.toString(); + } - public boolean ifNotExists() { - return ifNotExists; - } + public boolean ifNotExists() { + return ifNotExists; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDescGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDescGraph.java index 4e8021ed1..7d67c5f0a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDescGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDescGraph.java @@ -19,47 +19,49 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.collect.ImmutableList; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; +import com.google.common.collect.ImmutableList; + public class SqlDescGraph extends SqlCall { - private SqlIdentifier name; + private SqlIdentifier name; - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlDescGraph", - SqlKind.DESC_GRAPH); + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlDescGraph", SqlKind.DESC_GRAPH); - public SqlDescGraph(SqlParserPos pos, SqlIdentifier name) { - super(pos); - this.name = name; - } + public SqlDescGraph(SqlParserPos pos, SqlIdentifier name) { + super(pos); + this.name = name; + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("desc"); - writer.keyword("graph"); - name.unparse(writer, 0, 0); - } + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("desc"); + writer.keyword("graph"); + name.unparse(writer, 0, 0); + } - @Override - public List getOperandList() { - return ImmutableList.of(name); - } + @Override + public List getOperandList() { + return ImmutableList.of(name); + } - public SqlOperator getOperator() { - return this.OPERATOR; - } + public SqlOperator getOperator() { + return this.OPERATOR; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDropGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDropGraph.java index 24df76252..4ae5de6cb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDropGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlDropGraph.java @@ -19,43 +19,45 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.collect.ImmutableList; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; -public class SqlDropGraph extends SqlDrop { - - private SqlIdentifier name; +import com.google.common.collect.ImmutableList; - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlDropGraph", - SqlKind.DROP_GRAPH); +public class SqlDropGraph extends SqlDrop { - public SqlDropGraph(SqlParserPos pos, SqlIdentifier name) { - super(OPERATOR, pos, false); - this.name = name; - } + private SqlIdentifier name; - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlDropGraph", SqlKind.DROP_GRAPH); - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("drop"); - writer.keyword("graph"); - name.unparse(writer, 0, 0); - } + public SqlDropGraph(SqlParserPos pos, SqlIdentifier name) { + super(OPERATOR, pos, false); + this.name = name; + } - @Override - public List getOperandList() { - return ImmutableList.of(name); + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("drop"); + writer.keyword("graph"); + name.unparse(writer, 0, 0); + } + + @Override + public List getOperandList() { + return ImmutableList.of(name); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdge.java index f6674cfbd..d1f2be069 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdge.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; @@ -32,114 +33,117 @@ public class SqlEdge extends SqlCall { - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlEdge", - SqlKind.OTHER_DDL); - private SqlIdentifier name; - private SqlNodeList columns; - private SqlNodeList constraints; + private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlEdge", SqlKind.OTHER_DDL); + private SqlIdentifier name; + private SqlNodeList columns; + private SqlNodeList constraints; - public SqlEdge(SqlParserPos pos, SqlIdentifier name, SqlNodeList columns, SqlNodeList constraints) { - super(pos); - this.name = name; - this.columns = columns; - this.constraints = constraints; - } + public SqlEdge( + SqlParserPos pos, SqlIdentifier name, SqlNodeList columns, SqlNodeList constraints) { + super(pos); + this.name = name; + this.columns = columns; + this.constraints = constraints; + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.columns = (SqlNodeList) operand; - break; - case 2: - this.constraints = (SqlNodeList) operand; - break; - default: - throw new IndexOutOfBoundsException("current index " + i + " out of range " + 3); - } + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.columns = (SqlNodeList) operand; + break; + case 2: + this.constraints = (SqlNodeList) operand; + break; + default: + throw new IndexOutOfBoundsException("current index " + i + " out of range " + 3); } + } - @Override - public SqlOperator getOperator() { - return OPERATOR; - } + @Override + public SqlOperator getOperator() { + return OPERATOR; + } - @Override - public List getOperandList() { - return ImmutableNullableList.of(getName(), getColumns(), getConstraints()); - } + @Override + public List getOperandList() { + return ImmutableNullableList.of(getName(), getColumns(), getConstraints()); + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("edge"); - name.unparse(writer, 0, 0); - writer.print("("); - writer.newlineAndIndent(); - for (int i = 0; i < columns.size(); i++) { - if (i > 0) { - writer.print(",\n"); - } - columns.get(i).unparse(writer, 0, 0); - } - writer.newlineAndIndent(); - writer.print(")"); - if (constraints != null && constraints.size() > 0) { - for (int i = 0; i < constraints.size(); i++) { - if (i > 0) { - writer.print("\n"); - } - constraints.get(i).unparse(writer, 0, 0); - } + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("edge"); + name.unparse(writer, 0, 0); + writer.print("("); + writer.newlineAndIndent(); + for (int i = 0; i < columns.size(); i++) { + if (i > 0) { + writer.print(",\n"); + } + columns.get(i).unparse(writer, 0, 0); + } + writer.newlineAndIndent(); + writer.print(")"); + if (constraints != null && constraints.size() > 0) { + for (int i = 0; i < constraints.size(); i++) { + if (i > 0) { + writer.print("\n"); } + constraints.get(i).unparse(writer, 0, 0); + } } + } - public SqlIdentifier getName() { - return name; - } + public SqlIdentifier getName() { + return name; + } - public SqlNodeList getColumns() { - return columns; - } + public SqlNodeList getColumns() { + return columns; + } - public SqlNodeList getConstraints() { - return constraints; - } + public SqlNodeList getConstraints() { + return constraints; + } - public void validate() { - SqlIdentifier sourceVertex = null; - SqlIdentifier targetVertex = null; - for (Object c : columns) { - SqlTableColumn column = (SqlTableColumn) c; - column.validate(); - if (column.getCategory() == ColumnCategory.SOURCE_ID) { - assert sourceVertex == null : "Duplicated source id field."; - sourceVertex = column.getTypeFrom(); - } else if (column.getCategory() == ColumnCategory.DESTINATION_ID) { - assert targetVertex == null : "Duplicated destination id field."; - targetVertex = column.getTypeFrom(); - } - } - if (sourceVertex != null && targetVertex != null) { - this.constraints = new SqlNodeList(Collections.singletonList(new GQLEdgeConstraint( - new SqlNodeList(Collections.singletonList(sourceVertex), getParserPosition()), - new SqlNodeList(Collections.singletonList(targetVertex), getParserPosition()), - getParserPosition() - )), getParserPosition()); - } else if (sourceVertex == null && targetVertex != null) { - throw new GeaFlowDSLException("The vertex source id from in edge '{}' should be set.", - getName().getSimple()); - } else if (sourceVertex != null) { - throw new GeaFlowDSLException("The vertex target id from in edge '{}' should be set.", - getName().getSimple()); - } + public void validate() { + SqlIdentifier sourceVertex = null; + SqlIdentifier targetVertex = null; + for (Object c : columns) { + SqlTableColumn column = (SqlTableColumn) c; + column.validate(); + if (column.getCategory() == ColumnCategory.SOURCE_ID) { + assert sourceVertex == null : "Duplicated source id field."; + sourceVertex = column.getTypeFrom(); + } else if (column.getCategory() == ColumnCategory.DESTINATION_ID) { + assert targetVertex == null : "Duplicated destination id field."; + targetVertex = column.getTypeFrom(); + } } - - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - this.validate(); - super.validate(validator, scope); + if (sourceVertex != null && targetVertex != null) { + this.constraints = + new SqlNodeList( + Collections.singletonList( + new GQLEdgeConstraint( + new SqlNodeList(Collections.singletonList(sourceVertex), getParserPosition()), + new SqlNodeList(Collections.singletonList(targetVertex), getParserPosition()), + getParserPosition())), + getParserPosition()); + } else if (sourceVertex == null && targetVertex != null) { + throw new GeaFlowDSLException( + "The vertex source id from in edge '{}' should be set.", getName().getSimple()); + } else if (sourceVertex != null) { + throw new GeaFlowDSLException( + "The vertex target id from in edge '{}' should be set.", getName().getSimple()); } + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + this.validate(); + super.validate(validator, scope); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeConstruct.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeConstruct.java index 20fd8ef99..b49a31a26 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeConstruct.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeConstruct.java @@ -27,26 +27,26 @@ public class SqlEdgeConstruct extends AbstractSqlGraphElementConstruct { - public SqlEdgeConstruct(SqlNode[] operands, SqlParserPos pos) { - super(new SqlEdgeConstructOperator(getKeyNodes(operands)), operands, pos); - } + public SqlEdgeConstruct(SqlNode[] operands, SqlParserPos pos) { + super(new SqlEdgeConstructOperator(getKeyNodes(operands)), operands, pos); + } - public SqlEdgeConstruct(SqlIdentifier[] keyNodes, SqlNode[] valueNodes, SqlParserPos pos) { - this(getOperands(keyNodes, valueNodes), pos); - } + public SqlEdgeConstruct(SqlIdentifier[] keyNodes, SqlNode[] valueNodes, SqlParserPos pos) { + this(getOperands(keyNodes, valueNodes), pos); + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("Edge"); - writer.print("{\n"); - for (int i = 0; i < getKeyNodes().length; i++) { - SqlNode key = getKeyNodes()[i]; - SqlNode value = getValueNodes()[i]; - key.unparse(writer, 0, 0); - writer.print("="); - value.unparse(writer, 0, 0); - writer.print("\n"); - } - writer.print("}"); + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("Edge"); + writer.print("{\n"); + for (int i = 0; i < getKeyNodes().length; i++) { + SqlNode key = getKeyNodes()[i]; + SqlNode value = getValueNodes()[i]; + key.unparse(writer, 0, 0); + writer.print("="); + value.unparse(writer, 0, 0); + writer.print("\n"); } + writer.print("}"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeProperty.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeProperty.java index d76a6b027..087b8326d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeProperty.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeProperty.java @@ -20,100 +20,106 @@ package org.apache.geaflow.dsl.sqlnode; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.util.ImmutableNullableList; public class SqlEdgeProperty extends SqlCall { - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlEdgeProperty", - SqlKind.OTHER_DDL); - - private SqlIdentifier sourceId; - private SqlIdentifier targetId; - private SqlNode type; - private SqlNode direct; - private SqlNode timeField; - - public SqlEdgeProperty(SqlParserPos pos, SqlIdentifier sourceId, SqlIdentifier targetId, - SqlNode type, SqlNode direct, SqlNode timeField) { - super(pos); - this.sourceId = sourceId; - this.targetId = targetId; - this.type = type; - this.direct = direct; - this.timeField = timeField; - } + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlEdgeProperty", SqlKind.OTHER_DDL); - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.sourceId = (SqlIdentifier) operand; - break; - case 1: - this.targetId = (SqlIdentifier) operand; - break; - case 2: - this.type = operand; - break; - case 3: - this.direct = operand; - break; - case 4: - this.timeField = operand; - break; - default: - break; - } - } + private SqlIdentifier sourceId; + private SqlIdentifier targetId; + private SqlNode type; + private SqlNode direct; + private SqlNode timeField; - @Override - public SqlOperator getOperator() { - return OPERATOR; - } + public SqlEdgeProperty( + SqlParserPos pos, + SqlIdentifier sourceId, + SqlIdentifier targetId, + SqlNode type, + SqlNode direct, + SqlNode timeField) { + super(pos); + this.sourceId = sourceId; + this.targetId = targetId; + this.type = type; + this.direct = direct; + this.timeField = timeField; + } - @Override - public List getOperandList() { - return ImmutableNullableList.of(sourceId, targetId, type, direct, timeField); + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.sourceId = (SqlIdentifier) operand; + break; + case 1: + this.targetId = (SqlIdentifier) operand; + break; + case 2: + this.type = operand; + break; + case 3: + this.direct = operand; + break; + case 4: + this.timeField = operand; + break; + default: + break; } + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("edge"); - writer.keyword("prop"); - writer.print("("); - sourceId.unparse(writer, 0, 0); - writer.print(","); - targetId.unparse(writer, 0, 0); - if (timeField != null) { - writer.print(","); - timeField.unparse(writer, 0, 0); - } - writer.print(","); - type.unparse(writer, 0, 0); - writer.print(","); - direct.unparse(writer, 0, 0); - writer.print(")"); - } + @Override + public SqlOperator getOperator() { + return OPERATOR; + } - public SqlIdentifier getSourceId() { - return sourceId; - } + @Override + public List getOperandList() { + return ImmutableNullableList.of(sourceId, targetId, type, direct, timeField); + } - public SqlIdentifier getTargetId() { - return targetId; + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("edge"); + writer.keyword("prop"); + writer.print("("); + sourceId.unparse(writer, 0, 0); + writer.print(","); + targetId.unparse(writer, 0, 0); + if (timeField != null) { + writer.print(","); + timeField.unparse(writer, 0, 0); } + writer.print(","); + type.unparse(writer, 0, 0); + writer.print(","); + direct.unparse(writer, 0, 0); + writer.print(")"); + } - public SqlNode getType() { - return type; - } + public SqlIdentifier getSourceId() { + return sourceId; + } - public SqlNode getDirect() { - return direct; - } + public SqlIdentifier getTargetId() { + return targetId; + } - public SqlNode getTimeField() { - return timeField; - } + public SqlNode getType() { + return type; + } + + public SqlNode getDirect() { + return direct; + } + + public SqlNode getTimeField() { + return timeField; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeUsing.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeUsing.java index cb7ae2b46..f9fad8c9d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeUsing.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlEdgeUsing.java @@ -19,127 +19,135 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.collect.ImmutableList; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; -public class SqlEdgeUsing extends SqlCall { - - private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlEdgeUsing", - SqlKind.OTHER_DDL); - - private SqlIdentifier name; - private SqlIdentifier usingTableName; - private SqlIdentifier sourceId; - private SqlIdentifier targetId; - private SqlIdentifier timeField; - private SqlNodeList constraints; - - public SqlEdgeUsing(SqlParserPos pos, SqlIdentifier name, - SqlIdentifier usingTableName, - SqlIdentifier sourceId, - SqlIdentifier targetId, - SqlIdentifier timeField, - SqlNodeList constraints) { - super(pos); - this.name = name; - this.usingTableName = usingTableName; - this.sourceId = sourceId; - this.targetId = targetId; - this.timeField = timeField; - this.constraints = constraints; - } +import com.google.common.collect.ImmutableList; - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.usingTableName = (SqlIdentifier) operand; - break; - case 2: - this.sourceId = (SqlIdentifier) operand; - break; - case 3: - this.targetId = (SqlIdentifier) operand; - break; - case 4: - this.timeField = (SqlIdentifier) operand; - break; - case 5: - this.constraints = (SqlNodeList) operand; - break; - default: - throw new AssertionError(); - } - } +public class SqlEdgeUsing extends SqlCall { - @Override - public SqlOperator getOperator() { - return OPERATOR; + private static final SqlOperator OPERATOR = + new SqlSpecialOperator("SqlEdgeUsing", SqlKind.OTHER_DDL); + + private SqlIdentifier name; + private SqlIdentifier usingTableName; + private SqlIdentifier sourceId; + private SqlIdentifier targetId; + private SqlIdentifier timeField; + private SqlNodeList constraints; + + public SqlEdgeUsing( + SqlParserPos pos, + SqlIdentifier name, + SqlIdentifier usingTableName, + SqlIdentifier sourceId, + SqlIdentifier targetId, + SqlIdentifier timeField, + SqlNodeList constraints) { + super(pos); + this.name = name; + this.usingTableName = usingTableName; + this.sourceId = sourceId; + this.targetId = targetId; + this.timeField = timeField; + this.constraints = constraints; + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.usingTableName = (SqlIdentifier) operand; + break; + case 2: + this.sourceId = (SqlIdentifier) operand; + break; + case 3: + this.targetId = (SqlIdentifier) operand; + break; + case 4: + this.timeField = (SqlIdentifier) operand; + break; + case 5: + this.constraints = (SqlNodeList) operand; + break; + default: + throw new AssertionError(); } - - @Override - public List getOperandList() { - return ImmutableList.of(getName(), getUsingTableName(), getSourceId(), getTargetId(), - getTimeField(), getConstraints()); + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public List getOperandList() { + return ImmutableList.of( + getName(), + getUsingTableName(), + getSourceId(), + getTargetId(), + getTimeField(), + getConstraints()); + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("EDGE"); + name.unparse(writer, 0, 0); + writer.keyword("USING"); + usingTableName.unparse(writer, 0, 0); + writer.keyword("WITH"); + writer.keyword("ID"); + writer.print("("); + sourceId.unparse(writer, 0, 0); + writer.print(","); + targetId.unparse(writer, 0, 0); + writer.print(")"); + if (timeField != null) { + writer.print(","); + writer.keyword("TIMESTAMP"); + writer.print("("); + timeField.unparse(writer, 0, 0); + writer.print(")"); } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("EDGE"); - name.unparse(writer, 0, 0); - writer.keyword("USING"); - usingTableName.unparse(writer, 0, 0); - writer.keyword("WITH"); - writer.keyword("ID"); - writer.print("("); - sourceId.unparse(writer, 0, 0); - writer.print(","); - targetId.unparse(writer, 0, 0); - writer.print(")"); - if (timeField != null) { - writer.print(","); - writer.keyword("TIMESTAMP"); - writer.print("("); - timeField.unparse(writer, 0, 0); - writer.print(")"); - } - if (constraints != null && constraints.size() > 0) { - for (int i = 0; i < constraints.size(); i++) { - if (i > 0) { - writer.print("\n"); - } - constraints.get(i).unparse(writer, 0, 0); - } + if (constraints != null && constraints.size() > 0) { + for (int i = 0; i < constraints.size(); i++) { + if (i > 0) { + writer.print("\n"); } + constraints.get(i).unparse(writer, 0, 0); + } } + } - public SqlIdentifier getName() { - return name; - } + public SqlIdentifier getName() { + return name; + } - public SqlIdentifier getUsingTableName() { - return usingTableName; - } + public SqlIdentifier getUsingTableName() { + return usingTableName; + } - public SqlIdentifier getSourceId() { - return sourceId; - } + public SqlIdentifier getSourceId() { + return sourceId; + } - public SqlIdentifier getTargetId() { - return targetId; - } - - public SqlIdentifier getTimeField() { - return timeField; - } + public SqlIdentifier getTargetId() { + return targetId; + } - public SqlNodeList getConstraints() { - return constraints; - } + public SqlIdentifier getTimeField() { + return timeField; + } + public SqlNodeList getConstraints() { + return constraints; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlFilterStatement.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlFilterStatement.java index 7e356499b..200eaafb7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlFilterStatement.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlFilterStatement.java @@ -19,90 +19,91 @@ package org.apache.geaflow.dsl.sqlnode; -import com.google.common.collect.ImmutableList; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.geaflow.dsl.operator.SqlFilterOperator; -public class SqlFilterStatement extends SqlCall { - - private SqlNode from; - - private SqlNode condition; - - public SqlFilterStatement(SqlParserPos pos, SqlNode from, SqlNode condition) { - super(pos); - this.from = from; - this.condition = condition; - } - - @Override - public SqlOperator getOperator() { - return SqlFilterOperator.INSTANCE; - } - - @Override - public List getOperandList() { - return ImmutableList.of(getFrom(), getCondition()); - } +import com.google.common.collect.ImmutableList; - @Override - public SqlKind getKind() { - return SqlKind.GQL_FILTER; - } +public class SqlFilterStatement extends SqlCall { - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); + private SqlNode from; + + private SqlNode condition; + + public SqlFilterStatement(SqlParserPos pos, SqlNode from, SqlNode condition) { + super(pos); + this.from = from; + this.condition = condition; + } + + @Override + public SqlOperator getOperator() { + return SqlFilterOperator.INSTANCE; + } + + @Override + public List getOperandList() { + return ImmutableList.of(getFrom(), getCondition()); + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_FILTER; + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.from = operand; + break; + case 1: + this.condition = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.from = operand; - break; - case 1: - this.condition = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + if (from != null) { + from.unparse(writer, 0, 0); } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - if (from != null) { - from.unparse(writer, 0, 0); - } - if (condition != null) { - if (from != null) { - writer.keyword("THEN"); - writer.newlineAndIndent(); - } - writer.keyword("FILTER"); - condition.unparse(writer, 0, 0); - } + if (condition != null) { + if (from != null) { + writer.keyword("THEN"); writer.newlineAndIndent(); + } + writer.keyword("FILTER"); + condition.unparse(writer, 0, 0); } + writer.newlineAndIndent(); + } - public final SqlNode getFrom() { - return this.from; - } + public final SqlNode getFrom() { + return this.from; + } - public void setFrom(SqlNode from) { - this.from = from; - } + public void setFrom(SqlNode from) { + this.from = from; + } - public final SqlNode getCondition() { - return this.condition; - } - - public void setCondition(SqlNode condition) { - this.condition = condition; - } + public final SqlNode getCondition() { + return this.condition; + } + public void setCondition(SqlNode condition) { + this.condition = condition; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlGraphAlgorithmCall.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlGraphAlgorithmCall.java index 2f1f3e9ef..51e269ce7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlGraphAlgorithmCall.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlGraphAlgorithmCall.java @@ -22,147 +22,139 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.tools.ValidationException; import org.apache.calcite.util.ImmutableNullableList; import org.apache.geaflow.dsl.operator.SqlGraphAlgorithmOperator; -/** - * Parse tree node that represents a CREATE TABLE statement. - */ +/** Parse tree node that represents a CREATE TABLE statement. */ public class SqlGraphAlgorithmCall extends SqlBasicCall { - private static final SqlOperator OPERATOR = SqlGraphAlgorithmOperator.INSTANCE; - - private SqlNode from; - private SqlIdentifier algorithm; - private SqlNodeList parameters; - private SqlNodeList yields; - - /** - * Creates a SqlCreateTable. - */ - public SqlGraphAlgorithmCall(SqlParserPos pos, - SqlNode from, - SqlIdentifier algorithm, - SqlNodeList parameters, - SqlNodeList yields) { - super(OPERATOR, parameters == null ? null : parameters.toArray(), pos); - this.from = from; - this.algorithm = algorithm; - this.parameters = parameters; - this.yields = yields; - } + private static final SqlOperator OPERATOR = SqlGraphAlgorithmOperator.INSTANCE; - @Override - public List getOperandList() { - return ImmutableNullableList.of(getFrom(), getAlgorithm(), getParameters(), getYields()); - } + private SqlNode from; + private SqlIdentifier algorithm; + private SqlNodeList parameters; + private SqlNodeList yields; - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.from = operand; - break; - case 1: - this.algorithm = (SqlIdentifier) operand; - break; - case 2: - this.parameters = (SqlNodeList) operand; - break; - case 3: - this.yields = (SqlNodeList) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + /** Creates a SqlCreateTable. */ + public SqlGraphAlgorithmCall( + SqlParserPos pos, + SqlNode from, + SqlIdentifier algorithm, + SqlNodeList parameters, + SqlNodeList yields) { + super(OPERATOR, parameters == null ? null : parameters.toArray(), pos); + this.from = from; + this.algorithm = algorithm; + this.parameters = parameters; + this.yields = yields; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(getFrom(), getAlgorithm(), getParameters(), getYields()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.from = operand; + break; + case 1: + this.algorithm = (SqlIdentifier) operand; + break; + case 2: + this.parameters = (SqlNodeList) operand; + break; + case 3: + this.yields = (SqlNodeList) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("CALL"); - algorithm.unparse(writer, leftPrec, rightPrec); - if (parameters != null && parameters.size() >= 0) { - final SqlWriter.Frame frame = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("CALL"); + algorithm.unparse(writer, leftPrec, rightPrec); + if (parameters != null && parameters.size() >= 0) { + final SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.newlineAndIndent(); + writer.print(" "); + if (parameters.size() > 0) { + for (int i = 0; i < parameters.size(); i++) { + if (i > 0) { + writer.print(","); writer.newlineAndIndent(); writer.print(" "); - if (parameters.size() > 0) { - for (int i = 0; i < parameters.size(); i++) { - if (i > 0) { - writer.print(","); - writer.newlineAndIndent(); - writer.print(" "); - } - parameters.get(i).unparse(writer, leftPrec, rightPrec); - } - } - writer.newlineAndIndent(); - writer.endList(frame); - } else { - final SqlWriter.Frame frame = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); - writer.endList(frame); + } + parameters.get(i).unparse(writer, leftPrec, rightPrec); } - writer.keyword("YIELD"); - writer.print(" "); - if (yields != null && yields.size() > 0) { - final SqlWriter.Frame with = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); - writer.newlineAndIndent(); - for (int i = 0; i < yields.size(); i++) { - if (i > 0) { - writer.print(","); - writer.newlineAndIndent(); - } - yields.get(i).unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - writer.endList(with); - } else { - final SqlWriter.Frame frame = - writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); - writer.endList(frame); + } + writer.newlineAndIndent(); + writer.endList(frame); + } else { + final SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.endList(frame); + } + writer.keyword("YIELD"); + writer.print(" "); + if (yields != null && yields.size() > 0) { + final SqlWriter.Frame with = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.newlineAndIndent(); + for (int i = 0; i < yields.size(); i++) { + if (i > 0) { + writer.print(","); + writer.newlineAndIndent(); } + yields.get(i).unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); + writer.endList(with); + } else { + final SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")"); + writer.endList(frame); } + } - /** - * Sql syntax validation. - */ - public void validate() throws ValidationException { - Map columnNameMap = new HashMap<>(); - if (yields != null) { - for (SqlNode yield : yields) { - String yieldName = ((SqlIdentifier) yield).getSimple(); - if (columnNameMap.get(yieldName) == null) { - columnNameMap.put(yieldName, true); - } else { - throw new ValidationException( - "duplicate yield name " + "[" + yieldName + "], at " + yield.getParserPosition()); - } - } + /** Sql syntax validation. */ + public void validate() throws ValidationException { + Map columnNameMap = new HashMap<>(); + if (yields != null) { + for (SqlNode yield : yields) { + String yieldName = ((SqlIdentifier) yield).getSimple(); + if (columnNameMap.get(yieldName) == null) { + columnNameMap.put(yieldName, true); + } else { + throw new ValidationException( + "duplicate yield name " + "[" + yieldName + "], at " + yield.getParserPosition()); } + } } + } - public SqlIdentifier getAlgorithm() { - return algorithm; - } + public SqlIdentifier getAlgorithm() { + return algorithm; + } - public SqlNodeList getParameters() { - return parameters; - } + public SqlNodeList getParameters() { + return parameters; + } - public SqlNodeList getYields() { - return yields; - } + public SqlNodeList getYields() { + return yields; + } - public SqlNode getFrom() { - return from; - } + public SqlNode getFrom() { + return from; + } - public void setFrom(SqlNode from) { - this.from = from; - } + public void setFrom(SqlNode from) { + this.from = from; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlLetStatement.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlLetStatement.java index 7ab677f43..9e6cedd7b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlLetStatement.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlLetStatement.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Objects; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; @@ -31,106 +32,106 @@ public class SqlLetStatement extends SqlCall { - /** - * The input node for this statement. - */ - private SqlNode from; - - private SqlIdentifier leftVar; - - private SqlNode expression; - - private final boolean isGlobal; - - public SqlLetStatement(SqlParserPos pos, SqlNode from, SqlIdentifier leftVar, - SqlNode expression, boolean isGlobal) { - super(pos); - this.from = from; - this.leftVar = Objects.requireNonNull(leftVar); - this.expression = Objects.requireNonNull(expression); - this.isGlobal = isGlobal; - } - - @Override - public SqlOperator getOperator() { - return SqlLetOperator.INSTANCE; + /** The input node for this statement. */ + private SqlNode from; + + private SqlIdentifier leftVar; + + private SqlNode expression; + + private final boolean isGlobal; + + public SqlLetStatement( + SqlParserPos pos, SqlNode from, SqlIdentifier leftVar, SqlNode expression, boolean isGlobal) { + super(pos); + this.from = from; + this.leftVar = Objects.requireNonNull(leftVar); + this.expression = Objects.requireNonNull(expression); + this.isGlobal = isGlobal; + } + + @Override + public SqlOperator getOperator() { + return SqlLetOperator.INSTANCE; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(from, leftVar, expression); + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + if (leftVar.names.size() != 2) { + throw new GeaFlowDSLException( + leftVar.getParserPosition(), + "Illegal left variable field size: {}", + leftVar.names.size()); } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(from, leftVar, expression); - } - - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - if (leftVar.names.size() != 2) { - throw new GeaFlowDSLException(leftVar.getParserPosition(), - "Illegal left variable field size: {}", leftVar.names.size()); - } - validator.validateQuery(this, scope, validator.getUnknownType()); - } - - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.from = operand; - break; - case 1: - this.leftVar = (SqlIdentifier) operand; - break; - case 2: - this.expression = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - if (from != null) { - from.unparse(writer, 0, 0); - writer.print("\n"); - } - writer.keyword("LET"); - if (isGlobal) { - writer.keyword("GLOBAL"); - } - leftVar.unparse(writer, 0, 0); - writer.keyword("="); - expression.unparse(writer, 0, 0); + validator.validateQuery(this, scope, validator.getUnknownType()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.from = operand; + break; + case 1: + this.leftVar = (SqlIdentifier) operand; + break; + case 2: + this.expression = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - public SqlNode getFrom() { - return from; + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + if (from != null) { + from.unparse(writer, 0, 0); + writer.print("\n"); } - - public SqlIdentifier getLeftVar() { - return leftVar; - } - - public void setLeftVar(SqlIdentifier leftVar) { - this.leftVar = leftVar; - } - - public String getLeftLabel() { - return leftVar.names.get(0); - } - - public String getLeftField() { - return leftVar.names.get(1); - } - - public SqlNode getExpression() { - return expression; - } - - public void setExpression(SqlNode expression) { - this.expression = expression; - } - - public boolean isGlobal() { - return isGlobal; + writer.keyword("LET"); + if (isGlobal) { + writer.keyword("GLOBAL"); } + leftVar.unparse(writer, 0, 0); + writer.keyword("="); + expression.unparse(writer, 0, 0); + } + + public SqlNode getFrom() { + return from; + } + + public SqlIdentifier getLeftVar() { + return leftVar; + } + + public void setLeftVar(SqlIdentifier leftVar) { + this.leftVar = leftVar; + } + + public String getLeftLabel() { + return leftVar.names.get(0); + } + + public String getLeftField() { + return leftVar.names.get(1); + } + + public SqlNode getExpression() { + return expression; + } + + public void setExpression(SqlNode expression) { + this.expression = expression; + } + + public boolean isGlobal() { + return isGlobal; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchEdge.java index 6d77d7aa7..e2b63f06e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchEdge.java @@ -27,110 +27,114 @@ public class SqlMatchEdge extends SqlMatchNode { - private final EdgeDirection direction; - - private final int minHop; - - private final int maxHop; - - public SqlMatchEdge(SqlParserPos pos, SqlIdentifier name, - SqlNodeList labels, SqlNodeList propertySpecification, SqlNode where, - EdgeDirection direction, - int minHop, int maxHop) { - super(pos, name, labels, propertySpecification, where); - this.direction = direction; - this.minHop = minHop; - this.maxHop = maxHop; - } - - @Override - public SqlOperator getOperator() { - return SqlMatchEdgeOperator.INSTANCE; - } - - @Override - public SqlKind getKind() { - return SqlKind.GQL_MATCH_EDGE; - } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - if (getName() == null && getLabels() == null && getWhere() == null) { - switch (direction) { - case IN: - writer.print("<-"); - break; - case OUT: - writer.print("->"); - break; - case BOTH: - writer.print("-"); - break; - default: - throw new IllegalArgumentException("Illegal direction: " + direction); - } - } else { - if (direction == EdgeDirection.IN) { - writer.print("<"); - } - writer.print("-["); - unparseNode(writer); - writer.print("]-"); - if (direction == EdgeDirection.OUT) { - writer.print(">"); - } - if (minHop != -1 || maxHop != -1) { - writer.print("{"); - if (minHop != -1) { - writer.print(minHop); - } - writer.print(","); - if (maxHop != -1) { - writer.print(maxHop); - } - writer.print("}"); - } + private final EdgeDirection direction; + + private final int minHop; + + private final int maxHop; + + public SqlMatchEdge( + SqlParserPos pos, + SqlIdentifier name, + SqlNodeList labels, + SqlNodeList propertySpecification, + SqlNode where, + EdgeDirection direction, + int minHop, + int maxHop) { + super(pos, name, labels, propertySpecification, where); + this.direction = direction; + this.minHop = minHop; + this.maxHop = maxHop; + } + + @Override + public SqlOperator getOperator() { + return SqlMatchEdgeOperator.INSTANCE; + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_MATCH_EDGE; + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + if (getName() == null && getLabels() == null && getWhere() == null) { + switch (direction) { + case IN: + writer.print("<-"); + break; + case OUT: + writer.print("->"); + break; + case BOTH: + writer.print("-"); + break; + default: + throw new IllegalArgumentException("Illegal direction: " + direction); + } + } else { + if (direction == EdgeDirection.IN) { + writer.print("<"); + } + writer.print("-["); + unparseNode(writer); + writer.print("]-"); + if (direction == EdgeDirection.OUT) { + writer.print(">"); + } + if (minHop != -1 || maxHop != -1) { + writer.print("{"); + if (minHop != -1) { + writer.print(minHop); } - + writer.print(","); + if (maxHop != -1) { + writer.print(maxHop); + } + writer.print("}"); + } } + } - public EdgeDirection getDirection() { - return direction; - } + public EdgeDirection getDirection() { + return direction; + } - public enum EdgeDirection { - OUT, - IN, - BOTH; - - public static EdgeDirection of(String value) { - for (EdgeDirection direction : EdgeDirection.values()) { - if (direction.name().equalsIgnoreCase(value)) { - return direction; - } - } - throw new IllegalArgumentException("Illegal direction value: " + value); - } + public enum EdgeDirection { + OUT, + IN, + BOTH; - public static EdgeDirection reverse(EdgeDirection direction) { - return direction == BOTH ? direction : ((direction == IN) ? OUT : IN); + public static EdgeDirection of(String value) { + for (EdgeDirection direction : EdgeDirection.values()) { + if (direction.name().equalsIgnoreCase(value)) { + return direction; } + } + throw new IllegalArgumentException("Illegal direction value: " + value); } - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); + public static EdgeDirection reverse(EdgeDirection direction) { + return direction == BOTH ? direction : ((direction == IN) ? OUT : IN); } + } - public int getMinHop() { - return minHop; - } + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } - public int getMaxHop() { - return maxHop; - } + public int getMinHop() { + return minHop; + } - public boolean isRegexMatch() { - return minHop != 1 || maxHop != 1; - } + public int getMaxHop() { + return maxHop; + } + + public boolean isRegexMatch() { + return minHop != 1 || maxHop != 1; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchNode.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchNode.java index 8f7a814cd..d58652c51 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchNode.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -32,180 +33,188 @@ public class SqlMatchNode extends SqlCall { - private SqlIdentifier name; - - private SqlNodeList labels; - - private SqlNode where; - - private SqlNode combineWhere; - - private SqlNodeList propertySpecification; - - public SqlMatchNode(SqlParserPos pos, SqlIdentifier name, SqlNodeList labels, - SqlNodeList propertySpecification, SqlNode where) { - super(pos); - this.name = name; - this.labels = labels; - this.propertySpecification = propertySpecification; - this.where = where; - this.combineWhere = getInnerWhere(propertySpecification, where); - } - - @Override - public SqlOperator getOperator() { - return SqlMatchNodeOperator.INSTANCE; - } - - @Override - public SqlKind getKind() { - return SqlKind.GQL_MATCH_NODE; - } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(name, labels, propertySpecification, where); - } - - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.name = (SqlIdentifier) operand; - break; - case 1: - this.labels = (SqlNodeList) operand; - break; - case 2: - this.propertySpecification = (SqlNodeList) operand; - break; - case 3: - this.where = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.print("("); - unparseNode(writer); - writer.print(")"); - } - - protected void unparseNode(SqlWriter writer) { - if (name != null) { - name.unparse(writer, 0, 0); - } - if (labels != null && labels.size() > 0) { - writer.print(":"); - for (int i = 0; i < labels.size(); i++) { - SqlNode label = labels.get(i); - if (i > 0) { - writer.print("|"); - } - label.unparse(writer, 0, 0); - } - } - - if (where != null) { - writer.keyword("where"); - where.unparse(writer, 0, 0); - } - if (propertySpecification != null && propertySpecification.size() > 0) { - writer.keyword("{"); - int idx = 0; - for (SqlNode node : propertySpecification.getList()) { - if (idx % 2 != 0) { - writer.keyword(":"); - } else if (idx > 0) { - writer.keyword(","); - } - node.unparse(writer, 0, 0); - idx++; - } - writer.keyword("}"); - } - } - - public SqlIdentifier getNameId() { - return name; - } - - public SqlIdentifier setName(SqlIdentifier name) { - this.name = name; - return this.name; - } - - public SqlNodeList getLabels() { - return labels; - } - - public List getLabelNames() { - if (labels == null) { - return Collections.emptyList(); + private SqlIdentifier name; + + private SqlNodeList labels; + + private SqlNode where; + + private SqlNode combineWhere; + + private SqlNodeList propertySpecification; + + public SqlMatchNode( + SqlParserPos pos, + SqlIdentifier name, + SqlNodeList labels, + SqlNodeList propertySpecification, + SqlNode where) { + super(pos); + this.name = name; + this.labels = labels; + this.propertySpecification = propertySpecification; + this.where = where; + this.combineWhere = getInnerWhere(propertySpecification, where); + } + + @Override + public SqlOperator getOperator() { + return SqlMatchNodeOperator.INSTANCE; + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_MATCH_NODE; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(name, labels, propertySpecification, where); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.name = (SqlIdentifier) operand; + break; + case 1: + this.labels = (SqlNodeList) operand; + break; + case 2: + this.propertySpecification = (SqlNodeList) operand; + break; + case 3: + this.where = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); + } + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.print("("); + unparseNode(writer); + writer.print(")"); + } + + protected void unparseNode(SqlWriter writer) { + if (name != null) { + name.unparse(writer, 0, 0); + } + if (labels != null && labels.size() > 0) { + writer.print(":"); + for (int i = 0; i < labels.size(); i++) { + SqlNode label = labels.get(i); + if (i > 0) { + writer.print("|"); } - List labelNames = new ArrayList<>(labels.size()); - for (SqlNode labelNode : labels) { - SqlIdentifier label = (SqlIdentifier) labelNode; - labelNames.add(label.getSimple()); + label.unparse(writer, 0, 0); + } + } + + if (where != null) { + writer.keyword("where"); + where.unparse(writer, 0, 0); + } + if (propertySpecification != null && propertySpecification.size() > 0) { + writer.keyword("{"); + int idx = 0; + for (SqlNode node : propertySpecification.getList()) { + if (idx % 2 != 0) { + writer.keyword(":"); + } else if (idx > 0) { + writer.keyword(","); } - return labelNames; - } - - public String getName() { - if (name != null) { - return name.getSimple(); - } else { - return null; - } - } - - public SqlNode getWhere() { - return combineWhere; - } - - public static SqlNode getInnerWhere(SqlNodeList propertySpecification, SqlNode where) { - //If where is null but property specification is not null, - // construct where SqlNode use the property specification. - if (propertySpecification != null && propertySpecification.size() >= 2) { - List propertyEqualsConditions = new ArrayList<>(); - - for (int idx = 0; idx < propertySpecification.size(); idx += 2) { - SqlNode left = propertySpecification.get(idx); - SqlNode right = propertySpecification.get(idx + 1); - SqlNode equalsCondition = makeEqualsSqlNode(left, right); - propertyEqualsConditions.add(equalsCondition); - } - if (where != null) { - propertyEqualsConditions.add(where); - } - return makeAndSqlNode(propertyEqualsConditions); - } - return where; - } - - private static SqlNode makeEqualsSqlNode(SqlNode leftNode, SqlNode rightNode) { - return new SqlBasicCall(SqlStdOperatorTable.EQUALS, new SqlNode[]{leftNode, rightNode}, - leftNode.getParserPosition()); - } - - private static SqlNode makeAndSqlNode(List conditions) { - if (conditions.size() == 1) { - return conditions.get(0); // 只有一个条件时直接返回 - } else { - return new SqlBasicCall(SqlStdOperatorTable.AND, conditions.toArray(new SqlNode[0]), - conditions.get(0).getParserPosition()); - } - } - - public void setWhere(SqlNode where) { - this.combineWhere = where; - } - - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); - } + node.unparse(writer, 0, 0); + idx++; + } + writer.keyword("}"); + } + } + + public SqlIdentifier getNameId() { + return name; + } + + public SqlIdentifier setName(SqlIdentifier name) { + this.name = name; + return this.name; + } + + public SqlNodeList getLabels() { + return labels; + } + + public List getLabelNames() { + if (labels == null) { + return Collections.emptyList(); + } + List labelNames = new ArrayList<>(labels.size()); + for (SqlNode labelNode : labels) { + SqlIdentifier label = (SqlIdentifier) labelNode; + labelNames.add(label.getSimple()); + } + return labelNames; + } + + public String getName() { + if (name != null) { + return name.getSimple(); + } else { + return null; + } + } + + public SqlNode getWhere() { + return combineWhere; + } + + public static SqlNode getInnerWhere(SqlNodeList propertySpecification, SqlNode where) { + // If where is null but property specification is not null, + // construct where SqlNode use the property specification. + if (propertySpecification != null && propertySpecification.size() >= 2) { + List propertyEqualsConditions = new ArrayList<>(); + + for (int idx = 0; idx < propertySpecification.size(); idx += 2) { + SqlNode left = propertySpecification.get(idx); + SqlNode right = propertySpecification.get(idx + 1); + SqlNode equalsCondition = makeEqualsSqlNode(left, right); + propertyEqualsConditions.add(equalsCondition); + } + if (where != null) { + propertyEqualsConditions.add(where); + } + return makeAndSqlNode(propertyEqualsConditions); + } + return where; + } + + private static SqlNode makeEqualsSqlNode(SqlNode leftNode, SqlNode rightNode) { + return new SqlBasicCall( + SqlStdOperatorTable.EQUALS, + new SqlNode[] {leftNode, rightNode}, + leftNode.getParserPosition()); + } + + private static SqlNode makeAndSqlNode(List conditions) { + if (conditions.size() == 1) { + return conditions.get(0); // 只有一个条件时直接返回 + } else { + return new SqlBasicCall( + SqlStdOperatorTable.AND, + conditions.toArray(new SqlNode[0]), + conditions.get(0).getParserPosition()); + } + } + + public void setWhere(SqlNode where) { + this.combineWhere = where; + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchPattern.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchPattern.java index a864b1334..0abb5352b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchPattern.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlMatchPattern.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.sqlnode; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; @@ -29,145 +30,149 @@ public class SqlMatchPattern extends SqlCall { - private SqlNode from; - - private SqlNodeList pathPatterns; - - private SqlNode where; - - private SqlNodeList orderBy; - - private SqlNode limit; - - public SqlMatchPattern(SqlParserPos pos, SqlNode from, SqlNodeList pathPatterns, - SqlNode where, SqlNodeList orderBy, SqlNode limit) { - super(pos); - this.from = from; - this.pathPatterns = pathPatterns; - this.where = where; - this.orderBy = orderBy; - this.limit = limit; - } - - @Override - public SqlOperator getOperator() { - return SqlMatchPatternOperator.INSTANCE; - } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(getFrom(), getPathPatterns(), getWhere(), - getOrderBy(), getLimit()); - } - - @Override - public SqlKind getKind() { - return SqlKind.GQL_MATCH_PATTERN; - } - - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); - } - - public SqlNode getFrom() { - return from; - } - - public void setFrom(SqlNode from) { - this.from = from; - } - - public SqlNodeList getOrderBy() { - return orderBy; - } - - - public SqlNode getLimit() { - return limit; - } - - public void setOrderBy(SqlNodeList orderBy) { - this.orderBy = orderBy; - } - - public void setLimit(SqlNode limit) { - this.limit = limit; - } - - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.from = operand; - break; - case 1: - this.pathPatterns = (SqlNodeList) operand; - break; - case 2: - this.where = operand; - break; - case 3: - this.orderBy = (SqlNodeList) operand; - break; - case 4: - this.limit = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } - } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - writer.keyword("Match"); - if (pathPatterns != null) { - for (int i = 0; i < pathPatterns.size(); i++) { - if (i > 0) { - writer.print(", "); - } - pathPatterns.get(i).unparse(writer, leftPrec, rightPrec); - writer.newlineAndIndent(); - } - } - if (where != null) { - writer.keyword("WHERE"); - where.unparse(writer, 0, 0); + private SqlNode from; + + private SqlNodeList pathPatterns; + + private SqlNode where; + + private SqlNodeList orderBy; + + private SqlNode limit; + + public SqlMatchPattern( + SqlParserPos pos, + SqlNode from, + SqlNodeList pathPatterns, + SqlNode where, + SqlNodeList orderBy, + SqlNode limit) { + super(pos); + this.from = from; + this.pathPatterns = pathPatterns; + this.where = where; + this.orderBy = orderBy; + this.limit = limit; + } + + @Override + public SqlOperator getOperator() { + return SqlMatchPatternOperator.INSTANCE; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of( + getFrom(), getPathPatterns(), getWhere(), getOrderBy(), getLimit()); + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_MATCH_PATTERN; + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } + + public SqlNode getFrom() { + return from; + } + + public void setFrom(SqlNode from) { + this.from = from; + } + + public SqlNodeList getOrderBy() { + return orderBy; + } + + public SqlNode getLimit() { + return limit; + } + + public void setOrderBy(SqlNodeList orderBy) { + this.orderBy = orderBy; + } + + public void setLimit(SqlNode limit) { + this.limit = limit; + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.from = operand; + break; + case 1: + this.pathPatterns = (SqlNodeList) operand; + break; + case 2: + this.where = operand; + break; + case 3: + this.orderBy = (SqlNodeList) operand; + break; + case 4: + this.limit = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); + } + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + writer.keyword("Match"); + if (pathPatterns != null) { + for (int i = 0; i < pathPatterns.size(); i++) { + if (i > 0) { + writer.print(", "); } - if (orderBy != null && orderBy.size() > 0) { - writer.keyword("ORDER BY"); - for (int i = 0; i < orderBy.size(); i++) { - SqlNode label = orderBy.get(i); - if (i > 0) { - writer.print(","); - } - label.unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - } - if (limit != null) { - writer.keyword("LIMIT"); - limit.unparse(writer, leftPrec, rightPrec); + pathPatterns.get(i).unparse(writer, leftPrec, rightPrec); + writer.newlineAndIndent(); + } + } + if (where != null) { + writer.keyword("WHERE"); + where.unparse(writer, 0, 0); + } + if (orderBy != null && orderBy.size() > 0) { + writer.keyword("ORDER BY"); + for (int i = 0; i < orderBy.size(); i++) { + SqlNode label = orderBy.get(i); + if (i > 0) { + writer.print(","); } + label.unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); } - - public SqlNodeList getPathPatterns() { - return pathPatterns; + if (limit != null) { + writer.keyword("LIMIT"); + limit.unparse(writer, leftPrec, rightPrec); } + } - public SqlNode getWhere() { - return where; - } + public SqlNodeList getPathPatterns() { + return pathPatterns; + } - public final boolean isDistinct() { - return false; - } + public SqlNode getWhere() { + return where; + } - public void setWhere(SqlNode where) { - this.where = where; - } + public final boolean isDistinct() { + return false; + } - public boolean isSinglePattern() { - return pathPatterns.size() == 1 && pathPatterns.get(0) instanceof SqlPathPattern; - } + public void setWhere(SqlNode where) { + this.where = where; + } + + public boolean isSinglePattern() { + return pathPatterns.size() == 1 && pathPatterns.get(0) instanceof SqlPathPattern; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPattern.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPattern.java index 1c5c02d5a..f7719f8c4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPattern.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPattern.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Objects; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; @@ -30,78 +31,77 @@ public class SqlPathPattern extends SqlCall { - private SqlNodeList pathNodes; - - private SqlIdentifier pathAlias; - - - public SqlPathPattern(SqlParserPos pos, SqlNodeList pathNodes, SqlIdentifier pathAlias) { - super(pos); - this.pathNodes = Objects.requireNonNull(pathNodes); - this.pathAlias = pathAlias; + private SqlNodeList pathNodes; + + private SqlIdentifier pathAlias; + + public SqlPathPattern(SqlParserPos pos, SqlNodeList pathNodes, SqlIdentifier pathAlias) { + super(pos); + this.pathNodes = Objects.requireNonNull(pathNodes); + this.pathAlias = pathAlias; + } + + @Override + public SqlOperator getOperator() { + return SqlPathPatternOperator.INSTANCE; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(pathNodes, pathAlias); + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.pathNodes = (SqlNodeList) operand; + break; + case 1: + this.pathAlias = (SqlIdentifier) operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } + } - @Override - public SqlOperator getOperator() { - return SqlPathPatternOperator.INSTANCE; + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + if (pathAlias != null) { + pathAlias.unparse(writer, 0, 0); + writer.print("="); } - @Override - public List getOperandList() { - return ImmutableNullableList.of(pathNodes, pathAlias); + for (SqlNode node : pathNodes) { + node.unparse(writer, leftPrec, rightPrec); } + } - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); - } + public SqlNodeList getPathNodes() { + return pathNodes; + } - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.pathNodes = (SqlNodeList) operand; - break; - case 1: - this.pathAlias = (SqlIdentifier) operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + public String getPathAliasName() { + if (pathAlias != null) { + return pathAlias.getSimple(); } + return null; + } - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - if (pathAlias != null) { - pathAlias.unparse(writer, 0, 0); - writer.print("="); - } - - for (SqlNode node : pathNodes) { - node.unparse(writer, leftPrec, rightPrec); - } - } + public void setPathAlias(SqlIdentifier pathAlias) { + this.pathAlias = pathAlias; + } - public SqlNodeList getPathNodes() { - return pathNodes; - } + public SqlMatchNode getFirst() { + return (SqlMatchNode) pathNodes.get(0); + } - public String getPathAliasName() { - if (pathAlias != null) { - return pathAlias.getSimple(); - } - return null; - } - - public void setPathAlias(SqlIdentifier pathAlias) { - this.pathAlias = pathAlias; - } - - public SqlMatchNode getFirst() { - return (SqlMatchNode) pathNodes.get(0); - } - - public SqlMatchNode getLast() { - return (SqlMatchNode) pathNodes.get(pathNodes.size() - 1); - } + public SqlMatchNode getLast() { + return (SqlMatchNode) pathNodes.get(pathNodes.size() - 1); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPatternSubQuery.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPatternSubQuery.java index c7436c81c..4c882d5eb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPatternSubQuery.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlPathPatternSubQuery.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Objects; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidator; @@ -31,72 +32,71 @@ public class SqlPathPatternSubQuery extends SqlCall { - private static final SqlOperator OPERATOR = SqlBasicQueryOperator.of("PathPatternSubQuery"); - - private SqlPathPattern pathPattern; - - private SqlNode returnValue; - - public SqlPathPatternSubQuery(SqlPathPattern pathPattern, - SqlNode returnValue, SqlParserPos pos) { - super(pos); - this.pathPattern = Objects.requireNonNull(pathPattern); - this.returnValue = returnValue; - } - - @Override - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.pathPattern = (SqlPathPattern) operand; - break; - case 1: - this.returnValue = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); - } + private static final SqlOperator OPERATOR = SqlBasicQueryOperator.of("PathPatternSubQuery"); + + private SqlPathPattern pathPattern; + + private SqlNode returnValue; + + public SqlPathPatternSubQuery(SqlPathPattern pathPattern, SqlNode returnValue, SqlParserPos pos) { + super(pos); + this.pathPattern = Objects.requireNonNull(pathPattern); + this.returnValue = returnValue; + } + + @Override + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.pathPattern = (SqlPathPattern) operand; + break; + case 1: + this.returnValue = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } - - @Override - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - pathPattern.unparse(writer, 0, 0); - if (returnValue != null) { - writer.print("=>"); - returnValue.unparse(writer, 0, 0); - } - } - - @Override - public SqlOperator getOperator() { - return OPERATOR; - } - - @Override - public List getOperandList() { - return ImmutableNullableList.of(pathPattern, returnValue); - } - - public SqlPathPattern getPathPattern() { - return pathPattern; - } - - public SqlNode getReturnValue() { - return returnValue; - } - - public void setReturnValue(SqlNode returnValue) { - this.returnValue = returnValue; - } - - @Override - public SqlKind getKind() { - return SqlKind.GQL_PATH_PATTERN_SUB_QUERY; - } - - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - SqlValidatorNamespace namespace = validator.getNamespace(this); - namespace.validate(validator.getUnknownType()); + } + + @Override + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + pathPattern.unparse(writer, 0, 0); + if (returnValue != null) { + writer.print("=>"); + returnValue.unparse(writer, 0, 0); } + } + + @Override + public SqlOperator getOperator() { + return OPERATOR; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of(pathPattern, returnValue); + } + + public SqlPathPattern getPathPattern() { + return pathPattern; + } + + public SqlNode getReturnValue() { + return returnValue; + } + + public void setReturnValue(SqlNode returnValue) { + this.returnValue = returnValue; + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_PATH_PATTERN_SUB_QUERY; + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + SqlValidatorNamespace namespace = validator.getNamespace(this); + namespace.validate(validator.getUnknownType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlReturnStatement.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlReturnStatement.java index d85fa4e9e..1dbf6b562 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlReturnStatement.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlReturnStatement.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.sqlnode; import java.util.List; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.SqlWriter.Frame; import org.apache.calcite.sql.SqlWriter.FrameTypeEnum; @@ -32,220 +33,230 @@ public class SqlReturnStatement extends SqlCall { - private SqlNode from; - - private SqlNodeList gqlReturnKeywordList; - - private SqlNodeList returnList; - - private SqlNodeList groupBy; - - private SqlNodeList orderBy; - - private SqlNode offset; - - private SqlNode fetch; - - public SqlReturnStatement(SqlParserPos pos, SqlNodeList gqlReturnKeywordList, SqlNode from, - SqlNodeList returnList, SqlNodeList groupBy, SqlNodeList orderBy, - SqlNode offset, SqlNode fetch) { - super(pos); - this.gqlReturnKeywordList = gqlReturnKeywordList; - this.from = from; - this.returnList = returnList; - this.groupBy = groupBy; - this.orderBy = orderBy; - this.offset = offset; - this.fetch = fetch; + private SqlNode from; + + private SqlNodeList gqlReturnKeywordList; + + private SqlNodeList returnList; + + private SqlNodeList groupBy; + + private SqlNodeList orderBy; + + private SqlNode offset; + + private SqlNode fetch; + + public SqlReturnStatement( + SqlParserPos pos, + SqlNodeList gqlReturnKeywordList, + SqlNode from, + SqlNodeList returnList, + SqlNodeList groupBy, + SqlNodeList orderBy, + SqlNode offset, + SqlNode fetch) { + super(pos); + this.gqlReturnKeywordList = gqlReturnKeywordList; + this.from = from; + this.returnList = returnList; + this.groupBy = groupBy; + this.orderBy = orderBy; + this.offset = offset; + this.fetch = fetch; + } + + @Override + public SqlOperator getOperator() { + return SqlReturnOperator.INSTANCE; + } + + @Override + public List getOperandList() { + return ImmutableNullableList.of( + getGQLReturnKeywordList(), + getFrom(), + getReturnList(), + getGroupBy(), + getOrderBy(), + getOffset(), + getFetch()); + } + + public void setOperand(int i, SqlNode operand) { + switch (i) { + case 0: + this.gqlReturnKeywordList = (SqlNodeList) operand; + break; + case 1: + this.from = operand; + break; + case 2: + this.returnList = (SqlNodeList) operand; + break; + case 3: + this.groupBy = (SqlNodeList) operand; + break; + case 4: + this.orderBy = (SqlNodeList) operand; + break; + case 5: + this.offset = operand; + break; + case 6: + this.fetch = operand; + break; + default: + throw new IllegalArgumentException("Illegal index: " + i); } - - @Override - public SqlOperator getOperator() { - return SqlReturnOperator.INSTANCE; + } + + public void setFrom(SqlNode from) { + this.from = from; + } + + public void setReturnList(SqlNodeList returnList) { + this.returnList = returnList; + } + + public void setGroupBy(SqlNodeList groupBy) { + this.groupBy = groupBy; + } + + public final SqlNodeList getOrderList() { + return this.orderBy; + } + + public void setOrderBy(SqlNodeList orderBy) { + this.orderBy = orderBy; + } + + public void setOffset(SqlNode offset) { + this.offset = offset; + } + + public void setFetch(SqlNode fetch) { + this.fetch = fetch; + } + + public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { + if (!writer.inQuery()) { + Frame frame = writer.startList(FrameTypeEnum.SUB_QUERY, "(", ")"); + unparseCall(writer, 0, 0); + writer.endList(frame); + } else { + unparseCall(writer, leftPrec, rightPrec); } + } - @Override - public List getOperandList() { - return ImmutableNullableList.of(getGQLReturnKeywordList(), getFrom(), getReturnList(), - getGroupBy(), getOrderBy(), getOffset(), getFetch()); + private void unparseCall(SqlWriter writer, int leftPrec, int rightPrec) { + if (from != null) { + from.unparse(writer, leftPrec, rightPrec); } - - public void setOperand(int i, SqlNode operand) { - switch (i) { - case 0: - this.gqlReturnKeywordList = (SqlNodeList) operand; - break; - case 1: - this.from = operand; - break; - case 2: - this.returnList = (SqlNodeList) operand; - break; - case 3: - this.groupBy = (SqlNodeList) operand; - break; - case 4: - this.orderBy = (SqlNodeList) operand; - break; - case 5: - this.offset = operand; - break; - case 6: - this.fetch = operand; - break; - default: - throw new IllegalArgumentException("Illegal index: " + i); + if (returnList != null && returnList.size() > 0) { + writer.keyword("RETURN"); + if (gqlReturnKeywordList != null) { + for (int i = 0; i < gqlReturnKeywordList.size(); i++) { + SqlNode keyword = gqlReturnKeywordList.get(i); + writer.print(" "); + keyword.unparse(writer, leftPrec, rightPrec); } - - } - - public void setFrom(SqlNode from) { - this.from = from; - } - - public void setReturnList(SqlNodeList returnList) { - this.returnList = returnList; - } - - public void setGroupBy(SqlNodeList groupBy) { - this.groupBy = groupBy; - } - - public final SqlNodeList getOrderList() { - return this.orderBy; - } - - public void setOrderBy(SqlNodeList orderBy) { - this.orderBy = orderBy; - } - - public void setOffset(SqlNode offset) { - this.offset = offset; - } - - public void setFetch(SqlNode fetch) { - this.fetch = fetch; - } - - public void unparse(SqlWriter writer, int leftPrec, int rightPrec) { - if (!writer.inQuery()) { - Frame frame = writer.startList(FrameTypeEnum.SUB_QUERY, "(", ")"); - unparseCall(writer, 0, 0); - writer.endList(frame); - } else { - unparseCall(writer, leftPrec, rightPrec); + writer.print(" "); + } + for (int i = 0; i < returnList.size(); i++) { + SqlNode label = returnList.get(i); + if (i > 0) { + writer.print(","); } + label.unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); } - - private void unparseCall(SqlWriter writer, int leftPrec, int rightPrec) { - if (from != null) { - from.unparse(writer, leftPrec, rightPrec); - } - if (returnList != null && returnList.size() > 0) { - writer.keyword("RETURN"); - if (gqlReturnKeywordList != null) { - for (int i = 0; i < gqlReturnKeywordList.size(); i++) { - SqlNode keyword = gqlReturnKeywordList.get(i); - writer.print(" "); - keyword.unparse(writer, leftPrec, rightPrec); - } - writer.print(" "); - } - for (int i = 0; i < returnList.size(); i++) { - SqlNode label = returnList.get(i); - if (i > 0) { - writer.print(","); - } - label.unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - } - if (groupBy != null && groupBy.size() > 0) { - writer.keyword("GROUP BY"); - for (int i = 0; i < groupBy.size(); i++) { - SqlNode label = groupBy.get(i); - if (i > 0) { - writer.print(","); - } - label.unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - } - if (orderBy != null && orderBy.size() > 0) { - writer.keyword("ORDER BY"); - for (int i = 0; i < orderBy.size(); i++) { - SqlNode label = orderBy.get(i); - if (i > 0) { - writer.print(","); - } - label.unparse(writer, leftPrec, rightPrec); - } - writer.newlineAndIndent(); - } - - if (fetch != null) { - writer.keyword("LIMIT"); - fetch.unparse(writer, leftPrec, rightPrec); + if (groupBy != null && groupBy.size() > 0) { + writer.keyword("GROUP BY"); + for (int i = 0; i < groupBy.size(); i++) { + SqlNode label = groupBy.get(i); + if (i > 0) { + writer.print(","); } - - if (offset != null) { - writer.keyword("OFFSET"); - offset.unparse(writer, leftPrec, rightPrec); - } - } - - @Override - public SqlKind getKind() { - return SqlKind.GQL_RETURN; - } - - public SqlNode getGQLReturnKeywordList() { - return gqlReturnKeywordList; + label.unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); } - - public SqlNode getFrom() { - return from; - } - - public SqlNodeList getReturnList() { - return returnList; - } - - public SqlNodeList getGroupBy() { - return groupBy; - } - - public SqlNodeList getOrderBy() { - return orderBy; - } - - public SqlNode getOffset() { - return offset; - } - - public SqlNode getFetch() { - return fetch; + if (orderBy != null && orderBy.size() > 0) { + writer.keyword("ORDER BY"); + for (int i = 0; i < orderBy.size(); i++) { + SqlNode label = orderBy.get(i); + if (i > 0) { + writer.print(","); + } + label.unparse(writer, leftPrec, rightPrec); + } + writer.newlineAndIndent(); } - @Override - public void validate(SqlValidator validator, SqlValidatorScope scope) { - validator.validateQuery(this, scope, validator.getUnknownType()); + if (fetch != null) { + writer.keyword("LIMIT"); + fetch.unparse(writer, leftPrec, rightPrec); } - public final boolean isDistinct() { - return getModifierNode(GQLReturnKeyword.DISTINCT) != null; + if (offset != null) { + writer.keyword("OFFSET"); + offset.unparse(writer, leftPrec, rightPrec); } - - public final SqlNode getModifierNode(GQLReturnKeyword modifier) { - if (gqlReturnKeywordList != null) { - for (SqlNode keyword : gqlReturnKeywordList) { - GQLReturnKeyword keyword2 = - ((SqlLiteral) keyword).symbolValue(GQLReturnKeyword.class); - if (keyword2 == modifier) { - return keyword; - } - } + } + + @Override + public SqlKind getKind() { + return SqlKind.GQL_RETURN; + } + + public SqlNode getGQLReturnKeywordList() { + return gqlReturnKeywordList; + } + + public SqlNode getFrom() { + return from; + } + + public SqlNodeList getReturnList() { + return returnList; + } + + public SqlNodeList getGroupBy() { + return groupBy; + } + + public SqlNodeList getOrderBy() { + return orderBy; + } + + public SqlNode getOffset() { + return offset; + } + + public SqlNode getFetch() { + return fetch; + } + + @Override + public void validate(SqlValidator validator, SqlValidatorScope scope) { + validator.validateQuery(this, scope, validator.getUnknownType()); + } + + public final boolean isDistinct() { + return getModifierNode(GQLReturnKeyword.DISTINCT) != null; + } + + public final SqlNode getModifierNode(GQLReturnKeyword modifier) { + if (gqlReturnKeywordList != null) { + for (SqlNode keyword : gqlReturnKeywordList) { + GQLReturnKeyword keyword2 = ((SqlLiteral) keyword).symbolValue(GQLReturnKeyword.class); + if (keyword2 == modifier) { + return keyword; } - return null; + } } + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlSameCall.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlSameCall.java index 7e4d5d545..83216063b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlSameCall.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlSameCall.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperator; @@ -34,12 +35,13 @@ /** * SqlNode representing the ISO-GQL SAME predicate function. * - *

The SAME predicate checks if multiple element references point to the same - * graph element (identity check, not value equality). + *

The SAME predicate checks if multiple element references point to the same graph element + * (identity check, not value equality). * *

Syntax: SAME(element_ref1, element_ref2 [, element_ref3, ...]) * *

Example: + * *

  * MATCH (a:Person)-[:KNOWS]->(b), (b)-[:KNOWS]->(c)
  * WHERE SAME(a, c)
@@ -52,80 +54,75 @@
  */
 public class SqlSameCall extends SqlCall {
 
-    private final List operands;
-
-    /**
-     * Creates a SqlSameCall.
-     *
-     * @param pos Parser position
-     * @param operands List of element reference expressions (must be 2 or more)
-     */
-    public SqlSameCall(SqlParserPos pos, List operands) {
-        super(pos);
-        // Create a mutable copy to allow setOperand to work
-        this.operands = new ArrayList<>(Objects.requireNonNull(operands, "operands"));
-
-        // ISO-GQL requires at least 2 arguments
-        if (operands.size() < 2) {
-            throw new IllegalArgumentException(
-                "SAME predicate requires at least 2 arguments, got: " + operands.size());
-        }
-    }
+  private final List operands;
 
-    @Override
-    public SqlOperator getOperator() {
-        return SqlSameOperator.INSTANCE;
-    }
+  /**
+   * Creates a SqlSameCall.
+   *
+   * @param pos Parser position
+   * @param operands List of element reference expressions (must be 2 or more)
+   */
+  public SqlSameCall(SqlParserPos pos, List operands) {
+    super(pos);
+    // Create a mutable copy to allow setOperand to work
+    this.operands = new ArrayList<>(Objects.requireNonNull(operands, "operands"));
 
-    @Override
-    public List getOperandList() {
-        return operands;
+    // ISO-GQL requires at least 2 arguments
+    if (operands.size() < 2) {
+      throw new IllegalArgumentException(
+          "SAME predicate requires at least 2 arguments, got: " + operands.size());
     }
+  }
 
-    @Override
-    public void validate(SqlValidator validator, SqlValidatorScope scope) {
-        // Validation will be handled by GQLSameValidator
-        // This just validates the syntax is correct
-        for (SqlNode operand : operands) {
-            operand.validate(validator, scope);
-        }
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return SqlSameOperator.INSTANCE;
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        if (i < 0 || i >= operands.size()) {
-            throw new IllegalArgumentException("Invalid operand index: " + i);
-        }
-        operands.set(i, operand);
+  @Override
+  public List getOperandList() {
+    return operands;
+  }
+
+  @Override
+  public void validate(SqlValidator validator, SqlValidatorScope scope) {
+    // Validation will be handled by GQLSameValidator
+    // This just validates the syntax is correct
+    for (SqlNode operand : operands) {
+      operand.validate(validator, scope);
     }
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        writer.print("SAME");
-        final SqlWriter.Frame frame =
-            writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")");
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    if (i < 0 || i >= operands.size()) {
+      throw new IllegalArgumentException("Invalid operand index: " + i);
+    }
+    operands.set(i, operand);
+  }
 
-        for (int i = 0; i < operands.size(); i++) {
-            if (i > 0) {
-                writer.sep(",");
-            }
-            operands.get(i).unparse(writer, 0, 0);
-        }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    writer.print("SAME");
+    final SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.FUN_CALL, "(", ")");
 
-        writer.endList(frame);
+    for (int i = 0; i < operands.size(); i++) {
+      if (i > 0) {
+        writer.sep(",");
+      }
+      operands.get(i).unparse(writer, 0, 0);
     }
 
-    /**
-     * Returns the number of operands (element references) in this SAME call.
-     */
-    public int getOperandCount() {
-        return operands.size();
-    }
+    writer.endList(frame);
+  }
 
-    /**
-     * Returns the operand at the specified index.
-     */
-    public SqlNode getOperand(int index) {
-        return operands.get(index);
-    }
+  /** Returns the number of operands (element references) in this SAME call. */
+  public int getOperandCount() {
+    return operands.size();
+  }
+
+  /** Returns the operand at the specified index. */
+  public SqlNode getOperand(int index) {
+    return operands.get(index);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlShardNode.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlShardNode.java
index 886ca3bb8..17668c6c8 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlShardNode.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlShardNode.java
@@ -19,55 +19,58 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 
+import com.google.common.collect.ImmutableList;
+
 public class SqlShardNode extends SqlCall {
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlShardNode", SqlKind.OTHER_DDL);
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("SqlShardNode", SqlKind.OTHER_DDL);
 
-    private SqlIdentifier type;
+  private SqlIdentifier type;
 
-    private SqlNumericLiteral shardCount;
+  private SqlNumericLiteral shardCount;
 
-    public SqlShardNode(SqlParserPos pos, SqlIdentifier type, SqlNumericLiteral shardCount) {
-        super(pos);
-        this.type = type;
-        this.shardCount = shardCount;
-    }
+  public SqlShardNode(SqlParserPos pos, SqlIdentifier type, SqlNumericLiteral shardCount) {
+    super(pos);
+    this.type = type;
+    this.shardCount = shardCount;
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.type = (SqlIdentifier) operand;
-                break;
-            case 1:
-                this.shardCount = (SqlNumericLiteral) operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.type = (SqlIdentifier) operand;
+        break;
+      case 1:
+        this.shardCount = (SqlNumericLiteral) operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(type, shardCount);
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(type, shardCount);
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        writer.keyword("Shard As ");
-        type.unparse(writer, 0, 0);
-        writer.print("(");
-        shardCount.unparse(writer, 0, 0);
-        writer.print(")");
-    }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    writer.keyword("Shard As ");
+    type.unparse(writer, 0, 0);
+    writer.print("(");
+    shardCount.unparse(writer, 0, 0);
+    writer.print(")");
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableColumn.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableColumn.java
index 61c97431a..a948057b4 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableColumn.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableColumn.java
@@ -19,9 +19,9 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.validate.SqlValidator;
@@ -30,170 +30,169 @@
 import org.apache.geaflow.dsl.common.types.TableField;
 import org.apache.geaflow.dsl.util.SqlTypeUtil;
 
+import com.google.common.collect.ImmutableList;
+
 public class SqlTableColumn extends SqlCall {
 
-    private static final SqlOperator OPERATOR =
-        new SqlSpecialOperator("Table Column", SqlKind.OTHER_DDL);
-
-    private SqlIdentifier name;
-    private SqlDataTypeSpec type;
-    private SqlIdentifier category;
-    private SqlIdentifier typeFrom;
-
-    public SqlTableColumn(SqlIdentifier name,
-                          SqlDataTypeSpec type,
-                          SqlIdentifier category,
-                          SqlParserPos pos) {
-        super(pos);
-        this.name = name;
-        this.type = Objects.requireNonNull(type);
-        this.typeFrom = null;
-        this.category = category;
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("Table Column", SqlKind.OTHER_DDL);
+
+  private SqlIdentifier name;
+  private SqlDataTypeSpec type;
+  private SqlIdentifier category;
+  private SqlIdentifier typeFrom;
+
+  public SqlTableColumn(
+      SqlIdentifier name, SqlDataTypeSpec type, SqlIdentifier category, SqlParserPos pos) {
+    super(pos);
+    this.name = name;
+    this.type = Objects.requireNonNull(type);
+    this.typeFrom = null;
+    this.category = category;
+  }
+
+  public SqlTableColumn(
+      SqlIdentifier name,
+      SqlDataTypeSpec type,
+      SqlIdentifier typeFrom,
+      SqlIdentifier category,
+      SqlParserPos pos) {
+    super(pos);
+    this.name = name;
+    this.type = type;
+    this.typeFrom = typeFrom;
+    assert type != null || typeFrom != null;
+    this.category = category;
+  }
+
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
+
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(getName(), getType() != null ? getType() : getTypeFrom(), category);
+  }
+
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.name = (SqlIdentifier) operand;
+        break;
+      case 1:
+        if (operand instanceof SqlDataTypeSpec) {
+          this.type = (SqlDataTypeSpec) operand;
+          this.typeFrom = null;
+        } else {
+          this.type = null;
+          this.typeFrom = (SqlIdentifier) operand;
+        }
+        break;
+      case 2:
+        this.category = (SqlIdentifier) operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    public SqlTableColumn(SqlIdentifier name,
-                          SqlDataTypeSpec type,
-                          SqlIdentifier typeFrom,
-                          SqlIdentifier category,
-                          SqlParserPos pos) {
-        super(pos);
-        this.name = name;
-        this.type = type;
-        this.typeFrom = typeFrom;
-        assert type != null || typeFrom != null;
-        this.category = category;
-    }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
+    name.unparse(writer, leftPrec, rightPrec);
+    writer.print(" ");
+    if (type == null) {
+      writer.keyword("from");
+      typeFrom.unparse(writer, leftPrec, rightPrec);
+    } else {
+      type.unparse(writer, leftPrec, rightPrec);
     }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(getName(), getType() != null ? getType() : getTypeFrom(), category);
+    if (category != null) {
+      ColumnCategory category = getCategory();
+      writer.print(" ");
+      writer.keyword(category.name);
     }
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.name = (SqlIdentifier) operand;
-                break;
-            case 1:
-                if (operand instanceof SqlDataTypeSpec) {
-                    this.type = (SqlDataTypeSpec) operand;
-                    this.typeFrom = null;
-                } else {
-                    this.type = null;
-                    this.typeFrom = (SqlIdentifier) operand;
-                }
-                break;
-            case 2:
-                this.category = (SqlIdentifier) operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  public void validate() {
+    if (type == null && typeFrom != null) {
+      ColumnCategory columnCategory = ColumnCategory.of(this.category.toString());
+      assert columnCategory == ColumnCategory.SOURCE_ID
+              || columnCategory == ColumnCategory.DESTINATION_ID
+          : "Only edge source/destination id field can use type from syntax.";
     }
+  }
 
-    @Override
-    public void unparse(SqlWriter writer,
-                        int leftPrec,
-                        int rightPrec) {
+  @Override
+  public void validate(SqlValidator validator, SqlValidatorScope scope) {
+    this.validate();
+    super.validate(validator, scope);
+  }
 
-        name.unparse(writer, leftPrec, rightPrec);
-        writer.print(" ");
-        if (type == null) {
-            writer.keyword("from");
-            typeFrom.unparse(writer, leftPrec, rightPrec);
-        } else {
-            type.unparse(writer, leftPrec, rightPrec);
-        }
+  public SqlIdentifier getName() {
+    return name;
+  }
 
-        if (category != null) {
-            ColumnCategory category = getCategory();
-            writer.print(" ");
-            writer.keyword(category.name);
-        }
-    }
+  public void setName(SqlIdentifier name) {
+    this.name = name;
+  }
 
-    public void validate() {
-        if (type == null && typeFrom != null) {
-            ColumnCategory columnCategory = ColumnCategory.of(this.category.toString());
-            assert columnCategory == ColumnCategory.SOURCE_ID
-                || columnCategory == ColumnCategory.DESTINATION_ID
-                : "Only edge source/destination id field can use type from syntax.";
-        }
-    }
+  public SqlDataTypeSpec getType() {
+    return type;
+  }
 
-    @Override
-    public void validate(SqlValidator validator, SqlValidatorScope scope) {
-        this.validate();
-        super.validate(validator, scope);
-    }
+  public SqlIdentifier getTypeFrom() {
+    return typeFrom;
+  }
 
-    public SqlIdentifier getName() {
-        return name;
+  public ColumnCategory getCategory() {
+    if (category == null) {
+      return ColumnCategory.NONE;
     }
+    return ColumnCategory.of(category.getSimple());
+  }
 
-    public void setName(SqlIdentifier name) {
-        this.name = name;
+  public TableField toTableField() {
+    IType columnType = SqlTypeUtil.convertType(type);
+    Boolean nullable = type.getNullable();
+    if (nullable == null) {
+      nullable = true;
     }
+    return toTableField(columnType, nullable);
+  }
 
-    public SqlDataTypeSpec getType() {
-        return type;
-    }
+  public TableField toTableField(IType columnType, boolean nullable) {
+    String columnName = name.getSimple();
+    return new TableField(columnName, columnType, nullable);
+  }
 
-    public SqlIdentifier getTypeFrom() {
-        return typeFrom;
-    }
+  public enum ColumnCategory {
+    NONE(""),
+    ID("ID"),
+    SOURCE_ID("SOURCE ID"),
+    DESTINATION_ID("DESTINATION ID"),
+    TIMESTAMP("TIMESTAMP");
 
-    public ColumnCategory getCategory() {
-        if (category == null) {
-            return ColumnCategory.NONE;
-        }
-        return ColumnCategory.of(category.getSimple());
-    }
+    private final String name;
 
-    public TableField toTableField() {
-        IType columnType = SqlTypeUtil.convertType(type);
-        Boolean nullable = type.getNullable();
-        if (nullable == null) {
-            nullable = true;
-        }
-        return toTableField(columnType, nullable);
+    ColumnCategory(String name) {
+      this.name = Objects.requireNonNull(name);
     }
 
-    public TableField toTableField(IType columnType, boolean nullable) {
-        String columnName = name.getSimple();
-        return new TableField(columnName, columnType, nullable);
+    public String getName() {
+      return name;
     }
 
-    public enum ColumnCategory {
-        NONE(""),
-        ID("ID"),
-        SOURCE_ID("SOURCE ID"),
-        DESTINATION_ID("DESTINATION ID"),
-        TIMESTAMP("TIMESTAMP");
-
-        private final String name;
-
-        ColumnCategory(String name) {
-            this.name = Objects.requireNonNull(name);
-        }
-
-        public String getName() {
-            return name;
-        }
-
-        public static ColumnCategory of(String name) {
-            for (ColumnCategory category : values()) {
-                if (category.name.equalsIgnoreCase(name)) {
-                    return category;
-                }
-            }
-            return NONE;
+    public static ColumnCategory of(String name) {
+      for (ColumnCategory category : values()) {
+        if (category.name.equalsIgnoreCase(name)) {
+          return category;
         }
+      }
+      return NONE;
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableProperty.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableProperty.java
index 68cb40e9b..fb6e5a6fa 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableProperty.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlTableProperty.java
@@ -19,82 +19,78 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.util.NlsString;
 
-/**
- * Parse tree node that represents a Sql table property.
- */
+import com.google.common.collect.ImmutableList;
+
+/** Parse tree node that represents a Sql table property. */
 public class SqlTableProperty extends SqlCall {
 
-    private static final SqlOperator OPERATOR =
-        new SqlSpecialOperator("Table Property", SqlKind.OTHER);
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("Table Property", SqlKind.OTHER);
 
-    private SqlIdentifier key;
-    private SqlNode value;
+  private SqlIdentifier key;
+  private SqlNode value;
 
-    public SqlTableProperty(SqlIdentifier key,
-                            SqlNode value,
-                            SqlParserPos pos) {
-        super(pos);
-        this.key = key;
-        this.value = value;
-    }
+  public SqlTableProperty(SqlIdentifier key, SqlNode value, SqlParserPos pos) {
+    super(pos);
+    this.key = key;
+    this.value = value;
+  }
 
-    public SqlIdentifier getKey() {
-        return key;
-    }
+  public SqlIdentifier getKey() {
+    return key;
+  }
 
-    public void setKey(SqlIdentifier key) {
-        this.key = key;
-    }
+  public void setKey(SqlIdentifier key) {
+    this.key = key;
+  }
 
-    public SqlNode getValue() {
-        return value;
-    }
+  public SqlNode getValue() {
+    return value;
+  }
 
-    public void setValue(SqlNode value) {
-        this.value = value;
-    }
+  public void setValue(SqlNode value) {
+    this.value = value;
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(getKey(), getValue());
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(getKey(), getValue());
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.key = (SqlIdentifier) operand;
-                break;
-            case 1:
-                this.value = operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.key = (SqlIdentifier) operand;
+        break;
+      case 1:
+        this.value = operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    @Override
-    public void unparse(SqlWriter writer,
-                        int leftPrec,
-                        int rightPrec) {
-        key.unparse(writer, 0, 0);
-        writer.print("=");
-        if (value instanceof SqlCharStringLiteral) {
-            NlsString nlsString = (NlsString) ((SqlCharStringLiteral) value).getValue();
-            writer.print("'" + nlsString.getValue() + "'");
-        } else {
-            value.unparse(writer, 0, 0);
-        }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    key.unparse(writer, 0, 0);
+    writer.print("=");
+    if (value instanceof SqlCharStringLiteral) {
+      NlsString nlsString = (NlsString) ((SqlCharStringLiteral) value).getValue();
+      writer.print("'" + nlsString.getValue() + "'");
+    } else {
+      value.unparse(writer, 0, 0);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUnionPathPattern.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUnionPathPattern.java
index 14de5c40f..ff36a219e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUnionPathPattern.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUnionPathPattern.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.validate.SqlValidator;
@@ -31,94 +32,94 @@
 
 public class SqlUnionPathPattern extends SqlCall {
 
-    private SqlNode left;
+  private SqlNode left;
 
-    private SqlNode right;
+  private SqlNode right;
 
-    private SqlLiteral unionType;
+  private SqlLiteral unionType;
 
-    public SqlUnionPathPattern(SqlParserPos pos, SqlNode left, SqlNode right, boolean distinct) {
-        super(pos);
-        this.left = Objects.requireNonNull(left);
-        this.right = Objects.requireNonNull(right);
-        if (distinct) {
-            this.unionType = SqlLiteral.createSymbol(UnionPathPatternType.UNION_DISTINCT, pos);
-        } else {
-            this.unionType = SqlLiteral.createSymbol(UnionPathPatternType.UNION_ALL, pos);
-        }
+  public SqlUnionPathPattern(SqlParserPos pos, SqlNode left, SqlNode right, boolean distinct) {
+    super(pos);
+    this.left = Objects.requireNonNull(left);
+    this.right = Objects.requireNonNull(right);
+    if (distinct) {
+      this.unionType = SqlLiteral.createSymbol(UnionPathPatternType.UNION_DISTINCT, pos);
+    } else {
+      this.unionType = SqlLiteral.createSymbol(UnionPathPatternType.UNION_ALL, pos);
     }
-
-    @Override
-    public SqlOperator getOperator() {
-        return SqlPathPatternOperator.INSTANCE;
+  }
+
+  @Override
+  public SqlOperator getOperator() {
+    return SqlPathPatternOperator.INSTANCE;
+  }
+
+  @Override
+  public List getOperandList() {
+    return ImmutableNullableList.of(left, right, unionType);
+  }
+
+  @Override
+  public void validate(SqlValidator validator, SqlValidatorScope scope) {
+    validator.validateQuery(this, scope, validator.getUnknownType());
+  }
+
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.left = operand;
+        break;
+      case 1:
+        this.right = operand;
+        break;
+      case 2:
+        this.unionType = (SqlLiteral) operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
-
-    @Override
-    public List getOperandList() {
-        return ImmutableNullableList.of(left, right, unionType);
+  }
+
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    left.unparse(writer, leftPrec, rightPrec);
+    switch (getUnionPathPatternType()) {
+      case UNION_DISTINCT:
+        writer.print(" | ");
+        break;
+      case UNION_ALL:
+        writer.print(" |+| ");
+        break;
+      default:
+        throw new GeaFlowDSLException(
+            "Unknown union path pattern type: " + getUnionPathPatternType());
     }
+    right.unparse(writer, leftPrec, rightPrec);
+  }
 
-    @Override
-    public void validate(SqlValidator validator, SqlValidatorScope scope) {
-        validator.validateQuery(this, scope, validator.getUnknownType());
-    }
+  public SqlNode getLeft() {
+    return left;
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.left = operand;
-                break;
-            case 1:
-                this.right = operand;
-                break;
-            case 2:
-                this.unionType = (SqlLiteral) operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
-    }
+  public SqlNode getRight() {
+    return right;
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        left.unparse(writer, leftPrec, rightPrec);
-        switch (getUnionPathPatternType()) {
-            case UNION_DISTINCT:
-                writer.print(" | ");
-                break;
-            case UNION_ALL:
-                writer.print(" |+| ");
-                break;
-            default:
-                throw new GeaFlowDSLException("Unknown union path pattern type: "
-                    + getUnionPathPatternType());
-        }
-        right.unparse(writer, leftPrec, rightPrec);
-    }
+  public boolean isDistinct() {
+    return getUnionPathPatternType() == UnionPathPatternType.UNION_DISTINCT;
+  }
 
-    public SqlNode getLeft() {
-        return left;
-    }
+  public boolean isUnionAll() {
+    return getUnionPathPatternType() == UnionPathPatternType.UNION_ALL;
+  }
 
-    public SqlNode getRight() {
-        return right;
-    }
+  public final UnionPathPatternType getUnionPathPatternType() {
+    return unionType.symbolValue(UnionPathPatternType.class);
+  }
 
-    public boolean isDistinct() {
-        return getUnionPathPatternType() == UnionPathPatternType.UNION_DISTINCT;
-    }
-
-    public boolean isUnionAll() {
-        return getUnionPathPatternType() == UnionPathPatternType.UNION_ALL;
-    }
-
-    public final UnionPathPatternType getUnionPathPatternType() {
-        return unionType.symbolValue(UnionPathPatternType.class);
-    }
-
-    public enum UnionPathPatternType {
-        UNION_DISTINCT,
-        UNION_ALL
-    }
+  public enum UnionPathPatternType {
+    UNION_DISTINCT,
+    UNION_ALL
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseGraph.java
index 294b50859..25a6cad46 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseGraph.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseGraph.java
@@ -19,54 +19,56 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 
+import com.google.common.collect.ImmutableList;
+
 public class SqlUseGraph extends SqlAlter {
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlUseGraph",
-        SqlKind.USE_GRAPH);
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("SqlUseGraph", SqlKind.USE_GRAPH);
 
-    private SqlIdentifier graph;
+  private SqlIdentifier graph;
 
-    public SqlUseGraph(SqlParserPos pos, SqlIdentifier graph) {
-        super(pos);
-        this.graph = Objects.requireNonNull(graph);
-    }
+  public SqlUseGraph(SqlParserPos pos, SqlIdentifier graph) {
+    super(pos);
+    this.graph = Objects.requireNonNull(graph);
+  }
 
-    @Override
-    protected void unparseAlterOperation(SqlWriter sqlWriter, int leftPrec, int rightPrec) {
-        sqlWriter.keyword("USE");
-        sqlWriter.keyword("GRAPH");
-        graph.unparse(sqlWriter, 0, 0);
-    }
+  @Override
+  protected void unparseAlterOperation(SqlWriter sqlWriter, int leftPrec, int rightPrec) {
+    sqlWriter.keyword("USE");
+    sqlWriter.keyword("GRAPH");
+    graph.unparse(sqlWriter, 0, 0);
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(graph);
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(graph);
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        if (i == 0) {
-            this.graph = (SqlIdentifier) operand;
-        } else {
-            throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    if (i == 0) {
+      this.graph = (SqlIdentifier) operand;
+    } else {
+      throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    public String getGraph() {
-        if (graph.names.size() == 1) {
-            return graph.getSimple();
-        }
-        throw new IllegalArgumentException("Illegal graph name size: " + graph.names.size());
+  public String getGraph() {
+    if (graph.names.size() == 1) {
+      return graph.getSimple();
     }
+    throw new IllegalArgumentException("Illegal graph name size: " + graph.names.size());
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseInstance.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseInstance.java
index ed9bd6225..5fb719a8c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseInstance.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlUseInstance.java
@@ -19,51 +19,53 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 
+import com.google.common.collect.ImmutableList;
+
 public class SqlUseInstance extends SqlAlter {
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlUseInstance",
-        SqlKind.USE_INSTANCE);
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("SqlUseInstance", SqlKind.USE_INSTANCE);
 
-    private SqlIdentifier instance;
+  private SqlIdentifier instance;
 
-    public SqlUseInstance(SqlParserPos pos, SqlIdentifier instance) {
-        super(pos);
-        this.instance = Objects.requireNonNull(instance);
-    }
+  public SqlUseInstance(SqlParserPos pos, SqlIdentifier instance) {
+    super(pos);
+    this.instance = Objects.requireNonNull(instance);
+  }
 
-    @Override
-    protected void unparseAlterOperation(SqlWriter sqlWriter, int leftPrec, int rightPrec) {
-        sqlWriter.keyword("USE");
-        sqlWriter.keyword("INSTANCE");
-        instance.unparse(sqlWriter, 0, 0);
-    }
+  @Override
+  protected void unparseAlterOperation(SqlWriter sqlWriter, int leftPrec, int rightPrec) {
+    sqlWriter.keyword("USE");
+    sqlWriter.keyword("INSTANCE");
+    instance.unparse(sqlWriter, 0, 0);
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(instance);
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(instance);
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        if (i == 0) {
-            this.instance = (SqlIdentifier) operand;
-        } else {
-            throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    if (i == 0) {
+      this.instance = (SqlIdentifier) operand;
+    } else {
+      throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    public SqlIdentifier getInstance() {
-        return instance;
-    }
+  public SqlIdentifier getInstance() {
+    return instance;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertex.java
index f921f9301..6250aa706 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertex.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertex.java
@@ -20,69 +20,69 @@
 package org.apache.geaflow.dsl.sqlnode;
 
 import java.util.List;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.util.ImmutableNullableList;
 
 public class SqlVertex extends SqlCall {
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlVertex",
-        SqlKind.OTHER_DDL);
-    private SqlIdentifier name;
-    private SqlNodeList columns;
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("SqlVertex", SqlKind.OTHER_DDL);
+  private SqlIdentifier name;
+  private SqlNodeList columns;
 
-    public SqlVertex(SqlParserPos pos, SqlIdentifier name, SqlNodeList columns) {
-        super(pos);
-        this.name = name;
-        this.columns = columns;
-    }
+  public SqlVertex(SqlParserPos pos, SqlIdentifier name, SqlNodeList columns) {
+    super(pos);
+    this.name = name;
+    this.columns = columns;
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.name = (SqlIdentifier) operand;
-                break;
-            case 1:
-                this.columns = (SqlNodeList) operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.name = (SqlIdentifier) operand;
+        break;
+      case 1:
+        this.columns = (SqlNodeList) operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableNullableList.of(getName(), getColumns());
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableNullableList.of(getName(), getColumns());
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        writer.keyword("vertex");
-        name.unparse(writer, 0, 0);
-        writer.print("(");
-        writer.newlineAndIndent();
-        for (int i = 0; i < columns.size(); i++) {
-            if (i > 0) {
-                writer.print(",\n");
-            }
-            columns.get(i).unparse(writer, 0, 0);
-        }
-        writer.newlineAndIndent();
-        writer.print(")");
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    writer.keyword("vertex");
+    name.unparse(writer, 0, 0);
+    writer.print("(");
+    writer.newlineAndIndent();
+    for (int i = 0; i < columns.size(); i++) {
+      if (i > 0) {
+        writer.print(",\n");
+      }
+      columns.get(i).unparse(writer, 0, 0);
     }
+    writer.newlineAndIndent();
+    writer.print(")");
+  }
 
-    public SqlIdentifier getName() {
-        return name;
-    }
-
-    public SqlNodeList getColumns() {
-        return columns;
-    }
+  public SqlIdentifier getName() {
+    return name;
+  }
 
+  public SqlNodeList getColumns() {
+    return columns;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexConstruct.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexConstruct.java
index 5c4192517..47e65bdf4 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexConstruct.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexConstruct.java
@@ -27,26 +27,26 @@
 
 public class SqlVertexConstruct extends AbstractSqlGraphElementConstruct {
 
-    public SqlVertexConstruct(SqlNode[] operands, SqlParserPos pos) {
-        super(new SqlVertexConstructOperator(getKeyNodes(operands)), operands, pos);
-    }
+  public SqlVertexConstruct(SqlNode[] operands, SqlParserPos pos) {
+    super(new SqlVertexConstructOperator(getKeyNodes(operands)), operands, pos);
+  }
 
-    public SqlVertexConstruct(SqlIdentifier[] keyNodes, SqlNode[] valueNodes, SqlParserPos pos) {
-        this(getOperands(keyNodes, valueNodes), pos);
-    }
+  public SqlVertexConstruct(SqlIdentifier[] keyNodes, SqlNode[] valueNodes, SqlParserPos pos) {
+    this(getOperands(keyNodes, valueNodes), pos);
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        writer.keyword("VERTEX");
-        writer.print("{\n");
-        for (int i = 0; i < getKeyNodes().length; i++) {
-            SqlNode key = getKeyNodes()[i];
-            SqlNode value = getValueNodes()[i];
-            key.unparse(writer, 0, 0);
-            writer.print("=");
-            value.unparse(writer, 0, 0);
-            writer.print("\n");
-        }
-        writer.print("}");
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    writer.keyword("VERTEX");
+    writer.print("{\n");
+    for (int i = 0; i < getKeyNodes().length; i++) {
+      SqlNode key = getKeyNodes()[i];
+      SqlNode value = getValueNodes()[i];
+      key.unparse(writer, 0, 0);
+      writer.print("=");
+      value.unparse(writer, 0, 0);
+      writer.print("\n");
     }
+    writer.print("}");
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexUsing.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexUsing.java
index 21512b218..39a0095a0 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexUsing.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/sqlnode/SqlVertexUsing.java
@@ -19,79 +19,79 @@
 
 package org.apache.geaflow.dsl.sqlnode;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 
-public class SqlVertexUsing extends SqlCall {
+import com.google.common.collect.ImmutableList;
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("SqlVertexUsing",
-        SqlKind.OTHER_DDL);
+public class SqlVertexUsing extends SqlCall {
 
-    private SqlIdentifier name;
-    private SqlIdentifier usingTableName;
-    private SqlIdentifier id;
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("SqlVertexUsing", SqlKind.OTHER_DDL);
 
-    public SqlVertexUsing(SqlParserPos pos, SqlIdentifier name,
-                          SqlIdentifier usingTableName,
-                          SqlIdentifier id) {
-        super(pos);
-        this.name = name;
-        this.usingTableName = usingTableName;
-        this.id = id;
-    }
+  private SqlIdentifier name;
+  private SqlIdentifier usingTableName;
+  private SqlIdentifier id;
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.name = (SqlIdentifier) operand;
-                break;
-            case 1:
-                this.usingTableName = (SqlIdentifier) operand;
-                break;
-            case 2:
-                this.id = (SqlIdentifier) operand;
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal index: " + i);
-        }
-    }
+  public SqlVertexUsing(
+      SqlParserPos pos, SqlIdentifier name, SqlIdentifier usingTableName, SqlIdentifier id) {
+    super(pos);
+    this.name = name;
+    this.usingTableName = usingTableName;
+    this.id = id;
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.name = (SqlIdentifier) operand;
+        break;
+      case 1:
+        this.usingTableName = (SqlIdentifier) operand;
+        break;
+      case 2:
+        this.id = (SqlIdentifier) operand;
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal index: " + i);
     }
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableList.of(getName(), getUsingTableName(), getId());
-    }
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        writer.keyword("VERTEX");
-        name.unparse(writer, 0, 0);
-        writer.keyword("USING");
-        usingTableName.unparse(writer, 0, 0);
-        writer.keyword("WITH");
-        writer.keyword("ID");
-        writer.print("(");
-        id.unparse(writer, 0, 0);
-        writer.print(")");
+  @Override
+  public List getOperandList() {
+    return ImmutableList.of(getName(), getUsingTableName(), getId());
+  }
 
-    }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    writer.keyword("VERTEX");
+    name.unparse(writer, 0, 0);
+    writer.keyword("USING");
+    usingTableName.unparse(writer, 0, 0);
+    writer.keyword("WITH");
+    writer.keyword("ID");
+    writer.print("(");
+    id.unparse(writer, 0, 0);
+    writer.print(")");
+  }
 
-    public SqlIdentifier getName() {
-        return name;
-    }
+  public SqlIdentifier getName() {
+    return name;
+  }
 
-    public SqlIdentifier getUsingTableName() {
-        return usingTableName;
-    }
+  public SqlIdentifier getUsingTableName() {
+    return usingTableName;
+  }
 
-    public SqlIdentifier getId() {
-        return id;
-    }
+  public SqlIdentifier getId() {
+    return id;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/EdgeDirection.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/EdgeDirection.java
index 45eed921d..896f87eb7 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/EdgeDirection.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/EdgeDirection.java
@@ -20,16 +20,16 @@
 package org.apache.geaflow.dsl.util;
 
 public enum EdgeDirection {
-    OUT,
-    IN,
-    BOTH;
+  OUT,
+  IN,
+  BOTH;
 
-    public static EdgeDirection of(String value) {
-        for (EdgeDirection direction : EdgeDirection.values()) {
-            if (direction.name().equalsIgnoreCase(value)) {
-                return direction;
-            }
-        }
-        throw new IllegalArgumentException("Illegal direction value: " + value);
+  public static EdgeDirection of(String value) {
+    for (EdgeDirection direction : EdgeDirection.values()) {
+      if (direction.name().equalsIgnoreCase(value)) {
+        return direction;
+      }
     }
+    throw new IllegalArgumentException("Illegal direction value: " + value);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLEdgeConstraint.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLEdgeConstraint.java
index aa3319b23..89bc79210 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLEdgeConstraint.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLEdgeConstraint.java
@@ -22,82 +22,82 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.util.ImmutableNullableList;
 
 public class GQLEdgeConstraint extends SqlCall {
 
-    private static final SqlOperator OPERATOR = new SqlSpecialOperator("GQLEdgeConstraint",
-        SqlKind.OTHER_DDL);
+  private static final SqlOperator OPERATOR =
+      new SqlSpecialOperator("GQLEdgeConstraint", SqlKind.OTHER_DDL);
 
-    private SqlNodeList sourceVertexType;
-    private SqlNodeList targetVertexType;
+  private SqlNodeList sourceVertexType;
+  private SqlNodeList targetVertexType;
 
-    public GQLEdgeConstraint(SqlNodeList sourceVertexType,
-                             SqlNodeList targetVertexType, SqlParserPos pos) {
-        super(pos);
-        this.sourceVertexType = Objects.requireNonNull(sourceVertexType);
-        this.targetVertexType = Objects.requireNonNull(targetVertexType);
-    }
+  public GQLEdgeConstraint(
+      SqlNodeList sourceVertexType, SqlNodeList targetVertexType, SqlParserPos pos) {
+    super(pos);
+    this.sourceVertexType = Objects.requireNonNull(sourceVertexType);
+    this.targetVertexType = Objects.requireNonNull(targetVertexType);
+  }
 
+  @Override
+  public SqlOperator getOperator() {
+    return OPERATOR;
+  }
 
-    @Override
-    public SqlOperator getOperator() {
-        return OPERATOR;
-    }
+  @Override
+  public List getOperandList() {
+    return ImmutableNullableList.of(sourceVertexType, targetVertexType);
+  }
 
-    @Override
-    public List getOperandList() {
-        return ImmutableNullableList.of(sourceVertexType, targetVertexType);
+  @Override
+  public void setOperand(int i, SqlNode operand) {
+    switch (i) {
+      case 0:
+        this.sourceVertexType = (SqlNodeList) operand;
+        break;
+      case 1:
+        this.targetVertexType = (SqlNodeList) operand;
+        break;
+      default:
+        throw new IndexOutOfBoundsException("current index " + i + " out of range " + 2);
     }
+  }
 
-    @Override
-    public void setOperand(int i, SqlNode operand) {
-        switch (i) {
-            case 0:
-                this.sourceVertexType = (SqlNodeList) operand;
-                break;
-            case 1:
-                this.targetVertexType = (SqlNodeList) operand;
-                break;
-            default:
-                throw new IndexOutOfBoundsException("current index " + i + " out of range " + 2);
-        }
+  @Override
+  public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+    for (int i = 0; i < sourceVertexType.size(); i++) {
+      if (i > 0) {
+        writer.print("|");
+      }
+      sourceVertexType.get(i).unparse(writer, 0, 0);
     }
-
-    @Override
-    public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
-        for (int i = 0; i < sourceVertexType.size(); i++) {
-            if (i > 0) {
-                writer.print("|");
-            }
-            sourceVertexType.get(i).unparse(writer, 0, 0);
-        }
-        writer.print("->");
-        for (int i = 0; i < targetVertexType.size(); i++) {
-            if (i > 0) {
-                writer.print("|");
-            }
-            targetVertexType.get(i).unparse(writer, 0, 0);
-        }
+    writer.print("->");
+    for (int i = 0; i < targetVertexType.size(); i++) {
+      if (i > 0) {
+        writer.print("|");
+      }
+      targetVertexType.get(i).unparse(writer, 0, 0);
     }
+  }
 
-    public List getSourceVertexTypes() {
-        List sourceTypes = new ArrayList<>();
-        for (SqlNode node : sourceVertexType) {
-            assert node instanceof SqlIdentifier;
-            sourceTypes.add(((SqlIdentifier) node).getSimple());
-        }
-        return sourceTypes;
+  public List getSourceVertexTypes() {
+    List sourceTypes = new ArrayList<>();
+    for (SqlNode node : sourceVertexType) {
+      assert node instanceof SqlIdentifier;
+      sourceTypes.add(((SqlIdentifier) node).getSimple());
     }
+    return sourceTypes;
+  }
 
-    public List getTargetVertexTypes() {
-        List targetTypes = new ArrayList<>();
-        for (SqlNode node : targetVertexType) {
-            assert node instanceof SqlIdentifier;
-            targetTypes.add(((SqlIdentifier) node).getSimple());
-        }
-        return targetTypes;
+  public List getTargetVertexTypes() {
+    List targetTypes = new ArrayList<>();
+    for (SqlNode node : targetVertexType) {
+      assert node instanceof SqlIdentifier;
+      targetTypes.add(((SqlIdentifier) node).getSimple());
     }
+    return targetTypes;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLReturnKeyword.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLReturnKeyword.java
index 5ef61553d..3008ba44a 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLReturnKeyword.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/GQLReturnKeyword.java
@@ -23,11 +23,10 @@
 import org.apache.calcite.sql.parser.SqlParserPos;
 
 public enum GQLReturnKeyword {
-    DISTINCT,
-    ALL;
+  DISTINCT,
+  ALL;
 
-    public SqlLiteral symbol(SqlParserPos pos) {
-        return SqlLiteral.createSymbol(this, pos);
-    }
+  public SqlLiteral symbol(SqlParserPos pos) {
+    return SqlLiteral.createSymbol(this, pos);
+  }
 }
-
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlNodeUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlNodeUtil.java
index 7ca38f434..b3532d0f8 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlNodeUtil.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlNodeUtil.java
@@ -25,185 +25,188 @@
 import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.util.SqlVisitor;
 
 public class SqlNodeUtil {
 
-    public static void unparseNodeList(SqlWriter writer,
-                                       SqlNodeList nodeList, String sep) {
-        for (int i = 0; i < nodeList.size(); i++) {
-            if (i > 0) {
-                writer.print(sep);
-                writer.newlineAndIndent();
-            }
-            nodeList.get(i).unparse(writer, 0, 0);
-        }
+  public static void unparseNodeList(SqlWriter writer, SqlNodeList nodeList, String sep) {
+    for (int i = 0; i < nodeList.size(); i++) {
+      if (i > 0) {
+        writer.print(sep);
+        writer.newlineAndIndent();
+      }
+      nodeList.get(i).unparse(writer, 0, 0);
     }
-
-    public static List findUnresolvedFunctions(SqlNode sqlNode) {
-        return sqlNode.accept(new SqlVisitor>() {
-            @Override
-            public List visit(SqlLiteral literal) {
-                return Collections.emptyList();
-            }
-
-            @Override
-            public List visit(SqlCall call) {
-                List functions = new ArrayList<>();
-                if (call.getOperator() instanceof SqlUnresolvedFunction) {
-                    functions.add((SqlUnresolvedFunction) call.getOperator());
-                }
-                functions.addAll(visitNodes(call.getOperandList()));
-                return functions;
-            }
-
-            @Override
-            public List visit(SqlNodeList nodeList) {
-                return visitNodes(nodeList.getList());
-            }
-
-            private List visitNodes(List nodes) {
-                if (nodes == null) {
-                    return Collections.emptyList();
-                }
-                return nodes.stream()
-                    .flatMap(node -> {
-                        if (node != null) {
-                            return node.accept(this).stream();
-                        } else {
-                            return Collections.emptyList().stream();
-                        }
+  }
+
+  public static List findUnresolvedFunctions(SqlNode sqlNode) {
+    return sqlNode.accept(
+        new SqlVisitor>() {
+          @Override
+          public List visit(SqlLiteral literal) {
+            return Collections.emptyList();
+          }
+
+          @Override
+          public List visit(SqlCall call) {
+            List functions = new ArrayList<>();
+            if (call.getOperator() instanceof SqlUnresolvedFunction) {
+              functions.add((SqlUnresolvedFunction) call.getOperator());
+            }
+            functions.addAll(visitNodes(call.getOperandList()));
+            return functions;
+          }
+
+          @Override
+          public List visit(SqlNodeList nodeList) {
+            return visitNodes(nodeList.getList());
+          }
+
+          private List visitNodes(List nodes) {
+            if (nodes == null) {
+              return Collections.emptyList();
+            }
+            return nodes.stream()
+                .flatMap(
+                    node -> {
+                      if (node != null) {
+                        return node.accept(this).stream();
+                      } else {
+                        return Collections.emptyList().stream();
+                      }
                     })
-                    .collect(Collectors.toList());
-            }
-
-            @Override
-            public List visit(SqlIdentifier id) {
-                return Collections.emptyList();
-            }
-
-            @Override
-            public List visit(SqlDataTypeSpec type) {
-                return Collections.emptyList();
-            }
-
-            @Override
-            public List visit(SqlDynamicParam param) {
-                return Collections.emptyList();
-            }
-
-            @Override
-            public List visit(SqlIntervalQualifier intervalQualifier) {
-                return Collections.emptyList();
-            }
+                .collect(Collectors.toList());
+          }
+
+          @Override
+          public List visit(SqlIdentifier id) {
+            return Collections.emptyList();
+          }
+
+          @Override
+          public List visit(SqlDataTypeSpec type) {
+            return Collections.emptyList();
+          }
+
+          @Override
+          public List visit(SqlDynamicParam param) {
+            return Collections.emptyList();
+          }
+
+          @Override
+          public List visit(SqlIntervalQualifier intervalQualifier) {
+            return Collections.emptyList();
+          }
         });
-    }
-
-    public static Set findUsedTables(SqlNode sqlNode) {
-        return sqlNode.accept(new SqlVisitor>() {
-            @Override
-            public Set visit(SqlLiteral literal) {
-                return Collections.emptySet();
-            }
-
-            @Override
-            public Set visit(SqlCall call) {
-                Set allTables = new HashSet<>();
-                if (call instanceof SqlInsert) {
-                    SqlInsert sqlInsert = (SqlInsert) call;
-                    SqlNode source = sqlInsert.getSource();
-                    Set sourceTables = source.accept(this);
-                    SqlNode target = sqlInsert.getTargetTable();
-                    Set targetTables = target.accept(this);
-                    allTables.addAll(sourceTables);
-                    allTables.addAll(targetTables);
-                } else if (call instanceof SqlSelect) {
-                    SqlSelect sqlSelect = (SqlSelect) call;
-                    SqlNode from = sqlSelect.getFrom();
-                    if (from != null) {
-                        Set tables = from.accept(this);
-                        allTables.addAll(tables);
-                    }
-                } else if (call instanceof SqlJoin) {
-                    SqlJoin sqlJoin = (SqlJoin) call;
-                    String left = sqlJoin.getLeft().toString();
-                    String right = sqlJoin.getRight().toString();
-                    allTables.add(left);
-                    allTables.add(right);
-                } else if (call instanceof SqlWith) {
-                    SqlWith sqlWith = (SqlWith) call;
-                    SqlNodeList withList = sqlWith.withList;
-                    if (withList != null) {
-                        Set tables = withList.accept(this);
-                        allTables.addAll(tables);
-                    }
-
-                } else if (call instanceof SqlWithItem) {
-                    SqlWithItem withItem = (SqlWithItem) call;
-                    allTables.add(withItem.name.names.get(0));
-                    SqlNode query = withItem.query;
-                    if (query != null) {
-                        Set tables = query.accept(this);
-                        allTables.addAll(tables);
-                    }
-
-                } else if (call instanceof SqlBasicCall) {
-                    SqlBasicCall basicCall = (SqlBasicCall) call;
-                    SqlNode[] operands = basicCall.getOperands();
-                    if (operands.length > 0) {
-                        Set tables = operands[0].accept(this);
-                        allTables.addAll(tables);
-                    }
-                }
-                return allTables;
-            }
-
-            @Override
-            public Set visit(SqlNodeList nodeList) {
-                return visitNodes(nodeList.getList());
-            }
-
-            private Set visitNodes(List nodes) {
-                if (nodes == null) {
-                    return Collections.emptySet();
-                }
-                return nodes.stream()
-                    .flatMap(node -> {
-                        if (node != null) {
-                            return node.accept(this).stream();
-                        } else {
-                            return Collections.emptySet().stream();
-                        }
+  }
+
+  public static Set findUsedTables(SqlNode sqlNode) {
+    return sqlNode.accept(
+        new SqlVisitor>() {
+          @Override
+          public Set visit(SqlLiteral literal) {
+            return Collections.emptySet();
+          }
+
+          @Override
+          public Set visit(SqlCall call) {
+            Set allTables = new HashSet<>();
+            if (call instanceof SqlInsert) {
+              SqlInsert sqlInsert = (SqlInsert) call;
+              SqlNode source = sqlInsert.getSource();
+              Set sourceTables = source.accept(this);
+              SqlNode target = sqlInsert.getTargetTable();
+              Set targetTables = target.accept(this);
+              allTables.addAll(sourceTables);
+              allTables.addAll(targetTables);
+            } else if (call instanceof SqlSelect) {
+              SqlSelect sqlSelect = (SqlSelect) call;
+              SqlNode from = sqlSelect.getFrom();
+              if (from != null) {
+                Set tables = from.accept(this);
+                allTables.addAll(tables);
+              }
+            } else if (call instanceof SqlJoin) {
+              SqlJoin sqlJoin = (SqlJoin) call;
+              String left = sqlJoin.getLeft().toString();
+              String right = sqlJoin.getRight().toString();
+              allTables.add(left);
+              allTables.add(right);
+            } else if (call instanceof SqlWith) {
+              SqlWith sqlWith = (SqlWith) call;
+              SqlNodeList withList = sqlWith.withList;
+              if (withList != null) {
+                Set tables = withList.accept(this);
+                allTables.addAll(tables);
+              }
+
+            } else if (call instanceof SqlWithItem) {
+              SqlWithItem withItem = (SqlWithItem) call;
+              allTables.add(withItem.name.names.get(0));
+              SqlNode query = withItem.query;
+              if (query != null) {
+                Set tables = query.accept(this);
+                allTables.addAll(tables);
+              }
+
+            } else if (call instanceof SqlBasicCall) {
+              SqlBasicCall basicCall = (SqlBasicCall) call;
+              SqlNode[] operands = basicCall.getOperands();
+              if (operands.length > 0) {
+                Set tables = operands[0].accept(this);
+                allTables.addAll(tables);
+              }
+            }
+            return allTables;
+          }
+
+          @Override
+          public Set visit(SqlNodeList nodeList) {
+            return visitNodes(nodeList.getList());
+          }
+
+          private Set visitNodes(List nodes) {
+            if (nodes == null) {
+              return Collections.emptySet();
+            }
+            return nodes.stream()
+                .flatMap(
+                    node -> {
+                      if (node != null) {
+                        return node.accept(this).stream();
+                      } else {
+                        return Collections.emptySet().stream();
+                      }
                     })
-                    .collect(Collectors.toSet());
-            }
-
-            @Override
-            public Set visit(SqlIdentifier id) {
-                if (!id.names.isEmpty()) {
-                    String name = id.names.get(0);
-                    return Collections.singleton(name);
-                } else {
-                    return Collections.emptySet();
-                }
-
-            }
-
-            @Override
-            public Set visit(SqlDataTypeSpec type) {
-                return Collections.emptySet();
-            }
-
-            @Override
-            public Set visit(SqlDynamicParam param) {
-                return Collections.emptySet();
-            }
-
-            @Override
-            public Set visit(SqlIntervalQualifier intervalQualifier) {
-                return Collections.emptySet();
-            }
+                .collect(Collectors.toSet());
+          }
+
+          @Override
+          public Set visit(SqlIdentifier id) {
+            if (!id.names.isEmpty()) {
+              String name = id.names.get(0);
+              return Collections.singleton(name);
+            } else {
+              return Collections.emptySet();
+            }
+          }
+
+          @Override
+          public Set visit(SqlDataTypeSpec type) {
+            return Collections.emptySet();
+          }
+
+          @Override
+          public Set visit(SqlDynamicParam param) {
+            return Collections.emptySet();
+          }
+
+          @Override
+          public Set visit(SqlIntervalQualifier intervalQualifier) {
+            return Collections.emptySet();
+          }
         });
-    }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlTypeUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlTypeUtil.java
index dc66526a1..0cf6f630c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlTypeUtil.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/SqlTypeUtil.java
@@ -23,6 +23,7 @@
 import java.util.List;
 import java.util.Locale;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.adapter.java.JavaTypeFactory;
 import org.apache.calcite.rel.type.*;
 import org.apache.calcite.sql.SqlDataTypeSpec;
@@ -45,162 +46,171 @@
 
 public final class SqlTypeUtil {
 
-    public static IType convertType(SqlDataTypeSpec typeSpec) {
-        String typeName = typeSpec.getTypeName().getSimple().toUpperCase();
-        typeName = convertTypeName(typeName);
-        return Types.of(typeName, typeSpec.getPrecision());
-    }
+  public static IType convertType(SqlDataTypeSpec typeSpec) {
+    String typeName = typeSpec.getTypeName().getSimple().toUpperCase();
+    typeName = convertTypeName(typeName);
+    return Types.of(typeName, typeSpec.getPrecision());
+  }
 
-    public static IType convertType(RelDataType type) {
-        SqlTypeName sqlTypeName = type.getSqlTypeName();
-        switch (sqlTypeName) {
-            case ARRAY:
-                RelDataType componentType = type.getComponentType();
-                return new ArrayType(convertType(componentType));
-            case STRUCTURED:
-            case ROW:
-                List fields = toTableFields(type.getFieldList());
-                return new StructType(fields);
-            case VERTEX:
-                VertexRecordType vertexType = (VertexRecordType) type;
-                List vertexFields = toTableFields(vertexType.getFieldList());
-                return new VertexType(vertexFields);
-            case EDGE:
-                EdgeRecordType edgeType = (EdgeRecordType) type;
-                List edgeFields = toTableFields(edgeType.getFieldList());
-                return new EdgeType(edgeFields, edgeType.getTimestampField().isPresent());
-            case PATH:
-                PathRecordType pathType = (PathRecordType) type;
-                List pathFields = toTableFields(pathType.getFieldList());
-                return new PathType(pathFields);
-            case GRAPH:
-                GraphRecordType graphType = (GraphRecordType) type;
-                List graphFields = toTableFields(graphType.getFieldList());
-                return new GraphSchema(graphType.getGraphName(), graphFields);
-            default:
-                return ofTypeName(sqlTypeName);
-        }
+  public static IType convertType(RelDataType type) {
+    SqlTypeName sqlTypeName = type.getSqlTypeName();
+    switch (sqlTypeName) {
+      case ARRAY:
+        RelDataType componentType = type.getComponentType();
+        return new ArrayType(convertType(componentType));
+      case STRUCTURED:
+      case ROW:
+        List fields = toTableFields(type.getFieldList());
+        return new StructType(fields);
+      case VERTEX:
+        VertexRecordType vertexType = (VertexRecordType) type;
+        List vertexFields = toTableFields(vertexType.getFieldList());
+        return new VertexType(vertexFields);
+      case EDGE:
+        EdgeRecordType edgeType = (EdgeRecordType) type;
+        List edgeFields = toTableFields(edgeType.getFieldList());
+        return new EdgeType(edgeFields, edgeType.getTimestampField().isPresent());
+      case PATH:
+        PathRecordType pathType = (PathRecordType) type;
+        List pathFields = toTableFields(pathType.getFieldList());
+        return new PathType(pathFields);
+      case GRAPH:
+        GraphRecordType graphType = (GraphRecordType) type;
+        List graphFields = toTableFields(graphType.getFieldList());
+        return new GraphSchema(graphType.getGraphName(), graphFields);
+      default:
+        return ofTypeName(sqlTypeName);
     }
+  }
 
-    private static List toTableFields(List fields) {
-        return fields.stream().map(field ->
-                new TableField(field.getName(), convertType(field.getType()),
-                    field.getType().isNullable()))
-            .collect(Collectors.toList());
-    }
+  private static List toTableFields(List fields) {
+    return fields.stream()
+        .map(
+            field ->
+                new TableField(
+                    field.getName(), convertType(field.getType()), field.getType().isNullable()))
+        .collect(Collectors.toList());
+  }
 
-    public static IType ofTypeName(SqlTypeName sqlTypeName) {
-        String typeName = convertTypeName(sqlTypeName.getName());
-        return Types.of(typeName, sqlTypeName.getPrecision());
-    }
+  public static IType ofTypeName(SqlTypeName sqlTypeName) {
+    String typeName = convertTypeName(sqlTypeName.getName());
+    return Types.of(typeName, sqlTypeName.getPrecision());
+  }
 
-    public static RelDataType convertToRelType(IType type, boolean isNullable,
-                                               RelDataTypeFactory typeFactory) {
-        switch (type.getName()) {
-            case Types.TYPE_NAME_ARRAY:
-                ArrayType arrayType = (ArrayType) type;
-                RelDataType componentType = convertToRelType(arrayType.getComponentType(), isNullable, typeFactory);
-                return typeFactory.createTypeWithNullability(
-                    typeFactory.createArrayType(componentType, -1), true);
-            case Types.TYPE_NAME_STRUCT:
-                StructType structType = (StructType) type;
-                List fields = toRecordFields(structType.getFields(), typeFactory);
-                return new RelRecordType(StructKind.PEEK_FIELDS, fields);
-            case Types.TYPE_NAME_VERTEX:
-                VertexType vertexType = (VertexType) type;
-                List vertexFields = toRecordFields(vertexType.getFields(), typeFactory);
-                return VertexRecordType.createVertexType(vertexFields,
-                    vertexType.getId().getName(), typeFactory);
-            case Types.TYPE_NAME_EDGE:
-                EdgeType edgeType = (EdgeType) type;
-                List edgeFields = toRecordFields(edgeType.getFields(), typeFactory);
-                return EdgeRecordType.createEdgeType(edgeFields,
-                    edgeType.getSrcId().getName(),
-                    edgeType.getTargetId().getName(),
-                    edgeType.getTimestamp().map(TableField::getName).orElse(null),
-                    typeFactory);
-            case Types.TYPE_NAME_PATH:
-                PathType pathType = (PathType) type;
-                List pathFields = toRecordFields(pathType.getFields(), typeFactory);
-                return new PathRecordType(pathFields);
-            case Types.TYPE_NAME_GRAPH:
-                GraphSchema graphSchema = (GraphSchema) type;
-                List recordFields = toRecordFields(graphSchema.getFields(), typeFactory);
-                return new GraphRecordType(graphSchema.getGraphName(), recordFields);
-            default:
-                if (type.isPrimitive()) {
-                    String sqlTypeName = convertToSqlTypeName(type);
-                    SqlTypeName typeName = Types.getType(type.getTypeClass()) == Types.BINARY_STRING
-                        ? SqlTypeName.get(sqlTypeName, ((BinaryStringType) type).getPrecision())
-                        : SqlTypeName.get(sqlTypeName);
-                    return typeFactory.createTypeWithNullability(typeFactory.createSqlType(typeName), isNullable);
-                } else {
-                    throw new GeaFlowDSLException("Not support type: " + type);
-                }
+  public static RelDataType convertToRelType(
+      IType type, boolean isNullable, RelDataTypeFactory typeFactory) {
+    switch (type.getName()) {
+      case Types.TYPE_NAME_ARRAY:
+        ArrayType arrayType = (ArrayType) type;
+        RelDataType componentType =
+            convertToRelType(arrayType.getComponentType(), isNullable, typeFactory);
+        return typeFactory.createTypeWithNullability(
+            typeFactory.createArrayType(componentType, -1), true);
+      case Types.TYPE_NAME_STRUCT:
+        StructType structType = (StructType) type;
+        List fields = toRecordFields(structType.getFields(), typeFactory);
+        return new RelRecordType(StructKind.PEEK_FIELDS, fields);
+      case Types.TYPE_NAME_VERTEX:
+        VertexType vertexType = (VertexType) type;
+        List vertexFields = toRecordFields(vertexType.getFields(), typeFactory);
+        return VertexRecordType.createVertexType(
+            vertexFields, vertexType.getId().getName(), typeFactory);
+      case Types.TYPE_NAME_EDGE:
+        EdgeType edgeType = (EdgeType) type;
+        List edgeFields = toRecordFields(edgeType.getFields(), typeFactory);
+        return EdgeRecordType.createEdgeType(
+            edgeFields,
+            edgeType.getSrcId().getName(),
+            edgeType.getTargetId().getName(),
+            edgeType.getTimestamp().map(TableField::getName).orElse(null),
+            typeFactory);
+      case Types.TYPE_NAME_PATH:
+        PathType pathType = (PathType) type;
+        List pathFields = toRecordFields(pathType.getFields(), typeFactory);
+        return new PathRecordType(pathFields);
+      case Types.TYPE_NAME_GRAPH:
+        GraphSchema graphSchema = (GraphSchema) type;
+        List recordFields = toRecordFields(graphSchema.getFields(), typeFactory);
+        return new GraphRecordType(graphSchema.getGraphName(), recordFields);
+      default:
+        if (type.isPrimitive()) {
+          String sqlTypeName = convertToSqlTypeName(type);
+          SqlTypeName typeName =
+              Types.getType(type.getTypeClass()) == Types.BINARY_STRING
+                  ? SqlTypeName.get(sqlTypeName, ((BinaryStringType) type).getPrecision())
+                  : SqlTypeName.get(sqlTypeName);
+          return typeFactory.createTypeWithNullability(
+              typeFactory.createSqlType(typeName), isNullable);
+        } else {
+          throw new GeaFlowDSLException("Not support type: " + type);
         }
     }
+  }
 
-    private static List toRecordFields(List tableFields,
-                                                         RelDataTypeFactory typeFactory) {
-        List recordFields = new ArrayList<>(tableFields.size());
-        for (int i = 0; i < tableFields.size(); i++) {
-            TableField tableField = tableFields.get(i);
-            recordFields.add(new RelDataTypeFieldImpl(tableField.getName(), i,
-                convertToRelType(tableField.getType(), tableField.isNullable(), typeFactory)));
-        }
-        return recordFields;
+  private static List toRecordFields(
+      List tableFields, RelDataTypeFactory typeFactory) {
+    List recordFields = new ArrayList<>(tableFields.size());
+    for (int i = 0; i < tableFields.size(); i++) {
+      TableField tableField = tableFields.get(i);
+      recordFields.add(
+          new RelDataTypeFieldImpl(
+              tableField.getName(),
+              i,
+              convertToRelType(tableField.getType(), tableField.isNullable(), typeFactory)));
     }
+    return recordFields;
+  }
 
-    public static String convertTypeName(String sqlTypeName) {
-        String upperName = sqlTypeName.toUpperCase(Locale.ROOT);
-        if (upperName.equals(SqlTypeName.VARCHAR.getName())
-            || upperName.equals(SqlTypeName.CHAR.getName())
-            || upperName.equals(Types.TYPE_NAME_STRING)) {
-            return Types.TYPE_NAME_BINARY_STRING;
-        }
-        if (upperName.equals(SqlTypeName.BIGINT.getName())) {
-            return Types.TYPE_NAME_LONG;
-        }
-        if (upperName.equals(SqlTypeName.DECIMAL.getName())) {
-            return Types.TYPE_NAME_DECIMAL;
-        }
-        if (upperName.startsWith("CHAR(") || upperName.startsWith("VARCHAR(")) {
-            return Types.TYPE_NAME_BINARY_STRING;
-        }
-        if (upperName.equals("INT") || upperName.equals("SYMBOL")) {
-            return Types.TYPE_NAME_INTEGER;
-        }
-        return upperName;
+  public static String convertTypeName(String sqlTypeName) {
+    String upperName = sqlTypeName.toUpperCase(Locale.ROOT);
+    if (upperName.equals(SqlTypeName.VARCHAR.getName())
+        || upperName.equals(SqlTypeName.CHAR.getName())
+        || upperName.equals(Types.TYPE_NAME_STRING)) {
+      return Types.TYPE_NAME_BINARY_STRING;
+    }
+    if (upperName.equals(SqlTypeName.BIGINT.getName())) {
+      return Types.TYPE_NAME_LONG;
+    }
+    if (upperName.equals(SqlTypeName.DECIMAL.getName())) {
+      return Types.TYPE_NAME_DECIMAL;
+    }
+    if (upperName.startsWith("CHAR(") || upperName.startsWith("VARCHAR(")) {
+      return Types.TYPE_NAME_BINARY_STRING;
     }
+    if (upperName.equals("INT") || upperName.equals("SYMBOL")) {
+      return Types.TYPE_NAME_INTEGER;
+    }
+    return upperName;
+  }
 
-    private static String convertToSqlTypeName(IType type) {
-        switch (type.getName()) {
-            case Types.TYPE_NAME_STRING:
-            case Types.TYPE_NAME_BINARY_STRING:
-                return SqlTypeName.VARCHAR.getName();
-            case Types.TYPE_NAME_LONG:
-                return SqlTypeName.BIGINT.getName();
-            default:
-                return type.getName().toUpperCase();
-        }
+  private static String convertToSqlTypeName(IType type) {
+    switch (type.getName()) {
+      case Types.TYPE_NAME_STRING:
+      case Types.TYPE_NAME_BINARY_STRING:
+        return SqlTypeName.VARCHAR.getName();
+      case Types.TYPE_NAME_LONG:
+        return SqlTypeName.BIGINT.getName();
+      default:
+        return type.getName().toUpperCase();
     }
+  }
 
-    public static List> convertToJavaTypes(List types,
-                                                    JavaTypeFactory typeFactory) {
-        List> javaTypes = new ArrayList<>();
-        for (RelDataType type : types) {
-            javaTypes.add((Class) typeFactory.getJavaClass(type));
-        }
-        return javaTypes;
+  public static List> convertToJavaTypes(
+      List types, JavaTypeFactory typeFactory) {
+    List> javaTypes = new ArrayList<>();
+    for (RelDataType type : types) {
+      javaTypes.add((Class) typeFactory.getJavaClass(type));
     }
+    return javaTypes;
+  }
 
-    public static List> convertToJavaTypes(RelDataType rowType,
-                                                    JavaTypeFactory typeFactory) {
-        List fields = rowType.getFieldList();
-        List types = new ArrayList<>();
-        for (RelDataTypeField field : fields) {
-            types.add(field.getType());
-        }
-        return convertToJavaTypes(types, typeFactory);
+  public static List> convertToJavaTypes(
+      RelDataType rowType, JavaTypeFactory typeFactory) {
+    List fields = rowType.getFieldList();
+    List types = new ArrayList<>();
+    for (RelDataTypeField field : fields) {
+      types.add(field.getType());
     }
+    return convertToJavaTypes(types, typeFactory);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/StringLiteralUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/StringLiteralUtil.java
index 9f564758e..e49504077 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/StringLiteralUtil.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/main/java/org/apache/geaflow/dsl/util/StringLiteralUtil.java
@@ -24,240 +24,233 @@
 
 public class StringLiteralUtil {
 
-    // Unicode escape sequence hex weights: 16^3, 16^2, 16^1, 16^0 = 4096, 256, 16, 1
-    private static final int[] UNICODE_HEX_MULTIPLIER = new int[]{4096, 256, 16, 1};
+  // Unicode escape sequence hex weights: 16^3, 16^2, 16^1, 16^0 = 4096, 256, 16, 1
+  private static final int[] UNICODE_HEX_MULTIPLIER = new int[] {4096, 256, 16, 1};
 
-    public static String unescapeSQLString(String b) {
+  public static String unescapeSQLString(String b) {
 
-        Character enclosure = null;
+    Character enclosure = null;
 
-        // Some of the strings can be passed in as unicode. For example, the
-        // delimiter can be passed in as \002 - So, we first check if the
-        // string is a unicode number, else go back to the old behavior
-        StringBuilder sb = new StringBuilder(b.length());
-        for (int i = 0; i < b.length(); i++) {
-
-            char currentChar = b.charAt(i);
-            if (enclosure == null) {
-                if (currentChar == '\'' || currentChar == '\"') {
-                    enclosure = currentChar;
-                }
-                // ignore all other chars outside the enclosure
-                continue;
-            }
-
-            if (enclosure.equals(currentChar)) {
-                enclosure = null;
-                continue;
-            }
+    // Some of the strings can be passed in as unicode. For example, the
+    // delimiter can be passed in as \002 - So, we first check if the
+    // string is a unicode number, else go back to the old behavior
+    StringBuilder sb = new StringBuilder(b.length());
+    for (int i = 0; i < b.length(); i++) {
 
+      char currentChar = b.charAt(i);
+      if (enclosure == null) {
+        if (currentChar == '\'' || currentChar == '\"') {
+          enclosure = currentChar;
+        }
+        // ignore all other chars outside the enclosure
+        continue;
+      }
 
-            // Process Unicode escape sequence (backslash-u followed by 4 hex digits)
-            // Need at least 6 characters: backslash-u + 4 hex digits
-            if (currentChar == '\\' && (i + 6 <= b.length()) && b.charAt(i + 1) == 'u') {
-                int code = 0;
-                int base = i + 2;
-                boolean validHex = true;
-                
-                // Parse 4 hexadecimal digits with correct weights (16^3, 16^2, 16^1, 16^0)
-                for (int j = 0; j < 4; j++) {
-                    int digit = Character.digit(b.charAt(j + base), 16);
-                    if (digit < 0) {
-                        // Invalid hex character encountered
-                        validHex = false;
-                        break;
-                    }
-                    code += digit * UNICODE_HEX_MULTIPLIER[j];
-                }
-                
-                if (validHex) {
-                    sb.append((char) code);
-                    i += 5; // Skip backslash-u-XXXX (5 characters total)
-                    continue;
-                }
-                // If invalid hex, fall through to handle as regular backslash escape
-            }
-            
-            if (currentChar == '\\') { // process case for '\001'
-                int code = 0;
-                int base = i + 1;
-                int j;
+      if (enclosure.equals(currentChar)) {
+        enclosure = null;
+        continue;
+      }
 
-                for (j = 0; j < 3; j++) {
-                    char c = b.charAt(j + base);
-                    if (c >= '0' && c <= '9') {
-                        if (code * 10 + (c - '0') < 128) {
-                            code = code * 10 + (c - '0');
-                        } else {
-                            break;
-                        }
-                    } else {
-                        break;
-                    }
-                }
-                if (j > 0) {
-                    sb.append((char) code);
-                    i += j;
-                    continue;
-                }
+      // Process Unicode escape sequence (backslash-u followed by 4 hex digits)
+      // Need at least 6 characters: backslash-u + 4 hex digits
+      if (currentChar == '\\' && (i + 6 <= b.length()) && b.charAt(i + 1) == 'u') {
+        int code = 0;
+        int base = i + 2;
+        boolean validHex = true;
 
-            }
+        // Parse 4 hexadecimal digits with correct weights (16^3, 16^2, 16^1, 16^0)
+        for (int j = 0; j < 4; j++) {
+          int digit = Character.digit(b.charAt(j + base), 16);
+          if (digit < 0) {
+            // Invalid hex character encountered
+            validHex = false;
+            break;
+          }
+          code += digit * UNICODE_HEX_MULTIPLIER[j];
+        }
 
+        if (validHex) {
+          sb.append((char) code);
+          i += 5; // Skip backslash-u-XXXX (5 characters total)
+          continue;
+        }
+        // If invalid hex, fall through to handle as regular backslash escape
+      }
 
-            if (currentChar == '\\' && (i + 4 < b.length())) {
-                char i1 = b.charAt(i + 1);
-                char i2 = b.charAt(i + 2);
-                char i3 = b.charAt(i + 3);
-                if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7')
-                    && (i3 >= '0' && i3 <= '7')) {
-                    byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
-                    byte[] bValArr = new byte[1];
-                    bValArr[0] = bVal;
-                    String tmp = new String(bValArr);
-                    sb.append(tmp);
-                    i += 3;
-                    continue;
-                }
-            }
+      if (currentChar == '\\') { // process case for '\001'
+        int code = 0;
+        int base = i + 1;
+        int j;
 
-            if (currentChar == '\\' && (i + 2 < b.length())) {
-                char n = b.charAt(i + 1);
-                switch (n) {
-                    case '0':
-                        sb.append("\0");
-                        break;
-                    case '\'':
-                        sb.append("'");
-                        break;
-                    case '"':
-                        sb.append("\"");
-                        break;
-                    case 'b':
-                        sb.append("\b");
-                        break;
-                    case 'n':
-                        sb.append("\n");
-                        break;
-                    case 'r':
-                        sb.append("\r");
-                        break;
-                    case 't':
-                        sb.append("\t");
-                        break;
-                    case 'Z':
-                        sb.append("\u001A");
-                        break;
-                    case '\\':
-                        sb.append("\\");
-                        break;
-                    // The following 2 lines are exactly what MySQL does TODO: why do we do this?
-                    case '%':
-                        sb.append("\\%");
-                        break;
-                    case '_':
-                        sb.append("\\_");
-                        break;
-                    default:
-                        sb.append(n);
-                }
-                i++;
+        for (j = 0; j < 3; j++) {
+          char c = b.charAt(j + base);
+          if (c >= '0' && c <= '9') {
+            if (code * 10 + (c - '0') < 128) {
+              code = code * 10 + (c - '0');
             } else {
-                sb.append(currentChar);
+              break;
             }
+          } else {
+            break;
+          }
         }
-        return sb.toString();
-    }
+        if (j > 0) {
+          sb.append((char) code);
+          i += j;
+          continue;
+        }
+      }
 
-    /**
-     * Convert Java strings to sql strings.
-     */
-    public static String escapeSQLString(String b) {
-        // There's usually nothing to escape so we will be optimistic.
-        String result = b;
-        for (int i = 0; i < result.length(); ++i) {
-            char currentChar = result.charAt(i);
-            if (currentChar == '\\' && ((i + 1) < result.length())) {
-                // TODO: do we need to handle the "this is what MySQL does" here?
-                char nextChar = result.charAt(i + 1);
-                if (nextChar == '%' || nextChar == '_') {
-                    ++i;
-                    continue;
-                }
-            }
-            switch (currentChar) {
-                case '\0':
-                    result = spliceString(result, i, "\\0");
-                    ++i;
-                    break;
-                case '\'':
-                    result = spliceString(result, i, "\\'");
-                    ++i;
-                    break;
-                case '\"':
-                    result = spliceString(result, i, "\\\"");
-                    ++i;
-                    break;
-                case '\b':
-                    result = spliceString(result, i, "\\b");
-                    ++i;
-                    break;
-                case '\n':
-                    result = spliceString(result, i, "\\n");
-                    ++i;
-                    break;
-                case '\r':
-                    result = spliceString(result, i, "\\r");
-                    ++i;
-                    break;
-                case '\t':
-                    result = spliceString(result, i, "\\t");
-                    ++i;
-                    break;
-                case '\\':
-                    result = spliceString(result, i, "\\\\");
-                    ++i;
-                    break;
-                case '\u001A':
-                    result = spliceString(result, i, "\\Z");
-                    ++i;
-                    break;
-                default: {
-                    if (currentChar < ' ') {
-                        String hex = Integer.toHexString(currentChar);
-                        String unicode = "\\u";
-                        for (int j = 4; j > hex.length(); --j) {
-                            unicode += '0';
-                        }
-                        unicode += hex;
-                        result = spliceString(result, i, unicode);
-                        i += (unicode.length() - 1);
-                    }
-                    break; // if not a control character, do nothing
-                }
-            }
+      if (currentChar == '\\' && (i + 4 < b.length())) {
+        char i1 = b.charAt(i + 1);
+        char i2 = b.charAt(i + 2);
+        char i3 = b.charAt(i + 3);
+        if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
+          byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
+          byte[] bValArr = new byte[1];
+          bValArr[0] = bVal;
+          String tmp = new String(bValArr);
+          sb.append(tmp);
+          i += 3;
+          continue;
         }
+      }
 
-        return "'" + result + "'";
+      if (currentChar == '\\' && (i + 2 < b.length())) {
+        char n = b.charAt(i + 1);
+        switch (n) {
+          case '0':
+            sb.append("\0");
+            break;
+          case '\'':
+            sb.append("'");
+            break;
+          case '"':
+            sb.append("\"");
+            break;
+          case 'b':
+            sb.append("\b");
+            break;
+          case 'n':
+            sb.append("\n");
+            break;
+          case 'r':
+            sb.append("\r");
+            break;
+          case 't':
+            sb.append("\t");
+            break;
+          case 'Z':
+            sb.append("\u001A");
+            break;
+          case '\\':
+            sb.append("\\");
+            break;
+            // The following 2 lines are exactly what MySQL does TODO: why do we do this?
+          case '%':
+            sb.append("\\%");
+            break;
+          case '_':
+            sb.append("\\_");
+            break;
+          default:
+            sb.append(n);
+        }
+        i++;
+      } else {
+        sb.append(currentChar);
+      }
     }
+    return sb.toString();
+  }
 
-    private static String spliceString(String str, int i, String replacement) {
-        return spliceString(str, i, 1, replacement);
+  /** Convert Java strings to sql strings. */
+  public static String escapeSQLString(String b) {
+    // There's usually nothing to escape so we will be optimistic.
+    String result = b;
+    for (int i = 0; i < result.length(); ++i) {
+      char currentChar = result.charAt(i);
+      if (currentChar == '\\' && ((i + 1) < result.length())) {
+        // TODO: do we need to handle the "this is what MySQL does" here?
+        char nextChar = result.charAt(i + 1);
+        if (nextChar == '%' || nextChar == '_') {
+          ++i;
+          continue;
+        }
+      }
+      switch (currentChar) {
+        case '\0':
+          result = spliceString(result, i, "\\0");
+          ++i;
+          break;
+        case '\'':
+          result = spliceString(result, i, "\\'");
+          ++i;
+          break;
+        case '\"':
+          result = spliceString(result, i, "\\\"");
+          ++i;
+          break;
+        case '\b':
+          result = spliceString(result, i, "\\b");
+          ++i;
+          break;
+        case '\n':
+          result = spliceString(result, i, "\\n");
+          ++i;
+          break;
+        case '\r':
+          result = spliceString(result, i, "\\r");
+          ++i;
+          break;
+        case '\t':
+          result = spliceString(result, i, "\\t");
+          ++i;
+          break;
+        case '\\':
+          result = spliceString(result, i, "\\\\");
+          ++i;
+          break;
+        case '\u001A':
+          result = spliceString(result, i, "\\Z");
+          ++i;
+          break;
+        default:
+          {
+            if (currentChar < ' ') {
+              String hex = Integer.toHexString(currentChar);
+              String unicode = "\\u";
+              for (int j = 4; j > hex.length(); --j) {
+                unicode += '0';
+              }
+              unicode += hex;
+              result = spliceString(result, i, unicode);
+              i += (unicode.length() - 1);
+            }
+            break; // if not a control character, do nothing
+          }
+      }
     }
 
-    private static String spliceString(String str, int i, int length, String replacement) {
-        return str.substring(0, i) + replacement + str.substring(i + length);
-    }
+    return "'" + result + "'";
+  }
 
+  private static String spliceString(String str, int i, String replacement) {
+    return spliceString(str, i, 1, replacement);
+  }
 
-    public static String toJavaString(SqlNode node) {
-        if (node == null) {
-            return null;
-        }
-        if (node instanceof SqlCharStringLiteral) {
-            SqlCharStringLiteral literal = (SqlCharStringLiteral) node;
-            return unescapeSQLString("\"" + literal.getNlsString().getValue() + "\"");
-        }
-        return node.toString();
-    }
+  private static String spliceString(String str, int i, int length, String replacement) {
+    return str.substring(0, i) + replacement + str.substring(i + length);
+  }
 
+  public static String toJavaString(SqlNode node) {
+    if (node == null) {
+      return null;
+    }
+    if (node instanceof SqlCharStringLiteral) {
+      SqlCharStringLiteral literal = (SqlCharStringLiteral) node;
+      return unescapeSQLString("\"" + literal.getNlsString().getValue() + "\"");
+    }
+    return node.toString();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/BaseDslTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/BaseDslTest.java
index 4721bf9f8..bc83a6b7f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/BaseDslTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/BaseDslTest.java
@@ -19,12 +19,12 @@
 
 package org.apache.geaflow.dsl;
 
-import com.google.common.io.Resources;
 import java.net.URL;
 import java.nio.charset.Charset;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Set;
+
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlDialect;
 import org.apache.calcite.sql.SqlDialect.DatabaseProduct;
@@ -38,90 +38,86 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.io.Resources;
+
 public class BaseDslTest {
 
-    private static final Logger LOGGER = LoggerFactory.getLogger(BaseDslTest.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(BaseDslTest.class);
 
-    public List parseSql(String path) throws Exception {
-        URL url = Resources.getResource(path);
-        String sql = Resources.toString(url, Charset.defaultCharset());
-        return parseStmts(sql);
-    }
+  public List parseSql(String path) throws Exception {
+    URL url = Resources.getResource(path);
+    String sql = Resources.toString(url, Charset.defaultCharset());
+    return parseStmts(sql);
+  }
 
-    public String parseSqlAndUnParse(String path) throws Exception {
-        URL url = Resources.getResource(path);
-        String sql = Resources.toString(url, Charset.defaultCharset());
-        return parseStmtsAndUnParse(sql);
-    }
+  public String parseSqlAndUnParse(String path) throws Exception {
+    URL url = Resources.getResource(path);
+    String sql = Resources.toString(url, Charset.defaultCharset());
+    return parseStmtsAndUnParse(sql);
+  }
 
-    public List parseStmts(String stmts) throws Exception {
-        LOGGER.info("Origin Sql:\n" + stmts);
-        GeaFlowDSLParser parser = new GeaFlowDSLParser();
-        List sqlNodes = parser.parseMultiStatement(stmts);
-        return sqlNodes;
-    }
+  public List parseStmts(String stmts) throws Exception {
+    LOGGER.info("Origin Sql:\n" + stmts);
+    GeaFlowDSLParser parser = new GeaFlowDSLParser();
+    List sqlNodes = parser.parseMultiStatement(stmts);
+    return sqlNodes;
+  }
 
-    public String parseStmtsAndUnParse(String stmts) throws Exception {
-        LOGGER.info("Origin Sql:\n" + stmts);
-        GeaFlowDSLParser parser = new GeaFlowDSLParser();
-        List sqlNodes = parser.parseMultiStatement(stmts);
-        checkSqlNodes(sqlNodes);
-        String unParseSql = unparse(sqlNodes);
-        LOGGER.info("Unparse Sql:\n" + unParseSql);
-        return unParseSql;
-    }
+  public String parseStmtsAndUnParse(String stmts) throws Exception {
+    LOGGER.info("Origin Sql:\n" + stmts);
+    GeaFlowDSLParser parser = new GeaFlowDSLParser();
+    List sqlNodes = parser.parseMultiStatement(stmts);
+    checkSqlNodes(sqlNodes);
+    String unParseSql = unparse(sqlNodes);
+    LOGGER.info("Unparse Sql:\n" + unParseSql);
+    return unParseSql;
+  }
 
-    private void checkSqlNodes(List sqlNodes) throws ValidationException {
-        Set nodeSet = new LinkedHashSet();
-        for (SqlNode node : sqlNodes) {
-            nodeSet.add(node);
-        }
-        for (SqlNode node : nodeSet) {
-            if (node instanceof SqlCall) {
-                SqlCall call = (SqlCall) node;
-                List operandList = call.getOperandList();
-                checkSqlNodes(operandList);
-                for (int i = 0; i < operandList.size(); i++) {
-                    call.setOperand(i, operandList.get(i));
-                }
-            } else if (node instanceof SqlNodeList) {
-                checkSqlNodes(((SqlNodeList) node).getList());
-            }
-            if (node instanceof SqlCreateTable) {
-                ((SqlCreateTable) node).validate();
-            }
-        }
+  private void checkSqlNodes(List sqlNodes) throws ValidationException {
+    Set nodeSet = new LinkedHashSet();
+    for (SqlNode node : sqlNodes) {
+      nodeSet.add(node);
     }
-
-    /**
-     * Unparse multiple SqlNode to standard sql statement.
-     */
-    public String unparse(List sqlNodes) {
-        StringBuilder builder = new StringBuilder();
-        for (SqlNode node : sqlNodes) {
-            builder
-                .append(toSqlString(node, DatabaseProduct.UNKNOWN.getDialect(),
-                    false).toString());
-            builder.append(";");
-            builder.append("\n");
-            builder.append("\n");
+    for (SqlNode node : nodeSet) {
+      if (node instanceof SqlCall) {
+        SqlCall call = (SqlCall) node;
+        List operandList = call.getOperandList();
+        checkSqlNodes(operandList);
+        for (int i = 0; i < operandList.size(); i++) {
+          call.setOperand(i, operandList.get(i));
         }
-        return builder.toString();
+      } else if (node instanceof SqlNodeList) {
+        checkSqlNodes(((SqlNodeList) node).getList());
+      }
+      if (node instanceof SqlCreateTable) {
+        ((SqlCreateTable) node).validate();
+      }
     }
+  }
 
-    public SqlString toSqlString(SqlNode sqlNode, SqlDialect dialect,
-                                 boolean forceParens) {
-        if (dialect == null) {
-            dialect = SqlDialect.DUMMY;
-        }
-        SqlPrettyWriter writer = new SqlPrettyWriter(dialect);
-        writer.setAlwaysUseParentheses(forceParens);
-        writer.setQuoteAllIdentifiers(false);
-        writer.setIndentation(0);
-        writer.setSelectListItemsOnSeparateLines(true);
-        sqlNode.unparse(writer, 0, 0);
-        final String sql = writer.toString();
-        return new SqlString(dialect, sql);
+  /** Unparse multiple SqlNode to standard sql statement. */
+  public String unparse(List sqlNodes) {
+    StringBuilder builder = new StringBuilder();
+    for (SqlNode node : sqlNodes) {
+      builder.append(toSqlString(node, DatabaseProduct.UNKNOWN.getDialect(), false).toString());
+      builder.append(";");
+      builder.append("\n");
+      builder.append("\n");
     }
+    return builder.toString();
+  }
 
+  public SqlString toSqlString(SqlNode sqlNode, SqlDialect dialect, boolean forceParens) {
+    if (dialect == null) {
+      dialect = SqlDialect.DUMMY;
+    }
+    SqlPrettyWriter writer = new SqlPrettyWriter(dialect);
+    writer.setAlwaysUseParentheses(forceParens);
+    writer.setQuoteAllIdentifiers(false);
+    writer.setIndentation(0);
+    writer.setSelectListItemsOnSeparateLines(true);
+    sqlNode.unparse(writer, 0, 0);
+    final String sql = writer.toString();
+    return new SqlString(dialect, sql);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/DdlSyntaxTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/DdlSyntaxTest.java
index e4bcb4886..2a99304c0 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/DdlSyntaxTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/DdlSyntaxTest.java
@@ -25,73 +25,73 @@
 @Test(groups = "SyntaxTest")
 public class DdlSyntaxTest extends BaseDslTest {
 
-    @Test
-    public void testGQLCreateFunction() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLCreateFunction.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLCreateFunction() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLCreateFunction.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLCreateGraph() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLCreateGraph.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLCreateGraph() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLCreateGraph.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLCreateGraphUsing() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLCreateGraphUsing.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLCreateGraphUsing() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLCreateGraphUsing.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLAlterGraph() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLAlterGraph.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLAlterGraph() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLAlterGraph.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testCreateTable() throws Exception {
-        String unParseSql = parseSqlAndUnParse("CreateTable.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testCreateTable() throws Exception {
+    String unParseSql = parseSqlAndUnParse("CreateTable.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLDescGraph() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLDescGraph.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLDescGraph() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLDescGraph.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testCreateView() throws Exception {
-        String unParseSql = parseSqlAndUnParse("CreateView.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testCreateView() throws Exception {
+    String unParseSql = parseSqlAndUnParse("CreateView.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLLet() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLLetStatement.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLLet() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLLetStatement.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testUseInstance() throws Exception {
-        String unParseSql = parseSqlAndUnParse("UseInstance.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testUseInstance() throws Exception {
+    String unParseSql = parseSqlAndUnParse("UseInstance.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testEdgeConstraint() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLEdgeConstraint.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testEdgeConstraint() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLEdgeConstraint.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/GraphAlgorithmSyntaxTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/GraphAlgorithmSyntaxTest.java
index c14636236..8f2b22b27 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/GraphAlgorithmSyntaxTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/GraphAlgorithmSyntaxTest.java
@@ -25,10 +25,10 @@
 @Test(groups = "SyntaxTest")
 public class GraphAlgorithmSyntaxTest extends BaseDslTest {
 
-    @Test
-    public void testGQLAlgorithmCall() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLAlgorithmCall.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLAlgorithmCall() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLAlgorithmCall.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/IsoGqlSyntaxTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/IsoGqlSyntaxTest.java
index 38916e051..0e45ae759 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/IsoGqlSyntaxTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/IsoGqlSyntaxTest.java
@@ -25,17 +25,17 @@
 @Test(groups = "SyntaxTest")
 public class IsoGqlSyntaxTest extends BaseDslTest {
 
-    @Test
-    public void testIsoGQLMatch() throws Exception {
-        String unParseSql = parseSqlAndUnParse("IsoGQLMatch.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testIsoGQLMatch() throws Exception {
+    String unParseSql = parseSqlAndUnParse("IsoGQLMatch.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testIsoGQLSamePredicate() throws Exception {
-        String unParseSql = parseSqlAndUnParse("IsoGQLSame.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testIsoGQLSamePredicate() throws Exception {
+    String unParseSql = parseSqlAndUnParse("IsoGQLSame.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/MatchReturnSyntaxTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/MatchReturnSyntaxTest.java
index 86f0094bf..7c1618d9e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/MatchReturnSyntaxTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/MatchReturnSyntaxTest.java
@@ -25,59 +25,59 @@
 @Test(groups = "SyntaxTest")
 public class MatchReturnSyntaxTest extends BaseDslTest {
 
-    @Test
-    public void testGQLBaseMatchPattern() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLBaseMatchPattern.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLBaseMatchPattern() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLBaseMatchPattern.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLComplexMatch() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLComplexMatch.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLComplexMatch() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLComplexMatch.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLBaseMatchReturnStatements() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLBaseMatchReturnStatements.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLBaseMatchReturnStatements() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLBaseMatchReturnStatements.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLSelectFromMatch() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLSelectFromMatch.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLSelectFromMatch() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLSelectFromMatch.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLFilterStatement() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLFilterStatement.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLFilterStatement() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLFilterStatement.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLLet() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLLetStatement.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLLet() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLLetStatement.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLSubQuery() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLSubQuery.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLSubQuery() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLSubQuery.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 
-    @Test
-    public void testGQLMatchOrder() throws Exception {
-        String unParseSql = parseSqlAndUnParse("GQLMatchOrder.sql");
-        String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
-        Assert.assertEquals(unParseStmts, unParseSql);
-    }
+  @Test
+  public void testGQLMatchOrder() throws Exception {
+    String unParseSql = parseSqlAndUnParse("GQLMatchOrder.sql");
+    String unParseStmts = parseStmtsAndUnParse(parseStmtsAndUnParse(unParseSql));
+    Assert.assertEquals(unParseStmts, unParseSql);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/util/StringLiteralUtilTest.java b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/util/StringLiteralUtilTest.java
index 15669ef5b..d69f40f75 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/util/StringLiteralUtilTest.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-parser/src/test/java/org/apache/geaflow/dsl/util/StringLiteralUtilTest.java
@@ -24,131 +24,130 @@
 
 public class StringLiteralUtilTest {
 
-    @Test
-    public void testUnicodeEscapeSequence_Basic() {
-        // Test \u0041 => "A"
-        String input = "\"\\u0041\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "A");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_AccentedCharacter() {
-        // Test \u00e9 => "é"
-        String input = "\"\\u00e9\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "é");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_MultipleUnicode() {
-        // Test multiple Unicode characters
-        String input = "\"\\u0048\\u0065\\u006c\\u006c\\u006f\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "Hello");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_ChineseCharacter() {
-        // Test Chinese character \u4e2d => "中"
-        String input = "\"\\u4e2d\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "中");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_UpperCaseHex() {
-        // Test uppercase hex digits
-        String input = "\"\\u00FF\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "\u00FF");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_LowerCaseHex() {
-        // Test lowercase hex digits
-        String input = "\"\\u00ff\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "\u00FF");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_InvalidHexCharacter() {
-        // Test invalid hex character (should fall back to regular escape handling)
-        String input = "\"\\u00g1\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        // Should not crash and should handle gracefully
-        Assert.assertNotNull(result);
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_IncompleteSequence() {
-        // Test incomplete sequence (less than 4 hex digits)
-        String input = "\"\\u00a\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        // Should handle gracefully without crashing
-        Assert.assertNotNull(result);
-    }
-
-    @Test
-    public void testCommonEscapeSequences() {
-        // Test common escape sequences
-        String input = "\"\\n\\t\\r\\b\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "\n\t\r\b");
-    }
-
-    @Test
-    public void testMixedEscapeSequences() {
-        // Test mixing Unicode and common escape sequences
-        String input = "\"Hello\\u0020World\\n\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "Hello World\n");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_MaxValue() {
-        // Test maximum Unicode value \uFFFF
-        String input = "\"\\uFFFF\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "\uFFFF");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_ZeroValue() {
-        // Test zero value \u0000
-        String input = "\"\\u0000\"";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "\u0000");
-    }
-
-    @Test
-    public void testUnicodeEscapeSequence_WithQuotes() {
-        // Test with single quotes enclosure
-        String input = "'\\u0041'";
-        String result = StringLiteralUtil.unescapeSQLString(input);
-        Assert.assertEquals(result, "A");
-    }
-
-    @Test
-    public void testEscapeSQLString_Basic() {
-        // Test escapeSQLString method
-        String input = "Hello\nWorld";
-        String result = StringLiteralUtil.escapeSQLString(input);
-        Assert.assertTrue(result.startsWith("'"));
-        Assert.assertTrue(result.endsWith("'"));
-        Assert.assertTrue(result.contains("\\n"));
-    }
-
-    @Test
-    public void testEscapeSQLString_WithUnicode() {
-        // Test escapeSQLString with Unicode characters
-        String input = "A";
-        String result = StringLiteralUtil.escapeSQLString(input);
-        // Should be properly escaped
-        Assert.assertNotNull(result);
-        Assert.assertTrue(result.startsWith("'"));
-        Assert.assertTrue(result.endsWith("'"));
-    }
+  @Test
+  public void testUnicodeEscapeSequence_Basic() {
+    // Test \u0041 => "A"
+    String input = "\"\\u0041\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "A");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_AccentedCharacter() {
+    // Test \u00e9 => "é"
+    String input = "\"\\u00e9\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "é");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_MultipleUnicode() {
+    // Test multiple Unicode characters
+    String input = "\"\\u0048\\u0065\\u006c\\u006c\\u006f\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "Hello");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_ChineseCharacter() {
+    // Test Chinese character \u4e2d => "中"
+    String input = "\"\\u4e2d\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "中");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_UpperCaseHex() {
+    // Test uppercase hex digits
+    String input = "\"\\u00FF\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "\u00FF");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_LowerCaseHex() {
+    // Test lowercase hex digits
+    String input = "\"\\u00ff\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "\u00FF");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_InvalidHexCharacter() {
+    // Test invalid hex character (should fall back to regular escape handling)
+    String input = "\"\\u00g1\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    // Should not crash and should handle gracefully
+    Assert.assertNotNull(result);
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_IncompleteSequence() {
+    // Test incomplete sequence (less than 4 hex digits)
+    String input = "\"\\u00a\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    // Should handle gracefully without crashing
+    Assert.assertNotNull(result);
+  }
+
+  @Test
+  public void testCommonEscapeSequences() {
+    // Test common escape sequences
+    String input = "\"\\n\\t\\r\\b\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "\n\t\r\b");
+  }
+
+  @Test
+  public void testMixedEscapeSequences() {
+    // Test mixing Unicode and common escape sequences
+    String input = "\"Hello\\u0020World\\n\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "Hello World\n");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_MaxValue() {
+    // Test maximum Unicode value \uFFFF
+    String input = "\"\\uFFFF\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "\uFFFF");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_ZeroValue() {
+    // Test zero value \u0000
+    String input = "\"\\u0000\"";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "\u0000");
+  }
+
+  @Test
+  public void testUnicodeEscapeSequence_WithQuotes() {
+    // Test with single quotes enclosure
+    String input = "'\\u0041'";
+    String result = StringLiteralUtil.unescapeSQLString(input);
+    Assert.assertEquals(result, "A");
+  }
+
+  @Test
+  public void testEscapeSQLString_Basic() {
+    // Test escapeSQLString method
+    String input = "Hello\nWorld";
+    String result = StringLiteralUtil.escapeSQLString(input);
+    Assert.assertTrue(result.startsWith("'"));
+    Assert.assertTrue(result.endsWith("'"));
+    Assert.assertTrue(result.contains("\\n"));
+  }
+
+  @Test
+  public void testEscapeSQLString_WithUnicode() {
+    // Test escapeSQLString with Unicode characters
+    String input = "A";
+    String result = StringLiteralUtil.escapeSQLString(input);
+    // Should be properly escaped
+    Assert.assertNotNull(result);
+    Assert.assertTrue(result.startsWith("'"));
+    Assert.assertTrue(result.endsWith("'"));
+  }
 }
-
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/GQLOptimizer.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/GQLOptimizer.java
index 13bd8953e..992ac8319 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/GQLOptimizer.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/GQLOptimizer.java
@@ -23,6 +23,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.Context;
 import org.apache.calcite.plan.Contexts;
 import org.apache.calcite.plan.RelOptRule;
@@ -38,81 +39,83 @@
 
 public class GQLOptimizer {
 
-    private final Context context;
+  private final Context context;
 
-    private final List ruleGroups = new ArrayList<>();
+  private final List ruleGroups = new ArrayList<>();
 
-    private int times = 3;
+  private int times = 3;
 
-    public GQLOptimizer(Context context) {
-        this.context = context;
-    }
+  public GQLOptimizer(Context context) {
+    this.context = context;
+  }
 
-    public GQLOptimizer() {
-        this(Contexts.empty());
-    }
+  public GQLOptimizer() {
+    this(Contexts.empty());
+  }
 
-    public void addRuleGroup(RuleGroup ruleGroup) {
-        if (!ruleGroup.isEmpty()) {
-            ruleGroups.add(ruleGroup);
-            Collections.sort(ruleGroups, Collections.reverseOrder());
-        }
+  public void addRuleGroup(RuleGroup ruleGroup) {
+    if (!ruleGroup.isEmpty()) {
+      ruleGroups.add(ruleGroup);
+      Collections.sort(ruleGroups, Collections.reverseOrder());
     }
+  }
 
-    public int setTimes(int newTimes) {
-        int oldTimes = this.times;
-        this.times = newTimes;
-        return oldTimes;
-    }
+  public int setTimes(int newTimes) {
+    int oldTimes = this.times;
+    this.times = newTimes;
+    return oldTimes;
+  }
 
-    public RelNode optimize(RelNode root) {
-        return optimize(root, this.times);
-    }
+  public RelNode optimize(RelNode root) {
+    return optimize(root, this.times);
+  }
 
-    public RelNode optimize(RelNode root, int runTimes) {
-        RelNode optimizedNode = root;
-        for (int i = 0; i < runTimes; i++) {
-            for (RuleGroup rules : ruleGroups) {
-                optimizedNode = applyRules(rules, optimizedNode);
-            }
-        }
-        return optimizedNode;
+  public RelNode optimize(RelNode root, int runTimes) {
+    RelNode optimizedNode = root;
+    for (int i = 0; i < runTimes; i++) {
+      for (RuleGroup rules : ruleGroups) {
+        optimizedNode = applyRules(rules, optimizedNode);
+      }
     }
+    return optimizedNode;
+  }
 
-    private RelNode applyRules(RuleGroup rules, RelNode node) {
-        // optimize rel node
-        HepProgramBuilder builder = new HepProgramBuilder();
-        builder.addMatchOrder(HepMatchOrder.TOP_DOWN);
-        for (RelOptRule relOptRule : rules) {
-            builder.addRuleInstance(relOptRule);
-        }
-        HepPlanner planner = new HepPlanner(builder.build(), context);
-        planner.setRoot(node);
-        RelNode optimizedNode = planner.findBestExp();
-        // optimize node in match or sub-query.
-        return applyRulesOnChildren(rules, optimizedNode);
+  private RelNode applyRules(RuleGroup rules, RelNode node) {
+    // optimize rel node
+    HepProgramBuilder builder = new HepProgramBuilder();
+    builder.addMatchOrder(HepMatchOrder.TOP_DOWN);
+    for (RelOptRule relOptRule : rules) {
+      builder.addRuleInstance(relOptRule);
     }
+    HepPlanner planner = new HepPlanner(builder.build(), context);
+    planner.setRoot(node);
+    RelNode optimizedNode = planner.findBestExp();
+    // optimize node in match or sub-query.
+    return applyRulesOnChildren(rules, optimizedNode);
+  }
 
-    private RelNode applyRulesOnChildren(RuleGroup rules, RelNode node) {
-        List newInputs = node.getInputs()
-            .stream()
+  private RelNode applyRulesOnChildren(RuleGroup rules, RelNode node) {
+    List newInputs =
+        node.getInputs().stream()
             .map(input -> applyRulesOnChildren(rules, input))
             .collect(Collectors.toList());
 
-        if (node instanceof GraphMatch) {
-            GraphMatch match = (GraphMatch) node;
-            IMatchNode newPathPattern = (IMatchNode) applyRules(rules, match.getPathPattern());
-            assert newInputs.size() == 1;
-            return match.copy(match.getTraitSet(), newInputs.get(0), newPathPattern, match.getRowType());
-        }
-        RelNode newNode = node.accept(new RexShuttle() {
-            @Override
-            public RexNode visitSubQuery(RexSubQuery subQuery) {
+    if (node instanceof GraphMatch) {
+      GraphMatch match = (GraphMatch) node;
+      IMatchNode newPathPattern = (IMatchNode) applyRules(rules, match.getPathPattern());
+      assert newInputs.size() == 1;
+      return match.copy(match.getTraitSet(), newInputs.get(0), newPathPattern, match.getRowType());
+    }
+    RelNode newNode =
+        node.accept(
+            new RexShuttle() {
+              @Override
+              public RexNode visitSubQuery(RexSubQuery subQuery) {
                 RelNode subNode = subQuery.rel;
                 RelNode newSubNode = applyRules(rules, subNode);
                 return subQuery.clone(newSubNode);
-            }
-        });
-        return newNode.copy(newNode.getTraitSet(), newInputs);
-    }
+              }
+            });
+    return newNode.copy(newNode.getTraitSet(), newInputs);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/OptimizeRules.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/OptimizeRules.java
index 77183e146..42589c810 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/OptimizeRules.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/OptimizeRules.java
@@ -19,9 +19,8 @@
 
 package org.apache.geaflow.dsl.optimize;
 
-
-import com.google.common.collect.ImmutableList;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule;
 import org.apache.calcite.rel.rules.AggregateRemoveRule;
@@ -41,66 +40,67 @@
 import org.apache.calcite.rel.rules.UnionToDistinctRule;
 import org.apache.geaflow.dsl.optimize.rule.*;
 
+import com.google.common.collect.ImmutableList;
+
 public class OptimizeRules {
 
-    private static final List PRE_REWRITE_RULES = ImmutableList.of();
+  private static final List PRE_REWRITE_RULES = ImmutableList.of();
 
-    private static final List LOGICAL_RULES = ImmutableList.of(
-        ReduceExpressionsRule.FILTER_INSTANCE,
-        ReduceExpressionsRule.PROJECT_INSTANCE,
-        ReduceExpressionsRule.JOIN_INSTANCE,
-        FilterMergeRule.INSTANCE,
-        FilterAggregateTransposeRule.INSTANCE,
-        ProjectToWindowRule.PROJECT,
-        ProjectToWindowRule.INSTANCE,
-        FilterCorrelateRule.INSTANCE,
-        GQLAggregateProjectMergeRule.INSTANCE,
-        AggregateProjectPullUpConstantsRule.INSTANCE,
-        ProjectMergeRule.INSTANCE,
-        ProjectSortTransposeRule.INSTANCE,
-        JoinPushExpressionsRule.INSTANCE,
-        UnionToDistinctRule.INSTANCE,
-        AggregateRemoveRule.INSTANCE,
-        SortRemoveRule.INSTANCE,
-        PruneEmptyRules.AGGREGATE_INSTANCE,
-        PruneEmptyRules.FILTER_INSTANCE,
-        PruneEmptyRules.JOIN_LEFT_INSTANCE,
-        PruneEmptyRules.JOIN_RIGHT_INSTANCE,
-        PruneEmptyRules.PROJECT_INSTANCE,
-        PruneEmptyRules.SORT_INSTANCE,
-        PruneEmptyRules.UNION_INSTANCE,
-        ProjectFilterTransposeRule.INSTANCE,
-        FilterProjectTransposeRule.INSTANCE,
-        GQLProjectRemoveRule.INSTANCE,
-        UnionEliminatorRule.INSTANCE,
-        GQLMatchUnionMergeRule.INSTANCE,
-        MatchSortToLogicalSortRule.INSTANCE,
-        PathModifyMergeRule.INSTANCE,
-        AddVertexResetRule.INSTANCE,
-        PushJoinFilterConditionRule.INSTANCE,
-        PushConsecutiveJoinConditionRule.INSTANCE,
-        TableJoinTableToGraphRule.INSTANCE,
-        MatchJoinMatchMergeRule.INSTANCE,
-        MatchJoinTableToGraphMatchRule.INSTANCE,
-        TableJoinMatchToGraphMatchRule.INSTANCE,
-        MatchJoinMatchMergeRule.INSTANCE,
-        FilterToMatchRule.INSTANCE,
-        FilterMatchNodeTransposeRule.INSTANCE,
-        MatchFilterMergeRule.INSTANCE,
-        TableScanToGraphRule.INSTANCE,
-        MatchIdFilterSimplifyRule.INSTANCE,
-        MatchEdgeLabelFilterRemoveRule.INSTANCE,
-        GraphMatchFieldPruneRule.INSTANCE,
-        ProjectFieldPruneRule.INSTANCE
-    );
+  private static final List LOGICAL_RULES =
+      ImmutableList.of(
+          ReduceExpressionsRule.FILTER_INSTANCE,
+          ReduceExpressionsRule.PROJECT_INSTANCE,
+          ReduceExpressionsRule.JOIN_INSTANCE,
+          FilterMergeRule.INSTANCE,
+          FilterAggregateTransposeRule.INSTANCE,
+          ProjectToWindowRule.PROJECT,
+          ProjectToWindowRule.INSTANCE,
+          FilterCorrelateRule.INSTANCE,
+          GQLAggregateProjectMergeRule.INSTANCE,
+          AggregateProjectPullUpConstantsRule.INSTANCE,
+          ProjectMergeRule.INSTANCE,
+          ProjectSortTransposeRule.INSTANCE,
+          JoinPushExpressionsRule.INSTANCE,
+          UnionToDistinctRule.INSTANCE,
+          AggregateRemoveRule.INSTANCE,
+          SortRemoveRule.INSTANCE,
+          PruneEmptyRules.AGGREGATE_INSTANCE,
+          PruneEmptyRules.FILTER_INSTANCE,
+          PruneEmptyRules.JOIN_LEFT_INSTANCE,
+          PruneEmptyRules.JOIN_RIGHT_INSTANCE,
+          PruneEmptyRules.PROJECT_INSTANCE,
+          PruneEmptyRules.SORT_INSTANCE,
+          PruneEmptyRules.UNION_INSTANCE,
+          ProjectFilterTransposeRule.INSTANCE,
+          FilterProjectTransposeRule.INSTANCE,
+          GQLProjectRemoveRule.INSTANCE,
+          UnionEliminatorRule.INSTANCE,
+          GQLMatchUnionMergeRule.INSTANCE,
+          MatchSortToLogicalSortRule.INSTANCE,
+          PathModifyMergeRule.INSTANCE,
+          AddVertexResetRule.INSTANCE,
+          PushJoinFilterConditionRule.INSTANCE,
+          PushConsecutiveJoinConditionRule.INSTANCE,
+          TableJoinTableToGraphRule.INSTANCE,
+          MatchJoinMatchMergeRule.INSTANCE,
+          MatchJoinTableToGraphMatchRule.INSTANCE,
+          TableJoinMatchToGraphMatchRule.INSTANCE,
+          MatchJoinMatchMergeRule.INSTANCE,
+          FilterToMatchRule.INSTANCE,
+          FilterMatchNodeTransposeRule.INSTANCE,
+          MatchFilterMergeRule.INSTANCE,
+          TableScanToGraphRule.INSTANCE,
+          MatchIdFilterSimplifyRule.INSTANCE,
+          MatchEdgeLabelFilterRemoveRule.INSTANCE,
+          GraphMatchFieldPruneRule.INSTANCE,
+          ProjectFieldPruneRule.INSTANCE);
 
-    private static final List POST_OPTIMIZE_RULES = ImmutableList.of(
-        PathInputReplaceRule.INSTANCE
-    );
+  private static final List POST_OPTIMIZE_RULES =
+      ImmutableList.of(PathInputReplaceRule.INSTANCE);
 
-    public static final List RULE_GROUPS = ImmutableList.of(
-        RuleGroup.of(PRE_REWRITE_RULES, 10),
-        RuleGroup.of(LOGICAL_RULES, 5),
-        RuleGroup.of(POST_OPTIMIZE_RULES, 0)
-    );
+  public static final List RULE_GROUPS =
+      ImmutableList.of(
+          RuleGroup.of(PRE_REWRITE_RULES, 10),
+          RuleGroup.of(LOGICAL_RULES, 5),
+          RuleGroup.of(POST_OPTIMIZE_RULES, 0));
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/RuleGroup.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/RuleGroup.java
index 509c4a365..9fb0dd65a 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/RuleGroup.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/RuleGroup.java
@@ -22,44 +22,45 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptRule;
 
 public class RuleGroup implements Iterable, Comparable {
 
-    public static final int DEFAULT_PRIORITY = 0;
+  public static final int DEFAULT_PRIORITY = 0;
 
-    private final List rules;
+  private final List rules;
 
-    private final int priority;
+  private final int priority;
 
-    private RuleGroup(List rules, int priority) {
-        this.rules = Objects.requireNonNull(rules);
-        this.priority = priority;
-    }
+  private RuleGroup(List rules, int priority) {
+    this.rules = Objects.requireNonNull(rules);
+    this.priority = priority;
+  }
 
-    public static RuleGroup of(List rules, int priority) {
-        return new RuleGroup(rules, priority);
-    }
+  public static RuleGroup of(List rules, int priority) {
+    return new RuleGroup(rules, priority);
+  }
 
-    public RuleGroup(List rules) {
-        this(rules, DEFAULT_PRIORITY);
-    }
+  public RuleGroup(List rules) {
+    this(rules, DEFAULT_PRIORITY);
+  }
 
-    @Override
-    public Iterator iterator() {
-        return rules.iterator();
-    }
+  @Override
+  public Iterator iterator() {
+    return rules.iterator();
+  }
 
-    public int getPriority() {
-        return priority;
-    }
+  public int getPriority() {
+    return priority;
+  }
 
-    @Override
-    public int compareTo(RuleGroup o) {
-        return Integer.compare(this.priority, o.priority);
-    }
+  @Override
+  public int compareTo(RuleGroup o) {
+    return Integer.compare(this.priority, o.priority);
+  }
 
-    public boolean isEmpty() {
-        return rules.isEmpty();
-    }
+  public boolean isEmpty() {
+    return rules.isEmpty();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AbstractJoinToGraphRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AbstractJoinToGraphRule.java
index f0e5e4250..960297415 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AbstractJoinToGraphRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AbstractJoinToGraphRule.java
@@ -31,6 +31,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptRuleOperand;
@@ -89,981 +90,1289 @@
 
 public abstract class AbstractJoinToGraphRule extends RelOptRule {
 
-    private static final Logger LOGGER = LoggerFactory.getLogger(AbstractJoinToGraphRule.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(AbstractJoinToGraphRule.class);
 
-    public AbstractJoinToGraphRule(RelOptRuleOperand operand) {
-        super(operand);
-    }
+  public AbstractJoinToGraphRule(RelOptRuleOperand operand) {
+    super(operand);
+  }
 
-    /**
-     * Determine if a Join in SQL can be converted to node matching or edge matching in GQL.
-     */
-    protected static GraphJoinType getJoinType(LogicalJoin join) {
-        JoinInfo joinInfo = join.analyzeCondition();
-        // only support inner join and equal-join currently.
-        if (!joinInfo.isEqui() || !isSupportJoinType(join.getJoinType())) {
-            return GraphJoinType.NONE_GRAPH_JOIN;
-        }
+  /** Determine if a Join in SQL can be converted to node matching or edge matching in GQL. */
+  protected static GraphJoinType getJoinType(LogicalJoin join) {
+    JoinInfo joinInfo = join.analyzeCondition();
+    // only support inner join and equal-join currently.
+    if (!joinInfo.isEqui() || !isSupportJoinType(join.getJoinType())) {
+      return GraphJoinType.NONE_GRAPH_JOIN;
+    }
 
-        List leftKeys = new ArrayList<>(joinInfo.leftKeys);
-        List rightKeys = new ArrayList<>(joinInfo.rightKeys);
-        RelDataType leftType = join.getLeft().getRowType();
-        RelDataType rightType = join.getRight().getRowType();
+    List leftKeys = new ArrayList<>(joinInfo.leftKeys);
+    List rightKeys = new ArrayList<>(joinInfo.rightKeys);
+    RelDataType leftType = join.getLeft().getRowType();
+    RelDataType rightType = join.getRight().getRowType();
 
-        GraphJoinType graphJoinType = GraphJoinType.NONE_GRAPH_JOIN;
-        for (int i = 0; i < leftKeys.size(); i++) {
-            Integer leftKey = leftKeys.get(i);
-            Integer rightKey = rightKeys.get(i);
-            GraphJoinType currentJoinType = getJoinType(leftKey, leftType, rightKey, rightType);
-            if (currentJoinType != GraphJoinType.NONE_GRAPH_JOIN) {
-                if (graphJoinType == GraphJoinType.NONE_GRAPH_JOIN) {
-                    graphJoinType = currentJoinType;
-                } else if (graphJoinType != currentJoinType) {
-                    // contain multi join pattern, can not translate to graph, just return
-                    return GraphJoinType.NONE_GRAPH_JOIN;
-                }
-            }
+    GraphJoinType graphJoinType = GraphJoinType.NONE_GRAPH_JOIN;
+    for (int i = 0; i < leftKeys.size(); i++) {
+      Integer leftKey = leftKeys.get(i);
+      Integer rightKey = rightKeys.get(i);
+      GraphJoinType currentJoinType = getJoinType(leftKey, leftType, rightKey, rightType);
+      if (currentJoinType != GraphJoinType.NONE_GRAPH_JOIN) {
+        if (graphJoinType == GraphJoinType.NONE_GRAPH_JOIN) {
+          graphJoinType = currentJoinType;
+        } else if (graphJoinType != currentJoinType) {
+          // contain multi join pattern, can not translate to graph, just return
+          return GraphJoinType.NONE_GRAPH_JOIN;
         }
-        return graphJoinType;
+      }
     }
+    return graphJoinType;
+  }
 
-    protected static boolean isSupportJoinType(JoinRelType type) {
-        return type == JoinRelType.INNER || type == JoinRelType.LEFT;
-    }
+  protected static boolean isSupportJoinType(JoinRelType type) {
+    return type == JoinRelType.INNER || type == JoinRelType.LEFT;
+  }
 
-    private static GraphJoinType getJoinType(int leftIndex, RelDataType leftType, int rightIndex,
-                                             RelDataType rightType) {
+  private static GraphJoinType getJoinType(
+      int leftIndex, RelDataType leftType, int rightIndex, RelDataType rightType) {
 
-        RelDataType leftKeyType = leftType.getFieldList().get(leftIndex).getType();
-        RelDataType rightKeyType = rightType.getFieldList().get(rightIndex).getType();
+    RelDataType leftKeyType = leftType.getFieldList().get(leftIndex).getType();
+    RelDataType rightKeyType = rightType.getFieldList().get(rightIndex).getType();
 
-        if (leftKeyType instanceof MetaFieldType && rightKeyType instanceof MetaFieldType) {
-            MetaField leftMetaField = ((MetaFieldType) leftKeyType).getMetaField();
-            MetaField rightMetaField = ((MetaFieldType) rightKeyType).getMetaField();
+    if (leftKeyType instanceof MetaFieldType && rightKeyType instanceof MetaFieldType) {
+      MetaField leftMetaField = ((MetaFieldType) leftKeyType).getMetaField();
+      MetaField rightMetaField = ((MetaFieldType) rightKeyType).getMetaField();
 
-            switch (leftMetaField) {
-                case VERTEX_ID:
-                    if (rightMetaField == MetaField.EDGE_SRC_ID) {
-                        return GraphJoinType.VERTEX_JOIN_EDGE;
-                    } else if (rightMetaField == MetaField.EDGE_TARGET_ID) {
-                        return GraphJoinType.EDGE_JOIN_VERTEX;
-                    }
-                    break;
-                case EDGE_SRC_ID:
-                    if (rightMetaField == MetaField.VERTEX_ID) {
-                        return GraphJoinType.VERTEX_JOIN_EDGE;
-                    }
-                    break;
-                case EDGE_TARGET_ID:
-                    if (rightMetaField == MetaField.VERTEX_ID) {
-                        return GraphJoinType.EDGE_JOIN_VERTEX;
-                    }
-                    break;
-                default:
-            }
-        }
-        return GraphJoinType.NONE_GRAPH_JOIN;
+      switch (leftMetaField) {
+        case VERTEX_ID:
+          if (rightMetaField == MetaField.EDGE_SRC_ID) {
+            return GraphJoinType.VERTEX_JOIN_EDGE;
+          } else if (rightMetaField == MetaField.EDGE_TARGET_ID) {
+            return GraphJoinType.EDGE_JOIN_VERTEX;
+          }
+          break;
+        case EDGE_SRC_ID:
+          if (rightMetaField == MetaField.VERTEX_ID) {
+            return GraphJoinType.VERTEX_JOIN_EDGE;
+          }
+          break;
+        case EDGE_TARGET_ID:
+          if (rightMetaField == MetaField.VERTEX_ID) {
+            return GraphJoinType.EDGE_JOIN_VERTEX;
+          }
+          break;
+        default:
+      }
     }
+    return GraphJoinType.NONE_GRAPH_JOIN;
+  }
 
-    /**
-     * Determine if an SQL logical RelNode belongs to a single-chain and can be rewritten as a GQL RelNode.
-     */
-    private static boolean isSingleChain(RelNode relNode) {
-        return relNode instanceof LogicalFilter || relNode instanceof LogicalProject
-            || relNode instanceof LogicalAggregate;
-    }
+  /**
+   * Determine if an SQL logical RelNode belongs to a single-chain and can be rewritten as a GQL
+   * RelNode.
+   */
+  private static boolean isSingleChain(RelNode relNode) {
+    return relNode instanceof LogicalFilter
+        || relNode instanceof LogicalProject
+        || relNode instanceof LogicalAggregate;
+  }
 
-    /**
-     * Determine if an SQL logical RelNode is downstream of a TableScan and if all the RelNodes
-     * on its input chain can be rewritten as GQL.
-     */
-    protected static boolean isSingleChainFromLogicalTableScan(RelNode node) {
-        RelNode relNode = GQLRelUtil.toRel(node);
-        if (isSingleChain(relNode)) {
-            return relNode.getInputs().size() == 1 && isSingleChainFromLogicalTableScan(
-                relNode.getInput(0));
-        }
-        return relNode instanceof LogicalTableScan;
+  /**
+   * Determine if an SQL logical RelNode is downstream of a TableScan and if all the RelNodes on its
+   * input chain can be rewritten as GQL.
+   */
+  protected static boolean isSingleChainFromLogicalTableScan(RelNode node) {
+    RelNode relNode = GQLRelUtil.toRel(node);
+    if (isSingleChain(relNode)) {
+      return relNode.getInputs().size() == 1
+          && isSingleChainFromLogicalTableScan(relNode.getInput(0));
     }
+    return relNode instanceof LogicalTableScan;
+  }
 
-    /**
-     * Determine if an SQL logical RelNode is downstream of a GraphMatch and if all the RelNodes
-     * on its input chain can be rewritten as GQL.
-     */
-    protected static boolean isSingleChainFromGraphMatch(RelNode node) {
-        RelNode relNode = GQLRelUtil.toRel(node);
-        if (isSingleChain(relNode)) {
-            return relNode.getInputs().size() == 1 && isSingleChainFromGraphMatch(
-                relNode.getInput(0));
-        }
-        return relNode instanceof GraphMatch;
+  /**
+   * Determine if an SQL logical RelNode is downstream of a GraphMatch and if all the RelNodes on
+   * its input chain can be rewritten as GQL.
+   */
+  protected static boolean isSingleChainFromGraphMatch(RelNode node) {
+    RelNode relNode = GQLRelUtil.toRel(node);
+    if (isSingleChain(relNode)) {
+      return relNode.getInputs().size() == 1 && isSingleChainFromGraphMatch(relNode.getInput(0));
     }
+    return relNode instanceof GraphMatch;
+  }
 
-    /**
-     * Convert the RelNodes on the single chain from "from" to "to" into GQL and push them
-     * into the input. Rebuild a MatchNode, and sequentially place all the accessible fields
-     * into the returned rexNodeMap.
-     */
-    protected IMatchNode concatToMatchNode(RelBuilder builder, IMatchNode left, RelNode from, RelNode to,
-                                           IMatchNode input, List rexNodeMap) {
-        if (from instanceof LogicalFilter) {
-            LogicalFilter filter = (LogicalFilter) from;
-            List inputRexNode2RexInfo = new ArrayList<>();
-            IMatchNode filterInput = from == to ? input : concatToMatchNode(builder, left,
-                GQLRelUtil.toRel(filter.getInput()), to, input, inputRexNode2RexInfo);
-            int lastNodeIndex = filterInput.getPathSchema().getFieldCount() - 1;
-            if (lastNodeIndex < 0) {
-                throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
-            }
-            RexNode newCondition = filter.getCondition();
-            if (!inputRexNode2RexInfo.isEmpty()) {
-                newCondition = GQLRexUtil.replace(filter.getCondition(), rexNode -> {
-                    if (rexNode instanceof RexInputRef) {
-                        return inputRexNode2RexInfo.get(((RexInputRef) rexNode).getIndex());
-                    }
-                    return rexNode;
+  /**
+   * Convert the RelNodes on the single chain from "from" to "to" into GQL and push them into the
+   * input. Rebuild a MatchNode, and sequentially place all the accessible fields into the returned
+   * rexNodeMap.
+   */
+  protected IMatchNode concatToMatchNode(
+      RelBuilder builder,
+      IMatchNode left,
+      RelNode from,
+      RelNode to,
+      IMatchNode input,
+      List rexNodeMap) {
+    if (from instanceof LogicalFilter) {
+      LogicalFilter filter = (LogicalFilter) from;
+      List inputRexNode2RexInfo = new ArrayList<>();
+      IMatchNode filterInput =
+          from == to
+              ? input
+              : concatToMatchNode(
+                  builder,
+                  left,
+                  GQLRelUtil.toRel(filter.getInput()),
+                  to,
+                  input,
+                  inputRexNode2RexInfo);
+      int lastNodeIndex = filterInput.getPathSchema().getFieldCount() - 1;
+      if (lastNodeIndex < 0) {
+        throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
+      }
+      RexNode newCondition = filter.getCondition();
+      if (!inputRexNode2RexInfo.isEmpty()) {
+        newCondition =
+            GQLRexUtil.replace(
+                filter.getCondition(),
+                rexNode -> {
+                  if (rexNode instanceof RexInputRef) {
+                    return inputRexNode2RexInfo.get(((RexInputRef) rexNode).getIndex());
+                  }
+                  return rexNode;
                 });
-                rexNodeMap.addAll(inputRexNode2RexInfo);
-            } else {
-                String lastNodeLabel = filterInput.getPathSchema().getFieldList().get(lastNodeIndex)
-                    .getName();
-                RelDataType oldType = filterInput.getPathSchema().getFieldList().get(lastNodeIndex)
-                    .getType();
-                newCondition = GQLRexUtil.replace(newCondition, rex -> {
-                    if (rex instanceof RexInputRef) {
-                        return builder.getRexBuilder().makeFieldAccess(
+        rexNodeMap.addAll(inputRexNode2RexInfo);
+      } else {
+        String lastNodeLabel =
+            filterInput.getPathSchema().getFieldList().get(lastNodeIndex).getName();
+        RelDataType oldType =
+            filterInput.getPathSchema().getFieldList().get(lastNodeIndex).getType();
+        newCondition =
+            GQLRexUtil.replace(
+                newCondition,
+                rex -> {
+                  if (rex instanceof RexInputRef) {
+                    return builder
+                        .getRexBuilder()
+                        .makeFieldAccess(
                             new PathInputRef(lastNodeLabel, lastNodeIndex, oldType),
                             ((RexInputRef) rex).getIndex());
-                    }
-                    return rex;
+                  }
+                  return rex;
                 });
-                rexNodeMap.addAll(inputRexNode2RexInfo);
-            }
-            return MatchFilter.create(filterInput, newCondition, filterInput.getPathSchema());
-        } else if (from instanceof LogicalProject) {
-            LogicalProject project = (LogicalProject) from;
-            List inputRexNode2RexInfo = new ArrayList<>();
-            IMatchNode projectInput = from == to ? input : concatToMatchNode(builder, left,
-                GQLRelUtil.toRel(project.getInput()), to, input, inputRexNode2RexInfo);
+        rexNodeMap.addAll(inputRexNode2RexInfo);
+      }
+      return MatchFilter.create(filterInput, newCondition, filterInput.getPathSchema());
+    } else if (from instanceof LogicalProject) {
+      LogicalProject project = (LogicalProject) from;
+      List inputRexNode2RexInfo = new ArrayList<>();
+      IMatchNode projectInput =
+          from == to
+              ? input
+              : concatToMatchNode(
+                  builder,
+                  left,
+                  GQLRelUtil.toRel(project.getInput()),
+                  to,
+                  input,
+                  inputRexNode2RexInfo);
 
-            int lastNodeIndex = projectInput.getPathSchema().getFieldCount() - 1;
-            if (lastNodeIndex < 0) {
-                throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
-            }
-            String lastNodeLabel = projectInput.getPathSchema().getFieldList().get(lastNodeIndex)
-                .getName();
-            RelDataType lastNodeType =
-                projectInput.getPathSchema().getFieldList().get(lastNodeIndex).getType();
-            List replacedProjects = new ArrayList<>();
-            //Rewrite the projects by the rex mapping table returned through input reconstructing.
-            if (!inputRexNode2RexInfo.isEmpty()) {
-                replacedProjects.addAll(project.getProjects().stream().map(prj -> GQLRexUtil.replace(prj, rexNode -> {
-                    if (rexNode instanceof RexInputRef) {
-                        return inputRexNode2RexInfo.get(((RexInputRef) rexNode).getIndex());
-                    }
-                    return rexNode;
-                })).collect(Collectors.toList()));
-            } else {
-                replacedProjects.addAll(project.getProjects().stream().map(prj -> GQLRexUtil.replace(prj, rex -> {
-                    if (rex instanceof RexInputRef) {
-                        return builder.getRexBuilder().makeFieldAccess(
-                            new PathInputRef(lastNodeLabel, lastNodeIndex, lastNodeType),
-                            ((RexInputRef) rex).getIndex());
-                    }
-                    return rex;
-                })).collect(Collectors.toList()));
-            }
-            rexNodeMap.addAll(replacedProjects);
+      int lastNodeIndex = projectInput.getPathSchema().getFieldCount() - 1;
+      if (lastNodeIndex < 0) {
+        throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
+      }
+      String lastNodeLabel =
+          projectInput.getPathSchema().getFieldList().get(lastNodeIndex).getName();
+      RelDataType lastNodeType =
+          projectInput.getPathSchema().getFieldList().get(lastNodeIndex).getType();
+      List replacedProjects = new ArrayList<>();
+      // Rewrite the projects by the rex mapping table returned through input reconstructing.
+      if (!inputRexNode2RexInfo.isEmpty()) {
+        replacedProjects.addAll(
+            project.getProjects().stream()
+                .map(
+                    prj ->
+                        GQLRexUtil.replace(
+                            prj,
+                            rexNode -> {
+                              if (rexNode instanceof RexInputRef) {
+                                return inputRexNode2RexInfo.get(((RexInputRef) rexNode).getIndex());
+                              }
+                              return rexNode;
+                            }))
+                .collect(Collectors.toList()));
+      } else {
+        replacedProjects.addAll(
+            project.getProjects().stream()
+                .map(
+                    prj ->
+                        GQLRexUtil.replace(
+                            prj,
+                            rex -> {
+                              if (rex instanceof RexInputRef) {
+                                return builder
+                                    .getRexBuilder()
+                                    .makeFieldAccess(
+                                        new PathInputRef(
+                                            lastNodeLabel, lastNodeIndex, lastNodeType),
+                                        ((RexInputRef) rex).getIndex());
+                              }
+                              return rex;
+                            }))
+                .collect(Collectors.toList()));
+      }
+      rexNodeMap.addAll(replacedProjects);
 
-            List metaFieldProjects = replacedProjects.stream()
-                .filter(rex -> rex instanceof RexFieldAccess).collect(Collectors.toList());
-            List addFieldProjects = replacedProjects.stream()
-                .filter(rex -> !metaFieldProjects.contains(rex)).collect(Collectors.toList());
-            List addFieldIndices = addFieldProjects.stream().map(replacedProjects::indexOf)
-                .collect(Collectors.toList());
-
-            EdgeRecordType edgeNewType;
-            String edgeName;
-            VertexRecordType vertexNewType;
-            String vertexName;
-            RelDataType oldType = projectInput.getPathSchema().firstField().get().getType();
-            PathRecordType newPathRecordType = ((PathRecordType) projectInput.getRowType());
-            int extendNodeIndex;
-            String extendNodeLabel;
-            List addFieldNames;
-            List operands = new ArrayList<>();
-            Map rex2VariableInfo = new HashMap<>();
-            RexBuilder rexBuilder = builder.getRexBuilder();
-            if (addFieldProjects.size() > 0) {
-                //Paths start with a vertex, add fields on the vertex and add a vertex extension.
-                if (oldType instanceof VertexRecordType) {
-                    vertexNewType = (VertexRecordType) oldType;
-                    addFieldNames = this.generateFieldNames("f", addFieldProjects.size(),
-                        new HashSet<>(vertexNewType.getFieldNames()));
-                    for (int i = 0; i < addFieldNames.size(); i++) {
-                        vertexNewType = vertexNewType.add(addFieldNames.get(i),
-                            addFieldProjects.get(i).getType(), true);
-                    }
-                    vertexName = projectInput.getPathSchema().firstField().get().getName();
-                    newPathRecordType = newPathRecordType.addField(vertexName, vertexNewType, true);
-                    extendNodeIndex = newPathRecordType.getField(vertexName, true, false).getIndex();
-                    extendNodeLabel = newPathRecordType.getFieldList().get(extendNodeIndex).getName();
-
-                    int firstFieldIndex = projectInput.getPathSchema().firstField().get().getIndex();
-                    PathInputRef refPathInput = new PathInputRef(vertexName, firstFieldIndex, oldType);
-                    PathInputRef leftRex = new PathInputRef(extendNodeLabel, extendNodeIndex,
-                        vertexNewType);
-                    for (RelDataTypeField field : leftRex.getType().getFieldList()) {
-                        VariableInfo variableInfo;
-                        RexNode operand;
-                        if (addFieldNames.contains(field.getName())) {
-                            // cast right expression to field type.
-                            int indexOfAddFields = addFieldNames.indexOf(field.getName());
-                            operand = builder.getRexBuilder()
-                                .makeCast(field.getType(), addFieldProjects.get(indexOfAddFields));
-                            variableInfo = new VariableInfo(false, field.getName());
-                            rexNodeMap.set(addFieldIndices.get(indexOfAddFields),
-                                rexBuilder.makeFieldAccess(leftRex, field.getIndex()));
-                        } else {
-                            operand = rexBuilder.makeFieldAccess(refPathInput, field.getIndex());
-                            variableInfo = new VariableInfo(false, field.getName());
-                        }
-                        operands.add(operand);
-                        rex2VariableInfo.put(operand, variableInfo);
-                    }
-                    // Construct RexObjectConstruct for dynamic field append expression.
-                    RexObjectConstruct rightRex = new RexObjectConstruct(vertexNewType, operands,
-                        rex2VariableInfo);
-                    List pathModifyExpressions = new ArrayList<>();
-                    pathModifyExpressions.add(new PathModifyExpression(leftRex, rightRex));
+      List metaFieldProjects =
+          replacedProjects.stream()
+              .filter(rex -> rex instanceof RexFieldAccess)
+              .collect(Collectors.toList());
+      List addFieldProjects =
+          replacedProjects.stream()
+              .filter(rex -> !metaFieldProjects.contains(rex))
+              .collect(Collectors.toList());
+      List addFieldIndices =
+          addFieldProjects.stream().map(replacedProjects::indexOf).collect(Collectors.toList());
 
-                    vertexName = this.generateFieldNames("f", 1,
-                        new HashSet<>(newPathRecordType.getFieldNames())).get(0);
-                    newPathRecordType = newPathRecordType.addField(vertexName, oldType, true);
-                    extendNodeIndex = newPathRecordType.getField(vertexName, true, false).getIndex();
-                    PathInputRef leftRex2 = new PathInputRef(vertexName, extendNodeIndex, oldType);
-                    Map vertexRex2VariableInfo = new HashMap<>();
-                    List vertexOperands = refPathInput.getType().getFieldList().stream()
-                        .map(f -> {
-                            RexNode operand = builder.getRexBuilder()
-                                .makeFieldAccess(refPathInput, f.getIndex());
-                            VariableInfo variableInfo = new VariableInfo(false, f.getName());
-                            vertexRex2VariableInfo.put(operand, variableInfo);
-                            return operand;
-                        }).collect(Collectors.toList());
-                    RexObjectConstruct rightRex2 = new RexObjectConstruct(leftRex2.getType(),
-                        vertexOperands, vertexRex2VariableInfo);
-                    pathModifyExpressions.add(new PathModifyExpression(leftRex2, rightRex2));
+      EdgeRecordType edgeNewType;
+      String edgeName;
+      VertexRecordType vertexNewType;
+      String vertexName;
+      RelDataType oldType = projectInput.getPathSchema().firstField().get().getType();
+      PathRecordType newPathRecordType = ((PathRecordType) projectInput.getRowType());
+      int extendNodeIndex;
+      String extendNodeLabel;
+      List addFieldNames;
+      List operands = new ArrayList<>();
+      Map rex2VariableInfo = new HashMap<>();
+      RexBuilder rexBuilder = builder.getRexBuilder();
+      if (addFieldProjects.size() > 0) {
+        // Paths start with a vertex, add fields on the vertex and add a vertex extension.
+        if (oldType instanceof VertexRecordType) {
+          vertexNewType = (VertexRecordType) oldType;
+          addFieldNames =
+              this.generateFieldNames(
+                  "f", addFieldProjects.size(), new HashSet<>(vertexNewType.getFieldNames()));
+          for (int i = 0; i < addFieldNames.size(); i++) {
+            vertexNewType =
+                vertexNewType.add(addFieldNames.get(i), addFieldProjects.get(i).getType(), true);
+          }
+          vertexName = projectInput.getPathSchema().firstField().get().getName();
+          newPathRecordType = newPathRecordType.addField(vertexName, vertexNewType, true);
+          extendNodeIndex = newPathRecordType.getField(vertexName, true, false).getIndex();
+          extendNodeLabel = newPathRecordType.getFieldList().get(extendNodeIndex).getName();
 
-                    GQLJavaTypeFactory gqlJavaTypeFactory =
-                        (GQLJavaTypeFactory) builder.getTypeFactory();
-                    GeaFlowGraph currentGraph = gqlJavaTypeFactory.getCurrentGraph();
-                    GraphRecordType graphSchema = (GraphRecordType) currentGraph.getRowType(
-                        gqlJavaTypeFactory);
-                    return MatchExtend.create(projectInput, pathModifyExpressions,
-                        newPathRecordType, graphSchema);
-                } else {
-                    //Paths start with an edge, add fields on the edge and add an edge extension.
-                    edgeNewType = (EdgeRecordType) oldType;
-                    addFieldNames = this.generateFieldNames("f", addFieldProjects.size(),
-                        new HashSet<>(edgeNewType.getFieldNames()));
-                    for (int i = 0; i < addFieldNames.size(); i++) {
-                        edgeNewType = edgeNewType.add(addFieldNames.get(i),
-                            addFieldProjects.get(i).getType(), true);
-                    }
+          int firstFieldIndex = projectInput.getPathSchema().firstField().get().getIndex();
+          PathInputRef refPathInput = new PathInputRef(vertexName, firstFieldIndex, oldType);
+          PathInputRef leftRex = new PathInputRef(extendNodeLabel, extendNodeIndex, vertexNewType);
+          for (RelDataTypeField field : leftRex.getType().getFieldList()) {
+            VariableInfo variableInfo;
+            RexNode operand;
+            if (addFieldNames.contains(field.getName())) {
+              // cast right expression to field type.
+              int indexOfAddFields = addFieldNames.indexOf(field.getName());
+              operand =
+                  builder
+                      .getRexBuilder()
+                      .makeCast(field.getType(), addFieldProjects.get(indexOfAddFields));
+              variableInfo = new VariableInfo(false, field.getName());
+              rexNodeMap.set(
+                  addFieldIndices.get(indexOfAddFields),
+                  rexBuilder.makeFieldAccess(leftRex, field.getIndex()));
+            } else {
+              operand = rexBuilder.makeFieldAccess(refPathInput, field.getIndex());
+              variableInfo = new VariableInfo(false, field.getName());
+            }
+            operands.add(operand);
+            rex2VariableInfo.put(operand, variableInfo);
+          }
+          // Construct RexObjectConstruct for dynamic field append expression.
+          RexObjectConstruct rightRex =
+              new RexObjectConstruct(vertexNewType, operands, rex2VariableInfo);
+          List pathModifyExpressions = new ArrayList<>();
+          pathModifyExpressions.add(new PathModifyExpression(leftRex, rightRex));
 
-                    edgeName = projectInput.getPathSchema().firstField().get().getName();
-                    newPathRecordType = newPathRecordType.addField(edgeName, edgeNewType, true);
-                    extendNodeIndex = newPathRecordType.getField(edgeName, true, false).getIndex();
-                    extendNodeLabel = newPathRecordType.getFieldList().get(extendNodeIndex).getName();
-                    PathInputRef leftRex = new PathInputRef(extendNodeLabel, extendNodeIndex, edgeNewType);
+          vertexName =
+              this.generateFieldNames("f", 1, new HashSet<>(newPathRecordType.getFieldNames()))
+                  .get(0);
+          newPathRecordType = newPathRecordType.addField(vertexName, oldType, true);
+          extendNodeIndex = newPathRecordType.getField(vertexName, true, false).getIndex();
+          PathInputRef leftRex2 = new PathInputRef(vertexName, extendNodeIndex, oldType);
+          Map vertexRex2VariableInfo = new HashMap<>();
+          List vertexOperands =
+              refPathInput.getType().getFieldList().stream()
+                  .map(
+                      f -> {
+                        RexNode operand =
+                            builder.getRexBuilder().makeFieldAccess(refPathInput, f.getIndex());
+                        VariableInfo variableInfo = new VariableInfo(false, f.getName());
+                        vertexRex2VariableInfo.put(operand, variableInfo);
+                        return operand;
+                      })
+                  .collect(Collectors.toList());
+          RexObjectConstruct rightRex2 =
+              new RexObjectConstruct(leftRex2.getType(), vertexOperands, vertexRex2VariableInfo);
+          pathModifyExpressions.add(new PathModifyExpression(leftRex2, rightRex2));
 
-                    for (RelDataTypeField field : leftRex.getType().getFieldList()) {
-                        VariableInfo variableInfo;
-                        RexNode operand;
-                        if (addFieldNames.contains(field.getName())) {
-                            // cast right expression to field type.
-                            int indexOfAddFields = addFieldNames.indexOf(field.getName());
-                            operand = builder.getRexBuilder().makeCast(field.getType(),
-                                addFieldProjects.get(addFieldNames.indexOf(field.getName())));
-                            rexNodeMap.set(addFieldIndices.get(indexOfAddFields),
-                                rexBuilder.makeFieldAccess(leftRex, field.getIndex()));
-                        } else {
-                            operand = builder.getRexBuilder()
-                                .makeFieldAccess(leftRex, field.getIndex());
-                        }
-                        variableInfo = new VariableInfo(false, field.getName());
-                        operands.add(operand);
-                        rex2VariableInfo.put(operand, variableInfo);
-                    }
-                    // Construct RexObjectConstruct for dynamic field append expression.
-                    RexObjectConstruct rightRex = new RexObjectConstruct(leftRex.getType(),
-                        operands, rex2VariableInfo);
-                    List pathModifyExpressions = new ArrayList<>();
-                    pathModifyExpressions.add(new PathModifyExpression(leftRex, rightRex));
+          GQLJavaTypeFactory gqlJavaTypeFactory = (GQLJavaTypeFactory) builder.getTypeFactory();
+          GeaFlowGraph currentGraph = gqlJavaTypeFactory.getCurrentGraph();
+          GraphRecordType graphSchema =
+              (GraphRecordType) currentGraph.getRowType(gqlJavaTypeFactory);
+          return MatchExtend.create(
+              projectInput, pathModifyExpressions, newPathRecordType, graphSchema);
+        } else {
+          // Paths start with an edge, add fields on the edge and add an edge extension.
+          edgeNewType = (EdgeRecordType) oldType;
+          addFieldNames =
+              this.generateFieldNames(
+                  "f", addFieldProjects.size(), new HashSet<>(edgeNewType.getFieldNames()));
+          for (int i = 0; i < addFieldNames.size(); i++) {
+            edgeNewType =
+                edgeNewType.add(addFieldNames.get(i), addFieldProjects.get(i).getType(), true);
+          }
 
-                    edgeName = this.generateFieldNames("f", 1,
-                        new HashSet<>(newPathRecordType.getFieldNames())).get(0);
-                    newPathRecordType = newPathRecordType.addField(edgeName, oldType, true);
-                    extendNodeIndex = newPathRecordType.getField(edgeName, true, false).getIndex();
-                    PathInputRef leftRex2 = new PathInputRef(edgeName, extendNodeIndex, oldType);
-                    Map vertexRex2VariableInfo = new HashMap<>();
-                    int firstFieldIndex = projectInput.getPathSchema().firstField().get().getIndex();
-                    PathInputRef refPathInput = new PathInputRef(edgeName, firstFieldIndex, oldType);
-                    List vertexOperands = refPathInput.getType().getFieldList().stream()
-                        .map(f -> {
-                            RexNode operand = builder.getRexBuilder()
-                                .makeFieldAccess(refPathInput, f.getIndex());
-                            VariableInfo variableInfo = new VariableInfo(false, f.getName());
-                            vertexRex2VariableInfo.put(operand, variableInfo);
-                            return operand;
-                        }).collect(Collectors.toList());
-                    RexObjectConstruct rightRex2 = new RexObjectConstruct(leftRex2.getType(),
-                        vertexOperands, vertexRex2VariableInfo);
-                    pathModifyExpressions.add(new PathModifyExpression(leftRex2, rightRex2));
+          edgeName = projectInput.getPathSchema().firstField().get().getName();
+          newPathRecordType = newPathRecordType.addField(edgeName, edgeNewType, true);
+          extendNodeIndex = newPathRecordType.getField(edgeName, true, false).getIndex();
+          extendNodeLabel = newPathRecordType.getFieldList().get(extendNodeIndex).getName();
+          PathInputRef leftRex = new PathInputRef(extendNodeLabel, extendNodeIndex, edgeNewType);
 
-                    GQLJavaTypeFactory gqlJavaTypeFactory =
-                        (GQLJavaTypeFactory) builder.getTypeFactory();
-                    GeaFlowGraph currentGraph = gqlJavaTypeFactory.getCurrentGraph();
-                    GraphRecordType graphSchema = (GraphRecordType) currentGraph.getRowType(
-                        gqlJavaTypeFactory);
-                    return MatchExtend.create(projectInput, pathModifyExpressions,
-                        newPathRecordType, graphSchema);
-                }
+          for (RelDataTypeField field : leftRex.getType().getFieldList()) {
+            VariableInfo variableInfo;
+            RexNode operand;
+            if (addFieldNames.contains(field.getName())) {
+              // cast right expression to field type.
+              int indexOfAddFields = addFieldNames.indexOf(field.getName());
+              operand =
+                  builder
+                      .getRexBuilder()
+                      .makeCast(
+                          field.getType(),
+                          addFieldProjects.get(addFieldNames.indexOf(field.getName())));
+              rexNodeMap.set(
+                  addFieldIndices.get(indexOfAddFields),
+                  rexBuilder.makeFieldAccess(leftRex, field.getIndex()));
             } else {
-                return projectInput;
+              operand = builder.getRexBuilder().makeFieldAccess(leftRex, field.getIndex());
             }
-        } else if (from instanceof LogicalAggregate) {
-            LogicalAggregate aggregate = (LogicalAggregate) from;
-            List inputRexNode2RexInfo = new ArrayList<>();
-            IMatchNode aggregateInput = from == to ? input : concatToMatchNode(builder, left,
-                GQLRelUtil.toRel(aggregate.getInput()), to, input, inputRexNode2RexInfo);
-            int lastNodeIndex = aggregateInput.getPathSchema().getFieldCount() - 1;
-            if (lastNodeIndex < 0) {
-                throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
-            }
-            //MatchAggregate needs to reference path fields, not using indices for referencing,
-            // but using RexNode instead.
-            List adjustAggCalls;
-            List adjustGroupList;
-            Set matchAggPathLabels = new HashSet<>();
-            if (!inputRexNode2RexInfo.isEmpty()) {
-                adjustGroupList = aggregate.getGroupSet().asList().stream()
-                    .map(inputRexNode2RexInfo::get).collect(Collectors.toList());
-                adjustAggCalls = aggregate.getAggCallList().stream().map(
+            variableInfo = new VariableInfo(false, field.getName());
+            operands.add(operand);
+            rex2VariableInfo.put(operand, variableInfo);
+          }
+          // Construct RexObjectConstruct for dynamic field append expression.
+          RexObjectConstruct rightRex =
+              new RexObjectConstruct(leftRex.getType(), operands, rex2VariableInfo);
+          List pathModifyExpressions = new ArrayList<>();
+          pathModifyExpressions.add(new PathModifyExpression(leftRex, rightRex));
+
+          edgeName =
+              this.generateFieldNames("f", 1, new HashSet<>(newPathRecordType.getFieldNames()))
+                  .get(0);
+          newPathRecordType = newPathRecordType.addField(edgeName, oldType, true);
+          extendNodeIndex = newPathRecordType.getField(edgeName, true, false).getIndex();
+          PathInputRef leftRex2 = new PathInputRef(edgeName, extendNodeIndex, oldType);
+          Map vertexRex2VariableInfo = new HashMap<>();
+          int firstFieldIndex = projectInput.getPathSchema().firstField().get().getIndex();
+          PathInputRef refPathInput = new PathInputRef(edgeName, firstFieldIndex, oldType);
+          List vertexOperands =
+              refPathInput.getType().getFieldList().stream()
+                  .map(
+                      f -> {
+                        RexNode operand =
+                            builder.getRexBuilder().makeFieldAccess(refPathInput, f.getIndex());
+                        VariableInfo variableInfo = new VariableInfo(false, f.getName());
+                        vertexRex2VariableInfo.put(operand, variableInfo);
+                        return operand;
+                      })
+                  .collect(Collectors.toList());
+          RexObjectConstruct rightRex2 =
+              new RexObjectConstruct(leftRex2.getType(), vertexOperands, vertexRex2VariableInfo);
+          pathModifyExpressions.add(new PathModifyExpression(leftRex2, rightRex2));
+
+          GQLJavaTypeFactory gqlJavaTypeFactory = (GQLJavaTypeFactory) builder.getTypeFactory();
+          GeaFlowGraph currentGraph = gqlJavaTypeFactory.getCurrentGraph();
+          GraphRecordType graphSchema =
+              (GraphRecordType) currentGraph.getRowType(gqlJavaTypeFactory);
+          return MatchExtend.create(
+              projectInput, pathModifyExpressions, newPathRecordType, graphSchema);
+        }
+      } else {
+        return projectInput;
+      }
+    } else if (from instanceof LogicalAggregate) {
+      LogicalAggregate aggregate = (LogicalAggregate) from;
+      List inputRexNode2RexInfo = new ArrayList<>();
+      IMatchNode aggregateInput =
+          from == to
+              ? input
+              : concatToMatchNode(
+                  builder,
+                  left,
+                  GQLRelUtil.toRel(aggregate.getInput()),
+                  to,
+                  input,
+                  inputRexNode2RexInfo);
+      int lastNodeIndex = aggregateInput.getPathSchema().getFieldCount() - 1;
+      if (lastNodeIndex < 0) {
+        throw new GeaFlowDSLException("Need at least 1 node in the path to rewrite.");
+      }
+      // MatchAggregate needs to reference path fields, not using indices for referencing,
+      // but using RexNode instead.
+      List adjustAggCalls;
+      List adjustGroupList;
+      Set matchAggPathLabels = new HashSet<>();
+      if (!inputRexNode2RexInfo.isEmpty()) {
+        adjustGroupList =
+            aggregate.getGroupSet().asList().stream()
+                .map(inputRexNode2RexInfo::get)
+                .collect(Collectors.toList());
+        adjustAggCalls =
+            aggregate.getAggCallList().stream()
+                .map(
                     aggCall -> {
-                        List newArgList = aggCall.getArgList().stream()
-                            .map(inputRexNode2RexInfo::get).collect(Collectors.toList());
-                        return new MatchAggregateCall(aggCall.getAggregation(), aggCall.isDistinct(),
-                            aggCall.isApproximate(), newArgList, aggCall.filterArg,
-                            aggCall.getCollation(), aggCall.getType(), aggCall.getName());
-                    }).collect(Collectors.toList());
-            } else {
-                String lastNodeLabel = aggregateInput.getPathSchema().getFieldList()
-                    .get(lastNodeIndex).getName();
-                matchAggPathLabels.add(lastNodeLabel);
-                RelDataType oldType = aggregateInput.getPathSchema().getFieldList()
-                    .get(lastNodeIndex).getType();
-                adjustGroupList = aggregate.getGroupSet().asList().stream().map(idx -> builder.getRexBuilder()
-                    .makeFieldAccess(new PathInputRef(lastNodeLabel, lastNodeIndex, oldType),
-                        idx)).collect(Collectors.toList());
-                adjustAggCalls = aggregate.getAggCallList().stream().map(
-                    aggCall -> new MatchAggregateCall(aggCall.getAggregation(),
-                        aggCall.isDistinct(), aggCall.isApproximate(),
-                        aggCall.getArgList().stream().map(idx -> builder.getRexBuilder().makeFieldAccess(
-                            new PathInputRef(lastNodeLabel, lastNodeIndex, oldType), idx)).collect(Collectors.toList()), aggCall.filterArg, aggCall.getCollation(),
-                        aggCall.getType(), aggCall.getName())).collect(Collectors.toList());
-            }
+                      List newArgList =
+                          aggCall.getArgList().stream()
+                              .map(inputRexNode2RexInfo::get)
+                              .collect(Collectors.toList());
+                      return new MatchAggregateCall(
+                          aggCall.getAggregation(),
+                          aggCall.isDistinct(),
+                          aggCall.isApproximate(),
+                          newArgList,
+                          aggCall.filterArg,
+                          aggCall.getCollation(),
+                          aggCall.getType(),
+                          aggCall.getName());
+                    })
+                .collect(Collectors.toList());
+      } else {
+        String lastNodeLabel =
+            aggregateInput.getPathSchema().getFieldList().get(lastNodeIndex).getName();
+        matchAggPathLabels.add(lastNodeLabel);
+        RelDataType oldType =
+            aggregateInput.getPathSchema().getFieldList().get(lastNodeIndex).getType();
+        adjustGroupList =
+            aggregate.getGroupSet().asList().stream()
+                .map(
+                    idx ->
+                        builder
+                            .getRexBuilder()
+                            .makeFieldAccess(
+                                new PathInputRef(lastNodeLabel, lastNodeIndex, oldType), idx))
+                .collect(Collectors.toList());
+        adjustAggCalls =
+            aggregate.getAggCallList().stream()
+                .map(
+                    aggCall ->
+                        new MatchAggregateCall(
+                            aggCall.getAggregation(),
+                            aggCall.isDistinct(),
+                            aggCall.isApproximate(),
+                            aggCall.getArgList().stream()
+                                .map(
+                                    idx ->
+                                        builder
+                                            .getRexBuilder()
+                                            .makeFieldAccess(
+                                                new PathInputRef(
+                                                    lastNodeLabel, lastNodeIndex, oldType),
+                                                idx))
+                                .collect(Collectors.toList()),
+                            aggCall.filterArg,
+                            aggCall.getCollation(),
+                            aggCall.getType(),
+                            aggCall.getName()))
+                .collect(Collectors.toList());
+      }
 
-            //Get the pruning path.
-            if (left != null) {
-                matchAggPathLabels.addAll(left.getPathSchema().getFieldNames());
-            }
-            for (RelDataTypeField field : aggregateInput.getPathSchema().getFieldList()) {
-                if (field.getType() instanceof VertexRecordType && adjustGroupList.stream()
-                    .anyMatch(rexNode -> ((PathInputRef) ((RexFieldAccess) rexNode)
-                        .getReferenceExpr()).getLabel().equals(field.getName())
-                        && rexNode.getType() instanceof MetaFieldType
-                        && ((MetaFieldType) rexNode.getType()).getMetaField()
-                        .equals(MetaField.VERTEX_ID))) {
-                    //The condition for preserving vertex in the after aggregating path is that the
-                    // vertex ID appears in the Group.
-                    matchAggPathLabels.add(field.getName());
-                } else if (field.getType() instanceof EdgeRecordType) {
-                    //The condition for preserving edges in the after aggregating path is that
-                    // the edge srcId, targetId, and timestamp appear in the Group.
-                    //todo: Since the GQL inferred from SQL does not have a Union, we will
-                    // temporarily not consider Labels here.
-                    boolean groupBySrcId = adjustGroupList.stream().anyMatch(rexNode ->
-                        ((PathInputRef) ((RexFieldAccess) rexNode)
-                            .getReferenceExpr()).getLabel().equals(field.getName())
-                            && rexNode.getType() instanceof MetaFieldType
-                            && ((MetaFieldType) rexNode.getType()).getMetaField()
-                            .equals(MetaField.EDGE_SRC_ID));
-                    boolean groupByTargetId = adjustGroupList.stream().anyMatch(rexNode ->
-                        ((PathInputRef) ((RexFieldAccess) rexNode)
-                            .getReferenceExpr()).getLabel().equals(field.getName())
+      // Get the pruning path.
+      if (left != null) {
+        matchAggPathLabels.addAll(left.getPathSchema().getFieldNames());
+      }
+      for (RelDataTypeField field : aggregateInput.getPathSchema().getFieldList()) {
+        if (field.getType() instanceof VertexRecordType
+            && adjustGroupList.stream()
+                .anyMatch(
+                    rexNode ->
+                        ((PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr())
+                                .getLabel()
+                                .equals(field.getName())
                             && rexNode.getType() instanceof MetaFieldType
-                            && ((MetaFieldType) rexNode.getType()).getMetaField()
-                            .equals(MetaField.EDGE_TARGET_ID));
-                    boolean groupByTs = true;
-                    if (((EdgeRecordType) field.getType()).getTimestampField().isPresent()) {
-                        groupByTs = adjustGroupList.stream().anyMatch(rexNode ->
-                            ((PathInputRef) ((RexFieldAccess) rexNode)
-                                .getReferenceExpr()).getLabel().equals(field.getName())
+                            && ((MetaFieldType) rexNode.getType())
+                                .getMetaField()
+                                .equals(MetaField.VERTEX_ID))) {
+          // The condition for preserving vertex in the after aggregating path is that the
+          // vertex ID appears in the Group.
+          matchAggPathLabels.add(field.getName());
+        } else if (field.getType() instanceof EdgeRecordType) {
+          // The condition for preserving edges in the after aggregating path is that
+          // the edge srcId, targetId, and timestamp appear in the Group.
+          // todo: Since the GQL inferred from SQL does not have a Union, we will
+          // temporarily not consider Labels here.
+          boolean groupBySrcId =
+              adjustGroupList.stream()
+                  .anyMatch(
+                      rexNode ->
+                          ((PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr())
+                                  .getLabel()
+                                  .equals(field.getName())
+                              && rexNode.getType() instanceof MetaFieldType
+                              && ((MetaFieldType) rexNode.getType())
+                                  .getMetaField()
+                                  .equals(MetaField.EDGE_SRC_ID));
+          boolean groupByTargetId =
+              adjustGroupList.stream()
+                  .anyMatch(
+                      rexNode ->
+                          ((PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr())
+                                  .getLabel()
+                                  .equals(field.getName())
+                              && rexNode.getType() instanceof MetaFieldType
+                              && ((MetaFieldType) rexNode.getType())
+                                  .getMetaField()
+                                  .equals(MetaField.EDGE_TARGET_ID));
+          boolean groupByTs = true;
+          if (((EdgeRecordType) field.getType()).getTimestampField().isPresent()) {
+            groupByTs =
+                adjustGroupList.stream()
+                    .anyMatch(
+                        rexNode ->
+                            ((PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr())
+                                    .getLabel()
+                                    .equals(field.getName())
                                 && rexNode.getType() instanceof MetaFieldType
-                                && ((MetaFieldType) rexNode.getType()).getMetaField()
-                                .equals(MetaField.EDGE_TS));
-                    }
-                    if (groupBySrcId && groupByTargetId && groupByTs) {
-                        matchAggPathLabels.add(field.getName());
-                    }
-                }
-            }
-            if (matchAggPathLabels.isEmpty()) {
-                matchAggPathLabels.add(aggregateInput.getPathSchema().firstFieldName().get());
-            }
+                                && ((MetaFieldType) rexNode.getType())
+                                    .getMetaField()
+                                    .equals(MetaField.EDGE_TS));
+          }
+          if (groupBySrcId && groupByTargetId && groupByTs) {
+            matchAggPathLabels.add(field.getName());
+          }
+        }
+      }
+      if (matchAggPathLabels.isEmpty()) {
+        matchAggPathLabels.add(aggregateInput.getPathSchema().firstFieldName().get());
+      }
 
-            PathRecordType aggPathType;
-            if (adjustGroupList.size() > 0 || aggregate.getAggCallList().size() > 0) {
-                PathRecordType pathType = aggregateInput.getPathSchema();
-                PathRecordType prunePathType = new PathRecordType(pathType.getFieldList().stream()
+      PathRecordType aggPathType;
+      if (adjustGroupList.size() > 0 || aggregate.getAggCallList().size() > 0) {
+        PathRecordType pathType = aggregateInput.getPathSchema();
+        PathRecordType prunePathType =
+            new PathRecordType(
+                pathType.getFieldList().stream()
                     .filter(f -> matchAggPathLabels.contains(f.getName()))
                     .collect(Collectors.toList()));
-                //Prune the path, and add the aggregated value to the beginning of the path.
-                RelDataType firstNodeType = prunePathType.firstField().get().getType();
-                String firstNodeName = prunePathType.firstField().get().getName();
-                int offset;
-                if (firstNodeType instanceof VertexRecordType) {
-                    VertexRecordType vertexNewType = (VertexRecordType) firstNodeType;
-                    List addFieldNames = this.generateFieldNames("f", adjustGroupList.size(),
-                        new HashSet<>(vertexNewType.getFieldNames()));
-                    offset = vertexNewType.getFieldCount();
-                    for (int i = 0; i < adjustGroupList.size(); i++) {
-                        RelDataType dataType = adjustGroupList.get(i).getType();
-                        vertexNewType = vertexNewType.add(addFieldNames.get(i), dataType, true);
-                    }
-                    addFieldNames = generateFieldNames("agg", aggregate.getAggCallList().size(),
-                        new HashSet<>(vertexNewType.getFieldNames()));
-                    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
-                        vertexNewType = vertexNewType.add(addFieldNames.get(i),
-                            aggregate.getAggCallList().get(i).getType(), true);
-                    }
-                    for (int i = 0; i < adjustGroupList.size() + aggregate.getAggCallList().size(); i++) {
-                        rexNodeMap.add(builder.getRexBuilder().makeFieldAccess(
-                            new PathInputRef(firstNodeName, 0, vertexNewType), offset + i));
-                    }
-                    aggPathType = new PathRecordType(new ArrayList<>())
-                        .addField(firstNodeName, vertexNewType, true);
-                    aggPathType = aggPathType.concat(prunePathType, true);
-                } else if (firstNodeType instanceof EdgeRecordType) {
-                    EdgeRecordType edgeNewType = (EdgeRecordType) firstNodeType;
-                    List addFieldNames = this.generateFieldNames("f", adjustGroupList.size(),
-                        new HashSet<>(edgeNewType.getFieldNames()));
-                    offset = edgeNewType.getFieldCount();
-                    for (int i = 0; i < adjustGroupList.size(); i++) {
-                        RelDataType dataType = adjustGroupList.get(i).getType();
-                        edgeNewType = edgeNewType.add(addFieldNames.get(i), dataType, true);
-                    }
-                    addFieldNames = generateFieldNames("agg", aggregate.getAggCallList().size(),
-                        new HashSet<>(edgeNewType.getFieldNames()));
-                    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
-                        edgeNewType = edgeNewType.add(addFieldNames.get(i),
-                            aggregate.getAggCallList().get(i).getType(), true);
-                    }
-                    for (int i = 0; i < adjustGroupList.size() + aggregate.getAggCallList().size(); i++) {
-                        rexNodeMap.add(builder.getRexBuilder().makeFieldAccess(
-                            new PathInputRef(firstNodeName, 0, edgeNewType), offset + i));
-                    }
-                    aggPathType = new PathRecordType(new ArrayList<>())
-                        .addField(firstNodeName, edgeNewType, true);
-                    aggPathType = aggPathType.concat(prunePathType, true);
-                } else {
-                    throw new GeaFlowDSLException("Path node should be vertex or edge.");
-                }
-                return MatchAggregate.create(aggregateInput, aggregate.indicator, adjustGroupList,
-                    adjustAggCalls, aggPathType);
-            } else {
-                return aggregateInput;
-            }
+        // Prune the path, and add the aggregated value to the beginning of the path.
+        RelDataType firstNodeType = prunePathType.firstField().get().getType();
+        String firstNodeName = prunePathType.firstField().get().getName();
+        int offset;
+        if (firstNodeType instanceof VertexRecordType) {
+          VertexRecordType vertexNewType = (VertexRecordType) firstNodeType;
+          List addFieldNames =
+              this.generateFieldNames(
+                  "f", adjustGroupList.size(), new HashSet<>(vertexNewType.getFieldNames()));
+          offset = vertexNewType.getFieldCount();
+          for (int i = 0; i < adjustGroupList.size(); i++) {
+            RelDataType dataType = adjustGroupList.get(i).getType();
+            vertexNewType = vertexNewType.add(addFieldNames.get(i), dataType, true);
+          }
+          addFieldNames =
+              generateFieldNames(
+                  "agg",
+                  aggregate.getAggCallList().size(),
+                  new HashSet<>(vertexNewType.getFieldNames()));
+          for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
+            vertexNewType =
+                vertexNewType.add(
+                    addFieldNames.get(i), aggregate.getAggCallList().get(i).getType(), true);
+          }
+          for (int i = 0; i < adjustGroupList.size() + aggregate.getAggCallList().size(); i++) {
+            rexNodeMap.add(
+                builder
+                    .getRexBuilder()
+                    .makeFieldAccess(
+                        new PathInputRef(firstNodeName, 0, vertexNewType), offset + i));
+          }
+          aggPathType =
+              new PathRecordType(new ArrayList<>()).addField(firstNodeName, vertexNewType, true);
+          aggPathType = aggPathType.concat(prunePathType, true);
+        } else if (firstNodeType instanceof EdgeRecordType) {
+          EdgeRecordType edgeNewType = (EdgeRecordType) firstNodeType;
+          List addFieldNames =
+              this.generateFieldNames(
+                  "f", adjustGroupList.size(), new HashSet<>(edgeNewType.getFieldNames()));
+          offset = edgeNewType.getFieldCount();
+          for (int i = 0; i < adjustGroupList.size(); i++) {
+            RelDataType dataType = adjustGroupList.get(i).getType();
+            edgeNewType = edgeNewType.add(addFieldNames.get(i), dataType, true);
+          }
+          addFieldNames =
+              generateFieldNames(
+                  "agg",
+                  aggregate.getAggCallList().size(),
+                  new HashSet<>(edgeNewType.getFieldNames()));
+          for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
+            edgeNewType =
+                edgeNewType.add(
+                    addFieldNames.get(i), aggregate.getAggCallList().get(i).getType(), true);
+          }
+          for (int i = 0; i < adjustGroupList.size() + aggregate.getAggCallList().size(); i++) {
+            rexNodeMap.add(
+                builder
+                    .getRexBuilder()
+                    .makeFieldAccess(new PathInputRef(firstNodeName, 0, edgeNewType), offset + i));
+          }
+          aggPathType =
+              new PathRecordType(new ArrayList<>()).addField(firstNodeName, edgeNewType, true);
+          aggPathType = aggPathType.concat(prunePathType, true);
+        } else {
+          throw new GeaFlowDSLException("Path node should be vertex or edge.");
         }
-        return input;
+        return MatchAggregate.create(
+            aggregateInput, aggregate.indicator, adjustGroupList, adjustAggCalls, aggPathType);
+      } else {
+        return aggregateInput;
+      }
     }
+    return input;
+  }
 
-    /**
-     * Generate n non-repeating field names with the format "[$prefix][$index_number]".
-     */
-    protected List generateFieldNames(String prefix, int nameCount,
-                                              Collection existsNames) {
-        if (nameCount <= 0) {
-            return Collections.emptyList();
-        }
-        Set exists = new HashSet<>(existsNames);
-        List validNames = new ArrayList<>(nameCount);
-        int i = 0;
-        while (validNames.size() < nameCount) {
-            String newName = prefix + i;
-            if (!exists.contains(newName)) {
-                validNames.add(newName);
-            }
-            i++;
-        }
-        return validNames;
+  /** Generate n non-repeating field names with the format "[$prefix][$index_number]". */
+  protected List generateFieldNames(
+      String prefix, int nameCount, Collection existsNames) {
+    if (nameCount <= 0) {
+      return Collections.emptyList();
     }
+    Set exists = new HashSet<>(existsNames);
+    List validNames = new ArrayList<>(nameCount);
+    int i = 0;
+    while (validNames.size() < nameCount) {
+      String newName = prefix + i;
+      if (!exists.contains(newName)) {
+        validNames.add(newName);
+      }
+      i++;
+    }
+    return validNames;
+  }
 
-    /**
-     * Replace the references to path nodes in the leftRexNodes with references to
-     * the newGraphMatch directly, based on the label of the reference.
-     */
-    protected List adjustLeftRexNodes(List leftRexNodes, GraphMatch newGraphMatch,
-                                               RelBuilder builder) {
-        return leftRexNodes.stream().map(prj -> GQLRexUtil.replace(prj, rexNode -> {
-            if (rexNode instanceof RexFieldAccess
-                && ((RexFieldAccess) rexNode).getReferenceExpr() instanceof PathInputRef) {
-                PathInputRef pathInputRef =
-                    (PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr();
-                String label = pathInputRef.getLabel();
-                int index = newGraphMatch.getRowType().getField(label, true, false).getIndex();
-                RelDataType type = newGraphMatch.getRowType().getField(label, true, false)
-                    .getType();
-                PathInputRef vertexRef = new PathInputRef(label, index, type);
-                return builder.getRexBuilder().makeFieldAccess(vertexRef, ((RexFieldAccess) rexNode).getField().getIndex());
+  /**
+   * Replace the references to path nodes in the leftRexNodes with references to the newGraphMatch
+   * directly, based on the label of the reference.
+   */
+  protected List adjustLeftRexNodes(
+      List leftRexNodes, GraphMatch newGraphMatch, RelBuilder builder) {
+    return leftRexNodes.stream()
+        .map(
+            prj ->
+                GQLRexUtil.replace(
+                    prj,
+                    rexNode -> {
+                      if (rexNode instanceof RexFieldAccess
+                          && ((RexFieldAccess) rexNode).getReferenceExpr()
+                              instanceof PathInputRef) {
+                        PathInputRef pathInputRef =
+                            (PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr();
+                        String label = pathInputRef.getLabel();
+                        int index =
+                            newGraphMatch.getRowType().getField(label, true, false).getIndex();
+                        RelDataType type =
+                            newGraphMatch.getRowType().getField(label, true, false).getType();
+                        PathInputRef vertexRef = new PathInputRef(label, index, type);
+                        return builder
+                            .getRexBuilder()
+                            .makeFieldAccess(
+                                vertexRef, ((RexFieldAccess) rexNode).getField().getIndex());
+                      }
+                      return rexNode;
+                    }))
+        .collect(Collectors.toList());
+  }
 
-            }
-            return rexNode;
-        })).collect(Collectors.toList());
+  /**
+   * Replace the references to path nodes in the leftRexNodes with references to the newGraphMatch,
+   * based on the label of the reference. If the label of the reference exists in the left branch,
+   * automatically add an offset to make it a reference to the same name Node in the right branch.
+   */
+  protected List adjustRightRexNodes(
+      List rightRexNodes,
+      GraphMatch newGraphMatch,
+      RelBuilder builder,
+      IMatchNode leftPathPattern,
+      IMatchNode rightPathPattern) {
+    final IMatchNode finalLeft = leftPathPattern;
+    final IMatchNode finalRight = rightPathPattern;
+    return rightRexNodes.stream()
+        .map(
+            prj ->
+                GQLRexUtil.replace(
+                    prj,
+                    rexNode -> {
+                      if (rexNode instanceof RexFieldAccess
+                          && ((RexFieldAccess) rexNode).getReferenceExpr()
+                              instanceof PathInputRef) {
+                        PathInputRef pathInputRef =
+                            (PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr();
+                        String label = pathInputRef.getLabel();
+                        boolean isConflictLabel =
+                            leftPathPattern.getPathSchema().getFieldNames().contains(label);
+                        if (isConflictLabel) {
+                          int index =
+                              finalRight.getRowType().getField(label, true, false).getIndex();
+                          index += finalLeft.getRowType().getFieldCount();
+                          label = newGraphMatch.getRowType().getFieldList().get(index).getName();
+                          RelDataType type =
+                              newGraphMatch.getRowType().getFieldList().get(index).getType();
+                          PathInputRef vertexRef = new PathInputRef(label, index, type);
+                          return builder
+                              .getRexBuilder()
+                              .makeFieldAccess(
+                                  vertexRef, ((RexFieldAccess) rexNode).getField().getIndex());
+                        } else {
+                          int index =
+                              newGraphMatch.getRowType().getField(label, true, false).getIndex();
+                          RelDataType type =
+                              newGraphMatch.getRowType().getFieldList().get(index).getType();
+                          PathInputRef vertexRef = new PathInputRef(label, index, type);
+                          return builder
+                              .getRexBuilder()
+                              .makeFieldAccess(
+                                  vertexRef, ((RexFieldAccess) rexNode).getField().getIndex());
+                        }
+                      }
+                      return rexNode;
+                    }))
+        .collect(Collectors.toList());
+  }
+
+  /**
+   * Handling the transformation of SQL RelNodes for the MatchJoinTableToGraphMatch and
+   * TableJoinMatchToGraphMatch rules into GQL matches.
+   */
+  protected RelNode processGraphMatchJoinTable(
+      RelOptRuleCall call,
+      LogicalJoin join,
+      LogicalGraphMatch graphMatch,
+      LogicalProject project,
+      LogicalTableScan tableScan,
+      RelNode leftInput,
+      RelNode leftHead,
+      RelNode rightInput,
+      RelNode rightHead,
+      boolean isMatchInLeft) {
+    GeaFlowTable geaflowTable = tableScan.getTable().unwrap(GeaFlowTable.class);
+    GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
+    GeaFlowGraph currentGraph = typeFactory.getCurrentGraph();
+    if (!currentGraph.containTable(geaflowTable)) {
+      if (geaflowTable instanceof VertexTable || geaflowTable instanceof EdgeTable) {
+        throw new GeaFlowDSLException(
+            "Unknown graph element: {}, use graph please.", geaflowTable.getName());
+      }
+      return null;
     }
+    GraphJoinType graphJoinType = getJoinType(join);
+    RelDataType tableType = tableScan.getRowType();
+    IMatchNode matchNode = graphMatch.getPathPattern();
+    RelBuilder relBuilder = call.builder();
+    List rexLeftNodeMap = new ArrayList<>();
+    List rexRightNodeMap = new ArrayList<>();
+    IMatchNode concatedLeftMatch;
+    IMatchNode newPathPattern;
+    boolean isEdgeReverse = false;
+    switch (graphJoinType) {
+      case EDGE_JOIN_VERTEX:
+      case VERTEX_JOIN_EDGE:
+        if (geaflowTable instanceof VertexTable) { // graph match join vertex table
+          if (matchNode instanceof SingleMatchNode
+              && GQLRelUtil.getLatestMatchNode((SingleMatchNode) matchNode) instanceof EdgeMatch) {
+            concatedLeftMatch =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, matchNode, rexLeftNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, matchNode, rexRightNodeMap);
+            String nodeName = geaflowTable.getName();
+            PathRecordType pathRecordType =
+                concatedLeftMatch.getPathSchema().getFieldNames().contains(nodeName)
+                    ? concatedLeftMatch.getPathSchema()
+                    : concatedLeftMatch.getPathSchema().addField(nodeName, tableType, false);
+            switch (join.getJoinType()) {
+              case LEFT:
+                newPathPattern =
+                    OptionalVertexMatch.create(
+                        concatedLeftMatch.getCluster(),
+                        (SingleMatchNode) concatedLeftMatch,
+                        nodeName,
+                        Collections.singletonList(nodeName),
+                        tableType,
+                        pathRecordType);
+                break;
+              case INNER:
+                newPathPattern =
+                    VertexMatch.create(
+                        concatedLeftMatch.getCluster(),
+                        (SingleMatchNode) concatedLeftMatch,
+                        nodeName,
+                        Collections.singletonList(nodeName),
+                        tableType,
+                        pathRecordType);
+                break;
+              case RIGHT:
+              case FULL:
+              default:
+                throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
+            }
+            newPathPattern =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder,
+                        concatedLeftMatch,
+                        rightInput,
+                        rightHead,
+                        newPathPattern,
+                        rexRightNodeMap)
+                    : concatToMatchNode(
+                        relBuilder,
+                        concatedLeftMatch,
+                        leftInput,
+                        leftHead,
+                        newPathPattern,
+                        rexLeftNodeMap);
+          } else {
+            String nodeName = geaflowTable.getName();
+            assert currentGraph.getVertexTables().stream()
+                .anyMatch(t -> t.getName().equalsIgnoreCase(geaflowTable.getName()));
+            concatedLeftMatch =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, matchNode, rexLeftNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, matchNode, rexRightNodeMap);
+            RelDataType vertexRelType = geaflowTable.getRowType(relBuilder.getTypeFactory());
+            PathRecordType rightPathType =
+                PathRecordType.EMPTY.addField(geaflowTable.getName(), vertexRelType, false);
+            VertexMatch rightVertexMatch =
+                VertexMatch.create(
+                    concatedLeftMatch.getCluster(),
+                    null,
+                    nodeName,
+                    Collections.singletonList(geaflowTable.getName()),
+                    vertexRelType,
+                    rightPathType);
+            IMatchNode matchJoinRight =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, rightVertexMatch, rexRightNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, rightVertexMatch, rexLeftNodeMap);
+            MatchJoin matchJoin =
+                MatchJoin.create(
+                    concatedLeftMatch.getCluster(),
+                    concatedLeftMatch.getTraitSet(),
+                    concatedLeftMatch,
+                    matchJoinRight,
+                    relBuilder.getRexBuilder().makeLiteral(true),
+                    join.getJoinType());
 
-    /**
-     * Replace the references to path nodes in the leftRexNodes with references to
-     * the newGraphMatch, based on the label of the reference. If the label of the reference
-     * exists in the left branch, automatically add an offset to make it a reference to the same
-     * name Node in the right branch.
-     */
-    protected List adjustRightRexNodes(List rightRexNodes,
-                                                GraphMatch newGraphMatch, RelBuilder builder,
-                                                IMatchNode leftPathPattern,
-                                                IMatchNode rightPathPattern) {
-        final IMatchNode finalLeft = leftPathPattern;
-        final IMatchNode finalRight = rightPathPattern;
-        return rightRexNodes.stream().map(prj -> GQLRexUtil.replace(prj, rexNode -> {
-            if (rexNode instanceof RexFieldAccess
-                && ((RexFieldAccess) rexNode).getReferenceExpr() instanceof PathInputRef) {
-                PathInputRef pathInputRef =
-                    (PathInputRef) ((RexFieldAccess) rexNode).getReferenceExpr();
-                String label = pathInputRef.getLabel();
-                boolean isConflictLabel = leftPathPattern.getPathSchema().getFieldNames()
-                    .contains(label);
-                if (isConflictLabel) {
-                    int index = finalRight.getRowType().getField(label, true, false).getIndex();
-                    index += finalLeft.getRowType().getFieldCount();
-                    label = newGraphMatch.getRowType().getFieldList().get(index).getName();
-                    RelDataType type = newGraphMatch.getRowType().getFieldList().get(index)
-                        .getType();
-                    PathInputRef vertexRef = new PathInputRef(label, index, type);
-                    return builder.getRexBuilder().makeFieldAccess(vertexRef,
-                        ((RexFieldAccess) rexNode).getField().getIndex());
-                } else {
-                    int index = newGraphMatch.getRowType().getField(label, true, false).getIndex();
-                    RelDataType type = newGraphMatch.getRowType().getFieldList().get(index)
-                        .getType();
-                    PathInputRef vertexRef = new PathInputRef(label, index, type);
-                    return builder.getRexBuilder().makeFieldAccess(vertexRef,
-                        ((RexFieldAccess) rexNode).getField().getIndex());
+            PathInputRef vertexRef =
+                new PathInputRef(
+                    nodeName,
+                    matchJoin.getRowType().getField(nodeName, true, false).getIndex(),
+                    matchJoin.getRowType().getField(nodeName, true, false).getType());
+            RexNode operand1 =
+                relBuilder.getRexBuilder().makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
+            RelDataTypeField field =
+                matchJoin
+                    .getRowType()
+                    .getFieldList()
+                    .get(
+                        matchJoin.getRowType().getFieldCount()
+                            - rightVertexMatch.getRowType().getFieldCount());
+            vertexRef = new PathInputRef(field.getName(), field.getIndex(), field.getType());
+            RexNode operand2 =
+                relBuilder.getRexBuilder().makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
+            SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
+            RexNode condition =
+                relBuilder.getRexBuilder().makeCall(equalsOperator, operand1, operand2);
+            newPathPattern =
+                matchJoin.copy(
+                    matchJoin.getTraitSet(),
+                    condition,
+                    matchJoin.getLeft(),
+                    matchJoin.getRight(),
+                    matchJoin.getJoinType());
+          }
+        } else if (geaflowTable instanceof EdgeTable) {
+          if (matchNode instanceof SingleMatchNode
+              && GQLRelUtil.getLatestMatchNode((SingleMatchNode) matchNode)
+                  instanceof VertexMatch) {
+            concatedLeftMatch =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, matchNode, rexLeftNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, matchNode, rexRightNodeMap);
+            isEdgeReverse = graphJoinType.equals(GraphJoinType.EDGE_JOIN_VERTEX);
+            EdgeDirection edgeDirection = isEdgeReverse ? EdgeDirection.IN : EdgeDirection.OUT;
+            String edgeName = geaflowTable.getName();
+            PathRecordType pathRecordType =
+                concatedLeftMatch.getPathSchema().getFieldNames().contains(edgeName)
+                    ? concatedLeftMatch.getPathSchema()
+                    : concatedLeftMatch.getPathSchema().addField(edgeName, tableType, false);
+            switch (join.getJoinType()) {
+              case LEFT:
+                if (!isMatchInLeft) {
+                  LOGGER.warn("Left table cannot be forcibly retained. Use INNER Join instead.");
                 }
+                newPathPattern =
+                    OptionalEdgeMatch.create(
+                        concatedLeftMatch.getCluster(),
+                        (SingleMatchNode) concatedLeftMatch,
+                        edgeName,
+                        Collections.singletonList(edgeName),
+                        edgeDirection,
+                        tableType,
+                        pathRecordType);
+                break;
+              case INNER:
+                newPathPattern =
+                    EdgeMatch.create(
+                        concatedLeftMatch.getCluster(),
+                        (SingleMatchNode) concatedLeftMatch,
+                        edgeName,
+                        Collections.singletonList(edgeName),
+                        edgeDirection,
+                        tableType,
+                        pathRecordType);
+                break;
+              case RIGHT:
+              case FULL:
+              default:
+                throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
             }
-            return rexNode;
-        })).collect(Collectors.toList());
-    }
+            newPathPattern =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder,
+                        concatedLeftMatch,
+                        rightInput,
+                        rightHead,
+                        newPathPattern,
+                        rexRightNodeMap)
+                    : concatToMatchNode(
+                        relBuilder,
+                        concatedLeftMatch,
+                        leftInput,
+                        leftHead,
+                        newPathPattern,
+                        rexLeftNodeMap);
 
-    /**
-     * Handling the transformation of SQL RelNodes for the MatchJoinTableToGraphMatch
-     * and TableJoinMatchToGraphMatch rules into GQL matches.
-     */
-    protected RelNode processGraphMatchJoinTable(RelOptRuleCall call, LogicalJoin join,
-                                                 LogicalGraphMatch graphMatch,
-                                                 LogicalProject project,
-                                                 LogicalTableScan tableScan,
-                                                 RelNode leftInput, RelNode leftHead,
-                                                 RelNode rightInput, RelNode rightHead,
-                                                 boolean isMatchInLeft) {
-        GeaFlowTable geaflowTable = tableScan.getTable().unwrap(GeaFlowTable.class);
-        GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
-        GeaFlowGraph currentGraph = typeFactory.getCurrentGraph();
-        if (!currentGraph.containTable(geaflowTable)) {
-            if (geaflowTable instanceof VertexTable || geaflowTable instanceof EdgeTable) {
-                throw new GeaFlowDSLException("Unknown graph element: {}, use graph please.",
-                    geaflowTable.getName());
+          } else {
+            String edgeName = geaflowTable.getName();
+            GraphDescriptor graphDescriptor = currentGraph.getDescriptor();
+            Optional edgeDesc =
+                graphDescriptor.edges.stream()
+                    .filter(e -> e.type.equals(geaflowTable.getName()))
+                    .findFirst();
+            VertexTable dummyVertex = null;
+            if (edgeDesc.isPresent()) {
+              EdgeDescriptor edgeDescriptor = edgeDesc.get();
+              String dummyNodeType = edgeDescriptor.sourceType;
+              dummyVertex =
+                  currentGraph.getVertexTables().stream()
+                      .filter(v -> v.getName().equals(dummyNodeType))
+                      .findFirst()
+                      .get();
             }
-            return null;
-        }
-        GraphJoinType graphJoinType = getJoinType(join);
-        RelDataType tableType = tableScan.getRowType();
-        IMatchNode matchNode = graphMatch.getPathPattern();
-        RelBuilder relBuilder = call.builder();
-        List rexLeftNodeMap = new ArrayList<>();
-        List rexRightNodeMap = new ArrayList<>();
-        IMatchNode concatedLeftMatch;
-        IMatchNode newPathPattern;
-        boolean isEdgeReverse = false;
-        switch (graphJoinType) {
-            case EDGE_JOIN_VERTEX:
-            case VERTEX_JOIN_EDGE:
-                if (geaflowTable instanceof VertexTable) { // graph match join vertex table
-                    if (matchNode instanceof SingleMatchNode && GQLRelUtil.getLatestMatchNode(
-                        (SingleMatchNode) matchNode) instanceof EdgeMatch) {
-                        concatedLeftMatch =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, leftInput,
-                                leftHead, matchNode, rexLeftNodeMap)
-                                : concatToMatchNode(relBuilder, null, rightInput,
-                                rightHead, matchNode, rexRightNodeMap);
-                        String nodeName = geaflowTable.getName();
-                        PathRecordType pathRecordType =
-                            concatedLeftMatch.getPathSchema().getFieldNames().contains(nodeName)
-                                ? concatedLeftMatch.getPathSchema()
-                                : concatedLeftMatch.getPathSchema().addField(nodeName, tableType, false);
-                        switch (join.getJoinType()) {
-                            case LEFT:
-                                newPathPattern = OptionalVertexMatch.create(concatedLeftMatch.getCluster(),
-                                    (SingleMatchNode) concatedLeftMatch, nodeName,
-                                    Collections.singletonList(nodeName), tableType, pathRecordType);
-                                break;
-                            case INNER:
-                                newPathPattern = VertexMatch.create(concatedLeftMatch.getCluster(),
-                                    (SingleMatchNode) concatedLeftMatch, nodeName,
-                                    Collections.singletonList(nodeName), tableType, pathRecordType);
-                                break;
-                            case RIGHT:
-                            case FULL:
-                            default:
-                                throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
-                        }
-                        newPathPattern =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, concatedLeftMatch, rightInput,
-                                rightHead, newPathPattern, rexRightNodeMap)
-                                : concatToMatchNode(relBuilder, concatedLeftMatch, leftInput,
-                                leftHead, newPathPattern, rexLeftNodeMap);
-                    } else {
-                        String nodeName = geaflowTable.getName();
-                        assert currentGraph.getVertexTables().stream()
-                            .anyMatch(t -> t.getName().equalsIgnoreCase(geaflowTable.getName()));
-                        concatedLeftMatch =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, leftInput, leftHead,
-                                matchNode, rexLeftNodeMap)
-                                : concatToMatchNode(relBuilder, null, rightInput, rightHead,
-                                matchNode, rexRightNodeMap);
-                        RelDataType vertexRelType = geaflowTable.getRowType(
-                            relBuilder.getTypeFactory());
-                        PathRecordType rightPathType = PathRecordType.EMPTY.addField(geaflowTable.getName(),
-                            vertexRelType, false);
-                        VertexMatch rightVertexMatch = VertexMatch.create(concatedLeftMatch.getCluster(), null,
-                            nodeName, Collections.singletonList(geaflowTable.getName()),
-                            vertexRelType, rightPathType);
-                        IMatchNode matchJoinRight =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, rightInput,
-                                rightHead, rightVertexMatch, rexRightNodeMap)
-                                : concatToMatchNode(relBuilder, null, leftInput,
-                                leftHead, rightVertexMatch, rexLeftNodeMap);
-                        MatchJoin matchJoin = MatchJoin.create(concatedLeftMatch.getCluster(),
-                            concatedLeftMatch.getTraitSet(), concatedLeftMatch, matchJoinRight,
-                            relBuilder.getRexBuilder().makeLiteral(true), join.getJoinType());
-
-                        PathInputRef vertexRef = new PathInputRef(nodeName,
-                            matchJoin.getRowType().getField(nodeName, true, false).getIndex(),
-                            matchJoin.getRowType().getField(nodeName, true, false).getType());
-                        RexNode operand1 = relBuilder.getRexBuilder()
-                            .makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
-                        RelDataTypeField field = matchJoin.getRowType().getFieldList().get(
-                            matchJoin.getRowType().getFieldCount() - rightVertexMatch.getRowType()
-                                .getFieldCount());
-                        vertexRef = new PathInputRef(field.getName(), field.getIndex(),
-                            field.getType());
-                        RexNode operand2 = relBuilder.getRexBuilder()
-                            .makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
-                        SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
-                        RexNode condition = relBuilder.getRexBuilder()
-                            .makeCall(equalsOperator, operand1, operand2);
-                        newPathPattern = matchJoin.copy(matchJoin.getTraitSet(), condition,
-                            matchJoin.getLeft(), matchJoin.getRight(), matchJoin.getJoinType());
-                    }
-                } else if (geaflowTable instanceof EdgeTable) {
-                    if (matchNode instanceof SingleMatchNode && GQLRelUtil.getLatestMatchNode(
-                        (SingleMatchNode) matchNode) instanceof VertexMatch) {
-                        concatedLeftMatch =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, leftInput, leftHead,
-                                matchNode, rexLeftNodeMap)
-                                : concatToMatchNode(relBuilder, null, rightInput, rightHead,
-                                matchNode, rexRightNodeMap);
-                        isEdgeReverse = graphJoinType.equals(GraphJoinType.EDGE_JOIN_VERTEX);
-                        EdgeDirection edgeDirection =
-                            isEdgeReverse ? EdgeDirection.IN : EdgeDirection.OUT;
-                        String edgeName = geaflowTable.getName();
-                        PathRecordType pathRecordType =
-                            concatedLeftMatch.getPathSchema().getFieldNames().contains(edgeName)
-                                ? concatedLeftMatch.getPathSchema()
-                                : concatedLeftMatch.getPathSchema().addField(edgeName, tableType, false);
-                        switch (join.getJoinType()) {
-                            case LEFT:
-                                if (!isMatchInLeft) {
-                                    LOGGER.warn("Left table cannot be forcibly retained. Use INNER Join instead.");
-                                }
-                                newPathPattern = OptionalEdgeMatch.create(concatedLeftMatch.getCluster(),
-                                    (SingleMatchNode) concatedLeftMatch, edgeName,
-                                    Collections.singletonList(edgeName), edgeDirection, tableType,
-                                    pathRecordType);
-                                break;
-                            case INNER:
-                                newPathPattern = EdgeMatch.create(concatedLeftMatch.getCluster(),
-                                    (SingleMatchNode) concatedLeftMatch, edgeName,
-                                    Collections.singletonList(edgeName), edgeDirection, tableType,
-                                    pathRecordType);
-                                break;
-                            case RIGHT:
-                            case FULL:
-                            default:
-                                throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
-                        }
-                        newPathPattern =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, concatedLeftMatch, rightInput,
-                                rightHead, newPathPattern, rexRightNodeMap)
-                                : concatToMatchNode(relBuilder, concatedLeftMatch, leftInput,
-                                leftHead, newPathPattern, rexLeftNodeMap);
-
-                    } else {
-                        String edgeName = geaflowTable.getName();
-                        GraphDescriptor graphDescriptor = currentGraph.getDescriptor();
-                        Optional edgeDesc = graphDescriptor.edges.stream()
-                            .filter(e -> e.type.equals(geaflowTable.getName())).findFirst();
-                        VertexTable dummyVertex = null;
-                        if (edgeDesc.isPresent()) {
-                            EdgeDescriptor edgeDescriptor = edgeDesc.get();
-                            String dummyNodeType = edgeDescriptor.sourceType;
-                            dummyVertex = currentGraph.getVertexTables().stream()
-                                .filter(v -> v.getName().equals(dummyNodeType)).findFirst().get();
-                        }
-                        if (dummyVertex == null) {
-                            return null;
-                        }
-                        String dummyNodeName = dummyVertex.getName();
-                        RelDataType dummyVertexRelType = dummyVertex.getRowType(
-                            relBuilder.getTypeFactory());
-                        PathRecordType pathRecordType = new PathRecordType(
-                            new ArrayList<>()).addField(dummyNodeName, dummyVertexRelType, true);
-                        VertexMatch dummyVertexMatch = VertexMatch.create(matchNode.getCluster(),
-                            null, dummyNodeName, Collections.singletonList(dummyVertex.getName()),
-                            dummyVertexRelType, pathRecordType);
-                        RelDataType edgeRelType = geaflowTable.getRowType(
-                            relBuilder.getTypeFactory());
-                        pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
-                        EdgeMatch edgeMatch = EdgeMatch.create(matchNode.getCluster(),
-                            dummyVertexMatch, edgeName,
-                            Collections.singletonList(geaflowTable.getName()), EdgeDirection.OUT,
-                            edgeRelType, pathRecordType);
+            if (dummyVertex == null) {
+              return null;
+            }
+            String dummyNodeName = dummyVertex.getName();
+            RelDataType dummyVertexRelType = dummyVertex.getRowType(relBuilder.getTypeFactory());
+            PathRecordType pathRecordType =
+                new PathRecordType(new ArrayList<>())
+                    .addField(dummyNodeName, dummyVertexRelType, true);
+            VertexMatch dummyVertexMatch =
+                VertexMatch.create(
+                    matchNode.getCluster(),
+                    null,
+                    dummyNodeName,
+                    Collections.singletonList(dummyVertex.getName()),
+                    dummyVertexRelType,
+                    pathRecordType);
+            RelDataType edgeRelType = geaflowTable.getRowType(relBuilder.getTypeFactory());
+            pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
+            EdgeMatch edgeMatch =
+                EdgeMatch.create(
+                    matchNode.getCluster(),
+                    dummyVertexMatch,
+                    edgeName,
+                    Collections.singletonList(geaflowTable.getName()),
+                    EdgeDirection.OUT,
+                    edgeRelType,
+                    pathRecordType);
 
-                        concatedLeftMatch =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, leftInput, leftHead,
-                                matchNode, rexLeftNodeMap)
-                                : concatToMatchNode(relBuilder, null, rightInput, rightHead,
-                                matchNode, rexRightNodeMap);
-                        IMatchNode matchJoinRight =
-                            isMatchInLeft ? concatToMatchNode(relBuilder, null, rightInput,
-                                rightHead, edgeMatch, rexRightNodeMap)
-                                : concatToMatchNode(relBuilder, null, leftInput,
-                                leftHead, edgeMatch, rexLeftNodeMap);
-                        MatchJoin matchJoin = MatchJoin.create(matchNode.getCluster(),
-                            matchNode.getTraitSet(), concatedLeftMatch, matchJoinRight,
-                            relBuilder.getRexBuilder().makeLiteral(true), join.getJoinType());
+            concatedLeftMatch =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, matchNode, rexLeftNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, matchNode, rexRightNodeMap);
+            IMatchNode matchJoinRight =
+                isMatchInLeft
+                    ? concatToMatchNode(
+                        relBuilder, null, rightInput, rightHead, edgeMatch, rexRightNodeMap)
+                    : concatToMatchNode(
+                        relBuilder, null, leftInput, leftHead, edgeMatch, rexLeftNodeMap);
+            MatchJoin matchJoin =
+                MatchJoin.create(
+                    matchNode.getCluster(),
+                    matchNode.getTraitSet(),
+                    concatedLeftMatch,
+                    matchJoinRight,
+                    relBuilder.getRexBuilder().makeLiteral(true),
+                    join.getJoinType());
 
-                        PathInputRef vertexRef = new PathInputRef(dummyNodeName,
-                            matchJoin.getRowType().getField(dummyNodeName, true, false).getIndex(),
-                            matchJoin.getRowType().getField(dummyNodeName, true, false).getType());
-                        RexNode operand1 = relBuilder.getRexBuilder()
-                            .makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
-                        RelDataTypeField field = matchJoin.getRowType().getFieldList().get(
-                            matchJoin.getRowType().getFieldCount() - edgeMatch.getRowType()
-                                .getFieldCount());
-                        vertexRef = new PathInputRef(field.getName(), field.getIndex(),
-                            field.getType());
-                        RexNode operand2 = relBuilder.getRexBuilder()
-                            .makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
-                        SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
-                        RexNode condition = relBuilder.getRexBuilder()
-                            .makeCall(equalsOperator, operand1, operand2);
-                        newPathPattern = matchJoin.copy(matchJoin.getTraitSet(), condition,
-                            matchJoin.getLeft(), matchJoin.getRight(), matchJoin.getJoinType());
-                    }
-                } else {
-                    return null;
-                }
-                break;
-            default:
-                return null;
-        }
-        if (newPathPattern == null) {
-            return null;
-        }
-        GraphMatch newGraphMatch = graphMatch.copy(newPathPattern);
-        // Add the original Projects from the GraphMatch branch, filtering out fields that
-        // no longer exist after rebuilding GraphMatch inputs.
-        List oldProjects = project.getProjects().stream().filter(prj -> {
-            if (prj instanceof RexFieldAccess
-                && ((RexFieldAccess) prj).getReferenceExpr() instanceof PathInputRef) {
-                PathInputRef pathInputRef =
-                    (PathInputRef) ((RexFieldAccess) prj).getReferenceExpr();
-                String label = pathInputRef.getLabel();
-                RelDataTypeField pathField = newGraphMatch.getRowType().getField(label, true, false);
-                if (pathField != null) {
-                    RexFieldAccess fieldAccess = (RexFieldAccess) prj;
-                    int index = fieldAccess.getField().getIndex();
-                    return index < pathField.getType().getFieldList().size()
-                        && fieldAccess.getField().equals(pathField.getType().getFieldList().get(index));
-                }
-            }
-            return false;
-        }).collect(Collectors.toList());
-        List newProjects = new ArrayList<>();
-        if (newPathPattern instanceof MatchJoin) {
-            newProjects.addAll(
-                isMatchInLeft ? adjustLeftRexNodes(oldProjects, newGraphMatch,
-                    relBuilder) : adjustRightRexNodes(oldProjects, newGraphMatch,
-                    relBuilder, (IMatchNode) newPathPattern.getInput(0),
-                    (IMatchNode) newPathPattern.getInput(1)));
+            PathInputRef vertexRef =
+                new PathInputRef(
+                    dummyNodeName,
+                    matchJoin.getRowType().getField(dummyNodeName, true, false).getIndex(),
+                    matchJoin.getRowType().getField(dummyNodeName, true, false).getType());
+            RexNode operand1 =
+                relBuilder.getRexBuilder().makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
+            RelDataTypeField field =
+                matchJoin
+                    .getRowType()
+                    .getFieldList()
+                    .get(
+                        matchJoin.getRowType().getFieldCount()
+                            - edgeMatch.getRowType().getFieldCount());
+            vertexRef = new PathInputRef(field.getName(), field.getIndex(), field.getType());
+            RexNode operand2 =
+                relBuilder.getRexBuilder().makeFieldAccess(vertexRef, VertexType.ID_FIELD_POSITION);
+            SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
+            RexNode condition =
+                relBuilder.getRexBuilder().makeCall(equalsOperator, operand1, operand2);
+            newPathPattern =
+                matchJoin.copy(
+                    matchJoin.getTraitSet(),
+                    condition,
+                    matchJoin.getLeft(),
+                    matchJoin.getRight(),
+                    matchJoin.getJoinType());
+          }
         } else {
-            newProjects.addAll(adjustLeftRexNodes(oldProjects, newGraphMatch, relBuilder));
+          return null;
         }
+        break;
+      default:
+        return null;
+    }
+    if (newPathPattern == null) {
+      return null;
+    }
+    GraphMatch newGraphMatch = graphMatch.copy(newPathPattern);
+    // Add the original Projects from the GraphMatch branch, filtering out fields that
+    // no longer exist after rebuilding GraphMatch inputs.
+    List oldProjects =
+        project.getProjects().stream()
+            .filter(
+                prj -> {
+                  if (prj instanceof RexFieldAccess
+                      && ((RexFieldAccess) prj).getReferenceExpr() instanceof PathInputRef) {
+                    PathInputRef pathInputRef =
+                        (PathInputRef) ((RexFieldAccess) prj).getReferenceExpr();
+                    String label = pathInputRef.getLabel();
+                    RelDataTypeField pathField =
+                        newGraphMatch.getRowType().getField(label, true, false);
+                    if (pathField != null) {
+                      RexFieldAccess fieldAccess = (RexFieldAccess) prj;
+                      int index = fieldAccess.getField().getIndex();
+                      return index < pathField.getType().getFieldList().size()
+                          && fieldAccess
+                              .getField()
+                              .equals(pathField.getType().getFieldList().get(index));
+                    }
+                  }
+                  return false;
+                })
+            .collect(Collectors.toList());
+    List newProjects = new ArrayList<>();
+    if (newPathPattern instanceof MatchJoin) {
+      newProjects.addAll(
+          isMatchInLeft
+              ? adjustLeftRexNodes(oldProjects, newGraphMatch, relBuilder)
+              : adjustRightRexNodes(
+                  oldProjects,
+                  newGraphMatch,
+                  relBuilder,
+                  (IMatchNode) newPathPattern.getInput(0),
+                  (IMatchNode) newPathPattern.getInput(1)));
+    } else {
+      newProjects.addAll(adjustLeftRexNodes(oldProjects, newGraphMatch, relBuilder));
+    }
 
-        RexBuilder rexBuilder = relBuilder.getRexBuilder();
-        // Add fields of the Table into the projects.
-        String tableName = geaflowTable.getName();
-        RelDataTypeField pathTableField = newPathPattern.getPathSchema()
-            .getField(tableName, true, false);
-        List tableProjects = new ArrayList<>();
-        if (pathTableField != null) {
-            int baseOffset = newProjects.size();
-            PathInputRef pathTableRef = new PathInputRef(tableName, pathTableField.getIndex(),
-                pathTableField.getType());
-            tableProjects = tableType.getFieldList().stream()
-                .map(f -> rexBuilder.makeFieldAccess(pathTableRef, f.getIndex()))
-                .collect(Collectors.toList());
-            newProjects.addAll(tableProjects);
-
-            //In the case of reverse matching in the IN direction, the positions of the source
-            // vertex and the destination vertex are swapped.
-            if (isEdgeReverse) {
-                int edgeSrcIdIndex = tableType.getFieldList().stream().filter(
-                    f -> f.getType() instanceof MetaFieldType
-                            && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_SRC_ID))
-                    .collect(Collectors.toList()).get(0).getIndex();
-                int edgeTargetIdIndex = tableType.getFieldList().stream().filter(
-                    f -> f.getType() instanceof MetaFieldType
-                            && ((MetaFieldType) f.getType()).getMetaField()
-                            .equals(MetaField.EDGE_TARGET_ID)).collect(Collectors.toList()).get(0)
-                    .getIndex();
-                Collections.swap(newProjects, baseOffset + edgeSrcIdIndex,
-                    baseOffset + edgeTargetIdIndex);
-            }
-        }
+    RexBuilder rexBuilder = relBuilder.getRexBuilder();
+    // Add fields of the Table into the projects.
+    String tableName = geaflowTable.getName();
+    RelDataTypeField pathTableField =
+        newPathPattern.getPathSchema().getField(tableName, true, false);
+    List tableProjects = new ArrayList<>();
+    if (pathTableField != null) {
+      int baseOffset = newProjects.size();
+      PathInputRef pathTableRef =
+          new PathInputRef(tableName, pathTableField.getIndex(), pathTableField.getType());
+      tableProjects =
+          tableType.getFieldList().stream()
+              .map(f -> rexBuilder.makeFieldAccess(pathTableRef, f.getIndex()))
+              .collect(Collectors.toList());
+      newProjects.addAll(tableProjects);
 
-        // Add fields newly added in the rebuild of the left branch.
-        if (rexLeftNodeMap.size() > 0) {
-            List tmpLeftProjects = new ArrayList<>(rexLeftNodeMap);
-            newProjects.addAll(adjustLeftRexNodes(tmpLeftProjects, newGraphMatch, relBuilder));
-        }
+      // In the case of reverse matching in the IN direction, the positions of the source
+      // vertex and the destination vertex are swapped.
+      if (isEdgeReverse) {
+        int edgeSrcIdIndex =
+            tableType.getFieldList().stream()
+                .filter(
+                    f ->
+                        f.getType() instanceof MetaFieldType
+                            && ((MetaFieldType) f.getType())
+                                .getMetaField()
+                                .equals(MetaField.EDGE_SRC_ID))
+                .collect(Collectors.toList())
+                .get(0)
+                .getIndex();
+        int edgeTargetIdIndex =
+            tableType.getFieldList().stream()
+                .filter(
+                    f ->
+                        f.getType() instanceof MetaFieldType
+                            && ((MetaFieldType) f.getType())
+                                .getMetaField()
+                                .equals(MetaField.EDGE_TARGET_ID))
+                .collect(Collectors.toList())
+                .get(0)
+                .getIndex();
+        Collections.swap(newProjects, baseOffset + edgeSrcIdIndex, baseOffset + edgeTargetIdIndex);
+      }
+    }
 
-        // Add fields newly added in the rebuild of the right branch.
-        if (rexRightNodeMap.size() > 0) {
-            List tmpRightProjects = new ArrayList<>(rexRightNodeMap);
-            if (newPathPattern instanceof MatchJoin) {
-                newProjects.addAll(adjustRightRexNodes(tmpRightProjects, newGraphMatch, relBuilder,
-                    (IMatchNode) newPathPattern.getInput(0),
-                    (IMatchNode) newPathPattern.getInput(1)));
-            } else {
-                newProjects.addAll(adjustLeftRexNodes(tmpRightProjects, newGraphMatch, relBuilder));
-            }
-        }
+    // Add fields newly added in the rebuild of the left branch.
+    if (rexLeftNodeMap.size() > 0) {
+      List tmpLeftProjects = new ArrayList<>(rexLeftNodeMap);
+      newProjects.addAll(adjustLeftRexNodes(tmpLeftProjects, newGraphMatch, relBuilder));
+    }
 
-        // Complete the projection from Path to Row.
-        List matchTypeFields = new ArrayList<>();
-        List newFieldNames = this.generateFieldNames("f", newProjects.size(), new HashSet<>());
-        for (int i = 0; i < newProjects.size(); i++) {
-            matchTypeFields.add(
-                new RelDataTypeFieldImpl(newFieldNames.get(i), i, newProjects.get(i).getType()));
-        }
-        RelNode tail = LogicalProject.create(newGraphMatch, newProjects,
-            new RelRecordType(matchTypeFields));
+    // Add fields newly added in the rebuild of the right branch.
+    if (rexRightNodeMap.size() > 0) {
+      List tmpRightProjects = new ArrayList<>(rexRightNodeMap);
+      if (newPathPattern instanceof MatchJoin) {
+        newProjects.addAll(
+            adjustRightRexNodes(
+                tmpRightProjects,
+                newGraphMatch,
+                relBuilder,
+                (IMatchNode) newPathPattern.getInput(0),
+                (IMatchNode) newPathPattern.getInput(1)));
+      } else {
+        newProjects.addAll(adjustLeftRexNodes(tmpRightProjects, newGraphMatch, relBuilder));
+      }
+    }
 
-        // Complete the Join projection.
-        if (newPathPattern instanceof MatchJoin) {
-            rexRightNodeMap = adjustRightRexNodes(rexRightNodeMap, newGraphMatch, relBuilder,
-                (IMatchNode) newPathPattern.getInput(0), (IMatchNode) newPathPattern.getInput(1));
-        }
-        List joinProjects = new ArrayList<>();
-        //  If the left branch undergoes rebuilding, take the reconstructed Rex from
-        //  the left branch, otherwise take the original Projects.
-        final RelNode finalTail = tail;
-        int projectFieldCount = oldProjects.size();
-        int joinFieldCount = projectFieldCount + tableProjects.size();
-        if (rexLeftNodeMap.size() > 0) {
-            joinProjects.addAll(
-                IntStream.range(joinFieldCount, joinFieldCount + rexLeftNodeMap.size())
-                    .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
-                    .collect(Collectors.toList()));
-        } else {
-            if (isMatchInLeft) {
-                joinProjects.addAll(IntStream.range(0, projectFieldCount)
-                    .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
-                    .collect(Collectors.toList()));
-            } else {
-                joinProjects.addAll(IntStream.range(projectFieldCount, joinFieldCount)
-                    .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
-                    .collect(Collectors.toList()));
-            }
-        }
-        // If the right branch undergoes rebuilding, take the reconstructed Rex from the right
-        // branch, otherwise take the original Projects.
-        if (rexRightNodeMap.size() > 0) {
-            joinProjects.addAll(IntStream.range(joinFieldCount + rexLeftNodeMap.size(),
-                    joinFieldCount + rexLeftNodeMap.size() + rexRightNodeMap.size())
-                .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i)).collect(Collectors.toList()));
-        } else {
-            if (isMatchInLeft) {
-                joinProjects.addAll(IntStream.range(projectFieldCount, joinFieldCount)
-                    .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
-                    .collect(Collectors.toList()));
-            } else {
-                joinProjects.addAll(IntStream.range(0, projectFieldCount)
-                    .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
-                    .collect(Collectors.toList()));
-            }
-        }
-        AtomicInteger offset = new AtomicInteger();
-        // Make the project type nullable the same as the output type of the join.
-        joinProjects = joinProjects.stream().map(prj -> {
-            int i = offset.getAndIncrement();
-            boolean joinFieldNullable = join.getRowType().getFieldList().get(i).getType().isNullable();
-            if ((prj.getType().isNullable() && !joinFieldNullable)
-                || (!prj.getType().isNullable() && joinFieldNullable)) {
-                RelDataType type = rexBuilder.getTypeFactory().createTypeWithNullability(prj.getType(), joinFieldNullable);
-                return rexBuilder.makeCast(type, prj);
-            }
-            return prj;
-        }).collect(Collectors.toList());
-        tail = LogicalProject.create(tail, joinProjects, join.getRowType());
-        return tail;
+    // Complete the projection from Path to Row.
+    List matchTypeFields = new ArrayList<>();
+    List newFieldNames = this.generateFieldNames("f", newProjects.size(), new HashSet<>());
+    for (int i = 0; i < newProjects.size(); i++) {
+      matchTypeFields.add(
+          new RelDataTypeFieldImpl(newFieldNames.get(i), i, newProjects.get(i).getType()));
     }
+    RelNode tail =
+        LogicalProject.create(newGraphMatch, newProjects, new RelRecordType(matchTypeFields));
 
-    public enum GraphJoinType {
-        /**
-         * Vertex join edge src id.
-         */
-        VERTEX_JOIN_EDGE,
-        /**
-         * Edge target id join vertex.
-         */
-        EDGE_JOIN_VERTEX,
-        /**
-         * None graph match type.
-         */
-        NONE_GRAPH_JOIN
+    // Complete the Join projection.
+    if (newPathPattern instanceof MatchJoin) {
+      rexRightNodeMap =
+          adjustRightRexNodes(
+              rexRightNodeMap,
+              newGraphMatch,
+              relBuilder,
+              (IMatchNode) newPathPattern.getInput(0),
+              (IMatchNode) newPathPattern.getInput(1));
+    }
+    List joinProjects = new ArrayList<>();
+    //  If the left branch undergoes rebuilding, take the reconstructed Rex from
+    //  the left branch, otherwise take the original Projects.
+    final RelNode finalTail = tail;
+    int projectFieldCount = oldProjects.size();
+    int joinFieldCount = projectFieldCount + tableProjects.size();
+    if (rexLeftNodeMap.size() > 0) {
+      joinProjects.addAll(
+          IntStream.range(joinFieldCount, joinFieldCount + rexLeftNodeMap.size())
+              .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+              .collect(Collectors.toList()));
+    } else {
+      if (isMatchInLeft) {
+        joinProjects.addAll(
+            IntStream.range(0, projectFieldCount)
+                .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+                .collect(Collectors.toList()));
+      } else {
+        joinProjects.addAll(
+            IntStream.range(projectFieldCount, joinFieldCount)
+                .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+                .collect(Collectors.toList()));
+      }
     }
+    // If the right branch undergoes rebuilding, take the reconstructed Rex from the right
+    // branch, otherwise take the original Projects.
+    if (rexRightNodeMap.size() > 0) {
+      joinProjects.addAll(
+          IntStream.range(
+                  joinFieldCount + rexLeftNodeMap.size(),
+                  joinFieldCount + rexLeftNodeMap.size() + rexRightNodeMap.size())
+              .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+              .collect(Collectors.toList()));
+    } else {
+      if (isMatchInLeft) {
+        joinProjects.addAll(
+            IntStream.range(projectFieldCount, joinFieldCount)
+                .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+                .collect(Collectors.toList()));
+      } else {
+        joinProjects.addAll(
+            IntStream.range(0, projectFieldCount)
+                .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+                .collect(Collectors.toList()));
+      }
+    }
+    AtomicInteger offset = new AtomicInteger();
+    // Make the project type nullable the same as the output type of the join.
+    joinProjects =
+        joinProjects.stream()
+            .map(
+                prj -> {
+                  int i = offset.getAndIncrement();
+                  boolean joinFieldNullable =
+                      join.getRowType().getFieldList().get(i).getType().isNullable();
+                  if ((prj.getType().isNullable() && !joinFieldNullable)
+                      || (!prj.getType().isNullable() && joinFieldNullable)) {
+                    RelDataType type =
+                        rexBuilder
+                            .getTypeFactory()
+                            .createTypeWithNullability(prj.getType(), joinFieldNullable);
+                    return rexBuilder.makeCast(type, prj);
+                  }
+                  return prj;
+                })
+            .collect(Collectors.toList());
+    tail = LogicalProject.create(tail, joinProjects, join.getRowType());
+    return tail;
+  }
+
+  public enum GraphJoinType {
+    /** Vertex join edge src id. */
+    VERTEX_JOIN_EDGE,
+    /** Edge target id join vertex. */
+    EDGE_JOIN_VERTEX,
+    /** None graph match type. */
+    NONE_GRAPH_JOIN
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AddVertexResetRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AddVertexResetRule.java
index b47ccf0d6..cfcd79ff8 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AddVertexResetRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/AddVertexResetRule.java
@@ -23,6 +23,7 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -40,57 +41,61 @@
 
 public class AddVertexResetRule extends RelOptRule {
 
-    public static final AddVertexResetRule INSTANCE = new AddVertexResetRule();
+  public static final AddVertexResetRule INSTANCE = new AddVertexResetRule();
 
-    private AddVertexResetRule() {
-        super(operand(SingleMatchNode.class,
-            operand(IMatchNode.class, any())));
-    }
+  private AddVertexResetRule() {
+    super(operand(SingleMatchNode.class, operand(IMatchNode.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        SingleMatchNode matchNode = call.rel(0);
-        IMatchNode inputNode = call.rel(1);
-        // If the match node contains sub-query and the input output type is virtual vertex type.
-        // we should reset the latest vertex to it's task for sub query calling.
-        List lambdaCalls = GQLRexUtil.collect(matchNode,
-            rexNode -> rexNode instanceof RexLambdaCall);
-        if (lambdaCalls.isEmpty()) {
-            return;
-        }
-        if (!(inputNode.getNodeType() instanceof VirtualVertexRecordType)) {
-            return;
-        }
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    SingleMatchNode matchNode = call.rel(0);
+    IMatchNode inputNode = call.rel(1);
+    // If the match node contains sub-query and the input output type is virtual vertex type.
+    // we should reset the latest vertex to it's task for sub query calling.
+    List lambdaCalls =
+        GQLRexUtil.collect(matchNode, rexNode -> rexNode instanceof RexLambdaCall);
+    if (lambdaCalls.isEmpty()) {
+      return;
+    }
+    if (!(inputNode.getNodeType() instanceof VirtualVertexRecordType)) {
+      return;
+    }
 
-        Set startVertices = new HashSet<>();
+    Set startVertices = new HashSet<>();
 
-        for (RexLambdaCall lambdaCall : lambdaCalls) {
-            int startVertex = getStartVertexIndex(lambdaCall);
-            startVertices.add(startVertex);
-        }
-        if (startVertices.size() > 1) {
-            return;
-        }
-        PathRecordType inputPathType = inputNode.getPathSchema();
-        // start vertex should be the last field in the input path type.
-        if (startVertices.iterator().next() != inputPathType.getFieldCount() - 1) {
-            return;
-        }
-        int startVertexIndex = inputPathType.getFieldCount() - 1;
-        RexInputRef startVertexRef = call.builder().getRexBuilder()
-            .makeInputRef(inputPathType.getFieldList().get(startVertexIndex).getType(),
-                startVertexIndex);
+    for (RexLambdaCall lambdaCall : lambdaCalls) {
+      int startVertex = getStartVertexIndex(lambdaCall);
+      startVertices.add(startVertex);
+    }
+    if (startVertices.size() > 1) {
+      return;
+    }
+    PathRecordType inputPathType = inputNode.getPathSchema();
+    // start vertex should be the last field in the input path type.
+    if (startVertices.iterator().next() != inputPathType.getFieldCount() - 1) {
+      return;
+    }
+    int startVertexIndex = inputPathType.getFieldCount() - 1;
+    RexInputRef startVertexRef =
+        call.builder()
+            .getRexBuilder()
+            .makeInputRef(
+                inputPathType.getFieldList().get(startVertexIndex).getType(), startVertexIndex);
 
-        RexNode startVertexIdRef = call.builder().getRexBuilder()
+    RexNode startVertexIdRef =
+        call.builder()
+            .getRexBuilder()
             .makeFieldAccess(startVertexRef, VertexType.ID_FIELD_POSITION);
-        VirtualEdgeMatch virtualEdgeMatch = VirtualEdgeMatch.create(inputNode, startVertexIdRef, inputPathType);
-        RelNode newMatchNode = matchNode.copy(matchNode.getTraitSet(),
-            Collections.singletonList(virtualEdgeMatch));
-        call.transformTo(newMatchNode);
-    }
+    VirtualEdgeMatch virtualEdgeMatch =
+        VirtualEdgeMatch.create(inputNode, startVertexIdRef, inputPathType);
+    RelNode newMatchNode =
+        matchNode.copy(matchNode.getTraitSet(), Collections.singletonList(virtualEdgeMatch));
+    call.transformTo(newMatchNode);
+  }
 
-    private int getStartVertexIndex(RexLambdaCall lambdaCall) {
-        SingleMatchNode matchNode = (SingleMatchNode) lambdaCall.getInput().rel;
-        return GQLRelUtil.getFirstMatchNode(matchNode).getPathSchema().getFieldCount() - 1;
-    }
+  private int getStartVertexIndex(RexLambdaCall lambdaCall) {
+    SingleMatchNode matchNode = (SingleMatchNode) lambdaCall.getInput().rel;
+    return GQLRelUtil.getFirstMatchNode(matchNode).getPathSchema().getFieldCount() - 1;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterMatchNodeTransposeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterMatchNodeTransposeRule.java
index 8146bfc08..2082c474d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterMatchNodeTransposeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterMatchNodeTransposeRule.java
@@ -19,9 +19,9 @@
 
 package org.apache.geaflow.dsl.optimize.rule;
 
-import com.google.common.collect.Lists;
 import java.util.ArrayList;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
@@ -36,76 +36,80 @@
 import org.apache.geaflow.dsl.rex.RexLambdaCall;
 import org.apache.geaflow.dsl.util.GQLRexUtil;
 
+import com.google.common.collect.Lists;
+
 public class FilterMatchNodeTransposeRule extends RelOptRule {
 
-    public static final FilterMatchNodeTransposeRule INSTANCE = new FilterMatchNodeTransposeRule();
+  public static final FilterMatchNodeTransposeRule INSTANCE = new FilterMatchNodeTransposeRule();
 
-    private FilterMatchNodeTransposeRule() {
-        super(operand(MatchFilter.class,
-            operand(IMatchLabel.class, any())));
-    }
+  private FilterMatchNodeTransposeRule() {
+    super(operand(MatchFilter.class, operand(IMatchLabel.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchFilter filter = call.rel(0);
-        IMatchLabel matchLabel = call.rel(1);
-        List conditions = RelOptUtil.conjunctions(filter.getCondition());
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchFilter filter = call.rel(0);
+    IMatchLabel matchLabel = call.rel(1);
+    List conditions = RelOptUtil.conjunctions(filter.getCondition());
 
-        List pushes = new ArrayList<>();
-        List remains = new ArrayList<>();
-        for (RexNode condition : conditions) {
-            if (canPush(condition, matchLabel)) {
-                pushes.add(condition);
-            } else {
-                remains.add(condition);
-            }
-        }
-        if (pushes.isEmpty()) {
-            return;
-        }
-        RexBuilder builder = call.builder().getRexBuilder();
-        MatchFilter pushFilter = filter.copy(filter.getTraitSet(),
+    List pushes = new ArrayList<>();
+    List remains = new ArrayList<>();
+    for (RexNode condition : conditions) {
+      if (canPush(condition, matchLabel)) {
+        pushes.add(condition);
+      } else {
+        remains.add(condition);
+      }
+    }
+    if (pushes.isEmpty()) {
+      return;
+    }
+    RexBuilder builder = call.builder().getRexBuilder();
+    MatchFilter pushFilter =
+        filter.copy(
+            filter.getTraitSet(),
             matchLabel.getInput(),
             GQLRexUtil.and(pushes, builder),
             (PathRecordType) matchLabel.getInput().getRowType());
-        IMatchNode newMatchNode = (IMatchNode) matchLabel.copy(matchLabel.getTraitSet(),
-            Lists.newArrayList(pushFilter));
+    IMatchNode newMatchNode =
+        (IMatchNode) matchLabel.copy(matchLabel.getTraitSet(), Lists.newArrayList(pushFilter));
 
-        if (remains.isEmpty()) {
-            call.transformTo(newMatchNode);
-        } else {
-            MatchFilter remainFiler = MatchFilter.create(newMatchNode,
-                GQLRexUtil.and(remains, builder),
-                filter.getPathSchema());
-            call.transformTo(remainFiler);
-        }
+    if (remains.isEmpty()) {
+      call.transformTo(newMatchNode);
+    } else {
+      MatchFilter remainFiler =
+          MatchFilter.create(
+              newMatchNode, GQLRexUtil.and(remains, builder), filter.getPathSchema());
+      call.transformTo(remainFiler);
     }
+  }
 
-    private boolean canPush(RexNode condition, IMatchLabel match) {
-        if (GQLRexUtil.contain(condition, RexLambdaCall.class)) {
-            return false;
-        }
-        List fieldAccesses = GQLRexUtil.collect(condition, node -> node instanceof RexFieldAccess);
-        for (RexFieldAccess fieldAccess : fieldAccesses) {
-            if (fieldAccess.getReferenceExpr() instanceof RexInputRef) {
-                RexInputRef pathRef = (RexInputRef) fieldAccess.getReferenceExpr();
-                if (isRefCurrentNode(pathRef, match)) {
-                    return false;
-                }
-            }
-        }
-        List pathRefs = GQLRexUtil.collect(condition, node -> node instanceof RexInputRef);
-        for (RexInputRef pathRef : pathRefs) {
-            if (isRefCurrentNode(pathRef, match)) {
-                return false;
-            }
+  private boolean canPush(RexNode condition, IMatchLabel match) {
+    if (GQLRexUtil.contain(condition, RexLambdaCall.class)) {
+      return false;
+    }
+    List fieldAccesses =
+        GQLRexUtil.collect(condition, node -> node instanceof RexFieldAccess);
+    for (RexFieldAccess fieldAccess : fieldAccesses) {
+      if (fieldAccess.getReferenceExpr() instanceof RexInputRef) {
+        RexInputRef pathRef = (RexInputRef) fieldAccess.getReferenceExpr();
+        if (isRefCurrentNode(pathRef, match)) {
+          return false;
         }
-        return true;
+      }
     }
-
-    private boolean isRefCurrentNode(RexInputRef pathRef, IMatchNode match) {
-        // Test if the condition has referred current match node.
-        // We cannot push the condition down as current match node has been referred by the condition.
-        return pathRef.getIndex() == match.getPathSchema().getFieldCount() - 1;
+    List pathRefs = GQLRexUtil.collect(condition, node -> node instanceof RexInputRef);
+    for (RexInputRef pathRef : pathRefs) {
+      if (isRefCurrentNode(pathRef, match)) {
+        return false;
+      }
     }
+    return true;
+  }
+
+  private boolean isRefCurrentNode(RexInputRef pathRef, IMatchNode match) {
+    // Test if the condition has referred current match node.
+    // We cannot push the condition down as current match node has been referred by the condition.
+    return pathRef.getIndex() == match.getPathSchema().getFieldCount() - 1;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterToMatchRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterToMatchRule.java
index 9d0ca776c..0fafbca8f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterToMatchRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/FilterToMatchRule.java
@@ -28,23 +28,22 @@
 
 public class FilterToMatchRule extends RelOptRule {
 
-    public static final FilterToMatchRule INSTANCE = new FilterToMatchRule();
+  public static final FilterToMatchRule INSTANCE = new FilterToMatchRule();
 
-    private FilterToMatchRule() {
-        super(operand(LogicalFilter.class,
-            operand(LogicalGraphMatch.class, any())));
-    }
+  private FilterToMatchRule() {
+    super(operand(LogicalFilter.class, operand(LogicalGraphMatch.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        LogicalFilter filter = call.rel(0);
-        LogicalGraphMatch graphMatch = call.rel(1);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    LogicalFilter filter = call.rel(0);
+    LogicalGraphMatch graphMatch = call.rel(1);
 
-        IMatchNode pathPattern = graphMatch.getPathPattern();
+    IMatchNode pathPattern = graphMatch.getPathPattern();
 
-        MatchFilter matchFilter = MatchFilter.create(pathPattern,
-            filter.getCondition(), pathPattern.getPathSchema());
-        LogicalGraphMatch newMatch = (LogicalGraphMatch) graphMatch.copy(matchFilter);
-        call.transformTo(newMatch);
-    }
+    MatchFilter matchFilter =
+        MatchFilter.create(pathPattern, filter.getCondition(), pathPattern.getPathSchema());
+    LogicalGraphMatch newMatch = (LogicalGraphMatch) graphMatch.copy(matchFilter);
+    call.transformTo(newMatch);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLAggregateProjectMergeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLAggregateProjectMergeRule.java
index 7193ebc43..c0c10b4aa 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLAggregateProjectMergeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLAggregateProjectMergeRule.java
@@ -29,22 +29,23 @@
 
 public class GQLAggregateProjectMergeRule extends AggregateProjectMergeRule {
 
-    public static final GQLAggregateProjectMergeRule INSTANCE =
-        new GQLAggregateProjectMergeRule(Aggregate.class,
-            Project.class, RelFactories.LOGICAL_BUILDER);
+  public static final GQLAggregateProjectMergeRule INSTANCE =
+      new GQLAggregateProjectMergeRule(
+          Aggregate.class, Project.class, RelFactories.LOGICAL_BUILDER);
 
-    public GQLAggregateProjectMergeRule(Class aggregateClass,
-                                        Class projectClass,
-                                        RelBuilderFactory relBuilderFactory) {
-        super(aggregateClass, projectClass, relBuilderFactory);
-    }
+  public GQLAggregateProjectMergeRule(
+      Class aggregateClass,
+      Class projectClass,
+      RelBuilderFactory relBuilderFactory) {
+    super(aggregateClass, projectClass, relBuilderFactory);
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        Project project = call.rel(1);
-        if (GQLRelUtil.isGQLMatchRelNode(project.getInput())) {
-            return false;
-        }
-        return super.matches(call);
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    Project project = call.rel(1);
+    if (GQLRelUtil.isGQLMatchRelNode(project.getInput())) {
+      return false;
     }
+    return super.matches(call);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLMatchUnionMergeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLMatchUnionMergeRule.java
index 0968e6a84..17f09ee66 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLMatchUnionMergeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLMatchUnionMergeRule.java
@@ -21,6 +21,7 @@
 
 import java.util.ArrayList;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -28,26 +29,30 @@
 
 public class GQLMatchUnionMergeRule extends RelOptRule {
 
-    public static final GQLMatchUnionMergeRule INSTANCE = new GQLMatchUnionMergeRule();
+  public static final GQLMatchUnionMergeRule INSTANCE = new GQLMatchUnionMergeRule();
 
-    private GQLMatchUnionMergeRule() {
-        super(operand(MatchUnion.class, operand(MatchUnion.class, any())));
-    }
+  private GQLMatchUnionMergeRule() {
+    super(operand(MatchUnion.class, operand(MatchUnion.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchUnion topMatchUnion = call.rel(0);
-        MatchUnion bottomMatchUnion = call.rel(1);
-        if (topMatchUnion.isDistinct()
-            || !bottomMatchUnion.isDistinct() && !topMatchUnion.isDistinct()) { //distinct
-            List newInputs = new ArrayList<>();
-            newInputs.addAll(bottomMatchUnion.getInputs());
-            for (int i = 1; i < topMatchUnion.getInputs().size(); i++) {
-                newInputs.add(topMatchUnion.getInput(i));
-            }
-            MatchUnion newMatchUnion = MatchUnion.create(topMatchUnion.getCluster(),
-                topMatchUnion.getTraitSet(), newInputs, topMatchUnion.all);
-            call.transformTo(newMatchUnion);
-        }
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchUnion topMatchUnion = call.rel(0);
+    MatchUnion bottomMatchUnion = call.rel(1);
+    if (topMatchUnion.isDistinct()
+        || !bottomMatchUnion.isDistinct() && !topMatchUnion.isDistinct()) { // distinct
+      List newInputs = new ArrayList<>();
+      newInputs.addAll(bottomMatchUnion.getInputs());
+      for (int i = 1; i < topMatchUnion.getInputs().size(); i++) {
+        newInputs.add(topMatchUnion.getInput(i));
+      }
+      MatchUnion newMatchUnion =
+          MatchUnion.create(
+              topMatchUnion.getCluster(),
+              topMatchUnion.getTraitSet(),
+              newInputs,
+              topMatchUnion.all);
+      call.transformTo(newMatchUnion);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLProjectRemoveRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLProjectRemoveRule.java
index 9d207f3bd..f0a36b4f3 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLProjectRemoveRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GQLProjectRemoveRule.java
@@ -27,24 +27,24 @@
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
 public class GQLProjectRemoveRule extends ProjectRemoveRule {
-    public static final GQLProjectRemoveRule INSTANCE =
-        new GQLProjectRemoveRule(RelFactories.LOGICAL_BUILDER);
+  public static final GQLProjectRemoveRule INSTANCE =
+      new GQLProjectRemoveRule(RelFactories.LOGICAL_BUILDER);
 
-    /**
-     * Creates a ProjectRemoveRule.
-     *
-     * @param relBuilderFactory Builder for relational expressions
-     */
-    public GQLProjectRemoveRule(RelBuilderFactory relBuilderFactory) {
-        super(relBuilderFactory);
-    }
+  /**
+   * Creates a ProjectRemoveRule.
+   *
+   * @param relBuilderFactory Builder for relational expressions
+   */
+  public GQLProjectRemoveRule(RelBuilderFactory relBuilderFactory) {
+    super(relBuilderFactory);
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        Project project = call.rel(0);
-        if (GQLRelUtil.isGQLMatchRelNode(project.getInput())) {
-            return false;
-        }
-        return super.matches(call);
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    Project project = call.rel(0);
+    if (GQLRelUtil.isGQLMatchRelNode(project.getInput())) {
+      return false;
     }
+    return super.matches(call);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GraphMatchFieldPruneRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GraphMatchFieldPruneRule.java
index 005d91beb..27e819894 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GraphMatchFieldPruneRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/GraphMatchFieldPruneRule.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.optimize.rule;
 
 import java.util.*;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -34,203 +35,194 @@
 import org.apache.geaflow.dsl.rex.RexObjectConstruct;
 
 /**
- * Rule to prune unnecessary fields within GraphMatch operations by analyzing
- * filter conditions, path modifications, joins, and extends.
+ * Rule to prune unnecessary fields within GraphMatch operations by analyzing filter conditions,
+ * path modifications, joins, and extends.
  */
 public class GraphMatchFieldPruneRule extends RelOptRule {
 
-    public static final GraphMatchFieldPruneRule INSTANCE = new GraphMatchFieldPruneRule();
+  public static final GraphMatchFieldPruneRule INSTANCE = new GraphMatchFieldPruneRule();
 
-    private GraphMatchFieldPruneRule() {
-        // Match only a single LogicalGraphMatch node
-        super(operand(LogicalGraphMatch.class, any()), "GraphMatchFieldPruneRule");
-    }
+  private GraphMatchFieldPruneRule() {
+    // Match only a single LogicalGraphMatch node
+    super(operand(LogicalGraphMatch.class, any()), "GraphMatchFieldPruneRule");
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        LogicalGraphMatch graphMatch = call.rel(0);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    LogicalGraphMatch graphMatch = call.rel(0);
 
-        // 1. Extract field access information from LogicalGraphMatch
-        Set filteredElements = getFilteredElements(graphMatch);
+    // 1. Extract field access information from LogicalGraphMatch
+    Set filteredElements = getFilteredElements(graphMatch);
 
-        // 2. Pass the filtered field information to the path pattern
-        if (!filteredElements.isEmpty()) {
-            traverseAndPruneFields(filteredElements, graphMatch.getPathPattern());
-        }
+    // 2. Pass the filtered field information to the path pattern
+    if (!filteredElements.isEmpty()) {
+      traverseAndPruneFields(filteredElements, graphMatch.getPathPattern());
     }
+  }
 
-    /**
-     * Extract filtered field elements from GraphMatch.
-     */
-    public Set getFilteredElements(GraphMatch graphMatch) {
-        // Recursively extract field usage from conditions in the match node
-        return extractFromMatchNode(graphMatch.getPathPattern());
-    }
+  /** Extract filtered field elements from GraphMatch. */
+  public Set getFilteredElements(GraphMatch graphMatch) {
+    // Recursively extract field usage from conditions in the match node
+    return extractFromMatchNode(graphMatch.getPathPattern());
+  }
 
-    /**
-     * Recursively traverse the MatchNode to extract RexFieldAccess.
-     */
-    private Set extractFromMatchNode(IMatchNode matchNode) {
-        Set allFilteredFields = new HashSet<>();
+  /** Recursively traverse the MatchNode to extract RexFieldAccess. */
+  private Set extractFromMatchNode(IMatchNode matchNode) {
+    Set allFilteredFields = new HashSet<>();
 
-        if (matchNode == null) {
-            return allFilteredFields;
-        }
+    if (matchNode == null) {
+      return allFilteredFields;
+    }
 
-        // Process expressions in the current node
-        if (matchNode instanceof MatchFilter) {
-            MatchFilter filterNode = (MatchFilter) matchNode;
-            Set rawFields = extractFromRexNode(filterNode.getCondition());
-            allFilteredFields.addAll(convertToPathRefs(rawFields, filterNode));
-
-        } else if (matchNode instanceof MatchPathModify) {
-            MatchPathModify pathModifyNode = (MatchPathModify) matchNode;
-            RexObjectConstruct expression = pathModifyNode.getExpressions().get(0).getObjectConstruct();
-            Set rawFields = extractFromRexNode(expression);
-            allFilteredFields.addAll(convertToPathRefs(rawFields, matchNode));
-
-        } else if (matchNode instanceof MatchJoin) {
-            MatchJoin joinNode = (MatchJoin) matchNode;
-            if (joinNode.getCondition() != null) {
-                Set rawFields = extractFromRexNode(joinNode.getCondition());
-                allFilteredFields.addAll(convertToPathRefs(rawFields, joinNode));
-            }
-
-        } else if (matchNode instanceof MatchExtend) {
-            // For MatchExtend, check CAST attributes
-            MatchExtend extendNode = (MatchExtend) matchNode;
-            for (PathModifyExpression expression : extendNode.getExpressions()) {
-                for (RexNode extendOperands : expression.getObjectConstruct().getOperands()) {
-                    if (extendOperands instanceof RexCall) {
-                        // Only consider non-primitive property projections (CAST)
-                        Set rawFields = extractFromRexNode(extendOperands);
-                        allFilteredFields.addAll(convertToPathRefs(rawFields, extendNode));
-                    }
-                }
-            }
+    // Process expressions in the current node
+    if (matchNode instanceof MatchFilter) {
+      MatchFilter filterNode = (MatchFilter) matchNode;
+      Set rawFields = extractFromRexNode(filterNode.getCondition());
+      allFilteredFields.addAll(convertToPathRefs(rawFields, filterNode));
+
+    } else if (matchNode instanceof MatchPathModify) {
+      MatchPathModify pathModifyNode = (MatchPathModify) matchNode;
+      RexObjectConstruct expression = pathModifyNode.getExpressions().get(0).getObjectConstruct();
+      Set rawFields = extractFromRexNode(expression);
+      allFilteredFields.addAll(convertToPathRefs(rawFields, matchNode));
+
+    } else if (matchNode instanceof MatchJoin) {
+      MatchJoin joinNode = (MatchJoin) matchNode;
+      if (joinNode.getCondition() != null) {
+        Set rawFields = extractFromRexNode(joinNode.getCondition());
+        allFilteredFields.addAll(convertToPathRefs(rawFields, joinNode));
+      }
+
+    } else if (matchNode instanceof MatchExtend) {
+      // For MatchExtend, check CAST attributes
+      MatchExtend extendNode = (MatchExtend) matchNode;
+      for (PathModifyExpression expression : extendNode.getExpressions()) {
+        for (RexNode extendOperands : expression.getObjectConstruct().getOperands()) {
+          if (extendOperands instanceof RexCall) {
+            // Only consider non-primitive property projections (CAST)
+            Set rawFields = extractFromRexNode(extendOperands);
+            allFilteredFields.addAll(convertToPathRefs(rawFields, extendNode));
+          }
         }
+      }
+    }
 
-        // Recursively process all child nodes
-        if (matchNode.getInputs() != null && !matchNode.getInputs().isEmpty()) {
-            for (RelNode input : matchNode.getInputs()) {
-                if (input instanceof IMatchNode) {
-                    // Conversion is handled at leaf nodes, so no need for convertToPathRefs here
-                    allFilteredFields.addAll(extractFromMatchNode((IMatchNode) input));
-                }
-            }
+    // Recursively process all child nodes
+    if (matchNode.getInputs() != null && !matchNode.getInputs().isEmpty()) {
+      for (RelNode input : matchNode.getInputs()) {
+        if (input instanceof IMatchNode) {
+          // Conversion is handled at leaf nodes, so no need for convertToPathRefs here
+          allFilteredFields.addAll(extractFromMatchNode((IMatchNode) input));
         }
-
-        return allFilteredFields;
+      }
     }
 
-    /**
-     * Extract RexFieldAccess from the target RexNode.
-     */
-    private Set extractFromRexNode(RexNode rexNode) {
-        Set fields = new HashSet<>();
-
-        if (rexNode instanceof RexLiteral || rexNode instanceof RexInputRef) {
-            return fields;
-        } else {
-            RexCall rexCall = (RexCall) rexNode;
-            for (RexNode operand : rexCall.getOperands()) {
-                if (operand instanceof RexFieldAccess) {
-                    fields.add((RexFieldAccess) operand);
-                } else if (operand instanceof RexCall) {
-                    // Recursively process nested RexCall
-                    fields.addAll(extractFromRexNode(operand));
-                }
-            }
+    return allFilteredFields;
+  }
+
+  /** Extract RexFieldAccess from the target RexNode. */
+  private Set extractFromRexNode(RexNode rexNode) {
+    Set fields = new HashSet<>();
+
+    if (rexNode instanceof RexLiteral || rexNode instanceof RexInputRef) {
+      return fields;
+    } else {
+      RexCall rexCall = (RexCall) rexNode;
+      for (RexNode operand : rexCall.getOperands()) {
+        if (operand instanceof RexFieldAccess) {
+          fields.add((RexFieldAccess) operand);
+        } else if (operand instanceof RexCall) {
+          // Recursively process nested RexCall
+          fields.addAll(extractFromRexNode(operand));
         }
-        return fields;
+      }
     }
+    return fields;
+  }
+
+  /** Convert index-only field accesses to complete fields with labels. */
+  private static Set convertToPathRefs(
+      Set fieldAccesses, RelNode node) {
+    Set convertedFieldAccesses = new HashSet<>();
+    RelDataType pathRecordType = node.getRowType(); // Get the record type at current level
+    RexBuilder rexBuilder = node.getCluster().getRexBuilder(); // Builder for creating new fields
+
+    for (RexFieldAccess fieldAccess : fieldAccesses) {
+      RexNode referenceExpr = fieldAccess.getReferenceExpr();
+
+      // Only process field accesses of input reference type
+      if (referenceExpr instanceof RexInputRef) {
+        RexInputRef inputRef = (RexInputRef) referenceExpr;
+
+        // If index exceeds field list size, it comes from a subquery, skip it
+        if (pathRecordType.getFieldList().size() <= inputRef.getIndex()) {
+          continue;
+        }
 
-    /**
-     * Convert index-only field accesses to complete fields with labels.
-     */
-    private static Set convertToPathRefs(Set fieldAccesses, RelNode node) {
-        Set convertedFieldAccesses = new HashSet<>();
-        RelDataType pathRecordType = node.getRowType(); // Get the record type at current level
-        RexBuilder rexBuilder = node.getCluster().getRexBuilder(); // Builder for creating new fields
-
-        for (RexFieldAccess fieldAccess : fieldAccesses) {
-            RexNode referenceExpr = fieldAccess.getReferenceExpr();
-
-            // Only process field accesses of input reference type
-            if (referenceExpr instanceof RexInputRef) {
-                RexInputRef inputRef = (RexInputRef) referenceExpr;
-
-                // If index exceeds field list size, it comes from a subquery, skip it
-                if (pathRecordType.getFieldList().size() <= inputRef.getIndex()) {
-                    continue;
-                }
-
-                // Get the corresponding path field information from PathRecordType
-                RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
-
-                // Create the actual PathInputRef
-                PathInputRef pathInputRef = new PathInputRef(
-                        pathField.getName(),     // Path variable name (e.g., "a", "b", "c")
-                        pathField.getIndex(),    // Field index
-                        pathField.getType()      // Field type
-                );
+        // Get the corresponding path field information from PathRecordType
+        RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
 
-                // Recreate RexFieldAccess with the new path reference
-                RexFieldAccess newFieldAccess = (RexFieldAccess) rexBuilder.makeFieldAccess(
-                        pathInputRef,
-                        fieldAccess.getField().getIndex()
+        // Create the actual PathInputRef
+        PathInputRef pathInputRef =
+            new PathInputRef(
+                pathField.getName(), // Path variable name (e.g., "a", "b", "c")
+                pathField.getIndex(), // Field index
+                pathField.getType() // Field type
                 );
-                convertedFieldAccesses.add(newFieldAccess);
-            }
-        }
 
-        return convertedFieldAccesses;
+        // Recreate RexFieldAccess with the new path reference
+        RexFieldAccess newFieldAccess =
+            (RexFieldAccess)
+                rexBuilder.makeFieldAccess(pathInputRef, fieldAccess.getField().getIndex());
+        convertedFieldAccesses.add(newFieldAccess);
+      }
     }
 
-    /**
-     * Traverse the path pattern and add filtered fields to matching nodes.
-     */
-    private static void traverseAndPruneFields(Set fields, IMatchNode pathPattern) {
-        Queue queue = new LinkedList<>(); // Queue for nodes to visit
-        Set visited = new HashSet<>();    // Mark visited nodes
-
-        queue.offer(pathPattern);
-        visited.add(pathPattern);
-
-        // Visit all nodes in the path, and for each field: if label matches, add the field to .fields
-        while (!queue.isEmpty()) {
-            IMatchNode currentPathPattern = queue.poll();
-
-            if (currentPathPattern instanceof VertexMatch) {
-                VertexMatch vertexMatch = (VertexMatch) currentPathPattern;
-                String vertexLabel = vertexMatch.getLabel();
-                for (RexFieldAccess fieldElement : fields) {
-                    PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
-                    if (inputRef.getLabel().equals(vertexLabel)) {
-                        vertexMatch.addField(fieldElement);
-                    }
-                }
-            }
-
-            if (currentPathPattern instanceof EdgeMatch) {
-                EdgeMatch edgeMatch = (EdgeMatch) currentPathPattern;
-                String edgeLabel = edgeMatch.getLabel();
-                for (RexFieldAccess fieldElement : fields) {
-                    PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
-                    if (inputRef.getLabel().equals(edgeLabel)) {
-                        edgeMatch.addField(fieldElement);
-                    }
-                }
-            }
-
-            // Iterate through possible child nodes
-            List inputs = currentPathPattern.getInputs();
-            for (RelNode candidateInput : inputs) {
-                if (candidateInput != null && !visited.contains((IMatchNode) candidateInput)) {
-                    queue.offer((IMatchNode) candidateInput);
-                    visited.add((IMatchNode) candidateInput);
-                }
-            }
+    return convertedFieldAccesses;
+  }
+
+  /** Traverse the path pattern and add filtered fields to matching nodes. */
+  private static void traverseAndPruneFields(Set fields, IMatchNode pathPattern) {
+    Queue queue = new LinkedList<>(); // Queue for nodes to visit
+    Set visited = new HashSet<>(); // Mark visited nodes
+
+    queue.offer(pathPattern);
+    visited.add(pathPattern);
+
+    // Visit all nodes in the path, and for each field: if label matches, add the field to .fields
+    while (!queue.isEmpty()) {
+      IMatchNode currentPathPattern = queue.poll();
+
+      if (currentPathPattern instanceof VertexMatch) {
+        VertexMatch vertexMatch = (VertexMatch) currentPathPattern;
+        String vertexLabel = vertexMatch.getLabel();
+        for (RexFieldAccess fieldElement : fields) {
+          PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
+          if (inputRef.getLabel().equals(vertexLabel)) {
+            vertexMatch.addField(fieldElement);
+          }
+        }
+      }
+
+      if (currentPathPattern instanceof EdgeMatch) {
+        EdgeMatch edgeMatch = (EdgeMatch) currentPathPattern;
+        String edgeLabel = edgeMatch.getLabel();
+        for (RexFieldAccess fieldElement : fields) {
+          PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
+          if (inputRef.getLabel().equals(edgeLabel)) {
+            edgeMatch.addField(fieldElement);
+          }
+        }
+      }
+
+      // Iterate through possible child nodes
+      List inputs = currentPathPattern.getInputs();
+      for (RelNode candidateInput : inputs) {
+        if (candidateInput != null && !visited.contains((IMatchNode) candidateInput)) {
+          queue.offer((IMatchNode) candidateInput);
+          visited.add((IMatchNode) candidateInput);
         }
+      }
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchEdgeLabelFilterRemoveRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchEdgeLabelFilterRemoveRule.java
index a595f37d6..4ebef78d3 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchEdgeLabelFilterRemoveRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchEdgeLabelFilterRemoveRule.java
@@ -22,6 +22,7 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.type.RelDataTypeField;
@@ -35,80 +36,82 @@
 
 public class MatchEdgeLabelFilterRemoveRule extends RelOptRule {
 
-    public static final MatchEdgeLabelFilterRemoveRule INSTANCE = new MatchEdgeLabelFilterRemoveRule();
-
-    private MatchEdgeLabelFilterRemoveRule() {
-        super(operand(MatchFilter.class,
-            operand(EdgeMatch.class, any())));
-    }
+  public static final MatchEdgeLabelFilterRemoveRule INSTANCE =
+      new MatchEdgeLabelFilterRemoveRule();
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchFilter matchFilter = call.rel(0);
-        EdgeMatch edgeMatch = call.rel(1);
+  private MatchEdgeLabelFilterRemoveRule() {
+    super(operand(MatchFilter.class, operand(EdgeMatch.class, any())));
+  }
 
-        if (!(matchFilter.getCondition() instanceof RexCall)) {
-            return;
-        }
-        Set labelSetInFilter = new HashSet<>();
-        RexCall condition = (RexCall) matchFilter.getCondition();
-        boolean onlyHasLabelFilter = findLabelFilter(labelSetInFilter, condition, edgeMatch);
-        if (!onlyHasLabelFilter) {
-            return;
-        }
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchFilter matchFilter = call.rel(0);
+    EdgeMatch edgeMatch = call.rel(1);
 
-        Set edgeTypes = edgeMatch.getTypes();
-        // If all labels in EdgeMatch are included in MatchFilter, we can remove it.
-        for (String label : edgeTypes) {
-            if (!labelSetInFilter.contains(label)) {
-                return;
-            }
-        }
+    if (!(matchFilter.getCondition() instanceof RexCall)) {
+      return;
+    }
+    Set labelSetInFilter = new HashSet<>();
+    RexCall condition = (RexCall) matchFilter.getCondition();
+    boolean onlyHasLabelFilter = findLabelFilter(labelSetInFilter, condition, edgeMatch);
+    if (!onlyHasLabelFilter) {
+      return;
+    }
 
-        call.transformTo(edgeMatch);
+    Set edgeTypes = edgeMatch.getTypes();
+    // If all labels in EdgeMatch are included in MatchFilter, we can remove it.
+    for (String label : edgeTypes) {
+      if (!labelSetInFilter.contains(label)) {
+        return;
+      }
     }
 
-    private boolean findLabelFilter(Set labelSet, RexCall condition, EdgeMatch edgeMatch) {
-        SqlKind kind = condition.getKind();
-        if (kind == SqlKind.EQUALS) {
-            List operands = condition.getOperands();
-            RexFieldAccess fieldAccess = null;
-            RexLiteral labelLiteral = null;
-            if (operands.get(0) instanceof RexFieldAccess && operands.get(1) instanceof RexLiteral) {
-                fieldAccess = (RexFieldAccess) operands.get(0);
-                labelLiteral = (RexLiteral) operands.get(1);
-            } else if (operands.get(1) instanceof RexFieldAccess && operands.get(0) instanceof RexLiteral) {
-                fieldAccess = (RexFieldAccess) operands.get(1);
-                labelLiteral = (RexLiteral) operands.get(0);
-            } else {
-                return false;
-            }
-            RexNode referenceExpr = fieldAccess.getReferenceExpr();
-            RelDataTypeField field = fieldAccess.getField();
-            boolean isRefInputEdge = (referenceExpr instanceof PathInputRef
-                && ((PathInputRef) referenceExpr).getLabel().equals(edgeMatch.getLabel()))
-                || referenceExpr instanceof RexInputRef;
-            if (isRefInputEdge
-                && field.getType() instanceof MetaFieldType
-                && ((MetaFieldType) field.getType()).getMetaField() == MetaField.EDGE_TYPE) {
-                labelSet.add(RexLiteral.stringValue(labelLiteral));
-                return true;
-            }
-            return false;
-        } else if (kind == SqlKind.OR) {
-            boolean onlyHasIdFilter = true;
-            List operands = condition.getOperands();
-            for (RexNode operand : operands) {
-                if (operand instanceof RexCall) {
-                    onlyHasIdFilter = onlyHasIdFilter && findLabelFilter(labelSet, (RexCall) operand,
-                        edgeMatch);
-                } else {
-                    // Has other filter
-                    return false;
-                }
-            }
-            return onlyHasIdFilter;
-        }
+    call.transformTo(edgeMatch);
+  }
+
+  private boolean findLabelFilter(Set labelSet, RexCall condition, EdgeMatch edgeMatch) {
+    SqlKind kind = condition.getKind();
+    if (kind == SqlKind.EQUALS) {
+      List operands = condition.getOperands();
+      RexFieldAccess fieldAccess = null;
+      RexLiteral labelLiteral = null;
+      if (operands.get(0) instanceof RexFieldAccess && operands.get(1) instanceof RexLiteral) {
+        fieldAccess = (RexFieldAccess) operands.get(0);
+        labelLiteral = (RexLiteral) operands.get(1);
+      } else if (operands.get(1) instanceof RexFieldAccess
+          && operands.get(0) instanceof RexLiteral) {
+        fieldAccess = (RexFieldAccess) operands.get(1);
+        labelLiteral = (RexLiteral) operands.get(0);
+      } else {
         return false;
+      }
+      RexNode referenceExpr = fieldAccess.getReferenceExpr();
+      RelDataTypeField field = fieldAccess.getField();
+      boolean isRefInputEdge =
+          (referenceExpr instanceof PathInputRef
+                  && ((PathInputRef) referenceExpr).getLabel().equals(edgeMatch.getLabel()))
+              || referenceExpr instanceof RexInputRef;
+      if (isRefInputEdge
+          && field.getType() instanceof MetaFieldType
+          && ((MetaFieldType) field.getType()).getMetaField() == MetaField.EDGE_TYPE) {
+        labelSet.add(RexLiteral.stringValue(labelLiteral));
+        return true;
+      }
+      return false;
+    } else if (kind == SqlKind.OR) {
+      boolean onlyHasIdFilter = true;
+      List operands = condition.getOperands();
+      for (RexNode operand : operands) {
+        if (operand instanceof RexCall) {
+          onlyHasIdFilter =
+              onlyHasIdFilter && findLabelFilter(labelSet, (RexCall) operand, edgeMatch);
+        } else {
+          // Has other filter
+          return false;
+        }
+      }
+      return onlyHasIdFilter;
     }
+    return false;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchFilterMergeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchFilterMergeRule.java
index 661524941..00a2ca924 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchFilterMergeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchFilterMergeRule.java
@@ -19,33 +19,34 @@
 
 package org.apache.geaflow.dsl.optimize.rule;
 
-import com.google.common.collect.Lists;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rex.RexNode;
 import org.apache.geaflow.dsl.rel.match.MatchFilter;
 import org.apache.geaflow.dsl.util.GQLRexUtil;
 
+import com.google.common.collect.Lists;
+
 public class MatchFilterMergeRule extends RelOptRule {
 
-    public static final MatchFilterMergeRule INSTANCE = new MatchFilterMergeRule();
+  public static final MatchFilterMergeRule INSTANCE = new MatchFilterMergeRule();
 
-    private MatchFilterMergeRule() {
-        super(operand(MatchFilter.class,
-            operand(MatchFilter.class, any())));
-    }
+  private MatchFilterMergeRule() {
+    super(operand(MatchFilter.class, operand(MatchFilter.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchFilter topFilter = call.rel(0);
-        MatchFilter bottomFilter = call.rel(1);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchFilter topFilter = call.rel(0);
+    MatchFilter bottomFilter = call.rel(1);
 
-        RexNode mergedCondition = GQLRexUtil.and(
+    RexNode mergedCondition =
+        GQLRexUtil.and(
             Lists.newArrayList(topFilter.getCondition(), bottomFilter.getCondition()),
             call.builder().getRexBuilder());
 
-        MatchFilter mergedFilter = MatchFilter.create(bottomFilter.getInput(),
-            mergedCondition, bottomFilter.getPathSchema());
-        call.transformTo(mergedFilter);
-    }
+    MatchFilter mergedFilter =
+        MatchFilter.create(bottomFilter.getInput(), mergedCondition, bottomFilter.getPathSchema());
+    call.transformTo(mergedFilter);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchIdFilterSimplifyRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchIdFilterSimplifyRule.java
index 09c24e975..b853d8860 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchIdFilterSimplifyRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchIdFilterSimplifyRule.java
@@ -22,6 +22,7 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.type.RelDataType;
@@ -39,76 +40,76 @@
 
 public class MatchIdFilterSimplifyRule extends RelOptRule {
 
-    public static final MatchIdFilterSimplifyRule INSTANCE = new MatchIdFilterSimplifyRule();
+  public static final MatchIdFilterSimplifyRule INSTANCE = new MatchIdFilterSimplifyRule();
 
-    private MatchIdFilterSimplifyRule() {
-        super(operand(MatchFilter.class,
-            operand(VertexMatch.class, any())));
-    }
+  private MatchIdFilterSimplifyRule() {
+    super(operand(MatchFilter.class, operand(VertexMatch.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchFilter matchFilter = call.rel(0);
-        VertexMatch vertexMatch = call.rel(1);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchFilter matchFilter = call.rel(0);
+    VertexMatch vertexMatch = call.rel(1);
 
-        if (!(matchFilter.getCondition() instanceof RexCall)) {
-            return;
-        }
-        RexCall condition = (RexCall) matchFilter.getCondition();
-        Set idSet = new HashSet<>();
-        boolean onLyHasIdFilter = findIdFilter(idSet, condition, vertexMatch);
-
-        if (!onLyHasIdFilter) {
-            return;
-        }
+    if (!(matchFilter.getCondition() instanceof RexCall)) {
+      return;
+    }
+    RexCall condition = (RexCall) matchFilter.getCondition();
+    Set idSet = new HashSet<>();
+    boolean onLyHasIdFilter = findIdFilter(idSet, condition, vertexMatch);
 
-        VertexMatch newVertexMatch = vertexMatch.copy(idSet);
-        call.transformTo(newVertexMatch);
+    if (!onLyHasIdFilter) {
+      return;
     }
 
-    private boolean findIdFilter(Set idSet, RexCall condition, VertexMatch vertexMatch) {
-        SqlKind kind = condition.getKind();
-        if (kind == SqlKind.EQUALS) {
-            List operands = condition.getOperands();
-            RexFieldAccess fieldAccess = null;
-            RexLiteral idLiteral = null;
-            if (operands.get(0) instanceof RexFieldAccess && operands.get(1) instanceof RexLiteral) {
-                fieldAccess = (RexFieldAccess) operands.get(0);
-                idLiteral = (RexLiteral) operands.get(1);
-            } else if (operands.get(1) instanceof RexFieldAccess && operands.get(0) instanceof RexLiteral) {
-                fieldAccess = (RexFieldAccess) operands.get(1);
-                idLiteral = (RexLiteral) operands.get(0);
-            } else {
-                return false;
-            }
-            RexNode referenceExpr = fieldAccess.getReferenceExpr();
-            RelDataTypeField field = fieldAccess.getField();
-            boolean isRefInputVertex = (referenceExpr instanceof PathInputRef
-                && ((PathInputRef) referenceExpr).getLabel().equals(vertexMatch.getLabel()))
-                || referenceExpr instanceof RexInputRef;
-            if (isRefInputVertex
-                && field.getType() instanceof MetaFieldType
-                && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_ID) {
-                RelDataType dataType = ((MetaFieldType) field.getType()).getType();
-                IType idType = SqlTypeUtil.convertType(dataType);
-                idSet.add(TypeCastUtil.cast(idLiteral.getValue(), idType));
-                return true;
-            }
-            return false;
-        } else if (kind == SqlKind.OR) {
-            boolean onlyHasIdFilter = true;
-            List operands = condition.getOperands();
-            for (RexNode operand : operands) {
-                if (operand instanceof RexCall) {
-                    onlyHasIdFilter = onlyHasIdFilter && findIdFilter(idSet, (RexCall) operand,
-                        vertexMatch);
-                } else {
-                    // Has other filter
-                    return false;
-                }
-            }
-            return onlyHasIdFilter;
-        }
+    VertexMatch newVertexMatch = vertexMatch.copy(idSet);
+    call.transformTo(newVertexMatch);
+  }
+
+  private boolean findIdFilter(Set idSet, RexCall condition, VertexMatch vertexMatch) {
+    SqlKind kind = condition.getKind();
+    if (kind == SqlKind.EQUALS) {
+      List operands = condition.getOperands();
+      RexFieldAccess fieldAccess = null;
+      RexLiteral idLiteral = null;
+      if (operands.get(0) instanceof RexFieldAccess && operands.get(1) instanceof RexLiteral) {
+        fieldAccess = (RexFieldAccess) operands.get(0);
+        idLiteral = (RexLiteral) operands.get(1);
+      } else if (operands.get(1) instanceof RexFieldAccess
+          && operands.get(0) instanceof RexLiteral) {
+        fieldAccess = (RexFieldAccess) operands.get(1);
+        idLiteral = (RexLiteral) operands.get(0);
+      } else {
         return false;
+      }
+      RexNode referenceExpr = fieldAccess.getReferenceExpr();
+      RelDataTypeField field = fieldAccess.getField();
+      boolean isRefInputVertex =
+          (referenceExpr instanceof PathInputRef
+                  && ((PathInputRef) referenceExpr).getLabel().equals(vertexMatch.getLabel()))
+              || referenceExpr instanceof RexInputRef;
+      if (isRefInputVertex
+          && field.getType() instanceof MetaFieldType
+          && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_ID) {
+        RelDataType dataType = ((MetaFieldType) field.getType()).getType();
+        IType idType = SqlTypeUtil.convertType(dataType);
+        idSet.add(TypeCastUtil.cast(idLiteral.getValue(), idType));
+        return true;
+      }
+      return false;
+    } else if (kind == SqlKind.OR) {
+      boolean onlyHasIdFilter = true;
+      List operands = condition.getOperands();
+      for (RexNode operand : operands) {
+        if (operand instanceof RexCall) {
+          onlyHasIdFilter = onlyHasIdFilter && findIdFilter(idSet, (RexCall) operand, vertexMatch);
+        } else {
+          // Has other filter
+          return false;
+        }
+      }
+      return onlyHasIdFilter;
     }
+    return false;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinMatchMergeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinMatchMergeRule.java
index 2ad8d8e91..6af1ff5d4 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinMatchMergeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinMatchMergeRule.java
@@ -26,6 +26,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.JoinInfo;
@@ -47,127 +48,160 @@
 
 public class MatchJoinMatchMergeRule extends AbstractJoinToGraphRule {
 
-    public static final MatchJoinMatchMergeRule INSTANCE = new MatchJoinMatchMergeRule();
+  public static final MatchJoinMatchMergeRule INSTANCE = new MatchJoinMatchMergeRule();
 
-    private MatchJoinMatchMergeRule() {
-        super(operand(LogicalJoin.class,
-            operand(RelNode.class, any()),
-            operand(RelNode.class, any())));
-    }
+  private MatchJoinMatchMergeRule() {
+    super(operand(LogicalJoin.class, operand(RelNode.class, any()), operand(RelNode.class, any())));
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        LogicalJoin join = call.rel(0);
-        if (!isSupportJoinType(join.getJoinType())) {
-            // non-INNER joins is not supported.
-            return false;
-        }
-        RelNode leftInput = call.rel(1);
-        RelNode rightInput = call.rel(2);
-        return isSingleChainFromGraphMatch(leftInput)
-            && isSingleChainFromGraphMatch(rightInput);
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    LogicalJoin join = call.rel(0);
+    if (!isSupportJoinType(join.getJoinType())) {
+      // non-INNER joins is not supported.
+      return false;
     }
+    RelNode leftInput = call.rel(1);
+    RelNode rightInput = call.rel(2);
+    return isSingleChainFromGraphMatch(leftInput) && isSingleChainFromGraphMatch(rightInput);
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode leftInput = call.rel(1);
-        RelNode leftGraphMatchProject = null;
-        RelNode leftGraphMatch = leftInput;
-        while (leftGraphMatch != null && !(leftGraphMatch instanceof LogicalGraphMatch)) {
-            leftGraphMatchProject = leftGraphMatch;
-            leftGraphMatch = GQLRelUtil.toRel(leftGraphMatch.getInput(0));
-        }
-        RelNode rightInput = call.rel(2);
-        RelNode rightGraphMatchProject = null;
-        RelNode rightGraphMatch = rightInput;
-        while (rightGraphMatch != null && !(rightGraphMatch instanceof LogicalGraphMatch)) {
-            rightGraphMatchProject = rightGraphMatch;
-            rightGraphMatch = GQLRelUtil.toRel(rightGraphMatch.getInput(0));
-        }
-        assert leftGraphMatch != null && rightGraphMatch != null;
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode leftInput = call.rel(1);
+    RelNode leftGraphMatchProject = null;
+    RelNode leftGraphMatch = leftInput;
+    while (leftGraphMatch != null && !(leftGraphMatch instanceof LogicalGraphMatch)) {
+      leftGraphMatchProject = leftGraphMatch;
+      leftGraphMatch = GQLRelUtil.toRel(leftGraphMatch.getInput(0));
+    }
+    RelNode rightInput = call.rel(2);
+    RelNode rightGraphMatchProject = null;
+    RelNode rightGraphMatch = rightInput;
+    while (rightGraphMatch != null && !(rightGraphMatch instanceof LogicalGraphMatch)) {
+      rightGraphMatchProject = rightGraphMatch;
+      rightGraphMatch = GQLRelUtil.toRel(rightGraphMatch.getInput(0));
+    }
+    assert leftGraphMatch != null && rightGraphMatch != null;
 
-        RelBuilder relBuilder = call.builder();
-        RexBuilder rexBuilder = relBuilder.getRexBuilder();
-        List rexLeftNodeMap = new ArrayList<>();
-        List rexRightNodeMap = new ArrayList<>();
-        IMatchNode leftPathPattern = ((GraphMatch) leftGraphMatch).getPathPattern();
-        leftPathPattern = concatToMatchNode(relBuilder, null, leftInput, leftGraphMatch,
-            leftPathPattern, rexLeftNodeMap);
-        IMatchNode rightPathPattern = ((GraphMatch) rightGraphMatch).getPathPattern();
-        rightPathPattern = concatToMatchNode(relBuilder, null, rightInput, rightGraphMatch,
-            rightPathPattern, rexRightNodeMap);
-        if (leftPathPattern == null || rightPathPattern == null) {
-            return;
-        }
-        LogicalJoin join = call.rel(0);
-        MatchJoin newPathPattern = MatchJoin.create(join.getCluster(), join.getTraitSet(),
-            leftPathPattern, rightPathPattern, join.getCondition(), join.getJoinType());
-        GraphMatch newGraphMatch = ((GraphMatch) leftGraphMatch).copy(newPathPattern);
+    RelBuilder relBuilder = call.builder();
+    RexBuilder rexBuilder = relBuilder.getRexBuilder();
+    List rexLeftNodeMap = new ArrayList<>();
+    List rexRightNodeMap = new ArrayList<>();
+    IMatchNode leftPathPattern = ((GraphMatch) leftGraphMatch).getPathPattern();
+    leftPathPattern =
+        concatToMatchNode(
+            relBuilder, null, leftInput, leftGraphMatch, leftPathPattern, rexLeftNodeMap);
+    IMatchNode rightPathPattern = ((GraphMatch) rightGraphMatch).getPathPattern();
+    rightPathPattern =
+        concatToMatchNode(
+            relBuilder, null, rightInput, rightGraphMatch, rightPathPattern, rexRightNodeMap);
+    if (leftPathPattern == null || rightPathPattern == null) {
+      return;
+    }
+    LogicalJoin join = call.rel(0);
+    MatchJoin newPathPattern =
+        MatchJoin.create(
+            join.getCluster(),
+            join.getTraitSet(),
+            leftPathPattern,
+            rightPathPattern,
+            join.getCondition(),
+            join.getJoinType());
+    GraphMatch newGraphMatch = ((GraphMatch) leftGraphMatch).copy(newPathPattern);
 
-        List newProjects = new ArrayList<>();
-        if (rexLeftNodeMap.size() > 0) {
-            newProjects.addAll(rexLeftNodeMap);
-        } else {
-            assert leftGraphMatchProject != null;
-            newProjects.addAll(adjustLeftRexNodes(
-                ((LogicalProject) leftGraphMatchProject).getProjects(), newGraphMatch, relBuilder));
-        }
-        if (rexRightNodeMap.size() > 0) {
-            newProjects.addAll(adjustRightRexNodes(rexRightNodeMap, newGraphMatch, relBuilder,
-                leftPathPattern, rightPathPattern));
-        } else {
-            assert rightGraphMatchProject != null;
-            newProjects.addAll(adjustRightRexNodes(
-                ((LogicalProject) rightGraphMatchProject).getProjects(), newGraphMatch, relBuilder,
-                leftPathPattern, rightPathPattern));
-        }
+    List newProjects = new ArrayList<>();
+    if (rexLeftNodeMap.size() > 0) {
+      newProjects.addAll(rexLeftNodeMap);
+    } else {
+      assert leftGraphMatchProject != null;
+      newProjects.addAll(
+          adjustLeftRexNodes(
+              ((LogicalProject) leftGraphMatchProject).getProjects(), newGraphMatch, relBuilder));
+    }
+    if (rexRightNodeMap.size() > 0) {
+      newProjects.addAll(
+          adjustRightRexNodes(
+              rexRightNodeMap, newGraphMatch, relBuilder, leftPathPattern, rightPathPattern));
+    } else {
+      assert rightGraphMatchProject != null;
+      newProjects.addAll(
+          adjustRightRexNodes(
+              ((LogicalProject) rightGraphMatchProject).getProjects(),
+              newGraphMatch,
+              relBuilder,
+              leftPathPattern,
+              rightPathPattern));
+    }
 
-        JoinInfo joinInfo = join.analyzeCondition();
-        List joinConditions = new ArrayList<>();
-        if (newGraphMatch.getPathPattern() instanceof MatchJoin) {
-            MatchJoin matchJoin = (MatchJoin) newGraphMatch.getPathPattern();
-            for (int i = 0; i < joinInfo.leftKeys.size(); i++) {
-                int left = joinInfo.leftKeys.get(i);
-                int right = joinInfo.rightKeys.get(i);
-                RexNode leftNode = rexLeftNodeMap.get(left);
-                RexNode rightNode = rexRightNodeMap.get(right);
-                rightNode = adjustRightRexNodes(Collections.singletonList(rightNode), newGraphMatch,
-                    relBuilder, leftPathPattern, rightPathPattern).get(0);
-                SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
-                RexNode condition = relBuilder.getRexBuilder().makeCall(equalsOperator,
-                    leftNode, rightNode);
-                joinConditions.add(condition);
-            }
-            RexNode newCondition = RexUtil.composeConjunction(rexBuilder, joinConditions);
-            newGraphMatch = newGraphMatch.copy(matchJoin.copy(matchJoin.getTraitSet(),
-                newCondition, matchJoin.getLeft(), matchJoin.getRight(), matchJoin.getJoinType()));
-        }
+    JoinInfo joinInfo = join.analyzeCondition();
+    List joinConditions = new ArrayList<>();
+    if (newGraphMatch.getPathPattern() instanceof MatchJoin) {
+      MatchJoin matchJoin = (MatchJoin) newGraphMatch.getPathPattern();
+      for (int i = 0; i < joinInfo.leftKeys.size(); i++) {
+        int left = joinInfo.leftKeys.get(i);
+        int right = joinInfo.rightKeys.get(i);
+        RexNode leftNode = rexLeftNodeMap.get(left);
+        RexNode rightNode = rexRightNodeMap.get(right);
+        rightNode =
+            adjustRightRexNodes(
+                    Collections.singletonList(rightNode),
+                    newGraphMatch,
+                    relBuilder,
+                    leftPathPattern,
+                    rightPathPattern)
+                .get(0);
+        SqlOperator equalsOperator = SqlStdOperatorTable.EQUALS;
+        RexNode condition =
+            relBuilder.getRexBuilder().makeCall(equalsOperator, leftNode, rightNode);
+        joinConditions.add(condition);
+      }
+      RexNode newCondition = RexUtil.composeConjunction(rexBuilder, joinConditions);
+      newGraphMatch =
+          newGraphMatch.copy(
+              matchJoin.copy(
+                  matchJoin.getTraitSet(),
+                  newCondition,
+                  matchJoin.getLeft(),
+                  matchJoin.getRight(),
+                  matchJoin.getJoinType()));
+    }
 
-        List fieldNames = this.generateFieldNames("f", newProjects.size(), new HashSet<>());
-        RelNode tail = LogicalProject.create(newGraphMatch, newProjects, fieldNames);
+    List fieldNames = this.generateFieldNames("f", newProjects.size(), new HashSet<>());
+    RelNode tail = LogicalProject.create(newGraphMatch, newProjects, fieldNames);
 
-        // Complete the Join projection.
-        final RelNode finalTail = tail;
-        List joinProjects = IntStream.range(0, newProjects.size())
-            .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i)).collect(Collectors.toList());
-        AtomicInteger offset = new AtomicInteger();
-        // Make the project type nullable the same as the output type of the join.
-        joinProjects = joinProjects.stream().map(prj -> {
-            int i = offset.getAndIncrement();
-            boolean joinFieldNullable = join.getRowType().getFieldList().get(i).getType().isNullable();
-            if ((prj.getType().isNullable() && !joinFieldNullable)
-                || (!prj.getType().isNullable() && joinFieldNullable)) {
-                RelDataType type = rexBuilder.getTypeFactory().createTypeWithNullability(prj.getType(), joinFieldNullable);
-                return rexBuilder.makeCast(type, prj);
-            }
-            return prj;
-        }).collect(Collectors.toList());
-        tail = LogicalProject.create(tail, joinProjects, join.getRowType());
-        // Add remain filter.
-        RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
-        if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
-            tail = LogicalFilter.create(tail, remainFilter);
-        }
-        call.transformTo(tail);
+    // Complete the Join projection.
+    final RelNode finalTail = tail;
+    List joinProjects =
+        IntStream.range(0, newProjects.size())
+            .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+            .collect(Collectors.toList());
+    AtomicInteger offset = new AtomicInteger();
+    // Make the project type nullable the same as the output type of the join.
+    joinProjects =
+        joinProjects.stream()
+            .map(
+                prj -> {
+                  int i = offset.getAndIncrement();
+                  boolean joinFieldNullable =
+                      join.getRowType().getFieldList().get(i).getType().isNullable();
+                  if ((prj.getType().isNullable() && !joinFieldNullable)
+                      || (!prj.getType().isNullable() && joinFieldNullable)) {
+                    RelDataType type =
+                        rexBuilder
+                            .getTypeFactory()
+                            .createTypeWithNullability(prj.getType(), joinFieldNullable);
+                    return rexBuilder.makeCast(type, prj);
+                  }
+                  return prj;
+                })
+            .collect(Collectors.toList());
+    tail = LogicalProject.create(tail, joinProjects, join.getRowType());
+    // Add remain filter.
+    RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
+    if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
+      tail = LogicalFilter.create(tail, remainFilter);
     }
+    call.transformTo(tail);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinTableToGraphMatchRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinTableToGraphMatchRule.java
index 2b527da07..65ace088a 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinTableToGraphMatchRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchJoinTableToGraphMatchRule.java
@@ -29,54 +29,54 @@
 
 public class MatchJoinTableToGraphMatchRule extends AbstractJoinToGraphRule {
 
-    public static final MatchJoinTableToGraphMatchRule INSTANCE = new MatchJoinTableToGraphMatchRule();
+  public static final MatchJoinTableToGraphMatchRule INSTANCE =
+      new MatchJoinTableToGraphMatchRule();
 
-    private MatchJoinTableToGraphMatchRule() {
-        super(operand(LogicalJoin.class,
-            operand(RelNode.class, any()),
-            operand(RelNode.class, any())));
-    }
+  private MatchJoinTableToGraphMatchRule() {
+    super(operand(LogicalJoin.class, operand(RelNode.class, any()), operand(RelNode.class, any())));
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        LogicalJoin join = call.rel(0);
-        if (!isSupportJoinType(join.getJoinType())) {
-            // non-INNER joins is not supported.
-            return false;
-        }
-        RelNode leftInput = call.rel(1);
-        RelNode rightInput = call.rel(2);
-        return isSingleChainFromGraphMatch(leftInput)
-            && isSingleChainFromLogicalTableScan(rightInput);
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    LogicalJoin join = call.rel(0);
+    if (!isSupportJoinType(join.getJoinType())) {
+      // non-INNER joins is not supported.
+      return false;
     }
+    RelNode leftInput = call.rel(1);
+    RelNode rightInput = call.rel(2);
+    return isSingleChainFromGraphMatch(leftInput) && isSingleChainFromLogicalTableScan(rightInput);
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode leftInput = call.rel(1);
-        RelNode graphMatchProject = null;
-        RelNode leftGraphMatch = leftInput;
-        while (leftGraphMatch != null && !(leftGraphMatch instanceof LogicalGraphMatch)) {
-            graphMatchProject = leftGraphMatch;
-            leftGraphMatch = GQLRelUtil.toRel(leftGraphMatch.getInput(0));
-        }
-        RelNode rightInput = call.rel(2);
-        RelNode rightTableScan = rightInput;
-        while (!(rightTableScan instanceof LogicalTableScan)) {
-            rightTableScan = GQLRelUtil.toRel(rightTableScan.getInput(0));
-        }
-        RelNode tail = processGraphMatchJoinTable(
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode leftInput = call.rel(1);
+    RelNode graphMatchProject = null;
+    RelNode leftGraphMatch = leftInput;
+    while (leftGraphMatch != null && !(leftGraphMatch instanceof LogicalGraphMatch)) {
+      graphMatchProject = leftGraphMatch;
+      leftGraphMatch = GQLRelUtil.toRel(leftGraphMatch.getInput(0));
+    }
+    RelNode rightInput = call.rel(2);
+    RelNode rightTableScan = rightInput;
+    while (!(rightTableScan instanceof LogicalTableScan)) {
+      rightTableScan = GQLRelUtil.toRel(rightTableScan.getInput(0));
+    }
+    RelNode tail =
+        processGraphMatchJoinTable(
             call,
             call.rel(0),
             (LogicalGraphMatch) leftGraphMatch,
             (LogicalProject) graphMatchProject,
             (LogicalTableScan) rightTableScan,
-            leftInput, leftGraphMatch, rightInput, rightTableScan, true
-        );
-        if (tail == null) {
-            return;
-        }
-        call.transformTo(tail);
+            leftInput,
+            leftGraphMatch,
+            rightInput,
+            rightTableScan,
+            true);
+    if (tail == null) {
+      return;
     }
-
-
+    call.transformTo(tail);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchSortToLogicalSortRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchSortToLogicalSortRule.java
index 14b267638..0e87d7716 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchSortToLogicalSortRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/MatchSortToLogicalSortRule.java
@@ -21,11 +21,11 @@
 
 import static org.apache.geaflow.dsl.util.GQLRelUtil.toRel;
 
-import com.google.common.collect.Lists;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelCollations;
@@ -41,72 +41,75 @@
 import org.apache.geaflow.dsl.rel.match.MatchPathSort;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
+import com.google.common.collect.Lists;
+
 public class MatchSortToLogicalSortRule extends RelOptRule {
 
-    public static final MatchSortToLogicalSortRule INSTANCE = new MatchSortToLogicalSortRule();
+  public static final MatchSortToLogicalSortRule INSTANCE = new MatchSortToLogicalSortRule();
 
-    private MatchSortToLogicalSortRule() {
-        super(operand(RelNode.class,
-            operand(LogicalGraphMatch.class, any())));
+  private MatchSortToLogicalSortRule() {
+    super(operand(RelNode.class, operand(LogicalGraphMatch.class, any())));
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode topNode = call.rel(0);
+    LogicalGraphMatch graphMatch = call.rel(1);
+
+    if (!(graphMatch.getPathPattern() instanceof MatchPathSort)
+        || GQLRelUtil.isGQLMatchRelNode(topNode)) {
+      return;
     }
+    MatchPathSort pathSort = (MatchPathSort) graphMatch.getPathPattern();
+    RexBuilder rexBuilder = call.builder().getRexBuilder();
+
+    List projects = new ArrayList<>();
+    RexNode[] orderRexNodes = new RexNode[pathSort.getOrderByExpressions().size()];
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode topNode = call.rel(0);
-        LogicalGraphMatch graphMatch = call.rel(1);
-
-        if (!(graphMatch.getPathPattern() instanceof MatchPathSort)
-            || GQLRelUtil.isGQLMatchRelNode(topNode)) {
-            return;
-        }
-        MatchPathSort pathSort = (MatchPathSort) graphMatch.getPathPattern();
-        RexBuilder rexBuilder = call.builder().getRexBuilder();
-
-        List projects = new ArrayList<>();
-        RexNode[] orderRexNodes = new RexNode[pathSort.getOrderByExpressions().size()];
-
-        RelDataType inputType = pathSort.getInput().getRowType();
-        List topReferences = GQLRelUtil.getRelNodeReference(topNode);
-        Map topReferMapping = new HashMap<>();
-
-        for (int i = 0; i < topReferences.size(); i++) {
-            int topRefer = topReferences.get(i);
-            projects.add(rexBuilder.makeInputRef(inputType.getFieldList().get(topRefer).getType(), topRefer));
-            topReferMapping.put(topRefer, i);
-        }
-        for (int i = 0; i < pathSort.getOrderByExpressions().size(); i++) {
-            RexNode orderExp = pathSort.getOrderByExpressions().get(i);
-            projects.add(orderExp);
-
-            RexNode orderRex = rexBuilder.makeInputRef(orderExp.getType(), i + topReferences.size());
-            if (orderExp.getKind() == SqlKind.DESCENDING) {
-                orderRex = rexBuilder.makeCall(SqlStdOperatorTable.DESC, orderRex);
-            }
-            orderRexNodes[i] = orderRex;
-        }
-
-        RelNode newTop;
-        graphMatch = (LogicalGraphMatch) graphMatch.copy((IMatchNode) toRel(pathSort.getInput()));
-        LogicalSort logicalSort;
-        if (orderRexNodes.length > 0) {
-            logicalSort = (LogicalSort) call.builder().push(graphMatch)
-                .project(projects)
-                .sort(orderRexNodes)
-                .build();
-            logicalSort = (LogicalSort) logicalSort.copy(logicalSort.getTraitSet(),
-                logicalSort.getInput(),
-                logicalSort.getCollation(),
-                null,
-                pathSort.getLimit());
-            // adjust the reference index for topNode after add project to the graph match
-            newTop = GQLRelUtil.adjustInputRef(topNode, topReferMapping);
-        } else {
-            logicalSort = LogicalSort.create(graphMatch, RelCollations.EMPTY,
-                null, pathSort.getLimit());
-            newTop = topNode;
-        }
-
-        newTop = newTop.copy(topNode.getTraitSet(), Lists.newArrayList(logicalSort));
-        call.transformTo(newTop);
+    RelDataType inputType = pathSort.getInput().getRowType();
+    List topReferences = GQLRelUtil.getRelNodeReference(topNode);
+    Map topReferMapping = new HashMap<>();
+
+    for (int i = 0; i < topReferences.size(); i++) {
+      int topRefer = topReferences.get(i);
+      projects.add(
+          rexBuilder.makeInputRef(inputType.getFieldList().get(topRefer).getType(), topRefer));
+      topReferMapping.put(topRefer, i);
+    }
+    for (int i = 0; i < pathSort.getOrderByExpressions().size(); i++) {
+      RexNode orderExp = pathSort.getOrderByExpressions().get(i);
+      projects.add(orderExp);
+
+      RexNode orderRex = rexBuilder.makeInputRef(orderExp.getType(), i + topReferences.size());
+      if (orderExp.getKind() == SqlKind.DESCENDING) {
+        orderRex = rexBuilder.makeCall(SqlStdOperatorTable.DESC, orderRex);
+      }
+      orderRexNodes[i] = orderRex;
     }
+
+    RelNode newTop;
+    graphMatch = (LogicalGraphMatch) graphMatch.copy((IMatchNode) toRel(pathSort.getInput()));
+    LogicalSort logicalSort;
+    if (orderRexNodes.length > 0) {
+      logicalSort =
+          (LogicalSort)
+              call.builder().push(graphMatch).project(projects).sort(orderRexNodes).build();
+      logicalSort =
+          (LogicalSort)
+              logicalSort.copy(
+                  logicalSort.getTraitSet(),
+                  logicalSort.getInput(),
+                  logicalSort.getCollation(),
+                  null,
+                  pathSort.getLimit());
+      // adjust the reference index for topNode after add project to the graph match
+      newTop = GQLRelUtil.adjustInputRef(topNode, topReferMapping);
+    } else {
+      logicalSort = LogicalSort.create(graphMatch, RelCollations.EMPTY, null, pathSort.getLimit());
+      newTop = topNode;
+    }
+
+    newTop = newTop.copy(topNode.getTraitSet(), Lists.newArrayList(logicalSort));
+    call.transformTo(newTop);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathInputReplaceRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathInputReplaceRule.java
index 8fd1da463..eae8a366d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathInputReplaceRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathInputReplaceRule.java
@@ -19,7 +19,6 @@
 
 package org.apache.geaflow.dsl.optimize.rule;
 
-import com.google.common.collect.Lists;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -34,81 +33,84 @@
 import org.apache.geaflow.dsl.rex.RexLambdaCall;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
+import com.google.common.collect.Lists;
+
 public class PathInputReplaceRule extends RelOptRule {
 
-    public static final PathInputReplaceRule INSTANCE = new PathInputReplaceRule();
+  public static final PathInputReplaceRule INSTANCE = new PathInputReplaceRule();
 
-    private PathInputReplaceRule() {
-        super(operand(RelNode.class, any()));
-    }
+  private PathInputReplaceRule() {
+    super(operand(RelNode.class, any()));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode node = call.rel(0);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode node = call.rel(0);
 
-        if (node instanceof IMatchNode) {
-            PathRecordType pathRecordType = null;
-            if (node instanceof MatchJoin) {
-                pathRecordType = ((IMatchNode) node).getPathSchema();
-            } else if (node.getInputs().size() == 1) {
-                pathRecordType = ((IMatchNode) GQLRelUtil.toRel(node.getInput(0))).getPathSchema();
-            }
-            if (pathRecordType != null) {
-                RelNode newNode = replaceInputRef(pathRecordType, node);
-                call.transformTo(newNode);
-            }
-        } else {
-            if (node.getInputs().size() == 1
-                && node.getInput(0).getRowType() instanceof PathRecordType) {
-                PathRecordType pathRecordType = (PathRecordType) node.getInput(0).getRowType();
-                RelNode newNode = replaceInputRef(pathRecordType, node);
-                call.transformTo(newNode);
-            }
-        }
+    if (node instanceof IMatchNode) {
+      PathRecordType pathRecordType = null;
+      if (node instanceof MatchJoin) {
+        pathRecordType = ((IMatchNode) node).getPathSchema();
+      } else if (node.getInputs().size() == 1) {
+        pathRecordType = ((IMatchNode) GQLRelUtil.toRel(node.getInput(0))).getPathSchema();
+      }
+      if (pathRecordType != null) {
+        RelNode newNode = replaceInputRef(pathRecordType, node);
+        call.transformTo(newNode);
+      }
+    } else {
+      if (node.getInputs().size() == 1 && node.getInput(0).getRowType() instanceof PathRecordType) {
+        PathRecordType pathRecordType = (PathRecordType) node.getInput(0).getRowType();
+        RelNode newNode = replaceInputRef(pathRecordType, node);
+        call.transformTo(newNode);
+      }
     }
+  }
 
-    private RelNode replaceInputRef(PathRecordType pathRecordType, RelNode node) {
-        return node.accept(new PathRefReplaceVisitor(pathRecordType));
-    }
+  private RelNode replaceInputRef(PathRecordType pathRecordType, RelNode node) {
+    return node.accept(new PathRefReplaceVisitor(pathRecordType));
+  }
 
-    private static class PathRefReplaceVisitor extends RexShuttle {
+  private static class PathRefReplaceVisitor extends RexShuttle {
 
-        private final PathRecordType pathRecordType;
+    private final PathRecordType pathRecordType;
 
-        public PathRefReplaceVisitor(PathRecordType pathRecordType) {
-            this.pathRecordType = pathRecordType;
-        }
+    public PathRefReplaceVisitor(PathRecordType pathRecordType) {
+      this.pathRecordType = pathRecordType;
+    }
 
-        @Override
-        public RexNode visitInputRef(RexInputRef inputRef) {
-            RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
-            return new PathInputRef(pathField.getName(), pathField.getIndex(), pathField.getType());
-        }
+    @Override
+    public RexNode visitInputRef(RexInputRef inputRef) {
+      RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
+      return new PathInputRef(pathField.getName(), pathField.getIndex(), pathField.getType());
+    }
 
-        @Override
-        public RexNode visitCall(RexCall call) {
-            if (call instanceof RexLambdaCall) {
-                RexLambdaCall lambdaCall = (RexLambdaCall) call;
-                PathRecordType pathRecordType = (PathRecordType) lambdaCall.getInput().getType();
-                RexNode newValue = ((RexLambdaCall) call).getValue()
-                    .accept(new PathRefReplaceVisitor(pathRecordType));
-                return lambdaCall.clone(lambdaCall.type, Lists.newArrayList(lambdaCall.getInput(), newValue));
-            } else {
-                return super.visitCall(call);
-            }
-        }
+    @Override
+    public RexNode visitCall(RexCall call) {
+      if (call instanceof RexLambdaCall) {
+        RexLambdaCall lambdaCall = (RexLambdaCall) call;
+        PathRecordType pathRecordType = (PathRecordType) lambdaCall.getInput().getType();
+        RexNode newValue =
+            ((RexLambdaCall) call).getValue().accept(new PathRefReplaceVisitor(pathRecordType));
+        return lambdaCall.clone(
+            lambdaCall.type, Lists.newArrayList(lambdaCall.getInput(), newValue));
+      } else {
+        return super.visitCall(call);
+      }
+    }
 
-        @Override
-        public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
-            if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable
-                && (fieldAccess.getType() instanceof VertexRecordType
-                || fieldAccess.getType() instanceof EdgeRecordType)) {
-                String pathFieldName = fieldAccess.getField().getName();
-                return new PathInputRef(fieldAccess.getField().getName(),
-                    pathRecordType.getField(pathFieldName, true, false).getIndex(),
-                    fieldAccess.getField().getType());
-            }
-            return super.visitFieldAccess(fieldAccess);
-        }
+    @Override
+    public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+      if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable
+          && (fieldAccess.getType() instanceof VertexRecordType
+              || fieldAccess.getType() instanceof EdgeRecordType)) {
+        String pathFieldName = fieldAccess.getField().getName();
+        return new PathInputRef(
+            fieldAccess.getField().getName(),
+            pathRecordType.getField(pathFieldName, true, false).getIndex(),
+            fieldAccess.getField().getType());
+      }
+      return super.visitFieldAccess(fieldAccess);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathModifyMergeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathModifyMergeRule.java
index cb636819d..a7f32dedd 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathModifyMergeRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PathModifyMergeRule.java
@@ -19,13 +19,13 @@
 
 package org.apache.geaflow.dsl.optimize.rule;
 
-import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.type.RelDataType;
@@ -41,130 +41,133 @@
 import org.apache.geaflow.dsl.rex.RexObjectConstruct.VariableInfo;
 import org.apache.geaflow.dsl.util.GQLRexUtil;
 
-public class PathModifyMergeRule extends RelOptRule {
+import com.google.common.collect.ImmutableList;
 
-    public static final PathModifyMergeRule INSTANCE = new PathModifyMergeRule();
+public class PathModifyMergeRule extends RelOptRule {
 
-    private PathModifyMergeRule() {
-        super(operand(MatchPathModify.class,
-            operand(MatchPathModify.class, any())));
+  public static final PathModifyMergeRule INSTANCE = new PathModifyMergeRule();
+
+  private PathModifyMergeRule() {
+    super(operand(MatchPathModify.class, operand(MatchPathModify.class, any())));
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    MatchPathModify pathModify = call.rel(0);
+    MatchPathModify inputPathModify = call.rel(1);
+    // path modify with sub-query refer dependency can not merge as
+    // the sub-query execute concurrently. See AbstractStepOperator for
+    // more information about the sub-query.
+    if (hasFieldRefDependency(pathModify, inputPathModify)) {
+      return;
     }
-
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        MatchPathModify pathModify = call.rel(0);
-        MatchPathModify inputPathModify = call.rel(1);
-        // path modify with sub-query refer dependency can not merge as
-        // the sub-query execute concurrently. See AbstractStepOperator for
-        // more information about the sub-query.
-        if (hasFieldRefDependency(pathModify, inputPathModify)) {
-            return;
-        }
-        ImmutableList mergedExpressions =
-            mergePathModifyExpression(call.builder().getRexBuilder(),
-                pathModify.getExpressions(),
-                inputPathModify.getExpressions());
-
-        RelDataType mergedRelType = mergeRelType(pathModify.getRowType(),
-            inputPathModify.getRowType());
-
-        PathModify mergedPathModify = pathModify.copy(
-            pathModify.getTraitSet(),
-            inputPathModify.getInput(),
-            mergedExpressions,
-            mergedRelType);
-        call.transformTo(mergedPathModify);
+    ImmutableList mergedExpressions =
+        mergePathModifyExpression(
+            call.builder().getRexBuilder(),
+            pathModify.getExpressions(),
+            inputPathModify.getExpressions());
+
+    RelDataType mergedRelType = mergeRelType(pathModify.getRowType(), inputPathModify.getRowType());
+
+    PathModify mergedPathModify =
+        pathModify.copy(
+            pathModify.getTraitSet(), inputPathModify.getInput(), mergedExpressions, mergedRelType);
+    call.transformTo(mergedPathModify);
+  }
+
+  /**
+   * Check if pathModify has field dependencies, that is, the current PathModify references fields
+   * generated by input PathModify. In this case, there is a field dependency between
+   * LogicalPathModify, and we cannot merge them. Additionally, LogicalPathModify with Global field
+   * cannot be merged.
+   */
+  private boolean hasFieldRefDependency(
+      MatchPathModify inputPathModify, MatchPathModify pathModify) {
+    Set inputFields = new HashSet<>();
+
+    for (PathModifyExpression expression : inputPathModify.getExpressions()) {
+      String pathField = expression.getPathFieldName();
+      RexObjectConstruct objectConstruct = expression.getObjectConstruct();
+      if (objectConstruct.getVariableInfo().stream().anyMatch(VariableInfo::isGlobal)) {
+        return true;
+      }
+      for (int i = 0; i < objectConstruct.getOperands().size(); i++) {
+        String field = objectConstruct.getVariableInfo().get(i).getName();
+        inputFields.add(pathField + "." + field);
+      }
     }
 
-    /**
-     * Check if pathModify has field dependencies, that is, the current PathModify references
-     * fields generated by input PathModify.
-     * In this case, there is a field dependency between LogicalPathModify, and we cannot merge
-     * them. Additionally, LogicalPathModify with Global field cannot be merged.
-     */
-    private boolean hasFieldRefDependency(MatchPathModify inputPathModify, MatchPathModify pathModify) {
-        Set inputFields = new HashSet<>();
-
-        for (PathModifyExpression expression : inputPathModify.getExpressions()) {
-            String pathField = expression.getPathFieldName();
-            RexObjectConstruct objectConstruct = expression.getObjectConstruct();
-            if (objectConstruct.getVariableInfo().stream().anyMatch(VariableInfo::isGlobal)) {
-                return true;
-            }
-            for (int i = 0; i < objectConstruct.getOperands().size(); i++) {
-                String field = objectConstruct.getVariableInfo().get(i).getName();
-                inputFields.add(pathField + "." + field);
-            }
-        }
-
-        for (PathModifyExpression expression : pathModify.getExpressions()) {
-            List operands = expression.getObjectConstruct().getOperands();
-            for (RexNode operand : operands) {
-                List rexNodes = GQLRexUtil.collect(operand, node -> node != null);
-
-                for (RexNode rexNode : rexNodes) {
-                    List pathRefers = GQLRexUtil.collect(rexNode, node -> node instanceof RexFieldAccess);
-                    boolean referInputField = pathRefers.stream().anyMatch(pathRef -> {
-                        String pathField = expression.getPathFieldName()
-                            + "." + pathRef.getField().getName();
+    for (PathModifyExpression expression : pathModify.getExpressions()) {
+      List operands = expression.getObjectConstruct().getOperands();
+      for (RexNode operand : operands) {
+        List rexNodes = GQLRexUtil.collect(operand, node -> node != null);
+
+        for (RexNode rexNode : rexNodes) {
+          List pathRefers =
+              GQLRexUtil.collect(rexNode, node -> node instanceof RexFieldAccess);
+          boolean referInputField =
+              pathRefers.stream()
+                  .anyMatch(
+                      pathRef -> {
+                        String pathField =
+                            expression.getPathFieldName() + "." + pathRef.getField().getName();
                         return inputFields.contains(pathField);
-                    });
-                    if (referInputField) { // current path modify has referred the input's sub query.
-                        return true;
-                    }
-                }
-            }
+                      });
+          if (referInputField) { // current path modify has referred the input's sub query.
+            return true;
+          }
         }
-        return false;
+      }
     }
-
-
-    private ImmutableList mergePathModifyExpression(
-        RexBuilder builder,
-        ImmutableList expressions,
-        ImmutableList inputExpressions) {
-        Map pathIndex2Expression = new LinkedHashMap<>();
-        for (PathModifyExpression expression : expressions) {
-            pathIndex2Expression.put(expression.getIndex(), expression);
-        }
-
-        List mergedExpressions = new ArrayList<>();
-
-        for (PathModifyExpression inputExpression : inputExpressions) {
-            PathModifyExpression expression = pathIndex2Expression.get(inputExpression.getIndex());
-            if (expression != null) {
-                RexObjectConstruct mergedObjConstruct =
-                    expression.getObjectConstruct().merge(inputExpression.getObjectConstruct(),
-                        expression.getIndex(), builder);
-                PathModifyExpression mergedExpression = expression.copy(mergedObjConstruct);
-                mergedExpressions.add(mergedExpression);
-                pathIndex2Expression.remove(inputExpression.getIndex());
-            } else {
-                mergedExpressions.add(inputExpression);
-            }
-        }
-        mergedExpressions.addAll(pathIndex2Expression.values());
-        return ImmutableList.copyOf(mergedExpressions);
+    return false;
+  }
+
+  private ImmutableList mergePathModifyExpression(
+      RexBuilder builder,
+      ImmutableList expressions,
+      ImmutableList inputExpressions) {
+    Map pathIndex2Expression = new LinkedHashMap<>();
+    for (PathModifyExpression expression : expressions) {
+      pathIndex2Expression.put(expression.getIndex(), expression);
     }
 
-
-    private RelDataType mergeRelType(RelDataType relType, RelDataType inputRelType) {
-        Map name2Type = new LinkedHashMap<>();
-        for (RelDataTypeField field : relType.getFieldList()) {
-            name2Type.put(field.getName(), field);
-        }
-        List mergedFields = new ArrayList<>();
-        for (RelDataTypeField inputField : inputRelType.getFieldList()) {
-            RelDataTypeField currentField = name2Type.get(inputField.getName());
-            if (currentField != null) {
-                mergedFields.add(currentField); // use current field override the input field
-                name2Type.remove(inputField.getName());
-            } else {
-                mergedFields.add(inputField);
-            }
-        }
-
-        mergedFields.addAll(name2Type.values());
-        return new PathRecordType(mergedFields);
+    List mergedExpressions = new ArrayList<>();
+
+    for (PathModifyExpression inputExpression : inputExpressions) {
+      PathModifyExpression expression = pathIndex2Expression.get(inputExpression.getIndex());
+      if (expression != null) {
+        RexObjectConstruct mergedObjConstruct =
+            expression
+                .getObjectConstruct()
+                .merge(inputExpression.getObjectConstruct(), expression.getIndex(), builder);
+        PathModifyExpression mergedExpression = expression.copy(mergedObjConstruct);
+        mergedExpressions.add(mergedExpression);
+        pathIndex2Expression.remove(inputExpression.getIndex());
+      } else {
+        mergedExpressions.add(inputExpression);
+      }
+    }
+    mergedExpressions.addAll(pathIndex2Expression.values());
+    return ImmutableList.copyOf(mergedExpressions);
+  }
+
+  private RelDataType mergeRelType(RelDataType relType, RelDataType inputRelType) {
+    Map name2Type = new LinkedHashMap<>();
+    for (RelDataTypeField field : relType.getFieldList()) {
+      name2Type.put(field.getName(), field);
     }
+    List mergedFields = new ArrayList<>();
+    for (RelDataTypeField inputField : inputRelType.getFieldList()) {
+      RelDataTypeField currentField = name2Type.get(inputField.getName());
+      if (currentField != null) {
+        mergedFields.add(currentField); // use current field override the input field
+        name2Type.remove(inputField.getName());
+      } else {
+        mergedFields.add(inputField);
+      }
+    }
+
+    mergedFields.addAll(name2Type.values());
+    return new PathRecordType(mergedFields);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/ProjectFieldPruneRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/ProjectFieldPruneRule.java
index 2281827a1..f1d33944e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/ProjectFieldPruneRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/ProjectFieldPruneRule.java
@@ -23,6 +23,7 @@
 import static org.apache.geaflow.dsl.common.types.VertexType.DEFAULT_ID_FIELD_NAME;
 
 import java.util.*;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -36,227 +37,218 @@
 import org.apache.geaflow.dsl.rex.RexParameterRef;
 
 /**
- * Rule to prune unnecessary fields from LogicalProject and push down field requirements
- * to LogicalGraphMatch.
+ * Rule to prune unnecessary fields from LogicalProject and push down field requirements to
+ * LogicalGraphMatch.
  */
 public class ProjectFieldPruneRule extends RelOptRule {
 
-    public static final ProjectFieldPruneRule INSTANCE = new ProjectFieldPruneRule();
+  public static final ProjectFieldPruneRule INSTANCE = new ProjectFieldPruneRule();
 
-    // Mapping for special field names
-    private static final Map SPECIAL_FIELD_MAP;
+  // Mapping for special field names
+  private static final Map SPECIAL_FIELD_MAP;
 
-    static {
-        SPECIAL_FIELD_MAP = new HashMap<>();
-        SPECIAL_FIELD_MAP.put("id", DEFAULT_ID_FIELD_NAME);
-        SPECIAL_FIELD_MAP.put("label", DEFAULT_LABEL_NAME);
-        SPECIAL_FIELD_MAP.put("srcId", DEFAULT_SRC_ID_NAME );
-        SPECIAL_FIELD_MAP.put("targetId", DEFAULT_TARGET_ID_NAME);
-    }
+  static {
+    SPECIAL_FIELD_MAP = new HashMap<>();
+    SPECIAL_FIELD_MAP.put("id", DEFAULT_ID_FIELD_NAME);
+    SPECIAL_FIELD_MAP.put("label", DEFAULT_LABEL_NAME);
+    SPECIAL_FIELD_MAP.put("srcId", DEFAULT_SRC_ID_NAME);
+    SPECIAL_FIELD_MAP.put("targetId", DEFAULT_TARGET_ID_NAME);
+  }
 
-    private ProjectFieldPruneRule() {
-        super(operand(LogicalProject.class, operand(LogicalGraphMatch.class, any())),
-                "ProjectFieldPruneRule");
-    }
+  private ProjectFieldPruneRule() {
+    super(
+        operand(LogicalProject.class, operand(LogicalGraphMatch.class, any())),
+        "ProjectFieldPruneRule");
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        LogicalProject project = call.rel(0);           // Get LogicalProject
-        LogicalGraphMatch graphMatch = call.rel(1);     // Get LogicalGraphMatch (direct child)
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    LogicalProject project = call.rel(0); // Get LogicalProject
+    LogicalGraphMatch graphMatch = call.rel(1); // Get LogicalGraphMatch (direct child)
 
-        // 1. Extract field access information from LogicalProject
-        Set filteredElements = extractFields(project);
+    // 1. Extract field access information from LogicalProject
+    Set filteredElements = extractFields(project);
 
-        // 2. Pass the filtered field information to LogicalGraphMatch
-        if (!filteredElements.isEmpty()) {
-            traverseAndPruneFields(filteredElements, graphMatch.getPathPattern());
-        }
+    // 2. Pass the filtered field information to LogicalGraphMatch
+    if (!filteredElements.isEmpty()) {
+      traverseAndPruneFields(filteredElements, graphMatch.getPathPattern());
     }
+  }
+
+  /**
+   * Extract fields from LogicalProject and convert to semantic information (e.g., $0.id -> a.id).
+   */
+  private Set extractFields(LogicalProject project) {
+    List fields = project.getChildExps();
+    Set fieldAccesses = new HashSet<>();
+
+    for (RexNode node : fields) {
+      // Recursively collect all field accesses
+      fieldAccesses.addAll(collectAllFieldAccesses(project.getCluster().getRexBuilder(), node));
+    }
+
+    // Convert index-based references to label-based path references
+    return convertToPathRefs(fieldAccesses, project.getInput(0));
+  }
+
+  /** Recursively collect all RexFieldAccess nodes from a RexNode tree. */
+  private static Set collectAllFieldAccesses(
+      RexBuilder rexBuilder, RexNode rootNode) {
+    Set fieldAccesses = new HashSet<>();
+    Queue queue = new LinkedList<>();
+    queue.offer(rootNode);
+
+    while (!queue.isEmpty()) {
+      RexNode node = queue.poll();
+
+      if (node instanceof RexFieldAccess) {
+        // Direct field access
+        fieldAccesses.add((RexFieldAccess) node);
+
+      } else if (node instanceof RexCall) {
+        // Custom function call, need to extract and convert elements
+        RexCall rexCall = (RexCall) node;
+
+        // Check if it's a field access type call (operand[0] is ref, operator is field name)
+        if (rexCall.getOperands().size() > 0) {
+          RexNode ref = rexCall.getOperands().get(0);
+          String fieldName = rexCall.getOperator().getName();
+
+          // Handle special fields with mapping
+          if (SPECIAL_FIELD_MAP.containsKey(fieldName)) {
+            String mappedFieldName = SPECIAL_FIELD_MAP.get(fieldName);
+            fieldAccesses.add(
+                (RexFieldAccess) rexBuilder.makeFieldAccess(ref, mappedFieldName, false));
+
+          } else if (ref instanceof RexInputRef) {
+            // Other non-nested custom functions: enumerate all fields of ref and add them all
+            RelDataType refType = ref.getType();
+            List refFields = refType.getFieldList();
+
+            for (RelDataTypeField field : refFields) {
+              RexFieldAccess fieldAccess =
+                  (RexFieldAccess) rexBuilder.makeFieldAccess(ref, field.getName(), false);
+              fieldAccesses.add(fieldAccess);
+            }
 
-    /**
-     * Extract fields from LogicalProject and convert to semantic information (e.g., $0.id -> a.id).
-     */
-    private Set extractFields(LogicalProject project) {
-        List fields = project.getChildExps();
-        Set fieldAccesses = new HashSet<>();
-
-        for (RexNode node : fields) {
-            // Recursively collect all field accesses
-            fieldAccesses.addAll(collectAllFieldAccesses(
-                    project.getCluster().getRexBuilder(), node));
+          } else {
+            // ref itself might be a complex expression, continue recursive processing
+            queue.add(ref);
+          }
+
+          // Add other operands to the queue for continued processing
+          for (int i = 1; i < rexCall.getOperands().size(); i++) {
+            queue.add(rexCall.getOperands().get(i));
+          }
         }
 
-        // Convert index-based references to label-based path references
-        return convertToPathRefs(fieldAccesses, project.getInput(0));
-    }
+      } else if (node instanceof RexInputRef) {
+        // RexInputRef directly references input, enumerate all its fields
+        RelDataType refType = node.getType();
+        List refFields = refType.getFieldList();
 
-    /**
-     * Recursively collect all RexFieldAccess nodes from a RexNode tree.
-     */
-    private static Set collectAllFieldAccesses(RexBuilder rexBuilder, RexNode rootNode) {
-        Set fieldAccesses = new HashSet<>();
-        Queue queue = new LinkedList<>();
-        queue.offer(rootNode);
-
-        while (!queue.isEmpty()) {
-            RexNode node = queue.poll();
-
-            if (node instanceof RexFieldAccess) {
-                // Direct field access
-                fieldAccesses.add((RexFieldAccess) node);
-
-            } else if (node instanceof RexCall) {
-                // Custom function call, need to extract and convert elements
-                RexCall rexCall = (RexCall) node;
-
-                // Check if it's a field access type call (operand[0] is ref, operator is field name)
-                if (rexCall.getOperands().size() > 0) {
-                    RexNode ref = rexCall.getOperands().get(0);
-                    String fieldName = rexCall.getOperator().getName();
-
-                    // Handle special fields with mapping
-                    if (SPECIAL_FIELD_MAP.containsKey(fieldName)) {
-                        String mappedFieldName = SPECIAL_FIELD_MAP.get(fieldName);
-                        fieldAccesses.add((RexFieldAccess) rexBuilder.makeFieldAccess(ref, mappedFieldName, false));
-
-                    } else if (ref instanceof RexInputRef) {
-                        // Other non-nested custom functions: enumerate all fields of ref and add them all
-                        RelDataType refType = ref.getType();
-                        List refFields = refType.getFieldList();
-
-                        for (RelDataTypeField field : refFields) {
-                            RexFieldAccess fieldAccess = (RexFieldAccess) rexBuilder.makeFieldAccess(
-                                    ref,
-                                    field.getName(),
-                                    false
-                            );
-                            fieldAccesses.add(fieldAccess);
-                        }
-
-                    } else {
-                        // ref itself might be a complex expression, continue recursive processing
-                        queue.add(ref);
-                    }
-
-                    // Add other operands to the queue for continued processing
-                    for (int i = 1; i < rexCall.getOperands().size(); i++) {
-                        queue.add(rexCall.getOperands().get(i));
-                    }
-                }
-
-            } else if (node instanceof RexInputRef) {
-                // RexInputRef directly references input, enumerate all its fields
-                RelDataType refType = node.getType();
-                List refFields = refType.getFieldList();
-
-                for (RelDataTypeField field : refFields) {
-                    RexFieldAccess fieldAccess = (RexFieldAccess) rexBuilder.makeFieldAccess(
-                            node,
-                            field.getName(),
-                            false
-                    );
-                    fieldAccesses.add(fieldAccess);
-                }
-
-            } else if (node instanceof RexLiteral || node instanceof RexParameterRef) {
-                // Literals, skip
-                continue;
-
-            } else {
-                // Other unknown types, can choose to throw exception or log
-                throw new IllegalArgumentException("Unsupported type: " + node.getClass());
-            }
+        for (RelDataTypeField field : refFields) {
+          RexFieldAccess fieldAccess =
+              (RexFieldAccess) rexBuilder.makeFieldAccess(node, field.getName(), false);
+          fieldAccesses.add(fieldAccess);
         }
 
-        return fieldAccesses;
+      } else if (node instanceof RexLiteral || node instanceof RexParameterRef) {
+        // Literals, skip
+        continue;
+
+      } else {
+        // Other unknown types, can choose to throw exception or log
+        throw new IllegalArgumentException("Unsupported type: " + node.getClass());
+      }
     }
 
-    /**
-     * Convert index-only field accesses to complete fields with labels.
-     */
-    private static Set convertToPathRefs(Set fieldAccesses, RelNode node) {
-        Set convertedFieldAccesses = new HashSet<>();
-        RelDataType pathRecordType = node.getRowType(); // Get the record type at current level
-        RexBuilder rexBuilder = node.getCluster().getRexBuilder(); // Builder for creating new fields
-
-        for (RexFieldAccess fieldAccess : fieldAccesses) {
-            RexNode referenceExpr = fieldAccess.getReferenceExpr();
-
-            // Only process field accesses of input reference type
-            if (referenceExpr instanceof RexInputRef) {
-                RexInputRef inputRef = (RexInputRef) referenceExpr;
-
-                // If index exceeds field list size, it comes from a subquery, skip it
-                if (pathRecordType.getFieldList().size() <= inputRef.getIndex()) {
-                    continue;
-                }
-
-                // Get the corresponding path field information from PathRecordType
-                RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
-
-                // Create the actual PathInputRef
-                PathInputRef pathInputRef = new PathInputRef(
-                        pathField.getName(),     // Path variable name (e.g., "a", "b", "c")
-                        pathField.getIndex(),    // Field index
-                        pathField.getType()      // Field type
-                );
+    return fieldAccesses;
+  }
 
-                // Recreate RexFieldAccess with the new path reference
-                RexFieldAccess newFieldAccess = (RexFieldAccess) rexBuilder.makeFieldAccess(
-                        pathInputRef,
-                        fieldAccess.getField().getIndex()
-                );
-                convertedFieldAccesses.add(newFieldAccess);
-            }
+  /** Convert index-only field accesses to complete fields with labels. */
+  private static Set convertToPathRefs(
+      Set fieldAccesses, RelNode node) {
+    Set convertedFieldAccesses = new HashSet<>();
+    RelDataType pathRecordType = node.getRowType(); // Get the record type at current level
+    RexBuilder rexBuilder = node.getCluster().getRexBuilder(); // Builder for creating new fields
+
+    for (RexFieldAccess fieldAccess : fieldAccesses) {
+      RexNode referenceExpr = fieldAccess.getReferenceExpr();
+
+      // Only process field accesses of input reference type
+      if (referenceExpr instanceof RexInputRef) {
+        RexInputRef inputRef = (RexInputRef) referenceExpr;
+
+        // If index exceeds field list size, it comes from a subquery, skip it
+        if (pathRecordType.getFieldList().size() <= inputRef.getIndex()) {
+          continue;
         }
 
-        return convertedFieldAccesses;
-    }
+        // Get the corresponding path field information from PathRecordType
+        RelDataTypeField pathField = pathRecordType.getFieldList().get(inputRef.getIndex());
 
-    /**
-     * Traverse the path pattern and add filtered fields to matching nodes.
-     */
-    private static void traverseAndPruneFields(Set fields, IMatchNode pathPattern) {
-        Queue queue = new LinkedList<>(); // Queue for nodes to visit
-        Set visited = new HashSet<>();    // Mark visited nodes
-
-        queue.offer(pathPattern);
-        visited.add(pathPattern);
-
-        // Visit all nodes in the path, and for each field: if label matches, add the field to .fields
-        while (!queue.isEmpty()) {
-            IMatchNode currentPathPattern = queue.poll();
-
-            if (currentPathPattern instanceof VertexMatch) {
-                VertexMatch vertexMatch = (VertexMatch) currentPathPattern;
-                String vertexLabel = vertexMatch.getLabel();
-                for (RexFieldAccess fieldElement : fields) {
-                    PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
-                    if (inputRef.getLabel().equals(vertexLabel)) {
-                        vertexMatch.addField(fieldElement);
-                    }
-                }
-            }
+        // Create the actual PathInputRef
+        PathInputRef pathInputRef =
+            new PathInputRef(
+                pathField.getName(), // Path variable name (e.g., "a", "b", "c")
+                pathField.getIndex(), // Field index
+                pathField.getType() // Field type
+                );
 
-            if (currentPathPattern instanceof EdgeMatch) {
-                EdgeMatch edgeMatch = (EdgeMatch) currentPathPattern;
-                String edgeLabel = edgeMatch.getLabel();
-                for (RexFieldAccess fieldElement : fields) {
-                    PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
-                    if (inputRef.getLabel().equals(edgeLabel)) {
-                        edgeMatch.addField(fieldElement);
-                    }
-                }
-            }
+        // Recreate RexFieldAccess with the new path reference
+        RexFieldAccess newFieldAccess =
+            (RexFieldAccess)
+                rexBuilder.makeFieldAccess(pathInputRef, fieldAccess.getField().getIndex());
+        convertedFieldAccesses.add(newFieldAccess);
+      }
+    }
 
-            // Iterate through possible child nodes
-            List inputs = currentPathPattern.getInputs();
-            for (RelNode candidateInput : inputs) {
-                if (candidateInput != null && !visited.contains((IMatchNode) candidateInput)) {
-                    queue.offer((IMatchNode) candidateInput);
-                    visited.add((IMatchNode) candidateInput);
-                }
-            }
+    return convertedFieldAccesses;
+  }
+
+  /** Traverse the path pattern and add filtered fields to matching nodes. */
+  private static void traverseAndPruneFields(Set fields, IMatchNode pathPattern) {
+    Queue queue = new LinkedList<>(); // Queue for nodes to visit
+    Set visited = new HashSet<>(); // Mark visited nodes
+
+    queue.offer(pathPattern);
+    visited.add(pathPattern);
+
+    // Visit all nodes in the path, and for each field: if label matches, add the field to .fields
+    while (!queue.isEmpty()) {
+      IMatchNode currentPathPattern = queue.poll();
+
+      if (currentPathPattern instanceof VertexMatch) {
+        VertexMatch vertexMatch = (VertexMatch) currentPathPattern;
+        String vertexLabel = vertexMatch.getLabel();
+        for (RexFieldAccess fieldElement : fields) {
+          PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
+          if (inputRef.getLabel().equals(vertexLabel)) {
+            vertexMatch.addField(fieldElement);
+          }
+        }
+      }
+
+      if (currentPathPattern instanceof EdgeMatch) {
+        EdgeMatch edgeMatch = (EdgeMatch) currentPathPattern;
+        String edgeLabel = edgeMatch.getLabel();
+        for (RexFieldAccess fieldElement : fields) {
+          PathInputRef inputRef = (PathInputRef) fieldElement.getReferenceExpr();
+          if (inputRef.getLabel().equals(edgeLabel)) {
+            edgeMatch.addField(fieldElement);
+          }
+        }
+      }
+
+      // Iterate through possible child nodes
+      List inputs = currentPathPattern.getInputs();
+      for (RelNode candidateInput : inputs) {
+        if (candidateInput != null && !visited.contains((IMatchNode) candidateInput)) {
+          queue.offer((IMatchNode) candidateInput);
+          visited.add((IMatchNode) candidateInput);
         }
+      }
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushConsecutiveJoinConditionRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushConsecutiveJoinConditionRule.java
index 12d9140ed..d1ac6c76f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushConsecutiveJoinConditionRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushConsecutiveJoinConditionRule.java
@@ -22,6 +22,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
@@ -36,63 +37,84 @@
 
 public class PushConsecutiveJoinConditionRule extends RelOptRule {
 
-    public static final PushConsecutiveJoinConditionRule INSTANCE = new PushConsecutiveJoinConditionRule();
+  public static final PushConsecutiveJoinConditionRule INSTANCE =
+      new PushConsecutiveJoinConditionRule();
 
-    private PushConsecutiveJoinConditionRule() {
-        super(operand(LogicalJoin.class, operand(LogicalJoin.class, any())));
-    }
+  private PushConsecutiveJoinConditionRule() {
+    super(operand(LogicalJoin.class, operand(LogicalJoin.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        LogicalJoin bottomJoin = call.rel(0);
-        LogicalJoin join = call.rel(1);
-        if (!join.getJoinType().equals(JoinRelType.INNER) || !bottomJoin.getJoinType().equals(JoinRelType.INNER)) {
-            // Consecutive pushing of conditions for non-INNER joins is not supported.
-            return;
-        }
-        List splitRexNodes = RelOptUtil.conjunctions(bottomJoin.getCondition());
-        List pushRexNodes = splitRexNodes.stream().filter(n -> n.getKind().equals(SqlKind.EQUALS)).collect(
-            Collectors.toList());
-        // Filter out conditions unrelated to the sub-join.
-        boolean isBottomLeft = GQLRelUtil.toRel(bottomJoin.getLeft()).equals(join);
-        int fieldCount = join.getRowType().getFieldCount();
-        int fieldStart = isBottomLeft ? 0 : bottomJoin.getRowType().getFieldCount() - fieldCount;
-        int fieldEnd = isBottomLeft ? fieldCount : bottomJoin.getRowType().getFieldCount();
-        pushRexNodes = pushRexNodes.stream().filter(
-            rex -> GQLRexUtil.collect(rex, r -> r instanceof RexInputRef).stream().noneMatch(
-                inputRef -> ((RexInputRef) inputRef).getIndex() < fieldStart
-                    || ((RexInputRef) inputRef).getIndex() >= fieldEnd
-            )
-        ).collect(Collectors.toList());
-        List remainRexNodes = new ArrayList<>();
-        for (RexNode rex : splitRexNodes) {
-            if (!pushRexNodes.contains(rex)) {
-                remainRexNodes.add(rex);
-            }
-        }
-        // If the join is the right input, adjust all referenced indices.
-        if (!isBottomLeft) {
-            pushRexNodes = pushRexNodes.stream().map(
-                rex -> GQLRexUtil.replace(rex, rexNode -> {
-                    if (rexNode instanceof RexInputRef) {
-                        int index = ((RexInputRef) rexNode).getIndex();
-                        index -= (bottomJoin.getRowType().getFieldCount() - fieldCount);
-                        return new RexInputRef(index, rexNode.getType());
-                    } else {
-                        return rexNode;
-                    }
-                })
-            ).collect(Collectors.toList());
-        }
-        pushRexNodes.add(join.getCondition());
-        RexNode equalRexNode = RexUtil.composeConjunction(
-            call.builder().getRexBuilder(), pushRexNodes);
-        LogicalJoin newJoin = join.copy(join.getTraitSet(), equalRexNode, join.getLeft(),
-            join.getRight(), join.getJoinType(), join.isSemiJoinDone());
-        call.transformTo(bottomJoin.copy(bottomJoin.getTraitSet(),
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    LogicalJoin bottomJoin = call.rel(0);
+    LogicalJoin join = call.rel(1);
+    if (!join.getJoinType().equals(JoinRelType.INNER)
+        || !bottomJoin.getJoinType().equals(JoinRelType.INNER)) {
+      // Consecutive pushing of conditions for non-INNER joins is not supported.
+      return;
+    }
+    List splitRexNodes = RelOptUtil.conjunctions(bottomJoin.getCondition());
+    List pushRexNodes =
+        splitRexNodes.stream()
+            .filter(n -> n.getKind().equals(SqlKind.EQUALS))
+            .collect(Collectors.toList());
+    // Filter out conditions unrelated to the sub-join.
+    boolean isBottomLeft = GQLRelUtil.toRel(bottomJoin.getLeft()).equals(join);
+    int fieldCount = join.getRowType().getFieldCount();
+    int fieldStart = isBottomLeft ? 0 : bottomJoin.getRowType().getFieldCount() - fieldCount;
+    int fieldEnd = isBottomLeft ? fieldCount : bottomJoin.getRowType().getFieldCount();
+    pushRexNodes =
+        pushRexNodes.stream()
+            .filter(
+                rex ->
+                    GQLRexUtil.collect(rex, r -> r instanceof RexInputRef).stream()
+                        .noneMatch(
+                            inputRef ->
+                                ((RexInputRef) inputRef).getIndex() < fieldStart
+                                    || ((RexInputRef) inputRef).getIndex() >= fieldEnd))
+            .collect(Collectors.toList());
+    List remainRexNodes = new ArrayList<>();
+    for (RexNode rex : splitRexNodes) {
+      if (!pushRexNodes.contains(rex)) {
+        remainRexNodes.add(rex);
+      }
+    }
+    // If the join is the right input, adjust all referenced indices.
+    if (!isBottomLeft) {
+      pushRexNodes =
+          pushRexNodes.stream()
+              .map(
+                  rex ->
+                      GQLRexUtil.replace(
+                          rex,
+                          rexNode -> {
+                            if (rexNode instanceof RexInputRef) {
+                              int index = ((RexInputRef) rexNode).getIndex();
+                              index -= (bottomJoin.getRowType().getFieldCount() - fieldCount);
+                              return new RexInputRef(index, rexNode.getType());
+                            } else {
+                              return rexNode;
+                            }
+                          }))
+              .collect(Collectors.toList());
+    }
+    pushRexNodes.add(join.getCondition());
+    RexNode equalRexNode = RexUtil.composeConjunction(call.builder().getRexBuilder(), pushRexNodes);
+    LogicalJoin newJoin =
+        join.copy(
+            join.getTraitSet(),
+            equalRexNode,
+            join.getLeft(),
+            join.getRight(),
+            join.getJoinType(),
+            join.isSemiJoinDone());
+    call.transformTo(
+        bottomJoin.copy(
+            bottomJoin.getTraitSet(),
             RexUtil.composeConjunction(call.builder().getRexBuilder(), remainRexNodes),
             isBottomLeft ? newJoin : bottomJoin.getLeft(),
             isBottomLeft ? bottomJoin.getRight() : newJoin,
-            bottomJoin.getJoinType(), bottomJoin.isSemiJoinDone()));
-    }
+            bottomJoin.getJoinType(),
+            bottomJoin.isSemiJoinDone()));
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushJoinFilterConditionRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushJoinFilterConditionRule.java
index f57f27f12..33af2a726 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushJoinFilterConditionRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/PushJoinFilterConditionRule.java
@@ -22,6 +22,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
@@ -35,43 +36,51 @@
 
 public class PushJoinFilterConditionRule extends RelOptRule {
 
-    public static final PushJoinFilterConditionRule INSTANCE = new PushJoinFilterConditionRule();
+  public static final PushJoinFilterConditionRule INSTANCE = new PushJoinFilterConditionRule();
 
-    private PushJoinFilterConditionRule() {
-        super(operand(LogicalFilter.class,
-            operand(LogicalJoin.class, any())));
-    }
+  private PushJoinFilterConditionRule() {
+    super(operand(LogicalFilter.class, operand(LogicalJoin.class, any())));
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        LogicalFilter filter = call.rel(0);
-        LogicalJoin join = call.rel(1);
-        if (!join.getJoinType().equals(JoinRelType.INNER)) {
-            // Consecutive pushing of conditions for non-INNER joins is not supported.
-            return;
-        }
-        List splitRexNodes = RelOptUtil.conjunctions(filter.getCondition());
-        final List joinRexNodes = new ArrayList<>();
-        final List remainRexNodes = new ArrayList<>();
-        splitRexNodes.stream().map(n -> {
-            if (n.getKind().equals(SqlKind.EQUALS)) {
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    LogicalFilter filter = call.rel(0);
+    LogicalJoin join = call.rel(1);
+    if (!join.getJoinType().equals(JoinRelType.INNER)) {
+      // Consecutive pushing of conditions for non-INNER joins is not supported.
+      return;
+    }
+    List splitRexNodes = RelOptUtil.conjunctions(filter.getCondition());
+    final List joinRexNodes = new ArrayList<>();
+    final List remainRexNodes = new ArrayList<>();
+    splitRexNodes.stream()
+        .map(
+            n -> {
+              if (n.getKind().equals(SqlKind.EQUALS)) {
                 joinRexNodes.add(n);
-            } else {
+              } else {
                 remainRexNodes.add(n);
-            }
-            return n;
-        }).collect(Collectors.toList());
-        joinRexNodes.add(join.getCondition());
-        RexNode equalRexNode = RexUtil.composeConjunction(
-            new RexBuilder(call.builder().getTypeFactory()), joinRexNodes);
-        RexNode remainRexNode = RexUtil.composeConjunction(
-            new RexBuilder(call.builder().getTypeFactory()), remainRexNodes);
-        LogicalJoin newJoin = join.copy(join.getTraitSet(), equalRexNode, join.getLeft(),
-            join.getRight(), join.getJoinType(), join.isSemiJoinDone());
-        if (remainRexNode.isAlwaysTrue()) {
-            call.transformTo(newJoin);
-        } else {
-            call.transformTo(filter.copy(filter.getTraitSet(), newJoin, remainRexNode));
-        }
+              }
+              return n;
+            })
+        .collect(Collectors.toList());
+    joinRexNodes.add(join.getCondition());
+    RexNode equalRexNode =
+        RexUtil.composeConjunction(new RexBuilder(call.builder().getTypeFactory()), joinRexNodes);
+    RexNode remainRexNode =
+        RexUtil.composeConjunction(new RexBuilder(call.builder().getTypeFactory()), remainRexNodes);
+    LogicalJoin newJoin =
+        join.copy(
+            join.getTraitSet(),
+            equalRexNode,
+            join.getLeft(),
+            join.getRight(),
+            join.getJoinType(),
+            join.isSemiJoinDone());
+    if (remainRexNode.isAlwaysTrue()) {
+      call.transformTo(newJoin);
+    } else {
+      call.transformTo(filter.copy(filter.getTraitSet(), newJoin, remainRexNode));
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinMatchToGraphMatchRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinMatchToGraphMatchRule.java
index 766a713bb..0e41b514c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinMatchToGraphMatchRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinMatchToGraphMatchRule.java
@@ -32,59 +32,61 @@
 
 public class TableJoinMatchToGraphMatchRule extends AbstractJoinToGraphRule {
 
-    public static final TableJoinMatchToGraphMatchRule INSTANCE = new TableJoinMatchToGraphMatchRule();
+  public static final TableJoinMatchToGraphMatchRule INSTANCE =
+      new TableJoinMatchToGraphMatchRule();
 
-    private TableJoinMatchToGraphMatchRule() {
-        super(operand(LogicalJoin.class,
-            operand(RelNode.class, any()),
-            operand(RelNode.class, any())));
-    }
+  private TableJoinMatchToGraphMatchRule() {
+    super(operand(LogicalJoin.class, operand(RelNode.class, any()), operand(RelNode.class, any())));
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        LogicalJoin join = call.rel(0);
-        if (!isSupportJoinType(join.getJoinType())) {
-            // non-INNER joins is not supported.
-            return false;
-        }
-        RelNode leftInput = call.rel(1);
-        RelNode rightInput = call.rel(2);
-        return isSingleChainFromLogicalTableScan(leftInput)
-            && isSingleChainFromGraphMatch(rightInput);
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    LogicalJoin join = call.rel(0);
+    if (!isSupportJoinType(join.getJoinType())) {
+      // non-INNER joins is not supported.
+      return false;
     }
+    RelNode leftInput = call.rel(1);
+    RelNode rightInput = call.rel(2);
+    return isSingleChainFromLogicalTableScan(leftInput) && isSingleChainFromGraphMatch(rightInput);
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode leftInput = call.rel(1);
-        RelNode leftTableScan = leftInput;
-        while (!(leftTableScan instanceof LogicalTableScan)) {
-            leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
-        }
-        RelNode rightInput = call.rel(2);
-        RelNode graphMatchProject = null;
-        RelNode rightGraphMatch = rightInput;
-        while (rightGraphMatch != null && !(rightGraphMatch instanceof LogicalGraphMatch)) {
-            graphMatchProject = rightGraphMatch;
-            rightGraphMatch = GQLRelUtil.toRel(rightGraphMatch.getInput(0));
-        }
-        RelNode tail = processGraphMatchJoinTable(
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode leftInput = call.rel(1);
+    RelNode leftTableScan = leftInput;
+    while (!(leftTableScan instanceof LogicalTableScan)) {
+      leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
+    }
+    RelNode rightInput = call.rel(2);
+    RelNode graphMatchProject = null;
+    RelNode rightGraphMatch = rightInput;
+    while (rightGraphMatch != null && !(rightGraphMatch instanceof LogicalGraphMatch)) {
+      graphMatchProject = rightGraphMatch;
+      rightGraphMatch = GQLRelUtil.toRel(rightGraphMatch.getInput(0));
+    }
+    RelNode tail =
+        processGraphMatchJoinTable(
             call,
             call.rel(0),
             (LogicalGraphMatch) rightGraphMatch,
             (LogicalProject) graphMatchProject,
             (LogicalTableScan) leftTableScan,
-            leftInput, leftTableScan, rightInput, rightGraphMatch, false
-        );
-        if (tail == null) {
-            return;
-        }
-        LogicalJoin join = call.rel(0);
-        // add remain filter.
-        JoinInfo joinInfo = join.analyzeCondition();
-        RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
-        if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
-            tail = LogicalFilter.create(tail, remainFilter);
-        }
-        call.transformTo(tail);
+            leftInput,
+            leftTableScan,
+            rightInput,
+            rightGraphMatch,
+            false);
+    if (tail == null) {
+      return;
+    }
+    LogicalJoin join = call.rel(0);
+    // add remain filter.
+    JoinInfo joinInfo = join.analyzeCondition();
+    RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
+    if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
+      tail = LogicalFilter.create(tail, remainFilter);
     }
+    call.transformTo(tail);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinTableToGraphRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinTableToGraphRule.java
index bc5b5a120..36738e6dc 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinTableToGraphRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableJoinTableToGraphRule.java
@@ -27,6 +27,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -70,343 +71,444 @@
 
 public class TableJoinTableToGraphRule extends AbstractJoinToGraphRule {
 
-    private static final Logger LOGGER = LoggerFactory.getLogger(TableJoinTableToGraphRule.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(TableJoinTableToGraphRule.class);
+
+  public static final TableJoinTableToGraphRule INSTANCE = new TableJoinTableToGraphRule();
 
-    public static final TableJoinTableToGraphRule INSTANCE = new TableJoinTableToGraphRule();
+  private TableJoinTableToGraphRule() {
+    super(operand(LogicalJoin.class, operand(RelNode.class, any()), operand(RelNode.class, any())));
+  }
 
-    private TableJoinTableToGraphRule() {
-        super(operand(LogicalJoin.class,
-            operand(RelNode.class, any()),
-            operand(RelNode.class, any())));
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    LogicalJoin join = call.rel(0);
+    if (!isSupportJoinType(join.getJoinType())) {
+      // non-INNER joins is not supported.
+      return false;
     }
+    RelNode leftInput = call.rel(1);
+    RelNode rightInput = call.rel(2);
+    return isSingleChainFromLogicalTableScan(leftInput)
+        && isSingleChainFromLogicalTableScan(rightInput);
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        LogicalJoin join = call.rel(0);
-        if (!isSupportJoinType(join.getJoinType())) {
-            // non-INNER joins is not supported.
-            return false;
-        }
-        RelNode leftInput = call.rel(1);
-        RelNode rightInput = call.rel(2);
-        return isSingleChainFromLogicalTableScan(leftInput) && isSingleChainFromLogicalTableScan(rightInput);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode leftInput = call.rel(1);
+    RelNode leftHead = null;
+    RelNode leftTableScan = leftInput;
+    while (!(leftTableScan instanceof LogicalTableScan)) {
+      leftHead = leftTableScan;
+      leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
     }
+    RelNode rightInput = call.rel(2);
+    RelNode rightHead = null;
+    RelNode rightTableScan = rightInput;
+    while (!(rightTableScan instanceof LogicalTableScan)) {
+      rightHead = rightTableScan;
+      rightTableScan = GQLRelUtil.toRel(rightTableScan.getInput(0));
+    }
+    GeaFlowTable leftTable = leftTableScan.getTable().unwrap(GeaFlowTable.class);
+    GeaFlowTable rightTable = rightTableScan.getTable().unwrap(GeaFlowTable.class);
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode leftInput = call.rel(1);
-        RelNode leftHead = null;
-        RelNode leftTableScan = leftInput;
-        while (!(leftTableScan instanceof LogicalTableScan)) {
-            leftHead = leftTableScan;
-            leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
-        }
-        RelNode rightInput = call.rel(2);
-        RelNode rightHead = null;
-        RelNode rightTableScan = rightInput;
-        while (!(rightTableScan instanceof LogicalTableScan)) {
-            rightHead = rightTableScan;
-            rightTableScan = GQLRelUtil.toRel(rightTableScan.getInput(0));
-        }
-        GeaFlowTable leftTable = leftTableScan.getTable().unwrap(GeaFlowTable.class);
-        GeaFlowTable rightTable = rightTableScan.getTable().unwrap(GeaFlowTable.class);
+    GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
+    GeaFlowGraph currentGraph = typeFactory.getCurrentGraph();
+    if (!currentGraph.containTable(leftTable)) {
+      if (leftTable instanceof VertexTable || leftTable instanceof EdgeTable) {
+        throw new GeaFlowDSLException(
+            "Unknown graph element: {}, use graph please.", leftTable.getName());
+      }
+      return;
+    }
+    if (!currentGraph.containTable(rightTable)) {
+      if (rightTable instanceof VertexTable || rightTable instanceof EdgeTable) {
+        throw new GeaFlowDSLException(
+            "Unknown graph element: {}, use graph please.", rightTable.getName());
+      }
+      return;
+    }
 
-        GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
-        GeaFlowGraph currentGraph = typeFactory.getCurrentGraph();
-        if (!currentGraph.containTable(leftTable)) {
-            if (leftTable instanceof VertexTable || leftTable instanceof EdgeTable) {
-                throw new GeaFlowDSLException("Unknown graph element: {}, use graph please.",
-                    leftTable.getName());
-            }
-            return;
+    LogicalJoin join = call.rel(0);
+    GraphJoinType graphJoinType = getJoinType(join);
+    RelNode tail = null;
+    switch (graphJoinType) {
+      case VERTEX_JOIN_EDGE:
+        if (leftTable instanceof VertexTable && rightTable instanceof EdgeTable) {
+          VertexTable vertexTable = (VertexTable) leftTable;
+          EdgeTable edgeTable = (EdgeTable) rightTable;
+          tail =
+              vertexJoinEdgeSrc(
+                  vertexTable, edgeTable, call, true, leftInput, leftHead, rightInput, rightHead);
+        } else if (leftTable instanceof EdgeTable && rightTable instanceof VertexTable) {
+          VertexTable vertexTable = (VertexTable) rightTable;
+          EdgeTable edgeTable = (EdgeTable) leftTable;
+          tail =
+              vertexJoinEdgeSrc(
+                  vertexTable, edgeTable, call, false, leftInput, leftHead, rightInput, rightHead);
         }
-        if (!currentGraph.containTable(rightTable)) {
-            if (rightTable instanceof VertexTable || rightTable instanceof EdgeTable) {
-                throw new GeaFlowDSLException("Unknown graph element: {}, use graph please.",
-                    rightTable.getName());
-            }
-            return;
+        if (tail == null) {
+          return;
         }
-
-        LogicalJoin join = call.rel(0);
-        GraphJoinType graphJoinType = getJoinType(join);
-        RelNode tail = null;
-        switch (graphJoinType) {
-            case VERTEX_JOIN_EDGE:
-                if (leftTable instanceof VertexTable
-                    && rightTable instanceof EdgeTable) {
-                    VertexTable vertexTable = (VertexTable) leftTable;
-                    EdgeTable edgeTable = (EdgeTable) rightTable;
-                    tail = vertexJoinEdgeSrc(vertexTable, edgeTable, call, true,
-                        leftInput, leftHead, rightInput, rightHead);
-                } else if (leftTable instanceof EdgeTable
-                    && rightTable instanceof VertexTable) {
-                    VertexTable vertexTable = (VertexTable) rightTable;
-                    EdgeTable edgeTable = (EdgeTable) leftTable;
-                    tail = vertexJoinEdgeSrc(vertexTable, edgeTable, call, false,
-                        leftInput, leftHead, rightInput, rightHead);
-                }
-                if (tail == null) {
-                    return;
-                }
-                break;
-            case EDGE_JOIN_VERTEX:
-                if (leftTable instanceof VertexTable
-                    && rightTable instanceof EdgeTable) {
-                    VertexTable vertexTable = (VertexTable) leftTable;
-                    EdgeTable edgeTable = (EdgeTable) rightTable;
-                    tail = edgeTargetJoinVertex(vertexTable, edgeTable, call, true,
-                        leftInput, leftHead, rightInput, rightHead);
-                } else if (leftTable instanceof EdgeTable
-                    && rightTable instanceof VertexTable) {
-                    VertexTable vertexTable = (VertexTable) rightTable;
-                    EdgeTable edgeTable = (EdgeTable) leftTable;
-                    tail = edgeTargetJoinVertex(vertexTable, edgeTable, call, false,
-                        leftInput, leftHead, rightInput, rightHead);
-                }
-                if (tail == null) {
-                    return;
-                }
-                break;
-            default:
+        break;
+      case EDGE_JOIN_VERTEX:
+        if (leftTable instanceof VertexTable && rightTable instanceof EdgeTable) {
+          VertexTable vertexTable = (VertexTable) leftTable;
+          EdgeTable edgeTable = (EdgeTable) rightTable;
+          tail =
+              edgeTargetJoinVertex(
+                  vertexTable, edgeTable, call, true, leftInput, leftHead, rightInput, rightHead);
+        } else if (leftTable instanceof EdgeTable && rightTable instanceof VertexTable) {
+          VertexTable vertexTable = (VertexTable) rightTable;
+          EdgeTable edgeTable = (EdgeTable) leftTable;
+          tail =
+              edgeTargetJoinVertex(
+                  vertexTable, edgeTable, call, false, leftInput, leftHead, rightInput, rightHead);
         }
         if (tail == null) {
-            return;
-        }
-        // add remain filter.
-        JoinInfo joinInfo = join.analyzeCondition();
-        RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
-        if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
-            tail = LogicalFilter.create(tail, remainFilter);
+          return;
         }
-        call.transformTo(tail);
+        break;
+      default:
     }
-
-    private RelNode vertexJoinEdgeSrc(VertexTable vertexTable,
-                                      EdgeTable edgeTable,
-                                      RelOptRuleCall call,
-                                      boolean isLeftTableVertex,
-                                      RelNode leftInput, RelNode leftHead,
-                                      RelNode rightInput, RelNode rightHead) {
-        return vertexJoinEdge(vertexTable, edgeTable, call, EdgeDirection.OUT, isLeftTableVertex,
-            leftInput, leftHead, rightInput, rightHead);
+    if (tail == null) {
+      return;
     }
-
-    private RelNode edgeTargetJoinVertex(VertexTable vertexTable,
-                                         EdgeTable edgeTable,
-                                         RelOptRuleCall call,
-                                         boolean isLeftTableVertex,
-                                         RelNode leftInput, RelNode leftHead,
-                                         RelNode rightInput, RelNode rightHead) {
-        return vertexJoinEdge(vertexTable, edgeTable, call, EdgeDirection.IN, isLeftTableVertex,
-            leftInput, leftHead, rightInput, rightHead);
+    // add remain filter.
+    JoinInfo joinInfo = join.analyzeCondition();
+    RexNode remainFilter = joinInfo.getRemaining(join.getCluster().getRexBuilder());
+    if (remainFilter != null && !remainFilter.isAlwaysTrue()) {
+      tail = LogicalFilter.create(tail, remainFilter);
     }
+    call.transformTo(tail);
+  }
 
-    private RelNode vertexJoinEdge(VertexTable vertexTable,
-                                   EdgeTable edgeTable,
-                                   RelOptRuleCall call,
-                                   EdgeDirection direction,
-                                   boolean isLeftTableVertex,
-                                   RelNode leftInput, RelNode leftHead,
-                                   RelNode rightInput, RelNode rightHead) {
-        LogicalJoin join = call.rel(0);
-        RelOptCluster cluster = join.getCluster();
-        RelDataType vertexRelType = vertexTable.getRowType(call.builder().getTypeFactory());
-        PathRecordType pathRecordType = PathRecordType.EMPTY;
-        String nodeName = vertexTable.getName();
-        RelDataType edgeRelType = edgeTable.getRowType(call.builder().getTypeFactory());
-        String edgeName = edgeTable.getName();
-
-        VertexMatch vertexMatch;
-        EdgeMatch edgeMatch;
-        IMatchNode matchNode;
-        boolean swapSrcTargetId;
-        List projects = new ArrayList<>();
-        int leftFieldCount;
-        RexBuilder rexBuilder = call.builder().getRexBuilder();
-        List rexLeftNodeMap = new ArrayList<>();
-        List rexRightNodeMap = new ArrayList<>();
-        if (isLeftTableVertex) {
-            pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
-            vertexMatch = VertexMatch.create(cluster, null, nodeName,
-                Collections.singletonList(vertexTable.getName()), vertexRelType, pathRecordType);
-            IMatchNode afterLeft = concatToMatchNode(call.builder(), null, leftInput, leftHead,
-                vertexMatch, rexLeftNodeMap);
-            //Add vertex fields
-            if (rexLeftNodeMap.size() > 0) {
-                projects.addAll(rexLeftNodeMap);
-            } else {
-                RelDataTypeField field = afterLeft.getPathSchema().getField(nodeName, true, false);
-                PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
-                for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
-                }
-            }
-            leftFieldCount = projects.size();
+  private RelNode vertexJoinEdgeSrc(
+      VertexTable vertexTable,
+      EdgeTable edgeTable,
+      RelOptRuleCall call,
+      boolean isLeftTableVertex,
+      RelNode leftInput,
+      RelNode leftHead,
+      RelNode rightInput,
+      RelNode rightHead) {
+    return vertexJoinEdge(
+        vertexTable,
+        edgeTable,
+        call,
+        EdgeDirection.OUT,
+        isLeftTableVertex,
+        leftInput,
+        leftHead,
+        rightInput,
+        rightHead);
+  }
 
-            pathRecordType = afterLeft.getPathSchema();
-            pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
-            //When joining vertex and edge in a LEFT JOIN, the vertex is forcibly retained.
-            switch (join.getJoinType()) {
-                case LEFT:
-                    edgeMatch = OptionalEdgeMatch.create(cluster, (SingleMatchNode) afterLeft, edgeName,
-                        Collections.singletonList(edgeTable.getName()), direction, edgeRelType,
-                        pathRecordType);
-                    break;
-                case INNER:
-                    edgeMatch = EdgeMatch.create(cluster, (SingleMatchNode) afterLeft, edgeName,
-                        Collections.singletonList(edgeTable.getName()), direction, edgeRelType,
-                        pathRecordType);
-                    break;
-                case RIGHT:
-                case FULL:
-                default:
-                    throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
-            }
-            swapSrcTargetId = direction.equals(EdgeDirection.IN);
-            IMatchNode afterRight = concatToMatchNode(call.builder(), afterLeft,
-                rightInput, rightHead, edgeMatch, rexRightNodeMap);
-            //Add edge fields
-            if (rexRightNodeMap.size() > 0) {
-                // In the case of converting match out edges to in edges, swap the source and
-                // target id references of the edge.
-                if (swapSrcTargetId) {
-                    rexRightNodeMap = rexRightNodeMap.stream()
-                        .map(rex -> GQLRexUtil.swapReverseEdgeRef(rex, edgeName, rexBuilder))
-                        .collect(Collectors.toList());
-                    swapSrcTargetId = false;
-                }
-                projects.addAll(rexRightNodeMap);
-            } else {
-                RelDataTypeField field = afterRight.getPathSchema().getField(edgeName, true, false);
-                PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
-                for (int i = 0; i < rightInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
-                }
-            }
-            matchNode = afterRight;
-        } else {
-            GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
-            GeaFlowGraph graph = typeFactory.getCurrentGraph();
-            GraphDescriptor graphDescriptor = graph.getDescriptor();
-            Optional edgeDesc = graphDescriptor.edges.stream().filter(
-                e -> e.type.equals(edgeTable.getName())).findFirst();
-            VertexTable dummyVertex = null;
-            if (edgeDesc.isPresent()) {
-                EdgeDescriptor edgeDescriptor = edgeDesc.get();
-                String dummyNodeType = direction.equals(EdgeDirection.IN)
-                    ? edgeDescriptor.sourceType : edgeDescriptor.targetType;
-                dummyVertex = graph.getVertexTables().stream().filter(
-                    v -> v.getName().equals(dummyNodeType)).findFirst().get();
-            }
-            if (dummyVertex == null) {
-                return null;
-            }
-            String dummyNodeName = dummyVertex.getName();
-            RelDataType dummyVertexRelType = dummyVertex.getRowType(call.builder().getTypeFactory());
-            pathRecordType = pathRecordType.addField(dummyNodeName, dummyVertexRelType, true);
-            VertexMatch dummyVertexMatch = VertexMatch.create(cluster, null, dummyNodeName,
-                Collections.singletonList(dummyVertex.getName()), dummyVertexRelType, pathRecordType);
-            pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
-            EdgeDirection reverseDirection = EdgeDirection.reverse(direction);
-            edgeMatch = EdgeMatch.create(cluster, dummyVertexMatch, edgeName,
-                Collections.singletonList(edgeTable.getName()),
-                reverseDirection, edgeRelType, pathRecordType);
-            swapSrcTargetId = reverseDirection.equals(EdgeDirection.IN);
-            IMatchNode afterLeft = concatToMatchNode(call.builder(), null, leftInput, leftHead,
-                edgeMatch, rexLeftNodeMap);
+  private RelNode edgeTargetJoinVertex(
+      VertexTable vertexTable,
+      EdgeTable edgeTable,
+      RelOptRuleCall call,
+      boolean isLeftTableVertex,
+      RelNode leftInput,
+      RelNode leftHead,
+      RelNode rightInput,
+      RelNode rightHead) {
+    return vertexJoinEdge(
+        vertexTable,
+        edgeTable,
+        call,
+        EdgeDirection.IN,
+        isLeftTableVertex,
+        leftInput,
+        leftHead,
+        rightInput,
+        rightHead);
+  }
 
-            //Add edge fields
-            if (rexLeftNodeMap.size() > 0) {
-                // In the case of converting match out edges to in edges, swap the source and
-                // target id references of the edge.
-                if (swapSrcTargetId) {
-                    rexLeftNodeMap = rexLeftNodeMap.stream()
-                        .map(rex -> GQLRexUtil.swapReverseEdgeRef(rex, edgeName, rexBuilder))
-                        .collect(Collectors.toList());
-                    swapSrcTargetId = false;
-                }
-                projects.addAll(rexLeftNodeMap);
-            } else {
-                RelDataTypeField field = afterLeft.getPathSchema().getField(edgeName, true, false);
-                PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
-                for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
-                }
-            }
-            leftFieldCount = projects.size();
+  private RelNode vertexJoinEdge(
+      VertexTable vertexTable,
+      EdgeTable edgeTable,
+      RelOptRuleCall call,
+      EdgeDirection direction,
+      boolean isLeftTableVertex,
+      RelNode leftInput,
+      RelNode leftHead,
+      RelNode rightInput,
+      RelNode rightHead) {
+    LogicalJoin join = call.rel(0);
+    RelOptCluster cluster = join.getCluster();
+    RelDataType vertexRelType = vertexTable.getRowType(call.builder().getTypeFactory());
+    PathRecordType pathRecordType = PathRecordType.EMPTY;
+    String nodeName = vertexTable.getName();
+    RelDataType edgeRelType = edgeTable.getRowType(call.builder().getTypeFactory());
+    String edgeName = edgeTable.getName();
 
-            pathRecordType = afterLeft.getPathSchema();
-            pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
-            if (join.getJoinType().equals(JoinRelType.LEFT)) {
-                vertexMatch = OptionalVertexMatch.create(cluster, (SingleMatchNode) afterLeft, nodeName,
-                    Collections.singletonList(vertexTable.getName()), vertexRelType, pathRecordType);
-            } else {
-                vertexMatch = VertexMatch.create(cluster, (SingleMatchNode) afterLeft, nodeName,
-                    Collections.singletonList(vertexTable.getName()), vertexRelType, pathRecordType);
-            }
+    VertexMatch vertexMatch;
+    EdgeMatch edgeMatch;
+    IMatchNode matchNode;
+    boolean swapSrcTargetId;
+    List projects = new ArrayList<>();
+    int leftFieldCount;
+    RexBuilder rexBuilder = call.builder().getRexBuilder();
+    List rexLeftNodeMap = new ArrayList<>();
+    List rexRightNodeMap = new ArrayList<>();
+    if (isLeftTableVertex) {
+      pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
+      vertexMatch =
+          VertexMatch.create(
+              cluster,
+              null,
+              nodeName,
+              Collections.singletonList(vertexTable.getName()),
+              vertexRelType,
+              pathRecordType);
+      IMatchNode afterLeft =
+          concatToMatchNode(call.builder(), null, leftInput, leftHead, vertexMatch, rexLeftNodeMap);
+      // Add vertex fields
+      if (rexLeftNodeMap.size() > 0) {
+        projects.addAll(rexLeftNodeMap);
+      } else {
+        RelDataTypeField field = afterLeft.getPathSchema().getField(nodeName, true, false);
+        PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
+        for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
+        }
+      }
+      leftFieldCount = projects.size();
 
-            IMatchNode afterRight = concatToMatchNode(call.builder(), afterLeft, rightInput, rightHead,
-                vertexMatch, rexRightNodeMap);
-            //Add vertex fields
-            if (rexRightNodeMap.size() > 0) {
-                projects.addAll(rexRightNodeMap);
-            } else {
-                RelDataTypeField field = afterRight.getPathSchema().getField(nodeName, true, false);
-                PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
-                for (int i = 0; i < rightInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
-                }
-            }
-            matchNode = afterRight;
+      pathRecordType = afterLeft.getPathSchema();
+      pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
+      // When joining vertex and edge in a LEFT JOIN, the vertex is forcibly retained.
+      switch (join.getJoinType()) {
+        case LEFT:
+          edgeMatch =
+              OptionalEdgeMatch.create(
+                  cluster,
+                  (SingleMatchNode) afterLeft,
+                  edgeName,
+                  Collections.singletonList(edgeTable.getName()),
+                  direction,
+                  edgeRelType,
+                  pathRecordType);
+          break;
+        case INNER:
+          edgeMatch =
+              EdgeMatch.create(
+                  cluster,
+                  (SingleMatchNode) afterLeft,
+                  edgeName,
+                  Collections.singletonList(edgeTable.getName()),
+                  direction,
+                  edgeRelType,
+                  pathRecordType);
+          break;
+        case RIGHT:
+        case FULL:
+        default:
+          throw new GeaFlowDSLException("Illegal join type: {}", join.getJoinType());
+      }
+      swapSrcTargetId = direction.equals(EdgeDirection.IN);
+      IMatchNode afterRight =
+          concatToMatchNode(
+              call.builder(), afterLeft, rightInput, rightHead, edgeMatch, rexRightNodeMap);
+      // Add edge fields
+      if (rexRightNodeMap.size() > 0) {
+        // In the case of converting match out edges to in edges, swap the source and
+        // target id references of the edge.
+        if (swapSrcTargetId) {
+          rexRightNodeMap =
+              rexRightNodeMap.stream()
+                  .map(rex -> GQLRexUtil.swapReverseEdgeRef(rex, edgeName, rexBuilder))
+                  .collect(Collectors.toList());
+          swapSrcTargetId = false;
+        }
+        projects.addAll(rexRightNodeMap);
+      } else {
+        RelDataTypeField field = afterRight.getPathSchema().getField(edgeName, true, false);
+        PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
+        for (int i = 0; i < rightInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
         }
+      }
+      matchNode = afterRight;
+    } else {
+      GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
+      GeaFlowGraph graph = typeFactory.getCurrentGraph();
+      GraphDescriptor graphDescriptor = graph.getDescriptor();
+      Optional edgeDesc =
+          graphDescriptor.edges.stream()
+              .filter(e -> e.type.equals(edgeTable.getName()))
+              .findFirst();
+      VertexTable dummyVertex = null;
+      if (edgeDesc.isPresent()) {
+        EdgeDescriptor edgeDescriptor = edgeDesc.get();
+        String dummyNodeType =
+            direction.equals(EdgeDirection.IN)
+                ? edgeDescriptor.sourceType
+                : edgeDescriptor.targetType;
+        dummyVertex =
+            graph.getVertexTables().stream()
+                .filter(v -> v.getName().equals(dummyNodeType))
+                .findFirst()
+                .get();
+      }
+      if (dummyVertex == null) {
+        return null;
+      }
+      String dummyNodeName = dummyVertex.getName();
+      RelDataType dummyVertexRelType = dummyVertex.getRowType(call.builder().getTypeFactory());
+      pathRecordType = pathRecordType.addField(dummyNodeName, dummyVertexRelType, true);
+      VertexMatch dummyVertexMatch =
+          VertexMatch.create(
+              cluster,
+              null,
+              dummyNodeName,
+              Collections.singletonList(dummyVertex.getName()),
+              dummyVertexRelType,
+              pathRecordType);
+      pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
+      EdgeDirection reverseDirection = EdgeDirection.reverse(direction);
+      edgeMatch =
+          EdgeMatch.create(
+              cluster,
+              dummyVertexMatch,
+              edgeName,
+              Collections.singletonList(edgeTable.getName()),
+              reverseDirection,
+              edgeRelType,
+              pathRecordType);
+      swapSrcTargetId = reverseDirection.equals(EdgeDirection.IN);
+      IMatchNode afterLeft =
+          concatToMatchNode(call.builder(), null, leftInput, leftHead, edgeMatch, rexLeftNodeMap);
 
-        //In the case of reverse matching in the IN direction, the positions of the source
-        // vertex and the destination vertex are swapped.
+      // Add edge fields
+      if (rexLeftNodeMap.size() > 0) {
+        // In the case of converting match out edges to in edges, swap the source and
+        // target id references of the edge.
         if (swapSrcTargetId) {
-            int edgeSrcIdIndex = edgeRelType.getFieldList().stream().filter(
-                f -> f.getType() instanceof MetaFieldType
-                        && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_SRC_ID))
-                .collect(Collectors.toList()).get(0).getIndex();
-            int edgeTargetIdIndex = edgeRelType.getFieldList().stream().filter(
-                f -> f.getType() instanceof MetaFieldType
-                        && ((MetaFieldType) f.getType()).getMetaField().equals(MetaField.EDGE_TARGET_ID))
-                .collect(Collectors.toList()).get(0).getIndex();
-            int baseOffset = isLeftTableVertex ? leftFieldCount : 0;
-            Collections.swap(projects, baseOffset + edgeSrcIdIndex, baseOffset + edgeTargetIdIndex);
+          rexLeftNodeMap =
+              rexLeftNodeMap.stream()
+                  .map(rex -> GQLRexUtil.swapReverseEdgeRef(rex, edgeName, rexBuilder))
+                  .collect(Collectors.toList());
+          swapSrcTargetId = false;
+        }
+        projects.addAll(rexLeftNodeMap);
+      } else {
+        RelDataTypeField field = afterLeft.getPathSchema().getField(edgeName, true, false);
+        PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
+        for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
         }
+      }
+      leftFieldCount = projects.size();
 
-        GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
-        GeaFlowGraph graph = typeFactory.getCurrentGraph();
-        LogicalGraphScan graphScan = LogicalGraphScan.create(cluster, graph);
-        LogicalGraphMatch graphMatch = LogicalGraphMatch.create(cluster, graphScan,
-            matchNode, matchNode.getPathSchema());
+      pathRecordType = afterLeft.getPathSchema();
+      pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
+      if (join.getJoinType().equals(JoinRelType.LEFT)) {
+        vertexMatch =
+            OptionalVertexMatch.create(
+                cluster,
+                (SingleMatchNode) afterLeft,
+                nodeName,
+                Collections.singletonList(vertexTable.getName()),
+                vertexRelType,
+                pathRecordType);
+      } else {
+        vertexMatch =
+            VertexMatch.create(
+                cluster,
+                (SingleMatchNode) afterLeft,
+                nodeName,
+                Collections.singletonList(vertexTable.getName()),
+                vertexRelType,
+                pathRecordType);
+      }
 
-        List matchTypeFields = new ArrayList<>();
-        List newFieldNames = this.generateFieldNames("f", projects.size(), new HashSet<>());
-        for (int i = 0; i < projects.size(); i++) {
-            matchTypeFields.add(new RelDataTypeFieldImpl(newFieldNames.get(i), i ,
-                projects.get(i).getType()));
+      IMatchNode afterRight =
+          concatToMatchNode(
+              call.builder(), afterLeft, rightInput, rightHead, vertexMatch, rexRightNodeMap);
+      // Add vertex fields
+      if (rexRightNodeMap.size() > 0) {
+        projects.addAll(rexRightNodeMap);
+      } else {
+        RelDataTypeField field = afterRight.getPathSchema().getField(nodeName, true, false);
+        PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
+        for (int i = 0; i < rightInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
         }
-        RelNode tail = LogicalProject.create(graphMatch, projects, new RelRecordType(matchTypeFields));
+      }
+      matchNode = afterRight;
+    }
 
-        // Complete the Join projection.
-        final RelNode finalTail = tail;
-        List joinProjects = IntStream.range(0, projects.size())
-            .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i)).collect(Collectors.toList());
-        AtomicInteger offset = new AtomicInteger();
-        // Make the project type nullable the same as the output type of the join.
-        joinProjects = joinProjects.stream().map(prj -> {
-            int i = offset.getAndIncrement();
-            boolean joinFieldNullable = join.getRowType().getFieldList().get(i).getType().isNullable();
-            if ((prj.getType().isNullable() && !joinFieldNullable)
-                || (!prj.getType().isNullable() && joinFieldNullable)) {
-                RelDataType type = rexBuilder.getTypeFactory().createTypeWithNullability(prj.getType(), joinFieldNullable);
-                return rexBuilder.makeCast(type, prj);
-            }
-            return prj;
-        }).collect(Collectors.toList());
-        tail = LogicalProject.create(tail, joinProjects, join.getRowType());
-        return tail;
+    // In the case of reverse matching in the IN direction, the positions of the source
+    // vertex and the destination vertex are swapped.
+    if (swapSrcTargetId) {
+      int edgeSrcIdIndex =
+          edgeRelType.getFieldList().stream()
+              .filter(
+                  f ->
+                      f.getType() instanceof MetaFieldType
+                          && ((MetaFieldType) f.getType())
+                              .getMetaField()
+                              .equals(MetaField.EDGE_SRC_ID))
+              .collect(Collectors.toList())
+              .get(0)
+              .getIndex();
+      int edgeTargetIdIndex =
+          edgeRelType.getFieldList().stream()
+              .filter(
+                  f ->
+                      f.getType() instanceof MetaFieldType
+                          && ((MetaFieldType) f.getType())
+                              .getMetaField()
+                              .equals(MetaField.EDGE_TARGET_ID))
+              .collect(Collectors.toList())
+              .get(0)
+              .getIndex();
+      int baseOffset = isLeftTableVertex ? leftFieldCount : 0;
+      Collections.swap(projects, baseOffset + edgeSrcIdIndex, baseOffset + edgeTargetIdIndex);
     }
+
+    GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
+    GeaFlowGraph graph = typeFactory.getCurrentGraph();
+    LogicalGraphScan graphScan = LogicalGraphScan.create(cluster, graph);
+    LogicalGraphMatch graphMatch =
+        LogicalGraphMatch.create(cluster, graphScan, matchNode, matchNode.getPathSchema());
+
+    List matchTypeFields = new ArrayList<>();
+    List newFieldNames = this.generateFieldNames("f", projects.size(), new HashSet<>());
+    for (int i = 0; i < projects.size(); i++) {
+      matchTypeFields.add(
+          new RelDataTypeFieldImpl(newFieldNames.get(i), i, projects.get(i).getType()));
+    }
+    RelNode tail = LogicalProject.create(graphMatch, projects, new RelRecordType(matchTypeFields));
+
+    // Complete the Join projection.
+    final RelNode finalTail = tail;
+    List joinProjects =
+        IntStream.range(0, projects.size())
+            .mapToObj(i -> rexBuilder.makeInputRef(finalTail, i))
+            .collect(Collectors.toList());
+    AtomicInteger offset = new AtomicInteger();
+    // Make the project type nullable the same as the output type of the join.
+    joinProjects =
+        joinProjects.stream()
+            .map(
+                prj -> {
+                  int i = offset.getAndIncrement();
+                  boolean joinFieldNullable =
+                      join.getRowType().getFieldList().get(i).getType().isNullable();
+                  if ((prj.getType().isNullable() && !joinFieldNullable)
+                      || (!prj.getType().isNullable() && joinFieldNullable)) {
+                    RelDataType type =
+                        rexBuilder
+                            .getTypeFactory()
+                            .createTypeWithNullability(prj.getType(), joinFieldNullable);
+                    return rexBuilder.makeCast(type, prj);
+                  }
+                  return prj;
+                })
+            .collect(Collectors.toList());
+    tail = LogicalProject.create(tail, joinProjects, join.getRowType());
+    return tail;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableScanToGraphRule.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableScanToGraphRule.java
index a074507a3..62e1218b7 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableScanToGraphRule.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/optimize/rule/TableScanToGraphRule.java
@@ -24,6 +24,7 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
@@ -55,123 +56,146 @@
 
 public class TableScanToGraphRule extends AbstractJoinToGraphRule {
 
-    public static final TableScanToGraphRule INSTANCE = new TableScanToGraphRule();
+  public static final TableScanToGraphRule INSTANCE = new TableScanToGraphRule();
 
-    public TableScanToGraphRule() {
-        super(operand(LogicalTableScan.class, any()));
-    }
+  public TableScanToGraphRule() {
+    super(operand(LogicalTableScan.class, any()));
+  }
 
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        return true;
-    }
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    return true;
+  }
 
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        RelNode leftInput = call.rel(0);
-        RelNode leftHead = null;
-        RelNode leftTableScan = leftInput;
-        while (!(leftTableScan instanceof LogicalTableScan)) {
-            leftHead = leftTableScan;
-            leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
-        }
-        GeaFlowTable leftTable = leftTableScan.getTable().unwrap(GeaFlowTable.class);
-        GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
-        GeaFlowGraph graph = typeFactory.getCurrentGraph();
-        if (graph == null) {
-            return;
-        }
-        if (!graph.containTable(leftTable)) {
-            if (leftTable instanceof VertexTable || leftTable instanceof EdgeTable) {
-                throw new GeaFlowDSLException("Unknown graph element: {}, use graph please.",
-                    leftTable.getName());
-            }
-            return;
-        }
-        RelNode tail;
-        RelOptCluster cluster = leftTableScan.getCluster();
-        LogicalGraphScan graphScan = LogicalGraphScan.create(cluster, graph);
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    RelNode leftInput = call.rel(0);
+    RelNode leftHead = null;
+    RelNode leftTableScan = leftInput;
+    while (!(leftTableScan instanceof LogicalTableScan)) {
+      leftHead = leftTableScan;
+      leftTableScan = GQLRelUtil.toRel(leftTableScan.getInput(0));
+    }
+    GeaFlowTable leftTable = leftTableScan.getTable().unwrap(GeaFlowTable.class);
+    GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) call.builder().getTypeFactory();
+    GeaFlowGraph graph = typeFactory.getCurrentGraph();
+    if (graph == null) {
+      return;
+    }
+    if (!graph.containTable(leftTable)) {
+      if (leftTable instanceof VertexTable || leftTable instanceof EdgeTable) {
+        throw new GeaFlowDSLException(
+            "Unknown graph element: {}, use graph please.", leftTable.getName());
+      }
+      return;
+    }
+    RelNode tail;
+    RelOptCluster cluster = leftTableScan.getCluster();
+    LogicalGraphScan graphScan = LogicalGraphScan.create(cluster, graph);
 
-        VertexMatch vertexMatch;
-        List projects = new ArrayList<>();
-        RexBuilder rexBuilder = call.builder().getRexBuilder();
-        List rexLeftNodeMap = new ArrayList<>();
-        PathRecordType pathRecordType = PathRecordType.EMPTY;
-        IMatchNode afterLeft;
-        if (leftTable instanceof VertexTable) {
-            VertexTable vertexTable = (VertexTable) leftTable;
-            RelDataType vertexRelType = vertexTable.getRowType(call.builder().getTypeFactory());
-            String nodeName = vertexTable.getName();
+    VertexMatch vertexMatch;
+    List projects = new ArrayList<>();
+    RexBuilder rexBuilder = call.builder().getRexBuilder();
+    List rexLeftNodeMap = new ArrayList<>();
+    PathRecordType pathRecordType = PathRecordType.EMPTY;
+    IMatchNode afterLeft;
+    if (leftTable instanceof VertexTable) {
+      VertexTable vertexTable = (VertexTable) leftTable;
+      RelDataType vertexRelType = vertexTable.getRowType(call.builder().getTypeFactory());
+      String nodeName = vertexTable.getName();
 
-            pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
-            vertexMatch = VertexMatch.create(cluster, null, nodeName,
-                Collections.singletonList(vertexTable.getName()), vertexRelType, pathRecordType);
-            afterLeft = concatToMatchNode(call.builder(), null, leftInput, leftHead,
-                vertexMatch, rexLeftNodeMap);
-            //Add vertex fields
-            if (!rexLeftNodeMap.isEmpty()) {
-                projects.addAll(rexLeftNodeMap);
-            } else {
-                RelDataTypeField field = afterLeft.getPathSchema().getField(nodeName, true, false);
-                PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
-                for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
-                }
-            }
-            tail = afterLeft;
-        } else {
-            EdgeTable edgeTable = (EdgeTable) leftTable;
-            GraphDescriptor graphDescriptor = graph.getDescriptor();
-            Optional edgeDesc = graphDescriptor.edges.stream().filter(
-                e -> e.type.equals(edgeTable.getName())).findFirst();
-            VertexTable dummyVertex = null;
-            if (edgeDesc.isPresent()) {
-                EdgeDescriptor edgeDescriptor = edgeDesc.get();
-                String dummyNodeType = edgeDescriptor.sourceType;
-                dummyVertex = graph.getVertexTables().stream().filter(
-                    v -> v.getName().equals(dummyNodeType)).findFirst().get();
-            }
-            if (dummyVertex == null) {
-                return;
-            }
-            String dummyNodeName = dummyVertex.getName();
-            RelDataType dummyVertexRelType = dummyVertex.getRowType(call.builder().getTypeFactory());
-            pathRecordType = pathRecordType.addField(dummyNodeName, dummyVertexRelType, true);
-            VertexMatch dummyVertexMatch = VertexMatch.create(cluster, null, dummyNodeName,
-                Collections.singletonList(dummyVertex.getName()), dummyVertexRelType, pathRecordType);
-            RelDataType edgeRelType = edgeTable.getRowType(call.builder().getTypeFactory());
-            String edgeName = edgeTable.getName();
-            pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
-            IMatchNode edgeMatch = EdgeMatch.create(cluster, dummyVertexMatch, edgeName,
-                Collections.singletonList(edgeTable.getName()),
-                EdgeDirection.OUT, edgeRelType, pathRecordType);
-            afterLeft = concatToMatchNode(call.builder(), null, leftInput, leftHead,
-                edgeMatch, rexLeftNodeMap);
-            tail = afterLeft;
-            //Add edge fields
-            if (!rexLeftNodeMap.isEmpty()) {
-                projects.addAll(rexLeftNodeMap);
-            } else {
-                RelDataTypeField field = afterLeft.getPathSchema().getField(edgeName, true, false);
-                PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
-                for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
-                    projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
-                }
-            }
+      pathRecordType = pathRecordType.addField(nodeName, vertexRelType, false);
+      vertexMatch =
+          VertexMatch.create(
+              cluster,
+              null,
+              nodeName,
+              Collections.singletonList(vertexTable.getName()),
+              vertexRelType,
+              pathRecordType);
+      afterLeft =
+          concatToMatchNode(call.builder(), null, leftInput, leftHead, vertexMatch, rexLeftNodeMap);
+      // Add vertex fields
+      if (!rexLeftNodeMap.isEmpty()) {
+        projects.addAll(rexLeftNodeMap);
+      } else {
+        RelDataTypeField field = afterLeft.getPathSchema().getField(nodeName, true, false);
+        PathInputRef vertexRef = new PathInputRef(nodeName, field.getIndex(), field.getType());
+        for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(vertexRef, i));
         }
-        if (tail == null) {
-            return;
+      }
+      tail = afterLeft;
+    } else {
+      EdgeTable edgeTable = (EdgeTable) leftTable;
+      GraphDescriptor graphDescriptor = graph.getDescriptor();
+      Optional edgeDesc =
+          graphDescriptor.edges.stream()
+              .filter(e -> e.type.equals(edgeTable.getName()))
+              .findFirst();
+      VertexTable dummyVertex = null;
+      if (edgeDesc.isPresent()) {
+        EdgeDescriptor edgeDescriptor = edgeDesc.get();
+        String dummyNodeType = edgeDescriptor.sourceType;
+        dummyVertex =
+            graph.getVertexTables().stream()
+                .filter(v -> v.getName().equals(dummyNodeType))
+                .findFirst()
+                .get();
+      }
+      if (dummyVertex == null) {
+        return;
+      }
+      String dummyNodeName = dummyVertex.getName();
+      RelDataType dummyVertexRelType = dummyVertex.getRowType(call.builder().getTypeFactory());
+      pathRecordType = pathRecordType.addField(dummyNodeName, dummyVertexRelType, true);
+      VertexMatch dummyVertexMatch =
+          VertexMatch.create(
+              cluster,
+              null,
+              dummyNodeName,
+              Collections.singletonList(dummyVertex.getName()),
+              dummyVertexRelType,
+              pathRecordType);
+      RelDataType edgeRelType = edgeTable.getRowType(call.builder().getTypeFactory());
+      String edgeName = edgeTable.getName();
+      pathRecordType = pathRecordType.addField(edgeName, edgeRelType, true);
+      IMatchNode edgeMatch =
+          EdgeMatch.create(
+              cluster,
+              dummyVertexMatch,
+              edgeName,
+              Collections.singletonList(edgeTable.getName()),
+              EdgeDirection.OUT,
+              edgeRelType,
+              pathRecordType);
+      afterLeft =
+          concatToMatchNode(call.builder(), null, leftInput, leftHead, edgeMatch, rexLeftNodeMap);
+      tail = afterLeft;
+      // Add edge fields
+      if (!rexLeftNodeMap.isEmpty()) {
+        projects.addAll(rexLeftNodeMap);
+      } else {
+        RelDataTypeField field = afterLeft.getPathSchema().getField(edgeName, true, false);
+        PathInputRef edgeRef = new PathInputRef(edgeName, field.getIndex(), field.getType());
+        for (int i = 0; i < leftInput.getRowType().getFieldCount(); i++) {
+          projects.add(rexBuilder.makeFieldAccess(edgeRef, i));
         }
-        LogicalGraphMatch graphMatch = LogicalGraphMatch.create(cluster, graphScan,
-            afterLeft, afterLeft.getPathSchema());
+      }
+    }
+    if (tail == null) {
+      return;
+    }
+    LogicalGraphMatch graphMatch =
+        LogicalGraphMatch.create(cluster, graphScan, afterLeft, afterLeft.getPathSchema());
 
-        List matchTypeFields = new ArrayList<>();
-        List newFieldNames = this.generateFieldNames("f", projects.size(), new HashSet<>());
-        for (int i = 0; i < projects.size(); i++) {
-            matchTypeFields.add(new RelDataTypeFieldImpl(newFieldNames.get(i), i,
-                projects.get(i).getType()));
-        }
-        tail = LogicalProject.create(graphMatch, projects, new RelRecordType(matchTypeFields));
-        call.transformTo(tail);
+    List matchTypeFields = new ArrayList<>();
+    List newFieldNames = this.generateFieldNames("f", projects.size(), new HashSet<>());
+    for (int i = 0; i < projects.size(); i++) {
+      matchTypeFields.add(
+          new RelDataTypeFieldImpl(newFieldNames.get(i), i, projects.get(i).getType()));
     }
+    tail = LogicalProject.create(graphMatch, projects, new RelRecordType(matchTypeFields));
+    call.transformTo(tail);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLContext.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLContext.java
index 3f69760d0..06357391d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLContext.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLContext.java
@@ -19,10 +19,6 @@
 
 package org.apache.geaflow.dsl.planner;
 
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -34,6 +30,7 @@
 import java.util.Properties;
 import java.util.Set;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.config.CalciteConnectionConfig;
 import org.apache.calcite.config.CalciteConnectionConfigImpl;
 import org.apache.calcite.config.CalciteConnectionProperty;
@@ -101,70 +98,76 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
 public class GQLContext {
 
-    static {
-        try {
-            Class.forName("org.apache.calcite.jdbc.Driver");
-        } catch (ClassNotFoundException e) {
-            throw new RuntimeException(e);
-        }
+  static {
+    try {
+      Class.forName("org.apache.calcite.jdbc.Driver");
+    } catch (ClassNotFoundException e) {
+      throw new RuntimeException(e);
     }
+  }
 
-    private static final Logger LOG = LoggerFactory.getLogger(GQLContext.class);
+  private static final Logger LOG = LoggerFactory.getLogger(GQLContext.class);
 
-    private final Catalog catalog;
+  private final Catalog catalog;
 
-    private final SchemaPlus defaultSchema;
+  private final SchemaPlus defaultSchema;
 
-    private static final GQLConformance CONFORMANCE = GQLConformance.INSTANCE;
+  private static final GQLConformance CONFORMANCE = GQLConformance.INSTANCE;
 
-    private final GQLRelDataTypeSystem typeSystem = new GQLRelDataTypeSystem();
+  private final GQLRelDataTypeSystem typeSystem = new GQLRelDataTypeSystem();
 
-    private final GQLJavaTypeFactory typeFactory = new GQLJavaTypeFactory(typeSystem);
+  private final GQLJavaTypeFactory typeFactory = new GQLJavaTypeFactory(typeSystem);
 
-    private final GQLOperatorTable sqlOperatorTable;
+  private final GQLOperatorTable sqlOperatorTable;
 
-    private final FrameworkConfig frameworkConfig;
+  private final FrameworkConfig frameworkConfig;
 
-    private final GQLRelBuilder relBuilder;
+  private final GQLRelBuilder relBuilder;
 
-    private final GQLValidatorImpl validator;
+  private final GQLValidatorImpl validator;
 
-    private final SqlRexConvertletTable convertLetTable;
+  private final SqlRexConvertletTable convertLetTable;
 
-    private String currentInstance;
+  private String currentInstance;
 
-    private String currentGraph;
+  private String currentGraph;
 
-    private final Set validatedRelNode;
+  private final Set validatedRelNode;
 
-    private static final Map shortKeyMapping = new HashMap<>();
+  private static final Map shortKeyMapping = new HashMap<>();
 
-    static {
-        shortKeyMapping.put("storeType", DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey());
-        shortKeyMapping.put("shardCount", DSLConfigKeys.GEAFLOW_DSL_STORE_SHARD_COUNT.getKey());
-        shortKeyMapping.put("type", DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey());
-    }
+  static {
+    shortKeyMapping.put("storeType", DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey());
+    shortKeyMapping.put("shardCount", DSLConfigKeys.GEAFLOW_DSL_STORE_SHARD_COUNT.getKey());
+    shortKeyMapping.put("type", DSLConfigKeys.GEAFLOW_DSL_TABLE_TYPE.getKey());
+  }
 
-    private GQLContext(Configuration conf, boolean isCompile) {
-        this.currentInstance = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME);
-        if (isCompile) {
-            this.catalog = new CompileCatalog(CatalogFactory.getCatalog(conf));
-        } else {
-            this.catalog = CatalogFactory.getCatalog(conf);
-        }
-        this.defaultSchema = new GeaFlowRootCalciteSchema(this.catalog).plus();
-        this.sqlOperatorTable = new GQLOperatorTable(
+  private GQLContext(Configuration conf, boolean isCompile) {
+    this.currentInstance = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME);
+    if (isCompile) {
+      this.catalog = new CompileCatalog(CatalogFactory.getCatalog(conf));
+    } else {
+      this.catalog = CatalogFactory.getCatalog(conf);
+    }
+    this.defaultSchema = new GeaFlowRootCalciteSchema(this.catalog).plus();
+    this.sqlOperatorTable =
+        new GQLOperatorTable(
             catalog,
             typeFactory,
             this,
             new BuildInSqlOperatorTable(),
             new BuildInSqlFunctionTable(typeFactory));
 
-        GQLCostFactory costFactory = new GQLCostFactory();
-        this.frameworkConfig = Frameworks
-            .newConfigBuilder()
+    GQLCostFactory costFactory = new GQLCostFactory();
+    this.frameworkConfig =
+        Frameworks.newConfigBuilder()
             .defaultSchema(this.defaultSchema)
             .parserConfig(GeaFlowDSLParser.PARSER_CONFIG)
             .costFactory(costFactory)
@@ -172,434 +175,490 @@ private GQLContext(Configuration conf, boolean isCompile) {
             .operatorTable(this.sqlOperatorTable)
             .build();
 
-        this.relBuilder = GQLRelBuilder.create(frameworkConfig, createRexBuilder());
-        CalciteCatalogReader calciteCatalogReader = createCatalogReader();
-
-        this.validator = new GQLValidatorImpl(this, sqlOperatorTable,
-            calciteCatalogReader, this.typeFactory, CONFORMANCE);
-        this.validator.setIdentifierExpansion(true);
-        this.convertLetTable = frameworkConfig.getConvertletTable();
-        this.validatedRelNode = new HashSet<>();
-    }
-
-    public static GQLContext create(Configuration conf, boolean isCompile) {
-        return new GQLContext(conf, isCompile);
-    }
-
-    public Catalog getCatalog() {
-        return catalog;
-    }
-
-    /**
-     * Convert {@link SqlCreateTable} to {@link GeaFlowTable}.
-     */
-    public GeaFlowTable convertToTable(SqlCreateTable table) {
-        List fields = Lists.newArrayList();
-
-        for (SqlNode node : table.getColumns()) {
-            SqlTableColumn columnNode = (SqlTableColumn) node;
-            fields.add(columnNode.toTableField());
+    this.relBuilder = GQLRelBuilder.create(frameworkConfig, createRexBuilder());
+    CalciteCatalogReader calciteCatalogReader = createCatalogReader();
+
+    this.validator =
+        new GQLValidatorImpl(
+            this, sqlOperatorTable, calciteCatalogReader, this.typeFactory, CONFORMANCE);
+    this.validator.setIdentifierExpansion(true);
+    this.convertLetTable = frameworkConfig.getConvertletTable();
+    this.validatedRelNode = new HashSet<>();
+  }
+
+  public static GQLContext create(Configuration conf, boolean isCompile) {
+    return new GQLContext(conf, isCompile);
+  }
+
+  public Catalog getCatalog() {
+    return catalog;
+  }
+
+  /** Convert {@link SqlCreateTable} to {@link GeaFlowTable}. */
+  public GeaFlowTable convertToTable(SqlCreateTable table) {
+    List fields = Lists.newArrayList();
+
+    for (SqlNode node : table.getColumns()) {
+      SqlTableColumn columnNode = (SqlTableColumn) node;
+      fields.add(columnNode.toTableField());
+    }
+
+    List primaryFields = Lists.newArrayList();
+    if (table.getPrimaryKeys() != null) {
+      for (SqlNode node : table.getPrimaryKeys()) {
+        SqlIdentifier primaryKey = (SqlIdentifier) node;
+        primaryFields.add(primaryKey.getSimple());
+      }
+    }
+    List partitionFields = Lists.newArrayList();
+    if (table.getPartitionFields() != null) {
+      List fieldNames =
+          fields.stream().map(TableField::getName).collect(Collectors.toList());
+
+      for (SqlNode node : table.getPartitionFields()) {
+        SqlIdentifier partitionField = (SqlIdentifier) node;
+        partitionFields.add(partitionField.getSimple());
+        int partitionIndex = fieldNames.indexOf(partitionField.getSimple());
+        if (partitionIndex == -1) {
+          throw new GeaFlowDSLException(
+              node.getParserPosition(), "Partition field: {} is not exists in field list.", node);
         }
-
-        List primaryFields = Lists.newArrayList();
-        if (table.getPrimaryKeys() != null) {
-            for (SqlNode node : table.getPrimaryKeys()) {
-                SqlIdentifier primaryKey = (SqlIdentifier) node;
-                primaryFields.add(primaryKey.getSimple());
-            }
+        if (partitionIndex < fieldNames.size() - partitionFields.size()) {
+          throw new GeaFlowDSLException(
+              node.getParserPosition(),
+              "Partition field should be the last fields in the field list");
         }
-        List partitionFields = Lists.newArrayList();
-        if (table.getPartitionFields() != null) {
-            List fieldNames = fields.stream()
-                .map(TableField::getName).collect(Collectors.toList());
-
-            for (SqlNode node : table.getPartitionFields()) {
-                SqlIdentifier partitionField = (SqlIdentifier) node;
-                partitionFields.add(partitionField.getSimple());
-                int partitionIndex = fieldNames.indexOf(partitionField.getSimple());
-                if (partitionIndex == -1) {
-                    throw new GeaFlowDSLException(node.getParserPosition(),
-                        "Partition field: {} is not exists in field list.", node);
-                }
-                if (partitionIndex < fieldNames.size() - partitionFields.size()) {
-                    throw new GeaFlowDSLException(node.getParserPosition(),
-                        "Partition field should be the last fields in the field list");
-                }
-            }
+      }
+    }
+    Map config = Maps.newHashMap();
+    if (table.getProperties() != null) {
+      for (SqlNode sqlNode : table.getProperties()) {
+        SqlTableProperty property = (SqlTableProperty) sqlNode;
+        String key = keyMapping(property.getKey().toString());
+        String value = StringLiteralUtil.toJavaString(property.getValue());
+        config.put(key, value);
+      }
+    }
+    String tableName = getCatalogObjName(table.getName());
+    return new GeaFlowTable(
+        currentInstance,
+        tableName,
+        fields,
+        primaryFields,
+        partitionFields,
+        config,
+        table.ifNotExists(),
+        table.isTemporary());
+  }
+
+  public static String getCatalogObjName(SqlIdentifier name) {
+    if (name.names.size() > 0) {
+      return name.names.get(name.names.size() - 1);
+    }
+    throw new GeaFlowDSLException("Illegal table/graph/function name: " + name);
+  }
+
+  /**
+   * Complete the catalog object name.
+   *
+   * @param name catalog object identifier.
+   * @return completed catalog object identifier with instance name.
+   */
+  public SqlIdentifier completeCatalogObjName(SqlIdentifier name) {
+    String firstName = name.names.get(0);
+    if (!catalog.isInstanceExists(firstName)) {
+      // if the first name is not an instance, append the current instance name.
+      // e.g. table "user" in "select * from user" will complete to "${currentInstance}.user"
+      List newNames = new ArrayList<>();
+      newNames.add(currentInstance);
+      newNames.addAll(name.names);
+      return new SqlIdentifier(newNames, name.getParserPosition());
+    }
+    return name;
+  }
+
+  /** Convert {@link SqlCreateView} to {@link GeaFlowView}. */
+  public GeaFlowView convertToView(SqlCreateView view) {
+    String viewName = view.getName().getSimple();
+    validator.validate(view.getSubQuery());
+    validatedRelNode.add(view.getSubQuery());
+
+    RelRecordType recordType = (RelRecordType) validator.getValidatedNodeType(view.getSubQuery());
+    Preconditions.checkArgument(
+        recordType.getFieldCount() == view.getFields().size(),
+        "The column size of view "
+            + viewName
+            + " is "
+            + view.getFields().size()
+            + " ,but the output column size of the sub query is "
+            + recordType.getFieldCount()
+            + " at "
+            + view.getParserPosition());
+
+    List fields = new ArrayList<>();
+    List types = new ArrayList<>();
+
+    for (int i = 0; i < recordType.getFieldList().size(); i++) {
+      String field = view.getFields().get(i).toString();
+      RelDataType type = recordType.getFieldList().get(i).getType();
+      fields.add(field);
+      types.add(type);
+    }
+    RelDataType rowType = typeFactory.createStructType(types, fields);
+
+    String viewSql = view.getSubQuerySql();
+
+    return new GeaFlowView(currentInstance, viewName, fields, rowType, viewSql, view.ifNotExists());
+  }
+
+  /** Convert {@link SqlCreateGraph} to {@link GeaFlowGraph}. */
+  public GeaFlowGraph convertToGraph(SqlCreateGraph graph) {
+    return convertToGraph(graph, Collections.emptyList());
+  }
+
+  public GeaFlowGraph convertToGraph(
+      SqlCreateGraph graph, Collection createTablesInScript) {
+    List vertexTables = new ArrayList<>();
+    SqlNodeList vertices = graph.getVertices();
+    Map vertexEdgeName2UsingTableNameMap = new HashMap<>();
+
+    GraphDescriptor desc = new GraphDescriptor();
+    for (SqlNode node : vertices) {
+      String idFieldName = null;
+      List vertexFields = new ArrayList<>();
+
+      if (node instanceof SqlVertex) {
+        SqlVertex vertex = (SqlVertex) node;
+        for (SqlNode column : vertex.getColumns()) {
+          SqlTableColumn tableColumn = (SqlTableColumn) column;
+          vertexFields.add(tableColumn.toTableField());
+          switch (tableColumn.getCategory()) {
+            case ID:
+              idFieldName = tableColumn.getName().getSimple();
+              break;
+            case NONE:
+              break;
+            default:
+              throw new GeaFlowDSLException(
+                  "Illegal column category: "
+                      + tableColumn.getCategory()
+                      + " at "
+                      + tableColumn.getParserPosition());
+          }
         }
-        Map config = Maps.newHashMap();
-        if (table.getProperties() != null) {
-            for (SqlNode sqlNode : table.getProperties()) {
-                SqlTableProperty property = (SqlTableProperty) sqlNode;
-                String key = keyMapping(property.getKey().toString());
-                String value = StringLiteralUtil.toJavaString(property.getValue());
-                config.put(key, value);
-            }
+        vertexTables.add(
+            new VertexTable(
+                currentInstance, vertex.getName().getSimple(), vertexFields, idFieldName));
+        desc.addNode(
+            new NodeDescriptor(
+                desc.getIdName(graph.getName().toString()), vertex.getName().getSimple()));
+      } else if (node instanceof SqlVertexUsing) {
+        SqlVertexUsing vertexUsing = (SqlVertexUsing) node;
+        List names = vertexUsing.getUsingTableName().names;
+        String tableName = vertexUsing.getUsingTableName().getSimple();
+
+        Table usingTable = null;
+        for (GeaFlowTable createTable : createTablesInScript) {
+          if (createTable.getName().equals(vertexUsing.getUsingTableName().getSimple())) {
+            usingTable = createTable;
+          }
         }
-        String tableName = getCatalogObjName(table.getName());
-        return new GeaFlowTable(currentInstance, tableName, fields, primaryFields, partitionFields,
-            config, table.ifNotExists(), table.isTemporary());
-    }
-
-    public static String getCatalogObjName(SqlIdentifier name) {
-        if (name.names.size() > 0) {
-            return name.names.get(name.names.size() - 1);
+        if (usingTable == null) {
+          String instanceName =
+              names.size() > 1 ? names.get(names.size() - 2) : getCurrentInstance();
+          usingTable = this.getCatalog().getTable(instanceName, tableName);
         }
-        throw new GeaFlowDSLException("Illegal table/graph/function name: " + name);
-    }
-
-    /**
-     * Complete the catalog object name.
-     *
-     * @param name catalog object identifier.
-     * @return completed catalog object identifier with instance name.
-     */
-    public SqlIdentifier completeCatalogObjName(SqlIdentifier name) {
-        String firstName = name.names.get(0);
-        if (!catalog.isInstanceExists(firstName)) {
-            // if the first name is not an instance, append the current instance name.
-            // e.g. table "user" in "select * from user" will complete to "${currentInstance}.user"
-            List newNames = new ArrayList<>();
-            newNames.add(currentInstance);
-            newNames.addAll(name.names);
-            return new SqlIdentifier(newNames, name.getParserPosition());
+        if (usingTable == null) {
+          throw new GeaFlowDSLException(
+              node.getParserPosition(),
+              "Cannot found using table: {}, check statement order.",
+              tableName);
         }
-        return name;
-    }
-
-    /**
-     * Convert {@link SqlCreateView} to {@link GeaFlowView}.
-     */
-    public GeaFlowView convertToView(SqlCreateView view) {
-        String viewName = view.getName().getSimple();
-        validator.validate(view.getSubQuery());
-        validatedRelNode.add(view.getSubQuery());
-
-        RelRecordType recordType = (RelRecordType) validator.getValidatedNodeType(view.getSubQuery());
-        Preconditions.checkArgument(recordType.getFieldCount() == view.getFields().size(),
-            "The column size of view " + viewName + " is " + view.getFields().size()
-                + " ,but the output column size of the sub query is " + recordType.getFieldCount()
-                + " at " + view.getParserPosition());
-
-        List fields = new ArrayList<>();
-        List types = new ArrayList<>();
-
-        for (int i = 0; i < recordType.getFieldList().size(); i++) {
-            String field = view.getFields().get(i).toString();
-            RelDataType type = recordType.getFieldList().get(i).getType();
-            fields.add(field);
-            types.add(type);
+        idFieldName = vertexUsing.getId().getSimple();
+
+        TableField idField = null;
+        Set fieldNames = new HashSet<>();
+        for (RelDataTypeField column : usingTable.getRowType(this.typeFactory).getFieldList()) {
+          TableField tableField =
+              new TableField(
+                  column.getName(),
+                  SqlTypeUtil.convertType(column.getType()),
+                  column.getType().isNullable());
+          if (fieldNames.contains(tableField.getName())) {
+            throw new GeaFlowDSLException("Column already exists: {}", tableField.getName());
+          }
+          vertexFields.add(tableField);
+          fieldNames.add(tableField.getName());
+          if (tableField.getName().equals(idFieldName)) {
+            idField = tableField;
+          }
         }
-        RelDataType rowType = typeFactory.createStructType(types, fields);
-
-        String viewSql = view.getSubQuerySql();
-
-        return new GeaFlowView(currentInstance, viewName, fields, rowType, viewSql,
-            view.ifNotExists());
-    }
-
-    /**
-     * Convert {@link SqlCreateGraph} to {@link GeaFlowGraph}.
-     */
-    public GeaFlowGraph convertToGraph(SqlCreateGraph graph) {
-        return convertToGraph(graph, Collections.emptyList());
-    }
-
-    public GeaFlowGraph convertToGraph(SqlCreateGraph graph,
-                                       Collection createTablesInScript) {
-        List vertexTables = new ArrayList<>();
-        SqlNodeList vertices = graph.getVertices();
-        Map vertexEdgeName2UsingTableNameMap = new HashMap<>();
-
-        GraphDescriptor desc = new GraphDescriptor();
-        for (SqlNode node : vertices) {
-            String idFieldName = null;
-            List vertexFields = new ArrayList<>();
-
-            if (node instanceof SqlVertex) {
-                SqlVertex vertex = (SqlVertex) node;
-                for (SqlNode column : vertex.getColumns()) {
-                    SqlTableColumn tableColumn = (SqlTableColumn) column;
-                    vertexFields.add(tableColumn.toTableField());
-                    switch (tableColumn.getCategory()) {
-                        case ID:
-                            idFieldName = tableColumn.getName().getSimple();
-                            break;
-                        case NONE:
-                            break;
-                        default:
-                            throw new GeaFlowDSLException("Illegal column category: " + tableColumn.getCategory()
-                                + " at " + tableColumn.getParserPosition());
-                    }
-                }
-                vertexTables.add(new VertexTable(currentInstance, vertex.getName().getSimple(),
-                    vertexFields, idFieldName));
-                desc.addNode(new NodeDescriptor(desc.getIdName(graph.getName().toString()),
-                    vertex.getName().getSimple()));
-            } else if (node instanceof SqlVertexUsing) {
-                SqlVertexUsing vertexUsing = (SqlVertexUsing) node;
-                List names = vertexUsing.getUsingTableName().names;
-                String tableName = vertexUsing.getUsingTableName().getSimple();
-
-                Table usingTable = null;
-                for (GeaFlowTable createTable : createTablesInScript) {
-                    if (createTable.getName().equals(vertexUsing.getUsingTableName().getSimple())) {
-                        usingTable = createTable;
-                    }
-                }
-                if (usingTable == null) {
-                    String instanceName = names.size() > 1 ? names.get(names.size() - 2) : getCurrentInstance();
-                    usingTable = this.getCatalog().getTable(instanceName, tableName);
-                }
-                if (usingTable == null) {
-                    throw new GeaFlowDSLException(node.getParserPosition(),
-                        "Cannot found using table: {}, check statement order.", tableName);
-                }
-                idFieldName = vertexUsing.getId().getSimple();
-
-                TableField idField = null;
-                Set fieldNames = new HashSet<>();
-                for (RelDataTypeField column : usingTable.getRowType(this.typeFactory).getFieldList()) {
-                    TableField tableField = new TableField(column.getName(),
-                        SqlTypeUtil.convertType(column.getType()), column.getType().isNullable());
-                    if (fieldNames.contains(tableField.getName())) {
-                        throw new GeaFlowDSLException("Column already exists: {}", tableField.getName());
-                    }
-                    vertexFields.add(tableField);
-                    fieldNames.add(tableField.getName());
-                    if (tableField.getName().equals(idFieldName)) {
-                        idField = tableField;
-                    }
-                }
-                if (idField == null) {
-                    throw new GeaFlowDSLException("Cannot found srcIdFieldName: {} in vertex {}",
-                        idFieldName, vertexUsing.getName().getSimple());
-                }
-                vertexEdgeName2UsingTableNameMap.put(vertexUsing.getName().getSimple(),
-                    vertexUsing.getUsingTableName().getSimple());
-                vertexTables.add(new VertexTable(currentInstance, vertexUsing.getName().getSimple(),
-                    vertexFields, idFieldName));
-                desc.addNode(new NodeDescriptor(desc.getIdName(graph.getName().toString()),
-                    vertexUsing.getName().getSimple()));
-            } else {
-                throw new GeaFlowDSLException("vertex not support: " + node);
-            }
+        if (idField == null) {
+          throw new GeaFlowDSLException(
+              "Cannot found srcIdFieldName: {} in vertex {}",
+              idFieldName,
+              vertexUsing.getName().getSimple());
         }
-
-        List edgeTables = new ArrayList<>();
-        SqlNodeList edges = graph.getEdges();
-
-        for (SqlNode node : edges) {
-            String srcIdFieldName = null;
-            String targetIdFieldName = null;
-            String tsFieldName = null;
-            List edgeFields = new ArrayList<>();
-
-            if (node instanceof SqlEdge) {
-                SqlEdge edge = (SqlEdge) node;
-                edge.validate();
-                for (SqlNode column : edge.getColumns()) {
-                    SqlTableColumn tableColumn = (SqlTableColumn) column;
-                    if (tableColumn.getTypeFrom() != null) {
-                        IType columnType = null;
-                        for (VertexTable vertexTable : vertexTables) {
-                            if (vertexTable.getTypeName().equals(tableColumn.getTypeFrom().getSimple())) {
-                                columnType = vertexTable.getIdField().getType();
-                            }
-                        }
-                        assert columnType != null;
-                        edgeFields.add(tableColumn.toTableField(columnType, false));
-                    } else {
-                        edgeFields.add(tableColumn.toTableField());
-                    }
-                    String columnName = tableColumn.getName().getSimple();
-
-                    switch (tableColumn.getCategory()) {
-                        case SOURCE_ID:
-                            srcIdFieldName = columnName;
-                            break;
-                        case DESTINATION_ID:
-                            targetIdFieldName = columnName;
-                            break;
-                        case TIMESTAMP:
-                            tsFieldName = columnName;
-                            break;
-                        case NONE:
-                            break;
-                        default:
-                            throw new GeaFlowDSLException("Illegal column category: " + tableColumn.getCategory()
-                                + " at " + tableColumn.getParserPosition());
-                    }
-                }
-                String tableName = edge.getName().getSimple();
-                edgeTables.add(new EdgeTable(currentInstance, tableName, edgeFields, srcIdFieldName,
-                    targetIdFieldName, tsFieldName));
-                desc.addEdge(GraphDescriptorUtil.getEdgeDescriptor(desc, graph.getName().getSimple(), edge));
-            } else if (node instanceof SqlEdgeUsing) {
-                SqlEdgeUsing edgeUsing = (SqlEdgeUsing) node;
-                List names = edgeUsing.getUsingTableName().names;
-                String tableName = edgeUsing.getUsingTableName().getSimple();
-
-                Table usingTable = null;
-                for (GeaFlowTable createTable : createTablesInScript) {
-                    if (createTable.getName().equals(edgeUsing.getUsingTableName().getSimple())) {
-                        usingTable = createTable;
-                    }
-                }
-                if (usingTable == null) {
-                    String instanceName = names.size() > 1 ? names.get(names.size() - 2) : getCurrentInstance();
-                    usingTable = this.getCatalog().getTable(instanceName, tableName);
-                }
-                if (usingTable == null) {
-                    throw new GeaFlowDSLException(node.getParserPosition(),
-                        "Cannot found using table: {}, check statement order.", tableName);
-                }
-
-                srcIdFieldName = edgeUsing.getSourceId().getSimple();
-                targetIdFieldName = edgeUsing.getTargetId().getSimple();
-                tsFieldName = edgeUsing.getTimeField() == null ? null : edgeUsing.getTimeField().getSimple();
-                TableField srcIdField = null;
-                TableField targetIdField = null;
-                TableField tsField = null;
-                Set fieldNames = new HashSet<>();
-                for (RelDataTypeField column : usingTable.getRowType(this.typeFactory).getFieldList()) {
-                    TableField tableField = new TableField(column.getName(),
-                        SqlTypeUtil.convertType(column.getType()), column.getType().isNullable());
-                    if (fieldNames.contains(tableField.getName())) {
-                        throw new GeaFlowDSLException("Column already exists: {}", tableField.getName());
-                    }
-                    edgeFields.add(tableField);
-                    fieldNames.add(tableField.getName());
-                    if (tableField.getName().equals(srcIdFieldName)) {
-                        srcIdField = tableField;
-                    } else if (tableField.getName().equals(targetIdFieldName)) {
-                        targetIdField = tableField;
-                    } else if (tableField.getName().equals(tsFieldName)) {
-                        tsField = tableField;
-                    }
-                }
-                if (srcIdField == null) {
-                    throw new GeaFlowDSLException("Cannot found srcIdFieldName: {} in edge {}",
-                        srcIdFieldName, edgeUsing.getName().getSimple());
-                }
-                if (targetIdField == null) {
-                    throw new GeaFlowDSLException("Cannot found targetIdFieldName: {} in edge {}",
-                        targetIdFieldName, edgeUsing.getName().getSimple());
-                }
-                if (tsFieldName != null && tsField == null) {
-                    throw new GeaFlowDSLException("Cannot found tsFieldName: {} in edge {}",
-                        tsFieldName, edgeUsing.getName().getSimple());
-                }
-                vertexEdgeName2UsingTableNameMap.put(edgeUsing.getName().getSimple(),
-                    edgeUsing.getUsingTableName().getSimple());
-                edgeTables.add(new EdgeTable(currentInstance, edgeUsing.getName().getSimple(), edgeFields,
-                    srcIdFieldName, targetIdFieldName, tsFieldName));
-                desc.addEdge(
-                    GraphDescriptorUtil.getEdgeDescriptor(desc, graph.getName().getSimple(), edgeUsing));
+        vertexEdgeName2UsingTableNameMap.put(
+            vertexUsing.getName().getSimple(), vertexUsing.getUsingTableName().getSimple());
+        vertexTables.add(
+            new VertexTable(
+                currentInstance, vertexUsing.getName().getSimple(), vertexFields, idFieldName));
+        desc.addNode(
+            new NodeDescriptor(
+                desc.getIdName(graph.getName().toString()), vertexUsing.getName().getSimple()));
+      } else {
+        throw new GeaFlowDSLException("vertex not support: " + node);
+      }
+    }
+
+    List edgeTables = new ArrayList<>();
+    SqlNodeList edges = graph.getEdges();
+
+    for (SqlNode node : edges) {
+      String srcIdFieldName = null;
+      String targetIdFieldName = null;
+      String tsFieldName = null;
+      List edgeFields = new ArrayList<>();
+
+      if (node instanceof SqlEdge) {
+        SqlEdge edge = (SqlEdge) node;
+        edge.validate();
+        for (SqlNode column : edge.getColumns()) {
+          SqlTableColumn tableColumn = (SqlTableColumn) column;
+          if (tableColumn.getTypeFrom() != null) {
+            IType columnType = null;
+            for (VertexTable vertexTable : vertexTables) {
+              if (vertexTable.getTypeName().equals(tableColumn.getTypeFrom().getSimple())) {
+                columnType = vertexTable.getIdField().getType();
+              }
             }
+            assert columnType != null;
+            edgeFields.add(tableColumn.toTableField(columnType, false));
+          } else {
+            edgeFields.add(tableColumn.toTableField());
+          }
+          String columnName = tableColumn.getName().getSimple();
+
+          switch (tableColumn.getCategory()) {
+            case SOURCE_ID:
+              srcIdFieldName = columnName;
+              break;
+            case DESTINATION_ID:
+              targetIdFieldName = columnName;
+              break;
+            case TIMESTAMP:
+              tsFieldName = columnName;
+              break;
+            case NONE:
+              break;
+            default:
+              throw new GeaFlowDSLException(
+                  "Illegal column category: "
+                      + tableColumn.getCategory()
+                      + " at "
+                      + tableColumn.getParserPosition());
+          }
         }
-
-        Map config = new HashMap<>();
-        if (graph.getProperties() != null) {
-            for (SqlNode sqlNode : graph.getProperties()) {
-                SqlTableProperty property = (SqlTableProperty) sqlNode;
-                String key = keyMapping(property.getKey().toString());
-                String value = StringLiteralUtil.toJavaString(property.getValue());
-                config.put(key, value);
-            }
+        String tableName = edge.getName().getSimple();
+        edgeTables.add(
+            new EdgeTable(
+                currentInstance,
+                tableName,
+                edgeFields,
+                srcIdFieldName,
+                targetIdFieldName,
+                tsFieldName));
+        desc.addEdge(
+            GraphDescriptorUtil.getEdgeDescriptor(desc, graph.getName().getSimple(), edge));
+      } else if (node instanceof SqlEdgeUsing) {
+        SqlEdgeUsing edgeUsing = (SqlEdgeUsing) node;
+        List names = edgeUsing.getUsingTableName().names;
+        String tableName = edgeUsing.getUsingTableName().getSimple();
+
+        Table usingTable = null;
+        for (GeaFlowTable createTable : createTablesInScript) {
+          if (createTable.getName().equals(edgeUsing.getUsingTableName().getSimple())) {
+            usingTable = createTable;
+          }
         }
-        GeaFlowGraph geaFlowGraph = new GeaFlowGraph(currentInstance, graph.getName().getSimple(),
-            vertexTables, edgeTables, config, vertexEdgeName2UsingTableNameMap, graph.ifNotExists(),
-            graph.isTemporary(), desc);
-        GraphDescriptor graphStats = geaFlowGraph.getValidDescriptorInGraph(desc);
-        if (graphStats.nodes.size() != desc.nodes.size()
-            || graphStats.edges.size() != desc.edges.size()
-            || graphStats.relations.size() != desc.relations.size()) {
-            throw new GeaFlowDSLException("Error occurred while generating desc as partially "
-                + "constraints are invalid. \n desc: {} \n valid: {}",
-                desc, graphStats);
+        if (usingTable == null) {
+          String instanceName =
+              names.size() > 1 ? names.get(names.size() - 2) : getCurrentInstance();
+          usingTable = this.getCatalog().getTable(instanceName, tableName);
         }
-        geaFlowGraph.setDescriptor(graphStats);
-        return geaFlowGraph;
-    }
-
-    private String keyMapping(String key) {
-        return shortKeyMapping.getOrDefault(key, key);
-    }
-
-    public Map keyMapping(Map input) {
-        Map keyMapping = new HashMap<>();
-        for (Map.Entry entry : input.entrySet()) {
-            keyMapping.put(keyMapping(entry.getKey()), entry.getValue());
+        if (usingTable == null) {
+          throw new GeaFlowDSLException(
+              node.getParserPosition(),
+              "Cannot found using table: {}, check statement order.",
+              tableName);
         }
-        return keyMapping;
-    }
-
-    /**
-     * Register table to catalog.
-     */
-    public void registerTable(GeaFlowTable table) {
-        String tableName = table.getName();
-        catalog.createTable(currentInstance, table);
-        LOG.info("register table : {} to catalog", tableName);
-    }
-
-    /**
-     * Register view to catalog.
-     */
-    public void registerView(GeaFlowView view) {
-        String tableName = view.getName();
-        catalog.createView(currentInstance, view);
-        LOG.info("register view : {} to catalog", tableName);
-    }
-
-    /**
-     * Register graph to catalog.
-     */
-    public void registerGraph(GeaFlowGraph graph) {
-        String graphName = graph.getName();
-        catalog.createGraph(currentInstance, graph);
-        LOG.info("register graph : {} to catalog", graphName);
-    }
 
-    public void registerFunction(GeaFlowFunction function) {
-        sqlOperatorTable.registerSqlFunction(currentInstance, function);
-        LOG.info("register Function : {} to catlog", function);
-    }
-
-    public SqlNode validate(SqlNode node) {
-        if (validatedRelNode.contains(node)) {
-            return node;
+        srcIdFieldName = edgeUsing.getSourceId().getSimple();
+        targetIdFieldName = edgeUsing.getTargetId().getSimple();
+        tsFieldName =
+            edgeUsing.getTimeField() == null ? null : edgeUsing.getTimeField().getSimple();
+        TableField srcIdField = null;
+        TableField targetIdField = null;
+        TableField tsField = null;
+        Set fieldNames = new HashSet<>();
+        for (RelDataTypeField column : usingTable.getRowType(this.typeFactory).getFieldList()) {
+          TableField tableField =
+              new TableField(
+                  column.getName(),
+                  SqlTypeUtil.convertType(column.getType()),
+                  column.getType().isNullable());
+          if (fieldNames.contains(tableField.getName())) {
+            throw new GeaFlowDSLException("Column already exists: {}", tableField.getName());
+          }
+          edgeFields.add(tableField);
+          fieldNames.add(tableField.getName());
+          if (tableField.getName().equals(srcIdFieldName)) {
+            srcIdField = tableField;
+          } else if (tableField.getName().equals(targetIdFieldName)) {
+            targetIdField = tableField;
+          } else if (tableField.getName().equals(tsFieldName)) {
+            tsField = tableField;
+          }
         }
-        return validator.validate(node, new QueryNodeContext());
-    }
-
-    /**
-     * Find the {@link SqlFunction}.
-     */
-    public SqlFunction findSqlFunction(String instance, String name) {
-        return sqlOperatorTable.getSqlFunction(instance == null ? currentInstance : instance, name);
-    }
-
-    // ~ convert SqlNode to RelNode ----------------------------------------------------------
-
-    /**
-     * Return the RelRoot of SqlNode.
-     *
-     * @param sqlNode the sql node.
-     * @return the rel root.
-     */
-    public RelNode toRelNode(SqlNode sqlNode) {
-        RexBuilder rexBuilder = createRexBuilder();
-        RelOptCluster cluster = RelOptCluster.create(relBuilder.getPlanner(), rexBuilder);
-
-        SqlToRelConverter.Config config = SqlToRelConverter.configBuilder()
+        if (srcIdField == null) {
+          throw new GeaFlowDSLException(
+              "Cannot found srcIdFieldName: {} in edge {}",
+              srcIdFieldName,
+              edgeUsing.getName().getSimple());
+        }
+        if (targetIdField == null) {
+          throw new GeaFlowDSLException(
+              "Cannot found targetIdFieldName: {} in edge {}",
+              targetIdFieldName,
+              edgeUsing.getName().getSimple());
+        }
+        if (tsFieldName != null && tsField == null) {
+          throw new GeaFlowDSLException(
+              "Cannot found tsFieldName: {} in edge {}",
+              tsFieldName,
+              edgeUsing.getName().getSimple());
+        }
+        vertexEdgeName2UsingTableNameMap.put(
+            edgeUsing.getName().getSimple(), edgeUsing.getUsingTableName().getSimple());
+        edgeTables.add(
+            new EdgeTable(
+                currentInstance,
+                edgeUsing.getName().getSimple(),
+                edgeFields,
+                srcIdFieldName,
+                targetIdFieldName,
+                tsFieldName));
+        desc.addEdge(
+            GraphDescriptorUtil.getEdgeDescriptor(desc, graph.getName().getSimple(), edgeUsing));
+      }
+    }
+
+    Map config = new HashMap<>();
+    if (graph.getProperties() != null) {
+      for (SqlNode sqlNode : graph.getProperties()) {
+        SqlTableProperty property = (SqlTableProperty) sqlNode;
+        String key = keyMapping(property.getKey().toString());
+        String value = StringLiteralUtil.toJavaString(property.getValue());
+        config.put(key, value);
+      }
+    }
+    GeaFlowGraph geaFlowGraph =
+        new GeaFlowGraph(
+            currentInstance,
+            graph.getName().getSimple(),
+            vertexTables,
+            edgeTables,
+            config,
+            vertexEdgeName2UsingTableNameMap,
+            graph.ifNotExists(),
+            graph.isTemporary(),
+            desc);
+    GraphDescriptor graphStats = geaFlowGraph.getValidDescriptorInGraph(desc);
+    if (graphStats.nodes.size() != desc.nodes.size()
+        || graphStats.edges.size() != desc.edges.size()
+        || graphStats.relations.size() != desc.relations.size()) {
+      throw new GeaFlowDSLException(
+          "Error occurred while generating desc as partially "
+              + "constraints are invalid. \n desc: {} \n valid: {}",
+          desc,
+          graphStats);
+    }
+    geaFlowGraph.setDescriptor(graphStats);
+    return geaFlowGraph;
+  }
+
+  private String keyMapping(String key) {
+    return shortKeyMapping.getOrDefault(key, key);
+  }
+
+  public Map keyMapping(Map input) {
+    Map keyMapping = new HashMap<>();
+    for (Map.Entry entry : input.entrySet()) {
+      keyMapping.put(keyMapping(entry.getKey()), entry.getValue());
+    }
+    return keyMapping;
+  }
+
+  /** Register table to catalog. */
+  public void registerTable(GeaFlowTable table) {
+    String tableName = table.getName();
+    catalog.createTable(currentInstance, table);
+    LOG.info("register table : {} to catalog", tableName);
+  }
+
+  /** Register view to catalog. */
+  public void registerView(GeaFlowView view) {
+    String tableName = view.getName();
+    catalog.createView(currentInstance, view);
+    LOG.info("register view : {} to catalog", tableName);
+  }
+
+  /** Register graph to catalog. */
+  public void registerGraph(GeaFlowGraph graph) {
+    String graphName = graph.getName();
+    catalog.createGraph(currentInstance, graph);
+    LOG.info("register graph : {} to catalog", graphName);
+  }
+
+  public void registerFunction(GeaFlowFunction function) {
+    sqlOperatorTable.registerSqlFunction(currentInstance, function);
+    LOG.info("register Function : {} to catlog", function);
+  }
+
+  public SqlNode validate(SqlNode node) {
+    if (validatedRelNode.contains(node)) {
+      return node;
+    }
+    return validator.validate(node, new QueryNodeContext());
+  }
+
+  /** Find the {@link SqlFunction}. */
+  public SqlFunction findSqlFunction(String instance, String name) {
+    return sqlOperatorTable.getSqlFunction(instance == null ? currentInstance : instance, name);
+  }
+
+  // ~ convert SqlNode to RelNode ----------------------------------------------------------
+
+  /**
+   * Return the RelRoot of SqlNode.
+   *
+   * @param sqlNode the sql node.
+   * @return the rel root.
+   */
+  public RelNode toRelNode(SqlNode sqlNode) {
+    RexBuilder rexBuilder = createRexBuilder();
+    RelOptCluster cluster = RelOptCluster.create(relBuilder.getPlanner(), rexBuilder);
+
+    SqlToRelConverter.Config config =
+        SqlToRelConverter.configBuilder()
             .withTrimUnusedFields(false)
             .withInSubQueryThreshold(10000)
             .withConvertTableAccess(false)
             .build();
 
-        GQLToRelConverter sqlToRelConverter = new GQLToRelConverter(
+    GQLToRelConverter sqlToRelConverter =
+        new GQLToRelConverter(
             new ViewExpanderImpl(),
             validator,
             createCatalogReader(),
@@ -607,155 +666,158 @@ public RelNode toRelNode(SqlNode sqlNode) {
             convertLetTable,
             config);
 
-        RelRoot root = sqlToRelConverter.convertQuery(sqlNode, false, true);
-        root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel));
-        return root.rel;
-    }
-
-    private CalciteCatalogReader createCatalogReader() {
-        SchemaPlus rootSchema = rootSchema(defaultSchema);
-        List defaultSchemaName = ImmutableList.of(defaultSchema.getName());
-        Properties properties = new Properties();
-        properties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
-            frameworkConfig.getParserConfig().caseSensitive());
-        CalciteConnectionConfig config = new CalciteConnectionConfigImpl(properties);
-        return new CalciteCatalogReader(CalciteSchema.from(rootSchema), defaultSchemaName, typeFactory, config);
-    }
-
-    private SchemaPlus rootSchema(SchemaPlus schema) {
-        if (schema.getParentSchema() == null) {
-            return schema;
-        } else {
-            return rootSchema(schema.getParentSchema());
-        }
-    }
-
-    private RexBuilder createRexBuilder() {
-        return new RexBuilder(typeFactory);
-    }
-
-    // RBO optimizer.
-    public RelNode optimize(List ruleGroups, RelNode input) {
-        GQLOptimizer optimizer = new GQLOptimizer(frameworkConfig.getContext());
-        for (RuleGroup ruleGroup : ruleGroups) {
-            optimizer.addRuleGroup(ruleGroup);
-        }
-        return optimizer.optimize(input);
-    }
-
-    // ~ CBO optimizer
-
-    // Run VolcanoPlanner, transform by convention
-    public RelNode transform(List ruleSet, RelNode relNode,
-                             RelTraitSet relTraitSet) {
-        Program optProgram = Programs.ofRules(ruleSet);
-        RelNode transformed;
-        try {
-            transformed = optProgram.run(relBuilder.getPlanner(),
-                relNode,
-                relTraitSet,
-                Lists.newArrayList(),
-                Lists.newArrayList());
-        } catch (RelOptPlanner.CannotPlanException e) {
-            throw new GeaFlowDSLException(
-                "Cannot generate a valid execution plan for the given query: \n\n" + RelOptUtil.toString(relNode)
-                    + "This exception indicates that the query uses an unsupported SQL feature.\n"
-                    + "Please check the documentation for the set create currently supported SQL features.", e);
-        }
-        return transformed;
-    }
-
-    /**
-     * Gets framework config.
-     *
-     * @return the framework config
-     */
-    public FrameworkConfig getFrameworkConfig() {
-        return frameworkConfig;
-    }
-
-    /**
-     * Gets type factory.
-     *
-     * @return the type factory
-     */
-    public GQLJavaTypeFactory getTypeFactory() {
-        return this.typeFactory;
-    }
-
-    public RelBuilder getRelBuilder() {
-        return relBuilder;
-    }
-
-    private class ViewExpanderImpl implements RelOptTable.ViewExpander,
-        Serializable {
-
-        private static final long serialVersionUID = 42L;
-
-        @Override
-        public RelRoot expandView(RelDataType rowType,
-                                  String queryString,
-                                  List schemaPath,
-                                  List viewPath) {
-
-            SqlParser parser = SqlParser.create(queryString, GeaFlowDSLParser.PARSER_CONFIG);
-            SqlNode sqlNode;
-            try {
-                sqlNode = parser.parseQuery();
-            } catch (SqlParseException e) {
-                throw new RuntimeException("parse failed", e);
-            }
-            CalciteCatalogReader reader = createCatalogReader()
-                .withSchemaPath(schemaPath);
-            SqlValidator validator = new GQLValidatorImpl(GQLContext.this, sqlOperatorTable,
-                reader, typeFactory, CONFORMANCE);
-            validator.setIdentifierExpansion(true);
-            SqlNode validatedSqlNode = validator.validate(sqlNode);
-            RexBuilder rexBuilder = createRexBuilder();
-            RelOptCluster cluster = RelOptCluster.create(relBuilder.getPlanner(), rexBuilder);
-            SqlToRelConverter.Config config = SqlToRelConverter.configBuilder()
-                .withTrimUnusedFields(false)
-                .withConvertTableAccess(false)
-                .build();
-            SqlToRelConverter converter = new SqlToRelConverter(
-                new ViewExpanderImpl(), validator, reader, cluster, convertLetTable, config);
-            RelRoot root = converter.convertQuery(validatedSqlNode, true, false);
-            root = root.withRel(converter.flattenTypes(root.rel, true));
-            root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel));
-            return root;
-        }
-    }
-
-    public GQLValidatorImpl getValidator() {
-        return validator;
-    }
-
-    public String getCurrentInstance() {
-        return currentInstance;
-    }
-
-    public void setCurrentInstance(String currentInstance) {
-        this.currentInstance = currentInstance;
-    }
-
-    public String getCurrentGraph() {
-        return currentGraph;
-    }
-
-    public void setCurrentGraph(String currentGraph) {
-        Table graphTable = catalog.getGraph(currentInstance, currentGraph);
-        if (graphTable instanceof GeaFlowGraph) {
-            this.currentGraph = currentGraph;
-            GeaFlowGraph geaFlowGraph = (GeaFlowGraph) graphTable;
-            geaFlowGraph.getConfig().putAll(keyMapping(geaFlowGraph.getConfig().getConfigMap()));
-            getTypeFactory().setCurrentGraph(geaFlowGraph);
-        } else {
-            throw new GeaFlowDSLException("Graph: {} is not exists.", currentGraph);
-        }
-
-    }
-
-    public boolean isCaseSensitive() {
-        return validator.isCaseSensitive();
-    }
+    RelRoot root = sqlToRelConverter.convertQuery(sqlNode, false, true);
+    root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel));
+    return root.rel;
+  }
+
+  private CalciteCatalogReader createCatalogReader() {
+    SchemaPlus rootSchema = rootSchema(defaultSchema);
+    List defaultSchemaName = ImmutableList.of(defaultSchema.getName());
+    Properties properties = new Properties();
+    properties.put(
+        CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
+        frameworkConfig.getParserConfig().caseSensitive());
+    CalciteConnectionConfig config = new CalciteConnectionConfigImpl(properties);
+    return new CalciteCatalogReader(
+        CalciteSchema.from(rootSchema), defaultSchemaName, typeFactory, config);
+  }
+
+  private SchemaPlus rootSchema(SchemaPlus schema) {
+    if (schema.getParentSchema() == null) {
+      return schema;
+    } else {
+      return rootSchema(schema.getParentSchema());
+    }
+  }
+
+  private RexBuilder createRexBuilder() {
+    return new RexBuilder(typeFactory);
+  }
+
+  // RBO optimizer.
+  public RelNode optimize(List ruleGroups, RelNode input) {
+    GQLOptimizer optimizer = new GQLOptimizer(frameworkConfig.getContext());
+    for (RuleGroup ruleGroup : ruleGroups) {
+      optimizer.addRuleGroup(ruleGroup);
+    }
+    return optimizer.optimize(input);
+  }
+
+  // ~ CBO optimizer
+
+  // Run VolcanoPlanner, transform by convention
+  public RelNode transform(List ruleSet, RelNode relNode, RelTraitSet relTraitSet) {
+    Program optProgram = Programs.ofRules(ruleSet);
+    RelNode transformed;
+    try {
+      transformed =
+          optProgram.run(
+              relBuilder.getPlanner(),
+              relNode,
+              relTraitSet,
+              Lists.newArrayList(),
+              Lists.newArrayList());
+    } catch (RelOptPlanner.CannotPlanException e) {
+      throw new GeaFlowDSLException(
+          "Cannot generate a valid execution plan for the given query: \n\n"
+              + RelOptUtil.toString(relNode)
+              + "This exception indicates that the query uses an unsupported SQL feature.\n"
+              + "Please check the documentation for the set create currently supported SQL"
+              + " features.",
+          e);
+    }
+    return transformed;
+  }
+
+  /**
+   * Gets framework config.
+   *
+   * @return the framework config
+   */
+  public FrameworkConfig getFrameworkConfig() {
+    return frameworkConfig;
+  }
+
+  /**
+   * Gets type factory.
+   *
+   * @return the type factory
+   */
+  public GQLJavaTypeFactory getTypeFactory() {
+    return this.typeFactory;
+  }
+
+  public RelBuilder getRelBuilder() {
+    return relBuilder;
+  }
+
+  private class ViewExpanderImpl implements RelOptTable.ViewExpander, Serializable {
+
+    private static final long serialVersionUID = 42L;
+
+    @Override
+    public RelRoot expandView(
+        RelDataType rowType, String queryString, List schemaPath, List viewPath) {
+
+      SqlParser parser = SqlParser.create(queryString, GeaFlowDSLParser.PARSER_CONFIG);
+      SqlNode sqlNode;
+      try {
+        sqlNode = parser.parseQuery();
+      } catch (SqlParseException e) {
+        throw new RuntimeException("parse failed", e);
+      }
+      CalciteCatalogReader reader = createCatalogReader().withSchemaPath(schemaPath);
+      SqlValidator validator =
+          new GQLValidatorImpl(GQLContext.this, sqlOperatorTable, reader, typeFactory, CONFORMANCE);
+      validator.setIdentifierExpansion(true);
+      SqlNode validatedSqlNode = validator.validate(sqlNode);
+      RexBuilder rexBuilder = createRexBuilder();
+      RelOptCluster cluster = RelOptCluster.create(relBuilder.getPlanner(), rexBuilder);
+      SqlToRelConverter.Config config =
+          SqlToRelConverter.configBuilder()
+              .withTrimUnusedFields(false)
+              .withConvertTableAccess(false)
+              .build();
+      SqlToRelConverter converter =
+          new SqlToRelConverter(
+              new ViewExpanderImpl(), validator, reader, cluster, convertLetTable, config);
+      RelRoot root = converter.convertQuery(validatedSqlNode, true, false);
+      root = root.withRel(converter.flattenTypes(root.rel, true));
+      root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel));
+      return root;
+    }
+  }
+
+  public GQLValidatorImpl getValidator() {
+    return validator;
+  }
+
+  public String getCurrentInstance() {
+    return currentInstance;
+  }
+
+  public void setCurrentInstance(String currentInstance) {
+    this.currentInstance = currentInstance;
+  }
+
+  public String getCurrentGraph() {
+    return currentGraph;
+  }
+
+  public void setCurrentGraph(String currentGraph) {
+    Table graphTable = catalog.getGraph(currentInstance, currentGraph);
+    if (graphTable instanceof GeaFlowGraph) {
+      this.currentGraph = currentGraph;
+      GeaFlowGraph geaFlowGraph = (GeaFlowGraph) graphTable;
+      geaFlowGraph.getConfig().putAll(keyMapping(geaFlowGraph.getConfig().getConfigMap()));
+      getTypeFactory().setCurrentGraph(geaFlowGraph);
+    } else {
+      throw new GeaFlowDSLException("Graph: {} is not exists.", currentGraph);
+    }
+  }
+
+  public boolean isCaseSensitive() {
+    return validator.isCaseSensitive();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCost.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCost.java
index cd2a244b3..ad0e3f901 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCost.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCost.java
@@ -24,181 +24,188 @@
 
 public class GQLCost implements RelOptCost {
 
-    static final GQLCost
-        INFINITY =
-        new GQLCost(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
-            Double.POSITIVE_INFINITY) {
-            public String toString() {
-                return "{inf}";
-            }
-        };
-
-    static final GQLCost
-        HUGE =
-        new GQLCost(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE) {
-            public String toString() {
-                return "{huge}";
-            }
-        };
-
-    static final GQLCost ZERO = new GQLCost(0.0, 0.0, 0.0) {
+  static final GQLCost INFINITY =
+      new GQLCost(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY) {
         public String toString() {
-            return "{0}";
+          return "{inf}";
         }
-    };
+      };
 
-    static final GQLCost TINY = new GQLCost(1.0, 1.0, 0.0) {
+  static final GQLCost HUGE =
+      new GQLCost(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE) {
         public String toString() {
-            return "{tiny}";
+          return "{huge}";
         }
-    };
-
-    //~ Instance fields --------------------------------------------------------
-    final double cpu;
-    final double io;
-    final double rowCount;
-
-    //~ Constructors -----------------------------------------------------------
-    GQLCost(double rowCount, double cpu, double io) {
-        this.rowCount = rowCount;
-        this.cpu = cpu;
-        this.io = io;
-    }
-
-    //~ Methods ----------------------------------------------------------------
-
-    public double getCpu() {
-        return cpu;
-    }
-
-    public boolean isInfinite() {
-        return (this == INFINITY) || (this.rowCount == Double.POSITIVE_INFINITY) || (this.cpu
-            == Double.POSITIVE_INFINITY)
-            || (this.io == Double.POSITIVE_INFINITY);
-    }
-
-    @Override
-    public boolean equals(RelOptCost other) {
-        return this == other || other instanceof GQLCost && (this.rowCount
-            == ((GQLCost) other).rowCount)
-            && (this.cpu == ((GQLCost) other).cpu) && (this.io == ((GQLCost) other).io);
-    }
-
-    public double getIo() {
-        return io;
-    }
-
-    public boolean isLe(RelOptCost other) {
-        GQLCost that = (GQLCost) other;
-        return this == that || this.rowCount <= that.rowCount;
-    }
-
-    public boolean isLt(RelOptCost other) {
-        GQLCost that = (GQLCost) other;
-        return this.rowCount < that.rowCount;
-    }
-
-    public double getRows() {
-        return rowCount;
-    }
-
-    public boolean isEqWithEpsilon(RelOptCost other) {
-        if (!(other instanceof GQLCost)) {
-            return false;
-        }
-        GQLCost that = (GQLCost) other;
-        return (this == that) || ((Math.abs(this.rowCount - that.rowCount) < RelOptUtil.EPSILON)
-            && (Math.abs(this.cpu - that.cpu) < RelOptUtil.EPSILON) && (Math.abs(this.io - that.io)
-            < RelOptUtil.EPSILON));
-    }
-
-    public RelOptCost minus(RelOptCost other) {
-        if (this == INFINITY) {
-            return this;
-        }
-        GQLCost that = (GQLCost) other;
-        return new GQLCost(this.rowCount - that.rowCount, this.cpu - that.cpu,
-            this.io - that.io);
-    }
+      };
 
-    public RelOptCost multiplyBy(double factor) {
-        if (this == INFINITY) {
-            return this;
-        }
-        return new GQLCost(rowCount * factor, cpu * factor, io * factor);
-    }
-
-    public double divideBy(RelOptCost cost) {
-        // Compute the geometric average create the ratios create all create the factors
-        // which are non-zero and finite.
-        GQLCost that = (GQLCost) cost;
-        double d = 1;
-        double n = 0;
-        if ((this.rowCount != 0) && !Double.isInfinite(this.rowCount) && (that.rowCount != 0)
-            && !Double.isInfinite(that.rowCount)) {
-            d *= this.rowCount / that.rowCount;
-            ++n;
-        }
-        if ((this.cpu != 0) && !Double.isInfinite(this.cpu) && (that.cpu != 0) && !Double
-            .isInfinite(that.cpu)) {
-            d *= this.cpu / that.cpu;
-            ++n;
-        }
-        if ((this.io != 0) && !Double.isInfinite(this.io) && (that.io != 0) && !Double
-            .isInfinite(that.io)) {
-            d *= this.io / that.io;
-            ++n;
-        }
-        if (n == 0) {
-            return 1.0;
-        }
-        return Math.pow(d, 1 / n);
-    }
-
-    public RelOptCost plus(RelOptCost other) {
-        GQLCost that = (GQLCost) other;
-        if ((this == INFINITY) || (that == INFINITY)) {
-            return INFINITY;
-        }
-        return new GQLCost(this.rowCount + that.rowCount, this.cpu + that.cpu,
-            this.io + that.io);
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
+  static final GQLCost ZERO =
+      new GQLCost(0.0, 0.0, 0.0) {
+        public String toString() {
+          return "{0}";
         }
+      };
 
-        GQLCost that = (GQLCost) o;
-
-        if (Double.compare(that.cpu, cpu) != 0) {
-            return false;
-        }
-        if (Double.compare(that.io, io) != 0) {
-            return false;
+  static final GQLCost TINY =
+      new GQLCost(1.0, 1.0, 0.0) {
+        public String toString() {
+          return "{tiny}";
         }
-        return Double.compare(that.rowCount, rowCount) == 0;
-    }
-
-    @Override
-    public int hashCode() {
-        int result;
-        long temp;
-        temp = Double.doubleToLongBits(cpu);
-        result = (int) (temp ^ (temp >>> 32));
-        temp = Double.doubleToLongBits(io);
-        result = 31 * result + (int) (temp ^ (temp >>> 32));
-        temp = Double.doubleToLongBits(rowCount);
-        result = 31 * result + (int) (temp ^ (temp >>> 32));
-        return result;
-    }
-
-    @Override
-    public String toString() {
-        return "{" + rowCount + " rows, " + cpu + " cpu, " + io + " io}";
-    }
+      };
+
+  // ~ Instance fields --------------------------------------------------------
+  final double cpu;
+  final double io;
+  final double rowCount;
+
+  // ~ Constructors -----------------------------------------------------------
+  GQLCost(double rowCount, double cpu, double io) {
+    this.rowCount = rowCount;
+    this.cpu = cpu;
+    this.io = io;
+  }
+
+  // ~ Methods ----------------------------------------------------------------
+
+  public double getCpu() {
+    return cpu;
+  }
+
+  public boolean isInfinite() {
+    return (this == INFINITY)
+        || (this.rowCount == Double.POSITIVE_INFINITY)
+        || (this.cpu == Double.POSITIVE_INFINITY)
+        || (this.io == Double.POSITIVE_INFINITY);
+  }
+
+  @Override
+  public boolean equals(RelOptCost other) {
+    return this == other
+        || other instanceof GQLCost
+            && (this.rowCount == ((GQLCost) other).rowCount)
+            && (this.cpu == ((GQLCost) other).cpu)
+            && (this.io == ((GQLCost) other).io);
+  }
+
+  public double getIo() {
+    return io;
+  }
+
+  public boolean isLe(RelOptCost other) {
+    GQLCost that = (GQLCost) other;
+    return this == that || this.rowCount <= that.rowCount;
+  }
+
+  public boolean isLt(RelOptCost other) {
+    GQLCost that = (GQLCost) other;
+    return this.rowCount < that.rowCount;
+  }
+
+  public double getRows() {
+    return rowCount;
+  }
+
+  public boolean isEqWithEpsilon(RelOptCost other) {
+    if (!(other instanceof GQLCost)) {
+      return false;
+    }
+    GQLCost that = (GQLCost) other;
+    return (this == that)
+        || ((Math.abs(this.rowCount - that.rowCount) < RelOptUtil.EPSILON)
+            && (Math.abs(this.cpu - that.cpu) < RelOptUtil.EPSILON)
+            && (Math.abs(this.io - that.io) < RelOptUtil.EPSILON));
+  }
+
+  public RelOptCost minus(RelOptCost other) {
+    if (this == INFINITY) {
+      return this;
+    }
+    GQLCost that = (GQLCost) other;
+    return new GQLCost(this.rowCount - that.rowCount, this.cpu - that.cpu, this.io - that.io);
+  }
+
+  public RelOptCost multiplyBy(double factor) {
+    if (this == INFINITY) {
+      return this;
+    }
+    return new GQLCost(rowCount * factor, cpu * factor, io * factor);
+  }
+
+  public double divideBy(RelOptCost cost) {
+    // Compute the geometric average create the ratios create all create the factors
+    // which are non-zero and finite.
+    GQLCost that = (GQLCost) cost;
+    double d = 1;
+    double n = 0;
+    if ((this.rowCount != 0)
+        && !Double.isInfinite(this.rowCount)
+        && (that.rowCount != 0)
+        && !Double.isInfinite(that.rowCount)) {
+      d *= this.rowCount / that.rowCount;
+      ++n;
+    }
+    if ((this.cpu != 0)
+        && !Double.isInfinite(this.cpu)
+        && (that.cpu != 0)
+        && !Double.isInfinite(that.cpu)) {
+      d *= this.cpu / that.cpu;
+      ++n;
+    }
+    if ((this.io != 0)
+        && !Double.isInfinite(this.io)
+        && (that.io != 0)
+        && !Double.isInfinite(that.io)) {
+      d *= this.io / that.io;
+      ++n;
+    }
+    if (n == 0) {
+      return 1.0;
+    }
+    return Math.pow(d, 1 / n);
+  }
+
+  public RelOptCost plus(RelOptCost other) {
+    GQLCost that = (GQLCost) other;
+    if ((this == INFINITY) || (that == INFINITY)) {
+      return INFINITY;
+    }
+    return new GQLCost(this.rowCount + that.rowCount, this.cpu + that.cpu, this.io + that.io);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+
+    GQLCost that = (GQLCost) o;
+
+    if (Double.compare(that.cpu, cpu) != 0) {
+      return false;
+    }
+    if (Double.compare(that.io, io) != 0) {
+      return false;
+    }
+    return Double.compare(that.rowCount, rowCount) == 0;
+  }
+
+  @Override
+  public int hashCode() {
+    int result;
+    long temp;
+    temp = Double.doubleToLongBits(cpu);
+    result = (int) (temp ^ (temp >>> 32));
+    temp = Double.doubleToLongBits(io);
+    result = 31 * result + (int) (temp ^ (temp >>> 32));
+    temp = Double.doubleToLongBits(rowCount);
+    result = 31 * result + (int) (temp ^ (temp >>> 32));
+    return result;
+  }
+
+  @Override
+  public String toString() {
+    return "{" + rowCount + " rows, " + cpu + " cpu, " + io + " io}";
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCostFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCostFactory.java
index f34d09751..42ba258d2 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCostFactory.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLCostFactory.java
@@ -24,24 +24,23 @@
 
 public class GQLCostFactory implements RelOptCostFactory {
 
-    public RelOptCost makeCost(double dRows, double dCpu, double dIo) {
-        return new GQLCost(dRows, dCpu, dIo);
-    }
+  public RelOptCost makeCost(double dRows, double dCpu, double dIo) {
+    return new GQLCost(dRows, dCpu, dIo);
+  }
 
-    public RelOptCost makeHugeCost() {
-        return GQLCost.HUGE;
-    }
+  public RelOptCost makeHugeCost() {
+    return GQLCost.HUGE;
+  }
 
-    public RelOptCost makeInfiniteCost() {
-        return GQLCost.INFINITY;
-    }
+  public RelOptCost makeInfiniteCost() {
+    return GQLCost.INFINITY;
+  }
 
-    public RelOptCost makeTinyCost() {
-        return GQLCost.TINY;
-    }
-
-    public RelOptCost makeZeroCost() {
-        return GQLCost.ZERO;
-    }
+  public RelOptCost makeTinyCost() {
+    return GQLCost.TINY;
+  }
 
+  public RelOptCost makeZeroCost() {
+    return GQLCost.ZERO;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLJavaTypeFactory.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLJavaTypeFactory.java
index ccc4a9e5b..00ec42de2 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLJavaTypeFactory.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLJavaTypeFactory.java
@@ -27,6 +27,7 @@
 import java.sql.Time;
 import java.sql.Timestamp;
 import java.util.Map;
+
 import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeSystem;
@@ -43,121 +44,119 @@
 
 public class GQLJavaTypeFactory extends JavaTypeFactoryImpl {
 
-    private GeaFlowGraph currentGraph;
+  private GeaFlowGraph currentGraph;
 
-    public static final String NATIVE_UTF16_CHARSET_NAME =
-        (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ? "UTF-16BE" : "UTF-16LE";
+  public static final String NATIVE_UTF16_CHARSET_NAME =
+      (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ? "UTF-16BE" : "UTF-16LE";
 
-    public GQLJavaTypeFactory(RelDataTypeSystem relDataTypeSystem) {
-        super(relDataTypeSystem);
-    }
+  public GQLJavaTypeFactory(RelDataTypeSystem relDataTypeSystem) {
+    super(relDataTypeSystem);
+  }
 
-    public GeaFlowGraph getCurrentGraph() {
-        return currentGraph;
-    }
+  public GeaFlowGraph getCurrentGraph() {
+    return currentGraph;
+  }
 
-    public GeaFlowGraph setCurrentGraph(GeaFlowGraph graph) {
-        GeaFlowGraph prevGraph = currentGraph;
-        currentGraph = graph;
-        return prevGraph;
-    }
+  public GeaFlowGraph setCurrentGraph(GeaFlowGraph graph) {
+    GeaFlowGraph prevGraph = currentGraph;
+    currentGraph = graph;
+    return prevGraph;
+  }
 
-    @Override
-    public Charset getDefaultCharset() {
-        return Charset.forName(NATIVE_UTF16_CHARSET_NAME);
-    }
+  @Override
+  public Charset getDefaultCharset() {
+    return Charset.forName(NATIVE_UTF16_CHARSET_NAME);
+  }
 
-    @Override
-    public Type getJavaClass(RelDataType type) {
-        switch (type.getSqlTypeName()) {
-            case VARCHAR:
-            case CHAR:
-                return BinaryString.class;
-            case DATE:
-                return Date.class;
-            case TIME:
-                return Time.class;
-            case INTEGER:
-            case INTERVAL_YEAR:
-            case INTERVAL_YEAR_MONTH:
-            case INTERVAL_MONTH:
-            case SYMBOL:
-                return type.isNullable() ? Integer.class : int.class;
-            case TIMESTAMP:
-                return Timestamp.class;
-            case BIGINT:
-            case INTERVAL_DAY:
-            case INTERVAL_DAY_HOUR:
-            case INTERVAL_DAY_MINUTE:
-            case INTERVAL_DAY_SECOND:
-            case INTERVAL_HOUR:
-            case INTERVAL_HOUR_MINUTE:
-            case INTERVAL_HOUR_SECOND:
-            case INTERVAL_MINUTE:
-            case INTERVAL_MINUTE_SECOND:
-            case INTERVAL_SECOND:
-                return type.isNullable() ? Long.class : long.class;
-            case SMALLINT:
-                return type.isNullable() ? Short.class : short.class;
-            case TINYINT:
-                return type.isNullable() ? Byte.class : byte.class;
-            case DECIMAL:
-                return BigDecimal.class;
-            case BOOLEAN:
-                return type.isNullable() ? Boolean.class : boolean.class;
-            case DOUBLE:
-            case FLOAT:
-                return type.isNullable() ? Double.class : double.class;
-            case REAL:
-                return type.isNullable() ? Float.class : float.class;
-            case BINARY:
-            case VARBINARY:
-                return byte[].class;
-            case ARRAY:
-                Class elementType = (Class) getJavaClass(type.getComponentType());
-                try {
-                    return Class.forName("[L" + elementType.getName() + ";");
-                } catch (Exception e) {
-                    throw new GeaFlowDSLException("Error in create Java Type for " + type, e);
-                }
-            case MAP:
-                return Map.class;
-            case VERTEX:
-                return RowVertex.class;
-            case EDGE:
-                return RowEdge.class;
-            case PATH:
-                return Path.class;
-            case ROW:
-                return Row.class;
-            default:
-                return super.getJavaClass(type);
+  @Override
+  public Type getJavaClass(RelDataType type) {
+    switch (type.getSqlTypeName()) {
+      case VARCHAR:
+      case CHAR:
+        return BinaryString.class;
+      case DATE:
+        return Date.class;
+      case TIME:
+        return Time.class;
+      case INTEGER:
+      case INTERVAL_YEAR:
+      case INTERVAL_YEAR_MONTH:
+      case INTERVAL_MONTH:
+      case SYMBOL:
+        return type.isNullable() ? Integer.class : int.class;
+      case TIMESTAMP:
+        return Timestamp.class;
+      case BIGINT:
+      case INTERVAL_DAY:
+      case INTERVAL_DAY_HOUR:
+      case INTERVAL_DAY_MINUTE:
+      case INTERVAL_DAY_SECOND:
+      case INTERVAL_HOUR:
+      case INTERVAL_HOUR_MINUTE:
+      case INTERVAL_HOUR_SECOND:
+      case INTERVAL_MINUTE:
+      case INTERVAL_MINUTE_SECOND:
+      case INTERVAL_SECOND:
+        return type.isNullable() ? Long.class : long.class;
+      case SMALLINT:
+        return type.isNullable() ? Short.class : short.class;
+      case TINYINT:
+        return type.isNullable() ? Byte.class : byte.class;
+      case DECIMAL:
+        return BigDecimal.class;
+      case BOOLEAN:
+        return type.isNullable() ? Boolean.class : boolean.class;
+      case DOUBLE:
+      case FLOAT:
+        return type.isNullable() ? Double.class : double.class;
+      case REAL:
+        return type.isNullable() ? Float.class : float.class;
+      case BINARY:
+      case VARBINARY:
+        return byte[].class;
+      case ARRAY:
+        Class elementType = (Class) getJavaClass(type.getComponentType());
+        try {
+          return Class.forName("[L" + elementType.getName() + ";");
+        } catch (Exception e) {
+          throw new GeaFlowDSLException("Error in create Java Type for " + type, e);
         }
+      case MAP:
+        return Map.class;
+      case VERTEX:
+        return RowVertex.class;
+      case EDGE:
+        return RowEdge.class;
+      case PATH:
+        return Path.class;
+      case ROW:
+        return Row.class;
+      default:
+        return super.getJavaClass(type);
     }
+  }
 
-    @Override
-    public RelDataType createType(Type type) {
-        if (type instanceof Class && ((Class) type).isArray()) {
-            Class elementType = ((Class) type).getComponentType();
-            return canonize(new ArraySqlType(createType(elementType), true));
-        }
-        if (type == BinaryString.class) {
-            return super.createType(String.class);
-        }
-        return super.createType(type);
+  @Override
+  public RelDataType createType(Type type) {
+    if (type instanceof Class && ((Class) type).isArray()) {
+      Class elementType = ((Class) type).getComponentType();
+      return canonize(new ArraySqlType(createType(elementType), true));
     }
-
-    @Override
-    public RelDataType createTypeWithNullability(
-        final RelDataType type,
-        final boolean nullable) {
-        if (type.getSqlTypeName() == SqlTypeName.PATH || type instanceof MetaFieldType) {
-            return type;
-        }
-        return super.createTypeWithNullability(type, nullable);
+    if (type == BinaryString.class) {
+      return super.createType(String.class);
     }
+    return super.createType(type);
+  }
 
-    public static GQLJavaTypeFactory create() {
-        return new GQLJavaTypeFactory(new GQLRelDataTypeSystem());
+  @Override
+  public RelDataType createTypeWithNullability(final RelDataType type, final boolean nullable) {
+    if (type.getSqlTypeName() == SqlTypeName.PATH || type instanceof MetaFieldType) {
+      return type;
     }
+    return super.createTypeWithNullability(type, nullable);
+  }
+
+  public static GQLJavaTypeFactory create() {
+    return new GQLJavaTypeFactory(new GQLRelDataTypeSystem());
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLOperatorTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLOperatorTable.java
index fe98ee006..9af727b11 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLOperatorTable.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLOperatorTable.java
@@ -19,8 +19,8 @@
 
 package org.apache.geaflow.dsl.planner;
 
-import com.google.common.collect.Lists;
 import java.util.List;
+
 import org.apache.calcite.sql.*;
 import org.apache.calcite.sql.util.ChainedSqlOperatorTable;
 import org.apache.calcite.sql.util.ListSqlOperatorTable;
@@ -28,66 +28,69 @@
 import org.apache.geaflow.dsl.schema.GeaFlowFunction;
 import org.apache.geaflow.dsl.util.FunctionUtil;
 
-/**
- * An operator table for look up SQL operators and functions.
- */
+import com.google.common.collect.Lists;
+
+/** An operator table for look up SQL operators and functions. */
 public class GQLOperatorTable extends ChainedSqlOperatorTable {
 
-    private final Catalog catalog;
+  private final Catalog catalog;
 
-    private final GQLJavaTypeFactory typeFactory;
+  private final GQLJavaTypeFactory typeFactory;
 
-    private final GQLContext gqlContext;
+  private final GQLContext gqlContext;
 
-    public GQLOperatorTable(Catalog catalog, GQLJavaTypeFactory typeFactory,
-                            GQLContext gqlContext,
-                            SqlOperatorTable... tables) {
-        super(Lists.newArrayList(tables));
-        this.catalog = catalog;
-        this.gqlContext = gqlContext;
-        this.typeFactory = typeFactory;
-    }
+  public GQLOperatorTable(
+      Catalog catalog,
+      GQLJavaTypeFactory typeFactory,
+      GQLContext gqlContext,
+      SqlOperatorTable... tables) {
+    super(Lists.newArrayList(tables));
+    this.catalog = catalog;
+    this.gqlContext = gqlContext;
+    this.typeFactory = typeFactory;
+  }
 
-    /**
-     * Add a {@code SqlFunction} to operator table.
-     */
-    public void registerSqlFunction(String instance, GeaFlowFunction function) {
-        catalog.createFunction(instance, function);
+  /** Add a {@code SqlFunction} to operator table. */
+  public void registerSqlFunction(String instance, GeaFlowFunction function) {
+    catalog.createFunction(instance, function);
 
-        SqlFunction sqlFunction = FunctionUtil.createSqlFunction(function, typeFactory);
-        for (SqlOperatorTable operatorTable : tableList) {
-            if (operatorTable instanceof ListSqlOperatorTable) {
-                ((ListSqlOperatorTable) operatorTable).add(sqlFunction);
-                return;
-            }
-        }
+    SqlFunction sqlFunction = FunctionUtil.createSqlFunction(function, typeFactory);
+    for (SqlOperatorTable operatorTable : tableList) {
+      if (operatorTable instanceof ListSqlOperatorTable) {
+        ((ListSqlOperatorTable) operatorTable).add(sqlFunction);
+        return;
+      }
     }
+  }
 
-    public SqlFunction getSqlFunction(String instance, String name) {
-        for (SqlOperator operator : getOperatorList()) {
-            if (operator.getName().equalsIgnoreCase(name)) {
-                if (operator instanceof SqlFunction) {
-                    return (SqlFunction) operator;
-                }
-            }
-        }
-        GeaFlowFunction function = catalog.getFunction(instance, name);
-        if (function == null) {
-            return null;
+  public SqlFunction getSqlFunction(String instance, String name) {
+    for (SqlOperator operator : getOperatorList()) {
+      if (operator.getName().equalsIgnoreCase(name)) {
+        if (operator instanceof SqlFunction) {
+          return (SqlFunction) operator;
         }
-        return FunctionUtil.createSqlFunction(function, typeFactory);
+      }
     }
+    GeaFlowFunction function = catalog.getFunction(instance, name);
+    if (function == null) {
+      return null;
+    }
+    return FunctionUtil.createSqlFunction(function, typeFactory);
+  }
 
-    @Override
-    public void lookupOperatorOverloads(SqlIdentifier opName,
-                                        SqlFunctionCategory category, SqlSyntax syntax,
-                                        List operatorList) {
-        super.lookupOperatorOverloads(opName, category, syntax, operatorList);
-        if (operatorList.isEmpty() && category == SqlFunctionCategory.USER_DEFINED_FUNCTION) {
-            GeaFlowFunction function = catalog.getFunction(gqlContext.getCurrentInstance(), opName.getSimple());
-            if (function != null) {
-                operatorList.add(FunctionUtil.createSqlFunction(function, typeFactory));
-            }
-        }
+  @Override
+  public void lookupOperatorOverloads(
+      SqlIdentifier opName,
+      SqlFunctionCategory category,
+      SqlSyntax syntax,
+      List operatorList) {
+    super.lookupOperatorOverloads(opName, category, syntax, operatorList);
+    if (operatorList.isEmpty() && category == SqlFunctionCategory.USER_DEFINED_FUNCTION) {
+      GeaFlowFunction function =
+          catalog.getFunction(gqlContext.getCurrentInstance(), opName.getSimple());
+      if (function != null) {
+        operatorList.add(FunctionUtil.createSqlFunction(function, typeFactory));
+      }
     }
-}
\ No newline at end of file
+  }
+}
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelBuilder.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelBuilder.java
index 79f403fb6..ca0bb442d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelBuilder.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelBuilder.java
@@ -32,39 +32,40 @@
 
 public class GQLRelBuilder extends RelBuilder {
 
-    protected GQLRelBuilder(Context context,
-                            RelOptCluster cluster,
-                            RelOptSchema relOptSchema) {
-        super(context, cluster, relOptSchema);
-    }
+  protected GQLRelBuilder(Context context, RelOptCluster cluster, RelOptSchema relOptSchema) {
+    super(context, cluster, relOptSchema);
+  }
 
-    public static GQLRelBuilder create(FrameworkConfig config, RexBuilder builder) {
+  public static GQLRelBuilder create(FrameworkConfig config, RexBuilder builder) {
 
-        final RelOptCluster[] clusters = new RelOptCluster[1];
-        final RelOptSchema[] relOptSchemas = new RelOptSchema[1];
+    final RelOptCluster[] clusters = new RelOptCluster[1];
+    final RelOptSchema[] relOptSchemas = new RelOptSchema[1];
 
-        Frameworks.withPrepare(
-            new Frameworks.PrepareAction(config) {
-                public Void apply(RelOptCluster cluster, RelOptSchema relOptSchema,
-                                  SchemaPlus rootSchema, CalciteServerStatement statement) {
-                    clusters[0] = cluster;
-                    relOptSchemas[0] = relOptSchema;
-                    return null;
-                }
-            });
-        RelOptCluster gqlCluster = RelOptCluster.create(clusters[0].getPlanner(), builder);
-        return new GQLRelBuilder(config.getContext(), gqlCluster, relOptSchemas[0]);
-    }
+    Frameworks.withPrepare(
+        new Frameworks.PrepareAction(config) {
+          public Void apply(
+              RelOptCluster cluster,
+              RelOptSchema relOptSchema,
+              SchemaPlus rootSchema,
+              CalciteServerStatement statement) {
+            clusters[0] = cluster;
+            relOptSchemas[0] = relOptSchema;
+            return null;
+          }
+        });
+    RelOptCluster gqlCluster = RelOptCluster.create(clusters[0].getPlanner(), builder);
+    return new GQLRelBuilder(config.getContext(), gqlCluster, relOptSchemas[0]);
+  }
 
-    public RelOptCluster getCluster() {
-        return this.cluster;
-    }
+  public RelOptCluster getCluster() {
+    return this.cluster;
+  }
 
-    public RelOptPlanner getPlanner() {
-        return this.cluster.getPlanner();
-    }
+  public RelOptPlanner getPlanner() {
+    return this.cluster.getPlanner();
+  }
 
-    public GQLJavaTypeFactory getTypeFactory() {
-        return (GQLJavaTypeFactory) super.getTypeFactory();
-    }
+  public GQLJavaTypeFactory getTypeFactory() {
+    return (GQLJavaTypeFactory) super.getTypeFactory();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelDataTypeSystem.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelDataTypeSystem.java
index 868dc2bea..c94943bf5 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelDataTypeSystem.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelDataTypeSystem.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.planner;
 
 import java.io.Serializable;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
@@ -27,26 +28,25 @@
 
 public class GQLRelDataTypeSystem extends RelDataTypeSystemImpl implements Serializable {
 
-    private static final long serialVersionUID = -9093441501684237749L;
+  private static final long serialVersionUID = -9093441501684237749L;
 
-    @Override
-    public int getMaxNumericScale() {
-        return Integer.MAX_VALUE / 2;
-    }
+  @Override
+  public int getMaxNumericScale() {
+    return Integer.MAX_VALUE / 2;
+  }
 
-    @Override
-    public int getMaxNumericPrecision() {
-        return Integer.MAX_VALUE / 2;
-    }
+  @Override
+  public int getMaxNumericPrecision() {
+    return Integer.MAX_VALUE / 2;
+  }
 
-    public boolean shouldConvertRaggedUnionTypesToVarying() {
-        return true;
-    }
+  public boolean shouldConvertRaggedUnionTypesToVarying() {
+    return true;
+  }
 
-    @Override
-    public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
-                                        RelDataType argumentType) {
-        return typeFactory.createTypeWithNullability(
-            typeFactory.createSqlType(SqlTypeName.DOUBLE), true);
-    }
+  @Override
+  public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, RelDataType argumentType) {
+    return typeFactory.createTypeWithNullability(
+        typeFactory.createSqlType(SqlTypeName.DOUBLE), true);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelOptTableImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelOptTableImpl.java
index fe828af1c..74cc261b8 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelOptTableImpl.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/planner/GQLRelOptTableImpl.java
@@ -19,10 +19,10 @@
 
 package org.apache.geaflow.dsl.planner;
 
-import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.function.Function;
+
 import org.apache.calcite.linq4j.tree.Expression;
 import org.apache.calcite.plan.RelOptSchema;
 import org.apache.calcite.prepare.RelOptTableImpl;
@@ -34,33 +34,37 @@
 import org.apache.geaflow.dsl.schema.GeaFlowGraph.EdgeTable;
 import org.apache.geaflow.dsl.schema.GeaFlowGraph.VertexTable;
 
+import com.google.common.collect.ImmutableList;
+
 public class GQLRelOptTableImpl extends RelOptTableImpl {
 
-    protected GQLRelOptTableImpl(RelOptSchema schema,
-                                 RelDataType rowType,
-                                 List names, Table table,
-                                 Function expressionFunction,
-                                 Double rowCount) {
-        super(schema, rowType, names, table, expressionFunction, rowCount);
-    }
+  protected GQLRelOptTableImpl(
+      RelOptSchema schema,
+      RelDataType rowType,
+      List names,
+      Table table,
+      Function expressionFunction,
+      Double rowCount) {
+    super(schema, rowType, names, table, expressionFunction, rowCount);
+  }
 
-    public static RelOptTableImpl create(RelOptSchema schema,
-                                         RelDataType rowType, Table table, ImmutableList names) {
-        return new GQLRelOptTableImpl(schema, rowType, names, table, null, null);
-    }
+  public static RelOptTableImpl create(
+      RelOptSchema schema, RelDataType rowType, Table table, ImmutableList names) {
+    return new GQLRelOptTableImpl(schema, rowType, names, table, null, null);
+  }
 
-    @Override
-    public List getColumnStrategies() {
-        List columnStrategies = super.getColumnStrategies();
-        if (table instanceof VertexTable) {
-            List vertexColumnStrategies = new ArrayList<>(columnStrategies);
-            vertexColumnStrategies.set(VertexType.LABEL_FIELD_POSITION, ColumnStrategy.VIRTUAL);
-            return vertexColumnStrategies;
-        } else if (table instanceof EdgeTable) {
-            List edgeColumnStrategies = new ArrayList<>(columnStrategies);
-            edgeColumnStrategies.set(EdgeType.LABEL_FIELD_POSITION, ColumnStrategy.VIRTUAL);
-            return edgeColumnStrategies;
-        }
-        return columnStrategies;
+  @Override
+  public List getColumnStrategies() {
+    List columnStrategies = super.getColumnStrategies();
+    if (table instanceof VertexTable) {
+      List vertexColumnStrategies = new ArrayList<>(columnStrategies);
+      vertexColumnStrategies.set(VertexType.LABEL_FIELD_POSITION, ColumnStrategy.VIRTUAL);
+      return vertexColumnStrategies;
+    } else if (table instanceof EdgeTable) {
+      List edgeColumnStrategies = new ArrayList<>(columnStrategies);
+      edgeColumnStrategies.set(EdgeType.LABEL_FIELD_POSITION, ColumnStrategy.VIRTUAL);
+      return edgeColumnStrategies;
     }
+    return columnStrategies;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/AbstractMatchNodeVisitor.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/AbstractMatchNodeVisitor.java
index 84e160827..78106862b 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/AbstractMatchNodeVisitor.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/AbstractMatchNodeVisitor.java
@@ -24,11 +24,11 @@
 
 public abstract class AbstractMatchNodeVisitor implements MatchNodeVisitor {
 
-    @Override
-    public T visit(RelNode node) {
-        if (node instanceof IMatchNode) {
-            return ((IMatchNode) node).accept(this);
-        }
-        throw new IllegalArgumentException("node is not a IMatchNode");
+  @Override
+  public T visit(RelNode node) {
+    if (node instanceof IMatchNode) {
+      return ((IMatchNode) node).accept(this);
     }
+    throw new IllegalArgumentException("node is not a IMatchNode");
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ConstructGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ConstructGraph.java
index 405255889..5f645b215 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ConstructGraph.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ConstructGraph.java
@@ -19,8 +19,8 @@
 
 package org.apache.geaflow.dsl.rel;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -29,22 +29,31 @@
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException;
 
+import com.google.common.collect.ImmutableList;
+
 public abstract class ConstructGraph extends SingleRel {
 
-    protected final ImmutableList labelNames;
-
-    protected ConstructGraph(RelOptCluster cluster, RelTraitSet traits,
-                             RelNode input, List labelNames, RelDataType rowType) {
-        super(cluster, traits, input);
-        this.labelNames = ImmutableList.copyOf(labelNames);
-        this.rowType = rowType;
-        if (input.getRowType().getSqlTypeName() == SqlTypeName.PATH) {
-            throw new GeaFlowDSLException("Illegal input type: "
-                + input.getRowType().getSqlTypeName() + " for " + getRelTypeName());
-        }
-    }
+  protected final ImmutableList labelNames;
 
-    public ImmutableList getLabelNames() {
-        return labelNames;
+  protected ConstructGraph(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List labelNames,
+      RelDataType rowType) {
+    super(cluster, traits, input);
+    this.labelNames = ImmutableList.copyOf(labelNames);
+    this.rowType = rowType;
+    if (input.getRowType().getSqlTypeName() == SqlTypeName.PATH) {
+      throw new GeaFlowDSLException(
+          "Illegal input type: "
+              + input.getRowType().getSqlTypeName()
+              + " for "
+              + getRelTypeName());
     }
+  }
+
+  public ImmutableList getLabelNames() {
+    return labelNames;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GQLToRelConverter.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GQLToRelConverter.java
index d8e0e9685..d4a75ef89 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GQLToRelConverter.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GQLToRelConverter.java
@@ -22,9 +22,6 @@
 import static org.apache.calcite.rel.RelFieldCollation.NullDirection.UNSPECIFIED;
 import static org.apache.calcite.util.Static.RESOURCE;
 
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -37,6 +34,7 @@
 import java.util.Map;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptTable;
@@ -120,1559 +118,1631 @@
 import org.apache.geaflow.dsl.validator.scope.GQLScope;
 import org.apache.geaflow.dsl.validator.scope.GQLWithBodyScope;
 
-public class GQLToRelConverter extends SqlToRelConverter {
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 
-    private long queryIdCounter = 0;
+public class GQLToRelConverter extends SqlToRelConverter {
 
-    public GQLToRelConverter(ViewExpander viewExpander, SqlValidator validator,
-                             CatalogReader catalogReader, RelOptCluster cluster,
-                             SqlRexConvertletTable convertLetTable,
-                             Config config) {
-        super(viewExpander, validator, catalogReader, cluster, convertLetTable, config);
+  private long queryIdCounter = 0;
+
+  public GQLToRelConverter(
+      ViewExpander viewExpander,
+      SqlValidator validator,
+      CatalogReader catalogReader,
+      RelOptCluster cluster,
+      SqlRexConvertletTable convertLetTable,
+      Config config) {
+    super(viewExpander, validator, catalogReader, cluster, convertLetTable, config);
+  }
+
+  @Override
+  protected RelRoot convertQueryRecursive(SqlNode query, boolean top, RelDataType targetRowType) {
+    return RelRoot.of(convertQueryRecursive(query, top, targetRowType, null), query.getKind());
+  }
+
+  private RelNode convertQueryRecursive(
+      SqlNode query, boolean top, RelDataType targetRowType, Blackboard withBb) {
+    SqlKind kind = query.getKind();
+    switch (kind) {
+      case GQL_FILTER:
+        return convertGQLFilter((SqlFilterStatement) query, top, withBb);
+      case GQL_RETURN:
+        return convertGQLReturn((SqlReturnStatement) query, top, withBb);
+      case GQL_MATCH_PATTERN:
+        return convertGQLMatchPattern((SqlMatchPattern) query, top, withBb);
+      case GQL_LET:
+        return convertGQLLet((SqlLetStatement) query, top, withBb);
+      case GQL_ALGORITHM:
+        return convertGQLAlgorithm((SqlGraphAlgorithmCall) query, top, withBb);
+      default:
+        return super.convertQueryRecursive(query, top, targetRowType).rel;
+    }
+  }
+
+  @Override
+  public RelRoot convertWith(SqlWith with, boolean top) {
+    boolean containMatch = GQLNodeUtil.containMatch(with.body);
+    if (containMatch) {
+      if (with.withList.size() > 1) {
+        throw new GeaFlowDSLException(
+            "Multi-with list is not support for match at " + with.getParserPosition());
+      }
+      SqlWithItem withItem = (SqlWithItem) with.withList.get(0);
+      RelNode parameterNode = convertQueryRecursive(withItem.query, false, null).rel;
+      Blackboard withBb = createBlackboard(getValidator().getScopes(withItem), null, false);
+      withBb.setRoot(parameterNode, true);
+
+      RelNode queryNode = convertQueryRecursive(with.body, false, null, withBb);
+      RelNode parameterizedNode =
+          LogicalParameterizedRelNode.create(
+              getCluster(), getCluster().traitSet(), parameterNode, queryNode);
+      return RelRoot.of(parameterizedNode, with.getKind());
+    }
+    return super.convertWith(with, top);
+  }
+
+  private RelNode convertGQLFilter(
+      SqlFilterStatement filterStatement, boolean top, Blackboard withBb) {
+    SqlValidatorScope scope = getValidator().getScopes(filterStatement);
+    Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+    convertGQLFrom(bb, filterStatement.getFrom(), withBb);
+    RexNode condition = bb.convertExpression(filterStatement.getCondition());
+    return LogicalFilter.create(bb.root, condition);
+  }
+
+  private RelNode convertGQLReturn(
+      SqlReturnStatement returnStatement, boolean top, Blackboard withBb) {
+    assert returnStatement != null : "return statement is null";
+    SqlValidatorScope scope = getValidator().getScopes(returnStatement);
+    Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+
+    convertGQLFrom(bb, returnStatement.getFrom(), withBb);
+
+    List collationList = new ArrayList<>();
+    List orderExprList = new ArrayList<>();
+    if (returnStatement.getOrderBy() != null) {
+      for (SqlNode orderItem : returnStatement.getOrderList()) {
+        collationList.add(
+            convertOrderItem(
+                returnStatement, orderItem, orderExprList, Direction.ASCENDING, UNSPECIFIED));
+      }
+    }
+    RelCollation collation = cluster.traitSet().canonize(RelCollations.of(collationList));
+
+    SqlNodeList groupBy = returnStatement.getGroupBy();
+    if (((GQLValidatorImpl) validator).getAggregate(returnStatement.getReturnList()) != null
+        || groupBy != null) {
+      convertAgg(bb, returnStatement, orderExprList);
+    } else {
+      convertReturnList(bb, returnStatement, orderExprList);
+    }
+    SqlValidatorScope orderByScope = getValidator().getScopes(returnStatement.getOrderBy());
+    Blackboard orderByBb = createBlackboard(orderByScope, null, false).setWithBb(withBb);
+    orderByBb.setRoot(bb.root, false);
+    convertOrder(
+        getValidator().asSqlSelect(returnStatement),
+        orderByBb,
+        collation,
+        orderExprList,
+        returnStatement.getOffset(),
+        returnStatement.getFetch());
+    bb.setRoot(orderByBb.root, false);
+    return bb.root;
+  }
+
+  private void convertReturnList(
+      Blackboard bb, SqlReturnStatement returnStmt, List orderList) {
+    SqlNodeList returnList = returnStmt.getReturnList();
+    // Return Star & SubQueries are not support.
+    List fieldNames = new ArrayList<>();
+    List exprs = new ArrayList<>();
+    Collection aliases = new TreeSet<>();
+    int i = -1;
+    for (SqlNode node : returnList) {
+      exprs.add(bb.convertExpression(node));
+      fieldNames.add(deriveAlias(node, aliases, ++i));
     }
 
-    @Override
-    protected RelRoot convertQueryRecursive(SqlNode query, boolean top,
-                                            RelDataType targetRowType) {
-        return RelRoot.of(convertQueryRecursive(query, top, targetRowType, null), query.getKind());
-    }
-
-    private RelNode convertQueryRecursive(SqlNode query, boolean top,
-                                          RelDataType targetRowType, Blackboard withBb) {
-        SqlKind kind = query.getKind();
-        switch (kind) {
-            case GQL_FILTER:
-                return convertGQLFilter((SqlFilterStatement) query, top, withBb);
-            case GQL_RETURN:
-                return convertGQLReturn((SqlReturnStatement) query, top, withBb);
-            case GQL_MATCH_PATTERN:
-                return convertGQLMatchPattern((SqlMatchPattern) query, top, withBb);
-            case GQL_LET:
-                return convertGQLLet((SqlLetStatement) query, top, withBb);
-            case GQL_ALGORITHM:
-                return convertGQLAlgorithm((SqlGraphAlgorithmCall) query, top, withBb);
-            default:
-                return super.convertQueryRecursive(query, top, targetRowType).rel;
-        }
+    SqlValidatorScope orderScope = getValidator().getScopes(returnStmt.getOrderBy());
+    for (SqlNode node2 : orderList) {
+      SqlNode expandExpr =
+          ((GQLValidatorImpl) validator).expandReturnGroupOrderExpr(returnStmt, orderScope, node2);
+      exprs.add(bb.convertExpression(expandExpr));
+      fieldNames.add(deriveAlias(node2, aliases, ++i));
+    }
+    fieldNames =
+        SqlValidatorUtil.uniquify(fieldNames, this.catalogReader.nameMatcher().isCaseSensitive());
+    RelDataType rowType =
+        RexUtil.createStructType(
+            this.cluster.getTypeFactory(), exprs, fieldNames, SqlValidatorUtil.F_SUGGESTER);
+    bb.setRoot(LogicalProject.create(bb.root, exprs, rowType), true);
+
+    if (returnStmt.isDistinct()) {
+      convertDistinct(bb, true);
+    }
+  }
+
+  private void convertAgg(Blackboard bb, SqlReturnStatement returnStmt, List orderList) {
+    assert bb.root != null : "precondition: child != null when converting AGG";
+
+    SqlNodeList returnList = returnStmt.getReturnList();
+    final ReturnAggregateFinder aggregateFinder = new ReturnAggregateFinder();
+    returnList.accept(aggregateFinder);
+    GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStmt);
+    final GQLReturnScope.Resolved r = scope.resolved.get();
+    final GQLAggConverter aggConverter = new GQLAggConverter(bb, returnStmt);
+    for (SqlNode groupExpr : r.groupExprList) {
+      aggConverter.addGroupExpr(groupExpr);
     }
 
-    @Override
-    public RelRoot convertWith(SqlWith with, boolean top) {
-        boolean containMatch = GQLNodeUtil.containMatch(with.body);
-        if (containMatch) {
-            if (with.withList.size() > 1) {
-                throw new GeaFlowDSLException("Multi-with list is not support for match at "
-                    + with.getParserPosition());
-            }
-            SqlWithItem withItem = (SqlWithItem) with.withList.get(0);
-            RelNode parameterNode = convertQueryRecursive(withItem.query, false, null).rel;
-            Blackboard withBb = createBlackboard(getValidator().getScopes(withItem), null, false);
-            withBb.setRoot(parameterNode, true);
-
-            RelNode queryNode = convertQueryRecursive(with.body, false, null, withBb);
-            RelNode parameterizedNode = LogicalParameterizedRelNode.create(getCluster(), getCluster().traitSet(),
-                parameterNode, queryNode);
-            return RelRoot.of(parameterizedNode, with.getKind());
-        }
-        return super.convertWith(with, top);
+    final List> projects = new ArrayList<>();
+    try {
+      Preconditions.checkArgument(bb.getAgg() == null, "already in agg mode");
+      bb.setAgg(aggConverter);
+
+      returnList.accept(aggConverter);
+      // Assert we don't have dangling items left in the stack
+      assert !aggConverter.inOver;
+      for (SqlNode expr : orderList) {
+        expr.accept(aggConverter);
+        assert !aggConverter.inOver;
+      }
+      // compute inputs to the aggregator
+      List> preExprs = aggConverter.getPreExprs();
+
+      final RelNode inputRel = bb.root;
+
+      // Project the expressions required by agg and having.
+      bb.setRoot(
+          getRelBuilder()
+              .push(inputRel)
+              .projectNamed(Pair.left(preExprs), Pair.right(preExprs), true)
+              .build(),
+          false);
+      // Add the aggregator
+      bb.setRoot(createAggregate(bb, r.groupSet, r.groupSets, aggConverter.getAggCalls()), false);
+
+      // Now sub-queries in the entire select list have been converted.
+      // Convert the select expressions to get the final list to be
+      // projected.
+      int k = 0;
+
+      // For select expressions, use the field names previously assigned
+      // by the validator. If we derive afresh, we might generate names
+      // like "EXPR$2" that don't match the names generated by the
+      // validator. This is especially the case when there are system
+      // fields; system fields appear in the relnode's rowtype but do not
+      // (yet) appear in the validator type.
+      GQLReturnScope returnScope = null;
+      SqlValidatorScope curScope = bb.scope;
+      while (curScope instanceof DelegatingScope) {
+        if (curScope instanceof GQLReturnScope) {
+          returnScope = (GQLReturnScope) curScope;
+          break;
+        }
+        curScope = ((DelegatingScope) curScope).getParent();
+      }
+      assert returnScope != null;
+
+      final SqlValidatorNamespace returnNamespace = validator.getNamespace(returnScope.getNode());
+      final List names = returnNamespace.getRowType().getFieldNames();
+      int sysFieldCount = returnList.size() - names.size();
+      for (SqlNode expr : returnList) {
+        projects.add(
+            Pair.of(
+                bb.convertExpression(expr),
+                k < sysFieldCount
+                    ? validator.deriveAlias(expr, k++)
+                    : names.get(k++ - sysFieldCount)));
+      }
+      for (SqlNode expr : orderList) {
+        projects.add(Pair.of(bb.convertExpression(expr), validator.deriveAlias(expr, k++)));
+      }
+    } finally {
+      bb.setAgg(null);
     }
 
-    private RelNode convertGQLFilter(SqlFilterStatement filterStatement, boolean top, Blackboard withBb) {
-        SqlValidatorScope scope = getValidator().getScopes(filterStatement);
-        Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
-        convertGQLFrom(bb, filterStatement.getFrom(), withBb);
-        RexNode condition = bb.convertExpression(filterStatement.getCondition());
-        return LogicalFilter.create(bb.root, condition);
+    getRelBuilder().push(bb.root);
+    // implement the SELECT list
+    getRelBuilder().project(Pair.left(projects), Pair.right(projects)).rename(Pair.right(projects));
+    bb.setRoot(getRelBuilder().build(), true);
+  }
+
+  private RelNode convertGQLMatchPattern(
+      SqlMatchPattern matchPattern, boolean top, Blackboard withBb) {
+    SqlValidatorScope scope = getValidator().getScopes(matchPattern);
+    Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+    convertGQLFrom(bb, matchPattern.getFrom(), withBb);
+
+    SqlNodeList pathPatterns = matchPattern.getPathPatterns();
+
+    List relPathPatterns = new ArrayList<>();
+    List pathPatternTypes = new ArrayList<>();
+    // concat path pattern
+    for (SqlNode pathPattern : pathPatterns) {
+      IMatchNode relPathPattern = convertPathPattern(pathPattern, withBb);
+      PathRecordType pathType =
+          pathPattern instanceof SqlUnionPathPattern
+              ? (PathRecordType) getValidator().getNamespace(pathPattern).getRowType()
+              : (PathRecordType) getValidator().getValidatedNodeType(pathPattern);
+      // concat path pattern and type if then can connect in a line.
+      concatPathPatterns(relPathPatterns, pathPatternTypes, relPathPattern, pathType);
     }
 
-    private RelNode convertGQLReturn(SqlReturnStatement returnStatement, boolean top, Blackboard withBb) {
-        assert returnStatement != null : "return statement is null";
-        SqlValidatorScope scope = getValidator().getScopes(returnStatement);
-        Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+    IMatchNode joinPattern = null;
+    for (IMatchNode relPathPattern : relPathPatterns) {
+      if (joinPattern == null) {
+        joinPattern = relPathPattern;
+      } else {
+        IMatchNode left = joinPattern;
+        IMatchNode right = relPathPattern;
+        RexNode condition =
+            GQLRelUtil.createPathJoinCondition(left, right, isCaseSensitive(), rexBuilder);
+        joinPattern =
+            MatchJoin.create(
+                left.getCluster(), left.getTraitSet(), left, right, condition, JoinRelType.INNER);
+      }
+    }
 
-        convertGQLFrom(bb, returnStatement.getFrom(), withBb);
+    if (matchPattern.isDistinct()) {
+      joinPattern = MatchDistinct.create(joinPattern);
+    }
 
-        List collationList = new ArrayList<>();
-        List orderExprList = new ArrayList<>();
-        if (returnStatement.getOrderBy() != null) {
-            for (SqlNode orderItem : returnStatement.getOrderList()) {
-                collationList.add(convertOrderItem(returnStatement, orderItem, orderExprList, Direction.ASCENDING,
-                    UNSPECIFIED));
-            }
-        }
-        RelCollation collation = cluster.traitSet().canonize(RelCollations.of(collationList));
+    RelDataType graphType = getValidator().getValidatedNodeType(matchPattern);
+    GraphMatch graphMatch;
+    if (bb.root instanceof GraphMatch) { // merge with pre graph match.
+      GraphMatch input = (GraphMatch) bb.root;
+      graphMatch = input.merge(joinPattern);
+    } else {
+      graphMatch = LogicalGraphMatch.create(getCluster(), bb.root, joinPattern, graphType);
+    }
+    if (matchPattern.getWhere() != null) {
+      SqlNode where = matchPattern.getWhere();
+      SqlValidatorScope whereScope = getValidator().getScopes(where);
+      GQLBlackboard whereBb = createBlackboard(whereScope, null, false).setWithBb(withBb);
+      whereBb.setRoot(graphMatch, true);
+      replaceSubQueries(whereBb, where, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
+      RexNode condition = whereBb.convertExpression(where);
+
+      IMatchNode newPathPattern =
+          MatchFilter.create(
+              graphMatch.getPathPattern(), condition, graphMatch.getPathPattern().getPathSchema());
+      graphMatch = graphMatch.copy(newPathPattern);
+    }
 
-        SqlNodeList groupBy = returnStatement.getGroupBy();
-        if (((GQLValidatorImpl) validator).getAggregate(returnStatement.getReturnList()) != null
-            || groupBy != null) {
-            convertAgg(bb, returnStatement, orderExprList);
+    List orderByExpList = new ArrayList<>();
+    if (matchPattern.getOrderBy() != null) {
+      SqlValidatorScope orderScope = getValidator().getScopes(matchPattern.getOrderBy());
+      Blackboard orderBb = createBlackboard(orderScope, null, top).setWithBb(withBb);
+      orderBb.setRoot(graphMatch, true);
+      for (SqlNode orderItem : matchPattern.getOrderBy()) {
+        orderByExpList.add(orderBb.convertExpression(orderItem));
+      }
+    }
+    // convert match order
+    SqlNode limit = matchPattern.getLimit();
+    if (limit != null || orderByExpList.size() > 0) {
+      MatchPathSort newPathPattern =
+          MatchPathSort.create(
+              graphMatch.getPathPattern(),
+              orderByExpList,
+              limit == null ? null : convertExpression(limit),
+              graphMatch.getPathPattern().getPathSchema());
+
+      graphMatch = graphMatch.copy(newPathPattern);
+    }
+    return graphMatch;
+  }
+
+  private void concatPathPatterns(
+      List concatPatterns,
+      List concatPathTypes,
+      IMatchNode pathPattern,
+      PathRecordType pathType) {
+    if (concatPatterns.isEmpty() || !(pathPattern instanceof SingleMatchNode)) {
+      concatPatterns.add(pathPattern);
+      concatPathTypes.add(pathType);
+      return;
+    }
+    SingleMatchNode singlePathPattern = (SingleMatchNode) pathPattern;
+    String firstLabel = GQLRelUtil.getFirstMatchNode(singlePathPattern).getLabel();
+    String latestLabel = GQLRelUtil.getLatestMatchNode(singlePathPattern).getLabel();
+    int i;
+    for (i = 0; i < concatPatterns.size(); i++) {
+      IMatchNode pathPattern2 = concatPatterns.get(i);
+      if (pathPattern2 instanceof SingleMatchNode) {
+        String latestLabel2 =
+            GQLRelUtil.getLatestMatchNode((SingleMatchNode) pathPattern2).getLabel();
+        // pathPattern2 is "(a) - (b) - (c)", where singlePathPattern is "(c) - (d)"
+        // Then concat pathPattern2 with singlePathPattern and return "(a) - (b) - (c) - (d)"
+        if (getValidator().nameMatcher().matches(firstLabel, latestLabel2)) {
+          IMatchNode concatPattern =
+              GQLRelUtil.concatPathPattern(
+                  (SingleMatchNode) pathPattern2, singlePathPattern, isCaseSensitive());
+          concatPatterns.set(i, concatPattern);
+
+          PathRecordType pathType2 = concatPathTypes.get(i);
+          PathRecordType concatPathType = pathType2.concat(pathType, isCaseSensitive());
+          concatPathTypes.set(i, concatPathType);
+          break;
         } else {
-            convertReturnList(bb, returnStatement, orderExprList);
-        }
-        SqlValidatorScope orderByScope =
-            getValidator().getScopes(returnStatement.getOrderBy());
-        Blackboard orderByBb = createBlackboard(orderByScope, null, false).setWithBb(withBb);
-        orderByBb.setRoot(bb.root, false);
-        convertOrder(getValidator().asSqlSelect(returnStatement), orderByBb,
-            collation, orderExprList, returnStatement.getOffset(), returnStatement.getFetch());
-        bb.setRoot(orderByBb.root, false);
-        return bb.root;
-    }
-
-    private void convertReturnList(Blackboard bb,
-                                   SqlReturnStatement returnStmt, List orderList) {
-        SqlNodeList returnList = returnStmt.getReturnList();
-        //Return Star & SubQueries are not support.
-        List fieldNames = new ArrayList<>();
-        List exprs = new ArrayList<>();
-        Collection aliases = new TreeSet<>();
-        int i = -1;
-        for (SqlNode node : returnList) {
-            exprs.add(bb.convertExpression(node));
-            fieldNames.add(deriveAlias(node, aliases, ++i));
-        }
-
-        SqlValidatorScope orderScope =
-            getValidator().getScopes(returnStmt.getOrderBy());
-        for (SqlNode node2 : orderList) {
-            SqlNode expandExpr =
-                ((GQLValidatorImpl) validator).expandReturnGroupOrderExpr(returnStmt,
-                    orderScope, node2);
-            exprs.add(bb.convertExpression(expandExpr));
-            fieldNames.add(deriveAlias(node2, aliases, ++i));
-        }
-        fieldNames = SqlValidatorUtil.uniquify(fieldNames, this.catalogReader.nameMatcher().isCaseSensitive());
-        RelDataType rowType = RexUtil.createStructType(this.cluster.getTypeFactory(), exprs, fieldNames,
-            SqlValidatorUtil.F_SUGGESTER);
-        bb.setRoot(LogicalProject.create(bb.root, exprs, rowType), true);
-
-        if (returnStmt.isDistinct()) {
-            convertDistinct(bb, true);
-        }
+          String firstLabel2 =
+              GQLRelUtil.getFirstMatchNode((SingleMatchNode) pathPattern2).getLabel();
+          // singlePathPattern is "(a) - (b) - (c)", where pathPattern2 is "(c) - (d)"
+          // Then concat singlePathPattern with pathPattern2 and return "(a) - (b) - (c) - (d)"
+          if (getValidator().nameMatcher().matches(firstLabel2, latestLabel)) {
+            IMatchNode concatPattern =
+                GQLRelUtil.concatPathPattern(
+                    singlePathPattern, (SingleMatchNode) pathPattern2, isCaseSensitive());
+            concatPatterns.set(i, concatPattern);
+
+            PathRecordType pathType2 = concatPathTypes.get(i);
+            PathRecordType concatPathType = pathType.concat(pathType2, isCaseSensitive());
+            concatPathTypes.set(i, concatPathType);
+            break;
+          }
+        }
+      }
     }
+    // If merge not happen, just add it.
+    if (i == concatPatterns.size()) {
+      concatPatterns.add(pathPattern);
+      concatPathTypes.add(pathType);
+    }
+  }
+
+  private SingleMatchNode convertMatchNodeWhere(
+      SqlMatchNode matchNode, SqlNode matchWhere, IMatchNode input, Blackboard withBb) {
+    assert input != null;
+    SqlValidatorScope whereScope = getValidator().getScopes(matchWhere);
+    Blackboard nodeBb = createBlackboard(whereScope, null, false).setWithBb(withBb);
+    nodeBb.setRoot(new WhereMatchNode(input), true);
+    if (withBb != null) {
+      nodeBb.addInput(withBb.root);
+    }
+    replaceSubQueries(nodeBb, matchWhere, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
+    RexNode condition = nodeBb.convertExpression(matchWhere);
 
-    private void convertAgg(Blackboard bb,
-                            SqlReturnStatement returnStmt, List orderList) {
-        assert bb.root != null : "precondition: child != null when converting AGG";
-
-        SqlNodeList returnList = returnStmt.getReturnList();
-        final ReturnAggregateFinder aggregateFinder = new ReturnAggregateFinder();
-        returnList.accept(aggregateFinder);
-        GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStmt);
-        final GQLReturnScope.Resolved r = scope.resolved.get();
-        final GQLAggConverter aggConverter = new GQLAggConverter(bb, returnStmt);
-        for (SqlNode groupExpr : r.groupExprList) {
-            aggConverter.addGroupExpr(groupExpr);
-        }
-
-        final List> projects = new ArrayList<>();
-        try {
-            Preconditions.checkArgument(bb.getAgg() == null, "already in agg mode");
-            bb.setAgg(aggConverter);
-
-            returnList.accept(aggConverter);
-            // Assert we don't have dangling items left in the stack
-            assert !aggConverter.inOver;
-            for (SqlNode expr : orderList) {
-                expr.accept(aggConverter);
-                assert !aggConverter.inOver;
-            }
-            // compute inputs to the aggregator
-            List> preExprs = aggConverter.getPreExprs();
-
-            final RelNode inputRel = bb.root;
-
-            // Project the expressions required by agg and having.
-            bb.setRoot(
-                getRelBuilder().push(inputRel)
-                    .projectNamed(Pair.left(preExprs), Pair.right(preExprs), true)
-                    .build(),
-                false);
-            // Add the aggregator
-            bb.setRoot(
-                createAggregate(bb, r.groupSet, r.groupSets,
-                    aggConverter.getAggCalls()), false);
-
-            // Now sub-queries in the entire select list have been converted.
-            // Convert the select expressions to get the final list to be
-            // projected.
-            int k = 0;
-
-            // For select expressions, use the field names previously assigned
-            // by the validator. If we derive afresh, we might generate names
-            // like "EXPR$2" that don't match the names generated by the
-            // validator. This is especially the case when there are system
-            // fields; system fields appear in the relnode's rowtype but do not
-            // (yet) appear in the validator type.
-            GQLReturnScope returnScope = null;
-            SqlValidatorScope curScope = bb.scope;
-            while (curScope instanceof DelegatingScope) {
-                if (curScope instanceof GQLReturnScope) {
-                    returnScope = (GQLReturnScope) curScope;
-                    break;
-                }
-                curScope = ((DelegatingScope) curScope).getParent();
-            }
-            assert returnScope != null;
-
-            final SqlValidatorNamespace returnNamespace =
-                validator.getNamespace(returnScope.getNode());
-            final List names =
-                returnNamespace.getRowType().getFieldNames();
-            int sysFieldCount = returnList.size() - names.size();
-            for (SqlNode expr : returnList) {
-                projects.add(
-                    Pair.of(bb.convertExpression(expr),
-                        k < sysFieldCount
-                            ? validator.deriveAlias(expr, k++)
-                            : names.get(k++ - sysFieldCount)));
-            }
-            for (SqlNode expr : orderList) {
-                projects.add(
-                    Pair.of(bb.convertExpression(expr),
-                        validator.deriveAlias(expr, k++)));
-            }
-        } finally {
-            bb.setAgg(null);
-        }
-
-        getRelBuilder().push(bb.root);
-        // implement the SELECT list
-        getRelBuilder().project(Pair.left(projects), Pair.right(projects))
-            .rename(Pair.right(projects));
-        bb.setRoot(getRelBuilder().build(), true);
-    }
-
-    private RelNode convertGQLMatchPattern(SqlMatchPattern matchPattern, boolean top, Blackboard withBb) {
-        SqlValidatorScope scope = getValidator().getScopes(matchPattern);
-        Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
-        convertGQLFrom(bb, matchPattern.getFrom(), withBb);
-
-        SqlNodeList pathPatterns = matchPattern.getPathPatterns();
-
-        List relPathPatterns = new ArrayList<>();
-        List pathPatternTypes = new ArrayList<>();
-        // concat path pattern
-        for (SqlNode pathPattern : pathPatterns) {
-            IMatchNode relPathPattern = convertPathPattern(pathPattern, withBb);
-            PathRecordType pathType =
-                pathPattern instanceof SqlUnionPathPattern
-                    ? (PathRecordType) getValidator().getNamespace(pathPattern).getRowType() :
-                    (PathRecordType) getValidator().getValidatedNodeType(pathPattern);
-            // concat path pattern and type if then can connect in a line.
-            concatPathPatterns(relPathPatterns, pathPatternTypes, relPathPattern, pathType);
-        }
+    PathRecordType pathRecordType = input.getPathSchema();
+    RelDataTypeField pathField =
+        pathRecordType.getField(matchNode.getName(), isCaseSensitive(), false);
+    condition = GQLRexUtil.toPathInputRefForWhere(pathField, condition);
+    return MatchFilter.create(input, condition, input.getPathSchema());
+  }
 
-        IMatchNode joinPattern = null;
-        for (IMatchNode relPathPattern : relPathPatterns) {
-            if (joinPattern == null) {
-                joinPattern = relPathPattern;
-            } else {
-                IMatchNode left = joinPattern;
-                IMatchNode right = relPathPattern;
-                RexNode condition = GQLRelUtil.createPathJoinCondition(left, right, isCaseSensitive(), rexBuilder);
-                joinPattern = MatchJoin.create(left.getCluster(), left.getTraitSet(),
-                    left, right, condition, JoinRelType.INNER);
-            }
-        }
-
-        if (matchPattern.isDistinct()) {
-            joinPattern = MatchDistinct.create(joinPattern);
-        }
+  private static class WhereMatchNode extends AbstractRelNode {
 
-        RelDataType graphType = getValidator().getValidatedNodeType(matchPattern);
-        GraphMatch graphMatch;
-        if (bb.root instanceof GraphMatch) { // merge with pre graph match.
-            GraphMatch input = (GraphMatch) bb.root;
-            graphMatch = input.merge(joinPattern);
-        } else {
-            graphMatch = LogicalGraphMatch.create(getCluster(), bb.root, joinPattern, graphType);
-        }
-        if (matchPattern.getWhere() != null) {
-            SqlNode where = matchPattern.getWhere();
-            SqlValidatorScope whereScope = getValidator().getScopes(where);
-            GQLBlackboard whereBb = createBlackboard(whereScope, null, false)
-                .setWithBb(withBb);
-            whereBb.setRoot(graphMatch, true);
-            replaceSubQueries(whereBb, where, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
-            RexNode condition = whereBb.convertExpression(where);
-
-            IMatchNode newPathPattern = MatchFilter.create(graphMatch.getPathPattern(),
-                condition, graphMatch.getPathPattern().getPathSchema());
-            graphMatch = graphMatch.copy(newPathPattern);
-        }
-
-        List orderByExpList = new ArrayList<>();
-        if (matchPattern.getOrderBy() != null) {
-            SqlValidatorScope orderScope = getValidator().getScopes(matchPattern.getOrderBy());
-            Blackboard orderBb = createBlackboard(orderScope, null, top)
-                .setWithBb(withBb);
-            orderBb.setRoot(graphMatch, true);
-            for (SqlNode orderItem : matchPattern.getOrderBy()) {
-                orderByExpList.add(orderBb.convertExpression(orderItem));
-            }
-        }
-        // convert match order
-        SqlNode limit = matchPattern.getLimit();
-        if (limit != null || orderByExpList.size() > 0) {
-            MatchPathSort newPathPattern = MatchPathSort.create(graphMatch.getPathPattern(),
-                orderByExpList, limit == null ? null : convertExpression(limit),
-                graphMatch.getPathPattern().getPathSchema());
-
-            graphMatch = graphMatch.copy(newPathPattern);
-        }
-        return graphMatch;
+    public WhereMatchNode(IMatchNode matchNode) {
+      super(matchNode.getCluster(), matchNode.getTraitSet());
+      this.rowType = matchNode.getNodeType();
     }
-
-    private void concatPathPatterns(List concatPatterns, List concatPathTypes,
-                                    IMatchNode pathPattern, PathRecordType pathType) {
-        if (concatPatterns.isEmpty() || !(pathPattern instanceof SingleMatchNode)) {
-            concatPatterns.add(pathPattern);
-            concatPathTypes.add(pathType);
-            return;
-        }
-        SingleMatchNode singlePathPattern = (SingleMatchNode) pathPattern;
-        String firstLabel = GQLRelUtil.getFirstMatchNode(singlePathPattern).getLabel();
-        String latestLabel = GQLRelUtil.getLatestMatchNode(singlePathPattern).getLabel();
+  }
+
+  private IMatchNode convertPathPattern(SqlNode sqlNode, Blackboard withBb) {
+    if (sqlNode instanceof SqlUnionPathPattern) {
+      SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) sqlNode;
+      IMatchNode left = convertPathPattern(unionPathPattern.getLeft(), withBb);
+      IMatchNode right = convertPathPattern(unionPathPattern.getRight(), withBb);
+      return MatchUnion.create(
+          getCluster(),
+          getCluster().traitSet(),
+          Lists.newArrayList(left, right),
+          unionPathPattern.isUnionAll());
+    }
+    SqlPathPattern pathPattern = (SqlPathPattern) sqlNode;
+    SingleMatchNode relPathPattern = null;
+    SqlMatchNode preMatchNode = null;
+    for (SqlNode pathNode : pathPattern.getPathNodes()) {
+      PathRecordType pathType = (PathRecordType) getValidator().getValidatedNodeType(pathNode);
+      switch (pathNode.getKind()) {
+        case GQL_MATCH_NODE:
+          SqlMatchNode matchNode = (SqlMatchNode) pathNode;
+          RelDataType nodeType = getValidator().getMatchNodeType(matchNode);
+
+          SingleMatchNode vertexMatch =
+              VertexMatch.create(
+                  getCluster(),
+                  relPathPattern,
+                  matchNode.getName(),
+                  matchNode.getLabelNames(),
+                  nodeType,
+                  pathType);
+          if (matchNode.getWhere() != null) {
+            vertexMatch =
+                convertMatchNodeWhere(matchNode, matchNode.getWhere(), vertexMatch, withBb);
+          }
+          // generate for regex match
+          if (preMatchNode instanceof SqlMatchEdge
+              && ((SqlMatchEdge) preMatchNode).isRegexMatch()) {
+            SqlMatchEdge inputEdgeNode = (SqlMatchEdge) preMatchNode;
+            EdgeMatch inputEdgeMatch =
+                (EdgeMatch) vertexMatch.find(node -> node instanceof EdgeMatch);
+            relPathPattern = convertRegexMatch(inputEdgeNode, inputEdgeMatch, vertexMatch);
+          } else {
+            relPathPattern = vertexMatch;
+          }
+          if (getValidator().getStartCycleMatchNode((SqlMatchNode) pathNode) != null) {
+            relPathPattern = translateCycleMatchNode(relPathPattern, pathNode);
+          }
+          break;
+        case GQL_MATCH_EDGE:
+          SqlMatchEdge matchEdge = (SqlMatchEdge) pathNode;
+          RelDataType edgeType = getValidator().getMatchNodeType(matchEdge);
+
+          EdgeMatch edgeMatch =
+              EdgeMatch.create(
+                  getCluster(),
+                  relPathPattern,
+                  matchEdge.getName(),
+                  matchEdge.getLabelNames(),
+                  matchEdge.getDirection(),
+                  edgeType,
+                  pathType);
+
+          if (matchEdge.getWhere() != null) {
+            relPathPattern =
+                convertMatchNodeWhere(matchEdge, matchEdge.getWhere(), edgeMatch, withBb);
+          } else {
+            relPathPattern = edgeMatch;
+          }
+          break;
+        default:
+          throw new IllegalArgumentException("Illegal path node kind: " + pathNode.getKind());
+      }
+      preMatchNode = (SqlMatchNode) pathNode;
+    }
+    return relPathPattern;
+  }
+
+  private SingleMatchNode translateCycleMatchNode(
+      SingleMatchNode relPathPattern, SqlNode pathNode) {
+    PathRecordType pathRecordType = relPathPattern.getPathSchema();
+    int rightIndex =
+        pathRecordType
+            .getField(((SqlMatchNode) pathNode).getName(), isCaseSensitive(), false)
+            .getIndex();
+    int leftIndex =
+        pathRecordType
+            .getField(
+                getValidator().getStartCycleMatchNode((SqlMatchNode) pathNode).getName(),
+                isCaseSensitive(),
+                false)
+            .getIndex();
+    assert leftIndex >= 0 && rightIndex >= 0;
+    RexNode condition =
+        getRexBuilder()
+            .makeCall(
+                SqlStdOperatorTable.EQUALS,
+                getRexBuilder().makeInputRef(pathRecordType, leftIndex),
+                getRexBuilder().makeInputRef(pathRecordType, rightIndex));
+    return MatchFilter.create(relPathPattern, condition, relPathPattern.getPathSchema());
+  }
+
+  private SingleMatchNode convertRegexMatch(
+      SqlMatchEdge regexEdge, EdgeMatch regexEdgeMatch, SingleMatchNode vertexMatch) {
+    IMatchNode loopStart = (IMatchNode) regexEdgeMatch.getInput();
+    SubQueryStart queryStart =
+        SubQueryStart.create(
+            getCluster(),
+            loopStart.getTraitSet(),
+            generateSubQueryName(),
+            loopStart.getPathSchema(),
+            (VertexRecordType) loopStart.getNodeType());
+    // replace the input of regexEdgeMatch to queryStart and clone vertexMatch
+    SingleMatchNode loopBody = GQLRelUtil.replaceInput(vertexMatch, regexEdgeMatch, queryStart);
+
+    RexNode utilCondition = getRexBuilder().makeLiteral(true);
+    return LoopUntilMatch.create(
+        getCluster(),
+        vertexMatch.getTraitSet(),
+        loopStart,
+        loopBody,
+        utilCondition,
+        regexEdge.getMinHop(),
+        regexEdge.getMaxHop(),
+        vertexMatch.getPathSchema());
+  }
+
+  private RelNode convertGQLAlgorithm(
+      SqlGraphAlgorithmCall algorithmCall, boolean top, Blackboard withBb) {
+
+    SqlValidatorScope scope = getValidator().getScopes(algorithmCall);
+    Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+    convertGQLFrom(bb, algorithmCall.getFrom(), withBb);
+
+    Object[] params =
+        algorithmCall.getParameters() == null
+            ? new Object[0]
+            : algorithmCall.getParameters().getList().stream()
+                .map(sqlNode -> GQLRexUtil.getLiteralValue(bb.convertExpression(sqlNode)))
+                .toArray();
+    return LogicalGraphAlgorithm.create(
+        cluster,
+        getCluster().traitSet(),
+        bb.root,
+        ((GeaFlowUserDefinedGraphAlgorithm) algorithmCall.getOperator()).getImplementClass(),
+        params);
+  }
+
+  private RelNode convertGQLLet(SqlLetStatement letStatement, boolean top, Blackboard withBb) {
+    SqlValidatorScope scope = getValidator().getScopes(letStatement);
+    Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
+    convertGQLFrom(bb, letStatement.getFrom(), withBb);
+
+    RexNode rightExpression = bb.convertExpression(letStatement.getExpression());
+
+    PathRecordType letType = (PathRecordType) getValidator().getValidatedNodeType(letStatement);
+    PathInputRef leftRex = convertLeftLabel(letStatement.getLeftLabel(), letType);
+
+    String leftVarField = letStatement.getLeftField();
+    List operands = new ArrayList<>();
+    SqlIdentifier[] fieldNameNodes = new SqlIdentifier[leftRex.getType().getFieldCount()];
+    Map rex2VariableInfo = new HashMap<>();
+    int c = 0;
+    for (RelDataTypeField field : leftRex.getType().getFieldList()) {
+      VariableInfo variableInfo;
+      RexNode operand;
+      if (getValidator().nameMatcher().matches(leftVarField, field.getName())) {
+        // The field is the let a.xx
+        boolean isGlobal = letStatement.isGlobal();
+        // cast right expression to field type.
+        operand = getRexBuilder().makeCast(field.getType(), rightExpression);
+        variableInfo = new VariableInfo(isGlobal, field.getName());
+      } else {
+        RexInputRef labelRef = getRexBuilder().makeInputRef(leftRex.getType(), leftRex.getIndex());
+        operand = getRexBuilder().makeFieldAccess(labelRef, field.getIndex());
+        variableInfo = new VariableInfo(false, field.getName());
+      }
+      operands.add(operand);
+      rex2VariableInfo.put(operand, variableInfo);
+      fieldNameNodes[c++] = new SqlIdentifier(field.getName(), SqlParserPos.ZERO);
+    }
+    // Construct RexObjectConstruct for dynamic field append expression.
+    RexObjectConstruct rightRex =
+        new RexObjectConstruct(leftRex.getType(), operands, rex2VariableInfo);
+    PathModifyExpression modifyExpression = new PathModifyExpression(leftRex, rightRex);
+    GraphRecordType modifyGraphType = getValidator().getModifyGraphType(letStatement);
+
+    RelNode input = bb.root;
+    assert input instanceof GraphMatch;
+    GraphMatch graphMatch = (GraphMatch) input;
+    MatchPathModify newPathPattern =
+        MatchPathModify.create(
+            graphMatch.getPathPattern(),
+            Collections.singletonList(modifyExpression),
+            letType,
+            modifyGraphType);
+    return graphMatch.copy(newPathPattern);
+  }
+
+  private PathInputRef convertLeftLabel(String leftLabel, PathRecordType letType) {
+    RelDataTypeField labelField = letType.getField(leftLabel, isCaseSensitive(), false);
+    return new PathInputRef(leftLabel, labelField.getIndex(), labelField.getType());
+  }
+
+  private void convertGQLFrom(Blackboard bb, SqlNode from, Blackboard withBb) {
+    RelNode node;
+    switch (from.getKind()) {
+      case GQL_MATCH_PATTERN:
+        node = convertGQLMatchPattern((SqlMatchPattern) from, false, withBb);
+        bb.setRoot(node, true);
+        break;
+      case GQL_RETURN:
+        node = convertGQLReturn((SqlReturnStatement) from, false, withBb);
+        bb.setRoot(node, true);
+        break;
+      case GQL_FILTER:
+        node = convertGQLFilter((SqlFilterStatement) from, false, withBb);
+        bb.setRoot(node, true);
+        break;
+      case IDENTIFIER:
+        SqlIdentifier identifier = (SqlIdentifier) from;
+        SqlIdentifier completeIdentifier =
+            getValidator().getGQLContext().completeCatalogObjName(identifier);
+        RelOptTable table = catalogReader.getTable(completeIdentifier.names);
+        node = LogicalGraphScan.create(getCluster(), table);
+        bb.setRoot(node, true);
+        break;
+      case GQL_LET:
+        node = convertGQLLet((SqlLetStatement) from, false, withBb);
+        bb.setRoot(node, true);
+        break;
+      case GQL_ALGORITHM:
+        node = convertGQLAlgorithm((SqlGraphAlgorithmCall) from, false, withBb);
+        bb.setRoot(node, true);
+        break;
+      default:
+        throw new IllegalArgumentException("Illegal match from sql node: " + from.getKind());
+    }
+    if (withBb != null) {
+      bb.addInput(withBb.root);
+    }
+  }
+
+  @Override
+  protected RelNode convertInsert(SqlInsert call) {
+    RelOptTable targetTable = getTargetTable(call);
+    Table table = targetTable.unwrap(Table.class);
+    final RelDataType targetRowType = validator.getValidatedNodeType(call);
+    assert targetRowType != null;
+
+    RelNode sourceRel = convertQueryRecursive(call.getSource(), false, targetRowType).project();
+
+    if (table instanceof GraphElementTable) {
+      List targetColumns;
+      if (call.getTargetColumnList() != null) {
+        targetColumns =
+            call.getTargetColumnList().getList().stream()
+                .map(id -> ((SqlIdentifier) id).getSimple())
+                .collect(Collectors.toList());
+      } else {
+        targetColumns = new ArrayList<>();
+        for (RelDataTypeField field : targetRowType.getFieldList()) {
+          targetColumns.add(field.getName());
+        }
+      }
+      GraphElementTable graphElementTable = (GraphElementTable) table;
+      GeaFlowGraph graph = graphElementTable.getGraph();
+      RelDataType tableRowType = table.getRowType(validator.getTypeFactory());
+      int[] targetColumnIndices = new int[tableRowType.getFieldCount()];
+      for (int c = 0; c < tableRowType.getFieldList().size(); c++) {
+        RelDataTypeField field = tableRowType.getFieldList().get(c);
         int i;
-        for (i = 0; i < concatPatterns.size(); i++) {
-            IMatchNode pathPattern2 = concatPatterns.get(i);
-            if (pathPattern2 instanceof SingleMatchNode) {
-                String latestLabel2 = GQLRelUtil.getLatestMatchNode((SingleMatchNode) pathPattern2).getLabel();
-                // pathPattern2 is "(a) - (b) - (c)", where singlePathPattern is "(c) - (d)"
-                // Then concat pathPattern2 with singlePathPattern and return "(a) - (b) - (c) - (d)"
-                if (getValidator().nameMatcher().matches(firstLabel, latestLabel2)) {
-                    IMatchNode concatPattern = GQLRelUtil.concatPathPattern((SingleMatchNode) pathPattern2,
-                        singlePathPattern, isCaseSensitive());
-                    concatPatterns.set(i, concatPattern);
-
-                    PathRecordType pathType2 = concatPathTypes.get(i);
-                    PathRecordType concatPathType = pathType2.concat(pathType, isCaseSensitive());
-                    concatPathTypes.set(i, concatPathType);
-                    break;
-                } else {
-                    String firstLabel2 = GQLRelUtil.getFirstMatchNode((SingleMatchNode) pathPattern2).getLabel();
-                    // singlePathPattern is "(a) - (b) - (c)", where pathPattern2 is "(c) - (d)"
-                    // Then concat singlePathPattern with pathPattern2 and return "(a) - (b) - (c) - (d)"
-                    if (getValidator().nameMatcher().matches(firstLabel2, latestLabel)) {
-                        IMatchNode concatPattern = GQLRelUtil.concatPathPattern(singlePathPattern,
-                            (SingleMatchNode) pathPattern2, isCaseSensitive());
-                        concatPatterns.set(i, concatPattern);
-
-                        PathRecordType pathType2 = concatPathTypes.get(i);
-                        PathRecordType concatPathType = pathType.concat(pathType2, isCaseSensitive());
-                        concatPathTypes.set(i, concatPathType);
-                        break;
-                    }
-                }
-            }
-        }
-        // If merge not happen, just add it.
-        if (i == concatPatterns.size()) {
-            concatPatterns.add(pathPattern);
-            concatPathTypes.add(pathType);
-        }
+        for (i = 0; i < targetColumns.size(); i++) {
+          if (getValidator().nameMatcher().matches(targetColumns.get(i), field.getName())) {
+            break;
+          }
+        }
+        if (i < targetColumns.size()) {
+          targetColumnIndices[c] = i;
+        } else { // -1 means the meta field
+          targetColumnIndices[c] = -1;
+        }
+      }
+      RexObjectConstruct objConstruct =
+          createObjectConstruct(tableRowType, sourceRel, graphElementTable, targetColumnIndices);
+      return createGraphModify(
+          graph,
+          new String[] {graphElementTable.getTypeName()},
+          new RexNode[] {objConstruct},
+          sourceRel);
+    } else if (table instanceof GeaFlowGraph) {
+      GeaFlowGraph graph = (GeaFlowGraph) table;
+      Map> vertexEdgeType2RefFields = new HashMap<>();
+      Map vertexEdgeType2ExpIndices = new HashMap<>();
+
+      SqlNodeList targetColumns = call.getTargetColumnList();
+      assert targetColumns != null && targetColumns.size() > 0;
+
+      for (SqlNode targetColumn : targetColumns) {
+        List names = ((SqlIdentifier) targetColumn).names;
+        assert names.size() == 2;
+        String vertexEdgeTypeName = names.get(0);
+        String fieldName = names.get(1);
+        vertexEdgeType2RefFields
+            .computeIfAbsent(vertexEdgeTypeName, k -> new ArrayList<>())
+            .add(fieldName);
+      }
+      for (int c = 0; c < targetColumns.size(); c++) {
+        SqlIdentifier targetColumn = (SqlIdentifier) targetColumns.get(c);
+        List names = targetColumn.names;
+        String vertexEdgeTypeName = names.get(0);
+        GraphElementTable vertexEdgeTable = graph.getTable(vertexEdgeTypeName);
+        RelDataType tableRowType = vertexEdgeTable.getRowType(validator.getTypeFactory());
+        int[] targetColumnExpIndices =
+            vertexEdgeType2ExpIndices.computeIfAbsent(
+                vertexEdgeTypeName,
+                k -> {
+                  int[] indices = new int[tableRowType.getFieldCount()];
+                  Arrays.fill(indices, -1);
+                  return indices;
+                });
+        String fieldName = names.get(1);
+        int fieldIndex = tableRowType.getFieldNames().indexOf(fieldName);
+        targetColumnExpIndices[fieldIndex] = c;
+      }
+
+      String[] typeNames = new String[vertexEdgeType2RefFields.size()];
+      RexNode[] constructs = new RexNode[typeNames.length];
+      int c = 0;
+      RelDataType graphType = graph.getRowType(validator.getTypeFactory());
+
+      for (int i = 0; i < graphType.getFieldCount(); i++) {
+        String typeName = graphType.getFieldNames().get(i);
+        GraphElementTable vertexEdgeTable = graph.getTable(typeName);
+        RelDataType tableRowType = vertexEdgeTable.getRowType(validator.getTypeFactory());
+
+        int[] targetColumnExpIndices = vertexEdgeType2ExpIndices.get(typeName);
+        if (targetColumnExpIndices != null) {
+          RexObjectConstruct objConstruct =
+              createObjectConstruct(
+                  tableRowType, sourceRel, vertexEdgeTable, targetColumnExpIndices);
+          typeNames[c] = typeName;
+          constructs[c] = objConstruct;
+          c++;
+        }
+      }
+      return createGraphModify(graph, typeNames, constructs, sourceRel);
     }
-
-    private SingleMatchNode convertMatchNodeWhere(SqlMatchNode matchNode, SqlNode matchWhere,
-                                                  IMatchNode input, Blackboard withBb) {
-        assert input != null;
-        SqlValidatorScope whereScope = getValidator().getScopes(matchWhere);
-        Blackboard nodeBb = createBlackboard(whereScope, null, false).setWithBb(withBb);
-        nodeBb.setRoot(new WhereMatchNode(input), true);
-        if (withBb != null) {
-            nodeBb.addInput(withBb.root);
-        }
-        replaceSubQueries(nodeBb, matchWhere, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
-        RexNode condition = nodeBb.convertExpression(matchWhere);
-
-        PathRecordType pathRecordType = input.getPathSchema();
-        RelDataTypeField pathField = pathRecordType.getField(matchNode.getName(), isCaseSensitive(), false);
-        condition = GQLRexUtil.toPathInputRefForWhere(pathField, condition);
-        return MatchFilter.create(input, condition, input.getPathSchema());
+    return super.convertInsert(call);
+  }
+
+  private RexObjectConstruct createObjectConstruct(
+      RelDataType tableRowType,
+      RelNode sourceRel,
+      GraphElementTable table,
+      int[] targetColumnExpIndices) {
+    List columnExpressions = new ArrayList<>();
+    List fields = tableRowType.getFieldList();
+
+    for (int i = 0; i < fields.size(); i++) {
+      RelDataTypeField field = fields.get(i);
+      int fieldExpIndex = targetColumnExpIndices[i];
+      if (fieldExpIndex != -1) {
+        RelDataType sourceFieldType =
+            sourceRel.getRowType().getFieldList().get(fieldExpIndex).getType();
+        RelDataType targetFieldType = field.getType();
+        RexNode inputRef =
+            getRexBuilder()
+                .makeCast(
+                    targetFieldType, getRexBuilder().makeInputRef(sourceFieldType, fieldExpIndex));
+        columnExpressions.add(inputRef);
+      } else if (field.getType() instanceof MetaFieldType
+          && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_TYPE) {
+        RexLiteral vertexLabel = getRexBuilder().makeLiteral(table.getTypeName());
+        columnExpressions.add(getRexBuilder().makeCast(field.getType(), vertexLabel));
+      } else if (field.getType() instanceof MetaFieldType
+          && ((MetaFieldType) field.getType()).getMetaField() == MetaField.EDGE_TYPE) {
+        RexLiteral edgeLabel = getRexBuilder().makeLiteral(table.getTypeName());
+        columnExpressions.add(getRexBuilder().makeCast(field.getType(), edgeLabel));
+      } else {
+        RexLiteral nullLiteral = getRexBuilder().makeNullLiteral(field.getType());
+        columnExpressions.add(nullLiteral);
+      }
     }
-
-    private static class WhereMatchNode extends AbstractRelNode {
-
-        public WhereMatchNode(IMatchNode matchNode) {
-            super(matchNode.getCluster(), matchNode.getTraitSet());
-            this.rowType = matchNode.getNodeType();
-        }
+    Map rex2VarInfo = new HashMap<>();
+    for (int i = 0; i < fields.size(); i++) {
+      VariableInfo variableInfo = new VariableInfo(false, fields.get(i).getName());
+      rex2VarInfo.put(columnExpressions.get(i), variableInfo);
+    }
+    return new RexObjectConstruct(tableRowType, columnExpressions, rex2VarInfo);
+  }
+
+  private GraphModify createGraphModify(
+      GeaFlowGraph graph, String[] typeNames, RexNode[] constructs, RelNode sourceRel) {
+    GraphRecordType graphType = (GraphRecordType) graph.getRowType(validator.getTypeFactory());
+    List projects = new ArrayList<>();
+    for (RelDataTypeField field : graphType.getFieldList()) {
+      int i;
+      for (i = 0; i < typeNames.length; i++) {
+        if (getValidator().nameMatcher().matches(field.getName(), typeNames[i])) {
+          break;
+        }
+      }
+      if (i == typeNames.length) {
+        projects.add(getRexBuilder().makeNullLiteral(field.getType()));
+      } else {
+        projects.add(constructs[i]);
+      }
+    }
+    LogicalProject project = LogicalProject.create(sourceRel, projects, graphType);
+    return LogicalGraphModify.create(project.getCluster(), graph, project);
+  }
+
+  @Override
+  protected void convertFrom(Blackboard bb, SqlNode from) {
+    RelNode relNode;
+    if (from == null) {
+      super.convertFrom(bb, null);
+      return;
     }
+    switch (from.getKind()) {
+      case GQL_RETURN:
+      case GQL_FILTER:
+        relNode = convertQueryRecursive(from, false, null).rel;
+        bb.setRoot(relNode, false);
+        break;
+      case GQL_ALGORITHM:
+      case GQL_MATCH_PATTERN:
+      case GQL_LET:
+      case WITH:
+        relNode = convertQueryRecursive(from, false, null).rel;
+        bb.setRoot(relNode, true);
+        break;
+      default:
+        super.convertFrom(bb, from);
+    }
+  }
 
-    private IMatchNode convertPathPattern(SqlNode sqlNode, Blackboard withBb) {
-        if (sqlNode instanceof SqlUnionPathPattern) {
-            SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) sqlNode;
-            IMatchNode left = convertPathPattern(unionPathPattern.getLeft(), withBb);
-            IMatchNode right = convertPathPattern(unionPathPattern.getRight(), withBb);
-            return MatchUnion.create(getCluster(), getCluster().traitSet(),
-                Lists.newArrayList(left, right), unionPathPattern.isUnionAll());
-        }
-        SqlPathPattern pathPattern = (SqlPathPattern) sqlNode;
-        SingleMatchNode relPathPattern = null;
-        SqlMatchNode preMatchNode = null;
-        for (SqlNode pathNode : pathPattern.getPathNodes()) {
-            PathRecordType pathType = (PathRecordType) getValidator().getValidatedNodeType(pathNode);
-            switch (pathNode.getKind()) {
-                case GQL_MATCH_NODE:
-                    SqlMatchNode matchNode = (SqlMatchNode) pathNode;
-                    RelDataType nodeType = getValidator().getMatchNodeType(matchNode);
-
-                    SingleMatchNode vertexMatch = VertexMatch.create(getCluster(), relPathPattern,
-                        matchNode.getName(), matchNode.getLabelNames(), nodeType, pathType);
-                    if (matchNode.getWhere() != null) {
-                        vertexMatch = convertMatchNodeWhere(matchNode, matchNode.getWhere(),
-                            vertexMatch, withBb);
-                    }
-                    // generate for regex match
-                    if (preMatchNode instanceof SqlMatchEdge && ((SqlMatchEdge) preMatchNode).isRegexMatch()) {
-                        SqlMatchEdge inputEdgeNode = (SqlMatchEdge) preMatchNode;
-                        EdgeMatch inputEdgeMatch = (EdgeMatch) vertexMatch.find(node -> node instanceof EdgeMatch);
-                        relPathPattern = convertRegexMatch(inputEdgeNode, inputEdgeMatch, vertexMatch);
-                    } else {
-                        relPathPattern = vertexMatch;
-                    }
-                    if (getValidator().getStartCycleMatchNode((SqlMatchNode) pathNode) != null) {
-                        relPathPattern = translateCycleMatchNode(relPathPattern, pathNode);
-                    }
-                    break;
-                case GQL_MATCH_EDGE:
-                    SqlMatchEdge matchEdge = (SqlMatchEdge) pathNode;
-                    RelDataType edgeType = getValidator().getMatchNodeType(matchEdge);
-
-                    EdgeMatch edgeMatch = EdgeMatch.create(
-                        getCluster(), relPathPattern,
-                        matchEdge.getName(), matchEdge.getLabelNames(),
-                        matchEdge.getDirection(), edgeType, pathType);
-
-                    if (matchEdge.getWhere() != null) {
-                        relPathPattern = convertMatchNodeWhere(matchEdge, matchEdge.getWhere(),
-                            edgeMatch, withBb);
-                    } else {
-                        relPathPattern = edgeMatch;
-                    }
-                    break;
-                default:
-                    throw new IllegalArgumentException("Illegal path node kind: " + pathNode.getKind());
-            }
-            preMatchNode = (SqlMatchNode) pathNode;
+  private String deriveAlias(SqlNode node, Collection aliases, int ordinal) {
+    String alias = this.validator.deriveAlias(node, ordinal);
+    if (alias == null || aliases.contains(alias)) {
+      String aliasBase = alias == null ? "EXPR$" : alias;
+      int j = 0;
+
+      while (true) {
+        alias = aliasBase + j;
+        if (!aliases.contains(alias)) {
+          break;
         }
-        return relPathPattern;
-    }
-
-    private SingleMatchNode translateCycleMatchNode(SingleMatchNode relPathPattern, SqlNode pathNode) {
-        PathRecordType pathRecordType = relPathPattern.getPathSchema();
-        int rightIndex = pathRecordType.getField(((SqlMatchNode) pathNode).getName(),
-            isCaseSensitive(), false).getIndex();
-        int leftIndex = pathRecordType.getField(getValidator()
-                .getStartCycleMatchNode((SqlMatchNode) pathNode).getName(),
-            isCaseSensitive(), false).getIndex();
-        assert leftIndex >= 0 && rightIndex >= 0;
-        RexNode condition = getRexBuilder().makeCall(SqlStdOperatorTable.EQUALS,
-            getRexBuilder().makeInputRef(pathRecordType, leftIndex),
-            getRexBuilder().makeInputRef(pathRecordType, rightIndex)
-        );
-        return MatchFilter.create(relPathPattern, condition,
-            relPathPattern.getPathSchema());
-    }
-
-    private SingleMatchNode convertRegexMatch(SqlMatchEdge regexEdge, EdgeMatch regexEdgeMatch,
-                                              SingleMatchNode vertexMatch) {
-        IMatchNode loopStart = (IMatchNode) regexEdgeMatch.getInput();
-        SubQueryStart queryStart = SubQueryStart.create(getCluster(),
-            loopStart.getTraitSet(), generateSubQueryName(), loopStart.getPathSchema(),
-            (VertexRecordType) loopStart.getNodeType());
-        // replace the input of regexEdgeMatch to queryStart and clone vertexMatch
-        SingleMatchNode loopBody = GQLRelUtil.replaceInput(vertexMatch, regexEdgeMatch, queryStart);
 
-        RexNode utilCondition = getRexBuilder().makeLiteral(true);
-        return LoopUntilMatch.create(getCluster(), vertexMatch.getTraitSet(), loopStart,
-            loopBody, utilCondition, regexEdge.getMinHop(), regexEdge.getMaxHop(),
-            vertexMatch.getPathSchema());
+        ++j;
+      }
     }
 
-    private RelNode convertGQLAlgorithm(SqlGraphAlgorithmCall algorithmCall, boolean top,
-                                        Blackboard withBb) {
-
-        SqlValidatorScope scope = getValidator().getScopes(algorithmCall);
-        Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
-        convertGQLFrom(bb, algorithmCall.getFrom(), withBb);
-
-        Object[] params =
-            algorithmCall.getParameters() == null ? new Object[0] : algorithmCall.getParameters().getList()
-                .stream().map(sqlNode -> GQLRexUtil.getLiteralValue(bb.convertExpression(sqlNode)))
-                .toArray();
-        return LogicalGraphAlgorithm.create(cluster, getCluster().traitSet(),
-            bb.root,
-            ((GeaFlowUserDefinedGraphAlgorithm) algorithmCall.getOperator()).getImplementClass(),
-            params);
-    }
-
-    private RelNode convertGQLLet(SqlLetStatement letStatement, boolean top, Blackboard withBb) {
-        SqlValidatorScope scope = getValidator().getScopes(letStatement);
-        Blackboard bb = createBlackboard(scope, null, top).setWithBb(withBb);
-        convertGQLFrom(bb, letStatement.getFrom(), withBb);
-
-        RexNode rightExpression = bb.convertExpression(letStatement.getExpression());
-
-        PathRecordType letType = (PathRecordType) getValidator().getValidatedNodeType(letStatement);
-        PathInputRef leftRex = convertLeftLabel(letStatement.getLeftLabel(), letType);
-
-        String leftVarField = letStatement.getLeftField();
-        List operands = new ArrayList<>();
-        SqlIdentifier[] fieldNameNodes = new SqlIdentifier[leftRex.getType().getFieldCount()];
-        Map rex2VariableInfo = new HashMap<>();
-        int c = 0;
-        for (RelDataTypeField field : leftRex.getType().getFieldList()) {
-            VariableInfo variableInfo;
-            RexNode operand;
-            if (getValidator().nameMatcher().matches(leftVarField, field.getName())) {
-                // The field is the let a.xx
-                boolean isGlobal = letStatement.isGlobal();
-                // cast right expression to field type.
-                operand = getRexBuilder().makeCast(field.getType(), rightExpression);
-                variableInfo = new VariableInfo(isGlobal, field.getName());
-            } else {
-                RexInputRef labelRef = getRexBuilder().makeInputRef(leftRex.getType(), leftRex.getIndex());
-                operand = getRexBuilder().makeFieldAccess(labelRef, field.getIndex());
-                variableInfo = new VariableInfo(false, field.getName());
-            }
-            operands.add(operand);
-            rex2VariableInfo.put(operand, variableInfo);
-            fieldNameNodes[c++] = new SqlIdentifier(field.getName(), SqlParserPos.ZERO);
-        }
-        // Construct RexObjectConstruct for dynamic field append expression.
-        RexObjectConstruct rightRex = new RexObjectConstruct(leftRex.getType(), operands, rex2VariableInfo);
-        PathModifyExpression modifyExpression = new PathModifyExpression(leftRex, rightRex);
-        GraphRecordType modifyGraphType = getValidator().getModifyGraphType(letStatement);
-
-        RelNode input = bb.root;
-        assert input instanceof GraphMatch;
-        GraphMatch graphMatch = (GraphMatch) input;
-        MatchPathModify newPathPattern = MatchPathModify.create(graphMatch.getPathPattern(),
-            Collections.singletonList(modifyExpression), letType, modifyGraphType);
-        return graphMatch.copy(newPathPattern);
-    }
-
-    private PathInputRef convertLeftLabel(String leftLabel, PathRecordType letType) {
-        RelDataTypeField labelField = letType.getField(leftLabel, isCaseSensitive(), false);
-        return new PathInputRef(leftLabel, labelField.getIndex(), labelField.getType());
-    }
-
-    private void convertGQLFrom(Blackboard bb, SqlNode from, Blackboard withBb) {
-        RelNode node;
-        switch (from.getKind()) {
-            case GQL_MATCH_PATTERN:
-                node = convertGQLMatchPattern((SqlMatchPattern) from, false, withBb);
-                bb.setRoot(node, true);
-                break;
-            case GQL_RETURN:
-                node = convertGQLReturn((SqlReturnStatement) from, false, withBb);
-                bb.setRoot(node, true);
-                break;
-            case GQL_FILTER:
-                node = convertGQLFilter((SqlFilterStatement) from, false, withBb);
-                bb.setRoot(node, true);
-                break;
-            case IDENTIFIER:
-                SqlIdentifier identifier = (SqlIdentifier) from;
-                SqlIdentifier completeIdentifier = getValidator().getGQLContext()
-                    .completeCatalogObjName(identifier);
-                RelOptTable table = catalogReader.getTable(completeIdentifier.names);
-                node = LogicalGraphScan.create(getCluster(), table);
-                bb.setRoot(node, true);
-                break;
-            case GQL_LET:
-                node = convertGQLLet((SqlLetStatement) from, false, withBb);
-                bb.setRoot(node, true);
-                break;
-            case GQL_ALGORITHM:
-                node = convertGQLAlgorithm((SqlGraphAlgorithmCall) from, false, withBb);
-                bb.setRoot(node, true);
-                break;
-            default:
-                throw new IllegalArgumentException("Illegal match from sql node: " + from.getKind());
-        }
-        if (withBb != null) {
-            bb.addInput(withBb.root);
-        }
+    aliases.add(alias);
+    return alias;
+  }
+
+  private static boolean desc(Direction direction) {
+    switch (direction) {
+      case DESCENDING:
+      case STRICTLY_DESCENDING:
+        return true;
+      default:
+        return false;
+    }
+  }
+
+  protected RelFieldCollation convertOrderItem(
+      SqlReturnStatement returnStmt,
+      SqlNode orderItem,
+      List extraExprs,
+      Direction direction,
+      NullDirection nullDirection) {
+    assert returnStmt != null;
+
+    switch (orderItem.getKind()) {
+      case DESCENDING:
+        return this.convertOrderItem(
+            returnStmt,
+            ((SqlCall) orderItem).operand(0),
+            extraExprs,
+            Direction.DESCENDING,
+            nullDirection);
+      case NULLS_FIRST:
+        return this.convertOrderItem(
+            returnStmt,
+            ((SqlCall) orderItem).operand(0),
+            extraExprs,
+            direction,
+            NullDirection.FIRST);
+      case NULLS_LAST:
+        return this.convertOrderItem(
+            returnStmt,
+            ((SqlCall) orderItem).operand(0),
+            extraExprs,
+            direction,
+            NullDirection.LAST);
+      default:
+        SqlValidatorScope orderScope = getValidator().getScopes(returnStmt.getOrderBy());
+        SqlNode converted =
+            ((GQLValidatorImpl) validator)
+                .expandReturnGroupOrderExpr(returnStmt, orderScope, orderItem);
+        if (nullDirection == UNSPECIFIED) {
+          nullDirection =
+              this.validator.getDefaultNullCollation().last(desc(direction))
+                  ? NullDirection.LAST
+                  : NullDirection.FIRST;
+        }
+        GQLReturnScope returnScope = (GQLReturnScope) getValidator().getScopes(returnStmt);
+        int ordinal = -1;
+        Iterator returnListItr = returnScope.getExpandedReturnList().iterator();
+
+        SqlNode extraExpr;
+        do {
+          if (!returnListItr.hasNext()) {
+            returnListItr = extraExprs.iterator();
+
+            do {
+              if (!returnListItr.hasNext()) {
+                extraExprs.add(converted);
+                return new RelFieldCollation(ordinal + 1, direction, nullDirection);
+              }
+
+              extraExpr = returnListItr.next();
+              ++ordinal;
+            } while (!converted.equalsDeep(extraExpr, Litmus.IGNORE));
+
+            return new RelFieldCollation(ordinal, direction, nullDirection);
+          }
+
+          extraExpr = returnListItr.next();
+          ++ordinal;
+        } while (!converted.equalsDeep(SqlUtil.stripAs(extraExpr), Litmus.IGNORE));
+
+        return new RelFieldCollation(ordinal, direction, nullDirection);
     }
+  }
 
-    @Override
-    protected RelNode convertInsert(SqlInsert call) {
-        RelOptTable targetTable = getTargetTable(call);
-        Table table = targetTable.unwrap(Table.class);
-        final RelDataType targetRowType =
-            validator.getValidatedNodeType(call);
-        assert targetRowType != null;
-
-        RelNode sourceRel = convertQueryRecursive(call.getSource(), false, targetRowType).project();
-
-        if (table instanceof GraphElementTable) {
-            List targetColumns;
-            if (call.getTargetColumnList() != null) {
-                targetColumns = call.getTargetColumnList().getList()
-                    .stream().map(id -> ((SqlIdentifier) id).getSimple())
-                    .collect(Collectors.toList());
-            } else {
-                targetColumns = new ArrayList<>();
-                for (RelDataTypeField field : targetRowType.getFieldList()) {
-                    targetColumns.add(field.getName());
-                }
-            }
-            GraphElementTable graphElementTable = (GraphElementTable) table;
-            GeaFlowGraph graph = graphElementTable.getGraph();
-            RelDataType tableRowType = table.getRowType(validator.getTypeFactory());
-            int[] targetColumnIndices = new int[tableRowType.getFieldCount()];
-            for (int c = 0; c < tableRowType.getFieldList().size(); c++) {
-                RelDataTypeField field = tableRowType.getFieldList().get(c);
-                int i;
-                for (i = 0; i < targetColumns.size(); i++) {
-                    if (getValidator().nameMatcher().matches(targetColumns.get(i), field.getName())) {
-                        break;
-                    }
-                }
-                if (i < targetColumns.size()) {
-                    targetColumnIndices[c] = i;
-                } else { // -1 means the meta field
-                    targetColumnIndices[c] = -1;
-                }
-            }
-            RexObjectConstruct objConstruct = createObjectConstruct(tableRowType, sourceRel,
-                graphElementTable, targetColumnIndices);
-            return createGraphModify(graph, new String[]{graphElementTable.getTypeName()},
-                new RexNode[]{objConstruct}, sourceRel);
-        } else if (table instanceof GeaFlowGraph) {
-            GeaFlowGraph graph = (GeaFlowGraph) table;
-            Map> vertexEdgeType2RefFields = new HashMap<>();
-            Map vertexEdgeType2ExpIndices = new HashMap<>();
-
-            SqlNodeList targetColumns = call.getTargetColumnList();
-            assert targetColumns != null && targetColumns.size() > 0;
-
-            for (SqlNode targetColumn : targetColumns) {
-                List names = ((SqlIdentifier) targetColumn).names;
-                assert names.size() == 2;
-                String vertexEdgeTypeName = names.get(0);
-                String fieldName = names.get(1);
-                vertexEdgeType2RefFields.computeIfAbsent(vertexEdgeTypeName, k -> new ArrayList<>()).add(fieldName);
-            }
-            for (int c = 0; c < targetColumns.size(); c++) {
-                SqlIdentifier targetColumn = (SqlIdentifier) targetColumns.get(c);
-                List names = targetColumn.names;
-                String vertexEdgeTypeName = names.get(0);
-                GraphElementTable vertexEdgeTable = graph.getTable(vertexEdgeTypeName);
-                RelDataType tableRowType = vertexEdgeTable.getRowType(validator.getTypeFactory());
-                int[] targetColumnExpIndices = vertexEdgeType2ExpIndices.computeIfAbsent(vertexEdgeTypeName,
-                    k -> {
-                        int[] indices = new int[tableRowType.getFieldCount()];
-                        Arrays.fill(indices, -1);
-                        return indices;
-                    });
-                String fieldName = names.get(1);
-                int fieldIndex = tableRowType.getFieldNames().indexOf(fieldName);
-                targetColumnExpIndices[fieldIndex] = c;
-            }
+  private static class ReturnAggregateFinder extends SqlBasicVisitor {
 
-            String[] typeNames = new String[vertexEdgeType2RefFields.size()];
-            RexNode[] constructs = new RexNode[typeNames.length];
-            int c = 0;
-            RelDataType graphType = graph.getRowType(validator.getTypeFactory());
-
-            for (int i = 0; i < graphType.getFieldCount(); i++) {
-                String typeName = graphType.getFieldNames().get(i);
-                GraphElementTable vertexEdgeTable = graph.getTable(typeName);
-                RelDataType tableRowType = vertexEdgeTable.getRowType(validator.getTypeFactory());
-
-                int[] targetColumnExpIndices = vertexEdgeType2ExpIndices.get(typeName);
-                if (targetColumnExpIndices != null) {
-                    RexObjectConstruct objConstruct = createObjectConstruct(tableRowType, sourceRel,
-                        vertexEdgeTable, targetColumnExpIndices);
-                    typeNames[c] = typeName;
-                    constructs[c] = objConstruct;
-                    c++;
-                }
-            }
-            return createGraphModify(graph, typeNames, constructs, sourceRel);
-        }
-        return super.convertInsert(call);
-    }
-
-    private RexObjectConstruct createObjectConstruct(RelDataType tableRowType, RelNode sourceRel,
-                                                     GraphElementTable table, int[] targetColumnExpIndices) {
-        List columnExpressions = new ArrayList<>();
-        List fields = tableRowType.getFieldList();
-
-        for (int i = 0; i < fields.size(); i++) {
-            RelDataTypeField field = fields.get(i);
-            int fieldExpIndex = targetColumnExpIndices[i];
-            if (fieldExpIndex != -1) {
-                RelDataType sourceFieldType = sourceRel.getRowType().getFieldList().get(fieldExpIndex).getType();
-                RelDataType targetFieldType = field.getType();
-                RexNode inputRef = getRexBuilder().makeCast(targetFieldType,
-                    getRexBuilder().makeInputRef(sourceFieldType, fieldExpIndex));
-                columnExpressions.add(inputRef);
-            } else if (field.getType() instanceof MetaFieldType
-                && ((MetaFieldType) field.getType()).getMetaField() == MetaField.VERTEX_TYPE) {
-                RexLiteral vertexLabel = getRexBuilder().makeLiteral(table.getTypeName());
-                columnExpressions.add(getRexBuilder().makeCast(field.getType(), vertexLabel));
-            } else if (field.getType() instanceof MetaFieldType
-                && ((MetaFieldType) field.getType()).getMetaField() == MetaField.EDGE_TYPE) {
-                RexLiteral edgeLabel = getRexBuilder().makeLiteral(table.getTypeName());
-                columnExpressions.add(getRexBuilder().makeCast(field.getType(), edgeLabel));
-            } else {
-                RexLiteral nullLiteral = getRexBuilder().makeNullLiteral(field.getType());
-                columnExpressions.add(nullLiteral);
-            }
-        }
-        Map rex2VarInfo = new HashMap<>();
-        for (int i = 0; i < fields.size(); i++) {
-            VariableInfo variableInfo = new VariableInfo(false, fields.get(i).getName());
-            rex2VarInfo.put(columnExpressions.get(i), variableInfo);
-        }
-        return new RexObjectConstruct(tableRowType, columnExpressions, rex2VarInfo);
-    }
-
-    private GraphModify createGraphModify(GeaFlowGraph graph, String[] typeNames,
-                                          RexNode[] constructs, RelNode sourceRel) {
-        GraphRecordType graphType = (GraphRecordType) graph.getRowType(validator.getTypeFactory());
-        List projects = new ArrayList<>();
-        for (RelDataTypeField field : graphType.getFieldList()) {
-            int i;
-            for (i = 0; i < typeNames.length; i++) {
-                if (getValidator().nameMatcher().matches(field.getName(), typeNames[i])) {
-                    break;
-                }
-            }
-            if (i == typeNames.length) {
-                projects.add(getRexBuilder().makeNullLiteral(field.getType()));
-            } else {
-                projects.add(constructs[i]);
-            }
-        }
-        LogicalProject project = LogicalProject.create(sourceRel, projects, graphType);
-        return LogicalGraphModify.create(project.getCluster(), graph, project);
-    }
+    final SqlNodeList list = new SqlNodeList(SqlParserPos.ZERO);
 
     @Override
-    protected void convertFrom(
-        Blackboard bb,
-        SqlNode from) {
-        RelNode relNode;
-        if (from == null) {
-            super.convertFrom(bb, null);
-            return;
-        }
-        switch (from.getKind()) {
-            case GQL_RETURN:
-            case GQL_FILTER:
-                relNode = convertQueryRecursive(from, false, null).rel;
-                bb.setRoot(relNode, false);
-                break;
-            case GQL_ALGORITHM:
-            case GQL_MATCH_PATTERN:
-            case GQL_LET:
-            case WITH:
-                relNode = convertQueryRecursive(from, false, null).rel;
-                bb.setRoot(relNode, true);
-                break;
-            default:
-                super.convertFrom(bb, from);
-        }
+    public Void visit(SqlCall call) {
+      if (call.getOperator().isAggregator()) {
+        list.add(call);
+        return null;
+      }
+      return call.getOperator().acceptCall(this, call);
+    }
+  }
+
+  private void convertDistinct(Blackboard bb, boolean checkForDupExprs) {
+    // Look for duplicate expressions in the project.
+    // Say we have 'select x, y, x, z'.
+    // Then dups will be {[2, 0]}
+    // and oldToNew will be {[0, 0], [1, 1], [2, 0], [3, 2]}
+    RelNode rel = bb.root;
+    if (checkForDupExprs && (rel instanceof LogicalProject)) {
+      LogicalProject project = (LogicalProject) rel;
+      final List projectExprs = project.getProjects();
+      final List origins = new ArrayList<>();
+      int dupCount = 0;
+      for (int i = 0; i < projectExprs.size(); i++) {
+        int x = projectExprs.indexOf(projectExprs.get(i));
+        if (x >= 0 && x < i) {
+          origins.add(x);
+          ++dupCount;
+        } else {
+          origins.add(i);
+        }
+      }
+      if (dupCount == 0) {
+        convertDistinct(bb, false);
+        return;
+      }
+
+      final Map squished = new HashMap<>();
+      final List fields = rel.getRowType().getFieldList();
+      final List> newProjects = new ArrayList<>();
+      for (int i = 0; i < fields.size(); i++) {
+        if (origins.get(i) == i) {
+          squished.put(i, newProjects.size());
+          newProjects.add(RexInputRef.of2(i, fields));
+        }
+      }
+      rel = LogicalProject.create(rel, Pair.left(newProjects), Pair.right(newProjects));
+      bb.root = rel;
+      convertDistinct(bb, false);
+      rel = bb.root;
+
+      // Create the expressions to reverse the mapping.
+      // Project($0, $1, $0, $2).
+      final List> undoProjects = new ArrayList<>();
+      for (int i = 0; i < fields.size(); i++) {
+        final int origin = origins.get(i);
+        RelDataTypeField field = fields.get(i);
+        undoProjects.add(
+            Pair.of(new RexInputRef(squished.get(origin), field.getType()), field.getName()));
+      }
+
+      rel = LogicalProject.create(rel, Pair.left(undoProjects), Pair.right(undoProjects));
+      bb.setRoot(rel, false);
+      return;
     }
 
-    private String deriveAlias(SqlNode node, Collection aliases, int ordinal) {
-        String alias = this.validator.deriveAlias(node, ordinal);
-        if (alias == null || aliases.contains(alias)) {
-            String aliasBase = alias == null ? "EXPR$" : alias;
-            int j = 0;
+    // Usual case: all of the expressions in the SELECT clause are
+    // different.
+    final ImmutableBitSet groupSet = ImmutableBitSet.range(rel.getRowType().getFieldCount());
+    rel = createAggregate(bb, groupSet, ImmutableList.of(groupSet), ImmutableList.of());
 
-            while (true) {
-                alias = aliasBase + j;
-                if (!aliases.contains(alias)) {
-                    break;
-                }
+    bb.setRoot(rel, false);
+  }
 
-                ++j;
-            }
-        }
+  public class GQLAggConverter extends AggConverter implements SqlVisitor {
 
-        aliases.add(alias);
-        return alias;
-    }
+    private final Blackboard bb;
+    public final GQLReturnScope gqlReturnScope;
 
-    private static boolean desc(Direction direction) {
-        switch (direction) {
-            case DESCENDING:
-            case STRICTLY_DESCENDING:
-                return true;
-            default:
-                return false;
-        }
-    }
+    private final Map nameMap = new HashMap<>();
 
-    protected RelFieldCollation convertOrderItem(SqlReturnStatement returnStmt, SqlNode orderItem,
-                                                 List extraExprs, Direction direction,
-                                                 NullDirection nullDirection) {
-        assert returnStmt != null;
-
-        switch (orderItem.getKind()) {
-            case DESCENDING:
-                return this.convertOrderItem(returnStmt, ((SqlCall) orderItem).operand(0), extraExprs,
-                    Direction.DESCENDING, nullDirection);
-            case NULLS_FIRST:
-                return this.convertOrderItem(returnStmt, ((SqlCall) orderItem).operand(0), extraExprs, direction,
-                    NullDirection.FIRST);
-            case NULLS_LAST:
-                return this.convertOrderItem(returnStmt, ((SqlCall) orderItem).operand(0), extraExprs, direction,
-                    NullDirection.LAST);
-            default:
-                SqlValidatorScope orderScope = getValidator().getScopes(returnStmt.getOrderBy());
-                SqlNode converted = ((GQLValidatorImpl) validator).expandReturnGroupOrderExpr(returnStmt, orderScope,
-                    orderItem);
-                if (nullDirection == UNSPECIFIED) {
-                    nullDirection = this.validator.getDefaultNullCollation().last(desc(direction))
-                        ? NullDirection.LAST : NullDirection.FIRST;
-                }
-                GQLReturnScope returnScope =
-                    (GQLReturnScope) getValidator().getScopes(returnStmt);
-                int ordinal = -1;
-                Iterator returnListItr = returnScope.getExpandedReturnList().iterator();
-
-                SqlNode extraExpr;
-                do {
-                    if (!returnListItr.hasNext()) {
-                        returnListItr = extraExprs.iterator();
-
-                        do {
-                            if (!returnListItr.hasNext()) {
-                                extraExprs.add(converted);
-                                return new RelFieldCollation(ordinal + 1, direction, nullDirection);
-                            }
-
-                            extraExpr = returnListItr.next();
-                            ++ordinal;
-                        } while (!converted.equalsDeep(extraExpr, Litmus.IGNORE));
-
-                        return new RelFieldCollation(ordinal, direction, nullDirection);
-                    }
-
-                    extraExpr = returnListItr.next();
-                    ++ordinal;
-                } while (!converted.equalsDeep(SqlUtil.stripAs(extraExpr), Litmus.IGNORE));
-
-                return new RelFieldCollation(ordinal, direction, nullDirection);
-        }
-    }
+    /** The group-by expressions, in {@link SqlNode} format. */
+    private final SqlNodeList groupExprs = new SqlNodeList(SqlParserPos.ZERO);
 
-    private static class ReturnAggregateFinder extends SqlBasicVisitor {
+    /** The auxiliary group-by expressions. */
+    private final Map> auxiliaryGroupExprs = new HashMap<>();
 
-        final SqlNodeList list = new SqlNodeList(SqlParserPos.ZERO);
+    /**
+     * Input expressions for the group columns and aggregates, in {@link RexNode} format. The first
+     * elements of the list correspond to the elements in {@link #groupExprs}; the remaining
+     * elements are for aggregates. The right field of each pair is the name of the expression,
+     * where the expressions are simple mappings to input fields.
+     */
+    private final List> convertedInputExprs = new ArrayList<>();
 
-        @Override
-        public Void visit(SqlCall call) {
-            if (call.getOperator().isAggregator()) {
-                list.add(call);
-                return null;
-            }
-            return call.getOperator().acceptCall(this, call);
-        }
-    }
+    /**
+     * Expressions to be evaluated as rows are being placed into the aggregate's hash table. This is
+     * when group functions such as TUMBLE cause rows to be expanded.
+     */
+    private final List aggCalls = new ArrayList<>();
 
-    private void convertDistinct(Blackboard bb, boolean checkForDupExprs) {
-        // Look for duplicate expressions in the project.
-        // Say we have 'select x, y, x, z'.
-        // Then dups will be {[2, 0]}
-        // and oldToNew will be {[0, 0], [1, 1], [2, 0], [3, 2]}
-        RelNode rel = bb.root;
-        if (checkForDupExprs && (rel instanceof LogicalProject)) {
-            LogicalProject project = (LogicalProject) rel;
-            final List projectExprs = project.getProjects();
-            final List origins = new ArrayList<>();
-            int dupCount = 0;
-            for (int i = 0; i < projectExprs.size(); i++) {
-                int x = projectExprs.indexOf(projectExprs.get(i));
-                if (x >= 0 && x < i) {
-                    origins.add(x);
-                    ++dupCount;
-                } else {
-                    origins.add(i);
-                }
-            }
-            if (dupCount == 0) {
-                convertDistinct(bb, false);
-                return;
-            }
+    private final Map aggMapping = new HashMap<>();
+    private final Map aggCallMapping = new HashMap<>();
 
-            final Map squished = new HashMap<>();
-            final List fields = rel.getRowType().getFieldList();
-            final List> newProjects = new ArrayList<>();
-            for (int i = 0; i < fields.size(); i++) {
-                if (origins.get(i) == i) {
-                    squished.put(i, newProjects.size());
-                    newProjects.add(RexInputRef.of2(i, fields));
-                }
-            }
-            rel =
-                LogicalProject.create(rel, Pair.left(newProjects),
-                    Pair.right(newProjects));
-            bb.root = rel;
-            convertDistinct(bb, false);
-            rel = bb.root;
-
-            // Create the expressions to reverse the mapping.
-            // Project($0, $1, $0, $2).
-            final List> undoProjects = new ArrayList<>();
-            for (int i = 0; i < fields.size(); i++) {
-                final int origin = origins.get(i);
-                RelDataTypeField field = fields.get(i);
-                undoProjects.add(Pair.of(
-                    new RexInputRef(squished.get(origin), field.getType()), field.getName()));
-            }
+    private boolean inOver = false;
 
-            rel =
-                LogicalProject.create(rel, Pair.left(undoProjects),
-                    Pair.right(undoProjects));
-            bb.setRoot(
-                rel,
-                false);
-            return;
-        }
+    public GQLAggConverter(Blackboard bb, SqlReturnStatement returnStatement) {
+      super(bb);
+      this.bb = bb;
+      this.gqlReturnScope = (GQLReturnScope) getValidator().getScopes(returnStatement);
 
-        // Usual case: all of the expressions in the SELECT clause are
-        // different.
-        final ImmutableBitSet groupSet =
-            ImmutableBitSet.range(rel.getRowType().getFieldCount());
-        rel = createAggregate(bb, groupSet, ImmutableList.of(groupSet),
-            ImmutableList.of());
-
-        bb.setRoot(
-            rel,
-            false);
-    }
-
-    public class GQLAggConverter extends AggConverter implements SqlVisitor {
-
-        private final Blackboard bb;
-        public final GQLReturnScope gqlReturnScope;
-
-        private final Map nameMap = new HashMap<>();
-
-        /**
-         * The group-by expressions, in {@link SqlNode} format.
-         */
-        private final SqlNodeList groupExprs =
-            new SqlNodeList(SqlParserPos.ZERO);
-
-        /**
-         * The auxiliary group-by expressions.
-         */
-        private final Map> auxiliaryGroupExprs =
-            new HashMap<>();
-
-        /**
-         * Input expressions for the group columns and aggregates, in
-         * {@link RexNode} format. The first elements of the list correspond to the
-         * elements in {@link #groupExprs}; the remaining elements are for
-         * aggregates. The right field of each pair is the name of the expression,
-         * where the expressions are simple mappings to input fields.
-         */
-        private final List> convertedInputExprs =
-            new ArrayList<>();
-
-        /**
-         * Expressions to be evaluated as rows are being placed into the
-         * aggregate's hash table. This is when group functions such as TUMBLE
-         * cause rows to be expanded.
-         */
-
-        private final List aggCalls = new ArrayList<>();
-        private final Map aggMapping = new HashMap<>();
-        private final Map aggCallMapping =
-            new HashMap<>();
-
-        private boolean inOver = false;
-
-        public GQLAggConverter(Blackboard bb, SqlReturnStatement returnStatement) {
-            super(bb);
-            this.bb = bb;
-            this.gqlReturnScope = (GQLReturnScope) getValidator().getScopes(returnStatement);
-
-            // Collect all expressions used in the select list so that aggregate
-            // calls can be named correctly.
-            final SqlNodeList returnList = returnStatement.getReturnList();
-            for (int i = 0; i < returnList.size(); i++) {
-                SqlNode returnItem = returnList.get(i);
-                String name = null;
-                if (SqlUtil.isCallTo(
-                    returnItem,
-                    SqlStdOperatorTable.AS)) {
-                    final SqlCall call = (SqlCall) returnItem;
-                    returnItem = call.operand(0);
-                    name = call.operand(1).toString();
-                }
-                if (name == null) {
-                    name = validator.deriveAlias(returnItem, i);
-                }
-                nameMap.put(returnItem.toString(), name);
-            }
+      // Collect all expressions used in the select list so that aggregate
+      // calls can be named correctly.
+      final SqlNodeList returnList = returnStatement.getReturnList();
+      for (int i = 0; i < returnList.size(); i++) {
+        SqlNode returnItem = returnList.get(i);
+        String name = null;
+        if (SqlUtil.isCallTo(returnItem, SqlStdOperatorTable.AS)) {
+          final SqlCall call = (SqlCall) returnItem;
+          returnItem = call.operand(0);
+          name = call.operand(1).toString();
         }
-
-        public int addGroupExpr(SqlNode expr) {
-            int ref = lookupGroupExpr(expr);
-            if (ref >= 0) {
-                return ref;
-            }
-            final int index = groupExprs.size();
-            groupExprs.add(expr);
-            String name = nameMap.get(expr.toString());
-            RexNode convExpr = bb.convertExpression(expr);
-            addExpr(convExpr, name);
-
-            if (expr instanceof SqlCall) {
-                SqlCall call = (SqlCall) expr;
-                for (Pair p
-                    : SqlStdOperatorTable.convertGroupToAuxiliaryCalls(call)) {
-                    addAuxiliaryGroupExpr(p.left, index, p.right);
-                }
-            }
-
-            return index;
-        }
-
-        void addAuxiliaryGroupExpr(SqlNode node, int index,
-                                   AuxiliaryConverter converter) {
-            for (SqlNode node2 : auxiliaryGroupExprs.keySet()) {
-                if (node2.equalsDeep(node, Litmus.IGNORE)) {
-                    return;
-                }
-            }
-            auxiliaryGroupExprs.put(node, Ord.of(index, converter));
+        if (name == null) {
+          name = validator.deriveAlias(returnItem, i);
         }
+        nameMap.put(returnItem.toString(), name);
+      }
+    }
 
-        /**
-         * Adds an expression, deducing an appropriate name if possible.
-         *
-         * @param expr Expression
-         * @param name Suggested name
-         */
-        private void addExpr(RexNode expr, String name) {
-            if ((name == null) && (expr instanceof RexInputRef)) {
-                final int i = ((RexInputRef) expr).getIndex();
-                name = bb.root.getRowType().getFieldList().get(i).getName();
-            }
-            if (Pair.right(convertedInputExprs).contains(name)) {
-                // In case like 'SELECT ... GROUP BY x, y, x', don't add
-                // name 'x' twice.
-                name = null;
-            }
-            convertedInputExprs.add(Pair.of(expr, name));
-        }
+    public int addGroupExpr(SqlNode expr) {
+      int ref = lookupGroupExpr(expr);
+      if (ref >= 0) {
+        return ref;
+      }
+      final int index = groupExprs.size();
+      groupExprs.add(expr);
+      String name = nameMap.get(expr.toString());
+      RexNode convExpr = bb.convertExpression(expr);
+      addExpr(convExpr, name);
+
+      if (expr instanceof SqlCall) {
+        SqlCall call = (SqlCall) expr;
+        for (Pair p :
+            SqlStdOperatorTable.convertGroupToAuxiliaryCalls(call)) {
+          addAuxiliaryGroupExpr(p.left, index, p.right);
+        }
+      }
+
+      return index;
+    }
 
-        public Void visit(SqlIdentifier id) {
-            return null;
+    void addAuxiliaryGroupExpr(SqlNode node, int index, AuxiliaryConverter converter) {
+      for (SqlNode node2 : auxiliaryGroupExprs.keySet()) {
+        if (node2.equalsDeep(node, Litmus.IGNORE)) {
+          return;
         }
+      }
+      auxiliaryGroupExprs.put(node, Ord.of(index, converter));
+    }
 
-        public Void visit(SqlNodeList nodeList) {
-            for (int i = 0; i < nodeList.size(); i++) {
-                nodeList.get(i).accept(this);
-            }
-            return null;
-        }
+    /**
+     * Adds an expression, deducing an appropriate name if possible.
+     *
+     * @param expr Expression
+     * @param name Suggested name
+     */
+    private void addExpr(RexNode expr, String name) {
+      if ((name == null) && (expr instanceof RexInputRef)) {
+        final int i = ((RexInputRef) expr).getIndex();
+        name = bb.root.getRowType().getFieldList().get(i).getName();
+      }
+      if (Pair.right(convertedInputExprs).contains(name)) {
+        // In case like 'SELECT ... GROUP BY x, y, x', don't add
+        // name 'x' twice.
+        name = null;
+      }
+      convertedInputExprs.add(Pair.of(expr, name));
+    }
 
-        public Void visit(SqlLiteral lit) {
-            return null;
-        }
+    public Void visit(SqlIdentifier id) {
+      return null;
+    }
 
-        public Void visit(SqlDataTypeSpec type) {
-            return null;
-        }
+    public Void visit(SqlNodeList nodeList) {
+      for (int i = 0; i < nodeList.size(); i++) {
+        nodeList.get(i).accept(this);
+      }
+      return null;
+    }
 
-        public Void visit(SqlDynamicParam param) {
-            return null;
-        }
+    public Void visit(SqlLiteral lit) {
+      return null;
+    }
 
-        public Void visit(SqlIntervalQualifier intervalQualifier) {
-            return null;
-        }
+    public Void visit(SqlDataTypeSpec type) {
+      return null;
+    }
 
-        public Void visit(SqlCall call) {
-            switch (call.getKind()) {
-                case FILTER:
-                case WITHIN_GROUP:
-                    translateAgg(call);
-                    return null;
-                case SELECT:
-                    // rchen 2006-10-17:
-                    // for now do not detect aggregates in sub-queries.
-                    return null;
-                default:
-            }
-            final boolean prevInOver = inOver;
-            // Ignore window aggregates and ranking functions (associated with OVER
-            // operator). However, do not ignore nested window aggregates.
-            if (call.getOperator().getKind() == SqlKind.OVER) {
-                // Track aggregate nesting levels only within an OVER operator.
-                List operandList = call.getOperandList();
-                assert operandList.size() == 2;
-
-                // Ignore the top level window aggregates and ranking functions
-                // positioned as the first operand of a OVER operator
-                inOver = true;
-                operandList.get(0).accept(this);
-
-                // Normal translation for the second operand of a OVER operator
-                inOver = false;
-                operandList.get(1).accept(this);
-                return null;
-            }
+    public Void visit(SqlDynamicParam param) {
+      return null;
+    }
 
-            // Do not translate the top level window aggregate. Only do so for
-            // nested aggregates, if present
-            if (call.getOperator().isAggregator()) {
-                if (inOver) {
-                    // Add the parent aggregate level before visiting its children
-                    inOver = false;
-                } else {
-                    // We're beyond the one ignored level
-                    translateAgg(call);
-                    return null;
-                }
-            }
-            for (SqlNode operand : call.getOperandList()) {
-                // Operands are occasionally null, e.g. switched CASE arg 0.
-                if (operand != null) {
-                    operand.accept(this);
-                }
-            }
-            // Remove the parent aggregate level after visiting its children
-            inOver = prevInOver;
-            return null;
-        }
+    public Void visit(SqlIntervalQualifier intervalQualifier) {
+      return null;
+    }
 
-        private void translateAgg(SqlCall call) {
-            translateAgg(call, null, null, call);
-        }
+    public Void visit(SqlCall call) {
+      switch (call.getKind()) {
+        case FILTER:
+        case WITHIN_GROUP:
+          translateAgg(call);
+          return null;
+        case SELECT:
+          // rchen 2006-10-17:
+          // for now do not detect aggregates in sub-queries.
+          return null;
+        default:
+      }
+      final boolean prevInOver = inOver;
+      // Ignore window aggregates and ranking functions (associated with OVER
+      // operator). However, do not ignore nested window aggregates.
+      if (call.getOperator().getKind() == SqlKind.OVER) {
+        // Track aggregate nesting levels only within an OVER operator.
+        List operandList = call.getOperandList();
+        assert operandList.size() == 2;
+
+        // Ignore the top level window aggregates and ranking functions
+        // positioned as the first operand of a OVER operator
+        inOver = true;
+        operandList.get(0).accept(this);
+
+        // Normal translation for the second operand of a OVER operator
+        inOver = false;
+        operandList.get(1).accept(this);
+        return null;
+      }
+
+      // Do not translate the top level window aggregate. Only do so for
+      // nested aggregates, if present
+      if (call.getOperator().isAggregator()) {
+        if (inOver) {
+          // Add the parent aggregate level before visiting its children
+          inOver = false;
+        } else {
+          // We're beyond the one ignored level
+          translateAgg(call);
+          return null;
+        }
+      }
+      for (SqlNode operand : call.getOperandList()) {
+        // Operands are occasionally null, e.g. switched CASE arg 0.
+        if (operand != null) {
+          operand.accept(this);
+        }
+      }
+      // Remove the parent aggregate level after visiting its children
+      inOver = prevInOver;
+      return null;
+    }
 
-        private void translateAgg(SqlCall call, SqlNode filter,
-                                  SqlNodeList orderList, SqlCall outerCall) {
-            assert bb.getAgg() == this;
-            assert outerCall != null;
-            switch (call.getKind()) {
-                case FILTER:
-                    assert filter == null;
-                    translateAgg(call.operand(0), call.operand(1), orderList, outerCall);
-                    return;
-                case WITHIN_GROUP:
-                    assert orderList == null;
-                    translateAgg(call.operand(0), filter, call.operand(1), outerCall);
-                    return;
-                default:
-            }
-            final List args = new ArrayList<>();
-            int filterArg = -1;
-            final List argTypes =
-                call.getOperator() instanceof SqlCountAggFunction
-                    ? new ArrayList<>(call.getOperandList().size())
-                    : null;
-            try {
-                // switch out of agg mode
-                bb.setAgg(null);
-                for (SqlNode operand : call.getOperandList()) {
-
-                    // special case for COUNT(*):  delete the *
-                    if (operand instanceof SqlIdentifier) {
-                        SqlIdentifier id = (SqlIdentifier) operand;
-                        if (id.isStar()) {
-                            assert call.operandCount() == 1;
-                            assert args.isEmpty();
-                            break;
-                        }
-                    }
-                    RexNode convertedExpr = bb.convertExpression(operand);
-                    assert convertedExpr != null;
-                    if (argTypes != null) {
-                        argTypes.add(convertedExpr.getType());
-                    }
-                    args.add(lookupOrCreateGroupExpr(convertedExpr));
-                }
-
-                if (filter != null) {
-                    RexNode convertedExpr = bb.convertExpression(filter);
-                    assert convertedExpr != null;
-                    if (convertedExpr.getType().isNullable()) {
-                        convertedExpr =
-                            rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, convertedExpr);
-                    }
-                    filterArg = lookupOrCreateGroupExpr(convertedExpr);
-                }
-            } finally {
-                // switch back into agg mode
-                bb.setAgg(this);
-            }
+    private void translateAgg(SqlCall call) {
+      translateAgg(call, null, null, call);
+    }
 
-            SqlAggFunction aggFunction =
-                (SqlAggFunction) call.getOperator();
-            final RelDataType type = validator.deriveType(bb.scope, call);
-            boolean distinct = false;
-            SqlLiteral quantifier = call.getFunctionQuantifier();
-            if ((null != quantifier)
-                && (quantifier.getValue() == SqlSelectKeyword.DISTINCT)) {
-                distinct = true;
-            }
-            boolean approximate = false;
-            if (aggFunction == SqlStdOperatorTable.APPROX_COUNT_DISTINCT) {
-                aggFunction = SqlStdOperatorTable.COUNT;
-                distinct = true;
-                approximate = true;
+    private void translateAgg(
+        SqlCall call, SqlNode filter, SqlNodeList orderList, SqlCall outerCall) {
+      assert bb.getAgg() == this;
+      assert outerCall != null;
+      switch (call.getKind()) {
+        case FILTER:
+          assert filter == null;
+          translateAgg(call.operand(0), call.operand(1), orderList, outerCall);
+          return;
+        case WITHIN_GROUP:
+          assert orderList == null;
+          translateAgg(call.operand(0), filter, call.operand(1), outerCall);
+          return;
+        default:
+      }
+      final List args = new ArrayList<>();
+      int filterArg = -1;
+      final List argTypes =
+          call.getOperator() instanceof SqlCountAggFunction
+              ? new ArrayList<>(call.getOperandList().size())
+              : null;
+      try {
+        // switch out of agg mode
+        bb.setAgg(null);
+        for (SqlNode operand : call.getOperandList()) {
+
+          // special case for COUNT(*):  delete the *
+          if (operand instanceof SqlIdentifier) {
+            SqlIdentifier id = (SqlIdentifier) operand;
+            if (id.isStar()) {
+              assert call.operandCount() == 1;
+              assert args.isEmpty();
+              break;
             }
-            final RelCollation collation;
-            if (orderList == null || orderList.size() == 0) {
-                collation = RelCollations.EMPTY;
-            } else {
-                collation = RelCollations.of(
-                    orderList.getList()
-                        .stream()
-                        .map(order ->
-                            bb.convertSortExpression(order,
+          }
+          RexNode convertedExpr = bb.convertExpression(operand);
+          assert convertedExpr != null;
+          if (argTypes != null) {
+            argTypes.add(convertedExpr.getType());
+          }
+          args.add(lookupOrCreateGroupExpr(convertedExpr));
+        }
+
+        if (filter != null) {
+          RexNode convertedExpr = bb.convertExpression(filter);
+          assert convertedExpr != null;
+          if (convertedExpr.getType().isNullable()) {
+            convertedExpr = rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, convertedExpr);
+          }
+          filterArg = lookupOrCreateGroupExpr(convertedExpr);
+        }
+      } finally {
+        // switch back into agg mode
+        bb.setAgg(this);
+      }
+
+      SqlAggFunction aggFunction = (SqlAggFunction) call.getOperator();
+      final RelDataType type = validator.deriveType(bb.scope, call);
+      boolean distinct = false;
+      SqlLiteral quantifier = call.getFunctionQuantifier();
+      if ((null != quantifier) && (quantifier.getValue() == SqlSelectKeyword.DISTINCT)) {
+        distinct = true;
+      }
+      boolean approximate = false;
+      if (aggFunction == SqlStdOperatorTable.APPROX_COUNT_DISTINCT) {
+        aggFunction = SqlStdOperatorTable.COUNT;
+        distinct = true;
+        approximate = true;
+      }
+      final RelCollation collation;
+      if (orderList == null || orderList.size() == 0) {
+        collation = RelCollations.EMPTY;
+      } else {
+        collation =
+            RelCollations.of(
+                orderList.getList().stream()
+                    .map(
+                        order ->
+                            bb.convertSortExpression(
+                                order,
                                 RelFieldCollation.Direction.ASCENDING,
                                 RelFieldCollation.NullDirection.UNSPECIFIED))
-                        .map(fieldCollation ->
+                    .map(
+                        fieldCollation ->
                             new RelFieldCollation(
                                 lookupOrCreateGroupExpr(fieldCollation.left),
                                 fieldCollation.getDirection(),
                                 fieldCollation.getNullDirection()))
-                        .collect(Collectors.toList()));
-            }
-            final AggregateCall aggCall =
-                AggregateCall.create(
-                    aggFunction,
-                    distinct,
-                    approximate,
-                    args,
-                    filterArg,
-                    collation,
-                    type,
-                    nameMap.get(outerCall.toString()));
-
-            gqlReturnScope.resolved.get();
-            RexNode rex =
-                rexBuilder.addAggCall(
-                    aggCall,
-                    groupExprs.size(),
-                    false,
-                    aggCalls,
-                    aggCallMapping,
-                    argTypes);
-            aggMapping.put(outerCall, rex);
-        }
-
-        private int lookupOrCreateGroupExpr(RexNode expr) {
-            int index = 0;
-            for (RexNode convertedInputExpr : Pair.left(convertedInputExprs)) {
-                if (expr.equals(convertedInputExpr)) {
-                    return index;
-                }
-                ++index;
-            }
+                    .collect(Collectors.toList()));
+      }
+      final AggregateCall aggCall =
+          AggregateCall.create(
+              aggFunction,
+              distinct,
+              approximate,
+              args,
+              filterArg,
+              collation,
+              type,
+              nameMap.get(outerCall.toString()));
+
+      gqlReturnScope.resolved.get();
+      RexNode rex =
+          rexBuilder.addAggCall(
+              aggCall, groupExprs.size(), false, aggCalls, aggCallMapping, argTypes);
+      aggMapping.put(outerCall, rex);
+    }
 
-            // not found -- add it
-            addExpr(expr, null);
-            return index;
+    private int lookupOrCreateGroupExpr(RexNode expr) {
+      int index = 0;
+      for (RexNode convertedInputExpr : Pair.left(convertedInputExprs)) {
+        if (expr.equals(convertedInputExpr)) {
+          return index;
         }
+        ++index;
+      }
 
-        /**
-         * If an expression is structurally identical to one of the group-by
-         * expressions, returns a reference to the expression, otherwise returns
-         * null.
-         */
-        public int lookupGroupExpr(SqlNode expr) {
-            for (int i = 0; i < groupExprs.size(); i++) {
-                SqlNode groupExpr = groupExprs.get(i);
-                if (expr.equalsDeep(groupExpr, Litmus.IGNORE)) {
-                    return i;
-                }
-            }
-            return -1;
-        }
-
-        public RexNode lookupAggregates(SqlCall call) {
-            // assert call.getOperator().isAggregator();
-            assert bb.getAgg() == this;
-
-            for (Map.Entry> e
-                : auxiliaryGroupExprs.entrySet()) {
-                if (call.equalsDeep(e.getKey(), Litmus.IGNORE)) {
-                    AuxiliaryConverter converter = e.getValue().e;
-                    final int groupOrdinal = e.getValue().i;
-                    return converter.convert(rexBuilder,
-                        convertedInputExprs.get(groupOrdinal).left,
-                        rexBuilder.makeInputRef(bb.root, groupOrdinal));
-                }
-            }
-
-            return aggMapping.get(call);
-        }
+      // not found -- add it
+      addExpr(expr, null);
+      return index;
+    }
 
-        public List> getPreExprs() {
-            return convertedInputExprs;
-        }
+    /**
+     * If an expression is structurally identical to one of the group-by expressions, returns a
+     * reference to the expression, otherwise returns null.
+     */
+    public int lookupGroupExpr(SqlNode expr) {
+      for (int i = 0; i < groupExprs.size(); i++) {
+        SqlNode groupExpr = groupExprs.get(i);
+        if (expr.equalsDeep(groupExpr, Litmus.IGNORE)) {
+          return i;
+        }
+      }
+      return -1;
+    }
 
-        public List getAggCalls() {
-            return aggCalls;
-        }
+    public RexNode lookupAggregates(SqlCall call) {
+      // assert call.getOperator().isAggregator();
+      assert bb.getAgg() == this;
 
-        public RelDataTypeFactory getTypeFactory() {
-            return typeFactory;
+      for (Map.Entry> e : auxiliaryGroupExprs.entrySet()) {
+        if (call.equalsDeep(e.getKey(), Litmus.IGNORE)) {
+          AuxiliaryConverter converter = e.getValue().e;
+          final int groupOrdinal = e.getValue().i;
+          return converter.convert(
+              rexBuilder,
+              convertedInputExprs.get(groupOrdinal).left,
+              rexBuilder.makeInputRef(bb.root, groupOrdinal));
         }
+      }
 
+      return aggMapping.get(call);
     }
 
-    public static class GQLAggChecker extends SqlBasicVisitor {
-        //~ Instance fields --------------------------------------------------------
-
-        private final Deque scopes = new ArrayDeque<>();
-        private final List extraExprs;
-        private final List groupExprs;
-        private final boolean distinct;
-        private final SqlValidatorImpl validator;
-
-        //~ Constructors -----------------------------------------------------------
-
-        /**
-         * Creates an AggChecker.
-         *
-         * @param validator  Validator
-         * @param scope      Scope
-         * @param groupExprs Expressions in GROUP BY (or SELECT DISTINCT) clause, that are therefore available
-         * @param distinct   Whether aggregation checking is because of a SELECT DISTINCT clause
-         */
-        public GQLAggChecker(
-            SqlValidatorImpl validator,
-            SqlValidatorScope scope,
-            List extraExprs,
-            List groupExprs,
-            boolean distinct) {
-            this.validator = validator;
-            this.extraExprs = extraExprs;
-            this.groupExprs = groupExprs;
-            this.distinct = distinct;
-            this.scopes.push(scope);
-        }
-
-        //~ Methods ----------------------------------------------------------------
-
-        public boolean isGroupExpr(SqlNode expr) {
-            for (SqlNode groupExpr : groupExprs) {
-                if (groupExpr.equalsDeep(expr, Litmus.IGNORE)) {
-                    return true;
-                }
-            }
+    public List> getPreExprs() {
+      return convertedInputExprs;
+    }
 
-            for (SqlNode extraExpr : extraExprs) {
-                if (extraExpr.equalsDeep(expr, Litmus.IGNORE)) {
-                    return true;
-                }
-            }
-            return false;
-        }
+    public List getAggCalls() {
+      return aggCalls;
+    }
 
-        public Void visit(SqlIdentifier id) {
-            if (isGroupExpr(id) || id.isStar()) {
-                // Star may validly occur in "SELECT COUNT(*) OVER w"
-                return null;
-            }
+    public RelDataTypeFactory getTypeFactory() {
+      return typeFactory;
+    }
+  }
+
+  public static class GQLAggChecker extends SqlBasicVisitor {
+    // ~ Instance fields --------------------------------------------------------
+
+    private final Deque scopes = new ArrayDeque<>();
+    private final List extraExprs;
+    private final List groupExprs;
+    private final boolean distinct;
+    private final SqlValidatorImpl validator;
+
+    // ~ Constructors -----------------------------------------------------------
+
+    /**
+     * Creates an AggChecker.
+     *
+     * @param validator Validator
+     * @param scope Scope
+     * @param groupExprs Expressions in GROUP BY (or SELECT DISTINCT) clause, that are therefore
+     *     available
+     * @param distinct Whether aggregation checking is because of a SELECT DISTINCT clause
+     */
+    public GQLAggChecker(
+        SqlValidatorImpl validator,
+        SqlValidatorScope scope,
+        List extraExprs,
+        List groupExprs,
+        boolean distinct) {
+      this.validator = validator;
+      this.extraExprs = extraExprs;
+      this.groupExprs = groupExprs;
+      this.distinct = distinct;
+      this.scopes.push(scope);
+    }
 
-            // Is it a call to a parentheses-free function?
-            SqlCall call =
-                SqlUtil.makeCall(
-                    validator.getOperatorTable(),
-                    id);
-            if (call != null) {
-                return call.accept(this);
-            }
+    // ~ Methods ----------------------------------------------------------------
 
-            // Didn't find the identifier in the group-by list as is, now find
-            // it fully-qualified.
-            // TODO: It would be better if we always compared fully-qualified
-            // to fully-qualified.
-            assert scopes.peek() != null : "GQLToRelConverter has no scopes";
-            final SqlQualified fqId = scopes.peek().fullyQualify(id);
-            if (isGroupExpr(fqId.identifier)) {
-                return null;
-            }
-            SqlNode originalExpr = validator.getOriginal(id);
-            final String exprString = originalExpr.toString();
-            throw validator.newValidationError(originalExpr,
-                distinct
-                    ? RESOURCE.notSelectDistinctExpr(exprString)
-                    : RESOURCE.notGroupExpr(exprString));
+    public boolean isGroupExpr(SqlNode expr) {
+      for (SqlNode groupExpr : groupExprs) {
+        if (groupExpr.equalsDeep(expr, Litmus.IGNORE)) {
+          return true;
         }
+      }
 
-        public Void visit(SqlCall call) {
-            assert scopes.peek() != null : "GQLToRelConverter has no scopes";
-            final SqlValidatorScope scope = scopes.peek();
-            if (call.getOperator().isAggregator()) {
-                return null;
-            }
-            if (isGroupExpr(call)) {
-                // This call matches an expression in the GROUP BY clause.
-                return null;
-            }
-
-            final SqlCall groupCall =
-                SqlStdOperatorTable.convertAuxiliaryToGroupCall(call);
-            if (groupCall != null) {
-                if (isGroupExpr(groupCall)) {
-                    // This call is an auxiliary function that matches a group call in the
-                    // GROUP BY clause.
-                    //
-                    // For example TUMBLE_START is an auxiliary of the TUMBLE
-                    // group function, and
-                    //   TUMBLE_START(rowtime, INTERVAL '1' HOUR)
-                    // matches
-                    //   TUMBLE(rowtime, INTERVAL '1' HOUR')
-                    return null;
-                }
-                throw validator.newValidationError(groupCall,
-                    RESOURCE.auxiliaryWithoutMatchingGroupCall(
-                        call.getOperator().getName(), groupCall.getOperator().getName()));
-            }
-
-            if (call.isA(SqlKind.QUERY)) {
-                // Allow queries for now, even though they may contain
-                // references to forbidden columns.
-                return null;
-            }
-
-            // Switch to new scope.
-            SqlValidatorScope newScope = scope.getOperandScope(call);
-            scopes.push(newScope);
-
-            // Visit the operands (only expressions).
-            call.getOperator()
-                .acceptCall(this, call, true, ArgHandlerImpl.instance());
-
-            // Restore scope.
-            scopes.pop();
-            return null;
+      for (SqlNode extraExpr : extraExprs) {
+        if (extraExpr.equalsDeep(expr, Litmus.IGNORE)) {
+          return true;
         }
+      }
+      return false;
     }
 
-    private GQLValidatorImpl getValidator() {
-        return (GQLValidatorImpl) validator;
+    public Void visit(SqlIdentifier id) {
+      if (isGroupExpr(id) || id.isStar()) {
+        // Star may validly occur in "SELECT COUNT(*) OVER w"
+        return null;
+      }
+
+      // Is it a call to a parentheses-free function?
+      SqlCall call = SqlUtil.makeCall(validator.getOperatorTable(), id);
+      if (call != null) {
+        return call.accept(this);
+      }
+
+      // Didn't find the identifier in the group-by list as is, now find
+      // it fully-qualified.
+      // TODO: It would be better if we always compared fully-qualified
+      // to fully-qualified.
+      assert scopes.peek() != null : "GQLToRelConverter has no scopes";
+      final SqlQualified fqId = scopes.peek().fullyQualify(id);
+      if (isGroupExpr(fqId.identifier)) {
+        return null;
+      }
+      SqlNode originalExpr = validator.getOriginal(id);
+      final String exprString = originalExpr.toString();
+      throw validator.newValidationError(
+          originalExpr,
+          distinct
+              ? RESOURCE.notSelectDistinctExpr(exprString)
+              : RESOURCE.notGroupExpr(exprString));
     }
 
-    public class GQLBlackboard extends Blackboard {
+    public Void visit(SqlCall call) {
+      assert scopes.peek() != null : "GQLToRelConverter has no scopes";
+      final SqlValidatorScope scope = scopes.peek();
+      if (call.getOperator().isAggregator()) {
+        return null;
+      }
+      if (isGroupExpr(call)) {
+        // This call matches an expression in the GROUP BY clause.
+        return null;
+      }
+
+      final SqlCall groupCall = SqlStdOperatorTable.convertAuxiliaryToGroupCall(call);
+      if (groupCall != null) {
+        if (isGroupExpr(groupCall)) {
+          // This call is an auxiliary function that matches a group call in the
+          // GROUP BY clause.
+          //
+          // For example TUMBLE_START is an auxiliary of the TUMBLE
+          // group function, and
+          //   TUMBLE_START(rowtime, INTERVAL '1' HOUR)
+          // matches
+          //   TUMBLE(rowtime, INTERVAL '1' HOUR')
+          return null;
+        }
+        throw validator.newValidationError(
+            groupCall,
+            RESOURCE.auxiliaryWithoutMatchingGroupCall(
+                call.getOperator().getName(), groupCall.getOperator().getName()));
+      }
+
+      if (call.isA(SqlKind.QUERY)) {
+        // Allow queries for now, even though they may contain
+        // references to forbidden columns.
+        return null;
+      }
+
+      // Switch to new scope.
+      SqlValidatorScope newScope = scope.getOperandScope(call);
+      scopes.push(newScope);
+
+      // Visit the operands (only expressions).
+      call.getOperator().acceptCall(this, call, true, ArgHandlerImpl.instance());
+
+      // Restore scope.
+      scopes.pop();
+      return null;
+    }
+  }
 
-        private Blackboard withBb;
+  private GQLValidatorImpl getValidator() {
+    return (GQLValidatorImpl) validator;
+  }
 
-        protected GQLBlackboard(SqlValidatorScope scope,
-                                Map nameToNodeMap, boolean top) {
-            super(scope, nameToNodeMap, top);
-        }
+  public class GQLBlackboard extends Blackboard {
 
-        public GQLBlackboard setWithBb(Blackboard withBb) {
-            this.withBb = withBb;
-            return this;
-        }
+    private Blackboard withBb;
 
-        @Override
-        public RexNode convertExpression(SqlNode expr) {
-            RexNode rexNode;
-            switch (expr.getKind()) {
-                case VERTEX_VALUE_CONSTRUCTOR:
-                case EDGE_VALUE_CONSTRUCTOR:
-                    AbstractSqlGraphElementConstruct construct = (AbstractSqlGraphElementConstruct) expr;
-                    SqlNode[] valueNodes = construct.getValueNodes();
-                    SqlIdentifier[] keyNodes = construct.getKeyNodes();
-                    List operands = new ArrayList<>();
-                    Map rex2VariableInfo = new HashMap<>();
-
-                    for (int i = 0; i < valueNodes.length; i++) {
-                        SqlNode valueNode = valueNodes[i];
-                        RexNode operand = convertExpression(valueNode);
-                        VariableInfo variableInfo = new VariableInfo(false, keyNodes[i].getSimple());
-                        operands.add(operand);
-                        rex2VariableInfo.put(operand, variableInfo);
-                    }
-
-                    RelDataType type = getValidator().getValidatedNodeType(expr);
-                    rexNode = new RexObjectConstruct(type, operands, rex2VariableInfo);
-                    break;
-                case GQL_PATH_PATTERN_SUB_QUERY:
-                    SqlPathPatternSubQuery subQuery = (SqlPathPatternSubQuery) expr;
-
-                    SqlValidatorNamespace ns = getValidator().getNamespace(subQuery.getPathPattern());
-                    assert ns.getType() instanceof PathRecordType;
-                    PathRecordType pathRecordType = (PathRecordType) ns.getType();
-                    assert pathRecordType.getFieldCount() > 0;
-
-                    IMatchNode matchNode = convertPathPattern(subQuery.getPathPattern(), withBb);
-                    assert matchNode instanceof SingleMatchNode : "Sub-query should be single path match";
-                    IMatchNode firstNode = GQLRelUtil.getFirstMatchNode((SingleMatchNode) matchNode);
-
-                    assert pathRecordType.firstField().isPresent() : "Path type is empty";
-                    VertexRecordType firstFieldType = (VertexRecordType) pathRecordType.firstField().get().getType();
-                    // SubQueryStart's path schema is the same with the first match node in the suq-query.
-                    SubQueryStart subQueryStart = SubQueryStart.create(matchNode.getCluster(), matchNode.getTraitSet(),
-                        generateSubQueryName(), firstNode.getPathSchema(), firstFieldType);
-                    // add sub query start node to the head of the match node.
-                    matchNode = GQLRelUtil.addSubQueryStartNode(matchNode, subQueryStart);
-
-                    SqlNode returnValue = subQuery.getReturnValue();
-                    SqlNode originReturnValue = GQLToRelConverter.this.getValidator().getOriginal(returnValue);
-                    GQLScope returnValueScope = (GQLScope) GQLToRelConverter.this.getValidator()
-                        .getScopes(originReturnValue);
-
-                    GQLBlackboard bb = new GQLBlackboard(returnValueScope, null, true);
-                    bb.setRoot(matchNode, true);
-                    RexNode returnValueNode = bb.convertExpression(returnValue);
-
-                    RexSubQuery pathPatternSubQuery = RexSubQuery.create(matchNode.getPathSchema(),
-                        SqlPathPatternOperator.INSTANCE, ImmutableList.of(), matchNode);
-                    rexNode = new RexLambdaCall(pathPatternSubQuery, returnValueNode);
-                    break;
-                default:
-                    rexNode = super.convertExpression(expr);
-                    break;
-            }
-            return convertRexParameterRef(scope, rexNode);
-        }
+    protected GQLBlackboard(
+        SqlValidatorScope scope, Map nameToNodeMap, boolean top) {
+      super(scope, nameToNodeMap, top);
     }
 
-    private RexNode convertRexParameterRef(SqlValidatorScope scope, RexNode rexNode) {
-        if (scope instanceof ListScope
-            && ((ListScope) scope).getParent() instanceof GQLWithBodyScope) {
-            GQLWithBodyScope withScope = (GQLWithBodyScope) ((ListScope) scope).getParent();
-
-            return rexNode.accept(new RexShuttle() {
-                @Override
-                public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
-                    // replace CorrelVariable to ParameterRef for referring the with-body fields.
-                    if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
-                        SqlValidatorNamespace withItemNs = withScope.children.get(0).getNamespace();
-                        return new RexParameterRef(fieldAccess.getField().getIndex(),
-                            fieldAccess.getType(), withItemNs.getType());
-                    }
-                    return fieldAccess;
-                }
-            });
-        }
-        return rexNode;
+    public GQLBlackboard setWithBb(Blackboard withBb) {
+      this.withBb = withBb;
+      return this;
     }
 
     @Override
-    protected GQLBlackboard createBlackboard(SqlValidatorScope scope,
-                                             Map nameToNodeMap, boolean top) {
-        return new GQLBlackboard(scope, nameToNodeMap, top);
-    }
-
-    public boolean isCaseSensitive() {
-        return getValidator().isCaseSensitive();
+    public RexNode convertExpression(SqlNode expr) {
+      RexNode rexNode;
+      switch (expr.getKind()) {
+        case VERTEX_VALUE_CONSTRUCTOR:
+        case EDGE_VALUE_CONSTRUCTOR:
+          AbstractSqlGraphElementConstruct construct = (AbstractSqlGraphElementConstruct) expr;
+          SqlNode[] valueNodes = construct.getValueNodes();
+          SqlIdentifier[] keyNodes = construct.getKeyNodes();
+          List operands = new ArrayList<>();
+          Map rex2VariableInfo = new HashMap<>();
+
+          for (int i = 0; i < valueNodes.length; i++) {
+            SqlNode valueNode = valueNodes[i];
+            RexNode operand = convertExpression(valueNode);
+            VariableInfo variableInfo = new VariableInfo(false, keyNodes[i].getSimple());
+            operands.add(operand);
+            rex2VariableInfo.put(operand, variableInfo);
+          }
+
+          RelDataType type = getValidator().getValidatedNodeType(expr);
+          rexNode = new RexObjectConstruct(type, operands, rex2VariableInfo);
+          break;
+        case GQL_PATH_PATTERN_SUB_QUERY:
+          SqlPathPatternSubQuery subQuery = (SqlPathPatternSubQuery) expr;
+
+          SqlValidatorNamespace ns = getValidator().getNamespace(subQuery.getPathPattern());
+          assert ns.getType() instanceof PathRecordType;
+          PathRecordType pathRecordType = (PathRecordType) ns.getType();
+          assert pathRecordType.getFieldCount() > 0;
+
+          IMatchNode matchNode = convertPathPattern(subQuery.getPathPattern(), withBb);
+          assert matchNode instanceof SingleMatchNode : "Sub-query should be single path match";
+          IMatchNode firstNode = GQLRelUtil.getFirstMatchNode((SingleMatchNode) matchNode);
+
+          assert pathRecordType.firstField().isPresent() : "Path type is empty";
+          VertexRecordType firstFieldType =
+              (VertexRecordType) pathRecordType.firstField().get().getType();
+          // SubQueryStart's path schema is the same with the first match node in the suq-query.
+          SubQueryStart subQueryStart =
+              SubQueryStart.create(
+                  matchNode.getCluster(),
+                  matchNode.getTraitSet(),
+                  generateSubQueryName(),
+                  firstNode.getPathSchema(),
+                  firstFieldType);
+          // add sub query start node to the head of the match node.
+          matchNode = GQLRelUtil.addSubQueryStartNode(matchNode, subQueryStart);
+
+          SqlNode returnValue = subQuery.getReturnValue();
+          SqlNode originReturnValue =
+              GQLToRelConverter.this.getValidator().getOriginal(returnValue);
+          GQLScope returnValueScope =
+              (GQLScope) GQLToRelConverter.this.getValidator().getScopes(originReturnValue);
+
+          GQLBlackboard bb = new GQLBlackboard(returnValueScope, null, true);
+          bb.setRoot(matchNode, true);
+          RexNode returnValueNode = bb.convertExpression(returnValue);
+
+          RexSubQuery pathPatternSubQuery =
+              RexSubQuery.create(
+                  matchNode.getPathSchema(),
+                  SqlPathPatternOperator.INSTANCE,
+                  ImmutableList.of(),
+                  matchNode);
+          rexNode = new RexLambdaCall(pathPatternSubQuery, returnValueNode);
+          break;
+        default:
+          rexNode = super.convertExpression(expr);
+          break;
+      }
+      return convertRexParameterRef(scope, rexNode);
     }
-
-    private String generateSubQueryName() {
-        return "SubQuery-" + queryIdCounter++;
+  }
+
+  private RexNode convertRexParameterRef(SqlValidatorScope scope, RexNode rexNode) {
+    if (scope instanceof ListScope && ((ListScope) scope).getParent() instanceof GQLWithBodyScope) {
+      GQLWithBodyScope withScope = (GQLWithBodyScope) ((ListScope) scope).getParent();
+
+      return rexNode.accept(
+          new RexShuttle() {
+            @Override
+            public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+              // replace CorrelVariable to ParameterRef for referring the with-body fields.
+              if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+                SqlValidatorNamespace withItemNs = withScope.children.get(0).getNamespace();
+                return new RexParameterRef(
+                    fieldAccess.getField().getIndex(), fieldAccess.getType(), withItemNs.getType());
+              }
+              return fieldAccess;
+            }
+          });
     }
+    return rexNode;
+  }
+
+  @Override
+  protected GQLBlackboard createBlackboard(
+      SqlValidatorScope scope, Map nameToNodeMap, boolean top) {
+    return new GQLBlackboard(scope, nameToNodeMap, top);
+  }
+
+  public boolean isCaseSensitive() {
+    return getValidator().isCaseSensitive();
+  }
+
+  private String generateSubQueryName() {
+    return "SubQuery-" + queryIdCounter++;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphAlgorithm.java
index f3f003efa..bd8c7fedf 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphAlgorithm.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphAlgorithm.java
@@ -23,6 +23,7 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -37,58 +38,62 @@
 
 public abstract class GraphAlgorithm extends SingleRel {
 
-    protected final Class userFunctionClass;
-
-    protected final Object[] params;
-
-    protected GraphAlgorithm(RelOptCluster cluster, RelTraitSet traits,
-                             RelNode input,
-                             Class userFunctionClass,
-                             Object[] params) {
-        super(cluster, traits, input);
-        this.userFunctionClass = Objects.requireNonNull(userFunctionClass);
-        this.params = Objects.requireNonNull(params);
-        this.rowType = getFunctionOutputType(userFunctionClass, cluster.getTypeFactory());
-    }
+  protected final Class userFunctionClass;
 
-    public Class getUserFunctionClass() {
-        return userFunctionClass;
-    }
+  protected final Object[] params;
 
-    public Object[] getParams() {
-        return params;
-    }
+  protected GraphAlgorithm(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      Class userFunctionClass,
+      Object[] params) {
+    super(cluster, traits, input);
+    this.userFunctionClass = Objects.requireNonNull(userFunctionClass);
+    this.params = Objects.requireNonNull(params);
+    this.rowType = getFunctionOutputType(userFunctionClass, cluster.getTypeFactory());
+  }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        List paraClassNames = Arrays.stream(params).map(
-            para -> para.getClass().getSimpleName()
-        ).collect(Collectors.toList());
-        return super.explainTerms(pw)
-            .item("algo", userFunctionClass.getSimpleName())
-            .item("params", paraClassNames)
-            .item("outputType", getFunctionOutputType(userFunctionClass,
-                getCluster().getTypeFactory()));
-    }
+  public Class getUserFunctionClass() {
+    return userFunctionClass;
+  }
 
-    public static RelDataType getFunctionOutputType(Class userFunctionClass,
-                                                    RelDataTypeFactory typeFactory) {
-        AlgorithmUserFunction userFunction = ClassUtil.newInstance(userFunctionClass);
-        GQLJavaTypeFactory factory = (GQLJavaTypeFactory) typeFactory;
-        return SqlTypeUtil.convertToRelType(userFunction.getOutputType(
-                factory.getCurrentGraph().getGraphSchema(factory)),
-            false, typeFactory);
-    }
+  public Object[] getParams() {
+    return params;
+  }
 
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    List paraClassNames =
+        Arrays.stream(params)
+            .map(para -> para.getClass().getSimpleName())
+            .collect(Collectors.toList());
+    return super.explainTerms(pw)
+        .item("algo", userFunctionClass.getSimpleName())
+        .item("params", paraClassNames)
+        .item(
+            "outputType", getFunctionOutputType(userFunctionClass, getCluster().getTypeFactory()));
+  }
 
-    public abstract GraphAlgorithm copy(RelTraitSet traitSet, RelNode input,
-                                        Class userFunctionClass,
-                                        Object[] params);
+  public static RelDataType getFunctionOutputType(
+      Class userFunctionClass, RelDataTypeFactory typeFactory) {
+    AlgorithmUserFunction userFunction = ClassUtil.newInstance(userFunctionClass);
+    GQLJavaTypeFactory factory = (GQLJavaTypeFactory) typeFactory;
+    return SqlTypeUtil.convertToRelType(
+        userFunction.getOutputType(factory.getCurrentGraph().getGraphSchema(factory)),
+        false,
+        typeFactory);
+  }
 
-    @Override
-    public GraphAlgorithm copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, inputs.get(0), userFunctionClass, params);
-    }
+  public abstract GraphAlgorithm copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      Class userFunctionClass,
+      Object[] params);
 
+  @Override
+  public GraphAlgorithm copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, inputs.get(0), userFunctionClass, params);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphMatch.java
index dd7b524a1..d634d01d1 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphMatch.java
@@ -21,6 +21,7 @@
 
 import java.util.*;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -54,239 +55,261 @@
 
 public abstract class GraphMatch extends SingleRel {
 
-    protected final IMatchNode pathPattern;
-
-    protected GraphMatch(RelOptCluster cluster, RelTraitSet traits,
-                         RelNode input, IMatchNode pathPattern, RelDataType rowType) {
-        super(cluster, traits, input);
-        this.rowType = Objects.requireNonNull(rowType);
-        this.pathPattern = Objects.requireNonNull(pathPattern);
-        validateInput(input);
+  protected final IMatchNode pathPattern;
+
+  protected GraphMatch(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      IMatchNode pathPattern,
+      RelDataType rowType) {
+    super(cluster, traits, input);
+    this.rowType = Objects.requireNonNull(rowType);
+    this.pathPattern = Objects.requireNonNull(pathPattern);
+    validateInput(input);
+  }
+
+  private void validateInput(RelNode input) {
+    SqlTypeName inputType = input.getRowType().getSqlTypeName();
+    if (inputType != SqlTypeName.GRAPH && inputType != SqlTypeName.PATH) {
+      throw new GeaFlowDSLException(
+          "Illegal input type:" + inputType + ", for " + getRelTypeName());
     }
-
-    private void validateInput(RelNode input) {
-        SqlTypeName inputType = input.getRowType().getSqlTypeName();
-        if (inputType != SqlTypeName.GRAPH && inputType != SqlTypeName.PATH) {
-            throw new GeaFlowDSLException("Illegal input type:" + inputType
-                + ", for " + getRelTypeName());
-        }
+  }
+
+  public boolean canConcat(IMatchNode pathPattern) {
+    return GQLRelUtil.isAllSingleMatch(this.pathPattern)
+        && GQLRelUtil.isAllSingleMatch(pathPattern)
+        && this.pathPattern.getPathSchema().canConcat(pathPattern.getPathSchema());
+  }
+
+  /** Merge with a path pattern to generate a new graph match node. */
+  public GraphMatch merge(IMatchNode pathPattern) {
+    if (canConcat(pathPattern)) {
+      SingleMatchNode concatPathPattern =
+          GQLRelUtil.concatPathPattern(
+              (SingleMatchNode) this.pathPattern, (SingleMatchNode) pathPattern, true);
+
+      return this.copy(getTraitSet(), input, concatPathPattern, concatPathPattern.getPathSchema());
+    } else {
+      RexNode condition =
+          GQLRelUtil.createPathJoinCondition(
+              this.pathPattern, pathPattern, true, getCluster().getRexBuilder());
+      MatchJoin join =
+          MatchJoin.create(
+              getCluster(),
+              getTraitSet(),
+              this.pathPattern,
+              pathPattern,
+              condition,
+              JoinRelType.INNER);
+
+      return this.copy(getTraitSet(), input, join, join.getRowType());
     }
+  }
 
-    public boolean canConcat(IMatchNode pathPattern) {
-        return GQLRelUtil.isAllSingleMatch(this.pathPattern)
-            && GQLRelUtil.isAllSingleMatch(pathPattern)
-            && this.pathPattern.getPathSchema().canConcat(pathPattern.getPathSchema())
-            ;
-    }
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    ExplainVisitor explainVisitor = new ExplainVisitor();
+    String path = explainVisitor.visit(pathPattern);
 
-    /**
-     * Merge with a path pattern to generate a new graph match node.
-     */
-    public GraphMatch merge(IMatchNode pathPattern) {
-        if (canConcat(pathPattern)) {
-            SingleMatchNode concatPathPattern = GQLRelUtil.concatPathPattern(
-                (SingleMatchNode) this.pathPattern, (SingleMatchNode) pathPattern, true);
+    return super.explainTerms(pw).item("path", path);
+  }
 
-            return this.copy(getTraitSet(), input, concatPathPattern, concatPathPattern.getPathSchema());
-        } else {
-            RexNode condition = GQLRelUtil.createPathJoinCondition(this.pathPattern, pathPattern,
-                true, getCluster().getRexBuilder());
-            MatchJoin join = MatchJoin.create(getCluster(), getTraitSet(),
-                this.pathPattern, pathPattern, condition, JoinRelType.INNER);
+  public static class ExplainVisitor extends AbstractMatchNodeVisitor {
 
-            return this.copy(getTraitSet(), input, join, join.getRowType());
-        }
+    @Override
+    public String visitVertexMatch(VertexMatch vertexMatch) {
+      String inputString = "";
+      if (vertexMatch.getInput() != null) {
+        inputString = visit(vertexMatch.getInput());
+      }
+      String nodeString =
+          "(" + vertexMatch.getLabel() + ":" + StringUtils.join(vertexMatch.getTypes(), "|") + ")";
+      return inputString + nodeString;
     }
 
     @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        ExplainVisitor explainVisitor = new ExplainVisitor();
-        String path = explainVisitor.visit(pathPattern);
-
-        return super.explainTerms(pw)
-            .item("path", path);
+    public String visitEdgeMatch(EdgeMatch edgeMatch) {
+      String inputString = "";
+      if (edgeMatch.getInput() != null) {
+        inputString = visit(edgeMatch.getInput()) + "-";
+      }
+      String direction;
+      switch (edgeMatch.getDirection()) {
+        case OUT:
+          direction = "->";
+          break;
+        case IN:
+          direction = "<-";
+          break;
+        case BOTH:
+          direction = "-";
+          break;
+        default:
+          throw new IllegalArgumentException("Illegal edge direction: " + edgeMatch.getDirection());
+      }
+      String nodeString =
+          "["
+              + edgeMatch.getLabel()
+              + ":"
+              + StringUtils.join(edgeMatch.getTypes(), "|")
+              + "]"
+              + direction;
+      return inputString + nodeString;
     }
 
-    public static class ExplainVisitor extends AbstractMatchNodeVisitor {
-
-        @Override
-        public String visitVertexMatch(VertexMatch vertexMatch) {
-            String inputString = "";
-            if (vertexMatch.getInput() != null) {
-                inputString = visit(vertexMatch.getInput());
-            }
-            String nodeString = "(" + vertexMatch.getLabel() + ":"
-                + StringUtils.join(vertexMatch.getTypes(), "|") + ")";
-            return inputString + nodeString;
-        }
-
-        @Override
-        public String visitEdgeMatch(EdgeMatch edgeMatch) {
-            String inputString = "";
-            if (edgeMatch.getInput() != null) {
-                inputString = visit(edgeMatch.getInput()) + "-";
-            }
-            String direction;
-            switch (edgeMatch.getDirection()) {
-                case OUT:
-                    direction = "->";
-                    break;
-                case IN:
-                    direction = "<-";
-                    break;
-                case BOTH:
-                    direction = "-";
-                    break;
-                default:
-                    throw new IllegalArgumentException("Illegal edge direction: " + edgeMatch.getDirection());
-            }
-            String nodeString = "[" + edgeMatch.getLabel() + ":"
-                + StringUtils.join(edgeMatch.getTypes(), "|")
-                + "]" + direction;
-            return inputString + nodeString;
-        }
-
-        @Override
-        public String visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) {
-            String inputString = visit(virtualEdgeMatch.getInput()) + "-";
-            String nodeString = "[:"
-                + " targetId=" + virtualEdgeMatch.getTargetId() + "]--";
-            return inputString + nodeString;
-        }
-
-        @Override
-        public String visitFilter(MatchFilter filter) {
-            return visit(filter.getInput()) + " where " + filter.getCondition() + " ";
-        }
-
-        @Override
-        public String visitJoin(MatchJoin join) {
-            return "{" + visit(join.getLeft()) + "} Join {" + visit(join.getRight()) + "}";
-        }
-
-        @Override
-        public String visitDistinct(MatchDistinct distinct) {
-            return visit(distinct.getInput()) + " Distinct" + "{" + visit(distinct.getInput()) + "}";
-        }
-
-        @Override
-        public String visitUnion(MatchUnion union) {
-            return union.getInputs().stream()
-                .map(this::visit)
-                .map(explain -> "{" + explain + "}")
-                .collect(Collectors.joining(" Union "));
-        }
-
-        @Override
-        public String visitLoopMatch(LoopUntilMatch loopMatch) {
-            String inputString = visit(loopMatch.getInput()) + "-";
-            return inputString + " loop(" + visit(loopMatch.getLoopBody()) + ")"
-                + ".time(" + loopMatch.getMinLoopCount() + "," + loopMatch.getMaxLoopCount() + ")"
-                + ".until(" + loopMatch.getUtilCondition() + ")";
-        }
-
-        @Override
-        public String visitSubQueryStart(SubQueryStart subQueryStart) {
-            return "";
-        }
+    @Override
+    public String visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) {
+      String inputString = visit(virtualEdgeMatch.getInput()) + "-";
+      String nodeString = "[:" + " targetId=" + virtualEdgeMatch.getTargetId() + "]--";
+      return inputString + nodeString;
+    }
 
-        @Override
-        public String visitPathModify(MatchPathModify pathModify) {
-            return visit(pathModify.getInput()) + " PathModify(" + pathModify.expressions + ")";
-        }
+    @Override
+    public String visitFilter(MatchFilter filter) {
+      return visit(filter.getInput()) + " where " + filter.getCondition() + " ";
+    }
 
-        @Override
-        public String visitExtend(MatchExtend matchExtend) {
-            return visit(matchExtend.getInput()) + " MatchExtend(" + matchExtend.expressions + ")";
-        }
+    @Override
+    public String visitJoin(MatchJoin join) {
+      return "{" + visit(join.getLeft()) + "} Join {" + visit(join.getRight()) + "}";
+    }
 
-        @Override
-        public String visitSort(MatchPathSort pathSort) {
-            return visit(pathSort.getInput()) + " order by "
-                + StringUtils.join(pathSort.orderByExpressions, ",")
-                + " limit " + pathSort.getLimit();
-        }
+    @Override
+    public String visitDistinct(MatchDistinct distinct) {
+      return visit(distinct.getInput()) + " Distinct" + "{" + visit(distinct.getInput()) + "}";
+    }
 
-        @Override
-        public String visitAggregate(MatchAggregate matchAggregate) {
-            return visit(matchAggregate.getInput()) + " aggregate ";
-        }
+    @Override
+    public String visitUnion(MatchUnion union) {
+      return union.getInputs().stream()
+          .map(this::visit)
+          .map(explain -> "{" + explain + "}")
+          .collect(Collectors.joining(" Union "));
     }
 
-    public IMatchNode getPathPattern() {
-        return pathPattern;
+    @Override
+    public String visitLoopMatch(LoopUntilMatch loopMatch) {
+      String inputString = visit(loopMatch.getInput()) + "-";
+      return inputString
+          + " loop("
+          + visit(loopMatch.getLoopBody())
+          + ")"
+          + ".time("
+          + loopMatch.getMinLoopCount()
+          + ","
+          + loopMatch.getMaxLoopCount()
+          + ")"
+          + ".until("
+          + loopMatch.getUtilCondition()
+          + ")";
     }
 
-    public abstract GraphMatch copy(RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType);
+    @Override
+    public String visitSubQueryStart(SubQueryStart subQueryStart) {
+      return "";
+    }
 
-    public GraphMatch copy(IMatchNode pathPattern) {
-        return copy(traitSet, input, pathPattern, pathPattern.getRowType());
+    @Override
+    public String visitPathModify(MatchPathModify pathModify) {
+      return visit(pathModify.getInput()) + " PathModify(" + pathModify.expressions + ")";
     }
 
     @Override
-    public GraphMatch copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, inputs.get(0), pathPattern, rowType);
+    public String visitExtend(MatchExtend matchExtend) {
+      return visit(matchExtend.getInput()) + " MatchExtend(" + matchExtend.expressions + ")";
     }
 
     @Override
-    public RelNode accept(RexShuttle shuttle) {
-        return copy(traitSet, input, (IMatchNode) pathPattern.accept(shuttle), rowType);
+    public String visitSort(MatchPathSort pathSort) {
+      return visit(pathSort.getInput())
+          + " order by "
+          + StringUtils.join(pathSort.orderByExpressions, ",")
+          + " limit "
+          + pathSort.getLimit();
     }
 
     @Override
-    public RelNode accept(RelShuttle shuttle) {
-        return copy(traitSet, input, (IMatchNode) pathPattern.accept(shuttle), rowType);
+    public String visitAggregate(MatchAggregate matchAggregate) {
+      return visit(matchAggregate.getInput()) + " aggregate ";
     }
+  }
+
+  public IMatchNode getPathPattern() {
+    return pathPattern;
+  }
+
+  public abstract GraphMatch copy(
+      RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType);
+
+  public GraphMatch copy(IMatchNode pathPattern) {
+    return copy(traitSet, input, pathPattern, pathPattern.getRowType());
+  }
+
+  @Override
+  public GraphMatch copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, inputs.get(0), pathPattern, rowType);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    return copy(traitSet, input, (IMatchNode) pathPattern.accept(shuttle), rowType);
+  }
+
+  @Override
+  public RelNode accept(RelShuttle shuttle) {
+    return copy(traitSet, input, (IMatchNode) pathPattern.accept(shuttle), rowType);
+  }
+
+  /**
+   * Returns a string representation of filtered fields in dictionary order for both graph nodes and
+   * fields within each node.
+   */
+  public String getFilteredFields() {
+    Map> nodeFieldsMap = new TreeMap<>();
+
+    Queue nodeQueue = new LinkedList<>();
+    Set visitedNodes = new HashSet<>();
+
+    nodeQueue.offer(this.pathPattern);
+    visitedNodes.add(this.pathPattern);
+
+    while (!nodeQueue.isEmpty()) {
+      IMatchNode currentNode = nodeQueue.poll();
+      String nodeLabel = null;
+      Set nodeFields = null;
+
+      if (currentNode instanceof VertexMatch) {
+        VertexMatch vertexMatch = (VertexMatch) currentNode;
+        nodeLabel = vertexMatch.getLabel();
+        nodeFields = vertexMatch.getFields();
+      } else if (currentNode instanceof EdgeMatch) {
+        EdgeMatch edgeMatch = (EdgeMatch) currentNode;
+        nodeLabel = edgeMatch.getLabel();
+        nodeFields = edgeMatch.getFields();
+      }
+
+      if (nodeLabel != null) {
+        Set fields = nodeFieldsMap.computeIfAbsent(nodeLabel, k -> new TreeSet<>());
+
+        if (nodeFields != null && !nodeFields.isEmpty()) {
+          for (RexFieldAccess field : nodeFields) {
+            fields.add(nodeLabel + "." + field.getField().getName());
+          }
+        } else {
+          fields.add("null");
+        }
+      }
 
-    /**
-     * Returns a string representation of filtered fields in dictionary order
-     * for both graph nodes and fields within each node.
-     */
-    public String getFilteredFields() {
-        Map> nodeFieldsMap = new TreeMap<>();
-
-        Queue nodeQueue = new LinkedList<>();
-        Set visitedNodes = new HashSet<>();
-
-        nodeQueue.offer(this.pathPattern);
-        visitedNodes.add(this.pathPattern);
-
-        while (!nodeQueue.isEmpty()) {
-            IMatchNode currentNode = nodeQueue.poll();
-            String nodeLabel = null;
-            Set nodeFields = null;
-
-            if (currentNode instanceof VertexMatch) {
-                VertexMatch vertexMatch = (VertexMatch) currentNode;
-                nodeLabel = vertexMatch.getLabel();
-                nodeFields = vertexMatch.getFields();
-            } else if (currentNode instanceof EdgeMatch) {
-                EdgeMatch edgeMatch = (EdgeMatch) currentNode;
-                nodeLabel = edgeMatch.getLabel();
-                nodeFields = edgeMatch.getFields();
-            }
-
-            if (nodeLabel != null) {
-                Set fields = nodeFieldsMap.computeIfAbsent(nodeLabel, k -> new TreeSet<>());
-
-                if (nodeFields != null && !nodeFields.isEmpty()) {
-                    for (RexFieldAccess field : nodeFields) {
-                        fields.add(nodeLabel + "." + field.getField().getName());
-                    }
-                } else {
-                    fields.add("null");
-                }
-            }
-
-            for (RelNode inputNode : currentNode.getInputs()) {
-                if (inputNode != null && !visitedNodes.contains((IMatchNode) inputNode)) {
-                    nodeQueue.offer((IMatchNode) inputNode);
-                    visitedNodes.add((IMatchNode) inputNode);
-                }
-            }
+      for (RelNode inputNode : currentNode.getInputs()) {
+        if (inputNode != null && !visitedNodes.contains((IMatchNode) inputNode)) {
+          nodeQueue.offer((IMatchNode) inputNode);
+          visitedNodes.add((IMatchNode) inputNode);
         }
-        return nodeFieldsMap.toString();
+      }
     }
-
+    return nodeFieldsMap.toString();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphModify.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphModify.java
index b0518182f..3e533133e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphModify.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphModify.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -31,53 +32,58 @@
 import org.apache.geaflow.dsl.schema.GeaFlowGraph;
 
 public abstract class GraphModify extends SingleRel {
-    /**
-     * The graph to write.
-     */
-    protected final GeaFlowGraph graph;
+  /** The graph to write. */
+  protected final GeaFlowGraph graph;
 
-    protected GraphModify(RelOptCluster cluster, RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) {
-        super(cluster, traitSet, input);
-        this.graph = Objects.requireNonNull(graph);
-        validateInput(input);
-    }
+  protected GraphModify(
+      RelOptCluster cluster, RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) {
+    super(cluster, traitSet, input);
+    this.graph = Objects.requireNonNull(graph);
+    validateInput(input);
+  }
 
-    private void validateInput(RelNode input) {
-        RelDataType inputType = input.getRowType();
-        RelDataType graphType = getRowType();
-        if (inputType.getFieldCount() != graphType.getFieldCount()) {
-            throw new GeaFlowDSLException("Input type field size: " + inputType.getFieldCount()
-                + " is not equal to the target graph field size: " + graphType.getFieldCount());
-        }
-        for (int i = 0; i < inputType.getFieldCount(); i++) {
-            RelDataType inputFieldType = inputType.getFieldList().get(i).getType();
-            RelDataType targetFieldType = graphType.getFieldList().get(i).getType();
-            if (!inputFieldType.equals(targetFieldType)) {
-                throw new GeaFlowDSLException("Input field type: " + inputFieldType
-                    + " is mismatch with the target graph field type: " + targetFieldType);
-            }
-        }
+  private void validateInput(RelNode input) {
+    RelDataType inputType = input.getRowType();
+    RelDataType graphType = getRowType();
+    if (inputType.getFieldCount() != graphType.getFieldCount()) {
+      throw new GeaFlowDSLException(
+          "Input type field size: "
+              + inputType.getFieldCount()
+              + " is not equal to the target graph field size: "
+              + graphType.getFieldCount());
     }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw).item("table", graph.getName());
+    for (int i = 0; i < inputType.getFieldCount(); i++) {
+      RelDataType inputFieldType = inputType.getFieldList().get(i).getType();
+      RelDataType targetFieldType = graphType.getFieldList().get(i).getType();
+      if (!inputFieldType.equals(targetFieldType)) {
+        throw new GeaFlowDSLException(
+            "Input field type: "
+                + inputFieldType
+                + " is mismatch with the target graph field type: "
+                + targetFieldType);
+      }
     }
+  }
 
-    @Override
-    protected RelDataType deriveRowType() {
-        return graph.getRowType(getCluster().getTypeFactory());
-    }
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("table", graph.getName());
+  }
 
-    public abstract GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input);
+  @Override
+  protected RelDataType deriveRowType() {
+    return graph.getRowType(getCluster().getTypeFactory());
+  }
 
-    @Override
-    public GraphModify copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, graph, inputs.get(0));
-    }
+  public abstract GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input);
 
-    public GeaFlowGraph getGraph() {
-        return graph;
-    }
+  @Override
+  public GraphModify copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, graph, inputs.get(0));
+  }
+
+  public GeaFlowGraph getGraph() {
+    return graph;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphScan.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphScan.java
index ae0191d77..f71e220d6 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphScan.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/GraphScan.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptTable;
 import org.apache.calcite.plan.RelTraitSet;
@@ -32,33 +33,32 @@
 
 public abstract class GraphScan extends AbstractRelNode {
 
-    protected final RelOptTable table;
+  protected final RelOptTable table;
 
-    protected GraphScan(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) {
-        super(cluster, traitSet);
-        this.table = Objects.requireNonNull(table);
-    }
+  protected GraphScan(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) {
+    super(cluster, traitSet);
+    this.table = Objects.requireNonNull(table);
+  }
 
-    @Override
-    public RelOptTable getTable() {
-        return table;
-    }
+  @Override
+  public RelOptTable getTable() {
+    return table;
+  }
 
-    @Override
-    protected RelDataType deriveRowType() {
-        return table.getRowType();
-    }
+  @Override
+  protected RelDataType deriveRowType() {
+    return table.getRowType();
+  }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("table", StringUtils.join(table.getQualifiedName(), "."));
-    }
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("table", StringUtils.join(table.getQualifiedName(), "."));
+  }
 
-    public abstract GraphScan copy(RelTraitSet traitSet, RelOptTable table);
+  public abstract GraphScan copy(RelTraitSet traitSet, RelOptTable table);
 
-    public GraphScan copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.isEmpty();
-        return copy(traitSet, table);
-    }
+  public GraphScan copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.isEmpty();
+    return copy(traitSet, table);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchNodeVisitor.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchNodeVisitor.java
index 627b8b346..6cf1ab67f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchNodeVisitor.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchNodeVisitor.java
@@ -36,31 +36,31 @@
 
 public interface MatchNodeVisitor {
 
-    T visitVertexMatch(VertexMatch vertexMatch);
+  T visitVertexMatch(VertexMatch vertexMatch);
 
-    T visitEdgeMatch(EdgeMatch edgeMatch);
+  T visitEdgeMatch(EdgeMatch edgeMatch);
 
-    T visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch);
+  T visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch);
 
-    T visitFilter(MatchFilter filter);
+  T visitFilter(MatchFilter filter);
 
-    T visitJoin(MatchJoin join);
+  T visitJoin(MatchJoin join);
 
-    T visitDistinct(MatchDistinct distinct);
+  T visitDistinct(MatchDistinct distinct);
 
-    T visitUnion(MatchUnion union);
+  T visitUnion(MatchUnion union);
 
-    T visitLoopMatch(LoopUntilMatch loopMatch);
+  T visitLoopMatch(LoopUntilMatch loopMatch);
 
-    T visitSubQueryStart(SubQueryStart subQueryStart);
+  T visitSubQueryStart(SubQueryStart subQueryStart);
 
-    T visitPathModify(MatchPathModify pathModify);
+  T visitPathModify(MatchPathModify pathModify);
 
-    T visitExtend(MatchExtend matchExtend);
+  T visitExtend(MatchExtend matchExtend);
 
-    T visitSort(MatchPathSort pathSort);
+  T visitSort(MatchPathSort pathSort);
 
-    T visitAggregate(MatchAggregate matchAggregate);
+  T visitAggregate(MatchAggregate matchAggregate);
 
-    T visit(RelNode node);
+  T visit(RelNode node);
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchRelShuffle.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchRelShuffle.java
index e9199dec7..225585e44 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchRelShuffle.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/MatchRelShuffle.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.rel.RelNode;
 import org.apache.geaflow.dsl.rel.match.EdgeMatch;
 import org.apache.geaflow.dsl.rel.match.IMatchNode;
@@ -39,74 +40,74 @@
 
 public class MatchRelShuffle extends AbstractMatchNodeVisitor {
 
-    @Override
-    public IMatchNode visitVertexMatch(VertexMatch vertexMatch) {
-        return visitChildren(vertexMatch);
-    }
-
-    @Override
-    public IMatchNode visitEdgeMatch(EdgeMatch edgeMatch) {
-        return visitChildren(edgeMatch);
-    }
-
-    @Override
-    public IMatchNode visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) {
-        return visitChildren(virtualEdgeMatch);
-    }
-
-    @Override
-    public IMatchNode visitFilter(MatchFilter filter) {
-        return visitChildren(filter);
-    }
-
-    @Override
-    public IMatchNode visitJoin(MatchJoin join) {
-        return visitChildren(join);
-    }
-
-    @Override
-    public IMatchNode visitDistinct(MatchDistinct distinct) {
-        return visitChildren(distinct);
-    }
-
-    @Override
-    public IMatchNode visitUnion(MatchUnion union) {
-        return visitChildren(union);
-    }
-
-    @Override
-    public IMatchNode visitLoopMatch(LoopUntilMatch loopMatch) {
-        return visitChildren(loopMatch);
-    }
-
-    @Override
-    public IMatchNode visitSubQueryStart(SubQueryStart subQueryStart) {
-        return visitChildren(subQueryStart);
-    }
-
-    @Override
-    public IMatchNode visitPathModify(MatchPathModify pathModify) {
-        return visitChildren(pathModify);
-    }
-
-    @Override
-    public IMatchNode visitExtend(MatchExtend matchExtend) {
-        return visitChildren(matchExtend);
-    }
-
-    @Override
-    public IMatchNode visitSort(MatchPathSort pathSort) {
-        return visitChildren(pathSort);
-    }
-
-    @Override
-    public IMatchNode visitAggregate(MatchAggregate matchAggregate) {
-        return visitChildren(matchAggregate);
-    }
-
-    protected IMatchNode visitChildren(IMatchNode parent) {
-        List newInputs = parent.getInputs().stream()
-            .map(this::visit).collect(Collectors.toList());
-        return (IMatchNode) parent.copy(parent.getTraitSet(), newInputs);
-    }
+  @Override
+  public IMatchNode visitVertexMatch(VertexMatch vertexMatch) {
+    return visitChildren(vertexMatch);
+  }
+
+  @Override
+  public IMatchNode visitEdgeMatch(EdgeMatch edgeMatch) {
+    return visitChildren(edgeMatch);
+  }
+
+  @Override
+  public IMatchNode visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) {
+    return visitChildren(virtualEdgeMatch);
+  }
+
+  @Override
+  public IMatchNode visitFilter(MatchFilter filter) {
+    return visitChildren(filter);
+  }
+
+  @Override
+  public IMatchNode visitJoin(MatchJoin join) {
+    return visitChildren(join);
+  }
+
+  @Override
+  public IMatchNode visitDistinct(MatchDistinct distinct) {
+    return visitChildren(distinct);
+  }
+
+  @Override
+  public IMatchNode visitUnion(MatchUnion union) {
+    return visitChildren(union);
+  }
+
+  @Override
+  public IMatchNode visitLoopMatch(LoopUntilMatch loopMatch) {
+    return visitChildren(loopMatch);
+  }
+
+  @Override
+  public IMatchNode visitSubQueryStart(SubQueryStart subQueryStart) {
+    return visitChildren(subQueryStart);
+  }
+
+  @Override
+  public IMatchNode visitPathModify(MatchPathModify pathModify) {
+    return visitChildren(pathModify);
+  }
+
+  @Override
+  public IMatchNode visitExtend(MatchExtend matchExtend) {
+    return visitChildren(matchExtend);
+  }
+
+  @Override
+  public IMatchNode visitSort(MatchPathSort pathSort) {
+    return visitChildren(pathSort);
+  }
+
+  @Override
+  public IMatchNode visitAggregate(MatchAggregate matchAggregate) {
+    return visitChildren(matchAggregate);
+  }
+
+  protected IMatchNode visitChildren(IMatchNode parent) {
+    List newInputs =
+        parent.getInputs().stream().map(this::visit).collect(Collectors.toList());
+    return (IMatchNode) parent.copy(parent.getTraitSet(), newInputs);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ParameterizedRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ParameterizedRelNode.java
index 9c2dbedcd..02b30bc7b 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ParameterizedRelNode.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/ParameterizedRelNode.java
@@ -27,23 +27,21 @@
 
 public abstract class ParameterizedRelNode extends BiRel {
 
-    public ParameterizedRelNode(RelOptCluster cluster, RelTraitSet traitSet,
-                                RelNode parameter, RelNode query) {
-        super(cluster, traitSet, parameter, query);
-    }
-
-    public RelNode getParameterNode() {
-        return getLeft();
-    }
-
-    public RelNode getQueryNode() {
-        return getRight();
-    }
-
-    @Override
-    protected RelDataType deriveRowType() {
-        return getQueryNode().getRowType();
-    }
-
-
+  public ParameterizedRelNode(
+      RelOptCluster cluster, RelTraitSet traitSet, RelNode parameter, RelNode query) {
+    super(cluster, traitSet, parameter, query);
+  }
+
+  public RelNode getParameterNode() {
+    return getLeft();
+  }
+
+  public RelNode getQueryNode() {
+    return getRight();
+  }
+
+  @Override
+  protected RelDataType deriveRowType() {
+    return getQueryNode().getRowType();
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathModify.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathModify.java
index 79c77b8bb..481d5cfee 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathModify.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathModify.java
@@ -19,10 +19,10 @@
 
 package org.apache.geaflow.dsl.rel;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -34,110 +34,122 @@
 import org.apache.geaflow.dsl.rex.PathInputRef;
 import org.apache.geaflow.dsl.rex.RexObjectConstruct;
 
-public abstract class PathModify extends SingleRel {
-
-    protected final ImmutableList expressions;
+import com.google.common.collect.ImmutableList;
 
-    protected final GraphRecordType modifyGraphType;
+public abstract class PathModify extends SingleRel {
 
-    protected PathModify(RelOptCluster cluster, RelTraitSet traits,
-                         RelNode input, List expressions,
-                         RelDataType rowType, GraphRecordType modifyGraphType) {
-        super(cluster, traits, input);
-        this.expressions = ImmutableList.copyOf(expressions);
-        this.rowType = Objects.requireNonNull(rowType);
-        this.modifyGraphType = Objects.requireNonNull(modifyGraphType);
+  protected final ImmutableList expressions;
+
+  protected final GraphRecordType modifyGraphType;
+
+  protected PathModify(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType modifyGraphType) {
+    super(cluster, traits, input);
+    this.expressions = ImmutableList.copyOf(expressions);
+    this.rowType = Objects.requireNonNull(rowType);
+    this.modifyGraphType = Objects.requireNonNull(modifyGraphType);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("expressions", expressions);
+  }
+
+  @Override
+  public PathModify copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, inputs.get(0), expressions, rowType);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    List rewriteExpressions =
+        expressions.stream()
+            .map(
+                exp -> {
+                  PathInputRef rewriteLeftVar = (PathInputRef) exp.leftVar.accept(shuttle);
+                  RexObjectConstruct rewriteNode =
+                      (RexObjectConstruct) exp.getObjectConstruct().accept(shuttle);
+                  return new PathModifyExpression(rewriteLeftVar, rewriteNode);
+                })
+            .collect(Collectors.toList());
+    return copy(traitSet, input, rewriteExpressions, rowType);
+  }
+
+  public abstract PathModify copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      List expressions,
+      RelDataType rowType);
+
+  public PathModify copy(RelDataType rowType) {
+    return copy(traitSet, input, expressions, rowType);
+  }
+
+  public ImmutableList getExpressions() {
+    return expressions;
+  }
+
+  public GraphRecordType getModifyGraphType() {
+    return modifyGraphType;
+  }
+
+  public static class PathModifyExpression {
+
+    private final PathInputRef leftVar;
+
+    private final RexObjectConstruct expression;
+
+    public PathModifyExpression(PathInputRef leftVar, RexObjectConstruct expression) {
+      this.leftVar = leftVar;
+      this.expression = expression;
     }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw).item("expressions", expressions);
+    public String getPathFieldName() {
+      return leftVar.getLabel();
     }
 
-    @Override
-    public PathModify copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, inputs.get(0), expressions, rowType);
+    public int getIndex() {
+      return leftVar.getIndex();
     }
 
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        List rewriteExpressions =
-            expressions.stream().map(exp -> {
-                PathInputRef rewriteLeftVar = (PathInputRef) exp.leftVar.accept(shuttle);
-                RexObjectConstruct rewriteNode = (RexObjectConstruct) exp.getObjectConstruct().accept(shuttle);
-                return new PathModifyExpression(rewriteLeftVar, rewriteNode);
-            }).collect(Collectors.toList());
-        return copy(traitSet, input, rewriteExpressions, rowType);
+    public PathInputRef getLeftVar() {
+      return leftVar;
     }
 
+    public RexObjectConstruct getObjectConstruct() {
+      return expression;
+    }
 
-    public abstract PathModify copy(RelTraitSet traitSet, RelNode input, List expressions,
-                                    RelDataType rowType);
-
-    public PathModify copy(RelDataType rowType) {
-        return copy(traitSet, input, expressions, rowType);
+    public PathModifyExpression copy(RexObjectConstruct objectConstruct) {
+      return new PathModifyExpression(leftVar, objectConstruct);
     }
 
-    public ImmutableList getExpressions() {
-        return expressions;
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof PathModifyExpression)) {
+        return false;
+      }
+      PathModifyExpression that = (PathModifyExpression) o;
+      return Objects.equals(leftVar, that.leftVar) && Objects.equals(expression, that.expression);
     }
 
-    public GraphRecordType getModifyGraphType() {
-        return modifyGraphType;
+    @Override
+    public int hashCode() {
+      return Objects.hash(leftVar, expression);
     }
 
-    public static class PathModifyExpression {
-
-        private final PathInputRef leftVar;
-
-        private final RexObjectConstruct expression;
-
-        public PathModifyExpression(PathInputRef leftVar, RexObjectConstruct expression) {
-            this.leftVar = leftVar;
-            this.expression = expression;
-        }
-
-        public String getPathFieldName() {
-            return leftVar.getLabel();
-        }
-
-        public int getIndex() {
-            return leftVar.getIndex();
-        }
-
-        public PathInputRef getLeftVar() {
-            return leftVar;
-        }
-
-        public RexObjectConstruct getObjectConstruct() {
-            return expression;
-        }
-
-        public PathModifyExpression copy(RexObjectConstruct objectConstruct) {
-            return new PathModifyExpression(leftVar, objectConstruct);
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) {
-                return true;
-            }
-            if (!(o instanceof PathModifyExpression)) {
-                return false;
-            }
-            PathModifyExpression that = (PathModifyExpression) o;
-            return Objects.equals(leftVar, that.leftVar) && Objects.equals(expression, that.expression);
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(leftVar, expression);
-        }
-
-        @Override
-        public String toString() {
-            return leftVar.getLabel() + "=" + expression;
-        }
+    @Override
+    public String toString() {
+      return leftVar.getLabel() + "=" + expression;
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathSort.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathSort.java
index 4cd97c09a..891d475e1 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathSort.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/PathSort.java
@@ -22,6 +22,7 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -33,55 +34,59 @@
 
 public abstract class PathSort extends SingleRel {
 
-    protected final List orderByExpressions;
-    protected final RexNode limit;
+  protected final List orderByExpressions;
+  protected final RexNode limit;
 
-    protected PathSort(RelOptCluster cluster, RelTraitSet traits, RelNode input,
-                       List orderByExpressions, RexNode limit,
-                       PathRecordType pathType) {
-        super(cluster, traits, input);
-        this.rowType = Objects.requireNonNull(pathType);
-        this.orderByExpressions = Objects.requireNonNull(orderByExpressions);
-        this.limit = limit;
-    }
+  protected PathSort(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List orderByExpressions,
+      RexNode limit,
+      PathRecordType pathType) {
+    super(cluster, traits, input);
+    this.rowType = Objects.requireNonNull(pathType);
+    this.orderByExpressions = Objects.requireNonNull(orderByExpressions);
+    this.limit = limit;
+  }
 
-    public List getOrderByExpressions() {
-        return orderByExpressions;
-    }
+  public List getOrderByExpressions() {
+    return orderByExpressions;
+  }
 
-    public RexNode getLimit() {
-        return limit;
-    }
+  public RexNode getLimit() {
+    return limit;
+  }
 
-    public abstract PathSort copy(RelNode input, List orderByExpressions,
-                                  RexNode limit, PathRecordType pathType);
+  public abstract PathSort copy(
+      RelNode input, List orderByExpressions, RexNode limit, PathRecordType pathType);
 
-    @Override
-    public RelNode copy(RelTraitSet traitSet, List inputs) {
-        assert inputs != null && inputs.size() == 1 : "Invalid inputs size";
-        return copy(sole(inputs), orderByExpressions, limit, (PathRecordType) rowType);
-    }
+  @Override
+  public RelNode copy(RelTraitSet traitSet, List inputs) {
+    assert inputs != null && inputs.size() == 1 : "Invalid inputs size";
+    return copy(sole(inputs), orderByExpressions, limit, (PathRecordType) rowType);
+  }
 
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        List newOrders = orderByExpressions.stream()
-            .map(shuttle::apply).collect(Collectors.toList());
-        RexNode newLimit = null;
-        if (limit != null) {
-            newLimit = shuttle.apply(limit);
-        }
-        return copy(input, newOrders, newLimit, (PathRecordType) getRowType());
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    List newOrders =
+        orderByExpressions.stream().map(shuttle::apply).collect(Collectors.toList());
+    RexNode newLimit = null;
+    if (limit != null) {
+      newLimit = shuttle.apply(limit);
     }
+    return copy(input, newOrders, newLimit, (PathRecordType) getRowType());
+  }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        RelWriter writer = super.explainTerms(pw);
-        if (!orderByExpressions.isEmpty()) {
-            writer.item("order by", orderByExpressions);
-        }
-        if (limit != null) {
-            writer.item("limit", limit);
-        }
-        return writer;
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    RelWriter writer = super.explainTerms(pw);
+    if (!orderByExpressions.isEmpty()) {
+      writer.item("order by", orderByExpressions);
+    }
+    if (limit != null) {
+      writer.item("limit", limit);
     }
+    return writer;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalConstructGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalConstructGraph.java
index ddda01b2c..9ca7f044c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalConstructGraph.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalConstructGraph.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.logical;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -28,19 +29,24 @@
 
 public class LogicalConstructGraph extends ConstructGraph {
 
-    protected LogicalConstructGraph(RelOptCluster cluster, RelTraitSet traits,
-                                    RelNode input, List labelNames, RelDataType rowType) {
-        super(cluster, traits, input, labelNames, rowType);
-    }
+  protected LogicalConstructGraph(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List labelNames,
+      RelDataType rowType) {
+    super(cluster, traits, input, labelNames, rowType);
+  }
 
-    @Override
-    public LogicalConstructGraph copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return new LogicalConstructGraph(getCluster(), getTraitSet(), inputs.get(0), labelNames, rowType);
-    }
+  @Override
+  public LogicalConstructGraph copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return new LogicalConstructGraph(
+        getCluster(), getTraitSet(), inputs.get(0), labelNames, rowType);
+  }
 
-    public static LogicalConstructGraph create(RelOptCluster cluster, RelNode input,
-                                               List labelNames, RelDataType rowType) {
-        return new LogicalConstructGraph(cluster, cluster.traitSet(), input, labelNames, rowType);
-    }
+  public static LogicalConstructGraph create(
+      RelOptCluster cluster, RelNode input, List labelNames, RelDataType rowType) {
+    return new LogicalConstructGraph(cluster, cluster.traitSet(), input, labelNames, rowType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphAlgorithm.java
index 76a353365..d09d2de1a 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphAlgorithm.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphAlgorithm.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.logical;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -28,33 +29,38 @@
 
 public class LogicalGraphAlgorithm extends GraphAlgorithm {
 
-    protected LogicalGraphAlgorithm(RelOptCluster cluster,
-                                    RelTraitSet traits,
-                                    RelNode input,
-                                    Class userFunctionClass,
-                                    Object[] params) {
-        super(cluster, traits, input, userFunctionClass, params);
-    }
-
-    @Override
-    public GraphAlgorithm copy(RelTraitSet traitSet, RelNode input,
-                               Class userFunctionClass,
-                               Object[] params) {
-        return create(input.getCluster(), traitSet, input, userFunctionClass, params);
-    }
-
-    @Override
-    public GraphAlgorithm copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        RelNode input = inputs.get(0);
-        return new LogicalGraphAlgorithm(input.getCluster(), traitSet, input, userFunctionClass, params);
-    }
-
-    public static LogicalGraphAlgorithm create(RelOptCluster cluster,
-                                               RelTraitSet traits,
-                                               RelNode input,
-                                               Class userFunctionClass,
-                                               Object[] params) {
-        return new LogicalGraphAlgorithm(cluster, traits, input, userFunctionClass, params);
-    }
+  protected LogicalGraphAlgorithm(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      Class userFunctionClass,
+      Object[] params) {
+    super(cluster, traits, input, userFunctionClass, params);
+  }
+
+  @Override
+  public GraphAlgorithm copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      Class userFunctionClass,
+      Object[] params) {
+    return create(input.getCluster(), traitSet, input, userFunctionClass, params);
+  }
+
+  @Override
+  public GraphAlgorithm copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    RelNode input = inputs.get(0);
+    return new LogicalGraphAlgorithm(
+        input.getCluster(), traitSet, input, userFunctionClass, params);
+  }
+
+  public static LogicalGraphAlgorithm create(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      Class userFunctionClass,
+      Object[] params) {
+    return new LogicalGraphAlgorithm(cluster, traits, input, userFunctionClass, params);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphMatch.java
index c26cb0a4f..582e0787c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphMatch.java
@@ -28,18 +28,23 @@
 
 public class LogicalGraphMatch extends GraphMatch {
 
-    protected LogicalGraphMatch(RelOptCluster cluster, RelTraitSet traits,
-                                RelNode input, IMatchNode pathPattern, RelDataType rowType) {
-        super(cluster, traits, input, pathPattern, rowType);
-    }
+  protected LogicalGraphMatch(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      IMatchNode pathPattern,
+      RelDataType rowType) {
+    super(cluster, traits, input, pathPattern, rowType);
+  }
 
-    @Override
-    public LogicalGraphMatch copy(RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType) {
-        return new LogicalGraphMatch(getCluster(), traitSet, input, pathPattern, rowType);
-    }
+  @Override
+  public LogicalGraphMatch copy(
+      RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType) {
+    return new LogicalGraphMatch(getCluster(), traitSet, input, pathPattern, rowType);
+  }
 
-    public static LogicalGraphMatch create(RelOptCluster cluster, RelNode input,
-                                           IMatchNode pathPattern, RelDataType rowType) {
-        return new LogicalGraphMatch(cluster, cluster.traitSet(), input, pathPattern, rowType);
-    }
+  public static LogicalGraphMatch create(
+      RelOptCluster cluster, RelNode input, IMatchNode pathPattern, RelDataType rowType) {
+    return new LogicalGraphMatch(cluster, cluster.traitSet(), input, pathPattern, rowType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphModify.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphModify.java
index ffe5a5f92..68cb96482 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphModify.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphModify.java
@@ -27,17 +27,18 @@
 
 public class LogicalGraphModify extends GraphModify {
 
-    protected LogicalGraphModify(RelOptCluster cluster, RelTraitSet traitSet,
-                                 GeaFlowGraph graph, RelNode input) {
-        super(cluster, traitSet, graph, input);
-    }
+  protected LogicalGraphModify(
+      RelOptCluster cluster, RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) {
+    super(cluster, traitSet, graph, input);
+  }
 
-    @Override
-    public GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) {
-        return new LogicalGraphModify(getCluster(), traitSet, graph, input);
-    }
+  @Override
+  public GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) {
+    return new LogicalGraphModify(getCluster(), traitSet, graph, input);
+  }
 
-    public static LogicalGraphModify create(RelOptCluster cluster, GeaFlowGraph graph, RelNode input) {
-        return new LogicalGraphModify(cluster, cluster.traitSet(), graph, input);
-    }
+  public static LogicalGraphModify create(
+      RelOptCluster cluster, GeaFlowGraph graph, RelNode input) {
+    return new LogicalGraphModify(cluster, cluster.traitSet(), graph, input);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphScan.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphScan.java
index 1a82c882e..80574ed6f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphScan.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalGraphScan.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.logical;
 
 import java.util.List;
+
 import org.apache.calcite.linq4j.tree.Expression;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptSchema;
@@ -40,164 +41,162 @@
 
 public class LogicalGraphScan extends GraphScan {
 
-    protected LogicalGraphScan(RelOptCluster cluster, RelTraitSet traitSet,
-                               RelOptTable table) {
-        super(cluster, traitSet, table);
-    }
-
-    @Override
-    public GraphScan copy(RelTraitSet traitSet, RelOptTable table) {
-        return new LogicalGraphScan(getCluster(), traitSet, table);
-    }
-
-    public static LogicalGraphScan create(RelOptCluster cluster, RelOptTable relOptTable) {
-        return new LogicalGraphScan(cluster, cluster.traitSet(), relOptTable);
-    }
-
-    public static LogicalGraphScan create(RelOptCluster cluster, GeaFlowGraph graph) {
-        RelOptTable table = new RelOptTable() {
-
-            @Override
-            public  C unwrap(Class aClass) {
-                return (C) graph;
-            }
-
-            @Override
-            public List getQualifiedName() {
-                return null;
-            }
-
-            @Override
-            public double getRowCount() {
-                return 0;
-            }
-
-            @Override
-            public RelDataType getRowType() {
-                return graph.getRowType(GQLJavaTypeFactory.create());
-            }
-
-            @Override
-            public RelOptSchema getRelOptSchema() {
-                return null;
-            }
-
-            @Override
-            public RelNode toRel(ToRelContext context) {
-                return null;
-            }
-
-            @Override
-            public List getCollationList() {
-                return null;
-            }
-
-            @Override
-            public RelDistribution getDistribution() {
-                return null;
-            }
-
-            @Override
-            public boolean isKey(ImmutableBitSet columns) {
-                return false;
-            }
-
-            @Override
-            public List getReferentialConstraints() {
-                return null;
-            }
-
-            @Override
-            public Expression getExpression(Class clazz) {
-                return null;
-            }
-
-            @Override
-            public RelOptTable extend(List extendedFields) {
-                return null;
-            }
-
-            @Override
-            public List getColumnStrategies() {
-                return null;
-            }
-        };
-        return create(cluster, table);
-    }
-
-    public static LogicalGraphScan emptyScan(RelOptCluster cluster, GraphRecordType graphRecordType) {
-        return create(cluster, empty(graphRecordType));
-    }
-
-
-    private static RelOptTable empty(GraphRecordType graphRecordType) {
-        return new RelOptTable() {
-
-            @Override
-            public  C unwrap(Class aClass) {
-                return null;
-            }
-
-            @Override
-            public List getQualifiedName() {
-                return null;
-            }
-
-            @Override
-            public double getRowCount() {
-                return 0;
-            }
-
-            @Override
-            public RelDataType getRowType() {
-                return graphRecordType;
-            }
-
-            @Override
-            public RelOptSchema getRelOptSchema() {
-                return null;
-            }
-
-            @Override
-            public RelNode toRel(ToRelContext context) {
-                return null;
-            }
-
-            @Override
-            public List getCollationList() {
-                return null;
-            }
-
-            @Override
-            public RelDistribution getDistribution() {
-                return null;
-            }
-
-            @Override
-            public boolean isKey(ImmutableBitSet columns) {
-                return false;
-            }
-
-            @Override
-            public List getReferentialConstraints() {
-                return null;
-            }
-
-            @Override
-            public Expression getExpression(Class clazz) {
-                return null;
-            }
-
-            @Override
-            public RelOptTable extend(List extendedFields) {
-                return null;
-            }
-
-            @Override
-            public List getColumnStrategies() {
-                return null;
-            }
+  protected LogicalGraphScan(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) {
+    super(cluster, traitSet, table);
+  }
+
+  @Override
+  public GraphScan copy(RelTraitSet traitSet, RelOptTable table) {
+    return new LogicalGraphScan(getCluster(), traitSet, table);
+  }
+
+  public static LogicalGraphScan create(RelOptCluster cluster, RelOptTable relOptTable) {
+    return new LogicalGraphScan(cluster, cluster.traitSet(), relOptTable);
+  }
+
+  public static LogicalGraphScan create(RelOptCluster cluster, GeaFlowGraph graph) {
+    RelOptTable table =
+        new RelOptTable() {
+
+          @Override
+          public  C unwrap(Class aClass) {
+            return (C) graph;
+          }
+
+          @Override
+          public List getQualifiedName() {
+            return null;
+          }
+
+          @Override
+          public double getRowCount() {
+            return 0;
+          }
+
+          @Override
+          public RelDataType getRowType() {
+            return graph.getRowType(GQLJavaTypeFactory.create());
+          }
+
+          @Override
+          public RelOptSchema getRelOptSchema() {
+            return null;
+          }
+
+          @Override
+          public RelNode toRel(ToRelContext context) {
+            return null;
+          }
+
+          @Override
+          public List getCollationList() {
+            return null;
+          }
+
+          @Override
+          public RelDistribution getDistribution() {
+            return null;
+          }
+
+          @Override
+          public boolean isKey(ImmutableBitSet columns) {
+            return false;
+          }
+
+          @Override
+          public List getReferentialConstraints() {
+            return null;
+          }
+
+          @Override
+          public Expression getExpression(Class clazz) {
+            return null;
+          }
+
+          @Override
+          public RelOptTable extend(List extendedFields) {
+            return null;
+          }
+
+          @Override
+          public List getColumnStrategies() {
+            return null;
+          }
         };
-    }
-
+    return create(cluster, table);
+  }
+
+  public static LogicalGraphScan emptyScan(RelOptCluster cluster, GraphRecordType graphRecordType) {
+    return create(cluster, empty(graphRecordType));
+  }
+
+  private static RelOptTable empty(GraphRecordType graphRecordType) {
+    return new RelOptTable() {
+
+      @Override
+      public  C unwrap(Class aClass) {
+        return null;
+      }
+
+      @Override
+      public List getQualifiedName() {
+        return null;
+      }
+
+      @Override
+      public double getRowCount() {
+        return 0;
+      }
+
+      @Override
+      public RelDataType getRowType() {
+        return graphRecordType;
+      }
+
+      @Override
+      public RelOptSchema getRelOptSchema() {
+        return null;
+      }
+
+      @Override
+      public RelNode toRel(ToRelContext context) {
+        return null;
+      }
+
+      @Override
+      public List getCollationList() {
+        return null;
+      }
+
+      @Override
+      public RelDistribution getDistribution() {
+        return null;
+      }
+
+      @Override
+      public boolean isKey(ImmutableBitSet columns) {
+        return false;
+      }
+
+      @Override
+      public List getReferentialConstraints() {
+        return null;
+      }
+
+      @Override
+      public Expression getExpression(Class clazz) {
+        return null;
+      }
+
+      @Override
+      public RelOptTable extend(List extendedFields) {
+        return null;
+      }
+
+      @Override
+      public List getColumnStrategies() {
+        return null;
+      }
+    };
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalParameterizedRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalParameterizedRelNode.java
index c2cc5fb0e..5e2236161 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalParameterizedRelNode.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/logical/LogicalParameterizedRelNode.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.logical;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -27,20 +28,19 @@
 
 public class LogicalParameterizedRelNode extends ParameterizedRelNode {
 
-    public LogicalParameterizedRelNode(RelOptCluster cluster, RelTraitSet traitSet,
-                                       RelNode parameter, RelNode query) {
-        super(cluster, traitSet, parameter, query);
-    }
+  public LogicalParameterizedRelNode(
+      RelOptCluster cluster, RelTraitSet traitSet, RelNode parameter, RelNode query) {
+    super(cluster, traitSet, parameter, query);
+  }
 
-    @Override
-    public ParameterizedRelNode copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 2;
-        return new LogicalParameterizedRelNode(getCluster(), traitSet, inputs.get(0), inputs.get(1));
-    }
+  @Override
+  public ParameterizedRelNode copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 2;
+    return new LogicalParameterizedRelNode(getCluster(), traitSet, inputs.get(0), inputs.get(1));
+  }
 
-    public static LogicalParameterizedRelNode create(RelOptCluster cluster,
-                                                     RelTraitSet traitSet,
-                                                     RelNode parameter, RelNode query) {
-        return new LogicalParameterizedRelNode(cluster, traitSet, parameter, query);
-    }
+  public static LogicalParameterizedRelNode create(
+      RelOptCluster cluster, RelTraitSet traitSet, RelNode parameter, RelNode query) {
+    return new LogicalParameterizedRelNode(cluster, traitSet, parameter, query);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/EdgeMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/EdgeMatch.java
index b4f6ff47e..7f924069b 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/EdgeMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/EdgeMatch.java
@@ -21,9 +21,8 @@
 
 import static org.apache.geaflow.dsl.util.GQLRelUtil.match;
 
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
 import java.util.*;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.AbstractRelNode;
@@ -37,133 +36,166 @@
 import org.apache.geaflow.dsl.rel.MatchNodeVisitor;
 import org.apache.geaflow.dsl.sqlnode.SqlMatchEdge.EdgeDirection;
 
-public class EdgeMatch extends AbstractRelNode implements SingleMatchNode, IMatchLabel {
-
-    private RelNode input;
-
-    private final String label;
-
-    private Set pushDownFields;
-
-    private final ImmutableSet edgeTypes;
-
-    private final EdgeDirection direction;
-
-    private final PathRecordType pathType;
-
-    private final RelDataType nodeType;
-
-    protected EdgeMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, String label,
-                        Collection edgeTypes, EdgeDirection direction, RelDataType nodeType,
-                        PathRecordType pathType, Set pushDownFields) {
-        super(cluster, traitSet);
-        this.input = input;
-        this.label = label;
-        this.edgeTypes = ImmutableSet.copyOf(edgeTypes);
-        if (input != null && match(input).getNodeType().getSqlTypeName() != SqlTypeName.VERTEX) {
-            throw new GeaFlowDSLException("Illegal input type: " + match(input).getNodeType().getSqlTypeName()
-                + " for " + getRelTypeName() + ", should be: " + SqlTypeName.VERTEX);
-        }
-        this.direction = direction;
-        this.rowType = Objects.requireNonNull(pathType);
-        this.pathType = Objects.requireNonNull(pathType);
-        this.nodeType = Objects.requireNonNull(nodeType);
-        this.pushDownFields = pushDownFields;
-    }
-
-    public void addField(RexFieldAccess field) {
-        if (pushDownFields == null) {
-            pushDownFields = new HashSet<>();
-        }
-        pushDownFields.add(field);
-    }
-
-    @Override
-    public String getLabel() {
-        return label;
-    }
-
-    @Override
-    public Set getTypes() {
-        return edgeTypes;
-    }
-
-    @Override
-    public List getInputs() {
-        if (input == null) {
-            return Collections.emptyList();
-        }
-        return ImmutableList.of(input);
-    }
-
-    @Override
-    public RelNode getInput() {
-        return input;
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new EdgeMatch(getCluster(), traitSet, sole(inputs), label, edgeTypes,
-            direction, nodeType, pathSchema, pushDownFields);
-    }
-
-    public EdgeDirection getDirection() {
-        return direction;
-    }
-
-    @Override
-    public EdgeMatch copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return new EdgeMatch(getCluster(), getTraitSet(), sole(inputs),
-            label, edgeTypes, direction, nodeType, pathType, pushDownFields);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("input", input)
-            .item("label", label)
-            .item("edgeTypes", edgeTypes)
-            .item("direction", direction);
-    }
-
-    @Override
-    public void replaceInput(int ordinalInParent, RelNode p) {
-        assert ordinalInParent == 0;
-        this.input = p;
-    }
-
-    @Override
-    protected RelDataType deriveRowType() {
-        throw new UnsupportedOperationException();
-    }
-
-    public static EdgeMatch create(RelOptCluster cluster, SingleMatchNode input,
-                                   String label, List edgeTypes,
-                                   EdgeDirection direction, RelDataType nodeType,
-                                   PathRecordType pathType) {
-        return new EdgeMatch(cluster, cluster.traitSet(), input, label, edgeTypes,
-            direction, nodeType, pathType, null);
-    }
-
-
-    public Set getFields() {
-        return pushDownFields;
-    }
-
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
-    }
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 
-    @Override
-    public RelDataType getNodeType() {
-        return nodeType;
-    }
+public class EdgeMatch extends AbstractRelNode implements SingleMatchNode, IMatchLabel {
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitEdgeMatch(this);
-    }
+  private RelNode input;
+
+  private final String label;
+
+  private Set pushDownFields;
+
+  private final ImmutableSet edgeTypes;
+
+  private final EdgeDirection direction;
+
+  private final PathRecordType pathType;
+
+  private final RelDataType nodeType;
+
+  protected EdgeMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection edgeTypes,
+      EdgeDirection direction,
+      RelDataType nodeType,
+      PathRecordType pathType,
+      Set pushDownFields) {
+    super(cluster, traitSet);
+    this.input = input;
+    this.label = label;
+    this.edgeTypes = ImmutableSet.copyOf(edgeTypes);
+    if (input != null && match(input).getNodeType().getSqlTypeName() != SqlTypeName.VERTEX) {
+      throw new GeaFlowDSLException(
+          "Illegal input type: "
+              + match(input).getNodeType().getSqlTypeName()
+              + " for "
+              + getRelTypeName()
+              + ", should be: "
+              + SqlTypeName.VERTEX);
+    }
+    this.direction = direction;
+    this.rowType = Objects.requireNonNull(pathType);
+    this.pathType = Objects.requireNonNull(pathType);
+    this.nodeType = Objects.requireNonNull(nodeType);
+    this.pushDownFields = pushDownFields;
+  }
+
+  public void addField(RexFieldAccess field) {
+    if (pushDownFields == null) {
+      pushDownFields = new HashSet<>();
+    }
+    pushDownFields.add(field);
+  }
+
+  @Override
+  public String getLabel() {
+    return label;
+  }
+
+  @Override
+  public Set getTypes() {
+    return edgeTypes;
+  }
+
+  @Override
+  public List getInputs() {
+    if (input == null) {
+      return Collections.emptyList();
+    }
+    return ImmutableList.of(input);
+  }
+
+  @Override
+  public RelNode getInput() {
+    return input;
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new EdgeMatch(
+        getCluster(),
+        traitSet,
+        sole(inputs),
+        label,
+        edgeTypes,
+        direction,
+        nodeType,
+        pathSchema,
+        pushDownFields);
+  }
+
+  public EdgeDirection getDirection() {
+    return direction;
+  }
+
+  @Override
+  public EdgeMatch copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return new EdgeMatch(
+        getCluster(),
+        getTraitSet(),
+        sole(inputs),
+        label,
+        edgeTypes,
+        direction,
+        nodeType,
+        pathType,
+        pushDownFields);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw)
+        .item("input", input)
+        .item("label", label)
+        .item("edgeTypes", edgeTypes)
+        .item("direction", direction);
+  }
+
+  @Override
+  public void replaceInput(int ordinalInParent, RelNode p) {
+    assert ordinalInParent == 0;
+    this.input = p;
+  }
+
+  @Override
+  protected RelDataType deriveRowType() {
+    throw new UnsupportedOperationException();
+  }
+
+  public static EdgeMatch create(
+      RelOptCluster cluster,
+      SingleMatchNode input,
+      String label,
+      List edgeTypes,
+      EdgeDirection direction,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    return new EdgeMatch(
+        cluster, cluster.traitSet(), input, label, edgeTypes, direction, nodeType, pathType, null);
+  }
+
+  public Set getFields() {
+    return pushDownFields;
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return nodeType;
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitEdgeMatch(this);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchLabel.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchLabel.java
index 832d52de6..035c085b2 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchLabel.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchLabel.java
@@ -23,7 +23,7 @@
 
 public interface IMatchLabel extends SingleMatchNode {
 
-    String getLabel();
+  String getLabel();
 
-    Set getTypes();
+  Set getTypes();
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchNode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchNode.java
index 223f14b99..de8740165 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchNode.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/IMatchNode.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.match;
 
 import java.util.List;
+
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.geaflow.dsl.calcite.PathRecordType;
@@ -27,11 +28,11 @@
 
 public interface IMatchNode extends RelNode {
 
-    PathRecordType getPathSchema();
+  PathRecordType getPathSchema();
 
-    RelDataType getNodeType();
+  RelDataType getNodeType();
 
-    IMatchNode copy(List inputs, PathRecordType pathType);
+  IMatchNode copy(List inputs, PathRecordType pathType);
 
-     T accept(MatchNodeVisitor visitor);
+   T accept(MatchNodeVisitor visitor);
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/LoopUntilMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/LoopUntilMatch.java
index 270d0b166..cc341e060 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/LoopUntilMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/LoopUntilMatch.java
@@ -19,9 +19,9 @@
 
 package org.apache.geaflow.dsl.rel.match;
 
-import com.google.common.collect.Lists;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -35,121 +35,161 @@
 import org.apache.geaflow.dsl.rel.MatchNodeVisitor;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
-public class LoopUntilMatch extends SingleRel implements SingleMatchNode {
-
-    private final SingleMatchNode loopBody;
-
-    private final int minLoopCount;
-
-    private final int maxLoopCount;
-
-    private final RexNode utilCondition;
-
-    private final PathRecordType pathType;
-
-    protected LoopUntilMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                             SingleMatchNode loopBody, RexNode utilCondition,
-                             int minLoopCount, int maxLoopCount,
-                             PathRecordType pathType) {
-        super(cluster, traitSet, input);
-        this.loopBody = Objects.requireNonNull(loopBody);
-        this.utilCondition = Objects.requireNonNull(utilCondition);
-        this.minLoopCount = minLoopCount;
-        this.maxLoopCount = maxLoopCount;
-        this.pathType = Objects.requireNonNull(pathType);
-        this.rowType = pathType;
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return loopBody.getNodeType();
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitLoopMatch(this);
-    }
-
-    public SingleMatchNode getLoopBody() {
-        return loopBody;
-    }
-
-    public RexNode getUtilCondition() {
-        return utilCondition;
-    }
-
-    public int getMinLoopCount() {
-        return minLoopCount;
-    }
-
-    public int getMaxLoopCount() {
-        return maxLoopCount;
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("body", loopBody)
-            .item("minLoopCount", minLoopCount)
-            .item("maxLoopCount", maxLoopCount)
-            .item("condition", utilCondition);
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new LoopUntilMatch(getCluster(), getTraitSet(), sole(inputs), loopBody,
-            utilCondition, minLoopCount, maxLoopCount, pathSchema);
-    }
-
-    @Override
-    public LoopUntilMatch copy(RelTraitSet traitSet, List inputs) {
-        return new LoopUntilMatch(getCluster(), traitSet, sole(inputs), loopBody, utilCondition,
-            minLoopCount, maxLoopCount, pathType);
-    }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        GQLRelUtil.applyRexShuffleToTree(loopBody, shuttle);
-        RexNode newUtilCondition = utilCondition.accept(shuttle);
-        return new LoopUntilMatch(getCluster(), getTraitSet(), input, loopBody, newUtilCondition,
-            minLoopCount, maxLoopCount, pathType);
-    }
+import com.google.common.collect.Lists;
 
+public class LoopUntilMatch extends SingleRel implements SingleMatchNode {
 
-    public static LoopUntilMatch create(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                                        SingleMatchNode loopBody, RexNode utilCondition, int minLoopCount,
-                                        int maxLoopCount, PathRecordType pathType) {
-        return new LoopUntilMatch(cluster, traitSet, input, loopBody, utilCondition, minLoopCount,
-            maxLoopCount, pathType);
+  private final SingleMatchNode loopBody;
+
+  private final int minLoopCount;
+
+  private final int maxLoopCount;
+
+  private final RexNode utilCondition;
+
+  private final PathRecordType pathType;
+
+  protected LoopUntilMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      SingleMatchNode loopBody,
+      RexNode utilCondition,
+      int minLoopCount,
+      int maxLoopCount,
+      PathRecordType pathType) {
+    super(cluster, traitSet, input);
+    this.loopBody = Objects.requireNonNull(loopBody);
+    this.utilCondition = Objects.requireNonNull(utilCondition);
+    this.minLoopCount = minLoopCount;
+    this.maxLoopCount = maxLoopCount;
+    this.pathType = Objects.requireNonNull(pathType);
+    this.rowType = pathType;
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return loopBody.getNodeType();
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitLoopMatch(this);
+  }
+
+  public SingleMatchNode getLoopBody() {
+    return loopBody;
+  }
+
+  public RexNode getUtilCondition() {
+    return utilCondition;
+  }
+
+  public int getMinLoopCount() {
+    return minLoopCount;
+  }
+
+  public int getMaxLoopCount() {
+    return maxLoopCount;
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw)
+        .item("body", loopBody)
+        .item("minLoopCount", minLoopCount)
+        .item("maxLoopCount", maxLoopCount)
+        .item("condition", utilCondition);
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new LoopUntilMatch(
+        getCluster(),
+        getTraitSet(),
+        sole(inputs),
+        loopBody,
+        utilCondition,
+        minLoopCount,
+        maxLoopCount,
+        pathSchema);
+  }
+
+  @Override
+  public LoopUntilMatch copy(RelTraitSet traitSet, List inputs) {
+    return new LoopUntilMatch(
+        getCluster(),
+        traitSet,
+        sole(inputs),
+        loopBody,
+        utilCondition,
+        minLoopCount,
+        maxLoopCount,
+        pathType);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    GQLRelUtil.applyRexShuffleToTree(loopBody, shuttle);
+    RexNode newUtilCondition = utilCondition.accept(shuttle);
+    return new LoopUntilMatch(
+        getCluster(),
+        getTraitSet(),
+        input,
+        loopBody,
+        newUtilCondition,
+        minLoopCount,
+        maxLoopCount,
+        pathType);
+  }
+
+  public static LoopUntilMatch create(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      SingleMatchNode loopBody,
+      RexNode utilCondition,
+      int minLoopCount,
+      int maxLoopCount,
+      PathRecordType pathType) {
+    return new LoopUntilMatch(
+        cluster, traitSet, input, loopBody, utilCondition, minLoopCount, maxLoopCount, pathType);
+  }
+
+  public static SingleMatchNode copyWithSubQueryStartPathType(
+      PathRecordType startPathType, SingleMatchNode node, boolean caseSensitive) {
+    if (node == null) {
+      return null;
     }
-
-    public static SingleMatchNode copyWithSubQueryStartPathType(PathRecordType startPathType,
-                                                                SingleMatchNode node,
-                                                                boolean caseSensitive) {
-        if (node == null) {
-            return null;
-        }
-        if (node instanceof LoopUntilMatch) {
-            LoopUntilMatch loop = (LoopUntilMatch) node;
-            return LoopUntilMatch.create(loop.getCluster(), loop.getTraitSet(), loop.getInput(),
-                copyWithSubQueryStartPathType(startPathType, loop.getLoopBody(), caseSensitive),
-                loop.getUtilCondition(), loop.getMinLoopCount(),
-                loop.getMaxLoopCount(), loop.getPathSchema());
-        } else {
-            PathRecordType concatPathType = startPathType;
-            for (RelDataTypeField field : node.getPathSchema().getFieldList()) {
-                if (concatPathType.getField(field.getName(), caseSensitive, false) == null) {
-                    concatPathType = concatPathType.addField(field.getName(), field.getType(), caseSensitive);
-                }
-            }
-            return (SingleMatchNode) node.copy(Lists.newArrayList(copyWithSubQueryStartPathType(
-                startPathType, (SingleMatchNode) node.getInput(), caseSensitive)), concatPathType);
+    if (node instanceof LoopUntilMatch) {
+      LoopUntilMatch loop = (LoopUntilMatch) node;
+      return LoopUntilMatch.create(
+          loop.getCluster(),
+          loop.getTraitSet(),
+          loop.getInput(),
+          copyWithSubQueryStartPathType(startPathType, loop.getLoopBody(), caseSensitive),
+          loop.getUtilCondition(),
+          loop.getMinLoopCount(),
+          loop.getMaxLoopCount(),
+          loop.getPathSchema());
+    } else {
+      PathRecordType concatPathType = startPathType;
+      for (RelDataTypeField field : node.getPathSchema().getFieldList()) {
+        if (concatPathType.getField(field.getName(), caseSensitive, false) == null) {
+          concatPathType = concatPathType.addField(field.getName(), field.getType(), caseSensitive);
         }
+      }
+      return (SingleMatchNode)
+          node.copy(
+              Lists.newArrayList(
+                  copyWithSubQueryStartPathType(
+                      startPathType, (SingleMatchNode) node.getInput(), caseSensitive)),
+              concatPathType);
     }
-
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchAggregate.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchAggregate.java
index 292054ed9..d06eee6f6 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchAggregate.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchAggregate.java
@@ -19,11 +19,10 @@
 
 package org.apache.geaflow.dsl.rel.match;
 
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -38,103 +37,135 @@
 import org.apache.geaflow.dsl.rex.MatchAggregateCall;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
-public class MatchAggregate extends SingleRel implements SingleMatchNode {
-
-    public final boolean indicator;
-    protected final List aggCalls;
-    protected final List groupSet;
-
-    protected MatchAggregate(RelOptCluster cluster, RelTraitSet traits, RelNode input,
-                             boolean indicator, List groupSet,
-                             List aggCalls, PathRecordType pathType) {
-        super(cluster, traits, input);
-        this.indicator = indicator; // true is allowed, but discouraged
-        this.aggCalls = ImmutableList.copyOf(aggCalls);
-        this.groupSet = Objects.requireNonNull(groupSet);
-        for (MatchAggregateCall aggCall : aggCalls) {
-            Preconditions.checkArgument(aggCall.filterArg < 0
-                    || isPredicate(input, aggCall.filterArg),
-                "filter must be BOOLEAN NOT NULL");
-        }
-        this.rowType = Objects.requireNonNull(pathType);
-    }
-
-    private boolean isPredicate(RelNode input, int index) {
-        final RelDataType type =
-            input.getRowType().getFieldList().get(index).getType();
-        return type.getSqlTypeName() == SqlTypeName.BOOLEAN
-            && !type.isNullable();
-    }
-
-    public List getGroupSet() {
-        return groupSet;
-    }
-
-    public List getAggCalls() {
-        return aggCalls;
-    }
-
-    public boolean isIndicator() {
-        return indicator;
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) rowType;
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitAggregate(this);
-    }
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathType) {
-        return new MatchAggregate(getCluster(), getTraitSet(), sole(inputs),
-            indicator, groupSet, aggCalls, (PathRecordType) rowType);
-    }
+public class MatchAggregate extends SingleRel implements SingleMatchNode {
 
-    @Override
-    public RelNode copy(RelTraitSet traitSet, List inputs) {
-        return new MatchAggregate(getCluster(), traitSet, sole(inputs),
-            indicator, groupSet, aggCalls, (PathRecordType) rowType);
+  public final boolean indicator;
+  protected final List aggCalls;
+  protected final List groupSet;
+
+  protected MatchAggregate(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      boolean indicator,
+      List groupSet,
+      List aggCalls,
+      PathRecordType pathType) {
+    super(cluster, traits, input);
+    this.indicator = indicator; // true is allowed, but discouraged
+    this.aggCalls = ImmutableList.copyOf(aggCalls);
+    this.groupSet = Objects.requireNonNull(groupSet);
+    for (MatchAggregateCall aggCall : aggCalls) {
+      Preconditions.checkArgument(
+          aggCall.filterArg < 0 || isPredicate(input, aggCall.filterArg),
+          "filter must be BOOLEAN NOT NULL");
     }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        List rewriteGroupList = this.groupSet.stream().map(rex -> rex.accept(shuttle))
+    this.rowType = Objects.requireNonNull(pathType);
+  }
+
+  private boolean isPredicate(RelNode input, int index) {
+    final RelDataType type = input.getRowType().getFieldList().get(index).getType();
+    return type.getSqlTypeName() == SqlTypeName.BOOLEAN && !type.isNullable();
+  }
+
+  public List getGroupSet() {
+    return groupSet;
+  }
+
+  public List getAggCalls() {
+    return aggCalls;
+  }
+
+  public boolean isIndicator() {
+    return indicator;
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) rowType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitAggregate(this);
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathType) {
+    return new MatchAggregate(
+        getCluster(),
+        getTraitSet(),
+        sole(inputs),
+        indicator,
+        groupSet,
+        aggCalls,
+        (PathRecordType) rowType);
+  }
+
+  @Override
+  public RelNode copy(RelTraitSet traitSet, List inputs) {
+    return new MatchAggregate(
+        getCluster(),
+        traitSet,
+        sole(inputs),
+        indicator,
+        groupSet,
+        aggCalls,
+        (PathRecordType) rowType);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    List rewriteGroupList =
+        this.groupSet.stream().map(rex -> rex.accept(shuttle)).collect(Collectors.toList());
+    List rewriteAggCalls =
+        this.aggCalls.stream()
+            .map(
+                call -> {
+                  List rewriteArgList =
+                      call.getArgList().stream().map(shuttle::apply).collect(Collectors.toList());
+                  return new MatchAggregateCall(
+                      call.getAggregation(),
+                      call.isDistinct(),
+                      call.isApproximate(),
+                      rewriteArgList,
+                      call.filterArg,
+                      call.getCollation(),
+                      call.getType(),
+                      call.getName());
+                })
             .collect(Collectors.toList());
-        List rewriteAggCalls = this.aggCalls.stream().map(call -> {
-            List rewriteArgList = call.getArgList().stream()
-                .map(shuttle::apply).collect(Collectors.toList());
-            return new MatchAggregateCall(call.getAggregation(), call.isDistinct(),
-                call.isApproximate(), rewriteArgList, call.filterArg, call.getCollation(),
-                call.getType(), call.getName());
-        }).collect(Collectors.toList());
-        return MatchAggregate.create(getInput(), indicator, rewriteGroupList, rewriteAggCalls,
-            (PathRecordType) rowType);
+    return MatchAggregate.create(
+        getInput(), indicator, rewriteGroupList, rewriteAggCalls, (PathRecordType) rowType);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    RelWriter writer = super.explainTerms(pw);
+    if (!groupSet.isEmpty()) {
+      writer.item("group by", groupSet);
     }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        RelWriter writer = super.explainTerms(pw);
-        if (!groupSet.isEmpty()) {
-            writer.item("group by", groupSet);
-        }
-        if (aggCalls != null) {
-            writer.item("aggCalls", aggCalls);
-        }
-        return writer;
-    }
-
-    public static MatchAggregate create(RelNode input, boolean indicator, List groupSet,
-                                        List aggCalls, PathRecordType pathType) {
-        return new MatchAggregate(input.getCluster(), input.getTraitSet(), input, indicator,
-            groupSet, aggCalls, pathType);
+    if (aggCalls != null) {
+      writer.item("aggCalls", aggCalls);
     }
+    return writer;
+  }
+
+  public static MatchAggregate create(
+      RelNode input,
+      boolean indicator,
+      List groupSet,
+      List aggCalls,
+      PathRecordType pathType) {
+    return new MatchAggregate(
+        input.getCluster(), input.getTraitSet(), input, indicator, groupSet, aggCalls, pathType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchDistinct.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchDistinct.java
index dd027a30d..6012555ef 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchDistinct.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchDistinct.java
@@ -22,6 +22,7 @@
 import static org.apache.geaflow.dsl.util.GQLRelUtil.match;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -34,56 +35,55 @@
 
 public class MatchDistinct extends SingleRel implements SingleMatchNode {
 
-    protected MatchDistinct(RelOptCluster cluster, RelTraitSet traits,
-                            RelNode input) {
-        super(cluster, traits, input);
-    }
-
-    public MatchDistinct copy(RelTraitSet traitSet, RelNode input) {
-        return new MatchDistinct(getCluster(), traitSet, input);
-    }
-
-    @Override
-    public RelNode copy(RelTraitSet traitSet, List inputs) {
-        return copy(traitSet, sole(inputs));
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) input.getRowType();
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return match(getInput()).getNodeType();
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitDistinct(this);
-    }
-
-    @Override
-    public RelNode getInput() {
-        return input;
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return copy(traitSet, sole(inputs));
-    }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        return copy(traitSet, input);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw).item("distinct", true);
-    }
-
-    public static MatchDistinct create(IMatchNode input) {
-        return new MatchDistinct(input.getCluster(), input.getTraitSet(), input);
-    }
+  protected MatchDistinct(RelOptCluster cluster, RelTraitSet traits, RelNode input) {
+    super(cluster, traits, input);
+  }
+
+  public MatchDistinct copy(RelTraitSet traitSet, RelNode input) {
+    return new MatchDistinct(getCluster(), traitSet, input);
+  }
+
+  @Override
+  public RelNode copy(RelTraitSet traitSet, List inputs) {
+    return copy(traitSet, sole(inputs));
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) input.getRowType();
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return match(getInput()).getNodeType();
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitDistinct(this);
+  }
+
+  @Override
+  public RelNode getInput() {
+    return input;
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return copy(traitSet, sole(inputs));
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    return copy(traitSet, input);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("distinct", true);
+  }
+
+  public static MatchDistinct create(IMatchNode input) {
+    return new MatchDistinct(input.getCluster(), input.getTraitSet(), input);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchExtend.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchExtend.java
index d9632cfc4..74206c598 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchExtend.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchExtend.java
@@ -19,11 +19,11 @@
 
 package org.apache.geaflow.dsl.rel.match;
 
-import com.google.common.collect.ImmutableList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -38,93 +38,124 @@
 import org.apache.geaflow.dsl.rex.RexObjectConstruct;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
-public class MatchExtend extends PathModify implements SingleMatchNode {
+import com.google.common.collect.ImmutableList;
 
-    private final Set rewriteFields;
+public class MatchExtend extends PathModify implements SingleMatchNode {
 
-    private MatchExtend(RelOptCluster cluster, RelTraitSet traits, RelNode input,
-                        List expressions, RelDataType rowType,
-                        GraphRecordType graphType) {
-        super(cluster, traits, input, expressions.stream()
+  private final Set rewriteFields;
+
+  private MatchExtend(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType graphType) {
+    super(
+        cluster,
+        traits,
+        input,
+        expressions.stream()
             .filter(exp -> rowType.getFieldNames().contains(exp.getPathFieldName()))
-            .collect(Collectors.toList()), rowType, graphType);
-        rewriteFields = new HashSet<>();
-        for (int i = 0; i < this.expressions.size(); i++) {
-            if (input.getRowType().getFieldNames().contains(this.expressions.get(i).getLeftVar().getLabel())) {
-                rewriteFields.add(this.expressions.get(i).getLeftVar().getLabel());
-            }
-        }
-    }
-
-    public static MatchExtend create(RelNode input, List expressions,
-                                     RelDataType rowType, GraphRecordType graphType) {
-        return new MatchExtend(input.getCluster(), input.getTraitSet(), input, expressions, rowType,
-            graphType);
-    }
-
-    public MatchExtend copy(RelTraitSet traitSet, RelNode input,
-                            List expressions, RelDataType rowType,
-                            GraphRecordType graphType) {
-        return new MatchExtend(getCluster(), traitSet, input, expressions, rowType, graphType);
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) getRowType();
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
-    }
-
-    public Set getRewriteFields() {
-        return rewriteFields;
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new MatchExtend(getCluster(), getTraitSet(), sole(inputs), expressions,
-            pathSchema, modifyGraphType);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw).item("expressions", expressions);
-    }
-
-    @Override
-    public MatchExtend copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, inputs.get(0), expressions, rowType, modifyGraphType);
-    }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        List rewriteExpressions =
-            expressions.stream().map(exp -> {
-                //If the LeftVar reference does not exist in the input, MatchExtend will default to
-                // extending at the end of the path, and LeftVar does not need to be rewritten.
-                PathInputRef leftVar = rewriteFields.contains(exp.getLeftVar().getLabel())
-                    ? (PathInputRef) exp.getLeftVar().accept(shuttle) : exp.getLeftVar();
-                RexObjectConstruct rewriteNode = (RexObjectConstruct) exp.getObjectConstruct().accept(shuttle);
-                return new PathModifyExpression(leftVar, rewriteNode);
-            }).collect(Collectors.toList());
-        return copy(traitSet, input, rewriteExpressions, rowType);
-    }
-
-    @Override
-    public PathModify copy(RelTraitSet traitSet, RelNode input,
-                           List expressions, RelDataType rowType) {
-        return copy(traitSet, input, expressions, rowType, modifyGraphType);
-    }
-
-    public ImmutableList getExpressions() {
-        return ImmutableList.copyOf(expressions);
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitExtend(this);
+            .collect(Collectors.toList()),
+        rowType,
+        graphType);
+    rewriteFields = new HashSet<>();
+    for (int i = 0; i < this.expressions.size(); i++) {
+      if (input
+          .getRowType()
+          .getFieldNames()
+          .contains(this.expressions.get(i).getLeftVar().getLabel())) {
+        rewriteFields.add(this.expressions.get(i).getLeftVar().getLabel());
+      }
     }
+  }
+
+  public static MatchExtend create(
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType graphType) {
+    return new MatchExtend(
+        input.getCluster(), input.getTraitSet(), input, expressions, rowType, graphType);
+  }
+
+  public MatchExtend copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType graphType) {
+    return new MatchExtend(getCluster(), traitSet, input, expressions, rowType, graphType);
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) getRowType();
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
+  }
+
+  public Set getRewriteFields() {
+    return rewriteFields;
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new MatchExtend(
+        getCluster(), getTraitSet(), sole(inputs), expressions, pathSchema, modifyGraphType);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("expressions", expressions);
+  }
+
+  @Override
+  public MatchExtend copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, inputs.get(0), expressions, rowType, modifyGraphType);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    List rewriteExpressions =
+        expressions.stream()
+            .map(
+                exp -> {
+                  // If the LeftVar reference does not exist in the input, MatchExtend will default
+                  // to
+                  // extending at the end of the path, and LeftVar does not need to be rewritten.
+                  PathInputRef leftVar =
+                      rewriteFields.contains(exp.getLeftVar().getLabel())
+                          ? (PathInputRef) exp.getLeftVar().accept(shuttle)
+                          : exp.getLeftVar();
+                  RexObjectConstruct rewriteNode =
+                      (RexObjectConstruct) exp.getObjectConstruct().accept(shuttle);
+                  return new PathModifyExpression(leftVar, rewriteNode);
+                })
+            .collect(Collectors.toList());
+    return copy(traitSet, input, rewriteExpressions, rowType);
+  }
+
+  @Override
+  public PathModify copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      List expressions,
+      RelDataType rowType) {
+    return copy(traitSet, input, expressions, rowType, modifyGraphType);
+  }
+
+  public ImmutableList getExpressions() {
+    return ImmutableList.copyOf(expressions);
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitExtend(this);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchFilter.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchFilter.java
index b275f5d3c..196359c01 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchFilter.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchFilter.java
@@ -23,6 +23,7 @@
 
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -40,86 +41,90 @@
 
 public class MatchFilter extends SingleRel implements SingleMatchNode {
 
-    private final RexNode condition;
-    private final PathRecordType pathType;
-
-    protected MatchFilter(RelOptCluster cluster, RelTraitSet traits,
-                          RelNode input, RexNode condition, PathRecordType pathType) {
-        super(cluster, traits, input);
-        this.condition = Objects.requireNonNull(condition);
-        this.pathType = Objects.requireNonNull(pathType);
-    }
-
-    public MatchFilter copy(RelTraitSet traitSet, RelNode input, RexNode condition) {
-        return copy(traitSet, input, condition, pathType);
-    }
-
-    public MatchFilter copy(RelTraitSet traitSet, RelNode input,
-                            RexNode condition, PathRecordType pathType) {
-        return new MatchFilter(getCluster(), traitSet, input, condition, pathType);
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return copy(traitSet, sole(inputs), condition, pathSchema);
-    }
-
-    @Override
-    public RelNode copy(RelTraitSet traitSet, List inputs) {
-        return copy(traitSet, sole(inputs), condition);
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return match(input).getNodeType();
+  private final RexNode condition;
+  private final PathRecordType pathType;
+
+  protected MatchFilter(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      RexNode condition,
+      PathRecordType pathType) {
+    super(cluster, traits, input);
+    this.condition = Objects.requireNonNull(condition);
+    this.pathType = Objects.requireNonNull(pathType);
+  }
+
+  public MatchFilter copy(RelTraitSet traitSet, RelNode input, RexNode condition) {
+    return copy(traitSet, input, condition, pathType);
+  }
+
+  public MatchFilter copy(
+      RelTraitSet traitSet, RelNode input, RexNode condition, PathRecordType pathType) {
+    return new MatchFilter(getCluster(), traitSet, input, condition, pathType);
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return copy(traitSet, sole(inputs), condition, pathSchema);
+  }
+
+  @Override
+  public RelNode copy(RelTraitSet traitSet, List inputs) {
+    return copy(traitSet, sole(inputs), condition);
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return match(input).getNodeType();
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitFilter(this);
+  }
+
+  @Override
+  public boolean isValid(Litmus litmus, Context context) {
+    if (RexUtil.isNullabilityCast(getCluster().getTypeFactory(), condition)) {
+      return litmus.fail("Cast for just nullability not allowed");
     }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitFilter(this);
-    }
-
-    @Override
-    public boolean isValid(Litmus litmus, Context context) {
-        if (RexUtil.isNullabilityCast(getCluster().getTypeFactory(), condition)) {
-            return litmus.fail("Cast for just nullability not allowed");
-        }
-        final RexChecker checker =
-            new RexChecker(((IMatchNode) GQLRelUtil.toRel(getInput())).getPathSchema(), context, litmus);
-        condition.accept(checker);
-        if (checker.getFailureCount() > 0) {
-            return litmus.fail(null);
-        }
-        return litmus.succeed();
-    }
-
-    @Override
-    public RelNode getInput() {
-        return input;
-    }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        RexNode newCondition = condition.accept(shuttle);
-        return copy(traitSet, input, newCondition);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("condition", condition);
-    }
-
-    public RexNode getCondition() {
-        return condition;
-    }
-
-    public static MatchFilter create(RelNode input, RexNode condition, PathRecordType pathType) {
-        return new MatchFilter(input.getCluster(), input.getTraitSet(), input, condition, pathType);
+    final RexChecker checker =
+        new RexChecker(
+            ((IMatchNode) GQLRelUtil.toRel(getInput())).getPathSchema(), context, litmus);
+    condition.accept(checker);
+    if (checker.getFailureCount() > 0) {
+      return litmus.fail(null);
     }
+    return litmus.succeed();
+  }
+
+  @Override
+  public RelNode getInput() {
+    return input;
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    RexNode newCondition = condition.accept(shuttle);
+    return copy(traitSet, input, newCondition);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("condition", condition);
+  }
+
+  public RexNode getCondition() {
+    return condition;
+  }
+
+  public static MatchFilter create(RelNode input, RexNode condition, PathRecordType pathType) {
+    return new MatchFilter(input.getCluster(), input.getTraitSet(), input, condition, pathType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchJoin.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchJoin.java
index 1504fcf32..32600f70f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchJoin.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchJoin.java
@@ -23,6 +23,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.BiRel;
@@ -47,185 +48,218 @@
 
 public class MatchJoin extends BiRel implements IMatchNode {
 
-    protected final RexNode condition;
+  protected final RexNode condition;
 
-    protected final JoinRelType joinType;
+  protected final JoinRelType joinType;
 
-    protected MatchJoin(RelOptCluster cluster, RelTraitSet traitSet, RelNode left,
-                        RelNode right, RexNode condition, JoinRelType joinType) {
-        super(cluster, traitSet, left, right);
-        this.condition = Objects.requireNonNull(condition);
-        this.joinType = Objects.requireNonNull(joinType);
-    }
+  protected MatchJoin(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode left,
+      RelNode right,
+      RexNode condition,
+      JoinRelType joinType) {
+    super(cluster, traitSet, left, right);
+    this.condition = Objects.requireNonNull(condition);
+    this.joinType = Objects.requireNonNull(joinType);
+  }
 
-    public MatchJoin copy(RelTraitSet traitSet, RexNode conditionExpr, RelNode left, RelNode right,
-                          JoinRelType joinType) {
-        return new MatchJoin(getCluster(), traitSet, left, right, conditionExpr, joinType);
-    }
+  public MatchJoin copy(
+      RelTraitSet traitSet,
+      RexNode conditionExpr,
+      RelNode left,
+      RelNode right,
+      JoinRelType joinType) {
+    return new MatchJoin(getCluster(), traitSet, left, right, conditionExpr, joinType);
+  }
 
-    @Override
-    public RelNode copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 2;
-        return new MatchJoin(getCluster(), traitSet, inputs.get(0), inputs.get(1), condition, joinType);
-    }
+  @Override
+  public RelNode copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 2;
+    return new MatchJoin(getCluster(), traitSet, inputs.get(0), inputs.get(1), condition, joinType);
+  }
 
-    public static MatchJoin create(RelOptCluster cluster, RelTraitSet traitSet, IMatchNode left,
-                                   IMatchNode right, RexNode condition, JoinRelType joinType) {
-        return new MatchJoin(cluster, traitSet, left, right, condition, joinType);
-    }
+  public static MatchJoin create(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      IMatchNode left,
+      IMatchNode right,
+      RexNode condition,
+      JoinRelType joinType) {
+    return new MatchJoin(cluster, traitSet, left, right, condition, joinType);
+  }
 
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) getRowType();
-    }
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) getRowType();
+  }
 
-    @Override
-    public RelDataType getNodeType() {
-        return VirtualVertexRecordType.of(getCluster().getTypeFactory());
-    }
+  @Override
+  public RelDataType getNodeType() {
+    return VirtualVertexRecordType.of(getCluster().getTypeFactory());
+  }
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathType) {
-        return new MatchJoin(getCluster(), getTraitSet(), inputs.get(0), inputs.get(1), condition, joinType);
-    }
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathType) {
+    return new MatchJoin(
+        getCluster(), getTraitSet(), inputs.get(0), inputs.get(1), condition, joinType);
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitJoin(this);
-    }
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitJoin(this);
+  }
 
-    public RexNode getCondition() {
-        return condition;
-    }
+  public RexNode getCondition() {
+    return condition;
+  }
 
-    public JoinRelType getJoinType() {
-        return joinType;
-    }
+  public JoinRelType getJoinType() {
+    return joinType;
+  }
 
-    @Override
-    protected RelDataType deriveRowType() {
-        RelNode realLeft = GQLRelUtil.toRel(left);
-        RelNode realRight = GQLRelUtil.toRel(right);
-        RelDataType relRecordType = SqlValidatorUtil.deriveJoinRowType(((IMatchNode) realLeft).getPathSchema(),
-            ((IMatchNode) realRight).getPathSchema(), joinType, getCluster().getTypeFactory(),
-            null, new ArrayList<>());
-        if (realLeft.getRowType() instanceof UnionPathRecordType
-            || realRight.getRowType() instanceof UnionPathRecordType) {
-            List leftUnionPaths =
-                realLeft.getRowType() instanceof UnionPathRecordType
-                    ? ((UnionPathRecordType) realLeft.getRowType()).getInputPathRecordTypes()
-                    : Collections.singletonList(((IMatchNode) realLeft).getPathSchema());
-            List rightUnionPaths =
-                realRight.getRowType() instanceof UnionPathRecordType
-                    ? ((UnionPathRecordType) realRight.getRowType()).getInputPathRecordTypes()
-                    : Collections.singletonList(((IMatchNode) realRight).getPathSchema());
-            List unionPaths = new ArrayList<>();
-            for (PathRecordType leftPath : leftUnionPaths) {
-                for (PathRecordType rightPath : rightUnionPaths) {
-                    unionPaths.add(new PathRecordType(SqlValidatorUtil.deriveJoinRowType(leftPath,
-                        rightPath, joinType, getCluster().getTypeFactory(),
-                        null, new ArrayList<>()).getFieldList()));
-                }
-            }
-            return new UnionPathRecordType(relRecordType.getFieldList(), unionPaths);
-        } else {
-            return new JoinPathRecordType(relRecordType.getFieldList());
+  @Override
+  protected RelDataType deriveRowType() {
+    RelNode realLeft = GQLRelUtil.toRel(left);
+    RelNode realRight = GQLRelUtil.toRel(right);
+    RelDataType relRecordType =
+        SqlValidatorUtil.deriveJoinRowType(
+            ((IMatchNode) realLeft).getPathSchema(),
+            ((IMatchNode) realRight).getPathSchema(),
+            joinType,
+            getCluster().getTypeFactory(),
+            null,
+            new ArrayList<>());
+    if (realLeft.getRowType() instanceof UnionPathRecordType
+        || realRight.getRowType() instanceof UnionPathRecordType) {
+      List leftUnionPaths =
+          realLeft.getRowType() instanceof UnionPathRecordType
+              ? ((UnionPathRecordType) realLeft.getRowType()).getInputPathRecordTypes()
+              : Collections.singletonList(((IMatchNode) realLeft).getPathSchema());
+      List rightUnionPaths =
+          realRight.getRowType() instanceof UnionPathRecordType
+              ? ((UnionPathRecordType) realRight.getRowType()).getInputPathRecordTypes()
+              : Collections.singletonList(((IMatchNode) realRight).getPathSchema());
+      List unionPaths = new ArrayList<>();
+      for (PathRecordType leftPath : leftUnionPaths) {
+        for (PathRecordType rightPath : rightUnionPaths) {
+          unionPaths.add(
+              new PathRecordType(
+                  SqlValidatorUtil.deriveJoinRowType(
+                          leftPath,
+                          rightPath,
+                          joinType,
+                          getCluster().getTypeFactory(),
+                          null,
+                          new ArrayList<>())
+                      .getFieldList()));
         }
+      }
+      return new UnionPathRecordType(relRecordType.getFieldList(), unionPaths);
+    } else {
+      return new JoinPathRecordType(relRecordType.getFieldList());
     }
+  }
 
-    @Override
-    public void replaceInput(int ordinalInParent, RelNode p) {
-        super.replaceInput(ordinalInParent, p);
-    }
+  @Override
+  public void replaceInput(int ordinalInParent, RelNode p) {
+    super.replaceInput(ordinalInParent, p);
+  }
 
-    public JoinInfo analyzeCondition() {
-        RexBuilder rexBuilder = getCluster().getRexBuilder();
-        List leftPathKeys = new ArrayList<>();
-        List rightPathKeys = new ArrayList<>();
-        List nonEquiList = new ArrayList<>();
-        this.splitJoinCondition(condition, leftPathKeys, rightPathKeys, nonEquiList);
-        RexNode newRemaining = RexUtil.composeConjunction(rexBuilder, nonEquiList);
-        return new PathJoinInfo(ImmutableIntList.copyOf(leftPathKeys),
-            ImmutableIntList.copyOf(rightPathKeys), newRemaining);
-    }
+  public JoinInfo analyzeCondition() {
+    RexBuilder rexBuilder = getCluster().getRexBuilder();
+    List leftPathKeys = new ArrayList<>();
+    List rightPathKeys = new ArrayList<>();
+    List nonEquiList = new ArrayList<>();
+    this.splitJoinCondition(condition, leftPathKeys, rightPathKeys, nonEquiList);
+    RexNode newRemaining = RexUtil.composeConjunction(rexBuilder, nonEquiList);
+    return new PathJoinInfo(
+        ImmutableIntList.copyOf(leftPathKeys),
+        ImmutableIntList.copyOf(rightPathKeys),
+        newRemaining);
+  }
 
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        RexNode condition = shuttle.apply(this.condition);
-        if (this.condition == condition) {
-            return this;
-        }
-        return copy(traitSet, condition, left, right, joinType);
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    RexNode condition = shuttle.apply(this.condition);
+    if (this.condition == condition) {
+      return this;
     }
+    return copy(traitSet, condition, left, right, joinType);
+  }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("condition", condition)
-            .item("joinType", joinType.lowerName);
-    }
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("condition", condition).item("joinType", joinType.lowerName);
+  }
 
-    private static class PathJoinInfo extends JoinInfo {
-        public final RexNode remaining;
+  private static class PathJoinInfo extends JoinInfo {
+    public final RexNode remaining;
 
-        protected PathJoinInfo(ImmutableIntList leftKeys, ImmutableIntList rightKeys, RexNode remaining) {
-            super(leftKeys, rightKeys);
-            this.remaining = Objects.requireNonNull(remaining);
-        }
+    protected PathJoinInfo(
+        ImmutableIntList leftKeys, ImmutableIntList rightKeys, RexNode remaining) {
+      super(leftKeys, rightKeys);
+      this.remaining = Objects.requireNonNull(remaining);
+    }
 
-        @Override
-        public boolean isEqui() {
-            return remaining.isAlwaysTrue();
-        }
+    @Override
+    public boolean isEqui() {
+      return remaining.isAlwaysTrue();
+    }
 
-        @Override
-        public RexNode getRemaining(RexBuilder rexBuilder) {
-            return isEqui() ? rexBuilder.makeLiteral(true) : remaining;
-        }
+    @Override
+    public RexNode getRemaining(RexBuilder rexBuilder) {
+      return isEqui() ? rexBuilder.makeLiteral(true) : remaining;
     }
+  }
 
-    private void splitJoinCondition(RexNode condition, List leftKeys,
-                                    List rightKeys, List nonEquiList) {
-        if (condition instanceof RexCall) {
-            RexCall call = (RexCall) condition;
-            SqlKind kind = call.getKind();
-            if (kind == SqlKind.AND) {
-                for (RexNode operand : call.getOperands()) {
-                    splitJoinCondition(operand, leftKeys, rightKeys, nonEquiList);
-                }
-                return;
-            }
-            int leftFieldCount = ((IMatchNode) left).getPathSchema().getFieldCount();
-            if (kind == SqlKind.EQUALS) {
-                final List operands = call.getOperands();
-                if (GQLRexUtil.isVertexIdFieldAccess(operands.get(0))
-                    && GQLRexUtil.isVertexIdFieldAccess(operands.get(1))) {
-                    RexFieldAccess op0 = (RexFieldAccess) operands.get(0);
-                    RexFieldAccess op1 = (RexFieldAccess) operands.get(1);
-                    String op0PathLabel = ((PathInputRef) op0.getReferenceExpr()).getLabel();
-                    String op1PathLabel = ((PathInputRef) op1.getReferenceExpr()).getLabel();
-                    int op0Index = this.getPathSchema().getField(op0PathLabel, true, false).getIndex();
-                    int op1Index = this.getPathSchema().getField(op1PathLabel, true, false).getIndex();
-                    RelDataTypeField leftField;
-                    RelDataTypeField rightField;
-                    if (op0Index < leftFieldCount && op1Index >= leftFieldCount) {
-                        leftField = ((IMatchNode) left).getPathSchema().getFieldList().get(op0Index);
-                        rightField = ((IMatchNode) right).getPathSchema().getFieldList().get(op1Index - leftFieldCount);
-                    } else if (op1Index < leftFieldCount && op0Index >= leftFieldCount) {
-                        leftField = ((IMatchNode) left).getPathSchema().getFieldList().get(op1Index);
-                        rightField = ((IMatchNode) right).getPathSchema().getFieldList().get(op0Index - leftFieldCount);
-                    } else {
-                        nonEquiList.add(condition);
-                        return;
-                    }
-                    leftKeys.add(leftField.getIndex());
-                    rightKeys.add(rightField.getIndex());
-                    return;
-                }
-            }
+  private void splitJoinCondition(
+      RexNode condition,
+      List leftKeys,
+      List rightKeys,
+      List nonEquiList) {
+    if (condition instanceof RexCall) {
+      RexCall call = (RexCall) condition;
+      SqlKind kind = call.getKind();
+      if (kind == SqlKind.AND) {
+        for (RexNode operand : call.getOperands()) {
+          splitJoinCondition(operand, leftKeys, rightKeys, nonEquiList);
         }
-        if (!condition.isAlwaysTrue()) {
+        return;
+      }
+      int leftFieldCount = ((IMatchNode) left).getPathSchema().getFieldCount();
+      if (kind == SqlKind.EQUALS) {
+        final List operands = call.getOperands();
+        if (GQLRexUtil.isVertexIdFieldAccess(operands.get(0))
+            && GQLRexUtil.isVertexIdFieldAccess(operands.get(1))) {
+          RexFieldAccess op0 = (RexFieldAccess) operands.get(0);
+          RexFieldAccess op1 = (RexFieldAccess) operands.get(1);
+          String op0PathLabel = ((PathInputRef) op0.getReferenceExpr()).getLabel();
+          String op1PathLabel = ((PathInputRef) op1.getReferenceExpr()).getLabel();
+          int op0Index = this.getPathSchema().getField(op0PathLabel, true, false).getIndex();
+          int op1Index = this.getPathSchema().getField(op1PathLabel, true, false).getIndex();
+          RelDataTypeField leftField;
+          RelDataTypeField rightField;
+          if (op0Index < leftFieldCount && op1Index >= leftFieldCount) {
+            leftField = ((IMatchNode) left).getPathSchema().getFieldList().get(op0Index);
+            rightField =
+                ((IMatchNode) right).getPathSchema().getFieldList().get(op1Index - leftFieldCount);
+          } else if (op1Index < leftFieldCount && op0Index >= leftFieldCount) {
+            leftField = ((IMatchNode) left).getPathSchema().getFieldList().get(op1Index);
+            rightField =
+                ((IMatchNode) right).getPathSchema().getFieldList().get(op0Index - leftFieldCount);
+          } else {
             nonEquiList.add(condition);
+            return;
+          }
+          leftKeys.add(leftField.getIndex());
+          rightKeys.add(rightField.getIndex());
+          return;
         }
+      }
+    }
+    if (!condition.isAlwaysTrue()) {
+      nonEquiList.add(condition);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathModify.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathModify.java
index 8060f8ac6..e9ebe2325 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathModify.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathModify.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.match;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -32,44 +33,53 @@
 
 public class MatchPathModify extends PathModify implements SingleMatchNode {
 
-    protected MatchPathModify(RelOptCluster cluster, RelTraitSet traits,
-                              RelNode input,
-                              List expressions, RelDataType rowType,
-                              GraphRecordType modifyGraphType) {
-        super(cluster, traits, input, expressions, rowType, modifyGraphType);
-    }
+  protected MatchPathModify(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType modifyGraphType) {
+    super(cluster, traits, input, expressions, rowType, modifyGraphType);
+  }
 
-    @Override
-    public PathModify copy(RelTraitSet traitSet, RelNode input, List expressions,
-                           RelDataType rowType) {
-        return new MatchPathModify(getCluster(), traitSet, input, expressions, rowType, modifyGraphType);
-    }
+  @Override
+  public PathModify copy(
+      RelTraitSet traitSet,
+      RelNode input,
+      List expressions,
+      RelDataType rowType) {
+    return new MatchPathModify(
+        getCluster(), traitSet, input, expressions, rowType, modifyGraphType);
+  }
 
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) getRowType();
-    }
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) getRowType();
+  }
 
-    @Override
-    public RelDataType getNodeType() {
-        return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
-    }
+  @Override
+  public RelDataType getNodeType() {
+    return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitPathModify(this);
-    }
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitPathModify(this);
+  }
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new MatchPathModify(getCluster(), getTraitSet(), sole(inputs), expressions,
-            pathSchema, modifyGraphType);
-    }
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new MatchPathModify(
+        getCluster(), getTraitSet(), sole(inputs), expressions, pathSchema, modifyGraphType);
+  }
 
-    public static MatchPathModify create(RelNode input, List expressions,
-                                         RelDataType rowType,
-                                         GraphRecordType modifyGraphType) {
-        return new MatchPathModify(input.getCluster(), input.getTraitSet(), input,
-            expressions, rowType, modifyGraphType);
-    }
+  public static MatchPathModify create(
+      RelNode input,
+      List expressions,
+      RelDataType rowType,
+      GraphRecordType modifyGraphType) {
+    return new MatchPathModify(
+        input.getCluster(), input.getTraitSet(), input, expressions, rowType, modifyGraphType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathSort.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathSort.java
index 914777132..0b0e0ca23 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathSort.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchPathSort.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rel.match;
 
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -32,44 +33,46 @@
 
 public class MatchPathSort extends PathSort implements SingleMatchNode {
 
-    protected MatchPathSort(RelOptCluster cluster, RelTraitSet traits,
-                            RelNode input, List orderByExpressions,
-                            RexNode limit, PathRecordType pathType) {
-        super(cluster, traits, input, orderByExpressions, limit, pathType);
-    }
+  protected MatchPathSort(
+      RelOptCluster cluster,
+      RelTraitSet traits,
+      RelNode input,
+      List orderByExpressions,
+      RexNode limit,
+      PathRecordType pathType) {
+    super(cluster, traits, input, orderByExpressions, limit, pathType);
+  }
 
-    @Override
-    public PathSort copy(RelNode input, List orderByCollation,
-                         RexNode fetch, PathRecordType pathType) {
-        return new MatchPathSort(getCluster(), getTraitSet(), input, orderByCollation,
-            limit, pathType);
-    }
+  @Override
+  public PathSort copy(
+      RelNode input, List orderByCollation, RexNode fetch, PathRecordType pathType) {
+    return new MatchPathSort(getCluster(), getTraitSet(), input, orderByCollation, limit, pathType);
+  }
 
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) rowType;
-    }
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) rowType;
+  }
 
-    @Override
-    public RelDataType getNodeType() {
-        return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
-    }
+  @Override
+  public RelDataType getNodeType() {
+    return ((IMatchNode) GQLRelUtil.toRel(getInput())).getNodeType();
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitSort(this);
-    }
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitSort(this);
+  }
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathType) {
-        return new MatchPathSort(getCluster(), getTraitSet(), sole(inputs),
-            orderByExpressions, limit, pathType);
-    }
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathType) {
+    return new MatchPathSort(
+        getCluster(), getTraitSet(), sole(inputs), orderByExpressions, limit, pathType);
+  }
 
-    public static MatchPathSort create(RelNode input, List orderByExpressions,
-                                       RexNode limit,
-                                       PathRecordType pathType) {
-        return new MatchPathSort(input.getCluster(), input.getTraitSet(),
-            input, orderByExpressions, limit, pathType);
-    }
+  public static MatchPathSort create(
+      RelNode input, List orderByExpressions, RexNode limit, PathRecordType pathType) {
+    return new MatchPathSort(
+        input.getCluster(), input.getTraitSet(), input, orderByExpressions, limit, pathType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchUnion.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchUnion.java
index bfcb4d1b1..5ecbcc67f 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchUnion.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/MatchUnion.java
@@ -21,6 +21,7 @@
 
 import java.util.List;
 import java.util.stream.Collectors;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -35,46 +36,47 @@
 
 public class MatchUnion extends Union implements IMatchNode {
 
-    protected MatchUnion(RelOptCluster cluster, RelTraitSet traits,
-                         List inputs, boolean all) {
-        super(cluster, traits, ArrayUtil.castList(inputs), all);
-    }
+  protected MatchUnion(
+      RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all) {
+    super(cluster, traits, ArrayUtil.castList(inputs), all);
+  }
 
-    @Override
-    public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) {
-        return new MatchUnion(getCluster(), traitSet, ArrayUtil.castList(inputs), all);
-    }
+  @Override
+  public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) {
+    return new MatchUnion(getCluster(), traitSet, ArrayUtil.castList(inputs), all);
+  }
 
-    public static MatchUnion create(RelOptCluster cluster, RelTraitSet traits,
-                                    List inputs, boolean all) {
-        return new MatchUnion(cluster, traits, inputs, all);
-    }
+  public static MatchUnion create(
+      RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all) {
+    return new MatchUnion(cluster, traits, inputs, all);
+  }
 
-    @Override
-    protected RelDataType deriveRowType() {
-        List inputPathTypes = inputs.stream()
+  @Override
+  protected RelDataType deriveRowType() {
+    List inputPathTypes =
+        inputs.stream()
             .map(input -> ((IMatchNode) GQLRelUtil.toRel(input)).getPathSchema())
             .collect(Collectors.toList());
-        return new UnionPathRecordType(inputPathTypes, getCluster().getTypeFactory());
-    }
+    return new UnionPathRecordType(inputPathTypes, getCluster().getTypeFactory());
+  }
 
-    @Override
-    public PathRecordType getPathSchema() {
-        return (PathRecordType) getRowType();
-    }
+  @Override
+  public PathRecordType getPathSchema() {
+    return (PathRecordType) getRowType();
+  }
 
-    @Override
-    public RelDataType getNodeType() {
-        return getPathSchema();
-    }
+  @Override
+  public RelDataType getNodeType() {
+    return getPathSchema();
+  }
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathType) {
-        return new MatchUnion(getCluster(), getTraitSet(), inputs, all);
-    }
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathType) {
+    return new MatchUnion(getCluster(), getTraitSet(), inputs, all);
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitUnion(this);
-    }
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitUnion(this);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalEdgeMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalEdgeMatch.java
index 8de656e6f..20a0c096d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalEdgeMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalEdgeMatch.java
@@ -19,10 +19,10 @@
 
 package org.apache.geaflow.dsl.rel.match;
 
-
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -33,54 +33,82 @@
 import org.apache.geaflow.dsl.rel.MatchNodeVisitor;
 import org.apache.geaflow.dsl.sqlnode.SqlMatchEdge.EdgeDirection;
 
-
 public class OptionalEdgeMatch extends EdgeMatch {
 
-    private OptionalEdgeMatch(RelOptCluster cluster, RelTraitSet traitSet,
-                              RelNode input, String label,
-                              Collection edgeTypes, EdgeDirection direction,
-                              RelDataType nodeType, PathRecordType pathType) {
-        super(cluster, traitSet, input, label, edgeTypes, direction, nodeType, pathType, new HashSet<>());
-    }
+  private OptionalEdgeMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection edgeTypes,
+      EdgeDirection direction,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    super(
+        cluster, traitSet, input, label, edgeTypes, direction, nodeType, pathType, new HashSet<>());
+  }
 
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        assert inputs.size() == 1;
-        return new OptionalEdgeMatch(getCluster(), getTraitSet(), sole(inputs),
-            getLabel(), getTypes(), getDirection(), getNodeType(), pathSchema);
-    }
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    assert inputs.size() == 1;
+    return new OptionalEdgeMatch(
+        getCluster(),
+        getTraitSet(),
+        sole(inputs),
+        getLabel(),
+        getTypes(),
+        getDirection(),
+        getNodeType(),
+        pathSchema);
+  }
 
-    @Override
-    public OptionalEdgeMatch copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return new OptionalEdgeMatch(getCluster(), getTraitSet(), sole(inputs),
-            getLabel(), getTypes(), getDirection(), getNodeType(), getPathSchema());
-    }
+  @Override
+  public OptionalEdgeMatch copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return new OptionalEdgeMatch(
+        getCluster(),
+        getTraitSet(),
+        sole(inputs),
+        getLabel(),
+        getTypes(),
+        getDirection(),
+        getNodeType(),
+        getPathSchema());
+  }
 
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        EdgeMatch newEdgeMatch = (EdgeMatch) super.accept(shuttle);
-        return new OptionalEdgeMatch(
-            newEdgeMatch.getCluster(), newEdgeMatch.getTraitSet(), newEdgeMatch.getInput(),
-            newEdgeMatch.getLabel(), newEdgeMatch.getTypes(), newEdgeMatch.getDirection(),
-            newEdgeMatch.getNodeType(), newEdgeMatch.getPathSchema());
-    }
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    EdgeMatch newEdgeMatch = (EdgeMatch) super.accept(shuttle);
+    return new OptionalEdgeMatch(
+        newEdgeMatch.getCluster(),
+        newEdgeMatch.getTraitSet(),
+        newEdgeMatch.getInput(),
+        newEdgeMatch.getLabel(),
+        newEdgeMatch.getTypes(),
+        newEdgeMatch.getDirection(),
+        newEdgeMatch.getNodeType(),
+        newEdgeMatch.getPathSchema());
+  }
 
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw);
-    }
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw);
+  }
 
-    public static OptionalEdgeMatch create(RelOptCluster cluster, SingleMatchNode input,
-                                           String label, List edgeTypes,
-                                           EdgeDirection direction, RelDataType nodeType,
-                                           PathRecordType pathType) {
-        return new OptionalEdgeMatch(cluster, cluster.traitSet(), input, label, edgeTypes,
-            direction, nodeType, pathType);
-    }
+  public static OptionalEdgeMatch create(
+      RelOptCluster cluster,
+      SingleMatchNode input,
+      String label,
+      List edgeTypes,
+      EdgeDirection direction,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    return new OptionalEdgeMatch(
+        cluster, cluster.traitSet(), input, label, edgeTypes, direction, nodeType, pathType);
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitEdgeMatch(this);
-    }
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitEdgeMatch(this);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalVertexMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalVertexMatch.java
index 97025a610..649f4e76e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalVertexMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/OptionalVertexMatch.java
@@ -21,6 +21,7 @@
 
 import java.util.Collection;
 import java.util.List;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -32,46 +33,83 @@
 
 public class OptionalVertexMatch extends VertexMatch {
 
-    public OptionalVertexMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                               String label, Collection vertexTypes, RelDataType nodeType,
-                               PathRecordType pathType) {
-        super(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, null);
-    }
-
-    public OptionalVertexMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                               String label, Collection vertexTypes, RelDataType nodeType,
-                               PathRecordType pathType, RexNode pushDownFilter) {
-        super(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, pushDownFilter);
-    }
+  public OptionalVertexMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    super(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, null);
+  }
 
-    @Override
-    public SingleMatchNode copy(List inputs, PathRecordType pathSchema) {
-        assert inputs.size() <= 1;
-        RelNode input = inputs.isEmpty() ? null : inputs.get(0);
-        return new OptionalVertexMatch(getCluster(), traitSet, input, getLabel(),
-            getTypes(), getNodeType(), pathSchema, getPushDownFilter());
-    }
+  public OptionalVertexMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType,
+      RexNode pushDownFilter) {
+    super(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, pushDownFilter);
+  }
 
-    @Override
-    public OptionalVertexMatch copy(RelTraitSet traitSet, List inputs) {
-        RelNode input = GQLRelUtil.oneInput(inputs);
-        return new OptionalVertexMatch(getCluster(), getTraitSet(), input,
-            getLabel(), getTypes(), getNodeType(), getPathSchema(), getPushDownFilter());
-    }
+  @Override
+  public SingleMatchNode copy(List inputs, PathRecordType pathSchema) {
+    assert inputs.size() <= 1;
+    RelNode input = inputs.isEmpty() ? null : inputs.get(0);
+    return new OptionalVertexMatch(
+        getCluster(),
+        traitSet,
+        input,
+        getLabel(),
+        getTypes(),
+        getNodeType(),
+        pathSchema,
+        getPushDownFilter());
+  }
 
-    public OptionalVertexMatch copy(RexNode pushDownFilter) {
-        return new OptionalVertexMatch(getCluster(), getTraitSet(), getInput(),
-            getLabel(), getTypes(), getNodeType(), getPathSchema(), pushDownFilter);
-    }
+  @Override
+  public OptionalVertexMatch copy(RelTraitSet traitSet, List inputs) {
+    RelNode input = GQLRelUtil.oneInput(inputs);
+    return new OptionalVertexMatch(
+        getCluster(),
+        getTraitSet(),
+        input,
+        getLabel(),
+        getTypes(),
+        getNodeType(),
+        getPathSchema(),
+        getPushDownFilter());
+  }
 
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitVertexMatch(this);
-    }
+  public OptionalVertexMatch copy(RexNode pushDownFilter) {
+    return new OptionalVertexMatch(
+        getCluster(),
+        getTraitSet(),
+        getInput(),
+        getLabel(),
+        getTypes(),
+        getNodeType(),
+        getPathSchema(),
+        pushDownFilter);
+  }
 
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitVertexMatch(this);
+  }
 
-    public static OptionalVertexMatch create(RelOptCluster cluster, SingleMatchNode input, String label,
-                                             List vertexTypes, RelDataType nodeType, PathRecordType pathType) {
-        return new OptionalVertexMatch(cluster, cluster.traitSet(), input, label, vertexTypes, nodeType, pathType);
-    }
+  public static OptionalVertexMatch create(
+      RelOptCluster cluster,
+      SingleMatchNode input,
+      String label,
+      List vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    return new OptionalVertexMatch(
+        cluster, cluster.traitSet(), input, label, vertexTypes, nodeType, pathType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SingleMatchNode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SingleMatchNode.java
index 9e3fcbaa5..6465e5eee 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SingleMatchNode.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SingleMatchNode.java
@@ -20,19 +20,20 @@
 package org.apache.geaflow.dsl.rel.match;
 
 import java.util.function.Predicate;
+
 import org.apache.calcite.rel.RelNode;
 
 public interface SingleMatchNode extends IMatchNode {
 
-    RelNode getInput();
+  RelNode getInput();
 
-    default SingleMatchNode find(Predicate condition) {
-        if (condition.test(this)) {
-            return this;
-        }
-        if (this.getInput() == null) {
-            return null;
-        }
-        return ((SingleMatchNode) this.getInput()).find(condition);
+  default SingleMatchNode find(Predicate condition) {
+    if (condition.test(this)) {
+      return this;
+    }
+    if (this.getInput() == null) {
+      return null;
     }
+    return ((SingleMatchNode) this.getInput()).find(condition);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SubQueryStart.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SubQueryStart.java
index 36bcfda6a..76ef09025 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SubQueryStart.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/SubQueryStart.java
@@ -22,6 +22,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.AbstractRelNode;
@@ -33,65 +34,71 @@
 
 public class SubQueryStart extends AbstractRelNode implements SingleMatchNode {
 
-    private final String queryName;
-
-    private final PathRecordType pathType;
-
-    private final VertexRecordType vertexType;
-
-    protected SubQueryStart(RelOptCluster cluster, RelTraitSet traitSet, String queryName,
-                            PathRecordType pathType, VertexRecordType vertexType) {
-        super(cluster, traitSet);
-        this.queryName = Objects.requireNonNull(queryName);
-        this.pathType = pathType;
-        this.rowType = pathType;
-        this.vertexType = vertexType;
-    }
-
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return vertexType;
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitSubQueryStart(this);
-    }
-
-
-    @Override
-    public RelNode getInput() {
-        return null;
-    }
-
-    @Override
-    public List getInputs() {
-        return Collections.emptyList();
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new SubQueryStart(getCluster(), getTraitSet(), queryName, pathSchema, vertexType);
-    }
-
-    @Override
-    public SubQueryStart copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 0;
-        return new SubQueryStart(getCluster(), traitSet, queryName, pathType, vertexType);
-    }
-
-    public String getQueryName() {
-        return queryName;
-    }
-
-    public static SubQueryStart create(RelOptCluster cluster, RelTraitSet traitSet, String queryName,
-                                       PathRecordType parentPathType, VertexRecordType vertexType) {
-        return new SubQueryStart(cluster, traitSet, queryName, parentPathType, vertexType);
-    }
+  private final String queryName;
+
+  private final PathRecordType pathType;
+
+  private final VertexRecordType vertexType;
+
+  protected SubQueryStart(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      String queryName,
+      PathRecordType pathType,
+      VertexRecordType vertexType) {
+    super(cluster, traitSet);
+    this.queryName = Objects.requireNonNull(queryName);
+    this.pathType = pathType;
+    this.rowType = pathType;
+    this.vertexType = vertexType;
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return vertexType;
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitSubQueryStart(this);
+  }
+
+  @Override
+  public RelNode getInput() {
+    return null;
+  }
+
+  @Override
+  public List getInputs() {
+    return Collections.emptyList();
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new SubQueryStart(getCluster(), getTraitSet(), queryName, pathSchema, vertexType);
+  }
+
+  @Override
+  public SubQueryStart copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 0;
+    return new SubQueryStart(getCluster(), traitSet, queryName, pathType, vertexType);
+  }
+
+  public String getQueryName() {
+    return queryName;
+  }
+
+  public static SubQueryStart create(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      String queryName,
+      PathRecordType parentPathType,
+      VertexRecordType vertexType) {
+    return new SubQueryStart(cluster, traitSet, queryName, parentPathType, vertexType);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VertexMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VertexMatch.java
index b6f42b997..79acdfa62 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VertexMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VertexMatch.java
@@ -21,9 +21,8 @@
 
 import static org.apache.geaflow.dsl.util.GQLRelUtil.match;
 
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
 import java.util.*;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.AbstractRelNode;
@@ -39,175 +38,251 @@
 import org.apache.geaflow.dsl.rel.MatchNodeVisitor;
 import org.apache.geaflow.dsl.util.GQLRelUtil;
 
-public class VertexMatch extends AbstractRelNode implements SingleMatchNode, IMatchLabel {
-
-    private RelNode input;
-
-    private final String label;
-
-    private Set pushDownFields;
-
-    private final ImmutableSet vertexTypes;
-
-    private final PathRecordType pathType;
-
-    private final RelDataType nodeType;
-
-    /**
-     * The filter pushed down to the first vertex match.
-     */
-    private RexNode pushDownFilter;
-
-    private Set idSet;
-
-    public VertexMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                       String label, Collection vertexTypes, RelDataType nodeType,
-                       PathRecordType pathType) {
-        this(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, null);
-    }
-
-    public VertexMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                       String label, Collection vertexTypes, RelDataType nodeType,
-                       PathRecordType pathType, RexNode pushDownFilter) {
-        this(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, pushDownFilter,
-            new HashSet<>(), null);
-    }
-
-    public VertexMatch(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
-                       String label, Collection vertexTypes, RelDataType nodeType,
-                       PathRecordType pathType, RexNode pushDownFilter, Set idSet,
-                       Set pushDownFields) {
-        super(cluster, traitSet);
-        this.input = input;
-        this.label = label;
-        this.vertexTypes = ImmutableSet.copyOf(vertexTypes);
-
-        if (input != null && !(GQLRelUtil.toRel(input) instanceof SubQueryStart)
-            && match(input).getNodeType().getSqlTypeName() != SqlTypeName.EDGE) {
-            throw new GeaFlowDSLException("Illegal input type: " + match(input).getNodeType().getSqlTypeName()
-                + " for: " + getRelTypeName() + ", should be: " + SqlTypeName.EDGE);
-        }
-        this.rowType = Objects.requireNonNull(pathType);
-        this.pathType = Objects.requireNonNull(pathType);
-        this.nodeType = Objects.requireNonNull(nodeType);
-        this.pushDownFilter = pushDownFilter;
-        this.idSet = idSet;
-        this.pushDownFields = pushDownFields;
-    }
-
-    public void addField(RexFieldAccess field) {
-        if (pushDownFields == null) {
-            pushDownFields = new HashSet<>();
-        }
-        pushDownFields.add(field);
-    }
-
-    @Override
-    public String getLabel() {
-        return label;
-    }
-
-    @Override
-    public Set getTypes() {
-        return vertexTypes;
-    }
-
-    @Override
-    public List getInputs() {
-        if (input == null) {
-            return Collections.emptyList();
-        }
-        return ImmutableList.of(input);
-    }
-
-    @Override
-    public RelNode getInput() {
-        return input;
-    }
-
-    public RexNode getPushDownFilter() {
-        return pushDownFilter;
-    }
-
-    public Set getIdSet() {
-        return idSet;
-    }
-
-    public Set getFields() {
-        return pushDownFields;
-    }
-
-    @Override
-    public SingleMatchNode copy(List inputs, PathRecordType pathSchema) {
-        assert inputs.size() <= 1;
-        RelNode input = inputs.isEmpty() ? null : inputs.get(0);
-        return new VertexMatch(getCluster(), traitSet, input, label,
-            vertexTypes, nodeType, pathSchema, pushDownFilter, idSet, pushDownFields);
-    }
-
-    @Override
-    public VertexMatch copy(RelTraitSet traitSet, List inputs) {
-        RelNode input = GQLRelUtil.oneInput(inputs);
-        return new VertexMatch(getCluster(), getTraitSet(), input,
-            label, vertexTypes, nodeType, pathType, pushDownFilter, idSet, pushDownFields);
-    }
-
-    public VertexMatch copy(RexNode pushDownFilter) {
-        return new VertexMatch(getCluster(), getTraitSet(), input,
-            label, vertexTypes, nodeType, pathType, pushDownFilter, idSet, pushDownFields);
-    }
-
-    public VertexMatch copy(Set idSet) {
-        return new VertexMatch(getCluster(), getTraitSet(), input,
-            label, vertexTypes, nodeType, pathType, pushDownFilter, idSet, pushDownFields);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("input", input)
-            .item("label", label)
-            .item("vertexTypes", vertexTypes)
-            .item("idSet", idSet);
-    }
-
-    @Override
-    public void replaceInput(int ordinalInParent, RelNode p) {
-        assert ordinalInParent == 0;
-        this.input = p;
-    }
-
-    @Override
-    protected RelDataType deriveRowType() {
-        throw new UnsupportedOperationException();
-    }
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 
-    public static VertexMatch create(RelOptCluster cluster, SingleMatchNode input, String label,
-                                     List vertexTypes, RelDataType nodeType, PathRecordType pathType) {
-        return new VertexMatch(cluster, cluster.traitSet(), input, label, vertexTypes, nodeType, pathType);
-    }
+public class VertexMatch extends AbstractRelNode implements SingleMatchNode, IMatchLabel {
 
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
+  private RelNode input;
+
+  private final String label;
+
+  private Set pushDownFields;
+
+  private final ImmutableSet vertexTypes;
+
+  private final PathRecordType pathType;
+
+  private final RelDataType nodeType;
+
+  /** The filter pushed down to the first vertex match. */
+  private RexNode pushDownFilter;
+
+  private Set idSet;
+
+  public VertexMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    this(cluster, traitSet, input, label, vertexTypes, nodeType, pathType, null);
+  }
+
+  public VertexMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType,
+      RexNode pushDownFilter) {
+    this(
+        cluster,
+        traitSet,
+        input,
+        label,
+        vertexTypes,
+        nodeType,
+        pathType,
+        pushDownFilter,
+        new HashSet<>(),
+        null);
+  }
+
+  public VertexMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      String label,
+      Collection vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType,
+      RexNode pushDownFilter,
+      Set idSet,
+      Set pushDownFields) {
+    super(cluster, traitSet);
+    this.input = input;
+    this.label = label;
+    this.vertexTypes = ImmutableSet.copyOf(vertexTypes);
+
+    if (input != null
+        && !(GQLRelUtil.toRel(input) instanceof SubQueryStart)
+        && match(input).getNodeType().getSqlTypeName() != SqlTypeName.EDGE) {
+      throw new GeaFlowDSLException(
+          "Illegal input type: "
+              + match(input).getNodeType().getSqlTypeName()
+              + " for: "
+              + getRelTypeName()
+              + ", should be: "
+              + SqlTypeName.EDGE);
     }
-
-    @Override
-    public RelDataType getNodeType() {
-        return nodeType;
+    this.rowType = Objects.requireNonNull(pathType);
+    this.pathType = Objects.requireNonNull(pathType);
+    this.nodeType = Objects.requireNonNull(nodeType);
+    this.pushDownFilter = pushDownFilter;
+    this.idSet = idSet;
+    this.pushDownFields = pushDownFields;
+  }
+
+  public void addField(RexFieldAccess field) {
+    if (pushDownFields == null) {
+      pushDownFields = new HashSet<>();
     }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitVertexMatch(this);
+    pushDownFields.add(field);
+  }
+
+  @Override
+  public String getLabel() {
+    return label;
+  }
+
+  @Override
+  public Set getTypes() {
+    return vertexTypes;
+  }
+
+  @Override
+  public List getInputs() {
+    if (input == null) {
+      return Collections.emptyList();
     }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        if (pushDownFilter != null) {
-            RexNode newPushDownFilter = pushDownFilter.accept(shuttle);
-            return copy(newPushDownFilter);
-        }
-        return this;
+    return ImmutableList.of(input);
+  }
+
+  @Override
+  public RelNode getInput() {
+    return input;
+  }
+
+  public RexNode getPushDownFilter() {
+    return pushDownFilter;
+  }
+
+  public Set getIdSet() {
+    return idSet;
+  }
+
+  public Set getFields() {
+    return pushDownFields;
+  }
+
+  @Override
+  public SingleMatchNode copy(List inputs, PathRecordType pathSchema) {
+    assert inputs.size() <= 1;
+    RelNode input = inputs.isEmpty() ? null : inputs.get(0);
+    return new VertexMatch(
+        getCluster(),
+        traitSet,
+        input,
+        label,
+        vertexTypes,
+        nodeType,
+        pathSchema,
+        pushDownFilter,
+        idSet,
+        pushDownFields);
+  }
+
+  @Override
+  public VertexMatch copy(RelTraitSet traitSet, List inputs) {
+    RelNode input = GQLRelUtil.oneInput(inputs);
+    return new VertexMatch(
+        getCluster(),
+        getTraitSet(),
+        input,
+        label,
+        vertexTypes,
+        nodeType,
+        pathType,
+        pushDownFilter,
+        idSet,
+        pushDownFields);
+  }
+
+  public VertexMatch copy(RexNode pushDownFilter) {
+    return new VertexMatch(
+        getCluster(),
+        getTraitSet(),
+        input,
+        label,
+        vertexTypes,
+        nodeType,
+        pathType,
+        pushDownFilter,
+        idSet,
+        pushDownFields);
+  }
+
+  public VertexMatch copy(Set idSet) {
+    return new VertexMatch(
+        getCluster(),
+        getTraitSet(),
+        input,
+        label,
+        vertexTypes,
+        nodeType,
+        pathType,
+        pushDownFilter,
+        idSet,
+        pushDownFields);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw)
+        .item("input", input)
+        .item("label", label)
+        .item("vertexTypes", vertexTypes)
+        .item("idSet", idSet);
+  }
+
+  @Override
+  public void replaceInput(int ordinalInParent, RelNode p) {
+    assert ordinalInParent == 0;
+    this.input = p;
+  }
+
+  @Override
+  protected RelDataType deriveRowType() {
+    throw new UnsupportedOperationException();
+  }
+
+  public static VertexMatch create(
+      RelOptCluster cluster,
+      SingleMatchNode input,
+      String label,
+      List vertexTypes,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    return new VertexMatch(
+        cluster, cluster.traitSet(), input, label, vertexTypes, nodeType, pathType);
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return nodeType;
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitVertexMatch(this);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    if (pushDownFilter != null) {
+      RexNode newPushDownFilter = pushDownFilter.accept(shuttle);
+      return copy(newPushDownFilter);
     }
+    return this;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VirtualEdgeMatch.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VirtualEdgeMatch.java
index f8ae51d59..1c95f6086 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VirtualEdgeMatch.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rel/match/VirtualEdgeMatch.java
@@ -23,6 +23,7 @@
 
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelNode;
@@ -39,83 +40,93 @@
 
 public class VirtualEdgeMatch extends SingleRel implements SingleMatchNode, IMatchNode {
 
-    private final RexNode targetIdExpression;
-
-    private final PathRecordType pathType;
-
-    private final RelDataType nodeType;
-
-    public VirtualEdgeMatch(RelOptCluster cluster, RelTraitSet traitSet,
-                            RelNode input, RexNode targetIdExpression,
-                            RelDataType nodeType, PathRecordType pathType) {
-        super(cluster, traitSet, input);
-        this.targetIdExpression = Objects.requireNonNull(targetIdExpression);
-        if (match(input).getNodeType().getSqlTypeName() != SqlTypeName.VERTEX) {
-            throw new GeaFlowDSLException("Illegal input type: " + match(input).getNodeType().getSqlTypeName()
-                + " for " + getRelTypeName() + ", should be: " + SqlTypeName.VERTEX);
-        }
-        this.rowType = Objects.requireNonNull(pathType);
-        this.pathType = Objects.requireNonNull(pathType);
-        this.nodeType = Objects.requireNonNull(nodeType);
-    }
-
-    public RexNode getTargetId() {
-        return targetIdExpression;
-    }
-
-    @Override
-    public VirtualEdgeMatch copy(RelTraitSet traitSet, List inputs) {
-        assert inputs.size() == 1;
-        return copy(traitSet, sole(inputs), targetIdExpression);
-    }
-
-    public VirtualEdgeMatch copy(RelTraitSet traitSet, RelNode input, RexNode targetIdExpression) {
-        return new VirtualEdgeMatch(getCluster(), traitSet, input, targetIdExpression, nodeType, pathType);
-    }
-
-    @Override
-    public RelWriter explainTerms(RelWriter pw) {
-        return super.explainTerms(pw)
-            .item("targetId", targetIdExpression);
-    }
-
-    public static VirtualEdgeMatch create(IMatchNode input, RexNode targetIdExpression,
-                                          PathRecordType pathType) {
-        EdgeRecordType nodeType = EdgeRecordType.emptyEdgeType(targetIdExpression.getType(),
-            input.getCluster().getTypeFactory());
-        return new VirtualEdgeMatch(input.getCluster(), input.getTraitSet(), input,
-            targetIdExpression, nodeType, pathType);
-    }
-
-    @Override
-    public PathRecordType getPathSchema() {
-        return pathType;
-    }
-
-    @Override
-    public RelDataType getNodeType() {
-        return nodeType;
-    }
-
-    @Override
-    public  T accept(MatchNodeVisitor visitor) {
-        return visitor.visitVirtualEdgeMatch(this);
-    }
-
-    @Override
-    public RelNode getInput() {
-        return input;
-    }
-
-    @Override
-    public IMatchNode copy(List inputs, PathRecordType pathSchema) {
-        return new VirtualEdgeMatch(getCluster(), traitSet, sole(inputs),
-            targetIdExpression, rowType, pathSchema);
-    }
-
-    @Override
-    public RelNode accept(RexShuttle shuttle) {
-        RexNode newExpression = targetIdExpression.accept(shuttle);
-        return copy(traitSet, input, newExpression);
+  private final RexNode targetIdExpression;
+
+  private final PathRecordType pathType;
+
+  private final RelDataType nodeType;
+
+  public VirtualEdgeMatch(
+      RelOptCluster cluster,
+      RelTraitSet traitSet,
+      RelNode input,
+      RexNode targetIdExpression,
+      RelDataType nodeType,
+      PathRecordType pathType) {
+    super(cluster, traitSet, input);
+    this.targetIdExpression = Objects.requireNonNull(targetIdExpression);
+    if (match(input).getNodeType().getSqlTypeName() != SqlTypeName.VERTEX) {
+      throw new GeaFlowDSLException(
+          "Illegal input type: "
+              + match(input).getNodeType().getSqlTypeName()
+              + " for "
+              + getRelTypeName()
+              + ", should be: "
+              + SqlTypeName.VERTEX);
     }
+    this.rowType = Objects.requireNonNull(pathType);
+    this.pathType = Objects.requireNonNull(pathType);
+    this.nodeType = Objects.requireNonNull(nodeType);
+  }
+
+  public RexNode getTargetId() {
+    return targetIdExpression;
+  }
+
+  @Override
+  public VirtualEdgeMatch copy(RelTraitSet traitSet, List inputs) {
+    assert inputs.size() == 1;
+    return copy(traitSet, sole(inputs), targetIdExpression);
+  }
+
+  public VirtualEdgeMatch copy(RelTraitSet traitSet, RelNode input, RexNode targetIdExpression) {
+    return new VirtualEdgeMatch(
+        getCluster(), traitSet, input, targetIdExpression, nodeType, pathType);
+  }
+
+  @Override
+  public RelWriter explainTerms(RelWriter pw) {
+    return super.explainTerms(pw).item("targetId", targetIdExpression);
+  }
+
+  public static VirtualEdgeMatch create(
+      IMatchNode input, RexNode targetIdExpression, PathRecordType pathType) {
+    EdgeRecordType nodeType =
+        EdgeRecordType.emptyEdgeType(
+            targetIdExpression.getType(), input.getCluster().getTypeFactory());
+    return new VirtualEdgeMatch(
+        input.getCluster(), input.getTraitSet(), input, targetIdExpression, nodeType, pathType);
+  }
+
+  @Override
+  public PathRecordType getPathSchema() {
+    return pathType;
+  }
+
+  @Override
+  public RelDataType getNodeType() {
+    return nodeType;
+  }
+
+  @Override
+  public  T accept(MatchNodeVisitor visitor) {
+    return visitor.visitVirtualEdgeMatch(this);
+  }
+
+  @Override
+  public RelNode getInput() {
+    return input;
+  }
+
+  @Override
+  public IMatchNode copy(List inputs, PathRecordType pathSchema) {
+    return new VirtualEdgeMatch(
+        getCluster(), traitSet, sole(inputs), targetIdExpression, rowType, pathSchema);
+  }
+
+  @Override
+  public RelNode accept(RexShuttle shuttle) {
+    RexNode newExpression = targetIdExpression.accept(shuttle);
+    return copy(traitSet, input, newExpression);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/MatchAggregateCall.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/MatchAggregateCall.java
index da32a0ad5..7405c6cd7 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/MatchAggregateCall.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/MatchAggregateCall.java
@@ -19,85 +19,97 @@
 
 package org.apache.geaflow.dsl.rex;
 
-import com.google.common.collect.ImmutableList;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlAggFunction;
 
+import com.google.common.collect.ImmutableList;
+
 public class MatchAggregateCall {
-    private final SqlAggFunction aggFunction;
-
-    private final boolean distinct;
-    private final boolean approximate;
-    public final RelDataType type;
-    public final String name;
-
-    private final ImmutableList argList;
-    public final int filterArg;
-    public final RelCollation collation;
-
-    public MatchAggregateCall(SqlAggFunction aggFunction, boolean distinct,
-                              boolean approximate, List argList, int filterArg,
-                              RelCollation collation, RelDataType type, String name) {
-        this.type = Objects.requireNonNull(type);
-        this.name = name;
-        this.aggFunction = Objects.requireNonNull(aggFunction);
-        this.argList = ImmutableList.copyOf(argList);
-        this.filterArg = filterArg;
-        this.collation = Objects.requireNonNull(collation);
-        this.distinct = distinct;
-        this.approximate = approximate;
-    }
+  private final SqlAggFunction aggFunction;
 
-    public final boolean isDistinct() {
-        return distinct;
-    }
+  private final boolean distinct;
+  private final boolean approximate;
+  public final RelDataType type;
+  public final String name;
 
-    public final boolean isApproximate() {
-        return approximate;
-    }
+  private final ImmutableList argList;
+  public final int filterArg;
+  public final RelCollation collation;
 
-    public final SqlAggFunction getAggregation() {
-        return aggFunction;
-    }
+  public MatchAggregateCall(
+      SqlAggFunction aggFunction,
+      boolean distinct,
+      boolean approximate,
+      List argList,
+      int filterArg,
+      RelCollation collation,
+      RelDataType type,
+      String name) {
+    this.type = Objects.requireNonNull(type);
+    this.name = name;
+    this.aggFunction = Objects.requireNonNull(aggFunction);
+    this.argList = ImmutableList.copyOf(argList);
+    this.filterArg = filterArg;
+    this.collation = Objects.requireNonNull(collation);
+    this.distinct = distinct;
+    this.approximate = approximate;
+  }
 
-    public RelCollation getCollation() {
-        return collation;
-    }
+  public final boolean isDistinct() {
+    return distinct;
+  }
 
-    public final List getArgList() {
-        return argList;
-    }
+  public final boolean isApproximate() {
+    return approximate;
+  }
 
-    public final RelDataType getType() {
-        return type;
-    }
+  public final SqlAggFunction getAggregation() {
+    return aggFunction;
+  }
 
-    public String getName() {
-        return name;
-    }
+  public RelCollation getCollation() {
+    return collation;
+  }
 
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
-        }
-        MatchAggregateCall that = (MatchAggregateCall) o;
-        return distinct == that.distinct && approximate == that.approximate
-            && filterArg == that.filterArg && Objects.equals(aggFunction, that.aggFunction)
-            && Objects.equals(type, that.type) && Objects.equals(name, that.name) && Objects.equals(
-            argList, that.argList) && Objects.equals(collation, that.collation);
-    }
+  public final List getArgList() {
+    return argList;
+  }
 
-    @Override
-    public int hashCode() {
-        return Objects.hash(aggFunction, distinct, approximate, type, name, argList, filterArg,
-            collation);
+  public final RelDataType getType() {
+    return type;
+  }
+
+  public String getName() {
+    return name;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
     }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    MatchAggregateCall that = (MatchAggregateCall) o;
+    return distinct == that.distinct
+        && approximate == that.approximate
+        && filterArg == that.filterArg
+        && Objects.equals(aggFunction, that.aggFunction)
+        && Objects.equals(type, that.type)
+        && Objects.equals(name, that.name)
+        && Objects.equals(argList, that.argList)
+        && Objects.equals(collation, that.collation);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(
+        aggFunction, distinct, approximate, type, name, argList, filterArg, collation);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/PathInputRef.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/PathInputRef.java
index c3576acc3..b27e03d00 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/PathInputRef.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/PathInputRef.java
@@ -24,19 +24,19 @@
 
 public class PathInputRef extends RexInputRef {
 
-    private final String label;
+  private final String label;
 
-    public PathInputRef(String label, int index, RelDataType type) {
-        super(index, type);
-        this.digest = label;
-        this.label = label;
-    }
+  public PathInputRef(String label, int index, RelDataType type) {
+    super(index, type);
+    this.digest = label;
+    this.label = label;
+  }
 
-    public String getLabel() {
-        return label;
-    }
+  public String getLabel() {
+    return label;
+  }
 
-    public PathInputRef copy(int newIndex) {
-        return new PathInputRef(label, newIndex, type);
-    }
+  public PathInputRef copy(int newIndex) {
+    return new PathInputRef(label, newIndex, type);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexLambdaCall.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexLambdaCall.java
index 599b9f1ee..4a28c95a7 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexLambdaCall.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexLambdaCall.java
@@ -19,9 +19,10 @@
 
 package org.apache.geaflow.dsl.rex;
 
-import com.google.common.collect.Lists;
 import java.util.List;
+
 import javax.annotation.Nonnull;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexCall;
 import org.apache.calcite.rex.RexNode;
@@ -30,38 +31,40 @@
 import org.apache.geaflow.dsl.rel.GraphMatch.ExplainVisitor;
 import org.apache.geaflow.dsl.rel.match.IMatchNode;
 
+import com.google.common.collect.Lists;
+
 public class RexLambdaCall extends RexCall {
 
-    public RexLambdaCall(RexSubQuery input, RexNode value) {
-        super(value.getType(), SqlLambdaOperator.INSTANCE, Lists.newArrayList(input, value));
-    }
+  public RexLambdaCall(RexSubQuery input, RexNode value) {
+    super(value.getType(), SqlLambdaOperator.INSTANCE, Lists.newArrayList(input, value));
+  }
 
-    public RexSubQuery getInput() {
-        return (RexSubQuery) operands.get(0);
-    }
+  public RexSubQuery getInput() {
+    return (RexSubQuery) operands.get(0);
+  }
 
-    public RexNode getValue() {
-        return operands.get(1);
-    }
+  public RexNode getValue() {
+    return operands.get(1);
+  }
 
-    @Nonnull
-    @Override
-    protected String computeDigest(boolean withType) {
-        RexSubQuery input = getInput();
-        String inputStr;
-        if (input.rel instanceof IMatchNode) {
-            IMatchNode matchNode = (IMatchNode) input.rel;
-            ExplainVisitor explainVisitor = new ExplainVisitor();
-            inputStr = explainVisitor.visit(matchNode);
-        } else {
-            inputStr = input.toString();
-        }
-        return inputStr + " => " + getValue().toString();
+  @Nonnull
+  @Override
+  protected String computeDigest(boolean withType) {
+    RexSubQuery input = getInput();
+    String inputStr;
+    if (input.rel instanceof IMatchNode) {
+      IMatchNode matchNode = (IMatchNode) input.rel;
+      ExplainVisitor explainVisitor = new ExplainVisitor();
+      inputStr = explainVisitor.visit(matchNode);
+    } else {
+      inputStr = input.toString();
     }
+    return inputStr + " => " + getValue().toString();
+  }
 
-    @Override
-    public RexCall clone(RelDataType type, List operands) {
-        assert operands.size() == 2;
-        return new RexLambdaCall((RexSubQuery) operands.get(0), operands.get(1));
-    }
+  @Override
+  public RexCall clone(RelDataType type, List operands) {
+    assert operands.size() == 2;
+    return new RexLambdaCall((RexSubQuery) operands.get(0), operands.get(1));
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexObjectConstruct.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexObjectConstruct.java
index 3a1bafb50..47c3d6245 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexObjectConstruct.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexObjectConstruct.java
@@ -24,7 +24,9 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+
 import javax.annotation.Nonnull;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.*;
@@ -42,280 +44,274 @@
 
 public class RexObjectConstruct extends RexCall {
 
-    private static final RexLiteral DEFAULT_LABEL_VALUE = GQLRexUtil.createString("");
+  private static final RexLiteral DEFAULT_LABEL_VALUE = GQLRexUtil.createString("");
 
-    private final List variableInfos;
+  private final List variableInfos;
 
-    public RexObjectConstruct(RelDataType type,
-                              List operands,
-                              Map rex2VariableInfo) {
-        super(type, createConstructOperator(type), reOrderOperands(type, operands));
-        this.variableInfos = new ArrayList<>();
-        for (RexNode operand : operands) {
-            VariableInfo variableInfo = Objects.requireNonNull(rex2VariableInfo.get(operand));
-            variableInfos.add(variableInfo);
-        }
+  public RexObjectConstruct(
+      RelDataType type,
+      List operands,
+      Map rex2VariableInfo) {
+    super(type, createConstructOperator(type), reOrderOperands(type, operands));
+    this.variableInfos = new ArrayList<>();
+    for (RexNode operand : operands) {
+      VariableInfo variableInfo = Objects.requireNonNull(rex2VariableInfo.get(operand));
+      variableInfos.add(variableInfo);
     }
-
-    public RexObjectConstruct(RelDataType type,
-                              List operands,
-                              List variableInfos) {
-        super(type, createConstructOperator(type), operands);
-        this.variableInfos = Objects.requireNonNull(variableInfos);
+  }
+
+  public RexObjectConstruct(
+      RelDataType type, List operands, List variableInfos) {
+    super(type, createConstructOperator(type), operands);
+    this.variableInfos = Objects.requireNonNull(variableInfos);
+  }
+
+  @Override
+  public RexObjectConstruct clone(RelDataType type, List operands) {
+    return new RexObjectConstruct(type, operands, variableInfos);
+  }
+
+  public List getVariableInfo() {
+    return variableInfos;
+  }
+
+  public RexObjectConstruct merge(RexObjectConstruct input, int pathIndex, RexBuilder builder) {
+    SqlTypeName typeName = getType().getSqlTypeName();
+    if (typeName != input.getType().getSqlTypeName()) {
+      throw new IllegalArgumentException("Fail to merge vertex with edge");
     }
-
-    @Override
-    public RexObjectConstruct clone(RelDataType type, List operands) {
-        return new RexObjectConstruct(type, operands, variableInfos);
+    List fields = getType().getFieldList();
+    List inputFields = input.getType().getFieldList();
+    List inputVariables = input.variableInfos;
+
+    Map name2InputRex = new LinkedHashMap<>();
+    Map name2InputVar = new LinkedHashMap<>();
+    Map name2InputField = new LinkedHashMap<>();
+
+    for (int i = 0; i < inputFields.size(); i++) {
+      name2InputRex.put(inputFields.get(i).getName(), input.getOperands().get(i));
+      name2InputVar.put(inputVariables.get(i).getName(), inputVariables.get(i));
+      name2InputField.put(inputFields.get(i).getName(), inputFields.get(i));
     }
 
-    public List getVariableInfo() {
-        return variableInfos;
+    List mergedNodes = new ArrayList<>();
+    List mergedVariableInfo = new ArrayList<>();
+    List mergedFields = new ArrayList<>();
+
+    for (int i = 0; i < fields.size(); i++) {
+      String name = fields.get(i).getName();
+      RexNode currentRex = operands.get(i);
+      RexNode inputRex = name2InputRex.get(name);
+
+      VariableInfo currentVar = variableInfos.get(i);
+      VariableInfo inputVar = name2InputVar.get(name);
+
+      RelDataTypeField currentField = fields.get(i);
+
+      if (inputRex != null) {
+        assert currentVar.equals(inputVar) : "Fail to merge variable: " + currentVar;
+
+        RexNode mergedRex = mergeRexNode(inputRex, currentRex, pathIndex, name);
+        mergedNodes.add(mergedRex);
+        mergedVariableInfo.add(currentVar);
+        mergedFields.add(currentField);
+        name2InputRex.remove(name);
+        name2InputVar.remove(name);
+        name2InputField.remove(name);
+      } else {
+        mergedNodes.add(currentRex);
+        mergedVariableInfo.add(currentVar);
+        mergedFields.add(currentField);
+      }
     }
 
-    public RexObjectConstruct merge(RexObjectConstruct input, int pathIndex, RexBuilder builder) {
-        SqlTypeName typeName = getType().getSqlTypeName();
-        if (typeName != input.getType().getSqlTypeName()) {
-            throw new IllegalArgumentException("Fail to merge vertex with edge");
-        }
-        List fields = getType().getFieldList();
-        List inputFields = input.getType().getFieldList();
-        List inputVariables = input.variableInfos;
-
-        Map name2InputRex = new LinkedHashMap<>();
-        Map name2InputVar = new LinkedHashMap<>();
-        Map name2InputField = new LinkedHashMap<>();
-
-        for (int i = 0; i < inputFields.size(); i++) {
-            name2InputRex.put(inputFields.get(i).getName(), input.getOperands().get(i));
-            name2InputVar.put(inputVariables.get(i).getName(), inputVariables.get(i));
-            name2InputField.put(inputFields.get(i).getName(), inputFields.get(i));
-        }
-
-        List mergedNodes = new ArrayList<>();
-        List mergedVariableInfo = new ArrayList<>();
-        List mergedFields = new ArrayList<>();
-
-        for (int i = 0; i < fields.size(); i++) {
-            String name = fields.get(i).getName();
-            RexNode currentRex = operands.get(i);
-            RexNode inputRex = name2InputRex.get(name);
-
-            VariableInfo currentVar = variableInfos.get(i);
-            VariableInfo inputVar = name2InputVar.get(name);
-
-            RelDataTypeField currentField = fields.get(i);
-
-            if (inputRex != null) {
-                assert currentVar.equals(inputVar) : "Fail to merge variable: " + currentVar;
-
-                RexNode mergedRex = mergeRexNode(inputRex, currentRex, pathIndex, name);
-                mergedNodes.add(mergedRex);
-                mergedVariableInfo.add(currentVar);
-                mergedFields.add(currentField);
-                name2InputRex.remove(name);
-                name2InputVar.remove(name);
-                name2InputField.remove(name);
-            } else {
-                mergedNodes.add(currentRex);
-                mergedVariableInfo.add(currentVar);
-                mergedFields.add(currentField);
-            }
-        }
-
-        mergedNodes.addAll(name2InputRex.values());
-        mergedVariableInfo.addAll(name2InputVar.values());
-        mergedFields.addAll(name2InputField.values());
+    mergedNodes.addAll(name2InputRex.values());
+    mergedVariableInfo.addAll(name2InputVar.values());
+    mergedFields.addAll(name2InputField.values());
 
-        RelDataType mergedType;
-        if (typeName == SqlTypeName.VERTEX) {
-            mergedType =
-                VertexRecordType.createVertexType(mergedFields, builder.getTypeFactory());
-        } else {
-            mergedType =
-                EdgeRecordType.createEdgeType(mergedFields, builder.getTypeFactory());
-        }
-
-        return new RexObjectConstruct(mergedType, mergedNodes, mergedVariableInfo);
+    RelDataType mergedType;
+    if (typeName == SqlTypeName.VERTEX) {
+      mergedType = VertexRecordType.createVertexType(mergedFields, builder.getTypeFactory());
+    } else {
+      mergedType = EdgeRecordType.createEdgeType(mergedFields, builder.getTypeFactory());
     }
 
-    private RexNode mergeRexNode(RexNode inputRex, RexNode currentRex, int pathIndex, String variableName) {
-        return currentRex.accept(new RexShuttle() {
-
-            @Override
-            public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
-                if (fieldAccess.getReferenceExpr() instanceof RexInputRef) {
-                    RexInputRef inputRef = (RexInputRef) fieldAccess.getReferenceExpr();
-                    if (inputRef.getIndex() == pathIndex && variableName.equals(fieldAccess.getField().getName())) {
-                        return inputRex;
-                    }
-                }
-                return fieldAccess;
+    return new RexObjectConstruct(mergedType, mergedNodes, mergedVariableInfo);
+  }
+
+  private RexNode mergeRexNode(
+      RexNode inputRex, RexNode currentRex, int pathIndex, String variableName) {
+    return currentRex.accept(
+        new RexShuttle() {
+
+          @Override
+          public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+            if (fieldAccess.getReferenceExpr() instanceof RexInputRef) {
+              RexInputRef inputRef = (RexInputRef) fieldAccess.getReferenceExpr();
+              if (inputRef.getIndex() == pathIndex
+                  && variableName.equals(fieldAccess.getField().getName())) {
+                return inputRex;
+              }
             }
+            return fieldAccess;
+          }
         });
+  }
+
+  @Override
+  protected @Nonnull String computeDigest(boolean withType) {
+    final StringBuilder sb = new StringBuilder(op.getName());
+    sb.append("{");
+    for (int i = 0; i < operands.size(); i++) {
+      if (i > 0) {
+        sb.append(", ");
+      }
+      String name = getType().getFieldList().get(i).getName();
+      sb.append(name).append(":").append(operands.get(i));
     }
+    sb.append("}");
 
-    @Override
-    protected @Nonnull String computeDigest(boolean withType) {
-        final StringBuilder sb = new StringBuilder(op.getName());
-        sb.append("{");
+    if (withType) {
+      sb.append(":");
+      sb.append(type.getFullTypeString());
+    }
+    return sb.toString();
+  }
+
+  private static SqlOperator createConstructOperator(RelDataType dataType) {
+    SqlIdentifier[] keyNodes = new SqlIdentifier[dataType.getFieldCount()];
+    int c = 0;
+    for (RelDataTypeField field : dataType.getFieldList()) {
+      keyNodes[c++] = new SqlIdentifier(field.getName(), SqlParserPos.ZERO);
+    }
+    // Construct RexObjectConstruct for dynamic field append expression.
+    return dataType.getSqlTypeName() == SqlTypeName.VERTEX
+        ? new SqlVertexConstructOperator(keyNodes)
+        : new SqlEdgeConstructOperator(keyNodes);
+  }
+
+  private static List reOrderOperands(RelDataType type, List operands) {
+    List reOrderNodes = new ArrayList<>();
+    switch (type.getSqlTypeName()) {
+      case VERTEX:
+        int idIndex = -1;
+        int labelIndex = -1;
         for (int i = 0; i < operands.size(); i++) {
-            if (i > 0) {
-                sb.append(", ");
+          RexNode operand = operands.get(i);
+          if (operand.getType() instanceof MetaFieldType) {
+            MetaFieldType metaFieldType = (MetaFieldType) operand.getType();
+            if (metaFieldType.getMetaField() == MetaField.VERTEX_ID) {
+              idIndex = i;
+            } else if (metaFieldType.getMetaField() == MetaField.VERTEX_TYPE) {
+              labelIndex = i;
             }
-            String name = getType().getFieldList().get(i).getName();
-            sb.append(name).append(":").append(operands.get(i));
+          }
         }
-        sb.append("}");
-
-        if (withType) {
-            sb.append(":");
-            sb.append(type.getFullTypeString());
+        assert idIndex >= 0 : "Id field must defined in the vertex constructor";
+        reOrderNodes.add(operands.get(idIndex));
+        if (labelIndex >= 0) {
+          reOrderNodes.add(operands.get(labelIndex));
+        } else {
+          reOrderNodes.add(DEFAULT_LABEL_VALUE);
         }
-        return sb.toString();
-    }
 
-    private static SqlOperator createConstructOperator(RelDataType dataType) {
-        SqlIdentifier[] keyNodes = new SqlIdentifier[dataType.getFieldCount()];
-        int c = 0;
-        for (RelDataTypeField field : dataType.getFieldList()) {
-            keyNodes[c++] = new SqlIdentifier(field.getName(), SqlParserPos.ZERO);
+        for (int i = 0; i < operands.size(); i++) {
+          if (i != idIndex && i != labelIndex) {
+            reOrderNodes.add(operands.get(i));
+          }
         }
-        // Construct RexObjectConstruct for dynamic field append expression.
-        return dataType.getSqlTypeName() == SqlTypeName.VERTEX
-            ? new SqlVertexConstructOperator(keyNodes)
-            : new SqlEdgeConstructOperator(keyNodes);
-    }
-
-    private static List reOrderOperands(RelDataType type, List operands) {
-        List reOrderNodes = new ArrayList<>();
-        switch (type.getSqlTypeName()) {
-            case VERTEX:
-                int idIndex = -1;
-                int labelIndex = -1;
-                for (int i = 0; i < operands.size(); i++) {
-                    RexNode operand = operands.get(i);
-                    if (operand.getType() instanceof MetaFieldType) {
-                        MetaFieldType metaFieldType = (MetaFieldType) operand.getType();
-                        if (metaFieldType.getMetaField() == MetaField.VERTEX_ID) {
-                            idIndex = i;
-                        } else if (metaFieldType.getMetaField() == MetaField.VERTEX_TYPE) {
-                            labelIndex = i;
-                        }
-                    }
-                }
-                assert idIndex >= 0 : "Id field must defined in the vertex constructor";
-                reOrderNodes.add(operands.get(idIndex));
-                if (labelIndex >= 0) {
-                    reOrderNodes.add(operands.get(labelIndex));
-                } else {
-                    reOrderNodes.add(DEFAULT_LABEL_VALUE);
-                }
-
-                for (int i = 0; i < operands.size(); i++) {
-                    if (i != idIndex && i != labelIndex) {
-                        reOrderNodes.add(operands.get(i));
-                    }
-                }
-                return reOrderNodes;
-            case EDGE:
-                int srcIdIndex = -1;
-                int targetIdIndex = -1;
-                int edgeLabelIndex = -1;
-                int tsIndex = -1;
-                for (int i = 0; i < operands.size(); i++) {
-                    RexNode operand = operands.get(i);
-                    if (operand.getType() instanceof MetaFieldType) {
-                        MetaFieldType metaFieldType = (MetaFieldType) operand.getType();
-                        switch (metaFieldType.getMetaField()) {
-                            case EDGE_SRC_ID:
-                                srcIdIndex = i;
-                                break;
-                            case EDGE_TARGET_ID:
-                                targetIdIndex = i;
-                                break;
-                            case EDGE_TYPE:
-                                edgeLabelIndex = i;
-                                break;
-                            case EDGE_TS:
-                                tsIndex = i;
-                                break;
-                            default:
-                        }
-                    }
-                }
-                assert srcIdIndex >= 0 : "Source id field must defined in edge constructor";
-                assert targetIdIndex >= 0 : "Target id field must defined in edge constructor";
-                reOrderNodes.add(operands.get(srcIdIndex));
-                reOrderNodes.add(operands.get(targetIdIndex));
-                if (edgeLabelIndex >= 0) {
-                    reOrderNodes.add(operands.get(edgeLabelIndex));
-                } else {
-                    reOrderNodes.add(DEFAULT_LABEL_VALUE);
-                }
-                if (tsIndex >= 0) {
-                    reOrderNodes.add(operands.get(tsIndex));
-                }
-                for (int i = 0; i < operands.size(); i++) {
-                    if (i != srcIdIndex && i != targetIdIndex && i != edgeLabelIndex && i != tsIndex) {
-                        reOrderNodes.add(operands.get(i));
-                    }
-                }
-                return reOrderNodes;
-            default:
-                throw new IllegalArgumentException("Illegal type name: " + type.getSqlTypeName()
-                    + " for Object constructor.");
+        return reOrderNodes;
+      case EDGE:
+        int srcIdIndex = -1;
+        int targetIdIndex = -1;
+        int edgeLabelIndex = -1;
+        int tsIndex = -1;
+        for (int i = 0; i < operands.size(); i++) {
+          RexNode operand = operands.get(i);
+          if (operand.getType() instanceof MetaFieldType) {
+            MetaFieldType metaFieldType = (MetaFieldType) operand.getType();
+            switch (metaFieldType.getMetaField()) {
+              case EDGE_SRC_ID:
+                srcIdIndex = i;
+                break;
+              case EDGE_TARGET_ID:
+                targetIdIndex = i;
+                break;
+              case EDGE_TYPE:
+                edgeLabelIndex = i;
+                break;
+              case EDGE_TS:
+                tsIndex = i;
+                break;
+              default:
+            }
+          }
+        }
+        assert srcIdIndex >= 0 : "Source id field must defined in edge constructor";
+        assert targetIdIndex >= 0 : "Target id field must defined in edge constructor";
+        reOrderNodes.add(operands.get(srcIdIndex));
+        reOrderNodes.add(operands.get(targetIdIndex));
+        if (edgeLabelIndex >= 0) {
+          reOrderNodes.add(operands.get(edgeLabelIndex));
+        } else {
+          reOrderNodes.add(DEFAULT_LABEL_VALUE);
+        }
+        if (tsIndex >= 0) {
+          reOrderNodes.add(operands.get(tsIndex));
+        }
+        for (int i = 0; i < operands.size(); i++) {
+          if (i != srcIdIndex && i != targetIdIndex && i != edgeLabelIndex && i != tsIndex) {
+            reOrderNodes.add(operands.get(i));
+          }
         }
+        return reOrderNodes;
+      default:
+        throw new IllegalArgumentException(
+            "Illegal type name: " + type.getSqlTypeName() + " for Object constructor.");
     }
+  }
 
-    public static class VariableInfo {
+  public static class VariableInfo {
 
-        /**
-         * Whether it is a global variable.
-         */
-        private final boolean isGlobal;
+    /** Whether it is a global variable. */
+    private final boolean isGlobal;
 
-        /**
-         * Variable name.
-         */
-        private final String name;
+    /** Variable name. */
+    private final String name;
 
-        public VariableInfo(boolean isGlobal, String name) {
-            this.isGlobal = isGlobal;
-            this.name = name;
-        }
+    public VariableInfo(boolean isGlobal, String name) {
+      this.isGlobal = isGlobal;
+      this.name = name;
+    }
 
-        public boolean isGlobal() {
-            return isGlobal;
-        }
+    public boolean isGlobal() {
+      return isGlobal;
+    }
 
-        public String getName() {
-            return name;
-        }
+    public String getName() {
+      return name;
+    }
 
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) {
-                return true;
-            }
-            if (!(o instanceof VariableInfo)) {
-                return false;
-            }
-            VariableInfo that = (VariableInfo) o;
-            return isGlobal == that.isGlobal && Objects.equals(name, that.name);
-        }
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof VariableInfo)) {
+        return false;
+      }
+      VariableInfo that = (VariableInfo) o;
+      return isGlobal == that.isGlobal && Objects.equals(name, that.name);
+    }
 
-        @Override
-        public int hashCode() {
-            return Objects.hash(isGlobal, name);
-        }
+    @Override
+    public int hashCode() {
+      return Objects.hash(isGlobal, name);
+    }
 
-        @Override
-        public String toString() {
-            return "VariableInfo{"
-                + "isGlobal=" + isGlobal
-                + ", name='" + name + '\''
-                + '}';
-        }
+    @Override
+    public String toString() {
+      return "VariableInfo{" + "isGlobal=" + isGlobal + ", name='" + name + '\'' + '}';
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexParameterRef.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexParameterRef.java
index fccd87a11..323ba090c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexParameterRef.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexParameterRef.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rex;
 
 import java.util.Objects;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBiVisitor;
 import org.apache.calcite.rex.RexSlot;
@@ -27,37 +28,35 @@
 
 public class RexParameterRef extends RexSlot {
 
-    private final RelDataType inputType;
-
-    public RexParameterRef(int index, RelDataType type, RelDataType inputType) {
-        super("$$" + index, index, type);
-        this.inputType = Objects.requireNonNull(inputType);
-        this.digest = getName();
-    }
-
-    @Override
-    public  R accept(RexVisitor visitor) {
-        return visitor.visitOther(this);
-    }
-
-    @Override
-    public  R accept(RexBiVisitor visitor, P arg) {
-        return visitor.visitOther(this, arg);
-    }
-
-    @Override
-    public boolean equals(Object obj) {
-        return this == obj
-            || obj instanceof RexParameterRef
-            && index == ((RexParameterRef) obj).index;
-    }
-
-    @Override
-    public int hashCode() {
-        return index;
-    }
-
-    public RelDataType getInputType() {
-        return inputType;
-    }
+  private final RelDataType inputType;
+
+  public RexParameterRef(int index, RelDataType type, RelDataType inputType) {
+    super("$$" + index, index, type);
+    this.inputType = Objects.requireNonNull(inputType);
+    this.digest = getName();
+  }
+
+  @Override
+  public  R accept(RexVisitor visitor) {
+    return visitor.visitOther(this);
+  }
+
+  @Override
+  public  R accept(RexBiVisitor visitor, P arg) {
+    return visitor.visitOther(this, arg);
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return this == obj || obj instanceof RexParameterRef && index == ((RexParameterRef) obj).index;
+  }
+
+  @Override
+  public int hashCode() {
+    return index;
+  }
+
+  public RelDataType getInputType() {
+    return inputType;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexSystemVariable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexSystemVariable.java
index 1f4d09234..4bb4d8973 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexSystemVariable.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/rex/RexSystemVariable.java
@@ -20,6 +20,7 @@
 package org.apache.geaflow.dsl.rex;
 
 import java.util.Objects;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.rex.RexBiVisitor;
@@ -29,71 +30,71 @@
 
 public class RexSystemVariable extends RexSlot {
 
-    public RexSystemVariable(String name, int index, RelDataType type) {
-        super(name, index, type);
-    }
+  public RexSystemVariable(String name, int index, RelDataType type) {
+    super(name, index, type);
+  }
 
-    @Override
-    public  R accept(RexVisitor visitor) {
-        return visitor.visitOther(this);
-    }
+  @Override
+  public  R accept(RexVisitor visitor) {
+    return visitor.visitOther(this);
+  }
 
-    @Override
-    public  R accept(RexBiVisitor visitor, P arg) {
-        return visitor.visitOther(this, arg);
-    }
+  @Override
+  public  R accept(RexBiVisitor visitor, P arg) {
+    return visitor.visitOther(this, arg);
+  }
 
-    @Override
-    public boolean equals(Object obj) {
-        if (!(obj instanceof RexSystemVariable)) {
-            return false;
-        }
-        RexSystemVariable systemVariable = (RexSystemVariable) obj;
-        return Objects.equals(name, systemVariable.name) && index == systemVariable.index;
+  @Override
+  public boolean equals(Object obj) {
+    if (!(obj instanceof RexSystemVariable)) {
+      return false;
     }
+    RexSystemVariable systemVariable = (RexSystemVariable) obj;
+    return Objects.equals(name, systemVariable.name) && index == systemVariable.index;
+  }
 
-    @Override
-    public int hashCode() {
-        return Objects.hash(name, index);
-    }
+  @Override
+  public int hashCode() {
+    return Objects.hash(name, index);
+  }
 
-    public enum SystemVariable {
-        LOOP_COUNTER("loopCounter", 0, SqlTypeName.INTEGER),
-        ;
+  public enum SystemVariable {
+    LOOP_COUNTER("loopCounter", 0, SqlTypeName.INTEGER),
+    ;
 
-        private final String name;
+    private final String name;
 
-        private final int index;
+    private final int index;
 
-        private final SqlTypeName typeName;
+    private final SqlTypeName typeName;
 
-        SystemVariable(String name, int index, SqlTypeName typeName) {
-            this.name = name;
-            this.index = index;
-            this.typeName = typeName;
-        }
+    SystemVariable(String name, int index, SqlTypeName typeName) {
+      this.name = name;
+      this.index = index;
+      this.typeName = typeName;
+    }
 
-        public String getName() {
-            return name;
-        }
+    public String getName() {
+      return name;
+    }
 
-        public int getIndex() {
-            return index;
-        }
+    public int getIndex() {
+      return index;
+    }
 
-        public SqlTypeName getTypeName() {
-            return typeName;
-        }
+    public SqlTypeName getTypeName() {
+      return typeName;
+    }
 
-        public RexSystemVariable toRexNode(RelDataTypeFactory typeFactory) {
-            return new RexSystemVariable(name, index, typeFactory.createSqlType(typeName));
-        }
+    public RexSystemVariable toRexNode(RelDataTypeFactory typeFactory) {
+      return new RexSystemVariable(name, index, typeFactory.createSqlType(typeName));
+    }
 
-        public static SystemVariable of(String name) {
-            if (name.equalsIgnoreCase(LOOP_COUNTER.getName())) {
-                return LOOP_COUNTER;
-            }
-            throw new IllegalArgumentException("Not support system variable: " + name);
-        }
+    public static SystemVariable of(String name) {
+      if (name.equalsIgnoreCase(LOOP_COUNTER.getName())) {
+        return LOOP_COUNTER;
+      }
+      throw new IllegalArgumentException("Not support system variable: " + name);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java
index bfd2d6e8f..dcc211e00 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java
@@ -19,11 +19,10 @@
 
 package org.apache.geaflow.dsl.schema.function;
 
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+
 import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlIdentifier;
@@ -126,203 +125,201 @@
 import org.apache.geaflow.dsl.udf.table.string.UrlEncode;
 import org.apache.geaflow.dsl.util.FunctionUtil;
 
-/**
- * SQL build-in {@link SqlFunction}s.
- */
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+
+/** SQL build-in {@link SqlFunction}s. */
 public class BuildInSqlFunctionTable extends ListSqlOperatorTable {
 
-    private final GQLJavaTypeFactory typeFactory;
+  private final GQLJavaTypeFactory typeFactory;
 
-    private final ImmutableList buildInSqlFunctions =
-        new ImmutableList.Builder()
-            // ~ build in UDF --------------------------------
-            // udf.table.date
-            .add(GeaFlowFunction.of(AddMonths.class))
-            .add(GeaFlowFunction.of(DateAdd.class))
-            .add(GeaFlowFunction.of(DateDiff.class))
-            .add(GeaFlowFunction.of(DateFormat.class))
-            .add(GeaFlowFunction.of(DatePart.class))
-            .add(GeaFlowFunction.of(DateSub.class))
-            .add(GeaFlowFunction.of(DateTrunc.class))
-            .add(GeaFlowFunction.of(Day.class))
-            .add(GeaFlowFunction.of(DayOfMonth.class))
-            .add(GeaFlowFunction.of(FromUnixTime.class))
-            .add(GeaFlowFunction.of(FromUnixTimeMillis.class))
-            .add(GeaFlowFunction.of(Hour.class))
-            .add(GeaFlowFunction.of(IsDate.class))
-            .add(GeaFlowFunction.of(LastDay.class))
-            .add(GeaFlowFunction.of(Minute.class))
-            .add(GeaFlowFunction.of(Month.class))
-            .add(GeaFlowFunction.of(Now.class))
-            .add(GeaFlowFunction.of(Second.class))
-            .add(GeaFlowFunction.of(UnixTimeStamp.class))
-            .add(GeaFlowFunction.of(UnixTimeStampMillis.class))
-            .add(GeaFlowFunction.of(WeekDay.class))
-            .add(GeaFlowFunction.of(WeekOfYear.class))
-            .add(GeaFlowFunction.of(Year.class))
+  private final ImmutableList buildInSqlFunctions =
+      new ImmutableList.Builder()
+          // ~ build in UDF --------------------------------
+          // udf.table.date
+          .add(GeaFlowFunction.of(AddMonths.class))
+          .add(GeaFlowFunction.of(DateAdd.class))
+          .add(GeaFlowFunction.of(DateDiff.class))
+          .add(GeaFlowFunction.of(DateFormat.class))
+          .add(GeaFlowFunction.of(DatePart.class))
+          .add(GeaFlowFunction.of(DateSub.class))
+          .add(GeaFlowFunction.of(DateTrunc.class))
+          .add(GeaFlowFunction.of(Day.class))
+          .add(GeaFlowFunction.of(DayOfMonth.class))
+          .add(GeaFlowFunction.of(FromUnixTime.class))
+          .add(GeaFlowFunction.of(FromUnixTimeMillis.class))
+          .add(GeaFlowFunction.of(Hour.class))
+          .add(GeaFlowFunction.of(IsDate.class))
+          .add(GeaFlowFunction.of(LastDay.class))
+          .add(GeaFlowFunction.of(Minute.class))
+          .add(GeaFlowFunction.of(Month.class))
+          .add(GeaFlowFunction.of(Now.class))
+          .add(GeaFlowFunction.of(Second.class))
+          .add(GeaFlowFunction.of(UnixTimeStamp.class))
+          .add(GeaFlowFunction.of(UnixTimeStampMillis.class))
+          .add(GeaFlowFunction.of(WeekDay.class))
+          .add(GeaFlowFunction.of(WeekOfYear.class))
+          .add(GeaFlowFunction.of(Year.class))
 
-            // udf.table.array
-            .add(GeaFlowFunction.of(ArrayAppend.class))
-            .add(GeaFlowFunction.of(ArrayContains.class))
-            .add(GeaFlowFunction.of(ArrayDistinct.class))
-            .add(GeaFlowFunction.of(ArrayUnion.class))
+          // udf.table.array
+          .add(GeaFlowFunction.of(ArrayAppend.class))
+          .add(GeaFlowFunction.of(ArrayContains.class))
+          .add(GeaFlowFunction.of(ArrayDistinct.class))
+          .add(GeaFlowFunction.of(ArrayUnion.class))
 
-            // udf.table.math
-            .add(GeaFlowFunction.of(E.class))
-            .add(GeaFlowFunction.of(Log2.class))
-            .add(GeaFlowFunction.of(Round.class))
+          // udf.table.math
+          .add(GeaFlowFunction.of(E.class))
+          .add(GeaFlowFunction.of(Log2.class))
+          .add(GeaFlowFunction.of(Round.class))
 
-            // udf.table.string
-            .add(GeaFlowFunction.of(Ascii2String.class))
-            .add(GeaFlowFunction.of(Base64Decode.class))
-            .add(GeaFlowFunction.of(Base64Encode.class))
-            .add(GeaFlowFunction.of(Concat.class))
-            .add(GeaFlowFunction.of(ConcatWS.class))
-            .add(GeaFlowFunction.of(Hash.class))
-            .add(GeaFlowFunction.of(IndexOf.class))
-            .add(GeaFlowFunction.of(Instr.class))
-            .add(GeaFlowFunction.of(IsBlank.class))
-            .add(GeaFlowFunction.of(KeyValue.class))
-            .add(GeaFlowFunction.of(Length.class))
-            .add(GeaFlowFunction.of(Like.class))
-            .add(GeaFlowFunction.of(LTrim.class))
-            .add(GeaFlowFunction.of(RegExp.class))
-            .add(GeaFlowFunction.of(RegexpCount.class))
-            .add(GeaFlowFunction.of(RegExpExtract.class))
-            .add(GeaFlowFunction.of(RegExpReplace.class))
-            .add(GeaFlowFunction.of(Repeat.class))
-            .add(GeaFlowFunction.of(Replace.class))
-            .add(GeaFlowFunction.of(Reverse.class))
-            .add(GeaFlowFunction.of(RTrim.class))
-            .add(GeaFlowFunction.of(Space.class))
-            .add(GeaFlowFunction.of(SplitEx.class))
-            .add(GeaFlowFunction.of(Substr.class))
-            .add(GeaFlowFunction.of(UrlDecode.class))
-            .add(GeaFlowFunction.of(UrlEncode.class))
-            .add(GeaFlowFunction.of(GetJsonObject.class))
+          // udf.table.string
+          .add(GeaFlowFunction.of(Ascii2String.class))
+          .add(GeaFlowFunction.of(Base64Decode.class))
+          .add(GeaFlowFunction.of(Base64Encode.class))
+          .add(GeaFlowFunction.of(Concat.class))
+          .add(GeaFlowFunction.of(ConcatWS.class))
+          .add(GeaFlowFunction.of(Hash.class))
+          .add(GeaFlowFunction.of(IndexOf.class))
+          .add(GeaFlowFunction.of(Instr.class))
+          .add(GeaFlowFunction.of(IsBlank.class))
+          .add(GeaFlowFunction.of(KeyValue.class))
+          .add(GeaFlowFunction.of(Length.class))
+          .add(GeaFlowFunction.of(Like.class))
+          .add(GeaFlowFunction.of(LTrim.class))
+          .add(GeaFlowFunction.of(RegExp.class))
+          .add(GeaFlowFunction.of(RegexpCount.class))
+          .add(GeaFlowFunction.of(RegExpExtract.class))
+          .add(GeaFlowFunction.of(RegExpReplace.class))
+          .add(GeaFlowFunction.of(Repeat.class))
+          .add(GeaFlowFunction.of(Replace.class))
+          .add(GeaFlowFunction.of(Reverse.class))
+          .add(GeaFlowFunction.of(RTrim.class))
+          .add(GeaFlowFunction.of(Space.class))
+          .add(GeaFlowFunction.of(SplitEx.class))
+          .add(GeaFlowFunction.of(Substr.class))
+          .add(GeaFlowFunction.of(UrlDecode.class))
+          .add(GeaFlowFunction.of(UrlEncode.class))
+          .add(GeaFlowFunction.of(GetJsonObject.class))
 
-            // udf.table.other
-            .add(GeaFlowFunction.of(If.class))
-            .add(GeaFlowFunction.of(Direction.class))
-            .add(GeaFlowFunction.of(Label.class))
-            .add(GeaFlowFunction.of(VertexId.class))
-            .add(GeaFlowFunction.of(EdgeSrcId.class))
-            .add(GeaFlowFunction.of(EdgeTargetId.class))
-            .add(GeaFlowFunction.of(EdgeTimestamp.class))
-            .add(GeaFlowFunction.of(IsDecimal.class))
-            // ISO-GQL source/destination predicates
-            .add(GeaFlowFunction.of(IsSourceOf.class))
-            .add(GeaFlowFunction.of(IsNotSourceOf.class))
-            .add(GeaFlowFunction.of(IsDestinationOf.class))
-            .add(GeaFlowFunction.of(IsNotDestinationOf.class))
-            // ISO-GQL property exists predicate
-            .add(GeaFlowFunction.of(PropertyExists.class))
-            // UDAF
-            .add(GeaFlowFunction.of(PercentileLong.class))
-            .add(GeaFlowFunction.of(PercentileInteger.class))
-            .add(GeaFlowFunction.of(PercentileDouble.class))
-            // UDGA
-            .add(GeaFlowFunction.of(SingleSourceShortestPath.class))
-            .add(GeaFlowFunction.of(AllSourceShortestPath.class))
-            .add(GeaFlowFunction.of(PageRank.class))
-            .add(GeaFlowFunction.of(KHop.class))
-            .add(GeaFlowFunction.of(KCore.class))
-            .add(GeaFlowFunction.of(IncrementalKCore.class))
-            .add(GeaFlowFunction.of(IncMinimumSpanningTree.class))
-            .add(GeaFlowFunction.of(ClosenessCentrality.class))
-            .add(GeaFlowFunction.of(WeakConnectedComponents.class))
-            .add(GeaFlowFunction.of(TriangleCount.class))
-            .add(GeaFlowFunction.of(ClusterCoefficient.class))
-            .add(GeaFlowFunction.of(IncWeakConnectedComponents.class))
-            .add(GeaFlowFunction.of(CommonNeighbors.class))
-            .add(GeaFlowFunction.of(JaccardSimilarity.class))
-            .add(GeaFlowFunction.of(IncKHopAlgorithm.class))
-            .add(GeaFlowFunction.of(LabelPropagation.class))
-            .add(GeaFlowFunction.of(ConnectedComponents.class))
-            .build();
+          // udf.table.other
+          .add(GeaFlowFunction.of(If.class))
+          .add(GeaFlowFunction.of(Direction.class))
+          .add(GeaFlowFunction.of(Label.class))
+          .add(GeaFlowFunction.of(VertexId.class))
+          .add(GeaFlowFunction.of(EdgeSrcId.class))
+          .add(GeaFlowFunction.of(EdgeTargetId.class))
+          .add(GeaFlowFunction.of(EdgeTimestamp.class))
+          .add(GeaFlowFunction.of(IsDecimal.class))
+          // ISO-GQL source/destination predicates
+          .add(GeaFlowFunction.of(IsSourceOf.class))
+          .add(GeaFlowFunction.of(IsNotSourceOf.class))
+          .add(GeaFlowFunction.of(IsDestinationOf.class))
+          .add(GeaFlowFunction.of(IsNotDestinationOf.class))
+          // ISO-GQL property exists predicate
+          .add(GeaFlowFunction.of(PropertyExists.class))
+          // UDAF
+          .add(GeaFlowFunction.of(PercentileLong.class))
+          .add(GeaFlowFunction.of(PercentileInteger.class))
+          .add(GeaFlowFunction.of(PercentileDouble.class))
+          // UDGA
+          .add(GeaFlowFunction.of(SingleSourceShortestPath.class))
+          .add(GeaFlowFunction.of(AllSourceShortestPath.class))
+          .add(GeaFlowFunction.of(PageRank.class))
+          .add(GeaFlowFunction.of(KHop.class))
+          .add(GeaFlowFunction.of(KCore.class))
+          .add(GeaFlowFunction.of(IncrementalKCore.class))
+          .add(GeaFlowFunction.of(IncMinimumSpanningTree.class))
+          .add(GeaFlowFunction.of(ClosenessCentrality.class))
+          .add(GeaFlowFunction.of(WeakConnectedComponents.class))
+          .add(GeaFlowFunction.of(TriangleCount.class))
+          .add(GeaFlowFunction.of(ClusterCoefficient.class))
+          .add(GeaFlowFunction.of(IncWeakConnectedComponents.class))
+          .add(GeaFlowFunction.of(CommonNeighbors.class))
+          .add(GeaFlowFunction.of(JaccardSimilarity.class))
+          .add(GeaFlowFunction.of(IncKHopAlgorithm.class))
+          .add(GeaFlowFunction.of(LabelPropagation.class))
+          .add(GeaFlowFunction.of(ConnectedComponents.class))
+          .build();
 
-    public BuildInSqlFunctionTable(GQLJavaTypeFactory typeFactory) {
-        this.typeFactory = typeFactory;
-        this.register();
-    }
+  public BuildInSqlFunctionTable(GQLJavaTypeFactory typeFactory) {
+    this.typeFactory = typeFactory;
+    this.register();
+  }
 
-    private void register() {
-        for (GeaFlowFunction function : functions()) {
-            SqlFunction sqlFunction;
-            try {
-                sqlFunction = FunctionUtil.createSqlFunction(function, typeFactory);
-            } catch (GeaFlowDSLException e) {
-                throw new GeaFlowDSLException(
-                    "Error in register SqlFunction " + function, e);
-            }
-            super.add(sqlFunction);
-        }
+  private void register() {
+    for (GeaFlowFunction function : functions()) {
+      SqlFunction sqlFunction;
+      try {
+        sqlFunction = FunctionUtil.createSqlFunction(function, typeFactory);
+      } catch (GeaFlowDSLException e) {
+        throw new GeaFlowDSLException("Error in register SqlFunction " + function, e);
+      }
+      super.add(sqlFunction);
     }
+  }
 
-    public void registerFunction(Class functionClass) {
-        try {
-            SqlFunction sqlFunction = FunctionUtil.createSqlFunction(
-                GeaFlowFunction.of(functionClass), typeFactory);
-            super.add(sqlFunction);
-        } catch (GeaFlowDSLException e) {
-            throw new GeaFlowDSLException(
-                "Error in register SqlFunction " + functionClass, e);
-        }
+  public void registerFunction(Class functionClass) {
+    try {
+      SqlFunction sqlFunction =
+          FunctionUtil.createSqlFunction(GeaFlowFunction.of(functionClass), typeFactory);
+      super.add(sqlFunction);
+    } catch (GeaFlowDSLException e) {
+      throw new GeaFlowDSLException("Error in register SqlFunction " + functionClass, e);
     }
+  }
 
-    private List functions() {
-        Map combines = new HashMap<>();
-        for (GeaFlowFunction function : buildInSqlFunctions) {
-            GeaFlowFunction old = combines.get(function.getName());
-            if (old == null) {
-                old = function;
-            } else {
-                // As UDAF, for example:'max', it has multiple class impl.
-                // MaxLong
-                // MaxDouble
-                // MaxString
-                if (isUDAF(function)) {
-                    List clazz = old.getClazz();
-                    clazz.addAll(function.getClazz());
-                    old = GeaFlowFunction.of(function.getName(), clazz);
-                } else {
-                    throw new RuntimeException("function " + function.getName()
-                        + " cannot have more than one implement class");
-                }
-            }
-
-            String name = old.getName();
-            combines.put(name, old);
+  private List functions() {
+    Map combines = new HashMap<>();
+    for (GeaFlowFunction function : buildInSqlFunctions) {
+      GeaFlowFunction old = combines.get(function.getName());
+      if (old == null) {
+        old = function;
+      } else {
+        // As UDAF, for example:'max', it has multiple class impl.
+        // MaxLong
+        // MaxDouble
+        // MaxString
+        if (isUDAF(function)) {
+          List clazz = old.getClazz();
+          clazz.addAll(function.getClazz());
+          old = GeaFlowFunction.of(function.getName(), clazz);
+        } else {
+          throw new RuntimeException(
+              "function " + function.getName() + " cannot have more than one implement class");
         }
-        return Lists.newArrayList(combines.values());
-    }
+      }
 
+      String name = old.getName();
+      combines.put(name, old);
+    }
+    return Lists.newArrayList(combines.values());
+  }
 
-    private boolean isUDAF(GeaFlowFunction function) {
-        try {
-            String reflectClassName = function.getClazz().get(0);
-            Class reflectClass = Thread.currentThread().getContextClassLoader()
-                .loadClass(reflectClassName);
-            return UDAF.class.isAssignableFrom(reflectClass);
-        } catch (ClassNotFoundException e) {
-            throw new RuntimeException(e);
-        }
+  private boolean isUDAF(GeaFlowFunction function) {
+    try {
+      String reflectClassName = function.getClazz().get(0);
+      Class reflectClass =
+          Thread.currentThread().getContextClassLoader().loadClass(reflectClassName);
+      return UDAF.class.isAssignableFrom(reflectClass);
+    } catch (ClassNotFoundException e) {
+      throw new RuntimeException(e);
     }
+  }
 
-    @Override
-    public void lookupOperatorOverloads(SqlIdentifier opName,
-                                        SqlFunctionCategory category,
-                                        SqlSyntax syntax,
-                                        List operatorList) {
-        for (SqlOperator operator : getOperatorList()) {
-            if (!opName.isSimple()
-                || !operator.getName().equalsIgnoreCase(opName.getSimple())) {
-                continue;
-            }
-            if (operator.getSyntax() != syntax) {
-                continue;
-            }
-            operatorList.add(operator);
-        }
+  @Override
+  public void lookupOperatorOverloads(
+      SqlIdentifier opName,
+      SqlFunctionCategory category,
+      SqlSyntax syntax,
+      List operatorList) {
+    for (SqlOperator operator : getOperatorList()) {
+      if (!opName.isSimple() || !operator.getName().equalsIgnoreCase(opName.getSimple())) {
+        continue;
+      }
+      if (operator.getSyntax() != syntax) {
+        continue;
+      }
+      operatorList.add(operator);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlOperatorTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlOperatorTable.java
index 6b889e7c8..be575818e 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlOperatorTable.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlOperatorTable.java
@@ -28,165 +28,163 @@
 
 public class BuildInSqlOperatorTable extends ReflectiveSqlOperatorTable {
 
-    public static final SqlAggFunction MIN = new GqlMinMaxAggFunction(SqlKind.MIN);
-    public static final SqlAggFunction MAX = new GqlMinMaxAggFunction(SqlKind.MAX);
-    public static final SqlAggFunction SUM = new GqlSumAggFunction(null);
-    public static final SqlAggFunction COUNT = new GqlCountAggFunction("COUNT");
-    public static final SqlAggFunction AVG = new GqlAvgAggFunction(SqlKind.AVG);
+  public static final SqlAggFunction MIN = new GqlMinMaxAggFunction(SqlKind.MIN);
+  public static final SqlAggFunction MAX = new GqlMinMaxAggFunction(SqlKind.MAX);
+  public static final SqlAggFunction SUM = new GqlSumAggFunction(null);
+  public static final SqlAggFunction COUNT = new GqlCountAggFunction("COUNT");
+  public static final SqlAggFunction AVG = new GqlAvgAggFunction(SqlKind.AVG);
 
-    private final SqlOperator[] buildInSqlOperators = {
-        // SET OPERATORS
-        SqlStdOperatorTable.UNION,
-        SqlStdOperatorTable.UNION_ALL,
-        SqlStdOperatorTable.EXCEPT,
-        SqlStdOperatorTable.EXCEPT_ALL,
-        SqlStdOperatorTable.INTERSECT,
-        SqlStdOperatorTable.INTERSECT_ALL,
-        // BINARY OPERATORS
-        SqlStdOperatorTable.AND,
-        SqlStdOperatorTable.AS,
-        SqlStdOperatorTable.CONCAT,
-        GeaFlowOverwriteSqlOperators.DIVIDE,
-        SqlStdOperatorTable.DOT,
-        SqlStdOperatorTable.EQUALS,
-        SqlStdOperatorTable.GREATER_THAN,
-        SqlStdOperatorTable.IS_DISTINCT_FROM,
-        SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
-        SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
-        SqlStdOperatorTable.LESS_THAN,
-        SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
-        GeaFlowOverwriteSqlOperators.MINUS,
-        GeaFlowOverwriteSqlOperators.MULTIPLY,
-        SqlStdOperatorTable.NOT_EQUALS,
-        SqlStdOperatorTable.OR,
-        GeaFlowOverwriteSqlOperators.PLUS,
-        SqlStdOperatorTable.DATETIME_PLUS,
-        // POSTFIX OPERATORS
-        SqlStdOperatorTable.DESC,
-        SqlStdOperatorTable.NULLS_FIRST,
-        SqlStdOperatorTable.IS_NOT_NULL,
-        SqlStdOperatorTable.IS_NULL,
-        SqlStdOperatorTable.IS_NOT_TRUE,
-        SqlStdOperatorTable.IS_TRUE,
-        SqlStdOperatorTable.IS_NOT_FALSE,
-        SqlStdOperatorTable.IS_FALSE,
-        SqlStdOperatorTable.IS_NOT_UNKNOWN,
-        SqlStdOperatorTable.IS_UNKNOWN,
-        // PREFIX OPERATORS
-        SqlStdOperatorTable.NOT,
-        SqlStdOperatorTable.UNARY_MINUS,
-        SqlStdOperatorTable.UNARY_PLUS,
-        // GROUPING FUNCTIONS
-        SqlStdOperatorTable.GROUP_ID,
-        SqlStdOperatorTable.GROUPING,
-        SqlStdOperatorTable.GROUPING_ID,
-        // AGGREGATE OPERATORS
-        BuildInSqlOperatorTable.SUM,
-        SqlStdOperatorTable.SUM0,
-        BuildInSqlOperatorTable.COUNT,
-        BuildInSqlOperatorTable.MIN,
-        BuildInSqlOperatorTable.MAX,
-        BuildInSqlOperatorTable.AVG,
-        SqlStdOperatorTable.STDDEV_POP,
-        SqlStdOperatorTable.STDDEV_SAMP,
-        SqlStdOperatorTable.VAR_POP,
-        SqlStdOperatorTable.VAR_SAMP,
-        // ARRAY OPERATORS
-        SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
-        SqlStdOperatorTable.ITEM,
-        SqlStdOperatorTable.CARDINALITY,
-        SqlStdOperatorTable.ELEMENT,
-        // SPECIAL OPERATORS
-        SqlStdOperatorTable.ROW,
-        SqlStdOperatorTable.OVERLAPS,
-        SqlStdOperatorTable.LITERAL_CHAIN,
-        SqlStdOperatorTable.BETWEEN,
-        SqlStdOperatorTable.SYMMETRIC_BETWEEN,
-        SqlStdOperatorTable.NOT_BETWEEN,
-        SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN,
-        SqlStdOperatorTable.NOT_LIKE,
-        SqlStdOperatorTable.LIKE,
-        SqlStdOperatorTable.NOT_SIMILAR_TO,
-        SqlStdOperatorTable.SIMILAR_TO,
-        SqlStdOperatorTable.CASE,
-        SqlStdOperatorTable.REINTERPRET,
-        // FUNCTIONS
-        SqlStdOperatorTable.SUBSTRING,
-        SqlStdOperatorTable.OVERLAY,
-        SqlStdOperatorTable.TRIM,
-        SqlStdOperatorTable.POSITION,
-        SqlStdOperatorTable.CHAR_LENGTH,
-        SqlStdOperatorTable.CHARACTER_LENGTH,
-        SqlStdOperatorTable.UPPER,
-        SqlStdOperatorTable.LOWER,
-        SqlStdOperatorTable.INITCAP,
-        SqlStdOperatorTable.POWER,
-        SqlStdOperatorTable.SQRT,
-        GeaFlowOverwriteSqlOperators.MOD,
-        GeaFlowOverwriteSqlOperators.PERCENT_REMAINDER,
-        SqlStdOperatorTable.LN,
-        SqlStdOperatorTable.LOG10,
-        SqlStdOperatorTable.ABS,
-        SqlStdOperatorTable.EXP,
-        SqlStdOperatorTable.NULLIF,
-        SqlStdOperatorTable.COALESCE,
-        SqlStdOperatorTable.FLOOR,
-        SqlStdOperatorTable.CEIL,
-        SqlStdOperatorTable.LOCALTIME,
-        SqlStdOperatorTable.LOCALTIMESTAMP,
-        SqlStdOperatorTable.CURRENT_TIME,
-        SqlStdOperatorTable.CURRENT_TIMESTAMP,
-        SqlStdOperatorTable.CURRENT_DATE,
-        SqlStdOperatorTable.TIMESTAMP_ADD,
-        SqlStdOperatorTable.TIMESTAMP_DIFF,
-        SqlStdOperatorTable.CAST,
-        SqlStdOperatorTable.EXTRACT,
-        SqlStdOperatorTable.SCALAR_QUERY,
-        SqlStdOperatorTable.EXISTS,
-        SqlStdOperatorTable.SIN,
-        SqlStdOperatorTable.COS,
-        SqlStdOperatorTable.TAN,
-        SqlStdOperatorTable.COT,
-        SqlStdOperatorTable.ASIN,
-        SqlStdOperatorTable.ACOS,
-        SqlStdOperatorTable.ATAN,
-        SqlStdOperatorTable.DEGREES,
-        SqlStdOperatorTable.RADIANS,
-        GeaFlowOverwriteSqlOperators.SIGN,
-        // SqlStdOperatorTable.ROUND,
-        SqlStdOperatorTable.PI,
-        SqlStdOperatorTable.RAND,
-        SqlStdOperatorTable.RAND_INTEGER,
-        // EXTENSIONS
-        SqlStdOperatorTable.TUMBLE,
-        SqlStdOperatorTable.TUMBLE_START,
-        SqlStdOperatorTable.TUMBLE_END,
-        SqlStdOperatorTable.HOP,
-        SqlStdOperatorTable.HOP_START,
-        SqlStdOperatorTable.HOP_END,
-        SqlStdOperatorTable.SESSION,
-        SqlStdOperatorTable.SESSION_START,
-        SqlStdOperatorTable.SESSION_END,
+  private final SqlOperator[] buildInSqlOperators = {
+    // SET OPERATORS
+    SqlStdOperatorTable.UNION,
+    SqlStdOperatorTable.UNION_ALL,
+    SqlStdOperatorTable.EXCEPT,
+    SqlStdOperatorTable.EXCEPT_ALL,
+    SqlStdOperatorTable.INTERSECT,
+    SqlStdOperatorTable.INTERSECT_ALL,
+    // BINARY OPERATORS
+    SqlStdOperatorTable.AND,
+    SqlStdOperatorTable.AS,
+    SqlStdOperatorTable.CONCAT,
+    GeaFlowOverwriteSqlOperators.DIVIDE,
+    SqlStdOperatorTable.DOT,
+    SqlStdOperatorTable.EQUALS,
+    SqlStdOperatorTable.GREATER_THAN,
+    SqlStdOperatorTable.IS_DISTINCT_FROM,
+    SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
+    SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
+    SqlStdOperatorTable.LESS_THAN,
+    SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
+    GeaFlowOverwriteSqlOperators.MINUS,
+    GeaFlowOverwriteSqlOperators.MULTIPLY,
+    SqlStdOperatorTable.NOT_EQUALS,
+    SqlStdOperatorTable.OR,
+    GeaFlowOverwriteSqlOperators.PLUS,
+    SqlStdOperatorTable.DATETIME_PLUS,
+    // POSTFIX OPERATORS
+    SqlStdOperatorTable.DESC,
+    SqlStdOperatorTable.NULLS_FIRST,
+    SqlStdOperatorTable.IS_NOT_NULL,
+    SqlStdOperatorTable.IS_NULL,
+    SqlStdOperatorTable.IS_NOT_TRUE,
+    SqlStdOperatorTable.IS_TRUE,
+    SqlStdOperatorTable.IS_NOT_FALSE,
+    SqlStdOperatorTable.IS_FALSE,
+    SqlStdOperatorTable.IS_NOT_UNKNOWN,
+    SqlStdOperatorTable.IS_UNKNOWN,
+    // PREFIX OPERATORS
+    SqlStdOperatorTable.NOT,
+    SqlStdOperatorTable.UNARY_MINUS,
+    SqlStdOperatorTable.UNARY_PLUS,
+    // GROUPING FUNCTIONS
+    SqlStdOperatorTable.GROUP_ID,
+    SqlStdOperatorTable.GROUPING,
+    SqlStdOperatorTable.GROUPING_ID,
+    // AGGREGATE OPERATORS
+    BuildInSqlOperatorTable.SUM,
+    SqlStdOperatorTable.SUM0,
+    BuildInSqlOperatorTable.COUNT,
+    BuildInSqlOperatorTable.MIN,
+    BuildInSqlOperatorTable.MAX,
+    BuildInSqlOperatorTable.AVG,
+    SqlStdOperatorTable.STDDEV_POP,
+    SqlStdOperatorTable.STDDEV_SAMP,
+    SqlStdOperatorTable.VAR_POP,
+    SqlStdOperatorTable.VAR_SAMP,
+    // ARRAY OPERATORS
+    SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+    SqlStdOperatorTable.ITEM,
+    SqlStdOperatorTable.CARDINALITY,
+    SqlStdOperatorTable.ELEMENT,
+    // SPECIAL OPERATORS
+    SqlStdOperatorTable.ROW,
+    SqlStdOperatorTable.OVERLAPS,
+    SqlStdOperatorTable.LITERAL_CHAIN,
+    SqlStdOperatorTable.BETWEEN,
+    SqlStdOperatorTable.SYMMETRIC_BETWEEN,
+    SqlStdOperatorTable.NOT_BETWEEN,
+    SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN,
+    SqlStdOperatorTable.NOT_LIKE,
+    SqlStdOperatorTable.LIKE,
+    SqlStdOperatorTable.NOT_SIMILAR_TO,
+    SqlStdOperatorTable.SIMILAR_TO,
+    SqlStdOperatorTable.CASE,
+    SqlStdOperatorTable.REINTERPRET,
+    // FUNCTIONS
+    SqlStdOperatorTable.SUBSTRING,
+    SqlStdOperatorTable.OVERLAY,
+    SqlStdOperatorTable.TRIM,
+    SqlStdOperatorTable.POSITION,
+    SqlStdOperatorTable.CHAR_LENGTH,
+    SqlStdOperatorTable.CHARACTER_LENGTH,
+    SqlStdOperatorTable.UPPER,
+    SqlStdOperatorTable.LOWER,
+    SqlStdOperatorTable.INITCAP,
+    SqlStdOperatorTable.POWER,
+    SqlStdOperatorTable.SQRT,
+    GeaFlowOverwriteSqlOperators.MOD,
+    GeaFlowOverwriteSqlOperators.PERCENT_REMAINDER,
+    SqlStdOperatorTable.LN,
+    SqlStdOperatorTable.LOG10,
+    SqlStdOperatorTable.ABS,
+    SqlStdOperatorTable.EXP,
+    SqlStdOperatorTable.NULLIF,
+    SqlStdOperatorTable.COALESCE,
+    SqlStdOperatorTable.FLOOR,
+    SqlStdOperatorTable.CEIL,
+    SqlStdOperatorTable.LOCALTIME,
+    SqlStdOperatorTable.LOCALTIMESTAMP,
+    SqlStdOperatorTable.CURRENT_TIME,
+    SqlStdOperatorTable.CURRENT_TIMESTAMP,
+    SqlStdOperatorTable.CURRENT_DATE,
+    SqlStdOperatorTable.TIMESTAMP_ADD,
+    SqlStdOperatorTable.TIMESTAMP_DIFF,
+    SqlStdOperatorTable.CAST,
+    SqlStdOperatorTable.EXTRACT,
+    SqlStdOperatorTable.SCALAR_QUERY,
+    SqlStdOperatorTable.EXISTS,
+    SqlStdOperatorTable.SIN,
+    SqlStdOperatorTable.COS,
+    SqlStdOperatorTable.TAN,
+    SqlStdOperatorTable.COT,
+    SqlStdOperatorTable.ASIN,
+    SqlStdOperatorTable.ACOS,
+    SqlStdOperatorTable.ATAN,
+    SqlStdOperatorTable.DEGREES,
+    SqlStdOperatorTable.RADIANS,
+    GeaFlowOverwriteSqlOperators.SIGN,
+    // SqlStdOperatorTable.ROUND,
+    SqlStdOperatorTable.PI,
+    SqlStdOperatorTable.RAND,
+    SqlStdOperatorTable.RAND_INTEGER,
+    // EXTENSIONS
+    SqlStdOperatorTable.TUMBLE,
+    SqlStdOperatorTable.TUMBLE_START,
+    SqlStdOperatorTable.TUMBLE_END,
+    SqlStdOperatorTable.HOP,
+    SqlStdOperatorTable.HOP_START,
+    SqlStdOperatorTable.HOP_END,
+    SqlStdOperatorTable.SESSION,
+    SqlStdOperatorTable.SESSION_START,
+    SqlStdOperatorTable.SESSION_END,
+    SqlStdOperatorTable.RANK,
+    SqlStdOperatorTable.PERCENT_RANK,
+    SqlStdOperatorTable.DENSE_RANK,
+    SqlStdOperatorTable.CUME_DIST,
+    SqlStdOperatorTable.ROW_NUMBER,
+    SqlStdOperatorTable.LAG,
+    SqlStdOperatorTable.LEAD,
+    // ISO-GQL SAME predicate
+    SqlSameOperator.INSTANCE
+  };
 
-        SqlStdOperatorTable.RANK,
-        SqlStdOperatorTable.PERCENT_RANK,
-        SqlStdOperatorTable.DENSE_RANK,
-        SqlStdOperatorTable.CUME_DIST,
-        SqlStdOperatorTable.ROW_NUMBER,
-        SqlStdOperatorTable.LAG,
-        SqlStdOperatorTable.LEAD,
-        // ISO-GQL SAME predicate
-        SqlSameOperator.INSTANCE
-    };
+  public BuildInSqlOperatorTable() {
+    this.register();
+  }
 
-    public BuildInSqlOperatorTable() {
-        this.register();
+  private void register() {
+    for (SqlOperator operator : buildInSqlOperators) {
+      super.register(operator);
     }
-
-    private void register() {
-        for (SqlOperator operator : buildInSqlOperators) {
-            super.register(operator);
-        }
-    }
-
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowBuiltinFunctions.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowBuiltinFunctions.java
index 3c5e669a7..fb9fefb56 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowBuiltinFunctions.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowBuiltinFunctions.java
@@ -25,6 +25,7 @@
 import java.util.Calendar;
 import java.util.Objects;
 import java.util.Random;
+
 import org.apache.commons.lang3.time.DateUtils;
 import org.apache.geaflow.common.binary.BinaryString;
 import org.apache.geaflow.dsl.common.data.RowEdge;
@@ -32,1950 +33,1940 @@
 
 public final class GeaFlowBuiltinFunctions {
 
-    public static Long plus(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Integer a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Integer a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Short a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Short a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Byte a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Byte a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(b);
+  public static BigDecimal plus(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(b);
+  }
 
-    public static Long plus(Long a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Long a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Long plus(Long a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Long a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Long plus(Long a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Long a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Long a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Long a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(Long a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).add(b);
+  public static BigDecimal plus(Long a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).add(b);
+  }
 
-    public static Long plus(Integer a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Integer a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Integer a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Integer a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Integer a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Integer a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Integer a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Integer a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(Integer a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).add(b);
+  public static BigDecimal plus(Integer a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).add(b);
+  }
 
-    public static Long plus(Short a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Short a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Short a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Short a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Short a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Short a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Short a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Short a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(Short a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).add(b);
+  public static BigDecimal plus(Short a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).add(b);
+  }
 
-    public static Long plus(Byte a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Long plus(Byte a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Byte a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Byte a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Integer plus(Byte a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Integer plus(Byte a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Byte a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Byte a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(Byte a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).add(b);
+  public static BigDecimal plus(Byte a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).add(b);
+  }
 
-    public static Double plus(Double a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Double a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Double a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Double a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Double a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Double a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static Double plus(Double a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
+  public static Double plus(Double a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a + b;
+  }
 
-    public static BigDecimal plus(Double a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).add(b);
+  public static BigDecimal plus(Double a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).add(b);
+  }
 
-    public static BigDecimal plus(BigDecimal a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(new BigDecimal(b));
+  public static BigDecimal plus(BigDecimal a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(new BigDecimal(b));
+  }
 
-    public static BigDecimal plus(BigDecimal a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(new BigDecimal(b));
+  public static BigDecimal plus(BigDecimal a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(new BigDecimal(b));
+  }
 
-    public static BigDecimal plus(BigDecimal a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(new BigDecimal(b));
+  public static BigDecimal plus(BigDecimal a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(new BigDecimal(b));
+  }
 
-    public static BigDecimal plus(BigDecimal a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(new BigDecimal(b));
+  public static BigDecimal plus(BigDecimal a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(new BigDecimal(b));
+  }
 
-    public static BigDecimal plus(BigDecimal a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.add(new BigDecimal(b));
+  public static BigDecimal plus(BigDecimal a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.add(new BigDecimal(b));
+  }
 
-    public static Long minus(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Integer a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Integer a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Short a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Short a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Byte a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Byte a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(b);
+  public static BigDecimal minus(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(b);
+  }
 
-    public static Long minus(Long a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Long a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Long minus(Long a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Long a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Long minus(Long a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Long a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Long a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Long a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(Long a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).subtract(b);
+  public static BigDecimal minus(Long a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).subtract(b);
+  }
 
-    public static Long minus(Integer a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Integer a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Integer a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Integer a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Integer a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Integer a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Integer a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Integer a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(Integer a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).subtract(b);
+  public static BigDecimal minus(Integer a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).subtract(b);
+  }
 
-    public static Long minus(Short a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Short a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Short a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Short a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Short a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Short a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Short a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Short a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(Short a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).subtract(b);
+  public static BigDecimal minus(Short a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).subtract(b);
+  }
 
-    public static Long minus(Byte a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Long minus(Byte a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Byte a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Byte a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Integer minus(Byte a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Integer minus(Byte a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Byte a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Byte a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(Byte a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).subtract(b);
+  public static BigDecimal minus(Byte a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).subtract(b);
+  }
 
-    public static Double minus(Double a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Double a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Double a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Double a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Double a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Double a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static Double minus(Double a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a - b;
+  public static Double minus(Double a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a - b;
+  }
 
-    public static BigDecimal minus(Double a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).subtract(b);
+  public static BigDecimal minus(Double a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).subtract(b);
+  }
 
-    public static BigDecimal minus(BigDecimal a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(new BigDecimal(b));
+  public static BigDecimal minus(BigDecimal a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(new BigDecimal(b));
+  }
 
-    public static BigDecimal minus(BigDecimal a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(new BigDecimal(b));
+  public static BigDecimal minus(BigDecimal a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(new BigDecimal(b));
+  }
 
-    public static BigDecimal minus(BigDecimal a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(new BigDecimal(b));
+  public static BigDecimal minus(BigDecimal a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(new BigDecimal(b));
+  }
 
-    public static BigDecimal minus(BigDecimal a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(new BigDecimal(b));
+  public static BigDecimal minus(BigDecimal a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(new BigDecimal(b));
+  }
 
-    public static BigDecimal minus(BigDecimal a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.subtract(new BigDecimal(b));
+  public static BigDecimal minus(BigDecimal a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.subtract(new BigDecimal(b));
+  }
 
-    public static Long times(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Long times(Integer a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return (long) a * (long) b;
+  public static Long times(Integer a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return (long) a * (long) b;
+  }
 
-    public static Integer times(Short a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Short a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Byte a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Byte a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(b);
+  public static BigDecimal times(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(b);
+  }
 
-    public static Long times(Long a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Long a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Long times(Long a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Long a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Long times(Long a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Long a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Long a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Long a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(Long a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).multiply(b);
+  public static BigDecimal times(Long a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).multiply(b);
+  }
 
-    public static Long times(Integer a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Integer a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Integer a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Integer a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Integer a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Integer a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Integer a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Integer a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(Integer a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).multiply(b);
+  public static BigDecimal times(Integer a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).multiply(b);
+  }
 
-    public static Long times(Short a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Short a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Short a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Short a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Short a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Short a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Short a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Short a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(Short a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).multiply(b);
+  public static BigDecimal times(Short a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).multiply(b);
+  }
 
-    public static Long times(Byte a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Long times(Byte a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Byte a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Byte a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Integer times(Byte a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Integer times(Byte a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Byte a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Byte a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(Byte a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).multiply(b);
+  public static BigDecimal times(Byte a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).multiply(b);
+  }
 
-    public static Double times(Double a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Double a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Double a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Double a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Double a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Double a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static Double times(Double a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a * b;
+  public static Double times(Double a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a * b;
+  }
 
-    public static BigDecimal times(Double a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).multiply(b);
+  public static BigDecimal times(Double a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).multiply(b);
+  }
 
-    public static BigDecimal times(BigDecimal a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(new BigDecimal(b));
+  public static BigDecimal times(BigDecimal a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(new BigDecimal(b));
+  }
 
-    public static BigDecimal times(BigDecimal a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(new BigDecimal(b));
+  public static BigDecimal times(BigDecimal a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(new BigDecimal(b));
+  }
 
-    public static BigDecimal times(BigDecimal a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(new BigDecimal(b));
+  public static BigDecimal times(BigDecimal a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(new BigDecimal(b));
+  }
 
-    public static BigDecimal times(BigDecimal a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(new BigDecimal(b));
+  public static BigDecimal times(BigDecimal a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(new BigDecimal(b));
+  }
 
-    public static BigDecimal times(BigDecimal a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.multiply(new BigDecimal(b));
+  public static BigDecimal times(BigDecimal a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.multiply(new BigDecimal(b));
+  }
 
-    public static Long divide(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Integer a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Integer a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Short a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Short a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Byte a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Byte a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(b);
+  public static BigDecimal divide(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(b);
+  }
 
-    public static Long divide(Long a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Long a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Long divide(Long a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Long a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Long divide(Long a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Long a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Long a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Long a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(Long a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).divide(b);
+  public static BigDecimal divide(Long a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).divide(b);
+  }
 
-    public static Long divide(Integer a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Integer a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Integer a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Integer a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Integer a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Integer a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Integer a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Integer a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(Integer a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).divide(b);
+  public static BigDecimal divide(Integer a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).divide(b);
+  }
 
-    public static Long divide(Short a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Short a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Short a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Short a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Short a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Short a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Short a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Short a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(Short a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).divide(b);
+  public static BigDecimal divide(Short a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).divide(b);
+  }
 
-    public static Long divide(Byte a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Long divide(Byte a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Byte a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Byte a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Integer divide(Byte a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Integer divide(Byte a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Byte a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Byte a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(Byte a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).divide(b);
+  public static BigDecimal divide(Byte a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).divide(b);
+  }
 
-    public static Double divide(Double a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Double a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Double a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Double a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Double a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Double a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static Double divide(Double a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a / b;
+  public static Double divide(Double a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a / b;
+  }
 
-    public static BigDecimal divide(Double a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return new BigDecimal(a).divide(b);
+  public static BigDecimal divide(Double a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return new BigDecimal(a).divide(b);
+  }
 
-    public static BigDecimal divide(BigDecimal a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(new BigDecimal(b));
+  public static BigDecimal divide(BigDecimal a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(new BigDecimal(b));
+  }
 
-    public static BigDecimal divide(BigDecimal a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(new BigDecimal(b));
+  public static BigDecimal divide(BigDecimal a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(new BigDecimal(b));
+  }
 
-    public static BigDecimal divide(BigDecimal a, Short b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(new BigDecimal(b));
+  public static BigDecimal divide(BigDecimal a, Short b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(new BigDecimal(b));
+  }
 
-    public static BigDecimal divide(BigDecimal a, Byte b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(new BigDecimal(b));
+  public static BigDecimal divide(BigDecimal a, Byte b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(new BigDecimal(b));
+  }
 
-    public static BigDecimal divide(BigDecimal a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.divide(new BigDecimal(b));
+  public static BigDecimal divide(BigDecimal a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.divide(new BigDecimal(b));
+  }
 
-    public static Long mod(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        if (b == 0) {
-            return null;
-        }
-        return a % b;
+  public static Long mod(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
-
-    public static Double mod(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a % b;
+    if (b == 0) {
+      return null;
     }
+    return a % b;
+  }
 
-    public static Integer mod(Integer a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        if (b == 0) {
-            return null;
-        }
-        return a % b;
+  public static Double mod(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a % b;
+  }
 
-    public static Double power(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-
-        return Math.pow(a, b);
+  public static Integer mod(Integer a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
-
-    public static Long abs(Long a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.abs(a);
+    if (b == 0) {
+      return null;
     }
+    return a % b;
+  }
 
-    public static Double abs(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.abs(a);
+  public static Double power(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
 
-    public static Integer abs(Integer a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.abs(a);
-    }
+    return Math.pow(a, b);
+  }
 
-    public static Integer abs(Short a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.abs(a);
+  public static Long abs(Long a) {
+    if (a == null) {
+      return null;
     }
+    return Math.abs(a);
+  }
 
-    public static Integer abs(Byte a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.abs(a);
+  public static Double abs(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.abs(a);
+  }
 
-    public static BigDecimal abs(BigDecimal a) {
-        if (a == null) {
-            return null;
-        }
-        return a.abs();
+  public static Integer abs(Integer a) {
+    if (a == null) {
+      return null;
     }
+    return Math.abs(a);
+  }
 
-    public static Double asin(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.asin(a);
-    }
-
-    public static Double acos(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.acos(a);
+  public static Integer abs(Short a) {
+    if (a == null) {
+      return null;
     }
+    return Math.abs(a);
+  }
 
-    public static Double atan(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.atan(a);
+  public static Integer abs(Byte a) {
+    if (a == null) {
+      return null;
     }
+    return Math.abs(a);
+  }
 
-    public static Double ceil(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.ceil(a);
+  public static BigDecimal abs(BigDecimal a) {
+    if (a == null) {
+      return null;
     }
+    return a.abs();
+  }
 
-    public static Long ceil(Long a) {
-        if (a == null) {
-            return null;
-        }
-        return a;
+  public static Double asin(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.asin(a);
+  }
 
-    public static Integer ceil(Integer a) {
-        return a;
+  public static Double acos(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.acos(a);
+  }
 
-    public static Double cot(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return 1.0d / Math.tan(a);
+  public static Double atan(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.atan(a);
+  }
 
-    public static Double cos(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.cos(a);
+  public static Double ceil(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.ceil(a);
+  }
 
-    public static Double degrees(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.toDegrees(a);
+  public static Long ceil(Long a) {
+    if (a == null) {
+      return null;
     }
+    return a;
+  }
 
-    public static Double radians(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.toRadians(a);
-    }
+  public static Integer ceil(Integer a) {
+    return a;
+  }
 
-    public static Double exp(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.exp(a);
+  public static Double cot(Double a) {
+    if (a == null) {
+      return null;
     }
+    return 1.0d / Math.tan(a);
+  }
 
-    public static Double floor(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.floor(a);
+  public static Double cos(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.cos(a);
+  }
 
-    public static Long floor(Long a) {
-        return a;
+  public static Double degrees(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.toDegrees(a);
+  }
 
-    public static Integer floor(Integer a) {
-        return a;
+  public static Double radians(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.toRadians(a);
+  }
 
-    public static Double ln(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.log(a);
+  public static Double exp(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.exp(a);
+  }
 
-    public static Double log10(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.log10(a);
+  public static Double floor(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.floor(a);
+  }
 
-    public static Long minusPrefix(Long a) {
-        if (a == null) {
-            return null;
-        }
-        return -a;
-    }
+  public static Long floor(Long a) {
+    return a;
+  }
 
-    public static Double minusPrefix(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return -a;
-    }
+  public static Integer floor(Integer a) {
+    return a;
+  }
 
-    public static Integer minusPrefix(Integer a) {
-        if (a == null) {
-            return null;
-        }
-        return -a;
+  public static Double ln(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.log(a);
+  }
 
-    public static Integer minusPrefix(Short a) {
-        if (a == null) {
-            return null;
-        }
-        return -a;
+  public static Double log10(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.log10(a);
+  }
 
-    public static Integer minusPrefix(Byte a) {
-        if (a == null) {
-            return null;
-        }
-        return -a;
+  public static Long minusPrefix(Long a) {
+    if (a == null) {
+      return null;
     }
+    return -a;
+  }
 
-    public static BigDecimal minusPrefix(BigDecimal a) {
-        if (a == null) {
-            return null;
-        }
-        return a.negate();
+  public static Double minusPrefix(Double a) {
+    if (a == null) {
+      return null;
     }
+    return -a;
+  }
 
-    public static Double rand() {
-        return rand(null);
+  public static Integer minusPrefix(Integer a) {
+    if (a == null) {
+      return null;
     }
+    return -a;
+  }
 
-    public static Double rand(Long seed) {
-        Random random = seed == null ? new Random() : new Random(seed);
-        return random.nextDouble();
+  public static Integer minusPrefix(Short a) {
+    if (a == null) {
+      return null;
     }
+    return -a;
+  }
 
-    public static int rand(Long seed, Integer bound) {
-        Random random = seed == null ? new Random() : new Random(seed);
-        return random.nextInt(bound);
+  public static Integer minusPrefix(Byte a) {
+    if (a == null) {
+      return null;
     }
+    return -a;
+  }
 
-    public static Integer randInt(Integer bound) {
-        Random random = new Random();
-        return random.nextInt(bound);
+  public static BigDecimal minusPrefix(BigDecimal a) {
+    if (a == null) {
+      return null;
     }
+    return a.negate();
+  }
 
-    public static Integer randInt(Long seed, Integer bound) {
-        return rand(seed, bound);
-    }
+  public static Double rand() {
+    return rand(null);
+  }
 
-    public static Double sign(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.signum(a);
-    }
+  public static Double rand(Long seed) {
+    Random random = seed == null ? new Random() : new Random(seed);
+    return random.nextDouble();
+  }
 
-    public static Double sin(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.sin(a);
-    }
+  public static int rand(Long seed, Integer bound) {
+    Random random = seed == null ? new Random() : new Random(seed);
+    return random.nextInt(bound);
+  }
 
-    public static Double tan(Double a) {
-        if (a == null) {
-            return null;
-        }
-        return Math.tan(a);
-    }
+  public static Integer randInt(Integer bound) {
+    Random random = new Random();
+    return random.nextInt(bound);
+  }
 
-    public static Double round(Double a, Integer n) {
-        if (a == null || n == null) {
-            return null;
-        }
+  public static Integer randInt(Long seed, Integer bound) {
+    return rand(seed, bound);
+  }
 
-        if (Double.isNaN(a) || Double.isInfinite(a)) {
-            return a;
-        } else {
-            return BigDecimal.valueOf(a).setScale(n, RoundingMode.HALF_UP)
-                .doubleValue();
-        }
+  public static Double sign(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.signum(a);
+  }
 
-    public static Boolean equal(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.longValue() == b.longValue();
+  public static Double sin(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.sin(a);
+  }
 
-    public static Boolean equal(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-
-        return a.doubleValue() == b.doubleValue();
+  public static Double tan(Double a) {
+    if (a == null) {
+      return null;
     }
+    return Math.tan(a);
+  }
 
-    public static Boolean equal(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-
-        return a.compareTo(b) == 0;
+  public static Double round(Double a, Integer n) {
+    if (a == null || n == null) {
+      return null;
     }
 
-    public static Boolean equal(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.equals(b);
+    if (Double.isNaN(a) || Double.isInfinite(a)) {
+      return a;
+    } else {
+      return BigDecimal.valueOf(a).setScale(n, RoundingMode.HALF_UP).doubleValue();
     }
+  }
 
-    public static Boolean equal(BinaryString a, BinaryString b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.equals(b);
+  public static Boolean equal(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.longValue() == b.longValue();
+  }
 
-    public static Boolean equal(Boolean a, Boolean b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.equals(b);
+  public static Boolean equal(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
 
-    public static Boolean equal(String a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        int dot = a.indexOf('.');
-        if (dot >= 0) {
-            for (int i = dot + 1; i < a.length(); i++) {
-                if (a.charAt(i) != '0') {
-                    return false;
-                }
-            }
-        }
-        return Integer.valueOf(a).equals(b);
-    }
+    return a.doubleValue() == b.doubleValue();
+  }
 
-    public static Boolean equal(Integer a, String b) {
-        return equal(b, a);
+  public static Boolean equal(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
 
-    public static Boolean equal(String a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return Double.valueOf(a).equals(b);
-    }
+    return a.compareTo(b) == 0;
+  }
 
-    public static Boolean equal(Double a, String b) {
-        return equal(b, a);
+  public static Boolean equal(String a, String b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.equals(b);
+  }
 
-    public static Boolean equal(String a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        int dot = a.indexOf('.');
-        if (dot >= 0) {
-            for (int i = dot + 1; i < a.length(); i++) {
-                if (a.charAt(i) != '0') {
-                    return false;
-                }
-            }
-            a = a.substring(0, dot);
-        }
-        return Long.valueOf(a).equals(b);
+  public static Boolean equal(BinaryString a, BinaryString b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.equals(b);
+  }
 
-    public static Boolean equal(Long a, String b) {
-        return equal(b, a);
+  public static Boolean equal(Boolean a, Boolean b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.equals(b);
+  }
 
-    public static Boolean equal(String a, Boolean b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return Boolean.valueOf(a).equals(b);
+  public static Boolean equal(String a, Integer b) {
+    if (a == null || b == null) {
+      return null;
     }
+    int dot = a.indexOf('.');
+    if (dot >= 0) {
+      for (int i = dot + 1; i < a.length(); i++) {
+        if (a.charAt(i) != '0') {
+          return false;
+        }
+      }
+    }
+    return Integer.valueOf(a).equals(b);
+  }
+
+  public static Boolean equal(Integer a, String b) {
+    return equal(b, a);
+  }
 
-    public static Boolean equal(Boolean a, String b) {
-        return equal(b, a);
+  public static Boolean equal(String a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return Double.valueOf(a).equals(b);
+  }
 
-    public static Boolean equal(Object a, Object b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.equals(b);
-    }
-
-    /**
-     * ISO-GQL SAME predicate function for vertices.
-     * Checks if two vertices refer to the same element by comparing their IDs.
-     *
-     * @param a first vertex
-     * @param b second vertex
-     * @return true if vertices have the same ID, false otherwise, null if either is null
-     */
-    public static Boolean same(RowVertex a, RowVertex b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return Objects.equals(a.getId(), b.getId());
-    }
-
-    /**
-     * ISO-GQL SAME predicate function for edges.
-     * Checks if two edges refer to the same element by comparing their source and target IDs.
-     *
-     * @param a first edge
-     * @param b second edge
-     * @return true if edges have the same source and target IDs, false otherwise, null if either is null
-     */
-    public static Boolean same(RowEdge a, RowEdge b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return Objects.equals(a.getSrcId(), b.getSrcId())
-            && Objects.equals(a.getTargetId(), b.getTargetId());
-    }
-
-    /**
-     * ISO-GQL SAME predicate function (fallback for mixed or unknown types).
-     * Checks if two graph elements refer to the same element by comparing their identities.
-     * For vertices, compares vertex IDs.
-     * For edges, compares both source and target IDs.
-     *
-     * @param a first element (vertex or edge)
-     * @param b second element (vertex or edge)
-     * @return true if elements have the same identity, false otherwise, null if either is null
-     */
-    public static Boolean same(Object a, Object b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        // Delegate to type-specific overloads when possible
-        if (a instanceof RowVertex && b instanceof RowVertex) {
-            return same((RowVertex) a, (RowVertex) b);
-        }
-        if (a instanceof RowEdge && b instanceof RowEdge) {
-            return same((RowEdge) a, (RowEdge) b);
-        }
-        // Different types cannot be the same
-        return false;
-    }
-
-    /**
-     * ISO-GQL SAME predicate function for multiple elements.
-     * Checks if all elements refer to the same graph element.
-     * Returns true only if all elements are identical (same type and same identity).
-     *
-     * @param elements array of elements to compare (minimum 2 required)
-     * @return true if all elements have the same identity, false otherwise, null if any is null
-     */
-    public static Boolean same(Object... elements) {
-        if (elements == null || elements.length < 2) {
-            return null;
-        }
-        // Check for any null elements
-        for (Object e : elements) {
-            if (e == null) {
-                return null;
-            }
-        }
-        // Compare all elements with the first one
-        Object first = elements[0];
-        for (int i = 1; i < elements.length; i++) {
-            Boolean result = same(first, elements[i]);
-            if (result == null || !result) {
-                return result;
-            }
-        }
-        return true;
+  public static Boolean equal(Double a, String b) {
+    return equal(b, a);
+  }
+
+  public static Boolean equal(String a, Long b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    int dot = a.indexOf('.');
+    if (dot >= 0) {
+      for (int i = dot + 1; i < a.length(); i++) {
+        if (a.charAt(i) != '0') {
+          return false;
+        }
+      }
+      a = a.substring(0, dot);
+    }
+    return Long.valueOf(a).equals(b);
+  }
+
+  public static Boolean equal(Long a, String b) {
+    return equal(b, a);
+  }
+
+  public static Boolean equal(String a, Boolean b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return Boolean.valueOf(a).equals(b);
+  }
+
+  public static Boolean equal(Boolean a, String b) {
+    return equal(b, a);
+  }
+
+  public static Boolean equal(Object a, Object b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.equals(b);
+  }
+
+  /**
+   * ISO-GQL SAME predicate function for vertices. Checks if two vertices refer to the same element
+   * by comparing their IDs.
+   *
+   * @param a first vertex
+   * @param b second vertex
+   * @return true if vertices have the same ID, false otherwise, null if either is null
+   */
+  public static Boolean same(RowVertex a, RowVertex b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return Objects.equals(a.getId(), b.getId());
+  }
+
+  /**
+   * ISO-GQL SAME predicate function for edges. Checks if two edges refer to the same element by
+   * comparing their source and target IDs.
+   *
+   * @param a first edge
+   * @param b second edge
+   * @return true if edges have the same source and target IDs, false otherwise, null if either is
+   *     null
+   */
+  public static Boolean same(RowEdge a, RowEdge b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return Objects.equals(a.getSrcId(), b.getSrcId())
+        && Objects.equals(a.getTargetId(), b.getTargetId());
+  }
+
+  /**
+   * ISO-GQL SAME predicate function (fallback for mixed or unknown types). Checks if two graph
+   * elements refer to the same element by comparing their identities. For vertices, compares vertex
+   * IDs. For edges, compares both source and target IDs.
+   *
+   * @param a first element (vertex or edge)
+   * @param b second element (vertex or edge)
+   * @return true if elements have the same identity, false otherwise, null if either is null
+   */
+  public static Boolean same(Object a, Object b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    // Delegate to type-specific overloads when possible
+    if (a instanceof RowVertex && b instanceof RowVertex) {
+      return same((RowVertex) a, (RowVertex) b);
+    }
+    if (a instanceof RowEdge && b instanceof RowEdge) {
+      return same((RowEdge) a, (RowEdge) b);
+    }
+    // Different types cannot be the same
+    return false;
+  }
+
+  /**
+   * ISO-GQL SAME predicate function for multiple elements. Checks if all elements refer to the same
+   * graph element. Returns true only if all elements are identical (same type and same identity).
+   *
+   * @param elements array of elements to compare (minimum 2 required)
+   * @return true if all elements have the same identity, false otherwise, null if any is null
+   */
+  public static Boolean same(Object... elements) {
+    if (elements == null || elements.length < 2) {
+      return null;
+    }
+    // Check for any null elements
+    for (Object e : elements) {
+      if (e == null) {
+        return null;
+      }
+    }
+    // Compare all elements with the first one
+    Object first = elements[0];
+    for (int i = 1; i < elements.length; i++) {
+      Boolean result = same(first, elements[i]);
+      if (result == null || !result) {
+        return result;
+      }
+    }
+    return true;
+  }
+
+  public static Boolean unequal(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.longValue() != b.longValue();
+  }
+
+  public static Boolean unequal(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.doubleValue() != b.doubleValue();
+  }
+
+  public static Boolean unequal(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) != 0;
+  }
+
+  public static Boolean unequal(String a, String b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return !a.equals(b);
+  }
+
+  public static Boolean unequal(Boolean a, Boolean b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return !a.equals(b);
+  }
+
+  public static Boolean unequal(String a, Integer b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    int dot = a.indexOf('.');
+    if (dot >= 0) {
+      for (int i = dot + 1; i < a.length(); i++) {
+        if (a.charAt(i) != '0') {
+          return true;
+        }
+      }
+    }
+    return !Integer.valueOf(a).equals(b);
+  }
+
+  public static Boolean unequal(Integer a, String b) {
+    Boolean equals = equal(b, a);
+    return equals != null ? !equals : null;
+  }
+
+  public static Boolean unequal(String a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return !Double.valueOf(a).equals(b);
+  }
 
-    public static Boolean unequal(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.longValue() != b.longValue();
+  public static Boolean unequal(Double a, String b) {
+    if (a == null || b == null) {
+      return null;
     }
+    Boolean equals = equal(b, a);
+    return equals != null ? !equals : null;
+  }
 
-    public static Boolean unequal(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.doubleValue() != b.doubleValue();
+  public static Boolean unequal(String a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
-
-    public static Boolean unequal(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
+    int dot = a.indexOf('.');
+    if (dot >= 0) {
+      for (int i = dot + 1; i < a.length(); i++) {
+        if (a.charAt(i) != '0') {
+          return true;
         }
-        return a.compareTo(b) != 0;
+      }
+      a = a.substring(0, dot);
     }
+    return !Long.valueOf(a).equals(b);
+  }
 
-    public static Boolean unequal(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return !a.equals(b);
+  public static Boolean unequal(Long a, String b) {
+    if (a == null || b == null) {
+      return null;
     }
+    Boolean equals = equal(b, a);
+    return equals != null ? !equals : null;
+  }
 
-    public static Boolean unequal(Boolean a, Boolean b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return !a.equals(b);
+  public static Boolean unequal(String a, Boolean b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return !Boolean.valueOf(a).equals(b);
+  }
 
-    public static Boolean unequal(String a, Integer b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        int dot = a.indexOf('.');
-        if (dot >= 0) {
-            for (int i = dot + 1; i < a.length(); i++) {
-                if (a.charAt(i) != '0') {
-                    return true;
-                }
-            }
-        }
-        return !Integer.valueOf(a).equals(b);
+  public static Boolean unequal(Boolean a, String b) {
+    if (a == null || b == null) {
+      return null;
     }
+    Boolean equals = equal(b, a);
+    return equals != null ? !equals : null;
+  }
 
-    public static Boolean unequal(Integer a, String b) {
-        Boolean equals = equal(b, a);
-        return equals != null ? !equals : null;
+  public static Boolean unequal(Object a, Object b) {
+    if (a == null || b == null) {
+      return null;
     }
+    Boolean equals = equal(a, b);
+    return equals != null ? !equals : null;
+  }
 
-    public static Boolean unequal(String a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return !Double.valueOf(a).equals(b);
+  public static Boolean lessThan(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a < b;
+  }
 
-    public static Boolean unequal(Double a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        Boolean equals = equal(b, a);
-        return equals != null ? !equals : null;
+  public static Boolean lessThan(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a < b;
+  }
 
-    public static Boolean unequal(String a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        int dot = a.indexOf('.');
-        if (dot >= 0) {
-            for (int i = dot + 1; i < a.length(); i++) {
-                if (a.charAt(i) != '0') {
-                    return true;
-                }
-            }
-            a = a.substring(0, dot);
-        }
-        return !Long.valueOf(a).equals(b);
+  public static Boolean lessThan(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.compareTo(b) < 0;
+  }
 
-    public static Boolean unequal(Long a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        Boolean equals = equal(b, a);
-        return equals != null ? !equals : null;
+  public static Boolean lessThan(String a, String b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a.compareTo(b) < 0;
+  }
 
-    public static Boolean unequal(String a, Boolean b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return !Boolean.valueOf(a).equals(b);
+  public static Boolean greaterThanEq(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a >= b;
+  }
 
-    public static Boolean unequal(Boolean a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        Boolean equals = equal(b, a);
-        return equals != null ? !equals : null;
+  public static Boolean greaterThanEq(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
     }
+    return a >= b;
+  }
 
-    public static Boolean unequal(Object a, Object b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        Boolean equals = equal(a, b);
-        return equals != null ? !equals : null;
+  public static Boolean greaterThanEq(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) >= 0;
+  }
+
+  public static Boolean greaterThanEq(String a, String b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) >= 0;
+  }
+
+  public static Boolean lessThanEq(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a <= b;
+  }
+
+  public static Boolean lessThanEq(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a <= b;
+  }
+
+  public static Boolean lessThanEq(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) <= 0;
+  }
+
+  public static Boolean lessThanEq(String a, String b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) <= 0;
+  }
+
+  public static Boolean greaterThan(Long a, Long b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a > b;
+  }
+
+  public static Boolean greaterThan(Double a, Double b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a > b;
+  }
+
+  public static Boolean greaterThan(BigDecimal a, BigDecimal b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) > 0;
+  }
+
+  public static Boolean greaterThan(String a, String b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a.compareTo(b) > 0;
+  }
+
+  public static Timestamp timestampCeil(Timestamp b0, Long b1) {
+    if (b1 == 1000) { // second
+      return new Timestamp(DateUtils.ceiling(b0, Calendar.SECOND).getTime());
+    } else if (b1 == 60000) { // minute
+      return new Timestamp(DateUtils.ceiling(b0, Calendar.MINUTE).getTime());
+    } else if (b1 == 3600000) { // hour
+      return new Timestamp(DateUtils.ceiling(b0, Calendar.HOUR).getTime());
+    } else if (b1 == 86400000) { // day
+      return new Timestamp(DateUtils.ceiling(b0, Calendar.DATE).getTime());
+    } else {
+      throw new RuntimeException();
+    }
+  }
+
+  public static Timestamp timestampTumble(Timestamp b0, Long b1) {
+    if (b1 % 86400000 == 0) {
+      long base = DateUtils.truncate(b0, Calendar.DAY_OF_YEAR).getTime();
+      long day = b0.getDay() * 86400000;
+      long interval = ((day / b1) + 1) * b1 - day;
+      return new Timestamp(base + interval);
+    } else if (b1 % 3600000 == 0) {
+      long base = DateUtils.truncate(b0, Calendar.HOUR).getTime();
+      long hour = b0.getHours() * 3600000;
+      long interval = ((hour / b1) + 1) * b1 - hour;
+      return new Timestamp(base + interval);
+    } else if (b1 % 60000 == 0) {
+      long base = DateUtils.truncate(b0, Calendar.MINUTE).getTime();
+      long minutes = b0.getMinutes() * 60000;
+      long interval = ((minutes / b1) + 1) * b1 - minutes;
+      return new Timestamp(base + interval);
+    } else if (b1 % 1000 == 0) {
+      long base = DateUtils.truncate(b0, Calendar.SECOND).getTime();
+      long second = b0.getSeconds() * 1000;
+      long interval = ((second / b1) + 1) * b1 - second;
+      return new Timestamp(base + interval);
+    } else {
+      throw new RuntimeException();
+    }
+  }
+
+  public static Timestamp timestampFloor(Timestamp b0, Long b1) {
+    if (b1 == 1000) { // second
+      return new Timestamp(DateUtils.truncate(b0, Calendar.SECOND).getTime());
+    } else if (b1 == 60000) { // minute
+      return new Timestamp(DateUtils.truncate(b0, Calendar.MINUTE).getTime());
+    } else if (b1 == 3600000) { // hour
+      return new Timestamp(DateUtils.truncate(b0, Calendar.HOUR).getTime());
+    } else if (b1 == 86400000) { // day
+      return new Timestamp(DateUtils.truncate(b0, Calendar.DATE).getTime());
+    } else {
+      throw new RuntimeException();
+    }
+  }
+
+  public static Timestamp plus(Timestamp d, Long b0) {
+    if (d == null || b0 == null) {
+      return null;
+    }
+    return new Timestamp(d.getTime() + b0);
+  }
+
+  public static Timestamp minus(Timestamp d, Long b0) {
+    if (d == null || b0 == null) {
+      return null;
+    }
+    return new Timestamp(d.getTime() - b0);
+  }
+
+  public static Long minus(Timestamp d1, Timestamp d2) {
+    if (d1 == null || d2 == null) {
+      return null;
+    }
+    return d1.getTime() - d2.getTime();
+  }
+
+  public static Timestamp currentTimestamp() {
+    return new Timestamp(System.currentTimeMillis());
+  }
+
+  public static final int TRIM_BOTH = 0;
+  public static final int TRIM_LEFT = 1;
+  public static final int TRIM_RIGHT = 2;
+
+  public static String concat(String a, String b) {
+    if (a == null || b == null) {
+      return null;
+    }
+    return a + b;
+  }
+
+  public static Integer length(String s) {
+    if (s == null) {
+      return null;
+    }
+    return s.length();
+  }
+
+  public static String lower(String s) {
+    if (s == null) {
+      return null;
+    }
+    return s.toLowerCase();
+  }
+
+  public static String upper(String s) {
+    if (s == null) {
+      return null;
+    }
+    return s.toUpperCase();
+  }
+
+  public static String overlay(String s, String r, int start) {
+    if (s == null || r == null) {
+      return null;
+    }
+    return s.substring(0, start - 1) + r + s.substring(start - 1 + r.length());
+  }
+
+  public static String overlay(String s, String r, int start, int length) {
+    if (s == null || r == null) {
+      return null;
+    }
+    return s.substring(0, start - 1) + r + s.substring(start - 1 + length);
+  }
+
+  public static Integer position(String seek, String s) {
+    if (seek == null || s == null) {
+      return null;
     }
+    return s.indexOf(seek) + 1;
+  }
 
-    public static Boolean lessThan(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a < b;
+  public static Integer position(String seek, String s, Integer from) {
+    if (seek == null || s == null || from == null) {
+      return null;
     }
-
-    public static Boolean lessThan(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a < b;
+    final int from0 = from - 1;
+    if (from0 > s.length() || from0 < 0) {
+      return 0;
     }
+    return s.indexOf(seek, from0) + 1;
+  }
 
-    public static Boolean lessThan(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) < 0;
+  public static String substring(String c, Integer s, Integer l) {
+    if (c == null || s == null || l == null) {
+      return null;
     }
-
-    public static Boolean lessThan(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) < 0;
+    int lc = c.length();
+    if (s < 0) {
+      s += lc + 1;
     }
-
-    public static Boolean greaterThanEq(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a >= b;
-    }
-
-    public static Boolean greaterThanEq(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a >= b;
-    }
-
-    public static Boolean greaterThanEq(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) >= 0;
-    }
-
-    public static Boolean greaterThanEq(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) >= 0;
-    }
-
-    public static Boolean lessThanEq(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a <= b;
-    }
-
-    public static Boolean lessThanEq(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a <= b;
-    }
-
-    public static Boolean lessThanEq(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) <= 0;
+    int e = s + l;
+    if (e < s) {
+      return null;
     }
-
-    public static Boolean lessThanEq(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) <= 0;
-    }
-
-    public static Boolean greaterThan(Long a, Long b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a > b;
+    if (s > lc || e < 1) {
+      return "";
     }
+    int s1 = Math.max(s, 1);
+    int e1 = Math.min(e, lc + 1);
+    return c.substring(s1 - 1, e1 - 1);
+  }
 
-    public static Boolean greaterThan(Double a, Double b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a > b;
+  public static String substring(String c, Integer s) {
+    if (c == null || s == null) {
+      return null;
     }
+    return substring(c, s, c.length() + 1);
+  }
 
-    public static Boolean greaterThan(BigDecimal a, BigDecimal b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a.compareTo(b) > 0;
+  public static String trim(Integer flag, String removeStr, String str) {
+    if (flag == null || removeStr == null || str == null) {
+      return null;
     }
-
-    public static Boolean greaterThan(String a, String b) {
-        if (a == null || b == null) {
-            return null;
+    switch (flag) {
+      case TRIM_BOTH:
+        while (str.startsWith(removeStr)) {
+          str = str.substring(removeStr.length());
         }
-        return a.compareTo(b) > 0;
-    }
-
-    public static Timestamp timestampCeil(Timestamp b0, Long b1) {
-        if (b1 == 1000) { // second
-            return new Timestamp(DateUtils.ceiling(b0, Calendar.SECOND).getTime());
-        } else if (b1 == 60000) { // minute
-            return new Timestamp(DateUtils.ceiling(b0, Calendar.MINUTE).getTime());
-        } else if (b1 == 3600000) { // hour
-            return new Timestamp(DateUtils.ceiling(b0, Calendar.HOUR).getTime());
-        } else if (b1 == 86400000) { // day
-            return new Timestamp(DateUtils.ceiling(b0, Calendar.DATE).getTime());
-        } else {
-            throw new RuntimeException();
+        while (str.endsWith(removeStr)) {
+          str = str.substring(0, str.length() - removeStr.length());
         }
-    }
-
-    public static Timestamp timestampTumble(Timestamp b0, Long b1) {
-        if (b1 % 86400000 == 0) {
-            long base = DateUtils.truncate(b0, Calendar.DAY_OF_YEAR).getTime();
-            long day = b0.getDay() * 86400000;
-            long interval = ((day / b1) + 1) * b1 - day;
-            return new Timestamp(base + interval);
-        } else if (b1 % 3600000 == 0) {
-            long base = DateUtils.truncate(b0, Calendar.HOUR).getTime();
-            long hour = b0.getHours() * 3600000;
-            long interval = ((hour / b1) + 1) * b1 - hour;
-            return new Timestamp(base + interval);
-        } else if (b1 % 60000 == 0) {
-            long base = DateUtils.truncate(b0, Calendar.MINUTE).getTime();
-            long minutes = b0.getMinutes() * 60000;
-            long interval = ((minutes / b1) + 1) * b1 - minutes;
-            return new Timestamp(base + interval);
-        } else if (b1 % 1000 == 0) {
-            long base = DateUtils.truncate(b0, Calendar.SECOND).getTime();
-            long second = b0.getSeconds() * 1000;
-            long interval = ((second / b1) + 1) * b1 - second;
-            return new Timestamp(base + interval);
-        } else {
-            throw new RuntimeException();
+        return str;
+      case TRIM_LEFT:
+        while (str.startsWith(removeStr)) {
+          str = str.substring(removeStr.length());
         }
-    }
-
-    public static Timestamp timestampFloor(Timestamp b0, Long b1) {
-        if (b1 == 1000) { // second
-            return new Timestamp(DateUtils.truncate(b0, Calendar.SECOND).getTime());
-        } else if (b1 == 60000) { // minute
-            return new Timestamp(DateUtils.truncate(b0, Calendar.MINUTE).getTime());
-        } else if (b1 == 3600000) { // hour
-            return new Timestamp(DateUtils.truncate(b0, Calendar.HOUR).getTime());
-        } else if (b1 == 86400000) { // day
-            return new Timestamp(DateUtils.truncate(b0, Calendar.DATE).getTime());
-        } else {
-            throw new RuntimeException();
+        return str;
+      case TRIM_RIGHT:
+        while (str.endsWith(removeStr)) {
+          str = str.substring(0, str.length() - removeStr.length());
         }
+        return str;
+      default:
+        throw new UnsupportedOperationException("Not support trim flag: " + flag);
     }
+  }
 
-    public static Timestamp plus(Timestamp d, Long b0) {
-        if (d == null || b0 == null) {
-            return null;
-        }
-        return new Timestamp(d.getTime() + b0);
+  public static BinaryString trim(Integer flag, BinaryString removeStr, BinaryString str) {
+    if (flag == null || removeStr == null || str == null) {
+      return null;
     }
-
-    public static Timestamp minus(Timestamp d, Long b0) {
-        if (d == null || b0 == null) {
-            return null;
+    switch (flag) {
+      case TRIM_BOTH:
+        while (str.startsWith(removeStr)) {
+          str = str.substring(removeStr.getLength());
         }
-        return new Timestamp(d.getTime() - b0);
-    }
-
-    public static Long minus(Timestamp d1, Timestamp d2) {
-        if (d1 == null || d2 == null) {
-            return null;
+        while (str.endsWith(removeStr)) {
+          str = str.substring(0, str.getLength() - removeStr.getLength());
         }
-        return d1.getTime() - d2.getTime();
-    }
-
-    public static Timestamp currentTimestamp() {
-        return new Timestamp(System.currentTimeMillis());
-    }
-
-    public static final int TRIM_BOTH = 0;
-    public static final int TRIM_LEFT = 1;
-    public static final int TRIM_RIGHT = 2;
-
-    public static String concat(String a, String b) {
-        if (a == null || b == null) {
-            return null;
-        }
-        return a + b;
-    }
-
-    public static Integer length(String s) {
-        if (s == null) {
-            return null;
-        }
-        return s.length();
-    }
-
-    public static String lower(String s) {
-        if (s == null) {
-            return null;
-        }
-        return s.toLowerCase();
-    }
-
-    public static String upper(String s) {
-        if (s == null) {
-            return null;
-        }
-        return s.toUpperCase();
-    }
-
-    public static String overlay(String s, String r, int start) {
-        if (s == null || r == null) {
-            return null;
-        }
-        return s.substring(0, start - 1)
-            + r
-            + s.substring(start - 1 + r.length());
-    }
-
-    public static String overlay(String s, String r, int start, int length) {
-        if (s == null || r == null) {
-            return null;
-        }
-        return s.substring(0, start - 1)
-            + r
-            + s.substring(start - 1 + length);
-    }
-
-    public static Integer position(String seek, String s) {
-        if (seek == null || s == null) {
-            return null;
-        }
-        return s.indexOf(seek) + 1;
-    }
-
-    public static Integer position(String seek, String s, Integer from) {
-        if (seek == null
-            || s == null || from == null) {
-            return null;
-        }
-        final int from0 = from - 1;
-        if (from0 > s.length() || from0 < 0) {
-            return 0;
-        }
-        return s.indexOf(seek, from0) + 1;
-    }
-
-    public static String substring(String c, Integer s, Integer l) {
-        if (c == null
-            || s == null || l == null) {
-            return null;
-        }
-        int lc = c.length();
-        if (s < 0) {
-            s += lc + 1;
-        }
-        int e = s + l;
-        if (e < s) {
-            return null;
-        }
-        if (s > lc || e < 1) {
-            return "";
-        }
-        int s1 = Math.max(s, 1);
-        int e1 = Math.min(e, lc + 1);
-        return c.substring(s1 - 1, e1 - 1);
-    }
-
-    public static String substring(String c, Integer s) {
-        if (c == null || s == null) {
-            return null;
-        }
-        return substring(c, s, c.length() + 1);
-    }
-
-    public static String trim(Integer flag,
-                              String removeStr, String str) {
-        if (flag == null || removeStr == null || str == null) {
-            return null;
-        }
-        switch (flag) {
-            case TRIM_BOTH:
-                while (str.startsWith(removeStr)) {
-                    str = str.substring(removeStr.length());
-                }
-                while (str.endsWith(removeStr)) {
-                    str = str.substring(0, str.length() - removeStr.length());
-                }
-                return str;
-            case TRIM_LEFT:
-                while (str.startsWith(removeStr)) {
-                    str = str.substring(removeStr.length());
-                }
-                return str;
-            case TRIM_RIGHT:
-                while (str.endsWith(removeStr)) {
-                    str = str.substring(0, str.length() - removeStr.length());
-                }
-                return str;
-            default:
-                throw new UnsupportedOperationException("Not support trim flag: " + flag);
-        }
-    }
-
-    public static BinaryString trim(Integer flag,
-                                    BinaryString removeStr, BinaryString str) {
-        if (flag == null || removeStr == null || str == null) {
-            return null;
+        return str;
+      case TRIM_LEFT:
+        while (str.startsWith(removeStr)) {
+          str = str.substring(removeStr.getLength());
         }
-        switch (flag) {
-            case TRIM_BOTH:
-                while (str.startsWith(removeStr)) {
-                    str = str.substring(removeStr.getLength());
-                }
-                while (str.endsWith(removeStr)) {
-                    str = str.substring(0, str.getLength() - removeStr.getLength());
-                }
-                return str;
-            case TRIM_LEFT:
-                while (str.startsWith(removeStr)) {
-                    str = str.substring(removeStr.getLength());
-                }
-                return str;
-            case TRIM_RIGHT:
-                while (str.endsWith(removeStr)) {
-                    str = str.substring(0, str.getLength() - removeStr.getLength());
-                }
-                return str;
-            default:
-                throw new UnsupportedOperationException("Not support trim flag: " + flag);
+        return str;
+      case TRIM_RIGHT:
+        while (str.endsWith(removeStr)) {
+          str = str.substring(0, str.getLength() - removeStr.getLength());
         }
+        return str;
+      default:
+        throw new UnsupportedOperationException("Not support trim flag: " + flag);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowOverwriteSqlOperators.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowOverwriteSqlOperators.java
index c3f8cc6f2..23eedabb1 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowOverwriteSqlOperators.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowOverwriteSqlOperators.java
@@ -32,74 +32,74 @@
 
 public class GeaFlowOverwriteSqlOperators {
 
-    public static final SqlBinaryOperator PLUS =
-        new SqlMonotonicBinaryOperator(
-            "+",
-            SqlKind.PLUS,
-            40,
-            true,
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "plus"),
-            InferTypes.FIRST_KNOWN,
-            OperandTypes.PLUS_OPERATOR);
+  public static final SqlBinaryOperator PLUS =
+      new SqlMonotonicBinaryOperator(
+          "+",
+          SqlKind.PLUS,
+          40,
+          true,
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "plus"),
+          InferTypes.FIRST_KNOWN,
+          OperandTypes.PLUS_OPERATOR);
 
-    public static final SqlBinaryOperator MINUS =
-        new SqlMonotonicBinaryOperator(
-            "-",
-            SqlKind.MINUS,
-            40,
-            true,
-            // Same type inference strategy as sum
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "plus"),
-            InferTypes.FIRST_KNOWN,
-            OperandTypes.MINUS_OPERATOR);
+  public static final SqlBinaryOperator MINUS =
+      new SqlMonotonicBinaryOperator(
+          "-",
+          SqlKind.MINUS,
+          40,
+          true,
+          // Same type inference strategy as sum
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "plus"),
+          InferTypes.FIRST_KNOWN,
+          OperandTypes.MINUS_OPERATOR);
 
-    public static final SqlBinaryOperator DIVIDE =
-        new SqlBinaryOperator(
-            "/",
-            SqlKind.DIVIDE,
-            60,
-            true,
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "divide"),
-            InferTypes.FIRST_KNOWN,
-            OperandTypes.DIVISION_OPERATOR);
+  public static final SqlBinaryOperator DIVIDE =
+      new SqlBinaryOperator(
+          "/",
+          SqlKind.DIVIDE,
+          60,
+          true,
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "divide"),
+          InferTypes.FIRST_KNOWN,
+          OperandTypes.DIVISION_OPERATOR);
 
-    public static final SqlBinaryOperator MULTIPLY =
-        new SqlMonotonicBinaryOperator(
-            "*",
-            SqlKind.TIMES,
-            60,
-            true,
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "times"),
-            InferTypes.FIRST_KNOWN,
-            OperandTypes.MULTIPLY_OPERATOR);
+  public static final SqlBinaryOperator MULTIPLY =
+      new SqlMonotonicBinaryOperator(
+          "*",
+          SqlKind.TIMES,
+          60,
+          true,
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "times"),
+          InferTypes.FIRST_KNOWN,
+          OperandTypes.MULTIPLY_OPERATOR);
 
-    public static final SqlFunction MOD =
-        // Return type is same as divisor (2nd operand)
-        // SQL2003 Part2 Section 6.27, Syntax Rules 9
-        new SqlFunction(
-            "MOD",
-            SqlKind.MOD,
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "mod"),
-            null,
-            OperandTypes.EXACT_NUMERIC_EXACT_NUMERIC,
-            SqlFunctionCategory.NUMERIC);
+  public static final SqlFunction MOD =
+      // Return type is same as divisor (2nd operand)
+      // SQL2003 Part2 Section 6.27, Syntax Rules 9
+      new SqlFunction(
+          "MOD",
+          SqlKind.MOD,
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "mod"),
+          null,
+          OperandTypes.EXACT_NUMERIC_EXACT_NUMERIC,
+          SqlFunctionCategory.NUMERIC);
 
-    public static final SqlBinaryOperator PERCENT_REMAINDER =
-        new SqlBinaryOperator(
-            "%",
-            SqlKind.MOD,
-            60,
-            true,
-            getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "mod"),
-            null,
-            OperandTypes.EXACT_NUMERIC_EXACT_NUMERIC);
+  public static final SqlBinaryOperator PERCENT_REMAINDER =
+      new SqlBinaryOperator(
+          "%",
+          SqlKind.MOD,
+          60,
+          true,
+          getSqlReturnTypeInference(GeaFlowBuiltinFunctions.class, "mod"),
+          null,
+          OperandTypes.EXACT_NUMERIC_EXACT_NUMERIC);
 
-    public static final SqlFunction SIGN =
-        new SqlFunction(
-            "SIGN",
-            SqlKind.OTHER_FUNCTION,
-            ReturnTypes.DOUBLE_NULLABLE,
-            null,
-            OperandTypes.NUMERIC,
-            SqlFunctionCategory.NUMERIC);
+  public static final SqlFunction SIGN =
+      new SqlFunction(
+          "SIGN",
+          SqlKind.OTHER_FUNCTION,
+          ReturnTypes.DOUBLE_NULLABLE,
+          null,
+          OperandTypes.NUMERIC,
+          SqlFunctionCategory.NUMERIC);
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedAggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedAggFunction.java
index cc99e707f..03f5f42d3 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedAggFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedAggFunction.java
@@ -23,6 +23,7 @@
 import java.lang.reflect.Type;
 import java.util.List;
 import java.util.Objects;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.sql.SqlCallBinding;
 import org.apache.calcite.sql.SqlIdentifier;
@@ -40,108 +41,111 @@
 import org.apache.geaflow.dsl.planner.GQLJavaTypeFactory;
 import org.apache.geaflow.dsl.util.SqlTypeUtil;
 
-public class GeaFlowUserDefinedAggFunction extends SqlUserDefinedAggFunction implements Serializable {
-
-    private List>> udafClasses;
-
-    private GeaFlowUserDefinedAggFunction(String name, List>> udafClasses, GQLJavaTypeFactory typeFactory) {
-        super(new SqlIdentifier(name, SqlParserPos.ZERO),
-            getReturnTypeInference(name, udafClasses, typeFactory),
-            getOperandTypeInference(name, udafClasses, typeFactory),
-            getSqlOperandTypeChecker(name, udafClasses, typeFactory),
-            null, false, false,
-            Optionality.FORBIDDEN,
-            typeFactory);
-        this.udafClasses = Objects.requireNonNull(udafClasses);
-    }
-
-    public static GeaFlowUserDefinedAggFunction create(String name,
-                                                       List>> clazzs,
-                                                       GQLJavaTypeFactory typeFactory) {
-        return new GeaFlowUserDefinedAggFunction(name, clazzs, typeFactory);
-    }
-
-    private static SqlReturnTypeInference getReturnTypeInference(String name,
-                                                                 List>> clazzs,
-                                                                 GQLJavaTypeFactory typeFactory) {
-        return opBinding -> {
-            List> callParamTypes = SqlTypeUtil
-                .convertToJavaTypes(opBinding.collectOperandTypes(), typeFactory);
-            Class clazz = FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
-
-            Type[] genericTypes = FunctionCallUtils.getUDAFGenericTypes(clazz);
-            Type aggOutputType = genericTypes[2];
-            return typeFactory.createType(aggOutputType);
-        };
-    }
-
-    private static SqlOperandTypeInference getOperandTypeInference(String name,
-                                                                   List>> clazzs,
-                                                                   GQLJavaTypeFactory typeFactory) {
-        return (callBinding, returnType, operandTypes) -> {
-            List> callParamTypes = SqlTypeUtil
-                .convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory);
-            Class clazz = FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
-            List> aggInputTypes = FunctionCallUtils.getUDAFInputTypes(clazz);
-            for (int i = 0; i < operandTypes.length; i++) {
-                operandTypes[i] = typeFactory.createType(aggInputTypes.get(i));
-            }
-        };
-    }
-
-    private static SqlOperandTypeChecker getSqlOperandTypeChecker(String name, List>> clazzs,
-                                                                  GQLJavaTypeFactory typeFactory) {
-        return new SqlOperandTypeChecker() {
-            @Override
-            public boolean checkOperandTypes(SqlCallBinding callBinding,
-                                             boolean throwOnFailure) {
-                List> callParamTypes = SqlTypeUtil
-                    .convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory);
-                FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
-                return true;
-            }
-
-            @Override
-            public SqlOperandCountRange getOperandCountRange() {
-                int max = -1;
-                int min = 255;
-
-                for (Class clazz : clazzs) {
-                    List> inputTypes = FunctionCallUtils.getUDAFInputTypes(clazz);
-                    int size = inputTypes.size();
-                    if (size > max) {
-                        max = size;
-                    }
-                    if (size < min) {
-                        min = size;
-                    }
-                }
-                return SqlOperandCountRanges.between(min, max);
-            }
-
-            @Override
-            public String getAllowedSignatures(SqlOperator op, String opName) {
-                return opName + clazzs.toString();
-            }
-
-            @Override
-            public Consistency getConsistency() {
-                return Consistency.NONE;
-            }
-
-            @Override
-            public boolean isOptional(int i) {
-                return false;
-            }
-        };
-    }
-
-    @Override
-    public List getParamTypes() {
-        return null;
-    }
-
-    public List>> getUdafClasses() {
-        return udafClasses;
-    }
+public class GeaFlowUserDefinedAggFunction extends SqlUserDefinedAggFunction
+    implements Serializable {
+
+  private List>> udafClasses;
+
+  private GeaFlowUserDefinedAggFunction(
+      String name,
+      List>> udafClasses,
+      GQLJavaTypeFactory typeFactory) {
+    super(
+        new SqlIdentifier(name, SqlParserPos.ZERO),
+        getReturnTypeInference(name, udafClasses, typeFactory),
+        getOperandTypeInference(name, udafClasses, typeFactory),
+        getSqlOperandTypeChecker(name, udafClasses, typeFactory),
+        null,
+        false,
+        false,
+        Optionality.FORBIDDEN,
+        typeFactory);
+    this.udafClasses = Objects.requireNonNull(udafClasses);
+  }
+
+  public static GeaFlowUserDefinedAggFunction create(
+      String name, List>> clazzs, GQLJavaTypeFactory typeFactory) {
+    return new GeaFlowUserDefinedAggFunction(name, clazzs, typeFactory);
+  }
+
+  private static SqlReturnTypeInference getReturnTypeInference(
+      String name, List>> clazzs, GQLJavaTypeFactory typeFactory) {
+    return opBinding -> {
+      List> callParamTypes =
+          SqlTypeUtil.convertToJavaTypes(opBinding.collectOperandTypes(), typeFactory);
+      Class clazz = FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
+
+      Type[] genericTypes = FunctionCallUtils.getUDAFGenericTypes(clazz);
+      Type aggOutputType = genericTypes[2];
+      return typeFactory.createType(aggOutputType);
+    };
+  }
+
+  private static SqlOperandTypeInference getOperandTypeInference(
+      String name, List>> clazzs, GQLJavaTypeFactory typeFactory) {
+    return (callBinding, returnType, operandTypes) -> {
+      List> callParamTypes =
+          SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory);
+      Class clazz = FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
+      List> aggInputTypes = FunctionCallUtils.getUDAFInputTypes(clazz);
+      for (int i = 0; i < operandTypes.length; i++) {
+        operandTypes[i] = typeFactory.createType(aggInputTypes.get(i));
+      }
+    };
+  }
+
+  private static SqlOperandTypeChecker getSqlOperandTypeChecker(
+      String name, List>> clazzs, GQLJavaTypeFactory typeFactory) {
+    return new SqlOperandTypeChecker() {
+      @Override
+      public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
+        List> callParamTypes =
+            SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory);
+        FunctionCallUtils.findMatchUDAF(name, clazzs, callParamTypes);
+        return true;
+      }
+
+      @Override
+      public SqlOperandCountRange getOperandCountRange() {
+        int max = -1;
+        int min = 255;
+
+        for (Class clazz : clazzs) {
+          List> inputTypes = FunctionCallUtils.getUDAFInputTypes(clazz);
+          int size = inputTypes.size();
+          if (size > max) {
+            max = size;
+          }
+          if (size < min) {
+            min = size;
+          }
+        }
+        return SqlOperandCountRanges.between(min, max);
+      }
+
+      @Override
+      public String getAllowedSignatures(SqlOperator op, String opName) {
+        return opName + clazzs.toString();
+      }
+
+      @Override
+      public Consistency getConsistency() {
+        return Consistency.NONE;
+      }
+
+      @Override
+      public boolean isOptional(int i) {
+        return false;
+      }
+    };
+  }
+
+  @Override
+  public List getParamTypes() {
+    return null;
+  }
+
+  public List>> getUdafClasses() {
+    return udafClasses;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedGraphAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedGraphAlgorithm.java
index 448c5033a..af72dda66 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedGraphAlgorithm.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedGraphAlgorithm.java
@@ -33,45 +33,46 @@
 
 public class GeaFlowUserDefinedGraphAlgorithm extends SqlUserDefinedFunction {
 
-    private final Class implementClass;
+  private final Class implementClass;
 
-    private GeaFlowUserDefinedGraphAlgorithm(String name, Class clazz,
-                                             GQLJavaTypeFactory typeFactory) {
-        super(new SqlIdentifier(name, SqlParserPos.ZERO),
-            getReturnTypeInference(clazz),
-            FunctionUtil.getSqlOperandTypeInference(clazz, typeFactory),
-            FunctionUtil.getSqlOperandTypeChecker(name, clazz, typeFactory),
-            null, null);
-        this.implementClass = clazz;
-    }
+  private GeaFlowUserDefinedGraphAlgorithm(
+      String name, Class clazz, GQLJavaTypeFactory typeFactory) {
+    super(
+        new SqlIdentifier(name, SqlParserPos.ZERO),
+        getReturnTypeInference(clazz),
+        FunctionUtil.getSqlOperandTypeInference(clazz, typeFactory),
+        FunctionUtil.getSqlOperandTypeChecker(name, clazz, typeFactory),
+        null,
+        null);
+    this.implementClass = clazz;
+  }
 
-    public static GeaFlowUserDefinedGraphAlgorithm create(String name, Class clazz,
-                                                          GQLJavaTypeFactory typeFactory) {
-        return new GeaFlowUserDefinedGraphAlgorithm(name, clazz, typeFactory);
-    }
+  public static GeaFlowUserDefinedGraphAlgorithm create(
+      String name, Class clazz, GQLJavaTypeFactory typeFactory) {
+    return new GeaFlowUserDefinedGraphAlgorithm(name, clazz, typeFactory);
+  }
 
-    @Override
-    public SqlKind getKind() {
-        return SqlKind.GQL_ALGORITHM;
-    }
+  @Override
+  public SqlKind getKind() {
+    return SqlKind.GQL_ALGORITHM;
+  }
 
-    private static SqlReturnTypeInference getReturnTypeInference(final Class clazz) {
-        return opBinding -> {
-            final GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) opBinding.getTypeFactory();
-            AlgorithmUserFunction algorithm;
-            try {
-                algorithm = (AlgorithmUserFunction) clazz.getConstructor().newInstance();
-            } catch (Exception e) {
-                throw new GeaFlowDSLException("Cannot new instance for class: " + clazz.getName(), e);
-            }
-            final StructType outputType =
-                algorithm.getOutputType(typeFactory.getCurrentGraph().getGraphSchema(typeFactory));
-            return SqlTypeUtil.convertToRelType(outputType, true, typeFactory);
-        };
-    }
-
-    public Class getImplementClass() {
-        return implementClass;
-    }
+  private static SqlReturnTypeInference getReturnTypeInference(final Class clazz) {
+    return opBinding -> {
+      final GQLJavaTypeFactory typeFactory = (GQLJavaTypeFactory) opBinding.getTypeFactory();
+      AlgorithmUserFunction algorithm;
+      try {
+        algorithm = (AlgorithmUserFunction) clazz.getConstructor().newInstance();
+      } catch (Exception e) {
+        throw new GeaFlowDSLException("Cannot new instance for class: " + clazz.getName(), e);
+      }
+      final StructType outputType =
+          algorithm.getOutputType(typeFactory.getCurrentGraph().getGraphSchema(typeFactory));
+      return SqlTypeUtil.convertToRelType(outputType, true, typeFactory);
+    };
+  }
 
+  public Class getImplementClass() {
+    return implementClass;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedScalarFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedScalarFunction.java
index 14551541c..a33e2524d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedScalarFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedScalarFunction.java
@@ -32,41 +32,45 @@
 
 public class GeaFlowUserDefinedScalarFunction extends SqlUserDefinedFunction {
 
-    private final Class implementClass;
+  private final Class implementClass;
 
-    private GeaFlowUserDefinedScalarFunction(String name, Class clazz,
-                                             GQLJavaTypeFactory typeFactory) {
-        super(new SqlIdentifier(name, SqlParserPos.ZERO),
-            getReturnTypeInference(clazz),
-            FunctionUtil.getSqlOperandTypeInference(clazz, typeFactory),
-            FunctionUtil.getSqlOperandTypeChecker(name, clazz, typeFactory),
-            null, null);
-        this.implementClass = clazz;
-    }
+  private GeaFlowUserDefinedScalarFunction(
+      String name, Class clazz, GQLJavaTypeFactory typeFactory) {
+    super(
+        new SqlIdentifier(name, SqlParserPos.ZERO),
+        getReturnTypeInference(clazz),
+        FunctionUtil.getSqlOperandTypeInference(clazz, typeFactory),
+        FunctionUtil.getSqlOperandTypeChecker(name, clazz, typeFactory),
+        null,
+        null);
+    this.implementClass = clazz;
+  }
 
-    public static GeaFlowUserDefinedScalarFunction create(String name, Class clazz,
-                                                          GQLJavaTypeFactory typeFactory) {
-        return new GeaFlowUserDefinedScalarFunction(name, clazz, typeFactory);
-    }
+  public static GeaFlowUserDefinedScalarFunction create(
+      String name, Class clazz, GQLJavaTypeFactory typeFactory) {
+    return new GeaFlowUserDefinedScalarFunction(name, clazz, typeFactory);
+  }
 
-    @SuppressWarnings("unchecked")
-    public static GeaFlowUserDefinedScalarFunction create(GeaFlowFunction function, GQLJavaTypeFactory typeFactory) {
-        String name = function.getName();
-        String className = function.getClazz().get(0);
-        try {
-            Class clazz = (Class) Thread.currentThread()
-                .getContextClassLoader().loadClass(className);
-            return new GeaFlowUserDefinedScalarFunction(name, clazz, typeFactory);
-        } catch (ClassNotFoundException e) {
-            throw new GeaFlowDSLException(e);
-        }
+  @SuppressWarnings("unchecked")
+  public static GeaFlowUserDefinedScalarFunction create(
+      GeaFlowFunction function, GQLJavaTypeFactory typeFactory) {
+    String name = function.getName();
+    String className = function.getClazz().get(0);
+    try {
+      Class clazz =
+          (Class)
+              Thread.currentThread().getContextClassLoader().loadClass(className);
+      return new GeaFlowUserDefinedScalarFunction(name, clazz, typeFactory);
+    } catch (ClassNotFoundException e) {
+      throw new GeaFlowDSLException(e);
     }
+  }
 
-    private static SqlReturnTypeInference getReturnTypeInference(final Class clazz) {
-        return FunctionUtil.getSqlReturnTypeInference(clazz, FunctionCallUtils.UDF_EVAL_METHOD_NAME);
-    }
+  private static SqlReturnTypeInference getReturnTypeInference(final Class clazz) {
+    return FunctionUtil.getSqlReturnTypeInference(clazz, FunctionCallUtils.UDF_EVAL_METHOD_NAME);
+  }
 
-    public Class getImplementClass() {
-        return implementClass;
-    }
+  public Class getImplementClass() {
+    return implementClass;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedTableFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedTableFunction.java
index a002baa38..effd535a6 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedTableFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GeaFlowUserDefinedTableFunction.java
@@ -22,11 +22,11 @@
 import static org.apache.geaflow.dsl.util.FunctionUtil.getSqlOperandTypeChecker;
 import static org.apache.geaflow.dsl.util.FunctionUtil.getSqlOperandTypeInference;
 
-import com.google.common.collect.Lists;
 import java.io.Serializable;
 import java.lang.reflect.Type;
 import java.util.ArrayList;
 import java.util.List;
+
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.schema.FunctionParameter;
@@ -40,82 +40,94 @@
 import org.apache.geaflow.dsl.planner.GQLJavaTypeFactory;
 import org.apache.geaflow.dsl.util.SqlTypeUtil;
 
-public class GeaFlowUserDefinedTableFunction extends SqlUserDefinedTableFunction implements Serializable {
+import com.google.common.collect.Lists;
+
+public class GeaFlowUserDefinedTableFunction extends SqlUserDefinedTableFunction
+    implements Serializable {
+
+  private final Class implementClass;
+
+  private GeaFlowUserDefinedTableFunction(
+      String name,
+      Class clazz,
+      IRichTableFunction tableFunction,
+      GQLJavaTypeFactory typeFactory) {
+    super(
+        new SqlIdentifier(name, SqlParserPos.ZERO),
+        ReturnTypes.CURSOR,
+        getSqlOperandTypeInference(clazz, typeFactory),
+        getSqlOperandTypeChecker(name, clazz, typeFactory),
+        null,
+        tableFunction);
+    this.implementClass = clazz;
+  }
+
+  public static GeaFlowUserDefinedTableFunction create(
+      String name, Class clazz, GQLJavaTypeFactory typeFactory) {
+    try {
+      GeaflowTableFunction function = new GeaflowTableFunction(clazz, typeFactory);
+      return new GeaFlowUserDefinedTableFunction(name, clazz, function, typeFactory);
+    } catch (Exception e) {
+      throw new GeaFlowDSLException(e);
+    }
+  }
 
-    private final Class implementClass;
+  public Class getImplementClass() {
+    return implementClass;
+  }
+
+  public static class GeaflowTableFunction implements IRichTableFunction, Serializable {
+
+    private final GQLJavaTypeFactory typeFactory;
+    private UDTF functionInstance;
+
+    public GeaflowTableFunction(Class clazz, GQLJavaTypeFactory typeFactory) {
+      this.typeFactory = typeFactory;
+      try {
+        functionInstance = (UDTF) clazz.newInstance();
+      } catch (Exception e) {
+        throw new GeaFlowDSLException("Failed to create instance for " + clazz);
+      }
+    }
 
-    private GeaFlowUserDefinedTableFunction(String name, Class clazz, IRichTableFunction tableFunction,
-                                            GQLJavaTypeFactory typeFactory) {
-        super(new SqlIdentifier(name, SqlParserPos.ZERO),
-            ReturnTypes.CURSOR,
-            getSqlOperandTypeInference(clazz, typeFactory),
-            getSqlOperandTypeChecker(name, clazz, typeFactory),
-            null, tableFunction);
-        this.implementClass = clazz;
+    @Override
+    public List getParameters() {
+      return new ArrayList<>();
     }
 
-    public static GeaFlowUserDefinedTableFunction create(String name, Class clazz,
-                                                         GQLJavaTypeFactory typeFactory) {
-        try {
-            GeaflowTableFunction function = new GeaflowTableFunction(clazz, typeFactory);
-            return new GeaFlowUserDefinedTableFunction(name, clazz, function, typeFactory);
-        } catch (Exception e) {
-            throw new GeaFlowDSLException(e);
-        }
+    @Override
+    public RelDataType getRowType(RelDataTypeFactory typeFactory, List arguments) {
+      return null;
     }
 
-    public Class getImplementClass() {
-        return implementClass;
+    @Override
+    public Type getElementType(List arguments) {
+      return Object[].class;
     }
 
-    public static class GeaflowTableFunction implements IRichTableFunction, Serializable {
-
-        private final GQLJavaTypeFactory typeFactory;
-        private UDTF functionInstance;
-
-        public GeaflowTableFunction(Class clazz, GQLJavaTypeFactory typeFactory) {
-            this.typeFactory = typeFactory;
-            try {
-                functionInstance = (UDTF) clazz.newInstance();
-            } catch (Exception e) {
-                throw new GeaFlowDSLException("Failed to create instance for " + clazz);
-            }
-        }
-
-        @Override
-        public List getParameters() {
-            return new ArrayList<>();
-        }
-
-        @Override
-        public RelDataType getRowType(RelDataTypeFactory typeFactory, List arguments) {
-            return null;
-        }
-
-        @Override
-        public Type getElementType(List arguments) {
-            return Object[].class;
-        }
-
-
-        @Override
-        public RelDataType getRowType(RelDataTypeFactory typeFactory, List args, List paramTypes,
-                                      List outFieldNames) {
-
-            List fieldTypes = Lists.newArrayList();
-            List> callJavaTypes =
-                SqlTypeUtil.convertToJavaTypes(paramTypes, (GQLJavaTypeFactory) typeFactory);
-
-            List> returnTypes = functionInstance.getReturnType(callJavaTypes, outFieldNames);
-            for (int i = 0; i < returnTypes.size(); i++) {
-                fieldTypes.add(this.typeFactory.createType(returnTypes.get(i)));
-            }
-            if (outFieldNames.size() != fieldTypes.size()) {
-                throw new GeaFlowDSLException(String.format("Output fields size[%d] should equal to return type "
-                        + "size[%d] defined in class: %s", outFieldNames.size(), fieldTypes.size(),
-                    functionInstance.getClass().toString()));
-            }
-            return this.typeFactory.createStructType(fieldTypes, outFieldNames);
-        }
+    @Override
+    public RelDataType getRowType(
+        RelDataTypeFactory typeFactory,
+        List args,
+        List paramTypes,
+        List outFieldNames) {
+
+      List fieldTypes = Lists.newArrayList();
+      List> callJavaTypes =
+          SqlTypeUtil.convertToJavaTypes(paramTypes, (GQLJavaTypeFactory) typeFactory);
+
+      List> returnTypes = functionInstance.getReturnType(callJavaTypes, outFieldNames);
+      for (int i = 0; i < returnTypes.size(); i++) {
+        fieldTypes.add(this.typeFactory.createType(returnTypes.get(i)));
+      }
+      if (outFieldNames.size() != fieldTypes.size()) {
+        throw new GeaFlowDSLException(
+            String.format(
+                "Output fields size[%d] should equal to return type "
+                    + "size[%d] defined in class: %s",
+                outFieldNames.size(), fieldTypes.size(), functionInstance.getClass().toString()));
+      }
+      return this.typeFactory.createStructType(fieldTypes, outFieldNames);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlAvgAggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlAvgAggFunction.java
index 28be7aead..3816057d2 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlAvgAggFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlAvgAggFunction.java
@@ -27,15 +27,15 @@
 
 public class GqlAvgAggFunction extends SqlAvgAggFunction {
 
-    public GqlAvgAggFunction(SqlKind kind) {
-        super(kind);
-    }
+  public GqlAvgAggFunction(SqlKind kind) {
+    super(kind);
+  }
 
-    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
-        RelDataType superType = super.inferReturnType(opBinding);
-        if (superType instanceof MetaFieldType) {
-            return ((MetaFieldType) superType).getType();
-        }
-        return superType;
+  public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+    RelDataType superType = super.inferReturnType(opBinding);
+    if (superType instanceof MetaFieldType) {
+      return ((MetaFieldType) superType).getType();
     }
+    return superType;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlCountAggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlCountAggFunction.java
index 753caf282..f9a5994e3 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlCountAggFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlCountAggFunction.java
@@ -26,15 +26,15 @@
 
 public class GqlCountAggFunction extends SqlCountAggFunction {
 
-    public GqlCountAggFunction(String name) {
-        super(name);
-    }
+  public GqlCountAggFunction(String name) {
+    super(name);
+  }
 
-    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
-        RelDataType superType = super.inferReturnType(opBinding);
-        if (superType instanceof MetaFieldType) {
-            return ((MetaFieldType) superType).getType();
-        }
-        return superType;
+  public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+    RelDataType superType = super.inferReturnType(opBinding);
+    if (superType instanceof MetaFieldType) {
+      return ((MetaFieldType) superType).getType();
     }
+    return superType;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlMinMaxAggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlMinMaxAggFunction.java
index c67bdd956..7a5c46909 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlMinMaxAggFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlMinMaxAggFunction.java
@@ -27,15 +27,15 @@
 
 public class GqlMinMaxAggFunction extends SqlMinMaxAggFunction {
 
-    public GqlMinMaxAggFunction(SqlKind kind) {
-        super(kind);
-    }
+  public GqlMinMaxAggFunction(SqlKind kind) {
+    super(kind);
+  }
 
-    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
-        RelDataType superType = super.inferReturnType(opBinding);
-        if (superType instanceof MetaFieldType) {
-            return ((MetaFieldType) superType).getType();
-        }
-        return superType;
+  public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+    RelDataType superType = super.inferReturnType(opBinding);
+    if (superType instanceof MetaFieldType) {
+      return ((MetaFieldType) superType).getType();
     }
+    return superType;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlSumAggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlSumAggFunction.java
index f3ec70e53..d8c7cd89b 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlSumAggFunction.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/GqlSumAggFunction.java
@@ -26,15 +26,15 @@
 
 public class GqlSumAggFunction extends SqlSumAggFunction {
 
-    public GqlSumAggFunction(RelDataType type) {
-        super(type instanceof MetaFieldType ? ((MetaFieldType) type).getType() : type);
-    }
+  public GqlSumAggFunction(RelDataType type) {
+    super(type instanceof MetaFieldType ? ((MetaFieldType) type).getType() : type);
+  }
 
-    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
-        RelDataType superType = super.inferReturnType(opBinding);
-        if (superType instanceof MetaFieldType) {
-            return ((MetaFieldType) superType).getType();
-        }
-        return superType;
+  public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+    RelDataType superType = super.inferReturnType(opBinding);
+    if (superType instanceof MetaFieldType) {
+      return ((MetaFieldType) superType).getType();
     }
+    return superType;
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/AllSourceShortestPath.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/AllSourceShortestPath.java
index 0441774a8..8b6890bb7 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/AllSourceShortestPath.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/AllSourceShortestPath.java
@@ -19,7 +19,6 @@
 
 package org.apache.geaflow.dsl.udf.graph;
 
-import com.alibaba.fastjson.JSONObject;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Iterator;
@@ -27,6 +26,7 @@
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
+
 import org.apache.geaflow.common.type.primitive.StringType;
 import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
 import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction;
@@ -41,141 +41,143 @@
 import org.apache.geaflow.dsl.common.types.TableField;
 import org.apache.geaflow.model.graph.edge.EdgeDirection;
 
-@Description(name = "assp", description = "built-in udga All Source Shortest Path")
-public class AllSourceShortestPath implements AlgorithmUserFunction>,
-    IncrementalAlgorithmUserFunction {
+import com.alibaba.fastjson.JSONObject;
 
-    private AlgorithmRuntimeContext> context;
+@Description(name = "assp", description = "built-in udga All Source Shortest Path")
+public class AllSourceShortestPath
+    implements AlgorithmUserFunction>, IncrementalAlgorithmUserFunction {
 
-    private EdgeDirection edgeDirection;
+  private AlgorithmRuntimeContext> context;
 
-    @Override
-    public void init(AlgorithmRuntimeContext> context, Object[] parameters) {
-        this.context = context;
-        edgeDirection = EdgeDirection.OUT;
-    }
+  private EdgeDirection edgeDirection;
 
+  @Override
+  public void init(AlgorithmRuntimeContext> context, Object[] parameters) {
+    this.context = context;
+    edgeDirection = EdgeDirection.OUT;
+  }
 
-    private void sendPathMessages(List path, List edges) {
-        if (edges == null) {
-            return;
-        }
-        for (RowEdge edge : edges) {
-            if (edgeDirection == edge.getDirect()) {
-                sendPathMessage(edge.getTargetId(), path);
-            }
-        }
+  private void sendPathMessages(List path, List edges) {
+    if (edges == null) {
+      return;
     }
-
-    private void sendPathMessage(Object targetId, List path) {
-        this.context.sendMessage(targetId, path);
+    for (RowEdge edge : edges) {
+      if (edgeDirection == edge.getDirect()) {
+        sendPathMessage(edge.getTargetId(), path);
+      }
     }
-
-    @Override
-    public void process(RowVertex vertex, Optional updatedValues, Iterator> messages) {
-        if (context.getCurrentIterationId() == 1) {
-            List path = new ArrayList<>();
-            path.add(vertex.getId());
-            List edges = context.loadEdges(edgeDirection);
-            sendPathMessages(path, edges);
-
-            if (updatedValues.isPresent()) {
-                Map> pathMap = getPathMap(updatedValues);
-                // Send paths of the current vertex to neighborhood.
-                for (Entry> paths : pathMap.entrySet()) {
-                    sendPathMessages(paths.getValue(), edges);
-                }
-            } else {
-                // Init pathMap.
-                context.updateVertexValue(ObjectRow.create(new HashMap<>()));
-            }
-            return;
+  }
+
+  private void sendPathMessage(Object targetId, List path) {
+    this.context.sendMessage(targetId, path);
+  }
+
+  @Override
+  public void process(
+      RowVertex vertex, Optional updatedValues, Iterator> messages) {
+    if (context.getCurrentIterationId() == 1) {
+      List path = new ArrayList<>();
+      path.add(vertex.getId());
+      List edges = context.loadEdges(edgeDirection);
+      sendPathMessages(path, edges);
+
+      if (updatedValues.isPresent()) {
+        Map> pathMap = getPathMap(updatedValues);
+        // Send paths of the current vertex to neighborhood.
+        for (Entry> paths : pathMap.entrySet()) {
+          sendPathMessages(paths.getValue(), edges);
         }
+      } else {
+        // Init pathMap.
+        context.updateVertexValue(ObjectRow.create(new HashMap<>()));
+      }
+      return;
+    }
 
-        // Key is the id of src vertex.
-        Map> minPathMap = new HashMap<>();
-
-        while (messages.hasNext()) {
-            List msgPath = messages.next();
-
-            int msgDis = getPathDist(msgPath);
-            Object srcId = getSrcId(msgPath);
-            if (srcId.equals(vertex.getId())) {
-                continue;
-            }
-            if (!minPathMap.containsKey(srcId)) {
-                minPathMap.put(srcId, new ArrayList<>());
-            }
-
-            int dist = getPathDist(minPathMap.get(srcId));
-            // Find the min dist path.
-            if (msgDis < dist) {
-                minPathMap.put(srcId, msgPath);
-            }
-        }
+    // Key is the id of src vertex.
+    Map> minPathMap = new HashMap<>();
+
+    while (messages.hasNext()) {
+      List msgPath = messages.next();
+
+      int msgDis = getPathDist(msgPath);
+      Object srcId = getSrcId(msgPath);
+      if (srcId.equals(vertex.getId())) {
+        continue;
+      }
+      if (!minPathMap.containsKey(srcId)) {
+        minPathMap.put(srcId, new ArrayList<>());
+      }
+
+      int dist = getPathDist(minPathMap.get(srcId));
+      // Find the min dist path.
+      if (msgDis < dist) {
+        minPathMap.put(srcId, msgPath);
+      }
+    }
 
-        // Look up all src vertices.
-        for (Entry> entry : minPathMap.entrySet()) {
-            Object srcId = entry.getKey();
-            Map> pathMap = getPathMap(updatedValues);
+    // Look up all src vertices.
+    for (Entry> entry : minPathMap.entrySet()) {
+      Object srcId = entry.getKey();
+      Map> pathMap = getPathMap(updatedValues);
 
-            if (!pathMap.containsKey(srcId)) {
-                pathMap.put(srcId, new ArrayList<>());
-            }
+      if (!pathMap.containsKey(srcId)) {
+        pathMap.put(srcId, new ArrayList<>());
+      }
 
-            List path = pathMap.get(srcId);
-            int curVertexDis = getPathDist(path);
+      List path = pathMap.get(srcId);
+      int curVertexDis = getPathDist(path);
 
-            List minPath = entry.getValue();
-            int pathDist = getPathDist(minPath);
+      List minPath = entry.getValue();
+      int pathDist = getPathDist(minPath);
 
-            // Update if minDist is less than current.
-            if (pathDist < curVertexDis) {
+      // Update if minDist is less than current.
+      if (pathDist < curVertexDis) {
 
-                List newPath = new ArrayList<>(minPath);
-                // Add id of current vertex to the path.
-                newPath.add(vertex.getId());
+        List newPath = new ArrayList<>(minPath);
+        // Add id of current vertex to the path.
+        newPath.add(vertex.getId());
 
-                // Send path message to neighborhood.
-                sendPathMessages(newPath, context.loadEdges(edgeDirection));
+        // Send path message to neighborhood.
+        sendPathMessages(newPath, context.loadEdges(edgeDirection));
 
-                // update the pathMap of current vertex.
-                pathMap.put(srcId, newPath);
-                context.updateVertexValue(ObjectRow.create(pathMap));
-            }
-        }
+        // update the pathMap of current vertex.
+        pathMap.put(srcId, newPath);
+        context.updateVertexValue(ObjectRow.create(pathMap));
+      }
     }
+  }
 
-    Map> getPathMap(Optional value) {
-        return (Map>) value.get().getField(0, null);
-    }
+  Map> getPathMap(Optional value) {
+    return (Map>) value.get().getField(0, null);
+  }
 
-    int getPathDist(List path) {
-        if (path.isEmpty()) {
-            return Integer.MAX_VALUE;
-        }
-        return path.size() - 1;
+  int getPathDist(List path) {
+    if (path.isEmpty()) {
+      return Integer.MAX_VALUE;
     }
-
-    @Override
-    public void finish(RowVertex vertex, Optional newValue) {
-        JSONObject jsonObject = new JSONObject();
-        jsonObject.put("vertex", vertex.getId());
-        if (newValue.isPresent()) {
-            jsonObject.put("paths", getPathMap(newValue));
-        } else {
-            jsonObject.put("paths", "[]");
-        }
-
-        context.take(ObjectRow.create(jsonObject.toString()));
+    return path.size() - 1;
+  }
+
+  @Override
+  public void finish(RowVertex vertex, Optional newValue) {
+    JSONObject jsonObject = new JSONObject();
+    jsonObject.put("vertex", vertex.getId());
+    if (newValue.isPresent()) {
+      jsonObject.put("paths", getPathMap(newValue));
+    } else {
+      jsonObject.put("paths", "[]");
     }
 
-    private Object getSrcId(List path) {
-        return path.get(0);
-    }
+    context.take(ObjectRow.create(jsonObject.toString()));
+  }
 
-    @Override
-    public StructType getOutputType(GraphSchema graphSchema) {
-        return new StructType(new TableField("res", StringType.INSTANCE, false));
-    }
+  private Object getSrcId(List path) {
+    return path.get(0);
+  }
+
+  @Override
+  public StructType getOutputType(GraphSchema graphSchema) {
+    return new StructType(new TableField("res", StringType.INSTANCE, false));
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClosenessCentrality.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClosenessCentrality.java
index ffc6091f1..6fe86af92 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClosenessCentrality.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClosenessCentrality.java
@@ -22,6 +22,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Optional;
+
 import org.apache.geaflow.common.type.primitive.DoubleType;
 import org.apache.geaflow.common.type.primitive.LongType;
 import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
@@ -40,77 +41,75 @@
 @Description(name = "closeness_centrality", description = "built-in udga for ClosenessCentrality")
 public class ClosenessCentrality implements AlgorithmUserFunction {
 
-    private AlgorithmRuntimeContext context;
-    private Object sourceId;
+  private AlgorithmRuntimeContext context;
+  private Object sourceId;
 
-    @Override
-    public void init(AlgorithmRuntimeContext context, Object[] params) {
-        this.context = context;
-        if (params.length != 1) {
-            throw new IllegalArgumentException("Only support one arguments, usage: func(sourceId)");
-        }
-        this.sourceId = TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType());
+  @Override
+  public void init(AlgorithmRuntimeContext context, Object[] params) {
+    this.context = context;
+    if (params.length != 1) {
+      throw new IllegalArgumentException("Only support one arguments, usage: func(sourceId)");
     }
+    this.sourceId = TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType());
+  }
 
-    @Override
-    public void process(RowVertex vertex, Optional updatedValues, Iterator messages) {
-        updatedValues.ifPresent(vertex::setValue);
-        List edges = context.loadEdges(EdgeDirection.OUT);
-        if (context.getCurrentIterationId() == 1L) {
-            context.sendMessage(vertex.getId(), 1L);
-            context.sendMessage(sourceId, 1L);
-        } else if (context.getCurrentIterationId() == 2L) {
-            context.updateVertexValue(ObjectRow.create(0L, 0L));
-            if (vertex.getId().equals(sourceId)) {
-                long vertexNum = -2L;
-                while (messages.hasNext()) {
-                    messages.next();
-                    vertexNum++;
-                }
-                context.updateVertexValue(ObjectRow.create(0L, vertexNum));
-                sendMessageToNeighbors(edges, 1L);
-            }
-        } else {
-            if (vertex.getId().equals(sourceId)) {
-                long sum = (long) vertex.getValue().getField(0, LongType.INSTANCE);
-                while (messages.hasNext()) {
-                    sum += messages.next();
-                }
-                long vertexNum = (long) vertex.getValue().getField(1, LongType.INSTANCE);
-                context.updateVertexValue(ObjectRow.create(sum, vertexNum));
-            } else {
-                if (((long) vertex.getValue().getField(1, LongType.INSTANCE)) < 1) {
-                    Long meg = messages.next();
-                    context.sendMessage(sourceId, meg);
-                    sendMessageToNeighbors(edges, meg + 1);
-                    context.updateVertexValue(ObjectRow.create(0L, 1L));
-                }
-            }
+  @Override
+  public void process(RowVertex vertex, Optional updatedValues, Iterator messages) {
+    updatedValues.ifPresent(vertex::setValue);
+    List edges = context.loadEdges(EdgeDirection.OUT);
+    if (context.getCurrentIterationId() == 1L) {
+      context.sendMessage(vertex.getId(), 1L);
+      context.sendMessage(sourceId, 1L);
+    } else if (context.getCurrentIterationId() == 2L) {
+      context.updateVertexValue(ObjectRow.create(0L, 0L));
+      if (vertex.getId().equals(sourceId)) {
+        long vertexNum = -2L;
+        while (messages.hasNext()) {
+          messages.next();
+          vertexNum++;
         }
-
-    }
-
-    @Override
-    public void finish(RowVertex graphVertex, Optional updatedValues) {
-        updatedValues.ifPresent(graphVertex::setValue);
-        if (graphVertex.getId().equals(sourceId)) {
-            long len = (long) graphVertex.getValue().getField(0, LongType.INSTANCE);
-            long num = (long) graphVertex.getValue().getField(1, LongType.INSTANCE);
-            context.take(ObjectRow.create(graphVertex.getId(), (double) num / len));
+        context.updateVertexValue(ObjectRow.create(0L, vertexNum));
+        sendMessageToNeighbors(edges, 1L);
+      }
+    } else {
+      if (vertex.getId().equals(sourceId)) {
+        long sum = (long) vertex.getValue().getField(0, LongType.INSTANCE);
+        while (messages.hasNext()) {
+          sum += messages.next();
+        }
+        long vertexNum = (long) vertex.getValue().getField(1, LongType.INSTANCE);
+        context.updateVertexValue(ObjectRow.create(sum, vertexNum));
+      } else {
+        if (((long) vertex.getValue().getField(1, LongType.INSTANCE)) < 1) {
+          Long meg = messages.next();
+          context.sendMessage(sourceId, meg);
+          sendMessageToNeighbors(edges, meg + 1);
+          context.updateVertexValue(ObjectRow.create(0L, 1L));
         }
+      }
     }
+  }
 
-    @Override
-    public StructType getOutputType(GraphSchema graphSchema) {
-        return new StructType(
-            new TableField("id", graphSchema.getIdType(), false),
-            new TableField("cc", DoubleType.INSTANCE, false)
-        );
+  @Override
+  public void finish(RowVertex graphVertex, Optional updatedValues) {
+    updatedValues.ifPresent(graphVertex::setValue);
+    if (graphVertex.getId().equals(sourceId)) {
+      long len = (long) graphVertex.getValue().getField(0, LongType.INSTANCE);
+      long num = (long) graphVertex.getValue().getField(1, LongType.INSTANCE);
+      context.take(ObjectRow.create(graphVertex.getId(), (double) num / len));
     }
+  }
 
-    private void sendMessageToNeighbors(List edges, Long message) {
-        for (RowEdge rowEdge : edges) {
-            context.sendMessage(rowEdge.getTargetId(), message);
-        }
+  @Override
+  public StructType getOutputType(GraphSchema graphSchema) {
+    return new StructType(
+        new TableField("id", graphSchema.getIdType(), false),
+        new TableField("cc", DoubleType.INSTANCE, false));
+  }
+
+  private void sendMessageToNeighbors(List edges, Long message) {
+    for (RowEdge rowEdge : edges) {
+      context.sendMessage(rowEdge.getTargetId(), message);
     }
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java
index 266195467..aa7ce7b84 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java
@@ -19,13 +19,12 @@
 
 package org.apache.geaflow.dsl.udf.graph;
 
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
+
 import org.apache.geaflow.common.type.primitive.DoubleType;
 import org.apache.geaflow.common.type.primitive.IntegerType;
 import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
@@ -40,242 +39,232 @@
 import org.apache.geaflow.dsl.common.types.TableField;
 import org.apache.geaflow.model.graph.edge.EdgeDirection;
 
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
 /**
  * ClusterCoefficient Algorithm Implementation.
- * 
- * 

The clustering coefficient of a node measures how close its neighbors are to being - * a complete graph (clique). It is calculated as the ratio of the number of edges between - * neighbors to the maximum possible number of edges between them. - * - *

Formula: C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) - * where: - * - T(v) is the number of triangles through node v - * - k(v) is the degree of node v - * - *

The algorithm consists of 3 iteration phases: - * 1. First iteration: Each node sends its neighbor list to all neighbors - * 2. Second iteration: Each node receives neighbor lists and calculates connections - * 3. Third iteration: Output final clustering coefficient results - * - *

Supports parameters: - * - vertexType (optional): Filter nodes by vertex type - * - minDegree (optional): Minimum degree threshold (default: 2) + * + *

The clustering coefficient of a node measures how close its neighbors are to being a complete + * graph (clique). It is calculated as the ratio of the number of edges between neighbors to the + * maximum possible number of edges between them. + * + *

Formula: C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) where: - T(v) is the number of triangles + * through node v - k(v) is the degree of node v + * + *

The algorithm consists of 3 iteration phases: 1. First iteration: Each node sends its neighbor + * list to all neighbors 2. Second iteration: Each node receives neighbor lists and calculates + * connections 3. Third iteration: Output final clustering coefficient results + * + *

Supports parameters: - vertexType (optional): Filter nodes by vertex type - minDegree + * (optional): Minimum degree threshold (default: 2) */ @Description(name = "cluster_coefficient", description = "built-in udga for Cluster Coefficient.") public class ClusterCoefficient implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - - private static final int MAX_ITERATION = 3; - - // Parameters - private String vertexType = null; - private int minDegree = 2; - - // Exclude set for nodes that don't match the vertex type filter - private final Set excludeSet = Sets.newHashSet(); - - @Override - public void init(AlgorithmRuntimeContext context, Object[] params) { - this.context = context; - - // Validate parameter count - if (params.length > 2) { - throw new IllegalArgumentException( - "Maximum parameter limit exceeded. Expected: [vertexType], [minDegree]"); - } - - // Parse parameters based on type - // If first param is String, it's vertexType; if it's Integer/Long, it's minDegree - if (params.length >= 1 && params[0] != null) { - if (params[0] instanceof String) { - // First param is vertexType - vertexType = (String) params[0]; - - // Second param (if exists) is minDegree - if (params.length >= 2 && params[1] != null) { - if (!(params[1] instanceof Integer || params[1] instanceof Long)) { - throw new IllegalArgumentException( - "Minimum degree parameter should be integer."); - } - minDegree = params[1] instanceof Integer - ? (Integer) params[1] - : ((Long) params[1]).intValue(); - } - } else if (params[0] instanceof Integer || params[0] instanceof Long) { - // First param is minDegree (no vertexType filter) - vertexType = null; - minDegree = params[0] instanceof Integer - ? (Integer) params[0] - : ((Long) params[0]).intValue(); - } else { - throw new IllegalArgumentException( - "Parameter should be either string (vertexType) or integer (minDegree)."); - } - } + private AlgorithmRuntimeContext context; + + private static final int MAX_ITERATION = 3; + + // Parameters + private String vertexType = null; + private int minDegree = 2; + + // Exclude set for nodes that don't match the vertex type filter + private final Set excludeSet = Sets.newHashSet(); + + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; + + // Validate parameter count + if (params.length > 2) { + throw new IllegalArgumentException( + "Maximum parameter limit exceeded. Expected: [vertexType], [minDegree]"); } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - - Object vertexId = vertex.getId(); - long currentIteration = context.getCurrentIterationId(); - - if (currentIteration == 1L) { - // First iteration: Check vertex type filter and send neighbor lists - if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { - excludeSet.add(vertexId); - // Send heartbeat to keep vertex alive - context.sendMessage(vertexId, ObjectRow.create(-1)); - return; - } - - // Load all neighbors (both directions for undirected graph) - List edges = context.loadEdges(EdgeDirection.BOTH); - - // Get unique neighbor IDs - Set neighborSet = Sets.newHashSet(); - for (RowEdge edge : edges) { - Object neighborId = edge.getTargetId(); - if (!excludeSet.contains(neighborId)) { - neighborSet.add(neighborId); - } - } - - int degree = neighborSet.size(); - - // For nodes with degree < minDegree, clustering coefficient is 0 - if (degree < minDegree) { - // Store degree and triangle count = 0 - context.updateVertexValue(ObjectRow.create(degree, 0)); - context.sendMessage(vertexId, ObjectRow.create(-1)); - return; - } - - // Build neighbor list message: [degree, neighbor1, neighbor2, ...] - List neighborInfo = Lists.newArrayList(); - neighborInfo.add(degree); - neighborInfo.addAll(neighborSet); - - ObjectRow neighborListMsg = ObjectRow.create(neighborInfo.toArray()); - - // Send neighbor list to all neighbors - for (Object neighborId : neighborSet) { - context.sendMessage(neighborId, neighborListMsg); - } - - // Store neighbor list in vertex value for next iteration - context.updateVertexValue(neighborListMsg); - - // Send heartbeat to self - context.sendMessage(vertexId, ObjectRow.create(-1)); - - } else if (currentIteration == 2L) { - // Second iteration: Calculate connections between neighbors - if (excludeSet.contains(vertexId)) { - context.sendMessage(vertexId, ObjectRow.create(-1)); - return; - } - - Row vertexValue = vertex.getValue(); - if (vertexValue == null) { - context.sendMessage(vertexId, ObjectRow.create(-1)); - return; - } - - int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); - - // For nodes with degree < minDegree, skip calculation - if (degree < minDegree) { - context.sendMessage(vertexId, ObjectRow.create(-1)); - return; - } - - // Get this vertex's neighbor set - Set myNeighbors = row2Set(vertexValue); - - // Count triangles by checking common neighbors - int triangleCount = 0; - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - - // Skip heartbeat messages - int msgDegree = (int) msg.getField(0, IntegerType.INSTANCE); - if (msgDegree < 0) { - continue; - } - - // Get neighbor's neighbor set - Set neighborNeighbors = row2Set(msg); - - // Count common neighbors (forming triangles) - neighborNeighbors.retainAll(myNeighbors); - triangleCount += neighborNeighbors.size(); - } - - // Store degree and triangle count for final calculation - context.updateVertexValue(ObjectRow.create(degree, triangleCount)); - context.sendMessage(vertexId, ObjectRow.create(-1)); - - } else if (currentIteration == 3L) { - // Third iteration: Calculate and output clustering coefficient - if (excludeSet.contains(vertexId)) { - return; - } - - Row vertexValue = vertex.getValue(); - if (vertexValue == null) { - return; - } - - int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); - int triangleCount = (int) vertexValue.getField(1, IntegerType.INSTANCE); - - // Calculate clustering coefficient - double coefficient; - if (degree < minDegree) { - coefficient = 0.0; - } else { - // C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) - // Note: triangleCount is already counting edges, so we divide by 2 - double actualTriangles = triangleCount / 2.0; - double maxPossibleEdges = degree * (degree - 1.0); - coefficient = maxPossibleEdges > 0 - ? (2.0 * actualTriangles) / maxPossibleEdges - : 0.0; - } - - context.take(ObjectRow.create(vertexId, coefficient)); + // Parse parameters based on type + // If first param is String, it's vertexType; if it's Integer/Long, it's minDegree + if (params.length >= 1 && params[0] != null) { + if (params[0] instanceof String) { + // First param is vertexType + vertexType = (String) params[0]; + + // Second param (if exists) is minDegree + if (params.length >= 2 && params[1] != null) { + if (!(params[1] instanceof Integer || params[1] instanceof Long)) { + throw new IllegalArgumentException("Minimum degree parameter should be integer."); + } + minDegree = + params[1] instanceof Integer ? (Integer) params[1] : ((Long) params[1]).intValue(); } + } else if (params[0] instanceof Integer || params[0] instanceof Long) { + // First param is minDegree (no vertexType filter) + vertexType = null; + minDegree = + params[0] instanceof Integer ? (Integer) params[0] : ((Long) params[0]).intValue(); + } else { + throw new IllegalArgumentException( + "Parameter should be either string (vertexType) or integer (minDegree)."); + } } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - // No action needed in finish - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("vid", graphSchema.getIdType(), false), - new TableField("coefficient", DoubleType.INSTANCE, false) - ); - } + Object vertexId = vertex.getId(); + long currentIteration = context.getCurrentIterationId(); + + if (currentIteration == 1L) { + // First iteration: Check vertex type filter and send neighbor lists + if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { + excludeSet.add(vertexId); + // Send heartbeat to keep vertex alive + context.sendMessage(vertexId, ObjectRow.create(-1)); + return; + } + + // Load all neighbors (both directions for undirected graph) + List edges = context.loadEdges(EdgeDirection.BOTH); - /** - * Convert Row to Set of neighbor IDs. - * Row format: [degree, neighbor1, neighbor2, ...] - */ - private Set row2Set(Row row) { - int degree = (int) row.getField(0, IntegerType.INSTANCE); - Set neighborSet = Sets.newHashSet(); - for (int i = 1; i <= degree; i++) { - Object neighborId = row.getField(i, context.getGraphSchema().getIdType()); - if (!excludeSet.contains(neighborId)) { - neighborSet.add(neighborId); - } + // Get unique neighbor IDs + Set neighborSet = Sets.newHashSet(); + for (RowEdge edge : edges) { + Object neighborId = edge.getTargetId(); + if (!excludeSet.contains(neighborId)) { + neighborSet.add(neighborId); } - return neighborSet; + } + + int degree = neighborSet.size(); + + // For nodes with degree < minDegree, clustering coefficient is 0 + if (degree < minDegree) { + // Store degree and triangle count = 0 + context.updateVertexValue(ObjectRow.create(degree, 0)); + context.sendMessage(vertexId, ObjectRow.create(-1)); + return; + } + + // Build neighbor list message: [degree, neighbor1, neighbor2, ...] + List neighborInfo = Lists.newArrayList(); + neighborInfo.add(degree); + neighborInfo.addAll(neighborSet); + + ObjectRow neighborListMsg = ObjectRow.create(neighborInfo.toArray()); + + // Send neighbor list to all neighbors + for (Object neighborId : neighborSet) { + context.sendMessage(neighborId, neighborListMsg); + } + + // Store neighbor list in vertex value for next iteration + context.updateVertexValue(neighborListMsg); + + // Send heartbeat to self + context.sendMessage(vertexId, ObjectRow.create(-1)); + + } else if (currentIteration == 2L) { + // Second iteration: Calculate connections between neighbors + if (excludeSet.contains(vertexId)) { + context.sendMessage(vertexId, ObjectRow.create(-1)); + return; + } + + Row vertexValue = vertex.getValue(); + if (vertexValue == null) { + context.sendMessage(vertexId, ObjectRow.create(-1)); + return; + } + + int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); + + // For nodes with degree < minDegree, skip calculation + if (degree < minDegree) { + context.sendMessage(vertexId, ObjectRow.create(-1)); + return; + } + + // Get this vertex's neighbor set + Set myNeighbors = row2Set(vertexValue); + + // Count triangles by checking common neighbors + int triangleCount = 0; + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + + // Skip heartbeat messages + int msgDegree = (int) msg.getField(0, IntegerType.INSTANCE); + if (msgDegree < 0) { + continue; + } + + // Get neighbor's neighbor set + Set neighborNeighbors = row2Set(msg); + + // Count common neighbors (forming triangles) + neighborNeighbors.retainAll(myNeighbors); + triangleCount += neighborNeighbors.size(); + } + + // Store degree and triangle count for final calculation + context.updateVertexValue(ObjectRow.create(degree, triangleCount)); + context.sendMessage(vertexId, ObjectRow.create(-1)); + + } else if (currentIteration == 3L) { + // Third iteration: Calculate and output clustering coefficient + if (excludeSet.contains(vertexId)) { + return; + } + + Row vertexValue = vertex.getValue(); + if (vertexValue == null) { + return; + } + + int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); + int triangleCount = (int) vertexValue.getField(1, IntegerType.INSTANCE); + + // Calculate clustering coefficient + double coefficient; + if (degree < minDegree) { + coefficient = 0.0; + } else { + // C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) + // Note: triangleCount is already counting edges, so we divide by 2 + double actualTriangles = triangleCount / 2.0; + double maxPossibleEdges = degree * (degree - 1.0); + coefficient = maxPossibleEdges > 0 ? (2.0 * actualTriangles) / maxPossibleEdges : 0.0; + } + + context.take(ObjectRow.create(vertexId, coefficient)); + } + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + // No action needed in finish + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("vid", graphSchema.getIdType(), false), + new TableField("coefficient", DoubleType.INSTANCE, false)); + } + + /** Convert Row to Set of neighbor IDs. Row format: [degree, neighbor1, neighbor2, ...] */ + private Set row2Set(Row row) { + int degree = (int) row.getField(0, IntegerType.INSTANCE); + Set neighborSet = Sets.newHashSet(); + for (int i = 1; i <= degree; i++) { + Object neighborId = row.getField(i, context.getGraphSchema().getIdType()); + if (!excludeSet.contains(neighborId)) { + neighborSet.add(neighborId); + } } + return neighborSet; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/CommonNeighbors.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/CommonNeighbors.java index 917fa0207..666ea6d02 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/CommonNeighbors.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/CommonNeighbors.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -39,64 +40,62 @@ @Description(name = "common_neighbors", description = "built-in udga for CommonNeighbors") public class CommonNeighbors implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; + private AlgorithmRuntimeContext context; - // tuple to store params - private Tuple vertices; + // tuple to store params + private Tuple vertices; - @Override - public void init(AlgorithmRuntimeContext context, Object[] params) { - this.context = context; + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; - if (params.length != 2) { - throw new IllegalArgumentException("Only support two arguments, usage: common_neighbors(id_a, id_b)"); - } - this.vertices = new Tuple<>( - TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType()), - TypeCastUtil.cast(params[1], context.getGraphSchema().getIdType()) - ); + if (params.length != 2) { + throw new IllegalArgumentException( + "Only support two arguments, usage: common_neighbors(id_a, id_b)"); } + this.vertices = + new Tuple<>( + TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType()), + TypeCastUtil.cast(params[1], context.getGraphSchema().getIdType())); + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - if (context.getCurrentIterationId() == 1L) { - // send message to neighbors if they are vertices in params - if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { - sendMessageToNeighbors(context.loadEdges(EdgeDirection.BOTH), vertex.getId()); - } - } else if (context.getCurrentIterationId() == 2L) { - // add to result if received messages from both vertices in params - Tuple received = new Tuple<>(false, false); - while (messages.hasNext()) { - Object message = messages.next(); - if (vertices.f0.equals(message)) { - received.setF0(true); - } - if (vertices.f1.equals(message)) { - received.setF1(true); - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (context.getCurrentIterationId() == 1L) { + // send message to neighbors if they are vertices in params + if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { + sendMessageToNeighbors(context.loadEdges(EdgeDirection.BOTH), vertex.getId()); + } + } else if (context.getCurrentIterationId() == 2L) { + // add to result if received messages from both vertices in params + Tuple received = new Tuple<>(false, false); + while (messages.hasNext()) { + Object message = messages.next(); + if (vertices.f0.equals(message)) { + received.setF0(true); + } + if (vertices.f1.equals(message)) { + received.setF1(true); + } - if (received.getF0() && received.getF1()) { - context.take(ObjectRow.create(vertex.getId())); - } - } + if (received.getF0() && received.getF1()) { + context.take(ObjectRow.create(vertex.getId())); } + } } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) {} - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType(new TableField("id", graphSchema.getIdType(), false)); + } - private void sendMessageToNeighbors(List edges, Object message) { - for (RowEdge rowEdge : edges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + private void sendMessageToNeighbors(List edges, Object message) { + for (RowEdge rowEdge : edges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java index 5c0d8f95b..fbea59d45 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.Optional; import java.util.stream.Stream; + import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -38,70 +39,69 @@ @Description(name = "cc", description = "built-in udga for Connected Components Algorithm") public class ConnectedComponents implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private String outputKeyName = "component"; - private int iteration = 20; + private AlgorithmRuntimeContext context; + private String outputKeyName = "component"; + private int iteration = 20; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 2) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([iteration, [outputKeyName]])"); - } - if (parameters.length > 0) { - iteration = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - outputKeyName = String.valueOf(parameters[1]); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([iteration, [outputKeyName]])"); } + if (parameters.length > 0) { + iteration = Integer.parseInt(String.valueOf(parameters[0])); + } + if (parameters.length > 1) { + outputKeyName = String.valueOf(parameters[1]); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - Stream stream = context.loadEdges(EdgeDirection.IN).stream(); - if (context.getCurrentIterationId() == 1L) { - String initValue = String.valueOf(vertex.getId()); - sendMessageToNeighbors(stream, initValue); - context.sendMessage(vertex.getId(), initValue); - context.updateVertexValue(ObjectRow.create(initValue)); - } else if (context.getCurrentIterationId() < iteration) { - String minComponent = null; - while (messages.hasNext()) { - String next = messages.next(); - if (minComponent == null || next.compareTo(minComponent) < 0) { - minComponent = next; - } - } - - String currentValue = (String) vertex.getValue().getField(0, StringType.INSTANCE); - // If found smaller component id, update and propagate - if (minComponent != null && minComponent.compareTo(currentValue) < 0) { - sendMessageToNeighbors(stream, minComponent); - context.sendMessage(vertex.getId(), minComponent); - context.updateVertexValue(ObjectRow.create(minComponent)); - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + Stream stream = context.loadEdges(EdgeDirection.IN).stream(); + if (context.getCurrentIterationId() == 1L) { + String initValue = String.valueOf(vertex.getId()); + sendMessageToNeighbors(stream, initValue); + context.sendMessage(vertex.getId(), initValue); + context.updateVertexValue(ObjectRow.create(initValue)); + } else if (context.getCurrentIterationId() < iteration) { + String minComponent = null; + while (messages.hasNext()) { + String next = messages.next(); + if (minComponent == null || next.compareTo(minComponent) < 0) { + minComponent = next; } - } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - updatedValues.ifPresent(graphVertex::setValue); - String component = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), component)); + String currentValue = (String) vertex.getValue().getField(0, StringType.INSTANCE); + // If found smaller component id, update and propagate + if (minComponent != null && minComponent.compareTo(currentValue) < 0) { + sendMessageToNeighbors(stream, minComponent); + context.sendMessage(vertex.getId(), minComponent); + context.updateVertexValue(ObjectRow.create(minComponent)); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField(outputKeyName, StringType.INSTANCE, false) - ); - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + updatedValues.ifPresent(graphVertex::setValue); + String component = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), component)); + } - private void sendMessageToNeighbors(Stream edges, String message) { - edges.forEach(rowEdge -> context.sendMessage(rowEdge.getTargetId(), message)); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField(outputKeyName, StringType.INSTANCE, false)); + } + + private void sendMessageToNeighbors(Stream edges, String message) { + edges.forEach(rowEdge -> context.sendMessage(rowEdge.getTargetId(), message)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncKHopAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncKHopAlgorithm.java index c205350ad..32e5f8c1e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncKHopAlgorithm.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncKHopAlgorithm.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.common.config.ConfigHelper; import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -46,393 +47,396 @@ @Description(name = "inc_khop", description = "built-in incur udf for KHop") public class IncKHopAlgorithm implements AlgorithmUserFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(IncKHopAlgorithm.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncKHopAlgorithm.class); - private static final String SKIP_OUTPUT = "khop.skip.output"; + private static final String SKIP_OUTPUT = "khop.skip.output"; - private AlgorithmRuntimeContext context; - protected long maxIterNum; - private boolean skipOutput = false; - protected Set intMessageSet = new HashSet<>(); - protected Map> outEdgeMap = new HashMap<>(); - protected Map> inEdgeMap = new HashMap<>(); - private Map> stashInPathMessageMap = new HashMap<>(); - private Map> stashOutPathMessageMap = new HashMap<>(); + private AlgorithmRuntimeContext context; + protected long maxIterNum; + private boolean skipOutput = false; + protected Set intMessageSet = new HashSet<>(); + protected Map> outEdgeMap = new HashMap<>(); + protected Map> inEdgeMap = new HashMap<>(); + private Map> stashInPathMessageMap = new HashMap<>(); + private Map> stashOutPathMessageMap = new HashMap<>(); - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - this.maxIterNum = Integer.parseInt(String.valueOf(parameters[0])) + 2; - this.skipOutput = ConfigHelper.getBooleanOrDefault(context.getConfig().getConfigMap(), SKIP_OUTPUT, false); - } + @Override + public void init( + AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + this.maxIterNum = Integer.parseInt(String.valueOf(parameters[0])) + 2; + this.skipOutput = + ConfigHelper.getBooleanOrDefault(context.getConfig().getConfigMap(), SKIP_OUTPUT, false); + } + + @Override + public void process( + RowVertex vertex, Optional updatedValues, Iterator messageIterator) { + long currentIterationId = context.getCurrentIterationId(); + if (currentIterationId == 1) { + List dynamicOutEdges = context.loadDynamicEdges(EdgeDirection.OUT); + if (dynamicOutEdges.isEmpty()) { + return; + } + IntTreePathMessage currentVPathMessage = new IntTreePathMessage(vertex.getId()); + sendMessage(dynamicOutEdges, currentVPathMessage); + sendMessage(vertex.getId()); + } else if (currentIterationId < maxIterNum) { + List staticOutEdges = getStaticOutEdges(vertex.getId()); + List staticInEdges = getStaticInEdges(vertex.getId()); + List tmpEdges = context.loadDynamicEdges(EdgeDirection.BOTH); + List dynamicInEdges = new ArrayList<>(); + List dynamicOutEdges = new ArrayList<>(); + if (tmpEdges != null) { + for (RowEdge edge : tmpEdges) { + if (edge.getDirect() == EdgeDirection.OUT) { + dynamicOutEdges.add(edge); + } else { + dynamicInEdges.add(edge); + } + } + } + + IntTreePathMessage sendOutPathMessage = null; + IntTreePathMessage sendInPathMessage = null; + List outMessages = new ArrayList<>(); + List inMessages = new ArrayList<>(); + while (messageIterator.hasNext()) { + IntTreePathMessage pathMsg = (IntTreePathMessage) messageIterator.next(); + if (pathMsg.getCurrentVertexId() == null) { + continue; + } + int depth = pathMsg.getPathLength(); + if (depth == currentIterationId - 1) { + // out dir traversal message + outMessages.add(pathMsg); + } else if (depth < currentIterationId - 1) { + // in dir traversal message + inMessages.add(pathMsg); + } + } + if (!outMessages.isEmpty()) { + IntTreePathMessage[] parent = new IntTreePathMessage[outMessages.size()]; + for (int i = 0; i < outMessages.size(); i++) { + parent[i] = outMessages.get(i); + } + sendOutPathMessage = new IntTreePathMessage(parent, vertex.getId()); + } + if (!inMessages.isEmpty()) { + if (currentIterationId == 2) { + throw new RuntimeException("iter 2 should not have in path message"); + } + IntTreePathMessage[] parent = new IntTreePathMessage[inMessages.size()]; + for (int i = 0; i < inMessages.size(); i++) { + parent[i] = inMessages.get(i); + } + sendInPathMessage = new IntTreePathMessage(parent, vertex.getId()); + } else { + if (currentIterationId == 2) { + sendInPathMessage = new IntTreePathMessage(null, vertex.getId()); + } + } + + if (currentIterationId < maxIterNum - 1) { + if (!stashOutPathMessageMap.containsKey(vertex.getId())) { + stashOutPathMessageMap.put(vertex.getId(), new ArrayList<>()); + } + List stashOutPathMessages = stashOutPathMessageMap.get(vertex.getId()); + if (sendOutPathMessage != null) { + sendMessage(staticOutEdges, sendOutPathMessage); + sendMessage(dynamicOutEdges, sendOutPathMessage); + // 合并消息树开启时,只在无消息发出时,保存当前路径 + if (staticOutEdges.isEmpty() && dynamicOutEdges.isEmpty()) { + stashOutPathMessages.add(sendOutPathMessage); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messageIterator) { - long currentIterationId = context.getCurrentIterationId(); - if (currentIterationId == 1) { - List dynamicOutEdges = context.loadDynamicEdges(EdgeDirection.OUT); - if (dynamicOutEdges.isEmpty()) { - return; + if (!stashInPathMessageMap.containsKey(vertex.getId())) { + stashInPathMessageMap.put(vertex.getId(), new ArrayList<>()); + } + List stashInPathMessages = stashInPathMessageMap.get(vertex.getId()); + if (sendInPathMessage != null) { + sendMessage(staticInEdges, sendInPathMessage); + sendMessage(dynamicInEdges, sendInPathMessage); + // When merge message tree is enabled, only save current path when no message is sent + if (staticInEdges.isEmpty() && dynamicInEdges.isEmpty()) { + stashInPathMessages.add(sendInPathMessage); + } + } + // Activate self + sendMessage(vertex.getId()); + } else if (currentIterationId == maxIterNum - 1) { + // tree path is reversed. + Map head2OutMixMsg = new HashMap<>(); + Map head2InMixMsg = new HashMap<>(); + if (sendOutPathMessage != null) { + Map> pathMap = sendOutPathMessage.generatePathMap(); + Iterator keys = pathMap.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + IntTreePathMessage outMsg = head2OutMixMsg.get(head); + if (outMsg == null) { + outMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 1); + head2OutMixMsg.put(head, outMsg); } - IntTreePathMessage currentVPathMessage = new IntTreePathMessage(vertex.getId()); - sendMessage(dynamicOutEdges, currentVPathMessage); - sendMessage(vertex.getId()); - } else if (currentIterationId < maxIterNum) { - List staticOutEdges = getStaticOutEdges(vertex.getId()); - List staticInEdges = getStaticInEdges(vertex.getId()); - List tmpEdges = context.loadDynamicEdges(EdgeDirection.BOTH); - List dynamicInEdges = new ArrayList<>(); - List dynamicOutEdges = new ArrayList<>(); - if (tmpEdges != null) { - for (RowEdge edge : tmpEdges) { - if (edge.getDirect() == EdgeDirection.OUT) { - dynamicOutEdges.add(edge); - } else { - dynamicInEdges.add(edge); - } - } + Iterator itr = pathMap.get(head).iterator(); + while (itr.hasNext()) { + Object[] tmp = itr.next(); + outMsg.addPath(tmp); } + } + } - IntTreePathMessage sendOutPathMessage = null; - IntTreePathMessage sendInPathMessage = null; - List outMessages = new ArrayList<>(); - List inMessages = new ArrayList<>(); - while (messageIterator.hasNext()) { - IntTreePathMessage pathMsg = (IntTreePathMessage) messageIterator.next(); - if (pathMsg.getCurrentVertexId() == null) { - continue; - } - int depth = pathMsg.getPathLength(); - if (depth == currentIterationId - 1) { - // out dir traversal message - outMessages.add(pathMsg); - } else if (depth < currentIterationId - 1) { - // in dir traversal message - inMessages.add(pathMsg); - } + if (stashOutPathMessageMap.containsKey(vertex.getId())) { + List stashOutPathMessages = + stashOutPathMessageMap.get(vertex.getId()); + for (IntTreePathMessage pathMessage : stashOutPathMessages) { + Map> pathMap = pathMessage.generatePathMap(); + Iterator keys = pathMap.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + IntTreePathMessage outMsg = head2OutMixMsg.get(head); + if (outMsg == null) { + outMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 1); + head2OutMixMsg.put(head, outMsg); + } + Iterator itr = pathMap.get(head).iterator(); + while (itr.hasNext()) { + Object[] tmp = itr.next(); + outMsg.addPath(tmp); + } } - if (!outMessages.isEmpty()) { - IntTreePathMessage[] parent = new IntTreePathMessage[outMessages.size()]; - for (int i = 0; i < outMessages.size(); i++) { - parent[i] = outMessages.get(i); - } - sendOutPathMessage = new IntTreePathMessage(parent, vertex.getId()); + } + } + + if (sendInPathMessage != null) { + Map> pathMap = sendInPathMessage.generatePathMap(); + Iterator keys = pathMap.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + IntTreePathMessage inMsg = head2InMixMsg.get(head); + if (inMsg == null) { + inMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 0); + head2InMixMsg.put(head, inMsg); } - if (!inMessages.isEmpty()) { - if (currentIterationId == 2) { - throw new RuntimeException("iter 2 should not have in path message"); - } - IntTreePathMessage[] parent = new IntTreePathMessage[inMessages.size()]; - for (int i = 0; i < inMessages.size(); i++) { - parent[i] = inMessages.get(i); - } - sendInPathMessage = new IntTreePathMessage(parent, vertex.getId()); - } else { - if (currentIterationId == 2) { - sendInPathMessage = new IntTreePathMessage(null, vertex.getId()); - } + Iterator itr = pathMap.get(head).iterator(); + while (itr.hasNext()) { + Object[] tmp = itr.next(); + inMsg.addPath(tmp); } + } + } - if (currentIterationId < maxIterNum - 1) { - if (!stashOutPathMessageMap.containsKey(vertex.getId())) { - stashOutPathMessageMap.put(vertex.getId(), new ArrayList<>()); - } - List stashOutPathMessages = stashOutPathMessageMap.get(vertex.getId()); - if (sendOutPathMessage != null) { - sendMessage(staticOutEdges, sendOutPathMessage); - sendMessage(dynamicOutEdges, sendOutPathMessage); - //合并消息树开启时,只在无消息发出时,保存当前路径 - if (staticOutEdges.isEmpty() && dynamicOutEdges.isEmpty()) { - stashOutPathMessages.add(sendOutPathMessage); - } - } - - if (!stashInPathMessageMap.containsKey(vertex.getId())) { - stashInPathMessageMap.put(vertex.getId(), new ArrayList<>()); - } - List stashInPathMessages = stashInPathMessageMap.get(vertex.getId()); - if (sendInPathMessage != null) { - sendMessage(staticInEdges, sendInPathMessage); - sendMessage(dynamicInEdges, sendInPathMessage); - //When merge message tree is enabled, only save current path when no message is sent - if (staticInEdges.isEmpty() && dynamicInEdges.isEmpty()) { - stashInPathMessages.add(sendInPathMessage); - } - } - // Activate self - sendMessage(vertex.getId()); - } else if (currentIterationId == maxIterNum - 1) { - // tree path is reversed. - Map head2OutMixMsg = new HashMap<>(); - Map head2InMixMsg = new HashMap<>(); - if (sendOutPathMessage != null) { - Map> pathMap = sendOutPathMessage.generatePathMap(); - Iterator keys = pathMap.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - IntTreePathMessage outMsg = head2OutMixMsg.get(head); - if (outMsg == null) { - outMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 1); - head2OutMixMsg.put(head, outMsg); - } - Iterator itr = pathMap.get(head).iterator(); - while (itr.hasNext()) { - Object[] tmp = itr.next(); - outMsg.addPath(tmp); - } - } - } - - if (stashOutPathMessageMap.containsKey(vertex.getId())) { - List stashOutPathMessages = stashOutPathMessageMap.get(vertex.getId()); - for (IntTreePathMessage pathMessage : stashOutPathMessages) { - Map> pathMap = pathMessage.generatePathMap(); - Iterator keys = pathMap.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - IntTreePathMessage outMsg = head2OutMixMsg.get(head); - if (outMsg == null) { - outMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 1); - head2OutMixMsg.put(head, outMsg); - } - Iterator itr = pathMap.get(head).iterator(); - while (itr.hasNext()) { - Object[] tmp = itr.next(); - outMsg.addPath(tmp); - } - } - } - } - - if (sendInPathMessage != null) { - Map> pathMap = sendInPathMessage.generatePathMap(); - Iterator keys = pathMap.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - IntTreePathMessage inMsg = head2InMixMsg.get(head); - if (inMsg == null) { - inMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 0); - head2InMixMsg.put(head, inMsg); - } - Iterator itr = pathMap.get(head).iterator(); - while (itr.hasNext()) { - Object[] tmp = itr.next(); - inMsg.addPath(tmp); - } - } - } - - if (stashInPathMessageMap.containsKey(vertex.getId())) { - List stashInPathMessages = stashInPathMessageMap.get(vertex.getId()); - for (IntTreePathMessage pathMessage : stashInPathMessages) { - Map> pathMap = pathMessage.generatePathMap(); - Iterator keys = pathMap.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - IntTreePathMessage inMsg = head2InMixMsg.get(head); - if (inMsg == null) { - inMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 0); - head2InMixMsg.put(head, inMsg); - } - Iterator itr = pathMap.get(head).iterator(); - while (itr.hasNext()) { - Object[] tmp = itr.next(); - inMsg.addPath(tmp); - } - } - } - } - - Iterator keys = head2OutMixMsg.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - context.sendMessage(head, head2OutMixMsg.get(head)); - } - keys = head2InMixMsg.keySet().iterator(); - while (keys.hasNext()) { - Object head = keys.next(); - context.sendMessage(head, head2InMixMsg.get(head)); - } + if (stashInPathMessageMap.containsKey(vertex.getId())) { + List stashInPathMessages = stashInPathMessageMap.get(vertex.getId()); + for (IntTreePathMessage pathMessage : stashInPathMessages) { + Map> pathMap = pathMessage.generatePathMap(); + Iterator keys = pathMap.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + IntTreePathMessage inMsg = head2InMixMsg.get(head); + if (inMsg == null) { + inMsg = new IntTreePathMessage.IntTreePathMessageWrapper(null, 0); + head2InMixMsg.put(head, inMsg); + } + Iterator itr = pathMap.get(head).iterator(); + while (itr.hasNext()) { + Object[] tmp = itr.next(); + inMsg.addPath(tmp); + } } + } + } + + Iterator keys = head2OutMixMsg.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + context.sendMessage(head, head2OutMixMsg.get(head)); + } + keys = head2InMixMsg.keySet().iterator(); + while (keys.hasNext()) { + Object head = keys.next(); + context.sendMessage(head, head2InMixMsg.get(head)); + } + } + } else { + IntTreePathMessage sendOutPathMessage = new IntTreePathMessage(); + IntTreePathMessage sendInPathMessage = new IntTreePathMessage(); + while (messageIterator.hasNext()) { + IntTreePathMessage.IntTreePathMessageWrapper pathMsg = + (IntTreePathMessage.IntTreePathMessageWrapper) messageIterator.next(); + if (pathMsg.getTag() == 0) { + // in path + sendInPathMessage.merge(pathMsg); } else { - IntTreePathMessage sendOutPathMessage = new IntTreePathMessage(); - IntTreePathMessage sendInPathMessage = new IntTreePathMessage(); - while (messageIterator.hasNext()) { - IntTreePathMessage.IntTreePathMessageWrapper pathMsg = - (IntTreePathMessage.IntTreePathMessageWrapper) messageIterator.next(); - if (pathMsg.getTag() == 0) { - // in path - sendInPathMessage.merge(pathMsg); - } else { - sendOutPathMessage.merge(pathMsg); - } - } - if (!skipOutput) { - constructResult(sendOutPathMessage, sendInPathMessage, (int) maxIterNum - 1); - } + sendOutPathMessage.merge(pathMsg); } + } + if (!skipOutput) { + constructResult(sendOutPathMessage, sendInPathMessage, (int) maxIterNum - 1); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("ret", StringType.INSTANCE, false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType(new TableField("ret", StringType.INSTANCE, false)); + } - @Override - public void finishIteration(long iterationId) { - this.intMessageSet.clear(); - } + @Override + public void finishIteration(long iterationId) { + this.intMessageSet.clear(); + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) {} - @Override - public void finish() { - this.intMessageSet.clear(); - this.outEdgeMap.clear(); - this.inEdgeMap.clear(); - this.stashInPathMessageMap.clear(); - this.stashOutPathMessageMap.clear(); - } + @Override + public void finish() { + this.intMessageSet.clear(); + this.outEdgeMap.clear(); + this.inEdgeMap.clear(); + this.stashInPathMessageMap.clear(); + this.stashOutPathMessageMap.clear(); + } - protected void sendMessage(List edges) { - for (RowEdge edge : edges) { - if (!intMessageSet.contains(edge.getTargetId())) { - intMessageSet.add(edge.getTargetId()); - context.sendMessage(edge.getTargetId(), new IntTreePathMessage()); - } - } + protected void sendMessage(List edges) { + for (RowEdge edge : edges) { + if (!intMessageSet.contains(edge.getTargetId())) { + intMessageSet.add(edge.getTargetId()); + context.sendMessage(edge.getTargetId(), new IntTreePathMessage()); + } } + } - protected void sendMessage(Object targetId) { - if (!intMessageSet.contains(targetId)) { - intMessageSet.add(targetId); - context.sendMessage(targetId, new IntTreePathMessage()); - } + protected void sendMessage(Object targetId) { + if (!intMessageSet.contains(targetId)) { + intMessageSet.add(targetId); + context.sendMessage(targetId, new IntTreePathMessage()); } + } - protected void sendMessage(List edges, IntTreePathMessage pathMessage) { - for (RowEdge edge : edges) { - context.sendMessage(edge.getTargetId(), pathMessage); - } + protected void sendMessage(List edges, IntTreePathMessage pathMessage) { + for (RowEdge edge : edges) { + context.sendMessage(edge.getTargetId(), pathMessage); } + } - protected List getStaticInEdges(Object vid) { - List staticInEdges = inEdgeMap.get(vid); - if (staticInEdges == null) { - staticInEdges = this.context.loadStaticEdges(EdgeDirection.IN); - inEdgeMap.put(vid, staticInEdges); - return staticInEdges; - } else { - return staticInEdges; - } + protected List getStaticInEdges(Object vid) { + List staticInEdges = inEdgeMap.get(vid); + if (staticInEdges == null) { + staticInEdges = this.context.loadStaticEdges(EdgeDirection.IN); + inEdgeMap.put(vid, staticInEdges); + return staticInEdges; + } else { + return staticInEdges; } + } - protected List getStaticOutEdges(Object vid) { - List staticOutEdges = outEdgeMap.get(vid); - if (staticOutEdges == null) { - staticOutEdges = this.context.loadStaticEdges(EdgeDirection.OUT); - outEdgeMap.put(vid, staticOutEdges); - return staticOutEdges; - } else { - return staticOutEdges; - } + protected List getStaticOutEdges(Object vid) { + List staticOutEdges = outEdgeMap.get(vid); + if (staticOutEdges == null) { + staticOutEdges = this.context.loadStaticEdges(EdgeDirection.OUT); + outEdgeMap.put(vid, staticOutEdges); + return staticOutEdges; + } else { + return staticOutEdges; } + } - private void constructResult(IntTreePathMessage outMessage, IntTreePathMessage inMessage, int expectedLength) { + private void constructResult( + IntTreePathMessage outMessage, IntTreePathMessage inMessage, int expectedLength) { - // this path is begin with start central vertex - Iterator outPathIterator = outMessage.getPaths(); - Map> outPathMap = new HashMap<>(); - while (outPathIterator.hasNext()) { - Object[] currentPath = outPathIterator.next(); - if (!outPathMap.containsKey(currentPath.length)) { - outPathMap.put(currentPath.length, new ArrayList<>()); - } - outPathMap.get(currentPath.length).add(currentPath); + // this path is begin with start central vertex + Iterator outPathIterator = outMessage.getPaths(); + Map> outPathMap = new HashMap<>(); + while (outPathIterator.hasNext()) { + Object[] currentPath = outPathIterator.next(); + if (!outPathMap.containsKey(currentPath.length)) { + outPathMap.put(currentPath.length, new ArrayList<>()); + } + outPathMap.get(currentPath.length).add(currentPath); + } + // this iterator is end with central vertex + Iterator inPathIterator = inMessage.getPaths(); + Map> inPathMap = new HashMap<>(); + while (inPathIterator.hasNext()) { + Object[] currentPath = inPathIterator.next(); + if (!inPathMap.containsKey(currentPath.length)) { + inPathMap.put(currentPath.length, new ArrayList<>()); + } + inPathMap.get(currentPath.length).add(currentPath); + } + for (int outLength = expectedLength; outLength > 0; outLength--) { + if (outLength == expectedLength) { + List outPaths = outPathMap.get(outLength); + if (outPaths == null) { + continue; } - // this iterator is end with central vertex - Iterator inPathIterator = inMessage.getPaths(); - Map> inPathMap = new HashMap<>(); - while (inPathIterator.hasNext()) { - Object[] currentPath = inPathIterator.next(); - if (!inPathMap.containsKey(currentPath.length)) { - inPathMap.put(currentPath.length, new ArrayList<>()); - } - inPathMap.get(currentPath.length).add(currentPath); + for (Object[] outPath : outPaths) { + String pathStr = convertToPath(outPath, outLength, null, 0); + if (!skipOutput) { + context.take(ObjectRow.create(pathStr)); + } } - for (int outLength = expectedLength; outLength > 0; outLength--) { - if (outLength == expectedLength) { - List outPaths = outPathMap.get(outLength); - if (outPaths == null) { - continue; - } - for (Object[] outPath : outPaths) { - String pathStr = convertToPath(outPath, outLength, null, 0); - if (!skipOutput) { - context.take(ObjectRow.create(pathStr)); - } - } - } else { - int inLength = expectedLength - outLength + 1; - Set inPaths = getPaths(inPathMap, inLength, expectedLength, 0); - Set outPaths = getPaths(outPathMap, outLength, expectedLength, 1); - if (!outPaths.isEmpty() && !inPaths.isEmpty()) { - for (String outPath : outPaths) { - for (String inPath : inPaths) { - String pathStr = inPath + outPath; - if (!skipOutput) { - context.take(ObjectRow.create(pathStr)); - } - } - } - } + } else { + int inLength = expectedLength - outLength + 1; + Set inPaths = getPaths(inPathMap, inLength, expectedLength, 0); + Set outPaths = getPaths(outPathMap, outLength, expectedLength, 1); + if (!outPaths.isEmpty() && !inPaths.isEmpty()) { + for (String outPath : outPaths) { + for (String inPath : inPaths) { + String pathStr = inPath + outPath; + if (!skipOutput) { + context.take(ObjectRow.create(pathStr)); + } } + } } + } } + } - public static Set getPaths(Map> pathMap, int expectLength, int maxLength, int idr) { - Set paths = new HashSet<>(); - for (int length = expectLength; length <= maxLength; length++) { - List currentPaths = pathMap.get(length); - if (currentPaths == null) { - continue; - } - for (Object[] path : currentPaths) { - StringBuilder sb = new StringBuilder(); - if (idr == 0) { - int offset = path.length - expectLength; - for (int i = 0; i < expectLength - 1; i++) { - sb.append(path[offset + i]).append(","); - } - paths.add(sb.toString()); - } else { - int arrayLength = path.length; - for (int i = 0; i < expectLength; i++) { - sb.append(path[arrayLength - i - 1]).append(","); - } - paths.add(sb.toString()); - } - } + public static Set getPaths( + Map> pathMap, int expectLength, int maxLength, int idr) { + Set paths = new HashSet<>(); + for (int length = expectLength; length <= maxLength; length++) { + List currentPaths = pathMap.get(length); + if (currentPaths == null) { + continue; + } + for (Object[] path : currentPaths) { + StringBuilder sb = new StringBuilder(); + if (idr == 0) { + int offset = path.length - expectLength; + for (int i = 0; i < expectLength - 1; i++) { + sb.append(path[offset + i]).append(","); + } + paths.add(sb.toString()); + } else { + int arrayLength = path.length; + for (int i = 0; i < expectLength; i++) { + sb.append(path[arrayLength - i - 1]).append(","); + } + paths.add(sb.toString()); } - return paths; + } } + return paths; + } - public static String convertToPath(Object[] outPath, int outLength, int[] inPath, int inLength) { - StringBuilder sb = new StringBuilder(); - if (inLength > 1) { - int offset = inPath.length - inLength; - for (int i = 0; i < inLength - 1; i++) { - sb.append(inPath[offset + i]).append(","); - } - } - int arrayLength = outPath.length; - for (int i = 0; i < outLength; i++) { - sb.append(outPath[arrayLength - i - 1]).append(","); - } - return sb.toString(); + public static String convertToPath(Object[] outPath, int outLength, int[] inPath, int inLength) { + StringBuilder sb = new StringBuilder(); + if (inLength > 1) { + int offset = inPath.length - inLength; + for (int i = 0; i < inLength - 1; i++) { + sb.append(inPath[offset + i]).append(","); + } + } + int arrayLength = outPath.length; + for (int i = 0; i < outLength; i++) { + sb.append(outPath[arrayLength - i - 1]).append(","); } + return sb.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncMinimumSpanningTree.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncMinimumSpanningTree.java index 80b059af5..410217e77 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncMinimumSpanningTree.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncMinimumSpanningTree.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.primitive.DoubleType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -46,442 +47,441 @@ import org.slf4j.LoggerFactory; /** - * Incremental Minimum Spanning Tree algorithm implementation. - * Based on Geaflow incremental graph computing capabilities, implements MST maintenance on dynamic graphs. + * Incremental Minimum Spanning Tree algorithm implementation. Based on Geaflow incremental graph + * computing capabilities, implements MST maintenance on dynamic graphs. * - *

Algorithm principle: - * 1. Maintain current MST state - * 2. For new edges: Use Union-Find to detect if cycles are formed, if no cycle and weight is smaller then add to MST - * 3. For deleted edges: If deleted edge is MST edge, need to reconnect separated components - * 4. Use vertex-centric message passing mechanism for distributed computing + *

Algorithm principle: 1. Maintain current MST state 2. For new edges: Use Union-Find to detect + * if cycles are formed, if no cycle and weight is smaller then add to MST 3. For deleted edges: If + * deleted edge is MST edge, need to reconnect separated components 4. Use vertex-centric message + * passing mechanism for distributed computing * * @author Geaflow Team */ @Description(name = "IncMST", description = "built-in udga for Incremental Minimum Spanning Tree") -public class IncMinimumSpanningTree implements AlgorithmUserFunction, - IncrementalAlgorithmUserFunction { - - private static final Logger LOGGER = LoggerFactory.getLogger(IncMinimumSpanningTree.class); - - /** Field index for vertex state in row value. */ - private static final int STATE_FIELD_INDEX = 0; - - private AlgorithmRuntimeContext context; - private IType idType; // Cache the ID type for better performance - - // Configuration parameters - private int maxIterations = 50; // Default maximum iterations - private double convergenceThreshold = 0.001; // Default convergence threshold - private String keyFieldName = "mst_edges"; // Default key field name - - // Memory optimization parameters - private static final int MEMORY_COMPACT_INTERVAL = 10; // Compact memory every 10 iterations - private int iterationCount = 0; - - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - - // Cache the ID type for better performance and type safety - this.idType = context.getGraphSchema().getIdType(); - - // Parse configuration parameters - if (parameters != null && parameters.length > 0) { - if (parameters.length > 3) { - throw new IllegalArgumentException( - "IncMinimumSpanningTree algorithm supports at most 3 parameters: " - + "maxIterations, convergenceThreshold, keyFieldName"); - } - - // Parse maxIterations (first parameter) - if (parameters.length > 0 && parameters[0] != null) { - try { - this.maxIterations = Integer.parseInt(String.valueOf(parameters[0])); - if (this.maxIterations <= 0) { - throw new IllegalArgumentException("maxIterations must be positive"); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException( - "Invalid maxIterations parameter: " + parameters[0], e); - } - } - - // Parse convergenceThreshold (second parameter) - if (parameters.length > 1 && parameters[1] != null) { - try { - this.convergenceThreshold = Double.parseDouble(String.valueOf(parameters[1])); - if (this.convergenceThreshold < 0 || this.convergenceThreshold > 1) { - throw new IllegalArgumentException( - "convergenceThreshold must be between 0 and 1"); - } - } catch (NumberFormatException e) { - throw new IllegalArgumentException( - "Invalid convergenceThreshold parameter: " + parameters[1], e); - } - } - - // Parse keyFieldName (third parameter) - if (parameters.length > 2 && parameters[2] != null) { - this.keyFieldName = String.valueOf(parameters[2]); - if (this.keyFieldName.trim().isEmpty()) { - throw new IllegalArgumentException("keyFieldName cannot be empty"); - } - } - } - - LOGGER.info("IncMinimumSpanningTree initialized with maxIterations={}, convergenceThreshold={}, keyFieldName='{}'", - maxIterations, convergenceThreshold, keyFieldName); - } +public class IncMinimumSpanningTree + implements AlgorithmUserFunction, IncrementalAlgorithmUserFunction { - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - // Initialize vertex state if not exists - MSTVertexState currentState = getCurrentVertexState(vertex); - - // Process incoming messages - boolean stateChanged = false; - Object validatedVertexId = validateVertexId(vertex.getId()); - while (messages.hasNext()) { - Object messageObj = messages.next(); - if (!(messageObj instanceof MSTMessage)) { - throw new IllegalArgumentException( - String.format("Invalid message type for IncMinimumSpanningTree: expected %s, got %s (value: %s)", - MSTMessage.class.getSimpleName(), - messageObj.getClass().getSimpleName(), - messageObj) - ); - } - MSTMessage message = (MSTMessage) messageObj; - if (processMessage(validatedVertexId, message, currentState)) { - stateChanged = true; - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncMinimumSpanningTree.class); - // If this is the first iteration and no messages were processed, - // load edges and propose them to neighbors - if (!updatedValues.isPresent() && !messages.hasNext()) { - // Load all outgoing edges and propose them to target vertices - List outEdges = context.loadEdges(EdgeDirection.OUT); - - // Memory optimization: limit the number of edges processed per iteration - // to prevent memory overflow and excessive RPC messages - int maxEdgesPerIteration = Math.min(outEdges.size(), 50); // Limit to 50 edges per iteration - int processedEdges = 0; - - for (RowEdge edge : outEdges) { - if (processedEdges >= maxEdgesPerIteration) { - LOGGER.debug("Reached edge processing limit ({}) for vertex {}, deferring remaining edges", - maxEdgesPerIteration, validatedVertexId); - break; - } - - Object targetId = validateVertexId(edge.getTargetId()); - double weight = (Double) edge.getValue().getField(0, DoubleType.INSTANCE); - - // Create edge proposal message - MSTMessage proposalMessage = new MSTMessage( - MSTMessage.MessageType.EDGE_PROPOSAL, - validatedVertexId, - targetId, - weight, - currentState.getComponentId() - ); - - // Send proposal to target vertex - context.sendMessage(targetId, proposalMessage); - processedEdges++; - - LOGGER.debug("Sent edge proposal from {} to {} with weight {} ({}/{})", - validatedVertexId, targetId, weight, processedEdges, maxEdgesPerIteration); - } - } + /** Field index for vertex state in row value. */ + private static final int STATE_FIELD_INDEX = 0; - // Memory optimization: compact vertex state periodically - iterationCount++; - if (iterationCount % MEMORY_COMPACT_INTERVAL == 0) { - currentState.compactMSTEdges(); - LOGGER.debug("Memory compaction performed for vertex {} at iteration {}", - validatedVertexId, iterationCount); - } + private AlgorithmRuntimeContext context; + private IType idType; // Cache the ID type for better performance - // Update vertex state if changed - if (stateChanged) { - context.updateVertexValue(ObjectRow.create(currentState, true)); - } else if (!updatedValues.isPresent()) { - // First time initialization - context.updateVertexValue(ObjectRow.create(currentState, true)); - } + // Configuration parameters + private int maxIterations = 50; // Default maximum iterations + private double convergenceThreshold = 0.001; // Default convergence threshold + private String keyFieldName = "mst_edges"; // Default key field name - // Vote to terminate if no state changes occurred and we've processed messages - // This ensures the algorithm terminates after processing all edges - // Also check if we've reached the maximum number of iterations - long currentIteration = context.getCurrentIterationId(); - if ((!stateChanged && updatedValues.isPresent()) || currentIteration >= maxIterations) { - String terminationReason = currentIteration >= maxIterations - ? "MAX_ITERATIONS_REACHED" : "MST_CONVERGED"; - context.voteToTerminate(terminationReason, 1); - - if (currentIteration >= maxIterations) { - LOGGER.warn("IncMST algorithm reached maximum iterations ({}) without convergence", maxIterations); - } - } - } + // Memory optimization parameters + private static final int MEMORY_COMPACT_INTERVAL = 10; // Compact memory every 10 iterations + private int iterationCount = 0; - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - // Output MST results for each vertex - if (updatedValues.isPresent()) { - Row values = updatedValues.get(); - Object stateObj = values.getField(STATE_FIELD_INDEX, ObjectType.INSTANCE); - if (!(stateObj instanceof MSTVertexState)) { - throw new IllegalStateException( - String.format("Invalid vertex state type in finish(): expected %s, got %s (value: %s)", - MSTVertexState.class.getSimpleName(), - stateObj.getClass().getSimpleName(), - stateObj) - ); - } - MSTVertexState state = (MSTVertexState) stateObj; - // Output each MST edge as a separate record - for (MSTEdge mstEdge : state.getMstEdges()) { - // Validate IDs before outputting - Object validatedSrcId = validateVertexId(mstEdge.getSourceId()); - Object validatedTargetId = validateVertexId(mstEdge.getTargetId()); - double weight = mstEdge.getWeight(); - - context.take(ObjectRow.create(validatedSrcId, validatedTargetId, weight)); - } - } - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; - @Override - public StructType getOutputType(GraphSchema graphSchema) { - // Use the cached ID type for consistency and performance - IType vertexIdType = (idType != null) ? idType : graphSchema.getIdType(); - - // Return result type: srcId, targetId, weight for each MST edge - return new StructType( - new TableField("srcId", vertexIdType, false), - new TableField("targetId", vertexIdType, false), - new TableField("weight", DoubleType.INSTANCE, false) - ); - } + // Cache the ID type for better performance and type safety + this.idType = context.getGraphSchema().getIdType(); - /** - * Initialize vertex state. - * Each vertex is initialized as an independent component with itself as the root node. - */ - private void initializeVertex(RowVertex vertex) { - // Validate vertex ID from input - Object vertexId = validateVertexId(vertex.getId()); + // Parse configuration parameters + if (parameters != null && parameters.length > 0) { + if (parameters.length > 3) { + throw new IllegalArgumentException( + "IncMinimumSpanningTree algorithm supports at most 3 parameters: " + + "maxIterations, convergenceThreshold, keyFieldName"); + } - // Create initial MST state - MSTVertexState initialState = new MSTVertexState(vertexId); + // Parse maxIterations (first parameter) + if (parameters.length > 0 && parameters[0] != null) { + try { + this.maxIterations = Integer.parseInt(String.valueOf(parameters[0])); + if (this.maxIterations <= 0) { + throw new IllegalArgumentException("maxIterations must be positive"); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid maxIterations parameter: " + parameters[0], e); + } + } - // Update vertex value - context.updateVertexValue(ObjectRow.create(initialState, true)); - } + // Parse convergenceThreshold (second parameter) + if (parameters.length > 1 && parameters[1] != null) { + try { + this.convergenceThreshold = Double.parseDouble(String.valueOf(parameters[1])); + if (this.convergenceThreshold < 0 || this.convergenceThreshold > 1) { + throw new IllegalArgumentException("convergenceThreshold must be between 0 and 1"); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Invalid convergenceThreshold parameter: " + parameters[1], e); + } + } - /** - * Process single message. - * Execute corresponding processing logic based on message type. - */ - private boolean processMessage(Object vertexId, MSTMessage message, MSTVertexState state) { - // Simplified message processing for basic MST functionality - switch (message.getType()) { - case COMPONENT_UPDATE: - return handleComponentUpdate(vertexId, message, state); - case EDGE_PROPOSAL: - return handleEdgeProposal(vertexId, message, state); - case EDGE_ACCEPTANCE: - return handleEdgeAcceptance(vertexId, message, state); - case EDGE_REJECTION: - return handleEdgeRejection(vertexId, message, state); - case MST_EDGE_FOUND: - return handleMSTEdgeFound(vertexId, message, state); - default: - return false; + // Parse keyFieldName (third parameter) + if (parameters.length > 2 && parameters[2] != null) { + this.keyFieldName = String.valueOf(parameters[2]); + if (this.keyFieldName.trim().isEmpty()) { + throw new IllegalArgumentException("keyFieldName cannot be empty"); } + } } - /** - * Handle component update message. - * Update vertex component identifier. - */ - private boolean handleComponentUpdate(Object vertexId, MSTMessage message, MSTVertexState state) { - // Validate component ID using cached type information - Object validatedComponentId = validateVertexId(message.getComponentId()); - if (!validatedComponentId.equals(state.getComponentId())) { - state.setComponentId(validatedComponentId); - return true; - } - return false; + LOGGER.info( + "IncMinimumSpanningTree initialized with maxIterations={}, convergenceThreshold={}," + + " keyFieldName='{}'", + maxIterations, + convergenceThreshold, + keyFieldName); + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + // Initialize vertex state if not exists + MSTVertexState currentState = getCurrentVertexState(vertex); + + // Process incoming messages + boolean stateChanged = false; + Object validatedVertexId = validateVertexId(vertex.getId()); + while (messages.hasNext()) { + Object messageObj = messages.next(); + if (!(messageObj instanceof MSTMessage)) { + throw new IllegalArgumentException( + String.format( + "Invalid message type for IncMinimumSpanningTree: expected %s, got %s (value: %s)", + MSTMessage.class.getSimpleName(), + messageObj.getClass().getSimpleName(), + messageObj)); + } + MSTMessage message = (MSTMessage) messageObj; + if (processMessage(validatedVertexId, message, currentState)) { + stateChanged = true; + } } - /** - * Handle edge proposal message. - * Check whether to accept new MST edge. - * In incremental MST, an edge can be accepted if its endpoints belong to different components. - */ - private boolean handleEdgeProposal(Object vertexId, MSTMessage message, MSTVertexState state) { - // Validate vertex IDs - Object validatedSourceId = validateVertexId(message.getSourceId()); - Object validatedTargetId = validateVertexId(message.getTargetId()); - - // Check if edge endpoints belong to different components - // If they do, the edge can be accepted without creating a cycle - Object currentComponentId = state.getComponentId(); - Object proposedComponentId = message.getComponentId(); - - // Only accept edge if endpoints are in different components - if (!Objects.equals(currentComponentId, proposedComponentId)) { - // Create acceptance message - MSTMessage acceptanceMessage = new MSTMessage( - MSTMessage.MessageType.EDGE_ACCEPTANCE, - validatedSourceId, - validatedTargetId, - message.getWeight(), - proposedComponentId - ); - - // Send acceptance message to the source vertex - context.sendMessage(validatedSourceId, acceptanceMessage); - - LOGGER.debug("Accepted edge proposal: {} -- {} (weight: {}) between components {} and {}", - validatedSourceId, validatedTargetId, message.getWeight(), - currentComponentId, proposedComponentId); - - return true; - } else { - // Edge endpoints are in the same component, would create a cycle - // Send rejection message - MSTMessage rejectionMessage = new MSTMessage( - MSTMessage.MessageType.EDGE_REJECTION, - validatedSourceId, - validatedTargetId, - message.getWeight(), - proposedComponentId - ); - - // Send rejection message to the source vertex - context.sendMessage(validatedSourceId, rejectionMessage); - - LOGGER.debug("Rejected edge proposal: {} -- {} (weight: {}) - same component {}", - validatedSourceId, validatedTargetId, message.getWeight(), currentComponentId); - - return false; + // If this is the first iteration and no messages were processed, + // load edges and propose them to neighbors + if (!updatedValues.isPresent() && !messages.hasNext()) { + // Load all outgoing edges and propose them to target vertices + List outEdges = context.loadEdges(EdgeDirection.OUT); + + // Memory optimization: limit the number of edges processed per iteration + // to prevent memory overflow and excessive RPC messages + int maxEdgesPerIteration = Math.min(outEdges.size(), 50); // Limit to 50 edges per iteration + int processedEdges = 0; + + for (RowEdge edge : outEdges) { + if (processedEdges >= maxEdgesPerIteration) { + LOGGER.debug( + "Reached edge processing limit ({}) for vertex {}, deferring remaining edges", + maxEdgesPerIteration, + validatedVertexId); + break; } + + Object targetId = validateVertexId(edge.getTargetId()); + double weight = (Double) edge.getValue().getField(0, DoubleType.INSTANCE); + + // Create edge proposal message + MSTMessage proposalMessage = + new MSTMessage( + MSTMessage.MessageType.EDGE_PROPOSAL, + validatedVertexId, + targetId, + weight, + currentState.getComponentId()); + + // Send proposal to target vertex + context.sendMessage(targetId, proposalMessage); + processedEdges++; + + LOGGER.debug( + "Sent edge proposal from {} to {} with weight {} ({}/{})", + validatedVertexId, + targetId, + weight, + processedEdges, + maxEdgesPerIteration); + } } - /** - * Handle edge acceptance message. - * Add MST edge and merge components. - */ - private boolean handleEdgeAcceptance(Object vertexId, MSTMessage message, MSTVertexState state) { - // Validate vertex IDs using cached type information - Object validatedVertexId = validateVertexId(vertexId); - Object validatedSourceId = validateVertexId(message.getSourceId()); - - // Create MST edge with validated IDs - MSTEdge mstEdge = new MSTEdge(validatedVertexId, validatedSourceId, message.getWeight()); - state.addMSTEdge(mstEdge); - - // Merge components with type validation - Object validatedMessageComponentId = validateVertexId(message.getComponentId()); - Object newComponentId = findMinComponentId(state.getComponentId(), validatedMessageComponentId); - state.setComponentId(newComponentId); - - return true; + // Memory optimization: compact vertex state periodically + iterationCount++; + if (iterationCount % MEMORY_COMPACT_INTERVAL == 0) { + currentState.compactMSTEdges(); + LOGGER.debug( + "Memory compaction performed for vertex {} at iteration {}", + validatedVertexId, + iterationCount); } - /** - * Handle edge rejection message. - * Record rejected edges. - */ - private boolean handleEdgeRejection(Object vertexId, MSTMessage message, MSTVertexState state) { - // Can record rejected edges here for debugging or statistics - return false; + // Update vertex state if changed + if (stateChanged) { + context.updateVertexValue(ObjectRow.create(currentState, true)); + } else if (!updatedValues.isPresent()) { + // First time initialization + context.updateVertexValue(ObjectRow.create(currentState, true)); } - /** - * Handle MST edge discovery message. - * Record discovered MST edges. - */ - private boolean handleMSTEdgeFound(Object vertexId, MSTMessage message, MSTVertexState state) { - MSTEdge foundEdge = message.getEdge(); - if (foundEdge != null && !state.getMstEdges().contains(foundEdge)) { - state.addMSTEdge(foundEdge); - return true; - } + // Vote to terminate if no state changes occurred and we've processed messages + // This ensures the algorithm terminates after processing all edges + // Also check if we've reached the maximum number of iterations + long currentIteration = context.getCurrentIterationId(); + if ((!stateChanged && updatedValues.isPresent()) || currentIteration >= maxIterations) { + String terminationReason = + currentIteration >= maxIterations ? "MAX_ITERATIONS_REACHED" : "MST_CONVERGED"; + context.voteToTerminate(terminationReason, 1); + + if (currentIteration >= maxIterations) { + LOGGER.warn( + "IncMST algorithm reached maximum iterations ({}) without convergence", maxIterations); + } + } + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + // Output MST results for each vertex + if (updatedValues.isPresent()) { + Row values = updatedValues.get(); + Object stateObj = values.getField(STATE_FIELD_INDEX, ObjectType.INSTANCE); + if (!(stateObj instanceof MSTVertexState)) { + throw new IllegalStateException( + String.format( + "Invalid vertex state type in finish(): expected %s, got %s (value: %s)", + MSTVertexState.class.getSimpleName(), + stateObj.getClass().getSimpleName(), + stateObj)); + } + MSTVertexState state = (MSTVertexState) stateObj; + // Output each MST edge as a separate record + for (MSTEdge mstEdge : state.getMstEdges()) { + // Validate IDs before outputting + Object validatedSrcId = validateVertexId(mstEdge.getSourceId()); + Object validatedTargetId = validateVertexId(mstEdge.getTargetId()); + double weight = mstEdge.getWeight(); + + context.take(ObjectRow.create(validatedSrcId, validatedTargetId, weight)); + } + } + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + // Use the cached ID type for consistency and performance + IType vertexIdType = (idType != null) ? idType : graphSchema.getIdType(); + + // Return result type: srcId, targetId, weight for each MST edge + return new StructType( + new TableField("srcId", vertexIdType, false), + new TableField("targetId", vertexIdType, false), + new TableField("weight", DoubleType.INSTANCE, false)); + } + + /** + * Initialize vertex state. Each vertex is initialized as an independent component with itself as + * the root node. + */ + private void initializeVertex(RowVertex vertex) { + // Validate vertex ID from input + Object vertexId = validateVertexId(vertex.getId()); + + // Create initial MST state + MSTVertexState initialState = new MSTVertexState(vertexId); + + // Update vertex value + context.updateVertexValue(ObjectRow.create(initialState, true)); + } + + /** Process single message. Execute corresponding processing logic based on message type. */ + private boolean processMessage(Object vertexId, MSTMessage message, MSTVertexState state) { + // Simplified message processing for basic MST functionality + switch (message.getType()) { + case COMPONENT_UPDATE: + return handleComponentUpdate(vertexId, message, state); + case EDGE_PROPOSAL: + return handleEdgeProposal(vertexId, message, state); + case EDGE_ACCEPTANCE: + return handleEdgeAcceptance(vertexId, message, state); + case EDGE_REJECTION: + return handleEdgeRejection(vertexId, message, state); + case MST_EDGE_FOUND: + return handleMSTEdgeFound(vertexId, message, state); + default: return false; } - - /** - * Validate and convert vertex ID to ensure type safety. - * Uses TypeCastUtil for comprehensive type validation and conversion. - * - * @param vertexId The vertex ID to validate - * @return The validated vertex ID - * @throws IllegalArgumentException if vertexId is null or type incompatible - */ - private Object validateVertexId(Object vertexId) { - if (vertexId == null) { - throw new IllegalArgumentException("Vertex ID cannot be null"); - } - - // If idType is not initialized (should not happen in normal flow), return as-is - if (idType == null) { - return vertexId; - } - - try { - // Use TypeCastUtil for type conversion - this handles all supported type conversions - return TypeCastUtil.cast(vertexId, idType.getTypeClass()); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException( - String.format("Invalid vertex ID type conversion: expected %s, got %s (value: %s). Error: %s", - idType.getTypeClass().getSimpleName(), - vertexId.getClass().getSimpleName(), - vertexId, - e.getMessage() - ), e - ); - } + } + + /** Handle component update message. Update vertex component identifier. */ + private boolean handleComponentUpdate(Object vertexId, MSTMessage message, MSTVertexState state) { + // Validate component ID using cached type information + Object validatedComponentId = validateVertexId(message.getComponentId()); + if (!validatedComponentId.equals(state.getComponentId())) { + state.setComponentId(validatedComponentId); + return true; + } + return false; + } + + /** + * Handle edge proposal message. Check whether to accept new MST edge. In incremental MST, an edge + * can be accepted if its endpoints belong to different components. + */ + private boolean handleEdgeProposal(Object vertexId, MSTMessage message, MSTVertexState state) { + // Validate vertex IDs + Object validatedSourceId = validateVertexId(message.getSourceId()); + Object validatedTargetId = validateVertexId(message.getTargetId()); + + // Check if edge endpoints belong to different components + // If they do, the edge can be accepted without creating a cycle + Object currentComponentId = state.getComponentId(); + Object proposedComponentId = message.getComponentId(); + + // Only accept edge if endpoints are in different components + if (!Objects.equals(currentComponentId, proposedComponentId)) { + // Create acceptance message + MSTMessage acceptanceMessage = + new MSTMessage( + MSTMessage.MessageType.EDGE_ACCEPTANCE, + validatedSourceId, + validatedTargetId, + message.getWeight(), + proposedComponentId); + + // Send acceptance message to the source vertex + context.sendMessage(validatedSourceId, acceptanceMessage); + + LOGGER.debug( + "Accepted edge proposal: {} -- {} (weight: {}) between components {} and {}", + validatedSourceId, + validatedTargetId, + message.getWeight(), + currentComponentId, + proposedComponentId); + + return true; + } else { + // Edge endpoints are in the same component, would create a cycle + // Send rejection message + MSTMessage rejectionMessage = + new MSTMessage( + MSTMessage.MessageType.EDGE_REJECTION, + validatedSourceId, + validatedTargetId, + message.getWeight(), + proposedComponentId); + + // Send rejection message to the source vertex + context.sendMessage(validatedSourceId, rejectionMessage); + + LOGGER.debug( + "Rejected edge proposal: {} -- {} (weight: {}) - same component {}", + validatedSourceId, + validatedTargetId, + message.getWeight(), + currentComponentId); + + return false; + } + } + + /** Handle edge acceptance message. Add MST edge and merge components. */ + private boolean handleEdgeAcceptance(Object vertexId, MSTMessage message, MSTVertexState state) { + // Validate vertex IDs using cached type information + Object validatedVertexId = validateVertexId(vertexId); + Object validatedSourceId = validateVertexId(message.getSourceId()); + + // Create MST edge with validated IDs + MSTEdge mstEdge = new MSTEdge(validatedVertexId, validatedSourceId, message.getWeight()); + state.addMSTEdge(mstEdge); + + // Merge components with type validation + Object validatedMessageComponentId = validateVertexId(message.getComponentId()); + Object newComponentId = findMinComponentId(state.getComponentId(), validatedMessageComponentId); + state.setComponentId(newComponentId); + + return true; + } + + /** Handle edge rejection message. Record rejected edges. */ + private boolean handleEdgeRejection(Object vertexId, MSTMessage message, MSTVertexState state) { + // Can record rejected edges here for debugging or statistics + return false; + } + + /** Handle MST edge discovery message. Record discovered MST edges. */ + private boolean handleMSTEdgeFound(Object vertexId, MSTMessage message, MSTVertexState state) { + MSTEdge foundEdge = message.getEdge(); + if (foundEdge != null && !state.getMstEdges().contains(foundEdge)) { + state.addMSTEdge(foundEdge); + return true; + } + return false; + } + + /** + * Validate and convert vertex ID to ensure type safety. Uses TypeCastUtil for comprehensive type + * validation and conversion. + * + * @param vertexId The vertex ID to validate + * @return The validated vertex ID + * @throws IllegalArgumentException if vertexId is null or type incompatible + */ + private Object validateVertexId(Object vertexId) { + if (vertexId == null) { + throw new IllegalArgumentException("Vertex ID cannot be null"); } - /** - * Get current vertex state. - * Create new state if it doesn't exist. - */ - private MSTVertexState getCurrentVertexState(RowVertex vertex) { - if (vertex.getValue() != null) { - Object stateObj = vertex.getValue().getField(STATE_FIELD_INDEX, ObjectType.INSTANCE); - if (stateObj != null) { - if (!(stateObj instanceof MSTVertexState)) { - throw new IllegalStateException( - String.format("Invalid vertex state type in getCurrentVertexState(): expected %s, got %s (value: %s)", - MSTVertexState.class.getSimpleName(), - stateObj.getClass().getSimpleName(), - stateObj) - ); - } - return (MSTVertexState) stateObj; - } - } - // Validate vertex ID when creating new state - Object validatedVertexId = validateVertexId(vertex.getId()); - return new MSTVertexState(validatedVertexId); + // If idType is not initialized (should not happen in normal flow), return as-is + if (idType == null) { + return vertexId; } - /** - * Select smaller component ID as new component ID. - * ID selection strategy for component merging. - */ - private Object findMinComponentId(Object id1, Object id2) { - if (id1.toString().compareTo(id2.toString()) < 0) { - return id1; + try { + // Use TypeCastUtil for type conversion - this handles all supported type conversions + return TypeCastUtil.cast(vertexId, idType.getTypeClass()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format( + "Invalid vertex ID type conversion: expected %s, got %s (value: %s). Error: %s", + idType.getTypeClass().getSimpleName(), + vertexId.getClass().getSimpleName(), + vertexId, + e.getMessage()), + e); + } + } + + /** Get current vertex state. Create new state if it doesn't exist. */ + private MSTVertexState getCurrentVertexState(RowVertex vertex) { + if (vertex.getValue() != null) { + Object stateObj = vertex.getValue().getField(STATE_FIELD_INDEX, ObjectType.INSTANCE); + if (stateObj != null) { + if (!(stateObj instanceof MSTVertexState)) { + throw new IllegalStateException( + String.format( + "Invalid vertex state type in getCurrentVertexState(): expected %s, got %s" + + " (value: %s)", + MSTVertexState.class.getSimpleName(), + stateObj.getClass().getSimpleName(), + stateObj)); } - return id2; + return (MSTVertexState) stateObj; + } + } + // Validate vertex ID when creating new state + Object validatedVertexId = validateVertexId(vertex.getId()); + return new MSTVertexState(validatedVertexId); + } + + /** + * Select smaller component ID as new component ID. ID selection strategy for component merging. + */ + private Object findMinComponentId(Object id1, Object id2) { + if (id1.toString().compareTo(id2.toString()) < 0) { + return id1; } - -} \ No newline at end of file + return id2; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncWeakConnectedComponents.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncWeakConnectedComponents.java index 824e4c6d8..77b2dfe66 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncWeakConnectedComponents.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncWeakConnectedComponents.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.primitive.BooleanType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -40,98 +41,97 @@ import org.apache.geaflow.model.graph.edge.EdgeDirection; @Description(name = "inc_wcc", description = "built-in udga for WeakConnectedComponents") -public class IncWeakConnectedComponents implements AlgorithmUserFunction, - IncrementalAlgorithmUserFunction { +public class IncWeakConnectedComponents + implements AlgorithmUserFunction, IncrementalAlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private String keyFieldName = "component"; - private int iteration = 20; + private AlgorithmRuntimeContext context; + private String keyFieldName = "component"; + private int iteration = 20; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 2) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([alpha, [convergence, [max_iteration]]])"); - } - if (parameters.length > 0) { - iteration = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - keyFieldName = String.valueOf(parameters[1]); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([alpha, [convergence, [max_iteration]]])"); } - - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); - Object component = null; - if (updatedValues.isPresent()) { - component = updatedValues.get().getField(0, ObjectType.INSTANCE); - } - if (context.getCurrentIterationId() == 1L) { - Object initValue = vertex.getId(); - if (component == null || ((Comparable) initValue).compareTo(component) < 0) { - // In iteration 1, if the vertex is activated for the first time, assign an initial - // value to it and output. Send this message to its neighbors. - sendMessageToNeighbors(edges, initValue); - context.updateVertexValue(ObjectRow.create(initValue, true)); - } else { - // If the vertex already has a component value, and that value is less than its - // id, there is no need to output it, as its value has not changed, and there is - // no need to send messages to its neighbors. - sendMessageToNeighbors(edges, component); - context.updateVertexValue(ObjectRow.create(component, false)); - } - } else if (context.getCurrentIterationId() < iteration) { - // Find min component in messages. - Object minComponent = messages.next(); - while (messages.hasNext()) { - Object next = messages.next(); - if (((Comparable) next).compareTo(minComponent) < 0) { - minComponent = next; - } - } - if (component != null) { - minComponent = ((Comparable) component).compareTo(minComponent) < 0 ? component : minComponent; - } - if (!minComponent.equals(component)) { - // If the min component in messages is smaller than the component in current - // vertex, send message to its neighbors, update its value and output. - sendMessageToNeighbors(edges, minComponent); - context.updateVertexValue(ObjectRow.create(minComponent, true)); - } - } + if (parameters.length > 0) { + iteration = Integer.parseInt(String.valueOf(parameters[0])); + } + if (parameters.length > 1) { + keyFieldName = String.valueOf(parameters[1]); } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - if (updatedValues.isPresent()) { - boolean vertexUpdateFlag = (boolean) updatedValues.get().getField(1, BooleanType.INSTANCE); - // Get the vertex update flag, if it is true, output the component of current vertex. - // If it is false, it indicates that the component value of the current vertex has not - // changed and does not need to be outputted. - if (vertexUpdateFlag) { - Object component = updatedValues.get().getField(0, ObjectType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), component)); - context.updateVertexValue(ObjectRow.create(component, false)); - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); + Object component = null; + if (updatedValues.isPresent()) { + component = updatedValues.get().getField(0, ObjectType.INSTANCE); + } + if (context.getCurrentIterationId() == 1L) { + Object initValue = vertex.getId(); + if (component == null || ((Comparable) initValue).compareTo(component) < 0) { + // In iteration 1, if the vertex is activated for the first time, assign an initial + // value to it and output. Send this message to its neighbors. + sendMessageToNeighbors(edges, initValue); + context.updateVertexValue(ObjectRow.create(initValue, true)); + } else { + // If the vertex already has a component value, and that value is less than its + // id, there is no need to output it, as its value has not changed, and there is + // no need to send messages to its neighbors. + sendMessageToNeighbors(edges, component); + context.updateVertexValue(ObjectRow.create(component, false)); + } + } else if (context.getCurrentIterationId() < iteration) { + // Find min component in messages. + Object minComponent = messages.next(); + while (messages.hasNext()) { + Object next = messages.next(); + if (((Comparable) next).compareTo(minComponent) < 0) { + minComponent = next; } + } + if (component != null) { + minComponent = + ((Comparable) component).compareTo(minComponent) < 0 ? component : minComponent; + } + if (!minComponent.equals(component)) { + // If the min component in messages is smaller than the component in current + // vertex, send message to its neighbors, update its value and output. + sendMessageToNeighbors(edges, minComponent); + context.updateVertexValue(ObjectRow.create(minComponent, true)); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - IType idType = graphSchema.getIdType(); - return new StructType( - new TableField("id", idType, false), - new TableField(keyFieldName, idType, false) - ); + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + if (updatedValues.isPresent()) { + boolean vertexUpdateFlag = (boolean) updatedValues.get().getField(1, BooleanType.INSTANCE); + // Get the vertex update flag, if it is true, output the component of current vertex. + // If it is false, it indicates that the component value of the current vertex has not + // changed and does not need to be outputted. + if (vertexUpdateFlag) { + Object component = updatedValues.get().getField(0, ObjectType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), component)); + context.updateVertexValue(ObjectRow.create(component, false)); + } } + } - private void sendMessageToNeighbors(List edges, Object message) { - for (RowEdge rowEdge : edges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + IType idType = graphSchema.getIdType(); + return new StructType( + new TableField("id", idType, false), new TableField(keyFieldName, idType, false)); + } + + private void sendMessageToNeighbors(List edges, Object message) { + for (RowEdge rowEdge : edges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncrementalKCore.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncrementalKCore.java index 084d1721c..ae5a0ffce 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncrementalKCore.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IncrementalKCore.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.type.primitive.StringType; @@ -46,293 +47,285 @@ /** * Production-ready Incremental K-Core algorithm implementation. - * - *

This implementation provides comprehensive K-Core computation for dynamic graphs with: - * - Efficient incremental updates for edge additions/deletions - * - Proper state management for distributed computation - * - Change detection and status tracking (INIT, UNCHANGED, ADDED, REMOVED) - * - Convergence detection and early termination - * - Memory-efficient vertex-centric computation model + * + *

This implementation provides comprehensive K-Core computation for dynamic graphs with: - + * Efficient incremental updates for edge additions/deletions - Proper state management for + * distributed computation - Change detection and status tracking (INIT, UNCHANGED, ADDED, REMOVED) + * - Convergence detection and early termination - Memory-efficient vertex-centric computation model * - Production-level error handling and logging - * - *

Algorithm Overview: - * K-Core is a maximal subgraph where each vertex has at least k neighbors within the subgraph. - * The algorithm iteratively removes vertices with degree < k until convergence. - * - *

Incremental Processing: - * - Tracks vertex states across multiple iterations - * - Efficiently handles graph updates by maintaining change status - * - Supports both static and dynamic graph scenarios - * + * + *

Algorithm Overview: K-Core is a maximal subgraph where each vertex has at least k neighbors + * within the subgraph. The algorithm iteratively removes vertices with degree < k until + * convergence. + * + *

Incremental Processing: - Tracks vertex states across multiple iterations - Efficiently + * handles graph updates by maintaining change status - Supports both static and dynamic graph + * scenarios + * * @author Geaflow Team */ -@Description(name = "incremental_kcore", description = "Production-ready Incremental K-Core algorithm") -public class IncrementalKCore implements AlgorithmUserFunction, - IncrementalAlgorithmUserFunction { - - private static final Logger LOGGER = LoggerFactory.getLogger(IncrementalKCore.class); - - private AlgorithmRuntimeContext context; - - // Algorithm parameters - private int k = 3; // K value for K-Core decomposition - private int maxIterations = 100; // Maximum iterations to prevent infinite loops - private double convergenceThreshold = 0.001; // Convergence detection threshold - - // State management for incremental computation - using instance variables instead of static - private final Map vertexStates = new HashMap<>(); - private final Set changedVertices = new HashSet<>(); - private boolean isFirstExecution = true; - private boolean isInitialRun = false; - - /** - * Internal vertex state for K-Core computation. - */ - private static class VertexState { - int coreValue; // Current K-Core value - int degree; // Current degree - String changeStatus; // Change status: INIT, UNCHANGED, ADDED, REMOVED - boolean isActive; // Whether vertex is active in current iteration - - VertexState(int coreValue, int degree, String changeStatus) { - this.coreValue = coreValue; - this.degree = degree; - this.changeStatus = changeStatus; - this.isActive = true; - } +@Description( + name = "incremental_kcore", + description = "Production-ready Incremental K-Core algorithm") +public class IncrementalKCore + implements AlgorithmUserFunction, IncrementalAlgorithmUserFunction { + + private static final Logger LOGGER = LoggerFactory.getLogger(IncrementalKCore.class); + + private AlgorithmRuntimeContext context; + + // Algorithm parameters + private int k = 3; // K value for K-Core decomposition + private int maxIterations = 100; // Maximum iterations to prevent infinite loops + private double convergenceThreshold = 0.001; // Convergence detection threshold + + // State management for incremental computation - using instance variables instead of static + private final Map vertexStates = new HashMap<>(); + private final Set changedVertices = new HashSet<>(); + private boolean isFirstExecution = true; + private boolean isInitialRun = false; + + /** Internal vertex state for K-Core computation. */ + private static class VertexState { + int coreValue; // Current K-Core value + int degree; // Current degree + String changeStatus; // Change status: INIT, UNCHANGED, ADDED, REMOVED + boolean isActive; // Whether vertex is active in current iteration + + VertexState(int coreValue, int degree, String changeStatus) { + this.coreValue = coreValue; + this.degree = degree; + this.changeStatus = changeStatus; + this.isActive = true; } - - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - - // Parse algorithm parameters - if (parameters.length > 0) { - this.k = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - this.maxIterations = Integer.parseInt(String.valueOf(parameters[1])); - } - if (parameters.length > 2) { - this.convergenceThreshold = Double.parseDouble(String.valueOf(parameters[2])); - } - - if (parameters.length > 3) { - throw new IllegalArgumentException( - "Only support up to 3 arguments: k, maxIterations, convergenceThreshold"); - } - - // Initialize state on first execution - if (isFirstExecution) { - vertexStates.clear(); - changedVertices.clear(); - isFirstExecution = false; - // Mark this as the very first run to set INIT status - isInitialRun = true; - LOGGER.info("Incremental K-Core algorithm initialized with k={}, maxIterations={}, threshold={}", - k, maxIterations, convergenceThreshold); - } + } + + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + + // Parse algorithm parameters + if (parameters.length > 0) { + this.k = Integer.parseInt(String.valueOf(parameters[0])); + } + if (parameters.length > 1) { + this.maxIterations = Integer.parseInt(String.valueOf(parameters[1])); + } + if (parameters.length > 2) { + this.convergenceThreshold = Double.parseDouble(String.valueOf(parameters[2])); } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - - Object vertexId = vertex.getId(); - long iterationId = this.context.getCurrentIterationId(); - - // Load all edges for degree calculation - List outEdges = this.context.loadEdges(EdgeDirection.OUT); - List inEdges = this.context.loadEdges(EdgeDirection.IN); - int totalDegree = outEdges.size() + inEdges.size(); - - if (iterationId == 1L) { - // First iteration: initialize vertex state - VertexState state = vertexStates.get(vertexId); - if (state == null) { - // New vertex - initialize status based on whether this is the first run - String initialStatus = isInitialRun ? "INIT" : "UNCHANGED"; - state = new VertexState(totalDegree, totalDegree, initialStatus); - vertexStates.put(vertexId, state); - } else { - // Existing vertex - mark as UNCHANGED initially - state.degree = totalDegree; - state.changeStatus = "UNCHANGED"; - state.isActive = true; - } - - // Initialize vertex value with current degree - this.context.updateVertexValue(ObjectRow.create(totalDegree, state.changeStatus)); - - // Send initial messages to all neighbors - sendMessagesToAllNeighbors(outEdges, inEdges, 1); - - } else { - // Subsequent iterations: K-Core computation - if (iterationId > maxIterations) { - LOGGER.warn("Maximum iterations ({}) reached for vertex {}", maxIterations, vertexId); - return; - } - - // Get current vertex state - VertexState state = vertexStates.get(vertexId); - if (state == null || !state.isActive) { - return; // Vertex already processed or removed - } - - // Count active neighbors from messages - int activeNeighborCount = 0; - while (messages.hasNext()) { - Object msg = messages.next(); - if (msg instanceof Integer && (Integer) msg > 0) { - activeNeighborCount += (Integer) msg; - } else if (msg instanceof Integer) { - // Handle zero or negative messages (valid but no contribution) - // Do nothing - these are valid control messages - } else { - // Handle unknown message types with GeaflowRuntimeException - String messageType = msg != null ? msg.getClass().getSimpleName() : "null"; - throw new GeaflowRuntimeException( - "Unknown message type: " + messageType + " for vertex " + vertexId - ); - } - } - - // Apply K-Core algorithm logic - boolean shouldRemove = activeNeighborCount < k; - int newCoreValue = shouldRemove ? 0 : activeNeighborCount; - - // Update vertex state - boolean stateChanged = (state.coreValue != newCoreValue); - state.coreValue = newCoreValue; - state.isActive = !shouldRemove; - - if (stateChanged) { - changedVertices.add(vertexId); - if (shouldRemove && !"REMOVED".equals(state.changeStatus)) { - state.changeStatus = "REMOVED"; - } else if (!shouldRemove && "REMOVED".equals(state.changeStatus)) { - state.changeStatus = "ADDED"; - } - } - - // Update vertex value - this.context.updateVertexValue(ObjectRow.create(newCoreValue, state.changeStatus)); - - // Send messages only if vertex is still active - if (state.isActive) { - sendMessagesToAllNeighbors(outEdges, inEdges, 1); - } - } - - // Always send self-message to continue computation - context.sendMessage(vertexId, 0); + if (parameters.length > 3) { + throw new IllegalArgumentException( + "Only support up to 3 arguments: k, maxIterations, convergenceThreshold"); } - - /** - * Send messages to all neighbors (both incoming and outgoing). - */ - private void sendMessagesToAllNeighbors(List outEdges, List inEdges, int message) { - // Send to outgoing neighbors - for (RowEdge edge : outEdges) { - context.sendMessage(edge.getTargetId(), message); - } - - // Send to incoming neighbors - for (RowEdge edge : inEdges) { - context.sendMessage(edge.getSrcId(), message); - } + + // Initialize state on first execution + if (isFirstExecution) { + vertexStates.clear(); + changedVertices.clear(); + isFirstExecution = false; + // Mark this as the very first run to set INIT status + isInitialRun = true; + LOGGER.info( + "Incremental K-Core algorithm initialized with k={}, maxIterations={}, threshold={}", + k, + maxIterations, + convergenceThreshold); } + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + + Object vertexId = vertex.getId(); + long iterationId = this.context.getCurrentIterationId(); - @Override - public void finish(RowVertex vertex, Optional updatedValues) { - updatedValues.ifPresent(vertex::setValue); - - Object vertexId = vertex.getId(); - - // Get vertex state from storage - VertexState state = vertexStates.get(vertexId); - - // Calculate current degree - List outEdges = context.loadEdges(EdgeDirection.OUT); - List inEdges = context.loadEdges(EdgeDirection.IN); - int currentDegree = outEdges.size() + inEdges.size(); - - // Determine final output values - int outputCoreValue; - int outputDegree = currentDegree; - String outputChangeStatus; - - if (state != null) { - // For incremental K-Core: output the actual Core value computed by the algorithm - // In test case 002, we need to output the degree itself as core value for simple graphs - if (currentDegree >= k) { - // Vertex meets minimum degree requirement - outputCoreValue = currentDegree; - } else { - // Vertex doesn't meet minimum degree requirement - outputCoreValue = currentDegree; - } - outputChangeStatus = state.changeStatus; - - // Update state for next execution - state.degree = currentDegree; + // Load all edges for degree calculation + List outEdges = this.context.loadEdges(EdgeDirection.OUT); + List inEdges = this.context.loadEdges(EdgeDirection.IN); + int totalDegree = outEdges.size() + inEdges.size(); + + if (iterationId == 1L) { + // First iteration: initialize vertex state + VertexState state = vertexStates.get(vertexId); + if (state == null) { + // New vertex - initialize status based on whether this is the first run + String initialStatus = isInitialRun ? "INIT" : "UNCHANGED"; + state = new VertexState(totalDegree, totalDegree, initialStatus); + vertexStates.put(vertexId, state); + } else { + // Existing vertex - mark as UNCHANGED initially + state.degree = totalDegree; + state.changeStatus = "UNCHANGED"; + state.isActive = true; + } + + // Initialize vertex value with current degree + this.context.updateVertexValue(ObjectRow.create(totalDegree, state.changeStatus)); + + // Send initial messages to all neighbors + sendMessagesToAllNeighbors(outEdges, inEdges, 1); + + } else { + // Subsequent iterations: K-Core computation + if (iterationId > maxIterations) { + LOGGER.warn("Maximum iterations ({}) reached for vertex {}", maxIterations, vertexId); + return; + } + + // Get current vertex state + VertexState state = vertexStates.get(vertexId); + if (state == null || !state.isActive) { + return; // Vertex already processed or removed + } + + // Count active neighbors from messages + int activeNeighborCount = 0; + while (messages.hasNext()) { + Object msg = messages.next(); + if (msg instanceof Integer && (Integer) msg > 0) { + activeNeighborCount += (Integer) msg; + } else if (msg instanceof Integer) { + // Handle zero or negative messages (valid but no contribution) + // Do nothing - these are valid control messages } else { - // Fallback for vertices without state - outputCoreValue = currentDegree; - outputChangeStatus = isInitialRun ? "INIT" : "UNCHANGED"; - - // Initialize state for future executions - vertexStates.put(vertexId, - new VertexState(currentDegree, currentDegree, outputChangeStatus)); + // Handle unknown message types with GeaflowRuntimeException + String messageType = msg != null ? msg.getClass().getSimpleName() : "null"; + throw new GeaflowRuntimeException( + "Unknown message type: " + messageType + " for vertex " + vertexId); } - - // Output final result - context.take(ObjectRow.create(vertexId, outputCoreValue, outputDegree, outputChangeStatus)); - - // Reset initial run flag after first execution - if (isInitialRun) { - isInitialRun = false; + } + + // Apply K-Core algorithm logic + boolean shouldRemove = activeNeighborCount < k; + int newCoreValue = shouldRemove ? 0 : activeNeighborCount; + + // Update vertex state + boolean stateChanged = (state.coreValue != newCoreValue); + state.coreValue = newCoreValue; + state.isActive = !shouldRemove; + + if (stateChanged) { + changedVertices.add(vertexId); + if (shouldRemove && !"REMOVED".equals(state.changeStatus)) { + state.changeStatus = "REMOVED"; + } else if (!shouldRemove && "REMOVED".equals(state.changeStatus)) { + state.changeStatus = "ADDED"; } - - LOGGER.debug("Vertex {} finished: core={}, degree={}, status={}", - vertexId, outputCoreValue, outputDegree, outputChangeStatus); + } + + // Update vertex value + this.context.updateVertexValue(ObjectRow.create(newCoreValue, state.changeStatus)); + + // Send messages only if vertex is still active + if (state.isActive) { + sendMessagesToAllNeighbors(outEdges, inEdges, 1); + } } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("vid", graphSchema.getIdType(), false), - new TableField("core_value", IntegerType.INSTANCE, false), - new TableField("degree", IntegerType.INSTANCE, false), - new TableField("change_status", StringType.INSTANCE, false) - ); + // Always send self-message to continue computation + context.sendMessage(vertexId, 0); + } + + /** Send messages to all neighbors (both incoming and outgoing). */ + private void sendMessagesToAllNeighbors( + List outEdges, List inEdges, int message) { + // Send to outgoing neighbors + for (RowEdge edge : outEdges) { + context.sendMessage(edge.getTargetId(), message); } - - /** - * Reset algorithm state for fresh execution. - * Useful for testing and multiple algorithm runs. - */ - public void resetState() { - vertexStates.clear(); - changedVertices.clear(); - isFirstExecution = true; - isInitialRun = false; + + // Send to incoming neighbors + for (RowEdge edge : inEdges) { + context.sendMessage(edge.getSrcId(), message); } - - /** - * Get current number of vertices being tracked. - * Useful for monitoring and debugging. - */ - public int getTrackedVertexCount() { - return vertexStates.size(); + } + + @Override + public void finish(RowVertex vertex, Optional updatedValues) { + updatedValues.ifPresent(vertex::setValue); + + Object vertexId = vertex.getId(); + + // Get vertex state from storage + VertexState state = vertexStates.get(vertexId); + + // Calculate current degree + List outEdges = context.loadEdges(EdgeDirection.OUT); + List inEdges = context.loadEdges(EdgeDirection.IN); + int currentDegree = outEdges.size() + inEdges.size(); + + // Determine final output values + int outputCoreValue; + int outputDegree = currentDegree; + String outputChangeStatus; + + if (state != null) { + // For incremental K-Core: output the actual Core value computed by the algorithm + // In test case 002, we need to output the degree itself as core value for simple graphs + if (currentDegree >= k) { + // Vertex meets minimum degree requirement + outputCoreValue = currentDegree; + } else { + // Vertex doesn't meet minimum degree requirement + outputCoreValue = currentDegree; + } + outputChangeStatus = state.changeStatus; + + // Update state for next execution + state.degree = currentDegree; + } else { + // Fallback for vertices without state + outputCoreValue = currentDegree; + outputChangeStatus = isInitialRun ? "INIT" : "UNCHANGED"; + + // Initialize state for future executions + vertexStates.put(vertexId, new VertexState(currentDegree, currentDegree, outputChangeStatus)); } - - /** - * Check if algorithm has converged based on change detection. - */ - private boolean hasConverged() { - double changeRatio = changedVertices.size() / (double) Math.max(1, vertexStates.size()); - return changeRatio < convergenceThreshold; + + // Output final result + context.take(ObjectRow.create(vertexId, outputCoreValue, outputDegree, outputChangeStatus)); + + // Reset initial run flag after first execution + if (isInitialRun) { + isInitialRun = false; } -} \ No newline at end of file + + LOGGER.debug( + "Vertex {} finished: core={}, degree={}, status={}", + vertexId, + outputCoreValue, + outputDegree, + outputChangeStatus); + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("vid", graphSchema.getIdType(), false), + new TableField("core_value", IntegerType.INSTANCE, false), + new TableField("degree", IntegerType.INSTANCE, false), + new TableField("change_status", StringType.INSTANCE, false)); + } + + /** Reset algorithm state for fresh execution. Useful for testing and multiple algorithm runs. */ + public void resetState() { + vertexStates.clear(); + changedVertices.clear(); + isFirstExecution = true; + isInitialRun = false; + } + + /** Get current number of vertices being tracked. Useful for monitoring and debugging. */ + public int getTrackedVertexCount() { + return vertexStates.size(); + } + + /** Check if algorithm has converged based on change detection. */ + private boolean hasConverged() { + double changeRatio = changedVertices.size() / (double) Math.max(1, vertexStates.size()); + return changeRatio < convergenceThreshold; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IntTreePathMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IntTreePathMessage.java index 5b7222f13..fc20cbb87 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IntTreePathMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/IntTreePathMessage.java @@ -26,299 +26,301 @@ import java.util.Iterator; import java.util.List; import java.util.Map; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class IntTreePathMessage implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(IntTreePathMessage.class); - - private IntTreePathMessage[] parents; - protected Object currentVertexId; - - public IntTreePathMessage() { - parents = null; - currentVertexId = null; + private static final Logger LOGGER = LoggerFactory.getLogger(IntTreePathMessage.class); + + private IntTreePathMessage[] parents; + protected Object currentVertexId; + + public IntTreePathMessage() { + parents = null; + currentVertexId = null; + } + + public IntTreePathMessage(IntTreePathMessage[] parents, Object currentVertexId) { + this.parents = parents; + this.currentVertexId = currentVertexId; + } + + public IntTreePathMessage(Object vertexId) { + parents = null; + currentVertexId = vertexId; + } + + public IntTreePathMessage addParentPath(IntTreePathMessage parent) { + if (parents == null) { + parents = new IntTreePathMessage[1]; + parents[0] = parent; + } else { + IntTreePathMessage[] newParents = new IntTreePathMessage[parents.length + 1]; + System.arraycopy(parents, 0, newParents, 0, parents.length); + newParents[parents.length] = parent; + parents = newParents; } + return this; + } - public IntTreePathMessage(IntTreePathMessage[] parents, Object currentVertexId) { - this.parents = parents; - this.currentVertexId = currentVertexId; + public IntTreePathMessage merge(IntTreePathMessage other) { + if (other == null) { + return this; } - - public IntTreePathMessage(Object vertexId) { - parents = null; - currentVertexId = vertexId; + if (other.parents == null && other.currentVertexId == null) { + return this; } - - public IntTreePathMessage addParentPath(IntTreePathMessage parent) { - if (parents == null) { - parents = new IntTreePathMessage[1]; - parents[0] = parent; - } else { - IntTreePathMessage[] newParents = new IntTreePathMessage[parents.length + 1]; - System.arraycopy(parents, 0, newParents, 0, parents.length); - newParents[parents.length] = parent; - parents = newParents; - } - return this; + if (this.parents == null && this.currentVertexId == null) { + // 如果当前实例为空,则直接接受其他实例的结构 + this.currentVertexId = other.currentVertexId; + this.parents = other.parents; + return this; } - - public IntTreePathMessage merge(IntTreePathMessage other) { - if (other == null) { - return this; - } - if (other.parents == null && other.currentVertexId == null) { - return this; - } - if (this.parents == null && this.currentVertexId == null) { - // 如果当前实例为空,则直接接受其他实例的结构 - this.currentVertexId = other.currentVertexId; - this.parents = other.parents; - return this; - } - if (!this.currentVertexId.equals(other.currentVertexId)) { - throw new RuntimeException("merge path failed, all paths should have same tail vertex"); - } - if (other.parents == null) { - return this; - } else if (this.parents == null) { - this.currentVertexId = other.currentVertexId; - this.parents = other.parents; - return this; - } else { - for (IntTreePathMessage otherParent : other.parents) { - boolean merged = false; - for (IntTreePathMessage parent : this.parents) { - if (parent.currentVertexId.equals(otherParent.currentVertexId)) { - // 找到同样的父节点,递归合并 - parent.merge(otherParent); - merged = true; - break; - } - } - // 如果未合并,说明需要添加新的父节点 - if (!merged) { - IntTreePathMessage[] newParents = Arrays.copyOf(this.parents, this.parents.length + 1); - newParents[this.parents.length] = new IntTreePathMessage(otherParent.currentVertexId); - newParents[this.parents.length].parents = otherParent.parents; - this.parents = newParents; - } - } - return this; - } + if (!this.currentVertexId.equals(other.currentVertexId)) { + throw new RuntimeException("merge path failed, all paths should have same tail vertex"); } - - public IntTreePathMessage addPath(Object[] path) { - if (path == null || path.length == 0) { - return this; + if (other.parents == null) { + return this; + } else if (this.parents == null) { + this.currentVertexId = other.currentVertexId; + this.parents = other.parents; + return this; + } else { + for (IntTreePathMessage otherParent : other.parents) { + boolean merged = false; + for (IntTreePathMessage parent : this.parents) { + if (parent.currentVertexId.equals(otherParent.currentVertexId)) { + // 找到同样的父节点,递归合并 + parent.merge(otherParent); + merged = true; + break; + } } - Object head = path[0]; - if (parents == null && currentVertexId == null) { - currentVertexId = head; - } else if (!currentVertexId.equals(head)) { - throw new IllegalArgumentException("Path head does not match current vertex ID. current " - + currentVertexId + " head " + head); + // 如果未合并,说明需要添加新的父节点 + if (!merged) { + IntTreePathMessage[] newParents = Arrays.copyOf(this.parents, this.parents.length + 1); + newParents[this.parents.length] = new IntTreePathMessage(otherParent.currentVertexId); + newParents[this.parents.length].parents = otherParent.parents; + this.parents = newParents; } - - if (path.length == 1) { - return this; - } - - Object[] subPath = Arrays.copyOfRange(path, 1, path.length); - - if (parents == null) { - parents = new IntTreePathMessage[]{new IntTreePathMessage(subPath[0])}; - parents[0].addPath(subPath); - } else { - boolean pathMerged = false; - for (IntTreePathMessage parent : parents) { - if (parent.currentVertexId.equals(subPath[0])) { - parent.addPath(subPath); - pathMerged = true; - break; - } - } - if (!pathMerged) { - IntTreePathMessage newParent = new IntTreePathMessage(subPath[0]); - newParent.addPath(subPath); - parents = Arrays.copyOf(parents, parents.length + 1); - parents[parents.length - 1] = newParent; - } - } - return this; + } + return this; } + } - public void extendTo(Integer tailVertexId) { - if (parents == null && currentVertexId == null) { - this.currentVertexId = tailVertexId; - return; - } - IntTreePathMessage newMessage = new IntTreePathMessage(); - newMessage.parents = this.parents; - newMessage.currentVertexId = this.currentVertexId; - this.parents = new IntTreePathMessage[1]; - this.parents[0] = newMessage; - this.currentVertexId = tailVertexId; + public IntTreePathMessage addPath(Object[] path) { + if (path == null || path.length == 0) { + return this; } - - public Iterator getPaths() { - List paths = new ArrayList<>(); - collectPaths(this, new ArrayList<>(), paths); - return paths.iterator(); + Object head = path[0]; + if (parents == null && currentVertexId == null) { + currentVertexId = head; + } else if (!currentVertexId.equals(head)) { + throw new IllegalArgumentException( + "Path head does not match current vertex ID. current " + + currentVertexId + + " head " + + head); } - /** - * get paths, last element is tail vertex id. - * - * @param message - * @param path - * @param paths - */ - private void collectPaths(IntTreePathMessage message, List path, List paths) { - if (message == null) { - return; - } - if (message.parents == null && message.currentVertexId == null) { - return; - } - path.add(message.currentVertexId); - - if (message.parents == null || message.parents.length == 0) { - Object[] pathArray = new Object[path.size()]; - for (int i = 0; i < path.size(); i++) { - pathArray[i] = path.get(path.size() - 1 - i); - } - paths.add(pathArray); - } else { - for (IntTreePathMessage parent : message.parents) { - collectPaths(parent, new ArrayList<>(path), paths); - } - } + if (path.length == 1) { + return this; } - public long getPathSize() { - if (parents == null && currentVertexId == null) { - return 0; - } else if (parents == null) { - return 1; - } else { - long pathSize = 0L; - for (int i = 0; i < parents.length; i++) { - pathSize += parents[i].getPathSize(); - } - return pathSize; + Object[] subPath = Arrays.copyOfRange(path, 1, path.length); + + if (parents == null) { + parents = new IntTreePathMessage[] {new IntTreePathMessage(subPath[0])}; + parents[0].addPath(subPath); + } else { + boolean pathMerged = false; + for (IntTreePathMessage parent : parents) { + if (parent.currentVertexId.equals(subPath[0])) { + parent.addPath(subPath); + pathMerged = true; + break; } + } + if (!pathMerged) { + IntTreePathMessage newParent = new IntTreePathMessage(subPath[0]); + newParent.addPath(subPath); + parents = Arrays.copyOf(parents, parents.length + 1); + parents[parents.length - 1] = newParent; + } } + return this; + } - public int size() { - if (parents == null && currentVertexId == null) { - return 0; - } - int cnt = 0; - Iterator itr = getPaths(); - while (itr.hasNext()) { - itr.next(); - cnt++; - } - return cnt; + public void extendTo(Integer tailVertexId) { + if (parents == null && currentVertexId == null) { + this.currentVertexId = tailVertexId; + return; } - - public int getPathLength() { - return getPathLength(this); + IntTreePathMessage newMessage = new IntTreePathMessage(); + newMessage.parents = this.parents; + newMessage.currentVertexId = this.currentVertexId; + this.parents = new IntTreePathMessage[1]; + this.parents[0] = newMessage; + this.currentVertexId = tailVertexId; + } + + public Iterator getPaths() { + List paths = new ArrayList<>(); + collectPaths(this, new ArrayList<>(), paths); + return paths.iterator(); + } + + /** + * get paths, last element is tail vertex id. + * + * @param message + * @param path + * @param paths + */ + private void collectPaths(IntTreePathMessage message, List path, List paths) { + if (message == null) { + return; } - - private int getPathLength(IntTreePathMessage message) { - if (parents == null && currentVertexId == null) { - return 0; - } - if (parents == null) { - return 1; - } - if (message == null || message.parents == null) { - return 1; - } - int maxLength = 1; - for (IntTreePathMessage parent : message.parents) { - maxLength = Math.max(maxLength, 1 + getPathLength(parent)); - } - return maxLength; + if (message.parents == null && message.currentVertexId == null) { + return; + } + path.add(message.currentVertexId); + + if (message.parents == null || message.parents.length == 0) { + Object[] pathArray = new Object[path.size()]; + for (int i = 0; i < path.size(); i++) { + pathArray[i] = path.get(path.size() - 1 - i); + } + paths.add(pathArray); + } else { + for (IntTreePathMessage parent : message.parents) { + collectPaths(parent, new ArrayList<>(path), paths); + } } + } + + public long getPathSize() { + if (parents == null && currentVertexId == null) { + return 0; + } else if (parents == null) { + return 1; + } else { + long pathSize = 0L; + for (int i = 0; i < parents.length; i++) { + pathSize += parents[i].getPathSize(); + } + return pathSize; + } + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - for (Iterator it = getPaths(); it.hasNext(); ) { - Object[] path = it.next(); - sb.append(Arrays.toString(path)); - } - return sb.toString(); + public int size() { + if (parents == null && currentVertexId == null) { + return 0; + } + int cnt = 0; + Iterator itr = getPaths(); + while (itr.hasNext()) { + itr.next(); + cnt++; } + return cnt; + } - public IntTreePathMessage filter(Integer id) { - if (this.currentVertexId.equals(id)) { - return null; - } - if (this.parents == null || this.parents.length == 0) { - return this; - } - List copyParents = new ArrayList<>(); - for (int i = 0; i < parents.length; i++) { - IntTreePathMessage copyParent = parents[i].filter(id); - if (copyParent != null) { - copyParents.add(copyParent); - } - } - if (copyParents.isEmpty()) { - return null; - } else { - IntTreePathMessage[] copyParentArray = new IntTreePathMessage[copyParents.size()]; - for (int i = 0; i < copyParents.size(); i++) { - copyParentArray[i] = copyParents.get(i); - } - return new IntTreePathMessage(copyParentArray, this.currentVertexId); - } + public int getPathLength() { + return getPathLength(this); + } + private int getPathLength(IntTreePathMessage message) { + if (parents == null && currentVertexId == null) { + return 0; } - - public Object getCurrentVertexId() { - return currentVertexId; + if (parents == null) { + return 1; } + if (message == null || message.parents == null) { + return 1; + } + int maxLength = 1; + for (IntTreePathMessage parent : message.parents) { + maxLength = Math.max(maxLength, 1 + getPathLength(parent)); + } + return maxLength; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + for (Iterator it = getPaths(); it.hasNext(); ) { + Object[] path = it.next(); + sb.append(Arrays.toString(path)); + } + return sb.toString(); + } - public Map> generatePathMap() { - Iterator outPathsIter = this.getPaths(); - Map> pathMap = new HashMap<>(); - int preSize = -1; - while (outPathsIter.hasNext()) { - Object[] path = outPathsIter.next(); - if (preSize > 0 && path.length != preSize) { - throw new RuntimeException("meet un equal size " + preSize + " " + path.length); - } - preSize = path.length; - if (!pathMap.containsKey(path[0])) { - pathMap.put(path[0], new ArrayList<>()); - } - pathMap.get(path[0]).add(path); - } - return pathMap; - + public IntTreePathMessage filter(Integer id) { + if (this.currentVertexId.equals(id)) { + return null; + } + if (this.parents == null || this.parents.length == 0) { + return this; + } + List copyParents = new ArrayList<>(); + for (int i = 0; i < parents.length; i++) { + IntTreePathMessage copyParent = parents[i].filter(id); + if (copyParent != null) { + copyParents.add(copyParent); + } + } + if (copyParents.isEmpty()) { + return null; + } else { + IntTreePathMessage[] copyParentArray = new IntTreePathMessage[copyParents.size()]; + for (int i = 0; i < copyParents.size(); i++) { + copyParentArray[i] = copyParents.get(i); + } + return new IntTreePathMessage(copyParentArray, this.currentVertexId); } + } + + public Object getCurrentVertexId() { + return currentVertexId; + } + + public Map> generatePathMap() { + Iterator outPathsIter = this.getPaths(); + Map> pathMap = new HashMap<>(); + int preSize = -1; + while (outPathsIter.hasNext()) { + Object[] path = outPathsIter.next(); + if (preSize > 0 && path.length != preSize) { + throw new RuntimeException("meet un equal size " + preSize + " " + path.length); + } + preSize = path.length; + if (!pathMap.containsKey(path[0])) { + pathMap.put(path[0], new ArrayList<>()); + } + pathMap.get(path[0]).add(path); + } + return pathMap; + } - public static class IntTreePathMessageWrapper extends IntTreePathMessage { + public static class IntTreePathMessageWrapper extends IntTreePathMessage { - public int tag; + public int tag; - public IntTreePathMessageWrapper(Object vertexId) { - super(vertexId); - } + public IntTreePathMessageWrapper(Object vertexId) { + super(vertexId); + } - public IntTreePathMessageWrapper(Object vertexId, int tag) { - super(vertexId); - this.tag = tag; - } + public IntTreePathMessageWrapper(Object vertexId, int tag) { + super(vertexId); + this.tag = tag; + } - public int getTag() { - return tag; - } + public int getTag() { + return tag; } -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/JaccardSimilarity.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/JaccardSimilarity.java index e0516f85d..9252d050a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/JaccardSimilarity.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/JaccardSimilarity.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.common.type.primitive.DoubleType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -42,155 +43,161 @@ @Description(name = "jaccard_similarity", description = "built-in udga for Jaccard Similarity") public class JaccardSimilarity implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; + private AlgorithmRuntimeContext context; - // tuple to store params - private Tuple vertices; + // tuple to store params + private Tuple vertices; - @Override - public void init(AlgorithmRuntimeContext context, Object[] params) { - this.context = context; + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; - if (params.length != 2) { - throw new IllegalArgumentException("Only support two arguments, usage: jaccard_similarity(id_a, id_b)"); - } - this.vertices = new Tuple<>( - TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType()), - TypeCastUtil.cast(params[1], context.getGraphSchema().getIdType()) - ); + if (params.length != 2) { + throw new IllegalArgumentException( + "Only support two arguments, usage: jaccard_similarity(id_a, id_b)"); } + this.vertices = + new Tuple<>( + TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType()), + TypeCastUtil.cast(params[1], context.getGraphSchema().getIdType())); + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (context.getCurrentIterationId() == 1L) { + // First iteration: vertices A and B compute their neighbor counts + if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { + List edges = context.loadEdges(EdgeDirection.BOTH); + Object sourceId = vertex.getId(); + + // Calculate unique neighbors count (de-duplicate and exclude self-loops) + Set uniqueNeighbors = new HashSet<>(); + for (RowEdge edge : edges) { + Object targetId = edge.getTargetId(); + // Exclude self-loops: only add if targetId != sourceId + if (!sourceId.equals(targetId)) { + uniqueNeighbors.add(targetId); + } + } + + // Calculate neighbor count for this vertex + long neighborCount = uniqueNeighbors.size(); + + // Send messages to all unique neighbors + // Message format: [sourceId, neighborCount, messageType] + // messageType = 0: neighbor inquiry, messageType = 1: count from target vertex + for (Object neighbor : uniqueNeighbors) { + context.sendMessage(neighbor, ObjectRow.create(sourceId, neighborCount, 0L)); + } + + // Send neighbor count to the other target vertex (A ↔ B exchange) + // Message format: [vertexId, neighborCount, messageType] + // messageType = 1: this is a count message from target vertex B + if (vertices.f0.equals(sourceId) && !vertices.f0.equals(vertices.f1)) { + context.sendMessage(vertices.f1, ObjectRow.create(sourceId, neighborCount, 1L)); + } else if (vertices.f1.equals(sourceId) && !vertices.f0.equals(vertices.f1)) { + context.sendMessage(vertices.f0, ObjectRow.create(sourceId, neighborCount, 1L)); + } + } + } else if (context.getCurrentIterationId() == 2L) { + // Second iteration: calculate Jaccard similarity + if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { + // Extract neighbor counts and count common neighbors + long neighborCountA = 0; + long neighborCountB = 0; + long localCommonNeighborCount = 0; + + while (messages.hasNext()) { + ObjectRow message = messages.next(); + Object senderId = message.getField(0, context.getGraphSchema().getIdType()); + long count = + (Long) + message.getField(1, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); + long messageType = + (Long) + message.getField(2, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - if (context.getCurrentIterationId() == 1L) { - // First iteration: vertices A and B compute their neighbor counts - if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { - List edges = context.loadEdges(EdgeDirection.BOTH); - Object sourceId = vertex.getId(); - - // Calculate unique neighbors count (de-duplicate and exclude self-loops) - Set uniqueNeighbors = new HashSet<>(); - for (RowEdge edge : edges) { - Object targetId = edge.getTargetId(); - // Exclude self-loops: only add if targetId != sourceId - if (!sourceId.equals(targetId)) { - uniqueNeighbors.add(targetId); - } - } - - // Calculate neighbor count for this vertex - long neighborCount = uniqueNeighbors.size(); - - // Send messages to all unique neighbors - // Message format: [sourceId, neighborCount, messageType] - // messageType = 0: neighbor inquiry, messageType = 1: count from target vertex - for (Object neighbor : uniqueNeighbors) { - context.sendMessage(neighbor, ObjectRow.create(sourceId, neighborCount, 0L)); - } - - // Send neighbor count to the other target vertex (A ↔ B exchange) - // Message format: [vertexId, neighborCount, messageType] - // messageType = 1: this is a count message from target vertex B - if (vertices.f0.equals(sourceId) && !vertices.f0.equals(vertices.f1)) { - context.sendMessage(vertices.f1, ObjectRow.create(sourceId, neighborCount, 1L)); - } else if (vertices.f1.equals(sourceId) && !vertices.f0.equals(vertices.f1)) { - context.sendMessage(vertices.f0, ObjectRow.create(sourceId, neighborCount, 1L)); - } + // messageType = 1: neighbor count from the other target vertex (A or B) + // messageType = 0: confirmation from common neighbor + if (messageType == 1L) { + // This is a count message from target vertex + if (vertices.f0.equals(senderId)) { + neighborCountA = count; + } else if (vertices.f1.equals(senderId)) { + neighborCountB = count; } - } else if (context.getCurrentIterationId() == 2L) { - // Second iteration: calculate Jaccard similarity - if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { - // Extract neighbor counts and count common neighbors - long neighborCountA = 0; - long neighborCountB = 0; - long localCommonNeighborCount = 0; - - while (messages.hasNext()) { - ObjectRow message = messages.next(); - Object senderId = message.getField(0, context.getGraphSchema().getIdType()); - long count = (Long) message.getField(1, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); - long messageType = (Long) message.getField(2, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); - - // messageType = 1: neighbor count from the other target vertex (A or B) - // messageType = 0: confirmation from common neighbor - if (messageType == 1L) { - // This is a count message from target vertex - if (vertices.f0.equals(senderId)) { - neighborCountA = count; - } else if (vertices.f1.equals(senderId)) { - neighborCountB = count; - } - } else { - // This is a confirmation from a common neighbor - localCommonNeighborCount++; - } - } - - // Calculate and output the Jaccard coefficient only from vertex A - if (vertices.f0.equals(vertex.getId())) { - // If neighborCountA is 0, calculate it from edges - if (neighborCountA == 0) { - Object sourceId = vertex.getId(); - List edges = context.loadEdges(EdgeDirection.BOTH); - Set neighbors = new HashSet<>(); - for (RowEdge edge : edges) { - Object targetId = edge.getTargetId(); - if (!sourceId.equals(targetId)) { - neighbors.add(targetId); - } - } - neighborCountA = neighbors.size(); - } - - // Calculate Jaccard coefficient: |A ∩ B| / |A ∪ B| - long intersection = localCommonNeighborCount; - long union = neighborCountA + neighborCountB - intersection; - double jaccardCoefficient = union == 0 ? 0.0 : (double) intersection / union; - - // Output the result - context.take(ObjectRow.create(vertices.f0, vertices.f1, jaccardCoefficient)); - } - } else { - // For non-A, non-B vertices: check if they received messages from both A and B - boolean receivedFromA = false; - boolean receivedFromB = false; - - while (messages.hasNext()) { - ObjectRow message = messages.next(); - Object senderId = message.getField(0, context.getGraphSchema().getIdType()); - long messageType = (Long) message.getField(2, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); - - // Only count messages with type 0 (neighbor inquiry) - if (messageType == 0L) { - if (vertices.f0.equals(senderId)) { - receivedFromA = true; - } - if (vertices.f1.equals(senderId)) { - receivedFromB = true; - } - } - } - - // If this vertex received messages from both A and B, it's a common neighbor - // Send confirmation to vertex A with format [vertexId, 1, 0] - if (receivedFromA && receivedFromB) { - context.sendMessage(vertices.f0, ObjectRow.create(vertex.getId(), 1L, 0L)); - } + } else { + // This is a confirmation from a common neighbor + localCommonNeighborCount++; + } + } + + // Calculate and output the Jaccard coefficient only from vertex A + if (vertices.f0.equals(vertex.getId())) { + // If neighborCountA is 0, calculate it from edges + if (neighborCountA == 0) { + Object sourceId = vertex.getId(); + List edges = context.loadEdges(EdgeDirection.BOTH); + Set neighbors = new HashSet<>(); + for (RowEdge edge : edges) { + Object targetId = edge.getTargetId(); + if (!sourceId.equals(targetId)) { + neighbors.add(targetId); + } } + neighborCountA = neighbors.size(); + } + + // Calculate Jaccard coefficient: |A ∩ B| / |A ∪ B| + long intersection = localCommonNeighborCount; + long union = neighborCountA + neighborCountB - intersection; + double jaccardCoefficient = union == 0 ? 0.0 : (double) intersection / union; + + // Output the result + context.take(ObjectRow.create(vertices.f0, vertices.f1, jaccardCoefficient)); } - } + } else { + // For non-A, non-B vertices: check if they received messages from both A and B + boolean receivedFromA = false; + boolean receivedFromB = false; - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - // No additional finish processing needed - } + while (messages.hasNext()) { + ObjectRow message = messages.next(); + Object senderId = message.getField(0, context.getGraphSchema().getIdType()); + long messageType = + (Long) + message.getField(2, org.apache.geaflow.common.type.primitive.LongType.INSTANCE); - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("vertex_a", graphSchema.getIdType(), false), - new TableField("vertex_b", graphSchema.getIdType(), false), - new TableField("jaccard_coefficient", DoubleType.INSTANCE, false) - ); + // Only count messages with type 0 (neighbor inquiry) + if (messageType == 0L) { + if (vertices.f0.equals(senderId)) { + receivedFromA = true; + } + if (vertices.f1.equals(senderId)) { + receivedFromB = true; + } + } + } + + // If this vertex received messages from both A and B, it's a common neighbor + // Send confirmation to vertex A with format [vertexId, 1, 0] + if (receivedFromA && receivedFromB) { + context.sendMessage(vertices.f0, ObjectRow.create(vertex.getId(), 1L, 0L)); + } + } } -} \ No newline at end of file + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + // No additional finish processing needed + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("vertex_a", graphSchema.getIdType(), false), + new TableField("vertex_b", graphSchema.getIdType(), false), + new TableField("jaccard_coefficient", DoubleType.INSTANCE, false)); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KCore.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KCore.java index ace6cf78f..b005a53dc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KCore.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KCore.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -38,73 +39,70 @@ @Description(name = "kcore", description = "built-in udga for KCore") public class KCore implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private int k = 1; + private AlgorithmRuntimeContext context; + private int k = 1; - @Override - public void init(AlgorithmRuntimeContext context, Object[] params) { - this.context = context; - if (params.length > 1) { - throw new IllegalArgumentException( - "Only support 1 arguments, false arguments " - + "usage: func([k]])"); - } - if (params.length > 0) { - k = Integer.parseInt(String.valueOf(params[0])); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; + if (params.length > 1) { + throw new IllegalArgumentException( + "Only support 1 arguments, false arguments " + "usage: func([k]])"); } + if (params.length > 0) { + k = Integer.parseInt(String.valueOf(params[0])); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - boolean isFinish = false; - if (this.context.getCurrentIterationId() == 1) { - this.context.updateVertexValue(ObjectRow.create(-1)); - } else { - int currentV = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); - if (currentV == 0) { - return; - } - int sum = 0; - while (messages.hasNext()) { - sum += messages.next(); - } - if (sum < k) { - isFinish = true; - sum = 0; - } - context.updateVertexValue(ObjectRow.create(sum)); - } - - if (isFinish) { - return; - } - - List outEdges = this.context.loadEdges(EdgeDirection.OUT); - for (RowEdge rowEdge : outEdges) { - context.sendMessage(rowEdge.getTargetId(), 1); - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + boolean isFinish = false; + if (this.context.getCurrentIterationId() == 1) { + this.context.updateVertexValue(ObjectRow.create(-1)); + } else { + int currentV = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); + if (currentV == 0) { + return; + } + int sum = 0; + while (messages.hasNext()) { + sum += messages.next(); + } + if (sum < k) { + isFinish = true; + sum = 0; + } + context.updateVertexValue(ObjectRow.create(sum)); + } - List inEdges = this.context.loadEdges(EdgeDirection.IN); - for (RowEdge rowEdge : inEdges) { - context.sendMessage(rowEdge.getTargetId(), 1); - } - context.sendMessage(vertex.getId(), 0); + if (isFinish) { + return; } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - updatedValues.ifPresent(graphVertex::setValue); - int component = (int) graphVertex.getValue().getField(0, IntegerType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), component)); + List outEdges = this.context.loadEdges(EdgeDirection.OUT); + for (RowEdge rowEdge : outEdges) { + context.sendMessage(rowEdge.getTargetId(), 1); } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField("v", IntegerType.INSTANCE, false) - ); + List inEdges = this.context.loadEdges(EdgeDirection.IN); + for (RowEdge rowEdge : inEdges) { + context.sendMessage(rowEdge.getTargetId(), 1); } + context.sendMessage(vertex.getId(), 0); + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + updatedValues.ifPresent(graphVertex::setValue); + int component = (int) graphVertex.getValue().getField(0, IntegerType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), component)); + } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("v", IntegerType.INSTANCE, false)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KHop.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KHop.java index 3bbd1f09c..9c9f63403 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KHop.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/KHop.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -41,70 +42,69 @@ @Description(name = "khop", description = "built-in udga for KHop") public class KHop implements AlgorithmUserFunction { - private static final String OUTPUT_ID = "id"; - private static final String OUTPUT_K = "k"; - private AlgorithmRuntimeContext context; - private Object srcId; - private int k = 1; + private static final String OUTPUT_ID = "id"; + private static final String OUTPUT_K = "k"; + private AlgorithmRuntimeContext context; + private Object srcId; + private int k = 1; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 2) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([alpha, [convergence, [max_iteration]]])"); - } - if (parameters.length > 0) { - srcId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); - } - if (parameters.length > 1) { - k = Integer.parseInt(String.valueOf(parameters[1])); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([alpha, [convergence, [max_iteration]]])"); } - - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - List outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); - if (context.getCurrentIterationId() == 1L) { - if (Objects.equals(srcId, vertex.getId())) { - sendMessageToNeighbors(outEdges, 1); - context.updateVertexValue(ObjectRow.create(0)); - } else { - context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE)); - } - } else if (context.getCurrentIterationId() <= k + 1) { - int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); - if (messages.hasNext() && currentK == Integer.MAX_VALUE) { - Integer currK = messages.next(); - context.updateVertexValue(ObjectRow.create(currK)); - sendMessageToNeighbors(outEdges, currK + 1); - } - } + if (parameters.length > 0) { + srcId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); } + if (parameters.length > 1) { + k = Integer.parseInt(String.valueOf(parameters[1])); + } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField(OUTPUT_ID, graphSchema.getIdType(), false), - new TableField(OUTPUT_K, IntegerType.INSTANCE, false) - ); + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); + if (context.getCurrentIterationId() == 1L) { + if (Objects.equals(srcId, vertex.getId())) { + sendMessageToNeighbors(outEdges, 1); + context.updateVertexValue(ObjectRow.create(0)); + } else { + context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE)); + } + } else if (context.getCurrentIterationId() <= k + 1) { + int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); + if (messages.hasNext() && currentK == Integer.MAX_VALUE) { + Integer currK = messages.next(); + context.updateVertexValue(ObjectRow.create(currK)); + sendMessageToNeighbors(outEdges, currK + 1); + } } + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField(OUTPUT_ID, graphSchema.getIdType(), false), + new TableField(OUTPUT_K, IntegerType.INSTANCE, false)); + } - @Override - public void finish(RowVertex vertex, Optional newValue) { - if (newValue.isPresent()) { - int currentK = (int) newValue.get().getField(0, IntegerType.INSTANCE); - if (currentK != Integer.MAX_VALUE) { - context.take(ObjectRow.create(vertex.getId(), currentK)); - } - } + @Override + public void finish(RowVertex vertex, Optional newValue) { + if (newValue.isPresent()) { + int currentK = (int) newValue.get().getField(0, IntegerType.INSTANCE); + if (currentK != Integer.MAX_VALUE) { + context.take(ObjectRow.create(vertex.getId(), currentK)); + } } + } - private void sendMessageToNeighbors(List outEdges, Integer message) { - for (RowEdge rowEdge : outEdges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + private void sendMessageToNeighbors(List outEdges, Integer message) { + for (RowEdge rowEdge : outEdges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java index 33bd5be65..1485f9667 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -41,81 +42,80 @@ @Description(name = "lpa", description = "built-in udga for Label Propagation Algorithm") public class LabelPropagation implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private String outputKeyName = "label"; - private int iteration = 1000; + private AlgorithmRuntimeContext context; + private String outputKeyName = "label"; + private int iteration = 1000; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 2) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([iteration, [outputKeyName]])"); - } - if (parameters.length > 0) { - iteration = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - outputKeyName = String.valueOf(parameters[1]); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([iteration, [outputKeyName]])"); } + if (parameters.length > 0) { + iteration = Integer.parseInt(String.valueOf(parameters[0])); + } + if (parameters.length > 1) { + outputKeyName = String.valueOf(parameters[1]); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); - - if (context.getCurrentIterationId() == 1L) { - String initLabel = String.valueOf(vertex.getId()); - context.updateVertexValue(ObjectRow.create(initLabel)); - sendMessageToNeighbors(edges, initLabel); - } else if (context.getCurrentIterationId() < iteration) { - Map labelCounts = new HashMap<>(); - String currentLabel = (String) vertex.getValue().getField(0, StringType.INSTANCE); + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); - while (messages.hasNext()) { - String label = messages.next(); - labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1); - } + if (context.getCurrentIterationId() == 1L) { + String initLabel = String.valueOf(vertex.getId()); + context.updateVertexValue(ObjectRow.create(initLabel)); + sendMessageToNeighbors(edges, initLabel); + } else if (context.getCurrentIterationId() < iteration) { + Map labelCounts = new HashMap<>(); + String currentLabel = (String) vertex.getValue().getField(0, StringType.INSTANCE); - String mostFrequentLabel = currentLabel; - int maxCount = 0; + while (messages.hasNext()) { + String label = messages.next(); + labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1); + } - for (Map.Entry entry : labelCounts.entrySet()) { - String label = entry.getKey(); - int count = entry.getValue(); - if (count >= maxCount && label.compareTo(mostFrequentLabel) < 0) { - mostFrequentLabel = label; - maxCount = count; - } - } + String mostFrequentLabel = currentLabel; + int maxCount = 0; - if (!mostFrequentLabel.equals(currentLabel)) { - context.updateVertexValue(ObjectRow.create(mostFrequentLabel)); - sendMessageToNeighbors(edges, mostFrequentLabel); - } + for (Map.Entry entry : labelCounts.entrySet()) { + String label = entry.getKey(); + int count = entry.getValue(); + if (count >= maxCount && label.compareTo(mostFrequentLabel) < 0) { + mostFrequentLabel = label; + maxCount = count; } - } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - updatedValues.ifPresent(graphVertex::setValue); - String label = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), label)); + if (!mostFrequentLabel.equals(currentLabel)) { + context.updateVertexValue(ObjectRow.create(mostFrequentLabel)); + sendMessageToNeighbors(edges, mostFrequentLabel); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField(outputKeyName, StringType.INSTANCE, false) - ); - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + updatedValues.ifPresent(graphVertex::setValue); + String label = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), label)); + } - private void sendMessageToNeighbors(List edges, String message) { - for (RowEdge rowEdge : edges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField(outputKeyName, StringType.INSTANCE, false)); + } + + private void sendMessageToNeighbors(List edges, String message) { + for (RowEdge rowEdge : edges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/PageRank.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/PageRank.java index 10073d7ff..efc451672 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/PageRank.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/PageRank.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.DoubleType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -39,77 +40,76 @@ @Description(name = "page_rank", description = "built-in udga for PageRank") public class PageRank implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private double alpha = 0.85; - private double convergence = 0.01; - private int iteration = 20; + private AlgorithmRuntimeContext context; + private double alpha = 0.85; + private double convergence = 0.01; + private int iteration = 20; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 3) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([alpha, [convergence, [max_iteration]]])"); - } - if (parameters.length > 0) { - alpha = Double.parseDouble(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - convergence = Double.parseDouble(String.valueOf(parameters[1])); - } - if (parameters.length > 2) { - iteration = Integer.parseInt(String.valueOf(parameters[2])); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 3) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([alpha, [convergence, [max_iteration]]])"); } - - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - List outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); - if (context.getCurrentIterationId() == 1L) { - double initValue = 1.0; - sendMessageToNeighbors(outEdges, 1.0 / outEdges.size()); - context.sendMessage(vertex.getId(), -1.0); - context.updateVertexValue(ObjectRow.create(initValue)); - } else if (context.getCurrentIterationId() < iteration) { - double sum = 0.0; - while (messages.hasNext()) { - double input = (double) messages.next(); - input = input > 0 ? input : 0.0; - sum += input; - } - double pr = (1 - alpha) + (sum * alpha); - double currentPr = (double) vertex.getValue().getField(0, DoubleType.INSTANCE); - if (Math.abs(currentPr - pr) > convergence) { - context.updateVertexValue(ObjectRow.create(pr)); - currentPr = pr; - } - sendMessageToNeighbors(outEdges, currentPr / outEdges.size()); - context.sendMessage(vertex.getId(), -1.0); - context.updateVertexValue(ObjectRow.create(pr)); - } + if (parameters.length > 0) { + alpha = Double.parseDouble(String.valueOf(parameters[0])); + } + if (parameters.length > 1) { + convergence = Double.parseDouble(String.valueOf(parameters[1])); } + if (parameters.length > 2) { + iteration = Integer.parseInt(String.valueOf(parameters[2])); + } + } - @Override - public void finish(RowVertex vertex, Optional newValue) { - if (newValue.isPresent()) { - double currentPr = (double) newValue.get().getField(0, DoubleType.INSTANCE); - context.take(ObjectRow.create(vertex.getId(), currentPr)); - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); + if (context.getCurrentIterationId() == 1L) { + double initValue = 1.0; + sendMessageToNeighbors(outEdges, 1.0 / outEdges.size()); + context.sendMessage(vertex.getId(), -1.0); + context.updateVertexValue(ObjectRow.create(initValue)); + } else if (context.getCurrentIterationId() < iteration) { + double sum = 0.0; + while (messages.hasNext()) { + double input = (double) messages.next(); + input = input > 0 ? input : 0.0; + sum += input; + } + double pr = (1 - alpha) + (sum * alpha); + double currentPr = (double) vertex.getValue().getField(0, DoubleType.INSTANCE); + if (Math.abs(currentPr - pr) > convergence) { + context.updateVertexValue(ObjectRow.create(pr)); + currentPr = pr; + } + sendMessageToNeighbors(outEdges, currentPr / outEdges.size()); + context.sendMessage(vertex.getId(), -1.0); + context.updateVertexValue(ObjectRow.create(pr)); } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField("pr", DoubleType.INSTANCE, false) - ); + @Override + public void finish(RowVertex vertex, Optional newValue) { + if (newValue.isPresent()) { + double currentPr = (double) newValue.get().getField(0, DoubleType.INSTANCE); + context.take(ObjectRow.create(vertex.getId(), currentPr)); } + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("pr", DoubleType.INSTANCE, false)); + } - private void sendMessageToNeighbors(List outEdges, Object message) { - for (RowEdge rowEdge : outEdges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + private void sendMessageToNeighbors(List outEdges, Object message) { + for (RowEdge rowEdge : outEdges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/SingleSourceShortestPath.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/SingleSourceShortestPath.java index 5018f9568..7b111765b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/SingleSourceShortestPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/SingleSourceShortestPath.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.BooleanType; import org.apache.geaflow.common.type.primitive.LongType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -39,95 +40,93 @@ import org.apache.geaflow.model.graph.edge.EdgeDirection; @Description(name = "sssp", description = "built-in udga Single Source Shortest Path") -public class SingleSourceShortestPath implements AlgorithmUserFunction, - IncrementalAlgorithmUserFunction { +public class SingleSourceShortestPath + implements AlgorithmUserFunction, IncrementalAlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private Object sourceVertexId; - private String edgeType = null; - private String vertexType = null; + private AlgorithmRuntimeContext context; + private Object sourceVertexId; + private String edgeType = null; + private String vertexType = null; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - assert parameters.length >= 1 : "SSSP algorithm need source vid parameter."; - sourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); - assert sourceVertexId != null : "Source vid cannot be null for SSSP."; - if (parameters.length >= 2) { - assert parameters[1] instanceof String : "Edge type parameter should be string."; - edgeType = (String) parameters[1]; - } - if (parameters.length >= 3) { - assert parameters[2] instanceof String : "Vertex type parameter should be string."; - vertexType = (String) parameters[2]; - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + assert parameters.length >= 1 : "SSSP algorithm need source vid parameter."; + sourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); + assert sourceVertexId != null : "Source vid cannot be null for SSSP."; + if (parameters.length >= 2) { + assert parameters[1] instanceof String : "Edge type parameter should be string."; + edgeType = (String) parameters[1]; + } + if (parameters.length >= 3) { + assert parameters[2] instanceof String : "Vertex type parameter should be string."; + vertexType = (String) parameters[2]; } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - if (vertexType != null && !vertex.getLabel().equals(vertexType)) { - return; - } - long newDistance; - if (Objects.equals(vertex.getId(), sourceVertexId)) { - newDistance = 0; - } else { - newDistance = Long.MAX_VALUE; - } - while (messages.hasNext()) { - long d = messages.next(); - if (d < newDistance) { - newDistance = d; - } - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (vertexType != null && !vertex.getLabel().equals(vertexType)) { + return; + } + long newDistance; + if (Objects.equals(vertex.getId(), sourceVertexId)) { + newDistance = 0; + } else { + newDistance = Long.MAX_VALUE; + } + while (messages.hasNext()) { + long d = messages.next(); + if (d < newDistance) { + newDistance = d; + } + } - boolean distanceUpdatedForIteration; - boolean distanceUpdatedForWindow = false; - if (updatedValues.isPresent()) { - long oldDistance = (long) updatedValues.get().getField(0, LongType.INSTANCE); - if (newDistance < oldDistance) { - distanceUpdatedForIteration = true; - } else { - newDistance = oldDistance; - distanceUpdatedForIteration = false; - } - distanceUpdatedForWindow = (Boolean) updatedValues.get().getField(1, BooleanType.INSTANCE); - } else { - distanceUpdatedForIteration = true; - } + boolean distanceUpdatedForIteration; + boolean distanceUpdatedForWindow = false; + if (updatedValues.isPresent()) { + long oldDistance = (long) updatedValues.get().getField(0, LongType.INSTANCE); + if (newDistance < oldDistance) { + distanceUpdatedForIteration = true; + } else { + newDistance = oldDistance; + distanceUpdatedForIteration = false; + } + distanceUpdatedForWindow = (Boolean) updatedValues.get().getField(1, BooleanType.INSTANCE); + } else { + distanceUpdatedForIteration = true; + } - distanceUpdatedForWindow = distanceUpdatedForWindow || distanceUpdatedForIteration; - context.updateVertexValue(ObjectRow.create(newDistance, distanceUpdatedForWindow)); - long scatterDistance = newDistance == Long.MAX_VALUE ? Long.MAX_VALUE : - newDistance + 1; - if (distanceUpdatedForIteration || context.getCurrentIterationId() <= 1L) { - for (RowEdge edge : context.loadEdges(EdgeDirection.OUT)) { - if (edgeType == null || edge.getLabel().equals(edgeType)) { - context.sendMessage(edge.getTargetId(), scatterDistance); - } - } + distanceUpdatedForWindow = distanceUpdatedForWindow || distanceUpdatedForIteration; + context.updateVertexValue(ObjectRow.create(newDistance, distanceUpdatedForWindow)); + long scatterDistance = newDistance == Long.MAX_VALUE ? Long.MAX_VALUE : newDistance + 1; + if (distanceUpdatedForIteration || context.getCurrentIterationId() <= 1L) { + for (RowEdge edge : context.loadEdges(EdgeDirection.OUT)) { + if (edgeType == null || edge.getLabel().equals(edgeType)) { + context.sendMessage(edge.getTargetId(), scatterDistance); } + } } + } - @Override - public void finish(RowVertex vertex, Optional newValue) { - if (newValue.isPresent()) { - Boolean distanceUpdated = (Boolean) newValue.get().getField(1, BooleanType.INSTANCE); - if (distanceUpdated) { - long currentDistance = (long) newValue.get().getField(0, LongType.INSTANCE); - if (currentDistance < Long.MAX_VALUE) { - context.take(ObjectRow.create(vertex.getId(), currentDistance)); - } - context.updateVertexValue(ObjectRow.create(currentDistance, false)); - } + @Override + public void finish(RowVertex vertex, Optional newValue) { + if (newValue.isPresent()) { + Boolean distanceUpdated = (Boolean) newValue.get().getField(1, BooleanType.INSTANCE); + if (distanceUpdated) { + long currentDistance = (long) newValue.get().getField(0, LongType.INSTANCE); + if (currentDistance < Long.MAX_VALUE) { + context.take(ObjectRow.create(vertex.getId(), currentDistance)); } + context.updateVertexValue(ObjectRow.create(currentDistance, false)); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField("distance", LongType.INSTANCE, false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("distance", LongType.INSTANCE, false)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/TriangleCount.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/TriangleCount.java index 85acaf573..6db7ae699 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/TriangleCount.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/TriangleCount.java @@ -19,13 +19,12 @@ package org.apache.geaflow.dsl.udf.graph; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.common.type.primitive.LongType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -39,90 +38,90 @@ import org.apache.geaflow.dsl.common.types.TableField; import org.apache.geaflow.model.graph.edge.EdgeDirection; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + @Description(name = "triangle_count", description = "built-in udga for Triangle Count.") public class TriangleCount implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; + private AlgorithmRuntimeContext context; - private final int maxIteration = 2; + private final int maxIteration = 2; - private String vertexType = null; + private String vertexType = null; - private final Set excludeSet = Sets.newHashSet(); + private final Set excludeSet = Sets.newHashSet(); - @Override - public void init(AlgorithmRuntimeContext context, Object[] params) { - this.context = context; - if (params.length >= 1) { - assert params[0] instanceof String : "Vertex type parameter should be string."; - vertexType = (String) params[0]; - } - assert params.length <= 1 : "Maximum parameter limit exceeded."; + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; + if (params.length >= 1) { + assert params[0] instanceof String : "Vertex type parameter should be string."; + vertexType = (String) params[0]; } + assert params.length <= 1 : "Maximum parameter limit exceeded."; + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - - if (context.getCurrentIterationId() == 1L) { - if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { - excludeSet.add((Long) vertex.getId()); - return; - } - - List rowEdges = context.loadEdges(EdgeDirection.BOTH); - List neighborInfo = Lists.newArrayList(); - neighborInfo.add((long) rowEdges.size()); - for (RowEdge rowEdge : rowEdges) { - neighborInfo.add(rowEdge.getTargetId()); - } - ObjectRow msg = ObjectRow.create(neighborInfo.toArray()); - for (int i = 1; i < neighborInfo.size(); i++) { - context.sendMessage(neighborInfo.get(i), msg); - } - context.sendMessage(vertex.getId(), ObjectRow.create(0L)); - context.updateVertexValue(msg); - } else if (context.getCurrentIterationId() <= maxIteration) { - if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { - return; - } - long count = 0; - Set sourceSet = row2Set(vertex.getValue()); - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - Set targetSet = row2Set(msg); - targetSet.retainAll(sourceSet); - count += targetSet.size(); - } - context.take(ObjectRow.create(vertex.getId(), count / 2)); - } - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { + if (context.getCurrentIterationId() == 1L) { + if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { + excludeSet.add((Long) vertex.getId()); + return; + } + List rowEdges = context.loadEdges(EdgeDirection.BOTH); + List neighborInfo = Lists.newArrayList(); + neighborInfo.add((long) rowEdges.size()); + for (RowEdge rowEdge : rowEdges) { + neighborInfo.add(rowEdge.getTargetId()); + } + ObjectRow msg = ObjectRow.create(neighborInfo.toArray()); + for (int i = 1; i < neighborInfo.size(); i++) { + context.sendMessage(neighborInfo.get(i), msg); + } + context.sendMessage(vertex.getId(), ObjectRow.create(0L)); + context.updateVertexValue(msg); + } else if (context.getCurrentIterationId() <= maxIteration) { + if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { + return; + } + long count = 0; + Set sourceSet = row2Set(vertex.getValue()); + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + Set targetSet = row2Set(msg); + targetSet.retainAll(sourceSet); + count += targetSet.size(); + } + context.take(ObjectRow.create(vertex.getId(), count / 2)); } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField("count", LongType.INSTANCE, false) - ); - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) {} - private Set row2Set(Row row) { - long len = (long) row.getField(0, LongType.INSTANCE); - Object[] ids = new Object[(int) len]; - for (int i = 0; i < len; i++) { - ids[i] = row.getField(i + 1, LongType.INSTANCE); - } - Set set = Sets.newHashSet(); - for (Object id : ids) { - if (!excludeSet.contains((long) id)) { - set.add((long) id); - } - } - return set; + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("count", LongType.INSTANCE, false)); + } + + private Set row2Set(Row row) { + long len = (long) row.getField(0, LongType.INSTANCE); + Object[] ids = new Object[(int) len]; + for (int i = 0; i < len; i++) { + ids[i] = row.getField(i + 1, LongType.INSTANCE); + } + Set set = Sets.newHashSet(); + for (Object id : ids) { + if (!excludeSet.contains((long) id)) { + set.add((long) id); + } } + return set; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/WeakConnectedComponents.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/WeakConnectedComponents.java index 1bd245aca..0cc5790da 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/WeakConnectedComponents.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/WeakConnectedComponents.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -39,67 +40,66 @@ @Description(name = "wcc", description = "built-in udga for WeakConnectedComponents") public class WeakConnectedComponents implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private String keyFieldName = "component"; - private int iteration = 20; + private AlgorithmRuntimeContext context; + private String keyFieldName = "component"; + private int iteration = 20; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 2) { - throw new IllegalArgumentException( - "Only support zero or more arguments, false arguments " - + "usage: func([iteration, [keyFieldName]])"); - } - if (parameters.length > 0) { - iteration = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - keyFieldName = String.valueOf(parameters[1]); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or more arguments, false arguments " + + "usage: func([iteration, [keyFieldName]])"); + } + if (parameters.length > 0) { + iteration = Integer.parseInt(String.valueOf(parameters[0])); } + if (parameters.length > 1) { + keyFieldName = String.valueOf(parameters[1]); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); - if (context.getCurrentIterationId() == 1L) { - String initValue = String.valueOf(vertex.getId()); - sendMessageToNeighbors(edges, initValue); - context.sendMessage(vertex.getId(), String.valueOf(vertex.getId())); - context.updateVertexValue(ObjectRow.create(initValue)); - } else if (context.getCurrentIterationId() < iteration) { - String minComponent = messages.next(); - while (messages.hasNext()) { - String next = messages.next(); - if (next.compareTo(minComponent) < 0) { - minComponent = next; - } - } - sendMessageToNeighbors(edges, minComponent); - context.sendMessage(vertex.getId(), minComponent); - context.updateVertexValue(ObjectRow.create(minComponent)); + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); + if (context.getCurrentIterationId() == 1L) { + String initValue = String.valueOf(vertex.getId()); + sendMessageToNeighbors(edges, initValue); + context.sendMessage(vertex.getId(), String.valueOf(vertex.getId())); + context.updateVertexValue(ObjectRow.create(initValue)); + } else if (context.getCurrentIterationId() < iteration) { + String minComponent = messages.next(); + while (messages.hasNext()) { + String next = messages.next(); + if (next.compareTo(minComponent) < 0) { + minComponent = next; } + } + sendMessageToNeighbors(edges, minComponent); + context.sendMessage(vertex.getId(), minComponent); + context.updateVertexValue(ObjectRow.create(minComponent)); } + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - updatedValues.ifPresent(graphVertex::setValue); - String component = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), component)); - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + updatedValues.ifPresent(graphVertex::setValue); + String component = (String) graphVertex.getValue().getField(0, StringType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), component)); + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField(keyFieldName, StringType.INSTANCE, false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField(keyFieldName, StringType.INSTANCE, false)); + } - private void sendMessageToNeighbors(List edges, String message) { - for (RowEdge rowEdge : edges) { - context.sendMessage(rowEdge.getTargetId(), message); - } + private void sendMessageToNeighbors(List edges, String message) { + for (RowEdge rowEdge : edges) { + context.sendMessage(rowEdge.getTargetId(), message); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTEdge.java index 0cc489585..83466f6e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTEdge.java @@ -23,174 +23,179 @@ import java.util.Objects; /** - * MST edge class. - * Represents an edge in the minimum spanning tree, containing source vertex, target vertex and weight information. - * - *

Supported operations: - * - Create edge - * - Get edge endpoints - * - Check if it is a self-loop - * - Create reverse edge - * - Compare edges (by weight and endpoints) - * + * MST edge class. Represents an edge in the minimum spanning tree, containing source vertex, target + * vertex and weight information. + * + *

Supported operations: - Create edge - Get edge endpoints - Check if it is a self-loop - Create + * reverse edge - Compare edges (by weight and endpoints) + * * @author Geaflow Team */ public class MSTEdge implements Serializable, Comparable { - - private static final long serialVersionUID = 1L; - - /** Source vertex ID. */ - private Object sourceId; - - /** Target vertex ID. */ - private Object targetId; - - /** Edge weight. */ - private double weight; - - /** - * Constructor. - * @param sourceId Source vertex ID - * @param targetId Target vertex ID - * @param weight Edge weight - */ - public MSTEdge(Object sourceId, Object targetId, double weight) { - this.sourceId = sourceId; - this.targetId = targetId; - this.weight = weight; - } - - // Getters and Setters - - public Object getSourceId() { - return sourceId; - } - - public void setSourceId(Object sourceId) { - this.sourceId = sourceId; - } - - public Object getTargetId() { - return targetId; - } - - public void setTargetId(Object targetId) { - this.targetId = targetId; - } - - public double getWeight() { - return weight; - } - - public void setWeight(double weight) { - this.weight = weight; - } - - /** - * Get the other endpoint of the edge. - * @param vertexId Known vertex ID - * @return Other endpoint ID, returns null if vertexId is not an endpoint of the edge - */ - public Object getOtherEndpoint(Object vertexId) { - if (sourceId.equals(vertexId)) { - return targetId; - } else if (targetId.equals(vertexId)) { - return sourceId; - } - return null; - } - - /** - * Check if specified vertex is an endpoint of the edge. - * @param vertexId Vertex ID - * @return Whether it is an endpoint - */ - public boolean isEndpoint(Object vertexId) { - return sourceId.equals(vertexId) || targetId.equals(vertexId); - } - - /** - * Check if it is a self-loop edge. - * @return Whether it is a self-loop - */ - public boolean isSelfLoop() { - return sourceId.equals(targetId); - } - - /** - * Create reverse edge. - * @return Reverse edge - */ - public MSTEdge reverse() { - return new MSTEdge(targetId, sourceId, weight); - } - - /** - * Check if two edges are equal (ignoring direction). - * @param other Another edge - * @return Whether they are equal - */ - public boolean equalsIgnoreDirection(MSTEdge other) { - if (this == other) { - return true; - } - if (other == null || getClass() != other.getClass()) { - return false; - } - - // Quick weight comparison first (fastest check) - if (Double.compare(other.weight, weight) != 0) { - return false; - } - - // Short-circuit direction comparison using OR operator - return (Objects.equals(sourceId, other.sourceId) && Objects.equals(targetId, other.targetId)) - || (Objects.equals(sourceId, other.targetId) && Objects.equals(targetId, other.sourceId)); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - MSTEdge edge = (MSTEdge) obj; - return Double.compare(edge.weight, weight) == 0 - && Objects.equals(sourceId, edge.sourceId) - && Objects.equals(targetId, edge.targetId); - } - - @Override - public int hashCode() { - return Objects.hash(sourceId, targetId, weight); - } - - @Override - public int compareTo(MSTEdge other) { - // First compare by weight - int weightCompare = Double.compare(this.weight, other.weight); - if (weightCompare != 0) { - return weightCompare; - } - - // If weights are equal, compare by source vertex ID - int sourceCompare = sourceId.toString().compareTo(other.sourceId.toString()); - if (sourceCompare != 0) { - return sourceCompare; - } - - // If source vertex IDs are equal, compare by target vertex ID - return targetId.toString().compareTo(other.targetId.toString()); - } - @Override - public String toString() { - return "MSTEdge{" - + "sourceId=" + sourceId - + ", targetId=" + targetId - + ", weight=" + weight - + '}'; - } -} \ No newline at end of file + private static final long serialVersionUID = 1L; + + /** Source vertex ID. */ + private Object sourceId; + + /** Target vertex ID. */ + private Object targetId; + + /** Edge weight. */ + private double weight; + + /** + * Constructor. + * + * @param sourceId Source vertex ID + * @param targetId Target vertex ID + * @param weight Edge weight + */ + public MSTEdge(Object sourceId, Object targetId, double weight) { + this.sourceId = sourceId; + this.targetId = targetId; + this.weight = weight; + } + + // Getters and Setters + + public Object getSourceId() { + return sourceId; + } + + public void setSourceId(Object sourceId) { + this.sourceId = sourceId; + } + + public Object getTargetId() { + return targetId; + } + + public void setTargetId(Object targetId) { + this.targetId = targetId; + } + + public double getWeight() { + return weight; + } + + public void setWeight(double weight) { + this.weight = weight; + } + + /** + * Get the other endpoint of the edge. + * + * @param vertexId Known vertex ID + * @return Other endpoint ID, returns null if vertexId is not an endpoint of the edge + */ + public Object getOtherEndpoint(Object vertexId) { + if (sourceId.equals(vertexId)) { + return targetId; + } else if (targetId.equals(vertexId)) { + return sourceId; + } + return null; + } + + /** + * Check if specified vertex is an endpoint of the edge. + * + * @param vertexId Vertex ID + * @return Whether it is an endpoint + */ + public boolean isEndpoint(Object vertexId) { + return sourceId.equals(vertexId) || targetId.equals(vertexId); + } + + /** + * Check if it is a self-loop edge. + * + * @return Whether it is a self-loop + */ + public boolean isSelfLoop() { + return sourceId.equals(targetId); + } + + /** + * Create reverse edge. + * + * @return Reverse edge + */ + public MSTEdge reverse() { + return new MSTEdge(targetId, sourceId, weight); + } + + /** + * Check if two edges are equal (ignoring direction). + * + * @param other Another edge + * @return Whether they are equal + */ + public boolean equalsIgnoreDirection(MSTEdge other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + // Quick weight comparison first (fastest check) + if (Double.compare(other.weight, weight) != 0) { + return false; + } + + // Short-circuit direction comparison using OR operator + return (Objects.equals(sourceId, other.sourceId) && Objects.equals(targetId, other.targetId)) + || (Objects.equals(sourceId, other.targetId) && Objects.equals(targetId, other.sourceId)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + MSTEdge edge = (MSTEdge) obj; + return Double.compare(edge.weight, weight) == 0 + && Objects.equals(sourceId, edge.sourceId) + && Objects.equals(targetId, edge.targetId); + } + + @Override + public int hashCode() { + return Objects.hash(sourceId, targetId, weight); + } + + @Override + public int compareTo(MSTEdge other) { + // First compare by weight + int weightCompare = Double.compare(this.weight, other.weight); + if (weightCompare != 0) { + return weightCompare; + } + + // If weights are equal, compare by source vertex ID + int sourceCompare = sourceId.toString().compareTo(other.sourceId.toString()); + if (sourceCompare != 0) { + return sourceCompare; + } + + // If source vertex IDs are equal, compare by target vertex ID + return targetId.toString().compareTo(other.targetId.toString()); + } + + @Override + public String toString() { + return "MSTEdge{" + + "sourceId=" + + sourceId + + ", targetId=" + + targetId + + ", weight=" + + weight + + '}'; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTMessage.java index 5162e288b..5511ec3ec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTMessage.java @@ -23,248 +23,263 @@ import java.util.Objects; /** - * MST message class. - * Used for message passing between vertices, supporting different types of MST operations. - * - *

Message types: - * - COMPONENT_UPDATE: Component update message - * - EDGE_PROPOSAL: Edge proposal message - * - EDGE_ACCEPTANCE: Edge acceptance message - * - EDGE_REJECTION: Edge rejection message - * - MST_EDGE_FOUND: MST edge discovery message - * + * MST message class. Used for message passing between vertices, supporting different types of MST + * operations. + * + *

Message types: - COMPONENT_UPDATE: Component update message - EDGE_PROPOSAL: Edge proposal + * message - EDGE_ACCEPTANCE: Edge acceptance message - EDGE_REJECTION: Edge rejection message - + * MST_EDGE_FOUND: MST edge discovery message + * * @author Geaflow Team */ public class MSTMessage implements Serializable { - - private static final long serialVersionUID = 1L; - - /** Message type enumeration. */ - public enum MessageType { - /** Component update message. */ - COMPONENT_UPDATE, - /** Edge proposal message. */ - EDGE_PROPOSAL, - /** Edge acceptance message. */ - EDGE_ACCEPTANCE, - /** Edge rejection message. */ - EDGE_REJECTION, - /** MST edge discovery message. */ - MST_EDGE_FOUND - } - /** Message type. */ - private MessageType type; - - /** Source vertex ID. */ - private Object sourceId; - - /** Target vertex ID. */ - private Object targetId; - - /** Edge weight. */ - private double weight; - - /** Component ID. */ - private Object componentId; - - /** MST edge. */ - private MSTEdge edge; - - /** Message timestamp. */ - private long timestamp; - - /** - * Constructor. - * @param type Message type - * @param sourceId Source vertex ID - * @param targetId Target vertex ID - * @param weight Edge weight - */ - public MSTMessage(MessageType type, Object sourceId, Object targetId, double weight) { - this.type = type; - this.sourceId = sourceId; - this.targetId = targetId; - this.weight = weight; - this.timestamp = System.currentTimeMillis(); - } + private static final long serialVersionUID = 1L; - /** - * Constructor with component ID. - * @param type Message type - * @param sourceId Source vertex ID - * @param targetId Target vertex ID - * @param weight Edge weight - * @param componentId Component ID - */ - public MSTMessage(MessageType type, Object sourceId, Object targetId, double weight, Object componentId) { - this(type, sourceId, targetId, weight); - this.componentId = componentId; - } + /** Message type enumeration. */ + public enum MessageType { + /** Component update message. */ + COMPONENT_UPDATE, + /** Edge proposal message. */ + EDGE_PROPOSAL, + /** Edge acceptance message. */ + EDGE_ACCEPTANCE, + /** Edge rejection message. */ + EDGE_REJECTION, + /** MST edge discovery message. */ + MST_EDGE_FOUND + } - // Getters and Setters - - public MessageType getType() { - return type; - } + /** Message type. */ + private MessageType type; - public void setType(MessageType type) { - this.type = type; - } + /** Source vertex ID. */ + private Object sourceId; - public Object getSourceId() { - return sourceId; - } + /** Target vertex ID. */ + private Object targetId; - public void setSourceId(Object sourceId) { - this.sourceId = sourceId; - } + /** Edge weight. */ + private double weight; - public Object getTargetId() { - return targetId; - } + /** Component ID. */ + private Object componentId; - public void setTargetId(Object targetId) { - this.targetId = targetId; - } + /** MST edge. */ + private MSTEdge edge; - public double getWeight() { - return weight; - } + /** Message timestamp. */ + private long timestamp; - public void setWeight(double weight) { - this.weight = weight; - } + /** + * Constructor. + * + * @param type Message type + * @param sourceId Source vertex ID + * @param targetId Target vertex ID + * @param weight Edge weight + */ + public MSTMessage(MessageType type, Object sourceId, Object targetId, double weight) { + this.type = type; + this.sourceId = sourceId; + this.targetId = targetId; + this.weight = weight; + this.timestamp = System.currentTimeMillis(); + } - public Object getComponentId() { - return componentId; - } + /** + * Constructor with component ID. + * + * @param type Message type + * @param sourceId Source vertex ID + * @param targetId Target vertex ID + * @param weight Edge weight + * @param componentId Component ID + */ + public MSTMessage( + MessageType type, Object sourceId, Object targetId, double weight, Object componentId) { + this(type, sourceId, targetId, weight); + this.componentId = componentId; + } - public void setComponentId(Object componentId) { - this.componentId = componentId; - } + // Getters and Setters - public MSTEdge getEdge() { - return edge; - } + public MessageType getType() { + return type; + } - public void setEdge(MSTEdge edge) { - this.edge = edge; - } + public void setType(MessageType type) { + this.type = type; + } - public long getTimestamp() { - return timestamp; - } + public Object getSourceId() { + return sourceId; + } - public void setTimestamp(long timestamp) { - this.timestamp = timestamp; - } + public void setSourceId(Object sourceId) { + this.sourceId = sourceId; + } - /** - * Check if this is a component update message. - * @return Whether this is a component update message - */ - public boolean isComponentUpdate() { - return type == MessageType.COMPONENT_UPDATE; - } + public Object getTargetId() { + return targetId; + } - /** - * Check if this is an edge proposal message. - * @return Whether this is an edge proposal message - */ - public boolean isEdgeProposal() { - return type == MessageType.EDGE_PROPOSAL; - } + public void setTargetId(Object targetId) { + this.targetId = targetId; + } - /** - * Check if this is an edge acceptance message. - * @return Whether this is an edge acceptance message - */ - public boolean isEdgeAcceptance() { - return type == MessageType.EDGE_ACCEPTANCE; - } + public double getWeight() { + return weight; + } - /** - * Check if this is an edge rejection message. - * @return Whether this is an edge rejection message - */ - public boolean isEdgeRejection() { - return type == MessageType.EDGE_REJECTION; - } + public void setWeight(double weight) { + this.weight = weight; + } - /** - * Check if this is an MST edge discovery message. - * @return Whether this is an MST edge discovery message - */ - public boolean isMSTEdgeFound() { - return type == MessageType.MST_EDGE_FOUND; - } + public Object getComponentId() { + return componentId; + } - /** - * Check if the message is expired. - * @param currentTime Current time - * @param timeout Timeout duration (milliseconds) - * @return Whether the message is expired - */ - public boolean isExpired(long currentTime, long timeout) { - return (currentTime - timestamp) > timeout; - } + public void setComponentId(Object componentId) { + this.componentId = componentId; + } - /** - * Create a copy of the message. - * @return Message copy - */ - public MSTMessage copy() { - MSTMessage copy = new MSTMessage(type, sourceId, targetId, weight, componentId); - copy.setEdge(edge); - copy.setTimestamp(timestamp); - return copy; - } + public MSTEdge getEdge() { + return edge; + } - /** - * Create a reverse message. - * @return Reverse message - */ - public MSTMessage reverse() { - MSTMessage reverse = new MSTMessage(type, targetId, sourceId, weight, componentId); - reverse.setEdge(edge); - reverse.setTimestamp(timestamp); - return reverse; - } + public void setEdge(MSTEdge edge) { + this.edge = edge; + } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - MSTMessage message = (MSTMessage) obj; - return Double.compare(message.weight, weight) == 0 - && timestamp == message.timestamp - && type == message.type - && Objects.equals(sourceId, message.sourceId) - && Objects.equals(targetId, message.targetId) - && Objects.equals(componentId, message.componentId) - && Objects.equals(edge, message.edge); - } + public long getTimestamp() { + return timestamp; + } - @Override - public int hashCode() { - return Objects.hash(type, sourceId, targetId, weight, componentId, edge, timestamp); - } + public void setTimestamp(long timestamp) { + this.timestamp = timestamp; + } + + /** + * Check if this is a component update message. + * + * @return Whether this is a component update message + */ + public boolean isComponentUpdate() { + return type == MessageType.COMPONENT_UPDATE; + } - @Override - public String toString() { - return "MSTMessage{" - + "type=" + type - + ", sourceId=" + sourceId - + ", targetId=" + targetId - + ", weight=" + weight - + ", componentId=" + componentId - + ", edge=" + edge - + ", timestamp=" + timestamp - + '}'; + /** + * Check if this is an edge proposal message. + * + * @return Whether this is an edge proposal message + */ + public boolean isEdgeProposal() { + return type == MessageType.EDGE_PROPOSAL; + } + + /** + * Check if this is an edge acceptance message. + * + * @return Whether this is an edge acceptance message + */ + public boolean isEdgeAcceptance() { + return type == MessageType.EDGE_ACCEPTANCE; + } + + /** + * Check if this is an edge rejection message. + * + * @return Whether this is an edge rejection message + */ + public boolean isEdgeRejection() { + return type == MessageType.EDGE_REJECTION; + } + + /** + * Check if this is an MST edge discovery message. + * + * @return Whether this is an MST edge discovery message + */ + public boolean isMSTEdgeFound() { + return type == MessageType.MST_EDGE_FOUND; + } + + /** + * Check if the message is expired. + * + * @param currentTime Current time + * @param timeout Timeout duration (milliseconds) + * @return Whether the message is expired + */ + public boolean isExpired(long currentTime, long timeout) { + return (currentTime - timestamp) > timeout; + } + + /** + * Create a copy of the message. + * + * @return Message copy + */ + public MSTMessage copy() { + MSTMessage copy = new MSTMessage(type, sourceId, targetId, weight, componentId); + copy.setEdge(edge); + copy.setTimestamp(timestamp); + return copy; + } + + /** + * Create a reverse message. + * + * @return Reverse message + */ + public MSTMessage reverse() { + MSTMessage reverse = new MSTMessage(type, targetId, sourceId, weight, componentId); + reverse.setEdge(edge); + reverse.setTimestamp(timestamp); + return reverse; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; } -} \ No newline at end of file + MSTMessage message = (MSTMessage) obj; + return Double.compare(message.weight, weight) == 0 + && timestamp == message.timestamp + && type == message.type + && Objects.equals(sourceId, message.sourceId) + && Objects.equals(targetId, message.targetId) + && Objects.equals(componentId, message.componentId) + && Objects.equals(edge, message.edge); + } + + @Override + public int hashCode() { + return Objects.hash(type, sourceId, targetId, weight, componentId, edge, timestamp); + } + + @Override + public String toString() { + return "MSTMessage{" + + "type=" + + type + + ", sourceId=" + + sourceId + + ", targetId=" + + targetId + + ", weight=" + + weight + + ", componentId=" + + componentId + + ", edge=" + + edge + + ", timestamp=" + + timestamp + + '}'; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTVertexState.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTVertexState.java index c7052c95f..8c3f28e70 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTVertexState.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/MSTVertexState.java @@ -25,291 +25,298 @@ import java.util.Set; /** - * MST vertex state class. - * Maintains state information for each vertex in the minimum spanning tree. - * - *

Contains information: - * - parentId: Parent node ID in MST - * - componentId: Component ID it belongs to - * - minEdgeWeight: Edge weight to parent node - * - isRoot: Whether it is a root node - * - mstEdges: MST edge set - * - changed: Whether the state has changed - * + * MST vertex state class. Maintains state information for each vertex in the minimum spanning tree. + * + *

Contains information: - parentId: Parent node ID in MST - componentId: Component ID it belongs + * to - minEdgeWeight: Edge weight to parent node - isRoot: Whether it is a root node - mstEdges: + * MST edge set - changed: Whether the state has changed + * * @author Geaflow Team */ public class MSTVertexState implements Serializable { - - private static final long serialVersionUID = 1L; - - /** Parent node ID in MST. */ - private Object parentId; - - /** Component ID it belongs to. */ - private Object componentId; - - /** Edge weight to parent node. */ - private double minEdgeWeight; - - /** Whether it is a root node. */ - private boolean isRoot; - - /** MST edge set with size limit to prevent memory overflow. */ - private Set mstEdges; - - /** Maximum number of MST edges to store per vertex (memory optimization). */ - private static final int MAX_MST_EDGES_PER_VERTEX = 100; // Reduced from 1000 to prevent memory overflow - - /** Whether the state has changed. */ - private boolean changed; - - /** Vertex ID. */ - private Object vertexId; - - /** - * Constructor. - * @param vertexId Vertex ID - */ - public MSTVertexState(Object vertexId) { - this.vertexId = vertexId; - this.parentId = vertexId; // Initially self as parent node - this.componentId = vertexId; // Initially self as independent component - this.minEdgeWeight = Double.MAX_VALUE; // Initial weight as infinity - this.isRoot = true; // Initially as root node - this.mstEdges = new HashSet<>(); // Initial MST edge set is empty - this.changed = false; // Initial state unchanged - } - // Getters and Setters - - public Object getParentId() { - return parentId; - } + private static final long serialVersionUID = 1L; - public void setParentId(Object parentId) { - this.parentId = parentId; - this.changed = true; - } + /** Parent node ID in MST. */ + private Object parentId; - public Object getComponentId() { - return componentId; - } + /** Component ID it belongs to. */ + private Object componentId; - public void setComponentId(Object componentId) { - this.componentId = componentId; - this.changed = true; - } + /** Edge weight to parent node. */ + private double minEdgeWeight; - public double getMinEdgeWeight() { - return minEdgeWeight; - } + /** Whether it is a root node. */ + private boolean isRoot; - public void setMinEdgeWeight(double minEdgeWeight) { - this.minEdgeWeight = minEdgeWeight; - this.changed = true; - } + /** MST edge set with size limit to prevent memory overflow. */ + private Set mstEdges; - public boolean isRoot() { - return isRoot; - } + /** Maximum number of MST edges to store per vertex (memory optimization). */ + private static final int MAX_MST_EDGES_PER_VERTEX = + 100; // Reduced from 1000 to prevent memory overflow - public void setRoot(boolean root) { - this.isRoot = root; - this.changed = true; - } + /** Whether the state has changed. */ + private boolean changed; - public Set getMstEdges() { - return mstEdges; - } + /** Vertex ID. */ + private Object vertexId; - public void setMstEdges(Set mstEdges) { - this.mstEdges = mstEdges; - this.changed = true; - } + /** + * Constructor. + * + * @param vertexId Vertex ID + */ + public MSTVertexState(Object vertexId) { + this.vertexId = vertexId; + this.parentId = vertexId; // Initially self as parent node + this.componentId = vertexId; // Initially self as independent component + this.minEdgeWeight = Double.MAX_VALUE; // Initial weight as infinity + this.isRoot = true; // Initially as root node + this.mstEdges = new HashSet<>(); // Initial MST edge set is empty + this.changed = false; // Initial state unchanged + } - public boolean isChanged() { - return changed; - } + // Getters and Setters - public void setChanged(boolean changed) { - this.changed = changed; - } + public Object getParentId() { + return parentId; + } - public Object getVertexId() { - return vertexId; - } + public void setParentId(Object parentId) { + this.parentId = parentId; + this.changed = true; + } - public void setVertexId(Object vertexId) { - this.vertexId = vertexId; - } + public Object getComponentId() { + return componentId; + } - /** - * Add MST edge with memory optimization. - * Prevents memory overflow by limiting the number of stored edges. - * @param edge MST edge - * @return Whether addition was successful - */ - public boolean addMSTEdge(MSTEdge edge) { - // Memory optimization: limit the number of MST edges per vertex - if (this.mstEdges.size() >= MAX_MST_EDGES_PER_VERTEX) { - // Remove the edge with highest weight to make room for new edge - MSTEdge heaviestEdge = this.mstEdges.stream() - .max(MSTEdge::compareTo) - .orElse(null); - if (heaviestEdge != null && edge.getWeight() < heaviestEdge.getWeight()) { - this.mstEdges.remove(heaviestEdge); - } else { - // New edge is heavier than all existing edges, skip it - return false; - } - } - - boolean added = this.mstEdges.add(edge); - if (added) { - this.changed = true; - } - return added; - } + public void setComponentId(Object componentId) { + this.componentId = componentId; + this.changed = true; + } - /** - * Remove MST edge. - * @param edge MST edge - * @return Whether removal was successful - */ - public boolean removeMSTEdge(MSTEdge edge) { - boolean removed = this.mstEdges.remove(edge); - if (removed) { - this.changed = true; - } - return removed; - } + public double getMinEdgeWeight() { + return minEdgeWeight; + } - /** - * Check if contains the specified MST edge. - * @param edge MST edge - * @return Whether it contains the edge - */ - public boolean containsMSTEdge(MSTEdge edge) { - return this.mstEdges.contains(edge); - } + public void setMinEdgeWeight(double minEdgeWeight) { + this.minEdgeWeight = minEdgeWeight; + this.changed = true; + } - /** - * Get the number of MST edges. - * @return Number of edges - */ - public int getMSTEdgeCount() { - return this.mstEdges.size(); - } + public boolean isRoot() { + return isRoot; + } - /** - * Clear MST edge set. - */ - public void clearMSTEdges() { - if (!this.mstEdges.isEmpty()) { - this.mstEdges.clear(); - this.changed = true; - } - } + public void setRoot(boolean root) { + this.isRoot = root; + this.changed = true; + } - /** - * Reset state change flag. - */ - public void resetChanged() { - this.changed = false; - } - - /** - * Memory optimization: compact MST edges by removing redundant edges. - * Keeps only the most important edges to prevent memory overflow. - */ - public void compactMSTEdges() { - if (this.mstEdges.size() > MAX_MST_EDGES_PER_VERTEX) { - // Convert to sorted list and keep only the lightest edges - Set compactedEdges = this.mstEdges.stream() - .sorted() - .limit(MAX_MST_EDGES_PER_VERTEX) - .collect(java.util.stream.Collectors.toSet()); - - this.mstEdges.clear(); - this.mstEdges.addAll(compactedEdges); - this.changed = true; - } - } - - /** - * Get memory usage estimate for this vertex state. - * @return Estimated memory usage in bytes - */ - public long getMemoryUsageEstimate() { - long baseSize = 8 * 8; // Object overhead + 8 fields - long edgesSize = this.mstEdges.size() * 32; // Approximate size per MSTEdge - return baseSize + edgesSize; + public Set getMstEdges() { + return mstEdges; + } + + public void setMstEdges(Set mstEdges) { + this.mstEdges = mstEdges; + this.changed = true; + } + + public boolean isChanged() { + return changed; + } + + public void setChanged(boolean changed) { + this.changed = changed; + } + + public Object getVertexId() { + return vertexId; + } + + public void setVertexId(Object vertexId) { + this.vertexId = vertexId; + } + + /** + * Add MST edge with memory optimization. Prevents memory overflow by limiting the number of + * stored edges. + * + * @param edge MST edge + * @return Whether addition was successful + */ + public boolean addMSTEdge(MSTEdge edge) { + // Memory optimization: limit the number of MST edges per vertex + if (this.mstEdges.size() >= MAX_MST_EDGES_PER_VERTEX) { + // Remove the edge with highest weight to make room for new edge + MSTEdge heaviestEdge = this.mstEdges.stream().max(MSTEdge::compareTo).orElse(null); + if (heaviestEdge != null && edge.getWeight() < heaviestEdge.getWeight()) { + this.mstEdges.remove(heaviestEdge); + } else { + // New edge is heavier than all existing edges, skip it + return false; + } } - /** - * Check if it is a leaf node (no child nodes). - * @return Whether it is a leaf node - */ - public boolean isLeaf() { - return this.mstEdges.isEmpty(); + boolean added = this.mstEdges.add(edge); + if (added) { + this.changed = true; } + return added; + } - /** - * Get edge weight to specified vertex. - * @param targetId Target vertex ID - * @return Edge weight, returns Double.MAX_VALUE if not exists - */ - public double getEdgeWeightTo(Object targetId) { - for (MSTEdge edge : mstEdges) { - if (edge.getTargetId().equals(targetId) || edge.getSourceId().equals(targetId)) { - return edge.getWeight(); - } - } - return Double.MAX_VALUE; + /** + * Remove MST edge. + * + * @param edge MST edge + * @return Whether removal was successful + */ + public boolean removeMSTEdge(MSTEdge edge) { + boolean removed = this.mstEdges.remove(edge); + if (removed) { + this.changed = true; } + return removed; + } - /** - * Check if connected to specified vertex. - * @param targetId Target vertex ID - * @return Whether connected - */ - public boolean isConnectedTo(Object targetId) { - return getEdgeWeightTo(targetId) < Double.MAX_VALUE; + /** + * Check if contains the specified MST edge. + * + * @param edge MST edge + * @return Whether it contains the edge + */ + public boolean containsMSTEdge(MSTEdge edge) { + return this.mstEdges.contains(edge); + } + + /** + * Get the number of MST edges. + * + * @return Number of edges + */ + public int getMSTEdgeCount() { + return this.mstEdges.size(); + } + + /** Clear MST edge set. */ + public void clearMSTEdges() { + if (!this.mstEdges.isEmpty()) { + this.mstEdges.clear(); + this.changed = true; } + } + + /** Reset state change flag. */ + public void resetChanged() { + this.changed = false; + } - @Override - public String toString() { - return "MSTVertexState{" - + "vertexId=" + vertexId - + ", parentId=" + parentId - + ", componentId=" + componentId - + ", minEdgeWeight=" + minEdgeWeight - + ", isRoot=" + isRoot - + ", mstEdges=" + mstEdges - + ", changed=" + changed - + '}'; + /** + * Memory optimization: compact MST edges by removing redundant edges. Keeps only the most + * important edges to prevent memory overflow. + */ + public void compactMSTEdges() { + if (this.mstEdges.size() > MAX_MST_EDGES_PER_VERTEX) { + // Convert to sorted list and keep only the lightest edges + Set compactedEdges = + this.mstEdges.stream() + .sorted() + .limit(MAX_MST_EDGES_PER_VERTEX) + .collect(java.util.stream.Collectors.toSet()); + + this.mstEdges.clear(); + this.mstEdges.addAll(compactedEdges); + this.changed = true; } + } + + /** + * Get memory usage estimate for this vertex state. + * + * @return Estimated memory usage in bytes + */ + public long getMemoryUsageEstimate() { + long baseSize = 8 * 8; // Object overhead + 8 fields + long edgesSize = this.mstEdges.size() * 32; // Approximate size per MSTEdge + return baseSize + edgesSize; + } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - MSTVertexState that = (MSTVertexState) obj; - return Double.compare(that.minEdgeWeight, minEdgeWeight) == 0 - && isRoot == that.isRoot - && changed == that.changed - && Objects.equals(vertexId, that.vertexId) - && Objects.equals(parentId, that.parentId) - && Objects.equals(componentId, that.componentId) - && Objects.equals(mstEdges, that.mstEdges); + /** + * Check if it is a leaf node (no child nodes). + * + * @return Whether it is a leaf node + */ + public boolean isLeaf() { + return this.mstEdges.isEmpty(); + } + + /** + * Get edge weight to specified vertex. + * + * @param targetId Target vertex ID + * @return Edge weight, returns Double.MAX_VALUE if not exists + */ + public double getEdgeWeightTo(Object targetId) { + for (MSTEdge edge : mstEdges) { + if (edge.getTargetId().equals(targetId) || edge.getSourceId().equals(targetId)) { + return edge.getWeight(); + } } + return Double.MAX_VALUE; + } + + /** + * Check if connected to specified vertex. + * + * @param targetId Target vertex ID + * @return Whether connected + */ + public boolean isConnectedTo(Object targetId) { + return getEdgeWeightTo(targetId) < Double.MAX_VALUE; + } + + @Override + public String toString() { + return "MSTVertexState{" + + "vertexId=" + + vertexId + + ", parentId=" + + parentId + + ", componentId=" + + componentId + + ", minEdgeWeight=" + + minEdgeWeight + + ", isRoot=" + + isRoot + + ", mstEdges=" + + mstEdges + + ", changed=" + + changed + + '}'; + } - @Override - public int hashCode() { - return Objects.hash(vertexId, parentId, componentId, minEdgeWeight, isRoot, mstEdges, changed); + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; } -} \ No newline at end of file + if (obj == null || getClass() != obj.getClass()) { + return false; + } + MSTVertexState that = (MSTVertexState) obj; + return Double.compare(that.minEdgeWeight, minEdgeWeight) == 0 + && isRoot == that.isRoot + && changed == that.changed + && Objects.equals(vertexId, that.vertexId) + && Objects.equals(parentId, that.parentId) + && Objects.equals(componentId, that.componentId) + && Objects.equals(mstEdges, that.mstEdges); + } + + @Override + public int hashCode() { + return Objects.hash(vertexId, parentId, componentId, minEdgeWeight, isRoot, mstEdges, changed); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/UnionFindHelper.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/UnionFindHelper.java index 7c5bf770b..3ad6ae66b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/UnionFindHelper.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/mst/UnionFindHelper.java @@ -25,248 +25,255 @@ import java.util.Objects; /** - * Union-Find data structure helper class. - * Used for managing union and find operations on disjoint sets. - * - *

Supported operations: - * - makeSet: Create new set - * - find: Find the set an element belongs to - * - union: Merge two sets - * - getSetCount: Get number of sets - * - clear: Clear all sets - * + * Union-Find data structure helper class. Used for managing union and find operations on disjoint + * sets. + * + *

Supported operations: - makeSet: Create new set - find: Find the set an element belongs to - + * union: Merge two sets - getSetCount: Get number of sets - clear: Clear all sets + * * @author Geaflow Team */ public class UnionFindHelper implements Serializable { - - private static final long serialVersionUID = 1L; - - /** Parent node mapping. */ - private Map parent; - - /** Rank mapping (for path compression optimization). */ - private Map rank; - - /** Set size mapping. */ - private Map size; - - /** Number of sets. */ - private int setCount; - - /** - * Constructor. - */ - public UnionFindHelper() { - this.parent = new HashMap<>(); - this.rank = new HashMap<>(); - this.size = new HashMap<>(); - this.setCount = 0; - } - /** - * Create new set. - * @param x Element - */ - public void makeSet(Object x) { - if (!parent.containsKey(x)) { - parent.put(x, x); - rank.put(x, 0); - size.put(x, 1); - setCount++; - } - } + private static final long serialVersionUID = 1L; - /** - * Find the root node of the set an element belongs to. - * @param x Element - * @return Root node - */ - public Object find(Object x) { - if (!parent.containsKey(x)) { - return null; - } - - if (!parent.get(x).equals(x)) { - parent.put(x, find(parent.get(x))); - } - return parent.get(x); - } + /** Parent node mapping. */ + private Map parent; - /** - * Merge two sets. - * @param x First element - * @param y Second element - * @return Whether merge was successful - */ - public boolean union(Object x, Object y) { - Object rootX = find(x); - Object rootY = find(y); - - if (rootX == null || rootY == null) { - return false; - } - - if (rootX.equals(rootY)) { - return false; // Already in the same set - } - - // Union by rank - if (rank.get(rootX) < rank.get(rootY)) { - Object temp = rootX; - rootX = rootY; - rootY = temp; - } - - parent.put(rootY, rootX); - size.put(rootX, size.get(rootX) + size.get(rootY)); - - if (rank.get(rootX).equals(rank.get(rootY))) { - rank.put(rootX, rank.get(rootX) + 1); - } - - setCount--; - return true; - } + /** Rank mapping (for path compression optimization). */ + private Map rank; - /** - * Get number of sets. - * @return Number of sets - */ - public int getSetCount() { - return setCount; + /** Set size mapping. */ + private Map size; + + /** Number of sets. */ + private int setCount; + + /** Constructor. */ + public UnionFindHelper() { + this.parent = new HashMap<>(); + this.rank = new HashMap<>(); + this.size = new HashMap<>(); + this.setCount = 0; + } + + /** + * Create new set. + * + * @param x Element + */ + public void makeSet(Object x) { + if (!parent.containsKey(x)) { + parent.put(x, x); + rank.put(x, 0); + size.put(x, 1); + setCount++; } + } - /** - * Get size of specified set. - * @param x Any element in the set - * @return Set size - */ - public int getSetSize(Object x) { - Object root = find(x); - if (root == null) { - return 0; - } - return size.get(root); + /** + * Find the root node of the set an element belongs to. + * + * @param x Element + * @return Root node + */ + public Object find(Object x) { + if (!parent.containsKey(x)) { + return null; } - /** - * Check if two elements are in the same set. - * @param x First element - * @param y Second element - * @return Whether they are in the same set - */ - public boolean isConnected(Object x, Object y) { - Object rootX = find(x); - Object rootY = find(y); - return rootX != null && rootX.equals(rootY); + if (!parent.get(x).equals(x)) { + parent.put(x, find(parent.get(x))); } + return parent.get(x); + } - /** - * Clear all sets. - */ - public void clear() { - parent.clear(); - rank.clear(); - size.clear(); - setCount = 0; + /** + * Merge two sets. + * + * @param x First element + * @param y Second element + * @return Whether merge was successful + */ + public boolean union(Object x, Object y) { + Object rootX = find(x); + Object rootY = find(y); + + if (rootX == null || rootY == null) { + return false; } - /** - * Check if Union-Find structure is empty. - * @return Whether it is empty - */ - public boolean isEmpty() { - return parent.isEmpty(); + if (rootX.equals(rootY)) { + return false; // Already in the same set } - /** - * Get number of elements in Union-Find structure. - * @return Number of elements - */ - public int size() { - return parent.size(); + // Union by rank + if (rank.get(rootX) < rank.get(rootY)) { + Object temp = rootX; + rootX = rootY; + rootY = temp; } - /** - * Check if element exists. - * @param x Element - * @return Whether it exists - */ - public boolean contains(Object x) { - return parent.containsKey(x); + parent.put(rootY, rootX); + size.put(rootX, size.get(rootX) + size.get(rootY)); + + if (rank.get(rootX).equals(rank.get(rootY))) { + rank.put(rootX, rank.get(rootX) + 1); } - /** - * Remove element (and its set). - * @param x Element - * @return Whether removal was successful - */ - public boolean remove(Object x) { - if (!parent.containsKey(x)) { - return false; - } - - Object root = find(x); - int rootSize = size.get(root); - - if (rootSize == 1) { - // If set has only one element, remove directly - parent.remove(x); - rank.remove(x); - size.remove(x); - setCount--; - } else { - // If set has multiple elements, need to reorganize - // Simplified handling here, actual applications may need more complex logic - parent.remove(x); - size.put(root, rootSize - 1); - } - - return true; + setCount--; + return true; + } + + /** + * Get number of sets. + * + * @return Number of sets + */ + public int getSetCount() { + return setCount; + } + + /** + * Get size of specified set. + * + * @param x Any element in the set + * @return Set size + */ + public int getSetSize(Object x) { + Object root = find(x); + if (root == null) { + return 0; } + return size.get(root); + } + + /** + * Check if two elements are in the same set. + * + * @param x First element + * @param y Second element + * @return Whether they are in the same set + */ + public boolean isConnected(Object x, Object y) { + Object rootX = find(x); + Object rootY = find(y); + return rootX != null && rootX.equals(rootY); + } + + /** Clear all sets. */ + public void clear() { + parent.clear(); + rank.clear(); + size.clear(); + setCount = 0; + } + + /** + * Check if Union-Find structure is empty. + * + * @return Whether it is empty + */ + public boolean isEmpty() { + return parent.isEmpty(); + } + + /** + * Get number of elements in Union-Find structure. + * + * @return Number of elements + */ + public int size() { + return parent.size(); + } + + /** + * Check if element exists. + * + * @param x Element + * @return Whether it exists + */ + public boolean contains(Object x) { + return parent.containsKey(x); + } - /** - * Get all elements in specified set. - * @param root Set root node - * @return All elements in the set - */ - public java.util.Set getSetElements(Object root) { - java.util.Set elements = new java.util.HashSet<>(); - for (Object x : parent.keySet()) { - if (find(x).equals(root)) { - elements.add(x); - } - } - return elements; + /** + * Remove element (and its set). + * + * @param x Element + * @return Whether removal was successful + */ + public boolean remove(Object x) { + if (!parent.containsKey(x)) { + return false; } - @Override - public String toString() { - return "UnionFindHelper{" - + "parent=" + parent - + ", rank=" + rank - + ", size=" + size - + ", setCount=" + setCount - + '}'; + Object root = find(x); + int rootSize = size.get(root); + + if (rootSize == 1) { + // If set has only one element, remove directly + parent.remove(x); + rank.remove(x); + size.remove(x); + setCount--; + } else { + // If set has multiple elements, need to reorganize + // Simplified handling here, actual applications may need more complex logic + parent.remove(x); + size.put(root, rootSize - 1); } - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - UnionFindHelper that = (UnionFindHelper) obj; - return setCount == that.setCount - && Objects.equals(parent, that.parent) - && Objects.equals(rank, that.rank) - && Objects.equals(size, that.size); + return true; + } + + /** + * Get all elements in specified set. + * + * @param root Set root node + * @return All elements in the set + */ + public java.util.Set getSetElements(Object root) { + java.util.Set elements = new java.util.HashSet<>(); + for (Object x : parent.keySet()) { + if (find(x).equals(root)) { + elements.add(x); + } } + return elements; + } + + @Override + public String toString() { + return "UnionFindHelper{" + + "parent=" + + parent + + ", rank=" + + rank + + ", size=" + + size + + ", setCount=" + + setCount + + '}'; + } - @Override - public int hashCode() { - return Objects.hash(parent, rank, size, setCount); + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; } -} \ No newline at end of file + if (obj == null || getClass() != obj.getClass()) { + return false; + } + UnionFindHelper that = (UnionFindHelper) obj; + return setCount == that.setCount + && Objects.equals(parent, that.parent) + && Objects.equals(rank, that.rank) + && Objects.equals(size, that.size); + } + + @Override + public int hashCode() { + return Objects.hash(parent, rank, size, setCount); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BICityInteractionAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BICityInteractionAlgorithm.java index 9ec5e3f0a..fa443fc6f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BICityInteractionAlgorithm.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BICityInteractionAlgorithm.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.LongType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -43,246 +44,244 @@ @Description(name = "bi19_interaction", description = "LDBC BI19 City Interaction Algorithm") public class BICityInteractionAlgorithm implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; + private AlgorithmRuntimeContext context; - private final int personIteration = 1; - private final int messageIterationA = 2; - private final int messageIterationB = 0; - private final Long interactionMessage = -1L; - private final Long heartbeatMessage = -2L; + private final int personIteration = 1; + private final int messageIterationA = 2; + private final int messageIterationB = 0; + private final Long interactionMessage = -1L; + private final Long heartbeatMessage = -2L; - //Assert maxIterations % 3 = personIteration ensures results are collected at Person nodes. - private final int maxIterations = 31; + // Assert maxIterations % 3 = personIteration ensures results are collected at Person nodes. + private final int maxIterations = 31; - private final String personType = "Person"; - private final String postType = "Post"; - private final String commentType = "Comment"; + private final String personType = "Person"; + private final String postType = "Post"; + private final String commentType = "Comment"; - private final String knowsType = "knows"; - private final String hasCreatorType = "hasCreator"; - private final String replyOfType = "replyOf"; - private final String isLocatedInType = "isLocatedIn"; + private final String knowsType = "knows"; + private final String hasCreatorType = "hasCreator"; + private final String replyOfType = "replyOf"; + private final String isLocatedInType = "isLocatedIn"; - private RowVertex vertexCache; - private List vertexEdgesCache = new ArrayList<>(); + private RowVertex vertexCache; + private List vertexEdgesCache = new ArrayList<>(); - private Object city1Id; - private Object city2Id; + private Object city1Id; + private Object city2Id; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - assert parameters.length >= 2 : "Algorithm need source vid parameter."; - city1Id = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); - assert city1Id != null : "city1Id cannot be null for algorithm."; - city2Id = TypeCastUtil.cast(parameters[1], context.getGraphSchema().getIdType()); - assert city2Id != null : "city2Id cannot be null for algorithm."; - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + assert parameters.length >= 2 : "Algorithm need source vid parameter."; + city1Id = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); + assert city1Id != null : "city1Id cannot be null for algorithm."; + city2Id = TypeCastUtil.cast(parameters[1], context.getGraphSchema().getIdType()); + assert city2Id != null : "city2Id cannot be null for algorithm."; + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - if (context.getCurrentIterationId() > maxIterations) { + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (context.getCurrentIterationId() > maxIterations) { + return; + } + Long vId = (Long) vertex.getId(); + updatedValues.ifPresent(vertex::setValue); + switch ((int) (context.getCurrentIterationId() % 3)) { + case personIteration: + if (vertex.getLabel().equals(personType)) { + Map city1PersonId2DistanceMap = new HashMap<>(); + Map city2PersonId2DistanceMap = new HashMap<>(); + if (context.getCurrentIterationId() == 1L) { + // Person in City 1/2 have their corresponding distance initialized to 0 + List isLocatedInEdges = loadEdges(vertex, isLocatedInType, EdgeDirection.OUT); + for (RowEdge e : isLocatedInEdges) { + if (e.getTargetId().equals(city1Id)) { + city1PersonId2DistanceMap.put(vId, 0L); + } + if (e.getTargetId().equals(city2Id)) { + city2PersonId2DistanceMap.put(vId, 0L); + } + } + } else if (context.getCurrentIterationId() == maxIterations) { + decodeVertexValueAsMaps( + vertex.getValue(), city1PersonId2DistanceMap, city2PersonId2DistanceMap); + for (Long city1Person : city1PersonId2DistanceMap.keySet()) { + for (Long city2Person : city2PersonId2DistanceMap.keySet()) { + Long weight = + city1PersonId2DistanceMap.get(city1Person) + + city2PersonId2DistanceMap.get(city2Person); + context.take(ObjectRow.create(city1Person, city2Person, weight)); + } + } return; - } - Long vId = (Long) vertex.getId(); - updatedValues.ifPresent(vertex::setValue); - switch ((int) (context.getCurrentIterationId() % 3)) { - case personIteration: - if (vertex.getLabel().equals(personType)) { - Map city1PersonId2DistanceMap = new HashMap<>(); - Map city2PersonId2DistanceMap = new HashMap<>(); - if (context.getCurrentIterationId() == 1L) { - //Person in City 1/2 have their corresponding distance initialized to 0 - List isLocatedInEdges = loadEdges(vertex, isLocatedInType, - EdgeDirection.OUT); - for (RowEdge e : isLocatedInEdges) { - if (e.getTargetId().equals(city1Id)) { - city1PersonId2DistanceMap.put(vId, 0L); - } - if (e.getTargetId().equals(city2Id)) { - city2PersonId2DistanceMap.put(vId, 0L); - } - } - } else if (context.getCurrentIterationId() == maxIterations) { - decodeVertexValueAsMaps(vertex.getValue(), - city1PersonId2DistanceMap, city2PersonId2DistanceMap); - for (Long city1Person : city1PersonId2DistanceMap.keySet()) { - for (Long city2Person : city2PersonId2DistanceMap.keySet()) { - Long weight = city1PersonId2DistanceMap.get(city1Person) - + city2PersonId2DistanceMap.get(city2Person); - context.take(ObjectRow.create(city1Person, city2Person, weight)); - } - } - return; - } else { - decodeVertexValueAsMaps(vertex.getValue(), - city1PersonId2DistanceMap, city2PersonId2DistanceMap); - } - - //Aggregate the total number of interactions from all the neighboring - // "knows" edges and temporarily store their distance map messages. - Map personId2Interactions = new HashMap<>(); - List distanceMapMessages = new ArrayList<>(); - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - Long msgFlag = (Long) msg.getField(0, LongType.INSTANCE); - if (Objects.equals(msgFlag, interactionMessage)) { - Long personId = (Long) msg.getField(1, LongType.INSTANCE); - personId2Interactions.put(personId, personId2Interactions.containsKey(personId) - ? personId2Interactions.get(personId) + 1L : 1L); - } else if (msgFlag >= 0L) { - distanceMapMessages.add(msg); - } - } + } else { + decodeVertexValueAsMaps( + vertex.getValue(), city1PersonId2DistanceMap, city2PersonId2DistanceMap); + } - for (ObjectRow msg : distanceMapMessages) { - Long map1Size = (Long) msg.getField(0, LongType.INSTANCE); - Long map2Size = (Long) msg.getField(1, LongType.INSTANCE); - Long senderVid = (Long) msg.getField( - (int) ((map1Size + map2Size + 1) * 2), LongType.INSTANCE); - if (personId2Interactions.containsKey(senderVid) - && personId2Interactions.get(senderVid) > 0L) { - Map senderCity1PersonId2DistanceMap = new HashMap<>(); - Map senderCity2PersonId2DistanceMap = new HashMap<>(); - decodeVertexValueAsMaps(msg, - senderCity1PersonId2DistanceMap, senderCity2PersonId2DistanceMap); - double numInteractions = personId2Interactions.get(senderVid); - long deltaWeight = - Math.max(Math.round(40 - Math.sqrt(numInteractions)), 1L); - for (Long city1Person : senderCity1PersonId2DistanceMap.keySet()) { - long newWeight = - senderCity1PersonId2DistanceMap.get(city1Person) + deltaWeight; - if (!city1PersonId2DistanceMap.containsKey(city1Person) - || newWeight < city1PersonId2DistanceMap.get(city1Person)) { - city1PersonId2DistanceMap.put(city1Person, newWeight); - } - } - for (Long city2Person : senderCity2PersonId2DistanceMap.keySet()) { - long newWeight = - senderCity2PersonId2DistanceMap.get(city2Person) + deltaWeight; - if (!city2PersonId2DistanceMap.containsKey(city2Person) - || newWeight < city2PersonId2DistanceMap.get(city2Person)) { - city2PersonId2DistanceMap.put(city2Person, newWeight); - } - } - } - } - context.updateVertexValue(encodeMapsAsVertexValue(city1PersonId2DistanceMap, - city2PersonId2DistanceMap, vId)); + // Aggregate the total number of interactions from all the neighboring + // "knows" edges and temporarily store their distance map messages. + Map personId2Interactions = new HashMap<>(); + List distanceMapMessages = new ArrayList<>(); + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + Long msgFlag = (Long) msg.getField(0, LongType.INSTANCE); + if (Objects.equals(msgFlag, interactionMessage)) { + Long personId = (Long) msg.getField(1, LongType.INSTANCE); + personId2Interactions.put( + personId, + personId2Interactions.containsKey(personId) + ? personId2Interactions.get(personId) + 1L + : 1L); + } else if (msgFlag >= 0L) { + distanceMapMessages.add(msg); + } + } - //Sending once own ID represent an interaction, which is then forwarded to - // the neighbors through 2 message nodes. - List hasCreatorEdges = loadEdges(vertex, hasCreatorType, - EdgeDirection.IN); - for (RowEdge e : hasCreatorEdges) { - context.sendMessage(e.getTargetId(), - ObjectRow.create(interactionMessage, vId)); - } - } - break; - case messageIterationA: - if (vertex.getLabel().equals(postType) || vertex.getLabel().equals(commentType)) { - List replyOfEdges = loadEdges(vertex, replyOfType, EdgeDirection.BOTH); - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - for (RowEdge e : replyOfEdges) { - context.sendMessage(e.getTargetId(), msg); - } - } + for (ObjectRow msg : distanceMapMessages) { + Long map1Size = (Long) msg.getField(0, LongType.INSTANCE); + Long map2Size = (Long) msg.getField(1, LongType.INSTANCE); + Long senderVid = + (Long) msg.getField((int) ((map1Size + map2Size + 1) * 2), LongType.INSTANCE); + if (personId2Interactions.containsKey(senderVid) + && personId2Interactions.get(senderVid) > 0L) { + Map senderCity1PersonId2DistanceMap = new HashMap<>(); + Map senderCity2PersonId2DistanceMap = new HashMap<>(); + decodeVertexValueAsMaps( + msg, senderCity1PersonId2DistanceMap, senderCity2PersonId2DistanceMap); + double numInteractions = personId2Interactions.get(senderVid); + long deltaWeight = Math.max(Math.round(40 - Math.sqrt(numInteractions)), 1L); + for (Long city1Person : senderCity1PersonId2DistanceMap.keySet()) { + long newWeight = senderCity1PersonId2DistanceMap.get(city1Person) + deltaWeight; + if (!city1PersonId2DistanceMap.containsKey(city1Person) + || newWeight < city1PersonId2DistanceMap.get(city1Person)) { + city1PersonId2DistanceMap.put(city1Person, newWeight); } - break; - case messageIterationB: - if (vertex.getLabel().equals(postType) || vertex.getLabel().equals(commentType)) { - List hasCreatorEdges = loadEdges(vertex, hasCreatorType, EdgeDirection.OUT); - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - for (RowEdge e : hasCreatorEdges) { - context.sendMessage(e.getTargetId(), msg); - } - } - } else if (vertex.getLabel().equals(personType)) { - //The distance map message arrive at the Person node at the same time as the - // forwarded interaction message - List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); - for (RowEdge e : knowsEdges) { - context.sendMessage(e.getTargetId(), (ObjectRow) vertex.getValue()); - } + } + for (Long city2Person : senderCity2PersonId2DistanceMap.keySet()) { + long newWeight = senderCity2PersonId2DistanceMap.get(city2Person) + deltaWeight; + if (!city2PersonId2DistanceMap.containsKey(city2Person) + || newWeight < city2PersonId2DistanceMap.get(city2Person)) { + city2PersonId2DistanceMap.put(city2Person, newWeight); } - break; - default: - } - if (vertex.getLabel().equals(personType)) { - context.sendMessage(vertex.getId(), ObjectRow.create(heartbeatMessage, 0L)); - } - } - - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("person1Id", LongType.INSTANCE, false), - new TableField("person2Id", LongType.INSTANCE, false), - new TableField("totalWeight", LongType.INSTANCE, true) - ); - } + } + } + } + context.updateVertexValue( + encodeMapsAsVertexValue(city1PersonId2DistanceMap, city2PersonId2DistanceMap, vId)); - private List loadEdges(RowVertex vertex, String edgeLabel, EdgeDirection direction) { - if (!vertex.equals(vertexCache)) { - vertexEdgesCache.clear(); - vertexEdgesCache = context.loadEdges(EdgeDirection.BOTH); - vertexCache = vertex; + // Sending once own ID represent an interaction, which is then forwarded to + // the neighbors through 2 message nodes. + List hasCreatorEdges = loadEdges(vertex, hasCreatorType, EdgeDirection.IN); + for (RowEdge e : hasCreatorEdges) { + context.sendMessage(e.getTargetId(), ObjectRow.create(interactionMessage, vId)); + } + } + break; + case messageIterationA: + if (vertex.getLabel().equals(postType) || vertex.getLabel().equals(commentType)) { + List replyOfEdges = loadEdges(vertex, replyOfType, EdgeDirection.BOTH); + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + for (RowEdge e : replyOfEdges) { + context.sendMessage(e.getTargetId(), msg); + } + } } - List results = new ArrayList<>(); - for (RowEdge e : vertexEdgesCache) { - if (e.getLabel().equals(edgeLabel) - && (direction == EdgeDirection.BOTH || e.getDirect() == direction)) { - results.add(e); + break; + case messageIterationB: + if (vertex.getLabel().equals(postType) || vertex.getLabel().equals(commentType)) { + List hasCreatorEdges = loadEdges(vertex, hasCreatorType, EdgeDirection.OUT); + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + for (RowEdge e : hasCreatorEdges) { + context.sendMessage(e.getTargetId(), msg); } + } + } else if (vertex.getLabel().equals(personType)) { + // The distance map message arrive at the Person node at the same time as the + // forwarded interaction message + List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); + for (RowEdge e : knowsEdges) { + context.sendMessage(e.getTargetId(), (ObjectRow) vertex.getValue()); + } } - return results; + break; + default: + } + if (vertex.getLabel().equals(personType)) { + context.sendMessage(vertex.getId(), ObjectRow.create(heartbeatMessage, 0L)); } + } - private static ObjectRow encodeMapsAsVertexValue(Map map1, Map map2, - Long... originValues) { - int originValuesLength = originValues != null ? originValues.length : 0; - Long[] values = new Long[(map1.keySet().size() + map2.keySet().size()) * 2 + 2 + originValuesLength]; - values[0] = (long) map1.keySet().size(); - values[1] = (long) map2.keySet().size(); - int index = 1; - for (Long key : map1.keySet()) { - values[index * 2] = key; - values[index * 2 + 1] = map1.get(key); - ++index; - } - for (Long key : map2.keySet()) { - values[index * 2] = key; - values[index * 2 + 1] = map2.get(key); - ++index; - } - for (index = 0; index < originValuesLength; index++) { - values[values.length - originValuesLength + index] = originValues[index]; - } - return ObjectRow.create((Object[]) values); + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("person1Id", LongType.INSTANCE, false), + new TableField("person2Id", LongType.INSTANCE, false), + new TableField("totalWeight", LongType.INSTANCE, true)); + } + + private List loadEdges(RowVertex vertex, String edgeLabel, EdgeDirection direction) { + if (!vertex.equals(vertexCache)) { + vertexEdgesCache.clear(); + vertexEdgesCache = context.loadEdges(EdgeDirection.BOTH); + vertexCache = vertex; } + List results = new ArrayList<>(); + for (RowEdge e : vertexEdgesCache) { + if (e.getLabel().equals(edgeLabel) + && (direction == EdgeDirection.BOTH || e.getDirect() == direction)) { + results.add(e); + } + } + return results; + } - private static void decodeVertexValueAsMaps(Row objectRow, Map map1, Map map2) { - long size1 = (long) objectRow.getField(0, LongType.INSTANCE); - long size2 = (long) objectRow.getField(1, LongType.INSTANCE); - int i = 1; - for (int j = 0; j < size1; j++) { - Long key = (Long) objectRow.getField(i * 2, LongType.INSTANCE); - Long value = (Long) objectRow.getField(i * 2 + 1, LongType.INSTANCE); - map1.put(key, value); - ++i; - } - for (int j = 0; j < size2; j++) { - Long key = (Long) objectRow.getField(i * 2, LongType.INSTANCE); - Long value = (Long) objectRow.getField(i * 2 + 1, LongType.INSTANCE); - map2.put(key, value); - ++i; - } + private static ObjectRow encodeMapsAsVertexValue( + Map map1, Map map2, Long... originValues) { + int originValuesLength = originValues != null ? originValues.length : 0; + Long[] values = + new Long[(map1.keySet().size() + map2.keySet().size()) * 2 + 2 + originValuesLength]; + values[0] = (long) map1.keySet().size(); + values[1] = (long) map2.keySet().size(); + int index = 1; + for (Long key : map1.keySet()) { + values[index * 2] = key; + values[index * 2 + 1] = map1.get(key); + ++index; + } + for (Long key : map2.keySet()) { + values[index * 2] = key; + values[index * 2 + 1] = map2.get(key); + ++index; } + for (index = 0; index < originValuesLength; index++) { + values[values.length - originValuesLength + index] = originValues[index]; + } + return ObjectRow.create((Object[]) values); + } - @Override - public void finish(RowVertex vertex, Optional newValue) { + private static void decodeVertexValueAsMaps( + Row objectRow, Map map1, Map map2) { + long size1 = (long) objectRow.getField(0, LongType.INSTANCE); + long size2 = (long) objectRow.getField(1, LongType.INSTANCE); + int i = 1; + for (int j = 0; j < size1; j++) { + Long key = (Long) objectRow.getField(i * 2, LongType.INSTANCE); + Long value = (Long) objectRow.getField(i * 2 + 1, LongType.INSTANCE); + map1.put(key, value); + ++i; + } + for (int j = 0; j < size2; j++) { + Long key = (Long) objectRow.getField(i * 2, LongType.INSTANCE); + Long value = (Long) objectRow.getField(i * 2 + 1, LongType.INSTANCE); + map2.put(key, value); + ++i; } + } + + @Override + public void finish(RowVertex vertex, Optional newValue) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIConnectionPathAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIConnectionPathAlgorithm.java index 9801bfa39..e50b249ac 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIConnectionPathAlgorithm.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIConnectionPathAlgorithm.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.primitive.DoubleType; import org.apache.geaflow.common.type.primitive.IntegerType; @@ -46,262 +47,260 @@ @Description(name = "bi15_connection", description = "LDBC BI15 Connection Path Algorithm") public class BIConnectionPathAlgorithm implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private final int propagationIterations = 30; - private final int pathSearchIterations = 60; - private final double giganticThreshold = 1000000000.0; - private final double gigantic = 1000000001.0; + private AlgorithmRuntimeContext context; + private final int propagationIterations = 30; + private final int pathSearchIterations = 60; + private final double giganticThreshold = 1000000000.0; + private final double gigantic = 1000000001.0; - private final String personType = "Person"; - private final String postType = "Post"; - private final String commentType = "Comment"; - private final String forumType = "Forum"; - private final String knowsType = "knows"; - private final String hasCreatorType = "hasCreator"; - private final String replyOfType = "replyOf"; - private final String containerOfType = "containerOf"; + private final String personType = "Person"; + private final String postType = "Post"; + private final String commentType = "Comment"; + private final String forumType = "Forum"; + private final String knowsType = "knows"; + private final String hasCreatorType = "hasCreator"; + private final String replyOfType = "replyOf"; + private final String containerOfType = "containerOf"; - private RowVertex vertexCache; - private List vertexEdgesCache = new ArrayList<>(); + private RowVertex vertexCache; + private List vertexEdgesCache = new ArrayList<>(); - private Object leftSourceVertexId; - private Object rightSourceVertexId; - private Long startDate; - private Long endDate; + private Object leftSourceVertexId; + private Object rightSourceVertexId; + private Long startDate; + private Long endDate; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - assert parameters.length >= 4 : "Algorithm need source vid parameter."; - leftSourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); - assert leftSourceVertexId != null : "leftSourceVertexId cannot be null for algorithm."; - rightSourceVertexId = TypeCastUtil.cast(parameters[1], context.getGraphSchema().getIdType()); - assert rightSourceVertexId != null : "rightSourceVertexId cannot be null for algorithm."; - startDate = (Long) TypeCastUtil.cast(parameters[2], Long.class); - assert startDate != null : "startDate cannot be null for algorithm."; - endDate = (Long) TypeCastUtil.cast(parameters[3], Long.class); - assert endDate != null : "endDate cannot be null for algorithm."; - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + assert parameters.length >= 4 : "Algorithm need source vid parameter."; + leftSourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); + assert leftSourceVertexId != null : "leftSourceVertexId cannot be null for algorithm."; + rightSourceVertexId = TypeCastUtil.cast(parameters[1], context.getGraphSchema().getIdType()); + assert rightSourceVertexId != null : "rightSourceVertexId cannot be null for algorithm."; + startDate = (Long) TypeCastUtil.cast(parameters[2], Long.class); + assert startDate != null : "startDate cannot be null for algorithm."; + endDate = (Long) TypeCastUtil.cast(parameters[3], Long.class); + assert endDate != null : "endDate cannot be null for algorithm."; + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - //Stage 1 propagation - if (context.getCurrentIterationId() < propagationIterations) { - //Send heartbeat messages to keep nodes alive - if (vertex.getLabel().equals(personType)) { - context.sendMessage(vertex.getId(), ObjectRow.create(-1.0, 0L)); - } - switch (vertex.getLabel()) { - case forumType: - long creationDate = (long) vertex.getValue().getField(0, LongType.INSTANCE); - if (creationDate >= startDate && creationDate <= endDate) { - List containerOfEdges = loadEdges(vertex, containerOfType, EdgeDirection.OUT); - //The forum sends activation message to the post contained - for (RowEdge e : containerOfEdges) { - context.sendMessage(e.getTargetId(), ObjectRow.create(-1.0, 0L)); - } - } - break; - case postType: - case commentType: - List hasCreatorEdges = loadEdges(vertex, hasCreatorType, EdgeDirection.OUT); - boolean active = false; - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - double score = (double) msg.getField(0, DoubleType.INSTANCE); - if (score < 0) { - //If an activation message is received from the upstream reply chain, - // the Post&Comment nodes transition to the active state. - active = true; - } else { - //Forward the score and the identity of the interacting person to the creator person node. - Object interPerson = msg.getField(1, LongType.INSTANCE); - for (RowEdge creatorEdge : hasCreatorEdges) { - context.sendMessage(creatorEdge.getTargetId(), - ObjectRow.create(score, interPerson)); - } - } - } - if (active) { - //Generate scores and interacting person IDs, and send them downstream. - double score = vertex.getLabel().equals(postType) ? 1.0 : 0.5; - List replyOfEdges = loadEdges(vertex, replyOfType, EdgeDirection.IN); - if (hasCreatorEdges.size() > 0 && replyOfEdges.size() > 0) { - for (RowEdge creatorEdge : hasCreatorEdges) { - for (RowEdge replyEdge : replyOfEdges) { - context.sendMessage(replyEdge.getTargetId(), - ObjectRow.create(score, creatorEdge.getTargetId())); - } - } - } - //Activate downstream Comment nodes - for (RowEdge replyEdge : replyOfEdges) { - context.sendMessage(replyEdge.getTargetId(), ObjectRow.create(-1.0, 0L)); - } - } - break; - case personType: - List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); - Map knowsPersonId2InteractionScore = new HashMap<>(); - if (context.getCurrentIterationId() == 1) { - for (RowEdge e : knowsEdges) { - knowsPersonId2InteractionScore.put((Long) e.getTargetId(), 0.0); - } - } else { - Map valuesStoreMap = decodeObjectRowAsMap(vertex.getValue(), - LongType.INSTANCE, DoubleType.INSTANCE); - knowsPersonId2InteractionScore.putAll(valuesStoreMap); - } - //Aggregate the interaction scores from upstream - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - double score = (double) msg.getField(0, DoubleType.INSTANCE); - if (score > 0) { - Long interPerson = (long) msg.getField(1, LongType.INSTANCE); - if (knowsPersonId2InteractionScore.containsKey(interPerson)) { - knowsPersonId2InteractionScore.put(interPerson, - knowsPersonId2InteractionScore.get(interPerson) + score); - } - } - } - ObjectRow mapEncodeRow = encodeMapAsObjectRow(knowsPersonId2InteractionScore); - context.updateVertexValue(mapEncodeRow); - break; - default: + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + // Stage 1 propagation + if (context.getCurrentIterationId() < propagationIterations) { + // Send heartbeat messages to keep nodes alive + if (vertex.getLabel().equals(personType)) { + context.sendMessage(vertex.getId(), ObjectRow.create(-1.0, 0L)); + } + switch (vertex.getLabel()) { + case forumType: + long creationDate = (long) vertex.getValue().getField(0, LongType.INSTANCE); + if (creationDate >= startDate && creationDate <= endDate) { + List containerOfEdges = loadEdges(vertex, containerOfType, EdgeDirection.OUT); + // The forum sends activation message to the post contained + for (RowEdge e : containerOfEdges) { + context.sendMessage(e.getTargetId(), ObjectRow.create(-1.0, 0L)); } - } else if (context.getCurrentIterationId() < pathSearchIterations) { - //Stage 2 Bidirectional dijkstra path searching over Person nodes - //Send heartbeat messages to keep nodes alive - if (vertex.getLabel().equals(personType)) { - context.sendMessage(vertex.getId(), ObjectRow.create(0L, 0.0, gigantic, gigantic)); + } + break; + case postType: + case commentType: + List hasCreatorEdges = loadEdges(vertex, hasCreatorType, EdgeDirection.OUT); + boolean active = false; + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + double score = (double) msg.getField(0, DoubleType.INSTANCE); + if (score < 0) { + // If an activation message is received from the upstream reply chain, + // the Post&Comment nodes transition to the active state. + active = true; + } else { + // Forward the score and the identity of the interacting person to the creator person + // node. + Object interPerson = msg.getField(1, LongType.INSTANCE); + for (RowEdge creatorEdge : hasCreatorEdges) { + context.sendMessage( + creatorEdge.getTargetId(), ObjectRow.create(score, interPerson)); + } } - if (personType.equals(vertex.getLabel())) { - Map valuesStoreMap = decodeObjectRowAsMap(vertex.getValue(), LongType.INSTANCE, - DoubleType.INSTANCE); - Map knowsPersonId2InteractionScore = new HashMap<>(valuesStoreMap); - boolean valueChanged = false; - Object vId = vertex.getId(); - double currentDistanceToLeft; - double currentDistanceToRight; - if (context.getCurrentIterationId() - propagationIterations == 0L) { - currentDistanceToLeft = - Objects.equals(vId, leftSourceVertexId) ? 0.0 : gigantic; - currentDistanceToRight = - Objects.equals(vId, rightSourceVertexId) ? 0.0 : gigantic; - valueChanged = true; - } else { - //Merge the message with the interaction scores saved locally, and calculate - // the distance value. - int mapSize = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); - currentDistanceToLeft = (double) vertex.getValue() - .getField(1 + 2 * mapSize, DoubleType.INSTANCE); - currentDistanceToRight = (double) vertex.getValue() - .getField(1 + 2 * mapSize + 1, DoubleType.INSTANCE); - //Msg schema: Person.id BIGINT, interactionScore DOUBLE, - // leftSourceDistance DOUBLE, rightSourceDistance DOUBLE - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - Long personId = (Long) msg.getField(0, LongType.INSTANCE); - double interactionScore = (double) msg.getField(1, DoubleType.INSTANCE); - double leftDistance = (double) msg.getField(2, DoubleType.INSTANCE); - double rightDistance = (double) msg.getField(3, DoubleType.INSTANCE); - if (leftDistance >= giganticThreshold - && rightDistance >= giganticThreshold) { - continue; - } - if (knowsPersonId2InteractionScore.containsKey(personId)) { - interactionScore += knowsPersonId2InteractionScore.get(personId); - } - double newDeltaDistance = (1.0 / (interactionScore + 1.0)); - leftDistance += newDeltaDistance; - rightDistance += newDeltaDistance; - if (leftDistance < currentDistanceToLeft) { - currentDistanceToLeft = leftDistance; - valueChanged = true; - } - if (rightDistance < currentDistanceToRight) { - currentDistanceToRight = rightDistance; - valueChanged = true; - } - } - } - context.updateVertexValue( - encodeMapAsObjectRow(knowsPersonId2InteractionScore, currentDistanceToLeft, - currentDistanceToRight)); - if (valueChanged) { - List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); - for (RowEdge e : knowsEdges) { - context.sendMessage(e.getTargetId(), - ObjectRow.create(vId, knowsPersonId2InteractionScore.get( - (Long) e.getTargetId()), currentDistanceToLeft, currentDistanceToRight)); - } + } + if (active) { + // Generate scores and interacting person IDs, and send them downstream. + double score = vertex.getLabel().equals(postType) ? 1.0 : 0.5; + List replyOfEdges = loadEdges(vertex, replyOfType, EdgeDirection.IN); + if (hasCreatorEdges.size() > 0 && replyOfEdges.size() > 0) { + for (RowEdge creatorEdge : hasCreatorEdges) { + for (RowEdge replyEdge : replyOfEdges) { + context.sendMessage( + replyEdge.getTargetId(), ObjectRow.create(score, creatorEdge.getTargetId())); } + } + } + // Activate downstream Comment nodes + for (RowEdge replyEdge : replyOfEdges) { + context.sendMessage(replyEdge.getTargetId(), ObjectRow.create(-1.0, 0L)); + } + } + break; + case personType: + List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); + Map knowsPersonId2InteractionScore = new HashMap<>(); + if (context.getCurrentIterationId() == 1) { + for (RowEdge e : knowsEdges) { + knowsPersonId2InteractionScore.put((Long) e.getTargetId(), 0.0); } + } else { + Map valuesStoreMap = + decodeObjectRowAsMap(vertex.getValue(), LongType.INSTANCE, DoubleType.INSTANCE); + knowsPersonId2InteractionScore.putAll(valuesStoreMap); + } + // Aggregate the interaction scores from upstream + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + double score = (double) msg.getField(0, DoubleType.INSTANCE); + if (score > 0) { + Long interPerson = (long) msg.getField(1, LongType.INSTANCE); + if (knowsPersonId2InteractionScore.containsKey(interPerson)) { + knowsPersonId2InteractionScore.put( + interPerson, knowsPersonId2InteractionScore.get(interPerson) + score); + } + } + } + ObjectRow mapEncodeRow = encodeMapAsObjectRow(knowsPersonId2InteractionScore); + context.updateVertexValue(mapEncodeRow); + break; + default: + } + } else if (context.getCurrentIterationId() < pathSearchIterations) { + // Stage 2 Bidirectional dijkstra path searching over Person nodes + // Send heartbeat messages to keep nodes alive + if (vertex.getLabel().equals(personType)) { + context.sendMessage(vertex.getId(), ObjectRow.create(0L, 0.0, gigantic, gigantic)); + } + if (personType.equals(vertex.getLabel())) { + Map valuesStoreMap = + decodeObjectRowAsMap(vertex.getValue(), LongType.INSTANCE, DoubleType.INSTANCE); + Map knowsPersonId2InteractionScore = new HashMap<>(valuesStoreMap); + boolean valueChanged = false; + Object vId = vertex.getId(); + double currentDistanceToLeft; + double currentDistanceToRight; + if (context.getCurrentIterationId() - propagationIterations == 0L) { + currentDistanceToLeft = Objects.equals(vId, leftSourceVertexId) ? 0.0 : gigantic; + currentDistanceToRight = Objects.equals(vId, rightSourceVertexId) ? 0.0 : gigantic; + valueChanged = true; } else { - int mapSize = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); - double currentDistanceToLeft = (double) vertex.getValue().getField( - 1 + 2 * mapSize, DoubleType.INSTANCE); - double currentDistanceToRight = (double) vertex.getValue().getField( - 1 + 2 * mapSize + 1, DoubleType.INSTANCE); - if (currentDistanceToLeft + currentDistanceToRight < giganticThreshold) { - context.take(ObjectRow.create(currentDistanceToLeft + currentDistanceToRight)); + // Merge the message with the interaction scores saved locally, and calculate + // the distance value. + int mapSize = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); + currentDistanceToLeft = + (double) vertex.getValue().getField(1 + 2 * mapSize, DoubleType.INSTANCE); + currentDistanceToRight = + (double) vertex.getValue().getField(1 + 2 * mapSize + 1, DoubleType.INSTANCE); + // Msg schema: Person.id BIGINT, interactionScore DOUBLE, + // leftSourceDistance DOUBLE, rightSourceDistance DOUBLE + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + Long personId = (Long) msg.getField(0, LongType.INSTANCE); + double interactionScore = (double) msg.getField(1, DoubleType.INSTANCE); + double leftDistance = (double) msg.getField(2, DoubleType.INSTANCE); + double rightDistance = (double) msg.getField(3, DoubleType.INSTANCE); + if (leftDistance >= giganticThreshold && rightDistance >= giganticThreshold) { + continue; + } + if (knowsPersonId2InteractionScore.containsKey(personId)) { + interactionScore += knowsPersonId2InteractionScore.get(personId); + } + double newDeltaDistance = (1.0 / (interactionScore + 1.0)); + leftDistance += newDeltaDistance; + rightDistance += newDeltaDistance; + if (leftDistance < currentDistanceToLeft) { + currentDistanceToLeft = leftDistance; + valueChanged = true; } + if (rightDistance < currentDistanceToRight) { + currentDistanceToRight = rightDistance; + valueChanged = true; + } + } } - + context.updateVertexValue( + encodeMapAsObjectRow( + knowsPersonId2InteractionScore, currentDistanceToLeft, currentDistanceToRight)); + if (valueChanged) { + List knowsEdges = loadEdges(vertex, knowsType, EdgeDirection.BOTH); + for (RowEdge e : knowsEdges) { + context.sendMessage( + e.getTargetId(), + ObjectRow.create( + vId, + knowsPersonId2InteractionScore.get((Long) e.getTargetId()), + currentDistanceToLeft, + currentDistanceToRight)); + } + } + } + } else { + int mapSize = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); + double currentDistanceToLeft = + (double) vertex.getValue().getField(1 + 2 * mapSize, DoubleType.INSTANCE); + double currentDistanceToRight = + (double) vertex.getValue().getField(1 + 2 * mapSize + 1, DoubleType.INSTANCE); + if (currentDistanceToLeft + currentDistanceToRight < giganticThreshold) { + context.take(ObjectRow.create(currentDistanceToLeft + currentDistanceToRight)); + } } + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("distance", DoubleType.INSTANCE, false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType(new TableField("distance", DoubleType.INSTANCE, false)); + } - private List loadEdges(RowVertex vertex, String edgeLabel, EdgeDirection direction) { - if (!vertex.equals(vertexCache)) { - vertexEdgesCache.clear(); - vertexEdgesCache = context.loadEdges(EdgeDirection.BOTH); - vertexCache = vertex; - } - List results = new ArrayList<>(); - for (RowEdge e : vertexEdgesCache) { - if (e.getLabel().equals(edgeLabel) - && (direction == EdgeDirection.BOTH || e.getDirect() == direction)) { - results.add(e); - } - } - return results; + private List loadEdges(RowVertex vertex, String edgeLabel, EdgeDirection direction) { + if (!vertex.equals(vertexCache)) { + vertexEdgesCache.clear(); + vertexEdgesCache = context.loadEdges(EdgeDirection.BOTH); + vertexCache = vertex; } - - private static ObjectRow encodeMapAsObjectRow(Map map, Object... originValues) { - int originValuesLength = originValues != null ? originValues.length : 0; - Object[] values = new Object[map.keySet().size() * 2 + 1 + originValuesLength]; - values[0] = map.keySet().size(); - int index = 0; - for (Object key : map.keySet()) { - values[1 + index * 2] = key; - values[1 + index * 2 + 1] = map.get(key); - ++index; - } - for (index = 0; index < originValuesLength; index++) { - values[values.length - originValuesLength + index] = originValues[index]; - } - return ObjectRow.create((Object[]) values); + List results = new ArrayList<>(); + for (RowEdge e : vertexEdgesCache) { + if (e.getLabel().equals(edgeLabel) + && (direction == EdgeDirection.BOTH || e.getDirect() == direction)) { + results.add(e); + } } + return results; + } - private static Map decodeObjectRowAsMap(Row objectRow, IType keyType, IType valueType) { - int size = (int) objectRow.getField(0, IntegerType.INSTANCE); - Map result = new HashMap(); - for (int i = 0; i < size; i++) { - Object key = objectRow.getField(1 + i * 2, keyType); - Object value = objectRow.getField(1 + i * 2 + 1, valueType); - result.put(key, value); - } - return result; + private static ObjectRow encodeMapAsObjectRow(Map map, Object... originValues) { + int originValuesLength = originValues != null ? originValues.length : 0; + Object[] values = new Object[map.keySet().size() * 2 + 1 + originValuesLength]; + values[0] = map.keySet().size(); + int index = 0; + for (Object key : map.keySet()) { + values[1 + index * 2] = key; + values[1 + index * 2 + 1] = map.get(key); + ++index; } + for (index = 0; index < originValuesLength; index++) { + values[values.length - originValuesLength + index] = originValues[index]; + } + return ObjectRow.create((Object[]) values); + } - @Override - public void finish(RowVertex vertex, Optional newValue) { + private static Map decodeObjectRowAsMap(Row objectRow, IType keyType, IType valueType) { + int size = (int) objectRow.getField(0, IntegerType.INSTANCE); + Map result = new HashMap(); + for (int i = 0; i < size; i++) { + Object key = objectRow.getField(1 + i * 2, keyType); + Object value = objectRow.getField(1 + i * 2 + 1, valueType); + result.put(key, value); } + return result; + } + + @Override + public void finish(RowVertex vertex, Optional newValue) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIRecruitmentAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIRecruitmentAlgorithm.java index 5e92c38ec..3ac8ffbba 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIRecruitmentAlgorithm.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/ldbc/BIRecruitmentAlgorithm.java @@ -27,6 +27,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; + import org.apache.geaflow.common.type.primitive.LongType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; @@ -44,97 +45,93 @@ @Description(name = "bi20_recruitment", description = "LDBC BI20 Recruitment Algorithm") public class BIRecruitmentAlgorithm implements AlgorithmUserFunction { - private AlgorithmRuntimeContext context; - private Object sourceVertexId; - private final int maxIteration = 30; - private final String knowsType = "knows"; - private final String studyAtType = "studyAt"; - private final String personType = "Person"; + private AlgorithmRuntimeContext context; + private Object sourceVertexId; + private final int maxIteration = 30; + private final String knowsType = "knows"; + private final String studyAtType = "studyAt"; + private final String personType = "Person"; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - assert parameters.length >= 1 : "SSSP algorithm need source vid parameter."; - sourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); - assert sourceVertexId != null : "Source vid cannot be null for SSSP."; - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + assert parameters.length >= 1 : "SSSP algorithm need source vid parameter."; + sourceVertexId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); + assert sourceVertexId != null : "Source vid cannot be null for SSSP."; + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - if (!vertex.getLabel().equals(personType)) { - return; - } - Object vId = vertex.getId(); - updatedValues.ifPresent(vertex::setValue); - List outEdges = context.loadEdges(EdgeDirection.BOTH); - List sendMsgTargetIds = new ArrayList<>(); - Map university2ClassYear = new HashMap<>(); - for (RowEdge edge : outEdges) { - if (edge.getLabel().equals(knowsType)) { - sendMsgTargetIds.add(edge.getTargetId()); - } else if (edge.getLabel().equals(studyAtType)) { - university2ClassYear.put(edge.getTargetId(), - (Long) edge.getValue().getField(0, LongType.INSTANCE)); - } - } + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (!vertex.getLabel().equals(personType)) { + return; + } + Object vId = vertex.getId(); + updatedValues.ifPresent(vertex::setValue); + List outEdges = context.loadEdges(EdgeDirection.BOTH); + List sendMsgTargetIds = new ArrayList<>(); + Map university2ClassYear = new HashMap<>(); + for (RowEdge edge : outEdges) { + if (edge.getLabel().equals(knowsType)) { + sendMsgTargetIds.add(edge.getTargetId()); + } else if (edge.getLabel().equals(studyAtType)) { + university2ClassYear.put( + edge.getTargetId(), (Long) edge.getValue().getField(0, LongType.INSTANCE)); + } + } - Long currentDistance; - if (context.getCurrentIterationId() == 1L) { - if (Objects.equals(vId, sourceVertexId)) { - currentDistance = 0L; - } else { - currentDistance = Long.MAX_VALUE; - } - } else if (context.getCurrentIterationId() <= maxIteration) { - currentDistance = (Long) vertex.getValue().getField(0, LongType.INSTANCE); - //Msg schema: Person.id BIGINT, distance BIGINT, University.id BIGINT, classYear BIGINT - while (messages.hasNext()) { - ObjectRow msg = messages.next(); - Long inputDistance = (Long) msg.getField(1, LongType.INSTANCE); - Long universityId = (Long) msg.getField(2, LongType.INSTANCE); - Long classYear = (Long) msg.getField(3, LongType.INSTANCE); - Long newDistance = Long.MAX_VALUE; - if (inputDistance != Long.MAX_VALUE && university2ClassYear.containsKey(universityId)) { - newDistance = inputDistance + 1L - + Math.abs(university2ClassYear.get(universityId) - classYear); - } - if (newDistance < currentDistance) { - currentDistance = newDistance; - } - } - } else { - currentDistance = (long) vertex.getValue().getField(0, LongType.INSTANCE); - if (!vId.equals(sourceVertexId)) { - context.take(ObjectRow.create(TypeCastUtil.cast(vId, LongType.INSTANCE), currentDistance)); - } - return; + Long currentDistance; + if (context.getCurrentIterationId() == 1L) { + if (Objects.equals(vId, sourceVertexId)) { + currentDistance = 0L; + } else { + currentDistance = Long.MAX_VALUE; + } + } else if (context.getCurrentIterationId() <= maxIteration) { + currentDistance = (Long) vertex.getValue().getField(0, LongType.INSTANCE); + // Msg schema: Person.id BIGINT, distance BIGINT, University.id BIGINT, classYear BIGINT + while (messages.hasNext()) { + ObjectRow msg = messages.next(); + Long inputDistance = (Long) msg.getField(1, LongType.INSTANCE); + Long universityId = (Long) msg.getField(2, LongType.INSTANCE); + Long classYear = (Long) msg.getField(3, LongType.INSTANCE); + Long newDistance = Long.MAX_VALUE; + if (inputDistance != Long.MAX_VALUE && university2ClassYear.containsKey(universityId)) { + newDistance = + inputDistance + 1L + Math.abs(university2ClassYear.get(universityId) - classYear); } - context.updateVertexValue(ObjectRow.create(currentDistance)); - //Send active heartbeat message - context.sendMessage(vId, ObjectRow.create(new Object[]{0L, Long.MAX_VALUE, 0L, 0L})); - //Scatter - //Msg schema: Person.id BIGINT, distance BIGINT, University.id BIGINT, classYear BIGINT - for (Object targetId : sendMsgTargetIds) { - for (Entry universityMsg : university2ClassYear.entrySet()) { - context.sendMessage(targetId, ObjectRow.create( - vId, - currentDistance, - universityMsg.getKey(), - universityMsg.getValue() - )); - } + if (newDistance < currentDistance) { + currentDistance = newDistance; } + } + } else { + currentDistance = (long) vertex.getValue().getField(0, LongType.INSTANCE); + if (!vId.equals(sourceVertexId)) { + context.take(ObjectRow.create(TypeCastUtil.cast(vId, LongType.INSTANCE), currentDistance)); + } + return; } - - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", LongType.INSTANCE, false), - new TableField("distance", LongType.INSTANCE, false) - ); + context.updateVertexValue(ObjectRow.create(currentDistance)); + // Send active heartbeat message + context.sendMessage(vId, ObjectRow.create(new Object[] {0L, Long.MAX_VALUE, 0L, 0L})); + // Scatter + // Msg schema: Person.id BIGINT, distance BIGINT, University.id BIGINT, classYear BIGINT + for (Object targetId : sendMsgTargetIds) { + for (Entry universityMsg : university2ClassYear.entrySet()) { + context.sendMessage( + targetId, + ObjectRow.create( + vId, currentDistance, universityMsg.getKey(), universityMsg.getValue())); + } } + } - @Override - public void finish(RowVertex vertex, Optional newValue) { - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", LongType.INSTANCE, false), + new TableField("distance", LongType.INSTANCE, false)); + } + + @Override + public void finish(RowVertex vertex, Optional newValue) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgDouble.java index 48dc4af67..db36ae699 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgDouble.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.AvgDouble.Accumulator; @@ -27,40 +28,40 @@ @Description(name = "avg", description = "The avg function for double input.") public class AvgDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Double input) { - if (null != input) { - accumulator.sum += input; - accumulator.count++; - } + @Override + public void accumulate(Accumulator accumulator, Double input) { + if (null != input) { + accumulator.sum += input; + accumulator.count++; } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); - } + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); + } - public static class Accumulator implements Serializable { - public double sum = 0.0; - public long count = 0; - } + public static class Accumulator implements Serializable { + public double sum = 0.0; + public long count = 0; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgInteger.java index a9a2803a9..5781e2c0e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgInteger.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.AvgInteger.Accumulator; @@ -27,40 +28,40 @@ @Description(name = "avg", description = "The avg function for int input.") public class AvgInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Integer input) { - if (null != input) { - accumulator.sum += input; - accumulator.count++; - } + @Override + public void accumulate(Accumulator accumulator, Integer input) { + if (null != input) { + accumulator.sum += input; + accumulator.count++; } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); - } + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); + } - public static class Accumulator implements Serializable { - public double sum = 0.0; - public long count = 0; - } + public static class Accumulator implements Serializable { + public double sum = 0.0; + public long count = 0; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgLong.java index 8b474df5e..7e2942730 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/AvgLong.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.AvgLong.Accumulator; @@ -27,40 +28,40 @@ @Description(name = "avg", description = "The avg function for bigint input.") public class AvgLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Long input) { - if (null != input) { - accumulator.sum += input; - accumulator.count++; - } + @Override + public void accumulate(Accumulator accumulator, Long input) { + if (null != input) { + accumulator.sum += input; + accumulator.count++; } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); - } + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.count == 0 ? null : (accumulator.sum / (double) accumulator.count); + } - public static class Accumulator implements Serializable { - public double sum = 0.0; - public long count = 0; - } + public static class Accumulator implements Serializable { + public double sum = 0.0; + public long count = 0; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/Count.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/Count.java index d123e3cb7..3df05c076 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/Count.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/Count.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.Count.Accumulator; @@ -27,51 +28,48 @@ @Description(name = "count", description = "The count function.") public class Count extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(0); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(0); + } - @Override - public void accumulate(Accumulator accumulator, Object input) { - if (input != null) { - accumulator.value++; - } + @Override + public void accumulate(Accumulator accumulator, Object input) { + if (input != null) { + accumulator.value++; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - accumulator.value += toMerge.value; - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + accumulator.value += toMerge.value; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = 0L; + } - @Override - public Long getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Long getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - public Accumulator() { - } + public Accumulator() {} - public long value = 0; + public long value = 0; - public Accumulator(long value) { - this.value = value; - } + public Accumulator(long value) { + this.value = value; + } - @Override - public String toString() { - return "Accumulator{" - + "value=" + value - + '}'; - } + @Override + public String toString() { + return "Accumulator{" + "value=" + value + '}'; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxBinaryString.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxBinaryString.java index 9fd46a49a..72c356904 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxBinaryString.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxBinaryString.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; @@ -27,47 +28,45 @@ @Description(name = "max", description = "The max function for string.") public class MaxBinaryString extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, BinaryString input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) < 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, BinaryString input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) < 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public BinaryString getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public BinaryString getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public BinaryString value; + public BinaryString value; - public Accumulator(BinaryString value) { - this.value = value; - } + public Accumulator(BinaryString value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxDouble.java index ab589f1ea..b033719b5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxDouble.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MaxDouble.Accumulator; @@ -27,47 +28,45 @@ @Description(name = "max", description = "The max function for double.") public class MaxDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Double input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) < 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Double input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) < 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public Double value; + public Double value; - public Accumulator(Double value) { - this.value = value; - } + public Accumulator(Double value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxInteger.java index 90843e256..21d830db7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxInteger.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MaxInteger.Accumulator; @@ -27,47 +28,45 @@ @Description(name = "max", description = "The max function for int.") public class MaxInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Integer input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) < 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Integer input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) < 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Integer getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Integer getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public Integer value; + public Integer value; - public Accumulator(Integer value) { - this.value = value; - } + public Accumulator(Integer value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxLong.java index ce08292cc..65021089d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MaxLong.java @@ -20,53 +20,52 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MaxLong.Accumulator; @Description(name = "max", description = "The max function for bigint.") public class MaxLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Long input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) < 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Long input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) < 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) < 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Long getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Long getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public Long value; + public Long value; - public Accumulator(Long value) { - this.value = value; - } + public Accumulator(Long value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinBinaryString.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinBinaryString.java index 4674a76c1..2d6006159 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinBinaryString.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinBinaryString.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; @@ -27,47 +28,45 @@ @Description(name = "min", description = "The min function for string.") public class MinBinaryString extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, BinaryString input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) > 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, BinaryString input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) > 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public BinaryString getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public BinaryString getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public BinaryString value; + public BinaryString value; - public Accumulator(BinaryString value) { - this.value = value; - } + public Accumulator(BinaryString value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinDouble.java index 79fc60a4c..4b8df6f04 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinDouble.java @@ -20,53 +20,52 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MinDouble.Accumulator; @Description(name = "min", description = "The min function for double.") public class MinDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Double input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) > 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Double input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) > 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public Double value; + public Double value; - public Accumulator(Double value) { - this.value = value; - } + public Accumulator(Double value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinInteger.java index 1221370c6..4f18c5f7d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinInteger.java @@ -20,53 +20,52 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MinInteger.Accumulator; @Description(name = "min", description = "The min function for int.") public class MinInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Integer input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) > 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Integer input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) > 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Integer getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Integer getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { - public Accumulator() { - } + public static class Accumulator implements Serializable { + public Accumulator() {} - public Integer value; + public Integer value; - public Accumulator(Integer value) { - this.value = value; - } + public Accumulator(Integer value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinLong.java index 43bfd52d6..bad481810 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/MinLong.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.MinLong.Accumulator; @@ -27,48 +28,46 @@ @Description(name = "min", description = "The min function for bigint.") public class MinLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(null); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(null); + } - @Override - public void accumulate(Accumulator accumulator, Long input) { - if (null == accumulator.value - || (null != input && accumulator.value.compareTo(input) > 0)) { - accumulator.value = input; - } + @Override + public void accumulate(Accumulator accumulator, Long input) { + if (null == accumulator.value || (null != input && accumulator.value.compareTo(input) > 0)) { + accumulator.value = input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - if (accumulator.value == null - || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { - accumulator.value = toMerge.value; - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + if (accumulator.value == null + || (null != toMerge.value && accumulator.value.compareTo(toMerge.value) > 0)) { + accumulator.value = toMerge.value; + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = null; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = null; + } - @Override - public Long getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Long getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - public Accumulator() { - } + public Accumulator() {} - public Long value; + public Long value; - public Accumulator(Long value) { - this.value = value; - } + public Accumulator(Long value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileDouble.java index 5c2607645..8ec6342a4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileDouble.java @@ -23,6 +23,7 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; + import org.apache.commons.math3.stat.StatUtils; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.common.function.Description; @@ -34,103 +35,101 @@ @Description(name = "percentile", description = "percentile agg function for double") public class PercentileDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, MultiArguments input) { - if (input != null) { - accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); - accumulator.getValueList().add((double) input.getParam(0)); - } + @Override + public void accumulate(Accumulator accumulator, MultiArguments input) { + if (input != null) { + accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); + accumulator.getValueList().add((double) input.getParam(0)); } - - public static double getPercent(Object percent) { - if (percent instanceof BigDecimal) { - return ((BigDecimal) percent).doubleValue(); - } else if (percent instanceof Double) { - return ((Double) percent); - } else if (percent instanceof Float) { - return ((Float) percent); - } else if (percent instanceof Long) { - return ((Long) percent).doubleValue(); - } else if (percent instanceof Integer) { - return ((Integer) percent).doubleValue(); - } - throw new GeaFlowDSLException("Percentile not support percent type: " + percent); + } + + public static double getPercent(Object percent) { + if (percent instanceof BigDecimal) { + return ((BigDecimal) percent).doubleValue(); + } else if (percent instanceof Double) { + return ((Double) percent); + } else if (percent instanceof Float) { + return ((Float) percent); + } else if (percent instanceof Long) { + return ((Long) percent).doubleValue(); + } else if (percent instanceof Integer) { + return ((Integer) percent).doubleValue(); } - - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator it : its) { - if (it != null) { - accumulator.getValueList().addAll(it.getValueList()); - accumulator.setPercent(it.getPercent()); - } - } + throw new GeaFlowDSLException("Percentile not support percent type: " + percent); + } + + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator it : its) { + if (it != null) { + accumulator.getValueList().addAll(it.getValueList()); + accumulator.setPercent(it.getPercent()); + } } - - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.setValueList(new ArrayList<>()); + } + + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.setValueList(new ArrayList<>()); + } + + @Override + public Double getValue(Accumulator accumulator) { + List valueList = accumulator.getValueList(); + double[] values = new double[valueList.size()]; + for (int i = 0; i < valueList.size(); i++) { + values[i] = valueList.get(i); } + return StatUtils.percentile(values, accumulator.getPercent()); + } - @Override - public Double getValue(Accumulator accumulator) { - List valueList = accumulator.getValueList(); - double[] values = new double[valueList.size()]; - for (int i = 0; i < valueList.size(); i++) { - values[i] = valueList.get(i); - } - return StatUtils.percentile(values, accumulator.getPercent()); - } - - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - private static final long serialVersionUID = 7024955653427528364L; + private static final long serialVersionUID = 7024955653427528364L; - private List valueList; - private double percent; + private List valueList; + private double percent; - public Accumulator() { - this.valueList = new ArrayList<>(); - } + public Accumulator() { + this.valueList = new ArrayList<>(); + } - public Accumulator(double value) { - this.valueList.add(value); - } + public Accumulator(double value) { + this.valueList.add(value); + } - public List getValueList() { - return valueList; - } + public List getValueList() { + return valueList; + } - public void setValueList(List valueList) { - this.valueList = valueList; - } + public void setValueList(List valueList) { + this.valueList = valueList; + } - public double getPercent() { - return percent; - } + public double getPercent() { + return percent; + } - public void setPercent(double percent) { - this.percent = percent; - } + public void setPercent(double percent) { + this.percent = percent; } + } - public static class MultiArguments extends UDAFArguments { + public static class MultiArguments extends UDAFArguments { - public MultiArguments() { - } + public MultiArguments() {} - @Override - public List> getParamTypes() { - List> types = new ArrayList<>(); - types.add(Double.class); - types.add(Double.class); - return types; - } + @Override + public List> getParamTypes() { + List> types = new ArrayList<>(); + types.add(Double.class); + types.add(Double.class); + return types; } + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileInteger.java index 5d8cb5fe3..34448f314 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileInteger.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.commons.math3.stat.StatUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; @@ -31,56 +32,54 @@ @Description(name = "percentile", description = "percentile agg function for integer") public class PercentileInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, MultiArguments input) { - if (input != null) { - accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); - accumulator.getValueList().add((double) (int) input.getParam(0)); - } + @Override + public void accumulate(Accumulator accumulator, MultiArguments input) { + if (input != null) { + accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); + accumulator.getValueList().add((double) (int) input.getParam(0)); } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator it : its) { - if (it != null) { - accumulator.getValueList().addAll(it.getValueList()); - accumulator.setPercent(it.getPercent()); - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator it : its) { + if (it != null) { + accumulator.getValueList().addAll(it.getValueList()); + accumulator.setPercent(it.getPercent()); + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.setValueList(new ArrayList<>()); - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.setValueList(new ArrayList<>()); + } - @Override - public Double getValue(Accumulator accumulator) { - List valueList = accumulator.getValueList(); - double[] values = new double[valueList.size()]; - for (int i = 0; i < valueList.size(); i++) { - values[i] = valueList.get(i); - } - return StatUtils.percentile(values, accumulator.getPercent()); + @Override + public Double getValue(Accumulator accumulator) { + List valueList = accumulator.getValueList(); + double[] values = new double[valueList.size()]; + for (int i = 0; i < valueList.size(); i++) { + values[i] = valueList.get(i); } + return StatUtils.percentile(values, accumulator.getPercent()); + } - public static class MultiArguments extends UDAFArguments { + public static class MultiArguments extends UDAFArguments { - public MultiArguments() { - } + public MultiArguments() {} - @Override - public List> getParamTypes() { - List> types = new ArrayList<>(); - types.add(Integer.class); - types.add(Double.class); - return types; - } + @Override + public List> getParamTypes() { + List> types = new ArrayList<>(); + types.add(Integer.class); + types.add(Double.class); + return types; } + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileLong.java index f2561677f..4c488e087 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/PercentileLong.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.commons.math3.stat.StatUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; @@ -31,56 +32,54 @@ @Description(name = "percentile", description = "percentile agg function for long") public class PercentileLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, MultiArguments input) { - if (input != null) { - accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); - accumulator.getValueList().add((double) (long) input.getParam(0)); - } + @Override + public void accumulate(Accumulator accumulator, MultiArguments input) { + if (input != null) { + accumulator.setPercent(PercentileDouble.getPercent(input.getParam(1))); + accumulator.getValueList().add((double) (long) input.getParam(0)); } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator it : its) { - if (it != null) { - accumulator.getValueList().addAll(it.getValueList()); - accumulator.setPercent(it.getPercent()); - } - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator it : its) { + if (it != null) { + accumulator.getValueList().addAll(it.getValueList()); + accumulator.setPercent(it.getPercent()); + } } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.setValueList(new ArrayList<>()); - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.setValueList(new ArrayList<>()); + } - @Override - public Double getValue(Accumulator accumulator) { - List valueList = accumulator.getValueList(); - double[] values = new double[valueList.size()]; - for (int i = 0; i < valueList.size(); i++) { - values[i] = valueList.get(i); - } - return StatUtils.percentile(values, accumulator.getPercent()); + @Override + public Double getValue(Accumulator accumulator) { + List valueList = accumulator.getValueList(); + double[] values = new double[valueList.size()]; + for (int i = 0; i < valueList.size(); i++) { + values[i] = valueList.get(i); } + return StatUtils.percentile(values, accumulator.getPercent()); + } - public static class MultiArguments extends UDAFArguments { + public static class MultiArguments extends UDAFArguments { - public MultiArguments() { - } + public MultiArguments() {} - @Override - public List> getParamTypes() { - List> types = new ArrayList<>(); - types.add(Long.class); - types.add(Double.class); - return types; - } + @Override + public List> getParamTypes() { + List> types = new ArrayList<>(); + types.add(Long.class); + types.add(Double.class); + return types; } + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampDouble.java index bea99af9b..5aa526743 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampDouble.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.StdDevSampDouble.Accumulator; @@ -27,55 +28,56 @@ @Description(name = "stddev_samp", description = "The stddev function for double.") public class StdDevSampDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Double input) { - if (null != input) { - accumulator.add(input); - } + @Override + public void accumulate(Accumulator accumulator, Double input) { + if (null != input) { + accumulator.add(input); } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.squareSum += accumulator.squareSum; - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.squareSum += accumulator.squareSum; + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.squareSum = 0.0; - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.squareSum = 0.0; + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - if (accumulator.count == 0) { - return 0.0; - } - return Math.sqrt((accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) / accumulator.count); + @Override + public Double getValue(Accumulator accumulator) { + if (accumulator.count == 0) { + return 0.0; } + return Math.sqrt( + (accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) + / accumulator.count); + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - double squareSum = 0.0; - double sum = 0.0; - long count = 0; + double squareSum = 0.0; + double sum = 0.0; + long count = 0; - public Accumulator() { - } + public Accumulator() {} - public void add(double input) { - sum += input; - squareSum += input * input; - count++; - } + public void add(double input) { + sum += input; + squareSum += input * input; + count++; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampInteger.java index 37639a150..9bae0772c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampInteger.java @@ -26,40 +26,41 @@ @Description(name = "stddev_samp", description = "The stddev function for Integer.") public class StdDevSampInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Integer input) { - if (null != input) { - accumulator.add(input); - } + @Override + public void accumulate(Accumulator accumulator, Integer input) { + if (null != input) { + accumulator.add(input); } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.squareSum += accumulator.squareSum; - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.squareSum += accumulator.squareSum; + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.squareSum = 0.0; - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.squareSum = 0.0; + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - if (accumulator.count == 0) { - return 0.0; - } - return Math.sqrt((accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) / accumulator.count); + @Override + public Double getValue(Accumulator accumulator) { + if (accumulator.count == 0) { + return 0.0; } - + return Math.sqrt( + (accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) + / accumulator.count); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampLong.java index 40fd7cb6c..c21d91fbb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/StdDevSampLong.java @@ -26,40 +26,41 @@ @Description(name = "stddev_samp", description = "The stddev function for Long.") public class StdDevSampLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(); + } - @Override - public void accumulate(Accumulator accumulator, Long input) { - if (null != input) { - accumulator.add(input); - } + @Override + public void accumulate(Accumulator accumulator, Long input) { + if (null != input) { + accumulator.add(input); } + } - @Override - public void merge(Accumulator merged, Iterable accumulators) { - for (Accumulator accumulator : accumulators) { - merged.squareSum += accumulator.squareSum; - merged.sum += accumulator.sum; - merged.count += accumulator.count; - } + @Override + public void merge(Accumulator merged, Iterable accumulators) { + for (Accumulator accumulator : accumulators) { + merged.squareSum += accumulator.squareSum; + merged.sum += accumulator.sum; + merged.count += accumulator.count; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.squareSum = 0.0; - accumulator.sum = 0.0; - accumulator.count = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.squareSum = 0.0; + accumulator.sum = 0.0; + accumulator.count = 0L; + } - @Override - public Double getValue(Accumulator accumulator) { - if (accumulator.count == 0) { - return 0.0; - } - return Math.sqrt((accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) / accumulator.count); + @Override + public Double getValue(Accumulator accumulator) { + if (accumulator.count == 0) { + return 0.0; } - + return Math.sqrt( + (accumulator.squareSum - Math.pow(accumulator.sum, 2) / accumulator.count) + / accumulator.count); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumDouble.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumDouble.java index 62d70064f..66bdbb75e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumDouble.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumDouble.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.SumDouble.Accumulator; @@ -27,45 +28,43 @@ @Description(name = "sum", description = "The sum function for double.") public class SumDouble extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(0.0); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(0.0); + } - @Override - public void accumulate(Accumulator accumulator, Double input) { - if (null != input) { - accumulator.value += input; - } + @Override + public void accumulate(Accumulator accumulator, Double input) { + if (null != input) { + accumulator.value += input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - accumulator.value += toMerge.value; - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + accumulator.value += toMerge.value; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = 0.0; - } - - @Override - public Double getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = 0.0; + } - public static class Accumulator implements Serializable { + @Override + public Double getValue(Accumulator accumulator) { + return accumulator.value; + } - public Accumulator() { - } + public static class Accumulator implements Serializable { - public double value; + public Accumulator() {} - public Accumulator(double value) { - this.value = value; - } + public double value; + public Accumulator(double value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumInteger.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumInteger.java index c4f672753..01086c8d6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumInteger.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumInteger.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.SumInteger.Accumulator; @@ -27,44 +28,43 @@ @Description(name = "sum", description = "The sum function for double.") public class SumInteger extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(0); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(0); + } - @Override - public void accumulate(Accumulator accumulator, Integer input) { - if (null != input) { - accumulator.value += input; - } + @Override + public void accumulate(Accumulator accumulator, Integer input) { + if (null != input) { + accumulator.value += input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - accumulator.value += toMerge.value; - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + accumulator.value += toMerge.value; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = 0; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = 0; + } - @Override - public Integer getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Integer getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - public Accumulator() { - } + public Accumulator() {} - public int value; + public int value; - public Accumulator(int value) { - this.value = value; - } + public Accumulator(int value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumLong.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumLong.java index c157b6035..12b8abce0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumLong.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/agg/SumLong.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.agg; import java.io.Serializable; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; import org.apache.geaflow.dsl.udf.table.agg.SumLong.Accumulator; @@ -27,44 +28,43 @@ @Description(name = "sum", description = "The sum function for bigint.") public class SumLong extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(0L); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(0L); + } - @Override - public void accumulate(Accumulator accumulator, Long input) { - if (null != input) { - accumulator.value += input; - } + @Override + public void accumulate(Accumulator accumulator, Long input) { + if (null != input) { + accumulator.value += input; } + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - accumulator.value += toMerge.value; - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + accumulator.value += toMerge.value; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = 0L; + } - @Override - public Long getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Long getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - public Accumulator() { - } + public Accumulator() {} - public long value; + public long value; - public Accumulator(long value) { - this.value = value; - } + public Accumulator(long value) { + this.value = value; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayAppend.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayAppend.java index 53e5349c2..a88241bfe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayAppend.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayAppend.java @@ -25,12 +25,12 @@ @Description(name = "array_append", description = "Append other element to input array.") public class ArrayAppend extends UDF { - public Object[] eval(Object[] input, Object otherElement) { - Object[] res = new Object[input.length + 1]; - for (int i = 0; i < input.length; i++) { - res[i] = input[i]; - } - res[res.length - 1] = otherElement; - return res; + public Object[] eval(Object[] input, Object otherElement) { + Object[] res = new Object[input.length + 1]; + for (int i = 0; i < input.length; i++) { + res[i] = input[i]; } + res[res.length - 1] = otherElement; + return res; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayContains.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayContains.java index 25a679c75..3442d4795 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayContains.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayContains.java @@ -20,22 +20,23 @@ package org.apache.geaflow.dsl.udf.table.array; import java.util.Objects; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "array_contains", description = "Judge if the input element exists in array.") public class ArrayContains extends UDF { - public Boolean eval(Object[] array, Object e) { - if (array == null || e == null) { - return null; - } + public Boolean eval(Object[] array, Object e) { + if (array == null || e == null) { + return null; + } - for (Object a : array) { - if (Objects.equals(a, e)) { - return true; - } - } - return false; + for (Object a : array) { + if (Objects.equals(a, e)) { + return true; + } } + return false; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayDistinct.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayDistinct.java index 95e7db5eb..332df2af8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayDistinct.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayDistinct.java @@ -21,18 +21,19 @@ import java.util.HashSet; import java.util.Set; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "array_distinct", description = "Dedup input array.") public class ArrayDistinct extends UDF { - public Object[] eval(Object[] input) { - Set set = new HashSet<>(); + public Object[] eval(Object[] input) { + Set set = new HashSet<>(); - for (Object o : input) { - set.add(o); - } - return set.toArray(new Object[0]); + for (Object o : input) { + set.add(o); } + return set.toArray(new Object[0]); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayUnion.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayUnion.java index e41285572..1b3e4934a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayUnion.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/array/ArrayUnion.java @@ -19,24 +19,26 @@ package org.apache.geaflow.dsl.udf.table.array; -import com.google.common.collect.Sets; import java.util.Arrays; import java.util.Set; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; +import com.google.common.collect.Sets; + @Description(name = "array_union", description = "Union two input array.") public class ArrayUnion extends UDF { - public Object[] eval(Object[] input1, Object[] input2) { - if (input1 == null) { - return input2; - } - if (input2 == null) { - return input1; - } - Set result = Sets.newHashSet(input1); - result.addAll(Arrays.asList(input2)); - return result.toArray(); + public Object[] eval(Object[] input1, Object[] input2) { + if (input1 == null) { + return input2; + } + if (input2 == null) { + return input1; } + Set result = Sets.newHashSet(input1); + result.addAll(Arrays.asList(input2)); + return result.toArray(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/AddMonths.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/AddMonths.java index 1a33858d1..f1ffa9d74 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/AddMonths.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/AddMonths.java @@ -20,43 +20,44 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Calendar; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "add_months", description = "Returns the date that is num_months after " - + "start_date.") +@Description( + name = "add_months", + description = "Returns the date that is num_months after " + "start_date.") public class AddMonths extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); - + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - public String eval(String date, Integer month) { - if (date == null || month == null) { - return null; - } - DateTimeFormatter formatter; + public String eval(String date, Integer month) { + if (date == null || month == null) { + return null; + } + DateTimeFormatter formatter; - if (date.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - try { - calendar.setTime(formatter.parseDateTime(date).toDate()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - calendar.add(Calendar.MONTH, month); - try { - return formatter.print(calendar.getTime().getTime()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + if (date.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + try { + calendar.setTime(formatter.parseDateTime(date).toDate()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + calendar.add(Calendar.MONTH, month); + try { + return formatter.print(calendar.getTime().getTime()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateAdd.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateAdd.java index 1ee9cb5ec..b156d6509 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateAdd.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateAdd.java @@ -22,6 +22,7 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -31,38 +32,38 @@ @Description(name = "date_add", description = "Returns the date that is num_days after start date.") public class DateAdd extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - public String eval(String dateString, Integer days) { - if (dateString == null || days == null) { - return null; - } - DateTimeFormatter formatter; - try { - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - calendar.setTime(formatter.parseDateTime(dateString).toDate()); - calendar.add(Calendar.DAY_OF_MONTH, days); - Date newDate = calendar.getTime(); - return DATE_FORMATTER.print(newDate.getTime()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public String eval(String dateString, Integer days) { + if (dateString == null || days == null) { + return null; + } + DateTimeFormatter formatter; + try { + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + calendar.setTime(formatter.parseDateTime(dateString).toDate()); + calendar.add(Calendar.DAY_OF_MONTH, days); + Date newDate = calendar.getTime(); + return DATE_FORMATTER.print(newDate.getTime()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public String eval(Timestamp t, Integer days) { - if (t == null || days == null) { - return null; - } - calendar.setTime(t); - calendar.add(Calendar.DAY_OF_MONTH, days); - Date newDate = calendar.getTime(); - return DATE_TIME_FORMATTER.print(newDate.getTime()); + public String eval(Timestamp t, Integer days) { + if (t == null || days == null) { + return null; } + calendar.setTime(t); + calendar.add(Calendar.DAY_OF_MONTH, days); + Date newDate = calendar.getTime(); + return DATE_TIME_FORMATTER.print(newDate.getTime()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateDiff.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateDiff.java index 6b1f8e627..ae5bff008 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateDiff.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateDiff.java @@ -21,66 +21,69 @@ import java.sql.Timestamp; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "date_diff", description = "Returns the number of days from startdate to end date.") +@Description( + name = "date_diff", + description = "Returns the number of days from startdate to end date.") public class DateDiff extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - - public Integer eval(String dateString1, String dateString2) { - return eval(toDate(dateString1), toDate(dateString2)); - } + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - public Integer eval(Timestamp t1, Timestamp t2) { - return eval(toDate(t1), toDate(t2)); - } + public Integer eval(String dateString1, String dateString2) { + return eval(toDate(dateString1), toDate(dateString2)); + } - public Integer eval(Timestamp t, String dateString) { - return eval(toDate(t), toDate(dateString)); - } + public Integer eval(Timestamp t1, Timestamp t2) { + return eval(toDate(t1), toDate(t2)); + } - public Integer eval(String dateString, Timestamp t) { - return eval(toDate(dateString), toDate(t)); - } + public Integer eval(Timestamp t, String dateString) { + return eval(toDate(t), toDate(dateString)); + } - private Integer eval(Date date1, Date date2) { - if (date1 == null || date2 == null) { - return null; - } + public Integer eval(String dateString, Timestamp t) { + return eval(toDate(dateString), toDate(t)); + } - long diffInMilliSeconds = date1.getTime() - date2.getTime(); - return (int) (diffInMilliSeconds / (86400 * 1000)); + private Integer eval(Date date1, Date date2) { + if (date1 == null || date2 == null) { + return null; } - private Date format(String dateString) { - try { - DateTimeFormatter formatter; - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - return formatter.parseDateTime(dateString).toDate(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - } + long diffInMilliSeconds = date1.getTime() - date2.getTime(); + return (int) (diffInMilliSeconds / (86400 * 1000)); + } - private Date toDate(String dateString) { - if (dateString == null) { - return null; - } - return format(dateString); + private Date format(String dateString) { + try { + DateTimeFormatter formatter; + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + return formatter.parseDateTime(dateString).toDate(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - private Date toDate(Timestamp t) { - return t; + private Date toDate(String dateString) { + if (dateString == null) { + return null; } + return format(dateString); + } + + private Date toDate(Timestamp t) { + return t; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateFormat.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateFormat.java index 5ee638064..98b8074bc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateFormat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateFormat.java @@ -23,47 +23,48 @@ import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "date_format", description = "Returns convert the date from a format to " - + "another.") +@Description( + name = "date_format", + description = "Returns convert the date from a format to " + "another.") public class DateFormat extends UDF { - private FromUnixTime fromUnixTime; - private UnixTimeStamp unixTimeStamp; + private FromUnixTime fromUnixTime; + private UnixTimeStamp unixTimeStamp; - public void open(FunctionContext context) { - super.open(context); - fromUnixTime = new FromUnixTime(); - unixTimeStamp = new UnixTimeStamp(); - fromUnixTime.open(context); - unixTimeStamp.open(context); - } + public void open(FunctionContext context) { + super.open(context); + fromUnixTime = new FromUnixTime(); + unixTimeStamp = new UnixTimeStamp(); + fromUnixTime.open(context); + unixTimeStamp.open(context); + } - public String eval(String dateText, String toFormat) { - String format = "yyyy-MM-dd HH:mm:ss"; + public String eval(String dateText, String toFormat) { + String format = "yyyy-MM-dd HH:mm:ss"; - if (dateText != null && dateText.length() > 19) { - char sep = dateText.charAt(19); - format = "yyyy-MM-dd HH:mm:ss" + sep + "SSSSSS"; - } - return eval(dateText, format, toFormat); + if (dateText != null && dateText.length() > 19) { + char sep = dateText.charAt(19); + format = "yyyy-MM-dd HH:mm:ss" + sep + "SSSSSS"; } + return eval(dateText, format, toFormat); + } - public String eval(String dateText) { - return eval(dateText, "yyyy-MM-dd HH:mm:ss"); - } + public String eval(String dateText) { + return eval(dateText, "yyyy-MM-dd HH:mm:ss"); + } - public String eval(java.sql.Timestamp timestamp, String toFormat) { - return eval(timestamp.toString(), toFormat); - } + public String eval(java.sql.Timestamp timestamp, String toFormat) { + return eval(timestamp.toString(), toFormat); + } - public String eval(java.sql.Timestamp timestamp) { - return eval(timestamp.toString()); - } + public String eval(java.sql.Timestamp timestamp) { + return eval(timestamp.toString()); + } - public String eval(String dateText, String fromFormat, String toFormat) { - if (dateText == null || fromFormat == null || toFormat == null) { - return null; - } - return fromUnixTime.eval(unixTimeStamp.eval(dateText, fromFormat), toFormat); + public String eval(String dateText, String fromFormat, String toFormat) { + if (dateText == null || fromFormat == null || toFormat == null) { + return null; } + return fromUnixTime.eval(unixTimeStamp.eval(dateText, fromFormat), toFormat); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DatePart.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DatePart.java index 03b4888b6..b8b3b3e06 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DatePart.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DatePart.java @@ -20,59 +20,58 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Calendar; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "date_part", description = "Returns part of the date by date part " - + "format.") +@Description(name = "date_part", description = "Returns part of the date by date part " + "format.") public class DatePart extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); - - public Integer eval(String date, String datepart) { - if (date == null || datepart == null) { - return null; - } - DateTimeFormatter formatter; - if (date.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - try { - calendar.setTime(formatter.parseDateTime(date).toDate()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - switch (datepart) { - case "yyyy": - case "year": - return calendar.get(Calendar.YEAR); - case "mm": - case "mon": - case "month": - return 1 + calendar.get(Calendar.MONTH); - case "dd": - case "day": - return calendar.get(Calendar.DAY_OF_MONTH); - case "hh": - case "hour": - return calendar.get(Calendar.HOUR_OF_DAY); - case "mi": - case "minute": - return calendar.get(Calendar.MINUTE); - case "ss": - case "second": - return calendar.get(Calendar.SECOND); - default: - throw new RuntimeException("unknown datepart:" + datepart); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - } + public Integer eval(String date, String datepart) { + if (date == null || datepart == null) { + return null; + } + DateTimeFormatter formatter; + if (date.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + try { + calendar.setTime(formatter.parseDateTime(date).toDate()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + switch (datepart) { + case "yyyy": + case "year": + return calendar.get(Calendar.YEAR); + case "mm": + case "mon": + case "month": + return 1 + calendar.get(Calendar.MONTH); + case "dd": + case "day": + return calendar.get(Calendar.DAY_OF_MONTH); + case "hh": + case "hour": + return calendar.get(Calendar.HOUR_OF_DAY); + case "mi": + case "minute": + return calendar.get(Calendar.MINUTE); + case "ss": + case "second": + return calendar.get(Calendar.SECOND); + default: + throw new RuntimeException("unknown datepart:" + datepart); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateSub.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateSub.java index e03b5c0f5..b5f4b5e08 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateSub.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateSub.java @@ -22,47 +22,50 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "date_sub", description = "Returns the date that is num_days before start_date.") +@Description( + name = "date_sub", + description = "Returns the date that is num_days before start_date.") public class DateSub extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - public String eval(String dateString, Integer days) { - if (dateString == null || days == null) { - return null; - } - DateTimeFormatter formatter; - try { - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - calendar.setTime(formatter.parseDateTime(dateString).toDate()); - calendar.add(Calendar.DAY_OF_MONTH, -days); - Date newDate = calendar.getTime(); - return DATE_FORMATTER.print(newDate.getTime()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public String eval(String dateString, Integer days) { + if (dateString == null || days == null) { + return null; + } + DateTimeFormatter formatter; + try { + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + calendar.setTime(formatter.parseDateTime(dateString).toDate()); + calendar.add(Calendar.DAY_OF_MONTH, -days); + Date newDate = calendar.getTime(); + return DATE_FORMATTER.print(newDate.getTime()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public String eval(Timestamp t, Integer days) { - if (t == null || days == null) { - return null; - } - calendar.setTime(t); - calendar.add(Calendar.DAY_OF_MONTH, -days); - Date newDate = calendar.getTime(); - return DATE_TIME_FORMATTER.print(newDate.getTime()); + public String eval(Timestamp t, Integer days) { + if (t == null || days == null) { + return null; } + calendar.setTime(t); + calendar.add(Calendar.DAY_OF_MONTH, -days); + Date newDate = calendar.getTime(); + return DATE_TIME_FORMATTER.print(newDate.getTime()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateTrunc.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateTrunc.java index 8c24e4af1..dd97e65cf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateTrunc.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DateTrunc.java @@ -20,77 +20,98 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Calendar; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "date_trunc", description = "Returns date truncated to the unit specified by " - + "the format.") +@Description( + name = "date_trunc", + description = "Returns date truncated to the unit specified by " + "the format.") public class DateTrunc extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final String DATE_FORMAT = "%04d-%02d-%02d %02d:%02d:%02d"; - private final Calendar calendar = Calendar.getInstance(); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final String DATE_FORMAT = "%04d-%02d-%02d %02d:%02d:%02d"; + private final Calendar calendar = Calendar.getInstance(); - public String eval(String date, String datepart) { + public String eval(String date, String datepart) { - if (date == null || datepart == null) { - return null; - } + if (date == null || datepart == null) { + return null; + } - DateTimeFormatter formatter; + DateTimeFormatter formatter; - if (date != null && date.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } + if (date != null && date.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } - try { - calendar.setTime(formatter.parseDateTime(date).toDate()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + try { + calendar.setTime(formatter.parseDateTime(date).toDate()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } - switch (datepart.toLowerCase()) { - case "yyyy": - case "year": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), 1, 1, 0, 0, 0); - case "mm": - case "mon": - case "month": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, 1, 0, 0, 0); - case "dd": - case "day": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), 0, 0, 0); - case "hh": - case "hour": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), - calendar.get(Calendar.HOUR_OF_DAY), 0, 0); + switch (datepart.toLowerCase()) { + case "yyyy": + case "year": + return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), 1, 1, 0, 0, 0); + case "mm": + case "mon": + case "month": + return String.format( + DATE_FORMAT, calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 1, 0, 0, 0); + case "dd": + case "day": + return String.format( + DATE_FORMAT, + calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, + calendar.get(Calendar.DAY_OF_MONTH), + 0, + 0, + 0); + case "hh": + case "hour": + return String.format( + DATE_FORMAT, + calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, + calendar.get(Calendar.DAY_OF_MONTH), + calendar.get(Calendar.HOUR_OF_DAY), + 0, + 0); - case "mi": - case "minute": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), - calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), 0); + case "mi": + case "minute": + return String.format( + DATE_FORMAT, + calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, + calendar.get(Calendar.DAY_OF_MONTH), + calendar.get(Calendar.HOUR_OF_DAY), + calendar.get(Calendar.MINUTE), + 0); - case "ss": - case "second": - return String.format(DATE_FORMAT, calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), - calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), - calendar.get(Calendar.SECOND)); + case "ss": + case "second": + return String.format( + DATE_FORMAT, + calendar.get(Calendar.YEAR), + calendar.get(Calendar.MONTH) + 1, + calendar.get(Calendar.DAY_OF_MONTH), + calendar.get(Calendar.HOUR_OF_DAY), + calendar.get(Calendar.MINUTE), + calendar.get(Calendar.SECOND)); - default: - throw new RuntimeException("unknown datepart:" + datepart); - } + default: + throw new RuntimeException("unknown datepart:" + datepart); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Day.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Day.java index 69757ee7e..5ed466e66 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Day.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Day.java @@ -20,43 +20,45 @@ package org.apache.geaflow.dsl.udf.table.date; import java.sql.Timestamp; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "day", description = - "Get number of days of the date or datetime expression expr.") +@Description( + name = "day", + description = "Get number of days of the date or datetime expression expr.") public class Day extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - DateTimeFormatter formatter; - try { - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - - java.util.Date date = formatter.parseDateTime(dateString).toDate(); - return date.getDate(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + + public Integer eval(String dateString) { + if (dateString == null) { + return null; + } + DateTimeFormatter formatter; + try { + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + + java.util.Date date = formatter.parseDateTime(dateString).toDate(); + return date.getDate(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public Integer eval(Timestamp i) { - if (i == null) { - return null; - } - return i.getDate(); + public Integer eval(Timestamp i) { + if (i == null) { + return null; } + return i.getDate(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DayOfMonth.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DayOfMonth.java index 778f6d152..bbd2d241e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DayOfMonth.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/DayOfMonth.java @@ -22,45 +22,45 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "day_of_month", description = - "Returns the date of the month of date.") +@Description(name = "day_of_month", description = "Returns the date of the month of date.") public class DayOfMonth extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - DateTimeFormatter formatter; - try { - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - Date date = formatter.parseDateTime(dateString).toDate(); - calendar.setTime(date); - return calendar.get(Calendar.DAY_OF_MONTH); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Integer eval(String dateString) { + if (dateString == null) { + return null; + } + DateTimeFormatter formatter; + try { + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + Date date = formatter.parseDateTime(dateString).toDate(); + calendar.setTime(date); + return calendar.get(Calendar.DAY_OF_MONTH); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - calendar.setTime(t); - return calendar.get(Calendar.DAY_OF_MONTH); + public Integer eval(Timestamp t) { + if (t == null) { + return null; } + calendar.setTime(t); + return calendar.get(Calendar.DAY_OF_MONTH); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTime.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTime.java index 336ca4f64..35540b95f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTime.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTime.java @@ -27,31 +27,31 @@ @Description(name = "from_unixtime", description = "Translate unix timestamp to the date string.") public class FromUnixTime extends UDF { - private DateTimeFormatter formatter; - private String lastFormat; + private DateTimeFormatter formatter; + private String lastFormat; - private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss"; + private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss"; - public String eval(String unixTime, String format) { - return eval(Long.valueOf(unixTime), format); - } + public String eval(String unixTime, String format) { + return eval(Long.valueOf(unixTime), format); + } - public String eval(Long unixTime, String format) { - return evaluate(unixTime, format); - } + public String eval(Long unixTime, String format) { + return evaluate(unixTime, format); + } - public String eval(Long unixTime) { - return eval(unixTime, DEFAULT_FORMAT); - } + public String eval(Long unixTime) { + return eval(unixTime, DEFAULT_FORMAT); + } - private String evaluate(Long unixTime, String format) { - if (unixTime == null || format == null) { - return null; - } - if (!format.equals(lastFormat)) { - formatter = DateTimeFormat.forPattern(format); - lastFormat = format; - } - return formatter.print(unixTime * 1000L); + private String evaluate(Long unixTime, String format) { + if (unixTime == null || format == null) { + return null; + } + if (!format.equals(lastFormat)) { + formatter = DateTimeFormat.forPattern(format); + lastFormat = format; } + return formatter.print(unixTime * 1000L); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTimeMillis.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTimeMillis.java index 08b6a36c4..149abdb06 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTimeMillis.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/FromUnixTimeMillis.java @@ -24,38 +24,40 @@ import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "from_unixtime_millis", description = "Translate unix timestamp to date string.") +@Description( + name = "from_unixtime_millis", + description = "Translate unix timestamp to date string.") public class FromUnixTimeMillis extends UDF { - private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; + private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; - private DateTimeFormatter lastFormatter; - private String lastFormat; + private DateTimeFormatter lastFormatter; + private String lastFormat; - public String eval(String unixTime) { - return eval(Long.valueOf(unixTime), DEFAULT_FORMAT); - } + public String eval(String unixTime) { + return eval(Long.valueOf(unixTime), DEFAULT_FORMAT); + } - public String eval(Long unixTime, String format) { - if (unixTime == null || format == null) { - return null; - } - return evaluate(unixTime, format); + public String eval(Long unixTime, String format) { + if (unixTime == null || format == null) { + return null; } + return evaluate(unixTime, format); + } - public String eval(Long unixTime) { - if (unixTime == null) { - return null; - } - return eval(unixTime, DEFAULT_FORMAT); + public String eval(Long unixTime) { + if (unixTime == null) { + return null; } + return eval(unixTime, DEFAULT_FORMAT); + } - private String evaluate(Long unixTime, String format) { - if (!format.equals(lastFormat)) { - lastFormatter = DateTimeFormat.forPattern(format); - lastFormat = format; - } - - return lastFormatter.print(unixTime); + private String evaluate(Long unixTime, String format) { + if (!format.equals(lastFormat)) { + lastFormatter = DateTimeFormat.forPattern(format); + lastFormat = format; } + + return lastFormatter.print(unixTime); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Hour.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Hour.java index 5c82f2fcc..8ffe1de18 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Hour.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Hour.java @@ -22,6 +22,7 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -31,37 +32,36 @@ @Description(name = "hour", description = "Returns the hour of date.") public class Hour extends UDF { - private static final DateTimeFormatter FORMATTER1 = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); - private final Calendar calendar = Calendar.getInstance(); - - public Integer eval(String dateString) { + private static final DateTimeFormatter FORMATTER1 = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); + private final Calendar calendar = Calendar.getInstance(); - if (dateString == null) { - return null; - } + public Integer eval(String dateString) { - try { - Date date; - try { - date = FORMATTER1.parseDateTime(dateString).toDate(); - } catch (Exception e) { - date = FORMATTER2.parseDateTime(dateString).toDate(); - } - calendar.setTime(date); - return calendar.get(Calendar.HOUR_OF_DAY); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + if (dateString == null) { + return null; } - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - calendar.setTime(t); - return calendar.get(Calendar.HOUR); + try { + Date date; + try { + date = FORMATTER1.parseDateTime(dateString).toDate(); + } catch (Exception e) { + date = FORMATTER2.parseDateTime(dateString).toDate(); + } + calendar.setTime(date); + return calendar.get(Calendar.HOUR_OF_DAY); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } + public Integer eval(Timestamp t) { + if (t == null) { + return null; + } + calendar.setTime(t); + return calendar.get(Calendar.HOUR); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/IsDate.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/IsDate.java index d2d7cd91c..8b44567c8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/IsDate.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/IsDate.java @@ -26,28 +26,28 @@ @Description(name = "isdate", description = "Returns whether the string is a date format.") public class IsDate extends UDF { - private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss"; + private static final String DEFAULT_FORMAT = "yyyy-MM-dd HH:mm:ss"; - private DateTimeFormatter lastFormatter; - private String lastFormat; + private DateTimeFormatter lastFormatter; + private String lastFormat; - public Boolean eval(String date, String format) { - if (date == null || format == null) { - return false; - } - if (lastFormat == null || !lastFormat.equals(format)) { - lastFormatter = DateTimeFormat.forPattern(format); - lastFormat = format; - } - try { - lastFormatter.parseDateTime(date); - } catch (Exception e) { - return false; - } - return true; + public Boolean eval(String date, String format) { + if (date == null || format == null) { + return false; } - - public Boolean eval(String date) { - return eval(date, DEFAULT_FORMAT); + if (lastFormat == null || !lastFormat.equals(format)) { + lastFormatter = DateTimeFormat.forPattern(format); + lastFormat = format; + } + try { + lastFormatter.parseDateTime(date); + } catch (Exception e) { + return false; } + return true; + } + + public Boolean eval(String date) { + return eval(date, DEFAULT_FORMAT); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/LastDay.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/LastDay.java index 9699f053a..81bfd0e3b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/LastDay.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/LastDay.java @@ -21,38 +21,40 @@ import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -@Description(name = "lastday", description = "Returns the last day of the month which the date " - + "belongs to .") +@Description( + name = "lastday", + description = "Returns the last day of the month which the date " + "belongs to .") public class LastDay extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); - - public String eval(String date) { - DateTimeFormatter formatter; - if (date != null && date.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - try { - calendar.setTime(formatter.parseDateTime(date).toDate()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - calendar.set(Calendar.DAY_OF_MONTH, 1); - calendar.add(Calendar.MONTH, 1); - calendar.add(Calendar.DAY_OF_MONTH, -1); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private final Calendar calendar = Calendar.getInstance(); - Date lastDay = calendar.getTime(); - return DATE_FORMATTER.print(lastDay.getTime()) + " 00:00:00"; + public String eval(String date) { + DateTimeFormatter formatter; + if (date != null && date.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; } + try { + calendar.setTime(formatter.parseDateTime(date).toDate()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + calendar.set(Calendar.DAY_OF_MONTH, 1); + calendar.add(Calendar.MONTH, 1); + calendar.add(Calendar.DAY_OF_MONTH, -1); + + Date lastDay = calendar.getTime(); + return DATE_FORMATTER.print(lastDay.getTime()) + " 00:00:00"; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Minute.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Minute.java index 6d0b75a3d..0381f525e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Minute.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Minute.java @@ -22,6 +22,7 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -31,35 +32,34 @@ @Description(name = "minute", description = "Returns the minute of date.") public class Minute extends UDF { - private static final DateTimeFormatter FORMATTER1 = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); - private final Calendar calendar = Calendar.getInstance(); + private static final DateTimeFormatter FORMATTER1 = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); + private final Calendar calendar = Calendar.getInstance(); - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - try { - Date date; - try { - date = FORMATTER1.parseDateTime(dateString).toDate(); - } catch (Exception e) { - date = FORMATTER2.parseDateTime(dateString).toDate(); - } - calendar.setTime(date); - return calendar.get(Calendar.MINUTE); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Integer eval(String dateString) { + if (dateString == null) { + return null; } - - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - calendar.setTime(t); - return calendar.get(Calendar.MINUTE); + try { + Date date; + try { + date = FORMATTER1.parseDateTime(dateString).toDate(); + } catch (Exception e) { + date = FORMATTER2.parseDateTime(dateString).toDate(); + } + calendar.setTime(date); + return calendar.get(Calendar.MINUTE); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } + public Integer eval(Timestamp t) { + if (t == null) { + return null; + } + calendar.setTime(t); + return calendar.get(Calendar.MINUTE); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Month.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Month.java index 3cd37b1b6..acd7550c9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Month.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Month.java @@ -22,6 +22,7 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -31,38 +32,37 @@ @Description(name = "month", description = "Returns the month of date.") public class Month extends UDF { - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private final Calendar calendar = Calendar.getInstance(); + private final Calendar calendar = Calendar.getInstance(); - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - DateTimeFormatter formatter; - try { - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } else { - formatter = DATE_TIME_FORMATTER; - } - Date date = formatter.parseDateTime(dateString).toDate(); - calendar.setTime(date); - return 1 + calendar.get(Calendar.MONTH); - } catch (Exception e) { - throw new GeaflowRuntimeException( - "Extract month field failed, please check! input:" + dateString, e); - } + public Integer eval(String dateString) { + if (dateString == null) { + return null; } - - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - calendar.setTime(t); - return 1 + calendar.get(Calendar.MONTH); + DateTimeFormatter formatter; + try { + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } else { + formatter = DATE_TIME_FORMATTER; + } + Date date = formatter.parseDateTime(dateString).toDate(); + calendar.setTime(date); + return 1 + calendar.get(Calendar.MONTH); + } catch (Exception e) { + throw new GeaflowRuntimeException( + "Extract month field failed, please check! input:" + dateString, e); } + } + public Integer eval(Timestamp t) { + if (t == null) { + return null; + } + calendar.setTime(t); + return 1 + calendar.get(Calendar.MONTH); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Now.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Now.java index f62461215..dd33d97f7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Now.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Now.java @@ -20,24 +20,25 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Date; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "now", description = "Returns current timestamp.") public class Now extends UDF { - public Long eval() { - Date date = new Date(); - return date.getTime() / 1000; - } + public Long eval() { + Date date = new Date(); + return date.getTime() / 1000; + } - public Long eval(Integer offset) { - Date date = new Date(); - return date.getTime() / 1000 + offset; - } + public Long eval(Integer offset) { + Date date = new Date(); + return date.getTime() / 1000 + offset; + } - public Long eval(Long offset) { - Date date = new Date(); - return date.getTime() / 1000 + offset; - } + public Long eval(Long offset) { + Date date = new Date(); + return date.getTime() / 1000 + offset; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Second.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Second.java index f26eb92ba..93a877aca 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Second.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Second.java @@ -22,6 +22,7 @@ import java.sql.Timestamp; import java.util.Calendar; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -31,38 +32,37 @@ @Description(name = "second", description = "Returns the second of date.") public class Second extends UDF { - private static final DateTimeFormatter FORMATTER1 = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); - private final Calendar calendar = Calendar.getInstance(); - - public Integer eval(String dateString) { + private static final DateTimeFormatter FORMATTER1 = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter FORMATTER2 = DateTimeFormat.forPattern("HH:mm:ss"); + private final Calendar calendar = Calendar.getInstance(); - if (dateString == null) { - return null; - } + public Integer eval(String dateString) { - try { - Date date; - try { - date = FORMATTER1.parseDateTime(dateString).toDate(); - } catch (Exception e) { - date = FORMATTER2.parseDateTime(dateString).toDate(); - } - calendar.setTime(date); - return calendar.get(Calendar.SECOND); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + if (dateString == null) { + return null; } - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } + try { + Date date; + try { + date = FORMATTER1.parseDateTime(dateString).toDate(); + } catch (Exception e) { + date = FORMATTER2.parseDateTime(dateString).toDate(); + } + calendar.setTime(date); + return calendar.get(Calendar.SECOND); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } - calendar.setTime(t); - return calendar.get(Calendar.SECOND); + public Integer eval(Timestamp t) { + if (t == null) { + return null; } + calendar.setTime(t); + return calendar.get(Calendar.SECOND); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStamp.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStamp.java index dd9a3a143..ea9468362 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStamp.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStamp.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -29,60 +30,60 @@ @Description(name = "unix_timestamp", description = "Returns the UNIX timestamp.") public class UnixTimeStamp extends UDF { - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter MILLS_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss.SSSSSS"); - private String lastPatternText; - private DateTimeFormatter formatter; + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter MILLS_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"); + private String lastPatternText; + private DateTimeFormatter formatter; - public Long eval() { - Date date = new Date(); - return date.getTime() / 1000; - } + public Long eval() { + Date date = new Date(); + return date.getTime() / 1000; + } - public Long eval(Object dateText) { - if (dateText == null) { - return eval(); - } - return eval(String.valueOf(dateText)); + public Long eval(Object dateText) { + if (dateText == null) { + return eval(); } + return eval(String.valueOf(dateText)); + } - public Long eval(String dateText) { - if (dateText == null) { - return eval(); - } + public Long eval(String dateText) { + if (dateText == null) { + return eval(); + } - DateTimeFormatter formatter; + DateTimeFormatter formatter; - if (dateText.length() <= 10) { - formatter = DATE_FORMATTER; - } else if (dateText.length() <= 19) { - formatter = DATE_TIME_FORMATTER; - } else { - formatter = MILLS_FORMATTER; - } - try { - return formatter.parseDateTime(dateText).getMillis() / 1000; - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + if (dateText.length() <= 10) { + formatter = DATE_FORMATTER; + } else if (dateText.length() <= 19) { + formatter = DATE_TIME_FORMATTER; + } else { + formatter = MILLS_FORMATTER; } + try { + return formatter.parseDateTime(dateText).getMillis() / 1000; + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } - public Long eval(String dateText, String patternText) { - if (dateText == null || patternText == null) { - return null; - } - try { - if (!patternText.equals(lastPatternText)) { - formatter = DateTimeFormat.forPattern(patternText); - lastPatternText = patternText; - } - return formatter.parseDateTime(dateText).getMillis() / 1000; + public Long eval(String dateText, String patternText) { + if (dateText == null || patternText == null) { + return null; + } + try { + if (!patternText.equals(lastPatternText)) { + formatter = DateTimeFormat.forPattern(patternText); + lastPatternText = patternText; + } + return formatter.parseDateTime(dateText).getMillis() / 1000; - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStampMillis.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStampMillis.java index 90121843a..91346d5ec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStampMillis.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/UnixTimeStampMillis.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.date; import java.util.Date; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -29,57 +30,58 @@ @Description(name = "unix_timestamp_millis", description = "Returns the UNIX timestamp.") public class UnixTimeStampMillis extends UDF { - String lastPatternText; - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); - private static final DateTimeFormatter MILLS_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss.SSSSSS"); + String lastPatternText; + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter MILLS_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"); - private DateTimeFormatter formatter = MILLS_FORMATTER; + private DateTimeFormatter formatter = MILLS_FORMATTER; - public Long eval() { - Date date = new Date(); - return date.getTime(); - } + public Long eval() { + Date date = new Date(); + return date.getTime(); + } - public Long eval(Object dateText) { - if (dateText == null) { - return eval(); - } - return eval(String.valueOf(dateText)); + public Long eval(Object dateText) { + if (dateText == null) { + return eval(); } + return eval(String.valueOf(dateText)); + } - public Long eval(String dateText) { - if (dateText == null) { - return eval(); - } - DateTimeFormatter formatter; - if (dateText.length() <= 10) { - formatter = DATE_FORMATTER; - } else if (dateText.length() <= 19) { - formatter = DATE_TIME_FORMATTER; - } else { - formatter = MILLS_FORMATTER; - } - try { - return formatter.parseDateTime(dateText).getMillis(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Long eval(String dateText) { + if (dateText == null) { + return eval(); + } + DateTimeFormatter formatter; + if (dateText.length() <= 10) { + formatter = DATE_FORMATTER; + } else if (dateText.length() <= 19) { + formatter = DATE_TIME_FORMATTER; + } else { + formatter = MILLS_FORMATTER; } + try { + return formatter.parseDateTime(dateText).getMillis(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } - public Long eval(String dateText, String patternText) { - if (dateText == null || patternText == null) { - return null; - } - try { - if (!patternText.equals(lastPatternText)) { - formatter = DateTimeFormat.forPattern(patternText); - lastPatternText = patternText; - } - return formatter.parseDateTime(dateText).getMillis(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Long eval(String dateText, String patternText) { + if (dateText == null || patternText == null) { + return null; + } + try { + if (!patternText.equals(lastPatternText)) { + formatter = DateTimeFormat.forPattern(patternText); + lastPatternText = patternText; + } + return formatter.parseDateTime(dateText).getMillis(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekDay.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekDay.java index fe60aa060..3b35588f6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekDay.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekDay.java @@ -28,22 +28,22 @@ @Description(name = "weekday", description = "Returns weekday of the date.") public class WeekDay extends UDF { - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); - public Integer eval(String date) { - if (date == null) { - return null; - } - DateTimeFormatter formatter = DATE_TIME_FORMATTER; - if (date.length() <= 10) { - formatter = DATE_FORMATTER; - } - try { - return formatter.parseDateTime(date).getDayOfWeek(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Integer eval(String date) { + if (date == null) { + return null; } + DateTimeFormatter formatter = DATE_TIME_FORMATTER; + if (date.length() <= 10) { + formatter = DATE_FORMATTER; + } + try { + return formatter.parseDateTime(date).getDayOfWeek(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekOfYear.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekOfYear.java index 31d3f3e85..dc9503a31 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekOfYear.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/WeekOfYear.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.date; import java.sql.Timestamp; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -30,30 +31,30 @@ @Description(name = "week_of_year", description = "Returns the week of the year of the given date.") public class WeekOfYear extends UDF { - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private final DateTime dateTime = new DateTime(); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private final DateTime dateTime = new DateTime(); - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - DateTimeFormatter formatter = DATE_TIME_FORMATTER; - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } - try { - return formatter.parseDateTime(dateString).getWeekOfWeekyear(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Integer eval(String dateString) { + if (dateString == null) { + return null; + } + DateTimeFormatter formatter = DATE_TIME_FORMATTER; + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } + try { + return formatter.parseDateTime(dateString).getWeekOfWeekyear(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - return dateTime.withMillis(t.getTime()).getWeekOfWeekyear(); + public Integer eval(Timestamp t) { + if (t == null) { + return null; } + return dateTime.withMillis(t.getTime()).getWeekOfWeekyear(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Year.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Year.java index 38ecfb878..06e544ce7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Year.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/date/Year.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.date; import java.sql.Timestamp; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -30,37 +31,37 @@ @Description(name = "_year", description = "Returns the year of date.") public class Year extends UDF { - private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); - private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormat.forPattern( - "yyyy-MM-dd HH:mm:ss"); - private final DateTime dateTime = new DateTime(); + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormat.forPattern("yyyy-MM-dd"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss"); + private final DateTime dateTime = new DateTime(); - public Integer eval(String dateString) { - if (dateString == null) { - return null; - } - DateTimeFormatter formatter = DATE_TIME_FORMATTER; - if (dateString.length() <= 10) { - formatter = DATE_FORMATTER; - } - try { - return formatter.parseDateTime(dateString).getYear(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + public Integer eval(String dateString) { + if (dateString == null) { + return null; + } + DateTimeFormatter formatter = DATE_TIME_FORMATTER; + if (dateString.length() <= 10) { + formatter = DATE_FORMATTER; + } + try { + return formatter.parseDateTime(dateString).getYear(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } + } - public Integer eval(Timestamp t) { - if (t == null) { - return null; - } - return dateTime.withMillis(t.getTime()).getYear(); + public Integer eval(Timestamp t) { + if (t == null) { + return null; } + return dateTime.withMillis(t.getTime()).getYear(); + } - public Integer eval(Long t) { - if (t == null) { - return null; - } - return dateTime.withMillis(t).getYear(); + public Integer eval(Long t) { + if (t == null) { + return null; } + return dateTime.withMillis(t).getYear(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/E.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/E.java index 116aa87e9..95ada8dd0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/E.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/E.java @@ -25,7 +25,7 @@ @Description(name = "E", description = "returns the euler constant.") public class E extends UDF { - public Double eval() { - return java.lang.Math.E; - } + public Double eval() { + return java.lang.Math.E; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Log2.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Log2.java index 45bcb0c16..3276d1db8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Log2.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Log2.java @@ -25,14 +25,13 @@ @Description(name = "log2", description = "Returns the logarithm base 2.") public class Log2 extends UDF { - private static final double LOG2 = Math.log(2.0); + private static final double LOG2 = Math.log(2.0); - public Double eval(Double a) { - if (a == null || a <= 0.0) { - return null; - } else { - return Math.log(a) / LOG2; - } + public Double eval(Double a) { + if (a == null || a <= 0.0) { + return null; + } else { + return Math.log(a) / LOG2; } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Round.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Round.java index b0742bd5d..8ca30d3bc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Round.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/math/Round.java @@ -21,47 +21,48 @@ import java.math.BigDecimal; import java.math.RoundingMode; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "round", description = "round x to d decimal places") public class Round extends UDF { - private Double eval(Double n, int i) { - double d = n; - if (Double.isNaN(d) || Double.isInfinite(d)) { - return d; - } else { - return BigDecimal.valueOf(d).setScale(i, RoundingMode.HALF_UP).doubleValue(); - } + private Double eval(Double n, int i) { + double d = n; + if (Double.isNaN(d) || Double.isInfinite(d)) { + return d; + } else { + return BigDecimal.valueOf(d).setScale(i, RoundingMode.HALF_UP).doubleValue(); } + } - public Double eval(Double n) { - if (n == null) { - return null; - } - return eval(n, 0); + public Double eval(Double n) { + if (n == null) { + return null; } + return eval(n, 0); + } - public Long eval(Long n) { - return n; - } + public Long eval(Long n) { + return n; + } - public Integer eval(Integer n) { - return n; - } + public Integer eval(Integer n) { + return n; + } - public Double eval(Double n, Long i) { - if ((n == null) || (i == null)) { - return null; - } - return eval(n, i.intValue()); + public Double eval(Double n, Long i) { + if ((n == null) || (i == null)) { + return null; } + return eval(n, i.intValue()); + } - public Double eval(Double n, Integer i) { - if ((n == null) || (i == null)) { - return null; - } - return eval(n, i.intValue()); + public Double eval(Double n, Integer i) { + if ((n == null) || (i == null)) { + return null; } + return eval(n, i.intValue()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Direction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Direction.java index 157be6362..d67433e67 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Direction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Direction.java @@ -30,12 +30,12 @@ @Description(name = "direction", description = "Returns direction for edge") public class Direction extends UDF implements GraphMetaFieldAccessFunction { - public BinaryString eval(RowEdge edge) { - return BinaryString.fromString(edge.getDirect().name()); - } + public BinaryString eval(RowEdge edge) { + return BinaryString.fromString(edge.getDirect().name()); + } - @Override - public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { - return GraphSchemaUtil.getCurrentGraphDirectionType(typeFactory); - } + @Override + public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { + return GraphSchemaUtil.getCurrentGraphDirectionType(typeFactory); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeSrcId.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeSrcId.java index 7bcf5eeef..0394631cf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeSrcId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeSrcId.java @@ -29,12 +29,12 @@ @Description(name = "srcId", description = "Returns srcId for edge") public class EdgeSrcId extends UDF implements GraphMetaFieldAccessFunction { - public Object eval(RowEdge edge) { - return edge.getSrcId(); - } + public Object eval(RowEdge edge) { + return edge.getSrcId(); + } - @Override - public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { - return GraphSchemaUtil.getCurrentGraphEdgeSrcIdType(typeFactory); - } + @Override + public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { + return GraphSchemaUtil.getCurrentGraphEdgeSrcIdType(typeFactory); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTargetId.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTargetId.java index 684638e3f..5d5033097 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTargetId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTargetId.java @@ -29,12 +29,12 @@ @Description(name = "targetId", description = "Returns targetId for edge") public class EdgeTargetId extends UDF implements GraphMetaFieldAccessFunction { - public Object eval(RowEdge edge) { - return edge.getTargetId(); - } + public Object eval(RowEdge edge) { + return edge.getTargetId(); + } - @Override - public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { - return GraphSchemaUtil.getCurrentGraphEdgeTargetIdType(typeFactory); - } + @Override + public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { + return GraphSchemaUtil.getCurrentGraphEdgeTargetIdType(typeFactory); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTimestamp.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTimestamp.java index 914ba883e..db9b8b46b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTimestamp.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/EdgeTimestamp.java @@ -30,16 +30,16 @@ @Description(name = "ts", description = "Returns ts for edge with timestamp") public class EdgeTimestamp extends UDF implements GraphMetaFieldAccessFunction { - public Long eval(IGraphElementWithTimeField edge) { - return edge.getTime(); - } + public Long eval(IGraphElementWithTimeField edge) { + return edge.getTime(); + } - @Override - public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { - if (GraphSchemaUtil.getCurrentGraphEdgeTimestampType(typeFactory).isPresent()) { - return GraphSchemaUtil.getCurrentGraphEdgeTimestampType(typeFactory).get(); - } else { - throw new GeaFlowDSLException("Cannot find timestamp type"); - } + @Override + public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { + if (GraphSchemaUtil.getCurrentGraphEdgeTimestampType(typeFactory).isPresent()) { + return GraphSchemaUtil.getCurrentGraphEdgeTimestampType(typeFactory).get(); + } else { + throw new GeaFlowDSLException("Cannot find timestamp type"); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/GraphMetaFieldAccessFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/GraphMetaFieldAccessFunction.java index 0d87e21a4..2ccb8724c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/GraphMetaFieldAccessFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/GraphMetaFieldAccessFunction.java @@ -24,5 +24,5 @@ public interface GraphMetaFieldAccessFunction { - RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory); + RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/If.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/If.java index 8423e87d3..53a35dc08 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/If.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/If.java @@ -22,48 +22,50 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "if", description = "Return true value if condition is true, else return false value.") +@Description( + name = "if", + description = "Return true value if condition is true, else return false value.") public class If extends UDF { - public Boolean eval(Boolean condition, Boolean trueValue, Boolean falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public Boolean eval(Boolean condition, Boolean trueValue, Boolean falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } - public Integer eval(Boolean condition, Integer trueValue, Integer falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public Integer eval(Boolean condition, Integer trueValue, Integer falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } - public Double eval(Boolean condition, Double trueValue, Double falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public Double eval(Boolean condition, Double trueValue, Double falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } - public Long eval(Boolean condition, Long trueValue, Long falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public Long eval(Boolean condition, Long trueValue, Long falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } - public String eval(Boolean condition, String trueValue, String falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public String eval(Boolean condition, String trueValue, String falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } - public Object eval(Boolean condition, Object trueValue, Object falseValue) { - if (condition != null && condition) { - return trueValue; - } - return falseValue; + public Object eval(Boolean condition, Object trueValue, Object falseValue) { + if (condition != null && condition) { + return trueValue; } + return falseValue; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDecimal.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDecimal.java index 7fe72abc1..651121968 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDecimal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDecimal.java @@ -22,48 +22,50 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "is_decimal", description = "Returns true if only contains digits and is " - + "non-null, otherwise return false.") +@Description( + name = "is_decimal", + description = + "Returns true if only contains digits and is " + "non-null, otherwise return false.") public class IsDecimal extends UDF { - public boolean eval(String s) { - if (s == null) { - return false; - } - if (isInteger(s) || isLong(s) || isDouble(s)) { - return true; - } else { - return false; - } + public boolean eval(String s) { + if (s == null) { + return false; } + if (isInteger(s) || isLong(s) || isDouble(s)) { + return true; + } else { + return false; + } + } - private boolean isInteger(String s) { - boolean flag = true; - try { - Integer.parseInt(s); - } catch (NumberFormatException e) { - flag = false; - } - return flag; + private boolean isInteger(String s) { + boolean flag = true; + try { + Integer.parseInt(s); + } catch (NumberFormatException e) { + flag = false; } + return flag; + } - private boolean isLong(String s) { - boolean flag = true; - try { - Long.parseLong(s); - } catch (NumberFormatException e) { - flag = false; - } - return flag; + private boolean isLong(String s) { + boolean flag = true; + try { + Long.parseLong(s); + } catch (NumberFormatException e) { + flag = false; } + return flag; + } - private boolean isDouble(String s) { - boolean flag = true; - try { - Double.parseDouble(s); - } catch (NumberFormatException e) { - flag = false; - } - return flag; + private boolean isDouble(String s) { + boolean flag = true; + try { + Double.parseDouble(s); + } catch (NumberFormatException e) { + flag = false; } + return flag; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDestinationOf.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDestinationOf.java index 3beeb7a3d..59f42139e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDestinationOf.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsDestinationOf.java @@ -30,23 +30,26 @@ * *

Implements ISO-GQL Section 19.10: <source/destination predicate> * - *

Syntax:

+ *

Syntax: + * *

  *   IS_DESTINATION_OF(node, edge)
  * 
* - *

Semantics:

- * Returns TRUE if the node is the destination of the edge, FALSE otherwise, or NULL if either operand is NULL. + *

Semantics: Returns TRUE if the node is the destination of the edge, FALSE otherwise, or + * NULL if either operand is NULL. + * + *

ISO-GQL Rules: * - *

ISO-GQL Rules:

*
    - *
  • If node or edge is null, result is Unknown (null)
  • - *
  • If edge is undirected, result is False
  • - *
  • If node.id equals edge.targetId, result is True
  • - *
  • Otherwise, result is False
  • + *
  • If node or edge is null, result is Unknown (null) + *
  • If edge is undirected, result is False + *
  • If node.id equals edge.targetId, result is True + *
  • Otherwise, result is False *
* - *

Example:

+ *

Example: + * *

  * MATCH (a) -[e]-> (b)
  * WHERE IS_DESTINATION_OF(b, e)
@@ -55,26 +58,24 @@
  */
 @Description(
     name = "is_destination_of",
-    description = "ISO-GQL Destination Predicate: Returns TRUE if node is the destination of edge, "
-        + "FALSE if not, NULL if either operand is NULL. Follows ISO-GQL three-valued logic."
-)
+    description =
+        "ISO-GQL Destination Predicate: Returns TRUE if node is the destination of edge, "
+            + "FALSE if not, NULL if either operand is NULL. Follows ISO-GQL three-valued logic.")
 public class IsDestinationOf extends UDF {
 
-    /**
-     * Evaluates IS DESTINATION OF predicate.
-     *
-     * @param nodeValue vertex/node to check (should be RowVertex)
-     * @param edgeValue edge to check (should be RowEdge)
-     * @return Boolean: true if node is destination of edge, false if not, null if either is null
-     */
-    public Boolean eval(Object nodeValue, Object edgeValue) {
-        return SourceDestinationFunctions.isDestinationOf(nodeValue, edgeValue);
-    }
+  /**
+   * Evaluates IS DESTINATION OF predicate.
+   *
+   * @param nodeValue vertex/node to check (should be RowVertex)
+   * @param edgeValue edge to check (should be RowEdge)
+   * @return Boolean: true if node is destination of edge, false if not, null if either is null
+   */
+  public Boolean eval(Object nodeValue, Object edgeValue) {
+    return SourceDestinationFunctions.isDestinationOf(nodeValue, edgeValue);
+  }
 
-    /**
-     * Type-specific overload for better type checking.
-     */
-    public Boolean eval(RowVertex node, RowEdge edge) {
-        return SourceDestinationFunctions.isDestinationOf(node, edge);
-    }
+  /** Type-specific overload for better type checking. */
+  public Boolean eval(RowVertex node, RowEdge edge) {
+    return SourceDestinationFunctions.isDestinationOf(node, edge);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotDestinationOf.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotDestinationOf.java
index b3cf5c44e..d89ee4d2c 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotDestinationOf.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotDestinationOf.java
@@ -30,21 +30,24 @@
  *
  * 

Implements ISO-GQL Section 19.10: <source/destination predicate> * - *

Syntax:

+ *

Syntax: + * *

  *   IS_NOT_DESTINATION_OF(node, edge)
  * 
* - *

Semantics:

- * Returns TRUE if the node is NOT the destination of the edge, FALSE if it is, or NULL if either operand is NULL. + *

Semantics: Returns TRUE if the node is NOT the destination of the edge, FALSE if it is, + * or NULL if either operand is NULL. + * + *

ISO-GQL Rules: * - *

ISO-GQL Rules:

*
    - *
  • If node or edge is null, result is Unknown (null)
  • - *
  • Otherwise, returns the negation of IS_DESTINATION_OF result
  • + *
  • If node or edge is null, result is Unknown (null) + *
  • Otherwise, returns the negation of IS_DESTINATION_OF result *
* - *

Example:

+ *

Example: + * *

  * MATCH (a) -[e]-> (b)
  * WHERE IS_NOT_DESTINATION_OF(a, e)  -- Always TRUE for this pattern
@@ -53,26 +56,25 @@
  */
 @Description(
     name = "is_not_destination_of",
-    description = "ISO-GQL Destination Predicate: Returns TRUE if node is NOT the destination of edge, "
-        + "FALSE if it is, NULL if either operand is NULL. Follows ISO-GQL three-valued logic."
-)
+    description =
+        "ISO-GQL Destination Predicate: Returns TRUE if node is NOT the destination of edge, "
+            + "FALSE if it is, NULL if either operand is NULL. Follows ISO-GQL three-valued logic.")
 public class IsNotDestinationOf extends UDF {
 
-    /**
-     * Evaluates IS NOT DESTINATION OF predicate.
-     *
-     * @param nodeValue vertex/node to check (should be RowVertex)
-     * @param edgeValue edge to check (should be RowEdge)
-     * @return Boolean: true if node is NOT destination of edge, false if it is, null if either is null
-     */
-    public Boolean eval(Object nodeValue, Object edgeValue) {
-        return SourceDestinationFunctions.isNotDestinationOf(nodeValue, edgeValue);
-    }
+  /**
+   * Evaluates IS NOT DESTINATION OF predicate.
+   *
+   * @param nodeValue vertex/node to check (should be RowVertex)
+   * @param edgeValue edge to check (should be RowEdge)
+   * @return Boolean: true if node is NOT destination of edge, false if it is, null if either is
+   *     null
+   */
+  public Boolean eval(Object nodeValue, Object edgeValue) {
+    return SourceDestinationFunctions.isNotDestinationOf(nodeValue, edgeValue);
+  }
 
-    /**
-     * Type-specific overload for better type checking.
-     */
-    public Boolean eval(RowVertex node, RowEdge edge) {
-        return SourceDestinationFunctions.isNotDestinationOf(node, edge);
-    }
+  /** Type-specific overload for better type checking. */
+  public Boolean eval(RowVertex node, RowEdge edge) {
+    return SourceDestinationFunctions.isNotDestinationOf(node, edge);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotSourceOf.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotSourceOf.java
index c95c616c6..50af4ad62 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotSourceOf.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsNotSourceOf.java
@@ -30,21 +30,24 @@
  *
  * 

Implements ISO-GQL Section 19.10: <source/destination predicate> * - *

Syntax:

+ *

Syntax: + * *

  *   IS_NOT_SOURCE_OF(node, edge)
  * 
* - *

Semantics:

- * Returns TRUE if the node is NOT the source of the edge, FALSE if it is, or NULL if either operand is NULL. + *

Semantics: Returns TRUE if the node is NOT the source of the edge, FALSE if it is, or + * NULL if either operand is NULL. + * + *

ISO-GQL Rules: * - *

ISO-GQL Rules:

*
    - *
  • If node or edge is null, result is Unknown (null)
  • - *
  • Otherwise, returns the negation of IS_SOURCE_OF result
  • + *
  • If node or edge is null, result is Unknown (null) + *
  • Otherwise, returns the negation of IS_SOURCE_OF result *
* - *

Example:

+ *

Example: + * *

  * MATCH (a) -[e]-> (b)
  * WHERE IS_NOT_SOURCE_OF(a, e)  -- Always FALSE for this pattern
@@ -53,26 +56,24 @@
  */
 @Description(
     name = "is_not_source_of",
-    description = "ISO-GQL Source Predicate: Returns TRUE if node is NOT the source of edge, "
-        + "FALSE if it is, NULL if either operand is NULL. Follows ISO-GQL three-valued logic."
-)
+    description =
+        "ISO-GQL Source Predicate: Returns TRUE if node is NOT the source of edge, "
+            + "FALSE if it is, NULL if either operand is NULL. Follows ISO-GQL three-valued logic.")
 public class IsNotSourceOf extends UDF {
 
-    /**
-     * Evaluates IS NOT SOURCE OF predicate.
-     *
-     * @param nodeValue vertex/node to check (should be RowVertex)
-     * @param edgeValue edge to check (should be RowEdge)
-     * @return Boolean: true if node is NOT source of edge, false if it is, null if either is null
-     */
-    public Boolean eval(Object nodeValue, Object edgeValue) {
-        return SourceDestinationFunctions.isNotSourceOf(nodeValue, edgeValue);
-    }
+  /**
+   * Evaluates IS NOT SOURCE OF predicate.
+   *
+   * @param nodeValue vertex/node to check (should be RowVertex)
+   * @param edgeValue edge to check (should be RowEdge)
+   * @return Boolean: true if node is NOT source of edge, false if it is, null if either is null
+   */
+  public Boolean eval(Object nodeValue, Object edgeValue) {
+    return SourceDestinationFunctions.isNotSourceOf(nodeValue, edgeValue);
+  }
 
-    /**
-     * Type-specific overload for better type checking.
-     */
-    public Boolean eval(RowVertex node, RowEdge edge) {
-        return SourceDestinationFunctions.isNotSourceOf(node, edge);
-    }
+  /** Type-specific overload for better type checking. */
+  public Boolean eval(RowVertex node, RowEdge edge) {
+    return SourceDestinationFunctions.isNotSourceOf(node, edge);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsSourceOf.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsSourceOf.java
index c6d1cccde..12dfe7387 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsSourceOf.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/IsSourceOf.java
@@ -30,23 +30,26 @@
  *
  * 

Implements ISO-GQL Section 19.10: <source/destination predicate> * - *

Syntax:

+ *

Syntax: + * *

  *   IS_SOURCE_OF(node, edge)
  * 
* - *

Semantics:

- * Returns TRUE if the node is the source of the edge, FALSE otherwise, or NULL if either operand is NULL. + *

Semantics: Returns TRUE if the node is the source of the edge, FALSE otherwise, or NULL + * if either operand is NULL. + * + *

ISO-GQL Rules: * - *

ISO-GQL Rules:

*
    - *
  • If node or edge is null, result is Unknown (null)
  • - *
  • If edge is undirected, result is False
  • - *
  • If node.id equals edge.srcId, result is True
  • - *
  • Otherwise, result is False
  • + *
  • If node or edge is null, result is Unknown (null) + *
  • If edge is undirected, result is False + *
  • If node.id equals edge.srcId, result is True + *
  • Otherwise, result is False *
* - *

Example:

+ *

Example: + * *

  * MATCH (a) -[e]-> (b)
  * WHERE IS_SOURCE_OF(a, e)
@@ -55,26 +58,24 @@
  */
 @Description(
     name = "is_source_of",
-    description = "ISO-GQL Source Predicate: Returns TRUE if node is the source of edge, "
-        + "FALSE if not, NULL if either operand is NULL. Follows ISO-GQL three-valued logic."
-)
+    description =
+        "ISO-GQL Source Predicate: Returns TRUE if node is the source of edge, "
+            + "FALSE if not, NULL if either operand is NULL. Follows ISO-GQL three-valued logic.")
 public class IsSourceOf extends UDF {
 
-    /**
-     * Evaluates IS SOURCE OF predicate.
-     *
-     * @param nodeValue vertex/node to check (should be RowVertex)
-     * @param edgeValue edge to check (should be RowEdge)
-     * @return Boolean: true if node is source of edge, false if not, null if either is null
-     */
-    public Boolean eval(Object nodeValue, Object edgeValue) {
-        return SourceDestinationFunctions.isSourceOf(nodeValue, edgeValue);
-    }
+  /**
+   * Evaluates IS SOURCE OF predicate.
+   *
+   * @param nodeValue vertex/node to check (should be RowVertex)
+   * @param edgeValue edge to check (should be RowEdge)
+   * @return Boolean: true if node is source of edge, false if not, null if either is null
+   */
+  public Boolean eval(Object nodeValue, Object edgeValue) {
+    return SourceDestinationFunctions.isSourceOf(nodeValue, edgeValue);
+  }
 
-    /**
-     * Type-specific overload for better type checking.
-     */
-    public Boolean eval(RowVertex node, RowEdge edge) {
-        return SourceDestinationFunctions.isSourceOf(node, edge);
-    }
+  /** Type-specific overload for better type checking. */
+  public Boolean eval(RowVertex node, RowEdge edge) {
+    return SourceDestinationFunctions.isSourceOf(node, edge);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Label.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Label.java
index 82b40e6d5..00ecf62d5 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Label.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/Label.java
@@ -31,16 +31,16 @@
 @Description(name = "label", description = "Returns label for edge or vertex")
 public class Label extends UDF implements GraphMetaFieldAccessFunction {
 
-    public BinaryString eval(RowEdge edge) {
-        return edge.getBinaryLabel();
-    }
+  public BinaryString eval(RowEdge edge) {
+    return edge.getBinaryLabel();
+  }
 
-    public BinaryString eval(RowVertex vertex) {
-        return vertex.getBinaryLabel();
-    }
+  public BinaryString eval(RowVertex vertex) {
+    return vertex.getBinaryLabel();
+  }
 
-    @Override
-    public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) {
-        return GraphSchemaUtil.getCurrentGraphLabelType(typeFactory);
-    }
+  @Override
+  public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) {
+    return GraphSchemaUtil.getCurrentGraphLabelType(typeFactory);
+  }
 }
diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/PropertyExists.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/PropertyExists.java
index b8b41b5cf..780357a4d 100644
--- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/PropertyExists.java
+++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/PropertyExists.java
@@ -31,22 +31,25 @@
  *
  * 

Implements ISO-GQL Section 19.13: <property_exists predicate> * - *

Syntax:

+ *

Syntax: + * *

  *   PROPERTY_EXISTS(element, property_name)
  * 
* - *

Semantics:

- * Returns TRUE if the graph element has the specified property, FALSE otherwise, or NULL if the element is NULL. + *

Semantics: Returns TRUE if the graph element has the specified property, FALSE + * otherwise, or NULL if the element is NULL. + * + *

ISO-GQL Rules: * - *

ISO-GQL Rules:

*
    - *
  • If element is null, result is Unknown (null)
  • - *
  • If element has a property with the specified name, result is True
  • - *
  • Otherwise, result is False
  • + *
  • If element is null, result is Unknown (null) + *
  • If element has a property with the specified name, result is True + *
  • Otherwise, result is False *
* - *

Example:

+ *

Example: + * *

  * MATCH (p:Person)
  * WHERE PROPERTY_EXISTS(p, 'email')
@@ -55,77 +58,77 @@
  */
 @Description(
     name = "property_exists",
-    description = "ISO-GQL Property Exists Predicate: Returns TRUE if the graph element has "
-        + "the specified property, FALSE if not, NULL if element is NULL. "
-        + "Follows ISO-GQL three-valued logic."
-)
+    description =
+        "ISO-GQL Property Exists Predicate: Returns TRUE if the graph element has "
+            + "the specified property, FALSE if not, NULL if element is NULL. "
+            + "Follows ISO-GQL three-valued logic.")
 public class PropertyExists extends UDF {
 
-    /**
-     * Evaluates PROPERTY_EXISTS predicate for any graph element.
-     *
-     * 

This implementation follows the established GeaFlow pattern for ISO-GQL predicates, - * delegating to {@link PropertyExistsFunctions} utility class for consistent validation - * and error handling across the framework. - * - *

Implementation Strategy: - * Property existence validation relies on compile-time checking through GeaFlow's SQL - * optimizer and type system (StructType). At runtime, this function validates argument - * types and provides meaningful error messages. - * - *

This approach is consistent with: - *

    - *
  • Other ISO-GQL predicates (IS_SOURCE_OF, IS_DESTINATION_OF)
  • - *
  • GeaFlow's Row interface design (indexed access only)
  • - *
  • Three-layer architecture: UDF → Utility → Business Logic
  • - *
- * - * @param element graph element to check (Row, RowVertex, or RowEdge) - * @param propertyName name of property to check - * @return Boolean: null if element is null, true if property exists, false otherwise - * @throws IllegalArgumentException if element is not a valid graph element type - * @throws IllegalArgumentException if propertyName is null or empty - */ - public Boolean eval(Object element, String propertyName) { - return PropertyExistsFunctions.propertyExists(element, propertyName); - } + /** + * Evaluates PROPERTY_EXISTS predicate for any graph element. + * + *

This implementation follows the established GeaFlow pattern for ISO-GQL predicates, + * delegating to {@link PropertyExistsFunctions} utility class for consistent validation and error + * handling across the framework. + * + *

Implementation Strategy: Property existence validation relies on compile-time + * checking through GeaFlow's SQL optimizer and type system (StructType). At runtime, this + * function validates argument types and provides meaningful error messages. + * + *

This approach is consistent with: + * + *

    + *
  • Other ISO-GQL predicates (IS_SOURCE_OF, IS_DESTINATION_OF) + *
  • GeaFlow's Row interface design (indexed access only) + *
  • Three-layer architecture: UDF → Utility → Business Logic + *
+ * + * @param element graph element to check (Row, RowVertex, or RowEdge) + * @param propertyName name of property to check + * @return Boolean: null if element is null, true if property exists, false otherwise + * @throws IllegalArgumentException if element is not a valid graph element type + * @throws IllegalArgumentException if propertyName is null or empty + */ + public Boolean eval(Object element, String propertyName) { + return PropertyExistsFunctions.propertyExists(element, propertyName); + } - /** - * Type-specific overload for RowVertex. - * - *

Provides better type inference and error messages when called with vertex elements. - * - * @param vertex vertex to check - * @param propertyName name of property to check - * @return Boolean: null if vertex is null, true if property exists, false otherwise - */ - public Boolean eval(RowVertex vertex, String propertyName) { - return PropertyExistsFunctions.propertyExists(vertex, propertyName); - } + /** + * Type-specific overload for RowVertex. + * + *

Provides better type inference and error messages when called with vertex elements. + * + * @param vertex vertex to check + * @param propertyName name of property to check + * @return Boolean: null if vertex is null, true if property exists, false otherwise + */ + public Boolean eval(RowVertex vertex, String propertyName) { + return PropertyExistsFunctions.propertyExists(vertex, propertyName); + } - /** - * Type-specific overload for RowEdge. - * - *

Provides better type inference and error messages when called with edge elements. - * - * @param edge edge to check - * @param propertyName name of property to check - * @return Boolean: null if edge is null, true if property exists, false otherwise - */ - public Boolean eval(RowEdge edge, String propertyName) { - return PropertyExistsFunctions.propertyExists(edge, propertyName); - } + /** + * Type-specific overload for RowEdge. + * + *

Provides better type inference and error messages when called with edge elements. + * + * @param edge edge to check + * @param propertyName name of property to check + * @return Boolean: null if edge is null, true if property exists, false otherwise + */ + public Boolean eval(RowEdge edge, String propertyName) { + return PropertyExistsFunctions.propertyExists(edge, propertyName); + } - /** - * Type-specific overload for Row. - * - *

Provides better type inference and error messages when called with row elements. - * - * @param row row to check - * @param propertyName name of property to check - * @return Boolean: null if row is null, true if property exists, false otherwise - */ - public Boolean eval(Row row, String propertyName) { - return PropertyExistsFunctions.propertyExists(row, propertyName); - } + /** + * Type-specific overload for Row. + * + *

Provides better type inference and error messages when called with row elements. + * + * @param row row to check + * @param propertyName name of property to check + * @return Boolean: null if row is null, true if property exists, false otherwise + */ + public Boolean eval(Row row, String propertyName) { + return PropertyExistsFunctions.propertyExists(row, propertyName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/VertexId.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/VertexId.java index 9aaf2dcbf..1f5a65109 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/VertexId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/other/VertexId.java @@ -29,12 +29,12 @@ @Description(name = "id", description = "Returns id for vertex") public class VertexId extends UDF implements GraphMetaFieldAccessFunction { - public Object eval(RowVertex vertex) { - return vertex.getId(); - } + public Object eval(RowVertex vertex) { + return vertex.getId(); + } - @Override - public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { - return GraphSchemaUtil.getCurrentGraphVertexIdType(typeFactory); - } + @Override + public RelDataType getReturnRelDataType(GQLJavaTypeFactory typeFactory) { + return GraphSchemaUtil.getCurrentGraphVertexIdType(typeFactory); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Ascii2String.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Ascii2String.java index 26e8459c6..4ddd26dea 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Ascii2String.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Ascii2String.java @@ -25,17 +25,17 @@ @Description(name = "ascii2str", description = "Convert ascii code to string.") public class Ascii2String extends UDF { - public String eval(Integer ascii) { - if (ascii == null) { - return null; - } - return new String(new byte[]{ascii.byteValue()}); + public String eval(Integer ascii) { + if (ascii == null) { + return null; } + return new String(new byte[] {ascii.byteValue()}); + } - public String eval(Long ascii) { - if (ascii == null) { - return null; - } - return eval(ascii.intValue()); + public String eval(Long ascii) { + if (ascii == null) { + return null; } + return eval(ascii.intValue()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Decode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Decode.java index 9b9b4a3ce..cde1a20a9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Decode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Decode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.string; import java.nio.charset.StandardCharsets; + import org.apache.commons.codec.binary.Base64; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -27,9 +28,8 @@ @Description(name = "base64_decode", description = "Decode the string by Base64.") public class Base64Decode extends UDF { - public String eval(String s) { - return new String( - Base64.decodeBase64(s.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8); - } - + public String eval(String s) { + return new String( + Base64.decodeBase64(s.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Encode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Encode.java index 66e605a10..cbf9a2c38 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Encode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Base64Encode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.string; import java.nio.charset.StandardCharsets; + import org.apache.commons.codec.binary.Base64; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -27,9 +28,8 @@ @Description(name = "base64_encode", description = "Encode the string by Base64.") public class Base64Encode extends UDF { - public String eval(String s) { - return new String( - Base64.encodeBase64(s.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8); - } - + public String eval(String s) { + return new String( + Base64.encodeBase64(s.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Concat.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Concat.java index 118dfee58..ea6d2072b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Concat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Concat.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.string; import java.util.Objects; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -27,23 +28,23 @@ @Description(name = "concat", description = "Concat strings to one string.") public class Concat extends UDF { - public String eval(String... args) { - if (Objects.isNull(args)) { - return null; - } - StringBuilder sb = new StringBuilder(); - for (String arg : args) { - if (arg != null) { - sb.append(arg); - } - } - return sb.toString(); + public String eval(String... args) { + if (Objects.isNull(args)) { + return null; + } + StringBuilder sb = new StringBuilder(); + for (String arg : args) { + if (arg != null) { + sb.append(arg); + } } + return sb.toString(); + } - public BinaryString eval(BinaryString... args) { - if (Objects.isNull(args)) { - return null; - } - return BinaryString.concat(args); + public BinaryString eval(BinaryString... args) { + if (Objects.isNull(args)) { + return null; } + return BinaryString.concat(args); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/ConcatWS.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/ConcatWS.java index 2c5c4cbd4..b1c5b9542 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/ConcatWS.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/ConcatWS.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.udf.table.string; import java.util.Objects; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.function.Description; @@ -28,17 +29,17 @@ @Description(name = "concat_ws", description = "Concat strings to one by the separator string.") public class ConcatWS extends UDF { - public String eval(String separator, String... args) { - return StringUtils.join(args, separator); - } + public String eval(String separator, String... args) { + return StringUtils.join(args, separator); + } - public BinaryString eval(String separator, BinaryString... args) { - BinaryString sep = Objects.isNull(separator) ? BinaryString.EMPTY_STRING : - BinaryString.fromString(separator); - return BinaryString.concatWs(sep, args); - } + public BinaryString eval(String separator, BinaryString... args) { + BinaryString sep = + Objects.isNull(separator) ? BinaryString.EMPTY_STRING : BinaryString.fromString(separator); + return BinaryString.concatWs(sep, args); + } - public BinaryString eval(BinaryString separator, BinaryString... args) { - return BinaryString.concatWs(separator, args); - } + public BinaryString eval(BinaryString separator, BinaryString... args) { + return BinaryString.concatWs(separator, args); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/GetJsonObject.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/GetJsonObject.java index fd2af9319..6eea0d64d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/GetJsonObject.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/GetJsonObject.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.udf.table.string; -import com.google.common.collect.Iterators; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashMap; @@ -27,6 +26,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; import org.codehaus.jackson.JsonFactory; @@ -35,232 +35,237 @@ import org.codehaus.jackson.map.type.TypeFactory; import org.codehaus.jackson.type.JavaType; +import com.google.common.collect.Iterators; + @Description(name = "get_json_object", description = "parse string from json string.") public class GetJsonObject extends UDF { - private static final JsonFactory JSON_FACTORY = new JsonFactory(); - private static final JavaType MAP_TYPE = TypeFactory.fromClass(Map.class); - private static final JavaType LIST_TYPE = TypeFactory.fromClass(List.class); + private static final JsonFactory JSON_FACTORY = new JsonFactory(); + private static final JavaType MAP_TYPE = TypeFactory.fromClass(Map.class); + private static final JavaType LIST_TYPE = TypeFactory.fromClass(List.class); - static { - JSON_FACTORY.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS); - } + static { + JSON_FACTORY.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS); + } - private final Pattern patternKey = Pattern.compile("^([a-zA-Z0-9_\\-\\:\\s]+).*"); - private final Pattern patternIndex = Pattern.compile("\\[([0-9]+|\\*)\\]"); - private final ObjectMapper mapper = new ObjectMapper(JSON_FACTORY); + private final Pattern patternKey = Pattern.compile("^([a-zA-Z0-9_\\-\\:\\s]+).*"); + private final Pattern patternIndex = Pattern.compile("\\[([0-9]+|\\*)\\]"); + private final ObjectMapper mapper = new ObjectMapper(JSON_FACTORY); - private Map extractObjectCache = new HashCache(); - private Map pathExprCache = new HashCache(); - private Map> indexListCache = new HashCache>(); - private Map mKeyGroup1Cache = new HashCache(); - private Map mKeyMatchesCache = new HashCache(); + private Map extractObjectCache = new HashCache(); + private Map pathExprCache = new HashCache(); + private Map> indexListCache = + new HashCache>(); + private Map mKeyGroup1Cache = new HashCache(); + private Map mKeyMatchesCache = new HashCache(); - private transient AddingList jsonList = new AddingList(); + private transient AddingList jsonList = new AddingList(); - static class HashCache extends LinkedHashMap { + static class HashCache extends LinkedHashMap { - private static final int CACHE_SIZE = 16; - private static final int INIT_SIZE = 32; - private static final float LOAD_FACTOR = 0.6f; + private static final int CACHE_SIZE = 16; + private static final int INIT_SIZE = 32; + private static final float LOAD_FACTOR = 0.6f; - HashCache() { - super(INIT_SIZE, LOAD_FACTOR); - } + HashCache() { + super(INIT_SIZE, LOAD_FACTOR); + } - private static final long serialVersionUID = 1; + private static final long serialVersionUID = 1; - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return size() > CACHE_SIZE; - } + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_SIZE; } + } - public String eval(String jsonString, String pathString) { - - if (pathString != null && !pathString.startsWith("$.")) { - pathString = "$." + pathString; - } - if (jsonString == null || jsonString.isEmpty() || pathString == null - || pathString.isEmpty() || pathString.charAt(0) != '$') { - return null; - } + public String eval(String jsonString, String pathString) { - int pathExprStart = 1; - boolean isRootArray = false; + if (pathString != null && !pathString.startsWith("$.")) { + pathString = "$." + pathString; + } + if (jsonString == null + || jsonString.isEmpty() + || pathString == null + || pathString.isEmpty() + || pathString.charAt(0) != '$') { + return null; + } - if (pathString.length() > 1) { - if (pathString.charAt(1) == '[') { - pathExprStart = 0; - isRootArray = true; - } else if (pathString.charAt(1) == '.') { - isRootArray = pathString.length() > 2 && pathString.charAt(2) == '['; - } else { - return null; - } - } + int pathExprStart = 1; + boolean isRootArray = false; - String[] pathExpr = pathExprCache.get(pathString); - if (pathExpr == null) { - pathExpr = pathString.split("\\.", -1); - pathExprCache.put(pathString, pathExpr); - } + if (pathString.length() > 1) { + if (pathString.charAt(1) == '[') { + pathExprStart = 0; + isRootArray = true; + } else if (pathString.charAt(1) == '.') { + isRootArray = pathString.length() > 2 && pathString.charAt(2) == '['; + } else { + return null; + } + } - Object extractObject = extractObjectCache.get(jsonString); - if (extractObject == null) { - JavaType javaType = isRootArray ? LIST_TYPE : MAP_TYPE; - try { - extractObject = mapper.readValue(jsonString, javaType); - } catch (Exception e) { - return null; - } - extractObjectCache.put(jsonString, extractObject); - } - for (int i = pathExprStart; i < pathExpr.length; i++) { - if (extractObject == null) { - return null; - } - extractObject = extract(extractObject, pathExpr[i], i == pathExprStart && isRootArray); - } - String result = null; - if (extractObject instanceof Map || extractObject instanceof List) { - try { - result = mapper.writeValueAsString(extractObject); - } catch (Exception e) { - return null; - } - } else if (extractObject != null) { - result = extractObject.toString(); - } else { - return null; - } - return result; + String[] pathExpr = pathExprCache.get(pathString); + if (pathExpr == null) { + pathExpr = pathString.split("\\.", -1); + pathExprCache.put(pathString, pathExpr); } - private Object extract(Object json, String path, boolean skipMapProc) { - if (!skipMapProc) { - Matcher mKey = null; - Boolean mKeyMatches = mKeyMatchesCache.get(path); - if (mKeyMatches == null) { - mKey = patternKey.matcher(path); - mKeyMatches = mKey.matches() ? Boolean.TRUE : Boolean.FALSE; - mKeyMatchesCache.put(path, mKeyMatches); - } - if (!mKeyMatches.booleanValue()) { - return null; - } + Object extractObject = extractObjectCache.get(jsonString); + if (extractObject == null) { + JavaType javaType = isRootArray ? LIST_TYPE : MAP_TYPE; + try { + extractObject = mapper.readValue(jsonString, javaType); + } catch (Exception e) { + return null; + } + extractObjectCache.put(jsonString, extractObject); + } + for (int i = pathExprStart; i < pathExpr.length; i++) { + if (extractObject == null) { + return null; + } + extractObject = extract(extractObject, pathExpr[i], i == pathExprStart && isRootArray); + } + String result = null; + if (extractObject instanceof Map || extractObject instanceof List) { + try { + result = mapper.writeValueAsString(extractObject); + } catch (Exception e) { + return null; + } + } else if (extractObject != null) { + result = extractObject.toString(); + } else { + return null; + } + return result; + } - String mKeyGroup1 = mKeyGroup1Cache.get(path); - if (mKeyGroup1 == null) { - if (mKey == null) { - mKey = patternKey.matcher(path); - mKeyMatches = mKey.matches() ? Boolean.TRUE : Boolean.FALSE; - mKeyMatchesCache.put(path, mKeyMatches); - if (!mKeyMatches.booleanValue()) { - return null; - } - } - mKeyGroup1 = mKey.group(1); - mKeyGroup1Cache.put(path, mKeyGroup1); - } - json = extract_json_withkey(json, mKeyGroup1); - } - // Cache indexList - ArrayList indexList = indexListCache.get(path); - if (indexList == null) { - Matcher mIndex = patternIndex.matcher(path); - indexList = new ArrayList(); - while (mIndex.find()) { - indexList.add(mIndex.group(1)); - } - indexListCache.put(path, indexList); - } + private Object extract(Object json, String path, boolean skipMapProc) { + if (!skipMapProc) { + Matcher mKey = null; + Boolean mKeyMatches = mKeyMatchesCache.get(path); + if (mKeyMatches == null) { + mKey = patternKey.matcher(path); + mKeyMatches = mKey.matches() ? Boolean.TRUE : Boolean.FALSE; + mKeyMatchesCache.put(path, mKeyMatches); + } + if (!mKeyMatches.booleanValue()) { + return null; + } - if (indexList.size() > 0) { - json = extract_json_withindex(json, indexList); + String mKeyGroup1 = mKeyGroup1Cache.get(path); + if (mKeyGroup1 == null) { + if (mKey == null) { + mKey = patternKey.matcher(path); + mKeyMatches = mKey.matches() ? Boolean.TRUE : Boolean.FALSE; + mKeyMatchesCache.put(path, mKeyMatches); + if (!mKeyMatches.booleanValue()) { + return null; + } } + mKeyGroup1 = mKey.group(1); + mKeyGroup1Cache.put(path, mKeyGroup1); + } + json = extract_json_withkey(json, mKeyGroup1); + } + // Cache indexList + ArrayList indexList = indexListCache.get(path); + if (indexList == null) { + Matcher mIndex = patternIndex.matcher(path); + indexList = new ArrayList(); + while (mIndex.find()) { + indexList.add(mIndex.group(1)); + } + indexListCache.put(path, indexList); + } - return json; + if (indexList.size() > 0) { + json = extract_json_withindex(json, indexList); } - private static class AddingList extends ArrayList { + return json; + } - @Override - public Iterator iterator() { - return Iterators.forArray(toArray()); - } + private static class AddingList extends ArrayList { - @Override - public void removeRange(int fromIndex, int toIndex) { - super.removeRange(fromIndex, toIndex); - } + @Override + public Iterator iterator() { + return Iterators.forArray(toArray()); } - @SuppressWarnings("unchecked") - private Object extract_json_withindex(Object json, ArrayList indexList) { + @Override + public void removeRange(int fromIndex, int toIndex) { + super.removeRange(fromIndex, toIndex); + } + } - jsonList.clear(); - jsonList.add(json); - for (String index : indexList) { - int targets = jsonList.size(); - if (index.equalsIgnoreCase("*")) { - for (Object array : jsonList) { - if (array instanceof List) { - for (int j = 0; j < ((List) array).size(); j++) { - jsonList.add(((List) array).get(j)); - } - } - } - } else { - for (Object array : jsonList) { - int indexValue = Integer.parseInt(index); - if (!(array instanceof List)) { - continue; - } - List list = (List) array; - if (indexValue >= list.size()) { - continue; - } - jsonList.add(list.get(indexValue)); - } - } - if (jsonList.size() == targets) { - return null; + @SuppressWarnings("unchecked") + private Object extract_json_withindex(Object json, ArrayList indexList) { + + jsonList.clear(); + jsonList.add(json); + for (String index : indexList) { + int targets = jsonList.size(); + if (index.equalsIgnoreCase("*")) { + for (Object array : jsonList) { + if (array instanceof List) { + for (int j = 0; j < ((List) array).size(); j++) { + jsonList.add(((List) array).get(j)); } - jsonList.removeRange(0, targets); + } } - if (jsonList.isEmpty()) { - return null; + } else { + for (Object array : jsonList) { + int indexValue = Integer.parseInt(index); + if (!(array instanceof List)) { + continue; + } + List list = (List) array; + if (indexValue >= list.size()) { + continue; + } + jsonList.add(list.get(indexValue)); } - return (jsonList.size() > 1) ? new ArrayList(jsonList) : jsonList.get(0); + } + if (jsonList.size() == targets) { + return null; + } + jsonList.removeRange(0, targets); + } + if (jsonList.isEmpty()) { + return null; } + return (jsonList.size() > 1) ? new ArrayList(jsonList) : jsonList.get(0); + } - @SuppressWarnings("unchecked") - private Object extract_json_withkey(Object json, String path) { - if (json instanceof List) { - List jsonList = new ArrayList(); - for (int i = 0; i < ((List) json).size(); i++) { - Object jsonElem = ((List) json).get(i); - Object jsonObj = null; - if (jsonElem instanceof Map) { - jsonObj = ((Map) jsonElem).get(path); - } else { - continue; - } - if (jsonObj instanceof List) { - for (int j = 0; j < ((List) jsonObj).size(); j++) { - jsonList.add(((List) jsonObj).get(j)); - } - } else if (jsonObj != null) { - jsonList.add(jsonObj); - } - } - return (jsonList.isEmpty()) ? null : jsonList; - } else if (json instanceof Map) { - return ((Map) json).get(path); + @SuppressWarnings("unchecked") + private Object extract_json_withkey(Object json, String path) { + if (json instanceof List) { + List jsonList = new ArrayList(); + for (int i = 0; i < ((List) json).size(); i++) { + Object jsonElem = ((List) json).get(i); + Object jsonObj = null; + if (jsonElem instanceof Map) { + jsonObj = ((Map) jsonElem).get(path); } else { - return null; + continue; + } + if (jsonObj instanceof List) { + for (int j = 0; j < ((List) jsonObj).size(); j++) { + jsonList.add(((List) jsonObj).get(j)); + } + } else if (jsonObj != null) { + jsonList.add(jsonObj); } + } + return (jsonList.isEmpty()) ? null : jsonList; + } else if (json instanceof Map) { + return ((Map) json).get(path); + } else { + return null; } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Hash.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Hash.java index d76510c84..1c4fea35e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Hash.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Hash.java @@ -25,10 +25,10 @@ @Description(name = "hash", description = "Returns the absolute value of the hash code for input.") public class Hash extends UDF { - public Integer eval(Object obj) { - if (obj == null) { - return null; - } - return Math.abs(obj.hashCode()); + public Integer eval(Object obj) { + if (obj == null) { + return null; } + return Math.abs(obj.hashCode()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IndexOf.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IndexOf.java index a5e1f6a9a..cef6a6b63 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IndexOf.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IndexOf.java @@ -23,36 +23,39 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "index_of", description = "Returns the position of the first occurrence of the target string in the string.") +@Description( + name = "index_of", + description = + "Returns the position of the first occurrence of the target string in the string.") public class IndexOf extends UDF { - public Integer eval(String str, String target) { - return eval(str, target, 0); - } + public Integer eval(String str, String target) { + return eval(str, target, 0); + } - public Integer eval(String str, String target, Integer index) { - if ((str == null) || (target == null) || (index == null)) { - return -1; - } - int fromIndex = index; - if (fromIndex < 0) { - fromIndex = 0; - } - return str.indexOf(target, fromIndex); + public Integer eval(String str, String target, Integer index) { + if ((str == null) || (target == null) || (index == null)) { + return -1; } - - public Integer eval(BinaryString str, BinaryString target) { - return eval(str, target, 0); + int fromIndex = index; + if (fromIndex < 0) { + fromIndex = 0; } + return str.indexOf(target, fromIndex); + } - public Integer eval(BinaryString str, BinaryString target, Integer index) { - if ((str == null) || (target == null) || (index == null)) { - return -1; - } - int fromIndex = index; - if (fromIndex < 0) { - fromIndex = 0; - } - return str.indexOf(target, fromIndex); + public Integer eval(BinaryString str, BinaryString target) { + return eval(str, target, 0); + } + + public Integer eval(BinaryString str, BinaryString target, Integer index) { + if ((str == null) || (target == null) || (index == null)) { + return -1; + } + int fromIndex = index; + if (fromIndex < 0) { + fromIndex = 0; } + return str.indexOf(target, fromIndex); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Instr.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Instr.java index 3bdca7156..e131753bf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Instr.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Instr.java @@ -23,56 +23,58 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "instr", description = "Returns the position of the first occurrence of sub string in str.") +@Description( + name = "instr", + description = "Returns the position of the first occurrence of sub string in str.") public class Instr extends UDF { - public Long eval(String str, String target) { - return eval(str, target, 1L, 1L); - } + public Long eval(String str, String target) { + return eval(str, target, 1L, 1L); + } - public Long eval(String str, String target, Long from) { - return eval(str, target, from, 1L); - } + public Long eval(String str, String target, Long from) { + return eval(str, target, from, 1L); + } - public Long eval(String str, String target, Long from, Long nth) { - if (str == null || target == null || from == null || nth == null) { - return null; - } - if (nth <= 0) { - return null; - } - int fromIndex = from.intValue() - 1; - if (fromIndex < 0) { - return null; - } - for (int i = 0; i < nth; ++i) { - fromIndex = str.indexOf(target, fromIndex) + 1; - } - return (long) fromIndex; + public Long eval(String str, String target, Long from, Long nth) { + if (str == null || target == null || from == null || nth == null) { + return null; } - - public Long eval(BinaryString str, BinaryString target) { - return eval(str, target, 1L, 1L); + if (nth <= 0) { + return null; } - - public Long eval(BinaryString str, BinaryString target, Long from) { - return eval(str, target, from, 1L); + int fromIndex = from.intValue() - 1; + if (fromIndex < 0) { + return null; + } + for (int i = 0; i < nth; ++i) { + fromIndex = str.indexOf(target, fromIndex) + 1; } + return (long) fromIndex; + } - public Long eval(BinaryString str, BinaryString target, Long from, Long nth) { - if (str == null || target == null || from == null || nth == null) { - return null; - } - if (nth <= 0) { - return null; - } - int fromIndex = from.intValue() - 1; - if (fromIndex < 0) { - return null; - } - for (int i = 0; i < nth; ++i) { - fromIndex = str.indexOf(target, fromIndex) + 1; - } - return (long) fromIndex; + public Long eval(BinaryString str, BinaryString target) { + return eval(str, target, 1L, 1L); + } + + public Long eval(BinaryString str, BinaryString target, Long from) { + return eval(str, target, from, 1L); + } + + public Long eval(BinaryString str, BinaryString target, Long from, Long nth) { + if (str == null || target == null || from == null || nth == null) { + return null; + } + if (nth <= 0) { + return null; + } + int fromIndex = from.intValue() - 1; + if (fromIndex < 0) { + return null; + } + for (int i = 0; i < nth; ++i) { + fromIndex = str.indexOf(target, fromIndex) + 1; } + return (long) fromIndex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IsBlank.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IsBlank.java index 78b22d486..598130edd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IsBlank.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/IsBlank.java @@ -26,7 +26,7 @@ @Description(name = "isBlank", description = "Returns whether the string is blank.") public class IsBlank extends UDF { - public Boolean eval(String s) { - return StringUtils.isBlank(s); - } + public Boolean eval(String s) { + return StringUtils.isBlank(s); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/KeyValue.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/KeyValue.java index 69a337204..cc696d6e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/KeyValue.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/KeyValue.java @@ -23,33 +23,34 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "keyvalue", description = "Split the string to get key-value and return the value " - + "for specified key.") +@Description( + name = "keyvalue", + description = "Split the string to get key-value and return the value " + "for specified key.") public class KeyValue extends UDF { - public String eval(Object value, String lineDelimiter, String colDelimiter, String key) { - return eval(String.valueOf(value), lineDelimiter, colDelimiter, key); - } + public String eval(Object value, String lineDelimiter, String colDelimiter, String key) { + return eval(String.valueOf(value), lineDelimiter, colDelimiter, key); + } - public String eval(String value, String lineDelimiter, String colDelimiter, String key) { - if (value == null) { - return null; - } + public String eval(String value, String lineDelimiter, String colDelimiter, String key) { + if (value == null) { + return null; + } - String[] lines = StringUtils.splitByWholeSeparator(value, lineDelimiter); - for (String line : lines) { - if (StringUtils.isBlank(line)) { - continue; - } - String[] keyValue = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, colDelimiter); - if (key.equals(keyValue[0])) { - if (keyValue.length == 2) { - return keyValue[1]; - } else { - return null; - } - } + String[] lines = StringUtils.splitByWholeSeparator(value, lineDelimiter); + for (String line : lines) { + if (StringUtils.isBlank(line)) { + continue; + } + String[] keyValue = StringUtils.splitByWholeSeparatorPreserveAllTokens(line, colDelimiter); + if (key.equals(keyValue[0])) { + if (keyValue.length == 2) { + return keyValue[1]; + } else { + return null; } - return null; + } } + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/LTrim.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/LTrim.java index 7795349ea..c30e8fae4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/LTrim.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/LTrim.java @@ -27,25 +27,25 @@ @Description(name = "ltrim", description = "Returns a string with the left space removed.") public class LTrim extends UDF { - public String eval(String s) { - if (s == null) { - return null; - } - return StringUtils.stripStart(s, " "); + public String eval(String s) { + if (s == null) { + return null; } + return StringUtils.stripStart(s, " "); + } - public BinaryString eval(BinaryString s) { - if (s == null) { - return null; - } - int l = 0; - while (l < s.getLength()) { - if (s.getByte(l) == ' ') { - l++; - } else { - break; - } - } - return s.substring(l, s.getLength()); + public BinaryString eval(BinaryString s) { + if (s == null) { + return null; } + int l = 0; + while (l < s.getLength()) { + if (s.getByte(l) == ' ') { + l++; + } else { + break; + } + } + return s.substring(l, s.getLength()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Length.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Length.java index 583f3e2d3..daa4cac36 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Length.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Length.java @@ -26,10 +26,10 @@ @Description(name = "length", description = "Returns the length of the string.") public class Length extends UDF { - public Long eval(BinaryString s) { - if (s == null) { - return null; - } - return (long) s.getLength(); + public Long eval(BinaryString s) { + if (s == null) { + return null; } + return (long) s.getLength(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Like.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Like.java index 826301572..6c0e090e9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Like.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Like.java @@ -21,6 +21,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @@ -28,157 +29,156 @@ * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ -/** - * This class is an adaptation of Hive's org.apache.hadoop.hive.ql.udf.UDFLike. - */ +/** This class is an adaptation of Hive's org.apache.hadoop.hive.ql.udf.UDFLike. */ @Description(name = "like", description = "Returns whether string matches to the pattern.") public class Like extends UDF { - private PatternType type = PatternType.NONE; - private String simplePattern; - private String lastLikePattern; - private Pattern p = null; + private PatternType type = PatternType.NONE; + private String simplePattern; + private String lastLikePattern; + private Pattern p = null; - public static String likePatternToRegExp(String likePattern) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < likePattern.length(); i++) { - char n = likePattern.charAt(i); - if (n == '\\' && i + 1 < likePattern.length() && (likePattern.charAt(i + 1) == '_' - || likePattern.charAt(i + 1) == '%')) { - sb.append(likePattern.charAt(i + 1)); - i++; - continue; - } - if (n == '_') { - sb.append("."); - } else if (n == '%') { - sb.append(".*"); - } else { - sb.append(Pattern.quote(Character.toString(n))); - } - } - return sb.toString(); + public static String likePatternToRegExp(String likePattern) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < likePattern.length(); i++) { + char n = likePattern.charAt(i); + if (n == '\\' + && i + 1 < likePattern.length() + && (likePattern.charAt(i + 1) == '_' || likePattern.charAt(i + 1) == '%')) { + sb.append(likePattern.charAt(i + 1)); + i++; + continue; + } + if (n == '_') { + sb.append("."); + } else if (n == '%') { + sb.append(".*"); + } else { + sb.append(Pattern.quote(Character.toString(n))); + } } + return sb.toString(); + } - private static boolean find(String s, String sub, int startS, int endS) { - byte[] byteS = s.getBytes(); - byte[] byteSub = sub.getBytes(); - int lenSub = byteSub.length; - boolean match = false; - for (int i = startS; (i < endS - lenSub + 1) && (!match); i++) { - match = true; - for (int j = 0; j < lenSub; j++) { - if (byteS[j + i] != byteSub[j]) { - match = false; - break; - } - } + private static boolean find(String s, String sub, int startS, int endS) { + byte[] byteS = s.getBytes(); + byte[] byteSub = sub.getBytes(); + int lenSub = byteSub.length; + boolean match = false; + for (int i = startS; (i < endS - lenSub + 1) && (!match); i++) { + match = true; + for (int j = 0; j < lenSub; j++) { + if (byteS[j + i] != byteSub[j]) { + match = false; + break; } - return match; + } } + return match; + } - private void parseSimplePattern(String likePattern) { - int length = likePattern.length(); - int beginIndex = 0; - int endIndex = length; - char lastChar = 'a'; - String strPattern = ""; - type = PatternType.NONE; + private void parseSimplePattern(String likePattern) { + int length = likePattern.length(); + int beginIndex = 0; + int endIndex = length; + char lastChar = 'a'; + String strPattern = ""; + type = PatternType.NONE; - for (int i = 0; i < length; i++) { - char n = likePattern.charAt(i); - if (n == '_') { // such as "a_b" - if (lastChar != '\\') { // such as "a%bc" - type = PatternType.COMPLEX; - return; - } else { // such as "abc\%de%" - strPattern += likePattern.substring(beginIndex, i - 1); - beginIndex = i; - } - } else if (n == '%') { - if (i == 0) { // such as "%abc" - type = PatternType.END; - beginIndex = 1; - } else if (i < length - 1) { - if (lastChar != '\\') { // such as "a%bc" - type = PatternType.COMPLEX; - return; - } else { // such as "abc\%de%" - strPattern += likePattern.substring(beginIndex, i - 1); - beginIndex = i; - } - } else { - if (lastChar != '\\') { - endIndex = length - 1; - if (type == PatternType.END) { // such as "%abc%" - type = PatternType.MIDDLE; - } else { - type = PatternType.BEGIN; // such as "abc%" - } - } else { // such as "abc\%" - strPattern += likePattern.substring(beginIndex, i - 1); - beginIndex = i; - endIndex = length; - } - } + for (int i = 0; i < length; i++) { + char n = likePattern.charAt(i); + if (n == '_') { // such as "a_b" + if (lastChar != '\\') { // such as "a%bc" + type = PatternType.COMPLEX; + return; + } else { // such as "abc\%de%" + strPattern += likePattern.substring(beginIndex, i - 1); + beginIndex = i; + } + } else if (n == '%') { + if (i == 0) { // such as "%abc" + type = PatternType.END; + beginIndex = 1; + } else if (i < length - 1) { + if (lastChar != '\\') { // such as "a%bc" + type = PatternType.COMPLEX; + return; + } else { // such as "abc\%de%" + strPattern += likePattern.substring(beginIndex, i - 1); + beginIndex = i; + } + } else { + if (lastChar != '\\') { + endIndex = length - 1; + if (type == PatternType.END) { // such as "%abc%" + type = PatternType.MIDDLE; + } else { + type = PatternType.BEGIN; // such as "abc%" } - lastChar = n; + } else { // such as "abc\%" + strPattern += likePattern.substring(beginIndex, i - 1); + beginIndex = i; + endIndex = length; + } } - - strPattern += likePattern.substring(beginIndex, endIndex); - simplePattern = strPattern; + } + lastChar = n; } - public Boolean eval(String s, String likePattern) { - if (s == null || likePattern == null) { - return null; - } - if (!likePattern.equals(lastLikePattern)) { - lastLikePattern = likePattern; - String strLikePattern = likePattern; + strPattern += likePattern.substring(beginIndex, endIndex); + simplePattern = strPattern; + } - parseSimplePattern(strLikePattern); - if (type == PatternType.COMPLEX) { - p = Pattern.compile(likePatternToRegExp(strLikePattern), Pattern.DOTALL); - } - } + public Boolean eval(String s, String likePattern) { + if (s == null || likePattern == null) { + return null; + } + if (!likePattern.equals(lastLikePattern)) { + lastLikePattern = likePattern; + String strLikePattern = likePattern; - if (type == PatternType.COMPLEX) { - Matcher m = p.matcher(s); - return m.matches(); - } else { - int sLen = s.getBytes().length; - int likeLen = simplePattern.getBytes().length; - int startS = 0; - int endS = sLen; - // if s is shorter than the required pattern - if (endS < likeLen) { - return false; - } - switch (type) { - case BEGIN: - endS = likeLen; - break; - case END: - startS = endS - likeLen; - break; - case NONE: - if (likeLen != sLen) { - return false; - } - break; - default: - break; - } - return find(s, simplePattern, startS, endS); - } + parseSimplePattern(strLikePattern); + if (type == PatternType.COMPLEX) { + p = Pattern.compile(likePatternToRegExp(strLikePattern), Pattern.DOTALL); + } } - private enum PatternType { - NONE, - BEGIN, - END, - MIDDLE, - COMPLEX + if (type == PatternType.COMPLEX) { + Matcher m = p.matcher(s); + return m.matches(); + } else { + int sLen = s.getBytes().length; + int likeLen = simplePattern.getBytes().length; + int startS = 0; + int endS = sLen; + // if s is shorter than the required pattern + if (endS < likeLen) { + return false; + } + switch (type) { + case BEGIN: + endS = likeLen; + break; + case END: + startS = endS - likeLen; + break; + case NONE: + if (likeLen != sLen) { + return false; + } + break; + default: + break; + } + return find(s, simplePattern, startS, endS); } + } + + private enum PatternType { + NONE, + BEGIN, + END, + MIDDLE, + COMPLEX + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RTrim.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RTrim.java index eed75b81d..1a679b828 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RTrim.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RTrim.java @@ -27,25 +27,25 @@ @Description(name = "rtrim", description = "Returns a string with the right space removed.") public class RTrim extends UDF { - public String eval(String s) { - if (s == null) { - return null; - } - return StringUtils.stripEnd(s, " "); + public String eval(String s) { + if (s == null) { + return null; } + return StringUtils.stripEnd(s, " "); + } - public BinaryString eval(BinaryString s) { - if (s == null) { - return null; - } - int r = s.getLength() - 1; - while (r >= 0) { - if (s.getByte(r) == ' ') { - r--; - } else { - break; - } - } - return s.substring(0, r + 1); + public BinaryString eval(BinaryString s) { + if (s == null) { + return null; } + int r = s.getLength() - 1; + while (r >= 0) { + if (s.getByte(r) == ' ') { + r--; + } else { + break; + } + } + return s.substring(0, r + 1); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExp.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExp.java index bbd3b6b7d..011f11deb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExp.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExp.java @@ -21,27 +21,30 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "regexp", description = "Returns whether the string match the regular expression.") +@Description( + name = "regexp", + description = "Returns whether the string match the regular expression.") public class RegExp extends UDF { - private String lastRegex; - private Pattern p = null; + private String lastRegex; + private Pattern p = null; - public Boolean eval(String s, String regex) { - if (s == null || regex == null) { - return null; - } - if (regex.length() == 0) { - return false; - } - if (!regex.equals(lastRegex) || p == null) { - lastRegex = regex; - p = Pattern.compile(regex); - } - Matcher m = p.matcher(s); - return m.find(0); + public Boolean eval(String s, String regex) { + if (s == null || regex == null) { + return null; + } + if (regex.length() == 0) { + return false; + } + if (!regex.equals(lastRegex) || p == null) { + lastRegex = regex; + p = Pattern.compile(regex); } + Matcher m = p.matcher(s); + return m.find(0); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpExtract.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpExtract.java index 931dcf8ec..a8d3a07c8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpExtract.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpExtract.java @@ -22,65 +22,68 @@ import java.util.regex.MatchResult; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "regexp_extract", description = "Extract the substring match the regular expression.") +@Description( + name = "regexp_extract", + description = "Extract the substring match the regular expression.") public class RegExpExtract extends UDF { - private String lastRegex = null; - private Pattern p = null; + private String lastRegex = null; + private Pattern p = null; - public String eval(String s, String regex, String extractIndex) { - return eval(s, regex, Long.valueOf(extractIndex)); - } + public String eval(String s, String regex, String extractIndex) { + return eval(s, regex, Long.valueOf(extractIndex)); + } - public String eval(String s, String regex, Integer extractIndex) { - if (s == null || regex == null || extractIndex == null) { - return null; - } - if (StringUtils.isEmpty(regex)) { - return null; - } - if (extractIndex < 0) { - return null; - } - if (!regex.equals(lastRegex) || p == null) { - lastRegex = regex; - p = Pattern.compile(regex); - } - Matcher m = p.matcher(s); - if (m.find()) { - MatchResult mr = m.toMatchResult(); - return mr.group(extractIndex); - } - return ""; + public String eval(String s, String regex, Integer extractIndex) { + if (s == null || regex == null || extractIndex == null) { + return null; } - - public String eval(Object s, String regex, Long extractIndex) { - return eval(String.valueOf(s), regex, extractIndex); + if (StringUtils.isEmpty(regex)) { + return null; + } + if (extractIndex < 0) { + return null; + } + if (!regex.equals(lastRegex) || p == null) { + lastRegex = regex; + p = Pattern.compile(regex); } + Matcher m = p.matcher(s); + if (m.find()) { + MatchResult mr = m.toMatchResult(); + return mr.group(extractIndex); + } + return ""; + } - public String eval(String s, String regex, Long extractIndex) { - if (s == null || regex == null || extractIndex == null) { - return null; - } + public String eval(Object s, String regex, Long extractIndex) { + return eval(String.valueOf(s), regex, extractIndex); + } - if (StringUtils.isEmpty(regex)) { - return null; - } - if (extractIndex < 0) { - return null; - } - return eval(s, regex, extractIndex.intValue()); + public String eval(String s, String regex, Long extractIndex) { + if (s == null || regex == null || extractIndex == null) { + return null; } - public String eval(String s, String regex) { - return this.eval(s, regex, 1); + if (StringUtils.isEmpty(regex)) { + return null; } - - public String eval(Object s, String regex) { - return this.eval(String.valueOf(s), regex, 1); + if (extractIndex < 0) { + return null; } + return eval(s, regex, extractIndex.intValue()); + } + + public String eval(String s, String regex) { + return this.eval(s, regex, 1); + } + + public String eval(Object s, String regex) { + return this.eval(String.valueOf(s), regex, 1); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpReplace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpReplace.java index 20e495b84..2bcc9ca1d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpReplace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegExpReplace.java @@ -22,14 +22,15 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "regexp_replace", description = "Replace all substrings of the string that match " - + "the regular expression.") +@Description( + name = "regexp_replace", + description = "Replace all substrings of the string that match " + "the regular expression.") public class RegExpReplace extends UDF { - public String eval(String s, String regex, String replacement) { - if (s == null || regex == null || replacement == null) { - return null; - } - return s.replaceAll(regex, replacement); + public String eval(String s, String regex, String replacement) { + if (s == null || regex == null || replacement == null) { + return null; } + return s.replaceAll(regex, replacement); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegexpCount.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegexpCount.java index 2ddafb0c4..18be99bfc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegexpCount.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/RegexpCount.java @@ -21,32 +21,35 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "regexp_count", description = "Returns the count that the pattern matched in the string.") +@Description( + name = "regexp_count", + description = "Returns the count that the pattern matched in the string.") public class RegexpCount extends UDF { - private String lastPattern; - private Pattern p = null; + private String lastPattern; + private Pattern p = null; - public Long eval(String source, String pattern, Long startPos) { - if (source == null || pattern == null || startPos == null) { - return null; - } - if (lastPattern == null || !lastPattern.equals(pattern)) { - p = Pattern.compile(pattern); - lastPattern = pattern; - } - Matcher matcher = p.matcher(source.substring(startPos.intValue())); - long c = 0; - while (matcher.find()) { - c++; - } - return c; + public Long eval(String source, String pattern, Long startPos) { + if (source == null || pattern == null || startPos == null) { + return null; } - - public Long eval(String source, String pattern) { - return eval(source, pattern, 0L); + if (lastPattern == null || !lastPattern.equals(pattern)) { + p = Pattern.compile(pattern); + lastPattern = pattern; } + Matcher matcher = p.matcher(source.substring(startPos.intValue())); + long c = 0; + while (matcher.find()) { + c++; + } + return c; + } + + public Long eval(String source, String pattern) { + return eval(source, pattern, 0L); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Repeat.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Repeat.java index 507a54044..1d4edc413 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Repeat.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Repeat.java @@ -25,14 +25,14 @@ @Description(name = "repeat", description = "Repeats string n times.") public class Repeat extends UDF { - public String eval(String s, Integer n) { - if (n == null || s == null) { - return null; - } - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < n; ++i) { - sb.append(s); - } - return sb.toString(); + public String eval(String s, Integer n) { + if (n == null || s == null) { + return null; } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < n; ++i) { + sb.append(s); + } + return sb.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Replace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Replace.java index c2b9e3dbe..6d2936d0d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Replace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Replace.java @@ -23,16 +23,17 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "replace", description = "Removes each substring of the source String that matches" - + " the regular expression.") +@Description( + name = "replace", + description = + "Removes each substring of the source String that matches" + " the regular expression.") public class Replace extends UDF { - public String eval(String text, String searchString, String replacement) { - return StringUtils.replace(text, searchString, replacement); - } - - public String eval(Object text, String searchString, String replacement) { - return StringUtils.replace(String.valueOf(text), searchString, replacement); - } + public String eval(String text, String searchString, String replacement) { + return StringUtils.replace(text, searchString, replacement); + } + public String eval(Object text, String searchString, String replacement) { + return StringUtils.replace(String.valueOf(text), searchString, replacement); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Reverse.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Reverse.java index a7769f5eb..151e9b3a9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Reverse.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Reverse.java @@ -27,15 +27,15 @@ @Description(name = "reverse", description = "Returns the reversed string.") public class Reverse extends UDF { - public String eval(String s) { - return StringUtils.reverse(s); - } + public String eval(String s) { + return StringUtils.reverse(s); + } - public BinaryString eval(BinaryString s) { - if (s == null) { - return null; - } - String reverse = StringUtils.reverse(s.toString()); - return BinaryString.fromString(reverse); + public BinaryString eval(BinaryString s) { + if (s == null) { + return null; } + String reverse = StringUtils.reverse(s.toString()); + return BinaryString.fromString(reverse); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Space.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Space.java index 75e716ad1..36c65dc97 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Space.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Space.java @@ -25,14 +25,14 @@ @Description(name = "space", description = "Returns a string of n spaces.") public class Space extends UDF { - public String eval(Long n) { - if (n == null) { - return null; - } - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < n; i++) { - sb.append(" "); - } - return sb.toString(); + public String eval(Long n) { + if (n == null) { + return null; } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < n; i++) { + sb.append(" "); + } + return sb.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/SplitEx.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/SplitEx.java index 7b9b3cda4..0114cd74b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/SplitEx.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/SplitEx.java @@ -27,31 +27,35 @@ @Description(name = "split_ex", description = "Split string by separator and returns nth substring") public class SplitEx extends UDF { - public String eval(String str, String separator, Integer index) { - if ((str == null) || separator == null || index == null) { - return null; - } - if (index < 0) { - return null; - } - String[] values = StringUtils.splitByWholeSeparatorPreserveAllTokens(str, separator); - if (index >= values.length) { - return null; - } - return values[index]; + public String eval(String str, String separator, Integer index) { + if ((str == null) || separator == null || index == null) { + return null; } + if (index < 0) { + return null; + } + String[] values = StringUtils.splitByWholeSeparatorPreserveAllTokens(str, separator); + if (index >= values.length) { + return null; + } + return values[index]; + } - public BinaryString eval(BinaryString str, BinaryString separator, Integer index) { - if ((str == null) || separator == null || separator.equals(BinaryString.EMPTY_STRING) || index == null) { - return null; - } - if (index < 0) { - return null; - } - String[] values = StringUtils.splitByWholeSeparatorPreserveAllTokens(str.toString(), separator.toString()); - if (index >= values.length) { - return null; - } - return BinaryString.fromString(values[index]); + public BinaryString eval(BinaryString str, BinaryString separator, Integer index) { + if ((str == null) + || separator == null + || separator.equals(BinaryString.EMPTY_STRING) + || index == null) { + return null; + } + if (index < 0) { + return null; + } + String[] values = + StringUtils.splitByWholeSeparatorPreserveAllTokens(str.toString(), separator.toString()); + if (index >= values.length) { + return null; } + return BinaryString.fromString(values[index]); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Substr.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Substr.java index 4ffe3ef8a..c7c734c95 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Substr.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/Substr.java @@ -23,64 +23,67 @@ import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; -@Description(name = "substr", description = "Returns the substring for start position with specified length" - + ". The position start from 1.") +@Description( + name = "substr", + description = + "Returns the substring for start position with specified length" + + ". The position start from 1.") public class Substr extends UDF { - public String eval(String str, Integer pos, Integer length) { - if (str == null || pos == null || length == null) { - return null; - } - if ((Math.abs(pos) > str.length())) { - return str; - } - - int start; - int end; - if (pos > 0) { - start = pos - 1; - } else if (pos < 0) { - start = str.length() + pos; - } else { - start = 0; - } - if (length == -1) { - end = str.length(); - } else { - end = Math.min(start + length, str.length()); - } - return str.substring(start, end); + public String eval(String str, Integer pos, Integer length) { + if (str == null || pos == null || length == null) { + return null; } - - public String eval(String str, Integer start) { - return eval(str, start, -1); + if ((Math.abs(pos) > str.length())) { + return str; } - public BinaryString eval(BinaryString str, Integer start) { - return eval(str, start, -1); + int start; + int end; + if (pos > 0) { + start = pos - 1; + } else if (pos < 0) { + start = str.length() + pos; + } else { + start = 0; } + if (length == -1) { + end = str.length(); + } else { + end = Math.min(start + length, str.length()); + } + return str.substring(start, end); + } + + public String eval(String str, Integer start) { + return eval(str, start, -1); + } - public BinaryString eval(BinaryString str, Integer pos, Integer length) { - if (str == null || pos == null || length == null) { - return null; - } - if (Math.abs(pos) > str.getLength()) { - return null; - } - int start; - int end; - if (pos > 0) { - start = pos - 1; - } else if (pos < 0) { - start = str.getLength() + pos; - } else { - start = 0; - } - if (length == -1) { - end = str.getLength(); - } else { - end = Math.min(start + length, str.getLength()); - } - return str.substring(start, end); + public BinaryString eval(BinaryString str, Integer start) { + return eval(str, start, -1); + } + + public BinaryString eval(BinaryString str, Integer pos, Integer length) { + if (str == null || pos == null || length == null) { + return null; + } + if (Math.abs(pos) > str.getLength()) { + return null; + } + int start; + int end; + if (pos > 0) { + start = pos - 1; + } else if (pos < 0) { + start = str.getLength() + pos; + } else { + start = 0; + } + if (length == -1) { + end = str.getLength(); + } else { + end = Math.min(start + length, str.getLength()); } + return str.substring(start, end); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlDecode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlDecode.java index 4e121ebca..10f8094ae 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlDecode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlDecode.java @@ -21,20 +21,21 @@ import java.io.UnsupportedEncodingException; import java.net.URLDecoder; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "urldecode", description = "Decode the URL.") public class UrlDecode extends UDF { - public String eval(String url) { - if (url == null) { - return null; - } - try { - return URLDecoder.decode(url, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("error in decode url:" + url, e); - } + public String eval(String url) { + if (url == null) { + return null; + } + try { + return URLDecoder.decode(url, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("error in decode url:" + url, e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlEncode.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlEncode.java index 6b15e05ae..9e139e44b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlEncode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/string/UrlEncode.java @@ -21,20 +21,21 @@ import java.io.UnsupportedEncodingException; import java.net.URLEncoder; + import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDF; @Description(name = "urlencode", description = "Encode the URL.") public class UrlEncode extends UDF { - public String eval(String url) { - if (url == null) { - return null; - } - try { - return URLEncoder.encode(url, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException("error encode url:" + url, e); - } + public String eval(String url) { + if (url == null) { + return null; + } + try { + return URLEncoder.encode(url, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException("error encode url:" + url, e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/udtf/Split.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/udtf/Split.java index 64015ef0f..4de974088 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/udtf/Split.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/table/udtf/Split.java @@ -19,34 +19,36 @@ package org.apache.geaflow.dsl.udf.table.udtf; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDTF; +import com.google.common.collect.Lists; + @Description(name = "split", description = "Split string and expand it.") public class Split extends UDTF { - private static final String DEFAULT_SEP = ","; + private static final String DEFAULT_SEP = ","; - public void eval(String str) { - eval(str, DEFAULT_SEP); - } + public void eval(String str) { + eval(str, DEFAULT_SEP); + } - public void eval(String str, String separator) { - String[] lines = StringUtils.split(str, separator); - for (String line : lines) { - collect(new Object[]{line}); - } + public void eval(String str, String separator) { + String[] lines = StringUtils.split(str, separator); + for (String line : lines) { + collect(new Object[] {line}); } + } - @Override - public List> getReturnType(List> paramTypes, List udtfReturnFields) { - List> returnTypes = Lists.newArrayList(); - for (int i = 0; i < udtfReturnFields.size(); i++) { - returnTypes.add(String.class); - } - return returnTypes; + @Override + public List> getReturnType(List> paramTypes, List udtfReturnFields) { + List> returnTypes = Lists.newArrayList(); + for (int i = 0; i < udtfReturnFields.size(); i++) { + returnTypes.add(String.class); } + return returnTypes; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/ExpressionUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/ExpressionUtil.java index 8019c92c9..37f00478b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/ExpressionUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/ExpressionUtil.java @@ -19,165 +19,171 @@ package org.apache.geaflow.dsl.util; -import com.google.common.base.Joiner; import java.util.ArrayList; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.*; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; +import com.google.common.base.Joiner; + public class ExpressionUtil { - public static String showExpression(RexNode node, - List calcExps, - RelDataType inputRowType) { - return node.accept(new RexNodeExpressionVisitor(calcExps, inputRowType)); - } + public static String showExpression( + RexNode node, List calcExps, RelDataType inputRowType) { + return node.accept(new RexNodeExpressionVisitor(calcExps, inputRowType)); + } - static class RexNodeExpressionVisitor implements RexVisitor { + static class RexNodeExpressionVisitor implements RexVisitor { - private final List calcExps; - private final RelDataType inputRowType; + private final List calcExps; + private final RelDataType inputRowType; - public RexNodeExpressionVisitor(List localExps, - RelDataType inputRowType) { - this.calcExps = localExps; - this.inputRowType = inputRowType; - } + public RexNodeExpressionVisitor(List localExps, RelDataType inputRowType) { + this.calcExps = localExps; + this.inputRowType = inputRowType; + } - @Override - public String visitInputRef(RexInputRef inputRef) { - int index = inputRef.getIndex(); + @Override + public String visitInputRef(RexInputRef inputRef) { + int index = inputRef.getIndex(); - return inputRowType.getFieldList().get(index).getName(); - } + return inputRowType.getFieldList().get(index).getName(); + } - @Override - public String visitLocalRef(RexLocalRef localRef) { - int index = localRef.getIndex(); - RexNode node = calcExps.get(index); - return node.accept(this); - } + @Override + public String visitLocalRef(RexLocalRef localRef) { + int index = localRef.getIndex(); + RexNode node = calcExps.get(index); + return node.accept(this); + } - @Override - public String visitLiteral(RexLiteral literal) { - if (literal.getType().getSqlTypeName() == SqlTypeName.CHAR - || literal.getType().getSqlTypeName() == SqlTypeName.VARCHAR) { - NlsString nlsString = (NlsString) literal.getValue(); - if (nlsString != null) { - return StringLiteralUtil.escapeSQLString(String.valueOf(nlsString.getValue())); - } else { - return "null"; - } - } - return String.valueOf(literal.getValue()); + @Override + public String visitLiteral(RexLiteral literal) { + if (literal.getType().getSqlTypeName() == SqlTypeName.CHAR + || literal.getType().getSqlTypeName() == SqlTypeName.VARCHAR) { + NlsString nlsString = (NlsString) literal.getValue(); + if (nlsString != null) { + return StringLiteralUtil.escapeSQLString(String.valueOf(nlsString.getValue())); + } else { + return "null"; } + } + return String.valueOf(literal.getValue()); + } - @Override - public String visitCall(RexCall call) { - - String operatorName = call.getOperator().getName(); - List operands = call.getOperands(); - - List operandExpressions = new ArrayList<>(); - for (RexNode operand : operands) { - operandExpressions.add(operand.accept(this)); - } - StringBuilder sb = new StringBuilder(); - - switch (operatorName) { - case "+": - case "-": - case "=": - case "<>": - case "%": - case "*": - case "/": - case "<": - case ">": - case "<=": - case ">=": - case "OR": - case "AND": - case "||": - - sb.append("("); - - for (int i = 0; i < operandExpressions.size(); i++) { - if (i > 0) { - sb.append(" ").append(operatorName).append(" "); - } - sb.append(operandExpressions.get(i)); - } - sb.append(")"); - return sb.toString(); - - case "CASE": - sb.append("CASE "); - for (int i = 0; i < operandExpressions.size() - 1; i += 2) { - sb.append("WHEN ").append(operandExpressions.get(i)) - .append(" THEN ").append(operandExpressions.get(i + 1)); - } - sb.append(" ELSE ") - .append(operandExpressions.get(operandExpressions.size() - 1)) - .append(" END"); - return sb.toString(); - - case "CAST": - RelDataType targetType = call.getType(); - return operatorName + "(" + operandExpressions.get(0) + " as " + targetType.getSqlTypeName() + ")"; - case "ITEM": - return operandExpressions.get(0) + "[" + operandExpressions.get(1) + "]"; - default: - return operatorName + "(" + Joiner.on(',').join(operandExpressions) + ")"; + @Override + public String visitCall(RexCall call) { + + String operatorName = call.getOperator().getName(); + List operands = call.getOperands(); + + List operandExpressions = new ArrayList<>(); + for (RexNode operand : operands) { + operandExpressions.add(operand.accept(this)); + } + StringBuilder sb = new StringBuilder(); + + switch (operatorName) { + case "+": + case "-": + case "=": + case "<>": + case "%": + case "*": + case "/": + case "<": + case ">": + case "<=": + case ">=": + case "OR": + case "AND": + case "||": + sb.append("("); + + for (int i = 0; i < operandExpressions.size(); i++) { + if (i > 0) { + sb.append(" ").append(operatorName).append(" "); } - } + sb.append(operandExpressions.get(i)); + } + sb.append(")"); + return sb.toString(); + + case "CASE": + sb.append("CASE "); + for (int i = 0; i < operandExpressions.size() - 1; i += 2) { + sb.append("WHEN ") + .append(operandExpressions.get(i)) + .append(" THEN ") + .append(operandExpressions.get(i + 1)); + } + sb.append(" ELSE ") + .append(operandExpressions.get(operandExpressions.size() - 1)) + .append(" END"); + return sb.toString(); + + case "CAST": + RelDataType targetType = call.getType(); + return operatorName + + "(" + + operandExpressions.get(0) + + " as " + + targetType.getSqlTypeName() + + ")"; + case "ITEM": + return operandExpressions.get(0) + "[" + operandExpressions.get(1) + "]"; + default: + return operatorName + "(" + Joiner.on(',').join(operandExpressions) + ")"; + } + } - @Override - public String visitOver(RexOver over) { - return over.toString(); - } + @Override + public String visitOver(RexOver over) { + return over.toString(); + } - @Override - public String visitCorrelVariable(RexCorrelVariable correlVariable) { - return correlVariable.toString(); - } + @Override + public String visitCorrelVariable(RexCorrelVariable correlVariable) { + return correlVariable.toString(); + } - @Override - public String visitDynamicParam(RexDynamicParam dynamicParam) { - return dynamicParam.toString(); - } + @Override + public String visitDynamicParam(RexDynamicParam dynamicParam) { + return dynamicParam.toString(); + } - @Override - public String visitRangeRef(RexRangeRef rangeRef) { - return rangeRef.toString(); - } + @Override + public String visitRangeRef(RexRangeRef rangeRef) { + return rangeRef.toString(); + } - @Override - public String visitFieldAccess(RexFieldAccess fieldAccess) { - String refExpr = fieldAccess.getReferenceExpr().accept(this); - return refExpr + "." + fieldAccess.getField().getName(); - } + @Override + public String visitFieldAccess(RexFieldAccess fieldAccess) { + String refExpr = fieldAccess.getReferenceExpr().accept(this); + return refExpr + "." + fieldAccess.getField().getName(); + } - @Override - public String visitSubQuery(RexSubQuery subQuery) { - return subQuery.toString(); - } + @Override + public String visitSubQuery(RexSubQuery subQuery) { + return subQuery.toString(); + } - @Override - public String visitTableInputRef(RexTableInputRef fieldRef) { - return fieldRef.toString(); - } + @Override + public String visitTableInputRef(RexTableInputRef fieldRef) { + return fieldRef.toString(); + } - @Override - public String visitPatternFieldRef(RexPatternFieldRef fieldRef) { - return fieldRef.toString(); - } + @Override + public String visitPatternFieldRef(RexPatternFieldRef fieldRef) { + return fieldRef.toString(); + } - @Override - public String visitOther(RexNode other) { - return other.toString(); - } + @Override + public String visitOther(RexNode other) { + return other.toString(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/FunctionUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/FunctionUtil.java index be56db89c..66ee2e1b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/FunctionUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/FunctionUtil.java @@ -22,6 +22,7 @@ import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; + import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCallBinding; @@ -50,149 +51,145 @@ public class FunctionUtil { - @SuppressWarnings("unchecked") - public static SqlFunction createSqlFunction(GeaFlowFunction function, GQLJavaTypeFactory typeFactory) { - String name = function.getName(); - List> functionClazzs = new ArrayList<>(); - try { - for (String className : function.getClazz()) { - Class reflectClass = Thread.currentThread().getContextClassLoader().loadClass(className); - functionClazzs.add(reflectClass); - } - } catch (Exception e) { - throw new GeaFlowDSLException(e); - } - - FunctionType type; - if (UDF.class.isAssignableFrom(functionClazzs.get(0))) { - type = FunctionType.UDF; - } else if (UDAF.class.isAssignableFrom(functionClazzs.get(0))) { - type = FunctionType.UDAF; - } else if (UDTF.class.isAssignableFrom(functionClazzs.get(0))) { - type = FunctionType.UDTF; - } else if (AlgorithmUserFunction.class.isAssignableFrom(functionClazzs.get(0))) { - type = FunctionType.UDGA; - } else { - throw new GeaFlowDSLException("UnKnow function type of function " + functionClazzs); - } - switch (type) { - case UDF: - return GeaFlowUserDefinedScalarFunction.create(name, - (Class) functionClazzs.get(0), typeFactory); - case UDGA: - return GeaFlowUserDefinedGraphAlgorithm.create(name, - (Class) functionClazzs.get(0), typeFactory); - case UDTF: - return GeaFlowUserDefinedTableFunction.create(name, - (Class) functionClazzs.get(0), typeFactory); - case UDAF: - return GeaFlowUserDefinedAggFunction.create(name, - ArrayUtil.castList(functionClazzs), typeFactory); - default: - throw new GeaFlowDSLException("should never run here"); - } - + @SuppressWarnings("unchecked") + public static SqlFunction createSqlFunction( + GeaFlowFunction function, GQLJavaTypeFactory typeFactory) { + String name = function.getName(); + List> functionClazzs = new ArrayList<>(); + try { + for (String className : function.getClazz()) { + Class reflectClass = Thread.currentThread().getContextClassLoader().loadClass(className); + functionClazzs.add(reflectClass); + } + } catch (Exception e) { + throw new GeaFlowDSLException(e); } - - public static SqlOperandTypeChecker getSqlOperandTypeChecker(String name, Class udfClass, - GQLJavaTypeFactory typeFactory) { - final List types = FunctionCallUtils.getAllEvalParamTypes(udfClass); - return new SqlOperandTypeChecker() { - @Override - public boolean checkOperandTypes(SqlCallBinding callBinding, - boolean throwOnFailure) { - List> callParamTypes = - SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory); - FunctionCallUtils.findMatchMethod(udfClass, callParamTypes); - return true; - } - - @Override - public SqlOperandCountRange getOperandCountRange() { - int min = 255; - int max = -1; - - for (Class[] ts : types) { - int paramLength = ts.length; - if (paramLength > 0 && ts[ts.length - 1].isArray()) { - max = 254; - paramLength = ts.length - 1; - } - max = Math.max(paramLength, max); - min = Math.min(paramLength, min); - } - return SqlOperandCountRanges.between(min, max); - } - - @Override - public String getAllowedSignatures(SqlOperator op, String opName) { - return opName + types; - } - - @Override - public Consistency getConsistency() { - return Consistency.NONE; - } - - @Override - public boolean isOptional(int i) { - return false; - } - }; + FunctionType type; + if (UDF.class.isAssignableFrom(functionClazzs.get(0))) { + type = FunctionType.UDF; + } else if (UDAF.class.isAssignableFrom(functionClazzs.get(0))) { + type = FunctionType.UDAF; + } else if (UDTF.class.isAssignableFrom(functionClazzs.get(0))) { + type = FunctionType.UDTF; + } else if (AlgorithmUserFunction.class.isAssignableFrom(functionClazzs.get(0))) { + type = FunctionType.UDGA; + } else { + throw new GeaFlowDSLException("UnKnow function type of function " + functionClazzs); } - - public static SqlOperandTypeInference getSqlOperandTypeInference(Class udfClass, - GQLJavaTypeFactory typeFactory) { - return (callBinding, returnType, operandTypes) -> { - List> callParamJavaTypes = - SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory); - Method method = FunctionCallUtils.findMatchMethod(udfClass, callParamJavaTypes); - - Class[] realTypes = method.getParameterTypes(); - RelDataType[] realParamTypes = new RelDataType[realTypes.length]; - - for (int i = 0; i < realTypes.length; i++) { - realParamTypes[i] = typeFactory.createType(realTypes[i]); - } - - int varIndex = -1; - for (int i = 0; i < operandTypes.length; i++) { - if (i < realParamTypes.length - && realParamTypes[i].getComponentType() != null) { - varIndex = i; - } - if (varIndex >= 0) { - operandTypes[i] = realParamTypes[varIndex].getComponentType(); - } else { - operandTypes[i] = realParamTypes[i]; - } - } - }; + switch (type) { + case UDF: + return GeaFlowUserDefinedScalarFunction.create( + name, (Class) functionClazzs.get(0), typeFactory); + case UDGA: + return GeaFlowUserDefinedGraphAlgorithm.create( + name, (Class) functionClazzs.get(0), typeFactory); + case UDTF: + return GeaFlowUserDefinedTableFunction.create( + name, (Class) functionClazzs.get(0), typeFactory); + case UDAF: + return GeaFlowUserDefinedAggFunction.create( + name, ArrayUtil.castList(functionClazzs), typeFactory); + default: + throw new GeaFlowDSLException("should never run here"); } - - public static SqlReturnTypeInference getSqlReturnTypeInference(Class clazz, String functionName) { - return opBinding -> { - final JavaTypeFactoryImpl typeFactory = (JavaTypeFactoryImpl) opBinding.getTypeFactory(); - List> paramJavaTypes = - SqlTypeUtil.convertToJavaTypes(opBinding.collectOperandTypes(), typeFactory); - - Method method = FunctionCallUtils.findMatchMethod(clazz, functionName, paramJavaTypes); - if (GraphMetaFieldAccessFunction.class.isAssignableFrom(clazz)) { - Class returnClazz = method.getReturnType(); - if (returnClazz.equals(Object.class)) { - try { - GraphMetaFieldAccessFunction func = - ((GraphMetaFieldAccessFunction) clazz.newInstance()); - return func.getReturnRelDataType((GQLJavaTypeFactory) typeFactory); - } catch (Exception e) { - throw new GeaFlowDSLException(e, - "Cannot get instance of {}", clazz.getName()); - } - } - } - return typeFactory.createType(method.getReturnType()); - }; - } - + } + + public static SqlOperandTypeChecker getSqlOperandTypeChecker( + String name, Class udfClass, GQLJavaTypeFactory typeFactory) { + final List types = FunctionCallUtils.getAllEvalParamTypes(udfClass); + return new SqlOperandTypeChecker() { + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + List> callParamTypes = + SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory); + FunctionCallUtils.findMatchMethod(udfClass, callParamTypes); + return true; + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + int min = 255; + int max = -1; + + for (Class[] ts : types) { + int paramLength = ts.length; + if (paramLength > 0 && ts[ts.length - 1].isArray()) { + max = 254; + paramLength = ts.length - 1; + } + max = Math.max(paramLength, max); + min = Math.min(paramLength, min); + } + return SqlOperandCountRanges.between(min, max); + } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return opName + types; + } + + @Override + public Consistency getConsistency() { + return Consistency.NONE; + } + + @Override + public boolean isOptional(int i) { + return false; + } + }; + } + + public static SqlOperandTypeInference getSqlOperandTypeInference( + Class udfClass, GQLJavaTypeFactory typeFactory) { + return (callBinding, returnType, operandTypes) -> { + List> callParamJavaTypes = + SqlTypeUtil.convertToJavaTypes(callBinding.collectOperandTypes(), typeFactory); + Method method = FunctionCallUtils.findMatchMethod(udfClass, callParamJavaTypes); + + Class[] realTypes = method.getParameterTypes(); + RelDataType[] realParamTypes = new RelDataType[realTypes.length]; + + for (int i = 0; i < realTypes.length; i++) { + realParamTypes[i] = typeFactory.createType(realTypes[i]); + } + + int varIndex = -1; + for (int i = 0; i < operandTypes.length; i++) { + if (i < realParamTypes.length && realParamTypes[i].getComponentType() != null) { + varIndex = i; + } + if (varIndex >= 0) { + operandTypes[i] = realParamTypes[varIndex].getComponentType(); + } else { + operandTypes[i] = realParamTypes[i]; + } + } + }; + } + + public static SqlReturnTypeInference getSqlReturnTypeInference( + Class clazz, String functionName) { + return opBinding -> { + final JavaTypeFactoryImpl typeFactory = (JavaTypeFactoryImpl) opBinding.getTypeFactory(); + List> paramJavaTypes = + SqlTypeUtil.convertToJavaTypes(opBinding.collectOperandTypes(), typeFactory); + + Method method = FunctionCallUtils.findMatchMethod(clazz, functionName, paramJavaTypes); + if (GraphMetaFieldAccessFunction.class.isAssignableFrom(clazz)) { + Class returnClazz = method.getReturnType(); + if (returnClazz.equals(Object.class)) { + try { + GraphMetaFieldAccessFunction func = + ((GraphMetaFieldAccessFunction) clazz.newInstance()); + return func.getReturnRelDataType((GQLJavaTypeFactory) typeFactory); + } catch (Exception e) { + throw new GeaFlowDSLException(e, "Cannot get instance of {}", clazz.getName()); + } + } + } + return typeFactory.createType(method.getReturnType()); + }; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLNodeUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLNodeUtil.java index 0ebd839da..03f4fd87f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLNodeUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLNodeUtil.java @@ -26,6 +26,7 @@ import java.util.Objects; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -33,112 +34,111 @@ public class GQLNodeUtil { - public static List collect(SqlNode node, Predicate predicate) { - return node.accept(new SqlVisitor>() { - @Override - public List visit(SqlLiteral literal) { - if (predicate.test(literal)) { - return (List) Collections.singletonList(literal); - } - return Collections.emptyList(); + public static List collect(SqlNode node, Predicate predicate) { + return node.accept( + new SqlVisitor>() { + @Override + public List visit(SqlLiteral literal) { + if (predicate.test(literal)) { + return (List) Collections.singletonList(literal); } + return Collections.emptyList(); + } - @Override - public List visit(SqlCall call) { - List childResults = call.getOperandList() - .stream() + @Override + public List visit(SqlCall call) { + List childResults = + call.getOperandList().stream() .filter(Objects::nonNull) .flatMap(operand -> operand.accept(this).stream()) .collect(Collectors.toList()); - List results = new ArrayList<>(); - results.addAll(childResults); + List results = new ArrayList<>(); + results.addAll(childResults); - if (predicate.test(call)) { - results.add((T) call); - } - return results; + if (predicate.test(call)) { + results.add((T) call); } + return results; + } - @Override - public List visit(SqlNodeList nodeList) { - List childResults = nodeList.getList() - .stream() + @Override + public List visit(SqlNodeList nodeList) { + List childResults = + nodeList.getList().stream() .filter(Objects::nonNull) .flatMap(operand -> operand.accept(this).stream()) .collect(Collectors.toList()); - List results = new ArrayList<>(); - results.addAll(childResults); + List results = new ArrayList<>(); + results.addAll(childResults); - if (predicate.test(nodeList)) { - results.add((T) nodeList); - } - return results; + if (predicate.test(nodeList)) { + results.add((T) nodeList); } + return results; + } - @Override - public List visit(SqlIdentifier id) { - if (predicate.test(id)) { - return (List) Collections.singletonList(id); - } - return Collections.emptyList(); + @Override + public List visit(SqlIdentifier id) { + if (predicate.test(id)) { + return (List) Collections.singletonList(id); } + return Collections.emptyList(); + } - @Override - public List visit(SqlDataTypeSpec type) { - if (predicate.test(type)) { - return (List) Collections.singletonList(type); - } - return Collections.emptyList(); + @Override + public List visit(SqlDataTypeSpec type) { + if (predicate.test(type)) { + return (List) Collections.singletonList(type); } + return Collections.emptyList(); + } - @Override - public List visit(SqlDynamicParam param) { - if (predicate.test(param)) { - return (List) Collections.singletonList(param); - } - return Collections.emptyList(); + @Override + public List visit(SqlDynamicParam param) { + if (predicate.test(param)) { + return (List) Collections.singletonList(param); } + return Collections.emptyList(); + } - @Override - public List visit(SqlIntervalQualifier intervalQualifier) { - if (predicate.test(intervalQualifier)) { - return (List) Collections.singletonList(intervalQualifier); - } - return Collections.emptyList(); + @Override + public List visit(SqlIntervalQualifier intervalQualifier) { + if (predicate.test(intervalQualifier)) { + return (List) Collections.singletonList(intervalQualifier); } + return Collections.emptyList(); + } }); - } + } - public static boolean containMatch(SqlNode node) { - return !collect(node, n -> n.getKind() == SqlKind.GQL_MATCH_PATTERN).isEmpty(); - } + public static boolean containMatch(SqlNode node) { + return !collect(node, n -> n.getKind() == SqlKind.GQL_MATCH_PATTERN).isEmpty(); + } - public static SqlNode and(SqlNode... filters) { - List nonNullFilters = - Arrays.stream(filters) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - if (nonNullFilters.size() == 0) { - return SqlLiteral.createBoolean(true, SqlParserPos.ZERO); - } - if (nonNullFilters.size() == 1) { - return nonNullFilters.get(0); - } - SqlNode and = null; - for (SqlNode filter : nonNullFilters) { - if (and == null) { - and = filter; - } else { - and = SqlStdOperatorTable.AND.createCall(SqlParserPos.ZERO, and, filter); - } - } - return and; + public static SqlNode and(SqlNode... filters) { + List nonNullFilters = + Arrays.stream(filters).filter(Objects::nonNull).collect(Collectors.toList()); + if (nonNullFilters.size() == 0) { + return SqlLiteral.createBoolean(true, SqlParserPos.ZERO); } - - public static SqlIdentifier getGraphTableName(SqlIdentifier completeIdentifier) { - assert completeIdentifier.names.size() >= 2; - return completeIdentifier.getComponent(0, 2); + if (nonNullFilters.size() == 1) { + return nonNullFilters.get(0); + } + SqlNode and = null; + for (SqlNode filter : nonNullFilters) { + if (and == null) { + and = filter; + } else { + and = SqlStdOperatorTable.AND.createCall(SqlParserPos.ZERO, and, filter); + } } + return and; + } + + public static SqlIdentifier getGraphTableName(SqlIdentifier completeIdentifier) { + assert completeIdentifier.names.size() >= 2; + return completeIdentifier.getComponent(0, 2); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRelUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRelUtil.java index eae87ad0a..38794560a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRelUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRelUtil.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.util; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -31,6 +29,7 @@ import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.plan.volcano.RelSubset; import org.apache.calcite.rel.RelNode; @@ -54,322 +53,340 @@ import org.apache.geaflow.dsl.rel.match.SubQueryStart; import org.apache.geaflow.dsl.rex.PathInputRef; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + public class GQLRelUtil { - public static IMatchNode match(RelNode node) { - node = toRel(node); - if (node instanceof IMatchNode) { - return (IMatchNode) node; - } - throw new IllegalArgumentException("Node must be IMatchNode"); + public static IMatchNode match(RelNode node) { + node = toRel(node); + if (node instanceof IMatchNode) { + return (IMatchNode) node; } + throw new IllegalArgumentException("Node must be IMatchNode"); + } - public static IMatchNode addSubQueryStartNode(IMatchNode root, SubQueryStart queryStart) { - MatchRelShuffle shuffle = new MatchRelShuffle() { - @Override - protected IMatchNode visitChildren(IMatchNode parent) { - List newInputs = new ArrayList<>(); - if (parent.getInputs().isEmpty()) { - newInputs.add(queryStart); - } else { - for (RelNode input : parent.getInputs()) { - newInputs.add(visit(input)); - } - } - return (IMatchNode) parent.copy(parent.getTraitSet(), newInputs); + public static IMatchNode addSubQueryStartNode(IMatchNode root, SubQueryStart queryStart) { + MatchRelShuffle shuffle = + new MatchRelShuffle() { + @Override + protected IMatchNode visitChildren(IMatchNode parent) { + List newInputs = new ArrayList<>(); + if (parent.getInputs().isEmpty()) { + newInputs.add(queryStart); + } else { + for (RelNode input : parent.getInputs()) { + newInputs.add(visit(input)); + } } + return (IMatchNode) parent.copy(parent.getTraitSet(), newInputs); + } }; - return shuffle.visit(root); - } + return shuffle.visit(root); + } - public static SingleMatchNode replaceInput(SingleMatchNode root, SingleMatchNode replacedNode, - IMatchNode newInputNode) { - if (root == replacedNode) { - return (SingleMatchNode) root.copy(root.getTraitSet(), Collections.singletonList(newInputNode)); - } else { - SingleMatchNode newInput = replaceInput((SingleMatchNode) root.getInput(), replacedNode, newInputNode); - return (SingleMatchNode) root.copy(root.getTraitSet(), Collections.singletonList(newInput)); - } + public static SingleMatchNode replaceInput( + SingleMatchNode root, SingleMatchNode replacedNode, IMatchNode newInputNode) { + if (root == replacedNode) { + return (SingleMatchNode) + root.copy(root.getTraitSet(), Collections.singletonList(newInputNode)); + } else { + SingleMatchNode newInput = + replaceInput((SingleMatchNode) root.getInput(), replacedNode, newInputNode); + return (SingleMatchNode) root.copy(root.getTraitSet(), Collections.singletonList(newInput)); } + } - public static List collect(RelNode root, Predicate predicate) { - List childVisit = root.getInputs() - .stream() + public static List collect(RelNode root, Predicate predicate) { + List childVisit = + root.getInputs().stream() .flatMap(input -> collect(input, predicate).stream()) .collect(Collectors.toList()); - if (root instanceof LoopUntilMatch) { - childVisit.addAll(collect(((LoopUntilMatch) root).getLoopBody(), predicate)); - } - List results = new ArrayList<>(childVisit); - if (predicate.test(root)) { - results.add(root); - } - return results; + if (root instanceof LoopUntilMatch) { + childVisit.addAll(collect(((LoopUntilMatch) root).getLoopBody(), predicate)); } - - public static IMatchLabel getLatestMatchNode(SingleMatchNode pathPattern) { - if (pathPattern == null) { - return null; - } - if (pathPattern instanceof IMatchLabel) { - return (IMatchLabel) pathPattern; - } - if (pathPattern instanceof LoopUntilMatch) { - return getLatestMatchNode(((LoopUntilMatch) pathPattern).getLoopBody()); - } - return getLatestMatchNode((SingleMatchNode) pathPattern.getInput()); + List results = new ArrayList<>(childVisit); + if (predicate.test(root)) { + results.add(root); } + return results; + } - public static IMatchLabel getFirstMatchNode(SingleMatchNode pathPattern) { - if (pathPattern == null) { - return null; - } - if (pathPattern.getInput() == null || pathPattern.getInput() instanceof SubQueryStart) { - assert pathPattern instanceof IMatchLabel : "first node int match must be IMatchLabel."; - return (IMatchLabel) pathPattern; - } - return getFirstMatchNode((SingleMatchNode) pathPattern.getInput()); + public static IMatchLabel getLatestMatchNode(SingleMatchNode pathPattern) { + if (pathPattern == null) { + return null; + } + if (pathPattern instanceof IMatchLabel) { + return (IMatchLabel) pathPattern; + } + if (pathPattern instanceof LoopUntilMatch) { + return getLatestMatchNode(((LoopUntilMatch) pathPattern).getLoopBody()); } + return getLatestMatchNode((SingleMatchNode) pathPattern.getInput()); + } - public static boolean isAllSingleMatch(IMatchNode matchNode) { - return collect(matchNode, node -> !(node instanceof SingleMatchNode)).isEmpty(); + public static IMatchLabel getFirstMatchNode(SingleMatchNode pathPattern) { + if (pathPattern == null) { + return null; } + if (pathPattern.getInput() == null || pathPattern.getInput() instanceof SubQueryStart) { + assert pathPattern instanceof IMatchLabel : "first node int match must be IMatchLabel."; + return (IMatchLabel) pathPattern; + } + return getFirstMatchNode((SingleMatchNode) pathPattern.getInput()); + } - /** - * Concat single path pattern "p0" to the start of path pattern "p1" if they can merge. - * e.g. "a" is "(m) - (n) -(f)", "b" is "(f) - (p) - (q)" - * Then we can concat them to "(m) - (n) - (f) - (p) - (q). - */ - public static SingleMatchNode concatPathPattern(SingleMatchNode p0, SingleMatchNode p1, boolean caseSensitive) { - assert isAllSingleMatch(p0) && isAllSingleMatch(p1); - if (p1.getInput() == null) { - // first node in single path pattern must be IMatchNode. - assert p1 instanceof IMatchLabel; - assert Objects.equals(((IMatchLabel) p1).getLabel(), getLatestMatchNode(p0).getLabel()); - IMatchLabel matchNode = (IMatchLabel) p1; - Set filterTypes = matchNode.getTypes(); - // since p1's label is same with p0, we can remove the duplicate label node p1 and only keep the - // node type filters of p1. - if (filterTypes.isEmpty()) { - return p0; - } - return createNodeTypeFilter(p0, filterTypes); - } - SingleMatchNode concatInput = concatPathPattern(p0, (SingleMatchNode) p1.getInput(), caseSensitive); - PathRecordType concatPathType = concatInput.getPathSchema(); - // generate new path schema. - if (p1 instanceof IMatchLabel) { - concatPathType = concatPathType.addField(((IMatchLabel) p1).getLabel(), p1.getNodeType(), caseSensitive); - } else if (p1 instanceof LoopUntilMatch) { - for (RelDataTypeField field : p1.getPathSchema().getFieldList()) { - if (concatPathType.getField(field.getName(), caseSensitive, false) == null) { - concatPathType = concatPathType.addField(field.getName(), field.getType(), caseSensitive); - } - } - } - // copy with new input and path type. - p1 = (SingleMatchNode) p1.copy(Lists.newArrayList(concatInput), concatPathType); - if (p1 instanceof LoopUntilMatch) { - LoopUntilMatch loop = (LoopUntilMatch) p1; - p1 = LoopUntilMatch.copyWithSubQueryStartPathType(p0.getPathSchema(), loop, caseSensitive); + public static boolean isAllSingleMatch(IMatchNode matchNode) { + return collect(matchNode, node -> !(node instanceof SingleMatchNode)).isEmpty(); + } + + /** + * Concat single path pattern "p0" to the start of path pattern "p1" if they can merge. e.g. "a" + * is "(m) - (n) -(f)", "b" is "(f) - (p) - (q)" Then we can concat them to "(m) - (n) - (f) - (p) + * - (q). + */ + public static SingleMatchNode concatPathPattern( + SingleMatchNode p0, SingleMatchNode p1, boolean caseSensitive) { + assert isAllSingleMatch(p0) && isAllSingleMatch(p1); + if (p1.getInput() == null) { + // first node in single path pattern must be IMatchNode. + assert p1 instanceof IMatchLabel; + assert Objects.equals(((IMatchLabel) p1).getLabel(), getLatestMatchNode(p0).getLabel()); + IMatchLabel matchNode = (IMatchLabel) p1; + Set filterTypes = matchNode.getTypes(); + // since p1's label is same with p0, we can remove the duplicate label node p1 and only keep + // the + // node type filters of p1. + if (filterTypes.isEmpty()) { + return p0; + } + return createNodeTypeFilter(p0, filterTypes); + } + SingleMatchNode concatInput = + concatPathPattern(p0, (SingleMatchNode) p1.getInput(), caseSensitive); + PathRecordType concatPathType = concatInput.getPathSchema(); + // generate new path schema. + if (p1 instanceof IMatchLabel) { + concatPathType = + concatPathType.addField(((IMatchLabel) p1).getLabel(), p1.getNodeType(), caseSensitive); + } else if (p1 instanceof LoopUntilMatch) { + for (RelDataTypeField field : p1.getPathSchema().getFieldList()) { + if (concatPathType.getField(field.getName(), caseSensitive, false) == null) { + concatPathType = concatPathType.addField(field.getName(), field.getType(), caseSensitive); } - // adjust path field ref in match node after concat - return (SingleMatchNode) p1.accept(new RexShuttle() { - @Override - public RexNode visitInputRef(RexInputRef inputRef) { + } + } + // copy with new input and path type. + p1 = (SingleMatchNode) p1.copy(Lists.newArrayList(concatInput), concatPathType); + if (p1 instanceof LoopUntilMatch) { + LoopUntilMatch loop = (LoopUntilMatch) p1; + p1 = LoopUntilMatch.copyWithSubQueryStartPathType(p0.getPathSchema(), loop, caseSensitive); + } + // adjust path field ref in match node after concat + return (SingleMatchNode) + p1.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { if (inputRef instanceof PathInputRef) { - int index = inputRef.getIndex(); - int newIndex = index + p0.getPathSchema().getFieldCount() - 1; - return ((PathInputRef) inputRef).copy(newIndex); + int index = inputRef.getIndex(); + int newIndex = index + p0.getPathSchema().getFieldCount() - 1; + return ((PathInputRef) inputRef).copy(newIndex); } return inputRef; - } - }); - } - - public static SingleMatchNode createNodeTypeFilter(SingleMatchNode input, Collection nodeTypes) { - RexBuilder rexBuilder = new RexBuilder(GQLJavaTypeFactory.create()); - SqlTypeName typeName = input.getNodeType().getSqlTypeName(); - RexNode nodeTypeRef; - if (typeName == SqlTypeName.VERTEX) { - nodeTypeRef = rexBuilder.makeInputRef(input.getNodeType().getFieldList() - .get(VertexType.LABEL_FIELD_POSITION).getType(), - VertexType.LABEL_FIELD_POSITION); - } else { - assert typeName == SqlTypeName.EDGE; - nodeTypeRef = rexBuilder.makeInputRef(input.getNodeType().getFieldList() - .get(EdgeType.LABEL_FIELD_POSITION).getType(), - EdgeType.LABEL_FIELD_POSITION); - } - assert input.getPathSchema().lastField().isPresent(); - RelDataTypeField pathField = input.getPathSchema().lastField().get(); - nodeTypeRef = GQLRexUtil.toPathInputRefForWhere(pathField, nodeTypeRef); + } + }); + } - RexNode condition = null; - for (String nodeType : nodeTypes) { - RexNode eq = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, nodeTypeRef, - GQLRexUtil.createString(nodeType)); - if (condition == null) { - condition = eq; - } else { - condition = rexBuilder.makeCall(SqlStdOperatorTable.OR, condition, eq); - } - } - if (condition != null) { - return MatchFilter.create(input, condition, input.getPathSchema()); - } - return input; + public static SingleMatchNode createNodeTypeFilter( + SingleMatchNode input, Collection nodeTypes) { + RexBuilder rexBuilder = new RexBuilder(GQLJavaTypeFactory.create()); + SqlTypeName typeName = input.getNodeType().getSqlTypeName(); + RexNode nodeTypeRef; + if (typeName == SqlTypeName.VERTEX) { + nodeTypeRef = + rexBuilder.makeInputRef( + input.getNodeType().getFieldList().get(VertexType.LABEL_FIELD_POSITION).getType(), + VertexType.LABEL_FIELD_POSITION); + } else { + assert typeName == SqlTypeName.EDGE; + nodeTypeRef = + rexBuilder.makeInputRef( + input.getNodeType().getFieldList().get(EdgeType.LABEL_FIELD_POSITION).getType(), + EdgeType.LABEL_FIELD_POSITION); } + assert input.getPathSchema().lastField().isPresent(); + RelDataTypeField pathField = input.getPathSchema().lastField().get(); + nodeTypeRef = GQLRexUtil.toPathInputRefForWhere(pathField, nodeTypeRef); - public static Set getLabels(RelNode node) { - return collect(node, n -> n instanceof IMatchLabel) - .stream().map(matchNode -> ((IMatchLabel) matchNode).getLabel()) - .collect(Collectors.toSet()); + RexNode condition = null; + for (String nodeType : nodeTypes) { + RexNode eq = + rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, nodeTypeRef, GQLRexUtil.createString(nodeType)); + if (condition == null) { + condition = eq; + } else { + condition = rexBuilder.makeCall(SqlStdOperatorTable.OR, condition, eq); + } } - - public static Set getCommonLabels(RelNode node1, RelNode node2) { - Set nodeLabels1 = getLabels(node1); - Set nodeLabels2 = getLabels(node2); - return Sets.intersection(nodeLabels1, nodeLabels2); + if (condition != null) { + return MatchFilter.create(input, condition, input.getPathSchema()); } + return input; + } - public static RelNode oneInput(List inputs) { - if (inputs.size() == 0) { - return null; - } - assert inputs.size() == 1 : "Node should have one input at most."; - return inputs.get(0); - } + public static Set getLabels(RelNode node) { + return collect(node, n -> n instanceof IMatchLabel).stream() + .map(matchNode -> ((IMatchLabel) matchNode).getLabel()) + .collect(Collectors.toSet()); + } - public static RelNode toRel(RelNode node) { - if (node instanceof RelSubset) { - return toRel(((RelSubset) node).getRelList().get(0)); - } - if (node instanceof HepRelVertex) { - return toRel(((HepRelVertex) node).getCurrentRel()); - } - return node; + public static Set getCommonLabels(RelNode node1, RelNode node2) { + Set nodeLabels1 = getLabels(node1); + Set nodeLabels2 = getLabels(node2); + return Sets.intersection(nodeLabels1, nodeLabels2); + } + + public static RelNode oneInput(List inputs) { + if (inputs.size() == 0) { + return null; } + assert inputs.size() == 1 : "Node should have one input at most."; + return inputs.get(0); + } - public static boolean findTableFunctionScan(RelNode rel) { - rel = toRel(rel); - if (rel instanceof LogicalTableFunctionScan) { - return true; - } - if (rel instanceof LogicalCalc) { - RelNode input = toRel(((LogicalCalc) rel).getInput()); - if (input instanceof LogicalTableFunctionScan) { - return true; - } - } - if (rel instanceof LogicalFilter) { - RelNode input = toRel(((LogicalFilter) rel).getInput()); - return input instanceof LogicalTableFunctionScan; - } - return false; + public static RelNode toRel(RelNode node) { + if (node instanceof RelSubset) { + return toRel(((RelSubset) node).getRelList().get(0)); + } + if (node instanceof HepRelVertex) { + return toRel(((HepRelVertex) node).getCurrentRel()); } + return node; + } - public static boolean isGQLMatchRelNode(RelNode input) { - RelNode node = toRel(input); - return node.getRowType() instanceof PathRecordType; + public static boolean findTableFunctionScan(RelNode rel) { + rel = toRel(rel); + if (rel instanceof LogicalTableFunctionScan) { + return true; + } + if (rel instanceof LogicalCalc) { + RelNode input = toRel(((LogicalCalc) rel).getInput()); + if (input instanceof LogicalTableFunctionScan) { + return true; + } + } + if (rel instanceof LogicalFilter) { + RelNode input = toRel(((LogicalFilter) rel).getInput()); + return input instanceof LogicalTableFunctionScan; } + return false; + } - public static RelNode applyRexShuffleToTree(RelNode node, RexShuttle rexShuttle) { - if (node == null) { - return null; - } - List newInputs = new ArrayList<>(node.getInputs().size()); - for (RelNode inputRel : node.getInputs()) { - newInputs.add(applyRexShuffleToTree(inputRel, rexShuttle)); - } - node = node.accept(rexShuttle); - if (newInputs.isEmpty()) { - return node; - } else { - return node.copy(node.getTraitSet(), newInputs); - } + public static boolean isGQLMatchRelNode(RelNode input) { + RelNode node = toRel(input); + return node.getRowType() instanceof PathRecordType; + } + + public static RelNode applyRexShuffleToTree(RelNode node, RexShuttle rexShuttle) { + if (node == null) { + return null; + } + List newInputs = new ArrayList<>(node.getInputs().size()); + for (RelNode inputRel : node.getInputs()) { + newInputs.add(applyRexShuffleToTree(inputRel, rexShuttle)); } + node = node.accept(rexShuttle); + if (newInputs.isEmpty()) { + return node; + } else { + return node.copy(node.getTraitSet(), newInputs); + } + } - public static RexNode createPathJoinCondition(IMatchNode left, IMatchNode right, - boolean caseSensitive, RexBuilder rexBuilder) { - Set commonLabels = GQLRelUtil.getCommonLabels(left, right); - List joinConditions = new ArrayList<>(); + public static RexNode createPathJoinCondition( + IMatchNode left, IMatchNode right, boolean caseSensitive, RexBuilder rexBuilder) { + Set commonLabels = GQLRelUtil.getCommonLabels(left, right); + List joinConditions = new ArrayList<>(); - for (String label : commonLabels) { - RelDataTypeField leftField = left.getRowType().getField(label, caseSensitive, false); - RelDataTypeField rightField = right.getRowType().getField(label, caseSensitive, false); - RexInputRef leftRef = rexBuilder.makeInputRef( - left.getPathSchema().getFieldList().get(leftField.getIndex()).getType(), - leftField.getIndex()); - RexNode leftRefId = rexBuilder.makeFieldAccess(leftRef, - VertexType.ID_FIELD_POSITION); - RexInputRef rightRef = rexBuilder.makeInputRef( - right.getPathSchema().getFieldList().get(rightField.getIndex()).getType(), - rightField.getIndex() + left.getRowType().getFieldCount()); - RexNode rightRefId = rexBuilder.makeFieldAccess(rightRef, - VertexType.ID_FIELD_POSITION); - // create expression of "leftLabel.id = rightLabel.id". - RexNode eq = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftRefId, rightRefId); - joinConditions.add(eq); - } - return RexUtil.composeConjunction(rexBuilder, joinConditions); + for (String label : commonLabels) { + RelDataTypeField leftField = left.getRowType().getField(label, caseSensitive, false); + RelDataTypeField rightField = right.getRowType().getField(label, caseSensitive, false); + RexInputRef leftRef = + rexBuilder.makeInputRef( + left.getPathSchema().getFieldList().get(leftField.getIndex()).getType(), + leftField.getIndex()); + RexNode leftRefId = rexBuilder.makeFieldAccess(leftRef, VertexType.ID_FIELD_POSITION); + RexInputRef rightRef = + rexBuilder.makeInputRef( + right.getPathSchema().getFieldList().get(rightField.getIndex()).getType(), + rightField.getIndex() + left.getRowType().getFieldCount()); + RexNode rightRefId = rexBuilder.makeFieldAccess(rightRef, VertexType.ID_FIELD_POSITION); + // create expression of "leftLabel.id = rightLabel.id". + RexNode eq = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, leftRefId, rightRefId); + joinConditions.add(eq); } + return RexUtil.composeConjunction(rexBuilder, joinConditions); + } - public static List getRelNodeReference(RelNode relNode) { - Set references = new HashSet<>(); - relNode.accept(new RexShuttle() { - @Override - public RexInputRef visitInputRef(RexInputRef inputRef) { - references.add(inputRef.getIndex()); - return inputRef; - } + public static List getRelNodeReference(RelNode relNode) { + Set references = new HashSet<>(); + relNode.accept( + new RexShuttle() { + @Override + public RexInputRef visitInputRef(RexInputRef inputRef) { + references.add(inputRef.getIndex()); + return inputRef; + } }); - return references.stream().sorted().collect(Collectors.toList()); - } + return references.stream().sorted().collect(Collectors.toList()); + } - public static RelNode adjustInputRef(RelNode relNode, Map indexMapping) { - return relNode.accept(new RexShuttle() { - @Override - public RexInputRef visitInputRef(RexInputRef inputRef) { - Integer newIndex = indexMapping.get(inputRef.getIndex()); - assert newIndex != null; - if (inputRef instanceof PathInputRef) { - return ((PathInputRef) inputRef).copy(newIndex); - } - return new RexInputRef(newIndex, inputRef.getType()); + public static RelNode adjustInputRef(RelNode relNode, Map indexMapping) { + return relNode.accept( + new RexShuttle() { + @Override + public RexInputRef visitInputRef(RexInputRef inputRef) { + Integer newIndex = indexMapping.get(inputRef.getIndex()); + assert newIndex != null; + if (inputRef instanceof PathInputRef) { + return ((PathInputRef) inputRef).copy(newIndex); } + return new RexInputRef(newIndex, inputRef.getType()); + } }); - } + } - /** - * Reverse the traversal order for {@link SingleMatchNode}. - * - * @param matchNode The match node to reverse. - * @param input If input is not null, we make it as the input node of the reversed match node. - */ - public static SingleMatchNode reverse(SingleMatchNode matchNode, IMatchNode input) { - IMatchNode concatNode = input; - SingleMatchNode node = matchNode; - while (node != null) { - SingleMatchNode endNode = node; - // find the latest VertexMatch/EdgeMatch - while (node != null && !(node instanceof IMatchLabel)) { - node = (SingleMatchNode) node.getInput(); - } - assert node != null; + /** + * Reverse the traversal order for {@link SingleMatchNode}. + * + * @param matchNode The match node to reverse. + * @param input If input is not null, we make it as the input node of the reversed match node. + */ + public static SingleMatchNode reverse(SingleMatchNode matchNode, IMatchNode input) { + IMatchNode concatNode = input; + SingleMatchNode node = matchNode; + while (node != null) { + SingleMatchNode endNode = node; + // find the latest VertexMatch/EdgeMatch + while (node != null && !(node instanceof IMatchLabel)) { + node = (SingleMatchNode) node.getInput(); + } + assert node != null; - RelDataTypeField field = endNode.getPathSchema().lastField().get(); - if (concatNode != null) { - PathRecordType pathRecordType = concatNode.getPathSchema().addField(field.getName() - , field.getType(), false); - concatNode = node.copy(Collections.singletonList(concatNode), pathRecordType); - } else { - PathRecordType pathRecordType = PathRecordType.EMPTY.addField(field.getName() - , field.getType(), false); - concatNode = node.copy(Collections.emptyList(), pathRecordType); - } - node = (SingleMatchNode) node.getInput(); - } - return (SingleMatchNode) concatNode; + RelDataTypeField field = endNode.getPathSchema().lastField().get(); + if (concatNode != null) { + PathRecordType pathRecordType = + concatNode.getPathSchema().addField(field.getName(), field.getType(), false); + concatNode = node.copy(Collections.singletonList(concatNode), pathRecordType); + } else { + PathRecordType pathRecordType = + PathRecordType.EMPTY.addField(field.getName(), field.getType(), false); + concatNode = node.copy(Collections.emptyList(), pathRecordType); + } + node = (SingleMatchNode) node.getInput(); } + return (SingleMatchNode) concatNode; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRexUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRexUtil.java index 75b37b3ce..ee61658a7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRexUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GQLRexUtil.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.util; -import com.google.common.base.Preconditions; -import com.google.common.collect.Sets; import java.math.BigDecimal; import java.sql.Time; import java.sql.Timestamp; @@ -32,6 +30,7 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.calcite.avatica.util.TimeUnitRange; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataTypeField; @@ -53,694 +52,703 @@ import org.apache.geaflow.dsl.rex.PathInputRef; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; -public class GQLRexUtil { - - private static class RexCollectVisitor implements RexVisitor> { - - private final Predicate condition; - - public RexCollectVisitor(Predicate condition) { - this.condition = condition; - } +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; - @Override - public List visitInputRef(RexInputRef inputRef) { - if (condition.test(inputRef)) { - return (List) Collections.singletonList(inputRef); - } - return Collections.emptyList(); - } +public class GQLRexUtil { - @Override - public List visitLocalRef(RexLocalRef localRef) { - if (condition.test(localRef)) { - return (List) Collections.singletonList(localRef); - } - return Collections.emptyList(); - } + private static class RexCollectVisitor implements RexVisitor> { - @Override - public List visitLiteral(RexLiteral literal) { - if (condition.test(literal)) { - return (List) Collections.singletonList(literal); - } - return Collections.emptyList(); - } + private final Predicate condition; - @Override - public List visitCall(RexCall call) { - List childNodes = call.operands.stream() - .flatMap(operand -> operand.accept(this).stream()) - .collect(Collectors.toList()); - if (condition.test(call)) { - List nodes = new ArrayList<>(childNodes); - nodes.add(call); - return (List) nodes; - } - return (List) childNodes; - } + public RexCollectVisitor(Predicate condition) { + this.condition = condition; + } - @Override - public List visitOver(RexOver over) { - if (condition.test(over)) { - return (List) Collections.singletonList(over); - } - return Collections.emptyList(); - } + @Override + public List visitInputRef(RexInputRef inputRef) { + if (condition.test(inputRef)) { + return (List) Collections.singletonList(inputRef); + } + return Collections.emptyList(); + } - @Override - public List visitCorrelVariable(RexCorrelVariable correlVariable) { - if (condition.test(correlVariable)) { - return (List) Collections.singletonList(correlVariable); - } - return Collections.emptyList(); - } + @Override + public List visitLocalRef(RexLocalRef localRef) { + if (condition.test(localRef)) { + return (List) Collections.singletonList(localRef); + } + return Collections.emptyList(); + } - @Override - public List visitDynamicParam(RexDynamicParam dynamicParam) { - if (condition.test(dynamicParam)) { - return (List) Collections.singletonList(dynamicParam); - } - return Collections.emptyList(); - } + @Override + public List visitLiteral(RexLiteral literal) { + if (condition.test(literal)) { + return (List) Collections.singletonList(literal); + } + return Collections.emptyList(); + } - @Override - public List visitRangeRef(RexRangeRef rangeRef) { - if (condition.test(rangeRef)) { - return (List) Collections.singletonList(rangeRef); - } - return Collections.emptyList(); - } + @Override + public List visitCall(RexCall call) { + List childNodes = + call.operands.stream() + .flatMap(operand -> operand.accept(this).stream()) + .collect(Collectors.toList()); + if (condition.test(call)) { + List nodes = new ArrayList<>(childNodes); + nodes.add(call); + return (List) nodes; + } + return (List) childNodes; + } - @Override - public List visitFieldAccess(RexFieldAccess fieldAccess) { - List collects = new ArrayList<>(fieldAccess.getReferenceExpr().accept(this)); - if (condition.test(fieldAccess)) { - collects.add(fieldAccess); - return (List) collects; - } - return (List) collects; - } + @Override + public List visitOver(RexOver over) { + if (condition.test(over)) { + return (List) Collections.singletonList(over); + } + return Collections.emptyList(); + } - @Override - public List visitSubQuery(RexSubQuery subQuery) { - ResultRexShuffle resultRexShuffle = new ResultRexShuffle<>(this); - subQuery.rel.accept(resultRexShuffle); - return resultRexShuffle.getResult(); - } + @Override + public List visitCorrelVariable(RexCorrelVariable correlVariable) { + if (condition.test(correlVariable)) { + return (List) Collections.singletonList(correlVariable); + } + return Collections.emptyList(); + } - @Override - public List visitTableInputRef(RexTableInputRef fieldRef) { - if (condition.test(fieldRef)) { - return (List) Collections.singletonList(fieldRef); - } - return Collections.emptyList(); - } + @Override + public List visitDynamicParam(RexDynamicParam dynamicParam) { + if (condition.test(dynamicParam)) { + return (List) Collections.singletonList(dynamicParam); + } + return Collections.emptyList(); + } - @Override - public List visitPatternFieldRef(RexPatternFieldRef fieldRef) { - if (condition.test(fieldRef)) { - return (List) Collections.singletonList(fieldRef); - } - return Collections.emptyList(); - } + @Override + public List visitRangeRef(RexRangeRef rangeRef) { + if (condition.test(rangeRef)) { + return (List) Collections.singletonList(rangeRef); + } + return Collections.emptyList(); + } - @Override - public List visitOther(RexNode other) { - if (condition.test(other)) { - return (List) Collections.singletonList(other); - } - return Collections.emptyList(); - } + @Override + public List visitFieldAccess(RexFieldAccess fieldAccess) { + List collects = new ArrayList<>(fieldAccess.getReferenceExpr().accept(this)); + if (condition.test(fieldAccess)) { + collects.add(fieldAccess); + return (List) collects; + } + return (List) collects; } - /** - * Collect sub-node for {@link RexNode} which satisfy the condition. - * - * @param rexNode The rex-node to collect. - * @param condition The collect condition. - * @return The sub-node list which satisfy the condition. - */ - @SuppressWarnings("unchecked") - public static List collect(RexNode rexNode, Predicate condition) { - return rexNode.accept(new RexCollectVisitor<>(condition)); - } - - public static List collect(RelNode node, Predicate condition) { - RexCollectVisitor collectVisitor = new RexCollectVisitor<>(condition); - ResultRexShuffle resultShuffle = new ResultRexShuffle<>(collectVisitor); - node.accept(resultShuffle); - return resultShuffle.getResult(); - } - - /** - * Whether the rex-node contains specified kind of child node. - */ - public static boolean contain(RexNode rexNode, Class targetNodeClass) { - return !collect(rexNode, operand -> operand.getClass() == targetNodeClass).isEmpty(); - } - - /** - * Replace the sub-node of the {@link RexNode} to the new sub-node defined by the replace function. - * - * @param rexNode The rex-node to replace. - * @param replaceFn The replace function which mapping the old rex-node to the new rex-node. - * @return The replaced rex-node. - */ - public static RexNode replace(RexNode rexNode, Function replaceFn) { - return rexNode.accept(new RexVisitor() { - @Override - public RexNode visitInputRef(RexInputRef inputRef) { - return replaceFn.apply(inputRef); - } + @Override + public List visitSubQuery(RexSubQuery subQuery) { + ResultRexShuffle resultRexShuffle = new ResultRexShuffle<>(this); + subQuery.rel.accept(resultRexShuffle); + return resultRexShuffle.getResult(); + } - @Override - public RexNode visitLocalRef(RexLocalRef localRef) { - return replaceFn.apply(localRef); - } + @Override + public List visitTableInputRef(RexTableInputRef fieldRef) { + if (condition.test(fieldRef)) { + return (List) Collections.singletonList(fieldRef); + } + return Collections.emptyList(); + } - @Override - public RexNode visitLiteral(RexLiteral literal) { - return replaceFn.apply(literal); - } + @Override + public List visitPatternFieldRef(RexPatternFieldRef fieldRef) { + if (condition.test(fieldRef)) { + return (List) Collections.singletonList(fieldRef); + } + return Collections.emptyList(); + } - @Override - public RexNode visitCall(RexCall call) { - List newOperands = call.operands.stream() + @Override + public List visitOther(RexNode other) { + if (condition.test(other)) { + return (List) Collections.singletonList(other); + } + return Collections.emptyList(); + } + } + + /** + * Collect sub-node for {@link RexNode} which satisfy the condition. + * + * @param rexNode The rex-node to collect. + * @param condition The collect condition. + * @return The sub-node list which satisfy the condition. + */ + @SuppressWarnings("unchecked") + public static List collect(RexNode rexNode, Predicate condition) { + return rexNode.accept(new RexCollectVisitor<>(condition)); + } + + public static List collect(RelNode node, Predicate condition) { + RexCollectVisitor collectVisitor = new RexCollectVisitor<>(condition); + ResultRexShuffle resultShuffle = new ResultRexShuffle<>(collectVisitor); + node.accept(resultShuffle); + return resultShuffle.getResult(); + } + + /** Whether the rex-node contains specified kind of child node. */ + public static boolean contain(RexNode rexNode, Class targetNodeClass) { + return !collect(rexNode, operand -> operand.getClass() == targetNodeClass).isEmpty(); + } + + /** + * Replace the sub-node of the {@link RexNode} to the new sub-node defined by the replace + * function. + * + * @param rexNode The rex-node to replace. + * @param replaceFn The replace function which mapping the old rex-node to the new rex-node. + * @return The replaced rex-node. + */ + public static RexNode replace(RexNode rexNode, Function replaceFn) { + return rexNode.accept( + new RexVisitor() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + return replaceFn.apply(inputRef); + } + + @Override + public RexNode visitLocalRef(RexLocalRef localRef) { + return replaceFn.apply(localRef); + } + + @Override + public RexNode visitLiteral(RexLiteral literal) { + return replaceFn.apply(literal); + } + + @Override + public RexNode visitCall(RexCall call) { + List newOperands = + call.operands.stream() .map(operand -> operand.accept(this)) .collect(Collectors.toList()); - RexNode newCall = call.clone(call.getType(), newOperands); - return replaceFn.apply(newCall); - } - - @Override - public RexNode visitOver(RexOver over) { - return replaceFn.apply(over); - } - - @Override - public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { - return replaceFn.apply(correlVariable); - } - - @Override - public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { - return replaceFn.apply(dynamicParam); - } - - @Override - public RexNode visitRangeRef(RexRangeRef rangeRef) { - return replaceFn.apply(rangeRef); - } - - @Override - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { - return replaceFn.apply(fieldAccess); - } - - @Override - public RexNode visitSubQuery(RexSubQuery subQuery) { - return replaceFn.apply(subQuery); - } - - @Override - public RexNode visitTableInputRef(RexTableInputRef fieldRef) { - return replaceFn.apply(fieldRef); - } - - @Override - public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) { - return replaceFn.apply(fieldRef); - } - - @Override - public RexNode visitOther(RexNode other) { - return replaceFn.apply(other); - } + RexNode newCall = call.clone(call.getType(), newOperands); + return replaceFn.apply(newCall); + } + + @Override + public RexNode visitOver(RexOver over) { + return replaceFn.apply(over); + } + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { + return replaceFn.apply(correlVariable); + } + + @Override + public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { + return replaceFn.apply(dynamicParam); + } + + @Override + public RexNode visitRangeRef(RexRangeRef rangeRef) { + return replaceFn.apply(rangeRef); + } + + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + return replaceFn.apply(fieldAccess); + } + + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + return replaceFn.apply(subQuery); + } + + @Override + public RexNode visitTableInputRef(RexTableInputRef fieldRef) { + return replaceFn.apply(fieldRef); + } + + @Override + public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) { + return replaceFn.apply(fieldRef); + } + + @Override + public RexNode visitOther(RexNode other) { + return replaceFn.apply(other); + } }); - } - - /** - * Find vertex ids in the expression. - * e.g. for "a.id = '1' or a.id = '2'", Set("1", "2") will return. - * - * @param rexNode The expression. - * @param vertexRecordType The input vertex type for the expression. - * @return The id literals referred by the expression. - */ - public static Set findVertexIds(RexNode rexNode, VertexRecordType vertexRecordType) { - return rexNode.accept(new RexVisitor>() { - @Override - public Set visitInputRef(RexInputRef rexInputRef) { - return new HashSet<>(); - } - - @Override - public Set visitLocalRef(RexLocalRef rexLocalRef) { - return new HashSet<>(); - } - - @Override - public Set visitLiteral(RexLiteral rexLiteral) { - return new HashSet<>(); - } - - @Override - public Set visitCall(RexCall call) { - SqlKind kind = call.getKind(); - switch (kind) { - case EQUALS: - RexNode idValue = null; - RexNode left = call.operands.get(0); - RexNode right = call.operands.get(1); - if (isIdField(vertexRecordType, left) && isLiteralOrParameter(right, true)) { - idValue = right; - } else if (isIdField(vertexRecordType, right) && isLiteralOrParameter(left, true)) { - idValue = left; - } - if (idValue != null) { - return Sets.newHashSet(idValue); - } else { - return new HashSet<>(); - } - case AND: - return call.operands.stream() - .map(operand -> operand.accept(this)) - .filter(set -> !set.isEmpty()) - .reduce(Sets::intersection) - .orElse(new HashSet<>()); - case OR: - return call.operands.stream() - .map(operand -> operand.accept(this)) - .reduce((a, b) -> { - if (a.isEmpty() || b.isEmpty()) { - // all child should be id condition, else return empty. - return Sets.newHashSet(); - } else { - return Sets.union(a, b); - } - }) - .orElse(new HashSet<>()); - case CAST: - return call.operands.get(0).accept(this); - default: - return new HashSet<>(); + } + + /** + * Find vertex ids in the expression. e.g. for "a.id = '1' or a.id = '2'", Set("1", "2") will + * return. + * + * @param rexNode The expression. + * @param vertexRecordType The input vertex type for the expression. + * @return The id literals referred by the expression. + */ + public static Set findVertexIds(RexNode rexNode, VertexRecordType vertexRecordType) { + return rexNode.accept( + new RexVisitor>() { + @Override + public Set visitInputRef(RexInputRef rexInputRef) { + return new HashSet<>(); + } + + @Override + public Set visitLocalRef(RexLocalRef rexLocalRef) { + return new HashSet<>(); + } + + @Override + public Set visitLiteral(RexLiteral rexLiteral) { + return new HashSet<>(); + } + + @Override + public Set visitCall(RexCall call) { + SqlKind kind = call.getKind(); + switch (kind) { + case EQUALS: + RexNode idValue = null; + RexNode left = call.operands.get(0); + RexNode right = call.operands.get(1); + if (isIdField(vertexRecordType, left) && isLiteralOrParameter(right, true)) { + idValue = right; + } else if (isIdField(vertexRecordType, right) && isLiteralOrParameter(left, true)) { + idValue = left; } - } - - @Override - public Set visitOver(RexOver rexOver) { - return new HashSet<>(); - } - - @Override - public Set visitCorrelVariable(RexCorrelVariable rexCorrelVariable) { - return new HashSet<>(); - } - - @Override - public Set visitDynamicParam(RexDynamicParam rexDynamicParam) { - return new HashSet<>(); - } - - @Override - public Set visitRangeRef(RexRangeRef rexRangeRef) { - return new HashSet<>(); - } - - @Override - public Set visitFieldAccess(RexFieldAccess rexFieldAccess) { - return new HashSet<>(); - } - - @Override - public Set visitSubQuery(RexSubQuery rexSubQuery) { - return new HashSet<>(); - } - - @Override - public Set visitTableInputRef(RexTableInputRef rexTableInputRef) { - return new HashSet<>(); - } - - @Override - public Set visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { - return new HashSet<>(); - } - - @Override - public Set visitOther(RexNode other) { + if (idValue != null) { + return Sets.newHashSet(idValue); + } else { + return new HashSet<>(); + } + case AND: + return call.operands.stream() + .map(operand -> operand.accept(this)) + .filter(set -> !set.isEmpty()) + .reduce(Sets::intersection) + .orElse(new HashSet<>()); + case OR: + return call.operands.stream() + .map(operand -> operand.accept(this)) + .reduce( + (a, b) -> { + if (a.isEmpty() || b.isEmpty()) { + // all child should be id condition, else return empty. + return Sets.newHashSet(); + } else { + return Sets.union(a, b); + } + }) + .orElse(new HashSet<>()); + case CAST: + return call.operands.get(0).accept(this); + default: return new HashSet<>(); } + } + + @Override + public Set visitOver(RexOver rexOver) { + return new HashSet<>(); + } + + @Override + public Set visitCorrelVariable(RexCorrelVariable rexCorrelVariable) { + return new HashSet<>(); + } + + @Override + public Set visitDynamicParam(RexDynamicParam rexDynamicParam) { + return new HashSet<>(); + } + + @Override + public Set visitRangeRef(RexRangeRef rexRangeRef) { + return new HashSet<>(); + } + + @Override + public Set visitFieldAccess(RexFieldAccess rexFieldAccess) { + return new HashSet<>(); + } + + @Override + public Set visitSubQuery(RexSubQuery rexSubQuery) { + return new HashSet<>(); + } + + @Override + public Set visitTableInputRef(RexTableInputRef rexTableInputRef) { + return new HashSet<>(); + } + + @Override + public Set visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { + return new HashSet<>(); + } + + @Override + public Set visitOther(RexNode other) { + return new HashSet<>(); + } }); + } + + public static RexNode swapReverseEdgeRef( + RexNode rexNode, String reverseEdgeName, RexBuilder rexBuilder) { + return GQLRexUtil.replace( + rexNode, + node -> { + if (node instanceof RexFieldAccess + && ((RexFieldAccess) node).getReferenceExpr() instanceof PathInputRef) { + RexFieldAccess fieldAccess = (RexFieldAccess) node; + PathInputRef pathInputRef = (PathInputRef) fieldAccess.getReferenceExpr(); + if (pathInputRef.getLabel().equals(reverseEdgeName) + && fieldAccess.getType() instanceof MetaFieldType) { + if (((MetaFieldType) fieldAccess.getType()) + .getMetaField() + .equals(MetaField.EDGE_SRC_ID)) { + return rexBuilder.makeFieldAccess( + pathInputRef, + ((EdgeRecordType) pathInputRef.getType()).getTargetIdField().getIndex()); + } else if (((MetaFieldType) fieldAccess.getType()) + .getMetaField() + .equals(MetaField.EDGE_TARGET_ID)) { + return rexBuilder.makeFieldAccess( + pathInputRef, + ((EdgeRecordType) pathInputRef.getType()).getSrcIdField().getIndex()); + } + } + } + return node; + }); + } + + public static RexNode removeIdCondition(RexNode condition, VertexRecordType vertexRecordType) { + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + switch (call.getKind()) { + case EQUALS: + RexNode left = call.operands.get(0); + RexNode right = call.operands.get(1); + if (isIdField(vertexRecordType, left) && isLiteralOrParameter(right, true)) { + return null; + } + if (isIdField(vertexRecordType, right) && isLiteralOrParameter(left, true)) { + return null; + } + break; + case AND: + List filterOperands = + call.operands.stream() + .filter(operand -> removeIdCondition(operand, vertexRecordType) != null) + .collect(Collectors.toList()); + if (filterOperands.size() == 0) { + return null; + } else if (filterOperands.size() == 1) { + return filterOperands.get(0); + } + return call.clone(call.getType(), filterOperands); + case OR: + boolean allRemove = + call.operands.stream() + .allMatch(operand -> removeIdCondition(operand, vertexRecordType) == null); + if (allRemove) { + return null; + } + break; + case CAST: + RexNode newOperand = removeIdCondition(call.operands.get(0), vertexRecordType); + if (newOperand == null) { + return null; + } + return call.clone(call.getType(), Collections.singletonList(newOperand)); + default: + } } + return condition; + } - - public static RexNode swapReverseEdgeRef(RexNode rexNode, String reverseEdgeName, - RexBuilder rexBuilder) { - return GQLRexUtil.replace(rexNode, - node -> { - if (node instanceof RexFieldAccess - && ((RexFieldAccess) node).getReferenceExpr() instanceof PathInputRef) { - RexFieldAccess fieldAccess = (RexFieldAccess) node; - PathInputRef pathInputRef = (PathInputRef) fieldAccess.getReferenceExpr(); - if (pathInputRef.getLabel().equals(reverseEdgeName) - && fieldAccess.getType() instanceof MetaFieldType) { - if (((MetaFieldType) fieldAccess.getType()).getMetaField() - .equals(MetaField.EDGE_SRC_ID)) { - return rexBuilder.makeFieldAccess(pathInputRef, - ((EdgeRecordType) pathInputRef.getType()).getTargetIdField() - .getIndex()); - } else if (((MetaFieldType) fieldAccess.getType()).getMetaField() - .equals(MetaField.EDGE_TARGET_ID)) { - return rexBuilder.makeFieldAccess(pathInputRef, - ((EdgeRecordType) pathInputRef.getType()).getSrcIdField() - .getIndex()); - } - } - } - return node; - }); - } - - public static RexNode removeIdCondition(RexNode condition, VertexRecordType vertexRecordType) { - if (condition instanceof RexCall) { - RexCall call = (RexCall) condition; - switch (call.getKind()) { - case EQUALS: - RexNode left = call.operands.get(0); - RexNode right = call.operands.get(1); - if (isIdField(vertexRecordType, left) && isLiteralOrParameter(right, true)) { - return null; - } - if (isIdField(vertexRecordType, right) && isLiteralOrParameter(left, true)) { - return null; - } - break; - case AND: - List filterOperands = call.operands.stream() - .filter(operand -> removeIdCondition(operand, vertexRecordType) != null) - .collect(Collectors.toList()); - if (filterOperands.size() == 0) { - return null; - } else if (filterOperands.size() == 1) { - return filterOperands.get(0); - } - return call.clone(call.getType(), filterOperands); - case OR: - boolean allRemove = - call.operands.stream().allMatch(operand -> removeIdCondition(operand, - vertexRecordType) == null); - if (allRemove) { - return null; - } - break; - case CAST: - RexNode newOperand = removeIdCondition(call.operands.get(0), vertexRecordType); - if (newOperand == null) { - return null; - } - return call.clone(call.getType(), Collections.singletonList(newOperand)); - default: - } - } - return condition; + private static boolean isIdField(VertexRecordType vertexRecordType, RexNode node) { + if (node instanceof RexFieldAccess) { + int index = ((RexFieldAccess) node).getField().getIndex(); + return vertexRecordType.isId(index); } - - private static boolean isIdField(VertexRecordType vertexRecordType, RexNode node) { - if (node instanceof RexFieldAccess) { - int index = ((RexFieldAccess) node).getField().getIndex(); - return vertexRecordType.isId(index); - } - return false; - } - - public static Object getLiteralValue(RexNode node) { - SqlKind kind = node.getKind(); - if (kind == SqlKind.LITERAL) { - RexLiteral literal = (RexLiteral) node; - return getLiteralValue(literal); - } else if (kind == SqlKind.CAST) { - RexCall cast = (RexCall) node; - Object value = getLiteralValue(cast.operands.get(0)); - IType targetType = SqlTypeUtil.convertType(cast.getType()); - return TypeCastUtil.cast(value, targetType); - } - throw new IllegalArgumentException("RexNode: " + node + " is not a literal"); + return false; + } + + public static Object getLiteralValue(RexNode node) { + SqlKind kind = node.getKind(); + if (kind == SqlKind.LITERAL) { + RexLiteral literal = (RexLiteral) node; + return getLiteralValue(literal); + } else if (kind == SqlKind.CAST) { + RexCall cast = (RexCall) node; + Object value = getLiteralValue(cast.operands.get(0)); + IType targetType = SqlTypeUtil.convertType(cast.getType()); + return TypeCastUtil.cast(value, targetType); } + throw new IllegalArgumentException("RexNode: " + node + " is not a literal"); + } - public static Object getLiteralValue(RexLiteral literal) { - if (literal == null) { - return null; - } - SqlTypeName typeName = literal.getType().getSqlTypeName(); - Object value = literal.getValue(); - if (value == null) { - return null; - } - switch (typeName) { - case BOOLEAN: - return Boolean.class.cast(value); - - case TINYINT: - return ((BigDecimal) literal.getValue()).byteValue(); - case SMALLINT: - return ((BigDecimal) literal.getValue()).shortValue(); - case INTEGER: - return ((BigDecimal) literal.getValue()).intValue(); - case BIGINT: - case INTERVAL_SECOND: - case INTERVAL_MINUTE: - case INTERVAL_HOUR: - case INTERVAL_DAY: - case INTERVAL_MONTH: - case INTERVAL_YEAR: - return ((BigDecimal) literal.getValue()).longValue(); - - case FLOAT: - case DOUBLE: - case REAL: - return ((BigDecimal) literal.getValue()).doubleValue(); - case DECIMAL: - return literal.getValue(); - - case CHAR: - case VARCHAR: - Preconditions - .checkArgument(literal.getValue() instanceof NlsString, - "literal create type char/varchar must be NlsString type"); - return StringLiteralUtil - .unescapeSQLString("\"" + RexLiteral.stringValue(literal) + "\""); - case SYMBOL: - Preconditions.checkArgument(value instanceof Enum, - "literal create type symbol must be Enum type"); - if (value instanceof TimeUnitRange) { - return ((TimeUnitRange) value).startUnit.multiplier.intValue(); - } else if (value instanceof SqlTrimFunction.Flag) { - SqlTrimFunction.Flag flag = (Flag) value; - switch (flag) { - case BOTH: - return GeaFlowBuiltinFunctions.TRIM_BOTH; - case LEADING: - return GeaFlowBuiltinFunctions.TRIM_LEFT; - case TRAILING: - return GeaFlowBuiltinFunctions.TRIM_RIGHT; - default: - throw new IllegalArgumentException("illegal trim flag: " + flag); - } - } - break; - case DATE: - return java.sql.Date.valueOf(literal.toString()); - case TIME: - return Time.valueOf(literal.toString()); - case TIMESTAMP: - return Timestamp.valueOf(literal.toString()); - case BINARY: - case VARBINARY: - return byte[].class.cast(literal.getValue()); + public static Object getLiteralValue(RexLiteral literal) { + if (literal == null) { + return null; + } + SqlTypeName typeName = literal.getType().getSqlTypeName(); + Object value = literal.getValue(); + if (value == null) { + return null; + } + switch (typeName) { + case BOOLEAN: + return Boolean.class.cast(value); + + case TINYINT: + return ((BigDecimal) literal.getValue()).byteValue(); + case SMALLINT: + return ((BigDecimal) literal.getValue()).shortValue(); + case INTEGER: + return ((BigDecimal) literal.getValue()).intValue(); + case BIGINT: + case INTERVAL_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_HOUR: + case INTERVAL_DAY: + case INTERVAL_MONTH: + case INTERVAL_YEAR: + return ((BigDecimal) literal.getValue()).longValue(); + + case FLOAT: + case DOUBLE: + case REAL: + return ((BigDecimal) literal.getValue()).doubleValue(); + case DECIMAL: + return literal.getValue(); + + case CHAR: + case VARCHAR: + Preconditions.checkArgument( + literal.getValue() instanceof NlsString, + "literal create type char/varchar must be NlsString type"); + return StringLiteralUtil.unescapeSQLString("\"" + RexLiteral.stringValue(literal) + "\""); + case SYMBOL: + Preconditions.checkArgument( + value instanceof Enum, "literal create type symbol must be Enum type"); + if (value instanceof TimeUnitRange) { + return ((TimeUnitRange) value).startUnit.multiplier.intValue(); + } else if (value instanceof SqlTrimFunction.Flag) { + SqlTrimFunction.Flag flag = (Flag) value; + switch (flag) { + case BOTH: + return GeaFlowBuiltinFunctions.TRIM_BOTH; + case LEADING: + return GeaFlowBuiltinFunctions.TRIM_LEFT; + case TRAILING: + return GeaFlowBuiltinFunctions.TRIM_RIGHT; default: - throw new GeaFlowDSLException("Not support type:" + typeName); - } + throw new IllegalArgumentException("illegal trim flag: " + flag); + } + } + break; + case DATE: + return java.sql.Date.valueOf(literal.toString()); + case TIME: + return Time.valueOf(literal.toString()); + case TIMESTAMP: + return Timestamp.valueOf(literal.toString()); + case BINARY: + case VARBINARY: + return byte[].class.cast(literal.getValue()); + default: throw new GeaFlowDSLException("Not support type:" + typeName); } - - public static RexNode toPathInputRefForWhere(RelDataTypeField pathField, RexNode where) { - RexBuilder builder = new RexBuilder(GQLJavaTypeFactory.create()); - return where.accept(new RexShuttle() { - @Override - public RexNode visitInputRef(RexInputRef inputRef) { - PathInputRef pathInputRef = new PathInputRef(pathField.getName(), - pathField.getIndex(), pathField.getType()); - return builder.makeFieldAccess(pathInputRef, inputRef.getIndex()); - } + throw new GeaFlowDSLException("Not support type:" + typeName); + } + + public static RexNode toPathInputRefForWhere(RelDataTypeField pathField, RexNode where) { + RexBuilder builder = new RexBuilder(GQLJavaTypeFactory.create()); + return where.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + PathInputRef pathInputRef = + new PathInputRef(pathField.getName(), pathField.getIndex(), pathField.getType()); + return builder.makeFieldAccess(pathInputRef, inputRef.getIndex()); + } }); - } + } - public static RexLiteral createString(String value) { - RexBuilder rexBuilder = new RexBuilder(GQLJavaTypeFactory.create()); - return rexBuilder.makeLiteral(value); - } + public static RexLiteral createString(String value) { + RexBuilder rexBuilder = new RexBuilder(GQLJavaTypeFactory.create()); + return rexBuilder.makeLiteral(value); + } - public static boolean isLiteralOrParameter(RexNode rexNode, boolean allowCast) { - if (rexNode.getKind() == SqlKind.CAST && allowCast) { - return isLiteralOrParameter(((RexCall) rexNode).operands.get(0), true); - } - return !contain(rexNode, RexInputRef.class) && !contain(rexNode, RexFieldAccess.class); + public static boolean isLiteralOrParameter(RexNode rexNode, boolean allowCast) { + if (rexNode.getKind() == SqlKind.CAST && allowCast) { + return isLiteralOrParameter(((RexCall) rexNode).operands.get(0), true); } - - public static boolean isVertexIdFieldAccess(RexNode rexNode) { - if (rexNode instanceof RexFieldAccess) { - RexFieldAccess op = (RexFieldAccess) rexNode; - if (op.getReferenceExpr() instanceof PathInputRef - && op.getType() instanceof MetaFieldType) { - MetaFieldType opType = (MetaFieldType) op.getType(); - return opType.getMetaField() == MetaField.VERTEX_ID; - } - } - return false; + return !contain(rexNode, RexInputRef.class) && !contain(rexNode, RexFieldAccess.class); + } + + public static boolean isVertexIdFieldAccess(RexNode rexNode) { + if (rexNode instanceof RexFieldAccess) { + RexFieldAccess op = (RexFieldAccess) rexNode; + if (op.getReferenceExpr() instanceof PathInputRef && op.getType() instanceof MetaFieldType) { + MetaFieldType opType = (MetaFieldType) op.getType(); + return opType.getMetaField() == MetaField.VERTEX_ID; + } } + return false; + } - public static class ResultRexShuffle extends RexShuttle { + public static class ResultRexShuffle extends RexShuttle { - private final RexVisitor> baseVisitor; + private final RexVisitor> baseVisitor; - private final List result = new ArrayList<>(); + private final List result = new ArrayList<>(); - public ResultRexShuffle(RexVisitor> baseVisitor) { - this.baseVisitor = baseVisitor; - } + public ResultRexShuffle(RexVisitor> baseVisitor) { + this.baseVisitor = baseVisitor; + } - public List getResult() { - return result; - } + public List getResult() { + return result; + } - @Override - public RexNode visitInputRef(RexInputRef inputRef) { - List visitResults = baseVisitor.visitInputRef(inputRef); - if (visitResults != null) { - result.addAll(visitResults); - } - return inputRef; - } + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + List visitResults = baseVisitor.visitInputRef(inputRef); + if (visitResults != null) { + result.addAll(visitResults); + } + return inputRef; + } - @Override - public RexNode visitLocalRef(RexLocalRef localRef) { - List visitResults = baseVisitor.visitLocalRef(localRef); - if (visitResults != null) { - result.addAll(visitResults); - } - return localRef; - } + @Override + public RexNode visitLocalRef(RexLocalRef localRef) { + List visitResults = baseVisitor.visitLocalRef(localRef); + if (visitResults != null) { + result.addAll(visitResults); + } + return localRef; + } - @Override - public RexNode visitLiteral(RexLiteral literal) { - List visitResults = baseVisitor.visitLiteral(literal); - if (visitResults != null) { - result.addAll(visitResults); - } - return literal; - } + @Override + public RexNode visitLiteral(RexLiteral literal) { + List visitResults = baseVisitor.visitLiteral(literal); + if (visitResults != null) { + result.addAll(visitResults); + } + return literal; + } - @Override - public RexNode visitCall(RexCall call) { - List visitResults = baseVisitor.visitCall(call); - if (visitResults != null) { - result.addAll(visitResults); - } - return call; - } + @Override + public RexNode visitCall(RexCall call) { + List visitResults = baseVisitor.visitCall(call); + if (visitResults != null) { + result.addAll(visitResults); + } + return call; + } - @Override - public RexNode visitOver(RexOver over) { - List visitResults = baseVisitor.visitOver(over); - if (visitResults != null) { - result.addAll(visitResults); - } - return over; - } + @Override + public RexNode visitOver(RexOver over) { + List visitResults = baseVisitor.visitOver(over); + if (visitResults != null) { + result.addAll(visitResults); + } + return over; + } - @Override - public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { - List visitResults = baseVisitor.visitCorrelVariable(correlVariable); - if (visitResults != null) { - result.addAll(visitResults); - } - return correlVariable; - } + @Override + public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { + List visitResults = baseVisitor.visitCorrelVariable(correlVariable); + if (visitResults != null) { + result.addAll(visitResults); + } + return correlVariable; + } - @Override - public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { - List visitResults = baseVisitor.visitDynamicParam(dynamicParam); - if (visitResults != null) { - result.addAll(visitResults); - } - return dynamicParam; - } + @Override + public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { + List visitResults = baseVisitor.visitDynamicParam(dynamicParam); + if (visitResults != null) { + result.addAll(visitResults); + } + return dynamicParam; + } - @Override - public RexNode visitRangeRef(RexRangeRef rangeRef) { - List visitResults = baseVisitor.visitRangeRef(rangeRef); - if (visitResults != null) { - result.addAll(visitResults); - } - return rangeRef; - } + @Override + public RexNode visitRangeRef(RexRangeRef rangeRef) { + List visitResults = baseVisitor.visitRangeRef(rangeRef); + if (visitResults != null) { + result.addAll(visitResults); + } + return rangeRef; + } - @Override - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { - List visitResults = baseVisitor.visitFieldAccess(fieldAccess); - if (visitResults != null) { - result.addAll(visitResults); - } - return fieldAccess; - } + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + List visitResults = baseVisitor.visitFieldAccess(fieldAccess); + if (visitResults != null) { + result.addAll(visitResults); + } + return fieldAccess; + } - @Override - public RexNode visitSubQuery(RexSubQuery subQuery) { - List visitResults = baseVisitor.visitSubQuery(subQuery); - if (visitResults != null) { - result.addAll(visitResults); - } - return subQuery; - } + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + List visitResults = baseVisitor.visitSubQuery(subQuery); + if (visitResults != null) { + result.addAll(visitResults); + } + return subQuery; + } - @Override - public RexNode visitTableInputRef(RexTableInputRef fieldRef) { - List visitResults = baseVisitor.visitTableInputRef(fieldRef); - if (visitResults != null) { - result.addAll(visitResults); - } - return fieldRef; - } + @Override + public RexNode visitTableInputRef(RexTableInputRef fieldRef) { + List visitResults = baseVisitor.visitTableInputRef(fieldRef); + if (visitResults != null) { + result.addAll(visitResults); + } + return fieldRef; + } - @Override - public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) { - List visitResults = baseVisitor.visitPatternFieldRef(fieldRef); - if (visitResults != null) { - result.addAll(visitResults); - } - return fieldRef; - } + @Override + public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) { + List visitResults = baseVisitor.visitPatternFieldRef(fieldRef); + if (visitResults != null) { + result.addAll(visitResults); + } + return fieldRef; + } - @Override - public RexNode visitOther(RexNode other) { - List visitResults = baseVisitor.visitOther(other); - if (visitResults != null) { - result.addAll(visitResults); - } - return other; - } + @Override + public RexNode visitOther(RexNode other) { + List visitResults = baseVisitor.visitOther(other); + if (visitResults != null) { + result.addAll(visitResults); + } + return other; } + } - public static RexNode and(List conditions, RexBuilder builder) { - if (conditions == null) { - return null; - } - if (conditions.size() == 1) { - return conditions.get(0); - } - return builder.makeCall(SqlStdOperatorTable.AND, conditions); + public static RexNode and(List conditions, RexBuilder builder) { + if (conditions == null) { + return null; + } + if (conditions.size() == 1) { + return conditions.get(0); } + return builder.makeCall(SqlStdOperatorTable.AND, conditions); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphDescriptorUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphDescriptorUtil.java index 06a3de13c..6bc9dac00 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphDescriptorUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphDescriptorUtil.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.calcite.sql.SqlNodeList; import org.apache.commons.collections.map.HashedMap; import org.apache.geaflow.dsl.common.descriptor.EdgeDescriptor; @@ -31,35 +32,36 @@ public class GraphDescriptorUtil { - public static List getEdgeDescriptor(GraphDescriptor desc, String graphName, SqlEdge sqlEdge) { - return getEdgeDescriptor(desc, graphName, sqlEdge.getName().getSimple(), sqlEdge.getConstraints()); - } + public static List getEdgeDescriptor( + GraphDescriptor desc, String graphName, SqlEdge sqlEdge) { + return getEdgeDescriptor( + desc, graphName, sqlEdge.getName().getSimple(), sqlEdge.getConstraints()); + } - public static List getEdgeDescriptor(GraphDescriptor desc, String graphName, SqlEdgeUsing sqlEdgeUsing) { - return getEdgeDescriptor(desc, graphName, sqlEdgeUsing.getName().getSimple(), sqlEdgeUsing.getConstraints()); - } + public static List getEdgeDescriptor( + GraphDescriptor desc, String graphName, SqlEdgeUsing sqlEdgeUsing) { + return getEdgeDescriptor( + desc, graphName, sqlEdgeUsing.getName().getSimple(), sqlEdgeUsing.getConstraints()); + } - private static List getEdgeDescriptor(GraphDescriptor desc, - String graphName, - String edgeName, - SqlNodeList constraints) { - List result = new ArrayList<>(); - Map> sourceType2TargetTypes = new HashedMap(); - for (Object obj : constraints) { - assert obj instanceof GQLEdgeConstraint; - GQLEdgeConstraint constraint = (GQLEdgeConstraint) obj; - for (String sourceType : constraint.getSourceVertexTypes()) { - sourceType2TargetTypes.computeIfAbsent(sourceType, t -> new ArrayList<>()); - for (String targetType : constraint.getTargetVertexTypes()) { - if (!sourceType2TargetTypes.get(sourceType).contains(targetType)) { - result.add(new EdgeDescriptor(desc.getIdName(graphName), edgeName, - sourceType, targetType)); - sourceType2TargetTypes.get(sourceType).add(targetType); - } - } - } + private static List getEdgeDescriptor( + GraphDescriptor desc, String graphName, String edgeName, SqlNodeList constraints) { + List result = new ArrayList<>(); + Map> sourceType2TargetTypes = new HashedMap(); + for (Object obj : constraints) { + assert obj instanceof GQLEdgeConstraint; + GQLEdgeConstraint constraint = (GQLEdgeConstraint) obj; + for (String sourceType : constraint.getSourceVertexTypes()) { + sourceType2TargetTypes.computeIfAbsent(sourceType, t -> new ArrayList<>()); + for (String targetType : constraint.getTargetVertexTypes()) { + if (!sourceType2TargetTypes.get(sourceType).contains(targetType)) { + result.add( + new EdgeDescriptor(desc.getIdName(graphName), edgeName, sourceType, targetType)); + sourceType2TargetTypes.get(sourceType).add(targetType); + } } - return result; + } } - + return result; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphSchemaUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphSchemaUtil.java index ea37484f6..cf0077fc7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphSchemaUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/GraphSchemaUtil.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.util; import java.util.Optional; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -27,65 +28,111 @@ public class GraphSchemaUtil { - public static RelDataType getCurrentGraphVertexIdType(GQLJavaTypeFactory typeFactory) { - if (typeFactory.getCurrentGraph() == null) { - throw new GeaFlowDSLException("Cannot get vertex id type without setting current graph"); - } else if (typeFactory.getCurrentGraph().getVertexTables().isEmpty()) { - throw new GeaFlowDSLException("No vertex table found in current graph {}", typeFactory.getCurrentGraph()); - } else { - return typeFactory.getCurrentGraph().getVertexTables().get(0).getRowType(typeFactory).getIdField().getType(); - } + public static RelDataType getCurrentGraphVertexIdType(GQLJavaTypeFactory typeFactory) { + if (typeFactory.getCurrentGraph() == null) { + throw new GeaFlowDSLException("Cannot get vertex id type without setting current graph"); + } else if (typeFactory.getCurrentGraph().getVertexTables().isEmpty()) { + throw new GeaFlowDSLException( + "No vertex table found in current graph {}", typeFactory.getCurrentGraph()); + } else { + return typeFactory + .getCurrentGraph() + .getVertexTables() + .get(0) + .getRowType(typeFactory) + .getIdField() + .getType(); } + } - public static RelDataType getCurrentGraphEdgeSrcIdType(GQLJavaTypeFactory typeFactory) { - if (typeFactory.getCurrentGraph() == null) { - throw new GeaFlowDSLException("Cannot get edge src id type without setting current graph"); - } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { - throw new GeaFlowDSLException("No edge table found in current graph {}", typeFactory.getCurrentGraph()); - } else { - return typeFactory.getCurrentGraph().getEdgeTables().get(0).getRowType(typeFactory).getSrcIdField().getType(); - } + public static RelDataType getCurrentGraphEdgeSrcIdType(GQLJavaTypeFactory typeFactory) { + if (typeFactory.getCurrentGraph() == null) { + throw new GeaFlowDSLException("Cannot get edge src id type without setting current graph"); + } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { + throw new GeaFlowDSLException( + "No edge table found in current graph {}", typeFactory.getCurrentGraph()); + } else { + return typeFactory + .getCurrentGraph() + .getEdgeTables() + .get(0) + .getRowType(typeFactory) + .getSrcIdField() + .getType(); } + } - public static RelDataType getCurrentGraphEdgeTargetIdType(GQLJavaTypeFactory typeFactory) { - if (typeFactory.getCurrentGraph() == null) { - throw new GeaFlowDSLException("Cannot get edge target id type without setting current graph"); - } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { - throw new GeaFlowDSLException("No edge table found in current graph {}", typeFactory.getCurrentGraph()); - } else { - return typeFactory.getCurrentGraph().getEdgeTables().get(0).getRowType(typeFactory).getTargetIdField().getType(); - } + public static RelDataType getCurrentGraphEdgeTargetIdType(GQLJavaTypeFactory typeFactory) { + if (typeFactory.getCurrentGraph() == null) { + throw new GeaFlowDSLException("Cannot get edge target id type without setting current graph"); + } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { + throw new GeaFlowDSLException( + "No edge table found in current graph {}", typeFactory.getCurrentGraph()); + } else { + return typeFactory + .getCurrentGraph() + .getEdgeTables() + .get(0) + .getRowType(typeFactory) + .getTargetIdField() + .getType(); } + } - public static RelDataType getCurrentGraphLabelType(GQLJavaTypeFactory typeFactory) { - if (typeFactory.getCurrentGraph() == null) { - throw new GeaFlowDSLException("Cannot get label type without setting current graph"); - } else if (!typeFactory.getCurrentGraph().getVertexTables().isEmpty()) { - return typeFactory.getCurrentGraph().getVertexTables().get(0).getRowType(typeFactory) - .getLabelField().getType(); - } else if (!typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { - return typeFactory.getCurrentGraph().getEdgeTables().get(0).getRowType(typeFactory) - .getLabelField().getType(); - } else { - throw new GeaFlowDSLException("No vertex or edge table found in current graph {}", typeFactory.getCurrentGraph()); - } + public static RelDataType getCurrentGraphLabelType(GQLJavaTypeFactory typeFactory) { + if (typeFactory.getCurrentGraph() == null) { + throw new GeaFlowDSLException("Cannot get label type without setting current graph"); + } else if (!typeFactory.getCurrentGraph().getVertexTables().isEmpty()) { + return typeFactory + .getCurrentGraph() + .getVertexTables() + .get(0) + .getRowType(typeFactory) + .getLabelField() + .getType(); + } else if (!typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { + return typeFactory + .getCurrentGraph() + .getEdgeTables() + .get(0) + .getRowType(typeFactory) + .getLabelField() + .getType(); + } else { + throw new GeaFlowDSLException( + "No vertex or edge table found in current graph {}", typeFactory.getCurrentGraph()); } + } - public static RelDataType getCurrentGraphDirectionType(GQLJavaTypeFactory typeFactory) { - return typeFactory.createSqlType(SqlTypeName.VARCHAR); - } + public static RelDataType getCurrentGraphDirectionType(GQLJavaTypeFactory typeFactory) { + return typeFactory.createSqlType(SqlTypeName.VARCHAR); + } - public static Optional getCurrentGraphEdgeTimestampType(GQLJavaTypeFactory typeFactory) { - if (typeFactory.getCurrentGraph() == null) { - throw new GeaFlowDSLException("Cannot get edge ts type without setting current graph"); - } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { - throw new GeaFlowDSLException("No edge table found in current graph {}", typeFactory.getCurrentGraph()); - } else if (typeFactory.getCurrentGraph().getEdgeTables().get(0).getRowType(typeFactory).getTimestampField().isPresent()) { - return Optional.of(typeFactory.getCurrentGraph().getEdgeTables().get(0).getRowType(typeFactory) - .getTimestampField().get().getType()); - } else { - return Optional.empty(); - } + public static Optional getCurrentGraphEdgeTimestampType( + GQLJavaTypeFactory typeFactory) { + if (typeFactory.getCurrentGraph() == null) { + throw new GeaFlowDSLException("Cannot get edge ts type without setting current graph"); + } else if (typeFactory.getCurrentGraph().getEdgeTables().isEmpty()) { + throw new GeaFlowDSLException( + "No edge table found in current graph {}", typeFactory.getCurrentGraph()); + } else if (typeFactory + .getCurrentGraph() + .getEdgeTables() + .get(0) + .getRowType(typeFactory) + .getTimestampField() + .isPresent()) { + return Optional.of( + typeFactory + .getCurrentGraph() + .getEdgeTables() + .get(0) + .getRowType(typeFactory) + .getTimestampField() + .get() + .getType()); + } else { + return Optional.empty(); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/PathReferenceAnalyzer.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/PathReferenceAnalyzer.java index 06a0b64d8..b318c9f74 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/PathReferenceAnalyzer.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/util/PathReferenceAnalyzer.java @@ -19,8 +19,6 @@ package org.apache.geaflow.dsl.util; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -28,6 +26,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.rel.BiRel; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; @@ -47,350 +46,360 @@ import org.apache.geaflow.dsl.rex.PathInputRef; import org.apache.geaflow.dsl.rex.RexLambdaCall; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + public class PathReferenceAnalyzer { - private final GQLContext gqlContext; + private final GQLContext gqlContext; - /** - * Mapping of the RelNode to referred path field names after the RelNode(including this node). - */ - private final Map> node2RefPathFields = new HashMap<>(); + /** Mapping of the RelNode to referred path field names after the RelNode(including this node). */ + private final Map> node2RefPathFields = new HashMap<>(); - /** - * Path field names that need to be globally excluded from trimming. - */ - private final Set globalPruningWhitelist = new HashSet<>(); + /** Path field names that need to be globally excluded from trimming. */ + private final Set globalPruningWhitelist = new HashSet<>(); - /** - * The next node collection for the RelNode. - */ - private final Map> subsequentNodes = new HashMap<>(); + /** The next node collection for the RelNode. */ + private final Map> subsequentNodes = new HashMap<>(); - public PathReferenceAnalyzer(GQLContext gqlContext) { - this.gqlContext = gqlContext; - } + public PathReferenceAnalyzer(GQLContext gqlContext) { + this.gqlContext = gqlContext; + } - public RelNode analyze(RelNode node) { - analyzePathRef(node, new HashSet<>()); - for (Set refPathFields : node2RefPathFields.values()) { - refPathFields.addAll(globalPruningWhitelist); - } - return pruneAndAdjustPathInputRef(node); + public RelNode analyze(RelNode node) { + analyzePathRef(node, new HashSet<>()); + for (Set refPathFields : node2RefPathFields.values()) { + refPathFields.addAll(globalPruningWhitelist); } - - private void analyzePathRef(RelNode node, Set subsequentNodeRefPathFields) { - Set refPathFields = new HashSet<>(); - if (node instanceof GraphMatch) { - GraphMatch match = (GraphMatch) node; - RelNode pathPattern = match.getPathPattern(); - analyzePathRef(pathPattern, subsequentNodeRefPathFields); - // use graph match as the subsequent node for path pattern. - subsequentNodes.put(pathPattern, Lists.newArrayList(node)); - - refPathFields.addAll(subsequentNodeRefPathFields); - } else if (node instanceof LoopUntilMatch) { - LoopUntilMatch loopUtil = (LoopUntilMatch) node; - PathReferenceCollector referenceCollector = new PathReferenceCollector(loopUtil.getUtilCondition()); - refPathFields.addAll(referenceCollector.getRefPathFields()); - refPathFields.addAll(subsequentNodeRefPathFields); - // analyze loop-body - analyzePathRef(loopUtil.getLoopBody(), refPathFields); - List subNodes = subsequentNodes.get(node); - subsequentNodes.put(loopUtil, subNodes); - node2RefPathFields.put(node, refPathFields); - } else if (node instanceof MatchExtend) { - // In order to avoid mistakenly dedup edges with different values, the nodes - // generated by MatchExtend cannot be pruned. - MatchExtend matchExtend = (MatchExtend) node; - PathReferenceCollector referenceCollector = new PathReferenceCollector(node); - refPathFields.addAll(referenceCollector.getRefPathFields()); - refPathFields.addAll(subsequentNodeRefPathFields); - globalPruningWhitelist.addAll(matchExtend.getExpressions().stream() - .map(PathModifyExpression::getPathFieldName).collect(Collectors.toList())); - } else if (node instanceof MatchJoin) { - // analyze referred path fields by this node. - PathReferenceCollector referenceCollector = new PathReferenceCollector(node); - Set joinRefPathFields = referenceCollector.getRefPathFields(); - //Find the new fields created by join, and add the related field names to reference - List joinCreatedFields = joinRefPathFields.stream().filter( - f -> !((MatchJoin) node).getLeft().getRowType().getFieldNames().contains(f) - && !((MatchJoin) node).getRight().getRowType().getFieldNames().contains(f) - ).collect(Collectors.toList()); - for (String createdField : joinCreatedFields) { - List relatedFields = joinRefPathFields.stream().filter( - f -> !f.equals(createdField) && createdField.indexOf(f) == 0).collect(Collectors.toList()); - if (relatedFields.size() > 0) { - String nameBase = relatedFields.get(0); - for (String related : relatedFields) { - nameBase = related.length() > nameBase.length() ? related : nameBase; - } - assert Integer.parseInt(createdField.substring(nameBase.length())) >= 0; - for (int j = 0; ; j++) { - String name = nameBase + j; - if (name.equals(createdField)) { - break; - } - if (!joinRefPathFields.contains(name)) { - joinRefPathFields.add(name); - } - } - } - + return pruneAndAdjustPathInputRef(node); + } + + private void analyzePathRef(RelNode node, Set subsequentNodeRefPathFields) { + Set refPathFields = new HashSet<>(); + if (node instanceof GraphMatch) { + GraphMatch match = (GraphMatch) node; + RelNode pathPattern = match.getPathPattern(); + analyzePathRef(pathPattern, subsequentNodeRefPathFields); + // use graph match as the subsequent node for path pattern. + subsequentNodes.put(pathPattern, Lists.newArrayList(node)); + + refPathFields.addAll(subsequentNodeRefPathFields); + } else if (node instanceof LoopUntilMatch) { + LoopUntilMatch loopUtil = (LoopUntilMatch) node; + PathReferenceCollector referenceCollector = + new PathReferenceCollector(loopUtil.getUtilCondition()); + refPathFields.addAll(referenceCollector.getRefPathFields()); + refPathFields.addAll(subsequentNodeRefPathFields); + // analyze loop-body + analyzePathRef(loopUtil.getLoopBody(), refPathFields); + List subNodes = subsequentNodes.get(node); + subsequentNodes.put(loopUtil, subNodes); + node2RefPathFields.put(node, refPathFields); + } else if (node instanceof MatchExtend) { + // In order to avoid mistakenly dedup edges with different values, the nodes + // generated by MatchExtend cannot be pruned. + MatchExtend matchExtend = (MatchExtend) node; + PathReferenceCollector referenceCollector = new PathReferenceCollector(node); + refPathFields.addAll(referenceCollector.getRefPathFields()); + refPathFields.addAll(subsequentNodeRefPathFields); + globalPruningWhitelist.addAll( + matchExtend.getExpressions().stream() + .map(PathModifyExpression::getPathFieldName) + .collect(Collectors.toList())); + } else if (node instanceof MatchJoin) { + // analyze referred path fields by this node. + PathReferenceCollector referenceCollector = new PathReferenceCollector(node); + Set joinRefPathFields = referenceCollector.getRefPathFields(); + // Find the new fields created by join, and add the related field names to reference + List joinCreatedFields = + joinRefPathFields.stream() + .filter( + f -> + !((MatchJoin) node).getLeft().getRowType().getFieldNames().contains(f) + && !((MatchJoin) node) + .getRight() + .getRowType() + .getFieldNames() + .contains(f)) + .collect(Collectors.toList()); + for (String createdField : joinCreatedFields) { + List relatedFields = + joinRefPathFields.stream() + .filter(f -> !f.equals(createdField) && createdField.indexOf(f) == 0) + .collect(Collectors.toList()); + if (relatedFields.size() > 0) { + String nameBase = relatedFields.get(0); + for (String related : relatedFields) { + nameBase = related.length() > nameBase.length() ? related : nameBase; + } + assert Integer.parseInt(createdField.substring(nameBase.length())) >= 0; + for (int j = 0; ; j++) { + String name = nameBase + j; + if (name.equals(createdField)) { + break; } - refPathFields.addAll(joinRefPathFields); - // add referred path fields by the subsequent node. - refPathFields.addAll(subsequentNodeRefPathFields); - } else { - // analyze referred path fields by this node. - PathReferenceCollector referenceCollector = new PathReferenceCollector(node); - refPathFields.addAll(referenceCollector.getRefPathFields()); - // add referred path fields by the subsequent node. - refPathFields.addAll(subsequentNodeRefPathFields); - } - - node2RefPathFields.computeIfAbsent(node, n -> new HashSet<>()).addAll(refPathFields); - - for (RelNode input : node.getInputs()) { - subsequentNodes.computeIfAbsent(input, k -> new ArrayList<>()).add(node); - Set inputSubsequent; - if (!(node instanceof IMatchNode) && input.getRowType().getSqlTypeName() != SqlTypeName.PATH) { - // If input's type is not a path, then it breaks the continuous match. - // It only can be another match, so clean the subsequentNodeRefPathFields set. - inputSubsequent = new HashSet<>(); - } else { - inputSubsequent = refPathFields; + if (!joinRefPathFields.contains(name)) { + joinRefPathFields.add(name); } - analyzePathRef(input, inputSubsequent); + } } + } + refPathFields.addAll(joinRefPathFields); + // add referred path fields by the subsequent node. + refPathFields.addAll(subsequentNodeRefPathFields); + } else { + // analyze referred path fields by this node. + PathReferenceCollector referenceCollector = new PathReferenceCollector(node); + refPathFields.addAll(referenceCollector.getRefPathFields()); + // add referred path fields by the subsequent node. + refPathFields.addAll(subsequentNodeRefPathFields); } - /** - * Prune the path schema and adjust the PathInputRef index for {@link RelNode}. - * - * @param node The node to be pruned. - * @return The pruned node. - */ - private RelNode pruneAndAdjustPathInputRef(final RelNode node) { - List rewriteInputs = new ArrayList<>(); - //step1. rewrite all the inputs. - for (RelNode input : node.getInputs()) { - rewriteInputs.add(pruneAndAdjustPathInputRef(input)); - } - RelNode rewriteNode = node; - - //step2. adjust the index of the PathInputRef after the inputs has pruned. - if (rewriteNode instanceof LoopUntilMatch) { // Adjust loop-util - LoopUntilMatch loopUtil = (LoopUntilMatch) rewriteNode; - adjustPathRefIndex(loopUtil.getUtilCondition(), getPathType(rewriteInputs.get(0))); - pruneAndAdjustPathInputRef(loopUtil.getLoopBody()); - } else if (rewriteNode instanceof BiRel) { // Adjust for join & correlate - rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); - PathRecordType pathType = getPathType(rewriteNode); - if (pathType != null) { - // rewrite the on condition using the latest join output type. - rewriteNode = adjustPathRefIndex(rewriteNode, pathType); - } - } else if (rewriteInputs.size() == 1 - && getPathType(rewriteInputs.get(0)) != null) { - RelNode rewriteInput = rewriteInputs.get(0); - PathRecordType inputPathType = getPathType(rewriteInput); - rewriteNode = adjustPathRefIndex(rewriteNode, inputPathType); - // replace input after adjust path ref index. - rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); - } else { - rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); - } - - //step3. prune path type for single match node. - Set subsequentRefFields = getSubsequentNodeRefPathFields(node); - if (node instanceof SingleMatchNode) { - rewriteNode = pruneMatchNode((SingleMatchNode) rewriteNode, subsequentRefFields); - } else if (node instanceof GraphMatch) { - // prune match node in graph match. - GraphMatch match = (GraphMatch) node; - IMatchNode rewritePathPattern = (IMatchNode) pruneAndAdjustPathInputRef(match.getPathPattern()); - rewriteNode = match.copy(match.getTraitSet(), rewriteInputs.get(0), rewritePathPattern, - rewritePathPattern.getPathSchema()); - } - return rewriteNode; + node2RefPathFields.computeIfAbsent(node, n -> new HashSet<>()).addAll(refPathFields); + + for (RelNode input : node.getInputs()) { + subsequentNodes.computeIfAbsent(input, k -> new ArrayList<>()).add(node); + Set inputSubsequent; + if (!(node instanceof IMatchNode) + && input.getRowType().getSqlTypeName() != SqlTypeName.PATH) { + // If input's type is not a path, then it breaks the continuous match. + // It only can be another match, so clean the subsequentNodeRefPathFields set. + inputSubsequent = new HashSet<>(); + } else { + inputSubsequent = refPathFields; + } + analyzePathRef(input, inputSubsequent); } - - private RelNode adjustPathRefIndex(RelNode node, PathRecordType inputPathType) { - return node.accept(new AdjustPathRefIndexVisitor(inputPathType)); + } + + /** + * Prune the path schema and adjust the PathInputRef index for {@link RelNode}. + * + * @param node The node to be pruned. + * @return The pruned node. + */ + private RelNode pruneAndAdjustPathInputRef(final RelNode node) { + List rewriteInputs = new ArrayList<>(); + // step1. rewrite all the inputs. + for (RelNode input : node.getInputs()) { + rewriteInputs.add(pruneAndAdjustPathInputRef(input)); + } + RelNode rewriteNode = node; + + // step2. adjust the index of the PathInputRef after the inputs has pruned. + if (rewriteNode instanceof LoopUntilMatch) { // Adjust loop-util + LoopUntilMatch loopUtil = (LoopUntilMatch) rewriteNode; + adjustPathRefIndex(loopUtil.getUtilCondition(), getPathType(rewriteInputs.get(0))); + pruneAndAdjustPathInputRef(loopUtil.getLoopBody()); + } else if (rewriteNode instanceof BiRel) { // Adjust for join & correlate + rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); + PathRecordType pathType = getPathType(rewriteNode); + if (pathType != null) { + // rewrite the on condition using the latest join output type. + rewriteNode = adjustPathRefIndex(rewriteNode, pathType); + } + } else if (rewriteInputs.size() == 1 && getPathType(rewriteInputs.get(0)) != null) { + RelNode rewriteInput = rewriteInputs.get(0); + PathRecordType inputPathType = getPathType(rewriteInput); + rewriteNode = adjustPathRefIndex(rewriteNode, inputPathType); + // replace input after adjust path ref index. + rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); + } else { + rewriteNode = rewriteNode.copy(node.getTraitSet(), rewriteInputs); } - private RexNode adjustPathRefIndex(RexNode node, PathRecordType inputPathType) { - assert inputPathType != null; - return node.accept(new AdjustPathRefIndexVisitor(inputPathType)); + // step3. prune path type for single match node. + Set subsequentRefFields = getSubsequentNodeRefPathFields(node); + if (node instanceof SingleMatchNode) { + rewriteNode = pruneMatchNode((SingleMatchNode) rewriteNode, subsequentRefFields); + } else if (node instanceof GraphMatch) { + // prune match node in graph match. + GraphMatch match = (GraphMatch) node; + IMatchNode rewritePathPattern = + (IMatchNode) pruneAndAdjustPathInputRef(match.getPathPattern()); + rewriteNode = + match.copy( + match.getTraitSet(), + rewriteInputs.get(0), + rewritePathPattern, + rewritePathPattern.getPathSchema()); } + return rewriteNode; + } - private class AdjustPathRefIndexVisitor extends RexShuttle { + private RelNode adjustPathRefIndex(RelNode node, PathRecordType inputPathType) { + return node.accept(new AdjustPathRefIndexVisitor(inputPathType)); + } - private final PathRecordType inputPathType; + private RexNode adjustPathRefIndex(RexNode node, PathRecordType inputPathType) { + assert inputPathType != null; + return node.accept(new AdjustPathRefIndexVisitor(inputPathType)); + } - public AdjustPathRefIndexVisitor(PathRecordType inputPathType) { - this.inputPathType = inputPathType; - } + private class AdjustPathRefIndexVisitor extends RexShuttle { - @Override - public RexNode visitInputRef(RexInputRef inputRef) { - if (inputRef instanceof PathInputRef) { - PathInputRef pathInputRef = (PathInputRef) inputRef; - RelDataTypeField field = inputPathType.getField(pathInputRef.getLabel(), - gqlContext.isCaseSensitive(), false); - assert field != null : "Field: " + pathInputRef.getLabel() - + " not found in the input"; - return pathInputRef.copy(field.getIndex()); - } - return inputRef; - } + private final PathRecordType inputPathType; - @Override - public RexNode visitCall(RexCall call) { - if (call instanceof RexLambdaCall) { - RexLambdaCall lambdaCall = (RexLambdaCall) call; - RexSubQuery subQuery = lambdaCall.getInput(); - RexNode valueNode = lambdaCall.getValue(); - - // prune sub query - PathReferenceCollector referenceCollector = new PathReferenceCollector(valueNode); - Set refPathFields = new HashSet<>(referenceCollector.getRefPathFields()); - - assert inputPathType.lastFieldName().isPresent(); - // The last field is the start vertex to request the sub query. - String startLabel = inputPathType.lastFieldName().get(); - refPathFields.add(startLabel); - analyzePathRef(subQuery.rel, refPathFields); - // In order to getSubsequentNodeRefPathFields for subQuery.rel when pruning it, - // we attach subQuery.rel to itself as it has no real next node. - subsequentNodes.put(subQuery.rel, Lists.newArrayList(subQuery.rel)); - // prune sub query - RelNode newSubRel = pruneAndAdjustPathInputRef(subQuery.rel); - RexSubQuery newSubQuery = subQuery.clone(newSubRel); - // adjust path index for value node - RexNode newValue = adjustPathRefIndex(valueNode, getPathType(newSubRel)); - return lambdaCall.clone(lambdaCall.type, Lists.newArrayList(newSubQuery, newValue)); - } - return super.visitCall(call); - } + public AdjustPathRefIndexVisitor(PathRecordType inputPathType) { + this.inputPathType = inputPathType; } - /** - * Prune path type for match node. - * - * @param node The match node. - * @param subsequentRefFields The reference labels by the subsequent nodes. - */ - private SingleMatchNode pruneMatchNode(SingleMatchNode node, Set subsequentRefFields) { - PathRecordType outputPathType = prunePathType(subsequentRefFields, node.getPathSchema()); - return (SingleMatchNode) node.copy(node.getInputs(), outputPathType); + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + if (inputRef instanceof PathInputRef) { + PathInputRef pathInputRef = (PathInputRef) inputRef; + RelDataTypeField field = + inputPathType.getField(pathInputRef.getLabel(), gqlContext.isCaseSensitive(), false); + assert field != null : "Field: " + pathInputRef.getLabel() + " not found in the input"; + return pathInputRef.copy(field.getIndex()); + } + return inputRef; } - private PathRecordType prunePathType(Set refPathFields, RelDataType pathRecordType) { - // Pruned the path type by the reference. - List prunedFields = new ArrayList<>(); - int index = 0; - for (RelDataTypeField field : pathRecordType.getFieldList()) { - if (refPathFields.contains(field.getName())) { - prunedFields.add(new RelDataTypeFieldImpl(field.getName(), index, field.getType())); - index++; - } - } - return new PathRecordType(prunedFields); + @Override + public RexNode visitCall(RexCall call) { + if (call instanceof RexLambdaCall) { + RexLambdaCall lambdaCall = (RexLambdaCall) call; + RexSubQuery subQuery = lambdaCall.getInput(); + RexNode valueNode = lambdaCall.getValue(); + + // prune sub query + PathReferenceCollector referenceCollector = new PathReferenceCollector(valueNode); + Set refPathFields = new HashSet<>(referenceCollector.getRefPathFields()); + + assert inputPathType.lastFieldName().isPresent(); + // The last field is the start vertex to request the sub query. + String startLabel = inputPathType.lastFieldName().get(); + refPathFields.add(startLabel); + analyzePathRef(subQuery.rel, refPathFields); + // In order to getSubsequentNodeRefPathFields for subQuery.rel when pruning it, + // we attach subQuery.rel to itself as it has no real next node. + subsequentNodes.put(subQuery.rel, Lists.newArrayList(subQuery.rel)); + // prune sub query + RelNode newSubRel = pruneAndAdjustPathInputRef(subQuery.rel); + RexSubQuery newSubQuery = subQuery.clone(newSubRel); + // adjust path index for value node + RexNode newValue = adjustPathRefIndex(valueNode, getPathType(newSubRel)); + return lambdaCall.clone(lambdaCall.type, Lists.newArrayList(newSubQuery, newValue)); + } + return super.visitCall(call); } - - private Set getRefPathFields(RelNode node) { - return node2RefPathFields.get(node); + } + + /** + * Prune path type for match node. + * + * @param node The match node. + * @param subsequentRefFields The reference labels by the subsequent nodes. + */ + private SingleMatchNode pruneMatchNode(SingleMatchNode node, Set subsequentRefFields) { + PathRecordType outputPathType = prunePathType(subsequentRefFields, node.getPathSchema()); + return (SingleMatchNode) node.copy(node.getInputs(), outputPathType); + } + + private PathRecordType prunePathType(Set refPathFields, RelDataType pathRecordType) { + // Pruned the path type by the reference. + List prunedFields = new ArrayList<>(); + int index = 0; + for (RelDataTypeField field : pathRecordType.getFieldList()) { + if (refPathFields.contains(field.getName())) { + prunedFields.add(new RelDataTypeFieldImpl(field.getName(), index, field.getType())); + index++; + } } + return new PathRecordType(prunedFields); + } - private Set getSubsequentNodeRefPathFields(RelNode node) { - List subNodes = subsequentNodes.get(node); - if (subNodes != null && subNodes.size() > 0) { - return subNodes.stream().map(this::getRefPathFields).reduce(Sets::union).get(); - } - return new HashSet<>(); + private Set getRefPathFields(RelNode node) { + return node2RefPathFields.get(node); + } + + private Set getSubsequentNodeRefPathFields(RelNode node) { + List subNodes = subsequentNodes.get(node); + if (subNodes != null && subNodes.size() > 0) { + return subNodes.stream().map(this::getRefPathFields).reduce(Sets::union).get(); } + return new HashSet<>(); + } - private static PathRecordType getPathType(RelNode node) { - if (node instanceof IMatchNode) { - return ((IMatchNode) node).getPathSchema(); - } - if (node.getRowType() instanceof PathRecordType) { - return (PathRecordType) node.getRowType(); - } - return null; + private static PathRecordType getPathType(RelNode node) { + if (node instanceof IMatchNode) { + return ((IMatchNode) node).getPathSchema(); } + if (node.getRowType() instanceof PathRecordType) { + return (PathRecordType) node.getRowType(); + } + return null; + } - private static class PathReferenceCollector extends RexShuttle { + private static class PathReferenceCollector extends RexShuttle { - /** - * The RelNode to collect referred path fields. - */ - private RelNode node; + /** The RelNode to collect referred path fields. */ + private RelNode node; - private RexNode rexNode; + private RexNode rexNode; - private final Set refPathFields = new HashSet<>(); + private final Set refPathFields = new HashSet<>(); - private boolean hasAnalyze = false; + private boolean hasAnalyze = false; - public PathReferenceCollector(RelNode node) { - this.node = node; - } + public PathReferenceCollector(RelNode node) { + this.node = node; + } - public PathReferenceCollector(RexNode rexNode) { - this.rexNode = rexNode; - } + public PathReferenceCollector(RexNode rexNode) { + this.rexNode = rexNode; + } - @Override - public RexNode visitInputRef(RexInputRef inputRef) { - if (inputRef instanceof PathInputRef) { - refPathFields.add(((PathInputRef) inputRef).getLabel()); - } - return inputRef; - } + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + if (inputRef instanceof PathInputRef) { + refPathFields.add(((PathInputRef) inputRef).getLabel()); + } + return inputRef; + } - @Override - public RexNode visitCall(RexCall call) { - if (call instanceof RexLambdaCall) { - // analyze path reference in sub query. - assert node != null : "node should not be null when analyze sub query."; - RelDataType inputPathType; - if (node instanceof SingleMatchNode) { - inputPathType = ((SingleMatchNode) node).getInput().getRowType(); - } else if (node instanceof MatchJoin) { - inputPathType = node.getRowType(); - } else { - throw new IllegalArgumentException("Illegal node: " + node + " with sub-query"); - } - assert inputPathType != null; - - int parentPathSize = inputPathType.getFieldCount(); - RexLambdaCall lambdaCall = (RexLambdaCall) call; - RexSubQuery subQuery = lambdaCall.getInput(); - RelDataType subQueryPathType = subQuery.rel.getRowType(); - // The first node of the sub query is the start vertex, it cannot be pruned. - refPathFields.add(subQueryPathType.getFieldList().get(parentPathSize - 1).getName()); - - PathReferenceCollector subCollector = new PathReferenceCollector(subQuery.rel); - Set subRefPathFields = subCollector.getRefPathFields(); - inputPathType.getFieldNames().stream() - .filter(subRefPathFields::contains) - .forEach(refPathFields::add); - } - return super.visitCall(call); + @Override + public RexNode visitCall(RexCall call) { + if (call instanceof RexLambdaCall) { + // analyze path reference in sub query. + assert node != null : "node should not be null when analyze sub query."; + RelDataType inputPathType; + if (node instanceof SingleMatchNode) { + inputPathType = ((SingleMatchNode) node).getInput().getRowType(); + } else if (node instanceof MatchJoin) { + inputPathType = node.getRowType(); + } else { + throw new IllegalArgumentException("Illegal node: " + node + " with sub-query"); } + assert inputPathType != null; - public Set getRefPathFields() { - if (!hasAnalyze) { - if (node != null) { - node.accept(this); - } else if (rexNode != null) { - rexNode.accept(this); - } - hasAnalyze = true; - } - return refPathFields; + int parentPathSize = inputPathType.getFieldCount(); + RexLambdaCall lambdaCall = (RexLambdaCall) call; + RexSubQuery subQuery = lambdaCall.getInput(); + RelDataType subQueryPathType = subQuery.rel.getRowType(); + // The first node of the sub query is the start vertex, it cannot be pruned. + refPathFields.add(subQueryPathType.getFieldList().get(parentPathSize - 1).getName()); + + PathReferenceCollector subCollector = new PathReferenceCollector(subQuery.rel); + Set subRefPathFields = subCollector.getRefPathFields(); + inputPathType.getFieldNames().stream() + .filter(subRefPathFields::contains) + .forEach(refPathFields::add); + } + return super.visitCall(call); + } + + public Set getRefPathFields() { + if (!hasAnalyze) { + if (node != null) { + node.accept(this); + } else if (rexNode != null) { + rexNode.accept(this); } + hasAnalyze = true; + } + return refPathFields; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/GQLValidatorImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/GQLValidatorImpl.java index a527a3b3d..0cd5df61c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/GQLValidatorImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/GQLValidatorImpl.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -76,708 +77,734 @@ public class GQLValidatorImpl extends SqlValidatorImpl { - private final GQLContext gContext; - private final GQLJavaTypeFactory typeFactory; - private static final String ANONYMOUS_COLUMN_PREFIX = "col_"; - private static final String RECURRING_COLUMN_SUFFIX = "_rcr"; - private final Map matchNodeTypes = new HashMap<>(); - - private final Map let2ModifyGraphType = new HashMap<>(); - - private final Map renamedMatchNodes = new HashMap<>(); - - private QueryNodeContext currentQueryNodeContext; - - public GQLValidatorImpl(GQLContext gContext, SqlOperatorTable opTab, - SqlValidatorCatalogReader catalogReader, RelDataTypeFactory typeFactory, - SqlConformance conformance) { - super(opTab, catalogReader, typeFactory, conformance); - this.setCallRewrite(false); - this.gContext = gContext; - this.typeFactory = (GQLJavaTypeFactory) typeFactory; - } - - public SqlNode validate(SqlNode sqlNode, QueryNodeContext queryNodeContext) { - this.currentQueryNodeContext = queryNodeContext; - return super.validate(sqlNode); - } - - @Override - public RelDataType getLogicalSourceRowType(RelDataType sourceRawType, - SqlInsert insert) { - return typeFactory.toSql(sourceRawType); - } - - @Override - protected void registerQuery( - SqlValidatorScope parentScope, - SqlValidatorScope usingScope, - SqlNode node, - SqlNode enclosingNode, - String alias, - boolean forceNullable) { - if (node.getKind() == SqlKind.INSERT) { - SqlInsert insertCall = (SqlInsert) node; - GQLInsertNamespace insertNs = - new GQLInsertNamespace( - this, - insertCall, - parentScope); - registerNamespace(usingScope, null, insertNs, forceNullable); - registerQuery( - parentScope, - usingScope, - insertCall.getSource(), - enclosingNode, - null, - false); - } else { - super.registerQuery(parentScope, usingScope, node, enclosingNode, alias, forceNullable); + private final GQLContext gContext; + private final GQLJavaTypeFactory typeFactory; + private static final String ANONYMOUS_COLUMN_PREFIX = "col_"; + private static final String RECURRING_COLUMN_SUFFIX = "_rcr"; + private final Map matchNodeTypes = new HashMap<>(); + + private final Map let2ModifyGraphType = new HashMap<>(); + + private final Map renamedMatchNodes = new HashMap<>(); + + private QueryNodeContext currentQueryNodeContext; + + public GQLValidatorImpl( + GQLContext gContext, + SqlOperatorTable opTab, + SqlValidatorCatalogReader catalogReader, + RelDataTypeFactory typeFactory, + SqlConformance conformance) { + super(opTab, catalogReader, typeFactory, conformance); + this.setCallRewrite(false); + this.gContext = gContext; + this.typeFactory = (GQLJavaTypeFactory) typeFactory; + } + + public SqlNode validate(SqlNode sqlNode, QueryNodeContext queryNodeContext) { + this.currentQueryNodeContext = queryNodeContext; + return super.validate(sqlNode); + } + + @Override + public RelDataType getLogicalSourceRowType(RelDataType sourceRawType, SqlInsert insert) { + return typeFactory.toSql(sourceRawType); + } + + @Override + protected void registerQuery( + SqlValidatorScope parentScope, + SqlValidatorScope usingScope, + SqlNode node, + SqlNode enclosingNode, + String alias, + boolean forceNullable) { + if (node.getKind() == SqlKind.INSERT) { + SqlInsert insertCall = (SqlInsert) node; + GQLInsertNamespace insertNs = new GQLInsertNamespace(this, insertCall, parentScope); + registerNamespace(usingScope, null, insertNs, forceNullable); + registerQuery(parentScope, usingScope, insertCall.getSource(), enclosingNode, null, false); + } else { + super.registerQuery(parentScope, usingScope, node, enclosingNode, alias, forceNullable); + } + } + + @Override + protected void registerNamespace( + SqlValidatorScope usingScope, String alias, SqlValidatorNamespace ns, boolean forceNullable) { + SqlValidatorNamespace newNs = ns; + // auto complete the instance name + if (ns instanceof IdentifierNamespace) { + IdentifierNamespace idNs = (IdentifierNamespace) ns; + newNs = new IdentifierCompleteNamespace(idNs); + } + super.registerNamespace(usingScope, alias, newNs, forceNullable); + } + + @Override + protected void registerOtherKindQuery( + SqlValidatorScope parentScope, + SqlValidatorScope usingScope, + SqlNode node, + SqlNode enclosingNode, + String alias, + boolean forceNullable, + boolean checkUpdate) { + switch (node.getKind()) { + case GQL_RETURN: + SqlReturnStatement returnStatement = (SqlReturnStatement) node; + GQLReturnScope returnScope = new GQLReturnScope(parentScope, returnStatement); + GQLReturnNamespace returnNs = new GQLReturnNamespace(this, enclosingNode, returnStatement); + registerNamespace(usingScope, alias, returnNs, forceNullable); + if (returnStatement.getGroupBy() != null + || getAggregate(returnStatement.getReturnList()) != null) { + returnScope.setAggMode(); } - } - - @Override - protected void registerNamespace(SqlValidatorScope usingScope, String alias, - SqlValidatorNamespace ns, boolean forceNullable) { - SqlValidatorNamespace newNs = ns; - // auto complete the instance name - if (ns instanceof IdentifierNamespace) { - IdentifierNamespace idNs = (IdentifierNamespace) ns; - newNs = new IdentifierCompleteNamespace(idNs); + // register from + String matchPatternNsAlias = deriveAlias(returnStatement.getFrom()); + registerQuery( + parentScope, + returnScope, + returnStatement.getFrom(), + returnStatement.getFrom(), + matchPatternNsAlias, + forceNullable); + scopes.put(returnStatement, returnScope); + + String returnNsAlias = deriveAlias(returnStatement); + if (returnStatement.getGroupBy() != null) { + GQLReturnGroupByScope groupByScope = + new GQLReturnGroupByScope(returnScope, returnStatement, returnStatement.getGroupBy()); + registerNamespace(groupByScope, returnNsAlias, returnNs, forceNullable); + scopes.put(returnStatement.getGroupBy(), groupByScope); } - super.registerNamespace(usingScope, alias, newNs, forceNullable); - } - - @Override - protected void registerOtherKindQuery(SqlValidatorScope parentScope, - SqlValidatorScope usingScope, SqlNode node, - SqlNode enclosingNode, String alias, - boolean forceNullable, boolean checkUpdate) { - switch (node.getKind()) { - case GQL_RETURN: - SqlReturnStatement returnStatement = (SqlReturnStatement) node; - GQLReturnScope returnScope = new GQLReturnScope(parentScope, returnStatement); - GQLReturnNamespace returnNs = new GQLReturnNamespace(this, enclosingNode, - returnStatement); - registerNamespace(usingScope, alias, returnNs, forceNullable); - if (returnStatement.getGroupBy() != null - || getAggregate(returnStatement.getReturnList()) != null) { - returnScope.setAggMode(); - } - // register from - String matchPatternNsAlias = deriveAlias(returnStatement.getFrom()); - registerQuery(parentScope, returnScope, returnStatement.getFrom(), - returnStatement.getFrom(), matchPatternNsAlias, forceNullable); - scopes.put(returnStatement, returnScope); - - String returnNsAlias = deriveAlias(returnStatement); - if (returnStatement.getGroupBy() != null) { - GQLReturnGroupByScope groupByScope = new GQLReturnGroupByScope(returnScope, - returnStatement, returnStatement.getGroupBy()); - registerNamespace(groupByScope, returnNsAlias, returnNs, forceNullable); - scopes.put(returnStatement.getGroupBy(), groupByScope); - } - if (returnStatement.getOrderBy() != null) { - GQLReturnOrderByScope orderByScope = new GQLReturnOrderByScope(returnScope, - returnStatement.getOrderBy()); - registerNamespace(orderByScope, returnNsAlias, returnNs, forceNullable); - scopes.put(returnStatement.getOrderBy(), orderByScope); - } - break; - case GQL_FILTER: - SqlFilterStatement filterStatement = (SqlFilterStatement) node; - GQLFilterNamespace filterNs = new GQLFilterNamespace(this, enclosingNode, - filterStatement); - registerNamespace(usingScope, alias, filterNs, forceNullable); - - GQLScope filterScope = new GQLScope(parentScope, filterStatement); - registerQuery(parentScope, filterScope, filterStatement.getFrom(), - filterStatement.getFrom(), deriveAlias(filterStatement.getFrom()), forceNullable); - scopes.put(filterStatement, filterScope); - break; - case GQL_MATCH_PATTERN: - //register MatchPattern - SqlMatchPattern matchPattern = (SqlMatchPattern) node; - GQLScope matchPatternScope = new GQLScope(parentScope, matchPattern); - GQLMatchPatternNamespace matchNamespace = new GQLMatchPatternNamespace(this, - matchPattern); - registerNamespace(usingScope, alias, matchNamespace, forceNullable); - // performUnconditionalRewrites will set current graph node to the - // matchPattern#from, so it cannot be null. - assert matchPattern.getFrom() != null; - - registerQuery(parentScope, matchPatternScope, matchPattern.getFrom(), matchPattern, - deriveAlias(matchPattern.getFrom()), forceNullable); - scopes.put(matchPattern, matchPatternScope); - - SqlNodeList pathPatterns = matchPattern.getPathPatterns(); - SqlValidatorNamespace fromNs = namespaces.get(matchPattern.getFrom()); - for (SqlNode sqlNode : pathPatterns) { - registerPathPattern(sqlNode, parentScope, fromNs, alias, forceNullable); - if (sqlNode instanceof SqlUnionPathPattern) { - SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) sqlNode; - scopes.put(unionPathPattern, matchPatternScope); - } - } - if (matchPattern.getWhere() != null) { - SqlNode where = matchPattern.getWhere(); - GQLScope whereScope = new GQLScope(parentScope, where); - registerNamespace(whereScope, alias, matchNamespace, forceNullable); - scopes.put(where, whereScope); - registerGqlSubQuery(whereScope, alias, matchNamespace, where); - } - if (matchPattern.getOrderBy() != null) { - GQLReturnOrderByScope orderByScope = new GQLReturnOrderByScope(matchPatternScope, - matchPattern.getOrderBy()); - registerNamespace(orderByScope, alias, matchNamespace, forceNullable); - scopes.put(matchPattern.getOrderBy(), orderByScope); - } - break; - case IDENTIFIER: - SqlIdentifier identifier = (SqlIdentifier) node; - IdentifierNamespace ns = new IdentifierNamespace(this, identifier, null, identifier, - parentScope); - registerNamespace(usingScope, gContext.getCurrentGraph(), ns, forceNullable); - break; - case GQL_LET: - SqlLetStatement letStatement = (SqlLetStatement) node; - GQLLetNamespace letNamespace = new GQLLetNamespace(this, letStatement); - registerNamespace(usingScope, alias, letNamespace, forceNullable); - - GQLScope letScope = new GQLScope(parentScope, node); - registerQuery(parentScope, letScope, letStatement.getFrom(), letStatement.getFrom(), - deriveAlias(letStatement.getFrom()), forceNullable); - // register sub query in let expression. - SqlValidatorNamespace letFromNs = namespaces.get(letStatement.getFrom()); - registerGqlSubQuery(letScope, alias, letFromNs, letStatement.getExpression()); - scopes.put(letStatement, letScope); - break; - case GQL_ALGORITHM: - SqlGraphAlgorithmCall algorithmCall = (SqlGraphAlgorithmCall) node; - GQLAlgorithmNamespace algorithmNamespace = new GQLAlgorithmNamespace(this, - algorithmCall); - registerNamespace(usingScope, alias, algorithmNamespace, forceNullable); - GQLScope algorithmScope = new GQLScope(parentScope, node); - scopes.put(algorithmCall, algorithmScope); - break; - default: - super.registerOtherKindQuery(parentScope, usingScope, node, enclosingNode, alias, - forceNullable, checkUpdate); + if (returnStatement.getOrderBy() != null) { + GQLReturnOrderByScope orderByScope = + new GQLReturnOrderByScope(returnScope, returnStatement.getOrderBy()); + registerNamespace(orderByScope, returnNsAlias, returnNs, forceNullable); + scopes.put(returnStatement.getOrderBy(), orderByScope); } - } - - private SqlValidatorNamespace registerPathPattern(SqlNode sqlNode, - SqlValidatorScope parentScope, - SqlValidatorNamespace fromNs, String alias, - boolean forceNullable) { - GQLPathPatternScope pathPatternScope = new GQLPathPatternScope(parentScope, (SqlCall) sqlNode); - - if (sqlNode instanceof SqlUnionPathPattern) { + break; + case GQL_FILTER: + SqlFilterStatement filterStatement = (SqlFilterStatement) node; + GQLFilterNamespace filterNs = new GQLFilterNamespace(this, enclosingNode, filterStatement); + registerNamespace(usingScope, alias, filterNs, forceNullable); + + GQLScope filterScope = new GQLScope(parentScope, filterStatement); + registerQuery( + parentScope, + filterScope, + filterStatement.getFrom(), + filterStatement.getFrom(), + deriveAlias(filterStatement.getFrom()), + forceNullable); + scopes.put(filterStatement, filterScope); + break; + case GQL_MATCH_PATTERN: + // register MatchPattern + SqlMatchPattern matchPattern = (SqlMatchPattern) node; + GQLScope matchPatternScope = new GQLScope(parentScope, matchPattern); + GQLMatchPatternNamespace matchNamespace = new GQLMatchPatternNamespace(this, matchPattern); + registerNamespace(usingScope, alias, matchNamespace, forceNullable); + // performUnconditionalRewrites will set current graph node to the + // matchPattern#from, so it cannot be null. + assert matchPattern.getFrom() != null; + + registerQuery( + parentScope, + matchPatternScope, + matchPattern.getFrom(), + matchPattern, + deriveAlias(matchPattern.getFrom()), + forceNullable); + scopes.put(matchPattern, matchPatternScope); + + SqlNodeList pathPatterns = matchPattern.getPathPatterns(); + SqlValidatorNamespace fromNs = namespaces.get(matchPattern.getFrom()); + for (SqlNode sqlNode : pathPatterns) { + registerPathPattern(sqlNode, parentScope, fromNs, alias, forceNullable); + if (sqlNode instanceof SqlUnionPathPattern) { SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) sqlNode; - GQLUnionPathPatternNamespace pathNs = - new GQLUnionPathPatternNamespace(this, unionPathPattern); - registerNamespace(null, alias, pathNs, forceNullable); - registerPathPattern(unionPathPattern.getLeft(), parentScope, - fromNs, alias, forceNullable); - registerPathPattern(unionPathPattern.getRight(), parentScope, - fromNs, alias, forceNullable); - - scopes.put(unionPathPattern, pathPatternScope); - return pathNs; + scopes.put(unionPathPattern, matchPatternScope); + } } - SqlPathPattern pathPattern = (SqlPathPattern) sqlNode; - GQLPathPatternNamespace pathNs = new GQLPathPatternNamespace(this, pathPattern); - registerNamespace(null, null, pathNs, forceNullable); - - String pathPatternAlias = alias == null ? deriveAlias(pathPattern) : alias; - pathPatternScope.addChild(fromNs, pathPatternAlias, forceNullable); - scopes.put(pathPattern, pathPatternScope); - //register MatchNode - for (SqlNode matchNode : pathPattern.getPathNodes()) { - SqlMatchNode sqlMatchNode = (SqlMatchNode) matchNode; - GQLScope nodeScope = new GQLScope(pathPatternScope, sqlMatchNode); - GQLMatchNodeNamespace nodeNs = new GQLMatchNodeNamespace(this, sqlMatchNode); - registerNamespace(nodeScope, deriveAlias(matchNode), nodeNs, forceNullable); - scopes.put(matchNode, nodeScope); - if (sqlMatchNode.getWhere() != null) { - SqlNode nodeWhere = sqlMatchNode.getWhere(); - //Where condition can only access NodeScope, not MatchPatternScope - GQLScope nodeWhereScope = new GQLScope(parentScope, nodeWhere); - - GQLMatchNodeWhereNamespace nodeWhereNs = - new GQLMatchNodeWhereNamespace(this, matchNode, nodeNs); - nodeWhereScope.addChild(nodeWhereNs, sqlMatchNode.getName(), forceNullable); - scopes.put(nodeWhere, nodeWhereScope); - } + if (matchPattern.getWhere() != null) { + SqlNode where = matchPattern.getWhere(); + GQLScope whereScope = new GQLScope(parentScope, where); + registerNamespace(whereScope, alias, matchNamespace, forceNullable); + scopes.put(where, whereScope); + registerGqlSubQuery(whereScope, alias, matchNamespace, where); } - return pathNs; - } - - @Override - protected SqlNode registerOtherFrom(SqlValidatorScope parentScope, - SqlValidatorScope usingScope, - boolean register, - final SqlNode node, - SqlNode enclosingNode, - String alias, - SqlNodeList extendList, - boolean forceNullable, - final boolean lateral) { - switch (node.getKind()) { - case GQL_RETURN: - case GQL_FILTER: - case GQL_MATCH_PATTERN: - case GQL_LET: - if (alias == null) { - alias = deriveAlias(node); - } - registerQuery( - parentScope, - register ? usingScope : null, - node, - enclosingNode, - alias, - forceNullable); - return node; - default: - return super.registerOtherFrom(parentScope, usingScope, register, node, - enclosingNode, alias, extendList, forceNullable, lateral); + if (matchPattern.getOrderBy() != null) { + GQLReturnOrderByScope orderByScope = + new GQLReturnOrderByScope(matchPatternScope, matchPattern.getOrderBy()); + registerNamespace(orderByScope, alias, matchNamespace, forceNullable); + scopes.put(matchPattern.getOrderBy(), orderByScope); } - } - - @Override - protected SqlValidatorScope getWithBodyScope(SqlValidatorScope parentScope, SqlWith with) { - if (GQLNodeUtil.containMatch(with)) { - GQLScope withBodyScope = new GQLWithBodyScope(parentScope, with.withList); - if (with.withList.size() != 1) { - throw new GeaFlowDSLException(with.getParserPosition().toString(), "Only support one with item"); - } - for (SqlNode withItem : with.withList) { - SqlValidatorNamespace withItemNs = getNamespace(withItem); - String withName = ((SqlWithItem) withItem).name.getSimple(); - withBodyScope.addChild(withItemNs, withName, false); - } - scopes.put(with.withList, withBodyScope); - return withBodyScope; + break; + case IDENTIFIER: + SqlIdentifier identifier = (SqlIdentifier) node; + IdentifierNamespace ns = + new IdentifierNamespace(this, identifier, null, identifier, parentScope); + registerNamespace(usingScope, gContext.getCurrentGraph(), ns, forceNullable); + break; + case GQL_LET: + SqlLetStatement letStatement = (SqlLetStatement) node; + GQLLetNamespace letNamespace = new GQLLetNamespace(this, letStatement); + registerNamespace(usingScope, alias, letNamespace, forceNullable); + + GQLScope letScope = new GQLScope(parentScope, node); + registerQuery( + parentScope, + letScope, + letStatement.getFrom(), + letStatement.getFrom(), + deriveAlias(letStatement.getFrom()), + forceNullable); + // register sub query in let expression. + SqlValidatorNamespace letFromNs = namespaces.get(letStatement.getFrom()); + registerGqlSubQuery(letScope, alias, letFromNs, letStatement.getExpression()); + scopes.put(letStatement, letScope); + break; + case GQL_ALGORITHM: + SqlGraphAlgorithmCall algorithmCall = (SqlGraphAlgorithmCall) node; + GQLAlgorithmNamespace algorithmNamespace = new GQLAlgorithmNamespace(this, algorithmCall); + registerNamespace(usingScope, alias, algorithmNamespace, forceNullable); + GQLScope algorithmScope = new GQLScope(parentScope, node); + scopes.put(algorithmCall, algorithmScope); + break; + default: + super.registerOtherKindQuery( + parentScope, usingScope, node, enclosingNode, alias, forceNullable, checkUpdate); + } + } + + private SqlValidatorNamespace registerPathPattern( + SqlNode sqlNode, + SqlValidatorScope parentScope, + SqlValidatorNamespace fromNs, + String alias, + boolean forceNullable) { + GQLPathPatternScope pathPatternScope = new GQLPathPatternScope(parentScope, (SqlCall) sqlNode); + + if (sqlNode instanceof SqlUnionPathPattern) { + SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) sqlNode; + GQLUnionPathPatternNamespace pathNs = + new GQLUnionPathPatternNamespace(this, unionPathPattern); + registerNamespace(null, alias, pathNs, forceNullable); + registerPathPattern(unionPathPattern.getLeft(), parentScope, fromNs, alias, forceNullable); + registerPathPattern(unionPathPattern.getRight(), parentScope, fromNs, alias, forceNullable); + + scopes.put(unionPathPattern, pathPatternScope); + return pathNs; + } + SqlPathPattern pathPattern = (SqlPathPattern) sqlNode; + GQLPathPatternNamespace pathNs = new GQLPathPatternNamespace(this, pathPattern); + registerNamespace(null, null, pathNs, forceNullable); + + String pathPatternAlias = alias == null ? deriveAlias(pathPattern) : alias; + pathPatternScope.addChild(fromNs, pathPatternAlias, forceNullable); + scopes.put(pathPattern, pathPatternScope); + // register MatchNode + for (SqlNode matchNode : pathPattern.getPathNodes()) { + SqlMatchNode sqlMatchNode = (SqlMatchNode) matchNode; + GQLScope nodeScope = new GQLScope(pathPatternScope, sqlMatchNode); + GQLMatchNodeNamespace nodeNs = new GQLMatchNodeNamespace(this, sqlMatchNode); + registerNamespace(nodeScope, deriveAlias(matchNode), nodeNs, forceNullable); + scopes.put(matchNode, nodeScope); + if (sqlMatchNode.getWhere() != null) { + SqlNode nodeWhere = sqlMatchNode.getWhere(); + // Where condition can only access NodeScope, not MatchPatternScope + GQLScope nodeWhereScope = new GQLScope(parentScope, nodeWhere); + + GQLMatchNodeWhereNamespace nodeWhereNs = + new GQLMatchNodeWhereNamespace(this, matchNode, nodeNs); + nodeWhereScope.addChild(nodeWhereNs, sqlMatchNode.getName(), forceNullable); + scopes.put(nodeWhere, nodeWhereScope); + } + } + return pathNs; + } + + @Override + protected SqlNode registerOtherFrom( + SqlValidatorScope parentScope, + SqlValidatorScope usingScope, + boolean register, + final SqlNode node, + SqlNode enclosingNode, + String alias, + SqlNodeList extendList, + boolean forceNullable, + final boolean lateral) { + switch (node.getKind()) { + case GQL_RETURN: + case GQL_FILTER: + case GQL_MATCH_PATTERN: + case GQL_LET: + if (alias == null) { + alias = deriveAlias(node); } - return super.getWithBodyScope(parentScope, with); - } - - public void registerScope(SqlNode sqlNode, SqlValidatorScope scope) { - scopes.put(sqlNode, scope); - } - - @Override - public void inferUnknownTypes(RelDataType inferredType, SqlValidatorScope scope, SqlNode node) { - super.inferUnknownTypes(inferredType, scope, node); - } - - @Override - public void validateNamespace(final SqlValidatorNamespace namespace, - RelDataType targetRowType) { - super.validateNamespace(namespace, targetRowType); - } - - @Override - protected void checkFieldCount(SqlNode node, SqlValidatorTable table, - SqlNode source, RelDataType logicalSourceRowType, - RelDataType logicalTargetRowType) { - if (!(logicalTargetRowType instanceof GraphRecordType)) { - super.checkFieldCount(node, table, source, logicalSourceRowType, logicalTargetRowType); + registerQuery( + parentScope, register ? usingScope : null, node, enclosingNode, alias, forceNullable); + return node; + default: + return super.registerOtherFrom( + parentScope, + usingScope, + register, + node, + enclosingNode, + alias, + extendList, + forceNullable, + lateral); + } + } + + @Override + protected SqlValidatorScope getWithBodyScope(SqlValidatorScope parentScope, SqlWith with) { + if (GQLNodeUtil.containMatch(with)) { + GQLScope withBodyScope = new GQLWithBodyScope(parentScope, with.withList); + if (with.withList.size() != 1) { + throw new GeaFlowDSLException( + with.getParserPosition().toString(), "Only support one with item"); + } + for (SqlNode withItem : with.withList) { + SqlValidatorNamespace withItemNs = getNamespace(withItem); + String withName = ((SqlWithItem) withItem).name.getSimple(); + withBodyScope.addChild(withItemNs, withName, false); + } + scopes.put(with.withList, withBodyScope); + return withBodyScope; + } + return super.getWithBodyScope(parentScope, with); + } + + public void registerScope(SqlNode sqlNode, SqlValidatorScope scope) { + scopes.put(sqlNode, scope); + } + + @Override + public void inferUnknownTypes(RelDataType inferredType, SqlValidatorScope scope, SqlNode node) { + super.inferUnknownTypes(inferredType, scope, node); + } + + @Override + public void validateNamespace(final SqlValidatorNamespace namespace, RelDataType targetRowType) { + super.validateNamespace(namespace, targetRowType); + } + + @Override + protected void checkFieldCount( + SqlNode node, + SqlValidatorTable table, + SqlNode source, + RelDataType logicalSourceRowType, + RelDataType logicalTargetRowType) { + if (!(logicalTargetRowType instanceof GraphRecordType)) { + super.checkFieldCount(node, table, source, logicalSourceRowType, logicalTargetRowType); + } + } + + private void registerGqlSubQuery( + SqlValidatorScope parentScope, String alias, SqlValidatorNamespace fromNs, SqlNode node) { + if (node.getKind() == SqlKind.GQL_PATH_PATTERN_SUB_QUERY) { + SqlPathPatternSubQuery subQuery = (SqlPathPatternSubQuery) node; + GQLSubQueryNamespace ns = new GQLSubQueryNamespace(this, subQuery); + registerNamespace(null, null, ns, true); + GQLScope subQueryScope = new GQLSubQueryScope(parentScope, subQuery); + subQueryScope.addChild(fromNs, alias, true); + scopes.put(subQuery, subQueryScope); + + SqlValidatorNamespace pathPatternNs = + registerPathPattern(subQuery.getPathPattern(), subQueryScope, fromNs, alias, true); + if (subQuery.getReturnValue() != null) { + SqlNode returnValue = subQuery.getReturnValue(); + GQLScope returnValueScope = new GQLScope(parentScope, returnValue); + returnValueScope.addChild(pathPatternNs, deriveAlias(subQuery.getPathPattern()), true); + scopes.put(returnValue, returnValueScope); + } + } else if (node instanceof SqlCall) { + SqlCall call = (SqlCall) node; + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + registerGqlSubQuery(parentScope, alias, fromNs, operand); } - } - - private void registerGqlSubQuery(SqlValidatorScope parentScope, String alias, - SqlValidatorNamespace fromNs, SqlNode node) { - if (node.getKind() == SqlKind.GQL_PATH_PATTERN_SUB_QUERY) { - SqlPathPatternSubQuery subQuery = (SqlPathPatternSubQuery) node; - GQLSubQueryNamespace ns = new GQLSubQueryNamespace(this, subQuery); - registerNamespace(null, null, ns, true); - GQLScope subQueryScope = new GQLSubQueryScope(parentScope, subQuery); - subQueryScope.addChild(fromNs, alias, true); - scopes.put(subQuery, subQueryScope); - - SqlValidatorNamespace pathPatternNs = - registerPathPattern(subQuery.getPathPattern(), subQueryScope, fromNs, alias, true); - if (subQuery.getReturnValue() != null) { - SqlNode returnValue = subQuery.getReturnValue(); - GQLScope returnValueScope = new GQLScope(parentScope, returnValue); - returnValueScope.addChild(pathPatternNs, deriveAlias(subQuery.getPathPattern()), true); - scopes.put(returnValue, returnValueScope); - } - } else if (node instanceof SqlCall) { - SqlCall call = (SqlCall) node; - for (SqlNode operand : call.getOperandList()) { - if (operand != null) { - registerGqlSubQuery(parentScope, alias, fromNs, operand); - } - } - } else if (node instanceof SqlNodeList) { - SqlNodeList nodes = (SqlNodeList) node; - for (SqlNode item : nodes.getList()) { - registerGqlSubQuery(parentScope, alias, fromNs, item); - } + } + } else if (node instanceof SqlNodeList) { + SqlNodeList nodes = (SqlNodeList) node; + for (SqlNode item : nodes.getList()) { + registerGqlSubQuery(parentScope, alias, fromNs, item); + } + } + } + + public SqlMatchNode getStartCycleMatchNode(SqlMatchNode node) { + return renamedMatchNodes.get(node); + } + + public SqlValidatorScope getScopes(SqlNode node) { + return scopes.get(getOriginal(node)); + } + + public RelDataType getMatchNodeType(SqlMatchNode matchNode) { + RelDataType nodeType = matchNodeTypes.get(matchNode); + assert nodeType != null; + return nodeType; + } + + public void registerMatchNodeType(SqlMatchNode matchNode, RelDataType nodeType) { + matchNodeTypes.put(matchNode, nodeType); + } + + public GQLContext getGQLContext() { + return gContext; + } + + public String deriveAlias(SqlNode node) { + if (node instanceof SqlIdentifier && ((SqlIdentifier) node).isSimple()) { + return ((SqlIdentifier) node).getSimple(); + } + if (node.getKind() == SqlKind.AS) { + return ((SqlCall) node).operand(1).toString(); + } + return ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; + } + + public String anonymousMatchNodeName(boolean isVertex) { + if (isVertex) { + return "v_" + ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; + } + return "e_" + ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; + } + + public SqlSelect asSqlSelect(SqlReturnStatement returnStmt) { + return new SqlSelect( + SqlParserPos.ZERO, + null, + returnStmt.getReturnList(), + null, + null, + returnStmt.getGroupBy(), + null, + null, + returnStmt.getOrderList(), + null, + null); + } + + public SqlSelect asSqlSelect(SqlNodeList selectItems) { + return new SqlSelect( + SqlParserPos.ZERO, null, selectItems, null, null, null, null, null, null, null, null); + } + + public SqlNode getAggregate(SqlNodeList sqlNodeList) { + return super.getAggregate(asSqlSelect(sqlNodeList)); + } + + @Override + protected SqlNode performUnconditionalRewrites(SqlNode node, boolean underFrom) { + if (node instanceof SqlMatchPattern) { + SqlMatchPattern matchPattern = (SqlMatchPattern) node; + if (matchPattern.getFrom() == null) { + if (gContext.getCurrentGraph() == null) { + throw new GeaFlowDSLException( + matchPattern.getParserPosition(), "Missing 'from graph' for match"); } - } - - public SqlMatchNode getStartCycleMatchNode(SqlMatchNode node) { - return renamedMatchNodes.get(node); - } - - public SqlValidatorScope getScopes(SqlNode node) { - return scopes.get(getOriginal(node)); - } - - public RelDataType getMatchNodeType(SqlMatchNode matchNode) { - RelDataType nodeType = matchNodeTypes.get(matchNode); - assert nodeType != null; - return nodeType; - } - - public void registerMatchNodeType(SqlMatchNode matchNode, RelDataType nodeType) { - matchNodeTypes.put(matchNode, nodeType); - } - - public GQLContext getGQLContext() { - return gContext; - } - - public String deriveAlias(SqlNode node) { - if (node instanceof SqlIdentifier && ((SqlIdentifier) node).isSimple()) { - return ((SqlIdentifier) node).getSimple(); + // Set current graph to from if not exists. + SqlIdentifier usingGraphId = + new SqlIdentifier(gContext.getCurrentGraph(), matchPattern.getParserPosition()); + matchPattern.setFrom(usingGraphId); + } + List nodes = matchPattern.getOperandList(); + for (int i = 0; i < nodes.size(); i++) { + SqlNode operand = nodes.get(i); + SqlNode newOperand = performUnconditionalRewrites(operand, underFrom); + if (newOperand != operand) { + matchPattern.setOperand(i, newOperand); } - if (node.getKind() == SqlKind.AS) { - return ((SqlCall) node).operand(1).toString(); + } + return matchPattern; + } else if (node instanceof SqlGraphAlgorithmCall) { + SqlGraphAlgorithmCall graphAlgorithmCall = (SqlGraphAlgorithmCall) node; + if (graphAlgorithmCall.getFrom() == null) { + if (gContext.getCurrentGraph() == null) { + throw new GeaFlowDSLException( + graphAlgorithmCall.getParserPosition().toString(), + "Missing 'from graph' for graph algorithm call"); } - return ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; - } - - public String anonymousMatchNodeName(boolean isVertex) { - if (isVertex) { - return "v_" + ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; + // Set current graph to from if not exists. + SqlIdentifier usingGraphId = + new SqlIdentifier(gContext.getCurrentGraph(), graphAlgorithmCall.getParserPosition()); + graphAlgorithmCall.setFrom(usingGraphId); + } + return graphAlgorithmCall; + } else if (node instanceof SqlUnionPathPattern) { + SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) node; + return new SqlUnionPathPattern( + unionPathPattern.getParserPosition(), + performUnconditionalRewrites(unionPathPattern.getLeft(), underFrom), + performUnconditionalRewrites(unionPathPattern.getRight(), underFrom), + unionPathPattern.isDistinct()); + } else if (node instanceof SqlPathPattern) { + SqlPathPattern pathPattern = (SqlPathPattern) node; + if (pathPattern.getPathAliasName() == null) { + SqlIdentifier pathAlias = + new SqlIdentifier("p_" + nextGeneratedId++, pathPattern.getParserPosition()); + pathPattern.setPathAlias(pathAlias); + } + for (int i = 0; i < pathPattern.getPathNodes().size(); i++) { + SqlMatchNode sqlMatchNode = (SqlMatchNode) pathPattern.getPathNodes().get(i); + if (sqlMatchNode.getName() == null) { + String nodeName = + anonymousMatchNodeName(sqlMatchNode.getKind() == SqlKind.GQL_MATCH_NODE); + SqlParserPos pos = sqlMatchNode.getParserPosition(); + sqlMatchNode.setName(new SqlIdentifier(nodeName, pos)); } - return "e_" + ANONYMOUS_COLUMN_PREFIX + nextGeneratedId++; - } - - public SqlSelect asSqlSelect(SqlReturnStatement returnStmt) { - return new SqlSelect(SqlParserPos.ZERO, null, - returnStmt.getReturnList(), null, null, - returnStmt.getGroupBy(), null, null, returnStmt.getOrderList(), null, - null); - } - - public SqlSelect asSqlSelect(SqlNodeList selectItems) { - return new SqlSelect(SqlParserPos.ZERO, null, selectItems, - null, null, null, null, null, null, - null, null); - } - - public SqlNode getAggregate(SqlNodeList sqlNodeList) { - return super.getAggregate(asSqlSelect(sqlNodeList)); - } - - @Override - protected SqlNode performUnconditionalRewrites(SqlNode node, boolean underFrom) { - if (node instanceof SqlMatchPattern) { - SqlMatchPattern matchPattern = (SqlMatchPattern) node; - if (matchPattern.getFrom() == null) { - if (gContext.getCurrentGraph() == null) { - throw new GeaFlowDSLException(matchPattern.getParserPosition(), - "Missing 'from graph' for match"); - } - // Set current graph to from if not exists. - SqlIdentifier usingGraphId = new SqlIdentifier(gContext.getCurrentGraph(), - matchPattern.getParserPosition()); - matchPattern.setFrom(usingGraphId); - } - List nodes = matchPattern.getOperandList(); - for (int i = 0; i < nodes.size(); i++) { - SqlNode operand = nodes.get(i); - SqlNode newOperand = performUnconditionalRewrites(operand, underFrom); - if (newOperand != operand) { - matchPattern.setOperand(i, newOperand); - } - } - return matchPattern; - } else if (node instanceof SqlGraphAlgorithmCall) { - SqlGraphAlgorithmCall graphAlgorithmCall = (SqlGraphAlgorithmCall) node; - if (graphAlgorithmCall.getFrom() == null) { - if (gContext.getCurrentGraph() == null) { - throw new GeaFlowDSLException(graphAlgorithmCall.getParserPosition().toString(), - "Missing 'from graph' for graph algorithm call"); - } - // Set current graph to from if not exists. - SqlIdentifier usingGraphId = new SqlIdentifier(gContext.getCurrentGraph(), - graphAlgorithmCall.getParserPosition()); - graphAlgorithmCall.setFrom(usingGraphId); - } - return graphAlgorithmCall; - } else if (node instanceof SqlUnionPathPattern) { - SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) node; - return new SqlUnionPathPattern(unionPathPattern.getParserPosition(), - performUnconditionalRewrites(unionPathPattern.getLeft(), underFrom), - performUnconditionalRewrites(unionPathPattern.getRight(), underFrom), - unionPathPattern.isDistinct()); - } else if (node instanceof SqlPathPattern) { - SqlPathPattern pathPattern = (SqlPathPattern) node; - if (pathPattern.getPathAliasName() == null) { - SqlIdentifier pathAlias = new SqlIdentifier("p_" + nextGeneratedId++, - pathPattern.getParserPosition()); - pathPattern.setPathAlias(pathAlias); - } - for (int i = 0; i < pathPattern.getPathNodes().size(); i++) { - SqlMatchNode sqlMatchNode = (SqlMatchNode) pathPattern.getPathNodes().get(i); - if (sqlMatchNode.getName() == null) { - String nodeName = anonymousMatchNodeName( - sqlMatchNode.getKind() == SqlKind.GQL_MATCH_NODE); - SqlParserPos pos = sqlMatchNode.getParserPosition(); - sqlMatchNode.setName(new SqlIdentifier(nodeName, pos)); - } - } - renameCycleMatchNode(pathPattern); - return super.performUnconditionalRewrites(pathPattern, underFrom); - } else if (isExistsPathPattern(node)) { - // Rewrite "where exists (a) - (b)" to - // "where count((a) - (b) => a) > 0" - SqlPathPatternSubQuery subQuery = (SqlPathPatternSubQuery) performUnconditionalRewrites( - ((SqlBasicCall) node).getOperands()[0], underFrom); - subQuery.setReturnValue(subQuery.getPathPattern().getFirst().getNameId()); - - SqlNode count = SqlStdOperatorTable.COUNT.createCall(node.getParserPosition(), subQuery); - SqlNode zero = SqlLiteral.createExactNumeric("0", node.getParserPosition()); - return SqlStdOperatorTable.GREATER_THAN.createCall(node.getParserPosition(), count, zero); - } else if (node != null && node.getKind() == SqlKind.MOD) { - SqlBasicCall mod = (SqlBasicCall) node; - mod.setOperator(GeaFlowOverwriteSqlOperators.MOD); - return mod; - } else if (node instanceof SqlInsert) { - // complete the insert target table name. - // e.g. "insert into g.v" will replace to "insert into instance.g.v" - SqlInsert insert = (SqlInsert) node; - SqlIdentifier completeId = gContext.completeCatalogObjName((SqlIdentifier) insert.getTargetTable()); - insert.setTargetTable(completeId); - return super.performUnconditionalRewrites(insert, underFrom); + } + renameCycleMatchNode(pathPattern); + return super.performUnconditionalRewrites(pathPattern, underFrom); + } else if (isExistsPathPattern(node)) { + // Rewrite "where exists (a) - (b)" to + // "where count((a) - (b) => a) > 0" + SqlPathPatternSubQuery subQuery = + (SqlPathPatternSubQuery) + performUnconditionalRewrites(((SqlBasicCall) node).getOperands()[0], underFrom); + subQuery.setReturnValue(subQuery.getPathPattern().getFirst().getNameId()); + + SqlNode count = SqlStdOperatorTable.COUNT.createCall(node.getParserPosition(), subQuery); + SqlNode zero = SqlLiteral.createExactNumeric("0", node.getParserPosition()); + return SqlStdOperatorTable.GREATER_THAN.createCall(node.getParserPosition(), count, zero); + } else if (node != null && node.getKind() == SqlKind.MOD) { + SqlBasicCall mod = (SqlBasicCall) node; + mod.setOperator(GeaFlowOverwriteSqlOperators.MOD); + return mod; + } else if (node instanceof SqlInsert) { + // complete the insert target table name. + // e.g. "insert into g.v" will replace to "insert into instance.g.v" + SqlInsert insert = (SqlInsert) node; + SqlIdentifier completeId = + gContext.completeCatalogObjName((SqlIdentifier) insert.getTargetTable()); + insert.setTargetTable(completeId); + return super.performUnconditionalRewrites(insert, underFrom); + } + return super.performUnconditionalRewrites(node, underFrom); + } + + private void renameCycleMatchNode(SqlPathPattern pathPattern) { + Map> name2MatchNodes = new HashMap<>(); + for (int i = 0; i < pathPattern.getPathNodes().size(); i++) { + SqlMatchNode sqlMatchNode = (SqlMatchNode) pathPattern.getPathNodes().get(i); + if (sqlMatchNode.getKind() != SqlKind.GQL_MATCH_NODE) { + continue; + } + String oldName = sqlMatchNode.getName(); + if (name2MatchNodes.containsKey(oldName)) { + sqlMatchNode.setName( + new SqlIdentifier( + oldName + RECURRING_COLUMN_SUFFIX + name2MatchNodes.get(oldName).size(), + sqlMatchNode.getParserPosition())); + name2MatchNodes.get(oldName).add(sqlMatchNode); + if (name2MatchNodes.get(oldName).size() > 1) { + renamedMatchNodes.put(sqlMatchNode, name2MatchNodes.get(oldName).get(0)); } - return super.performUnconditionalRewrites(node, underFrom); - } - - private void renameCycleMatchNode(SqlPathPattern pathPattern) { - Map> name2MatchNodes = new HashMap<>(); - for (int i = 0; i < pathPattern.getPathNodes().size(); i++) { - SqlMatchNode sqlMatchNode = (SqlMatchNode) pathPattern.getPathNodes().get(i); - if (sqlMatchNode.getKind() != SqlKind.GQL_MATCH_NODE) { - continue; - } - String oldName = sqlMatchNode.getName(); - if (name2MatchNodes.containsKey(oldName)) { - sqlMatchNode.setName(new SqlIdentifier( - oldName + RECURRING_COLUMN_SUFFIX + name2MatchNodes.get(oldName).size(), - sqlMatchNode.getParserPosition())); - name2MatchNodes.get(oldName).add(sqlMatchNode); - if (name2MatchNodes.get(oldName).size() > 1) { - renamedMatchNodes.put(sqlMatchNode, name2MatchNodes.get(oldName).get(0)); - } - } else { - name2MatchNodes.put(sqlMatchNode.getName(), new ArrayList<>()); - name2MatchNodes.get(oldName).add(sqlMatchNode); - } + } else { + name2MatchNodes.put(sqlMatchNode.getName(), new ArrayList<>()); + name2MatchNodes.get(oldName).add(sqlMatchNode); + } + } + } + + @Override + protected RelDataType getLogicalTargetRowType(RelDataType targetRowType, SqlInsert insert) { + RelDataType targetType = super.getLogicalTargetRowType(targetRowType, insert); + if (targetType instanceof VertexRecordType) { + List fields = new ArrayList<>(targetType.getFieldList()); + fields.remove(VertexType.LABEL_FIELD_POSITION); + targetType = new RelRecordType(fields); + } else if (targetType instanceof EdgeRecordType) { + List fields = new ArrayList<>(targetType.getFieldList()); + fields.remove(EdgeType.LABEL_FIELD_POSITION); + targetType = new RelRecordType(fields); + } + return targetType; + } + + @Override + protected RelDataType createTargetRowType( + SqlValidatorTable table, SqlNodeList targetColumnList, boolean append) { + GeaFlowGraph graph = table.unwrap(GeaFlowGraph.class); + if (graph != null) { // for insert g + GraphRecordType graphType = (GraphRecordType) graph.getRowType(getTypeFactory()); + if (targetColumnList == null || targetColumnList.size() == 0) { + throw new GeaFlowDSLException("Missing target columns for insert graph statement"); + } + for (SqlNode targetColumn : targetColumnList) { + List names = ((SqlIdentifier) targetColumn).names; + RelDataTypeField field = graphType.getField(names, isCaseSensitive()); + if (field == null) { + throw new GeaFlowDSLException( + targetColumn.getParserPosition().toString(), + "Insert field: {} is not found in graph: {}", + targetColumn, + graph.getName()); } - } - - @Override - protected RelDataType getLogicalTargetRowType( - RelDataType targetRowType, - SqlInsert insert) { - RelDataType targetType = super.getLogicalTargetRowType(targetRowType, insert); - if (targetType instanceof VertexRecordType) { - List fields = new ArrayList<>(targetType.getFieldList()); - fields.remove(VertexType.LABEL_FIELD_POSITION); - targetType = new RelRecordType(fields); - } else if (targetType instanceof EdgeRecordType) { - List fields = new ArrayList<>(targetType.getFieldList()); - fields.remove(EdgeType.LABEL_FIELD_POSITION); - targetType = new RelRecordType(fields); + } + return graphType; + } + return super.createTargetRowType(table, targetColumnList, append); + } + + @Override + protected void checkTypeAssignment( + RelDataType sourceRowType, RelDataType targetRowType, final SqlNode query) { + if (targetRowType instanceof GraphRecordType) { // for insert g + GraphRecordType graphType = (GraphRecordType) targetRowType; + SqlInsert insert = (SqlInsert) query; + for (int i = 0; i < insert.getTargetColumnList().size(); i++) { + SqlIdentifier targetColumn = (SqlIdentifier) insert.getTargetColumnList().get(i); + List names = targetColumn.names; + RelDataTypeField targetField = graphType.getField(names, isCaseSensitive()); + RelDataTypeField sourceField = sourceRowType.getFieldList().get(i); + + if (!SqlTypeUtil.canAssignFrom(targetField.getType(), sourceField.getType())) { + throw newValidationError( + targetColumn, + RESOURCE.typeNotAssignable( + targetField.getName(), targetField.getType().getFullTypeString(), + sourceField.getName(), sourceField.getType().getFullTypeString())); } - return targetType; - } - - @Override - protected RelDataType createTargetRowType( - SqlValidatorTable table, - SqlNodeList targetColumnList, - boolean append) { - GeaFlowGraph graph = table.unwrap(GeaFlowGraph.class); - if (graph != null) { // for insert g - GraphRecordType graphType = (GraphRecordType) graph.getRowType(getTypeFactory()); - if (targetColumnList == null || targetColumnList.size() == 0) { - throw new GeaFlowDSLException("Missing target columns for insert graph statement"); - } - for (SqlNode targetColumn : targetColumnList) { - List names = ((SqlIdentifier) targetColumn).names; - RelDataTypeField field = graphType.getField(names, isCaseSensitive()); - if (field == null) { - throw new GeaFlowDSLException(targetColumn.getParserPosition().toString(), - "Insert field: {} is not found in graph: {}", targetColumn, graph.getName()); - } + } + } else { + super.checkTypeAssignment(sourceRowType, targetRowType, query); + } + } + + private boolean isExistsPathPattern(SqlNode node) { + return node != null + && node.getKind() == SqlKind.EXISTS + && ((SqlBasicCall) node).getOperands().length == 1 + && ((SqlBasicCall) node).getOperands()[0] instanceof SqlPathPatternSubQuery; + } + + public SqlNode expandReturnGroupOrderExpr( + SqlReturnStatement returnStmt, SqlValidatorScope scope, SqlNode orderExpr) { + SqlNode newSqlNode = + (new ReturnGroupOrderExpressionExpander(returnStmt, scope, orderExpr)).go(); + if (newSqlNode != orderExpr) { + this.inferUnknownTypes(this.unknownType, scope, newSqlNode); + RelDataType type = this.deriveType(scope, newSqlNode); + this.setValidatedNodeType(newSqlNode, type); + } + return newSqlNode; + } + + class ReturnGroupOrderExpressionExpander extends SqlScopedShuttle { + + private final List aliasList; + private final SqlReturnStatement returnStmt; + private final SqlNode root; + + ReturnGroupOrderExpressionExpander( + SqlReturnStatement returnStmt, SqlValidatorScope scope, SqlNode root) { + super(scope); + this.returnStmt = returnStmt; + this.root = root; + this.aliasList = getNamespace(returnStmt).getRowType().getFieldNames(); + } + + public SqlNode go() { + return this.root.accept(this); + } + + public SqlNode visit(SqlLiteral literal) { + if (literal == this.root && getConformance().isSortByOrdinal()) { + switch (literal.getTypeName()) { + case DECIMAL: + case DOUBLE: + int intValue = literal.intValue(false); + if (intValue >= 0) { + if (intValue >= 1 && intValue <= this.aliasList.size()) { + int ordinal = intValue - 1; + return this.nthSelectItem(ordinal, literal.getParserPosition()); + } + + throw newValidationError(literal, RESOURCE.orderByOrdinalOutOfRange()); } - return graphType; + break; + default: } - return super.createTargetRowType(table, targetColumnList, append); - } + } - @Override - protected void checkTypeAssignment( - RelDataType sourceRowType, - RelDataType targetRowType, - final SqlNode query) { - if (targetRowType instanceof GraphRecordType) { // for insert g - GraphRecordType graphType = (GraphRecordType) targetRowType; - SqlInsert insert = (SqlInsert) query; - for (int i = 0; i < insert.getTargetColumnList().size(); i++) { - SqlIdentifier targetColumn = (SqlIdentifier) insert.getTargetColumnList().get(i); - List names = targetColumn.names; - RelDataTypeField targetField = graphType.getField(names, isCaseSensitive()); - RelDataTypeField sourceField = sourceRowType.getFieldList().get(i); - - if (!SqlTypeUtil.canAssignFrom(targetField.getType(), sourceField.getType())) { - throw newValidationError(targetColumn, - RESOURCE.typeNotAssignable( - targetField.getName(), targetField.getType().getFullTypeString(), - sourceField.getName(), sourceField.getType().getFullTypeString())); - } - } - } else { - super.checkTypeAssignment(sourceRowType, targetRowType, query); - } + return super.visit(literal); } - private boolean isExistsPathPattern(SqlNode node) { - return node != null - && node.getKind() == SqlKind.EXISTS - && ((SqlBasicCall) node).getOperands().length == 1 - && ((SqlBasicCall) node).getOperands()[0] instanceof SqlPathPatternSubQuery - ; - } + private SqlNode nthSelectItem(int ordinal, SqlParserPos pos) { + SqlNodeList expandedReturnList = returnStmt.getReturnList(); + SqlNode expr = expandedReturnList.get(ordinal); + SqlNode exprx = SqlUtil.stripAs(expr); + if (exprx instanceof SqlIdentifier) { + exprx = this.getScope().fullyQualify((SqlIdentifier) exprx).identifier; + } - public SqlNode expandReturnGroupOrderExpr(SqlReturnStatement returnStmt, - SqlValidatorScope scope, SqlNode orderExpr) { - SqlNode newSqlNode = - (new ReturnGroupOrderExpressionExpander(returnStmt, scope, orderExpr)).go(); - if (newSqlNode != orderExpr) { - this.inferUnknownTypes(this.unknownType, scope, newSqlNode); - RelDataType type = this.deriveType(scope, newSqlNode); - this.setValidatedNodeType(newSqlNode, type); - } - return newSqlNode; + return exprx.clone(pos); } - class ReturnGroupOrderExpressionExpander extends SqlScopedShuttle { - - private final List aliasList; - private final SqlReturnStatement returnStmt; - private final SqlNode root; - - ReturnGroupOrderExpressionExpander(SqlReturnStatement returnStmt, - SqlValidatorScope scope, SqlNode root) { - super(scope); - this.returnStmt = returnStmt; - this.root = root; - this.aliasList = getNamespace(returnStmt).getRowType().getFieldNames(); + public SqlNode visit(SqlIdentifier id) { + if (id.isSimple() && getConformance().isSortByAlias()) { + String alias = id.getSimple(); + SqlValidatorNamespace selectNs = getNamespace(returnStmt); + RelDataType rowType = selectNs.getRowTypeSansSystemColumns(); + SqlNameMatcher nameMatcher = getCatalogReader().nameMatcher(); + RelDataTypeField field = nameMatcher.field(rowType, alias); + if (field != null) { + return this.nthSelectItem(field.getIndex(), id.getParserPosition()); } - - public SqlNode go() { - return this.root.accept(this); - } - - public SqlNode visit(SqlLiteral literal) { - if (literal == this.root && getConformance().isSortByOrdinal()) { - switch (literal.getTypeName()) { - case DECIMAL: - case DOUBLE: - int intValue = literal.intValue(false); - if (intValue >= 0) { - if (intValue >= 1 && intValue <= this.aliasList.size()) { - int ordinal = intValue - 1; - return this.nthSelectItem(ordinal, literal.getParserPosition()); - } - - throw newValidationError(literal, - RESOURCE.orderByOrdinalOutOfRange()); - } - break; - default: - } - } - - return super.visit(literal); - } - - private SqlNode nthSelectItem(int ordinal, SqlParserPos pos) { - SqlNodeList expandedReturnList = returnStmt.getReturnList(); - SqlNode expr = expandedReturnList.get(ordinal); - SqlNode exprx = SqlUtil.stripAs(expr); - if (exprx instanceof SqlIdentifier) { - exprx = this.getScope().fullyQualify((SqlIdentifier) exprx).identifier; - } - - return exprx.clone(pos); - } - - public SqlNode visit(SqlIdentifier id) { - if (id.isSimple() && getConformance().isSortByAlias()) { - String alias = id.getSimple(); - SqlValidatorNamespace selectNs = getNamespace(returnStmt); - RelDataType rowType = selectNs.getRowTypeSansSystemColumns(); - SqlNameMatcher nameMatcher = getCatalogReader().nameMatcher(); - RelDataTypeField field = nameMatcher.field(rowType, alias); - if (field != null) { - return this.nthSelectItem(field.getIndex(), id.getParserPosition()); - } - } - //Replace if alias exists - int size = id.names.size(); - final SqlIdentifier prefix = id.getComponent(0, 1); - String alias = prefix.getSimple(); - SqlValidatorNamespace selectNs = getNamespace(returnStmt); - RelDataType rowType = selectNs.getRowTypeSansSystemColumns(); - SqlNameMatcher nameMatcher = getCatalogReader().nameMatcher(); - RelDataTypeField field = nameMatcher.field(rowType, alias); - if (field != null) { - SqlNode identifierNewPrefix = this.nthSelectItem(field.getIndex(), - id.getParserPosition()); - assert identifierNewPrefix instanceof SqlIdentifier : "At " + id.getParserPosition() - + " : Prefix in OrderBy should be identifier."; - List newIdList = new ArrayList<>(); - newIdList.addAll(((SqlIdentifier) identifierNewPrefix).names); - newIdList.addAll(id.getComponent(1, size).names); - return new SqlIdentifier(newIdList, id.getParserPosition()); - } else { - return this.getScope().fullyQualify(id).identifier; - } - } - - protected SqlNode visitScoped(SqlCall call) { - return call instanceof SqlSelect ? call : super.visitScoped(call); - } - } - - public boolean isCaseSensitive() { - return getCatalogReader().nameMatcher().isCaseSensitive(); - } - - public SqlNameMatcher nameMatcher() { - return getCatalogReader().nameMatcher(); - } - - public QueryNodeContext getCurrentQueryNodeContext() { - return currentQueryNodeContext; - } - - public void addModifyGraphType(SqlLetStatement letStatement, GraphRecordType modifyGraphType) { - let2ModifyGraphType.put(letStatement, modifyGraphType); - } - - public GraphRecordType getModifyGraphType(SqlLetStatement letStatement) { - return let2ModifyGraphType.get(letStatement); - } + } + // Replace if alias exists + int size = id.names.size(); + final SqlIdentifier prefix = id.getComponent(0, 1); + String alias = prefix.getSimple(); + SqlValidatorNamespace selectNs = getNamespace(returnStmt); + RelDataType rowType = selectNs.getRowTypeSansSystemColumns(); + SqlNameMatcher nameMatcher = getCatalogReader().nameMatcher(); + RelDataTypeField field = nameMatcher.field(rowType, alias); + if (field != null) { + SqlNode identifierNewPrefix = this.nthSelectItem(field.getIndex(), id.getParserPosition()); + assert identifierNewPrefix instanceof SqlIdentifier + : "At " + id.getParserPosition() + " : Prefix in OrderBy should be identifier."; + List newIdList = new ArrayList<>(); + newIdList.addAll(((SqlIdentifier) identifierNewPrefix).names); + newIdList.addAll(id.getComponent(1, size).names); + return new SqlIdentifier(newIdList, id.getParserPosition()); + } else { + return this.getScope().fullyQualify(id).identifier; + } + } + + protected SqlNode visitScoped(SqlCall call) { + return call instanceof SqlSelect ? call : super.visitScoped(call); + } + } + + public boolean isCaseSensitive() { + return getCatalogReader().nameMatcher().isCaseSensitive(); + } + + public SqlNameMatcher nameMatcher() { + return getCatalogReader().nameMatcher(); + } + + public QueryNodeContext getCurrentQueryNodeContext() { + return currentQueryNodeContext; + } + + public void addModifyGraphType(SqlLetStatement letStatement, GraphRecordType modifyGraphType) { + let2ModifyGraphType.put(letStatement, modifyGraphType); + } + + public GraphRecordType getModifyGraphType(SqlLetStatement letStatement) { + return let2ModifyGraphType.get(letStatement); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/QueryNodeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/QueryNodeContext.java index 9633e1437..4aed13b6b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/QueryNodeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/QueryNodeContext.java @@ -21,18 +21,18 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.dsl.calcite.GraphRecordType; public class QueryNodeContext { - private final Map modifyGraphs = new HashMap<>(); - + private final Map modifyGraphs = new HashMap<>(); - public void addModifyGraph(GraphRecordType graph) { - modifyGraphs.put(graph.getGraphName(), graph); - } + public void addModifyGraph(GraphRecordType graph) { + modifyGraphs.put(graph.getGraphName(), graph); + } - public GraphRecordType getModifyGraph(String graphName) { - return modifyGraphs.get(graphName); - } + public GraphRecordType getModifyGraph(String graphName) { + return modifyGraphs.get(graphName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLAlgorithmNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLAlgorithmNamespace.java index 14500375a..ac406571e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLAlgorithmNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLAlgorithmNamespace.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.*; import org.apache.calcite.sql.validate.SqlValidatorImpl; @@ -33,63 +34,75 @@ public class GQLAlgorithmNamespace extends GQLBaseNamespace { - private final SqlGraphAlgorithmCall graphAlgorithmCall; + private final SqlGraphAlgorithmCall graphAlgorithmCall; - public GQLAlgorithmNamespace(SqlValidatorImpl validator, SqlGraphAlgorithmCall graphAlgorithmCall) { - super(validator, graphAlgorithmCall); - this.graphAlgorithmCall = graphAlgorithmCall; - } + public GQLAlgorithmNamespace( + SqlValidatorImpl validator, SqlGraphAlgorithmCall graphAlgorithmCall) { + super(validator, graphAlgorithmCall); + this.graphAlgorithmCall = graphAlgorithmCall; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - Map columnNameMap = new HashMap<>(); - SqlNodeList yields = graphAlgorithmCall.getYields(); - if (yields != null) { - for (SqlNode yield : yields) { - String yieldName = ((SqlIdentifier) yield).getSimple(); - if (columnNameMap.get(yieldName) == null) { - columnNameMap.put(yieldName, true); - } else { - throw new GeaFlowDSLException(yield.getParserPosition(), "duplicate yield " - + "name: {}", yieldName); - } - } - } - List overloads = new ArrayList<>(); - getValidator().getOperatorTable().lookupOperatorOverloads(graphAlgorithmCall.getAlgorithm(), - SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR, SqlSyntax.FUNCTION, overloads); - if (overloads.isEmpty()) { - throw new GeaFlowDSLException(graphAlgorithmCall.getParserPosition(), - "Cannot load graph algorithm implementation of {}", - graphAlgorithmCall.getAlgorithm().getSimple()); - } else { - //When multiple implementation classes of an algorithm with the same name are found, - // use the last registered class. - graphAlgorithmCall.setOperator(overloads.get(overloads.size() - 1)); - } - SqlOperator function = graphAlgorithmCall.getOperator(); - RelDataType inferType = function.inferReturnType(getValidator().getTypeFactory(), - Collections.emptyList()); - if (yields == null) { - return inferType; + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + Map columnNameMap = new HashMap<>(); + SqlNodeList yields = graphAlgorithmCall.getYields(); + if (yields != null) { + for (SqlNode yield : yields) { + String yieldName = ((SqlIdentifier) yield).getSimple(); + if (columnNameMap.get(yieldName) == null) { + columnNameMap.put(yieldName, true); } else { - if (yields.size() != inferType.getFieldCount()) { - throw new GeaFlowDSLException(graphAlgorithmCall.getParserPosition().toString(), - "The number of fields returned after calling the graph algorithm: {} " - + "should be consistent with the definition in the graph algorithm implementation class: {}.", - yields.size(), inferType.getFieldCount()); - } - final List> fieldList = new ArrayList<>(); - for (int i = 0, size = yields.size(); i < size; i++) { - fieldList.add(Pair.of(((SqlIdentifier) yields.get(i)).getSimple(), - inferType.getFieldList().get(i).getType())); - } - return validator.getTypeFactory().createStructType(fieldList); + throw new GeaFlowDSLException( + yield.getParserPosition(), "duplicate yield " + "name: {}", yieldName); } + } } - - @Override - public SqlNode getNode() { - return graphAlgorithmCall; + List overloads = new ArrayList<>(); + getValidator() + .getOperatorTable() + .lookupOperatorOverloads( + graphAlgorithmCall.getAlgorithm(), + SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR, + SqlSyntax.FUNCTION, + overloads); + if (overloads.isEmpty()) { + throw new GeaFlowDSLException( + graphAlgorithmCall.getParserPosition(), + "Cannot load graph algorithm implementation of {}", + graphAlgorithmCall.getAlgorithm().getSimple()); + } else { + // When multiple implementation classes of an algorithm with the same name are found, + // use the last registered class. + graphAlgorithmCall.setOperator(overloads.get(overloads.size() - 1)); } + SqlOperator function = graphAlgorithmCall.getOperator(); + RelDataType inferType = + function.inferReturnType(getValidator().getTypeFactory(), Collections.emptyList()); + if (yields == null) { + return inferType; + } else { + if (yields.size() != inferType.getFieldCount()) { + throw new GeaFlowDSLException( + graphAlgorithmCall.getParserPosition().toString(), + "The number of fields returned after calling the graph algorithm: {} should be" + + " consistent with the definition in the graph algorithm implementation class:" + + " {}.", + yields.size(), + inferType.getFieldCount()); + } + final List> fieldList = new ArrayList<>(); + for (int i = 0, size = yields.size(); i < size; i++) { + fieldList.add( + Pair.of( + ((SqlIdentifier) yields.get(i)).getSimple(), + inferType.getFieldList().get(i).getType())); + } + return validator.getTypeFactory().createStructType(fieldList); + } + } + + @Override + public SqlNode getNode() { + return graphAlgorithmCall; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLBaseNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLBaseNamespace.java index 12d24275b..10334166a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLBaseNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLBaseNamespace.java @@ -26,17 +26,16 @@ public abstract class GQLBaseNamespace extends AbstractNamespace { - public GQLBaseNamespace(SqlValidatorImpl validator, - SqlNode enclosingNode) { - super(validator, enclosingNode); - } + public GQLBaseNamespace(SqlValidatorImpl validator, SqlNode enclosingNode) { + super(validator, enclosingNode); + } - @Override - public GQLValidatorImpl getValidator() { - return (GQLValidatorImpl) validator; - } + @Override + public GQLValidatorImpl getValidator() { + return (GQLValidatorImpl) validator; + } - public boolean isCaseSensitive() { - return ((GQLValidatorImpl) validator).isCaseSensitive(); - } + public boolean isCaseSensitive() { + return ((GQLValidatorImpl) validator).isCaseSensitive(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLFilterNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLFilterNamespace.java index fb7c6ae4b..1a0ab5e88 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLFilterNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLFilterNamespace.java @@ -32,41 +32,41 @@ public class GQLFilterNamespace extends GQLBaseNamespace { - private final SqlFilterStatement filterStatement; + private final SqlFilterStatement filterStatement; - public GQLFilterNamespace(SqlValidatorImpl validator, SqlNode enclosingNode, - SqlFilterStatement filterStatement) { - super(validator, enclosingNode); - this.filterStatement = filterStatement; - } + public GQLFilterNamespace( + SqlValidatorImpl validator, SqlNode enclosingNode, SqlFilterStatement filterStatement) { + super(validator, enclosingNode); + this.filterStatement = filterStatement; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - SqlValidatorNamespace fromNs = validator.getNamespace(filterStatement.getFrom()); - // Validate parent. - fromNs.validate(targetRowType); + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + SqlValidatorNamespace fromNs = validator.getNamespace(filterStatement.getFrom()); + // Validate parent. + fromNs.validate(targetRowType); - SqlValidatorScope scope = getValidator().getScopes(filterStatement); + SqlValidatorScope scope = getValidator().getScopes(filterStatement); - SqlNode condition = filterStatement.getCondition(); - // expand the condition, e.g. expand the "where id > 10" to "where g0.a.id > 10". - condition = validator.expand(condition, scope); - filterStatement.setCondition(condition); + SqlNode condition = filterStatement.getCondition(); + // expand the condition, e.g. expand the "where id > 10" to "where g0.a.id > 10". + condition = validator.expand(condition, scope); + filterStatement.setCondition(condition); - RelDataType boolType = validator.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); - getValidator().inferUnknownTypes(boolType, scope, condition); - condition.validate(validator, scope); + RelDataType boolType = validator.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); + getValidator().inferUnknownTypes(boolType, scope, condition); + condition.validate(validator, scope); - RelDataType conditionType = validator.deriveType(scope, condition); - if (!SqlTypeUtil.inBooleanFamily(conditionType)) { - throw validator.newValidationError(condition, RESOURCE.condMustBeBoolean("Filter")); - } - // Filter return parent type. - return fromNs.getType(); + RelDataType conditionType = validator.deriveType(scope, condition); + if (!SqlTypeUtil.inBooleanFamily(conditionType)) { + throw validator.newValidationError(condition, RESOURCE.condMustBeBoolean("Filter")); } + // Filter return parent type. + return fromNs.getType(); + } - @Override - public SqlNode getNode() { - return filterStatement; - } + @Override + public SqlNode getNode() { + return filterStatement; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLInsertNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLInsertNamespace.java index 2502df4b2..43bce1a66 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLInsertNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLInsertNamespace.java @@ -19,8 +19,8 @@ package org.apache.geaflow.dsl.validator.namespace; -import com.google.common.collect.ImmutableList; import java.util.Objects; + import org.apache.calcite.plan.RelOptSchema; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; @@ -37,66 +37,74 @@ import org.apache.geaflow.dsl.schema.GeaFlowGraph; import org.apache.geaflow.dsl.validator.GQLValidatorImpl; +import com.google.common.collect.ImmutableList; + public class GQLInsertNamespace extends GQLBaseNamespace { - private final SqlInsert insert; + private final SqlInsert insert; - private final IdentifierNamespace idNamespace; + private final IdentifierNamespace idNamespace; - public GQLInsertNamespace(GQLValidatorImpl validator, SqlInsert insert, SqlValidatorScope parentScope) { - super(validator, insert); - this.insert = Objects.requireNonNull(insert); - SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); - SqlIdentifier targetId; - int size = targetTable.names.size(); - if (size >= 3) { // for instance.g.v use instance.g for validate - targetId = new SqlIdentifier(targetTable.names.subList(0, 2), targetTable.getParserPosition()); - } else { - targetId = targetTable; - } - this.idNamespace = new IdentifierNamespace(validator, targetId, insert.getTargetTable(), parentScope); + public GQLInsertNamespace( + GQLValidatorImpl validator, SqlInsert insert, SqlValidatorScope parentScope) { + super(validator, insert); + this.insert = Objects.requireNonNull(insert); + SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); + SqlIdentifier targetId; + int size = targetTable.names.size(); + if (size >= 3) { // for instance.g.v use instance.g for validate + targetId = + new SqlIdentifier(targetTable.names.subList(0, 2), targetTable.getParserPosition()); + } else { + targetId = targetTable; } + this.idNamespace = + new IdentifierNamespace(validator, targetId, insert.getTargetTable(), parentScope); + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - RelDataType type = idNamespace.validateImpl(targetRowType); - if (type instanceof GraphRecordType) { - SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); - if (targetTable.names.size() == 2) { // insert instance.g - return type; - } else if (targetTable.names.size() == 3) { // insert instance.g.v - String vertexTableName = targetTable.names.get(2); - RelDataTypeField field = type.getField(vertexTableName, isCaseSensitive(), false); - if (field == null) { - throw new GeaFlowDSLException("Field:{} is not found, graph type is:{}", vertexTableName, type); - } - type = field.getType(); - } else { - throw new GeaFlowDSLException(targetTable.getParserPosition().toString(), - "Illegal target table name size: {}", targetTable.names.size()); - } - return type; - } + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + RelDataType type = idNamespace.validateImpl(targetRowType); + if (type instanceof GraphRecordType) { + SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); + if (targetTable.names.size() == 2) { // insert instance.g return type; + } else if (targetTable.names.size() == 3) { // insert instance.g.v + String vertexTableName = targetTable.names.get(2); + RelDataTypeField field = type.getField(vertexTableName, isCaseSensitive(), false); + if (field == null) { + throw new GeaFlowDSLException( + "Field:{} is not found, graph type is:{}", vertexTableName, type); + } + type = field.getType(); + } else { + throw new GeaFlowDSLException( + targetTable.getParserPosition().toString(), + "Illegal target table name size: {}", + targetTable.names.size()); + } + return type; } + return type; + } - @Override - public SqlValidatorTable getTable() { - SqlValidatorTable validatorTable = idNamespace.resolve().getTable(); - GeaFlowGraph graph = validatorTable.unwrap(GeaFlowGraph.class); - SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); + @Override + public SqlValidatorTable getTable() { + SqlValidatorTable validatorTable = idNamespace.resolve().getTable(); + GeaFlowGraph graph = validatorTable.unwrap(GeaFlowGraph.class); + SqlIdentifier targetTable = (SqlIdentifier) insert.getTargetTable(); - if (graph != null && targetTable.names.size() == 3) { // for insert into instance.g.v - String tableName = targetTable.names.get(2); - Table table = graph.getTable(tableName); - RelOptSchema optSchema = getValidator().getCatalogReader().unwrap(RelOptSchema.class); - return GQLRelOptTableImpl.create(optSchema, getRowType(), table, ImmutableList.of()); - } - return validatorTable; + if (graph != null && targetTable.names.size() == 3) { // for insert into instance.g.v + String tableName = targetTable.names.get(2); + Table table = graph.getTable(tableName); + RelOptSchema optSchema = getValidator().getCatalogReader().unwrap(RelOptSchema.class); + return GQLRelOptTableImpl.create(optSchema, getRowType(), table, ImmutableList.of()); } + return validatorTable; + } - @Override - public SqlNode getNode() { - return insert; - } + @Override + public SqlNode getNode() { + return insert; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLLetNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLLetNamespace.java index 5cd425e92..a2199acd3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLLetNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLLetNamespace.java @@ -39,88 +39,99 @@ public class GQLLetNamespace extends GQLBaseNamespace { - private final SqlLetStatement letStatement; + private final SqlLetStatement letStatement; - public GQLLetNamespace(SqlValidatorImpl validator, SqlLetStatement letStatement) { - super(validator, letStatement); - this.letStatement = letStatement; - } + public GQLLetNamespace(SqlValidatorImpl validator, SqlLetStatement letStatement) { + super(validator, letStatement); + this.letStatement = letStatement; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - if (letStatement.getFrom() != null) { - SqlValidatorNamespace fromNs = validator.getNamespace(letStatement.getFrom()); - fromNs.validate(targetRowType); + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + if (letStatement.getFrom() != null) { + SqlValidatorNamespace fromNs = validator.getNamespace(letStatement.getFrom()); + fromNs.validate(targetRowType); - SqlValidatorScope scope = ((GQLValidatorImpl) validator).getScopes(letStatement); + SqlValidatorScope scope = ((GQLValidatorImpl) validator).getScopes(letStatement); - RelDataType fromType = fromNs.getType(); - assert fromType instanceof PathRecordType; + RelDataType fromType = fromNs.getType(); + assert fromType instanceof PathRecordType; - PathRecordType inputPathType = (PathRecordType) fromType; - String leftLabel = letStatement.getLeftLabel(); + PathRecordType inputPathType = (PathRecordType) fromType; + String leftLabel = letStatement.getLeftLabel(); - RelDataTypeField labelField = inputPathType.getField(leftLabel, - isCaseSensitive(), false); - if (labelField == null) { - throw new GeaFlowDSLException(letStatement.getLeftVar().getParserPosition().toString(), - "Left label: {} is not exists in input path fields: {}.", leftLabel, fromType.getFieldNames()); - } - SqlNode expression = letStatement.getExpression(); - expression = validator.expand(expression, scope); - letStatement.setExpression(expression); + RelDataTypeField labelField = inputPathType.getField(leftLabel, isCaseSensitive(), false); + if (labelField == null) { + throw new GeaFlowDSLException( + letStatement.getLeftVar().getParserPosition().toString(), + "Left label: {} is not exists in input path fields: {}.", + leftLabel, + fromType.getFieldNames()); + } + SqlNode expression = letStatement.getExpression(); + expression = validator.expand(expression, scope); + letStatement.setExpression(expression); - expression.validate(validator, scope); - ((GQLValidatorImpl) validator).inferUnknownTypes(validator.getUnknownType(), scope, expression); - RelDataType expressionType = validator.deriveType(scope, expression); - // set let expression type nullable - expressionType = validator.getTypeFactory().createTypeWithNullability(expressionType, true); - RelDataType labelType = labelField.getType(); - // check exist field type. - RelDataTypeField existField = labelType.getField(letStatement.getLeftField(), isCaseSensitive(), false); - if (existField != null && !SqlTypeUtil.canCastFrom(existField.getType(), expressionType, true)) { - throw new GeaFlowDSLException(letStatement.getParserPosition().toString(), - "Let statement cannot assign from {} to {}", expressionType, existField.getType().getSqlTypeName()); - } - RelDataType newLabelType = labelType; - if (existField == null) { // first define this let variable. - if (labelType.getSqlTypeName() == SqlTypeName.VERTEX) { - newLabelType = ((VertexRecordType) labelType).add(letStatement.getLeftField(), - expressionType, isCaseSensitive()); - } else if (labelType.getSqlTypeName() == SqlTypeName.EDGE) { - newLabelType = ((EdgeRecordType) labelType).add(letStatement.getLeftField(), - expressionType, isCaseSensitive()); - } else { - throw new IllegalArgumentException("Illegal labelType: " + labelType); - } - } - GraphRecordType graphRecordType = - GQLPathPatternScope.findCurrentGraphType(getValidator(), scope); - assert graphRecordType != null; - // Modify the query node level graph schema if it is a global vertex field modify - if (letStatement.isGlobal()) { - if (labelType.getSqlTypeName() != SqlTypeName.VERTEX) { - throw new GeaFlowDSLException(letStatement.getParserPosition(), - "Only vertex support global variable"); - } - QueryNodeContext nodeContext = getValidator().getCurrentQueryNodeContext(); - GraphRecordType newGraphType = graphRecordType.addVertexField(letStatement.getLeftField(), - expressionType); - // add to the node context - nodeContext.addModifyGraph(newGraphType); - // add to the validator and used by the GQLToRelConverter. - getValidator().addModifyGraphType(letStatement, newGraphType); - } else { - getValidator().addModifyGraphType(letStatement, graphRecordType); - } - // replace the type for left label. - return inputPathType.copy(labelField.getIndex(), newLabelType); + expression.validate(validator, scope); + ((GQLValidatorImpl) validator) + .inferUnknownTypes(validator.getUnknownType(), scope, expression); + RelDataType expressionType = validator.deriveType(scope, expression); + // set let expression type nullable + expressionType = validator.getTypeFactory().createTypeWithNullability(expressionType, true); + RelDataType labelType = labelField.getType(); + // check exist field type. + RelDataTypeField existField = + labelType.getField(letStatement.getLeftField(), isCaseSensitive(), false); + if (existField != null + && !SqlTypeUtil.canCastFrom(existField.getType(), expressionType, true)) { + throw new GeaFlowDSLException( + letStatement.getParserPosition().toString(), + "Let statement cannot assign from {} to {}", + expressionType, + existField.getType().getSqlTypeName()); + } + RelDataType newLabelType = labelType; + if (existField == null) { // first define this let variable. + if (labelType.getSqlTypeName() == SqlTypeName.VERTEX) { + newLabelType = + ((VertexRecordType) labelType) + .add(letStatement.getLeftField(), expressionType, isCaseSensitive()); + } else if (labelType.getSqlTypeName() == SqlTypeName.EDGE) { + newLabelType = + ((EdgeRecordType) labelType) + .add(letStatement.getLeftField(), expressionType, isCaseSensitive()); + } else { + throw new IllegalArgumentException("Illegal labelType: " + labelType); + } + } + GraphRecordType graphRecordType = + GQLPathPatternScope.findCurrentGraphType(getValidator(), scope); + assert graphRecordType != null; + // Modify the query node level graph schema if it is a global vertex field modify + if (letStatement.isGlobal()) { + if (labelType.getSqlTypeName() != SqlTypeName.VERTEX) { + throw new GeaFlowDSLException( + letStatement.getParserPosition(), "Only vertex support global variable"); } - throw new GeaFlowDSLException(letStatement.getParserPosition(), "Let without from is not support"); + QueryNodeContext nodeContext = getValidator().getCurrentQueryNodeContext(); + GraphRecordType newGraphType = + graphRecordType.addVertexField(letStatement.getLeftField(), expressionType); + // add to the node context + nodeContext.addModifyGraph(newGraphType); + // add to the validator and used by the GQLToRelConverter. + getValidator().addModifyGraphType(letStatement, newGraphType); + } else { + getValidator().addModifyGraphType(letStatement, graphRecordType); + } + // replace the type for left label. + return inputPathType.copy(labelField.getIndex(), newLabelType); } + throw new GeaFlowDSLException( + letStatement.getParserPosition(), "Let without from is not support"); + } - @Override - public SqlNode getNode() { - return letStatement; - } + @Override + public SqlNode getNode() { + return letStatement; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeNamespace.java index eed46365f..355466e8f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeNamespace.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlNode; @@ -38,126 +39,129 @@ public class GQLMatchNodeNamespace extends GQLBaseNamespace { - private final SqlMatchNode matchNode; - - private MatchNodeContext matchNodeContext; - - public GQLMatchNodeNamespace(SqlValidatorImpl validator, SqlMatchNode matchNode) { - super(validator, matchNode); - this.matchNode = matchNode; + private final SqlMatchNode matchNode; + + private MatchNodeContext matchNodeContext; + + public GQLMatchNodeNamespace(SqlValidatorImpl validator, SqlMatchNode matchNode) { + super(validator, matchNode); + this.matchNode = matchNode; + } + + public void setMatchNodeContext(MatchNodeContext matchNodeContext) { + this.matchNodeContext = matchNodeContext; + } + + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + assert matchNodeContext != null : "matchNodeContext is null."; + if (matchNode instanceof SqlMatchEdge) { + SqlMatchEdge matchEdge = (SqlMatchEdge) matchNode; + if (matchEdge.getMinHop() < 0) { + throw new GeaFlowDSLException( + matchEdge.getParserPosition(), + "The min hop:{} count should greater than 0.", + matchEdge.getMinHop()); + } + if (matchEdge.getMaxHop() != -1 && matchEdge.getMaxHop() < matchEdge.getMinHop()) { + throw new GeaFlowDSLException( + matchEdge.getParserPosition(), + "The max hop: {} count should greater than min hop: {}.", + matchEdge.getMaxHop(), + matchEdge.getMinHop()); + } } - - public void setMatchNodeContext(MatchNodeContext matchNodeContext) { - this.matchNodeContext = matchNodeContext; - } - - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - assert matchNodeContext != null : "matchNodeContext is null."; - if (matchNode instanceof SqlMatchEdge) { - SqlMatchEdge matchEdge = (SqlMatchEdge) matchNode; - if (matchEdge.getMinHop() < 0) { - throw new GeaFlowDSLException(matchEdge.getParserPosition(), - "The min hop:{} count should greater than 0.", matchEdge.getMinHop()); - } - if (matchEdge.getMaxHop() != -1 && matchEdge.getMaxHop() < matchEdge.getMinHop()) { - throw new GeaFlowDSLException(matchEdge.getParserPosition(), - "The max hop: {} count should greater than min hop: {}.", matchEdge.getMaxHop(), - matchEdge.getMinHop()); - } - } - GQLPathPatternScope pathPatternScope = matchNodeContext.getPathPatternScope(); - RelDataType nodeType = pathPatternScope.resolveTypeByLabels(matchNode, matchNodeContext); - PathRecordType inputPathType = matchNodeContext.getInputPathType(); - PathRecordType outputPathType = inputPathType.addField(matchNode.getName(), nodeType, isCaseSensitive()); - // set outputPathType as the input for the next match node. - matchNodeContext.setInputPathType(outputPathType); - // register match node type - getValidator().registerMatchNodeType(matchNode, nodeType); - - setType(outputPathType); - - if (matchNode.getWhere() != null) { - SqlNode where = matchNode.getWhere(); - - SqlValidatorScope whereScope = getValidator().getScopes(matchNode.getWhere()); - // expand where expression. - SqlNode expandWhere = getValidator().expand(where, whereScope); - if (expandWhere != where) { - matchNode.setWhere(expandWhere); - getValidator().registerScope(expandWhere, whereScope); - where = matchNode.getWhere(); - } - - RelDataType boolType = getValidator().getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); - getValidator().inferUnknownTypes(boolType, whereScope, where); - where.validate(getValidator(), whereScope); - RelDataType conditionType = getValidator().deriveType(whereScope, where); - if (!SqlTypeUtil.inBooleanFamily(conditionType)) { - throw validator.newValidationError(where, RESOURCE.condMustBeBoolean("Filter")); - } - } - return outputPathType; + GQLPathPatternScope pathPatternScope = matchNodeContext.getPathPatternScope(); + RelDataType nodeType = pathPatternScope.resolveTypeByLabels(matchNode, matchNodeContext); + PathRecordType inputPathType = matchNodeContext.getInputPathType(); + PathRecordType outputPathType = + inputPathType.addField(matchNode.getName(), nodeType, isCaseSensitive()); + // set outputPathType as the input for the next match node. + matchNodeContext.setInputPathType(outputPathType); + // register match node type + getValidator().registerMatchNodeType(matchNode, nodeType); + + setType(outputPathType); + + if (matchNode.getWhere() != null) { + SqlNode where = matchNode.getWhere(); + + SqlValidatorScope whereScope = getValidator().getScopes(matchNode.getWhere()); + // expand where expression. + SqlNode expandWhere = getValidator().expand(where, whereScope); + if (expandWhere != where) { + matchNode.setWhere(expandWhere); + getValidator().registerScope(expandWhere, whereScope); + where = matchNode.getWhere(); + } + + RelDataType boolType = getValidator().getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); + getValidator().inferUnknownTypes(boolType, whereScope, where); + where.validate(getValidator(), whereScope); + RelDataType conditionType = getValidator().deriveType(whereScope, where); + if (!SqlTypeUtil.inBooleanFamily(conditionType)) { + throw validator.newValidationError(where, RESOURCE.condMustBeBoolean("Filter")); + } } + return outputPathType; + } - @Override - public SqlNode getNode() { - return matchNode; - } + @Override + public SqlNode getNode() { + return matchNode; + } - public RelDataType getNodeType() { - return getValidator().getMatchNodeType(matchNode); - } + public RelDataType getNodeType() { + return getValidator().getMatchNodeType(matchNode); + } - public static class MatchNodeContext { + public static class MatchNodeContext { - private GQLPathPatternScope pathPatternScope; + private GQLPathPatternScope pathPatternScope; - private boolean isFirstNode; + private boolean isFirstNode; - /** - * The input path record type for current match node. - */ - private PathRecordType inputPathType; + /** The input path record type for current match node. */ + private PathRecordType inputPathType; - private final List resolvedPathPatternTypes = new ArrayList<>(); + private final List resolvedPathPatternTypes = new ArrayList<>(); - public GQLPathPatternScope getPathPatternScope() { - return pathPatternScope; - } + public GQLPathPatternScope getPathPatternScope() { + return pathPatternScope; + } - public void setPathPatternScope(GQLPathPatternScope pathPatternScope) { - this.pathPatternScope = pathPatternScope; - } + public void setPathPatternScope(GQLPathPatternScope pathPatternScope) { + this.pathPatternScope = pathPatternScope; + } - public boolean isFirstNode() { - return isFirstNode; - } + public boolean isFirstNode() { + return isFirstNode; + } - public void setFirstNode(boolean firstNode) { - isFirstNode = firstNode; - } + public void setFirstNode(boolean firstNode) { + isFirstNode = firstNode; + } - public PathRecordType getInputPathType() { - return inputPathType; - } + public PathRecordType getInputPathType() { + return inputPathType; + } - public void setInputPathType(PathRecordType inputPathType) { - this.inputPathType = inputPathType; - } + public void setInputPathType(PathRecordType inputPathType) { + this.inputPathType = inputPathType; + } - public void addResolvedPathPatternType(PathRecordType pathRecordType) { - resolvedPathPatternTypes.add(pathRecordType); - } + public void addResolvedPathPatternType(PathRecordType pathRecordType) { + resolvedPathPatternTypes.add(pathRecordType); + } - public RelDataTypeField getResolvedField(String name, boolean caseSensitive) { - for (PathRecordType pathRecordType : resolvedPathPatternTypes) { - RelDataTypeField field = pathRecordType.getField(name, caseSensitive, false); - if (field != null) { - return field; - } - } - return null; + public RelDataTypeField getResolvedField(String name, boolean caseSensitive) { + for (PathRecordType pathRecordType : resolvedPathPatternTypes) { + RelDataTypeField field = pathRecordType.getField(name, caseSensitive, false); + if (field != null) { + return field; } + } + return null; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeWhereNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeWhereNamespace.java index 04634864c..de860a48b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeWhereNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchNodeWhereNamespace.java @@ -25,24 +25,24 @@ public class GQLMatchNodeWhereNamespace extends GQLBaseNamespace { - private final SqlNode matchNode; - - private final GQLMatchNodeNamespace nodeNamespace; - - public GQLMatchNodeWhereNamespace(SqlValidatorImpl validator, SqlNode matchNode, - GQLMatchNodeNamespace nodeNamespace) { - super(validator, matchNode); - this.matchNode = matchNode; - this.nodeNamespace = nodeNamespace; - } - - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - return nodeNamespace.getNodeType(); - } - - @Override - public SqlNode getNode() { - return matchNode; - } + private final SqlNode matchNode; + + private final GQLMatchNodeNamespace nodeNamespace; + + public GQLMatchNodeWhereNamespace( + SqlValidatorImpl validator, SqlNode matchNode, GQLMatchNodeNamespace nodeNamespace) { + super(validator, matchNode); + this.matchNode = matchNode; + this.nodeNamespace = nodeNamespace; + } + + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + return nodeNamespace.getNodeType(); + } + + @Override + public SqlNode getNode() { + return matchNode; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchPatternNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchPatternNamespace.java index 75bff9d12..95cb8356a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchPatternNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLMatchPatternNamespace.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; @@ -44,160 +45,155 @@ public class GQLMatchPatternNamespace extends GQLBaseNamespace { - private final SqlMatchPattern matchPattern; + private final SqlMatchPattern matchPattern; - public GQLMatchPatternNamespace(SqlValidatorImpl validator, SqlMatchPattern matchPattern) { - super(validator, matchPattern); - this.matchPattern = matchPattern; - } + public GQLMatchPatternNamespace(SqlValidatorImpl validator, SqlMatchPattern matchPattern) { + super(validator, matchPattern); + this.matchPattern = matchPattern; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - SqlValidatorNamespace fromNs = validator.getNamespace(matchPattern.getFrom()); - fromNs.validate(targetRowType); - RelDataType fromType = fromNs.getType(); + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + SqlValidatorNamespace fromNs = validator.getNamespace(matchPattern.getFrom()); + fromNs.validate(targetRowType); + RelDataType fromType = fromNs.getType(); - if (!(fromType instanceof GraphRecordType) && !(fromType instanceof PathRecordType)) { - throw new GeaFlowDSLException(matchPattern.getParserPosition(), - "Only can match from a graph or match statement"); - } - - SqlNodeList pathPatternNodes = matchPattern.getPathPatterns(); - SqlValidatorScope scope = getValidator().getScopes(matchPattern); - - MatchNodeContext matchNodeContext = new MatchNodeContext(); - - List pathPatternTypes = new ArrayList<>(); - for (SqlNode pathPatternNode : pathPatternNodes) { - RelDataType pathType; - if (pathPatternNode instanceof SqlPathPattern) { - SqlPathPattern pathPattern = (SqlPathPattern) pathPatternNode; - GQLPathPatternNamespace pathPatternNs = - (GQLPathPatternNamespace) validator.getNamespace(pathPatternNode); - pathPatternNs.setMatchNodeContext(matchNodeContext); - pathPattern.validate(validator, scope); - pathType = validator.getValidatedNodeType(pathPattern); - } else { - SqlUnionPathPattern pathPattern = (SqlUnionPathPattern) pathPatternNode; - GQLUnionPathPatternNamespace pathPatternNs = - (GQLUnionPathPatternNamespace) validator.getNamespace(pathPatternNode); - pathPatternNs.setMatchNodeContext(matchNodeContext); - pathPattern.validate(validator, scope); - pathType = validator.getValidatedNodeType(pathPattern); - } - if (!(pathType instanceof PathRecordType)) { - throw new IllegalStateException("PathPattern should return PathRecordType"); - } - matchNodeContext.addResolvedPathPatternType((PathRecordType) pathType); - pathPatternTypes.add((PathRecordType) pathType); - } - PathRecordType matchType = createMatchType(pathPatternTypes); - // join from path with current match path for continue match. - if (fromType instanceof PathRecordType) { - if (matchPattern.isSinglePattern() - && ((PathRecordType) fromType).canConcat(matchType)) { - matchType = ((PathRecordType) fromType).concat(matchType, isCaseSensitive()); - } else { - matchType = ((PathRecordType) fromType).join(matchType, getValidator().getTypeFactory()); - } - } + if (!(fromType instanceof GraphRecordType) && !(fromType instanceof PathRecordType)) { + throw new GeaFlowDSLException( + matchPattern.getParserPosition(), "Only can match from a graph or match statement"); + } - if (matchPattern.getWhere() != null) { - setType(matchType); - SqlNode where = matchPattern.getWhere(); - GQLScope whereScope = (GQLScope) getValidator().getScopes(where); - where = validator.expand(where, whereScope); - matchPattern.setWhere(where); - RelDataType boolType = validator.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); - getValidator().inferUnknownTypes(boolType, whereScope, where); - RelDataType conditionType = validator.deriveType(whereScope, where); - if (!SqlTypeUtil.inBooleanFamily(conditionType)) { - throw validator.newValidationError(where, RESOURCE.condMustBeBoolean("Filter")); - } - } - setType(matchType); - validateOrderList(); - return matchType; + SqlNodeList pathPatternNodes = matchPattern.getPathPatterns(); + SqlValidatorScope scope = getValidator().getScopes(matchPattern); + + MatchNodeContext matchNodeContext = new MatchNodeContext(); + + List pathPatternTypes = new ArrayList<>(); + for (SqlNode pathPatternNode : pathPatternNodes) { + RelDataType pathType; + if (pathPatternNode instanceof SqlPathPattern) { + SqlPathPattern pathPattern = (SqlPathPattern) pathPatternNode; + GQLPathPatternNamespace pathPatternNs = + (GQLPathPatternNamespace) validator.getNamespace(pathPatternNode); + pathPatternNs.setMatchNodeContext(matchNodeContext); + pathPattern.validate(validator, scope); + pathType = validator.getValidatedNodeType(pathPattern); + } else { + SqlUnionPathPattern pathPattern = (SqlUnionPathPattern) pathPatternNode; + GQLUnionPathPatternNamespace pathPatternNs = + (GQLUnionPathPatternNamespace) validator.getNamespace(pathPatternNode); + pathPatternNs.setMatchNodeContext(matchNodeContext); + pathPattern.validate(validator, scope); + pathType = validator.getValidatedNodeType(pathPattern); + } + if (!(pathType instanceof PathRecordType)) { + throw new IllegalStateException("PathPattern should return PathRecordType"); + } + matchNodeContext.addResolvedPathPatternType((PathRecordType) pathType); + pathPatternTypes.add((PathRecordType) pathType); + } + PathRecordType matchType = createMatchType(pathPatternTypes); + // join from path with current match path for continue match. + if (fromType instanceof PathRecordType) { + if (matchPattern.isSinglePattern() && ((PathRecordType) fromType).canConcat(matchType)) { + matchType = ((PathRecordType) fromType).concat(matchType, isCaseSensitive()); + } else { + matchType = ((PathRecordType) fromType).join(matchType, getValidator().getTypeFactory()); + } } - private PathRecordType createMatchType(List pathPatternTypes) { - List concatPathTypes = new ArrayList<>(); - - for (PathRecordType pathType : pathPatternTypes) { - if (pathType instanceof UnionPathRecordType) { - concatPathTypes.add(pathType); - } else { - assert pathType.firstFieldName().isPresent() && pathType.lastFieldName().isPresent(); - String firstField = pathType.firstFieldName().get(); - String lastField = pathType.lastFieldName().get(); - int i; - for (i = 0; i < concatPathTypes.size(); i++) { - PathRecordType pathType2 = concatPathTypes.get(i); - if (pathType2 instanceof UnionPathRecordType) { - continue; - } - assert pathType2.firstFieldName().isPresent() && pathType2.lastFieldName().isPresent(); - String firstField2 = pathType2.firstFieldName().get(); - String lastField2 = pathType2.lastFieldName().get(); - // pathType2 is "(a) - (b) - (c)", while pathType is "(c) - (d)" - // concat pathType to pathType2 and return "(a) - (b) - (c) - (d)" - if (getValidator().nameMatcher().matches(lastField2, firstField)) { - PathRecordType concatPathType = pathType2.concat(pathType, isCaseSensitive()); - concatPathTypes.set(i, concatPathType); - break; - } else if (getValidator().nameMatcher().matches(lastField, firstField2)) { - // pathType2 is "(c) - (d)", while pathType is "(a) - (b) - (c)" - // concat pathType2 to pathType and return "(a) - (b) - (c) - (d)" - PathRecordType concatPathType = pathType.concat(pathType2, isCaseSensitive()); - concatPathTypes.set(i, concatPathType); - break; - } - } - // If cannot concat, just add to list. - if (i == concatPathTypes.size()) { - concatPathTypes.add(pathType); - } - } + if (matchPattern.getWhere() != null) { + setType(matchType); + SqlNode where = matchPattern.getWhere(); + GQLScope whereScope = (GQLScope) getValidator().getScopes(where); + where = validator.expand(where, whereScope); + matchPattern.setWhere(where); + RelDataType boolType = validator.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); + getValidator().inferUnknownTypes(boolType, whereScope, where); + RelDataType conditionType = validator.deriveType(whereScope, where); + if (!SqlTypeUtil.inBooleanFamily(conditionType)) { + throw validator.newValidationError(where, RESOURCE.condMustBeBoolean("Filter")); + } + } + setType(matchType); + validateOrderList(); + return matchType; + } + + private PathRecordType createMatchType(List pathPatternTypes) { + List concatPathTypes = new ArrayList<>(); + + for (PathRecordType pathType : pathPatternTypes) { + if (pathType instanceof UnionPathRecordType) { + concatPathTypes.add(pathType); + } else { + assert pathType.firstFieldName().isPresent() && pathType.lastFieldName().isPresent(); + String firstField = pathType.firstFieldName().get(); + String lastField = pathType.lastFieldName().get(); + int i; + for (i = 0; i < concatPathTypes.size(); i++) { + PathRecordType pathType2 = concatPathTypes.get(i); + if (pathType2 instanceof UnionPathRecordType) { + continue; + } + assert pathType2.firstFieldName().isPresent() && pathType2.lastFieldName().isPresent(); + String firstField2 = pathType2.firstFieldName().get(); + String lastField2 = pathType2.lastFieldName().get(); + // pathType2 is "(a) - (b) - (c)", while pathType is "(c) - (d)" + // concat pathType to pathType2 and return "(a) - (b) - (c) - (d)" + if (getValidator().nameMatcher().matches(lastField2, firstField)) { + PathRecordType concatPathType = pathType2.concat(pathType, isCaseSensitive()); + concatPathTypes.set(i, concatPathType); + break; + } else if (getValidator().nameMatcher().matches(lastField, firstField2)) { + // pathType2 is "(c) - (d)", while pathType is "(a) - (b) - (c)" + // concat pathType2 to pathType and return "(a) - (b) - (c) - (d)" + PathRecordType concatPathType = pathType.concat(pathType2, isCaseSensitive()); + concatPathTypes.set(i, concatPathType); + break; + } } - - PathRecordType joinPathType = concatPathTypes.get(0); - for (int i = 1; i < concatPathTypes.size(); i++) { - joinPathType = joinPathType.join(concatPathTypes.get(i), getValidator().getTypeFactory()); + // If cannot concat, just add to list. + if (i == concatPathTypes.size()) { + concatPathTypes.add(pathType); } - return joinPathType; + } } - protected void validateOrderList() { - SqlNodeList orderList = matchPattern.getOrderBy(); - if (orderList == null) { - return; - } - final SqlValidatorScope scope = getValidator().getScopes(orderList); - Objects.requireNonNull(scope); - getValidator().inferUnknownTypes( - validator.getUnknownType(), scope, orderList); - - List expandList = new ArrayList<>(); - for (SqlNode orderItem : orderList) { - SqlNode expandedOrderItem = - getValidator().expand(orderItem, scope); - expandList.add(expandedOrderItem); - } + PathRecordType joinPathType = concatPathTypes.get(0); + for (int i = 1; i < concatPathTypes.size(); i++) { + joinPathType = joinPathType.join(concatPathTypes.get(i), getValidator().getTypeFactory()); + } + return joinPathType; + } - SqlNodeList expandedOrderList = new SqlNodeList( - expandList, - orderList.getParserPosition()); + protected void validateOrderList() { + SqlNodeList orderList = matchPattern.getOrderBy(); + if (orderList == null) { + return; + } + final SqlValidatorScope scope = getValidator().getScopes(orderList); + Objects.requireNonNull(scope); + getValidator().inferUnknownTypes(validator.getUnknownType(), scope, orderList); + + List expandList = new ArrayList<>(); + for (SqlNode orderItem : orderList) { + SqlNode expandedOrderItem = getValidator().expand(orderItem, scope); + expandList.add(expandedOrderItem); + } - matchPattern.setOrderBy(expandedOrderList); - getValidator().registerScope(expandedOrderList, scope); + SqlNodeList expandedOrderList = new SqlNodeList(expandList, orderList.getParserPosition()); - for (SqlNode orderItem : expandedOrderList) { - validator.deriveType(scope, orderItem); - } - } + matchPattern.setOrderBy(expandedOrderList); + getValidator().registerScope(expandedOrderList, scope); - @Override - public SqlNode getNode() { - return matchPattern; + for (SqlNode orderItem : expandedOrderList) { + validator.deriveType(scope, orderItem); } + } + + @Override + public SqlNode getNode() { + return matchPattern; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLPathPatternNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLPathPatternNamespace.java index dfa1a5321..1a13ea658 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLPathPatternNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLPathPatternNamespace.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Set; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.SqlValidatorImpl; @@ -34,67 +35,70 @@ public class GQLPathPatternNamespace extends GQLBaseNamespace { - private final SqlPathPattern pathPattern; + private final SqlPathPattern pathPattern; - private MatchNodeContext matchNodeContext; + private MatchNodeContext matchNodeContext; - public GQLPathPatternNamespace(SqlValidatorImpl validator, SqlPathPattern pathPattern) { - super(validator, pathPattern); - this.pathPattern = pathPattern; - } + public GQLPathPatternNamespace(SqlValidatorImpl validator, SqlPathPattern pathPattern) { + super(validator, pathPattern); + this.pathPattern = pathPattern; + } - public void setMatchNodeContext(MatchNodeContext matchNodeContext) { - this.matchNodeContext = matchNodeContext; - } + public void setMatchNodeContext(MatchNodeContext matchNodeContext) { + this.matchNodeContext = matchNodeContext; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - assert matchNodeContext != null : "matchNodeContext has not set"; - GQLPathPatternScope scope = (GQLPathPatternScope) getValidator().getScopes(pathPattern); - matchNodeContext.setPathPatternScope(scope); + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + assert matchNodeContext != null : "matchNodeContext has not set"; + GQLPathPatternScope scope = (GQLPathPatternScope) getValidator().getScopes(pathPattern); + matchNodeContext.setPathPatternScope(scope); - // for match in sub-query, the parentPathType is the output type of the parent match. - PathRecordType parentPathType = PathRecordType.EMPTY; - if (scope.getParent() instanceof GQLSubQueryScope) { - parentPathType = ((GQLSubQueryScope) scope.getParent()).getInputPathType(); - } - PathRecordType outputPathType = null; - Set matchNodeAlias = new HashSet<>(); - boolean isFirstNode = true; - // init the input path type. - matchNodeContext.setInputPathType(parentPathType); - for (SqlNode pathNode : pathPattern.getPathNodes()) { - SqlMatchNode matchNode = (SqlMatchNode) pathNode; - if (matchNodeAlias.contains(matchNode.getName())) { - throw new GeaFlowDSLException(matchNode.getParserPosition(), - "Duplicated node label: {} in the path pattern.", matchNode.getName()); - } else { - matchNodeAlias.add(matchNode.getName()); - } - GQLMatchNodeNamespace nodeNs = (GQLMatchNodeNamespace) validator.getNamespace(matchNode); - matchNodeContext.setFirstNode(isFirstNode); - nodeNs.setMatchNodeContext(matchNodeContext); + // for match in sub-query, the parentPathType is the output type of the parent match. + PathRecordType parentPathType = PathRecordType.EMPTY; + if (scope.getParent() instanceof GQLSubQueryScope) { + parentPathType = ((GQLSubQueryScope) scope.getParent()).getInputPathType(); + } + PathRecordType outputPathType = null; + Set matchNodeAlias = new HashSet<>(); + boolean isFirstNode = true; + // init the input path type. + matchNodeContext.setInputPathType(parentPathType); + for (SqlNode pathNode : pathPattern.getPathNodes()) { + SqlMatchNode matchNode = (SqlMatchNode) pathNode; + if (matchNodeAlias.contains(matchNode.getName())) { + throw new GeaFlowDSLException( + matchNode.getParserPosition(), + "Duplicated node label: {} in the path pattern.", + matchNode.getName()); + } else { + matchNodeAlias.add(matchNode.getName()); + } + GQLMatchNodeNamespace nodeNs = (GQLMatchNodeNamespace) validator.getNamespace(matchNode); + matchNodeContext.setFirstNode(isFirstNode); + nodeNs.setMatchNodeContext(matchNodeContext); - matchNode.validate(validator, scope); - outputPathType = (PathRecordType) validator.getValidatedNodeType(matchNode); + matchNode.validate(validator, scope); + outputPathType = (PathRecordType) validator.getValidatedNodeType(matchNode); - isFirstNode = false; - } - assert outputPathType != null; - // concat output path type with parent match's path type if current path pattern is in sub-query. - return concatParentPathType(parentPathType, outputPathType); + isFirstNode = false; } + assert outputPathType != null; + // concat output path type with parent match's path type if current path pattern is in + // sub-query. + return concatParentPathType(parentPathType, outputPathType); + } - private PathRecordType concatParentPathType(PathRecordType parentPathType, - PathRecordType pathRecordType) { - if (parentPathType == null) { - return pathRecordType; - } - return parentPathType.concat(pathRecordType, isCaseSensitive()); + private PathRecordType concatParentPathType( + PathRecordType parentPathType, PathRecordType pathRecordType) { + if (parentPathType == null) { + return pathRecordType; } + return parentPathType.concat(pathRecordType, isCaseSensitive()); + } - @Override - public SqlNode getNode() { - return pathPattern; - } + @Override + public SqlNode getNode() { + return pathPattern; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLReturnNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLReturnNamespace.java index 4fb3ee118..a248018df 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLReturnNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLReturnNamespace.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.*; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -44,201 +45,190 @@ public class GQLReturnNamespace extends GQLBaseNamespace { - private final SqlReturnStatement returnStatement; - - public GQLReturnNamespace(SqlValidatorImpl validator, SqlNode enclosingNode, - SqlReturnStatement returnStatement) { - super(validator, enclosingNode); - this.returnStatement = returnStatement; - } - - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - assert targetRowType != null; - - GQLReturnNamespace ns = validator.getNamespace(returnStatement).unwrap(GQLReturnNamespace.class); - assert ns.rowType == null; - // validate from - SqlValidatorNamespace fromNs = validator.getNamespace(returnStatement.getFrom()); - fromNs.validate(targetRowType); - //return * are not support - final SqlNodeList selectItems = returnStatement.getReturnList(); - if (selectItems.size() == 1) { - final SqlNode selectItem = selectItems.get(0); - if (selectItem instanceof SqlIdentifier) { - SqlIdentifier id = (SqlIdentifier) selectItem; - if (id.isStar() && (id.names.size() == 1)) { - throw new GeaFlowDSLException("'Return * ' is not support at " + id.getParserPosition()); - } - } + private final SqlReturnStatement returnStatement; + + public GQLReturnNamespace( + SqlValidatorImpl validator, SqlNode enclosingNode, SqlReturnStatement returnStatement) { + super(validator, enclosingNode); + this.returnStatement = returnStatement; + } + + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + assert targetRowType != null; + + GQLReturnNamespace ns = + validator.getNamespace(returnStatement).unwrap(GQLReturnNamespace.class); + assert ns.rowType == null; + // validate from + SqlValidatorNamespace fromNs = validator.getNamespace(returnStatement.getFrom()); + fromNs.validate(targetRowType); + // return * are not support + final SqlNodeList selectItems = returnStatement.getReturnList(); + if (selectItems.size() == 1) { + final SqlNode selectItem = selectItems.get(0); + if (selectItem instanceof SqlIdentifier) { + SqlIdentifier id = (SqlIdentifier) selectItem; + if (id.isStar() && (id.names.size() == 1)) { + throw new GeaFlowDSLException("'Return * ' is not support at " + id.getParserPosition()); } + } + } - final RelDataType rowType = - validateReturnList(selectItems, returnStatement, targetRowType); - ns.setType(rowType); - validateGroupClause(returnStatement); - handleOffsetFetch(returnStatement.getOffset(), returnStatement.getFetch()); + final RelDataType rowType = validateReturnList(selectItems, returnStatement, targetRowType); + ns.setType(rowType); + validateGroupClause(returnStatement); + handleOffsetFetch(returnStatement.getOffset(), returnStatement.getFetch()); - GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStatement); - for (SqlNode returnItem : returnStatement.getReturnList()) { - returnItem.validate(getValidator(), scope); - scope.validateExpr(returnItem); - } - validateOrderList(returnStatement); - return rowType; + GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStatement); + for (SqlNode returnItem : returnStatement.getReturnList()) { + returnItem.validate(getValidator(), scope); + scope.validateExpr(returnItem); } - - protected RelDataType validateReturnList( - final SqlNodeList returnItems, - SqlReturnStatement returnStmt, - RelDataType targetRowType) { - GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStatement); - - final List expandedReturnItems = new ArrayList<>(); - final Set aliases = new HashSet<>(); - final List> fieldList = new ArrayList<>(); - - for (int i = 0; i < returnItems.size(); i++) { - SqlNode returnItem = returnItems.get(i); - if (returnItem instanceof SqlSelect) { - throw new GeaFlowDSLException("SubQuery in Return statement is not support"); - } else { - //expand select items - SqlNode expanded = validator.expand(returnItem, scope); - final String alias = - SqlValidatorUtil.getAlias(returnItem, aliases.size()); - // If expansion has altered the natural alias, supply an explicit 'AS'. - if (expanded != returnItem) { - String newAlias = - SqlValidatorUtil.getAlias(expanded, aliases.size()); - if (!newAlias.equals(alias)) { - expanded = - SqlStdOperatorTable.AS.createCall( - returnItem.getParserPosition(), - expanded, - new SqlIdentifier(alias, SqlParserPos.ZERO)); - validator.deriveType(scope, expanded); - } - } - expandedReturnItems.add(expanded); - if (aliases.contains(alias)) { - throw new GeaFlowDSLException("Duplicated alias in an Return statement at " - + returnItem.getParserPosition()); - } - aliases.add(alias); - RelDataType targetType = targetRowType.isStruct() - && targetRowType.getFieldCount() - 1 >= i - ? targetRowType.getFieldList().get(i).getType() - : validator.getUnknownType(); - - getValidator().inferUnknownTypes(targetType, scope, expanded); - final RelDataType type = validator.deriveType(scope, expanded); - validator.setValidatedNodeType(expanded, type); - fieldList.add(Pair.of(alias, type)); - } + validateOrderList(returnStatement); + return rowType; + } + + protected RelDataType validateReturnList( + final SqlNodeList returnItems, SqlReturnStatement returnStmt, RelDataType targetRowType) { + GQLReturnScope scope = (GQLReturnScope) getValidator().getScopes(returnStatement); + + final List expandedReturnItems = new ArrayList<>(); + final Set aliases = new HashSet<>(); + final List> fieldList = new ArrayList<>(); + + for (int i = 0; i < returnItems.size(); i++) { + SqlNode returnItem = returnItems.get(i); + if (returnItem instanceof SqlSelect) { + throw new GeaFlowDSLException("SubQuery in Return statement is not support"); + } else { + // expand select items + SqlNode expanded = validator.expand(returnItem, scope); + final String alias = SqlValidatorUtil.getAlias(returnItem, aliases.size()); + // If expansion has altered the natural alias, supply an explicit 'AS'. + if (expanded != returnItem) { + String newAlias = SqlValidatorUtil.getAlias(expanded, aliases.size()); + if (!newAlias.equals(alias)) { + expanded = + SqlStdOperatorTable.AS.createCall( + returnItem.getParserPosition(), + expanded, + new SqlIdentifier(alias, SqlParserPos.ZERO)); + validator.deriveType(scope, expanded); + } } - // Create the new select list with expanded items. Pass through - // the original parser position so that any overall failures can - // still reference the original input text. - SqlNodeList newReturnList = - new SqlNodeList( - expandedReturnItems, - returnItems.getParserPosition()); - if (validator.shouldExpandIdentifiers()) { - returnStmt.setReturnList(newReturnList); + expandedReturnItems.add(expanded); + if (aliases.contains(alias)) { + throw new GeaFlowDSLException( + "Duplicated alias in an Return statement at " + returnItem.getParserPosition()); } - scope.setExpandedReturnList(expandedReturnItems); - getValidator().inferUnknownTypes(targetRowType, scope, newReturnList); - - return validator.getTypeFactory().createStructType(fieldList); + aliases.add(alias); + RelDataType targetType = + targetRowType.isStruct() && targetRowType.getFieldCount() - 1 >= i + ? targetRowType.getFieldList().get(i).getType() + : validator.getUnknownType(); + + getValidator().inferUnknownTypes(targetType, scope, expanded); + final RelDataType type = validator.deriveType(scope, expanded); + validator.setValidatedNodeType(expanded, type); + fieldList.add(Pair.of(alias, type)); + } + } + // Create the new select list with expanded items. Pass through + // the original parser position so that any overall failures can + // still reference the original input text. + SqlNodeList newReturnList = + new SqlNodeList(expandedReturnItems, returnItems.getParserPosition()); + if (validator.shouldExpandIdentifiers()) { + returnStmt.setReturnList(newReturnList); } + scope.setExpandedReturnList(expandedReturnItems); + getValidator().inferUnknownTypes(targetRowType, scope, newReturnList); - protected void validateGroupClause(SqlReturnStatement returnStmt) { - SqlNodeList groupList = returnStmt.getGroupBy(); - if (groupList == null) { - return; - } - final SqlValidatorScope scope = getValidator().getScopes(groupList); - Objects.requireNonNull(scope); - getValidator().inferUnknownTypes( - validator.getUnknownType(), scope, groupList); - - // expand the expression in group list. - List expandedList = new ArrayList<>(); - for (SqlNode groupItem : groupList) { - SqlNode expandedGroupItem = - getValidator().expandReturnGroupOrderExpr(returnStmt, - scope, groupItem); - expandedList.add(expandedGroupItem); - } - groupList = new SqlNodeList(expandedList, groupList.getParserPosition()); - //GROUPING SETS, ROLLUP or CUBE are not support - for (SqlNode node : groupList) { - if (node.getKind() != SqlKind.IDENTIFIER) { - throw new GeaFlowDSLException("Group by non-identifier is not support, actually " + node.getKind()); - } - } - returnStatement.setGroupBy(groupList); - getValidator().registerScope(groupList, scope); + return validator.getTypeFactory().createStructType(fieldList); + } - for (SqlNode groupItem : expandedList) { - scope.validateExpr(groupItem); - } - if (getValidator().getAggregate(groupList) != null) { - throw new GeaFlowDSLException("Aggregation in Group By is not support."); - } + protected void validateGroupClause(SqlReturnStatement returnStmt) { + SqlNodeList groupList = returnStmt.getGroupBy(); + if (groupList == null) { + return; } + final SqlValidatorScope scope = getValidator().getScopes(groupList); + Objects.requireNonNull(scope); + getValidator().inferUnknownTypes(validator.getUnknownType(), scope, groupList); + + // expand the expression in group list. + List expandedList = new ArrayList<>(); + for (SqlNode groupItem : groupList) { + SqlNode expandedGroupItem = + getValidator().expandReturnGroupOrderExpr(returnStmt, scope, groupItem); + expandedList.add(expandedGroupItem); + } + groupList = new SqlNodeList(expandedList, groupList.getParserPosition()); + // GROUPING SETS, ROLLUP or CUBE are not support + for (SqlNode node : groupList) { + if (node.getKind() != SqlKind.IDENTIFIER) { + throw new GeaFlowDSLException( + "Group by non-identifier is not support, actually " + node.getKind()); + } + } + returnStatement.setGroupBy(groupList); + getValidator().registerScope(groupList, scope); - private void handleOffsetFetch(SqlNode offset, SqlNode fetch) { - if (offset instanceof SqlDynamicParam) { - validator.setValidatedNodeType(offset, - validator.getTypeFactory().createSqlType(SqlTypeName.INTEGER)); - } - if (fetch instanceof SqlDynamicParam) { - validator.setValidatedNodeType(fetch, - validator.getTypeFactory().createSqlType(SqlTypeName.INTEGER)); - } + for (SqlNode groupItem : expandedList) { + scope.validateExpr(groupItem); } + if (getValidator().getAggregate(groupList) != null) { + throw new GeaFlowDSLException("Aggregation in Group By is not support."); + } + } - protected void validateOrderList(SqlReturnStatement returnStmt) { - SqlNodeList orderList = returnStmt.getOrderList(); - if (orderList == null) { - return; - } - final SqlValidatorScope scope = getValidator().getScopes(orderList); - Objects.requireNonNull(scope); - getValidator().inferUnknownTypes( - validator.getUnknownType(), scope, orderList); - - List expandList = new ArrayList<>(); - for (SqlNode orderItem : orderList) { - SqlNode expandedOrderItem = - getValidator().expandReturnGroupOrderExpr(returnStmt, - scope, orderItem); - expandList.add(expandedOrderItem); - } + private void handleOffsetFetch(SqlNode offset, SqlNode fetch) { + if (offset instanceof SqlDynamicParam) { + validator.setValidatedNodeType( + offset, validator.getTypeFactory().createSqlType(SqlTypeName.INTEGER)); + } + if (fetch instanceof SqlDynamicParam) { + validator.setValidatedNodeType( + fetch, validator.getTypeFactory().createSqlType(SqlTypeName.INTEGER)); + } + } - SqlNodeList expandedOrderList = new SqlNodeList( - expandList, - orderList.getParserPosition()); + protected void validateOrderList(SqlReturnStatement returnStmt) { + SqlNodeList orderList = returnStmt.getOrderList(); + if (orderList == null) { + return; + } + final SqlValidatorScope scope = getValidator().getScopes(orderList); + Objects.requireNonNull(scope); + getValidator().inferUnknownTypes(validator.getUnknownType(), scope, orderList); + + List expandList = new ArrayList<>(); + for (SqlNode orderItem : orderList) { + SqlNode expandedOrderItem = + getValidator().expandReturnGroupOrderExpr(returnStmt, scope, orderItem); + expandList.add(expandedOrderItem); + } - returnStatement.setOrderBy(expandedOrderList); + SqlNodeList expandedOrderList = new SqlNodeList(expandList, orderList.getParserPosition()); - getValidator().registerScope(expandedOrderList, scope); + returnStatement.setOrderBy(expandedOrderList); - for (SqlNode orderItem : expandedOrderList) { - if (orderItem.getKind() == DESCENDING) { - assert RESOURCE.sQLConformance_OrderByDesc() - .getProperties().get("FeatureDefinition") != null; - scope.validateExpr(((SqlCall) orderItem).operand(0)); - } else { - scope.validateExpr(orderItem); - } - } - } + getValidator().registerScope(expandedOrderList, scope); - @Override - public SqlNode getNode() { - return returnStatement; + for (SqlNode orderItem : expandedOrderList) { + if (orderItem.getKind() == DESCENDING) { + assert RESOURCE.sQLConformance_OrderByDesc().getProperties().get("FeatureDefinition") + != null; + scope.validateExpr(((SqlCall) orderItem).operand(0)); + } else { + scope.validateExpr(orderItem); + } } + } + + @Override + public SqlNode getNode() { + return returnStatement; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLSubQueryNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLSubQueryNamespace.java index 4ab8555a8..379db7679 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLSubQueryNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLSubQueryNamespace.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.validator.namespace; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlKind; @@ -37,73 +38,76 @@ public class GQLSubQueryNamespace extends GQLBaseNamespace { - private final SqlPathPatternSubQuery subQuery; + private final SqlPathPatternSubQuery subQuery; - public GQLSubQueryNamespace(SqlValidatorImpl validator, - SqlPathPatternSubQuery subQuery) { - super(validator, subQuery); - this.subQuery = subQuery; - } + public GQLSubQueryNamespace(SqlValidatorImpl validator, SqlPathPatternSubQuery subQuery) { + super(validator, subQuery); + this.subQuery = subQuery; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - GQLScope scope = (GQLScope) getValidator().getScopes(subQuery); - assert scope.children.size() == 1; - SqlValidatorNamespace fromNs = scope.children.get(0).getNamespace(); - assert fromNs.getType() instanceof PathRecordType; - PathRecordType inputPathType = (PathRecordType) fromNs.getType(); - SqlPathPattern pathPattern = subQuery.getPathPattern(); - List pathNodes = pathPattern.getPathNodes().getList(); - assert pathNodes.size() > 0; + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + GQLScope scope = (GQLScope) getValidator().getScopes(subQuery); + assert scope.children.size() == 1; + SqlValidatorNamespace fromNs = scope.children.get(0).getNamespace(); + assert fromNs.getType() instanceof PathRecordType; + PathRecordType inputPathType = (PathRecordType) fromNs.getType(); + SqlPathPattern pathPattern = subQuery.getPathPattern(); + List pathNodes = pathPattern.getPathNodes().getList(); + assert pathNodes.size() > 0; - SqlMatchNode firstNode = (SqlMatchNode) pathNodes.get(0); - validateFirstNode(firstNode, inputPathType); - for (int i = 1; i < pathNodes.size(); i++) { - validateOtherNode((SqlMatchNode) pathNodes.get(i), inputPathType); - } + SqlMatchNode firstNode = (SqlMatchNode) pathNodes.get(0); + validateFirstNode(firstNode, inputPathType); + for (int i = 1; i < pathNodes.size(); i++) { + validateOtherNode((SqlMatchNode) pathNodes.get(i), inputPathType); + } - GQLPathPatternNamespace pathPatternNs = - validator.getNamespace(pathPattern).unwrap(GQLPathPatternNamespace.class); - MatchNodeContext context = new MatchNodeContext(); - pathPatternNs.setMatchNodeContext(context); + GQLPathPatternNamespace pathPatternNs = + validator.getNamespace(pathPattern).unwrap(GQLPathPatternNamespace.class); + MatchNodeContext context = new MatchNodeContext(); + pathPatternNs.setMatchNodeContext(context); - SqlNode returnValue = subQuery.getReturnValue(); - if (returnValue != null) { - GQLScope returnValueScope = (GQLScope) getValidator().getScopes(returnValue); - returnValue = validator.expand(returnValue, returnValueScope); - subQuery.setReturnValue(returnValue); + SqlNode returnValue = subQuery.getReturnValue(); + if (returnValue != null) { + GQLScope returnValueScope = (GQLScope) getValidator().getScopes(returnValue); + returnValue = validator.expand(returnValue, returnValueScope); + subQuery.setReturnValue(returnValue); - returnValue.validate(validator, returnValueScope); - return validator.deriveType(returnValueScope, returnValue); - } - return pathPatternNs.getType(); + returnValue.validate(validator, returnValueScope); + return validator.deriveType(returnValueScope, returnValue); } + return pathPatternNs.getType(); + } - private void validateFirstNode(SqlMatchNode firstNode, PathRecordType inputPathType) { - RelDataTypeField field = inputPathType.getField(firstNode.getName(), - isCaseSensitive(), false); - if (field == null) { - throw new GeaFlowDSLException(firstNode.getParserPosition(), - "Label:{} is not exists in the input match statement", firstNode.getName()); - } - if (firstNode.getKind() != SqlKind.GQL_MATCH_NODE - || field.getType().getSqlTypeName() != SqlTypeName.VERTEX) { - throw new GeaFlowDSLException(firstNode.getParserPosition(), - "SubQuery should start from a vertex, current start label:{} is an edge", firstNode.getName()); - } + private void validateFirstNode(SqlMatchNode firstNode, PathRecordType inputPathType) { + RelDataTypeField field = inputPathType.getField(firstNode.getName(), isCaseSensitive(), false); + if (field == null) { + throw new GeaFlowDSLException( + firstNode.getParserPosition(), + "Label:{} is not exists in the input match statement", + firstNode.getName()); } - - private void validateOtherNode(SqlMatchNode otherNode, PathRecordType inputPathType) { - RelDataTypeField field = inputPathType.getField(otherNode.getName(), - isCaseSensitive(), false); - if (field != null) { - throw new GeaFlowDSLException(otherNode.getParserPosition(), - "Label:{} in SubQuery is already exists in the input match statement", otherNode.getName()); - } + if (firstNode.getKind() != SqlKind.GQL_MATCH_NODE + || field.getType().getSqlTypeName() != SqlTypeName.VERTEX) { + throw new GeaFlowDSLException( + firstNode.getParserPosition(), + "SubQuery should start from a vertex, current start label:{} is an edge", + firstNode.getName()); } + } - @Override - public SqlNode getNode() { - return subQuery; + private void validateOtherNode(SqlMatchNode otherNode, PathRecordType inputPathType) { + RelDataTypeField field = inputPathType.getField(otherNode.getName(), isCaseSensitive(), false); + if (field != null) { + throw new GeaFlowDSLException( + otherNode.getParserPosition(), + "Label:{} in SubQuery is already exists in the input match statement", + otherNode.getName()); } + } + + @Override + public SqlNode getNode() { + return subQuery; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLUnionPathPatternNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLUnionPathPatternNamespace.java index 5992accfe..061e7ea49 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLUnionPathPatternNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/GQLUnionPathPatternNamespace.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.SqlValidatorImpl; @@ -33,60 +34,60 @@ public class GQLUnionPathPatternNamespace extends GQLBaseNamespace { - private final SqlUnionPathPattern unionPathPattern; + private final SqlUnionPathPattern unionPathPattern; - private MatchNodeContext matchNodeContext; + private MatchNodeContext matchNodeContext; - public GQLUnionPathPatternNamespace(SqlValidatorImpl validator, SqlUnionPathPattern unionPathPattern) { - super(validator, unionPathPattern); - this.unionPathPattern = unionPathPattern; - } + public GQLUnionPathPatternNamespace( + SqlValidatorImpl validator, SqlUnionPathPattern unionPathPattern) { + super(validator, unionPathPattern); + this.unionPathPattern = unionPathPattern; + } - public void setMatchNodeContext(MatchNodeContext matchNodeContext) { - this.matchNodeContext = matchNodeContext; - } + public void setMatchNodeContext(MatchNodeContext matchNodeContext) { + this.matchNodeContext = matchNodeContext; + } - @Override - protected RelDataType validateImpl(RelDataType targetRowType) { - SqlValidatorScope scope = getValidator().getScopes(unionPathPattern); - List pathRecordTypes = new ArrayList<>(); - validatePathPatternRecursive(unionPathPattern, matchNodeContext, scope, pathRecordTypes); - PathRecordType patternType = new UnionPathRecordType(pathRecordTypes, - this.getValidator().getTypeFactory()); - return patternType; - } + @Override + protected RelDataType validateImpl(RelDataType targetRowType) { + SqlValidatorScope scope = getValidator().getScopes(unionPathPattern); + List pathRecordTypes = new ArrayList<>(); + validatePathPatternRecursive(unionPathPattern, matchNodeContext, scope, pathRecordTypes); + PathRecordType patternType = + new UnionPathRecordType(pathRecordTypes, this.getValidator().getTypeFactory()); + return patternType; + } - private void validatePathPatternRecursive(SqlNode pathPatternNode, - MatchNodeContext matchNodeContext, - SqlValidatorScope scope, - List pathPatternTypes) { - if (pathPatternNode instanceof SqlUnionPathPattern) { - SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) pathPatternNode; - validatePathPatternRecursive(unionPathPattern.getLeft(), matchNodeContext, scope, - pathPatternTypes); - validatePathPatternRecursive(unionPathPattern.getRight(), matchNodeContext, scope, - pathPatternTypes); - } - if (pathPatternNode instanceof SqlPathPattern) { - SqlPathPattern pathPattern = (SqlPathPattern) pathPatternNode; - GQLPathPatternNamespace pathPatternNs = - (GQLPathPatternNamespace) validator.getNamespace( - pathPatternNode); - pathPatternNs.setMatchNodeContext(matchNodeContext); - - pathPattern.validate(validator, scope); - RelDataType pathType = validator.getValidatedNodeType(pathPattern); - if (!(pathType instanceof PathRecordType)) { - throw new IllegalStateException("PathPattern should return PathRecordType"); - } - matchNodeContext.addResolvedPathPatternType((PathRecordType) pathType); - pathPatternTypes.add((PathRecordType) pathType); - } + private void validatePathPatternRecursive( + SqlNode pathPatternNode, + MatchNodeContext matchNodeContext, + SqlValidatorScope scope, + List pathPatternTypes) { + if (pathPatternNode instanceof SqlUnionPathPattern) { + SqlUnionPathPattern unionPathPattern = (SqlUnionPathPattern) pathPatternNode; + validatePathPatternRecursive( + unionPathPattern.getLeft(), matchNodeContext, scope, pathPatternTypes); + validatePathPatternRecursive( + unionPathPattern.getRight(), matchNodeContext, scope, pathPatternTypes); } + if (pathPatternNode instanceof SqlPathPattern) { + SqlPathPattern pathPattern = (SqlPathPattern) pathPatternNode; + GQLPathPatternNamespace pathPatternNs = + (GQLPathPatternNamespace) validator.getNamespace(pathPatternNode); + pathPatternNs.setMatchNodeContext(matchNodeContext); - - @Override - public SqlNode getNode() { - return unionPathPattern; + pathPattern.validate(validator, scope); + RelDataType pathType = validator.getValidatedNodeType(pathPattern); + if (!(pathType instanceof PathRecordType)) { + throw new IllegalStateException("PathPattern should return PathRecordType"); + } + matchNodeContext.addResolvedPathPatternType((PathRecordType) pathType); + pathPatternTypes.add((PathRecordType) pathType); } + } + + @Override + public SqlNode getNode() { + return unionPathPattern; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/IdentifierCompleteNamespace.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/IdentifierCompleteNamespace.java index ef9004b52..6ead793e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/IdentifierCompleteNamespace.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/namespace/IdentifierCompleteNamespace.java @@ -27,24 +27,27 @@ public class IdentifierCompleteNamespace extends IdentifierNamespace { - private final IdentifierNamespace baseNamespace; - private final GQLContext gContext; + private final IdentifierNamespace baseNamespace; + private final GQLContext gContext; - public IdentifierCompleteNamespace(IdentifierNamespace baseNamespace) { - super((GQLValidatorImpl) baseNamespace.getValidator(), baseNamespace.getId(), - baseNamespace.extendList, baseNamespace.getEnclosingNode(), - baseNamespace.getParentScope()); - this.baseNamespace = baseNamespace; - this.gContext = ((GQLValidatorImpl) getValidator()).getGQLContext(); - } + public IdentifierCompleteNamespace(IdentifierNamespace baseNamespace) { + super( + (GQLValidatorImpl) baseNamespace.getValidator(), + baseNamespace.getId(), + baseNamespace.extendList, + baseNamespace.getEnclosingNode(), + baseNamespace.getParentScope()); + this.baseNamespace = baseNamespace; + this.gContext = ((GQLValidatorImpl) getValidator()).getGQLContext(); + } - @Override - protected SqlIdentifier getResolveId(SqlIdentifier id) { - return gContext.completeCatalogObjName(id); - } + @Override + protected SqlIdentifier getResolveId(SqlIdentifier id) { + return gContext.completeCatalogObjName(id); + } - @Override - public SqlNode getNode() { - return baseNamespace.getNode(); - } + @Override + public SqlNode getNode() { + return baseNamespace.getNode(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLPathPatternScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLPathPatternScope.java index 7259d09e3..2ff57b88f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLPathPatternScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLPathPatternScope.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.validator.scope; import java.util.Collection; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlCall; @@ -36,71 +37,69 @@ public class GQLPathPatternScope extends GQLScope { - public GQLPathPatternScope(SqlValidatorScope parent, SqlCall pathPattern) { - super(parent, pathPattern); - } + public GQLPathPatternScope(SqlValidatorScope parent, SqlCall pathPattern) { + super(parent, pathPattern); + } - @Override - protected boolean ignoreColumnAmbiguous() { - return true; - } + @Override + protected boolean ignoreColumnAmbiguous() { + return true; + } - /** - * Resolve the type for {@link SqlMatchNode}. - */ - public RelDataType resolveTypeByLabels(SqlMatchNode matchNode, MatchNodeContext context) { - boolean isVertex = matchNode.getKind() == SqlKind.GQL_MATCH_NODE; - Collection labels = matchNode.getLabelNames(); - String nodeName = matchNode.getName(); + /** Resolve the type for {@link SqlMatchNode}. */ + public RelDataType resolveTypeByLabels(SqlMatchNode matchNode, MatchNodeContext context) { + boolean isVertex = matchNode.getKind() == SqlKind.GQL_MATCH_NODE; + Collection labels = matchNode.getLabelNames(); + String nodeName = matchNode.getName(); - GraphRecordType graphType = findCurrentGraphType((GQLValidatorImpl) validator, this); - if (graphType == null) { - boolean caseSensitive = ((GQLValidatorImpl) getValidator()).isCaseSensitive(); - RelDataTypeField resolvedField = context.getResolvedField(nodeName, caseSensitive); - if (resolvedField != null) { - return resolvedField.getType(); - } else { - throw new GeaFlowDSLException(matchNode.getParserPosition(), "Match node: {} is not find." - , matchNode.getName()); - } - } else { - if (isVertex) { - return graphType.getVertexType(labels, validator.getTypeFactory()); - } - return graphType.getEdgeType(labels, validator.getTypeFactory()); - } + GraphRecordType graphType = findCurrentGraphType((GQLValidatorImpl) validator, this); + if (graphType == null) { + boolean caseSensitive = ((GQLValidatorImpl) getValidator()).isCaseSensitive(); + RelDataTypeField resolvedField = context.getResolvedField(nodeName, caseSensitive); + if (resolvedField != null) { + return resolvedField.getType(); + } else { + throw new GeaFlowDSLException( + matchNode.getParserPosition(), "Match node: {} is not find.", matchNode.getName()); + } + } else { + if (isVertex) { + return graphType.getVertexType(labels, validator.getTypeFactory()); + } + return graphType.getEdgeType(labels, validator.getTypeFactory()); } + } - public static GraphRecordType findCurrentGraphType(GQLValidatorImpl validator, SqlValidatorScope scope) { - if (scope instanceof ListScope) { - SqlValidatorNamespace namespace = ((ListScope) scope).children.get(0).getNamespace(); - return findCurrentGraphType(validator, namespace); - } - return null; + public static GraphRecordType findCurrentGraphType( + GQLValidatorImpl validator, SqlValidatorScope scope) { + if (scope instanceof ListScope) { + SqlValidatorNamespace namespace = ((ListScope) scope).children.get(0).getNamespace(); + return findCurrentGraphType(validator, namespace); } + return null; + } - /** - * Find the graph using by current scope. - */ - public static GraphRecordType findCurrentGraphType(GQLValidatorImpl validator, SqlValidatorNamespace childNs) { - RelDataType type = childNs.getType(); - if (type instanceof GraphRecordType) { - GraphRecordType graphType = (GraphRecordType) type; - GraphRecordType modifyGraphType = - validator.getCurrentQueryNodeContext().getModifyGraph(graphType.getGraphName()); - if (modifyGraphType != null) { - // If the graph type has been modified by "Let Global" at current query-node-level context, - // then return the modified type. - return modifyGraphType; - } - return graphType; - } - SqlNode inputNode = childNs.getNode(); - SqlValidatorScope inputScope = (validator).getScopes(inputNode); + /** Find the graph using by current scope. */ + public static GraphRecordType findCurrentGraphType( + GQLValidatorImpl validator, SqlValidatorNamespace childNs) { + RelDataType type = childNs.getType(); + if (type instanceof GraphRecordType) { + GraphRecordType graphType = (GraphRecordType) type; + GraphRecordType modifyGraphType = + validator.getCurrentQueryNodeContext().getModifyGraph(graphType.getGraphName()); + if (modifyGraphType != null) { + // If the graph type has been modified by "Let Global" at current query-node-level context, + // then return the modified type. + return modifyGraphType; + } + return graphType; + } + SqlNode inputNode = childNs.getNode(); + SqlValidatorScope inputScope = (validator).getScopes(inputNode); - if (inputScope instanceof ListScope && ((ListScope) inputScope).getChildren().size() == 1) { - return findCurrentGraphType(validator, ((ListScope) inputScope).getChildren().get(0)); - } - return null; + if (inputScope instanceof ListScope && ((ListScope) inputScope).getChildren().size() == 1) { + return findCurrentGraphType(validator, ((ListScope) inputScope).getChildren().get(0)); } + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnGroupByScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnGroupByScope.java index 401e7bc88..bbf46144f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnGroupByScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnGroupByScope.java @@ -26,20 +26,20 @@ public class GQLReturnGroupByScope extends GQLScope { - private final SqlReturnStatement returnStatement; + private final SqlReturnStatement returnStatement; - public GQLReturnGroupByScope(SqlValidatorScope parent, SqlReturnStatement returnStatement, - SqlNodeList groupBy) { - super(parent, groupBy); - this.returnStatement = returnStatement; - } + public GQLReturnGroupByScope( + SqlValidatorScope parent, SqlReturnStatement returnStatement, SqlNodeList groupBy) { + super(parent, groupBy); + this.returnStatement = returnStatement; + } - @Override - public void validateExpr(SqlNode expr) { - parent.validateExpr(expr); - } + @Override + public void validateExpr(SqlNode expr) { + parent.validateExpr(expr); + } - public SqlReturnStatement getReturnStmt() { - return returnStatement; - } + public SqlReturnStatement getReturnStmt() { + return returnStatement; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnOrderByScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnOrderByScope.java index 1591aff11..1a21fd180 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnOrderByScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnOrderByScope.java @@ -25,15 +25,15 @@ public class GQLReturnOrderByScope extends GQLScope { - private final SqlNodeList orderBy; + private final SqlNodeList orderBy; - public GQLReturnOrderByScope(SqlValidatorScope parent, SqlNodeList orderBy) { - super(parent, orderBy); - this.orderBy = orderBy; - } + public GQLReturnOrderByScope(SqlValidatorScope parent, SqlNodeList orderBy) { + super(parent, orderBy); + this.orderBy = orderBy; + } - @Override - public void validateExpr(SqlNode expr) { - parent.validateExpr(expr); - } + @Override + public void validateExpr(SqlNode expr) { + parent.validateExpr(expr); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnScope.java index 6451176e9..2dad571f8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLReturnScope.java @@ -19,15 +19,12 @@ package org.apache.geaflow.dsl.validator.scope; -import com.google.common.base.Suppliers; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Supplier; + import org.apache.calcite.linq4j.Linq4j; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; @@ -38,156 +35,158 @@ import org.apache.geaflow.dsl.rel.GQLToRelConverter.GQLAggChecker; import org.apache.geaflow.dsl.sqlnode.SqlReturnStatement; -public class GQLReturnScope extends GQLScope { - - private final SqlReturnStatement returnStatement; - - private boolean containAgg = false; - - private List expandedReturnList = null; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; - private List temporaryGroupExprList; +public class GQLReturnScope extends GQLScope { - public GQLReturnScope(SqlValidatorScope parent, SqlReturnStatement returnStatement) { - super(parent, returnStatement); - this.returnStatement = returnStatement; + private final SqlReturnStatement returnStatement; + + private boolean containAgg = false; + + private List expandedReturnList = null; + + private List temporaryGroupExprList; + + public GQLReturnScope(SqlValidatorScope parent, SqlReturnStatement returnStatement) { + super(parent, returnStatement); + this.returnStatement = returnStatement; + } + + public final Supplier resolved = + Suppliers.memoize( + () -> { + assert temporaryGroupExprList == null; + temporaryGroupExprList = new ArrayList<>(); + try { + return resolve(); + } finally { + temporaryGroupExprList = null; + } + }) + ::get; + + public void setAggMode() { + containAgg = true; + } + + public List getExpandedReturnList() { + return expandedReturnList; + } + + public void setExpandedReturnList(List returnList) { + expandedReturnList = returnList; + } + + private Resolved resolve() { + final ImmutableList.Builder> builder = ImmutableList.builder(); + List extraExprs = ImmutableList.of(); + Map groupExprProjection = ImmutableMap.of(); + if (returnStatement.getGroupBy() != null) { + final SqlNodeList groupList = returnStatement.getGroupBy(); + final SqlValidatorUtil.GroupAnalyzer groupAnalyzer = + new SqlValidatorUtil.GroupAnalyzer(temporaryGroupExprList); + for (SqlNode groupExpr : groupList) { + SqlValidatorUtil.analyzeGroupItem(this, groupAnalyzer, builder, groupExpr); + } + extraExprs = groupAnalyzer.getExtraExprs(); + groupExprProjection = groupAnalyzer.getGroupExprProjection(); } - public final Supplier resolved = - Suppliers.memoize(() -> { - assert temporaryGroupExprList == null; - temporaryGroupExprList = new ArrayList<>(); - try { - return resolve(); - } finally { - temporaryGroupExprList = null; - } - })::get; - - public void setAggMode() { - containAgg = true; + final Set flatGroupSets = Sets.newTreeSet(ImmutableBitSet.COMPARATOR); + for (List groupSet : Linq4j.product(builder.build())) { + flatGroupSets.add(ImmutableBitSet.union(groupSet)); } - public List getExpandedReturnList() { - return expandedReturnList; + // For GROUP BY (), we need a singleton grouping set. + if (flatGroupSets.isEmpty()) { + flatGroupSets.add(ImmutableBitSet.of()); } - public void setExpandedReturnList(List returnList) { - expandedReturnList = returnList; + return new Resolved(extraExprs, temporaryGroupExprList, flatGroupSets, groupExprProjection); + } + + /** + * Returns the expressions that are in the GROUP BY clause (or the SELECT DISTINCT clause, if + * distinct) and that can therefore be referenced without being wrapped in aggregate functions. + * + *

The expressions are fully-qualified, and any "*" in select clauses are expanded. + * + * @return list of grouping expressions + */ + private Pair, ImmutableList> getGroupExprs() { + if (returnStatement.getGroupBy() != null) { + if (temporaryGroupExprList != null) { + // we are in the middle of resolving + return Pair.of(ImmutableList.of(), ImmutableList.copyOf(temporaryGroupExprList)); + } else { + final Resolved resolved = this.resolved.get(); + return Pair.of(resolved.extraExprList, resolved.groupExprList); + } + } else { + return Pair.of(ImmutableList.of(), ImmutableList.of()); } + } - private Resolved resolve() { - final ImmutableList.Builder> builder = - ImmutableList.builder(); - List extraExprs = ImmutableList.of(); - Map groupExprProjection = ImmutableMap.of(); - if (returnStatement.getGroupBy() != null) { - final SqlNodeList groupList = returnStatement.getGroupBy(); - final SqlValidatorUtil.GroupAnalyzer groupAnalyzer = - new SqlValidatorUtil.GroupAnalyzer(temporaryGroupExprList); - for (SqlNode groupExpr : groupList) { - SqlValidatorUtil.analyzeGroupItem(this, groupAnalyzer, builder, - groupExpr); - } - extraExprs = groupAnalyzer.getExtraExprs(); - groupExprProjection = groupAnalyzer.getGroupExprProjection(); - } - - final Set flatGroupSets = - Sets.newTreeSet(ImmutableBitSet.COMPARATOR); - for (List groupSet : Linq4j.product(builder.build())) { - flatGroupSets.add(ImmutableBitSet.union(groupSet)); - } - - // For GROUP BY (), we need a singleton grouping set. - if (flatGroupSets.isEmpty()) { - flatGroupSets.add(ImmutableBitSet.of()); - } - - return new Resolved(extraExprs, temporaryGroupExprList, flatGroupSets, groupExprProjection); + private static boolean allContain(List bitSets, int bit) { + for (ImmutableBitSet bitSet : bitSets) { + if (!bitSet.get(bit)) { + return false; + } } + return true; + } - /** - * Returns the expressions that are in the GROUP BY clause (or the SELECT - * DISTINCT clause, if distinct) and that can therefore be referenced - * without being wrapped in aggregate functions. - * - *

The expressions are fully-qualified, and any "*" in select clauses are - * expanded. - * - * @return list of grouping expressions - */ - private Pair, ImmutableList> getGroupExprs() { - if (returnStatement.getGroupBy() != null) { - if (temporaryGroupExprList != null) { - // we are in the middle of resolving - return Pair.of(ImmutableList.of(), - ImmutableList.copyOf(temporaryGroupExprList)); - } else { - final Resolved resolved = this.resolved.get(); - return Pair.of(resolved.extraExprList, resolved.groupExprList); - } - } else { - return Pair.of(ImmutableList.of(), ImmutableList.of()); - } + public boolean checkAggregateExpr(SqlNode expr, boolean deep) { + // Fully-qualify any identifiers in expr. + if (deep) { + expr = validator.expand(expr, this); } - private static boolean allContain(List bitSets, int bit) { - for (ImmutableBitSet bitSet : bitSets) { - if (!bitSet.get(bit)) { - return false; - } - } - return true; + // Make sure expression is valid, throws if not. + Pair, ImmutableList> pair = getGroupExprs(); + final GQLAggChecker aggChecker = + new GQLAggChecker(validator, this, pair.left, pair.right, false); + if (deep) { + expr.accept(aggChecker); } - public boolean checkAggregateExpr(SqlNode expr, boolean deep) { - // Fully-qualify any identifiers in expr. - if (deep) { - expr = validator.expand(expr, this); - } - - // Make sure expression is valid, throws if not. - Pair, ImmutableList> pair = getGroupExprs(); - final GQLAggChecker aggChecker = - new GQLAggChecker(validator, this, pair.left, pair.right, false); - if (deep) { - expr.accept(aggChecker); - } - - // Return whether expression exactly matches one of the group - // expressions. - return aggChecker.isGroupExpr(expr); - } + // Return whether expression exactly matches one of the group + // expressions. + return aggChecker.isGroupExpr(expr); + } - public void validateExpr(SqlNode expr) { - if (containAgg) { - checkAggregateExpr(expr, true); - } else { - super.validateExpr(expr); - } + public void validateExpr(SqlNode expr) { + if (containAgg) { + checkAggregateExpr(expr, true); + } else { + super.validateExpr(expr); + } + } + + public static class Resolved { + public final ImmutableList extraExprList; + public final ImmutableList groupExprList; + public final ImmutableBitSet groupSet; + public final ImmutableList groupSets; + public final Map groupExprProjection; + + Resolved( + List extraExprList, + List groupExprList, + Iterable groupSets, + Map groupExprProjection) { + this.extraExprList = ImmutableList.copyOf(extraExprList); + this.groupExprList = ImmutableList.copyOf(groupExprList); + this.groupSet = ImmutableBitSet.range(groupExprList.size()); + this.groupSets = ImmutableList.copyOf(groupSets); + this.groupExprProjection = ImmutableMap.copyOf(groupExprProjection); } - - public static class Resolved { - public final ImmutableList extraExprList; - public final ImmutableList groupExprList; - public final ImmutableBitSet groupSet; - public final ImmutableList groupSets; - public final Map groupExprProjection; - - Resolved(List extraExprList, List groupExprList, - Iterable groupSets, - Map groupExprProjection) { - this.extraExprList = ImmutableList.copyOf(extraExprList); - this.groupExprList = ImmutableList.copyOf(groupExprList); - this.groupSet = ImmutableBitSet.range(groupExprList.size()); - this.groupSets = ImmutableList.copyOf(groupSets); - this.groupExprProjection = ImmutableMap.copyOf(groupExprProjection); - } - - public boolean isNullable(int i) { - return i < groupExprList.size() && !allContain(groupSets, i); - } + public boolean isNullable(int i) { + return i < groupExprList.size() && !allContain(groupSets, i); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLScope.java index 5bd95aef3..3b0613dc9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLScope.java @@ -27,24 +27,24 @@ public class GQLScope extends ListScope { - protected final SqlNode node; + protected final SqlNode node; - public GQLScope(SqlValidatorScope parent, SqlNode node) { - super(parent); - this.node = node; - } + public GQLScope(SqlValidatorScope parent, SqlNode node) { + super(parent); + this.node = node; + } - @Override - public SqlNode getNode() { - return node; - } + @Override + public SqlNode getNode() { + return node; + } - @Override - public SqlValidatorScope getOperandScope(SqlCall call) { - SqlValidatorScope scope = ((GQLValidatorImpl) validator).getScopes(call); - if (scope != null) { - return scope; - } - return super.getOperandScope(call); + @Override + public SqlValidatorScope getOperandScope(SqlCall call) { + SqlValidatorScope scope = ((GQLValidatorImpl) validator).getScopes(call); + if (scope != null) { + return scope; } + return super.getOperandScope(call); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLSubQueryScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLSubQueryScope.java index abe4bd650..ae2634f00 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLSubQueryScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLSubQueryScope.java @@ -25,12 +25,12 @@ public class GQLSubQueryScope extends GQLScope { - public GQLSubQueryScope(SqlValidatorScope parent, SqlNode node) { - super(parent, node); - } + public GQLSubQueryScope(SqlValidatorScope parent, SqlNode node) { + super(parent, node); + } - public PathRecordType getInputPathType() { - assert children.size() == 1 : "GQLSubQueryScope must have only one child namespace."; - return (PathRecordType) children.get(0).getNamespace().getType(); - } + public PathRecordType getInputPathType() { + assert children.size() == 1 : "GQLSubQueryScope must have only one child namespace."; + return (PathRecordType) children.get(0).getNamespace().getType(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLWithBodyScope.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLWithBodyScope.java index 35d9ff7cf..83172351e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLWithBodyScope.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/validator/scope/GQLWithBodyScope.java @@ -24,7 +24,7 @@ public class GQLWithBodyScope extends GQLScope { - public GQLWithBodyScope(SqlValidatorScope parent, SqlNode node) { - super(parent, node); - } + public GQLWithBodyScope(SqlValidatorScope parent, SqlNode node) { + super(parent, node); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLFieldExtractorTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLFieldExtractorTest.java index 3a2ddaa2f..ddd3096de 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLFieldExtractorTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLFieldExtractorTest.java @@ -25,49 +25,51 @@ public class GQLFieldExtractorTest { - private static final String GRAPH_G1 = "create graph g1(" - + "vertex user(" - + " id bigint ID," - + "name varchar" - + ")," - + "vertex person(" - + " id bigint ID," - + "name varchar," - + "gender int," - + "age integer" - + ")," - + "edge knows(" - + " src_id bigint SOURCE ID," - + " target_id bigint DESTINATION ID," - + " time bigint TIMESTAMP," - + " weight double" - + ")" - + ")"; + private static final String GRAPH_G1 = + "create graph g1(" + + "vertex user(" + + " id bigint ID," + + "name varchar" + + ")," + + "vertex person(" + + " id bigint ID," + + "name varchar," + + "gender int," + + "age integer" + + ")," + + "edge knows(" + + " src_id bigint SOURCE ID," + + " target_id bigint DESTINATION ID," + + " time bigint TIMESTAMP," + + " weight double" + + ")" + + ")"; - @Test - public void testGraphMatchFieldPrune() { - PlanTester.build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.age > 18)" + - "-[e:knows WHERE e.weight > 0.5]" + - "->(b:user WHERE b.id != 0 AND name like 'MARKO')\n") - .toRel() - .checkFilteredFields("{a=[null], b=[null], e=[null]}") - .opt(GraphMatchFieldPruneRule.INSTANCE) - .checkFilteredFields("{a=[a.age], b=[b.id, b.name], e=[e.weight]}"); - } - - @Test - public void testProjectFieldPrune() { - PlanTester.build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person)-[e:knows]->(b:user)\n" + - " RETURN e.src_id as src_id, e.target_id as target_id," + - " a.gender as a_gender, b.id as b_id") - .toRel() - .checkFilteredFields("{a=[null], b=[null], e=[null]}") - .opt(ProjectFieldPruneRule.INSTANCE) - .checkFilteredFields("{a=[a.gender], b=[b.id], e=[e.src_id, e.target_id]}"); - } + @Test + public void testGraphMatchFieldPrune() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.age > 18)" + + "-[e:knows WHERE e.weight > 0.5]" + + "->(b:user WHERE b.id != 0 AND name like 'MARKO')\n") + .toRel() + .checkFilteredFields("{a=[null], b=[null], e=[null]}") + .opt(GraphMatchFieldPruneRule.INSTANCE) + .checkFilteredFields("{a=[a.age], b=[b.id, b.name], e=[e.weight]}"); + } + @Test + public void testProjectFieldPrune() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person)-[e:knows]->(b:user)\n" + + " RETURN e.src_id as src_id, e.target_id as target_id," + + " a.gender as a_gender, b.id as b_id") + .toRel() + .checkFilteredFields("{a=[null], b=[null], e=[null]}") + .opt(ProjectFieldPruneRule.INSTANCE) + .checkFilteredFields("{a=[a.gender], b=[b.id], e=[e.src_id, e.target_id]}"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLToRelConverterTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLToRelConverterTest.java index d9128093e..bb0fa4928 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLToRelConverterTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLToRelConverterTest.java @@ -28,63 +28,61 @@ public class GQLToRelConverterTest { - private static final String GRAPH_G1 = "create graph g1(" - + "vertex user(" - + " id bigint ID," - + "name varchar" - + ")," - + "vertex person(" - + " id bigint ID," - + "name varchar," - + "gender int," - + "age integer" - + ")," - + "edge knows(" - + " src_id bigint SOURCE ID," - + " target_id bigint DESTINATION ID," - + " time bigint TIMESTAMP," - + " weight double" - + ")" - + ")"; - - @Test - public void testMatchPattern() { - PlanTester.build() - .gql("match(a:user)-[e:knows]->(b:user)") - .toRel() - .checkRelNode( - "LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testMatchPattern2() { - PlanTester.build() - .gql("MATCH (a:user)-[e:knows]->(b)->(c)-[]->(d)<-[e2]-(f)") - .toRel() - .checkRelNode( - "LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:)-[e_col_1:]->(c:)-[e_col_2:]->(d:)-[e2:]<-(f:)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testReturn() { - PlanTester.build() - .gql("match(a:user)-[e:knows]->(b:user)\n" - + "RETURN a.id AS a_id, e, b.id") - .toRel() - .checkRelNode( - "LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testReturn2() { - String graph = "create graph g1(" + private static final String GRAPH_G1 = + "create graph g1(" + + "vertex user(" + + " id bigint ID," + + "name varchar" + + ")," + + "vertex person(" + + " id bigint ID," + + "name varchar," + + "gender int," + + "age integer" + + ")," + + "edge knows(" + + " src_id bigint SOURCE ID," + + " target_id bigint DESTINATION ID," + + " time bigint TIMESTAMP," + + " weight double" + + ")" + + ")"; + + @Test + public void testMatchPattern() { + PlanTester.build() + .gql("match(a:user)-[e:knows]->(b:user)") + .toRel() + .checkRelNode( + "LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testMatchPattern2() { + PlanTester.build() + .gql("MATCH (a:user)-[e:knows]->(b)->(c)-[]->(d)<-[e2]-(f)") + .toRel() + .checkRelNode( + "LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:)-[e_col_1:]->(c:)-[e_col_2:]->(d:)-[e2:]<-(f:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testReturn() { + PlanTester.build() + .gql("match(a:user)-[e:knows]->(b:user)\n" + "RETURN a.id AS a_id, e, b.id") + .toRel() + .checkRelNode( + "LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testReturn2() { + String graph = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -108,895 +106,859 @@ public void testReturn2() { + " weight double" + ")" + ")"; - PlanTester.build() - .registerGraph(graph) - .gql("MATCH (a:user|person WHERE id = 1)-[e:knows|follow]->(b:user)\n" - + "RETURN a, e, b") - .toRel() - .checkRelNode( - "LogicalProject(a=[$0], e=[$1], b=[$2])\n" - + " LogicalGraphMatch(path=[(a:user|person) where =(a.id, 1) " - + "-[e:knows|follow]->(b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testVertexScan() { - PlanTester.build() - .gql("select id from user") - .toRel() - .checkRelNode( - "LogicalProject(id=[$0])\n" - + " LogicalTableScan(table=[[default, user]])\n" - ) - .opt(new TableScanToGraphRule()) - .checkRelNode( - "LogicalProject(id=[$0])\n" - + " LogicalProject(f0=[user.id], f1=[user.~label], f2=[user.name], f3=[user" - + ".age])\n" - + " LogicalGraphMatch(path=[(user:user)])\n" - + " LogicalGraphScan(table=[null])\n" - ); - } - - @Test - public void testVertexIdFilterSimplify() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows]-(b:user)\n" + PlanTester.build() + .registerGraph(graph) + .gql("MATCH (a:user|person WHERE id = 1)-[e:knows|follow]->(b:user)\n" + "RETURN a, e, b") + .toRel() + .checkRelNode( + "LogicalProject(a=[$0], e=[$1], b=[$2])\n" + + " LogicalGraphMatch(path=[(a:user|person) where =(a.id, 1) " + + "-[e:knows|follow]->(b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testVertexScan() { + PlanTester.build() + .gql("select id from user") + .toRel() + .checkRelNode("LogicalProject(id=[$0])\n" + " LogicalTableScan(table=[[default, user]])\n") + .opt(new TableScanToGraphRule()) + .checkRelNode( + "LogicalProject(id=[$0])\n" + + " LogicalProject(f0=[user.id], f1=[user.~label], f2=[user.name], f3=[user" + + ".age])\n" + + " LogicalGraphMatch(path=[(user:user)])\n" + + " LogicalGraphScan(table=[null])\n"); + } + + @Test + public void testVertexIdFilterSimplify() { + PlanTester.build() + .gql( + "MATCH (a:user where id = 1)-[e:knows]-(b:user)\n" + "RETURN a.id as a_id, e.weight as weight, b.id as b_id") - .toRel() - .checkRelNode( - "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ) - .opt(MatchIdFilterSimplifyRule.INSTANCE) - .checkRelNode( - "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:knows]-(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testMatchEdgeLabelRemove() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows]-(b:user) WHERE e.~label = 'knows' " + .toRel() + .checkRelNode( + "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n") + .opt(MatchIdFilterSimplifyRule.INSTANCE) + .checkRelNode( + "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:knows]-(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testMatchEdgeLabelRemove() { + PlanTester.build() + .gql( + "MATCH (a:user where id = 1)-[e:knows]-(b:user) WHERE e.~label = 'knows' " + "or e.~label = 'created'\n" + "RETURN a.id as a_id, e.weight as weight, b.id as b_id") - .toRel() - .checkRelNode( - "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user) " - + "where OR(=($1.~label, _UTF-16LE'knows'), =($1.~label, _UTF-16LE'created'))" - + " ])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ) - .opt(MatchEdgeLabelFilterRemoveRule.INSTANCE, FilterMatchNodeTransposeRule.INSTANCE, FilterToMatchRule.INSTANCE) - .checkRelNode( - "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testFilter() { - PlanTester.build() - .gql("match(a:user)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user) " + + "where OR(=($1.~label, _UTF-16LE'knows'), =($1.~label, _UTF-16LE'created'))" + + " ])\n" + + " LogicalGraphScan(table=[default.g0])\n") + .opt( + MatchEdgeLabelFilterRemoveRule.INSTANCE, + FilterMatchNodeTransposeRule.INSTANCE, + FilterToMatchRule.INSTANCE) + .checkRelNode( + "LogicalProject(a_id=[$0.id], weight=[$1.weight], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testFilter() { + PlanTester.build() + .gql( + "match(a:user)-[e:knows]->(b:user)\n" + "RETURN a.id AS a_id, e, b.id\n" + "THEN FILTER a_id = 1 AND e.src_id = 1") - .toRel() - .checkRelNode( - "LogicalFilter(condition=[AND(=($0, 1), =($1.src_id, 1))])\n" - + " LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testFilter2() { - PlanTester.build() - .gql("match(a:user)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalFilter(condition=[AND(=($0, 1), =($1.src_id, 1))])\n" + + " LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testFilter2() { + PlanTester.build() + .gql( + "match(a:user)-[e:knows]->(b:user)\n" + "RETURN a.id as a_id, e, b.id\n" + "THEN FILTER a_id = 1 OR id = 1 AND CAST(e.weight as int) = 1") - .toRel() - .checkRelNode( - "LogicalFilter(condition=[OR(=($0, 1), AND(=($2, 1), =(CAST($1.weight):INTEGER, 1)))])\n" - + " LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } - - @Test - public void testReturnOrderBy1() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalFilter(condition=[OR(=($0, 1), AND(=($2, 1), =(CAST($1.weight):INTEGER," + + " 1)))])\n" + + " LogicalProject(a_id=[$0.id], e=[$1], id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:user)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testReturnOrderBy1() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN a.name as name, b.id as _id, b.age as age order by name DESC") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$0], dir0=[DESC])\n" - + " LogicalProject(name=[$0.name], _id=[$2.id], age=[$2.age])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " - + "-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy2() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$0], dir0=[DESC])\n" + + " LogicalProject(name=[$0.name], _id=[$2.id], age=[$2.age])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " + + "-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy2() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, b.id as _id, a.age as a_age, b.age as b_age order by" + " b_age + 10 DESC Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(name=[$0], _id=[$1], a_age=[$2], b_age=[$3])\n" - + " LogicalSort(sort0=[$4], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(name=[$2.name], _id=[$2.id], a_age=[$0.age], b_age=[$2" - + ".age], EXPR$4=[+($2.age, 10)])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " - + "-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy3() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalProject(name=[$0], _id=[$1], a_age=[$2], b_age=[$3])\n" + + " LogicalSort(sort0=[$4], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(name=[$2.name], _id=[$2.id], a_age=[$0.age], b_age=[$2" + + ".age], EXPR$4=[+($2.age, 10)])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " + + "-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy3() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as b_id, weight Order by _a.id * 10 DESC Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$0], b_id=[$1], weight=[$2])\n" - + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight], EXPR$3=[*($0" - + ".id, 10)])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy4() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalProject(_a=[$0], b_id=[$1], weight=[$2])\n" + + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight], EXPR$3=[*($0" + + ".id, 10)])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy4() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as b_id, weight Order by a.id DESC Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$0], b_id=[$1], weight=[$2])\n" - + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight], id=[$0.id])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt1() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user where a.id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a") - .toRel() - .checkRelNode( - "LogicalProject(a=[$0])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where >" - + "(e.weight, 0.4) (b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt2() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("Match (a:user WHERE name = 'marko')<-[e:knows]-(b:person) WHERE a.name <> b.name RETURN e") - .toRel() - .checkRelNode( - "LogicalProject(e=[$1])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.name, _UTF-16LE'marko') -[e:knows]<-(b:person) " - + "where <>($0.name, $2.name) ])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt3() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.age > 18) - (b: person) RETURN b") - .toRel() - .checkRelNode( - "LogicalProject(b=[$2])\n" - + " LogicalGraphMatch(path=[(a:person) where >(a.age, 18) -[e_col_1:]-(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt4() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN a.name as name, b.id as b_id") - .toRel() - .checkRelNode( - "LogicalProject(name=[$0.name], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " - + "-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt5() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN a.name as name, b.id as b_id, b.age * 10 as amt") - .toRel() - .checkRelNode( - "LogicalProject(name=[$0.name], b_id=[$2.id], amt=[*($2.age, 10)])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " - + "-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt6() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN b.name as name, cast(b.id as int) as _id") - .toRel() - .checkRelNode( - "LogicalProject(name=[$2.name], _id=[CAST($2.id):INTEGER NOT NULL])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') -[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt7() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person WHERE id = '1')-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalProject(_a=[$0], b_id=[$1], weight=[$2])\n" + + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight], id=[$0.id])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt1() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql("MATCH (a:user where a.id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a") + .toRel() + .checkRelNode( + "LogicalProject(a=[$0])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where >" + + "(e.weight, 0.4) (b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt2() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "Match (a:user WHERE name = 'marko')<-[e:knows]-(b:person) WHERE a.name <> b.name" + + " RETURN e") + .toRel() + .checkRelNode( + "LogicalProject(e=[$1])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.name, _UTF-16LE'marko')" + + " -[e:knows]<-(b:person) where <>($0.name, $2.name) ])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt3() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql("MATCH (a:person WHERE a.age > 18) - (b: person) RETURN b") + .toRel() + .checkRelNode( + "LogicalProject(b=[$2])\n" + + " LogicalGraphMatch(path=[(a:person) where >(a.age, 18)" + + " -[e_col_1:]-(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt4() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN a.name as name, b.id as" + + " b_id") + .toRel() + .checkRelNode( + "LogicalProject(name=[$0.name], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " + + "-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt5() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN a.name as name, b.id as" + + " b_id, b.age * 10 as amt") + .toRel() + .checkRelNode( + "LogicalProject(name=[$0.name], b_id=[$2.id], amt=[*($2.age, 10)])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " + + "-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt6() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) RETURN b.name as name," + + " cast(b.id as int) as _id") + .toRel() + .checkRelNode( + "LogicalProject(name=[$2.name], _id=[CAST($2.id):INTEGER NOT NULL])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1')" + + " -[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt7() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person WHERE id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, case when b.gender = 0 then '0' else '1' end as _id") - .toRel() - .checkRelNode( - "LogicalProject(name=[$2.name], _id=[CASE(=($2.gender, 0), _UTF-16LE'0', " - + "_UTF-16LE'1')])\n" - + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " - + "-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt8() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("Match (a1:user WHERE name like 'marko')-[e1]->(b1:person) return b1.name AS b_id") - .toRel() - .checkRelNode( - "LogicalProject(b_id=[$2.name])\n" - + " LogicalGraphMatch(path=[(a1:user) where LIKE(a1.name, _UTF-16LE'marko')" - + " -[e1:]->(b1:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnStmt9() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("Match (a:user)-[e]->(b:person WHERE name = 'lop') return a.id") - .toRel() - .checkRelNode( - "LogicalProject(id=[$0.id])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:]->(b:person) where =(b.name, " - + "_UTF-16LE'lop') ])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy5() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalProject(name=[$2.name], _id=[CASE(=($2.gender, 0), _UTF-16LE'0', " + + "_UTF-16LE'1')])\n" + + " LogicalGraphMatch(path=[(a:person) where =(a.id, _UTF-16LE'1') " + + "-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt8() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql("Match (a1:user WHERE name like 'marko')-[e1]->(b1:person) return b1.name AS b_id") + .toRel() + .checkRelNode( + "LogicalProject(b_id=[$2.name])\n" + + " LogicalGraphMatch(path=[(a1:user) where LIKE(a1.name, _UTF-16LE'marko')" + + " -[e1:]->(b1:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnStmt9() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql("Match (a:user)-[e]->(b:person WHERE name = 'lop') return a.id") + .toRel() + .checkRelNode( + "LogicalProject(id=[$0.id])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:]->(b:person) where =(b.name, " + + "_UTF-16LE'lop') ])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy5() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as b_id, weight Order by weight DESC Limit 10") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$2], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy6() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$2], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], b_id=[$2.id], weight=[$1.weight])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy6() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, e as _e, b.id as b_id Order by e DESC, a Limit 10") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$1], sort1=[$0], dir0=[DESC], dir1=[ASC], fetch=[10])\n" - + " LogicalProject(_a=[$0], _e=[$1], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy7() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$1], sort1=[$0], dir0=[DESC], dir1=[ASC], fetch=[10])\n" + + " LogicalProject(_a=[$0], _e=[$1], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy7() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + "RETURN a as _a, e.time as e_time, gender Order by " + "case when b.gender > 25 then '1' else '0' end DESC " + "Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$0], e_time=[$1], gender=[$2])\n" - + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], e_time=[$1.time], gender=[$2.gender], " - + "EXPR$3=[CASE(>($2.gender, 25), _UTF-16LE'1', _UTF-16LE'0')])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user|person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy8() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + .toRel() + .checkRelNode( + "LogicalProject(_a=[$0], e_time=[$1], gender=[$2])\n" + + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], e_time=[$1.time], gender=[$2.gender], " + + "EXPR$3=[CASE(>($2.gender, 25), _UTF-16LE'1', _UTF-16LE'0')])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user|person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy8() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + "RETURN a as _a, e.time as e_time, gender Order by " + "case when gender > 25 then '1' else '0' end DESC " + "Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$0], e_time=[$1], gender=[$2])\n" - + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], e_time=[$1.time], gender=[$2.gender], " - + "EXPR$3=[CASE(>($2.gender, 25), _UTF-16LE'1', _UTF-16LE'0')])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user|person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy9() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + .toRel() + .checkRelNode( + "LogicalProject(_a=[$0], e_time=[$1], gender=[$2])\n" + + " LogicalSort(sort0=[$3], dir0=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], e_time=[$1.time], gender=[$2.gender], " + + "EXPR$3=[CASE(>($2.gender, 25), _UTF-16LE'1', _UTF-16LE'0')])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user|person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy9() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + "RETURN a as _a, e as _e, gender Order by " + "a DESC, e DESC " + "Limit 10") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$0], sort1=[$1], dir0=[DESC], dir1=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user|person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy10() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$0], sort1=[$1], dir0=[DESC], dir1=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user|person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy10() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + "RETURN a as _a, e as _e, gender Order by " + "_e DESC, _a DESC " + "Limit 10") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$1], sort1=[$0], dir0=[DESC], dir1=[DESC], fetch=[10])\n" - + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user|person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnOrderBy11() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$1], sort1=[$0], dir0=[DESC], dir1=[DESC], fetch=[10])\n" + + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user|person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnOrderBy11() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user|person)\n" + "RETURN a as _a, e as _e, gender Order by " + "_a.id DESC, _e.time + 1000" + "Limit 10") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$0], _e=[$1], gender=[$2])\n" - + " LogicalSort(sort0=[$3], sort1=[$4], dir0=[DESC], dir1=[ASC], fetch=[10])\n" - + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender], id=[$0.id], " - + "EXPR$4=[+($1.time, 1000)])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" - + "(b:user|person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnGroupBy1() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + .toRel() + .checkRelNode( + "LogicalProject(_a=[$0], _e=[$1], gender=[$2])\n" + + " LogicalSort(sort0=[$3], sort1=[$4], dir0=[DESC], dir1=[ASC], fetch=[10])\n" + + " LogicalProject(_a=[$0], _e=[$1], gender=[$2.gender], id=[$0.id], " + + "EXPR$4=[+($1.time, 1000)])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]->" + + "(b:user|person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnGroupBy1() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + "RETURN a.id as a_id, SUM(e.weight * 10) as amt GROUP BY a_id ORDER BY a_id") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$0], dir0=[ASC])\n" - + " LogicalAggregate(group=[{0}], amt=[SUM($1)])\n" - + " LogicalProject(a_id=[$0.id], $f1=[*($1.weight, 10)])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]-> where >(e.weight, 0" - + ".4) (b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnGroupBy2() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person)-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + + " LogicalAggregate(group=[{0}], amt=[SUM($1)])\n" + + " LogicalProject(a_id=[$0.id], $f1=[*($1.weight, 10)])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]-> where >(e.weight, 0" + + ".4) (b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnGroupBy2() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN a as _a, e as _e, b as _b group by _a, _e, _b") - .toRel() - .checkRelNode( - "LogicalAggregate(group=[{0, 1, 2}])\n" - + " LogicalProject(_a=[$0], _e=[$1], _b=[$2])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnGroupBy3() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person)-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalAggregate(group=[{0, 1, 2}])\n" + + " LogicalProject(_a=[$0], _e=[$1], _b=[$2])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnGroupBy3() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN COUNT(a.name) as a_name, SUM(e.weight) as e_weight, " + "MAX(b.age) as b_age_max, MIN(b.age) as b_age_min, " + "AVG(b.age) as b_age_avg, b as _b " + "group by _b order by _b") - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$5], dir0=[ASC])\n" - + " LogicalProject(a_name=[$1], e_weight=[$2], b_age_max=[$3], " - + "b_age_min=[$4], b_age_avg=[$5], _b=[$0])\n" - + " LogicalAggregate(group=[{0}], a_name=[COUNT($1)], e_weight=[SUM($2)], " - + "b_age_max=[MAX($3)], b_age_min=[MIN($3)], b_age_avg=[AVG($3)])\n" - + " LogicalProject(_b=[$2], $f1=[$0.name], $f2=[$1.weight], $f3=[$2" - + ".age])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnGroupBy4() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:person)-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$5], dir0=[ASC])\n" + + " LogicalProject(a_name=[$1], e_weight=[$2], b_age_max=[$3], " + + "b_age_min=[$4], b_age_avg=[$5], _b=[$0])\n" + + " LogicalAggregate(group=[{0}], a_name=[COUNT($1)], e_weight=[SUM($2)], " + + "b_age_max=[MAX($3)], b_age_min=[MIN($3)], b_age_avg=[AVG($3)])\n" + + " LogicalProject(_b=[$2], $f1=[$0.name], $f2=[$1.weight], $f3=[$2" + + ".age])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnGroupBy4() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN a as _a, e as _e, b.id as b_id group by b.id, e, a") - .toRel() - .checkRelNode( - "LogicalProject(_a=[$2], _e=[$1], b_id=[$0])\n" - + " LogicalAggregate(group=[{0, 1, 2}])\n" - + " LogicalProject(b_id=[$2.id], _e=[$1], _a=[$0])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testReturnGroupBy5() { - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql("MATCH (a:user)-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalProject(_a=[$2], _e=[$1], b_id=[$0])\n" + + " LogicalAggregate(group=[{0, 1, 2}])\n" + + " LogicalProject(b_id=[$2.id], _e=[$1], _a=[$0])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testReturnGroupBy5() { + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql( + "MATCH (a:user)-[e:knows]->(b:person) " + "RETURN `time` as e_time, age group by `time`, age") - .toRel() - .checkRelNode( - "LogicalAggregate(group=[{0, 1}])\n" - + " LogicalProject(e_time=[$1.time], age=[$2.age])\n" - + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testAggregateSum() { - String script = "MATCH (a:person)-[e:knows]->(b:person) " + .toRel() + .checkRelNode( + "LogicalAggregate(group=[{0, 1}])\n" + + " LogicalProject(e_time=[$1.time], age=[$2.age])\n" + + " LogicalGraphMatch(path=[(a:user)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testAggregateSum() { + String script = + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN a.id as a_id, b.id as b_id, SUM(ALL e.weight) as e_sum GROUP BY a_id, b_id"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalAggregate(group=[{0, 1}], e_sum=[SUM($2)])\n" - + " LogicalProject(a_id=[$0.id], b_id=[$2.id], $f2=[$1.weight])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testAggregateSum2() { - String script = "MATCH (a:person)-[e:knows]->(b:person) " + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalAggregate(group=[{0, 1}], e_sum=[SUM($2)])\n" + + " LogicalProject(a_id=[$0.id], b_id=[$2.id], $f2=[$1.weight])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testAggregateSum2() { + String script = + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN b.name as name, SUM(ALL b.age) as amt group by name order by name"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$0], dir0=[ASC])\n" - + " LogicalAggregate(group=[{0}], amt=[SUM($1)])\n" - + " LogicalProject(name=[$2.name], $f1=[$2.age])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testAggregateSum3() { - String script = "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + + " LogicalAggregate(group=[{0}], amt=[SUM($1)])\n" + + " LogicalProject(name=[$2.name], $f1=[$2.age])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]->(b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testAggregateSum3() { + String script = + "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + "RETURN a.id, SUM(ALL e.weight) * 10 as amt GROUP BY a.id order by a.id"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalSort(sort0=[$0], dir0=[ASC])\n" - + " LogicalProject(id=[$0], amt=[*($1, 10)])\n" - + " LogicalAggregate(group=[{0}], agg#0=[SUM($1)])\n" - + " LogicalProject(id=[$0.id], $f1=[$1.weight])\n" - + " LogicalGraphMatch(path=[(a:person)-[e:knows]-> where >(e.weight, 0.4) (b:person)])\n" - + " LogicalGraphScan(table=[default.g1])\n" - ); - } - - @Test - public void testGQLLet1() { - String script = "MATCH(a:person)\n" + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + + " LogicalProject(id=[$0], amt=[*($1, 10)])\n" + + " LogicalAggregate(group=[{0}], agg#0=[SUM($1)])\n" + + " LogicalProject(id=[$0.id], $f1=[$1.weight])\n" + + " LogicalGraphMatch(path=[(a:person)-[e:knows]-> where >(e.weight, 0.4)" + + " (b:person)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testGQLLet1() { + String script = + "MATCH(a:person)\n" + " Let a.weight = a.age / 10\n" + " Let a.weight = a.weight * 2\n" + " Let a.ratio = 1.0"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, " - + "name:$0.name, gender:$0.gender, age:$0.age, weight:CAST(/($0.age, 10)):JavaType(class java.lang.Integer)}]) " - + "PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0.name, gender:$0.gender, age:$0.age, " - + "weight:CAST(*($0.weight, 2)):JavaType(class java.lang.Integer)}]) PathModify([a=VERTEX{id:$0" - + ".id, ~label:$0.~label, name:$0.name, gender:$0.gender, age:$0.age, weight:$0.weight, ratio:1" - + ".0}])])\n" - + " LogicalGraphScan(table=[default.g1])\n"); - } - - @Test - public void testGQLLet2() { - String script = "MATCH(a:person)\n" - + " Let Global a.weight = a.age / 10.0\n" - + " Return a.id, a.weight\n"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalProject(id=[$0.id], weight=[$0.weight])\n" - + " LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0" - + ".name, gender:$0.gender, age:$0.age, weight:CAST(/($0.age, 10.0)):" - + "JavaType(class java.math.BigDecimal)}])])\n" - + " LogicalGraphScan(table=[default.g1])\n"); - } - - @Test - public void testGQLLet3() { - String script = "MATCH(a:person)" + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label," + + " name:$0.name, gender:$0.gender, age:$0.age, weight:CAST(/($0.age," + + " 10)):JavaType(class java.lang.Integer)}]) PathModify([a=VERTEX{id:$0.id," + + " ~label:$0.~label, name:$0.name, gender:$0.gender, age:$0.age," + + " weight:CAST(*($0.weight, 2)):JavaType(class java.lang.Integer)}])" + + " PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0.name," + + " gender:$0.gender, age:$0.age, weight:$0.weight, ratio:1.0}])])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testGQLLet2() { + String script = + "MATCH(a:person)\n" + " Let Global a.weight = a.age / 10.0\n" + " Return a.id, a.weight\n"; + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalProject(id=[$0.id], weight=[$0.weight])\n" + + " LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id," + + " ~label:$0.~label, name:$0.name, gender:$0.gender, age:$0.age," + + " weight:CAST(/($0.age, 10.0)):JavaType(class java.math.BigDecimal)}])])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testGQLLet3() { + String script = + "MATCH(a:person)" + " Let a.weight = a.age / 100," + "Let a.ext = 1" + "Return a.id, a.weight, a.ext"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script) - .toRel() - .checkRelNode( - "LogicalProject(id=[$0.id], weight=[$0.weight], ext=[$0.ext])\n" - + " LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0" - + ".name, gender:$0.gender, age:$0.age, weight:CAST(/($0.age, 100)):JavaType(class java.lang" - + ".Integer)}]) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0.name, gender:$0.gender, " - + "age:$0.age, weight:$0.weight, ext:1}])])\n" - + " LogicalGraphScan(table=[default.g1])\n"); - } - - @Test - public void testContinueMatch() { - String script = "Match(a)-[e1]->(b)\n" - + "Let a.weight = e1.weight / 10\n" - + "Match(b) - (c)"; - - PlanTester - .build() - .gql(script) - .toRel() - .checkRelNode( - "LogicalGraphMatch(path=[(a:)-[e1:]->(b:) PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, " - + "name:$0.name, age:$0.age, weight:CAST(/($1.weight, 10)):JavaType(class java.lang.Double)}])" - + "-[e_col_2:]-(c:)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGQLWith() { - String script = "With p as (Select id from t0)\n" + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script) + .toRel() + .checkRelNode( + "LogicalProject(id=[$0.id], weight=[$0.weight], ext=[$0.ext])\n" + + " LogicalGraphMatch(path=[(a:person) PathModify([a=VERTEX{id:$0.id," + + " ~label:$0.~label, name:$0.name, gender:$0.gender, age:$0.age," + + " weight:CAST(/($0.age, 100)):JavaType(class java.lang.Integer)}])" + + " PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0.name," + + " gender:$0.gender, age:$0.age, weight:$0.weight, ext:1}])])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testContinueMatch() { + String script = "Match(a)-[e1]->(b)\n" + "Let a.weight = e1.weight / 10\n" + "Match(b) - (c)"; + + PlanTester.build() + .gql(script) + .toRel() + .checkRelNode( + "LogicalGraphMatch(path=[(a:)-[e1:]->(b:) PathModify([a=VERTEX{id:$0.id," + + " ~label:$0.~label, name:$0.name, age:$0.age, weight:CAST(/($1.weight," + + " 10)):JavaType(class java.lang.Double)}])-[e_col_2:]-(c:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLWith() { + String script = + "With p as (Select id from t0)\n" + "Match (a where a.id = p.id) -[e] -> (b)\n" + "Return a.id as a_id, b.id as b_id"; - PlanTester - .build() - .gql(script) - .toRel() - .checkRelNode( - "LogicalParameterizedRelNode\n" + " LogicalProject(id=[$0])\n" - + " LogicalTableScan(table=[[default, t0]])\n" - + " LogicalProject(a_id=[$0.id], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:) where =(a.id, CAST($$0):BIGINT) -[e:]->(b:)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGQLFromWith() { - String script = "Select a_id From(\n" + PlanTester.build() + .gql(script) + .toRel() + .checkRelNode( + "LogicalParameterizedRelNode\n" + + " LogicalProject(id=[$0])\n" + + " LogicalTableScan(table=[[default, t0]])\n" + + " LogicalProject(a_id=[$0.id], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:) where =(a.id, CAST($$0):BIGINT) -[e:]->(b:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLFromWith() { + String script = + "Select a_id From(\n" + "With p as (Select id from t0)\n" + "Match (a where a.id = p.id) -[e] -> (b)\n" + "Return a.id as a_id, b.id as b_id" + ")"; - PlanTester - .build() - .gql(script) - .toRel() - .checkRelNode( - "LogicalProject(a_id=[$0])\n" + " LogicalParameterizedRelNode\n" - + " LogicalProject(id=[$0])\n" - + " LogicalTableScan(table=[[default, t0]])\n" - + " LogicalProject(a_id=[$0.id], b_id=[$2.id])\n" - + " LogicalGraphMatch(path=[(a:) where =(a.id, CAST($$0):BIGINT) -[e:]->(b:)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGQLComplexMatchWithPathConcat() { - String script1 = "Match(a)-(b), (b) - (c)" - + "RETURN b, c, a.id"; - - PlanTester - .build() - .gql(script1) - .toRel() - .checkRelNode( - "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" - + " LogicalGraphMatch(path=[(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - - String script2 = "Match(b: person) - (c), (a)-(b: user)" - + "RETURN b, c, a.id"; - - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script2) - .toRel() - .checkRelNode( - "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" - + " LogicalGraphMatch(path=[(a:)-[e_col_3:]-(b:user) where =(b.~label, _UTF-16LE'person') -[e_col_1:]-(c:)" - + "])\n" - + " LogicalGraphScan(table=[default.g1])\n"); - - String script3 = "Match(b: person) - (c where c.name = 'marko'), (a)-(b: user)" - + "RETURN b, c, a.id"; - PlanTester - .build() - .registerGraph(GRAPH_G1) - .gql(script3) - .toRel() - .checkRelNode( - "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" - + " LogicalGraphMatch(path=[(a:)-[e_col_3:]-(b:user) where =(b.~label, _UTF-16LE'person') -[e_col_1:]-" - + "(c:) where =(c.name, _UTF-16LE'marko') ])\n" - + " LogicalGraphScan(table=[default.g1])\n"); - } - - @Test - public void testGQLComplexMatchWithPathNotConcat() { - String script1 = "Match(a)-(b), (b) - (c), (b) - (d) - (f)" - + "RETURN b, c, a, d, f"; - - PlanTester - .build() - .gql(script1) - .toRel() - .checkRelNode( - "LogicalProject(b=[$2], c=[$4], a=[$0], d=[$7], f=[$9])\n" - + " LogicalGraphMatch(path=[{(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:)} Join {(b:)-[e_col_5:]-(d:)" - + "-[e_col_6:]-(f:)}])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - - String script2 = "Match(a:user where a.id = 0)-[e]-(b),(a where a.id = 2)-(c), (d) - (a)\n" + PlanTester.build() + .gql(script) + .toRel() + .checkRelNode( + "LogicalProject(a_id=[$0])\n" + + " LogicalParameterizedRelNode\n" + + " LogicalProject(id=[$0])\n" + + " LogicalTableScan(table=[[default, t0]])\n" + + " LogicalProject(a_id=[$0.id], b_id=[$2.id])\n" + + " LogicalGraphMatch(path=[(a:) where =(a.id, CAST($$0):BIGINT)" + + " -[e:]->(b:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLComplexMatchWithPathConcat() { + String script1 = "Match(a)-(b), (b) - (c)" + "RETURN b, c, a.id"; + + PlanTester.build() + .gql(script1) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" + + " LogicalGraphMatch(path=[(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + + String script2 = "Match(b: person) - (c), (a)-(b: user)" + "RETURN b, c, a.id"; + + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script2) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" + + " LogicalGraphMatch(path=[(a:)-[e_col_3:]-(b:user) where =(b.~label," + + " _UTF-16LE'person') -[e_col_1:]-(c:)])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + + String script3 = + "Match(b: person) - (c where c.name = 'marko'), (a)-(b: user)" + "RETURN b, c, a.id"; + PlanTester.build() + .registerGraph(GRAPH_G1) + .gql(script3) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2], c=[$4], id=[$0.id])\n" + + " LogicalGraphMatch(path=[(a:)-[e_col_3:]-(b:user) where =(b.~label," + + " _UTF-16LE'person') -[e_col_1:]-(c:) where =(c.name, _UTF-16LE'marko') ])\n" + + " LogicalGraphScan(table=[default.g1])\n"); + } + + @Test + public void testGQLComplexMatchWithPathNotConcat() { + String script1 = "Match(a)-(b), (b) - (c), (b) - (d) - (f)" + "RETURN b, c, a, d, f"; + + PlanTester.build() + .gql(script1) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2], c=[$4], a=[$0], d=[$7], f=[$9])\n" + + " LogicalGraphMatch(path=[{(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:)} Join" + + " {(b:)-[e_col_5:]-(d:)-[e_col_6:]-(f:)}])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + + String script2 = + "Match(a:user where a.id = 0)-[e]-(b),(a where a.id = 2)-(c), (d) - (a)\n" + "RETURN a, b, c, d"; - PlanTester - .build() - .gql(script2) - .toRel() - .checkRelNode( - "LogicalProject(a=[$2], b=[$4], c=[$7], d=[$0])\n" - + " LogicalGraphMatch(path=[{(d:)-[e_col_4:]-(a:) where =(a.~label, _UTF-16LE'user') where =(a.id, 0) " - + "-[e:]-(b:)} Join {(a:) where =(a.id, 2) -[e_col_2:]-(c:)}])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - - String script3 = "Match(a)-(b where b.name = 'marko'), (c) - (d where d.id = 1)\n" - + "Return d, c, a, b"; - - PlanTester - .build() - .gql(script3) - .toRel() - .checkRelNode("LogicalProject(d=[$5], c=[$3], a=[$0], b=[$2])\n" - + " LogicalGraphMatch(path=[{(a:)-[e_col_1:]-(b:) where =(b.name, _UTF-16LE'marko') } Join {(c:)" - + "-[e_col_3:]-(d:) where =(d.id, 1) }])\n" + PlanTester.build() + .gql(script2) + .toRel() + .checkRelNode( + "LogicalProject(a=[$2], b=[$4], c=[$7], d=[$0])\n" + + " LogicalGraphMatch(path=[{(d:)-[e_col_4:]-(a:) where =(a.~label," + + " _UTF-16LE'user') where =(a.id, 0) -[e:]-(b:)} Join {(a:) where =(a.id, 2)" + + " -[e_col_2:]-(c:)}])\n" + " LogicalGraphScan(table=[default.g0])\n"); - } - @Test - public void testGQLComplexMatchWithLet() { - String script = "Match(a)-(b), (b) - (c)" + String script3 = + "Match(a)-(b where b.name = 'marko'), (c) - (d where d.id = 1)\n" + "Return d, c, a, b"; + + PlanTester.build() + .gql(script3) + .toRel() + .checkRelNode( + "LogicalProject(d=[$5], c=[$3], a=[$0], b=[$2])\n" + + " LogicalGraphMatch(path=[{(a:)-[e_col_1:]-(b:) where =(b.name," + + " _UTF-16LE'marko') } Join {(c:)-[e_col_3:]-(d:) where =(d.id, 1) }])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLComplexMatchWithLet() { + String script = + "Match(a)-(b), (b) - (c)" + "Let a.weight = a.age / 10," + "Let b.weight = b.age / 10" + "RETURN a, b, c.id"; - PlanTester - .build() - .gql(script) - .toRel() - .checkRelNode( - "LogicalProject(a=[$0], b=[$2], id=[$4.id])\n" - + " LogicalGraphMatch(path=[(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:) PathModify([a=VERTEX{id:$0.id, " - + "~label:$0.~label, name:$0.name, age:$0.age, weight:CAST(/($0.age, 10)):JavaType(class java" - + ".lang.Integer)}]) PathModify([b=VERTEX{id:$2.id, ~label:$2.~label, name:$2.name, age:$2.age, " - + "weight:CAST(/($2.age, 10)):JavaType(class java.lang.Integer)}])])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGQLLoopMatch() { - String script1 = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 5} (b:user)\n" + PlanTester.build() + .gql(script) + .toRel() + .checkRelNode( + "LogicalProject(a=[$0], b=[$2], id=[$4.id])\n" + + " LogicalGraphMatch(path=[(a:)-[e_col_1:]-(b:)-[e_col_3:]-(c:)" + + " PathModify([a=VERTEX{id:$0.id, ~label:$0.~label, name:$0.name, age:$0.age," + + " weight:CAST(/($0.age, 10)):JavaType(class java.lang.Integer)}])" + + " PathModify([b=VERTEX{id:$2.id, ~label:$2.~label, name:$2.name, age:$2.age," + + " weight:CAST(/($2.age, 10)):JavaType(class java.lang.Integer)}])])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLLoopMatch() { + String script1 = + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 5} (b:user)\n" + " RETURN a"; - PlanTester - .build() - .gql(script1) - .toRel() - .checkRelNode( - "LogicalProject(a=[$0])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) - loop(-[e:knows]-> where >(e.weight, 0.4)" - + " (b:user)).time(1,5).until(true)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - - - String script2 = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1,} (b:user) -> (c)\n" + PlanTester.build() + .gql(script1) + .toRel() + .checkRelNode( + "LogicalProject(a=[$0])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) - loop(-[e:knows]-> where" + + " >(e.weight, 0.4) (b:user)).time(1,5).until(true)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + + String script2 = + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1,} (b:user) -> (c)\n" + " RETURN a, c"; - PlanTester - .build() - .gql(script2) - .toRel() - .checkRelNode( - "LogicalProject(a=[$0], c=[$4])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) - loop(-[e:knows]-> where >(e.weight, 0.4)" - + " (b:user)).time(1,-1).until(true)-[e_col_1:]->(c:)])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGQLSubQuery() { - String script1 = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]-> (b:user)\n" + PlanTester.build() + .gql(script2) + .toRel() + .checkRelNode( + "LogicalProject(a=[$0], c=[$4])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) - loop(-[e:knows]-> where" + + " >(e.weight, 0.4) (b:user)).time(1,-1).until(true)-[e_col_1:]->(c:)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGQLSubQuery() { + String script1 = + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]-> (b:user)\n" + " Where AVG((b) ->(c) => c.age) > 20\n" + " RETURN b"; - PlanTester - .build() - .gql(script1) - .toRel() - .checkRelNode( - "LogicalProject(b=[$2])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where >(e.weight, 0.4) " - + "(b:user) where >(AVG((b:)-[e_col_2:]->(c:) => $4.age), 20) ])\n" - + " LogicalGraphScan(table=[default.g0])\n"); - - - String script2 = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]-> (b:user)\n" + PlanTester.build() + .gql(script1) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where" + + " >(e.weight, 0.4) (b:user) where >(AVG((b:)-[e_col_2:]->(c:) => $4.age), 20)" + + " ])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + + String script2 = + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]-> (b:user)\n" + " Where SUM((b) ->(c) => c.age) > 20 AND COUNT((b) -(c) => c) > 10\n" + " RETURN b"; - PlanTester - .build() - .gql(script2) - .toRel() - .checkRelNode("LogicalProject(b=[$2])\n" - + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where >(e.weight, 0.4) (b:user) " - + "where AND(>(SUM((b:)-[e_col_2:]->(c:) => $4.age), 20), >(COUNT((b:)-[e_col_4:]-(c:) => $4), 10)) " - + "])\n" + PlanTester.build() + .gql(script2) + .toRel() + .checkRelNode( + "LogicalProject(b=[$2])\n" + + " LogicalGraphMatch(path=[(a:user) where =(a.id, 1) -[e:knows]-> where" + + " >(e.weight, 0.4) (b:user) where AND(>(SUM((b:)-[e_col_2:]->(c:) => $4.age)," + + " 20), >(COUNT((b:)-[e_col_4:]-(c:) => $4), 10)) ])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + } + + @Test + public void testGraphAlgorithm() { + PlanTester.build() + .gql("CALL SSSP(1) YIELD (vid, distance)\n" + "RETURN vid, distance") + .toRel() + .checkRelNode( + "LogicalProject(vid=[$0], distance=[$1])\n" + + " LogicalGraphAlgorithm(algo=[SingleSourceShortestPath], params=[[Integer]], " + + "outputType=[RecordType:peek(BIGINT id, BIGINT distance)])\n" + + " LogicalGraphScan(table=[default.g0])\n"); + PlanTester.build() + .gql("CALL SSSP() YIELD (vid, distance)\n" + "RETURN vid, distance") + .toRel() + .checkRelNode( + "LogicalProject(vid=[$0], distance=[$1])\n" + + " LogicalGraphAlgorithm(algo=[SingleSourceShortestPath], params=[[]], " + + "outputType=[RecordType:peek(BIGINT id, BIGINT distance)])\n" + " LogicalGraphScan(table=[default.g0])\n"); - } - - @Test - public void testGraphAlgorithm() { - PlanTester.build() - .gql("CALL SSSP(1) YIELD (vid, distance)\n" + "RETURN vid, distance") - .toRel() - .checkRelNode( - "LogicalProject(vid=[$0], distance=[$1])\n" - + " LogicalGraphAlgorithm(algo=[SingleSourceShortestPath], params=[[Integer]], " - + "outputType=[RecordType:peek(BIGINT id, BIGINT distance)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - PlanTester.build() - .gql("CALL SSSP() YIELD (vid, distance)\n" + "RETURN vid, distance") - .toRel() - .checkRelNode( - "LogicalProject(vid=[$0], distance=[$1])\n" - + " LogicalGraphAlgorithm(algo=[SingleSourceShortestPath], params=[[]], " - + "outputType=[RecordType:peek(BIGINT id, BIGINT distance)])\n" - + " LogicalGraphScan(table=[default.g0])\n" - ); - } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateComplexMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateComplexMatchTest.java index 03bb36bc5..9d9caee48 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateComplexMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateComplexMatchTest.java @@ -23,87 +23,85 @@ public class GQLValidateComplexMatchTest { - @Test - public void testMultiMatch() { - String script = "Match(a)-(b), (b) - (c)" - + "RETURN a.id, b, c"; + @Test + public void testMultiMatch() { + String script = "Match(a)-(b), (b) - (c)" + "RETURN a.id, b, c"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(BIGINT id, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) c)"); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(BIGINT id, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) b, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) c)"); + } - @Test - public void testMultiMatchWithLet() { - String script = "Match(a)-(b), (b) - (c)" + @Test + public void testMultiMatchWithLet() { + String script = + "Match(a)-(b), (b) - (c)" + "Let a.weight = a.age / 10," + "Let b.weight = b.age / 10" + "RETURN a, b, c.id"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age, " - + "JavaType(class java.lang.Integer) weight) a, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age, " - + "JavaType(class java.lang.Integer) weight) b, " - + "BIGINT id)"); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age, JavaType(class java.lang.Integer) weight) a, Vertex:RecordType:peek(BIGINT" + + " id, VARCHAR ~label, VARCHAR name, INTEGER age, JavaType(class" + + " java.lang.Integer) weight) b, BIGINT id)"); + } - @Test - public void testUnion() { - String script = "Match(a) - (b) |+| (c) - (d) |+| (e) - (f)" - + "RETURN a.id, b, c, d, e, f"; + @Test + public void testUnion() { + String script = "Match(a) - (b) |+| (c) - (d) |+| (e) - (f)" + "RETURN a.id, b, c, d, e, f"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(BIGINT id, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) c, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) d, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) e, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) f)" - ); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(BIGINT id, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) b, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) c, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) d, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) e, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) f)"); + } - @Test - public void testUnionWithDuplicatedAlias() { - String script = "Match(a) - (b) |+| (a) - (c) |+| (e) - (f)" - + "RETURN a.id, a, b, e, f"; + @Test + public void testUnionWithDuplicatedAlias() { + String script = "Match(a) - (b) |+| (a) - (c) |+| (e) - (f)" + "RETURN a.id, a, b, e, f"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(BIGINT id, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) e, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) f)" - ); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(BIGINT id, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) a, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) b, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) e, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age) f)"); + } - @Test - public void testUnionWithLet() { - String script = "MATCH (a:user where a.id = 1) -[e1:knows]->(b1:user) |+| (a:user where a.id = 2) -[e2:knows]->(b2:user)\n" - + "LET a.weight = a.age / cast(100.0 as double),\n" + "LET a.weight = a.weight * 2,\n" - + "LET e1.weight = e1.weight * 10\n" + "MATCH(b2) -[]->(c1) | (b2) <-[]-(c2)\n" - + "RETURN a.weight AS a_weight, b1.id AS b1_id, c1.id AS c1_id, b2.id AS b2_id, c2.id " - + "AS c2_id"; + @Test + public void testUnionWithLet() { + String script = + "MATCH (a:user where a.id = 1) -[e1:knows]->(b1:user) |+| (a:user where a.id = 2)" + + " -[e2:knows]->(b2:user)\n" + + "LET a.weight = a.age / cast(100.0 as double),\n" + + "LET a.weight = a.weight * 2,\n" + + "LET e1.weight = e1.weight * 10\n" + + "MATCH(b2) -[]->(c1) | (b2) <-[]-(c2)\n" + + "RETURN a.weight AS a_weight, b1.id AS b1_id, c1.id AS c1_id, b2.id AS b2_id, c2.id" + + " AS c2_id"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(JavaType(class java.lang.Double) a_weight, " - + "BIGINT b1_id, BIGINT c1_id, BIGINT b2_id, BIGINT c2_id)" - ); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(JavaType(class java.lang.Double) a_weight, " + + "BIGINT b1_id, BIGINT c1_id, BIGINT b2_id, BIGINT c2_id)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateContinueMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateContinueMatchTest.java index fa8fd036a..29d3df39b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateContinueMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateContinueMatchTest.java @@ -23,35 +23,35 @@ public class GQLValidateContinueMatchTest { - @Test - public void testContinueMatch_001() { - String script = "Match(a)-[e1]->(b)" - + " Match(b) <-[e2] - (c)" - + "RETURN a, b, c"; + @Test + public void testContinueMatch_001() { + String script = "Match(a)-[e1]->(b)" + " Match(b) <-[e2] - (c)" + "RETURN a, b, c"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) c)"); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) a, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) b, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) c)"); + } - @Test - public void testContinueMatch_002() { - String script = "Match(a)-[e1]->(b)" + @Test + public void testContinueMatch_002() { + String script = + "Match(a)-[e1]->(b)" + " Let a.weight = a.age / 100" + " Match(b) <-[e2] - (c)" + "RETURN a, b, c"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age, " - + "JavaType(class java.lang.Integer) weight) a, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) c)"); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age, JavaType(class java.lang.Integer) weight) a, Vertex:RecordType:peek(BIGINT" + + " id, VARCHAR ~label, VARCHAR name, INTEGER age) b, Vertex:RecordType:peek(BIGINT" + + " id, VARCHAR ~label, VARCHAR name, INTEGER age) c)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateFilterStatementTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateFilterStatementTest.java index c4178661b..67140417f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateFilterStatementTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateFilterStatementTest.java @@ -23,23 +23,31 @@ public class GQLValidateFilterStatementTest { - @Test - public void testSimpleFilter() { - PlanTester.build() - .gql("MATCH (a:user WHERE a.id = '1')-[e:knows]->(b:user)\n" - + "RETURN b.name, b as _b\n" + "THEN\n" + "FILTER name IS NOT NULL AND _b.id " + @Test + public void testSimpleFilter() { + PlanTester.build() + .gql( + "MATCH (a:user WHERE a.id = '1')-[e:knows]->(b:user)\n" + + "RETURN b.name, b as _b\n" + + "THEN\n" + + "FILTER name IS NOT NULL AND _b.id " + "> 10") - .validate() - .expectValidateType("RecordType(VARCHAR name, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) _b)"); - } + .validate() + .expectValidateType( + "RecordType(VARCHAR name, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR" + + " name, INTEGER age) _b)"); + } - @Test - public void testNotUseAlias() { - PlanTester.build() - .gql("MATCH (a:user WHERE a.id = '1')-[e:knows]->(b:user)\n" - + "RETURN b.name, b as _b\n" + "THEN\n" + "FILTER b.name IS NOT NULL AND _b.id " + @Test + public void testNotUseAlias() { + PlanTester.build() + .gql( + "MATCH (a:user WHERE a.id = '1')-[e:knows]->(b:user)\n" + + "RETURN b.name, b as _b\n" + + "THEN\n" + + "FILTER b.name IS NOT NULL AND _b.id " + "> 10") - .validate() - .expectException("At line 4, column 8: Table 'b' not found"); - } + .validate() + .expectException("At line 4, column 8: Table 'b' not found"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java index a1de8505a..2c0a5d9b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java @@ -23,22 +23,20 @@ public class GQLValidateGraphAlgorithmTest { - @Test - public void testGraphAlgorithm() { - String script1 = "CALL SSSP(1) YIELD (vid, distance)\n" + "RETURN vid, distance"; + @Test + public void testGraphAlgorithm() { + String script1 = "CALL SSSP(1) YIELD (vid, distance)\n" + "RETURN vid, distance"; - PlanTester.build() - .gql(script1) - .validate() - .expectValidateType( - "RecordType(BIGINT vid, BIGINT distance)"); + PlanTester.build() + .gql(script1) + .validate() + .expectValidateType("RecordType(BIGINT vid, BIGINT distance)"); - String script2 = "CALL SSSP() YIELD (vid, distance)\n" + "RETURN vid, distance"; + String script2 = "CALL SSSP() YIELD (vid, distance)\n" + "RETURN vid, distance"; - PlanTester.build() - .gql(script2) - .validate() - .expectValidateType( - "RecordType(BIGINT vid, BIGINT distance)"); - } + PlanTester.build() + .gql(script2) + .validate() + .expectValidateType("RecordType(BIGINT vid, BIGINT distance)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphRecordTypeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphRecordTypeTest.java index 4b3b6baff..d3c5e6830 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphRecordTypeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphRecordTypeTest.java @@ -23,9 +23,10 @@ public class GQLValidateGraphRecordTypeTest { - @Test - public void testDifferentVertexFieldType() { - String graphDDL = "create graph g1(" + @Test + public void testDifferentVertexFieldType() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -49,17 +50,20 @@ public void testDifferentVertexFieldType() { + " weight double" + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user|person WHERE a.id = 1)-[e:knows]->(b:user)\n" + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user|person WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a, e, b Order by a.id " + "DESC Limit 10") - .validate() - .expectException("Same name field between vertex tables shouldn't have different type."); - } + .validate() + .expectException("Same name field between vertex tables shouldn't have different type."); + } - @Test - public void testDifferentEdgeFieldType() { - String graphDDL = "create graph g1(" + @Test + public void testDifferentEdgeFieldType() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -83,17 +87,20 @@ public void testDifferentEdgeFieldType() { + " weight int" + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user|person WHERE a.id = 1)-[e:knows|follow]->(b:user)\n" + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user|person WHERE a.id = 1)-[e:knows|follow]->(b:user)\n" + "RETURN a, e, b Order by a.id " + "DESC Limit 10") - .validate() - .expectException("Same name field between edge tables shouldn't have different type."); - } + .validate() + .expectException("Same name field between edge tables shouldn't have different type."); + } - @Test(enabled = false) - public void testSameIdFieldNameValidation() { - String graphDDL = "create graph g1(" + @Test(enabled = false) + public void testSameIdFieldNameValidation() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -118,13 +125,15 @@ public void testSameIdFieldNameValidation() { + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .expectException("Id field name should be same between vertex tables"); - } + PlanTester.build() + .registerGraph(graphDDL) + .expectException("Id field name should be same between vertex tables"); + } - @Test(enabled = false) - public void testSameSourceIdFieldNameValidation() { - String graphDDL = "create graph g1(" + @Test(enabled = false) + public void testSameSourceIdFieldNameValidation() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -148,13 +157,15 @@ public void testSameSourceIdFieldNameValidation() { + " weight double" + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .expectException("SOURCE ID field name should be same between edge tables"); - } + PlanTester.build() + .registerGraph(graphDDL) + .expectException("SOURCE ID field name should be same between edge tables"); + } - @Test(enabled = false) - public void testSameDestinationIdFieldNameValidation() { - String graphDDL = "create graph g1(" + @Test(enabled = false) + public void testSameDestinationIdFieldNameValidation() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -179,14 +190,15 @@ public void testSameDestinationIdFieldNameValidation() { + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .expectException("DESTINATION ID field name should be same between edge tables"); - } + PlanTester.build() + .registerGraph(graphDDL) + .expectException("DESTINATION ID field name should be same between edge tables"); + } - - @Test(enabled = false) - public void testSameTimestampFieldNameValidation() { - String graphDDL = "create graph g1(" + @Test(enabled = false) + public void testSameTimestampFieldNameValidation() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -210,7 +222,8 @@ public void testSameTimestampFieldNameValidation() { + " weight double" + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .expectException("TIMESTAMP field name should be same between edge tables"); - } + PlanTester.build() + .registerGraph(graphDDL) + .expectException("TIMESTAMP field name should be same between edge tables"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLetStatementTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLetStatementTest.java index d9f8e0ac8..cb8e41e7b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLetStatementTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLetStatementTest.java @@ -23,27 +23,26 @@ public class GQLValidateLetStatementTest { - @Test - public void testLet_001() { - PlanTester.build() - .gql("Match(a: user)" - + "Let a.cnt = a.age / 2") - .validate() - .expectValidateType("Path:RecordType:peek(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, " + @Test + public void testLet_001() { + PlanTester.build() + .gql("Match(a: user)" + "Let a.cnt = a.age / 2") + .validate() + .expectValidateType( + "Path:RecordType:peek(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, " + "VARCHAR name, INTEGER age, JavaType(class java.lang.Integer) cnt) a)"); - } + } - @Test - public void testLet_002() { - PlanTester.build() - .gql("Match(a: user) -[e]->(b)" - + "Let a.cnt = a.age / 2," - + "Let e.ratio = e.weight * 100") - .validate() - .expectValidateType("Path:RecordType:peek(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, " - + "VARCHAR name, INTEGER age, JavaType(class java.lang.Integer) cnt) a, " - + "Edge: RecordType:peek(BIGINT src_id, BIGINT target_id, VARCHAR ~label, " - + "DOUBLE weight, JavaType(class java.lang.Double) ratio) e, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + @Test + public void testLet_002() { + PlanTester.build() + .gql("Match(a: user) -[e]->(b)" + "Let a.cnt = a.age / 2," + "Let e.ratio = e.weight * 100") + .validate() + .expectValidateType( + "Path:RecordType:peek(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " INTEGER age, JavaType(class java.lang.Integer) cnt) a, Edge:" + + " RecordType:peek(BIGINT src_id, BIGINT target_id, VARCHAR ~label, DOUBLE weight," + + " JavaType(class java.lang.Double) ratio) e, Vertex:RecordType:peek(BIGINT id," + + " VARCHAR ~label, VARCHAR name, INTEGER age) b)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLoopMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLoopMatchTest.java index f2e9fc0fc..a35f1dda2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLoopMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateLoopMatchTest.java @@ -23,30 +23,36 @@ public class GQLValidateLoopMatchTest { - @Test - public void testValidateLoopMatch() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 5} (b:user)\n" + @Test + public void testValidateLoopMatch() { + PlanTester.build() + .gql( + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 5} (b:user)\n" + " RETURN a") - .validate() - .expectValidateType( - "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a)"); + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) a)"); - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, } (b:user)\n" + PlanTester.build() + .gql( + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, } (b:user)\n" + " RETURN a, b") - .validate() - .expectValidateType( - "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a," - + " Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) a, Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER" + + " age) b)"); + } - @Test - public void testValidateLoopMatchException() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 0} (b:user)\n" + @Test + public void testValidateLoopMatchException() { + PlanTester.build() + .gql( + "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->{1, 0} (b:user)\n" + " RETURN e") - .validate() - .expectException("At line 1, column 32: The max hop: 0 count should greater than min hop: 1."); - } + .validate() + .expectException( + "At line 1, column 32: The max hop: 0 count should greater than min hop: 1."); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateMatchStatementTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateMatchStatementTest.java index 352b0d617..5b91d0309 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateMatchStatementTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateMatchStatementTest.java @@ -23,204 +23,218 @@ public class GQLValidateMatchStatementTest { - @Test - public void testValidateMatchWhere1() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatchWhere1() { + PlanTester.build() + .gql("MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a)"); - } + } - @Test - public void testValidateMatchWhere2() { - PlanTester.build() - .gql("Match (a WHERE name = 'marko')<-[e]-(b) WHERE a.name <> b.name RETURN e") - .validate() - .expectValidateType("RecordType(Edge: RecordType:peek" + @Test + public void testValidateMatchWhere2() { + PlanTester.build() + .gql("Match (a WHERE name = 'marko')<-[e]-(b) WHERE a.name <> b.name RETURN e") + .validate() + .expectValidateType( + "RecordType(Edge: RecordType:peek" + "(BIGINT src_id, BIGINT target_id, VARCHAR ~label, DOUBLE weight) e)"); - } + } - @Test - public void testValidateMatchWhere3() { - PlanTester.build() - .gql("Match (a:user WHERE name = 'where')-[e]-(b) RETURN b") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatchWhere3() { + PlanTester.build() + .gql("Match (a:user WHERE name = 'where')-[e]-(b) RETURN b") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + } - @Test - public void testValidateMatchWhere4() { - PlanTester.build() - .gql("Match (a WHERE name = 'match')-[e]->(b) RETURN a.name") - .validate() - .expectValidateType("RecordType(VARCHAR name)"); - } + @Test + public void testValidateMatchWhere4() { + PlanTester.build() + .gql("Match (a WHERE name = 'match')-[e]->(b) RETURN a.name") + .validate() + .expectValidateType("RecordType(VARCHAR name)"); + } - @Test - public void testValidateMatchWhere5() { - PlanTester.build() - .gql("Match (a WHERE name = 'knows')<-[e]->(b) RETURN e.weight") - .validate() - .expectValidateType("RecordType(DOUBLE weight)"); - } + @Test + public void testValidateMatchWhere5() { + PlanTester.build() + .gql("Match (a WHERE name = 'knows')<-[e]->(b) RETURN e.weight") + .validate() + .expectValidateType("RecordType(DOUBLE weight)"); + } - @Test - public void testValidateMatch1() { - PlanTester.build() - .gql("MATCH (a)->(b) - (c) RETURN a, b") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch1() { + PlanTester.build() + .gql("MATCH (a)->(b) - (c) RETURN a, b") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + } - @Test - public void testValidateMatch2() { - PlanTester.build() - .gql("MATCH (a)<-(b) <->(c) RETURN b, c, a") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch2() { + PlanTester.build() + .gql("MATCH (a)<-(b) <->(c) RETURN b, c, a") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) c, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a)"); - } + } - @Test - public void testValidateMatch3() { - PlanTester.build() - .gql("MATCH (e) -> (d) <- (f) RETURN d, e") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch3() { + PlanTester.build() + .gql("MATCH (e) -> (d) <- (f) RETURN d, e") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) d, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) e)"); - } + } - @Test - public void testValidateMatch4() { - PlanTester.build() - .gql("MATCH (e) -> (d) - (f) RETURN e, f, d") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch4() { + PlanTester.build() + .gql("MATCH (e) -> (d) - (f) RETURN e, f, d") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) e, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) f, " + "Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) d)"); - } + } - @Test - public void testValidateMatch5() { - PlanTester.build() - .gql("MATCH (n) RETURN n") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch5() { + PlanTester.build() + .gql("MATCH (n) RETURN n") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) n)"); - } + } - @Test - public void testValidateMatch6() { - PlanTester.build() - .gql(" MATCH (n:user) RETURN n") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch6() { + PlanTester.build() + .gql(" MATCH (n:user) RETURN n") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) n)"); - } + } - @Test - public void testValidateMatch7() { - PlanTester.build() - .gql("MATCH (a) -[e]->(b) RETURN b ") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch7() { + PlanTester.build() + .gql("MATCH (a) -[e]->(b) RETURN b ") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + } - @Test - public void testValidateMatch8() { - PlanTester.build() - .gql(" MATCH (a) - (b) RETURN a") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatch8() { + PlanTester.build() + .gql(" MATCH (a) - (b) RETURN a") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) a)"); - } + } - @Test - public void testValidateMatchWhere6() { - PlanTester.build() - .gql("MATCH (a:user WHERE a.age > 18) - (b: user) RETURN b") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + @Test + public void testValidateMatchWhere6() { + PlanTester.build() + .gql("MATCH (a:user WHERE a.age > 18) - (b: user) RETURN b") + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER age) b)"); - } + } - @Test - public void testVertexColumnNotExists() { - PlanTester.build() - .gql("MATCH (a:user where c = 1)->(b) - (c) RETURN c") - .validate() - .expectException("At line 1, column 21: Column 'c' not found in any table"); - } + @Test + public void testVertexColumnNotExists() { + PlanTester.build() + .gql("MATCH (a:user where c = 1)->(b) - (c) RETURN c") + .validate() + .expectException("At line 1, column 21: Column 'c' not found in any table"); + } - @Test - public void testEdgeColumnNotExists() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e where id = 1]->(b) - (c) RETURN c") - .validate() - .expectException("From line 1, column 38 to line 1, column 39: " - + "Column 'id' not found in any table"); - } + @Test + public void testEdgeColumnNotExists() { + PlanTester.build() + .gql("MATCH (a:user where id = 1)-[e where id = 1]->(b) - (c) RETURN c") + .validate() + .expectException( + "From line 1, column 38 to line 1, column 39: " + "Column 'id' not found in any table"); + } - @Test - public void testVertexTypeNotExists() { - PlanTester.build() - .gql("MATCH (a:user|person where id = 1)->(b) - (c) RETURN c") - .validate() - .expectException("Cannot find vertex type: 'person'."); - } + @Test + public void testVertexTypeNotExists() { + PlanTester.build() + .gql("MATCH (a:user|person where id = 1)->(b) - (c) RETURN c") + .validate() + .expectException("Cannot find vertex type: 'person'."); + } - @Test - public void testEdgeTypeNotExists() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e:test]->(b) - (c) RETURN c") - .validate() - .expectException("Cannot find edge type: 'test'."); - } + @Test + public void testEdgeTypeNotExists() { + PlanTester.build() + .gql("MATCH (a:user where id = 1)-[e:test]->(b) - (c) RETURN c") + .validate() + .expectException("Cannot find edge type: 'test'."); + } - @Test - public void testVertexScopeNotExists() { - PlanTester.build() - .gql("MATCH (a:user where user.id = 1)->(b) - (c) RETURN c") - .validate() - .expectException("From line 1, column 21 to line 1, column 24: " + @Test + public void testVertexScopeNotExists() { + PlanTester.build() + .gql("MATCH (a:user where user.id = 1)->(b) - (c) RETURN c") + .validate() + .expectException( + "From line 1, column 21 to line 1, column 24: " + "Column 'user' not found in any table"); - } + } - @Test - public void testEdgeScopeNotExists() { - PlanTester.build() - .gql("MATCH (a:user where id = 1)-[e where knows.src_id = 1]->(b) - (c) RETURN c") - .validate() - .expectException("From line 1, column 38 to line 1, column 42: " - + "Table 'knows' not found"); - } + @Test + public void testEdgeScopeNotExists() { + PlanTester.build() + .gql("MATCH (a:user where id = 1)-[e where knows.src_id = 1]->(b) - (c) RETURN c") + .validate() + .expectException( + "From line 1, column 38 to line 1, column 42: " + "Table 'knows' not found"); + } - @Test - public void testDuplicatedVariablesInMatchPattern() { - PlanTester.build() - .gql("MATCH (a where id = 1)-[e]->(b) -[e]- (a) RETURN c") - .validate() - .expectException("At line 1, column 37: Duplicated node label: e in the path pattern."); - } + @Test + public void testDuplicatedVariablesInMatchPattern() { + PlanTester.build() + .gql("MATCH (a where id = 1)-[e]->(b) -[e]- (a) RETURN c") + .validate() + .expectException("At line 1, column 37: Duplicated node label: e in the path pattern."); + } - @Test - public void testVertexScopeExists() { - String graphDDL = "create graph g1(" + @Test + public void testVertexScopeExists() { + String graphDDL = + "create graph g1(" + "vertex user(" + " id bigint ID," + "name varchar" @@ -244,16 +258,18 @@ public void testVertexScopeExists() { + " weight double" + ")" + ")"; - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (person:person|user WHERE person.id = 1)-[e:knows|follow]->(b:user)\n" + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (person:person|user WHERE person.id = 1)-[e:knows|follow]->(b:user)\n" + "RETURN person, e, b Order by person.id " + "DESC Limit 10") - .validate() - .expectValidateType("RecordType(Vertex:RecordType:peek" + .validate() + .expectValidateType( + "RecordType(Vertex:RecordType:peek" + "(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER gender, INTEGER age) person, " + "Edge: RecordType:peek" + "(BIGINT src_id, BIGINT target_id, VARCHAR ~label, BIGINT time, DOUBLE weight) e," + " Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) b)"); - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateReturnStatementTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateReturnStatementTest.java index a00156b20..e6060c2a3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateReturnStatementTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateReturnStatementTest.java @@ -22,295 +22,354 @@ import org.testng.annotations.Test; public class GQLValidateReturnStatementTest { - private static final String graphDDL = "create graph g1(" - + "vertex user(" - + " id bigint ID," - + "name varchar" - + ")," - + "vertex person(" - + " id bigint ID," - + "name varchar," - + "gender int," - + "age integer" - + ")," - + "edge knows(" - + " src_id bigint SOURCE ID," - + " target_id bigint DESTINATION ID," - + " time bigint TIMESTAMP," - + " weight double" - + ")" - + ")"; + private static final String graphDDL = + "create graph g1(" + + "vertex user(" + + " id bigint ID," + + "name varchar" + + ")," + + "vertex person(" + + " id bigint ID," + + "name varchar," + + "gender int," + + "age integer" + + ")," + + "edge knows(" + + " src_id bigint SOURCE ID," + + " target_id bigint DESTINATION ID," + + " time bigint TIMESTAMP," + + " weight double" + + ")" + + ")"; - @Test - public void testValidatedReturnVertex() { - String script = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a"; - String expectType = "RecordType(" - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) a)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + @Test + public void testValidatedReturnVertex() { + String script = "MATCH (a:user where id = 1)-[e:knows where e.weight > 0.4]->(b:user) RETURN a"; + String expectType = + "RecordType(" + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) a)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnEdge() { - String script = "Match (a:user WHERE name = 'marko')<-[e:knows]-(b:person) " + @Test + public void testValidatedReturnEdge() { + String script = + "Match (a:user WHERE name = 'marko')<-[e:knows]-(b:person) " + "WHERE a.name <> b.name RETURN e"; - String expectType = "RecordType(Edge: RecordType:peek(BIGINT src_id, BIGINT target_id, VARCHAR ~label, BIGINT time, " - + "DOUBLE weight) e)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = + "RecordType(Edge: RecordType:peek(BIGINT src_id, BIGINT target_id, VARCHAR ~label, BIGINT" + + " time, DOUBLE weight) e)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnVertex2() { - String script = "MATCH (a:person WHERE a.age > 18) - (b: person) RETURN b"; - String expectType = "RecordType(" - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER gender, INTEGER age) b)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + @Test + public void testValidatedReturnVertex2() { + String script = "MATCH (a:person WHERE a.age > 18) - (b: person) RETURN b"; + String expectType = + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER gender," + + " INTEGER age) b)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnFields() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnFields() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN a.name as name, b.id as b_id"; - String expectType = "RecordType(VARCHAR name, BIGINT b_id)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, BIGINT b_id)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnMultiplication() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnMultiplication() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN a.name as name, b.id as b_id, b.age * 10 as amt"; - String expectType = "RecordType(VARCHAR name, BIGINT b_id, JavaType(class java.lang.Long) amt)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, BIGINT b_id, JavaType(class java.lang.Long) amt)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnSum() { - String script = "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + @Test + public void testValidatedReturnSum() { + String script = + "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + "RETURN a.id as a_id, SUM(ALL e.weight) * 10 as amt GROUP BY a_id"; - String expectType = "RecordType(BIGINT a_id, JavaType(class java.lang.Double) amt)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(BIGINT a_id, JavaType(class java.lang.Double) amt)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnGroupByAlias() { - String script = "MATCH (a:person)-[e:knows]->(b:person) " + @Test + public void testValidatedReturnGroupByAlias() { + String script = + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN b.name as name, SUM(ALL b.age) as amt group by name"; - String expectType = "RecordType(VARCHAR name, INTEGER amt)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, INTEGER amt)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnGroupByAlias2() { - String script = "MATCH (a:person)-[e:knows]->(b:person) " + @Test + public void testValidatedReturnGroupByAlias2() { + String script = + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN a.id as a_id, b.id as b_id, SUM(ALL e.weight) as e_sum GROUP BY a_id, b_id"; - String expectType = "RecordType(BIGINT a_id, BIGINT b_id, DOUBLE e_sum)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(BIGINT a_id, BIGINT b_id, DOUBLE e_sum)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnOrderByAlias() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnOrderByAlias() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, b.id as _id, b.age as age order by age DESC"; - String expectType = "RecordType(VARCHAR name, BIGINT _id, INTEGER age)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, BIGINT _id, INTEGER age)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnOrderByAlias2() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnOrderByAlias2() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, b.id as _id, b.age as age order by age DESC Limit 10"; - String expectType = "RecordType(VARCHAR name, BIGINT _id, INTEGER age)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, BIGINT _id, INTEGER age)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnCast() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnCast() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, cast(b.id as int) as _id"; - String expectType = "RecordType(VARCHAR name, INTEGER _id)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, INTEGER _id)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnCase() { - String script = "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + @Test + public void testValidatedReturnCase() { + String script = + "MATCH (a:person WHERE a.id = '1')-[e:knows]->(b:person) " + "RETURN b.name as name, case when b.gender = 0 then '0' else '1' end as _id"; - String expectType = "RecordType(VARCHAR name, CHAR(1) _id)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = "RecordType(VARCHAR name, CHAR(1) _id)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnFields2() { - String script = "Match (a1:user WHERE name like 'marko')-[e1]->(b1:person) " - + "return b1.name AS b_id"; - String expectType = "RecordType(VARCHAR b_id)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + @Test + public void testValidatedReturnFields2() { + String script = + "Match (a1:user WHERE name like 'marko')-[e1]->(b1:person) " + "return b1.name AS b_id"; + String expectType = "RecordType(VARCHAR b_id)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnFields3() { - String script = "Match (a:user)-[e]->(b:person WHERE name = 'lop') return a.id"; - String expectType = "RecordType(BIGINT id)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + @Test + public void testValidatedReturnFields3() { + String script = "Match (a:user)-[e]->(b:person WHERE name = 'lop') return a.id"; + String expectType = "RecordType(BIGINT id)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnGroupByAlias3() { - String script = "MATCH (a:user)-[e:knows]->(b:user)\n" + @Test + public void testValidatedReturnGroupByAlias3() { + String script = + "MATCH (a:user)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as b_id, e.weight as e_weight GROUP BY a, b_id, e_weight"; - String expectType = "RecordType(" - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) _a, BIGINT b_id, DOUBLE e_weight)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) _a, BIGINT" + + " b_id, DOUBLE e_weight)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testValidatedReturnOrderBy() { - String script = "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testValidatedReturnOrderBy() { + String script = + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as b_id, weight Order by _a.id DESC Limit 10"; - String expectType = "RecordType(" - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) _a, BIGINT b_id, DOUBLE weight)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = + "RecordType(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name) _a, BIGINT" + + " b_id, DOUBLE weight)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testReturnAggregate1() { - String script = "MATCH (a:person)-[e:knows]->(b:person) " + @Test + public void testReturnAggregate1() { + String script = + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN COUNT(a.name) as a_name, SUM(e.weight) as e_weight, " + "MAX(b.age) + 1 as b_age_max, MIN(b.age) - 1 as b_age_min, " + "AVG(b.age) as b_age_avg, b as _b " + "group by _b order by _b"; - String expectType = "RecordType(BIGINT a_name, DOUBLE e_weight, " - + "JavaType(class java.lang.Integer) b_age_max, " - + "JavaType(class java.lang.Integer) b_age_min, DOUBLE b_age_avg, " - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER gender, INTEGER age) _b)"; - PlanTester.build().registerGraph(graphDDL) - .gql(script) - .validate() - .expectValidateType(expectType); - } + String expectType = + "RecordType(BIGINT a_name, DOUBLE e_weight, JavaType(class java.lang.Integer) b_age_max," + + " JavaType(class java.lang.Integer) b_age_min, DOUBLE b_age_avg," + + " Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, INTEGER gender," + + " INTEGER age) _b)"; + PlanTester.build() + .registerGraph(graphDDL) + .gql(script) + .validate() + .expectValidateType(expectType); + } - @Test - public void testDuplicatedAlias() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testDuplicatedAlias() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.id as _a, weight as e_weight Order by a.id " + "DESC Limit 10") - .validate() - .expectException("Duplicated alias in an Return statement at line 2, column 15"); - } + .validate() + .expectException("Duplicated alias in an Return statement at line 2, column 15"); + } - @Test - public void testAmbiguousColumn() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e1:knows]->(o:person)-[e2]->(b:user)\n" + @Test + public void testAmbiguousColumn() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e1:knows]->(o:person)-[e2]->(b:user)\n" + "RETURN a as _a, b.id as b_id, weight Order by _a.id " + "DESC Limit 10") - .validate() - .expectException("From line 2, column 31 to line 2, column 36: Column 'weight' is ambiguous"); - } + .validate() + .expectException( + "From line 2, column 31 to line 2, column 36: Column 'weight' is ambiguous"); + } - @Test - public void testReturnVertexColumnNotExists() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testReturnVertexColumnNotExists() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, b.test as b_test, weight as e_weight Order by a.id " + "DESC Limit 10") - .validate() - .expectException("From line 2, column 17 to line 2, column 22: Column 'b.test' not found in table 'col_1'"); - } + .validate() + .expectException( + "From line 2, column 17 to line 2, column 22: Column 'b.test' not found in table" + + " 'col_1'"); + } - @Test - public void testReturnEdgeColumnNotExists() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testReturnEdgeColumnNotExists() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, e.test as e_test, weight as e_weight Order by a.id " + "DESC Limit 10") - .validate() - .expectException("From line 2, column 17 to line 2, column 22: Column 'e.test' not found in table 'col_1'"); - } + .validate() + .expectException( + "From line 2, column 17 to line 2, column 22: Column 'e.test' not found in table" + + " 'col_1'"); + } - @Test - public void testVertexScopeNotExists() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testVertexScopeNotExists() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, user.name as user_name, weight as e_weight Order by a.id " + "DESC Limit 10") - .validate() - .expectException("From line 2, column 17 to line 2, column 20: Column 'user' not found in any table"); - } + .validate() + .expectException( + "From line 2, column 17 to line 2, column 20: Column 'user' not found in any table"); + } - @Test - public void testEdgeScopeNotExists() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + @Test + public void testEdgeScopeNotExists() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:user WHERE a.id = 1)-[e:knows]->(b:user)\n" + "RETURN a as _a, knows.weight as weight, weight as e_weight Order by a.id " + "DESC Limit 10") - .validate() - .expectException("From line 2, column 17 to line 2, column 21: Table 'knows' not found"); - } + .validate() + .expectException("From line 2, column 17 to line 2, column 21: Table 'knows' not found"); + } - @Test - public void testExpressionNotBeingGrouped1() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + @Test + public void testExpressionNotBeingGrouped1() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:person)-[e:knows where e.weight > 0.4]->(b:person) " + "RETURN a.id as a_id, e.weight as amt GROUP BY amt ORDER BY a_id") - .validate() - .expectException("From line 1, column 68 to line 1, column 71: Expression 'a.id' is not being grouped"); - } + .validate() + .expectException( + "From line 1, column 68 to line 1, column 71: Expression 'a.id' is not being grouped"); + } - @Test - public void testExpressionNotBeingGrouped2() { - PlanTester.build().registerGraph(graphDDL) - .gql("MATCH (a:person)-[e:knows]->(b:person) " + @Test + public void testExpressionNotBeingGrouped2() { + PlanTester.build() + .registerGraph(graphDDL) + .gql( + "MATCH (a:person)-[e:knows]->(b:person) " + "RETURN a as _a, e as _e, b as _b group by _a.id, _e.time, _b.id") - .validate() - .expectException("At line 1, column 47: Expression 'a' is not being grouped"); - } + .validate() + .expectException("At line 1, column 47: Expression 'a' is not being grouped"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateWithStatementTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateWithStatementTest.java index 3edd52139..d6f4fd76f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateWithStatementTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateWithStatementTest.java @@ -23,16 +23,16 @@ public class GQLValidateWithStatementTest { - @Test - public void testWithStatement_001() { - String script = "With p as (Select 1 as id)\n" + @Test + public void testWithStatement_001() { + String script = + "With p as (Select 1 as id)\n" + "Match (a where a.id = p.id) -[e] -> (b)\n" + "Return a.id as a_id, b.id as b_id"; - PlanTester.build() - .gql(script) - .validate() - .expectValidateType( - "RecordType(BIGINT a_id, BIGINT b_id)"); - } + PlanTester.build() + .gql(script) + .validate() + .expectValidateType("RecordType(BIGINT a_id, BIGINT b_id)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/PlanTester.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/PlanTester.java index f1f1e93bc..2bfa50a1d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/PlanTester.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/PlanTester.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; @@ -43,160 +44,159 @@ public class PlanTester { - public static final String defaultGraphDDL = - "create graph g0(" - + "vertex user(" - + " id bigint ID," - + "name varchar," - + "age integer" - + ")," - + "edge knows(" - + " src_id bigint SOURCE ID," - + " target_id bigint DESTINATION ID," - + " weight double" - + ")" - + ")"; - - public static final String defaultTableDDL = - "create table t0 (id bigint, name varchar, age int)"; - - private String gql; - - private SqlNode validateNode; - - private Exception validateException; - - private RelNode relNode; - - private final GeaFlowDSLParser parser = new GeaFlowDSLParser(); - - private final GQLContext gqlContext = GQLContext.create(new Configuration(), false); - - private PlanTester() { - try { - GeaFlowDSLParser parser = new GeaFlowDSLParser(); - SqlCreateGraph createGraph = (SqlCreateGraph) parser.parseStatement(defaultGraphDDL); - GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); - gqlContext.registerGraph(graph); - gqlContext.setCurrentGraph(graph.getName()); - - SqlCreateTable createTable = (SqlCreateTable) parser.parseStatement(defaultTableDDL); - GeaFlowTable table = gqlContext.convertToTable(createTable); - gqlContext.registerTable(table); - } catch (SqlParseException e) { - throw new RuntimeException(e); - } - } + public static final String defaultGraphDDL = + "create graph g0(" + + "vertex user(" + + " id bigint ID," + + "name varchar," + + "age integer" + + ")," + + "edge knows(" + + " src_id bigint SOURCE ID," + + " target_id bigint DESTINATION ID," + + " weight double" + + ")" + + ")"; - public PlanTester registerGraph(String ddl) { - try { - GeaFlowDSLParser parser = new GeaFlowDSLParser(); - SqlCreateGraph createGraph = (SqlCreateGraph) parser.parseStatement(ddl); - GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); - gqlContext.registerGraph(graph); - gqlContext.setCurrentGraph(graph.getName()); - } catch (SqlParseException e) { - throw new RuntimeException(e); - } catch (Exception e) { - this.validateException = e; - } - return this; - } + public static final String defaultTableDDL = "create table t0 (id bigint, name varchar, age int)"; - public PlanTester setCurrentGraph(String graphName) { - gqlContext.setCurrentGraph(graphName); - return this; - } + private String gql; - public static PlanTester build() { - return new PlanTester(); - } + private SqlNode validateNode; - public PlanTester gql(String gql) { - if (this.gql != null) { - throw new IllegalArgumentException("Duplicate setting for gql."); - } - this.gql = gql; - return this; - } + private Exception validateException; - public PlanTester validate() { - if (validateException != null || validateNode != null) { - return this; - } - try { - SqlNode sqlNode = parser.parseStatement(gql); - validateNode = gqlContext.validate(sqlNode); - } catch (Exception e) { - this.validateException = e; - } - return this; - } + private RelNode relNode; - public PlanTester toRel() { - if (relNode != null) { - return this; - } - validate(); - if (validateException == null) { - this.relNode = gqlContext.toRelNode(validateNode); - } - return this; - } + private final GeaFlowDSLParser parser = new GeaFlowDSLParser(); - public PlanTester opt(RelOptRule... rules) { - if (relNode == null) { - return this; - } - GQLOptimizer optimizer = new GQLOptimizer(); - List useRuleList = new ArrayList<>(Arrays.asList(rules)); - RuleGroup useRules = new RuleGroup(useRuleList); - optimizer.addRuleGroup(useRules); - this.relNode = optimizer.optimize(relNode); - return this; - } + private final GQLContext gqlContext = GQLContext.create(new Configuration(), false); + + private PlanTester() { + try { + GeaFlowDSLParser parser = new GeaFlowDSLParser(); + SqlCreateGraph createGraph = (SqlCreateGraph) parser.parseStatement(defaultGraphDDL); + GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); + gqlContext.registerGraph(graph); + gqlContext.setCurrentGraph(graph.getName()); - public void expectValidateType(String expectType) { - if (validateException != null) { - throw new GeaFlowDSLException(validateException); - } - RelDataType dataType = gqlContext.getValidator().getValidatedNodeType(validateNode); - Assert.assertEquals(dataType.toString(), expectType); + SqlCreateTable createTable = (SqlCreateTable) parser.parseStatement(defaultTableDDL); + GeaFlowTable table = gqlContext.convertToTable(createTable); + gqlContext.registerTable(table); + } catch (SqlParseException e) { + throw new RuntimeException(e); } + } + + public PlanTester registerGraph(String ddl) { + try { + GeaFlowDSLParser parser = new GeaFlowDSLParser(); + SqlCreateGraph createGraph = (SqlCreateGraph) parser.parseStatement(ddl); + GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); + gqlContext.registerGraph(graph); + gqlContext.setCurrentGraph(graph.getName()); + } catch (SqlParseException e) { + throw new RuntimeException(e); + } catch (Exception e) { + this.validateException = e; + } + return this; + } + + public PlanTester setCurrentGraph(String graphName) { + gqlContext.setCurrentGraph(graphName); + return this; + } + + public static PlanTester build() { + return new PlanTester(); + } - public void expectException(String expectErrorMsg) { - Assert.assertNotNull(validateException); - Assert.assertEquals(validateException.getMessage(), expectErrorMsg); + public PlanTester gql(String gql) { + if (this.gql != null) { + throw new IllegalArgumentException("Duplicate setting for gql."); } + this.gql = gql; + return this; + } - public PlanTester checkRelNode(String expectPlan) { - if (validateException != null) { - throw new GeaFlowDSLException(validateException); - } - String actualPlan = RelOptUtil.toString(relNode); - Assert.assertEquals(actualPlan, expectPlan); - return this; + public PlanTester validate() { + if (validateException != null || validateNode != null) { + return this; } + try { + SqlNode sqlNode = parser.parseStatement(gql); + validateNode = gqlContext.validate(sqlNode); + } catch (Exception e) { + this.validateException = e; + } + return this; + } - public String getDefaultGraphDDL() { - return defaultGraphDDL; + public PlanTester toRel() { + if (relNode != null) { + return this; + } + validate(); + if (validateException == null) { + this.relNode = gqlContext.toRelNode(validateNode); } + return this; + } - public PlanTester checkFilteredFields(String expectFields) { - // Transverse until relNode is a LogicalGraphMatch node - // Only apply for simple case with unique LogicalGraphMatch node and linear hierarchy - RelNode currentNode = relNode; - - while (currentNode != null && !(currentNode instanceof LogicalGraphMatch)) { - currentNode = currentNode.getInputs().get(0); - } - if (currentNode == null) { - throw new GeaFlowDSLException("No matching fields found."); - } - - LogicalGraphMatch matchNode = (LogicalGraphMatch) currentNode; - String actualFields = matchNode.getFilteredFields(); - Assert.assertEquals(actualFields, expectFields); - return this; + public PlanTester opt(RelOptRule... rules) { + if (relNode == null) { + return this; } + GQLOptimizer optimizer = new GQLOptimizer(); + List useRuleList = new ArrayList<>(Arrays.asList(rules)); + RuleGroup useRules = new RuleGroup(useRuleList); + optimizer.addRuleGroup(useRules); + this.relNode = optimizer.optimize(relNode); + return this; + } + + public void expectValidateType(String expectType) { + if (validateException != null) { + throw new GeaFlowDSLException(validateException); + } + RelDataType dataType = gqlContext.getValidator().getValidatedNodeType(validateNode); + Assert.assertEquals(dataType.toString(), expectType); + } + + public void expectException(String expectErrorMsg) { + Assert.assertNotNull(validateException); + Assert.assertEquals(validateException.getMessage(), expectErrorMsg); + } + + public PlanTester checkRelNode(String expectPlan) { + if (validateException != null) { + throw new GeaFlowDSLException(validateException); + } + String actualPlan = RelOptUtil.toString(relNode); + Assert.assertEquals(actualPlan, expectPlan); + return this; + } + + public String getDefaultGraphDDL() { + return defaultGraphDDL; + } + + public PlanTester checkFilteredFields(String expectFields) { + // Transverse until relNode is a LogicalGraphMatch node + // Only apply for simple case with unique LogicalGraphMatch node and linear hierarchy + RelNode currentNode = relNode; + + while (currentNode != null && !(currentNode instanceof LogicalGraphMatch)) { + currentNode = currentNode.getInputs().get(0); + } + if (currentNode == null) { + throw new GeaFlowDSLException("No matching fields found."); + } + + LogicalGraphMatch matchNode = (LogicalGraphMatch) currentNode; + String actualFields = matchNode.getFilteredFields(); + Assert.assertEquals(actualFields, expectFields); + return this; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/CatalogTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/CatalogTest.java index dc163eaa5..1a74fdea7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/CatalogTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/CatalogTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.Set; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.catalog.exception.ObjectAlreadyExistException; import org.apache.geaflow.dsl.catalog.exception.ObjectNotExistException; @@ -33,51 +34,74 @@ public class CatalogTest { - @Test - public void testMemoryCatalog() { - Catalog catalog = CatalogFactory.getCatalog(new Configuration()); - String instance = "default"; - GeaFlowGraph graph = new GeaFlowGraph(instance, "g1", new ArrayList<>(), - new ArrayList<>(), new HashMap<>(), new HashMap<>(), true, false); - GeaFlowTable table = new GeaFlowTable(instance, "t1", new ArrayList<>(), - new ArrayList<>(), new ArrayList<>(), new HashMap<>(), true, false); - GeaFlowView view = new GeaFlowView(instance, "v1", new ArrayList<>(), - null, null, true); - catalog.createGraph(graph.getInstanceName(), graph); - catalog.createTable(table.getInstanceName(), table); - catalog.createView(view.getInstanceName(), view); - // create repeatedly - catalog.createGraph(instance, graph); - catalog.createTable(instance, table); - catalog.createView(instance, view); + @Test + public void testMemoryCatalog() { + Catalog catalog = CatalogFactory.getCatalog(new Configuration()); + String instance = "default"; + GeaFlowGraph graph = + new GeaFlowGraph( + instance, + "g1", + new ArrayList<>(), + new ArrayList<>(), + new HashMap<>(), + new HashMap<>(), + true, + false); + GeaFlowTable table = + new GeaFlowTable( + instance, + "t1", + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new HashMap<>(), + true, + false); + GeaFlowView view = new GeaFlowView(instance, "v1", new ArrayList<>(), null, null, true); + catalog.createGraph(graph.getInstanceName(), graph); + catalog.createTable(table.getInstanceName(), table); + catalog.createView(view.getInstanceName(), view); + // create repeatedly + catalog.createGraph(instance, graph); + catalog.createTable(instance, table); + catalog.createView(instance, view); - Set graphAndTables = catalog.listGraphAndTable(instance); - Assert.assertEquals(graphAndTables.size(), 3); - catalog.describeGraph(instance, "g1"); - catalog.describeTable(instance, "t1"); - catalog.dropGraph(instance, "g1"); - catalog.dropTable(instance, "t1"); - Set graphAndTablesAfterDrop = catalog.listGraphAndTable(instance); - Assert.assertEquals(graphAndTablesAfterDrop.size(), 1); + Set graphAndTables = catalog.listGraphAndTable(instance); + Assert.assertEquals(graphAndTables.size(), 3); + catalog.describeGraph(instance, "g1"); + catalog.describeTable(instance, "t1"); + catalog.dropGraph(instance, "g1"); + catalog.dropTable(instance, "t1"); + Set graphAndTablesAfterDrop = catalog.listGraphAndTable(instance); + Assert.assertEquals(graphAndTablesAfterDrop.size(), 1); - // check exception - try { - catalog.dropGraph("testInstance", "g1"); - } catch (Exception e) { - Assert.assertTrue(e instanceof ObjectNotExistException); - } - try { - catalog.dropGraph(instance, "g1"); - } catch (Exception e) { - Assert.assertTrue(e instanceof ObjectNotExistException); - } - try { - GeaFlowGraph graph2 = new GeaFlowGraph(instance, "g2", new ArrayList<>(), - new ArrayList<>(), new HashMap<>(), new HashMap<>(), false, false); - catalog.createGraph(instance, graph2); - catalog.createGraph(instance, graph2); - } catch (Exception e) { - Assert.assertTrue(e instanceof ObjectAlreadyExistException); - } + // check exception + try { + catalog.dropGraph("testInstance", "g1"); + } catch (Exception e) { + Assert.assertTrue(e instanceof ObjectNotExistException); + } + try { + catalog.dropGraph(instance, "g1"); + } catch (Exception e) { + Assert.assertTrue(e instanceof ObjectNotExistException); + } + try { + GeaFlowGraph graph2 = + new GeaFlowGraph( + instance, + "g2", + new ArrayList<>(), + new ArrayList<>(), + new HashMap<>(), + new HashMap<>(), + false, + false); + catalog.createGraph(instance, graph2); + catalog.createGraph(instance, graph2); + } catch (Exception e) { + Assert.assertTrue(e instanceof ObjectAlreadyExistException); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/ConsoleCatalogTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/ConsoleCatalogTest.java index 233c7e95e..279b00917 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/ConsoleCatalogTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/catalog/ConsoleCatalogTest.java @@ -19,14 +19,10 @@ package org.apache.geaflow.dsl.catalog; -import com.google.gson.Gson; import java.io.IOException; import java.util.Collections; import java.util.Set; -import okhttp3.mockwebserver.Dispatcher; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; + import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.geaflow.common.config.Configuration; @@ -52,204 +48,263 @@ import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import com.google.gson.Gson; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + public class ConsoleCatalogTest { - private final String instanceName = "default"; - private String baseUrl; - private MockWebServer server; - private ConsoleCatalog consoleCatalog; - private GeaFlowGraph graph; - private GeaFlowGraph graph2; - private GeaFlowTable table; - private GeaFlowTable table2; - private GeaFlowView view; - - @BeforeTest - public void prepare() throws IOException, SqlParseException { - GeaFlowDSLParser parser = new GeaFlowDSLParser(); - GQLContext gqlContext = GQLContext.create(new Configuration(), false); - - // setup graph - String stmtOfGraph = "CREATE GRAPH IF NOT EXISTS modern (\n" + "\tVertex person (\n" - + "\t id bigint ID,\n" + "\t name varchar,\n" + "\t age int\n" + "\t),\n" - + "\tVertex software (\n" + "\t id bigint ID,\n" + "\t name varchar,\n" - + "\t lang varchar\n" + "\t),\n" + "\tEdge knows (\n" + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID,\n" + "\t weight double\n" + "\t),\n" - + "\tEdge created (\n" + "\t srcId bigint SOURCE ID,\n" - + " \ttargetId bigint DESTINATION ID,\n" + " \tweight double\n" + "\t)\n" - + ") WITH (\n" + "\tstoreType='memory',\n" + private final String instanceName = "default"; + private String baseUrl; + private MockWebServer server; + private ConsoleCatalog consoleCatalog; + private GeaFlowGraph graph; + private GeaFlowGraph graph2; + private GeaFlowTable table; + private GeaFlowTable table2; + private GeaFlowView view; + + @BeforeTest + public void prepare() throws IOException, SqlParseException { + GeaFlowDSLParser parser = new GeaFlowDSLParser(); + GQLContext gqlContext = GQLContext.create(new Configuration(), false); + + // setup graph + String stmtOfGraph = + "CREATE GRAPH IF NOT EXISTS modern (\n" + + "\tVertex person (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t age int\n" + + "\t),\n" + + "\tVertex software (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t lang varchar\n" + + "\t),\n" + + "\tEdge knows (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID,\n" + + "\t weight double\n" + + "\t),\n" + + "\tEdge created (\n" + + "\t srcId bigint SOURCE ID,\n" + + " \ttargetId bigint DESTINATION ID,\n" + + " \tweight double\n" + + "\t)\n" + + ") WITH (\n" + + "\tstoreType='memory',\n" + "\tgeaflow.dsl.using.vertex.path = 'resource:///data/modern_vertex.txt',\n" - + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + ")"; - SqlNode sqlNodeOfGraph = parser.parseStatement(stmtOfGraph); - SqlCreateGraph sqlCreateGraph = (SqlCreateGraph) sqlNodeOfGraph; - graph = gqlContext.convertToGraph(sqlCreateGraph); - - String stmtOfGraph2 = "CREATE GRAPH IF NOT EXISTS modern2 (\n" + "\tVertex person (\n" - + "\t id bigint ID,\n" + "\t name varchar,\n" + "\t age int\n" + "\t),\n" - + "\tVertex software (\n" + "\t id bigint ID,\n" + "\t name varchar,\n" - + "\t lang varchar\n" + "\t),\n" + "\tEdge knows (\n" + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID,\n" + "\t weight double\n" + "\t),\n" - + "\tEdge created (\n" + "\t srcId bigint SOURCE ID,\n" - + " \ttargetId bigint DESTINATION ID,\n" + " \tweight double\n" + "\t)\n" - + ") WITH (\n" + "\tstoreType='memory',\n" + + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + + ")"; + SqlNode sqlNodeOfGraph = parser.parseStatement(stmtOfGraph); + SqlCreateGraph sqlCreateGraph = (SqlCreateGraph) sqlNodeOfGraph; + graph = gqlContext.convertToGraph(sqlCreateGraph); + + String stmtOfGraph2 = + "CREATE GRAPH IF NOT EXISTS modern2 (\n" + + "\tVertex person (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t age int\n" + + "\t),\n" + + "\tVertex software (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t lang varchar\n" + + "\t),\n" + + "\tEdge knows (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID,\n" + + "\t weight double\n" + + "\t),\n" + + "\tEdge created (\n" + + "\t srcId bigint SOURCE ID,\n" + + " \ttargetId bigint DESTINATION ID,\n" + + " \tweight double\n" + + "\t)\n" + + ") WITH (\n" + + "\tstoreType='memory',\n" + "\tgeaflow.dsl.using.vertex.path = 'resource:///data/modern_vertex.txt',\n" - + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + ")"; - SqlNode sqlNodeOfGraph2 = parser.parseStatement(stmtOfGraph2); - SqlCreateGraph sqlCreateGraph2 = (SqlCreateGraph) sqlNodeOfGraph2; - graph2 = gqlContext.convertToGraph(sqlCreateGraph2); - GraphDescriptor stats = new GraphDescriptor().addEdge(new EdgeDescriptor("0", "created", "person", "software")); - graph2.setDescriptor(stats); - - // setup table - String stmtOfTable = "CREATE TABLE IF NOT EXISTS users (\n" + "\tcreateTime bigint,\n" - + "\tproductId bigint,\n" + "\torderId bigint,\n" + "\tunits bigint,\n" - + "\tuser_name VARCHAR\n" + ") WITH (\n" + "\ttype='file',\n" - + "\tgeaflow.dsl.file.path = 'resource:///data/users_correlate2.txt'\n" + ")"; - SqlNode sqlNodeOfTable = parser.parseStatement(stmtOfTable); - SqlCreateTable sqlCreateTable = (SqlCreateTable) sqlNodeOfTable; - table = gqlContext.convertToTable(sqlCreateTable); - - String stmtOfTable2 = "CREATE TABLE IF NOT EXISTS users2 (\n" + "\tcreateTime bigint,\n" - + "\tproductId bigint,\n" + "\torderId bigint,\n" + "\tunits bigint,\n" - + "\tuser_name VARCHAR\n" + ") WITH (\n" + "\ttype='file',\n" - + "\tgeaflow.dsl.file.path = 'resource:///data/users_correlate2.txt'\n" + ")"; - SqlNode sqlNodeOfTable2 = parser.parseStatement(stmtOfTable2); - SqlCreateTable sqlCreateTable2 = (SqlCreateTable) sqlNodeOfTable2; - table2 = gqlContext.convertToTable(sqlCreateTable2); - - // setup view - String stmtOfView = "CREATE VIEW IF NOT EXISTS console (count_id, sum_id, max_id, min_id, avg_id, distinct_id, user_name) AS\n" - + "SELECT\n" + " 1 AS count_id,\n" + " 2 AS sum_id,\n" + " 3 AS max_id,\n" - + " 4 AS min_id,\n" + " 5 AS avg_id,\n" + " 6 AS distinct_id,\n" + + "\tgeaflow.dsl.using.edge.path = 'resource:///data/modern_edge.txt'\n" + + ")"; + SqlNode sqlNodeOfGraph2 = parser.parseStatement(stmtOfGraph2); + SqlCreateGraph sqlCreateGraph2 = (SqlCreateGraph) sqlNodeOfGraph2; + graph2 = gqlContext.convertToGraph(sqlCreateGraph2); + GraphDescriptor stats = + new GraphDescriptor().addEdge(new EdgeDescriptor("0", "created", "person", "software")); + graph2.setDescriptor(stats); + + // setup table + String stmtOfTable = + "CREATE TABLE IF NOT EXISTS users (\n" + + "\tcreateTime bigint,\n" + + "\tproductId bigint,\n" + + "\torderId bigint,\n" + + "\tunits bigint,\n" + + "\tuser_name VARCHAR\n" + + ") WITH (\n" + + "\ttype='file',\n" + + "\tgeaflow.dsl.file.path = 'resource:///data/users_correlate2.txt'\n" + + ")"; + SqlNode sqlNodeOfTable = parser.parseStatement(stmtOfTable); + SqlCreateTable sqlCreateTable = (SqlCreateTable) sqlNodeOfTable; + table = gqlContext.convertToTable(sqlCreateTable); + + String stmtOfTable2 = + "CREATE TABLE IF NOT EXISTS users2 (\n" + + "\tcreateTime bigint,\n" + + "\tproductId bigint,\n" + + "\torderId bigint,\n" + + "\tunits bigint,\n" + + "\tuser_name VARCHAR\n" + + ") WITH (\n" + + "\ttype='file',\n" + + "\tgeaflow.dsl.file.path = 'resource:///data/users_correlate2.txt'\n" + + ")"; + SqlNode sqlNodeOfTable2 = parser.parseStatement(stmtOfTable2); + SqlCreateTable sqlCreateTable2 = (SqlCreateTable) sqlNodeOfTable2; + table2 = gqlContext.convertToTable(sqlCreateTable2); + + // setup view + String stmtOfView = + "CREATE VIEW IF NOT EXISTS console (count_id, sum_id, max_id, min_id, avg_id, distinct_id," + + " user_name) AS\n" + + "SELECT\n" + + " 1 AS count_id,\n" + + " 2 AS sum_id,\n" + + " 3 AS max_id,\n" + + " 4 AS min_id,\n" + + " 5 AS avg_id,\n" + + " 6 AS distinct_id,\n" + " 'test_name' AS user_name"; - SqlNode sqlNodeOfView = parser.parseStatement(stmtOfView); - SqlCreateView sqlCreateView = (SqlCreateView) sqlNodeOfView; - view = gqlContext.convertToView(sqlCreateView); - - // setup server - server = new MockWebServer(); - Dispatcher dispatcher = new Dispatcher() { - @Override - public MockResponse dispatch(RecordedRequest recordedRequest) - throws InterruptedException { - String path = recordedRequest.getPath(); - Gson gson = new Gson(); - HttpResponse response = new HttpResponse(); - response.setSuccess(true); - response.setCode("200"); - switch (path) { - case "/api/instances/default/graphs/modern": - response.setData(gson.toJsonTree(CatalogUtil.convertToGraphModel(graph))); - return new MockResponse().setResponseCode(200) - .setBody(gson.toJson(response)); - case "/api/instances/default/graphs/modern2": - case "/api/instances/default/tables/users2": - case "/api/instances/default/graphs/modern2/endpoints": - return new MockResponse().setResponseCode(200).setBody("{success:true}"); - case "/api/instances/default/graphs": - PageList graphList = new PageList(); - graphList.setList( - Collections.singletonList(CatalogUtil.convertToGraphModel(graph))); - response.setData(gson.toJsonTree(graphList)); - return new MockResponse().setResponseCode(200) - .setBody(gson.toJson(response)); - case "/api/instances/default/tables/users": - response.setData(gson.toJsonTree(CatalogUtil.convertToTableModel(table))); - return new MockResponse().setResponseCode(200) - .setBody(gson.toJson(response)); - case "/api/instances/default/tables": - PageList tableList = new PageList(); - tableList.setList( - Collections.singletonList(CatalogUtil.convertToTableModel(table))); - response.setData(gson.toJsonTree(tableList)); - return new MockResponse().setResponseCode(200) - .setBody(gson.toJson(response)); - case "/api/instances": - PageList instanceList = new PageList(); - InstanceModel instanceModel = new InstanceModel(); - instanceModel.setName(instanceName); - instanceModel.setId("13"); - instanceModel.setComment("test comment"); - instanceModel.setCreateTime("2023-05-19"); - instanceModel.setCreatorId("128745"); - instanceModel.setModifierId("128745"); - instanceModel.setModifierName("user1"); - instanceModel.setCreatorName("user1"); - instanceModel.setModifyTime("2023-05-19"); - instanceList.setList(Collections.singletonList(instanceModel)); - response.setData(gson.toJsonTree(instanceList)); - return new MockResponse().setResponseCode(200) - .setBody(gson.toJson(response)); - } - return null; + SqlNode sqlNodeOfView = parser.parseStatement(stmtOfView); + SqlCreateView sqlCreateView = (SqlCreateView) sqlNodeOfView; + view = gqlContext.convertToView(sqlCreateView); + + // setup server + server = new MockWebServer(); + Dispatcher dispatcher = + new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest recordedRequest) + throws InterruptedException { + String path = recordedRequest.getPath(); + Gson gson = new Gson(); + HttpResponse response = new HttpResponse(); + response.setSuccess(true); + response.setCode("200"); + switch (path) { + case "/api/instances/default/graphs/modern": + response.setData(gson.toJsonTree(CatalogUtil.convertToGraphModel(graph))); + return new MockResponse().setResponseCode(200).setBody(gson.toJson(response)); + case "/api/instances/default/graphs/modern2": + case "/api/instances/default/tables/users2": + case "/api/instances/default/graphs/modern2/endpoints": + return new MockResponse().setResponseCode(200).setBody("{success:true}"); + case "/api/instances/default/graphs": + PageList graphList = new PageList(); + graphList.setList( + Collections.singletonList(CatalogUtil.convertToGraphModel(graph))); + response.setData(gson.toJsonTree(graphList)); + return new MockResponse().setResponseCode(200).setBody(gson.toJson(response)); + case "/api/instances/default/tables/users": + response.setData(gson.toJsonTree(CatalogUtil.convertToTableModel(table))); + return new MockResponse().setResponseCode(200).setBody(gson.toJson(response)); + case "/api/instances/default/tables": + PageList tableList = new PageList(); + tableList.setList( + Collections.singletonList(CatalogUtil.convertToTableModel(table))); + response.setData(gson.toJsonTree(tableList)); + return new MockResponse().setResponseCode(200).setBody(gson.toJson(response)); + case "/api/instances": + PageList instanceList = new PageList(); + InstanceModel instanceModel = new InstanceModel(); + instanceModel.setName(instanceName); + instanceModel.setId("13"); + instanceModel.setComment("test comment"); + instanceModel.setCreateTime("2023-05-19"); + instanceModel.setCreatorId("128745"); + instanceModel.setModifierId("128745"); + instanceModel.setModifierName("user1"); + instanceModel.setCreatorName("user1"); + instanceModel.setModifyTime("2023-05-19"); + instanceList.setList(Collections.singletonList(instanceModel)); + response.setData(gson.toJsonTree(instanceList)); + return new MockResponse().setResponseCode(200).setBody(gson.toJson(response)); } + return null; + } }; - server.setDispatcher(dispatcher); - server.start(); - baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); - - // setup catalog - consoleCatalog = new ConsoleCatalog(); - Configuration configuration = new Configuration(); - configuration.put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), "test"); - configuration.put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), baseUrl); - consoleCatalog.init(configuration); - } - - @AfterTest - public void after() throws IOException { - server.shutdown(); - } - - @Test - public void testGraph() { - consoleCatalog.createGraph(instanceName, graph); - consoleCatalog.createGraph(instanceName, graph2); - GeaFlowGraph catalogGraph = (GeaFlowGraph) consoleCatalog.getGraph(instanceName, - this.graph.getName()); - Assert.assertEquals(catalogGraph.getVertexTables().size(), 2); - Assert.assertEquals(catalogGraph.getEdgeTables().size(), 2); - consoleCatalog.describeGraph(instanceName, this.graph.getName()); - consoleCatalog.dropGraph(instanceName, this.graph.getName()); - Set graphsAndTables = consoleCatalog.listGraphAndTable(instanceName); - Assert.assertTrue(graphsAndTables.contains(graph.getName())); - Assert.assertTrue(consoleCatalog.isInstanceExists(instanceName)); - - CompileCatalog compileCatalog = new CompileCatalog(consoleCatalog); - compileCatalog.createGraph(instanceName, graph); - GeaFlowGraph compileGraph = (GeaFlowGraph) compileCatalog.getGraph(instanceName, - this.graph.getName()); - Assert.assertEquals(compileGraph.getVertexTables().size(), 2); - Assert.assertEquals(compileGraph.getEdgeTables().size(), 2); - compileCatalog.describeGraph(instanceName, this.graph.getName()); - compileCatalog.dropGraph(instanceName, this.graph.getName()); - Set compileGraphsAndTables = compileCatalog.listGraphAndTable(instanceName); - Assert.assertTrue(compileGraphsAndTables.contains(graph.getName())); - Assert.assertTrue(compileCatalog.isInstanceExists(instanceName)); - } - - @Test - public void testTableAndView() { - consoleCatalog.createTable(instanceName, table); - consoleCatalog.createTable(instanceName, table2); - consoleCatalog.createView(instanceName, view); - GeaFlowTable catalogTable = (GeaFlowTable) consoleCatalog.getTable(instanceName, - this.table.getName()); - Assert.assertEquals(catalogTable.getFields().size(), 5); - consoleCatalog.describeTable(instanceName, this.table.getName()); - consoleCatalog.dropTable(instanceName, this.table.getName()); - Set graphsAndTables = consoleCatalog.listGraphAndTable(instanceName); - Assert.assertTrue(graphsAndTables.contains(table.getName())); - - CompileCatalog compileCatalog = new CompileCatalog(consoleCatalog); - compileCatalog.createTable(instanceName, table); - compileCatalog.createView(instanceName, view); - GeaFlowTable compileTable = (GeaFlowTable) compileCatalog.getTable(instanceName, - this.table.getName()); - Assert.assertEquals(compileTable.getFields().size(), 5); - compileCatalog.describeTable(instanceName, this.table.getName()); - compileCatalog.dropTable(instanceName, this.table.getName()); - Set compileGraphsAndTables = compileCatalog.listGraphAndTable(instanceName); - Assert.assertTrue(compileGraphsAndTables.contains(table.getName())); - } + server.setDispatcher(dispatcher); + server.start(); + baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); + + // setup catalog + consoleCatalog = new ConsoleCatalog(); + Configuration configuration = new Configuration(); + configuration.put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), "test"); + configuration.put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), baseUrl); + consoleCatalog.init(configuration); + } + + @AfterTest + public void after() throws IOException { + server.shutdown(); + } + + @Test + public void testGraph() { + consoleCatalog.createGraph(instanceName, graph); + consoleCatalog.createGraph(instanceName, graph2); + GeaFlowGraph catalogGraph = + (GeaFlowGraph) consoleCatalog.getGraph(instanceName, this.graph.getName()); + Assert.assertEquals(catalogGraph.getVertexTables().size(), 2); + Assert.assertEquals(catalogGraph.getEdgeTables().size(), 2); + consoleCatalog.describeGraph(instanceName, this.graph.getName()); + consoleCatalog.dropGraph(instanceName, this.graph.getName()); + Set graphsAndTables = consoleCatalog.listGraphAndTable(instanceName); + Assert.assertTrue(graphsAndTables.contains(graph.getName())); + Assert.assertTrue(consoleCatalog.isInstanceExists(instanceName)); + + CompileCatalog compileCatalog = new CompileCatalog(consoleCatalog); + compileCatalog.createGraph(instanceName, graph); + GeaFlowGraph compileGraph = + (GeaFlowGraph) compileCatalog.getGraph(instanceName, this.graph.getName()); + Assert.assertEquals(compileGraph.getVertexTables().size(), 2); + Assert.assertEquals(compileGraph.getEdgeTables().size(), 2); + compileCatalog.describeGraph(instanceName, this.graph.getName()); + compileCatalog.dropGraph(instanceName, this.graph.getName()); + Set compileGraphsAndTables = compileCatalog.listGraphAndTable(instanceName); + Assert.assertTrue(compileGraphsAndTables.contains(graph.getName())); + Assert.assertTrue(compileCatalog.isInstanceExists(instanceName)); + } + + @Test + public void testTableAndView() { + consoleCatalog.createTable(instanceName, table); + consoleCatalog.createTable(instanceName, table2); + consoleCatalog.createView(instanceName, view); + GeaFlowTable catalogTable = + (GeaFlowTable) consoleCatalog.getTable(instanceName, this.table.getName()); + Assert.assertEquals(catalogTable.getFields().size(), 5); + consoleCatalog.describeTable(instanceName, this.table.getName()); + consoleCatalog.dropTable(instanceName, this.table.getName()); + Set graphsAndTables = consoleCatalog.listGraphAndTable(instanceName); + Assert.assertTrue(graphsAndTables.contains(table.getName())); + + CompileCatalog compileCatalog = new CompileCatalog(consoleCatalog); + compileCatalog.createTable(instanceName, table); + compileCatalog.createView(instanceName, view); + GeaFlowTable compileTable = + (GeaFlowTable) compileCatalog.getTable(instanceName, this.table.getName()); + Assert.assertEquals(compileTable.getFields().size(), 5); + compileCatalog.describeTable(instanceName, this.table.getName()); + compileCatalog.dropTable(instanceName, this.table.getName()); + Set compileGraphsAndTables = compileCatalog.listGraphAndTable(instanceName); + Assert.assertTrue(compileGraphsAndTables.contains(table.getName())); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLContextTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLContextTest.java index 4505cde31..7d97dd3a2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLContextTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLContextTest.java @@ -33,25 +33,29 @@ import org.testng.annotations.Test; public class GQLContextTest { - @Test - public void testGQLContext() throws SqlParseException { - String stmt = "Create Graph g (\n" + " Vertex buyer \n" - + " (id bigint ID, name string, age int),\n" + " Edge knows \n" - + " (s_id bigint SOURCE ID, t_id bigint DESTINATION ID, weight double)\n" + ")" + @Test + public void testGQLContext() throws SqlParseException { + String stmt = + "Create Graph g (\n" + + " Vertex buyer \n" + + " (id bigint ID, name string, age int),\n" + + " Edge knows \n" + + " (s_id bigint SOURCE ID, t_id bigint DESTINATION ID, weight double)\n" + + ")" + " with (store = 'memory')\n"; - GeaFlowDSLParser parser = new GeaFlowDSLParser(); - SqlNode sqlNode = parser.parseStatement(stmt); - assertTrue(sqlNode instanceof SqlCreateGraph); - SqlCreateGraph sqlCreateGraph = (SqlCreateGraph) sqlNode; - GQLContext gqlContext = GQLContext.create(new Configuration(), false); - GeaFlowGraph graph = gqlContext.convertToGraph(sqlCreateGraph); - gqlContext.registerGraph(graph); - assertNull(gqlContext.findSqlFunction(null, "function")); - assertNotNull(gqlContext.getTypeFactory()); - assertTrue(gqlContext.getRelBuilder() instanceof GQLRelBuilder); - assertNotNull(gqlContext.getValidator()); - gqlContext.setCurrentGraph("g"); - assertEquals(gqlContext.getCurrentGraph(), "g"); - } + GeaFlowDSLParser parser = new GeaFlowDSLParser(); + SqlNode sqlNode = parser.parseStatement(stmt); + assertTrue(sqlNode instanceof SqlCreateGraph); + SqlCreateGraph sqlCreateGraph = (SqlCreateGraph) sqlNode; + GQLContext gqlContext = GQLContext.create(new Configuration(), false); + GeaFlowGraph graph = gqlContext.convertToGraph(sqlCreateGraph); + gqlContext.registerGraph(graph); + assertNull(gqlContext.findSqlFunction(null, "function")); + assertNotNull(gqlContext.getTypeFactory()); + assertTrue(gqlContext.getRelBuilder() instanceof GQLRelBuilder); + assertNotNull(gqlContext.getValidator()); + gqlContext.setCurrentGraph("g"); + assertEquals(gqlContext.getCurrentGraph(), "g"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLCostTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLCostTest.java index 9b2c591fb..d41131bef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLCostTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/planner/GQLCostTest.java @@ -26,47 +26,46 @@ import org.testng.annotations.Test; public class GQLCostTest { - @Test - public void testGQLCostCalc() { - GQLCost tiny = GQLCost.TINY; - GQLCost huge = GQLCost.HUGE; - GQLCost zero = GQLCost.ZERO; - GQLCost infi = GQLCost.INFINITY; + @Test + public void testGQLCostCalc() { + GQLCost tiny = GQLCost.TINY; + GQLCost huge = GQLCost.HUGE; + GQLCost zero = GQLCost.ZERO; + GQLCost infi = GQLCost.INFINITY; - assertEquals(tiny.toString(), "{tiny}"); - assertEquals(huge.toString(), "{huge}"); - assertEquals(zero.toString(), "{0}"); - assertEquals(infi.toString(), "{inf}"); + assertEquals(tiny.toString(), "{tiny}"); + assertEquals(huge.toString(), "{huge}"); + assertEquals(zero.toString(), "{0}"); + assertEquals(infi.toString(), "{inf}"); - assertEquals(tiny.hashCode(), 1138753536); - assertEquals(huge.hashCode(), -1106247680); - assertEquals(zero.hashCode(), 0); - assertEquals(infi.hashCode(), 1106247680); + assertEquals(tiny.hashCode(), 1138753536); + assertEquals(huge.hashCode(), -1106247680); + assertEquals(zero.hashCode(), 0); + assertEquals(infi.hashCode(), 1106247680); - assertFalse(zero.isEqWithEpsilon(tiny)); - assertFalse(tiny.isEqWithEpsilon(huge)); - assertFalse(huge.isEqWithEpsilon(infi)); - assertFalse(infi.isEqWithEpsilon(zero)); + assertFalse(zero.isEqWithEpsilon(tiny)); + assertFalse(tiny.isEqWithEpsilon(huge)); + assertFalse(huge.isEqWithEpsilon(infi)); + assertFalse(infi.isEqWithEpsilon(zero)); - assertEquals(zero.minus(tiny), tiny.minus(tiny).minus(tiny)); - assertTrue(tiny.minus(huge).isEqWithEpsilon(zero.minus(huge))); - assertFalse(huge.minus(infi).isEqWithEpsilon(huge)); - assertEquals(infi.minus(zero), infi); + assertEquals(zero.minus(tiny), tiny.minus(tiny).minus(tiny)); + assertTrue(tiny.minus(huge).isEqWithEpsilon(zero.minus(huge))); + assertFalse(huge.minus(infi).isEqWithEpsilon(huge)); + assertEquals(infi.minus(zero), infi); - assertEquals(zero.divideBy(tiny), 1.0); - assertEquals(tiny.divideBy(huge), 0.0); - assertEquals(huge.divideBy(infi), 1.0); - assertEquals(infi.divideBy(zero), 1.0); + assertEquals(zero.divideBy(tiny), 1.0); + assertEquals(tiny.divideBy(huge), 0.0); + assertEquals(huge.divideBy(infi), 1.0); + assertEquals(infi.divideBy(zero), 1.0); - assertFalse(zero.equals(tiny)); - assertFalse(tiny.equals(huge)); - assertFalse(huge.equals(infi)); - assertFalse(infi.equals(zero)); + assertFalse(zero.equals(tiny)); + assertFalse(tiny.equals(huge)); + assertFalse(huge.equals(infi)); + assertFalse(infi.equals(zero)); - assertFalse(zero.equals((Object) tiny)); - assertFalse(tiny.equals((Object) huge)); - assertFalse(huge.equals((Object) infi)); - assertFalse(infi.equals((Object) zero)); - - } + assertFalse(zero.equals((Object) tiny)); + assertFalse(tiny.equals((Object) huge)); + assertFalse(huge.equals((Object) infi)); + assertFalse(infi.equals((Object) zero)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/GeaFlowGraphTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/GeaFlowGraphTest.java index 0422097da..a37ddb6a2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/GeaFlowGraphTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/GeaFlowGraphTest.java @@ -23,9 +23,9 @@ import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Map; + import org.apache.calcite.rel.type.RelDataType; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.DSLConfigKeys; @@ -37,65 +37,64 @@ import org.apache.geaflow.view.IViewDesc.BackendType; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class GeaFlowGraphTest { - @Test - public void testGeaFlowGraph() { - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - TableField field1 = new TableField("name", Types.STRING, true); - TableField field2 = new TableField("id", Types.LONG, false); - TableField field3 = new TableField("age", Types.DOUBLE, true); - VertexTable vertexTable = new VertexTable( - "default", - "person", - Lists.newArrayList(field1, field2, field3), - "id" - ); - assertEquals(vertexTable.getTypeName(), "person"); - assertEquals(vertexTable.getFields().size(), 3); - assertNotNull(vertexTable.getIdField()); + @Test + public void testGeaFlowGraph() { + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + TableField field1 = new TableField("name", Types.STRING, true); + TableField field2 = new TableField("id", Types.LONG, false); + TableField field3 = new TableField("age", Types.DOUBLE, true); + VertexTable vertexTable = + new VertexTable("default", "person", Lists.newArrayList(field1, field2, field3), "id"); + assertEquals(vertexTable.getTypeName(), "person"); + assertEquals(vertexTable.getFields().size(), 3); + assertNotNull(vertexTable.getIdField()); - TableField field4 = new TableField("src", Types.LONG, false); - TableField field5 = new TableField("dst", Types.LONG, false); - TableField field6 = new TableField("weight", Types.DOUBLE, true); - EdgeTable edgeTable = new EdgeTable( - "default", - "follow", - Lists.newArrayList(field4, field5, field6), - "src", "dst", null - ); - assertEquals(edgeTable.getTypeName(), "follow"); - assertEquals(edgeTable.getFields().size(), 3); - assertNotNull(edgeTable.getSrcIdField()); - assertNotNull(edgeTable.getTargetIdField()); - assertNull(edgeTable.getTimestampField()); + TableField field4 = new TableField("src", Types.LONG, false); + TableField field5 = new TableField("dst", Types.LONG, false); + TableField field6 = new TableField("weight", Types.DOUBLE, true); + EdgeTable edgeTable = + new EdgeTable( + "default", "follow", Lists.newArrayList(field4, field5, field6), "src", "dst", null); + assertEquals(edgeTable.getTypeName(), "follow"); + assertEquals(edgeTable.getFields().size(), 3); + assertNotNull(edgeTable.getSrcIdField()); + assertNotNull(edgeTable.getTargetIdField()); + assertNull(edgeTable.getTimestampField()); - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), "MEMORY"); - GeaFlowGraph graph = new GeaFlowGraph( + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), "MEMORY"); + GeaFlowGraph graph = + new GeaFlowGraph( "default", "g0", Lists.newArrayList(vertexTable), Lists.newArrayList(edgeTable), - config, new HashMap<>(), - false, false); + config, + new HashMap<>(), + false, + false); - RelDataType relDataType = graph.getRowType(typeFactory); - assertEquals(relDataType.toString(), "Graph:RecordType:peek(" - + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, DOUBLE age) person, " - + "Edge: RecordType:peek(BIGINT src, BIGINT dst, VARCHAR ~label, DOUBLE weight) follow)" - ); - assertEquals(graph.getName(), "g0"); - assertEquals(graph.getLabelType().getName(), "STRING"); - assertEquals(graph.getVertexTables().size(), 1); - assertEquals(graph.getEdgeTables().size(), 1); - assertEquals(graph.getConfig().getConfigMap().size(), 1); - assertEquals(BackendType.of(graph.getStoreType()), BackendType.Memory); + RelDataType relDataType = graph.getRowType(typeFactory); + assertEquals( + relDataType.toString(), + "Graph:RecordType:peek(Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name," + + " DOUBLE age) person, Edge: RecordType:peek(BIGINT src, BIGINT dst, VARCHAR ~label," + + " DOUBLE weight) follow)"); + assertEquals(graph.getName(), "g0"); + assertEquals(graph.getLabelType().getName(), "STRING"); + assertEquals(graph.getVertexTables().size(), 1); + assertEquals(graph.getEdgeTables().size(), 1); + assertEquals(graph.getConfig().getConfigMap().size(), 1); + assertEquals(BackendType.of(graph.getStoreType()), BackendType.Memory); - Map globalConfMap = new HashMap<>(); - globalConfMap.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), "rocksdb"); - Configuration globalConf = new Configuration(globalConfMap); - Configuration conf = graph.getConfigWithGlobal(globalConf); - assertEquals(conf.getString(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE), "MEMORY"); - } + Map globalConfMap = new HashMap<>(); + globalConfMap.put(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE.getKey(), "rocksdb"); + Configuration globalConf = new Configuration(globalConfMap); + Configuration conf = graph.getConfigWithGlobal(globalConf); + assertEquals(conf.getString(DSLConfigKeys.GEAFLOW_DSL_STORE_TYPE), "MEMORY"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/InternalFunctionsTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/InternalFunctionsTest.java index 084e4e5af..d0f60f644 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/InternalFunctionsTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/InternalFunctionsTest.java @@ -56,612 +56,610 @@ import static org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions.unequal; import static org.testng.AssertJUnit.assertEquals; -import com.google.common.collect.Lists; import java.math.BigDecimal; import java.sql.Timestamp; import java.util.List; + import org.apache.geaflow.dsl.common.function.UDTF; import org.apache.geaflow.dsl.planner.GQLJavaTypeFactory; import org.apache.geaflow.dsl.schema.function.GeaFlowUserDefinedTableFunction; import org.testng.Assert; import org.testng.annotations.Test; -public class InternalFunctionsTest { - Byte byte1 = Byte.valueOf("1"); - Short short1 = Short.valueOf("1"); - Integer int1 = Integer.valueOf("1"); - Long long1 = Long.valueOf("1"); - Double double1 = Double.valueOf("1.0"); - BigDecimal decimal1 = BigDecimal.valueOf(1); - Byte byteNull = null; - Short shortNull = null; - Integer intNull = null; - Long longNull = null; - Double doubleNull = null; - BigDecimal decimalNull = null; - - @Test - public void testPlus() { - assertEquals((long) plus(long1, long1), 2); - assertEquals((int) plus(int1, int1), 2); - assertEquals((int) plus(short1, short1), 2); - assertEquals((int) plus(byte1, byte1), 2); - assertEquals(plus(double1, double1), 2.0); - assertEquals(plus(decimal1, decimal1), BigDecimal.valueOf(2)); - assertEquals((long) plus(long1, int1), 2); - assertEquals((long) plus(long1, short1), 2); - assertEquals((long) plus(long1, byte1), 2); - assertEquals(plus(long1, double1), 2.0); - assertEquals(plus(long1, decimal1), BigDecimal.valueOf(2)); - assertEquals((long) plus(int1, long1), 2); - assertEquals((int) plus(int1, short1), 2); - assertEquals((int) plus(int1, byte1), 2); - assertEquals(plus(int1, double1), 2.0); - assertEquals(plus(int1, decimal1), BigDecimal.valueOf(2)); - assertEquals((long) plus(short1, long1), 2); - assertEquals((int) plus(short1, int1), 2); - assertEquals((int) plus(short1, byte1), 2); - assertEquals(plus(short1, double1), 2.0); - assertEquals(plus(short1, decimal1), BigDecimal.valueOf(2)); - assertEquals((long) plus(byte1, long1), 2); - assertEquals((int) plus(byte1, int1), 2); - assertEquals((int) plus(byte1, short1), 2); - assertEquals(plus(byte1, double1), 2.0); - assertEquals(plus(byte1, decimal1), BigDecimal.valueOf(2)); - assertEquals(plus(double1, long1), 2.0); - assertEquals(plus(double1, int1), 2.0); - assertEquals(plus(double1, short1), 2.0); - assertEquals(plus(double1, byte1), 2.0); - assertEquals(plus(double1, decimal1), BigDecimal.valueOf(2)); - assertEquals(plus(decimal1, long1), BigDecimal.valueOf(2)); - assertEquals(plus(decimal1, int1), BigDecimal.valueOf(2)); - assertEquals(plus(decimal1, short1), BigDecimal.valueOf(2)); - assertEquals(plus(decimal1, byte1), BigDecimal.valueOf(2)); - assertEquals(plus(decimal1, double1), BigDecimal.valueOf(2)); - - Assert.assertNull(plus(longNull, long1)); - Assert.assertNull(plus(intNull, int1)); - Assert.assertNull(plus(shortNull, short1)); - Assert.assertNull(plus(byteNull, byte1)); - Assert.assertNull(plus(doubleNull, double1)); - Assert.assertNull(plus(decimalNull, decimal1)); - Assert.assertNull(plus(longNull, int1)); - Assert.assertNull(plus(longNull, short1)); - Assert.assertNull(plus(longNull, byte1)); - Assert.assertNull(plus(longNull, double1)); - Assert.assertNull(plus(longNull, decimal1)); - Assert.assertNull(plus(intNull, long1)); - Assert.assertNull(plus(intNull, short1)); - Assert.assertNull(plus(intNull, byte1)); - Assert.assertNull(plus(intNull, double1)); - Assert.assertNull(plus(intNull, decimal1)); - Assert.assertNull(plus(shortNull, long1)); - Assert.assertNull(plus(shortNull, int1)); - Assert.assertNull(plus(shortNull, byte1)); - Assert.assertNull(plus(shortNull, double1)); - Assert.assertNull(plus(shortNull, decimal1)); - Assert.assertNull(plus(byteNull, long1)); - Assert.assertNull(plus(byteNull, int1)); - Assert.assertNull(plus(byteNull, short1)); - Assert.assertNull(plus(byteNull, double1)); - Assert.assertNull(plus(byteNull, decimal1)); - Assert.assertNull(plus(doubleNull, long1)); - Assert.assertNull(plus(doubleNull, int1)); - Assert.assertNull(plus(doubleNull, short1)); - Assert.assertNull(plus(doubleNull, byte1)); - Assert.assertNull(plus(doubleNull, decimal1)); - Assert.assertNull(plus(decimalNull, long1)); - Assert.assertNull(plus(decimalNull, int1)); - Assert.assertNull(plus(decimalNull, short1)); - Assert.assertNull(plus(decimalNull, byte1)); - Assert.assertNull(plus(decimalNull, double1)); - } - - @Test - public void testMinus() { - assertEquals((long) minus(long1, long1), 0); - assertEquals((int) minus(int1, int1), 0); - assertEquals((int) minus(short1, short1), 0); - assertEquals((int) minus(byte1, byte1), 0); - assertEquals(minus(double1, double1), 0.0); - assertEquals(minus(decimal1, decimal1), BigDecimal.valueOf(0)); - assertEquals((long) minus(long1, int1), 0); - assertEquals((long) minus(long1, short1), 0); - assertEquals((long) minus(long1, byte1), 0); - assertEquals(minus(long1, double1), 0.0); - assertEquals(minus(long1, decimal1), BigDecimal.valueOf(0)); - assertEquals((long) minus(int1, long1), 0); - assertEquals((int) minus(int1, short1), 0); - assertEquals((int) minus(int1, byte1), 0); - assertEquals(minus(int1, double1), 0.0); - assertEquals(minus(int1, decimal1), BigDecimal.valueOf(0)); - assertEquals((long) minus(short1, long1), 0); - assertEquals((int) minus(short1, int1), 0); - assertEquals((int) minus(short1, byte1), 0); - assertEquals(minus(short1, double1), 0.0); - assertEquals(minus(short1, decimal1), BigDecimal.valueOf(0)); - assertEquals((long) minus(byte1, long1), 0); - assertEquals((int) minus(byte1, int1), 0); - assertEquals((int) minus(byte1, short1), 0); - assertEquals(minus(byte1, double1), 0.0); - assertEquals(minus(byte1, decimal1), BigDecimal.valueOf(0)); - assertEquals(minus(double1, long1), 0.0); - assertEquals(minus(double1, int1), 0.0); - assertEquals(minus(double1, short1), 0.0); - assertEquals(minus(double1, byte1), 0.0); - assertEquals(minus(double1, decimal1), BigDecimal.valueOf(0)); - assertEquals(minus(decimal1, long1), BigDecimal.valueOf(0)); - assertEquals(minus(decimal1, int1), BigDecimal.valueOf(0)); - assertEquals(minus(decimal1, short1), BigDecimal.valueOf(0)); - assertEquals(minus(decimal1, byte1), BigDecimal.valueOf(0)); - assertEquals(minus(decimal1, double1), BigDecimal.valueOf(0)); - - Assert.assertNull(minus(longNull, long1)); - Assert.assertNull(minus(intNull, int1)); - Assert.assertNull(minus(shortNull, short1)); - Assert.assertNull(minus(byteNull, byte1)); - Assert.assertNull(minus(doubleNull, double1)); - Assert.assertNull(minus(decimalNull, decimal1)); - Assert.assertNull(minus(longNull, int1)); - Assert.assertNull(minus(longNull, short1)); - Assert.assertNull(minus(longNull, byte1)); - Assert.assertNull(minus(longNull, double1)); - Assert.assertNull(minus(longNull, decimal1)); - Assert.assertNull(minus(intNull, long1)); - Assert.assertNull(minus(intNull, short1)); - Assert.assertNull(minus(intNull, byte1)); - Assert.assertNull(minus(intNull, double1)); - Assert.assertNull(minus(intNull, decimal1)); - Assert.assertNull(minus(shortNull, long1)); - Assert.assertNull(minus(shortNull, int1)); - Assert.assertNull(minus(shortNull, byte1)); - Assert.assertNull(minus(shortNull, double1)); - Assert.assertNull(minus(shortNull, decimal1)); - Assert.assertNull(minus(byteNull, long1)); - Assert.assertNull(minus(byteNull, int1)); - Assert.assertNull(minus(byteNull, short1)); - Assert.assertNull(minus(byteNull, double1)); - Assert.assertNull(minus(byteNull, decimal1)); - Assert.assertNull(minus(doubleNull, long1)); - Assert.assertNull(minus(doubleNull, int1)); - Assert.assertNull(minus(doubleNull, short1)); - Assert.assertNull(minus(doubleNull, byte1)); - Assert.assertNull(minus(doubleNull, decimal1)); - Assert.assertNull(minus(decimalNull, long1)); - Assert.assertNull(minus(decimalNull, int1)); - Assert.assertNull(minus(decimalNull, short1)); - Assert.assertNull(minus(decimalNull, byte1)); - Assert.assertNull(minus(decimalNull, double1)); - } - - @Test - public void testTimes() { - assertEquals((long) times(long1, long1), 1); - assertEquals((long) times(int1, int1), 1); - assertEquals((int) times(short1, short1), 1); - assertEquals((int) times(byte1, byte1), 1); - assertEquals(times(double1, double1), 1.0); - assertEquals(times(decimal1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) times(long1, int1), 1); - assertEquals((long) times(long1, short1), 1); - assertEquals((long) times(long1, byte1), 1); - assertEquals(times(long1, double1), 1.0); - assertEquals(times(long1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) times(int1, long1), 1); - assertEquals((int) times(int1, short1), 1); - assertEquals((int) times(int1, byte1), 1); - assertEquals(times(int1, double1), 1.0); - assertEquals(times(int1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) times(short1, long1), 1); - assertEquals((int) times(short1, int1), 1); - assertEquals((int) times(short1, byte1), 1); - assertEquals(times(short1, double1), 1.0); - assertEquals(times(short1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) times(byte1, long1), 1); - assertEquals((int) times(byte1, int1), 1); - assertEquals((int) times(byte1, short1), 1); - assertEquals(times(byte1, double1), 1.0); - assertEquals(times(byte1, decimal1), BigDecimal.valueOf(1)); - assertEquals(times(double1, long1), 1.0); - assertEquals(times(double1, int1), 1.0); - assertEquals(times(double1, short1), 1.0); - assertEquals(times(double1, byte1), 1.0); - assertEquals(times(double1, decimal1), BigDecimal.valueOf(1)); - assertEquals(times(decimal1, long1), BigDecimal.valueOf(1)); - assertEquals(times(decimal1, int1), BigDecimal.valueOf(1)); - assertEquals(times(decimal1, short1), BigDecimal.valueOf(1)); - assertEquals(times(decimal1, byte1), BigDecimal.valueOf(1)); - assertEquals(times(decimal1, double1), BigDecimal.valueOf(1)); - - Assert.assertNull(times(longNull, long1)); - Assert.assertNull(times(intNull, int1)); - Assert.assertNull(times(shortNull, short1)); - Assert.assertNull(times(byteNull, byte1)); - Assert.assertNull(times(doubleNull, double1)); - Assert.assertNull(times(decimalNull, decimal1)); - Assert.assertNull(times(longNull, int1)); - Assert.assertNull(times(longNull, short1)); - Assert.assertNull(times(longNull, byte1)); - Assert.assertNull(times(longNull, double1)); - Assert.assertNull(times(longNull, decimal1)); - Assert.assertNull(times(intNull, long1)); - Assert.assertNull(times(intNull, short1)); - Assert.assertNull(times(intNull, byte1)); - Assert.assertNull(times(intNull, double1)); - Assert.assertNull(times(intNull, decimal1)); - Assert.assertNull(times(shortNull, long1)); - Assert.assertNull(times(shortNull, int1)); - Assert.assertNull(times(shortNull, byte1)); - Assert.assertNull(times(shortNull, double1)); - Assert.assertNull(times(shortNull, decimal1)); - Assert.assertNull(times(byteNull, long1)); - Assert.assertNull(times(byteNull, int1)); - Assert.assertNull(times(byteNull, short1)); - Assert.assertNull(times(byteNull, double1)); - Assert.assertNull(times(byteNull, decimal1)); - Assert.assertNull(times(doubleNull, long1)); - Assert.assertNull(times(doubleNull, int1)); - Assert.assertNull(times(doubleNull, short1)); - Assert.assertNull(times(doubleNull, byte1)); - Assert.assertNull(times(doubleNull, decimal1)); - Assert.assertNull(times(decimalNull, long1)); - Assert.assertNull(times(decimalNull, int1)); - Assert.assertNull(times(decimalNull, short1)); - Assert.assertNull(times(decimalNull, byte1)); - Assert.assertNull(times(decimalNull, double1)); - } - - @Test - public void testDivide() { - assertEquals((long) divide(long1, long1), 1); - assertEquals((int) divide(int1, int1), 1); - assertEquals((int) divide(short1, short1), 1); - assertEquals((int) divide(byte1, byte1), 1); - assertEquals(divide(double1, double1), 1.0); - assertEquals(divide(decimal1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) divide(long1, int1), 1); - assertEquals((long) divide(long1, short1), 1); - assertEquals((long) divide(long1, byte1), 1); - assertEquals(divide(long1, double1), 1.0); - assertEquals(divide(long1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) divide(int1, long1), 1); - assertEquals((int) divide(int1, short1), 1); - assertEquals((int) divide(int1, byte1), 1); - assertEquals(divide(int1, double1), 1.0); - assertEquals(divide(int1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) divide(short1, long1), 1); - assertEquals((int) divide(short1, int1), 1); - assertEquals((int) divide(short1, byte1), 1); - assertEquals(divide(short1, double1), 1.0); - assertEquals(divide(short1, decimal1), BigDecimal.valueOf(1)); - assertEquals((long) divide(byte1, long1), 1); - assertEquals((int) divide(byte1, int1), 1); - assertEquals((int) divide(byte1, short1), 1); - assertEquals(divide(byte1, double1), 1.0); - assertEquals(divide(byte1, decimal1), BigDecimal.valueOf(1)); - assertEquals(divide(double1, long1), 1.0); - assertEquals(divide(double1, int1), 1.0); - assertEquals(divide(double1, short1), 1.0); - assertEquals(divide(double1, byte1), 1.0); - assertEquals(divide(double1, decimal1), BigDecimal.valueOf(1)); - assertEquals(divide(decimal1, long1), BigDecimal.valueOf(1)); - assertEquals(divide(decimal1, int1), BigDecimal.valueOf(1)); - assertEquals(divide(decimal1, short1), BigDecimal.valueOf(1)); - assertEquals(divide(decimal1, byte1), BigDecimal.valueOf(1)); - assertEquals(divide(decimal1, double1), BigDecimal.valueOf(1)); - - Assert.assertNull(divide(longNull, long1)); - Assert.assertNull(divide(intNull, int1)); - Assert.assertNull(divide(shortNull, short1)); - Assert.assertNull(divide(byteNull, byte1)); - Assert.assertNull(divide(doubleNull, double1)); - Assert.assertNull(divide(decimalNull, decimal1)); - Assert.assertNull(divide(longNull, int1)); - Assert.assertNull(divide(longNull, short1)); - Assert.assertNull(divide(longNull, byte1)); - Assert.assertNull(divide(longNull, double1)); - Assert.assertNull(divide(longNull, decimal1)); - Assert.assertNull(divide(intNull, long1)); - Assert.assertNull(divide(intNull, short1)); - Assert.assertNull(divide(intNull, byte1)); - Assert.assertNull(divide(intNull, double1)); - Assert.assertNull(divide(intNull, decimal1)); - Assert.assertNull(divide(shortNull, long1)); - Assert.assertNull(divide(shortNull, int1)); - Assert.assertNull(divide(shortNull, byte1)); - Assert.assertNull(divide(shortNull, double1)); - Assert.assertNull(divide(shortNull, decimal1)); - Assert.assertNull(divide(byteNull, long1)); - Assert.assertNull(divide(byteNull, int1)); - Assert.assertNull(divide(byteNull, short1)); - Assert.assertNull(divide(byteNull, double1)); - Assert.assertNull(divide(byteNull, decimal1)); - Assert.assertNull(divide(doubleNull, long1)); - Assert.assertNull(divide(doubleNull, int1)); - Assert.assertNull(divide(doubleNull, short1)); - Assert.assertNull(divide(doubleNull, byte1)); - Assert.assertNull(divide(doubleNull, decimal1)); - Assert.assertNull(divide(decimalNull, long1)); - Assert.assertNull(divide(decimalNull, int1)); - Assert.assertNull(divide(decimalNull, short1)); - Assert.assertNull(divide(decimalNull, byte1)); - Assert.assertNull(divide(decimalNull, double1)); - } - - @Test - public void testMod() { - assertEquals((long) mod(long1, long1), 0); - assertEquals((int) mod(int1, int1), 0); - assertEquals(mod(double1, double1), 0.0); - - Assert.assertNull(mod(longNull, long1)); - Assert.assertNull(mod(intNull, int1)); - Assert.assertNull(mod(doubleNull, double1)); - } - - @Test - public void testPower() { - assertEquals(power(double1, double1), 1.0); - - Assert.assertNull(power(doubleNull, double1)); - } - - @Test - public void testAbs() { - assertEquals((long) abs(long1), 1); - assertEquals((int) abs(int1), 1); - assertEquals((int) abs(short1), 1); - assertEquals((int) abs(byte1), 1); - assertEquals(abs(decimal1), BigDecimal.valueOf(1)); - assertEquals(abs(double1), 1.0); - - Assert.assertNull(abs(longNull)); - Assert.assertNull(abs(intNull)); - Assert.assertNull(abs(shortNull)); - Assert.assertNull(abs(byteNull)); - Assert.assertNull(abs(decimalNull)); - Assert.assertNull(abs(doubleNull)); - } - - @Test - public void testTrigonometric() { - assertEquals(asin(double1), 1.5707963267948966); - assertEquals(acos(double1), 0.0); - assertEquals(atan(double1), 0.7853981633974483); - assertEquals(ceil(double1), 1.0); - assertEquals(ceil(long1), long1); - assertEquals(ceil(int1), int1); - assertEquals(cot(double1), 0.6420926159343306); - assertEquals(cos(double1), 0.5403023058681398); - assertEquals(degrees(double1), 57.29577951308232); - assertEquals(radians(double1), 0.017453292519943295); - assertEquals(sign(double1), 1.0); - assertEquals(sin(double1), 0.8414709848078965); - assertEquals(tan(double1), 1.5574077246549023); - - Assert.assertNull(asin(doubleNull)); - Assert.assertNull(acos(doubleNull)); - Assert.assertNull(atan(doubleNull)); - Assert.assertNull(ceil(doubleNull)); - Assert.assertNull(ceil(longNull)); - Assert.assertNull(ceil(intNull)); - Assert.assertNull(cot(doubleNull)); - Assert.assertNull(cos(doubleNull)); - Assert.assertNull(degrees(doubleNull)); - Assert.assertNull(radians(doubleNull)); - Assert.assertNull(sign(doubleNull)); - Assert.assertNull(sin(doubleNull)); - Assert.assertNull(tan(doubleNull)); - } - - @Test - public void testMath() { - assertEquals(exp(double1), 2.718281828459045, 1e-15); - assertEquals(floor(double1), 1.0); - assertEquals((long) floor(long1), 1); - assertEquals((int) floor(int1), 1); - assertEquals(ln(double1), 0.0); - assertEquals(log10(double1), 0.0); - assertEquals(minusPrefix(double1), -1.0); - assertEquals((long) minusPrefix(long1), -1); - assertEquals((int) minusPrefix(int1), -1); - assertEquals((int) minusPrefix(short1), -1); - assertEquals((int) minusPrefix(byte1), -1); - assertEquals(minusPrefix(decimal1), BigDecimal.valueOf(-1)); - - System.out.println(rand()); - System.out.println(rand(long1)); - System.out.println(rand(long1, int1)); - System.out.println(randInt(int1)); - System.out.println(randInt(long1, int1)); - - Assert.assertNull(exp(doubleNull)); - Assert.assertNull(floor(doubleNull)); - Assert.assertNull(floor(longNull)); - Assert.assertNull(floor(intNull)); - Assert.assertNull(ln(doubleNull)); - Assert.assertNull(log10(doubleNull)); - Assert.assertNull(minusPrefix(doubleNull)); - Assert.assertNull(minusPrefix(longNull)); - Assert.assertNull(minusPrefix(intNull)); - Assert.assertNull(minusPrefix(shortNull)); - Assert.assertNull(minusPrefix(byteNull)); - Assert.assertNull(minusPrefix(decimalNull)); - } - - @Test - public void testEqual() { - String string1 = "1"; - String stringNull = null; - Boolean boolTrue = true; - Boolean boolNull = null; - - Assert.assertTrue(equal(long1, long1)); - Assert.assertTrue(equal(double1, double1)); - Assert.assertTrue(equal(decimal1, decimal1)); - Assert.assertTrue(equal(string1, string1)); - Assert.assertTrue(equal(boolTrue, boolTrue)); - Assert.assertTrue(equal(string1, int1)); - Assert.assertTrue(equal(int1, string1)); - Assert.assertTrue(equal(string1, double1)); - Assert.assertTrue(equal(double1, string1)); - Assert.assertTrue(equal(string1, long1)); - Assert.assertTrue(equal(long1, string1)); - Assert.assertTrue(equal(string1, false)); - Assert.assertTrue(equal(false, string1)); - Assert.assertTrue(equal((Object) boolTrue, boolTrue)); - - Assert.assertNull(equal(longNull, long1)); - Assert.assertNull(equal(doubleNull, double1)); - Assert.assertNull(equal(decimalNull, decimal1)); - Assert.assertNull(equal(stringNull, string1)); - Assert.assertNull(equal(boolNull, boolTrue)); - Assert.assertNull(equal(stringNull, int1)); - Assert.assertNull(equal(intNull, string1)); - Assert.assertNull(equal(stringNull, double1)); - Assert.assertNull(equal(doubleNull, string1)); - Assert.assertNull(equal(stringNull, long1)); - Assert.assertNull(equal(longNull, string1)); - Assert.assertNull(equal(stringNull, boolTrue)); - Assert.assertNull(equal(boolNull, string1)); - Assert.assertNull(equal(stringNull, boolTrue)); - Assert.assertNull(equal((Object) boolNull, boolTrue)); - } - - @Test - public void testUnequal() { - String string1 = "1"; - String stringNull = null; - Boolean boolTrue = true; - Boolean boolNull = null; - - Assert.assertFalse(unequal(long1, long1)); - Assert.assertFalse(unequal(double1, double1)); - Assert.assertFalse(unequal(decimal1, decimal1)); - Assert.assertFalse(unequal(string1, string1)); - Assert.assertFalse(unequal(boolTrue, boolTrue)); - Assert.assertFalse(unequal(string1, int1)); - Assert.assertFalse(unequal(int1, string1)); - Assert.assertFalse(unequal(string1, double1)); - Assert.assertFalse(unequal(double1, string1)); - Assert.assertFalse(unequal(string1, long1)); - Assert.assertFalse(unequal(long1, string1)); - Assert.assertFalse(unequal(string1, false)); - Assert.assertFalse(unequal(false, string1)); - Assert.assertFalse(unequal((Object) boolTrue, boolTrue)); - - Assert.assertNull(unequal(longNull, long1)); - Assert.assertNull(unequal(doubleNull, double1)); - Assert.assertNull(unequal(decimalNull, decimal1)); - Assert.assertNull(unequal(stringNull, string1)); - Assert.assertNull(unequal(boolNull, boolTrue)); - Assert.assertNull(unequal(stringNull, int1)); - Assert.assertNull(unequal(intNull, string1)); - Assert.assertNull(unequal(stringNull, double1)); - Assert.assertNull(unequal(doubleNull, string1)); - Assert.assertNull(unequal(stringNull, long1)); - Assert.assertNull(unequal(longNull, string1)); - Assert.assertNull(unequal(stringNull, boolTrue)); - Assert.assertNull(unequal(boolNull, string1)); - Assert.assertNull(unequal(stringNull, boolTrue)); - Assert.assertFalse(unequal((Object) boolTrue, boolTrue)); - } - - @Test - public void testCompare() { - String string1 = "1"; - String stringNull = null; - - Assert.assertFalse(lessThan(long1, long1)); - Assert.assertFalse(lessThan(double1, double1)); - Assert.assertFalse(lessThan(decimal1, decimal1)); - Assert.assertFalse(lessThan(string1, string1)); - - Assert.assertTrue(greaterThanEq(long1, long1)); - Assert.assertTrue(greaterThanEq(double1, double1)); - Assert.assertTrue(greaterThanEq(decimal1, decimal1)); - Assert.assertTrue(greaterThanEq(string1, string1)); - - Assert.assertTrue(lessThanEq(long1, long1)); - Assert.assertTrue(lessThanEq(double1, double1)); - Assert.assertTrue(lessThanEq(decimal1, decimal1)); - Assert.assertTrue(lessThanEq(string1, string1)); - - Assert.assertFalse(greaterThan(long1, long1)); - Assert.assertFalse(greaterThan(double1, double1)); - Assert.assertFalse(greaterThan(decimal1, decimal1)); - Assert.assertFalse(greaterThan(string1, string1)); - - Assert.assertNull(lessThan(longNull, long1)); - Assert.assertNull(lessThan(doubleNull, double1)); - Assert.assertNull(lessThan(decimalNull, decimal1)); - Assert.assertNull(lessThan(stringNull, string1)); - - Assert.assertNull(greaterThanEq(longNull, long1)); - Assert.assertNull(greaterThanEq(doubleNull, double1)); - Assert.assertNull(greaterThanEq(decimalNull, decimal1)); - Assert.assertNull(greaterThanEq(stringNull, string1)); - - Assert.assertNull(lessThanEq(longNull, long1)); - Assert.assertNull(lessThanEq(doubleNull, double1)); - Assert.assertNull(lessThanEq(decimalNull, decimal1)); - Assert.assertNull(lessThanEq(stringNull, string1)); - - Assert.assertNull(greaterThan(longNull, long1)); - Assert.assertNull(greaterThan(doubleNull, double1)); - Assert.assertNull(greaterThan(decimalNull, decimal1)); - Assert.assertNull(greaterThan(stringNull, string1)); - } - - @Test - public void testTimestampUtil() { - Timestamp ts = Timestamp.valueOf("1987-06-05 04:03:02"); - Assert.assertEquals(timestampCeil(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:03")); - Assert.assertEquals(timestampCeil(ts, 60000L), Timestamp.valueOf("1987-06-05 04:04:00.0")); - Assert.assertEquals(timestampCeil(ts, 3600000L), Timestamp.valueOf("1987-06-05 05:00:00.0")); - Assert.assertEquals(timestampCeil(ts, 86400000L), Timestamp.valueOf("1987-06-06 00:00:00.0")); - Assert.assertEquals(timestampTumble(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:03")); - Assert.assertEquals(timestampTumble(ts, 60000L), Timestamp.valueOf("1987-06-05 04:04:00.0")); - Assert.assertEquals(timestampTumble(ts, 3600000L), Timestamp.valueOf("1987-06-05 05:00:00.0")); - Assert.assertEquals(timestampFloor(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:02.0")); - Assert.assertEquals(timestampFloor(ts, 60000L), Timestamp.valueOf("1987-06-05 04:03:00.0")); - Assert.assertEquals(timestampFloor(ts, 3600000L), Timestamp.valueOf("1987-06-05 04:00:00.0")); - Assert.assertEquals(timestampFloor(ts, 86400000L), Timestamp.valueOf("1987-06-05 00:00:00.0")); - - Assert.assertEquals(plus(ts, 1L), Timestamp.valueOf("1987-06-05 04:03:02.001")); - Assert.assertEquals(minus(ts, 1L), Timestamp.valueOf("1987-06-05 04:03:01.999")); - Assert.assertEquals((long) minus(ts, Timestamp.valueOf("1987-06-05 00:00:00.0")), 14582000L); - - Assert.assertNull(plus((Timestamp) null, 1L)); - Assert.assertNull(minus((Timestamp) null, 1L)); - Assert.assertNull(minus((Timestamp) null, Timestamp.valueOf("1987-06-05 00:00:00.0"))); - - } - - @Test - public void testOtherFunction() { - Assert.assertEquals(round(double1, 2), 1.0); - - Assert.assertNull(round(doubleNull, 2)); - } - - @Test - public void testGeaFlowUserDefinedTableFunction() { - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - GeaFlowUserDefinedTableFunction.create("testFunction", - UtUserDefinedTableFunction.class, - typeFactory - ); - } - - public static class UtUserDefinedTableFunction extends UDTF { - - public UtUserDefinedTableFunction() { - } +import com.google.common.collect.Lists; - @Override - public List> getReturnType(List> paramTypes, List outFieldNames) { - return Lists.newArrayList(String.class, Long.class); - } +public class InternalFunctionsTest { + Byte byte1 = Byte.valueOf("1"); + Short short1 = Short.valueOf("1"); + Integer int1 = Integer.valueOf("1"); + Long long1 = Long.valueOf("1"); + Double double1 = Double.valueOf("1.0"); + BigDecimal decimal1 = BigDecimal.valueOf(1); + Byte byteNull = null; + Short shortNull = null; + Integer intNull = null; + Long longNull = null; + Double doubleNull = null; + BigDecimal decimalNull = null; + + @Test + public void testPlus() { + assertEquals((long) plus(long1, long1), 2); + assertEquals((int) plus(int1, int1), 2); + assertEquals((int) plus(short1, short1), 2); + assertEquals((int) plus(byte1, byte1), 2); + assertEquals(plus(double1, double1), 2.0); + assertEquals(plus(decimal1, decimal1), BigDecimal.valueOf(2)); + assertEquals((long) plus(long1, int1), 2); + assertEquals((long) plus(long1, short1), 2); + assertEquals((long) plus(long1, byte1), 2); + assertEquals(plus(long1, double1), 2.0); + assertEquals(plus(long1, decimal1), BigDecimal.valueOf(2)); + assertEquals((long) plus(int1, long1), 2); + assertEquals((int) plus(int1, short1), 2); + assertEquals((int) plus(int1, byte1), 2); + assertEquals(plus(int1, double1), 2.0); + assertEquals(plus(int1, decimal1), BigDecimal.valueOf(2)); + assertEquals((long) plus(short1, long1), 2); + assertEquals((int) plus(short1, int1), 2); + assertEquals((int) plus(short1, byte1), 2); + assertEquals(plus(short1, double1), 2.0); + assertEquals(plus(short1, decimal1), BigDecimal.valueOf(2)); + assertEquals((long) plus(byte1, long1), 2); + assertEquals((int) plus(byte1, int1), 2); + assertEquals((int) plus(byte1, short1), 2); + assertEquals(plus(byte1, double1), 2.0); + assertEquals(plus(byte1, decimal1), BigDecimal.valueOf(2)); + assertEquals(plus(double1, long1), 2.0); + assertEquals(plus(double1, int1), 2.0); + assertEquals(plus(double1, short1), 2.0); + assertEquals(plus(double1, byte1), 2.0); + assertEquals(plus(double1, decimal1), BigDecimal.valueOf(2)); + assertEquals(plus(decimal1, long1), BigDecimal.valueOf(2)); + assertEquals(plus(decimal1, int1), BigDecimal.valueOf(2)); + assertEquals(plus(decimal1, short1), BigDecimal.valueOf(2)); + assertEquals(plus(decimal1, byte1), BigDecimal.valueOf(2)); + assertEquals(plus(decimal1, double1), BigDecimal.valueOf(2)); + + Assert.assertNull(plus(longNull, long1)); + Assert.assertNull(plus(intNull, int1)); + Assert.assertNull(plus(shortNull, short1)); + Assert.assertNull(plus(byteNull, byte1)); + Assert.assertNull(plus(doubleNull, double1)); + Assert.assertNull(plus(decimalNull, decimal1)); + Assert.assertNull(plus(longNull, int1)); + Assert.assertNull(plus(longNull, short1)); + Assert.assertNull(plus(longNull, byte1)); + Assert.assertNull(plus(longNull, double1)); + Assert.assertNull(plus(longNull, decimal1)); + Assert.assertNull(plus(intNull, long1)); + Assert.assertNull(plus(intNull, short1)); + Assert.assertNull(plus(intNull, byte1)); + Assert.assertNull(plus(intNull, double1)); + Assert.assertNull(plus(intNull, decimal1)); + Assert.assertNull(plus(shortNull, long1)); + Assert.assertNull(plus(shortNull, int1)); + Assert.assertNull(plus(shortNull, byte1)); + Assert.assertNull(plus(shortNull, double1)); + Assert.assertNull(plus(shortNull, decimal1)); + Assert.assertNull(plus(byteNull, long1)); + Assert.assertNull(plus(byteNull, int1)); + Assert.assertNull(plus(byteNull, short1)); + Assert.assertNull(plus(byteNull, double1)); + Assert.assertNull(plus(byteNull, decimal1)); + Assert.assertNull(plus(doubleNull, long1)); + Assert.assertNull(plus(doubleNull, int1)); + Assert.assertNull(plus(doubleNull, short1)); + Assert.assertNull(plus(doubleNull, byte1)); + Assert.assertNull(plus(doubleNull, decimal1)); + Assert.assertNull(plus(decimalNull, long1)); + Assert.assertNull(plus(decimalNull, int1)); + Assert.assertNull(plus(decimalNull, short1)); + Assert.assertNull(plus(decimalNull, byte1)); + Assert.assertNull(plus(decimalNull, double1)); + } + + @Test + public void testMinus() { + assertEquals((long) minus(long1, long1), 0); + assertEquals((int) minus(int1, int1), 0); + assertEquals((int) minus(short1, short1), 0); + assertEquals((int) minus(byte1, byte1), 0); + assertEquals(minus(double1, double1), 0.0); + assertEquals(minus(decimal1, decimal1), BigDecimal.valueOf(0)); + assertEquals((long) minus(long1, int1), 0); + assertEquals((long) minus(long1, short1), 0); + assertEquals((long) minus(long1, byte1), 0); + assertEquals(minus(long1, double1), 0.0); + assertEquals(minus(long1, decimal1), BigDecimal.valueOf(0)); + assertEquals((long) minus(int1, long1), 0); + assertEquals((int) minus(int1, short1), 0); + assertEquals((int) minus(int1, byte1), 0); + assertEquals(minus(int1, double1), 0.0); + assertEquals(minus(int1, decimal1), BigDecimal.valueOf(0)); + assertEquals((long) minus(short1, long1), 0); + assertEquals((int) minus(short1, int1), 0); + assertEquals((int) minus(short1, byte1), 0); + assertEquals(minus(short1, double1), 0.0); + assertEquals(minus(short1, decimal1), BigDecimal.valueOf(0)); + assertEquals((long) minus(byte1, long1), 0); + assertEquals((int) minus(byte1, int1), 0); + assertEquals((int) minus(byte1, short1), 0); + assertEquals(minus(byte1, double1), 0.0); + assertEquals(minus(byte1, decimal1), BigDecimal.valueOf(0)); + assertEquals(minus(double1, long1), 0.0); + assertEquals(minus(double1, int1), 0.0); + assertEquals(minus(double1, short1), 0.0); + assertEquals(minus(double1, byte1), 0.0); + assertEquals(minus(double1, decimal1), BigDecimal.valueOf(0)); + assertEquals(minus(decimal1, long1), BigDecimal.valueOf(0)); + assertEquals(minus(decimal1, int1), BigDecimal.valueOf(0)); + assertEquals(minus(decimal1, short1), BigDecimal.valueOf(0)); + assertEquals(minus(decimal1, byte1), BigDecimal.valueOf(0)); + assertEquals(minus(decimal1, double1), BigDecimal.valueOf(0)); + + Assert.assertNull(minus(longNull, long1)); + Assert.assertNull(minus(intNull, int1)); + Assert.assertNull(minus(shortNull, short1)); + Assert.assertNull(minus(byteNull, byte1)); + Assert.assertNull(minus(doubleNull, double1)); + Assert.assertNull(minus(decimalNull, decimal1)); + Assert.assertNull(minus(longNull, int1)); + Assert.assertNull(minus(longNull, short1)); + Assert.assertNull(minus(longNull, byte1)); + Assert.assertNull(minus(longNull, double1)); + Assert.assertNull(minus(longNull, decimal1)); + Assert.assertNull(minus(intNull, long1)); + Assert.assertNull(minus(intNull, short1)); + Assert.assertNull(minus(intNull, byte1)); + Assert.assertNull(minus(intNull, double1)); + Assert.assertNull(minus(intNull, decimal1)); + Assert.assertNull(minus(shortNull, long1)); + Assert.assertNull(minus(shortNull, int1)); + Assert.assertNull(minus(shortNull, byte1)); + Assert.assertNull(minus(shortNull, double1)); + Assert.assertNull(minus(shortNull, decimal1)); + Assert.assertNull(minus(byteNull, long1)); + Assert.assertNull(minus(byteNull, int1)); + Assert.assertNull(minus(byteNull, short1)); + Assert.assertNull(minus(byteNull, double1)); + Assert.assertNull(minus(byteNull, decimal1)); + Assert.assertNull(minus(doubleNull, long1)); + Assert.assertNull(minus(doubleNull, int1)); + Assert.assertNull(minus(doubleNull, short1)); + Assert.assertNull(minus(doubleNull, byte1)); + Assert.assertNull(minus(doubleNull, decimal1)); + Assert.assertNull(minus(decimalNull, long1)); + Assert.assertNull(minus(decimalNull, int1)); + Assert.assertNull(minus(decimalNull, short1)); + Assert.assertNull(minus(decimalNull, byte1)); + Assert.assertNull(minus(decimalNull, double1)); + } + + @Test + public void testTimes() { + assertEquals((long) times(long1, long1), 1); + assertEquals((long) times(int1, int1), 1); + assertEquals((int) times(short1, short1), 1); + assertEquals((int) times(byte1, byte1), 1); + assertEquals(times(double1, double1), 1.0); + assertEquals(times(decimal1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) times(long1, int1), 1); + assertEquals((long) times(long1, short1), 1); + assertEquals((long) times(long1, byte1), 1); + assertEquals(times(long1, double1), 1.0); + assertEquals(times(long1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) times(int1, long1), 1); + assertEquals((int) times(int1, short1), 1); + assertEquals((int) times(int1, byte1), 1); + assertEquals(times(int1, double1), 1.0); + assertEquals(times(int1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) times(short1, long1), 1); + assertEquals((int) times(short1, int1), 1); + assertEquals((int) times(short1, byte1), 1); + assertEquals(times(short1, double1), 1.0); + assertEquals(times(short1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) times(byte1, long1), 1); + assertEquals((int) times(byte1, int1), 1); + assertEquals((int) times(byte1, short1), 1); + assertEquals(times(byte1, double1), 1.0); + assertEquals(times(byte1, decimal1), BigDecimal.valueOf(1)); + assertEquals(times(double1, long1), 1.0); + assertEquals(times(double1, int1), 1.0); + assertEquals(times(double1, short1), 1.0); + assertEquals(times(double1, byte1), 1.0); + assertEquals(times(double1, decimal1), BigDecimal.valueOf(1)); + assertEquals(times(decimal1, long1), BigDecimal.valueOf(1)); + assertEquals(times(decimal1, int1), BigDecimal.valueOf(1)); + assertEquals(times(decimal1, short1), BigDecimal.valueOf(1)); + assertEquals(times(decimal1, byte1), BigDecimal.valueOf(1)); + assertEquals(times(decimal1, double1), BigDecimal.valueOf(1)); + + Assert.assertNull(times(longNull, long1)); + Assert.assertNull(times(intNull, int1)); + Assert.assertNull(times(shortNull, short1)); + Assert.assertNull(times(byteNull, byte1)); + Assert.assertNull(times(doubleNull, double1)); + Assert.assertNull(times(decimalNull, decimal1)); + Assert.assertNull(times(longNull, int1)); + Assert.assertNull(times(longNull, short1)); + Assert.assertNull(times(longNull, byte1)); + Assert.assertNull(times(longNull, double1)); + Assert.assertNull(times(longNull, decimal1)); + Assert.assertNull(times(intNull, long1)); + Assert.assertNull(times(intNull, short1)); + Assert.assertNull(times(intNull, byte1)); + Assert.assertNull(times(intNull, double1)); + Assert.assertNull(times(intNull, decimal1)); + Assert.assertNull(times(shortNull, long1)); + Assert.assertNull(times(shortNull, int1)); + Assert.assertNull(times(shortNull, byte1)); + Assert.assertNull(times(shortNull, double1)); + Assert.assertNull(times(shortNull, decimal1)); + Assert.assertNull(times(byteNull, long1)); + Assert.assertNull(times(byteNull, int1)); + Assert.assertNull(times(byteNull, short1)); + Assert.assertNull(times(byteNull, double1)); + Assert.assertNull(times(byteNull, decimal1)); + Assert.assertNull(times(doubleNull, long1)); + Assert.assertNull(times(doubleNull, int1)); + Assert.assertNull(times(doubleNull, short1)); + Assert.assertNull(times(doubleNull, byte1)); + Assert.assertNull(times(doubleNull, decimal1)); + Assert.assertNull(times(decimalNull, long1)); + Assert.assertNull(times(decimalNull, int1)); + Assert.assertNull(times(decimalNull, short1)); + Assert.assertNull(times(decimalNull, byte1)); + Assert.assertNull(times(decimalNull, double1)); + } + + @Test + public void testDivide() { + assertEquals((long) divide(long1, long1), 1); + assertEquals((int) divide(int1, int1), 1); + assertEquals((int) divide(short1, short1), 1); + assertEquals((int) divide(byte1, byte1), 1); + assertEquals(divide(double1, double1), 1.0); + assertEquals(divide(decimal1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) divide(long1, int1), 1); + assertEquals((long) divide(long1, short1), 1); + assertEquals((long) divide(long1, byte1), 1); + assertEquals(divide(long1, double1), 1.0); + assertEquals(divide(long1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) divide(int1, long1), 1); + assertEquals((int) divide(int1, short1), 1); + assertEquals((int) divide(int1, byte1), 1); + assertEquals(divide(int1, double1), 1.0); + assertEquals(divide(int1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) divide(short1, long1), 1); + assertEquals((int) divide(short1, int1), 1); + assertEquals((int) divide(short1, byte1), 1); + assertEquals(divide(short1, double1), 1.0); + assertEquals(divide(short1, decimal1), BigDecimal.valueOf(1)); + assertEquals((long) divide(byte1, long1), 1); + assertEquals((int) divide(byte1, int1), 1); + assertEquals((int) divide(byte1, short1), 1); + assertEquals(divide(byte1, double1), 1.0); + assertEquals(divide(byte1, decimal1), BigDecimal.valueOf(1)); + assertEquals(divide(double1, long1), 1.0); + assertEquals(divide(double1, int1), 1.0); + assertEquals(divide(double1, short1), 1.0); + assertEquals(divide(double1, byte1), 1.0); + assertEquals(divide(double1, decimal1), BigDecimal.valueOf(1)); + assertEquals(divide(decimal1, long1), BigDecimal.valueOf(1)); + assertEquals(divide(decimal1, int1), BigDecimal.valueOf(1)); + assertEquals(divide(decimal1, short1), BigDecimal.valueOf(1)); + assertEquals(divide(decimal1, byte1), BigDecimal.valueOf(1)); + assertEquals(divide(decimal1, double1), BigDecimal.valueOf(1)); + + Assert.assertNull(divide(longNull, long1)); + Assert.assertNull(divide(intNull, int1)); + Assert.assertNull(divide(shortNull, short1)); + Assert.assertNull(divide(byteNull, byte1)); + Assert.assertNull(divide(doubleNull, double1)); + Assert.assertNull(divide(decimalNull, decimal1)); + Assert.assertNull(divide(longNull, int1)); + Assert.assertNull(divide(longNull, short1)); + Assert.assertNull(divide(longNull, byte1)); + Assert.assertNull(divide(longNull, double1)); + Assert.assertNull(divide(longNull, decimal1)); + Assert.assertNull(divide(intNull, long1)); + Assert.assertNull(divide(intNull, short1)); + Assert.assertNull(divide(intNull, byte1)); + Assert.assertNull(divide(intNull, double1)); + Assert.assertNull(divide(intNull, decimal1)); + Assert.assertNull(divide(shortNull, long1)); + Assert.assertNull(divide(shortNull, int1)); + Assert.assertNull(divide(shortNull, byte1)); + Assert.assertNull(divide(shortNull, double1)); + Assert.assertNull(divide(shortNull, decimal1)); + Assert.assertNull(divide(byteNull, long1)); + Assert.assertNull(divide(byteNull, int1)); + Assert.assertNull(divide(byteNull, short1)); + Assert.assertNull(divide(byteNull, double1)); + Assert.assertNull(divide(byteNull, decimal1)); + Assert.assertNull(divide(doubleNull, long1)); + Assert.assertNull(divide(doubleNull, int1)); + Assert.assertNull(divide(doubleNull, short1)); + Assert.assertNull(divide(doubleNull, byte1)); + Assert.assertNull(divide(doubleNull, decimal1)); + Assert.assertNull(divide(decimalNull, long1)); + Assert.assertNull(divide(decimalNull, int1)); + Assert.assertNull(divide(decimalNull, short1)); + Assert.assertNull(divide(decimalNull, byte1)); + Assert.assertNull(divide(decimalNull, double1)); + } + + @Test + public void testMod() { + assertEquals((long) mod(long1, long1), 0); + assertEquals((int) mod(int1, int1), 0); + assertEquals(mod(double1, double1), 0.0); + + Assert.assertNull(mod(longNull, long1)); + Assert.assertNull(mod(intNull, int1)); + Assert.assertNull(mod(doubleNull, double1)); + } + + @Test + public void testPower() { + assertEquals(power(double1, double1), 1.0); + + Assert.assertNull(power(doubleNull, double1)); + } + + @Test + public void testAbs() { + assertEquals((long) abs(long1), 1); + assertEquals((int) abs(int1), 1); + assertEquals((int) abs(short1), 1); + assertEquals((int) abs(byte1), 1); + assertEquals(abs(decimal1), BigDecimal.valueOf(1)); + assertEquals(abs(double1), 1.0); + + Assert.assertNull(abs(longNull)); + Assert.assertNull(abs(intNull)); + Assert.assertNull(abs(shortNull)); + Assert.assertNull(abs(byteNull)); + Assert.assertNull(abs(decimalNull)); + Assert.assertNull(abs(doubleNull)); + } + + @Test + public void testTrigonometric() { + assertEquals(asin(double1), 1.5707963267948966); + assertEquals(acos(double1), 0.0); + assertEquals(atan(double1), 0.7853981633974483); + assertEquals(ceil(double1), 1.0); + assertEquals(ceil(long1), long1); + assertEquals(ceil(int1), int1); + assertEquals(cot(double1), 0.6420926159343306); + assertEquals(cos(double1), 0.5403023058681398); + assertEquals(degrees(double1), 57.29577951308232); + assertEquals(radians(double1), 0.017453292519943295); + assertEquals(sign(double1), 1.0); + assertEquals(sin(double1), 0.8414709848078965); + assertEquals(tan(double1), 1.5574077246549023); + + Assert.assertNull(asin(doubleNull)); + Assert.assertNull(acos(doubleNull)); + Assert.assertNull(atan(doubleNull)); + Assert.assertNull(ceil(doubleNull)); + Assert.assertNull(ceil(longNull)); + Assert.assertNull(ceil(intNull)); + Assert.assertNull(cot(doubleNull)); + Assert.assertNull(cos(doubleNull)); + Assert.assertNull(degrees(doubleNull)); + Assert.assertNull(radians(doubleNull)); + Assert.assertNull(sign(doubleNull)); + Assert.assertNull(sin(doubleNull)); + Assert.assertNull(tan(doubleNull)); + } + + @Test + public void testMath() { + assertEquals(exp(double1), 2.718281828459045, 1e-15); + assertEquals(floor(double1), 1.0); + assertEquals((long) floor(long1), 1); + assertEquals((int) floor(int1), 1); + assertEquals(ln(double1), 0.0); + assertEquals(log10(double1), 0.0); + assertEquals(minusPrefix(double1), -1.0); + assertEquals((long) minusPrefix(long1), -1); + assertEquals((int) minusPrefix(int1), -1); + assertEquals((int) minusPrefix(short1), -1); + assertEquals((int) minusPrefix(byte1), -1); + assertEquals(minusPrefix(decimal1), BigDecimal.valueOf(-1)); + + System.out.println(rand()); + System.out.println(rand(long1)); + System.out.println(rand(long1, int1)); + System.out.println(randInt(int1)); + System.out.println(randInt(long1, int1)); + + Assert.assertNull(exp(doubleNull)); + Assert.assertNull(floor(doubleNull)); + Assert.assertNull(floor(longNull)); + Assert.assertNull(floor(intNull)); + Assert.assertNull(ln(doubleNull)); + Assert.assertNull(log10(doubleNull)); + Assert.assertNull(minusPrefix(doubleNull)); + Assert.assertNull(minusPrefix(longNull)); + Assert.assertNull(minusPrefix(intNull)); + Assert.assertNull(minusPrefix(shortNull)); + Assert.assertNull(minusPrefix(byteNull)); + Assert.assertNull(minusPrefix(decimalNull)); + } + + @Test + public void testEqual() { + String string1 = "1"; + String stringNull = null; + Boolean boolTrue = true; + Boolean boolNull = null; + + Assert.assertTrue(equal(long1, long1)); + Assert.assertTrue(equal(double1, double1)); + Assert.assertTrue(equal(decimal1, decimal1)); + Assert.assertTrue(equal(string1, string1)); + Assert.assertTrue(equal(boolTrue, boolTrue)); + Assert.assertTrue(equal(string1, int1)); + Assert.assertTrue(equal(int1, string1)); + Assert.assertTrue(equal(string1, double1)); + Assert.assertTrue(equal(double1, string1)); + Assert.assertTrue(equal(string1, long1)); + Assert.assertTrue(equal(long1, string1)); + Assert.assertTrue(equal(string1, false)); + Assert.assertTrue(equal(false, string1)); + Assert.assertTrue(equal((Object) boolTrue, boolTrue)); + + Assert.assertNull(equal(longNull, long1)); + Assert.assertNull(equal(doubleNull, double1)); + Assert.assertNull(equal(decimalNull, decimal1)); + Assert.assertNull(equal(stringNull, string1)); + Assert.assertNull(equal(boolNull, boolTrue)); + Assert.assertNull(equal(stringNull, int1)); + Assert.assertNull(equal(intNull, string1)); + Assert.assertNull(equal(stringNull, double1)); + Assert.assertNull(equal(doubleNull, string1)); + Assert.assertNull(equal(stringNull, long1)); + Assert.assertNull(equal(longNull, string1)); + Assert.assertNull(equal(stringNull, boolTrue)); + Assert.assertNull(equal(boolNull, string1)); + Assert.assertNull(equal(stringNull, boolTrue)); + Assert.assertNull(equal((Object) boolNull, boolTrue)); + } + + @Test + public void testUnequal() { + String string1 = "1"; + String stringNull = null; + Boolean boolTrue = true; + Boolean boolNull = null; + + Assert.assertFalse(unequal(long1, long1)); + Assert.assertFalse(unequal(double1, double1)); + Assert.assertFalse(unequal(decimal1, decimal1)); + Assert.assertFalse(unequal(string1, string1)); + Assert.assertFalse(unequal(boolTrue, boolTrue)); + Assert.assertFalse(unequal(string1, int1)); + Assert.assertFalse(unequal(int1, string1)); + Assert.assertFalse(unequal(string1, double1)); + Assert.assertFalse(unequal(double1, string1)); + Assert.assertFalse(unequal(string1, long1)); + Assert.assertFalse(unequal(long1, string1)); + Assert.assertFalse(unequal(string1, false)); + Assert.assertFalse(unequal(false, string1)); + Assert.assertFalse(unequal((Object) boolTrue, boolTrue)); + + Assert.assertNull(unequal(longNull, long1)); + Assert.assertNull(unequal(doubleNull, double1)); + Assert.assertNull(unequal(decimalNull, decimal1)); + Assert.assertNull(unequal(stringNull, string1)); + Assert.assertNull(unequal(boolNull, boolTrue)); + Assert.assertNull(unequal(stringNull, int1)); + Assert.assertNull(unequal(intNull, string1)); + Assert.assertNull(unequal(stringNull, double1)); + Assert.assertNull(unequal(doubleNull, string1)); + Assert.assertNull(unequal(stringNull, long1)); + Assert.assertNull(unequal(longNull, string1)); + Assert.assertNull(unequal(stringNull, boolTrue)); + Assert.assertNull(unequal(boolNull, string1)); + Assert.assertNull(unequal(stringNull, boolTrue)); + Assert.assertFalse(unequal((Object) boolTrue, boolTrue)); + } + + @Test + public void testCompare() { + String string1 = "1"; + String stringNull = null; + + Assert.assertFalse(lessThan(long1, long1)); + Assert.assertFalse(lessThan(double1, double1)); + Assert.assertFalse(lessThan(decimal1, decimal1)); + Assert.assertFalse(lessThan(string1, string1)); + + Assert.assertTrue(greaterThanEq(long1, long1)); + Assert.assertTrue(greaterThanEq(double1, double1)); + Assert.assertTrue(greaterThanEq(decimal1, decimal1)); + Assert.assertTrue(greaterThanEq(string1, string1)); + + Assert.assertTrue(lessThanEq(long1, long1)); + Assert.assertTrue(lessThanEq(double1, double1)); + Assert.assertTrue(lessThanEq(decimal1, decimal1)); + Assert.assertTrue(lessThanEq(string1, string1)); + + Assert.assertFalse(greaterThan(long1, long1)); + Assert.assertFalse(greaterThan(double1, double1)); + Assert.assertFalse(greaterThan(decimal1, decimal1)); + Assert.assertFalse(greaterThan(string1, string1)); + + Assert.assertNull(lessThan(longNull, long1)); + Assert.assertNull(lessThan(doubleNull, double1)); + Assert.assertNull(lessThan(decimalNull, decimal1)); + Assert.assertNull(lessThan(stringNull, string1)); + + Assert.assertNull(greaterThanEq(longNull, long1)); + Assert.assertNull(greaterThanEq(doubleNull, double1)); + Assert.assertNull(greaterThanEq(decimalNull, decimal1)); + Assert.assertNull(greaterThanEq(stringNull, string1)); + + Assert.assertNull(lessThanEq(longNull, long1)); + Assert.assertNull(lessThanEq(doubleNull, double1)); + Assert.assertNull(lessThanEq(decimalNull, decimal1)); + Assert.assertNull(lessThanEq(stringNull, string1)); + + Assert.assertNull(greaterThan(longNull, long1)); + Assert.assertNull(greaterThan(doubleNull, double1)); + Assert.assertNull(greaterThan(decimalNull, decimal1)); + Assert.assertNull(greaterThan(stringNull, string1)); + } + + @Test + public void testTimestampUtil() { + Timestamp ts = Timestamp.valueOf("1987-06-05 04:03:02"); + Assert.assertEquals(timestampCeil(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:03")); + Assert.assertEquals(timestampCeil(ts, 60000L), Timestamp.valueOf("1987-06-05 04:04:00.0")); + Assert.assertEquals(timestampCeil(ts, 3600000L), Timestamp.valueOf("1987-06-05 05:00:00.0")); + Assert.assertEquals(timestampCeil(ts, 86400000L), Timestamp.valueOf("1987-06-06 00:00:00.0")); + Assert.assertEquals(timestampTumble(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:03")); + Assert.assertEquals(timestampTumble(ts, 60000L), Timestamp.valueOf("1987-06-05 04:04:00.0")); + Assert.assertEquals(timestampTumble(ts, 3600000L), Timestamp.valueOf("1987-06-05 05:00:00.0")); + Assert.assertEquals(timestampFloor(ts, 1000L), Timestamp.valueOf("1987-06-05 04:03:02.0")); + Assert.assertEquals(timestampFloor(ts, 60000L), Timestamp.valueOf("1987-06-05 04:03:00.0")); + Assert.assertEquals(timestampFloor(ts, 3600000L), Timestamp.valueOf("1987-06-05 04:00:00.0")); + Assert.assertEquals(timestampFloor(ts, 86400000L), Timestamp.valueOf("1987-06-05 00:00:00.0")); + + Assert.assertEquals(plus(ts, 1L), Timestamp.valueOf("1987-06-05 04:03:02.001")); + Assert.assertEquals(minus(ts, 1L), Timestamp.valueOf("1987-06-05 04:03:01.999")); + Assert.assertEquals((long) minus(ts, Timestamp.valueOf("1987-06-05 00:00:00.0")), 14582000L); + + Assert.assertNull(plus((Timestamp) null, 1L)); + Assert.assertNull(minus((Timestamp) null, 1L)); + Assert.assertNull(minus((Timestamp) null, Timestamp.valueOf("1987-06-05 00:00:00.0"))); + } + + @Test + public void testOtherFunction() { + Assert.assertEquals(round(double1, 2), 1.0); + + Assert.assertNull(round(doubleNull, 2)); + } + + @Test + public void testGeaFlowUserDefinedTableFunction() { + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + GeaFlowUserDefinedTableFunction.create( + "testFunction", UtUserDefinedTableFunction.class, typeFactory); + } + + public static class UtUserDefinedTableFunction extends UDTF { + + public UtUserDefinedTableFunction() {} + + @Override + public List> getReturnType(List> paramTypes, List outFieldNames) { + return Lists.newArrayList(String.class, Long.class); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/function/SameTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/function/SameTest.java index 52065e160..84480927e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/function/SameTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/schema/function/SameTest.java @@ -25,231 +25,233 @@ import org.testng.Assert; import org.testng.annotations.Test; -/** - * Unit tests for the ISO-GQL SAME predicate function. - */ +/** Unit tests for the ISO-GQL SAME predicate function. */ public class SameTest { - @Test - public void testSameWithIdenticalVertices() { - // Create two vertices with the same ID - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); - Assert.assertTrue(result, "Vertices with same ID should return true"); - } - - @Test - public void testSameWithDifferentVertices() { - // Create two vertices with different IDs - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex(2, null, ObjectRow.create("Bob", 30)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); - Assert.assertFalse(result, "Vertices with different IDs should return false"); - } - - @Test - public void testSameWithIdenticalEdges() { - // Create two edges with the same source and target IDs - ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(1, 2, ObjectRow.create("likes")); - - Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); - Assert.assertTrue(result, "Edges with same source and target IDs should return true"); - } - - @Test - public void testSameWithDifferentEdgesSameSource() { - // Create two edges with the same source but different target IDs - ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(1, 3, ObjectRow.create("knows")); - - Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); - Assert.assertFalse(result, "Edges with different target IDs should return false"); - } - - @Test - public void testSameWithDifferentEdgesSameTarget() { - // Create two edges with different source but same target IDs - ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(3, 2, ObjectRow.create("knows")); - - Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); - Assert.assertFalse(result, "Edges with different source IDs should return false"); - } - - @Test - public void testSameWithDifferentEdges() { - // Create two edges with completely different IDs - ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(3, 4, ObjectRow.create("knows")); - - Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); - Assert.assertFalse(result, "Edges with different IDs should return false"); - } - - @Test - public void testSameWithMixedTypes() { - // Test vertex and edge - should return false - ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectEdge e = new ObjectEdge(1, 2, ObjectRow.create("knows")); - - Boolean result = GeaFlowBuiltinFunctions.same(v, e); - Assert.assertFalse(result, "Vertex and edge should return false"); - } - - @Test - public void testSameWithNullFirst() { - // Test with first argument null - ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - - Boolean result = GeaFlowBuiltinFunctions.same(null, v); - Assert.assertNull(result, "Null first argument should return null"); - } - - @Test - public void testSameWithNullSecond() { - // Test with second argument null - ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - - Boolean result = GeaFlowBuiltinFunctions.same(v, null); - Assert.assertNull(result, "Null second argument should return null"); - } - - @Test - public void testSameWithBothNull() { - // Test with both arguments null - use explicit cast to Object to resolve ambiguity - Boolean result = GeaFlowBuiltinFunctions.same((Object) null, (Object) null); - Assert.assertNull(result, "Both null arguments should return null"); - } - - @Test - public void testSameWithStringIds() { - // Test with string IDs instead of integer IDs - ObjectVertex v1 = new ObjectVertex("user123", null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex("user123", null, ObjectRow.create("Bob", 30)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); - Assert.assertTrue(result, "Vertices with same string ID should return true"); - } - - @Test - public void testSameWithDifferentStringIds() { - // Test with different string IDs - ObjectVertex v1 = new ObjectVertex("user123", null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex("user456", null, ObjectRow.create("Bob", 30)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); - Assert.assertFalse(result, "Vertices with different string IDs should return false"); - } - - @Test - public void testSameWithInvalidTypes() { - // Test with objects that are not RowVertex or RowEdge - String s1 = "test"; - String s2 = "test"; - - Boolean result = GeaFlowBuiltinFunctions.same(s1, s2); - Assert.assertFalse(result, "Non-graph elements should return false"); - } - - // Tests for type-specific overloads (RowVertex, RowEdge) - - @Test - public void testSameVertexOverloadWithIdenticalIds() { - // Test the type-specific RowVertex overload - ObjectVertex v1 = new ObjectVertex(100, null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex(100, null, ObjectRow.create("Bob", 30)); - - // Explicitly call with RowVertex types - Boolean result = GeaFlowBuiltinFunctions.same((org.apache.geaflow.dsl.common.data.RowVertex) v1, + @Test + public void testSameWithIdenticalVertices() { + // Create two vertices with the same ID + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); + Assert.assertTrue(result, "Vertices with same ID should return true"); + } + + @Test + public void testSameWithDifferentVertices() { + // Create two vertices with different IDs + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex(2, null, ObjectRow.create("Bob", 30)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); + Assert.assertFalse(result, "Vertices with different IDs should return false"); + } + + @Test + public void testSameWithIdenticalEdges() { + // Create two edges with the same source and target IDs + ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(1, 2, ObjectRow.create("likes")); + + Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); + Assert.assertTrue(result, "Edges with same source and target IDs should return true"); + } + + @Test + public void testSameWithDifferentEdgesSameSource() { + // Create two edges with the same source but different target IDs + ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(1, 3, ObjectRow.create("knows")); + + Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); + Assert.assertFalse(result, "Edges with different target IDs should return false"); + } + + @Test + public void testSameWithDifferentEdgesSameTarget() { + // Create two edges with different source but same target IDs + ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(3, 2, ObjectRow.create("knows")); + + Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); + Assert.assertFalse(result, "Edges with different source IDs should return false"); + } + + @Test + public void testSameWithDifferentEdges() { + // Create two edges with completely different IDs + ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(3, 4, ObjectRow.create("knows")); + + Boolean result = GeaFlowBuiltinFunctions.same(e1, e2); + Assert.assertFalse(result, "Edges with different IDs should return false"); + } + + @Test + public void testSameWithMixedTypes() { + // Test vertex and edge - should return false + ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectEdge e = new ObjectEdge(1, 2, ObjectRow.create("knows")); + + Boolean result = GeaFlowBuiltinFunctions.same(v, e); + Assert.assertFalse(result, "Vertex and edge should return false"); + } + + @Test + public void testSameWithNullFirst() { + // Test with first argument null + ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + + Boolean result = GeaFlowBuiltinFunctions.same(null, v); + Assert.assertNull(result, "Null first argument should return null"); + } + + @Test + public void testSameWithNullSecond() { + // Test with second argument null + ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + + Boolean result = GeaFlowBuiltinFunctions.same(v, null); + Assert.assertNull(result, "Null second argument should return null"); + } + + @Test + public void testSameWithBothNull() { + // Test with both arguments null - use explicit cast to Object to resolve ambiguity + Boolean result = GeaFlowBuiltinFunctions.same((Object) null, (Object) null); + Assert.assertNull(result, "Both null arguments should return null"); + } + + @Test + public void testSameWithStringIds() { + // Test with string IDs instead of integer IDs + ObjectVertex v1 = new ObjectVertex("user123", null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex("user123", null, ObjectRow.create("Bob", 30)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); + Assert.assertTrue(result, "Vertices with same string ID should return true"); + } + + @Test + public void testSameWithDifferentStringIds() { + // Test with different string IDs + ObjectVertex v1 = new ObjectVertex("user123", null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex("user456", null, ObjectRow.create("Bob", 30)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2); + Assert.assertFalse(result, "Vertices with different string IDs should return false"); + } + + @Test + public void testSameWithInvalidTypes() { + // Test with objects that are not RowVertex or RowEdge + String s1 = "test"; + String s2 = "test"; + + Boolean result = GeaFlowBuiltinFunctions.same(s1, s2); + Assert.assertFalse(result, "Non-graph elements should return false"); + } + + // Tests for type-specific overloads (RowVertex, RowEdge) + + @Test + public void testSameVertexOverloadWithIdenticalIds() { + // Test the type-specific RowVertex overload + ObjectVertex v1 = new ObjectVertex(100, null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex(100, null, ObjectRow.create("Bob", 30)); + + // Explicitly call with RowVertex types + Boolean result = + GeaFlowBuiltinFunctions.same( + (org.apache.geaflow.dsl.common.data.RowVertex) v1, (org.apache.geaflow.dsl.common.data.RowVertex) v2); - Assert.assertTrue(result, "Type-specific vertex overload should work"); - } - - @Test - public void testSameEdgeOverloadWithIdenticalIds() { - // Test the type-specific RowEdge overload - ObjectEdge e1 = new ObjectEdge(10, 20, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(10, 20, ObjectRow.create("likes")); - - // Explicitly call with RowEdge types - Boolean result = GeaFlowBuiltinFunctions.same((org.apache.geaflow.dsl.common.data.RowEdge) e1, + Assert.assertTrue(result, "Type-specific vertex overload should work"); + } + + @Test + public void testSameEdgeOverloadWithIdenticalIds() { + // Test the type-specific RowEdge overload + ObjectEdge e1 = new ObjectEdge(10, 20, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(10, 20, ObjectRow.create("likes")); + + // Explicitly call with RowEdge types + Boolean result = + GeaFlowBuiltinFunctions.same( + (org.apache.geaflow.dsl.common.data.RowEdge) e1, (org.apache.geaflow.dsl.common.data.RowEdge) e2); - Assert.assertTrue(result, "Type-specific edge overload should work"); - } - - // Tests for multi-argument same() varargs method - - @Test - public void testSameWithThreeIdenticalVertices() { - // Test varargs with 3 identical vertices - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); - ObjectVertex v3 = new ObjectVertex(1, null, ObjectRow.create("Charlie", 35)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2, v3); - Assert.assertTrue(result, "Three vertices with same ID should return true"); - } - - @Test - public void testSameWithThreeVerticesOneDifferent() { - // Test varargs with one different vertex - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); - ObjectVertex v3 = new ObjectVertex(2, null, ObjectRow.create("Charlie", 35)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, v2, v3); - Assert.assertFalse(result, "Three vertices with one different ID should return false"); - } - - @Test - public void testSameWithFourIdenticalEdges() { - // Test varargs with 4 identical edges - ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); - ObjectEdge e2 = new ObjectEdge(1, 2, ObjectRow.create("likes")); - ObjectEdge e3 = new ObjectEdge(1, 2, ObjectRow.create("follows")); - ObjectEdge e4 = new ObjectEdge(1, 2, ObjectRow.create("trusts")); - - Boolean result = GeaFlowBuiltinFunctions.same(e1, e2, e3, e4); - Assert.assertTrue(result, "Four edges with same source and target IDs should return true"); - } - - @Test - public void testSameWithMultipleNullInMiddle() { - // Test varargs with null in the middle - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectVertex v3 = new ObjectVertex(1, null, ObjectRow.create("Charlie", 35)); - - Boolean result = GeaFlowBuiltinFunctions.same(v1, null, v3); - Assert.assertNull(result, "Varargs with null element should return null"); - } - - @Test - public void testSameWithEmptyVarargs() { - // Test varargs with no arguments (should return null) - Boolean result = GeaFlowBuiltinFunctions.same(new Object[0]); - Assert.assertNull(result, "Empty varargs should return null"); - } - - @Test - public void testSameWithSingleVararg() { - // Test varargs with single argument (should return null - need at least 2) - ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - - Boolean result = GeaFlowBuiltinFunctions.same(new Object[]{v1}); - Assert.assertNull(result, "Single vararg should return null"); - } - - @Test - public void testSameWithMixedTypesInVarargs() { - // Test varargs with mixed vertex and edge types - ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); - ObjectEdge e = new ObjectEdge(1, 2, ObjectRow.create("knows")); - - Boolean result = GeaFlowBuiltinFunctions.same(v, e); - Assert.assertFalse(result, "Mixed vertex and edge in varargs should return false"); - } + Assert.assertTrue(result, "Type-specific edge overload should work"); + } + + // Tests for multi-argument same() varargs method + + @Test + public void testSameWithThreeIdenticalVertices() { + // Test varargs with 3 identical vertices + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); + ObjectVertex v3 = new ObjectVertex(1, null, ObjectRow.create("Charlie", 35)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2, v3); + Assert.assertTrue(result, "Three vertices with same ID should return true"); + } + + @Test + public void testSameWithThreeVerticesOneDifferent() { + // Test varargs with one different vertex + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectVertex v2 = new ObjectVertex(1, null, ObjectRow.create("Bob", 30)); + ObjectVertex v3 = new ObjectVertex(2, null, ObjectRow.create("Charlie", 35)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, v2, v3); + Assert.assertFalse(result, "Three vertices with one different ID should return false"); + } + + @Test + public void testSameWithFourIdenticalEdges() { + // Test varargs with 4 identical edges + ObjectEdge e1 = new ObjectEdge(1, 2, ObjectRow.create("knows")); + ObjectEdge e2 = new ObjectEdge(1, 2, ObjectRow.create("likes")); + ObjectEdge e3 = new ObjectEdge(1, 2, ObjectRow.create("follows")); + ObjectEdge e4 = new ObjectEdge(1, 2, ObjectRow.create("trusts")); + + Boolean result = GeaFlowBuiltinFunctions.same(e1, e2, e3, e4); + Assert.assertTrue(result, "Four edges with same source and target IDs should return true"); + } + + @Test + public void testSameWithMultipleNullInMiddle() { + // Test varargs with null in the middle + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectVertex v3 = new ObjectVertex(1, null, ObjectRow.create("Charlie", 35)); + + Boolean result = GeaFlowBuiltinFunctions.same(v1, null, v3); + Assert.assertNull(result, "Varargs with null element should return null"); + } + + @Test + public void testSameWithEmptyVarargs() { + // Test varargs with no arguments (should return null) + Boolean result = GeaFlowBuiltinFunctions.same(new Object[0]); + Assert.assertNull(result, "Empty varargs should return null"); + } + + @Test + public void testSameWithSingleVararg() { + // Test varargs with single argument (should return null - need at least 2) + ObjectVertex v1 = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + + Boolean result = GeaFlowBuiltinFunctions.same(new Object[] {v1}); + Assert.assertNull(result, "Single vararg should return null"); + } + + @Test + public void testSameWithMixedTypesInVarargs() { + // Test varargs with mixed vertex and edge types + ObjectVertex v = new ObjectVertex(1, null, ObjectRow.create("Alice", 25)); + ObjectEdge e = new ObjectEdge(1, 2, ObjectRow.create("knows")); + + Boolean result = GeaFlowBuiltinFunctions.same(v, e); + Assert.assertFalse(result, "Mixed vertex and edge in varargs should return false"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/stats/StatsParseJsonTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/stats/StatsParseJsonTest.java index 530eb0ff4..48a6ff2e6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/stats/StatsParseJsonTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/stats/StatsParseJsonTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.stats; -import com.google.gson.Gson; import org.apache.geaflow.dsl.common.descriptor.EdgeDescriptor; import org.apache.geaflow.dsl.common.descriptor.GraphDescriptor; import org.apache.geaflow.dsl.common.descriptor.NodeDescriptor; @@ -27,16 +26,19 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class StatsParseJsonTest { +import com.google.gson.Gson; - @Test - public void testParseUserStats() { - GraphDescriptor userStats = new GraphDescriptor(); - userStats.addNode(new NodeDescriptor("n1", "Person")); - userStats.addEdge(new EdgeDescriptor("e1", "knows", "Person", "Person")); - userStats.addRelation(new RelationDescriptor("Person", "knows", "one-to-one")); - Gson gson = new Gson(); - Assert.assertEquals(gson.toJson(gson.fromJson(gson.toJson(userStats), GraphDescriptor.class)), gson.toJson(userStats)); +public class StatsParseJsonTest { - } + @Test + public void testParseUserStats() { + GraphDescriptor userStats = new GraphDescriptor(); + userStats.addNode(new NodeDescriptor("n1", "Person")); + userStats.addEdge(new EdgeDescriptor("e1", "knows", "Person", "Person")); + userStats.addRelation(new RelationDescriptor("Person", "knows", "one-to-one")); + Gson gson = new Gson(); + Assert.assertEquals( + gson.toJson(gson.fromJson(gson.toJson(userStats), GraphDescriptor.class)), + gson.toJson(userStats)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/type/GraphRecordTypeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/type/GraphRecordTypeTest.java index 4a9dcd935..4278000ad 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/type/GraphRecordTypeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/type/GraphRecordTypeTest.java @@ -21,7 +21,6 @@ import static org.testng.Assert.assertEquals; -import com.google.common.collect.Lists; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeFieldImpl; @@ -35,50 +34,45 @@ import org.apache.geaflow.dsl.planner.GQLRelDataTypeSystem; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class GraphRecordTypeTest { - @Test - public void testGraphRecordType() { - SqlTypeName stringType = SqlTypeName.VARCHAR; - SqlTypeName longType = SqlTypeName.BIGINT; - SqlTypeName doubleType = SqlTypeName.DOUBLE; - RelDataTypeSystem typeSystem = new GQLRelDataTypeSystem(); - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + @Test + public void testGraphRecordType() { + SqlTypeName stringType = SqlTypeName.VARCHAR; + SqlTypeName longType = SqlTypeName.BIGINT; + SqlTypeName doubleType = SqlTypeName.DOUBLE; + RelDataTypeSystem typeSystem = new GQLRelDataTypeSystem(); + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - RelDataType relStringType = new BasicSqlType(typeSystem, stringType); - RelDataType relLongType = new BasicSqlType(typeSystem, longType); - RelDataType relDoubleType = new BasicSqlType(typeSystem, doubleType); + RelDataType relStringType = new BasicSqlType(typeSystem, stringType); + RelDataType relLongType = new BasicSqlType(typeSystem, longType); + RelDataType relDoubleType = new BasicSqlType(typeSystem, doubleType); - RelDataTypeField field1 = new RelDataTypeFieldImpl("name", 0, relStringType); - RelDataTypeField field2 = new RelDataTypeFieldImpl("id", 1, relLongType); - RelDataTypeField field3 = new RelDataTypeFieldImpl("age", 2, relDoubleType); - RelDataType vertexType = VertexRecordType.createVertexType( - Lists.newArrayList(field1, field2, field3), - "id", typeFactory - ); - RelDataTypeField vertexField = new RelDataTypeFieldImpl("user", 0, vertexType); + RelDataTypeField field1 = new RelDataTypeFieldImpl("name", 0, relStringType); + RelDataTypeField field2 = new RelDataTypeFieldImpl("id", 1, relLongType); + RelDataTypeField field3 = new RelDataTypeFieldImpl("age", 2, relDoubleType); + RelDataType vertexType = + VertexRecordType.createVertexType( + Lists.newArrayList(field1, field2, field3), "id", typeFactory); + RelDataTypeField vertexField = new RelDataTypeFieldImpl("user", 0, vertexType); - RelDataTypeField field4 = new RelDataTypeFieldImpl("src", 0, relLongType); - RelDataTypeField field5 = new RelDataTypeFieldImpl("dst", 1, relLongType); - RelDataTypeField field6 = new RelDataTypeFieldImpl("weight", 2, relDoubleType); - RelDataType edgeType = EdgeRecordType.createEdgeType( - Lists.newArrayList(field4, field5, field6), - "src", "dst", null, typeFactory - ); - RelDataTypeField edgeField = new RelDataTypeFieldImpl("follow", 1, edgeType); + RelDataTypeField field4 = new RelDataTypeFieldImpl("src", 0, relLongType); + RelDataTypeField field5 = new RelDataTypeFieldImpl("dst", 1, relLongType); + RelDataTypeField field6 = new RelDataTypeFieldImpl("weight", 2, relDoubleType); + RelDataType edgeType = + EdgeRecordType.createEdgeType( + Lists.newArrayList(field4, field5, field6), "src", "dst", null, typeFactory); + RelDataTypeField edgeField = new RelDataTypeFieldImpl("follow", 1, edgeType); - GraphRecordType graphRecordType = new GraphRecordType("g0", - Lists.newArrayList(vertexField, edgeField) - ); - assertEquals( - graphRecordType.getVertexType(Lists.newArrayList("user"), typeFactory) - .toString(), - "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, DOUBLE age)" - ); - assertEquals( - graphRecordType.getEdgeType(Lists.newArrayList("follow"), typeFactory) - .toString(), - "Edge: RecordType:peek(BIGINT src, BIGINT dst, VARCHAR ~label, DOUBLE weight)" - ); - } + GraphRecordType graphRecordType = + new GraphRecordType("g0", Lists.newArrayList(vertexField, edgeField)); + assertEquals( + graphRecordType.getVertexType(Lists.newArrayList("user"), typeFactory).toString(), + "Vertex:RecordType:peek(BIGINT id, VARCHAR ~label, VARCHAR name, DOUBLE age)"); + assertEquals( + graphRecordType.getEdgeType(Lists.newArrayList("follow"), typeFactory).toString(), + "Edge: RecordType:peek(BIGINT src, BIGINT dst, VARCHAR ~label, DOUBLE weight)"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/Split2.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/Split2.java index 694390bd6..82bf70d49 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/Split2.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/Split2.java @@ -19,51 +19,53 @@ package org.apache.geaflow.dsl.udf; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDTF; +import com.google.common.collect.Lists; + @Description(name = "split2", description = "") public class Split2 extends UDTF { - private String columnDelimiter = ","; - private String lineDelimiter = "\n"; + private String columnDelimiter = ","; + private String lineDelimiter = "\n"; - public void eval(String data) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - public void eval(String data, String columnDelimiter) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data, String columnDelimiter) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - public void eval(String data, String columnDelimiter, String lineDelimiter) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data, String columnDelimiter, String lineDelimiter) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - private void evalInternal(String data, String columnDelimiter, String lineDelimiter) { - String[] rows = StringUtils.split(data, lineDelimiter); - for (String row : rows) { - String[] split = StringUtils.split(row, columnDelimiter); - collect(split); - } + private void evalInternal(String data, String columnDelimiter, String lineDelimiter) { + String[] rows = StringUtils.split(data, lineDelimiter); + for (String row : rows) { + String[] split = StringUtils.split(row, columnDelimiter); + collect(split); } + } - @Override - public List> getReturnType(List> paramTypes, List udtfReturnFields) { + @Override + public List> getReturnType(List> paramTypes, List udtfReturnFields) { - List> clazzs = Lists.newArrayList(); + List> clazzs = Lists.newArrayList(); - if (udtfReturnFields == null) { - clazzs.add(String.class); - clazzs.add(String.class); - return clazzs; - } + if (udtfReturnFields == null) { + clazzs.add(String.class); + clazzs.add(String.class); + return clazzs; + } - for (int i = 0; i < udtfReturnFields.size(); i++) { - clazzs.add(String.class); - } - return clazzs; + for (int i = 0; i < udtfReturnFields.size(); i++) { + clazzs.add(String.class); } + return clazzs; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/agg/UDAFTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/agg/UDAFTest.java index 87397aff6..70cb36ee2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/agg/UDAFTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/agg/UDAFTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.udf.agg; -import com.google.common.collect.Lists; import org.apache.geaflow.dsl.udf.table.agg.AvgDouble; import org.apache.geaflow.dsl.udf.table.agg.AvgInteger; import org.apache.geaflow.dsl.udf.table.agg.AvgLong; @@ -36,161 +35,163 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class UDAFTest { - @Test - public void testAvgDouble() { - AvgDouble af = new AvgDouble(); - AvgDouble.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1.0); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), null); - } - - @Test - public void testAvgInteger() { - AvgInteger af = new AvgInteger(); - AvgInteger.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), null); - } - - @Test - public void testAvgLong() { - AvgLong af = new AvgLong(); - AvgLong.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1L); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), null); - } - - @Test - public void testCount() { - Count af = new Count(); - Count.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertEquals((long) af.getValue(accumulator), 0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 0); - } - - @Test - public void testMaxDouble() { - MaxDouble af = new MaxDouble(); - MaxDouble.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1.0); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testMaxInteger() { - MaxInteger af = new MaxInteger(); - MaxInteger.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testMaxLong() { - MaxLong af = new MaxLong(); - MaxLong.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1L); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testMinDouble() { - MinDouble af = new MinDouble(); - MinDouble.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1.0); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testMinInteger() { - MinInteger af = new MinInteger(); - MinInteger.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testMinLong() { - MinLong af = new MinLong(); - MinLong.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1L); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertNull(af.getValue(accumulator)); - } - - @Test - public void testSumDouble() { - SumDouble af = new SumDouble(); - SumDouble.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1.0); - Assert.assertEquals(af.getValue(accumulator), 1.0); - af.resetAccumulator(accumulator); - Assert.assertEquals(af.getValue(accumulator), 0.0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals(af.getValue(accumulator), 0.0); - } - - @Test - public void testSumInteger() { - SumInteger af = new SumInteger(); - SumInteger.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertEquals((int) af.getValue(accumulator), 0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 0); - } - - @Test - public void testSumLong() { - SumLong af = new SumLong(); - SumLong.Accumulator accumulator = af.createAccumulator(); - af.accumulate(accumulator, 1L); - Assert.assertEquals((long) af.getValue(accumulator), 1); - af.resetAccumulator(accumulator); - Assert.assertEquals((long) af.getValue(accumulator), 0); - af.merge(accumulator, Lists.newArrayList(accumulator)); - Assert.assertEquals((long) af.getValue(accumulator), 0); - } + @Test + public void testAvgDouble() { + AvgDouble af = new AvgDouble(); + AvgDouble.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1.0); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), null); + } + + @Test + public void testAvgInteger() { + AvgInteger af = new AvgInteger(); + AvgInteger.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), null); + } + + @Test + public void testAvgLong() { + AvgLong af = new AvgLong(); + AvgLong.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1L); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), null); + } + + @Test + public void testCount() { + Count af = new Count(); + Count.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertEquals((long) af.getValue(accumulator), 0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 0); + } + + @Test + public void testMaxDouble() { + MaxDouble af = new MaxDouble(); + MaxDouble.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1.0); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testMaxInteger() { + MaxInteger af = new MaxInteger(); + MaxInteger.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testMaxLong() { + MaxLong af = new MaxLong(); + MaxLong.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1L); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testMinDouble() { + MinDouble af = new MinDouble(); + MinDouble.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1.0); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testMinInteger() { + MinInteger af = new MinInteger(); + MinInteger.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testMinLong() { + MinLong af = new MinLong(); + MinLong.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1L); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertNull(af.getValue(accumulator)); + } + + @Test + public void testSumDouble() { + SumDouble af = new SumDouble(); + SumDouble.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1.0); + Assert.assertEquals(af.getValue(accumulator), 1.0); + af.resetAccumulator(accumulator); + Assert.assertEquals(af.getValue(accumulator), 0.0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals(af.getValue(accumulator), 0.0); + } + + @Test + public void testSumInteger() { + SumInteger af = new SumInteger(); + SumInteger.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertEquals((int) af.getValue(accumulator), 0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 0); + } + + @Test + public void testSumLong() { + SumLong af = new SumLong(); + SumLong.Accumulator accumulator = af.createAccumulator(); + af.accumulate(accumulator, 1L); + Assert.assertEquals((long) af.getValue(accumulator), 1); + af.resetAccumulator(accumulator); + Assert.assertEquals((long) af.getValue(accumulator), 0); + af.merge(accumulator, Lists.newArrayList(accumulator)); + Assert.assertEquals((long) af.getValue(accumulator), 0); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/array/ArrayTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/array/ArrayTest.java index 691eb7a87..b0e574dbe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/array/ArrayTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/array/ArrayTest.java @@ -28,36 +28,36 @@ public class ArrayTest { - @Test - public void testArrayAppend() throws Exception { - ArrayAppend udf = new ArrayAppend(); - Object[] input = new Object[]{1, 2, 4, -1}; - Object[] res = udf.eval(input, 6); - Assert.assertEquals(res, new Object[]{1, 2, 4, -1, 6}); - } + @Test + public void testArrayAppend() throws Exception { + ArrayAppend udf = new ArrayAppend(); + Object[] input = new Object[] {1, 2, 4, -1}; + Object[] res = udf.eval(input, 6); + Assert.assertEquals(res, new Object[] {1, 2, 4, -1, 6}); + } - @Test - public void testArrayContains() throws Exception { - ArrayContains udf = new ArrayContains(); - Object[] input = new Object[]{1, 2, 4, -1}; - Assert.assertTrue(udf.eval(input, 2)); - Assert.assertFalse(udf.eval(input, -3)); - } + @Test + public void testArrayContains() throws Exception { + ArrayContains udf = new ArrayContains(); + Object[] input = new Object[] {1, 2, 4, -1}; + Assert.assertTrue(udf.eval(input, 2)); + Assert.assertFalse(udf.eval(input, -3)); + } - @Test - public void testArrayDistinct() throws Exception { - ArrayDistinct udf = new ArrayDistinct(); - Object[] input = new Object[]{1, 2, 4, -1, 2, 1}; - Object[] res = udf.eval(input); - Assert.assertEquals(res.length, 4); - } + @Test + public void testArrayDistinct() throws Exception { + ArrayDistinct udf = new ArrayDistinct(); + Object[] input = new Object[] {1, 2, 4, -1, 2, 1}; + Object[] res = udf.eval(input); + Assert.assertEquals(res.length, 4); + } - @Test - public void testArrayUnion() throws Exception { - ArrayUnion udf = new ArrayUnion(); - Object[] input1 = new Object[]{1, 2, 4, -1}; - Object[] input2 = new Object[]{1, 3, 5, 2, -4, -6}; - Object[] res = udf.eval(input1, input2); - Assert.assertEquals(res.length, 8); - } + @Test + public void testArrayUnion() throws Exception { + ArrayUnion udf = new ArrayUnion(); + Object[] input1 = new Object[] {1, 2, 4, -1}; + Object[] input2 = new Object[] {1, 3, 5, 2, -4, -6}; + Object[] res = udf.eval(input1, input2); + Assert.assertEquals(res.length, 8); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/DateFormatTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/DateFormatTest.java index 35dbf0b09..0438b776b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/DateFormatTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/DateFormatTest.java @@ -24,10 +24,10 @@ public class DateFormatTest { - @Test - public void testName() { - DateFormat format = new DateFormat(); - format.open(null); - format.eval(new java.sql.Timestamp(System.currentTimeMillis()).toString()); - } + @Test + public void testName() { + DateFormat format = new DateFormat(); + format.open(null); + format.eval(new java.sql.Timestamp(System.currentTimeMillis()).toString()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/LastDayTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/LastDayTest.java index e2a1416fc..6a262e151 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/LastDayTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/LastDayTest.java @@ -26,16 +26,16 @@ public class LastDayTest { - LastDay lastDay; + LastDay lastDay; - @BeforeClass - public void before() { - lastDay = new LastDay(); - } + @BeforeClass + public void before() { + lastDay = new LastDay(); + } - @Test - public void testEval() { - String day = lastDay.eval("2018-04-23 12:00:01"); - Assert.assertEquals(day, "2018-04-30 00:00:00"); - } + @Test + public void testEval() { + String day = lastDay.eval("2018-04-23 12:00:01"); + Assert.assertEquals(day, "2018-04-30 00:00:00"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UDFDateTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UDFDateTest.java index ba78ecc1c..b8f24a8d0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UDFDateTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UDFDateTest.java @@ -51,243 +51,243 @@ public class UDFDateTest { - @Test - public void testAddMonths() { - AddMonths addMonths = new AddMonths(); - addMonths.open(null); - assertEquals(addMonths.eval("1987-06-05 00:11:22", 5), "1987-11-05 00:11:22"); - assertEquals(addMonths.eval("1987-06-05", 5), "1987-11-05"); - assertEquals(addMonths.eval("1987-06-05 00:11:22", null), null); - } + @Test + public void testAddMonths() { + AddMonths addMonths = new AddMonths(); + addMonths.open(null); + assertEquals(addMonths.eval("1987-06-05 00:11:22", 5), "1987-11-05 00:11:22"); + assertEquals(addMonths.eval("1987-06-05", 5), "1987-11-05"); + assertEquals(addMonths.eval("1987-06-05 00:11:22", null), null); + } - @Test - public void testDateAdd() { - DateAdd test = new DateAdd(); - test.open(null); - assertEquals(test.eval("1987-06-05 00:11:22", 5), "1987-06-10"); - assertEquals(test.eval(new java.sql.Timestamp(1667900725), 5), "1970-01-25 15:18:20"); - assertEquals(test.eval("1987-06-05", 5), "1987-06-10"); - assertEquals(test.eval(new java.sql.Timestamp(1667900725), null), null); - assertEquals(test.eval("1987-06-05", null), null); - } + @Test + public void testDateAdd() { + DateAdd test = new DateAdd(); + test.open(null); + assertEquals(test.eval("1987-06-05 00:11:22", 5), "1987-06-10"); + assertEquals(test.eval(new java.sql.Timestamp(1667900725), 5), "1970-01-25 15:18:20"); + assertEquals(test.eval("1987-06-05", 5), "1987-06-10"); + assertEquals(test.eval(new java.sql.Timestamp(1667900725), null), null); + assertEquals(test.eval("1987-06-05", null), null); + } - @Test - public void testDateDiff() { - DateDiff test = new DateDiff(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22", "2022-06-05 00:11:22"), -12784); - assertEquals((long) test.eval(new java.sql.Timestamp(1667900725), "2022-06-05"), -19128); - assertEquals((long) test.eval("1987-06-05", new java.sql.Timestamp(1667900725)), 6344); - assertNull(test.eval((String) null, (String) null)); - assertNull(test.eval(new java.sql.Timestamp(1867900725), (java.sql.Timestamp) null)); - } + @Test + public void testDateDiff() { + DateDiff test = new DateDiff(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22", "2022-06-05 00:11:22"), -12784); + assertEquals((long) test.eval(new java.sql.Timestamp(1667900725), "2022-06-05"), -19128); + assertEquals((long) test.eval("1987-06-05", new java.sql.Timestamp(1667900725)), 6344); + assertNull(test.eval((String) null, (String) null)); + assertNull(test.eval(new java.sql.Timestamp(1867900725), (java.sql.Timestamp) null)); + } - @Test - public void testDateFormat() { - DateFormat test = new DateFormat(); - test.open(null); - assertEquals(test.eval("1987-06-05 00:11:22", "MM-dd-yyyy"), "06-05-1987"); - assertEquals(test.eval("1987-06-05 00:11:22"), "1987-06-05 00:11:22"); - assertEquals(test.eval(new java.sql.Timestamp(1867900725), "MM-dd-yyyy"), "01-22-1970"); - assertEquals(test.eval(new java.sql.Timestamp(1867900725)), "1970-01-22 22:51:40"); - assertEquals(test.eval("1987-06-05 00:11:22", null), null); - } + @Test + public void testDateFormat() { + DateFormat test = new DateFormat(); + test.open(null); + assertEquals(test.eval("1987-06-05 00:11:22", "MM-dd-yyyy"), "06-05-1987"); + assertEquals(test.eval("1987-06-05 00:11:22"), "1987-06-05 00:11:22"); + assertEquals(test.eval(new java.sql.Timestamp(1867900725), "MM-dd-yyyy"), "01-22-1970"); + assertEquals(test.eval(new java.sql.Timestamp(1867900725)), "1970-01-22 22:51:40"); + assertEquals(test.eval("1987-06-05 00:11:22", null), null); + } - @Test - public void testDatePart() { - DatePart test = new DatePart(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22", "yyyy"), 1987); - assertEquals((int) test.eval("1987-06-05 00:11:22", "mm"), 6); - assertEquals((int) test.eval("1987-06-05 00:11:22", "dd"), 5); - assertEquals((int) test.eval("1987-06-05 00:11:22", "hh"), 0); - assertEquals((int) test.eval("1987-06-05 00:11:22", "mi"), 11); - assertEquals((int) test.eval("1987-06-05 00:11:22", "ss"), 22); - assertNull(test.eval("1987-06-05 00:11:22", null)); - assertEquals((int) test.eval("1987-06-05", "ss"), 0); - } + @Test + public void testDatePart() { + DatePart test = new DatePart(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22", "yyyy"), 1987); + assertEquals((int) test.eval("1987-06-05 00:11:22", "mm"), 6); + assertEquals((int) test.eval("1987-06-05 00:11:22", "dd"), 5); + assertEquals((int) test.eval("1987-06-05 00:11:22", "hh"), 0); + assertEquals((int) test.eval("1987-06-05 00:11:22", "mi"), 11); + assertEquals((int) test.eval("1987-06-05 00:11:22", "ss"), 22); + assertNull(test.eval("1987-06-05 00:11:22", null)); + assertEquals((int) test.eval("1987-06-05", "ss"), 0); + } - @Test - public void testDateSub() { - DateSub test = new DateSub(); - test.open(null); - assertEquals(test.eval("1987-06-05 00:11:22", 5), "1987-05-31"); - assertEquals(test.eval((String) null, 5), null); - assertEquals(test.eval("1987-06-05", 5), "1987-05-31"); - assertEquals(test.eval(new java.sql.Timestamp(1867900725), 5), "1970-01-17 22:51:40"); - assertNull(test.eval(new java.sql.Timestamp(1867900725), null)); - } + @Test + public void testDateSub() { + DateSub test = new DateSub(); + test.open(null); + assertEquals(test.eval("1987-06-05 00:11:22", 5), "1987-05-31"); + assertEquals(test.eval((String) null, 5), null); + assertEquals(test.eval("1987-06-05", 5), "1987-05-31"); + assertEquals(test.eval(new java.sql.Timestamp(1867900725), 5), "1970-01-17 22:51:40"); + assertNull(test.eval(new java.sql.Timestamp(1867900725), null)); + } - @Test - public void testDateTrunc() { - DateTrunc test = new DateTrunc(); - test.open(null); - assertEquals(test.eval("1987-06-05 00:11:22", "yyyy"), "1987-01-01 00:00:00"); - assertEquals(test.eval("1987-06-05 00:11:22", "mm"), "1987-06-01 00:00:00"); - assertEquals(test.eval("1987-06-05 00:11:22", "dd"), "1987-06-05 00:00:00"); - assertEquals(test.eval("1987-06-05 00:11:22", "hh"), "1987-06-05 00:00:00"); - assertEquals(test.eval("1987-06-05 00:11:22", "mi"), "1987-06-05 00:11:00"); - assertEquals(test.eval("1987-06-05 00:11:22", "ss"), "1987-06-05 00:11:22"); - assertNull(test.eval("1987-06-05 00:11:22", null)); - assertEquals(test.eval("1987-06-05", "yyyy"), "1987-01-01 00:00:00"); - } + @Test + public void testDateTrunc() { + DateTrunc test = new DateTrunc(); + test.open(null); + assertEquals(test.eval("1987-06-05 00:11:22", "yyyy"), "1987-01-01 00:00:00"); + assertEquals(test.eval("1987-06-05 00:11:22", "mm"), "1987-06-01 00:00:00"); + assertEquals(test.eval("1987-06-05 00:11:22", "dd"), "1987-06-05 00:00:00"); + assertEquals(test.eval("1987-06-05 00:11:22", "hh"), "1987-06-05 00:00:00"); + assertEquals(test.eval("1987-06-05 00:11:22", "mi"), "1987-06-05 00:11:00"); + assertEquals(test.eval("1987-06-05 00:11:22", "ss"), "1987-06-05 00:11:22"); + assertNull(test.eval("1987-06-05 00:11:22", null)); + assertEquals(test.eval("1987-06-05", "yyyy"), "1987-01-01 00:00:00"); + } - @Test - public void testDay() { - Day test = new Day(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 5); - assertEquals((int) test.eval("1987-06-05"), 5); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - assertEquals((int) test.eval(new java.sql.Timestamp(1867900725)), 22); - } + @Test + public void testDay() { + Day test = new Day(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 5); + assertEquals((int) test.eval("1987-06-05"), 5); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + assertEquals((int) test.eval(new java.sql.Timestamp(1867900725)), 22); + } - @Test - public void testDayOfMonth() { - DayOfMonth test = new DayOfMonth(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 5); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - assertEquals((int) test.eval(new java.sql.Timestamp(1867900725)), 22); - } + @Test + public void testDayOfMonth() { + DayOfMonth test = new DayOfMonth(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 5); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + assertEquals((int) test.eval(new java.sql.Timestamp(1867900725)), 22); + } - @Test - public void testFromUnixTime() { - FromUnixTime test = new FromUnixTime(); - test.open(null); - assertEquals(test.eval(11111111L), "1970-05-09 22:25:11"); - assertNull(test.eval(null)); - assertEquals(test.eval("11111111", "yyyy-MM-dd HH:mm:ss"), "1970-05-09 22:25:11"); - } + @Test + public void testFromUnixTime() { + FromUnixTime test = new FromUnixTime(); + test.open(null); + assertEquals(test.eval(11111111L), "1970-05-09 22:25:11"); + assertNull(test.eval(null)); + assertEquals(test.eval("11111111", "yyyy-MM-dd HH:mm:ss"), "1970-05-09 22:25:11"); + } - @Test - public void testFromUnixTimeMillis() { - FromUnixTimeMillis test = new FromUnixTimeMillis(); - test.open(null); - assertEquals(test.eval(11111111L), "1970-01-01 11:05:11.111"); - assertEquals(test.eval("11111111"), "1970-01-01 11:05:11.111"); - assertNull(test.eval((Long) null)); - assertNull(test.eval(11111111L, null)); - assertEquals(test.eval(11111111L, "yyyy-MM-dd HH:mm:ss"), "1970-01-01 11:05:11"); - } + @Test + public void testFromUnixTimeMillis() { + FromUnixTimeMillis test = new FromUnixTimeMillis(); + test.open(null); + assertEquals(test.eval(11111111L), "1970-01-01 11:05:11.111"); + assertEquals(test.eval("11111111"), "1970-01-01 11:05:11.111"); + assertNull(test.eval((Long) null)); + assertNull(test.eval(11111111L, null)); + assertEquals(test.eval(11111111L, "yyyy-MM-dd HH:mm:ss"), "1970-01-01 11:05:11"); + } - @Test - public void testHour() { - Hour test = new Hour(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 0); - assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 3); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - } + @Test + public void testHour() { + Hour test = new Hour(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 0); + assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 3); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + } - @Test - public void testIsDate() { - IsDate test = new IsDate(); - test.open(null); - assertEquals((boolean) test.eval("1987-06-05 00:11:22"), true); - assertTrue(test.eval("1987-06-05 00:11:22")); - assertFalse(test.eval("xxxxxxxxxxxxx")); - } + @Test + public void testIsDate() { + IsDate test = new IsDate(); + test.open(null); + assertEquals((boolean) test.eval("1987-06-05 00:11:22"), true); + assertTrue(test.eval("1987-06-05 00:11:22")); + assertFalse(test.eval("xxxxxxxxxxxxx")); + } - @Test - public void testLastDay() { - LastDay test = new LastDay(); - test.open(null); - assertEquals(test.eval("1987-06-05"), "1987-06-30 00:00:00"); - } + @Test + public void testLastDay() { + LastDay test = new LastDay(); + test.open(null); + assertEquals(test.eval("1987-06-05"), "1987-06-30 00:00:00"); + } - @Test - public void testMinute() { - Minute test = new Minute(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 11); - assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 18); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - } + @Test + public void testMinute() { + Minute test = new Minute(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 11); + assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 18); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + } - @Test - public void testMonth() { - Month test = new Month(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 6); - assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 1); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - } + @Test + public void testMonth() { + Month test = new Month(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 6); + assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 1); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + } - @Test - public void testNow() { - Now test = new Now(); - test.open(null); - test.eval(); - test.eval(3); - test.eval(3L); - } + @Test + public void testNow() { + Now test = new Now(); + test.open(null); + test.eval(); + test.eval(3); + test.eval(3L); + } - @Test - public void testSecond() { - Second test = new Second(); - test.open(null); - assertEquals((int) test.eval("1987-06-05 00:11:22"), 22); - assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 20); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - } + @Test + public void testSecond() { + Second test = new Second(); + test.open(null); + assertEquals((int) test.eval("1987-06-05 00:11:22"), 22); + assertEquals((int) test.eval(new java.sql.Timestamp(1667900725)), 20); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + } - @Test - public void testUnixTimeStamp() { - UnixTimeStamp test = new UnixTimeStamp(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22"), 549817882); - assertEquals((long) test.eval((Object) "1987-06-05 00:11:22.33"), 549817882); - assertEquals((long) test.eval((Object) "1987-06-05"), 549817200); - assertNull(test.eval("1987-06-05 00:11:22", null)); - } + @Test + public void testUnixTimeStamp() { + UnixTimeStamp test = new UnixTimeStamp(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22"), 549817882); + assertEquals((long) test.eval((Object) "1987-06-05 00:11:22.33"), 549817882); + assertEquals((long) test.eval((Object) "1987-06-05"), 549817200); + assertNull(test.eval("1987-06-05 00:11:22", null)); + } - @Test - public void testUnixTimeStampMillis() { - UnixTimeStampMillis test = new UnixTimeStampMillis(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22"), 549817882000L); - assertEquals((long) test.eval("1987-06-05", "yyyy-mm-dd"), 536774760000L); - assertNull(test.eval("1987-06-05", null)); - assertEquals((long) test.eval("1987-06-05"), 549817200000L); - assertEquals((long) test.eval((Object) "1987-06-05"), 549817200000L); - assertEquals((long) test.eval("1987-06-05 00:11:22.111"), 549817882111L); - } + @Test + public void testUnixTimeStampMillis() { + UnixTimeStampMillis test = new UnixTimeStampMillis(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22"), 549817882000L); + assertEquals((long) test.eval("1987-06-05", "yyyy-mm-dd"), 536774760000L); + assertNull(test.eval("1987-06-05", null)); + assertEquals((long) test.eval("1987-06-05"), 549817200000L); + assertEquals((long) test.eval((Object) "1987-06-05"), 549817200000L); + assertEquals((long) test.eval("1987-06-05 00:11:22.111"), 549817882111L); + } - @Test - public void testWeekDay() { - WeekDay test = new WeekDay(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22"), 5); - assertEquals((long) test.eval("1987-06-05"), 5); - assertNull(test.eval(null)); - } + @Test + public void testWeekDay() { + WeekDay test = new WeekDay(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22"), 5); + assertEquals((long) test.eval("1987-06-05"), 5); + assertNull(test.eval(null)); + } - @Test - public void testWeekOfYear() { - WeekOfYear test = new WeekOfYear(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22"), 23); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - assertEquals((long) test.eval("1987-06-05"), 23); - assertEquals((long) test.eval(new java.sql.Timestamp(1667900725)), 4); - } + @Test + public void testWeekOfYear() { + WeekOfYear test = new WeekOfYear(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22"), 23); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + assertEquals((long) test.eval("1987-06-05"), 23); + assertEquals((long) test.eval(new java.sql.Timestamp(1667900725)), 4); + } - @Test - public void testYear() { - Year test = new Year(); - test.open(null); - assertEquals((long) test.eval("1987-06-05 00:11:22"), 1987); - assertNull(test.eval((String) null)); - assertNull(test.eval((java.sql.Timestamp) null)); - assertEquals((long) test.eval("1987-06-05"), 1987); - assertEquals((long) test.eval(new java.sql.Timestamp(1667900725)), 1970); - } + @Test + public void testYear() { + Year test = new Year(); + test.open(null); + assertEquals((long) test.eval("1987-06-05 00:11:22"), 1987); + assertNull(test.eval((String) null)); + assertNull(test.eval((java.sql.Timestamp) null)); + assertEquals((long) test.eval("1987-06-05"), 1987); + assertEquals((long) test.eval(new java.sql.Timestamp(1667900725)), 1970); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampMillisTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampMillisTest.java index 760a3599d..a81245221 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampMillisTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampMillisTest.java @@ -29,41 +29,37 @@ public class UnixTimeStampMillisTest { - UnixTimeStampMillis udf; + UnixTimeStampMillis udf; - @BeforeMethod - public void setUp() throws Exception { - udf = new UnixTimeStampMillis(); - } + @BeforeMethod + public void setUp() throws Exception { + udf = new UnixTimeStampMillis(); + } - @Test - public void test() { + @Test + public void test() { - assertEquals(udf.eval("1993", "yyyy"), new Long(725817600000L)); + assertEquals(udf.eval("1993", "yyyy"), new Long(725817600000L)); - assertEquals(udf.eval("1993-12", "yyyy-MM"), new Long(754675200000L)); + assertEquals(udf.eval("1993-12", "yyyy-MM"), new Long(754675200000L)); - assertEquals(udf.eval("1993-12-01", "yyyy-MM-dd"), new Long(754675200000L)); + assertEquals(udf.eval("1993-12-01", "yyyy-MM-dd"), new Long(754675200000L)); - assertEquals(udf.eval("1993-12-01 12", "yyyy-MM-dd HH"), new Long(754718400000L)); + assertEquals(udf.eval("1993-12-01 12", "yyyy-MM-dd HH"), new Long(754718400000L)); - assertEquals(udf.eval("1993-12-01 12:03", "yyyy-MM-dd HH:mm"), new Long(754718580000L)); + assertEquals(udf.eval("1993-12-01 12:03", "yyyy-MM-dd HH:mm"), new Long(754718580000L)); - assertEquals(udf.eval("1993-12-01 12:03:01"), new Long(754718581000L)); + assertEquals(udf.eval("1993-12-01 12:03:01"), new Long(754718581000L)); - assertEquals(udf.eval("1993-12-01 12:03:01", "yyyy-MM-dd HH:mm:ss"), - new Long(754718581000L)); + assertEquals(udf.eval("1993-12-01 12:03:01", "yyyy-MM-dd HH:mm:ss"), new Long(754718581000L)); - assertEquals(udf.eval("1993-12-01 12:03:01.111"), new Long(754718581111L)); + assertEquals(udf.eval("1993-12-01 12:03:01.111"), new Long(754718581111L)); - assertEquals(udf.eval("1993-12-01 12:03:01.111", "yyyy-MM-dd HH:mm:ss.SSS"), - new Long(754718581111L)); - - DateTimeFormatter millisFormatter = DateTimeFormat.forPattern( - "yyyy-MM-dd'T'HH:mm:ss.SSS+00:00"); - assertEquals(millisFormatter.parseMillis("2010-04-13T15:39:24.399+00:00"), - 1271144364399L); - - } + assertEquals( + udf.eval("1993-12-01 12:03:01.111", "yyyy-MM-dd HH:mm:ss.SSS"), new Long(754718581111L)); + DateTimeFormatter millisFormatter = + DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSS+00:00"); + assertEquals(millisFormatter.parseMillis("2010-04-13T15:39:24.399+00:00"), 1271144364399L); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampTest.java index 85aaa76d7..183f34abc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/date/UnixTimeStampTest.java @@ -23,46 +23,46 @@ import static org.testng.Assert.assertNull; import java.util.Date; + import org.apache.geaflow.dsl.udf.table.date.UnixTimeStamp; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; public class UnixTimeStampTest { - UnixTimeStamp udf; - - @BeforeMethod - public void setUp() throws Exception { - udf = new UnixTimeStamp(); - } + UnixTimeStamp udf; - @Test - public void test() { + @BeforeMethod + public void setUp() throws Exception { + udf = new UnixTimeStamp(); + } - assertEquals(udf.eval("1993", "yyyy"), new Long(725817600)); + @Test + public void test() { - assertEquals(udf.eval("1993-12", "yyyy-MM"), new Long(754675200)); + assertEquals(udf.eval("1993", "yyyy"), new Long(725817600)); - assertEquals(udf.eval("1993-12-01", "yyyy-MM-dd"), new Long(754675200)); + assertEquals(udf.eval("1993-12", "yyyy-MM"), new Long(754675200)); - assertEquals(udf.eval("1993-12-01 12", "yyyy-MM-dd HH"), new Long(754718400)); + assertEquals(udf.eval("1993-12-01", "yyyy-MM-dd"), new Long(754675200)); - assertEquals(udf.eval("1993-12-01 12:03", "yyyy-MM-dd HH:mm"), new Long(754718580)); + assertEquals(udf.eval("1993-12-01 12", "yyyy-MM-dd HH"), new Long(754718400)); - assertEquals(udf.eval("1993-12-01 12:03:01"), new Long(754718581)); + assertEquals(udf.eval("1993-12-01 12:03", "yyyy-MM-dd HH:mm"), new Long(754718580)); - assertEquals(udf.eval("1993-12-01 12:03:01", "yyyy-MM-dd HH:mm:ss"), new Long(754718581)); + assertEquals(udf.eval("1993-12-01 12:03:01"), new Long(754718581)); - assertEquals(udf.eval("1993-12-01 12:03:01.111000"), new Long(754718581)); + assertEquals(udf.eval("1993-12-01 12:03:01", "yyyy-MM-dd HH:mm:ss"), new Long(754718581)); - assertEquals(udf.eval("1993-12-01 12:03:01,111", "yyyy-MM-dd HH:mm:ss,SSS"), - new Long(754718581)); + assertEquals(udf.eval("1993-12-01 12:03:01.111000"), new Long(754718581)); - assertEquals(udf.eval(null), new Long(new Date().getTime() / 1000)); + assertEquals( + udf.eval("1993-12-01 12:03:01,111", "yyyy-MM-dd HH:mm:ss,SSS"), new Long(754718581)); - assertNull(udf.eval(null, null)); + assertEquals(udf.eval(null), new Long(new Date().getTime() / 1000)); - assertEquals(udf.eval("1993-12-01 00:00:00"), new Long(754675200)); - } + assertNull(udf.eval(null, null)); + assertEquals(udf.eval("1993-12-01 00:00:00"), new Long(754675200)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegExpExtractTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegExpExtractTest.java index 5676216f4..197010063 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegExpExtractTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegExpExtractTest.java @@ -24,21 +24,21 @@ public class RegExpExtractTest { - @Test - public void test() throws Exception { - RegExpExtract udf = new RegExpExtract(); - udf.eval("252 - (_4Ped87iivN", ".*((.*))", 1); - udf.eval("100-200", "(\\d+)-(\\d+)", 1); - udf.eval("100-200", "(\\d+)-(\\d+)"); - udf.eval("100-200", "(\\d+)-(\\d+)", 1L); - udf.eval("100-200", "(\\d+)-(\\d+)", "1"); - udf.eval("100-200", "", "1"); - udf.eval("100-200", "", 1); - udf.eval("100-200", null); - udf.eval("100-200", "(\\d+)-(\\d+)", 1L); - udf.eval("100-200", "", 1L); - udf.eval("100-200", null, 1L); - udf.eval("100-200", "(\\d+)-(\\d+)", "-1"); - udf.eval("100-200", "(\\d+)-(\\d+)", -1); - } + @Test + public void test() throws Exception { + RegExpExtract udf = new RegExpExtract(); + udf.eval("252 - (_4Ped87iivN", ".*((.*))", 1); + udf.eval("100-200", "(\\d+)-(\\d+)", 1); + udf.eval("100-200", "(\\d+)-(\\d+)"); + udf.eval("100-200", "(\\d+)-(\\d+)", 1L); + udf.eval("100-200", "(\\d+)-(\\d+)", "1"); + udf.eval("100-200", "", "1"); + udf.eval("100-200", "", 1); + udf.eval("100-200", null); + udf.eval("100-200", "(\\d+)-(\\d+)", 1L); + udf.eval("100-200", "", 1L); + udf.eval("100-200", null, 1L); + udf.eval("100-200", "(\\d+)-(\\d+)", "-1"); + udf.eval("100-200", "(\\d+)-(\\d+)", -1); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegexpCountTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegexpCountTest.java index d47b8d26d..74cef3a43 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegexpCountTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/RegexpCountTest.java @@ -25,9 +25,9 @@ public class RegexpCountTest { - @Test - public void test() { - RegexpCount regexpCount = new RegexpCount(); - Assert.assertEquals(regexpCount.eval("ab1d2d3dsss", "[0-9]d").longValue(), 3); - } + @Test + public void test() { + RegexpCount regexpCount = new RegexpCount(); + Assert.assertEquals(regexpCount.eval("ab1d2d3dsss", "[0-9]d").longValue(), 3); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/SubstrTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/SubstrTest.java index ec37db6c7..f824728c2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/SubstrTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/SubstrTest.java @@ -26,21 +26,23 @@ public class SubstrTest { - @Test - public void test() { - Substr sb = new Substr(); + @Test + public void test() { + Substr sb = new Substr(); - Assert.assertEquals(sb.eval("01021000100000000017", 1, 1), "0"); - Assert.assertEquals(sb.eval("01021000100000000017", 2, 1), "1"); - Assert.assertNull(sb.eval("01021000100000000017", null, 1)); - Assert.assertEquals(sb.eval("Facebook", 5), "book"); - Assert.assertEquals(sb.eval("Facebook", -5), "ebook"); - Assert.assertEquals(sb.eval("Facebook", 5, 1), "b"); - Assert.assertEquals(sb.eval(BinaryString.fromString("01021000100000000017"), 1, 1).toString(), "0"); - Assert.assertEquals(sb.eval(BinaryString.fromString("01021000100000000017"), 2, 1).toString(), "1"); - Assert.assertNull(sb.eval(BinaryString.fromString("01021000100000000017").toString(), null, 1)); - Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), 5).toString(), "book"); - Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), -5).toString(), "ebook"); - Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), 5, 1).toString(), "b"); - } + Assert.assertEquals(sb.eval("01021000100000000017", 1, 1), "0"); + Assert.assertEquals(sb.eval("01021000100000000017", 2, 1), "1"); + Assert.assertNull(sb.eval("01021000100000000017", null, 1)); + Assert.assertEquals(sb.eval("Facebook", 5), "book"); + Assert.assertEquals(sb.eval("Facebook", -5), "ebook"); + Assert.assertEquals(sb.eval("Facebook", 5, 1), "b"); + Assert.assertEquals( + sb.eval(BinaryString.fromString("01021000100000000017"), 1, 1).toString(), "0"); + Assert.assertEquals( + sb.eval(BinaryString.fromString("01021000100000000017"), 2, 1).toString(), "1"); + Assert.assertNull(sb.eval(BinaryString.fromString("01021000100000000017").toString(), null, 1)); + Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), 5).toString(), "book"); + Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), -5).toString(), "ebook"); + Assert.assertEquals(sb.eval(BinaryString.fromString("Facebook"), 5, 1).toString(), "b"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/TrimTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/TrimTest.java index e05e7f5b9..95d87d81b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/TrimTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/TrimTest.java @@ -27,20 +27,24 @@ public class TrimTest { - @Test - public void testLTrim() { - LTrim lTrim = new LTrim(); - Assert.assertEquals(lTrim.eval(BinaryString.fromString(" abc")), BinaryString.fromString("abc")); - Assert.assertEquals(lTrim.eval(BinaryString.fromString("abc")), BinaryString.fromString("abc")); - Assert.assertEquals(lTrim.eval(BinaryString.fromString(" abc")), BinaryString.fromString("abc")); - Assert.assertEquals(lTrim.eval(BinaryString.fromString(" ")), BinaryString.fromString("")); - } + @Test + public void testLTrim() { + LTrim lTrim = new LTrim(); + Assert.assertEquals( + lTrim.eval(BinaryString.fromString(" abc")), BinaryString.fromString("abc")); + Assert.assertEquals(lTrim.eval(BinaryString.fromString("abc")), BinaryString.fromString("abc")); + Assert.assertEquals( + lTrim.eval(BinaryString.fromString(" abc")), BinaryString.fromString("abc")); + Assert.assertEquals(lTrim.eval(BinaryString.fromString(" ")), BinaryString.fromString("")); + } - @Test - public void testRLTrim() { - RTrim rTrim = new RTrim(); - Assert.assertEquals(rTrim.eval(BinaryString.fromString(" abc ")), BinaryString.fromString(" abc")); - Assert.assertEquals(rTrim.eval(BinaryString.fromString("abc ")), BinaryString.fromString("abc")); - Assert.assertEquals(rTrim.eval(BinaryString.fromString(" ")), BinaryString.fromString("")); - } + @Test + public void testRLTrim() { + RTrim rTrim = new RTrim(); + Assert.assertEquals( + rTrim.eval(BinaryString.fromString(" abc ")), BinaryString.fromString(" abc")); + Assert.assertEquals( + rTrim.eval(BinaryString.fromString("abc ")), BinaryString.fromString("abc")); + Assert.assertEquals(rTrim.eval(BinaryString.fromString(" ")), BinaryString.fromString("")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64DecodeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64DecodeTest.java index d61bee19b..ccbe85a66 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64DecodeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64DecodeTest.java @@ -26,9 +26,9 @@ public class UDFBase64DecodeTest { - @Test - public void test() { - Base64Decode u = new Base64Decode(); - assertEquals("abc ", u.eval("YWJjIA==")); - } + @Test + public void test() { + Base64Decode u = new Base64Decode(); + assertEquals("abc ", u.eval("YWJjIA==")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64EncodeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64EncodeTest.java index 12409ea0b..a03cd8b2a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64EncodeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFBase64EncodeTest.java @@ -26,12 +26,10 @@ public class UDFBase64EncodeTest { - @Test - public void test() { - Base64Encode u = new Base64Encode(); + @Test + public void test() { + Base64Encode u = new Base64Encode(); - assertEquals("YWJjIA==", u.eval("abc ")); - - - } + assertEquals("YWJjIA==", u.eval("abc ")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFGetJsonObjectTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFGetJsonObjectTest.java index 3ba3554c2..1b3743d41 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFGetJsonObjectTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFGetJsonObjectTest.java @@ -26,19 +26,20 @@ public class UDFGetJsonObjectTest { - @Test - public void test() { - - GetJsonObject udf = new GetJsonObject(); - - String jsonStr = "{\"name\": \"Bob\", \"age\": 30, \"address\": " + - "{\"city\": \"Los Angeles\", \"zip\": \"90001\"},\"items\": [\"item1\", \"item2\", \"item3\"]}"; - - assertEquals("Bob", udf.eval(jsonStr, "name")); - assertEquals("30", udf.eval(jsonStr, "$.age")); - assertEquals("Los Angeles", udf.eval(jsonStr, "$.address.city")); - assertEquals("item3", udf.eval(jsonStr, "$.items[2]")); - assertEquals(null, udf.eval(jsonStr, "gender")); - assertEquals(null, udf.eval(jsonStr, "$.items[3]")); - } + @Test + public void test() { + + GetJsonObject udf = new GetJsonObject(); + + String jsonStr = + "{\"name\": \"Bob\", \"age\": 30, \"address\": {\"city\": \"Los Angeles\", \"zip\":" + + " \"90001\"},\"items\": [\"item1\", \"item2\", \"item3\"]}"; + + assertEquals("Bob", udf.eval(jsonStr, "name")); + assertEquals("30", udf.eval(jsonStr, "$.age")); + assertEquals("Los Angeles", udf.eval(jsonStr, "$.address.city")); + assertEquals("item3", udf.eval(jsonStr, "$.items[2]")); + assertEquals(null, udf.eval(jsonStr, "gender")); + assertEquals(null, udf.eval(jsonStr, "$.items[3]")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFHashTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFHashTest.java index 627a58a73..0cbc86f6e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFHashTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFHashTest.java @@ -26,10 +26,9 @@ public class UDFHashTest { - @Test - public void test() { - Hash u = new Hash(); - assertEquals((int) u.eval("1"), 49); - } - + @Test + public void test() { + Hash u = new Hash(); + assertEquals((int) u.eval("1"), 49); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFInstrTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFInstrTest.java index b8712a604..c548566ad 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFInstrTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFInstrTest.java @@ -27,53 +27,69 @@ public class UDFInstrTest { - @Test - public void test() { + @Test + public void test() { - Instr udf = new Instr(); + Instr udf = new Instr(); - assertEquals(1L, (long) udf.eval("abc", "a")); + assertEquals(1L, (long) udf.eval("abc", "a")); - assertEquals(3L, (long) udf.eval("abc", "c")); + assertEquals(3L, (long) udf.eval("abc", "c")); - assertEquals(0L, (long) udf.eval("abc", "d")); + assertEquals(0L, (long) udf.eval("abc", "d")); - assertEquals(3L, (long) udf.eval("abc", "c", 1L)); + assertEquals(3L, (long) udf.eval("abc", "c", 1L)); - assertEquals(6L, (long) udf.eval("abcabc", "c", 4L)); + assertEquals(6L, (long) udf.eval("abcabc", "c", 4L)); - assertEquals(2L, (long) udf.eval("a\u0002abc\u0002c", "\u0002", 2L)); + assertEquals(2L, (long) udf.eval("a\u0002abc\u0002c", "\u0002", 2L)); - assertEquals(2L, (long) udf.eval("a\002b\002c", "\002")); + assertEquals(2L, (long) udf.eval("a\002b\002c", "\002")); - assertEquals(9L, (long) udf.eval("s.taobao.com", ".", 3L)); - assertEquals(2, (long) udf.eval("s.taobao.com", ".", 1L)); + assertEquals(9L, (long) udf.eval("s.taobao.com", ".", 3L)); + assertEquals(2, (long) udf.eval("s.taobao.com", ".", 1L)); - assertEquals(0, (long) udf.eval("s.taobao.com", "abc")); - } + assertEquals(0, (long) udf.eval("s.taobao.com", "abc")); + } - @Test - public void testBinaryString() { + @Test + public void testBinaryString() { - Instr udf = new Instr(); + Instr udf = new Instr(); - assertEquals(1L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("a"))); + assertEquals(1L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("a"))); - assertEquals(3L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("c"))); + assertEquals(3L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("c"))); - assertEquals(0L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("d"))); + assertEquals(0L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("d"))); - assertEquals(3L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("c"), 1L)); + assertEquals( + 3L, (long) udf.eval(BinaryString.fromString("abc"), BinaryString.fromString("c"), 1L)); - assertEquals(6L, (long) udf.eval(BinaryString.fromString("abcabc"), BinaryString.fromString("c"), 4L)); + assertEquals( + 6L, (long) udf.eval(BinaryString.fromString("abcabc"), BinaryString.fromString("c"), 4L)); - assertEquals(2L, (long) udf.eval(BinaryString.fromString("a\u0002abc\u0002c"), BinaryString.fromString("\u0002"), 2L)); + assertEquals( + 2L, + (long) + udf.eval( + BinaryString.fromString("a\u0002abc\u0002c"), + BinaryString.fromString("\u0002"), + 2L)); - assertEquals(2L, (long) udf.eval(BinaryString.fromString("a\002b\002c"), BinaryString.fromString("\002"))); + assertEquals( + 2L, + (long) udf.eval(BinaryString.fromString("a\002b\002c"), BinaryString.fromString("\002"))); - assertEquals(9L, (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("."), 3L)); - assertEquals(2, (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("."), 1L)); + assertEquals( + 9L, + (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("."), 3L)); + assertEquals( + 2, + (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("."), 1L)); - assertEquals(0, (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("abc"))); - } + assertEquals( + 0, + (long) udf.eval(BinaryString.fromString("s.taobao.com"), BinaryString.fromString("abc"))); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFKeyValueTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFKeyValueTest.java index c06961c0c..abc3d7792 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFKeyValueTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFKeyValueTest.java @@ -24,14 +24,11 @@ public class UDFKeyValueTest { - @Test - public void test() { + @Test + public void test() { - KeyValue kv = new KeyValue(); + KeyValue kv = new KeyValue(); - System.out.println(kv.eval( - "xxx^", "^", "=", "ip" - )); - - } + System.out.println(kv.eval("xxx^", "^", "=", "ip")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFRegexpReplaceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFRegexpReplaceTest.java index e257fd0f0..727eedadf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFRegexpReplaceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFRegexpReplaceTest.java @@ -26,20 +26,19 @@ public class UDFRegexpReplaceTest { - @Test - public void test() { - RegExpReplace udf = new RegExpReplace(); - String res = udf.eval("100-200", "(\\d+)-(\\d+)", "1"); - assertEquals(res, "1"); - assertEquals(udf.eval("(adfafa", "\\(", ""), "adfafa"); + @Test + public void test() { + RegExpReplace udf = new RegExpReplace(); + String res = udf.eval("100-200", "(\\d+)-(\\d+)", "1"); + assertEquals(res, "1"); + assertEquals(udf.eval("(adfafa", "\\(", ""), "adfafa"); - assertEquals(udf.eval("(adfafa", "\\(", ""), "adfafa"); + assertEquals(udf.eval("(adfafa", "\\(", ""), "adfafa"); - assertEquals(udf.eval("adf\"afa", "\"", ""), "adfafa"); + assertEquals(udf.eval("adf\"afa", "\"", ""), "adfafa"); - assertEquals(udf.eval("adf\"afa", "\"", ""), "adfafa"); + assertEquals(udf.eval("adf\"afa", "\"", ""), "adfafa"); - assertEquals(udf.eval("adfabadfasdf", "[a]", "3"), "3df3b3df3sdf"); - - } + assertEquals(udf.eval("adfabadfasdf", "[a]", "3"), "3df3b3df3sdf"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFStringTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFStringTest.java index 6f267213a..b429c55a9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFStringTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFStringTest.java @@ -52,263 +52,269 @@ public class UDFStringTest { - @Test - public void testAscii2String() { - Ascii2String test = new Ascii2String(); - test.open(null); - assertEquals(test.eval(48), "0"); - assertEquals(test.eval(48L), "0"); - assertNull(test.eval((Long) null)); - assertNull(test.eval((Integer) null)); - } + @Test + public void testAscii2String() { + Ascii2String test = new Ascii2String(); + test.open(null); + assertEquals(test.eval(48), "0"); + assertEquals(test.eval(48L), "0"); + assertNull(test.eval((Long) null)); + assertNull(test.eval((Integer) null)); + } - @Test - public void testBase64() { - Base64Encode encode = new Base64Encode(); - encode.open(null); - Base64Decode decode = new Base64Decode(); - decode.open(null); - String test = "ant group"; - assertEquals(decode.eval(encode.eval(test)), test); - } + @Test + public void testBase64() { + Base64Encode encode = new Base64Encode(); + encode.open(null); + Base64Decode decode = new Base64Decode(); + decode.open(null); + String test = "ant group"; + assertEquals(decode.eval(encode.eval(test)), test); + } - @Test - public void testConcat() { - String string = "ant group"; - BinaryString binaryString = BinaryString.fromString(string); - Concat test = new Concat(); - test.open(null); - assertEquals(test.eval(string, string, string), "ant groupant groupant group"); - assertEquals(test.eval(string, null, string), "ant groupant group"); - assertEquals(test.eval(string, string, null), "ant groupant group"); - assertEquals(test.eval((String) null, null, null), ""); + @Test + public void testConcat() { + String string = "ant group"; + BinaryString binaryString = BinaryString.fromString(string); + Concat test = new Concat(); + test.open(null); + assertEquals(test.eval(string, string, string), "ant groupant groupant group"); + assertEquals(test.eval(string, null, string), "ant groupant group"); + assertEquals(test.eval(string, string, null), "ant groupant group"); + assertEquals(test.eval((String) null, null, null), ""); - assertEquals(test.eval(binaryString, binaryString, binaryString), - BinaryString.fromString("ant groupant groupant group")); - assertEquals(test.eval(binaryString, null, binaryString), BinaryString.fromString( - "ant groupant group")); - assertEquals(test.eval(binaryString, binaryString, null), BinaryString.fromString( - "ant groupant group")); - assertEquals(test.eval(BinaryString.fromString("蚂蚁1"), BinaryString.fromString("蚂蚁2"), - BinaryString.fromString("蚂蚁3")), BinaryString.fromString("蚂蚁1蚂蚁2蚂蚁3")); - assertEquals(test.eval((BinaryString) null, null, null), BinaryString.fromString("")); - } + assertEquals( + test.eval(binaryString, binaryString, binaryString), + BinaryString.fromString("ant groupant groupant group")); + assertEquals( + test.eval(binaryString, null, binaryString), BinaryString.fromString("ant groupant group")); + assertEquals( + test.eval(binaryString, binaryString, null), BinaryString.fromString("ant groupant group")); + assertEquals( + test.eval( + BinaryString.fromString("蚂蚁1"), + BinaryString.fromString("蚂蚁2"), + BinaryString.fromString("蚂蚁3")), + BinaryString.fromString("蚂蚁1蚂蚁2蚂蚁3")); + assertEquals(test.eval((BinaryString) null, null, null), BinaryString.fromString("")); + } - @Test - public void testConcatWS() { - String string = "ant group"; - ConcatWS test = new ConcatWS(); - BinaryString binaryString = BinaryString.fromString(string); - test.open(null); - assertEquals(test.eval(",", string, string, string), - "ant group,ant group,ant group"); - assertEquals(test.eval("-", string, string, string), - "ant group-ant group-ant group"); - assertEquals(test.eval("***", string, string, string), - "ant group***ant group***ant group"); - assertEquals(test.eval(",", "1", string, "23"), - "1,ant group,23"); - assertEquals(test.eval(",", string, null, string), - "ant group,,ant group"); - assertEquals(test.eval(",", (String) null, null, null), - ",,"); - assertEquals(test.eval(null, string, string, string), - "ant groupant groupant group"); + @Test + public void testConcatWS() { + String string = "ant group"; + ConcatWS test = new ConcatWS(); + BinaryString binaryString = BinaryString.fromString(string); + test.open(null); + assertEquals(test.eval(",", string, string, string), "ant group,ant group,ant group"); + assertEquals(test.eval("-", string, string, string), "ant group-ant group-ant group"); + assertEquals(test.eval("***", string, string, string), "ant group***ant group***ant group"); + assertEquals(test.eval(",", "1", string, "23"), "1,ant group,23"); + assertEquals(test.eval(",", string, null, string), "ant group,,ant group"); + assertEquals(test.eval(",", (String) null, null, null), ",,"); + assertEquals(test.eval(null, string, string, string), "ant groupant groupant group"); - assertEquals( - test.eval(",", binaryString, binaryString, binaryString), - BinaryString.fromString("ant group,ant group,ant group")); - assertEquals( - test.eval("-", binaryString, binaryString, binaryString), - BinaryString.fromString("ant group-ant group-ant group")); - assertEquals( - test.eval("***", binaryString, binaryString, binaryString), - BinaryString.fromString("ant group***ant group***ant group")); - assertEquals(test.eval(",", BinaryString.fromString("1"), - binaryString, BinaryString.fromString("23")), - BinaryString.fromString("1,ant group,23")); - assertEquals(test.eval(",", binaryString, null, binaryString), - BinaryString.fromString("ant group,,ant group")); - assertEquals(test.eval(",", (BinaryString) null, null, null), - BinaryString.fromString(",,")); - assertEquals(test.eval((String) null, binaryString, binaryString, binaryString), - BinaryString.fromString("ant groupant groupant group")); - assertEquals(test.eval((BinaryString) null, binaryString, binaryString, binaryString), - BinaryString.fromString("ant groupant groupant group")); - assertEquals(test.eval(",", BinaryString.fromString("蚂蚁1"), + assertEquals( + test.eval(",", binaryString, binaryString, binaryString), + BinaryString.fromString("ant group,ant group,ant group")); + assertEquals( + test.eval("-", binaryString, binaryString, binaryString), + BinaryString.fromString("ant group-ant group-ant group")); + assertEquals( + test.eval("***", binaryString, binaryString, binaryString), + BinaryString.fromString("ant group***ant group***ant group")); + assertEquals( + test.eval(",", BinaryString.fromString("1"), binaryString, BinaryString.fromString("23")), + BinaryString.fromString("1,ant group,23")); + assertEquals( + test.eval(",", binaryString, null, binaryString), + BinaryString.fromString("ant group,,ant group")); + assertEquals(test.eval(",", (BinaryString) null, null, null), BinaryString.fromString(",,")); + assertEquals( + test.eval((String) null, binaryString, binaryString, binaryString), + BinaryString.fromString("ant groupant groupant group")); + assertEquals( + test.eval((BinaryString) null, binaryString, binaryString, binaryString), + BinaryString.fromString("ant groupant groupant group")); + assertEquals( + test.eval( + ",", + BinaryString.fromString("蚂蚁1"), BinaryString.fromString("蚂蚁2"), - BinaryString.fromString("蚂蚁3")), BinaryString.fromString("蚂蚁1,蚂蚁2,蚂蚁3")); - } + BinaryString.fromString("蚂蚁3")), + BinaryString.fromString("蚂蚁1,蚂蚁2,蚂蚁3")); + } - @Test - public void testHash() { - String string = "ant group"; - Hash test = new Hash(); - test.open(null); - assertNull(test.eval((String) null)); - assertNull(test.eval((Integer) null)); - } + @Test + public void testHash() { + String string = "ant group"; + Hash test = new Hash(); + test.open(null); + assertNull(test.eval((String) null)); + assertNull(test.eval((Integer) null)); + } - @Test - public void testIndexOf() { - String string = "ant group"; - BinaryString binaryString = BinaryString.fromString("ant group"); - IndexOf test = new IndexOf(); - test.open(null); - assertEquals((int) test.eval(string, "ant"), 0); - assertEquals((int) test.eval(string, "group", 3), 4); - assertEquals((int) test.eval(null, "ant"), -1); - assertEquals((int) test.eval(string, "group", -1), 4); + @Test + public void testIndexOf() { + String string = "ant group"; + BinaryString binaryString = BinaryString.fromString("ant group"); + IndexOf test = new IndexOf(); + test.open(null); + assertEquals((int) test.eval(string, "ant"), 0); + assertEquals((int) test.eval(string, "group", 3), 4); + assertEquals((int) test.eval(null, "ant"), -1); + assertEquals((int) test.eval(string, "group", -1), 4); - assertEquals((int) test.eval(binaryString, BinaryString.fromString("ant")), 0); - assertEquals((int) test.eval(binaryString, BinaryString.fromString("group"), 3), 4); - assertEquals((int) test.eval(null, BinaryString.fromString("ant")), -1); - assertEquals((int) test.eval(binaryString, BinaryString.fromString("group"), -1), 4); + assertEquals((int) test.eval(binaryString, BinaryString.fromString("ant")), 0); + assertEquals((int) test.eval(binaryString, BinaryString.fromString("group"), 3), 4); + assertEquals((int) test.eval(null, BinaryString.fromString("ant")), -1); + assertEquals((int) test.eval(binaryString, BinaryString.fromString("group"), -1), 4); - assertEquals( - (int) test.eval(BinaryString.fromString("数据砖头"), BinaryString.fromString("砖"), -1), - 2); - } + assertEquals( + (int) test.eval(BinaryString.fromString("数据砖头"), BinaryString.fromString("砖"), -1), 2); + } - @Test - public void testInstr() { - String string = "ant group"; - Instr test = new Instr(); - test.open(null); - assertEquals((long) test.eval(string, "group"), 5); - assertEquals((long) test.eval(string, "group", 3L), 5); - assertNull(test.eval(string, null, 3L)); - assertNull(test.eval(string, "group", -1L)); - assertNull(test.eval(string, "group", 3L, -1L)); - } + @Test + public void testInstr() { + String string = "ant group"; + Instr test = new Instr(); + test.open(null); + assertEquals((long) test.eval(string, "group"), 5); + assertEquals((long) test.eval(string, "group", 3L), 5); + assertNull(test.eval(string, null, 3L)); + assertNull(test.eval(string, "group", -1L)); + assertNull(test.eval(string, "group", 3L, -1L)); + } - @Test - public void testIsBlank() { - String string = "ant group"; - IsBlank test = new IsBlank(); - test.open(null); - assertEquals((boolean) test.eval(string), false); - } + @Test + public void testIsBlank() { + String string = "ant group"; + IsBlank test = new IsBlank(); + test.open(null); + assertEquals((boolean) test.eval(string), false); + } - @Test - public void testKeyValue() { - KeyValue test = new KeyValue(); - test.open(null); - assertEquals(test.eval("key1:value1 key2:value2", " ", ":", "key1"), "value1"); - assertNull(test.eval((Object) null, " ", ":", "key")); - } + @Test + public void testKeyValue() { + KeyValue test = new KeyValue(); + test.open(null); + assertEquals(test.eval("key1:value1 key2:value2", " ", ":", "key1"), "value1"); + assertNull(test.eval((Object) null, " ", ":", "key")); + } - @Test - public void testLength() { - Length test = new Length(); - test.open(null); - assertEquals((long) test.eval(BinaryString.fromString("test")), 4); - } + @Test + public void testLength() { + Length test = new Length(); + test.open(null); + assertEquals((long) test.eval(BinaryString.fromString("test")), 4); + } - @Test - public void testLike() { - Like test = new Like(); - test.open(null); - assertTrue(test.eval("abc", "%abc")); - assertTrue(test.eval("abc", "abc%")); - assertTrue(test.eval("abc", "a%bc")); - assertFalse(test.eval("test", "abc\\%")); - assertFalse(test.eval("test", "abc\\%de%")); - assertFalse(test.eval("test", "abc\\%de%")); - assertFalse(test.eval("atest", "a%bc")); - } + @Test + public void testLike() { + Like test = new Like(); + test.open(null); + assertTrue(test.eval("abc", "%abc")); + assertTrue(test.eval("abc", "abc%")); + assertTrue(test.eval("abc", "a%bc")); + assertFalse(test.eval("test", "abc\\%")); + assertFalse(test.eval("test", "abc\\%de%")); + assertFalse(test.eval("test", "abc\\%de%")); + assertFalse(test.eval("atest", "a%bc")); + } - @Test - public void testLTrim() { - LTrim test = new LTrim(); - test.open(null); - assertEquals(test.eval(" facebook "), "facebook "); - assertNull(test.eval((String) null)); - } + @Test + public void testLTrim() { + LTrim test = new LTrim(); + test.open(null); + assertEquals(test.eval(" facebook "), "facebook "); + assertNull(test.eval((String) null)); + } - @Test - public void testRegExp() { - RegExp test = new RegExp(); - test.open(null); - assertTrue(test.eval("a.b.c.d.e.f", ".")); - assertNull(test.eval("a.b.c.d.e.f", null)); - assertFalse(test.eval("a.b.c.d.e.f", "")); - } + @Test + public void testRegExp() { + RegExp test = new RegExp(); + test.open(null); + assertTrue(test.eval("a.b.c.d.e.f", ".")); + assertNull(test.eval("a.b.c.d.e.f", null)); + assertFalse(test.eval("a.b.c.d.e.f", "")); + } - @Test - public void testRegExpReplace() { - RegExpReplace test = new RegExpReplace(); - test.open(null); - assertEquals(test.eval("100-200", "(\\d+)", "num"), "num-num"); - assertNull(test.eval(null, "(\\d+)", "num")); - } + @Test + public void testRegExpReplace() { + RegExpReplace test = new RegExpReplace(); + test.open(null); + assertEquals(test.eval("100-200", "(\\d+)", "num"), "num-num"); + assertNull(test.eval(null, "(\\d+)", "num")); + } - @Test - public void testRepeat() { - Repeat test = new Repeat(); - test.open(null); - assertEquals(test.eval("AntGroup", 3), "AntGroupAntGroupAntGroup"); - assertNull(test.eval(null, 3)); - } + @Test + public void testRepeat() { + Repeat test = new Repeat(); + test.open(null); + assertEquals(test.eval("AntGroup", 3), "AntGroupAntGroupAntGroup"); + assertNull(test.eval(null, 3)); + } - @Test - public void testReplace() { - Replace test = new Replace(); - test.open(null); - assertEquals(test.eval("AntGroup", "Ant", "ant"), "antGroup"); - assertEquals(test.eval((Object) null, "Ant", "ant"), "null"); - } + @Test + public void testReplace() { + Replace test = new Replace(); + test.open(null); + assertEquals(test.eval("AntGroup", "Ant", "ant"), "antGroup"); + assertEquals(test.eval((Object) null, "Ant", "ant"), "null"); + } - @Test - public void testReverse() { - Reverse test = new Reverse(); - test.open(null); - assertEquals(test.eval("AntGroup"), "puorGtnA"); - assertEquals(test.eval(BinaryString.fromString("AntGroup")), BinaryString.fromString("puorGtnA")); - assertNull(test.eval((String) null)); - assertNull(test.eval((BinaryString) null)); - } + @Test + public void testReverse() { + Reverse test = new Reverse(); + test.open(null); + assertEquals(test.eval("AntGroup"), "puorGtnA"); + assertEquals( + test.eval(BinaryString.fromString("AntGroup")), BinaryString.fromString("puorGtnA")); + assertNull(test.eval((String) null)); + assertNull(test.eval((BinaryString) null)); + } - @Test - public void testRTrim() { - RTrim test = new RTrim(); - test.open(null); - assertEquals(test.eval(" AntGroup "), " AntGroup"); - assertNull(test.eval((String) null)); - } + @Test + public void testRTrim() { + RTrim test = new RTrim(); + test.open(null); + assertEquals(test.eval(" AntGroup "), " AntGroup"); + assertNull(test.eval((String) null)); + } - @Test - public void testSpace() { - Space test = new Space(); - test.open(null); - assertEquals(test.eval(6L), " "); - assertNull(test.eval(null)); - } + @Test + public void testSpace() { + Space test = new Space(); + test.open(null); + assertEquals(test.eval(6L), " "); + assertNull(test.eval(null)); + } - @Test - public void testSplitEx() { - SplitEx test = new SplitEx(); - test.open(null); - assertEquals(test.eval("a.b.c.d.e", ".", 1), "b"); - assertNull(test.eval(null, ".", 1)); - assertNull(test.eval("a.b.c.d.e", ".", -1)); - assertNull(test.eval("a.b.c.d.e", ".", 5)); - assertEquals(test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), 1), BinaryString.fromString("b")); - assertNull(test.eval(null, BinaryString.fromString("."), 1)); - assertNull(test.eval(null, BinaryString.fromString(""), 1)); - assertNull(test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), -1)); - assertNull(test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), 5)); - } + @Test + public void testSplitEx() { + SplitEx test = new SplitEx(); + test.open(null); + assertEquals(test.eval("a.b.c.d.e", ".", 1), "b"); + assertNull(test.eval(null, ".", 1)); + assertNull(test.eval("a.b.c.d.e", ".", -1)); + assertNull(test.eval("a.b.c.d.e", ".", 5)); + assertEquals( + test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), 1), + BinaryString.fromString("b")); + assertNull(test.eval(null, BinaryString.fromString("."), 1)); + assertNull(test.eval(null, BinaryString.fromString(""), 1)); + assertNull(test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), -1)); + assertNull(test.eval(BinaryString.fromString("a.b.c.d.e"), BinaryString.fromString("."), 5)); + } - @Test - public void testUrlDecode() { - UrlEncode encode = new UrlEncode(); - encode.open(null); - UrlDecode decode = new UrlDecode(); - decode.open(null); - String test = "ant group"; - assertEquals(decode.eval(encode.eval(test)), test); - assertNull(encode.eval(null)); - assertNull(decode.eval(null)); - } + @Test + public void testUrlDecode() { + UrlEncode encode = new UrlEncode(); + encode.open(null); + UrlDecode decode = new UrlDecode(); + decode.open(null); + String test = "ant group"; + assertEquals(decode.eval(encode.eval(test)), test); + assertNull(encode.eval(null)); + assertNull(decode.eval(null)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFSubstrTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFSubstrTest.java index 3e6fdf237..e316624ae 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFSubstrTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/string/UDFSubstrTest.java @@ -24,12 +24,11 @@ public class UDFSubstrTest { - @Test - public void testUDFSubstr() { - Substr substr = new Substr(); - String time = String.valueOf(System.currentTimeMillis()); - System.out.println(time); - System.out.println(substr.eval(time, 0, 10)); - } - + @Test + public void testUDFSubstr() { + Substr substr = new Substr(); + String time = String.valueOf(System.currentTimeMillis()); + System.out.println(time); + System.out.println(substr.eval(time, 0, 10)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/table/other/PropertyExistsTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/table/other/PropertyExistsTest.java index 47bfa478c..508e99bd7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/table/other/PropertyExistsTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/table/other/PropertyExistsTest.java @@ -31,185 +31,189 @@ * Unit tests for PropertyExists ISO-GQL predicate function. * *

Tests validate: + * *

    - *
  • Three-valued logic (NULL handling)
  • - *
  • Type validation and error handling
  • - *
  • Property name validation
  • - *
  • ISO-GQL compliance
  • + *
  • Three-valued logic (NULL handling) + *
  • Type validation and error handling + *
  • Property name validation + *
  • ISO-GQL compliance *
*/ public class PropertyExistsTest { - @Test - public void testNullElement() { - PropertyExists func = new PropertyExists(); - - // ISO-GQL Rule: NULL element → Unknown (null) - // Test with null Object - Boolean result = func.eval((Object) null, "anyProperty"); - Assert.assertNull("NULL element should return NULL (Unknown in three-valued logic)", result); - - // Test with null RowVertex - result = func.eval((RowVertex) null, "anyProperty"); - Assert.assertNull("NULL vertex should return NULL", result); - - // Test with null RowEdge - result = func.eval((RowEdge) null, "anyProperty"); - Assert.assertNull("NULL edge should return NULL", result); - - // Test with null Row - result = func.eval((Row) null, "anyProperty"); - Assert.assertNull("NULL row should return NULL", result); + @Test + public void testNullElement() { + PropertyExists func = new PropertyExists(); + + // ISO-GQL Rule: NULL element → Unknown (null) + // Test with null Object + Boolean result = func.eval((Object) null, "anyProperty"); + Assert.assertNull("NULL element should return NULL (Unknown in three-valued logic)", result); + + // Test with null RowVertex + result = func.eval((RowVertex) null, "anyProperty"); + Assert.assertNull("NULL vertex should return NULL", result); + + // Test with null RowEdge + result = func.eval((RowEdge) null, "anyProperty"); + Assert.assertNull("NULL edge should return NULL", result); + + // Test with null Row + result = func.eval((Row) null, "anyProperty"); + Assert.assertNull("NULL row should return NULL", result); + } + + @Test + public void testNonNullVertex() { + PropertyExists func = new PropertyExists(); + + // Create a simple vertex (non-null) + RowVertex vertex = new LongVertex(1L); + + // In GeaFlow's implementation, property existence is validated at compile-time + // At runtime, non-null elements with valid property names return true + Boolean result = func.eval(vertex, "name"); + Assert.assertNotNull("Non-null vertex should not return NULL", result); + Assert.assertTrue("Non-null vertex with valid property name should return TRUE", result); + } + + @Test + public void testNonNullRow() { + PropertyExists func = new PropertyExists(); + + // Create a simple row (non-null) + Row row = ObjectRow.create(new Object[] {"value1", "value2"}); + + // Property existence validated at compile-time + Boolean result = func.eval(row, "field1"); + Assert.assertNotNull("Non-null row should not return NULL", result); + Assert.assertTrue("Non-null row with valid property name should return TRUE", result); + } + + @Test + public void testThreeValuedLogic() { + PropertyExists func = new PropertyExists(); + + // Test NULL case (Unknown in three-valued logic) + Boolean resultNull = func.eval((Object) null, "property"); + Assert.assertNull("Three-valued logic: NULL element → Unknown (null)", resultNull); + + // Test TRUE case (property exists - simplified as non-null element) + RowVertex vertex = new LongVertex(1L); + Boolean resultTrue = func.eval(vertex, "property"); + Assert.assertTrue("Three-valued logic: Non-null element → TRUE", resultTrue); + } + + @Test + public void testDescription() { + PropertyExists func = new PropertyExists(); + + // Verify the function has proper description annotation + Assert.assertNotNull("PropertyExists class should exist", func); + + // The @Description annotation should be present (checked by reflection if needed) + Assert.assertTrue( + "PropertyExists should be a UDF", + func.getClass().getSuperclass().getSimpleName().equals("UDF")); + } + + // ==================== Error Handling Tests ==================== + + @Test(expected = IllegalArgumentException.class) + public void testInvalidElementType() { + PropertyExists func = new PropertyExists(); + + // Test with invalid element type (String instead of graph element) + func.eval("not a graph element", "propertyName"); + // Should throw IllegalArgumentException + } + + @Test(expected = IllegalArgumentException.class) + public void testInvalidElementTypeInteger() { + PropertyExists func = new PropertyExists(); + + // Test with invalid element type (Integer) + func.eval(123, "propertyName"); + // Should throw IllegalArgumentException + } + + @Test + public void testInvalidElementTypeMessage() { + PropertyExists func = new PropertyExists(); + + try { + func.eval("invalid", "propertyName"); + Assert.fail("Should have thrown IllegalArgumentException for invalid element type"); + } catch (IllegalArgumentException e) { + // Verify error message contains useful information + Assert.assertTrue( + "Error message should mention graph element requirement", + e.getMessage().contains("graph element")); + Assert.assertTrue( + "Error message should include actual type", e.getMessage().contains("String")); } - - @Test - public void testNonNullVertex() { - PropertyExists func = new PropertyExists(); - - // Create a simple vertex (non-null) - RowVertex vertex = new LongVertex(1L); - - // In GeaFlow's implementation, property existence is validated at compile-time - // At runtime, non-null elements with valid property names return true - Boolean result = func.eval(vertex, "name"); - Assert.assertNotNull("Non-null vertex should not return NULL", result); - Assert.assertTrue("Non-null vertex with valid property name should return TRUE", result); + } + + @Test(expected = IllegalArgumentException.class) + public void testNullPropertyName() { + PropertyExists func = new PropertyExists(); + + // Test with null property name + RowVertex vertex = new LongVertex(1L); + func.eval(vertex, null); + // Should throw IllegalArgumentException + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyPropertyName() { + PropertyExists func = new PropertyExists(); + + // Test with empty property name + RowVertex vertex = new LongVertex(1L); + func.eval(vertex, ""); + // Should throw IllegalArgumentException + } + + @Test(expected = IllegalArgumentException.class) + public void testWhitespacePropertyName() { + PropertyExists func = new PropertyExists(); + + // Test with whitespace-only property name + RowVertex vertex = new LongVertex(1L); + func.eval(vertex, " "); + // Should throw IllegalArgumentException + } + + @Test + public void testInvalidPropertyNameMessage() { + PropertyExists func = new PropertyExists(); + RowVertex vertex = new LongVertex(1L); + + try { + func.eval(vertex, null); + Assert.fail("Should have thrown IllegalArgumentException for null property name"); + } catch (IllegalArgumentException e) { + // Verify error message is clear + Assert.assertTrue( + "Error message should mention property name requirement", + e.getMessage().contains("property name")); } + } - @Test - public void testNonNullRow() { - PropertyExists func = new PropertyExists(); - - // Create a simple row (non-null) - Row row = ObjectRow.create(new Object[]{"value1", "value2"}); - - // Property existence validated at compile-time - Boolean result = func.eval(row, "field1"); - Assert.assertNotNull("Non-null row should not return NULL", result); - Assert.assertTrue("Non-null row with valid property name should return TRUE", result); - } + @Test + public void testTypeSpecificOverloads() { + PropertyExists func = new PropertyExists(); - @Test - public void testThreeValuedLogic() { - PropertyExists func = new PropertyExists(); + // Test that type-specific overloads work correctly + RowVertex vertex = new LongVertex(1L); + Row row = ObjectRow.create(new Object[] {"value"}); - // Test NULL case (Unknown in three-valued logic) - Boolean resultNull = func.eval((Object) null, "property"); - Assert.assertNull("Three-valued logic: NULL element → Unknown (null)", resultNull); - - // Test TRUE case (property exists - simplified as non-null element) - RowVertex vertex = new LongVertex(1L); - Boolean resultTrue = func.eval(vertex, "property"); - Assert.assertTrue("Three-valued logic: Non-null element → TRUE", resultTrue); - } - - @Test - public void testDescription() { - PropertyExists func = new PropertyExists(); - - // Verify the function has proper description annotation - Assert.assertNotNull("PropertyExists class should exist", func); - - // The @Description annotation should be present (checked by reflection if needed) - Assert.assertTrue("PropertyExists should be a UDF", - func.getClass().getSuperclass().getSimpleName().equals("UDF")); - } - - // ==================== Error Handling Tests ==================== - - @Test(expected = IllegalArgumentException.class) - public void testInvalidElementType() { - PropertyExists func = new PropertyExists(); - - // Test with invalid element type (String instead of graph element) - func.eval("not a graph element", "propertyName"); - // Should throw IllegalArgumentException - } + // These should all work without ClassCastException + Boolean vertexResult = func.eval(vertex, "name"); + Boolean rowResult = func.eval(row, "field"); - @Test(expected = IllegalArgumentException.class) - public void testInvalidElementTypeInteger() { - PropertyExists func = new PropertyExists(); - - // Test with invalid element type (Integer) - func.eval(123, "propertyName"); - // Should throw IllegalArgumentException - } - - @Test - public void testInvalidElementTypeMessage() { - PropertyExists func = new PropertyExists(); - - try { - func.eval("invalid", "propertyName"); - Assert.fail("Should have thrown IllegalArgumentException for invalid element type"); - } catch (IllegalArgumentException e) { - // Verify error message contains useful information - Assert.assertTrue("Error message should mention graph element requirement", - e.getMessage().contains("graph element")); - Assert.assertTrue("Error message should include actual type", - e.getMessage().contains("String")); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testNullPropertyName() { - PropertyExists func = new PropertyExists(); - - // Test with null property name - RowVertex vertex = new LongVertex(1L); - func.eval(vertex, null); - // Should throw IllegalArgumentException - } - - @Test(expected = IllegalArgumentException.class) - public void testEmptyPropertyName() { - PropertyExists func = new PropertyExists(); - - // Test with empty property name - RowVertex vertex = new LongVertex(1L); - func.eval(vertex, ""); - // Should throw IllegalArgumentException - } - - @Test(expected = IllegalArgumentException.class) - public void testWhitespacePropertyName() { - PropertyExists func = new PropertyExists(); - - // Test with whitespace-only property name - RowVertex vertex = new LongVertex(1L); - func.eval(vertex, " "); - // Should throw IllegalArgumentException - } - - @Test - public void testInvalidPropertyNameMessage() { - PropertyExists func = new PropertyExists(); - RowVertex vertex = new LongVertex(1L); - - try { - func.eval(vertex, null); - Assert.fail("Should have thrown IllegalArgumentException for null property name"); - } catch (IllegalArgumentException e) { - // Verify error message is clear - Assert.assertTrue("Error message should mention property name requirement", - e.getMessage().contains("property name")); - } - } - - @Test - public void testTypeSpecificOverloads() { - PropertyExists func = new PropertyExists(); - - // Test that type-specific overloads work correctly - RowVertex vertex = new LongVertex(1L); - Row row = ObjectRow.create(new Object[]{"value"}); - - // These should all work without ClassCastException - Boolean vertexResult = func.eval(vertex, "name"); - Boolean rowResult = func.eval(row, "field"); - - Assert.assertNotNull("Vertex overload should work", vertexResult); - Assert.assertNotNull("Row overload should work", rowResult); - Assert.assertTrue("Type-specific overloads should return TRUE", vertexResult && rowResult); - } + Assert.assertNotNull("Vertex overload should work", vertexResult); + Assert.assertNotNull("Row overload should work", rowResult); + Assert.assertTrue("Type-specific overloads should return TRUE", vertexResult && rowResult); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/util/ParserUtilTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/util/ParserUtilTest.java index 48978f8a8..0d5a7d2b9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/util/ParserUtilTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/util/ParserUtilTest.java @@ -23,7 +23,6 @@ import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; -import com.google.common.collect.Lists; import org.apache.calcite.rel.type.RelDataType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.common.type.primitive.LongType; @@ -40,95 +39,95 @@ import org.apache.geaflow.dsl.planner.GQLJavaTypeFactory; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class ParserUtilTest { - @Test - public void testConvertTypeUtil() { - RelDataType relType; - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - TableField field1 = new TableField("name", Types.STRING, true); - TableField field2 = new TableField("id", Types.LONG, true); - TableField field3 = new TableField("age", Types.DOUBLE, true); - StructType structType = - new StructType(Lists.newArrayList(field1, field2)); - assertEquals(structType.getField(1).getName(), "id"); - assertEquals(structType.getTypeClass(), Row.class); - assertEquals(structType.indexOf("name"), 0); - assertEquals(structType.getName(), "STRUCT"); - assertEquals(structType.getField("name").getName(), "name"); - structType.addField(field3); - assertEquals(structType.getFieldNames().size(), 2); - assertNotNull(structType.toString()); - structType.replace("name", field1); - assertEquals(structType.getFieldNames().size(), 2); - relType = SqlTypeUtil.convertToRelType(structType, true, typeFactory); - assertEquals(relType.toString(), "RecordType:peek(VARCHAR name, BIGINT id)"); - assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 2); + @Test + public void testConvertTypeUtil() { + RelDataType relType; + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + TableField field1 = new TableField("name", Types.STRING, true); + TableField field2 = new TableField("id", Types.LONG, true); + TableField field3 = new TableField("age", Types.DOUBLE, true); + StructType structType = new StructType(Lists.newArrayList(field1, field2)); + assertEquals(structType.getField(1).getName(), "id"); + assertEquals(structType.getTypeClass(), Row.class); + assertEquals(structType.indexOf("name"), 0); + assertEquals(structType.getName(), "STRUCT"); + assertEquals(structType.getField("name").getName(), "name"); + structType.addField(field3); + assertEquals(structType.getFieldNames().size(), 2); + assertNotNull(structType.toString()); + structType.replace("name", field1); + assertEquals(structType.getFieldNames().size(), 2); + relType = SqlTypeUtil.convertToRelType(structType, true, typeFactory); + assertEquals(relType.toString(), "RecordType:peek(VARCHAR name, BIGINT id)"); + assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 2); - VertexType vertexType = new VertexType(Lists.newArrayList(field1, field2)); - assertEquals(vertexType.getField(1).getName(), "id"); - assertEquals(vertexType.getTypeClass(), RowVertex.class); - assertEquals(vertexType.indexOf("name"), 0); - assertEquals(vertexType.getName(), "VERTEX"); - assertEquals(vertexType.getField("name").getName(), "name"); - vertexType.addField(field3); - assertEquals(vertexType.getFieldNames().size(), 2); - assertNotNull(vertexType.toString()); - relType = SqlTypeUtil.convertToRelType(vertexType, true, typeFactory); - assertEquals(relType.toString(), - "Vertex:RecordType:peek(VARCHAR name, VARCHAR ~label, BIGINT id)"); - assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 3); - assertEquals(SqlTypeUtil.convertType(relType).getName(), "VERTEX"); + VertexType vertexType = new VertexType(Lists.newArrayList(field1, field2)); + assertEquals(vertexType.getField(1).getName(), "id"); + assertEquals(vertexType.getTypeClass(), RowVertex.class); + assertEquals(vertexType.indexOf("name"), 0); + assertEquals(vertexType.getName(), "VERTEX"); + assertEquals(vertexType.getField("name").getName(), "name"); + vertexType.addField(field3); + assertEquals(vertexType.getFieldNames().size(), 2); + assertNotNull(vertexType.toString()); + relType = SqlTypeUtil.convertToRelType(vertexType, true, typeFactory); + assertEquals( + relType.toString(), "Vertex:RecordType:peek(VARCHAR name, VARCHAR ~label, BIGINT id)"); + assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 3); + assertEquals(SqlTypeUtil.convertType(relType).getName(), "VERTEX"); - EdgeType edgeType = new EdgeType(Lists.newArrayList(field1, field2), false); - assertEquals(edgeType.getField(1).getName(), "id"); - assertEquals(edgeType.getTypeClass(), RowEdge.class); - assertEquals(edgeType.indexOf("name"), 0); - assertEquals(edgeType.getName(), "EDGE"); - assertEquals(edgeType.getField("name").getName(), "name"); - edgeType.addField(field3); - assertEquals(edgeType.getFieldNames().size(), 2); - assertNotNull(edgeType.toString()); - relType = SqlTypeUtil.convertToRelType(edgeType, true, typeFactory); - assertEquals(relType.toString(), - "Edge: RecordType:peek(VARCHAR name, BIGINT id, VARCHAR ~label)"); - assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 3); - assertEquals(SqlTypeUtil.convertType(relType).getName(), "EDGE"); + EdgeType edgeType = new EdgeType(Lists.newArrayList(field1, field2), false); + assertEquals(edgeType.getField(1).getName(), "id"); + assertEquals(edgeType.getTypeClass(), RowEdge.class); + assertEquals(edgeType.indexOf("name"), 0); + assertEquals(edgeType.getName(), "EDGE"); + assertEquals(edgeType.getField("name").getName(), "name"); + edgeType.addField(field3); + assertEquals(edgeType.getFieldNames().size(), 2); + assertNotNull(edgeType.toString()); + relType = SqlTypeUtil.convertToRelType(edgeType, true, typeFactory); + assertEquals( + relType.toString(), "Edge: RecordType:peek(VARCHAR name, BIGINT id, VARCHAR ~label)"); + assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 3); + assertEquals(SqlTypeUtil.convertType(relType).getName(), "EDGE"); - PathType pathType = new PathType(Lists.newArrayList(field1, field2)); - assertEquals(pathType.getField(1).getName(), "id"); - assertEquals(pathType.getTypeClass(), Path.class); - assertEquals(pathType.indexOf("name"), 0); - assertEquals(pathType.getName(), "PATH"); - assertEquals(pathType.getField("name").getName(), "name"); - pathType.addField(field3); - assertEquals(pathType.getFieldNames().size(), 2); - assertNotNull(pathType.toString()); - pathType.replace("name", field1); - assertEquals(pathType.getFieldNames().size(), 2); - relType = SqlTypeUtil.convertToRelType(pathType, true, typeFactory); - assertEquals(relType.toString(), - "Path:RecordType:peek(VARCHAR name, BIGINT id)"); - assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 2); - assertNotNull(SqlTypeUtil.convertType(relType)); + PathType pathType = new PathType(Lists.newArrayList(field1, field2)); + assertEquals(pathType.getField(1).getName(), "id"); + assertEquals(pathType.getTypeClass(), Path.class); + assertEquals(pathType.indexOf("name"), 0); + assertEquals(pathType.getName(), "PATH"); + assertEquals(pathType.getField("name").getName(), "name"); + pathType.addField(field3); + assertEquals(pathType.getFieldNames().size(), 2); + assertNotNull(pathType.toString()); + pathType.replace("name", field1); + assertEquals(pathType.getFieldNames().size(), 2); + relType = SqlTypeUtil.convertToRelType(pathType, true, typeFactory); + assertEquals(relType.toString(), "Path:RecordType:peek(VARCHAR name, BIGINT id)"); + assertEquals(SqlTypeUtil.convertToJavaTypes(relType, typeFactory).size(), 2); + assertNotNull(SqlTypeUtil.convertType(relType)); - ArrayType arrayType = new ArrayType(Types.LONG); - assertTrue(arrayType.getComponentType() instanceof LongType); - assertEquals(arrayType.getTypeClass(), Long[].class); - assertEquals(arrayType.getName(), "ARRAY"); - assertNotNull(arrayType.toString()); - Long[] longArray1 = new Long[]{0L, 1L, 2L}; - Long[] longArray2 = new Long[]{2L, 1L, 0L}; - assertEquals(arrayType.compare(longArray1, longArray2), -1); - relType = SqlTypeUtil.convertToRelType(arrayType, true, typeFactory); - assertEquals(relType.toString(), "BIGINT ARRAY"); - assertEquals(SqlTypeUtil.convertType(relType).getName(), "ARRAY"); - } + ArrayType arrayType = new ArrayType(Types.LONG); + assertTrue(arrayType.getComponentType() instanceof LongType); + assertEquals(arrayType.getTypeClass(), Long[].class); + assertEquals(arrayType.getName(), "ARRAY"); + assertNotNull(arrayType.toString()); + Long[] longArray1 = new Long[] {0L, 1L, 2L}; + Long[] longArray2 = new Long[] {2L, 1L, 0L}; + assertEquals(arrayType.compare(longArray1, longArray2), -1); + relType = SqlTypeUtil.convertToRelType(arrayType, true, typeFactory); + assertEquals(relType.toString(), "BIGINT ARRAY"); + assertEquals(SqlTypeUtil.convertType(relType).getName(), "ARRAY"); + } - @Test - public void testEdgeDirection() { - assertEquals(EdgeDirection.of("OUT").toString(), "OUT"); - assertEquals(EdgeDirection.of("IN").toString(), "IN"); - assertEquals(EdgeDirection.of("BOTH").toString(), "BOTH"); - } + @Test + public void testEdgeDirection() { + assertEquals(EdgeDirection.of("OUT").toString(), "OUT"); + assertEquals(EdgeDirection.of("IN").toString(), "IN"); + assertEquals(EdgeDirection.of("BOTH").toString(), "BOTH"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/ExecutionMode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/ExecutionMode.java index 0b8b3e1c1..1bde5555f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/ExecutionMode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/ExecutionMode.java @@ -20,15 +20,15 @@ package org.apache.geaflow.dsl.runtime; public enum ExecutionMode { - BATCH, - STREAM; + BATCH, + STREAM; - public static ExecutionMode of(String value) { - for (ExecutionMode executionMode : values()) { - if (executionMode.name().equalsIgnoreCase(value)) { - return executionMode; - } - } - throw new IllegalArgumentException("Illegal executionMode: " + value); + public static ExecutionMode of(String value) { + for (ExecutionMode executionMode : values()) { + if (executionMode.name().equalsIgnoreCase(value)) { + return executionMode; + } } + throw new IllegalArgumentException("Illegal executionMode: " + value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/InsertGraphMaterialCallback.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/InsertGraphMaterialCallback.java index 54603ef7e..e5e907fec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/InsertGraphMaterialCallback.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/InsertGraphMaterialCallback.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime; import java.util.Set; + import org.apache.geaflow.api.pdata.stream.window.PWindowStream; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; @@ -32,30 +33,30 @@ public class InsertGraphMaterialCallback implements QueryCallback { - public static final QueryCallback INSTANCE = new InsertGraphMaterialCallback(); + public static final QueryCallback INSTANCE = new InsertGraphMaterialCallback(); - private InsertGraphMaterialCallback() { + private InsertGraphMaterialCallback() {} - } + @Override + public void onQueryFinish(QueryContext queryContext) { + Set nonMaterializedGraphs = queryContext.getNonMaterializedGraphs(); + for (String graphName : nonMaterializedGraphs) { + PWindowStream vertexStream = queryContext.getGraphVertexStream(graphName); + PWindowStream edgeStream = queryContext.getGraphEdgeStream(graphName); + + GeaFlowGraph graph = queryContext.getGraph(graphName); + queryContext.updateVertexAndEdgeToGraph(graphName, graph, vertexStream, edgeStream); + + GraphViewDesc graphViewDesc = + SchemaUtil.buildGraphViewDesc(graph, queryContext.getGlobalConf()); + IPipelineJobContext pipelineContext = queryContext.getEngineContext().getContext(); + PGraphView graphView = + pipelineContext.createGraphView(graphViewDesc); + PIncGraphView incGraphView = + graphView.appendGraph((PWindowStream) vertexStream, (PWindowStream) edgeStream); - @Override - public void onQueryFinish(QueryContext queryContext) { - Set nonMaterializedGraphs = queryContext.getNonMaterializedGraphs(); - for (String graphName : nonMaterializedGraphs) { - PWindowStream vertexStream = queryContext.getGraphVertexStream(graphName); - PWindowStream edgeStream = queryContext.getGraphEdgeStream(graphName); - - GeaFlowGraph graph = queryContext.getGraph(graphName); - queryContext.updateVertexAndEdgeToGraph(graphName, graph, vertexStream, edgeStream); - - GraphViewDesc graphViewDesc = SchemaUtil.buildGraphViewDesc(graph, queryContext.getGlobalConf()); - IPipelineJobContext pipelineContext = queryContext.getEngineContext().getContext(); - PGraphView graphView = pipelineContext.createGraphView(graphViewDesc); - PIncGraphView incGraphView = - graphView.appendGraph((PWindowStream) vertexStream, (PWindowStream) edgeStream); - - incGraphView.materialize(); - queryContext.addMaterializedGraph(graphName); - } + incGraphView.materialize(); + queryContext.addMaterializedGraph(graphName); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryCallback.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryCallback.java index d5c050c7f..ed2e434fe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryCallback.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryCallback.java @@ -19,10 +19,8 @@ package org.apache.geaflow.dsl.runtime; -/** - * Call back for query compile. - */ +/** Call back for query compile. */ public interface QueryCallback { - void onQueryFinish(QueryContext queryContext); + void onQueryFinish(QueryContext queryContext); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryClient.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryClient.java index 567a3fb62..5d59bc21d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryClient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryClient.java @@ -21,12 +21,12 @@ import static org.apache.geaflow.common.config.keys.DSLConfigKeys.GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE; -import com.google.common.collect.ImmutableList; import java.util.HashSet; import java.util.List; import java.util.ServiceLoader; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.sql.*; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.util.NlsString; @@ -57,208 +57,216 @@ import org.apache.geaflow.runtime.pipeline.PipelineTaskType; import org.apache.geaflow.runtime.pipeline.task.PipelineTaskContext; -public class QueryClient implements QueryCompiler { +import com.google.common.collect.ImmutableList; - private final GeaFlowDSLParser parser = new GeaFlowDSLParser(); +public class QueryClient implements QueryCompiler { + private final GeaFlowDSLParser parser = new GeaFlowDSLParser(); - public QueryClient() { - } + public QueryClient() {} - /** - * Execute multi-query at once. - * - * @param sql The sql script which contains multi-query to execute. - * @param context The context for query. - */ - public void executeQuery(String sql, QueryContext context) { - try { - List sqlNodes = parser.parseMultiStatement(sql); - for (SqlNode sqlNode : sqlNodes) { - executeQuery(sqlNode, context); - } - } catch (SqlParseException e) { - throw new GeaFlowDSLException("Error in execute query: \n" + sql, e); - } + /** + * Execute multi-query at once. + * + * @param sql The sql script which contains multi-query to execute. + * @param context The context for query. + */ + public void executeQuery(String sql, QueryContext context) { + try { + List sqlNodes = parser.parseMultiStatement(sql); + for (SqlNode sqlNode : sqlNodes) { + executeQuery(sqlNode, context); + } + } catch (SqlParseException e) { + throw new GeaFlowDSLException("Error in execute query: \n" + sql, e); } + } - /** - * Execute single query. - * - * @param sql The sql to execute. - * @param context The context for executor engine. - * @return The result for the query. - */ - public QueryResult executeSingleQuery(String sql, QueryContext context) { + /** + * Execute single query. + * + * @param sql The sql to execute. + * @param context The context for executor engine. + * @return The result for the query. + */ + public QueryResult executeSingleQuery(String sql, QueryContext context) { - try { - SqlNode sqlNode = parser.parseStatement(sql); - return executeQuery(sqlNode, context); - } catch (SqlParseException e) { - throw new GeaFlowDSLException("Error in execute query: \n" + sql, e); - } + try { + SqlNode sqlNode = parser.parseStatement(sql); + return executeQuery(sqlNode, context); + } catch (SqlParseException e) { + throw new GeaFlowDSLException("Error in execute query: \n" + sql, e); } + } + private QueryResult executeQuery(SqlNode sqlNode, QueryContext context) { + IQueryCommand command = context.getCommand(sqlNode); + return command.execute(context); + } - private QueryResult executeQuery(SqlNode sqlNode, QueryContext context) { - IQueryCommand command = context.getCommand(sqlNode); - return command.execute(context); - } - - @Override - public CompileResult compile(String script, CompileContext context) { - PipelineContext pipelineContext = new PipelineContext(PipelineTaskType.CompileTask.name(), - new Configuration(context.getConfig())); - PipelineTaskContext pipelineTaskCxt = new PipelineTaskContext(0L, pipelineContext); - QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); - QueryContext queryContext = QueryContext.builder() + @Override + public CompileResult compile(String script, CompileContext context) { + PipelineContext pipelineContext = + new PipelineContext( + PipelineTaskType.CompileTask.name(), new Configuration(context.getConfig())); + PipelineTaskContext pipelineTaskCxt = new PipelineTaskContext(0L, pipelineContext); + QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); + QueryContext queryContext = + QueryContext.builder() .setEngineContext(engineContext) .setCompile(true) .setTraversalParallelism(-1) .build(); - queryContext.putConfigParallelism(context.getParallelisms()); - executeQuery(script, queryContext); + queryContext.putConfigParallelism(context.getParallelisms()); + executeQuery(script, queryContext); - CompileResult compileResult = new CompileResult(); - // Get current schema before finish. - compileResult.setCurrentResultType(queryContext.getCurrentResultType()); + CompileResult compileResult = new CompileResult(); + // Get current schema before finish. + compileResult.setCurrentResultType(queryContext.getCurrentResultType()); - queryContext.finish(); + queryContext.finish(); - compileResult.setSourceTables(queryContext.getReferSourceTables()); - compileResult.setTargetTables(queryContext.getReferTargetTables()); - compileResult.setSourceGraphs(queryContext.getReferSourceGraphs()); - compileResult.setTargetGraphs(queryContext.getReferTargetGraphs()); + compileResult.setSourceTables(queryContext.getReferSourceTables()); + compileResult.setTargetTables(queryContext.getReferTargetTables()); + compileResult.setSourceGraphs(queryContext.getReferSourceGraphs()); + compileResult.setTargetGraphs(queryContext.getReferTargetGraphs()); - boolean needPlan = ConfigHelper.getBooleanOrDefault(context.getConfig(), GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), Boolean.TRUE); - if (needPlan) { - PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); - PipelineGraph pipelineGraph = pipelinePlanBuilder.buildPlan(pipelineContext); - pipelinePlanBuilder.optimizePlan(pipelineContext.getConfig()); - JsonPlanGraphVisualization visualization = new JsonPlanGraphVisualization(pipelineGraph); - compileResult.setPhysicPlan(visualization.getJsonPlan()); - } - return compileResult; + boolean needPlan = + ConfigHelper.getBooleanOrDefault( + context.getConfig(), GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), Boolean.TRUE); + if (needPlan) { + PipelinePlanBuilder pipelinePlanBuilder = new PipelinePlanBuilder(); + PipelineGraph pipelineGraph = pipelinePlanBuilder.buildPlan(pipelineContext); + pipelinePlanBuilder.optimizePlan(pipelineContext.getConfig()); + JsonPlanGraphVisualization visualization = new JsonPlanGraphVisualization(pipelineGraph); + compileResult.setPhysicPlan(visualization.getJsonPlan()); } + return compileResult; + } + @Override + public String formatOlapResult(String script, Object queryResult, CompileContext context) { + context.getConfig().put("needPhysicalPlan", "false"); + CompileResult compileResult = compile(script, context); + return AnalyticsResultFormatter.formatResult(queryResult, compileResult.getCurrentResultType()); + } - @Override - public String formatOlapResult(String script, Object queryResult, CompileContext context) { - context.getConfig().put("needPhysicalPlan", "false"); - CompileResult compileResult = compile(script, context); - return AnalyticsResultFormatter.formatResult(queryResult, compileResult.getCurrentResultType()); - } - - @Override - public Set getUnResolvedFunctions(String script, CompileContext context) { - try { - List sqlNodes = parser.parseMultiStatement(script); - GQLContext gqlContext = GQLContext.create(new Configuration(context.getConfig()), true); - Set functions = new HashSet<>(); - for (SqlNode sqlNode : sqlNodes) { - if (sqlNode instanceof SqlCreateFunction) { - SqlCreateFunction function = (SqlCreateFunction) sqlNode; - String functionName = function.getFunctionName().toString(); - String instanceName = gqlContext.getCurrentInstance(); - FunctionInfo functionInfo = new FunctionInfo(instanceName, functionName); - functions.add(functionInfo); - } else if (sqlNode instanceof SqlUseInstance) { - SqlUseInstance useInstance = (SqlUseInstance) sqlNode; - String instanceName = useInstance.getInstance().toString(); - gqlContext.setCurrentInstance(instanceName); - } else { - List unresolvedFunctions = SqlNodeUtil.findUnresolvedFunctions(sqlNode); - for (SqlUnresolvedFunction unresolvedFunction : unresolvedFunctions) { - String name = unresolvedFunction.getName(); - SqlFunction sqlFunction = gqlContext.findSqlFunction(gqlContext.getCurrentInstance(), name); - if (sqlFunction == null) { - FunctionInfo functionInfo = new FunctionInfo(gqlContext.getCurrentInstance(), name); - functions.add(functionInfo); - } - } - } + @Override + public Set getUnResolvedFunctions(String script, CompileContext context) { + try { + List sqlNodes = parser.parseMultiStatement(script); + GQLContext gqlContext = GQLContext.create(new Configuration(context.getConfig()), true); + Set functions = new HashSet<>(); + for (SqlNode sqlNode : sqlNodes) { + if (sqlNode instanceof SqlCreateFunction) { + SqlCreateFunction function = (SqlCreateFunction) sqlNode; + String functionName = function.getFunctionName().toString(); + String instanceName = gqlContext.getCurrentInstance(); + FunctionInfo functionInfo = new FunctionInfo(instanceName, functionName); + functions.add(functionInfo); + } else if (sqlNode instanceof SqlUseInstance) { + SqlUseInstance useInstance = (SqlUseInstance) sqlNode; + String instanceName = useInstance.getInstance().toString(); + gqlContext.setCurrentInstance(instanceName); + } else { + List unresolvedFunctions = + SqlNodeUtil.findUnresolvedFunctions(sqlNode); + for (SqlUnresolvedFunction unresolvedFunction : unresolvedFunctions) { + String name = unresolvedFunction.getName(); + SqlFunction sqlFunction = + gqlContext.findSqlFunction(gqlContext.getCurrentInstance(), name); + if (sqlFunction == null) { + FunctionInfo functionInfo = new FunctionInfo(gqlContext.getCurrentInstance(), name); + functions.add(functionInfo); } - return functions; - } catch (Exception e) { - throw new GeaFlowDSLException("Error in parser dsl", e); + } } + } + return functions; + } catch (Exception e) { + throw new GeaFlowDSLException("Error in parser dsl", e); } + } - @Override - public Set getEnginePlugins() { - Set typs = new HashSet<>(); - ServiceLoader connectors = ServiceLoader.load(TableConnector.class); - for (TableConnector connector : connectors) { - typs.add(connector.getType().toUpperCase()); - } - return typs; + @Override + public Set getEnginePlugins() { + Set typs = new HashSet<>(); + ServiceLoader connectors = ServiceLoader.load(TableConnector.class); + for (TableConnector connector : connectors) { + typs.add(connector.getType().toUpperCase()); } + return typs; + } - @Override - public Set getDeclaredTablePlugins(String script, CompileContext context) { - try { - List sqlNodes = parser.parseMultiStatement(script); - Set plugins = new HashSet<>(); + @Override + public Set getDeclaredTablePlugins(String script, CompileContext context) { + try { + List sqlNodes = parser.parseMultiStatement(script); + Set plugins = new HashSet<>(); - for (SqlNode sqlNode : sqlNodes) { - if (sqlNode instanceof SqlCreateTable) { - SqlNodeList properties = ((SqlCreateTable) sqlNode).getProperties(); - for (SqlNode property : properties) { - if (property instanceof SqlTableProperty) { - ImmutableList names = ((SqlTableProperty) property).getKey().names; - String key = names.get(0); - if (key.equals("type")) { - NlsString nlsString = - (NlsString) ((SqlCharStringLiteral) ((SqlTableProperty) property).getValue()).getValue(); - plugins.add(nlsString.getValue()); - } - } - } - } + for (SqlNode sqlNode : sqlNodes) { + if (sqlNode instanceof SqlCreateTable) { + SqlNodeList properties = ((SqlCreateTable) sqlNode).getProperties(); + for (SqlNode property : properties) { + if (property instanceof SqlTableProperty) { + ImmutableList names = ((SqlTableProperty) property).getKey().names; + String key = names.get(0); + if (key.equals("type")) { + NlsString nlsString = + (NlsString) + ((SqlCharStringLiteral) ((SqlTableProperty) property).getValue()) + .getValue(); + plugins.add(nlsString.getValue()); + } } - return plugins; - } catch (Exception e) { - throw new GeaFlowDSLException("Error in parser dsl", e); + } } + } + return plugins; + } catch (Exception e) { + throw new GeaFlowDSLException("Error in parser dsl", e); } + } - @Override - public Set getUnResolvedTables(String script, CompileContext context) { - try { - List sqlNodes = parser.parseMultiStatement(script); - GQLContext gqlContext = GQLContext.create(new Configuration(context.getConfig()), true); + @Override + public Set getUnResolvedTables(String script, CompileContext context) { + try { + List sqlNodes = parser.parseMultiStatement(script); + GQLContext gqlContext = GQLContext.create(new Configuration(context.getConfig()), true); - Set declaredTables = new HashSet<>(); - Set unResolvedTables = new HashSet<>(); - for (SqlNode sqlNode : sqlNodes) { - if (sqlNode instanceof SqlCreateTable) { - ImmutableList names = ((SqlCreateTable) sqlNode).getName().names; - String tableName = names.get(0); - String instanceName = gqlContext.getCurrentInstance(); - declaredTables.add(new TableInfo(instanceName, tableName)); + Set declaredTables = new HashSet<>(); + Set unResolvedTables = new HashSet<>(); + for (SqlNode sqlNode : sqlNodes) { + if (sqlNode instanceof SqlCreateTable) { + ImmutableList names = ((SqlCreateTable) sqlNode).getName().names; + String tableName = names.get(0); + String instanceName = gqlContext.getCurrentInstance(); + declaredTables.add(new TableInfo(instanceName, tableName)); - } else if (sqlNode instanceof SqlCreateGraph) { - ImmutableList names = ((SqlCreateGraph) sqlNode).getName().names; - String tableName = names.get(0); - String instanceName = gqlContext.getCurrentInstance(); - declaredTables.add(new TableInfo(instanceName, tableName)); + } else if (sqlNode instanceof SqlCreateGraph) { + ImmutableList names = ((SqlCreateGraph) sqlNode).getName().names; + String tableName = names.get(0); + String instanceName = gqlContext.getCurrentInstance(); + declaredTables.add(new TableInfo(instanceName, tableName)); - } else if (sqlNode instanceof SqlUseInstance) { - SqlUseInstance useInstance = (SqlUseInstance) sqlNode; - String instanceName = useInstance.getInstance().toString(); - gqlContext.setCurrentInstance(instanceName); + } else if (sqlNode instanceof SqlUseInstance) { + SqlUseInstance useInstance = (SqlUseInstance) sqlNode; + String instanceName = useInstance.getInstance().toString(); + gqlContext.setCurrentInstance(instanceName); - } else { - Set usedTables = SqlNodeUtil.findUsedTables(sqlNode); - String instanceName = gqlContext.getCurrentInstance(); - for (String usedTable : usedTables) { - unResolvedTables.add(new TableInfo(instanceName, usedTable)); - } - } - } - return unResolvedTables.stream().filter(e -> !declaredTables.contains(e)).collect(Collectors.toSet()); - } catch (Exception e) { - throw new GeaFlowDSLException("Error in parser dsl", e); + } else { + Set usedTables = SqlNodeUtil.findUsedTables(sqlNode); + String instanceName = gqlContext.getCurrentInstance(); + for (String usedTable : usedTables) { + unResolvedTables.add(new TableInfo(instanceName, usedTable)); + } } + } + return unResolvedTables.stream() + .filter(e -> !declaredTables.contains(e)) + .collect(Collectors.toSet()); + } catch (Exception e) { + throw new GeaFlowDSLException("Error in parser dsl", e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryContext.java index c7d8c62e4..d06449fb8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryContext.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlNode; @@ -66,355 +67,355 @@ public class QueryContext { - private final QueryEngine engineContext; - - private final GQLContext gqlContext; - - private final Map viewDataViews = new HashMap<>(); - - private boolean isCompile; - - private final List optimizeRules = new ArrayList<>(OptimizeRules.RULE_GROUPS); - - private final PathReferenceAnalyzer pathAnalyzer; - - private RuntimeTable requestTable; - - private boolean isIdOnlyRequest; - - private Expression pushFilter; - - private final Map configParallelisms = new HashMap<>(); - - private long opNameCounter = 0L; - - private int traversalParallelism = -1; - - private final Map setOptions = new HashMap<>(); - - private final Map> graphVertices = new HashMap<>(); - - private final Map> graphEdges = new HashMap<>(); - - private final Map graphs = new HashMap<>(); - - private final Map runtimeTables = new HashMap<>(); - - private final Map runtimeGraphs = new HashMap<>(); + private final QueryEngine engineContext; - private final Set materializedGraphs = new HashSet<>(); + private final GQLContext gqlContext; - private final Set referSourceTables = new HashSet<>(); + private final Map viewDataViews = new HashMap<>(); - private final Set referTargetTables = new HashSet<>(); + private boolean isCompile; - private final Set referSourceGraphs = new HashSet<>(); + private final List optimizeRules = new ArrayList<>(OptimizeRules.RULE_GROUPS); - private final Set referTargetGraphs = new HashSet<>(); + private final PathReferenceAnalyzer pathAnalyzer; - private final List queryCallbacks = new ArrayList<>(); + private RuntimeTable requestTable; - private RelDataType currentResultType; + private boolean isIdOnlyRequest; - public RelDataType getCurrentResultType() { - return currentResultType; - } - - public QueryContext setCurrentResultType(RelDataType currentResultType) { - this.currentResultType = currentResultType; - return this; - } - - private QueryContext(QueryEngine engineContext, boolean isCompile) { - this.engineContext = engineContext; - this.gqlContext = GQLContext.create(new Configuration(engineContext.getConfig()), isCompile); - this.pathAnalyzer = new PathReferenceAnalyzer(gqlContext); - this.isCompile = isCompile; - registerQueryCallback(InsertGraphMaterialCallback.INSTANCE); - } - - public IQueryCommand getCommand(SqlNode node) { - SqlKind kind = node.getKind(); - if (!kind.belongsTo(SqlKind.TOP_LEVEL)) { - throw new IllegalArgumentException("SqlNode is a top level query, current kind is: " + kind); - } - switch (kind) { - case SELECT: - case GQL_FILTER: - case GQL_MATCH_PATTERN: - case GQL_RETURN: - case INSERT: - case ORDER_BY: - case WITH: - return new QueryCommand(node); - case CREATE_TABLE: - return new CreateTableCommand((SqlCreateTable) node); - case CREATE_VIEW: - return new CreateViewCommand((SqlCreateView) node); - case SET_OPTION: - return new SetCommand((SqlSetOption) node); - case CREATE_GRAPH: - return new CreateGraphCommand((SqlCreateGraph) node); - case DROP_GRAPH: - return new DropGraphCommand((SqlDropGraph) node); - case DESC_GRAPH: - return new DescGraphCommand((SqlDescGraph) node); - case ALTER_GRAPH: - return new AlterGraphCommand((SqlAlterGraph) node); - case USE_GRAPH: - return new UseGraphCommand((SqlUseGraph) node); - case USE_INSTANCE: - return new UseInstanceCommand((SqlUseInstance) node); - case CREATE_FUNCTION: - return new CreateFunctionCommand((SqlCreateFunction) node); - default: - throw new IllegalArgumentException("Not support sql kind: " + kind); - } - } - - public GQLContext getGqlContext() { - return gqlContext; - } - - public QueryEngine getEngineContext() { - return engineContext; - } + private Expression pushFilter; - public List getLogicalRules() { - return optimizeRules; - } + private final Map configParallelisms = new HashMap<>(); - public boolean isCompile() { - return isCompile; - } + private long opNameCounter = 0L; - public RDataView getDataViewByViewName(String viewName) { - return viewDataViews.get(viewName); - } + private int traversalParallelism = -1; - public void putViewDataView(String viewName, RDataView dataView) { - if (viewDataViews.containsKey(viewName)) { - throw new IllegalArgumentException("View: " + viewName + " has already registered"); - } - if (dataView == null) { - throw new IllegalArgumentException("DataView is null"); - } - viewDataViews.put(viewName, dataView); - } + private final Map setOptions = new HashMap<>(); - public boolean setCompile(boolean isCompile) { - boolean oldValue = this.isCompile; - this.isCompile = isCompile; - return oldValue; - } + private final Map> graphVertices = new HashMap<>(); - public PathReferenceAnalyzer getPathAnalyzer() { - return pathAnalyzer; - } + private final Map> graphEdges = new HashMap<>(); - public RuntimeTable setRequestTable(RuntimeTable requestTable) { - RuntimeTable preValue = this.requestTable; - this.requestTable = requestTable; - return preValue; - } + private final Map graphs = new HashMap<>(); - public boolean setIdOnlyRequest(boolean isIdOnlyRequest) { - boolean preValue = this.isIdOnlyRequest; - this.isIdOnlyRequest = isIdOnlyRequest; - return preValue; - } + private final Map runtimeTables = new HashMap<>(); - public RuntimeTable getRequestTable() { - return requestTable; - } + private final Map runtimeGraphs = new HashMap<>(); - public boolean isIdOnlyRequest() { - return isIdOnlyRequest; - } + private final Set materializedGraphs = new HashSet<>(); - public Expression setPushFilter(Expression pushFilter) { - Expression preFilter = this.pushFilter; - this.pushFilter = pushFilter; - return preFilter; - } + private final Set referSourceTables = new HashSet<>(); - public Expression getPushFilter() { - return pushFilter; - } + private final Set referTargetTables = new HashSet<>(); - public int getConfigParallelisms(String opName, int defaultParallelism) { - return configParallelisms.getOrDefault(opName, defaultParallelism); - } + private final Set referSourceGraphs = new HashSet<>(); - public void putConfigParallelism(String opName, int parallelism) { - configParallelisms.put(opName, parallelism); - } + private final Set referTargetGraphs = new HashSet<>(); - public void putConfigParallelism(Map parallelisms) { - configParallelisms.putAll(parallelisms); - } + private final List queryCallbacks = new ArrayList<>(); - public long getOpNameCount() { - return opNameCounter++; - } + private RelDataType currentResultType; - public String createOperatorName(String baseName) { - return baseName + "-" + getOpNameCount(); - } + public RelDataType getCurrentResultType() { + return currentResultType; + } - public Map getSetOptions() { - return setOptions; - } + public QueryContext setCurrentResultType(RelDataType currentResultType) { + this.currentResultType = currentResultType; + return this; + } - public void putSetOption(String key, String value) { - this.setOptions.put(key, value); - } + private QueryContext(QueryEngine engineContext, boolean isCompile) { + this.engineContext = engineContext; + this.gqlContext = GQLContext.create(new Configuration(engineContext.getConfig()), isCompile); + this.pathAnalyzer = new PathReferenceAnalyzer(gqlContext); + this.isCompile = isCompile; + registerQueryCallback(InsertGraphMaterialCallback.INSTANCE); + } - public void updateVertexAndEdgeToGraph(String graphName, - GeaFlowGraph graph, - PWindowStream vertexStream, - PWindowStream edgeStream) { - graphs.put(graphName, graph); - graphVertices.put(graphName, vertexStream); - graphEdges.put(graphName, edgeStream); + public IQueryCommand getCommand(SqlNode node) { + SqlKind kind = node.getKind(); + if (!kind.belongsTo(SqlKind.TOP_LEVEL)) { + throw new IllegalArgumentException("SqlNode is a top level query, current kind is: " + kind); } - - public PWindowStream getGraphVertexStream(String graphName) { - return graphVertices.get(graphName); + switch (kind) { + case SELECT: + case GQL_FILTER: + case GQL_MATCH_PATTERN: + case GQL_RETURN: + case INSERT: + case ORDER_BY: + case WITH: + return new QueryCommand(node); + case CREATE_TABLE: + return new CreateTableCommand((SqlCreateTable) node); + case CREATE_VIEW: + return new CreateViewCommand((SqlCreateView) node); + case SET_OPTION: + return new SetCommand((SqlSetOption) node); + case CREATE_GRAPH: + return new CreateGraphCommand((SqlCreateGraph) node); + case DROP_GRAPH: + return new DropGraphCommand((SqlDropGraph) node); + case DESC_GRAPH: + return new DescGraphCommand((SqlDescGraph) node); + case ALTER_GRAPH: + return new AlterGraphCommand((SqlAlterGraph) node); + case USE_GRAPH: + return new UseGraphCommand((SqlUseGraph) node); + case USE_INSTANCE: + return new UseInstanceCommand((SqlUseInstance) node); + case CREATE_FUNCTION: + return new CreateFunctionCommand((SqlCreateFunction) node); + default: + throw new IllegalArgumentException("Not support sql kind: " + kind); } + } - public PWindowStream getGraphEdgeStream(String graphName) { - return graphEdges.get(graphName); - } + public GQLContext getGqlContext() { + return gqlContext; + } - public GeaFlowGraph getGraph(String graphName) { - return graphs.get(graphName); - } + public QueryEngine getEngineContext() { + return engineContext; + } + + public List getLogicalRules() { + return optimizeRules; + } + + public boolean isCompile() { + return isCompile; + } + + public RDataView getDataViewByViewName(String viewName) { + return viewDataViews.get(viewName); + } + + public void putViewDataView(String viewName, RDataView dataView) { + if (viewDataViews.containsKey(viewName)) { + throw new IllegalArgumentException("View: " + viewName + " has already registered"); + } + if (dataView == null) { + throw new IllegalArgumentException("DataView is null"); + } + viewDataViews.put(viewName, dataView); + } + + public boolean setCompile(boolean isCompile) { + boolean oldValue = this.isCompile; + this.isCompile = isCompile; + return oldValue; + } + + public PathReferenceAnalyzer getPathAnalyzer() { + return pathAnalyzer; + } + + public RuntimeTable setRequestTable(RuntimeTable requestTable) { + RuntimeTable preValue = this.requestTable; + this.requestTable = requestTable; + return preValue; + } + + public boolean setIdOnlyRequest(boolean isIdOnlyRequest) { + boolean preValue = this.isIdOnlyRequest; + this.isIdOnlyRequest = isIdOnlyRequest; + return preValue; + } + + public RuntimeTable getRequestTable() { + return requestTable; + } + + public boolean isIdOnlyRequest() { + return isIdOnlyRequest; + } + + public Expression setPushFilter(Expression pushFilter) { + Expression preFilter = this.pushFilter; + this.pushFilter = pushFilter; + return preFilter; + } + + public Expression getPushFilter() { + return pushFilter; + } + + public int getConfigParallelisms(String opName, int defaultParallelism) { + return configParallelisms.getOrDefault(opName, defaultParallelism); + } + + public void putConfigParallelism(String opName, int parallelism) { + configParallelisms.put(opName, parallelism); + } + + public void putConfigParallelism(Map parallelisms) { + configParallelisms.putAll(parallelisms); + } + + public long getOpNameCount() { + return opNameCounter++; + } + + public String createOperatorName(String baseName) { + return baseName + "-" + getOpNameCount(); + } + + public Map getSetOptions() { + return setOptions; + } + + public void putSetOption(String key, String value) { + this.setOptions.put(key, value); + } + + public void updateVertexAndEdgeToGraph( + String graphName, + GeaFlowGraph graph, + PWindowStream vertexStream, + PWindowStream edgeStream) { + graphs.put(graphName, graph); + graphVertices.put(graphName, vertexStream); + graphEdges.put(graphName, edgeStream); + } + + public PWindowStream getGraphVertexStream(String graphName) { + return graphVertices.get(graphName); + } + + public PWindowStream getGraphEdgeStream(String graphName) { + return graphEdges.get(graphName); + } + + public GeaFlowGraph getGraph(String graphName) { + return graphs.get(graphName); + } + + public void addGraph(String graphName, GeaFlowGraph graph) { + graphs.put(graphName, graph); + } + + public void addMaterializedGraph(String graphName) { + this.materializedGraphs.add(graphName); + } + + public Set getNonMaterializedGraphs() { + Set graphs = new HashSet<>(); + graphs.addAll(graphVertices.keySet()); + graphs.addAll(graphEdges.keySet()); + + for (String materializedGraph : materializedGraphs) { + graphs.remove(materializedGraph); + } + return graphs; + } + + public void addReferSourceTable(GeaFlowTable table) { + referSourceTables.add(new TableInfo(table.getInstanceName(), table.getName())); + } + + public void addReferTargetTable(GeaFlowTable table) { + referTargetTables.add(new TableInfo(table.getInstanceName(), table.getName())); + } + + public void addReferSourceGraph(GeaFlowGraph graph) { + referSourceGraphs.add(new GraphInfo(graph.getInstanceName(), graph.getName())); + } - public void addGraph(String graphName, GeaFlowGraph graph) { - graphs.put(graphName, graph); - } + public void addReferTargetGraph(GeaFlowGraph graph) { + referTargetGraphs.add(new GraphInfo(graph.getInstanceName(), graph.getName())); + } - public void addMaterializedGraph(String graphName) { - this.materializedGraphs.add(graphName); - } + public Set getReferSourceTables() { + return referSourceTables; + } - public Set getNonMaterializedGraphs() { - Set graphs = new HashSet<>(); - graphs.addAll(graphVertices.keySet()); - graphs.addAll(graphEdges.keySet()); + public Set getReferTargetTables() { + return referTargetTables; + } - for (String materializedGraph : materializedGraphs) { - graphs.remove(materializedGraph); - } - return graphs; - } + public Set getReferSourceGraphs() { + return referSourceGraphs; + } - public void addReferSourceTable(GeaFlowTable table) { - referSourceTables.add(new TableInfo(table.getInstanceName(), table.getName())); - } + public Set getReferTargetGraphs() { + return referTargetGraphs; + } - public void addReferTargetTable(GeaFlowTable table) { - referTargetTables.add(new TableInfo(table.getInstanceName(), table.getName())); - } + public static QueryContextBuilder builder() { + return new QueryContextBuilder(); + } - public void addReferSourceGraph(GeaFlowGraph graph) { - referSourceGraphs.add(new GraphInfo(graph.getInstanceName(), graph.getName())); - } + public void putRuntimeTable(String tableName, RuntimeTable table) { + runtimeTables.put(tableName, table); + } - public void addReferTargetGraph(GeaFlowGraph graph) { - referTargetGraphs.add(new GraphInfo(graph.getInstanceName(), graph.getName())); - } + public RuntimeTable getRuntimeTable(String tableName) { + return runtimeTables.get(tableName); + } - public Set getReferSourceTables() { - return referSourceTables; - } + public void putRuntimeGraph(String graphName, RuntimeGraph graph) { + runtimeGraphs.put(graphName, graph); + } - public Set getReferTargetTables() { - return referTargetTables; - } + public RuntimeGraph getRuntimeGraph(String graphName) { + return runtimeGraphs.get(graphName); + } - public Set getReferSourceGraphs() { - return referSourceGraphs; - } + public Configuration getGlobalConf() { + Map globalConf = new HashMap<>(engineContext.getConfig()); + globalConf.putAll(setOptions); + return new Configuration(globalConf); + } - public Set getReferTargetGraphs() { - return referTargetGraphs; - } + public void setTraversalParallelism(int traversalParallelism) { + this.traversalParallelism = traversalParallelism; + } - public static QueryContextBuilder builder() { - return new QueryContextBuilder(); - } + public int getTraversalParallelism() { + return this.traversalParallelism; + } - public void putRuntimeTable(String tableName, RuntimeTable table) { - runtimeTables.put(tableName, table); - } + public void registerQueryCallback(QueryCallback callback) { + queryCallbacks.add(callback); + } - public RuntimeTable getRuntimeTable(String tableName) { - return runtimeTables.get(tableName); + public void finish() { + for (QueryCallback callback : queryCallbacks) { + callback.onQueryFinish(this); } + this.currentResultType = null; + } - public void putRuntimeGraph(String graphName, RuntimeGraph graph) { - runtimeGraphs.put(graphName, graph); - } + public static class QueryContextBuilder { - public RuntimeGraph getRuntimeGraph(String graphName) { - return runtimeGraphs.get(graphName); - } + private QueryEngine engineContext; - public Configuration getGlobalConf() { - Map globalConf = new HashMap<>(engineContext.getConfig()); - globalConf.putAll(setOptions); - return new Configuration(globalConf); - } + private boolean isCompile; - public void setTraversalParallelism(int traversalParallelism) { - this.traversalParallelism = traversalParallelism; - } + private int traversalParallelism = -1; - public int getTraversalParallelism() { - return this.traversalParallelism; + public QueryContextBuilder setEngineContext(QueryEngine engineContext) { + this.engineContext = engineContext; + return this; } - - public void registerQueryCallback(QueryCallback callback) { - queryCallbacks.add(callback); + public QueryContextBuilder setCompile(boolean compile) { + isCompile = compile; + return this; } - public void finish() { - for (QueryCallback callback : queryCallbacks) { - callback.onQueryFinish(this); - } - this.currentResultType = null; + public QueryContextBuilder setTraversalParallelism(int traversalParallelism) { + this.traversalParallelism = traversalParallelism; + return this; } - public static class QueryContextBuilder { - - private QueryEngine engineContext; - - private boolean isCompile; - - private int traversalParallelism = -1; - - public QueryContextBuilder setEngineContext(QueryEngine engineContext) { - this.engineContext = engineContext; - return this; - } - - public QueryContextBuilder setCompile(boolean compile) { - isCompile = compile; - return this; - } - - public QueryContextBuilder setTraversalParallelism(int traversalParallelism) { - this.traversalParallelism = traversalParallelism; - return this; - } - - public QueryContext build() { - QueryContext context = new QueryContext(engineContext, isCompile); - context.setTraversalParallelism(traversalParallelism); - return context; - } + public QueryContext build() { + QueryContext context = new QueryContext(engineContext, isCompile); + context.setTraversalParallelism(traversalParallelism); + return context; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryEngine.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryEngine.java index feca664be..046231f8c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryEngine.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryEngine.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.Map; + import org.apache.geaflow.api.function.io.SourceFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.dsl.common.data.Row; @@ -29,20 +30,18 @@ import org.apache.geaflow.dsl.schema.GeaFlowTable; import org.apache.geaflow.pipeline.job.IPipelineJobContext; -/** - * Interface for a query engine. - */ +/** Interface for a query engine. */ public interface QueryEngine { - Map getConfig(); + Map getConfig(); - IPipelineJobContext getContext(); + IPipelineJobContext getContext(); - RuntimeTable createRuntimeTable(QueryContext context, GeaFlowTable table, Expression pushFilter); + RuntimeTable createRuntimeTable(QueryContext context, GeaFlowTable table, Expression pushFilter); - RuntimeTable createRuntimeTable(QueryContext context, Collection rows); + RuntimeTable createRuntimeTable(QueryContext context, Collection rows); - PWindowSource createRuntimeTable(QueryContext context, SourceFunction sourceFunction); + PWindowSource createRuntimeTable(QueryContext context, SourceFunction sourceFunction); - RuntimeGraph createRuntimeGraph(QueryContext context, GeaFlowGraph graph); + RuntimeGraph createRuntimeGraph(QueryContext context, GeaFlowGraph graph); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryResult.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryResult.java index 9ce6c78ce..31046387b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryResult.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/QueryResult.java @@ -21,69 +21,69 @@ import java.io.Serializable; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; public class QueryResult implements Serializable { - private List results; + private List results; - private RDataView dataView; + private RDataView dataView; - private boolean success; + private boolean success; - private String errorMsg; + private String errorMsg; - public QueryResult(List results, RDataView dataView, boolean success, String errorMsg) { - this.results = results; - this.dataView = dataView; - this.success = success; - this.errorMsg = errorMsg; - } + public QueryResult(List results, RDataView dataView, boolean success, String errorMsg) { + this.results = results; + this.dataView = dataView; + this.success = success; + this.errorMsg = errorMsg; + } - public QueryResult(List results) { - this(results, null, true, null); - } + public QueryResult(List results) { + this(results, null, true, null); + } - public QueryResult(boolean success) { - this(null, null, success, null); - } + public QueryResult(boolean success) { + this(null, null, success, null); + } - public QueryResult(RDataView dataView) { - this(null, dataView, true, null); - } + public QueryResult(RDataView dataView) { + this(null, dataView, true, null); + } - public QueryResult() { - } + public QueryResult() {} - public List getResults() { - return results; - } + public List getResults() { + return results; + } - public boolean isSuccess() { - return success; - } + public boolean isSuccess() { + return success; + } - public String getErrorMsg() { - return errorMsg; - } + public String getErrorMsg() { + return errorMsg; + } - public void setResults(List results) { - this.results = results; - } + public void setResults(List results) { + this.results = results; + } - public void setSuccess(boolean success) { - this.success = success; - } + public void setSuccess(boolean success) { + this.success = success; + } - public void setErrorMsg(String errorMsg) { - this.errorMsg = errorMsg; - } + public void setErrorMsg(String errorMsg) { + this.errorMsg = errorMsg; + } - public RDataView getDataView() { - return dataView; - } + public RDataView getDataView() { + return dataView; + } - public void setDataView(RDataView dataView) { - this.dataView = dataView; - } + public void setDataView(RDataView dataView) { + this.dataView = dataView; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RDataView.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RDataView.java index 8234d09cf..e167eb187 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RDataView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RDataView.java @@ -20,23 +20,22 @@ package org.apache.geaflow.dsl.runtime; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; -/** - * Represents a view of the data at the runtime. - */ +/** Represents a view of the data at the runtime. */ public interface RDataView { - T getPlan(); + T getPlan(); - ViewType getType(); + ViewType getType(); - List take(IType type); + List take(IType type); - enum ViewType { - TABLE, - GRAPH, - SINK - } + enum ViewType { + TABLE, + GRAPH, + SINK + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeGraph.java index 6c88129b8..eaa6a9c6d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeGraph.java @@ -20,26 +20,27 @@ package org.apache.geaflow.dsl.runtime; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.rel.GraphAlgorithm; import org.apache.geaflow.dsl.rel.GraphMatch; /** - * The runtime graph view which mapping logical graph operator to the runtime - * representation of the underlying engine. + * The runtime graph view which mapping logical graph operator to the runtime representation of the + * underlying engine. */ public interface RuntimeGraph extends RDataView { - List take(IType type); + List take(IType type); - RuntimeGraph traversal(GraphMatch graphMatch); + RuntimeGraph traversal(GraphMatch graphMatch); - RuntimeTable getPathTable(); + RuntimeTable getPathTable(); - RuntimeTable runAlgorithm(GraphAlgorithm graphAlgorithm); + RuntimeTable runAlgorithm(GraphAlgorithm graphAlgorithm); - default ViewType getType() { - return ViewType.GRAPH; - } + default ViewType getType() { + return ViewType.GRAPH; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeTable.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeTable.java index 3ef16c6b3..e729432b9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/RuntimeTable.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.function.table.AggFunction; @@ -33,32 +34,32 @@ import org.apache.geaflow.dsl.schema.GeaFlowTable; /** - * The runtime table view which mapping SQL function to the runtime - * representation of the underlying engine. + * The runtime table view which mapping SQL function to the runtime representation of the underlying + * engine. */ public interface RuntimeTable extends RDataView { - RuntimeTable project(ProjectFunction function); + RuntimeTable project(ProjectFunction function); - RuntimeTable filter(WhereFunction function); + RuntimeTable filter(WhereFunction function); - RuntimeTable join(RuntimeTable other, JoinTableFunction function); + RuntimeTable join(RuntimeTable other, JoinTableFunction function); - RuntimeTable aggregate(GroupByFunction groupByFunction, AggFunction aggFunction); + RuntimeTable aggregate(GroupByFunction groupByFunction, AggFunction aggFunction); - RuntimeTable union(RuntimeTable other); + RuntimeTable union(RuntimeTable other); - RuntimeTable orderBy(OrderByFunction function); + RuntimeTable orderBy(OrderByFunction function); - RuntimeTable correlate(CorrelateFunction function); + RuntimeTable correlate(CorrelateFunction function); - SinkDataView write(GeaFlowTable table); + SinkDataView write(GeaFlowTable table); - SinkDataView write(GeaFlowGraph graph, QueryContext queryContext); + SinkDataView write(GeaFlowGraph graph, QueryContext queryContext); - List take(IType type); + List take(IType type); - default ViewType getType() { - return ViewType.TABLE; - } + default ViewType getType() { + return ViewType.TABLE; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/SinkDataView.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/SinkDataView.java index f6f54ac67..5a0f31b03 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/SinkDataView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/SinkDataView.java @@ -21,7 +21,7 @@ public interface SinkDataView extends RDataView { - default ViewType getType() { - return ViewType.SINK; - } + default ViewType getType() { + return ViewType.SINK; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/AlterGraphCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/AlterGraphCommand.java index 7b1a7359c..b13411377 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/AlterGraphCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/AlterGraphCommand.java @@ -26,19 +26,19 @@ public class AlterGraphCommand implements IQueryCommand { - private final SqlAlterGraph alterGraph; + private final SqlAlterGraph alterGraph; - public AlterGraphCommand(SqlAlterGraph alterGraph) { - this.alterGraph = alterGraph; - } + public AlterGraphCommand(SqlAlterGraph alterGraph) { + this.alterGraph = alterGraph; + } - @Override - public QueryResult execute(QueryContext context) { - return null; - } + @Override + public QueryResult execute(QueryContext context) { + return null; + } - @Override - public SqlNode getSqlNode() { - return alterGraph; - } + @Override + public SqlNode getSqlNode() { + return alterGraph; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateFunctionCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateFunctionCommand.java index 23b68ce2b..4f8d4e7aa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateFunctionCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateFunctionCommand.java @@ -30,26 +30,26 @@ public class CreateFunctionCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(CreateFunctionCommand.class); - - private final SqlCreateFunction createFunction; - - public CreateFunctionCommand(SqlCreateFunction createFunction) { - this.createFunction = createFunction; - } - - @Override - public QueryResult execute(QueryContext context) { - GQLContext gqlContext = context.getGqlContext(); - GeaFlowFunction function = GeaFlowFunction.toFunction(createFunction); - // register function to catalog. - gqlContext.registerFunction(function); - LOGGER.info("Success to create function: \n{}", function); - return new QueryResult(true); - } - - @Override - public SqlNode getSqlNode() { - return createFunction; - } + private static final Logger LOGGER = LoggerFactory.getLogger(CreateFunctionCommand.class); + + private final SqlCreateFunction createFunction; + + public CreateFunctionCommand(SqlCreateFunction createFunction) { + this.createFunction = createFunction; + } + + @Override + public QueryResult execute(QueryContext context) { + GQLContext gqlContext = context.getGqlContext(); + GeaFlowFunction function = GeaFlowFunction.toFunction(createFunction); + // register function to catalog. + gqlContext.registerFunction(function); + LOGGER.info("Success to create function: \n{}", function); + return new QueryResult(true); + } + + @Override + public SqlNode getSqlNode() { + return createFunction; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateGraphCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateGraphCommand.java index 9efc49d04..cdcb9415c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateGraphCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateGraphCommand.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelRecordType; import org.apache.calcite.schema.Table; @@ -48,112 +49,127 @@ public class CreateGraphCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(CreateGraphCommand.class); - - private final SqlCreateGraph createGraph; - - public CreateGraphCommand(SqlCreateGraph createGraph) { - this.createGraph = createGraph; - } + private static final Logger LOGGER = LoggerFactory.getLogger(CreateGraphCommand.class); - @Override - public QueryResult execute(QueryContext context) { - GQLContext gContext = context.getGqlContext(); - GeaFlowGraph graph = gContext.convertToGraph(createGraph); - gContext.registerGraph(graph); - processUsing(graph, context); - context.addGraph(graph.getName(), graph); - LOGGER.info("Succeed to create graph: {}.", graph); - return new QueryResult(true); - } + private final SqlCreateGraph createGraph; + public CreateGraphCommand(SqlCreateGraph createGraph) { + this.createGraph = createGraph; + } - private void processUsing(GeaFlowGraph graph, QueryContext context) { + @Override + public QueryResult execute(QueryContext context) { + GQLContext gContext = context.getGqlContext(); + GeaFlowGraph graph = gContext.convertToGraph(createGraph); + gContext.registerGraph(graph); + processUsing(graph, context); + context.addGraph(graph.getName(), graph); + LOGGER.info("Succeed to create graph: {}.", graph); + return new QueryResult(true); + } - //Graph first time creation will trigger insert operations - if (!QueryUtil.isGraphExists(graph, graph.getConfigWithGlobal(context.getGlobalConf()))) { - //Convert graph construction using tables to equivalent insert statement - Map vertexEdgeName2UsingTableNameMap = graph.getUsingTables(); - List graphElements = new ArrayList<>(graph.getVertexTables()); - graphElements.addAll(graph.getEdgeTables()); - RelDataTypeFactory factory = context.getGqlContext().getTypeFactory(); - for (GraphElementTable tbl : graphElements) { - if (vertexEdgeName2UsingTableNameMap.containsKey(tbl.getTypeName())) { - String usingTable = vertexEdgeName2UsingTableNameMap.get(tbl.getTypeName()); - Table table = context.getGqlContext().getCatalog().getTable( - context.getGqlContext().getCurrentInstance(), usingTable); - assert table instanceof GeaFlowTable; - SqlCall relatedSqlCall = null; - List createGraphElements = - new ArrayList<>(createGraph.getVertices().getList()); - createGraphElements.addAll(createGraph.getEdges().getList()); - for (SqlNode node : createGraphElements) { - if (node instanceof SqlVertexUsing - && ((SqlVertexUsing) node).getName().getSimple().equals(tbl.getTypeName())) { - relatedSqlCall = ((SqlVertexUsing) node); - } else if (node instanceof SqlEdgeUsing - && ((SqlEdgeUsing) node).getName().getSimple().equals(tbl.getTypeName())) { - relatedSqlCall = ((SqlEdgeUsing) node); - } - } - assert relatedSqlCall != null; - RelRecordType reorderType; - if (tbl instanceof VertexTable) { - VertexTable vertexTable = (VertexTable) tbl; - reorderType = VertexRecordType.createVertexType( - vertexTable.getRowType(factory).getFieldList(), - ((SqlVertexUsing) relatedSqlCall).getId().getSimple(), - factory - ); - } else { - EdgeTable edgeTable = (EdgeTable) tbl; - SqlEdgeUsing edgeUsing = ((SqlEdgeUsing) relatedSqlCall); - reorderType = EdgeRecordType.createEdgeType( - edgeTable.getRowType(factory).getFieldList(), - edgeUsing.getSourceId().getSimple(), - edgeUsing.getTargetId().getSimple(), - edgeUsing.getTimeField() == null ? null : edgeUsing.getTimeField().getSimple(), - factory - ); - } + private void processUsing(GeaFlowGraph graph, QueryContext context) { - SqlNode insertSqlNode = createUsingGraphInsert(createGraph.getParserPosition(), - graph, tbl.getTypeName(), usingTable, reorderType); - QueryCommand insertCommand = new QueryCommand(insertSqlNode); - insertCommand.execute(context); - } + // Graph first time creation will trigger insert operations + if (!QueryUtil.isGraphExists(graph, graph.getConfigWithGlobal(context.getGlobalConf()))) { + // Convert graph construction using tables to equivalent insert statement + Map vertexEdgeName2UsingTableNameMap = graph.getUsingTables(); + List graphElements = new ArrayList<>(graph.getVertexTables()); + graphElements.addAll(graph.getEdgeTables()); + RelDataTypeFactory factory = context.getGqlContext().getTypeFactory(); + for (GraphElementTable tbl : graphElements) { + if (vertexEdgeName2UsingTableNameMap.containsKey(tbl.getTypeName())) { + String usingTable = vertexEdgeName2UsingTableNameMap.get(tbl.getTypeName()); + Table table = + context + .getGqlContext() + .getCatalog() + .getTable(context.getGqlContext().getCurrentInstance(), usingTable); + assert table instanceof GeaFlowTable; + SqlCall relatedSqlCall = null; + List createGraphElements = new ArrayList<>(createGraph.getVertices().getList()); + createGraphElements.addAll(createGraph.getEdges().getList()); + for (SqlNode node : createGraphElements) { + if (node instanceof SqlVertexUsing + && ((SqlVertexUsing) node).getName().getSimple().equals(tbl.getTypeName())) { + relatedSqlCall = ((SqlVertexUsing) node); + } else if (node instanceof SqlEdgeUsing + && ((SqlEdgeUsing) node).getName().getSimple().equals(tbl.getTypeName())) { + relatedSqlCall = ((SqlEdgeUsing) node); } - } else { - LOGGER.warn("The graph: {} already exists, skip exec using.", graph.getName()); + } + assert relatedSqlCall != null; + RelRecordType reorderType; + if (tbl instanceof VertexTable) { + VertexTable vertexTable = (VertexTable) tbl; + reorderType = + VertexRecordType.createVertexType( + vertexTable.getRowType(factory).getFieldList(), + ((SqlVertexUsing) relatedSqlCall).getId().getSimple(), + factory); + } else { + EdgeTable edgeTable = (EdgeTable) tbl; + SqlEdgeUsing edgeUsing = ((SqlEdgeUsing) relatedSqlCall); + reorderType = + EdgeRecordType.createEdgeType( + edgeTable.getRowType(factory).getFieldList(), + edgeUsing.getSourceId().getSimple(), + edgeUsing.getTargetId().getSimple(), + edgeUsing.getTimeField() == null ? null : edgeUsing.getTimeField().getSimple(), + factory); + } + + SqlNode insertSqlNode = + createUsingGraphInsert( + createGraph.getParserPosition(), + graph, + tbl.getTypeName(), + usingTable, + reorderType); + QueryCommand insertCommand = new QueryCommand(insertSqlNode); + insertCommand.execute(context); } + } + } else { + LOGGER.warn("The graph: {} already exists, skip exec using.", graph.getName()); } + } - private static SqlNode createUsingGraphInsert(SqlParserPos pos, - GeaFlowGraph graph, - String graphElementName, - String usingTable, - RelRecordType reorderType) { - List elementNames = new ArrayList<>(); - elementNames.add(graph.getName()); - elementNames.add(graphElementName); - List columns = reorderType.getFieldList().stream() + private static SqlNode createUsingGraphInsert( + SqlParserPos pos, + GeaFlowGraph graph, + String graphElementName, + String usingTable, + RelRecordType reorderType) { + List elementNames = new ArrayList<>(); + elementNames.add(graph.getName()); + elementNames.add(graphElementName); + List columns = + reorderType.getFieldList().stream() .filter(f -> !f.getName().equals(GraphSchema.LABEL_FIELD_NAME)) .map(f -> new SqlIdentifier(f.getName(), pos)) .collect(Collectors.toList()); - return new SqlInsert( + return new SqlInsert( + pos, + SqlNodeList.EMPTY, + new SqlIdentifier(elementNames, pos), + new SqlSelect( pos, - SqlNodeList.EMPTY, - new SqlIdentifier(elementNames, pos), - new SqlSelect(pos, null, - new SqlNodeList(columns, pos), - new SqlIdentifier(usingTable, pos), - null, null, null, null, null, null, null), - null - ); - } + null, + new SqlNodeList(columns, pos), + new SqlIdentifier(usingTable, pos), + null, + null, + null, + null, + null, + null, + null), + null); + } - @Override - public SqlNode getSqlNode() { - return createGraph; - } + @Override + public SqlNode getSqlNode() { + return createGraph; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateTableCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateTableCommand.java index 3a78f91c1..c2eff6258 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateTableCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateTableCommand.java @@ -30,26 +30,26 @@ public class CreateTableCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(CreateTableCommand.class); - - private final SqlCreateTable createTable; - - public CreateTableCommand(SqlCreateTable createTable) { - this.createTable = createTable; - } - - @Override - public QueryResult execute(QueryContext context) { - GQLContext gqlContext = context.getGqlContext(); - GeaFlowTable table = gqlContext.convertToTable(createTable); - // register table to catalog. - gqlContext.registerTable(table); - LOGGER.info("Success to create table: \n{}", table); - return new QueryResult(true); - } - - @Override - public SqlNode getSqlNode() { - return createTable; - } + private static final Logger LOGGER = LoggerFactory.getLogger(CreateTableCommand.class); + + private final SqlCreateTable createTable; + + public CreateTableCommand(SqlCreateTable createTable) { + this.createTable = createTable; + } + + @Override + public QueryResult execute(QueryContext context) { + GQLContext gqlContext = context.getGqlContext(); + GeaFlowTable table = gqlContext.convertToTable(createTable); + // register table to catalog. + gqlContext.registerTable(table); + LOGGER.info("Success to create table: \n{}", table); + return new QueryResult(true); + } + + @Override + public SqlNode getSqlNode() { + return createTable; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateViewCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateViewCommand.java index 549f3ab3c..a4aa75538 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateViewCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/CreateViewCommand.java @@ -30,32 +30,32 @@ public class CreateViewCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(CreateViewCommand.class); - - private final SqlCreateView createView; - - public CreateViewCommand(SqlCreateView createView) { - this.createView = createView; - } - - @Override - public QueryResult execute(QueryContext context) { - GQLContext gqlContext = context.getGqlContext(); - GeaFlowView view = gqlContext.convertToView(createView); - // register view to catalog. - gqlContext.registerView(view); - LOGGER.info("Success to create view: \n{}", view); - - IQueryCommand command = context.getCommand(createView.getSubQuery()); - boolean preIsCompile = context.setCompile(true); - QueryResult viewResult = command.execute(context); - context.putViewDataView(view.getName(), viewResult.getDataView()); - context.setCompile(preIsCompile); - return new QueryResult(true); - } - - @Override - public SqlNode getSqlNode() { - return createView; - } + private static final Logger LOGGER = LoggerFactory.getLogger(CreateViewCommand.class); + + private final SqlCreateView createView; + + public CreateViewCommand(SqlCreateView createView) { + this.createView = createView; + } + + @Override + public QueryResult execute(QueryContext context) { + GQLContext gqlContext = context.getGqlContext(); + GeaFlowView view = gqlContext.convertToView(createView); + // register view to catalog. + gqlContext.registerView(view); + LOGGER.info("Success to create view: \n{}", view); + + IQueryCommand command = context.getCommand(createView.getSubQuery()); + boolean preIsCompile = context.setCompile(true); + QueryResult viewResult = command.execute(context); + context.putViewDataView(view.getName(), viewResult.getDataView()); + context.setCompile(preIsCompile); + return new QueryResult(true); + } + + @Override + public SqlNode getSqlNode() { + return createView; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DescGraphCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DescGraphCommand.java index 95d86f1b4..31e6df52e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DescGraphCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DescGraphCommand.java @@ -26,19 +26,19 @@ public class DescGraphCommand implements IQueryCommand { - private final SqlDescGraph descGraph; + private final SqlDescGraph descGraph; - public DescGraphCommand(SqlDescGraph descGraph) { - this.descGraph = descGraph; - } + public DescGraphCommand(SqlDescGraph descGraph) { + this.descGraph = descGraph; + } - @Override - public QueryResult execute(QueryContext context) { - return null; - } + @Override + public QueryResult execute(QueryContext context) { + return null; + } - @Override - public SqlNode getSqlNode() { - return descGraph; - } + @Override + public SqlNode getSqlNode() { + return descGraph; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DropGraphCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DropGraphCommand.java index ac785b172..a24bf1f97 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DropGraphCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/DropGraphCommand.java @@ -26,19 +26,19 @@ public class DropGraphCommand implements IQueryCommand { - private final SqlDropGraph dropGraph; + private final SqlDropGraph dropGraph; - public DropGraphCommand(SqlDropGraph dropGraph) { - this.dropGraph = dropGraph; - } + public DropGraphCommand(SqlDropGraph dropGraph) { + this.dropGraph = dropGraph; + } - @Override - public QueryResult execute(QueryContext context) { - return null; - } + @Override + public QueryResult execute(QueryContext context) { + return null; + } - @Override - public SqlNode getSqlNode() { - return dropGraph; - } + @Override + public SqlNode getSqlNode() { + return dropGraph; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/IQueryCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/IQueryCommand.java index c284dfaf3..f0f8697c7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/IQueryCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/IQueryCommand.java @@ -25,7 +25,7 @@ public interface IQueryCommand { - QueryResult execute(QueryContext context); + QueryResult execute(QueryContext context); - SqlNode getSqlNode(); + SqlNode getSqlNode(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/QueryCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/QueryCommand.java index 7b4101fb4..7df72362d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/QueryCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/QueryCommand.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; import org.apache.calcite.sql.SqlKind; @@ -42,63 +43,67 @@ public class QueryCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(QueryCommand.class); + private static final Logger LOGGER = LoggerFactory.getLogger(QueryCommand.class); - private final SqlNode query; + private final SqlNode query; - public QueryCommand(SqlNode query) { - this.query = query; - } + public QueryCommand(SqlNode query) { + this.query = query; + } + + @SuppressWarnings("unchecked") + @Override + public QueryResult execute(QueryContext context) { + long startTs = System.currentTimeMillis(); + LOGGER.info("Execute query:\n{}", query); + + GQLContext gqlContext = context.getGqlContext(); + SqlNode validateQuery = gqlContext.validate(query); + + RelNode logicalPlan = gqlContext.toRelNode(validateQuery); + LOGGER.info("Convert sql to logical plan:\n{}", RelOptUtil.toString(logicalPlan)); - @SuppressWarnings("unchecked") - @Override - public QueryResult execute(QueryContext context) { - long startTs = System.currentTimeMillis(); - LOGGER.info("Execute query:\n{}", query); - - GQLContext gqlContext = context.getGqlContext(); - SqlNode validateQuery = gqlContext.validate(query); - - RelNode logicalPlan = gqlContext.toRelNode(validateQuery); - LOGGER.info("Convert sql to logical plan:\n{}", RelOptUtil.toString(logicalPlan)); - - RelNode optimizedNode = gqlContext.optimize(context.getLogicalRules(), logicalPlan); - LOGGER.info("After optimize logical plan:\n{}", RelOptUtil.toString(optimizedNode)); - - PhysicRelNode physicNode = (PhysicRelNode) gqlContext.transform(ConvertRules.TRANSFORM_RULES, - optimizedNode, optimizedNode.getTraitSet().plus(PhysicConvention.INSTANCE).simplify()); - LOGGER.info("Convert to physic plan:\n{}", RelOptUtil.toString(physicNode)); - - physicNode = (PhysicRelNode) context.getPathAnalyzer().analyze(physicNode); - LOGGER.info("After path analyzer:\n{}", RelOptUtil.toString(physicNode)); - RDataView dataView = physicNode.translate(context); - context.setCurrentResultType(physicNode.getRowType()); - - if (context.isCompile()) { - long compileSpend = System.currentTimeMillis() - startTs; - LOGGER.info("Finish compile query, spend:{}ms", compileSpend); - return new QueryResult(dataView); - } - - if (query.getKind() != SqlKind.INSERT) { - List rows = (List) dataView.take(SqlTypeUtil.convertType(physicNode.getRowType())); - RowDecoder rowDecoder = - new DefaultRowDecoder((StructType) SqlTypeUtil.convertType(physicNode.getRowType())); - List decodeRows = new ArrayList<>(rows.size()); - for (Row row : rows) { - decodeRows.add(rowDecoder.decode(row)); - } - long spend = System.currentTimeMillis() - startTs; - LOGGER.info("Finish execute query, take records: {}, spend: {}ms", rows.size(), spend); - return new QueryResult(decodeRows); - } - long spend = System.currentTimeMillis() - startTs; - LOGGER.info("Finish execute query, spend: {}ms", spend); - return new QueryResult(true); + RelNode optimizedNode = gqlContext.optimize(context.getLogicalRules(), logicalPlan); + LOGGER.info("After optimize logical plan:\n{}", RelOptUtil.toString(optimizedNode)); + + PhysicRelNode physicNode = + (PhysicRelNode) + gqlContext.transform( + ConvertRules.TRANSFORM_RULES, + optimizedNode, + optimizedNode.getTraitSet().plus(PhysicConvention.INSTANCE).simplify()); + LOGGER.info("Convert to physic plan:\n{}", RelOptUtil.toString(physicNode)); + + physicNode = (PhysicRelNode) context.getPathAnalyzer().analyze(physicNode); + LOGGER.info("After path analyzer:\n{}", RelOptUtil.toString(physicNode)); + RDataView dataView = physicNode.translate(context); + context.setCurrentResultType(physicNode.getRowType()); + + if (context.isCompile()) { + long compileSpend = System.currentTimeMillis() - startTs; + LOGGER.info("Finish compile query, spend:{}ms", compileSpend); + return new QueryResult(dataView); } - @Override - public SqlNode getSqlNode() { - return query; + if (query.getKind() != SqlKind.INSERT) { + List rows = (List) dataView.take(SqlTypeUtil.convertType(physicNode.getRowType())); + RowDecoder rowDecoder = + new DefaultRowDecoder((StructType) SqlTypeUtil.convertType(physicNode.getRowType())); + List decodeRows = new ArrayList<>(rows.size()); + for (Row row : rows) { + decodeRows.add(rowDecoder.decode(row)); + } + long spend = System.currentTimeMillis() - startTs; + LOGGER.info("Finish execute query, take records: {}, spend: {}ms", rows.size(), spend); + return new QueryResult(decodeRows); } + long spend = System.currentTimeMillis() - startTs; + LOGGER.info("Finish execute query, spend: {}ms", spend); + return new QueryResult(true); + } + + @Override + public SqlNode getSqlNode() { + return query; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/SetCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/SetCommand.java index 3e8f54537..1fd5307fe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/SetCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/SetCommand.java @@ -29,25 +29,25 @@ public class SetCommand implements IQueryCommand { - private static final Logger LOGGER = LoggerFactory.getLogger(SetCommand.class); - - private final SqlSetOption sqlSetOption; - - public SetCommand(SqlSetOption sqlSetOption) { - this.sqlSetOption = sqlSetOption; - } - - @Override - public QueryResult execute(QueryContext context) { - String key = StringLiteralUtil.toJavaString(sqlSetOption.getName()); - String value = StringLiteralUtil.toJavaString(sqlSetOption.getValue()); - context.putSetOption(key, value); - LOGGER.info("set '{}' to '{}'", key, value); - return new QueryResult(true); - } - - @Override - public SqlNode getSqlNode() { - return sqlSetOption; - } + private static final Logger LOGGER = LoggerFactory.getLogger(SetCommand.class); + + private final SqlSetOption sqlSetOption; + + public SetCommand(SqlSetOption sqlSetOption) { + this.sqlSetOption = sqlSetOption; + } + + @Override + public QueryResult execute(QueryContext context) { + String key = StringLiteralUtil.toJavaString(sqlSetOption.getName()); + String value = StringLiteralUtil.toJavaString(sqlSetOption.getValue()); + context.putSetOption(key, value); + LOGGER.info("set '{}' to '{}'", key, value); + return new QueryResult(true); + } + + @Override + public SqlNode getSqlNode() { + return sqlSetOption; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseGraphCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseGraphCommand.java index 9e27f5c58..1f81fbc07 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseGraphCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseGraphCommand.java @@ -26,20 +26,20 @@ public class UseGraphCommand implements IQueryCommand { - private final SqlUseGraph useGraph; + private final SqlUseGraph useGraph; - public UseGraphCommand(SqlUseGraph useGraph) { - this.useGraph = useGraph; - } + public UseGraphCommand(SqlUseGraph useGraph) { + this.useGraph = useGraph; + } - @Override - public QueryResult execute(QueryContext context) { - context.getGqlContext().setCurrentGraph(useGraph.getGraph()); - return new QueryResult(true); - } + @Override + public QueryResult execute(QueryContext context) { + context.getGqlContext().setCurrentGraph(useGraph.getGraph()); + return new QueryResult(true); + } - @Override - public SqlNode getSqlNode() { - return useGraph; - } + @Override + public SqlNode getSqlNode() { + return useGraph; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseInstanceCommand.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseInstanceCommand.java index ea63eba71..e10baaf28 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseInstanceCommand.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/command/UseInstanceCommand.java @@ -27,24 +27,24 @@ public class UseInstanceCommand implements IQueryCommand { - private final SqlUseInstance useInstance; + private final SqlUseInstance useInstance; - public UseInstanceCommand(SqlUseInstance useInstance) { - this.useInstance = useInstance; - } + public UseInstanceCommand(SqlUseInstance useInstance) { + this.useInstance = useInstance; + } - @Override - public QueryResult execute(QueryContext context) { - String instance = useInstance.getInstance().getSimple(); - if (!context.getGqlContext().getCatalog().isInstanceExists(instance)) { - throw new ObjectNotExistException("Instance: '" + instance + "' is not exists"); - } - context.getGqlContext().setCurrentInstance(useInstance.getInstance().getSimple()); - return new QueryResult(true); + @Override + public QueryResult execute(QueryContext context) { + String instance = useInstance.getInstance().getSimple(); + if (!context.getGqlContext().getCatalog().isInstanceExists(instance)) { + throw new ObjectNotExistException("Instance: '" + instance + "' is not exists"); } + context.getGqlContext().setCurrentInstance(useInstance.getInstance().getSimple()); + return new QueryResult(true); + } - @Override - public SqlNode getSqlNode() { - return useInstance; - } + @Override + public SqlNode getSqlNode() { + return useInstance; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/AbstractTraversalRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/AbstractTraversalRuntimeContext.java index d067875b4..687f39a60 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/AbstractTraversalRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/AbstractTraversalRuntimeContext.java @@ -26,6 +26,7 @@ import java.util.Objects; import java.util.Set; import java.util.Stack; + import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction.VertexCentricAggContext; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalEdgeQuery; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalVertexQuery; @@ -56,288 +57,290 @@ public abstract class AbstractTraversalRuntimeContext implements TraversalRuntimeContext { - private DagTopologyGroup topology; + private DagTopologyGroup topology; - private RowVertex vertex; + private RowVertex vertex; - private long inputOpId; + private long inputOpId; - private long currentOpId; + private long currentOpId; - private MessageBox messageBox; + private MessageBox messageBox; - private ParameterRequest request; + private ParameterRequest request; - protected TraversalVertexQuery vertexQuery; + protected TraversalVertexQuery vertexQuery; - protected TraversalEdgeQuery edgeQuery; + protected TraversalEdgeQuery edgeQuery; - // opId -> CallContext stack. - private final Map> callStacks = new HashMap<>(); + // opId -> CallContext stack. + private final Map> callStacks = new HashMap<>(); - private final Set callRequestIds = new HashSet<>(); + private final Set callRequestIds = new HashSet<>(); - private final Map vertexId2AppendFields = new HashMap<>(); + private final Map vertexId2AppendFields = new HashMap<>(); - protected VertexCentricAggContext aggContext; + protected VertexCentricAggContext aggContext; - public AbstractTraversalRuntimeContext(TraversalVertexQuery vertexQuery, - TraversalEdgeQuery edgeQuery) { - this.vertexQuery = vertexQuery; - this.edgeQuery = edgeQuery; - } + public AbstractTraversalRuntimeContext( + TraversalVertexQuery vertexQuery, TraversalEdgeQuery edgeQuery) { + this.vertexQuery = vertexQuery; + this.edgeQuery = edgeQuery; + } - public AbstractTraversalRuntimeContext() { + public AbstractTraversalRuntimeContext() {} - } + @Override + public DagTopologyGroup getTopology() { + return topology; + } - @Override - public DagTopologyGroup getTopology() { - return topology; - } + @Override + public void setTopology(DagTopologyGroup topology) { + this.topology = topology; + } - @Override - public void setTopology(DagTopologyGroup topology) { - this.topology = topology; - } + @Override + public RowVertex getVertex() { + return vertex; + } - @Override - public RowVertex getVertex() { - return vertex; - } + @Override + public void setVertex(RowVertex vertex) { + this.vertex = vertex; + vertexQuery.withId(vertex.getId()); + edgeQuery.withId(vertex.getId()); + } - @Override - public void setVertex(RowVertex vertex) { - this.vertex = vertex; - vertexQuery.withId(vertex.getId()); - edgeQuery.withId(vertex.getId()); + @Override + public Object getRequestId() { + if (request != null) { + return request.getRequestId(); } + return null; + } - @Override - public Object getRequestId() { - if (request != null) { - return request.getRequestId(); - } - return null; - } + @Override + public void setCurrentOpId(long operatorId) { + this.currentOpId = operatorId; + } - @Override - public void setCurrentOpId(long operatorId) { - this.currentOpId = operatorId; - } + @Override + public long getCurrentOpId() { + return currentOpId; + } - @Override - public long getCurrentOpId() { - return currentOpId; - } + @Override + public void setRequest(ParameterRequest parameterRequest) { + this.request = parameterRequest; + } - @Override - public void setRequest(ParameterRequest parameterRequest) { - this.request = parameterRequest; - } + @Override + public ParameterRequest getRequest() { + return request; + } - @Override - public ParameterRequest getRequest() { - return request; - } + @Override + public void setMessageBox(MessageBox messageBox) { + this.messageBox = messageBox; + } - @Override - public void setMessageBox(MessageBox messageBox) { - this.messageBox = messageBox; - } + @SuppressWarnings("unchecked") + @Override + public M getMessage(MessageType messageType) { + if (messageBox != null) { + M message = messageBox.getMessage(currentOpId, messageType); + if (message instanceof RequestIsolationMessage && getRequestId() != null) { + message = (M) ((RequestIsolationMessage) message).getMessageByRequestId(getRequestId()); + } - @SuppressWarnings("unchecked") - @Override - public M getMessage(MessageType messageType) { - if (messageBox != null) { - M message = messageBox.getMessage(currentOpId, messageType); - if (message instanceof RequestIsolationMessage && getRequestId() != null) { - message = (M) ((RequestIsolationMessage) message).getMessageByRequestId(getRequestId()); - } + if (messageType == MessageType.PATH) { + // Fetch path from the calling stack + ITreePath pathMessage = (ITreePath) message; + if (callStacks.containsKey(currentOpId)) { + Stack callStack = callStacks.get(currentOpId); - if (messageType == MessageType.PATH) { - // Fetch path from the calling stack - ITreePath pathMessage = (ITreePath) message; - if (callStacks.containsKey(currentOpId)) { - Stack callStack = callStacks.get(currentOpId); - - if (!callStack.isEmpty()) { - CallContext callContext = callStack.peek(); - ITreePath stashTreePath = callContext.getPath(getRequestId(), getVertex().getId()); - if (pathMessage != null) { - return (M) stashTreePath.merge(pathMessage); - } - return (M) stashTreePath; - } - } - } else if (messageType == MessageType.PARAMETER_REQUEST) { - // Fetch request from the calling stack - ParameterRequestMessage requestMessage = (ParameterRequestMessage) message; - if (requestMessage == null) { - requestMessage = new ParameterRequestMessage(); - } - Stack callStack = callStacks.get(currentOpId); - if (callStack != null && !callStack.isEmpty()) { - CallContext callContext = callStack.peek(); - List stashRequests = callContext.getRequests(getVertex().getId()); - if (stashRequests != null) { - for (ParameterRequest request : stashRequests) { - requestMessage.addRequest(request); - } - } - return (M) requestMessage; - } + if (!callStack.isEmpty()) { + CallContext callContext = callStack.peek(); + ITreePath stashTreePath = callContext.getPath(getRequestId(), getVertex().getId()); + if (pathMessage != null) { + return (M) stashTreePath.merge(pathMessage); } - return message; + return (M) stashTreePath; + } } - return null; - } - - @Override - public EdgeGroup loadEdges(IFilter loadEdgesFilter) { - return EdgeGroup.of((CloseableIterator) edgeQuery.getEdges(loadEdgesFilter)); - } - - @Override - public RowVertex loadVertex(Object id, IFilter loadVertexFilter, GraphSchema graphSchema, - IType[] addingVertexFieldTypes) { - RowVertex vertexFromState = (RowVertex) vertexQuery.withId(id).get(loadVertexFilter); - if (addingVertexFieldTypes.length > 0) { - Object[] appendFields = vertexId2AppendFields.get(vertexFromState.getId()); - if (appendFields == null) { - appendFields = new Object[addingVertexFieldTypes.length]; - } - VertexType stateVertexSchema = graphSchema.getVertex(vertexFromState.getLabel()); - return appendFields(vertexFromState, stateVertexSchema, appendFields); + } else if (messageType == MessageType.PARAMETER_REQUEST) { + // Fetch request from the calling stack + ParameterRequestMessage requestMessage = (ParameterRequestMessage) message; + if (requestMessage == null) { + requestMessage = new ParameterRequestMessage(); } - return vertexFromState; - } - - @Override - public CloseableIterator loadAllVertex() { - return vertexQuery.loadIdIterator(); - } - - private RowVertex appendFields(RowVertex vertexFromState, VertexType stateVertexSchema, - Object[] appendFields) { - Row value = vertexFromState.getValue(); - Object[] fields = value.getFields(stateVertexSchema.getValueTypes()); - Object[] newFields = new Object[fields.length + appendFields.length]; - System.arraycopy(fields, 0, newFields, 0, fields.length); - System.arraycopy(appendFields, 0, newFields, fields.length, appendFields.length); - Row newValue = ObjectRow.create(newFields); - return (RowVertex) vertexFromState.withValue(newValue); - } - - @Override - public void push(long opId, CallContext callContext) { - callStacks.computeIfAbsent(opId, k -> new Stack<>()).push(callContext); - } - - @Override - public void pop(long opId) { - if (callStacks.containsKey(opId)) { - Stack callStack = callStacks.get(opId); - callStack.pop(); - if (callStack.isEmpty()) { - callStacks.remove(opId); + Stack callStack = callStacks.get(currentOpId); + if (callStack != null && !callStack.isEmpty()) { + CallContext callContext = callStack.peek(); + List stashRequests = callContext.getRequests(getVertex().getId()); + if (stashRequests != null) { + for (ParameterRequest request : stashRequests) { + requestMessage.addRequest(request); } + } + return (M) requestMessage; } - } - - @Override - public void sendMessage(Object vertexId, IMessage message, long receiverId, long... otherReceiveIds) { - MessageBox messageBox = message.getType().createMessageBox(); - messageBox.addMessage(receiverId, message); - - for (long otherReceiverId : otherReceiveIds) { - messageBox.addMessage(otherReceiverId, message); - } - sendMessage(vertexId, messageBox); - } - - protected abstract void sendMessage(Object vertexId, MessageBox messageBox); - - @SuppressWarnings("unchecked") - @Override - public void broadcast(IMessage message, long receiverId, long... otherReceiveIds) { - BroadcastId id = new BroadcastId(getTaskIndex()); - MessageBox messageBox = message.getType().createMessageBox(); - messageBox.addMessage(receiverId, message); - - for (long otherReceiverId : otherReceiveIds) { - messageBox.addMessage(otherReceiverId, message); - } - sendBroadcastMessage(id, messageBox); - } - - protected abstract void sendBroadcastMessage(Object vertexId, MessageBox messageBox); - - @Override - public void stashCallRequestId(CallRequestId callRequestId) { - callRequestIds.add(callRequestId); - } - - @Override - public Iterable takeCallRequestIds() { - Set requestIds = new HashSet<>(callRequestIds); - callRequestIds.clear(); - return requestIds; - } - - @Override - public void setInputOperatorId(long id) { - this.inputOpId = id; - } - - @Override - public long getInputOperatorId() { - return this.inputOpId; - } - - @Override - public void addFieldToVertex(Object vertexId, int updateIndex, Object value) { - //Here append value to the existed store - Object[] appendFields; - Object[] existFields = vertexId2AppendFields.get(vertexId); - if (existFields == null) { - appendFields = new Object[updateIndex + 1]; - } else if (updateIndex >= existFields.length) { - appendFields = new Object[updateIndex + 1]; - System.arraycopy(existFields, 0, appendFields, 0, existFields.length); - } else { - appendFields = existFields; - } - appendFields[updateIndex] = value; - vertexId2AppendFields.put(vertexId, appendFields); - } - - public int getNumTasks() { - return getRuntimeContext().getTaskArgs().getParallelism(); - } - - @Override - public int getTaskIndex() { - return getRuntimeContext().getTaskArgs().getTaskIndex(); - } - - @Override - public long createUniqueId(long idInTask) { - return IDUtil.uniqueId(getNumTasks(), getTaskIndex(), idInTask); - } - - @Override - public MetricGroup getMetric() { - return getRuntimeContext().getMetric(); - } - - @Override - public VertexCentricAggContext getAggContext() { - return aggContext; - } - - @Override - public void setAggContext(VertexCentricAggContext aggContext) { - this.aggContext = Objects.requireNonNull(aggContext); - } + } + return message; + } + return null; + } + + @Override + public EdgeGroup loadEdges(IFilter loadEdgesFilter) { + return EdgeGroup.of((CloseableIterator) edgeQuery.getEdges(loadEdgesFilter)); + } + + @Override + public RowVertex loadVertex( + Object id, + IFilter loadVertexFilter, + GraphSchema graphSchema, + IType[] addingVertexFieldTypes) { + RowVertex vertexFromState = (RowVertex) vertexQuery.withId(id).get(loadVertexFilter); + if (addingVertexFieldTypes.length > 0) { + Object[] appendFields = vertexId2AppendFields.get(vertexFromState.getId()); + if (appendFields == null) { + appendFields = new Object[addingVertexFieldTypes.length]; + } + VertexType stateVertexSchema = graphSchema.getVertex(vertexFromState.getLabel()); + return appendFields(vertexFromState, stateVertexSchema, appendFields); + } + return vertexFromState; + } + + @Override + public CloseableIterator loadAllVertex() { + return vertexQuery.loadIdIterator(); + } + + private RowVertex appendFields( + RowVertex vertexFromState, VertexType stateVertexSchema, Object[] appendFields) { + Row value = vertexFromState.getValue(); + Object[] fields = value.getFields(stateVertexSchema.getValueTypes()); + Object[] newFields = new Object[fields.length + appendFields.length]; + System.arraycopy(fields, 0, newFields, 0, fields.length); + System.arraycopy(appendFields, 0, newFields, fields.length, appendFields.length); + Row newValue = ObjectRow.create(newFields); + return (RowVertex) vertexFromState.withValue(newValue); + } + + @Override + public void push(long opId, CallContext callContext) { + callStacks.computeIfAbsent(opId, k -> new Stack<>()).push(callContext); + } + + @Override + public void pop(long opId) { + if (callStacks.containsKey(opId)) { + Stack callStack = callStacks.get(opId); + callStack.pop(); + if (callStack.isEmpty()) { + callStacks.remove(opId); + } + } + } + + @Override + public void sendMessage( + Object vertexId, IMessage message, long receiverId, long... otherReceiveIds) { + MessageBox messageBox = message.getType().createMessageBox(); + messageBox.addMessage(receiverId, message); + + for (long otherReceiverId : otherReceiveIds) { + messageBox.addMessage(otherReceiverId, message); + } + sendMessage(vertexId, messageBox); + } + + protected abstract void sendMessage(Object vertexId, MessageBox messageBox); + + @SuppressWarnings("unchecked") + @Override + public void broadcast(IMessage message, long receiverId, long... otherReceiveIds) { + BroadcastId id = new BroadcastId(getTaskIndex()); + MessageBox messageBox = message.getType().createMessageBox(); + messageBox.addMessage(receiverId, message); + + for (long otherReceiverId : otherReceiveIds) { + messageBox.addMessage(otherReceiverId, message); + } + sendBroadcastMessage(id, messageBox); + } + + protected abstract void sendBroadcastMessage(Object vertexId, MessageBox messageBox); + + @Override + public void stashCallRequestId(CallRequestId callRequestId) { + callRequestIds.add(callRequestId); + } + + @Override + public Iterable takeCallRequestIds() { + Set requestIds = new HashSet<>(callRequestIds); + callRequestIds.clear(); + return requestIds; + } + + @Override + public void setInputOperatorId(long id) { + this.inputOpId = id; + } + + @Override + public long getInputOperatorId() { + return this.inputOpId; + } + + @Override + public void addFieldToVertex(Object vertexId, int updateIndex, Object value) { + // Here append value to the existed store + Object[] appendFields; + Object[] existFields = vertexId2AppendFields.get(vertexId); + if (existFields == null) { + appendFields = new Object[updateIndex + 1]; + } else if (updateIndex >= existFields.length) { + appendFields = new Object[updateIndex + 1]; + System.arraycopy(existFields, 0, appendFields, 0, existFields.length); + } else { + appendFields = existFields; + } + appendFields[updateIndex] = value; + vertexId2AppendFields.put(vertexId, appendFields); + } + + public int getNumTasks() { + return getRuntimeContext().getTaskArgs().getParallelism(); + } + + @Override + public int getTaskIndex() { + return getRuntimeContext().getTaskArgs().getTaskIndex(); + } + + @Override + public long createUniqueId(long idInTask) { + return IDUtil.uniqueId(getNumTasks(), getTaskIndex(), idInTask); + } + + @Override + public MetricGroup getMetric() { + return getRuntimeContext().getMetric(); + } + + @Override + public VertexCentricAggContext getAggContext() { + return aggContext; + } + + @Override + public void setAggContext(VertexCentricAggContext aggContext) { + this.aggContext = Objects.requireNonNull(aggContext); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GQLPipeLine.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GQLPipeLine.java index ad34d09c0..ce367db3c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GQLPipeLine.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GQLPipeLine.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.geaflow.common.config.Configuration; @@ -56,166 +57,175 @@ public class GQLPipeLine { - private static final Logger LOGGER = LoggerFactory.getLogger(GQLPipeLine.class); - - private static final String GQL_FILE_NAME = "user.gql"; - - private final Environment environment; - - private GQLPipelineHook pipelineHook; - - private final int timeWaitSeconds; - - private final Map parallelismConfigMap; - - public GQLPipeLine(Environment environment, Map parallelismConfigMap) { - this(environment, -1, parallelismConfigMap); - } - - public GQLPipeLine(Environment environment, int timeOutSeconds) { - this(environment, timeOutSeconds, null); + private static final Logger LOGGER = LoggerFactory.getLogger(GQLPipeLine.class); + + private static final String GQL_FILE_NAME = "user.gql"; + + private final Environment environment; + + private GQLPipelineHook pipelineHook; + + private final int timeWaitSeconds; + + private final Map parallelismConfigMap; + + public GQLPipeLine(Environment environment, Map parallelismConfigMap) { + this(environment, -1, parallelismConfigMap); + } + + public GQLPipeLine(Environment environment, int timeOutSeconds) { + this(environment, timeOutSeconds, null); + } + + public GQLPipeLine( + Environment environment, int timeWaitSeconds, Map parallelismConfigMap) { + this.environment = environment; + this.timeWaitSeconds = timeWaitSeconds; + this.parallelismConfigMap = parallelismConfigMap; + } + + public void setPipelineHook(GQLPipelineHook pipelineHook) { + this.pipelineHook = pipelineHook; + } + + public void execute() throws Exception { + Configuration configuration = environment.getEnvironmentContext().getConfig(); + String queryPath = configuration.getString(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH, GQL_FILE_NAME); + String script; + if (queryPath.startsWith(FileConstants.PREFIX_JAVA_RESOURCE)) { + script = + IOUtils.resourceToString( + queryPath.substring(FileConstants.PREFIX_JAVA_RESOURCE.length()), + Charset.defaultCharset()); + } else { + String pathType = + configuration.getString( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE, FileConstants.PREFIX_JAVA_RESOURCE); + if (pathType.equals(FileConstants.PREFIX_JAVA_RESOURCE)) { + script = + IOUtils.resourceToString( + queryPath, Charset.defaultCharset(), GQLPipeLine.class.getClassLoader()); + } else { + script = FileUtils.readFileToString(new File(queryPath), Charset.defaultCharset()); + } } + LOGGER.info("queryPath:{}", queryPath); - public GQLPipeLine(Environment environment, int timeWaitSeconds, - Map parallelismConfigMap) { - this.environment = environment; - this.timeWaitSeconds = timeWaitSeconds; - this.parallelismConfigMap = parallelismConfigMap; + if (pipelineHook != null) { + script = pipelineHook.rewriteScript(script, configuration); } - - public void setPipelineHook(GQLPipelineHook pipelineHook) { - this.pipelineHook = pipelineHook; + LOGGER.info("execute query:\n{}", script); + if (script == null) { + throw new IllegalArgumentException("Cannot get script from certain query path."); } - - public void execute() throws Exception { - Configuration configuration = environment.getEnvironmentContext().getConfig(); - String queryPath = configuration.getString(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH, GQL_FILE_NAME); - String script; - if (queryPath.startsWith(FileConstants.PREFIX_JAVA_RESOURCE)) { - script = IOUtils.resourceToString( - queryPath.substring(FileConstants.PREFIX_JAVA_RESOURCE.length()), - Charset.defaultCharset()); - } else { - String pathType = configuration.getString(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE, FileConstants.PREFIX_JAVA_RESOURCE); - if (pathType.equals(FileConstants.PREFIX_JAVA_RESOURCE)) { - script = IOUtils.resourceToString(queryPath, Charset.defaultCharset(), - GQLPipeLine.class.getClassLoader()); - } else { - script = FileUtils.readFileToString(new File(queryPath), Charset.defaultCharset()); - } - } - LOGGER.info("queryPath:{}", queryPath); - - if (pipelineHook != null) { - script = pipelineHook.rewriteScript(script, configuration); - } - LOGGER.info("execute query:\n{}", script); - if (script == null) { - throw new IllegalArgumentException("Cannot get script from certain query path."); - } - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - LOGGER.info("Submit pipeline task ..."); - PreCompileResult compileResult = QueryUtil.preCompile(script, configuration); - TaskCallBack callBack = pipeline.submit(new GQLPipelineTask(script, pipelineHook, - parallelismConfigMap)); - callBack.addCallBack(new SaveGraphWriteVersionCallbackFunction(configuration, compileResult)); - LOGGER.info("Execute pipeline task"); - IPipelineResult result = pipeline.execute(); - LOGGER.info("Submit finished, waiting future result ..."); - if (timeWaitSeconds > 0) { - CompletableFuture future = CompletableFuture.supplyAsync(() -> result.get()); - future.get(timeWaitSeconds, TimeUnit.SECONDS); - } else if (timeWaitSeconds == 0) { - result.get(); - } + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + LOGGER.info("Submit pipeline task ..."); + PreCompileResult compileResult = QueryUtil.preCompile(script, configuration); + TaskCallBack callBack = + pipeline.submit(new GQLPipelineTask(script, pipelineHook, parallelismConfigMap)); + callBack.addCallBack(new SaveGraphWriteVersionCallbackFunction(configuration, compileResult)); + LOGGER.info("Execute pipeline task"); + IPipelineResult result = pipeline.execute(); + LOGGER.info("Submit finished, waiting future result ..."); + if (timeWaitSeconds > 0) { + CompletableFuture future = CompletableFuture.supplyAsync(() -> result.get()); + future.get(timeWaitSeconds, TimeUnit.SECONDS); + } else if (timeWaitSeconds == 0) { + result.get(); } + } - private static class SaveGraphWriteVersionCallbackFunction implements ICallbackFunction { - - private static final Logger LOGGER = LoggerFactory.getLogger(SaveGraphWriteVersionCallbackFunction.class); + private static class SaveGraphWriteVersionCallbackFunction implements ICallbackFunction { - private final Configuration conf; - private final List insertGraphs; - private final long checkpointDuration; + private static final Logger LOGGER = + LoggerFactory.getLogger(SaveGraphWriteVersionCallbackFunction.class); - public SaveGraphWriteVersionCallbackFunction(Configuration conf, PreCompileResult compileResult) { - this.conf = conf; - this.checkpointDuration = conf.getLong(BATCH_NUMBER_PER_CHECKPOINT); - this.insertGraphs = compileResult.getInsertGraphs(); - } - - @Override - public void window(long windowId) { - if (CheckpointUtil.needDoCheckpoint(windowId, checkpointDuration)) { - for (GraphViewDesc graphViewDesc : insertGraphs) { - if (graphViewDesc.getBackend().equals(BackendType.Memory)) { - continue; - } - long checkpointId = graphViewDesc.getCheckpoint(windowId); - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), conf); - keeper.saveViewVersion(checkpointId); - keeper.archive(); - LOGGER.info("save latest version for graph: {}, version id: {}", keeper.getViewName(), - checkpointId); - } catch (IOException e) { - throw new GeaflowRuntimeException("fail to do save latest version for: " - + graphViewDesc.getName() + ", windowId is: " + windowId + ", checkpointId is: " - + checkpointId, e); - } - } - } - } + private final Configuration conf; + private final List insertGraphs; + private final long checkpointDuration; - @Override - public void terminal() { + public SaveGraphWriteVersionCallbackFunction( + Configuration conf, PreCompileResult compileResult) { + this.conf = conf; + this.checkpointDuration = conf.getLong(BATCH_NUMBER_PER_CHECKPOINT); + this.insertGraphs = compileResult.getInsertGraphs(); + } + @Override + public void window(long windowId) { + if (CheckpointUtil.needDoCheckpoint(windowId, checkpointDuration)) { + for (GraphViewDesc graphViewDesc : insertGraphs) { + if (graphViewDesc.getBackend().equals(BackendType.Memory)) { + continue; + } + long checkpointId = graphViewDesc.getCheckpoint(windowId); + try { + ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), conf); + keeper.saveViewVersion(checkpointId); + keeper.archive(); + LOGGER.info( + "save latest version for graph: {}, version id: {}", + keeper.getViewName(), + checkpointId); + } catch (IOException e) { + throw new GeaflowRuntimeException( + "fail to do save latest version for: " + + graphViewDesc.getName() + + ", windowId is: " + + windowId + + ", checkpointId is: " + + checkpointId, + e); + } } + } } - public static class GQLPipelineTask implements PipelineTask { + @Override + public void terminal() {} + } - private final String script; + public static class GQLPipelineTask implements PipelineTask { - private final GQLPipelineHook pipelineHook; + private final String script; - private final Map parallelismConfigMap; + private final GQLPipelineHook pipelineHook; - public GQLPipelineTask(String script, GQLPipelineHook pipelineHook, - Map parallelismConfigMap) { - this.script = script; - this.pipelineHook = pipelineHook; - this.parallelismConfigMap = parallelismConfigMap; - } + private final Map parallelismConfigMap; - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - QueryClient queryClient = new QueryClient(); - QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); - QueryContext queryContext = QueryContext.builder() - .setEngineContext(engineContext) - .setCompile(false) - .build(); - if (pipelineHook != null) { - pipelineHook.beforeExecute(queryClient, queryContext); - } - if (parallelismConfigMap != null) { - queryContext.putConfigParallelism(parallelismConfigMap); - } - queryClient.executeQuery(script, queryContext); - if (pipelineHook != null) { - pipelineHook.afterExecute(queryClient, queryContext); - } - queryContext.finish(); - } + public GQLPipelineTask( + String script, GQLPipelineHook pipelineHook, Map parallelismConfigMap) { + this.script = script; + this.pipelineHook = pipelineHook; + this.parallelismConfigMap = parallelismConfigMap; } - public interface GQLPipelineHook { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + QueryClient queryClient = new QueryClient(); + QueryEngine engineContext = new GeaFlowQueryEngine(pipelineTaskCxt); + QueryContext queryContext = + QueryContext.builder().setEngineContext(engineContext).setCompile(false).build(); + if (pipelineHook != null) { + pipelineHook.beforeExecute(queryClient, queryContext); + } + if (parallelismConfigMap != null) { + queryContext.putConfigParallelism(parallelismConfigMap); + } + queryClient.executeQuery(script, queryContext); + if (pipelineHook != null) { + pipelineHook.afterExecute(queryClient, queryContext); + } + queryContext.finish(); + } + } - String rewriteScript(String script, Configuration configuration); + public interface GQLPipelineHook { - void beforeExecute(QueryClient queryClient, QueryContext queryContext); + String rewriteScript(String script, Configuration configuration); - void afterExecute(QueryClient queryClient, QueryContext queryContext); - } + void beforeExecute(QueryClient queryClient, QueryContext queryContext); + + void afterExecute(QueryClient queryClient, QueryContext queryContext); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversal.java index f68e75997..22234d750 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversal.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.Objects; + import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -30,43 +31,60 @@ import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; -public class GeaFlowAlgorithmAggTraversal extends VertexCentricAggTraversal { +public class GeaFlowAlgorithmAggTraversal + extends VertexCentricAggTraversal< + Object, + Row, + Row, + Object, + Row, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg> { - private final AlgorithmUserFunction userFunction; - private final Object[] params; + private final AlgorithmUserFunction userFunction; + private final Object[] params; - private final GraphSchema graphSchema; - private final int parallelism; + private final GraphSchema graphSchema; + private final int parallelism; - public GeaFlowAlgorithmAggTraversal(AlgorithmUserFunction userFunction, int maxTraversal, - Object[] params, GraphSchema graphSchema, int parallelism) { - super(maxTraversal); - this.userFunction = Objects.requireNonNull(userFunction); - this.params = Objects.requireNonNull(params); - this.graphSchema = Objects.requireNonNull(graphSchema); - assert parallelism >= 1; - this.parallelism = parallelism; - } + public GeaFlowAlgorithmAggTraversal( + AlgorithmUserFunction userFunction, + int maxTraversal, + Object[] params, + GraphSchema graphSchema, + int parallelism) { + super(maxTraversal); + this.userFunction = Objects.requireNonNull(userFunction); + this.params = Objects.requireNonNull(params); + this.graphSchema = Objects.requireNonNull(graphSchema); + assert parallelism >= 1; + this.parallelism = parallelism; + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } - @Override - public IEncoder getMessageEncoder() { - return null; - } + @Override + public IEncoder getMessageEncoder() { + return null; + } - @Override - public VertexCentricAggTraversalFunction getTraversalFunction() { - return new GeaFlowAlgorithmAggTraversalFunction(graphSchema, userFunction, params); - } + @Override + public VertexCentricAggTraversalFunction< + Object, Row, Row, Object, Row, ITraversalAgg, ITraversalAgg> + getTraversalFunction() { + return new GeaFlowAlgorithmAggTraversalFunction(graphSchema, userFunction, params); + } - @Override - public VertexCentricAggregateFunction getAggregateFunction() { - return (VertexCentricAggregateFunction) new GeaFlowKVAlgorithmAggregateFunction(parallelism); - } + @Override + public VertexCentricAggregateFunction< + ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg> + getAggregateFunction() { + return (VertexCentricAggregateFunction) new GeaFlowKVAlgorithmAggregateFunction(parallelism); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java index daca980db..890c1cf7e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java @@ -27,6 +27,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; import org.apache.geaflow.dsl.common.data.Row; @@ -42,128 +43,135 @@ import org.apache.geaflow.utils.keygroup.KeyGroupAssignerFactory; import org.apache.geaflow.utils.keygroup.KeyGroupAssignment; -public class GeaFlowAlgorithmAggTraversalFunction implements - VertexCentricAggTraversalFunction { +public class GeaFlowAlgorithmAggTraversalFunction + implements VertexCentricAggTraversalFunction< + Object, Row, Row, Object, Row, ITraversalAgg, ITraversalAgg> { - private static final String STATE_SUFFIX = "UpdatedValueState"; + private static final String STATE_SUFFIX = "UpdatedValueState"; - private final AlgorithmUserFunction userFunction; + private final AlgorithmUserFunction userFunction; - private final Object[] params; + private final Object[] params; - private GraphSchema graphSchema; + private GraphSchema graphSchema; - private VertexCentricTraversalFuncContext traversalContext; + private VertexCentricTraversalFuncContext traversalContext; - private GeaFlowAlgorithmRuntimeContext algorithmCtx; + private GeaFlowAlgorithmRuntimeContext algorithmCtx; - private transient Set invokeVIds; + private transient Set invokeVIds; - private transient KeyValueState vertexUpdateValues; + private transient KeyValueState vertexUpdateValues; - public GeaFlowAlgorithmAggTraversalFunction(GraphSchema graphSchema, - AlgorithmUserFunction userFunction, - Object[] params) { - this.graphSchema = Objects.requireNonNull(graphSchema); - this.userFunction = Objects.requireNonNull(userFunction); - this.params = Objects.requireNonNull(params); - } + public GeaFlowAlgorithmAggTraversalFunction( + GraphSchema graphSchema, + AlgorithmUserFunction userFunction, + Object[] params) { + this.graphSchema = Objects.requireNonNull(graphSchema); + this.userFunction = Objects.requireNonNull(userFunction); + this.params = Objects.requireNonNull(params); + } - @Override - public void open( - VertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.traversalContext = vertexCentricFuncContext; - this.algorithmCtx = new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema); - this.userFunction.init(algorithmCtx, params); - this.invokeVIds = new HashSet<>(); - String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; - KeyValueStateDescriptor descriptor = KeyValueStateDescriptor.build( + @Override + public void open( + VertexCentricTraversalFuncContext vertexCentricFuncContext) { + this.traversalContext = vertexCentricFuncContext; + this.algorithmCtx = new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema); + this.userFunction.init(algorithmCtx, params); + this.invokeVIds = new HashSet<>(); + String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; + KeyValueStateDescriptor descriptor = + KeyValueStateDescriptor.build( stateName, - traversalContext.getRuntimeContext().getConfiguration().getString(SYSTEM_STATE_BACKEND_TYPE)); - int parallelism = traversalContext.getRuntimeContext().getTaskArgs().getParallelism(); - int maxParallelism = traversalContext.getRuntimeContext().getTaskArgs().getMaxParallelism(); - int taskIndex = traversalContext.getRuntimeContext().getTaskArgs().getTaskIndex(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + traversalContext + .getRuntimeContext() + .getConfiguration() + .getString(SYSTEM_STATE_BACKEND_TYPE)); + int parallelism = traversalContext.getRuntimeContext().getTaskArgs().getParallelism(); + int maxParallelism = traversalContext.getRuntimeContext().getTaskArgs().getMaxParallelism(); + int taskIndex = traversalContext.getRuntimeContext().getTaskArgs().getTaskIndex(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( maxParallelism, parallelism, taskIndex); - descriptor.withKeyGroup(keyGroup); - IKeyGroupAssigner keyGroupAssigner = KeyGroupAssignerFactory.createKeyGroupAssigner( - keyGroup, taskIndex, maxParallelism); - descriptor.withKeyGroupAssigner(keyGroupAssigner); - long recoverWindowId = traversalContext.getRuntimeContext().getWindowId(); - this.vertexUpdateValues = StateFactory.buildKeyValueState(descriptor, - traversalContext.getRuntimeContext().getConfiguration()); - if (recoverWindowId > 1) { - this.vertexUpdateValues.manage().operate().setCheckpointId(recoverWindowId - 1); - this.vertexUpdateValues.manage().operate().recover(); - } - } - - @Override - public void init(ITraversalRequest traversalRequest) { - RowVertex vertex = (RowVertex) traversalContext.vertex().get(); - if (vertex != null) { - algorithmCtx.setVertexId(vertex.getId()); - addInvokeVertex(vertex); - Row newValue = getVertexNewValue(vertex.getId()); - userFunction.process(vertex, Optional.ofNullable(newValue), Collections.emptyIterator()); - } - } - - @Override - public void compute(Object vertexId, Iterator messages) { - algorithmCtx.setVertexId(vertexId); - RowVertex vertex = (RowVertex) traversalContext.vertex().get(); - if (vertex != null) { - Row newValue = getVertexNewValue(vertex.getId()); - addInvokeVertex(vertex); - userFunction.process(vertex, Optional.ofNullable(newValue), messages); - } - } - - @Override - public void finish() { - Iterator idIterator = getInvokeVIds(); - while (idIterator.hasNext()) { - Object id = idIterator.next(); - algorithmCtx.setVertexId(id); - RowVertex graphVertex = (RowVertex) traversalContext.vertex().withId(id).get(); - if (graphVertex != null) { - Row newValue = getVertexNewValue(graphVertex.getId()); - userFunction.finish(graphVertex, Optional.ofNullable(newValue)); - } - } - algorithmCtx.finish(); - long windowId = traversalContext.getRuntimeContext().getWindowId(); - this.vertexUpdateValues.manage().operate().setCheckpointId(windowId); - this.vertexUpdateValues.manage().operate().finish(); - this.vertexUpdateValues.manage().operate().archive(); - invokeVIds.clear(); - } - - @Override - public void close() { - algorithmCtx.close(); + descriptor.withKeyGroup(keyGroup); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxParallelism); + descriptor.withKeyGroupAssigner(keyGroupAssigner); + long recoverWindowId = traversalContext.getRuntimeContext().getWindowId(); + this.vertexUpdateValues = + StateFactory.buildKeyValueState( + descriptor, traversalContext.getRuntimeContext().getConfiguration()); + if (recoverWindowId > 1) { + this.vertexUpdateValues.manage().operate().setCheckpointId(recoverWindowId - 1); + this.vertexUpdateValues.manage().operate().recover(); } - - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.algorithmCtx.setAggContext(Objects.requireNonNull(aggContext)); + } + + @Override + public void init(ITraversalRequest traversalRequest) { + RowVertex vertex = (RowVertex) traversalContext.vertex().get(); + if (vertex != null) { + algorithmCtx.setVertexId(vertex.getId()); + addInvokeVertex(vertex); + Row newValue = getVertexNewValue(vertex.getId()); + userFunction.process(vertex, Optional.ofNullable(newValue), Collections.emptyIterator()); } - - public void updateVertexValue(Object vertexId, Row value) { - vertexUpdateValues.put(vertexId, value); + } + + @Override + public void compute(Object vertexId, Iterator messages) { + algorithmCtx.setVertexId(vertexId); + RowVertex vertex = (RowVertex) traversalContext.vertex().get(); + if (vertex != null) { + Row newValue = getVertexNewValue(vertex.getId()); + addInvokeVertex(vertex); + userFunction.process(vertex, Optional.ofNullable(newValue), messages); } - - public Row getVertexNewValue(Object vertexId) { - return vertexUpdateValues.get(vertexId); + } + + @Override + public void finish() { + Iterator idIterator = getInvokeVIds(); + while (idIterator.hasNext()) { + Object id = idIterator.next(); + algorithmCtx.setVertexId(id); + RowVertex graphVertex = (RowVertex) traversalContext.vertex().withId(id).get(); + if (graphVertex != null) { + Row newValue = getVertexNewValue(graphVertex.getId()); + userFunction.finish(graphVertex, Optional.ofNullable(newValue)); + } } - - public void addInvokeVertex(RowVertex v) { - invokeVIds.add(v.getId()); - } - - public Iterator getInvokeVIds() { - return invokeVIds.iterator(); - } - + algorithmCtx.finish(); + long windowId = traversalContext.getRuntimeContext().getWindowId(); + this.vertexUpdateValues.manage().operate().setCheckpointId(windowId); + this.vertexUpdateValues.manage().operate().finish(); + this.vertexUpdateValues.manage().operate().archive(); + invokeVIds.clear(); + } + + @Override + public void close() { + algorithmCtx.close(); + } + + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.algorithmCtx.setAggContext(Objects.requireNonNull(aggContext)); + } + + public void updateVertexValue(Object vertexId, Row value) { + vertexUpdateValues.put(vertexId, value); + } + + public Row getVertexNewValue(Object vertexId) { + return vertexUpdateValues.get(vertexId); + } + + public void addInvokeVertex(RowVertex v) { + invokeVIds.add(v.getId()); + } + + public Iterator getInvokeVIds() { + return invokeVIds.iterator(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversal.java index 9279280ee..feea1f0fc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversal.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.Objects; + import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -30,46 +31,60 @@ import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; -public class GeaFlowAlgorithmDynamicAggTraversal extends - IncVertexCentricAggTraversal { +public class GeaFlowAlgorithmDynamicAggTraversal + extends IncVertexCentricAggTraversal< + Object, + Row, + Row, + Object, + Row, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg> { - private final AlgorithmUserFunction userFunction; - private final Object[] params; + private final AlgorithmUserFunction userFunction; + private final Object[] params; - private final GraphSchema graphSchema; - private final int parallelism; + private final GraphSchema graphSchema; + private final int parallelism; - public GeaFlowAlgorithmDynamicAggTraversal(AlgorithmUserFunction userFunction, int maxTraversal, - Object[] params, GraphSchema graphSchema, - int parallelism) { - super(maxTraversal); - this.userFunction = Objects.requireNonNull(userFunction); - this.params = Objects.requireNonNull(params); - this.graphSchema = Objects.requireNonNull(graphSchema); - assert parallelism >= 1; - this.parallelism = parallelism; - } + public GeaFlowAlgorithmDynamicAggTraversal( + AlgorithmUserFunction userFunction, + int maxTraversal, + Object[] params, + GraphSchema graphSchema, + int parallelism) { + super(maxTraversal); + this.userFunction = Objects.requireNonNull(userFunction); + this.params = Objects.requireNonNull(params); + this.graphSchema = Objects.requireNonNull(graphSchema); + assert parallelism >= 1; + this.parallelism = parallelism; + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } - @Override - public IEncoder getMessageEncoder() { - return null; - } + @Override + public IEncoder getMessageEncoder() { + return null; + } - @Override - public IncVertexCentricAggTraversalFunction getIncTraversalFunction() { - return new GeaFlowAlgorithmDynamicAggTraversalFunction(graphSchema, userFunction, params); - } + @Override + public IncVertexCentricAggTraversalFunction< + Object, Row, Row, Object, Row, ITraversalAgg, ITraversalAgg> + getIncTraversalFunction() { + return new GeaFlowAlgorithmDynamicAggTraversalFunction(graphSchema, userFunction, params); + } - @Override - public VertexCentricAggregateFunction getAggregateFunction() { - return (VertexCentricAggregateFunction) new GeaFlowKVAlgorithmAggregateFunction(parallelism); - } + @Override + public VertexCentricAggregateFunction< + ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg> + getAggregateFunction() { + return (VertexCentricAggregateFunction) new GeaFlowKVAlgorithmAggregateFunction(parallelism); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java index 98c475b15..fba8556fa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; @@ -51,195 +52,207 @@ import org.slf4j.LoggerFactory; public class GeaFlowAlgorithmDynamicAggTraversalFunction - implements IncVertexCentricAggTraversalFunction, RichIteratorFunction { + implements IncVertexCentricAggTraversalFunction< + Object, Row, Row, Object, Row, ITraversalAgg, ITraversalAgg>, + RichIteratorFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowAlgorithmDynamicAggTraversalFunction.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(GeaFlowAlgorithmDynamicAggTraversalFunction.class); - private static final String STATE_SUFFIX = "UpdatedValueState"; + private static final String STATE_SUFFIX = "UpdatedValueState"; - private final AlgorithmUserFunction userFunction; + private final AlgorithmUserFunction userFunction; - private final Object[] params; + private final Object[] params; - private GraphSchema graphSchema; + private GraphSchema graphSchema; - private IncVertexCentricTraversalFuncContext traversalContext; + private IncVertexCentricTraversalFuncContext traversalContext; - private GeaFlowAlgorithmDynamicRuntimeContext algorithmCtx; + private GeaFlowAlgorithmDynamicRuntimeContext algorithmCtx; - private MutableGraph mutableGraph; + private MutableGraph mutableGraph; - private transient Set initVertices; + private transient Set initVertices; - private transient KeyValueState vertexUpdateValues; + private transient KeyValueState vertexUpdateValues; - private boolean materializeInFinish; + private boolean materializeInFinish; - public GeaFlowAlgorithmDynamicAggTraversalFunction(GraphSchema graphSchema, - AlgorithmUserFunction userFunction, - Object[] params) { - this.graphSchema = Objects.requireNonNull(graphSchema); - this.userFunction = Objects.requireNonNull(userFunction); - this.params = Objects.requireNonNull(params); - this.initVertices = new HashSet<>(); - } + public GeaFlowAlgorithmDynamicAggTraversalFunction( + GraphSchema graphSchema, + AlgorithmUserFunction userFunction, + Object[] params) { + this.graphSchema = Objects.requireNonNull(graphSchema); + this.userFunction = Objects.requireNonNull(userFunction); + this.params = Objects.requireNonNull(params); + this.initVertices = new HashSet<>(); + } - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.traversalContext = vertexCentricFuncContext; - this.materializeInFinish = traversalContext.getRuntimeContext().getConfiguration().getBoolean(FrameworkConfigKeys.UDF_MATERIALIZE_GRAPH_IN_FINISH); - this.algorithmCtx = new GeaFlowAlgorithmDynamicRuntimeContext(this, traversalContext, - graphSchema); - this.initVertices = new HashSet<>(); - this.userFunction.init(algorithmCtx, params); - this.mutableGraph = traversalContext.getMutableGraph(); - - int taskIndex = traversalContext.getRuntimeContext().getTaskArgs().getTaskIndex(); - String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; - KeyValueStateDescriptor descriptor = KeyValueStateDescriptor.build( + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.traversalContext = vertexCentricFuncContext; + this.materializeInFinish = + traversalContext + .getRuntimeContext() + .getConfiguration() + .getBoolean(FrameworkConfigKeys.UDF_MATERIALIZE_GRAPH_IN_FINISH); + this.algorithmCtx = + new GeaFlowAlgorithmDynamicRuntimeContext(this, traversalContext, graphSchema); + this.initVertices = new HashSet<>(); + this.userFunction.init(algorithmCtx, params); + this.mutableGraph = traversalContext.getMutableGraph(); + + int taskIndex = traversalContext.getRuntimeContext().getTaskArgs().getTaskIndex(); + String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; + KeyValueStateDescriptor descriptor = + KeyValueStateDescriptor.build( stateName, - traversalContext.getRuntimeContext().getConfiguration().getString(SYSTEM_STATE_BACKEND_TYPE)); - int parallelism = traversalContext.getRuntimeContext().getTaskArgs().getParallelism(); - int maxParallelism = traversalContext.getRuntimeContext().getTaskArgs().getMaxParallelism(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + traversalContext + .getRuntimeContext() + .getConfiguration() + .getString(SYSTEM_STATE_BACKEND_TYPE)); + int parallelism = traversalContext.getRuntimeContext().getTaskArgs().getParallelism(); + int maxParallelism = traversalContext.getRuntimeContext().getTaskArgs().getMaxParallelism(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( maxParallelism, parallelism, taskIndex); - descriptor.withKeyGroup(keyGroup); - IKeyGroupAssigner keyGroupAssigner = KeyGroupAssignerFactory.createKeyGroupAssigner( - keyGroup, taskIndex, maxParallelism); - descriptor.withKeyGroupAssigner(keyGroupAssigner); - long recoverWindowId = traversalContext.getRuntimeContext().getWindowId(); - this.vertexUpdateValues = StateFactory.buildKeyValueState(descriptor, - traversalContext.getRuntimeContext().getConfiguration()); - if (recoverWindowId > 1) { - this.vertexUpdateValues.manage().operate().setCheckpointId(recoverWindowId - 1); - this.vertexUpdateValues.manage().operate().recover(); - } + descriptor.withKeyGroup(keyGroup); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxParallelism); + descriptor.withKeyGroupAssigner(keyGroupAssigner); + long recoverWindowId = traversalContext.getRuntimeContext().getWindowId(); + this.vertexUpdateValues = + StateFactory.buildKeyValueState( + descriptor, traversalContext.getRuntimeContext().getConfiguration()); + if (recoverWindowId > 1) { + this.vertexUpdateValues.manage().operate().setCheckpointId(recoverWindowId - 1); + this.vertexUpdateValues.manage().operate().recover(); } - - @Override - public void init(ITraversalRequest traversalRequest) { - if (!materializeInFinish) { - Object vertexId = traversalRequest.getVId(); - algorithmCtx.setVertexId(vertexId); - // The set formed by the vertices and source/target vertices of the edges inserted into - // each window of the dynamic graph is taken as the trigger vertex for the first round of - // iteration in the algorithm. These vertices may be duplicated, and needInit() returns - // false when called after the first time to avoid redundant invocation. - if (vertexId != null && needInit(vertexId)) { - RowVertex vertex = (RowVertex) algorithmCtx.loadVertex(); - if (vertex != null) { - algorithmCtx.setVertexId(vertex.getId()); - Row newValue = getVertexNewValue(vertex.getId()); - userFunction.process(vertex, Optional.ofNullable(newValue), Collections.emptyIterator()); - } - } - } - } - - public void updateVertexValue(Object vertexId, Row value) { - vertexUpdateValues.put(vertexId, value); - } - - public Row getVertexNewValue(Object vertexId) { - return vertexUpdateValues.get(vertexId); - } - - @Override - public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { - if (!materializeInFinish) { - IVertex vertex = temporaryGraph.getVertex(); - List> edges = temporaryGraph.getEdges(); - materializeGraph(vertex, edges); - } else { - algorithmCtx.setVertexId(vertexId); - RowVertex vertex = (RowVertex) temporaryGraph.getVertex(); - List emptyMessages = Collections.emptyList(); - userFunction.process(vertex, Optional.ofNullable(null), emptyMessages.iterator()); - } - } - - private void materializeGraph(IVertex vertex, List> edges) { + } + + @Override + public void init(ITraversalRequest traversalRequest) { + if (!materializeInFinish) { + Object vertexId = traversalRequest.getVId(); + algorithmCtx.setVertexId(vertexId); + // The set formed by the vertices and source/target vertices of the edges inserted into + // each window of the dynamic graph is taken as the trigger vertex for the first round of + // iteration in the algorithm. These vertices may be duplicated, and needInit() returns + // false when called after the first time to avoid redundant invocation. + if (vertexId != null && needInit(vertexId)) { + RowVertex vertex = (RowVertex) algorithmCtx.loadVertex(); if (vertex != null) { - mutableGraph.addVertex(GRAPH_VERSION, vertex); - } - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(GRAPH_VERSION, edge); - } + algorithmCtx.setVertexId(vertex.getId()); + Row newValue = getVertexNewValue(vertex.getId()); + userFunction.process(vertex, Optional.ofNullable(newValue), Collections.emptyIterator()); } + } } - - @Override - public void compute(Object vertexId, Iterator messages) { - algorithmCtx.setVertexId(vertexId); - RowVertex vertex; - if (materializeInFinish) { - vertex = (RowVertex) algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); - if (vertex == null) { - vertex = (RowVertex) algorithmCtx.loadVertex(); - } - } else { - vertex = (RowVertex) algorithmCtx.loadVertex(); - } - if (vertex != null) { - Row newValue = getVertexNewValue(vertex.getId()); - userFunction.process(vertex, Optional.ofNullable(newValue), messages); - } + } + + public void updateVertexValue(Object vertexId, Row value) { + vertexUpdateValues.put(vertexId, value); + } + + public Row getVertexNewValue(Object vertexId) { + return vertexUpdateValues.get(vertexId); + } + + @Override + public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { + if (!materializeInFinish) { + IVertex vertex = temporaryGraph.getVertex(); + List> edges = temporaryGraph.getEdges(); + materializeGraph(vertex, edges); + } else { + algorithmCtx.setVertexId(vertexId); + RowVertex vertex = (RowVertex) temporaryGraph.getVertex(); + List emptyMessages = Collections.emptyList(); + userFunction.process(vertex, Optional.ofNullable(null), emptyMessages.iterator()); } + } - @Override - public void finish(Object vertexId, MutableGraph mutableGraph) { - algorithmCtx.setVertexId(vertexId); - RowVertex graphVertex = (RowVertex) algorithmCtx.loadVertex(); - if (graphVertex != null) { - Row newValue = getVertexNewValue(graphVertex.getId()); - userFunction.finish(graphVertex, Optional.ofNullable(newValue)); - } - if (materializeInFinish) { - IVertex vertex = algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); - materializeGraph(vertex, traversalContext.getTemporaryGraph().getEdges()); - } + private void materializeGraph(IVertex vertex, List> edges) { + if (vertex != null) { + mutableGraph.addVertex(GRAPH_VERSION, vertex); } - - public boolean needInit(Object v) { - if (initVertices.contains(v)) { - return false; - } else { - initVertices.add(v); - return true; - } + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(GRAPH_VERSION, edge); + } } - - @Override - public void finish() { - algorithmCtx.finish(); - initVertices.clear(); - userFunction.finish(); - long windowId = traversalContext.getRuntimeContext().getWindowId(); - this.vertexUpdateValues.manage().operate().setCheckpointId(windowId); - this.vertexUpdateValues.manage().operate().finish(); - this.vertexUpdateValues.manage().operate().archive(); + } + + @Override + public void compute(Object vertexId, Iterator messages) { + algorithmCtx.setVertexId(vertexId); + RowVertex vertex; + if (materializeInFinish) { + vertex = (RowVertex) algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + if (vertex == null) { + vertex = (RowVertex) algorithmCtx.loadVertex(); + } + } else { + vertex = (RowVertex) algorithmCtx.loadVertex(); } - - - @Override - public void close() { - algorithmCtx.close(); + if (vertex != null) { + Row newValue = getVertexNewValue(vertex.getId()); + userFunction.process(vertex, Optional.ofNullable(newValue), messages); } - - @Override - public void initIteration(long iterationId) { + } + + @Override + public void finish(Object vertexId, MutableGraph mutableGraph) { + algorithmCtx.setVertexId(vertexId); + RowVertex graphVertex = (RowVertex) algorithmCtx.loadVertex(); + if (graphVertex != null) { + Row newValue = getVertexNewValue(graphVertex.getId()); + userFunction.finish(graphVertex, Optional.ofNullable(newValue)); } - - @Override - public void finishIteration(long iterationId) { - userFunction.finishIteration(iterationId); + if (materializeInFinish) { + IVertex vertex = algorithmCtx.getIncVCTraversalCtx().getTemporaryGraph().getVertex(); + materializeGraph(vertex, traversalContext.getTemporaryGraph().getEdges()); } - - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.algorithmCtx.setAggContext(Objects.requireNonNull(aggContext)); + } + + public boolean needInit(Object v) { + if (initVertices.contains(v)) { + return false; + } else { + initVertices.add(v); + return true; } + } + + @Override + public void finish() { + algorithmCtx.finish(); + initVertices.clear(); + userFunction.finish(); + long windowId = traversalContext.getRuntimeContext().getWindowId(); + this.vertexUpdateValues.manage().operate().setCheckpointId(windowId); + this.vertexUpdateValues.manage().operate().finish(); + this.vertexUpdateValues.manage().operate().archive(); + } + + @Override + public void close() { + algorithmCtx.close(); + } + + @Override + public void initIteration(long iterationId) {} + + @Override + public void finishIteration(long iterationId) { + userFunction.finishIteration(iterationId); + } + + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.algorithmCtx.setAggContext(Objects.requireNonNull(aggContext)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java index d929ae441..1f2ab6107 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction.VertexCentricAggContext; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.IncVertexCentricTraversalFuncContext; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.TraversalGraphSnapShot; @@ -46,278 +47,278 @@ import org.apache.geaflow.state.pushdown.filter.InEdgeFilter; import org.apache.geaflow.state.pushdown.filter.OutEdgeFilter; -public class GeaFlowAlgorithmDynamicRuntimeContext implements AlgorithmRuntimeContext { - - private final IncVertexCentricTraversalFuncContext incVCTraversalCtx; - - protected VertexCentricAggContext aggContext; - - private final GraphSchema graphSchema; - - protected TraversalVertexQuery vertexQuery; - - protected TraversalEdgeQuery edgeQuery; - - private final transient GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction; - - private Object vertexId; - - private long iterationId = -1L; - - public GeaFlowAlgorithmDynamicRuntimeContext(GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction, - IncVertexCentricTraversalFuncContext traversalContext, GraphSchema graphSchema) { - this.traversalFunction = traversalFunction; - this.incVCTraversalCtx = traversalContext; - this.graphSchema = graphSchema; - TraversalGraphSnapShot graphSnapShot = incVCTraversalCtx.getHistoricalGraph().getSnapShot(0L); - this.vertexQuery = graphSnapShot.vertex(); - this.edgeQuery = graphSnapShot.edges(); +public class GeaFlowAlgorithmDynamicRuntimeContext + implements AlgorithmRuntimeContext { + + private final IncVertexCentricTraversalFuncContext + incVCTraversalCtx; + + protected VertexCentricAggContext aggContext; + + private final GraphSchema graphSchema; + + protected TraversalVertexQuery vertexQuery; + + protected TraversalEdgeQuery edgeQuery; + + private final transient GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction; + + private Object vertexId; + + private long iterationId = -1L; + + public GeaFlowAlgorithmDynamicRuntimeContext( + GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction, + IncVertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema) { + this.traversalFunction = traversalFunction; + this.incVCTraversalCtx = traversalContext; + this.graphSchema = graphSchema; + TraversalGraphSnapShot graphSnapShot = + incVCTraversalCtx.getHistoricalGraph().getSnapShot(0L); + this.vertexQuery = graphSnapShot.vertex(); + this.edgeQuery = graphSnapShot.edges(); + } + + public void setVertexId(Object vertexId) { + this.vertexId = vertexId; + this.vertexQuery.withId(vertexId); + this.edgeQuery.withId(vertexId); + } + + public IVertex loadVertex() { + return vertexQuery.get(); + } + + public CloseableIterator loadAllVertex() { + return vertexQuery.loadIdIterator(); + } + + @Override + public Configuration getConfig() { + return incVCTraversalCtx.getRuntimeContext().getConfiguration(); + } + + @SuppressWarnings("unchecked") + @Override + public List loadEdges(EdgeDirection direction) { + switch (direction) { + case OUT: + return (List) edgeQuery.getOutEdges(); + case IN: + return (List) edgeQuery.getInEdges(); + case BOTH: + List edges = new ArrayList<>(); + edges.addAll((List) edgeQuery.getOutEdges()); + edges.addAll((List) edgeQuery.getInEdges()); + return edges; + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - public void setVertexId(Object vertexId) { - this.vertexId = vertexId; - this.vertexQuery.withId(vertexId); - this.edgeQuery.withId(vertexId); - } - - public IVertex loadVertex() { - return vertexQuery.get(); - } - - public CloseableIterator loadAllVertex() { - return vertexQuery.loadIdIterator(); + } + + @Override + public CloseableIterator loadEdgesIterator(EdgeDirection direction) { + switch (direction) { + case OUT: + return loadEdgesIterator(OutEdgeFilter.getInstance()); + case IN: + return loadEdgesIterator(InEdgeFilter.getInstance()); + case BOTH: + return loadEdgesIterator(EmptyFilter.getInstance()); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - @Override - public Configuration getConfig() { - return incVCTraversalCtx.getRuntimeContext().getConfiguration(); - } - - @SuppressWarnings("unchecked") - @Override - public List loadEdges(EdgeDirection direction) { - switch (direction) { - case OUT: - return (List) edgeQuery.getOutEdges(); - case IN: - return (List) edgeQuery.getInEdges(); - case BOTH: - List edges = new ArrayList<>(); - edges.addAll((List) edgeQuery.getOutEdges()); - edges.addAll((List) edgeQuery.getInEdges()); - return edges; - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } + } + + @Override + public CloseableIterator loadEdgesIterator(IFilter filter) { + return (CloseableIterator) edgeQuery.getEdges(filter); + } + + public List loadDynamicEdges(EdgeDirection direction) { + List> edges = incVCTraversalCtx.getTemporaryGraph().getEdges(); + List rowEdges = new ArrayList<>(); + if (edges == null) { + return rowEdges; } - - @Override - public CloseableIterator loadEdgesIterator(EdgeDirection direction) { - switch (direction) { - case OUT: - return loadEdgesIterator(OutEdgeFilter.getInstance()); - case IN: - return loadEdgesIterator(InEdgeFilter.getInstance()); - case BOTH: - return loadEdgesIterator(EmptyFilter.getInstance()); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); + switch (direction) { + case OUT: + for (IEdge edge : edges) { + if (edge.getDirect() == EdgeDirection.OUT) { + rowEdges.add((RowEdge) edge); + } } - } - - @Override - public CloseableIterator loadEdgesIterator(IFilter filter) { - return (CloseableIterator) edgeQuery.getEdges(filter); - } - - public List loadDynamicEdges(EdgeDirection direction) { - List> edges = incVCTraversalCtx.getTemporaryGraph().getEdges(); - List rowEdges = new ArrayList<>(); - if (edges == null) { - return rowEdges; + break; + case IN: + for (IEdge edge : edges) { + if (edge.getDirect() == EdgeDirection.IN) { + rowEdges.add((RowEdge) edge); + } } - switch (direction) { - case OUT: - for (IEdge edge : edges) { - if (edge.getDirect() == EdgeDirection.OUT) { - rowEdges.add((RowEdge) edge); - } - } - break; - case IN: - for (IEdge edge : edges) { - if (edge.getDirect() == EdgeDirection.IN) { - rowEdges.add((RowEdge) edge); - } - } - break; - case BOTH: - for (IEdge edge : edges) { - rowEdges.add((RowEdge) edge); - } - break; - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); + break; + case BOTH: + for (IEdge edge : edges) { + rowEdges.add((RowEdge) edge); } - return rowEdges; + break; + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - public List loadStaticEdges(EdgeDirection direction) { - switch (direction) { - case OUT: - return (List) edgeQuery.getOutEdges(); - case IN: - return (List) edgeQuery.getInEdges(); - case BOTH: - return (List) edgeQuery.getEdges(); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } + return rowEdges; + } + + public List loadStaticEdges(EdgeDirection direction) { + switch (direction) { + case OUT: + return (List) edgeQuery.getOutEdges(); + case IN: + return (List) edgeQuery.getInEdges(); + case BOTH: + return (List) edgeQuery.getEdges(); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - @Override - public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { - switch (direction) { - case OUT: - return loadStaticEdgesIterator(OutEdgeFilter.getInstance()); - case IN: - return loadStaticEdgesIterator(InEdgeFilter.getInstance()); - case BOTH: - return loadStaticEdgesIterator(EmptyFilter.getInstance()); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } + } + + @Override + public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { + switch (direction) { + case OUT: + return loadStaticEdgesIterator(OutEdgeFilter.getInstance()); + case IN: + return loadStaticEdgesIterator(InEdgeFilter.getInstance()); + case BOTH: + return loadStaticEdgesIterator(EmptyFilter.getInstance()); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - @Override - public CloseableIterator loadStaticEdgesIterator(IFilter filter) { - return (CloseableIterator) edgeQuery.getEdges(filter); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(IFilter filter) { + return (CloseableIterator) edgeQuery.getEdges(filter); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { + switch (direction) { + case OUT: + return loadDynamicEdgesIterator(OutEdgeFilter.getInstance()); + case IN: + return loadDynamicEdgesIterator(InEdgeFilter.getInstance()); + case BOTH: + return loadDynamicEdgesIterator(EmptyFilter.getInstance()); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); } - - @Override - public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { - switch (direction) { - case OUT: - return loadDynamicEdgesIterator(OutEdgeFilter.getInstance()); - case IN: - return loadDynamicEdgesIterator(InEdgeFilter.getInstance()); - case BOTH: - return loadDynamicEdgesIterator(EmptyFilter.getInstance()); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { + List> edges = incVCTraversalCtx.getTemporaryGraph().getEdges(); + List res = new ArrayList<>(); + for (IEdge edge : edges) { + if (filter.filter(edge)) { + res.add((RowEdge) edge); + } } - - @Override - public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { - List> edges = incVCTraversalCtx.getTemporaryGraph().getEdges(); - List res = new ArrayList<>(); - for (IEdge edge : edges) { - if (filter.filter(edge)) { - res.add((RowEdge) edge); - } - } - return IteratorWithClose.wrap(res.iterator()); + return IteratorWithClose.wrap(res.iterator()); + } + + @Override + public void sendMessage(Object vertexId, Object message) { + incVCTraversalCtx.sendMessage(vertexId, message); + if (getCurrentIterationId() > iterationId) { + iterationId = getCurrentIterationId(); + aggContext.aggregate(GeaFlowKVAlgorithmAggregateFunction.getAlgorithmAgg(iterationId)); } + } - @Override - public void sendMessage(Object vertexId, Object message) { - incVCTraversalCtx.sendMessage(vertexId, message); - if (getCurrentIterationId() > iterationId) { - iterationId = getCurrentIterationId(); - aggContext.aggregate(GeaFlowKVAlgorithmAggregateFunction.getAlgorithmAgg(iterationId)); - } - } + @Override + public void take(Row row) { + incVCTraversalCtx.takeResponse(new AlgorithmResponse(row)); + } - @Override - public void take(Row row) { - incVCTraversalCtx.takeResponse(new AlgorithmResponse(row)); - } + @Override + public void updateVertexValue(Row value) { + traversalFunction.updateVertexValue(vertexId, value); + } - @Override - public void updateVertexValue(Row value) { - traversalFunction.updateVertexValue(vertexId, value); - } + public long getCurrentIterationId() { + return incVCTraversalCtx.getIterationId(); + } - public long getCurrentIterationId() { - return incVCTraversalCtx.getIterationId(); - } + public void finish() {} - public void finish() { + public void close() {} - } + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } - public void close() { + public VertexCentricAggContext getAggContext() { + return aggContext; + } - } + public void setAggContext(VertexCentricAggContext aggContext) { + this.aggContext = Objects.requireNonNull(aggContext); + } - @Override - public GraphSchema getGraphSchema() { - return graphSchema; - } + public IncVertexCentricTraversalFuncContext + getIncVCTraversalCtx() { + return incVCTraversalCtx; + } - public VertexCentricAggContext getAggContext() { - return aggContext; + @Override + public void voteToTerminate(String terminationReason, Object voteValue) { + // Send termination vote to coordinator through aggregation context + if (aggContext != null) { + aggContext.aggregate(new AlgorithmTerminationVote(terminationReason, voteValue)); } + } - public void setAggContext(VertexCentricAggContext aggContext) { - this.aggContext = Objects.requireNonNull(aggContext); - } + /** Internal class representing a termination vote sent to the coordinator. */ + private static class AlgorithmTerminationVote implements ITraversalAgg { + private final String terminationReason; + private final Object voteValue; - public IncVertexCentricTraversalFuncContext getIncVCTraversalCtx() { - return incVCTraversalCtx; + public AlgorithmTerminationVote(String terminationReason, Object voteValue) { + this.terminationReason = terminationReason; + this.voteValue = voteValue; } - @Override - public void voteToTerminate(String terminationReason, Object voteValue) { - // Send termination vote to coordinator through aggregation context - if (aggContext != null) { - aggContext.aggregate(new AlgorithmTerminationVote(terminationReason, voteValue)); - } + public String getTerminationReason() { + return terminationReason; } - /** - * Internal class representing a termination vote sent to the coordinator. - */ - private static class AlgorithmTerminationVote implements ITraversalAgg { - private final String terminationReason; - private final Object voteValue; - - public AlgorithmTerminationVote(String terminationReason, Object voteValue) { - this.terminationReason = terminationReason; - this.voteValue = voteValue; - } - - public String getTerminationReason() { - return terminationReason; - } - - public Object getVoteValue() { - return voteValue; - } + public Object getVoteValue() { + return voteValue; } + } - private static class AlgorithmResponse implements ITraversalResponse { + private static class AlgorithmResponse implements ITraversalResponse { - private final Row row; + private final Row row; - public AlgorithmResponse(Row row) { - this.row = row; - } + public AlgorithmResponse(Row row) { + this.row = row; + } - @Override - public long getResponseId() { - return 0; - } + @Override + public long getResponseId() { + return 0; + } - @Override - public Row getResponse() { - return row; - } + @Override + public Row getResponse() { + return row; + } - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public ResponseType getType() { + return ResponseType.Vertex; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java index 7696b4f10..a3bdb3663 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction.VertexCentricAggContext; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalEdgeQuery; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.VertexCentricTraversalFuncContext; @@ -44,209 +45,206 @@ public class GeaFlowAlgorithmRuntimeContext implements AlgorithmRuntimeContext { - private final VertexCentricTraversalFuncContext traversalContext; - - protected VertexCentricAggContext aggContext; - - private final GraphSchema graphSchema; - private final TraversalEdgeQuery edgeQuery; - private final transient GeaFlowAlgorithmAggTraversalFunction traversalFunction; - private Object vertexId; - - private long lastSendAggMsgIterationId = -1L; - - public GeaFlowAlgorithmRuntimeContext( - GeaFlowAlgorithmAggTraversalFunction traversalFunction, - VertexCentricTraversalFuncContext traversalContext, - GraphSchema graphSchema) { - this.traversalFunction = traversalFunction; - this.traversalContext = traversalContext; - this.edgeQuery = traversalContext.edges(); - this.graphSchema = graphSchema; - this.aggContext = null; - } - - public void setVertexId(Object vertexId) { - this.vertexId = vertexId; - this.edgeQuery.withId(vertexId); - } - - @SuppressWarnings("unchecked") - @Override - public List loadEdges(EdgeDirection direction) { - switch (direction) { - case OUT: - return (List) edgeQuery.getOutEdges(); - case IN: - return (List) edgeQuery.getInEdges(); - case BOTH: - List edges = new ArrayList<>(); - edges.addAll((List) edgeQuery.getOutEdges()); - edges.addAll((List) edgeQuery.getInEdges()); - return edges; - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } - } - - @Override - public CloseableIterator loadEdgesIterator(EdgeDirection direction) { - switch (direction) { - case OUT: - return loadEdgesIterator(OutEdgeFilter.getInstance()); - case IN: - return loadEdgesIterator(InEdgeFilter.getInstance()); - case BOTH: - return loadEdgesIterator(EmptyFilter.getInstance()); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } - } - - @Override - public CloseableIterator loadEdgesIterator(IFilter filter) { - return (CloseableIterator) edgeQuery.getEdges(filter); - } - - @Override - public List loadStaticEdges(EdgeDirection direction) { - return loadEdges(direction); - } - - @Override - public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { - switch (direction) { - case OUT: - return loadStaticEdgesIterator(OutEdgeFilter.getInstance()); - case IN: - return loadStaticEdgesIterator(InEdgeFilter.getInstance()); - case BOTH: - return loadStaticEdgesIterator(EmptyFilter.getInstance()); - default: - throw new GeaFlowDSLException("Illegal edge direction: " + direction); - } - } - - @Override - public CloseableIterator loadStaticEdgesIterator(IFilter filter) { - return (CloseableIterator) edgeQuery.getEdges(filter); - } - - @Override - public List loadDynamicEdges(EdgeDirection direction) { - throw new GeaflowRuntimeException("GeaFlowAlgorithmRuntimeContext not support loadDynamicEdges"); - } + private final VertexCentricTraversalFuncContext traversalContext; + + protected VertexCentricAggContext aggContext; + + private final GraphSchema graphSchema; + private final TraversalEdgeQuery edgeQuery; + private final transient GeaFlowAlgorithmAggTraversalFunction traversalFunction; + private Object vertexId; + + private long lastSendAggMsgIterationId = -1L; + + public GeaFlowAlgorithmRuntimeContext( + GeaFlowAlgorithmAggTraversalFunction traversalFunction, + VertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema) { + this.traversalFunction = traversalFunction; + this.traversalContext = traversalContext; + this.edgeQuery = traversalContext.edges(); + this.graphSchema = graphSchema; + this.aggContext = null; + } + + public void setVertexId(Object vertexId) { + this.vertexId = vertexId; + this.edgeQuery.withId(vertexId); + } + + @SuppressWarnings("unchecked") + @Override + public List loadEdges(EdgeDirection direction) { + switch (direction) { + case OUT: + return (List) edgeQuery.getOutEdges(); + case IN: + return (List) edgeQuery.getInEdges(); + case BOTH: + List edges = new ArrayList<>(); + edges.addAll((List) edgeQuery.getOutEdges()); + edges.addAll((List) edgeQuery.getInEdges()); + return edges; + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); + } + } + + @Override + public CloseableIterator loadEdgesIterator(EdgeDirection direction) { + switch (direction) { + case OUT: + return loadEdgesIterator(OutEdgeFilter.getInstance()); + case IN: + return loadEdgesIterator(InEdgeFilter.getInstance()); + case BOTH: + return loadEdgesIterator(EmptyFilter.getInstance()); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); + } + } + + @Override + public CloseableIterator loadEdgesIterator(IFilter filter) { + return (CloseableIterator) edgeQuery.getEdges(filter); + } + + @Override + public List loadStaticEdges(EdgeDirection direction) { + return loadEdges(direction); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { + switch (direction) { + case OUT: + return loadStaticEdgesIterator(OutEdgeFilter.getInstance()); + case IN: + return loadStaticEdgesIterator(InEdgeFilter.getInstance()); + case BOTH: + return loadStaticEdgesIterator(EmptyFilter.getInstance()); + default: + throw new GeaFlowDSLException("Illegal edge direction: " + direction); + } + } + + @Override + public CloseableIterator loadStaticEdgesIterator(IFilter filter) { + return (CloseableIterator) edgeQuery.getEdges(filter); + } + + @Override + public List loadDynamicEdges(EdgeDirection direction) { + throw new GeaflowRuntimeException( + "GeaFlowAlgorithmRuntimeContext not support loadDynamicEdges"); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { + throw new GeaflowRuntimeException( + "GeaFlowAlgorithmRuntimeContext not support loadDynamicEdgesIterator"); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { + throw new GeaflowRuntimeException( + "GeaFlowAlgorithmRuntimeContext not support loadDynamicEdgesIterator"); + } + + @Override + public void sendMessage(Object vertexId, Object message) { + traversalContext.sendMessage(vertexId, message); + if (getCurrentIterationId() > lastSendAggMsgIterationId) { + lastSendAggMsgIterationId = getCurrentIterationId(); + aggContext.aggregate( + GeaFlowKVAlgorithmAggregateFunction.getAlgorithmAgg(lastSendAggMsgIterationId)); + } + } + + @Override + public void updateVertexValue(Row value) { + traversalFunction.updateVertexValue(vertexId, value); + } + + @Override + public void take(Row row) { + traversalContext.takeResponse(new AlgorithmResponse(row)); + } + + public void finish() {} + + public void close() {} + + public long getCurrentIterationId() { + return traversalContext.getIterationId(); + } + + @Override + public Configuration getConfig() { + return traversalContext.getRuntimeContext().getConfiguration(); + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + public VertexCentricAggContext getAggContext() { + return aggContext; + } + + public void setAggContext(VertexCentricAggContext aggContext) { + this.aggContext = Objects.requireNonNull(aggContext); + } - @Override - public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { - throw new GeaflowRuntimeException("GeaFlowAlgorithmRuntimeContext not support loadDynamicEdgesIterator"); + @Override + public void voteToTerminate(String terminationReason, Object voteValue) { + // Send termination vote to coordinator through aggregation context + if (aggContext != null) { + aggContext.aggregate(new AlgorithmTerminationVote(terminationReason, voteValue)); } + } - @Override - public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { - throw new GeaflowRuntimeException("GeaFlowAlgorithmRuntimeContext not support loadDynamicEdgesIterator"); - } + /** Internal class representing a termination vote sent to the coordinator. */ + private static class AlgorithmTerminationVote implements ITraversalAgg { + private final String terminationReason; + private final Object voteValue; - @Override - public void sendMessage(Object vertexId, Object message) { - traversalContext.sendMessage(vertexId, message); - if (getCurrentIterationId() > lastSendAggMsgIterationId) { - lastSendAggMsgIterationId = getCurrentIterationId(); - aggContext.aggregate(GeaFlowKVAlgorithmAggregateFunction.getAlgorithmAgg( - lastSendAggMsgIterationId)); - } + public AlgorithmTerminationVote(String terminationReason, Object voteValue) { + this.terminationReason = terminationReason; + this.voteValue = voteValue; } - @Override - public void updateVertexValue(Row value) { - traversalFunction.updateVertexValue(vertexId, value); - } - - @Override - public void take(Row row) { - traversalContext.takeResponse(new AlgorithmResponse(row)); + public String getTerminationReason() { + return terminationReason; } - public void finish() { - + public Object getVoteValue() { + return voteValue; } + } - public void close() { + private static class AlgorithmResponse implements ITraversalResponse { - } + private final Row row; - public long getCurrentIterationId() { - return traversalContext.getIterationId(); + public AlgorithmResponse(Row row) { + this.row = row; } @Override - public Configuration getConfig() { - return traversalContext.getRuntimeContext().getConfiguration(); + public long getResponseId() { + return 0; } @Override - public GraphSchema getGraphSchema() { - return graphSchema; - } - - public VertexCentricAggContext getAggContext() { - return aggContext; - } - - public void setAggContext(VertexCentricAggContext aggContext) { - this.aggContext = Objects.requireNonNull(aggContext); + public Row getResponse() { + return row; } @Override - public void voteToTerminate(String terminationReason, Object voteValue) { - // Send termination vote to coordinator through aggregation context - if (aggContext != null) { - aggContext.aggregate(new AlgorithmTerminationVote(terminationReason, voteValue)); - } - } - - /** - * Internal class representing a termination vote sent to the coordinator. - */ - private static class AlgorithmTerminationVote implements ITraversalAgg { - private final String terminationReason; - private final Object voteValue; - - public AlgorithmTerminationVote(String terminationReason, Object voteValue) { - this.terminationReason = terminationReason; - this.voteValue = voteValue; - } - - public String getTerminationReason() { - return terminationReason; - } - - public Object getVoteValue() { - return voteValue; - } - } - - private static class AlgorithmResponse implements ITraversalResponse { - - private final Row row; - - public AlgorithmResponse(Row row) { - this.row = row; - } - - @Override - public long getResponseId() { - return 0; - } - - @Override - public Row getResponse() { - return row; - } - - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + public ResponseType getType() { + return ResponseType.Vertex; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowCommonTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowCommonTraversalFunction.java index 9a37b326c..e0b398dc5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowCommonTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowCommonTraversalFunction.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.dsl.runtime.traversal.ExecuteDagGroup; @@ -38,106 +39,109 @@ public class GeaFlowCommonTraversalFunction { - private final ExecuteDagGroup executeDagGroup; + private final ExecuteDagGroup executeDagGroup; - private TraversalRuntimeContext context; + private TraversalRuntimeContext context; - private final boolean isTraversalAllWithRequest; + private final boolean isTraversalAllWithRequest; - private final List> initRequests = new ArrayList<>(); + private final List> initRequests = new ArrayList<>(); - public GeaFlowCommonTraversalFunction(ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { - this.executeDagGroup = Objects.requireNonNull(executeDagGroup); - this.isTraversalAllWithRequest = isTraversalAllWithRequest; - } + public GeaFlowCommonTraversalFunction( + ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { + this.executeDagGroup = Objects.requireNonNull(executeDagGroup); + this.isTraversalAllWithRequest = isTraversalAllWithRequest; + } - public void open(TraversalRuntimeContext context) { - this.context = Objects.requireNonNull(context); - this.executeDagGroup.open(context); - } + public void open(TraversalRuntimeContext context) { + this.context = Objects.requireNonNull(context); + this.executeDagGroup.open(context); + } - public void init(ITraversalRequest traversalRequest) { - initRequests.add(traversalRequest); - } + public void init(ITraversalRequest traversalRequest) { + initRequests.add(traversalRequest); + } - public void compute(Object vertexId, Iterator messageIterator) { - // Only one MessageBox in the iterator as we will combine the message in MessageCombineFunction. - MessageBox messageBox = messageIterator.next(); - if (vertexId instanceof BroadcastId) { - executeDagGroup.processBroadcast(messageBox); - } else { - context.setMessageBox(messageBox); - long[] receiveOpIds = messageBox.getReceiverIds(); - executeDagGroup.execute(vertexId, receiveOpIds); - } + public void compute(Object vertexId, Iterator messageIterator) { + // Only one MessageBox in the iterator as we will combine the message in MessageCombineFunction. + MessageBox messageBox = messageIterator.next(); + if (vertexId instanceof BroadcastId) { + executeDagGroup.processBroadcast(messageBox); + } else { + context.setMessageBox(messageBox); + long[] receiveOpIds = messageBox.getReceiverIds(); + executeDagGroup.execute(vertexId, receiveOpIds); } - - public void finish(long iterationId) { - if (isTraversalAllWithRequest && initRequests.size() > 0) { - try (CloseableIterator idIterator = context.loadAllVertex()) { - while (idIterator.hasNext()) { - Object vertexId = idIterator.next(); - MessageBox messageBox = MessageType.PARAMETER_REQUEST.createMessageBox(); - ParameterRequestMessage parameterMessage = new ParameterRequestMessage(); - - for (ITraversalRequest request : initRequests) { - assert Objects.equals(request.getVId(), TraversalAll.INSTANCE); - assert request instanceof InitParameterRequest; - InitParameterRequest initRequest = (InitParameterRequest) request; - // convert InitParameterRequest to ParameterRequest because ParameterRequest - // can support multi-key request id, however ITraversalRequest can only support - // Long type which is not enough for complex query, e.g. sub query request. - ParameterRequest parameterRequest = new ParameterRequest(initRequest.getRequestId(), - vertexId, initRequest.getParameters()); - parameterMessage.addRequest(parameterRequest); - } - messageBox.addMessage(executeDagGroup.getEntryOpId(), parameterMessage); - context.setMessageBox(messageBox); - executeDagGroup.execute(vertexId, executeDagGroup.getEntryOpId()); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + } + + public void finish(long iterationId) { + if (isTraversalAllWithRequest && initRequests.size() > 0) { + try (CloseableIterator idIterator = context.loadAllVertex()) { + while (idIterator.hasNext()) { + Object vertexId = idIterator.next(); + MessageBox messageBox = MessageType.PARAMETER_REQUEST.createMessageBox(); + ParameterRequestMessage parameterMessage = new ParameterRequestMessage(); + + for (ITraversalRequest request : initRequests) { + assert Objects.equals(request.getVId(), TraversalAll.INSTANCE); + assert request instanceof InitParameterRequest; + InitParameterRequest initRequest = (InitParameterRequest) request; + // convert InitParameterRequest to ParameterRequest because ParameterRequest + // can support multi-key request id, however ITraversalRequest can only support + // Long type which is not enough for complex query, e.g. sub query request. + ParameterRequest parameterRequest = + new ParameterRequest( + initRequest.getRequestId(), vertexId, initRequest.getParameters()); + parameterMessage.addRequest(parameterRequest); + } + messageBox.addMessage(executeDagGroup.getEntryOpId(), parameterMessage); + context.setMessageBox(messageBox); + executeDagGroup.execute(vertexId, executeDagGroup.getEntryOpId()); + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + } else { + for (ITraversalRequest request : initRequests) { + Object vertexId = request.getVId(); + if (request instanceof InitParameterRequest) { + InitParameterRequest initRequest = (InitParameterRequest) request; + MessageBox messageBox = MessageType.PARAMETER_REQUEST.createMessageBox(); + ParameterRequestMessage parameterMessage = new ParameterRequestMessage(); + + // convert InitParameterRequest to ParameterRequest because ParameterRequest + // can support multi-key request id, however ITraversalRequest can only support + // Long type which is not enough for complex query, e.g. sub query request. + ParameterRequest parameterRequest = + new ParameterRequest( + initRequest.getRequestId(), vertexId, initRequest.getParameters()); + parameterMessage.addRequest(parameterRequest); + messageBox.addMessage(executeDagGroup.getEntryOpId(), parameterMessage); + context.setMessageBox(messageBox); } else { - for (ITraversalRequest request : initRequests) { - Object vertexId = request.getVId(); - if (request instanceof InitParameterRequest) { - InitParameterRequest initRequest = (InitParameterRequest) request; - MessageBox messageBox = MessageType.PARAMETER_REQUEST.createMessageBox(); - ParameterRequestMessage parameterMessage = new ParameterRequestMessage(); - - // convert InitParameterRequest to ParameterRequest because ParameterRequest - // can support multi-key request id, however ITraversalRequest can only support - // Long type which is not enough for complex query, e.g. sub query request. - ParameterRequest parameterRequest = new ParameterRequest(initRequest.getRequestId(), - vertexId, initRequest.getParameters()); - parameterMessage.addRequest(parameterRequest); - messageBox.addMessage(executeDagGroup.getEntryOpId(), parameterMessage); - context.setMessageBox(messageBox); - } else { - context.setMessageBox(null); - } - executeDagGroup.execute(vertexId, executeDagGroup.getEntryOpId()); - } + context.setMessageBox(null); } - - initRequests.clear(); - executeDagGroup.finishIteration(iterationId); + executeDagGroup.execute(vertexId, executeDagGroup.getEntryOpId()); + } } - public void close() { - executeDagGroup.close(); - } + initRequests.clear(); + executeDagGroup.finishIteration(iterationId); + } - public ExecuteDagGroup getExecuteDagGroup() { - return executeDagGroup; - } + public void close() { + executeDagGroup.close(); + } - public TraversalRuntimeContext getContext() { - return context; - } + public ExecuteDagGroup getExecuteDagGroup() { + return executeDagGroup; + } - public List> getInitRequests() { - return initRequests; - } + public TraversalRuntimeContext getContext() { + return context; + } + + public List> getInitRequests() { + return initRequests; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicTraversalRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicTraversalRuntimeContext.java index 44aed2f4f..ae194fa25 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicTraversalRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicTraversalRuntimeContext.java @@ -33,55 +33,62 @@ public class GeaFlowDynamicTraversalRuntimeContext extends AbstractTraversalRuntimeContext { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowDynamicTraversalRuntimeContext.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(GeaFlowDynamicTraversalRuntimeContext.class); - private final IncVertexCentricTraversalFuncContext incVCTraversalCtx; + private final IncVertexCentricTraversalFuncContext + incVCTraversalCtx; - public GeaFlowDynamicTraversalRuntimeContext( - IncVertexCentricTraversalFuncContext incVCTraversalCtx) { - this.incVCTraversalCtx = incVCTraversalCtx; - TraversalGraphSnapShot graphSnapShot = incVCTraversalCtx.getHistoricalGraph() - .getSnapShot(0L); - this.vertexQuery = graphSnapShot.vertex(); - this.edgeQuery = graphSnapShot.edges(); - } + public GeaFlowDynamicTraversalRuntimeContext( + IncVertexCentricTraversalFuncContext + incVCTraversalCtx) { + this.incVCTraversalCtx = incVCTraversalCtx; + TraversalGraphSnapShot graphSnapShot = + incVCTraversalCtx.getHistoricalGraph().getSnapShot(0L); + this.vertexQuery = graphSnapShot.vertex(); + this.edgeQuery = graphSnapShot.edges(); + } - @Override - public Configuration getConfig() { - return incVCTraversalCtx.getRuntimeContext().getConfiguration(); - } + @Override + public Configuration getConfig() { + return incVCTraversalCtx.getRuntimeContext().getConfiguration(); + } - @Override - public long getIterationId() { - return incVCTraversalCtx.getIterationId(); - } + @Override + public long getIterationId() { + return incVCTraversalCtx.getIterationId(); + } - @Override - protected void sendBroadcastMessage(Object vertexId, MessageBox messageBox) { - incVCTraversalCtx.broadcast(new DefaultGraphMessage<>(vertexId, messageBox)); - } + @Override + protected void sendBroadcastMessage(Object vertexId, MessageBox messageBox) { + incVCTraversalCtx.broadcast(new DefaultGraphMessage<>(vertexId, messageBox)); + } - @Override - protected void sendMessage(Object vertexId, MessageBox messageBox) { - incVCTraversalCtx.sendMessage(vertexId, messageBox); - } + @Override + protected void sendMessage(Object vertexId, MessageBox messageBox) { + incVCTraversalCtx.sendMessage(vertexId, messageBox); + } - @Override - public void takePath(ITreePath treePath) { - incVCTraversalCtx.takeResponse(new TraversalResponse(treePath)); - } + @Override + public void takePath(ITreePath treePath) { + incVCTraversalCtx.takeResponse(new TraversalResponse(treePath)); + } - @Override - public void sendCoordinator(String name, Object value) { - LOGGER.info("task: {} send to coordinator {}:{} isAggTraversal:{}", getTaskIndex(), name, - value, aggContext != null); - if (aggContext != null) { - aggContext.aggregate(new KVTraversalAgg(name, value)); - } + @Override + public void sendCoordinator(String name, Object value) { + LOGGER.info( + "task: {} send to coordinator {}:{} isAggTraversal:{}", + getTaskIndex(), + name, + value, + aggContext != null); + if (aggContext != null) { + aggContext.aggregate(new KVTraversalAgg(name, value)); } + } - @Override - public RuntimeContext getRuntimeContext() { - return incVCTraversalCtx.getRuntimeContext(); - } + @Override + public RuntimeContext getRuntimeContext() { + return incVCTraversalCtx.getRuntimeContext(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversal.java index 7b2791eee..6eea88b9c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversal.java @@ -29,41 +29,53 @@ import org.apache.geaflow.dsl.runtime.traversal.message.MessageBox; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class GeaFlowDynamicVCAggTraversal extends - IncVertexCentricAggTraversal { +public class GeaFlowDynamicVCAggTraversal + extends IncVertexCentricAggTraversal< + Object, + Row, + Row, + MessageBox, + ITreePath, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg> { - private final ExecuteDagGroup executeDagGroup; + private final ExecuteDagGroup executeDagGroup; - private final boolean isTraversalAllWithRequest; + private final boolean isTraversalAllWithRequest; - private final int parallelism; + private final int parallelism; - public GeaFlowDynamicVCAggTraversal(ExecuteDagGroup executeDagGroup, - int maxTraversal, - boolean isTraversalAllWithRequest, - int parallelism) { - super(maxTraversal); - this.executeDagGroup = executeDagGroup; - this.isTraversalAllWithRequest = isTraversalAllWithRequest; - assert parallelism > 0; - this.parallelism = parallelism; - } + public GeaFlowDynamicVCAggTraversal( + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isTraversalAllWithRequest, + int parallelism) { + super(maxTraversal); + this.executeDagGroup = executeDagGroup; + this.isTraversalAllWithRequest = isTraversalAllWithRequest; + assert parallelism > 0; + this.parallelism = parallelism; + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return new MessageBoxCombineFunction(); - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new MessageBoxCombineFunction(); + } - @Override - public IncVertexCentricAggTraversalFunction getIncTraversalFunction() { - return new GeaFlowDynamicVCAggTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } + @Override + public IncVertexCentricAggTraversalFunction< + Object, Row, Row, MessageBox, ITreePath, ITraversalAgg, ITraversalAgg> + getIncTraversalFunction() { + return new GeaFlowDynamicVCAggTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } - @Override - public VertexCentricAggregateFunction getAggregateFunction() { - return (VertexCentricAggregateFunction) new GeaFlowKVTraversalAggregateFunction(parallelism); - } + @Override + public VertexCentricAggregateFunction< + ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg> + getAggregateFunction() { + return (VertexCentricAggregateFunction) new GeaFlowKVTraversalAggregateFunction(parallelism); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversalFunction.java index be803ca30..588530e78 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCAggTraversalFunction.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; import org.apache.geaflow.dsl.common.data.Row; @@ -35,80 +36,75 @@ import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.model.traversal.ITraversalRequest; -public class GeaFlowDynamicVCAggTraversalFunction implements - IncVertexCentricAggTraversalFunction, RichIteratorFunction { - - private GeaFlowDynamicTraversalRuntimeContext traversalRuntimeContext; - - private final GeaFlowCommonTraversalFunction commonFunction; - - private MutableGraph mutableGraph; - - public GeaFlowDynamicVCAggTraversalFunction(ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { - this.commonFunction = new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } - - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - traversalRuntimeContext = new GeaFlowDynamicTraversalRuntimeContext( - vertexCentricFuncContext); - this.mutableGraph = vertexCentricFuncContext.getMutableGraph(); - this.commonFunction.open(traversalRuntimeContext); - } - - @Override - public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - mutableGraph.addVertex(GRAPH_VERSION, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(GRAPH_VERSION, edge); - } - } - } - - @Override - public void initIteration(long windowId) { - +public class GeaFlowDynamicVCAggTraversalFunction + implements IncVertexCentricAggTraversalFunction< + Object, Row, Row, MessageBox, ITreePath, ITraversalAgg, ITraversalAgg>, + RichIteratorFunction { + + private GeaFlowDynamicTraversalRuntimeContext traversalRuntimeContext; + + private final GeaFlowCommonTraversalFunction commonFunction; + + private MutableGraph mutableGraph; + + public GeaFlowDynamicVCAggTraversalFunction( + ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { + this.commonFunction = + new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } + + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + traversalRuntimeContext = new GeaFlowDynamicTraversalRuntimeContext(vertexCentricFuncContext); + this.mutableGraph = vertexCentricFuncContext.getMutableGraph(); + this.commonFunction.open(traversalRuntimeContext); + } + + @Override + public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + mutableGraph.addVertex(GRAPH_VERSION, vertex); } - - @Override - public void init(ITraversalRequest traversalRequest) { - commonFunction.init(traversalRequest); + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(GRAPH_VERSION, edge); + } } + } - @Override - public void finish() { - - } + @Override + public void initIteration(long windowId) {} - @Override - public void close() { + @Override + public void init(ITraversalRequest traversalRequest) { + commonFunction.init(traversalRequest); + } - } + @Override + public void finish() {} - @Override - public void compute(Object vertexId, Iterator messageIterator) { - commonFunction.compute(vertexId, messageIterator); - } + @Override + public void close() {} - @Override - public void finishIteration(long windowId) { - commonFunction.finish(windowId); - } + @Override + public void compute(Object vertexId, Iterator messageIterator) { + commonFunction.compute(vertexId, messageIterator); + } - @Override - public void finish(Object vertexId, MutableGraph mutableGraph) { + @Override + public void finishIteration(long windowId) { + commonFunction.finish(windowId); + } - } + @Override + public void finish(Object vertexId, MutableGraph mutableGraph) {} - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.traversalRuntimeContext.setAggContext(Objects.requireNonNull(aggContext)); - } + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.traversalRuntimeContext.setAggContext(Objects.requireNonNull(aggContext)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversal.java index 0e0aca6e4..ef0781b9e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversal.java @@ -27,31 +27,35 @@ import org.apache.geaflow.dsl.runtime.traversal.message.MessageBox; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class GeaFlowDynamicVCTraversal extends IncVertexCentricTraversal { - - private final ExecuteDagGroup executeDagGroup; - - private final boolean isTraversalAllWithRequest; - - private final boolean enableIncrTraversal; - - public GeaFlowDynamicVCTraversal(ExecuteDagGroup executeDagGroup, - int maxTraversal, - boolean isTraversalAllWithRequest, - boolean enableIncrTraversal) { - super(maxTraversal); - this.executeDagGroup = executeDagGroup; - this.isTraversalAllWithRequest = isTraversalAllWithRequest; - this.enableIncrTraversal = enableIncrTraversal; - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return new MessageBoxCombineFunction(); - } - - @Override - public IncVertexCentricTraversalFunction getIncTraversalFunction() { - return new GeaFlowDynamicVCTraversalFunction(executeDagGroup, isTraversalAllWithRequest, enableIncrTraversal); - } +public class GeaFlowDynamicVCTraversal + extends IncVertexCentricTraversal { + + private final ExecuteDagGroup executeDagGroup; + + private final boolean isTraversalAllWithRequest; + + private final boolean enableIncrTraversal; + + public GeaFlowDynamicVCTraversal( + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isTraversalAllWithRequest, + boolean enableIncrTraversal) { + super(maxTraversal); + this.executeDagGroup = executeDagGroup; + this.isTraversalAllWithRequest = isTraversalAllWithRequest; + this.enableIncrTraversal = enableIncrTraversal; + } + + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new MessageBoxCombineFunction(); + } + + @Override + public IncVertexCentricTraversalFunction + getIncTraversalFunction() { + return new GeaFlowDynamicVCTraversalFunction( + executeDagGroup, isTraversalAllWithRequest, enableIncrTraversal); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversalFunction.java index 9049033e1..5697550ee 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowDynamicVCTraversalFunction.java @@ -27,6 +27,7 @@ import java.util.Iterator; import java.util.List; import java.util.Set; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; import org.apache.geaflow.dsl.common.data.Row; @@ -55,193 +56,193 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class GeaFlowDynamicVCTraversalFunction implements - IncVertexCentricTraversalFunction, RichIteratorFunction { +public class GeaFlowDynamicVCTraversalFunction + implements IncVertexCentricTraversalFunction, + RichIteratorFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowDynamicVCTraversalFunction.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(GeaFlowDynamicVCTraversalFunction.class); - private final GeaFlowCommonTraversalFunction commonFunction; + private final GeaFlowCommonTraversalFunction commonFunction; - private MutableGraph mutableGraph; + private MutableGraph mutableGraph; - int queryMaxIteration = 0; + int queryMaxIteration = 0; - List incrMessageEdgeFilter = new ArrayList<>(); + List incrMessageEdgeFilter = new ArrayList<>(); - boolean enableIncrTraversal; + boolean enableIncrTraversal; - Set evolveIds; + Set evolveIds; - public GeaFlowDynamicVCTraversalFunction(ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest, - boolean enableIncrTraversal) { - this.commonFunction = new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - this.queryMaxIteration = executeDagGroup.getMaxIterationCount(); - this.enableIncrTraversal = enableIncrTraversal; - this.evolveIds = new HashSet<>(); - } + public GeaFlowDynamicVCTraversalFunction( + ExecuteDagGroup executeDagGroup, + boolean isTraversalAllWithRequest, + boolean enableIncrTraversal) { + this.commonFunction = + new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + this.queryMaxIteration = executeDagGroup.getMaxIterationCount(); + this.enableIncrTraversal = enableIncrTraversal; + this.evolveIds = new HashSet<>(); + } - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - TraversalRuntimeContext traversalRuntimeContext = new GeaFlowDynamicTraversalRuntimeContext( - vertexCentricFuncContext); - this.mutableGraph = vertexCentricFuncContext.getMutableGraph(); - ((IncGraphVCTraversalCtxImpl) vertexCentricFuncContext).setEnableIncrMatch(enableIncrTraversal); - this.commonFunction.open(traversalRuntimeContext); - - setEdgeFilters(); - if (!incrMessageEdgeFilter.isEmpty()) { - // remove the last edgeFilter since 1 hop is already included in the incr edges. - incrMessageEdgeFilter.remove(incrMessageEdgeFilter.size() - 1); - Collections.reverse(incrMessageEdgeFilter); - } + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + TraversalRuntimeContext traversalRuntimeContext = + new GeaFlowDynamicTraversalRuntimeContext(vertexCentricFuncContext); + this.mutableGraph = vertexCentricFuncContext.getMutableGraph(); + ((IncGraphVCTraversalCtxImpl) vertexCentricFuncContext).setEnableIncrMatch(enableIncrTraversal); + this.commonFunction.open(traversalRuntimeContext); + setEdgeFilters(); + if (!incrMessageEdgeFilter.isEmpty()) { + // remove the last edgeFilter since 1 hop is already included in the incr edges. + incrMessageEdgeFilter.remove(incrMessageEdgeFilter.size() - 1); + Collections.reverse(incrMessageEdgeFilter); } - - void setEdgeFilters() { - StepOperator step = commonFunction.getExecuteDagGroup().getMainDag().getEntryOperator(); - while (!step.getNextOperators().isEmpty()) { - List> nextOperators = step.getNextOperators(); - if (nextOperators.size() > 1) { - // ignore multi branch case - incrMessageEdgeFilter.clear(); - break; - } - - if (step.getClass() == MatchEdgeOperator.class) { - // set the evolveMessage edge directions according to the query. - EdgeDirection direction = ((MatchEdgeOperator) step).getFunction().getDirection(); - switch (direction) { - case OUT: - incrMessageEdgeFilter.add(InEdgeFilter.getInstance()); - break; - case IN: - incrMessageEdgeFilter.add(OutEdgeFilter.getInstance()); - break; - default: - incrMessageEdgeFilter.add(EmptyFilter.getInstance()); - } - } - - step = nextOperators.get(0); + } + + void setEdgeFilters() { + StepOperator step = commonFunction.getExecuteDagGroup().getMainDag().getEntryOperator(); + while (!step.getNextOperators().isEmpty()) { + List> nextOperators = step.getNextOperators(); + if (nextOperators.size() > 1) { + // ignore multi branch case + incrMessageEdgeFilter.clear(); + break; + } + + if (step.getClass() == MatchEdgeOperator.class) { + // set the evolveMessage edge directions according to the query. + EdgeDirection direction = ((MatchEdgeOperator) step).getFunction().getDirection(); + switch (direction) { + case OUT: + incrMessageEdgeFilter.add(InEdgeFilter.getInstance()); + break; + case IN: + incrMessageEdgeFilter.add(OutEdgeFilter.getInstance()); + break; + default: + incrMessageEdgeFilter.add(EmptyFilter.getInstance()); } - } - - @Override - public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - mutableGraph.addVertex(GRAPH_VERSION, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(GRAPH_VERSION, edge); - } - } - - evolveIds.add(vertexId); - } - - @Override - public void initIteration(long windowId) { + } + step = nextOperators.get(0); } + } - @Override - public void init(ITraversalRequest traversalRequest) { - commonFunction.init(traversalRequest); + @Override + public void evolve(Object vertexId, TemporaryGraph temporaryGraph) { + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + mutableGraph.addVertex(GRAPH_VERSION, vertex); } - - @Override - public void finish() { - evolveIds.clear(); + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(GRAPH_VERSION, edge); + } } - @Override - public void close() { - + evolveIds.add(vertexId); + } + + @Override + public void initIteration(long windowId) {} + + @Override + public void init(ITraversalRequest traversalRequest) { + commonFunction.init(traversalRequest); + } + + @Override + public void finish() { + evolveIds.clear(); + } + + @Override + public void close() {} + + private void sendEvolveMessage(Object vertexId, TraversalRuntimeContext context) { + context.setVertex(IdOnlyVertex.of(vertexId)); + + IFilter edgeFilter = + incrMessageEdgeFilter.isEmpty() + ? EmptyFilter.getInstance() + : incrMessageEdgeFilter.get((int) (context.getIterationId() - 1)); + EdgeGroup rowEdges = context.loadEdges(edgeFilter); + StepOperator operator = + commonFunction.getExecuteDagGroup().getMainDag().getEntryOperator(); + for (RowEdge edge : rowEdges) { + context.sendMessage(edge.getTargetId(), new EvolveVertexMessage(), operator.getId()); } + } - private void sendEvolveMessage(Object vertexId, TraversalRuntimeContext context) { - context.setVertex(IdOnlyVertex.of(vertexId)); - - IFilter edgeFilter = incrMessageEdgeFilter.isEmpty() ? EmptyFilter.getInstance() : incrMessageEdgeFilter.get( - (int) (context.getIterationId() - 1)); - EdgeGroup rowEdges = context.loadEdges(edgeFilter); - StepOperator operator = commonFunction.getExecuteDagGroup() - .getMainDag() - .getEntryOperator(); - for (RowEdge edge : rowEdges) { - context.sendMessage(edge.getTargetId(), new EvolveVertexMessage(), operator.getId()); - } - } + private boolean needIncrTraversal() { + // if has initRequests, no need send EvolveMessage. + return enableIncrTraversal + && DynamicGraphHelper.enableIncrTraversalRuntime( + commonFunction.getContext().getRuntimeContext()); + } + + @Override + public void compute(Object vertexId, Iterator messageIterator) { + TraversalRuntimeContext context = commonFunction.getContext(); + if (needIncrTraversal() && !(vertexId instanceof BroadcastId)) { + long iterationId = context.getIterationId(); + // sendEvolveMessage to evolve subGraphs when iterationId is less than the plan iteration + if (iterationId < queryMaxIteration - 1) { + evolveIds.add(vertexId); + sendEvolveMessage(vertexId, context); + return; + } + if (iterationId == queryMaxIteration - 1) { + // the current iteration is the end of evolve phase. + evolveIds.add(vertexId); + return; + } + // traversal + commonFunction.compute(vertexId, messageIterator); - private boolean needIncrTraversal() { - // if has initRequests, no need send EvolveMessage. - return enableIncrTraversal && DynamicGraphHelper.enableIncrTraversalRuntime( - commonFunction.getContext().getRuntimeContext()); + } else { + commonFunction.compute(vertexId, messageIterator); } - - @Override - public void compute(Object vertexId, Iterator messageIterator) { - TraversalRuntimeContext context = commonFunction.getContext(); - if (needIncrTraversal() && !(vertexId instanceof BroadcastId)) { - long iterationId = context.getIterationId(); - // sendEvolveMessage to evolve subGraphs when iterationId is less than the plan iteration - if (iterationId < queryMaxIteration - 1) { - evolveIds.add(vertexId); - sendEvolveMessage(vertexId, context); - return; - } - - if (iterationId == queryMaxIteration - 1) { - // the current iteration is the end of evolve phase. - evolveIds.add(vertexId); - return; - } - // traversal - commonFunction.compute(vertexId, messageIterator); - - } else { - commonFunction.compute(vertexId, messageIterator); + } + + @Override + public void finishIteration(long iterationId) { + TraversalRuntimeContext context = commonFunction.getContext(); + if (needIncrTraversal()) { + if (iterationId == 1) { + // begin evolve + for (ITraversalRequest request : commonFunction.getInitRequests()) { + Object vertexId = request.getVId(); + sendEvolveMessage(vertexId, context); } - } - - - @Override - public void finishIteration(long iterationId) { - TraversalRuntimeContext context = commonFunction.getContext(); - if (needIncrTraversal()) { - if (iterationId == 1) { - // begin evolve - for (ITraversalRequest request : commonFunction.getInitRequests()) { - Object vertexId = request.getVId(); - sendEvolveMessage(vertexId, context); - } - commonFunction.getInitRequests().clear(); - } - - if (commonFunction.getContext().getIterationId() == queryMaxIteration - 1) { - // begin Traversal - LOGGER.info("StartEvolveIds: {}, {}, {}", evolveIds.size(), queryMaxIteration - 1, - context.getRuntimeContext().getWindowId()); - for (Object evolveId : evolveIds) { - ExecuteDagGroup executeDagGroup = commonFunction.getExecuteDagGroup(); - executeDagGroup.execute(evolveId, executeDagGroup.getEntryOpId()); - } - - } - - } else { - commonFunction.finish(iterationId); + commonFunction.getInitRequests().clear(); + } + + if (commonFunction.getContext().getIterationId() == queryMaxIteration - 1) { + // begin Traversal + LOGGER.info( + "StartEvolveIds: {}, {}, {}", + evolveIds.size(), + queryMaxIteration - 1, + context.getRuntimeContext().getWindowId()); + for (Object evolveId : evolveIds) { + ExecuteDagGroup executeDagGroup = commonFunction.getExecuteDagGroup(); + executeDagGroup.execute(evolveId, executeDagGroup.getEntryOpId()); } + } + } else { + commonFunction.finish(iterationId); } + } - @Override - public void finish(Object vertexId, MutableGraph mutableGraph) { - } - + @Override + public void finish(Object vertexId, MutableGraph mutableGraph) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowGqlClient.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowGqlClient.java index 5fa39d907..d7e313378 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowGqlClient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowGqlClient.java @@ -21,13 +21,12 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.CLUSTER_TYPE; -import com.google.common.reflect.TypeToken; -import com.google.gson.Gson; import java.io.IOException; import java.nio.charset.Charset; import java.util.HashMap; import java.util.Locale; import java.util.Map; + import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.env.Environment; @@ -37,66 +36,68 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.reflect.TypeToken; +import com.google.gson.Gson; + public class GeaFlowGqlClient { - private static final Logger LOGGER = LoggerFactory.getLogger("GeaFlowGqlClient"); + private static final Logger LOGGER = LoggerFactory.getLogger("GeaFlowGqlClient"); - private static final String CONF_FILE_NAME = "user.conf"; + private static final String CONF_FILE_NAME = "user.conf"; - public static void main(String[] args) throws Exception { - for (int i = 0; i < args.length; i++) { - LOGGER.info("args[{}]: {}", i, args[i]); - } - Environment environment = loadEnvironment(args); - Map parallelismConfigMap = loadParallelismConfig(); - LOGGER.info("parallelism config map: {}", parallelismConfigMap); - int timeWait = -1; // No wait for remote mode. - if (environment.getEnvType() == EnvType.LOCAL) { - timeWait = 0; // Infinite wait for local test. - Map localConfig = new HashMap<>(); - if (!environment.getEnvironmentContext().getConfig().contains(FileConfigKeys.ROOT.getKey())) { - localConfig.put(FileConfigKeys.ROOT.getKey(), "/tmp/dsl/"); - } - environment.getEnvironmentContext().withConfig(localConfig); - } - GQLPipeLine pipeLine = new GQLPipeLine(environment, timeWait, parallelismConfigMap); - pipeLine.execute(); + public static void main(String[] args) throws Exception { + for (int i = 0; i < args.length; i++) { + LOGGER.info("args[{}]: {}", i, args[i]); + } + Environment environment = loadEnvironment(args); + Map parallelismConfigMap = loadParallelismConfig(); + LOGGER.info("parallelism config map: {}", parallelismConfigMap); + int timeWait = -1; // No wait for remote mode. + if (environment.getEnvType() == EnvType.LOCAL) { + timeWait = 0; // Infinite wait for local test. + Map localConfig = new HashMap<>(); + if (!environment.getEnvironmentContext().getConfig().contains(FileConfigKeys.ROOT.getKey())) { + localConfig.put(FileConfigKeys.ROOT.getKey(), "/tmp/dsl/"); + } + environment.getEnvironmentContext().withConfig(localConfig); } + GQLPipeLine pipeLine = new GQLPipeLine(environment, timeWait, parallelismConfigMap); + pipeLine.execute(); + } - private static Map loadParallelismConfig() { - try { - String parallelismConf = IOUtils.resourceToString(CONF_FILE_NAME, - Charset.defaultCharset(), GeaFlowGqlClient.class.getClassLoader()); - Gson gson = new Gson(); - return gson.fromJson(parallelismConf, new TypeToken>() { - }.getType()); - } catch (IOException e) { - if (!e.getMessage().contains("Resource not found")) { - LOGGER.warn("Error in load parallelism config file", e); - } - return new HashMap<>(); - } + private static Map loadParallelismConfig() { + try { + String parallelismConf = + IOUtils.resourceToString( + CONF_FILE_NAME, Charset.defaultCharset(), GeaFlowGqlClient.class.getClassLoader()); + Gson gson = new Gson(); + return gson.fromJson(parallelismConf, new TypeToken>() {}.getType()); + } catch (IOException e) { + if (!e.getMessage().contains("Resource not found")) { + LOGGER.warn("Error in load parallelism config file", e); + } + return new HashMap<>(); } + } - private static EnvType getClusterType() { - String clusterType = System.getProperty(CLUSTER_TYPE); - if (StringUtils.isBlank(clusterType)) { - LOGGER.warn("use LOCAL as default cluster"); - return EnvType.LOCAL; - } - return (EnvType.valueOf(clusterType.toUpperCase(Locale.ROOT))); + private static EnvType getClusterType() { + String clusterType = System.getProperty(CLUSTER_TYPE); + if (StringUtils.isBlank(clusterType)) { + LOGGER.warn("use LOCAL as default cluster"); + return EnvType.LOCAL; } + return (EnvType.valueOf(clusterType.toUpperCase(Locale.ROOT))); + } - public static Environment loadEnvironment(String[] args) { - EnvType clusterType = getClusterType(); - switch (clusterType) { - case K8S: - return EnvironmentFactory.onK8SEnvironment(args); - case RAY: - return EnvironmentFactory.onRayEnvironment(args); - default: - return EnvironmentFactory.onLocalEnvironment(args); - } + public static Environment loadEnvironment(String[] args) { + EnvType clusterType = getClusterType(); + switch (clusterType) { + case K8S: + return EnvironmentFactory.onK8SEnvironment(args); + case RAY: + return EnvironmentFactory.onRayEnvironment(args); + default: + return EnvironmentFactory.onLocalEnvironment(args); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVAlgorithmAggregateFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVAlgorithmAggregateFunction.java index 8e0e9984d..1ccf1fdf8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVAlgorithmAggregateFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVAlgorithmAggregateFunction.java @@ -21,109 +21,119 @@ import java.util.Map.Entry; import java.util.Objects; + import org.apache.geaflow.dsl.runtime.traversal.message.KVTraversalAgg; public class GeaFlowKVAlgorithmAggregateFunction extends GeaFlowKVTraversalAggregateFunction { - private static final String ALGORITHM_ITERATION_PREFIX = "AlgorithmIteration-"; - - public GeaFlowKVAlgorithmAggregateFunction(int parallelism) { - super(parallelism); - } - - public static KVTraversalAgg getAlgorithmAgg(long iteration) { - return new KVTraversalAgg(ALGORITHM_ITERATION_PREFIX + iteration, 1); - } - - @Override - public IPartialGraphAggFunction, - KVTraversalAgg, KVTraversalAgg> getPartialAggregation() { - return new IPartialGraphAggFunction, - KVTraversalAgg, KVTraversalAgg>() { - - private IPartialAggContext> partialAggContext; - - @Override - public KVTraversalAgg create( - IPartialAggContext> partialAggContext) { - this.partialAggContext = Objects.requireNonNull(partialAggContext); - return KVTraversalAgg.empty(); - } - - - @Override - public KVTraversalAgg aggregate(KVTraversalAgg iterm, - KVTraversalAgg result) { - if (iterm == null) { - return result; - } else if (result == null) { - return iterm; - } - for (Entry entryObj : iterm.getMap().entrySet()) { - String key = entryObj.getKey(); - if (result.getMap().containsKey(key)) { - result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); - } else { - result.getMap().put(key, entryObj.getValue()); - } - } - return result; - } - - @Override - public void finish(KVTraversalAgg result) { - assert partialAggContext != null; - if (result != null) { - KVTraversalAgg tmpResult = result.copy(); - partialAggContext.collect(tmpResult); - } - } - }; - } - - @Override - public IGraphAggregateFunction, KVTraversalAgg, - KVTraversalAgg> getGlobalAggregation() { - return new IGraphAggregateFunction, - KVTraversalAgg, KVTraversalAgg>() { - - private IGlobalGraphAggContext> globalGraphAggContext; - - @Override - public KVTraversalAgg create( - IGlobalGraphAggContext> globalGraphAggContext) { - this.globalGraphAggContext = Objects.requireNonNull(globalGraphAggContext); - return KVTraversalAgg.empty(); - } - - @Override - public KVTraversalAgg aggregate(KVTraversalAgg iterm, - KVTraversalAgg result) { - if (iterm == null) { - return result; - } else if (result == null) { - return iterm; - } - for (Entry entryObj : iterm.getMap().entrySet()) { - String key = entryObj.getKey(); - if (result.getMap().containsKey(key)) { - result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); - } else { - result.getMap().put(key, entryObj.getValue()); - } - } - return result; - } - - @Override - public void finish(KVTraversalAgg value) { - assert globalGraphAggContext != null; - long currentIteration = globalGraphAggContext.getIteration(); - String key = ALGORITHM_ITERATION_PREFIX + currentIteration; - if (value == null || value.get(key) == null || value.get(key) == 0) { - globalGraphAggContext.terminate(); - } - } - }; - } + private static final String ALGORITHM_ITERATION_PREFIX = "AlgorithmIteration-"; + + public GeaFlowKVAlgorithmAggregateFunction(int parallelism) { + super(parallelism); + } + + public static KVTraversalAgg getAlgorithmAgg(long iteration) { + return new KVTraversalAgg(ALGORITHM_ITERATION_PREFIX + iteration, 1); + } + + @Override + public IPartialGraphAggFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg> + getPartialAggregation() { + return new IPartialGraphAggFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg>() { + + private IPartialAggContext> partialAggContext; + + @Override + public KVTraversalAgg create( + IPartialAggContext> partialAggContext) { + this.partialAggContext = Objects.requireNonNull(partialAggContext); + return KVTraversalAgg.empty(); + } + + @Override + public KVTraversalAgg aggregate( + KVTraversalAgg iterm, KVTraversalAgg result) { + if (iterm == null) { + return result; + } else if (result == null) { + return iterm; + } + for (Entry entryObj : iterm.getMap().entrySet()) { + String key = entryObj.getKey(); + if (result.getMap().containsKey(key)) { + result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); + } else { + result.getMap().put(key, entryObj.getValue()); + } + } + return result; + } + + @Override + public void finish(KVTraversalAgg result) { + assert partialAggContext != null; + if (result != null) { + KVTraversalAgg tmpResult = result.copy(); + partialAggContext.collect(tmpResult); + } + } + }; + } + + @Override + public IGraphAggregateFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg> + getGlobalAggregation() { + return new IGraphAggregateFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg>() { + + private IGlobalGraphAggContext> globalGraphAggContext; + + @Override + public KVTraversalAgg create( + IGlobalGraphAggContext> globalGraphAggContext) { + this.globalGraphAggContext = Objects.requireNonNull(globalGraphAggContext); + return KVTraversalAgg.empty(); + } + + @Override + public KVTraversalAgg aggregate( + KVTraversalAgg iterm, KVTraversalAgg result) { + if (iterm == null) { + return result; + } else if (result == null) { + return iterm; + } + for (Entry entryObj : iterm.getMap().entrySet()) { + String key = entryObj.getKey(); + if (result.getMap().containsKey(key)) { + result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); + } else { + result.getMap().put(key, entryObj.getValue()); + } + } + return result; + } + + @Override + public void finish(KVTraversalAgg value) { + assert globalGraphAggContext != null; + long currentIteration = globalGraphAggContext.getIteration(); + String key = ALGORITHM_ITERATION_PREFIX + currentIteration; + if (value == null || value.get(key) == null || value.get(key) == 0) { + globalGraphAggContext.terminate(); + } + } + }; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVTraversalAggregateFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVTraversalAggregateFunction.java index 419cf26e7..d6039e747 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVTraversalAggregateFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowKVTraversalAggregateFunction.java @@ -21,110 +21,124 @@ import java.util.Map.Entry; import java.util.Objects; + import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.dsl.runtime.traversal.collector.StepEndCollector; import org.apache.geaflow.dsl.runtime.traversal.message.KVTraversalAgg; -public class GeaFlowKVTraversalAggregateFunction implements VertexCentricAggregateFunction - , KVTraversalAgg, - KVTraversalAgg, KVTraversalAgg, +public class GeaFlowKVTraversalAggregateFunction + implements VertexCentricAggregateFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg, KVTraversalAgg> { - private final int parallelism; - - public GeaFlowKVTraversalAggregateFunction(int parallelism) { - assert parallelism > 0 : "GeaFlowKVTraversalAggregateFunction parallelism <= 0"; - this.parallelism = parallelism; - } - - @Override - public IPartialGraphAggFunction, - KVTraversalAgg, KVTraversalAgg> getPartialAggregation() { - return new IPartialGraphAggFunction, - KVTraversalAgg, KVTraversalAgg>() { - - private IPartialAggContext> partialAggContext; - - @Override - public KVTraversalAgg create( - IPartialAggContext> partialAggContext) { - this.partialAggContext = Objects.requireNonNull(partialAggContext); - return KVTraversalAgg.empty(); - } - - - @Override - public KVTraversalAgg aggregate(KVTraversalAgg iterm, - KVTraversalAgg result) { - if (iterm == null) { - return result; - } else if (result == null) { - return iterm; - } - for (Entry entryObj : iterm.getMap().entrySet()) { - String key = entryObj.getKey(); - if (result.getMap().containsKey(key)) { - result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); - } else { - result.getMap().put(key, entryObj.getValue()); - } - } - return result; - } - - @Override - public void finish(KVTraversalAgg result) { - assert partialAggContext != null; - if (result != null) { - KVTraversalAgg tmpResult = result.copy(); - partialAggContext.collect(tmpResult); - } - } - }; - } - - @Override - public IGraphAggregateFunction, KVTraversalAgg, - KVTraversalAgg> getGlobalAggregation() { - return new IGraphAggregateFunction, - KVTraversalAgg, KVTraversalAgg>() { - - private IGlobalGraphAggContext> globalGraphAggContext; - - @Override - public KVTraversalAgg create( - IGlobalGraphAggContext> globalGraphAggContext) { - this.globalGraphAggContext = Objects.requireNonNull(globalGraphAggContext); - return KVTraversalAgg.empty(); - } - - @Override - public KVTraversalAgg aggregate(KVTraversalAgg iterm, - KVTraversalAgg result) { - if (iterm == null) { - return result; - } else if (result == null) { - return iterm; - } - for (Entry entryObj : iterm.getMap().entrySet()) { - String key = entryObj.getKey(); - if (result.getMap().containsKey(key)) { - result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); - } else { - result.getMap().put(key, entryObj.getValue()); - } - } - return result; - } - - @Override - public void finish(KVTraversalAgg value) { - assert globalGraphAggContext != null; - if (value != null && value.getMap().containsKey(StepEndCollector.TRAVERSAL_FINISH) - && value.get(StepEndCollector.TRAVERSAL_FINISH) >= parallelism) { - globalGraphAggContext.terminate(); - } - } - }; - } + private final int parallelism; + + public GeaFlowKVTraversalAggregateFunction(int parallelism) { + assert parallelism > 0 : "GeaFlowKVTraversalAggregateFunction parallelism <= 0"; + this.parallelism = parallelism; + } + + @Override + public IPartialGraphAggFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg> + getPartialAggregation() { + return new IPartialGraphAggFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg>() { + + private IPartialAggContext> partialAggContext; + + @Override + public KVTraversalAgg create( + IPartialAggContext> partialAggContext) { + this.partialAggContext = Objects.requireNonNull(partialAggContext); + return KVTraversalAgg.empty(); + } + + @Override + public KVTraversalAgg aggregate( + KVTraversalAgg iterm, KVTraversalAgg result) { + if (iterm == null) { + return result; + } else if (result == null) { + return iterm; + } + for (Entry entryObj : iterm.getMap().entrySet()) { + String key = entryObj.getKey(); + if (result.getMap().containsKey(key)) { + result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); + } else { + result.getMap().put(key, entryObj.getValue()); + } + } + return result; + } + + @Override + public void finish(KVTraversalAgg result) { + assert partialAggContext != null; + if (result != null) { + KVTraversalAgg tmpResult = result.copy(); + partialAggContext.collect(tmpResult); + } + } + }; + } + + @Override + public IGraphAggregateFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg> + getGlobalAggregation() { + return new IGraphAggregateFunction< + KVTraversalAgg, + KVTraversalAgg, + KVTraversalAgg>() { + + private IGlobalGraphAggContext> globalGraphAggContext; + + @Override + public KVTraversalAgg create( + IGlobalGraphAggContext> globalGraphAggContext) { + this.globalGraphAggContext = Objects.requireNonNull(globalGraphAggContext); + return KVTraversalAgg.empty(); + } + + @Override + public KVTraversalAgg aggregate( + KVTraversalAgg iterm, KVTraversalAgg result) { + if (iterm == null) { + return result; + } else if (result == null) { + return iterm; + } + for (Entry entryObj : iterm.getMap().entrySet()) { + String key = entryObj.getKey(); + if (result.getMap().containsKey(key)) { + result.getMap().put(key, result.getMap().get(key) + entryObj.getValue()); + } else { + result.getMap().put(key, entryObj.getValue()); + } + } + return result; + } + + @Override + public void finish(KVTraversalAgg value) { + assert globalGraphAggContext != null; + if (value != null + && value.getMap().containsKey(StepEndCollector.TRAVERSAL_FINISH) + && value.get(StepEndCollector.TRAVERSAL_FINISH) >= parallelism) { + globalGraphAggContext.terminate(); + } + } + }; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowQueryEngine.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowQueryEngine.java index 7969a374c..f9af18e87 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowQueryEngine.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowQueryEngine.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.function.io.SourceFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -59,175 +60,189 @@ import org.apache.geaflow.view.graph.GraphViewDesc; import org.apache.geaflow.view.graph.PGraphView; -/** - * GeaFlow's implement of the {@link QueryEngine}. - */ +/** GeaFlow's implement of the {@link QueryEngine}. */ public class GeaFlowQueryEngine implements QueryEngine { - private final IPipelineJobContext pipelineContext; - - public GeaFlowQueryEngine(IPipelineJobContext pipelineContext) { - this.pipelineContext = pipelineContext; + private final IPipelineJobContext pipelineContext; + + public GeaFlowQueryEngine(IPipelineJobContext pipelineContext) { + this.pipelineContext = pipelineContext; + } + + @Override + public Map getConfig() { + return pipelineContext.getConfig().getConfigMap(); + } + + @Override + public IPipelineJobContext getContext() { + return pipelineContext; + } + + @Override + public RuntimeTable createRuntimeTable( + QueryContext context, GeaFlowTable table, Expression pushFiler) { + String tableType = table.getTableType(); + TableConnector connector = ConnectorFactory.loadConnector(tableType); + if (!(connector instanceof TableReadableConnector)) { + throw new GeaFlowDSLException("Table: '{}' is not readable", connector.getType()); } - - @Override - public Map getConfig() { - return pipelineContext.getConfig().getConfigMap(); + TableReadableConnector readableConnector = (TableReadableConnector) connector; + Configuration conf = + table.getConfigWithGlobal(pipelineContext.getConfig(), context.getSetOptions()); + TableSource tableSource = readableConnector.createSource(conf); + // push down filter to table source + if (pushFiler != null) { + List filters = pushFiler.splitByAnd(); + List partitionFilters = + filters.stream() + .filter(filter -> isPartitionFilter(filter, table)) + .collect(Collectors.toList()); + + Expression partitionFilterExp = + partitionFilters.isEmpty() + ? LiteralExpression.createBoolean(true) + : new AndExpression(partitionFilters); + PartitionFilter partitionFilter = + new PartitionFilterImpl(partitionFilterExp, table.getPartitionIndices()); + if (tableSource instanceof EnablePartitionPushDown) { + ((EnablePartitionPushDown) tableSource).setPartitionFilter(partitionFilter); + } } - @Override - public IPipelineJobContext getContext() { - return pipelineContext; - } + tableSource.init(conf, table.getTableSchema()); - @Override - public RuntimeTable createRuntimeTable(QueryContext context, GeaFlowTable table, Expression pushFiler) { - String tableType = table.getTableType(); - TableConnector connector = ConnectorFactory.loadConnector(tableType); - if (!(connector instanceof TableReadableConnector)) { - throw new GeaFlowDSLException("Table: '{}' is not readable", connector.getType()); - } - TableReadableConnector readableConnector = (TableReadableConnector) connector; - Configuration conf = table.getConfigWithGlobal(pipelineContext.getConfig(), context.getSetOptions()); - TableSource tableSource = readableConnector.createSource(conf); - // push down filter to table source - if (pushFiler != null) { - List filters = pushFiler.splitByAnd(); - List partitionFilters = filters.stream() - .filter(filter -> isPartitionFilter(filter, table)) - .collect(Collectors.toList()); - - Expression partitionFilterExp = - partitionFilters.isEmpty() ? LiteralExpression.createBoolean(true) : - new AndExpression(partitionFilters); - PartitionFilter partitionFilter = new PartitionFilterImpl(partitionFilterExp, - table.getPartitionIndices()); - if (tableSource instanceof EnablePartitionPushDown) { - ((EnablePartitionPushDown) tableSource).setPartitionFilter(partitionFilter); - } - } - - tableSource.init(conf, table.getTableSchema()); - - if (!(tableSource instanceof ISkipOpenAndClose)) { - tableSource.open(new DefaultRuntimeContext(pipelineContext.getConfig())); - } - String opName = PhysicRelNodeName.TABLE_SCAN.getName(table.getName()); - int parallelism = context.getConfigParallelisms(opName, -1); - if (parallelism == -1) { - parallelism = Configuration.getInteger(DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM, - (Integer) DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM.getDefaultValue(), table.getConfig()); - } - int numPartitions = tableSource.listPartitions(parallelism).size(); - int partitionsPerParallelism = conf.getInteger( - ConnectorConfigKeys.GEAFLOW_DSL_PARTITIONS_PER_SOURCE_PARALLELISM); - int sourceParallelism; - // If user has set source parallelism, use it. - if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_SOURCE_PARALLELISM)) { - sourceParallelism = conf.getInteger(DSLConfigKeys.GEAFLOW_DSL_SOURCE_PARALLELISM); - } else { - if (numPartitions % partitionsPerParallelism > 0) { - sourceParallelism = numPartitions / partitionsPerParallelism + 1; - } else { - sourceParallelism = numPartitions / partitionsPerParallelism; - } - } - if (!(tableSource instanceof ISkipOpenAndClose)) { - tableSource.close(); - } - if (context.getConfigParallelisms(opName, sourceParallelism) == -1 - && Configuration.getInteger(DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM, -1, table.getConfig()) == -1) { - parallelism = sourceParallelism; - } - GeaFlowTableSourceFunction sourceFunction; - if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION)) { - String customClassName = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION); - try { - sourceFunction = (GeaFlowTableSourceFunction) ClassUtil.classForName(customClassName) + if (!(tableSource instanceof ISkipOpenAndClose)) { + tableSource.open(new DefaultRuntimeContext(pipelineContext.getConfig())); + } + String opName = PhysicRelNodeName.TABLE_SCAN.getName(table.getName()); + int parallelism = context.getConfigParallelisms(opName, -1); + if (parallelism == -1) { + parallelism = + Configuration.getInteger( + DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM, + (Integer) DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM.getDefaultValue(), + table.getConfig()); + } + int numPartitions = tableSource.listPartitions(parallelism).size(); + int partitionsPerParallelism = + conf.getInteger(ConnectorConfigKeys.GEAFLOW_DSL_PARTITIONS_PER_SOURCE_PARALLELISM); + int sourceParallelism; + // If user has set source parallelism, use it. + if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_SOURCE_PARALLELISM)) { + sourceParallelism = conf.getInteger(DSLConfigKeys.GEAFLOW_DSL_SOURCE_PARALLELISM); + } else { + if (numPartitions % partitionsPerParallelism > 0) { + sourceParallelism = numPartitions / partitionsPerParallelism + 1; + } else { + sourceParallelism = numPartitions / partitionsPerParallelism; + } + } + if (!(tableSource instanceof ISkipOpenAndClose)) { + tableSource.close(); + } + if (context.getConfigParallelisms(opName, sourceParallelism) == -1 + && Configuration.getInteger( + DSLConfigKeys.GEAFLOW_DSL_TABLE_PARALLELISM, -1, table.getConfig()) + == -1) { + parallelism = sourceParallelism; + } + GeaFlowTableSourceFunction sourceFunction; + if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION)) { + String customClassName = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION); + try { + sourceFunction = + (GeaFlowTableSourceFunction) + ClassUtil.classForName(customClassName) .getConstructor(GeaFlowTable.class, TableSource.class) .newInstance(table, tableSource); - } catch (Exception e) { - throw new GeaFlowDSLException("Cannot create sink function: {}.", - customClassName, e); - } - } else { - sourceFunction = new GeaFlowTableSourceFunction(table, tableSource); - } - - IWindow window = Windows.createWindow(conf); - PWindowSource source = pipelineContext.buildSource(sourceFunction, window) + } catch (Exception e) { + throw new GeaFlowDSLException("Cannot create sink function: {}.", customClassName, e); + } + } else { + sourceFunction = new GeaFlowTableSourceFunction(table, tableSource); + } + + IWindow window = Windows.createWindow(conf); + PWindowSource source = + pipelineContext + .buildSource(sourceFunction, window) .withConfig(context.getSetOptions()) .withName(opName) .withParallelism(parallelism); - return new GeaFlowRuntimeTable(context, pipelineContext, source); - } + return new GeaFlowRuntimeTable(context, pipelineContext, source); + } - @Override - public RuntimeTable createRuntimeTable(QueryContext context, Collection rows) { - IWindow window = Windows.createWindow(pipelineContext.getConfig()); - PWindowSource source = pipelineContext.buildSource(new CollectionSource<>(rows), window) + @Override + public RuntimeTable createRuntimeTable(QueryContext context, Collection rows) { + IWindow window = Windows.createWindow(pipelineContext.getConfig()); + PWindowSource source = + pipelineContext + .buildSource(new CollectionSource<>(rows), window) .withConfig(context.getSetOptions()) .withName(PhysicRelNodeName.VALUE_SCAN.getName(context.getOpNameCount())) .withParallelism(1); - return new GeaFlowRuntimeTable(context, pipelineContext, source); - } - - @Override - public PWindowSource createRuntimeTable(QueryContext context, SourceFunction sourceFunction) { - IWindow window = Windows.createWindow(pipelineContext.getConfig()); - return pipelineContext.buildSource(sourceFunction, window) - .withConfig(context.getSetOptions()); - } - - @SuppressWarnings("unchecked") - @Override - public RuntimeGraph createRuntimeGraph(QueryContext context, GeaFlowGraph graph) { - GraphViewDesc graphViewDesc = SchemaUtil.buildGraphViewDesc(graph, context.getGlobalConf()); - PGraphView pGraphView = pipelineContext.createGraphView(graphViewDesc); - return new GeaFlowRuntimeGraph(context, pGraphView, graph, graphViewDesc); - } - - public IPipelineJobContext getPipelineContext() { - return pipelineContext; + return new GeaFlowRuntimeTable(context, pipelineContext, source); + } + + @Override + public PWindowSource createRuntimeTable( + QueryContext context, SourceFunction sourceFunction) { + IWindow window = Windows.createWindow(pipelineContext.getConfig()); + return pipelineContext.buildSource(sourceFunction, window).withConfig(context.getSetOptions()); + } + + @SuppressWarnings("unchecked") + @Override + public RuntimeGraph createRuntimeGraph(QueryContext context, GeaFlowGraph graph) { + GraphViewDesc graphViewDesc = SchemaUtil.buildGraphViewDesc(graph, context.getGlobalConf()); + PGraphView pGraphView = pipelineContext.createGraphView(graphViewDesc); + return new GeaFlowRuntimeGraph(context, pGraphView, graph, graphViewDesc); + } + + public IPipelineJobContext getPipelineContext() { + return pipelineContext; + } + + private boolean isPartitionFilter(Expression filter, GeaFlowTable table) { + if (filter.getInputs().size() != 2) { + return false; } - - private boolean isPartitionFilter(Expression filter, GeaFlowTable table) { - if (filter.getInputs().size() != 2) { - return false; - } - Expression left = filter.getInputs().get(0); - Expression right = filter.getInputs().get(1); - if (left instanceof FieldExpression - && table.isPartitionField(((FieldExpression) left).getFieldIndex()) - && right instanceof LiteralExpression) { - return true; - } - return right instanceof FieldExpression - && table.isPartitionField(((FieldExpression) right).getFieldIndex()) - && left instanceof LiteralExpression; + Expression left = filter.getInputs().get(0); + Expression right = filter.getInputs().get(1); + if (left instanceof FieldExpression + && table.isPartitionField(((FieldExpression) left).getFieldIndex()) + && right instanceof LiteralExpression) { + return true; } + return right instanceof FieldExpression + && table.isPartitionField(((FieldExpression) right).getFieldIndex()) + && left instanceof LiteralExpression; + } - private static class PartitionFilterImpl implements PartitionFilter { + private static class PartitionFilterImpl implements PartitionFilter { - private final Expression filter; + private final Expression filter; - public PartitionFilterImpl(Expression filter, List partitionFields) { - this.filter = filter.replace(exp -> { + public PartitionFilterImpl(Expression filter, List partitionFields) { + this.filter = + filter.replace( + exp -> { if (exp instanceof FieldExpression) { - FieldExpression field = (FieldExpression) exp; - int indexOfPartitionField = partitionFields.indexOf(field.getFieldIndex()); - assert indexOfPartitionField >= 0 : "Not a partition field in partition filter"; - return field.copy(indexOfPartitionField); + FieldExpression field = (FieldExpression) exp; + int indexOfPartitionField = partitionFields.indexOf(field.getFieldIndex()); + assert indexOfPartitionField >= 0 : "Not a partition field in partition filter"; + return field.copy(indexOfPartitionField); } return exp; - }); - } - - @Override - public boolean apply(Row partition) { - Boolean accept = (Boolean) filter.evaluate(partition); - return accept != null && accept; - } + }); + } + + @Override + public boolean apply(Row partition) { + Boolean accept = (Boolean) filter.evaluate(partition); + return accept != null && accept; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeGraph.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeGraph.java index 168d1f32f..a7fdfc803 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeGraph.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeGraph.java @@ -19,13 +19,13 @@ package org.apache.geaflow.dsl.runtime.engine; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; @@ -79,447 +79,604 @@ import org.apache.geaflow.view.graph.PGraphView; import org.apache.geaflow.view.graph.PIncGraphView; -public class GeaFlowRuntimeGraph implements RuntimeGraph { - - private final QueryContext queryContext; - - private final IPipelineJobContext context; - - private final GeaFlowGraph graph; - - private final GraphSchema graphSchema; - - private final GraphViewDesc graphViewDesc; - - private final PGraphView graphView; - - private final StepLogicalPlanSet logicalPlanSet; - - public GeaFlowRuntimeGraph(QueryContext queryContext, - PGraphView graphView, - GeaFlowGraph graph, - StepLogicalPlanSet logicalPlanSet, - GraphViewDesc graphViewDesc) { - this.queryContext = Objects.requireNonNull(queryContext); - this.context = ((GeaFlowQueryEngine) queryContext.getEngineContext()).getPipelineContext(); - this.graphView = Objects.requireNonNull(graphView); - this.graph = Objects.requireNonNull(graph); - this.graphSchema = graph.getGraphSchema(queryContext.getGqlContext().getTypeFactory()); - this.logicalPlanSet = logicalPlanSet; - this.graphViewDesc = graphViewDesc; - } - - public GeaFlowRuntimeGraph(QueryContext queryContext, - PGraphView graphView, - GeaFlowGraph graph, - GraphViewDesc graphViewDesc) { - this(queryContext, graphView, graph, planSet(graph, queryContext), graphViewDesc); - } - - private static StepLogicalPlanSet planSet(GeaFlowGraph graph, QueryContext queryContext) { - return new StepLogicalPlanSet(graph.getGraphSchema(queryContext.getGqlContext().getTypeFactory())); - } - - @Override - public T getPlan() { - return getPathTable().getPlan(); - } - - @Override - public List take(IType type) { - return ArrayUtil.castList(getPathTable().take(logicalPlanSet.getMainPlan().getOutputPathSchema())); - } +import com.google.common.base.Preconditions; - @Override - public RuntimeGraph traversal(GraphMatch graphMatch) { - StepLogicalPlanTranslator planTranslator = new StepLogicalPlanTranslator(); - StepLogicalPlan logicalPlan = planTranslator.translate(graphMatch, logicalPlanSet); - logicalPlanSet.setMainPlan(logicalPlan); - return new GeaFlowRuntimeGraph(queryContext, graphView, graph, logicalPlanSet, graphViewDesc); - } +public class GeaFlowRuntimeGraph implements RuntimeGraph { - @Override - public RuntimeTable getPathTable() { - assert logicalPlanSet != null; - DagGroupBuilder builder = new DagGroupBuilder(); - ExecuteDagGroup executeDagGroup = builder.buildExecuteDagGroup(logicalPlanSet); - StepOperator mainOp = executeDagGroup.getMainDag().getEntryOperator(); - assert mainOp instanceof StepSourceOperator; - Set startIds = ((StepSourceOperator) mainOp).getStartIds(); - - Set parameterStartIds = startIds.stream() + private final QueryContext queryContext; + + private final IPipelineJobContext context; + + private final GeaFlowGraph graph; + + private final GraphSchema graphSchema; + + private final GraphViewDesc graphViewDesc; + + private final PGraphView graphView; + + private final StepLogicalPlanSet logicalPlanSet; + + public GeaFlowRuntimeGraph( + QueryContext queryContext, + PGraphView graphView, + GeaFlowGraph graph, + StepLogicalPlanSet logicalPlanSet, + GraphViewDesc graphViewDesc) { + this.queryContext = Objects.requireNonNull(queryContext); + this.context = ((GeaFlowQueryEngine) queryContext.getEngineContext()).getPipelineContext(); + this.graphView = Objects.requireNonNull(graphView); + this.graph = Objects.requireNonNull(graph); + this.graphSchema = graph.getGraphSchema(queryContext.getGqlContext().getTypeFactory()); + this.logicalPlanSet = logicalPlanSet; + this.graphViewDesc = graphViewDesc; + } + + public GeaFlowRuntimeGraph( + QueryContext queryContext, + PGraphView graphView, + GeaFlowGraph graph, + GraphViewDesc graphViewDesc) { + this(queryContext, graphView, graph, planSet(graph, queryContext), graphViewDesc); + } + + private static StepLogicalPlanSet planSet(GeaFlowGraph graph, QueryContext queryContext) { + return new StepLogicalPlanSet( + graph.getGraphSchema(queryContext.getGqlContext().getTypeFactory())); + } + + @Override + public T getPlan() { + return getPathTable().getPlan(); + } + + @Override + public List take(IType type) { + return ArrayUtil.castList( + getPathTable().take(logicalPlanSet.getMainPlan().getOutputPathSchema())); + } + + @Override + public RuntimeGraph traversal(GraphMatch graphMatch) { + StepLogicalPlanTranslator planTranslator = new StepLogicalPlanTranslator(); + StepLogicalPlan logicalPlan = planTranslator.translate(graphMatch, logicalPlanSet); + logicalPlanSet.setMainPlan(logicalPlan); + return new GeaFlowRuntimeGraph(queryContext, graphView, graph, logicalPlanSet, graphViewDesc); + } + + @Override + public RuntimeTable getPathTable() { + assert logicalPlanSet != null; + DagGroupBuilder builder = new DagGroupBuilder(); + ExecuteDagGroup executeDagGroup = builder.buildExecuteDagGroup(logicalPlanSet); + StepOperator mainOp = executeDagGroup.getMainDag().getEntryOperator(); + assert mainOp instanceof StepSourceOperator; + Set startIds = ((StepSourceOperator) mainOp).getStartIds(); + + Set parameterStartIds = + startIds.stream() .filter(id -> id instanceof ParameterStartId) .map(id -> (ParameterStartId) id) .collect(Collectors.toSet()); - Set constantStartIds = startIds.stream() + Set constantStartIds = + startIds.stream() .filter(id -> id instanceof ConstantStartId) .map(id -> ((ConstantStartId) id).getValue()) .collect(Collectors.toSet()); - int maxTraversal = context.getConfig().getInteger(DSLConfigKeys.GEAFLOW_DSL_MAX_TRAVERSAL); - int dagMaxTraversal = executeDagGroup.getMaxIterationCount(); + int maxTraversal = context.getConfig().getInteger(DSLConfigKeys.GEAFLOW_DSL_MAX_TRAVERSAL); + int dagMaxTraversal = executeDagGroup.getMaxIterationCount(); - boolean isAggTraversal = dagMaxTraversal == Integer.MAX_VALUE; - if (!isAggTraversal) { - maxTraversal = Math.max(0, Math.min(maxTraversal, dagMaxTraversal)); - } - int parallelism = (queryContext.getTraversalParallelism() > 0 - && queryContext.getTraversalParallelism() <= graph.getShardCount()) - ? queryContext.getTraversalParallelism() : graph.getShardCount(); - - PWindowStream> responsePWindow; - - assert graphView instanceof PIncGraphView : "Illegal graph view"; - queryContext.addMaterializedGraph(graph.getName()); - - PWindowStream vertexStream = queryContext.getGraphVertexStream(graph.getName()); - PWindowStream edgeStream = queryContext.getGraphEdgeStream(graph.getName()); - if (vertexStream == null && edgeStream == null) { // traversal on snapshot of the - // dynamic graph - PGraphWindow staticGraph = graphView.snapshot(graphViewDesc.getCurrentVersion()); - responsePWindow = staticGraphTraversal(staticGraph, parameterStartIds, - constantStartIds, executeDagGroup, maxTraversal, isAggTraversal, parallelism); - } else { // traversal on dynamic graph - Preconditions.checkArgument(graphViewDesc.getBackend() != BackendType.Paimon, - "paimon does not support dynamic graph traversal"); - - boolean enableIncrTraversal = DynamicGraphHelper.enableIncrTraversal(maxTraversal, startIds.size(), - context.getConfig()); - if (maxTraversal != Integer.MAX_VALUE) { - if (enableIncrTraversal) { - // Double the maxTraversal if is incrTraversal, need pre evolve subgraph. - // the evolve phase is 1 smaller than the query Iteration - maxTraversal = maxTraversal * 2 - 1; - } - } - vertexStream = vertexStream != null ? vertexStream : - queryContext.getEngineContext().createRuntimeTable(queryContext, Collections.emptyList()) - .getPlan(); - edgeStream = edgeStream != null ? edgeStream : - queryContext.getEngineContext().createRuntimeTable(queryContext, Collections.emptyList()) - .getPlan(); - - PIncGraphView dynamicGraph = graphView.appendGraph((PWindowStream) vertexStream, - (PWindowStream) edgeStream); - responsePWindow = dynamicGraphTraversal(dynamicGraph, parameterStartIds, constantStartIds, executeDagGroup, - maxTraversal, isAggTraversal, parallelism, enableIncrTraversal); + boolean isAggTraversal = dagMaxTraversal == Integer.MAX_VALUE; + if (!isAggTraversal) { + maxTraversal = Math.max(0, Math.min(maxTraversal, dagMaxTraversal)); + } + int parallelism = + (queryContext.getTraversalParallelism() > 0 + && queryContext.getTraversalParallelism() <= graph.getShardCount()) + ? queryContext.getTraversalParallelism() + : graph.getShardCount(); + + PWindowStream> responsePWindow; + + assert graphView instanceof PIncGraphView : "Illegal graph view"; + queryContext.addMaterializedGraph(graph.getName()); + + PWindowStream vertexStream = queryContext.getGraphVertexStream(graph.getName()); + PWindowStream edgeStream = queryContext.getGraphEdgeStream(graph.getName()); + if (vertexStream == null && edgeStream == null) { // traversal on snapshot of the + // dynamic graph + PGraphWindow staticGraph = + graphView.snapshot(graphViewDesc.getCurrentVersion()); + responsePWindow = + staticGraphTraversal( + staticGraph, + parameterStartIds, + constantStartIds, + executeDagGroup, + maxTraversal, + isAggTraversal, + parallelism); + } else { // traversal on dynamic graph + Preconditions.checkArgument( + graphViewDesc.getBackend() != BackendType.Paimon, + "paimon does not support dynamic graph traversal"); + + boolean enableIncrTraversal = + DynamicGraphHelper.enableIncrTraversal( + maxTraversal, startIds.size(), context.getConfig()); + if (maxTraversal != Integer.MAX_VALUE) { + if (enableIncrTraversal) { + // Double the maxTraversal if is incrTraversal, need pre evolve subgraph. + // the evolve phase is 1 smaller than the query Iteration + maxTraversal = maxTraversal * 2 - 1; } - responsePWindow.withParallelism(parallelism); - PWindowStream resultPWindow = responsePWindow.flatMap(new ResponseToRowFunction()) + } + vertexStream = + vertexStream != null + ? vertexStream + : queryContext + .getEngineContext() + .createRuntimeTable(queryContext, Collections.emptyList()) + .getPlan(); + edgeStream = + edgeStream != null + ? edgeStream + : queryContext + .getEngineContext() + .createRuntimeTable(queryContext, Collections.emptyList()) + .getPlan(); + + PIncGraphView dynamicGraph = + graphView.appendGraph((PWindowStream) vertexStream, (PWindowStream) edgeStream); + responsePWindow = + dynamicGraphTraversal( + dynamicGraph, + parameterStartIds, + constantStartIds, + executeDagGroup, + maxTraversal, + isAggTraversal, + parallelism, + enableIncrTraversal); + } + responsePWindow.withParallelism(parallelism); + PWindowStream resultPWindow = + responsePWindow + .flatMap(new ResponseToRowFunction()) .withName(queryContext.createOperatorName("TraversalResponseToRow")); - return new GeaFlowRuntimeTable(queryContext, context, resultPWindow); + return new GeaFlowRuntimeTable(queryContext, context, resultPWindow); + } + + private PWindowStream> staticGraphTraversal( + PGraphWindow staticGraph, + Set parameterStartIds, + Set constantStartIds, + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isAggTraversal, + int parallelism) { + PWindowStream> responsePWindow; + if (queryContext.getRequestTable() != null) { // traversal with request + RuntimeTable requestTable = queryContext.getRequestTable(); + boolean isIdOnlyRequest = queryContext.isIdOnlyRequest(); + + PWindowStream requestWindowStream = requestTable.getPlan(); + PWindowStream> parameterizedRequest; + boolean isTraversalAllWithRequest; + if (parameterStartIds.size() == 1) { // static request table attach the start id + parameterizedRequest = + requestWindowStream.map( + new RowToParameterRequestFunction( + parameterStartIds.iterator().next(), isIdOnlyRequest)); + isTraversalAllWithRequest = false; + } else { // static request table attach all the traversal ids. + parameterizedRequest = + requestWindowStream + .map(new RowToParameterRequestFunction(null, isIdOnlyRequest)) + .broadcast(); + isTraversalAllWithRequest = true; + } + responsePWindow = + ((PGraphTraversal) + getStaticVCTraversal( + isAggTraversal, + staticGraph, + executeDagGroup, + maxTraversal, + isTraversalAllWithRequest, + parallelism)) + .start((PWindowStream) parameterizedRequest); + + } else if (constantStartIds.size() > 0) { // static request with constant ids. + responsePWindow = + ((PGraphTraversal) + getStaticVCTraversal( + isAggTraversal, + staticGraph, + executeDagGroup, + maxTraversal, + false, + parallelism)) + .start(new ArrayList<>(constantStartIds)); + } else { // traversal all + boolean enableTraversalAllSplit = + queryContext.getGlobalConf().getBoolean(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE); + if (enableTraversalAllSplit) { + DynamicGraphVertexScanSourceFunction sourceFunction = + new DynamicGraphVertexScanSourceFunction<>(graphViewDesc); + PWindowSource source = + queryContext + .getEngineContext() + .createRuntimeTable(queryContext, sourceFunction) + .withParallelism(graphViewDesc.getShardNum()) + .withName(queryContext.createOperatorName("VertexScanSource")); + responsePWindow = + getStaticVCTraversal( + isAggTraversal, staticGraph, executeDagGroup, maxTraversal, false, parallelism) + .start((PWindowStream) source); + } else { + responsePWindow = + ((PGraphTraversal) + getStaticVCTraversal( + isAggTraversal, + staticGraph, + executeDagGroup, + maxTraversal, + false, + parallelism)) + .start(); + } } - - private PWindowStream> staticGraphTraversal( - PGraphWindow staticGraph, - Set parameterStartIds, - Set constantStartIds, - ExecuteDagGroup executeDagGroup, - int maxTraversal, - boolean isAggTraversal, - int parallelism) { - PWindowStream> responsePWindow; - if (queryContext.getRequestTable() != null) { // traversal with request - RuntimeTable requestTable = queryContext.getRequestTable(); - boolean isIdOnlyRequest = queryContext.isIdOnlyRequest(); - - PWindowStream requestWindowStream = requestTable.getPlan(); - PWindowStream> parameterizedRequest; - boolean isTraversalAllWithRequest; - if (parameterStartIds.size() == 1) { // static request table attach the start id - parameterizedRequest = requestWindowStream.map( - new RowToParameterRequestFunction(parameterStartIds.iterator().next(), isIdOnlyRequest)); - isTraversalAllWithRequest = false; - } else { // static request table attach all the traversal ids. - parameterizedRequest = requestWindowStream.map(new RowToParameterRequestFunction(null, isIdOnlyRequest)) - .broadcast(); - isTraversalAllWithRequest = true; - } - responsePWindow = - ((PGraphTraversal) getStaticVCTraversal(isAggTraversal, - staticGraph, executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism)) - .start((PWindowStream) parameterizedRequest); - - } else if (constantStartIds.size() > 0) { // static request with constant ids. - responsePWindow = - ((PGraphTraversal) getStaticVCTraversal(isAggTraversal, - staticGraph, executeDagGroup, maxTraversal, false, parallelism)).start(new ArrayList<>(constantStartIds)); - } else { // traversal all - boolean enableTraversalAllSplit = queryContext.getGlobalConf() - .getBoolean(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE); - if (enableTraversalAllSplit) { - DynamicGraphVertexScanSourceFunction sourceFunction = - new DynamicGraphVertexScanSourceFunction<>(graphViewDesc); - PWindowSource source = queryContext.getEngineContext() - .createRuntimeTable(queryContext, sourceFunction) - .withParallelism(graphViewDesc.getShardNum()) - .withName(queryContext.createOperatorName("VertexScanSource")); - responsePWindow = - getStaticVCTraversal(isAggTraversal, - staticGraph, executeDagGroup, maxTraversal, false, parallelism) - .start((PWindowStream) source); - } else { - responsePWindow = - ((PGraphTraversal) getStaticVCTraversal(isAggTraversal, - staticGraph, executeDagGroup, maxTraversal, false, parallelism)).start(); - } - } - return responsePWindow; + return responsePWindow; + } + + private PWindowStream> dynamicGraphTraversal( + PIncGraphView dynamicGraph, + Set parameterStartIds, + Set constantStartIds, + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isAggTraversal, + int parallelism, + boolean enableIncrTraversal) { + if (queryContext.getRequestTable() != null) { // dynamic traversal with request + RuntimeTable requestTable = queryContext.getRequestTable(); + boolean isIdOnlyRequest = queryContext.isIdOnlyRequest(); + + PWindowStream requestWindowStream = requestTable.getPlan(); + PWindowStream> parameterizedRequest; + boolean isTraversalAllWithRequest; + if (parameterStartIds.size() == 1) { // request table attach the start id. + parameterizedRequest = + requestWindowStream.map( + new RowToParameterRequestFunction( + parameterStartIds.iterator().next(), isIdOnlyRequest)); + isTraversalAllWithRequest = false; + } else { + parameterizedRequest = + requestWindowStream + .map(new RowToParameterRequestFunction(null, isIdOnlyRequest)) + .broadcast(); + isTraversalAllWithRequest = true; + } + return ((PGraphTraversal) + getDynamicVCTraversal( + isAggTraversal, + dynamicGraph, + executeDagGroup, + maxTraversal, + isTraversalAllWithRequest, + parallelism, + enableIncrTraversal)) + .start((PWindowStream) parameterizedRequest); + } else if (constantStartIds.size() > 0) { // request with constant ids. + return ((PGraphTraversal) + getDynamicVCTraversal( + isAggTraversal, + dynamicGraph, + executeDagGroup, + maxTraversal, + false, + parallelism, + enableIncrTraversal)) + .start(new ArrayList<>(constantStartIds)); + } else { // dynamic traversal all + boolean enableTraversalAllSplit = + queryContext.getGlobalConf().getBoolean(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE); + if (enableTraversalAllSplit) { + DynamicGraphVertexScanSourceFunction sourceFunction = + new DynamicGraphVertexScanSourceFunction<>(graphViewDesc); + PWindowSource source = + queryContext + .getEngineContext() + .createRuntimeTable(queryContext, sourceFunction) + .withParallelism(graphViewDesc.getShardNum()) + .withName(queryContext.createOperatorName("VertexScanSource")); + return getDynamicVCTraversal( + isAggTraversal, + dynamicGraph, + executeDagGroup, + maxTraversal, + false, + parallelism, + enableIncrTraversal) + .start((PWindowStream) source); + } + return ((PGraphTraversal) + getDynamicVCTraversal( + isAggTraversal, + dynamicGraph, + executeDagGroup, + maxTraversal, + false, + parallelism, + enableIncrTraversal)) + .start(); } - - private PWindowStream> dynamicGraphTraversal( - PIncGraphView dynamicGraph, Set parameterStartIds, - Set constantStartIds, ExecuteDagGroup executeDagGroup, int maxTraversal, boolean isAggTraversal, - int parallelism, boolean enableIncrTraversal) { - if (queryContext.getRequestTable() != null) { // dynamic traversal with request - RuntimeTable requestTable = queryContext.getRequestTable(); - boolean isIdOnlyRequest = queryContext.isIdOnlyRequest(); - - PWindowStream requestWindowStream = requestTable.getPlan(); - PWindowStream> parameterizedRequest; - boolean isTraversalAllWithRequest; - if (parameterStartIds.size() == 1) { // request table attach the start id. - parameterizedRequest = requestWindowStream.map( - new RowToParameterRequestFunction(parameterStartIds.iterator().next(), isIdOnlyRequest)); - isTraversalAllWithRequest = false; - } else { - parameterizedRequest = requestWindowStream.map( - new RowToParameterRequestFunction(null, isIdOnlyRequest)).broadcast(); - isTraversalAllWithRequest = true; - } - return ((PGraphTraversal) getDynamicVCTraversal(isAggTraversal, dynamicGraph, - executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism, enableIncrTraversal)).start( - (PWindowStream) parameterizedRequest); - } else if (constantStartIds.size() > 0) { // request with constant ids. - return ((PGraphTraversal) getDynamicVCTraversal(isAggTraversal, dynamicGraph, - executeDagGroup, maxTraversal, false, parallelism, enableIncrTraversal)).start( - new ArrayList<>(constantStartIds)); - } else { // dynamic traversal all - boolean enableTraversalAllSplit = queryContext.getGlobalConf() - .getBoolean(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE); - if (enableTraversalAllSplit) { - DynamicGraphVertexScanSourceFunction sourceFunction = new DynamicGraphVertexScanSourceFunction<>( - graphViewDesc); - PWindowSource source = queryContext.getEngineContext() - .createRuntimeTable(queryContext, sourceFunction) - .withParallelism(graphViewDesc.getShardNum()) - .withName(queryContext.createOperatorName("VertexScanSource")); - return getDynamicVCTraversal(isAggTraversal, dynamicGraph, executeDagGroup, maxTraversal, false, - parallelism, enableIncrTraversal).start((PWindowStream) source); - } - return ((PGraphTraversal) getDynamicVCTraversal(isAggTraversal, dynamicGraph, - executeDagGroup, maxTraversal, false, parallelism, enableIncrTraversal)).start(); - - } + } + + private PGraphTraversal getStaticVCTraversal( + boolean isAggTraversal, + PGraphWindow staticGraph, + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isTraversalAllWithRequest, + int parallelism) { + if (isAggTraversal) { + return staticGraph.traversal( + new GeaFlowStaticVCAggTraversal( + executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism)); + } else { + return staticGraph.traversal( + new GeaFlowStaticVCTraversal(executeDagGroup, maxTraversal, isTraversalAllWithRequest)); } - - private PGraphTraversal getStaticVCTraversal(boolean isAggTraversal, - PGraphWindow staticGraph, - ExecuteDagGroup executeDagGroup, int maxTraversal, - boolean isTraversalAllWithRequest, int parallelism) { - if (isAggTraversal) { - return staticGraph.traversal( - new GeaFlowStaticVCAggTraversal(executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism)); - } else { - return staticGraph.traversal( - new GeaFlowStaticVCTraversal(executeDagGroup, maxTraversal, isTraversalAllWithRequest)); - } + } + + private PGraphTraversal getDynamicVCTraversal( + boolean isAggTraversal, + PIncGraphView dynamicGraph, + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isTraversalAllWithRequest, + int parallelism, + boolean enableIncrTraversal) { + if (isAggTraversal) { + return dynamicGraph.incrementalTraversal( + new GeaFlowDynamicVCAggTraversal( + executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism)); + } else { + return dynamicGraph.incrementalTraversal( + new GeaFlowDynamicVCTraversal( + executeDagGroup, maxTraversal, isTraversalAllWithRequest, enableIncrTraversal)); } - - private PGraphTraversal getDynamicVCTraversal(boolean isAggTraversal, - PIncGraphView dynamicGraph, - ExecuteDagGroup executeDagGroup, int maxTraversal, - boolean isTraversalAllWithRequest, int parallelism, - boolean enableIncrTraversal) { - if (isAggTraversal) { - return dynamicGraph.incrementalTraversal( - new GeaFlowDynamicVCAggTraversal(executeDagGroup, maxTraversal, isTraversalAllWithRequest, parallelism)); - } else { - return dynamicGraph.incrementalTraversal( - new GeaFlowDynamicVCTraversal(executeDagGroup, maxTraversal, isTraversalAllWithRequest, - enableIncrTraversal)); - } + } + + @Override + public RuntimeTable runAlgorithm(GraphAlgorithm graphAlgorithm) { + Class algorithmUserFunctionClass = + graphAlgorithm.getUserFunctionClass(); + AlgorithmUserFunction algorithm; + try { + algorithm = algorithmUserFunctionClass.getConstructor().newInstance(); + } catch (Exception e) { + throw new GeaFlowDSLException( + "Cannot new instance for class: " + algorithmUserFunctionClass.getName(), e); } - - @Override - public RuntimeTable runAlgorithm(GraphAlgorithm graphAlgorithm) { - Class algorithmUserFunctionClass = graphAlgorithm.getUserFunctionClass(); - AlgorithmUserFunction algorithm; - try { - algorithm = algorithmUserFunctionClass.getConstructor().newInstance(); - } catch (Exception e) { - throw new GeaFlowDSLException("Cannot new instance for class: " + algorithmUserFunctionClass.getName(), e); - } - int maxTraversal = context.getConfig().getInteger(DSLConfigKeys.GEAFLOW_DSL_MAX_TRAVERSAL); - int parallelism = (queryContext.getTraversalParallelism() > 0 - && queryContext.getTraversalParallelism() <= graph.getShardCount()) - ? queryContext.getTraversalParallelism() : graph.getShardCount(); - - PWindowStream vertexStream = queryContext.getGraphVertexStream(graph.getName()); - PWindowStream edgeStream = queryContext.getGraphEdgeStream(graph.getName()); - PWindowStream> responsePWindow; - assert graphView instanceof PIncGraphView : "Illegal graph view"; - queryContext.addMaterializedGraph(graph.getName()); - - if (vertexStream == null && edgeStream == null) { // traversal on snapshot of the dynamic graph - PGraphWindow staticGraph = graphView.snapshot(graphViewDesc.getCurrentVersion()); - boolean enableAlgorithmSplit = algorithm instanceof IncrementalAlgorithmUserFunction; - if (enableAlgorithmSplit) { - DynamicGraphVertexScanSourceFunction sourceFunction = - new DynamicGraphVertexScanSourceFunction<>(graphViewDesc); - PWindowSource source = queryContext.getEngineContext() - .createRuntimeTable(queryContext, sourceFunction) - .withParallelism(graphViewDesc.getShardNum()) - .withName(queryContext.createOperatorName("VertexScanSource")); - responsePWindow = staticGraph.traversal(new GeaFlowAlgorithmAggTraversal( - algorithm, maxTraversal, graphAlgorithm.getParams(), graphSchema, parallelism)) - .start((PWindowStream) source); - } else { - responsePWindow = staticGraph.traversal( - new GeaFlowAlgorithmAggTraversal(algorithm, maxTraversal, - graphAlgorithm.getParams(), graphSchema, parallelism)).start(); - } - } else { // traversal on dynamic graph - Preconditions.checkArgument(graphViewDesc.getBackend() != BackendType.Paimon, - "paimon does not support dynamic graph traversal"); - - vertexStream = vertexStream != null ? vertexStream : - queryContext.getEngineContext().createRuntimeTable(queryContext, Collections.emptyList()) - .getPlan(); - edgeStream = edgeStream != null ? edgeStream : - queryContext.getEngineContext().createRuntimeTable(queryContext, Collections.emptyList()) - .getPlan(); - - PIncGraphView dynamicGraph = graphView.appendGraph((PWindowStream) vertexStream, - (PWindowStream) edgeStream); - boolean enableAlgorithmSplit = algorithm instanceof IncrementalAlgorithmUserFunction; - if (enableAlgorithmSplit) { - PWindowStream evolvedRequest = - vertexStream.map(new VertexToParameterRequestFunction()).union( - edgeStream.flatMap(new EdgeToParameterRequestFunction())).broadcast(); - responsePWindow = dynamicGraph.incrementalTraversal( - new GeaFlowAlgorithmDynamicAggTraversal(algorithm, maxTraversal, - graphAlgorithm.getParams(), graphSchema, parallelism)).start(evolvedRequest); - } else { - responsePWindow = dynamicGraph.incrementalTraversal( - new GeaFlowAlgorithmDynamicAggTraversal(algorithm, maxTraversal, - graphAlgorithm.getParams(), graphSchema, parallelism)).start(); - } - } - responsePWindow = responsePWindow.withParallelism(parallelism); - PWindowStream resultPWindow = responsePWindow.flatMap( - (FlatMapFunction, Row>) (value, collector) -> collector.partition( - value.getResponse())); - return new GeaFlowRuntimeTable(queryContext, context, resultPWindow); + int maxTraversal = context.getConfig().getInteger(DSLConfigKeys.GEAFLOW_DSL_MAX_TRAVERSAL); + int parallelism = + (queryContext.getTraversalParallelism() > 0 + && queryContext.getTraversalParallelism() <= graph.getShardCount()) + ? queryContext.getTraversalParallelism() + : graph.getShardCount(); + + PWindowStream vertexStream = queryContext.getGraphVertexStream(graph.getName()); + PWindowStream edgeStream = queryContext.getGraphEdgeStream(graph.getName()); + PWindowStream> responsePWindow; + assert graphView instanceof PIncGraphView : "Illegal graph view"; + queryContext.addMaterializedGraph(graph.getName()); + + if (vertexStream == null && edgeStream == null) { // traversal on snapshot of the dynamic graph + PGraphWindow staticGraph = + graphView.snapshot(graphViewDesc.getCurrentVersion()); + boolean enableAlgorithmSplit = algorithm instanceof IncrementalAlgorithmUserFunction; + if (enableAlgorithmSplit) { + DynamicGraphVertexScanSourceFunction sourceFunction = + new DynamicGraphVertexScanSourceFunction<>(graphViewDesc); + PWindowSource source = + queryContext + .getEngineContext() + .createRuntimeTable(queryContext, sourceFunction) + .withParallelism(graphViewDesc.getShardNum()) + .withName(queryContext.createOperatorName("VertexScanSource")); + responsePWindow = + staticGraph + .traversal( + new GeaFlowAlgorithmAggTraversal( + algorithm, + maxTraversal, + graphAlgorithm.getParams(), + graphSchema, + parallelism)) + .start((PWindowStream) source); + } else { + responsePWindow = + staticGraph + .traversal( + new GeaFlowAlgorithmAggTraversal( + algorithm, + maxTraversal, + graphAlgorithm.getParams(), + graphSchema, + parallelism)) + .start(); + } + } else { // traversal on dynamic graph + Preconditions.checkArgument( + graphViewDesc.getBackend() != BackendType.Paimon, + "paimon does not support dynamic graph traversal"); + + vertexStream = + vertexStream != null + ? vertexStream + : queryContext + .getEngineContext() + .createRuntimeTable(queryContext, Collections.emptyList()) + .getPlan(); + edgeStream = + edgeStream != null + ? edgeStream + : queryContext + .getEngineContext() + .createRuntimeTable(queryContext, Collections.emptyList()) + .getPlan(); + + PIncGraphView dynamicGraph = + graphView.appendGraph((PWindowStream) vertexStream, (PWindowStream) edgeStream); + boolean enableAlgorithmSplit = algorithm instanceof IncrementalAlgorithmUserFunction; + if (enableAlgorithmSplit) { + PWindowStream evolvedRequest = + vertexStream + .map(new VertexToParameterRequestFunction()) + .union(edgeStream.flatMap(new EdgeToParameterRequestFunction())) + .broadcast(); + responsePWindow = + dynamicGraph + .incrementalTraversal( + new GeaFlowAlgorithmDynamicAggTraversal( + algorithm, + maxTraversal, + graphAlgorithm.getParams(), + graphSchema, + parallelism)) + .start(evolvedRequest); + } else { + responsePWindow = + dynamicGraph + .incrementalTraversal( + new GeaFlowAlgorithmDynamicAggTraversal( + algorithm, + maxTraversal, + graphAlgorithm.getParams(), + graphSchema, + parallelism)) + .start(); + } } + responsePWindow = responsePWindow.withParallelism(parallelism); + PWindowStream resultPWindow = + responsePWindow.flatMap( + (FlatMapFunction, Row>) + (value, collector) -> collector.partition(value.getResponse())); + return new GeaFlowRuntimeTable(queryContext, context, resultPWindow); + } + + private static class ResponseToRowFunction + implements FlatMapFunction, Row> { - private static class ResponseToRowFunction implements FlatMapFunction, Row> { - - @Override - public void flatMap(ITraversalResponse value, Collector collector) { - ITreePath treePath = value.getResponse(); - boolean isParametrizedTreePath = treePath instanceof ParameterizedTreePath; - List paths = treePath.toList(); - for (Path path : paths) { - Row resultRow = path; - // If traversal with parameter request, we carry the parameter and requestId to the - // sql function. So that the sql follow the match statement can refer the request parameter. - if (isParametrizedTreePath) { - ParameterizedTreePath parameterizedTreePath = (ParameterizedTreePath) treePath; - Object requestId = parameterizedTreePath.getRequestId(); - Row parameter = parameterizedTreePath.getParameter(); - resultRow = new DefaultParameterizedRow(path, requestId, parameter); - } - collector.partition(resultRow); - } + @Override + public void flatMap(ITraversalResponse value, Collector collector) { + ITreePath treePath = value.getResponse(); + boolean isParametrizedTreePath = treePath instanceof ParameterizedTreePath; + List paths = treePath.toList(); + for (Path path : paths) { + Row resultRow = path; + // If traversal with parameter request, we carry the parameter and requestId to the + // sql function. So that the sql follow the match statement can refer the request parameter. + if (isParametrizedTreePath) { + ParameterizedTreePath parameterizedTreePath = (ParameterizedTreePath) treePath; + Object requestId = parameterizedTreePath.getRequestId(); + Row parameter = parameterizedTreePath.getParameter(); + resultRow = new DefaultParameterizedRow(path, requestId, parameter); } + collector.partition(resultRow); + } } + } - private static class RowToParameterRequestFunction extends RichFunction - implements MapFunction> { - - private final ParameterStartId startId; + private static class RowToParameterRequestFunction extends RichFunction + implements MapFunction> { - private final boolean isIdOnlyRequest; + private final ParameterStartId startId; - private int numTasks; + private final boolean isIdOnlyRequest; - private int taskIndex; + private int numTasks; - private long rowCounter = 0; + private int taskIndex; - public RowToParameterRequestFunction(ParameterStartId startId, boolean isIdOnlyRequest) { - this.startId = startId; - this.isIdOnlyRequest = isIdOnlyRequest; - } - - @Override - public void open(RuntimeContext runtimeContext) { - this.numTasks = runtimeContext.getTaskArgs().getParallelism(); - this.taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - } + private long rowCounter = 0; - @Override - public ITraversalRequest map(Row row) { - long requestId = IDUtil.uniqueId(numTasks, taskIndex, rowCounter); - if (requestId < 0) { - throw new GeaFlowDSLException("Request id exceed the Long.MAX, numTasks: " - + numTasks + ", taskIndex: " + taskIndex + ", rowCounter: " + rowCounter); - } - rowCounter++; - Object vertexId; - if (startId != null) { - vertexId = startId.getIdExpression().evaluate(row); - } else { - vertexId = TraversalAll.INSTANCE; - } - if (isIdOnlyRequest) { - return new IdOnlyRequest(vertexId); - } - return new InitParameterRequest(requestId, vertexId, row); - } + public RowToParameterRequestFunction(ParameterStartId startId, boolean isIdOnlyRequest) { + this.startId = startId; + this.isIdOnlyRequest = isIdOnlyRequest; + } - @Override - public void close() { + @Override + public void open(RuntimeContext runtimeContext) { + this.numTasks = runtimeContext.getTaskArgs().getParallelism(); + this.taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + } - } + @Override + public ITraversalRequest map(Row row) { + long requestId = IDUtil.uniqueId(numTasks, taskIndex, rowCounter); + if (requestId < 0) { + throw new GeaFlowDSLException( + "Request id exceed the Long.MAX, numTasks: " + + numTasks + + ", taskIndex: " + + taskIndex + + ", rowCounter: " + + rowCounter); + } + rowCounter++; + Object vertexId; + if (startId != null) { + vertexId = startId.getIdExpression().evaluate(row); + } else { + vertexId = TraversalAll.INSTANCE; + } + if (isIdOnlyRequest) { + return new IdOnlyRequest(vertexId); + } + return new InitParameterRequest(requestId, vertexId, row); } - private static class VertexToParameterRequestFunction extends RichFunction - implements MapFunction> { + @Override + public void close() {} + } - @Override - public void open(RuntimeContext runtimeContext) { - } + private static class VertexToParameterRequestFunction extends RichFunction + implements MapFunction> { - @Override - public ITraversalRequest map(RowVertex vertex) { - return new IdOnlyRequest(vertex.getId()); - } + @Override + public void open(RuntimeContext runtimeContext) {} - @Override - public void close() { - } + @Override + public ITraversalRequest map(RowVertex vertex) { + return new IdOnlyRequest(vertex.getId()); } - private static class EdgeToParameterRequestFunction extends RichFunction - implements FlatMapFunction> { - - @Override - public void open(RuntimeContext runtimeContext) { - } + @Override + public void close() {} + } - @Override - public void flatMap(RowEdge edge, Collector> collector) { - collector.partition(new IdOnlyRequest(edge.getSrcId())); - collector.partition(new IdOnlyRequest(edge.getTargetId())); - } + private static class EdgeToParameterRequestFunction extends RichFunction + implements FlatMapFunction> { - @Override - public void close() { - } + @Override + public void open(RuntimeContext runtimeContext) {} + @Override + public void flatMap(RowEdge edge, Collector> collector) { + collector.partition(new IdOnlyRequest(edge.getSrcId())); + collector.partition(new IdOnlyRequest(edge.getTargetId())); } + + @Override + public void close() {} + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeTable.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeTable.java index 8f62a189f..74a7d6ed6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowRuntimeTable.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.engine; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -28,6 +27,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; @@ -89,702 +89,716 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.collect.Lists; + public class GeaFlowRuntimeTable implements RuntimeTable { - private final QueryContext queryContext; + private final QueryContext queryContext; - private final IPipelineJobContext context; + private final IPipelineJobContext context; - private final PWindowStream pStream; + private final PWindowStream pStream; - public GeaFlowRuntimeTable(QueryContext queryContext, IPipelineJobContext context, - PWindowStream pStream) { - this.queryContext = Objects.requireNonNull(queryContext); - this.context = Objects.requireNonNull(context); - this.pStream = Objects.requireNonNull(pStream); - } + public GeaFlowRuntimeTable( + QueryContext queryContext, IPipelineJobContext context, PWindowStream pStream) { + this.queryContext = Objects.requireNonNull(queryContext); + this.context = Objects.requireNonNull(context); + this.pStream = Objects.requireNonNull(pStream); + } - public GeaFlowRuntimeTable copyWithSetOptions(PWindowStream pStream) { - pStream = pStream.withConfig(queryContext.getSetOptions()); - return new GeaFlowRuntimeTable(queryContext, context, pStream); - } + public GeaFlowRuntimeTable copyWithSetOptions(PWindowStream pStream) { + pStream = pStream.withConfig(queryContext.getSetOptions()); + return new GeaFlowRuntimeTable(queryContext, context, pStream); + } - @Override - public T getPlan() { - return (T) pStream; - } + @Override + public T getPlan() { + return (T) pStream; + } - @Override - public List take(IType type) { - if (JobMode.getJobMode(context.getConfig()).equals(JobMode.OLAP_SERVICE)) { - pStream.map(new BinaryRowToObjectMapFunction(type)).collect(); - } else { - pStream.collect(); - } - return new ArrayList<>(); + @Override + public List take(IType type) { + if (JobMode.getJobMode(context.getConfig()).equals(JobMode.OLAP_SERVICE)) { + pStream.map(new BinaryRowToObjectMapFunction(type)).collect(); + } else { + pStream.collect(); } + return new ArrayList<>(); + } - @Override - public RuntimeTable project(ProjectFunction function) { - String opName = PhysicRelNodeName.PROJECT.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + @Override + public RuntimeTable project(ProjectFunction function) { + String opName = PhysicRelNodeName.PROJECT.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - PWindowStream map = pStream.map(new TableProjectFunction(function)) - .withName(opName).withParallelism(parallelism); - return copyWithSetOptions(map); - } - - @Override - public RuntimeTable filter(WhereFunction function) { - String opName = PhysicRelNodeName.FILTER.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - PWindowStream filter = pStream.filter(new TableFilterFunction(function)) - .withName(opName).withParallelism(parallelism); - return copyWithSetOptions(filter); - } - - @Override - public RuntimeTable join(RuntimeTable other, JoinTableFunction function) { - throw new GeaFlowDSLException("Join has not support yet"); - } - - @Override - public RuntimeTable aggregate(GroupByFunction groupByFunction, AggFunction aggFunction) { - String opName = PhysicRelNodeName.AGGREGATE.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - boolean isGlobalDistinct = aggFunction.getValueTypes().length == 0; - PWindowStream aggregate = - pStream.flatMap(isGlobalDistinct ? new TableLocalDistinctFunction(groupByFunction) + PWindowStream map = + pStream + .map(new TableProjectFunction(function)) + .withName(opName) + .withParallelism(parallelism); + return copyWithSetOptions(map); + } + + @Override + public RuntimeTable filter(WhereFunction function) { + String opName = PhysicRelNodeName.FILTER.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + PWindowStream filter = + pStream + .filter(new TableFilterFunction(function)) + .withName(opName) + .withParallelism(parallelism); + return copyWithSetOptions(filter); + } + + @Override + public RuntimeTable join(RuntimeTable other, JoinTableFunction function) { + throw new GeaFlowDSLException("Join has not support yet"); + } + + @Override + public RuntimeTable aggregate(GroupByFunction groupByFunction, AggFunction aggFunction) { + String opName = PhysicRelNodeName.AGGREGATE.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + boolean isGlobalDistinct = aggFunction.getValueTypes().length == 0; + PWindowStream aggregate = + pStream + .flatMap( + isGlobalDistinct + ? new TableLocalDistinctFunction(groupByFunction) : new TableLocalAggregateFunction(groupByFunction, aggFunction)) - .withName(opName + "-local") - .withParallelism(pStream.getParallelism()) - .keyBy(new GroupKeySelectorFunction(groupByFunction)) - .withName(opName + "-KeyBy") - .withParallelism(pStream.getParallelism()) - .materialize() - .aggregate(isGlobalDistinct ? new TableGlobalDistinctFunction(groupByFunction) + .withName(opName + "-local") + .withParallelism(pStream.getParallelism()) + .keyBy(new GroupKeySelectorFunction(groupByFunction)) + .withName(opName + "-KeyBy") + .withParallelism(pStream.getParallelism()) + .materialize() + .aggregate( + isGlobalDistinct + ? new TableGlobalDistinctFunction(groupByFunction) : new TableGlobalAggregateFunction(groupByFunction, aggFunction)) - .withName(opName + "-global") - .withParallelism(parallelism); - return copyWithSetOptions(aggregate); - } - - @Override - public RuntimeTable union(RuntimeTable other) { - String opName = PhysicRelNodeName.UNION.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - PWindowStream union = pStream.union(other.getPlan()) - .withName(opName).withParallelism(parallelism); - return copyWithSetOptions(union); - } - - @Override - public RuntimeTable orderBy(OrderByFunction function) { - String opName = PhysicRelNodeName.SORT.getName(queryContext.getOpNameCount()); - PWindowStream order = pStream.flatMap(new TableOrderByFunction(function)) - .withName(opName + "-local").withParallelism(pStream.getParallelism()) + .withName(opName + "-global") + .withParallelism(parallelism); + return copyWithSetOptions(aggregate); + } + + @Override + public RuntimeTable union(RuntimeTable other) { + String opName = PhysicRelNodeName.UNION.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + PWindowStream union = + pStream.union(other.getPlan()).withName(opName).withParallelism(parallelism); + return copyWithSetOptions(union); + } + + @Override + public RuntimeTable orderBy(OrderByFunction function) { + String opName = PhysicRelNodeName.SORT.getName(queryContext.getOpNameCount()); + PWindowStream order = + pStream + .flatMap(new TableOrderByFunction(function)) + .withName(opName + "-local") + .withParallelism(pStream.getParallelism()) .flatMap(new TableOrderByFunction(function)) .withName(opName + "-global") .withParallelism(1); - return copyWithSetOptions(order); - } - - @Override - public RuntimeTable correlate(CorrelateFunction function) { - String opName = PhysicRelNodeName.CORRELATE.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - PWindowStream correlate = pStream.flatMap(new CorrelateFlatMapFunction(function)) - .withName(opName).withParallelism(parallelism); - return copyWithSetOptions(correlate); + return copyWithSetOptions(order); + } + + @Override + public RuntimeTable correlate(CorrelateFunction function) { + String opName = PhysicRelNodeName.CORRELATE.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + PWindowStream correlate = + pStream + .flatMap(new CorrelateFlatMapFunction(function)) + .withName(opName) + .withParallelism(parallelism); + return copyWithSetOptions(correlate); + } + + @Override + public SinkDataView write(GeaFlowTable table) { + TableConnector connector = ConnectorFactory.loadConnector(table.getTableType()); + if (!(connector instanceof TableWritableConnector)) { + throw new GeaFlowDSLException("Table: '{}' is not writeable", connector.getType()); } - - @Override - public SinkDataView write(GeaFlowTable table) { - TableConnector connector = ConnectorFactory.loadConnector(table.getTableType()); - if (!(connector instanceof TableWritableConnector)) { - throw new GeaFlowDSLException("Table: '{}' is not writeable", connector.getType()); - } - TableWritableConnector writableConnector = (TableWritableConnector) connector; - Configuration conf = table.getConfigWithGlobal(context.getConfig(), queryContext.getSetOptions()); - TableSink tableSink = writableConnector.createSink(conf); - tableSink.init(conf, table.getTableSchema()); - - String opName = PhysicRelNodeName.TABLE_SINK.getName(queryContext.getOpNameCount()); - int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); - - GeaFlowTableSinkFunction sinkFunction; - if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SINK_FUNCTION)) { - String customClassName = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SINK_FUNCTION); - try { - sinkFunction = (GeaFlowTableSinkFunction) ClassUtil.classForName(customClassName) + TableWritableConnector writableConnector = (TableWritableConnector) connector; + Configuration conf = + table.getConfigWithGlobal(context.getConfig(), queryContext.getSetOptions()); + TableSink tableSink = writableConnector.createSink(conf); + tableSink.init(conf, table.getTableSchema()); + + String opName = PhysicRelNodeName.TABLE_SINK.getName(queryContext.getOpNameCount()); + int parallelism = queryContext.getConfigParallelisms(opName, pStream.getParallelism()); + + GeaFlowTableSinkFunction sinkFunction; + if (conf.contains(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SINK_FUNCTION)) { + String customClassName = conf.getString(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SINK_FUNCTION); + try { + sinkFunction = + (GeaFlowTableSinkFunction) + ClassUtil.classForName(customClassName) .getConstructor(GeaFlowTable.class, TableSink.class) .newInstance(table, tableSink); - } catch (Exception e) { - throw new GeaFlowDSLException("Cannot create sink function: {}.", - customClassName, e); - } - } else { - sinkFunction = new GeaFlowTableSinkFunction(table, tableSink); - } - PWindowStream inputStream = pStream; - if (table.getPrimaryFields().size() > 0) { - int[] primaryKeyIndices = ArrayUtil.toIntArray(table.getPrimaryFields() - .stream().map(name -> table.getTableSchema().indexOf(name)) - .collect(Collectors.toList())); - IType[] primaryKeyTypes = table.getPrimaryFields() - .stream().map(name -> table.getTableSchema().getField(name).getType()) - .collect(Collectors.toList()).toArray(new IType[]{}); - - inputStream = pStream.keyBy(new GroupKeySelectorFunction( - new GroupByFunctionImpl(primaryKeyIndices, primaryKeyTypes))); - } - PStreamSink sink = inputStream.sink(sinkFunction) + } catch (Exception e) { + throw new GeaFlowDSLException("Cannot create sink function: {}.", customClassName, e); + } + } else { + sinkFunction = new GeaFlowTableSinkFunction(table, tableSink); + } + PWindowStream inputStream = pStream; + if (table.getPrimaryFields().size() > 0) { + int[] primaryKeyIndices = + ArrayUtil.toIntArray( + table.getPrimaryFields().stream() + .map(name -> table.getTableSchema().indexOf(name)) + .collect(Collectors.toList())); + IType[] primaryKeyTypes = + table.getPrimaryFields().stream() + .map(name -> table.getTableSchema().getField(name).getType()) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + + inputStream = + pStream.keyBy( + new GroupKeySelectorFunction( + new GroupByFunctionImpl(primaryKeyIndices, primaryKeyTypes))); + } + PStreamSink sink = + inputStream + .sink(sinkFunction) .withConfig(queryContext.getSetOptions()) .withName(opName) .withParallelism(parallelism); - return new GeaFlowSinkDataView(context, sink); - } + return new GeaFlowSinkDataView(context, sink); + } - @Override - public SinkDataView write(GeaFlowGraph graph, QueryContext queryContext) { - PWindowStream vertexStream = pStream.flatMap(new RowToVertexFunction(graph)); - PWindowStream edgeStream = pStream.flatMap(new RowToEdgeFunction(graph)); + @Override + public SinkDataView write(GeaFlowGraph graph, QueryContext queryContext) { + PWindowStream vertexStream = pStream.flatMap(new RowToVertexFunction(graph)); + PWindowStream edgeStream = pStream.flatMap(new RowToEdgeFunction(graph)); - PWindowStream preVertexStream = queryContext.getGraphVertexStream(graph.getName()); - if (preVertexStream != null) { - vertexStream = vertexStream.union(preVertexStream); - } - PWindowStream preEdgeStream = queryContext.getGraphEdgeStream(graph.getName()); - if (preEdgeStream != null) { - edgeStream = edgeStream.union(preEdgeStream); - } - queryContext.updateVertexAndEdgeToGraph(graph.getName(), graph, vertexStream, edgeStream); - return new GeaFlowSinkIncGraphView(context); + PWindowStream preVertexStream = queryContext.getGraphVertexStream(graph.getName()); + if (preVertexStream != null) { + vertexStream = vertexStream.union(preVertexStream); } + PWindowStream preEdgeStream = queryContext.getGraphEdgeStream(graph.getName()); + if (preEdgeStream != null) { + edgeStream = edgeStream.union(preEdgeStream); + } + queryContext.updateVertexAndEdgeToGraph(graph.getName(), graph, vertexStream, edgeStream); + return new GeaFlowSinkIncGraphView(context); + } + private static class TableProjectFunction implements MapFunction, Serializable { - private static class TableProjectFunction implements MapFunction, Serializable { - - private final ProjectFunction projectFunction; + private final ProjectFunction projectFunction; - public TableProjectFunction(ProjectFunction projectFunction) { - this.projectFunction = projectFunction; - } + public TableProjectFunction(ProjectFunction projectFunction) { + this.projectFunction = projectFunction; + } - @Override - public Row map(Row value) { - return projectFunction.project(value); - } + @Override + public Row map(Row value) { + return projectFunction.project(value); } + } - private static class BinaryRowToObjectMapFunction implements MapFunction, Serializable { + private static class BinaryRowToObjectMapFunction implements MapFunction, Serializable { - private final RowDecoder rowDecoder; + private final RowDecoder rowDecoder; - public BinaryRowToObjectMapFunction(IType schema) { - this.rowDecoder = new DefaultRowDecoder((StructType) schema); - } + public BinaryRowToObjectMapFunction(IType schema) { + this.rowDecoder = new DefaultRowDecoder((StructType) schema); + } - @Override - public Row map(Row row) { - return rowDecoder.decode(row); - } + @Override + public Row map(Row row) { + return rowDecoder.decode(row); } + } - private static class TableFilterFunction implements FilterFunction { + private static class TableFilterFunction implements FilterFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(TableFilterFunction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TableFilterFunction.class); - private final WhereFunction whereFunction; + private final WhereFunction whereFunction; - public TableFilterFunction(WhereFunction whereFunction) { - this.whereFunction = whereFunction; - } - - @Override - public boolean filter(Row record) { - return whereFunction.filter(record); - } + public TableFilterFunction(WhereFunction whereFunction) { + this.whereFunction = whereFunction; } - private static class TableLocalDistinctFunction extends RichWindowFunction implements - FlatMapFunction { - - private final GroupByFunction groupByFunction; - private final Map aggregatingState; - private final IBinaryEncoder encoder; - - public TableLocalDistinctFunction(GroupByFunction groupByFunction) { - this.groupByFunction = groupByFunction; - this.aggregatingState = new HashMap<>(); - IType[] fieldTypes = groupByFunction.getFieldTypes(); - TableField[] tableFields = new TableField[fieldTypes.length]; - for (int i = 0; i < fieldTypes.length; i++) { - tableFields[i] = new TableField(String.valueOf(i), fieldTypes[i], false); - } - this.encoder = EncoderFactory.createEncoder(new StructType(tableFields)); - } - - @Override - public void open(RuntimeContext runtimeContext) { - } + @Override + public boolean filter(Row record) { + return whereFunction.filter(record); + } + } + + private static class TableLocalDistinctFunction extends RichWindowFunction + implements FlatMapFunction { + + private final GroupByFunction groupByFunction; + private final Map aggregatingState; + private final IBinaryEncoder encoder; + + public TableLocalDistinctFunction(GroupByFunction groupByFunction) { + this.groupByFunction = groupByFunction; + this.aggregatingState = new HashMap<>(); + IType[] fieldTypes = groupByFunction.getFieldTypes(); + TableField[] tableFields = new TableField[fieldTypes.length]; + for (int i = 0; i < fieldTypes.length; i++) { + tableFields[i] = new TableField(String.valueOf(i), fieldTypes[i], false); + } + this.encoder = EncoderFactory.createEncoder(new StructType(tableFields)); + } - @Override - public void close() { - } + @Override + public void open(RuntimeContext runtimeContext) {} - @Override - public void flatMap(Row value, Collector collector) { - //local distinct - RowKey groupKey = groupByFunction.getRowKey(value); - Row acc = aggregatingState.get(groupKey); - if (acc == null) { - assert collector != null : "collector is null"; - IType[] keyTypes = groupByFunction.getFieldTypes(); - Object[] fields = new Object[keyTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = groupKey.getField(i, keyTypes[i]); - } - collector.partition(encoder.encode(ObjectRow.create(fields))); - } - aggregatingState.put(groupKey, value); - } + @Override + public void close() {} - @Override - public void finish() { - aggregatingState.clear(); - } + @Override + public void flatMap(Row value, Collector collector) { + // local distinct + RowKey groupKey = groupByFunction.getRowKey(value); + Row acc = aggregatingState.get(groupKey); + if (acc == null) { + assert collector != null : "collector is null"; + IType[] keyTypes = groupByFunction.getFieldTypes(); + Object[] fields = new Object[keyTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = groupKey.getField(i, keyTypes[i]); + } + collector.partition(encoder.encode(ObjectRow.create(fields))); + } + aggregatingState.put(groupKey, value); } - private static class TableLocalAggregateFunction extends RichWindowFunction implements - FlatMapFunction { - - private final AggFunction localAggFunction; - private final GroupByFunction groupByFunction; - private Collector collector; - private final Map aggregatingState; - private final IBinaryEncoder encoder; - - public TableLocalAggregateFunction(GroupByFunction groupByFunction, AggFunction localAggFunction) { - this.localAggFunction = localAggFunction; - this.groupByFunction = groupByFunction; - this.aggregatingState = new HashMap<>(); - IType[] fieldTypes = groupByFunction.getFieldTypes(); - TableField[] tableFields = new TableField[fieldTypes.length + 1]; - for (int i = 0; i < fieldTypes.length; i++) { - tableFields[i] = new TableField(String.valueOf(i), fieldTypes[i], false); - } - tableFields[fieldTypes.length] = new TableField(String.valueOf(fieldTypes.length) - , ObjectType.INSTANCE, false); - this.encoder = EncoderFactory.createEncoder(new StructType(tableFields)); - } + @Override + public void finish() { + aggregatingState.clear(); + } + } + + private static class TableLocalAggregateFunction extends RichWindowFunction + implements FlatMapFunction { + + private final AggFunction localAggFunction; + private final GroupByFunction groupByFunction; + private Collector collector; + private final Map aggregatingState; + private final IBinaryEncoder encoder; + + public TableLocalAggregateFunction( + GroupByFunction groupByFunction, AggFunction localAggFunction) { + this.localAggFunction = localAggFunction; + this.groupByFunction = groupByFunction; + this.aggregatingState = new HashMap<>(); + IType[] fieldTypes = groupByFunction.getFieldTypes(); + TableField[] tableFields = new TableField[fieldTypes.length + 1]; + for (int i = 0; i < fieldTypes.length; i++) { + tableFields[i] = new TableField(String.valueOf(i), fieldTypes[i], false); + } + tableFields[fieldTypes.length] = + new TableField(String.valueOf(fieldTypes.length), ObjectType.INSTANCE, false); + this.encoder = EncoderFactory.createEncoder(new StructType(tableFields)); + } - @Override - public void open(RuntimeContext runtimeContext) { - FunctionContext context = - FunctionContext.of(runtimeContext.getConfiguration()); - localAggFunction.open(context); - } + @Override + public void open(RuntimeContext runtimeContext) { + FunctionContext context = FunctionContext.of(runtimeContext.getConfiguration()); + localAggFunction.open(context); + } - @Override - public void close() { - } + @Override + public void close() {} - @Override - public void flatMap(Row value, Collector collector) { - this.collector = collector; - //local aggregate - RowKey groupKey = groupByFunction.getRowKey(value); - Object acc = aggregatingState.get(groupKey); - if (acc == null) { - acc = localAggFunction.createAccumulator(); - } - localAggFunction.add(value, acc); - aggregatingState.put(groupKey, acc); - } + @Override + public void flatMap(Row value, Collector collector) { + this.collector = collector; + // local aggregate + RowKey groupKey = groupByFunction.getRowKey(value); + Object acc = aggregatingState.get(groupKey); + if (acc == null) { + acc = localAggFunction.createAccumulator(); + } + localAggFunction.add(value, acc); + aggregatingState.put(groupKey, acc); + } - @Override - public void finish() { - for (Entry rowKeyObjectEntry : aggregatingState.entrySet()) { - assert collector != null : "collector is null"; - IType[] keyTypes = groupByFunction.getFieldTypes(); - //The last offset of ObjectRow is accumulator - Object[] fields = new Object[keyTypes.length + 1]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = rowKeyObjectEntry.getKey().getField(i, keyTypes[i]); - } - fields[keyTypes.length] = rowKeyObjectEntry.getValue(); - collector.partition(encoder.encode(ObjectRow.create(fields))); - } - aggregatingState.clear(); - } + @Override + public void finish() { + for (Entry rowKeyObjectEntry : aggregatingState.entrySet()) { + assert collector != null : "collector is null"; + IType[] keyTypes = groupByFunction.getFieldTypes(); + // The last offset of ObjectRow is accumulator + Object[] fields = new Object[keyTypes.length + 1]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = rowKeyObjectEntry.getKey().getField(i, keyTypes[i]); + } + fields[keyTypes.length] = rowKeyObjectEntry.getValue(); + collector.partition(encoder.encode(ObjectRow.create(fields))); + } + aggregatingState.clear(); } + } - private static class TableGlobalDistinctFunction extends RichFunction implements - AggregateFunction { + private static class TableGlobalDistinctFunction extends RichFunction + implements AggregateFunction { - private final GroupByFunction groupByFunction; + private final GroupByFunction groupByFunction; - public TableGlobalDistinctFunction(GroupByFunction groupByFunction) { - this.groupByFunction = groupByFunction; - } + public TableGlobalDistinctFunction(GroupByFunction groupByFunction) { + this.groupByFunction = groupByFunction; + } - @Override - public void open(RuntimeContext runtimeContext) { - } + @Override + public void open(RuntimeContext runtimeContext) {} - @Override - public void close() { - } + @Override + public void close() {} - @Override - public Object createAccumulator() { - return new DistinctAccumulator(null); - } + @Override + public Object createAccumulator() { + return new DistinctAccumulator(null); + } - @Override - public void add(Row value, Object keyAccumulator) { - IType[] keyTypes = groupByFunction.getFieldTypes(); - Object[] fields = new Object[keyTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = value.getField(i, keyTypes[i]); - } - RowKey key = ObjectRowKey.of(fields); - DistinctAccumulator keyAcc = (DistinctAccumulator) keyAccumulator; - if (keyAcc.getKey() == null) { - keyAcc.setKey(key); - } - } + @Override + public void add(Row value, Object keyAccumulator) { + IType[] keyTypes = groupByFunction.getFieldTypes(); + Object[] fields = new Object[keyTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = value.getField(i, keyTypes[i]); + } + RowKey key = ObjectRowKey.of(fields); + DistinctAccumulator keyAcc = (DistinctAccumulator) keyAccumulator; + if (keyAcc.getKey() == null) { + keyAcc.setKey(key); + } + } - @Override - public Row getResult(Object keyAccumulator) { - RowKey key = ((DistinctAccumulator) keyAccumulator).getResult(); - if (key == null) { - return null; - } - IType[] keyTypes = groupByFunction.getFieldTypes(); - Object[] fields = new Object[keyTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = key.getField(i, keyTypes[i]); - } - return ObjectRow.create(fields); - } + @Override + public Row getResult(Object keyAccumulator) { + RowKey key = ((DistinctAccumulator) keyAccumulator).getResult(); + if (key == null) { + return null; + } + IType[] keyTypes = groupByFunction.getFieldTypes(); + Object[] fields = new Object[keyTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = key.getField(i, keyTypes[i]); + } + return ObjectRow.create(fields); + } - @Override - public Object merge(Object a, Object b) { - assert Objects.equals(((DistinctAccumulator) a).getKey(), - ((DistinctAccumulator) b).getKey()); - return a; - } + @Override + public Object merge(Object a, Object b) { + assert Objects.equals(((DistinctAccumulator) a).getKey(), ((DistinctAccumulator) b).getKey()); + return a; + } - private static class DistinctAccumulator implements Serializable { + private static class DistinctAccumulator implements Serializable { - private RowKey key; + private RowKey key; - private boolean hasBeenRead = false; + private boolean hasBeenRead = false; - public DistinctAccumulator(RowKey key) { - this.key = key; - } + public DistinctAccumulator(RowKey key) { + this.key = key; + } - public RowKey getKey() { - return key; - } + public RowKey getKey() { + return key; + } - public void setKey(RowKey key) { - this.key = key; - } + public void setKey(RowKey key) { + this.key = key; + } - public RowKey getResult() { - if (hasBeenRead) { - return null; - } else { - hasBeenRead = true; - return key; - } - } + public RowKey getResult() { + if (hasBeenRead) { + return null; + } else { + hasBeenRead = true; + return key; } + } } + } - private static class TableGlobalAggregateFunction extends RichFunction implements - AggregateFunction { + private static class TableGlobalAggregateFunction extends RichFunction + implements AggregateFunction { - private final AggFunction aggFunction; - private final GroupByFunction groupByFunction; - - public TableGlobalAggregateFunction(GroupByFunction groupByFunction, AggFunction aggFunction) { - this.aggFunction = aggFunction; - this.groupByFunction = groupByFunction; - } + private final AggFunction aggFunction; + private final GroupByFunction groupByFunction; - @Override - public void open(RuntimeContext runtimeContext) { - FunctionContext context = - FunctionContext.of(runtimeContext.getConfiguration()); - aggFunction.open(context); - } - - @Override - public void close() { - } - - @Override - public Object createAccumulator() { - return new KeyAccumulator(null, aggFunction.createAccumulator()); - } + public TableGlobalAggregateFunction(GroupByFunction groupByFunction, AggFunction aggFunction) { + this.aggFunction = aggFunction; + this.groupByFunction = groupByFunction; + } - @Override - public void add(Row value, Object keyAccumulator) { - IType[] keyTypes = groupByFunction.getFieldTypes(); - Object[] fields = new Object[keyTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = value.getField(i, keyTypes[i]); - } - RowKey key = ObjectRowKey.of(fields); - KeyAccumulator keyAcc = (KeyAccumulator) keyAccumulator; - if (keyAcc.getKey() == null) { - keyAcc.setKey(key); - } - if (aggFunction.getValueTypes().length > 0) { - int offset = keyTypes.length; - aggFunction.merge(keyAcc.getAcc(), value.getField(offset, ObjectType.INSTANCE)); - } - } + @Override + public void open(RuntimeContext runtimeContext) { + FunctionContext context = FunctionContext.of(runtimeContext.getConfiguration()); + aggFunction.open(context); + } - @Override - public Row getResult(Object keyAccumulator) { - KeyAccumulator keyAcc = (KeyAccumulator) keyAccumulator; - RowKey key = keyAcc.getKey(); - Object accumulator = keyAcc.getAcc(); - Row aggValue = aggFunction.getValue(accumulator); + @Override + public void close() {} - IType[] keyTypes = groupByFunction.getFieldTypes(); - IType[] valueTypes = aggFunction.getValueTypes(); + @Override + public Object createAccumulator() { + return new KeyAccumulator(null, aggFunction.createAccumulator()); + } - Object[] fields = new Object[keyTypes.length + valueTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = key.getField(i, keyTypes[i]); - } + @Override + public void add(Row value, Object keyAccumulator) { + IType[] keyTypes = groupByFunction.getFieldTypes(); + Object[] fields = new Object[keyTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = value.getField(i, keyTypes[i]); + } + RowKey key = ObjectRowKey.of(fields); + KeyAccumulator keyAcc = (KeyAccumulator) keyAccumulator; + if (keyAcc.getKey() == null) { + keyAcc.setKey(key); + } + if (aggFunction.getValueTypes().length > 0) { + int offset = keyTypes.length; + aggFunction.merge(keyAcc.getAcc(), value.getField(offset, ObjectType.INSTANCE)); + } + } - int offset = keyTypes.length; - for (int i = 0; i < valueTypes.length; i++) { - fields[offset + i] = aggValue.getField(i, valueTypes[i]); - } - return ObjectRow.create(fields); - } + @Override + public Row getResult(Object keyAccumulator) { + KeyAccumulator keyAcc = (KeyAccumulator) keyAccumulator; + RowKey key = keyAcc.getKey(); + Object accumulator = keyAcc.getAcc(); + Row aggValue = aggFunction.getValue(accumulator); + + IType[] keyTypes = groupByFunction.getFieldTypes(); + IType[] valueTypes = aggFunction.getValueTypes(); + + Object[] fields = new Object[keyTypes.length + valueTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = key.getField(i, keyTypes[i]); + } + + int offset = keyTypes.length; + for (int i = 0; i < valueTypes.length; i++) { + fields[offset + i] = aggValue.getField(i, valueTypes[i]); + } + return ObjectRow.create(fields); + } - @Override - public Object merge(Object a, Object b) { - aggFunction.merge(((KeyAccumulator) a).getAcc(), ((KeyAccumulator) b).getAcc()); - return a; - } + @Override + public Object merge(Object a, Object b) { + aggFunction.merge(((KeyAccumulator) a).getAcc(), ((KeyAccumulator) b).getAcc()); + return a; + } - private static class KeyAccumulator implements Serializable { + private static class KeyAccumulator implements Serializable { - private RowKey key; + private RowKey key; - private final Object accumulator; + private final Object accumulator; - public KeyAccumulator(RowKey key, Object accumulator) { - this.key = key; - this.accumulator = accumulator; - } + public KeyAccumulator(RowKey key, Object accumulator) { + this.key = key; + this.accumulator = accumulator; + } - public RowKey getKey() { - return key; - } + public RowKey getKey() { + return key; + } - public void setKey(RowKey key) { - this.key = key; - } + public void setKey(RowKey key) { + this.key = key; + } - public Object getAcc() { - return accumulator; - } - } + public Object getAcc() { + return accumulator; + } } + } - private static class GroupKeySelectorFunction implements KeySelector { - - GroupByFunction groupByFunction; + private static class GroupKeySelectorFunction implements KeySelector { - public GroupKeySelectorFunction(GroupByFunction groupByFunction) { - this.groupByFunction = groupByFunction; - } + GroupByFunction groupByFunction; - @Override - public RowKey getKey(Row value) { - IType[] keyTypes = groupByFunction.getFieldTypes(); - Object[] fields = new Object[keyTypes.length]; - for (int i = 0; i < keyTypes.length; i++) { - fields[i] = value.getField(i, keyTypes[i]); - } - return ObjectRowKey.of(fields); - } + public GroupKeySelectorFunction(GroupByFunction groupByFunction) { + this.groupByFunction = groupByFunction; } - private static class TableOrderByFunction extends RichWindowFunction implements - FlatMapFunction { + @Override + public RowKey getKey(Row value) { + IType[] keyTypes = groupByFunction.getFieldTypes(); + Object[] fields = new Object[keyTypes.length]; + for (int i = 0; i < keyTypes.length; i++) { + fields[i] = value.getField(i, keyTypes[i]); + } + return ObjectRowKey.of(fields); + } + } - private final OrderByFunction orderByFunction; - private Collector collector; + private static class TableOrderByFunction extends RichWindowFunction + implements FlatMapFunction { - public TableOrderByFunction(OrderByFunction orderByFunction) { - this.orderByFunction = orderByFunction; - } + private final OrderByFunction orderByFunction; + private Collector collector; - @Override - public void open(RuntimeContext runtimeContext) { - FunctionContext context = - FunctionContext.of(runtimeContext.getConfiguration()); - orderByFunction.open(context); - } + public TableOrderByFunction(OrderByFunction orderByFunction) { + this.orderByFunction = orderByFunction; + } - @Override - public void flatMap(Row value, Collector collector) { - this.orderByFunction.process(value); - this.collector = collector; - } + @Override + public void open(RuntimeContext runtimeContext) { + FunctionContext context = FunctionContext.of(runtimeContext.getConfiguration()); + orderByFunction.open(context); + } - @Override - public void finish() { - Iterable resultRows = orderByFunction.finish(); - for (Row row : resultRows) { - assert collector != null : "Not empty sort encounters collector which is null"; - collector.partition(row); - } - } + @Override + public void flatMap(Row value, Collector collector) { + this.orderByFunction.process(value); + this.collector = collector; + } - @Override - public void close() { - } + @Override + public void finish() { + Iterable resultRows = orderByFunction.finish(); + for (Row row : resultRows) { + assert collector != null : "Not empty sort encounters collector which is null"; + collector.partition(row); + } } + @Override + public void close() {} + } - private static class CorrelateFlatMapFunction extends RichFunction implements FlatMapFunction { + private static class CorrelateFlatMapFunction extends RichFunction + implements FlatMapFunction { - private final CorrelateFunction correlateFunction; + private final CorrelateFunction correlateFunction; - public CorrelateFlatMapFunction(CorrelateFunction correlateFunction) { - this.correlateFunction = correlateFunction; - } + public CorrelateFlatMapFunction(CorrelateFunction correlateFunction) { + this.correlateFunction = correlateFunction; + } - @Override - public void open(RuntimeContext runtimeContext) { - FunctionContext context = - FunctionContext.of(runtimeContext.getConfiguration()); - correlateFunction.open(context); - } + @Override + public void open(RuntimeContext runtimeContext) { + FunctionContext context = FunctionContext.of(runtimeContext.getConfiguration()); + correlateFunction.open(context); + } - @Override - public void close() { - } + @Override + public void close() {} - @Override - public void flatMap(Row value, Collector collector) { - if (value != null) { - List table = correlateFunction.process(value); - if (table != null) { - List rows = joinTable(value, table); - for (Row row : rows) { - if (row != null) { - collector.partition(row); - } - } - } + @Override + public void flatMap(Row value, Collector collector) { + if (value != null) { + List table = correlateFunction.process(value); + if (table != null) { + List rows = joinTable(value, table); + for (Row row : rows) { + if (row != null) { + collector.partition(row); } + } } + } + } - private List joinTable(Row value, List table) { - List rows = Lists.newArrayList(); - for (Row line : table) { - Object[] values = new Object[correlateFunction.getLeftOutputTypes().size() + private List joinTable(Row value, List table) { + List rows = Lists.newArrayList(); + for (Row line : table) { + Object[] values = + new Object + [correlateFunction.getLeftOutputTypes().size() + correlateFunction.getRightOutputTypes().size()]; - int idx = 0; - for (int i = 0; i < correlateFunction.getLeftOutputTypes().size(); i++) { - values[idx++] = value.getField(i, correlateFunction.getLeftOutputTypes().get(i)); - } - for (int i = 0; i < correlateFunction.getRightOutputTypes().size(); i++) { - values[idx++] = line.getField(i, correlateFunction.getRightOutputTypes().get(i)); - } - Row row = ObjectRow.create(values); - rows.add(row); - } - return rows; + int idx = 0; + for (int i = 0; i < correlateFunction.getLeftOutputTypes().size(); i++) { + values[idx++] = value.getField(i, correlateFunction.getLeftOutputTypes().get(i)); } + for (int i = 0; i < correlateFunction.getRightOutputTypes().size(); i++) { + values[idx++] = line.getField(i, correlateFunction.getRightOutputTypes().get(i)); + } + Row row = ObjectRow.create(values); + rows.add(row); + } + return rows; } + } - /** - * Convert {@link Row} to {@link RowVertex} for writing graph. - */ - public static class RowToVertexFunction implements FlatMapFunction { + /** Convert {@link Row} to {@link RowVertex} for writing graph. */ + public static class RowToVertexFunction implements FlatMapFunction { - private final int numVertex; + private final int numVertex; - private final IType[] vertexTypes; + private final IType[] vertexTypes; - private final VertexEncoder[] vertexEncoders; + private final VertexEncoder[] vertexEncoders; - public RowToVertexFunction(GeaFlowGraph graph) { - this.numVertex = graph.getVertexTables().size(); - this.vertexTypes = new IType[numVertex]; - this.vertexEncoders = new VertexEncoder[numVertex]; + public RowToVertexFunction(GeaFlowGraph graph) { + this.numVertex = graph.getVertexTables().size(); + this.vertexTypes = new IType[numVertex]; + this.vertexEncoders = new VertexEncoder[numVertex]; - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - for (int i = 0; i < numVertex; i++) { - vertexTypes[i] = SqlTypeUtil.convertType( - graph.getVertexTables().get(i).getRowType(typeFactory)); - vertexEncoders[i] = EncoderFactory.createVertexEncoder((VertexType) vertexTypes[i]); - } - } + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + for (int i = 0; i < numVertex; i++) { + vertexTypes[i] = + SqlTypeUtil.convertType(graph.getVertexTables().get(i).getRowType(typeFactory)); + vertexEncoders[i] = EncoderFactory.createVertexEncoder((VertexType) vertexTypes[i]); + } + } - @Override - public void flatMap(Row value, Collector collector) { - for (int i = 0; i < numVertex; i++) { - RowVertex vertex = (RowVertex) value.getField(i, vertexTypes[i]); - if (vertex != null) { - collector.partition(vertexEncoders[i].encode(vertex)); - } - } + @Override + public void flatMap(Row value, Collector collector) { + for (int i = 0; i < numVertex; i++) { + RowVertex vertex = (RowVertex) value.getField(i, vertexTypes[i]); + if (vertex != null) { + collector.partition(vertexEncoders[i].encode(vertex)); } + } } + } - /** - * Convert {@link Row} to {@link RowEdge} for writing graph. - */ - public static class RowToEdgeFunction implements FlatMapFunction { + /** Convert {@link Row} to {@link RowEdge} for writing graph. */ + public static class RowToEdgeFunction implements FlatMapFunction { - private final int numVertex; + private final int numVertex; - private final int numEdge; + private final int numEdge; - private final IType[] edgeTypes; + private final IType[] edgeTypes; - private final EdgeEncoder[] edgeEncoders; + private final EdgeEncoder[] edgeEncoders; - public RowToEdgeFunction(GeaFlowGraph graph) { - this.numVertex = graph.getVertexTables().size(); - this.numEdge = graph.getEdgeTables().size(); - this.edgeTypes = new IType[numEdge]; - this.edgeEncoders = new EdgeEncoder[numEdge]; - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - for (int i = 0; i < numEdge; i++) { - edgeTypes[i] = SqlTypeUtil.convertType( - graph.getEdgeTables().get(i).getRowType(typeFactory)); - edgeEncoders[i] = EncoderFactory.createEdgeEncoder((EdgeType) edgeTypes[i]); - } - } + public RowToEdgeFunction(GeaFlowGraph graph) { + this.numVertex = graph.getVertexTables().size(); + this.numEdge = graph.getEdgeTables().size(); + this.edgeTypes = new IType[numEdge]; + this.edgeEncoders = new EdgeEncoder[numEdge]; + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + for (int i = 0; i < numEdge; i++) { + edgeTypes[i] = + SqlTypeUtil.convertType(graph.getEdgeTables().get(i).getRowType(typeFactory)); + edgeEncoders[i] = EncoderFactory.createEdgeEncoder((EdgeType) edgeTypes[i]); + } + } - @Override - public void flatMap(Row value, Collector collector) { - for (int i = numVertex; i < numVertex + numEdge; i++) { - RowEdge edge = (RowEdge) value.getField(i, edgeTypes[i - numVertex]); - if (edge != null) { - RowEdge encodeEdge = edgeEncoders[i - numVertex].encode(edge); - collector.partition(encodeEdge); - collector.partition(encodeEdge.identityReverse()); - } - } - } + @Override + public void flatMap(Row value, Collector collector) { + for (int i = numVertex; i < numVertex + numEdge; i++) { + RowEdge edge = (RowEdge) value.getField(i, edgeTypes[i - numVertex]); + if (edge != null) { + RowEdge encodeEdge = edgeEncoders[i - numVertex].encode(edge); + collector.partition(encodeEdge); + collector.partition(encodeEdge.identityReverse()); + } + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkDataView.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkDataView.java index 198444ca4..a4b749cfb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkDataView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkDataView.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.List; + import org.apache.geaflow.api.pdata.PStreamSink; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -29,22 +30,22 @@ public class GeaFlowSinkDataView implements SinkDataView { - private final IPipelineJobContext context; + private final IPipelineJobContext context; - private final PStreamSink sink; + private final PStreamSink sink; - public GeaFlowSinkDataView(IPipelineJobContext context, PStreamSink sink) { - this.context = context; - this.sink = sink; - } + public GeaFlowSinkDataView(IPipelineJobContext context, PStreamSink sink) { + this.context = context; + this.sink = sink; + } - @Override - public T getPlan() { - return (T) sink; - } + @Override + public T getPlan() { + return (T) sink; + } - @Override - public List take(IType type) { - throw new GeaFlowDSLException("Should not call take() on SinkDataView"); - } + @Override + public List take(IType type) { + throw new GeaFlowDSLException("Should not call take() on SinkDataView"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkIncGraphView.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkIncGraphView.java index 715b9a2ac..eca046cc3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkIncGraphView.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowSinkIncGraphView.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -28,19 +29,19 @@ public class GeaFlowSinkIncGraphView implements SinkDataView { - private final IPipelineJobContext context; + private final IPipelineJobContext context; - public GeaFlowSinkIncGraphView(IPipelineJobContext context) { - this.context = context; - } + public GeaFlowSinkIncGraphView(IPipelineJobContext context) { + this.context = context; + } - @Override - public T getPlan() { - throw new GeaFlowDSLException("Should not call getPlan() on GeaFlowSinkIncGraphView"); - } + @Override + public T getPlan() { + throw new GeaFlowDSLException("Should not call getPlan() on GeaFlowSinkIncGraphView"); + } - @Override - public List take(IType type) { - throw new GeaFlowDSLException("Should not call take() on GeaFlowIncGraphView"); - } + @Override + public List take(IType type) { + throw new GeaFlowDSLException("Should not call take() on GeaFlowIncGraphView"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticTraversalRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticTraversalRuntimeContext.java index 319ab5f37..cd3067e21 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticTraversalRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticTraversalRuntimeContext.java @@ -32,53 +32,59 @@ public class GeaFlowStaticTraversalRuntimeContext extends AbstractTraversalRuntimeContext { - private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowStaticTraversalRuntimeContext.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(GeaFlowStaticTraversalRuntimeContext.class); - private final VertexCentricTraversalFuncContext traversalContext; + private final VertexCentricTraversalFuncContext + traversalContext; - public GeaFlowStaticTraversalRuntimeContext( - VertexCentricTraversalFuncContext traversalContext) { - super(traversalContext.vertex(), traversalContext.edges()); - this.traversalContext = traversalContext; - } + public GeaFlowStaticTraversalRuntimeContext( + VertexCentricTraversalFuncContext traversalContext) { + super(traversalContext.vertex(), traversalContext.edges()); + this.traversalContext = traversalContext; + } - @Override - public Configuration getConfig() { - return traversalContext.getRuntimeContext().getConfiguration(); - } + @Override + public Configuration getConfig() { + return traversalContext.getRuntimeContext().getConfiguration(); + } - @Override - public long getIterationId() { - return traversalContext.getIterationId(); - } + @Override + public long getIterationId() { + return traversalContext.getIterationId(); + } - @SuppressWarnings("unchecked") - @Override - protected void sendBroadcastMessage(Object vertexId, MessageBox messageBox) { - traversalContext.broadcast(new DefaultGraphMessage<>(vertexId, messageBox)); - } + @SuppressWarnings("unchecked") + @Override + protected void sendBroadcastMessage(Object vertexId, MessageBox messageBox) { + traversalContext.broadcast(new DefaultGraphMessage<>(vertexId, messageBox)); + } - @Override - protected void sendMessage(Object vertexId, MessageBox messageBox) { - traversalContext.sendMessage(vertexId, messageBox); - } + @Override + protected void sendMessage(Object vertexId, MessageBox messageBox) { + traversalContext.sendMessage(vertexId, messageBox); + } - @Override - public void takePath(ITreePath treePath) { - traversalContext.takeResponse(new TraversalResponse(treePath)); - } + @Override + public void takePath(ITreePath treePath) { + traversalContext.takeResponse(new TraversalResponse(treePath)); + } - @Override - public void sendCoordinator(String name, Object value) { - LOGGER.info("task: {} send to coordinator {}:{} isAggTraversal:{}", getTaskIndex(), name, - value, aggContext != null); - if (aggContext != null) { - aggContext.aggregate(new KVTraversalAgg(name, value)); - } + @Override + public void sendCoordinator(String name, Object value) { + LOGGER.info( + "task: {} send to coordinator {}:{} isAggTraversal:{}", + getTaskIndex(), + name, + value, + aggContext != null); + if (aggContext != null) { + aggContext.aggregate(new KVTraversalAgg(name, value)); } + } - @Override - public RuntimeContext getRuntimeContext() { - return traversalContext.getRuntimeContext(); - } + @Override + public RuntimeContext getRuntimeContext() { + return traversalContext.getRuntimeContext(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversal.java index 062906b1a..2651f8061 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversal.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.Objects; + import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -31,52 +32,66 @@ import org.apache.geaflow.dsl.runtime.traversal.message.MessageBox; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class GeaFlowStaticVCAggTraversal extends VertexCentricAggTraversal { - private final ExecuteDagGroup executeDagGroup; +public class GeaFlowStaticVCAggTraversal + extends VertexCentricAggTraversal< + Object, + Row, + Row, + MessageBox, + ITreePath, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg, + ITraversalAgg> { + private final ExecuteDagGroup executeDagGroup; - private final boolean isTraversalAllWithRequest; + private final boolean isTraversalAllWithRequest; - private final int parallelism; + private final int parallelism; - public GeaFlowStaticVCAggTraversal(ExecuteDagGroup executeDagGroup, - int maxTraversal, - boolean isTraversalAllWithRequest, - int parallelism) { - super(maxTraversal); - this.executeDagGroup = Objects.requireNonNull(executeDagGroup); - this.isTraversalAllWithRequest = isTraversalAllWithRequest; - assert parallelism > 0; - this.parallelism = parallelism; - } + public GeaFlowStaticVCAggTraversal( + ExecuteDagGroup executeDagGroup, + int maxTraversal, + boolean isTraversalAllWithRequest, + int parallelism) { + super(maxTraversal); + this.executeDagGroup = Objects.requireNonNull(executeDagGroup); + this.isTraversalAllWithRequest = isTraversalAllWithRequest; + assert parallelism > 0; + this.parallelism = parallelism; + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return new MessageBoxCombineFunction(); - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new MessageBoxCombineFunction(); + } - @Override - public IEncoder getMessageEncoder() { - return null; - } + @Override + public IEncoder getMessageEncoder() { + return null; + } - @Override - public VertexCentricAggTraversalFunction getTraversalFunction() { - return new GeaFlowStaticVCAggTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } + @Override + public VertexCentricAggTraversalFunction< + Object, Row, Row, MessageBox, ITreePath, ITraversalAgg, ITraversalAgg> + getTraversalFunction() { + return new GeaFlowStaticVCAggTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } - @Override - public VertexCentricAggregateFunction getAggregateFunction() { - return (VertexCentricAggregateFunction) new GeaFlowKVTraversalAggregateFunction(parallelism); - } + @Override + public VertexCentricAggregateFunction< + ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg, ITraversalAgg> + getAggregateFunction() { + return (VertexCentricAggregateFunction) new GeaFlowKVTraversalAggregateFunction(parallelism); + } - private static class MessageBoxCombineFunction implements VertexCentricCombineFunction { + private static class MessageBoxCombineFunction + implements VertexCentricCombineFunction { - @Override - public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { - return newMessage.combine(oldMessage); - } + @Override + public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { + return newMessage.combine(oldMessage); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversalFunction.java index 62ec22361..cce2ee107 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCAggTraversalFunction.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.Objects; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; import org.apache.geaflow.dsl.common.data.Row; @@ -30,56 +31,57 @@ import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.model.traversal.ITraversalRequest; -public class GeaFlowStaticVCAggTraversalFunction implements - VertexCentricAggTraversalFunction, RichIteratorFunction { - - private final GeaFlowCommonTraversalFunction commonFunction; - private GeaFlowStaticTraversalRuntimeContext traversalRuntimeContext; - - public GeaFlowStaticVCAggTraversalFunction(ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { - this.commonFunction = new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } +public class GeaFlowStaticVCAggTraversalFunction + implements VertexCentricAggTraversalFunction< + Object, Row, Row, MessageBox, ITreePath, ITraversalAgg, ITraversalAgg>, + RichIteratorFunction { - @Override - public void open( - VertexCentricTraversalFuncContext vertexCentricFuncContext) { - commonFunction.open(this.traversalRuntimeContext = - new GeaFlowStaticTraversalRuntimeContext(vertexCentricFuncContext)); - } + private final GeaFlowCommonTraversalFunction commonFunction; + private GeaFlowStaticTraversalRuntimeContext traversalRuntimeContext; - @Override - public void initIteration(long windowId) { + public GeaFlowStaticVCAggTraversalFunction( + ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { + this.commonFunction = + new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } - } + @Override + public void open( + VertexCentricTraversalFuncContext + vertexCentricFuncContext) { + commonFunction.open( + this.traversalRuntimeContext = + new GeaFlowStaticTraversalRuntimeContext(vertexCentricFuncContext)); + } - @Override - public void init(ITraversalRequest traversalRequest) { - commonFunction.init(traversalRequest); - } + @Override + public void initIteration(long windowId) {} - @Override - public void compute(Object vertexId, Iterator messageIterator) { - commonFunction.compute(vertexId, messageIterator); - } + @Override + public void init(ITraversalRequest traversalRequest) { + commonFunction.init(traversalRequest); + } - @Override - public void finishIteration(long windowId) { - commonFunction.finish(windowId); - } + @Override + public void compute(Object vertexId, Iterator messageIterator) { + commonFunction.compute(vertexId, messageIterator); + } - @Override - public void finish() { + @Override + public void finishIteration(long windowId) { + commonFunction.finish(windowId); + } - } + @Override + public void finish() {} - @Override - public void close() { - commonFunction.close(); - } + @Override + public void close() { + commonFunction.close(); + } - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.traversalRuntimeContext.setAggContext(Objects.requireNonNull(aggContext)); - } + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.traversalRuntimeContext.setAggContext(Objects.requireNonNull(aggContext)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversal.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversal.java index f3e6511f8..441ea14c7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversal.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversal.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.Objects; + import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; import org.apache.geaflow.api.graph.traversal.VertexCentricTraversal; @@ -29,40 +30,42 @@ import org.apache.geaflow.dsl.runtime.traversal.message.MessageBox; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class GeaFlowStaticVCTraversal extends VertexCentricTraversal { +public class GeaFlowStaticVCTraversal + extends VertexCentricTraversal { - private final ExecuteDagGroup executeDagGroup; + private final ExecuteDagGroup executeDagGroup; - private final boolean isTraversalAllWithRequest; + private final boolean isTraversalAllWithRequest; - public GeaFlowStaticVCTraversal(ExecuteDagGroup executeDagGroup, - int maxTraversal, - boolean isTraversalAllWithRequest) { - super(maxTraversal); - this.executeDagGroup = Objects.requireNonNull(executeDagGroup); - this.isTraversalAllWithRequest = isTraversalAllWithRequest; - } + public GeaFlowStaticVCTraversal( + ExecuteDagGroup executeDagGroup, int maxTraversal, boolean isTraversalAllWithRequest) { + super(maxTraversal); + this.executeDagGroup = Objects.requireNonNull(executeDagGroup); + this.isTraversalAllWithRequest = isTraversalAllWithRequest; + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return new MessageBoxCombineFunction(); - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new MessageBoxCombineFunction(); + } - @Override - public IEncoder getMessageEncoder() { - return null; - } + @Override + public IEncoder getMessageEncoder() { + return null; + } - @Override - public VertexCentricTraversalFunction getTraversalFunction() { - return new GeaFlowStaticVCTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } + @Override + public VertexCentricTraversalFunction + getTraversalFunction() { + return new GeaFlowStaticVCTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } - private static class MessageBoxCombineFunction implements VertexCentricCombineFunction { + private static class MessageBoxCombineFunction + implements VertexCentricCombineFunction { - @Override - public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { - return newMessage.combine(oldMessage); - } + @Override + public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { + return newMessage.combine(oldMessage); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversalFunction.java index 8951447e5..0884a60e1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowStaticVCTraversalFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.engine; import java.util.Iterator; + import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; import org.apache.geaflow.dsl.common.data.Row; @@ -28,48 +29,48 @@ import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.model.traversal.ITraversalRequest; -public class GeaFlowStaticVCTraversalFunction implements - VertexCentricTraversalFunction, RichIteratorFunction { - - private final GeaFlowCommonTraversalFunction commonFunction; - - public GeaFlowStaticVCTraversalFunction(ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { - this.commonFunction = new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); - } +public class GeaFlowStaticVCTraversalFunction + implements VertexCentricTraversalFunction, + RichIteratorFunction { - @Override - public void open( - VertexCentricTraversalFuncContext vertexCentricFuncContext) { - commonFunction.open(new GeaFlowStaticTraversalRuntimeContext(vertexCentricFuncContext)); - } + private final GeaFlowCommonTraversalFunction commonFunction; - @Override - public void initIteration(long windowId) { + public GeaFlowStaticVCTraversalFunction( + ExecuteDagGroup executeDagGroup, boolean isTraversalAllWithRequest) { + this.commonFunction = + new GeaFlowCommonTraversalFunction(executeDagGroup, isTraversalAllWithRequest); + } - } + @Override + public void open( + VertexCentricTraversalFuncContext + vertexCentricFuncContext) { + commonFunction.open(new GeaFlowStaticTraversalRuntimeContext(vertexCentricFuncContext)); + } - @Override - public void init(ITraversalRequest traversalRequest) { - commonFunction.init(traversalRequest); - } + @Override + public void initIteration(long windowId) {} - @Override - public void compute(Object vertexId, Iterator messageIterator) { - commonFunction.compute(vertexId, messageIterator); - } + @Override + public void init(ITraversalRequest traversalRequest) { + commonFunction.init(traversalRequest); + } - @Override - public void finishIteration(long windowId) { - commonFunction.finish(windowId); - } + @Override + public void compute(Object vertexId, Iterator messageIterator) { + commonFunction.compute(vertexId, messageIterator); + } - @Override - public void finish() { + @Override + public void finishIteration(long windowId) { + commonFunction.finish(windowId); + } - } + @Override + public void finish() {} - @Override - public void close() { - commonFunction.close(); - } + @Override + public void close() { + commonFunction.close(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/MessageBoxCombineFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/MessageBoxCombineFunction.java index 33fdba280..46dd8f374 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/MessageBoxCombineFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/MessageBoxCombineFunction.java @@ -24,8 +24,8 @@ public class MessageBoxCombineFunction implements VertexCentricCombineFunction { - @Override - public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { - return newMessage.combine(oldMessage); - } + @Override + public MessageBox combine(MessageBox oldMessage, MessageBox newMessage) { + return newMessage.combine(oldMessage); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/TraversalResponse.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/TraversalResponse.java index aec2b362c..0192a8d61 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/TraversalResponse.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/TraversalResponse.java @@ -25,24 +25,24 @@ public class TraversalResponse implements ITraversalResponse { - private final ITreePath treePath; - - public TraversalResponse(ITreePath treePath) { - this.treePath = treePath; - } - - @Override - public long getResponseId() { - return treePath.hashCode(); - } - - @Override - public ITreePath getResponse() { - return treePath; - } - - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + private final ITreePath treePath; + + public TraversalResponse(ITreePath treePath) { + this.treePath = treePath; + } + + @Override + public long getResponseId() { + return treePath.hashCode(); + } + + @Override + public ITreePath getResponse() { + return treePath; + } + + @Override + public ResponseType getType() { + return ResponseType.Vertex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractExpression.java index 7a20eedfe..7ddd6750d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractExpression.java @@ -25,70 +25,71 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; + import org.apache.geaflow.dsl.runtime.expression.field.PathFieldExpression; import org.apache.geaflow.dsl.runtime.expression.logic.AndExpression; public abstract class AbstractExpression implements Expression { - @Override - public String toString() { - return showExpression(); - } + @Override + public String toString() { + return showExpression(); + } - @Override - public int hashCode() { - return toString().hashCode(); - } + @Override + public int hashCode() { + return toString().hashCode(); + } - @Override - public boolean equals(Object that) { - if (!(that instanceof Expression)) { - return false; - } - return toString().equals(that.toString()); + @Override + public boolean equals(Object that) { + if (!(that instanceof Expression)) { + return false; } + return toString().equals(that.toString()); + } - @Override - public List getRefPathFieldIndices() { - List pathFields = new ArrayList<>(); - getInputs().forEach(input -> pathFields.addAll(input.getRefPathFieldIndices())); - if (this instanceof PathFieldExpression) { - pathFields.add(((PathFieldExpression) this).getFieldIndex()); - } - return pathFields; + @Override + public List getRefPathFieldIndices() { + List pathFields = new ArrayList<>(); + getInputs().forEach(input -> pathFields.addAll(input.getRefPathFieldIndices())); + if (this instanceof PathFieldExpression) { + pathFields.add(((PathFieldExpression) this).getFieldIndex()); } + return pathFields; + } - @Override - public Expression replace(Function replaceFn) { - List newInputs = getInputs().stream() - .map(input -> input.replace(replaceFn)) - .collect(Collectors.toList()); - Expression replaceExpression = replaceFn.apply(this); - if (replaceExpression != this) { - return replaceExpression.copy(newInputs); - } - return this.copy(newInputs); + @Override + public Expression replace(Function replaceFn) { + List newInputs = + getInputs().stream().map(input -> input.replace(replaceFn)).collect(Collectors.toList()); + Expression replaceExpression = replaceFn.apply(this); + if (replaceExpression != this) { + return replaceExpression.copy(newInputs); } + return this.copy(newInputs); + } - @Override - public List collect(Predicate condition) { - List collects = getInputs().stream() + @Override + public List collect(Predicate condition) { + List collects = + getInputs().stream() .flatMap(input -> input.collect(condition).stream()) .collect(Collectors.toList()); - if (condition.test(this)) { - collects.add(this); - } - return collects; + if (condition.test(this)) { + collects.add(this); } + return collects; + } - @Override - public List splitByAnd() { - if (this instanceof AndExpression) { - AndExpression and = (AndExpression) this; - return and.getInputs().stream() - .flatMap(input -> input.splitByAnd().stream()) - .collect(Collectors.toList()); - } - return Collections.singletonList(this); + @Override + public List splitByAnd() { + if (this instanceof AndExpression) { + AndExpression and = (AndExpression) this; + return and.getInputs().stream() + .flatMap(input -> input.splitByAnd().stream()) + .collect(Collectors.toList()); } + return Collections.singletonList(this); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractNonLeafExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractNonLeafExpression.java index 3fb9c17cc..acb476cf6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractNonLeafExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractNonLeafExpression.java @@ -22,40 +22,42 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.function.FunctionContext; public abstract class AbstractNonLeafExpression extends AbstractExpression { - protected final List inputs; + protected final List inputs; - protected final List> inputTypes; + protected final List> inputTypes; - protected final IType outputType; + protected final IType outputType; - public AbstractNonLeafExpression(List inputs, IType outputType) { - this.inputs = Objects.requireNonNull(inputs); - this.inputTypes = inputs.stream() + public AbstractNonLeafExpression(List inputs, IType outputType) { + this.inputs = Objects.requireNonNull(inputs); + this.inputTypes = + inputs.stream() .map(Expression::getOutputType) .map(IType::getTypeClass) .collect(Collectors.toList()); - this.outputType = outputType; - } + this.outputType = outputType; + } - @Override - public void open(FunctionContext context) { - for (Expression input : inputs) { - input.open(context); - } + @Override + public void open(FunctionContext context) { + for (Expression input : inputs) { + input.open(context); } + } - @Override - public List getInputs() { - return inputs; - } + @Override + public List getInputs() { + return inputs; + } - @Override - public IType getOutputType() { - return outputType; - } + @Override + public IType getOutputType() { + return outputType; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractReflectCallExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractReflectCallExpression.java index 0242379a0..5bc733d0f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractReflectCallExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/AbstractReflectCallExpression.java @@ -23,6 +23,7 @@ import java.lang.reflect.Modifier; import java.util.List; import java.util.Objects; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -31,48 +32,52 @@ public abstract class AbstractReflectCallExpression extends AbstractNonLeafExpression { - protected final Class implementClass; + protected final Class implementClass; - protected final String methodName; + protected final String methodName; - private transient Method method = null; + private transient Method method = null; - protected transient Object implementInstance = null; + protected transient Object implementInstance = null; - public AbstractReflectCallExpression(List inputs, IType outputType, - Class implementClass, String methodName) { - super(inputs, outputType); - this.implementClass = Objects.requireNonNull(implementClass); - this.methodName = Objects.requireNonNull(methodName); - } + public AbstractReflectCallExpression( + List inputs, IType outputType, Class implementClass, String methodName) { + super(inputs, outputType); + this.implementClass = Objects.requireNonNull(implementClass); + this.methodName = Objects.requireNonNull(methodName); + } - @Override - public Object evaluate(Row row) { - Object[] paramValues = new Object[inputs.size()]; - for (int i = 0; i < inputs.size(); i++) { - paramValues[i] = inputs.get(i).evaluate(row); - } - try { - initMethod(); - return FunctionCallUtils.callMethod(method, implementInstance, paramValues); - } catch (Exception e) { - String msg = "Error in call " + methodName + "," - + " params is (" + StringUtils.join(paramValues, ", ") + ")"; - throw new RuntimeException(msg, e); - } + @Override + public Object evaluate(Row row) { + Object[] paramValues = new Object[inputs.size()]; + for (int i = 0; i < inputs.size(); i++) { + paramValues[i] = inputs.get(i).evaluate(row); + } + try { + initMethod(); + return FunctionCallUtils.callMethod(method, implementInstance, paramValues); + } catch (Exception e) { + String msg = + "Error in call " + + methodName + + "," + + " params is (" + + StringUtils.join(paramValues, ", ") + + ")"; + throw new RuntimeException(msg, e); } + } - private void initMethod() { - try { - if (method == null) { - method = FunctionCallUtils.findMatchMethod(implementClass, - methodName, inputTypes); - if (!Modifier.isStatic(method.getModifiers())) { - implementInstance = implementClass.newInstance(); - } - } - } catch (Exception e) { - throw new GeaFlowDSLException(e); + private void initMethod() { + try { + if (method == null) { + method = FunctionCallUtils.findMatchMethod(implementClass, methodName, inputTypes); + if (!Modifier.isStatic(method.getModifiers())) { + implementInstance = implementClass.newInstance(); } + } + } catch (Exception e) { + throw new GeaFlowDSLException(e); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/BuildInExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/BuildInExpression.java index 4093366e6..589d3bb4e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/BuildInExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/BuildInExpression.java @@ -21,90 +21,92 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; public class BuildInExpression extends AbstractReflectCallExpression { - public static final String FLOOR = "floor"; + public static final String FLOOR = "floor"; - public static final String TIMESTAMP_FLOOR = "timestampFloor"; + public static final String TIMESTAMP_FLOOR = "timestampFloor"; - public static final String CEIL = "ceil"; + public static final String CEIL = "ceil"; - public static final String TIMESTAMP_CEIL = "timestampCeil"; + public static final String TIMESTAMP_CEIL = "timestampCeil"; - public static final String TRIM = "trim"; + public static final String TRIM = "trim"; - public static final String SIMILAR = "similar"; + public static final String SIMILAR = "similar"; - public static final String CONCAT = "concat"; + public static final String CONCAT = "concat"; - public static final String LENGTH = "length"; + public static final String LENGTH = "length"; - public static final String UPPER = "upper"; + public static final String UPPER = "upper"; - public static final String LOWER = "lower"; + public static final String LOWER = "lower"; - public static final String POSITION = "position"; + public static final String POSITION = "position"; - public static final String OVERLAY = "overlay"; + public static final String OVERLAY = "overlay"; - public static final String SUBSTRING = "substring"; + public static final String SUBSTRING = "substring"; - public static final String INITCAP = "initcap"; + public static final String INITCAP = "initcap"; - public static final String ABS = "abs"; + public static final String ABS = "abs"; - public static final String POWER = "power"; + public static final String POWER = "power"; - public static final String SIN = "sin"; + public static final String SIN = "sin"; - public static final String COS = "cos"; + public static final String COS = "cos"; - public static final String LN = "ln"; + public static final String LN = "ln"; - public static final String LOG10 = "log10"; + public static final String LOG10 = "log10"; - public static final String EXP = "exp"; + public static final String EXP = "exp"; - public static final String TAN = "tan"; + public static final String TAN = "tan"; - public static final String COT = "cot"; + public static final String COT = "cot"; - public static final String ASIN = "asin"; + public static final String ASIN = "asin"; - public static final String ACOS = "acos"; + public static final String ACOS = "acos"; - public static final String ATAN = "atan"; + public static final String ATAN = "atan"; - public static final String DEGREES = "degrees"; + public static final String DEGREES = "degrees"; - public static final String RADIANS = "radians"; + public static final String RADIANS = "radians"; - public static final String SIGN = "sign"; + public static final String SIGN = "sign"; - public static final String RAND = "rand"; + public static final String RAND = "rand"; - public static final String RAND_INTEGER = "randInt"; + public static final String RAND_INTEGER = "randInt"; - public static final String CURRENT_TIMESTAMP = "currentTimestamp"; + public static final String CURRENT_TIMESTAMP = "currentTimestamp"; - public static final String SAME = "same"; + public static final String SAME = "same"; - public BuildInExpression(List inputs, IType outputType, - Class implementClass, String methodName) { - super(inputs, outputType, implementClass, methodName); - } + public BuildInExpression( + List inputs, IType outputType, Class implementClass, String methodName) { + super(inputs, outputType, implementClass, methodName); + } - @Override - public String showExpression() { - return methodName + "(" - + inputs.stream().map(Expression::showExpression).collect(Collectors.joining(",")) - + ")"; - } + @Override + public String showExpression() { + return methodName + + "(" + + inputs.stream().map(Expression::showExpression).collect(Collectors.joining(",")) + + ")"; + } - @Override - public Expression copy(List inputs) { - return new BuildInExpression(inputs, outputType, implementClass, methodName); - } + @Override + public Expression copy(List inputs) { + return new BuildInExpression(inputs, outputType, implementClass, methodName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/DefaultExpressionBuilder.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/DefaultExpressionBuilder.java index d72eee5d4..06cb361d3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/DefaultExpressionBuilder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/DefaultExpressionBuilder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.expression; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.function.UDF; import org.apache.geaflow.dsl.common.function.UDTF; @@ -61,178 +62,179 @@ public class DefaultExpressionBuilder implements ExpressionBuilder { - public DefaultExpressionBuilder() { - - } - - @Override - public Expression plus(Expression left, Expression right, IType outputType) { - return new PlusExpression(left, right, outputType); - } - - @Override - public Expression minus(Expression left, Expression right, IType outputType) { - return new MinusExpression(left, right, outputType); - } - - @Override - public Expression multiply(Expression left, Expression right, IType outputType) { - return new MultiplyExpression(left, right, outputType); - } - - @Override - public Expression divide(Expression left, Expression right, IType outputType) { - return new DivideExpression(left, right, outputType); - } - - @Override - public Expression mod(Expression left, Expression right, IType outputType) { - return new ModExpression(left, right, outputType); - } - - @Override - public Expression minusPrefix(Expression input, IType outputType) { - return new MinusPrefixExpression(input, outputType); - } - - @Override - public Expression cast(Expression input, IType outputType) { - return new CastExpression(input, outputType); - } - - @Override - public Expression isNull(Expression input) { - return new IsNullExpression(input); - } - - @Override - public Expression isNotNull(Expression input) { - return new IsNotNullExpression(input); - } - - @Override - public Expression caseWhen(List inputs, IType outputType) { - return new CaseExpression(inputs, outputType); - } - - @Override - public Expression ifExp(Expression condition, Expression trueValue, Expression falseValue, IType outputType) { - return new IfExpression(condition, trueValue, falseValue, outputType); - } - - @Override - public Expression field(Expression input, int fieldIndex, IType outputType) { - return new FieldExpression(input, fieldIndex, outputType); - } - - @Override - public Expression pathField(String label, int fieldIndex, IType outputType) { - return new PathFieldExpression(label, fieldIndex, outputType); - } - - @Override - public Expression parameterField(int fieldIndex, IType outputType) { - return new ParameterFieldExpression(fieldIndex, outputType); - } - - @Override - public Expression item(Expression input, Expression index) { - return new ItemExpression(input, index); - } - - @Override - public Expression literal(Object value, IType outputType) { - return new LiteralExpression(value, outputType); - } - - @Override - public Expression pi() { - return new PIExpression(); - } - - @Override - public Expression and(List inputs) { - return new AndExpression(inputs); - } - - @Override - public Expression or(List inputs) { - return new OrExpression(inputs); - } - - @Override - public Expression isFalse(Expression input) { - return new IsFalseExpression(input); - } - - @Override - public Expression isNotFalse(Expression input) { - return new IsNotFalseExpression(input); - } - - @Override - public Expression isTrue(Expression input) { - return new IsTrueExpression(input); - } - - @Override - public Expression isNotTrue(Expression input) { - return new IsNotTrueExpression(input); - } - - @Override - public Expression greaterThan(Expression left, Expression right) { - return new GTExpression(left, right); - } - - @Override - public Expression greaterEqThen(Expression left, Expression right) { - return new GTEExpression(left, right); - } - - @Override - public Expression lessThan(Expression left, Expression right) { - return new LTExpression(left, right); - } - - @Override - public Expression lessEqThan(Expression left, Expression right) { - return new LTEExpression(left, right); - } - - @Override - public Expression equal(Expression left, Expression right) { - return new EqualExpression(left, right); - } - - @Override - public Expression not(Expression input) { - return new NotExpression(input); - } - - @Override - public Expression vertexConstruct(List inputs, List globalVariables, - VertexType vertexType) { - return new VertexConstructExpression(inputs, globalVariables, vertexType); - } - - @Override - public Expression edgeConstruct(List inputs, EdgeType edgeType) { - return new EdgeConstructExpression(inputs, edgeType); - } - - @Override - public Expression udf(List inputs, IType outputType, Class implementClass) { - return new UDFExpression(inputs, outputType, implementClass); - } - - @Override - public Expression udtf(List inputs, IType outputType, Class implementClass) { - return new UDTFExpression(inputs, outputType, implementClass); - } - - @Override - public Expression buildIn(List inputs, IType outputType, String methodName) { - return new BuildInExpression(inputs, outputType, GeaFlowBuiltinFunctions.class, methodName); - } + public DefaultExpressionBuilder() {} + + @Override + public Expression plus(Expression left, Expression right, IType outputType) { + return new PlusExpression(left, right, outputType); + } + + @Override + public Expression minus(Expression left, Expression right, IType outputType) { + return new MinusExpression(left, right, outputType); + } + + @Override + public Expression multiply(Expression left, Expression right, IType outputType) { + return new MultiplyExpression(left, right, outputType); + } + + @Override + public Expression divide(Expression left, Expression right, IType outputType) { + return new DivideExpression(left, right, outputType); + } + + @Override + public Expression mod(Expression left, Expression right, IType outputType) { + return new ModExpression(left, right, outputType); + } + + @Override + public Expression minusPrefix(Expression input, IType outputType) { + return new MinusPrefixExpression(input, outputType); + } + + @Override + public Expression cast(Expression input, IType outputType) { + return new CastExpression(input, outputType); + } + + @Override + public Expression isNull(Expression input) { + return new IsNullExpression(input); + } + + @Override + public Expression isNotNull(Expression input) { + return new IsNotNullExpression(input); + } + + @Override + public Expression caseWhen(List inputs, IType outputType) { + return new CaseExpression(inputs, outputType); + } + + @Override + public Expression ifExp( + Expression condition, Expression trueValue, Expression falseValue, IType outputType) { + return new IfExpression(condition, trueValue, falseValue, outputType); + } + + @Override + public Expression field(Expression input, int fieldIndex, IType outputType) { + return new FieldExpression(input, fieldIndex, outputType); + } + + @Override + public Expression pathField(String label, int fieldIndex, IType outputType) { + return new PathFieldExpression(label, fieldIndex, outputType); + } + + @Override + public Expression parameterField(int fieldIndex, IType outputType) { + return new ParameterFieldExpression(fieldIndex, outputType); + } + + @Override + public Expression item(Expression input, Expression index) { + return new ItemExpression(input, index); + } + + @Override + public Expression literal(Object value, IType outputType) { + return new LiteralExpression(value, outputType); + } + + @Override + public Expression pi() { + return new PIExpression(); + } + + @Override + public Expression and(List inputs) { + return new AndExpression(inputs); + } + + @Override + public Expression or(List inputs) { + return new OrExpression(inputs); + } + + @Override + public Expression isFalse(Expression input) { + return new IsFalseExpression(input); + } + + @Override + public Expression isNotFalse(Expression input) { + return new IsNotFalseExpression(input); + } + + @Override + public Expression isTrue(Expression input) { + return new IsTrueExpression(input); + } + + @Override + public Expression isNotTrue(Expression input) { + return new IsNotTrueExpression(input); + } + + @Override + public Expression greaterThan(Expression left, Expression right) { + return new GTExpression(left, right); + } + + @Override + public Expression greaterEqThen(Expression left, Expression right) { + return new GTEExpression(left, right); + } + + @Override + public Expression lessThan(Expression left, Expression right) { + return new LTExpression(left, right); + } + + @Override + public Expression lessEqThan(Expression left, Expression right) { + return new LTEExpression(left, right); + } + + @Override + public Expression equal(Expression left, Expression right) { + return new EqualExpression(left, right); + } + + @Override + public Expression not(Expression input) { + return new NotExpression(input); + } + + @Override + public Expression vertexConstruct( + List inputs, List globalVariables, VertexType vertexType) { + return new VertexConstructExpression(inputs, globalVariables, vertexType); + } + + @Override + public Expression edgeConstruct(List inputs, EdgeType edgeType) { + return new EdgeConstructExpression(inputs, edgeType); + } + + @Override + public Expression udf( + List inputs, IType outputType, Class implementClass) { + return new UDFExpression(inputs, outputType, implementClass); + } + + @Override + public Expression udtf( + List inputs, IType outputType, Class implementClass) { + return new UDTFExpression(inputs, outputType, implementClass); + } + + @Override + public Expression buildIn(List inputs, IType outputType, String methodName) { + return new BuildInExpression(inputs, outputType, GeaFlowBuiltinFunctions.class, methodName); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/Expression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/Expression.java index 24e4e7103..115765b81 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/Expression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/Expression.java @@ -23,42 +23,34 @@ import java.util.List; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; public interface Expression extends Serializable { - /** - * Open method for expression. - */ - default void open(FunctionContext context) { - } + /** Open method for expression. */ + default void open(FunctionContext context) {} - /** - * Evaluate the value of this expression with the given input. - */ - Object evaluate(Row row); + /** Evaluate the value of this expression with the given input. */ + Object evaluate(Row row); - /** - * Show the expression string. - */ - String showExpression(); + /** Show the expression string. */ + String showExpression(); - /** - * Get the output type of this expression. - */ - IType getOutputType(); + /** Get the output type of this expression. */ + IType getOutputType(); - List getInputs(); + List getInputs(); - Expression copy(List inputs); + Expression copy(List inputs); - List getRefPathFieldIndices(); + List getRefPathFieldIndices(); - Expression replace(Function replaceFn); + Expression replace(Function replaceFn); - List collect(Predicate condition); + List collect(Predicate condition); - List splitByAnd(); + List splitByAnd(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionBuilder.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionBuilder.java index a97d1dd81..8d731f044 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionBuilder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionBuilder.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.expression; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.function.UDF; import org.apache.geaflow.dsl.common.function.UDTF; @@ -29,72 +30,74 @@ public interface ExpressionBuilder { - Expression plus(Expression left, Expression right, IType outputType); + Expression plus(Expression left, Expression right, IType outputType); - Expression minus(Expression left, Expression right, IType outputType); + Expression minus(Expression left, Expression right, IType outputType); - Expression multiply(Expression left, Expression right, IType outputType); + Expression multiply(Expression left, Expression right, IType outputType); - Expression divide(Expression left, Expression right, IType outputType); + Expression divide(Expression left, Expression right, IType outputType); - Expression mod(Expression left, Expression right, IType outputType); + Expression mod(Expression left, Expression right, IType outputType); - Expression minusPrefix(Expression input, IType outputType); + Expression minusPrefix(Expression input, IType outputType); - Expression cast(Expression input, IType outputType); + Expression cast(Expression input, IType outputType); - Expression isNull(Expression input); + Expression isNull(Expression input); - Expression isNotNull(Expression input); + Expression isNotNull(Expression input); - Expression caseWhen(List inputs, IType outputType); + Expression caseWhen(List inputs, IType outputType); - Expression ifExp(Expression condition, Expression trueValue, Expression falseValue, IType outputType); + Expression ifExp( + Expression condition, Expression trueValue, Expression falseValue, IType outputType); - Expression field(Expression input, int fieldIndex, IType outputType); + Expression field(Expression input, int fieldIndex, IType outputType); - Expression pathField(String label, int fieldIndex, IType outputType); + Expression pathField(String label, int fieldIndex, IType outputType); - Expression parameterField(int fieldIndex, IType outputType); + Expression parameterField(int fieldIndex, IType outputType); - Expression item(Expression input, Expression index); + Expression item(Expression input, Expression index); - Expression literal(Object value, IType outputType); + Expression literal(Object value, IType outputType); - Expression pi(); + Expression pi(); - Expression and(List inputs); + Expression and(List inputs); - Expression or(List inputs); + Expression or(List inputs); - Expression isFalse(Expression input); + Expression isFalse(Expression input); - Expression isNotFalse(Expression input); + Expression isNotFalse(Expression input); - Expression isTrue(Expression input); + Expression isTrue(Expression input); - Expression isNotTrue(Expression input); + Expression isNotTrue(Expression input); - Expression greaterThan(Expression left, Expression right); + Expression greaterThan(Expression left, Expression right); - Expression greaterEqThen(Expression left, Expression right); + Expression greaterEqThen(Expression left, Expression right); - Expression lessThan(Expression left, Expression right); + Expression lessThan(Expression left, Expression right); - Expression lessEqThan(Expression left, Expression right); + Expression lessEqThan(Expression left, Expression right); - Expression equal(Expression left, Expression right); + Expression equal(Expression left, Expression right); - Expression not(Expression input); + Expression not(Expression input); - Expression vertexConstruct(List inputs, List globalVariables, - VertexType vertexType); + Expression vertexConstruct( + List inputs, List globalVariables, VertexType vertexType); - Expression edgeConstruct(List inputs, EdgeType edgeType); + Expression edgeConstruct(List inputs, EdgeType edgeType); - Expression udf(List inputs, IType outputType, Class implementClass); + Expression udf(List inputs, IType outputType, Class implementClass); - Expression udtf(List inputs, IType outputType, Class implementClass); + Expression udtf( + List inputs, IType outputType, Class implementClass); - Expression buildIn(List inputs, IType outputType, String methodName); + Expression buildIn(List inputs, IType outputType, String methodName); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionTranslator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionTranslator.java index b1a70d9fd..3c4b16b8a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionTranslator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/ExpressionTranslator.java @@ -19,10 +19,10 @@ package org.apache.geaflow.dsl.runtime.expression; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.*; @@ -72,471 +72,481 @@ import org.apache.geaflow.dsl.util.GQLRexUtil; import org.apache.geaflow.dsl.util.SqlTypeUtil; +import com.google.common.collect.Lists; + public class ExpressionTranslator implements RexVisitor { - private final ExpressionBuilder builder; + private final ExpressionBuilder builder; - private final RelDataType inputType; + private final RelDataType inputType; - private final StepLogicalPlanSet logicalPlanSet; + private final StepLogicalPlanSet logicalPlanSet; - private ExpressionTranslator(RelDataType inputType) { - this(inputType, null); - } + private ExpressionTranslator(RelDataType inputType) { + this(inputType, null); + } - private ExpressionTranslator(RelDataType inputType, StepLogicalPlanSet logicalPlanSet) { - this.inputType = inputType; - this.logicalPlanSet = logicalPlanSet; - this.builder = new DefaultExpressionBuilder(); - } + private ExpressionTranslator(RelDataType inputType, StepLogicalPlanSet logicalPlanSet) { + this.inputType = inputType; + this.logicalPlanSet = logicalPlanSet; + this.builder = new DefaultExpressionBuilder(); + } - public static ExpressionTranslator of(RelDataType inputType) { - return new ExpressionTranslator(inputType); - } + public static ExpressionTranslator of(RelDataType inputType) { + return new ExpressionTranslator(inputType); + } - public static ExpressionTranslator of(RelDataType inputType, StepLogicalPlanSet logicalPlanSet) { - return new ExpressionTranslator(inputType, logicalPlanSet); - } + public static ExpressionTranslator of(RelDataType inputType, StepLogicalPlanSet logicalPlanSet) { + return new ExpressionTranslator(inputType, logicalPlanSet); + } - public Expression translate(RexNode exp) { - return exp.accept(this); - } + public Expression translate(RexNode exp) { + return exp.accept(this); + } - public List translate(List rexNodes) { - return rexNodes.stream() - .map(exp -> exp.accept(this)) - .collect(Collectors.toList()); - } + public List translate(List rexNodes) { + return rexNodes.stream().map(exp -> exp.accept(this)).collect(Collectors.toList()); + } - @Override - public Expression visitInputRef(RexInputRef inputRef) { - int index = inputRef.getIndex(); - IType type = SqlTypeUtil.convertType(inputRef.getType()); - if (inputRef instanceof PathInputRef) { - PathInputRef pathInputRef = (PathInputRef) inputRef; - return builder.pathField(pathInputRef.getLabel(), index, type); - } - return builder.field(null, index, type); + @Override + public Expression visitInputRef(RexInputRef inputRef) { + int index = inputRef.getIndex(); + IType type = SqlTypeUtil.convertType(inputRef.getType()); + if (inputRef instanceof PathInputRef) { + PathInputRef pathInputRef = (PathInputRef) inputRef; + return builder.pathField(pathInputRef.getLabel(), index, type); } - - @Override - public Expression visitLocalRef(RexLocalRef localRef) { - throw new IllegalArgumentException("Illegal call"); + return builder.field(null, index, type); + } + + @Override + public Expression visitLocalRef(RexLocalRef localRef) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public Expression visitLiteral(RexLiteral literal) { + Object value = GQLRexUtil.getLiteralValue(literal); + IType type = SqlTypeUtil.convertType(literal.getType()); + return builder.literal(value, type); + } + + @Override + public Expression visitOver(RexOver over) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public Expression visitCall(RexCall call) { + if (isPathPatternSubQuery(call)) { + return processPathPatternSubQuery(call); } - - @Override - public Expression visitLiteral(RexLiteral literal) { - Object value = GQLRexUtil.getLiteralValue(literal); - IType type = SqlTypeUtil.convertType(literal.getType()); - return builder.literal(value, type); - } - - @Override - public Expression visitOver(RexOver over) { - throw new IllegalArgumentException("Illegal call"); + List operands = call.getOperands(); + List inputs = Lists.newArrayList(); + for (RexNode operand : operands) { + Expression input = operand.accept(this); + inputs.add(input); } - @Override - public Expression visitCall(RexCall call) { - if (isPathPatternSubQuery(call)) { - return processPathPatternSubQuery(call); + SqlSyntax syntax = call.getOperator().getSyntax(); + SqlKind callKind = call.getKind(); + IType outputType = SqlTypeUtil.convertType(call.getType()); + switch (syntax) { + case BINARY: + switch (callKind) { + case PLUS: + assert inputs.size() == 2; + return builder.plus(inputs.get(0), inputs.get(1), outputType); + case MINUS: + assert inputs.size() == 2; + return builder.minus(inputs.get(0), inputs.get(1), outputType); + case TIMES: + assert inputs.size() == 2; + return builder.multiply(inputs.get(0), inputs.get(1), outputType); + case DIVIDE: + assert inputs.size() == 2; + return builder.divide(inputs.get(0), inputs.get(1), outputType); + case AND: + return builder.and(inputs); + case OR: + return builder.or(inputs); + case LESS_THAN: + assert inputs.size() == 2; + return builder.lessThan(inputs.get(0), inputs.get(1)); + case LESS_THAN_OR_EQUAL: + assert inputs.size() == 2; + return builder.lessEqThan(inputs.get(0), inputs.get(1)); + case EQUALS: + assert inputs.size() == 2; + return builder.equal(inputs.get(0), inputs.get(1)); + case NOT_EQUALS: + assert inputs.size() == 2; + return builder.not(builder.equal(inputs.get(0), inputs.get(1))); + case GREATER_THAN: + assert inputs.size() == 2; + return builder.greaterThan(inputs.get(0), inputs.get(1)); + case GREATER_THAN_OR_EQUAL: + assert inputs.size() == 2; + return builder.greaterEqThen(inputs.get(0), inputs.get(1)); + case IS_NULL: + assert inputs.size() == 1; + return builder.isNull(inputs.get(0)); + case OTHER: + return processOtherTrans(inputs, call); + default: + break; } - List operands = call.getOperands(); - List inputs = Lists.newArrayList(); - for (RexNode operand : operands) { - Expression input = operand.accept(this); - inputs.add(input); - } - - SqlSyntax syntax = call.getOperator().getSyntax(); - SqlKind callKind = call.getKind(); - IType outputType = SqlTypeUtil.convertType(call.getType()); - switch (syntax) { - case BINARY: - switch (callKind) { - case PLUS: - assert inputs.size() == 2; - return builder.plus(inputs.get(0), inputs.get(1), outputType); - case MINUS: - assert inputs.size() == 2; - return builder.minus(inputs.get(0), inputs.get(1), outputType); - case TIMES: - assert inputs.size() == 2; - return builder.multiply(inputs.get(0), inputs.get(1), outputType); - case DIVIDE: - assert inputs.size() == 2; - return builder.divide(inputs.get(0), inputs.get(1), outputType); - case AND: - return builder.and(inputs); - case OR: - return builder.or(inputs); - case LESS_THAN: - assert inputs.size() == 2; - return builder.lessThan(inputs.get(0), inputs.get(1)); - case LESS_THAN_OR_EQUAL: - assert inputs.size() == 2; - return builder.lessEqThan(inputs.get(0), inputs.get(1)); - case EQUALS: - assert inputs.size() == 2; - return builder.equal(inputs.get(0), inputs.get(1)); - case NOT_EQUALS: - assert inputs.size() == 2; - return builder.not(builder.equal(inputs.get(0), inputs.get(1))); - case GREATER_THAN: - assert inputs.size() == 2; - return builder.greaterThan(inputs.get(0), inputs.get(1)); - case GREATER_THAN_OR_EQUAL: - assert inputs.size() == 2; - return builder.greaterEqThen(inputs.get(0), inputs.get(1)); - case IS_NULL: - assert inputs.size() == 1; - return builder.isNull(inputs.get(0)); - case OTHER: - return processOtherTrans(inputs, call); - default: - break; - } - break; - case FUNCTION: - case FUNCTION_ID: - switch (callKind) { - case FLOOR: - if (call.getOperands().size() == 1) { - return builder.buildIn(inputs, outputType, BuildInExpression.FLOOR); - } else if (call.getOperands().size() == 2) { - return builder.buildIn(inputs, outputType, BuildInExpression.TIMESTAMP_FLOOR); - } - break; - case CEIL: - if (call.getOperands().size() == 1) { - return builder.buildIn(inputs, outputType, BuildInExpression.CEIL); - } else if (call.getOperands().size() == 2) { - return builder.buildIn(inputs, outputType, BuildInExpression.TIMESTAMP_CEIL); - } - break; - case TRIM: - return builder.buildIn(inputs, outputType, BuildInExpression.TRIM); - case OTHER_FUNCTION: - return processOtherTrans(inputs, call); - default: - break; - } - break; - case SPECIAL: - switch (callKind) { - case CAST: - case REINTERPRET: - assert inputs.size() == 1; - return builder.cast(inputs.get(0), outputType); - case CASE: - return builder.caseWhen(inputs, outputType); - case LIKE: - return builder.udf(inputs, outputType, Like.class); - case SIMILAR: - return builder.buildIn(inputs, outputType, BuildInExpression.SIMILAR); - case VERTEX_VALUE_CONSTRUCTOR: - RexObjectConstruct objConstruct = (RexObjectConstruct) call; - List variableInfoList = objConstruct.getVariableInfo(); - List globalVariables = new ArrayList<>(); - - for (int i = 0; i < objConstruct.getOperands().size(); i++) { - VariableInfo variableInfo = variableInfoList.get(i); - if (variableInfo.isGlobal()) { - RexNode operand = objConstruct.getOperands().get(i); - IType type = SqlTypeUtil.convertType(operand.getType()); - globalVariables.add(new GlobalVariable(variableInfo.getName(), i, type)); - } - } - return builder.vertexConstruct(inputs, globalVariables, (VertexType) outputType); - case EDGE_VALUE_CONSTRUCTOR: - return builder.edgeConstruct(inputs, (EdgeType) outputType); - default: - break; - } - break; - case POSTFIX: - switch (callKind) { - case IS_NULL: - assert inputs.size() == 1; - return builder.isNull(inputs.get(0)); - case IS_NOT_NULL: - assert inputs.size() == 1; - return builder.isNotNull(inputs.get(0)); - case IS_FALSE: - assert inputs.size() == 1; - return builder.isFalse(inputs.get(0)); - case IS_NOT_FALSE: - assert inputs.size() == 1; - return builder.isNotFalse(inputs.get(0)); - case IS_TRUE: - assert inputs.size() == 1; - return builder.isTrue(inputs.get(0)); - case IS_NOT_TRUE: - assert inputs.size() == 1; - return builder.isNotTrue(inputs.get(0)); - case DESCENDING: - return call.operands.get(0).accept(this); - default: - break; - } - break; - case PREFIX: - switch (callKind) { - case NOT: - assert inputs.size() == 1; - return builder.not(inputs.get(0)); - case MINUS_PREFIX: - assert inputs.size() == 1; - return builder.minusPrefix(inputs.get(0), outputType); - default: - break; - } - break; - default: - break; + break; + case FUNCTION: + case FUNCTION_ID: + switch (callKind) { + case FLOOR: + if (call.getOperands().size() == 1) { + return builder.buildIn(inputs, outputType, BuildInExpression.FLOOR); + } else if (call.getOperands().size() == 2) { + return builder.buildIn(inputs, outputType, BuildInExpression.TIMESTAMP_FLOOR); + } + break; + case CEIL: + if (call.getOperands().size() == 1) { + return builder.buildIn(inputs, outputType, BuildInExpression.CEIL); + } else if (call.getOperands().size() == 2) { + return builder.buildIn(inputs, outputType, BuildInExpression.TIMESTAMP_CEIL); + } + break; + case TRIM: + return builder.buildIn(inputs, outputType, BuildInExpression.TRIM); + case OTHER_FUNCTION: + return processOtherTrans(inputs, call); + default: + break; } - return processOtherTrans(inputs, call); - } - - private Expression processOtherTrans(List inputs, RexCall call) { - SqlOperator sqlOperator = call.getOperator(); - // Upper operator name. - String operatorName = sqlOperator.getName().toUpperCase(); - String functionName = null; - IType outputType = SqlTypeUtil.convertType(call.getType()); - - switch (operatorName) { - case "NULL": - return builder.literal(null, outputType); - case "IF": - assert inputs.size() == 3; - return builder.ifExp(inputs.get(0), inputs.get(1), inputs.get(2), outputType); - case "ITEM": - assert inputs.size() == 2; - return builder.item(inputs.get(0), inputs.get(1)); - case "%": - case "MOD": - return builder.mod(inputs.get(0), inputs.get(1), outputType); - case "PI": - return builder.pi(); - case "||": - functionName = BuildInExpression.CONCAT; - break; - case "CHAR_LENGTH": - case "CHARACTER_LENGTH": - functionName = BuildInExpression.LENGTH; - break; - case "UPPER": - functionName = BuildInExpression.UPPER; - break; - case "LOWER": - functionName = BuildInExpression.LOWER; - break; - case "POSITION": - functionName = BuildInExpression.POSITION; - break; - case "OVERLAY": - functionName = BuildInExpression.OVERLAY; - break; - case "SUBSTRING": - functionName = BuildInExpression.SUBSTRING; - break; - case "INITCAP": - functionName = BuildInExpression.INITCAP; - break; - case "POWER": - functionName = BuildInExpression.POWER; - break; - case "ABS": - functionName = BuildInExpression.ABS; - break; - case "LN": - functionName = BuildInExpression.LN; - break; - case "LOG10": - functionName = BuildInExpression.LOG10; - break; - case "EXP": - functionName = BuildInExpression.EXP; - break; - case "SIN": - functionName = BuildInExpression.SIN; - break; - case "COS": - functionName = BuildInExpression.COS; - break; - case "TAN": - functionName = BuildInExpression.TAN; - break; - case "COT": - functionName = BuildInExpression.COT; - break; - case "ASIN": - functionName = BuildInExpression.ASIN; - break; - case "ACOS": - functionName = BuildInExpression.ACOS; - break; - case "ATAN": - functionName = BuildInExpression.ATAN; - break; - case "DEGREES": - functionName = BuildInExpression.DEGREES; - break; - case "RADIANS": - functionName = BuildInExpression.RADIANS; - break; - case "SIGN": - functionName = BuildInExpression.SIGN; - break; - case "RAND": - functionName = BuildInExpression.RAND; - break; - case "RAND_INTEGER": - functionName = BuildInExpression.RAND_INTEGER; - break; - case "LOCALTIMESTAMP": - case "CURRENT_TIMESTAMP": - functionName = BuildInExpression.CURRENT_TIMESTAMP; - break; - case "SAME": - functionName = BuildInExpression.SAME; - break; - default: + break; + case SPECIAL: + switch (callKind) { + case CAST: + case REINTERPRET: + assert inputs.size() == 1; + return builder.cast(inputs.get(0), outputType); + case CASE: + return builder.caseWhen(inputs, outputType); + case LIKE: + return builder.udf(inputs, outputType, Like.class); + case SIMILAR: + return builder.buildIn(inputs, outputType, BuildInExpression.SIMILAR); + case VERTEX_VALUE_CONSTRUCTOR: + RexObjectConstruct objConstruct = (RexObjectConstruct) call; + List variableInfoList = objConstruct.getVariableInfo(); + List globalVariables = new ArrayList<>(); + + for (int i = 0; i < objConstruct.getOperands().size(); i++) { + VariableInfo variableInfo = variableInfoList.get(i); + if (variableInfo.isGlobal()) { + RexNode operand = objConstruct.getOperands().get(i); + IType type = SqlTypeUtil.convertType(operand.getType()); + globalVariables.add(new GlobalVariable(variableInfo.getName(), i, type)); + } + } + return builder.vertexConstruct(inputs, globalVariables, (VertexType) outputType); + case EDGE_VALUE_CONSTRUCTOR: + return builder.edgeConstruct(inputs, (EdgeType) outputType); + default: + break; } - if (functionName != null) { - return builder.buildIn(inputs, outputType, functionName); + break; + case POSTFIX: + switch (callKind) { + case IS_NULL: + assert inputs.size() == 1; + return builder.isNull(inputs.get(0)); + case IS_NOT_NULL: + assert inputs.size() == 1; + return builder.isNotNull(inputs.get(0)); + case IS_FALSE: + assert inputs.size() == 1; + return builder.isFalse(inputs.get(0)); + case IS_NOT_FALSE: + assert inputs.size() == 1; + return builder.isNotFalse(inputs.get(0)); + case IS_TRUE: + assert inputs.size() == 1; + return builder.isTrue(inputs.get(0)); + case IS_NOT_TRUE: + assert inputs.size() == 1; + return builder.isNotTrue(inputs.get(0)); + case DESCENDING: + return call.operands.get(0).accept(this); + default: + break; } - - if (call.getOperator() instanceof GeaFlowUserDefinedTableFunction) { - GeaFlowUserDefinedTableFunction operator = (GeaFlowUserDefinedTableFunction) call.getOperator(); - return builder.udtf(inputs, outputType, operator.getImplementClass()); - } else if (call.getOperator() instanceof GeaFlowUserDefinedScalarFunction) { - GeaFlowUserDefinedScalarFunction operator = (GeaFlowUserDefinedScalarFunction) call.getOperator(); - return builder.udf(inputs, outputType, operator.getImplementClass()); + break; + case PREFIX: + switch (callKind) { + case NOT: + assert inputs.size() == 1; + return builder.not(inputs.get(0)); + case MINUS_PREFIX: + assert inputs.size() == 1; + return builder.minusPrefix(inputs.get(0), outputType); + default: + break; } - throw new GeaFlowDSLException("Not support expression: " + call); + break; + default: + break; } - - @SuppressWarnings("unchecked") - private Expression processPathPatternSubQuery(RexCall call) { - RexLambdaCall lambdaCall = (RexLambdaCall) call.operands.get(0); - SingleMatchNode matchNode = (SingleMatchNode) (lambdaCall.getInput()).rel; - - // generate sub query logical plan. - assert logicalPlanSet != null; - StepLogicalPlanTranslator planTranslator = new StepLogicalPlanTranslator(); - GraphSchema graphSchema = logicalPlanSet.getGraphSchema(); - GraphRecordType graphRecordType = (GraphRecordType) SqlTypeUtil.convertToRelType(graphSchema, false, - matchNode.getCluster().getTypeFactory()); - GraphScan emptyScan = LogicalGraphScan.emptyScan(matchNode.getCluster(), graphRecordType); - GraphMatch graphMatch = LogicalGraphMatch.create(matchNode.getCluster(), emptyScan, matchNode, - matchNode.getPathSchema()); - StepLogicalPlan matchPlan = planTranslator.translate(graphMatch, logicalPlanSet); - - SqlAggFunction aggFunction = (SqlAggFunction) call.getOperator(); - IType aggInputType = SqlTypeUtil.convertType(lambdaCall.type); - Class> udafClass = PhysicAggregateRelNode.findUDAF( - aggFunction, new IType[]{aggInputType}); - UDAF udaf = (UDAF) ClassUtil.newInstance(udafClass); - StepAggregateFunction stepAggFunction = new StepAggFunctionImpl(udaf, aggInputType); - - Expression valueExpression = lambdaCall.getValue().accept(this); - IType valueType = valueExpression.getOutputType(); - StructType singleValueType = StructType.singleValue(valueType, lambdaCall.getValue().getType().isNullable()); - StepLogicalPlan valuePlan = matchPlan.mapRow(new StepSingleValueMapFunction(valueExpression)) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(singleValueType); - - IType aggOutputType = SqlTypeUtil.convertType(call.getType()); - StepLogicalPlan aggPlan = valuePlan.aggregate(stepAggFunction).withOutputType(aggOutputType); - StepLogicalPlan returnPlan = aggPlan.ret(); - // add sub query plan to plan set. - logicalPlanSet.addSubLogicalPlan(returnPlan); - - // create call sub query expression. - RelDataTypeField startField = inputType.getFieldList().get(inputType.getFieldCount() - 1); - int startVertexIndex = startField.getIndex(); - VertexType startVertexType = (VertexType) SqlTypeUtil.convertType(startField.getType()); - assert matchPlan.getHeadPlan().getOperator() instanceof StepSubQueryStartOperator; - StepSubQueryStartOperator startOperator = (StepSubQueryStartOperator) matchPlan.getHeadPlan().getOperator(); - - List subQueryRefPathFields = startOperator.getOutputPathSchema().getFieldNames(); - List inputPathFields = inputType.getFieldNames(); - List refParentPathIndices = new ArrayList<>(); - for (int i = 0; i < inputPathFields.size(); i++) { - if (subQueryRefPathFields.contains(inputPathFields.get(i))) { - refParentPathIndices.add(i); - } - } - Object accumulator = stepAggFunction.createAccumulator(); - SingleValue defaultAggValue = stepAggFunction.getValue(accumulator); - return new CallQueryExpression(startOperator.getQueryName(), - startOperator.getId(), - startVertexIndex, - startVertexType, - aggOutputType, - ArrayUtil.toIntArray(refParentPathIndices), - defaultAggValue.getValue(aggOutputType)); - } - - private boolean isPathPatternSubQuery(RexCall call) { - return call.getOperator().isAggregator() - && call.operands.size() == 1 - && call.operands.get(0) instanceof RexLambdaCall - && (((RexLambdaCall) call.operands.get(0)).getInput()).rel instanceof IMatchNode - ; - } - - @Override - public Expression visitCorrelVariable(RexCorrelVariable correlVariable) { - return null; - } - - @Override - public Expression visitDynamicParam(RexDynamicParam dynamicParam) { - throw new GeaFlowDSLException("Not support expression: " + dynamicParam); - } - - @Override - public Expression visitRangeRef(RexRangeRef rangeRef) { - throw new GeaFlowDSLException("Not support expression: " + rangeRef); + return processOtherTrans(inputs, call); + } + + private Expression processOtherTrans(List inputs, RexCall call) { + SqlOperator sqlOperator = call.getOperator(); + // Upper operator name. + String operatorName = sqlOperator.getName().toUpperCase(); + String functionName = null; + IType outputType = SqlTypeUtil.convertType(call.getType()); + + switch (operatorName) { + case "NULL": + return builder.literal(null, outputType); + case "IF": + assert inputs.size() == 3; + return builder.ifExp(inputs.get(0), inputs.get(1), inputs.get(2), outputType); + case "ITEM": + assert inputs.size() == 2; + return builder.item(inputs.get(0), inputs.get(1)); + case "%": + case "MOD": + return builder.mod(inputs.get(0), inputs.get(1), outputType); + case "PI": + return builder.pi(); + case "||": + functionName = BuildInExpression.CONCAT; + break; + case "CHAR_LENGTH": + case "CHARACTER_LENGTH": + functionName = BuildInExpression.LENGTH; + break; + case "UPPER": + functionName = BuildInExpression.UPPER; + break; + case "LOWER": + functionName = BuildInExpression.LOWER; + break; + case "POSITION": + functionName = BuildInExpression.POSITION; + break; + case "OVERLAY": + functionName = BuildInExpression.OVERLAY; + break; + case "SUBSTRING": + functionName = BuildInExpression.SUBSTRING; + break; + case "INITCAP": + functionName = BuildInExpression.INITCAP; + break; + case "POWER": + functionName = BuildInExpression.POWER; + break; + case "ABS": + functionName = BuildInExpression.ABS; + break; + case "LN": + functionName = BuildInExpression.LN; + break; + case "LOG10": + functionName = BuildInExpression.LOG10; + break; + case "EXP": + functionName = BuildInExpression.EXP; + break; + case "SIN": + functionName = BuildInExpression.SIN; + break; + case "COS": + functionName = BuildInExpression.COS; + break; + case "TAN": + functionName = BuildInExpression.TAN; + break; + case "COT": + functionName = BuildInExpression.COT; + break; + case "ASIN": + functionName = BuildInExpression.ASIN; + break; + case "ACOS": + functionName = BuildInExpression.ACOS; + break; + case "ATAN": + functionName = BuildInExpression.ATAN; + break; + case "DEGREES": + functionName = BuildInExpression.DEGREES; + break; + case "RADIANS": + functionName = BuildInExpression.RADIANS; + break; + case "SIGN": + functionName = BuildInExpression.SIGN; + break; + case "RAND": + functionName = BuildInExpression.RAND; + break; + case "RAND_INTEGER": + functionName = BuildInExpression.RAND_INTEGER; + break; + case "LOCALTIMESTAMP": + case "CURRENT_TIMESTAMP": + functionName = BuildInExpression.CURRENT_TIMESTAMP; + break; + case "SAME": + functionName = BuildInExpression.SAME; + break; + default: } - - @Override - public Expression visitFieldAccess(RexFieldAccess fieldAccess) { - Expression input = fieldAccess.getReferenceExpr().accept(this); - int index = fieldAccess.getField().getIndex(); - IType type = SqlTypeUtil.convertType(fieldAccess.getField().getType()); - - return builder.field(input, index, type); + if (functionName != null) { + return builder.buildIn(inputs, outputType, functionName); } - @Override - public Expression visitSubQuery(RexSubQuery subQuery) { - throw new GeaFlowDSLException("Not support expression: " + subQuery); - } - - @Override - public Expression visitTableInputRef(RexTableInputRef fieldRef) { - throw new GeaFlowDSLException("Not support expression: " + fieldRef); + if (call.getOperator() instanceof GeaFlowUserDefinedTableFunction) { + GeaFlowUserDefinedTableFunction operator = + (GeaFlowUserDefinedTableFunction) call.getOperator(); + return builder.udtf(inputs, outputType, operator.getImplementClass()); + } else if (call.getOperator() instanceof GeaFlowUserDefinedScalarFunction) { + GeaFlowUserDefinedScalarFunction operator = + (GeaFlowUserDefinedScalarFunction) call.getOperator(); + return builder.udf(inputs, outputType, operator.getImplementClass()); } + throw new GeaFlowDSLException("Not support expression: " + call); + } + + @SuppressWarnings("unchecked") + private Expression processPathPatternSubQuery(RexCall call) { + RexLambdaCall lambdaCall = (RexLambdaCall) call.operands.get(0); + SingleMatchNode matchNode = (SingleMatchNode) (lambdaCall.getInput()).rel; + + // generate sub query logical plan. + assert logicalPlanSet != null; + StepLogicalPlanTranslator planTranslator = new StepLogicalPlanTranslator(); + GraphSchema graphSchema = logicalPlanSet.getGraphSchema(); + GraphRecordType graphRecordType = + (GraphRecordType) + SqlTypeUtil.convertToRelType( + graphSchema, false, matchNode.getCluster().getTypeFactory()); + GraphScan emptyScan = LogicalGraphScan.emptyScan(matchNode.getCluster(), graphRecordType); + GraphMatch graphMatch = + LogicalGraphMatch.create( + matchNode.getCluster(), emptyScan, matchNode, matchNode.getPathSchema()); + StepLogicalPlan matchPlan = planTranslator.translate(graphMatch, logicalPlanSet); + + SqlAggFunction aggFunction = (SqlAggFunction) call.getOperator(); + IType aggInputType = SqlTypeUtil.convertType(lambdaCall.type); + Class> udafClass = + PhysicAggregateRelNode.findUDAF(aggFunction, new IType[] {aggInputType}); + UDAF udaf = + (UDAF) ClassUtil.newInstance(udafClass); + StepAggregateFunction stepAggFunction = new StepAggFunctionImpl(udaf, aggInputType); + + Expression valueExpression = lambdaCall.getValue().accept(this); + IType valueType = valueExpression.getOutputType(); + StructType singleValueType = + StructType.singleValue(valueType, lambdaCall.getValue().getType().isNullable()); + StepLogicalPlan valuePlan = + matchPlan + .mapRow(new StepSingleValueMapFunction(valueExpression)) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(singleValueType); - @Override - public Expression visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { - throw new GeaFlowDSLException("Not support expression: " + rexPatternFieldRef); + IType aggOutputType = SqlTypeUtil.convertType(call.getType()); + StepLogicalPlan aggPlan = valuePlan.aggregate(stepAggFunction).withOutputType(aggOutputType); + StepLogicalPlan returnPlan = aggPlan.ret(); + // add sub query plan to plan set. + logicalPlanSet.addSubLogicalPlan(returnPlan); + + // create call sub query expression. + RelDataTypeField startField = inputType.getFieldList().get(inputType.getFieldCount() - 1); + int startVertexIndex = startField.getIndex(); + VertexType startVertexType = (VertexType) SqlTypeUtil.convertType(startField.getType()); + assert matchPlan.getHeadPlan().getOperator() instanceof StepSubQueryStartOperator; + StepSubQueryStartOperator startOperator = + (StepSubQueryStartOperator) matchPlan.getHeadPlan().getOperator(); + + List subQueryRefPathFields = startOperator.getOutputPathSchema().getFieldNames(); + List inputPathFields = inputType.getFieldNames(); + List refParentPathIndices = new ArrayList<>(); + for (int i = 0; i < inputPathFields.size(); i++) { + if (subQueryRefPathFields.contains(inputPathFields.get(i))) { + refParentPathIndices.add(i); + } } - - @Override - public Expression visitOther(RexNode other) { - if (other instanceof RexParameterRef) { - RexParameterRef rexParameterRef = (RexParameterRef) other; - IType outputType = SqlTypeUtil.convertType(rexParameterRef.getType()); - return builder.parameterField(rexParameterRef.getIndex(), outputType); - } else if (other instanceof RexSystemVariable) { - RexSystemVariable systemVariable = (RexSystemVariable) other; - return new SystemVariableExpression(SystemVariable.of(systemVariable.getName())); - } - throw new GeaFlowDSLException("Not support expression: " + other); + Object accumulator = stepAggFunction.createAccumulator(); + SingleValue defaultAggValue = stepAggFunction.getValue(accumulator); + return new CallQueryExpression( + startOperator.getQueryName(), + startOperator.getId(), + startVertexIndex, + startVertexType, + aggOutputType, + ArrayUtil.toIntArray(refParentPathIndices), + defaultAggValue.getValue(aggOutputType)); + } + + private boolean isPathPatternSubQuery(RexCall call) { + return call.getOperator().isAggregator() + && call.operands.size() == 1 + && call.operands.get(0) instanceof RexLambdaCall + && (((RexLambdaCall) call.operands.get(0)).getInput()).rel instanceof IMatchNode; + } + + @Override + public Expression visitCorrelVariable(RexCorrelVariable correlVariable) { + return null; + } + + @Override + public Expression visitDynamicParam(RexDynamicParam dynamicParam) { + throw new GeaFlowDSLException("Not support expression: " + dynamicParam); + } + + @Override + public Expression visitRangeRef(RexRangeRef rangeRef) { + throw new GeaFlowDSLException("Not support expression: " + rangeRef); + } + + @Override + public Expression visitFieldAccess(RexFieldAccess fieldAccess) { + Expression input = fieldAccess.getReferenceExpr().accept(this); + int index = fieldAccess.getField().getIndex(); + IType type = SqlTypeUtil.convertType(fieldAccess.getField().getType()); + + return builder.field(input, index, type); + } + + @Override + public Expression visitSubQuery(RexSubQuery subQuery) { + throw new GeaFlowDSLException("Not support expression: " + subQuery); + } + + @Override + public Expression visitTableInputRef(RexTableInputRef fieldRef) { + throw new GeaFlowDSLException("Not support expression: " + fieldRef); + } + + @Override + public Expression visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { + throw new GeaFlowDSLException("Not support expression: " + rexPatternFieldRef); + } + + @Override + public Expression visitOther(RexNode other) { + if (other instanceof RexParameterRef) { + RexParameterRef rexParameterRef = (RexParameterRef) other; + IType outputType = SqlTypeUtil.convertType(rexParameterRef.getType()); + return builder.parameterField(rexParameterRef.getIndex(), outputType); + } else if (other instanceof RexSystemVariable) { + RexSystemVariable systemVariable = (RexSystemVariable) other; + return new SystemVariableExpression(SystemVariable.of(systemVariable.getName())); } + throw new GeaFlowDSLException("Not support expression: " + other); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDFExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDFExpression.java index 2a658e8ae..d109be197 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDFExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDFExpression.java @@ -21,34 +21,35 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.function.Description; public class UDFExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "eval"; + public static final String METHOD_NAME = "eval"; - private final String udfName; + private final String udfName; - public UDFExpression(List inputs, IType outputType, Class implementClass) { - super(inputs, outputType, implementClass, METHOD_NAME); - this.udfName = implementClass.getAnnotation(Description.class).name(); - } + public UDFExpression(List inputs, IType outputType, Class implementClass) { + super(inputs, outputType, implementClass, METHOD_NAME); + this.udfName = implementClass.getAnnotation(Description.class).name(); + } - @Override - public String showExpression() { - return udfName + "(" - + inputs.stream().map(Expression::showExpression) - .collect(Collectors.joining(",")) - + ")"; - } + @Override + public String showExpression() { + return udfName + + "(" + + inputs.stream().map(Expression::showExpression).collect(Collectors.joining(",")) + + ")"; + } - @Override - public Expression copy(List inputs) { - return new UDFExpression(inputs, outputType, implementClass); - } + @Override + public Expression copy(List inputs) { + return new UDFExpression(inputs, outputType, implementClass); + } - public Class getUdfClass() { - return implementClass; - } + public Class getUdfClass() { + return implementClass; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDTFExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDTFExpression.java index 2955b3646..14e8df86d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDTFExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/UDTFExpression.java @@ -20,25 +20,25 @@ package org.apache.geaflow.dsl.runtime.expression; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.UDTF; public class UDTFExpression extends UDFExpression { - public UDTFExpression(List inputs, IType outputType, - Class implementClass) { - super(inputs, outputType, implementClass); - } + public UDTFExpression(List inputs, IType outputType, Class implementClass) { + super(inputs, outputType, implementClass); + } - @Override - public Object evaluate(Row row) { - super.evaluate(row); - return ((UDTF) implementInstance).getCollectData(); - } + @Override + public Object evaluate(Row row) { + super.evaluate(row); + return ((UDTF) implementInstance).getCollectData(); + } - @Override - public Expression copy(List inputs) { - return new UDTFExpression(inputs, outputType, implementClass); - } + @Override + public Expression copy(List inputs) { + return new UDTFExpression(inputs, outputType, implementClass); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/DivideExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/DivideExpression.java index 1c1062c82..b004673e3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/DivideExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/DivideExpression.java @@ -19,29 +19,31 @@ package org.apache.geaflow.dsl.runtime.expression.binary; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class DivideExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "divide"; + public static final String METHOD_NAME = "divide"; - public DivideExpression(Expression left, Expression right, IType outputType) { - super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public DivideExpression(Expression left, Expression right, IType outputType) { + super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return "(" + inputs.get(0).showExpression() + "/" + inputs.get(1).showExpression() + ")"; - } + @Override + public String showExpression() { + return "(" + inputs.get(0).showExpression() + "/" + inputs.get(1).showExpression() + ")"; + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new DivideExpression(inputs.get(0), inputs.get(1), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new DivideExpression(inputs.get(0), inputs.get(1), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusExpression.java index e30cc6b28..dca670c35 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusExpression.java @@ -19,29 +19,31 @@ package org.apache.geaflow.dsl.runtime.expression.binary; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class MinusExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "minus"; + public static final String METHOD_NAME = "minus"; - public MinusExpression(Expression left, Expression right, IType outputType) { - super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public MinusExpression(Expression left, Expression right, IType outputType) { + super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return "(" + inputs.get(0).showExpression() + "-" + inputs.get(1).showExpression() + ")"; - } + @Override + public String showExpression() { + return "(" + inputs.get(0).showExpression() + "-" + inputs.get(1).showExpression() + ")"; + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new MinusExpression(inputs.get(0), inputs.get(1), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new MinusExpression(inputs.get(0), inputs.get(1), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusPrefixExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusPrefixExpression.java index 407bbaec4..89eb2c082 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusPrefixExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MinusPrefixExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -28,19 +29,19 @@ public class MinusPrefixExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "minusPrefix"; + public static final String METHOD_NAME = "minusPrefix"; - public MinusPrefixExpression(Expression input, IType outputType) { - super(Collections.singletonList(input), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public MinusPrefixExpression(Expression input, IType outputType) { + super(Collections.singletonList(input), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return "-" + inputs.get(0).showExpression(); - } + @Override + public String showExpression() { + return "-" + inputs.get(0).showExpression(); + } - @Override - public Expression copy(List inputs) { - return new MinusPrefixExpression(inputs.get(0), getOutputType()); - } + @Override + public Expression copy(List inputs) { + return new MinusPrefixExpression(inputs.get(0), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/ModExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/ModExpression.java index 625aecd7d..fc936f9df 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/ModExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/ModExpression.java @@ -19,29 +19,31 @@ package org.apache.geaflow.dsl.runtime.expression.binary; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class ModExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "mod"; + public static final String METHOD_NAME = "mod"; - public ModExpression(Expression left, Expression right, IType outputType) { - super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public ModExpression(Expression left, Expression right, IType outputType) { + super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + "%" + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + "%" + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new ModExpression(inputs.get(0), inputs.get(1), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new ModExpression(inputs.get(0), inputs.get(1), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MultiplyExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MultiplyExpression.java index aaa3a0e4e..a6fd9b656 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MultiplyExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/MultiplyExpression.java @@ -19,29 +19,31 @@ package org.apache.geaflow.dsl.runtime.expression.binary; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class MultiplyExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "times"; + public static final String METHOD_NAME = "times"; - public MultiplyExpression(Expression left, Expression right, IType outputType) { - super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public MultiplyExpression(Expression left, Expression right, IType outputType) { + super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + "*" + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + "*" + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new MultiplyExpression(inputs.get(0), inputs.get(1), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new MultiplyExpression(inputs.get(0), inputs.get(1), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/PlusExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/PlusExpression.java index 12115d561..2fc84db27 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/PlusExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/binary/PlusExpression.java @@ -19,29 +19,31 @@ package org.apache.geaflow.dsl.runtime.expression.binary; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class PlusExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "plus"; + public static final String METHOD_NAME = "plus"; - public PlusExpression(Expression left, Expression right, IType outputType) { - super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public PlusExpression(Expression left, Expression right, IType outputType) { + super(Lists.newArrayList(left, right), outputType, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return "(" + inputs.get(0).showExpression() + "+" + inputs.get(1).showExpression() + ")"; - } + @Override + public String showExpression() { + return "(" + inputs.get(0).showExpression() + "+" + inputs.get(1).showExpression() + ")"; + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new PlusExpression(inputs.get(0), inputs.get(1), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new PlusExpression(inputs.get(0), inputs.get(1), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/cast/CastExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/cast/CastExpression.java index 5bdfa9b8b..9a28dc2f4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/cast/CastExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/cast/CastExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.util.TypeCastUtil; @@ -30,26 +31,30 @@ public class CastExpression extends AbstractNonLeafExpression { - private final ITypeCast typeCast; - - public CastExpression(Expression input, IType outputType) { - super(Collections.singletonList(input), outputType); - this.typeCast = TypeCastUtil.getTypeCast(input.getOutputType(), outputType); - } - - @Override - public Object evaluate(Row row) { - Object inputValue = inputs.get(0).evaluate(row); - return typeCast.castTo(inputValue); - } - - @Override - public String showExpression() { - return "cast(" + inputs.get(0).showExpression() + " as " + getOutputType().getTypeClass().getSimpleName() + ")"; - } - - @Override - public Expression copy(List inputs) { - return new CastExpression(inputs.get(0), outputType); - } + private final ITypeCast typeCast; + + public CastExpression(Expression input, IType outputType) { + super(Collections.singletonList(input), outputType); + this.typeCast = TypeCastUtil.getTypeCast(input.getOutputType(), outputType); + } + + @Override + public Object evaluate(Row row) { + Object inputValue = inputs.get(0).evaluate(row); + return typeCast.castTo(inputValue); + } + + @Override + public String showExpression() { + return "cast(" + + inputs.get(0).showExpression() + + " as " + + getOutputType().getTypeClass().getSimpleName() + + ")"; + } + + @Override + public Expression copy(List inputs) { + return new CastExpression(inputs.get(0), outputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/CaseExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/CaseExpression.java index 76b7e25d1..49c553a33 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/CaseExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/CaseExpression.java @@ -19,51 +19,55 @@ package org.apache.geaflow.dsl.runtime.expression.condition; -import com.google.common.base.Preconditions; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; +import com.google.common.base.Preconditions; + public class CaseExpression extends AbstractNonLeafExpression { - public CaseExpression(List inputs, IType outputType) { - super(inputs, outputType); - Preconditions.checkArgument(inputs.size() % 2 == 1); - } + public CaseExpression(List inputs, IType outputType) { + super(inputs, outputType); + Preconditions.checkArgument(inputs.size() % 2 == 1); + } - @Override - public Object evaluate(Row row) { - int i = 0; - while (i < inputs.size() - 1) { - if (i % 2 == 0) { - Boolean condition = (Boolean) inputs.get(i).evaluate(row); - if (condition != null && condition) { - return inputs.get(i + 1).evaluate(row); - } - } - i++; + @Override + public Object evaluate(Row row) { + int i = 0; + while (i < inputs.size() - 1) { + if (i % 2 == 0) { + Boolean condition = (Boolean) inputs.get(i).evaluate(row); + if (condition != null && condition) { + return inputs.get(i + 1).evaluate(row); } - return inputs.get(i).evaluate(row); + } + i++; } + return inputs.get(i).evaluate(row); + } - @Override - public String showExpression() { - StringBuilder str = new StringBuilder(); - str.append("case "); - int i = 0; - while (i < inputs.size() - 1) { - str.append(" when ").append(inputs.get(i).showExpression()) - .append(" then ").append(inputs.get(i + 1).showExpression()); - i += 2; - } - str.append(" else ").append(inputs.get(i).showExpression()).append(" end"); - return str.toString(); + @Override + public String showExpression() { + StringBuilder str = new StringBuilder(); + str.append("case "); + int i = 0; + while (i < inputs.size() - 1) { + str.append(" when ") + .append(inputs.get(i).showExpression()) + .append(" then ") + .append(inputs.get(i + 1).showExpression()); + i += 2; } + str.append(" else ").append(inputs.get(i).showExpression()).append(" end"); + return str.toString(); + } - @Override - public Expression copy(List inputs) { - return new CaseExpression(inputs, getOutputType()); - } + @Override + public Expression copy(List inputs) { + return new CaseExpression(inputs, getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/IfExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/IfExpression.java index 7dc3b751a..d47c7c236 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/IfExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/condition/IfExpression.java @@ -19,37 +19,45 @@ package org.apache.geaflow.dsl.runtime.expression.condition; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; +import com.google.common.collect.Lists; + public class IfExpression extends AbstractNonLeafExpression { - public IfExpression(Expression condition, Expression trueValue, Expression falseValue, IType outputType) { - super(Lists.newArrayList(condition, trueValue, falseValue), outputType); - } + public IfExpression( + Expression condition, Expression trueValue, Expression falseValue, IType outputType) { + super(Lists.newArrayList(condition, trueValue, falseValue), outputType); + } - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - if (condition != null && condition) { - return inputs.get(1).evaluate(row); - } - return inputs.get(2).evaluate(row); + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + if (condition != null && condition) { + return inputs.get(1).evaluate(row); } + return inputs.get(2).evaluate(row); + } - @Override - public String showExpression() { - return "if(" + inputs.get(0).showExpression() + ", " + inputs.get(1).showExpression() - + ", " + inputs.get(2).showExpression() + ")"; - } + @Override + public String showExpression() { + return "if(" + + inputs.get(0).showExpression() + + ", " + + inputs.get(1).showExpression() + + ", " + + inputs.get(2).showExpression() + + ")"; + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 3; - return new IfExpression(inputs.get(0), inputs.get(1), inputs.get(2), getOutputType()); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 3; + return new IfExpression(inputs.get(0), inputs.get(1), inputs.get(2), getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/EdgeConstructExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/EdgeConstructExpression.java index 95993ec5b..1b9cf07b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/EdgeConstructExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/EdgeConstructExpression.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.expression.construct; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,67 +34,67 @@ public class EdgeConstructExpression extends AbstractNonLeafExpression { - private final Expression srcIdExpression; + private final Expression srcIdExpression; - private final Expression targetIdExpression; + private final Expression targetIdExpression; - private final Expression labelExpression; + private final Expression labelExpression; - private final Expression tsExpression; + private final Expression tsExpression; - private final EdgeType edgeType; + private final EdgeType edgeType; - public EdgeConstructExpression(List inputs, IType outputType) { - super(inputs, outputType); - this.srcIdExpression = inputs.get(EdgeType.SRC_ID_FIELD_POSITION); - this.targetIdExpression = inputs.get(EdgeType.TARGET_ID_FIELD_POSITION); - this.labelExpression = inputs.get(EdgeType.LABEL_FIELD_POSITION); - this.edgeType = (EdgeType) outputType; - if (this.edgeType.getTimestamp().isPresent()) { - tsExpression = inputs.get(EdgeType.TIME_FIELD_POSITION); - } else { - tsExpression = null; - } + public EdgeConstructExpression(List inputs, IType outputType) { + super(inputs, outputType); + this.srcIdExpression = inputs.get(EdgeType.SRC_ID_FIELD_POSITION); + this.targetIdExpression = inputs.get(EdgeType.TARGET_ID_FIELD_POSITION); + this.labelExpression = inputs.get(EdgeType.LABEL_FIELD_POSITION); + this.edgeType = (EdgeType) outputType; + if (this.edgeType.getTimestamp().isPresent()) { + tsExpression = inputs.get(EdgeType.TIME_FIELD_POSITION); + } else { + tsExpression = null; } + } - @Override - public Object evaluate(Row row) { - Object srcId = srcIdExpression.evaluate(row); - Object targetId = targetIdExpression.evaluate(row); - - Object[] values = new Object[edgeType.getValueSize()]; - for (int i = edgeType.getValueOffset(); i < edgeType.size(); i++) { - values[i - edgeType.getValueOffset()] = inputs.get(i).evaluate(row); - } - RowEdge edge = VertexEdgeFactory.createEdge((EdgeType) outputType); - BinaryString label = (BinaryString) labelExpression.evaluate(row); - edge.setSrcId(srcId); - edge.setTargetId(targetId); - edge.setBinaryLabel(label); - edge.setValue(ObjectRow.create(values)); + @Override + public Object evaluate(Row row) { + Object srcId = srcIdExpression.evaluate(row); + Object targetId = targetIdExpression.evaluate(row); - if (tsExpression != null) { - Long ts = (Long) tsExpression.evaluate(row); - assert ts != null; - ((IGraphElementWithTimeField) edge).setTime(ts); - } - return edge; + Object[] values = new Object[edgeType.getValueSize()]; + for (int i = edgeType.getValueOffset(); i < edgeType.size(); i++) { + values[i - edgeType.getValueOffset()] = inputs.get(i).evaluate(row); } + RowEdge edge = VertexEdgeFactory.createEdge((EdgeType) outputType); + BinaryString label = (BinaryString) labelExpression.evaluate(row); + edge.setSrcId(srcId); + edge.setTargetId(targetId); + edge.setBinaryLabel(label); + edge.setValue(ObjectRow.create(values)); - @Override - public String showExpression() { - StringBuilder str = new StringBuilder(); - for (int i = 0; i < edgeType.size(); i++) { - if (i > 0) { - str.append(","); - } - str.append(edgeType.getField(i).getName()).append(":").append(inputs.get(i).showExpression()); - } - return "Edge{" + str + "}"; + if (tsExpression != null) { + Long ts = (Long) tsExpression.evaluate(row); + assert ts != null; + ((IGraphElementWithTimeField) edge).setTime(ts); } + return edge; + } - @Override - public Expression copy(List inputs) { - return new EdgeConstructExpression(inputs, getOutputType()); + @Override + public String showExpression() { + StringBuilder str = new StringBuilder(); + for (int i = 0; i < edgeType.size(); i++) { + if (i > 0) { + str.append(","); + } + str.append(edgeType.getField(i).getName()).append(":").append(inputs.get(i).showExpression()); } + return "Edge{" + str + "}"; + } + + @Override + public Expression copy(List inputs) { + return new EdgeConstructExpression(inputs, getOutputType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/VertexConstructExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/VertexConstructExpression.java index 6c8bf04fd..403f80542 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/VertexConstructExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/construct/VertexConstructExpression.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.expression.construct; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -33,56 +34,58 @@ public class VertexConstructExpression extends AbstractNonLeafExpression { - private final Expression idExpression; + private final Expression idExpression; - private final Expression labelExpression; + private final Expression labelExpression; - private final List globalVariables; + private final List globalVariables; - private final VertexType vertexType; + private final VertexType vertexType; - public VertexConstructExpression(List inputs, List globalVariables, - IType outputType) { - super(inputs, outputType); - this.idExpression = inputs.get(VertexType.ID_FIELD_POSITION); - this.labelExpression = inputs.get(VertexType.LABEL_FIELD_POSITION); - this.globalVariables = globalVariables; - this.vertexType = (VertexType) outputType; - } + public VertexConstructExpression( + List inputs, List globalVariables, IType outputType) { + super(inputs, outputType); + this.idExpression = inputs.get(VertexType.ID_FIELD_POSITION); + this.labelExpression = inputs.get(VertexType.LABEL_FIELD_POSITION); + this.globalVariables = globalVariables; + this.vertexType = (VertexType) outputType; + } - @Override - public Object evaluate(Row row) { - Object id = idExpression.evaluate(row); - Object label = labelExpression.evaluate(row); - Object[] values = new Object[vertexType.getValueSize()]; - for (int i = vertexType.getValueOffset(); i < vertexType.size(); i++) { - values[i - vertexType.getValueOffset()] = inputs.get(i).evaluate(row); - } - RowVertex vertex = VertexEdgeFactory.createVertex((VertexType) outputType); - vertex.setId(id); - vertex.setBinaryLabel((BinaryString) label); - vertex.setValue(ObjectRow.create(values)); - return vertex; + @Override + public Object evaluate(Row row) { + Object id = idExpression.evaluate(row); + Object label = labelExpression.evaluate(row); + Object[] values = new Object[vertexType.getValueSize()]; + for (int i = vertexType.getValueOffset(); i < vertexType.size(); i++) { + values[i - vertexType.getValueOffset()] = inputs.get(i).evaluate(row); } + RowVertex vertex = VertexEdgeFactory.createVertex((VertexType) outputType); + vertex.setId(id); + vertex.setBinaryLabel((BinaryString) label); + vertex.setValue(ObjectRow.create(values)); + return vertex; + } - @Override - public String showExpression() { - StringBuilder str = new StringBuilder(); - for (int i = 0; i < vertexType.size(); i++) { - if (i > 0) { - str.append(","); - } - str.append(vertexType.getField(i).getName()).append(":").append(inputs.get(i).showExpression()); - } - return "Vertex{" + str + "}"; + @Override + public String showExpression() { + StringBuilder str = new StringBuilder(); + for (int i = 0; i < vertexType.size(); i++) { + if (i > 0) { + str.append(","); + } + str.append(vertexType.getField(i).getName()) + .append(":") + .append(inputs.get(i).showExpression()); } + return "Vertex{" + str + "}"; + } - @Override - public Expression copy(List inputs) { - return new VertexConstructExpression(inputs, globalVariables, getOutputType()); - } + @Override + public Expression copy(List inputs) { + return new VertexConstructExpression(inputs, globalVariables, getOutputType()); + } - public List getGlobalVariables() { - return globalVariables; - } + public List getGlobalVariables() { + return globalVariables; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/FieldExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/FieldExpression.java index 669e14275..8a177378c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/FieldExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/FieldExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractExpression; @@ -28,70 +29,70 @@ public class FieldExpression extends AbstractExpression { - private final Expression input; + private final Expression input; - protected final int fieldIndex; + protected final int fieldIndex; - protected final IType outputType; + protected final IType outputType; - public FieldExpression(Expression input, int fieldIndex, IType outputType) { - this.input = input; - this.fieldIndex = fieldIndex; - this.outputType = outputType; - } + public FieldExpression(Expression input, int fieldIndex, IType outputType) { + this.input = input; + this.fieldIndex = fieldIndex; + this.outputType = outputType; + } - public FieldExpression(int fieldIndex, IType outputType) { - this(null, fieldIndex, outputType); - } + public FieldExpression(int fieldIndex, IType outputType) { + this(null, fieldIndex, outputType); + } - @Override - public Object evaluate(Row row) { - Row inputR = row; - if (input != null) { - inputR = (Row) input.evaluate(row); - } - if (inputR == null) { - return null; - } - return inputR.getField(fieldIndex, outputType); + @Override + public Object evaluate(Row row) { + Row inputR = row; + if (input != null) { + inputR = (Row) input.evaluate(row); } - - @Override - public String showExpression() { - if (input != null) { - return input.showExpression() + ".$" + fieldIndex; - } - return "$" + fieldIndex; + if (inputR == null) { + return null; } + return inputR.getField(fieldIndex, outputType); + } - @Override - public IType getOutputType() { - return outputType; + @Override + public String showExpression() { + if (input != null) { + return input.showExpression() + ".$" + fieldIndex; } + return "$" + fieldIndex; + } - @Override - public List getInputs() { - if (input != null) { - return Collections.singletonList(input); - } - return Collections.emptyList(); - } + @Override + public IType getOutputType() { + return outputType; + } - @Override - public Expression copy(List inputs) { - if (input == null) { - assert inputs.isEmpty(); - return new FieldExpression(fieldIndex, outputType); - } - assert inputs.size() == 1; - return new FieldExpression(inputs.get(0), fieldIndex, outputType); + @Override + public List getInputs() { + if (input != null) { + return Collections.singletonList(input); } + return Collections.emptyList(); + } - public int getFieldIndex() { - return fieldIndex; + @Override + public Expression copy(List inputs) { + if (input == null) { + assert inputs.isEmpty(); + return new FieldExpression(fieldIndex, outputType); } + assert inputs.size() == 1; + return new FieldExpression(inputs.get(0), fieldIndex, outputType); + } - public FieldExpression copy(int newIndex) { - return new FieldExpression(input, newIndex, outputType); - } + public int getFieldIndex() { + return fieldIndex; + } + + public FieldExpression copy(int newIndex) { + return new FieldExpression(input, newIndex, outputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/ParameterFieldExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/ParameterFieldExpression.java index 5dd7f1282..e9c9215c1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/ParameterFieldExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/ParameterFieldExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.ParameterizedRow; import org.apache.geaflow.dsl.common.data.Row; @@ -29,43 +30,43 @@ public class ParameterFieldExpression extends AbstractExpression { - protected final int fieldIndex; + protected final int fieldIndex; - protected final IType outputType; + protected final IType outputType; - public ParameterFieldExpression(int fieldIndex, IType outputType) { - this.fieldIndex = fieldIndex; - this.outputType = outputType; - } + public ParameterFieldExpression(int fieldIndex, IType outputType) { + this.fieldIndex = fieldIndex; + this.outputType = outputType; + } - @Override - public Object evaluate(Row row) { - ParameterizedRow parameterizedRow = (ParameterizedRow) row; - return parameterizedRow.getParameter().getField(fieldIndex, outputType); - } + @Override + public Object evaluate(Row row) { + ParameterizedRow parameterizedRow = (ParameterizedRow) row; + return parameterizedRow.getParameter().getField(fieldIndex, outputType); + } - @Override - public String showExpression() { - return "$$" + fieldIndex; - } + @Override + public String showExpression() { + return "$$" + fieldIndex; + } - @Override - public IType getOutputType() { - return outputType; - } + @Override + public IType getOutputType() { + return outputType; + } - @Override - public List getInputs() { - return Collections.emptyList(); - } + @Override + public List getInputs() { + return Collections.emptyList(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 0; - return new ParameterFieldExpression(fieldIndex, outputType); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 0; + return new ParameterFieldExpression(fieldIndex, outputType); + } - public int getFieldIndex() { - return fieldIndex; - } + public int getFieldIndex() { + return fieldIndex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/PathFieldExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/PathFieldExpression.java index 39bddc9d0..bb7597f0d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/PathFieldExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/PathFieldExpression.java @@ -20,37 +20,38 @@ package org.apache.geaflow.dsl.runtime.expression.field; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.Expression; public class PathFieldExpression extends FieldExpression { - private final String label; - - public PathFieldExpression(String label, int fieldIndex, IType outputType) { - super(fieldIndex, outputType); - this.label = label; - } - - @Override - public Object evaluate(Row row) { - return row.getField(fieldIndex, outputType); - } - - @Override - public String showExpression() { - return label; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 0; - return new PathFieldExpression(label, fieldIndex, outputType); - } - - @Override - public PathFieldExpression copy(int newIndex) { - return new PathFieldExpression(label, newIndex, outputType); - } + private final String label; + + public PathFieldExpression(String label, int fieldIndex, IType outputType) { + super(fieldIndex, outputType); + this.label = label; + } + + @Override + public Object evaluate(Row row) { + return row.getField(fieldIndex, outputType); + } + + @Override + public String showExpression() { + return label; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 0; + return new PathFieldExpression(label, fieldIndex, outputType); + } + + @Override + public PathFieldExpression copy(int newIndex) { + return new PathFieldExpression(label, newIndex, outputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/SystemVariableExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/SystemVariableExpression.java index 45ddd6fee..6d6f98a62 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/SystemVariableExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/field/SystemVariableExpression.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.ParameterizedRow; import org.apache.geaflow.dsl.common.data.Row; @@ -32,36 +33,36 @@ public class SystemVariableExpression extends AbstractExpression { - private final SystemVariable variable; + private final SystemVariable variable; - public SystemVariableExpression(SystemVariable variable) { - this.variable = Objects.requireNonNull(variable); - } + public SystemVariableExpression(SystemVariable variable) { + this.variable = Objects.requireNonNull(variable); + } - @Override - public Object evaluate(Row row) { - ParameterizedRow parameterizedRow = (ParameterizedRow) row; - return parameterizedRow.getSystemVariables().getField(variable.getIndex(), getOutputType()); - } + @Override + public Object evaluate(Row row) { + ParameterizedRow parameterizedRow = (ParameterizedRow) row; + return parameterizedRow.getSystemVariables().getField(variable.getIndex(), getOutputType()); + } - @Override - public String showExpression() { - return variable.getName(); - } + @Override + public String showExpression() { + return variable.getName(); + } - @Override - public IType getOutputType() { - return SqlTypeUtil.ofTypeName(variable.getTypeName()); - } + @Override + public IType getOutputType() { + return SqlTypeUtil.ofTypeName(variable.getTypeName()); + } - @Override - public List getInputs() { - return Collections.emptyList(); - } + @Override + public List getInputs() { + return Collections.emptyList(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 0; - return new SystemVariableExpression(variable); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 0; + return new SystemVariableExpression(variable); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/item/ItemExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/item/ItemExpression.java index 09265ac93..e8088dfd6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/item/ItemExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/item/ItemExpression.java @@ -19,40 +19,43 @@ package org.apache.geaflow.dsl.runtime.expression.item; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.types.ArrayType; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; +import com.google.common.collect.Lists; + public class ItemExpression extends AbstractNonLeafExpression { - public ItemExpression(Expression target, Expression index) { - super(Lists.newArrayList(target, index), ((ArrayType) target.getOutputType()).getComponentType()); - } + public ItemExpression(Expression target, Expression index) { + super( + Lists.newArrayList(target, index), ((ArrayType) target.getOutputType()).getComponentType()); + } - @Override - public Object evaluate(Row row) { - Object target = inputs.get(0).evaluate(row); - if (target == null) { - return null; - } - if (target instanceof Object[]) { - int index = (int) inputs.get(1).evaluate(row); - return ((Object[]) target)[index]; - } - throw new IllegalArgumentException("target is not a array object"); + @Override + public Object evaluate(Row row) { + Object target = inputs.get(0).evaluate(row); + if (target == null) { + return null; } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + "[" + inputs.get(1).showExpression() + "]"; + if (target instanceof Object[]) { + int index = (int) inputs.get(1).evaluate(row); + return ((Object[]) target)[index]; } + throw new IllegalArgumentException("target is not a array object"); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new ItemExpression(inputs.get(0), inputs.get(1)); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + "[" + inputs.get(1).showExpression() + "]"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new ItemExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/LiteralExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/LiteralExpression.java index 6ff4abc4e..c0d1deb83 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/LiteralExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/LiteralExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; @@ -31,51 +32,51 @@ public class LiteralExpression extends AbstractExpression { - private final Object value; - private final String literalString; - private final IType outputType; + private final Object value; + private final String literalString; + private final IType outputType; - public LiteralExpression(Object value, IType outputType) { - this.value = BinaryUtil.toBinaryForString(value); - this.outputType = outputType; - if (outputType == Types.STRING) { - this.literalString = StringLiteralUtil.escapeSQLString(String.valueOf(value)); - } else { - literalString = String.valueOf(value); - } + public LiteralExpression(Object value, IType outputType) { + this.value = BinaryUtil.toBinaryForString(value); + this.outputType = outputType; + if (outputType == Types.STRING) { + this.literalString = StringLiteralUtil.escapeSQLString(String.valueOf(value)); + } else { + literalString = String.valueOf(value); } + } - @Override - public Object evaluate(Row row) { - return value; - } + @Override + public Object evaluate(Row row) { + return value; + } - @Override - public String showExpression() { - return literalString; - } + @Override + public String showExpression() { + return literalString; + } - @Override - public IType getOutputType() { - return outputType; - } + @Override + public IType getOutputType() { + return outputType; + } - @Override - public List getInputs() { - return Collections.emptyList(); - } + @Override + public List getInputs() { + return Collections.emptyList(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 0; - return new LiteralExpression(value, outputType); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 0; + return new LiteralExpression(value, outputType); + } - public Object getValue() { - return value; - } + public Object getValue() { + return value; + } - public static LiteralExpression createBoolean(boolean value) { - return new LiteralExpression(value, Types.BOOLEAN); - } + public static LiteralExpression createBoolean(boolean value) { + return new LiteralExpression(value, Types.BOOLEAN); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/PIExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/PIExpression.java index 3f4c72bac..3f395d993 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/PIExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/literal/PIExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; @@ -29,29 +30,29 @@ public class PIExpression extends AbstractExpression { - @Override - public Object evaluate(Row input) { - return Math.PI; - } - - @Override - public IType getOutputType() { - return Types.DOUBLE; - } - - @Override - public String showExpression() { - return "PI"; - } - - @Override - public List getInputs() { - return Collections.emptyList(); - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 0; - return this; - } + @Override + public Object evaluate(Row input) { + return Math.PI; + } + + @Override + public IType getOutputType() { + return Types.DOUBLE; + } + + @Override + public String showExpression() { + return "PI"; + } + + @Override + public List getInputs() { + return Collections.emptyList(); + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 0; + return this; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/AndExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/AndExpression.java index f6635a9d3..cf1e35b33 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/AndExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/AndExpression.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,38 +29,39 @@ public class AndExpression extends AbstractNonLeafExpression { - public AndExpression(List inputs) { - super(inputs, Types.BOOLEAN); - } + public AndExpression(List inputs) { + super(inputs, Types.BOOLEAN); + } - @Override - public Object evaluate(Row row) { - boolean hasNull = false; + @Override + public Object evaluate(Row row) { + boolean hasNull = false; - for (Expression input : inputs) { - Boolean b = (Boolean) input.evaluate(row); - if (b != null && !b) { - return false; - } - if (b == null) { - hasNull = true; - } - } - if (hasNull) { - return null; - } else { - return true; - } + for (Expression input : inputs) { + Boolean b = (Boolean) input.evaluate(row); + if (b != null && !b) { + return false; + } + if (b == null) { + hasNull = true; + } } - - @Override - public String showExpression() { - return "And(" + inputs.stream().map(Expression::showExpression) - .collect(Collectors.joining(",")) + ")"; + if (hasNull) { + return null; + } else { + return true; } + } - @Override - public Expression copy(List inputs) { - return new AndExpression(inputs); - } + @Override + public String showExpression() { + return "And(" + + inputs.stream().map(Expression::showExpression).collect(Collectors.joining(",")) + + ")"; + } + + @Override + public Expression copy(List inputs) { + return new AndExpression(inputs); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/EqualExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/EqualExpression.java index 36ee64048..fa262acd0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/EqualExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/EqualExpression.java @@ -19,29 +19,32 @@ package org.apache.geaflow.dsl.runtime.expression.logic; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class EqualExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "equal"; + public static final String METHOD_NAME = "equal"; - public EqualExpression(Expression left, Expression right) { - super(Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public EqualExpression(Expression left, Expression right) { + super( + Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " = " + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " = " + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new EqualExpression(inputs.get(0), inputs.get(1)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new EqualExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTEExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTEExpression.java index fc98b6b28..2d2cc354e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTEExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTEExpression.java @@ -19,29 +19,32 @@ package org.apache.geaflow.dsl.runtime.expression.logic; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class GTEExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "greaterThanEq"; + public static final String METHOD_NAME = "greaterThanEq"; - public GTEExpression(Expression left, Expression right) { - super(Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public GTEExpression(Expression left, Expression right) { + super( + Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " >= " + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " >= " + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new GTEExpression(inputs.get(0), inputs.get(1)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new GTEExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTExpression.java index 54a8e9a83..0622a9b0f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/GTExpression.java @@ -19,29 +19,32 @@ package org.apache.geaflow.dsl.runtime.expression.logic; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class GTExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "greaterThan"; + public static final String METHOD_NAME = "greaterThan"; - public GTExpression(Expression left, Expression right) { - super(Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public GTExpression(Expression left, Expression right) { + super( + Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " > " + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " > " + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new GTExpression(inputs.get(0), inputs.get(1)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new GTExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsFalseExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsFalseExpression.java index 658e61f1c..18a60f2d4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsFalseExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsFalseExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,24 +29,24 @@ public class IsFalseExpression extends AbstractNonLeafExpression { - public IsFalseExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - return condition != null && !condition; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is false"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsFalseExpression(inputs.get(0)); - } + public IsFalseExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + return condition != null && !condition; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is false"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsFalseExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotFalseExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotFalseExpression.java index 6024eab36..29aed819c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotFalseExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotFalseExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,24 +29,24 @@ public class IsNotFalseExpression extends AbstractNonLeafExpression { - public IsNotFalseExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - return condition == null || condition; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is not false"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsNotFalseExpression(inputs.get(0)); - } + public IsNotFalseExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + return condition == null || condition; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is not false"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsNotFalseExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotNullExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotNullExpression.java index 5e9b42ef9..c99eebcad 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotNullExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotNullExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,23 +29,23 @@ public class IsNotNullExpression extends AbstractNonLeafExpression { - public IsNotNullExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - return inputs.get(0).evaluate(row) != null; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is not null"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsNotNullExpression(inputs.get(0)); - } + public IsNotNullExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + return inputs.get(0).evaluate(row) != null; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is not null"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsNotNullExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotTrueExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotTrueExpression.java index 1bc2bf91b..390817bb0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotTrueExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNotTrueExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,24 +29,24 @@ public class IsNotTrueExpression extends AbstractNonLeafExpression { - public IsNotTrueExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - return condition == null || !condition; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is not true"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsNotTrueExpression(inputs.get(0)); - } + public IsNotTrueExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + return condition == null || !condition; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is not true"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsNotTrueExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNullExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNullExpression.java index a8fa0ee62..23b808907 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNullExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsNullExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,23 +29,23 @@ public class IsNullExpression extends AbstractNonLeafExpression { - public IsNullExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - return inputs.get(0).evaluate(row) == null; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is null"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsNullExpression(inputs.get(0)); - } + public IsNullExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + return inputs.get(0).evaluate(row) == null; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is null"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsNullExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsTrueExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsTrueExpression.java index 0adf8b5fd..d832ac48e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsTrueExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/IsTrueExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,24 +29,24 @@ public class IsTrueExpression extends AbstractNonLeafExpression { - public IsTrueExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } - - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - return condition != null && condition; - } - - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " is true"; - } - - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new IsTrueExpression(inputs.get(0)); - } + public IsTrueExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } + + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + return condition != null && condition; + } + + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " is true"; + } + + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new IsTrueExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTEExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTEExpression.java index f8d34d082..a31f2744e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTEExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTEExpression.java @@ -19,29 +19,32 @@ package org.apache.geaflow.dsl.runtime.expression.logic; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class LTEExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "lessThanEq"; + public static final String METHOD_NAME = "lessThanEq"; - public LTEExpression(Expression left, Expression right) { - super(Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public LTEExpression(Expression left, Expression right) { + super( + Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " <= " + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " <= " + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new LTEExpression(inputs.get(0), inputs.get(1)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new LTEExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTExpression.java index 3870b96b9..74e9a7bec 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/LTExpression.java @@ -19,29 +19,32 @@ package org.apache.geaflow.dsl.runtime.expression.logic; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.runtime.expression.AbstractReflectCallExpression; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.schema.function.GeaFlowBuiltinFunctions; +import com.google.common.collect.Lists; + public class LTExpression extends AbstractReflectCallExpression { - public static final String METHOD_NAME = "lessThan"; + public static final String METHOD_NAME = "lessThan"; - public LTExpression(Expression left, Expression right) { - super(Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); - } + public LTExpression(Expression left, Expression right) { + super( + Lists.newArrayList(left, right), Types.BOOLEAN, GeaFlowBuiltinFunctions.class, METHOD_NAME); + } - @Override - public String showExpression() { - return inputs.get(0).showExpression() + " < " + inputs.get(1).showExpression(); - } + @Override + public String showExpression() { + return inputs.get(0).showExpression() + " < " + inputs.get(1).showExpression(); + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 2; - return new LTExpression(inputs.get(0), inputs.get(1)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 2; + return new LTExpression(inputs.get(0), inputs.get(1)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/NotExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/NotExpression.java index 8bf056055..26ad98015 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/NotExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/NotExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,27 +29,27 @@ public class NotExpression extends AbstractNonLeafExpression { - public NotExpression(Expression input) { - super(Collections.singletonList(input), Types.BOOLEAN); - } + public NotExpression(Expression input) { + super(Collections.singletonList(input), Types.BOOLEAN); + } - @Override - public Object evaluate(Row row) { - Boolean condition = (Boolean) inputs.get(0).evaluate(row); - if (condition == null) { - return null; - } - return !condition; + @Override + public Object evaluate(Row row) { + Boolean condition = (Boolean) inputs.get(0).evaluate(row); + if (condition == null) { + return null; } + return !condition; + } - @Override - public String showExpression() { - return "Not(" + inputs.get(0).showExpression() + ")"; - } + @Override + public String showExpression() { + return "Not(" + inputs.get(0).showExpression() + ")"; + } - @Override - public Expression copy(List inputs) { - assert inputs.size() == 1; - return new NotExpression(inputs.get(0)); - } + @Override + public Expression copy(List inputs) { + assert inputs.size() == 1; + return new NotExpression(inputs.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/OrExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/OrExpression.java index 18202d554..bcb6ad591 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/OrExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/logic/OrExpression.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.AbstractNonLeafExpression; @@ -28,37 +29,38 @@ public class OrExpression extends AbstractNonLeafExpression { - public OrExpression(List inputs) { - super(inputs, Types.BOOLEAN); - } + public OrExpression(List inputs) { + super(inputs, Types.BOOLEAN); + } - @Override - public Object evaluate(Row row) { - boolean hasNull = false; - for (Expression input : inputs) { - Boolean b = (Boolean) input.evaluate(row); - if (b != null && b) { - return true; - } - if (b == null) { - hasNull = true; - } - } - if (hasNull) { - return null; - } else { - return false; - } + @Override + public Object evaluate(Row row) { + boolean hasNull = false; + for (Expression input : inputs) { + Boolean b = (Boolean) input.evaluate(row); + if (b != null && b) { + return true; + } + if (b == null) { + hasNull = true; + } } - - @Override - public String showExpression() { - return "Or(" + inputs.stream().map(Expression::showExpression) - .collect(Collectors.joining(",")) + ")"; + if (hasNull) { + return null; + } else { + return false; } + } - @Override - public Expression copy(List inputs) { - return new OrExpression(inputs); - } + @Override + public String showExpression() { + return "Or(" + + inputs.stream().map(Expression::showExpression).collect(Collectors.joining(",")) + + ")"; + } + + @Override + public Expression copy(List inputs) { + return new OrExpression(inputs); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallContext.java index 6a79a15af..aeb3943c8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallContext.java @@ -24,52 +24,54 @@ import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.dsl.runtime.traversal.data.ParameterRequest; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; public class CallContext { - // requestId -> (vertexId, treePath) - private final Map> paths; + // requestId -> (vertexId, treePath) + private final Map> paths; - // vertexId -> request list - private final Map> requests; + // vertexId -> request list + private final Map> requests; - public CallContext(Map> paths, Map> requests) { - this.paths = Objects.requireNonNull(paths); - this.requests = Objects.requireNonNull(requests); - } + public CallContext( + Map> paths, Map> requests) { + this.paths = Objects.requireNonNull(paths); + this.requests = Objects.requireNonNull(requests); + } - public CallContext() { - this(new HashMap<>(), new HashMap<>()); - } + public CallContext() { + this(new HashMap<>(), new HashMap<>()); + } - public void addPath(Object requestId, Object vertexId, ITreePath treePath) { - paths.computeIfAbsent(requestId, r -> new HashMap<>()).put(vertexId, treePath); - } + public void addPath(Object requestId, Object vertexId, ITreePath treePath) { + paths.computeIfAbsent(requestId, r -> new HashMap<>()).put(vertexId, treePath); + } - public ITreePath getPath(Object requestId, Object vertexId) { - if (paths.containsKey(requestId)) { - Map vertexTreePaths = paths.get(requestId); - if (vertexTreePaths != null) { - return vertexTreePaths.get(vertexId); - } - } - return null; + public ITreePath getPath(Object requestId, Object vertexId) { + if (paths.containsKey(requestId)) { + Map vertexTreePaths = paths.get(requestId); + if (vertexTreePaths != null) { + return vertexTreePaths.get(vertexId); + } } + return null; + } - public void addRequest(Object vertexId, ParameterRequest request) { - if (request != null) { - requests.computeIfAbsent(vertexId, v -> new ArrayList<>()).add(request); - } + public void addRequest(Object vertexId, ParameterRequest request) { + if (request != null) { + requests.computeIfAbsent(vertexId, v -> new ArrayList<>()).add(request); } + } - public List getRequests(Object vertexId) { - return requests.get(vertexId); - } + public List getRequests(Object vertexId) { + return requests.get(vertexId); + } - public void reset() { - paths.clear(); - requests.clear(); - } + public void reset() { + paths.clear(); + requests.clear(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryExpression.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryExpression.java index 1a5b66afa..2e2a1d649 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryExpression.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryExpression.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.Path; @@ -45,135 +46,137 @@ public class CallQueryExpression extends AbstractExpression implements ICallQuery { - private final String queryName; - - /** - * The operator id of the start operator for the sub query dag. - */ - private final long queryId; - - private final int startVertexIndex; - - private final VertexType startVertexType; - - private final IType returnType; - - private TraversalRuntimeContext context; - - private CallState callState; - - private final int[] refParentPathIndices; - - private final Object defaultAggValue; - - public CallQueryExpression(String queryName, long queryId, - int startVertexIndex, VertexType startVertexType, - IType returnType, int[] refParentPathIndices, - Object defaultAggValue) { - this.queryName = queryName; - this.queryId = queryId; - this.startVertexIndex = startVertexIndex; - this.startVertexType = startVertexType; - this.returnType = returnType; - this.refParentPathIndices = refParentPathIndices; - this.defaultAggValue = defaultAggValue; - } - - @Override - public void open(TraversalRuntimeContext context) { - this.context = context; - this.open(FunctionContext.of(context.getConfig())); - this.callState = CallState.INIT; - } - - @Override - public void setCallState(CallState callState) { - this.callState = callState; - } - - @Override - public CallState getCallState() { - return callState; - } - - @Override - public Object evaluate(Row row) { - RowVertex startVertex = (RowVertex) row.getField(startVertexIndex, startVertexType); - Path path = (Path) row; - if (callState == CallState.WAITING) { - ParameterRequestMessage requestMessage = new ParameterRequestMessage(); - Row parameter = context.getParameters(); - long uniquePathId = context.createUniqueId(path.getId()); - CallRequestId callRequestId = new CallRequestId(uniquePathId, context.getCurrentOpId(), - startVertex.getId()); - ParameterRequest request = new ParameterRequest(callRequestId, startVertex.getId(), parameter); - - requestMessage.addRequest(request); - // send request message to sub query plan's start query operator id. - context.sendMessage(startVertex.getId(), requestMessage, queryId); - Path pathMessage = ((Path) row).subPath(refParentPathIndices); - ITreePath treePath = TreePaths.createTreePath(Collections.singletonList(pathMessage)); - treePath.setRequestIdForTree(callRequestId); - context.sendMessage(startVertex.getId(), treePath, queryId); - return null; - } else if (callState == CallState.RETURNING) { - ReturnMessage returnMessage = context.getMessage(MessageType.RETURN_VALUE); - long uniquePathId = context.createUniqueId(path.getId()); - ReturnKey returnKey = new ReturnKey(uniquePathId, queryId); - SingleValue singleValue = returnMessage.getValue(returnKey); - if (singleValue == null) { - return defaultAggValue; - } - return singleValue.getValue(returnType); - } - throw new IllegalArgumentException("Illegal call state: " + callState + " for evaluate() method"); - } - - public void sendEod() { - EndOfData eod = EndOfData.of(context.getCurrentOpId(), context.getCurrentOpId()); - context.broadcast(EODMessage.of(eod), queryId); - } - - @Override - public void finishCall() { - - } - - public String getQueryName() { - return queryName; - } - - @Override - public String showExpression() { - return "Call(" + queryName + ")"; - } - - @Override - public IType getOutputType() { - return returnType; - } - - @Override - public List getInputs() { - return Collections.emptyList(); - } - - @Override - public Expression copy(List inputs) { - assert inputs.isEmpty(); - return this; - } - - @Override - public List getRefPathFieldIndices() { - return ArrayUtil.toList(refParentPathIndices); - } - - public enum CallState { - INIT, - CALLING, - WAITING, - RETURNING, - FINISH + private final String queryName; + + /** The operator id of the start operator for the sub query dag. */ + private final long queryId; + + private final int startVertexIndex; + + private final VertexType startVertexType; + + private final IType returnType; + + private TraversalRuntimeContext context; + + private CallState callState; + + private final int[] refParentPathIndices; + + private final Object defaultAggValue; + + public CallQueryExpression( + String queryName, + long queryId, + int startVertexIndex, + VertexType startVertexType, + IType returnType, + int[] refParentPathIndices, + Object defaultAggValue) { + this.queryName = queryName; + this.queryId = queryId; + this.startVertexIndex = startVertexIndex; + this.startVertexType = startVertexType; + this.returnType = returnType; + this.refParentPathIndices = refParentPathIndices; + this.defaultAggValue = defaultAggValue; + } + + @Override + public void open(TraversalRuntimeContext context) { + this.context = context; + this.open(FunctionContext.of(context.getConfig())); + this.callState = CallState.INIT; + } + + @Override + public void setCallState(CallState callState) { + this.callState = callState; + } + + @Override + public CallState getCallState() { + return callState; + } + + @Override + public Object evaluate(Row row) { + RowVertex startVertex = (RowVertex) row.getField(startVertexIndex, startVertexType); + Path path = (Path) row; + if (callState == CallState.WAITING) { + ParameterRequestMessage requestMessage = new ParameterRequestMessage(); + Row parameter = context.getParameters(); + long uniquePathId = context.createUniqueId(path.getId()); + CallRequestId callRequestId = + new CallRequestId(uniquePathId, context.getCurrentOpId(), startVertex.getId()); + ParameterRequest request = + new ParameterRequest(callRequestId, startVertex.getId(), parameter); + + requestMessage.addRequest(request); + // send request message to sub query plan's start query operator id. + context.sendMessage(startVertex.getId(), requestMessage, queryId); + Path pathMessage = ((Path) row).subPath(refParentPathIndices); + ITreePath treePath = TreePaths.createTreePath(Collections.singletonList(pathMessage)); + treePath.setRequestIdForTree(callRequestId); + context.sendMessage(startVertex.getId(), treePath, queryId); + return null; + } else if (callState == CallState.RETURNING) { + ReturnMessage returnMessage = context.getMessage(MessageType.RETURN_VALUE); + long uniquePathId = context.createUniqueId(path.getId()); + ReturnKey returnKey = new ReturnKey(uniquePathId, queryId); + SingleValue singleValue = returnMessage.getValue(returnKey); + if (singleValue == null) { + return defaultAggValue; + } + return singleValue.getValue(returnType); } + throw new IllegalArgumentException( + "Illegal call state: " + callState + " for evaluate() method"); + } + + public void sendEod() { + EndOfData eod = EndOfData.of(context.getCurrentOpId(), context.getCurrentOpId()); + context.broadcast(EODMessage.of(eod), queryId); + } + + @Override + public void finishCall() {} + + public String getQueryName() { + return queryName; + } + + @Override + public String showExpression() { + return "Call(" + queryName + ")"; + } + + @Override + public IType getOutputType() { + return returnType; + } + + @Override + public List getInputs() { + return Collections.emptyList(); + } + + @Override + public Expression copy(List inputs) { + assert inputs.isEmpty(); + return this; + } + + @Override + public List getRefPathFieldIndices() { + return ArrayUtil.toList(refParentPathIndices); + } + + public enum CallState { + INIT, + CALLING, + WAITING, + RETURNING, + FINISH + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryProxy.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryProxy.java index 28fe71423..fb3ede9bc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryProxy.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/CallQueryProxy.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.Path; @@ -38,196 +39,201 @@ public class CallQueryProxy extends AbstractExpression implements ICallQuery { - private static final Logger LOGGER = LoggerFactory.getLogger(CallQueryProxy.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CallQueryProxy.class); - private final CallQueryExpression[] queryCalls; + private final CallQueryExpression[] queryCalls; - private final Expression rewriteExpression; + private final Expression rewriteExpression; - private final PlaceHolderExpression[] placeHolderExpressions; + private final PlaceHolderExpression[] placeHolderExpressions; - private CallState callState; + private CallState callState; - private Map>> stashPaths; + private Map>> stashPaths; - private TraversalRuntimeContext context; + private TraversalRuntimeContext context; - private CallQueryProxy(CallQueryExpression[] queryCalls, - Expression rewriteExpression) { - this.queryCalls = Objects.requireNonNull(queryCalls); - this.rewriteExpression = Objects.requireNonNull(rewriteExpression); - this.placeHolderExpressions = rewriteExpression.collect(exp -> exp instanceof PlaceHolderExpression) - .toArray(new PlaceHolderExpression[]{}); - } + private CallQueryProxy(CallQueryExpression[] queryCalls, Expression rewriteExpression) { + this.queryCalls = Objects.requireNonNull(queryCalls); + this.rewriteExpression = Objects.requireNonNull(rewriteExpression); + this.placeHolderExpressions = + rewriteExpression + .collect(exp -> exp instanceof PlaceHolderExpression) + .toArray(new PlaceHolderExpression[] {}); + } - public static Expression from(Expression expression) { - List calls = - ArrayUtil.castList(expression.collect(exp -> exp instanceof CallQueryExpression)); - if (calls.isEmpty()) { - return expression; - } - Map callIndexMap = new HashMap<>(); - for (int i = 0; i < calls.size(); i++) { - callIndexMap.put(calls.get(i), i); - } - CallQueryExpression[] callArrays = calls.toArray(new CallQueryExpression[]{}); - Expression rewriteExpression = expression.replace(exp -> { - if (exp instanceof CallQueryExpression) { + public static Expression from(Expression expression) { + List calls = + ArrayUtil.castList(expression.collect(exp -> exp instanceof CallQueryExpression)); + if (calls.isEmpty()) { + return expression; + } + Map callIndexMap = new HashMap<>(); + for (int i = 0; i < calls.size(); i++) { + callIndexMap.put(calls.get(i), i); + } + CallQueryExpression[] callArrays = calls.toArray(new CallQueryExpression[] {}); + Expression rewriteExpression = + expression.replace( + exp -> { + if (exp instanceof CallQueryExpression) { int index = callIndexMap.get(exp); return new PlaceHolderExpression(callArrays, index); - } - return exp; - }); - return new CallQueryProxy(callArrays, rewriteExpression); + } + return exp; + }); + return new CallQueryProxy(callArrays, rewriteExpression); + } + + @Override + public void open(TraversalRuntimeContext context) { + for (CallQueryExpression call : queryCalls) { + call.open(context); + } + this.context = context; + this.stashPaths = new HashMap<>(); + this.callState = CallState.INIT; + } + + @Override + public Object evaluate(Row row) { + if (callState == CallState.CALLING) { + // stash paths. + stashPaths + .computeIfAbsent(context.getRequest(), k -> new HashMap<>()) + .computeIfAbsent(context.getVertex().getId(), id -> new ArrayList<>()) + .add(((Path) row).copy()); + return null; } - @Override - public void open(TraversalRuntimeContext context) { - for (CallQueryExpression call : queryCalls) { - call.open(context); - } - this.context = context; - this.stashPaths = new HashMap<>(); - this.callState = CallState.INIT; + if (callState == CallState.RETURNING) { + Object[] callResults = new Object[queryCalls.length]; + for (int i = 0; i < queryCalls.length; i++) { + callResults[i] = queryCalls[i].evaluate(row); + } + for (PlaceHolderExpression placeHolderExpression : placeHolderExpressions) { + placeHolderExpression.setResults(callResults); + } + return rewriteExpression.evaluate(row); + } + throw new IllegalArgumentException( + "Illegal call state: " + callState + " for evaluate() method"); + } + + @Override + public String showExpression() { + return rewriteExpression.showExpression(); + } + + @Override + public IType getOutputType() { + return rewriteExpression.getOutputType(); + } + + @Override + public List getInputs() { + return rewriteExpression.getInputs(); + } + + @Override + public Expression copy(List inputs) { + return new CallQueryProxy(queryCalls, rewriteExpression.copy(inputs)); + } + + @Override + public void setCallState(CallState callState) { + this.callState = callState; + for (CallQueryExpression call : queryCalls) { + call.setCallState(callState); } + } - @Override - public Object evaluate(Row row) { - if (callState == CallState.CALLING) { - // stash paths. - stashPaths.computeIfAbsent(context.getRequest(), k -> new HashMap<>()) - .computeIfAbsent(context.getVertex().getId(), id -> new ArrayList<>()) - .add(((Path) row).copy()); - return null; - } + @Override + public CallState getCallState() { + return callState; + } - if (callState == CallState.RETURNING) { - Object[] callResults = new Object[queryCalls.length]; - for (int i = 0; i < queryCalls.length; i++) { - callResults[i] = queryCalls[i].evaluate(row); - } - for (PlaceHolderExpression placeHolderExpression : placeHolderExpressions) { - placeHolderExpression.setResults(callResults); + public CallQueryExpression[] getQueryCalls() { + return queryCalls; + } + + public List getSubQueryNames() { + List names = new ArrayList<>(); + for (CallQueryExpression queryCall : queryCalls) { + names.add(queryCall.getQueryName()); + } + return names; + } + + @Override + public void finishCall() { + if (callState == CallState.WAITING) { // call sub query + for (Map.Entry>> entry : stashPaths.entrySet()) { + ParameterRequest request = entry.getKey(); + Map> vertexPaths = entry.getValue(); + for (Map.Entry> vertexPathEntry : vertexPaths.entrySet()) { + List paths = vertexPathEntry.getValue(); + context.setRequest(request); + for (Path path : paths) { + for (CallQueryExpression queryCall : queryCalls) { + queryCall.evaluate(path); } - return rewriteExpression.evaluate(row); - } - throw new IllegalArgumentException("Illegal call state: " + callState + " for evaluate() method"); + } + } + } + // send eod to the sub query after call finish. + for (CallQueryExpression queryCall : queryCalls) { + queryCall.sendEod(); + } + stashPaths.clear(); } - @Override - public String showExpression() { - return rewriteExpression.showExpression(); + for (ICallQuery queryCall : queryCalls) { + queryCall.finishCall(); } + } - @Override - public IType getOutputType() { - return rewriteExpression.getOutputType(); - } + private static class PlaceHolderExpression extends AbstractExpression { - @Override - public List getInputs() { - return rewriteExpression.getInputs(); - } + private final Expression[] expressions; - @Override - public Expression copy(List inputs) { - return new CallQueryProxy(queryCalls, rewriteExpression.copy(inputs)); + private final int placeHolderIndex; + + private Object[] results; + + public PlaceHolderExpression(Expression[] expressions, int placeHolderIndex) { + this.expressions = expressions; + this.placeHolderIndex = placeHolderIndex; } - @Override - public void setCallState(CallState callState) { - this.callState = callState; - for (CallQueryExpression call : queryCalls) { - call.setCallState(callState); - } + public void setResults(Object[] results) { + this.results = results; } @Override - public CallState getCallState() { - return callState; + public Object evaluate(Row row) { + return results[placeHolderIndex]; } - public CallQueryExpression[] getQueryCalls() { - return queryCalls; + @Override + public String showExpression() { + return expressions[placeHolderIndex].showExpression(); } - public List getSubQueryNames() { - List names = new ArrayList<>(); - for (CallQueryExpression queryCall : queryCalls) { - names.add(queryCall.getQueryName()); - } - return names; + @Override + public IType getOutputType() { + return expressions[placeHolderIndex].getOutputType(); } @Override - public void finishCall() { - if (callState == CallState.WAITING) { // call sub query - for (Map.Entry>> entry : stashPaths.entrySet()) { - ParameterRequest request = entry.getKey(); - Map> vertexPaths = entry.getValue(); - for (Map.Entry> vertexPathEntry : vertexPaths.entrySet()) { - List paths = vertexPathEntry.getValue(); - context.setRequest(request); - for (Path path : paths) { - for (CallQueryExpression queryCall : queryCalls) { - queryCall.evaluate(path); - } - } - } - } - // send eod to the sub query after call finish. - for (CallQueryExpression queryCall : queryCalls) { - queryCall.sendEod(); - } - stashPaths.clear(); - } - - for (ICallQuery queryCall : queryCalls) { - queryCall.finishCall(); - } + public List getInputs() { + return expressions[placeHolderIndex].getInputs(); } - private static class PlaceHolderExpression extends AbstractExpression { - - private final Expression[] expressions; - - private final int placeHolderIndex; - - private Object[] results; - - public PlaceHolderExpression(Expression[] expressions, int placeHolderIndex) { - this.expressions = expressions; - this.placeHolderIndex = placeHolderIndex; - } - - public void setResults(Object[] results) { - this.results = results; - } - - @Override - public Object evaluate(Row row) { - return results[placeHolderIndex]; - } - - @Override - public String showExpression() { - return expressions[placeHolderIndex].showExpression(); - } - - @Override - public IType getOutputType() { - return expressions[placeHolderIndex].getOutputType(); - } - - @Override - public List getInputs() { - return expressions[placeHolderIndex].getInputs(); - } - - @Override - public Expression copy(List inputs) { - expressions[placeHolderIndex] = expressions[placeHolderIndex].copy(inputs); - return new PlaceHolderExpression(expressions, placeHolderIndex); - } + @Override + public Expression copy(List inputs) { + expressions[placeHolderIndex] = expressions[placeHolderIndex].copy(inputs); + return new PlaceHolderExpression(expressions, placeHolderIndex); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/ICallQuery.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/ICallQuery.java index 82a61477c..d6d633932 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/ICallQuery.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/expression/subquery/ICallQuery.java @@ -24,11 +24,11 @@ public interface ICallQuery { - void open(TraversalRuntimeContext context); + void open(TraversalRuntimeContext context); - void finishCall(); + void finishCall(); - void setCallState(CallState callState); + void setCallState(CallState callState); - CallState getCallState(); + CallState getCallState(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/FunctionSchemas.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/FunctionSchemas.java index 155ec9d9f..1406680bb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/FunctionSchemas.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/FunctionSchemas.java @@ -20,67 +20,69 @@ package org.apache.geaflow.dsl.runtime.function.graph; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.common.types.PathType; public class FunctionSchemas { - private final List inputPathSchemas; + private final List inputPathSchemas; - private final PathType outputPathSchema; + private final PathType outputPathSchema; - private final IType outputType; + private final IType outputType; - private final GraphSchema graphSchema; + private final GraphSchema graphSchema; - private final GraphSchema modifyGraphSchema; + private final GraphSchema modifyGraphSchema; - private final IType[] addingVertexFieldTypes; + private final IType[] addingVertexFieldTypes; - private final String[] addingVertexFieldNames; + private final String[] addingVertexFieldNames; - public FunctionSchemas(List inputPathSchemas, - PathType outputPathSchema, - IType outputType, - GraphSchema graphSchema, - GraphSchema modifyGraphSchema, - IType[] addingVertexFieldTypes, - String[] addingVertexFieldNames) { - this.inputPathSchemas = inputPathSchemas; - this.outputPathSchema = outputPathSchema; - this.outputType = outputType; - this.graphSchema = graphSchema; - this.modifyGraphSchema = modifyGraphSchema; - this.addingVertexFieldTypes = addingVertexFieldTypes; - this.addingVertexFieldNames = addingVertexFieldNames; - } + public FunctionSchemas( + List inputPathSchemas, + PathType outputPathSchema, + IType outputType, + GraphSchema graphSchema, + GraphSchema modifyGraphSchema, + IType[] addingVertexFieldTypes, + String[] addingVertexFieldNames) { + this.inputPathSchemas = inputPathSchemas; + this.outputPathSchema = outputPathSchema; + this.outputType = outputType; + this.graphSchema = graphSchema; + this.modifyGraphSchema = modifyGraphSchema; + this.addingVertexFieldTypes = addingVertexFieldTypes; + this.addingVertexFieldNames = addingVertexFieldNames; + } - public List getInputPathSchemas() { - return inputPathSchemas; - } + public List getInputPathSchemas() { + return inputPathSchemas; + } - public PathType getOutputPathSchema() { - return outputPathSchema; - } + public PathType getOutputPathSchema() { + return outputPathSchema; + } - public IType getOutputType() { - return outputType; - } + public IType getOutputType() { + return outputType; + } - public GraphSchema getGraphSchema() { - return graphSchema; - } + public GraphSchema getGraphSchema() { + return graphSchema; + } - public GraphSchema getModifyGraphSchema() { - return modifyGraphSchema; - } + public GraphSchema getModifyGraphSchema() { + return modifyGraphSchema; + } - public IType[] getAddingVertexFieldTypes() { - return addingVertexFieldTypes; - } + public IType[] getAddingVertexFieldTypes() { + return addingVertexFieldTypes; + } - public String[] getAddingVertexFieldNames() { - return addingVertexFieldNames; - } + public String[] getAddingVertexFieldNames() { + return addingVertexFieldNames; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunction.java index 64cc2b9a3..db4c88867 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunction.java @@ -20,17 +20,18 @@ package org.apache.geaflow.dsl.runtime.function.graph; import java.util.Set; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.sqlnode.SqlMatchEdge.EdgeDirection; import org.apache.geaflow.state.pushdown.filter.IFilter; public interface MatchEdgeFunction extends StepFunction { - String getLabel(); + String getLabel(); - EdgeDirection getDirection(); + EdgeDirection getDirection(); - IFilter getEdgesFilter(); + IFilter getEdgesFilter(); - Set getEdgeTypes(); + Set getEdgeTypes(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunctionImpl.java index 7bffa5d7d..c30985cfc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchEdgeFunctionImpl.java @@ -24,6 +24,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -38,109 +39,115 @@ public class MatchEdgeFunctionImpl implements MatchEdgeFunction { - private final EdgeDirection direction; - - private final Set edgeTypes; - - private final String label; - - private IFilter edgesFilter; - - private final boolean isOptionalMatchEdge; - - public MatchEdgeFunctionImpl(EdgeDirection direction, Set edgeTypes, - boolean isOptionalMatchEdge, String label, IFilter edgeFilter) { - this.direction = direction; - this.edgeTypes = edgeTypes; - this.label = label; - this.edgesFilter = edgeFilter; - this.isOptionalMatchEdge = isOptionalMatchEdge; - } - - public MatchEdgeFunctionImpl(EdgeDirection direction, Set edgeTypes, String label, - IFilter edgeFilter) { - this(direction, edgeTypes, false, label, edgeFilter); - } - - public MatchEdgeFunctionImpl(EdgeDirection direction, Set edgeTypes, String label, - IFilter... pushDownFilter) { - this(direction, edgeTypes, false, label, pushDownFilter); - } - - public MatchEdgeFunctionImpl(EdgeDirection direction, Set edgeTypes, - boolean isOptionalMatchEdge, String label, - IFilter... pushDownFilter) { - this.direction = direction; - this.edgeTypes = Objects.requireNonNull(edgeTypes); - this.label = label; - IFilter directionFilter; - switch (direction) { - case OUT: - directionFilter = OutEdgeFilter.getInstance(); - break; - case IN: - directionFilter = InEdgeFilter.getInstance(); - break; - case BOTH: - directionFilter = EmptyFilter.getInstance(); - break; - default: - throw new IllegalArgumentException("Illegal edge direction: " + direction); - } - this.edgesFilter = directionFilter; - if (!edgeTypes.isEmpty()) { - this.edgesFilter.and(new EdgeLabelFilter(edgeTypes.stream().map(BinaryString::toString) - .collect(Collectors.toSet()))); - } - for (IFilter andFilter : pushDownFilter) { - this.edgesFilter = this.edgesFilter == null ? andFilter : - this.edgesFilter.and(andFilter); - } - this.isOptionalMatchEdge = isOptionalMatchEdge; - } - - @Override - public String getLabel() { - return label; - } - - @Override - public EdgeDirection getDirection() { - return direction; - } - - @Override - public Set getEdgeTypes() { - return edgeTypes; - } - - public boolean isOptionalMatchEdge() { - return isOptionalMatchEdge; + private final EdgeDirection direction; + + private final Set edgeTypes; + + private final String label; + + private IFilter edgesFilter; + + private final boolean isOptionalMatchEdge; + + public MatchEdgeFunctionImpl( + EdgeDirection direction, + Set edgeTypes, + boolean isOptionalMatchEdge, + String label, + IFilter edgeFilter) { + this.direction = direction; + this.edgeTypes = edgeTypes; + this.label = label; + this.edgesFilter = edgeFilter; + this.isOptionalMatchEdge = isOptionalMatchEdge; + } + + public MatchEdgeFunctionImpl( + EdgeDirection direction, Set edgeTypes, String label, IFilter edgeFilter) { + this(direction, edgeTypes, false, label, edgeFilter); + } + + public MatchEdgeFunctionImpl( + EdgeDirection direction, + Set edgeTypes, + String label, + IFilter... pushDownFilter) { + this(direction, edgeTypes, false, label, pushDownFilter); + } + + public MatchEdgeFunctionImpl( + EdgeDirection direction, + Set edgeTypes, + boolean isOptionalMatchEdge, + String label, + IFilter... pushDownFilter) { + this.direction = direction; + this.edgeTypes = Objects.requireNonNull(edgeTypes); + this.label = label; + IFilter directionFilter; + switch (direction) { + case OUT: + directionFilter = OutEdgeFilter.getInstance(); + break; + case IN: + directionFilter = InEdgeFilter.getInstance(); + break; + case BOTH: + directionFilter = EmptyFilter.getInstance(); + break; + default: + throw new IllegalArgumentException("Illegal edge direction: " + direction); } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - + this.edgesFilter = directionFilter; + if (!edgeTypes.isEmpty()) { + this.edgesFilter.and( + new EdgeLabelFilter( + edgeTypes.stream().map(BinaryString::toString).collect(Collectors.toSet()))); } - - @Override - public void finish(StepCollector collector) { - - } - - @Override - public IFilter getEdgesFilter() { - return edgesFilter; - } - - @Override - public List getExpressions() { - return Collections.emptyList(); - } - - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new MatchEdgeFunctionImpl(direction, edgeTypes, label, edgesFilter); + for (IFilter andFilter : pushDownFilter) { + this.edgesFilter = this.edgesFilter == null ? andFilter : this.edgesFilter.and(andFilter); } + this.isOptionalMatchEdge = isOptionalMatchEdge; + } + + @Override + public String getLabel() { + return label; + } + + @Override + public EdgeDirection getDirection() { + return direction; + } + + @Override + public Set getEdgeTypes() { + return edgeTypes; + } + + public boolean isOptionalMatchEdge() { + return isOptionalMatchEdge; + } + + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} + + @Override + public void finish(StepCollector collector) {} + + @Override + public IFilter getEdgesFilter() { + return edgesFilter; + } + + @Override + public List getExpressions() { + return Collections.emptyList(); + } + + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new MatchEdgeFunctionImpl(direction, edgeTypes, label, edgesFilter); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunction.java index 6c0b41f85..a5d764eca 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunction.java @@ -20,14 +20,15 @@ package org.apache.geaflow.dsl.runtime.function.graph; import java.util.Set; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.state.pushdown.filter.IFilter; public interface MatchVertexFunction extends StepFunction { - String getLabel(); + String getLabel(); - IFilter getVertexFilter(); + IFilter getVertexFilter(); - Set getVertexTypes(); + Set getVertexTypes(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunctionImpl.java index 7e1f6919e..b4970a7cf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVertexFunctionImpl.java @@ -25,6 +25,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -36,92 +37,96 @@ public class MatchVertexFunctionImpl implements MatchVertexFunction { - private final Set vertexTypes; - - private final String label; - - private IFilter vertexFilter; - - private final boolean isOptionalMatchVertex; - - private Set idSet; - - public MatchVertexFunctionImpl(Set vertexTypes, String label, IFilter vertexFilter) { - this(vertexTypes, false, label, vertexFilter); - } - - public MatchVertexFunctionImpl(Set vertexTypes, boolean isOptionalMatchVertex, - String label, IFilter vertexFilter) { - this.vertexTypes = vertexTypes; - this.label = label; - this.vertexFilter = vertexFilter; - this.isOptionalMatchVertex = isOptionalMatchVertex; - } - - public MatchVertexFunctionImpl(Set vertexTypes, String label, - IFilter... pushDownFilters) { - this(vertexTypes, false, label, new HashSet<>(), pushDownFilters); - } - - public MatchVertexFunctionImpl(Set vertexTypes, boolean isOptionalMatchVertex, - String label, Set idSet, IFilter... pushDownFilters) { - this.vertexTypes = Objects.requireNonNull(vertexTypes); - this.isOptionalMatchVertex = isOptionalMatchVertex; - this.label = label; - if (!vertexTypes.isEmpty()) { - this.vertexFilter = new VertexLabelFilter( - vertexTypes.stream().map(BinaryString::toString) - .collect(Collectors.toSet())); - } else { - this.vertexFilter = EmptyFilter.getInstance(); - } - for (IFilter filter : pushDownFilters) { - this.vertexFilter = this.vertexFilter == null ? filter : this.vertexFilter.and(filter); - } - this.idSet = idSet; - } - - @Override - public String getLabel() { - return label; - } - - @Override - public Set getVertexTypes() { - return vertexTypes; + private final Set vertexTypes; + + private final String label; + + private IFilter vertexFilter; + + private final boolean isOptionalMatchVertex; + + private Set idSet; + + public MatchVertexFunctionImpl( + Set vertexTypes, String label, IFilter vertexFilter) { + this(vertexTypes, false, label, vertexFilter); + } + + public MatchVertexFunctionImpl( + Set vertexTypes, + boolean isOptionalMatchVertex, + String label, + IFilter vertexFilter) { + this.vertexTypes = vertexTypes; + this.label = label; + this.vertexFilter = vertexFilter; + this.isOptionalMatchVertex = isOptionalMatchVertex; + } + + public MatchVertexFunctionImpl( + Set vertexTypes, String label, IFilter... pushDownFilters) { + this(vertexTypes, false, label, new HashSet<>(), pushDownFilters); + } + + public MatchVertexFunctionImpl( + Set vertexTypes, + boolean isOptionalMatchVertex, + String label, + Set idSet, + IFilter... pushDownFilters) { + this.vertexTypes = Objects.requireNonNull(vertexTypes); + this.isOptionalMatchVertex = isOptionalMatchVertex; + this.label = label; + if (!vertexTypes.isEmpty()) { + this.vertexFilter = + new VertexLabelFilter( + vertexTypes.stream().map(BinaryString::toString).collect(Collectors.toSet())); + } else { + this.vertexFilter = EmptyFilter.getInstance(); } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - - } - - @Override - public void finish(StepCollector collector) { - - } - - public boolean isOptionalMatchVertex() { - return isOptionalMatchVertex; - } - - public Set getIdSet() { - return idSet; - } - - @Override - public IFilter getVertexFilter() { - return vertexFilter; - } - - @Override - public List getExpressions() { - return Collections.emptyList(); - } - - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new MatchVertexFunctionImpl(vertexTypes, label, vertexFilter); + for (IFilter filter : pushDownFilters) { + this.vertexFilter = this.vertexFilter == null ? filter : this.vertexFilter.and(filter); } + this.idSet = idSet; + } + + @Override + public String getLabel() { + return label; + } + + @Override + public Set getVertexTypes() { + return vertexTypes; + } + + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} + + @Override + public void finish(StepCollector collector) {} + + public boolean isOptionalMatchVertex() { + return isOptionalMatchVertex; + } + + public Set getIdSet() { + return idSet; + } + + @Override + public IFilter getVertexFilter() { + return vertexFilter; + } + + @Override + public List getExpressions() { + return Collections.emptyList(); + } + + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new MatchVertexFunctionImpl(vertexTypes, label, vertexFilter); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunction.java index 2f9f92f55..af81366fb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunction.java @@ -20,14 +20,15 @@ package org.apache.geaflow.dsl.runtime.function.graph; import java.util.List; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; public interface MatchVirtualEdgeFunction extends StepFunction { - List computeTargetId(Path path); + List computeTargetId(Path path); - default ITreePath computeTargetPath(Object targetId, ITreePath currentPath) { - return currentPath; - } + default ITreePath computeTargetPath(Object targetId, ITreePath currentPath) { + return currentPath; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunctionImpl.java index b9b28b588..9aa189815 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/MatchVirtualEdgeFunctionImpl.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -29,36 +30,32 @@ public class MatchVirtualEdgeFunctionImpl implements MatchVirtualEdgeFunction { - private final Expression targetId; - - public MatchVirtualEdgeFunctionImpl(Expression targetId) { - this.targetId = targetId; - } - - @Override - public List computeTargetId(Path path) { - Object targetId = this.targetId.evaluate(path); - return Collections.singletonList(targetId); - } + private final Expression targetId; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public MatchVirtualEdgeFunctionImpl(Expression targetId) { + this.targetId = targetId; + } - } + @Override + public List computeTargetId(Path path) { + Object targetId = this.targetId.evaluate(path); + return Collections.singletonList(targetId); + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.singletonList(targetId); - } + @Override + public List getExpressions() { + return Collections.singletonList(targetId); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.size() == 1; - return new MatchVirtualEdgeFunctionImpl(expressions.get(0)); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.size() == 1; + return new MatchVirtualEdgeFunctionImpl(expressions.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggExpressionFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggExpressionFunctionImpl.java index 27e7a8283..b48c6c044 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggExpressionFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggExpressionFunctionImpl.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.utils.ClassUtil; import org.apache.geaflow.dsl.common.data.Row; @@ -46,201 +47,197 @@ public class StepAggExpressionFunctionImpl implements StepAggregateFunction { - private final StepAggCall[] aggFunctionCalls; + private final StepAggCall[] aggFunctionCalls; - private final IType[] aggOutputTypes; + private final IType[] aggOutputTypes; - private UDAF[] udafs; + private UDAF[] udafs; - private final int[] pathPruneIndices; + private final int[] pathPruneIndices; - private final IType[] inputPathTypes; + private final IType[] inputPathTypes; - private final IType[] pathPruneTypes; + private final IType[] pathPruneTypes; - public StepAggExpressionFunctionImpl(int[] pathPruneIndices, - IType[] pathPruneTypes, - IType[] inputPathTypes, - List aggFunctionCalls, - List> aggOutputTypes) { - this.pathPruneIndices = Objects.requireNonNull(pathPruneIndices); - this.pathPruneTypes = Objects.requireNonNull(pathPruneTypes); - this.inputPathTypes = Objects.requireNonNull(inputPathTypes); - this.aggFunctionCalls = Objects.requireNonNull(aggFunctionCalls).toArray(new StepAggCall[]{}); - this.aggOutputTypes = Objects.requireNonNull(aggOutputTypes).toArray(new IType[]{}); - } + public StepAggExpressionFunctionImpl( + int[] pathPruneIndices, + IType[] pathPruneTypes, + IType[] inputPathTypes, + List aggFunctionCalls, + List> aggOutputTypes) { + this.pathPruneIndices = Objects.requireNonNull(pathPruneIndices); + this.pathPruneTypes = Objects.requireNonNull(pathPruneTypes); + this.inputPathTypes = Objects.requireNonNull(inputPathTypes); + this.aggFunctionCalls = Objects.requireNonNull(aggFunctionCalls).toArray(new StepAggCall[] {}); + this.aggOutputTypes = Objects.requireNonNull(aggOutputTypes).toArray(new IType[] {}); + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - private final Object[] accumulators; + private final Object[] accumulators; - public Accumulator(Object[] accumulators) { - this.accumulators = accumulators; - } + public Accumulator(Object[] accumulators) { + this.accumulators = accumulators; + } - public Object getAcc(int index) { - return accumulators[index]; - } + public Object getAcc(int index) { + return accumulators[index]; } + } - public static class StepAggCall implements Serializable { - - /** - * The name of the agg function, e.g. SUM、COUNT. - */ - private final String name; - - /** - * The argument field expressions. - */ - private final Expression[] argExpressions; - - /** - * The argument field type. - */ - private final IType[] argFieldTypes; - - /** - * The UDAF implement class. - */ - private final Class> udafClass; - - /** - * The UDAF input class. - */ - private final Class udafInputClass; - /** - * The distinct flag. - */ - private final boolean isDistinct; - - public StepAggCall(String name, Expression[] argExpressions, IType[] argFieldTypes, - Class> udafClass, boolean isDistinct) { - this.name = name; - this.argExpressions = argExpressions; - this.argFieldTypes = argFieldTypes; - this.udafClass = udafClass; - Type[] genericTypes = getUDAFGenericTypes(udafClass); - this.udafInputClass = (Class) genericTypes[0]; - this.isDistinct = isDistinct; - } + public static class StepAggCall implements Serializable { - public String getName() { - return name; - } + /** The name of the agg function, e.g. SUM、COUNT. */ + private final String name; - public Expression[] getArgExpressions() { - return argExpressions; - } + /** The argument field expressions. */ + private final Expression[] argExpressions; - public IType[] getArgFieldTypes() { - return argFieldTypes; - } + /** The argument field type. */ + private final IType[] argFieldTypes; - public Class> getUdafClass() { - return udafClass; - } + /** The UDAF implement class. */ + private final Class> udafClass; - public boolean isDistinct() { - return isDistinct; - } - } + /** The UDAF input class. */ + private final Class udafInputClass; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - udafs = new UDAF[aggFunctionCalls.length]; - for (int i = 0; i < aggFunctionCalls.length; i++) { - try { - udafs[i] = (UDAF) aggFunctionCalls[i].getUdafClass().newInstance(); - if (aggFunctionCalls[i].isDistinct) { - udafs[i] = new DistinctUDAF(udafs[i]); - } - udafs[i].open(FunctionContext.of(context.getConfig())); - } catch (InstantiationException | IllegalAccessException e) { - throw new GeaFlowDSLException("Error in create UDAF: " + aggFunctionCalls[i].getUdafClass()); - } - } - } - - @Override - public void finish(StepCollector collector) { + /** The distinct flag. */ + private final boolean isDistinct; + public StepAggCall( + String name, + Expression[] argExpressions, + IType[] argFieldTypes, + Class> udafClass, + boolean isDistinct) { + this.name = name; + this.argExpressions = argExpressions; + this.argFieldTypes = argFieldTypes; + this.udafClass = udafClass; + Type[] genericTypes = getUDAFGenericTypes(udafClass); + this.udafInputClass = (Class) genericTypes[0]; + this.isDistinct = isDistinct; } - public IType[] getAggOutputTypes() { - return aggOutputTypes; + public String getName() { + return name; } - public int[] getPathPruneIndices() { - return pathPruneIndices; + public Expression[] getArgExpressions() { + return argExpressions; } - public IType[] getPathPruneTypes() { - return pathPruneTypes; + public IType[] getArgFieldTypes() { + return argFieldTypes; } - public IType[] getInputPathTypes() { - return inputPathTypes; + public Class> getUdafClass() { + return udafClass; } - @Override - public List getExpressions() { - return Collections.emptyList(); + public boolean isDistinct() { + return isDistinct; } - - @Override - public StepFunction copy(List expressions) { - return new StepAggExpressionFunctionImpl(pathPruneIndices, pathPruneTypes, inputPathTypes, - Arrays.stream(aggFunctionCalls).collect(Collectors.toList()), - Arrays.stream(aggOutputTypes).collect(Collectors.toList())); - } - - @Override - public Object createAccumulator() { - Object[] accumulators = new Object[udafs.length]; - for (int i = 0; i < accumulators.length; i++) { - accumulators[i] = udafs[i].createAccumulator(); + } + + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + udafs = new UDAF[aggFunctionCalls.length]; + for (int i = 0; i < aggFunctionCalls.length; i++) { + try { + udafs[i] = (UDAF) aggFunctionCalls[i].getUdafClass().newInstance(); + if (aggFunctionCalls[i].isDistinct) { + udafs[i] = new DistinctUDAF(udafs[i]); } - return new Accumulator(accumulators); + udafs[i].open(FunctionContext.of(context.getConfig())); + } catch (InstantiationException | IllegalAccessException e) { + throw new GeaFlowDSLException( + "Error in create UDAF: " + aggFunctionCalls[i].getUdafClass()); + } } - - @Override - public void add(Row row, Object accumulator) { - for (int i = 0; i < udafs.length; i++) { - StepAggCall aggInfo = aggFunctionCalls[i]; - Object argValue; - if (aggInfo.argExpressions.length == 0) { // for count() without input parameter. - argValue = row; - } else if (aggInfo.argExpressions.length == 1) { - argValue = aggInfo.argExpressions[0].evaluate(row); - } else { // for agg with multi-parameters - assert UDAFArguments.class.isAssignableFrom(aggInfo.udafInputClass); - Object[] parameters = new Object[aggInfo.argExpressions.length]; - for (int p = 0; p < parameters.length; p++) { - parameters[p] = aggInfo.argExpressions[p].evaluate(row); - } - argValue = ClassUtil.newInstance(aggInfo.udafInputClass); - ((UDAFArguments) argValue).setParams(parameters); - } - udafs[i].accumulate(((Accumulator) accumulator).accumulators[i], argValue); - } + } + + @Override + public void finish(StepCollector collector) {} + + public IType[] getAggOutputTypes() { + return aggOutputTypes; + } + + public int[] getPathPruneIndices() { + return pathPruneIndices; + } + + public IType[] getPathPruneTypes() { + return pathPruneTypes; + } + + public IType[] getInputPathTypes() { + return inputPathTypes; + } + + @Override + public List getExpressions() { + return Collections.emptyList(); + } + + @Override + public StepFunction copy(List expressions) { + return new StepAggExpressionFunctionImpl( + pathPruneIndices, + pathPruneTypes, + inputPathTypes, + Arrays.stream(aggFunctionCalls).collect(Collectors.toList()), + Arrays.stream(aggOutputTypes).collect(Collectors.toList())); + } + + @Override + public Object createAccumulator() { + Object[] accumulators = new Object[udafs.length]; + for (int i = 0; i < accumulators.length; i++) { + accumulators[i] = udafs[i].createAccumulator(); } - - @Override - public void merge(Object acc, Object otherAcc) { - for (int i = 0; i < udafs.length; i++) { - udafs[i].merge( - ((Accumulator) acc).getAcc(i), - Collections.singletonList(((Accumulator) otherAcc).getAcc(i))); + return new Accumulator(accumulators); + } + + @Override + public void add(Row row, Object accumulator) { + for (int i = 0; i < udafs.length; i++) { + StepAggCall aggInfo = aggFunctionCalls[i]; + Object argValue; + if (aggInfo.argExpressions.length == 0) { // for count() without input parameter. + argValue = row; + } else if (aggInfo.argExpressions.length == 1) { + argValue = aggInfo.argExpressions[0].evaluate(row); + } else { // for agg with multi-parameters + assert UDAFArguments.class.isAssignableFrom(aggInfo.udafInputClass); + Object[] parameters = new Object[aggInfo.argExpressions.length]; + for (int p = 0; p < parameters.length; p++) { + parameters[p] = aggInfo.argExpressions[p].evaluate(row); } + argValue = ClassUtil.newInstance(aggInfo.udafInputClass); + ((UDAFArguments) argValue).setParams(parameters); + } + udafs[i].accumulate(((Accumulator) accumulator).accumulators[i], argValue); } + } + + @Override + public void merge(Object acc, Object otherAcc) { + for (int i = 0; i < udafs.length; i++) { + udafs[i].merge( + ((Accumulator) acc).getAcc(i), + Collections.singletonList(((Accumulator) otherAcc).getAcc(i))); + } + } - @Override - public SingleValue getValue(Object accumulator) { - Object[] aggValues = new Object[udafs.length]; - for (int i = 0; i < udafs.length; i++) { - aggValues[i] = udafs[i].getValue(((Accumulator) accumulator).accumulators[i]); - } - return ObjectSingleValue.of(ObjectRow.create(aggValues)); + @Override + public SingleValue getValue(Object accumulator) { + Object[] aggValues = new Object[udafs.length]; + for (int i = 0; i < udafs.length; i++) { + aggValues[i] = udafs[i].getValue(((Accumulator) accumulator).accumulators[i]); } + return ObjectSingleValue.of(ObjectRow.create(aggValues)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggFunctionImpl.java index 72f0857b5..238b5f9d8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggFunctionImpl.java @@ -19,10 +19,10 @@ package org.apache.geaflow.dsl.runtime.function.graph; -import com.google.common.collect.Lists; import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.StepRecord; @@ -34,56 +34,56 @@ import org.apache.geaflow.dsl.runtime.traversal.data.ObjectSingleValue; import org.apache.geaflow.dsl.runtime.traversal.data.SingleValue; -public class StepAggFunctionImpl implements StepAggregateFunction { +import com.google.common.collect.Lists; - private final UDAF udaf; +public class StepAggFunctionImpl implements StepAggregateFunction { - private final IType inputType; + private final UDAF udaf; - public StepAggFunctionImpl(UDAF udaf, IType inputType) { - this.udaf = Objects.requireNonNull(udaf); - this.inputType = Objects.requireNonNull(inputType); - } + private final IType inputType; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - udaf.open(FunctionContext.of(context.getConfig())); - } + public StepAggFunctionImpl(UDAF udaf, IType inputType) { + this.udaf = Objects.requireNonNull(udaf); + this.inputType = Objects.requireNonNull(inputType); + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + udaf.open(FunctionContext.of(context.getConfig())); + } - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - return new StepAggFunctionImpl(udaf, inputType); - } + @Override + public StepFunction copy(List expressions) { + return new StepAggFunctionImpl(udaf, inputType); + } - @Override - public Object createAccumulator() { - return udaf.createAccumulator(); - } + @Override + public Object createAccumulator() { + return udaf.createAccumulator(); + } - @Override - public void add(Row row, Object accumulator) { - Object input = row.getField(0, inputType); - udaf.accumulate(accumulator, input); - } + @Override + public void add(Row row, Object accumulator) { + Object input = row.getField(0, inputType); + udaf.accumulate(accumulator, input); + } - @Override - public void merge(Object acc, Object otherAcc) { - udaf.merge(acc, Lists.newArrayList(otherAcc)); - } + @Override + public void merge(Object acc, Object otherAcc) { + udaf.merge(acc, Lists.newArrayList(otherAcc)); + } - @Override - public SingleValue getValue(Object accumulator) { - Object value = udaf.getValue(accumulator); - return ObjectSingleValue.of(value); - } + @Override + public SingleValue getValue(Object accumulator) { + Object value = udaf.getValue(accumulator); + return ObjectSingleValue.of(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggregateFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggregateFunction.java index b88f4f636..99b4fd368 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggregateFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepAggregateFunction.java @@ -24,11 +24,11 @@ public interface StepAggregateFunction extends StepFunction { - Object createAccumulator(); + Object createAccumulator(); - void add(Row row, Object accumulator); + void add(Row row, Object accumulator); - void merge(Object acc, Object otherAcc); + void merge(Object acc, Object otherAcc); - SingleValue getValue(Object accumulator); + SingleValue getValue(Object accumulator); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunction.java index cb597b44c..770ebec93 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunction.java @@ -23,5 +23,5 @@ public interface StepBoolFunction extends StepFunction { - boolean filter(Row path); + boolean filter(Row path); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunctionImpl.java index 2c4454bd5..82bf5a902 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepBoolFunctionImpl.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -30,41 +31,38 @@ public class StepBoolFunctionImpl implements StepBoolFunction { - private final Expression condition; - - public StepBoolFunctionImpl(Expression condition) { - this.condition = Objects.requireNonNull(condition); - } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - StepFunction.openExpression(condition, context); - } + private final Expression condition; - @Override - public void finish(StepCollector collector) { + public StepBoolFunctionImpl(Expression condition) { + this.condition = Objects.requireNonNull(condition); + } - } + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + StepFunction.openExpression(condition, context); + } - @Override - public boolean filter(Row record) { - Boolean accept = (Boolean) condition.evaluate(record); - return accept != null && accept; - } + @Override + public void finish(StepCollector collector) {} - public Expression getCondition() { - return condition; - } + @Override + public boolean filter(Row record) { + Boolean accept = (Boolean) condition.evaluate(record); + return accept != null && accept; + } - @Override - public List getExpressions() { - return Collections.singletonList(condition); - } + public Expression getCondition() { + return condition; + } - @Override - public StepFunction copy(List expressions) { - assert expressions.size() == 1; - return new StepBoolFunctionImpl(expressions.get(0)); - } + @Override + public List getExpressions() { + return Collections.singletonList(condition); + } + @Override + public StepFunction copy(List expressions) { + assert expressions.size() == 1; + return new StepBoolFunctionImpl(expressions.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFlatMapFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFlatMapFunction.java index 11dcc11ba..587a90b45 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFlatMapFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFlatMapFunction.java @@ -25,5 +25,5 @@ public interface StepFlatMapFunction extends StepFunction { - void process(Row record, StepCollector collector); + void process(Row record, StepCollector collector); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFunction.java index a771bc0e9..578892be2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepFunction.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -32,26 +33,26 @@ public interface StepFunction extends Serializable { - void open(TraversalRuntimeContext context, FunctionSchemas schemas); - - void finish(StepCollector collector); + void open(TraversalRuntimeContext context, FunctionSchemas schemas); - List getExpressions(); + void finish(StepCollector collector); - StepFunction copy(List expressions); + List getExpressions(); - static void openExpression(Expression expression, TraversalRuntimeContext context) { - if (expression instanceof ICallQuery) { - ((ICallQuery) expression).open(context); - } else { - expression.open(FunctionContext.of(context.getConfig())); - } - } + StepFunction copy(List expressions); - default List getCallQueryProxies() { - return getExpressions().stream() - .filter(exp -> exp instanceof CallQueryProxy) - .map(exp -> (CallQueryProxy) exp) - .collect(Collectors.toList()); + static void openExpression(Expression expression, TraversalRuntimeContext context) { + if (expression instanceof ICallQuery) { + ((ICallQuery) expression).open(context); + } else { + expression.open(FunctionContext.of(context.getConfig())); } + } + + default List getCallQueryProxies() { + return getExpressions().stream() + .filter(exp -> exp instanceof CallQueryProxy) + .map(exp -> (CallQueryProxy) exp) + .collect(Collectors.toList()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepGroupFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepGroupFunction.java index b3db5c866..c570eae6a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepGroupFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepGroupFunction.java @@ -25,5 +25,5 @@ public interface StepGroupFunction extends StepFunction { - void process(EdgeGroup edges, StepCollector collector); + void process(EdgeGroup edges, StepCollector collector); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunction.java index a5576b1ca..adc952e96 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunction.java @@ -24,7 +24,7 @@ public interface StepJoinFunction extends StepFunction { - Path join(Path left, Path right); + Path join(Path left, Path right); - JoinRelType getJoinType(); + JoinRelType getJoinType(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunctionImpl.java index 4c588e944..978d1d189 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepJoinFunctionImpl.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.calcite.rel.core.JoinRelType; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; @@ -35,106 +36,102 @@ public class StepJoinFunctionImpl implements StepJoinFunction { - private final JoinRelType joinType; + private final JoinRelType joinType; - private final IType[] leftTypes; + private final IType[] leftTypes; - private final IType[] rightTypes; + private final IType[] rightTypes; - private final Expression condition; + private final Expression condition; - public StepJoinFunctionImpl(JoinRelType joinType, IType[] leftTypes, IType[] rightTypes) { - this(joinType, leftTypes, rightTypes, null); - } + public StepJoinFunctionImpl(JoinRelType joinType, IType[] leftTypes, IType[] rightTypes) { + this(joinType, leftTypes, rightTypes, null); + } - public StepJoinFunctionImpl(JoinRelType joinType, IType[] leftTypes, - IType[] rightTypes, Expression condition) { - this.joinType = joinType; - this.leftTypes = Objects.requireNonNull(leftTypes); - this.rightTypes = Objects.requireNonNull(rightTypes); - this.condition = condition; - } + public StepJoinFunctionImpl( + JoinRelType joinType, IType[] leftTypes, IType[] rightTypes, Expression condition) { + this.joinType = joinType; + this.leftTypes = Objects.requireNonNull(leftTypes); + this.rightTypes = Objects.requireNonNull(rightTypes); + this.condition = condition; + } - @Override - public Path join(Path left, Path right) { - switch (joinType) { - case INNER: - if (left == null || right == null) { - return null; - } - Path innerJoinPath = joinPath(left, right); - if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(innerJoinPath))) { - return innerJoinPath; - } else { - return null; - } - case LEFT: - if (left == null) { - return null; - } - if (right == null) { - right = new DefaultPath(new Row[rightTypes.length]); - } - Path leftJoinPath = joinPath(left, right); - if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(leftJoinPath))) { - return leftJoinPath; - } else { - return null; - } - case RIGHT: - if (right == null) { - return null; - } - if (left == null) { - left = new DefaultPath(new Row[leftTypes.length]); - } - Path rightJoinPath = joinPath(left, right); - if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(rightJoinPath))) { - return rightJoinPath; - } else { - return null; - } - default: - throw new GeaFlowDSLException("JoinType: " + joinType + " is not support"); + @Override + public Path join(Path left, Path right) { + switch (joinType) { + case INNER: + if (left == null || right == null) { + return null; } - } - - @Override - public JoinRelType getJoinType() { - return joinType; - } - - private Path joinPath(Path left, Path right) { - Row[] joinPaths = new Row[leftTypes.length + rightTypes.length]; - int i; - for (i = 0; i < leftTypes.length; i++) { - joinPaths[i] = left.getField(i, leftTypes[i]); + Path innerJoinPath = joinPath(left, right); + if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(innerJoinPath))) { + return innerJoinPath; + } else { + return null; + } + case LEFT: + if (left == null) { + return null; + } + if (right == null) { + right = new DefaultPath(new Row[rightTypes.length]); } - for (; i < leftTypes.length + rightTypes.length; i++) { - int rightIndex = i - leftTypes.length; - joinPaths[i] = right.getField(rightIndex, rightTypes[rightIndex]); + Path leftJoinPath = joinPath(left, right); + if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(leftJoinPath))) { + return leftJoinPath; + } else { + return null; } - return new DefaultPath(joinPaths); + case RIGHT: + if (right == null) { + return null; + } + if (left == null) { + left = new DefaultPath(new Row[leftTypes.length]); + } + Path rightJoinPath = joinPath(left, right); + if (condition == null || Boolean.valueOf(true).equals(condition.evaluate(rightJoinPath))) { + return rightJoinPath; + } else { + return null; + } + default: + throw new GeaFlowDSLException("JoinType: " + joinType + " is not support"); } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - + } + + @Override + public JoinRelType getJoinType() { + return joinType; + } + + private Path joinPath(Path left, Path right) { + Row[] joinPaths = new Row[leftTypes.length + rightTypes.length]; + int i; + for (i = 0; i < leftTypes.length; i++) { + joinPaths[i] = left.getField(i, leftTypes[i]); } + for (; i < leftTypes.length + rightTypes.length; i++) { + int rightIndex = i - leftTypes.length; + joinPaths[i] = right.getField(rightIndex, rightTypes[rightIndex]); + } + return new DefaultPath(joinPaths); + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new StepJoinFunctionImpl(joinType, leftTypes, rightTypes, condition); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new StepJoinFunctionImpl(joinType, leftTypes, rightTypes, condition); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyExpressionFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyExpressionFunctionImpl.java index 8edf664e7..91b1725eb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyExpressionFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyExpressionFunctionImpl.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowKey; @@ -33,42 +34,38 @@ public class StepKeyExpressionFunctionImpl implements StepKeyFunction { - private final Expression[] keyExpressions; - - private final IType[] keyTypes; - - public StepKeyExpressionFunctionImpl(Expression[] keyIndices, IType[] keyTypes) { - this.keyExpressions = keyIndices; - this.keyTypes = keyTypes; - assert keyIndices.length == keyTypes.length; - } + private final Expression[] keyExpressions; - @Override - public RowKey getKey(Row row) { - Object[] keys = new Object[keyExpressions.length]; - for (int i = 0; i < keys.length; i++) { - keys[i] = keyExpressions[i].evaluate(row); - } - return ObjectRowKey.of(keys); - } + private final IType[] keyTypes; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public StepKeyExpressionFunctionImpl(Expression[] keyIndices, IType[] keyTypes) { + this.keyExpressions = keyIndices; + this.keyTypes = keyTypes; + assert keyIndices.length == keyTypes.length; + } + @Override + public RowKey getKey(Row row) { + Object[] keys = new Object[keyExpressions.length]; + for (int i = 0; i < keys.length; i++) { + keys[i] = keyExpressions[i].evaluate(row); } + return ObjectRowKey.of(keys); + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Arrays.stream(keyExpressions).collect(Collectors.toList()); - } + @Override + public List getExpressions() { + return Arrays.stream(keyExpressions).collect(Collectors.toList()); + } - @Override - public StepFunction copy(List expressions) { - return new StepKeyExpressionFunctionImpl(keyExpressions, keyTypes); - } + @Override + public StepFunction copy(List expressions) { + return new StepKeyExpressionFunctionImpl(keyExpressions, keyTypes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunction.java index ef0afe6fb..e84a1c601 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunction.java @@ -24,5 +24,5 @@ public interface StepKeyFunction extends StepFunction { - RowKey getKey(Row row); + RowKey getKey(Row row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunctionImpl.java index 88facb314..8ea92bd00 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepKeyFunctionImpl.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowKey; @@ -32,43 +33,39 @@ public class StepKeyFunctionImpl implements StepKeyFunction { - private final int[] keyIndices; - - private final IType[] keyTypes; - - public StepKeyFunctionImpl(int[] keyIndices, IType[] keyTypes) { - this.keyIndices = keyIndices; - this.keyTypes = keyTypes; - assert keyIndices.length == keyTypes.length; - } + private final int[] keyIndices; - @Override - public RowKey getKey(Row row) { - Object[] keys = new Object[keyIndices.length]; - for (int i = 0; i < keys.length; i++) { - keys[i] = row.getField(keyIndices[i], keyTypes[i]); - } - return ObjectRowKey.of(keys); - } + private final IType[] keyTypes; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public StepKeyFunctionImpl(int[] keyIndices, IType[] keyTypes) { + this.keyIndices = keyIndices; + this.keyTypes = keyTypes; + assert keyIndices.length == keyTypes.length; + } + @Override + public RowKey getKey(Row row) { + Object[] keys = new Object[keyIndices.length]; + for (int i = 0; i < keys.length; i++) { + keys[i] = row.getField(keyIndices[i], keyTypes[i]); } + return ObjectRowKey.of(keys); + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new StepKeyFunctionImpl(keyIndices, keyTypes); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new StepKeyFunctionImpl(keyIndices, keyTypes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapFunction.java index 94a2c47d8..1bde0f195 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapFunction.java @@ -24,5 +24,5 @@ public interface StepMapFunction extends StepFunction { - Path map(Row record); + Path map(Row record); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapRowFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapRowFunction.java index f3a5e1640..e3767d11f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapRowFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepMapRowFunction.java @@ -23,5 +23,5 @@ public interface StepMapRowFunction extends StepFunction { - Row map(Row record); + Row map(Row record); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeFilterFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeFilterFunction.java index db8b59647..630675959 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeFilterFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeFilterFunction.java @@ -23,5 +23,5 @@ public interface StepNodeFilterFunction extends StepFunction { - boolean filter(RowVertex vertex); + boolean filter(RowVertex vertex); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeTypeFilterFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeTypeFilterFunction.java index ebd7f3e1d..c67110efe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeTypeFilterFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepNodeTypeFilterFunction.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Set; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.StepRecord; @@ -31,35 +32,31 @@ public class StepNodeTypeFilterFunction implements StepNodeFilterFunction { - private final Set nodeTypes; - - public StepNodeTypeFilterFunction(Set nodeTypes) { - this.nodeTypes = nodeTypes; - } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + private final Set nodeTypes; - } + public StepNodeTypeFilterFunction(Set nodeTypes) { + this.nodeTypes = nodeTypes; + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new StepNodeTypeFilterFunction(nodeTypes); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new StepNodeTypeFilterFunction(nodeTypes); + } - @Override - public boolean filter(RowVertex vertex) { - return nodeTypes.contains(vertex.getBinaryLabel()); - } + @Override + public boolean filter(RowVertex vertex) { + return nodeTypes.contains(vertex.getBinaryLabel()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepPathModifyFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepPathModifyFunction.java index c757d756e..ca8194eea 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepPathModifyFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepPathModifyFunction.java @@ -19,10 +19,10 @@ package org.apache.geaflow.dsl.runtime.function.graph; -import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; import java.util.Objects; + import org.apache.commons.lang3.ArrayUtils; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; @@ -37,97 +37,103 @@ import org.apache.geaflow.dsl.runtime.traversal.collector.StepCollector; import org.apache.geaflow.dsl.runtime.traversal.data.GlobalVariable; -public class StepPathModifyFunction implements StepMapFunction { - - protected final int[] updatePathIndices; +import com.google.common.collect.ImmutableList; - protected final Expression[] modifyExpressions; +public class StepPathModifyFunction implements StepMapFunction { - protected final IType[] fieldTypes; + protected final int[] updatePathIndices; - protected TraversalRuntimeContext context; + protected final Expression[] modifyExpressions; - private FunctionSchemas schemas; + protected final IType[] fieldTypes; - private final int newFieldNum; + protected TraversalRuntimeContext context; - public StepPathModifyFunction(int[] updatePathIndices, - Expression[] modifyExpressions, - IType[] fieldTypes) { - this.updatePathIndices = Objects.requireNonNull(updatePathIndices); - this.modifyExpressions = Objects.requireNonNull(modifyExpressions); - assert updatePathIndices.length == modifyExpressions.length; - this.fieldTypes = Objects.requireNonNull(fieldTypes); - this.newFieldNum = this.updatePathIndices.length > 0 ? Math.max(fieldTypes.length, - 1 + Arrays.stream(updatePathIndices).max().getAsInt()) : fieldTypes.length; - } + private FunctionSchemas schemas; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - this.context = context; - for (Expression expression : modifyExpressions) { - StepFunction.openExpression(expression, context); - } - this.schemas = schemas; - for (int i = 0; i < updatePathIndices.length; i++) { - if (modifyExpressions[i] instanceof VertexConstructExpression) { - VertexConstructExpression vertexConstruct = (VertexConstructExpression) modifyExpressions[i]; - List globalVariables = vertexConstruct.getGlobalVariables(); - for (GlobalVariable gv : globalVariables) { - gv.setAddFieldIndex(ArrayUtils.indexOf(schemas.getAddingVertexFieldNames(), gv.getName())); - } - } - } - } + private final int newFieldNum; - @Override - public void finish(StepCollector collector) { + public StepPathModifyFunction( + int[] updatePathIndices, Expression[] modifyExpressions, IType[] fieldTypes) { + this.updatePathIndices = Objects.requireNonNull(updatePathIndices); + this.modifyExpressions = Objects.requireNonNull(modifyExpressions); + assert updatePathIndices.length == modifyExpressions.length; + this.fieldTypes = Objects.requireNonNull(fieldTypes); + this.newFieldNum = + this.updatePathIndices.length > 0 + ? Math.max(fieldTypes.length, 1 + Arrays.stream(updatePathIndices).max().getAsInt()) + : fieldTypes.length; + } + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + this.context = context; + for (Expression expression : modifyExpressions) { + StepFunction.openExpression(expression, context); } - - @Override - public Path map(Row record) { - Row[] values = new Row[newFieldNum]; - for (int i = 0; i < fieldTypes.length; i++) { - values[i] = (Row) record.getField(i, fieldTypes[i]); - } - for (int i = 0; i < updatePathIndices.length; i++) { - Row value = (Row) modifyExpressions[i].evaluate(record); - updateGlobalVariable(modifyExpressions[i], value); - values[updatePathIndices[i]] = value; + this.schemas = schemas; + for (int i = 0; i < updatePathIndices.length; i++) { + if (modifyExpressions[i] instanceof VertexConstructExpression) { + VertexConstructExpression vertexConstruct = + (VertexConstructExpression) modifyExpressions[i]; + List globalVariables = vertexConstruct.getGlobalVariables(); + for (GlobalVariable gv : globalVariables) { + gv.setAddFieldIndex( + ArrayUtils.indexOf(schemas.getAddingVertexFieldNames(), gv.getName())); } - return new DefaultPath(values); + } } + } - private void updateGlobalVariable(Expression modifyExpression, Row value) { - // modify global variable to vertex. - if (modifyExpression instanceof VertexConstructExpression) { - VertexConstructExpression vertexConstruct = (VertexConstructExpression) modifyExpression; - - List globalVariables = vertexConstruct.getGlobalVariables(); - for (GlobalVariable gv : globalVariables) { - // index of the global variable - int index = gv.getIndex(); - VertexType vertexType = ((VertexType) vertexConstruct.getOutputType()); - IType fieldType = vertexType.getType(index); - Object fieldValue = value.getField(index, fieldType); - // add field to vertex which will affect all the computing with this vertexId - int updateIndex = gv.getAddFieldIndex(); - updateIndex = updateIndex >= 0 ? updateIndex : ArrayUtils.indexOf( - schemas.getAddingVertexFieldNames(), gv.getName()); - context.addFieldToVertex(((RowVertex) value).getId(), updateIndex, fieldValue); - } - } - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return ImmutableList.copyOf(modifyExpressions); + @Override + public Path map(Row record) { + Row[] values = new Row[newFieldNum]; + for (int i = 0; i < fieldTypes.length; i++) { + values[i] = (Row) record.getField(i, fieldTypes[i]); } - - @Override - public StepFunction copy(List expressions) { - assert expressions.size() == this.modifyExpressions.length; - return new StepPathModifyFunction(updatePathIndices, expressions.toArray(new Expression[]{}), fieldTypes); + for (int i = 0; i < updatePathIndices.length; i++) { + Row value = (Row) modifyExpressions[i].evaluate(record); + updateGlobalVariable(modifyExpressions[i], value); + values[updatePathIndices[i]] = value; + } + return new DefaultPath(values); + } + + private void updateGlobalVariable(Expression modifyExpression, Row value) { + // modify global variable to vertex. + if (modifyExpression instanceof VertexConstructExpression) { + VertexConstructExpression vertexConstruct = (VertexConstructExpression) modifyExpression; + + List globalVariables = vertexConstruct.getGlobalVariables(); + for (GlobalVariable gv : globalVariables) { + // index of the global variable + int index = gv.getIndex(); + VertexType vertexType = ((VertexType) vertexConstruct.getOutputType()); + IType fieldType = vertexType.getType(index); + Object fieldValue = value.getField(index, fieldType); + // add field to vertex which will affect all the computing with this vertexId + int updateIndex = gv.getAddFieldIndex(); + updateIndex = + updateIndex >= 0 + ? updateIndex + : ArrayUtils.indexOf(schemas.getAddingVertexFieldNames(), gv.getName()); + context.addFieldToVertex(((RowVertex) value).getId(), updateIndex, fieldValue); + } } + } + + @Override + public List getExpressions() { + return ImmutableList.copyOf(modifyExpressions); + } + + @Override + public StepFunction copy(List expressions) { + assert expressions.size() == this.modifyExpressions.length; + return new StepPathModifyFunction( + updatePathIndices, expressions.toArray(new Expression[] {}), fieldTypes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSingleValueMapFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSingleValueMapFunction.java index ebf7b0ced..6791f1f6c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSingleValueMapFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSingleValueMapFunction.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -30,36 +31,34 @@ public class StepSingleValueMapFunction implements StepMapRowFunction { - private final Expression valueExpression; - - public StepSingleValueMapFunction(Expression valueExpression) { - this.valueExpression = valueExpression; - } + private final Expression valueExpression; - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - StepFunction.openExpression(valueExpression, context); - } + public StepSingleValueMapFunction(Expression valueExpression) { + this.valueExpression = valueExpression; + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + StepFunction.openExpression(valueExpression, context); + } - } + @Override + public void finish(StepCollector collector) {} - @Override - public Row map(Row record) { - Object value = valueExpression.evaluate(record); - return ObjectRow.create(value); - } + @Override + public Row map(Row record) { + Object value = valueExpression.evaluate(record); + return ObjectRow.create(value); + } - @Override - public List getExpressions() { - return Collections.singletonList(valueExpression); - } + @Override + public List getExpressions() { + return Collections.singletonList(valueExpression); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.size() == 1; - return new StepSingleValueMapFunction(expressions.get(0)); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.size() == 1; + return new StepSingleValueMapFunction(expressions.get(0)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunction.java index 311fcd1bf..edcdacb4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunction.java @@ -24,5 +24,5 @@ public interface StepSortFunction extends StepFunction { - void process(RowVertex currentVertex, Path path); + void process(RowVertex currentVertex, Path path); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunctionImpl.java index b6f2ce077..3c0f2e45b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/StepSortFunctionImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.function.graph; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -28,6 +27,7 @@ import java.util.Map; import java.util.PriorityQueue; import java.util.stream.Collectors; + import org.apache.commons.lang3.tuple.Pair; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -43,127 +43,131 @@ import org.apache.geaflow.dsl.runtime.traversal.data.VertexRecord; import org.apache.geaflow.dsl.runtime.traversal.path.TreePaths; +import com.google.common.collect.Lists; + public class StepSortFunctionImpl implements StepSortFunction { - private final SortInfo sortInfo; + private final SortInfo sortInfo; - private PriorityQueue> topNQueue; + private PriorityQueue> topNQueue; - private List> paths; + private List> paths; - private TopNRowComparator topNComparator; + private TopNRowComparator topNComparator; - private final boolean isGlobalSortFunction; + private final boolean isGlobalSortFunction; - public StepSortFunctionImpl(SortInfo sortInfo) { - this.sortInfo = sortInfo; - this.isGlobalSortFunction = false; - } + public StepSortFunctionImpl(SortInfo sortInfo) { + this.sortInfo = sortInfo; + this.isGlobalSortFunction = false; + } - private StepSortFunctionImpl(SortInfo sortInfo, boolean isGlobalSortFunction) { - this.sortInfo = sortInfo; - this.isGlobalSortFunction = isGlobalSortFunction; - } + private StepSortFunctionImpl(SortInfo sortInfo, boolean isGlobalSortFunction) { + this.sortInfo = sortInfo; + this.isGlobalSortFunction = isGlobalSortFunction; + } - public StepSortFunctionImpl copy(boolean isGlobalSortFunction) { - return new StepSortFunctionImpl(sortInfo, isGlobalSortFunction); - } + public StepSortFunctionImpl copy(boolean isGlobalSortFunction) { + return new StepSortFunctionImpl(sortInfo, isGlobalSortFunction); + } - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - this.topNComparator = new TopNRowComparator<>(sortInfo); - if (sortInfo.fetch > 0) { - this.topNQueue = new PriorityQueue<>(sortInfo.fetch, - new RowPairComparator(topNComparator.getNegativeComparator())); - } else { - this.paths = new ArrayList<>(); - } + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + this.topNComparator = new TopNRowComparator<>(sortInfo); + if (sortInfo.fetch > 0) { + this.topNQueue = + new PriorityQueue<>( + sortInfo.fetch, new RowPairComparator(topNComparator.getNegativeComparator())); + } else { + this.paths = new ArrayList<>(); } - - @Override - public List getExpressions() { - return sortInfo.orderByFields.stream() - .map(field -> field.expression) - .collect(Collectors.toList()); + } + + @Override + public List getExpressions() { + return sortInfo.orderByFields.stream() + .map(field -> field.expression) + .collect(Collectors.toList()); + } + + @Override + public StepFunction copy(List expressions) { + assert sortInfo.orderByFields.size() == expressions.size(); + + List newOrderByFields = new ArrayList<>(sortInfo.orderByFields.size()); + for (int i = 0; i < expressions.size(); i++) { + OrderByField newOrderByField = sortInfo.orderByFields.get(i).copy(expressions.get(i)); + newOrderByFields.add(newOrderByField); } + return new StepSortFunctionImpl(sortInfo.copy(newOrderByFields), isGlobalSortFunction); + } - @Override - public StepFunction copy(List expressions) { - assert sortInfo.orderByFields.size() == expressions.size(); - - List newOrderByFields = new ArrayList<>(sortInfo.orderByFields.size()); - for (int i = 0; i < expressions.size(); i++) { - OrderByField newOrderByField = sortInfo.orderByFields.get(i).copy(expressions.get(i)); - newOrderByFields.add(newOrderByField); - } - return new StepSortFunctionImpl(sortInfo.copy(newOrderByFields), isGlobalSortFunction); + @Override + public void process(RowVertex currentVertex, Path path) { + if (sortInfo.fetch == 0) { + return; } - - @Override - public void process(RowVertex currentVertex, Path path) { - if (sortInfo.fetch == 0) { - return; + Pair newPair = Pair.of(currentVertex, path); + if (topNQueue != null) { + if (topNQueue.size() == sortInfo.fetch) { + if (sortInfo.orderByFields.isEmpty()) { + return; } - Pair newPair = Pair.of(currentVertex, path); - if (topNQueue != null) { - if (topNQueue.size() == sortInfo.fetch) { - if (sortInfo.orderByFields.isEmpty()) { - return; - } - Pair top = topNQueue.peek(); - if (topNQueue.comparator().compare(top, newPair) < 0) { - topNQueue.remove(); - topNQueue.add(newPair); - } - } else { - topNQueue.add(newPair); - } - } else { - paths.add(newPair); + Pair top = topNQueue.peek(); + if (topNQueue.comparator().compare(top, newPair) < 0) { + topNQueue.remove(); + topNQueue.add(newPair); } + } else { + topNQueue.add(newPair); + } + } else { + paths.add(newPair); } + } + + @Override + public void finish(StepCollector collector) { + List> topNPaths; + if (topNQueue != null) { + topNPaths = Lists.newArrayList(topNQueue.iterator()); + topNQueue.clear(); + } else { + topNPaths = Lists.newArrayList(paths); + paths.clear(); + } + topNPaths.sort(new RowPairComparator(topNComparator)); - @Override - public void finish(StepCollector collector) { - List> topNPaths; - if (topNQueue != null) { - topNPaths = Lists.newArrayList(topNQueue.iterator()); - topNQueue.clear(); + Map> head2PathsMap = new HashMap<>(); + for (Pair pair : topNPaths) { + head2PathsMap.computeIfAbsent(pair.getKey(), x -> new ArrayList<>()); + head2PathsMap.get(pair.getKey()).add(pair.getValue()); + } + for (RowVertex currentVertex : head2PathsMap.keySet()) { + for (Path path : head2PathsMap.get(currentVertex)) { + if (isGlobalSortFunction) { + collector.collect( + VertexRecord.of( + currentVertex, TreePaths.createTreePath(Collections.singletonList(path)))); } else { - topNPaths = Lists.newArrayList(paths); - paths.clear(); - } - topNPaths.sort(new RowPairComparator(topNComparator)); - - Map> head2PathsMap = new HashMap<>(); - for (Pair pair : topNPaths) { - head2PathsMap.computeIfAbsent(pair.getKey(), x -> new ArrayList<>()); - head2PathsMap.get(pair.getKey()).add(pair.getValue()); - } - for (RowVertex currentVertex : head2PathsMap.keySet()) { - for (Path path : head2PathsMap.get(currentVertex)) { - if (isGlobalSortFunction) { - collector.collect(VertexRecord.of(currentVertex, - TreePaths.createTreePath(Collections.singletonList(path)))); - } else { - Row row = ObjectRow.create(currentVertex, path); - collector.collect(row); - } - } + Row row = ObjectRow.create(currentVertex, path); + collector.collect(row); } + } } + } - protected class RowPairComparator implements Comparator> { + protected class RowPairComparator implements Comparator> { - private final Comparator comparator; + private final Comparator comparator; - public RowPairComparator(Comparator comparator) { - this.comparator = comparator; - } + public RowPairComparator(Comparator comparator) { + this.comparator = comparator; + } - @Override - public int compare(Pair a, Pair b) { - return comparator.compare(a.getValue(), b.getValue()); - } + @Override + public int compare(Pair a, Pair b) { + return comparator.compare(a.getValue(), b.getValue()); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/TraversalFromVertexFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/TraversalFromVertexFunction.java index 75089f5cc..f2867ea67 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/TraversalFromVertexFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/TraversalFromVertexFunction.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -33,45 +34,41 @@ public class TraversalFromVertexFunction implements MatchVirtualEdgeFunction { - private final int vertexFieldIndex; - private final IType vertexFieldType; - - public TraversalFromVertexFunction(int vertexFieldIndex, IType vertexFieldType) { - this.vertexFieldIndex = vertexFieldIndex; - this.vertexFieldType = vertexFieldType; - } - - @Override - public List computeTargetId(Path path) { - Row node = path.getField(vertexFieldIndex, vertexFieldType); - RowVertex vertex = (RowVertex) node; - return Collections.singletonList(vertex.getId()); - } + private final int vertexFieldIndex; + private final IType vertexFieldType; - @Override - public ITreePath computeTargetPath(Object targetId, ITreePath currentPath) { - // just traversal from the vertex and cannot carry the paths. - return null; - } + public TraversalFromVertexFunction(int vertexFieldIndex, IType vertexFieldType) { + this.vertexFieldIndex = vertexFieldIndex; + this.vertexFieldType = vertexFieldType; + } - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + @Override + public List computeTargetId(Path path) { + Row node = path.getField(vertexFieldIndex, vertexFieldType); + RowVertex vertex = (RowVertex) node; + return Collections.singletonList(vertex.getId()); + } - } + @Override + public ITreePath computeTargetPath(Object targetId, ITreePath currentPath) { + // just traversal from the vertex and cannot carry the paths. + return null; + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new TraversalFromVertexFunction(vertexFieldIndex, vertexFieldType); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new TraversalFromVertexFunction(vertexFieldIndex, vertexFieldType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/AbstractVertexScanSourceFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/AbstractVertexScanSourceFunction.java index 4d35b7c92..8cb1f4952 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/AbstractVertexScanSourceFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/AbstractVertexScanSourceFunction.java @@ -19,11 +19,11 @@ package org.apache.geaflow.dsl.runtime.function.graph.source; -import com.google.common.base.Preconditions; import java.io.IOException; import java.util.Iterator; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; import org.apache.geaflow.api.function.io.SourceFunction; @@ -46,123 +46,137 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractVertexScanSourceFunction extends RichFunction implements - SourceFunction { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractVertexScanSourceFunction.class); - - protected transient RuntimeContext runtimeContext; - - protected GraphViewDesc graphViewDesc; - - protected transient GraphState graphState; - - private Iterator idIterator; - - private long windSize; - - private static final AtomicInteger storeCounter = new AtomicInteger(0); - - public AbstractVertexScanSourceFunction(GraphViewDesc graphViewDesc) { - this.graphViewDesc = Objects.requireNonNull(graphViewDesc); - } - - @Override - public void open(RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - this.windSize = this.runtimeContext.getConfiguration().getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); - Configuration rewriteConfiguration = runtimeContext.getConfiguration(); - String jobName = rewriteConfiguration.getString(ExecutionConfigKeys.JOB_APP_NAME); - // A read-only graph copy will be created locally for the VertexScan. - // To avoid conflicts with other VertexScans or Ops, an independent copy name is - // constructed using the job name to differentiate the storage path. - rewriteConfiguration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "VertexScanSourceFunction_" + jobName + "_" + storeCounter.getAndIncrement()); - GraphStateDescriptor desc = buildGraphStateDesc(); - desc.withMetricGroup(runtimeContext.getMetric()); - this.graphState = StateFactory.buildGraphState(desc, runtimeContext.getConfiguration()); - recover(); - this.idIterator = buildIdIterator(); - } - - protected abstract Iterator buildIdIterator(); - - protected void recover() { - LOGGER.info("Task: {} will do recover, windowId: {}", - this.runtimeContext.getTaskArgs().getTaskId(), this.runtimeContext.getWindowId()); - long lastCheckPointId = getLatestViewVersion(); - if (lastCheckPointId >= 0) { - LOGGER.info("Task: {} do recover to state VersionId: {}", this.runtimeContext.getTaskArgs().getTaskId(), - lastCheckPointId); - graphState.manage().operate().setCheckpointId(lastCheckPointId); - graphState.manage().operate().recover(); - } - } - - @Override - public void init(int parallel, int index) { - - } +import com.google.common.base.Preconditions; - protected GraphStateDescriptor buildGraphStateDesc() { - int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - int taskPara = runtimeContext.getTaskArgs().getParallelism(); - BackendType backendType = graphViewDesc.getBackend(); - GraphStateDescriptor desc = GraphStateDescriptor.build(graphViewDesc.getName() - , backendType.name()); - - int maxPara = graphViewDesc.getShardNum(); - Preconditions.checkArgument(taskPara <= maxPara, - String.format("task parallelism '%s' must be <= shard num(max parallelism) '%s'", - taskPara, maxPara)); - - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, taskPara, taskIndex); - IKeyGroupAssigner keyGroupAssigner = - KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxPara); - desc.withKeyGroup(keyGroup); - desc.withKeyGroupAssigner(keyGroupAssigner); - - long taskId = runtimeContext.getTaskArgs().getTaskId(); - int containerNum = runtimeContext.getConfiguration().getInteger(ExecutionConfigKeys.CONTAINER_NUM); - LOGGER.info("Task:{} taskId:{} taskIndex:{} keyGroup:{} containerNum:{} real taskIndex:{}", - this.runtimeContext.getTaskArgs().getTaskName(), - taskId, - taskIndex, - desc.getKeyGroup(), containerNum, runtimeContext.getTaskArgs().getTaskIndex()); - return desc; +public abstract class AbstractVertexScanSourceFunction extends RichFunction + implements SourceFunction { + + private static final Logger LOGGER = + LoggerFactory.getLogger(AbstractVertexScanSourceFunction.class); + + protected transient RuntimeContext runtimeContext; + + protected GraphViewDesc graphViewDesc; + + protected transient GraphState graphState; + + private Iterator idIterator; + + private long windSize; + + private static final AtomicInteger storeCounter = new AtomicInteger(0); + + public AbstractVertexScanSourceFunction(GraphViewDesc graphViewDesc) { + this.graphViewDesc = Objects.requireNonNull(graphViewDesc); + } + + @Override + public void open(RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + this.windSize = + this.runtimeContext.getConfiguration().getLong(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE); + Configuration rewriteConfiguration = runtimeContext.getConfiguration(); + String jobName = rewriteConfiguration.getString(ExecutionConfigKeys.JOB_APP_NAME); + // A read-only graph copy will be created locally for the VertexScan. + // To avoid conflicts with other VertexScans or Ops, an independent copy name is + // constructed using the job name to differentiate the storage path. + rewriteConfiguration.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "VertexScanSourceFunction_" + jobName + "_" + storeCounter.getAndIncrement()); + GraphStateDescriptor desc = buildGraphStateDesc(); + desc.withMetricGroup(runtimeContext.getMetric()); + this.graphState = StateFactory.buildGraphState(desc, runtimeContext.getConfiguration()); + recover(); + this.idIterator = buildIdIterator(); + } + + protected abstract Iterator buildIdIterator(); + + protected void recover() { + LOGGER.info( + "Task: {} will do recover, windowId: {}", + this.runtimeContext.getTaskArgs().getTaskId(), + this.runtimeContext.getWindowId()); + long lastCheckPointId = getLatestViewVersion(); + if (lastCheckPointId >= 0) { + LOGGER.info( + "Task: {} do recover to state VersionId: {}", + this.runtimeContext.getTaskArgs().getTaskId(), + lastCheckPointId); + graphState.manage().operate().setCheckpointId(lastCheckPointId); + graphState.manage().operate().recover(); } - - protected long getLatestViewVersion() { - long lastCheckPointId; - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graphViewDesc.getName(), - this.runtimeContext.getConfiguration()); - lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); - LOGGER.info("Task: {} will do recover or load, ViewMetaBookKeeper version: {}", - runtimeContext.getTaskArgs().getTaskId(), lastCheckPointId); - } catch (IOException e) { - throw new GeaflowRuntimeException(e); - } - return lastCheckPointId; + } + + @Override + public void init(int parallel, int index) {} + + protected GraphStateDescriptor buildGraphStateDesc() { + int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + int taskPara = runtimeContext.getTaskArgs().getParallelism(); + BackendType backendType = graphViewDesc.getBackend(); + GraphStateDescriptor desc = + GraphStateDescriptor.build(graphViewDesc.getName(), backendType.name()); + + int maxPara = graphViewDesc.getShardNum(); + Preconditions.checkArgument( + taskPara <= maxPara, + String.format( + "task parallelism '%s' must be <= shard num(max parallelism) '%s'", taskPara, maxPara)); + + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxPara, taskPara, taskIndex); + IKeyGroupAssigner keyGroupAssigner = + KeyGroupAssignerFactory.createKeyGroupAssigner(keyGroup, taskIndex, maxPara); + desc.withKeyGroup(keyGroup); + desc.withKeyGroupAssigner(keyGroupAssigner); + + long taskId = runtimeContext.getTaskArgs().getTaskId(); + int containerNum = + runtimeContext.getConfiguration().getInteger(ExecutionConfigKeys.CONTAINER_NUM); + LOGGER.info( + "Task:{} taskId:{} taskIndex:{} keyGroup:{} containerNum:{} real taskIndex:{}", + this.runtimeContext.getTaskArgs().getTaskName(), + taskId, + taskIndex, + desc.getKeyGroup(), + containerNum, + runtimeContext.getTaskArgs().getTaskIndex()); + return desc; + } + + protected long getLatestViewVersion() { + long lastCheckPointId; + try { + ViewMetaBookKeeper keeper = + new ViewMetaBookKeeper(graphViewDesc.getName(), this.runtimeContext.getConfiguration()); + lastCheckPointId = keeper.getLatestViewVersion(graphViewDesc.getName()); + LOGGER.info( + "Task: {} will do recover or load, ViewMetaBookKeeper version: {}", + runtimeContext.getTaskArgs().getTaskId(), + lastCheckPointId); + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } - - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - int count = 0; - while (idIterator.hasNext()) { - K id = idIterator.next(); - IdOnlyRequest idOnlyRequest = new IdOnlyRequest(id); - ctx.collect(idOnlyRequest); - count++; - if (count == windSize) { - break; - } - } - return count == windSize; + return lastCheckPointId; + } + + @Override + public boolean fetch(IWindow window, SourceContext ctx) + throws Exception { + int count = 0; + while (idIterator.hasNext()) { + K id = idIterator.next(); + IdOnlyRequest idOnlyRequest = new IdOnlyRequest(id); + ctx.collect(idOnlyRequest); + count++; + if (count == windSize) { + break; + } } + return count == windSize; + } - @Override - public void close() { - - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/DynamicGraphVertexScanSourceFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/DynamicGraphVertexScanSourceFunction.java index 643049fde..09dbd018a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/DynamicGraphVertexScanSourceFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/DynamicGraphVertexScanSourceFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.function.graph.source; import java.util.Iterator; + import org.apache.geaflow.model.graph.meta.GraphMeta; import org.apache.geaflow.state.DataModel; import org.apache.geaflow.state.descriptor.GraphStateDescriptor; @@ -28,21 +29,21 @@ public class DynamicGraphVertexScanSourceFunction extends AbstractVertexScanSourceFunction { - public DynamicGraphVertexScanSourceFunction(GraphViewDesc graphViewDesc) { - super(graphViewDesc); - } + public DynamicGraphVertexScanSourceFunction(GraphViewDesc graphViewDesc) { + super(graphViewDesc); + } - @Override - protected Iterator buildIdIterator() { - return graphState.dynamicGraph().V().idIterator(); - } + @Override + protected Iterator buildIdIterator() { + return graphState.dynamicGraph().V().idIterator(); + } - @Override - protected GraphStateDescriptor buildGraphStateDesc() { - GraphStateDescriptor desc = super.buildGraphStateDesc(); - desc.withDataModel(DataModel.DYNAMIC_GRAPH); - desc.withStateMode(StateMode.RW); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - return desc; - } + @Override + protected GraphStateDescriptor buildGraphStateDesc() { + GraphStateDescriptor desc = super.buildGraphStateDesc(); + desc.withDataModel(DataModel.DYNAMIC_GRAPH); + desc.withStateMode(StateMode.RW); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + return desc; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/StaticGraphVertexScanSourceFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/StaticGraphVertexScanSourceFunction.java index fa388e51f..6b0424bb6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/StaticGraphVertexScanSourceFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/graph/source/StaticGraphVertexScanSourceFunction.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.function.graph.source; import java.util.Iterator; + import org.apache.geaflow.model.graph.meta.GraphMeta; import org.apache.geaflow.state.DataModel; import org.apache.geaflow.state.descriptor.GraphStateDescriptor; @@ -28,20 +29,20 @@ public class StaticGraphVertexScanSourceFunction extends AbstractVertexScanSourceFunction { - public StaticGraphVertexScanSourceFunction(GraphViewDesc graphViewDesc) { - super(graphViewDesc); - } + public StaticGraphVertexScanSourceFunction(GraphViewDesc graphViewDesc) { + super(graphViewDesc); + } - @Override - protected Iterator buildIdIterator() { - return graphState.staticGraph().V().idIterator(); - } + @Override + protected Iterator buildIdIterator() { + return graphState.staticGraph().V().idIterator(); + } - protected GraphStateDescriptor buildGraphStateDesc() { - GraphStateDescriptor desc = super.buildGraphStateDesc(); - desc.withDataModel(DataModel.STATIC_GRAPH); - desc.withStateMode(StateMode.RW); - desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); - return desc; - } + protected GraphStateDescriptor buildGraphStateDesc() { + GraphStateDescriptor desc = super.buildGraphStateDesc(); + desc.withDataModel(DataModel.STATIC_GRAPH); + desc.withStateMode(StateMode.RW); + desc.withGraphMeta(new GraphMeta(graphViewDesc.getGraphMetaType())); + return desc; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/AggFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/AggFunction.java index a25ddfd44..3a9c3aa4a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/AggFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/AggFunction.java @@ -20,23 +20,24 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; public interface AggFunction extends Serializable { - void open(FunctionContext context); + void open(FunctionContext context); - Object createAccumulator(); + Object createAccumulator(); - void add(Row row, Object accumulator); + void add(Row row, Object accumulator); - void reset(Row row, Object accumulator); + void reset(Row row, Object accumulator); - void merge(Object acc, Object otherAcc); + void merge(Object acc, Object otherAcc); - Row getValue(Object accumulator); + Row getValue(Object accumulator); - IType[] getValueTypes(); + IType[] getValueTypes(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunction.java index e583c98dc..d39776664 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunction.java @@ -20,17 +20,18 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; public interface CorrelateFunction { - void open(FunctionContext context); + void open(FunctionContext context); - List process(Row row); + List process(Row row); - List> getLeftOutputTypes(); + List> getLeftOutputTypes(); - List> getRightOutputTypes(); + List> getRightOutputTypes(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunctionImpl.java index e1812c4ec..7443e8d2d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/CorrelateFunctionImpl.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -31,56 +32,57 @@ public class CorrelateFunctionImpl implements CorrelateFunction { - private final UDTFExpression expression; + private final UDTFExpression expression; - private final Expression filterExpression; + private final Expression filterExpression; - private final List> leftOutputTypes; + private final List> leftOutputTypes; - private final List> rightOutputTypes; + private final List> rightOutputTypes; - public CorrelateFunctionImpl(UDTFExpression expression, - Expression filterExpression, - List> leftOutputTypes, - List> rightOutputTypes) { - this.expression = Objects.requireNonNull(expression, - "CorrelateFunctionImpl: expression is null"); - this.filterExpression = filterExpression; - this.leftOutputTypes = Objects.requireNonNull(leftOutputTypes, - "CorrelateFunctionImpl: output type is null"); - this.rightOutputTypes = Objects.requireNonNull(rightOutputTypes, - "CorrelateFunctionImpl: output type is null"); - } + public CorrelateFunctionImpl( + UDTFExpression expression, + Expression filterExpression, + List> leftOutputTypes, + List> rightOutputTypes) { + this.expression = + Objects.requireNonNull(expression, "CorrelateFunctionImpl: expression is null"); + this.filterExpression = filterExpression; + this.leftOutputTypes = + Objects.requireNonNull(leftOutputTypes, "CorrelateFunctionImpl: output type is null"); + this.rightOutputTypes = + Objects.requireNonNull(rightOutputTypes, "CorrelateFunctionImpl: output type is null"); + } - @Override - public void open(FunctionContext context) { - expression.open(context); - } + @Override + public void open(FunctionContext context) { + expression.open(context); + } - @Override - public List process(Row row) { - List results = new ArrayList<>(); - for (Object[] value : (List) expression.evaluate(row)) { - Row newRow = ObjectRow.create(value); - if (filterExpression == null) { - results.add(newRow); - } else { - Object accept = filterExpression.evaluate(newRow); - if (accept != null && ((Boolean) accept)) { - results.add(newRow); - } - } + @Override + public List process(Row row) { + List results = new ArrayList<>(); + for (Object[] value : (List) expression.evaluate(row)) { + Row newRow = ObjectRow.create(value); + if (filterExpression == null) { + results.add(newRow); + } else { + Object accept = filterExpression.evaluate(newRow); + if (accept != null && ((Boolean) accept)) { + results.add(newRow); } - return results; + } } + return results; + } - @Override - public List> getLeftOutputTypes() { - return leftOutputTypes; - } + @Override + public List> getLeftOutputTypes() { + return leftOutputTypes; + } - @Override - public List> getRightOutputTypes() { - return rightOutputTypes; - } + @Override + public List> getRightOutputTypes() { + return rightOutputTypes; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ExpandFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ExpandFunction.java index 36842fee1..f39feae45 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ExpandFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ExpandFunction.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface ExpandFunction extends Serializable { - Iterable expand(Row row); + Iterable expand(Row row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunction.java index f055622eb..7f6c85592 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunction.java @@ -20,15 +20,16 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowKey; public interface GroupByFunction extends Serializable { - RowKey getRowKey(Row row); + RowKey getRowKey(Row row); - int[] getKeyFieldIndices(); + int[] getKeyFieldIndices(); - IType[] getFieldTypes(); + IType[] getFieldTypes(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunctionImpl.java index dba4302fc..c133e293f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/GroupByFunctionImpl.java @@ -28,37 +28,37 @@ public class GroupByFunctionImpl implements GroupByFunction { - private final int[] keyFieldIndices; + private final int[] keyFieldIndices; - private final IType[] keyFieldTypes; + private final IType[] keyFieldTypes; - public GroupByFunctionImpl(int[] keyFieldIndices, IType[] keyFieldTypes) { - assert keyFieldIndices.length == keyFieldTypes.length; + public GroupByFunctionImpl(int[] keyFieldIndices, IType[] keyFieldTypes) { + assert keyFieldIndices.length == keyFieldTypes.length; - this.keyFieldIndices = keyFieldIndices; - this.keyFieldTypes = keyFieldTypes; - } + this.keyFieldIndices = keyFieldIndices; + this.keyFieldTypes = keyFieldTypes; + } - @Override - public RowKey getRowKey(Row row) { - Object[] keys = new Object[keyFieldIndices.length]; - for (int i = 0; i < keys.length; i++) { - keys[i] = row.getField(keyFieldIndices[i], keyFieldTypes[i]); - } - RowKey key = ObjectRowKey.of(keys); - if (row instanceof ParameterizedRow) { - return new DefaultRowKeyWithRequestId(((ParameterizedRow) row).getRequestId(), key); - } - return key; + @Override + public RowKey getRowKey(Row row) { + Object[] keys = new Object[keyFieldIndices.length]; + for (int i = 0; i < keys.length; i++) { + keys[i] = row.getField(keyFieldIndices[i], keyFieldTypes[i]); } - - @Override - public IType[] getFieldTypes() { - return keyFieldTypes; + RowKey key = ObjectRowKey.of(keys); + if (row instanceof ParameterizedRow) { + return new DefaultRowKeyWithRequestId(((ParameterizedRow) row).getRequestId(), key); } + return key; + } - @Override - public int[] getKeyFieldIndices() { - return keyFieldIndices; - } + @Override + public IType[] getFieldTypes() { + return keyFieldTypes; + } + + @Override + public int[] getKeyFieldIndices() { + return keyFieldIndices; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/JoinTableFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/JoinTableFunction.java index d9cbf6fdc..40384fe23 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/JoinTableFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/JoinTableFunction.java @@ -20,10 +20,11 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.calcite.sql.JoinType; import org.apache.geaflow.dsl.common.data.Row; public interface JoinTableFunction extends Serializable { - Row join(Row left, Row right, JoinType joinType); + Row join(Row left, Row right, JoinType joinType); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunction.java index f89c7b672..dcf280652 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunction.java @@ -20,14 +20,15 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; public interface OrderByFunction extends Serializable { - void open(FunctionContext context); + void open(FunctionContext context); - void process(Row row); + void process(Row row); - Iterable finish(); + Iterable finish(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunctionImpl.java index 140aa450f..b08505945 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByFunctionImpl.java @@ -19,75 +19,77 @@ package org.apache.geaflow.dsl.runtime.function.table; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; import java.util.PriorityQueue; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.runtime.function.table.order.SortInfo; import org.apache.geaflow.dsl.runtime.function.table.order.TopNRowComparator; +import com.google.common.collect.Lists; + public class OrderByFunctionImpl implements OrderByFunction { - private final SortInfo sortInfo; + private final SortInfo sortInfo; - private PriorityQueue topNQueue; + private PriorityQueue topNQueue; - private List allRows; + private List allRows; - private TopNRowComparator topNRowComparator; + private TopNRowComparator topNRowComparator; - public OrderByFunctionImpl(SortInfo sortInfo) { - this.sortInfo = sortInfo; - } + public OrderByFunctionImpl(SortInfo sortInfo) { + this.sortInfo = sortInfo; + } - @Override - public void open(FunctionContext context) { - this.topNRowComparator = new TopNRowComparator<>(sortInfo); - if (sortInfo.fetch > 0) { - this.topNQueue = new PriorityQueue<>( - sortInfo.fetch, topNRowComparator.getNegativeComparator()); - } else { - this.allRows = new ArrayList<>(); - } + @Override + public void open(FunctionContext context) { + this.topNRowComparator = new TopNRowComparator<>(sortInfo); + if (sortInfo.fetch > 0) { + this.topNQueue = + new PriorityQueue<>(sortInfo.fetch, topNRowComparator.getNegativeComparator()); + } else { + this.allRows = new ArrayList<>(); } + } - @Override - public void process(Row row) { - if (sortInfo.fetch == 0) { - return; + @Override + public void process(Row row) { + if (sortInfo.fetch == 0) { + return; + } + if (topNQueue != null) { + if (topNQueue.size() == sortInfo.fetch) { + if (sortInfo.orderByFields.isEmpty()) { + return; } - if (topNQueue != null) { - if (topNQueue.size() == sortInfo.fetch) { - if (sortInfo.orderByFields.isEmpty()) { - return; - } - Row top = topNQueue.peek(); - if (topNQueue.comparator().compare(top, row) < 0) { - topNQueue.remove(); - topNQueue.add(row); - } - } else { - topNQueue.add(row); - } - } else { - allRows.add(row); + Row top = topNQueue.peek(); + if (topNQueue.comparator().compare(top, row) < 0) { + topNQueue.remove(); + topNQueue.add(row); } + } else { + topNQueue.add(row); + } + } else { + allRows.add(row); } + } - @Override - public Iterable finish() { - if (topNQueue != null) { - List results = Lists.newArrayList(topNQueue.iterator()); - results.sort(topNRowComparator); - topNQueue.clear(); - return results; - } else { - List sortedRows = new ArrayList<>(allRows); - sortedRows.sort(topNRowComparator); - allRows.clear(); - return sortedRows; - } + @Override + public Iterable finish() { + if (topNQueue != null) { + List results = Lists.newArrayList(topNQueue.iterator()); + results.sort(topNRowComparator); + topNQueue.clear(); + return results; + } else { + List sortedRows = new ArrayList<>(allRows); + sortedRows.sort(topNRowComparator); + allRows.clear(); + return sortedRows; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByHeapSort.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByHeapSort.java index dc57095ce..6b5d2925e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByHeapSort.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByHeapSort.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; import java.util.PriorityQueue; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.runtime.function.table.order.SortInfo; @@ -30,47 +31,46 @@ public class OrderByHeapSort implements OrderByFunction { - private final SortInfo sortInfo; + private final SortInfo sortInfo; - private PriorityQueue topNQueue; + private PriorityQueue topNQueue; - private TopNRowComparator topNRowComparator; + private TopNRowComparator topNRowComparator; - public OrderByHeapSort(SortInfo sortInfo) { - this.sortInfo = sortInfo; - } + public OrderByHeapSort(SortInfo sortInfo) { + this.sortInfo = sortInfo; + } - @Override - public void open(FunctionContext context) { - this.topNRowComparator = new TopNRowComparator<>(sortInfo); - this.topNQueue = new PriorityQueue<>( - sortInfo.fetch, topNRowComparator.getNegativeComparator()); - } + @Override + public void open(FunctionContext context) { + this.topNRowComparator = new TopNRowComparator<>(sortInfo); + this.topNQueue = new PriorityQueue<>(sortInfo.fetch, topNRowComparator.getNegativeComparator()); + } - @Override - public void process(Row row) { - if (topNQueue.size() == sortInfo.fetch) { - if (sortInfo.orderByFields.isEmpty()) { - return; - } - Row top = topNQueue.peek(); - if (topNQueue.comparator().compare(top, row) < 0) { - topNQueue.remove(); - topNQueue.add(row); - } - } else { - topNQueue.add(row); - } + @Override + public void process(Row row) { + if (topNQueue.size() == sortInfo.fetch) { + if (sortInfo.orderByFields.isEmpty()) { + return; + } + Row top = topNQueue.peek(); + if (topNQueue.comparator().compare(top, row) < 0) { + topNQueue.remove(); + topNQueue.add(row); + } + } else { + topNQueue.add(row); } + } - @Override - public Iterable finish() { - List results = new ArrayList<>(); - while (!topNQueue.isEmpty()) { - results.add(topNQueue.remove()); - } - Collections.reverse(results); - topNQueue.clear(); - return results; + @Override + public Iterable finish() { + List results = new ArrayList<>(); + while (!topNQueue.isEmpty()) { + results.add(topNQueue.remove()); } + Collections.reverse(results); + topNQueue.clear(); + return results; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByRadixSort.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByRadixSort.java index 515d2c48a..93640f223 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByRadixSort.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByRadixSort.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.runtime.function.table.order.MultiFieldRadixSort; @@ -28,32 +29,32 @@ public class OrderByRadixSort implements OrderByFunction { - private final SortInfo sortInfo; + private final SortInfo sortInfo; - private List allRows; + private List allRows; - public OrderByRadixSort(SortInfo sortInfo) { - this.sortInfo = sortInfo; - } + public OrderByRadixSort(SortInfo sortInfo) { + this.sortInfo = sortInfo; + } - @Override - public void open(FunctionContext context) { - this.allRows = new ArrayList<>(); - } - - @Override - public void process(Row row) { - if (sortInfo.fetch == 0) { - return; - } - allRows.add(row); - } + @Override + public void open(FunctionContext context) { + this.allRows = new ArrayList<>(); + } - @Override - public Iterable finish() { - List sortedRows = new ArrayList<>(allRows); - MultiFieldRadixSort.multiFieldRadixSort(sortedRows, sortInfo); - allRows.clear(); - return sortedRows; + @Override + public void process(Row row) { + if (sortInfo.fetch == 0) { + return; } + allRows.add(row); + } + + @Override + public Iterable finish() { + List sortedRows = new ArrayList<>(allRows); + MultiFieldRadixSort.multiFieldRadixSort(sortedRows, sortInfo); + allRows.clear(); + return sortedRows; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByTimSort.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByTimSort.java index 6b4e216d2..a8d727150 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByTimSort.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/OrderByTimSort.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.function.FunctionContext; import org.apache.geaflow.dsl.runtime.function.table.order.SortInfo; @@ -28,35 +29,35 @@ public class OrderByTimSort implements OrderByFunction { - private final SortInfo sortInfo; + private final SortInfo sortInfo; - private List allRows; + private List allRows; - private TopNRowComparator topNRowComparator; + private TopNRowComparator topNRowComparator; - public OrderByTimSort(SortInfo sortInfo) { - this.sortInfo = sortInfo; - } + public OrderByTimSort(SortInfo sortInfo) { + this.sortInfo = sortInfo; + } - @Override - public void open(FunctionContext context) { - this.topNRowComparator = new TopNRowComparator<>(sortInfo); - this.allRows = new ArrayList<>(); - } - - @Override - public void process(Row row) { - if (sortInfo.fetch == 0) { - return; - } - allRows.add(row); - } + @Override + public void open(FunctionContext context) { + this.topNRowComparator = new TopNRowComparator<>(sortInfo); + this.allRows = new ArrayList<>(); + } - @Override - public Iterable finish() { - List sortedRows = new ArrayList<>(allRows); - sortedRows.sort(topNRowComparator); - allRows.clear(); - return sortedRows; + @Override + public void process(Row row) { + if (sortInfo.fetch == 0) { + return; } + allRows.add(row); + } + + @Override + public Iterable finish() { + List sortedRows = new ArrayList<>(allRows); + sortedRows.sort(topNRowComparator); + allRows.clear(); + return sortedRows; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunction.java index 5e7a4d5d5..710294611 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunction.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface ProjectFunction extends Serializable { - Row project(Row row); + Row project(Row row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunctionImpl.java index 3ae549681..0fa8f73dc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/ProjectFunctionImpl.java @@ -20,24 +20,25 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; import org.apache.geaflow.dsl.runtime.expression.Expression; public class ProjectFunctionImpl implements ProjectFunction { - private final List projects; + private final List projects; - public ProjectFunctionImpl(List projects) { - this.projects = projects; - } + public ProjectFunctionImpl(List projects) { + this.projects = projects; + } - @Override - public Row project(Row row) { - Object[] values = new Object[projects.size()]; - for (int i = 0; i < values.length; i++) { - values[i] = projects.get(i).evaluate(row); - } - return ObjectRow.create(values); + @Override + public Row project(Row row) { + Object[] values = new Object[projects.size()]; + for (int i = 0; i < values.length; i++) { + values[i] = projects.get(i).evaluate(row); } + return ObjectRow.create(values); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunction.java index 3a383bb29..73a30551c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunction.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; public interface TableDecodeFunction extends Serializable { - Row decode(Row row); + Row decode(Row row); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunctionImpl.java index cb1100069..f4566f6f7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/TableDecodeFunctionImpl.java @@ -26,14 +26,14 @@ public class TableDecodeFunctionImpl implements TableDecodeFunction { - private final RowDecoder rowDecoder; + private final RowDecoder rowDecoder; - public TableDecodeFunctionImpl(StructType rowType) { - this.rowDecoder = new DefaultRowDecoder(rowType); - } + public TableDecodeFunctionImpl(StructType rowType) { + this.rowDecoder = new DefaultRowDecoder(rowType); + } - @Override - public Row decode(Row row) { - return rowDecoder.decode(row); - } + @Override + public Row decode(Row row) { + return rowDecoder.decode(row); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunction.java index d1b6b6d6c..3b05fdd2c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunction.java @@ -20,12 +20,13 @@ package org.apache.geaflow.dsl.runtime.function.table; import java.io.Serializable; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.expression.Expression; public interface WhereFunction extends Serializable { - boolean filter(Row row); + boolean filter(Row row); - Expression getCondition(); + Expression getCondition(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunctionImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunctionImpl.java index 7f2eee7be..c7ce84cac 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunctionImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/WhereFunctionImpl.java @@ -24,20 +24,20 @@ public class WhereFunctionImpl implements WhereFunction { - private final Expression condition; + private final Expression condition; - public WhereFunctionImpl(Expression condition) { - this.condition = condition; - } + public WhereFunctionImpl(Expression condition) { + this.condition = condition; + } - @Override - public boolean filter(Row row) { - Boolean accept = (Boolean) condition.evaluate(row); - return accept != null && accept; - } + @Override + public boolean filter(Row row) { + Boolean accept = (Boolean) condition.evaluate(row); + return accept != null && accept; + } - @Override - public Expression getCondition() { - return condition; - } + @Override + public Expression getCondition() { + return condition; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/MultiFieldRadixSort.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/MultiFieldRadixSort.java index 223ebecb3..8bbb0a31b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/MultiFieldRadixSort.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/MultiFieldRadixSort.java @@ -20,245 +20,236 @@ package org.apache.geaflow.dsl.runtime.function.table.order; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.data.Row; public class MultiFieldRadixSort { - private static final ThreadLocal dataSize = new ThreadLocal<>(); - private static final ThreadLocal intValues = new ThreadLocal<>(); - private static final ThreadLocal sortedIntValues = new ThreadLocal<>(); - private static final ThreadLocal charCodes = new ThreadLocal<>(); - private static final ThreadLocal digits = new ThreadLocal<>(); - private static final ThreadLocal stringValues = new ThreadLocal<>(); - private static final ThreadLocal sortedStringValues = new ThreadLocal<>(); - private static final ThreadLocal srcData = new ThreadLocal<>(); - private static final ThreadLocal dstData = new ThreadLocal<>(); - - /** - * Multi-field radix sort. - */ - public static void multiFieldRadixSort(List data, SortInfo sortInfo) { - if (data == null || data.size() <= 1) { - return; - } - int size = data.size(); - - try { - dataSize.set(size); - intValues.set(new int[size]); - sortedIntValues.set(new int[size]); - charCodes.set(new int[size]); - digits.set(new byte[size]); - stringValues.set(new String[size]); - sortedStringValues.set(new String[size]); - srcData.set(data.toArray(new Row[0])); - dstData.set(new Row[size]); - - // Sort by field with the lowest priority. - List fields = sortInfo.orderByFields; - - for (int i = fields.size() - 1; i >= 0; i--) { - OrderByField field = fields.get(i); - if (field.expression.getOutputType().getTypeClass() == Integer.class) { - radixSortByIntField(field); - } else { - radixSortByStringField(field); - } - } - - Row[] finalData = srcData.get(); - for (int j = 0; j < size; j++) { - data.set(j, finalData[j]); - } - } finally { - dataSize.remove(); - intValues.remove(); - sortedIntValues.remove(); - charCodes.remove(); - digits.remove(); - stringValues.remove(); - sortedStringValues.remove(); - srcData.remove(); - dstData.remove(); + private static final ThreadLocal dataSize = new ThreadLocal<>(); + private static final ThreadLocal intValues = new ThreadLocal<>(); + private static final ThreadLocal sortedIntValues = new ThreadLocal<>(); + private static final ThreadLocal charCodes = new ThreadLocal<>(); + private static final ThreadLocal digits = new ThreadLocal<>(); + private static final ThreadLocal stringValues = new ThreadLocal<>(); + private static final ThreadLocal sortedStringValues = new ThreadLocal<>(); + private static final ThreadLocal srcData = new ThreadLocal<>(); + private static final ThreadLocal dstData = new ThreadLocal<>(); + + /** Multi-field radix sort. */ + public static void multiFieldRadixSort(List data, SortInfo sortInfo) { + if (data == null || data.size() <= 1) { + return; + } + int size = data.size(); + + try { + dataSize.set(size); + intValues.set(new int[size]); + sortedIntValues.set(new int[size]); + charCodes.set(new int[size]); + digits.set(new byte[size]); + stringValues.set(new String[size]); + sortedStringValues.set(new String[size]); + srcData.set(data.toArray(new Row[0])); + dstData.set(new Row[size]); + + // Sort by field with the lowest priority. + List fields = sortInfo.orderByFields; + + for (int i = fields.size() - 1; i >= 0; i--) { + OrderByField field = fields.get(i); + if (field.expression.getOutputType().getTypeClass() == Integer.class) { + radixSortByIntField(field); + } else { + radixSortByStringField(field); } + } + + Row[] finalData = srcData.get(); + for (int j = 0; j < size; j++) { + data.set(j, finalData[j]); + } + } finally { + dataSize.remove(); + intValues.remove(); + sortedIntValues.remove(); + charCodes.remove(); + digits.remove(); + stringValues.remove(); + sortedStringValues.remove(); + srcData.remove(); + dstData.remove(); + } + } + + /** Radix sort by integer field. */ + private static void radixSortByIntField(OrderByField field) { + int size = dataSize.get(); + int[] intVals = intValues.get(); + byte[] digs = digits.get(); + Row[] src = srcData.get(); + + // Determine the number of digits. + int max = Integer.MIN_VALUE; + int min = Integer.MAX_VALUE; + boolean hasNull = false; + + for (int i = 0; i < size; i++) { + Integer value = (Integer) field.expression.evaluate(src[i]); + if (value != null) { + intVals[i] = value; + max = value > max ? value : max; + min = value < min ? value : min; + } else { + intVals[i] = Integer.MIN_VALUE; + hasNull = true; + } + } + if (hasNull) { + min--; } - /** - * Radix sort by integer field. - */ - private static void radixSortByIntField(OrderByField field) { - int size = dataSize.get(); - int[] intVals = intValues.get(); - byte[] digs = digits.get(); - Row[] src = srcData.get(); - - // Determine the number of digits. - int max = Integer.MIN_VALUE; - int min = Integer.MAX_VALUE; - boolean hasNull = false; - - for (int i = 0; i < size; i++) { - Integer value = (Integer) field.expression.evaluate(src[i]); - if (value != null) { - intVals[i] = value; - max = value > max ? value : max; - min = value < min ? value : min; - } else { - intVals[i] = Integer.MIN_VALUE; - hasNull = true; - } - } - if (hasNull) { - min--; - } - - // Handling negative numbers: Add the offset to all numbers to make them positive. - final int offset = min < 0 ? -min : 0; - max += offset; - - for (int i = 0; i < size; i++) { - if (intVals[i] == Integer.MIN_VALUE) { - intVals[i] = min; - } - intVals[i] += offset; - } + // Handling negative numbers: Add the offset to all numbers to make them positive. + final int offset = min < 0 ? -min : 0; + max += offset; - // Bitwise sorting. - for (int exp = 1; max / exp > 0; exp *= 10) { - for (int j = 0; j < size; j++) { - digs[j] = (byte) (intVals[j] / exp % 10); - } - countingSortByDigit(field.order.value > 0); - } + for (int i = 0; i < size; i++) { + if (intVals[i] == Integer.MIN_VALUE) { + intVals[i] = min; + } + intVals[i] += offset; } - /** - * Radix sorting by string field. - */ - private static void radixSortByStringField(OrderByField field) { - int size = dataSize.get(); - String[] strVals = stringValues.get(); - Row[] src = srcData.get(); - - // Precompute all strings to avoid repeated evaluation and toString. - int maxLength = 0; - - for (int i = 0; i < size; i++) { - BinaryString binaryString = (BinaryString) field.expression.evaluate(src[i]); - strVals[i] = binaryString != null ? binaryString.toString() : ""; - maxLength = Math.max(maxLength, strVals[i].length()); - } - - // Sort from the last digit of the string. - for (int pos = maxLength - 1; pos >= 0; pos--) { - countingSortByChar(field.order.value > 0, pos); - } + // Bitwise sorting. + for (int exp = 1; max / exp > 0; exp *= 10) { + for (int j = 0; j < size; j++) { + digs[j] = (byte) (intVals[j] / exp % 10); + } + countingSortByDigit(field.order.value > 0); } + } - /** - * Sort by the specified number of digits (integer). - */ - private static void countingSortByDigit(boolean ascending) { - int size = dataSize.get(); - byte[] digs = digits.get(); - int[] intVals = intValues.get(); - int[] sortedIntVals = sortedIntValues.get(); - Row[] src = srcData.get(); - Row[] dst = dstData.get(); - - int[] count = new int[10]; - - // Count the number of times each number appears. - for (int i = 0; i < size; i++) { - count[digs[i]]++; - } + /** Radix sorting by string field. */ + private static void radixSortByStringField(OrderByField field) { + int size = dataSize.get(); + String[] strVals = stringValues.get(); + Row[] src = srcData.get(); - // Calculate cumulative count. - if (ascending) { - for (int i = 1; i < 10; i++) { - count[i] += count[i - 1]; - } - } else { - for (int i = 8; i >= 0; i--) { - count[i] += count[i + 1]; - } - } + // Precompute all strings to avoid repeated evaluation and toString. + int maxLength = 0; - // Build the output array from back to front (to ensure stability). - for (int i = size - 1; i >= 0; i--) { - int index = --count[digs[i]]; - dst[index] = src[i]; - sortedIntVals[index] = intVals[i]; - } + for (int i = 0; i < size; i++) { + BinaryString binaryString = (BinaryString) field.expression.evaluate(src[i]); + strVals[i] = binaryString != null ? binaryString.toString() : ""; + maxLength = Math.max(maxLength, strVals[i].length()); + } - int[] intTmp = intVals; - intValues.set(sortedIntVals); - sortedIntValues.set(intTmp); - - Row[] rowTmp = src; - srcData.set(dst); - dstData.set(rowTmp); + // Sort from the last digit of the string. + for (int pos = maxLength - 1; pos >= 0; pos--) { + countingSortByChar(field.order.value > 0, pos); + } + } + + /** Sort by the specified number of digits (integer). */ + private static void countingSortByDigit(boolean ascending) { + int size = dataSize.get(); + byte[] digs = digits.get(); + int[] intVals = intValues.get(); + int[] sortedIntVals = sortedIntValues.get(); + Row[] src = srcData.get(); + Row[] dst = dstData.get(); + + int[] count = new int[10]; + + // Count the number of times each number appears. + for (int i = 0; i < size; i++) { + count[digs[i]]++; } - /** - * Sort by the specified number of digits (string). - */ - private static void countingSortByChar(boolean ascending, int pos) { - int size = dataSize.get(); - String[] strVals = stringValues.get(); - String[] sortedStrVals = sortedStringValues.get(); - int[] charCds = charCodes.get(); - Row[] src = srcData.get(); - Row[] dst = dstData.get(); - - // Precompute all strings and character codes to avoid repeated evaluate and toString. - int minChar = Integer.MAX_VALUE; - int maxChar = Integer.MIN_VALUE; - - for (int i = 0; i < size; i++) { - String value = strVals[i]; - if (pos < value.length()) { - int charCode = value.codePointAt(pos); - charCds[i] = charCode; - minChar = Math.min(minChar, charCode); - maxChar = Math.max(maxChar, charCode); - } - } - int range = maxChar - minChar + 2; - int[] count = new int[range]; - - for (int i = 0; i < size; i++) { - if (pos < strVals[i].length()) { - charCds[i] -= (minChar - 1); - } else { - charCds[i] = 0; // null character - } - count[charCds[i]]++; - } + // Calculate cumulative count. + if (ascending) { + for (int i = 1; i < 10; i++) { + count[i] += count[i - 1]; + } + } else { + for (int i = 8; i >= 0; i--) { + count[i] += count[i + 1]; + } + } - if (ascending) { - for (int i = 1; i < range; i++) { - count[i] += count[i - 1]; - } - } else { - for (int i = range - 2; i >= 0; i--) { - count[i] += count[i + 1]; - } - } + // Build the output array from back to front (to ensure stability). + for (int i = size - 1; i >= 0; i--) { + int index = --count[digs[i]]; + dst[index] = src[i]; + sortedIntVals[index] = intVals[i]; + } - for (int i = size - 1; i >= 0; i--) { - int index = --count[charCds[i]]; - dst[index] = src[i]; - sortedStrVals[index] = strVals[i]; - } + int[] intTmp = intVals; + intValues.set(sortedIntVals); + sortedIntValues.set(intTmp); + + Row[] rowTmp = src; + srcData.set(dst); + dstData.set(rowTmp); + } + + /** Sort by the specified number of digits (string). */ + private static void countingSortByChar(boolean ascending, int pos) { + int size = dataSize.get(); + String[] strVals = stringValues.get(); + String[] sortedStrVals = sortedStringValues.get(); + int[] charCds = charCodes.get(); + Row[] src = srcData.get(); + Row[] dst = dstData.get(); + + // Precompute all strings and character codes to avoid repeated evaluate and toString. + int minChar = Integer.MAX_VALUE; + int maxChar = Integer.MIN_VALUE; + + for (int i = 0; i < size; i++) { + String value = strVals[i]; + if (pos < value.length()) { + int charCode = value.codePointAt(pos); + charCds[i] = charCode; + minChar = Math.min(minChar, charCode); + maxChar = Math.max(maxChar, charCode); + } + } + int range = maxChar - minChar + 2; + int[] count = new int[range]; + + for (int i = 0; i < size; i++) { + if (pos < strVals[i].length()) { + charCds[i] -= (minChar - 1); + } else { + charCds[i] = 0; // null character + } + count[charCds[i]]++; + } - String[] stringTmp = strVals; - stringValues.set(sortedStrVals); - sortedStringValues.set(stringTmp); - - Row[] rowTmp = src; - srcData.set(dst); - dstData.set(rowTmp); + if (ascending) { + for (int i = 1; i < range; i++) { + count[i] += count[i - 1]; + } + } else { + for (int i = range - 2; i >= 0; i--) { + count[i] += count[i + 1]; + } } -} \ No newline at end of file + + for (int i = size - 1; i >= 0; i--) { + int index = --count[charCds[i]]; + dst[index] = src[i]; + sortedStrVals[index] = strVals[i]; + } + + String[] stringTmp = strVals; + stringValues.set(sortedStrVals); + sortedStringValues.set(stringTmp); + + Row[] rowTmp = src; + srcData.set(dst); + dstData.set(rowTmp); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/OrderByField.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/OrderByField.java index 756f31821..f023c031e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/OrderByField.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/OrderByField.java @@ -20,31 +20,31 @@ package org.apache.geaflow.dsl.runtime.function.table.order; import java.io.Serializable; + import org.apache.geaflow.dsl.runtime.expression.Expression; public class OrderByField implements Serializable { - public Expression expression; - - public ORDER order; + public Expression expression; - public enum ORDER { + public ORDER order; - ASC(1), + public enum ORDER { + ASC(1), - DESC(-1); + DESC(-1); - public final int value; - - ORDER(int value) { - this.value = value; - } - } + public final int value; - public OrderByField copy(Expression expression) { - OrderByField field = new OrderByField(); - field.expression = expression; - field.order = order; - return field; + ORDER(int value) { + this.value = value; } + } + + public OrderByField copy(Expression expression) { + OrderByField field = new OrderByField(); + field.expression = expression; + field.order = order; + return field; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/SortInfo.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/SortInfo.java index c90c9e21d..1d1f9c3d7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/SortInfo.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/SortInfo.java @@ -19,41 +19,44 @@ package org.apache.geaflow.dsl.runtime.function.table.order; -import com.google.common.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.primitive.BinaryStringType; +import com.google.common.collect.Lists; + public class SortInfo implements Serializable { - public List orderByFields = new ArrayList<>(); + public List orderByFields = new ArrayList<>(); - public int fetch = -1; + public int fetch = -1; - public SortInfo copy(List orderByFields) { - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = Lists.newArrayList(orderByFields); - sortInfo.fetch = this.fetch; - return sortInfo; - } + public SortInfo copy(List orderByFields) { + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = Lists.newArrayList(orderByFields); + sortInfo.fetch = this.fetch; + return sortInfo; + } - public boolean isRadixSortable() { - for (int i = 0; i < this.orderByFields.size(); i++) { - OrderByField field = this.orderByFields.get(i); - IType orderType = field.expression.getOutputType(); - if (orderType.getTypeClass() != Integer.class && orderType.getTypeClass() != BinaryString.class) { - return false; - } else if (orderType.getTypeClass() == BinaryString.class) { - int precision = ((BinaryStringType) orderType).getPrecision(); - // MongoDB ObjectId: 24-character hexadecimal - if (precision > 24 || precision < 0) { - return false; - } - } + public boolean isRadixSortable() { + for (int i = 0; i < this.orderByFields.size(); i++) { + OrderByField field = this.orderByFields.get(i); + IType orderType = field.expression.getOutputType(); + if (orderType.getTypeClass() != Integer.class + && orderType.getTypeClass() != BinaryString.class) { + return false; + } else if (orderType.getTypeClass() == BinaryString.class) { + int precision = ((BinaryStringType) orderType).getPrecision(); + // MongoDB ObjectId: 24-character hexadecimal + if (precision > 24 || precision < 0) { + return false; } - return true; + } } + return true; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/TopNRowComparator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/TopNRowComparator.java index ed50ea4f4..fa5bfe08d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/TopNRowComparator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/function/table/order/TopNRowComparator.java @@ -21,50 +21,51 @@ import java.util.Comparator; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; public class TopNRowComparator implements Comparator { - private final SortInfo sortInfo; - - public TopNRowComparator(SortInfo sortInfo) { - this.sortInfo = sortInfo; - } + private final SortInfo sortInfo; - public Comparator getNegativeComparator() { - return new NegativeTopNRowComparator(this.sortInfo); - } + public TopNRowComparator(SortInfo sortInfo) { + this.sortInfo = sortInfo; + } - private class NegativeTopNRowComparator extends TopNRowComparator { + public Comparator getNegativeComparator() { + return new NegativeTopNRowComparator(this.sortInfo); + } - public NegativeTopNRowComparator(SortInfo sortInfo) { - super(sortInfo); - } + private class NegativeTopNRowComparator extends TopNRowComparator { - @Override - public int compare(IN a, IN b) { - return -super.compare(a, b); - } + public NegativeTopNRowComparator(SortInfo sortInfo) { + super(sortInfo); } @Override public int compare(IN a, IN b) { - List fields = sortInfo.orderByFields; + return -super.compare(a, b); + } + } + + @Override + public int compare(IN a, IN b) { + List fields = sortInfo.orderByFields; - Object[] aOrders = new Object[fields.size()]; - Object[] bOrders = new Object[fields.size()]; + Object[] aOrders = new Object[fields.size()]; + Object[] bOrders = new Object[fields.size()]; - for (int i = 0; i < fields.size(); i++) { - OrderByField field = fields.get(i); - aOrders[i] = field.expression.evaluate(a); - bOrders[i] = field.expression.evaluate(b); - IType orderType = field.expression.getOutputType(); - int comparator = orderType.compare(aOrders[i], bOrders[i]); - if (comparator != 0) { - return comparator * (fields.get(i).order.value); - } - } - return 0; + for (int i = 0; i < fields.size(); i++) { + OrderByField field = fields.get(i); + aOrders[i] = field.expression.evaluate(a); + bOrders[i] = field.expression.evaluate(b); + IType orderType = field.expression.getOutputType(); + int comparator = orderType.compare(aOrders[i], bOrders[i]); + if (comparator != 0) { + return comparator * (fields.get(i).order.value); + } } + return 0; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicAggregateRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicAggregateRelNode.java index 0d88993e9..34e16a207 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicAggregateRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicAggregateRelNode.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.dsl.common.util.FunctionCallUtils.getUDAFGenericTypes; -import com.google.common.collect.Lists; import java.io.Serializable; import java.lang.reflect.Type; import java.util.ArrayList; @@ -32,6 +31,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -86,431 +86,423 @@ import org.apache.geaflow.dsl.udf.table.agg.SumLong; import org.apache.geaflow.dsl.util.SqlTypeUtil; -public class PhysicAggregateRelNode extends Aggregate implements PhysicRelNode { +import com.google.common.collect.Lists; - public static final String UDAF_COUNT = "COUNT"; +public class PhysicAggregateRelNode extends Aggregate implements PhysicRelNode { - public static final String UDAF_SUM = "SUM"; + public static final String UDAF_COUNT = "COUNT"; + + public static final String UDAF_SUM = "SUM"; + + public static final String UDAF_AVG = "AVG"; + + public static final String UDAF_MAX = "MAX"; + + public static final String UDAF_MIN = "MIN"; + + public static final String UDAF_STDDEV_SAMP = "STDDEV_SAMP"; + + public static final String UDAF_PERCENTILE = "PERCENTILE"; + + public PhysicAggregateRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode child, + boolean indicator, + ImmutableBitSet groupSet, + List groupSets, + List aggCalls) { + super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls); + } + + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } + + @Override + public Aggregate copy( + RelTraitSet traitSet, + RelNode input, + boolean indicator, + ImmutableBitSet groupSet, + List groupSets, + List aggCalls) { + return new PhysicAggregateRelNode( + getCluster(), traitSet, input, indicator, groupSet, groupSets, aggCalls); + } + + @Override + public RuntimeTable translate(QueryContext context) { + int[] keyFieldIndices = ArrayUtil.toIntArray(Lists.newArrayList(BitSets.toIter(groupSet))); + IType[] keyFieldTypes = getInputFieldTypes(keyFieldIndices); + GroupByFunction groupByFn = new GroupByFunctionImpl(keyFieldIndices, keyFieldTypes); + + List aggFnCalls = buildAggFunctionCalls(); + List> aggOutputTypes = + aggCalls.stream() + .map(call -> SqlTypeUtil.convertType(call.getType())) + .collect(Collectors.toList()); + AggFunction aggFn = new AggFunctionImpl(aggFnCalls, aggOutputTypes); - public static final String UDAF_AVG = "AVG"; + // translate input node. + RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - public static final String UDAF_MAX = "MAX"; + if (dataView.getType() == ViewType.TABLE) { + RuntimeTable runtimeTable = (RuntimeTable) dataView; + return runtimeTable.aggregate(groupByFn, aggFn); + } else if (dataView.getType() == ViewType.GRAPH) { + RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; + return runtimeGraph.getPathTable().aggregate(groupByFn, aggFn); + } + throw new GeaFlowDSLException("DataView: " + dataView.getType() + " cannot support Aggregate"); + } - public static final String UDAF_MIN = "MIN"; + private static class AggFunctionImpl implements AggFunction { - public static final String UDAF_STDDEV_SAMP = "STDDEV_SAMP"; + private final AggFunctionCall[] aggFunctionCalls; - public static final String UDAF_PERCENTILE = "PERCENTILE"; + private final IType[] aggOutputTypes; + private UDAF[] udafs; - public PhysicAggregateRelNode( - RelOptCluster cluster, - RelTraitSet traits, - RelNode child, - boolean indicator, - ImmutableBitSet groupSet, - List groupSets, - List aggCalls) { - super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls); + public AggFunctionImpl(List aggFunctionCalls, List> aggOutputTypes) { + this.aggFunctionCalls = + Objects.requireNonNull(aggFunctionCalls).toArray(new AggFunctionCall[] {}); + this.aggOutputTypes = Objects.requireNonNull(aggOutputTypes).toArray(new IType[] {}); } + @SuppressWarnings("unchecked") @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); + public void open(FunctionContext context) { + udafs = new UDAF[aggFunctionCalls.length]; + for (int i = 0; i < aggFunctionCalls.length; i++) { + try { + udafs[i] = + (UDAF) aggFunctionCalls[i].getUdafClass().newInstance(); + if (aggFunctionCalls[i].isDistinct) { + udafs[i] = new DistinctUDAF(udafs[i]); + } + udafs[i].open(context); + } catch (InstantiationException | IllegalAccessException e) { + throw new GeaFlowDSLException( + "Error in create UDAF: " + aggFunctionCalls[i].getUdafClass()); + } + } } @Override - public Aggregate copy( - RelTraitSet traitSet, - RelNode input, - boolean indicator, - ImmutableBitSet groupSet, - List groupSets, - List aggCalls) { - return new PhysicAggregateRelNode(getCluster(), traitSet, input, indicator, - groupSet, groupSets, aggCalls); + public Accumulator createAccumulator() { + Object[] accumulators = new Object[udafs.length]; + for (int i = 0; i < accumulators.length; i++) { + accumulators[i] = udafs[i].createAccumulator(); + } + return new Accumulator(accumulators); } @Override - public RuntimeTable translate(QueryContext context) { - int[] keyFieldIndices = ArrayUtil.toIntArray(Lists.newArrayList(BitSets.toIter(groupSet))); - IType[] keyFieldTypes = getInputFieldTypes(keyFieldIndices); - GroupByFunction groupByFn = new GroupByFunctionImpl(keyFieldIndices, keyFieldTypes); - - List aggFnCalls = buildAggFunctionCalls(); - List> aggOutputTypes = aggCalls.stream() - .map(call -> SqlTypeUtil.convertType(call.getType())) - .collect(Collectors.toList()); - AggFunction aggFn = new AggFunctionImpl(aggFnCalls, aggOutputTypes); - - // translate input node. - RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - - if (dataView.getType() == ViewType.TABLE) { - RuntimeTable runtimeTable = (RuntimeTable) dataView; - return runtimeTable.aggregate(groupByFn, aggFn); - } else if (dataView.getType() == ViewType.GRAPH) { - RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; - return runtimeGraph.getPathTable().aggregate(groupByFn, aggFn); + public void add(Row row, Object accumulator) { + for (int i = 0; i < udafs.length; i++) { + AggFunctionCall aggInfo = aggFunctionCalls[i]; + Object argValue; + if (aggInfo.argFieldIndices.length == 0) { // for count() without input parameter. + argValue = row; + } else if (aggInfo.argFieldIndices.length == 1) { + argValue = row.getField(aggInfo.argFieldIndices[0], aggInfo.argFieldTypes[0]); + } else { // for agg with multi-parameters + assert UDAFArguments.class.isAssignableFrom(aggInfo.udafInputClass); + Object[] parameters = new Object[aggInfo.argFieldIndices.length]; + for (int p = 0; p < parameters.length; p++) { + parameters[p] = row.getField(aggInfo.argFieldIndices[p], aggInfo.argFieldTypes[p]); + } + argValue = ClassUtil.newInstance(aggInfo.udafInputClass); + ((UDAFArguments) argValue).setParams(parameters); } - throw new GeaFlowDSLException("DataView: " + dataView.getType() + " cannot support Aggregate"); + udafs[i].accumulate(((Accumulator) accumulator).accumulators[i], argValue); + } } - private static class AggFunctionImpl implements AggFunction { - - private final AggFunctionCall[] aggFunctionCalls; - - private final IType[] aggOutputTypes; - - private UDAF[] udafs; - - public AggFunctionImpl(List aggFunctionCalls, List> aggOutputTypes) { - this.aggFunctionCalls = Objects.requireNonNull(aggFunctionCalls).toArray(new AggFunctionCall[]{}); - this.aggOutputTypes = Objects.requireNonNull(aggOutputTypes).toArray(new IType[]{}); - } - - @SuppressWarnings("unchecked") - @Override - public void open(FunctionContext context) { - udafs = new UDAF[aggFunctionCalls.length]; - for (int i = 0; i < aggFunctionCalls.length; i++) { - try { - udafs[i] = (UDAF) aggFunctionCalls[i].getUdafClass().newInstance(); - if (aggFunctionCalls[i].isDistinct) { - udafs[i] = new DistinctUDAF(udafs[i]); - } - udafs[i].open(context); - } catch (InstantiationException | IllegalAccessException e) { - throw new GeaFlowDSLException("Error in create UDAF: " + aggFunctionCalls[i].getUdafClass()); - } - } - } - - @Override - public Accumulator createAccumulator() { - Object[] accumulators = new Object[udafs.length]; - for (int i = 0; i < accumulators.length; i++) { - accumulators[i] = udafs[i].createAccumulator(); - } - return new Accumulator(accumulators); - } - - @Override - public void add(Row row, Object accumulator) { - for (int i = 0; i < udafs.length; i++) { - AggFunctionCall aggInfo = aggFunctionCalls[i]; - Object argValue; - if (aggInfo.argFieldIndices.length == 0) { // for count() without input parameter. - argValue = row; - } else if (aggInfo.argFieldIndices.length == 1) { - argValue = row.getField(aggInfo.argFieldIndices[0], aggInfo.argFieldTypes[0]); - } else { // for agg with multi-parameters - assert UDAFArguments.class.isAssignableFrom(aggInfo.udafInputClass); - Object[] parameters = new Object[aggInfo.argFieldIndices.length]; - for (int p = 0; p < parameters.length; p++) { - parameters[p] = row.getField(aggInfo.argFieldIndices[p], - aggInfo.argFieldTypes[p]); - } - argValue = ClassUtil.newInstance(aggInfo.udafInputClass); - ((UDAFArguments) argValue).setParams(parameters); - } - udafs[i].accumulate(((Accumulator) accumulator).accumulators[i], argValue); - } - } - - @Override - public void reset(Row row, Object accumulator) { - Accumulator acc = (Accumulator) accumulator; - for (int i = 0; i < acc.accumulators.length; i++) { - udafs[i].resetAccumulator(acc.accumulators[i]); - } - } - - @Override - public Row getValue(Object accumulator) { - Object[] aggValues = new Object[udafs.length]; - for (int i = 0; i < udafs.length; i++) { - aggValues[i] = udafs[i].getValue(((Accumulator) accumulator).accumulators[i]); - } - return ObjectRow.create(aggValues); - } - - @Override - public void merge(Object accA, Object accB) { - for (int i = 0; i < udafs.length; i++) { - udafs[i].merge( - ((Accumulator) accA).getAcc(i), - Arrays.asList(((Accumulator) accB).getAcc(i)) - ); - } - } - - @Override - public IType[] getValueTypes() { - return aggOutputTypes; - } - - private static class Accumulator implements Serializable { + @Override + public void reset(Row row, Object accumulator) { + Accumulator acc = (Accumulator) accumulator; + for (int i = 0; i < acc.accumulators.length; i++) { + udafs[i].resetAccumulator(acc.accumulators[i]); + } + } - private final Object[] accumulators; + @Override + public Row getValue(Object accumulator) { + Object[] aggValues = new Object[udafs.length]; + for (int i = 0; i < udafs.length; i++) { + aggValues[i] = udafs[i].getValue(((Accumulator) accumulator).accumulators[i]); + } + return ObjectRow.create(aggValues); + } - public Accumulator(Object[] accumulators) { - this.accumulators = accumulators; - } + @Override + public void merge(Object accA, Object accB) { + for (int i = 0; i < udafs.length; i++) { + udafs[i].merge( + ((Accumulator) accA).getAcc(i), Arrays.asList(((Accumulator) accB).getAcc(i))); + } + } - public Object getAcc(int index) { - return accumulators[index]; - } - } + @Override + public IType[] getValueTypes() { + return aggOutputTypes; } - /** - * A wrapper for distinct aggregate function. - */ - public static class DistinctUDAF extends UDAF { + private static class Accumulator implements Serializable { - private final UDAF baseUDAF; + private final Object[] accumulators; - public DistinctUDAF(UDAF baseUDAF) { - this.baseUDAF = baseUDAF; - } + public Accumulator(Object[] accumulators) { + this.accumulators = accumulators; + } - @Override - public void open(FunctionContext context) { - baseUDAF.open(context); - } + public Object getAcc(int index) { + return accumulators[index]; + } + } + } - @Override - public Object createAccumulator() { - return new HashSet<>(); - } + /** A wrapper for distinct aggregate function. */ + public static class DistinctUDAF extends UDAF { - @SuppressWarnings("unchecked") - @Override - public void accumulate(Object accumulator, Object input) { - if (input != null) { - ((Set) accumulator).add(input); - } - } + private final UDAF baseUDAF; - @SuppressWarnings("unchecked") - @Override - public void merge(Object accumulator, Iterable its) { - Set setAcc = (Set) accumulator; - for (Object it : its) { - Set otherSet = (Set) it; - setAcc.addAll(otherSet); - } - } + public DistinctUDAF(UDAF baseUDAF) { + this.baseUDAF = baseUDAF; + } - @Override - public void resetAccumulator(Object accumulator) { - ((Set) accumulator).clear(); - } + @Override + public void open(FunctionContext context) { + baseUDAF.open(context); + } - @SuppressWarnings("unchecked") - @Override - public Object getValue(Object accumulator) { - Set setAcc = (Set) accumulator; - Object acc = baseUDAF.createAccumulator(); - for (Object input : setAcc) { - baseUDAF.accumulate(acc, input); - } - return baseUDAF.getValue(acc); - } + @Override + public Object createAccumulator() { + return new HashSet<>(); } - private IType[] getInputFieldTypes(int[] fieldIndices) { - RelDataType inputType = getInput().getRowType(); - IType[] fieldTypes = new IType[fieldIndices.length]; + @SuppressWarnings("unchecked") + @Override + public void accumulate(Object accumulator, Object input) { + if (input != null) { + ((Set) accumulator).add(input); + } + } - for (int i = 0; i < fieldTypes.length; i++) { - RelDataType relType = inputType.getFieldList().get(fieldIndices[i]).getType(); - IType type = SqlTypeUtil.convertType(relType); - fieldTypes[i] = type; - } - return fieldTypes; + @SuppressWarnings("unchecked") + @Override + public void merge(Object accumulator, Iterable its) { + Set setAcc = (Set) accumulator; + for (Object it : its) { + Set otherSet = (Set) it; + setAcc.addAll(otherSet); + } } - private static class AggFunctionCall implements Serializable { - - /** - * The name of the agg function, e.g. SUM, COUNT. - */ - private final String name; - - /** - * The argument field index. - */ - private final int[] argFieldIndices; - - /** - * The argument field type. - */ - private final IType[] argFieldTypes; - - /** - * The UDAF implement class. - */ - private final Class> udafClass; - - /** - * The UDAF input class. - */ - private final Class udafInputClass; - /** - * The distinct flag. - */ - private final boolean isDistinct; - - public AggFunctionCall(String name, int[] argFieldIndices, IType[] argFieldTypes, - Class> udafClass, boolean isDistinct) { - this.name = name; - this.argFieldIndices = argFieldIndices; - this.argFieldTypes = argFieldTypes; - this.udafClass = udafClass; - Type[] genericTypes = getUDAFGenericTypes(udafClass); - this.udafInputClass = (Class) genericTypes[0]; - this.isDistinct = isDistinct; - } + @Override + public void resetAccumulator(Object accumulator) { + ((Set) accumulator).clear(); + } - public String getName() { - return name; - } + @SuppressWarnings("unchecked") + @Override + public Object getValue(Object accumulator) { + Set setAcc = (Set) accumulator; + Object acc = baseUDAF.createAccumulator(); + for (Object input : setAcc) { + baseUDAF.accumulate(acc, input); + } + return baseUDAF.getValue(acc); + } + } - public int[] getArgFieldIndices() { - return argFieldIndices; - } + private IType[] getInputFieldTypes(int[] fieldIndices) { + RelDataType inputType = getInput().getRowType(); + IType[] fieldTypes = new IType[fieldIndices.length]; - public IType[] getArgFieldTypes() { - return argFieldTypes; - } + for (int i = 0; i < fieldTypes.length; i++) { + RelDataType relType = inputType.getFieldList().get(fieldIndices[i]).getType(); + IType type = SqlTypeUtil.convertType(relType); + fieldTypes[i] = type; + } + return fieldTypes; + } + + private static class AggFunctionCall implements Serializable { + + /** The name of the agg function, e.g. SUM, COUNT. */ + private final String name; + + /** The argument field index. */ + private final int[] argFieldIndices; + + /** The argument field type. */ + private final IType[] argFieldTypes; + + /** The UDAF implement class. */ + private final Class> udafClass; + + /** The UDAF input class. */ + private final Class udafInputClass; + + /** The distinct flag. */ + private final boolean isDistinct; + + public AggFunctionCall( + String name, + int[] argFieldIndices, + IType[] argFieldTypes, + Class> udafClass, + boolean isDistinct) { + this.name = name; + this.argFieldIndices = argFieldIndices; + this.argFieldTypes = argFieldTypes; + this.udafClass = udafClass; + Type[] genericTypes = getUDAFGenericTypes(udafClass); + this.udafInputClass = (Class) genericTypes[0]; + this.isDistinct = isDistinct; + } - public Class> getUdafClass() { - return udafClass; - } + public String getName() { + return name; + } - public boolean isDistinct() { - return isDistinct; - } + public int[] getArgFieldIndices() { + return argFieldIndices; } - private List buildAggFunctionCalls() { - List aggFunctionCalls = new ArrayList<>(); - for (AggregateCall aggCall : aggCalls) { - String name = aggCall.getName(); - int[] argFieldIndices = ArrayUtil.toIntArray(aggCall.getArgList()); - IType[] argFieldTypes = getInputFieldTypes(argFieldIndices); - Class> udafClass = findUDAF(aggCall.getAggregation(), argFieldTypes); - - AggFunctionCall functionCall = new AggFunctionCall(name, argFieldIndices, argFieldTypes, udafClass, - aggCall.isDistinct()); - aggFunctionCalls.add(functionCall); - } - return aggFunctionCalls; + public IType[] getArgFieldTypes() { + return argFieldTypes; } - public static Class> findUDAF(SqlAggFunction aggFunction, IType[] argFieldTypes) { - List>> aggClasses = new ArrayList<>(); - String aggName = aggFunction.getName().toUpperCase(Locale.ROOT); - // User-defined aggregate function - if (aggFunction instanceof GeaFlowUserDefinedAggFunction) { - GeaFlowUserDefinedAggFunction function = (GeaFlowUserDefinedAggFunction) aggFunction; - aggClasses = function.getUdafClasses(); - } else { - // Build-in aggregate function - switch (aggName) { - case UDAF_COUNT: - aggClasses.add(Count.class); - break; - case UDAF_SUM: - aggClasses.add(SumLong.class); - aggClasses.add(SumDouble.class); - aggClasses.add(SumInteger.class); - break; - case UDAF_AVG: - aggClasses.add(AvgDouble.class); - aggClasses.add(AvgLong.class); - aggClasses.add(AvgInteger.class); - break; - case UDAF_MAX: - aggClasses.add(MaxLong.class); - aggClasses.add(MaxDouble.class); - aggClasses.add(MaxInteger.class); - aggClasses.add(MaxBinaryString.class); - break; - case UDAF_MIN: - aggClasses.add(MinLong.class); - aggClasses.add(MinDouble.class); - aggClasses.add(MinInteger.class); - aggClasses.add(MinBinaryString.class); - break; - case UDAF_STDDEV_SAMP: - aggClasses.add(StdDevSampLong.class); - aggClasses.add(StdDevSampDouble.class); - aggClasses.add(StdDevSampInteger.class); - break; - case UDAF_PERCENTILE: - aggClasses.add(PercentileLong.class); - aggClasses.add(PercentileInteger.class); - aggClasses.add(PercentileDouble.class); - break; - default: - throw new GeaFlowDSLException("Not support aggregate function " + aggName); - } - } + public Class> getUdafClass() { + return udafClass; + } - Class> aggClass; - if (UDAF_COUNT.equals(aggName)) { // process case for count(*) - assert aggClasses.size() == 1; - aggClass = aggClasses.get(0); - } else { - List> argTypes = Arrays.stream(argFieldTypes) - .map(IType::getTypeClass) - .collect(Collectors.toList()); - // Find the best udaf implement class according to the argument types. - aggClass = FunctionCallUtils.findMatchUDAF(aggName, aggClasses, argTypes); - } - return aggClass; + public boolean isDistinct() { + return isDistinct; + } + } + + private List buildAggFunctionCalls() { + List aggFunctionCalls = new ArrayList<>(); + for (AggregateCall aggCall : aggCalls) { + String name = aggCall.getName(); + int[] argFieldIndices = ArrayUtil.toIntArray(aggCall.getArgList()); + IType[] argFieldTypes = getInputFieldTypes(argFieldIndices); + Class> udafClass = findUDAF(aggCall.getAggregation(), argFieldTypes); + + AggFunctionCall functionCall = + new AggFunctionCall( + name, argFieldIndices, argFieldTypes, udafClass, aggCall.isDistinct()); + aggFunctionCalls.add(functionCall); + } + return aggFunctionCalls; + } + + public static Class> findUDAF( + SqlAggFunction aggFunction, IType[] argFieldTypes) { + List>> aggClasses = new ArrayList<>(); + String aggName = aggFunction.getName().toUpperCase(Locale.ROOT); + // User-defined aggregate function + if (aggFunction instanceof GeaFlowUserDefinedAggFunction) { + GeaFlowUserDefinedAggFunction function = (GeaFlowUserDefinedAggFunction) aggFunction; + aggClasses = function.getUdafClasses(); + } else { + // Build-in aggregate function + switch (aggName) { + case UDAF_COUNT: + aggClasses.add(Count.class); + break; + case UDAF_SUM: + aggClasses.add(SumLong.class); + aggClasses.add(SumDouble.class); + aggClasses.add(SumInteger.class); + break; + case UDAF_AVG: + aggClasses.add(AvgDouble.class); + aggClasses.add(AvgLong.class); + aggClasses.add(AvgInteger.class); + break; + case UDAF_MAX: + aggClasses.add(MaxLong.class); + aggClasses.add(MaxDouble.class); + aggClasses.add(MaxInteger.class); + aggClasses.add(MaxBinaryString.class); + break; + case UDAF_MIN: + aggClasses.add(MinLong.class); + aggClasses.add(MinDouble.class); + aggClasses.add(MinInteger.class); + aggClasses.add(MinBinaryString.class); + break; + case UDAF_STDDEV_SAMP: + aggClasses.add(StdDevSampLong.class); + aggClasses.add(StdDevSampDouble.class); + aggClasses.add(StdDevSampInteger.class); + break; + case UDAF_PERCENTILE: + aggClasses.add(PercentileLong.class); + aggClasses.add(PercentileInteger.class); + aggClasses.add(PercentileDouble.class); + break; + default: + throw new GeaFlowDSLException("Not support aggregate function " + aggName); + } } - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); - RelDataType inputRowType = getInput().getRowType(); - sql.append("group by "); - boolean first = true; - for (int groupByIndex : BitSets.toIter(groupSet)) { - if (first) { - first = false; - } else { - sql.append(","); - } - sql.append(inputRowType.getFieldNames().get(groupByIndex)); - } - sql.append("\n"); - sql.append("aggFunction[ "); - for (int i = 0; i < aggCalls.size(); i++) { - if (i > 0) { - sql.append(","); - } - AggregateCall call = aggCalls.get(i); - String funcName = call.getAggregation().getName(); - sql.append(funcName).append("("); - if (call.isDistinct()) { - sql.append("distinct "); - } - for (int k = 0; k < call.getArgList().size(); k++) { - if (k > 0) { - sql.append(","); - } - sql.append( - inputRowType.getFieldNames().get(call.getArgList().get(k))); - } - sql.append(")"); + Class> aggClass; + if (UDAF_COUNT.equals(aggName)) { // process case for count(*) + assert aggClasses.size() == 1; + aggClass = aggClasses.get(0); + } else { + List> argTypes = + Arrays.stream(argFieldTypes).map(IType::getTypeClass).collect(Collectors.toList()); + // Find the best udaf implement class according to the argument types. + aggClass = FunctionCallUtils.findMatchUDAF(aggName, aggClasses, argTypes); + } + return aggClass; + } + + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); + RelDataType inputRowType = getInput().getRowType(); + sql.append("group by "); + boolean first = true; + for (int groupByIndex : BitSets.toIter(groupSet)) { + if (first) { + first = false; + } else { + sql.append(","); + } + sql.append(inputRowType.getFieldNames().get(groupByIndex)); + } + sql.append("\n"); + sql.append("aggFunction[ "); + for (int i = 0; i < aggCalls.size(); i++) { + if (i > 0) { + sql.append(","); + } + AggregateCall call = aggCalls.get(i); + String funcName = call.getAggregation().getName(); + sql.append(funcName).append("("); + if (call.isDistinct()) { + sql.append("distinct "); + } + for (int k = 0; k < call.getArgList().size(); k++) { + if (k > 0) { + sql.append(","); } - sql.append(" ]"); - return sql.toString(); + sql.append(inputRowType.getFieldNames().get(call.getArgList().get(k))); + } + sql.append(")"); } + sql.append(" ]"); + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConstructGraphRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConstructGraphRelNode.java index 6f139de4b..3d41653ee 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConstructGraphRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConstructGraphRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -28,22 +29,25 @@ import org.apache.geaflow.dsl.runtime.QueryContext; import org.apache.geaflow.dsl.runtime.RuntimeGraph; -public class PhysicConstructGraphRelNode extends ConstructGraph implements PhysicRelNode { +public class PhysicConstructGraphRelNode extends ConstructGraph + implements PhysicRelNode { - public PhysicConstructGraphRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode input, List labelNames, - RelDataType rowType) { - super(cluster, traits, input, labelNames, rowType); - } + public PhysicConstructGraphRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode input, + List labelNames, + RelDataType rowType) { + super(cluster, traits, input, labelNames, rowType); + } - @Override - public RuntimeGraph translate(QueryContext context) { - return null; - } + @Override + public RuntimeGraph translate(QueryContext context) { + return null; + } - @Override - public String showSQL() { - return null; - } + @Override + public String showSQL() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConvention.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConvention.java index bc6afe07e..0c2401bdd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConvention.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicConvention.java @@ -23,42 +23,38 @@ public class PhysicConvention implements Convention { - public static final PhysicConvention INSTANCE = new PhysicConvention(); - - @Override - public Class getInterface() { - return PhysicRelNode.class; - } - - @Override - public String getName() { - return "physic"; - } - - @Override - public boolean canConvertConvention(Convention toConvention) { - return false; - } - - @Override - public boolean useAbstractConvertersForConversion(RelTraitSet fromTraits, - RelTraitSet toTraits) { - return false; - } - - @Override - public RelTraitDef getTraitDef() { - return ConventionTraitDef.INSTANCE; - } - - @Override - public boolean satisfies(RelTrait trait) { - return this.equals(trait); - } - - @Override - public void register(RelOptPlanner planner) { - - } - + public static final PhysicConvention INSTANCE = new PhysicConvention(); + + @Override + public Class getInterface() { + return PhysicRelNode.class; + } + + @Override + public String getName() { + return "physic"; + } + + @Override + public boolean canConvertConvention(Convention toConvention) { + return false; + } + + @Override + public boolean useAbstractConvertersForConversion(RelTraitSet fromTraits, RelTraitSet toTraits) { + return false; + } + + @Override + public RelTraitDef getTraitDef() { + return ConventionTraitDef.INSTANCE; + } + + @Override + public boolean satisfies(RelTrait trait) { + return this.equals(trait); + } + + @Override + public void register(RelOptPlanner planner) {} } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicCorrelateRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicCorrelateRelNode.java index 073c9c38b..c4f2d8d07 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicCorrelateRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicCorrelateRelNode.java @@ -19,10 +19,9 @@ package org.apache.geaflow.dsl.runtime.plan; -import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; import java.util.List; import java.util.stream.Collectors; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -49,104 +48,108 @@ import org.apache.geaflow.dsl.util.GQLRelUtil; import org.apache.geaflow.dsl.util.SqlTypeUtil; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; + public class PhysicCorrelateRelNode extends Correlate implements PhysicRelNode { - public PhysicCorrelateRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode left, - RelNode right, - CorrelationId correlationId, - ImmutableBitSet requiredColumns, - SemiJoinType joinType) { - super(cluster, traits, left, right, correlationId, requiredColumns, joinType); - } + public PhysicCorrelateRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode left, + RelNode right, + CorrelationId correlationId, + ImmutableBitSet requiredColumns, + SemiJoinType joinType) { + super(cluster, traits, left, right, correlationId, requiredColumns, joinType); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public Correlate copy(RelTraitSet traitSet, - RelNode left, - RelNode right, - CorrelationId correlationId, - ImmutableBitSet requiredColumns, - SemiJoinType joinType) { - return new PhysicCorrelateRelNode( - super.getCluster(), - traitSet, - left, - right, - correlationId, - requiredColumns, - joinType); - } + @Override + public Correlate copy( + RelTraitSet traitSet, + RelNode left, + RelNode right, + CorrelationId correlationId, + ImmutableBitSet requiredColumns, + SemiJoinType joinType) { + return new PhysicCorrelateRelNode( + super.getCluster(), traitSet, left, right, correlationId, requiredColumns, joinType); + } - @Override - public RuntimeTable translate(QueryContext context) { - PhysicRelNode rightInput = ((PhysicRelNode) getInput(1)); - RelNode right = GQLRelUtil.toRel(rightInput); - Filter filterRelNode = null; - while (!(right instanceof TableFunctionScan)) { - Preconditions.checkArgument(rightInput.getInputs().size() == 1); - Preconditions.checkArgument(filterRelNode == null); - Preconditions.checkArgument(rightInput instanceof Filter); - filterRelNode = (Filter) right; - right = GQLRelUtil.toRel(right.getInput(0)); - } + @Override + public RuntimeTable translate(QueryContext context) { + PhysicRelNode rightInput = ((PhysicRelNode) getInput(1)); + RelNode right = GQLRelUtil.toRel(rightInput); + Filter filterRelNode = null; + while (!(right instanceof TableFunctionScan)) { + Preconditions.checkArgument(rightInput.getInputs().size() == 1); + Preconditions.checkArgument(filterRelNode == null); + Preconditions.checkArgument(rightInput instanceof Filter); + filterRelNode = (Filter) right; + right = GQLRelUtil.toRel(right.getInput(0)); + } - Preconditions.checkArgument(right instanceof PhysicTableFunctionScanRelNode); - Expression tableExpression = ExpressionTranslator.of(right.getRowType()) + Preconditions.checkArgument(right instanceof PhysicTableFunctionScanRelNode); + Expression tableExpression = + ExpressionTranslator.of(right.getRowType()) .translate(((PhysicTableFunctionScanRelNode) right).getCall()); - Preconditions.checkArgument(tableExpression instanceof UDTFExpression); - UDTFExpression udtfExpression = (UDTFExpression) tableExpression; - List> correlateLeftOutputTypes = getLeft().getRowType().getFieldList().stream() - .map(field -> SqlTypeUtil.convertType(field.getType())).collect(Collectors.toList()); - List> correlateRightOutputTypes = getRight().getRowType().getFieldList().stream() - .map(field -> SqlTypeUtil.convertType(field.getType())).collect(Collectors.toList()); + Preconditions.checkArgument(tableExpression instanceof UDTFExpression); + UDTFExpression udtfExpression = (UDTFExpression) tableExpression; + List> correlateLeftOutputTypes = + getLeft().getRowType().getFieldList().stream() + .map(field -> SqlTypeUtil.convertType(field.getType())) + .collect(Collectors.toList()); + List> correlateRightOutputTypes = + getRight().getRowType().getFieldList().stream() + .map(field -> SqlTypeUtil.convertType(field.getType())) + .collect(Collectors.toList()); - Expression condition = null; - if (filterRelNode != null) { - ExpressionTranslator translator = ExpressionTranslator.of(this.getRowType()); - condition = translator.translate(filterRelNode.getCondition()); - } - CorrelateFunction correlateFunction = new CorrelateFunctionImpl(udtfExpression, - condition, correlateLeftOutputTypes, correlateRightOutputTypes); - RDataView input = ((PhysicRelNode) getInput(0)).translate(context); - if (input.getType() == ViewType.TABLE) { - return ((RuntimeTable) input).correlate(correlateFunction); - } else if (input.getType() == ViewType.GRAPH) { - RuntimeGraph runtimeGraph = (RuntimeGraph) input; - return runtimeGraph.getPathTable().correlate(correlateFunction); - } - throw new GeaFlowDSLException("DataView: " + input.getType() + " cannot support " - + "correlate"); + Expression condition = null; + if (filterRelNode != null) { + ExpressionTranslator translator = ExpressionTranslator.of(this.getRowType()); + condition = translator.translate(filterRelNode.getCondition()); + } + CorrelateFunction correlateFunction = + new CorrelateFunctionImpl( + udtfExpression, condition, correlateLeftOutputTypes, correlateRightOutputTypes); + RDataView input = ((PhysicRelNode) getInput(0)).translate(context); + if (input.getType() == ViewType.TABLE) { + return ((RuntimeTable) input).correlate(correlateFunction); + } else if (input.getType() == ViewType.GRAPH) { + RuntimeGraph runtimeGraph = (RuntimeGraph) input; + return runtimeGraph.getPathTable().correlate(correlateFunction); } + throw new GeaFlowDSLException( + "DataView: " + input.getType() + " cannot support " + "correlate"); + } - @SuppressWarnings("unchecked") - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); + @SuppressWarnings("unchecked") + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); - PhysicRelNode left = (PhysicRelNode) getLeft(); - PhysicRelNode right = (PhysicRelNode) getRight(); - if (left instanceof TableScan) { - String tableName = Joiner.on('.').join(left.getTable().getQualifiedName()); - sql.append(tableName); - } else { - sql.append("SubQuery[").append(left.showSQL()).append("]"); - } - sql.append(" ").append(joinType.name().toLowerCase()).append(" join "); + PhysicRelNode left = (PhysicRelNode) getLeft(); + PhysicRelNode right = (PhysicRelNode) getRight(); + if (left instanceof TableScan) { + String tableName = Joiner.on('.').join(left.getTable().getQualifiedName()); + sql.append(tableName); + } else { + sql.append("SubQuery[").append(left.showSQL()).append("]"); + } + sql.append(" ").append(joinType.name().toLowerCase()).append(" join "); - if (right instanceof TableFunctionScan) { - TableFunctionScan tableFunctionScan = (TableFunctionScan) right; - RexNode call = tableFunctionScan.getCall(); - sql.append(ExpressionUtil.showExpression(call, null, left.getRowType())); - } else { - sql.append(right.showSQL()); - } - return sql.toString(); + if (right instanceof TableFunctionScan) { + TableFunctionScan tableFunctionScan = (TableFunctionScan) right; + RexNode call = tableFunctionScan.getCall(); + sql.append(ExpressionUtil.showExpression(call, null, left.getRowType())); + } else { + sql.append(right.showSQL()); } + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicExchangeRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicExchangeRelNode.java index f5d73ce06..62819cd05 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicExchangeRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicExchangeRelNode.java @@ -33,44 +33,39 @@ public class PhysicExchangeRelNode extends Exchange implements PhysicRelNode { - public PhysicExchangeRelNode(RelOptCluster cluster, - RelTraitSet traitSet, - RelNode input, - RelDistribution distribution) { - super(cluster, traitSet, input, distribution); - } + public PhysicExchangeRelNode( + RelOptCluster cluster, RelTraitSet traitSet, RelNode input, RelDistribution distribution) { + super(cluster, traitSet, input, distribution); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public Exchange copy(RelTraitSet traitSet, RelNode input, - RelDistribution distribution) { - return new PhysicExchangeRelNode(super.getCluster(), - traitSet, input, distribution); - } + @Override + public Exchange copy(RelTraitSet traitSet, RelNode input, RelDistribution distribution) { + return new PhysicExchangeRelNode(super.getCluster(), traitSet, input, distribution); + } - @Override - public RuntimeTable translate(QueryContext context) { - return null; - } + @Override + public RuntimeTable translate(QueryContext context) { + return null; + } - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); - sql.append("PARTITION BY "); - RelDataType inputRowType = getInput().getRowType(); - for (int i = 0; i < distribution.getKeys().size(); i++) { - if (i > 0) { - sql.append(", "); - } - int key = distribution.getKeys().get(i); - sql.append(inputRowType.getFieldNames().get(key)); - } - return sql.toString(); + sql.append("PARTITION BY "); + RelDataType inputRowType = getInput().getRowType(); + for (int i = 0; i < distribution.getKeys().size(); i++) { + if (i > 0) { + sql.append(", "); + } + int key = distribution.getKeys().get(i); + sql.append(inputRowType.getFieldNames().get(key)); } + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicFilterRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicFilterRelNode.java index 2eadf970e..4264d44ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicFilterRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicFilterRelNode.java @@ -42,53 +42,49 @@ public class PhysicFilterRelNode extends Filter implements PhysicRelNode { - public PhysicFilterRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode child, - RexNode condition) { - super(cluster, traits, child, condition); - } + public PhysicFilterRelNode( + RelOptCluster cluster, RelTraitSet traits, RelNode child, RexNode condition) { + super(cluster, traits, child, condition); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public Filter copy(RelTraitSet traitSet, RelNode input, RexNode condition) { - return new PhysicFilterRelNode(super.getCluster(), traitSet, input, condition); - } + @Override + public Filter copy(RelTraitSet traitSet, RelNode input, RexNode condition) { + return new PhysicFilterRelNode(super.getCluster(), traitSet, input, condition); + } - @Override - public RDataView translate(QueryContext context) { - ExpressionTranslator translator = ExpressionTranslator.of(getInput().getRowType()); - Expression condition = translator.translate(getCondition()); - Expression preFilter = context.getPushFilter(); - if (getInput() instanceof TableScan) { - context.setPushFilter(condition); - } - RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - context.setPushFilter(preFilter); - RuntimeTable runtimeTable; - if (dataView.getType() == ViewType.TABLE) { - runtimeTable = (RuntimeTable) dataView; - } else if (dataView.getType() == ViewType.GRAPH) { - RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; - runtimeTable = runtimeGraph.getPathTable(); - } else { - throw new GeaFlowDSLException("DataView: " + dataView.getType() + " cannot support filter"); - } - WhereFunction whereFunction = new WhereFunctionImpl(condition); - return runtimeTable.filter(whereFunction); + @Override + public RDataView translate(QueryContext context) { + ExpressionTranslator translator = ExpressionTranslator.of(getInput().getRowType()); + Expression condition = translator.translate(getCondition()); + Expression preFilter = context.getPushFilter(); + if (getInput() instanceof TableScan) { + context.setPushFilter(condition); } - - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); - sql.append("WHERE "); - sql.append(ExpressionUtil.showExpression(condition, null, - getInput().getRowType())); - return sql.toString(); + RDataView dataView = ((PhysicRelNode) getInput()).translate(context); + context.setPushFilter(preFilter); + RuntimeTable runtimeTable; + if (dataView.getType() == ViewType.TABLE) { + runtimeTable = (RuntimeTable) dataView; + } else if (dataView.getType() == ViewType.GRAPH) { + RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; + runtimeTable = runtimeGraph.getPathTable(); + } else { + throw new GeaFlowDSLException("DataView: " + dataView.getType() + " cannot support filter"); } + WhereFunction whereFunction = new WhereFunctionImpl(condition); + return runtimeTable.filter(whereFunction); + } + + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); + sql.append("WHERE "); + sql.append(ExpressionUtil.showExpression(condition, null, getInput().getRowType())); + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphAlgorithm.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphAlgorithm.java index e1b3f82ed..0f383961e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphAlgorithm.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphAlgorithm.java @@ -30,30 +30,33 @@ public class PhysicGraphAlgorithm extends GraphAlgorithm implements PhysicRelNode { - public PhysicGraphAlgorithm(RelOptCluster cluster, - RelTraitSet traits, - RelNode input, - Class userFunctionClass, - Object[] params) { - super(cluster, traits, input, userFunctionClass, params); - } + public PhysicGraphAlgorithm( + RelOptCluster cluster, + RelTraitSet traits, + RelNode input, + Class userFunctionClass, + Object[] params) { + super(cluster, traits, input, userFunctionClass, params); + } - @Override - public GraphAlgorithm copy(RelTraitSet traitSet, RelNode input, - Class userFunctionClass, - Object[] params) { - return new PhysicGraphAlgorithm(input.getCluster(), traitSet, input, userFunctionClass, params); - } + @Override + public GraphAlgorithm copy( + RelTraitSet traitSet, + RelNode input, + Class userFunctionClass, + Object[] params) { + return new PhysicGraphAlgorithm(input.getCluster(), traitSet, input, userFunctionClass, params); + } - @Override - public RuntimeTable translate(QueryContext context) { - PhysicRelNode input = (PhysicRelNode) getInput(); - RuntimeGraph inputGraph = input.translate(context); - return inputGraph.runAlgorithm(this); - } + @Override + public RuntimeTable translate(QueryContext context) { + PhysicRelNode input = (PhysicRelNode) getInput(); + RuntimeGraph inputGraph = input.translate(context); + return inputGraph.runAlgorithm(this); + } - @Override - public String showSQL() { - return null; - } + @Override + public String showSQL() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphMatchRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphMatchRelNode.java index 6d7ba8775..a396a18f0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphMatchRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphMatchRelNode.java @@ -30,29 +30,31 @@ public class PhysicGraphMatchRelNode extends GraphMatch implements PhysicRelNode { - public PhysicGraphMatchRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode input, - IMatchNode pathPattern, - RelDataType rowType) { - super(cluster, traits, input, pathPattern, rowType); - } + public PhysicGraphMatchRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode input, + IMatchNode pathPattern, + RelDataType rowType) { + super(cluster, traits, input, pathPattern, rowType); + } - @SuppressWarnings("unchecked") - @Override - public RuntimeGraph translate(QueryContext context) { - PhysicRelNode input = (PhysicRelNode) getInput(); - RuntimeGraph inputGraph = input.translate(context); - return inputGraph.traversal(this); - } + @SuppressWarnings("unchecked") + @Override + public RuntimeGraph translate(QueryContext context) { + PhysicRelNode input = (PhysicRelNode) getInput(); + RuntimeGraph inputGraph = input.translate(context); + return inputGraph.traversal(this); + } - @Override - public GraphMatch copy(RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType) { - return new PhysicGraphMatchRelNode(getCluster(), traitSet, input, pathPattern, rowType); - } + @Override + public GraphMatch copy( + RelTraitSet traitSet, RelNode input, IMatchNode pathPattern, RelDataType rowType) { + return new PhysicGraphMatchRelNode(getCluster(), traitSet, input, pathPattern, rowType); + } - @Override - public String showSQL() { - return null; - } + @Override + public String showSQL() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphModifyRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphModifyRelNode.java index 0bda1120e..2f28d13cb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphModifyRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphModifyRelNode.java @@ -33,30 +33,30 @@ public class PhysicGraphModifyRelNode extends GraphModify implements PhysicRelNode { - public PhysicGraphModifyRelNode(RelOptCluster cluster, RelTraitSet traitSet, - GeaFlowGraph graph, RelNode input) { - super(cluster, traitSet, graph, input); + public PhysicGraphModifyRelNode( + RelOptCluster cluster, RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) { + super(cluster, traitSet, graph, input); + } + + @Override + public GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) { + return new PhysicGraphModifyRelNode(getCluster(), traitSet, graph, input); + } + + @Override + public SinkDataView translate(QueryContext context) { + context.addReferTargetGraph(graph); + + RDataView dataView = ((PhysicRelNode) getInput()).translate(context); + if (dataView.getType() == ViewType.TABLE) { + RuntimeTable runtimeTable = (RuntimeTable) dataView; + return runtimeTable.write(graph, context); } + throw new GeaFlowDSLException("DataView type: {} cannot insert to graph", dataView.getType()); + } - @Override - public GraphModify copy(RelTraitSet traitSet, GeaFlowGraph graph, RelNode input) { - return new PhysicGraphModifyRelNode(getCluster(), traitSet, graph, input); - } - - @Override - public SinkDataView translate(QueryContext context) { - context.addReferTargetGraph(graph); - - RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - if (dataView.getType() == ViewType.TABLE) { - RuntimeTable runtimeTable = (RuntimeTable) dataView; - return runtimeTable.write(graph, context); - } - throw new GeaFlowDSLException("DataView type: {} cannot insert to graph", dataView.getType()); - } - - @Override - public String showSQL() { - return null; - } + @Override + public String showSQL() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphScanRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphScanRelNode.java index 3f57ae7a8..052fed2a3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphScanRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicGraphScanRelNode.java @@ -29,31 +29,29 @@ public class PhysicGraphScanRelNode extends GraphScan implements PhysicRelNode { - public PhysicGraphScanRelNode(RelOptCluster cluster, - RelTraitSet traitSet, - RelOptTable table) { - super(cluster, traitSet, table); - } + public PhysicGraphScanRelNode(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) { + super(cluster, traitSet, table); + } - @Override - public RuntimeGraph translate(QueryContext context) { - GeaFlowGraph graph = table.unwrap(GeaFlowGraph.class); - RuntimeGraph runtimeGraph = context.getRuntimeGraph(graph.getName()); - if (runtimeGraph == null) { - context.addReferSourceGraph(graph); - runtimeGraph = context.getEngineContext().createRuntimeGraph(context, graph); - context.putRuntimeGraph(graph.getName(), runtimeGraph); - } - return runtimeGraph; + @Override + public RuntimeGraph translate(QueryContext context) { + GeaFlowGraph graph = table.unwrap(GeaFlowGraph.class); + RuntimeGraph runtimeGraph = context.getRuntimeGraph(graph.getName()); + if (runtimeGraph == null) { + context.addReferSourceGraph(graph); + runtimeGraph = context.getEngineContext().createRuntimeGraph(context, graph); + context.putRuntimeGraph(graph.getName(), runtimeGraph); } + return runtimeGraph; + } - @Override - public PhysicGraphScanRelNode copy(RelTraitSet traitSet, RelOptTable table) { - return new PhysicGraphScanRelNode(getCluster(), traitSet, table); - } + @Override + public PhysicGraphScanRelNode copy(RelTraitSet traitSet, RelOptTable table) { + return new PhysicGraphScanRelNode(getCluster(), traitSet, table); + } - @Override - public String showSQL() { - return table.unwrap(GeaFlowGraph.class).toString(); - } + @Override + public String showSQL() { + return table.unwrap(GeaFlowGraph.class).toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicParameterizedRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicParameterizedRelNode.java index 09400526d..c95ad8c5d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicParameterizedRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicParameterizedRelNode.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -41,106 +42,111 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PhysicParameterizedRelNode extends ParameterizedRelNode implements PhysicRelNode { +public class PhysicParameterizedRelNode extends ParameterizedRelNode + implements PhysicRelNode { - private static final Logger LOGGER = LoggerFactory.getLogger(PhysicParameterizedRelNode.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PhysicParameterizedRelNode.class); - public PhysicParameterizedRelNode(RelOptCluster cluster, - RelTraitSet traitSet, - RelNode parameter, RelNode query) { - super(cluster, traitSet, parameter, query); - } + public PhysicParameterizedRelNode( + RelOptCluster cluster, RelTraitSet traitSet, RelNode parameter, RelNode query) { + super(cluster, traitSet, parameter, query); + } - @Override - public PhysicParameterizedRelNode copy(RelTraitSet traitSet, List inputs) { - assert inputs.size() == 2; - return new PhysicParameterizedRelNode(getCluster(), traitSet, inputs.get(0), inputs.get(1)); + @Override + public PhysicParameterizedRelNode copy(RelTraitSet traitSet, List inputs) { + assert inputs.size() == 2; + return new PhysicParameterizedRelNode(getCluster(), traitSet, inputs.get(0), inputs.get(1)); + } + + @SuppressWarnings("unchecked") + @Override + public RDataView translate(QueryContext context) { + RuntimeTable requestTable = + ((PhysicRelNode) getParameterNode()).translate(context); + RelNode newQueryNode = isIdOnlyRequest(); + boolean idOnlyRequest = newQueryNode != getQueryNode(); + if (idOnlyRequest) { + LOGGER.info("It is id only parameter request."); } + RuntimeTable preRequestTable = context.setRequestTable(requestTable); + boolean preIsIdOnlyRequest = context.setIdOnlyRequest(idOnlyRequest); + RDataView dataView = ((PhysicRelNode) newQueryNode).translate(context); - @SuppressWarnings("unchecked") - @Override - public RDataView translate(QueryContext context) { - RuntimeTable requestTable = ((PhysicRelNode) getParameterNode()).translate(context); - RelNode newQueryNode = isIdOnlyRequest(); - boolean idOnlyRequest = newQueryNode != getQueryNode(); - if (idOnlyRequest) { - LOGGER.info("It is id only parameter request."); - } - RuntimeTable preRequestTable = context.setRequestTable(requestTable); - boolean preIsIdOnlyRequest = context.setIdOnlyRequest(idOnlyRequest); - RDataView dataView = ((PhysicRelNode) newQueryNode).translate(context); + context.setRequestTable(preRequestTable); + context.setIdOnlyRequest(preIsIdOnlyRequest); + return dataView; + } - context.setRequestTable(preRequestTable); - context.setIdOnlyRequest(preIsIdOnlyRequest); - return dataView; + private RelNode isIdOnlyRequest() { + Set idReferences = new HashSet<>(); + RelNode newQueryNode = isIdOnlyRequest(getQueryNode(), idReferences); + // Only one id parameter has referred. + if (idReferences.size() == 1 && newQueryNode != null) { + return newQueryNode; } + return getQueryNode(); + } - private RelNode isIdOnlyRequest() { - Set idReferences = new HashSet<>(); - RelNode newQueryNode = isIdOnlyRequest(getQueryNode(), idReferences); - // Only one id parameter has referred. - if (idReferences.size() == 1 && newQueryNode != null) { - return newQueryNode; - } - return getQueryNode(); + private RelNode isIdOnlyRequest(RelNode node, Set idReferences) { + if (node instanceof GraphMatch) { + GraphMatch match = (GraphMatch) node; + IMatchNode newPathPattern = + (IMatchNode) isIdOnlyRequest(match.getPathPattern(), idReferences); + if (newPathPattern == null) { + return null; + } + return match.copy(newPathPattern); } + RelNode newNode = node; + List newInputs = new ArrayList<>(node.getInputs().size()); - private RelNode isIdOnlyRequest(RelNode node, Set idReferences) { - if (node instanceof GraphMatch) { - GraphMatch match = (GraphMatch) node; - IMatchNode newPathPattern = (IMatchNode) isIdOnlyRequest(match.getPathPattern(), idReferences); - if (newPathPattern == null) { - return null; - } - return match.copy(newPathPattern); - } - RelNode newNode = node; - List newInputs = new ArrayList<>(node.getInputs().size()); + if (node instanceof MatchFilter + && ((MatchFilter) node).getInput() instanceof VertexMatch + && ((MatchFilter) node).getInput().getInputs().isEmpty()) { + MatchFilter filter = (MatchFilter) node; + VertexRecordType vertexRecordType = + (VertexRecordType) ((VertexMatch) filter.getInput()).getNodeType(); + RexNode conditionRemoveId = + GQLRexUtil.removeIdCondition(filter.getCondition(), vertexRecordType); - if (node instanceof MatchFilter - && ((MatchFilter) node).getInput() instanceof VertexMatch - && ((MatchFilter) node).getInput().getInputs().isEmpty()) { - MatchFilter filter = (MatchFilter) node; - VertexRecordType vertexRecordType = (VertexRecordType) ((VertexMatch) filter.getInput()).getNodeType(); - RexNode conditionRemoveId = GQLRexUtil.removeIdCondition(filter.getCondition(), vertexRecordType); + Set ids = GQLRexUtil.findVertexIds(filter.getCondition(), vertexRecordType); + idReferences.addAll(ids); + // It contains parameter reference except the id request. + boolean isIdOnlyRef = + conditionRemoveId == null + || !GQLRexUtil.contain(conditionRemoveId, RexParameterRef.class); + VertexMatch vertexMatch = (VertexMatch) filter.getInput(); + // push filter to vertex-match. + newInputs.add(vertexMatch.copy(filter.getCondition())); - Set ids = GQLRexUtil.findVertexIds(filter.getCondition(), vertexRecordType); - idReferences.addAll(ids); - // It contains parameter reference except the id request. - boolean isIdOnlyRef = conditionRemoveId == null || !GQLRexUtil.contain(conditionRemoveId, - RexParameterRef.class); - VertexMatch vertexMatch = (VertexMatch) filter.getInput(); - // push filter to vertex-match. - newInputs.add(vertexMatch.copy(filter.getCondition())); - - if (isIdOnlyRef) { - if (conditionRemoveId != null) { - newNode = filter.copy(filter.getTraitSet(), filter.getInput(), conditionRemoveId); - } else { // remove current filter. - return newInputs.get(0); - } - } else { - return null; - } - } else { - boolean containParameterRef = - !GQLRexUtil.collect(node, rexNode -> rexNode instanceof RexParameterRef).isEmpty(); - if (containParameterRef) { - return null; - } - for (RelNode input : node.getInputs()) { - RelNode newInput = isIdOnlyRequest(input, idReferences); - if (newInput == null) { - return null; - } - newInputs.add(newInput); - } + if (isIdOnlyRef) { + if (conditionRemoveId != null) { + newNode = filter.copy(filter.getTraitSet(), filter.getInput(), conditionRemoveId); + } else { // remove current filter. + return newInputs.get(0); } - return newNode.copy(node.getTraitSet(), newInputs); - } - - @Override - public String showSQL() { + } else { return null; + } + } else { + boolean containParameterRef = + !GQLRexUtil.collect(node, rexNode -> rexNode instanceof RexParameterRef).isEmpty(); + if (containParameterRef) { + return null; + } + for (RelNode input : node.getInputs()) { + RelNode newInput = isIdOnlyRequest(input, idReferences); + if (newInput == null) { + return null; + } + newInputs.add(newInput); + } } + return newNode.copy(node.getTraitSet(), newInputs); + } + + @Override + public String showSQL() { + return null; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicProjectRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicProjectRelNode.java index 6083815c4..543994af4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicProjectRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicProjectRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -42,53 +43,52 @@ public class PhysicProjectRelNode extends Project implements PhysicRelNode { - public PhysicProjectRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode input, - List projects, - RelDataType rowType) { - super(cluster, traits, input, projects, rowType); - } + public PhysicProjectRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode input, + List projects, + RelDataType rowType) { + super(cluster, traits, input, projects, rowType); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public Project copy(RelTraitSet traitSet, - RelNode input, - List projects, - RelDataType rowType) { - return new PhysicProjectRelNode(super.getCluster(), traitSet, input, projects, rowType); - } + @Override + public Project copy( + RelTraitSet traitSet, RelNode input, List projects, RelDataType rowType) { + return new PhysicProjectRelNode(super.getCluster(), traitSet, input, projects, rowType); + } - @Override - public RuntimeTable translate(QueryContext context) { - List projects = ExpressionTranslator.of(getInput().getRowType()).translate(getProjects()); - ProjectFunction projectFunction = new ProjectFunctionImpl(projects); + @Override + public RuntimeTable translate(QueryContext context) { + List projects = + ExpressionTranslator.of(getInput().getRowType()).translate(getProjects()); + ProjectFunction projectFunction = new ProjectFunctionImpl(projects); - RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - if (dataView.getType() == ViewType.TABLE) { - RuntimeTable runtimeTable = (RuntimeTable) dataView; - return runtimeTable.project(projectFunction); - } else { // project path for graph. - RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; - return runtimeGraph.getPathTable().project(projectFunction); - } + RDataView dataView = ((PhysicRelNode) getInput()).translate(context); + if (dataView.getType() == ViewType.TABLE) { + RuntimeTable runtimeTable = (RuntimeTable) dataView; + return runtimeTable.project(projectFunction); + } else { // project path for graph. + RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; + return runtimeGraph.getPathTable().project(projectFunction); } + } - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); - sql.append("SELECT \n"); - for (int i = 0; i < exps.size(); i++) { - if (i > 0) { - sql.append("\n"); - } - sql.append(ExpressionUtil.showExpression(exps.get(i), null, - input.getRowType())); - } - return sql.toString(); + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); + sql.append("SELECT \n"); + for (int i = 0; i < exps.size(); i++) { + if (i > 0) { + sql.append("\n"); + } + sql.append(ExpressionUtil.showExpression(exps.get(i), null, input.getRowType())); } + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicRelNode.java index 04711de6e..5a241be4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicRelNode.java @@ -25,72 +25,46 @@ public interface PhysicRelNode extends RelNode { - D translate(QueryContext context); + D translate(QueryContext context); - String showSQL(); + String showSQL(); - enum PhysicRelNodeName { - /** - * Name for table scan node. - */ - TABLE_SCAN("TableScan"), - /** - * Name for project node. - */ - PROJECT("Project"), - /** - * Name for filter node. - */ - FILTER("Filter"), - /** - * Name for aggregate node. - */ - AGGREGATE("Aggregate"), - /** - * Name for join node. - */ - JOIN("Join"), - /** - * Name for table function node. - */ - TABLE_FUNCTION("TableFunction"), - /** - * Name for union node. - */ - UNION("Union"), - /** - * Name for sort node. - */ - SORT("Sort"), - /** - * Name for correlate node. - */ - CORRELATE("Correlate"), - /** - * Name for value scan node. - */ - VALUE_SCAN("ValueScan"), - /** - * Name for graph scan node. - */ - GRAPH_SCAN("GraphScan"), - /** - * Name for graph match node. - */ - GRAPH_MATCH("GraphMatch"), - /** - * Name for table sink node. - */ - TABLE_SINK("TableSink"); + enum PhysicRelNodeName { + /** Name for table scan node. */ + TABLE_SCAN("TableScan"), + /** Name for project node. */ + PROJECT("Project"), + /** Name for filter node. */ + FILTER("Filter"), + /** Name for aggregate node. */ + AGGREGATE("Aggregate"), + /** Name for join node. */ + JOIN("Join"), + /** Name for table function node. */ + TABLE_FUNCTION("TableFunction"), + /** Name for union node. */ + UNION("Union"), + /** Name for sort node. */ + SORT("Sort"), + /** Name for correlate node. */ + CORRELATE("Correlate"), + /** Name for value scan node. */ + VALUE_SCAN("ValueScan"), + /** Name for graph scan node. */ + GRAPH_SCAN("GraphScan"), + /** Name for graph match node. */ + GRAPH_MATCH("GraphMatch"), + /** Name for table sink node. */ + TABLE_SINK("TableSink"); - private final String namePrefix; + private final String namePrefix; - PhysicRelNodeName(String namePrefix) { - this.namePrefix = namePrefix; - } + PhysicRelNodeName(String namePrefix) { + this.namePrefix = namePrefix; + } - public String getName(Object suffix) { - return namePrefix + "-" + suffix; - } + public String getName(Object suffix) { + return namePrefix + "-" + suffix; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicSortRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicSortRelNode.java index 7823eceef..ec9045878 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicSortRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicSortRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelCollation; @@ -48,84 +49,88 @@ public class PhysicSortRelNode extends Sort implements PhysicRelNode { - public PhysicSortRelNode(RelOptCluster cluster, - RelTraitSet traits, - RelNode child, - RelCollation collation, - RexNode offset, - RexNode fetch) { - super(cluster, traits, child, collation, offset, fetch); - } + public PhysicSortRelNode( + RelOptCluster cluster, + RelTraitSet traits, + RelNode child, + RelCollation collation, + RexNode offset, + RexNode fetch) { + super(cluster, traits, child, collation, offset, fetch); + } - @Override - public Sort copy(RelTraitSet traitSet, - RelNode newInput, - RelCollation newCollation, - RexNode offset, - RexNode fetch) { - return new PhysicSortRelNode(super.getCluster(), traitSet, newInput, - newCollation, offset, fetch); - } + @Override + public Sort copy( + RelTraitSet traitSet, + RelNode newInput, + RelCollation newCollation, + RexNode offset, + RexNode fetch) { + return new PhysicSortRelNode( + super.getCluster(), traitSet, newInput, newCollation, offset, fetch); + } - @Override - public RuntimeTable translate(QueryContext context) { - SortInfo sortInfo = buildSortInfo(); - RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - - OrderByFunction orderByFunction; - if (sortInfo.fetch > 0) { - orderByFunction = new OrderByHeapSort(sortInfo); - } else if (sortInfo.isRadixSortable()) { - orderByFunction = new OrderByRadixSort(sortInfo); - } else { - orderByFunction = new OrderByTimSort(sortInfo); - } - if (dataView.getType() == ViewType.TABLE) { - return ((RuntimeTable) dataView).orderBy(orderByFunction); - } else { - assert dataView instanceof RuntimeGraph; - RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; - return runtimeGraph.getPathTable().orderBy(orderByFunction); - } - } + @Override + public RuntimeTable translate(QueryContext context) { + SortInfo sortInfo = buildSortInfo(); + RDataView dataView = ((PhysicRelNode) getInput()).translate(context); - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); - sql.append("Order By "); - collation.getFieldCollations().forEach(fc -> { - String name = getInput().getRowType().getFieldNames().get(fc.getFieldIndex()); - sql.append(name).append(" ").append(fc.direction.shortString); - }); - return sql.toString(); + OrderByFunction orderByFunction; + if (sortInfo.fetch > 0) { + orderByFunction = new OrderByHeapSort(sortInfo); + } else if (sortInfo.isRadixSortable()) { + orderByFunction = new OrderByRadixSort(sortInfo); + } else { + orderByFunction = new OrderByTimSort(sortInfo); + } + if (dataView.getType() == ViewType.TABLE) { + return ((RuntimeTable) dataView).orderBy(orderByFunction); + } else { + assert dataView instanceof RuntimeGraph; + RuntimeGraph runtimeGraph = (RuntimeGraph) dataView; + return runtimeGraph.getPathTable().orderBy(orderByFunction); } + } + + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); + sql.append("Order By "); + collation + .getFieldCollations() + .forEach( + fc -> { + String name = getInput().getRowType().getFieldNames().get(fc.getFieldIndex()); + sql.append(name).append(" ").append(fc.direction.shortString); + }); + return sql.toString(); + } - private SortInfo buildSortInfo() { - SortInfo sortInfo = new SortInfo(); - for (RelFieldCollation fc : collation.getFieldCollations()) { - List fieldList = getRowType().getFieldList(); - IType fieldType = SqlTypeUtil.convertType(fieldList.get(fc.getFieldIndex()).getType()); + private SortInfo buildSortInfo() { + SortInfo sortInfo = new SortInfo(); + for (RelFieldCollation fc : collation.getFieldCollations()) { + List fieldList = getRowType().getFieldList(); + IType fieldType = SqlTypeUtil.convertType(fieldList.get(fc.getFieldIndex()).getType()); - OrderByField orderByField = new OrderByField(); - orderByField.expression = new FieldExpression(fc.getFieldIndex(), fieldType); - switch (fc.getDirection()) { - case ASCENDING: - orderByField.order = ORDER.ASC; - break; - case DESCENDING: - orderByField.order = ORDER.DESC; - break; - default: - throw new UnsupportedOperationException( - "UnSupport sort type: " + fc.getDirection()); - } - sortInfo.orderByFields.add(orderByField); - } - ExpressionTranslator translator = ExpressionTranslator.of(getInput().getRowType()); - sortInfo.fetch = fetch == null ? -1 : - (int) TypeCastUtil.cast( - translator.translate(fetch).evaluate(null), - Integer.class); - return sortInfo; + OrderByField orderByField = new OrderByField(); + orderByField.expression = new FieldExpression(fc.getFieldIndex(), fieldType); + switch (fc.getDirection()) { + case ASCENDING: + orderByField.order = ORDER.ASC; + break; + case DESCENDING: + orderByField.order = ORDER.DESC; + break; + default: + throw new UnsupportedOperationException("UnSupport sort type: " + fc.getDirection()); + } + sortInfo.orderByFields.add(orderByField); } + ExpressionTranslator translator = ExpressionTranslator.of(getInput().getRowType()); + sortInfo.fetch = + fetch == null + ? -1 + : (int) TypeCastUtil.cast(translator.translate(fetch).evaluate(null), Integer.class); + return sortInfo; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableFunctionScanRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableFunctionScanRelNode.java index 7ac248c9e..bc271b44a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableFunctionScanRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableFunctionScanRelNode.java @@ -22,6 +22,7 @@ import java.lang.reflect.Type; import java.util.List; import java.util.Set; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -37,45 +38,44 @@ import org.apache.geaflow.dsl.runtime.RuntimeTable; import org.apache.geaflow.dsl.util.ExpressionUtil; -public class PhysicTableFunctionScanRelNode extends TableFunctionScan implements PhysicRelNode { - - public PhysicTableFunctionScanRelNode(RelOptCluster cluster, - RelTraitSet traits, - List inputs, RexNode rexCall, - Type elementType, RelDataType rowType, - Set columnMappings) { - super(cluster, traits, inputs, rexCall, elementType, rowType, columnMappings); - } +public class PhysicTableFunctionScanRelNode extends TableFunctionScan + implements PhysicRelNode { - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + public PhysicTableFunctionScanRelNode( + RelOptCluster cluster, + RelTraitSet traits, + List inputs, + RexNode rexCall, + Type elementType, + RelDataType rowType, + Set columnMappings) { + super(cluster, traits, inputs, rexCall, elementType, rowType, columnMappings); + } - @Override - public TableFunctionScan copy(RelTraitSet traitSet, List inputs, - RexNode rexCall, - Type elementType, RelDataType rowType, - Set columnMappings) { - return new PhysicTableFunctionScanRelNode( - super.getCluster(), - traitSet, - inputs, - rexCall, - elementType, - rowType, - columnMappings); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public RuntimeTable translate(QueryContext context) { - throw new GeaFlowDSLException("Table function is not support"); - } + @Override + public TableFunctionScan copy( + RelTraitSet traitSet, + List inputs, + RexNode rexCall, + Type elementType, + RelDataType rowType, + Set columnMappings) { + return new PhysicTableFunctionScanRelNode( + super.getCluster(), traitSet, inputs, rexCall, elementType, rowType, columnMappings); + } - @Override - public String showSQL() { - return ExpressionUtil.showExpression(getCall(), null, null); - } + @Override + public RuntimeTable translate(QueryContext context) { + throw new GeaFlowDSLException("Table function is not support"); + } + @Override + public String showSQL() { + return ExpressionUtil.showExpression(getCall(), null, null); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableModifyRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableModifyRelNode.java index 54aae6762..326033c92 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableModifyRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableModifyRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.*; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; @@ -33,51 +34,60 @@ public class PhysicTableModifyRelNode extends TableModify implements PhysicRelNode { - public PhysicTableModifyRelNode(RelOptCluster cluster, - RelTraitSet traitSet, - RelOptTable table, - Prepare.CatalogReader catalogReader, - RelNode input, - Operation operation, - List updateColumnList, - List sourceExpressionList, - boolean flattened) { - super(cluster, traitSet, table, catalogReader, input, operation, - updateColumnList, sourceExpressionList, flattened); - } + public PhysicTableModifyRelNode( + RelOptCluster cluster, + RelTraitSet traitSet, + RelOptTable table, + Prepare.CatalogReader catalogReader, + RelNode input, + Operation operation, + List updateColumnList, + List sourceExpressionList, + boolean flattened) { + super( + cluster, + traitSet, + table, + catalogReader, + input, + operation, + updateColumnList, + sourceExpressionList, + flattened); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public RelNode copy(RelTraitSet traitSet, List inputs) { - return new PhysicTableModifyRelNode(super.getCluster(), - traitSet, - super.getTable(), - super.catalogReader, - sole(inputs), - super.getOperation(), - super.getUpdateColumnList(), - super.getSourceExpressionList(), - super.isFlattened()); - } + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new PhysicTableModifyRelNode( + super.getCluster(), + traitSet, + super.getTable(), + super.catalogReader, + sole(inputs), + super.getOperation(), + super.getUpdateColumnList(), + super.getSourceExpressionList(), + super.isFlattened()); + } - @SuppressWarnings("unchecked") - @Override - public SinkDataView translate(QueryContext context) { - GeaFlowTable table = getTable().unwrap(GeaFlowTable.class); - context.addReferTargetTable(table); + @SuppressWarnings("unchecked") + @Override + public SinkDataView translate(QueryContext context) { + GeaFlowTable table = getTable().unwrap(GeaFlowTable.class); + context.addReferTargetTable(table); - RuntimeTable runtimeTable = ((PhysicRelNode) getInput()).translate(context); - return runtimeTable.write(table); - } + RuntimeTable runtimeTable = ((PhysicRelNode) getInput()).translate(context); + return runtimeTable.write(table); + } - @Override - public String showSQL() { - GeaFlowTable table = super.getTable().unwrap(GeaFlowTable.class); - return table.toString(); - } + @Override + public String showSQL() { + GeaFlowTable table = super.getTable().unwrap(GeaFlowTable.class); + return table.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableScanRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableScanRelNode.java index 891ac62c7..7de03c523 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableScanRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicTableScanRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.*; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.TableScan; @@ -31,39 +32,37 @@ public class PhysicTableScanRelNode extends TableScan implements PhysicRelNode { - public PhysicTableScanRelNode(RelOptCluster cluster, - RelTraitSet traitSet, - RelOptTable table) { - super(cluster, traitSet, table); - } + public PhysicTableScanRelNode(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) { + super(cluster, traitSet, table); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public RelNode copy(RelTraitSet traitSet, List inputs) { - return new PhysicTableScanRelNode(getCluster(), traitSet, getTable()); - } + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new PhysicTableScanRelNode(getCluster(), traitSet, getTable()); + } - @Override - public RuntimeTable translate(QueryContext context) { - GeaFlowTable geaFlowTable = table.unwrap(GeaFlowTable.class); - RuntimeTable runtimeTable = context.getRuntimeTable(geaFlowTable.getName()); - if (runtimeTable == null) { - Expression pushFilter = context.getPushFilter(); - context.addReferSourceTable(geaFlowTable); - runtimeTable = context.getEngineContext().createRuntimeTable(context, geaFlowTable, pushFilter); - context.putRuntimeTable(geaFlowTable.getName(), runtimeTable); - } - return runtimeTable; + @Override + public RuntimeTable translate(QueryContext context) { + GeaFlowTable geaFlowTable = table.unwrap(GeaFlowTable.class); + RuntimeTable runtimeTable = context.getRuntimeTable(geaFlowTable.getName()); + if (runtimeTable == null) { + Expression pushFilter = context.getPushFilter(); + context.addReferSourceTable(geaFlowTable); + runtimeTable = + context.getEngineContext().createRuntimeTable(context, geaFlowTable, pushFilter); + context.putRuntimeTable(geaFlowTable.getName(), runtimeTable); } + return runtimeTable; + } - @Override - public String showSQL() { - GeaFlowTable table = super.getTable().unwrap(GeaFlowTable.class); - return table.toString(); - } + @Override + public String showSQL() { + GeaFlowTable table = super.getTable().unwrap(GeaFlowTable.class); + return table.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicUnionRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicUnionRelNode.java index fe07746e8..a202998f5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicUnionRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicUnionRelNode.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -35,60 +36,56 @@ import org.apache.geaflow.dsl.runtime.RDataView.ViewType; import org.apache.geaflow.dsl.runtime.RuntimeTable; -/** - * UNION. - */ +/** UNION. */ public class PhysicUnionRelNode extends Union implements PhysicRelNode { - public PhysicUnionRelNode(RelOptCluster cluster, - RelTraitSet traits, - List inputs, - boolean all) { - super(cluster, traits, inputs, all); - } + public PhysicUnionRelNode( + RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all) { + super(cluster, traits, inputs, all); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) { - return new PhysicUnionRelNode(super.getCluster(), traitSet, inputs, all); - } + @Override + public SetOp copy(RelTraitSet traitSet, List inputs, boolean all) { + return new PhysicUnionRelNode(super.getCluster(), traitSet, inputs, all); + } - @Override - public RuntimeTable translate(QueryContext context) { - List dataViews = new ArrayList<>(); - for (RelNode input : getInputs()) { - dataViews.add(((PhysicRelNode) input).translate(context)); - if (dataViews.get(dataViews.size() - 1).getType() != ViewType.TABLE) { - throw new GeaFlowDSLException("DataView: " - + dataViews.get(dataViews.size() - 1).getType() + " cannot support SQL union"); - } - } - if (!dataViews.isEmpty()) { - RuntimeTable output = (RuntimeTable) dataViews.get(0); - for (int i = 1; i < dataViews.size(); i++) { - output = output.union((RuntimeTable) dataViews.get(i)); - } - return output; - } else { - throw new GeaFlowDSLException("Union inputs cannot be empty."); - } + @Override + public RuntimeTable translate(QueryContext context) { + List dataViews = new ArrayList<>(); + for (RelNode input : getInputs()) { + dataViews.add(((PhysicRelNode) input).translate(context)); + if (dataViews.get(dataViews.size() - 1).getType() != ViewType.TABLE) { + throw new GeaFlowDSLException( + "DataView: " + + dataViews.get(dataViews.size() - 1).getType() + + " cannot support SQL union"); + } } - - @Override - public String showSQL() { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < getInputs().size(); i++) { - if (i > 0) { - sb.append(" ").append(kind.name()).append(" "); - } - sb.append(((PhysicRelNode) getInputs().get(i)).showSQL()); - } - return sb.toString(); + if (!dataViews.isEmpty()) { + RuntimeTable output = (RuntimeTable) dataViews.get(0); + for (int i = 1; i < dataViews.size(); i++) { + output = output.union((RuntimeTable) dataViews.get(i)); + } + return output; + } else { + throw new GeaFlowDSLException("Union inputs cannot be empty."); } + } + @Override + public String showSQL() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < getInputs().size(); i++) { + if (i > 0) { + sb.append(" ").append(kind.name()).append(" "); + } + sb.append(((PhysicRelNode) getInputs().get(i)).showSQL()); + } + return sb.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicValuesRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicValuesRelNode.java index 14d5a75be..71b651f8e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicValuesRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicValuesRelNode.java @@ -19,9 +19,9 @@ package org.apache.geaflow.dsl.runtime.plan; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.List; + import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -37,54 +37,64 @@ import org.apache.geaflow.dsl.runtime.RuntimeTable; import org.apache.geaflow.dsl.util.GQLRexUtil; +import com.google.common.collect.ImmutableList; + public class PhysicValuesRelNode extends Values implements PhysicRelNode { - public PhysicValuesRelNode(RelOptCluster cluster, - RelDataType rowType, - ImmutableList> tuples, - RelTraitSet traits) { - super(cluster, rowType, tuples, traits); - } + public PhysicValuesRelNode( + RelOptCluster cluster, + RelDataType rowType, + ImmutableList> tuples, + RelTraitSet traits) { + super(cluster, rowType, tuples, traits); + } - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } - @Override - public RelNode copy(RelTraitSet traitSet, List inputs) { - return new PhysicValuesRelNode(getCluster(), getRowType(), getTuples(), traitSet); - } + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new PhysicValuesRelNode(getCluster(), getRowType(), getTuples(), traitSet); + } - @Override - public RuntimeTable translate(QueryContext context) { - List values = convertValues(); - return context.getEngineContext().createRuntimeTable(context, values); - } + @Override + public RuntimeTable translate(QueryContext context) { + List values = convertValues(); + return context.getEngineContext().createRuntimeTable(context, values); + } - private List convertValues() { - List values = new ArrayList<>(); - for (List literals : getTuples()) { - Object[] fields = new Object[literals.size()]; - for (int i = 0; i < fields.length; i++) { - fields[i] = GQLRexUtil.getLiteralValue(literals.get(i)); - } - Row row = ObjectRow.create(fields); - values.add(row); - } - return values; + private List convertValues() { + List values = new ArrayList<>(); + for (List literals : getTuples()) { + Object[] fields = new Object[literals.size()]; + for (int i = 0; i < fields.length; i++) { + fields[i] = GQLRexUtil.getLiteralValue(literals.get(i)); + } + Row row = ObjectRow.create(fields); + values.add(row); } + return values; + } - @Override - public String showSQL() { - StringBuilder sql = new StringBuilder(); - sql.append("Values(") - .append(getTuples().stream() - .map(literals -> "(" + literals.stream().map(literal -> - literal.toString()).reduce((a, b) -> a + "," + b).get() + ")") - .reduce((a, b) -> a + " " + b).get()) - .append(")"); - return sql.toString(); - } + @Override + public String showSQL() { + StringBuilder sql = new StringBuilder(); + sql.append("Values(") + .append( + getTuples().stream() + .map( + literals -> + "(" + + literals.stream() + .map(literal -> literal.toString()) + .reduce((a, b) -> a + "," + b) + .get() + + ")") + .reduce((a, b) -> a + " " + b) + .get()) + .append(")"); + return sql.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicViewRelNode.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicViewRelNode.java index d862a2d46..6f1a50f0f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicViewRelNode.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/PhysicViewRelNode.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.plan; import java.util.List; + import org.apache.calcite.plan.*; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.TableScan; @@ -30,30 +31,28 @@ public class PhysicViewRelNode extends TableScan implements PhysicRelNode { - public PhysicViewRelNode(RelOptCluster cluster, RelTraitSet traitSet, - RelOptTable table) { - super(cluster, traitSet, table); - } - - @Override - public RelOptCost computeSelfCost(RelOptPlanner planner, - RelMetadataQuery mq) { - return super.computeSelfCost(planner, mq); - } - - @Override - public RelNode copy(RelTraitSet traitSet, List inputs) { - return new PhysicViewRelNode(getCluster(), traitSet, getTable()); - } - - @Override - public RuntimeTable translate(QueryContext context) { - GeaFlowView view = table.unwrap(GeaFlowView.class); - return (RuntimeTable) context.getDataViewByViewName(view.getName()); - } - - @Override - public String showSQL() { - return table.toString(); - } + public PhysicViewRelNode(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table) { + super(cluster, traitSet, table); + } + + @Override + public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) { + return super.computeSelfCost(planner, mq); + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new PhysicViewRelNode(getCluster(), traitSet, getTable()); + } + + @Override + public RuntimeTable translate(QueryContext context) { + GeaFlowView view = table.unwrap(GeaFlowView.class); + return (RuntimeTable) context.getDataViewByViewName(view.getName()); + } + + @Override + public String showSQL() { + return table.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertAggregateRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertAggregateRule.java index 948f75bc1..093e0192f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertAggregateRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertAggregateRule.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.plan.converters; -import com.google.common.base.Preconditions; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelTraitSet; @@ -30,33 +29,46 @@ import org.apache.geaflow.dsl.runtime.plan.PhysicAggregateRelNode; import org.apache.geaflow.dsl.runtime.plan.PhysicConvention; +import com.google.common.base.Preconditions; + public class ConvertAggregateRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertAggregateRule(); - - private ConvertAggregateRule() { - super(LogicalAggregate.class, Convention.NONE, PhysicConvention.INSTANCE, - ConvertAggregateRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - LogicalAggregate aggregate = call.rel(0); - Preconditions.checkArgument(aggregate.getGroupType() == Aggregate.Group.SIMPLE, - "Only support Aggregate.Group.SIMPLE, current group type " + aggregate.getGroupType()); - return true; - } - - @Override - public RelNode convert(RelNode rel) { - LogicalAggregate aggregate = (LogicalAggregate) rel; - - RelTraitSet relTraitSet = aggregate.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(aggregate.getInput(), aggregate.getInput().getTraitSet() - .replace(PhysicConvention.INSTANCE)); - - return new PhysicAggregateRelNode(aggregate.getCluster(), relTraitSet, - convertedInput, aggregate.indicator, aggregate.getGroupSet(), - aggregate.getGroupSets(), aggregate.getAggCallList()); - } + public static final ConverterRule INSTANCE = new ConvertAggregateRule(); + + private ConvertAggregateRule() { + super( + LogicalAggregate.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertAggregateRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalAggregate aggregate = call.rel(0); + Preconditions.checkArgument( + aggregate.getGroupType() == Aggregate.Group.SIMPLE, + "Only support Aggregate.Group.SIMPLE, current group type " + aggregate.getGroupType()); + return true; + } + + @Override + public RelNode convert(RelNode rel) { + LogicalAggregate aggregate = (LogicalAggregate) rel; + + RelTraitSet relTraitSet = aggregate.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + aggregate.getInput(), + aggregate.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); + + return new PhysicAggregateRelNode( + aggregate.getCluster(), + relTraitSet, + convertedInput, + aggregate.indicator, + aggregate.getGroupSet(), + aggregate.getGroupSets(), + aggregate.getAggCallList()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertConstructGraphRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertConstructGraphRule.java index a0a4dff30..e83713915 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertConstructGraphRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertConstructGraphRule.java @@ -29,22 +29,31 @@ public class ConvertConstructGraphRule extends ConverterRule { - public static final ConvertConstructGraphRule INSTANCE = new ConvertConstructGraphRule(); - - private ConvertConstructGraphRule() { - super(LogicalConstructGraph.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertConstructGraphRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalConstructGraph constructGraph = (LogicalConstructGraph) rel; - - RelTraitSet relTraitSet = constructGraph.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(constructGraph.getInput(), + public static final ConvertConstructGraphRule INSTANCE = new ConvertConstructGraphRule(); + + private ConvertConstructGraphRule() { + super( + LogicalConstructGraph.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertConstructGraphRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalConstructGraph constructGraph = (LogicalConstructGraph) rel; + + RelTraitSet relTraitSet = constructGraph.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + constructGraph.getInput(), constructGraph.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicConstructGraphRelNode(constructGraph.getCluster(), relTraitSet, convertedInput, - constructGraph.getLabelNames(), constructGraph.getRowType()); - } + return new PhysicConstructGraphRelNode( + constructGraph.getCluster(), + relTraitSet, + convertedInput, + constructGraph.getLabelNames(), + constructGraph.getRowType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertCorrelateRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertCorrelateRule.java index 080d5ef82..bd06b9049 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertCorrelateRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertCorrelateRule.java @@ -31,32 +31,40 @@ public class ConvertCorrelateRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertCorrelateRule(); - - private ConvertCorrelateRule() { - super(LogicalCorrelate.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertCorrelateRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - LogicalCorrelate join = call.rel(0); - RelNode right = join.getRight(); - return GQLRelUtil.findTableFunctionScan(right); - } - - @Override - public RelNode convert(RelNode relNode) { - LogicalCorrelate join = (LogicalCorrelate) relNode; - RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); - - RelNode convertedLeft = convert(join.getLeft(), - join.getLeft().getTraitSet().replace(PhysicConvention.INSTANCE)); - RelNode convertedRight = convert(join.getRight(), - join.getRight().getTraitSet().replace(PhysicConvention.INSTANCE)); - - return new PhysicCorrelateRelNode(join.getCluster(), relTraitSet, - convertedLeft, convertedRight, join.getCorrelationId(), - join.getRequiredColumns(), join.getJoinType()); - } + public static final ConverterRule INSTANCE = new ConvertCorrelateRule(); + + private ConvertCorrelateRule() { + super( + LogicalCorrelate.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertCorrelateRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalCorrelate join = call.rel(0); + RelNode right = join.getRight(); + return GQLRelUtil.findTableFunctionScan(right); + } + + @Override + public RelNode convert(RelNode relNode) { + LogicalCorrelate join = (LogicalCorrelate) relNode; + RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); + + RelNode convertedLeft = + convert(join.getLeft(), join.getLeft().getTraitSet().replace(PhysicConvention.INSTANCE)); + RelNode convertedRight = + convert(join.getRight(), join.getRight().getTraitSet().replace(PhysicConvention.INSTANCE)); + + return new PhysicCorrelateRelNode( + join.getCluster(), + relTraitSet, + convertedLeft, + convertedRight, + join.getCorrelationId(), + join.getRequiredColumns(), + join.getJoinType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertExchangeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertExchangeRule.java index f72fa2a4e..a517bb84a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertExchangeRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertExchangeRule.java @@ -30,27 +30,32 @@ public class ConvertExchangeRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertExchangeRule(); - - private ConvertExchangeRule() { - super(LogicalExchange.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertExchangeRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - return super.matches(call); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalExchange exchange = (LogicalExchange) rel; - - RelTraitSet relTraitSet = exchange.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(exchange.getInput(), + public static final ConverterRule INSTANCE = new ConvertExchangeRule(); + + private ConvertExchangeRule() { + super( + LogicalExchange.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertExchangeRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + return super.matches(call); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalExchange exchange = (LogicalExchange) rel; + + RelTraitSet relTraitSet = exchange.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + exchange.getInput(), exchange.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicExchangeRelNode(exchange.getCluster(), relTraitSet, - convertedInput, exchange.getDistribution()); - } + return new PhysicExchangeRelNode( + exchange.getCluster(), relTraitSet, convertedInput, exchange.getDistribution()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertFilterRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertFilterRule.java index 5a0b98012..7527736c9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertFilterRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertFilterRule.java @@ -29,22 +29,26 @@ public class ConvertFilterRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertFilterRule(); - - private ConvertFilterRule() { - super(Filter.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertFilterRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - Filter filter = (Filter) rel; - RelTraitSet relTraitSet = filter.getTraitSet().replace(PhysicConvention.INSTANCE); - - RelNode convertedInput = convert(filter.getInput(), - filter.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - - return new PhysicFilterRelNode(filter.getCluster(), relTraitSet, - convertedInput, filter.getCondition()); - } + public static final ConverterRule INSTANCE = new ConvertFilterRule(); + + private ConvertFilterRule() { + super( + Filter.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertFilterRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + Filter filter = (Filter) rel; + RelTraitSet relTraitSet = filter.getTraitSet().replace(PhysicConvention.INSTANCE); + + RelNode convertedInput = + convert( + filter.getInput(), filter.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); + + return new PhysicFilterRelNode( + filter.getCluster(), relTraitSet, convertedInput, filter.getCondition()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphAlgorithmRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphAlgorithmRule.java index c10ea9f63..5a7b4f1d2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphAlgorithmRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphAlgorithmRule.java @@ -29,22 +29,31 @@ public class ConvertGraphAlgorithmRule extends ConverterRule { - public static final ConvertGraphAlgorithmRule INSTANCE = new ConvertGraphAlgorithmRule(); - - private ConvertGraphAlgorithmRule() { - super(LogicalGraphAlgorithm.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertGraphAlgorithmRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalGraphAlgorithm graphAlgorithm = (LogicalGraphAlgorithm) rel; - - RelTraitSet relTraitSet = graphAlgorithm.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(graphAlgorithm.getInput(), + public static final ConvertGraphAlgorithmRule INSTANCE = new ConvertGraphAlgorithmRule(); + + private ConvertGraphAlgorithmRule() { + super( + LogicalGraphAlgorithm.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertGraphAlgorithmRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalGraphAlgorithm graphAlgorithm = (LogicalGraphAlgorithm) rel; + + RelTraitSet relTraitSet = graphAlgorithm.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + graphAlgorithm.getInput(), graphAlgorithm.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicGraphAlgorithm(graphAlgorithm.getCluster(), relTraitSet, - convertedInput, graphAlgorithm.getUserFunctionClass(), graphAlgorithm.getParams()); - } + return new PhysicGraphAlgorithm( + graphAlgorithm.getCluster(), + relTraitSet, + convertedInput, + graphAlgorithm.getUserFunctionClass(), + graphAlgorithm.getParams()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphMatchRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphMatchRule.java index c363f91c8..14c542884 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphMatchRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphMatchRule.java @@ -29,22 +29,31 @@ public class ConvertGraphMatchRule extends ConverterRule { - public static final ConvertGraphMatchRule INSTANCE = new ConvertGraphMatchRule(); - - private ConvertGraphMatchRule() { - super(LogicalGraphMatch.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertGraphMatchRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalGraphMatch graphMatch = (LogicalGraphMatch) rel; - - RelTraitSet relTraitSet = graphMatch.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(graphMatch.getInput(), + public static final ConvertGraphMatchRule INSTANCE = new ConvertGraphMatchRule(); + + private ConvertGraphMatchRule() { + super( + LogicalGraphMatch.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertGraphMatchRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalGraphMatch graphMatch = (LogicalGraphMatch) rel; + + RelTraitSet relTraitSet = graphMatch.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + graphMatch.getInput(), graphMatch.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicGraphMatchRelNode(graphMatch.getCluster(), relTraitSet, - convertedInput, graphMatch.getPathPattern(), graphMatch.getRowType()); - } + return new PhysicGraphMatchRelNode( + graphMatch.getCluster(), + relTraitSet, + convertedInput, + graphMatch.getPathPattern(), + graphMatch.getRowType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphModifyRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphModifyRule.java index af4b2fb8f..7a57dcb46 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphModifyRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphModifyRule.java @@ -29,25 +29,27 @@ public class ConvertGraphModifyRule extends ConverterRule { - public static final ConvertGraphModifyRule INSTANCE = new ConvertGraphModifyRule(); - - private ConvertGraphModifyRule() { - super(LogicalGraphModify.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertGraphModifyRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalGraphModify graphModify = (LogicalGraphModify) rel; - - RelTraitSet relTraitSet = graphModify.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(graphModify.getInput(), + public static final ConvertGraphModifyRule INSTANCE = new ConvertGraphModifyRule(); + + private ConvertGraphModifyRule() { + super( + LogicalGraphModify.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertGraphModifyRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalGraphModify graphModify = (LogicalGraphModify) rel; + + RelTraitSet relTraitSet = graphModify.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + graphModify.getInput(), graphModify.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicGraphModifyRelNode( - graphModify.getCluster(), - relTraitSet, - graphModify.getGraph(), - convertedInput); - } + return new PhysicGraphModifyRelNode( + graphModify.getCluster(), relTraitSet, graphModify.getGraph(), convertedInput); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphScanRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphScanRule.java index dff81ec16..248a6394d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphScanRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertGraphScanRule.java @@ -31,25 +31,28 @@ public class ConvertGraphScanRule extends ConverterRule { - public static final ConvertGraphScanRule INSTANCE = new ConvertGraphScanRule(); - - private ConvertGraphScanRule() { - super(LogicalGraphScan.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertGraphScanRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - LogicalGraphScan scan = call.rel(0); - GeaFlowGraph table = scan.getTable().unwrap(GeaFlowGraph.class); - return table != null; - } - - @Override - public RelNode convert(RelNode rel) { - LogicalGraphScan graphScan = (LogicalGraphScan) rel; - - RelTraitSet relTraitSet = graphScan.getTraitSet().replace(PhysicConvention.INSTANCE); - return new PhysicGraphScanRelNode(graphScan.getCluster(), relTraitSet, graphScan.getTable()); - } + public static final ConvertGraphScanRule INSTANCE = new ConvertGraphScanRule(); + + private ConvertGraphScanRule() { + super( + LogicalGraphScan.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertGraphScanRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalGraphScan scan = call.rel(0); + GeaFlowGraph table = scan.getTable().unwrap(GeaFlowGraph.class); + return table != null; + } + + @Override + public RelNode convert(RelNode rel) { + LogicalGraphScan graphScan = (LogicalGraphScan) rel; + + RelTraitSet relTraitSet = graphScan.getTraitSet().replace(PhysicConvention.INSTANCE); + return new PhysicGraphScanRelNode(graphScan.getCluster(), relTraitSet, graphScan.getTable()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertParameterizedRelNodeRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertParameterizedRelNodeRule.java index 5514ea8d8..e493468de 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertParameterizedRelNodeRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertParameterizedRelNodeRule.java @@ -29,30 +29,30 @@ public class ConvertParameterizedRelNodeRule extends ConverterRule { - public static final ConvertParameterizedRelNodeRule INSTANCE = new ConvertParameterizedRelNodeRule(); - - private ConvertParameterizedRelNodeRule() { - super(LogicalParameterizedRelNode.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertParameterizedRelNodeRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalParameterizedRelNode parameterizedRelNode = (LogicalParameterizedRelNode) rel; - - RelNode parameter = parameterizedRelNode.getParameterNode(); - RelNode convertedParameter = convert(parameter, - parameter.getTraitSet().replace(PhysicConvention.INSTANCE)); - - RelNode query = parameterizedRelNode.getQueryNode(); - RelNode convertedQuery = convert(query, - query.getTraitSet().replace(PhysicConvention.INSTANCE)); - - RelTraitSet traitSet = rel.getTraitSet().replace(PhysicConvention.INSTANCE); - return new PhysicParameterizedRelNode( - rel.getCluster(), - traitSet, - convertedParameter, - convertedQuery); - } + public static final ConvertParameterizedRelNodeRule INSTANCE = + new ConvertParameterizedRelNodeRule(); + + private ConvertParameterizedRelNodeRule() { + super( + LogicalParameterizedRelNode.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertParameterizedRelNodeRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalParameterizedRelNode parameterizedRelNode = (LogicalParameterizedRelNode) rel; + + RelNode parameter = parameterizedRelNode.getParameterNode(); + RelNode convertedParameter = + convert(parameter, parameter.getTraitSet().replace(PhysicConvention.INSTANCE)); + + RelNode query = parameterizedRelNode.getQueryNode(); + RelNode convertedQuery = convert(query, query.getTraitSet().replace(PhysicConvention.INSTANCE)); + + RelTraitSet traitSet = rel.getTraitSet().replace(PhysicConvention.INSTANCE); + return new PhysicParameterizedRelNode( + rel.getCluster(), traitSet, convertedParameter, convertedQuery); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertProjectRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertProjectRule.java index e215aa782..bc2e2e87d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertProjectRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertProjectRule.java @@ -29,23 +29,31 @@ public class ConvertProjectRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertProjectRule(); - - private ConvertProjectRule() { - super(LogicalProject.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertProjectRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalProject project = (LogicalProject) rel; - - RelTraitSet relTraitSet = project.getTraitSet().replace(PhysicConvention.INSTANCE); - RelNode convertedInput = convert(project.getInput(), + public static final ConverterRule INSTANCE = new ConvertProjectRule(); + + private ConvertProjectRule() { + super( + LogicalProject.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertProjectRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalProject project = (LogicalProject) rel; + + RelTraitSet relTraitSet = project.getTraitSet().replace(PhysicConvention.INSTANCE); + RelNode convertedInput = + convert( + project.getInput(), project.getInput().getTraitSet().replace(PhysicConvention.INSTANCE)); - return new PhysicProjectRelNode(project.getCluster(), relTraitSet, convertedInput, - project.getProjects(), project.getRowType()); - } - + return new PhysicProjectRelNode( + project.getCluster(), + relTraitSet, + convertedInput, + project.getProjects(), + project.getRowType()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertRules.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertRules.java index c3620195f..5ad483bf7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertRules.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertRules.java @@ -19,29 +19,30 @@ package org.apache.geaflow.dsl.runtime.plan.converters; -import com.google.common.collect.ImmutableList; import org.apache.calcite.rel.convert.ConverterRule; +import com.google.common.collect.ImmutableList; + public class ConvertRules { - public static ImmutableList TRANSFORM_RULES = ImmutableList.of( - ConvertAggregateRule.INSTANCE, - ConvertCorrelateRule.INSTANCE, - ConvertFilterRule.INSTANCE, - ConvertTableSortRule.INSTANCE, - ConvertProjectRule.INSTANCE, - ConvertTableModifyRule.INSTANCE, - ConvertTableScanRule.INSTANCE, - ConvertViewRule.INSTANCE, - ConvertTableFunctionScanRule.INSTANCE, - ConvertUnionRule.INSTANCE, - ConvertValuesRule.INSTANCE, - ConvertExchangeRule.INSTANCE, - ConvertGraphScanRule.INSTANCE, - ConvertGraphMatchRule.INSTANCE, - ConvertConstructGraphRule.INSTANCE, - ConvertParameterizedRelNodeRule.INSTANCE, - ConvertGraphModifyRule.INSTANCE, - ConvertGraphAlgorithmRule.INSTANCE - ); + public static ImmutableList TRANSFORM_RULES = + ImmutableList.of( + ConvertAggregateRule.INSTANCE, + ConvertCorrelateRule.INSTANCE, + ConvertFilterRule.INSTANCE, + ConvertTableSortRule.INSTANCE, + ConvertProjectRule.INSTANCE, + ConvertTableModifyRule.INSTANCE, + ConvertTableScanRule.INSTANCE, + ConvertViewRule.INSTANCE, + ConvertTableFunctionScanRule.INSTANCE, + ConvertUnionRule.INSTANCE, + ConvertValuesRule.INSTANCE, + ConvertExchangeRule.INSTANCE, + ConvertGraphScanRule.INSTANCE, + ConvertGraphMatchRule.INSTANCE, + ConvertConstructGraphRule.INSTANCE, + ConvertParameterizedRelNodeRule.INSTANCE, + ConvertGraphModifyRule.INSTANCE, + ConvertGraphAlgorithmRule.INSTANCE); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableFunctionScanRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableFunctionScanRule.java index 33eac466c..008a4a1a9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableFunctionScanRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableFunctionScanRule.java @@ -30,34 +30,34 @@ public class ConvertTableFunctionScanRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertTableFunctionScanRule(); - - private ConvertTableFunctionScanRule() { - super(LogicalTableFunctionScan.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertTableFunctionScanRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - return super.matches(call); - } - - @Override - public RelNode convert(RelNode relNode) { - LogicalTableFunctionScan tableFunctionScan = (LogicalTableFunctionScan) relNode; - - RelTraitSet relTraitSet = - tableFunctionScan.getTraitSet() - .replace(PhysicConvention.INSTANCE); - - return new PhysicTableFunctionScanRelNode( - tableFunctionScan.getCluster(), - relTraitSet, - tableFunctionScan.getInputs(), - tableFunctionScan.getCall(), - tableFunctionScan.getElementType(), - tableFunctionScan.getRowType(), - tableFunctionScan.getColumnMappings()); - } - + public static final ConverterRule INSTANCE = new ConvertTableFunctionScanRule(); + + private ConvertTableFunctionScanRule() { + super( + LogicalTableFunctionScan.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertTableFunctionScanRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + return super.matches(call); + } + + @Override + public RelNode convert(RelNode relNode) { + LogicalTableFunctionScan tableFunctionScan = (LogicalTableFunctionScan) relNode; + + RelTraitSet relTraitSet = tableFunctionScan.getTraitSet().replace(PhysicConvention.INSTANCE); + + return new PhysicTableFunctionScanRelNode( + tableFunctionScan.getCluster(), + relTraitSet, + tableFunctionScan.getInputs(), + tableFunctionScan.getCall(), + tableFunctionScan.getElementType(), + tableFunctionScan.getRowType(), + tableFunctionScan.getColumnMappings()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableModifyRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableModifyRule.java index 45baab994..78605687e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableModifyRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableModifyRule.java @@ -29,29 +29,31 @@ public class ConvertTableModifyRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertTableModifyRule(); - - private ConvertTableModifyRule() { - super(LogicalTableModify.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertTableModifyRule.class.getSimpleName()); - } - - @Override - public RelNode convert(RelNode relNode) { - LogicalTableModify modify = (LogicalTableModify) relNode; - - RelTraitSet relTraitSet = modify.getTraitSet().replace(PhysicConvention.INSTANCE); - - return new PhysicTableModifyRelNode( - modify.getCluster(), - relTraitSet, - modify.getTable(), - modify.getCatalogReader(), - convert(modify.getInput(), relTraitSet), - modify.getOperation(), - modify.getUpdateColumnList(), - modify.getSourceExpressionList(), - modify.isFlattened()); - } - + public static final ConverterRule INSTANCE = new ConvertTableModifyRule(); + + private ConvertTableModifyRule() { + super( + LogicalTableModify.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertTableModifyRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode relNode) { + LogicalTableModify modify = (LogicalTableModify) relNode; + + RelTraitSet relTraitSet = modify.getTraitSet().replace(PhysicConvention.INSTANCE); + + return new PhysicTableModifyRelNode( + modify.getCluster(), + relTraitSet, + modify.getTable(), + modify.getCatalogReader(), + convert(modify.getInput(), relTraitSet), + modify.getOperation(), + modify.getUpdateColumnList(), + modify.getSourceExpressionList(), + modify.isFlattened()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableScanRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableScanRule.java index fedd1ea82..d69b5a495 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableScanRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableScanRule.java @@ -31,25 +31,28 @@ public class ConvertTableScanRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertTableScanRule(); - - private ConvertTableScanRule() { - super(LogicalTableScan.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertTableScanRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - LogicalTableScan scan = call.rel(0); - GeaFlowTable table = scan.getTable().unwrap(GeaFlowTable.class); - return table != null; - } - - @Override - public RelNode convert(RelNode relNode) { - LogicalTableScan scan = (LogicalTableScan) relNode; - - RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); - return new PhysicTableScanRelNode(relNode.getCluster(), relTraitSet, scan.getTable()); - } + public static final ConverterRule INSTANCE = new ConvertTableScanRule(); + + private ConvertTableScanRule() { + super( + LogicalTableScan.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertTableScanRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalTableScan scan = call.rel(0); + GeaFlowTable table = scan.getTable().unwrap(GeaFlowTable.class); + return table != null; + } + + @Override + public RelNode convert(RelNode relNode) { + LogicalTableScan scan = (LogicalTableScan) relNode; + + RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); + return new PhysicTableScanRelNode(relNode.getCluster(), relTraitSet, scan.getTable()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableSortRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableSortRule.java index 1d0d94d5a..eb918fead 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableSortRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertTableSortRule.java @@ -29,26 +29,32 @@ public class ConvertTableSortRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertTableSortRule(); - - private ConvertTableSortRule() { - super(LogicalSort.class, Convention.NONE, PhysicConvention.INSTANCE, - ConvertTableSortRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - return super.matches(call); - } - - @Override - public RelNode convert(RelNode rel) { - LogicalSort sort = (LogicalSort) rel; - RelNode input = sort.getInput(); - - return new PhysicSortRelNode(sort.getCluster(), - sort.getTraitSet().replace(PhysicConvention.INSTANCE), - convert(input, input.getTraitSet().replace(PhysicConvention.INSTANCE)), - sort.getCollation(), sort.offset, sort.fetch); - } + public static final ConverterRule INSTANCE = new ConvertTableSortRule(); + + private ConvertTableSortRule() { + super( + LogicalSort.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertTableSortRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + return super.matches(call); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalSort sort = (LogicalSort) rel; + RelNode input = sort.getInput(); + + return new PhysicSortRelNode( + sort.getCluster(), + sort.getTraitSet().replace(PhysicConvention.INSTANCE), + convert(input, input.getTraitSet().replace(PhysicConvention.INSTANCE)), + sort.getCollation(), + sort.offset, + sort.fetch); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertUnionRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertUnionRule.java index 74b91b350..aa6d2fdff 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertUnionRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertUnionRule.java @@ -19,8 +19,8 @@ package org.apache.geaflow.dsl.runtime.plan.converters; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -29,26 +29,31 @@ import org.apache.geaflow.dsl.runtime.plan.PhysicConvention; import org.apache.geaflow.dsl.runtime.plan.PhysicUnionRelNode; -public class ConvertUnionRule extends ConverterRule { - - public static final ConverterRule INSTANCE = new ConvertUnionRule(); +import com.google.common.collect.Lists; - public ConvertUnionRule() { - super(LogicalUnion.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertUnionRule.class.getSimpleName()); - } +public class ConvertUnionRule extends ConverterRule { - @Override - public RelNode convert(RelNode rel) { - LogicalUnion union = (LogicalUnion) rel; - RelTraitSet relTraitSet = union.getTraitSet().replace(PhysicConvention.INSTANCE); - - List convertedInputs = Lists.newArrayList(); - for (RelNode input : union.getInputs()) { - RelNode convertedInput = convert(input, input.getTraitSet().replace(PhysicConvention.INSTANCE)); - convertedInputs.add(convertedInput); - } - return new PhysicUnionRelNode(union.getCluster(), relTraitSet, convertedInputs, union.all); + public static final ConverterRule INSTANCE = new ConvertUnionRule(); + + public ConvertUnionRule() { + super( + LogicalUnion.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertUnionRule.class.getSimpleName()); + } + + @Override + public RelNode convert(RelNode rel) { + LogicalUnion union = (LogicalUnion) rel; + RelTraitSet relTraitSet = union.getTraitSet().replace(PhysicConvention.INSTANCE); + + List convertedInputs = Lists.newArrayList(); + for (RelNode input : union.getInputs()) { + RelNode convertedInput = + convert(input, input.getTraitSet().replace(PhysicConvention.INSTANCE)); + convertedInputs.add(convertedInput); } - + return new PhysicUnionRelNode(union.getCluster(), relTraitSet, convertedInputs, union.all); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertValuesRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertValuesRule.java index 2fdefc2a0..2b3e40d6f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertValuesRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertValuesRule.java @@ -29,17 +29,21 @@ public class ConvertValuesRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertValuesRule(); + public static final ConverterRule INSTANCE = new ConvertValuesRule(); - private ConvertValuesRule() { - super(LogicalValues.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertValuesRule.class.getSimpleName()); - } + private ConvertValuesRule() { + super( + LogicalValues.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertValuesRule.class.getSimpleName()); + } - @Override - public RelNode convert(RelNode rel) { - LogicalValues values = (LogicalValues) rel; - RelTraitSet relTraitSet = values.getTraitSet().replace(PhysicConvention.INSTANCE); - return new PhysicValuesRelNode(values.getCluster(), values.getRowType(), values.getTuples(), relTraitSet); - } + @Override + public RelNode convert(RelNode rel) { + LogicalValues values = (LogicalValues) rel; + RelTraitSet relTraitSet = values.getTraitSet().replace(PhysicConvention.INSTANCE); + return new PhysicValuesRelNode( + values.getCluster(), values.getRowType(), values.getTuples(), relTraitSet); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertViewRule.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertViewRule.java index 1ef2de2e7..f62c689ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertViewRule.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/plan/converters/ConvertViewRule.java @@ -31,26 +31,28 @@ public class ConvertViewRule extends ConverterRule { - public static final ConverterRule INSTANCE = new ConvertViewRule(); - - private ConvertViewRule() { - super(LogicalTableScan.class, Convention.NONE, - PhysicConvention.INSTANCE, ConvertViewRule.class.getSimpleName()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - LogicalTableScan scan = call.rel(0); - GeaFlowView view = scan.getTable().unwrap(GeaFlowView.class); - return view != null; - } - - @Override - public RelNode convert(RelNode relNode) { - LogicalTableScan scan = (LogicalTableScan) relNode; - RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); - - return new PhysicViewRelNode(relNode.getCluster(), relTraitSet, scan.getTable()); - } - + public static final ConverterRule INSTANCE = new ConvertViewRule(); + + private ConvertViewRule() { + super( + LogicalTableScan.class, + Convention.NONE, + PhysicConvention.INSTANCE, + ConvertViewRule.class.getSimpleName()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + LogicalTableScan scan = call.rel(0); + GeaFlowView view = scan.getTable().unwrap(GeaFlowView.class); + return view != null; + } + + @Override + public RelNode convert(RelNode relNode) { + LogicalTableScan scan = (LogicalTableScan) relNode; + RelTraitSet relTraitSet = relNode.getTraitSet().replace(PhysicConvention.INSTANCE); + + return new PhysicViewRelNode(relNode.getCluster(), relTraitSet, scan.getTable()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagGroupBuilder.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagGroupBuilder.java index 8288cff5c..e7cc986b2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagGroupBuilder.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagGroupBuilder.java @@ -21,33 +21,34 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.dsl.runtime.traversal.operator.StepEndOperator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DagGroupBuilder { - private static final Logger LOGGER = LoggerFactory.getLogger(DagGroupBuilder.class); - - public static final String MAIN_QUERY_NAME = "Main"; - - public DagTopologyGroup buildDagGroup(StepLogicalPlanSet logicalPlanSet) { - logicalPlanSet = logicalPlanSet.markChainable(); - assert logicalPlanSet.getMainPlan().getOperator() instanceof StepEndOperator; - LOGGER.info("[DGB]Step logical plan description:\n{}", logicalPlanSet.getPlanSetDesc()); - DagTopology mainDag = DagTopology.build(MAIN_QUERY_NAME, logicalPlanSet.getMainPlan()); - Map subDags = new HashMap<>(); - for (Map.Entry entry : logicalPlanSet.getSubPlans().entrySet()) { - String queryName = entry.getKey(); - StepLogicalPlan subPlan = entry.getValue(); - DagTopology subDag = DagTopology.build(queryName, subPlan); - subDags.put(queryName, subDag); - } - return new DagTopologyGroup(mainDag, subDags); - } + private static final Logger LOGGER = LoggerFactory.getLogger(DagGroupBuilder.class); - public ExecuteDagGroup buildExecuteDagGroup(StepLogicalPlanSet logicalPlanSet) { - DagTopologyGroup dagGroup = buildDagGroup(logicalPlanSet); - return new ExecuteDagGroupImpl(dagGroup); + public static final String MAIN_QUERY_NAME = "Main"; + + public DagTopologyGroup buildDagGroup(StepLogicalPlanSet logicalPlanSet) { + logicalPlanSet = logicalPlanSet.markChainable(); + assert logicalPlanSet.getMainPlan().getOperator() instanceof StepEndOperator; + LOGGER.info("[DGB]Step logical plan description:\n{}", logicalPlanSet.getPlanSetDesc()); + DagTopology mainDag = DagTopology.build(MAIN_QUERY_NAME, logicalPlanSet.getMainPlan()); + Map subDags = new HashMap<>(); + for (Map.Entry entry : logicalPlanSet.getSubPlans().entrySet()) { + String queryName = entry.getKey(); + StepLogicalPlan subPlan = entry.getValue(); + DagTopology subDag = DagTopology.build(queryName, subPlan); + subDags.put(queryName, subDag); } + return new DagTopologyGroup(mainDag, subDags); + } + + public ExecuteDagGroup buildExecuteDagGroup(StepLogicalPlanSet logicalPlanSet) { + DagTopologyGroup dagGroup = buildDagGroup(logicalPlanSet); + return new ExecuteDagGroupImpl(dagGroup); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopology.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopology.java index c5a608fb9..24a271bfb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopology.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopology.java @@ -26,143 +26,146 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.commons.lang3.tuple.Pair; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.traversal.operator.StepOperator; public class DagTopology implements Serializable { - /** - * The name of query that this DAG represents. - */ - private final String queryName; - - private final StepOperator entryOperator; - - /** - * The output node ids for each node in the DAG. - */ - private final Map> opId2OutputIds; - - /** - * The input node ids for each node in the DAG. - */ - private final Map> opId2InputIds; - - private final Set> chainableOpIdPairs; - - private final Map> opId2Operators; - - private DagTopology(String queryName, - StepOperator entryOperator, - Set> chainableOpIdPairs, - Map> opId2InputIds, - Map> opId2OutputIds, - Map> opId2Operators) { - this.queryName = queryName; - this.entryOperator = entryOperator; - this.chainableOpIdPairs = chainableOpIdPairs; - this.opId2InputIds = opId2InputIds; - this.opId2OutputIds = opId2OutputIds; - this.opId2Operators = opId2Operators; - } - - @SuppressWarnings("unchecked") - public static DagTopology build(String queryName, StepLogicalPlan logicalPlan) { - Set> chainableOpIdPairs = new HashSet<>(); - Map> opId2InputIds = new HashMap<>(); - Map> opId2OutputIds = new HashMap<>(); - Map> opId2Operators = new HashMap<>(); - - generateChainOpIds(logicalPlan, chainableOpIdPairs); - generateLogicalDependency(logicalPlan, opId2InputIds, opId2OutputIds); - addNextOperators(logicalPlan, opId2Operators); - return new DagTopology(queryName, logicalPlan.getHeadPlan().getOperator(), chainableOpIdPairs, - opId2InputIds, opId2OutputIds, opId2Operators); + /** The name of query that this DAG represents. */ + private final String queryName; + + private final StepOperator entryOperator; + + /** The output node ids for each node in the DAG. */ + private final Map> opId2OutputIds; + + /** The input node ids for each node in the DAG. */ + private final Map> opId2InputIds; + + private final Set> chainableOpIdPairs; + + private final Map> opId2Operators; + + private DagTopology( + String queryName, + StepOperator entryOperator, + Set> chainableOpIdPairs, + Map> opId2InputIds, + Map> opId2OutputIds, + Map> opId2Operators) { + this.queryName = queryName; + this.entryOperator = entryOperator; + this.chainableOpIdPairs = chainableOpIdPairs; + this.opId2InputIds = opId2InputIds; + this.opId2OutputIds = opId2OutputIds; + this.opId2Operators = opId2Operators; + } + + @SuppressWarnings("unchecked") + public static DagTopology build(String queryName, StepLogicalPlan logicalPlan) { + Set> chainableOpIdPairs = new HashSet<>(); + Map> opId2InputIds = new HashMap<>(); + Map> opId2OutputIds = new HashMap<>(); + Map> opId2Operators = new HashMap<>(); + + generateChainOpIds(logicalPlan, chainableOpIdPairs); + generateLogicalDependency(logicalPlan, opId2InputIds, opId2OutputIds); + addNextOperators(logicalPlan, opId2Operators); + return new DagTopology( + queryName, + logicalPlan.getHeadPlan().getOperator(), + chainableOpIdPairs, + opId2InputIds, + opId2OutputIds, + opId2Operators); + } + + private static void generateLogicalDependency( + StepLogicalPlan endPlan, + Map> opId2InputIds, + Map> opId2OutputIds) { + for (StepLogicalPlan inputPlan : endPlan.getInputs()) { + List outputIds = + opId2OutputIds.computeIfAbsent(inputPlan.getId(), k -> new ArrayList<>()); + if (!outputIds.contains(endPlan.getId())) { + outputIds.add(endPlan.getId()); + opId2OutputIds.put(inputPlan.getId(), outputIds); + } + List inputIds = opId2InputIds.computeIfAbsent(endPlan.getId(), k -> new ArrayList<>()); + if (!inputIds.contains(inputPlan.getId())) { + inputIds.add(inputPlan.getId()); + opId2InputIds.put(endPlan.getId(), inputIds); + } + + generateLogicalDependency(inputPlan, opId2InputIds, opId2OutputIds); } + } - private static void generateLogicalDependency(StepLogicalPlan endPlan, - Map> opId2InputIds, - Map> opId2OutputIds) { - for (StepLogicalPlan inputPlan : endPlan.getInputs()) { - List outputIds = opId2OutputIds.computeIfAbsent(inputPlan.getId(), - k -> new ArrayList<>()); - if (!outputIds.contains(endPlan.getId())) { - outputIds.add(endPlan.getId()); - opId2OutputIds.put(inputPlan.getId(), outputIds); - } - List inputIds = opId2InputIds.computeIfAbsent(endPlan.getId(), - k -> new ArrayList<>()); - if (!inputIds.contains(inputPlan.getId())) { - inputIds.add(inputPlan.getId()); - opId2InputIds.put(endPlan.getId(), inputIds); - } - - generateLogicalDependency(inputPlan, opId2InputIds, opId2OutputIds); - } + private static void generateChainOpIds( + StepLogicalPlan endPlan, Set> chainableOpIdPairs) { + for (StepLogicalPlan inputPlan : endPlan.getInputs()) { + generateChainOpIds(inputPlan, chainableOpIdPairs); } - private static void generateChainOpIds(StepLogicalPlan endPlan, Set> chainableOpIdPairs) { - for (StepLogicalPlan inputPlan : endPlan.getInputs()) { - generateChainOpIds(inputPlan, chainableOpIdPairs); + if (endPlan.getInputs().isEmpty()) { + chainableOpIdPairs.add(Pair.of(endPlan.getId(), endPlan.getId())); + } else { + for (StepLogicalPlan inputPlan : endPlan.getInputs()) { + if (inputPlan.isAllowChain()) { + chainableOpIdPairs.add(Pair.of(endPlan.getId(), inputPlan.getId())); } - - if (endPlan.getInputs().isEmpty()) { - chainableOpIdPairs.add(Pair.of(endPlan.getId(), endPlan.getId())); - } else { - for (StepLogicalPlan inputPlan : endPlan.getInputs()) { - if (inputPlan.isAllowChain()) { - chainableOpIdPairs.add(Pair.of(endPlan.getId(), inputPlan.getId())); - } - } - } - } - - @SuppressWarnings("unchecked") - private static void addNextOperators(StepLogicalPlan endPlan, - Map> opId2Operators) { - opId2Operators.put(endPlan.getId(), endPlan.getOperator()); - for (StepLogicalPlan inputPlan : endPlan.getInputs()) { - inputPlan.getOperator().addNextOperator(endPlan.getOperator()); - addNextOperators(inputPlan, opId2Operators); - } - } - - public String getQueryName() { - return queryName; - } - - public StepOperator getEntryOperator() { - return entryOperator; + } } - - public boolean contains(long opId) { - return opId2Operators.containsKey(opId); - } - - public List getOutputIds(long opId) { - return opId2OutputIds.getOrDefault(opId, new ArrayList<>()); - } - - public List getInputIds(long opId) { - return opId2InputIds.getOrDefault(opId, new ArrayList<>()); - } - - public boolean isChained(long opId1, long opId2) { - return chainableOpIdPairs.contains(Pair.of(opId1, opId2)) - || chainableOpIdPairs.contains(Pair.of(opId2, opId1)); - } - - public Map> getOpId2Operators() { - return opId2Operators; - } - - @SuppressWarnings("unchecked") - public StepOperator getOperator(long opId) { - return (StepOperator) opId2Operators.get(opId); - } - - public long getEntryOpId() { - return entryOperator.getId(); + } + + @SuppressWarnings("unchecked") + private static void addNextOperators( + StepLogicalPlan endPlan, Map> opId2Operators) { + opId2Operators.put(endPlan.getId(), endPlan.getOperator()); + for (StepLogicalPlan inputPlan : endPlan.getInputs()) { + inputPlan.getOperator().addNextOperator(endPlan.getOperator()); + addNextOperators(inputPlan, opId2Operators); } + } + + public String getQueryName() { + return queryName; + } + + public StepOperator getEntryOperator() { + return entryOperator; + } + + public boolean contains(long opId) { + return opId2Operators.containsKey(opId); + } + + public List getOutputIds(long opId) { + return opId2OutputIds.getOrDefault(opId, new ArrayList<>()); + } + + public List getInputIds(long opId) { + return opId2InputIds.getOrDefault(opId, new ArrayList<>()); + } + + public boolean isChained(long opId1, long opId2) { + return chainableOpIdPairs.contains(Pair.of(opId1, opId2)) + || chainableOpIdPairs.contains(Pair.of(opId2, opId1)); + } + + public Map> getOpId2Operators() { + return opId2Operators; + } + + @SuppressWarnings("unchecked") + public StepOperator getOperator( + long opId) { + return (StepOperator) opId2Operators.get(opId); + } + + public long getEntryOpId() { + return entryOperator.getId(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopologyGroup.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopologyGroup.java index 3b852c8ab..942cb49e5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopologyGroup.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/DagTopologyGroup.java @@ -19,170 +19,174 @@ package org.apache.geaflow.dsl.runtime.traversal; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.runtime.traversal.operator.StepLoopUntilOperator; import org.apache.geaflow.dsl.runtime.traversal.operator.StepOperator; +import com.google.common.collect.Lists; + public class DagTopologyGroup { - private final DagTopology mainDag; - - private final Map subDags; - - private final Map> globalOpId2Operators; - - public DagTopologyGroup(DagTopology mainDag, - Map subDags) { - this.mainDag = mainDag; - this.subDags = subDags; - this.globalOpId2Operators = new HashMap<>(); - this.globalOpId2Operators.putAll(mainDag.getOpId2Operators()); - - for (DagTopology subDag : subDags.values()) { - Map> id2Operators = subDag.getOpId2Operators(); - for (Map.Entry> entry : id2Operators.entrySet()) { - long opId = entry.getKey(); - if (globalOpId2Operators.containsKey(opId)) { - throw new GeaFlowDSLException("Operator id: " + opId + " in sub dag: " + subDag.getQueryName() - + " is conflict with other dag."); - } - globalOpId2Operators.put(opId, entry.getValue()); - } + private final DagTopology mainDag; + + private final Map subDags; + + private final Map> globalOpId2Operators; + + public DagTopologyGroup(DagTopology mainDag, Map subDags) { + this.mainDag = mainDag; + this.subDags = subDags; + this.globalOpId2Operators = new HashMap<>(); + this.globalOpId2Operators.putAll(mainDag.getOpId2Operators()); + + for (DagTopology subDag : subDags.values()) { + Map> id2Operators = subDag.getOpId2Operators(); + for (Map.Entry> entry : id2Operators.entrySet()) { + long opId = entry.getKey(); + if (globalOpId2Operators.containsKey(opId)) { + throw new GeaFlowDSLException( + "Operator id: " + + opId + + " in sub dag: " + + subDag.getQueryName() + + " is conflict with other dag."); } + globalOpId2Operators.put(opId, entry.getValue()); + } } + } - public List getOutputIds(long opId) { - if (mainDag.contains(opId)) { - return mainDag.getOutputIds(opId); - } - for (DagTopology subDag : subDags.values()) { - if (subDag.contains(opId)) { - return subDag.getOutputIds(opId); - } - } - throw new IllegalArgumentException("Illegal opId: " + opId); + public List getOutputIds(long opId) { + if (mainDag.contains(opId)) { + return mainDag.getOutputIds(opId); } - - public List getInputIds(long opId) { - if (mainDag.contains(opId)) { - return mainDag.getInputIds(opId); - } - for (DagTopology subDag : subDags.values()) { - if (subDag.contains(opId)) { - return subDag.getInputIds(opId); - } - } - throw new IllegalArgumentException("Illegal opId: " + opId); + for (DagTopology subDag : subDags.values()) { + if (subDag.contains(opId)) { + return subDag.getOutputIds(opId); + } } + throw new IllegalArgumentException("Illegal opId: " + opId); + } - public DagTopology getDagTopology(long opId) { - if (mainDag.contains(opId)) { - return mainDag; - } - for (DagTopology subDag : subDags.values()) { - if (subDag.contains(opId)) { - return subDag; - } - } - throw new IllegalArgumentException("Illegal opId: " + opId); + public List getInputIds(long opId) { + if (mainDag.contains(opId)) { + return mainDag.getInputIds(opId); } - - public boolean isChained(long opId1, long opId2) { - if (mainDag.contains(opId1) && mainDag.contains(opId2)) { - return mainDag.isChained(opId1, opId2); - } - for (DagTopology subDag : subDags.values()) { - if (subDag.contains(opId1) && subDag.contains(opId2)) { - return subDag.isChained(opId1, opId2); - } - } - return false; + for (DagTopology subDag : subDags.values()) { + if (subDag.contains(opId)) { + return subDag.getInputIds(opId); + } } + throw new IllegalArgumentException("Illegal opId: " + opId); + } - public boolean belongMainDag(long opId) { - return mainDag.contains(opId); + public DagTopology getDagTopology(long opId) { + if (mainDag.contains(opId)) { + return mainDag; } - - @SuppressWarnings("unchecked") - public StepOperator getOperator(long opId) { - return globalOpId2Operators.get(opId); + for (DagTopology subDag : subDags.values()) { + if (subDag.contains(opId)) { + return subDag; + } } + throw new IllegalArgumentException("Illegal opId: " + opId); + } - public DagTopology getMainDag() { - return mainDag; + public boolean isChained(long opId1, long opId2) { + if (mainDag.contains(opId1) && mainDag.contains(opId2)) { + return mainDag.isChained(opId1, opId2); } - - public List getAllDagTopology() { - List dagTopologies = new ArrayList<>(); - dagTopologies.add(mainDag); - dagTopologies.addAll(subDags.values()); - return dagTopologies; + for (DagTopology subDag : subDags.values()) { + if (subDag.contains(opId1) && subDag.contains(opId2)) { + return subDag.isChained(opId1, opId2); + } } - - public List getSubDagTopologies() { - return Lists.newArrayList(subDags.values()); + return false; + } + + public boolean belongMainDag(long opId) { + return mainDag.contains(opId); + } + + @SuppressWarnings("unchecked") + public StepOperator getOperator(long opId) { + return globalOpId2Operators.get(opId); + } + + public DagTopology getMainDag() { + return mainDag; + } + + public List getAllDagTopology() { + List dagTopologies = new ArrayList<>(); + dagTopologies.add(mainDag); + dagTopologies.addAll(subDags.values()); + return dagTopologies; + } + + public List getSubDagTopologies() { + return Lists.newArrayList(subDags.values()); + } + + public Collection> getAllOperators() { + return globalOpId2Operators.values(); + } + + public int getIterationCount(int currentDepth, StepOperator stepOperator) { + List subQueryNames = stepOperator.getSubQueryNames(); + int maxSubDagIteration = 0; + for (String subQueryName : subQueryNames) { + DagTopology subDag = this.subDags.get(subQueryName); + assert subDag != null; + int subDagIterationCount = addIteration(getIterationCount(1, subDag.getEntryOperator()), 1); + if (subDagIterationCount > maxSubDagIteration) { + maxSubDagIteration = subDagIterationCount; + } } + currentDepth = addIteration(currentDepth, maxSubDagIteration); - public Collection> getAllOperators() { - return globalOpId2Operators.values(); + if (stepOperator instanceof StepLoopUntilOperator) { + StepLoopUntilOperator loopUntilOperator = (StepLoopUntilOperator) stepOperator; + currentDepth = + addIteration(currentDepth, addIteration(loopUntilOperator.getMaxLoopCount(), 1)); } - - public int getIterationCount(int currentDepth, StepOperator stepOperator) { - List subQueryNames = stepOperator.getSubQueryNames(); - int maxSubDagIteration = 0; - for (String subQueryName : subQueryNames) { - DagTopology subDag = this.subDags.get(subQueryName); - assert subDag != null; - int subDagIterationCount = - addIteration(getIterationCount(1, subDag.getEntryOperator()), 1); - if (subDagIterationCount > maxSubDagIteration) { - maxSubDagIteration = subDagIterationCount; - } - } - currentDepth = addIteration(currentDepth, maxSubDagIteration); - - if (stepOperator instanceof StepLoopUntilOperator) { - StepLoopUntilOperator loopUntilOperator = (StepLoopUntilOperator) stepOperator; - currentDepth = addIteration(currentDepth, - addIteration(loopUntilOperator.getMaxLoopCount(), 1)); - } - int depth = currentDepth; - for (Object op : stepOperator.getNextOperators()) { - StepOperator next = (StepOperator) op; - int branchDepth = getIterationCount(currentDepth, next); - if (!isChained(stepOperator.getId(), next.getId())) { - branchDepth = addIteration(branchDepth, 1); - } - if (branchDepth > depth) { - depth = branchDepth; - } - } - return depth; + int depth = currentDepth; + for (Object op : stepOperator.getNextOperators()) { + StepOperator next = (StepOperator) op; + int branchDepth = getIterationCount(currentDepth, next); + if (!isChained(stepOperator.getId(), next.getId())) { + branchDepth = addIteration(branchDepth, 1); + } + if (branchDepth > depth) { + depth = branchDepth; + } } + return depth; + } - private static int addIteration(int iteration, int delta) { - if (iteration == Integer.MAX_VALUE || iteration < 0 || delta == 0) { - return iteration; - } - if (delta > 0) { - if (Integer.MAX_VALUE - iteration >= delta) { - return iteration + delta; - } else { - return Integer.MAX_VALUE; - } - } else { - if (iteration + delta >= 0) { - return iteration + delta; - } else { - return iteration; - } - } + private static int addIteration(int iteration, int delta) { + if (iteration == Integer.MAX_VALUE || iteration < 0 || delta == 0) { + return iteration; + } + if (delta > 0) { + if (Integer.MAX_VALUE - iteration >= delta) { + return iteration + delta; + } else { + return Integer.MAX_VALUE; + } + } else { + if (iteration + delta >= 0) { + return iteration + delta; + } else { + return iteration; + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroup.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroup.java index c1a446eeb..61ee969e2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroup.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroup.java @@ -23,19 +23,19 @@ public interface ExecuteDagGroup { - void open(TraversalRuntimeContext context); + void open(TraversalRuntimeContext context); - void execute(Object vertexId, long... receiverOpIds); + void execute(Object vertexId, long... receiverOpIds); - void finishIteration(long iterationId); + void finishIteration(long iterationId); - void processBroadcast(MessageBox messageBox); + void processBroadcast(MessageBox messageBox); - void close(); + void close(); - long getEntryOpId(); + long getEntryOpId(); - DagTopology getMainDag(); + DagTopology getMainDag(); - int getMaxIterationCount(); + int getMaxIterationCount(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroupImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroupImpl.java index cf51f5d41..a5d4eb078 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroupImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/ExecuteDagGroupImpl.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; + import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.traversal.data.EndOfData; @@ -39,102 +40,101 @@ public class ExecuteDagGroupImpl implements ExecuteDagGroup { - private final DagTopologyGroup dagGroup; + private final DagTopologyGroup dagGroup; - private TraversalRuntimeContext context; + private TraversalRuntimeContext context; - private final List broadcastMessages = new ArrayList<>(); + private final List broadcastMessages = new ArrayList<>(); - public ExecuteDagGroupImpl(DagTopologyGroup dagGroup) { - this.dagGroup = dagGroup; - } + public ExecuteDagGroupImpl(DagTopologyGroup dagGroup) { + this.dagGroup = dagGroup; + } - @Override - public void open(TraversalRuntimeContext context) { - context.setTopology(dagGroup); - for (DagTopology dagTopology : dagGroup.getAllDagTopology()) { - dagTopology.getEntryOperator().open(context); - } - this.context = context; + @Override + public void open(TraversalRuntimeContext context) { + context.setTopology(dagGroup); + for (DagTopology dagTopology : dagGroup.getAllDagTopology()) { + dagTopology.getEntryOperator().open(context); } - - @Override - public void execute(Object vertexId, long... receiverOpIds) { - RowVertex vertex = IdOnlyVertex.of(vertexId); - for (long receiverOpId : receiverOpIds) { - StepOperator operator = dagGroup.getOperator(receiverOpId); - // set current process operator id. - context.setCurrentOpId(operator.getId()); - context.setVertex(vertex); - ParameterRequestMessage requestMessage = context.getMessage(MessageType.PARAMETER_REQUEST); - if (requestMessage != null && !requestMessage.isEmpty()) { // execute for each request id. - requestMessage.forEach(request -> doExecute(operator, vertex, request)); - } else { // execute for the case without request message. - doExecute(operator, vertex, null); - } - } - } - - private void doExecute(StepOperator operator, - RowVertex vertex, - ParameterRequest request) { - // set current request - context.setRequest(request); - context.setCurrentOpId(operator.getId()); - context.setVertex(vertex); - IPathMessage pathMessage = context.getMessage(MessageType.PATH); - ITreePath treePath = pathMessage == null ? EmptyTreePath.INSTANCE : (ITreePath) pathMessage; - - operator.process(VertexRecord.of(vertex, treePath)); + this.context = context; + } + + @Override + public void execute(Object vertexId, long... receiverOpIds) { + RowVertex vertex = IdOnlyVertex.of(vertexId); + for (long receiverOpId : receiverOpIds) { + StepOperator operator = dagGroup.getOperator(receiverOpId); + // set current process operator id. + context.setCurrentOpId(operator.getId()); + context.setVertex(vertex); + ParameterRequestMessage requestMessage = context.getMessage(MessageType.PARAMETER_REQUEST); + if (requestMessage != null && !requestMessage.isEmpty()) { // execute for each request id. + requestMessage.forEach(request -> doExecute(operator, vertex, request)); + } else { // execute for the case without request message. + doExecute(operator, vertex, null); + } } - - public void finishIteration(long iterationId) { - StepOperator mainOp = dagGroup.getMainDag().getEntryOperator(); - if (iterationId == 1) { - mainOp.process(EndOfData.of(mainOp.getId())); - } else { - // process broadcast message after other normal message has processed. - for (MessageBox messageBox : broadcastMessages) { - long[] receiverOpIds = messageBox.getReceiverIds(); - for (long receiverOpId : receiverOpIds) { - StepOperator operator = dagGroup.getOperator(receiverOpId); - EODMessage eodMessage = messageBox.getMessage(receiverOpId, MessageType.EOD); - if (eodMessage != null) { - for (EndOfData endOfData : eodMessage.getEodData()) { - operator.process(endOfData); - } - } - } + } + + private void doExecute( + StepOperator operator, RowVertex vertex, ParameterRequest request) { + // set current request + context.setRequest(request); + context.setCurrentOpId(operator.getId()); + context.setVertex(vertex); + IPathMessage pathMessage = context.getMessage(MessageType.PATH); + ITreePath treePath = pathMessage == null ? EmptyTreePath.INSTANCE : (ITreePath) pathMessage; + + operator.process(VertexRecord.of(vertex, treePath)); + } + + public void finishIteration(long iterationId) { + StepOperator mainOp = dagGroup.getMainDag().getEntryOperator(); + if (iterationId == 1) { + mainOp.process(EndOfData.of(mainOp.getId())); + } else { + // process broadcast message after other normal message has processed. + for (MessageBox messageBox : broadcastMessages) { + long[] receiverOpIds = messageBox.getReceiverIds(); + for (long receiverOpId : receiverOpIds) { + StepOperator operator = dagGroup.getOperator(receiverOpId); + EODMessage eodMessage = messageBox.getMessage(receiverOpId, MessageType.EOD); + if (eodMessage != null) { + for (EndOfData endOfData : eodMessage.getEodData()) { + operator.process(endOfData); } - broadcastMessages.clear(); + } } + } + broadcastMessages.clear(); } - - @Override - public void processBroadcast(MessageBox messageBox) { - broadcastMessages.add(messageBox); - } - - @Override - public void close() { - Collection> operators = dagGroup.getAllOperators(); - for (StepOperator operator : operators) { - operator.close(); - } - } - - @Override - public long getEntryOpId() { - return dagGroup.getMainDag().getEntryOpId(); - } - - @Override - public DagTopology getMainDag() { - return dagGroup.getMainDag(); - } - - @Override - public int getMaxIterationCount() { - return dagGroup.getIterationCount(1, dagGroup.getMainDag().getEntryOperator()); + } + + @Override + public void processBroadcast(MessageBox messageBox) { + broadcastMessages.add(messageBox); + } + + @Override + public void close() { + Collection> operators = dagGroup.getAllOperators(); + for (StepOperator operator : operators) { + operator.close(); } + } + + @Override + public long getEntryOpId() { + return dagGroup.getMainDag().getEntryOpId(); + } + + @Override + public DagTopology getMainDag() { + return dagGroup.getMainDag(); + } + + @Override + public int getMaxIterationCount() { + return dagGroup.getIterationCount(1, dagGroup.getMainDag().getEntryOperator()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlan.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlan.java index 8510ed22b..ead7e73f5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlan.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlan.java @@ -19,9 +19,6 @@ package org.apache.geaflow.dsl.runtime.traversal; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; @@ -34,6 +31,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; + import org.apache.calcite.rex.RexFieldAccess; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.types.EdgeType; @@ -85,539 +83,547 @@ import org.apache.geaflow.dsl.runtime.traversal.operator.StepSubQueryStartOperator; import org.apache.geaflow.dsl.runtime.traversal.operator.StepUnionOperator; -/** - * Logical plan for traversal step. - */ -public class StepLogicalPlan implements Serializable { - - private static final AtomicLong idCounter = new AtomicLong(0L); - - private List inputs; - - private final StepOperator operator; - - private final List outputs = new ArrayList<>(); - - private boolean allowChain = false; - - public StepLogicalPlan(List inputs, StepOperator operator) { - this.operator = Objects.requireNonNull(operator); - setInputs(inputs); - } - - public StepLogicalPlan(StepLogicalPlan input, StepOperator operator) { - this(input == null ? Collections.emptyList() : Collections.singletonList(input), operator); - } - - public StepLogicalPlan(StepOperator operator) { - this(Collections.emptyList(), operator); - } - - @SuppressWarnings("unchecked") - private void setInputs(List inputs) { - this.inputs = Lists.newArrayList(Objects.requireNonNull(inputs)); - GraphSchema modifyGraphSchema = null; - for (StepLogicalPlan input : inputs) { - input.addOutput(this); - input.operator.addNextOperator((StepOperator) operator); - if (modifyGraphSchema == null) { - modifyGraphSchema = input.operator.getModifyGraphSchema(); - } else { - modifyGraphSchema = modifyGraphSchema.merge(input.operator.getModifyGraphSchema()); - } - } - // inherit modify graph schema from the input. - operator.withModifyGraphSchema(modifyGraphSchema); - } - - /** - * Create a start operator without start ids which means traversal all. - * - * @return The logical plan. - */ - public static StepLogicalPlan start() { - return start(new HashSet<>()); - } - - public static StepLogicalPlan start(Object... ids) { - StartId[] startIds = new StartId[ids.length]; - for (int i = 0; i < startIds.length; i++) { - startIds[i] = new ConstantStartId(ids[i]); - } - return start(Sets.newHashSet(startIds)); - } - - /** - * Create a start operator with start ids which will be the head operator in the DAG. - * - * @param startIds The start ids for traversal. Empty start ids means traversal all. - * @return The logical plan. - */ - public static StepLogicalPlan start(Set startIds) { - StepSourceOperator startOp = new StepSourceOperator(nextPlanId(), startIds); - return new StepLogicalPlan(startOp); - } - - public StepLogicalPlan end() { - StepEndOperator operator = new StepEndOperator(nextPlanId()); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(VoidType.INSTANCE) - ; - } - - public static StepLogicalPlan subQueryStart(String queryName) { - StepSubQueryStartOperator operator = new StepSubQueryStartOperator(nextPlanId(), - queryName); - return new StepLogicalPlan(Collections.emptyList(), operator); - } - - public StepLogicalPlan vertexMatch(MatchVertexFunction function) { - MatchVertexOperator operator = new MatchVertexOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()); - } - - public StepLogicalPlan edgeMatch(MatchEdgeFunction function) { - MatchEdgeOperator operator = new MatchEdgeOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()); - } - - public StepLogicalPlan virtualEdgeMatch(MatchVirtualEdgeFunction function) { - MatchVirtualEdgeOperator operator = new MatchVirtualEdgeOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()); - } - - public StepLogicalPlan startFrom(String label) { - int fieldIndex = getOutputPathSchema().indexOf(label); - if (fieldIndex != -1) { // start from exist label. - IType fieldType = getOutputPathSchema().getType(fieldIndex); - if (!(fieldType instanceof VertexType)) { - throw new IllegalArgumentException( - "Only can start traversal from vertex, current type is: " + fieldType); - } - return this.virtualEdgeMatch(new TraversalFromVertexFunction(fieldIndex, fieldType)) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(EdgeType.emptyEdge(getGraphSchema().getIdType())) - ; - } else { // start from a new label. - return this.getHeadPlan(); - } - } - - public StepLogicalPlan filter(StepBoolFunction function) { - StepFilterOperator operator = new StepFilterOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()) - ; - } - - public StepLogicalPlan filterNode(StepNodeFilterFunction function) { - StepNodeFilterOperator operator = new StepNodeFilterOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()); - } - - public StepLogicalPlan distinct(StepKeyFunction keyFunction) { - StepDistinctOperator localOperator = new StepDistinctOperator(nextPlanId(), keyFunction); - StepLogicalPlan localDistinct = new StepLogicalPlan(this, localOperator) - .withName("StepLocalDistinct-" + localOperator.getId()) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()); - - StepLogicalPlan exchange = localDistinct.exchange(keyFunction); - - StepDistinctOperator globalOperator = new StepDistinctOperator(nextPlanId(), keyFunction); - return new StepLogicalPlan(exchange, globalOperator) - .withName("StepGlobalDistinct-" + globalOperator.getId()) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()); - } - - public StepLogicalPlan loopUtil(StepLogicalPlan loopBody, StepBoolFunction utilCondition, - int minLoopCount, int maxLoopCount, - int loopStartPathFieldCount, int loopBodyPathFieldCount) { - if (!loopBody.getOutputs().isEmpty()) { - throw new IllegalArgumentException("loopBody should be the last node"); - } - StepLogicalPlan bodyStart = loopBody.getHeadPlan(); - if (bodyStart.getOperator() instanceof StepSourceOperator) { - throw new IllegalArgumentException("Loop body cannot be a StepSourceOperator"); - } - // append loop body the current plan. - bodyStart.setInputs(Collections.singletonList(this)); - StepLoopUntilOperator operator = new StepLoopUntilOperator( - nextPlanId(), - bodyStart.getId(), - loopBody.getId(), - utilCondition, - minLoopCount, - maxLoopCount, - loopStartPathFieldCount, - loopBodyPathFieldCount); - List inputs = new ArrayList<>(); - inputs.add(loopBody); - if (minLoopCount == 0) { - inputs.add(this); - } - return new StepLogicalPlan(inputs, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(loopBody.getOutputPathSchema()); - } - - public StepLogicalPlan map(StepPathModifyFunction function, boolean isGlobal) { - StepMapOperator operator = new StepMapOperator(nextPlanId(), function, isGlobal); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - ; - } - - public StepLogicalPlan mapRow(StepMapRowFunction function) { - StepMapRowOperator operator = new StepMapRowOperator(nextPlanId(), function); - return new StepLogicalPlan(this, operator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - ; - } - - public StepLogicalPlan union(List inputs) { - StepUnionOperator operator = new StepUnionOperator(nextPlanId()); - - List totalInputs = new ArrayList<>(); - totalInputs.add(this); - totalInputs.addAll(inputs); - List inputPathTypes = totalInputs.stream() - .map(StepLogicalPlan::getOutputPathSchema) - .collect(Collectors.toList()); - return new StepLogicalPlan(totalInputs, operator) - .withGraphSchema(getGraphSchema()) - .withInputPathSchema(inputPathTypes) - ; - } - - public StepLogicalPlan exchange(StepKeyFunction keyFunction) { - StepExchangeOperator exchange = new StepExchangeOperator(nextPlanId(), keyFunction); - return new StepLogicalPlan(this, exchange) - .withGraphSchema(getGraphSchema()) - .withInputPathSchema(getOutputPathSchema()) - .withOutputPathSchema(getOutputPathSchema()) - .withOutputType(getOutputType()) - ; - } - - public StepLogicalPlan localExchange(StepKeyFunction keyFunction) { - StepLocalExchangeOperator exchange = new StepLocalExchangeOperator(nextPlanId(), keyFunction); - return new StepLogicalPlan(this, exchange) - .withGraphSchema(getGraphSchema()) - .withInputPathSchema(getOutputPathSchema()) - .withOutputPathSchema(getOutputPathSchema()) - .withOutputType(getOutputType()) - ; - } - - public StepLogicalPlan join(StepLogicalPlan right, StepKeyFunction leftKey, - StepKeyFunction rightKey, StepJoinFunction joinFunction, - PathType inputJoinPathSchema, boolean isLocalJoin) { - StepLogicalPlan leftExchange = isLocalJoin - ? this.localExchange(leftKey) : this.exchange(leftKey); - StepLogicalPlan rightExchange = isLocalJoin - ? right.localExchange(rightKey) : right.exchange(rightKey); - - List joinInputPaths = Lists.newArrayList(leftExchange.getOutputPathSchema(), - rightExchange.getOutputPathSchema()); - - StepJoinOperator joinOperator = new StepJoinOperator(nextPlanId(), joinFunction, - inputJoinPathSchema, joinInputPaths, isLocalJoin); - return new StepLogicalPlan(Lists.newArrayList(leftExchange, rightExchange), joinOperator) - .withGraphSchema(getGraphSchema()) - .withInputPathSchema(joinInputPaths) - .withOutputType(VertexType.emptyVertex(getGraphSchema().getIdType())) - ; - } - - public StepLogicalPlan sort(StepSortFunction sortFunction) { - StepSortOperator localSortOperator = new StepSortOperator(nextPlanId(), sortFunction); - StepLogicalPlan localSortPlan = new StepLogicalPlan(this, localSortOperator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()); - StepLogicalPlan exchangePlan = localSortPlan.exchange(new StepKeyFunctionImpl(new int[0], new IType[0])); - StepSortFunction globalSortFunction = ((StepSortFunctionImpl) sortFunction).copy(true); - StepGlobalSortOperator globalSortOperator = new StepGlobalSortOperator(nextPlanId(), - globalSortFunction, this.getOutputType(), this.getOutputPathSchema()); - - return new StepLogicalPlan(exchangePlan, globalSortOperator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()); - } - - public StepLogicalPlan aggregate(StepAggregateFunction aggFunction) { - StepLocalSingleValueAggregateOperator localAggOp = new StepLocalSingleValueAggregateOperator(nextPlanId(), aggFunction); - IType localAggOutputType = ObjectType.INSTANCE; - StepLogicalPlan localAggPlan = new StepLogicalPlan(this, localAggOp) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(StructType.singleValue(localAggOutputType, false)); - - StepLogicalPlan exchangePlan = localAggPlan.exchange(new StepKeyFunctionImpl(new int[0], new IType[0])); - StepGlobalSingleValueAggregateOperator globalAggOp = new StepGlobalSingleValueAggregateOperator(nextPlanId(), localAggOutputType, - aggFunction); - - return new StepLogicalPlan(exchangePlan, globalAggOp) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(exchangePlan.getOutputPathSchema()) - .withOutputPathSchema(PathType.EMPTY) - ; - } - - public StepLogicalPlan aggregate(PathType inputPath, PathType outputPath, - StepKeyFunction keyFunction, - StepAggregateFunction aggFn) { - StepLocalAggregateOperator localAggOp = new StepLocalAggregateOperator(nextPlanId(), - keyFunction, aggFn); - StepLogicalPlan localAggPlan = new StepLogicalPlan(this, localAggOp) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(inputPath) - .withOutputPathSchema(inputPath) - .withOutputType(getOutputType()); - StepLogicalPlan exchangePlan = localAggPlan.exchange(keyFunction); - StepGlobalAggregateOperator globalAggOp = new StepGlobalAggregateOperator(nextPlanId(), - keyFunction, aggFn); - - return new StepLogicalPlan(exchangePlan, globalAggOp) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(outputPath) - .withOutputPathSchema(outputPath) - .withOutputType(this.getOutputType()) - ; - } - - public StepLogicalPlan ret() { - StepReturnOperator returnOperator = new StepReturnOperator(nextPlanId()); - return new StepLogicalPlan(this, returnOperator) - .withGraphSchema(this.getGraphSchema()) - .withInputPathSchema(this.getOutputPathSchema()) - .withOutputPathSchema(this.getOutputPathSchema()) - .withOutputType(this.getOutputType()) - ; - } - - private void addOutput(StepLogicalPlan output) { - assert !outputs.contains(output) : "Output has already added"; - outputs.add(output); - } - - public StepLogicalPlan withFilteredFields(Set fields) { - if (operator instanceof FilteredFieldsOperator) { - ((FilteredFieldsOperator) operator).withFilteredFields(fields); - } - return this; - } - - public StepLogicalPlan withName(String name) { - operator.withName(name); - return this; - } - - public StepLogicalPlan withOutputPathSchema(PathType outputPath) { - operator.withOutputPathSchema(outputPath); - return this; - } - - public StepLogicalPlan withInputPathSchema(List inputPaths) { - operator.withInputPathSchema(inputPaths); - return this; - } - - public StepLogicalPlan withInputPathSchema(PathType inputPath) { - operator.withInputPathSchema(inputPath); - return this; - } - - public StepLogicalPlan withOutputType(IType outputType) { - operator.withOutputType(outputType); - return this; - } - - public StepLogicalPlan withGraphSchema(GraphSchema graphSchema) { - operator.withGraphSchema(graphSchema); - return this; - } - - public StepLogicalPlan withModifyGraphSchema(GraphSchema modifyGraphSchema) { - operator.withModifyGraphSchema(modifyGraphSchema); - return this; - } - - public PathType getOutputPathSchema() { - return operator.getOutputPathSchema(); - } - - public PathType getInputPathSchema() { - assert operator.getInputPathSchemas().size() == 1; - return operator.getInputPathSchemas().get(0); - } - - public List getInputPathSchemas() { - return operator.getInputPathSchemas(); - } - - public GraphSchema getGraphSchema() { - return operator.getGraphSchema(); - } - - public GraphSchema getModifyGraphSchema() { - return operator.getModifyGraphSchema(); - } - - public IType getOutputType() { - return operator.getOutputType(); - } - - public String getPlanDesc(boolean onlyContent) { - StringBuilder graphviz = new StringBuilder(); - if (!onlyContent) { - graphviz.append("digraph G {\n"); - } - generatePlanEdge(graphviz, this, new HashSet<>()); - generatePlanVertex(graphviz, this, new HashSet<>()); - if (!onlyContent) { - graphviz.append("}"); - } - return graphviz.toString(); - } - - public String getPlanDesc() { - return getPlanDesc(false); - } - - private void generatePlanEdge(StringBuilder graphviz, StepLogicalPlan plan, Set visited) { - if (visited.contains(plan.getId())) { - return; - } - visited.add(plan.getId()); - for (StepLogicalPlan input : plan.getInputs()) { - String edgeDesc = ""; - if (!input.isAllowChain()) { - edgeDesc = "chain = false"; - } - graphviz.append(String.format("%d -> %d [label= \"%s\"]\n", input.getId(), plan.getId(), edgeDesc)); - generatePlanEdge(graphviz, input, visited); - } - } - - private void generatePlanVertex(StringBuilder graphviz, StepLogicalPlan plan, Set visited) { - if (visited.contains(plan.getId())) { - return; - } - visited.add(plan.getId()); - - String vertexStr = plan.getOperator().toString(); - graphviz.append(String.format("%d [label= \"%s\"]\n", plan.getId(), vertexStr)); - for (StepLogicalPlan input : plan.getInputs()) { - generatePlanVertex(graphviz, input, visited); - } - } - - private static long nextPlanId() { - return idCounter.getAndIncrement(); - } - - public boolean isAllowChain() { - return allowChain; - } - - public void setAllowChain(boolean allowChain) { - this.allowChain = allowChain; - } - - public List getInputs() { - return inputs; - } - - public StepOperator getOperator() { - return operator; - } - - public List getOutputs() { - return outputs; - } - - public long getId() { - return operator.getId(); - } - - public List getFinalPlans() { - if (outputs.isEmpty()) { - return Collections.singletonList(this); - } - Set finalPlans = new LinkedHashSet<>(); - for (StepLogicalPlan output : outputs) { - finalPlans.addAll(output.getFinalPlans()); - } - return ImmutableList.copyOf(finalPlans); - } - - public StepLogicalPlan getHeadPlan() { - if (getInputs().isEmpty()) { - return this; - } - StepLogicalPlan headPlan = null; - for (StepLogicalPlan input : inputs) { - if (headPlan == null) { - headPlan = input.getHeadPlan(); - } else if (headPlan != input.getHeadPlan()) { - throw new IllegalArgumentException("Illegal plan with multi-head plan"); - } - } - return headPlan; - } - - public StepLogicalPlan copy() { - Map copyPlanCache = new LinkedHashMap<>(); - for (StepLogicalPlan finalPlan : getFinalPlans()) { - finalPlan.copy(copyPlanCache); - } - return copyPlanCache.values().iterator().next().getHeadPlan(); - } +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; - private StepLogicalPlan copy(Map copyPlanCache) { - List inputsCopy = inputs.stream() - .map(input -> input.copy(copyPlanCache)) - .collect(Collectors.toList()); - - if (copyPlanCache.containsKey(getId())) { - return copyPlanCache.get(getId()); - } - StepLogicalPlan copyPlan = new StepLogicalPlan(inputsCopy, operator.copy()); - copyPlanCache.put(getId(), copyPlan); - return copyPlan; - } +/** Logical plan for traversal step. */ +public class StepLogicalPlan implements Serializable { - public static void clearCounter() { - idCounter.set(0); + private static final AtomicLong idCounter = new AtomicLong(0L); + + private List inputs; + + private final StepOperator operator; + + private final List outputs = new ArrayList<>(); + + private boolean allowChain = false; + + public StepLogicalPlan(List inputs, StepOperator operator) { + this.operator = Objects.requireNonNull(operator); + setInputs(inputs); + } + + public StepLogicalPlan(StepLogicalPlan input, StepOperator operator) { + this(input == null ? Collections.emptyList() : Collections.singletonList(input), operator); + } + + public StepLogicalPlan(StepOperator operator) { + this(Collections.emptyList(), operator); + } + + @SuppressWarnings("unchecked") + private void setInputs(List inputs) { + this.inputs = Lists.newArrayList(Objects.requireNonNull(inputs)); + GraphSchema modifyGraphSchema = null; + for (StepLogicalPlan input : inputs) { + input.addOutput(this); + input.operator.addNextOperator((StepOperator) operator); + if (modifyGraphSchema == null) { + modifyGraphSchema = input.operator.getModifyGraphSchema(); + } else { + modifyGraphSchema = modifyGraphSchema.merge(input.operator.getModifyGraphSchema()); + } + } + // inherit modify graph schema from the input. + operator.withModifyGraphSchema(modifyGraphSchema); + } + + /** + * Create a start operator without start ids which means traversal all. + * + * @return The logical plan. + */ + public static StepLogicalPlan start() { + return start(new HashSet<>()); + } + + public static StepLogicalPlan start(Object... ids) { + StartId[] startIds = new StartId[ids.length]; + for (int i = 0; i < startIds.length; i++) { + startIds[i] = new ConstantStartId(ids[i]); + } + return start(Sets.newHashSet(startIds)); + } + + /** + * Create a start operator with start ids which will be the head operator in the DAG. + * + * @param startIds The start ids for traversal. Empty start ids means traversal all. + * @return The logical plan. + */ + public static StepLogicalPlan start(Set startIds) { + StepSourceOperator startOp = new StepSourceOperator(nextPlanId(), startIds); + return new StepLogicalPlan(startOp); + } + + public StepLogicalPlan end() { + StepEndOperator operator = new StepEndOperator(nextPlanId()); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(VoidType.INSTANCE); + } + + public static StepLogicalPlan subQueryStart(String queryName) { + StepSubQueryStartOperator operator = new StepSubQueryStartOperator(nextPlanId(), queryName); + return new StepLogicalPlan(Collections.emptyList(), operator); + } + + public StepLogicalPlan vertexMatch(MatchVertexFunction function) { + MatchVertexOperator operator = new MatchVertexOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()); + } + + public StepLogicalPlan edgeMatch(MatchEdgeFunction function) { + MatchEdgeOperator operator = new MatchEdgeOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()); + } + + public StepLogicalPlan virtualEdgeMatch(MatchVirtualEdgeFunction function) { + MatchVirtualEdgeOperator operator = new MatchVirtualEdgeOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()); + } + + public StepLogicalPlan startFrom(String label) { + int fieldIndex = getOutputPathSchema().indexOf(label); + if (fieldIndex != -1) { // start from exist label. + IType fieldType = getOutputPathSchema().getType(fieldIndex); + if (!(fieldType instanceof VertexType)) { + throw new IllegalArgumentException( + "Only can start traversal from vertex, current type is: " + fieldType); + } + return this.virtualEdgeMatch(new TraversalFromVertexFunction(fieldIndex, fieldType)) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(EdgeType.emptyEdge(getGraphSchema().getIdType())); + } else { // start from a new label. + return this.getHeadPlan(); + } + } + + public StepLogicalPlan filter(StepBoolFunction function) { + StepFilterOperator operator = new StepFilterOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + } + + public StepLogicalPlan filterNode(StepNodeFilterFunction function) { + StepNodeFilterOperator operator = new StepNodeFilterOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + } + + public StepLogicalPlan distinct(StepKeyFunction keyFunction) { + StepDistinctOperator localOperator = new StepDistinctOperator(nextPlanId(), keyFunction); + StepLogicalPlan localDistinct = + new StepLogicalPlan(this, localOperator) + .withName("StepLocalDistinct-" + localOperator.getId()) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + + StepLogicalPlan exchange = localDistinct.exchange(keyFunction); + + StepDistinctOperator globalOperator = new StepDistinctOperator(nextPlanId(), keyFunction); + return new StepLogicalPlan(exchange, globalOperator) + .withName("StepGlobalDistinct-" + globalOperator.getId()) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + } + + public StepLogicalPlan loopUtil( + StepLogicalPlan loopBody, + StepBoolFunction utilCondition, + int minLoopCount, + int maxLoopCount, + int loopStartPathFieldCount, + int loopBodyPathFieldCount) { + if (!loopBody.getOutputs().isEmpty()) { + throw new IllegalArgumentException("loopBody should be the last node"); + } + StepLogicalPlan bodyStart = loopBody.getHeadPlan(); + if (bodyStart.getOperator() instanceof StepSourceOperator) { + throw new IllegalArgumentException("Loop body cannot be a StepSourceOperator"); + } + // append loop body the current plan. + bodyStart.setInputs(Collections.singletonList(this)); + StepLoopUntilOperator operator = + new StepLoopUntilOperator( + nextPlanId(), + bodyStart.getId(), + loopBody.getId(), + utilCondition, + minLoopCount, + maxLoopCount, + loopStartPathFieldCount, + loopBodyPathFieldCount); + List inputs = new ArrayList<>(); + inputs.add(loopBody); + if (minLoopCount == 0) { + inputs.add(this); + } + return new StepLogicalPlan(inputs, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(loopBody.getOutputPathSchema()); + } + + public StepLogicalPlan map(StepPathModifyFunction function, boolean isGlobal) { + StepMapOperator operator = new StepMapOperator(nextPlanId(), function, isGlobal); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()); + } + + public StepLogicalPlan mapRow(StepMapRowFunction function) { + StepMapRowOperator operator = new StepMapRowOperator(nextPlanId(), function); + return new StepLogicalPlan(this, operator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()); + } + + public StepLogicalPlan union(List inputs) { + StepUnionOperator operator = new StepUnionOperator(nextPlanId()); + + List totalInputs = new ArrayList<>(); + totalInputs.add(this); + totalInputs.addAll(inputs); + List inputPathTypes = + totalInputs.stream().map(StepLogicalPlan::getOutputPathSchema).collect(Collectors.toList()); + return new StepLogicalPlan(totalInputs, operator) + .withGraphSchema(getGraphSchema()) + .withInputPathSchema(inputPathTypes); + } + + public StepLogicalPlan exchange(StepKeyFunction keyFunction) { + StepExchangeOperator exchange = new StepExchangeOperator(nextPlanId(), keyFunction); + return new StepLogicalPlan(this, exchange) + .withGraphSchema(getGraphSchema()) + .withInputPathSchema(getOutputPathSchema()) + .withOutputPathSchema(getOutputPathSchema()) + .withOutputType(getOutputType()); + } + + public StepLogicalPlan localExchange(StepKeyFunction keyFunction) { + StepLocalExchangeOperator exchange = new StepLocalExchangeOperator(nextPlanId(), keyFunction); + return new StepLogicalPlan(this, exchange) + .withGraphSchema(getGraphSchema()) + .withInputPathSchema(getOutputPathSchema()) + .withOutputPathSchema(getOutputPathSchema()) + .withOutputType(getOutputType()); + } + + public StepLogicalPlan join( + StepLogicalPlan right, + StepKeyFunction leftKey, + StepKeyFunction rightKey, + StepJoinFunction joinFunction, + PathType inputJoinPathSchema, + boolean isLocalJoin) { + StepLogicalPlan leftExchange = + isLocalJoin ? this.localExchange(leftKey) : this.exchange(leftKey); + StepLogicalPlan rightExchange = + isLocalJoin ? right.localExchange(rightKey) : right.exchange(rightKey); + + List joinInputPaths = + Lists.newArrayList(leftExchange.getOutputPathSchema(), rightExchange.getOutputPathSchema()); + + StepJoinOperator joinOperator = + new StepJoinOperator( + nextPlanId(), joinFunction, inputJoinPathSchema, joinInputPaths, isLocalJoin); + return new StepLogicalPlan(Lists.newArrayList(leftExchange, rightExchange), joinOperator) + .withGraphSchema(getGraphSchema()) + .withInputPathSchema(joinInputPaths) + .withOutputType(VertexType.emptyVertex(getGraphSchema().getIdType())); + } + + public StepLogicalPlan sort(StepSortFunction sortFunction) { + StepSortOperator localSortOperator = new StepSortOperator(nextPlanId(), sortFunction); + StepLogicalPlan localSortPlan = + new StepLogicalPlan(this, localSortOperator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + StepLogicalPlan exchangePlan = + localSortPlan.exchange(new StepKeyFunctionImpl(new int[0], new IType[0])); + StepSortFunction globalSortFunction = ((StepSortFunctionImpl) sortFunction).copy(true); + StepGlobalSortOperator globalSortOperator = + new StepGlobalSortOperator( + nextPlanId(), globalSortFunction, this.getOutputType(), this.getOutputPathSchema()); + + return new StepLogicalPlan(exchangePlan, globalSortOperator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + } + + public StepLogicalPlan aggregate(StepAggregateFunction aggFunction) { + StepLocalSingleValueAggregateOperator localAggOp = + new StepLocalSingleValueAggregateOperator(nextPlanId(), aggFunction); + IType localAggOutputType = ObjectType.INSTANCE; + StepLogicalPlan localAggPlan = + new StepLogicalPlan(this, localAggOp) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(StructType.singleValue(localAggOutputType, false)); + + StepLogicalPlan exchangePlan = + localAggPlan.exchange(new StepKeyFunctionImpl(new int[0], new IType[0])); + StepGlobalSingleValueAggregateOperator globalAggOp = + new StepGlobalSingleValueAggregateOperator(nextPlanId(), localAggOutputType, aggFunction); + + return new StepLogicalPlan(exchangePlan, globalAggOp) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(exchangePlan.getOutputPathSchema()) + .withOutputPathSchema(PathType.EMPTY); + } + + public StepLogicalPlan aggregate( + PathType inputPath, + PathType outputPath, + StepKeyFunction keyFunction, + StepAggregateFunction aggFn) { + StepLocalAggregateOperator localAggOp = + new StepLocalAggregateOperator(nextPlanId(), keyFunction, aggFn); + StepLogicalPlan localAggPlan = + new StepLogicalPlan(this, localAggOp) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(inputPath) + .withOutputPathSchema(inputPath) + .withOutputType(getOutputType()); + StepLogicalPlan exchangePlan = localAggPlan.exchange(keyFunction); + StepGlobalAggregateOperator globalAggOp = + new StepGlobalAggregateOperator(nextPlanId(), keyFunction, aggFn); + + return new StepLogicalPlan(exchangePlan, globalAggOp) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(outputPath) + .withOutputPathSchema(outputPath) + .withOutputType(this.getOutputType()); + } + + public StepLogicalPlan ret() { + StepReturnOperator returnOperator = new StepReturnOperator(nextPlanId()); + return new StepLogicalPlan(this, returnOperator) + .withGraphSchema(this.getGraphSchema()) + .withInputPathSchema(this.getOutputPathSchema()) + .withOutputPathSchema(this.getOutputPathSchema()) + .withOutputType(this.getOutputType()); + } + + private void addOutput(StepLogicalPlan output) { + assert !outputs.contains(output) : "Output has already added"; + outputs.add(output); + } + + public StepLogicalPlan withFilteredFields(Set fields) { + if (operator instanceof FilteredFieldsOperator) { + ((FilteredFieldsOperator) operator).withFilteredFields(fields); + } + return this; + } + + public StepLogicalPlan withName(String name) { + operator.withName(name); + return this; + } + + public StepLogicalPlan withOutputPathSchema(PathType outputPath) { + operator.withOutputPathSchema(outputPath); + return this; + } + + public StepLogicalPlan withInputPathSchema(List inputPaths) { + operator.withInputPathSchema(inputPaths); + return this; + } + + public StepLogicalPlan withInputPathSchema(PathType inputPath) { + operator.withInputPathSchema(inputPath); + return this; + } + + public StepLogicalPlan withOutputType(IType outputType) { + operator.withOutputType(outputType); + return this; + } + + public StepLogicalPlan withGraphSchema(GraphSchema graphSchema) { + operator.withGraphSchema(graphSchema); + return this; + } + + public StepLogicalPlan withModifyGraphSchema(GraphSchema modifyGraphSchema) { + operator.withModifyGraphSchema(modifyGraphSchema); + return this; + } + + public PathType getOutputPathSchema() { + return operator.getOutputPathSchema(); + } + + public PathType getInputPathSchema() { + assert operator.getInputPathSchemas().size() == 1; + return operator.getInputPathSchemas().get(0); + } + + public List getInputPathSchemas() { + return operator.getInputPathSchemas(); + } + + public GraphSchema getGraphSchema() { + return operator.getGraphSchema(); + } + + public GraphSchema getModifyGraphSchema() { + return operator.getModifyGraphSchema(); + } + + public IType getOutputType() { + return operator.getOutputType(); + } + + public String getPlanDesc(boolean onlyContent) { + StringBuilder graphviz = new StringBuilder(); + if (!onlyContent) { + graphviz.append("digraph G {\n"); + } + generatePlanEdge(graphviz, this, new HashSet<>()); + generatePlanVertex(graphviz, this, new HashSet<>()); + if (!onlyContent) { + graphviz.append("}"); + } + return graphviz.toString(); + } + + public String getPlanDesc() { + return getPlanDesc(false); + } + + private void generatePlanEdge(StringBuilder graphviz, StepLogicalPlan plan, Set visited) { + if (visited.contains(plan.getId())) { + return; + } + visited.add(plan.getId()); + for (StepLogicalPlan input : plan.getInputs()) { + String edgeDesc = ""; + if (!input.isAllowChain()) { + edgeDesc = "chain = false"; + } + graphviz.append( + String.format("%d -> %d [label= \"%s\"]\n", input.getId(), plan.getId(), edgeDesc)); + generatePlanEdge(graphviz, input, visited); + } + } + + private void generatePlanVertex(StringBuilder graphviz, StepLogicalPlan plan, Set visited) { + if (visited.contains(plan.getId())) { + return; + } + visited.add(plan.getId()); + + String vertexStr = plan.getOperator().toString(); + graphviz.append(String.format("%d [label= \"%s\"]\n", plan.getId(), vertexStr)); + for (StepLogicalPlan input : plan.getInputs()) { + generatePlanVertex(graphviz, input, visited); + } + } + + private static long nextPlanId() { + return idCounter.getAndIncrement(); + } + + public boolean isAllowChain() { + return allowChain; + } + + public void setAllowChain(boolean allowChain) { + this.allowChain = allowChain; + } + + public List getInputs() { + return inputs; + } + + public StepOperator getOperator() { + return operator; + } + + public List getOutputs() { + return outputs; + } + + public long getId() { + return operator.getId(); + } + + public List getFinalPlans() { + if (outputs.isEmpty()) { + return Collections.singletonList(this); + } + Set finalPlans = new LinkedHashSet<>(); + for (StepLogicalPlan output : outputs) { + finalPlans.addAll(output.getFinalPlans()); + } + return ImmutableList.copyOf(finalPlans); + } + + public StepLogicalPlan getHeadPlan() { + if (getInputs().isEmpty()) { + return this; + } + StepLogicalPlan headPlan = null; + for (StepLogicalPlan input : inputs) { + if (headPlan == null) { + headPlan = input.getHeadPlan(); + } else if (headPlan != input.getHeadPlan()) { + throw new IllegalArgumentException("Illegal plan with multi-head plan"); + } + } + return headPlan; + } + + public StepLogicalPlan copy() { + Map copyPlanCache = new LinkedHashMap<>(); + for (StepLogicalPlan finalPlan : getFinalPlans()) { + finalPlan.copy(copyPlanCache); } + return copyPlanCache.values().iterator().next().getHeadPlan(); + } + + private StepLogicalPlan copy(Map copyPlanCache) { + List inputsCopy = + inputs.stream().map(input -> input.copy(copyPlanCache)).collect(Collectors.toList()); + + if (copyPlanCache.containsKey(getId())) { + return copyPlanCache.get(getId()); + } + StepLogicalPlan copyPlan = new StepLogicalPlan(inputsCopy, operator.copy()); + copyPlanCache.put(getId(), copyPlan); + return copyPlan; + } + + public static void clearCounter() { + idCounter.set(0); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanSet.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanSet.java index cfae2dba9..1b564a47f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanSet.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanSet.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.operator.MatchVertexOperator; @@ -36,124 +37,124 @@ public class StepLogicalPlanSet { - private final GraphSchema graphSchema; + private final GraphSchema graphSchema; - private StepLogicalPlan mainLogicalPlan; + private StepLogicalPlan mainLogicalPlan; - private final Map subLogicalPlans = new HashMap<>(); + private final Map subLogicalPlans = new HashMap<>(); - public StepLogicalPlanSet(GraphSchema graphSchema) { - this.graphSchema = Objects.requireNonNull(graphSchema); - } + public StepLogicalPlanSet(GraphSchema graphSchema) { + this.graphSchema = Objects.requireNonNull(graphSchema); + } - public StepLogicalPlanSet(StepLogicalPlan mainLogicalPlan) { - this(mainLogicalPlan.getGraphSchema()); - this.mainLogicalPlan = mainLogicalPlan; - } + public StepLogicalPlanSet(StepLogicalPlan mainLogicalPlan) { + this(mainLogicalPlan.getGraphSchema()); + this.mainLogicalPlan = mainLogicalPlan; + } - public void addSubLogicalPlan(StepLogicalPlan subLogicalPlan) { - StepLogicalPlan headPlan = subLogicalPlan.getHeadPlan(); - assert headPlan.getOperator() instanceof StepSubQueryStartOperator; - StepSubQueryStartOperator startOperator = (StepSubQueryStartOperator) headPlan.getOperator(); + public void addSubLogicalPlan(StepLogicalPlan subLogicalPlan) { + StepLogicalPlan headPlan = subLogicalPlan.getHeadPlan(); + assert headPlan.getOperator() instanceof StepSubQueryStartOperator; + StepSubQueryStartOperator startOperator = (StepSubQueryStartOperator) headPlan.getOperator(); - subLogicalPlans.put(startOperator.getQueryName(), subLogicalPlan); - } + subLogicalPlans.put(startOperator.getQueryName(), subLogicalPlan); + } - public GraphSchema getGraphSchema() { - return graphSchema; - } + public GraphSchema getGraphSchema() { + return graphSchema; + } - public StepLogicalPlan getMainPlan() { - return mainLogicalPlan; - } + public StepLogicalPlan getMainPlan() { + return mainLogicalPlan; + } - public void setMainPlan(StepLogicalPlan mainPlan) { - this.mainLogicalPlan = Objects.requireNonNull(mainPlan); - } + public void setMainPlan(StepLogicalPlan mainPlan) { + this.mainLogicalPlan = Objects.requireNonNull(mainPlan); + } - public Map getSubPlans() { - return subLogicalPlans; - } + public Map getSubPlans() { + return subLogicalPlans; + } - public String getPlanSetDesc() { - StringBuilder graphviz = new StringBuilder(); - graphviz.append("digraph G {\n"); - String mainPlanDesc = mainLogicalPlan.getPlanDesc(true); - graphviz.append(mainPlanDesc).append("\n"); + public String getPlanSetDesc() { + StringBuilder graphviz = new StringBuilder(); + graphviz.append("digraph G {\n"); + String mainPlanDesc = mainLogicalPlan.getPlanDesc(true); + graphviz.append(mainPlanDesc).append("\n"); - for (StepLogicalPlan subPlan : subLogicalPlans.values()) { - graphviz.append(subPlan.getPlanDesc(true)).append("\n"); - } - String str = StringUtils.stripEnd(graphviz.toString(), "\n"); - return str + "\n}"; + for (StepLogicalPlan subPlan : subLogicalPlans.values()) { + graphviz.append(subPlan.getPlanDesc(true)).append("\n"); } - - public StepLogicalPlanSet markChainable() { - this.mainLogicalPlan = mainLogicalPlan.end(); - markChainable(mainLogicalPlan); - for (StepLogicalPlan subPlan : subLogicalPlans.values()) { - markChainable(subPlan); - } - return this; + String str = StringUtils.stripEnd(graphviz.toString(), "\n"); + return str + "\n}"; + } + + public StepLogicalPlanSet markChainable() { + this.mainLogicalPlan = mainLogicalPlan.end(); + markChainable(mainLogicalPlan); + for (StepLogicalPlan subPlan : subLogicalPlans.values()) { + markChainable(subPlan); } - - private static StepLogicalPlan markChainable(StepLogicalPlan endPlan) { - Map visitedPlanNumMV = new HashMap<>(); - markChainable(endPlan, visitedPlanNumMV); - return endPlan; + return this; + } + + private static StepLogicalPlan markChainable(StepLogicalPlan endPlan) { + Map visitedPlanNumMV = new HashMap<>(); + markChainable(endPlan, visitedPlanNumMV); + return endPlan; + } + + private static void markChainable(StepLogicalPlan plan, Map visitedPlanNumMV) { + if (visitedPlanNumMV.containsKey(plan.getId())) { + return; } - - private static void markChainable(StepLogicalPlan plan, Map visitedPlanNumMV) { - if (visitedPlanNumMV.containsKey(plan.getId())) { - return; - } - plan.getInputs().forEach(input -> markChainable(input, visitedPlanNumMV)); - List inputsNumMV = plan.getInputs() - .stream().map(input -> visitedPlanNumMV.get(input.getId())) + plan.getInputs().forEach(input -> markChainable(input, visitedPlanNumMV)); + List inputsNumMV = + plan.getInputs().stream() + .map(input -> visitedPlanNumMV.get(input.getId())) .collect(Collectors.toList()); - // init allow chain - plan.setAllowChain(true); - if (inputsNumMV.size() > 0) { - int maxNumMatchVertex = inputsNumMV.stream().max(Integer::compareTo).get(); - if (plan.getOperator() instanceof MatchVertexOperator) { - if (maxNumMatchVertex == 1) { //For VC - plan.getInputs().forEach(input -> input.setAllowChain(false)); - visitedPlanNumMV.put(plan.getId(), 1); - } else { - visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex + 1); - } - } else if (plan.getOperator() instanceof StepEndOperator) { - if (inputsNumMV.size() > 1) { - // If input size of end operator is > 1 ,then it cannot chain with the inputs. - plan.getInputs().forEach(input -> input.setAllowChain(false)); - } - } else if (plan.getOperator() instanceof StepExchangeOperator) { - plan.setAllowChain(false); - visitedPlanNumMV.put(plan.getId(), 0); - } else if (plan.getOperator() instanceof StepGlobalSortOperator) { - // after global sort, we should send the vertex back, so it cannot chain with the follow op. - // This is an implicit vertex load, so the init mv number should be 1. - plan.setAllowChain(false); - visitedPlanNumMV.put(plan.getId(), 1); - } else if (plan.getOperator() instanceof StepGlobalAggregateOperator) { - // after global aggregate, we should send the vertex back, so it cannot chain with - // the follow op. This is an implicit vertex load, so the init mv number should - // be 1. - plan.setAllowChain(false); - visitedPlanNumMV.put(plan.getId(), 1); - } else if (plan.getOperator() instanceof MatchVirtualEdgeOperator) { - plan.setAllowChain(false); - visitedPlanNumMV.put(plan.getId(), 1); - } else if (plan.getOperator() instanceof StepSubQueryStartOperator) { - plan.getInputs().forEach(input -> input.setAllowChain(false)); - visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex); - } else { - visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex); - } + // init allow chain + plan.setAllowChain(true); + if (inputsNumMV.size() > 0) { + int maxNumMatchVertex = inputsNumMV.stream().max(Integer::compareTo).get(); + if (plan.getOperator() instanceof MatchVertexOperator) { + if (maxNumMatchVertex == 1) { // For VC + plan.getInputs().forEach(input -> input.setAllowChain(false)); + visitedPlanNumMV.put(plan.getId(), 1); } else { - visitedPlanNumMV.put(plan.getId(), 0); + visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex + 1); } + } else if (plan.getOperator() instanceof StepEndOperator) { + if (inputsNumMV.size() > 1) { + // If input size of end operator is > 1 ,then it cannot chain with the inputs. + plan.getInputs().forEach(input -> input.setAllowChain(false)); + } + } else if (plan.getOperator() instanceof StepExchangeOperator) { + plan.setAllowChain(false); + visitedPlanNumMV.put(plan.getId(), 0); + } else if (plan.getOperator() instanceof StepGlobalSortOperator) { + // after global sort, we should send the vertex back, so it cannot chain with the follow op. + // This is an implicit vertex load, so the init mv number should be 1. + plan.setAllowChain(false); + visitedPlanNumMV.put(plan.getId(), 1); + } else if (plan.getOperator() instanceof StepGlobalAggregateOperator) { + // after global aggregate, we should send the vertex back, so it cannot chain with + // the follow op. This is an implicit vertex load, so the init mv number should + // be 1. + plan.setAllowChain(false); + visitedPlanNumMV.put(plan.getId(), 1); + } else if (plan.getOperator() instanceof MatchVirtualEdgeOperator) { + plan.setAllowChain(false); + visitedPlanNumMV.put(plan.getId(), 1); + } else if (plan.getOperator() instanceof StepSubQueryStartOperator) { + plan.getInputs().forEach(input -> input.setAllowChain(false)); + visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex); + } else { + visitedPlanNumMV.put(plan.getId(), maxNumMatchVertex); + } + } else { + visitedPlanNumMV.put(plan.getId(), 0); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanTranslator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanTranslator.java index ea74d2d2b..88e33c5bb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanTranslator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/StepLogicalPlanTranslator.java @@ -32,6 +32,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.JoinInfo; import org.apache.calcite.rel.type.RelDataType; @@ -119,543 +120,621 @@ public class StepLogicalPlanTranslator { - /** - * Translate path pattern to {@link StepLogicalPlan}. - * - * @param graphMatch The path pattern to translate. - * @return The last node of the {@link StepLogicalPlan}. - */ - public StepLogicalPlan translate(GraphMatch graphMatch, - StepLogicalPlanSet logicalPlanSet) { - // do the plan translate. - LogicalPlanTranslatorVisitor translator = - new LogicalPlanTranslatorVisitor(logicalPlanSet); - return translator.translate(graphMatch.getPathPattern()); - } - - /** - * Translate the {@link RelNode} in graph match to {@link StepLogicalPlan}. - **/ - private static class LogicalPlanTranslatorVisitor extends AbstractMatchNodeVisitor { + /** + * Translate path pattern to {@link StepLogicalPlan}. + * + * @param graphMatch The path pattern to translate. + * @return The last node of the {@link StepLogicalPlan}. + */ + public StepLogicalPlan translate(GraphMatch graphMatch, StepLogicalPlanSet logicalPlanSet) { + // do the plan translate. + LogicalPlanTranslatorVisitor translator = new LogicalPlanTranslatorVisitor(logicalPlanSet); + return translator.translate(graphMatch.getPathPattern()); + } - private final GraphSchema graphSchema; + /** Translate the {@link RelNode} in graph match to {@link StepLogicalPlan}. */ + private static class LogicalPlanTranslatorVisitor + extends AbstractMatchNodeVisitor { - private final StepLogicalPlanSet logicalPlanSet; + private final GraphSchema graphSchema; - private final GraphSchema modifyGraphSchema; + private final StepLogicalPlanSet logicalPlanSet; - // label -> plan - private Map planCache = new HashMap<>(); + private final GraphSchema modifyGraphSchema; - private StepLogicalPlan logicalPlanHead = null; + // label -> plan + private Map planCache = new HashMap<>(); - private final Map nodePushDownFilters; - - public LogicalPlanTranslatorVisitor(StepLogicalPlanSet logicalPlanSet) { - this(logicalPlanSet, new HashMap<>()); - } + private StepLogicalPlan logicalPlanHead = null; - private LogicalPlanTranslatorVisitor(StepLogicalPlanSet logicalPlanSet, - Map nodePushDownFilters) { - this.graphSchema = logicalPlanSet.getGraphSchema(); - this.logicalPlanSet = Objects.requireNonNull(logicalPlanSet); - this.modifyGraphSchema = graphSchema; - this.nodePushDownFilters = Objects.requireNonNull(nodePushDownFilters); - } + private final Map nodePushDownFilters; - public StepLogicalPlan translate(RelNode pathPattern) { - return this.visit(pathPattern); - } + public LogicalPlanTranslatorVisitor(StepLogicalPlanSet logicalPlanSet) { + this(logicalPlanSet, new HashMap<>()); + } - @Override - public StepLogicalPlan visitVertexMatch(VertexMatch vertexMatch) { - String label = vertexMatch.getLabel(); - RexNode filter = nodePushDownFilters.get(vertexMatch); - // TODO use optimizer rule to push the filter to the vertex-match. - if (vertexMatch.getPushDownFilter() != null) { - filter = vertexMatch.getPushDownFilter(); - } - Set startIds = new HashSet<>(); - if (vertexMatch.getInput() == null && filter != null) { - Set ids = GQLRexUtil.findVertexIds(filter, (VertexRecordType) vertexMatch.getNodeType()); - startIds = toStartIds(ids); - } else if (!vertexMatch.getIdSet().isEmpty()) { - startIds = vertexMatch.getIdSet().stream().map(id -> new ConstantStartId(id)).collect( - Collectors.toSet()); - } - Set nodeTypes = vertexMatch.getTypes().stream() - .map(s -> (BinaryString) BinaryUtil.toBinaryForString(s)) - .collect(Collectors.toSet()); + private LogicalPlanTranslatorVisitor( + StepLogicalPlanSet logicalPlanSet, Map nodePushDownFilters) { + this.graphSchema = logicalPlanSet.getGraphSchema(); + this.logicalPlanSet = Objects.requireNonNull(logicalPlanSet); + this.modifyGraphSchema = graphSchema; + this.nodePushDownFilters = Objects.requireNonNull(nodePushDownFilters); + } - // If this head label node has generated in other branch, just reuse it and push down the startIds. - if (vertexMatch.getInput() == null && planCache.containsKey(label)) { - StepLogicalPlan plan = planCache.get(label); - // push start ids to StepSourceOperator - assert plan.getInputs().size() == 1; - if (plan.getInputs().get(0).getOperator() instanceof StepSourceOperator) { - StepSourceOperator sourceOp = (StepSourceOperator) plan.getInputs().get(0).getOperator(); - sourceOp.joinStartId(startIds); - } - if (vertexMatch.getTypes().size() > 0) { - return plan.filterNode(new StepNodeTypeFilterFunction(nodeTypes)); - } - return plan; - } - IType nodeType = SqlTypeUtil.convertType(vertexMatch.getNodeType()); - // generate input plan. - StepLogicalPlan input; - if (vertexMatch.getInput() != null) { - input = this.visit(vertexMatch.getInput()); - } else { - if (logicalPlanHead == null) { // create start plan for the first time - input = StepLogicalPlan.start(startIds) - .withGraphSchema(graphSchema) - .withModifyGraphSchema(modifyGraphSchema) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(nodeType); - logicalPlanHead = input; - } else { // start from the exists start plan. - StepLogicalPlan startPlan = logicalPlanHead; - assert startPlan.getOperator() instanceof StepSourceOperator : - "Start plan should be StepSourceOperator"; - // push startIds of this branch to the StepSourceOperator. - ((StepSourceOperator) startPlan.getOperator()).unionStartId(startIds); - input = startPlan; - } - } - PathType outputPath = (PathType) SqlTypeUtil.convertType(vertexMatch.getPathSchema()); - boolean isOptionalMatch = vertexMatch instanceof OptionalVertexMatch - && SqlTypeUtil.convertType(vertexMatch.getNodeType()) != null; - MatchVertexFunction mvf = new MatchVertexFunctionImpl(nodeTypes, isOptionalMatch, - label, vertexMatch.getIdSet()); - StepLogicalPlan plan = input.vertexMatch(mvf) - .withModifyGraphSchema(input.getModifyGraphSchema()) - .withOutputPathSchema(outputPath) - .withOutputType(nodeType) - .withFilteredFields(vertexMatch.getFields()); - planCache.put(label, plan); - return plan; - } + public StepLogicalPlan translate(RelNode pathPattern) { + return this.visit(pathPattern); + } - @Override - public StepLogicalPlan visitEdgeMatch(EdgeMatch edgeMatch) { - String label = edgeMatch.getLabel(); - if (planCache.containsKey(label)) { - return planCache.get(label); - } - if (edgeMatch.getInput() == null) { - throw new GeaFlowDSLException("Graph match should start from a vertex"); - } - StepLogicalPlan input = this.visit(edgeMatch.getInput()); - - IType nodeType = SqlTypeUtil.convertType(edgeMatch.getNodeType()); - PathType outputPath = (PathType) SqlTypeUtil.convertType(edgeMatch.getPathSchema()); - - IFilter[] pushDownFilter = null; - RexNode filter = nodePushDownFilters.get(edgeMatch); - if (filter != null) { - // push down edge timestamp condition - IFilter tsRangeFilter = null; - List tsRanges = FilterPushDownUtil.findTsRange(filter, - (EdgeRecordType) edgeMatch.getNodeType()).stream().collect(Collectors.toList()); - if (!tsRanges.isEmpty()) { - for (TimeRange timeRange : tsRanges) { - if (tsRangeFilter != null) { - tsRangeFilter = tsRangeFilter.or(new EdgeTsFilter(timeRange)); - } else { - tsRangeFilter = new EdgeTsFilter(timeRange); - } - } - } - if (tsRangeFilter != null) { - pushDownFilter = new IFilter[]{tsRangeFilter}; - } - } - Set edgeTypes = edgeMatch.getTypes().stream() - .map(s -> (BinaryString) BinaryUtil.toBinaryForString(s)) + @Override + public StepLogicalPlan visitVertexMatch(VertexMatch vertexMatch) { + String label = vertexMatch.getLabel(); + RexNode filter = nodePushDownFilters.get(vertexMatch); + // TODO use optimizer rule to push the filter to the vertex-match. + if (vertexMatch.getPushDownFilter() != null) { + filter = vertexMatch.getPushDownFilter(); + } + Set startIds = new HashSet<>(); + if (vertexMatch.getInput() == null && filter != null) { + Set ids = + GQLRexUtil.findVertexIds(filter, (VertexRecordType) vertexMatch.getNodeType()); + startIds = toStartIds(ids); + } else if (!vertexMatch.getIdSet().isEmpty()) { + startIds = + vertexMatch.getIdSet().stream() + .map(id -> new ConstantStartId(id)) .collect(Collectors.toSet()); - boolean isOptionalMatch = edgeMatch instanceof OptionalEdgeMatch - && SqlTypeUtil.convertType(edgeMatch.getNodeType()) != null; - MatchEdgeFunction mef = - pushDownFilter == null ? new MatchEdgeFunctionImpl(edgeMatch.getDirection(), - edgeTypes, isOptionalMatch, label) : - new MatchEdgeFunctionImpl(edgeMatch.getDirection(), edgeTypes, isOptionalMatch, - label, pushDownFilter); - - StepLogicalPlan plan = input.edgeMatch(mef) - .withModifyGraphSchema(input.getModifyGraphSchema()) - .withOutputPathSchema(outputPath) - .withOutputType(nodeType) - .withFilteredFields(edgeMatch.getFields());; - planCache.put(label, plan); - return plan; + } + Set nodeTypes = + vertexMatch.getTypes().stream() + .map(s -> (BinaryString) BinaryUtil.toBinaryForString(s)) + .collect(Collectors.toSet()); + + // If this head label node has generated in other branch, just reuse it and push down the + // startIds. + if (vertexMatch.getInput() == null && planCache.containsKey(label)) { + StepLogicalPlan plan = planCache.get(label); + // push start ids to StepSourceOperator + assert plan.getInputs().size() == 1; + if (plan.getInputs().get(0).getOperator() instanceof StepSourceOperator) { + StepSourceOperator sourceOp = (StepSourceOperator) plan.getInputs().get(0).getOperator(); + sourceOp.joinStartId(startIds); } - - @Override - public StepLogicalPlan visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) { - StepLogicalPlan input = this.visit(virtualEdgeMatch.getInput()); - PathRecordType inputPath = ((IMatchNode) virtualEdgeMatch.getInput()).getPathSchema(); - Expression targetId = ExpressionTranslator.of(inputPath, logicalPlanSet) - .translate(virtualEdgeMatch.getTargetId()); - PathType outputPath = (PathType) SqlTypeUtil.convertType(virtualEdgeMatch.getPathSchema()); - MatchVirtualEdgeFunction virtualEdgeFunction = new MatchVirtualEdgeFunctionImpl(targetId); - return input.virtualEdgeMatch(virtualEdgeFunction) - .withModifyGraphSchema(input.getModifyGraphSchema()) - .withOutputPathSchema(outputPath) - .withOutputType(SqlTypeUtil.convertType(virtualEdgeMatch.getNodeType())); + if (vertexMatch.getTypes().size() > 0) { + return plan.filterNode(new StepNodeTypeFilterFunction(nodeTypes)); } - - @Override - public StepLogicalPlan visitFilter(MatchFilter filter) { - // push down filter condition - nodePushDownFilters.put(filter.getInput(), filter.getCondition()); - StepLogicalPlan input = this.visit(filter.getInput()); - PathType outputPath = (PathType) SqlTypeUtil.convertType(filter.getPathSchema()); - PathRecordType inputPath = ((IMatchNode) filter.getInput()).getPathSchema(); - - Expression condition = - ExpressionTranslator.of(inputPath, logicalPlanSet).translate(filter.getCondition()); - StepBoolFunction fn = new StepBoolFunctionImpl(condition); - return input.filter(fn).withModifyGraphSchema(input.getModifyGraphSchema()) - .withOutputPathSchema(outputPath); + return plan; + } + IType nodeType = SqlTypeUtil.convertType(vertexMatch.getNodeType()); + // generate input plan. + StepLogicalPlan input; + if (vertexMatch.getInput() != null) { + input = this.visit(vertexMatch.getInput()); + } else { + if (logicalPlanHead == null) { // create start plan for the first time + input = + StepLogicalPlan.start(startIds) + .withGraphSchema(graphSchema) + .withModifyGraphSchema(modifyGraphSchema) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(nodeType); + logicalPlanHead = input; + } else { // start from the exists start plan. + StepLogicalPlan startPlan = logicalPlanHead; + assert startPlan.getOperator() instanceof StepSourceOperator + : "Start plan should be StepSourceOperator"; + // push startIds of this branch to the StepSourceOperator. + ((StepSourceOperator) startPlan.getOperator()).unionStartId(startIds); + input = startPlan; } + } + PathType outputPath = (PathType) SqlTypeUtil.convertType(vertexMatch.getPathSchema()); + boolean isOptionalMatch = + vertexMatch instanceof OptionalVertexMatch + && SqlTypeUtil.convertType(vertexMatch.getNodeType()) != null; + MatchVertexFunction mvf = + new MatchVertexFunctionImpl(nodeTypes, isOptionalMatch, label, vertexMatch.getIdSet()); + StepLogicalPlan plan = + input + .vertexMatch(mvf) + .withModifyGraphSchema(input.getModifyGraphSchema()) + .withOutputPathSchema(outputPath) + .withOutputType(nodeType) + .withFilteredFields(vertexMatch.getFields()); + planCache.put(label, plan); + return plan; + } - @Override - public StepLogicalPlan visitJoin(MatchJoin join) { - JoinInfo joinInfo = join.analyzeCondition(); - PathRecordType leftPathType = ((IMatchNode) join.getLeft()).getPathSchema(); - PathRecordType rightPathType = ((IMatchNode) join.getRight()).getPathSchema(); - - IType[] leftKeyTypes = joinInfo.leftKeys.stream() - .map(index -> - SqlTypeUtil.convertType(leftPathType - .getFieldList().get(index).getType())) - .collect(Collectors.toList()) - .toArray(new IType[]{}); - IType[] rightKeyTypes = joinInfo.rightKeys.stream() - .map(index -> - SqlTypeUtil.convertType(rightPathType - .getFieldList().get(index).getType())) - .collect(Collectors.toList()) - .toArray(new IType[]{}); - StepKeyFunction leftKeyFn = new StepKeyFunctionImpl(toIntArray(joinInfo.leftKeys), leftKeyTypes); - StepKeyFunction rightKeyFn = new StepKeyFunctionImpl(toIntArray(joinInfo.rightKeys), rightKeyTypes); - - StepLogicalPlan leftPlan = visit(join.getLeft()); - StepLogicalPlan rightPlan = visit(join.getRight()); - IType[] leftPathTypes = leftPlan.getOutputPathSchema().getTypes(); - IType[] rightPathTypes = rightPlan.getOutputPathSchema().getTypes(); - - Expression joinConditionExp = - ExpressionTranslator.of(join.getPathSchema()).translate(join.getCondition()); - StepJoinFunction joinFunction = new StepJoinFunctionImpl(join.getJoinType(), - leftPathTypes, rightPathTypes, joinConditionExp); - - PathType inputJoinPath = (PathType) SqlTypeUtil.convertType(leftPathType.join(rightPathType, - join.getCluster().getTypeFactory())); - PathType joinOutputPath = (PathType) SqlTypeUtil.convertType(join.getPathSchema()); - - List leftChainableVertex = - StepLogicalPlanTranslator.getChainableVertexMatch(leftPlan); - List rightChainableVertex = - StepLogicalPlanTranslator.getChainableVertexMatch(rightPlan); - boolean isLocalJoin = false; - if (leftChainableVertex.size() == 1 - && rightChainableVertex.size() == 1 - && joinInfo.leftKeys.size() == 1 && joinInfo.rightKeys.size() == 1) { - String leftVertexLabel = ((MatchVertexOperator) leftChainableVertex.get(0).getOperator()).getLabel(); - String rightVertexLabel = ((MatchVertexOperator) rightChainableVertex.get(0).getOperator()).getLabel(); - if (leftPathType.getFieldList().get(joinInfo.leftKeys.get(0)).getName().equals(leftVertexLabel) - && rightPathType.getFieldList().get(joinInfo.rightKeys.get(0)).getName().equals(rightVertexLabel)) { - isLocalJoin = true; - } + @Override + public StepLogicalPlan visitEdgeMatch(EdgeMatch edgeMatch) { + String label = edgeMatch.getLabel(); + if (planCache.containsKey(label)) { + return planCache.get(label); + } + if (edgeMatch.getInput() == null) { + throw new GeaFlowDSLException("Graph match should start from a vertex"); + } + StepLogicalPlan input = this.visit(edgeMatch.getInput()); + + IType nodeType = SqlTypeUtil.convertType(edgeMatch.getNodeType()); + PathType outputPath = (PathType) SqlTypeUtil.convertType(edgeMatch.getPathSchema()); + + IFilter[] pushDownFilter = null; + RexNode filter = nodePushDownFilters.get(edgeMatch); + if (filter != null) { + // push down edge timestamp condition + IFilter tsRangeFilter = null; + List tsRanges = + FilterPushDownUtil.findTsRange(filter, (EdgeRecordType) edgeMatch.getNodeType()) + .stream() + .collect(Collectors.toList()); + if (!tsRanges.isEmpty()) { + for (TimeRange timeRange : tsRanges) { + if (tsRangeFilter != null) { + tsRangeFilter = tsRangeFilter.or(new EdgeTsFilter(timeRange)); + } else { + tsRangeFilter = new EdgeTsFilter(timeRange); } - return leftPlan - .join(rightPlan, leftKeyFn, rightKeyFn, joinFunction, inputJoinPath, isLocalJoin) - .withOutputPathSchema(joinOutputPath); + } } - - @Override - public StepLogicalPlan visitDistinct(MatchDistinct distinct) { - RelNode input = distinct.getInput(0); - IType[] types = ((IMatchNode) input).getPathSchema().getFieldList().stream() - .map(field -> SqlTypeUtil.convertType(field.getType())) - .collect(Collectors.toList()).toArray(new IType[]{}); - int[] keyIndices = new int[types.length]; - for (int i = 0, size = types.length; i < size; i++) { - keyIndices[i] = i; - } - StepKeyFunction keyFunction = new StepKeyFunctionImpl(keyIndices, types); - PathType distinctPathType = (PathType) SqlTypeUtil.convertType(distinct.getPathSchema()); - IType nodeType = SqlTypeUtil.convertType(distinct.getNodeType()); - return visit(input).distinct(keyFunction) - .withOutputPathSchema(distinctPathType) - .withOutputType(nodeType); + if (tsRangeFilter != null) { + pushDownFilter = new IFilter[] {tsRangeFilter}; } + } + Set edgeTypes = + edgeMatch.getTypes().stream() + .map(s -> (BinaryString) BinaryUtil.toBinaryForString(s)) + .collect(Collectors.toSet()); + boolean isOptionalMatch = + edgeMatch instanceof OptionalEdgeMatch + && SqlTypeUtil.convertType(edgeMatch.getNodeType()) != null; + MatchEdgeFunction mef = + pushDownFilter == null + ? new MatchEdgeFunctionImpl( + edgeMatch.getDirection(), edgeTypes, isOptionalMatch, label) + : new MatchEdgeFunctionImpl( + edgeMatch.getDirection(), edgeTypes, isOptionalMatch, label, pushDownFilter); + + StepLogicalPlan plan = + input + .edgeMatch(mef) + .withModifyGraphSchema(input.getModifyGraphSchema()) + .withOutputPathSchema(outputPath) + .withOutputType(nodeType) + .withFilteredFields(edgeMatch.getFields()); + ; + planCache.put(label, plan); + return plan; + } - @Override - public StepLogicalPlan visitUnion(MatchUnion union) { - List inputPlans = new ArrayList<>(); - - for (int i = 0, size = union.getInputs().size(); i < size; i++) { - // The input of union should not referer the plan cache generated by each other. - // So we create a new plan cache for each input. - Map prePlanCache = planCache; - planCache = new HashMap<>(planCache); - inputPlans.add(visit(union.getInput(i))); - // recover pre-plan cache. - planCache = prePlanCache; - } - - StepLogicalPlan firstPlan = inputPlans.get(0); - PathType unionPathType = (PathType) SqlTypeUtil.convertType(union.getPathSchema()); - IType nodeType = SqlTypeUtil.convertType(union.getNodeType()); + @Override + public StepLogicalPlan visitVirtualEdgeMatch(VirtualEdgeMatch virtualEdgeMatch) { + StepLogicalPlan input = this.visit(virtualEdgeMatch.getInput()); + PathRecordType inputPath = ((IMatchNode) virtualEdgeMatch.getInput()).getPathSchema(); + Expression targetId = + ExpressionTranslator.of(inputPath, logicalPlanSet) + .translate(virtualEdgeMatch.getTargetId()); + PathType outputPath = (PathType) SqlTypeUtil.convertType(virtualEdgeMatch.getPathSchema()); + MatchVirtualEdgeFunction virtualEdgeFunction = new MatchVirtualEdgeFunctionImpl(targetId); + return input + .virtualEdgeMatch(virtualEdgeFunction) + .withModifyGraphSchema(input.getModifyGraphSchema()) + .withOutputPathSchema(outputPath) + .withOutputType(SqlTypeUtil.convertType(virtualEdgeMatch.getNodeType())); + } - StepLogicalPlan unionPlan = firstPlan.union(inputPlans.subList(1, inputPlans.size())) - .withModifyGraphSchema(firstPlan.getModifyGraphSchema()) - .withOutputPathSchema(unionPathType) - .withOutputType(nodeType); - if (union.all) { - return unionPlan; - } else { - IType[] types = unionPlan.getOutputPathSchema().getFields().stream() - .map(TableField::getType) - .collect(Collectors.toList()).toArray(new IType[]{}); - int[] keyIndices = new int[types.length]; - for (int i = 0, size = types.length; i < size; i++) { - keyIndices[i] = i; - } - StepKeyFunction keyFunction = new StepKeyFunctionImpl(keyIndices, types); - return unionPlan.distinct(keyFunction) - .withModifyGraphSchema(unionPlan.getModifyGraphSchema()) - .withOutputPathSchema(unionPlan.getOutputPathSchema()) - .withOutputType(unionPlan.getOutputType()); - } - } + @Override + public StepLogicalPlan visitFilter(MatchFilter filter) { + // push down filter condition + nodePushDownFilters.put(filter.getInput(), filter.getCondition()); + StepLogicalPlan input = this.visit(filter.getInput()); + PathType outputPath = (PathType) SqlTypeUtil.convertType(filter.getPathSchema()); + PathRecordType inputPath = ((IMatchNode) filter.getInput()).getPathSchema(); + + Expression condition = + ExpressionTranslator.of(inputPath, logicalPlanSet).translate(filter.getCondition()); + StepBoolFunction fn = new StepBoolFunctionImpl(condition); + return input + .filter(fn) + .withModifyGraphSchema(input.getModifyGraphSchema()) + .withOutputPathSchema(outputPath); + } - @Override - public StepLogicalPlan visitLoopMatch(LoopUntilMatch loopMatch) { - StepLogicalPlan loopStart = visit(loopMatch.getInput()); - StepLogicalPlan loopBody = visit(loopMatch.getLoopBody()); - for (StepLogicalPlan plan : loopBody.getFinalPlans()) { - plan.withModifyGraphSchema(loopStart.getModifyGraphSchema()); - } - ExpressionTranslator translator = ExpressionTranslator.of(loopMatch.getLoopBody().getPathSchema()); - Expression utilCondition = translator.translate(loopMatch.getUtilCondition()); - - PathType outputPath = (PathType) SqlTypeUtil.convertType(loopMatch.getPathSchema()); - IType nodeType = SqlTypeUtil.convertType(loopMatch.getNodeType()); - int loopStartPathFieldCount = loopStart.getOutputPathSchema().size(); - int loopBodyPathFieldCount = loopBody.getOutputPathSchema().size() - loopStartPathFieldCount; - return loopStart.loopUtil(loopBody, new StepBoolFunctionImpl(utilCondition), - loopMatch.getMinLoopCount(), loopMatch.getMaxLoopCount(), - loopStartPathFieldCount, loopBodyPathFieldCount) - .withModifyGraphSchema(loopStart.getModifyGraphSchema()) - .withOutputPathSchema(outputPath) - .withOutputType(nodeType) - ; + @Override + public StepLogicalPlan visitJoin(MatchJoin join) { + JoinInfo joinInfo = join.analyzeCondition(); + PathRecordType leftPathType = ((IMatchNode) join.getLeft()).getPathSchema(); + PathRecordType rightPathType = ((IMatchNode) join.getRight()).getPathSchema(); + + IType[] leftKeyTypes = + joinInfo.leftKeys.stream() + .map( + index -> + SqlTypeUtil.convertType(leftPathType.getFieldList().get(index).getType())) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + IType[] rightKeyTypes = + joinInfo.rightKeys.stream() + .map( + index -> + SqlTypeUtil.convertType(rightPathType.getFieldList().get(index).getType())) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + StepKeyFunction leftKeyFn = + new StepKeyFunctionImpl(toIntArray(joinInfo.leftKeys), leftKeyTypes); + StepKeyFunction rightKeyFn = + new StepKeyFunctionImpl(toIntArray(joinInfo.rightKeys), rightKeyTypes); + + StepLogicalPlan leftPlan = visit(join.getLeft()); + StepLogicalPlan rightPlan = visit(join.getRight()); + IType[] leftPathTypes = leftPlan.getOutputPathSchema().getTypes(); + IType[] rightPathTypes = rightPlan.getOutputPathSchema().getTypes(); + + Expression joinConditionExp = + ExpressionTranslator.of(join.getPathSchema()).translate(join.getCondition()); + StepJoinFunction joinFunction = + new StepJoinFunctionImpl( + join.getJoinType(), leftPathTypes, rightPathTypes, joinConditionExp); + + PathType inputJoinPath = + (PathType) + SqlTypeUtil.convertType( + leftPathType.join(rightPathType, join.getCluster().getTypeFactory())); + PathType joinOutputPath = (PathType) SqlTypeUtil.convertType(join.getPathSchema()); + + List leftChainableVertex = + StepLogicalPlanTranslator.getChainableVertexMatch(leftPlan); + List rightChainableVertex = + StepLogicalPlanTranslator.getChainableVertexMatch(rightPlan); + boolean isLocalJoin = false; + if (leftChainableVertex.size() == 1 + && rightChainableVertex.size() == 1 + && joinInfo.leftKeys.size() == 1 + && joinInfo.rightKeys.size() == 1) { + String leftVertexLabel = + ((MatchVertexOperator) leftChainableVertex.get(0).getOperator()).getLabel(); + String rightVertexLabel = + ((MatchVertexOperator) rightChainableVertex.get(0).getOperator()).getLabel(); + if (leftPathType + .getFieldList() + .get(joinInfo.leftKeys.get(0)) + .getName() + .equals(leftVertexLabel) + && rightPathType + .getFieldList() + .get(joinInfo.rightKeys.get(0)) + .getName() + .equals(rightVertexLabel)) { + isLocalJoin = true; } + } + return leftPlan + .join(rightPlan, leftKeyFn, rightKeyFn, joinFunction, inputJoinPath, isLocalJoin) + .withOutputPathSchema(joinOutputPath); + } - @Override - public StepLogicalPlan visitSubQueryStart(SubQueryStart subQueryStart) { - PathType pathType = (PathType) SqlTypeUtil.convertType(subQueryStart.getPathSchema()); - - return StepLogicalPlan.subQueryStart(subQueryStart.getQueryName()) - .withGraphSchema(graphSchema) - .withInputPathSchema(pathType) - .withOutputPathSchema(pathType) - .withOutputType(SqlTypeUtil.convertType(subQueryStart.getNodeType())); - } + @Override + public StepLogicalPlan visitDistinct(MatchDistinct distinct) { + RelNode input = distinct.getInput(0); + IType[] types = + ((IMatchNode) input) + .getPathSchema().getFieldList().stream() + .map(field -> SqlTypeUtil.convertType(field.getType())) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + int[] keyIndices = new int[types.length]; + for (int i = 0, size = types.length; i < size; i++) { + keyIndices[i] = i; + } + StepKeyFunction keyFunction = new StepKeyFunctionImpl(keyIndices, types); + PathType distinctPathType = (PathType) SqlTypeUtil.convertType(distinct.getPathSchema()); + IType nodeType = SqlTypeUtil.convertType(distinct.getNodeType()); + return visit(input) + .distinct(keyFunction) + .withOutputPathSchema(distinctPathType) + .withOutputType(nodeType); + } - @Override - public StepLogicalPlan visitPathModify(MatchPathModify pathModify) { - StepLogicalPlan input = visit(pathModify.getInput()); - List modifyExpressions = pathModify.getExpressions(); - int[] updatePathIndices = new int[modifyExpressions.size()]; - Expression[] updateExpressions = new Expression[modifyExpressions.size()]; - - ExpressionTranslator translator = ExpressionTranslator.of(pathModify.getInput().getRowType(), - logicalPlanSet); - for (int i = 0; i < modifyExpressions.size(); i++) { - PathModifyExpression modifyExpression = modifyExpressions.get(i); - updatePathIndices[i] = modifyExpression.getIndex(); - updateExpressions[i] = translator.translate(modifyExpression.getObjectConstruct()); - } - IType[] inputFieldTypes = input.getOutputPathSchema().getFields() - .stream() + @Override + public StepLogicalPlan visitUnion(MatchUnion union) { + List inputPlans = new ArrayList<>(); + + for (int i = 0, size = union.getInputs().size(); i < size; i++) { + // The input of union should not referer the plan cache generated by each other. + // So we create a new plan cache for each input. + Map prePlanCache = planCache; + planCache = new HashMap<>(planCache); + inputPlans.add(visit(union.getInput(i))); + // recover pre-plan cache. + planCache = prePlanCache; + } + + StepLogicalPlan firstPlan = inputPlans.get(0); + PathType unionPathType = (PathType) SqlTypeUtil.convertType(union.getPathSchema()); + IType nodeType = SqlTypeUtil.convertType(union.getNodeType()); + + StepLogicalPlan unionPlan = + firstPlan + .union(inputPlans.subList(1, inputPlans.size())) + .withModifyGraphSchema(firstPlan.getModifyGraphSchema()) + .withOutputPathSchema(unionPathType) + .withOutputType(nodeType); + if (union.all) { + return unionPlan; + } else { + IType[] types = + unionPlan.getOutputPathSchema().getFields().stream() .map(TableField::getType) .collect(Collectors.toList()) - .toArray(new IType[]{}); - GraphSchema modifyGraphSchema = (GraphSchema) SqlTypeUtil.convertType(pathModify.getModifyGraphType()); - StepPathModifyFunction modifyFunction = new StepPathModifyFunction(updatePathIndices, - updateExpressions, inputFieldTypes); - boolean isGlobal = pathModify.getExpressions().stream().anyMatch(exp -> { - return exp.getObjectConstruct().getVariableInfo().stream().anyMatch(VariableInfo::isGlobal); - }); - return input.map(modifyFunction, isGlobal) - .withGraphSchema(graphSchema) - .withModifyGraphSchema(modifyGraphSchema) - .withInputPathSchema(input.getOutputPathSchema()) - .withOutputPathSchema((PathType) SqlTypeUtil.convertType(pathModify.getRowType())) - .withOutputType(input.getOutputType()); + .toArray(new IType[] {}); + int[] keyIndices = new int[types.length]; + for (int i = 0, size = types.length; i < size; i++) { + keyIndices[i] = i; } + StepKeyFunction keyFunction = new StepKeyFunctionImpl(keyIndices, types); + return unionPlan + .distinct(keyFunction) + .withModifyGraphSchema(unionPlan.getModifyGraphSchema()) + .withOutputPathSchema(unionPlan.getOutputPathSchema()) + .withOutputType(unionPlan.getOutputType()); + } + } - @Override - public StepLogicalPlan visitExtend(MatchExtend matchExtend) { - StepLogicalPlan input = visit(matchExtend.getInput()); - List modifyExpressions = matchExtend.getExpressions(); - int[] updatePathIndices = new int[modifyExpressions.size()]; - Expression[] updateExpressions = new Expression[modifyExpressions.size()]; - - ExpressionTranslator translator = ExpressionTranslator.of( - matchExtend.getInput().getRowType(), logicalPlanSet); - int offset = 0; - for (int i = 0; i < modifyExpressions.size(); i++) { - PathModifyExpression modifyExpression = modifyExpressions.get(i); - if (matchExtend.getRewriteFields().contains(modifyExpression.getLeftVar().getLabel())) { - updatePathIndices[i] = modifyExpression.getIndex(); - } else { - updatePathIndices[i] = input.getOutputPathSchema().size() + offset; - offset++; - } - updateExpressions[i] = translator.translate(modifyExpression.getObjectConstruct()); - } - IType[] inputFieldTypes = input.getOutputPathSchema().getFields() - .stream() - .map(TableField::getType) - .collect(Collectors.toList()) - .toArray(new IType[]{}); - GraphSchema modifyGraphSchema = (GraphSchema) SqlTypeUtil.convertType(matchExtend.getModifyGraphType()); - StepPathModifyFunction modifyFunction = new StepPathModifyFunction(updatePathIndices, - updateExpressions, inputFieldTypes); - return input.map(modifyFunction, false) - .withGraphSchema(graphSchema) - .withModifyGraphSchema(modifyGraphSchema) - .withInputPathSchema(input.getOutputPathSchema()) - .withOutputPathSchema((PathType) SqlTypeUtil.convertType(matchExtend.getRowType())) - .withOutputType(input.getOutputType()); - } + @Override + public StepLogicalPlan visitLoopMatch(LoopUntilMatch loopMatch) { + StepLogicalPlan loopStart = visit(loopMatch.getInput()); + StepLogicalPlan loopBody = visit(loopMatch.getLoopBody()); + for (StepLogicalPlan plan : loopBody.getFinalPlans()) { + plan.withModifyGraphSchema(loopStart.getModifyGraphSchema()); + } + ExpressionTranslator translator = + ExpressionTranslator.of(loopMatch.getLoopBody().getPathSchema()); + Expression utilCondition = translator.translate(loopMatch.getUtilCondition()); + + PathType outputPath = (PathType) SqlTypeUtil.convertType(loopMatch.getPathSchema()); + IType nodeType = SqlTypeUtil.convertType(loopMatch.getNodeType()); + int loopStartPathFieldCount = loopStart.getOutputPathSchema().size(); + int loopBodyPathFieldCount = loopBody.getOutputPathSchema().size() - loopStartPathFieldCount; + return loopStart + .loopUtil( + loopBody, + new StepBoolFunctionImpl(utilCondition), + loopMatch.getMinLoopCount(), + loopMatch.getMaxLoopCount(), + loopStartPathFieldCount, + loopBodyPathFieldCount) + .withModifyGraphSchema(loopStart.getModifyGraphSchema()) + .withOutputPathSchema(outputPath) + .withOutputType(nodeType); + } - @Override - public StepLogicalPlan visitSort(MatchPathSort pathSort) { - StepLogicalPlan input = visit(pathSort.getInput()); - SortInfo sortInfo = buildSortInfo(pathSort); - StepSortFunction orderByFunction = new StepSortFunctionImpl(sortInfo); - PathType inputPath = input.getOutputPathSchema(); - return input.sort(orderByFunction) - .withModifyGraphSchema(input.getModifyGraphSchema()) - .withInputPathSchema(inputPath) - .withOutputPathSchema(inputPath).withOutputType(inputPath); - } + @Override + public StepLogicalPlan visitSubQueryStart(SubQueryStart subQueryStart) { + PathType pathType = (PathType) SqlTypeUtil.convertType(subQueryStart.getPathSchema()); - @Override - public StepLogicalPlan visitAggregate(MatchAggregate matchAggregate) { - StepLogicalPlan input = visit(matchAggregate.getInput()); - List groupList = matchAggregate.getGroupSet(); - RelDataType inputRelDataType = matchAggregate.getInput().getRowType(); - List groupListExpressions = groupList.stream().map(rex -> - ExpressionTranslator.of(inputRelDataType, logicalPlanSet).translate(rex)).collect( - Collectors.toList()); - StepKeyFunction keyFunction = new StepKeyExpressionFunctionImpl( - groupListExpressions.toArray(new Expression[0]), - groupListExpressions.stream().map(Expression::getOutputType).toArray(IType[]::new)); - - List aggCalls = matchAggregate.getAggCalls(); - List aggFnCalls = new ArrayList<>(); - for (MatchAggregateCall aggCall : aggCalls) { - String name = aggCall.getName(); - Expression[] argFields = aggCall.getArgList().stream().map( - rex -> ExpressionTranslator.of(inputRelDataType, logicalPlanSet).translate(rex)) - .toArray(Expression[]::new); - IType[] argFieldTypes = Arrays.stream(argFields).map(Expression::getOutputType) - .toArray(IType[]::new); - Class> udafClass = - PhysicAggregateRelNode.findUDAF(aggCall.getAggregation(), argFieldTypes); - StepAggCall functionCall = new StepAggCall(name, argFields, argFieldTypes, udafClass, - aggCall.isDistinct()); - aggFnCalls.add(functionCall); - } + return StepLogicalPlan.subQueryStart(subQueryStart.getQueryName()) + .withGraphSchema(graphSchema) + .withInputPathSchema(pathType) + .withOutputPathSchema(pathType) + .withOutputType(SqlTypeUtil.convertType(subQueryStart.getNodeType())); + } - List> aggOutputTypes = aggCalls.stream() - .map(call -> SqlTypeUtil.convertType(call.getType())) - .collect(Collectors.toList()); - int[] pathPruneIndices = inputRelDataType.getFieldList().stream().filter( - f -> matchAggregate.getPathSchema().getFieldNames().contains(f.getName()) - ).map(RelDataTypeField::getIndex).mapToInt(Integer::intValue).toArray(); - IType[] inputPathTypes = inputRelDataType.getFieldList().stream() - .map(f -> SqlTypeUtil.convertType(f.getType())).toArray(IType[]::new); - IType[] pathPruneTypes = matchAggregate.getPathSchema().getFieldList().stream() - .map(f -> SqlTypeUtil.convertType(f.getType())).toArray(IType[]::new); - StepAggregateFunction aggFn = new StepAggExpressionFunctionImpl(pathPruneIndices, - pathPruneTypes, inputPathTypes, aggFnCalls, aggOutputTypes); - - PathType inputPath = input.getOutputPathSchema(); - PathType outputPath = (PathType) SqlTypeUtil.convertType(matchAggregate.getRowType()); - return input.aggregate(inputPath, outputPath, keyFunction, aggFn); - } + @Override + public StepLogicalPlan visitPathModify(MatchPathModify pathModify) { + StepLogicalPlan input = visit(pathModify.getInput()); + List modifyExpressions = pathModify.getExpressions(); + int[] updatePathIndices = new int[modifyExpressions.size()]; + Expression[] updateExpressions = new Expression[modifyExpressions.size()]; + + ExpressionTranslator translator = + ExpressionTranslator.of(pathModify.getInput().getRowType(), logicalPlanSet); + for (int i = 0; i < modifyExpressions.size(); i++) { + PathModifyExpression modifyExpression = modifyExpressions.get(i); + updatePathIndices[i] = modifyExpression.getIndex(); + updateExpressions[i] = translator.translate(modifyExpression.getObjectConstruct()); + } + IType[] inputFieldTypes = + input.getOutputPathSchema().getFields().stream() + .map(TableField::getType) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + GraphSchema modifyGraphSchema = + (GraphSchema) SqlTypeUtil.convertType(pathModify.getModifyGraphType()); + StepPathModifyFunction modifyFunction = + new StepPathModifyFunction(updatePathIndices, updateExpressions, inputFieldTypes); + boolean isGlobal = + pathModify.getExpressions().stream() + .anyMatch( + exp -> { + return exp.getObjectConstruct().getVariableInfo().stream() + .anyMatch(VariableInfo::isGlobal); + }); + return input + .map(modifyFunction, isGlobal) + .withGraphSchema(graphSchema) + .withModifyGraphSchema(modifyGraphSchema) + .withInputPathSchema(input.getOutputPathSchema()) + .withOutputPathSchema((PathType) SqlTypeUtil.convertType(pathModify.getRowType())) + .withOutputType(input.getOutputType()); + } - private SortInfo buildSortInfo(PathSort sort) { - SortInfo sortInfo = new SortInfo(); - ExpressionTranslator translator = ExpressionTranslator.of(sort.getRowType()); - for (RexNode fd : sort.getOrderByExpressions()) { - OrderByField orderByField = new OrderByField(); - if (fd.getKind() == DESCENDING) { - orderByField.order = ORDER.DESC; - } else { - orderByField.order = ORDER.ASC; - } - orderByField.expression = translator.translate(fd); - sortInfo.orderByFields.add(orderByField); - } - sortInfo.fetch = sort.getLimit() == null ? -1 : - (int) TypeCastUtil.cast( - translator.translate(sort.getLimit()).evaluate(null), - Integer.class); - return sortInfo; + @Override + public StepLogicalPlan visitExtend(MatchExtend matchExtend) { + StepLogicalPlan input = visit(matchExtend.getInput()); + List modifyExpressions = matchExtend.getExpressions(); + int[] updatePathIndices = new int[modifyExpressions.size()]; + Expression[] updateExpressions = new Expression[modifyExpressions.size()]; + + ExpressionTranslator translator = + ExpressionTranslator.of(matchExtend.getInput().getRowType(), logicalPlanSet); + int offset = 0; + for (int i = 0; i < modifyExpressions.size(); i++) { + PathModifyExpression modifyExpression = modifyExpressions.get(i); + if (matchExtend.getRewriteFields().contains(modifyExpression.getLeftVar().getLabel())) { + updatePathIndices[i] = modifyExpression.getIndex(); + } else { + updatePathIndices[i] = input.getOutputPathSchema().size() + offset; + offset++; } + updateExpressions[i] = translator.translate(modifyExpression.getObjectConstruct()); + } + IType[] inputFieldTypes = + input.getOutputPathSchema().getFields().stream() + .map(TableField::getType) + .collect(Collectors.toList()) + .toArray(new IType[] {}); + GraphSchema modifyGraphSchema = + (GraphSchema) SqlTypeUtil.convertType(matchExtend.getModifyGraphType()); + StepPathModifyFunction modifyFunction = + new StepPathModifyFunction(updatePathIndices, updateExpressions, inputFieldTypes); + return input + .map(modifyFunction, false) + .withGraphSchema(graphSchema) + .withModifyGraphSchema(modifyGraphSchema) + .withInputPathSchema(input.getOutputPathSchema()) + .withOutputPathSchema((PathType) SqlTypeUtil.convertType(matchExtend.getRowType())) + .withOutputType(input.getOutputType()); } - private static Set toStartIds(Set ids) { - return ids.stream() - .map(StepLogicalPlanTranslator::toStartId) - .collect(Collectors.toSet()); + @Override + public StepLogicalPlan visitSort(MatchPathSort pathSort) { + StepLogicalPlan input = visit(pathSort.getInput()); + SortInfo sortInfo = buildSortInfo(pathSort); + StepSortFunction orderByFunction = new StepSortFunctionImpl(sortInfo); + PathType inputPath = input.getOutputPathSchema(); + return input + .sort(orderByFunction) + .withModifyGraphSchema(input.getModifyGraphSchema()) + .withInputPathSchema(inputPath) + .withOutputPathSchema(inputPath) + .withOutputType(inputPath); } - private static StartId toStartId(RexNode id) { - List nonLiteralLeafNodes = GQLRexUtil.collect(id, - child -> !(child instanceof RexCall) && !(child instanceof RexLiteral)); + @Override + public StepLogicalPlan visitAggregate(MatchAggregate matchAggregate) { + StepLogicalPlan input = visit(matchAggregate.getInput()); + List groupList = matchAggregate.getGroupSet(); + RelDataType inputRelDataType = matchAggregate.getInput().getRowType(); + List groupListExpressions = + groupList.stream() + .map(rex -> ExpressionTranslator.of(inputRelDataType, logicalPlanSet).translate(rex)) + .collect(Collectors.toList()); + StepKeyFunction keyFunction = + new StepKeyExpressionFunctionImpl( + groupListExpressions.toArray(new Expression[0]), + groupListExpressions.stream() + .map(Expression::getOutputType) + .toArray(IType[]::new)); + + List aggCalls = matchAggregate.getAggCalls(); + List aggFnCalls = new ArrayList<>(); + for (MatchAggregateCall aggCall : aggCalls) { + String name = aggCall.getName(); + Expression[] argFields = + aggCall.getArgList().stream() + .map( + rex -> ExpressionTranslator.of(inputRelDataType, logicalPlanSet).translate(rex)) + .toArray(Expression[]::new); + IType[] argFieldTypes = + Arrays.stream(argFields).map(Expression::getOutputType).toArray(IType[]::new); + Class> udafClass = + PhysicAggregateRelNode.findUDAF(aggCall.getAggregation(), argFieldTypes); + StepAggCall functionCall = + new StepAggCall(name, argFields, argFieldTypes, udafClass, aggCall.isDistinct()); + aggFnCalls.add(functionCall); + } + + List> aggOutputTypes = + aggCalls.stream() + .map(call -> SqlTypeUtil.convertType(call.getType())) + .collect(Collectors.toList()); + int[] pathPruneIndices = + inputRelDataType.getFieldList().stream() + .filter(f -> matchAggregate.getPathSchema().getFieldNames().contains(f.getName())) + .map(RelDataTypeField::getIndex) + .mapToInt(Integer::intValue) + .toArray(); + IType[] inputPathTypes = + inputRelDataType.getFieldList().stream() + .map(f -> SqlTypeUtil.convertType(f.getType())) + .toArray(IType[]::new); + IType[] pathPruneTypes = + matchAggregate.getPathSchema().getFieldList().stream() + .map(f -> SqlTypeUtil.convertType(f.getType())) + .toArray(IType[]::new); + StepAggregateFunction aggFn = + new StepAggExpressionFunctionImpl( + pathPruneIndices, pathPruneTypes, inputPathTypes, aggFnCalls, aggOutputTypes); + + PathType inputPath = input.getOutputPathSchema(); + PathType outputPath = (PathType) SqlTypeUtil.convertType(matchAggregate.getRowType()); + return input.aggregate(inputPath, outputPath, keyFunction, aggFn); + } - Expression expression = ExpressionTranslator.of(null).translate(id); - if (nonLiteralLeafNodes.isEmpty()) { // all the leaf node is constant. - Object constantValue = expression.evaluate(null); - return new ConstantStartId(constantValue); + private SortInfo buildSortInfo(PathSort sort) { + SortInfo sortInfo = new SortInfo(); + ExpressionTranslator translator = ExpressionTranslator.of(sort.getRowType()); + for (RexNode fd : sort.getOrderByExpressions()) { + OrderByField orderByField = new OrderByField(); + if (fd.getKind() == DESCENDING) { + orderByField.order = ORDER.DESC; } else { - Expression idExpression = expression.replace(exp -> { + orderByField.order = ORDER.ASC; + } + orderByField.expression = translator.translate(fd); + sortInfo.orderByFields.add(orderByField); + } + sortInfo.fetch = + sort.getLimit() == null + ? -1 + : (int) + TypeCastUtil.cast( + translator.translate(sort.getLimit()).evaluate(null), Integer.class); + return sortInfo; + } + } + + private static Set toStartIds(Set ids) { + return ids.stream().map(StepLogicalPlanTranslator::toStartId).collect(Collectors.toSet()); + } + + private static StartId toStartId(RexNode id) { + List nonLiteralLeafNodes = + GQLRexUtil.collect( + id, child -> !(child instanceof RexCall) && !(child instanceof RexLiteral)); + + Expression expression = ExpressionTranslator.of(null).translate(id); + if (nonLiteralLeafNodes.isEmpty()) { // all the leaf node is constant. + Object constantValue = expression.evaluate(null); + return new ConstantStartId(constantValue); + } else { + Expression idExpression = + expression.replace( + exp -> { if (exp instanceof ParameterFieldExpression) { - ParameterFieldExpression field = (ParameterFieldExpression) exp; - return new FieldExpression(field.getFieldIndex(), field.getOutputType()); + ParameterFieldExpression field = (ParameterFieldExpression) exp; + return new FieldExpression(field.getFieldIndex(), field.getOutputType()); } return exp; - }); - return new ParameterStartId(idExpression); - } + }); + return new ParameterStartId(idExpression); } + } - public static List getChainableVertexMatch(StepLogicalPlan startPlan) { - if (startPlan == null) { - return Collections.emptyList(); - } - if (startPlan.getOperator() instanceof MatchVertexOperator) { - return Collections.singletonList(startPlan); - } else if (startPlan.getOperator() instanceof MatchEdgeOperator - || startPlan.getOperator() instanceof StepNodeFilterOperator - || startPlan.getOperator() instanceof StepLocalExchangeOperator - || startPlan.getOperator() instanceof StepLocalSingleValueAggregateOperator) { - return startPlan.getInputs().stream().flatMap( - input -> StepLogicalPlanTranslator.getChainableVertexMatch(input).stream() - ).collect(Collectors.toList()); - } - return Collections.emptyList(); + public static List getChainableVertexMatch(StepLogicalPlan startPlan) { + if (startPlan == null) { + return Collections.emptyList(); + } + if (startPlan.getOperator() instanceof MatchVertexOperator) { + return Collections.singletonList(startPlan); + } else if (startPlan.getOperator() instanceof MatchEdgeOperator + || startPlan.getOperator() instanceof StepNodeFilterOperator + || startPlan.getOperator() instanceof StepLocalExchangeOperator + || startPlan.getOperator() instanceof StepLocalSingleValueAggregateOperator) { + return startPlan.getInputs().stream() + .flatMap(input -> StepLogicalPlanTranslator.getChainableVertexMatch(input).stream()) + .collect(Collectors.toList()); } + return Collections.emptyList(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/TraversalRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/TraversalRuntimeContext.java index 1f7dddb36..0ef29c54c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/TraversalRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/TraversalRuntimeContext.java @@ -41,78 +41,82 @@ public interface TraversalRuntimeContext { - Configuration getConfig(); + Configuration getConfig(); - DagTopologyGroup getTopology(); + DagTopologyGroup getTopology(); - void setTopology(DagTopologyGroup topology); + void setTopology(DagTopologyGroup topology); - long getIterationId(); + long getIterationId(); - RowVertex getVertex(); + RowVertex getVertex(); - void setVertex(RowVertex vertex); + void setVertex(RowVertex vertex); - Object getRequestId(); + Object getRequestId(); - void setCurrentOpId(long operatorId); + void setCurrentOpId(long operatorId); - long getCurrentOpId(); + long getCurrentOpId(); - void setRequest(ParameterRequest parameterRequest); + void setRequest(ParameterRequest parameterRequest); - ParameterRequest getRequest(); + ParameterRequest getRequest(); - default Row getParameters() { - if (getRequest() != null) { - return getRequest().getParameters(); - } - return null; + default Row getParameters() { + if (getRequest() != null) { + return getRequest().getParameters(); } + return null; + } - void setMessageBox(MessageBox messageBox); + void setMessageBox(MessageBox messageBox); - M getMessage(MessageType messageType); + M getMessage(MessageType messageType); - EdgeGroup loadEdges(IFilter loadEdgesFilter); + EdgeGroup loadEdges(IFilter loadEdgesFilter); - RowVertex loadVertex(Object vertexId, IFilter loadVertexFilter, GraphSchema graphSchema, IType[] addingVertexFieldTypes); + RowVertex loadVertex( + Object vertexId, + IFilter loadVertexFilter, + GraphSchema graphSchema, + IType[] addingVertexFieldTypes); - CloseableIterator loadAllVertex(); + CloseableIterator loadAllVertex(); - void sendMessage(Object vertexId, IMessage message, long receiverId, long... otherReceiveIds); + void sendMessage(Object vertexId, IMessage message, long receiverId, long... otherReceiveIds); - void broadcast(IMessage message, long receiverId, long... otherReceiveIds); + void broadcast(IMessage message, long receiverId, long... otherReceiveIds); - void takePath(ITreePath treePath); + void takePath(ITreePath treePath); - void sendCoordinator(String name, Object value); + void sendCoordinator(String name, Object value); - VertexCentricAggContext getAggContext(); + VertexCentricAggContext getAggContext(); - void setAggContext(VertexCentricAggContext aggContext); + void setAggContext(VertexCentricAggContext aggContext); - RuntimeContext getRuntimeContext(); + RuntimeContext getRuntimeContext(); - MetricGroup getMetric(); + MetricGroup getMetric(); - int getNumTasks(); + int getNumTasks(); - int getTaskIndex(); + int getTaskIndex(); - void push(long opId, CallContext callContext); + void push(long opId, CallContext callContext); - void pop(long opId); + void pop(long opId); - void stashCallRequestId(CallRequestId requestId); + void stashCallRequestId(CallRequestId requestId); - Iterable takeCallRequestIds(); + Iterable takeCallRequestIds(); - void setInputOperatorId(long id); + void setInputOperatorId(long id); - long getInputOperatorId(); + long getInputOperatorId(); - void addFieldToVertex(Object vertexId, int index, Object value); + void addFieldToVertex(Object vertexId, int index, Object value); - long createUniqueId(long idInTask); + long createUniqueId(long idInTask); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/ChainOperatorCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/ChainOperatorCollector.java index 80c8f1d5c..2bc0c6019 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/ChainOperatorCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/ChainOperatorCollector.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.data.StepRecord.StepRecordType; import org.apache.geaflow.dsl.runtime.traversal.TraversalRuntimeContext; @@ -28,31 +29,31 @@ public class ChainOperatorCollector implements StepCollector { - private final List> nextOperators; - - private final TraversalRuntimeContext context; - - public ChainOperatorCollector(List> nextOperators, - TraversalRuntimeContext context) { - this.nextOperators = nextOperators; - this.context = context; - } - - public ChainOperatorCollector(StepOperator nextOperator, - TraversalRuntimeContext context) { - this(Collections.singletonList(nextOperator), context); - } - - @Override - public void collect(StepRecord record) { - if (record.getType() == StepRecordType.EOD) { - for (StepOperator nextOperator : nextOperators) { - nextOperator.process(record); - } - } else { - for (StepOperator nextOperator : nextOperators) { - nextOperator.process(record); - } - } + private final List> nextOperators; + + private final TraversalRuntimeContext context; + + public ChainOperatorCollector( + List> nextOperators, TraversalRuntimeContext context) { + this.nextOperators = nextOperators; + this.context = context; + } + + public ChainOperatorCollector( + StepOperator nextOperator, TraversalRuntimeContext context) { + this(Collections.singletonList(nextOperator), context); + } + + @Override + public void collect(StepRecord record) { + if (record.getType() == StepRecordType.EOD) { + for (StepOperator nextOperator : nextOperators) { + nextOperator.process(record); + } + } else { + for (StepOperator nextOperator : nextOperators) { + nextOperator.process(record); + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepBroadcastCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepBroadcastCollector.java index bea540cd6..16399f72c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepBroadcastCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepBroadcastCollector.java @@ -21,20 +21,21 @@ import java.util.List; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.StepRecord; public class StepBroadcastCollector implements StepCollector { - private final List> collectors; + private final List> collectors; - public StepBroadcastCollector(List> collectors) { - this.collectors = Objects.requireNonNull(collectors); - } + public StepBroadcastCollector(List> collectors) { + this.collectors = Objects.requireNonNull(collectors); + } - @Override - public void collect(OUT record) { - for (StepCollector collector : collectors) { - collector.collect(record); - } + @Override + public void collect(OUT record) { + for (StepCollector collector : collectors) { + collector.collect(record); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepCollector.java index e14a33443..c082ecc6f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepCollector.java @@ -23,5 +23,5 @@ public interface StepCollector { - void collect(OUT record); + void collect(OUT record); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepEndCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepEndCollector.java index a99b550cc..936e660db 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepEndCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepEndCollector.java @@ -28,29 +28,31 @@ public class StepEndCollector implements StepCollector { - public static final String TRAVERSAL_FINISH = "TraversalFinish"; + public static final String TRAVERSAL_FINISH = "TraversalFinish"; - private final TraversalRuntimeContext context; + private final TraversalRuntimeContext context; - public StepEndCollector(TraversalRuntimeContext context) { - this.context = context; - } + public StepEndCollector(TraversalRuntimeContext context) { + this.context = context; + } - @Override - public void collect(StepRecord record) { - if (record.getType() == StepRecordType.EOD) { - // The StepEndOperator has received all the EOD, send TRAVERSAL_FINISH to the coordinator - // to finish the iteration. - context.sendCoordinator(TRAVERSAL_FINISH, 1); - } else { - StepRecordWithPath recordWithPath = (StepRecordWithPath) record; - for (ITreePath path : recordWithPath.getPaths()) { - if (context.getParameters() != null) { - // If current is request with parameter, carry the request id and request parameter out. - path = new ParameterizedTreePath(path, context.getRequest().getRequestId(), context.getParameters()); - } - context.takePath(path); - } + @Override + public void collect(StepRecord record) { + if (record.getType() == StepRecordType.EOD) { + // The StepEndOperator has received all the EOD, send TRAVERSAL_FINISH to the coordinator + // to finish the iteration. + context.sendCoordinator(TRAVERSAL_FINISH, 1); + } else { + StepRecordWithPath recordWithPath = (StepRecordWithPath) record; + for (ITreePath path : recordWithPath.getPaths()) { + if (context.getParameters() != null) { + // If current is request with parameter, carry the request id and request parameter out. + path = + new ParameterizedTreePath( + path, context.getRequest().getRequestId(), context.getParameters()); } + context.takePath(path); + } } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepJumpCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepJumpCollector.java index 63f206546..d54c52d58 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepJumpCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepJumpCollector.java @@ -25,31 +25,28 @@ public class StepJumpCollector implements StepCollector { - private final long jumpOpId; - - private final StepCollector baseCollector; - - public StepJumpCollector(long senderId, long jumpOpId, TraversalRuntimeContext context) { - this.jumpOpId = jumpOpId; - boolean isChained = context.getTopology().isChained(senderId, jumpOpId); - if (isChained) { - StepOperator jumpOp = context.getTopology().getOperator(jumpOpId); - this.baseCollector = new ChainOperatorCollector(jumpOp, context); - } else { - this.baseCollector = new StepNextCollector(senderId, jumpOpId, context); - } + private final long jumpOpId; + + private final StepCollector baseCollector; + + public StepJumpCollector(long senderId, long jumpOpId, TraversalRuntimeContext context) { + this.jumpOpId = jumpOpId; + boolean isChained = context.getTopology().isChained(senderId, jumpOpId); + if (isChained) { + StepOperator jumpOp = context.getTopology().getOperator(jumpOpId); + this.baseCollector = new ChainOperatorCollector(jumpOp, context); + } else { + this.baseCollector = new StepNextCollector(senderId, jumpOpId, context); } + } - @Override - public void collect(StepRecord record) { - baseCollector.collect(record); - } + @Override + public void collect(StepRecord record) { + baseCollector.collect(record); + } - @Override - public String toString() { - return "StepJumpCollector{" - + "jumpOpId=" + jumpOpId - + ", baseCollector=" + baseCollector - + '}'; - } + @Override + public String toString() { + return "StepJumpCollector{" + "jumpOpId=" + jumpOpId + ", baseCollector=" + baseCollector + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepNextCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepNextCollector.java index 60cace2f9..7e60b1936 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepNextCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepNextCollector.java @@ -19,9 +19,9 @@ package org.apache.geaflow.dsl.runtime.traversal.collector; -import com.google.common.collect.Lists; import java.util.HashSet; import java.util.Set; + import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowKey; import org.apache.geaflow.dsl.common.data.StepRecord; @@ -41,94 +41,93 @@ import org.apache.geaflow.dsl.runtime.traversal.path.EmptyTreePath; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; +import com.google.common.collect.Lists; + public class StepNextCollector implements StepCollector { - private final long senderId; + private final long senderId; - /** - * The op id of the receiver. - */ - private final long receiverOpId; + /** The op id of the receiver. */ + private final long receiverOpId; - private final TraversalRuntimeContext context; + private final TraversalRuntimeContext context; - private final boolean nextIsJoin; + private final boolean nextIsJoin; - public StepNextCollector(long senderId, long receiverOpId, - TraversalRuntimeContext context) { - this.senderId = senderId; - this.receiverOpId = receiverOpId; - this.context = context; - this.nextIsJoin = context.getTopology().getOperator(receiverOpId) - instanceof StepJoinOperator; - } + public StepNextCollector(long senderId, long receiverOpId, TraversalRuntimeContext context) { + this.senderId = senderId; + this.receiverOpId = receiverOpId; + this.context = context; + this.nextIsJoin = context.getTopology().getOperator(receiverOpId) instanceof StepJoinOperator; + } - @Override - public void collect(StepRecord record) { - switch (record.getType()) { - case VERTEX: - VertexRecord vertexRecord = (VertexRecord) record; - sendPathMessage(vertexRecord.getVertex().getId(), vertexRecord.getTreePath()); - break; - case EDGE_GROUP: - EdgeGroupRecord edgeGroupRecord = (EdgeGroupRecord) record; - Set targetIds = new HashSet<>(); - for (RowEdge edge : edgeGroupRecord.getEdgeGroup()) { - targetIds.add(edge.getTargetId()); - } - for (Object targetId : targetIds) { - sendPathMessage(targetId, edgeGroupRecord.getPathById(targetId)); - sendRequest(targetId); - } - break; - case EOD: - EndOfData eod = (EndOfData) record; - // broadcast EOD to all the tasks. - context.broadcast(EODMessage.of(eod), receiverOpId); - break; - case KEY_RECORD: - StepKeyRecord keyRecord = (StepKeyRecord) record; - RowKey rowKey = keyRecord.getKey(); - KeyGroupMessage keyGroupMessage = new KeyGroupMessageImpl(Lists.newArrayList(keyRecord.getValue())); - context.sendMessage(rowKey, keyGroupMessage, receiverOpId); - sendRequest(rowKey); - break; - default: - throw new IllegalArgumentException("Illegal record type: " + record.getType()); + @Override + public void collect(StepRecord record) { + switch (record.getType()) { + case VERTEX: + VertexRecord vertexRecord = (VertexRecord) record; + sendPathMessage(vertexRecord.getVertex().getId(), vertexRecord.getTreePath()); + break; + case EDGE_GROUP: + EdgeGroupRecord edgeGroupRecord = (EdgeGroupRecord) record; + Set targetIds = new HashSet<>(); + for (RowEdge edge : edgeGroupRecord.getEdgeGroup()) { + targetIds.add(edge.getTargetId()); + } + for (Object targetId : targetIds) { + sendPathMessage(targetId, edgeGroupRecord.getPathById(targetId)); + sendRequest(targetId); } + break; + case EOD: + EndOfData eod = (EndOfData) record; + // broadcast EOD to all the tasks. + context.broadcast(EODMessage.of(eod), receiverOpId); + break; + case KEY_RECORD: + StepKeyRecord keyRecord = (StepKeyRecord) record; + RowKey rowKey = keyRecord.getKey(); + KeyGroupMessage keyGroupMessage = + new KeyGroupMessageImpl(Lists.newArrayList(keyRecord.getValue())); + context.sendMessage(rowKey, keyGroupMessage, receiverOpId); + sendRequest(rowKey); + break; + default: + throw new IllegalArgumentException("Illegal record type: " + record.getType()); } + } - /** - * Send path messages to target vertex id. - * - * @param targetId The target vertex id. - * @param treePath The path to send. - */ - private void sendPathMessage(Object targetId, ITreePath treePath) { - if (treePath == null) { - treePath = EmptyTreePath.INSTANCE; - } - if (context.getRequest() != null) { - // set requestId for tree path. - treePath.setRequestIdForTree(context.getRequest().getRequestId()); - } - IMessage pathMessage; - if (nextIsJoin) { // If next op is join, add the senderId to the message. - pathMessage = JoinPathMessage.from(senderId, treePath); - } else { - pathMessage = treePath; - } - // Send path message. - context.sendMessage(targetId, pathMessage, receiverOpId); + /** + * Send path messages to target vertex id. + * + * @param targetId The target vertex id. + * @param treePath The path to send. + */ + private void sendPathMessage(Object targetId, ITreePath treePath) { + if (treePath == null) { + treePath = EmptyTreePath.INSTANCE; } + if (context.getRequest() != null) { + // set requestId for tree path. + treePath.setRequestIdForTree(context.getRequest().getRequestId()); + } + IMessage pathMessage; + if (nextIsJoin) { // If next op is join, add the senderId to the message. + pathMessage = JoinPathMessage.from(senderId, treePath); + } else { + pathMessage = treePath; + } + // Send path message. + context.sendMessage(targetId, pathMessage, receiverOpId); + } - private void sendRequest(Object targetId) { - // Send request message. - if (context.getRequest() != null) { - ParameterRequest request = context.getRequest(); - ParameterRequestMessage requestMessage = new ParameterRequestMessage(); - requestMessage.addRequest(request); - context.sendMessage(targetId, requestMessage, receiverOpId); - } + private void sendRequest(Object targetId) { + // Send request message. + if (context.getRequest() != null) { + ParameterRequest request = context.getRequest(); + ParameterRequestMessage requestMessage = new ParameterRequestMessage(); + requestMessage.addRequest(request); + context.sendMessage(targetId, requestMessage, receiverOpId); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepPathPruneCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepPathPruneCollector.java index 8ff9c54b3..9cf3d7ef5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepPathPruneCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepPathPruneCollector.java @@ -25,27 +25,29 @@ public class StepPathPruneCollector implements StepCollector { - private final TraversalRuntimeContext context; - - private final StepCollector baseCollector; - - private final int[] outputPathFieldIndices; - - public StepPathPruneCollector(TraversalRuntimeContext context, - StepCollector baseCollector, int[] outputPathFieldIndices) { - this.context = context; - this.baseCollector = baseCollector; - this.outputPathFieldIndices = outputPathFieldIndices; - } - - @Override - public void collect(OUT record) { - if (record instanceof StepRecordWithPath) { - StepRecordWithPath recordWithPath = (StepRecordWithPath) record; - StepRecordWithPath prunePathRecord = recordWithPath.subPathSet(outputPathFieldIndices); - baseCollector.collect((OUT) prunePathRecord); - } else { - baseCollector.collect(record); - } + private final TraversalRuntimeContext context; + + private final StepCollector baseCollector; + + private final int[] outputPathFieldIndices; + + public StepPathPruneCollector( + TraversalRuntimeContext context, + StepCollector baseCollector, + int[] outputPathFieldIndices) { + this.context = context; + this.baseCollector = baseCollector; + this.outputPathFieldIndices = outputPathFieldIndices; + } + + @Override + public void collect(OUT record) { + if (record instanceof StepRecordWithPath) { + StepRecordWithPath recordWithPath = (StepRecordWithPath) record; + StepRecordWithPath prunePathRecord = recordWithPath.subPathSet(outputPathFieldIndices); + baseCollector.collect((OUT) prunePathRecord); + } else { + baseCollector.collect(record); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepReturnCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepReturnCollector.java index 3b000cdcd..08845f55e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepReturnCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepReturnCollector.java @@ -35,48 +35,52 @@ public class StepReturnCollector implements StepCollector { - private static final Logger LOGGER = LoggerFactory.getLogger(StepReturnCollector.class); + private static final Logger LOGGER = LoggerFactory.getLogger(StepReturnCollector.class); - private final TraversalRuntimeContext context; + private final TraversalRuntimeContext context; - private final long queryId; + private final long queryId; - public StepReturnCollector(TraversalRuntimeContext context, long currentOpId) { - this.context = context; - this.queryId = context.getTopology().getDagTopology(currentOpId).getEntryOpId(); - } + public StepReturnCollector(TraversalRuntimeContext context, long currentOpId) { + this.context = context; + this.queryId = context.getTopology().getDagTopology(currentOpId).getEntryOpId(); + } - @Override - public void collect(StepRecord record) { - StepRecordType recordType = record.getType(); - if (recordType == StepRecordType.SINGLE_VALUE) { - ParameterRequest request = context.getRequest(); - CallRequestId callRequestId = (CallRequestId) request.getRequestId(); - long callOpId = callRequestId.getCallOpId(); - sendReturnValue(callOpId, callRequestId.getPathId(), request.getVertexId(), (SingleValue) record); - } else if (recordType == StepRecordType.EOD) { - Iterable callRequestIds = context.takeCallRequestIds(); - // Send default return value: 'null' to the caller as the request may filter by the middle operator - // and cannot reach to the return operator. If the request can reach the return operator, the return value - // will update the null value. - for (CallRequestId callRequestId : callRequestIds) { - long callOpId = callRequestId.getCallOpId(); - Object startVertexId = callRequestId.getVertexId(); - sendReturnValue(callOpId, callRequestId.getPathId(), startVertexId, null); - } - // send eod to the caller after call return. - EndOfData eod = (EndOfData) record; - assert eod.getCallOpId() >= 0 : "Illegal caller op id: " + eod.getCallOpId(); - long callerId = eod.getCallOpId(); - eod = EndOfData.of(-1, eod.getSenderId()); - context.broadcast(EODMessage.of(eod), callerId); - } + @Override + public void collect(StepRecord record) { + StepRecordType recordType = record.getType(); + if (recordType == StepRecordType.SINGLE_VALUE) { + ParameterRequest request = context.getRequest(); + CallRequestId callRequestId = (CallRequestId) request.getRequestId(); + long callOpId = callRequestId.getCallOpId(); + sendReturnValue( + callOpId, callRequestId.getPathId(), request.getVertexId(), (SingleValue) record); + } else if (recordType == StepRecordType.EOD) { + Iterable callRequestIds = context.takeCallRequestIds(); + // Send default return value: 'null' to the caller as the request may filter by the middle + // operator + // and cannot reach to the return operator. If the request can reach the return operator, the + // return value + // will update the null value. + for (CallRequestId callRequestId : callRequestIds) { + long callOpId = callRequestId.getCallOpId(); + Object startVertexId = callRequestId.getVertexId(); + sendReturnValue(callOpId, callRequestId.getPathId(), startVertexId, null); + } + // send eod to the caller after call return. + EndOfData eod = (EndOfData) record; + assert eod.getCallOpId() >= 0 : "Illegal caller op id: " + eod.getCallOpId(); + long callerId = eod.getCallOpId(); + eod = EndOfData.of(-1, eod.getSenderId()); + context.broadcast(EODMessage.of(eod), callerId); } + } - private void sendReturnValue(long callerOpId, long pathId, Object startVertexId, SingleValue value) { - // send back the result value. - ReturnMessage returnMessage = new ReturnMessageImpl(); - returnMessage.putValue(new ReturnKey(pathId, queryId), value); - context.sendMessage(startVertexId, returnMessage, callerOpId); - } + private void sendReturnValue( + long callerOpId, long pathId, Object startVertexId, SingleValue value) { + // send back the result value. + ReturnMessage returnMessage = new ReturnMessageImpl(); + returnMessage.putValue(new ReturnKey(pathId, queryId), value); + context.sendMessage(startVertexId, returnMessage, callerOpId); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepWaitCallQueryCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepWaitCallQueryCollector.java index 056ea320a..a781ee869 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepWaitCallQueryCollector.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/collector/StepWaitCallQueryCollector.java @@ -25,21 +25,22 @@ public class StepWaitCallQueryCollector implements StepCollector { - private final CallQueryProxy callQueryProxy; + private final CallQueryProxy callQueryProxy; - private final StepCollector baseCollector; + private final StepCollector baseCollector; - public StepWaitCallQueryCollector(CallQueryProxy callQueryProxy, StepCollector baseCollector) { - this.callQueryProxy = callQueryProxy; - this.baseCollector = baseCollector; - } + public StepWaitCallQueryCollector( + CallQueryProxy callQueryProxy, StepCollector baseCollector) { + this.callQueryProxy = callQueryProxy; + this.baseCollector = baseCollector; + } - @Override - public void collect(OUT record) { - // Only when the calling query has returned, we can collect the record to the next operator. - if (callQueryProxy.getCallState() == CallState.RETURNING - || callQueryProxy.getCallState() == CallState.FINISH) { - baseCollector.collect(record); - } + @Override + public void collect(OUT record) { + // Only when the calling query has returned, we can collect the record to the next operator. + if (callQueryProxy.getCallState() == CallState.RETURNING + || callQueryProxy.getCallState() == CallState.FINISH) { + baseCollector.collect(record); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/BroadcastId.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/BroadcastId.java index 83359fc26..17af78c4b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/BroadcastId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/BroadcastId.java @@ -20,30 +20,31 @@ package org.apache.geaflow.dsl.runtime.traversal.data; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.VirtualId; public class BroadcastId implements VirtualId { - private final int fromTaskIndex; + private final int fromTaskIndex; - public BroadcastId(int fromTaskIndex) { - this.fromTaskIndex = fromTaskIndex; - } + public BroadcastId(int fromTaskIndex) { + this.fromTaskIndex = fromTaskIndex; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof BroadcastId)) { - return false; - } - BroadcastId that = (BroadcastId) o; - return fromTaskIndex == that.fromTaskIndex; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(fromTaskIndex); + if (!(o instanceof BroadcastId)) { + return false; } + BroadcastId that = (BroadcastId) o; + return fromTaskIndex == that.fromTaskIndex; + } + + @Override + public int hashCode() { + return Objects.hash(fromTaskIndex); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/CallRequestId.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/CallRequestId.java index d0d343984..44ddb8a4a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/CallRequestId.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/CallRequestId.java @@ -24,59 +24,60 @@ public class CallRequestId implements Serializable { - /** - * The request id for sub query calling. It's the path id for each request path. - */ - private final long pathId; + /** The request id for sub query calling. It's the path id for each request path. */ + private final long pathId; - private final long callOpId; + private final long callOpId; - private final Object vertexId; + private final Object vertexId; - public CallRequestId(long pathId, long callOpId, Object vertexId) { - if (pathId < 0) { - throw new IllegalArgumentException("Illegal pathId: " + pathId); - } - this.pathId = pathId; - this.callOpId = callOpId; - this.vertexId = vertexId; + public CallRequestId(long pathId, long callOpId, Object vertexId) { + if (pathId < 0) { + throw new IllegalArgumentException("Illegal pathId: " + pathId); } + this.pathId = pathId; + this.callOpId = callOpId; + this.vertexId = vertexId; + } - public long getPathId() { - return pathId; - } + public long getPathId() { + return pathId; + } - public long getCallOpId() { - return callOpId; - } + public long getCallOpId() { + return callOpId; + } - public Object getVertexId() { - return vertexId; - } + public Object getVertexId() { + return vertexId; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof CallRequestId)) { - return false; - } - CallRequestId that = (CallRequestId) o; - return pathId == that.pathId && callOpId == that.callOpId && vertexId.equals(that.vertexId); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(pathId, callOpId, vertexId); + if (!(o instanceof CallRequestId)) { + return false; } + CallRequestId that = (CallRequestId) o; + return pathId == that.pathId && callOpId == that.callOpId && vertexId.equals(that.vertexId); + } - @Override - public String toString() { - return "CallRequestId{" - + "pathId=" + pathId - + ", callOpId=" + callOpId - + ", vertexId=" + vertexId - + '}'; - } + @Override + public int hashCode() { + return Objects.hash(pathId, callOpId, vertexId); + } + + @Override + public String toString() { + return "CallRequestId{" + + "pathId=" + + pathId + + ", callOpId=" + + callOpId + + ", vertexId=" + + vertexId + + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroup.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroup.java index b5905f2bb..f59c4a7ad 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroup.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroup.java @@ -19,79 +19,82 @@ package org.apache.geaflow.dsl.runtime.traversal.data; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.FluentIterable; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import java.util.Iterator; import java.util.List; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.dsl.common.data.RowEdge; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + public class EdgeGroup implements Iterable { - private final Iterable edges; - - private EdgeGroup(Iterable edges) { - this.edges = edges; - } - - public static EdgeGroup of(List edges) { - return new EdgeGroup(edges); - } - - public static EdgeGroup of(Iterable edges) { - return new EdgeGroup(edges); - } - - public static EdgeGroup of(Iterator edges) { - return new EdgeGroup(Lists.newArrayList(edges)); - } - - @Override - public Iterator iterator() { - return edges.iterator(); - } - - public EdgeGroup filter(Predicate predicate) { - Iterable filterEdges = Iterables.filter(edges, predicate::test); - return EdgeGroup.of(filterEdges); - } - - public EdgeGroup map(Function function) { - Iterable mapEdges = Iterables.transform(edges, function::apply); - return EdgeGroup.of(mapEdges); - } - - public Iterable flatMap(Function> function) { - Iterator edgeIterator = edges.iterator(); - Iterator flatIterator = new AbstractIterator() { - - private Iterator current = null; - - @Override - protected E computeNext() { - if (current == null || !current.hasNext()) { - if (edgeIterator.hasNext()) { - RowEdge edge = edges.iterator().next(); - current = function.apply(edge); - } else { - current = null; - } - } - if (current == null) { - return this.endOfData(); - } - return current.next(); - } - }; + private final Iterable edges; - return new FluentIterable() { - @Override - public Iterator iterator() { - return flatIterator; + private EdgeGroup(Iterable edges) { + this.edges = edges; + } + + public static EdgeGroup of(List edges) { + return new EdgeGroup(edges); + } + + public static EdgeGroup of(Iterable edges) { + return new EdgeGroup(edges); + } + + public static EdgeGroup of(Iterator edges) { + return new EdgeGroup(Lists.newArrayList(edges)); + } + + @Override + public Iterator iterator() { + return edges.iterator(); + } + + public EdgeGroup filter(Predicate predicate) { + Iterable filterEdges = Iterables.filter(edges, predicate::test); + return EdgeGroup.of(filterEdges); + } + + public EdgeGroup map(Function function) { + Iterable mapEdges = Iterables.transform(edges, function::apply); + return EdgeGroup.of(mapEdges); + } + + public Iterable flatMap(Function> function) { + Iterator edgeIterator = edges.iterator(); + Iterator flatIterator = + new AbstractIterator() { + + private Iterator current = null; + + @Override + protected E computeNext() { + if (current == null || !current.hasNext()) { + if (edgeIterator.hasNext()) { + RowEdge edge = edges.iterator().next(); + current = function.apply(edge); + } else { + current = null; + } } + if (current == null) { + return this.endOfData(); + } + return current.next(); + } }; - } + + return new FluentIterable() { + @Override + public Iterator iterator() { + return flatIterator; + } + }; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroupRecord.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroupRecord.java index a614d6ddd..54750f1ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroupRecord.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EdgeGroupRecord.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.function.Function; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath.PathFilterFunction; @@ -32,120 +33,120 @@ public class EdgeGroupRecord implements StepRecordWithPath { - private final EdgeGroup edgeGroup; - - private final Map targetId2TreePaths; - - private EdgeGroupRecord(EdgeGroup edgeGroup, Map targetId2TreePaths) { - this.edgeGroup = Objects.requireNonNull(edgeGroup); - this.targetId2TreePaths = Objects.requireNonNull(targetId2TreePaths); - } - - public static EdgeGroupRecord of(EdgeGroup edgeGroup, Map targetId2TreePaths) { - return new EdgeGroupRecord(edgeGroup, targetId2TreePaths); - } - - public EdgeGroup getEdgeGroup() { - return edgeGroup; - } - - @Override - public ITreePath getPathById(Object vertexId) { - return targetId2TreePaths.get(vertexId); - } - - @Override - public Iterable getPaths() { - return targetId2TreePaths.values(); - } - - @Override - public Iterable getVertexIds() { - return targetId2TreePaths.keySet(); - } - - @Override - public StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices) { - Map filterTreePaths = new HashMap<>(); - for (Map.Entry entry : targetId2TreePaths.entrySet()) { - Object targetId = entry.getKey(); - ITreePath treePath = entry.getValue(); - ITreePath filterPath = treePath.filter(function, refPathIndices); - - if (!filterPath.isEmpty()) { - filterTreePaths.put(targetId, filterPath); - } - } - EdgeGroup filterEg; - if (filterTreePaths.size() == targetId2TreePaths.size()) { - filterEg = edgeGroup; - } else { - filterEg = edgeGroup.filter(edge -> filterTreePaths.containsKey(edge.getTargetId())); - } - return new EdgeGroupRecord(filterEg, filterTreePaths); + private final EdgeGroup edgeGroup; + + private final Map targetId2TreePaths; + + private EdgeGroupRecord(EdgeGroup edgeGroup, Map targetId2TreePaths) { + this.edgeGroup = Objects.requireNonNull(edgeGroup); + this.targetId2TreePaths = Objects.requireNonNull(targetId2TreePaths); + } + + public static EdgeGroupRecord of(EdgeGroup edgeGroup, Map targetId2TreePaths) { + return new EdgeGroupRecord(edgeGroup, targetId2TreePaths); + } + + public EdgeGroup getEdgeGroup() { + return edgeGroup; + } + + @Override + public ITreePath getPathById(Object vertexId) { + return targetId2TreePaths.get(vertexId); + } + + @Override + public Iterable getPaths() { + return targetId2TreePaths.values(); + } + + @Override + public Iterable getVertexIds() { + return targetId2TreePaths.keySet(); + } + + @Override + public StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices) { + Map filterTreePaths = new HashMap<>(); + for (Map.Entry entry : targetId2TreePaths.entrySet()) { + Object targetId = entry.getKey(); + ITreePath treePath = entry.getValue(); + ITreePath filterPath = treePath.filter(function, refPathIndices); + + if (!filterPath.isEmpty()) { + filterTreePaths.put(targetId, filterPath); + } } - - @Override - public List map(PathMapFunction function, int[] refPathIndices) { - List results = new ArrayList<>(); - for (ITreePath treePath : targetId2TreePaths.values()) { - List treeResults = treePath.map(function); - if (treeResults != null) { - results.addAll(treeResults); - } - } - return results; + EdgeGroup filterEg; + if (filterTreePaths.size() == targetId2TreePaths.size()) { + filterEg = edgeGroup; + } else { + filterEg = edgeGroup.filter(edge -> filterTreePaths.containsKey(edge.getTargetId())); } - - @Override - public StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices) { - Map mapTreePaths = new HashMap<>(); - for (Map.Entry entry : targetId2TreePaths.entrySet()) { - Object targetId = entry.getKey(); - ITreePath treePath = entry.getValue(); - ITreePath mapPath = treePath.mapTree(function); - mapTreePaths.put(targetId, mapPath); - } - return new EdgeGroupRecord(edgeGroup, mapTreePaths); + return new EdgeGroupRecord(filterEg, filterTreePaths); + } + + @Override + public List map(PathMapFunction function, int[] refPathIndices) { + List results = new ArrayList<>(); + for (ITreePath treePath : targetId2TreePaths.values()) { + List treeResults = treePath.map(function); + if (treeResults != null) { + results.addAll(treeResults); + } } - - @Override - public StepRecordWithPath mapTreePath(Function function) { - Map mapTreePaths = new HashMap<>(); - for (Map.Entry entry : targetId2TreePaths.entrySet()) { - Object targetId = entry.getKey(); - ITreePath treePath = entry.getValue(); - ITreePath mapPath = function.apply(treePath); - mapTreePaths.put(targetId, mapPath); - } - return new EdgeGroupRecord(edgeGroup, mapTreePaths); + return results; + } + + @Override + public StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices) { + Map mapTreePaths = new HashMap<>(); + for (Map.Entry entry : targetId2TreePaths.entrySet()) { + Object targetId = entry.getKey(); + ITreePath treePath = entry.getValue(); + ITreePath mapPath = treePath.mapTree(function); + mapTreePaths.put(targetId, mapPath); } - - @Override - public StepRecordWithPath subPathSet(int[] pathIndices) { - if (pathIndices.length == 0) { - return new EdgeGroupRecord(edgeGroup, new HashMap<>()); - } - Map subTreePaths = new HashMap<>(); - for (Map.Entry entry : targetId2TreePaths.entrySet()) { - Object targetId = entry.getKey(); - ITreePath treePath = entry.getValue(); - ITreePath subTreePath = treePath.subPath(pathIndices); - - if (!subTreePath.isEmpty()) { - subTreePaths.put(targetId, subTreePath); - } - } - return new EdgeGroupRecord(edgeGroup, subTreePaths); + return new EdgeGroupRecord(edgeGroup, mapTreePaths); + } + + @Override + public StepRecordWithPath mapTreePath(Function function) { + Map mapTreePaths = new HashMap<>(); + for (Map.Entry entry : targetId2TreePaths.entrySet()) { + Object targetId = entry.getKey(); + ITreePath treePath = entry.getValue(); + ITreePath mapPath = function.apply(treePath); + mapTreePaths.put(targetId, mapPath); } + return new EdgeGroupRecord(edgeGroup, mapTreePaths); + } - @Override - public boolean isPathEmpty() { - return targetId2TreePaths.isEmpty(); + @Override + public StepRecordWithPath subPathSet(int[] pathIndices) { + if (pathIndices.length == 0) { + return new EdgeGroupRecord(edgeGroup, new HashMap<>()); } - - @Override - public StepRecordType getType() { - return StepRecordType.EDGE_GROUP; + Map subTreePaths = new HashMap<>(); + for (Map.Entry entry : targetId2TreePaths.entrySet()) { + Object targetId = entry.getKey(); + ITreePath treePath = entry.getValue(); + ITreePath subTreePath = treePath.subPath(pathIndices); + + if (!subTreePath.isEmpty()) { + subTreePaths.put(targetId, subTreePath); + } } + return new EdgeGroupRecord(edgeGroup, subTreePaths); + } + + @Override + public boolean isPathEmpty() { + return targetId2TreePaths.isEmpty(); + } + + @Override + public StepRecordType getType() { + return StepRecordType.EDGE_GROUP; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EndOfData.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EndOfData.java index 1627c3307..d041e6d65 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EndOfData.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/EndOfData.java @@ -23,56 +23,47 @@ public class EndOfData implements StepRecord { - /** - * The caller operator id. -1 means main dag execution. - */ - private final long callOpId; + /** The caller operator id. -1 means main dag execution. */ + private final long callOpId; - /** - * The sender operator id who send this data. - */ - private final long senderId; + /** The sender operator id who send this data. */ + private final long senderId; - /** - * No data been processed between two eod cycles. - */ - public boolean isGlobalEmptyCycle; + /** No data been processed between two eod cycles. */ + public boolean isGlobalEmptyCycle; - private EndOfData(long callOpId, long senderId) { - this.callOpId = callOpId; - this.senderId = senderId; - } + private EndOfData(long callOpId, long senderId) { + this.callOpId = callOpId; + this.senderId = senderId; + } - private EndOfData(long senderId) { - this(-1L, senderId); - } + private EndOfData(long senderId) { + this(-1L, senderId); + } - public static EndOfData of(long senderId) { - return new EndOfData(senderId); - } + public static EndOfData of(long senderId) { + return new EndOfData(senderId); + } - public static EndOfData of(long callOpId, long senderId) { - return new EndOfData(callOpId, senderId); - } + public static EndOfData of(long callOpId, long senderId) { + return new EndOfData(callOpId, senderId); + } - @Override - public StepRecordType getType() { - return StepRecordType.EOD; - } + @Override + public StepRecordType getType() { + return StepRecordType.EOD; + } - public long getCallOpId() { - return callOpId; - } + public long getCallOpId() { + return callOpId; + } - public long getSenderId() { - return senderId; - } + public long getSenderId() { + return senderId; + } - @Override - public String toString() { - return "EndOfData{" - + "callOpId=" + callOpId - + ", senderId=" + senderId - + '}'; - } + @Override + public String toString() { + return "EndOfData{" + "callOpId=" + callOpId + ", senderId=" + senderId + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleEdge.java index cf3add319..29eb16336 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleEdge.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.types.DoubleEdge; @@ -28,55 +29,59 @@ public class FieldAlignDoubleEdge extends FieldAlignEdge implements RowEdge { - private final DoubleEdge baseEdge; + private final DoubleEdge baseEdge; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignDoubleEdge(DoubleEdge baseEdge, int[] fieldMapping) { - super(baseEdge, fieldMapping); - this.baseEdge = baseEdge; - this.fieldMapping = fieldMapping; - } + public FieldAlignDoubleEdge(DoubleEdge baseEdge, int[] fieldMapping) { + super(baseEdge, fieldMapping); + this.baseEdge = baseEdge; + this.fieldMapping = fieldMapping; + } - @Override - public IEdge withValue(Row value) { - return new FieldAlignDoubleEdge(baseEdge.withValue(value), fieldMapping); - } + @Override + public IEdge withValue(Row value) { + return new FieldAlignDoubleEdge(baseEdge.withValue(value), fieldMapping); + } - @Override - public IEdge reverse() { - return new FieldAlignDoubleEdge(baseEdge.reverse(), fieldMapping); - } + @Override + public IEdge reverse() { + return new FieldAlignDoubleEdge(baseEdge.reverse(), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof DoubleEdge) { - DoubleEdge that = (DoubleEdge) o; - return baseEdge.srcId == that.srcId && baseEdge.targetId == that.targetId - && baseEdge.direction == that.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else if (o instanceof FieldAlignDoubleEdge) { - FieldAlignDoubleEdge that = (FieldAlignDoubleEdge) o; - return baseEdge.srcId == that.baseEdge.srcId && baseEdge.targetId == that.baseEdge.targetId - && baseEdge.direction == that.baseEdge.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else { - RowEdge that = (RowEdge) o; - return Objects.equals(getSrcId(), that.getSrcId()) && Objects.equals(getTargetId(), - that.getTargetId()) && getDirect() == that.getDirect() && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - int result = Objects.hash(baseEdge); - result = 31 * result + Arrays.hashCode(fieldMapping); - return result; + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof DoubleEdge) { + DoubleEdge that = (DoubleEdge) o; + return baseEdge.srcId == that.srcId + && baseEdge.targetId == that.targetId + && baseEdge.direction == that.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignDoubleEdge) { + FieldAlignDoubleEdge that = (FieldAlignDoubleEdge) o; + return baseEdge.srcId == that.baseEdge.srcId + && baseEdge.targetId == that.baseEdge.targetId + && baseEdge.direction == that.baseEdge.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowEdge that = (RowEdge) o; + return Objects.equals(getSrcId(), that.getSrcId()) + && Objects.equals(getTargetId(), that.getTargetId()) + && getDirect() == that.getDirect() + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + int result = Objects.hash(baseEdge); + result = 31 * result + Arrays.hashCode(fieldMapping); + return result; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleVertex.java index 880ee62ce..7ad41a448 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignDoubleVertex.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.data; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.impl.types.DoubleVertex; @@ -27,50 +28,51 @@ public class FieldAlignDoubleVertex extends FieldAlignVertex implements RowVertex { - private final DoubleVertex baseVertex; + private final DoubleVertex baseVertex; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignDoubleVertex(DoubleVertex baseVertex, int[] fieldMapping) { - super(baseVertex, fieldMapping); - this.baseVertex = baseVertex; - this.fieldMapping = fieldMapping; - } + public FieldAlignDoubleVertex(DoubleVertex baseVertex, int[] fieldMapping) { + super(baseVertex, fieldMapping); + this.baseVertex = baseVertex; + this.fieldMapping = fieldMapping; + } - @Override - public IVertex withValue(Row value) { - return new FieldAlignDoubleVertex(baseVertex.withValue(value), fieldMapping); - } + @Override + public IVertex withValue(Row value) { + return new FieldAlignDoubleVertex(baseVertex.withValue(value), fieldMapping); + } - @Override - public IVertex withLabel(String label) { - return new FieldAlignDoubleVertex(baseVertex.withLabel(label), fieldMapping); - } + @Override + public IVertex withLabel(String label) { + return new FieldAlignDoubleVertex(baseVertex.withLabel(label), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof DoubleVertex) { - DoubleVertex that = (DoubleVertex) o; - return Double.compare(baseVertex.id, that.id) == 0 - && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); - } else if (o instanceof FieldAlignDoubleVertex) { - FieldAlignDoubleVertex that = (FieldAlignDoubleVertex) o; - return Double.compare(baseVertex.id, that.baseVertex.id) == 0 - && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return getId().equals(that.getId()) && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(baseVertex.id, getBinaryLabel()); + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof DoubleVertex) { + DoubleVertex that = (DoubleVertex) o; + return Double.compare(baseVertex.id, that.id) == 0 + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignDoubleVertex) { + FieldAlignDoubleVertex that = (FieldAlignDoubleVertex) o; + return Double.compare(baseVertex.id, that.baseVertex.id) == 0 + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return getId().equals(that.getId()) + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + return Objects.hash(baseVertex.id, getBinaryLabel()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignEdge.java index 861bf9c3a..716b675f7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignEdge.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.runtime.traversal.data; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -35,172 +32,187 @@ import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; -public class FieldAlignEdge implements RowEdge { - - private final RowEdge baseEdge; - - private final int[] fieldMapping; - - public FieldAlignEdge(RowEdge baseEdge, int[] fieldMapping) { - this.baseEdge = baseEdge; - this.fieldMapping = fieldMapping; - } - - @Override - public Object getField(int i, IType type) { - int mappingIndex = fieldMapping[i]; - if (mappingIndex < 0) { - return null; - } - return baseEdge.getField(mappingIndex, type); - } - - @Override - public void setValue(Row value) { - baseEdge.setValue(value); - } - - @Override - public RowEdge withDirection(EdgeDirection direction) { - return new FieldAlignEdge(baseEdge.withDirection(direction), fieldMapping); - } - - @Override - public RowEdge identityReverse() { - return new FieldAlignEdge(baseEdge.identityReverse(), fieldMapping); - } - - @Override - public String getLabel() { - return baseEdge.getLabel(); - } - - @Override - public void setLabel(String label) { - baseEdge.setLabel(label); - } - - @Override - public Object getSrcId() { - return baseEdge.getSrcId(); - } - - @Override - public void setSrcId(Object srcId) { - baseEdge.setSrcId(srcId); - } - - @Override - public Object getTargetId() { - return baseEdge.getTargetId(); - } - - @Override - public void setTargetId(Object targetId) { - baseEdge.setTargetId(targetId); - } - - @Override - public EdgeDirection getDirect() { - return baseEdge.getDirect(); - } - - @Override - public void setDirect(EdgeDirection direction) { - baseEdge.setDirect(direction); - } - - @Override - public Row getValue() { - return baseEdge.getValue(); - } - - @Override - public IEdge withValue(Row value) { - return new FieldAlignEdge((RowEdge) baseEdge.withValue(value), fieldMapping); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public IEdge reverse() { - return new FieldAlignEdge((RowEdge) baseEdge.reverse(), fieldMapping); - } +public class FieldAlignEdge implements RowEdge { - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; + private final RowEdge baseEdge; + + private final int[] fieldMapping; + + public FieldAlignEdge(RowEdge baseEdge, int[] fieldMapping) { + this.baseEdge = baseEdge; + this.fieldMapping = fieldMapping; + } + + @Override + public Object getField(int i, IType type) { + int mappingIndex = fieldMapping[i]; + if (mappingIndex < 0) { + return null; + } + return baseEdge.getField(mappingIndex, type); + } + + @Override + public void setValue(Row value) { + baseEdge.setValue(value); + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + return new FieldAlignEdge(baseEdge.withDirection(direction), fieldMapping); + } + + @Override + public RowEdge identityReverse() { + return new FieldAlignEdge(baseEdge.identityReverse(), fieldMapping); + } + + @Override + public String getLabel() { + return baseEdge.getLabel(); + } + + @Override + public void setLabel(String label) { + baseEdge.setLabel(label); + } + + @Override + public Object getSrcId() { + return baseEdge.getSrcId(); + } + + @Override + public void setSrcId(Object srcId) { + baseEdge.setSrcId(srcId); + } + + @Override + public Object getTargetId() { + return baseEdge.getTargetId(); + } + + @Override + public void setTargetId(Object targetId) { + baseEdge.setTargetId(targetId); + } + + @Override + public EdgeDirection getDirect() { + return baseEdge.getDirect(); + } + + @Override + public void setDirect(EdgeDirection direction) { + baseEdge.setDirect(direction); + } + + @Override + public Row getValue() { + return baseEdge.getValue(); + } + + @Override + public IEdge withValue(Row value) { + return new FieldAlignEdge((RowEdge) baseEdge.withValue(value), fieldMapping); + } + + @Override + public IEdge reverse() { + return new FieldAlignEdge((RowEdge) baseEdge.reverse(), fieldMapping); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowEdge)) { + return false; + } + RowEdge that = (RowEdge) o; + return Objects.equals(getSrcId(), that.getSrcId()) + && Objects.equals(getTargetId(), that.getTargetId()) + && getDirect() == that.getDirect() + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); + } + + @Override + public int hashCode() { + int result = Objects.hash(baseEdge); + result = 31 * result + Arrays.hashCode(fieldMapping); + return result; + } + + @Override + public BinaryString getBinaryLabel() { + return baseEdge.getBinaryLabel(); + } + + @Override + public void setBinaryLabel(BinaryString label) { + baseEdge.setBinaryLabel(label); + } + + @Override + public String toString() { + return getSrcId() + + "#" + + getTargetId() + + "#" + + getBinaryLabel() + + "#" + + getDirect() + + "#" + + getValue(); + } + + public static RowEdge createFieldAlignedEdge(RowEdge baseEdge, int[] fieldMapping) { + if (baseEdge instanceof LongEdge) { + return new FieldAlignLongEdge((LongEdge) baseEdge, fieldMapping); + } else if (baseEdge instanceof IntEdge) { + return new FieldAlignIntEdge((IntEdge) baseEdge, fieldMapping); + } else if (baseEdge instanceof DoubleEdge) { + return new FieldAlignDoubleEdge((DoubleEdge) baseEdge, fieldMapping); + } + return new FieldAlignEdge(baseEdge, fieldMapping); + } + + public static class FieldAlignEdgeSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, FieldAlignEdge object) { + kryo.writeClassAndObject(output, object.baseEdge); + if (object.fieldMapping != null) { + output.writeInt(object.fieldMapping.length, true); + for (int i : object.fieldMapping) { + output.writeInt(i); } - RowEdge that = (RowEdge) o; - return Objects.equals(getSrcId(), that.getSrcId()) && Objects.equals(getTargetId(), - that.getTargetId()) && getDirect() == that.getDirect() && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } - - @Override - public int hashCode() { - int result = Objects.hash(baseEdge); - result = 31 * result + Arrays.hashCode(fieldMapping); - return result; + } else { + output.writeInt(0, true); + } } @Override - public BinaryString getBinaryLabel() { - return baseEdge.getBinaryLabel(); + public FieldAlignEdge read(Kryo kryo, Input input, Class type) { + RowEdge baseEdge = (RowEdge) kryo.readClassAndObject(input); + int[] fieldMapping = new int[input.readInt(true)]; + for (int i = 0; i < fieldMapping.length; i++) { + fieldMapping[i] = input.readInt(); + } + return new FieldAlignEdge(baseEdge, fieldMapping); } @Override - public void setBinaryLabel(BinaryString label) { - baseEdge.setBinaryLabel(label); - } - - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "#" + getBinaryLabel() + "#" + getDirect() + "#" + getValue(); - } - - public static RowEdge createFieldAlignedEdge(RowEdge baseEdge, int[] fieldMapping) { - if (baseEdge instanceof LongEdge) { - return new FieldAlignLongEdge((LongEdge) baseEdge, fieldMapping); - } else if (baseEdge instanceof IntEdge) { - return new FieldAlignIntEdge((IntEdge) baseEdge, fieldMapping); - } else if (baseEdge instanceof DoubleEdge) { - return new FieldAlignDoubleEdge((DoubleEdge) baseEdge, fieldMapping); - } - return new FieldAlignEdge(baseEdge, fieldMapping); - } - - public static class FieldAlignEdgeSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, FieldAlignEdge object) { - kryo.writeClassAndObject(output, object.baseEdge); - if (object.fieldMapping != null) { - output.writeInt(object.fieldMapping.length, true); - for (int i : object.fieldMapping) { - output.writeInt(i); - } - } else { - output.writeInt(0, true); - } - } - - @Override - public FieldAlignEdge read(Kryo kryo, Input input, Class type) { - RowEdge baseEdge = (RowEdge) kryo.readClassAndObject(input); - int[] fieldMapping = new int[input.readInt(true)]; - for (int i = 0; i < fieldMapping.length; i++) { - fieldMapping[i] = input.readInt(); - } - return new FieldAlignEdge(baseEdge, fieldMapping); - } - - @Override - public FieldAlignEdge copy(Kryo kryo, FieldAlignEdge original) { - return new FieldAlignEdge(kryo.copy(original.baseEdge), Arrays.copyOf(original.fieldMapping, original.fieldMapping.length)); - } - + public FieldAlignEdge copy(Kryo kryo, FieldAlignEdge original) { + return new FieldAlignEdge( + kryo.copy(original.baseEdge), + Arrays.copyOf(original.fieldMapping, original.fieldMapping.length)); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntEdge.java index f04cc3d54..6f4cf11c9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntEdge.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.types.IntEdge; @@ -28,55 +29,59 @@ public class FieldAlignIntEdge extends FieldAlignEdge implements RowEdge { - private final IntEdge baseEdge; + private final IntEdge baseEdge; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignIntEdge(IntEdge baseEdge, int[] fieldMapping) { - super(baseEdge, fieldMapping); - this.baseEdge = baseEdge; - this.fieldMapping = fieldMapping; - } + public FieldAlignIntEdge(IntEdge baseEdge, int[] fieldMapping) { + super(baseEdge, fieldMapping); + this.baseEdge = baseEdge; + this.fieldMapping = fieldMapping; + } - @Override - public IEdge withValue(Row value) { - return new FieldAlignIntEdge(baseEdge.withValue(value), fieldMapping); - } + @Override + public IEdge withValue(Row value) { + return new FieldAlignIntEdge(baseEdge.withValue(value), fieldMapping); + } - @Override - public IEdge reverse() { - return new FieldAlignIntEdge(baseEdge.reverse(), fieldMapping); - } + @Override + public IEdge reverse() { + return new FieldAlignIntEdge(baseEdge.reverse(), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof IntEdge) { - IntEdge that = (IntEdge) o; - return baseEdge.srcId == that.srcId && baseEdge.targetId == that.targetId - && baseEdge.direction == that.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else if (o instanceof FieldAlignIntEdge) { - FieldAlignIntEdge that = (FieldAlignIntEdge) o; - return baseEdge.srcId == that.baseEdge.srcId && baseEdge.targetId == that.baseEdge.targetId - && baseEdge.direction == that.baseEdge.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else { - RowEdge that = (RowEdge) o; - return Objects.equals(getSrcId(), that.getSrcId()) && Objects.equals(getTargetId(), - that.getTargetId()) && getDirect() == that.getDirect() && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - int result = Objects.hash(baseEdge); - result = 31 * result + Arrays.hashCode(fieldMapping); - return result; + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof IntEdge) { + IntEdge that = (IntEdge) o; + return baseEdge.srcId == that.srcId + && baseEdge.targetId == that.targetId + && baseEdge.direction == that.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignIntEdge) { + FieldAlignIntEdge that = (FieldAlignIntEdge) o; + return baseEdge.srcId == that.baseEdge.srcId + && baseEdge.targetId == that.baseEdge.targetId + && baseEdge.direction == that.baseEdge.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowEdge that = (RowEdge) o; + return Objects.equals(getSrcId(), that.getSrcId()) + && Objects.equals(getTargetId(), that.getTargetId()) + && getDirect() == that.getDirect() + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + int result = Objects.hash(baseEdge); + result = 31 * result + Arrays.hashCode(fieldMapping); + return result; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntVertex.java index d3ce124a4..76064bf4a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignIntVertex.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.data; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.impl.types.IntVertex; @@ -27,50 +28,51 @@ public class FieldAlignIntVertex extends FieldAlignVertex implements RowVertex { - private final IntVertex baseVertex; + private final IntVertex baseVertex; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignIntVertex(IntVertex baseVertex, int[] fieldMapping) { - super(baseVertex, fieldMapping); - this.baseVertex = baseVertex; - this.fieldMapping = fieldMapping; - } + public FieldAlignIntVertex(IntVertex baseVertex, int[] fieldMapping) { + super(baseVertex, fieldMapping); + this.baseVertex = baseVertex; + this.fieldMapping = fieldMapping; + } - @Override - public IVertex withValue(Row value) { - return new FieldAlignIntVertex(baseVertex.withValue(value), fieldMapping); - } + @Override + public IVertex withValue(Row value) { + return new FieldAlignIntVertex(baseVertex.withValue(value), fieldMapping); + } - @Override - public IVertex withLabel(String label) { - return new FieldAlignIntVertex(baseVertex.withLabel(label), fieldMapping); - } + @Override + public IVertex withLabel(String label) { + return new FieldAlignIntVertex(baseVertex.withLabel(label), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof IntVertex) { - IntVertex that = (IntVertex) o; - return baseVertex.id == that.id && Objects.equals(baseVertex.getBinaryLabel(), - that.getBinaryLabel()); - } else if (o instanceof FieldAlignIntVertex) { - FieldAlignIntVertex that = (FieldAlignIntVertex) o; - return baseVertex.id == that.baseVertex.id && Objects.equals(baseVertex.getBinaryLabel(), - that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return getId().equals(that.getId()) && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(baseVertex.id, getBinaryLabel()); + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof IntVertex) { + IntVertex that = (IntVertex) o; + return baseVertex.id == that.id + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignIntVertex) { + FieldAlignIntVertex that = (FieldAlignIntVertex) o; + return baseVertex.id == that.baseVertex.id + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return getId().equals(that.getId()) + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + return Objects.hash(baseVertex.id, getBinaryLabel()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongEdge.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongEdge.java index 82cc3e0fe..720f73f88 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongEdge.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongEdge.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.types.LongEdge; @@ -28,55 +29,59 @@ public class FieldAlignLongEdge extends FieldAlignEdge implements RowEdge { - private final LongEdge baseEdge; + private final LongEdge baseEdge; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignLongEdge(LongEdge baseEdge, int[] fieldMapping) { - super(baseEdge, fieldMapping); - this.baseEdge = baseEdge; - this.fieldMapping = fieldMapping; - } + public FieldAlignLongEdge(LongEdge baseEdge, int[] fieldMapping) { + super(baseEdge, fieldMapping); + this.baseEdge = baseEdge; + this.fieldMapping = fieldMapping; + } - @Override - public IEdge withValue(Row value) { - return new FieldAlignLongEdge(baseEdge.withValue(value), fieldMapping); - } + @Override + public IEdge withValue(Row value) { + return new FieldAlignLongEdge(baseEdge.withValue(value), fieldMapping); + } - @Override - public IEdge reverse() { - return new FieldAlignLongEdge(baseEdge.reverse(), fieldMapping); - } + @Override + public IEdge reverse() { + return new FieldAlignLongEdge(baseEdge.reverse(), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowEdge)) { - return false; - } - if (o instanceof LongEdge) { - LongEdge that = (LongEdge) o; - return baseEdge.srcId == that.srcId && baseEdge.targetId == that.targetId - && baseEdge.direction == that.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else if (o instanceof FieldAlignLongEdge) { - FieldAlignLongEdge that = (FieldAlignLongEdge) o; - return baseEdge.srcId == that.baseEdge.srcId && baseEdge.targetId == that.baseEdge.targetId - && baseEdge.direction == that.baseEdge.direction - && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); - } else { - RowEdge that = (RowEdge) o; - return Objects.equals(getSrcId(), that.getSrcId()) && Objects.equals(getTargetId(), - that.getTargetId()) && getDirect() == that.getDirect() && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - int result = Objects.hash(baseEdge); - result = 31 * result + Arrays.hashCode(fieldMapping); - return result; + if (!(o instanceof RowEdge)) { + return false; + } + if (o instanceof LongEdge) { + LongEdge that = (LongEdge) o; + return baseEdge.srcId == that.srcId + && baseEdge.targetId == that.targetId + && baseEdge.direction == that.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignLongEdge) { + FieldAlignLongEdge that = (FieldAlignLongEdge) o; + return baseEdge.srcId == that.baseEdge.srcId + && baseEdge.targetId == that.baseEdge.targetId + && baseEdge.direction == that.baseEdge.direction + && Objects.equals(baseEdge.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowEdge that = (RowEdge) o; + return Objects.equals(getSrcId(), that.getSrcId()) + && Objects.equals(getTargetId(), that.getTargetId()) + && getDirect() == that.getDirect() + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + int result = Objects.hash(baseEdge); + result = 31 * result + Arrays.hashCode(fieldMapping); + return result; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongVertex.java index 5eca76260..80060efd2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignLongVertex.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.data; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.impl.types.LongVertex; @@ -27,50 +28,51 @@ public class FieldAlignLongVertex extends FieldAlignVertex implements RowVertex { - private final LongVertex baseVertex; + private final LongVertex baseVertex; - private final int[] fieldMapping; + private final int[] fieldMapping; - public FieldAlignLongVertex(LongVertex baseVertex, int[] fieldMapping) { - super(baseVertex, fieldMapping); - this.baseVertex = baseVertex; - this.fieldMapping = fieldMapping; - } + public FieldAlignLongVertex(LongVertex baseVertex, int[] fieldMapping) { + super(baseVertex, fieldMapping); + this.baseVertex = baseVertex; + this.fieldMapping = fieldMapping; + } - @Override - public IVertex withValue(Row value) { - return new FieldAlignLongVertex(baseVertex.withValue(value), fieldMapping); - } + @Override + public IVertex withValue(Row value) { + return new FieldAlignLongVertex(baseVertex.withValue(value), fieldMapping); + } - @Override - public IVertex withLabel(String label) { - return new FieldAlignLongVertex(baseVertex.withLabel(label), fieldMapping); - } + @Override + public IVertex withLabel(String label) { + return new FieldAlignLongVertex(baseVertex.withLabel(label), fieldMapping); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof RowVertex)) { - return false; - } - if (o instanceof LongVertex) { - LongVertex that = (LongVertex) o; - return baseVertex.id == that.id && Objects.equals(baseVertex.getBinaryLabel(), - that.getBinaryLabel()); - } else if (o instanceof FieldAlignLongVertex) { - FieldAlignLongVertex that = (FieldAlignLongVertex) o; - return baseVertex.id == that.baseVertex.id && Objects.equals(baseVertex.getBinaryLabel(), - that.getBinaryLabel()); - } else { - RowVertex that = (RowVertex) o; - return getId().equals(that.getId()) && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(baseVertex.id, getBinaryLabel()); + if (!(o instanceof RowVertex)) { + return false; + } + if (o instanceof LongVertex) { + LongVertex that = (LongVertex) o; + return baseVertex.id == that.id + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else if (o instanceof FieldAlignLongVertex) { + FieldAlignLongVertex that = (FieldAlignLongVertex) o; + return baseVertex.id == that.baseVertex.id + && Objects.equals(baseVertex.getBinaryLabel(), that.getBinaryLabel()); + } else { + RowVertex that = (RowVertex) o; + return getId().equals(that.getId()) + && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); } + } + + @Override + public int hashCode() { + return Objects.hash(baseVertex.id, getBinaryLabel()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignPath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignPath.java index 6776221b7..e8ec2ae7c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignPath.java @@ -19,121 +19,122 @@ package org.apache.geaflow.dsl.runtime.traversal.data; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.ArrayList; import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.DefaultPath; -public class FieldAlignPath implements Path { - - private final Path basePath; - - private final int[] fieldMapping; - - public FieldAlignPath(Path basePath, int[] fieldMapping) { - this.basePath = basePath; - this.fieldMapping = fieldMapping; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void addNode(Row node) { - throw new IllegalArgumentException("Read only path, addNode() method is not supported"); - } +public class FieldAlignPath implements Path { - @Override - public void remove(int index) { - throw new IllegalArgumentException("Read only path, remove() method is not supported"); + private final Path basePath; + + private final int[] fieldMapping; + + public FieldAlignPath(Path basePath, int[] fieldMapping) { + this.basePath = basePath; + this.fieldMapping = fieldMapping; + } + + @Override + public void addNode(Row node) { + throw new IllegalArgumentException("Read only path, addNode() method is not supported"); + } + + @Override + public void remove(int index) { + throw new IllegalArgumentException("Read only path, remove() method is not supported"); + } + + @Override + public Path copy() { + return new FieldAlignPath(basePath.copy(), fieldMapping); + } + + @Override + public int size() { + return fieldMapping.length; + } + + @Override + public Path subPath(Collection indices) { + return subPath(ArrayUtil.toIntArray(indices)); + } + + @Override + public Path subPath(int[] indices) { + Path subPath = new DefaultPath(); + for (int index : indices) { + subPath.addNode(getField(index, null)); } - - @Override - public Path copy() { - return new FieldAlignPath(basePath.copy(), fieldMapping); + return subPath; + } + + @Override + public long getId() { + return basePath.getId(); + } + + @Override + public void setId(long id) { + basePath.setId(id); + } + + @Override + public Row getField(int i, IType type) { + int mappingIndex = fieldMapping[i]; + if (mappingIndex < 0) { + return null; } - - @Override - public int size() { - return fieldMapping.length; + return basePath.getField(mappingIndex, type); + } + + @Override + public List getPathNodes() { + List pathNodes = new ArrayList<>(fieldMapping.length); + for (int i = 0; i < fieldMapping.length; i++) { + pathNodes.add(getField(i, null)); } + return pathNodes; + } - @Override - public Path subPath(Collection indices) { - return subPath(ArrayUtil.toIntArray(indices)); - } + public static class FieldAlignPathSerializer extends Serializer { @Override - public Path subPath(int[] indices) { - Path subPath = new DefaultPath(); - for (int index : indices) { - subPath.addNode(getField(index, null)); + public void write(Kryo kryo, Output output, FieldAlignPath object) { + kryo.writeClassAndObject(output, object.basePath); + if (object.fieldMapping != null) { + output.writeInt(object.fieldMapping.length, true); + for (int i : object.fieldMapping) { + output.writeInt(i); } - return subPath; - } - - @Override - public long getId() { - return basePath.getId(); - } - - @Override - public void setId(long id) { - basePath.setId(id); + } else { + output.writeInt(0, true); + } } @Override - public Row getField(int i, IType type) { - int mappingIndex = fieldMapping[i]; - if (mappingIndex < 0) { - return null; - } - return basePath.getField(mappingIndex, type); + public FieldAlignPath read(Kryo kryo, Input input, Class type) { + Path basePath = (Path) kryo.readClassAndObject(input); + int[] fieldMapping = new int[input.readInt(true)]; + for (int i = 0; i < fieldMapping.length; i++) { + fieldMapping[i] = input.readInt(); + } + return new FieldAlignPath(basePath, fieldMapping); } @Override - public List getPathNodes() { - List pathNodes = new ArrayList<>(fieldMapping.length); - for (int i = 0; i < fieldMapping.length; i++) { - pathNodes.add(getField(i, null)); - } - return pathNodes; - } - - public static class FieldAlignPathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, FieldAlignPath object) { - kryo.writeClassAndObject(output, object.basePath); - if (object.fieldMapping != null) { - output.writeInt(object.fieldMapping.length, true); - for (int i : object.fieldMapping) { - output.writeInt(i); - } - } else { - output.writeInt(0, true); - } - } - - @Override - public FieldAlignPath read(Kryo kryo, Input input, Class type) { - Path basePath = (Path) kryo.readClassAndObject(input); - int[] fieldMapping = new int[input.readInt(true)]; - for (int i = 0; i < fieldMapping.length; i++) { - fieldMapping[i] = input.readInt(); - } - return new FieldAlignPath(basePath, fieldMapping); - } - - @Override - public FieldAlignPath copy(Kryo kryo, FieldAlignPath original) { - return (FieldAlignPath) original.copy(); - } + public FieldAlignPath copy(Kryo kryo, FieldAlignPath original) { + return (FieldAlignPath) original.copy(); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignVertex.java index e0199b316..951730f58 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/FieldAlignVertex.java @@ -19,12 +19,9 @@ package org.apache.geaflow.dsl.runtime.traversal.data; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Arrays; import java.util.Objects; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -34,148 +31,154 @@ import org.apache.geaflow.dsl.common.data.impl.types.LongVertex; import org.apache.geaflow.model.graph.vertex.IVertex; -public class FieldAlignVertex implements RowVertex { - - private final RowVertex baseVertex; - - private final int[] fieldMapping; - - public FieldAlignVertex(RowVertex baseVertex, int[] fieldMapping) { - this.baseVertex = baseVertex; - this.fieldMapping = fieldMapping; - } - - @Override - public Object getField(int i, IType type) { - int mappingIndex = fieldMapping[i]; - if (mappingIndex < 0) { - return null; - } - return baseVertex.getField(mappingIndex, type); - } - - @Override - public void setValue(Row value) { - baseVertex.setValue(value); - } - - @Override - public String getLabel() { - return baseVertex.getLabel(); - } - - @Override - public void setLabel(String label) { - baseVertex.setLabel(label); - } - - @Override - public Object getId() { - return baseVertex.getId(); - } - - @Override - public void setId(Object id) { - baseVertex.setId(id); - } - - @Override - public Row getValue() { - return baseVertex.getValue(); - } - - @Override - public IVertex withValue(Row value) { - return new FieldAlignVertex((RowVertex) baseVertex.withValue(value), fieldMapping); - } - - @Override - public IVertex withLabel(String label) { - return new FieldAlignVertex((RowVertex) baseVertex.withLabel(label), fieldMapping); - } - - @Override - public IVertex withTime(long time) { - return new FieldAlignVertex((RowVertex) baseVertex.withTime(time), fieldMapping); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public int compareTo(Object o) { - return baseVertex.compareTo(o); - } +public class FieldAlignVertex implements RowVertex { - @Override - public boolean equals(Object o) { - if (this == o) { - return true; + private final RowVertex baseVertex; + + private final int[] fieldMapping; + + public FieldAlignVertex(RowVertex baseVertex, int[] fieldMapping) { + this.baseVertex = baseVertex; + this.fieldMapping = fieldMapping; + } + + @Override + public Object getField(int i, IType type) { + int mappingIndex = fieldMapping[i]; + if (mappingIndex < 0) { + return null; + } + return baseVertex.getField(mappingIndex, type); + } + + @Override + public void setValue(Row value) { + baseVertex.setValue(value); + } + + @Override + public String getLabel() { + return baseVertex.getLabel(); + } + + @Override + public void setLabel(String label) { + baseVertex.setLabel(label); + } + + @Override + public Object getId() { + return baseVertex.getId(); + } + + @Override + public void setId(Object id) { + baseVertex.setId(id); + } + + @Override + public Row getValue() { + return baseVertex.getValue(); + } + + @Override + public IVertex withValue(Row value) { + return new FieldAlignVertex((RowVertex) baseVertex.withValue(value), fieldMapping); + } + + @Override + public IVertex withLabel(String label) { + return new FieldAlignVertex((RowVertex) baseVertex.withLabel(label), fieldMapping); + } + + @Override + public IVertex withTime(long time) { + return new FieldAlignVertex((RowVertex) baseVertex.withTime(time), fieldMapping); + } + + @Override + public int compareTo(Object o) { + return baseVertex.compareTo(o); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof RowVertex)) { + return false; + } + RowVertex that = (RowVertex) o; + return getId().equals(that.getId()) && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); + } + + @Override + public int hashCode() { + return Objects.hash(getId(), getBinaryLabel()); + } + + @Override + public BinaryString getBinaryLabel() { + return baseVertex.getBinaryLabel(); + } + + @Override + public void setBinaryLabel(BinaryString label) { + baseVertex.setBinaryLabel(label); + } + + @Override + public String toString() { + return getId() + "#" + getBinaryLabel() + "#" + getValue(); + } + + public static RowVertex createFieldAlignedVertex(RowVertex baseVertex, int[] fieldMapping) { + if (baseVertex instanceof LongVertex) { + return new FieldAlignLongVertex((LongVertex) baseVertex, fieldMapping); + } else if (baseVertex instanceof IntVertex) { + return new FieldAlignIntVertex((IntVertex) baseVertex, fieldMapping); + } else if (baseVertex instanceof DoubleVertex) { + return new FieldAlignDoubleVertex((DoubleVertex) baseVertex, fieldMapping); + } + return new FieldAlignVertex(baseVertex, fieldMapping); + } + + public static class FieldAlignVertexSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, FieldAlignVertex object) { + kryo.writeClassAndObject(output, object.baseVertex); + if (object.fieldMapping != null) { + output.writeInt(object.fieldMapping.length, true); + for (int i : object.fieldMapping) { + output.writeInt(i); } - if (!(o instanceof RowVertex)) { - return false; - } - RowVertex that = (RowVertex) o; - return getId().equals(that.getId()) && Objects.equals(getBinaryLabel(), that.getBinaryLabel()); + } else { + output.writeInt(0, true); + } } @Override - public int hashCode() { - return Objects.hash(getId(), getBinaryLabel()); + public FieldAlignVertex read(Kryo kryo, Input input, Class type) { + RowVertex baseVertex = (RowVertex) kryo.readClassAndObject(input); + int[] fieldMapping = new int[input.readInt(true)]; + for (int i = 0; i < fieldMapping.length; i++) { + fieldMapping[i] = input.readInt(); + } + return new FieldAlignVertex(baseVertex, fieldMapping); } @Override - public BinaryString getBinaryLabel() { - return baseVertex.getBinaryLabel(); + public FieldAlignVertex copy(Kryo kryo, FieldAlignVertex original) { + return new FieldAlignVertex( + kryo.copy(original.baseVertex), + Arrays.copyOf(original.fieldMapping, original.fieldMapping.length)); } - - @Override - public void setBinaryLabel(BinaryString label) { - baseVertex.setBinaryLabel(label); - } - - @Override - public String toString() { - return getId() + "#" + getBinaryLabel() + "#" + getValue(); - } - - public static RowVertex createFieldAlignedVertex(RowVertex baseVertex, int[] fieldMapping) { - if (baseVertex instanceof LongVertex) { - return new FieldAlignLongVertex((LongVertex) baseVertex, fieldMapping); - } else if (baseVertex instanceof IntVertex) { - return new FieldAlignIntVertex((IntVertex) baseVertex, fieldMapping); - } else if (baseVertex instanceof DoubleVertex) { - return new FieldAlignDoubleVertex((DoubleVertex) baseVertex, fieldMapping); - } - return new FieldAlignVertex(baseVertex, fieldMapping); - } - - public static class FieldAlignVertexSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, FieldAlignVertex object) { - kryo.writeClassAndObject(output, object.baseVertex); - if (object.fieldMapping != null) { - output.writeInt(object.fieldMapping.length, true); - for (int i : object.fieldMapping) { - output.writeInt(i); - } - } else { - output.writeInt(0, true); - } - } - - @Override - public FieldAlignVertex read(Kryo kryo, Input input, Class type) { - RowVertex baseVertex = (RowVertex) kryo.readClassAndObject(input); - int[] fieldMapping = new int[input.readInt(true)]; - for (int i = 0; i < fieldMapping.length; i++) { - fieldMapping[i] = input.readInt(); - } - return new FieldAlignVertex(baseVertex, fieldMapping); - } - - @Override - public FieldAlignVertex copy(Kryo kryo, FieldAlignVertex original) { - return new FieldAlignVertex(kryo.copy(original.baseVertex), Arrays.copyOf(original.fieldMapping, original.fieldMapping.length)); - } - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/GlobalVariable.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/GlobalVariable.java index 8e1210470..edb3f7ab7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/GlobalVariable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/GlobalVariable.java @@ -23,37 +23,37 @@ public class GlobalVariable { - private final String name; + private final String name; - private final int index; + private final int index; - private final IType type; + private final IType type; - private int addFieldIndex = -1; + private int addFieldIndex = -1; - public GlobalVariable(String name, int index, IType type) { - this.name = name; - this.index = index; - this.type = type; - } + public GlobalVariable(String name, int index, IType type) { + this.name = name; + this.index = index; + this.type = type; + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public int getIndex() { - return index; - } + public int getIndex() { + return index; + } - public IType getType() { - return type; - } + public IType getType() { + return type; + } - public int getAddFieldIndex() { - return addFieldIndex; - } + public int getAddFieldIndex() { + return addFieldIndex; + } - public void setAddFieldIndex(int addFieldIndex) { - this.addFieldIndex = addFieldIndex; - } + public void setAddFieldIndex(int addFieldIndex) { + this.addFieldIndex = addFieldIndex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyRequest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyRequest.java index cd39facbd..ccb478990 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyRequest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyRequest.java @@ -24,31 +24,29 @@ public class IdOnlyRequest implements ITraversalRequest { - private final Object vertexId; - - public IdOnlyRequest(Object vertexId) { - this.vertexId = vertexId; - } - - @Override - public long getRequestId() { - return -1; - } - - @Override - public Object getVId() { - return vertexId; - } - - @Override - public RequestType getType() { - return RequestType.Vertex; - } - - @Override - public String toString() { - return "IdOnlyRequest{" - + "vertexId=" + vertexId - + '}'; - } + private final Object vertexId; + + public IdOnlyRequest(Object vertexId) { + this.vertexId = vertexId; + } + + @Override + public long getRequestId() { + return -1; + } + + @Override + public Object getVId() { + return vertexId; + } + + @Override + public RequestType getType() { + return RequestType.Vertex; + } + + @Override + public String toString() { + return "IdOnlyRequest{" + "vertexId=" + vertexId + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyVertex.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyVertex.java index 9882b9385..5c77bb890 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyVertex.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/IdOnlyVertex.java @@ -19,10 +19,6 @@ package org.apache.geaflow.dsl.runtime.traversal.data; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; @@ -30,102 +26,105 @@ import org.apache.geaflow.dsl.common.types.VertexType; import org.apache.geaflow.model.graph.vertex.IVertex; -public class IdOnlyVertex implements RowVertex, KryoSerializable { - - private Object id; - - private IdOnlyVertex() { - - } - - private IdOnlyVertex(Object id) { - this.id = id; - } - - public static IdOnlyVertex of(Object id) { - return new IdOnlyVertex(id); - } - - @Override - public Object getField(int i, IType type) { - if (i == VertexType.ID_FIELD_POSITION) { - return id; - } - throw new IllegalArgumentException("Index out of range: " + i); - } - - @Override - public void setValue(Row value) { - throw new IllegalArgumentException("Illegal call on setValue"); - } - - @Override - public String getLabel() { - throw new IllegalArgumentException("Illegal call on getLabel"); - } - - @Override - public void setLabel(String label) { - throw new IllegalArgumentException("Illegal call on setLabel"); - } - - @Override - public Object getId() { - return id; - } - - @Override - public void setId(Object id) { - this.id = id; - } - - @Override - public Row getValue() { - return Row.EMPTY; - } - - @Override - public IVertex withValue(Row value) { - throw new IllegalArgumentException("Illegal call on withValue"); - } - - @Override - public IVertex withLabel(String label) { - throw new IllegalArgumentException("Illegal call on withLabel"); - } - - @Override - public IVertex withTime(long time) { - throw new IllegalArgumentException("Illegal call on withTime"); - } - - @Override - public int compareTo(Object o) { - return 0; - } - - @Override - public String toString() { - return String.valueOf(id); - } - - @Override - public BinaryString getBinaryLabel() { - throw new IllegalArgumentException("Illegal call on getBinaryLabel"); - } - - @Override - public void setBinaryLabel(BinaryString label) { - throw new IllegalArgumentException("Illegal call on setBinaryLabel"); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public void write(Kryo kryo, Output output) { - kryo.writeClassAndObject(output, this.id); - } +public class IdOnlyVertex implements RowVertex, KryoSerializable { - @Override - public void read(Kryo kryo, Input input) { - this.setId(kryo.readClassAndObject(input)); - } + private Object id; + + private IdOnlyVertex() {} + + private IdOnlyVertex(Object id) { + this.id = id; + } + + public static IdOnlyVertex of(Object id) { + return new IdOnlyVertex(id); + } + + @Override + public Object getField(int i, IType type) { + if (i == VertexType.ID_FIELD_POSITION) { + return id; + } + throw new IllegalArgumentException("Index out of range: " + i); + } + + @Override + public void setValue(Row value) { + throw new IllegalArgumentException("Illegal call on setValue"); + } + + @Override + public String getLabel() { + throw new IllegalArgumentException("Illegal call on getLabel"); + } + + @Override + public void setLabel(String label) { + throw new IllegalArgumentException("Illegal call on setLabel"); + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = id; + } + + @Override + public Row getValue() { + return Row.EMPTY; + } + + @Override + public IVertex withValue(Row value) { + throw new IllegalArgumentException("Illegal call on withValue"); + } + + @Override + public IVertex withLabel(String label) { + throw new IllegalArgumentException("Illegal call on withLabel"); + } + + @Override + public IVertex withTime(long time) { + throw new IllegalArgumentException("Illegal call on withTime"); + } + + @Override + public int compareTo(Object o) { + return 0; + } + + @Override + public String toString() { + return String.valueOf(id); + } + + @Override + public BinaryString getBinaryLabel() { + throw new IllegalArgumentException("Illegal call on getBinaryLabel"); + } + + @Override + public void setBinaryLabel(BinaryString label) { + throw new IllegalArgumentException("Illegal call on setBinaryLabel"); + } + + @Override + public void write(Kryo kryo, Output output) { + kryo.writeClassAndObject(output, this.id); + } + + @Override + public void read(Kryo kryo, Input input) { + this.setId(kryo.readClassAndObject(input)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/InitParameterRequest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/InitParameterRequest.java index d7a19b7b6..77e77d457 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/InitParameterRequest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/InitParameterRequest.java @@ -25,43 +25,46 @@ public class InitParameterRequest implements ITraversalRequest { - private final long requestId; + private final long requestId; - private final Object vertexId; + private final Object vertexId; - private final Row parameters; + private final Row parameters; - public InitParameterRequest(long requestId, Object vertexId, Row parameters) { - this.requestId = requestId; - this.vertexId = vertexId; - this.parameters = parameters; - } + public InitParameterRequest(long requestId, Object vertexId, Row parameters) { + this.requestId = requestId; + this.vertexId = vertexId; + this.parameters = parameters; + } - @Override - public long getRequestId() { - return requestId; - } + @Override + public long getRequestId() { + return requestId; + } - @Override - public Object getVId() { - return vertexId; - } + @Override + public Object getVId() { + return vertexId; + } - @Override - public RequestType getType() { - return RequestType.Vertex; - } + @Override + public RequestType getType() { + return RequestType.Vertex; + } - public Row getParameters() { - return parameters; - } + public Row getParameters() { + return parameters; + } - @Override - public String toString() { - return "InitParameterRequest{" - + "requestId=" + requestId - + ", vertexId=" + vertexId - + ", parameters=" + parameters - + '}'; - } + @Override + public String toString() { + return "InitParameterRequest{" + + "requestId=" + + requestId + + ", vertexId=" + + vertexId + + ", parameters=" + + parameters + + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ObjectSingleValue.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ObjectSingleValue.java index f9606419f..4e37daf7a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ObjectSingleValue.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ObjectSingleValue.java @@ -23,23 +23,23 @@ public class ObjectSingleValue implements SingleValue { - private final Object value; + private final Object value; - private ObjectSingleValue(Object value) { - this.value = value; - } + private ObjectSingleValue(Object value) { + this.value = value; + } - public static ObjectSingleValue of(Object value) { - return new ObjectSingleValue(value); - } + public static ObjectSingleValue of(Object value) { + return new ObjectSingleValue(value); + } - @Override - public Object getValue(IType type) { - return value; - } + @Override + public Object getValue(IType type) { + return value; + } - @Override - public String toString() { - return String.valueOf(value); - } + @Override + public String toString() { + return String.valueOf(value); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ParameterRequest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ParameterRequest.java index ce7f54ee1..99f7f64bb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ParameterRequest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/ParameterRequest.java @@ -21,57 +21,61 @@ import java.io.Serializable; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; public class ParameterRequest implements Serializable { - private final Object requestId; + private final Object requestId; - private final Object vertexId; + private final Object vertexId; - private final Row parameters; + private final Row parameters; - public ParameterRequest(Object requestId, Object vertexId, Row parameters) { - this.requestId = requestId; - this.vertexId = vertexId; - this.parameters = parameters; - } + public ParameterRequest(Object requestId, Object vertexId, Row parameters) { + this.requestId = requestId; + this.vertexId = vertexId; + this.parameters = parameters; + } - public Object getRequestId() { - return requestId; - } + public Object getRequestId() { + return requestId; + } - public Object getVertexId() { - return vertexId; - } + public Object getVertexId() { + return vertexId; + } - public Row getParameters() { - return parameters; - } + public Row getParameters() { + return parameters; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ParameterRequest)) { - return false; - } - ParameterRequest that = (ParameterRequest) o; - return Objects.equals(requestId, that.requestId); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(requestId); + if (!(o instanceof ParameterRequest)) { + return false; } + ParameterRequest that = (ParameterRequest) o; + return Objects.equals(requestId, that.requestId); + } - @Override - public String toString() { - return "ParameterRequest{" - + "requestId=" + requestId - + ", vertexId=" + vertexId - + ", parameters=" + parameters - + '}'; - } + @Override + public int hashCode() { + return Objects.hash(requestId); + } + + @Override + public String toString() { + return "ParameterRequest{" + + "requestId=" + + requestId + + ", vertexId=" + + vertexId + + ", parameters=" + + parameters + + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/SingleValue.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/SingleValue.java index d37836995..aee0131f1 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/SingleValue.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/SingleValue.java @@ -26,9 +26,9 @@ public interface SingleValue extends StepRecord { - default StepRecordType getType() { - return SINGLE_VALUE; - } + default StepRecordType getType() { + return SINGLE_VALUE; + } - Object getValue(IType type); + Object getValue(IType type); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecord.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecord.java index 3da84436e..19cfea069 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecord.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecord.java @@ -25,11 +25,11 @@ public interface StepKeyRecord extends StepRecord { - RowKey getKey(); + RowKey getKey(); - Row getValue(); + Row getValue(); - default StepRecordType getType() { - return StepRecordType.KEY_RECORD; - } + default StepRecordType getType() { + return StepRecordType.KEY_RECORD; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecordImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecordImpl.java index 9ee9e2a37..c90f3b7d2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecordImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepKeyRecordImpl.java @@ -24,22 +24,22 @@ public class StepKeyRecordImpl implements StepKeyRecord { - private final RowKey key; + private final RowKey key; - private final Row value; + private final Row value; - public StepKeyRecordImpl(RowKey key, Row value) { - this.key = key; - this.value = value; - } + public StepKeyRecordImpl(RowKey key, Row value) { + this.key = key; + this.value = value; + } - @Override - public RowKey getKey() { - return key; - } + @Override + public RowKey getKey() { + return key; + } - @Override - public Row getValue() { - return value; - } + @Override + public Row getValue() { + return value; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepRecordWithPath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepRecordWithPath.java index c4b7c1469..b467dc85b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepRecordWithPath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/StepRecordWithPath.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.function.Function; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; @@ -29,21 +30,21 @@ public interface StepRecordWithPath extends StepRecord { - ITreePath getPathById(Object vertexId); + ITreePath getPathById(Object vertexId); - Iterable getPaths(); + Iterable getPaths(); - Iterable getVertexIds(); + Iterable getVertexIds(); - StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices); + StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices); - StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices); + StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices); - StepRecordWithPath mapTreePath(Function function); + StepRecordWithPath mapTreePath(Function function); - List map(PathMapFunction function, int[] refPathIndices); + List map(PathMapFunction function, int[] refPathIndices); - StepRecordWithPath subPathSet(int[] pathIndices); + StepRecordWithPath subPathSet(int[] pathIndices); - boolean isPathEmpty(); + boolean isPathEmpty(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/TraversalAll.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/TraversalAll.java index 499cbc7e3..88966cbf2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/TraversalAll.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/TraversalAll.java @@ -23,14 +23,12 @@ public class TraversalAll implements Serializable { - public static final TraversalAll INSTANCE = new TraversalAll(); + public static final TraversalAll INSTANCE = new TraversalAll(); - private TraversalAll() { + private TraversalAll() {} - } - - @Override - public boolean equals(Object o) { - return o instanceof TraversalAll; - } + @Override + public boolean equals(Object o) { + return o instanceof TraversalAll; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/VertexRecord.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/VertexRecord.java index bcf428661..41de1d272 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/VertexRecord.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/data/VertexRecord.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Objects; import java.util.function.Function; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.runtime.traversal.path.EmptyTreePath; @@ -32,81 +33,81 @@ public class VertexRecord implements StepRecordWithPath { - private final RowVertex vertex; - - private final ITreePath treePath; - - private VertexRecord(RowVertex vertex, ITreePath treePath) { - this.vertex = vertex; - this.treePath = treePath == null ? EmptyTreePath.INSTANCE : treePath; - } - - public static VertexRecord of(RowVertex vertex, ITreePath treePath) { - return new VertexRecord(vertex, treePath); - } - - public RowVertex getVertex() { - return vertex; - } - - public ITreePath getTreePath() { - return treePath; - } - - @Override - public StepRecordType getType() { - return StepRecordType.VERTEX; - } + private final RowVertex vertex; - @Override - public ITreePath getPathById(Object vertexId) { - if (Objects.equals(vertexId, vertex.getId())) { - return treePath; - } - return null; - } + private final ITreePath treePath; - @Override - public Iterable getPaths() { - return Collections.singletonList(treePath); - } + private VertexRecord(RowVertex vertex, ITreePath treePath) { + this.vertex = vertex; + this.treePath = treePath == null ? EmptyTreePath.INSTANCE : treePath; + } - @Override - public Iterable getVertexIds() { - return Collections.singletonList(vertex.getId()); - } + public static VertexRecord of(RowVertex vertex, ITreePath treePath) { + return new VertexRecord(vertex, treePath); + } - @Override - public StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices) { - ITreePath filterTreePath = treePath.filter(function, refPathIndices); - return new VertexRecord(vertex, filterTreePath); - } + public RowVertex getVertex() { + return vertex; + } - @Override - public StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices) { - ITreePath mapTreePath = treePath.mapTree(function); - return new VertexRecord(vertex, mapTreePath); - } + public ITreePath getTreePath() { + return treePath; + } - @Override - public StepRecordWithPath mapTreePath(Function function) { - ITreePath mapTreePath = function.apply(treePath); - return new VertexRecord(vertex, mapTreePath); - } - - @Override - public List map(PathMapFunction function, int[] refPathIndices) { - return treePath.map(function); - } - - @Override - public StepRecordWithPath subPathSet(int[] pathIndices) { - ITreePath subTreePath = treePath.subPath(pathIndices); - return new VertexRecord(vertex, subTreePath); - } + @Override + public StepRecordType getType() { + return StepRecordType.VERTEX; + } - @Override - public boolean isPathEmpty() { - return treePath.isEmpty(); + @Override + public ITreePath getPathById(Object vertexId) { + if (Objects.equals(vertexId, vertex.getId())) { + return treePath; } + return null; + } + + @Override + public Iterable getPaths() { + return Collections.singletonList(treePath); + } + + @Override + public Iterable getVertexIds() { + return Collections.singletonList(vertex.getId()); + } + + @Override + public StepRecordWithPath filter(PathFilterFunction function, int[] refPathIndices) { + ITreePath filterTreePath = treePath.filter(function, refPathIndices); + return new VertexRecord(vertex, filterTreePath); + } + + @Override + public StepRecordWithPath mapPath(PathMapFunction function, int[] refPathIndices) { + ITreePath mapTreePath = treePath.mapTree(function); + return new VertexRecord(vertex, mapTreePath); + } + + @Override + public StepRecordWithPath mapTreePath(Function function) { + ITreePath mapTreePath = function.apply(treePath); + return new VertexRecord(vertex, mapTreePath); + } + + @Override + public List map(PathMapFunction function, int[] refPathIndices) { + return treePath.map(function); + } + + @Override + public StepRecordWithPath subPathSet(int[] pathIndices) { + ITreePath subTreePath = treePath.subPath(pathIndices); + return new VertexRecord(vertex, subTreePath); + } + + @Override + public boolean isPathEmpty() { + return treePath.isEmpty(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessage.java index ebd7d3177..5ec75ebe7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessage.java @@ -19,73 +19,74 @@ package org.apache.geaflow.dsl.runtime.traversal.message; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.apache.geaflow.dsl.runtime.traversal.data.EndOfData; + import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.google.common.collect.Lists; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import org.apache.geaflow.dsl.runtime.traversal.data.EndOfData; public class EODMessage implements IMessage { - private final List endOfDatas; + private final List endOfDatas; - private EODMessage() { - this(new ArrayList<>()); - } + private EODMessage() { + this(new ArrayList<>()); + } - private EODMessage(List endOfDatas) { - this.endOfDatas = Objects.requireNonNull(endOfDatas); - } + private EODMessage(List endOfDatas) { + this.endOfDatas = Objects.requireNonNull(endOfDatas); + } - public static EODMessage of(EndOfData data) { - return new EODMessage(Lists.newArrayList(data)); - } + public static EODMessage of(EndOfData data) { + return new EODMessage(Lists.newArrayList(data)); + } - @Override - public MessageType getType() { - return MessageType.EOD; - } + @Override + public MessageType getType() { + return MessageType.EOD; + } - @Override - public IMessage combine(IMessage other) { - EODMessage newMessage = this.copy(); - EODMessage otherEod = (EODMessage) other; - newMessage.endOfDatas.addAll(otherEod.endOfDatas); - return newMessage; - } + @Override + public IMessage combine(IMessage other) { + EODMessage newMessage = this.copy(); + EODMessage otherEod = (EODMessage) other; + newMessage.endOfDatas.addAll(otherEod.endOfDatas); + return newMessage; + } + + @Override + public EODMessage copy() { + return new EODMessage(new ArrayList<>(endOfDatas)); + } + + public List getEodData() { + return endOfDatas; + } + + public static class EODMessageSerializer extends Serializer { @Override - public EODMessage copy() { - return new EODMessage(new ArrayList<>(endOfDatas)); + public void write(Kryo kryo, Output output, EODMessage object) { + // serialize endOfDatas using default CollectionSerializer + kryo.writeObject(output, object.getEodData()); } - public List getEodData() { - return endOfDatas; + @Override + public EODMessage read(Kryo kryo, Input input, Class type) { + // deserialize endOfDatas using default CollectionSerializer + List endOfDatas = kryo.readObject(input, ArrayList.class); + return new EODMessage(endOfDatas); } - public static class EODMessageSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, EODMessage object) { - // serialize endOfDatas using default CollectionSerializer - kryo.writeObject(output, object.getEodData()); - } - - @Override - public EODMessage read(Kryo kryo, Input input, Class type) { - // deserialize endOfDatas using default CollectionSerializer - List endOfDatas = kryo.readObject(input, ArrayList.class); - return new EODMessage(endOfDatas); - } - - @Override - public EODMessage copy(Kryo kryo, EODMessage original) { - return new EODMessage(new ArrayList<>(original.getEodData())); - } + @Override + public EODMessage copy(Kryo kryo, EODMessage original) { + return new EODMessage(new ArrayList<>(original.getEodData())); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessageBox.java index f030fe882..fe093b990 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EODMessageBox.java @@ -21,16 +21,15 @@ public class EODMessageBox extends SingleMessageBox { - public EODMessageBox() { - } + public EODMessageBox() {} - @Override - public MessageType getMessageType() { - return MessageType.EOD; - } + @Override + public MessageType getMessageType() { + return MessageType.EOD; + } - @Override - protected SingleMessageBox create() { - return new EODMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new EODMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessage.java index fc42b2c1b..d214eb4ed 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessage.java @@ -21,23 +21,20 @@ public class EvolveVertexMessage implements IMessage { - public EvolveVertexMessage() { - - } - - @Override - public MessageType getType() { - return MessageType.EVOLVE_VERTEX; - } - - @Override - public IMessage combine(IMessage other) { - return new EvolveVertexMessage(); - } - - @Override - public EvolveVertexMessage copy() { - return new EvolveVertexMessage(); - } - + public EvolveVertexMessage() {} + + @Override + public MessageType getType() { + return MessageType.EVOLVE_VERTEX; + } + + @Override + public IMessage combine(IMessage other) { + return new EvolveVertexMessage(); + } + + @Override + public EvolveVertexMessage copy() { + return new EvolveVertexMessage(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessageBox.java index ca56de8b2..c4c7245b0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/EvolveVertexMessageBox.java @@ -21,13 +21,13 @@ public class EvolveVertexMessageBox extends SingleMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.EVOLVE_VERTEX; - } + @Override + public MessageType getMessageType() { + return MessageType.EVOLVE_VERTEX; + } - @Override - protected SingleMessageBox create() { - return new EvolveVertexMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new EvolveVertexMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IMessage.java index 470726291..03f3ea26f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IMessage.java @@ -23,10 +23,9 @@ public interface IMessage extends Serializable { - MessageType getType(); + MessageType getType(); - IMessage combine(IMessage other); - - IMessage copy(); + IMessage combine(IMessage other); + IMessage copy(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessage.java index f96446925..d566abb21 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessage.java @@ -19,6 +19,4 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -public interface IPathMessage extends RequestIsolationMessage { - -} +public interface IPathMessage extends RequestIsolationMessage {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessageBox.java index 905cfdec4..43d3d41ce 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/IPathMessageBox.java @@ -23,5 +23,5 @@ public interface IPathMessageBox extends MessageBox { - ITreePath[] getPathMessages(); + ITreePath[] getPathMessages(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ITraversalAgg.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ITraversalAgg.java index 2f48b5f17..fbb919b60 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ITraversalAgg.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ITraversalAgg.java @@ -19,6 +19,4 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -public interface ITraversalAgg { - -} +public interface ITraversalAgg {} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessage.java index 90b052b94..a77e6d17e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessage.java @@ -19,100 +19,102 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.HashMap; import java.util.Map; import java.util.Set; + import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class JoinPathMessage implements IPathMessage { +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - private final Map senderId2Paths; +public class JoinPathMessage implements IPathMessage { - public JoinPathMessage(Map senderId2Paths) { - this.senderId2Paths = senderId2Paths; + private final Map senderId2Paths; + + public JoinPathMessage(Map senderId2Paths) { + this.senderId2Paths = senderId2Paths; + } + + public JoinPathMessage() { + this(new HashMap<>()); + } + + public static JoinPathMessage from(long senderId, ITreePath treePath) { + Map senderId2Paths = new HashMap<>(); + senderId2Paths.put(senderId, treePath); + return new JoinPathMessage(senderId2Paths); + } + + @Override + public MessageType getType() { + return MessageType.JOIN_PATH; + } + + @Override + public IMessage combine(IMessage other) { + JoinPathMessage combinedTreePath = this.copy(); + JoinPathMessage otherTreePath = (JoinPathMessage) other; + + for (Map.Entry entry : otherTreePath.senderId2Paths.entrySet()) { + long senderId = entry.getKey(); + ITreePath treePath = entry.getValue(); + if (combinedTreePath.senderId2Paths.containsKey(senderId)) { + ITreePath mergeTree = combinedTreePath.senderId2Paths.get(senderId).merge(treePath); + combinedTreePath.senderId2Paths.put(senderId, mergeTree); + } else { + combinedTreePath.senderId2Paths.put(senderId, treePath); + } } - - public JoinPathMessage() { - this(new HashMap<>()); + return combinedTreePath; + } + + @Override + public JoinPathMessage copy() { + return new JoinPathMessage(new HashMap<>(senderId2Paths)); + } + + @Override + public IMessage getMessageByRequestId(Object requestId) { + Map requestTreePaths = new HashMap<>(senderId2Paths.size()); + for (Map.Entry entry : senderId2Paths.entrySet()) { + long senderId = entry.getKey(); + ITreePath treePath = (ITreePath) entry.getValue().getMessageByRequestId(requestId); + requestTreePaths.put(senderId, treePath); } + return new JoinPathMessage(requestTreePaths); + } - public static JoinPathMessage from(long senderId, ITreePath treePath) { - Map senderId2Paths = new HashMap<>(); - senderId2Paths.put(senderId, treePath); - return new JoinPathMessage(senderId2Paths); - } + public ITreePath getTreePath(long senderId) { + return senderId2Paths.get(senderId); + } - @Override - public MessageType getType() { - return MessageType.JOIN_PATH; - } + public boolean isEmpty() { + return senderId2Paths.isEmpty(); + } - @Override - public IMessage combine(IMessage other) { - JoinPathMessage combinedTreePath = this.copy(); - JoinPathMessage otherTreePath = (JoinPathMessage) other; - - for (Map.Entry entry : otherTreePath.senderId2Paths.entrySet()) { - long senderId = entry.getKey(); - ITreePath treePath = entry.getValue(); - if (combinedTreePath.senderId2Paths.containsKey(senderId)) { - ITreePath mergeTree = combinedTreePath.senderId2Paths.get(senderId).merge(treePath); - combinedTreePath.senderId2Paths.put(senderId, mergeTree); - } else { - combinedTreePath.senderId2Paths.put(senderId, treePath); - } - } - return combinedTreePath; - } + public Set getSenders() { + return senderId2Paths.keySet(); + } - @Override - public JoinPathMessage copy() { - return new JoinPathMessage(new HashMap<>(senderId2Paths)); - } + public static class JoinPathMessageSerializer extends Serializer { @Override - public IMessage getMessageByRequestId(Object requestId) { - Map requestTreePaths = new HashMap<>(senderId2Paths.size()); - for (Map.Entry entry : senderId2Paths.entrySet()) { - long senderId = entry.getKey(); - ITreePath treePath = (ITreePath) entry.getValue().getMessageByRequestId(requestId); - requestTreePaths.put(senderId, treePath); - } - return new JoinPathMessage(requestTreePaths); - } - - public ITreePath getTreePath(long senderId) { - return senderId2Paths.get(senderId); + public void write(Kryo kryo, Output output, JoinPathMessage object) { + kryo.writeClassAndObject(output, object.senderId2Paths); } - public boolean isEmpty() { - return senderId2Paths.isEmpty(); - } - - public Set getSenders() { - return senderId2Paths.keySet(); + @Override + public JoinPathMessage read(Kryo kryo, Input input, Class type) { + Map senderId2Paths = (Map) kryo.readClassAndObject(input); + return new JoinPathMessage(senderId2Paths); } - public static class JoinPathMessageSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, JoinPathMessage object) { - kryo.writeClassAndObject(output, object.senderId2Paths); - } - - @Override - public JoinPathMessage read(Kryo kryo, Input input, Class type) { - Map senderId2Paths = (Map) kryo.readClassAndObject(input); - return new JoinPathMessage(senderId2Paths); - } - - @Override - public JoinPathMessage copy(Kryo kryo, JoinPathMessage original) { - return original.copy(); - } + @Override + public JoinPathMessage copy(Kryo kryo, JoinPathMessage original) { + return original.copy(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessageBox.java index 2340d8db9..bd9c14d5e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/JoinPathMessageBox.java @@ -21,13 +21,13 @@ public class JoinPathMessageBox extends SingleMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.JOIN_PATH; - } + @Override + public MessageType getMessageType() { + return MessageType.JOIN_PATH; + } - @Override - protected SingleMessageBox create() { - return new JoinPathMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new JoinPathMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KVTraversalAgg.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KVTraversalAgg.java index d68b5955b..769f517b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KVTraversalAgg.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KVTraversalAgg.java @@ -25,42 +25,42 @@ public class KVTraversalAgg implements ITraversalAgg { - private final Map map; + private final Map map; - public KVTraversalAgg() { - this.map = new HashMap<>(); - } + public KVTraversalAgg() { + this.map = new HashMap<>(); + } - public KVTraversalAgg(Map map) { - this.map = new HashMap<>(map); - } + public KVTraversalAgg(Map map) { + this.map = new HashMap<>(map); + } - public KVTraversalAgg(K key, V value) { - this.map = Collections.singletonMap(key, value); - } + public KVTraversalAgg(K key, V value) { + this.map = Collections.singletonMap(key, value); + } - public Map getMap() { - return map; - } + public Map getMap() { + return map; + } - public V get(K key) { - return this.map.get(key); - } + public V get(K key) { + return this.map.get(key); + } - public void clear() { - this.map.clear(); - } + public void clear() { + this.map.clear(); + } - public KVTraversalAgg copy() { - return new KVTraversalAgg<>(map); - } + public KVTraversalAgg copy() { + return new KVTraversalAgg<>(map); + } - public static KVTraversalAgg empty() { - return new KVTraversalAgg(); - } + public static KVTraversalAgg empty() { + return new KVTraversalAgg(); + } - @Override - public String toString() { - return "KVTraversalAgg{" + "map=" + map + '}'; - } + @Override + public String toString() { + return "KVTraversalAgg{" + "map=" + map + '}'; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessage.java index 4e3a613d1..6d18450f4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessage.java @@ -20,9 +20,10 @@ package org.apache.geaflow.dsl.runtime.traversal.message; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; public interface KeyGroupMessage extends IMessage { - List getGroupRows(); + List getGroupRows(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageBox.java index 07489bd44..e460b0ecf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageBox.java @@ -21,13 +21,13 @@ public class KeyGroupMessageBox extends SingleMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.KEY_GROUP; - } + @Override + public MessageType getMessageType() { + return MessageType.KEY_GROUP; + } - @Override - protected SingleMessageBox create() { - return new KeyGroupMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new KeyGroupMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageImpl.java index 20a340d73..be745e004 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/KeyGroupMessageImpl.java @@ -19,63 +19,64 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.ArrayList; import java.util.List; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.Row; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class KeyGroupMessageImpl implements KeyGroupMessage { - private final List groupRows; + private final List groupRows; - public KeyGroupMessageImpl(List groupRows) { - this.groupRows = Objects.requireNonNull(groupRows); - } + public KeyGroupMessageImpl(List groupRows) { + this.groupRows = Objects.requireNonNull(groupRows); + } - @Override - public MessageType getType() { - return MessageType.KEY_GROUP; - } + @Override + public MessageType getType() { + return MessageType.KEY_GROUP; + } - @Override - public IMessage combine(IMessage other) { - List combineRows = new ArrayList<>(groupRows); - KeyGroupMessage groupMessage = (KeyGroupMessage) other; - combineRows.addAll(groupMessage.getGroupRows()); - return new KeyGroupMessageImpl(combineRows); - } + @Override + public IMessage combine(IMessage other) { + List combineRows = new ArrayList<>(groupRows); + KeyGroupMessage groupMessage = (KeyGroupMessage) other; + combineRows.addAll(groupMessage.getGroupRows()); + return new KeyGroupMessageImpl(combineRows); + } + + @Override + public KeyGroupMessage copy() { + return new KeyGroupMessageImpl(new ArrayList<>(groupRows)); + } + + @Override + public List getGroupRows() { + return groupRows; + } + + public static class KeyGroupMessageImplSerializer extends Serializer { @Override - public KeyGroupMessage copy() { - return new KeyGroupMessageImpl(new ArrayList<>(groupRows)); + public void write(Kryo kryo, Output output, KeyGroupMessageImpl object) { + kryo.writeObject(output, object.groupRows); } @Override - public List getGroupRows() { - return groupRows; + public KeyGroupMessageImpl read(Kryo kryo, Input input, Class type) { + List groupRows = kryo.readObject(input, ArrayList.class); + return new KeyGroupMessageImpl(groupRows); } - public static class KeyGroupMessageImplSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, KeyGroupMessageImpl object) { - kryo.writeObject(output, object.groupRows); - } - - @Override - public KeyGroupMessageImpl read(Kryo kryo, Input input, Class type) { - List groupRows = kryo.readObject(input, ArrayList.class); - return new KeyGroupMessageImpl(groupRows); - } - - @Override - public KeyGroupMessageImpl copy(Kryo kryo, KeyGroupMessageImpl original) { - return new KeyGroupMessageImpl(new ArrayList<>(original.getGroupRows())); - } + @Override + public KeyGroupMessageImpl copy(Kryo kryo, KeyGroupMessageImpl original) { + return new KeyGroupMessageImpl(new ArrayList<>(original.getGroupRows())); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageBox.java index 29847f416..d90e2127b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageBox.java @@ -23,17 +23,17 @@ public interface MessageBox extends Serializable { - void addMessage(long receiverId, IMessage message); + void addMessage(long receiverId, IMessage message); - M getMessage(long receiverId, MessageType messageType); + M getMessage(long receiverId, MessageType messageType); - long[] getReceiverIds(); + long[] getReceiverIds(); - MessageBox combine(MessageBox other); + MessageBox combine(MessageBox other); - MessageBox copy(); + MessageBox copy(); - default boolean isEmpty() { - return getReceiverIds() == null || getReceiverIds().length == 0; - } + default boolean isEmpty() { + return getReceiverIds() == null || getReceiverIds().length == 0; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageType.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageType.java index 8912332ef..40026dce2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageType.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/MessageType.java @@ -23,57 +23,54 @@ import org.apache.geaflow.dsl.runtime.traversal.operator.StepJoinOperator; public enum MessageType { - /** - * The message type for {@link IPathMessage} which is the path message send to the vertex. - */ - PATH, - /** - * The message type for {@link EODMessage} which is end-of-data message sending to the next - * by the step operator when finished processing all the data. - */ - EOD, - /** - * The message type for {@link ParameterRequestMessage} which is the request message with the parameters. - */ - PARAMETER_REQUEST, - /** - * The path message type for {@link JoinPathMessage} which for the - * {@link StepJoinOperator}. - */ - JOIN_PATH, - /** - * The message type for {@link KeyGroupMessage} which is the message send to the - * {@link VirtualId} to do same relation operations - * e.g. aggregate. - */ - KEY_GROUP, - /** - * The message type for {@link ReturnMessage} which is the result returning from the sub-query calling. - */ - RETURN_VALUE, - /** - * The message type for {@link EvolveVertexMessage} which is the message to evolve vertices in incrMatch. - */ - EVOLVE_VERTEX; + /** The message type for {@link IPathMessage} which is the path message send to the vertex. */ + PATH, + /** + * The message type for {@link EODMessage} which is end-of-data message sending to the next by the + * step operator when finished processing all the data. + */ + EOD, + /** + * The message type for {@link ParameterRequestMessage} which is the request message with the + * parameters. + */ + PARAMETER_REQUEST, + /** The path message type for {@link JoinPathMessage} which for the {@link StepJoinOperator}. */ + JOIN_PATH, + /** + * The message type for {@link KeyGroupMessage} which is the message send to the {@link VirtualId} + * to do same relation operations e.g. aggregate. + */ + KEY_GROUP, + /** + * The message type for {@link ReturnMessage} which is the result returning from the sub-query + * calling. + */ + RETURN_VALUE, + /** + * The message type for {@link EvolveVertexMessage} which is the message to evolve vertices in + * incrMatch. + */ + EVOLVE_VERTEX; - public SingleMessageBox createMessageBox() { - switch (this) { - case PATH: - return new PathMessageBox(); - case EOD: - return new EODMessageBox(); - case PARAMETER_REQUEST: - return new ParameterRequestMessageBox(); - case JOIN_PATH: - return new JoinPathMessageBox(); - case KEY_GROUP: - return new KeyGroupMessageBox(); - case RETURN_VALUE: - return new ReturnMessageBox(); - case EVOLVE_VERTEX: - return new EvolveVertexMessageBox(); - default: - throw new IllegalArgumentException("Failing to create message box for: " + this); - } + public SingleMessageBox createMessageBox() { + switch (this) { + case PATH: + return new PathMessageBox(); + case EOD: + return new EODMessageBox(); + case PARAMETER_REQUEST: + return new ParameterRequestMessageBox(); + case JOIN_PATH: + return new JoinPathMessageBox(); + case KEY_GROUP: + return new KeyGroupMessageBox(); + case RETURN_VALUE: + return new ReturnMessageBox(); + case EVOLVE_VERTEX: + return new EvolveVertexMessageBox(); + default: + throw new IllegalArgumentException("Failing to create message box for: " + this); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessage.java index 3c07bd41b..ff76c26fd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessage.java @@ -19,80 +19,83 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.HashSet; import java.util.Set; import java.util.function.Consumer; + import org.apache.geaflow.dsl.runtime.traversal.data.ParameterRequest; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class ParameterRequestMessage implements IMessage { - private final Set requests; + private final Set requests; - public ParameterRequestMessage(Set requests) { - this.requests = requests; - } + public ParameterRequestMessage(Set requests) { + this.requests = requests; + } - public ParameterRequestMessage() { - this(new HashSet<>()); - } + public ParameterRequestMessage() { + this(new HashSet<>()); + } - @Override - public MessageType getType() { - return MessageType.PARAMETER_REQUEST; - } + @Override + public MessageType getType() { + return MessageType.PARAMETER_REQUEST; + } - public void addRequest(ParameterRequest request) { - requests.add(request); - } + public void addRequest(ParameterRequest request) { + requests.add(request); + } - public void forEach(Consumer consumer) { - requests.forEach(consumer); - } + public void forEach(Consumer consumer) { + requests.forEach(consumer); + } - @Override - public IMessage combine(IMessage other) { - - Set requests = new HashSet<>(); - if (this.requests != null) { - requests.addAll(this.requests); - } - Set thatRequests = ((ParameterRequestMessage) other).requests; - if (thatRequests != null) { - requests.addAll(thatRequests); - } - return new ParameterRequestMessage(requests); - } + @Override + public IMessage combine(IMessage other) { - @Override - public IMessage copy() { - return new ParameterRequestMessage(new HashSet<>(requests)); + Set requests = new HashSet<>(); + if (this.requests != null) { + requests.addAll(this.requests); } - - public boolean isEmpty() { - return requests.isEmpty(); + Set thatRequests = ((ParameterRequestMessage) other).requests; + if (thatRequests != null) { + requests.addAll(thatRequests); } + return new ParameterRequestMessage(requests); + } - public static class ParameterRequestMessageSerializer extends Serializer { + @Override + public IMessage copy() { + return new ParameterRequestMessage(new HashSet<>(requests)); + } - @Override - public void write(Kryo kryo, Output output, ParameterRequestMessage object) { - kryo.writeClassAndObject(output, object.requests); - } + public boolean isEmpty() { + return requests.isEmpty(); + } - @Override - public ParameterRequestMessage read(Kryo kryo, Input input, Class type) { - Set requests = (Set) kryo.readClassAndObject(input); - return new ParameterRequestMessage(requests); - } + public static class ParameterRequestMessageSerializer + extends Serializer { - @Override - public ParameterRequestMessage copy(Kryo kryo, ParameterRequestMessage original) { - return new ParameterRequestMessage(new HashSet<>(original.requests)); - } + @Override + public void write(Kryo kryo, Output output, ParameterRequestMessage object) { + kryo.writeClassAndObject(output, object.requests); + } + + @Override + public ParameterRequestMessage read( + Kryo kryo, Input input, Class type) { + Set requests = (Set) kryo.readClassAndObject(input); + return new ParameterRequestMessage(requests); } + @Override + public ParameterRequestMessage copy(Kryo kryo, ParameterRequestMessage original) { + return new ParameterRequestMessage(new HashSet<>(original.requests)); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessageBox.java index b99cf5156..b844da9df 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ParameterRequestMessageBox.java @@ -21,13 +21,13 @@ public class ParameterRequestMessageBox extends SingleMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.PARAMETER_REQUEST; - } + @Override + public MessageType getMessageType() { + return MessageType.PARAMETER_REQUEST; + } - @Override - protected SingleMessageBox create() { - return new ParameterRequestMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new ParameterRequestMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/PathMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/PathMessageBox.java index c81b7ca69..0cf115469 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/PathMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/PathMessageBox.java @@ -23,23 +23,23 @@ public class PathMessageBox extends SingleMessageBox implements IPathMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.PATH; - } + @Override + public MessageType getMessageType() { + return MessageType.PATH; + } - @Override - protected SingleMessageBox create() { - return new PathMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new PathMessageBox(); + } - @Override - public ITreePath[] getPathMessages() { - if (messages == null) { - return null; - } - ITreePath[] paths = new ITreePath[messages.length]; - System.arraycopy(messages, 0, paths, 0, messages.length); - return paths; + @Override + public ITreePath[] getPathMessages() { + if (messages == null) { + return null; } + ITreePath[] paths = new ITreePath[messages.length]; + System.arraycopy(messages, 0, paths, 0, messages.length); + return paths; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/RequestIsolationMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/RequestIsolationMessage.java index 699ea8a2d..8134f0295 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/RequestIsolationMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/RequestIsolationMessage.java @@ -21,5 +21,5 @@ public interface RequestIsolationMessage extends IMessage { - IMessage getMessageByRequestId(Object requestId); + IMessage getMessageByRequestId(Object requestId); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessage.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessage.java index 61e70aa8e..1cdd13b4c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessage.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessage.java @@ -20,14 +20,15 @@ package org.apache.geaflow.dsl.runtime.traversal.message; import java.util.Map; + import org.apache.geaflow.dsl.runtime.traversal.data.SingleValue; import org.apache.geaflow.dsl.runtime.traversal.message.ReturnMessageImpl.ReturnKey; public interface ReturnMessage extends IMessage { - Map getReturnKey2Values(); + Map getReturnKey2Values(); - void putValue(ReturnKey returnKey, SingleValue value); + void putValue(ReturnKey returnKey, SingleValue value); - SingleValue getValue(ReturnKey returnKey); + SingleValue getValue(ReturnKey returnKey); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageBox.java index d0c3bc6c5..f09fbff5b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageBox.java @@ -21,13 +21,13 @@ public class ReturnMessageBox extends SingleMessageBox { - @Override - public MessageType getMessageType() { - return MessageType.RETURN_VALUE; - } + @Override + public MessageType getMessageType() { + return MessageType.RETURN_VALUE; + } - @Override - protected SingleMessageBox create() { - return new ReturnMessageBox(); - } + @Override + protected SingleMessageBox create() { + return new ReturnMessageBox(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageImpl.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageImpl.java index c972b8d6c..c8567c463 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageImpl.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/ReturnMessageImpl.java @@ -19,127 +19,124 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Objects; + import org.apache.geaflow.dsl.runtime.traversal.data.SingleValue; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class ReturnMessageImpl implements ReturnMessage { - private final Map returnKey2Value; + private final Map returnKey2Value; - public ReturnMessageImpl(Map returnKey2Value) { - this.returnKey2Value = returnKey2Value; - } + public ReturnMessageImpl(Map returnKey2Value) { + this.returnKey2Value = returnKey2Value; + } - public ReturnMessageImpl() { - this(new HashMap<>()); - } + public ReturnMessageImpl() { + this(new HashMap<>()); + } - @Override - public MessageType getType() { - return MessageType.RETURN_VALUE; - } + @Override + public MessageType getType() { + return MessageType.RETURN_VALUE; + } - @Override - public ReturnMessage combine(IMessage other) { - ReturnMessage copy = this.copy(); - ReturnMessage otherMsg = (ReturnMessage) other; - - for (Map.Entry entry : otherMsg.getReturnKey2Values().entrySet()) { - ReturnKey returnKey = entry.getKey(); - SingleValue value = entry.getValue(); - if (value != null) { - copy.putValue(returnKey, value); - } - } - return copy; + @Override + public ReturnMessage combine(IMessage other) { + ReturnMessage copy = this.copy(); + ReturnMessage otherMsg = (ReturnMessage) other; + + for (Map.Entry entry : otherMsg.getReturnKey2Values().entrySet()) { + ReturnKey returnKey = entry.getKey(); + SingleValue value = entry.getValue(); + if (value != null) { + copy.putValue(returnKey, value); + } } + return copy; + } - @Override - public ReturnMessage copy() { - return new ReturnMessageImpl(new HashMap<>(returnKey2Value)); + @Override + public ReturnMessage copy() { + return new ReturnMessageImpl(new HashMap<>(returnKey2Value)); + } + + @Override + public Map getReturnKey2Values() { + return returnKey2Value; + } + + @Override + public void putValue(ReturnKey returnKey, SingleValue value) { + returnKey2Value.put(returnKey, value); + } + + @Override + public SingleValue getValue(ReturnKey returnKey) { + return returnKey2Value.get(returnKey); + } + + public static class ReturnKey implements Serializable { + + private final long pathId; + + private final long queryId; + + public ReturnKey(long pathId, long queryId) { + if (pathId < 0) { + throw new IllegalArgumentException("Illegal pathId: " + pathId); + } + this.pathId = pathId; + this.queryId = queryId; } @Override - public Map getReturnKey2Values() { - return returnKey2Value; + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ReturnKey)) { + return false; + } + ReturnKey returnKey = (ReturnKey) o; + return queryId == returnKey.queryId && Objects.equals(pathId, returnKey.pathId); } @Override - public void putValue(ReturnKey returnKey, SingleValue value) { - returnKey2Value.put(returnKey, value); + public int hashCode() { + return Objects.hash(pathId, queryId); } @Override - public SingleValue getValue(ReturnKey returnKey) { - return returnKey2Value.get(returnKey); + public String toString() { + return "ReturnKey{" + "pathId=" + pathId + ", queryId=" + queryId + '}'; } + } + public static class ReturnMessageImplSerializer extends Serializer { - public static class ReturnKey implements Serializable { - - private final long pathId; - - private final long queryId; - - public ReturnKey(long pathId, long queryId) { - if (pathId < 0) { - throw new IllegalArgumentException("Illegal pathId: " + pathId); - } - this.pathId = pathId; - this.queryId = queryId; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ReturnKey)) { - return false; - } - ReturnKey returnKey = (ReturnKey) o; - return queryId == returnKey.queryId && Objects.equals(pathId, returnKey.pathId); - } - - @Override - public int hashCode() { - return Objects.hash(pathId, queryId); - } - - @Override - public String toString() { - return "ReturnKey{" - + "pathId=" + pathId - + ", queryId=" + queryId - + '}'; - } + @Override + public void write(Kryo kryo, Output output, ReturnMessageImpl object) { + kryo.writeClassAndObject(output, object.returnKey2Value); } - public static class ReturnMessageImplSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, ReturnMessageImpl object) { - kryo.writeClassAndObject(output, object.returnKey2Value); - } - - @Override - public ReturnMessageImpl read(Kryo kryo, Input input, Class type) { - Map returnKey2Value = - (Map) kryo.readClassAndObject(input); - return new ReturnMessageImpl(returnKey2Value); - } - - @Override - public ReturnMessageImpl copy(Kryo kryo, ReturnMessageImpl original) { - return new ReturnMessageImpl(new HashMap<>(original.returnKey2Value)); - } + @Override + public ReturnMessageImpl read(Kryo kryo, Input input, Class type) { + Map returnKey2Value = + (Map) kryo.readClassAndObject(input); + return new ReturnMessageImpl(returnKey2Value); } + @Override + public ReturnMessageImpl copy(Kryo kryo, ReturnMessageImpl original) { + return new ReturnMessageImpl(new HashMap<>(original.returnKey2Value)); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/SingleMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/SingleMessageBox.java index 1c5ded6f0..ed86deaa7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/SingleMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/SingleMessageBox.java @@ -20,101 +20,102 @@ package org.apache.geaflow.dsl.runtime.traversal.message; import java.util.Arrays; + import org.apache.geaflow.common.utils.ArrayUtil; public abstract class SingleMessageBox implements MessageBox { - // The operator id who receive the message. - protected long[] receiverIds; + // The operator id who receive the message. + protected long[] receiverIds; - // Messages for each receiver. - protected IMessage[] messages; + // Messages for each receiver. + protected IMessage[] messages; - @Override - public void addMessage(long receiverId, IMessage message) { - int index; - IMessage oldMessage = null; - if (receiverIds == null) { - receiverIds = new long[1]; - messages = new IMessage[1]; - index = 0; - } else { - index = ArrayUtil.indexOf(receiverIds, receiverId); - if (index == -1) { - index = receiverIds.length; - receiverIds = ArrayUtil.grow(receiverIds, 1); - IMessage[] newMessages = new IMessage[messages.length + 1]; - System.arraycopy(messages, 0, newMessages, 0, messages.length); - messages = newMessages; - } else { - oldMessage = messages[index]; - } - } - receiverIds[index] = receiverId; - if (oldMessage == null) { - // Copy the first message as the combine() method will modify the old message. - messages[index] = message.copy(); - } else { - messages[index] = oldMessage.combine(message); - } + @Override + public void addMessage(long receiverId, IMessage message) { + int index; + IMessage oldMessage = null; + if (receiverIds == null) { + receiverIds = new long[1]; + messages = new IMessage[1]; + index = 0; + } else { + index = ArrayUtil.indexOf(receiverIds, receiverId); + if (index == -1) { + index = receiverIds.length; + receiverIds = ArrayUtil.grow(receiverIds, 1); + IMessage[] newMessages = new IMessage[messages.length + 1]; + System.arraycopy(messages, 0, newMessages, 0, messages.length); + messages = newMessages; + } else { + oldMessage = messages[index]; + } } - - @SuppressWarnings("unchecked") - @Override - public M getMessage(long receiverId, MessageType messageType) { - if (messageType != getMessageType()) { - return null; - } - int index = ArrayUtil.indexOf(receiverIds, receiverId); - if (index == -1) { - return null; - } - return (M) messages[index]; + receiverIds[index] = receiverId; + if (oldMessage == null) { + // Copy the first message as the combine() method will modify the old message. + messages[index] = message.copy(); + } else { + messages[index] = oldMessage.combine(message); } + } - @Override - public long[] getReceiverIds() { - return receiverIds; + @SuppressWarnings("unchecked") + @Override + public M getMessage(long receiverId, MessageType messageType) { + if (messageType != getMessageType()) { + return null; } + int index = ArrayUtil.indexOf(receiverIds, receiverId); + if (index == -1) { + return null; + } + return (M) messages[index]; + } - public abstract MessageType getMessageType(); + @Override + public long[] getReceiverIds() { + return receiverIds; + } - @Override - public MessageBox combine(MessageBox other) { - if (other.isEmpty()) { - return this; - } - if (other instanceof SingleMessageBox && canMerge((SingleMessageBox) other)) { - MessageBox newBox = this.copy(); - SingleMessageBox otherBox = (SingleMessageBox) other; - for (long receiverId : otherBox.getReceiverIds()) { - IMessage message = otherBox.getMessage(receiverId, getMessageType()); - newBox.addMessage(receiverId, message); - } - return newBox; - } else { - UnionMessageBox unionBox = new UnionMessageBox(); - unionBox.addMessageBox(this); - unionBox.addMessageBox(other); - return unionBox; - } - } + public abstract MessageType getMessageType(); - protected boolean canMerge(SingleMessageBox other) { - return other.getMessageType() == this.getMessageType(); + @Override + public MessageBox combine(MessageBox other) { + if (other.isEmpty()) { + return this; } + if (other instanceof SingleMessageBox && canMerge((SingleMessageBox) other)) { + MessageBox newBox = this.copy(); + SingleMessageBox otherBox = (SingleMessageBox) other; + for (long receiverId : otherBox.getReceiverIds()) { + IMessage message = otherBox.getMessage(receiverId, getMessageType()); + newBox.addMessage(receiverId, message); + } + return newBox; + } else { + UnionMessageBox unionBox = new UnionMessageBox(); + unionBox.addMessageBox(this); + unionBox.addMessageBox(other); + return unionBox; + } + } + + protected boolean canMerge(SingleMessageBox other) { + return other.getMessageType() == this.getMessageType(); + } - @Override - public MessageBox copy() { - SingleMessageBox newMessage = create(); - if (receiverIds != null) { - newMessage.receiverIds = ArrayUtil.copy(receiverIds); - } - if (messages != null) { - newMessage.messages = Arrays.copyOf(messages, messages.length); - } - return newMessage; + @Override + public MessageBox copy() { + SingleMessageBox newMessage = create(); + if (receiverIds != null) { + newMessage.receiverIds = ArrayUtil.copy(receiverIds); + } + if (messages != null) { + newMessage.messages = Arrays.copyOf(messages, messages.length); } + return newMessage; + } - protected abstract SingleMessageBox create(); + protected abstract SingleMessageBox create(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/UnionMessageBox.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/UnionMessageBox.java index 71e1e6ff8..9a1e640e4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/UnionMessageBox.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/message/UnionMessageBox.java @@ -19,82 +19,85 @@ package org.apache.geaflow.dsl.runtime.traversal.message; -import it.unimi.dsi.fastutil.longs.LongArraySet; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import it.unimi.dsi.fastutil.longs.LongArraySet; + public class UnionMessageBox implements MessageBox { - private final Map messageBoxes; + private final Map messageBoxes; - public UnionMessageBox(Map messageBoxes) { - this.messageBoxes = Objects.requireNonNull(messageBoxes); - } + public UnionMessageBox(Map messageBoxes) { + this.messageBoxes = Objects.requireNonNull(messageBoxes); + } - public UnionMessageBox() { - this(new HashMap<>()); - } + public UnionMessageBox() { + this(new HashMap<>()); + } - @Override - public void addMessage(long receiverId, IMessage message) { - MessageType messageType = message.getType(); - MessageBox messageBox = messageBoxes.computeIfAbsent(messageType, m -> messageType.createMessageBox()); - messageBox.addMessage(receiverId, message); - } + @Override + public void addMessage(long receiverId, IMessage message) { + MessageType messageType = message.getType(); + MessageBox messageBox = + messageBoxes.computeIfAbsent(messageType, m -> messageType.createMessageBox()); + messageBox.addMessage(receiverId, message); + } - @Override - public M getMessage(long receiverId, MessageType messageType) { - MessageBox messageBox = messageBoxes.get(messageType); - if (messageBox != null) { - return messageBox.getMessage(receiverId, messageType); - } - return null; + @Override + public M getMessage(long receiverId, MessageType messageType) { + MessageBox messageBox = messageBoxes.get(messageType); + if (messageBox != null) { + return messageBox.getMessage(receiverId, messageType); } + return null; + } - @Override - public long[] getReceiverIds() { - LongArraySet receiverIds = new LongArraySet(); - for (SingleMessageBox messageBox : messageBoxes.values()) { - for (long id : messageBox.getReceiverIds()) { - receiverIds.add(id); - } - } - return receiverIds.toLongArray(); + @Override + public long[] getReceiverIds() { + LongArraySet receiverIds = new LongArraySet(); + for (SingleMessageBox messageBox : messageBoxes.values()) { + for (long id : messageBox.getReceiverIds()) { + receiverIds.add(id); + } } + return receiverIds.toLongArray(); + } - public void addMessageBox(MessageBox other) { - if (other instanceof UnionMessageBox) { - UnionMessageBox otherUnionBox = (UnionMessageBox) other; - for (MessageType messageType : otherUnionBox.messageBoxes.keySet()) { - SingleMessageBox otherBox = otherUnionBox.getMessageBox(messageType); - addMessageBox(otherBox); - } - } else { // add single message box. - SingleMessageBox singleBox = (SingleMessageBox) other; - MessageType messageType = singleBox.getMessageType(); - if (messageBoxes.containsKey(messageType)) { - SingleMessageBox combineBox = (SingleMessageBox) messageBoxes.get(messageType).combine(other); - messageBoxes.put(messageType, combineBox); - } else { - messageBoxes.put(messageType, singleBox); - } - } + public void addMessageBox(MessageBox other) { + if (other instanceof UnionMessageBox) { + UnionMessageBox otherUnionBox = (UnionMessageBox) other; + for (MessageType messageType : otherUnionBox.messageBoxes.keySet()) { + SingleMessageBox otherBox = otherUnionBox.getMessageBox(messageType); + addMessageBox(otherBox); + } + } else { // add single message box. + SingleMessageBox singleBox = (SingleMessageBox) other; + MessageType messageType = singleBox.getMessageType(); + if (messageBoxes.containsKey(messageType)) { + SingleMessageBox combineBox = + (SingleMessageBox) messageBoxes.get(messageType).combine(other); + messageBoxes.put(messageType, combineBox); + } else { + messageBoxes.put(messageType, singleBox); + } } + } - public SingleMessageBox getMessageBox(MessageType messageType) { - return messageBoxes.get(messageType); - } + public SingleMessageBox getMessageBox(MessageType messageType) { + return messageBoxes.get(messageType); + } - @Override - public MessageBox combine(MessageBox other) { - UnionMessageBox newBox = this.copy(); - newBox.addMessageBox(other); - return newBox; - } + @Override + public MessageBox combine(MessageBox other) { + UnionMessageBox newBox = this.copy(); + newBox.addMessageBox(other); + return newBox; + } - @Override - public UnionMessageBox copy() { - return new UnionMessageBox(new HashMap<>(messageBoxes)); - } + @Override + public UnionMessageBox copy() { + return new UnionMessageBox(new HashMap<>(messageBoxes)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/AbstractStepOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/AbstractStepOperator.java index 4686bca37..66e75ee7c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/AbstractStepOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/AbstractStepOperator.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.common.utils.ArrayUtil.castList; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -29,6 +28,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.type.IType; @@ -78,535 +78,569 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractStepOperator +import com.google.common.base.Preconditions; + +public abstract class AbstractStepOperator< + FUNC extends StepFunction, IN extends StepRecord, OUT extends StepRecord> implements StepOperator { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStepOperator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractStepOperator.class); - protected final long id; + protected final long id; - protected String name; - protected final FUNC function; - private final Map> caller2ReceiveEods = new HashMap<>(); - protected List inputPathSchemas; - protected PathType outputPathSchema; - protected IType outputType; - protected GraphSchema graphSchema; - protected GraphSchema modifyGraphSchema; - protected StepRecordType outputRecordType; - protected TraversalRuntimeContext context; - protected List> nextOperators = new ArrayList<>(); - protected StepCollector collector; + protected String name; + protected final FUNC function; + private final Map> caller2ReceiveEods = new HashMap<>(); + protected List inputPathSchemas; + protected PathType outputPathSchema; + protected IType outputType; + protected GraphSchema graphSchema; + protected GraphSchema modifyGraphSchema; + protected StepRecordType outputRecordType; + protected TraversalRuntimeContext context; + protected List> nextOperators = new ArrayList<>(); + protected StepCollector collector; - protected boolean needAddToPath; + protected boolean needAddToPath; - protected IType[] addingVertexFieldTypes; + protected IType[] addingVertexFieldTypes; - protected String[] addingVertexFieldNames; + protected String[] addingVertexFieldNames; - protected List callQueryProxies; + protected List callQueryProxies; - private final int numCallQueries; + private final int numCallQueries; - private CallContext callContext = null; + private CallContext callContext = null; - private CallState callState = null; + private CallState callState = null; - protected int numTasks; - private int numReceiveEods; - protected long numProcessRecords; - protected boolean isGlobalEmptyCycle; + protected int numTasks; + private int numReceiveEods; + protected long numProcessRecords; + protected boolean isGlobalEmptyCycle; - protected MetricGroup metricGroup; - private Counter inputCounter; - private Counter outputCounter; - private Meter inputTps; - private Meter outputTps; - private Histogram processRt; - private Counter inputEodCounter; + protected MetricGroup metricGroup; + private Counter inputCounter; + private Counter outputCounter; + private Meter inputTps; + private Meter outputTps; + private Histogram processRt; + private Counter inputEodCounter; - public AbstractStepOperator(long id, FUNC function) { - this.id = id; - this.name = generateName(); - this.function = Objects.requireNonNull(createCallQueryProxy(function)); - this.callQueryProxies = this.function.getCallQueryProxies(); - this.numCallQueries = this.callQueryProxies.stream() + public AbstractStepOperator(long id, FUNC function) { + this.id = id; + this.name = generateName(); + this.function = Objects.requireNonNull(createCallQueryProxy(function)); + this.callQueryProxies = this.function.getCallQueryProxies(); + this.numCallQueries = + this.callQueryProxies.stream() .map(proxy -> proxy.getQueryCalls().length) - .reduce(Integer::sum).orElse(0); - if (this.callQueryProxies.size() > 0) { - this.callContext = new CallContext(); - this.callState = CallState.INIT; - } - } - - @SuppressWarnings("unchecked") - private FUNC createCallQueryProxy(FUNC function) { - List rewriteExpressions = function.getExpressions().stream() - .map(CallQueryProxy::from) - .collect(Collectors.toList()); - return (FUNC) function.copy(rewriteExpressions); - } - - @SuppressWarnings("unchecked") - @Override - public void open(TraversalRuntimeContext context) { - Preconditions.checkArgument(inputPathSchemas != null, "inputPathSchemas is null"); - Preconditions.checkArgument(outputPathSchema != null, "outputPathSchema is null"); - Preconditions.checkArgument(outputType != null, "outputType is null"); - - if (context.getConfig().getBoolean(ExecutionConfigKeys.ENABLE_DETAIL_METRIC)) { - metricGroup = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL); + .reduce(Integer::sum) + .orElse(0); + if (this.callQueryProxies.size() > 0) { + this.callContext = new CallContext(); + this.callState = CallState.INIT; + } + } + + @SuppressWarnings("unchecked") + private FUNC createCallQueryProxy(FUNC function) { + List rewriteExpressions = + function.getExpressions().stream().map(CallQueryProxy::from).collect(Collectors.toList()); + return (FUNC) function.copy(rewriteExpressions); + } + + @SuppressWarnings("unchecked") + @Override + public void open(TraversalRuntimeContext context) { + Preconditions.checkArgument(inputPathSchemas != null, "inputPathSchemas is null"); + Preconditions.checkArgument(outputPathSchema != null, "outputPathSchema is null"); + Preconditions.checkArgument(outputType != null, "outputType is null"); + + if (context.getConfig().getBoolean(ExecutionConfigKeys.ENABLE_DETAIL_METRIC)) { + metricGroup = MetricGroupRegistry.getInstance().getMetricGroup(MetricConstants.MODULE_DSL); + } else { + metricGroup = BlackHoleMetricGroup.INSTANCE; + } + inputCounter = metricGroup.counter(MetricNameFormatter.stepInputRecordName(getName())); + inputTps = metricGroup.meter(MetricNameFormatter.stepInputRowTpsName(getName())); + outputCounter = metricGroup.counter(MetricNameFormatter.stepOutputRecordName(getName())); + outputTps = metricGroup.meter(MetricNameFormatter.stepOutputRowTpsName(getName())); + processRt = metricGroup.histogram(MetricNameFormatter.stepProcessTimeRtName(getName())); + inputEodCounter = metricGroup.counter(MetricNameFormatter.stepInputEodName(getName())); + + this.context = context; + this.numTasks = context.getNumTasks(); + this.numReceiveEods = getNumReceiveEods(); + this.numProcessRecords = 0L; + this.isGlobalEmptyCycle = true; + for (StepOperator nextOp : nextOperators) { + nextOp.open(context); + } + if (nextOperators.isEmpty()) { + if (context.getTopology().belongMainDag(id)) { + assert this instanceof StepEndOperator; + // The collectors are empty means current is the StepEndOperator for main dag, use the + // StepEndCollector. + this.collector = (StepCollector) new StepEndCollector(context); + } else { // current is StepReturnOperator. + assert this instanceof StepReturnOperator; + this.collector = (StepCollector) new StepReturnCollector(context, id); + } + } else { + List> outputCollectors = new ArrayList<>(); + for (StepOperator nextOp : nextOperators) { + StepCollector collector; + if (context.getTopology().isChained(id, nextOp.getId())) { + collector = + (StepCollector) + new ChainOperatorCollector(castList(Collections.singletonList(nextOp)), context); } else { - metricGroup = BlackHoleMetricGroup.INSTANCE; + collector = (StepCollector) new StepNextCollector(id, nextOp.getId(), context); } - inputCounter = metricGroup.counter(MetricNameFormatter.stepInputRecordName(getName())); - inputTps = metricGroup.meter(MetricNameFormatter.stepInputRowTpsName(getName())); - outputCounter = metricGroup.counter(MetricNameFormatter.stepOutputRecordName(getName())); - outputTps = metricGroup.meter(MetricNameFormatter.stepOutputRowTpsName(getName())); - processRt = metricGroup.histogram(MetricNameFormatter.stepProcessTimeRtName(getName())); - inputEodCounter = metricGroup.counter(MetricNameFormatter.stepInputEodName(getName())); - - this.context = context; - this.numTasks = context.getNumTasks(); - this.numReceiveEods = getNumReceiveEods(); - this.numProcessRecords = 0L; - this.isGlobalEmptyCycle = true; - for (StepOperator nextOp : nextOperators) { - nextOp.open(context); - } - if (nextOperators.isEmpty()) { - if (context.getTopology().belongMainDag(id)) { - assert this instanceof StepEndOperator; - // The collectors are empty means current is the StepEndOperator for main dag, use the StepEndCollector. - this.collector = (StepCollector) new StepEndCollector(context); - } else { // current is StepReturnOperator. - assert this instanceof StepReturnOperator; - this.collector = (StepCollector) new StepReturnCollector(context, id); - } - } else { - List> outputCollectors = new ArrayList<>(); - for (StepOperator nextOp : nextOperators) { - StepCollector collector; - if (context.getTopology().isChained(id, nextOp.getId())) { - collector = (StepCollector) new ChainOperatorCollector( - castList(Collections.singletonList(nextOp)), context); - } else { - collector = (StepCollector) new StepNextCollector(id, nextOp.getId(), context); - } - outputCollectors.add(collector); - } - this.collector = new StepBroadcastCollector<>(outputCollectors); - } - if (callQueryProxies.size() > 0) { - this.collector = new StepWaitCallQueryCollector<>(callQueryProxies.get(0), this.collector); - } - - PathType concatInputType = concatInputPathType(); - if (concatInputType != null) { - List appendFields = new ArrayList<>(); - int[] outputPathFieldIndices = ArrayUtil.toIntArray( - outputPathSchema.getFieldNames().stream().map(name -> { - int index = concatInputType.indexOf(name); - if (index != -1) { + outputCollectors.add(collector); + } + this.collector = new StepBroadcastCollector<>(outputCollectors); + } + if (callQueryProxies.size() > 0) { + this.collector = new StepWaitCallQueryCollector<>(callQueryProxies.get(0), this.collector); + } + + PathType concatInputType = concatInputPathType(); + if (concatInputType != null) { + List appendFields = new ArrayList<>(); + int[] outputPathFieldIndices = + ArrayUtil.toIntArray( + outputPathSchema.getFieldNames().stream() + .map( + name -> { + int index = concatInputType.indexOf(name); + if (index != -1) { + return index; + } + // If the last output field is not exist in the input, then it is a new + // append field by MatchVertex Or MatchEdge Or MatchExtend operator, return + // the + // field index of outputPathSchema. + index = concatInputType.size() + appendFields.size(); + appendFields.add(name); return index; - } - // If the last output field is not exist in the input, then it is a new - // append field by MatchVertex Or MatchEdge Or MatchExtend operator, return the - // field index of outputPathSchema. - index = concatInputType.size() + appendFields.size(); - appendFields.add(name); - return index; - }).collect(Collectors.toList())); - // If any of the input field is not exist in the output, then we should prune the output path. - boolean needPathPrune = concatInputType.getFieldNames().stream() - .anyMatch(input -> !outputPathSchema.contain(input)); - - if (needPathPrune) { - this.collector = new StepPathPruneCollector<>(context, this.collector, - outputPathFieldIndices); - } - } - this.needAddToPath = this instanceof LabeledStepOperator + }) + .collect(Collectors.toList())); + // If any of the input field is not exist in the output, then we should prune the output path. + boolean needPathPrune = + concatInputType.getFieldNames().stream() + .anyMatch(input -> !outputPathSchema.contain(input)); + + if (needPathPrune) { + this.collector = + new StepPathPruneCollector<>(context, this.collector, outputPathFieldIndices); + } + } + this.needAddToPath = + this instanceof LabeledStepOperator && outputPathSchema.contain(((LabeledStepOperator) this).getLabel()) && !isSubQueryStartLabel(); - List addingVertexFields = getModifyGraphSchema().getAddingFields(graphSchema); - this.addingVertexFieldTypes = addingVertexFields.stream() + List addingVertexFields = getModifyGraphSchema().getAddingFields(graphSchema); + this.addingVertexFieldTypes = + addingVertexFields.stream() .map(TableField::getType) .collect(Collectors.toList()) - .toArray(new IType[]{}); + .toArray(new IType[] {}); - this.addingVertexFieldNames = addingVertexFields.stream() + this.addingVertexFieldNames = + addingVertexFields.stream() .map(TableField::getName) .collect(Collectors.toList()) - .toArray(new String[]{}); - - function.open(context, new FunctionSchemas(inputPathSchemas, outputPathSchema, - outputType, graphSchema, modifyGraphSchema, addingVertexFieldTypes, addingVertexFieldNames)); - } - - private boolean isSubQueryStartLabel() { - boolean isSubDag = !context.getTopology().belongMainDag(id); - if (!isSubDag) { - return false; - } - if (!(this instanceof LabeledStepOperator)) { - return false; - } - List inputOpIds = context.getTopology().getInputIds(id); - return inputOpIds.size() == 1 && context.getTopology() - .getOperator(inputOpIds.get(0)) instanceof StepSubQueryStartOperator; - } - - private int getNumReceiveEods() { - DagTopologyGroup topologyGroup = context.getTopology(); - List inputOpIds = topologyGroup.getInputIds(id); - int numReceiveEods = 0; - for (long inputOpId : inputOpIds) { - if (topologyGroup.isChained(id, inputOpId)) { - numReceiveEods += 1; - } else { - numReceiveEods += numTasks; - } - } - return numReceiveEods; - } - - protected PathType concatInputPathType() { - if (inputPathSchemas.isEmpty()) { - return null; - } - if (inputPathSchemas.size() == 1) { - return inputPathSchemas.get(0); - } - throw new IllegalArgumentException( - this.getClass().getSimpleName() + " should override concatInputPathType() method."); - } - - public final void process(IN record) { - if (callState == CallState.FINISH) { - throw new IllegalArgumentException("task index:" + context.getTaskIndex() - + ", op id: " + id + " in illegal call state: " + callState); - } - long startTs = System.nanoTime(); - - // set current operator id. - context.setCurrentOpId(id); - if (record.getType() == StepRecordType.EOD) { - inputEodCounter.inc(); - EndOfData eod = (EndOfData) record; - processEod(eod); - } else { - if (callState == CallState.INIT) { - setCallState(CallState.CALLING); - } else if (callState == CallState.WAITING) { - setCallState(CallState.RETURNING); - } - // set current vertex. - if (record.getType() == StepRecordType.VERTEX) { - VertexRecord vertexRecord = (VertexRecord) record; - RowVertex vertex = vertexRecord.getVertex(); - context.setVertex(vertex); - if (callContext != null) { - callContext.addPath(context.getRequestId(), vertex.getId(), - vertexRecord.getPathById(vertex.getId())); - callContext.addRequest(vertex.getId(), context.getRequest()); - } - } else { - assert callContext == null : "Calling sub query on non-vertex record is not allowed."; - } - numProcessRecords++; - processRecord(record); - processRt.update((System.nanoTime() - startTs) / 1000L); - inputCounter.inc(); - inputTps.mark(); - } - } - - protected void processEod(EndOfData eod) { - caller2ReceiveEods.computeIfAbsent(eod.getCallOpId(), k -> new ArrayList<>()) - .add(eod); - boolean inputEmptyCycle = eod.isGlobalEmptyCycle; - this.isGlobalEmptyCycle &= inputEmptyCycle; - // Receive all EOD from input operators. - for (Map.Entry> entry : caller2ReceiveEods.entrySet()) { - long callerOpId = entry.getKey(); - List receiveEods = entry.getValue(); - if (hasReceivedAllEod(receiveEods)) { - LOGGER.info("Step op: {} task: {} received all eods. Iterations: {}", - this.getName(), context.getTaskIndex(), context.getIterationId()); - onReceiveAllEOD(callerOpId, receiveEods); - } - } - } - - protected boolean hasReceivedAllEod(List receiveEods) { - if (numCallQueries > 0) { - if (callQueryProxies.get(0).getCallState() == CallState.RETURNING - || callQueryProxies.get(0).getCallState() == CallState.WAITING) { - // When receiving all eod from the sub-query call, trigger the onReceiveAllEOD. - return (numTasks * numCallQueries) == receiveEods.size(); - } else if (callQueryProxies.get(0).getCallState() == CallState.CALLING - || callQueryProxies.get(0).getCallState() == CallState.INIT) { - return numReceiveEods == receiveEods.size(); - } - return false; - } else { - // For source operator, the input is empty, so if it has received eod, - // it will trigger the onReceiveAllEOD. For other operator, - // the count of eod should equal to the input size. - return numReceiveEods == receiveEods.size(); - } - } - - protected abstract void processRecord(IN record); - - protected void onReceiveAllEOD(long callerOpId, List receiveEods) { - finish(); - // Send EOD to output operators. - collectEOD(callerOpId); - receiveEods.clear(); - - if (callState == CallState.FINISH) { - LOGGER.info("step operator: {} finished, task id is: {}", getName(), context.getTaskIndex()); - this.setCallState(CallState.INIT); - } - } - - private void setCallState(CallState callState) { - this.callState = callState; - for (CallQueryProxy callQueryProxy : callQueryProxies) { - callQueryProxy.setCallState(callState); - } - } - - public void finish() { - if (callQueryProxies.size() > 0) { - switch (callState) { - case INIT: - case CALLING: - this.setCallState(CallState.WAITING); - // push call context to the stack when finish calling. - context.push(id, callContext); - break; - case WAITING: - case RETURNING: - this.setCallState(CallState.FINISH); - // pop call context from stack when all calling has returned from sub query - context.pop(id); - // reset call context after pop from the stack - callContext.reset(); - break; - default: - throw new GeaFlowDSLException("Illegal call state: {}", callState); - } - for (CallQueryProxy callQueryProxy : callQueryProxies) { - callQueryProxy.finishCall(); - } - } - function.finish((StepCollector) collector); - - LOGGER.info("Step op: {} task: {} finished. Iterations: {}", this.getName(), - this.getContext().getTaskIndex(), context.getIterationId()); - } - - protected void collect(OUT record) { - context.setInputOperatorId(id); - collector.collect(record); - outputCounter.inc(); - outputTps.mark(); - } - - @SuppressWarnings("unchecked") - protected void collectEOD(long callerOpId) { - this.isGlobalEmptyCycle &= numProcessRecords == 0L; - EndOfData eod = EndOfData.of(callerOpId, id); - eod.isGlobalEmptyCycle = isGlobalEmptyCycle; - collector.collect((OUT) eod); - this.isGlobalEmptyCycle = true; - this.numProcessRecords = 0L; - } - - @Override - public void close() { - } - - @Override - public void addNextOperator(StepOperator nextOperator) { - if (nextOperators.contains(nextOperator)) { - return; + .toArray(new String[] {}); + + function.open( + context, + new FunctionSchemas( + inputPathSchemas, + outputPathSchema, + outputType, + graphSchema, + modifyGraphSchema, + addingVertexFieldTypes, + addingVertexFieldNames)); + } + + private boolean isSubQueryStartLabel() { + boolean isSubDag = !context.getTopology().belongMainDag(id); + if (!isSubDag) { + return false; + } + if (!(this instanceof LabeledStepOperator)) { + return false; + } + List inputOpIds = context.getTopology().getInputIds(id); + return inputOpIds.size() == 1 + && context.getTopology().getOperator(inputOpIds.get(0)) + instanceof StepSubQueryStartOperator; + } + + private int getNumReceiveEods() { + DagTopologyGroup topologyGroup = context.getTopology(); + List inputOpIds = topologyGroup.getInputIds(id); + int numReceiveEods = 0; + for (long inputOpId : inputOpIds) { + if (topologyGroup.isChained(id, inputOpId)) { + numReceiveEods += 1; + } else { + numReceiveEods += numTasks; + } + } + return numReceiveEods; + } + + protected PathType concatInputPathType() { + if (inputPathSchemas.isEmpty()) { + return null; + } + if (inputPathSchemas.size() == 1) { + return inputPathSchemas.get(0); + } + throw new IllegalArgumentException( + this.getClass().getSimpleName() + " should override concatInputPathType() method."); + } + + public final void process(IN record) { + if (callState == CallState.FINISH) { + throw new IllegalArgumentException( + "task index:" + + context.getTaskIndex() + + ", op id: " + + id + + " in illegal call state: " + + callState); + } + long startTs = System.nanoTime(); + + // set current operator id. + context.setCurrentOpId(id); + if (record.getType() == StepRecordType.EOD) { + inputEodCounter.inc(); + EndOfData eod = (EndOfData) record; + processEod(eod); + } else { + if (callState == CallState.INIT) { + setCallState(CallState.CALLING); + } else if (callState == CallState.WAITING) { + setCallState(CallState.RETURNING); + } + // set current vertex. + if (record.getType() == StepRecordType.VERTEX) { + VertexRecord vertexRecord = (VertexRecord) record; + RowVertex vertex = vertexRecord.getVertex(); + context.setVertex(vertex); + if (callContext != null) { + callContext.addPath( + context.getRequestId(), vertex.getId(), vertexRecord.getPathById(vertex.getId())); + callContext.addRequest(vertex.getId(), context.getRequest()); } - this.nextOperators.add(nextOperator); - } - - @Override - public List> getNextOperators() { - return this.nextOperators; - } - - @Override - public StepOperator withOutputPathSchema(PathType pathSchema) { - this.outputPathSchema = Objects.requireNonNull(pathSchema); - return this; - } - - @Override - public StepOperator withInputPathSchema(List inputPaths) { - this.inputPathSchemas = Objects.requireNonNull(inputPaths); - return this; - } - - @Override - public StepOperator withOutputType(IType outputType) { - this.outputType = Objects.requireNonNull(outputType); - if (outputType instanceof VertexType) { - this.outputRecordType = StepRecordType.VERTEX; - } else { - this.outputRecordType = StepRecordType.EDGE_GROUP; - } - return this; - } - - @Override - public StepOperator withGraphSchema(GraphSchema graph) { - this.graphSchema = Objects.requireNonNull(graph); - return this; - } - - @Override - public StepOperator withModifyGraphSchema(GraphSchema modifyGraphSchema) { - this.modifyGraphSchema = modifyGraphSchema; - return this; - } - - @Override - public List getInputPathSchemas() { - return inputPathSchemas; - } - - @Override - public PathType getOutputPathSchema() { - return outputPathSchema; - } - - @Override - public IType getOutputType() { - return outputType; - } - - @Override - public GraphSchema getGraphSchema() { - return graphSchema; - } - - @Override - public GraphSchema getModifyGraphSchema() { - if (modifyGraphSchema == null) { - return graphSchema; - } - return modifyGraphSchema; - } - - public StepOperator copy() { - return copyInternal() - .withGraphSchema(graphSchema) - .withInputPathSchema(inputPathSchemas) - .withOutputPathSchema(outputPathSchema) - .withOutputType(outputType); - } - - public FUNC getFunction() { - return function; - } - - public TraversalRuntimeContext getContext() { - return context; - } - - @Override - public int hashCode() { - return Objects.hash(id); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof AbstractStepOperator)) { - return false; - } - AbstractStepOperator that = (AbstractStepOperator) o; - return id == that.id; - } - - @Override - public List getSubQueryNames() { - return function.getCallQueryProxies().stream() - .flatMap(proxy -> proxy.getSubQueryNames().stream()) - .collect(Collectors.toList()); - } - - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(getName()); - List subQueryNames = getSubQueryNames(); - if (subQueryNames.size() > 0) { - str.append("[").append(StringUtils.join(subQueryNames, ",")).append("]"); - } - return str.toString(); - } - - public abstract StepOperator copyInternal(); - - protected RowVertex alignToOutputSchema(RowVertex vertex) { - if (vertex instanceof IdOnlyVertex) { - return vertex; - } - String label = vertex.getLabel(); - VertexType inputVertexType = getModifyGraphSchema().getVertex(label); - VertexType outputVertexType = ((VertexType) getOutputType()); - return SchemaUtil.alignToVertexSchema(vertex, inputVertexType, outputVertexType); - } - - protected RowEdge alignToOutputSchema(RowEdge edge) { - String label = edge.getLabel(); - EdgeType inputEdgeType = getModifyGraphSchema().getEdge(label); - EdgeType outputEdgeType = ((EdgeType) getOutputType()); - return SchemaUtil.alignToEdgeSchema(edge, inputEdgeType, outputEdgeType); - } - - protected Row withParameter(Row row) { - Row parameters = context.getParameters(); - if (parameters != null) { - if (row instanceof Path) { - return new DefaultParameterizedPath((Path) row, context.getRequestId(), context.getParameters()); - } else { - return new DefaultParameterizedRow(row, context.getRequestId(), context.getParameters()); - } - } - return row; - } - - @Override - public long getId() { - return id; - } - - private String generateName() { - String className = getClass().getSimpleName(); - return className.substring(0, className.length() - "Operator".length()) + "-" + getId(); - } - - @Override - public String getName() { - return name; - } - - @Override - public StepOperator withName(String name) { - this.name = Objects.requireNonNull(name); - return this; - } + } else { + assert callContext == null : "Calling sub query on non-vertex record is not allowed."; + } + numProcessRecords++; + processRecord(record); + processRt.update((System.nanoTime() - startTs) / 1000L); + inputCounter.inc(); + inputTps.mark(); + } + } + + protected void processEod(EndOfData eod) { + caller2ReceiveEods.computeIfAbsent(eod.getCallOpId(), k -> new ArrayList<>()).add(eod); + boolean inputEmptyCycle = eod.isGlobalEmptyCycle; + this.isGlobalEmptyCycle &= inputEmptyCycle; + // Receive all EOD from input operators. + for (Map.Entry> entry : caller2ReceiveEods.entrySet()) { + long callerOpId = entry.getKey(); + List receiveEods = entry.getValue(); + if (hasReceivedAllEod(receiveEods)) { + LOGGER.info( + "Step op: {} task: {} received all eods. Iterations: {}", + this.getName(), + context.getTaskIndex(), + context.getIterationId()); + onReceiveAllEOD(callerOpId, receiveEods); + } + } + } + + protected boolean hasReceivedAllEod(List receiveEods) { + if (numCallQueries > 0) { + if (callQueryProxies.get(0).getCallState() == CallState.RETURNING + || callQueryProxies.get(0).getCallState() == CallState.WAITING) { + // When receiving all eod from the sub-query call, trigger the onReceiveAllEOD. + return (numTasks * numCallQueries) == receiveEods.size(); + } else if (callQueryProxies.get(0).getCallState() == CallState.CALLING + || callQueryProxies.get(0).getCallState() == CallState.INIT) { + return numReceiveEods == receiveEods.size(); + } + return false; + } else { + // For source operator, the input is empty, so if it has received eod, + // it will trigger the onReceiveAllEOD. For other operator, + // the count of eod should equal to the input size. + return numReceiveEods == receiveEods.size(); + } + } + + protected abstract void processRecord(IN record); + + protected void onReceiveAllEOD(long callerOpId, List receiveEods) { + finish(); + // Send EOD to output operators. + collectEOD(callerOpId); + receiveEods.clear(); + + if (callState == CallState.FINISH) { + LOGGER.info("step operator: {} finished, task id is: {}", getName(), context.getTaskIndex()); + this.setCallState(CallState.INIT); + } + } + + private void setCallState(CallState callState) { + this.callState = callState; + for (CallQueryProxy callQueryProxy : callQueryProxies) { + callQueryProxy.setCallState(callState); + } + } + + public void finish() { + if (callQueryProxies.size() > 0) { + switch (callState) { + case INIT: + case CALLING: + this.setCallState(CallState.WAITING); + // push call context to the stack when finish calling. + context.push(id, callContext); + break; + case WAITING: + case RETURNING: + this.setCallState(CallState.FINISH); + // pop call context from stack when all calling has returned from sub query + context.pop(id); + // reset call context after pop from the stack + callContext.reset(); + break; + default: + throw new GeaFlowDSLException("Illegal call state: {}", callState); + } + for (CallQueryProxy callQueryProxy : callQueryProxies) { + callQueryProxy.finishCall(); + } + } + function.finish((StepCollector) collector); + + LOGGER.info( + "Step op: {} task: {} finished. Iterations: {}", + this.getName(), + this.getContext().getTaskIndex(), + context.getIterationId()); + } + + protected void collect(OUT record) { + context.setInputOperatorId(id); + collector.collect(record); + outputCounter.inc(); + outputTps.mark(); + } + + @SuppressWarnings("unchecked") + protected void collectEOD(long callerOpId) { + this.isGlobalEmptyCycle &= numProcessRecords == 0L; + EndOfData eod = EndOfData.of(callerOpId, id); + eod.isGlobalEmptyCycle = isGlobalEmptyCycle; + collector.collect((OUT) eod); + this.isGlobalEmptyCycle = true; + this.numProcessRecords = 0L; + } + + @Override + public void close() {} + + @Override + public void addNextOperator(StepOperator nextOperator) { + if (nextOperators.contains(nextOperator)) { + return; + } + this.nextOperators.add(nextOperator); + } + + @Override + public List> getNextOperators() { + return this.nextOperators; + } + + @Override + public StepOperator withOutputPathSchema(PathType pathSchema) { + this.outputPathSchema = Objects.requireNonNull(pathSchema); + return this; + } + + @Override + public StepOperator withInputPathSchema(List inputPaths) { + this.inputPathSchemas = Objects.requireNonNull(inputPaths); + return this; + } + + @Override + public StepOperator withOutputType(IType outputType) { + this.outputType = Objects.requireNonNull(outputType); + if (outputType instanceof VertexType) { + this.outputRecordType = StepRecordType.VERTEX; + } else { + this.outputRecordType = StepRecordType.EDGE_GROUP; + } + return this; + } + + @Override + public StepOperator withGraphSchema(GraphSchema graph) { + this.graphSchema = Objects.requireNonNull(graph); + return this; + } + + @Override + public StepOperator withModifyGraphSchema(GraphSchema modifyGraphSchema) { + this.modifyGraphSchema = modifyGraphSchema; + return this; + } + + @Override + public List getInputPathSchemas() { + return inputPathSchemas; + } + + @Override + public PathType getOutputPathSchema() { + return outputPathSchema; + } + + @Override + public IType getOutputType() { + return outputType; + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + @Override + public GraphSchema getModifyGraphSchema() { + if (modifyGraphSchema == null) { + return graphSchema; + } + return modifyGraphSchema; + } + + public StepOperator copy() { + return copyInternal() + .withGraphSchema(graphSchema) + .withInputPathSchema(inputPathSchemas) + .withOutputPathSchema(outputPathSchema) + .withOutputType(outputType); + } + + public FUNC getFunction() { + return function; + } + + public TraversalRuntimeContext getContext() { + return context; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AbstractStepOperator)) { + return false; + } + AbstractStepOperator that = (AbstractStepOperator) o; + return id == that.id; + } + + @Override + public List getSubQueryNames() { + return function.getCallQueryProxies().stream() + .flatMap(proxy -> proxy.getSubQueryNames().stream()) + .collect(Collectors.toList()); + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(getName()); + List subQueryNames = getSubQueryNames(); + if (subQueryNames.size() > 0) { + str.append("[").append(StringUtils.join(subQueryNames, ",")).append("]"); + } + return str.toString(); + } + + public abstract StepOperator copyInternal(); + + protected RowVertex alignToOutputSchema(RowVertex vertex) { + if (vertex instanceof IdOnlyVertex) { + return vertex; + } + String label = vertex.getLabel(); + VertexType inputVertexType = getModifyGraphSchema().getVertex(label); + VertexType outputVertexType = ((VertexType) getOutputType()); + return SchemaUtil.alignToVertexSchema(vertex, inputVertexType, outputVertexType); + } + + protected RowEdge alignToOutputSchema(RowEdge edge) { + String label = edge.getLabel(); + EdgeType inputEdgeType = getModifyGraphSchema().getEdge(label); + EdgeType outputEdgeType = ((EdgeType) getOutputType()); + return SchemaUtil.alignToEdgeSchema(edge, inputEdgeType, outputEdgeType); + } + + protected Row withParameter(Row row) { + Row parameters = context.getParameters(); + if (parameters != null) { + if (row instanceof Path) { + return new DefaultParameterizedPath( + (Path) row, context.getRequestId(), context.getParameters()); + } else { + return new DefaultParameterizedRow(row, context.getRequestId(), context.getParameters()); + } + } + return row; + } + + @Override + public long getId() { + return id; + } + + private String generateName() { + String className = getClass().getSimpleName(); + return className.substring(0, className.length() - "Operator".length()) + "-" + getId(); + } + + @Override + public String getName() { + return name; + } + + @Override + public StepOperator withName(String name) { + this.name = Objects.requireNonNull(name); + return this; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/FilteredFieldsOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/FilteredFieldsOperator.java index 55c427314..ec5e6a8d5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/FilteredFieldsOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/FilteredFieldsOperator.java @@ -20,8 +20,9 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; import java.util.Set; + import org.apache.calcite.rex.RexFieldAccess; public interface FilteredFieldsOperator { - StepOperator withFilteredFields(Set fields); + StepOperator withFilteredFields(Set fields); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/LabeledStepOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/LabeledStepOperator.java index 753e44443..f034b531c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/LabeledStepOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/LabeledStepOperator.java @@ -21,5 +21,5 @@ public interface LabeledStepOperator { - String getLabel(); + String getLabel(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchEdgeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchEdgeOperator.java index 374f618ff..7e6c9968c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchEdgeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchEdgeOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; import java.util.*; + import org.apache.calcite.rex.RexFieldAccess; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.runtime.function.graph.MatchEdgeFunction; @@ -34,117 +35,114 @@ import org.apache.geaflow.metrics.common.MetricNameFormatter; import org.apache.geaflow.metrics.common.api.Histogram; -public class MatchEdgeOperator extends AbstractStepOperator +public class MatchEdgeOperator + extends AbstractStepOperator implements FilteredFieldsOperator, LabeledStepOperator { - private Histogram loadEdgeHg; - private Histogram loadEdgeRt; + private Histogram loadEdgeHg; + private Histogram loadEdgeRt; - private final boolean isOptionMatch; + private final boolean isOptionMatch; - private Set fields; + private Set fields; - private EdgeProjectorUtil edgeProjector = null; + private EdgeProjectorUtil edgeProjector = null; - @Override - public StepOperator withFilteredFields(Set fields) { - this.fields = fields; - return this; - } + @Override + public StepOperator withFilteredFields( + Set fields) { + this.fields = fields; + return this; + } - public MatchEdgeOperator(long id, MatchEdgeFunction function) { - super(id, function); - isOptionMatch = function instanceof MatchEdgeFunctionImpl + public MatchEdgeOperator(long id, MatchEdgeFunction function) { + super(id, function); + isOptionMatch = + function instanceof MatchEdgeFunctionImpl && ((MatchEdgeFunctionImpl) function).isOptionalMatchEdge(); + } + + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + this.loadEdgeHg = metricGroup.histogram(MetricNameFormatter.loadEdgeCountRtName(getName())); + this.loadEdgeRt = metricGroup.histogram(MetricNameFormatter.loadEdgeTimeRtName(getName())); + } + + @Override + public void processRecord(VertexRecord vertex) { + long startTs = System.currentTimeMillis(); + EdgeGroup loadEdges = context.loadEdges(function.getEdgesFilter()); + loadEdgeRt.update(System.currentTimeMillis() - startTs); + loadEdges = loadEdges.map(this::alignToOutputSchema); + // filter by edge types if exists. + EdgeGroup edgeGroup = loadEdges; + if (!function.getEdgeTypes().isEmpty()) { + edgeGroup = loadEdges.filter(edge -> function.getEdgeTypes().contains(edge.getBinaryLabel())); } - - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - this.loadEdgeHg = metricGroup.histogram(MetricNameFormatter.loadEdgeCountRtName(getName())); - this.loadEdgeRt = metricGroup.histogram(MetricNameFormatter.loadEdgeTimeRtName(getName())); - } - - @Override - public void processRecord(VertexRecord vertex) { - long startTs = System.currentTimeMillis(); - EdgeGroup loadEdges = context.loadEdges(function.getEdgesFilter()); - loadEdgeRt.update(System.currentTimeMillis() - startTs); - loadEdges = loadEdges.map(this::alignToOutputSchema); - // filter by edge types if exists. - EdgeGroup edgeGroup = loadEdges; - if (!function.getEdgeTypes().isEmpty()) { - edgeGroup = loadEdges.filter(edge -> - - function.getEdgeTypes().contains(edge.getBinaryLabel())); + Map targetTreePaths = new HashMap<>(); + + // generate new paths. + if (needAddToPath) { + int numEdge = 0; + for (RowEdge edge : edgeGroup) { + if (edgeProjector == null) { + edgeProjector = new EdgeProjectorUtil(graphSchema, fields, getOutputType()); } - Map targetTreePaths = new HashMap<>(); - - // generate new paths. - if (needAddToPath) { - int numEdge = 0; - for (RowEdge edge : edgeGroup) { - if (edgeProjector == null) { - edgeProjector = new EdgeProjectorUtil( - graphSchema, - fields, - getOutputType() - ); - } - if (fields != null && !fields.isEmpty()) { - edge = edgeProjector.projectEdge(edge); - } - - // add edge to path. - if (!targetTreePaths.containsKey(edge.getTargetId())) { - ITreePath newPath = vertex.getTreePath().extendTo(edge); - targetTreePaths.put(edge.getTargetId(), newPath); - } else { - ITreePath treePath = targetTreePaths.get(edge.getTargetId()); - treePath.getEdgeSet().addEdge(edge); - } - numEdge++; - } - if (numEdge == 0 && isOptionMatch) { - ITreePath newPath = vertex.getTreePath().extendTo((RowEdge) null); - targetTreePaths.put(null, newPath); - } - loadEdgeHg.update(numEdge); - } else { - if (!vertex.isPathEmpty()) { // inherit input path. - int numEdge = 0; - for (RowEdge edge : edgeGroup) { - targetTreePaths.put(edge.getTargetId(), vertex.getTreePath()); - numEdge++; - } - if (numEdge == 0 && isOptionMatch) { - targetTreePaths.put(null, vertex.getTreePath()); - } - loadEdgeHg.update(numEdge); - } + if (fields != null && !fields.isEmpty()) { + edge = edgeProjector.projectEdge(edge); } - EdgeGroupRecord edgeGroupRecord = EdgeGroupRecord.of(edgeGroup, targetTreePaths); - collect(edgeGroupRecord); - } - @Override - public String getLabel() { - return function.getLabel(); - } - - @Override - public StepOperator copyInternal() { - return new MatchEdgeOperator(id, function); - } - - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(getName()); - EdgeDirection direction = getFunction().getDirection(); - str.append("(").append(direction).append(")"); - String label = getLabel(); - str.append(" [").append(label).append("]"); - return str.toString(); + // add edge to path. + if (!targetTreePaths.containsKey(edge.getTargetId())) { + ITreePath newPath = vertex.getTreePath().extendTo(edge); + targetTreePaths.put(edge.getTargetId(), newPath); + } else { + ITreePath treePath = targetTreePaths.get(edge.getTargetId()); + treePath.getEdgeSet().addEdge(edge); + } + numEdge++; + } + if (numEdge == 0 && isOptionMatch) { + ITreePath newPath = vertex.getTreePath().extendTo((RowEdge) null); + targetTreePaths.put(null, newPath); + } + loadEdgeHg.update(numEdge); + } else { + if (!vertex.isPathEmpty()) { // inherit input path. + int numEdge = 0; + for (RowEdge edge : edgeGroup) { + targetTreePaths.put(edge.getTargetId(), vertex.getTreePath()); + numEdge++; + } + if (numEdge == 0 && isOptionMatch) { + targetTreePaths.put(null, vertex.getTreePath()); + } + loadEdgeHg.update(numEdge); + } } + EdgeGroupRecord edgeGroupRecord = EdgeGroupRecord.of(edgeGroup, targetTreePaths); + collect(edgeGroupRecord); + } + + @Override + public String getLabel() { + return function.getLabel(); + } + + @Override + public StepOperator copyInternal() { + return new MatchEdgeOperator(id, function); + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(getName()); + EdgeDirection direction = getFunction().getDirection(); + str.append("(").append(direction).append(")"); + String label = getLabel(); + str.append(" [").append(label).append("]"); + return str.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVertexOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVertexOperator.java index 47c845358..1ae465e1b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVertexOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVertexOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; import java.util.*; + import org.apache.calcite.rex.RexFieldAccess; import org.apache.geaflow.dsl.common.data.*; import org.apache.geaflow.dsl.common.data.StepRecord.StepRecordType; @@ -37,154 +38,150 @@ import org.apache.geaflow.metrics.common.MetricNameFormatter; import org.apache.geaflow.metrics.common.api.Histogram; -public class MatchVertexOperator extends AbstractStepOperator implements FilteredFieldsOperator, LabeledStepOperator { - - private Histogram loadVertexRt; - - private final boolean isOptionMatch; - - private Set idSet; +public class MatchVertexOperator + extends AbstractStepOperator + implements FilteredFieldsOperator, LabeledStepOperator { - private Set fields; + private Histogram loadVertexRt; - private VertexProjectorUtil vertexProjector = null; + private final boolean isOptionMatch; - @Override - public StepOperator withFilteredFields(Set fields) { - this.fields = fields; - return this; - } - - public MatchVertexOperator(long id, MatchVertexFunction function) { - super(id, function); - if (function instanceof MatchVertexFunctionImpl) { - isOptionMatch = ((MatchVertexFunctionImpl) function).isOptionalMatchVertex(); - idSet = ((MatchVertexFunctionImpl) function).getIdSet(); - } else { - isOptionMatch = false; - } - } + private Set idSet; - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - loadVertexRt = metricGroup.histogram(MetricNameFormatter.loadVertexTimeRtName(getName())); - } + private Set fields; - @SuppressWarnings("unchecked") - @Override - protected void processRecord(StepRecord record) { - if (record.getType() == StepRecordType.VERTEX) { - processVertex((VertexRecord) record); - } else { - EdgeGroupRecord edgeGroupRecord = (EdgeGroupRecord) record; - processEdgeGroup(edgeGroupRecord); - } - } + private VertexProjectorUtil vertexProjector = null; + @Override + public StepOperator withFilteredFields(Set fields) { + this.fields = fields; + return this; + } - private void processVertex(VertexRecord vertexRecord) { - RowVertex vertex = vertexRecord.getVertex(); - if (vertex instanceof IdOnlyVertex && needLoadVertex(vertex.getId())) { - long startTs = System.currentTimeMillis(); - vertex = context.loadVertex(vertex.getId(), - function.getVertexFilter(), - graphSchema, - addingVertexFieldTypes); - loadVertexRt.update(System.currentTimeMillis() - startTs); - - if (vertexProjector == null) { - vertexProjector = new VertexProjectorUtil( - graphSchema, - fields, - addingVertexFieldNames, - addingVertexFieldTypes - ); - } - if (fields != null && !fields.isEmpty()) { - vertex = vertexProjector.projectVertex(vertex); - } - - if (vertex == null && !isOptionMatch) { - // load a non-exists vertex, just skip. - return; - } - } - - if (vertex != null) { - if (!function.getVertexTypes().isEmpty() - && !function.getVertexTypes().contains(vertex.getBinaryLabel())) { - // filter by the vertex types. - return; - } - if (!idSet.isEmpty() && !idSet.contains(vertex.getId())) { - return; - } - vertex = alignToOutputSchema(vertex); - } - - ITreePath currentPath; - if (needAddToPath) { - currentPath = vertexRecord.getTreePath().extendTo(vertex); - } else { - currentPath = vertexRecord.getTreePath(); - } - if (vertex == null) { - vertex = VertexEdgeFactory.createVertex((VertexType) getOutputType()); - } - collect(VertexRecord.of(vertex, currentPath)); + public MatchVertexOperator(long id, MatchVertexFunction function) { + super(id, function); + if (function instanceof MatchVertexFunctionImpl) { + isOptionMatch = ((MatchVertexFunctionImpl) function).isOptionalMatchVertex(); + idSet = ((MatchVertexFunctionImpl) function).getIdSet(); + } else { + isOptionMatch = false; } - - private void processEdgeGroup(EdgeGroupRecord edgeGroupRecord) { - EdgeGroup edgeGroup = edgeGroupRecord.getEdgeGroup(); - for (RowEdge edge : edgeGroup) { - Object targetId = edge.getTargetId(); - // load targetId. - RowVertex vertex = context.loadVertex(targetId, function.getVertexFilter(), graphSchema, addingVertexFieldTypes); - if (vertex != null) { - ITreePath treePath = edgeGroupRecord.getPathById(targetId); - // set current vertex. - context.setVertex(vertex); - // process new vertex. - processVertex(VertexRecord.of(vertex, treePath)); - } else if (isOptionMatch) { - vertex = VertexEdgeFactory.createVertex((VertexType) getOutputType()); - ITreePath treePath = edgeGroupRecord.getPathById(targetId); - // set current vertex. - context.setVertex(vertex); - // process new vertex. - processVertex(VertexRecord.of(null, treePath)); - } - } + } + + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + loadVertexRt = metricGroup.histogram(MetricNameFormatter.loadVertexTimeRtName(getName())); + } + + @SuppressWarnings("unchecked") + @Override + protected void processRecord(StepRecord record) { + if (record.getType() == StepRecordType.VERTEX) { + processVertex((VertexRecord) record); + } else { + EdgeGroupRecord edgeGroupRecord = (EdgeGroupRecord) record; + processEdgeGroup(edgeGroupRecord); } - - private boolean needLoadVertex(Object vertexId) { - // skip load virtual id. - return !(vertexId instanceof VirtualId); + } + + private void processVertex(VertexRecord vertexRecord) { + RowVertex vertex = vertexRecord.getVertex(); + if (vertex instanceof IdOnlyVertex && needLoadVertex(vertex.getId())) { + long startTs = System.currentTimeMillis(); + vertex = + context.loadVertex( + vertex.getId(), function.getVertexFilter(), graphSchema, addingVertexFieldTypes); + loadVertexRt.update(System.currentTimeMillis() - startTs); + + if (vertexProjector == null) { + vertexProjector = + new VertexProjectorUtil( + graphSchema, fields, addingVertexFieldNames, addingVertexFieldTypes); + } + if (fields != null && !fields.isEmpty()) { + vertex = vertexProjector.projectVertex(vertex); + } + + if (vertex == null && !isOptionMatch) { + // load a non-exists vertex, just skip. + return; + } } - @Override - public void close() { - + if (vertex != null) { + if (!function.getVertexTypes().isEmpty() + && !function.getVertexTypes().contains(vertex.getBinaryLabel())) { + // filter by the vertex types. + return; + } + if (!idSet.isEmpty() && !idSet.contains(vertex.getId())) { + return; + } + vertex = alignToOutputSchema(vertex); } - @Override - public StepOperator copyInternal() { - return new MatchVertexOperator(id, function); + ITreePath currentPath; + if (needAddToPath) { + currentPath = vertexRecord.getTreePath().extendTo(vertex); + } else { + currentPath = vertexRecord.getTreePath(); } - - @Override - public String getLabel() { - return function.getLabel(); + if (vertex == null) { + vertex = VertexEdgeFactory.createVertex((VertexType) getOutputType()); } - - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(getName()); - String label = getLabel(); - str.append(" [").append(label).append("]"); - return str.toString(); + collect(VertexRecord.of(vertex, currentPath)); + } + + private void processEdgeGroup(EdgeGroupRecord edgeGroupRecord) { + EdgeGroup edgeGroup = edgeGroupRecord.getEdgeGroup(); + for (RowEdge edge : edgeGroup) { + Object targetId = edge.getTargetId(); + // load targetId. + RowVertex vertex = + context.loadVertex( + targetId, function.getVertexFilter(), graphSchema, addingVertexFieldTypes); + if (vertex != null) { + ITreePath treePath = edgeGroupRecord.getPathById(targetId); + // set current vertex. + context.setVertex(vertex); + // process new vertex. + processVertex(VertexRecord.of(vertex, treePath)); + } else if (isOptionMatch) { + vertex = VertexEdgeFactory.createVertex((VertexType) getOutputType()); + ITreePath treePath = edgeGroupRecord.getPathById(targetId); + // set current vertex. + context.setVertex(vertex); + // process new vertex. + processVertex(VertexRecord.of(null, treePath)); + } } + } + + private boolean needLoadVertex(Object vertexId) { + // skip load virtual id. + return !(vertexId instanceof VirtualId); + } + + @Override + public void close() {} + + @Override + public StepOperator copyInternal() { + return new MatchVertexOperator(id, function); + } + + @Override + public String getLabel() { + return function.getLabel(); + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(getName()); + String label = getLabel(); + str.append(" [").append(label).append("]"); + return str.toString(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVirtualEdgeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVirtualEdgeOperator.java index 039576717..6542ab8bf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVirtualEdgeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/MatchVirtualEdgeOperator.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.impl.types.ObjectEdge; import org.apache.geaflow.dsl.runtime.function.graph.MatchVirtualEdgeFunction; @@ -31,39 +32,39 @@ import org.apache.geaflow.dsl.runtime.traversal.data.VertexRecord; import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; -public class MatchVirtualEdgeOperator extends AbstractStepOperator { +public class MatchVirtualEdgeOperator + extends AbstractStepOperator { - public MatchVirtualEdgeOperator(long id, MatchVirtualEdgeFunction function) { - super(id, function); - } + public MatchVirtualEdgeOperator(long id, MatchVirtualEdgeFunction function) { + super(id, function); + } - @Override - protected void processRecord(VertexRecord vertexRecord) { - ITreePath treePath = vertexRecord.getTreePath(); - if (treePath != null) { - List targetIds = treePath.flatMap(function::computeTargetId); - if (targetIds != null) { - Map targetIdPaths = new HashMap<>(); - List edges = new ArrayList<>(targetIds.size()); - for (Object targetId : targetIds) { - RowEdge edge = new ObjectEdge(vertexRecord.getVertex().getId(), targetId); - edges.add(edge); + @Override + protected void processRecord(VertexRecord vertexRecord) { + ITreePath treePath = vertexRecord.getTreePath(); + if (treePath != null) { + List targetIds = treePath.flatMap(function::computeTargetId); + if (targetIds != null) { + Map targetIdPaths = new HashMap<>(); + List edges = new ArrayList<>(targetIds.size()); + for (Object targetId : targetIds) { + RowEdge edge = new ObjectEdge(vertexRecord.getVertex().getId(), targetId); + edges.add(edge); - ITreePath targetPath = function.computeTargetPath(targetId, treePath); - if (targetPath != null && !treePath.isEmpty()) { - targetIdPaths.put(targetId, targetPath); - } - EdgeGroup edgeGroup = EdgeGroup.of(edges); - EdgeGroupRecord edgeGroupRecord = EdgeGroupRecord.of(edgeGroup, targetIdPaths); - collect(edgeGroupRecord); - } - } + ITreePath targetPath = function.computeTargetPath(targetId, treePath); + if (targetPath != null && !treePath.isEmpty()) { + targetIdPaths.put(targetId, targetPath); + } + EdgeGroup edgeGroup = EdgeGroup.of(edges); + EdgeGroupRecord edgeGroupRecord = EdgeGroupRecord.of(edgeGroup, targetIdPaths); + collect(edgeGroupRecord); } + } } + } - @Override - public StepOperator copyInternal() { - return new MatchVirtualEdgeOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new MatchVirtualEdgeOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepDistinctOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepDistinctOperator.java index 843563b0b..d3bcb95e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepDistinctOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepDistinctOperator.java @@ -21,54 +21,59 @@ import java.util.HashSet; import java.util.Set; + import org.apache.geaflow.dsl.common.data.RowKey; import org.apache.geaflow.dsl.runtime.function.graph.StepKeyFunction; import org.apache.geaflow.dsl.runtime.traversal.TraversalRuntimeContext; import org.apache.geaflow.dsl.runtime.traversal.data.StepRecordWithPath; -public class StepDistinctOperator extends AbstractStepOperator { +public class StepDistinctOperator + extends AbstractStepOperator { - private final Set distinctKeys = new HashSet<>(); + private final Set distinctKeys = new HashSet<>(); - private int[] refPathIndices; + private int[] refPathIndices; - public StepDistinctOperator(long id, StepKeyFunction function) { - super(id, function); - } + public StepDistinctOperator(long id, StepKeyFunction function) { + super(id, function); + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - assert inputPathSchemas.size() == 1; - refPathIndices = new int[inputPathSchemas.get(0).size()]; - // refer all the input path fields - for (int i = 0; i < refPathIndices.length; i++) { - refPathIndices[i] = i; - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + assert inputPathSchemas.size() == 1; + refPathIndices = new int[inputPathSchemas.get(0).size()]; + // refer all the input path fields + for (int i = 0; i < refPathIndices.length; i++) { + refPathIndices[i] = i; } + } - @Override - protected void processRecord(StepRecordWithPath record) { - StepRecordWithPath distinctRecord = record.filter(path -> { - RowKey key = function.getKey(path); - if (!distinctKeys.contains(key)) { + @Override + protected void processRecord(StepRecordWithPath record) { + StepRecordWithPath distinctRecord = + record.filter( + path -> { + RowKey key = function.getKey(path); + if (!distinctKeys.contains(key)) { distinctKeys.add(key); return true; - } - return false; - }, refPathIndices); + } + return false; + }, + refPathIndices); - collect(distinctRecord); - } + collect(distinctRecord); + } - @Override - public StepOperator copyInternal() { - return new StepDistinctOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepDistinctOperator(id, function); + } - @Override - public void finish() { - distinctKeys.clear(); - super.finish(); - } + @Override + public void finish() { + distinctKeys.clear(); + super.finish(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepEndOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepEndOperator.java index 1a4571e9b..632e9235f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepEndOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepEndOperator.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.runtime.function.graph.FunctionSchemas; @@ -30,43 +31,39 @@ public class StepEndOperator extends AbstractStepOperator { - public StepEndOperator(long id) { - super(id, EndStepFunction.INSTANCE); - } + public StepEndOperator(long id) { + super(id, EndStepFunction.INSTANCE); + } - @Override - protected void processRecord(StepRecord record) { - collect(record); - } - - @Override - public StepOperator copyInternal() { - return new StepEndOperator(id); - } + @Override + protected void processRecord(StepRecord record) { + collect(record); + } - private static class EndStepFunction implements StepFunction { + @Override + public StepOperator copyInternal() { + return new StepEndOperator(id); + } - public static final StepFunction INSTANCE = new EndStepFunction(); + private static class EndStepFunction implements StepFunction { - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public static final StepFunction INSTANCE = new EndStepFunction(); - } - - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new EndStepFunction(); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new EndStepFunction(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepExchangeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepExchangeOperator.java index 601c51984..02e21c95f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepExchangeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepExchangeOperator.java @@ -19,8 +19,8 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowKey; @@ -37,39 +37,41 @@ import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.dsl.runtime.traversal.path.TreePaths; -public class StepExchangeOperator extends AbstractStepOperator { +import com.google.common.collect.Lists; +public class StepExchangeOperator + extends AbstractStepOperator { - public StepExchangeOperator(long id, StepKeyFunction function) { - super(id, function); - } + public StepExchangeOperator(long id, StepKeyFunction function) { + super(id, function); + } - @Override - protected void processRecord(StepRecord record) { - if (record.getType() == StepRecordType.ROW) { - Row row = (Row) record; - RowKey key = function.getKey(row); - if (context.getRequest() != null) { // append requestId to the key. - key = new DefaultRowKeyWithRequestId(context.getRequestId(), key); - } - StepKeyRecord keyRecord = new StepKeyRecordImpl(key, row); - collect(keyRecord); - } else { - StepRecordWithPath recordWithPath = (StepRecordWithPath) record; - for (ITreePath treePath : recordWithPath.getPaths()) { - List paths = treePath.toList(); - for (Path path : paths) { - RowKey key = function.getKey(path); - RowVertex virtualVertex = IdOnlyVertex.of(key); - ITreePath virtualVertexPath = TreePaths.createTreePath(Lists.newArrayList(path)); - collect(VertexRecord.of(virtualVertex, virtualVertexPath)); - } - } + @Override + protected void processRecord(StepRecord record) { + if (record.getType() == StepRecordType.ROW) { + Row row = (Row) record; + RowKey key = function.getKey(row); + if (context.getRequest() != null) { // append requestId to the key. + key = new DefaultRowKeyWithRequestId(context.getRequestId(), key); + } + StepKeyRecord keyRecord = new StepKeyRecordImpl(key, row); + collect(keyRecord); + } else { + StepRecordWithPath recordWithPath = (StepRecordWithPath) record; + for (ITreePath treePath : recordWithPath.getPaths()) { + List paths = treePath.toList(); + for (Path path : paths) { + RowKey key = function.getKey(path); + RowVertex virtualVertex = IdOnlyVertex.of(key); + ITreePath virtualVertexPath = TreePaths.createTreePath(Lists.newArrayList(path)); + collect(VertexRecord.of(virtualVertex, virtualVertexPath)); } + } } + } - @Override - public StepOperator copyInternal() { - return new StepExchangeOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepExchangeOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepFilterOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepFilterOperator.java index df504b52b..686ba7d73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepFilterOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepFilterOperator.java @@ -23,25 +23,27 @@ import org.apache.geaflow.dsl.runtime.traversal.data.StepRecordWithPath; import org.apache.geaflow.dsl.runtime.util.StepFunctionUtil; -public class StepFilterOperator extends AbstractStepOperator { +public class StepFilterOperator + extends AbstractStepOperator { - private final int[] refPathIndices; + private final int[] refPathIndices; - public StepFilterOperator(long id, StepBoolFunction function) { - super(id, function); - this.refPathIndices = StepFunctionUtil.getRefPathIndices(function); - } + public StepFilterOperator(long id, StepBoolFunction function) { + super(id, function); + this.refPathIndices = StepFunctionUtil.getRefPathIndices(function); + } - @Override - public void processRecord(StepRecordWithPath record) { - StepRecordWithPath filterRecord = record.filter(path -> function.filter(withParameter(path)), refPathIndices); - if (!filterRecord.isPathEmpty()) { - collect(filterRecord); - } + @Override + public void processRecord(StepRecordWithPath record) { + StepRecordWithPath filterRecord = + record.filter(path -> function.filter(withParameter(path)), refPathIndices); + if (!filterRecord.isPathEmpty()) { + collect(filterRecord); } + } - @Override - public StepOperator copyInternal() { - return new StepFilterOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepFilterOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalAggregateOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalAggregateOperator.java index 8990c462d..9f0d73276 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalAggregateOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalAggregateOperator.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -47,135 +48,148 @@ import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; -public class StepGlobalAggregateOperator extends AbstractStepOperator { +public class StepGlobalAggregateOperator + extends AbstractStepOperator { - private final StepKeyFunction groupByFunction; - private Map> requestId2Path; - private Map> requestId2Accumulators; - private final int[] pathPruneIndices; - private final IType[] pathPruneTypes; - private final IType[] inputPathTypes; - boolean isPathHeadVertexType; - boolean isPathTailVertexType; - private final IType[] aggregateNodeTypes; - private final IType[] aggOutputTypes; + private final StepKeyFunction groupByFunction; + private Map> requestId2Path; + private Map> requestId2Accumulators; + private final int[] pathPruneIndices; + private final IType[] pathPruneTypes; + private final IType[] inputPathTypes; + boolean isPathHeadVertexType; + boolean isPathTailVertexType; + private final IType[] aggregateNodeTypes; + private final IType[] aggOutputTypes; - public StepGlobalAggregateOperator(long id, StepKeyFunction keyFunction, - StepAggregateFunction function) { - super(id, function); - this.groupByFunction = Objects.requireNonNull(keyFunction); - this.pathPruneIndices = ((StepAggExpressionFunctionImpl) function).getPathPruneIndices(); - this.pathPruneTypes = ((StepAggExpressionFunctionImpl) function).getPathPruneTypes(); - this.inputPathTypes = ((StepAggExpressionFunctionImpl) function).getInputPathTypes(); - this.aggOutputTypes = ((StepAggExpressionFunctionImpl) function).getAggOutputTypes(); - assert pathPruneIndices.length > 0 && pathPruneTypes.length == pathPruneIndices.length; - this.isPathHeadVertexType = pathPruneTypes[0] instanceof VertexType; - this.isPathTailVertexType = pathPruneTypes[pathPruneTypes.length - 1] instanceof VertexType; - this.aggregateNodeTypes = isPathHeadVertexType + public StepGlobalAggregateOperator( + long id, StepKeyFunction keyFunction, StepAggregateFunction function) { + super(id, function); + this.groupByFunction = Objects.requireNonNull(keyFunction); + this.pathPruneIndices = ((StepAggExpressionFunctionImpl) function).getPathPruneIndices(); + this.pathPruneTypes = ((StepAggExpressionFunctionImpl) function).getPathPruneTypes(); + this.inputPathTypes = ((StepAggExpressionFunctionImpl) function).getInputPathTypes(); + this.aggOutputTypes = ((StepAggExpressionFunctionImpl) function).getAggOutputTypes(); + assert pathPruneIndices.length > 0 && pathPruneTypes.length == pathPruneIndices.length; + this.isPathHeadVertexType = pathPruneTypes[0] instanceof VertexType; + this.isPathTailVertexType = pathPruneTypes[pathPruneTypes.length - 1] instanceof VertexType; + this.aggregateNodeTypes = + isPathHeadVertexType ? ((VertexType) inputPathTypes[pathPruneIndices[0]]).getValueTypes() : ((EdgeType) inputPathTypes[pathPruneIndices[0]]).getValueTypes(); - } + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - requestId2Path = new HashMap<>(); - requestId2Accumulators = new HashMap<>(); - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + requestId2Path = new HashMap<>(); + requestId2Accumulators = new HashMap<>(); + } - @Override - protected void processRecord(VertexRecord record) { - ParameterRequest request = context.getRequest(); - if (!requestId2Accumulators.containsKey(request)) { - requestId2Path.put(request, new HashMap<>()); - requestId2Accumulators.put(request, new HashMap<>()); - } - record.mapPath(path -> { - RowKey key = groupByFunction.getKey(path); - Map key2Path = requestId2Path.get(request); - Map key2Acc = requestId2Accumulators.get(request); - Path prunedPath = path.subPath(pathPruneIndices); - Object partAccumulator; - if (isPathHeadVertexType) { - partAccumulator = ((IVertex) prunedPath.getPathNodes().get(0)) - .getValue().getField(aggregateNodeTypes.length, ObjectType.INSTANCE); - } else { - partAccumulator = ((IEdge) prunedPath.getPathNodes().get(0)) - .getValue().getField(aggregateNodeTypes.length, ObjectType.INSTANCE); - } - if (!key2Acc.containsKey(key)) { - key2Acc.put(key, partAccumulator); - key2Path.put(key, prunedPath); - } else { - Object accumulator = key2Acc.get(key); - function.merge(accumulator, partAccumulator); - key2Acc.put(key, accumulator); - } - return path; - }, null); + @Override + protected void processRecord(VertexRecord record) { + ParameterRequest request = context.getRequest(); + if (!requestId2Accumulators.containsKey(request)) { + requestId2Path.put(request, new HashMap<>()); + requestId2Accumulators.put(request, new HashMap<>()); } + record.mapPath( + path -> { + RowKey key = groupByFunction.getKey(path); + Map key2Path = requestId2Path.get(request); + Map key2Acc = requestId2Accumulators.get(request); + Path prunedPath = path.subPath(pathPruneIndices); + Object partAccumulator; + if (isPathHeadVertexType) { + partAccumulator = + ((IVertex) prunedPath.getPathNodes().get(0)) + .getValue() + .getField(aggregateNodeTypes.length, ObjectType.INSTANCE); + } else { + partAccumulator = + ((IEdge) prunedPath.getPathNodes().get(0)) + .getValue() + .getField(aggregateNodeTypes.length, ObjectType.INSTANCE); + } + if (!key2Acc.containsKey(key)) { + key2Acc.put(key, partAccumulator); + key2Path.put(key, prunedPath); + } else { + Object accumulator = key2Acc.get(key); + function.merge(accumulator, partAccumulator); + key2Acc.put(key, accumulator); + } + return path; + }, + null); + } - @Override - public void finish() { - for (Map.Entry> entry : requestId2Accumulators.entrySet()) { - ParameterRequest request = entry.getKey(); - context.setRequest(request); - Map key2Acc = entry.getValue(); - Map key2Path = requestId2Path.get(request); - for (Entry rowKeyObjectEntry : key2Acc.entrySet()) { - RowKey rowKey = rowKeyObjectEntry.getKey(); - Path path = (Path) key2Path.get(rowKey); - Row[] values = new Row[pathPruneIndices.length]; - for (int i = 0; i < pathPruneIndices.length; i++) { - values[i] = path.getField(i, inputPathTypes[pathPruneIndices[i]]); - } - Row aggregateNodeValue; - int aggregateNodeIndex = 0; - if (isPathHeadVertexType) { - aggregateNodeValue = ((IVertex) values[aggregateNodeIndex]).getValue(); - } else { - aggregateNodeValue = ((IEdge) values[aggregateNodeIndex]).getValue(); - } - Object[] aggregateNodeValues = new Object[aggregateNodeTypes.length - + rowKey.getKeys().length + aggOutputTypes.length]; - for (int i = 0; i < aggregateNodeTypes.length; i++) { - aggregateNodeValues[i] = aggregateNodeValue.getField(i, aggregateNodeTypes[i]); - } - int offset = aggregateNodeTypes.length; - for (int j = 0; j < rowKey.getKeys().length; j++) { - aggregateNodeValues[offset + j] = rowKey.getKeys()[j]; - } - offset = aggregateNodeTypes.length + rowKey.getKeys().length; - Object accumulator = rowKeyObjectEntry.getValue(); - ObjectRow accumulatorValues = (ObjectRow) function.getValue(accumulator).getValue(ObjectType.INSTANCE); - for (int j = 0; j < aggOutputTypes.length; j++) { - aggregateNodeValues[offset + j] = accumulatorValues.getField(j, aggOutputTypes[j]); - } - if (isPathHeadVertexType) { - values[aggregateNodeIndex] = (Row) ((IVertex) values[aggregateNodeIndex]) - .withValue(ObjectRow.create(aggregateNodeValues)); - } else { - values[aggregateNodeIndex] = (Row) ((IEdge) values[aggregateNodeIndex]) - .withValue(ObjectRow.create(aggregateNodeValues)); - } - ITreePath globalAggPath = TreePaths.singletonPath(new DefaultPath(values)); - if (isPathTailVertexType) { - collect(VertexRecord.of(IdOnlyVertex.of(globalAggPath.getVertexId()), globalAggPath)); - } else { - Map targetId2TreePaths = new HashMap<>(); - targetId2TreePaths.put(globalAggPath.getEdgeSet().getTargetId(), globalAggPath); - collect(EdgeGroupRecord.of(EdgeGroup.of(globalAggPath.getEdgeSet()), - targetId2TreePaths)); - } - } + @Override + public void finish() { + for (Map.Entry> entry : + requestId2Accumulators.entrySet()) { + ParameterRequest request = entry.getKey(); + context.setRequest(request); + Map key2Acc = entry.getValue(); + Map key2Path = requestId2Path.get(request); + for (Entry rowKeyObjectEntry : key2Acc.entrySet()) { + RowKey rowKey = rowKeyObjectEntry.getKey(); + Path path = (Path) key2Path.get(rowKey); + Row[] values = new Row[pathPruneIndices.length]; + for (int i = 0; i < pathPruneIndices.length; i++) { + values[i] = path.getField(i, inputPathTypes[pathPruneIndices[i]]); + } + Row aggregateNodeValue; + int aggregateNodeIndex = 0; + if (isPathHeadVertexType) { + aggregateNodeValue = ((IVertex) values[aggregateNodeIndex]).getValue(); + } else { + aggregateNodeValue = ((IEdge) values[aggregateNodeIndex]).getValue(); + } + Object[] aggregateNodeValues = + new Object[aggregateNodeTypes.length + rowKey.getKeys().length + aggOutputTypes.length]; + for (int i = 0; i < aggregateNodeTypes.length; i++) { + aggregateNodeValues[i] = aggregateNodeValue.getField(i, aggregateNodeTypes[i]); } - requestId2Accumulators.clear(); - requestId2Path.clear(); - super.finish(); + int offset = aggregateNodeTypes.length; + for (int j = 0; j < rowKey.getKeys().length; j++) { + aggregateNodeValues[offset + j] = rowKey.getKeys()[j]; + } + offset = aggregateNodeTypes.length + rowKey.getKeys().length; + Object accumulator = rowKeyObjectEntry.getValue(); + ObjectRow accumulatorValues = + (ObjectRow) function.getValue(accumulator).getValue(ObjectType.INSTANCE); + for (int j = 0; j < aggOutputTypes.length; j++) { + aggregateNodeValues[offset + j] = accumulatorValues.getField(j, aggOutputTypes[j]); + } + if (isPathHeadVertexType) { + values[aggregateNodeIndex] = + (Row) + ((IVertex) values[aggregateNodeIndex]) + .withValue(ObjectRow.create(aggregateNodeValues)); + } else { + values[aggregateNodeIndex] = + (Row) + ((IEdge) values[aggregateNodeIndex]) + .withValue(ObjectRow.create(aggregateNodeValues)); + } + ITreePath globalAggPath = TreePaths.singletonPath(new DefaultPath(values)); + if (isPathTailVertexType) { + collect(VertexRecord.of(IdOnlyVertex.of(globalAggPath.getVertexId()), globalAggPath)); + } else { + Map targetId2TreePaths = new HashMap<>(); + targetId2TreePaths.put(globalAggPath.getEdgeSet().getTargetId(), globalAggPath); + collect(EdgeGroupRecord.of(EdgeGroup.of(globalAggPath.getEdgeSet()), targetId2TreePaths)); + } + } } + requestId2Accumulators.clear(); + requestId2Path.clear(); + super.finish(); + } - @Override - public StepOperator copyInternal() { - return new StepGlobalAggregateOperator(id, groupByFunction, function); - } + @Override + public StepOperator copyInternal() { + return new StepGlobalAggregateOperator(id, groupByFunction, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSingleValueAggregateOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSingleValueAggregateOperator.java index 6f8cffbb1..9e5eb5c8e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSingleValueAggregateOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSingleValueAggregateOperator.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.runtime.function.graph.StepAggregateFunction; @@ -31,51 +32,53 @@ import org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessage; import org.apache.geaflow.dsl.runtime.traversal.message.MessageType; -public class StepGlobalSingleValueAggregateOperator extends AbstractStepOperator { +public class StepGlobalSingleValueAggregateOperator + extends AbstractStepOperator { - private final IType inputType; + private final IType inputType; - private Map requestId2Accumulators; + private Map requestId2Accumulators; - public StepGlobalSingleValueAggregateOperator(long id, IType inputType, StepAggregateFunction function) { - super(id, function); - this.inputType = inputType; - } + public StepGlobalSingleValueAggregateOperator( + long id, IType inputType, StepAggregateFunction function) { + super(id, function); + this.inputType = inputType; + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - requestId2Accumulators = new HashMap<>(); - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + requestId2Accumulators = new HashMap<>(); + } - @Override - protected void processRecord(VertexRecord record) { - KeyGroupMessage groupMessage = context.getMessage(MessageType.KEY_GROUP); - ParameterRequest request = context.getRequest(); - Object accumulator = requestId2Accumulators.computeIfAbsent(request, - r -> function.createAccumulator()); + @Override + protected void processRecord(VertexRecord record) { + KeyGroupMessage groupMessage = context.getMessage(MessageType.KEY_GROUP); + ParameterRequest request = context.getRequest(); + Object accumulator = + requestId2Accumulators.computeIfAbsent(request, r -> function.createAccumulator()); - for (Row row : groupMessage.getGroupRows()) { - Object inputAcc = row.getField(0, inputType); - function.merge(accumulator, inputAcc); - } + for (Row row : groupMessage.getGroupRows()) { + Object inputAcc = row.getField(0, inputType); + function.merge(accumulator, inputAcc); } + } - @Override - public void finish() { - for (Map.Entry entry : requestId2Accumulators.entrySet()) { - ParameterRequest request = entry.getKey(); - Object accumulator = entry.getValue(); - SingleValue value = function.getValue(accumulator); - context.setRequest(request); - collect(value); - } - requestId2Accumulators.clear(); - super.finish(); + @Override + public void finish() { + for (Map.Entry entry : requestId2Accumulators.entrySet()) { + ParameterRequest request = entry.getKey(); + Object accumulator = entry.getValue(); + SingleValue value = function.getValue(accumulator); + context.setRequest(request); + collect(value); } + requestId2Accumulators.clear(); + super.finish(); + } - @Override - public StepOperator copyInternal() { - return new StepGlobalSingleValueAggregateOperator(id, inputType, function); - } + @Override + public StepOperator copyInternal() { + return new StepGlobalSingleValueAggregateOperator(id, inputType, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSortOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSortOperator.java index 69274d327..895092d2f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSortOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepGlobalSortOperator.java @@ -28,32 +28,32 @@ import org.apache.geaflow.dsl.runtime.traversal.message.KeyGroupMessage; import org.apache.geaflow.dsl.runtime.traversal.message.MessageType; -public class StepGlobalSortOperator extends AbstractStepOperator { - - private final IType rowVertexType; - private final IType inputType; - - public StepGlobalSortOperator(long id, StepSortFunction function, IType rowVertexType, - IType inputType) { - super(id, function); - this.rowVertexType = rowVertexType; - this.inputType = inputType; - } - - @Override - protected void processRecord(VertexRecord record) { - KeyGroupMessage groupMessage = context.getMessage(MessageType.KEY_GROUP); - for (Row row : groupMessage.getGroupRows()) { - RowVertex head = (RowVertex) row.getField(0, rowVertexType); - Path path = (Path) row.getField(1, inputType); - - function.process(head, path); - } +public class StepGlobalSortOperator + extends AbstractStepOperator { + + private final IType rowVertexType; + private final IType inputType; + + public StepGlobalSortOperator( + long id, StepSortFunction function, IType rowVertexType, IType inputType) { + super(id, function); + this.rowVertexType = rowVertexType; + this.inputType = inputType; + } + + @Override + protected void processRecord(VertexRecord record) { + KeyGroupMessage groupMessage = context.getMessage(MessageType.KEY_GROUP); + for (Row row : groupMessage.getGroupRows()) { + RowVertex head = (RowVertex) row.getField(0, rowVertexType); + Path path = (Path) row.getField(1, inputType); + + function.process(head, path); } + } - @Override - public StepOperator copyInternal() { - return new StepGlobalSortOperator(id, function, rowVertexType, inputType); - } + @Override + public StepOperator copyInternal() { + return new StepGlobalSortOperator(id, function, rowVertexType, inputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepJoinOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepJoinOperator.java index 915236419..a1d9df8b4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepJoinOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepJoinOperator.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.RowKey; import org.apache.geaflow.dsl.common.data.RowVertex; @@ -40,141 +41,150 @@ import org.apache.geaflow.dsl.runtime.traversal.path.TreePaths; import org.apache.geaflow.dsl.runtime.util.SchemaUtil; -public class StepJoinOperator extends AbstractStepOperator { - - private final PathType inputJoinPathSchema; - - private final List inputPathSchemas; - - // requestId -> jonKey -> (leftTreePath, rightTreePath) - private Map> cachedLeftAndRightTreePaths; - - private long leftInputOpId; - - private long rightInputOpId; - - private boolean isLocalJoin; - - public StepJoinOperator(long id, StepJoinFunction function, PathType inputJoinPathSchema, - List inputPathSchemas, boolean isLocalJoin) { - super(id, function); - this.inputJoinPathSchema = Objects.requireNonNull(inputJoinPathSchema); - this.inputPathSchemas = Objects.requireNonNull(inputPathSchemas); - assert inputPathSchemas.size() == 2; - this.isLocalJoin = isLocalJoin; - } - - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - this.cachedLeftAndRightTreePaths = new HashMap<>(); - this.leftInputOpId = context.getTopology().getInputIds(id).get(0); - this.rightInputOpId = context.getTopology().getInputIds(id).get(1); - } - - @Override - protected void processRecord(VertexRecord record) { - RowVertex vertex = record.getVertex(); - RowKey joinKey = (RowKey) vertex.getId(); - ITreePath leftTree; - ITreePath rightTree; - if (isLocalJoin) { - leftTree = context.getInputOperatorId() == leftInputOpId ? record.getTreePath() : null; - rightTree = context.getInputOperatorId() == rightInputOpId ? record.getTreePath() : null; - } else { - JoinPathMessage pathMessage = context.getMessage(MessageType.JOIN_PATH); - leftTree = pathMessage.getTreePath(leftInputOpId); - rightTree = pathMessage.getTreePath(rightInputOpId); - } - cachedLeftAndRightTreePaths - .computeIfAbsent(context.getRequestId(), k -> new HashMap<>()) - .computeIfAbsent(joinKey, k -> new JoinTree()) - .addLeftTree(leftTree) - .addRightTree(rightTree); +public class StepJoinOperator + extends AbstractStepOperator { + + private final PathType inputJoinPathSchema; + + private final List inputPathSchemas; + + // requestId -> jonKey -> (leftTreePath, rightTreePath) + private Map> cachedLeftAndRightTreePaths; + + private long leftInputOpId; + + private long rightInputOpId; + + private boolean isLocalJoin; + + public StepJoinOperator( + long id, + StepJoinFunction function, + PathType inputJoinPathSchema, + List inputPathSchemas, + boolean isLocalJoin) { + super(id, function); + this.inputJoinPathSchema = Objects.requireNonNull(inputJoinPathSchema); + this.inputPathSchemas = Objects.requireNonNull(inputPathSchemas); + assert inputPathSchemas.size() == 2; + this.isLocalJoin = isLocalJoin; + } + + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + this.cachedLeftAndRightTreePaths = new HashMap<>(); + this.leftInputOpId = context.getTopology().getInputIds(id).get(0); + this.rightInputOpId = context.getTopology().getInputIds(id).get(1); + } + + @Override + protected void processRecord(VertexRecord record) { + RowVertex vertex = record.getVertex(); + RowKey joinKey = (RowKey) vertex.getId(); + ITreePath leftTree; + ITreePath rightTree; + if (isLocalJoin) { + leftTree = context.getInputOperatorId() == leftInputOpId ? record.getTreePath() : null; + rightTree = context.getInputOperatorId() == rightInputOpId ? record.getTreePath() : null; + } else { + JoinPathMessage pathMessage = context.getMessage(MessageType.JOIN_PATH); + leftTree = pathMessage.getTreePath(leftInputOpId); + rightTree = pathMessage.getTreePath(rightInputOpId); } - - @Override - public void finish() { - Set requestIds = cachedLeftAndRightTreePaths.keySet(); - for (Object requestId : requestIds) { - Map joinTreeMap = cachedLeftAndRightTreePaths.get(requestId); - - for (Map.Entry entry : joinTreeMap.entrySet()) { - RowKey joinKey = entry.getKey(); - JoinTree joinTree = entry.getValue(); - - List joinPaths = new ArrayList<>(); - List leftPaths = - joinTree.leftTree == null ? Collections.singletonList(null) : joinTree.leftTree.toList(); - List rightPaths = - joinTree.rightTree == null ? Collections.singletonList(null) : joinTree.rightTree.toList(); - - for (Path leftPath : leftPaths) { - int numJoinEdge = 0; - for (Path rightPath : rightPaths) { - Path joinPath = function.join(leftPath, rightPath); - if (joinPath != null) { - numJoinEdge++; - joinPaths.add(joinPath); - } - } - if (numJoinEdge == 0) { - switch (function.getJoinType()) { - case LEFT: - joinPaths.add(SchemaUtil.alignToPathSchema(leftPath, - inputPathSchemas.get(0), getOutputPathSchema())); - break; - default: - } - } - } - for (Path joinPath : joinPaths) { - ITreePath joinTreePath = TreePaths.singletonPath(joinPath); - collect(VertexRecord.of(IdOnlyVertex.of(joinKey), joinTreePath)); - } + cachedLeftAndRightTreePaths + .computeIfAbsent(context.getRequestId(), k -> new HashMap<>()) + .computeIfAbsent(joinKey, k -> new JoinTree()) + .addLeftTree(leftTree) + .addRightTree(rightTree); + } + + @Override + public void finish() { + Set requestIds = cachedLeftAndRightTreePaths.keySet(); + for (Object requestId : requestIds) { + Map joinTreeMap = cachedLeftAndRightTreePaths.get(requestId); + + for (Map.Entry entry : joinTreeMap.entrySet()) { + RowKey joinKey = entry.getKey(); + JoinTree joinTree = entry.getValue(); + + List joinPaths = new ArrayList<>(); + List leftPaths = + joinTree.leftTree == null + ? Collections.singletonList(null) + : joinTree.leftTree.toList(); + List rightPaths = + joinTree.rightTree == null + ? Collections.singletonList(null) + : joinTree.rightTree.toList(); + + for (Path leftPath : leftPaths) { + int numJoinEdge = 0; + for (Path rightPath : rightPaths) { + Path joinPath = function.join(leftPath, rightPath); + if (joinPath != null) { + numJoinEdge++; + joinPaths.add(joinPath); } + } + if (numJoinEdge == 0) { + switch (function.getJoinType()) { + case LEFT: + joinPaths.add( + SchemaUtil.alignToPathSchema( + leftPath, inputPathSchemas.get(0), getOutputPathSchema())); + break; + default: + } + } + } + for (Path joinPath : joinPaths) { + ITreePath joinTreePath = TreePaths.singletonPath(joinPath); + collect(VertexRecord.of(IdOnlyVertex.of(joinKey), joinTreePath)); } - cachedLeftAndRightTreePaths.clear(); - super.finish(); + } } + cachedLeftAndRightTreePaths.clear(); + super.finish(); + } - @Override - public StepOperator copyInternal() { - return new StepJoinOperator(id, function, inputJoinPathSchema, inputPathSchemas, - isLocalJoin); - } + @Override + public StepOperator copyInternal() { + return new StepJoinOperator(id, function, inputJoinPathSchema, inputPathSchemas, isLocalJoin); + } - @Override - protected PathType concatInputPathType() { - return inputJoinPathSchema; - } + @Override + protected PathType concatInputPathType() { + return inputJoinPathSchema; + } - private static class JoinTree { + private static class JoinTree { - public ITreePath leftTree; + public ITreePath leftTree; - public ITreePath rightTree; + public ITreePath rightTree; - public JoinTree addLeftTree(ITreePath leftTree) { - if (leftTree != null) { - if (this.leftTree == null) { - this.leftTree = leftTree; - } else { - this.leftTree = this.leftTree.merge(leftTree); - } - } - return this; + public JoinTree addLeftTree(ITreePath leftTree) { + if (leftTree != null) { + if (this.leftTree == null) { + this.leftTree = leftTree; + } else { + this.leftTree = this.leftTree.merge(leftTree); } + } + return this; + } - public JoinTree addRightTree(ITreePath rightTree) { - if (rightTree != null) { - if (this.rightTree == null) { - this.rightTree = rightTree; - } else { - this.rightTree = this.rightTree.merge(rightTree); - } - } - return this; + public JoinTree addRightTree(ITreePath rightTree) { + if (rightTree != null) { + if (this.rightTree == null) { + this.rightTree = rightTree; + } else { + this.rightTree = this.rightTree.merge(rightTree); } + } + return this; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalAggregateOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalAggregateOperator.java index 6883da269..31ee3839d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalAggregateOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalAggregateOperator.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -47,106 +48,112 @@ public class StepLocalAggregateOperator extends AbstractStepOperator { - private final StepKeyFunction groupByFunction; - private Map> requestId2Path; - private Map> requestId2Accumulators; - private final int[] pathPruneIndices; - private final IType[] pathPruneTypes; - private final IType[] inputPathTypes; - boolean isVertexType; - private final IType[] aggregateNodeTypes; + private final StepKeyFunction groupByFunction; + private Map> requestId2Path; + private Map> requestId2Accumulators; + private final int[] pathPruneIndices; + private final IType[] pathPruneTypes; + private final IType[] inputPathTypes; + boolean isVertexType; + private final IType[] aggregateNodeTypes; - public StepLocalAggregateOperator(long id, StepKeyFunction keyFunction, StepAggregateFunction function) { - super(id, function); - this.groupByFunction = Objects.requireNonNull(keyFunction); - this.pathPruneIndices = ((StepAggExpressionFunctionImpl) function).getPathPruneIndices(); - this.pathPruneTypes = ((StepAggExpressionFunctionImpl) function).getPathPruneTypes(); - this.inputPathTypes = ((StepAggExpressionFunctionImpl) function).getInputPathTypes(); - assert pathPruneIndices.length > 0 && pathPruneTypes.length == pathPruneIndices.length; - this.isVertexType = pathPruneTypes[0] instanceof VertexType; - this.aggregateNodeTypes = isVertexType + public StepLocalAggregateOperator( + long id, StepKeyFunction keyFunction, StepAggregateFunction function) { + super(id, function); + this.groupByFunction = Objects.requireNonNull(keyFunction); + this.pathPruneIndices = ((StepAggExpressionFunctionImpl) function).getPathPruneIndices(); + this.pathPruneTypes = ((StepAggExpressionFunctionImpl) function).getPathPruneTypes(); + this.inputPathTypes = ((StepAggExpressionFunctionImpl) function).getInputPathTypes(); + assert pathPruneIndices.length > 0 && pathPruneTypes.length == pathPruneIndices.length; + this.isVertexType = pathPruneTypes[0] instanceof VertexType; + this.aggregateNodeTypes = + isVertexType ? ((VertexType) inputPathTypes[pathPruneIndices[0]]).getValueTypes() : ((EdgeType) inputPathTypes[pathPruneIndices[0]]).getValueTypes(); - } + } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + requestId2Path = new HashMap<>(); + requestId2Accumulators = new HashMap<>(); + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - requestId2Path = new HashMap<>(); - requestId2Accumulators = new HashMap<>(); + @Override + protected void processRecord(StepRecordWithPath record) { + ParameterRequest request = context.getRequest(); + if (!requestId2Accumulators.containsKey(request)) { + requestId2Path.put(request, new HashMap<>()); + requestId2Accumulators.put(request, new HashMap<>()); } + record.mapPath( + path -> { + RowKey key = groupByFunction.getKey(path); + Map key2Path = requestId2Path.get(request); + Map key2Acc = requestId2Accumulators.get(request); + if (!key2Acc.containsKey(key)) { + key2Acc.put(key, function.createAccumulator()); + key2Path.put(key, path); + } + Object accumulator = key2Acc.get(key); + function.add(path, accumulator); + key2Acc.put(key, accumulator); + return path; + }, + null); + } - @Override - protected void processRecord(StepRecordWithPath record) { - ParameterRequest request = context.getRequest(); - if (!requestId2Accumulators.containsKey(request)) { - requestId2Path.put(request, new HashMap<>()); - requestId2Accumulators.put(request, new HashMap<>()); + @Override + public void finish() { + for (Map.Entry> entry : + requestId2Accumulators.entrySet()) { + ParameterRequest request = entry.getKey(); + context.setRequest(request); + Map key2Acc = entry.getValue(); + Map key2Path = requestId2Path.get(request); + for (Entry rowKeyObjectEntry : key2Acc.entrySet()) { + RowKey rowKey = rowKeyObjectEntry.getKey(); + Path path = (Path) key2Path.get(rowKey); + Row[] values = new Row[inputPathTypes.length]; + for (int i = 0; i < inputPathTypes.length; i++) { + values[i] = path.getField(i, inputPathTypes[i]); } - record.mapPath(path -> { - RowKey key = groupByFunction.getKey(path); - Map key2Path = requestId2Path.get(request); - Map key2Acc = requestId2Accumulators.get(request); - if (!key2Acc.containsKey(key)) { - key2Acc.put(key, function.createAccumulator()); - key2Path.put(key, path); - } - Object accumulator = key2Acc.get(key); - function.add(path, accumulator); - key2Acc.put(key, accumulator); - return path; - }, null); - - } - - @Override - public void finish() { - for (Map.Entry> entry : requestId2Accumulators.entrySet()) { - ParameterRequest request = entry.getKey(); - context.setRequest(request); - Map key2Acc = entry.getValue(); - Map key2Path = requestId2Path.get(request); - for (Entry rowKeyObjectEntry : key2Acc.entrySet()) { - RowKey rowKey = rowKeyObjectEntry.getKey(); - Path path = (Path) key2Path.get(rowKey); - Row[] values = new Row[inputPathTypes.length]; - for (int i = 0; i < inputPathTypes.length; i++) { - values[i] = path.getField(i, inputPathTypes[i]); - } - Row aggregateNodeValue; - int aggregateNodeIndex = pathPruneIndices[0]; - if (isVertexType) { - aggregateNodeValue = ((IVertex) values[aggregateNodeIndex]).getValue(); - } else { - aggregateNodeValue = ((IEdge) values[aggregateNodeIndex]).getValue(); - } - //The last offset of aggregate node is accumulator - Object[] aggregateNodeValues = - new Object[aggregateNodeTypes.length + 1]; - for (int j = 0; j < aggregateNodeTypes.length; j++) { - aggregateNodeValues[j] = aggregateNodeValue.getField(j, aggregateNodeTypes[j]); - } - Object accumulator = rowKeyObjectEntry.getValue(); - aggregateNodeValues[aggregateNodeValues.length - 1] = accumulator; - if (isVertexType) { - values[aggregateNodeIndex] = (Row) ((IVertex) values[aggregateNodeIndex]) - .withValue(ObjectRow.create(aggregateNodeValues)); - } else { - values[aggregateNodeIndex] = (Row) ((IEdge) values[aggregateNodeIndex]) - .withValue(ObjectRow.create(aggregateNodeValues)); - } - ITreePath localAggPath = TreePaths.singletonPath(new DefaultPath(values)); - collect(VertexRecord.of(IdOnlyVertex.of(rowKey), localAggPath)); - } + Row aggregateNodeValue; + int aggregateNodeIndex = pathPruneIndices[0]; + if (isVertexType) { + aggregateNodeValue = ((IVertex) values[aggregateNodeIndex]).getValue(); + } else { + aggregateNodeValue = ((IEdge) values[aggregateNodeIndex]).getValue(); + } + // The last offset of aggregate node is accumulator + Object[] aggregateNodeValues = new Object[aggregateNodeTypes.length + 1]; + for (int j = 0; j < aggregateNodeTypes.length; j++) { + aggregateNodeValues[j] = aggregateNodeValue.getField(j, aggregateNodeTypes[j]); } - requestId2Accumulators.clear(); - requestId2Path.clear(); - super.finish(); + Object accumulator = rowKeyObjectEntry.getValue(); + aggregateNodeValues[aggregateNodeValues.length - 1] = accumulator; + if (isVertexType) { + values[aggregateNodeIndex] = + (Row) + ((IVertex) values[aggregateNodeIndex]) + .withValue(ObjectRow.create(aggregateNodeValues)); + } else { + values[aggregateNodeIndex] = + (Row) + ((IEdge) values[aggregateNodeIndex]) + .withValue(ObjectRow.create(aggregateNodeValues)); + } + ITreePath localAggPath = TreePaths.singletonPath(new DefaultPath(values)); + collect(VertexRecord.of(IdOnlyVertex.of(rowKey), localAggPath)); + } } + requestId2Accumulators.clear(); + requestId2Path.clear(); + super.finish(); + } - @Override - public StepOperator copyInternal() { - return new StepLocalAggregateOperator(id, groupByFunction, function); - } + @Override + public StepOperator copyInternal() { + return new StepLocalAggregateOperator(id, groupByFunction, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalExchangeOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalExchangeOperator.java index 6a37a2cc3..8bbab7d63 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalExchangeOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalExchangeOperator.java @@ -19,8 +19,8 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowKey; @@ -37,39 +37,41 @@ import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.dsl.runtime.traversal.path.TreePaths; -public class StepLocalExchangeOperator extends AbstractStepOperator { +import com.google.common.collect.Lists; +public class StepLocalExchangeOperator + extends AbstractStepOperator { - public StepLocalExchangeOperator(long id, StepKeyFunction function) { - super(id, function); - } + public StepLocalExchangeOperator(long id, StepKeyFunction function) { + super(id, function); + } - @Override - protected void processRecord(StepRecord record) { - if (record.getType() == StepRecordType.ROW) { - Row row = (Row) record; - RowKey key = function.getKey(row); - if (context.getRequest() != null) { // append requestId to the key. - key = new DefaultRowKeyWithRequestId(context.getRequestId(), key); - } - StepKeyRecord keyRecord = new StepKeyRecordImpl(key, row); - collect(keyRecord); - } else { - StepRecordWithPath recordWithPath = (StepRecordWithPath) record; - for (ITreePath treePath : recordWithPath.getPaths()) { - List paths = treePath.toList(); - for (Path path : paths) { - RowKey key = function.getKey(path); - RowVertex virtualVertex = IdOnlyVertex.of(key); - ITreePath virtualVertexPath = TreePaths.createTreePath(Lists.newArrayList(path)); - collect(VertexRecord.of(virtualVertex, virtualVertexPath)); - } - } + @Override + protected void processRecord(StepRecord record) { + if (record.getType() == StepRecordType.ROW) { + Row row = (Row) record; + RowKey key = function.getKey(row); + if (context.getRequest() != null) { // append requestId to the key. + key = new DefaultRowKeyWithRequestId(context.getRequestId(), key); + } + StepKeyRecord keyRecord = new StepKeyRecordImpl(key, row); + collect(keyRecord); + } else { + StepRecordWithPath recordWithPath = (StepRecordWithPath) record; + for (ITreePath treePath : recordWithPath.getPaths()) { + List paths = treePath.toList(); + for (Path path : paths) { + RowKey key = function.getKey(path); + RowVertex virtualVertex = IdOnlyVertex.of(key); + ITreePath virtualVertexPath = TreePaths.createTreePath(Lists.newArrayList(path)); + collect(VertexRecord.of(virtualVertex, virtualVertexPath)); } + } } + } - @Override - public StepOperator copyInternal() { - return new StepLocalExchangeOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepLocalExchangeOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalSingleValueAggregateOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalSingleValueAggregateOperator.java index c3a15a78a..7b70e6f23 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalSingleValueAggregateOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLocalSingleValueAggregateOperator.java @@ -21,50 +21,52 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; import org.apache.geaflow.dsl.runtime.function.graph.StepAggregateFunction; import org.apache.geaflow.dsl.runtime.traversal.TraversalRuntimeContext; import org.apache.geaflow.dsl.runtime.traversal.data.ParameterRequest; -public class StepLocalSingleValueAggregateOperator extends AbstractStepOperator { +public class StepLocalSingleValueAggregateOperator + extends AbstractStepOperator { - private Map requestId2Accumulators; + private Map requestId2Accumulators; - public StepLocalSingleValueAggregateOperator(long id, StepAggregateFunction function) { - super(id, function); - } + public StepLocalSingleValueAggregateOperator(long id, StepAggregateFunction function) { + super(id, function); + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - requestId2Accumulators = new HashMap<>(); - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + requestId2Accumulators = new HashMap<>(); + } - @Override - protected void processRecord(Row record) { - ParameterRequest request = context.getRequest(); - Object accumulator = requestId2Accumulators.computeIfAbsent(request, - r -> function.createAccumulator()); - function.add(record, accumulator); - } + @Override + protected void processRecord(Row record) { + ParameterRequest request = context.getRequest(); + Object accumulator = + requestId2Accumulators.computeIfAbsent(request, r -> function.createAccumulator()); + function.add(record, accumulator); + } - @Override - public void finish() { - for (Map.Entry entry : requestId2Accumulators.entrySet()) { - ParameterRequest request = entry.getKey(); - Object accumulator = entry.getValue(); - Row localResult = ObjectRow.create(accumulator); + @Override + public void finish() { + for (Map.Entry entry : requestId2Accumulators.entrySet()) { + ParameterRequest request = entry.getKey(); + Object accumulator = entry.getValue(); + Row localResult = ObjectRow.create(accumulator); - context.setRequest(request); - collect(localResult); - } - requestId2Accumulators.clear(); - super.finish(); + context.setRequest(request); + collect(localResult); } + requestId2Accumulators.clear(); + super.finish(); + } - @Override - public StepOperator copyInternal() { - return new StepLocalSingleValueAggregateOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepLocalSingleValueAggregateOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLoopUntilOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLoopUntilOperator.java index d96b77a27..ba5a2a6ff 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLoopUntilOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepLoopUntilOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; import java.util.List; + import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.data.StepRecord; @@ -35,129 +36,143 @@ import org.apache.geaflow.dsl.runtime.traversal.path.ITreePath; import org.apache.geaflow.state.pushdown.filter.EmptyFilter; -public class StepLoopUntilOperator extends AbstractStepOperator { - - private final long loopStartOpId; - private final long loopBodyOpId; - private StepCollector loopStartCollector; - private int loopCounter; - - private final int minLoopCount; - private final int maxLoopCount; - - private final int loopStartPathFieldCount; - private final int loopBodyPathFieldCount; - - private int[] pathIndices; - - public StepLoopUntilOperator(long id, long loopStartOpId, - long loopBodyOpId, StepBoolFunction function, - int minLoopCount, int maxLoopCount, - int loopStartPathFieldCount, int loopBodyPathFieldCount) { - super(id, function); - this.loopStartOpId = loopStartOpId; - this.loopBodyOpId = loopBodyOpId; - this.minLoopCount = minLoopCount; - this.maxLoopCount = maxLoopCount == -1 ? Integer.MAX_VALUE : maxLoopCount; - this.loopStartPathFieldCount = loopStartPathFieldCount; - this.loopBodyPathFieldCount = loopBodyPathFieldCount; - } +public class StepLoopUntilOperator + extends AbstractStepOperator { - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - this.loopStartCollector = new StepJumpCollector(id, loopStartOpId, context); - // When reach the loop-util, we have already looped 1 time, so the init loop counter should be 1. - this.loopCounter = 1; - this.pathIndices = new int[loopBodyPathFieldCount + loopStartPathFieldCount]; - } + private final long loopStartOpId; + private final long loopBodyOpId; + private StepCollector loopStartCollector; + private int loopCounter; - @Override - public StepOperator copyInternal() { - return new StepLoopUntilOperator(id, loopStartOpId, loopBodyOpId, function, minLoopCount, maxLoopCount, - loopStartPathFieldCount, loopBodyPathFieldCount); - } + private final int minLoopCount; + private final int maxLoopCount; - @Override - protected void processRecord(StepRecordWithPath record) { - boolean fromZeroLoop = context.getInputOperatorId() != loopBodyOpId; - StepRecordWithPath lastLoopPath = selectLastLoopPath(record, fromZeroLoop); - if (fromZeroLoop) { // from the loop start (for loop 0) - collect(lastLoopPath); - } else { // from the loop body - if (loopCounter >= minLoopCount) { - collect(lastLoopPath); - } - if (loopCounter < maxLoopCount) { - loopStartCollector.collect(lastLoopPath); - } - } - } + private final int loopStartPathFieldCount; + private final int loopBodyPathFieldCount; - private StepRecordWithPath selectLastLoopPath(StepRecordWithPath record, boolean fromZeroLoop) { - RowVertex vertexId = ((VertexRecord) record).getVertex(); - if (fromZeroLoop) { - final RowVertex vertexRecord; - if (vertexId instanceof IdOnlyVertex && !(vertexId.getId() instanceof VirtualId)) { - vertexRecord = context.loadVertex(vertexId.getId(), - EmptyFilter.getInstance(), - graphSchema, - addingVertexFieldTypes); - } else { - vertexRecord = vertexId; - } - return record.mapTreePath(treePath -> { - ITreePath newTreePath = treePath; - for (int i = 0; i < loopBodyPathFieldCount - 1; i++) { - newTreePath = newTreePath.extendTo((RowEdge) null); - } - return newTreePath.extendTo(vertexRecord); - }); - } else { - for (int i = 0; i < loopStartPathFieldCount; i++) { - pathIndices[i] = i; - } - for (int i = 0; i < loopBodyPathFieldCount; i++) { - // When calculating the index for the loopBody fields, when - // loopCounter is 1, the first offset is used for input values. After that, - // values generated by the loop are placed starting from an offset of 1 - pathIndices[i + loopStartPathFieldCount] = loopStartPathFieldCount - + Math.min(loopCounter - 1, 1) * loopBodyPathFieldCount + i; - } - return record.subPathSet(pathIndices); - } + private int[] pathIndices; + + public StepLoopUntilOperator( + long id, + long loopStartOpId, + long loopBodyOpId, + StepBoolFunction function, + int minLoopCount, + int maxLoopCount, + int loopStartPathFieldCount, + int loopBodyPathFieldCount) { + super(id, function); + this.loopStartOpId = loopStartOpId; + this.loopBodyOpId = loopBodyOpId; + this.minLoopCount = minLoopCount; + this.maxLoopCount = maxLoopCount == -1 ? Integer.MAX_VALUE : maxLoopCount; + this.loopStartPathFieldCount = loopStartPathFieldCount; + this.loopBodyPathFieldCount = loopBodyPathFieldCount; + } + + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + this.loopStartCollector = new StepJumpCollector(id, loopStartOpId, context); + // When reach the loop-util, we have already looped 1 time, so the init loop counter should be + // 1. + this.loopCounter = 1; + this.pathIndices = new int[loopBodyPathFieldCount + loopStartPathFieldCount]; + } + + @Override + public StepOperator copyInternal() { + return new StepLoopUntilOperator( + id, + loopStartOpId, + loopBodyOpId, + function, + minLoopCount, + maxLoopCount, + loopStartPathFieldCount, + loopBodyPathFieldCount); + } + + @Override + protected void processRecord(StepRecordWithPath record) { + boolean fromZeroLoop = context.getInputOperatorId() != loopBodyOpId; + StepRecordWithPath lastLoopPath = selectLastLoopPath(record, fromZeroLoop); + if (fromZeroLoop) { // from the loop start (for loop 0) + collect(lastLoopPath); + } else { // from the loop body + if (loopCounter >= minLoopCount) { + collect(lastLoopPath); + } + if (loopCounter < maxLoopCount) { + loopStartCollector.collect(lastLoopPath); + } } + } - @Override - protected void onReceiveAllEOD(long callerOpId, List receiveEods) { - boolean isGlobalEmptyCycle = true; - for (EndOfData eod : receiveEods) { - if (eod.getSenderId() == this.loopBodyOpId) { - isGlobalEmptyCycle &= eod.isGlobalEmptyCycle; + private StepRecordWithPath selectLastLoopPath(StepRecordWithPath record, boolean fromZeroLoop) { + RowVertex vertexId = ((VertexRecord) record).getVertex(); + if (fromZeroLoop) { + final RowVertex vertexRecord; + if (vertexId instanceof IdOnlyVertex && !(vertexId.getId() instanceof VirtualId)) { + vertexRecord = + context.loadVertex( + vertexId.getId(), EmptyFilter.getInstance(), graphSchema, addingVertexFieldTypes); + } else { + vertexRecord = vertexId; + } + return record.mapTreePath( + treePath -> { + ITreePath newTreePath = treePath; + for (int i = 0; i < loopBodyPathFieldCount - 1; i++) { + newTreePath = newTreePath.extendTo((RowEdge) null); } - } - if (loopCounter < maxLoopCount && !isGlobalEmptyCycle) { - // remove eod from the loop body. - receiveEods.removeIf(eod -> eod.getSenderId() == this.loopBodyOpId); - // send EOD to the loop start. - EndOfData eod = EndOfData.of(callerOpId, id); - eod.isGlobalEmptyCycle = numProcessRecords == 0; - loopStartCollector.collect(eod); - } else { // If no data in the loop, it means the whole loop has finished. Just send EOD to the next. - super.onReceiveAllEOD(callerOpId, receiveEods); - receiveEods.clear(); - } - this.isGlobalEmptyCycle = true; - this.numProcessRecords = 0L; - this.loopCounter++; + return newTreePath.extendTo(vertexRecord); + }); + } else { + for (int i = 0; i < loopStartPathFieldCount; i++) { + pathIndices[i] = i; + } + for (int i = 0; i < loopBodyPathFieldCount; i++) { + // When calculating the index for the loopBody fields, when + // loopCounter is 1, the first offset is used for input values. After that, + // values generated by the loop are placed starting from an offset of 1 + pathIndices[i + loopStartPathFieldCount] = + loopStartPathFieldCount + Math.min(loopCounter - 1, 1) * loopBodyPathFieldCount + i; + } + return record.subPathSet(pathIndices); } + } - public int getMinLoopCount() { - return this.minLoopCount; + @Override + protected void onReceiveAllEOD(long callerOpId, List receiveEods) { + boolean isGlobalEmptyCycle = true; + for (EndOfData eod : receiveEods) { + if (eod.getSenderId() == this.loopBodyOpId) { + isGlobalEmptyCycle &= eod.isGlobalEmptyCycle; + } } - - public int getMaxLoopCount() { - return this.maxLoopCount; + if (loopCounter < maxLoopCount && !isGlobalEmptyCycle) { + // remove eod from the loop body. + receiveEods.removeIf(eod -> eod.getSenderId() == this.loopBodyOpId); + // send EOD to the loop start. + EndOfData eod = EndOfData.of(callerOpId, id); + eod.isGlobalEmptyCycle = numProcessRecords == 0; + loopStartCollector.collect(eod); + } else { // If no data in the loop, it means the whole loop has finished. Just send EOD to the + // next. + super.onReceiveAllEOD(callerOpId, receiveEods); + receiveEods.clear(); } + this.isGlobalEmptyCycle = true; + this.numProcessRecords = 0L; + this.loopCounter++; + } + + public int getMinLoopCount() { + return this.minLoopCount; + } + + public int getMaxLoopCount() { + return this.maxLoopCount; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapOperator.java index a9e7a8911..fd3992e62 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapOperator.java @@ -21,52 +21,53 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.dsl.runtime.function.graph.StepMapFunction; import org.apache.geaflow.dsl.runtime.traversal.TraversalRuntimeContext; import org.apache.geaflow.dsl.runtime.traversal.data.StepRecordWithPath; -public class StepMapOperator extends AbstractStepOperator { +public class StepMapOperator + extends AbstractStepOperator { - private List cacheRecords; - private final boolean isGlobal; + private List cacheRecords; + private final boolean isGlobal; - public StepMapOperator(long id, StepMapFunction function, boolean isGlobal) { - super(id, function); - this.isGlobal = isGlobal; - } + public StepMapOperator(long id, StepMapFunction function, boolean isGlobal) { + super(id, function); + this.isGlobal = isGlobal; + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - if (isGlobal) { - cacheRecords = new ArrayList<>(); - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + if (isGlobal) { + cacheRecords = new ArrayList<>(); } + } - @Override - public void finish() { - if (isGlobal) { - for (StepRecordWithPath record : cacheRecords) { - collect(record); - } - cacheRecords.clear(); - } - super.finish(); + @Override + public void finish() { + if (isGlobal) { + for (StepRecordWithPath record : cacheRecords) { + collect(record); + } + cacheRecords.clear(); } + super.finish(); + } - @Override - protected void processRecord(StepRecordWithPath record) { - StepRecordWithPath mapRecord = record.mapPath(path -> function.map(withParameter(path)), - null); - if (isGlobal) { - cacheRecords.add(mapRecord); - } else { - collect(mapRecord); - } + @Override + protected void processRecord(StepRecordWithPath record) { + StepRecordWithPath mapRecord = record.mapPath(path -> function.map(withParameter(path)), null); + if (isGlobal) { + cacheRecords.add(mapRecord); + } else { + collect(mapRecord); } + } - @Override - public StepOperator copyInternal() { - return new StepMapOperator(id, function, isGlobal); - } + @Override + public StepOperator copyInternal() { + return new StepMapOperator(id, function, isGlobal); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapRowOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapRowOperator.java index 560686e39..cbea68060 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapRowOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepMapRowOperator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; import java.util.List; + import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.data.StepRecord.StepRecordType; @@ -28,28 +29,28 @@ public class StepMapRowOperator extends AbstractStepOperator { - public StepMapRowOperator(long id, StepMapRowFunction function) { - super(id, function); - } + public StepMapRowOperator(long id, StepMapRowFunction function) { + super(id, function); + } - @Override - protected void processRecord(StepRecord record) { - if (record.getType() == StepRecordType.ROW) { - Row result = function.map((Row) record); - collect(result); - } else { - StepRecordWithPath recordWithPath = (StepRecordWithPath) record; - List rows = recordWithPath.map(path -> function.map(withParameter(path)), null); - if (rows != null) { - for (Row row : rows) { - collect(row); - } - } + @Override + protected void processRecord(StepRecord record) { + if (record.getType() == StepRecordType.ROW) { + Row result = function.map((Row) record); + collect(result); + } else { + StepRecordWithPath recordWithPath = (StepRecordWithPath) record; + List rows = recordWithPath.map(path -> function.map(withParameter(path)), null); + if (rows != null) { + for (Row row : rows) { + collect(row); } + } } + } - @Override - public StepOperator copyInternal() { - return new StepMapRowOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepMapRowOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepNodeFilterOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepNodeFilterOperator.java index 4b0d3dfd1..d64cfc55a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepNodeFilterOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepNodeFilterOperator.java @@ -22,23 +22,22 @@ import org.apache.geaflow.dsl.runtime.function.graph.StepNodeFilterFunction; import org.apache.geaflow.dsl.runtime.traversal.data.VertexRecord; -public class StepNodeFilterOperator extends AbstractStepOperator { +public class StepNodeFilterOperator + extends AbstractStepOperator { - public StepNodeFilterOperator(long id, StepNodeFilterFunction function) { - super(id, function); - } - - @Override - protected void processRecord(VertexRecord record) { - if (function.filter(record.getVertex())) { - collect(record); - } - } + public StepNodeFilterOperator(long id, StepNodeFilterFunction function) { + super(id, function); + } - @Override - public StepOperator copyInternal() { - return new StepNodeFilterOperator(id, function); + @Override + protected void processRecord(VertexRecord record) { + if (function.filter(record.getVertex())) { + collect(record); } + } + @Override + public StepOperator copyInternal() { + return new StepNodeFilterOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepOperator.java index f0496942f..c503774c3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepOperator.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.*; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.types.GraphSchema; @@ -29,79 +30,81 @@ public interface StepOperator extends Serializable { - /** - * The operator id. - */ - long getId(); + /** The operator id. */ + long getId(); - /** - * The operator name. - */ - String getName(); + /** The operator name. */ + String getName(); - /** - * The init method for step operator. - * @param context The context for traversal. - */ - void open(TraversalRuntimeContext context); + /** + * The init method for step operator. + * + * @param context The context for traversal. + */ + void open(TraversalRuntimeContext context); - /** - * Process input record. - * @param record The input record. - */ - void process(IN record); + /** + * Process input record. + * + * @param record The input record. + */ + void process(IN record); - void finish(); + void finish(); - void close(); + void close(); - void addNextOperator(StepOperator nextOperator); + void addNextOperator(StepOperator nextOperator); - List> getNextOperators(); + List> getNextOperators(); - StepOperator withName(String name); + StepOperator withName(String name); - /** - * Set the output path schema for the operator. - * @param outputPath The output path schema. - */ - StepOperator withOutputPathSchema(PathType outputPath); + /** + * Set the output path schema for the operator. + * + * @param outputPath The output path schema. + */ + StepOperator withOutputPathSchema(PathType outputPath); - /** - * Set the input path schema for the operator. - * @param inputPaths The input path schemas for each input. - */ - StepOperator withInputPathSchema(List inputPaths); + /** + * Set the input path schema for the operator. + * + * @param inputPaths The input path schemas for each input. + */ + StepOperator withInputPathSchema(List inputPaths); - default StepOperator withInputPathSchema(PathType pathType) { - return withInputPathSchema(Collections.singletonList(Objects.requireNonNull(pathType))); - } + default StepOperator withInputPathSchema(PathType pathType) { + return withInputPathSchema(Collections.singletonList(Objects.requireNonNull(pathType))); + } - StepOperator withOutputType(IType outputType); + StepOperator withOutputType(IType outputType); - /** - * Set the origin graph schema. - * @param graphSchema The origin graph schema defined in the DDL. - */ - StepOperator withGraphSchema(GraphSchema graphSchema); + /** + * Set the origin graph schema. + * + * @param graphSchema The origin graph schema defined in the DDL. + */ + StepOperator withGraphSchema(GraphSchema graphSchema); - /** - * Set the modified graph schema after the let-global-statement. - * @param modifyGraphSchema The modified graph schema. - */ - StepOperator withModifyGraphSchema(GraphSchema modifyGraphSchema); + /** + * Set the modified graph schema after the let-global-statement. + * + * @param modifyGraphSchema The modified graph schema. + */ + StepOperator withModifyGraphSchema(GraphSchema modifyGraphSchema); - List getInputPathSchemas(); + List getInputPathSchemas(); - PathType getOutputPathSchema(); + PathType getOutputPathSchema(); - IType getOutputType(); + IType getOutputType(); - GraphSchema getGraphSchema(); + GraphSchema getGraphSchema(); - GraphSchema getModifyGraphSchema(); + GraphSchema getModifyGraphSchema(); - List getSubQueryNames(); + List getSubQueryNames(); - StepOperator copy(); + StepOperator copy(); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepReturnOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepReturnOperator.java index 4f867c8eb..b8f4c6156 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepReturnOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepReturnOperator.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.runtime.function.graph.FunctionSchemas; @@ -30,49 +31,44 @@ import org.apache.geaflow.dsl.runtime.traversal.data.SingleValue; import org.apache.geaflow.dsl.runtime.traversal.operator.StepReturnOperator.StepReturnFunction; -public class StepReturnOperator extends AbstractStepOperator { - - public StepReturnOperator(long id) { - super(id, StepReturnFunction.INSTANCE); - } - - @Override - protected void processRecord(SingleValue record) { - collect(record); - } - - @Override - public StepOperator copyInternal() { - return new StepReturnOperator(id); - } +public class StepReturnOperator + extends AbstractStepOperator { - static class StepReturnFunction implements StepFunction { + public StepReturnOperator(long id) { + super(id, StepReturnFunction.INSTANCE); + } - public static final StepReturnFunction INSTANCE = new StepReturnFunction(); + @Override + protected void processRecord(SingleValue record) { + collect(record); + } - private StepReturnFunction() { + @Override + public StepOperator copyInternal() { + return new StepReturnOperator(id); + } - } + static class StepReturnFunction implements StepFunction { - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public static final StepReturnFunction INSTANCE = new StepReturnFunction(); - } + private StepReturnFunction() {} - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new StepReturnFunction(); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new StepReturnFunction(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSortOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSortOperator.java index 44a27d1a1..8f2abb868 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSortOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSortOperator.java @@ -26,19 +26,19 @@ public class StepSortOperator extends AbstractStepOperator { - public StepSortOperator(long id, StepSortFunction function) { - super(id, function); - } + public StepSortOperator(long id, StepSortFunction function) { + super(id, function); + } - @Override - protected void processRecord(VertexRecord record) { - for (Path path : record.getTreePath().toList()) { - function.process(record.getVertex(), path); - } + @Override + protected void processRecord(VertexRecord record) { + for (Path path : record.getTreePath().toList()) { + function.process(record.getVertex(), path); } + } - @Override - public StepOperator copyInternal() { - return new StepSortOperator(id, function); - } + @Override + public StepOperator copyInternal() { + return new StepSortOperator(id, function); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSourceOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSourceOperator.java index c3273b5eb..75c1db654 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSourceOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSourceOperator.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.traversal.operator; -import com.google.common.collect.Sets; import java.io.Serializable; import java.util.Collection; import java.util.Collections; @@ -27,6 +26,7 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.types.PathType; @@ -38,180 +38,173 @@ import org.apache.geaflow.dsl.runtime.traversal.data.EndOfData; import org.apache.geaflow.dsl.runtime.traversal.data.VertexRecord; -public class StepSourceOperator extends AbstractStepOperator { +import com.google.common.collect.Sets; - private final Set startIds; +public class StepSourceOperator + extends AbstractStepOperator { - public StepSourceOperator(long id, Set startIds) { - super(id, SourceStepFunction.INSTANCE); - this.startIds = Sets.newHashSet(Objects.requireNonNull(startIds)); - } + private final Set startIds; + + public StepSourceOperator(long id, Set startIds) { + super(id, SourceStepFunction.INSTANCE); + this.startIds = Sets.newHashSet(Objects.requireNonNull(startIds)); + } + + @Override + protected void processRecord(VertexRecord record) { + collect(record); + } + + @Override + public PathType getOutputPathSchema() { + return PathType.EMPTY; + } + + @Override + protected boolean hasReceivedAllEod(List receiveEods) { + // For source operator, the input is empty, so if it has received eod, + // it will trigger the onReceiveAllEOD. + return !receiveEods.isEmpty(); + } + + @Override + public StepOperator copyInternal() { + return new StepSourceOperator(id, Sets.newHashSet(startIds)); + } + + private static class SourceStepFunction implements StepFunction { + + public static final StepFunction INSTANCE = new SourceStepFunction(); @Override - protected void processRecord(VertexRecord record) { - collect(record); - } + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} @Override - public PathType getOutputPathSchema() { - return PathType.EMPTY; - } + public void finish(StepCollector collector) {} @Override - protected boolean hasReceivedAllEod(List receiveEods) { - // For source operator, the input is empty, so if it has received eod, - // it will trigger the onReceiveAllEOD. - return !receiveEods.isEmpty(); + public List getExpressions() { + return Collections.emptyList(); } @Override - public StepOperator copyInternal() { - return new StepSourceOperator(id, Sets.newHashSet(startIds)); + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new SourceStepFunction(); } + } + + public Set getStartIds() { + return startIds; + } + + public void addStartIds(Collection ids) { + this.startIds.addAll(ids); + } + + public void unionStartId(Collection ids) { + if (ids.isEmpty()) { + // If same branch need traversal all, the startIds should be empty. + this.startIds.clear(); + } else { + if (!this.startIds.isEmpty()) { + this.startIds.addAll(ids); + } + } + } - private static class SourceStepFunction implements StepFunction { - - public static final StepFunction INSTANCE = new SourceStepFunction(); + public void joinStartId(Collection ids) { + if (ids.isEmpty()) { // empty start id list means traversal all. + return; + } + if (this.startIds.isEmpty()) { + this.startIds.addAll(ids); + } - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + Set intersections = + this.startIds.stream().filter(ids::contains).collect(Collectors.toSet()); + this.startIds.clear(); + this.startIds.addAll(intersections); + } - } + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(getName()); + String startId = StringUtils.join(startIds, ","); + str.append("(").append(startId).append(")"); + return str.toString(); + } - @Override - public void finish(StepCollector collector) { + public interface StartId extends Serializable {} - } + public static class ConstantStartId implements StartId { - @Override - public List getExpressions() { - return Collections.emptyList(); - } + private final Object value; - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new SourceStepFunction(); - } + public ConstantStartId(Object value) { + this.value = value; } - public Set getStartIds() { - return startIds; + public Object getValue() { + return value; } - public void addStartIds(Collection ids) { - this.startIds.addAll(ids); + @Override + public String toString() { + return Objects.toString(value); } - public void unionStartId(Collection ids) { - if (ids.isEmpty()) { - // If same branch need traversal all, the startIds should be empty. - this.startIds.clear(); - } else { - if (!this.startIds.isEmpty()) { - this.startIds.addAll(ids); - } - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ConstantStartId)) { + return false; + } + ConstantStartId that = (ConstantStartId) o; + return Objects.equals(value, that.value); } - public void joinStartId(Collection ids) { - if (ids.isEmpty()) { // empty start id list means traversal all. - return; - } - if (this.startIds.isEmpty()) { - this.startIds.addAll(ids); - } - - Set intersections = this.startIds.stream() - .filter(ids::contains) - .collect(Collectors.toSet()); - this.startIds.clear(); - this.startIds.addAll(intersections); + @Override + public int hashCode() { + return Objects.hash(value); } + } - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(getName()); - String startId = StringUtils.join(startIds, ","); - str.append("(").append(startId).append(")"); - return str.toString(); + public static class ParameterStartId implements StartId { + + private final Expression idExpression; + + public ParameterStartId(Expression idExpression) { + this.idExpression = idExpression; } - public interface StartId extends Serializable { + public Expression getIdExpression() { + return idExpression; + } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ParameterStartId)) { + return false; + } + ParameterStartId that = (ParameterStartId) o; + return Objects.equals(idExpression, that.idExpression); } - public static class ConstantStartId implements StartId { - - private final Object value; - - public ConstantStartId(Object value) { - this.value = value; - } - - public Object getValue() { - return value; - } - - @Override - public String toString() { - return Objects.toString(value); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ConstantStartId)) { - return false; - } - ConstantStartId that = (ConstantStartId) o; - return Objects.equals(value, that.value); - } - - @Override - public int hashCode() { - return Objects.hash(value); - } + @Override + public int hashCode() { + return Objects.hash(idExpression); } - public static class ParameterStartId implements StartId { - - private final Expression idExpression; - - public ParameterStartId(Expression idExpression) { - this.idExpression = idExpression; - } - - - public Expression getIdExpression() { - return idExpression; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ParameterStartId)) { - return false; - } - ParameterStartId that = (ParameterStartId) o; - return Objects.equals(idExpression, that.idExpression); - } - - @Override - public int hashCode() { - return Objects.hash(idExpression); - } - - @Override - public String toString() { - return "ParameterStartId{" - + "startIdExpression=" + idExpression.showExpression() - + '}'; - } + @Override + public String toString() { + return "ParameterStartId{" + "startIdExpression=" + idExpression.showExpression() + '}'; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSubQueryStartOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSubQueryStartOperator.java index bc84f5cf3..fec601daa 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSubQueryStartOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepSubQueryStartOperator.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.runtime.expression.Expression; import org.apache.geaflow.dsl.runtime.function.graph.FunctionSchemas; @@ -32,69 +33,66 @@ import org.apache.geaflow.dsl.runtime.traversal.data.EndOfData; import org.apache.geaflow.dsl.runtime.traversal.data.VertexRecord; -public class StepSubQueryStartOperator extends AbstractStepOperator { +public class StepSubQueryStartOperator + extends AbstractStepOperator { - private final String queryName; + private final String queryName; - public StepSubQueryStartOperator(long id, String queryName) { - super(id, SubQueryStartFunction.INSTANCE); - this.queryName = Objects.requireNonNull(queryName); - } + public StepSubQueryStartOperator(long id, String queryName) { + super(id, SubQueryStartFunction.INSTANCE); + this.queryName = Objects.requireNonNull(queryName); + } - @Override - protected void processRecord(VertexRecord record) { - Object requestId = context.getRequestId(); - if (requestId instanceof CallRequestId) { - context.stashCallRequestId((CallRequestId) requestId); - } - collect(record); + @Override + protected void processRecord(VertexRecord record) { + Object requestId = context.getRequestId(); + if (requestId instanceof CallRequestId) { + context.stashCallRequestId((CallRequestId) requestId); } + collect(record); + } - @Override - protected boolean hasReceivedAllEod(List receiveEods) { - return receiveEods.size() == numTasks; - } + @Override + protected boolean hasReceivedAllEod(List receiveEods) { + return receiveEods.size() == numTasks; + } - @Override - public StepOperator copyInternal() { - return new StepSubQueryStartOperator(id, queryName); - } + @Override + public StepOperator copyInternal() { + return new StepSubQueryStartOperator(id, queryName); + } - @Override - public String toString() { - StringBuilder str = new StringBuilder(); - str.append(getName()); - str.append("(name=").append(queryName).append(")"); - return str.toString(); - } + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(getName()); + str.append("(name=").append(queryName).append(")"); + return str.toString(); + } - public String getQueryName() { - return queryName; - } - - private static class SubQueryStartFunction implements StepFunction { + public String getQueryName() { + return queryName; + } - public static final SubQueryStartFunction INSTANCE = new SubQueryStartFunction(); + private static class SubQueryStartFunction implements StepFunction { - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { + public static final SubQueryStartFunction INSTANCE = new SubQueryStartFunction(); - } - - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - assert expressions.isEmpty(); - return new SubQueryStartFunction(); - } + @Override + public StepFunction copy(List expressions) { + assert expressions.isEmpty(); + return new SubQueryStartFunction(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepUnionOperator.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepUnionOperator.java index 431bb19e5..4c562f41f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepUnionOperator.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/operator/StepUnionOperator.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.dsl.common.data.StepRecord; import org.apache.geaflow.dsl.common.types.PathType; import org.apache.geaflow.dsl.runtime.expression.Expression; @@ -34,69 +35,65 @@ import org.apache.geaflow.dsl.runtime.traversal.operator.StepUnionOperator.StepUnionFunction; import org.apache.geaflow.dsl.runtime.util.SchemaUtil; -public class StepUnionOperator extends AbstractStepOperator { +public class StepUnionOperator + extends AbstractStepOperator { - private final Map inputSchemas = new HashMap<>(); + private final Map inputSchemas = new HashMap<>(); - public StepUnionOperator(long id) { - super(id, StepUnionFunction.INSTANCE); - } + public StepUnionOperator(long id) { + super(id, StepUnionFunction.INSTANCE); + } - @Override - public void open(TraversalRuntimeContext context) { - super.open(context); - List inputIds = context.getTopology().getInputIds(id); - for (long inputId : inputIds) { - PathType pathType = context.getTopology().getOperator(inputId).getOutputPathSchema(); - this.inputSchemas.put(inputId, pathType); - } + @Override + public void open(TraversalRuntimeContext context) { + super.open(context); + List inputIds = context.getTopology().getInputIds(id); + for (long inputId : inputIds) { + PathType pathType = context.getTopology().getOperator(inputId).getOutputPathSchema(); + this.inputSchemas.put(inputId, pathType); } + } - @Override - protected void processRecord(StepRecordWithPath record) { - final long inputId = context.getInputOperatorId(); - PathType pathType = inputSchemas.get(inputId); - StepRecordWithPath unionPath = record.mapPath(path -> - SchemaUtil.alignToPathSchema(path, pathType, outputPathSchema), null); - collect(unionPath); - } + @Override + protected void processRecord(StepRecordWithPath record) { + final long inputId = context.getInputOperatorId(); + PathType pathType = inputSchemas.get(inputId); + StepRecordWithPath unionPath = + record.mapPath( + path -> SchemaUtil.alignToPathSchema(path, pathType, outputPathSchema), null); + collect(unionPath); + } - @Override - protected PathType concatInputPathType() { - return this.getOutputPathSchema(); - } + @Override + protected PathType concatInputPathType() { + return this.getOutputPathSchema(); + } - @Override - public StepOperator copyInternal() { - return new StepUnionOperator(id); - } + @Override + public StepOperator copyInternal() { + return new StepUnionOperator(id); + } - public static class StepUnionFunction implements StepFunction { + public static class StepUnionFunction implements StepFunction { - public static final StepUnionFunction INSTANCE = new StepUnionFunction(); + public static final StepUnionFunction INSTANCE = new StepUnionFunction(); - private StepUnionFunction() { + private StepUnionFunction() {} - } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - - } - - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - return new StepUnionFunction(); - } + @Override + public StepFunction copy(List expressions) { + return new StepUnionFunction(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractSingleTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractSingleTreePath.java index 1da81aaf6..f2edc5178 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractSingleTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractSingleTreePath.java @@ -19,40 +19,39 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.google.common.collect.Sets; import java.util.Collection; import java.util.HashSet; import java.util.Set; +import com.google.common.collect.Sets; + public abstract class AbstractSingleTreePath extends AbstractTreePath { - /** - * The request ids of this node belongs to. - */ - protected Set requestIds; - - @Override - public void setRequestId(Object requestId) { - if (requestId == null) { - requestIds = null; - } else { - requestIds = Sets.newHashSet(requestId); - } - } + /** The request ids of this node belongs to. */ + protected Set requestIds; - @Override - public Set getRequestIds() { - return requestIds; + @Override + public void setRequestId(Object requestId) { + if (requestId == null) { + requestIds = null; + } else { + requestIds = Sets.newHashSet(requestId); } + } + + @Override + public Set getRequestIds() { + return requestIds; + } - @Override - public void addRequestIds(Collection requestIds) { - if (requestIds == null) { - return; - } - if (this.requestIds == null) { - this.requestIds = new HashSet<>(); - } - this.requestIds.addAll(requestIds); + @Override + public void addRequestIds(Collection requestIds) { + if (requestIds == null) { + return; + } + if (this.requestIds == null) { + this.requestIds = new HashSet<>(); } + this.requestIds.addAll(requestIds); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractTreePath.java index d5a9e6c58..625198c86 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/AbstractTreePath.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.dsl.runtime.traversal.path.TreePaths.createTreePath; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -29,6 +28,7 @@ import java.util.List; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; @@ -40,444 +40,463 @@ import org.apache.geaflow.dsl.runtime.traversal.message.IMessage; import org.apache.geaflow.dsl.runtime.traversal.message.MessageType; +import com.google.common.collect.Lists; + public abstract class AbstractTreePath implements ITreePath { - @Override - public boolean isEmpty() { - return size() == 0; - } + @Override + public boolean isEmpty() { + return size() == 0; + } - @Override - public void addParent(ITreePath parent) { - if (parent.getNodeType() != NodeType.EMPTY_TREE) { - getParents().add(parent); - } + @Override + public void addParent(ITreePath parent) { + if (parent.getNodeType() != NodeType.EMPTY_TREE) { + getParents().add(parent); } - - @SuppressWarnings("unchecked") - @Override - public ITreePath merge(ITreePath other) { - if (other.getNodeType() == NodeType.UNION_TREE) { - return other.merge(this); - } else if (other.getNodeType() == NodeType.EMPTY_TREE) { - return this; - } - if (other.getNodeType() != getNodeType() - && !(getNodeType() == NodeType.VERTEX_TREE && getVertex() == null) - && !(other.getNodeType() == NodeType.VERTEX_TREE && other.getVertex() == null)) { - throw new GeaFlowDSLException("Merge with different tree kinds: " + getNodeType() - + " and " + other.getNodeType()); - } - - if (this.equalNode(other) && getDepth() == other.getDepth()) { - for (ITreePath parent : other.getParents()) { - addParent(parent); - } - this.addRequestIds(other.getRequestIds()); - return this; - } else { - return UnionTreePath.create(Lists.newArrayList(this, other)).optimize(); - } + } + + @SuppressWarnings("unchecked") + @Override + public ITreePath merge(ITreePath other) { + if (other.getNodeType() == NodeType.UNION_TREE) { + return other.merge(this); + } else if (other.getNodeType() == NodeType.EMPTY_TREE) { + return this; } - - @Override - public ITreePath optimize() { - if (getParents().isEmpty()) { - return this; - } - List optimizedParents = new ArrayList<>(getParents().size()); - for (ITreePath parent : getParents()) { - optimizedParents.add(parent.optimize()); - } - - List mergedParents = new ArrayList<>(optimizedParents.size()); - mergedParents.add(optimizedParents.get(0).copy()); - - for (int i = 1; i < optimizedParents.size(); i++) { - ITreePath parent = optimizedParents.get(i); - mergeParent(parent, mergedParents); - } - return copy(mergedParents); + if (other.getNodeType() != getNodeType() + && !(getNodeType() == NodeType.VERTEX_TREE && getVertex() == null) + && !(other.getNodeType() == NodeType.VERTEX_TREE && other.getVertex() == null)) { + throw new GeaFlowDSLException( + "Merge with different tree kinds: " + getNodeType() + " and " + other.getNodeType()); } - private void mergeParent(ITreePath parent, List mergedParents) { - boolean hasMerged = false; - if (parent.getNodeType() == NodeType.EDGE_TREE) { - for (ITreePath mergedParent : mergedParents) { - if (mergedParent.getNodeType() == NodeType.EDGE_TREE - && mergedParent.getEdgeSet().like(parent.getEdgeSet()) - && Objects.equals(mergedParent.getParents(), parent.getParents())) { - mergedParent.getEdgeSet().addEdges(parent.getEdgeSet()); - hasMerged = true; - break; - } - } - } - if (!hasMerged) { - mergedParents.add(parent); - } + if (this.equalNode(other) && getDepth() == other.getDepth()) { + for (ITreePath parent : other.getParents()) { + addParent(parent); + } + this.addRequestIds(other.getRequestIds()); + return this; + } else { + return UnionTreePath.create(Lists.newArrayList(this, other)).optimize(); } + } - - @Override - public ITreePath extendTo(Set requestIds, List edges) { - EdgeSet edgeSet = new DefaultEdgeSet(edges); - ITreePath newTreePath = EdgeTreePath.of(requestIds, edgeSet); - newTreePath.addParent(this); - return newTreePath; + @Override + public ITreePath optimize() { + if (getParents().isEmpty()) { + return this; } - - @Override - public ITreePath extendTo(Set requestIds, RowVertex vertex) { - ITreePath newTreePath = VertexTreePath.of(requestIds, vertex); - newTreePath.addParent(this); - return newTreePath; + List optimizedParents = new ArrayList<>(getParents().size()); + for (ITreePath parent : getParents()) { + optimizedParents.add(parent.optimize()); } - @Override - public ITreePath limit(int n) { - if (n < 0 || n == Integer.MAX_VALUE) { - return this; - } - final List limitPaths = new ArrayList<>(); - walkTree(paths -> { - int rest = n - limitPaths.size(); - if (rest > 0) { - limitPaths.addAll(paths.subList(0, Math.min(rest, paths.size()))); - return true; - } else { - return false; - } - }); - return createTreePath(limitPaths); - } + List mergedParents = new ArrayList<>(optimizedParents.size()); + mergedParents.add(optimizedParents.get(0).copy()); - @Override - public ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices) { - int depth = getDepth(); - int[] mapping = createMappingIndices(refPathIndices, depth); - return filter(filterFunction, refPathIndices, mapping, new DefaultPath(), depth, new PathIdCounter()); + for (int i = 1; i < optimizedParents.size(); i++) { + ITreePath parent = optimizedParents.get(i); + mergeParent(parent, mergedParents); } - - /** - * Filter on the tree path.This is a fast implementation which filter the tree without - * expand this tree and rebuild the filter paths to a tree. - */ - protected ITreePath filter(PathFilterFunction filterFunction, - int[] refPathIndices, - int[] fieldMapping, - Path currentPath, - int maxDepth, - PathIdCounter pathId) { - if (refPathIndices.length == 0) { - // filter function has not referred any fields in the path. - if (filterFunction.accept(null)) { - return this; - } - return EmptyTreePath.of(); - } - int pathIndex = maxDepth - currentPath.size() - 1; - int parentSize = getParents().size(); - switch (getNodeType()) { - case VERTEX_TREE: - currentPath.addNode(getVertex()); - break; - case EDGE_TREE: - // If this edge set is referred by the filter function, do filter - // for each edge in the set. - if (Arrays.binarySearch(refPathIndices, pathIndex) >= 0) { - EdgeSet edges = getEdgeSet(); - List filterTrees = new ArrayList<>(); - for (RowEdge edge : edges) { - currentPath.addNode(edge); - // if the parent is empty or reach the last referred path node in the filter function. - if (parentSize == 0 || pathIndex == refPathIndices[0]) { - // Align the field indices of the current path with the referred index in the function. - FieldAlignPath alignPath = new FieldAlignPath(currentPath, fieldMapping); - alignPath.setId(pathId.getAndInc()); - if (filterFunction.accept(alignPath)) { - EdgeTreePath edgeTreePath = EdgeTreePath.of(null, edge); - if (parentSize > 0) { - edgeTreePath.addParent(getParents().get(0)); - } - filterTrees.add(edgeTreePath); - } - } else if (parentSize >= 1) { - for (ITreePath parent : getParents()) { - ITreePath filterTree = ((AbstractTreePath) parent).filter(filterFunction, - refPathIndices, fieldMapping, currentPath, maxDepth, pathId); - if (!filterTree.isEmpty()) { - filterTrees.add(filterTree.extendTo(edge)); - } - } - } - currentPath.remove(currentPath.size() - 1); - } - return UnionTreePath.create(filterTrees); - } else { // edge is not referred in the filter function, so add null to the current path. - currentPath.addNode(null); - } - break; - default: - throw new IllegalArgumentException("Illegal tree node: " + getNodeType()); + return copy(mergedParents); + } + + private void mergeParent(ITreePath parent, List mergedParents) { + boolean hasMerged = false; + if (parent.getNodeType() == NodeType.EDGE_TREE) { + for (ITreePath mergedParent : mergedParents) { + if (mergedParent.getNodeType() == NodeType.EDGE_TREE + && mergedParent.getEdgeSet().like(parent.getEdgeSet()) + && Objects.equals(mergedParent.getParents(), parent.getParents())) { + mergedParent.getEdgeSet().addEdges(parent.getEdgeSet()); + hasMerged = true; + break; } - // reach the last referred path node. (refPathIndices is sorted, so refPathIndices[0] is the - // last referred path field). - if (pathIndex == refPathIndices[0]) { - // Align the field indices of the current path with the referred index in the function. - FieldAlignPath alignPath = new FieldAlignPath(currentPath, fieldMapping); - alignPath.setId(pathId.getAndInc()); - boolean accept = filterFunction.accept(alignPath); - // remove current node before return. - currentPath.remove(currentPath.size() - 1); - if (accept) { - return this; - } - return EmptyTreePath.of(); - } - // filter parent tree - List filterParents = new ArrayList<>(parentSize); - for (ITreePath parent : getParents()) { - ITreePath filterTree = ((AbstractTreePath) parent).filter(filterFunction, refPathIndices, - fieldMapping, currentPath, maxDepth, pathId); - if (!filterTree.isEmpty()) { - filterParents.add(filterTree); - } - } - // remove current node before return. - currentPath.remove(currentPath.size() - 1); - if (filterParents.size() > 0) { - return copy(filterParents); - } - // If all the parents has be filtered, then this tree will be filtered and just return an empty tree - // to the child node. - return EmptyTreePath.of(); + } } - - /** - * Create field mapping for the referred path indices. - * - *

e.g. The refPathIndices is: [1, 3], the total path field is: 4, then the path layout is: [3, 2, 1, 0] - * which is the reverse order, Then the mapping index is: [-1, 2, -1, 0] which will mapping $3 to $0 in the path - * layout, mapping $1 to $2, for other field not exists in the referring indices, will map to -1 which - * means the field not exists.You can also see {@link FieldAlignPath} for more information. - */ - private int[] createMappingIndices(int[] refPathIndices, int totalPathField) { - if (refPathIndices.length == 0) { - return new int[0]; - } - int[] mapping = new int[refPathIndices[refPathIndices.length - 1] + 1]; - Arrays.fill(mapping, -1); - for (int i = 0; i < refPathIndices.length; i++) { - mapping[refPathIndices[i]] = totalPathField - refPathIndices[i] - 1; - } - return mapping; + if (!hasMerged) { + mergedParents.add(parent); } - - @Override - public ITreePath mapTree(PathMapFunction mapFunction) { - List mapPaths = map(mapFunction); - return createTreePath(mapPaths); + } + + @Override + public ITreePath extendTo(Set requestIds, List edges) { + EdgeSet edgeSet = new DefaultEdgeSet(edges); + ITreePath newTreePath = EdgeTreePath.of(requestIds, edgeSet); + newTreePath.addParent(this); + return newTreePath; + } + + @Override + public ITreePath extendTo(Set requestIds, RowVertex vertex) { + ITreePath newTreePath = VertexTreePath.of(requestIds, vertex); + newTreePath.addParent(this); + return newTreePath; + } + + @Override + public ITreePath limit(int n) { + if (n < 0 || n == Integer.MAX_VALUE) { + return this; } - - @Override - public List toList() { - final List pathList = new ArrayList<>(); - walkTree(paths -> { - pathList.addAll(paths); + final List limitPaths = new ArrayList<>(); + walkTree( + paths -> { + int rest = n - limitPaths.size(); + if (rest > 0) { + limitPaths.addAll(paths.subList(0, Math.min(rest, paths.size()))); return true; + } else { + return false; + } }); - return pathList; + return createTreePath(limitPaths); + } + + @Override + public ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices) { + int depth = getDepth(); + int[] mapping = createMappingIndices(refPathIndices, depth); + return filter( + filterFunction, refPathIndices, mapping, new DefaultPath(), depth, new PathIdCounter()); + } + + /** + * Filter on the tree path.This is a fast implementation which filter the tree without expand this + * tree and rebuild the filter paths to a tree. + */ + protected ITreePath filter( + PathFilterFunction filterFunction, + int[] refPathIndices, + int[] fieldMapping, + Path currentPath, + int maxDepth, + PathIdCounter pathId) { + if (refPathIndices.length == 0) { + // filter function has not referred any fields in the path. + if (filterFunction.accept(null)) { + return this; + } + return EmptyTreePath.of(); } - - @Override - public List select(int... pathIndices) { - if (ArrayUtil.isEmpty(pathIndices)) { - return new ArrayList<>(); - } - int minIndex = pathIndices[0]; - int maxIndex = pathIndices[0]; - for (int index : pathIndices) { - if (index < minIndex) { - minIndex = index; - } - if (index > maxIndex) { - maxIndex = index; + int pathIndex = maxDepth - currentPath.size() - 1; + int parentSize = getParents().size(); + switch (getNodeType()) { + case VERTEX_TREE: + currentPath.addNode(getVertex()); + break; + case EDGE_TREE: + // If this edge set is referred by the filter function, do filter + // for each edge in the set. + if (Arrays.binarySearch(refPathIndices, pathIndex) >= 0) { + EdgeSet edges = getEdgeSet(); + List filterTrees = new ArrayList<>(); + for (RowEdge edge : edges) { + currentPath.addNode(edge); + // if the parent is empty or reach the last referred path node in the filter function. + if (parentSize == 0 || pathIndex == refPathIndices[0]) { + // Align the field indices of the current path with the referred index in the + // function. + FieldAlignPath alignPath = new FieldAlignPath(currentPath, fieldMapping); + alignPath.setId(pathId.getAndInc()); + if (filterFunction.accept(alignPath)) { + EdgeTreePath edgeTreePath = EdgeTreePath.of(null, edge); + if (parentSize > 0) { + edgeTreePath.addParent(getParents().get(0)); + } + filterTrees.add(edgeTreePath); + } + } else if (parentSize >= 1) { + for (ITreePath parent : getParents()) { + ITreePath filterTree = + ((AbstractTreePath) parent) + .filter( + filterFunction, + refPathIndices, + fieldMapping, + currentPath, + maxDepth, + pathId); + if (!filterTree.isEmpty()) { + filterTrees.add(filterTree.extendTo(edge)); + } + } } + currentPath.remove(currentPath.size() - 1); + } + return UnionTreePath.create(filterTrees); + } else { // edge is not referred in the filter function, so add null to the current path. + currentPath.addNode(null); } - //if select index is out boundary return empty list - if (maxIndex >= getDepth()) { - return new ArrayList<>(); - } - int maxDepth = getDepth() - minIndex; - // 调整原来的index - List newIndices = new ArrayList<>(); - for (int index : pathIndices) { - newIndices.add(index - minIndex); - } - - Set selectPaths = new HashSet<>(); - walkTree(paths -> { - for (Path path : paths) { - selectPaths.add(path.subPath(newIndices)); - } - return true; - }, maxDepth); - - return Lists.newArrayList(selectPaths); + break; + default: + throw new IllegalArgumentException("Illegal tree node: " + getNodeType()); } - - @Override - public ITreePath subPath(int... pathIndices) { - List paths = select(pathIndices); - return TreePaths.createTreePath(paths); + // reach the last referred path node. (refPathIndices is sorted, so refPathIndices[0] is the + // last referred path field). + if (pathIndex == refPathIndices[0]) { + // Align the field indices of the current path with the referred index in the function. + FieldAlignPath alignPath = new FieldAlignPath(currentPath, fieldMapping); + alignPath.setId(pathId.getAndInc()); + boolean accept = filterFunction.accept(alignPath); + // remove current node before return. + currentPath.remove(currentPath.size() - 1); + if (accept) { + return this; + } + return EmptyTreePath.of(); } - - public int getDepth() { - if (getParents().isEmpty()) { - return 1; - } - return getParents().get(0).getDepth() + 1; + // filter parent tree + List filterParents = new ArrayList<>(parentSize); + for (ITreePath parent : getParents()) { + ITreePath filterTree = + ((AbstractTreePath) parent) + .filter(filterFunction, refPathIndices, fieldMapping, currentPath, maxDepth, pathId); + if (!filterTree.isEmpty()) { + filterParents.add(filterTree); + } } - - protected void walkTree(WalkFunction walkFunction) { - walkTree(walkFunction, -1); + // remove current node before return. + currentPath.remove(currentPath.size() - 1); + if (filterParents.size() > 0) { + return copy(filterParents); } - - protected void walkTree(WalkFunction walkFunction, int maxDepth) { - walkTree(new ArrayList<>(), walkFunction, maxDepth, new PathIdCounter()); + // If all the parents has be filtered, then this tree will be filtered and just return an empty + // tree + // to the child node. + return EmptyTreePath.of(); + } + + /** + * Create field mapping for the referred path indices. + * + *

e.g. The refPathIndices is: [1, 3], the total path field is: 4, then the path layout is: [3, + * 2, 1, 0] which is the reverse order, Then the mapping index is: [-1, 2, -1, 0] which will + * mapping $3 to $0 in the path layout, mapping $1 to $2, for other field not exists in the + * referring indices, will map to -1 which means the field not exists.You can also see {@link + * FieldAlignPath} for more information. + */ + private int[] createMappingIndices(int[] refPathIndices, int totalPathField) { + if (refPathIndices.length == 0) { + return new int[0]; } - - @SuppressWarnings("unchecked") - public boolean walkTree(List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { - boolean isContinue = true; - switch (getNodeType()) { - case VERTEX_TREE: - pathNodes.add(getVertex()); - break; - case EDGE_TREE: - pathNodes.add(getEdgeSet()); - break; - default: - throw new IllegalArgumentException("Cannot walk on this kind of tree:" + getNodeType()); - } - // Reach the last node - if (getParents().isEmpty() || pathNodes.size() == maxDepth) { - List paths = new ArrayList<>(); - paths.add(new DefaultPath()); - - for (int i = pathNodes.size() - 1; i >= 0; i--) { - Object pathNode = pathNodes.get(i); - if (pathNode instanceof RowVertex) { - for (Path path : paths) { - path.addNode((Row) pathNode); - } - } else if (pathNode instanceof EdgeSet) { - EdgeSet edgeSet = (EdgeSet) pathNode; - List newPaths = new ArrayList<>(paths.size() * edgeSet.size()); - for (Path path : paths) { - for (RowEdge edge : edgeSet) { - Path newPath = path.copy(); - newPath.addNode(edge); - newPaths.add(newPath); - } - } - paths = newPaths; - } else if (pathNode == null) { - for (Path path : paths) { - path.addNode(null); - } - } else { - throw new IllegalArgumentException("Illegal path node: " + pathNode); - } - } - // set id to the path. - for (Path path : paths) { - path.setId(pathId.getAndInc()); - } - isContinue = walkFunction.onWalk(paths); - } else { - for (ITreePath parent : getParents()) { - isContinue = parent.walkTree(pathNodes, walkFunction, maxDepth, pathId); - if (!isContinue) { - break; - } - } - } - - pathNodes.remove(pathNodes.size() - 1); - return isContinue; + int[] mapping = new int[refPathIndices[refPathIndices.length - 1] + 1]; + Arrays.fill(mapping, -1); + for (int i = 0; i < refPathIndices.length; i++) { + mapping[refPathIndices[i]] = totalPathField - refPathIndices[i] - 1; } + return mapping; + } + + @Override + public ITreePath mapTree(PathMapFunction mapFunction) { + List mapPaths = map(mapFunction); + return createTreePath(mapPaths); + } + + @Override + public List toList() { + final List pathList = new ArrayList<>(); + walkTree( + paths -> { + pathList.addAll(paths); + return true; + }); + return pathList; + } - @Override - public void setRequestIdForTree(Object requestId) { - for (ITreePath parent : getParents()) { - parent.setRequestIdForTree(requestId); - } - setRequestId(requestId); + @Override + public List select(int... pathIndices) { + if (ArrayUtil.isEmpty(pathIndices)) { + return new ArrayList<>(); } - - @Override - public MessageType getType() { - return MessageType.PATH; + int minIndex = pathIndices[0]; + int maxIndex = pathIndices[0]; + for (int index : pathIndices) { + if (index < minIndex) { + minIndex = index; + } + if (index > maxIndex) { + maxIndex = index; + } + } + // if select index is out boundary return empty list + if (maxIndex >= getDepth()) { + return new ArrayList<>(); + } + int maxDepth = getDepth() - minIndex; + // 调整原来的index + List newIndices = new ArrayList<>(); + for (int index : pathIndices) { + newIndices.add(index - minIndex); } - @Override - public IMessage combine(IMessage other) { - if (this.getNodeType() == NodeType.EMPTY_TREE) { - return other; - } - if (((ITreePath) other).getNodeType() == NodeType.EMPTY_TREE) { - return this; - } - List nodes = new ArrayList<>(); - if (this instanceof UnionTreePath) { - nodes.addAll(((UnionTreePath) this).getNodes()); + Set selectPaths = new HashSet<>(); + walkTree( + paths -> { + for (Path path : paths) { + selectPaths.add(path.subPath(newIndices)); + } + return true; + }, + maxDepth); + + return Lists.newArrayList(selectPaths); + } + + @Override + public ITreePath subPath(int... pathIndices) { + List paths = select(pathIndices); + return TreePaths.createTreePath(paths); + } + + public int getDepth() { + if (getParents().isEmpty()) { + return 1; + } + return getParents().get(0).getDepth() + 1; + } + + protected void walkTree(WalkFunction walkFunction) { + walkTree(walkFunction, -1); + } + + protected void walkTree(WalkFunction walkFunction, int maxDepth) { + walkTree(new ArrayList<>(), walkFunction, maxDepth, new PathIdCounter()); + } + + @SuppressWarnings("unchecked") + public boolean walkTree( + List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { + boolean isContinue = true; + switch (getNodeType()) { + case VERTEX_TREE: + pathNodes.add(getVertex()); + break; + case EDGE_TREE: + pathNodes.add(getEdgeSet()); + break; + default: + throw new IllegalArgumentException("Cannot walk on this kind of tree:" + getNodeType()); + } + // Reach the last node + if (getParents().isEmpty() || pathNodes.size() == maxDepth) { + List paths = new ArrayList<>(); + paths.add(new DefaultPath()); + + for (int i = pathNodes.size() - 1; i >= 0; i--) { + Object pathNode = pathNodes.get(i); + if (pathNode instanceof RowVertex) { + for (Path path : paths) { + path.addNode((Row) pathNode); + } + } else if (pathNode instanceof EdgeSet) { + EdgeSet edgeSet = (EdgeSet) pathNode; + List newPaths = new ArrayList<>(paths.size() * edgeSet.size()); + for (Path path : paths) { + for (RowEdge edge : edgeSet) { + Path newPath = path.copy(); + newPath.addNode(edge); + newPaths.add(newPath); + } + } + paths = newPaths; + } else if (pathNode == null) { + for (Path path : paths) { + path.addNode(null); + } } else { - nodes.add(this); + throw new IllegalArgumentException("Illegal path node: " + pathNode); } - if (other instanceof UnionTreePath) { - nodes.addAll(((UnionTreePath) other).getNodes()); - } else { - nodes.add((ITreePath) other); + } + // set id to the path. + for (Path path : paths) { + path.setId(pathId.getAndInc()); + } + isContinue = walkFunction.onWalk(paths); + } else { + for (ITreePath parent : getParents()) { + isContinue = parent.walkTree(pathNodes, walkFunction, maxDepth, pathId); + if (!isContinue) { + break; } - return UnionTreePath.create(nodes); + } } - @Override - public ITreePath getMessageByRequestId(Object requestId) { - return this.getTreePath(requestId); - } + pathNodes.remove(pathNodes.size() - 1); + return isContinue; + } - @Override - public List map(PathMapFunction mapFunction) { - List paths = toList(); - List results = new ArrayList<>(); - for (Path path : paths) { - O result = mapFunction.map(path); - results.add(result); - } - return results; + @Override + public void setRequestIdForTree(Object requestId) { + for (ITreePath parent : getParents()) { + parent.setRequestIdForTree(requestId); } - - @Override - public List flatMap(PathFlatMapFunction flatMapFunction) { - List paths = toList(); - List finalResults = new ArrayList<>(); - for (Path path : paths) { - Collection results = flatMapFunction.flatMap(path); - finalResults.addAll(results); - } - return finalResults; + setRequestId(requestId); + } + + @Override + public MessageType getType() { + return MessageType.PATH; + } + + @Override + public IMessage combine(IMessage other) { + if (this.getNodeType() == NodeType.EMPTY_TREE) { + return other; } - - @Override - public ITreePath extendTo(RowEdge edge) { - return extendTo(null, Lists.newArrayList(edge)); + if (((ITreePath) other).getNodeType() == NodeType.EMPTY_TREE) { + return this; } - - @Override - public ITreePath extendTo(RowVertex vertex) { - return extendTo(null, vertex); + List nodes = new ArrayList<>(); + if (this instanceof UnionTreePath) { + nodes.addAll(((UnionTreePath) this).getNodes()); + } else { + nodes.add(this); + } + if (other instanceof UnionTreePath) { + nodes.addAll(((UnionTreePath) other).getNodes()); + } else { + nodes.add((ITreePath) other); + } + return UnionTreePath.create(nodes); + } + + @Override + public ITreePath getMessageByRequestId(Object requestId) { + return this.getTreePath(requestId); + } + + @Override + public List map(PathMapFunction mapFunction) { + List paths = toList(); + List results = new ArrayList<>(); + for (Path path : paths) { + O result = mapFunction.map(path); + results.add(result); + } + return results; + } + + @Override + public List flatMap(PathFlatMapFunction flatMapFunction) { + List paths = toList(); + List finalResults = new ArrayList<>(); + for (Path path : paths) { + Collection results = flatMapFunction.flatMap(path); + finalResults.addAll(results); } -} \ No newline at end of file + return finalResults; + } + + @Override + public ITreePath extendTo(RowEdge edge) { + return extendTo(null, Lists.newArrayList(edge)); + } + + @Override + public ITreePath extendTo(RowVertex vertex) { + return extendTo(null, vertex); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/DefaultEdgeSet.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/DefaultEdgeSet.java index 756921293..f54819405 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/DefaultEdgeSet.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/DefaultEdgeSet.java @@ -19,87 +19,89 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Objects; -import org.apache.geaflow.dsl.common.data.RowEdge; - -public class DefaultEdgeSet implements EdgeSet { - - private final List edges; - - public DefaultEdgeSet(List edges) { - this.edges = Objects.requireNonNull(edges, "edges is null"); - } - - public DefaultEdgeSet() { - this(new ArrayList<>()); - } - - @Override - public Iterator iterator() { - return edges.iterator(); - } - - @Override - public void addEdge(RowEdge edge) { - edges.add(edge); - } - @Override - public int size() { - return edges.size(); - } - - @Override - public Object getSrcId() { - return edges.get(0).getSrcId(); - } - - @Override - public Object getTargetId() { - return edges.get(0).getTargetId(); - } +import org.apache.geaflow.dsl.common.data.RowEdge; - @Override - public EdgeSet copy() { - return new DefaultEdgeSet(Lists.newArrayList(edges)); - } +import com.google.common.collect.Lists; - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof DefaultEdgeSet)) { - return false; - } - DefaultEdgeSet that = (DefaultEdgeSet) o; - return Objects.equals(edges, that.edges); - } +public class DefaultEdgeSet implements EdgeSet { - @Override - public int hashCode() { - return Objects.hash(edges); + private final List edges; + + public DefaultEdgeSet(List edges) { + this.edges = Objects.requireNonNull(edges, "edges is null"); + } + + public DefaultEdgeSet() { + this(new ArrayList<>()); + } + + @Override + public Iterator iterator() { + return edges.iterator(); + } + + @Override + public void addEdge(RowEdge edge) { + edges.add(edge); + } + + @Override + public int size() { + return edges.size(); + } + + @Override + public Object getSrcId() { + return edges.get(0).getSrcId(); + } + + @Override + public Object getTargetId() { + return edges.get(0).getTargetId(); + } + + @Override + public EdgeSet copy() { + return new DefaultEdgeSet(Lists.newArrayList(edges)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return getSrcId() + "#" + getTargetId() + "(size:" + size() + ")"; + if (!(o instanceof DefaultEdgeSet)) { + return false; } - - @Override - public void addEdges(EdgeSet edgeSet) { - for (RowEdge edge : edgeSet) { - addEdge(edge); - } + DefaultEdgeSet that = (DefaultEdgeSet) o; + return Objects.equals(edges, that.edges); + } + + @Override + public int hashCode() { + return Objects.hash(edges); + } + + @Override + public String toString() { + return getSrcId() + "#" + getTargetId() + "(size:" + size() + ")"; + } + + @Override + public void addEdges(EdgeSet edgeSet) { + for (RowEdge edge : edgeSet) { + addEdge(edge); } + } - @Override - public boolean like(EdgeSet other) { - return Objects.equals(getSrcId(), other.getSrcId()) - && Objects.equals(getTargetId(), other.getTargetId()); - } + @Override + public boolean like(EdgeSet other) { + return Objects.equals(getSrcId(), other.getSrcId()) + && Objects.equals(getTargetId(), other.getTargetId()); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeSet.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeSet.java index 45cf71c87..74f996831 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeSet.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeSet.java @@ -23,17 +23,17 @@ public interface EdgeSet extends Iterable { - void addEdge(RowEdge edge); + void addEdge(RowEdge edge); - int size(); + int size(); - Object getSrcId(); + Object getSrcId(); - Object getTargetId(); + Object getTargetId(); - void addEdges(EdgeSet edgeSet); + void addEdges(EdgeSet edgeSet); - EdgeSet copy(); + EdgeSet copy(); - boolean like(EdgeSet other); + boolean like(EdgeSet other); } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeTreePath.java index 9f2dfb398..95f56d1ed 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EdgeTreePath.java @@ -19,178 +19,175 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; -public class EdgeTreePath extends AbstractSingleTreePath { - - /** - * The parent nodes. - */ - private List parents = new ArrayList<>(); - - /** - * Edge sets with same (srcId, targetId) pair. - */ - private final EdgeSet edges; - - private EdgeTreePath(Set requestIds, EdgeSet edges) { - this.edges = Objects.requireNonNull(edges, "edges is null"); - this.requestIds = requestIds; - } - - public static EdgeTreePath of(Set requestIds, EdgeSet edges) { - return new EdgeTreePath(requestIds, edges); - } - - public static EdgeTreePath of(Object requestId, RowEdge edge) { - if (requestId == null) { - return of(null, edge); - } - return of(Sets.newHashSet(requestId), edge); - } - - public static EdgeTreePath of(Set requestIds, RowEdge edge) { - EdgeSet edgeSet = new DefaultEdgeSet(); - edgeSet.addEdge(edge); - return new EdgeTreePath(requestIds, edgeSet); - } - - @Override - public RowVertex getVertex() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public void setVertex(RowVertex vertex) { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public Object getVertexId() { - return edges.getTargetId(); - } - - @Override - public NodeType getNodeType() { - return NodeType.EDGE_TREE; - } - - @Override - public List getParents() { - return parents; - } - - @Override - public EdgeSet getEdgeSet() { - return edges; - } - - @Override - public ITreePath copy() { - EdgeTreePath copyTree = new EdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); - List copyParents = new ArrayList<>(); - for (ITreePath parent : parents) { - copyParents.add(parent.copy()); - } - copyTree.parents = copyParents; - return copyTree; - } - - @Override - public ITreePath copy(List parents) { - EdgeTreePath copyTree = new EdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); - copyTree.parents = parents; - return copyTree; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Sets; - @Override - public ITreePath getTreePath(Object sessionId) { - if (requestIds != null && !requestIds.contains(sessionId)) { - return null; - } - ITreePath treePathOnSession = new EdgeTreePath(Sets.newHashSet(sessionId), edges.copy()); - for (ITreePath parent : parents) { - ITreePath sessionParent = parent.getTreePath(sessionId); - if (sessionParent != null) { - treePathOnSession.addParent(sessionParent); - } - } - return treePathOnSession; - } +public class EdgeTreePath extends AbstractSingleTreePath { - @Override - public int size() { - if (parents.isEmpty()) { - return edges.size(); - } - int parentSize = 0; - for (ITreePath parent : parents) { - parentSize += parent.size(); - } - return edges.size() * parentSize; - } + /** The parent nodes. */ + private List parents = new ArrayList<>(); + + /** Edge sets with same (srcId, targetId) pair. */ + private final EdgeSet edges; + + private EdgeTreePath(Set requestIds, EdgeSet edges) { + this.edges = Objects.requireNonNull(edges, "edges is null"); + this.requestIds = requestIds; + } + + public static EdgeTreePath of(Set requestIds, EdgeSet edges) { + return new EdgeTreePath(requestIds, edges); + } + + public static EdgeTreePath of(Object requestId, RowEdge edge) { + if (requestId == null) { + return of(null, edge); + } + return of(Sets.newHashSet(requestId), edge); + } + + public static EdgeTreePath of(Set requestIds, RowEdge edge) { + EdgeSet edgeSet = new DefaultEdgeSet(); + edgeSet.addEdge(edge); + return new EdgeTreePath(requestIds, edgeSet); + } + + @Override + public RowVertex getVertex() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public void setVertex(RowVertex vertex) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public Object getVertexId() { + return edges.getTargetId(); + } + + @Override + public NodeType getNodeType() { + return NodeType.EDGE_TREE; + } + + @Override + public List getParents() { + return parents; + } + + @Override + public EdgeSet getEdgeSet() { + return edges; + } + + @Override + public ITreePath copy() { + EdgeTreePath copyTree = new EdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); + List copyParents = new ArrayList<>(); + for (ITreePath parent : parents) { + copyParents.add(parent.copy()); + } + copyTree.parents = copyParents; + return copyTree; + } + + @Override + public ITreePath copy(List parents) { + EdgeTreePath copyTree = new EdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); + copyTree.parents = parents; + return copyTree; + } + + @Override + public ITreePath getTreePath(Object sessionId) { + if (requestIds != null && !requestIds.contains(sessionId)) { + return null; + } + ITreePath treePathOnSession = new EdgeTreePath(Sets.newHashSet(sessionId), edges.copy()); + for (ITreePath parent : parents) { + ITreePath sessionParent = parent.getTreePath(sessionId); + if (sessionParent != null) { + treePathOnSession.addParent(sessionParent); + } + } + return treePathOnSession; + } + + @Override + public int size() { + if (parents.isEmpty()) { + return edges.size(); + } + int parentSize = 0; + for (ITreePath parent : parents) { + parentSize += parent.size(); + } + return edges.size() * parentSize; + } + + @Override + public boolean equalNode(ITreePath other) { + if (other.getNodeType() == NodeType.EDGE_TREE) { + return Objects.equals(edges, other.getEdgeSet()); + } + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof EdgeTreePath)) { + return false; + } + EdgeTreePath that = (EdgeTreePath) o; + return Objects.equals(parents, that.parents) + && Objects.equals(edges, that.edges) + && Objects.equals(requestIds, that.requestIds); + } + + @Override + public int hashCode() { + return Objects.hash(parents, edges, requestIds); + } + + public static class EdgeTreePathSerializer extends Serializer { @Override - public boolean equalNode(ITreePath other) { - if (other.getNodeType() == NodeType.EDGE_TREE) { - return Objects.equals(edges, other.getEdgeSet()); - } - return false; + public void write(Kryo kryo, Output output, EdgeTreePath object) { + kryo.writeClassAndObject(output, object.getRequestIds()); + kryo.writeClassAndObject(output, object.getParents()); + kryo.writeClassAndObject(output, object.getEdgeSet()); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof EdgeTreePath)) { - return false; - } - EdgeTreePath that = (EdgeTreePath) o; - return Objects.equals(parents, that.parents) && Objects.equals(edges, that.edges) - && Objects.equals(requestIds, that.requestIds); + public EdgeTreePath read(Kryo kryo, Input input, Class type) { + Set requestIds = (Set) kryo.readClassAndObject(input); + List parents = (List) kryo.readClassAndObject(input); + EdgeSet edges = (EdgeSet) kryo.readClassAndObject(input); + EdgeTreePath treePath = EdgeTreePath.of(requestIds, edges); + treePath.parents.addAll(parents); + return treePath; } @Override - public int hashCode() { - return Objects.hash(parents, edges, requestIds); - } - - - public static class EdgeTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, EdgeTreePath object) { - kryo.writeClassAndObject(output, object.getRequestIds()); - kryo.writeClassAndObject(output, object.getParents()); - kryo.writeClassAndObject(output, object.getEdgeSet()); - } - - @Override - public EdgeTreePath read(Kryo kryo, Input input, Class type) { - Set requestIds = (Set) kryo.readClassAndObject(input); - List parents = (List) kryo.readClassAndObject(input); - EdgeSet edges = (EdgeSet) kryo.readClassAndObject(input); - EdgeTreePath treePath = EdgeTreePath.of(requestIds, edges); - treePath.parents.addAll(parents); - return treePath; - } - - @Override - public EdgeTreePath copy(Kryo kryo, EdgeTreePath original) { - return (EdgeTreePath) original.copy(); - } + public EdgeTreePath copy(Kryo kryo, EdgeTreePath original) { + return (EdgeTreePath) original.copy(); } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EmptyTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EmptyTreePath.java index 5f81b9ef7..a77106c10 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EmptyTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/EmptyTreePath.java @@ -19,136 +19,140 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Collections; import java.util.List; import java.util.Set; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + public class EmptyTreePath extends AbstractSingleTreePath implements KryoSerializable { - public static final ITreePath INSTANCE = new EmptyTreePath(); - - private EmptyTreePath() { - - } - - @SuppressWarnings("unchecked") - public static EmptyTreePath of() { - return (EmptyTreePath) INSTANCE; - } - - @Override - public RowVertex getVertex() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public void setVertex(RowVertex vertex) { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public Object getVertexId() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public NodeType getNodeType() { - return NodeType.EMPTY_TREE; - } - - @Override - public List getParents() { - return Collections.emptyList(); - } - - @Override - public void addParent(ITreePath parent) { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public ITreePath merge(ITreePath other) { - return other; - } - - @Override - public EdgeSet getEdgeSet() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public ITreePath copy() { - return new EmptyTreePath(); - } - - @Override - public ITreePath copy(List parents) { - return copy(); - } - - @Override - public ITreePath getTreePath(Object requestId) { - return INSTANCE; - } - - @Override - public boolean isEmpty() { - return true; - } - - @Override - public int size() { - return 0; - } - - @Override - public ITreePath extendTo(Set requestIds, List edges) { - EdgeSet edgeSet = new DefaultEdgeSet(edges); - return SourceEdgeTreePath.of(requestIds, edgeSet); - } - - @Override - public ITreePath extendTo(Set requestIds, RowVertex vertex) { - return SourceVertexTreePath.of(requestIds, vertex); - } - - @Override - public List select(int... pathIndices) { - return Collections.emptyList(); - } - - @Override - public boolean walkTree(List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { - return false; - } - - @Override - protected ITreePath filter(PathFilterFunction filterFunction, - int[] refPathIndices, int[] fieldMapping, - Path currentPath, int maxDepth, PathIdCounter pathId) { - return EmptyTreePath.of(); - } - - @Override - public boolean equalNode(ITreePath other) { - return other.getNodeType() == NodeType.EMPTY_TREE; - } - - @Override - public void write(Kryo kryo, Output output) { - // no fields to serialize - } - - @Override - public void read(Kryo kryo, Input input) { - // no fields to deserialize - } - -} \ No newline at end of file + public static final ITreePath INSTANCE = new EmptyTreePath(); + + private EmptyTreePath() {} + + @SuppressWarnings("unchecked") + public static EmptyTreePath of() { + return (EmptyTreePath) INSTANCE; + } + + @Override + public RowVertex getVertex() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public void setVertex(RowVertex vertex) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public Object getVertexId() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public NodeType getNodeType() { + return NodeType.EMPTY_TREE; + } + + @Override + public List getParents() { + return Collections.emptyList(); + } + + @Override + public void addParent(ITreePath parent) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public ITreePath merge(ITreePath other) { + return other; + } + + @Override + public EdgeSet getEdgeSet() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public ITreePath copy() { + return new EmptyTreePath(); + } + + @Override + public ITreePath copy(List parents) { + return copy(); + } + + @Override + public ITreePath getTreePath(Object requestId) { + return INSTANCE; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public int size() { + return 0; + } + + @Override + public ITreePath extendTo(Set requestIds, List edges) { + EdgeSet edgeSet = new DefaultEdgeSet(edges); + return SourceEdgeTreePath.of(requestIds, edgeSet); + } + + @Override + public ITreePath extendTo(Set requestIds, RowVertex vertex) { + return SourceVertexTreePath.of(requestIds, vertex); + } + + @Override + public List select(int... pathIndices) { + return Collections.emptyList(); + } + + @Override + public boolean walkTree( + List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { + return false; + } + + @Override + protected ITreePath filter( + PathFilterFunction filterFunction, + int[] refPathIndices, + int[] fieldMapping, + Path currentPath, + int maxDepth, + PathIdCounter pathId) { + return EmptyTreePath.of(); + } + + @Override + public boolean equalNode(ITreePath other) { + return other.getNodeType() == NodeType.EMPTY_TREE; + } + + @Override + public void write(Kryo kryo, Output output) { + // no fields to serialize + } + + @Override + public void read(Kryo kryo, Input input) { + // no fields to deserialize + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ITreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ITreePath.java index 949c07f5b..6a3b4e72e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ITreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ITreePath.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Set; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; @@ -29,118 +30,115 @@ public interface ITreePath extends IPathMessage { - RowVertex getVertex(); + RowVertex getVertex(); - void setVertex(RowVertex vertex); + void setVertex(RowVertex vertex); - Object getVertexId(); + Object getVertexId(); - /** - * Get node type. - */ - NodeType getNodeType(); + /** Get node type. */ + NodeType getNodeType(); - List getParents(); + List getParents(); - void addParent(ITreePath parent); + void addParent(ITreePath parent); - ITreePath merge(ITreePath other); + ITreePath merge(ITreePath other); - EdgeSet getEdgeSet(); + EdgeSet getEdgeSet(); - ITreePath copy(); + ITreePath copy(); - ITreePath copy(List parents); + ITreePath copy(List parents); - ITreePath getTreePath(Object requestId); + ITreePath getTreePath(Object requestId); - Set getRequestIds(); + Set getRequestIds(); - void addRequestIds(Collection requestIds); + void addRequestIds(Collection requestIds); - boolean isEmpty(); + boolean isEmpty(); - int size(); + int size(); - ITreePath limit(int n); + ITreePath limit(int n); - ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices); + ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices); - ITreePath mapTree(PathMapFunction mapFunction); + ITreePath mapTree(PathMapFunction mapFunction); - List map(PathMapFunction mapFunction); + List map(PathMapFunction mapFunction); - List flatMap(PathFlatMapFunction flatMapFunction); + List flatMap(PathFlatMapFunction flatMapFunction); - List toList(); + List toList(); - ITreePath extendTo(Set requestIds, List edges); + ITreePath extendTo(Set requestIds, List edges); - ITreePath extendTo(RowEdge edge); + ITreePath extendTo(RowEdge edge); - ITreePath extendTo(Set requestIds, RowVertex vertex); + ITreePath extendTo(Set requestIds, RowVertex vertex); - ITreePath extendTo(RowVertex vertex); + ITreePath extendTo(RowVertex vertex); - List select(int... pathIndices); + List select(int... pathIndices); - ITreePath subPath(int... pathIndices); + ITreePath subPath(int... pathIndices); - int getDepth(); + int getDepth(); - boolean walkTree(List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId); + boolean walkTree( + List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId); - boolean equalNode(ITreePath other); + boolean equalNode(ITreePath other); - ITreePath optimize(); + ITreePath optimize(); - /** - * Set request id to all the nodes in the tree. - */ - void setRequestIdForTree(Object requestId); + /** Set request id to all the nodes in the tree. */ + void setRequestIdForTree(Object requestId); - void setRequestId(Object requestId); + void setRequestId(Object requestId); - enum NodeType { - EMPTY_TREE, - VERTEX_TREE, - EDGE_TREE, - UNION_TREE - } + enum NodeType { + EMPTY_TREE, + VERTEX_TREE, + EDGE_TREE, + UNION_TREE + } - interface WalkFunction { + interface WalkFunction { - boolean onWalk(List paths); - } + boolean onWalk(List paths); + } - interface PathFilterFunction { + interface PathFilterFunction { - boolean accept(Path path); - } + boolean accept(Path path); + } - interface PathMapFunction { + interface PathMapFunction { - O map(Path path); - } + O map(Path path); + } - interface PathFlatMapFunction { + interface PathFlatMapFunction { - Collection flatMap(Path path); - } + Collection flatMap(Path path); + } - class PathIdCounter { - private long counter; + class PathIdCounter { + private long counter; - public PathIdCounter(long counter) { - this.counter = counter; - } + public PathIdCounter(long counter) { + this.counter = counter; + } - public PathIdCounter() { - this(0L); - } + public PathIdCounter() { + this(0L); + } - public long getAndInc() { - return counter++; - } + public long getAndInc() { + return counter++; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ParameterizedTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ParameterizedTreePath.java index fc28800c0..29c0015b7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ParameterizedTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/ParameterizedTreePath.java @@ -19,13 +19,10 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; import java.util.Collection; import java.util.List; import java.util.Set; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; @@ -33,209 +30,215 @@ import org.apache.geaflow.dsl.runtime.traversal.message.IMessage; import org.apache.geaflow.dsl.runtime.traversal.message.MessageType; -public class ParameterizedTreePath extends AbstractTreePath { - - private final ITreePath baseTreePath; - - private final Object requestId; - - private final Row parameter; - - public ParameterizedTreePath(ITreePath baseTreePath, Object requestId, Row parameter) { - this.baseTreePath = baseTreePath; - this.requestId = requestId; - this.parameter = parameter; - } - - @Override - public MessageType getType() { - return baseTreePath.getType(); - } - - @Override - public IMessage combine(IMessage other) { - throw new IllegalArgumentException("read only tree path, combine is not support"); - } - - @Override - public ITreePath getMessageByRequestId(Object requestId) { - return (ITreePath) baseTreePath.getMessageByRequestId(requestId); - } - - @Override - public RowVertex getVertex() { - return baseTreePath.getVertex(); - } - - @Override - public void setVertex(RowVertex vertex) { - throw new IllegalArgumentException("read only tree path, setVertex is not support"); - } - - @Override - public Object getVertexId() { - return baseTreePath.getVertexId(); - } - - @Override - public NodeType getNodeType() { - return baseTreePath.getNodeType(); - } - - @Override - public List getParents() { - return baseTreePath.getParents(); - } - - @Override - public void addParent(ITreePath parent) { - throw new IllegalArgumentException("read only tree path, addParent is not support"); - } - - @Override - public ITreePath merge(ITreePath other) { - throw new IllegalArgumentException("read only tree path, merge is not support"); - } - - @Override - public EdgeSet getEdgeSet() { - return baseTreePath.getEdgeSet(); - } - - @Override - public ITreePath copy() { - return new ParameterizedTreePath(baseTreePath.copy(), requestId, parameter); - } - - @Override - public ITreePath copy(List parents) { - return new ParameterizedTreePath(baseTreePath.copy(parents), requestId, parameter); - } - - @Override - public ITreePath getTreePath(Object requestId) { - return new ParameterizedTreePath(baseTreePath.getTreePath(requestId), requestId, parameter); - } - - @Override - public Set getRequestIds() { - return baseTreePath.getRequestIds(); - } - - @Override - public void addRequestIds(Collection requestIds) { - throw new IllegalArgumentException("read only tree path, addRequestIds is not support"); - } - - @Override - public boolean isEmpty() { - return baseTreePath.isEmpty(); - } - - @Override - public int size() { - return baseTreePath.size(); - } - - @Override - public ITreePath limit(int n) { - return new ParameterizedTreePath(baseTreePath.limit(n), requestId, parameter); - } - - @Override - public ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices) { - return new ParameterizedTreePath(baseTreePath.filter(filterFunction, refPathIndices), requestId, parameter); - } - - @Override - public ITreePath mapTree(PathMapFunction mapFunction) { - return new ParameterizedTreePath(baseTreePath.mapTree(mapFunction), requestId, parameter); - } - - @Override - public List toList() { - return baseTreePath.toList(); - } - - @Override - public ITreePath extendTo(Set requestIds, List edges) { - throw new IllegalArgumentException("read only tree path, extendTo is not support"); - } - - @Override - public ITreePath extendTo(Set requestIds, RowVertex vertex) { - throw new IllegalArgumentException("read only tree path, extendTo is not support"); - } - - @Override - public List select(int... pathIndices) { - return baseTreePath.select(pathIndices); - } - - @Override - public ITreePath subPath(int... labelIndices) { - return new ParameterizedTreePath(baseTreePath.subPath(labelIndices), requestId, parameter); - } - - @Override - public int getDepth() { - return baseTreePath.getDepth(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; - @Override - public boolean walkTree(List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { - return baseTreePath.walkTree(pathNodes, walkFunction, maxDepth, pathId); - } +public class ParameterizedTreePath extends AbstractTreePath { - @Override - public boolean equalNode(ITreePath other) { - return baseTreePath.equalNode(other); - } + private final ITreePath baseTreePath; + + private final Object requestId; + + private final Row parameter; + + public ParameterizedTreePath(ITreePath baseTreePath, Object requestId, Row parameter) { + this.baseTreePath = baseTreePath; + this.requestId = requestId; + this.parameter = parameter; + } + + @Override + public MessageType getType() { + return baseTreePath.getType(); + } + + @Override + public IMessage combine(IMessage other) { + throw new IllegalArgumentException("read only tree path, combine is not support"); + } + + @Override + public ITreePath getMessageByRequestId(Object requestId) { + return (ITreePath) baseTreePath.getMessageByRequestId(requestId); + } + + @Override + public RowVertex getVertex() { + return baseTreePath.getVertex(); + } + + @Override + public void setVertex(RowVertex vertex) { + throw new IllegalArgumentException("read only tree path, setVertex is not support"); + } + + @Override + public Object getVertexId() { + return baseTreePath.getVertexId(); + } + + @Override + public NodeType getNodeType() { + return baseTreePath.getNodeType(); + } + + @Override + public List getParents() { + return baseTreePath.getParents(); + } + + @Override + public void addParent(ITreePath parent) { + throw new IllegalArgumentException("read only tree path, addParent is not support"); + } + + @Override + public ITreePath merge(ITreePath other) { + throw new IllegalArgumentException("read only tree path, merge is not support"); + } + + @Override + public EdgeSet getEdgeSet() { + return baseTreePath.getEdgeSet(); + } + + @Override + public ITreePath copy() { + return new ParameterizedTreePath(baseTreePath.copy(), requestId, parameter); + } + + @Override + public ITreePath copy(List parents) { + return new ParameterizedTreePath(baseTreePath.copy(parents), requestId, parameter); + } + + @Override + public ITreePath getTreePath(Object requestId) { + return new ParameterizedTreePath(baseTreePath.getTreePath(requestId), requestId, parameter); + } + + @Override + public Set getRequestIds() { + return baseTreePath.getRequestIds(); + } + + @Override + public void addRequestIds(Collection requestIds) { + throw new IllegalArgumentException("read only tree path, addRequestIds is not support"); + } + + @Override + public boolean isEmpty() { + return baseTreePath.isEmpty(); + } + + @Override + public int size() { + return baseTreePath.size(); + } + + @Override + public ITreePath limit(int n) { + return new ParameterizedTreePath(baseTreePath.limit(n), requestId, parameter); + } + + @Override + public ITreePath filter(PathFilterFunction filterFunction, int[] refPathIndices) { + return new ParameterizedTreePath( + baseTreePath.filter(filterFunction, refPathIndices), requestId, parameter); + } + + @Override + public ITreePath mapTree(PathMapFunction mapFunction) { + return new ParameterizedTreePath(baseTreePath.mapTree(mapFunction), requestId, parameter); + } + + @Override + public List toList() { + return baseTreePath.toList(); + } + + @Override + public ITreePath extendTo(Set requestIds, List edges) { + throw new IllegalArgumentException("read only tree path, extendTo is not support"); + } + + @Override + public ITreePath extendTo(Set requestIds, RowVertex vertex) { + throw new IllegalArgumentException("read only tree path, extendTo is not support"); + } + + @Override + public List select(int... pathIndices) { + return baseTreePath.select(pathIndices); + } + + @Override + public ITreePath subPath(int... labelIndices) { + return new ParameterizedTreePath(baseTreePath.subPath(labelIndices), requestId, parameter); + } + + @Override + public int getDepth() { + return baseTreePath.getDepth(); + } + + @Override + public boolean walkTree( + List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { + return baseTreePath.walkTree(pathNodes, walkFunction, maxDepth, pathId); + } + + @Override + public boolean equalNode(ITreePath other) { + return baseTreePath.equalNode(other); + } + + @Override + public ITreePath optimize() { + throw new IllegalArgumentException("read only tree path, optimize is not support"); + } + + @Override + public void setRequestIdForTree(Object requestId) { + throw new IllegalArgumentException("read only tree path, setRequestIdForTree is not support"); + } + + @Override + public void setRequestId(Object requestId) { + throw new IllegalArgumentException("read only tree path, setRequestId is not support"); + } + + public Object getRequestId() { + return requestId; + } + + public Row getParameter() { + return parameter; + } - @Override - public ITreePath optimize() { - throw new IllegalArgumentException("read only tree path, optimize is not support"); + public static class ParameterizedTreePathSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, ParameterizedTreePath object) { + kryo.writeClassAndObject(output, object.baseTreePath); + kryo.writeClassAndObject(output, object.getRequestId()); + kryo.writeClassAndObject(output, object.getParameter()); } @Override - public void setRequestIdForTree(Object requestId) { - throw new IllegalArgumentException("read only tree path, setRequestIdForTree is not support"); + public ParameterizedTreePath read(Kryo kryo, Input input, Class type) { + ITreePath baseTreePath = (ITreePath) kryo.readClassAndObject(input); + Object requestId = kryo.readClassAndObject(input); + Row parameter = (Row) kryo.readClassAndObject(input); + return new ParameterizedTreePath(baseTreePath, requestId, parameter); } @Override - public void setRequestId(Object requestId) { - throw new IllegalArgumentException("read only tree path, setRequestId is not support"); + public ParameterizedTreePath copy(Kryo kryo, ParameterizedTreePath original) { + return (ParameterizedTreePath) original.copy(); } - - public Object getRequestId() { - return requestId; - } - - public Row getParameter() { - return parameter; - } - - public static class ParameterizedTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, ParameterizedTreePath object) { - kryo.writeClassAndObject(output, object.baseTreePath); - kryo.writeClassAndObject(output, object.getRequestId()); - kryo.writeClassAndObject(output, object.getParameter()); - } - - @Override - public ParameterizedTreePath read(Kryo kryo, Input input, Class type) { - ITreePath baseTreePath = (ITreePath) kryo.readClassAndObject(input); - Object requestId = kryo.readClassAndObject(input); - Row parameter = (Row) kryo.readClassAndObject(input); - return new ParameterizedTreePath(baseTreePath, requestId, parameter); - } - - @Override - public ParameterizedTreePath copy(Kryo kryo, ParameterizedTreePath original) { - return (ParameterizedTreePath) original.copy(); - } - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceEdgeTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceEdgeTreePath.java index 1be889743..211c78dc0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceEdgeTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceEdgeTreePath.java @@ -19,139 +19,138 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.RowVertex; -public class SourceEdgeTreePath extends AbstractSingleTreePath { - - /** - * Edge sets with same (srcId, targetId) pair. - */ - private final EdgeSet edges; - - private SourceEdgeTreePath(Set requestIds, EdgeSet edges) { - this.edges = edges; - this.requestIds = requestIds; - } - - public static SourceEdgeTreePath of(Set requestIds, EdgeSet edges) { - return new SourceEdgeTreePath(requestIds, edges); - } - - public static SourceEdgeTreePath of(Object requestId, EdgeSet edges) { - if (requestId == null) { - return of(null, edges); - } - return of(Sets.newHashSet(requestId), edges); - } - - @Override - public RowVertex getVertex() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public void setVertex(RowVertex vertex) { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public Object getVertexId() { - return edges.getTargetId(); - } - - @Override - public NodeType getNodeType() { - return NodeType.EDGE_TREE; - } - - @Override - public List getParents() { - return new ArrayList<>(); - } - - @Override - public EdgeSet getEdgeSet() { - return edges; - } - - @Override - public ITreePath copy() { - return new SourceEdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); - } - - @Override - public ITreePath copy(List parents) { - assert parents.isEmpty(); - return copy(); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Sets; - @Override - public ITreePath getTreePath(Object requestId) { - if (requestIds != null && !requestIds.contains(requestId)) { - return null; - } - return of(requestId, edges.copy()); - } +public class SourceEdgeTreePath extends AbstractSingleTreePath { - @Override - public int size() { - return edges.size(); - } + /** Edge sets with same (srcId, targetId) pair. */ + private final EdgeSet edges; + + private SourceEdgeTreePath(Set requestIds, EdgeSet edges) { + this.edges = edges; + this.requestIds = requestIds; + } + + public static SourceEdgeTreePath of(Set requestIds, EdgeSet edges) { + return new SourceEdgeTreePath(requestIds, edges); + } + + public static SourceEdgeTreePath of(Object requestId, EdgeSet edges) { + if (requestId == null) { + return of(null, edges); + } + return of(Sets.newHashSet(requestId), edges); + } + + @Override + public RowVertex getVertex() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public void setVertex(RowVertex vertex) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public Object getVertexId() { + return edges.getTargetId(); + } + + @Override + public NodeType getNodeType() { + return NodeType.EDGE_TREE; + } + + @Override + public List getParents() { + return new ArrayList<>(); + } + + @Override + public EdgeSet getEdgeSet() { + return edges; + } + + @Override + public ITreePath copy() { + return new SourceEdgeTreePath(ArrayUtil.copySet(requestIds), edges.copy()); + } + + @Override + public ITreePath copy(List parents) { + assert parents.isEmpty(); + return copy(); + } + + @Override + public ITreePath getTreePath(Object requestId) { + if (requestIds != null && !requestIds.contains(requestId)) { + return null; + } + return of(requestId, edges.copy()); + } + + @Override + public int size() { + return edges.size(); + } + + @Override + public boolean equalNode(ITreePath other) { + if (other.getNodeType() == NodeType.EDGE_TREE) { + return Objects.equals(edges, other.getEdgeSet()); + } + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SourceEdgeTreePath)) { + return false; + } + SourceEdgeTreePath that = (SourceEdgeTreePath) o; + return Objects.equals(edges, that.edges) && Objects.equals(requestIds, that.requestIds); + } + + @Override + public int hashCode() { + return Objects.hash(edges, requestIds); + } + + public static class SourceEdgeTreePathSerializer extends Serializer { @Override - public boolean equalNode(ITreePath other) { - if (other.getNodeType() == NodeType.EDGE_TREE) { - return Objects.equals(edges, other.getEdgeSet()); - } - return false; + public void write(Kryo kryo, Output output, SourceEdgeTreePath object) { + kryo.writeClassAndObject(output, object.getRequestIds()); + kryo.writeClassAndObject(output, object.getEdgeSet()); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SourceEdgeTreePath)) { - return false; - } - SourceEdgeTreePath that = (SourceEdgeTreePath) o; - return Objects.equals(edges, that.edges) && Objects.equals(requestIds, that.requestIds); + public SourceEdgeTreePath read(Kryo kryo, Input input, Class type) { + Set requestIds = (Set) kryo.readClassAndObject(input); + EdgeSet edges = (EdgeSet) kryo.readClassAndObject(input); + return SourceEdgeTreePath.of(requestIds, edges); } @Override - public int hashCode() { - return Objects.hash(edges, requestIds); + public SourceEdgeTreePath copy(Kryo kryo, SourceEdgeTreePath original) { + return (SourceEdgeTreePath) original.copy(); } - - public static class SourceEdgeTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, SourceEdgeTreePath object) { - kryo.writeClassAndObject(output, object.getRequestIds()); - kryo.writeClassAndObject(output, object.getEdgeSet()); - } - - @Override - public SourceEdgeTreePath read(Kryo kryo, Input input, Class type) { - Set requestIds = (Set) kryo.readClassAndObject(input); - EdgeSet edges = (EdgeSet) kryo.readClassAndObject(input); - return SourceEdgeTreePath.of(requestIds, edges); - } - - @Override - public SourceEdgeTreePath copy(Kryo kryo, SourceEdgeTreePath original) { - return (SourceEdgeTreePath) original.copy(); - } - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceVertexTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceVertexTreePath.java index 5d466d6a7..9ee0b89bd 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceVertexTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/SourceVertexTreePath.java @@ -19,147 +19,145 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Sets; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.RowVertex; -public class SourceVertexTreePath extends AbstractSingleTreePath { - - /** - * Vertex. - */ - private RowVertex vertex; - - private SourceVertexTreePath(Set requestIds, RowVertex vertex) { - this.vertex = vertex; - this.requestIds = requestIds; - } - - public static SourceVertexTreePath of(Set requestIds, RowVertex vertex) { - return new SourceVertexTreePath(requestIds, vertex); - } - - public static SourceVertexTreePath of(Object requestId, RowVertex vertex) { - return new SourceVertexTreePath(Sets.newHashSet(requestId), vertex); - } - - @Override - public RowVertex getVertex() { - return vertex; - } - - @Override - public void setVertex(RowVertex vertex) { - this.vertex = Objects.requireNonNull(vertex, "vertex is null"); - } - - @Override - public Object getVertexId() { - return vertex.getId(); - } - - @Override - public NodeType getNodeType() { - return NodeType.VERTEX_TREE; - } - - @Override - public List getParents() { - return Collections.emptyList(); - } - - @Override - public void addParent(ITreePath parent) { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public EdgeSet getEdgeSet() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public ITreePath copy() { - return new SourceVertexTreePath(ArrayUtil.copySet(requestIds), vertex); - } - - @Override - public ITreePath copy(List parents) { - assert parents.isEmpty(); - return copy(); - } - - @Override - public ITreePath getTreePath(Object requestId) { - if (requestIds != null && !requestIds.contains(requestId)) { - return null; - } - return of(requestId, vertex); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Sets; - @Override - public boolean isEmpty() { - return false; - } +public class SourceVertexTreePath extends AbstractSingleTreePath { - @Override - public int size() { - return 1; - } + /** Vertex. */ + private RowVertex vertex; + + private SourceVertexTreePath(Set requestIds, RowVertex vertex) { + this.vertex = vertex; + this.requestIds = requestIds; + } + + public static SourceVertexTreePath of(Set requestIds, RowVertex vertex) { + return new SourceVertexTreePath(requestIds, vertex); + } + + public static SourceVertexTreePath of(Object requestId, RowVertex vertex) { + return new SourceVertexTreePath(Sets.newHashSet(requestId), vertex); + } + + @Override + public RowVertex getVertex() { + return vertex; + } + + @Override + public void setVertex(RowVertex vertex) { + this.vertex = Objects.requireNonNull(vertex, "vertex is null"); + } + + @Override + public Object getVertexId() { + return vertex.getId(); + } + + @Override + public NodeType getNodeType() { + return NodeType.VERTEX_TREE; + } + + @Override + public List getParents() { + return Collections.emptyList(); + } + + @Override + public void addParent(ITreePath parent) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public EdgeSet getEdgeSet() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public ITreePath copy() { + return new SourceVertexTreePath(ArrayUtil.copySet(requestIds), vertex); + } + + @Override + public ITreePath copy(List parents) { + assert parents.isEmpty(); + return copy(); + } + + @Override + public ITreePath getTreePath(Object requestId) { + if (requestIds != null && !requestIds.contains(requestId)) { + return null; + } + return of(requestId, vertex); + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public int size() { + return 1; + } + + @Override + public boolean equalNode(ITreePath other) { + if (other.getNodeType() == NodeType.VERTEX_TREE) { + return Objects.equals(vertex, other.getVertex()); + } + return false; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SourceVertexTreePath)) { + return false; + } + SourceVertexTreePath that = (SourceVertexTreePath) o; + return Objects.equals(vertex, that.vertex) && Objects.equals(requestIds, that.requestIds); + } + + @Override + public int hashCode() { + return Objects.hash(vertex, requestIds); + } + + public static class SourceVertexTreePathSerializer extends Serializer { @Override - public boolean equalNode(ITreePath other) { - if (other.getNodeType() == NodeType.VERTEX_TREE) { - return Objects.equals(vertex, other.getVertex()); - } - return false; + public void write(Kryo kryo, Output output, SourceVertexTreePath object) { + kryo.writeClassAndObject(output, object.getRequestIds()); + kryo.writeClassAndObject(output, object.getVertex()); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SourceVertexTreePath)) { - return false; - } - SourceVertexTreePath that = (SourceVertexTreePath) o; - return Objects.equals(vertex, that.vertex) - && Objects.equals(requestIds, that.requestIds); + public SourceVertexTreePath read(Kryo kryo, Input input, Class type) { + Set requestIds = (Set) kryo.readClassAndObject(input); + RowVertex vertex = (RowVertex) kryo.readClassAndObject(input); + return SourceVertexTreePath.of(requestIds, vertex); } @Override - public int hashCode() { - return Objects.hash(vertex, requestIds); + public SourceVertexTreePath copy(Kryo kryo, SourceVertexTreePath original) { + return (SourceVertexTreePath) original.copy(); } - - public static class SourceVertexTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, SourceVertexTreePath object) { - kryo.writeClassAndObject(output, object.getRequestIds()); - kryo.writeClassAndObject(output, object.getVertex()); - } - - @Override - public SourceVertexTreePath read(Kryo kryo, Input input, Class type) { - Set requestIds = (Set) kryo.readClassAndObject(input); - RowVertex vertex = (RowVertex) kryo.readClassAndObject(input); - return SourceVertexTreePath.of(requestIds, vertex); - } - - @Override - public SourceVertexTreePath copy(Kryo kryo, SourceVertexTreePath original) { - return (SourceVertexTreePath) original.copy(); - } - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/TreePaths.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/TreePaths.java index 3a54f1e78..5c21c35d0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/TreePaths.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/TreePaths.java @@ -19,74 +19,76 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.google.common.collect.Lists; import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; +import com.google.common.collect.Lists; + public class TreePaths { - public static ITreePath createTreePath(Iterable paths) { - return createTreePath(Lists.newArrayList(paths)); - } + public static ITreePath createTreePath(Iterable paths) { + return createTreePath(Lists.newArrayList(paths)); + } - public static ITreePath singletonPath(Path path) { - return createTreePath(Collections.singletonList(path)); - } + public static ITreePath singletonPath(Path path) { + return createTreePath(Collections.singletonList(path)); + } - public static ITreePath createTreePath(List paths) { - return createTreePath(paths, true); - } + public static ITreePath createTreePath(List paths) { + return createTreePath(paths, true); + } - public static ITreePath createTreePath(List paths, boolean optimize) { - if (paths == null) { - throw new NullPointerException("paths is null"); - } - if (paths.isEmpty()) { - return EmptyTreePath.of(); - } - ITreePath treePath = null; - for (Path path : paths) { - ITreePath currentTree = createTreePath(path); - if (treePath == null) { - treePath = currentTree; - } else { - treePath = treePath.merge(currentTree); - } - } - if (optimize) { - return treePath.optimize(); - } - return treePath; + public static ITreePath createTreePath(List paths, boolean optimize) { + if (paths == null) { + throw new NullPointerException("paths is null"); + } + if (paths.isEmpty()) { + return EmptyTreePath.of(); + } + ITreePath treePath = null; + for (Path path : paths) { + ITreePath currentTree = createTreePath(path); + if (treePath == null) { + treePath = currentTree; + } else { + treePath = treePath.merge(currentTree); + } + } + if (optimize) { + return treePath.optimize(); } + return treePath; + } - @SuppressWarnings("unchecked") - private static ITreePath createTreePath(Path path) { - ITreePath lastTree = EmptyTreePath.of(); - for (int i = 0; i < path.size(); i++) { - Row node = path.getField(i, null); - lastTree = createTreePath(lastTree, node); - } - return lastTree; + @SuppressWarnings("unchecked") + private static ITreePath createTreePath(Path path) { + ITreePath lastTree = EmptyTreePath.of(); + for (int i = 0; i < path.size(); i++) { + Row node = path.getField(i, null); + lastTree = createTreePath(lastTree, node); } + return lastTree; + } - private static ITreePath createTreePath(ITreePath lastTree, Row node) { - ITreePath treePath; - if (node instanceof RowVertex) { - RowVertex vertex = (RowVertex) node; - treePath = lastTree.extendTo(null, vertex); - } else if (node instanceof RowEdge) { - RowEdge edge = (RowEdge) node; - treePath = lastTree.extendTo(null, Lists.newArrayList(edge)); - } else if (node == null) { - treePath = lastTree.extendTo((RowVertex) null); - } else { - throw new GeaflowRuntimeException("TreePath cannot be extended to node: " + node); - } - return treePath; + private static ITreePath createTreePath(ITreePath lastTree, Row node) { + ITreePath treePath; + if (node instanceof RowVertex) { + RowVertex vertex = (RowVertex) node; + treePath = lastTree.extendTo(null, vertex); + } else if (node instanceof RowEdge) { + RowEdge edge = (RowEdge) node; + treePath = lastTree.extendTo(null, Lists.newArrayList(edge)); + } else if (node == null) { + treePath = lastTree.extendTo((RowVertex) null); + } else { + throw new GeaflowRuntimeException("TreePath cannot be extended to node: " + node); } + return treePath; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/UnionTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/UnionTreePath.java index 0f59e4114..7a425ec1b 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/UnionTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/UnionTreePath.java @@ -19,11 +19,6 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; @@ -31,324 +26,337 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.geaflow.dsl.common.data.Path; import org.apache.geaflow.dsl.common.data.RowEdge; import org.apache.geaflow.dsl.common.data.RowVertex; -public class UnionTreePath extends AbstractTreePath { - - private final List nodes; - - private UnionTreePath(List nodes) { - this.nodes = Objects.requireNonNull(nodes); - } - - public static ITreePath create(List nodes) { - List notEmptyTreePath = - Objects.requireNonNull(nodes).stream().filter(n -> n.getNodeType() != NodeType.EMPTY_TREE).collect(Collectors.toList()); - if (notEmptyTreePath.isEmpty()) { - return EmptyTreePath.of(); - } else if (notEmptyTreePath.size() == 1) { - return notEmptyTreePath.get(0); - } else { - return new UnionTreePath(notEmptyTreePath); - } - } - - private UnionTreePath() { - this(new ArrayList<>()); - } - - @Override - public RowVertex getVertex() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public void setVertex(RowVertex vertex) { - for (ITreePath node : nodes) { - node.setVertex(vertex); - } - } - - @Override - public Object getVertexId() { - throw new IllegalArgumentException("Illegal call"); - } - - - @Override - public NodeType getNodeType() { - return NodeType.UNION_TREE; - } - - @Override - public List getParents() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public int getDepth() { - if (nodes.isEmpty()) { - return 0; - } - return nodes.get(0).getDepth(); - } - - @Override - public ITreePath merge(ITreePath other) { - if (other instanceof UnionTreePath) { - UnionTreePath unionTreePath = (UnionTreePath) other; - for (ITreePath otherNode : unionTreePath.nodes) { - addNode(otherNode, false); - } - } else { - addNode(other, false); - } - return this; - } - - private void addNode(ITreePath node, boolean mergeEdgeSet) { - ITreePath existNode = null; - for (ITreePath thisNode : nodes) { - if (thisNode.equalNode(node) && thisNode.getDepth() == node.getDepth()) { - existNode = thisNode; - break; - } - } - if (existNode != null) { - ITreePath merged = existNode.merge(node); - assert merged == existNode; - } else { - boolean hasMerged = false; - if (mergeEdgeSet) { - // merge edge with same source and target id. - if (node.getNodeType() == NodeType.EDGE_TREE) { - for (ITreePath thisNode : nodes) { - if (thisNode.getNodeType() == NodeType.EDGE_TREE - && thisNode.getEdgeSet().like(node.getEdgeSet()) - && Objects.equals(thisNode.getParents(), node.getParents())) { - thisNode.getEdgeSet().addEdges(node.getEdgeSet()); - hasMerged = true; - break; - } - } - } - } - if (!hasMerged) { - nodes.add(node); - } - } - } - - @Override - public ITreePath optimize() { - if (nodes.isEmpty()) { - return this; - } - UnionTreePath unionTreePath = new UnionTreePath(); - - for (int i = 0; i < nodes.size(); i++) { - ITreePath node = nodes.get(i); - if (i == 0) { - node = node.copy(); - } - unionTreePath.addNode(node, true); - } - if (unionTreePath.nodes.size() == 1) { - return unionTreePath.nodes.get(0); - } else if (unionTreePath.nodes.size() == 0) { - return EmptyTreePath.of(); - } - return unionTreePath; - } - - @Override - public void setRequestId(Object requestId) { - for (ITreePath node : nodes) { - node.setRequestId(requestId); - } - } - - @Override - public void setRequestIdForTree(Object requestId) { - for (ITreePath node : nodes) { - node.setRequestIdForTree(requestId); - } - } - - @Override - public EdgeSet getEdgeSet() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public ITreePath copy() { - List copyNodes = new ArrayList<>(nodes.size()); - for (ITreePath node : nodes) { - copyNodes.add(node.copy()); - } - return new UnionTreePath(copyNodes); - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Lists; - @Override - public ITreePath copy(List parents) { - throw new IllegalArgumentException("Illegal call"); - } +public class UnionTreePath extends AbstractTreePath { - @Override - public ITreePath getTreePath(Object sessionId) { - List sessionNodes = new ArrayList<>(nodes.size()); - for (ITreePath node : nodes) { - ITreePath nodeOnSession = node.getTreePath(sessionId); - if (nodeOnSession != null) { - sessionNodes.add(nodeOnSession.getTreePath(sessionId)); + private final List nodes; + + private UnionTreePath(List nodes) { + this.nodes = Objects.requireNonNull(nodes); + } + + public static ITreePath create(List nodes) { + List notEmptyTreePath = + Objects.requireNonNull(nodes).stream() + .filter(n -> n.getNodeType() != NodeType.EMPTY_TREE) + .collect(Collectors.toList()); + if (notEmptyTreePath.isEmpty()) { + return EmptyTreePath.of(); + } else if (notEmptyTreePath.size() == 1) { + return notEmptyTreePath.get(0); + } else { + return new UnionTreePath(notEmptyTreePath); + } + } + + private UnionTreePath() { + this(new ArrayList<>()); + } + + @Override + public RowVertex getVertex() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public void setVertex(RowVertex vertex) { + for (ITreePath node : nodes) { + node.setVertex(vertex); + } + } + + @Override + public Object getVertexId() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public NodeType getNodeType() { + return NodeType.UNION_TREE; + } + + @Override + public List getParents() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public int getDepth() { + if (nodes.isEmpty()) { + return 0; + } + return nodes.get(0).getDepth(); + } + + @Override + public ITreePath merge(ITreePath other) { + if (other instanceof UnionTreePath) { + UnionTreePath unionTreePath = (UnionTreePath) other; + for (ITreePath otherNode : unionTreePath.nodes) { + addNode(otherNode, false); + } + } else { + addNode(other, false); + } + return this; + } + + private void addNode(ITreePath node, boolean mergeEdgeSet) { + ITreePath existNode = null; + for (ITreePath thisNode : nodes) { + if (thisNode.equalNode(node) && thisNode.getDepth() == node.getDepth()) { + existNode = thisNode; + break; + } + } + if (existNode != null) { + ITreePath merged = existNode.merge(node); + assert merged == existNode; + } else { + boolean hasMerged = false; + if (mergeEdgeSet) { + // merge edge with same source and target id. + if (node.getNodeType() == NodeType.EDGE_TREE) { + for (ITreePath thisNode : nodes) { + if (thisNode.getNodeType() == NodeType.EDGE_TREE + && thisNode.getEdgeSet().like(node.getEdgeSet()) + && Objects.equals(thisNode.getParents(), node.getParents())) { + thisNode.getEdgeSet().addEdges(node.getEdgeSet()); + hasMerged = true; + break; } + } } - return UnionTreePath.create(sessionNodes); - } - - @Override - public Set getRequestIds() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public void addRequestIds(Collection requestIds) { - for (ITreePath node : nodes) { - node.addRequestIds(requestIds); + } + if (!hasMerged) { + nodes.add(node); + } + } + } + + @Override + public ITreePath optimize() { + if (nodes.isEmpty()) { + return this; + } + UnionTreePath unionTreePath = new UnionTreePath(); + + for (int i = 0; i < nodes.size(); i++) { + ITreePath node = nodes.get(i); + if (i == 0) { + node = node.copy(); + } + unionTreePath.addNode(node, true); + } + if (unionTreePath.nodes.size() == 1) { + return unionTreePath.nodes.get(0); + } else if (unionTreePath.nodes.size() == 0) { + return EmptyTreePath.of(); + } + return unionTreePath; + } + + @Override + public void setRequestId(Object requestId) { + for (ITreePath node : nodes) { + node.setRequestId(requestId); + } + } + + @Override + public void setRequestIdForTree(Object requestId) { + for (ITreePath node : nodes) { + node.setRequestIdForTree(requestId); + } + } + + @Override + public EdgeSet getEdgeSet() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public ITreePath copy() { + List copyNodes = new ArrayList<>(nodes.size()); + for (ITreePath node : nodes) { + copyNodes.add(node.copy()); + } + return new UnionTreePath(copyNodes); + } + + @Override + public ITreePath copy(List parents) { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public ITreePath getTreePath(Object sessionId) { + List sessionNodes = new ArrayList<>(nodes.size()); + for (ITreePath node : nodes) { + ITreePath nodeOnSession = node.getTreePath(sessionId); + if (nodeOnSession != null) { + sessionNodes.add(nodeOnSession.getTreePath(sessionId)); + } + } + return UnionTreePath.create(sessionNodes); + } + + @Override + public Set getRequestIds() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public void addRequestIds(Collection requestIds) { + for (ITreePath node : nodes) { + node.addRequestIds(requestIds); + } + } + + @Override + public int size() { + int size = 0; + for (ITreePath node : nodes) { + size += node.size(); + } + return size; + } + + @Override + public ITreePath extendTo(Set requestIds, List edges) { + EdgeSet edgeSet = new DefaultEdgeSet(edges); + ITreePath newTreePath = EdgeTreePath.of(requestIds, edgeSet); + for (ITreePath parent : nodes) { + newTreePath.addParent(parent); + } + return newTreePath; + } + + @Override + public ITreePath extendTo(Set requestIds, RowVertex vertex) { + ITreePath newTreePath = VertexTreePath.of(requestIds, vertex); + for (ITreePath parent : nodes) { + newTreePath.addParent(parent); + } + return newTreePath; + } + + @Override + public List select(int... pathIndices) { + Set selectPaths = new HashSet<>(); + for (ITreePath node : nodes) { + selectPaths.addAll(node.select(pathIndices)); + } + return Lists.newArrayList(selectPaths); + } + + @Override + public boolean walkTree( + List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { + for (ITreePath node : nodes) { + if (!node.walkTree(pathNodes, walkFunction, maxDepth, pathId)) { + return false; + } + } + return true; + } + + @Override + protected ITreePath filter( + PathFilterFunction filterFunction, + int[] refPathIndices, + int[] fieldMapping, + Path currentPath, + int maxDepth, + PathIdCounter pathId) { + List filterNodes = new ArrayList<>(); + for (ITreePath node : nodes) { + ITreePath filterNode = + ((AbstractTreePath) node) + .filter(filterFunction, refPathIndices, fieldMapping, currentPath, maxDepth, pathId); + if (filterNode != null) { + filterNodes.add(filterNode); + } + } + return UnionTreePath.create(filterNodes); + } + + @Override + public boolean equalNode(ITreePath other) { + if (other.getNodeType() == NodeType.UNION_TREE) { + UnionTreePath unionTreePath = (UnionTreePath) other; + if (nodes.size() != unionTreePath.nodes.size()) { + return false; + } + for (int i = 0; i < nodes.size(); i++) { + if (!nodes.get(i).equalNode(unionTreePath.nodes.get(i))) { + return false; } + } + return true; } + return false; + } - @Override - public int size() { - int size = 0; - for (ITreePath node : nodes) { - size += node.size(); - } - return size; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public ITreePath extendTo(Set requestIds, List edges) { - EdgeSet edgeSet = new DefaultEdgeSet(edges); - ITreePath newTreePath = EdgeTreePath.of(requestIds, edgeSet); - for (ITreePath parent : nodes) { - newTreePath.addParent(parent); - } - return newTreePath; + if (!(o instanceof UnionTreePath)) { + return false; } + UnionTreePath that = (UnionTreePath) o; + return Objects.equals(nodes, that.nodes); + } - @Override - public ITreePath extendTo(Set requestIds, RowVertex vertex) { - ITreePath newTreePath = VertexTreePath.of(requestIds, vertex); - for (ITreePath parent : nodes) { - newTreePath.addParent(parent); - } - return newTreePath; - } + @Override + public int hashCode() { + return Objects.hash(nodes); + } - @Override - public List select(int... pathIndices) { - Set selectPaths = new HashSet<>(); - for (ITreePath node : nodes) { - selectPaths.addAll(node.select(pathIndices)); - } - return Lists.newArrayList(selectPaths); - } + public List getNodes() { + return nodes; + } - @Override - public boolean walkTree(List pathNodes, WalkFunction walkFunction, int maxDepth, PathIdCounter pathId) { - for (ITreePath node : nodes) { - if (!node.walkTree(pathNodes, walkFunction, maxDepth, pathId)) { - return false; - } - } - return true; + public List expand() { + List paths = new ArrayList<>(); + for (ITreePath node : nodes) { + if (node instanceof UnionTreePath) { + paths.addAll(((UnionTreePath) node).expand()); + } else { + paths.add(node); + } } + return paths; + } - @Override - protected ITreePath filter(PathFilterFunction filterFunction, - int[] refPathIndices, int[] fieldMapping, - Path currentPath, int maxDepth, PathIdCounter pathId) { - List filterNodes = new ArrayList<>(); - for (ITreePath node : nodes) { - ITreePath filterNode = ((AbstractTreePath) node).filter(filterFunction, - refPathIndices, fieldMapping, currentPath, maxDepth, pathId); - if (filterNode != null) { - filterNodes.add(filterNode); - } - } - return UnionTreePath.create(filterNodes); - } + public static class UnionTreePathSerializer extends Serializer { @Override - public boolean equalNode(ITreePath other) { - if (other.getNodeType() == NodeType.UNION_TREE) { - UnionTreePath unionTreePath = (UnionTreePath) other; - if (nodes.size() != unionTreePath.nodes.size()) { - return false; - } - for (int i = 0; i < nodes.size(); i++) { - if (!nodes.get(i).equalNode(unionTreePath.nodes.get(i))) { - return false; - } - } - return true; - } - return false; + public void write(Kryo kryo, Output output, UnionTreePath object) { + kryo.writeClassAndObject(output, object.getNodes()); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof UnionTreePath)) { - return false; - } - UnionTreePath that = (UnionTreePath) o; - return Objects.equals(nodes, that.nodes); + public UnionTreePath read(Kryo kryo, Input input, Class type) { + List nodes = (List) kryo.readClassAndObject(input); + return new UnionTreePath(nodes); } @Override - public int hashCode() { - return Objects.hash(nodes); + public UnionTreePath copy(Kryo kryo, UnionTreePath original) { + return (UnionTreePath) original.copy(); } - - public List getNodes() { - return nodes; - } - - public List expand() { - List paths = new ArrayList<>(); - for (ITreePath node : nodes) { - if (node instanceof UnionTreePath) { - paths.addAll(((UnionTreePath) node).expand()); - } else { - paths.add(node); - } - } - return paths; - } - - public static class UnionTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, UnionTreePath object) { - kryo.writeClassAndObject(output, object.getNodes()); - } - - @Override - public UnionTreePath read(Kryo kryo, Input input, Class type) { - List nodes = (List) kryo.readClassAndObject(input); - return new UnionTreePath(nodes); - } - - @Override - public UnionTreePath copy(Kryo kryo, UnionTreePath original) { - return (UnionTreePath) original.copy(); - } - } - -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/VertexTreePath.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/VertexTreePath.java index 22c37deb6..85cffa762 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/VertexTreePath.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/traversal/path/VertexTreePath.java @@ -19,170 +19,168 @@ package org.apache.geaflow.dsl.runtime.traversal.path; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.Serializer; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Set; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.common.data.RowVertex; -public class VertexTreePath extends AbstractSingleTreePath { - - /** - * Parent node. - */ - private List parents = new ArrayList<>(); - - /** - * Vertex. - */ - private RowVertex vertex; - - private VertexTreePath(Object sessionId, RowVertex vertex) { - this(Sets.newHashSet(sessionId), vertex); - } - - private VertexTreePath(Set requestIds, RowVertex vertex) { - this.requestIds = requestIds; - this.vertex = vertex; - } - - public static VertexTreePath of(Object sessionId, RowVertex vertex) { - return new VertexTreePath(sessionId, vertex); - } - - public static VertexTreePath of(Set requestIds, RowVertex vertex) { - return new VertexTreePath(requestIds, vertex); - } - - public RowVertex getVertex() { - return vertex; - } - - @Override - public void setVertex(RowVertex vertex) { - this.vertex = Objects.requireNonNull(vertex, "vertex is null"); - } - - @Override - public Object getVertexId() { - return vertex != null ? vertex.getId() : null; - } - - @Override - public List getParents() { - return parents; - } - - @Override - public EdgeSet getEdgeSet() { - throw new IllegalArgumentException("Illegal call"); - } - - @Override - public VertexTreePath copy() { - VertexTreePath copyTree = new VertexTreePath(ArrayUtil.copySet(requestIds), vertex); - List parentsCopy = new ArrayList<>(); - for (ITreePath parent : parents) { - parentsCopy.add(parent.copy()); - } - copyTree.parents = parentsCopy; - return copyTree; - } - - @Override - public ITreePath copy(List parents) { - VertexTreePath copyTree = new VertexTreePath(ArrayUtil.copySet(requestIds), vertex); - copyTree.parents = parents; - return copyTree; - } - - @Override - public ITreePath getTreePath(Object requestId) { - if (requestIds != null && !requestIds.contains(requestId)) { - return null; - } - ITreePath treePathOnSession = of(requestId, vertex); - for (ITreePath parent : parents) { - ITreePath sessionParent = parent.getTreePath(requestId); - if (sessionParent != null) { - treePathOnSession.addParent(sessionParent); - } - } - return treePathOnSession; - } +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.google.common.collect.Sets; - @Override - public int size() { - if (parents.isEmpty()) { - return 1; - } - int size = 0; - for (ITreePath parent : parents) { - size += parent.size(); - } - return size; - } +public class VertexTreePath extends AbstractSingleTreePath { - @Override - public boolean equalNode(ITreePath other) { - if (other.getNodeType() == NodeType.VERTEX_TREE) { - return Objects.equals(vertex, other.getVertex()); - } - return false; - } + /** Parent node. */ + private List parents = new ArrayList<>(); + + /** Vertex. */ + private RowVertex vertex; + + private VertexTreePath(Object sessionId, RowVertex vertex) { + this(Sets.newHashSet(sessionId), vertex); + } + + private VertexTreePath(Set requestIds, RowVertex vertex) { + this.requestIds = requestIds; + this.vertex = vertex; + } + + public static VertexTreePath of(Object sessionId, RowVertex vertex) { + return new VertexTreePath(sessionId, vertex); + } + + public static VertexTreePath of(Set requestIds, RowVertex vertex) { + return new VertexTreePath(requestIds, vertex); + } + + public RowVertex getVertex() { + return vertex; + } + + @Override + public void setVertex(RowVertex vertex) { + this.vertex = Objects.requireNonNull(vertex, "vertex is null"); + } + + @Override + public Object getVertexId() { + return vertex != null ? vertex.getId() : null; + } + + @Override + public List getParents() { + return parents; + } + + @Override + public EdgeSet getEdgeSet() { + throw new IllegalArgumentException("Illegal call"); + } + + @Override + public VertexTreePath copy() { + VertexTreePath copyTree = new VertexTreePath(ArrayUtil.copySet(requestIds), vertex); + List parentsCopy = new ArrayList<>(); + for (ITreePath parent : parents) { + parentsCopy.add(parent.copy()); + } + copyTree.parents = parentsCopy; + return copyTree; + } + + @Override + public ITreePath copy(List parents) { + VertexTreePath copyTree = new VertexTreePath(ArrayUtil.copySet(requestIds), vertex); + copyTree.parents = parents; + return copyTree; + } + + @Override + public ITreePath getTreePath(Object requestId) { + if (requestIds != null && !requestIds.contains(requestId)) { + return null; + } + ITreePath treePathOnSession = of(requestId, vertex); + for (ITreePath parent : parents) { + ITreePath sessionParent = parent.getTreePath(requestId); + if (sessionParent != null) { + treePathOnSession.addParent(sessionParent); + } + } + return treePathOnSession; + } + + @Override + public int size() { + if (parents.isEmpty()) { + return 1; + } + int size = 0; + for (ITreePath parent : parents) { + size += parent.size(); + } + return size; + } + + @Override + public boolean equalNode(ITreePath other) { + if (other.getNodeType() == NodeType.VERTEX_TREE) { + return Objects.equals(vertex, other.getVertex()); + } + return false; + } + + @Override + public NodeType getNodeType() { + return NodeType.VERTEX_TREE; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof VertexTreePath)) { + return false; + } + VertexTreePath that = (VertexTreePath) o; + return Objects.equals(parents, that.parents) + && Objects.equals(vertex, that.vertex) + && Objects.equals(requestIds, that.requestIds); + } + + @Override + public int hashCode() { + return Objects.hash(parents, vertex, requestIds); + } + + public static class VertexTreePathSerializer extends Serializer { @Override - public NodeType getNodeType() { - return NodeType.VERTEX_TREE; + public void write(Kryo kryo, Output output, VertexTreePath object) { + kryo.writeClassAndObject(output, object.getRequestIds()); + kryo.writeClassAndObject(output, object.getParents()); + kryo.writeClassAndObject(output, object.getVertex()); } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof VertexTreePath)) { - return false; - } - VertexTreePath that = (VertexTreePath) o; - return Objects.equals(parents, that.parents) && Objects.equals(vertex, that.vertex) - && Objects.equals(requestIds, that.requestIds); + public VertexTreePath read(Kryo kryo, Input input, Class type) { + Set requestIds = (Set) kryo.readClassAndObject(input); + List parents = (List) kryo.readClassAndObject(input); + RowVertex vertex = (RowVertex) kryo.readClassAndObject(input); + VertexTreePath vertexTreePath = VertexTreePath.of(requestIds, vertex); + vertexTreePath.parents.addAll(parents); + return vertexTreePath; } @Override - public int hashCode() { - return Objects.hash(parents, vertex, requestIds); + public VertexTreePath copy(Kryo kryo, VertexTreePath original) { + return original.copy(); } - - public static class VertexTreePathSerializer extends Serializer { - - @Override - public void write(Kryo kryo, Output output, VertexTreePath object) { - kryo.writeClassAndObject(output, object.getRequestIds()); - kryo.writeClassAndObject(output, object.getParents()); - kryo.writeClassAndObject(output, object.getVertex()); - } - - @Override - public VertexTreePath read(Kryo kryo, Input input, Class type) { - Set requestIds = (Set) kryo.readClassAndObject(input); - List parents = (List) kryo.readClassAndObject(input); - RowVertex vertex = (RowVertex) kryo.readClassAndObject(input); - VertexTreePath vertexTreePath = VertexTreePath.of(requestIds, vertex); - vertexTreePath.parents.addAll(parents); - return vertexTreePath; - } - - @Override - public VertexTreePath copy(Kryo kryo, VertexTreePath original) { - return original.copy(); - } - } - + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/AnalyticsResultFormatter.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/AnalyticsResultFormatter.java index 4c2cbbd27..5ecc7aa89 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/AnalyticsResultFormatter.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/AnalyticsResultFormatter.java @@ -19,10 +19,6 @@ package org.apache.geaflow.dsl.runtime.util; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; -import com.alibaba.fastjson.serializer.SerializerFeature; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -30,6 +26,7 @@ import java.util.Map; import java.util.TreeSet; import java.util.stream.Collectors; + import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.geaflow.cluster.response.ResponseResult; @@ -39,172 +36,192 @@ import org.apache.geaflow.dsl.common.data.impl.ObjectRow; import org.apache.geaflow.model.graph.IGraphElementWithLabelField; -public class AnalyticsResultFormatter { - - public static String formatResult(Object queryResult, RelDataType currentResultType) { - final JSONObject finalResult = new JSONObject(); - JSONArray jsonResult = new JSONArray(); - JSONObject viewResult = new JSONObject(); - List vertices = new ArrayList<>(); - List edges = new ArrayList<>(); - List> list = (List>) queryResult; - - for (List responseResults : list) { - for (ResponseResult responseResult : responseResults) { - for (Object o : responseResult.getResponse()) { - jsonResult.add(formatRow(o, currentResultType, vertices, edges)); - } - } - } - - List filteredVertices = - vertices.stream().collect(Collectors.collectingAndThen(Collectors.toCollection(() -> new TreeSet<>( - Comparator.comparing(ViewVertex::getId))), ArrayList::new)); - - viewResult.put("nodes", filteredVertices); - viewResult.put("edges", edges); - finalResult.put("viewResult", viewResult); - finalResult.put("jsonResult", jsonResult); - return JSON.toJSONString(finalResult, SerializerFeature.DisableCircularReferenceDetect); - } - - private static Object formatRow(Object o, RelDataType currentResultType, List vertices, List edges) { - if (o == null) { - return null; - } - if (o instanceof ObjectRow) { - JSONObject jsonObject = new JSONObject(); - ObjectRow objectRow = (ObjectRow) o; - Object[] fields = objectRow.getFields(); - for (int i = 0; i < fields.length; i++) { - RelDataTypeField relDataTypeField = currentResultType.getFieldList().get(i); - Object field = fields[i]; - Object formatResult; - if (field instanceof RowVertex) { - RowVertex vertex = (RowVertex) field; - ObjectRow vertexValue = (ObjectRow) vertex.getValue(); - Map properties = new HashMap<>(); - if (vertexValue != null) { - Object[] vertexProperties = vertexValue.getFields(); - int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); - List typeList = relDataTypeField.getType().getFieldList(); - // find the correspond key in properties - for (int j = 0; j < vertexProperties.length; j++) { - properties.put(typeList.get(j + metaFieldCount).getName(), vertexProperties[j]); - } - } - - formatResult = new ViewVertex(String.valueOf(vertex.getId()), getLabel(vertex), properties); - vertices.add((ViewVertex) formatResult); - } else if (field instanceof RowEdge) { - RowEdge edge = (RowEdge) field; - ObjectRow edgeValue = (ObjectRow) edge.getValue(); - Map properties = new HashMap<>(); - if (edgeValue != null) { - Object[] edgeProperties = edgeValue.getFields(); - int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); - List typeList = relDataTypeField.getType().getFieldList(); - for (int j = 0; j < edgeProperties.length; j++) { - properties.put(typeList.get(j + metaFieldCount).getName(), edgeProperties[j]); - } - } - formatResult = new ViewEdge(String.valueOf(edge.getSrcId()), String.valueOf(edge.getTargetId()), - getLabel(edge), properties, edge.getDirect().name()); - edges.add((ViewEdge) formatResult); - } else { - formatResult = field; - } - - jsonObject.put(relDataTypeField.getKey(), formatResult); - } +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.serializer.SerializerFeature; - return jsonObject; +public class AnalyticsResultFormatter { - } else { - return o.toString(); + public static String formatResult(Object queryResult, RelDataType currentResultType) { + final JSONObject finalResult = new JSONObject(); + JSONArray jsonResult = new JSONArray(); + JSONObject viewResult = new JSONObject(); + List vertices = new ArrayList<>(); + List edges = new ArrayList<>(); + List> list = (List>) queryResult; + + for (List responseResults : list) { + for (ResponseResult responseResult : responseResults) { + for (Object o : responseResult.getResponse()) { + jsonResult.add(formatRow(o, currentResultType, vertices, edges)); } + } } - private static String getLabel(IGraphElementWithLabelField field) { - try { - return field.getLabel(); - } catch (Exception e) { - return null; - } + List filteredVertices = + vertices.stream() + .collect( + Collectors.collectingAndThen( + Collectors.toCollection( + () -> new TreeSet<>(Comparator.comparing(ViewVertex::getId))), + ArrayList::new)); + + viewResult.put("nodes", filteredVertices); + viewResult.put("edges", edges); + finalResult.put("viewResult", viewResult); + finalResult.put("jsonResult", jsonResult); + return JSON.toJSONString(finalResult, SerializerFeature.DisableCircularReferenceDetect); + } + + private static Object formatRow( + Object o, RelDataType currentResultType, List vertices, List edges) { + if (o == null) { + return null; } - - private static int getMetaFieldCount(RelDataType type) { - List fieldList = type.getFieldList(); - int count = 0; - for (RelDataTypeField relDataTypeField : fieldList) { - if (!(relDataTypeField.getType() instanceof MetaFieldType)) { - break; + if (o instanceof ObjectRow) { + JSONObject jsonObject = new JSONObject(); + ObjectRow objectRow = (ObjectRow) o; + Object[] fields = objectRow.getFields(); + for (int i = 0; i < fields.length; i++) { + RelDataTypeField relDataTypeField = currentResultType.getFieldList().get(i); + Object field = fields[i]; + Object formatResult; + if (field instanceof RowVertex) { + RowVertex vertex = (RowVertex) field; + ObjectRow vertexValue = (ObjectRow) vertex.getValue(); + Map properties = new HashMap<>(); + if (vertexValue != null) { + Object[] vertexProperties = vertexValue.getFields(); + int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); + List typeList = relDataTypeField.getType().getFieldList(); + // find the correspond key in properties + for (int j = 0; j < vertexProperties.length; j++) { + properties.put(typeList.get(j + metaFieldCount).getName(), vertexProperties[j]); + } + } + + formatResult = + new ViewVertex(String.valueOf(vertex.getId()), getLabel(vertex), properties); + vertices.add((ViewVertex) formatResult); + } else if (field instanceof RowEdge) { + RowEdge edge = (RowEdge) field; + ObjectRow edgeValue = (ObjectRow) edge.getValue(); + Map properties = new HashMap<>(); + if (edgeValue != null) { + Object[] edgeProperties = edgeValue.getFields(); + int metaFieldCount = getMetaFieldCount(relDataTypeField.getType()); + List typeList = relDataTypeField.getType().getFieldList(); + for (int j = 0; j < edgeProperties.length; j++) { + properties.put(typeList.get(j + metaFieldCount).getName(), edgeProperties[j]); } - count++; + } + formatResult = + new ViewEdge( + String.valueOf(edge.getSrcId()), + String.valueOf(edge.getTargetId()), + getLabel(edge), + properties, + edge.getDirect().name()); + edges.add((ViewEdge) formatResult); + } else { + formatResult = field; } - return count; - } - private static class ViewVertex { + jsonObject.put(relDataTypeField.getKey(), formatResult); + } - private final String id; - private final String label; - private final Map properties; + return jsonObject; - public ViewVertex(String identifier, String label, Map properties) { - this.id = identifier; - this.label = label; - this.properties = properties; - } + } else { + return o.toString(); + } + } - public String getId() { - return id; - } + private static String getLabel(IGraphElementWithLabelField field) { + try { + return field.getLabel(); + } catch (Exception e) { + return null; + } + } + + private static int getMetaFieldCount(RelDataType type) { + List fieldList = type.getFieldList(); + int count = 0; + for (RelDataTypeField relDataTypeField : fieldList) { + if (!(relDataTypeField.getType() instanceof MetaFieldType)) { + break; + } + count++; + } + return count; + } - public String getLabel() { - return label; - } + private static class ViewVertex { - public Map getProperties() { - return properties; - } + private final String id; + private final String label; + private final Map properties; + public ViewVertex(String identifier, String label, Map properties) { + this.id = identifier; + this.label = label; + this.properties = properties; } - private static class ViewEdge { + public String getId() { + return id; + } - private final String source; - private final String target; - private final String label; - private final String direction; - private final Map properties; + public String getLabel() { + return label; + } - public ViewEdge(String source, String target, String label, Map properties, String direction) { - this.source = source; - this.target = target; - this.label = label; - this.properties = properties; - this.direction = direction; - } + public Map getProperties() { + return properties; + } + } + + private static class ViewEdge { + + private final String source; + private final String target; + private final String label; + private final String direction; + private final Map properties; + + public ViewEdge( + String source, + String target, + String label, + Map properties, + String direction) { + this.source = source; + this.target = target; + this.label = label; + this.properties = properties; + this.direction = direction; + } - public String getSource() { - return source; - } + public String getSource() { + return source; + } - public String getTarget() { - return target; - } + public String getTarget() { + return target; + } - public String getLabel() { - return label; - } + public String getLabel() { + return label; + } - public String getDirection() { - return direction; - } + public String getDirection() { + return direction; + } - public Map getProperties() { - return properties; - } + public Map getProperties() { + return properties; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/EdgeProjectorUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/EdgeProjectorUtil.java index fbd1052d5..12c18114a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/EdgeProjectorUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/EdgeProjectorUtil.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.calcite.rex.RexFieldAccess; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.binary.encoder.DefaultEdgeEncoder; @@ -37,140 +38,136 @@ import org.apache.geaflow.dsl.runtime.function.table.ProjectFunction; import org.apache.geaflow.dsl.runtime.function.table.ProjectFunctionImpl; -/** - * Utility class for projecting edges with field pruning. - */ +/** Utility class for projecting edges with field pruning. */ public class EdgeProjectorUtil { - private static final String SOURCE_ID = "srcId"; - private static final String TARGET_ID = "targetId"; - private static final String LABEL = "~label"; - - private final Map projectFunctions; - private final Map> tableOutputTypes; - private final GraphSchema graphSchema; - private final Set fields; - private final IType outputType; - - /** - * Constructs an EdgeProjector with specified parameters. - * - * @param graphSchema The graph schema containing all vertex and edge type definitions - * @param fields The set of fields to be included in the projection, null means no filtering - * @param outputType The output type of the edge, must be an EdgeType - */ - public EdgeProjectorUtil(GraphSchema graphSchema, - Set fields, - IType outputType) { - this.graphSchema = graphSchema; - this.fields = fields; - this.outputType = outputType; - this.projectFunctions = new HashMap<>(); - this.tableOutputTypes = new HashMap<>(); - - if (!(outputType instanceof EdgeType)) { - throw new IllegalArgumentException("Unsupported type: " + outputType.getClass()); - } + private static final String SOURCE_ID = "srcId"; + private static final String TARGET_ID = "targetId"; + private static final String LABEL = "~label"; + + private final Map projectFunctions; + private final Map> tableOutputTypes; + private final GraphSchema graphSchema; + private final Set fields; + private final IType outputType; + + /** + * Constructs an EdgeProjector with specified parameters. + * + * @param graphSchema The graph schema containing all vertex and edge type definitions + * @param fields The set of fields to be included in the projection, null means no filtering + * @param outputType The output type of the edge, must be an EdgeType + */ + public EdgeProjectorUtil( + GraphSchema graphSchema, Set fields, IType outputType) { + this.graphSchema = graphSchema; + this.fields = fields; + this.outputType = outputType; + this.projectFunctions = new HashMap<>(); + this.tableOutputTypes = new HashMap<>(); + + if (!(outputType instanceof EdgeType)) { + throw new IllegalArgumentException("Unsupported type: " + outputType.getClass()); + } + } + + /** + * Projects an edge by filtering fields based on the required field set. + * + * @param edge The input edge to be projected + * @return The projected edge with only required fields, or null if input is null + */ + public RowEdge projectEdge(RowEdge edge) { + if (edge == null) { + return null; } - /** - * Projects an edge by filtering fields based on the required field set. - * - * @param edge The input edge to be projected - * @return The projected edge with only required fields, or null if input is null - */ - public RowEdge projectEdge(RowEdge edge) { - if (edge == null) { - return null; - } - - String edgeLabel = edge.getLabel(); - - // Initialize project function for this edge label if not exists - if (this.projectFunctions.get(edgeLabel) == null) { - initializeProject( - edge, // edge: The edge instance used for schema inference - edgeLabel // edgeLabel: The label of the edge for unique identification - ); - } - - // Utilize project functions to filter fields - ProjectFunction currentProjectFunction = this.projectFunctions.get(edgeLabel); - ObjectRow projectEdge = (ObjectRow) currentProjectFunction.project(edge); - RowEdge edgeDecoded = (RowEdge) projectEdge.getField(0, null); + String edgeLabel = edge.getLabel(); - EdgeType edgeType = new EdgeType(this.tableOutputTypes.get(edgeLabel), false); - EdgeEncoder encoder = new DefaultEdgeEncoder(edgeType); - return encoder.encode(edgeDecoded); + // Initialize project function for this edge label if not exists + if (this.projectFunctions.get(edgeLabel) == null) { + initializeProject( + edge, // edge: The edge instance used for schema inference + edgeLabel // edgeLabel: The label of the edge for unique identification + ); } - /** - * Initializes the project function for a given edge label. - * - * @param edge The edge instance used to determine the schema and label - * @param edgeLabel The label of the edge for unique identification - */ - private void initializeProject(RowEdge edge, String edgeLabel) { - List graphSchemaFieldList = graphSchema.getFields(); - - // Get fields of the output edge type - List fieldsOfTable = ((EdgeType) outputType).getFields(); - - // Extract field names from RexFieldAccess list into a set - Set fieldNames = (this.fields == null) - ? Collections.emptySet() - : this.fields.stream() - .map(e -> e.getField().getName()) - .collect(Collectors.toSet()); - - List expressions = new ArrayList<>(); - List tableOutputType = null; - - // Enumerate list of fields in every table - for (TableField tableField : graphSchemaFieldList) { - if (edgeLabel.equals(tableField.getName())) { - List inputs = new ArrayList<>(); - tableOutputType = new ArrayList<>(); - - // Enumerate list of fields in the targeted table - for (int i = 0; i < fieldsOfTable.size(); i++) { - TableField column = fieldsOfTable.get(i); - String columnName = column.getName(); - - // Normalize: convert fields like `knowsCreationDate` to `creationDate` - if (columnName.startsWith(edgeLabel)) { - String suffix = columnName.substring(edgeLabel.length()); - if (!suffix.isEmpty()) { - suffix = Character.toLowerCase(suffix.charAt(0)) + suffix.substring(1); - columnName = suffix; - } - } - - if (fieldNames.contains(columnName) - || columnName.equals(SOURCE_ID) - || columnName.equals(TARGET_ID)) { - // Include a field if it's in fieldNames or is source/target ID column - inputs.add(new FieldExpression(null, i, column.getType())); - tableOutputType.add(column); - } else if (columnName.equals(LABEL)) { - // Add edge label for LABEL column - inputs.add(new LiteralExpression(edge.getLabel(), column.getType())); - tableOutputType.add(column); - } else { - // Use null placeholder for excluded fields - inputs.add(new LiteralExpression(null, column.getType())); - tableOutputType.add(column); - } - } - - expressions.add(new EdgeConstructExpression(inputs, new EdgeType(tableOutputType, false))); + // Utilize project functions to filter fields + ProjectFunction currentProjectFunction = this.projectFunctions.get(edgeLabel); + ObjectRow projectEdge = (ObjectRow) currentProjectFunction.project(edge); + RowEdge edgeDecoded = (RowEdge) projectEdge.getField(0, null); + + EdgeType edgeType = new EdgeType(this.tableOutputTypes.get(edgeLabel), false); + EdgeEncoder encoder = new DefaultEdgeEncoder(edgeType); + return encoder.encode(edgeDecoded); + } + + /** + * Initializes the project function for a given edge label. + * + * @param edge The edge instance used to determine the schema and label + * @param edgeLabel The label of the edge for unique identification + */ + private void initializeProject(RowEdge edge, String edgeLabel) { + List graphSchemaFieldList = graphSchema.getFields(); + + // Get fields of the output edge type + List fieldsOfTable = ((EdgeType) outputType).getFields(); + + // Extract field names from RexFieldAccess list into a set + Set fieldNames = + (this.fields == null) + ? Collections.emptySet() + : this.fields.stream().map(e -> e.getField().getName()).collect(Collectors.toSet()); + + List expressions = new ArrayList<>(); + List tableOutputType = null; + + // Enumerate list of fields in every table + for (TableField tableField : graphSchemaFieldList) { + if (edgeLabel.equals(tableField.getName())) { + List inputs = new ArrayList<>(); + tableOutputType = new ArrayList<>(); + + // Enumerate list of fields in the targeted table + for (int i = 0; i < fieldsOfTable.size(); i++) { + TableField column = fieldsOfTable.get(i); + String columnName = column.getName(); + + // Normalize: convert fields like `knowsCreationDate` to `creationDate` + if (columnName.startsWith(edgeLabel)) { + String suffix = columnName.substring(edgeLabel.length()); + if (!suffix.isEmpty()) { + suffix = Character.toLowerCase(suffix.charAt(0)) + suffix.substring(1); + columnName = suffix; } + } + + if (fieldNames.contains(columnName) + || columnName.equals(SOURCE_ID) + || columnName.equals(TARGET_ID)) { + // Include a field if it's in fieldNames or is source/target ID column + inputs.add(new FieldExpression(null, i, column.getType())); + tableOutputType.add(column); + } else if (columnName.equals(LABEL)) { + // Add edge label for LABEL column + inputs.add(new LiteralExpression(edge.getLabel(), column.getType())); + tableOutputType.add(column); + } else { + // Use null placeholder for excluded fields + inputs.add(new LiteralExpression(null, column.getType())); + tableOutputType.add(column); + } } - ProjectFunction projectFunction = new ProjectFunctionImpl(expressions); - - // Store project function and output type for this edge label - this.projectFunctions.put(edgeLabel, projectFunction); - this.tableOutputTypes.put(edgeLabel, tableOutputType); + expressions.add(new EdgeConstructExpression(inputs, new EdgeType(tableOutputType, false))); + } } + + ProjectFunction projectFunction = new ProjectFunctionImpl(expressions); + + // Store project function and output type for this edge label + this.projectFunctions.put(edgeLabel, projectFunction); + this.tableOutputTypes.put(edgeLabel, tableOutputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/FilterPushDownUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/FilterPushDownUtil.java index aa096d927..ce98c46b5 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/FilterPushDownUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/FilterPushDownUtil.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; + import org.apache.calcite.rex.*; import org.apache.calcite.sql.SqlKind; import org.apache.geaflow.dsl.calcite.EdgeRecordType; @@ -33,242 +34,247 @@ public class FilterPushDownUtil { - /** - * Find timestamp range condition in the expression. - */ - public static List findTsRange(RexNode rexNode, - EdgeRecordType edgeRecordType) { - int tsFieldIndex = edgeRecordType.getTimestampIndex(); - if (tsFieldIndex < 0) { + /** Find timestamp range condition in the expression. */ + public static List findTsRange(RexNode rexNode, EdgeRecordType edgeRecordType) { + int tsFieldIndex = edgeRecordType.getTimestampIndex(); + if (tsFieldIndex < 0) { + return new ArrayList<>(); + } + return rexNode.accept( + new RexVisitor>() { + @Override + public List visitInputRef(RexInputRef rexInputRef) { return new ArrayList<>(); - } - return rexNode.accept(new RexVisitor>() { - @Override - public List visitInputRef(RexInputRef rexInputRef) { - return new ArrayList<>(); - } + } - @Override - public List visitLocalRef(RexLocalRef rexLocalRef) { - return new ArrayList<>(); - } + @Override + public List visitLocalRef(RexLocalRef rexLocalRef) { + return new ArrayList<>(); + } - @Override - public List visitLiteral(RexLiteral rexLiteral) { - return new ArrayList<>(); - } + @Override + public List visitLiteral(RexLiteral rexLiteral) { + return new ArrayList<>(); + } - @Override - public List visitCall(RexCall call) { - SqlKind kind = call.getKind(); - switch (kind) { - case BETWEEN: - RexNode node = call.operands.get(0); - RexNode leftValue = call.operands.get(1); - RexNode rightValue = call.operands.get(2); - if (node instanceof RexFieldAccess) { - if (tsFieldIndex == ((RexFieldAccess) node).getField().getIndex()) { - Long leftTsValue = null; - Long rightTsValue = null; - if (GQLRexUtil.isLiteralOrParameter(leftValue, true)) { - leftTsValue = toTsLongValue(leftValue); - } - if (GQLRexUtil.isLiteralOrParameter(rightValue, true)) { - rightTsValue = toTsLongValue(rightValue); - } - TimeRange result = TimeRange.of(leftTsValue == null ? Long.MIN_VALUE : leftTsValue, rightTsValue == null ? Long.MAX_VALUE : rightTsValue - ); - return new ArrayList<>(Collections.singletonList(result)); - } - } - return new ArrayList<>(); + @Override + public List visitCall(RexCall call) { + SqlKind kind = call.getKind(); + switch (kind) { + case BETWEEN: + RexNode node = call.operands.get(0); + RexNode leftValue = call.operands.get(1); + RexNode rightValue = call.operands.get(2); + if (node instanceof RexFieldAccess) { + if (tsFieldIndex == ((RexFieldAccess) node).getField().getIndex()) { + Long leftTsValue = null; + Long rightTsValue = null; + if (GQLRexUtil.isLiteralOrParameter(leftValue, true)) { + leftTsValue = toTsLongValue(leftValue); + } + if (GQLRexUtil.isLiteralOrParameter(rightValue, true)) { + rightTsValue = toTsLongValue(rightValue); + } + TimeRange result = + TimeRange.of( + leftTsValue == null ? Long.MIN_VALUE : leftTsValue, + rightTsValue == null ? Long.MAX_VALUE : rightTsValue); + return new ArrayList<>(Collections.singletonList(result)); + } + } + return new ArrayList<>(); + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case EQUALS: + RexNode rangeValue = null; + RexNode left = call.operands.get(0); + RexNode right = call.operands.get(1); + SqlKind realKind = kind; + if (left instanceof RexFieldAccess + && GQLRexUtil.isLiteralOrParameter(right, true)) { + RexFieldAccess leftAccess = (RexFieldAccess) left; + int index = leftAccess.getField().getIndex(); + if (tsFieldIndex == index) { + rangeValue = right; + } + } else if (right instanceof RexFieldAccess + && GQLRexUtil.isLiteralOrParameter(left, true)) { + RexFieldAccess rightAccess = (RexFieldAccess) right; + int index = rightAccess.getField().getIndex(); + if (tsFieldIndex == index) { + switch (kind) { + case LESS_THAN: + realKind = SqlKind.GREATER_THAN; + break; + case LESS_THAN_OR_EQUAL: + realKind = SqlKind.GREATER_THAN_OR_EQUAL; + break; + case GREATER_THAN: + realKind = SqlKind.LESS_THAN; + break; + case GREATER_THAN_OR_EQUAL: + realKind = SqlKind.LESS_THAN_OR_EQUAL; + break; + default: + } + rangeValue = left; + } + } + if (rangeValue != null) { + Long ts = toTsLongValue(rangeValue); + if (ts == null) { + return new ArrayList<>(); + } + switch (realKind) { case LESS_THAN: + return Collections.singletonList(TimeRange.of(Long.MIN_VALUE, ts)); case LESS_THAN_OR_EQUAL: + if (ts < Long.MAX_VALUE) { + return Collections.singletonList(TimeRange.of(Long.MIN_VALUE, ts + 1)); + } + return Collections.singletonList( + TimeRange.of(Long.MIN_VALUE, Long.MAX_VALUE)); case GREATER_THAN: + if (ts < Long.MAX_VALUE) { + return Collections.singletonList(TimeRange.of(ts + 1, Long.MAX_VALUE)); + } + return Collections.singletonList( + TimeRange.of(Long.MAX_VALUE, Long.MAX_VALUE)); case GREATER_THAN_OR_EQUAL: + return Collections.singletonList(TimeRange.of(ts, Long.MAX_VALUE)); case EQUALS: - RexNode rangeValue = null; - RexNode left = call.operands.get(0); - RexNode right = call.operands.get(1); - SqlKind realKind = kind; - if (left instanceof RexFieldAccess && GQLRexUtil.isLiteralOrParameter(right, - true)) { - RexFieldAccess leftAccess = (RexFieldAccess) left; - int index = leftAccess.getField().getIndex(); - if (tsFieldIndex == index) { - rangeValue = right; - } - } else if (right instanceof RexFieldAccess && GQLRexUtil.isLiteralOrParameter(left, true)) { - RexFieldAccess rightAccess = (RexFieldAccess) right; - int index = rightAccess.getField().getIndex(); - if (tsFieldIndex == index) { - switch (kind) { - case LESS_THAN: - realKind = SqlKind.GREATER_THAN; - break; - case LESS_THAN_OR_EQUAL: - realKind = SqlKind.GREATER_THAN_OR_EQUAL; - break; - case GREATER_THAN: - realKind = SqlKind.LESS_THAN; - break; - case GREATER_THAN_OR_EQUAL: - realKind = SqlKind.LESS_THAN_OR_EQUAL; - break; - default: - } - rangeValue = left; - } - } - if (rangeValue != null) { - Long ts = toTsLongValue(rangeValue); - if (ts == null) { - return new ArrayList<>(); - } - switch (realKind) { - case LESS_THAN: - return Collections.singletonList(TimeRange.of(Long.MIN_VALUE, ts)); - case LESS_THAN_OR_EQUAL: - if (ts < Long.MAX_VALUE) { - return Collections.singletonList(TimeRange.of(Long.MIN_VALUE, ts + 1)); - } - return Collections.singletonList(TimeRange.of(Long.MIN_VALUE, Long.MAX_VALUE)); - case GREATER_THAN: - if (ts < Long.MAX_VALUE) { - return Collections.singletonList(TimeRange.of(ts + 1, Long.MAX_VALUE)); - } - return Collections.singletonList(TimeRange.of(Long.MAX_VALUE, Long.MAX_VALUE)); - case GREATER_THAN_OR_EQUAL: - return Collections.singletonList(TimeRange.of(ts, Long.MAX_VALUE)); - case EQUALS: - if (ts < Long.MAX_VALUE) { - return Collections.singletonList(TimeRange.of(ts, ts + 1)); - } else { - return Collections.singletonList(TimeRange.of(ts, Long.MAX_VALUE)); - } - default: - } - } - return new ArrayList<>(); - case AND: - return call.operands.stream() - .map(operand -> operand.accept(this)) - .filter(list -> list != null && !list.isEmpty()) - .reduce(FilterPushDownUtil::timeRangeIntersection) - .orElse(new ArrayList<>()); - case OR: - return call.operands.stream() - .map(operand -> operand.accept(this)) - .filter(list -> list != null && !list.isEmpty()) - .reduce(FilterPushDownUtil::timeRangeUnion) - .orElse(new ArrayList<>()); + if (ts < Long.MAX_VALUE) { + return Collections.singletonList(TimeRange.of(ts, ts + 1)); + } else { + return Collections.singletonList(TimeRange.of(ts, Long.MAX_VALUE)); + } default: - return new ArrayList<>(); + } } - } - - @Override - public List visitOver(RexOver rexOver) { return new ArrayList<>(); - } - - @Override - public List visitCorrelVariable(RexCorrelVariable rexCorrelVariable) { + case AND: + return call.operands.stream() + .map(operand -> operand.accept(this)) + .filter(list -> list != null && !list.isEmpty()) + .reduce(FilterPushDownUtil::timeRangeIntersection) + .orElse(new ArrayList<>()); + case OR: + return call.operands.stream() + .map(operand -> operand.accept(this)) + .filter(list -> list != null && !list.isEmpty()) + .reduce(FilterPushDownUtil::timeRangeUnion) + .orElse(new ArrayList<>()); + default: return new ArrayList<>(); } + } - @Override - public List visitDynamicParam(RexDynamicParam rexDynamicParam) { - return new ArrayList<>(); - } + @Override + public List visitOver(RexOver rexOver) { + return new ArrayList<>(); + } - @Override - public List visitRangeRef(RexRangeRef rexRangeRef) { - return new ArrayList<>(); - } + @Override + public List visitCorrelVariable(RexCorrelVariable rexCorrelVariable) { + return new ArrayList<>(); + } - @Override - public List visitFieldAccess(RexFieldAccess rexFieldAccess) { - return new ArrayList<>(); - } + @Override + public List visitDynamicParam(RexDynamicParam rexDynamicParam) { + return new ArrayList<>(); + } - @Override - public List visitSubQuery(RexSubQuery rexSubQuery) { - return new ArrayList<>(); - } + @Override + public List visitRangeRef(RexRangeRef rexRangeRef) { + return new ArrayList<>(); + } - @Override - public List visitTableInputRef(RexTableInputRef rexTableInputRef) { - return new ArrayList<>(); - } + @Override + public List visitFieldAccess(RexFieldAccess rexFieldAccess) { + return new ArrayList<>(); + } - @Override - public List visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { - return new ArrayList<>(); - } + @Override + public List visitSubQuery(RexSubQuery rexSubQuery) { + return new ArrayList<>(); + } - @Override - public List visitOther(RexNode other) { - return new ArrayList<>(); - } + @Override + public List visitTableInputRef(RexTableInputRef rexTableInputRef) { + return new ArrayList<>(); + } + + @Override + public List visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { + return new ArrayList<>(); + } + + @Override + public List visitOther(RexNode other) { + return new ArrayList<>(); + } }); - } + } - private static Long toTsLongValue(RexNode rangeValue) { - List nonLiteralLeafNodes = GQLRexUtil.collect(rangeValue, - child -> !(child instanceof RexCall) && !(child instanceof RexLiteral)); - ExpressionTranslator translator = ExpressionTranslator.of(null); - Expression expression = translator.translate(rangeValue); - if (nonLiteralLeafNodes.isEmpty()) { // all the leaf node is constant. - Object constantValue = expression.evaluate(null); - assert constantValue instanceof Number : "Not Number timestamp range."; - return ((Number) constantValue).longValue(); - } - //todo Parameter timestamp range not support push down currently. - return null; + private static Long toTsLongValue(RexNode rangeValue) { + List nonLiteralLeafNodes = + GQLRexUtil.collect( + rangeValue, child -> !(child instanceof RexCall) && !(child instanceof RexLiteral)); + ExpressionTranslator translator = ExpressionTranslator.of(null); + Expression expression = translator.translate(rangeValue); + if (nonLiteralLeafNodes.isEmpty()) { // all the leaf node is constant. + Object constantValue = expression.evaluate(null); + assert constantValue instanceof Number : "Not Number timestamp range."; + return ((Number) constantValue).longValue(); } + // todo Parameter timestamp range not support push down currently. + return null; + } - public static List timeRangeIntersection(final List ranges, - final List others) { - List tmpList = new ArrayList<>(); - for (TimeRange a : ranges) { - for (TimeRange b : ranges) { - long maxStart = Math.max(a.getStart(), b.getStart()); - long maxEnd = Math.max(a.getEnd(), b.getEnd()); - TimeRange range = TimeRange.of(maxStart, maxEnd); - tmpList.add(range); - } - } - return new ArrayList<>(mergeTsRanges(tmpList)); + public static List timeRangeIntersection( + final List ranges, final List others) { + List tmpList = new ArrayList<>(); + for (TimeRange a : ranges) { + for (TimeRange b : ranges) { + long maxStart = Math.max(a.getStart(), b.getStart()); + long maxEnd = Math.max(a.getEnd(), b.getEnd()); + TimeRange range = TimeRange.of(maxStart, maxEnd); + tmpList.add(range); + } } + return new ArrayList<>(mergeTsRanges(tmpList)); + } - public static List mergeTsRanges(List list) { - list.sort(Comparator.comparing(TimeRange::getStart)); - for (int i = 0; i < list.size() - 1; i++) { - TimeRange outer; - TimeRange inner; - for (int j = i + 1; j < list.size(); j++) { - outer = list.get(i); - inner = list.get(j); - long end = list.get(i).getEnd(); - long end2 = inner.getEnd(); - if (end >= inner.getStart() && end <= end2) { - TimeRange tmpRange = TimeRange.of(outer.getStart(), end2); - list.set(i, tmpRange); - list.remove(j); - j--; - } else if (end >= end2 || outer.getStart() == end2) { - list.remove(j--); - } - } + public static List mergeTsRanges(List list) { + list.sort(Comparator.comparing(TimeRange::getStart)); + for (int i = 0; i < list.size() - 1; i++) { + TimeRange outer; + TimeRange inner; + for (int j = i + 1; j < list.size(); j++) { + outer = list.get(i); + inner = list.get(j); + long end = list.get(i).getEnd(); + long end2 = inner.getEnd(); + if (end >= inner.getStart() && end <= end2) { + TimeRange tmpRange = TimeRange.of(outer.getStart(), end2); + list.set(i, tmpRange); + list.remove(j); + j--; + } else if (end >= end2 || outer.getStart() == end2) { + list.remove(j--); } - return list; + } } + return list; + } - public static List timeRangeUnion(final List ranges, final List others) { - List tmpList = new ArrayList<>(); - tmpList.addAll(ranges); - tmpList.addAll(others); - return new ArrayList<>(mergeTsRanges(tmpList)); - } + public static List timeRangeUnion( + final List ranges, final List others) { + List tmpList = new ArrayList<>(); + tmpList.addAll(ranges); + tmpList.addAll(others); + return new ArrayList<>(mergeTsRanges(tmpList)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/IDUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/IDUtil.java index 05f4232d6..8985c55a2 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/IDUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/IDUtil.java @@ -20,15 +20,15 @@ package org.apache.geaflow.dsl.runtime.util; public class IDUtil { - /** - * generate a unique id between all the tasks. - * - * @param numTask The number of total tasks. - * @param taskIndex The index for current task. - * @param idInTask The unique id in the task. - * @return A unique id between all the tasks. - */ - public static long uniqueId(int numTask, int taskIndex, long idInTask) { - return numTask * idInTask + taskIndex; - } + /** + * generate a unique id between all the tasks. + * + * @param numTask The number of total tasks. + * @param taskIndex The index for current task. + * @param idInTask The unique id in the task. + * @return A unique id between all the tasks. + */ + public static long uniqueId(int numTask, int taskIndex, long idInTask) { + return numTask * idInTask + taskIndex; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/QueryUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/QueryUtil.java index f8e328f05..7ebe88166 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/QueryUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/QueryUtil.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.calcite.schema.Table; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlInsert; @@ -50,91 +51,98 @@ public class QueryUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(QueryUtil.class); + private static final Logger LOGGER = LoggerFactory.getLogger(QueryUtil.class); - public static PreCompileResult preCompile(String script, Configuration config) { - GeaFlowDSLParser parser = new GeaFlowDSLParser(); - GQLContext gqlContext = GQLContext.create(config, false); - PreCompileResult preCompileResult = new PreCompileResult(); - try { - List sqlNodes = parser.parseMultiStatement(script); - List createTablesInScript = new ArrayList<>(); - Map createGraphs = new HashMap<>(); - for (SqlNode sqlNode : sqlNodes) { - if (sqlNode instanceof SqlSetOption) { - SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; - String key = StringLiteralUtil.toJavaString(sqlSetOption.getName()); - String value = StringLiteralUtil.toJavaString(sqlSetOption.getValue()); - config.put(key, value); - } else if (sqlNode instanceof SqlCreateTable) { - createTablesInScript.add(gqlContext.convertToTable((SqlCreateTable) sqlNode)); - } else if (sqlNode instanceof SqlCreateGraph) { - SqlCreateGraph createGraph = (SqlCreateGraph) sqlNode; - SqlIdentifier graphName = gqlContext.completeCatalogObjName(createGraph.getName()); - if (createGraph.getVertices().getList().stream().anyMatch(node -> node instanceof SqlVertexUsing) - || createGraph.getEdges().getList().stream().anyMatch(node -> node instanceof SqlEdgeUsing)) { - GeaFlowGraph graph = gqlContext.convertToGraph(createGraph, createTablesInScript); - Configuration globalConfig = graph.getConfigWithGlobal(config); - if (!QueryUtil.isGraphExists(graph, globalConfig)) { - LOGGER.info("insertGraphs: {}", graph.getUniqueName()); - preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(graph, globalConfig)); - } - } else { - createGraphs.put(graphName.toString(), createGraph); - } - } else if (sqlNode instanceof SqlInsert) { - SqlInsert insert = (SqlInsert) sqlNode; - SqlIdentifier insertName = gqlContext.completeCatalogObjName( - (SqlIdentifier) insert.getTargetTable()); - SqlIdentifier insertGraphName = GQLNodeUtil.getGraphTableName(insertName); - String simpleGraphName = insertName.getComponent(1, 2).getSimple(); - LOGGER.info("insertGraphName: {}, insertName:{}, simpleGraphName: {}", - insertGraphName, insertName, simpleGraphName); - if (createGraphs.containsKey(insertGraphName.toString())) { - SqlCreateGraph createGraph = createGraphs.get(insertGraphName.toString()); - GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); - LOGGER.info("insertGraphs: {}", graph.getUniqueName()); - preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(graph, config)); - } else { - Table graph = gqlContext.getCatalog().getGraph( - gqlContext.getCurrentInstance(), simpleGraphName); - if (graph != null) { - GeaFlowGraph geaFlowGraph = (GeaFlowGraph) graph; - geaFlowGraph.getConfig().putAll(gqlContext.keyMapping(geaFlowGraph.getConfig().getConfigMap())); - LOGGER.info("insertGraphs: {}", geaFlowGraph.getUniqueName()); - preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(geaFlowGraph, config)); - } - } - } + public static PreCompileResult preCompile(String script, Configuration config) { + GeaFlowDSLParser parser = new GeaFlowDSLParser(); + GQLContext gqlContext = GQLContext.create(config, false); + PreCompileResult preCompileResult = new PreCompileResult(); + try { + List sqlNodes = parser.parseMultiStatement(script); + List createTablesInScript = new ArrayList<>(); + Map createGraphs = new HashMap<>(); + for (SqlNode sqlNode : sqlNodes) { + if (sqlNode instanceof SqlSetOption) { + SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; + String key = StringLiteralUtil.toJavaString(sqlSetOption.getName()); + String value = StringLiteralUtil.toJavaString(sqlSetOption.getValue()); + config.put(key, value); + } else if (sqlNode instanceof SqlCreateTable) { + createTablesInScript.add(gqlContext.convertToTable((SqlCreateTable) sqlNode)); + } else if (sqlNode instanceof SqlCreateGraph) { + SqlCreateGraph createGraph = (SqlCreateGraph) sqlNode; + SqlIdentifier graphName = gqlContext.completeCatalogObjName(createGraph.getName()); + if (createGraph.getVertices().getList().stream() + .anyMatch(node -> node instanceof SqlVertexUsing) + || createGraph.getEdges().getList().stream() + .anyMatch(node -> node instanceof SqlEdgeUsing)) { + GeaFlowGraph graph = gqlContext.convertToGraph(createGraph, createTablesInScript); + Configuration globalConfig = graph.getConfigWithGlobal(config); + if (!QueryUtil.isGraphExists(graph, globalConfig)) { + LOGGER.info("insertGraphs: {}", graph.getUniqueName()); + preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(graph, globalConfig)); + } + } else { + createGraphs.put(graphName.toString(), createGraph); + } + } else if (sqlNode instanceof SqlInsert) { + SqlInsert insert = (SqlInsert) sqlNode; + SqlIdentifier insertName = + gqlContext.completeCatalogObjName((SqlIdentifier) insert.getTargetTable()); + SqlIdentifier insertGraphName = GQLNodeUtil.getGraphTableName(insertName); + String simpleGraphName = insertName.getComponent(1, 2).getSimple(); + LOGGER.info( + "insertGraphName: {}, insertName:{}, simpleGraphName: {}", + insertGraphName, + insertName, + simpleGraphName); + if (createGraphs.containsKey(insertGraphName.toString())) { + SqlCreateGraph createGraph = createGraphs.get(insertGraphName.toString()); + GeaFlowGraph graph = gqlContext.convertToGraph(createGraph); + LOGGER.info("insertGraphs: {}", graph.getUniqueName()); + preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(graph, config)); + } else { + Table graph = + gqlContext.getCatalog().getGraph(gqlContext.getCurrentInstance(), simpleGraphName); + if (graph != null) { + GeaFlowGraph geaFlowGraph = (GeaFlowGraph) graph; + geaFlowGraph + .getConfig() + .putAll(gqlContext.keyMapping(geaFlowGraph.getConfig().getConfigMap())); + LOGGER.info("insertGraphs: {}", geaFlowGraph.getUniqueName()); + preCompileResult.addGraph(SchemaUtil.buildGraphViewDesc(geaFlowGraph, config)); } - return preCompileResult; - } catch (SqlParseException e) { - throw new GeaFlowDSLException(e); + } } + } + return preCompileResult; + } catch (SqlParseException e) { + throw new GeaFlowDSLException(e); } + } - public static boolean isGraphExists(GeaFlowGraph graph, Configuration globalConfig) { - boolean graphExists; - try { - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graph.getUniqueName(), globalConfig); - long lastCheckPointId = keeper.getLatestViewVersion(graph.getUniqueName()); - graphExists = lastCheckPointId >= 0; - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } - return graphExists; + public static boolean isGraphExists(GeaFlowGraph graph, Configuration globalConfig) { + boolean graphExists; + try { + ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graph.getUniqueName(), globalConfig); + long lastCheckPointId = keeper.getLatestViewVersion(graph.getUniqueName()); + graphExists = lastCheckPointId >= 0; + } catch (IOException e) { + throw new GeaFlowDSLException(e); } + return graphExists; + } - public static class PreCompileResult implements Serializable { + public static class PreCompileResult implements Serializable { - private final List insertGraphs = new ArrayList<>(); + private final List insertGraphs = new ArrayList<>(); - public void addGraph(GraphViewDesc graphViewDesc) { - insertGraphs.add(graphViewDesc); - } + public void addGraph(GraphViewDesc graphViewDesc) { + insertGraphs.add(graphViewDesc); + } - public List getInsertGraphs() { - return insertGraphs; - } + public List getInsertGraphs() { + return insertGraphs; } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/SchemaUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/SchemaUtil.java index 06c9155ac..5b268a267 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/SchemaUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/SchemaUtil.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.data.Path; @@ -55,155 +56,159 @@ public class SchemaUtil { - public static final String VERTEX_EDGE_CONSTRUCTOR_FIELD = "CONSTRUCTOR"; - - public static GraphMetaType buildGraphMeta(GeaFlowGraph graph) { - Map vertexTypes = getVertexTypes(graph); - Map edgeTypes = getEdgeTypes(graph); - RowVertex vertex = VertexEdgeFactory.createVertex(vertexTypes.values().iterator().next()); - RowEdge edge = VertexEdgeFactory.createEdge(edgeTypes.values().iterator().next()); - - try { - Field vertexConstructorField = vertex.getClass().getField(VERTEX_EDGE_CONSTRUCTOR_FIELD); - Field edgeConstructorField = edge.getClass().getField(VERTEX_EDGE_CONSTRUCTOR_FIELD); - Supplier vertexConstructor = (Supplier) vertexConstructorField.get(null); - Supplier edgeConstructor = (Supplier) edgeConstructorField.get(null); - return new GraphMetaType(graph.getIdType(), vertex.getClass(), vertexConstructor, - BinaryRow.class, edge.getClass(), edgeConstructor, BinaryRow.class); - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new GeaFlowDSLException("Fail to get vertex or edge constructor", e); - } - } - - - public static GraphViewDesc buildGraphViewDesc(GeaFlowGraph graph, Configuration conf) { - Configuration graphConfig = new Configuration(graph.getConfig().getConfigMap()); - long latestVersion = getGraphLatestVersion(graph, conf); - BackendType storeType = BackendType.of(graph.getStoreType()); - return GraphViewBuilder.createGraphView(graph.getUniqueName()) - .withShardNum(graph.getShardCount()) - .withBackend(storeType) - .withLatestVersion(latestVersion) - .withProps(graphConfig.getConfigMap()) - .withSchema(buildGraphMeta(graph)) - .build(); + public static final String VERTEX_EDGE_CONSTRUCTOR_FIELD = "CONSTRUCTOR"; + + public static GraphMetaType buildGraphMeta(GeaFlowGraph graph) { + Map vertexTypes = getVertexTypes(graph); + Map edgeTypes = getEdgeTypes(graph); + RowVertex vertex = VertexEdgeFactory.createVertex(vertexTypes.values().iterator().next()); + RowEdge edge = VertexEdgeFactory.createEdge(edgeTypes.values().iterator().next()); + + try { + Field vertexConstructorField = vertex.getClass().getField(VERTEX_EDGE_CONSTRUCTOR_FIELD); + Field edgeConstructorField = edge.getClass().getField(VERTEX_EDGE_CONSTRUCTOR_FIELD); + Supplier vertexConstructor = (Supplier) vertexConstructorField.get(null); + Supplier edgeConstructor = (Supplier) edgeConstructorField.get(null); + return new GraphMetaType( + graph.getIdType(), + vertex.getClass(), + vertexConstructor, + BinaryRow.class, + edge.getClass(), + edgeConstructor, + BinaryRow.class); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new GeaFlowDSLException("Fail to get vertex or edge constructor", e); } - - private static long getGraphLatestVersion(GeaFlowGraph graph, Configuration conf) { - try { - Configuration globalConfig = graph.getConfigWithGlobal(conf); - ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graph.getUniqueName(), globalConfig); - return keeper.getLatestViewVersion(graph.getUniqueName()); - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } + } + + public static GraphViewDesc buildGraphViewDesc(GeaFlowGraph graph, Configuration conf) { + Configuration graphConfig = new Configuration(graph.getConfig().getConfigMap()); + long latestVersion = getGraphLatestVersion(graph, conf); + BackendType storeType = BackendType.of(graph.getStoreType()); + return GraphViewBuilder.createGraphView(graph.getUniqueName()) + .withShardNum(graph.getShardCount()) + .withBackend(storeType) + .withLatestVersion(latestVersion) + .withProps(graphConfig.getConfigMap()) + .withSchema(buildGraphMeta(graph)) + .build(); + } + + private static long getGraphLatestVersion(GeaFlowGraph graph, Configuration conf) { + try { + Configuration globalConfig = graph.getConfigWithGlobal(conf); + ViewMetaBookKeeper keeper = new ViewMetaBookKeeper(graph.getUniqueName(), globalConfig); + return keeper.getLatestViewVersion(graph.getUniqueName()); + } catch (IOException e) { + throw new GeaFlowDSLException(e); } + } - public static Map getVertexTypes(GeaFlowGraph graph) { - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + public static Map getVertexTypes(GeaFlowGraph graph) { + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - Map vertexTypes = new HashMap<>(); - for (VertexTable vertexTable : graph.getVertexTables()) { - VertexType vertexType = (VertexType) SqlTypeUtil.convertType(vertexTable.getRowType(typeFactory)); - vertexTypes.put(vertexTable.getTypeName(), vertexType); - } - return vertexTypes; + Map vertexTypes = new HashMap<>(); + for (VertexTable vertexTable : graph.getVertexTables()) { + VertexType vertexType = + (VertexType) SqlTypeUtil.convertType(vertexTable.getRowType(typeFactory)); + vertexTypes.put(vertexTable.getTypeName(), vertexType); } + return vertexTypes; + } - public static Map getEdgeTypes(GeaFlowGraph graph) { - GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); + public static Map getEdgeTypes(GeaFlowGraph graph) { + GQLJavaTypeFactory typeFactory = GQLJavaTypeFactory.create(); - Map edgeTypes = new HashMap<>(); - for (EdgeTable edgeTable : graph.getEdgeTables()) { - EdgeType edgeType = (EdgeType) SqlTypeUtil.convertType(edgeTable.getRowType(typeFactory)); - edgeTypes.put(edgeTable.getTypeName(), edgeType); - } - return edgeTypes; + Map edgeTypes = new HashMap<>(); + for (EdgeTable edgeTable : graph.getEdgeTables()) { + EdgeType edgeType = (EdgeType) SqlTypeUtil.convertType(edgeTable.getRowType(typeFactory)); + edgeTypes.put(edgeTable.getTypeName(), edgeType); } - - private static int[] getFieldMappingIndices(StructType inputType, StructType outputType) { - int[] mapping = new int[outputType.size()]; - for (int i = 0; i < outputType.size(); i++) { - String outputField = outputType.getField(i).getName(); - mapping[i] = inputType.indexOf(outputField); - if (mapping[i] < 0) { - switch (outputField) { - case VertexType.DEFAULT_ID_FIELD_NAME: - mapping[i] = VertexType.ID_FIELD_POSITION; - break; - case EdgeType.DEFAULT_SRC_ID_NAME: - mapping[i] = EdgeType.SRC_ID_FIELD_POSITION; - break; - case EdgeType.DEFAULT_TARGET_ID_NAME: - mapping[i] = EdgeType.TARGET_ID_FIELD_POSITION; - break; - case EdgeType.DEFAULT_TS_NAME: - mapping[i] = EdgeType.TIME_FIELD_POSITION; - break; - default: - if (mapping[i] < -1) { - throw new GeaFlowDSLException("Cannot find field {}, illegal index {}", - outputField, mapping[i]); - } - } + return edgeTypes; + } + + private static int[] getFieldMappingIndices(StructType inputType, StructType outputType) { + int[] mapping = new int[outputType.size()]; + for (int i = 0; i < outputType.size(); i++) { + String outputField = outputType.getField(i).getName(); + mapping[i] = inputType.indexOf(outputField); + if (mapping[i] < 0) { + switch (outputField) { + case VertexType.DEFAULT_ID_FIELD_NAME: + mapping[i] = VertexType.ID_FIELD_POSITION; + break; + case EdgeType.DEFAULT_SRC_ID_NAME: + mapping[i] = EdgeType.SRC_ID_FIELD_POSITION; + break; + case EdgeType.DEFAULT_TARGET_ID_NAME: + mapping[i] = EdgeType.TARGET_ID_FIELD_POSITION; + break; + case EdgeType.DEFAULT_TS_NAME: + mapping[i] = EdgeType.TIME_FIELD_POSITION; + break; + default: + if (mapping[i] < -1) { + throw new GeaFlowDSLException( + "Cannot find field {}, illegal index {}", outputField, mapping[i]); } } - return mapping; + } } + return mapping; + } - public static RowVertex alignToVertexSchema(RowVertex vertex, VertexType inputVertexType, - VertexType outputVertexType) { - if (vertex == null) { - return null; - } - if (inputVertexType.equals(outputVertexType)) { - return vertex; - } - int[] mapping = getFieldMappingIndices(inputVertexType, outputVertexType); - return FieldAlignVertex.createFieldAlignedVertex(vertex, mapping); + public static RowVertex alignToVertexSchema( + RowVertex vertex, VertexType inputVertexType, VertexType outputVertexType) { + if (vertex == null) { + return null; } - - public static RowEdge alignToEdgeSchema(RowEdge edge, EdgeType inputEdgeType, - EdgeType outputEdgeType) { - if (edge == null) { - return null; - } - if (inputEdgeType.equals(outputEdgeType)) { - return edge; - } - int[] mapping = getFieldMappingIndices(inputEdgeType, outputEdgeType); - return FieldAlignEdge.createFieldAlignedEdge(edge, mapping); + if (inputVertexType.equals(outputVertexType)) { + return vertex; } + int[] mapping = getFieldMappingIndices(inputVertexType, outputVertexType); + return FieldAlignVertex.createFieldAlignedVertex(vertex, mapping); + } + + public static RowEdge alignToEdgeSchema( + RowEdge edge, EdgeType inputEdgeType, EdgeType outputEdgeType) { + if (edge == null) { + return null; + } + if (inputEdgeType.equals(outputEdgeType)) { + return edge; + } + int[] mapping = getFieldMappingIndices(inputEdgeType, outputEdgeType); + return FieldAlignEdge.createFieldAlignedEdge(edge, mapping); + } - public static Path alignToPathSchema(Path path, PathType inputPathType, - PathType outputPathType) { - if (path == null) { - return null; - } - int[] mapping = getFieldMappingIndices(inputPathType, outputPathType); - - List pathNodes = new ArrayList<>(path.size()); - for (int i = 0; i < mapping.length; i++) { - int index = mapping[i]; - if (index >= 0) { - IType nodeType = inputPathType.getType(index); - IType outputType = outputPathType.getType(i); - Row node = path.getField(index, nodeType); - - if (nodeType instanceof VertexType) { - node = alignToVertexSchema((RowVertex) node, (VertexType) nodeType, - (VertexType) outputType); - } else if (nodeType instanceof EdgeType) { - node = alignToEdgeSchema((RowEdge) node, (EdgeType) nodeType, - (EdgeType) outputType); - } else { - throw new IllegalArgumentException("Illegal node type: " + nodeType); - } - pathNodes.add(node); - } else { - pathNodes.add(null); - } + public static Path alignToPathSchema(Path path, PathType inputPathType, PathType outputPathType) { + if (path == null) { + return null; + } + int[] mapping = getFieldMappingIndices(inputPathType, outputPathType); + + List pathNodes = new ArrayList<>(path.size()); + for (int i = 0; i < mapping.length; i++) { + int index = mapping[i]; + if (index >= 0) { + IType nodeType = inputPathType.getType(index); + IType outputType = outputPathType.getType(i); + Row node = path.getField(index, nodeType); + + if (nodeType instanceof VertexType) { + node = + alignToVertexSchema((RowVertex) node, (VertexType) nodeType, (VertexType) outputType); + } else if (nodeType instanceof EdgeType) { + node = alignToEdgeSchema((RowEdge) node, (EdgeType) nodeType, (EdgeType) outputType); + } else { + throw new IllegalArgumentException("Illegal node type: " + nodeType); } - return new DefaultPath(pathNodes); + pathNodes.add(node); + } else { + pathNodes.add(null); + } } + return new DefaultPath(pathNodes); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/StepFunctionUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/StepFunctionUtil.java index 56c1940e5..93630380f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/StepFunctionUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/StepFunctionUtil.java @@ -20,19 +20,18 @@ package org.apache.geaflow.dsl.runtime.util; import java.util.stream.Collectors; + import org.apache.geaflow.common.utils.ArrayUtil; import org.apache.geaflow.dsl.runtime.function.graph.StepFunction; public class StepFunctionUtil { - public static int[] getRefPathIndices(StepFunction function) { - return ArrayUtil.toIntArray( - function.getExpressions() - .stream() - .flatMap(expression -> expression.getRefPathFieldIndices().stream()) - .distinct() - .sorted() - .collect(Collectors.toList()) - ); - } + public static int[] getRefPathIndices(StepFunction function) { + return ArrayUtil.toIntArray( + function.getExpressions().stream() + .flatMap(expression -> expression.getRefPathFieldIndices().stream()) + .distinct() + .sorted() + .collect(Collectors.toList())); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/VertexProjectorUtil.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/VertexProjectorUtil.java index 1e87ac6f3..97e54c088 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/VertexProjectorUtil.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/util/VertexProjectorUtil.java @@ -21,6 +21,7 @@ import java.util.*; import java.util.stream.Collectors; + import org.apache.calcite.rex.RexFieldAccess; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.dsl.common.binary.encoder.DefaultVertexEncoder; @@ -37,147 +38,152 @@ import org.apache.geaflow.dsl.runtime.function.table.ProjectFunction; import org.apache.geaflow.dsl.runtime.function.table.ProjectFunctionImpl; -/** - * Utility class for projecting vertices with field pruning. - */ +/** Utility class for projecting vertices with field pruning. */ public class VertexProjectorUtil { - private static final String ID = "id"; - private static final String LABEL = "~label"; - - private final Map projectFunctions; - private final Map> tableOutputTypes; - private final GraphSchema graphSchema; - private final Set fields; - private final String[] addingVertexFieldNames; - private final IType[] addingVertexFieldTypes; - - /** - * Constructs a VertexProjector with specified parameters. - * - * @param graphSchema The graph schema containing all vertex and edge type definitions - * @param fields The set of fields to be included in the projection, null means no filtering - * @param addingVertexFieldNames The names of additional fields to be added to vertices (e.g., global variables) - * @param addingVertexFieldTypes The types of additional fields corresponding to addingVertexFieldNames - */ - public VertexProjectorUtil(GraphSchema graphSchema, - Set fields, - String[] addingVertexFieldNames, - IType[] addingVertexFieldTypes) { - this.graphSchema = graphSchema; - this.fields = fields; - this.addingVertexFieldNames = addingVertexFieldNames != null ? addingVertexFieldNames : new String[0]; - this.addingVertexFieldTypes = addingVertexFieldTypes != null ? addingVertexFieldTypes : new IType[0]; - this.projectFunctions = new HashMap<>(); - this.tableOutputTypes = new HashMap<>(); + private static final String ID = "id"; + private static final String LABEL = "~label"; + + private final Map projectFunctions; + private final Map> tableOutputTypes; + private final GraphSchema graphSchema; + private final Set fields; + private final String[] addingVertexFieldNames; + private final IType[] addingVertexFieldTypes; + + /** + * Constructs a VertexProjector with specified parameters. + * + * @param graphSchema The graph schema containing all vertex and edge type definitions + * @param fields The set of fields to be included in the projection, null means no filtering + * @param addingVertexFieldNames The names of additional fields to be added to vertices (e.g., + * global variables) + * @param addingVertexFieldTypes The types of additional fields corresponding to + * addingVertexFieldNames + */ + public VertexProjectorUtil( + GraphSchema graphSchema, + Set fields, + String[] addingVertexFieldNames, + IType[] addingVertexFieldTypes) { + this.graphSchema = graphSchema; + this.fields = fields; + this.addingVertexFieldNames = + addingVertexFieldNames != null ? addingVertexFieldNames : new String[0]; + this.addingVertexFieldTypes = + addingVertexFieldTypes != null ? addingVertexFieldTypes : new IType[0]; + this.projectFunctions = new HashMap<>(); + this.tableOutputTypes = new HashMap<>(); + } + + /** + * Projects a vertex by filtering fields based on the required field set. + * + * @param vertex The input vertex to be projected + * @return The projected vertex with only required fields, or null if input is null + */ + public RowVertex projectVertex(RowVertex vertex) { + if (vertex == null) { + return null; } - /** - * Projects a vertex by filtering fields based on the required field set. - * - * @param vertex The input vertex to be projected - * @return The projected vertex with only required fields, or null if input is null - */ - public RowVertex projectVertex(RowVertex vertex) { - if (vertex == null) { - return null; - } - - // Handle the case of global variables - String compactedVertexLabel = vertex.getLabel(); - for (String addingName : addingVertexFieldNames) { - compactedVertexLabel += "_" + addingName; - } - - // Initialize - if (this.projectFunctions.get(compactedVertexLabel) == null) { - initializeProject(vertex, compactedVertexLabel, addingVertexFieldTypes, addingVertexFieldNames); - } - - // Utilize project functions to filter fields - ProjectFunction currentProjectFunction = this.projectFunctions.get(compactedVertexLabel); - ObjectRow projectVertex = (ObjectRow) currentProjectFunction.project(vertex); - RowVertex vertexDecoded = (RowVertex) projectVertex.getField(0, null); + // Handle the case of global variables + String compactedVertexLabel = vertex.getLabel(); + for (String addingName : addingVertexFieldNames) { + compactedVertexLabel += "_" + addingName; + } - VertexType vertexType = new VertexType(this.tableOutputTypes.get(compactedVertexLabel)); - VertexEncoder encoder = new DefaultVertexEncoder(vertexType); - return encoder.encode(vertexDecoded); + // Initialize + if (this.projectFunctions.get(compactedVertexLabel) == null) { + initializeProject( + vertex, compactedVertexLabel, addingVertexFieldTypes, addingVertexFieldNames); } - /** - * Initializes the project function for a given vertex label. - * - * @param vertex The vertex instance used to determine the schema and label - * @param compactedLabel The vertex label with additional field names appended for unique identification - * @param globalTypes The types of global variables to be added to the vertex - * @param globalNames The names of global variables to be added to the vertex - */ - private void initializeProject(RowVertex vertex, String compactedLabel, - IType[] globalTypes, String[] globalNames) { - List graphSchemaFieldList = graphSchema.getFields(); - List fieldsOfTable; - List tableOutputType = new ArrayList<>(); - - // Extract field names from RexFieldAccess list into a set - Set fieldNames = (this.fields == null) - ? Collections.emptySet() - : this.fields.stream() - .map(e -> e.getField().getName()) - .collect(Collectors.toSet()); - - List expressions = new ArrayList<>(); - String vertexLabel = vertex.getLabel(); - - for (TableField tableField : graphSchemaFieldList) { - if (vertexLabel.equals(tableField.getName())) { - List inputs = new ArrayList<>(); - fieldsOfTable = ((VertexType) tableField.getType()).getFields(); - - for (int i = 0; i < fieldsOfTable.size(); i++) { - TableField column = fieldsOfTable.get(i); - String columnName = column.getName(); - - // Normalize: convert fields like `personId` to `id` - if (columnName.startsWith(vertexLabel)) { - String suffix = columnName.substring(vertexLabel.length()); - if (!suffix.isEmpty()) { - suffix = Character.toLowerCase(suffix.charAt(0)) + suffix.substring(1); - columnName = suffix; - } - } - - if (fieldNames.contains(columnName) || columnName.equals(ID)) { - // Include a field if it's in fieldNames or is ID column - inputs.add(new FieldExpression(null, i, column.getType())); - tableOutputType.add(column); - } else if (columnName.equals(LABEL)) { - // Add vertex label for LABEL column - inputs.add(new LiteralExpression(vertex.getLabel(), column.getType())); - tableOutputType.add(column); - } else { - // Use null placeholder for excluded fields - inputs.add(new LiteralExpression(null, column.getType())); - tableOutputType.add(column); - } - } - - // Handle additional mapping when all global variables exist - if (globalNames.length > 0) { - for (int j = 0; j < globalNames.length; j++) { - int fieldIndex = j + fieldsOfTable.size(); - inputs.add(new FieldExpression(null, fieldIndex, globalTypes[j])); - tableOutputType.add(new TableField(globalNames[j], globalTypes[j])); - } - } - - expressions.add(new VertexConstructExpression(inputs, null, new VertexType(tableOutputType))); + // Utilize project functions to filter fields + ProjectFunction currentProjectFunction = this.projectFunctions.get(compactedVertexLabel); + ObjectRow projectVertex = (ObjectRow) currentProjectFunction.project(vertex); + RowVertex vertexDecoded = (RowVertex) projectVertex.getField(0, null); + + VertexType vertexType = new VertexType(this.tableOutputTypes.get(compactedVertexLabel)); + VertexEncoder encoder = new DefaultVertexEncoder(vertexType); + return encoder.encode(vertexDecoded); + } + + /** + * Initializes the project function for a given vertex label. + * + * @param vertex The vertex instance used to determine the schema and label + * @param compactedLabel The vertex label with additional field names appended for unique + * identification + * @param globalTypes The types of global variables to be added to the vertex + * @param globalNames The names of global variables to be added to the vertex + */ + private void initializeProject( + RowVertex vertex, String compactedLabel, IType[] globalTypes, String[] globalNames) { + List graphSchemaFieldList = graphSchema.getFields(); + List fieldsOfTable; + List tableOutputType = new ArrayList<>(); + + // Extract field names from RexFieldAccess list into a set + Set fieldNames = + (this.fields == null) + ? Collections.emptySet() + : this.fields.stream().map(e -> e.getField().getName()).collect(Collectors.toSet()); + + List expressions = new ArrayList<>(); + String vertexLabel = vertex.getLabel(); + + for (TableField tableField : graphSchemaFieldList) { + if (vertexLabel.equals(tableField.getName())) { + List inputs = new ArrayList<>(); + fieldsOfTable = ((VertexType) tableField.getType()).getFields(); + + for (int i = 0; i < fieldsOfTable.size(); i++) { + TableField column = fieldsOfTable.get(i); + String columnName = column.getName(); + + // Normalize: convert fields like `personId` to `id` + if (columnName.startsWith(vertexLabel)) { + String suffix = columnName.substring(vertexLabel.length()); + if (!suffix.isEmpty()) { + suffix = Character.toLowerCase(suffix.charAt(0)) + suffix.substring(1); + columnName = suffix; } + } + + if (fieldNames.contains(columnName) || columnName.equals(ID)) { + // Include a field if it's in fieldNames or is ID column + inputs.add(new FieldExpression(null, i, column.getType())); + tableOutputType.add(column); + } else if (columnName.equals(LABEL)) { + // Add vertex label for LABEL column + inputs.add(new LiteralExpression(vertex.getLabel(), column.getType())); + tableOutputType.add(column); + } else { + // Use null placeholder for excluded fields + inputs.add(new LiteralExpression(null, column.getType())); + tableOutputType.add(column); + } } - ProjectFunction projectFunction = new ProjectFunctionImpl(expressions); + // Handle additional mapping when all global variables exist + if (globalNames.length > 0) { + for (int j = 0; j < globalNames.length; j++) { + int fieldIndex = j + fieldsOfTable.size(); + inputs.add(new FieldExpression(null, fieldIndex, globalTypes[j])); + tableOutputType.add(new TableField(globalNames[j], globalTypes[j])); + } + } - // Store project functions - this.projectFunctions.put(compactedLabel, projectFunction); - this.tableOutputTypes.put(compactedLabel, tableOutputType); + expressions.add( + new VertexConstructExpression(inputs, null, new VertexType(tableOutputType))); + } } + + ProjectFunction projectFunction = new ProjectFunctionImpl(expressions); + + // Store project functions + this.projectFunctions.put(compactedLabel, projectFunction); + this.tableOutputTypes.put(compactedLabel, tableOutputType); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/CompilerTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/CompilerTest.java index b18d273ef..048fd2287 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/CompilerTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/CompilerTest.java @@ -21,8 +21,6 @@ import static org.apache.geaflow.common.config.keys.DSLConfigKeys.GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE; -import com.google.common.collect.Sets; -import com.google.gson.Gson; import java.io.IOException; import java.nio.charset.Charset; import java.util.HashMap; @@ -30,6 +28,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.dsl.common.compile.CompileContext; import org.apache.geaflow.dsl.common.compile.CompileResult; @@ -39,97 +38,104 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class CompilerTest { +import com.google.common.collect.Sets; +import com.google.gson.Gson; - @Test - public void testCompile() throws IOException { - QueryCompiler compiler = new QueryClient(); - String script = IOUtils.resourceToString("/query/compile.sql", Charset.defaultCharset()); - CompileContext context = new CompileContext(); - Map config = new HashMap<>(); - config.put(GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), Boolean.TRUE.toString()); - context.setConfig(config); - context.setParallelisms(new HashMap<>()); +public class CompilerTest { - CompileResult result = compiler.compile(script, context); - Assert.assertEquals(result.getSourceGraphs().stream().map(g -> g.getGraphName()).collect( - Collectors.toSet()), Sets.newHashSet("modern")); - Assert.assertEquals(result.getTargetGraphs(), Sets.newHashSet()); + @Test + public void testCompile() throws IOException { + QueryCompiler compiler = new QueryClient(); + String script = IOUtils.resourceToString("/query/compile.sql", Charset.defaultCharset()); + CompileContext context = new CompileContext(); + Map config = new HashMap<>(); + config.put(GEAFLOW_DSL_COMPILE_PHYSICAL_PLAN_ENABLE.getKey(), Boolean.TRUE.toString()); + context.setConfig(config); + context.setParallelisms(new HashMap<>()); - Assert.assertEquals(result.getSourceTables(), Sets.newHashSet()); - Assert.assertEquals(result.getTargetTables().stream().map(t -> t.getTableName()).collect( - Collectors.toSet()), Sets.newHashSet("tbl_result")); - Gson gson = new Gson(); - Assert.assertEquals(gson.toJson(result.getPhysicPlan()), - "{\"vertices\":{\"1\":{\"vertexType\":\"source\",\"id\":\"1\",\"parallelism\":1,\"parents\":[]," - + "\"innerPlan\":{\"vertices\":{\"1-1\":{\"id\":\"1-1\",\"parallelism\":1," - + "\"operator\":\"WindowSourceOperator\",\"operatorName\":\"1\",\"parents\":[]}," - + "\"1-4\":{\"id\":\"1-4\",\"parallelism\":1,\"operator\":\"KeySelectorOperator\"," - + "\"operatorName\":\"4\",\"parents\":[{\"id\":\"1-1\"}]}}}},\"2\":{\"vertexType\":\"source\"," - + "\"id\":\"2\",\"parallelism\":1,\"parents\":[]," - + "\"innerPlan\":{\"vertices\":{\"2-2\":{\"id\":\"2-2\",\"parallelism\":1," - + "\"operator\":\"WindowSourceOperator\",\"operatorName\":\"2\",\"parents\":[]}," - + "\"2-5\":{\"id\":\"2-5\",\"parallelism\":1,\"operator\":\"KeySelectorOperator\"," - + "\"operatorName\":\"5\",\"parents\":[{\"id\":\"2-2\"}]}}}}," - + "\"3\":{\"vertexType\":\"vertex_centric\",\"id\":\"3\",\"parallelism\":2," - + "\"operator\":\"StaticGraphVertexCentricTraversalAllOp\"," - + "\"operatorName\":\"GeaFlowStaticVCTraversal\",\"parents\":[{\"id\":\"3\"," - + "\"partitionType\":\"key\"},{\"id\":\"2\",\"partitionType\":\"key\"},{\"id\":\"1\"," - + "\"partitionType\":\"key\"}]},\"6\":{\"vertexType\":\"process\",\"id\":\"6\",\"parallelism\":2," - + "\"parents\":[{\"id\":\"3\",\"partitionType\":\"forward\"}]," - + "\"innerPlan\":{\"vertices\":{\"6-7\":{\"id\":\"6-7\",\"parallelism\":2,\"operator\":\"MapOperator\"," - + "\"operatorName\":\"Project-1\",\"parents\":[{\"id\":\"6-6\"}]},\"6-8\":{\"id\":\"6-8\"," - + "\"parallelism\":2,\"operator\":\"SinkOperator\",\"operatorName\":\"TableSink-2\"," - + "\"parents\":[{\"id\":\"6-7\"}]},\"6-6\":{\"id\":\"6-6\",\"parallelism\":2," - + "\"operator\":\"FlatMapOperator\",\"operatorName\":\"TraversalResponseToRow-0\",\"parents\":[]}}}}}}" - ); - } + CompileResult result = compiler.compile(script, context); + Assert.assertEquals( + result.getSourceGraphs().stream().map(g -> g.getGraphName()).collect(Collectors.toSet()), + Sets.newHashSet("modern")); + Assert.assertEquals(result.getTargetGraphs(), Sets.newHashSet()); + Assert.assertEquals(result.getSourceTables(), Sets.newHashSet()); + Assert.assertEquals( + result.getTargetTables().stream().map(t -> t.getTableName()).collect(Collectors.toSet()), + Sets.newHashSet("tbl_result")); + Gson gson = new Gson(); + Assert.assertEquals( + gson.toJson(result.getPhysicPlan()), + "{\"vertices\":{\"1\":{\"vertexType\":\"source\",\"id\":\"1\",\"parallelism\":1,\"parents\":[]," + + "\"innerPlan\":{\"vertices\":{\"1-1\":{\"id\":\"1-1\",\"parallelism\":1," + + "\"operator\":\"WindowSourceOperator\",\"operatorName\":\"1\",\"parents\":[]}," + + "\"1-4\":{\"id\":\"1-4\",\"parallelism\":1,\"operator\":\"KeySelectorOperator\"," + + "\"operatorName\":\"4\",\"parents\":[{\"id\":\"1-1\"}]}}}},\"2\":{\"vertexType\":\"source\"," + + "\"id\":\"2\",\"parallelism\":1,\"parents\":[]," + + "\"innerPlan\":{\"vertices\":{\"2-2\":{\"id\":\"2-2\",\"parallelism\":1," + + "\"operator\":\"WindowSourceOperator\",\"operatorName\":\"2\",\"parents\":[]}," + + "\"2-5\":{\"id\":\"2-5\",\"parallelism\":1,\"operator\":\"KeySelectorOperator\"," + + "\"operatorName\":\"5\",\"parents\":[{\"id\":\"2-2\"}]}}}}," + + "\"3\":{\"vertexType\":\"vertex_centric\",\"id\":\"3\",\"parallelism\":2," + + "\"operator\":\"StaticGraphVertexCentricTraversalAllOp\"," + + "\"operatorName\":\"GeaFlowStaticVCTraversal\",\"parents\":[{\"id\":\"3\"," + + "\"partitionType\":\"key\"},{\"id\":\"2\",\"partitionType\":\"key\"},{\"id\":\"1\"," + + "\"partitionType\":\"key\"}]},\"6\":{\"vertexType\":\"process\",\"id\":\"6\",\"parallelism\":2," + + "\"parents\":[{\"id\":\"3\",\"partitionType\":\"forward\"}]," + + "\"innerPlan\":{\"vertices\":{\"6-7\":{\"id\":\"6-7\",\"parallelism\":2,\"operator\":\"MapOperator\"," + + "\"operatorName\":\"Project-1\",\"parents\":[{\"id\":\"6-6\"}]},\"6-8\":{\"id\":\"6-8\"," + + "\"parallelism\":2,\"operator\":\"SinkOperator\",\"operatorName\":\"TableSink-2\"," + + "\"parents\":[{\"id\":\"6-7\"}]},\"6-6\":{\"id\":\"6-6\",\"parallelism\":2," + + "\"operator\":\"FlatMapOperator\",\"operatorName\":\"TraversalResponseToRow-0\",\"parents\":[]}}}}}}"); + } - @Test - public void testFindUnResolvedFunctions() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindUnResolvedFunctions() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "create function f0 as 'com.antgroup.udf.TestUdf';" + String script = + "create function f0 as 'com.antgroup.udf.TestUdf';" + "select f1(name), substr(name, 1, 10) from t0;" + "use instance instance0;" + "select f2(id), max(id) from t0;" + "create view v0(c0, c1) as select f3(id) as c0, c1 from t1"; - Set unResolvedFunctions = compiler.getUnResolvedFunctions(script, context); - Assert.assertEquals(unResolvedFunctions.size(), 4); - List functions = - unResolvedFunctions.stream().map(FunctionInfo::toString).collect(Collectors.toList()); - Assert.assertEquals(functions.get(0), "instance0.f3"); - Assert.assertEquals(functions.get(1), "instance0.f2"); - Assert.assertEquals(functions.get(2), "default.f1"); - Assert.assertEquals(functions.get(3), "default.f0"); - } + Set unResolvedFunctions = compiler.getUnResolvedFunctions(script, context); + Assert.assertEquals(unResolvedFunctions.size(), 4); + List functions = + unResolvedFunctions.stream().map(FunctionInfo::toString).collect(Collectors.toList()); + Assert.assertEquals(functions.get(0), "instance0.f3"); + Assert.assertEquals(functions.get(1), "instance0.f2"); + Assert.assertEquals(functions.get(2), "default.f1"); + Assert.assertEquals(functions.get(3), "default.f0"); + } - @Test - public void testFindTables() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindTables() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "insert into t1(id,name) select 1,\"tom\";\n" + String script = + "insert into t1(id,name) select 1,\"tom\";\n" + "insert into t2 select id,name from t1;\n" + "insert into t4 select * from t3;"; - Set tables = compiler.getUnResolvedTables(script, context); - Assert.assertEquals(tables.size(), 4); - Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t3"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t4"))); - } + Set tables = compiler.getUnResolvedTables(script, context); + Assert.assertEquals(tables.size(), 4); + Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t3"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t4"))); + } - @Test - public void testFindTables2() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindTables2() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "CREATE GRAPH dy_modern (\n" + String script = + "CREATE GRAPH dy_modern (\n" + "\tVertex person (\n" + "\t id bigint ID,\n" + "\t name varchar,\n" @@ -190,47 +196,48 @@ public void testFindTables2() { + "\n" + "INSERT INTO t1 select * from t2;\n"; - Set tables = compiler.getUnResolvedTables(script, context); - Assert.assertEquals(tables.size(), 2); - Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); - Assert.assertFalse(tables.contains(new TableInfo("default", "tbl_result"))); - Assert.assertFalse(tables.contains(new TableInfo("default", "dy_modern"))); - } + Set tables = compiler.getUnResolvedTables(script, context); + Assert.assertEquals(tables.size(), 2); + Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); + Assert.assertFalse(tables.contains(new TableInfo("default", "tbl_result"))); + Assert.assertFalse(tables.contains(new TableInfo("default", "dy_modern"))); + } - @Test - public void testFindTables3() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindTables3() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "insert into t1(id) select * from (select name from t2);"; + String script = "insert into t1(id) select * from (select name from t2);"; - Set tables = compiler.getUnResolvedTables(script, context); - Assert.assertEquals(tables.size(), 2); - Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); - } + Set tables = compiler.getUnResolvedTables(script, context); + Assert.assertEquals(tables.size(), 2); + Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); + } - @Test - public void testFindTables4() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindTables4() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "insert into t1(id) select t2.id from t2 join t3 on t2.id = t3.id;"; + String script = "insert into t1(id) select t2.id from t2 join t3 on t2.id = t3.id;"; - Set tables = compiler.getUnResolvedTables(script, context); - Assert.assertEquals(tables.size(), 3); - Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t3"))); - } + Set tables = compiler.getUnResolvedTables(script, context); + Assert.assertEquals(tables.size(), 3); + Assert.assertTrue(tables.contains(new TableInfo("default", "t1"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t3"))); + } - @Test - public void testFindTables5() { - QueryCompiler compiler = new QueryClient(); - CompileContext context = new CompileContext(); + @Test + public void testFindTables5() { + QueryCompiler compiler = new QueryClient(); + CompileContext context = new CompileContext(); - String script = "CREATE TABLE tbl_result (\n" + String script = + "CREATE TABLE tbl_result (\n" + " a_id bigint,\n" + " b_id bigint,\n" + " weight double\n" @@ -254,10 +261,10 @@ public void testFindTables5() { + " RETURN a.id as a_id, e.weight as weight, b.id as b_id\n" + ")"; - Set tables = compiler.getUnResolvedTables(script, context); - Assert.assertEquals(tables.size(), 2); - Assert.assertTrue(tables.contains(new TableInfo("default", "p"))); - Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); - Assert.assertFalse(tables.contains(new TableInfo("default", "tbl_result"))); - } + Set tables = compiler.getUnResolvedTables(script, context); + Assert.assertEquals(tables.size(), 2); + Assert.assertTrue(tables.contains(new TableInfo("default", "p"))); + Assert.assertTrue(tables.contains(new TableInfo("default", "t2"))); + Assert.assertFalse(tables.contains(new TableInfo("default", "tbl_result"))); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/FieldAlignTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/FieldAlignTest.java index 473a5d79f..312f4224a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/FieldAlignTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/FieldAlignTest.java @@ -34,69 +34,69 @@ public class FieldAlignTest { - @Test - public void testFieldAlignEdge() { - RowEdge edge = new IntEdge(1, 2); - edge.setValue(ObjectRow.create(1, 2, 3)); - edge.setDirect(EdgeDirection.OUT); - - FieldAlignEdge alignEdge = new FieldAlignEdge(edge, new int[]{0, 1, 2, 3, -1, 4, 5}); - - Assert.assertEquals(alignEdge.getSrcId(), edge.getSrcId()); - Assert.assertEquals(alignEdge.getTargetId(), edge.getTargetId()); - Assert.assertEquals(alignEdge.getDirect(), edge.getDirect()); - - Assert.assertEquals(alignEdge.getField(3, Types.INTEGER), 1); - Assert.assertNull(alignEdge.getField(4, Types.INTEGER)); - Assert.assertEquals(alignEdge.getField(6, Types.INTEGER), 3); - - alignEdge.setLabel("l0"); - Assert.assertEquals(alignEdge.getLabel(), "l0"); - alignEdge.setBinaryLabel(BinaryString.fromString("l1")); - Assert.assertEquals(alignEdge.getBinaryLabel(), BinaryString.fromString("l1")); - - alignEdge = (FieldAlignEdge) alignEdge.withValue(ObjectRow.EMPTY); - Assert.assertEquals(alignEdge.getValue(), ObjectRow.EMPTY); - - alignEdge = (FieldAlignEdge) alignEdge.withDirection(EdgeDirection.IN); - Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.IN); - - alignEdge.setSrcId(2); - alignEdge.setTargetId(1); - alignEdge.setDirect(EdgeDirection.OUT); - Assert.assertEquals(alignEdge.getSrcId(), 2); - Assert.assertEquals(alignEdge.getTargetId(), 1); - Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.OUT); - - alignEdge = (FieldAlignEdge) alignEdge.identityReverse(); - Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.IN); - Assert.assertEquals(alignEdge.getSrcId(), 1); - Assert.assertEquals(alignEdge.getTargetId(), 2); - } - - @Test - public void testAlignVertex() { - RowVertex vertex = new IntVertex(1); - vertex.setValue(ObjectRow.create(1, 2, 3)); - - FieldAlignVertex alignVertex = new FieldAlignVertex(vertex, new int[]{-1, 2, 1, 0}); - Assert.assertEquals(alignVertex.getId(), vertex.getId()); - Assert.assertNull(alignVertex.getField(0, Types.INTEGER)); - Assert.assertEquals(alignVertex.getField(1, Types.INTEGER), 1); - - alignVertex.setId(2); - Assert.assertEquals(alignVertex.getId(), 2); - alignVertex.setLabel("l0"); - Assert.assertEquals(alignVertex.getLabel(), "l0"); - alignVertex.setBinaryLabel(BinaryString.fromString("l1")); - Assert.assertEquals(alignVertex.getBinaryLabel(), BinaryString.fromString("l1")); - - alignVertex.setValue(ObjectRow.EMPTY); - Assert.assertEquals(alignVertex.getValue(), ObjectRow.EMPTY); - alignVertex = (FieldAlignVertex) alignVertex.withValue(ObjectRow.EMPTY); - Assert.assertEquals(alignVertex.getValue(), ObjectRow.EMPTY); - - alignVertex = (FieldAlignVertex) alignVertex.withLabel("l2"); - Assert.assertEquals(alignVertex.getBinaryLabel(), BinaryString.fromString("l2")); - } + @Test + public void testFieldAlignEdge() { + RowEdge edge = new IntEdge(1, 2); + edge.setValue(ObjectRow.create(1, 2, 3)); + edge.setDirect(EdgeDirection.OUT); + + FieldAlignEdge alignEdge = new FieldAlignEdge(edge, new int[] {0, 1, 2, 3, -1, 4, 5}); + + Assert.assertEquals(alignEdge.getSrcId(), edge.getSrcId()); + Assert.assertEquals(alignEdge.getTargetId(), edge.getTargetId()); + Assert.assertEquals(alignEdge.getDirect(), edge.getDirect()); + + Assert.assertEquals(alignEdge.getField(3, Types.INTEGER), 1); + Assert.assertNull(alignEdge.getField(4, Types.INTEGER)); + Assert.assertEquals(alignEdge.getField(6, Types.INTEGER), 3); + + alignEdge.setLabel("l0"); + Assert.assertEquals(alignEdge.getLabel(), "l0"); + alignEdge.setBinaryLabel(BinaryString.fromString("l1")); + Assert.assertEquals(alignEdge.getBinaryLabel(), BinaryString.fromString("l1")); + + alignEdge = (FieldAlignEdge) alignEdge.withValue(ObjectRow.EMPTY); + Assert.assertEquals(alignEdge.getValue(), ObjectRow.EMPTY); + + alignEdge = (FieldAlignEdge) alignEdge.withDirection(EdgeDirection.IN); + Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.IN); + + alignEdge.setSrcId(2); + alignEdge.setTargetId(1); + alignEdge.setDirect(EdgeDirection.OUT); + Assert.assertEquals(alignEdge.getSrcId(), 2); + Assert.assertEquals(alignEdge.getTargetId(), 1); + Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.OUT); + + alignEdge = (FieldAlignEdge) alignEdge.identityReverse(); + Assert.assertEquals(alignEdge.getDirect(), EdgeDirection.IN); + Assert.assertEquals(alignEdge.getSrcId(), 1); + Assert.assertEquals(alignEdge.getTargetId(), 2); + } + + @Test + public void testAlignVertex() { + RowVertex vertex = new IntVertex(1); + vertex.setValue(ObjectRow.create(1, 2, 3)); + + FieldAlignVertex alignVertex = new FieldAlignVertex(vertex, new int[] {-1, 2, 1, 0}); + Assert.assertEquals(alignVertex.getId(), vertex.getId()); + Assert.assertNull(alignVertex.getField(0, Types.INTEGER)); + Assert.assertEquals(alignVertex.getField(1, Types.INTEGER), 1); + + alignVertex.setId(2); + Assert.assertEquals(alignVertex.getId(), 2); + alignVertex.setLabel("l0"); + Assert.assertEquals(alignVertex.getLabel(), "l0"); + alignVertex.setBinaryLabel(BinaryString.fromString("l1")); + Assert.assertEquals(alignVertex.getBinaryLabel(), BinaryString.fromString("l1")); + + alignVertex.setValue(ObjectRow.EMPTY); + Assert.assertEquals(alignVertex.getValue(), ObjectRow.EMPTY); + alignVertex = (FieldAlignVertex) alignVertex.withValue(ObjectRow.EMPTY); + Assert.assertEquals(alignVertex.getValue(), ObjectRow.EMPTY); + + alignVertex = (FieldAlignVertex) alignVertex.withLabel("l2"); + Assert.assertEquals(alignVertex.getBinaryLabel(), BinaryString.fromString("l2")); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderMemoryBenchmark.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderMemoryBenchmark.java index 94a8a6fd6..def4cfb5a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderMemoryBenchmark.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderMemoryBenchmark.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Random; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.type.Types; import org.apache.geaflow.dsl.common.data.Row; import org.apache.geaflow.dsl.common.data.impl.ObjectRow; @@ -51,128 +52,129 @@ @Measurement(iterations = 10) @Fork(1) public class OrderMemoryBenchmark { - - @Param({"10000", "100000", "1000000"}) - private int dataSize; - - @Param({"100", "1000", "10000"}) - private int topN; - - private OrderByFunction orderByFunction; - private List testData; - private SortInfo sortInfo = new SortInfo(); - - @Setup(Level.Trial) - public void setup() { - // Create sort expression - Expression expression = new FieldExpression(0, Types.INTEGER); - - OrderByField orderByField = new OrderByField(); - orderByField.expression = expression; - orderByField.order = ORDER.ASC; - - List orderByFields = new ArrayList<>(1); - orderByFields.add(orderByField); - - sortInfo.orderByFields = orderByFields; - sortInfo.fetch = topN; - - // Generate test data - testData = generateTestData(); + + @Param({"10000", "100000", "1000000"}) + private int dataSize; + + @Param({"100", "1000", "10000"}) + private int topN; + + private OrderByFunction orderByFunction; + private List testData; + private SortInfo sortInfo = new SortInfo(); + + @Setup(Level.Trial) + public void setup() { + // Create sort expression + Expression expression = new FieldExpression(0, Types.INTEGER); + + OrderByField orderByField = new OrderByField(); + orderByField.expression = expression; + orderByField.order = ORDER.ASC; + + List orderByFields = new ArrayList<>(1); + orderByFields.add(orderByField); + + sortInfo.orderByFields = orderByFields; + sortInfo.fetch = topN; + + // Generate test data + testData = generateTestData(); + } + + private List generateTestData() { + List data = new ArrayList<>(dataSize); + Random random = new Random(42); + + for (int i = 0; i < dataSize; i++) { + Object[] values = {random.nextInt(dataSize * 10)}; + data.add(ObjectRow.create(values)); + } + + return data; + } + + @Benchmark + public Iterable benchmarkHeapSortMemory() { + // Create a copy of the input data to avoid state pollution + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByHeapSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); } - - private List generateTestData() { - List data = new ArrayList<>(dataSize); - Random random = new Random(42); - - for (int i = 0; i < dataSize; i++) { - Object[] values = {random.nextInt(dataSize * 10)}; - data.add(ObjectRow.create(values)); - } - - return data; + + // Perform Top-N sorting + return orderByFunction.finish(); + } + + @Benchmark + public Iterable benchmarkRadixSortMemory() { + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByRadixSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); } - - @Benchmark - public Iterable benchmarkHeapSortMemory() { - // Create a copy of the input data to avoid state pollution - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByHeapSort(sortInfo); - orderByFunction.open(null); - - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } - - // Perform Top-N sorting - return orderByFunction.finish(); + + return orderByFunction.finish(); + } + + @Benchmark + public Iterable benchmarkTimSortMemory() { + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByTimSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); } - @Benchmark - public Iterable benchmarkRadixSortMemory() { - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByRadixSort(sortInfo); - orderByFunction.open(null); - - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } - - return orderByFunction.finish(); + return orderByFunction.finish(); + } + + public static void main(String[] args) throws RunnerException { + // Run a verification first + OrderMemoryBenchmark benchmark = new OrderMemoryBenchmark(); + benchmark.dataSize = 10; + benchmark.topN = 10; + benchmark.setup(); + Iterable heapResults = benchmark.benchmarkHeapSortMemory(); + System.out.println("===HEAP_SORT==="); + for (Row result : heapResults) { + System.out.print(result); } - - @Benchmark - public Iterable benchmarkTimSortMemory() { - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByTimSort(sortInfo); - orderByFunction.open(null); - - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } - - return orderByFunction.finish(); + System.out.println(); + System.out.println("===RADIX_SORT==="); + Iterable radixResults = benchmark.benchmarkRadixSortMemory(); + for (Row result : radixResults) { + System.out.print(result); } - - public static void main(String[] args) throws RunnerException { - // Run a verification first - OrderMemoryBenchmark benchmark = new OrderMemoryBenchmark(); - benchmark.dataSize = 10; - benchmark.topN = 10; - benchmark.setup(); - Iterable heapResults = benchmark.benchmarkHeapSortMemory(); - System.out.println("===HEAP_SORT==="); - for (Row result: heapResults) { - System.out.print(result); - } - System.out.println(); - System.out.println("===RADIX_SORT==="); - Iterable radixResults = benchmark.benchmarkRadixSortMemory(); - for (Row result: radixResults) { - System.out.print(result); - } - System.out.println(); - System.out.println("===TIM_SORT==="); - Iterable timResults = benchmark.benchmarkTimSortMemory(); - for (Row result: timResults) { - System.out.print(result); - } - System.out.println(); - - String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(new Date()); - String resultFile = "target/benchmark-results/memory-" + timestamp + ".json"; - - Options opt = new OptionsBuilder() + System.out.println(); + System.out.println("===TIM_SORT==="); + Iterable timResults = benchmark.benchmarkTimSortMemory(); + for (Row result : timResults) { + System.out.print(result); + } + System.out.println(); + + String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(new Date()); + String resultFile = "target/benchmark-results/memory-" + timestamp + ".json"; + + Options opt = + new OptionsBuilder() .include(OrderMemoryBenchmark.class.getSimpleName()) .addProfiler(GCProfiler.class) .jvmArgs("-Xms2g", "-Xmx4g") .result(resultFile) .resultFormat(ResultFormatType.JSON) .build(); - - new Runner(opt).run(); - } -} \ No newline at end of file + + new Runner(opt).run(); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderTimeBenchmark.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderTimeBenchmark.java index 96bdf77e8..278733260 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderTimeBenchmark.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/benchmark/OrderTimeBenchmark.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Random; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.common.type.Types; @@ -53,255 +54,255 @@ @Measurement(iterations = 5, time = 2, timeUnit = TimeUnit.SECONDS) @Fork(2) public class OrderTimeBenchmark { - private static final String CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; - - @Param({"1000"}) - private int dataSize; - - @Param({"1000"}) - private int topN; - - @Param({"RANDOM", "SORTED", "REVERSE_SORTED", "PARTIAL_SORTED", "DUPLICATED"}) - private String dataPattern; - - @Param({"STRING"}) - private String dataType; - - private OrderByFunction orderByFunction; - private List testData; - private SortInfo sortInfo = new SortInfo(); - - @Setup(Level.Trial) - public void setupBenchmark() { - // Create sort expression - setupOrderByExpressions(); - - // Generate test data - testData = generateTestData(); - } - - private void setupOrderByExpressions() { - IType fieldType; - switch (dataType) { - case "INTEGER": - fieldType = Types.INTEGER; - break; - case "DOUBLE": - fieldType = Types.DOUBLE; - break; - case "STRING": - fieldType = Types.BINARY_STRING; - break; - default: - fieldType = Types.INTEGER; - } - - List orderByFields = new ArrayList<>(2); - - // Primary sort field - Expression expression1 = new FieldExpression(0, fieldType); - OrderByField orderByField1 = new OrderByField(); - orderByField1.expression = expression1; - orderByField1.order = ORDER.ASC; - orderByFields.add(orderByField1); - - // Add a secondary sort field (for testing multi-field sorting performance) - Expression expression2 = new FieldExpression(1, Types.INTEGER); - OrderByField orderByField2 = new OrderByField(); - orderByField2.expression = expression2; - orderByField2.order = ORDER.ASC; - orderByFields.add(orderByField2); - - sortInfo.orderByFields = orderByFields; - sortInfo.fetch = topN; + private static final String CHARS = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + @Param({"1000"}) + private int dataSize; + + @Param({"1000"}) + private int topN; + + @Param({"RANDOM", "SORTED", "REVERSE_SORTED", "PARTIAL_SORTED", "DUPLICATED"}) + private String dataPattern; + + @Param({"STRING"}) + private String dataType; + + private OrderByFunction orderByFunction; + private List testData; + private SortInfo sortInfo = new SortInfo(); + + @Setup(Level.Trial) + public void setupBenchmark() { + // Create sort expression + setupOrderByExpressions(); + + // Generate test data + testData = generateTestData(); + } + + private void setupOrderByExpressions() { + IType fieldType; + switch (dataType) { + case "INTEGER": + fieldType = Types.INTEGER; + break; + case "DOUBLE": + fieldType = Types.DOUBLE; + break; + case "STRING": + fieldType = Types.BINARY_STRING; + break; + default: + fieldType = Types.INTEGER; } - - private List generateTestData() { - List data = new ArrayList<>(dataSize); - Random random = new Random(42); // Fixed seeds ensure reproducibility - - for (int i = 0; i < dataSize; i++) { - Object[] values = new Object[2]; - - // Generate the value of the primary sort field - switch (dataType) { - case "INTEGER": - values[0] = generateIntegerValue(i, random); - break; - case "DOUBLE": - values[0] = generateDoubleValue(i, random); - break; - case "STRING": - values[0] = BinaryString.fromString(generateStringValue(i, random)); - break; - default: - return data; - } - - // Generate the value of the secondary sort field - values[1] = random.nextInt(100); - - data.add(ObjectRow.create(values)); - } - - return data; + + List orderByFields = new ArrayList<>(2); + + // Primary sort field + Expression expression1 = new FieldExpression(0, fieldType); + OrderByField orderByField1 = new OrderByField(); + orderByField1.expression = expression1; + orderByField1.order = ORDER.ASC; + orderByFields.add(orderByField1); + + // Add a secondary sort field (for testing multi-field sorting performance) + Expression expression2 = new FieldExpression(1, Types.INTEGER); + OrderByField orderByField2 = new OrderByField(); + orderByField2.expression = expression2; + orderByField2.order = ORDER.ASC; + orderByFields.add(orderByField2); + + sortInfo.orderByFields = orderByFields; + sortInfo.fetch = topN; + } + + private List generateTestData() { + List data = new ArrayList<>(dataSize); + Random random = new Random(42); // Fixed seeds ensure reproducibility + + for (int i = 0; i < dataSize; i++) { + Object[] values = new Object[2]; + + // Generate the value of the primary sort field + switch (dataType) { + case "INTEGER": + values[0] = generateIntegerValue(i, random); + break; + case "DOUBLE": + values[0] = generateDoubleValue(i, random); + break; + case "STRING": + values[0] = BinaryString.fromString(generateStringValue(i, random)); + break; + default: + return data; + } + + // Generate the value of the secondary sort field + values[1] = random.nextInt(100); + + data.add(ObjectRow.create(values)); } - - private Integer generateIntegerValue(int index, Random random) { - switch (dataPattern) { - case "RANDOM": - return random.nextInt(dataSize * 10); - case "SORTED": - return index; - case "REVERSE_SORTED": - return dataSize - index; - case "PARTIAL_SORTED": - // 70% ordered, 30% random - return index < dataSize * 0.7 ? index : random.nextInt(dataSize); - case "DUPLICATED": - // Generate a large number of repeated values - return random.nextInt(dataSize / 10); - default: - return random.nextInt(dataSize); - } + + return data; + } + + private Integer generateIntegerValue(int index, Random random) { + switch (dataPattern) { + case "RANDOM": + return random.nextInt(dataSize * 10); + case "SORTED": + return index; + case "REVERSE_SORTED": + return dataSize - index; + case "PARTIAL_SORTED": + // 70% ordered, 30% random + return index < dataSize * 0.7 ? index : random.nextInt(dataSize); + case "DUPLICATED": + // Generate a large number of repeated values + return random.nextInt(dataSize / 10); + default: + return random.nextInt(dataSize); } - - private Double generateDoubleValue(int index, Random random) { - switch (dataPattern) { - case "RANDOM": - return random.nextDouble() * dataSize * 10; - case "SORTED": - return (double) index + random.nextDouble(); - case "REVERSE_SORTED": - return (double) (dataSize - index) + random.nextDouble(); - case "PARTIAL_SORTED": - return index < dataSize * 0.7 - ? (double) index + random.nextDouble() - : random.nextDouble() * dataSize; - case "DUPLICATED": - return (double) (random.nextInt(dataSize / 10)) + random.nextDouble(); - default: - return random.nextDouble() * dataSize; - } + } + + private Double generateDoubleValue(int index, Random random) { + switch (dataPattern) { + case "RANDOM": + return random.nextDouble() * dataSize * 10; + case "SORTED": + return (double) index + random.nextDouble(); + case "REVERSE_SORTED": + return (double) (dataSize - index) + random.nextDouble(); + case "PARTIAL_SORTED": + return index < dataSize * 0.7 + ? (double) index + random.nextDouble() + : random.nextDouble() * dataSize; + case "DUPLICATED": + return (double) (random.nextInt(dataSize / 10)) + random.nextDouble(); + default: + return random.nextDouble() * dataSize; } - - private String generateStringValue(int index, Random random) { - String[] prefixes = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"}; - - switch (dataPattern) { - case "RANDOM": - return generateRandomString(1, 101, random); - case "SORTED": - return String.format("R%0100d", index); - case "REVERSE_SORTED": - return String.format("R%0100d", dataSize - index); - case "PARTIAL_SORTED": - return index < dataSize * 0.7 - ? String.format("R%0100d", index) - : generateRandomString(1, 101, random); - case "DUPLICATED": - return prefixes[random.nextInt(3)] - + String.format("%0100d", random.nextInt(dataSize / 10)); - default: - return String.format("R%0100d", random.nextInt(dataSize)); - } + } + + private String generateStringValue(int index, Random random) { + String[] prefixes = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"}; + + switch (dataPattern) { + case "RANDOM": + return generateRandomString(1, 101, random); + case "SORTED": + return String.format("R%0100d", index); + case "REVERSE_SORTED": + return String.format("R%0100d", dataSize - index); + case "PARTIAL_SORTED": + return index < dataSize * 0.7 + ? String.format("R%0100d", index) + : generateRandomString(1, 101, random); + case "DUPLICATED": + return prefixes[random.nextInt(3)] + String.format("%0100d", random.nextInt(dataSize / 10)); + default: + return String.format("R%0100d", random.nextInt(dataSize)); } + } - private String generateRandomString(int length, Random random) { - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append(CHARS.charAt(random.nextInt(CHARS.length()))); - } - return sb.toString(); + private String generateRandomString(int length, Random random) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(CHARS.charAt(random.nextInt(CHARS.length()))); } + return sb.toString(); + } + + private String generateRandomString(int minLength, int maxLength, Random random) { + int length = minLength + random.nextInt(maxLength - minLength + 1); + return generateRandomString(length, random); + } - private String generateRandomString(int minLength, int maxLength, Random random) { - int length = minLength + random.nextInt(maxLength - minLength + 1); - return generateRandomString(length, random); + @Benchmark + public Iterable benchmarkHeapSort() { + // Create a copy of the input data to avoid state pollution + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByHeapSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); } - - - @Benchmark - public Iterable benchmarkHeapSort() { - // Create a copy of the input data to avoid state pollution - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByHeapSort(sortInfo); - orderByFunction.open(null); - - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } - - // Perform Top-N sorting - return orderByFunction.finish(); + + // Perform Top-N sorting + return orderByFunction.finish(); + } + + @Benchmark + public Iterable benchmarkRadixSort() { + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByRadixSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); } - @Benchmark - public Iterable benchmarkRadixSort() { - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByRadixSort(sortInfo); - orderByFunction.open(null); + return orderByFunction.finish(); + } + + @Benchmark + public Iterable benchmarkTimSort() { + List inputData = new ArrayList<>(testData); + + orderByFunction = new OrderByTimSort(sortInfo); + orderByFunction.open(null); + + for (int i = 0; i < dataSize; i++) { + orderByFunction.process(inputData.get(i)); + } - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } + return orderByFunction.finish(); + } - return orderByFunction.finish(); + public static void main(String[] args) throws RunnerException { + // Run a verification first + OrderTimeBenchmark benchmark = new OrderTimeBenchmark(); + benchmark.dataSize = 10; + benchmark.topN = 10; + benchmark.dataPattern = "RANDOM"; + benchmark.dataType = "INTEGER"; + benchmark.setupBenchmark(); + Iterable heapResults = benchmark.benchmarkHeapSort(); + System.out.println("===HEAP_SORT==="); + for (Row result : heapResults) { + System.out.print(result); } - - @Benchmark - public Iterable benchmarkTimSort() { - List inputData = new ArrayList<>(testData); - - orderByFunction = new OrderByTimSort(sortInfo); - orderByFunction.open(null); - - for (int i = 0; i < dataSize; i++) { - orderByFunction.process(inputData.get(i)); - } - - return orderByFunction.finish(); + System.out.println(); + System.out.println("===RADIX_SORT==="); + Iterable radixResults = benchmark.benchmarkRadixSort(); + for (Row result : radixResults) { + System.out.print(result); } - - public static void main(String[] args) throws RunnerException { - // Run a verification first - OrderTimeBenchmark benchmark = new OrderTimeBenchmark(); - benchmark.dataSize = 10; - benchmark.topN = 10; - benchmark.dataPattern = "RANDOM"; - benchmark.dataType = "INTEGER"; - benchmark.setupBenchmark(); - Iterable heapResults = benchmark.benchmarkHeapSort(); - System.out.println("===HEAP_SORT==="); - for (Row result: heapResults) { - System.out.print(result); - } - System.out.println(); - System.out.println("===RADIX_SORT==="); - Iterable radixResults = benchmark.benchmarkRadixSort(); - for (Row result: radixResults) { - System.out.print(result); - } - System.out.println(); - System.out.println("===TIM_SORT==="); - Iterable timResults = benchmark.benchmarkTimSort(); - for (Row result: timResults) { - System.out.print(result); - } - System.out.println(); - - String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(new Date()); - String resultFile = "target/benchmark-results/time-" + timestamp + ".json"; - - Options opt = new OptionsBuilder() + System.out.println(); + System.out.println("===TIM_SORT==="); + Iterable timResults = benchmark.benchmarkTimSort(); + for (Row result : timResults) { + System.out.print(result); + } + System.out.println(); + + String timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss").format(new Date()); + String resultFile = "target/benchmark-results/time-" + timestamp + ".json"; + + Options opt = + new OptionsBuilder() .include(OrderTimeBenchmark.class.getSimpleName()) .jvmArgs("-Xms4g", "-Xmx8g", "-XX:+UseG1GC") .result(resultFile) .resultFormat(ResultFormatType.JSON) .build(); - - new Runner(opt).run(); - } -} \ No newline at end of file + + new Runner(opt).run(); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeEdgeTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeEdgeTest.java index 81148fce2..1c65f0209 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeEdgeTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeEdgeTest.java @@ -32,58 +32,57 @@ public class RuntimeEdgeTest { - @Test - public void testLongEdge() { - LongEdge test = new LongEdge(1, 2); - test.withDirection(EdgeDirection.OUT); - test.withValue(ObjectRow.EMPTY); - LongEdge test2 = new LongEdge(2, 1); - test.withDirection(EdgeDirection.OUT); - test2.withValue(ObjectRow.EMPTY); - test2 = test2.reverse(); - Assert.assertEquals(test, test2); + @Test + public void testLongEdge() { + LongEdge test = new LongEdge(1, 2); + test.withDirection(EdgeDirection.OUT); + test.withValue(ObjectRow.EMPTY); + LongEdge test2 = new LongEdge(2, 1); + test.withDirection(EdgeDirection.OUT); + test2.withValue(ObjectRow.EMPTY); + test2 = test2.reverse(); + Assert.assertEquals(test, test2); - FieldAlignLongEdge test3 = new FieldAlignLongEdge(test2, new int[]{0, 1}); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test, test3.reverse().reverse()); - FieldAlignLongEdge test4 = new FieldAlignLongEdge(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignLongEdge test3 = new FieldAlignLongEdge(test2, new int[] {0, 1}); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test, test3.reverse().reverse()); + FieldAlignLongEdge test4 = new FieldAlignLongEdge(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } - @Test - public void testIntEdge() { - IntEdge test = new IntEdge(1, 2); - test.withDirection(EdgeDirection.OUT); - test.withValue(ObjectRow.EMPTY); - IntEdge test2 = new IntEdge(2, 1); - test.withDirection(EdgeDirection.OUT); - test2.withValue(ObjectRow.EMPTY); - test2 = test2.reverse(); - Assert.assertEquals(test, test2); + @Test + public void testIntEdge() { + IntEdge test = new IntEdge(1, 2); + test.withDirection(EdgeDirection.OUT); + test.withValue(ObjectRow.EMPTY); + IntEdge test2 = new IntEdge(2, 1); + test.withDirection(EdgeDirection.OUT); + test2.withValue(ObjectRow.EMPTY); + test2 = test2.reverse(); + Assert.assertEquals(test, test2); - FieldAlignIntEdge test3 = new FieldAlignIntEdge(test2, new int[]{0, 1}); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test, test3.reverse().reverse()); - FieldAlignIntEdge test4 = new FieldAlignIntEdge(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignIntEdge test3 = new FieldAlignIntEdge(test2, new int[] {0, 1}); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test, test3.reverse().reverse()); + FieldAlignIntEdge test4 = new FieldAlignIntEdge(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } - @Test - public void testDoubleEdge() { - DoubleEdge test = new DoubleEdge(1, 2); - test.withDirection(EdgeDirection.OUT); - test.withValue(ObjectRow.EMPTY); - DoubleEdge test2 = new DoubleEdge(2, 1); - test.withDirection(EdgeDirection.OUT); - test2.withValue(ObjectRow.EMPTY); - test2 = test2.reverse(); - Assert.assertEquals(test, test2); + @Test + public void testDoubleEdge() { + DoubleEdge test = new DoubleEdge(1, 2); + test.withDirection(EdgeDirection.OUT); + test.withValue(ObjectRow.EMPTY); + DoubleEdge test2 = new DoubleEdge(2, 1); + test.withDirection(EdgeDirection.OUT); + test2.withValue(ObjectRow.EMPTY); + test2 = test2.reverse(); + Assert.assertEquals(test, test2); - FieldAlignDoubleEdge test3 = new FieldAlignDoubleEdge(test2, new int[]{0, 1}); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test, test3.reverse().reverse()); - FieldAlignDoubleEdge test4 = new FieldAlignDoubleEdge(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignDoubleEdge test3 = new FieldAlignDoubleEdge(test2, new int[] {0, 1}); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test, test3.reverse().reverse()); + FieldAlignDoubleEdge test4 = new FieldAlignDoubleEdge(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeVertexTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeVertexTest.java index e8290b4f5..82bcf596d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeVertexTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/data/RuntimeVertexTest.java @@ -31,64 +31,63 @@ public class RuntimeVertexTest { - @Test - public void testLongVertex() { - LongVertex test = new LongVertex(1); - test.withLabel("v"); - test.withValue(ObjectRow.EMPTY); - LongVertex test2 = new LongVertex(1); - test2.withLabel("v"); - test2.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test2.hashCode()); - Assert.assertEquals(test, test2); + @Test + public void testLongVertex() { + LongVertex test = new LongVertex(1); + test.withLabel("v"); + test.withValue(ObjectRow.EMPTY); + LongVertex test2 = new LongVertex(1); + test2.withLabel("v"); + test2.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test2.hashCode()); + Assert.assertEquals(test, test2); - FieldAlignLongVertex test3 = new FieldAlignLongVertex(test2, new int[]{0, 1}); - test3.withLabel("v"); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test3.hashCode()); - Assert.assertEquals(test, test3); - FieldAlignLongVertex test4 = new FieldAlignLongVertex(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignLongVertex test3 = new FieldAlignLongVertex(test2, new int[] {0, 1}); + test3.withLabel("v"); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test3.hashCode()); + Assert.assertEquals(test, test3); + FieldAlignLongVertex test4 = new FieldAlignLongVertex(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } - @Test - public void testIntVertex() { - IntVertex test = new IntVertex(1); - test.withLabel("v"); - test.withValue(ObjectRow.EMPTY); - IntVertex test2 = new IntVertex(1); - test2.withLabel("v"); - test2.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test2.hashCode()); - Assert.assertEquals(test, test2); + @Test + public void testIntVertex() { + IntVertex test = new IntVertex(1); + test.withLabel("v"); + test.withValue(ObjectRow.EMPTY); + IntVertex test2 = new IntVertex(1); + test2.withLabel("v"); + test2.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test2.hashCode()); + Assert.assertEquals(test, test2); - FieldAlignIntVertex test3 = new FieldAlignIntVertex(test2, new int[]{0, 1}); - test3.withLabel("v"); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test3.hashCode()); - Assert.assertEquals(test, test3); - FieldAlignIntVertex test4 = new FieldAlignIntVertex(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignIntVertex test3 = new FieldAlignIntVertex(test2, new int[] {0, 1}); + test3.withLabel("v"); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test3.hashCode()); + Assert.assertEquals(test, test3); + FieldAlignIntVertex test4 = new FieldAlignIntVertex(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } - @Test - public void testDoubleVertex() { - DoubleVertex test = new DoubleVertex(1); - test.withLabel("v"); - test.withValue(ObjectRow.EMPTY); - DoubleVertex test2 = new DoubleVertex(1); - test2.withLabel("v"); - test2.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test2.hashCode()); - Assert.assertEquals(test, test2); + @Test + public void testDoubleVertex() { + DoubleVertex test = new DoubleVertex(1); + test.withLabel("v"); + test.withValue(ObjectRow.EMPTY); + DoubleVertex test2 = new DoubleVertex(1); + test2.withLabel("v"); + test2.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test2.hashCode()); + Assert.assertEquals(test, test2); - FieldAlignDoubleVertex test3 = new FieldAlignDoubleVertex(test2, new int[]{0, 1}); - test3.withLabel("v"); - test3.withValue(ObjectRow.EMPTY); - Assert.assertEquals(test.hashCode(), test3.hashCode()); - Assert.assertEquals(test, test3); - FieldAlignDoubleVertex test4 = new FieldAlignDoubleVertex(test, new int[]{0, 1}); - Assert.assertEquals(test3, test4); - } + FieldAlignDoubleVertex test3 = new FieldAlignDoubleVertex(test2, new int[] {0, 1}); + test3.withLabel("v"); + test3.withValue(ObjectRow.EMPTY); + Assert.assertEquals(test.hashCode(), test3.hashCode()); + Assert.assertEquals(test, test3); + FieldAlignDoubleVertex test4 = new FieldAlignDoubleVertex(test, new int[] {0, 1}); + Assert.assertEquals(test3, test4); + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/DagTopologyGroupTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/DagTopologyGroupTest.java index 8a17ce286..43cb31009 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/DagTopologyGroupTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/DagTopologyGroupTest.java @@ -19,12 +19,11 @@ package org.apache.geaflow.dsl.runtime.plan; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.calcite.GraphRecordType; import org.apache.geaflow.dsl.common.data.Row; @@ -57,139 +56,162 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + public class DagTopologyGroupTest { - private static final Logger LOGGER = LoggerFactory.getLogger(DagTopologyGroupTest.class); - - @Test - public void testDagTopologyGroup() { - StepLogicalPlan.clearCounter(); - StepLogicalPlan mainPlan = - StepLogicalPlan.start() - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "a")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "b")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()); - - StepLogicalPlan subPlan = - StepLogicalPlan.subQueryStart("SubQuery1") - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "b")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e1")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .aggregate(new TestStepAggFunction()); - - StepLogicalPlanSet logicalPlanSet = new StepLogicalPlanSet(mainPlan); - logicalPlanSet.addSubLogicalPlan(subPlan); - DagGroupBuilder builder = new DagGroupBuilder(); - DagTopologyGroup dagGroup = builder.buildDagGroup(logicalPlanSet); - - String planSetDesc = logicalPlanSet.getPlanSetDesc(); - LOGGER.info("Plan Desc:\n{}", planSetDesc); - Assert.assertEquals(planSetDesc, "digraph G {\n" + "3 -> 10 [label= \"\"]\n" - + "2 -> 3 [label= \"chain = false\"]\n" + "1 -> 2 [label= \"\"]\n" - + "0 -> 1 [label= \"\"]\n" + "10 [label= \"StepEnd-10\"]\n" - + "3 [label= \"MatchVertex-3 [b]\"]\n" + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" - + "1 [label= \"MatchVertex-1 [a]\"]\n" + "0 [label= \"StepSource-0()\"]\n" + "\n" - + "8 -> 9 [label= \"chain = false\"]\n" + "7 -> 8 [label= \"\"]\n" - + "6 -> 7 [label= \"\"]\n" + "5 -> 6 [label= \"\"]\n" + "4 -> 5 [label= \"\"]\n" + private static final Logger LOGGER = LoggerFactory.getLogger(DagTopologyGroupTest.class); + + @Test + public void testDagTopologyGroup() { + StepLogicalPlan.clearCounter(); + StepLogicalPlan mainPlan = + StepLogicalPlan.start() + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "a")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "b")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()); + + StepLogicalPlan subPlan = + StepLogicalPlan.subQueryStart("SubQuery1") + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "b")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e1")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .aggregate(new TestStepAggFunction()); + + StepLogicalPlanSet logicalPlanSet = new StepLogicalPlanSet(mainPlan); + logicalPlanSet.addSubLogicalPlan(subPlan); + DagGroupBuilder builder = new DagGroupBuilder(); + DagTopologyGroup dagGroup = builder.buildDagGroup(logicalPlanSet); + + String planSetDesc = logicalPlanSet.getPlanSetDesc(); + LOGGER.info("Plan Desc:\n{}", planSetDesc); + Assert.assertEquals( + planSetDesc, + "digraph G {\n" + + "3 -> 10 [label= \"\"]\n" + + "2 -> 3 [label= \"chain = false\"]\n" + + "1 -> 2 [label= \"\"]\n" + + "0 -> 1 [label= \"\"]\n" + + "10 [label= \"StepEnd-10\"]\n" + + "3 [label= \"MatchVertex-3 [b]\"]\n" + + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" + + "1 [label= \"MatchVertex-1 [a]\"]\n" + + "0 [label= \"StepSource-0()\"]\n" + + "\n" + + "8 -> 9 [label= \"chain = false\"]\n" + + "7 -> 8 [label= \"\"]\n" + + "6 -> 7 [label= \"\"]\n" + + "5 -> 6 [label= \"\"]\n" + + "4 -> 5 [label= \"\"]\n" + "9 [label= \"StepGlobalSingleValueAggregate-9\"]\n" - + "8 [label= \"StepExchange-8\"]\n" + "7 [label= \"StepLocalSingleValueAggregate-7\"]\n" - + "6 [label= \"MatchEdge-6(OUT) [e1]\"]\n" + "5 [label= \"MatchVertex-5 [b]\"]\n" - + "4 [label= \"StepSubQueryStart-4(name=SubQuery1)\"]\n" + "}"); - DagTopology mainDag = dagGroup.getMainDag(); - Assert.assertTrue(mainDag.getEntryOperator().getClass().isAssignableFrom(StepSourceOperator.class)); - Assert.assertTrue(mainDag.isChained(0, 1)); - Assert.assertTrue(mainDag.isChained(1, 2)); - Assert.assertFalse(mainDag.isChained(2, 3)); - Assert.assertTrue(mainDag.isChained(3, 10)); - - DagTopology subDag = dagGroup.getSubDagTopologies().get(0); - Assert.assertTrue(subDag.getEntryOperator().getClass().isAssignableFrom(StepSubQueryStartOperator.class)); - Assert.assertEquals(subDag.getInputIds(5), Lists.newArrayList(4L)); - Assert.assertEquals(subDag.getOutputIds(5), Lists.newArrayList(6L)); - Assert.assertEquals(subDag.getOutputIds(11), Lists.newArrayList()); - - Assert.assertEquals(dagGroup.isChained(0, 1), mainDag.isChained(0, 1)); - Assert.assertEquals(dagGroup.isChained(4, 5), subDag.isChained(4, 5)); - Assert.assertEquals(dagGroup.getInputIds(4), subDag.getInputIds(4)); - } - - private GraphSchema createGraph() { - GeaFlowGraph graph = new GeaFlowGraph("default", "test", new ArrayList<>(), - new ArrayList<>(), new HashMap<>(), new HashMap<>(), false, false); - GraphRecordType graphRecordType = (GraphRecordType) graph.getRowType(GQLJavaTypeFactory.create()); - return (GraphSchema) SqlTypeUtil.convertType(graphRecordType); + + "8 [label= \"StepExchange-8\"]\n" + + "7 [label= \"StepLocalSingleValueAggregate-7\"]\n" + + "6 [label= \"MatchEdge-6(OUT) [e1]\"]\n" + + "5 [label= \"MatchVertex-5 [b]\"]\n" + + "4 [label= \"StepSubQueryStart-4(name=SubQuery1)\"]\n" + + "}"); + DagTopology mainDag = dagGroup.getMainDag(); + Assert.assertTrue( + mainDag.getEntryOperator().getClass().isAssignableFrom(StepSourceOperator.class)); + Assert.assertTrue(mainDag.isChained(0, 1)); + Assert.assertTrue(mainDag.isChained(1, 2)); + Assert.assertFalse(mainDag.isChained(2, 3)); + Assert.assertTrue(mainDag.isChained(3, 10)); + + DagTopology subDag = dagGroup.getSubDagTopologies().get(0); + Assert.assertTrue( + subDag.getEntryOperator().getClass().isAssignableFrom(StepSubQueryStartOperator.class)); + Assert.assertEquals(subDag.getInputIds(5), Lists.newArrayList(4L)); + Assert.assertEquals(subDag.getOutputIds(5), Lists.newArrayList(6L)); + Assert.assertEquals(subDag.getOutputIds(11), Lists.newArrayList()); + + Assert.assertEquals(dagGroup.isChained(0, 1), mainDag.isChained(0, 1)); + Assert.assertEquals(dagGroup.isChained(4, 5), subDag.isChained(4, 5)); + Assert.assertEquals(dagGroup.getInputIds(4), subDag.getInputIds(4)); + } + + private GraphSchema createGraph() { + GeaFlowGraph graph = + new GeaFlowGraph( + "default", + "test", + new ArrayList<>(), + new ArrayList<>(), + new HashMap<>(), + new HashMap<>(), + false, + false); + GraphRecordType graphRecordType = + (GraphRecordType) graph.getRowType(GQLJavaTypeFactory.create()); + return (GraphSchema) SqlTypeUtil.convertType(graphRecordType); + } + + private static class TestStepAggFunction implements StepAggregateFunction { + + @Override + public Object createAccumulator() { + return null; } - private static class TestStepAggFunction implements StepAggregateFunction { + @Override + public void add(Row row, Object accumulator) {} - @Override - public Object createAccumulator() { - return null; - } + @Override + public void merge(Object acc, Object otherAcc) {} - @Override - public void add(Row row, Object accumulator) { - - } - - @Override - public void merge(Object acc, Object otherAcc) { - - } - - @Override - public SingleValue getValue(Object accumulator) { - return null; - } - - @Override - public void open(TraversalRuntimeContext context, FunctionSchemas schemas) { - - } + @Override + public SingleValue getValue(Object accumulator) { + return null; + } - @Override - public void finish(StepCollector collector) { + @Override + public void open(TraversalRuntimeContext context, FunctionSchemas schemas) {} - } + @Override + public void finish(StepCollector collector) {} - @Override - public List getExpressions() { - return Collections.emptyList(); - } + @Override + public List getExpressions() { + return Collections.emptyList(); + } - @Override - public StepFunction copy(List expressions) { - return new TestStepAggFunction(); - } + @Override + public StepFunction copy(List expressions) { + return new TestStepAggFunction(); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/ExpressionTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/ExpressionTest.java index b056cb23a..93cc5d9a3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/ExpressionTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/ExpressionTest.java @@ -22,10 +22,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import com.google.common.collect.Lists; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.geaflow.common.type.primitive.BooleanType; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.dsl.common.data.Row; @@ -36,189 +36,169 @@ import org.apache.geaflow.dsl.runtime.expression.ExpressionBuilder; import org.testng.annotations.Test; -public class ExpressionTest { - - @Test - public void testEqual() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.equal( - builder.literal(1, IntegerType.INSTANCE), - builder.literal(1, IntegerType.INSTANCE) - ); - assertEquals(expression.showExpression(), "1 = 1"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testGTE() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.greaterEqThen( - builder.literal(1, IntegerType.INSTANCE), - builder.literal(1, IntegerType.INSTANCE) - ); - assertEquals(expression.showExpression(), "1 >= 1"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testGT() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.greaterThan( - builder.literal(1, IntegerType.INSTANCE), - builder.literal(1, IntegerType.INSTANCE) - ); - assertEquals(expression.showExpression(), "1 > 1"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), false); - } - - @Test - public void testIsFalse() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isFalse( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is false"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testIsNotFalse() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isNotFalse( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is not false"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), false); - } - - @Test - public void testIsNotNull() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isNotNull( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is not null"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testIsNotTrue() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isNotTrue( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is not true"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testIsNull() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isNull( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is null"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), false); - } - - @Test - public void testIsTrue() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.isTrue( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "false is true"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), false); - } - - @Test - public void testLTE() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.lessEqThan( - builder.literal(1, IntegerType.INSTANCE), - builder.literal(2, IntegerType.INSTANCE) - ); - assertEquals(expression.showExpression(), "1 <= 2"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testLT() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.lessThan( - builder.literal(1, IntegerType.INSTANCE), - builder.literal(1, IntegerType.INSTANCE) - ); - assertEquals(expression.showExpression(), "1 < 1"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), false); - } - - @Test - public void testNot() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.not( - builder.literal(false, BooleanType.INSTANCE) - ); - assertEquals(expression.showExpression(), "Not(false)"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testOr() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.or(Lists.newArrayList( - builder.literal(false, BooleanType.INSTANCE), - builder.literal(true, BooleanType.INSTANCE)) - ); - assertEquals(expression.showExpression(), "Or(false,true)"); - Expression newExpr = expression.copy(expression.getInputs()); - assertEquals(newExpr.evaluate(Row.EMPTY), true); - } - - @Test - public void testItem() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression = builder.item( - builder.literal(new Integer[]{0, 1}, new ArrayType(IntegerType.INSTANCE)), - builder.literal(1, IntegerType.INSTANCE) - ); - assertTrue(expression.showExpression().contains("java.lang.Integer")); - assertTrue(expression.showExpression().contains("[1]")); - Expression newExpr = expression.copy(expression.getInputs()); - try { - throw new GeaFlowDSLException(String.valueOf(newExpr.evaluate(Row.EMPTY))); - } catch (Exception e) { - assertEquals(e.getMessage(), "1"); - } +import com.google.common.collect.Lists; - } +public class ExpressionTest { - @Test - public void testSplitByEnd() { - ExpressionBuilder builder = new DefaultExpressionBuilder(); - Expression expression1 = builder.not( - builder.literal(false, BooleanType.INSTANCE) - ); - Expression expression2 = builder.not( - builder.literal(true, BooleanType.INSTANCE) - ); - Expression expression = builder.and(Stream.of(expression1, expression2).collect(Collectors.toList())); - assertEquals(expression.showExpression(), "And(Not(false),Not(true))"); - List exprList = expression.splitByAnd(); - assertEquals(exprList.size(), 2); - assertEquals(exprList.get(0), expression1); - assertEquals(exprList.get(1), expression2); + @Test + public void testEqual() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.equal( + builder.literal(1, IntegerType.INSTANCE), builder.literal(1, IntegerType.INSTANCE)); + assertEquals(expression.showExpression(), "1 = 1"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testGTE() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.greaterEqThen( + builder.literal(1, IntegerType.INSTANCE), builder.literal(1, IntegerType.INSTANCE)); + assertEquals(expression.showExpression(), "1 >= 1"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testGT() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.greaterThan( + builder.literal(1, IntegerType.INSTANCE), builder.literal(1, IntegerType.INSTANCE)); + assertEquals(expression.showExpression(), "1 > 1"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), false); + } + + @Test + public void testIsFalse() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isFalse(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is false"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testIsNotFalse() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isNotFalse(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is not false"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), false); + } + + @Test + public void testIsNotNull() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isNotNull(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is not null"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testIsNotTrue() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isNotTrue(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is not true"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testIsNull() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isNull(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is null"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), false); + } + + @Test + public void testIsTrue() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.isTrue(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "false is true"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), false); + } + + @Test + public void testLTE() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.lessEqThan( + builder.literal(1, IntegerType.INSTANCE), builder.literal(2, IntegerType.INSTANCE)); + assertEquals(expression.showExpression(), "1 <= 2"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testLT() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.lessThan( + builder.literal(1, IntegerType.INSTANCE), builder.literal(1, IntegerType.INSTANCE)); + assertEquals(expression.showExpression(), "1 < 1"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), false); + } + + @Test + public void testNot() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = builder.not(builder.literal(false, BooleanType.INSTANCE)); + assertEquals(expression.showExpression(), "Not(false)"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testOr() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.or( + Lists.newArrayList( + builder.literal(false, BooleanType.INSTANCE), + builder.literal(true, BooleanType.INSTANCE))); + assertEquals(expression.showExpression(), "Or(false,true)"); + Expression newExpr = expression.copy(expression.getInputs()); + assertEquals(newExpr.evaluate(Row.EMPTY), true); + } + + @Test + public void testItem() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression = + builder.item( + builder.literal(new Integer[] {0, 1}, new ArrayType(IntegerType.INSTANCE)), + builder.literal(1, IntegerType.INSTANCE)); + assertTrue(expression.showExpression().contains("java.lang.Integer")); + assertTrue(expression.showExpression().contains("[1]")); + Expression newExpr = expression.copy(expression.getInputs()); + try { + throw new GeaFlowDSLException(String.valueOf(newExpr.evaluate(Row.EMPTY))); + } catch (Exception e) { + assertEquals(e.getMessage(), "1"); } + } + + @Test + public void testSplitByEnd() { + ExpressionBuilder builder = new DefaultExpressionBuilder(); + Expression expression1 = builder.not(builder.literal(false, BooleanType.INSTANCE)); + Expression expression2 = builder.not(builder.literal(true, BooleanType.INSTANCE)); + Expression expression = + builder.and(Stream.of(expression1, expression2).collect(Collectors.toList())); + assertEquals(expression.showExpression(), "And(Not(false),Not(true))"); + List exprList = expression.splitByAnd(); + assertEquals(exprList.size(), 2); + assertEquals(exprList.get(0), expression1); + assertEquals(exprList.get(1), expression2); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/StepPlanTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/StepPlanTest.java index 82d72af5c..9c6302d61 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/StepPlanTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/plan/StepPlanTest.java @@ -19,11 +19,10 @@ package org.apache.geaflow.dsl.runtime.plan; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; + import org.apache.calcite.rel.core.JoinRelType; import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.IType; @@ -52,139 +51,183 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + public class StepPlanTest { - private static final Logger LOGGER = LoggerFactory.getLogger(StepPlanTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(StepPlanTest.class); - @Test - public void testLogicalPlan() { - StepLogicalPlan logicalPlan = - StepLogicalPlan.start() - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "a", EmptyFilter.getInstance())) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "b", EmptyFilter.getInstance())) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - ; - String planDesc = logicalPlan.getPlanDesc(); - LOGGER.info("Logical plan:\n{}", planDesc); - Assert.assertEquals(planDesc, - "digraph G {\n" - + "2 -> 3 [label= \"chain = false\"]\n" - + "1 -> 2 [label= \"chain = false\"]\n" - + "0 -> 1 [label= \"chain = false\"]\n" - + "3 [label= \"MatchVertex-3 [b]\"]\n" - + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" - + "1 [label= \"MatchVertex-1 [a]\"]\n" - + "0 [label= \"StepSource-0()\"]\n" - + "}"); - } + @Test + public void testLogicalPlan() { + StepLogicalPlan logicalPlan = + StepLogicalPlan.start() + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), + "a", + EmptyFilter.getInstance())) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), + "b", + EmptyFilter.getInstance())) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()); + String planDesc = logicalPlan.getPlanDesc(); + LOGGER.info("Logical plan:\n{}", planDesc); + Assert.assertEquals( + planDesc, + "digraph G {\n" + + "2 -> 3 [label= \"chain = false\"]\n" + + "1 -> 2 [label= \"chain = false\"]\n" + + "0 -> 1 [label= \"chain = false\"]\n" + + "3 [label= \"MatchVertex-3 [b]\"]\n" + + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" + + "1 [label= \"MatchVertex-1 [a]\"]\n" + + "0 [label= \"StepSource-0()\"]\n" + + "}"); + } - @Test - public void testMultiOutputLogicalPlan() { - StepLogicalPlan logicalPlan = - StepLogicalPlan.start() - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "a")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "b")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()); + @Test + public void testMultiOutputLogicalPlan() { + StepLogicalPlan logicalPlan = + StepLogicalPlan.start() + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "a")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "e")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "b")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()); - StepLogicalPlan leftPlan = - logicalPlan.edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "f")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "c")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()); + StepLogicalPlan leftPlan = + logicalPlan + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "f")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "c")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()); - StepLogicalPlan rightPlan = - logicalPlan.edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "g")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()) - .vertexMatch(new MatchVertexFunctionImpl(Sets.newHashSet(BinaryString.fromString( - "person")), "d")) - .withInputPathSchema(PathType.EMPTY) - .withOutputPathSchema(PathType.EMPTY) - .withOutputType(VoidType.INSTANCE) - .withGraphSchema(createGraph()); - StepKeyFunction keyFunction = new StepKeyFunctionImpl(new int[]{}, new IType[]{}); - StepLogicalPlan joinPlan = leftPlan.join(rightPlan, keyFunction, keyFunction, - new StepJoinFunctionImpl(JoinRelType.INNER, new IType[]{}, new IType[]{}), - PathType.EMPTY, false) + StepLogicalPlan rightPlan = + logicalPlan + .edgeMatch(new MatchEdgeFunctionImpl(EdgeDirection.OUT, Sets.newHashSet(), "g")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()) + .vertexMatch( + new MatchVertexFunctionImpl( + Sets.newHashSet(BinaryString.fromString("person")), "d")) + .withInputPathSchema(PathType.EMPTY) + .withOutputPathSchema(PathType.EMPTY) + .withOutputType(VoidType.INSTANCE) + .withGraphSchema(createGraph()); + StepKeyFunction keyFunction = new StepKeyFunctionImpl(new int[] {}, new IType[] {}); + StepLogicalPlan joinPlan = + leftPlan + .join( + rightPlan, + keyFunction, + keyFunction, + new StepJoinFunctionImpl(JoinRelType.INNER, new IType[] {}, new IType[] {}), + PathType.EMPTY, + false) .withOutputPathSchema(new PathType()); - StepLogicalPlanSet logicalPlanSet = new StepLogicalPlanSet(joinPlan); - logicalPlanSet.markChainable(); - String planDesc = logicalPlanSet.getPlanSetDesc(); - LOGGER.info("Logical plan:\n{}", planDesc); - Assert.assertEquals(planDesc, - "digraph G {\n" + "10 -> 11 [label= \"\"]\n" + "8 -> 10 [label= \"chain = false\"]\n" - + "5 -> 8 [label= \"\"]\n" + "4 -> 5 [label= \"chain = false\"]\n" - + "3 -> 4 [label= \"\"]\n" + "2 -> 3 [label= \"chain = false\"]\n" - + "1 -> 2 [label= \"\"]\n" + "0 -> 1 [label= \"\"]\n" - + "9 -> 10 [label= \"chain = false\"]\n" + "7 -> 9 [label= \"\"]\n" - + "6 -> 7 [label= \"chain = false\"]\n" + "3 -> 6 [label= \"\"]\n" - + "11 [label= \"StepEnd-11\"]\n" + "10 [label= \"StepJoin-10\"]\n" - + "8 [label= \"StepExchange-8\"]\n" + "5 [label= \"MatchVertex-5 [c]\"]\n" - + "4 [label= \"MatchEdge-4(OUT) [f]\"]\n" + "3 [label= \"MatchVertex-3 [b]\"]\n" - + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" + "1 [label= \"MatchVertex-1 [a]\"]\n" - + "0 [label= \"StepSource-0()\"]\n" + "9 [label= \"StepExchange-9\"]\n" - + "7 [label= \"MatchVertex-7 [d]\"]\n" + "6 [label= \"MatchEdge-6(OUT) [g]\"]\n" - + "}" - ); - } + StepLogicalPlanSet logicalPlanSet = new StepLogicalPlanSet(joinPlan); + logicalPlanSet.markChainable(); + String planDesc = logicalPlanSet.getPlanSetDesc(); + LOGGER.info("Logical plan:\n{}", planDesc); + Assert.assertEquals( + planDesc, + "digraph G {\n" + + "10 -> 11 [label= \"\"]\n" + + "8 -> 10 [label= \"chain = false\"]\n" + + "5 -> 8 [label= \"\"]\n" + + "4 -> 5 [label= \"chain = false\"]\n" + + "3 -> 4 [label= \"\"]\n" + + "2 -> 3 [label= \"chain = false\"]\n" + + "1 -> 2 [label= \"\"]\n" + + "0 -> 1 [label= \"\"]\n" + + "9 -> 10 [label= \"chain = false\"]\n" + + "7 -> 9 [label= \"\"]\n" + + "6 -> 7 [label= \"chain = false\"]\n" + + "3 -> 6 [label= \"\"]\n" + + "11 [label= \"StepEnd-11\"]\n" + + "10 [label= \"StepJoin-10\"]\n" + + "8 [label= \"StepExchange-8\"]\n" + + "5 [label= \"MatchVertex-5 [c]\"]\n" + + "4 [label= \"MatchEdge-4(OUT) [f]\"]\n" + + "3 [label= \"MatchVertex-3 [b]\"]\n" + + "2 [label= \"MatchEdge-2(OUT) [e]\"]\n" + + "1 [label= \"MatchVertex-1 [a]\"]\n" + + "0 [label= \"StepSource-0()\"]\n" + + "9 [label= \"StepExchange-9\"]\n" + + "7 [label= \"MatchVertex-7 [d]\"]\n" + + "6 [label= \"MatchEdge-6(OUT) [g]\"]\n" + + "}"); + } - private GraphSchema createGraph() { - TableField idField = new TableField("id", Types.of("Long", -1), false); - VertexTable vTable = new VertexTable("default", "testV", Collections.singletonList(idField), "id"); - GeaFlowGraph graph = new GeaFlowGraph("default", "test", Lists.newArrayList(vTable), - new ArrayList<>(), new HashMap<>(), new HashMap<>(), false, false); - GraphRecordType graphRecordType = (GraphRecordType) graph.getRowType(GQLJavaTypeFactory.create()); - return (GraphSchema) SqlTypeUtil.convertType(graphRecordType); - } + private GraphSchema createGraph() { + TableField idField = new TableField("id", Types.of("Long", -1), false); + VertexTable vTable = + new VertexTable("default", "testV", Collections.singletonList(idField), "id"); + GeaFlowGraph graph = + new GeaFlowGraph( + "default", + "test", + Lists.newArrayList(vTable), + new ArrayList<>(), + new HashMap<>(), + new HashMap<>(), + false, + false); + GraphRecordType graphRecordType = + (GraphRecordType) graph.getRowType(GQLJavaTypeFactory.create()); + return (GraphSchema) SqlTypeUtil.convertType(graphRecordType); + } - @BeforeMethod - public void setup() { - StepLogicalPlan.clearCounter(); - } + @BeforeMethod + public void setup() { + StepLogicalPlan.clearCounter(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/AggregateTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/AggregateTest.java index 7e8ea37fa..57821f9ae 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/AggregateTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/AggregateTest.java @@ -23,121 +23,71 @@ public class AggregateTest { - @Test - public void testAggregate_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_006() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_007() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_008() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_009() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_009.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_010() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_011() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_011.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregate_012() throws Exception { - QueryTester - .build() - .withQueryPath("/query/aggregate_012.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testStreamAggregate_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/stream_aggregate_001.sql") - .execute() - .checkSinkResult(); - } - + @Test + public void testAggregate_001() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_001.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_002() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_002.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_003() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_003.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_004() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_004.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_005() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_005.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_006() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_006.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_007() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_007.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_008() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_008.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_009() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_009.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_010() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_010.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_011() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_011.sql").execute().checkSinkResult(); + } + + @Test + public void testAggregate_012() throws Exception { + QueryTester.build().withQueryPath("/query/aggregate_012.sql").execute().checkSinkResult(); + } + + @Test + public void testStreamAggregate_001() throws Exception { + QueryTester.build() + .withQueryPath("/query/stream_aggregate_001.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/CorrelateTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/CorrelateTest.java index 1aa89b72c..8db971c76 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/CorrelateTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/CorrelateTest.java @@ -23,48 +23,28 @@ public class CorrelateTest { - @Test - public void testCorrelate_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/correlate_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testCorrelate_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/correlate_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testCorrelate_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/correlate_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testCorrelate_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/correlate_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testCorrelate_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/correlate_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testCorrelate_001() throws Exception { + QueryTester.build().withQueryPath("/query/correlate_001.sql").execute().checkSinkResult(); + } + + @Test + public void testCorrelate_002() throws Exception { + QueryTester.build().withQueryPath("/query/correlate_002.sql").execute().checkSinkResult(); + } + + @Test + public void testCorrelate_003() throws Exception { + QueryTester.build().withQueryPath("/query/correlate_003.sql").execute().checkSinkResult(); + } + + @Test + public void testCorrelate_004() throws Exception { + QueryTester.build().withQueryPath("/query/correlate_004.sql").execute().checkSinkResult(); + } + + @Test + public void testCorrelate_005() throws Exception { + QueryTester.build().withQueryPath("/query/correlate_005.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/DistinctTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/DistinctTest.java index aaf873e0c..503ac8d91 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/DistinctTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/DistinctTest.java @@ -24,68 +24,46 @@ public class DistinctTest { - @Test - public void testDistinct_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/distinct_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_001() throws Exception { + QueryTester.build().withQueryPath("/query/distinct_001.sql").execute().checkSinkResult(); + } - @Test - public void testDistinct_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/distinct_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_002() throws Exception { + QueryTester.build().withQueryPath("/query/distinct_002.sql").execute().checkSinkResult(); + } - @Test - public void testDistinct_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/distinct_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_003() throws Exception { + QueryTester.build().withQueryPath("/query/distinct_003.sql").execute().checkSinkResult(); + } - @Test - public void testDistinct_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/distinct_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_004() throws Exception { + QueryTester.build().withQueryPath("/query/distinct_004.sql").execute().checkSinkResult(); + } - @Test - public void testDistinct_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/distinct_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_005() throws Exception { + QueryTester.build().withQueryPath("/query/distinct_005.sql").execute().checkSinkResult(); + } - @Test - public void testDistinct_006() throws Exception { - QueryTester - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "4") - .withQueryPath("/query/distinct_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_006() throws Exception { + QueryTester.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "4") + .withQueryPath("/query/distinct_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_007() throws Exception { - QueryTester - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "4") - .withQueryPath("/query/distinct_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_007() throws Exception { + QueryTester.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "4") + .withQueryPath("/query/distinct_007.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/FilterTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/FilterTest.java index 55434f17b..bafe90443 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/FilterTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/FilterTest.java @@ -23,57 +23,33 @@ public class FilterTest { - @Test - public void testSqlFilter_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSqlFilter_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSqlFilter_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSqlFilter_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSqlFilter_006() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSqlFilter_007() throws Exception { - QueryTester - .build() - .withQueryPath("/query/filter_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSqlFilter_001() throws Exception { + QueryTester.build().withQueryPath("/query/filter_001.sql").execute().checkSinkResult(); + } + + @Test + public void testSqlFilter_002() throws Exception { + QueryTester.build().withQueryPath("/query/filter_002.sql").execute().checkSinkResult(); + } + + @Test + public void testSqlFilter_003() throws Exception { + QueryTester.build().withQueryPath("/query/filter_003.sql").execute().checkSinkResult(); + } + + @Test + public void testSqlFilter_005() throws Exception { + QueryTester.build().withQueryPath("/query/filter_005.sql").execute().checkSinkResult(); + } + + @Test + public void testSqlFilter_006() throws Exception { + QueryTester.build().withQueryPath("/query/filter_006.sql").execute().checkSinkResult(); + } + + @Test + public void testSqlFilter_007() throws Exception { + QueryTester.build().withQueryPath("/query/filter_007.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java index 04c62509f..ef7ee7d3c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; @@ -29,314 +30,256 @@ public class GQLAlgorithmTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/algorithm/test/graph"; - - @Test - public void testAlgorithm_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_006() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithm_008() throws Exception { - QueryTester - .build() - .withQueryPath("/query/find_loop.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmKHop() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmKCore() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_kcore.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmClosenessCentrality() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_closeness_centrality.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmWeakConnectedComponents() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_wcc.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmTriangleCount() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_tc.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmClusterCoefficient() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_cluster_coefficient.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmClusterCoefficientWithParams() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_cluster_coefficient_with_params.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmClusterCoefficientMedium() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_cluster_coefficient_medium.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmClusterCoefficientLarge() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_cluster_coefficient_large.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmLabelPropagation() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_lpa.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmConnectedComponents() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_cc.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_inc_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_002() throws Exception { - clearGraph(); - QueryTester - .build() - .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) - .withQueryPath("/query/gql_using_001_ddl.sql") - .execute() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), 1) - .withQueryPath("/query/gql_algorithm_inc_002.sql") - .execute() - .checkSinkResult(); - clearGraph(); - } - - @Test - public void testIncGraphAlgorithm_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_inc_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_inc_004.sql") - .execute() - .compareWithOrder() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_assp() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_algorithm_inc_assp.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncWccVsSpark() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_test_demo_case_vs_spark.sql") - .execute() - .checkSinkResult(); - } - - public void testIncGraphAlgorithm_005() throws Exception { - QueryTester - .build() - .withWorkerNum(20) - .withDedupe(true) - .withQueryPath("/query/gql_algorithm_inc_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_006() throws Exception { - QueryTester - .build() - .withWorkerNum(20) - .withDedupe(true) - .withQueryPath("/query/gql_algorithm_inc_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_007() throws Exception { - QueryTester - .build() - .withWorkerNum(20) - .withDedupe(true) - .withQueryPath("/query/gql_algorithm_inc_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testIncGraphAlgorithm_008() throws Exception { - QueryTester - .build() - .withWorkerNum(20) - .withDedupe(true) - .withQueryPath("/query/gql_algorithm_inc_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmCommonNeighbors() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_common_neighbors.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAlgorithmJaccardSimilarity() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_algorithm_jaccard_similarity.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testEdgeIterator() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_edge_iterator_test.sql") - .execute() - .checkSinkResult(); - } - - private void clearGraph() throws IOException { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - } + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/algorithm/test/graph"; + + @Test + public void testAlgorithm_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithm_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithm_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithm_004() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_004.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithm_005() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_005.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithm_006() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_006.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithm_008() throws Exception { + QueryTester.build().withQueryPath("/query/find_loop.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithmKHop() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_007.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithmKCore() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_kcore.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithmClosenessCentrality() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_closeness_centrality.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmWeakConnectedComponents() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_wcc.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithmTriangleCount() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_tc.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmClusterCoefficient() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_cluster_coefficient.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmClusterCoefficientWithParams() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_cluster_coefficient_with_params.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmClusterCoefficientMedium() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_cluster_coefficient_medium.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmClusterCoefficientLarge() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_cluster_coefficient_large.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmLabelPropagation() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_lpa.sql").execute().checkSinkResult(); + } + + @Test + public void testAlgorithmConnectedComponents() throws Exception { + QueryTester.build().withQueryPath("/query/gql_algorithm_cc.sql").execute().checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_001() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_inc_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_002() throws Exception { + clearGraph(); + QueryTester.build() + .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) + .withQueryPath("/query/gql_using_001_ddl.sql") + .execute() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), 1) + .withQueryPath("/query/gql_algorithm_inc_002.sql") + .execute() + .checkSinkResult(); + clearGraph(); + } + + @Test + public void testIncGraphAlgorithm_003() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_inc_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_004() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_inc_004.sql") + .execute() + .compareWithOrder() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_assp() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_algorithm_inc_assp.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncWccVsSpark() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_test_demo_case_vs_spark.sql") + .execute() + .checkSinkResult(); + } + + public void testIncGraphAlgorithm_005() throws Exception { + QueryTester.build() + .withWorkerNum(20) + .withDedupe(true) + .withQueryPath("/query/gql_algorithm_inc_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_006() throws Exception { + QueryTester.build() + .withWorkerNum(20) + .withDedupe(true) + .withQueryPath("/query/gql_algorithm_inc_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_007() throws Exception { + QueryTester.build() + .withWorkerNum(20) + .withDedupe(true) + .withQueryPath("/query/gql_algorithm_inc_007.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncGraphAlgorithm_008() throws Exception { + QueryTester.build() + .withWorkerNum(20) + .withDedupe(true) + .withQueryPath("/query/gql_algorithm_inc_008.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmCommonNeighbors() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_common_neighbors.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmJaccardSimilarity() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_jaccard_similarity.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testEdgeIterator() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_edge_iterator_test.sql") + .execute() + .checkSinkResult(); + } + + private void clearGraph() throws IOException { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); + } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLConstraintTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLConstraintTest.java index 5f97d675c..2ae359809 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLConstraintTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLConstraintTest.java @@ -23,19 +23,13 @@ public class GQLConstraintTest { - @Test - public void testConstraint_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_constraint_001.sql") - .execute(); - } + @Test + public void testConstraint_001() throws Exception { + QueryTester.build().withQueryPath("/query/gql_constraint_001.sql").execute(); + } - @Test - public void testConstraint_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_constraint_002.sql") - .execute(); - } + @Test + public void testConstraint_002() throws Exception { + QueryTester.build().withQueryPath("/query/gql_constraint_002.sql").execute(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLContinueMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLContinueMatchTest.java index c91bbc16d..1ee6501e8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLContinueMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLContinueMatchTest.java @@ -23,53 +23,48 @@ public class GQLContinueMatchTest { - @Test - public void testContinueMatch_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_continue_match_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testContinueMatch_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_continue_match_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testContinueMatch_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_continue_match_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testContinueMatch_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_continue_match_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testContinueMatch_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_continue_match_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testContinueMatch_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_continue_match_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testContinueMatch_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_continue_match_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testContinueMatch_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_continue_match_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testContinueMatch_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_continue_match_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testContinueMatch_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_continue_match_005.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLDistinctTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLDistinctTest.java index 48f295509..04851cce6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLDistinctTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLDistinctTest.java @@ -23,83 +23,75 @@ public class GQLDistinctTest { - @Test - public void testDistinct_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_007.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testDistinct_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_distinct_008.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDistinct_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_distinct_008.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLFilterTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLFilterTest.java index cb23aae1a..6e570cf37 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLFilterTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLFilterTest.java @@ -23,83 +23,75 @@ public class GQLFilterTest { - @Test - public void testFilter_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_007.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testFilter_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_filter_008.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testFilter_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_filter_008.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLGlobalVariableTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLGlobalVariableTest.java index f656f1308..a70de05a0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLGlobalVariableTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLGlobalVariableTest.java @@ -23,13 +23,12 @@ public class GQLGlobalVariableTest { - @Test - public void testVertexGlobalVariable_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_global_variable_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testVertexGlobalVariable_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_global_variable_001.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLInsertTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLInsertTest.java index 2b750f3a7..c425d6a73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLInsertTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLInsertTest.java @@ -26,114 +26,102 @@ public class GQLInsertTest { - @Test - public void testInsertAndQuery_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_graph_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_001() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_graph_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQuery_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_graph_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_002() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_graph_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQuery_003() throws Exception { - QueryTester - .build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "2") - .withQueryPath("/query/gql_insert_and_graph_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_003() throws Exception { + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "2") + .withQueryPath("/query/gql_insert_and_graph_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQuery_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_graph_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_004() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_graph_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQueryWithRequest_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_query_with_request_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQueryWithRequest_001() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_query_with_request_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQueryWithRequest_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_query_with_request_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQueryWithRequest_002() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_query_with_request_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQueryWithRequest_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_query_with_request_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQueryWithRequest_003() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_query_with_request_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQueryWithSubQuery_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_query_with_subquery_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQueryWithSubQuery_001() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_query_with_subquery_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQueryWithSubQuery_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_query_with_subquery_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQueryWithSubQuery_002() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_query_with_subquery_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQuery_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_insert_and_graph_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_005() throws Exception { + QueryTester.build() + .withQueryPath("/query/gql_insert_and_graph_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testInsertAndQuery_006() throws Exception { - QueryTester - .build() - .withConfig(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true") - .withQueryPath("/query/gql_insert_and_graph_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testInsertAndQuery_006() throws Exception { + QueryTester.build() + .withConfig(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true") + .withQueryPath("/query/gql_insert_and_graph_006.sql") + .execute() + .checkSinkResult(); + } - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testInsertAndQuery_007() throws Exception { - QueryTester - .build() - .withConfig(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true") - .withQueryPath("/query/gql_insert_and_graph_007.sql") - .execute() - .checkSinkResult(); - } + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testInsertAndQuery_007() throws Exception { + QueryTester.build() + .withConfig(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true") + .withQueryPath("/query/gql_insert_and_graph_007.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLJoinTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLJoinTest.java index 845b0fa30..63c8a5583 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLJoinTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLJoinTest.java @@ -23,73 +23,66 @@ public class GQLJoinTest { - @Test - public void testJoin_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testJoin_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_join_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testJoin_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_join_007.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLetTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLetTest.java index d233b2a9e..bc1447b5a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLetTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLetTest.java @@ -23,33 +23,30 @@ public class GQLLetTest { - @Test - public void testLet_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_let_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testLet_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_let_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testLet_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_let_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testLet_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_let_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testLet_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_let_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testLet_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_let_003.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLoopMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLoopMatchTest.java index 2379e99bf..7992289da 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLoopMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLLoopMatchTest.java @@ -23,23 +23,21 @@ public class GQLLoopMatchTest { - @Test - public void testMatchLoop_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_loop_match_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testMatchLoop_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_loop_match_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testMatchLoop_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_loop_match_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testMatchLoop_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_loop_match_002.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLMatchTest.java index 90b790d66..ef51f058f 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLMatchTest.java @@ -23,139 +23,113 @@ public class GQLMatchTest { - @Test - public void testMatch_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_007() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_match_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_008() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_match_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_009() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_match_009.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_011() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_011.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_012() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_match_012.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_013() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_013.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatch_014() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_match_014.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testMatch_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_007() throws Exception { + QueryTester.build().withQueryPath("/query/gql_match_007.sql").execute().checkSinkResult(); + } + + @Test + public void testMatch_008() throws Exception { + QueryTester.build().withQueryPath("/query/gql_match_008.sql").execute().checkSinkResult(); + } + + @Test + public void testMatch_009() throws Exception { + QueryTester.build().withQueryPath("/query/gql_match_009.sql").execute().checkSinkResult(); + } + + @Test + public void testMatch_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_010.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_011() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_011.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_012() throws Exception { + QueryTester.build().withQueryPath("/query/gql_match_012.sql").execute().checkSinkResult(); + } + + @Test + public void testMatch_013() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_013.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatch_014() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_match_014.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLParameterRequestTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLParameterRequestTest.java index 258c39478..b69001995 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLParameterRequestTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLParameterRequestTest.java @@ -23,63 +23,57 @@ public class GQLParameterRequestTest { - @Test - public void testParameterRequest_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_parameter_request_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testParameterRequest_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_parameter_request_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testParameterRequest_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_parameter_request_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testParameterRequest_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_parameter_request_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testParameterRequest_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_parameter_request_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testParameterRequest_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_parameter_request_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIdOnlyParameterRequest_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_idonly_parameter_request_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIdOnlyParameterRequest_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_idonly_parameter_request_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIdOnlyParameterRequest_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_idonly_parameter_request_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIdOnlyParameterRequest_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_idonly_parameter_request_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIdOnlyParameterRequest_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_idonly_parameter_request_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIdOnlyParameterRequest_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_idonly_parameter_request_003.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLReturnTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLReturnTest.java index 26ded1b0a..511b689ce 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLReturnTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLReturnTest.java @@ -23,164 +23,148 @@ public class GQLReturnTest { - @Test - public void testReturn_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_008.sql") - .compareWithOrder() - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_009() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_009.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_011() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_011.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_012() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_012.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_013() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_013.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_014() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_014.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_015() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_015.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testReturn_016() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_return_016.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testReturn_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_007.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_008.sql") + .compareWithOrder() + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_009() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_009.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_010.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_011() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_011.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_012() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_012.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_013() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_013.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_014() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_014.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_015() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_015.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testReturn_016() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_return_016.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSortTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSortTest.java index b15634126..51aedfd73 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSortTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSortTest.java @@ -23,107 +23,97 @@ public class GQLSortTest { - @Test - public void testSort_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_006.sql") - .compareWithOrder() - .execute() - .checkSinkResult(); - } + @Test + public void testSort_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_006.sql") + .compareWithOrder() + .execute() + .checkSinkResult(); + } - @Test - public void testSort_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_007.sql") - .compareWithOrder() - .execute() - .checkSinkResult(); - } + @Test + public void testSort_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_007.sql") + .compareWithOrder() + .execute() + .checkSinkResult(); + } - @Test - public void testSort_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_008.sql") - .compareWithOrder() - .execute() - .checkSinkResult(); - } + @Test + public void testSort_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_008.sql") + .compareWithOrder() + .execute() + .checkSinkResult(); + } - @Test - public void testSort_009() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_009.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_009() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_009.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSort_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_sort_010.sql") - .compareWithOrder() - .execute() - .checkSinkResult(); - } + @Test + public void testSort_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_sort_010.sql") + .compareWithOrder() + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSourceDestinationTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSourceDestinationTest.java index fee0fd1cc..9630ccc18 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSourceDestinationTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSourceDestinationTest.java @@ -23,53 +23,48 @@ public class GQLSourceDestinationTest { - @Test - public void testSourceDestination_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_source_destination_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSourceDestination_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_source_destination_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSourceDestination_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_source_destination_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSourceDestination_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_source_destination_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSourceDestination_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_source_destination_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSourceDestination_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_source_destination_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSourceDestination_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_source_destination_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSourceDestination_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_source_destination_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSourceDestination_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_source_destination_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSourceDestination_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_source_destination_005.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLStandardTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLStandardTest.java index be6a69221..dc791393c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLStandardTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLStandardTest.java @@ -23,113 +23,102 @@ public class GQLStandardTest { - @Test - public void testStandard_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_001.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_001.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_002.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_002.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_003.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_003.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_004.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_004.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_005.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_005.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_006.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_006.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_007.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_007.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_008.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_008.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_009() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_009.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_009() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_009.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_010.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_010.sql") + .execute() + .checkSinkResult("standard"); + } - @Test - public void testStandard_011() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/movie_graph.sql") - .withQueryPath("/query/standard/standard_gql_011.sql") - .execute() - .checkSinkResult("standard"); - } + @Test + public void testStandard_011() throws Exception { + QueryTester.build() + .withGraphDefine("/query/movie_graph.sql") + .withQueryPath("/query/standard/standard_gql_011.sql") + .execute() + .checkSinkResult("standard"); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSubQueryTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSubQueryTest.java index ca71f8077..220423bc0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSubQueryTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLSubQueryTest.java @@ -23,114 +23,102 @@ public class GQLSubQueryTest { - @Test - public void testSubQuery_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_007.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_008.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_008.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_009() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_009.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_009() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_009.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testSubQuery_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testSubQuery_011() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_subquery_011.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSubQuery_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_010.sql") + .execute() + .checkSinkResult(); + } + @Test + public void testSubQuery_011() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_subquery_011.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUnionTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUnionTest.java index 7039a7732..b70a178be 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUnionTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUnionTest.java @@ -23,143 +23,129 @@ public class GQLUnionTest { - @Test - public void testUnion_001() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_002() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_003() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_004() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_005() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_006() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_007() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_008() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_009() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_009.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_010() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_011() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_011.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_012() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_012.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_013() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_013.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testUnion_014() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_union_014.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUnion_001() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_002() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_003() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_004() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_005() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_006() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_007() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_007.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_008() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_008.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_009() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_009.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_010() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_010.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_011() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_011.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_012() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_012.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_013() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_013.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testUnion_014() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_union_014.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUsingTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUsingTest.java index cfabc146b..5652f8292 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUsingTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLUsingTest.java @@ -24,42 +24,29 @@ public class GQLUsingTest { - @Test - public void testUsing_001() throws Exception { - QueryTester - .build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) - .withQueryPath("/query/gql_using_001_ddl.sql") - .execute() - .withQueryPath("/query/gql_using_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUsing_001() throws Exception { + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) + .withQueryPath("/query/gql_using_001_ddl.sql") + .execute() + .withQueryPath("/query/gql_using_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testUsing_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_using_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUsing_002() throws Exception { + QueryTester.build().withQueryPath("/query/gql_using_002.sql").execute().checkSinkResult(); + } - @Test - public void testUsingTemporary_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_temporary_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUsingTemporary_001() throws Exception { + QueryTester.build().withQueryPath("/query/gql_temporary_001.sql").execute().checkSinkResult(); + } - @Test - public void testUsingTemporary_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/gql_temporary_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUsingTemporary_002() throws Exception { + QueryTester.build().withQueryPath("/query/gql_temporary_002.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLWriteAndReadTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLWriteAndReadTest.java index 60a4bd03e..dc7b0f743 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLWriteAndReadTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLWriteAndReadTest.java @@ -25,42 +25,39 @@ public class GQLWriteAndReadTest { - @Test - public void testInsertDynamicGraph_001() throws Exception { - QueryTester - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), 1) - .withQueryPath("/query/gql_graph_write_001.sql") - .execute() // write data to graph - .withQueryPath("/query/gql_graph_read_001.sql") - .execute() // query the graph - .checkSinkResult(); - } + @Test + public void testInsertDynamicGraph_001() throws Exception { + QueryTester.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), 1) + .withQueryPath("/query/gql_graph_write_001.sql") + .execute() // write data to graph + .withQueryPath("/query/gql_graph_read_001.sql") + .execute() // query the graph + .checkSinkResult(); + } - @Test - public void testInsertDynamicGraph_002() throws Exception { - QueryTester - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), -1) - .withQueryPath("/query/gql_graph_write_002.sql") - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) - .execute() // write data to graph - .withQueryPath("/query/gql_graph_write_003.sql") - .execute() // write data to graph - .withQueryPath("/query/gql_graph_read_002.sql") - .execute() // query the graph - .checkSinkResult(); - } + @Test + public void testInsertDynamicGraph_002() throws Exception { + QueryTester.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), -1) + .withQueryPath("/query/gql_graph_write_002.sql") + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) + .execute() // write data to graph + .withQueryPath("/query/gql_graph_write_003.sql") + .execute() // write data to graph + .withQueryPath("/query/gql_graph_read_002.sql") + .execute() // query the graph + .checkSinkResult(); + } - @Test - public void testInsertStaticGraph_001() throws Exception { - QueryTester - .build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) - .withQueryPath("/query/gql_static_graph_001.sql") - .execute() // write data to graph - .withQueryPath("/query/gql_static_graph_read_001.sql") - .execute() // query the graph - .checkSinkResult(); - } + @Test + public void testInsertStaticGraph_001() throws Exception { + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), 1) + .withQueryPath("/query/gql_static_graph_001.sql") + .execute() // write data to graph + .withQueryPath("/query/gql_static_graph_read_001.sql") + .execute() // query the graph + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTPerformanceTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTPerformanceTest.java index 44d7ab7bc..cd7208afe 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTPerformanceTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTPerformanceTest.java @@ -19,215 +19,210 @@ package org.apache.geaflow.dsl.runtime.query; -import org.apache.geaflow.common.config.keys.DSLConfigKeys; -import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; -import org.apache.geaflow.file.FileConfigKeys; import java.io.File; import java.io.IOException; import java.util.concurrent.TimeUnit; + import org.apache.commons.io.FileUtils; -import org.testng.annotations.Test; -import org.testng.annotations.BeforeClass; import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; /** - * IncMST algorithm performance test class - * Test algorithm performance in large graph scenarios - * + * IncMST algorithm performance test class Test algorithm performance in large graph scenarios + * * @author Geaflow Team */ public class IncMSTPerformanceTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/inc_mst/perf_test/graph"; - private long startTime; - private long endTime; - - @BeforeClass - public void setUp() throws IOException { - // Clean up test directory - FileUtils.deleteDirectory(new File(TEST_GRAPH_PATH)); - System.out.println("=== IncMST Performance Test Setup Complete ==="); - } - - @AfterClass - public void tearDown() throws IOException { - // Clean up test directory - FileUtils.deleteDirectory(new File(TEST_GRAPH_PATH)); - System.out.println("=== IncMST Performance Test Cleanup Complete ==="); - } - - @Test - public void testIncMST_001_SmallGraphPerformance() throws Exception { - System.out.println("Starting Small Graph Performance Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_001.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Small Graph (Modern)", startTime, endTime); - } - - @Test - public void testIncMST_002_MediumGraphPerformance() throws Exception { - System.out.println("Starting Medium Graph Performance Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/medium_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_002.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Medium Graph (1K vertices)", startTime, endTime); - } - - @Test - public void testIncMST_003_LargeGraphPerformance() throws Exception { - System.out.println("Starting Large Graph Performance Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/large_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_003.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Large Graph (10K vertices)", startTime, endTime); - } - - @Test - public void testIncMST_004_IncrementalUpdatePerformance() throws Exception { - System.out.println("Starting Incremental Update Performance Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/dynamic_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_004.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Incremental Update", startTime, endTime); - } - - @Test - public void testIncMST_005_ConvergencePerformance() throws Exception { - System.out.println("Starting Convergence Performance Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_005.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Convergence Test", startTime, endTime); - } - - @Test - public void testIncMST_006_MemoryEfficiency() throws Exception { - System.out.println("Starting Memory Efficiency Test..."); - long initialMemory = getCurrentMemoryUsage(); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/large_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_006.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - long finalMemory = getCurrentMemoryUsage(); - printPerformanceMetrics("Memory Efficiency", startTime, endTime); - printMemoryMetrics("Memory Efficiency", initialMemory, finalMemory); - } - - @Test - public void testIncMST_007_ScalabilityTest() throws Exception { - System.out.println("Starting Scalability Test..."); - startTime = System.nanoTime(); - - QueryTester - .build() - .withGraphDefine("/query/scalability_graph.sql") - .withQueryPath("/query/gql_inc_mst_perf_007.sql") - .execute() - .checkSinkResult(); - - endTime = System.nanoTime(); - printPerformanceMetrics("Scalability Test (100K vertices)", startTime, endTime); - } - - /** - * Print performance metrics - * @param testName Test name - * @param startTime Start time (nanoseconds) - * @param endTime End time (nanoseconds) - */ - private void printPerformanceMetrics(String testName, long startTime, long endTime) { - long durationNano = endTime - startTime; - long durationMs = TimeUnit.NANOSECONDS.toMillis(durationNano); - long durationSec = TimeUnit.NANOSECONDS.toSeconds(durationNano); - - System.out.println("=== Performance Metrics for " + testName + " ==="); - System.out.println("Execution Time: " + durationMs + " ms (" + durationSec + " seconds)"); - System.out.println("Throughput: " + String.format("%.2f", 1000.0 / durationMs) + " operations/ms"); - System.out.println("========================================"); - } - - /** - * Print memory metrics - * @param testName Test name - * @param initialMemory Initial memory usage (bytes) - * @param finalMemory Final memory usage (bytes) - */ - private void printMemoryMetrics(String testName, long initialMemory, long finalMemory) { - long memoryUsed = finalMemory - initialMemory; - double memoryUsedMB = memoryUsed / (1024.0 * 1024.0); - - System.out.println("=== Memory Metrics for " + testName + " ==="); - System.out.println("Memory Used: " + String.format("%.2f", memoryUsedMB) + " MB"); - System.out.println("Initial Memory: " + formatMemorySize(initialMemory)); - System.out.println("Final Memory: " + formatMemorySize(finalMemory)); - System.out.println("========================================"); - } - - /** - * Get current memory usage - * @return Memory usage (bytes) - */ - private long getCurrentMemoryUsage() { - Runtime runtime = Runtime.getRuntime(); - return runtime.totalMemory() - runtime.freeMemory(); - } - - /** - * Format memory size - * @param bytes Number of bytes - * @return Formatted string - */ - private String formatMemorySize(long bytes) { - if (bytes < 1024) { - return bytes + " B"; - } else if (bytes < 1024 * 1024) { - return String.format("%.2f KB", bytes / 1024.0); - } else if (bytes < 1024 * 1024 * 1024) { - return String.format("%.2f MB", bytes / (1024.0 * 1024.0)); - } else { - return String.format("%.2f GB", bytes / (1024.0 * 1024.0 * 1024.0)); - } + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/inc_mst/perf_test/graph"; + private long startTime; + private long endTime; + + @BeforeClass + public void setUp() throws IOException { + // Clean up test directory + FileUtils.deleteDirectory(new File(TEST_GRAPH_PATH)); + System.out.println("=== IncMST Performance Test Setup Complete ==="); + } + + @AfterClass + public void tearDown() throws IOException { + // Clean up test directory + FileUtils.deleteDirectory(new File(TEST_GRAPH_PATH)); + System.out.println("=== IncMST Performance Test Cleanup Complete ==="); + } + + @Test + public void testIncMST_001_SmallGraphPerformance() throws Exception { + System.out.println("Starting Small Graph Performance Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_001.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Small Graph (Modern)", startTime, endTime); + } + + @Test + public void testIncMST_002_MediumGraphPerformance() throws Exception { + System.out.println("Starting Medium Graph Performance Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/medium_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_002.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Medium Graph (1K vertices)", startTime, endTime); + } + + @Test + public void testIncMST_003_LargeGraphPerformance() throws Exception { + System.out.println("Starting Large Graph Performance Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/large_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_003.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Large Graph (10K vertices)", startTime, endTime); + } + + @Test + public void testIncMST_004_IncrementalUpdatePerformance() throws Exception { + System.out.println("Starting Incremental Update Performance Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/dynamic_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_004.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Incremental Update", startTime, endTime); + } + + @Test + public void testIncMST_005_ConvergencePerformance() throws Exception { + System.out.println("Starting Convergence Performance Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_005.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Convergence Test", startTime, endTime); + } + + @Test + public void testIncMST_006_MemoryEfficiency() throws Exception { + System.out.println("Starting Memory Efficiency Test..."); + long initialMemory = getCurrentMemoryUsage(); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/large_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_006.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + long finalMemory = getCurrentMemoryUsage(); + printPerformanceMetrics("Memory Efficiency", startTime, endTime); + printMemoryMetrics("Memory Efficiency", initialMemory, finalMemory); + } + + @Test + public void testIncMST_007_ScalabilityTest() throws Exception { + System.out.println("Starting Scalability Test..."); + startTime = System.nanoTime(); + + QueryTester.build() + .withGraphDefine("/query/scalability_graph.sql") + .withQueryPath("/query/gql_inc_mst_perf_007.sql") + .execute() + .checkSinkResult(); + + endTime = System.nanoTime(); + printPerformanceMetrics("Scalability Test (100K vertices)", startTime, endTime); + } + + /** + * Print performance metrics + * + * @param testName Test name + * @param startTime Start time (nanoseconds) + * @param endTime End time (nanoseconds) + */ + private void printPerformanceMetrics(String testName, long startTime, long endTime) { + long durationNano = endTime - startTime; + long durationMs = TimeUnit.NANOSECONDS.toMillis(durationNano); + long durationSec = TimeUnit.NANOSECONDS.toSeconds(durationNano); + + System.out.println("=== Performance Metrics for " + testName + " ==="); + System.out.println("Execution Time: " + durationMs + " ms (" + durationSec + " seconds)"); + System.out.println( + "Throughput: " + String.format("%.2f", 1000.0 / durationMs) + " operations/ms"); + System.out.println("========================================"); + } + + /** + * Print memory metrics + * + * @param testName Test name + * @param initialMemory Initial memory usage (bytes) + * @param finalMemory Final memory usage (bytes) + */ + private void printMemoryMetrics(String testName, long initialMemory, long finalMemory) { + long memoryUsed = finalMemory - initialMemory; + double memoryUsedMB = memoryUsed / (1024.0 * 1024.0); + + System.out.println("=== Memory Metrics for " + testName + " ==="); + System.out.println("Memory Used: " + String.format("%.2f", memoryUsedMB) + " MB"); + System.out.println("Initial Memory: " + formatMemorySize(initialMemory)); + System.out.println("Final Memory: " + formatMemorySize(finalMemory)); + System.out.println("========================================"); + } + + /** + * Get current memory usage + * + * @return Memory usage (bytes) + */ + private long getCurrentMemoryUsage() { + Runtime runtime = Runtime.getRuntime(); + return runtime.totalMemory() - runtime.freeMemory(); + } + + /** + * Format memory size + * + * @param bytes Number of bytes + * @return Formatted string + */ + private String formatMemorySize(long bytes) { + if (bytes < 1024) { + return bytes + " B"; + } else if (bytes < 1024 * 1024) { + return String.format("%.2f KB", bytes / 1024.0); + } else if (bytes < 1024 * 1024 * 1024) { + return String.format("%.2f MB", bytes / (1024.0 * 1024.0)); + } else { + return String.format("%.2f GB", bytes / (1024.0 * 1024.0 * 1024.0)); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTTest.java index c165b6213..0a691f6cf 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncMSTTest.java @@ -19,121 +19,105 @@ package org.apache.geaflow.dsl.runtime.query; -import org.apache.geaflow.common.config.keys.DSLConfigKeys; -import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; -import org.apache.geaflow.file.FileConfigKeys; -import java.io.File; -import java.io.IOException; -import org.apache.commons.io.FileUtils; import org.testng.annotations.Test; /** - * Incremental Minimum Spanning Tree algorithm test class - * Includes basic functionality tests, incremental update tests, connectivity validation, etc. - * + * Incremental Minimum Spanning Tree algorithm test class Includes basic functionality tests, + * incremental update tests, connectivity validation, etc. + * * @author Geaflow Team */ public class IncMSTTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/inc_mst/test/graph"; + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/inc_mst/test/graph"; - @Test - public void testIncMST_001_Basic() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_001_Basic() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_002_IncrementalUpdate() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_002_IncrementalUpdate() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_003_LargeGraph() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/large_graph.sql") - .withQueryPath("/query/gql_inc_mst_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_003_LargeGraph() throws Exception { + QueryTester.build() + .withGraphDefine("/query/large_graph.sql") + .withQueryPath("/query/gql_inc_mst_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_004_EdgeAddition() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/dynamic_graph.sql") - .withQueryPath("/query/gql_inc_mst_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_004_EdgeAddition() throws Exception { + QueryTester.build() + .withGraphDefine("/query/dynamic_graph.sql") + .withQueryPath("/query/gql_inc_mst_004.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_005_EdgeDeletion() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/dynamic_graph.sql") - .withQueryPath("/query/gql_inc_mst_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_005_EdgeDeletion() throws Exception { + QueryTester.build() + .withGraphDefine("/query/dynamic_graph.sql") + .withQueryPath("/query/gql_inc_mst_005.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_006_ConnectedComponents() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/disconnected_graph.sql") - .withQueryPath("/query/gql_inc_mst_006.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_006_ConnectedComponents() throws Exception { + QueryTester.build() + .withGraphDefine("/query/disconnected_graph.sql") + .withQueryPath("/query/gql_inc_mst_006.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_007_Performance() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/performance_graph.sql") - .withQueryPath("/query/gql_inc_mst_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_007_Performance() throws Exception { + QueryTester.build() + .withGraphDefine("/query/performance_graph.sql") + .withQueryPath("/query/gql_inc_mst_007.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_008_Convergence() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_008.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_008_Convergence() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_008.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_009_CustomParameters() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_mst_009.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_009_CustomParameters() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_mst_009.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncMST_010_ComplexTopology() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/complex_graph.sql") - .withQueryPath("/query/gql_inc_mst_010.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncMST_010_ComplexTopology() throws Exception { + QueryTester.build() + .withGraphDefine("/query/complex_graph.sql") + .withQueryPath("/query/gql_inc_mst_010.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrMatchTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrMatchTest.java index bac6b1f93..875d6f740 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrMatchTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrMatchTest.java @@ -28,197 +28,202 @@ import java.util.List; import java.util.Random; import java.util.Set; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; -import scala.Tuple2; +import scala.Tuple2; public class IncrMatchTest { - private int vertexNum = 80; - private int edgeNum = 400; - - private final String lineSplit = "----"; - - private void RunTest(String queryPath) throws Exception { - - QueryTester.build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") - .withQueryPath(queryPath) - .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "true") - .withConfig(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), lineSplit) - .execute(); - - String incr = getTargetPath(queryPath); - List> incrRes = readRes(incr, true); - - QueryTester.build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") - .withQueryPath(queryPath) - .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "false") - .withConfig(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), lineSplit) - .execute(); - - String allPath = getTargetPath(queryPath); - List> allRes = readRes(allPath, false); - - // Ensure both results have the same number of windows - Assert.assertEquals(incrRes.size(), allRes.size(), - "Incremental and full traversal should have same number of windows"); - - // For incremental traversal, each window contains cumulative results (all results from window 0 to current) - // For full traversal, each window contains only results from that specific window - // So we need to compare: incremental[i] should equal union of full[0] to full[i] - Set cumulativeFull = new HashSet<>(); - for (int i = 0; i < incrRes.size(); i++) { - Set incrSet = incrRes.get(i); - Set fullSet = allRes.get(i); - - // Add current window's results to cumulative set - cumulativeFull.addAll(fullSet); - - // Compare incremental result (cumulative) with cumulative full result - Assert.assertEquals(incrSet, cumulativeFull, - String.format("Window %d mismatch: incremental (cumulative)=%s, full (cumulative)=%s", - i, incrSet, cumulativeFull)); - } + private int vertexNum = 80; + private int edgeNum = 400; + + private final String lineSplit = "----"; + + private void RunTest(String queryPath) throws Exception { + + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") + .withQueryPath(queryPath) + .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "true") + .withConfig(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), lineSplit) + .execute(); + + String incr = getTargetPath(queryPath); + List> incrRes = readRes(incr, true); + + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") + .withQueryPath(queryPath) + .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "false") + .withConfig(DSLConfigKeys.TABLE_SINK_SPLIT_LINE.getKey(), lineSplit) + .execute(); + + String allPath = getTargetPath(queryPath); + List> allRes = readRes(allPath, false); + + // Ensure both results have the same number of windows + Assert.assertEquals( + incrRes.size(), + allRes.size(), + "Incremental and full traversal should have same number of windows"); + + // For incremental traversal, each window contains cumulative results (all results from window 0 + // to current) + // For full traversal, each window contains only results from that specific window + // So we need to compare: incremental[i] should equal union of full[0] to full[i] + Set cumulativeFull = new HashSet<>(); + for (int i = 0; i < incrRes.size(); i++) { + Set incrSet = incrRes.get(i); + Set fullSet = allRes.get(i); + + // Add current window's results to cumulative set + cumulativeFull.addAll(fullSet); + + // Compare incremental result (cumulative) with cumulative full result + Assert.assertEquals( + incrSet, + cumulativeFull, + String.format( + "Window %d mismatch: incremental (cumulative)=%s, full (cumulative)=%s", + i, incrSet, cumulativeFull)); } - - @Test - public void testIncrMatch0() throws Exception { - RunTest("/query/gql_incr_match.sql"); + } + + @Test + public void testIncrMatch0() throws Exception { + RunTest("/query/gql_incr_match.sql"); + } + + @Test + public void testIncrMatchMultiParall() throws Exception { + QueryTester.build() + .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") + .withQueryPath("/query/gql_incr_match_multi_parall.sql") + .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "true") + .execute() + .checkSinkResult(); + } + + @Test + public void testIncrMatchRandom() throws Exception { + createData(); + RunTest("/query/gql_incr_match_random.sql"); + } + + private void createData() { + createVertex(); + createEdge(); + } + + private void createEdge() { + File edgeFile = new File("/tmp/geaflow-test/incr_modern_edge.txt"); + Set> edges = new HashSet<>(); + + // Use fixed seed for deterministic test results to avoid flaky tests + Random r = new Random(42L); + while (edges.size() < edgeNum) { + int src = r.nextInt(vertexNum) + 1; // 1 to vertexNum + int dst = r.nextInt(vertexNum) + 1; // 1 to vertexNum + while (src == dst) { + dst = r.nextInt(vertexNum) + 1; + } + edges.add(new Tuple2<>(src, dst)); } - @Test - public void testIncrMatchMultiParall() throws Exception { - QueryTester.build() - .withConfig(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT.getKey(), "1") - .withQueryPath("/query/gql_incr_match_multi_parall.sql") - .withConfig(DSLConfigKeys.ENABLE_INCR_TRAVERSAL.getKey(), "true") - .execute() - .checkSinkResult(); + List edgeString = new ArrayList<>(); + for (Tuple2 edge : edges) { + edgeString.add(String.format("%s,%s,knows,0.5", edge._1, edge._2)); } - @Test - public void testIncrMatchRandom() throws Exception { - createData(); - RunTest("/query/gql_incr_match_random.sql"); + try { + FileUtils.writeLines(edgeFile, edgeString); + } catch (IOException e) { + throw new RuntimeException(e); } + } - private void createData() { - createVertex(); - createEdge(); + private void createVertex() { + File file = new File("/tmp/geaflow-test/incr_modern_vertex.txt"); + List vertices = new ArrayList<>(); + for (int i = 1; i <= vertexNum; i++) { + vertices.add(String.format("%s,person,name1,1", i)); } - private void createEdge() { - File edgeFile = new File("/tmp/geaflow-test/incr_modern_edge.txt"); - Set> edges = new HashSet<>(); - - // Use fixed seed for deterministic test results to avoid flaky tests - Random r = new Random(42L); - while (edges.size() < edgeNum) { - int src = r.nextInt(vertexNum) + 1; // 1 to vertexNum - int dst = r.nextInt(vertexNum) + 1; // 1 to vertexNum - while (src == dst) { - dst = r.nextInt(vertexNum) + 1; + try { + FileUtils.writeLines(file, vertices); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static String getTargetPath(String queryPath) { + assert queryPath != null; + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String targetPath = "target/" + lastPath.split("\\.")[0]; + String currentPath = new File(".").getAbsolutePath(); + targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath + "/partition_0"; + return targetPath; + } + + private List> readRes(String path, boolean isIncr) throws IOException { + List> res = new ArrayList<>(); + Set curWindow = new HashSet<>(); + Set allHistoryRes = new HashSet<>(); + BufferedReader reader = null; + try { + reader = new BufferedReader(new FileReader(path)); + String currentLine; + while ((currentLine = reader.readLine()) != null) { + if (currentLine.equals(lineSplit)) { + // Process window separator + if (curWindow.isEmpty()) { + // Empty window: for incremental, keep history; for full, use empty + if (isIncr) { + res.add(new HashSet<>(allHistoryRes)); + } else { + res.add(new HashSet<>()); } - edges.add(new Tuple2<>(src, dst)); - } - - List edgeString = new ArrayList<>(); - for (Tuple2 edge : edges) { - edgeString.add(String.format("%s,%s,knows,0.5", edge._1, edge._2)); - } - - try { - FileUtils.writeLines(edgeFile, edgeString); - } catch (IOException e) { - throw new RuntimeException(e); + } else { + // Non-empty window + if (isIncr) { + allHistoryRes.addAll(curWindow); + res.add(new HashSet<>(allHistoryRes)); + } else { + res.add(new HashSet<>(curWindow)); + } + } + curWindow = new HashSet<>(); + } else { + curWindow.add(currentLine); } - } - - private void createVertex() { - File file = new File("/tmp/geaflow-test/incr_modern_vertex.txt"); - List vertices = new ArrayList<>(); - for (int i = 1; i <= vertexNum; i++) { - vertices.add(String.format("%s,person,name1,1", i)); + } + // Handle last window if file doesn't end with separator + if (!curWindow.isEmpty()) { + if (isIncr) { + allHistoryRes.addAll(curWindow); + res.add(new HashSet<>(allHistoryRes)); + } else { + res.add(new HashSet<>(curWindow)); } - - try { - FileUtils.writeLines(file, vertices); - } catch (IOException e) { - throw new RuntimeException(e); + } + } catch (IOException e) { + e.printStackTrace(); + } finally { + try { + if (reader != null) { + reader.close(); } - + } catch (IOException ex) { + ex.printStackTrace(); + } } - private static String getTargetPath(String queryPath) { - assert queryPath != null; - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String targetPath = "target/" + lastPath.split("\\.")[0]; - String currentPath = new File(".").getAbsolutePath(); - targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath + "/partition_0"; - return targetPath; - } - - - private List> readRes(String path, boolean isIncr) throws IOException { - List> res = new ArrayList<>(); - Set curWindow = new HashSet<>(); - Set allHistoryRes = new HashSet<>(); - BufferedReader reader = null; - try { - reader = new BufferedReader(new FileReader(path)); - String currentLine; - while ((currentLine = reader.readLine()) != null) { - if (currentLine.equals(lineSplit)) { - // Process window separator - if (curWindow.isEmpty()) { - // Empty window: for incremental, keep history; for full, use empty - if (isIncr) { - res.add(new HashSet<>(allHistoryRes)); - } else { - res.add(new HashSet<>()); - } - } else { - // Non-empty window - if (isIncr) { - allHistoryRes.addAll(curWindow); - res.add(new HashSet<>(allHistoryRes)); - } else { - res.add(new HashSet<>(curWindow)); - } - } - curWindow = new HashSet<>(); - } else { - curWindow.add(currentLine); - } - } - // Handle last window if file doesn't end with separator - if (!curWindow.isEmpty()) { - if (isIncr) { - allHistoryRes.addAll(curWindow); - res.add(new HashSet<>(allHistoryRes)); - } else { - res.add(new HashSet<>(curWindow)); - } - } - } catch (IOException e) { - e.printStackTrace(); - } finally { - try { - if (reader != null) { - reader.close(); - } - } catch (IOException ex) { - ex.printStackTrace(); - } - } - - return res; - } + return res; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrementalKCoreTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrementalKCoreTest.java index d432cd39e..e8a03511e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrementalKCoreTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/IncrementalKCoreTest.java @@ -17,82 +17,69 @@ package org.apache.geaflow.dsl.runtime.query; -import java.io.File; -import java.io.IOException; - -import org.apache.commons.io.FileUtils; -import org.apache.geaflow.common.config.keys.DSLConfigKeys; -import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; -import org.apache.geaflow.file.FileConfigKeys; import org.testng.annotations.Test; /** - * Incremental K-Core algorithm test class - * Includes basic functionality tests, incremental update tests, dynamic graph tests, etc. - * + * Incremental K-Core algorithm test class Includes basic functionality tests, incremental update + * tests, dynamic graph tests, etc. + * * @author TuGraph Analytics Team */ public class IncrementalKCoreTest { - @Test - public void testIncrementalKCore_001_Basic() throws Exception { - // Note: Currently only this test can run stably - // Other tests are disabled due to GeaFlow framework RPC communication issues - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_kcore_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_001_Basic() throws Exception { + // Note: Currently only this test can run stably + // Other tests are disabled due to GeaFlow framework RPC communication issues + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_kcore_001.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncrementalKCore_002_IncrementalUpdate() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/modern_graph.sql") - .withQueryPath("/query/gql_inc_kcore_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_002_IncrementalUpdate() throws Exception { + QueryTester.build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_inc_kcore_002.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncrementalKCore_003_EdgeAddition() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/dynamic_graph.sql") - .withQueryPath("/query/gql_inc_kcore_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_003_EdgeAddition() throws Exception { + QueryTester.build() + .withGraphDefine("/query/dynamic_graph.sql") + .withQueryPath("/query/gql_inc_kcore_003.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncrementalKCore_004_Performance() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/large_graph.sql") - .withQueryPath("/query/gql_inc_kcore_007.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_004_Performance() throws Exception { + QueryTester.build() + .withGraphDefine("/query/large_graph.sql") + .withQueryPath("/query/gql_inc_kcore_007.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncrementalKCore_005_ComplexTopology() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/complex_graph.sql") - .withQueryPath("/query/gql_inc_kcore_009.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_005_ComplexTopology() throws Exception { + QueryTester.build() + .withGraphDefine("/query/complex_graph.sql") + .withQueryPath("/query/gql_inc_kcore_009.sql") + .execute() + .checkSinkResult(); + } - @Test - public void testIncrementalKCore_006_DisconnectedComponents() throws Exception { - QueryTester - .build() - .withGraphDefine("/query/disconnected_graph.sql") - .withQueryPath("/query/gql_inc_kcore_010.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testIncrementalKCore_006_DisconnectedComponents() throws Exception { + QueryTester.build() + .withGraphDefine("/query/disconnected_graph.sql") + .withQueryPath("/query/gql_inc_kcore_010.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/JDBCTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/JDBCTest.java index 6babfe52f..a7f591fa3 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/JDBCTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/JDBCTest.java @@ -23,6 +23,7 @@ import java.sql.Statement; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.dsl.runtime.testenv.SourceFunctionNoPartitionCheck; import org.h2.jdbcx.JdbcDataSource; @@ -32,33 +33,34 @@ import org.testng.annotations.Test; public class JDBCTest { - private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(JDBCTest.class); - private final String URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; - private final String username = "h2_user"; - private final String password = "h2_pwd"; + private final String URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private final String username = "h2_user"; + private final String password = "h2_pwd"; - @BeforeClass - public void setup() throws SQLException { - LOGGER.info("start h2 database."); - JdbcDataSource dataSource = new JdbcDataSource(); - dataSource.setURL(URL); - dataSource.setUser(username); - dataSource.setPassword(password); + @BeforeClass + public void setup() throws SQLException { + LOGGER.info("start h2 database."); + JdbcDataSource dataSource = new JdbcDataSource(); + dataSource.setURL(URL); + dataSource.setUser(username); + dataSource.setPassword(password); - Statement statement = dataSource.getConnection().createStatement(); - statement.execute("CREATE TABLE test (user_name VARCHAR(255) primary key, count INT);"); - statement.execute("CREATE TABLE users (id INT primary key, name VARCHAR(255), age INT);"); - } + Statement statement = dataSource.getConnection().createStatement(); + statement.execute("CREATE TABLE test (user_name VARCHAR(255) primary key, count INT);"); + statement.execute("CREATE TABLE users (id INT primary key, name VARCHAR(255), age INT);"); + } - @Test - public void testJDBC_001() throws Exception { - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(1L)); - config.put(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), - SourceFunctionNoPartitionCheck.class.getName()); - QueryTester tester = QueryTester - .build() + @Test + public void testJDBC_001() throws Exception { + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(1L)); + config.put( + DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), + SourceFunctionNoPartitionCheck.class.getName()); + QueryTester tester = + QueryTester.build() .withQueryPath("/query/jdbc_write_001.sql") .withConfig(config) .withTestTimeWaitSeconds(60) @@ -68,19 +70,19 @@ public void testJDBC_001() throws Exception { .withTestTimeWaitSeconds(60) .execute(); - tester.checkSinkResult(); - } + tester.checkSinkResult(); + } - @Test - public void testJDBC_002() throws Exception { - QueryTester tester = QueryTester - .build() + @Test + public void testJDBC_002() throws Exception { + QueryTester tester = + QueryTester.build() .withQueryPath("/query/jdbc_write_002.sql") .withTestTimeWaitSeconds(60) .execute() .withQueryPath("/query/jdbc_scan_002.sql") .withTestTimeWaitSeconds(60) .execute(); - tester.checkSinkResult(); - } + tester.checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/KafkaFoTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/KafkaFoTest.java index c165f2898..bf2f3f700 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/KafkaFoTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/KafkaFoTest.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.common.utils.DateTimeUtil; import org.apache.geaflow.dsl.connector.api.util.ConnectorConstants; @@ -35,76 +36,77 @@ public class KafkaFoTest { - private static final Logger LOGGER = LoggerFactory.getLogger(KafkaFoTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KafkaFoTest.class); - public static int injectExceptionTimes = 0; + public static int injectExceptionTimes = 0; - @BeforeClass - public void startKafkaServer() throws IOException { - LOGGER.info("startKafkaServer"); - KafkaTestEnv.get().startKafkaServer(); - LOGGER.info("startKafkaServer done"); - } + @BeforeClass + public void startKafkaServer() throws IOException { + LOGGER.info("startKafkaServer"); + KafkaTestEnv.get().startKafkaServer(); + LOGGER.info("startKafkaServer done"); + } - @AfterClass - public void shutdownKafkaServer() throws IOException { - LOGGER.info("shutdownKafkaServer"); - KafkaTestEnv.get().shutdownKafkaServer(); - LOGGER.info("shutdownKafkaServer done"); - } + @AfterClass + public void shutdownKafkaServer() throws IOException { + LOGGER.info("shutdownKafkaServer"); + KafkaTestEnv.get().shutdownKafkaServer(); + LOGGER.info("shutdownKafkaServer done"); + } - @Test - public void testKafka_001() throws Exception { - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(1L)); - config.put(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), - SourceFunctionNoPartitionCheck.class.getName()); - KafkaTestEnv.get().createTopic("sink-test"); - QueryTester - .build() - .withQueryPath("/query/kafka_write_001.sql") - .withConfig(config) - .withTestTimeWaitSeconds(60) - .execute(); - QueryTester tester = QueryTester - .build() + @Test + public void testKafka_001() throws Exception { + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(1L)); + config.put( + DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), + SourceFunctionNoPartitionCheck.class.getName()); + KafkaTestEnv.get().createTopic("sink-test"); + QueryTester.build() + .withQueryPath("/query/kafka_write_001.sql") + .withConfig(config) + .withTestTimeWaitSeconds(60) + .execute(); + QueryTester tester = + QueryTester.build() .withQueryPath("/query/kafka_scan_001.sql") .withConfig(config) .withTestTimeWaitSeconds(60); - try { - tester.execute(); - } catch (Exception e) { - LOGGER.info("Kafka unbounded stream finish with timeout."); - } - tester.checkSinkResult(); + try { + tester.execute(); + } catch (Exception e) { + LOGGER.info("Kafka unbounded stream finish with timeout."); } + tester.checkSinkResult(); + } - @Test - public void testKafka_002() throws Exception { - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), - SourceFunctionNoPartitionCheck.class.getName()); - String startTime = DateTimeUtil.fromUnixTime(System.currentTimeMillis() - 120 * 1000, ConnectorConstants.START_TIME_FORMAT); - config.put("startTime", startTime); - KafkaTestEnv.get().createTopic("scan_002"); - QueryTester - .build() - .withQueryPath("/query/kafka_write_002.sql") - .withConfig(config) - .withTestTimeWaitSeconds(60) - .execute(); - QueryTester tester = QueryTester - .build() + @Test + public void testKafka_002() throws Exception { + Map config = new HashMap<>(); + config.put( + DSLConfigKeys.GEAFLOW_DSL_CUSTOM_SOURCE_FUNCTION.getKey(), + SourceFunctionNoPartitionCheck.class.getName()); + String startTime = + DateTimeUtil.fromUnixTime( + System.currentTimeMillis() - 120 * 1000, ConnectorConstants.START_TIME_FORMAT); + config.put("startTime", startTime); + KafkaTestEnv.get().createTopic("scan_002"); + QueryTester.build() + .withQueryPath("/query/kafka_write_002.sql") + .withConfig(config) + .withTestTimeWaitSeconds(60) + .execute(); + QueryTester tester = + QueryTester.build() .withQueryPath("/query/kafka_scan_002.sql") .withCustomWindow() .withConfig(config) .withTestTimeWaitSeconds(60); - try { - tester.execute(); - } catch (Exception e) { - LOGGER.info("Kafka unbounded stream finish with timeout."); - } - tester.checkSinkResult(); + try { + tester.execute(); + } catch (Exception e) { + LOGGER.info("Kafka unbounded stream finish with timeout."); } - + tester.checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/LdbcTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/LdbcTest.java index 0ba9d309b..aff10813c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/LdbcTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/LdbcTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.file.FileConfigKeys; @@ -31,534 +32,480 @@ public class LdbcTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/bi/test/graph"; + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/bi/test/graph"; - private final Map testConfig = new HashMap() { + private final Map testConfig = + new HashMap() { { - put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); - put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); - put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); - } - }; - - @BeforeClass - public void prepare() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - QueryTester - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") - .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") - .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) - .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") - .withQueryPath("/ldbc/bi_insert_01.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_02.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_03.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_04.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_05.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_06.sql") - .execute(); - } - - @AfterClass - public void tearDown() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); + put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); + put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); + put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); } - } - - @Test - public void testLdbcBi_01() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_01.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_01.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_02() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_02.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_02.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_03() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_03.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_03.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test(enabled = false) - public void testLdbcBi_04() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_04.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_04.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_05() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_05.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_05.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_06() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_06.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_06.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_07() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_07.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_07.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_08() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_08.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_08.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_09() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_09.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_09.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_10() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_10.sql") - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "-1") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_10.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_11() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_11.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_11.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } + }; - @Test - public void testLdbcBi_12() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_12.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_12.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); + @BeforeClass + public void prepare() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } - - @Test - public void testLdbcBi_13() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_13.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_13.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_14() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_14.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_14.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_15() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_15.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_15.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_16() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_16.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_16.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_17() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_17.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_17.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_18() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_18.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_18.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_19() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_19.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_19.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcBi_20() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/bi_20.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/bi_20.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_01() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_01.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_01.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_02() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_02.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_02.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_03() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_03.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_03.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_04() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_04.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_04.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_05() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_05.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_05.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_06() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_06.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_06.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); - } - - @Test - public void testLdbcIs_07() throws Exception { - QueryTester - .build() - .withQueryPath("/ldbc/is_07.sql") - .withConfig(testConfig) - .execute() - .checkSinkResult(); - - QueryTester - .build() - .withQueryPath("/ldbc/is_07.sql") - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) - .execute() - .checkSinkResult(); + QueryTester.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") + .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") + .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) + .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") + .withQueryPath("/ldbc/bi_insert_01.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_02.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_03.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_04.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_05.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_06.sql") + .execute(); + } + + @AfterClass + public void tearDown() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + } + + @Test + public void testLdbcBi_01() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_01.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_01.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_02() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_02.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_02.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_03() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_03.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_03.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test(enabled = false) + public void testLdbcBi_04() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_04.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_04.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_05() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_05.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_05.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_06() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_06.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_06.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_07() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_07.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_07.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_08() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_08.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_08.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_09() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_09.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_09.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_10() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_10.sql") + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "-1") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_10.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_11() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_11.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_11.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_12() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_12.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_12.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_13() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_13.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_13.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_14() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_14.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_14.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_15() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_15.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_15.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_16() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_16.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_16.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_17() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_17.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_17.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_18() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_18.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_18.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_19() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_19.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_19.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcBi_20() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/bi_20.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/bi_20.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_01() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_01.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_01.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_02() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_02.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_02.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_03() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_03.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_03.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_04() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_04.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_04.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_05() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_05.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_05.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_06() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_06.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_06.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } + + @Test + public void testLdbcIs_07() throws Exception { + QueryTester.build() + .withQueryPath("/ldbc/is_07.sql") + .withConfig(testConfig) + .execute() + .checkSinkResult(); + + QueryTester.build() + .withQueryPath("/ldbc/is_07.sql") + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_TRAVERSAL_SPLIT_ENABLE.getKey(), String.valueOf(true)) + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/MultiFieldRadixSortTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/MultiFieldRadixSortTest.java index 1b2b4502f..e3ce0fa1e 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/MultiFieldRadixSortTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/MultiFieldRadixSortTest.java @@ -23,11 +23,12 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.common.type.primitive.BinaryStringType; import org.apache.geaflow.common.type.primitive.IntegerType; -import org.apache.geaflow.dsl.common.data.impl.ObjectRow; import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; import org.apache.geaflow.dsl.runtime.expression.field.FieldExpression; import org.apache.geaflow.dsl.runtime.function.table.order.MultiFieldRadixSort; import org.apache.geaflow.dsl.runtime.function.table.order.OrderByField; @@ -37,261 +38,274 @@ public class MultiFieldRadixSortTest { - @Test - public void testSortEmptyList() { - List data = new ArrayList<>(); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 0); - } + @Test + public void testSortEmptyList() { + List data = new ArrayList<>(); - @Test - public void testSortSingleElement() { - List data = new ArrayList<>(1); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("test")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 1); - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - } + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); - @Test - public void testSortByIntegerFieldAscending() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("c")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("b")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); - assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); - } + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); - @Test - public void testSortByIntegerFieldDescending() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("c")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("b")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.DESC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); - assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); - assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - } + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - @Test - public void testSortByStringFieldAscending() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("zebra")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("apple")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("banana")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField stringField = new OrderByField(); - stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); - stringField.order = ORDER.ASC; - sortInfo.orderByFields.add(stringField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - assertEquals(((BinaryString)data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "apple"); - assertEquals(((BinaryString)data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "banana"); - assertEquals(((BinaryString)data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "zebra"); - } + assertEquals(data.size(), 0); + } - @Test - public void testSortByStringFieldDescending() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("apple")})); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("zebra")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("banana")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField stringField = new OrderByField(); - stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); - stringField.order = ORDER.DESC; - sortInfo.orderByFields.add(stringField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - assertEquals(((BinaryString)data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "zebra"); - assertEquals(((BinaryString)data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "banana"); - assertEquals(((BinaryString)data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "apple"); - } + @Test + public void testSortSingleElement() { + List data = new ArrayList<>(1); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("test")})); - @Test - public void testMultiFieldSort() { - List data = new ArrayList<>(4); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("b")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("b")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(2); - - // First sort by integer field (ascending) - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - // Then sort by string field (ascending) - OrderByField stringField = new OrderByField(); - stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); - stringField.order = ORDER.ASC; - sortInfo.orderByFields.add(stringField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 4); - // Expected order: (1, "a"), (1, "b"), (2, "a"), (2, "b") - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - assertEquals(((BinaryString)data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); - - assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - assertEquals(((BinaryString)data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "b"); - - assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); - assertEquals(((BinaryString)data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); - - assertEquals(data.get(3).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); - assertEquals(((BinaryString)data.get(3).getField(1, BinaryStringType.INSTANCE)).toString(), "b"); - } + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); - @Test - public void testSortWithNullValues() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{null, BinaryString.fromString("b")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("c")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - // Null values should appear first in ascending order - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), null); - assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); - } + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); - @Test - public void testSortWithNegativeNumbers() { - List data = new ArrayList<>(4); - data.add(ObjectRow.create(new Object[]{-1, BinaryString.fromString("a")})); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("b")})); - data.add(ObjectRow.create(new Object[]{-5, BinaryString.fromString("c")})); - data.add(ObjectRow.create(new Object[]{0, BinaryString.fromString("d")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 4); - assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(-5)); - assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(-1)); - assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(0)); - assertEquals(data.get(3).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); - } + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - @Test - public void testSortWithEmptyStrings() { - List data = new ArrayList<>(3); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("")})); - data.add(ObjectRow.create(new Object[]{2, BinaryString.fromString("hello")})); - data.add(ObjectRow.create(new Object[]{3, BinaryString.fromString("a")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField stringField = new OrderByField(); - stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); - stringField.order = ORDER.ASC; - sortInfo.orderByFields.add(stringField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - // Empty string should come first - assertEquals(((BinaryString)data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), ""); - assertEquals(((BinaryString)data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); - assertEquals(((BinaryString)data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "hello"); - } + assertEquals(data.size(), 1); + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + } + + @Test + public void testSortByIntegerFieldAscending() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("c")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("b")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); + assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); + } + + @Test + public void testSortByIntegerFieldDescending() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("c")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("b")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.DESC; + sortInfo.orderByFields.add(intField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); + assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); + assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + } + + @Test + public void testSortByStringFieldAscending() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("zebra")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("apple")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("banana")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField stringField = new OrderByField(); + stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); + stringField.order = ORDER.ASC; + sortInfo.orderByFields.add(stringField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + assertEquals( + ((BinaryString) data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "apple"); + assertEquals( + ((BinaryString) data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "banana"); + assertEquals( + ((BinaryString) data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "zebra"); + } + + @Test + public void testSortByStringFieldDescending() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("apple")})); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("zebra")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("banana")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField stringField = new OrderByField(); + stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); + stringField.order = ORDER.DESC; + sortInfo.orderByFields.add(stringField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + assertEquals( + ((BinaryString) data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "zebra"); + assertEquals( + ((BinaryString) data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "banana"); + assertEquals( + ((BinaryString) data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "apple"); + } + + @Test + public void testMultiFieldSort() { + List data = new ArrayList<>(4); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("b")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("b")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(2); + + // First sort by integer field (ascending) + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); + + // Then sort by string field (ascending) + OrderByField stringField = new OrderByField(); + stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); + stringField.order = ORDER.ASC; + sortInfo.orderByFields.add(stringField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 4); + // Expected order: (1, "a"), (1, "b"), (2, "a"), (2, "b") + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + assertEquals( + ((BinaryString) data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); + + assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + assertEquals( + ((BinaryString) data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "b"); + + assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); + assertEquals( + ((BinaryString) data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); + + assertEquals(data.get(3).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); + assertEquals( + ((BinaryString) data.get(3).getField(1, BinaryStringType.INSTANCE)).toString(), "b"); + } + + @Test + public void testSortWithNullValues() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {null, BinaryString.fromString("b")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("c")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + // Null values should appear first in ascending order + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), null); + assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); + assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(2)); + } + + @Test + public void testSortWithNegativeNumbers() { + List data = new ArrayList<>(4); + data.add(ObjectRow.create(new Object[] {-1, BinaryString.fromString("a")})); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("b")})); + data.add(ObjectRow.create(new Object[] {-5, BinaryString.fromString("c")})); + data.add(ObjectRow.create(new Object[] {0, BinaryString.fromString("d")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 4); + assertEquals(data.get(0).getField(0, IntegerType.INSTANCE), Integer.valueOf(-5)); + assertEquals(data.get(1).getField(0, IntegerType.INSTANCE), Integer.valueOf(-1)); + assertEquals(data.get(2).getField(0, IntegerType.INSTANCE), Integer.valueOf(0)); + assertEquals(data.get(3).getField(0, IntegerType.INSTANCE), Integer.valueOf(3)); + } + + @Test + public void testSortWithEmptyStrings() { + List data = new ArrayList<>(3); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("")})); + data.add(ObjectRow.create(new Object[] {2, BinaryString.fromString("hello")})); + data.add(ObjectRow.create(new Object[] {3, BinaryString.fromString("a")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField stringField = new OrderByField(); + stringField.expression = new FieldExpression(1, BinaryStringType.INSTANCE); + stringField.order = ORDER.ASC; + sortInfo.orderByFields.add(stringField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); + + assertEquals(data.size(), 3); + // Empty string should come first + assertEquals( + ((BinaryString) data.get(0).getField(1, BinaryStringType.INSTANCE)).toString(), ""); + assertEquals( + ((BinaryString) data.get(1).getField(1, BinaryStringType.INSTANCE)).toString(), "a"); + assertEquals( + ((BinaryString) data.get(2).getField(1, BinaryStringType.INSTANCE)).toString(), "hello"); + } + + @Test + public void testSortStability() { + List data = new ArrayList<>(3); + // Create rows with same sort key but different secondary values + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("first")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("second")})); + data.add(ObjectRow.create(new Object[] {1, BinaryString.fromString("third")})); + + SortInfo sortInfo = new SortInfo(); + sortInfo.orderByFields = new ArrayList<>(1); + OrderByField intField = new OrderByField(); + intField.expression = new FieldExpression(0, IntegerType.INSTANCE); + intField.order = ORDER.ASC; + sortInfo.orderByFields.add(intField); + + MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - @Test - public void testSortStability() { - List data = new ArrayList<>(3); - // Create rows with same sort key but different secondary values - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("first")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("second")})); - data.add(ObjectRow.create(new Object[]{1, BinaryString.fromString("third")})); - - SortInfo sortInfo = new SortInfo(); - sortInfo.orderByFields = new ArrayList<>(1); - OrderByField intField = new OrderByField(); - intField.expression = new FieldExpression(0, IntegerType.INSTANCE); - intField.order = ORDER.ASC; - sortInfo.orderByFields.add(intField); - - MultiFieldRadixSort.multiFieldRadixSort(data, sortInfo); - - assertEquals(data.size(), 3); - // All should have the same integer value - for (Row row : data) { - assertEquals(row.getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); - } + assertEquals(data.size(), 3); + // All should have the same integer value + for (Row row : data) { + assertEquals(row.getField(0, IntegerType.INSTANCE), Integer.valueOf(1)); } -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ProjectTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ProjectTest.java index 3631f8c21..71917d3a7 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ProjectTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ProjectTest.java @@ -23,58 +23,33 @@ public class ProjectTest { - @Test - public void testProject_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testProject_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testProject_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testProject_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testProject_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testProject_006() throws Exception { - QueryTester - .build() - .withQueryPath("/query/project_006.sql") - .execute() - .checkSinkResult(); - } - + @Test + public void testProject_001() throws Exception { + QueryTester.build().withQueryPath("/query/project_001.sql").execute().checkSinkResult(); + } + + @Test + public void testProject_002() throws Exception { + QueryTester.build().withQueryPath("/query/project_002.sql").execute().checkSinkResult(); + } + + @Test + public void testProject_003() throws Exception { + QueryTester.build().withQueryPath("/query/project_003.sql").execute().checkSinkResult(); + } + + @Test + public void testProject_004() throws Exception { + QueryTester.build().withQueryPath("/query/project_004.sql").execute().checkSinkResult(); + } + + @Test + public void testProject_005() throws Exception { + QueryTester.build().withQueryPath("/query/project_005.sql").execute().checkSinkResult(); + } + + @Test + public void testProject_006() throws Exception { + QueryTester.build().withQueryPath("/query/project_006.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java index 6ddcd691c..d77612486 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/QueryTester.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.query; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.io.Serializable; @@ -32,6 +31,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; @@ -51,255 +51,260 @@ import org.apache.geaflow.runtime.core.scheduler.resource.ScheduledWorkerManagerFactory; import org.testng.Assert; +import com.google.common.base.Preconditions; + public class QueryTester implements Serializable { - private int testTimeWaitSeconds = 0; + private int testTimeWaitSeconds = 0; - public static final String INIT_DDL = "/query/modern_graph.sql"; - public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; + public static final String INIT_DDL = "/query/modern_graph.sql"; + public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; - private String queryPath; + private String queryPath; - private boolean compareWithOrder = false; + private boolean compareWithOrder = false; - private String graphDefinePath; + private String graphDefinePath; - private boolean hasCustomWindowConfig = false; + private boolean hasCustomWindowConfig = false; - protected boolean dedupe = false; + protected boolean dedupe = false; - private int workerNum = (int) ExecutionConfigKeys.CONTAINER_WORKER_NUM.getDefaultValue(); + private int workerNum = (int) ExecutionConfigKeys.CONTAINER_WORKER_NUM.getDefaultValue(); - private final Map config = new HashMap<>(); + private final Map config = new HashMap<>(); - private QueryTester() { - try { - initRemotePath(); - } catch (IOException e) { - throw new RuntimeException(e); - } + private QueryTester() { + try { + initRemotePath(); + } catch (IOException e) { + throw new RuntimeException(e); } - - public static QueryTester build() { - return new QueryTester(); + } + + public static QueryTester build() { + return new QueryTester(); + } + + public QueryTester withQueryPath(String queryPath) { + this.queryPath = queryPath; + return this; + } + + public QueryTester withTestTimeWaitSeconds(int testTimeWaitSeconds) { + this.testTimeWaitSeconds = testTimeWaitSeconds; + return this; + } + + public QueryTester withDedupe(boolean dedupe) { + this.dedupe = dedupe; + return this; + } + + public QueryTester compareWithOrder() { + this.compareWithOrder = true; + return this; + } + + public QueryTester withConfig(Map config) { + this.config.putAll(config); + return this; + } + + public QueryTester withConfig(String key, Object value) { + this.config.put(key, String.valueOf(value)); + return this; + } + + public QueryTester withCustomWindow() { + hasCustomWindowConfig = true; + return this; + } + + public QueryTester withWorkerNum(int workerNum) { + this.workerNum = workerNum; + return this; + } + + public QueryTester execute() throws Exception { + if (queryPath == null) { + throw new IllegalArgumentException("You should call withQueryPath() before execute()."); } - - - public QueryTester withQueryPath(String queryPath) { - this.queryPath = queryPath; - return this; + Map config = new HashMap<>(); + if (!hasCustomWindowConfig) { + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); } - - public QueryTester withTestTimeWaitSeconds(int testTimeWaitSeconds) { - this.testTimeWaitSeconds = testTimeWaitSeconds; - return this; + config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); + config.put( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), + FileConstants.PREFIX_JAVA_RESOURCE + queryPath); + config.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM.getKey(), String.valueOf(workerNum)); + config.putAll(this.config); + initResultDirectory(); + + Environment environment = EnvironmentFactory.onLocalEnvironment(); + environment.getEnvironmentContext().withConfig(config); + + GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); + + String graphDefinePath = null; + if (this.graphDefinePath != null) { + graphDefinePath = this.graphDefinePath; } - - public QueryTester withDedupe(boolean dedupe) { - this.dedupe = dedupe; - return this; + gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); + try { + gqlPipeLine.execute(); + } finally { + environment.shutdown(); + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); } - - public QueryTester compareWithOrder() { - this.compareWithOrder = true; - return this; + return this; + } + + private void initResultDirectory() throws Exception { + // delete target file path + String targetPath = getTargetPath(queryPath); + File targetFile = new File(targetPath); + if (targetFile.exists()) { + FileUtils.forceDelete(targetFile); } + } - public QueryTester withConfig(Map config) { - this.config.putAll(config); - return this; + private void initRemotePath() throws IOException { + // delete state remote path + File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); + if (stateRemoteFile.exists()) { + FileUtils.forceDelete(stateRemoteFile); } - - public QueryTester withConfig(String key, Object value) { - this.config.put(key, String.valueOf(value)); - return this; + } + + public void checkSinkResult() throws Exception { + checkSinkResult(null); + } + + public void checkSinkResult(String dict) throws Exception { + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String exceptPath = + dict != null + ? "/expect/" + dict + "/" + lastPath.split("\\.")[0] + ".txt" + : "/expect/" + lastPath.split("\\.")[0] + ".txt"; + String targetPath = getTargetPath(queryPath); + String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); + String actualResult = readFile(targetPath); + compareResult(actualResult, expectResult); + } + + private void compareResult(String actualResult, String expectResult) { + if (compareWithOrder) { + Assert.assertEquals(actualResult, expectResult); + } else { + String[] actualLines = actualResult.split("\n"); + String[] expectLines = expectResult.split("\n"); + if (dedupe) { + List actualLinesDedupe = + Arrays.asList(actualLines).stream().distinct().collect(Collectors.toList()); + actualLines = actualLinesDedupe.toArray(new String[0]); + List expectLinesDedupe = + Arrays.asList(expectLines).stream().distinct().collect(Collectors.toList()); + expectLines = expectLinesDedupe.toArray(new String[0]); + } + Arrays.sort(actualLines); + Arrays.sort(expectLines); + + String actualSort = StringUtils.join(actualLines, "\n"); + String expectSort = StringUtils.join(expectLines, "\n"); + if (!Objects.equals(actualSort, expectSort)) { + Assert.assertEquals(actualResult, expectResult); + } } + } - public QueryTester withCustomWindow() { - hasCustomWindowConfig = true; - return this; + private String readFile(String path) throws IOException { + File file = new File(path); + if (file.isHidden()) { + return ""; } - - public QueryTester withWorkerNum(int workerNum) { - this.workerNum = workerNum; - return this; + if (file.isFile()) { + return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); } - - public QueryTester execute() throws Exception { - if (queryPath == null) { - throw new IllegalArgumentException("You should call withQueryPath() before execute()."); + File[] files = file.listFiles(); + StringBuilder content = new StringBuilder(); + if (files != null) { + for (File subFile : files) { + String readText = readFile(subFile.getAbsolutePath()); + if (StringUtils.isBlank(readText)) { + continue; } - Map config = new HashMap<>(); - if (!hasCustomWindowConfig) { - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); + if (content.length() > 0) { + content.append("\n"); } - config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), FileConstants.PREFIX_JAVA_RESOURCE + queryPath); - config.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM.getKey(), String.valueOf(workerNum)); - config.putAll(this.config); - initResultDirectory(); + content.append(readText); + } + } + return content.toString().trim(); + } - Environment environment = EnvironmentFactory.onLocalEnvironment(); - environment.getEnvironmentContext().withConfig(config); + private static String getTargetPath(String queryPath) { + assert queryPath != null; + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String targetPath = "target/" + lastPath.split("\\.")[0]; + String currentPath = new File(".").getAbsolutePath(); + targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; + return targetPath; + } - GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); + public QueryTester withGraphDefine(String graphDefinePath) { + this.graphDefinePath = Objects.requireNonNull(graphDefinePath); + return this; + } - String graphDefinePath = null; - if (this.graphDefinePath != null) { - graphDefinePath = this.graphDefinePath; - } - gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); - try { - gqlPipeLine.execute(); - } finally { - environment.shutdown(); - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); - } - return this; - } + private static class TestGQLPipelineHook implements GQLPipelineHook { - private void initResultDirectory() throws Exception { - // delete target file path - String targetPath = getTargetPath(queryPath); - File targetFile = new File(targetPath); - if (targetFile.exists()) { - FileUtils.forceDelete(targetFile); - } - } + private final String graphDefinePath; - private void initRemotePath() throws IOException { - // delete state remote path - File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); - if (stateRemoteFile.exists()) { - FileUtils.forceDelete(stateRemoteFile); - } - } - - public void checkSinkResult() throws Exception { - checkSinkResult(null); - } + private final String queryPath; - public void checkSinkResult(String dict) throws Exception { - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String exceptPath = dict != null ? "/expect/" + dict + "/" + lastPath.split("\\.")[0] + ".txt" - : "/expect/" + lastPath.split("\\.")[0] + ".txt"; - String targetPath = getTargetPath(queryPath); - String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); - String actualResult = readFile(targetPath); - compareResult(actualResult, expectResult); + public TestGQLPipelineHook(String graphDefinePath, String queryPath) { + this.graphDefinePath = graphDefinePath; + this.queryPath = queryPath; } - private void compareResult(String actualResult, String expectResult) { - if (compareWithOrder) { - Assert.assertEquals(actualResult, expectResult); + @Override + public String rewriteScript(String script, Configuration configuration) { + String result = script; + String regex = "\\$\\{[^}]+}"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(result); + while (matcher.find()) { + String matchedField = matcher.group(); + String replaceKey = matchedField.substring(2, matchedField.length() - 1); + if (replaceKey.equals("target")) { + result = result.replace(matchedField, getTargetPath(queryPath)); } else { - String[] actualLines = actualResult.split("\n"); - String[] expectLines = expectResult.split("\n"); - if (dedupe) { - List actualLinesDedupe = Arrays.asList(actualLines).stream().distinct().collect(Collectors.toList()); - actualLines = actualLinesDedupe.toArray(new String[0]); - List expectLinesDedupe = Arrays.asList(expectLines).stream().distinct().collect(Collectors.toList()); - expectLines = expectLinesDedupe.toArray(new String[0]); - } - Arrays.sort(actualLines); - Arrays.sort(expectLines); - - String actualSort = StringUtils.join(actualLines, "\n"); - String expectSort = StringUtils.join(expectLines, "\n"); - if (!Objects.equals(actualSort, expectSort)) { - Assert.assertEquals(actualResult, expectResult); - } + String replaceData = configuration.getString(replaceKey); + Preconditions.checkState(replaceData != null, "Not found replace key:{}", replaceKey); + result = result.replace(matchedField, replaceData); } + } + return result; } - private String readFile(String path) throws IOException { - File file = new File(path); - if (file.isHidden()) { - return ""; - } - if (file.isFile()) { - return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); - } - File[] files = file.listFiles(); - StringBuilder content = new StringBuilder(); - if (files != null) { - for (File subFile : files) { - String readText = readFile(subFile.getAbsolutePath()); - if (StringUtils.isBlank(readText)) { - continue; - } - if (content.length() > 0) { - content.append("\n"); - } - content.append(readText); - } + @Override + public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { + if (graphDefinePath != null) { + try { + String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); + queryClient.executeQuery(ddl, queryContext); + } catch (IOException e) { + throw new GeaFlowDSLException(e); } - return content.toString().trim(); - } - - private static String getTargetPath(String queryPath) { - assert queryPath != null; - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String targetPath = "target/" + lastPath.split("\\.")[0]; - String currentPath = new File(".").getAbsolutePath(); - targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; - return targetPath; - } - - public QueryTester withGraphDefine(String graphDefinePath) { - this.graphDefinePath = Objects.requireNonNull(graphDefinePath); - return this; + } } - private static class TestGQLPipelineHook implements GQLPipelineHook { - - private final String graphDefinePath; - - private final String queryPath; - - public TestGQLPipelineHook(String graphDefinePath, String queryPath) { - this.graphDefinePath = graphDefinePath; - this.queryPath = queryPath; - } - - @Override - public String rewriteScript(String script, Configuration configuration) { - String result = script; - String regex = "\\$\\{[^}]+}"; - Pattern pattern = Pattern.compile(regex); - Matcher matcher = pattern.matcher(result); - while (matcher.find()) { - String matchedField = matcher.group(); - String replaceKey = matchedField.substring(2, matchedField.length() - 1); - if (replaceKey.equals("target")) { - result = result.replace(matchedField, getTargetPath(queryPath)); - } else { - String replaceData = configuration.getString(replaceKey); - Preconditions.checkState(replaceData != null, "Not found replace key:{}", replaceKey); - result = result.replace(matchedField, replaceData); - } - } - return result; - } - - @Override - public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { - if (graphDefinePath != null) { - try { - String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); - queryClient.executeQuery(ddl, queryContext); - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } - } - } - - @Override - public void afterExecute(QueryClient queryClient, QueryContext queryContext) { - - } - } + @Override + public void afterExecute(QueryClient queryClient, QueryContext queryContext) {} + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/SortTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/SortTest.java index f5e957862..e2ebfdc9d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/SortTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/SortTest.java @@ -23,39 +23,23 @@ public class SortTest { - @Test - public void testSort_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/sort_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_001() throws Exception { + QueryTester.build().withQueryPath("/query/sort_001.sql").execute().checkSinkResult(); + } - @Test - public void testSort_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/sort_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_002() throws Exception { + QueryTester.build().withQueryPath("/query/sort_002.sql").execute().checkSinkResult(); + } - @Test - public void testSort_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/sort_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_003() throws Exception { + QueryTester.build().withQueryPath("/query/sort_003.sql").execute().checkSinkResult(); + } - @Test - public void testSort_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/sort_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testSort_004() throws Exception { + QueryTester.build().withQueryPath("/query/sort_004.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TableScanTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TableScanTest.java index 8ef97889e..78838aa45 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TableScanTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TableScanTest.java @@ -23,48 +23,28 @@ public class TableScanTest { - @Test - public void testScan_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/scan_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testScan_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/scan_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testScan_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/scan_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testScan_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/scan_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testScan_005() throws Exception { - QueryTester - .build() - .withQueryPath("/query/scan_005.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testScan_001() throws Exception { + QueryTester.build().withQueryPath("/query/scan_001.sql").execute().checkSinkResult(); + } + + @Test + public void testScan_002() throws Exception { + QueryTester.build().withQueryPath("/query/scan_002.sql").execute().checkSinkResult(); + } + + @Test + public void testScan_003() throws Exception { + QueryTester.build().withQueryPath("/query/scan_003.sql").execute().checkSinkResult(); + } + + @Test + public void testScan_004() throws Exception { + QueryTester.build().withQueryPath("/query/scan_004.sql").execute().checkSinkResult(); + } + + @Test + public void testScan_005() throws Exception { + QueryTester.build().withQueryPath("/query/scan_005.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TypesTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TypesTest.java index 23751fb8b..724139ef4 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TypesTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/TypesTest.java @@ -23,30 +23,18 @@ public class TypesTest { - @Test - public void testBooleanType_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/type_boolean_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testBooleanType_001() throws Exception { + QueryTester.build().withQueryPath("/query/type_boolean_001.sql").execute().checkSinkResult(); + } - @Test - public void testTimestampType_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/type_timestamp_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testTimestampType_001() throws Exception { + QueryTester.build().withQueryPath("/query/type_timestamp_001.sql").execute().checkSinkResult(); + } - @Test - public void testDateType_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/type_date_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testDateType_001() throws Exception { + QueryTester.build().withQueryPath("/query/type_date_001.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/UnionTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/UnionTest.java index 60d8bfc79..543a2d11a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/UnionTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/UnionTest.java @@ -23,21 +23,13 @@ public class UnionTest { - @Test - public void testUnion_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/union_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUnion_001() throws Exception { + QueryTester.build().withQueryPath("/query/union_001.sql").execute().checkSinkResult(); + } - @Test - public void testUnion_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/union_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testUnion_002() throws Exception { + QueryTester.build().withQueryPath("/query/union_002.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ValuesTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ValuesTest.java index 699747ea5..7b32d48a8 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ValuesTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ValuesTest.java @@ -23,21 +23,13 @@ public class ValuesTest { - @Test - public void testValues_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/values_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testValues_001() throws Exception { + QueryTester.build().withQueryPath("/query/values_001.sql").execute().checkSinkResult(); + } - @Test - public void testValues_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/values_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testValues_002() throws Exception { + QueryTester.build().withQueryPath("/query/values_002.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ViewTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ViewTest.java index 46927bada..c3bc9752d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ViewTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/ViewTest.java @@ -23,39 +23,23 @@ public class ViewTest { - @Test - public void testView_001() throws Exception { - QueryTester - .build() - .withQueryPath("/query/view_001.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testView_001() throws Exception { + QueryTester.build().withQueryPath("/query/view_001.sql").execute().checkSinkResult(); + } - @Test - public void testView_002() throws Exception { - QueryTester - .build() - .withQueryPath("/query/view_002.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testView_002() throws Exception { + QueryTester.build().withQueryPath("/query/view_002.sql").execute().checkSinkResult(); + } - @Test - public void testView_003() throws Exception { - QueryTester - .build() - .withQueryPath("/query/view_003.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testView_003() throws Exception { + QueryTester.build().withQueryPath("/query/view_003.sql").execute().checkSinkResult(); + } - @Test - public void testView_004() throws Exception { - QueryTester - .build() - .withQueryPath("/query/view_004.sql") - .execute() - .checkSinkResult(); - } + @Test + public void testView_004() throws Exception { + QueryTester.build().withQueryPath("/query/view_004.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/JsonParserTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/JsonParserTest.java index 2775bacff..a26ecdf32 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/JsonParserTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/JsonParserTest.java @@ -24,13 +24,8 @@ public class JsonParserTest { - @Test - public void testJsonPathGet() throws Exception { - QueryTester - .build() - .withQueryPath("/query/json_path_get_001.sql") - .execute() - .checkSinkResult(); - } - + @Test + public void testJsonPathGet() throws Exception { + QueryTester.build().withQueryPath("/query/json_path_get_001.sql").execute().checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/MyCount.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/MyCount.java index 797cdf523..eebeb2241 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/MyCount.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/MyCount.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.binary.BinaryString; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDAF; @@ -32,67 +33,62 @@ @Description(name = "count", description = "custom count function for test") public class MyCount extends UDAF { - @Override - public Accumulator createAccumulator() { - return new Accumulator(0); - } + @Override + public Accumulator createAccumulator() { + return new Accumulator(0); + } - @Override - public void accumulate(Accumulator accumulator, MultiArguments input) { - if (input.getParam(0).toString().equalsIgnoreCase("jim")) { - accumulator.value += 10000L; - } - accumulator.value += (long) input.getParam(1); + @Override + public void accumulate(Accumulator accumulator, MultiArguments input) { + if (input.getParam(0).toString().equalsIgnoreCase("jim")) { + accumulator.value += 10000L; } + accumulator.value += (long) input.getParam(1); + } - @Override - public void merge(Accumulator accumulator, Iterable its) { - for (Accumulator toMerge : its) { - accumulator.value += toMerge.value; - } + @Override + public void merge(Accumulator accumulator, Iterable its) { + for (Accumulator toMerge : its) { + accumulator.value += toMerge.value; } + } - @Override - public void resetAccumulator(Accumulator accumulator) { - accumulator.value = 0L; - } + @Override + public void resetAccumulator(Accumulator accumulator) { + accumulator.value = 0L; + } - @Override - public Long getValue(Accumulator accumulator) { - return accumulator.value; - } + @Override + public Long getValue(Accumulator accumulator) { + return accumulator.value; + } - public static class Accumulator implements Serializable { + public static class Accumulator implements Serializable { - public Accumulator() { - } + public Accumulator() {} - public long value = 0; + public long value = 0; - public Accumulator(long value) { - this.value = value; - } + public Accumulator(long value) { + this.value = value; + } - @Override - public String toString() { - return "Accumulator{" - + "value=" + value - + '}'; - } + @Override + public String toString() { + return "Accumulator{" + "value=" + value + '}'; } + } - public static class MultiArguments extends UDAFArguments { + public static class MultiArguments extends UDAFArguments { - public MultiArguments() { - } + public MultiArguments() {} - @Override - public List> getParamTypes() { - List> types = new ArrayList<>(); - types.add(BinaryString.class); - types.add(Long.class); - return types; - } + @Override + public List> getParamTypes() { + List> types = new ArrayList<>(); + types.add(BinaryString.class); + types.add(Long.class); + return types; } + } } - diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/TestEdgeIteratorUdf.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/TestEdgeIteratorUdf.java index 815d0d559..614dbfacc 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/TestEdgeIteratorUdf.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udf/TestEdgeIteratorUdf.java @@ -19,11 +19,10 @@ package org.apache.geaflow.dsl.runtime.query.udf; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import java.util.Iterator; import java.util.List; import java.util.Optional; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.type.primitive.LongType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -40,72 +39,77 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + @Description(name = "test_edge_iterator", description = "built-in udga for WeakConnectedComponents") public class TestEdgeIteratorUdf implements AlgorithmUserFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(TestEdgeIteratorUdf.class); + private static final Logger LOGGER = LoggerFactory.getLogger(TestEdgeIteratorUdf.class); - private AlgorithmRuntimeContext context; + private AlgorithmRuntimeContext context; - private int iteration = 5; + private int iteration = 5; - private int edgeLimit = 100; + private int edgeLimit = 100; - @Override - public void init(AlgorithmRuntimeContext context, Object[] parameters) { - this.context = context; - if (parameters.length > 0) { - iteration = Integer.parseInt(String.valueOf(parameters[0])); - } - if (parameters.length > 1) { - edgeLimit = Integer.parseInt(String.valueOf(parameters[1])); - } + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 0) { + iteration = Integer.parseInt(String.valueOf(parameters[0])); } + if (parameters.length > 1) { + edgeLimit = Integer.parseInt(String.valueOf(parameters[1])); + } + } - @Override - public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { - updatedValues.ifPresent(vertex::setValue); - CloseableIterator edgesIterator = context.loadStaticEdgesIterator(EdgeDirection.BOTH); - if (context.getCurrentIterationId() < iteration) { - int count = 0; - while (edgesIterator.hasNext() && count < edgeLimit) { - RowEdge next = edgesIterator.next(); - context.sendMessage(next.getTargetId(), context.getCurrentIterationId()); - count++; - } - } - List bothEdge = Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.BOTH)); - List inEdge = Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.IN)); - List outEdge = Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.OUT)); - Preconditions.checkState(bothEdge.size() == inEdge.size() + outEdge.size(), "Static edge not equal"); - List bothEdgeList = context.loadStaticEdges(EdgeDirection.BOTH); - Preconditions.checkState(bothEdgeList.size() == inEdge.size() + outEdge.size(), "Static edge not equal"); - - bothEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.BOTH)); - inEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.IN)); - outEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.OUT)); - Preconditions.checkState(bothEdge.size() == inEdge.size() + outEdge.size(), "Dynamic edge not equal"); - - bothEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.BOTH)); - inEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.IN)); - outEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.OUT)); - Preconditions.checkState(bothEdge.size() == inEdge.size() + outEdge.size(), "History edge not equal"); + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + CloseableIterator edgesIterator = context.loadStaticEdgesIterator(EdgeDirection.BOTH); + if (context.getCurrentIterationId() < iteration) { + int count = 0; + while (edgesIterator.hasNext() && count < edgeLimit) { + RowEdge next = edgesIterator.next(); + context.sendMessage(next.getTargetId(), context.getCurrentIterationId()); + count++; + } + } + List bothEdge = + Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.BOTH)); + List inEdge = Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.IN)); + List outEdge = Lists.newArrayList(context.loadStaticEdgesIterator(EdgeDirection.OUT)); + Preconditions.checkState( + bothEdge.size() == inEdge.size() + outEdge.size(), "Static edge not equal"); + List bothEdgeList = context.loadStaticEdges(EdgeDirection.BOTH); + Preconditions.checkState( + bothEdgeList.size() == inEdge.size() + outEdge.size(), "Static edge not equal"); + bothEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.BOTH)); + inEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.IN)); + outEdge = Lists.newArrayList(context.loadDynamicEdgesIterator(EdgeDirection.OUT)); + Preconditions.checkState( + bothEdge.size() == inEdge.size() + outEdge.size(), "Dynamic edge not equal"); - } + bothEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.BOTH)); + inEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.IN)); + outEdge = Lists.newArrayList(context.loadEdgesIterator(EdgeDirection.OUT)); + Preconditions.checkState( + bothEdge.size() == inEdge.size() + outEdge.size(), "History edge not equal"); + } - @Override - public void finish(RowVertex graphVertex, Optional updatedValues) { - updatedValues.ifPresent(graphVertex::setValue); - long iteration = (long) graphVertex.getValue().getField(0, LongType.INSTANCE); - context.take(ObjectRow.create(graphVertex.getId(), iteration)); - } + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + updatedValues.ifPresent(graphVertex::setValue); + long iteration = (long) graphVertex.getValue().getField(0, LongType.INSTANCE); + context.take(ObjectRow.create(graphVertex.getId(), iteration)); + } - @Override - public StructType getOutputType(GraphSchema graphSchema) { - return new StructType( - new TableField("id", graphSchema.getIdType(), false), - new TableField("iteration", LongType.INSTANCE, false) - ); - } + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("iteration", LongType.INSTANCE, false)); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/Split2.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/Split2.java index 34e67890d..a9021b95d 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/Split2.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/Split2.java @@ -19,51 +19,53 @@ package org.apache.geaflow.dsl.runtime.query.udtf; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDTF; +import com.google.common.collect.Lists; + @Description(name = "split2", description = "") public class Split2 extends UDTF { - private String columnDelimiter = ","; - private String lineDelimiter = "\n"; + private String columnDelimiter = ","; + private String lineDelimiter = "\n"; - public void eval(String data) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - public void eval(String data, String columnDelimiter) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data, String columnDelimiter) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - public void eval(String data, String columnDelimiter, String lineDelimiter) { - evalInternal(data, columnDelimiter, lineDelimiter); - } + public void eval(String data, String columnDelimiter, String lineDelimiter) { + evalInternal(data, columnDelimiter, lineDelimiter); + } - private void evalInternal(String data, String columnDelimiter, String lineDelimiter) { - String[] rows = StringUtils.split(data, lineDelimiter); - for (String row : rows) { - String[] split = StringUtils.split(row, columnDelimiter); - collect(split); - } + private void evalInternal(String data, String columnDelimiter, String lineDelimiter) { + String[] rows = StringUtils.split(data, lineDelimiter); + for (String row : rows) { + String[] split = StringUtils.split(row, columnDelimiter); + collect(split); } + } - @Override - public List> getReturnType(List> paramTypes, List udtfReturnFields) { + @Override + public List> getReturnType(List> paramTypes, List udtfReturnFields) { - List> clazzs = Lists.newArrayList(); + List> clazzs = Lists.newArrayList(); - if (udtfReturnFields == null) { - clazzs.add(String.class); - clazzs.add(String.class); - return clazzs; - } + if (udtfReturnFields == null) { + clazzs.add(String.class); + clazzs.add(String.class); + return clazzs; + } - for (int i = 0; i < udtfReturnFields.size(); i++) { - clazzs.add(String.class); - } - return clazzs; + for (int i = 0; i < udtfReturnFields.size(); i++) { + clazzs.add(String.class); } + return clazzs; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/SplitMap.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/SplitMap.java index 771d76136..a9299e103 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/SplitMap.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/udtf/SplitMap.java @@ -19,49 +19,51 @@ package org.apache.geaflow.dsl.runtime.query.udtf; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.dsl.common.function.Description; import org.apache.geaflow.dsl.common.function.UDTF; +import com.google.common.collect.Lists; + @Description(name = "split_map", description = "") public class SplitMap extends UDTF { - private String splitChar = ","; + private String splitChar = ","; - public void eval(String arg0) { - evalInternal(arg0); - } + public void eval(String arg0) { + evalInternal(arg0); + } - public void eval(String arg0, String arg1) { - evalInternal(arg0, arg1); - } + public void eval(String arg0, String arg1) { + evalInternal(arg0, arg1); + } - private void evalInternal(String... args) { - if (args != null && (args.length == 1 || args.length == 2)) { - if (args.length == 2 && StringUtils.isNotEmpty(args[1])) { - splitChar = args[1]; - } - String[] lines = StringUtils.split(args[0], splitChar); - collect(new Object[]{lines[0]}); - } + private void evalInternal(String... args) { + if (args != null && (args.length == 1 || args.length == 2)) { + if (args.length == 2 && StringUtils.isNotEmpty(args[1])) { + splitChar = args[1]; + } + String[] lines = StringUtils.split(args[0], splitChar); + collect(new Object[] {lines[0]}); } + } - @Override - public List> getReturnType(List> paramTypes, List udtfReturnFields) { - - List> clazzs = Lists.newArrayList(); + @Override + public List> getReturnType(List> paramTypes, List udtfReturnFields) { - if (udtfReturnFields == null) { - clazzs.add(String.class); - return clazzs; - } + List> clazzs = Lists.newArrayList(); - for (int i = 0; i < udtfReturnFields.size(); i++) { - clazzs.add(String.class); - } + if (udtfReturnFields == null) { + clazzs.add(String.class); + return clazzs; + } - return clazzs; + for (int i = 0; i < udtfReturnFields.size(); i++) { + clazzs.add(String.class); } + + return clazzs; + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/sql2graph/JoinToGraphTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/sql2graph/JoinToGraphTest.java index a0e27edeb..5a6fd90e0 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/sql2graph/JoinToGraphTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/sql2graph/JoinToGraphTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.dsl.runtime.query.QueryTester; @@ -32,474 +33,437 @@ public class JoinToGraphTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/join2Graph/test/graph"; + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/join2Graph/test/graph"; - private final Map testConfig = new HashMap() { + private final Map testConfig = + new HashMap() { { - put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); - put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); - put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); - // If the test is conducted using the console catalog, the appended config is required. - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE.getKey(), "console"); - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), ""); - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME.getKey(), "test1"); - // put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), "http://127.0.0.1:8888"); - } - }; - - @BeforeClass - public void prepare() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - QueryTester - .build() - .withConfig(testConfig) - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") - .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") - .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) - .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") - .withQueryPath("/sql2graph/graph_student_v_insert.sql").execute() - .withQueryPath("/sql2graph/graph_student_e_insert.sql").execute(); - } - - @AfterClass - public void tearDown() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); + put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); + put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); + put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); + // If the test is conducted using the console catalog, the appended config is required. + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE.getKey(), "console"); + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), ""); + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME.getKey(), "test1"); + // put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), "http://127.0.0.1:8888"); } - } - - @Test - public void testVertexJoinEdge_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/vertex_join_edge_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testVertexJoinEdge_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/vertex_join_edge_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testEdgeJoinVertex_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/edge_join_vertex_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testEdgeJoinVertex_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/edge_join_vertex_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinVertex_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_vertex_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinVertex_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_vertex_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinVertex_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_vertex_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinEdge_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinEdge_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinEdge_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinEdge_004() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testMatchJoinEdge_005() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_004() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_005() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_006() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_007() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_008() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_009() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_009.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_010() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_010.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinToMatch_011() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/join_to_match_011.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregateToMatch_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/aggregate_to_match_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregateToMatch_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/aggregate_to_match_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testAggregateToMatch_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/aggregate_to_match_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_004() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_004.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_005() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_005.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_006() throws Exception { - //di_join_001 - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_006.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_007() throws Exception { - //di_join_0011 - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_007.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testLeftJoin_008() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/left_join_008.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testTableScan_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/table_scan_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testTableScan_002() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/table_scan_002.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testTableScan_003() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/table_scan_003.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinEdgeWithFilter_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_with_filter_001.sql") - .execute() - .checkSinkResult(); - } - - @Test - public void testJoinEdgeWithGroup_001() throws Exception { - QueryTester - .build() - .withConfig(testConfig) - .withGraphDefine("/sql2graph/graph_student.sql") - .withQueryPath("/sql2graph/match_join_edge_with_group_001.sql") - .execute() - .checkSinkResult(); - } + }; + + @BeforeClass + public void prepare() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); + } + QueryTester.build() + .withConfig(testConfig) + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") + .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") + .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) + .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") + .withQueryPath("/sql2graph/graph_student_v_insert.sql") + .execute() + .withQueryPath("/sql2graph/graph_student_e_insert.sql") + .execute(); + } + + @AfterClass + public void tearDown() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); + } + } + + @Test + public void testVertexJoinEdge_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/vertex_join_edge_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testVertexJoinEdge_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/vertex_join_edge_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testEdgeJoinVertex_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/edge_join_vertex_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testEdgeJoinVertex_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/edge_join_vertex_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinVertex_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_vertex_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinVertex_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_vertex_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinVertex_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_vertex_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinEdge_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinEdge_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinEdge_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinEdge_004() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testMatchJoinEdge_005() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_004() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_005() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_006() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_007() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_007.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_008() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_008.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_009() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_009.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_010() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_010.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinToMatch_011() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/join_to_match_011.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAggregateToMatch_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/aggregate_to_match_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAggregateToMatch_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/aggregate_to_match_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAggregateToMatch_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/aggregate_to_match_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_004() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_004.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_005() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_005.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_006() throws Exception { + // di_join_001 + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_006.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_007() throws Exception { + // di_join_0011 + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_007.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testLeftJoin_008() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/left_join_008.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testTableScan_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/table_scan_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testTableScan_002() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/table_scan_002.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testTableScan_003() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/table_scan_003.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinEdgeWithFilter_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_with_filter_001.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testJoinEdgeWithGroup_001() throws Exception { + QueryTester.build() + .withConfig(testConfig) + .withGraphDefine("/sql2graph/graph_student.sql") + .withQueryPath("/sql2graph/match_join_edge_with_group_001.sql") + .execute() + .checkSinkResult(); + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/FoGeaFlowTableSinkFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/FoGeaFlowTableSinkFunction.java index ceef92c0e..bf4463e67 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/FoGeaFlowTableSinkFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/FoGeaFlowTableSinkFunction.java @@ -30,25 +30,24 @@ public class FoGeaFlowTableSinkFunction extends GeaFlowTableSinkFunction { - private static final Logger LOGGER = LoggerFactory.getLogger( - FoGeaFlowTableSinkFunction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FoGeaFlowTableSinkFunction.class); - public FoGeaFlowTableSinkFunction(GeaFlowTable table, TableSink tableSink) { - super(table, tableSink); - } + public FoGeaFlowTableSinkFunction(GeaFlowTable table, TableSink tableSink) { + super(table, tableSink); + } - @Override - public void open(RuntimeContext runtimeContext) { - LOGGER.info("open fo sink function."); - super.open(runtimeContext); - } + @Override + public void open(RuntimeContext runtimeContext) { + LOGGER.info("open fo sink function."); + super.open(runtimeContext); + } - @Override - public void write(Row row) throws Exception { - super.write(row); - if (KafkaFoTest.injectExceptionTimes > 0) { - --KafkaFoTest.injectExceptionTimes; - throw new RuntimeException("break for fo test"); - } + @Override + public void write(Row row) throws Exception { + super.write(row); + if (KafkaFoTest.injectExceptionTimes > 0) { + --KafkaFoTest.injectExceptionTimes; + throw new RuntimeException("break for fo test"); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/KafkaTestEnv.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/KafkaTestEnv.java index c0c3b63ea..c584a2b32 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/KafkaTestEnv.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/KafkaTestEnv.java @@ -19,7 +19,6 @@ package org.apache.geaflow.dsl.runtime.testenv; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.io.File; import java.io.IOException; import java.security.Permission; @@ -31,8 +30,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import kafka.server.KafkaConfig; -import kafka.server.KafkaServer; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; @@ -45,168 +43,177 @@ import org.apache.zookeeper.server.ZooKeeperServerMain; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; import scala.Option; import scala.collection.mutable.ArraySeq; public class KafkaTestEnv { - private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTestEnv.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KafkaTestEnv.class); - private static final KafkaTestEnv INSTANCE = new KafkaTestEnv(); + private static final KafkaTestEnv INSTANCE = new KafkaTestEnv(); - private static SecurityManager securityManager; + private static SecurityManager securityManager; - private KafkaTestEnv() { - } + private KafkaTestEnv() {} - public static KafkaTestEnv get() { - return INSTANCE; - } + public static KafkaTestEnv get() { + return INSTANCE; + } - private KafkaServer server; - - private ExecutorService zkServer; - - private static final String KAFKA_LOGS_PATH = "/tmp/kafka-logs/dsl-kafka-connector-test"; - - private static final String ZOOKEEPER_LOGS_PATH = "/tmp/zookeeper-kafka-test"; - - public void startZkServer() throws IOException { - cleanZk(); - if (zkServer == null) { - ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat( - "zookeeperServerThread" + "-%d").build(); - ExecutorService singleThreadPool = new ThreadPoolExecutor(5, 5, 0L, - TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(5), namedThreadFactory); - singleThreadPool.execute(() -> { - String clientPort = "2181"; - String tickTime = "16000"; - ServerConfig config = new ServerConfig(); - config.parse(new String[]{clientPort, ZOOKEEPER_LOGS_PATH, tickTime}); - ZooKeeperServerMain zk = new ZooKeeperServerMain(); - try { - zk.runFromConfig(config); - } catch (Exception e) { - throw new GeaFlowDSLException("Error in run zookeeper.", e); - } - }); - zkServer = singleThreadPool; - } - SleepUtils.sleepSecond(10); - } + private KafkaServer server; - public void startKafkaServer() throws IOException { - cleanKafka(); - if (server != null) { - LOGGER.info("server has been established."); - return; - } - startZkServer(); - Properties props = new Properties(); - props.put("broker.id", "0"); - props.put("log.dirs", KAFKA_LOGS_PATH); - props.put("num.partitions", "1"); - props.put("zookeeper.connect", "localhost:2181"); - props.put("offsets.topic.replication.factor", "1"); - props.put("zookeeper.connection.timeout.ms", "10000"); - props.put("retries", "5"); - server = new KafkaServer(new KafkaConfig(props), Time.SYSTEM, Option.empty(), new ArraySeq<>(0)); - LOGGER.info("valid kafka server config."); - server.startup(); - SleepUtils.sleepMilliSecond(1000); - LOGGER.info("server startup."); - } + private ExecutorService zkServer; - public void createTopic(String topic) { - Properties props = new Properties(); - props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, "localhost:9092"); - props.setProperty(KafkaConstants.KAFKA_KEY_DESERIALIZER, - "org.apache.kafka.common.serialization.StringDeserializer"); - props.setProperty(KafkaConstants.KAFKA_VALUE_DESERIALIZER, - "org.apache.kafka.common.serialization.StringDeserializer"); - props.setProperty(KafkaConstants.KAFKA_MAX_POLL_RECORDS, - String.valueOf(500)); - props.setProperty(KafkaConstants.KAFKA_GROUP_ID, "geaflow-dsl-kafka-source-default-group-id"); - KafkaConsumer consumer = new KafkaConsumer<>(props); - Map> topic2PartitionInfo = consumer.listTopics(); - if (!topic2PartitionInfo.containsKey(topic)) { - Properties producerProps = new Properties(); - producerProps.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, "localhost:9092"); - producerProps.setProperty(KafkaConstants.KAFKA_KEY_SERIALIZER, - "org.apache.kafka.common.serialization.StringSerializer"); - producerProps.setProperty(KafkaConstants.KAFKA_VALUE_SERIALIZER, - "org.apache.kafka.common.serialization.StringSerializer"); - - KafkaProducer producer = new KafkaProducer<>(producerProps); - producer.partitionsFor(topic); - producer.close(); - } - consumer.close(); - } + private static final String KAFKA_LOGS_PATH = "/tmp/kafka-logs/dsl-kafka-connector-test"; - public void shutdownKafkaServer() throws IOException { - shutdownKafka(); - cleanKafka(); - cleanZookeeper(); - } + private static final String ZOOKEEPER_LOGS_PATH = "/tmp/zookeeper-kafka-test"; - private void shutdownKafka() throws IOException { - if (server == null) { - LOGGER.warn("null server cannot be shutdownKafka."); - } else { - server.shutdown(); - server = null; - cleanKafka(); - LOGGER.info("server shutdownKafka."); - } + public void startZkServer() throws IOException { + cleanZk(); + if (zkServer == null) { + ThreadFactory namedThreadFactory = + new ThreadFactoryBuilder().setNameFormat("zookeeperServerThread" + "-%d").build(); + ExecutorService singleThreadPool = + new ThreadPoolExecutor( + 5, 5, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(5), namedThreadFactory); + singleThreadPool.execute( + () -> { + String clientPort = "2181"; + String tickTime = "16000"; + ServerConfig config = new ServerConfig(); + config.parse(new String[] {clientPort, ZOOKEEPER_LOGS_PATH, tickTime}); + ZooKeeperServerMain zk = new ZooKeeperServerMain(); + try { + zk.runFromConfig(config); + } catch (Exception e) { + throw new GeaFlowDSLException("Error in run zookeeper.", e); + } + }); + zkServer = singleThreadPool; } - - private void cleanZookeeper() throws IOException { - if (zkServer != null) { - zkServer.shutdownNow(); - zkServer = null; - cleanZk(); - LOGGER.info("zk closed."); - } + SleepUtils.sleepSecond(10); + } + + public void startKafkaServer() throws IOException { + cleanKafka(); + if (server != null) { + LOGGER.info("server has been established."); + return; } - - private void cleanZk() throws IOException { - File file = new File(ZOOKEEPER_LOGS_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } + startZkServer(); + Properties props = new Properties(); + props.put("broker.id", "0"); + props.put("log.dirs", KAFKA_LOGS_PATH); + props.put("num.partitions", "1"); + props.put("zookeeper.connect", "localhost:2181"); + props.put("offsets.topic.replication.factor", "1"); + props.put("zookeeper.connection.timeout.ms", "10000"); + props.put("retries", "5"); + server = + new KafkaServer( + new KafkaConfig(props), Time.SYSTEM, Option.empty(), new ArraySeq<>(0)); + LOGGER.info("valid kafka server config."); + server.startup(); + SleepUtils.sleepMilliSecond(1000); + LOGGER.info("server startup."); + } + + public void createTopic(String topic) { + Properties props = new Properties(); + props.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, "localhost:9092"); + props.setProperty( + KafkaConstants.KAFKA_KEY_DESERIALIZER, + "org.apache.kafka.common.serialization.StringDeserializer"); + props.setProperty( + KafkaConstants.KAFKA_VALUE_DESERIALIZER, + "org.apache.kafka.common.serialization.StringDeserializer"); + props.setProperty(KafkaConstants.KAFKA_MAX_POLL_RECORDS, String.valueOf(500)); + props.setProperty(KafkaConstants.KAFKA_GROUP_ID, "geaflow-dsl-kafka-source-default-group-id"); + KafkaConsumer consumer = new KafkaConsumer<>(props); + Map> topic2PartitionInfo = consumer.listTopics(); + if (!topic2PartitionInfo.containsKey(topic)) { + Properties producerProps = new Properties(); + producerProps.setProperty(KafkaConstants.KAFKA_BOOTSTRAP_SERVERS, "localhost:9092"); + producerProps.setProperty( + KafkaConstants.KAFKA_KEY_SERIALIZER, + "org.apache.kafka.common.serialization.StringSerializer"); + producerProps.setProperty( + KafkaConstants.KAFKA_VALUE_SERIALIZER, + "org.apache.kafka.common.serialization.StringSerializer"); + + KafkaProducer producer = new KafkaProducer<>(producerProps); + producer.partitionsFor(topic); + producer.close(); } - - private void cleanKafka() throws IOException { - File file = new File(KAFKA_LOGS_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } + consumer.close(); + } + + public void shutdownKafkaServer() throws IOException { + shutdownKafka(); + cleanKafka(); + cleanZookeeper(); + } + + private void shutdownKafka() throws IOException { + if (server == null) { + LOGGER.warn("null server cannot be shutdownKafka."); + } else { + server.shutdown(); + server = null; + cleanKafka(); + LOGGER.info("server shutdownKafka."); } + } + + private void cleanZookeeper() throws IOException { + if (zkServer != null) { + zkServer.shutdownNow(); + zkServer = null; + cleanZk(); + LOGGER.info("zk closed."); + } + } - public static void before() { - securityManager = System.getSecurityManager(); - System.setSecurityManager(new SystemExitIgnoreSecurityManager()); + private void cleanZk() throws IOException { + File file = new File(ZOOKEEPER_LOGS_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + } - public static void after() { - System.setSecurityManager(securityManager); + private void cleanKafka() throws IOException { + File file = new File(KAFKA_LOGS_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + } + + public static void before() { + securityManager = System.getSecurityManager(); + System.setSecurityManager(new SystemExitIgnoreSecurityManager()); + } + + public static void after() { + System.setSecurityManager(securityManager); + } + + private static class SystemExitIgnoreSecurityManager extends SecurityManager { + @Override + public void checkPermission(Permission perm) {} + + @Override + public void checkPermission(Permission perm, Object context) {} - private static class SystemExitIgnoreSecurityManager extends SecurityManager { - @Override - public void checkPermission(Permission perm) { - } - - @Override - public void checkPermission(Permission perm, Object context) { - } - - @Override - public void checkExit(int status) { - super.checkExit(status); - LOGGER.info("check exit {}", status); - throw new RuntimeException("throw exception instead of exit process"); - } + @Override + public void checkExit(int status) { + super.checkExit(status); + LOGGER.info("check exit {}", status); + throw new RuntimeException("throw exception instead of exit process"); } + } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/SourceFunctionNoPartitionCheck.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/SourceFunctionNoPartitionCheck.java index 2011cc243..45a6ed191 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/SourceFunctionNoPartitionCheck.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/testenv/SourceFunctionNoPartitionCheck.java @@ -20,18 +20,19 @@ package org.apache.geaflow.dsl.runtime.testenv; import java.util.concurrent.ExecutorService; + import org.apache.geaflow.dsl.connector.api.TableSource; import org.apache.geaflow.dsl.connector.api.function.GeaFlowTableSourceFunction; import org.apache.geaflow.dsl.schema.GeaFlowTable; public class SourceFunctionNoPartitionCheck extends GeaFlowTableSourceFunction { - public SourceFunctionNoPartitionCheck(GeaFlowTable table, TableSource tableSource) { - super(table, tableSource); - } + public SourceFunctionNoPartitionCheck(GeaFlowTable table, TableSource tableSource) { + super(table, tableSource); + } - @Override - protected ExecutorService startPartitionCompareThread() { - return null; - } + @Override + protected ExecutorService startPartitionCompareThread() { + return null; + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/config/ExampleConfigKeys.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/config/ExampleConfigKeys.java index 82cadf03a..67b922718 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/config/ExampleConfigKeys.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/config/ExampleConfigKeys.java @@ -20,45 +20,40 @@ package org.apache.geaflow.example.config; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class ExampleConfigKeys implements Serializable { - public static final ConfigKey SOURCE_PARALLELISM = ConfigKeys - .key("geaflow.source.parallelism") - .defaultValue(1) - .description("job source parallelism"); - - public static final ConfigKey SINK_PARALLELISM = ConfigKeys - .key("geaflow.sink.parallelism") - .defaultValue(1) - .description("job sink parallelism"); + public static final ConfigKey SOURCE_PARALLELISM = + ConfigKeys.key("geaflow.source.parallelism") + .defaultValue(1) + .description("job source parallelism"); - public static final ConfigKey MAP_PARALLELISM = ConfigKeys - .key("geaflow.map.parallelism") - .defaultValue(1) - .description("job map parallelism"); + public static final ConfigKey SINK_PARALLELISM = + ConfigKeys.key("geaflow.sink.parallelism") + .defaultValue(1) + .description("job sink parallelism"); - public static final ConfigKey REDUCE_PARALLELISM = ConfigKeys - .key("geaflow.reduce.parallelism") - .defaultValue(1) - .description("job reduce parallelism"); + public static final ConfigKey MAP_PARALLELISM = + ConfigKeys.key("geaflow.map.parallelism").defaultValue(1).description("job map parallelism"); - public static final ConfigKey ITERATOR_PARALLELISM = ConfigKeys - .key("geaflow.iterator.parallelism") - .defaultValue(1) - .description("job iterator parallelism"); + public static final ConfigKey REDUCE_PARALLELISM = + ConfigKeys.key("geaflow.reduce.parallelism") + .defaultValue(1) + .description("job reduce parallelism"); - public static final ConfigKey AGG_PARALLELISM = ConfigKeys - .key("geaflow.agg.parallelism") - .defaultValue(1) - .description("job agg parallelism"); + public static final ConfigKey ITERATOR_PARALLELISM = + ConfigKeys.key("geaflow.iterator.parallelism") + .defaultValue(1) + .description("job iterator parallelism"); - public static final ConfigKey GEAFLOW_SINK_TYPE = ConfigKeys - .key("geaflow.sink.type") - .defaultValue("console") - .description("job sink type, console or file"); + public static final ConfigKey AGG_PARALLELISM = + ConfigKeys.key("geaflow.agg.parallelism").defaultValue(1).description("job agg parallelism"); + public static final ConfigKey GEAFLOW_SINK_TYPE = + ConfigKeys.key("geaflow.sink.type") + .defaultValue("console") + .description("job sink type, console or file"); } - diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/data/GraphDataSet.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/data/GraphDataSet.java index a2c60a4d4..ae9b6e46c 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/data/GraphDataSet.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/data/GraphDataSet.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.ValueEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -29,78 +30,77 @@ public class GraphDataSet implements Serializable { - public static final String DATASET_FILE = "data/input/web-google-mini"; - - public static List> getIncVertices() { - List> vertices = new ArrayList<>(); - vertices.add(new ValueVertex<>(1, 0)); - vertices.add(new ValueVertex<>(2, 0)); - vertices.add(new ValueVertex<>(3, 0)); - vertices.add(new ValueVertex<>(4, 0)); - vertices.add(new ValueVertex<>(5, 0)); - return vertices; - } + public static final String DATASET_FILE = "data/input/web-google-mini"; - public static List> getIncEdges() { - List> edges = new ArrayList<>(); - edges.add(new ValueEdge<>(1, 2, 12)); - edges.add(new ValueEdge<>(1, 3, 13)); - edges.add(new ValueEdge<>(1, 4, 13)); - edges.add(new ValueEdge<>(2, 4, 23)); - edges.add(new ValueEdge<>(2, 1, 23)); - edges.add(new ValueEdge<>(3, 5, 34)); - edges.add(new ValueEdge<>(4, 5, 35)); - edges.add(new ValueEdge<>(4, 2, 73)); - edges.add(new ValueEdge<>(2, 5, 45)); - edges.add(new ValueEdge<>(5, 1, 51)); - return edges; - } + public static List> getIncVertices() { + List> vertices = new ArrayList<>(); + vertices.add(new ValueVertex<>(1, 0)); + vertices.add(new ValueVertex<>(2, 0)); + vertices.add(new ValueVertex<>(3, 0)); + vertices.add(new ValueVertex<>(4, 0)); + vertices.add(new ValueVertex<>(5, 0)); + return vertices; + } - public static List> getPRVertices() { - List> vertices = new ArrayList<>(); - vertices.add(new ValueVertex<>(1, 0.2)); - vertices.add(new ValueVertex<>(2, 0.2)); - vertices.add(new ValueVertex<>(3, 0.2)); - vertices.add(new ValueVertex<>(4, 0.2)); - vertices.add(new ValueVertex<>(5, 0.2)); - return vertices; - } + public static List> getIncEdges() { + List> edges = new ArrayList<>(); + edges.add(new ValueEdge<>(1, 2, 12)); + edges.add(new ValueEdge<>(1, 3, 13)); + edges.add(new ValueEdge<>(1, 4, 13)); + edges.add(new ValueEdge<>(2, 4, 23)); + edges.add(new ValueEdge<>(2, 1, 23)); + edges.add(new ValueEdge<>(3, 5, 34)); + edges.add(new ValueEdge<>(4, 5, 35)); + edges.add(new ValueEdge<>(4, 2, 73)); + edges.add(new ValueEdge<>(2, 5, 45)); + edges.add(new ValueEdge<>(5, 1, 51)); + return edges; + } - public static List> getPREdges() { - List> edges = new ArrayList<>(); - edges.add(new ValueEdge<>(1, 2, 12)); - edges.add(new ValueEdge<>(1, 3, 13)); - edges.add(new ValueEdge<>(1, 4, 13)); - edges.add(new ValueEdge<>(2, 4, 23)); - edges.add(new ValueEdge<>(3, 5, 34)); - edges.add(new ValueEdge<>(4, 5, 35)); - edges.add(new ValueEdge<>(2, 5, 45)); - edges.add(new ValueEdge<>(5, 1, 51)); - return edges; - } + public static List> getPRVertices() { + List> vertices = new ArrayList<>(); + vertices.add(new ValueVertex<>(1, 0.2)); + vertices.add(new ValueVertex<>(2, 0.2)); + vertices.add(new ValueVertex<>(3, 0.2)); + vertices.add(new ValueVertex<>(4, 0.2)); + vertices.add(new ValueVertex<>(5, 0.2)); + return vertices; + } - public static List> getSSSPVertices() { - List> vertices = new ArrayList<>(); - vertices.add(new ValueVertex<>(1, Integer.MAX_VALUE)); - vertices.add(new ValueVertex<>(2, Integer.MAX_VALUE)); - vertices.add(new ValueVertex<>(3, Integer.MAX_VALUE)); - vertices.add(new ValueVertex<>(4, Integer.MAX_VALUE)); - vertices.add(new ValueVertex<>(5, Integer.MAX_VALUE)); - vertices.add(new ValueVertex<>(6, Integer.MAX_VALUE)); - return vertices; - } + public static List> getPREdges() { + List> edges = new ArrayList<>(); + edges.add(new ValueEdge<>(1, 2, 12)); + edges.add(new ValueEdge<>(1, 3, 13)); + edges.add(new ValueEdge<>(1, 4, 13)); + edges.add(new ValueEdge<>(2, 4, 23)); + edges.add(new ValueEdge<>(3, 5, 34)); + edges.add(new ValueEdge<>(4, 5, 35)); + edges.add(new ValueEdge<>(2, 5, 45)); + edges.add(new ValueEdge<>(5, 1, 51)); + return edges; + } - public static List> getSSSPEdges() { - List> edges = new ArrayList<>(); - edges.add(new ValueEdge<>(1, 3, 10)); - edges.add(new ValueEdge<>(1, 5, 30)); - edges.add(new ValueEdge<>(1, 6, 100)); - edges.add(new ValueEdge<>(2, 3, 5)); - edges.add(new ValueEdge<>(3, 4, 50)); - edges.add(new ValueEdge<>(4, 6, 10)); - edges.add(new ValueEdge<>(5, 4, 20)); - edges.add(new ValueEdge<>(5, 6, 60)); - return edges; - } + public static List> getSSSPVertices() { + List> vertices = new ArrayList<>(); + vertices.add(new ValueVertex<>(1, Integer.MAX_VALUE)); + vertices.add(new ValueVertex<>(2, Integer.MAX_VALUE)); + vertices.add(new ValueVertex<>(3, Integer.MAX_VALUE)); + vertices.add(new ValueVertex<>(4, Integer.MAX_VALUE)); + vertices.add(new ValueVertex<>(5, Integer.MAX_VALUE)); + vertices.add(new ValueVertex<>(6, Integer.MAX_VALUE)); + return vertices; + } + public static List> getSSSPEdges() { + List> edges = new ArrayList<>(); + edges.add(new ValueEdge<>(1, 3, 10)); + edges.add(new ValueEdge<>(1, 5, 30)); + edges.add(new ValueEdge<>(1, 6, 100)); + edges.add(new ValueEdge<>(2, 3, 5)); + edges.add(new ValueEdge<>(3, 4, 50)); + edges.add(new ValueEdge<>(4, 6, 10)); + edges.add(new ValueEdge<>(5, 4, 20)); + edges.add(new ValueEdge<>(5, 6, 60)); + return edges; + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/AbstractVcFunc.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/AbstractVcFunc.java index 7fa704537..b69e4819f 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/AbstractVcFunc.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/AbstractVcFunc.java @@ -21,17 +21,16 @@ import org.apache.geaflow.api.graph.function.vc.VertexCentricComputeFunction; -public abstract class AbstractVcFunc implements VertexCentricComputeFunction { +public abstract class AbstractVcFunc + implements VertexCentricComputeFunction { - protected VertexCentricComputeFuncContext context; + protected VertexCentricComputeFuncContext context; - @Override - public void init(VertexCentricComputeFuncContext vertexCentricFuncContext) { - this.context = vertexCentricFuncContext; - } - - @Override - public void finish() { - } + @Override + public void init(VertexCentricComputeFuncContext vertexCentricFuncContext) { + this.context = vertexCentricFuncContext; + } + @Override + public void finish() {} } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/ConsoleSink.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/ConsoleSink.java index 970dddf57..8c6a892dd 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/ConsoleSink.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/ConsoleSink.java @@ -27,22 +27,20 @@ public class ConsoleSink extends RichFunction implements SinkFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(ConsoleSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ConsoleSink.class); - private int taskIndex; + private int taskIndex; - @Override - public void open(RuntimeContext runtimeContext) { - this.taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - } + @Override + public void open(RuntimeContext runtimeContext) { + this.taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + } - @Override - public void write(OUT out) throws Exception { - LOGGER.info("sink {} got result {}", this.taskIndex, out); - } - - @Override - public void close() { - } + @Override + public void write(OUT out) throws Exception { + LOGGER.info("sink {} got result {}", this.taskIndex, out); + } + @Override + public void close() {} } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSink.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSink.java index 54bfb9154..0edc87479 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSink.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSink.java @@ -22,6 +22,7 @@ import java.io.File; import java.io.IOException; import java.nio.charset.Charset; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; @@ -31,57 +32,59 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; - public class FileSink extends RichFunction implements SinkFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(FileSink.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FileSink.class); - public static final String OUTPUT_DIR = "output.dir"; - public static final String FILE_OUTPUT_APPEND_ENABLE = "file.append.enable"; + public static final String OUTPUT_DIR = "output.dir"; + public static final String FILE_OUTPUT_APPEND_ENABLE = "file.append.enable"; - private File file; + private File file; - public FileSink() { - } + public FileSink() {} - @Override - public void open(RuntimeContext runtimeContext) { - String filePath = String.format("%s/result_%s", - runtimeContext.getConfiguration().getString(OUTPUT_DIR), runtimeContext.getTaskArgs().getTaskIndex()); - LOGGER.info("sink file name {}", filePath); - boolean append = runtimeContext.getConfiguration().getBoolean(new ConfigKey(FILE_OUTPUT_APPEND_ENABLE, true)); - file = new File(filePath); + @Override + public void open(RuntimeContext runtimeContext) { + String filePath = + String.format( + "%s/result_%s", + runtimeContext.getConfiguration().getString(OUTPUT_DIR), + runtimeContext.getTaskArgs().getTaskIndex()); + LOGGER.info("sink file name {}", filePath); + boolean append = + runtimeContext + .getConfiguration() + .getBoolean(new ConfigKey(FILE_OUTPUT_APPEND_ENABLE, true)); + file = new File(filePath); + try { + if (!append && file.exists()) { try { - if (!append && file.exists()) { - try { - FileUtils.forceDelete(file); - } catch (Exception e) { - // ignore - } - } + FileUtils.forceDelete(file); + } catch (Exception e) { + // ignore + } + } - if (!file.exists()) { - if (!file.getParentFile().exists()) { - file.getParentFile().mkdirs(); - } - file.createNewFile(); - } - } catch (IOException e) { - throw new GeaflowRuntimeException(e); + if (!file.exists()) { + if (!file.getParentFile().exists()) { + file.getParentFile().mkdirs(); } + file.createNewFile(); + } + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + } - @Override - public void close() { + @Override + public void close() {} + @Override + public void write(OUT out) throws Exception { + try { + FileUtils.write(file, out + "\n", Charset.defaultCharset(), true); + } catch (IOException e) { + throw new RuntimeException(e); } - - @Override - public void write(OUT out) throws Exception { - try { - FileUtils.write(file, out + "\n", Charset.defaultCharset(), true); - } catch (IOException e) { - throw new RuntimeException(e); - } - } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSource.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSource.java index 5dc52ef12..572d8b398 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSource.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/FileSource.java @@ -19,13 +19,13 @@ package org.apache.geaflow.example.function; -import com.google.common.io.Resources; import java.io.IOException; import java.io.Serializable; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collection; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.function.RichFunction; @@ -34,90 +34,98 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.io.Resources; + public class FileSource extends RichFunction implements SourceFunction { - private static final Logger LOGGER = LoggerFactory.getLogger(FileSource.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FileSource.class); - protected String filePath; - protected List records; - protected Integer readPos = null; - protected FileLineParser parser; - protected transient RuntimeContext runtimeContext; + protected String filePath; + protected List records; + protected Integer readPos = null; + protected FileLineParser parser; + protected transient RuntimeContext runtimeContext; - public FileSource(String filePath, FileLineParser parser) { - this.filePath = filePath; - this.parser = parser; - } + public FileSource(String filePath, FileLineParser parser) { + this.filePath = filePath; + this.parser = parser; + } - @Override - public void open(RuntimeContext runtimeContext) { - this.runtimeContext = runtimeContext; - } + @Override + public void open(RuntimeContext runtimeContext) { + this.runtimeContext = runtimeContext; + } - @Override - public void init(int parallel, int index) { - this.records = readFileLines(filePath); - if (parallel != 1) { - List allRecords = records; - records = new ArrayList<>(); - for (int i = 0; i < allRecords.size(); i++) { - if (i % parallel == index) { - records.add(allRecords.get(i)); - } - } + @Override + public void init(int parallel, int index) { + this.records = readFileLines(filePath); + if (parallel != 1) { + List allRecords = records; + records = new ArrayList<>(); + for (int i = 0; i < allRecords.size(); i++) { + if (i % parallel == index) { + records.add(allRecords.get(i)); } + } } + } - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - LOGGER.info("collection source fetch taskId:{}, batchId:{}, start readPos {}, totalSize {}", - runtimeContext.getTaskArgs().getTaskId(), window.windowId(), readPos, records.size()); - if (readPos == null) { - readPos = 0; - } - while (readPos < records.size()) { - OUT out = records.get(readPos); - long windowId = window.assignWindow(out); - if (window.windowId() == windowId) { - ctx.collect(out); - readPos++; - } else { - break; - } - } - boolean result = false; - if (readPos < records.size()) { - result = true; - } - LOGGER.info("collection source fetch batchId:{}, current readPos {}, result {}", - window.windowId(), readPos, result); - return result; + @Override + public boolean fetch(IWindow window, SourceContext ctx) throws Exception { + LOGGER.info( + "collection source fetch taskId:{}, batchId:{}, start readPos {}, totalSize {}", + runtimeContext.getTaskArgs().getTaskId(), + window.windowId(), + readPos, + records.size()); + if (readPos == null) { + readPos = 0; } - - @Override - public void close() { + while (readPos < records.size()) { + OUT out = records.get(readPos); + long windowId = window.assignWindow(out); + if (window.windowId() == windowId) { + ctx.collect(out); + readPos++; + } else { + break; + } } + boolean result = false; + if (readPos < records.size()) { + result = true; + } + LOGGER.info( + "collection source fetch batchId:{}, current readPos {}, result {}", + window.windowId(), + readPos, + result); + return result; + } + + @Override + public void close() {} - private List readFileLines(String filePath) { - try { - List lines = Resources.readLines(Resources.getResource(filePath), - Charset.defaultCharset()); - List result = new ArrayList<>(); + private List readFileLines(String filePath) { + try { + List lines = + Resources.readLines(Resources.getResource(filePath), Charset.defaultCharset()); + List result = new ArrayList<>(); - for (String line : lines) { - if (StringUtils.isBlank(line)) { - continue; - } - Collection collection = parser.parse(line); - result.addAll(collection); - } - return result; - } catch (IOException e) { - throw new RuntimeException("error in read resource file: " + filePath, e); + for (String line : lines) { + if (StringUtils.isBlank(line)) { + continue; } + Collection collection = parser.parse(line); + result.addAll(collection); + } + return result; + } catch (IOException e) { + throw new RuntimeException("error in read resource file: " + filePath, e); } + } - public interface FileLineParser extends Serializable { - Collection parse(String line); - } + public interface FileLineParser extends Serializable { + Collection parse(String line); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/RecoverableFileSource.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/RecoverableFileSource.java index 38d427c9c..28d5568e9 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/RecoverableFileSource.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/function/RecoverableFileSource.java @@ -27,6 +27,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.api.context.RuntimeContext; import org.apache.geaflow.api.window.IWindow; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -47,169 +48,181 @@ public class RecoverableFileSource extends FileSource { - private static final Logger LOGGER = LoggerFactory.getLogger(RecoverableFileSource.class); - private static final Pattern P = Pattern.compile("(PipelineTask#\\d+).*(.*cycle#\\d+)-\\d+(.*)"); - - private KeyValueState offsetState; - private long checkpointDuration; - - public RecoverableFileSource(String filePath, FileLineParser parser) { - super(filePath, parser); - } - - @Override - public void open(RuntimeContext runtimeContext) { - super.open(runtimeContext); - this.offsetState = OffsetKVStateKeeper.build(runtimeContext); - this.checkpointDuration = this.runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); - LOGGER.info("open file source, taskIndex {}, taskName {}, checkpointDuration {}", - runtimeContext.getTaskArgs().getTaskId(), runtimeContext.getTaskArgs().getTaskName(), checkpointDuration); - } - - @Override - public boolean fetch(IWindow window, SourceContext ctx) throws Exception { - if (window.windowId() == 1) { - readPos = 0; - LOGGER.info("init readPos to {} for windowId {}", 0, window.windowId()); + private static final Logger LOGGER = LoggerFactory.getLogger(RecoverableFileSource.class); + private static final Pattern P = Pattern.compile("(PipelineTask#\\d+).*(.*cycle#\\d+)-\\d+(.*)"); + + private KeyValueState offsetState; + private long checkpointDuration; + + public RecoverableFileSource(String filePath, FileLineParser parser) { + super(filePath, parser); + } + + @Override + public void open(RuntimeContext runtimeContext) { + super.open(runtimeContext); + this.offsetState = OffsetKVStateKeeper.build(runtimeContext); + this.checkpointDuration = + this.runtimeContext.getConfiguration().getLong(BATCH_NUMBER_PER_CHECKPOINT); + LOGGER.info( + "open file source, taskIndex {}, taskName {}, checkpointDuration {}", + runtimeContext.getTaskArgs().getTaskId(), + runtimeContext.getTaskArgs().getTaskName(), + checkpointDuration); + } + + @Override + public boolean fetch(IWindow window, SourceContext ctx) throws Exception { + if (window.windowId() == 1) { + readPos = 0; + LOGGER.info("init readPos to {} for windowId {}", 0, window.windowId()); + } else { + if (readPos == null) { + long lastWindowId = window.windowId() - 1; + LOGGER.info("need recover readPos for last windowId {}", lastWindowId); + offsetState.manage().operate().setCheckpointId(lastWindowId); + offsetState.manage().operate().recover(); + Integer pos = offsetState.get(runtimeContext.getTaskArgs().getTaskIndex()); + if (pos != null) { + LOGGER.info("windowId{} recover readPos {}", window.windowId(), pos); + readPos = pos; } else { - if (readPos == null) { - long lastWindowId = window.windowId() - 1; - LOGGER.info("need recover readPos for last windowId {}", lastWindowId); - offsetState.manage().operate().setCheckpointId(lastWindowId); - offsetState.manage().operate().recover(); - Integer pos = offsetState.get(runtimeContext.getTaskArgs().getTaskIndex()); - if (pos != null) { - LOGGER.info("windowId{} recover readPos {}", window.windowId(), pos); - readPos = pos; - } else { - LOGGER.info("not found readPos set to {}", 0); - readPos = 0; - } - } else { - LOGGER.info("current windowId {} readPos {}", window.windowId(), readPos); - } - } - boolean result = super.fetch(window, ctx); - offsetState.put(runtimeContext.getTaskArgs().getTaskIndex(), readPos); - - long batchId = window.windowId(); - if (CheckpointUtil.needDoCheckpoint(batchId, checkpointDuration)) { - offsetState.manage().operate().setCheckpointId(batchId); - offsetState.manage().operate().finish(); - offsetState.manage().operate().archive(); - LOGGER.info("do checkpoint windowId {} readPos {}", batchId, readPos); + LOGGER.info("not found readPos set to {}", 0); + readPos = 0; } - return result; + } else { + LOGGER.info("current windowId {} readPos {}", window.windowId(), readPos); + } } - - @Override - public void close() { - + boolean result = super.fetch(window, ctx); + offsetState.put(runtimeContext.getTaskArgs().getTaskIndex(), readPos); + + long batchId = window.windowId(); + if (CheckpointUtil.needDoCheckpoint(batchId, checkpointDuration)) { + offsetState.manage().operate().setCheckpointId(batchId); + offsetState.manage().operate().finish(); + offsetState.manage().operate().archive(); + LOGGER.info("do checkpoint windowId {} readPos {}", batchId, readPos); } - - static class OffsetKVStateKeeper { - - private static volatile Map> KV_STATE_MAP = new ConcurrentHashMap<>(); - - public static KeyValueState build(RuntimeContext runtimeContext) { - // Pipeline task name pattern PipelineTask#0-[windowId] cycle#1-[iterationId], - // e.g. PipelineTask#0-1 cycle#1-1. - // We only extract name without any windowId or iterationId for store name. - Matcher matcher = P.matcher(runtimeContext.getTaskArgs().getTaskName()); - String taskName; - if (matcher.find()) { - taskName = matcher.group(1) + matcher.group(2); - } else { - taskName = runtimeContext.getTaskArgs().getTaskName(); + return result; + } + + @Override + public void close() {} + + static class OffsetKVStateKeeper { + + private static volatile Map> KV_STATE_MAP = + new ConcurrentHashMap<>(); + + public static KeyValueState build(RuntimeContext runtimeContext) { + // Pipeline task name pattern PipelineTask#0-[windowId] cycle#1-[iterationId], + // e.g. PipelineTask#0-1 cycle#1-1. + // We only extract name without any windowId or iterationId for store name. + Matcher matcher = P.matcher(runtimeContext.getTaskArgs().getTaskName()); + String taskName; + if (matcher.find()) { + taskName = matcher.group(1) + matcher.group(2); + } else { + taskName = runtimeContext.getTaskArgs().getTaskName(); + } + + taskName = + String.format( + "%s_%s_%s", + runtimeContext.getConfiguration().getString(ExecutionConfigKeys.JOB_UNIQUE_ID), + taskName, + runtimeContext.getTaskArgs().getTaskId()); + if (KV_STATE_MAP.get(taskName) == null) { + synchronized (OffsetKVStateKeeper.class) { + if (KV_STATE_MAP.get(taskName) == null) { + KeyValueStateDescriptor descriptor = + KeyValueStateDescriptor.build( + taskName, + runtimeContext.getConfiguration().getString(SYSTEM_OFFSET_BACKEND_TYPE)); + for (Entry entry : + runtimeContext.getConfiguration().getConfigMap().entrySet()) { + LOGGER.info("runtime key {} value {}", entry.getKey(), entry.getValue()); } - - taskName = String.format("%s_%s_%s", - runtimeContext.getConfiguration().getString(ExecutionConfigKeys.JOB_UNIQUE_ID), - taskName, runtimeContext.getTaskArgs().getTaskId()); - if (KV_STATE_MAP.get(taskName) == null) { - synchronized (OffsetKVStateKeeper.class) { - if (KV_STATE_MAP.get(taskName) == null) { - KeyValueStateDescriptor descriptor = KeyValueStateDescriptor.build( - taskName, - runtimeContext.getConfiguration().getString(SYSTEM_OFFSET_BACKEND_TYPE)); - for (Entry entry : runtimeContext.getConfiguration().getConfigMap().entrySet()) { - LOGGER.info("runtime key {} value {}", entry.getKey(), entry.getValue()); - } - if (descriptor.getStoreType().equalsIgnoreCase(new RocksdbStoreBuilder().getStoreDesc().name())) { - throw new GeaflowRuntimeException("GeaFlow offset not support ROCKSDB storage and should " - + "be configured as JDBC or MEMORY"); - } - int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); - int maxParallelism = runtimeContext.getTaskArgs().getMaxParallelism(); - KeyGroup keyGroup = KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex(maxParallelism, - runtimeContext.getTaskArgs().getParallelism(), - taskIndex); - descriptor.withKeyGroup(keyGroup); - descriptor.withKVSerializer(new OffsetKvSerializer()); - descriptor.withKeyGroupAssigner(new IntKeyGroupAssigner(maxParallelism)); - KV_STATE_MAP.put(taskName, StateFactory.buildKeyValueState(descriptor, runtimeContext.getConfiguration())); - } - } + if (descriptor + .getStoreType() + .equalsIgnoreCase(new RocksdbStoreBuilder().getStoreDesc().name())) { + throw new GeaflowRuntimeException( + "GeaFlow offset not support ROCKSDB storage and should " + + "be configured as JDBC or MEMORY"); } - return KV_STATE_MAP.get(taskName); + int taskIndex = runtimeContext.getTaskArgs().getTaskIndex(); + int maxParallelism = runtimeContext.getTaskArgs().getMaxParallelism(); + KeyGroup keyGroup = + KeyGroupAssignment.computeKeyGroupRangeForOperatorIndex( + maxParallelism, runtimeContext.getTaskArgs().getParallelism(), taskIndex); + descriptor.withKeyGroup(keyGroup); + descriptor.withKVSerializer(new OffsetKvSerializer()); + descriptor.withKeyGroupAssigner(new IntKeyGroupAssigner(maxParallelism)); + KV_STATE_MAP.put( + taskName, + StateFactory.buildKeyValueState(descriptor, runtimeContext.getConfiguration())); + } } - + } + return KV_STATE_MAP.get(taskName); } + } - static class IntKeyGroupAssigner implements IKeyGroupAssigner { + static class IntKeyGroupAssigner implements IKeyGroupAssigner { - private int maxPara; + private int maxPara; - public IntKeyGroupAssigner(int maxPara) { - this.maxPara = maxPara; - } + public IntKeyGroupAssigner(int maxPara) { + this.maxPara = maxPara; + } - @Override - public int getKeyGroupNumber() { - return this.maxPara; - } + @Override + public int getKeyGroupNumber() { + return this.maxPara; + } - @Override - public int assign(Object key) { - if (key == null) { - return -1; - } - return Math.abs(((int) key) % maxPara); - } + @Override + public int assign(Object key) { + if (key == null) { + return -1; + } + return Math.abs(((int) key) % maxPara); } + } - static class OffsetKvSerializer implements IKVSerializer { + static class OffsetKvSerializer implements IKVSerializer { - private final ISerializer serializer; + private final ISerializer serializer; - public OffsetKvSerializer() { - this.serializer = SerializerFactory.getKryoSerializer(); - } + public OffsetKvSerializer() { + this.serializer = SerializerFactory.getKryoSerializer(); + } - @Override - public byte[] serializeKey(Integer key) { - return serializer.serialize(key); - } + @Override + public byte[] serializeKey(Integer key) { + return serializer.serialize(key); + } - @Override - public Integer deserializeKey(byte[] array) { - if (array == null) { - return null; - } - return (Integer) serializer.deserialize(array); - } + @Override + public Integer deserializeKey(byte[] array) { + if (array == null) { + return null; + } + return (Integer) serializer.deserialize(array); + } - @Override - public byte[] serializeValue(Integer value) { - return serializer.serialize(value); - } + @Override + public byte[] serializeValue(Integer value) { + return serializer.serialize(value); + } - @Override - public Integer deserializeValue(byte[] valueArray) { - if (valueArray == null) { - return null; - } - return (Integer) serializer.deserialize(valueArray); - } + @Override + public Integer deserializeValue(byte[] valueArray) { + if (valueArray == null) { + return null; + } + return (Integer) serializer.deserialize(valueArray); } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/IncrGraphCompute.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/IncrGraphCompute.java index 0a4132f2c..1d7f61fdf 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/IncrGraphCompute.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/IncrGraphCompute.java @@ -25,6 +25,7 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; @@ -60,182 +61,202 @@ public class IncrGraphCompute { - private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphCompute.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphCompute.class); - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph"; - public static final String REF_FILE_PATH = "data/reference/incr_graph"; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph"; + public static final String REF_FILE_PATH = "data/reference/incr_graph"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } - public static IPipelineResult submit(Environment environment) { - final Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static IPipelineResult submit(Environment environment) { + final Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - //build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + // build graph view + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(envConfig.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) .withBackend(BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, Integer.class, - ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> vertices = - // extract vertex from edge file - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IVertex vertex1 = new ValueVertex<>( - Integer.valueOf(fields[0]), 1); - IVertex vertex2 = new ValueVertex<>( - Integer.valueOf(fields[1]), 1); - return Arrays.asList(vertex1, vertex2); - }), SizeTumblingWindow.of(10000)) - .withParallelism(pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> vertices = + // extract vertex from edge file + pipelineTaskCxt + .buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IVertex vertex1 = + new ValueVertex<>(Integer.valueOf(fields[0]), 1); + IVertex vertex2 = + new ValueVertex<>(Integer.valueOf(fields[1]), 1); + return Arrays.asList(vertex1, vertex2); + }), + SizeTumblingWindow.of(10000)) + .withParallelism( + pipelineTaskCxt + .getConfig() + .getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt.buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(5000)); - - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); - - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); - int mapParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); - int sinkParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - incGraphView.incrementalCompute(new IncGraphAlgorithms(3)) - .getVertices() - .map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .withParallelism(mapParallelism) - .sink(sink) - .withParallelism(sinkParallelism); - } + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(5000)); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + int mapParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); + int sinkParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + incGraphView + .incrementalCompute(new IncGraphAlgorithms(3)) + .getVertices() + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .withParallelism(mapParallelism) + .sink(sink) + .withParallelism(sinkParallelism); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, - Comparator.comparingLong(IncrGraphCompute::parseSumValue)); - } + public static void validateResult() throws IOException { + ResultValidator.validateMapResult( + REF_FILE_PATH, RESULT_FILE_PATH, Comparator.comparingLong(IncrGraphCompute::parseSumValue)); + } - private static long parseSumValue(String result) { - String sumValue = result.split(",")[1]; - return Long.parseLong(sumValue); + private static long parseSumValue(String result) { + String sumValue = result.split(",")[1]; + return Long.parseLong(sumValue); + } + + public static class IncGraphAlgorithms + extends IncVertexCentricCompute { + + public IncGraphAlgorithms(long iterations) { + super(iterations); } - public static class IncGraphAlgorithms extends IncVertexCentricCompute { + @Override + public IncVertexCentricComputeFunction + getIncComputeFunction() { + return new PRVertexCentricComputeFunction(); + } - public IncGraphAlgorithms(long iterations) { - super(iterations); - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + } - @Override - public IncVertexCentricComputeFunction getIncComputeFunction() { - return new PRVertexCentricComputeFunction(); - } + public static class PRVertexCentricComputeFunction + implements IncVertexCentricComputeFunction { - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private IncGraphComputeContext graphContext; + @Override + public void init(IncGraphComputeContext graphContext) { + this.graphContext = graphContext; } - public static class PRVertexCentricComputeFunction implements - IncVertexCentricComputeFunction { + @Override + public void evolve(Integer vertexId, TemporaryGraph temporaryGraph) { + long lastVersionId = 0L; + IVertex vertex = temporaryGraph.getVertex(); + HistoricalGraph historicalGraph = + graphContext.getHistoricalGraph(); + if (vertex == null) { + vertex = historicalGraph.getSnapShot(lastVersionId).vertex().get(); + } - private IncGraphComputeContext graphContext; + if (vertex != null) { + List> newEs = temporaryGraph.getEdges(); + List> oldEs = + historicalGraph.getSnapShot(lastVersionId).edges().getOutEdges(); + if (newEs != null) { + for (IEdge edge : newEs) { - @Override - public void init(IncGraphComputeContext graphContext) { - this.graphContext = graphContext; + graphContext.sendMessage(edge.getTargetId(), vertexId); + } } - - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - long lastVersionId = 0L; - IVertex vertex = temporaryGraph.getVertex(); - HistoricalGraph historicalGraph = graphContext - .getHistoricalGraph(); - if (vertex == null) { - vertex = historicalGraph.getSnapShot(lastVersionId).vertex().get(); - } - - if (vertex != null) { - List> newEs = temporaryGraph.getEdges(); - List> oldEs = historicalGraph.getSnapShot(lastVersionId) - .edges().getOutEdges(); - if (newEs != null) { - for (IEdge edge : newEs) { - - graphContext.sendMessage(edge.getTargetId(), vertexId); - } - } - if (oldEs != null) { - for (IEdge edge : oldEs) { - graphContext.sendMessage(edge.getTargetId(), vertexId); - } - } - } - + if (oldEs != null) { + for (IEdge edge : oldEs) { + graphContext.sendMessage(edge.getTargetId(), vertexId); + } } + } + } - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - int max = 0; - while (messageIterator.hasNext()) { - int value = messageIterator.next(); - max = max > value ? max : value; - } - IVertex vertex = graphContext.getTemporaryGraph().getVertex(); - IVertex historyVertex = graphContext.getHistoricalGraph().getSnapShot(0).vertex().get(); - if (vertex != null && max < vertex.getValue()) { - max = vertex.getValue(); - } - if (historyVertex != null && max < historyVertex.getValue()) { - max = historyVertex.getValue(); - } - graphContext.getTemporaryGraph().updateVertexValue(max); - } + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + int max = 0; + while (messageIterator.hasNext()) { + int value = messageIterator.next(); + max = max > value ? max : value; + } + IVertex vertex = graphContext.getTemporaryGraph().getVertex(); + IVertex historyVertex = + graphContext.getHistoricalGraph().getSnapShot(0).vertex().get(); + if (vertex != null && max < vertex.getValue()) { + max = vertex.getValue(); + } + if (historyVertex != null && max < historyVertex.getValue()) { + max = historyVertex.getValue(); + } + graphContext.getTemporaryGraph().updateVertexValue(max); + } - @Override - public void finish(Integer vertexId, MutableGraph mutableGraph) { - IVertex vertex = graphContext.getTemporaryGraph().getVertex(); - List> edges = graphContext.getTemporaryGraph().getEdges(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - graphContext.collect(vertex); - } else { - LOGGER.info("not found vertex {} in temporaryGraph ", vertexId); - } - if (edges != null) { - edges.stream().forEach(edge -> { - mutableGraph.addEdge(0, edge); + @Override + public void finish(Integer vertexId, MutableGraph mutableGraph) { + IVertex vertex = graphContext.getTemporaryGraph().getVertex(); + List> edges = graphContext.getTemporaryGraph().getEdges(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + graphContext.collect(vertex); + } else { + LOGGER.info("not found vertex {} in temporaryGraph ", vertexId); + } + if (edges != null) { + edges.stream() + .forEach( + edge -> { + mutableGraph.addEdge(0, edge); }); - } - } + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAlgorithm.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAlgorithm.java index 25874b9f9..378d7fa47 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAlgorithm.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAlgorithm.java @@ -23,6 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.graph.function.vc.IncVertexCentricAggTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -37,215 +38,239 @@ import org.slf4j.LoggerFactory; public class IncrGraphAggTraversalAlgorithm - extends IncVertexCentricAggTraversal, Tuple, - Tuple, Integer> { - - private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphAggTraversalAlgorithm.class); - - public IncrGraphAggTraversalAlgorithm(long iterations) { - super(iterations); - } - - @Override - public IncVertexCentricAggTraversalFunction getIncTraversalFunction() { + extends IncVertexCentricAggTraversal< + Integer, + Integer, + Integer, + Integer, + Integer, + Integer, + Tuple, + Tuple, + Tuple, + Integer> { + + private static final Logger LOGGER = + LoggerFactory.getLogger(IncrGraphAggTraversalAlgorithm.class); + + public IncrGraphAggTraversalAlgorithm(long iterations) { + super(iterations); + } + + @Override + public IncVertexCentricAggTraversalFunction< + Integer, Integer, Integer, Integer, Integer, Integer, Integer> + getIncTraversalFunction() { + + return new IncVertexCentricAggTraversalFunction< + Integer, Integer, Integer, Integer, Integer, Integer, Integer>() { + + private IncVertexCentricTraversalFuncContext + vertexCentricFuncContext; + private VertexCentricAggContext aggContext; + private Map vertexValue; + + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.aggContext = aggContext; + } + + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; + this.vertexValue = new HashMap<>(); + } + + @Override + public void evolve( + Integer vertexId, TemporaryGraph temporaryGraph) { + MutableGraph mutableGraph = + this.vertexCentricFuncContext.getMutableGraph(); + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + vertex.withValue(0); + mutableGraph.addVertex(0, vertex); + } + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(0, edge); + } + } + } + + @Override + public void init(ITraversalRequest traversalRequest) { + List> edges = + this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + this.vertexCentricFuncContext.sendMessage(edge.getTargetId(), edges.size()); + } + aggContext.aggregate(edges.size()); + vertexValue.put(traversalRequest.getVId(), 0); + } + } - return new IncVertexCentricAggTraversalFunction() { + @Override + public void finish() {} - private IncVertexCentricTraversalFuncContext vertexCentricFuncContext; - private VertexCentricAggContext aggContext; - private Map vertexValue; + @Override + public void close() {} - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.aggContext = aggContext; - } + @Override + public void compute(Integer vertexId, Iterator messageIterator) { - @Override - public void open(IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - this.vertexValue = new HashMap<>(); - } + int sum = 0; + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + List> edges = + this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - MutableGraph mutableGraph = - this.vertexCentricFuncContext.getMutableGraph(); - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - vertex.withValue(0); - mutableGraph.addVertex(0, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(0, edge); - } - } - } + if (edges == null || edges.isEmpty()) { + aggContext.aggregate(0); + return; + } - @Override - public void init(ITraversalRequest traversalRequest) { - List> edges = - this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - this.vertexCentricFuncContext.sendMessage(edge.getTargetId(), edges.size()); - } - aggContext.aggregate(edges.size()); - vertexValue.put(traversalRequest.getVId(), 0); - } - } + int average = sum / edges.size(); + IVertex vertex = + this.vertexCentricFuncContext.getTemporaryGraph().getVertex(); + + if (vertex != null) { + for (IEdge edge : edges) { + this.vertexCentricFuncContext.sendMessage(edge.getTargetId(), average); + } + aggContext.aggregate(edges.size()); + vertexValue.put(vertexId, vertex.getValue()); + } else { + aggContext.aggregate(0); + vertexValue.put(vertexId, 0); + } + } + + @Override + public void finish(Integer vertexId, MutableGraph mutableGraph) { + this.vertexCentricFuncContext.takeResponse( + new TraversalResponse( + vertexId, Math.toIntExact(vertexCentricFuncContext.getIterationId()))); + } + }; + } + + @Override + public VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer> + getAggregateFunction() { + return new VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer>() { + @Override + public IPartialGraphAggFunction, Tuple> + getPartialAggregation() { + return new IPartialGraphAggFunction< + Integer, Tuple, Tuple>() { + + private IPartialAggContext> partialAggContext; + + @Override + public Tuple create( + IPartialAggContext> partialAggContext) { + this.partialAggContext = partialAggContext; + return Tuple.of(0, 0); + } + + @Override + public Tuple aggregate( + Integer integer, Tuple result) { + result.f0 += 1; + result.f1 += integer; + return result; + } + + @Override + public void finish(Tuple result) { + partialAggContext.collect(result); + } + }; + } - @Override - public void finish() { + @Override + public IGraphAggregateFunction, Tuple, Integer> + getGlobalAggregation() { + return new IGraphAggregateFunction< + Tuple, Tuple, Integer>() { + private IGlobalGraphAggContext globalGraphAggContext; + + @Override + public Tuple create( + IGlobalGraphAggContext globalGraphAggContext) { + this.globalGraphAggContext = globalGraphAggContext; + return Tuple.of(0, 0); + } + + @Override + public Integer aggregate( + Tuple integerIntegerTuple2, + Tuple integerIntegerTuple22) { + return (integerIntegerTuple22.f1 + integerIntegerTuple2.f1) / 1005; + } + + @Override + public void finish(Integer value) { + long iterationId = this.globalGraphAggContext.getIteration(); + if (iterationId == 2) { + LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); + this.globalGraphAggContext.terminate(); + } else { + LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); + // Set a dummy value, without any global result + this.globalGraphAggContext.broadcast(0); } + } + }; + } + }; + } - @Override - public void close() { + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } - } + static class TraversalResponse implements ITraversalResponse { - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - - int sum = 0; - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - List> edges = - this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); - - if (edges == null || edges.isEmpty()) { - aggContext.aggregate(0); - return; - } - - int average = sum / edges.size(); - IVertex vertex = this.vertexCentricFuncContext.getTemporaryGraph().getVertex(); - - if (vertex != null) { - for (IEdge edge : edges) { - this.vertexCentricFuncContext.sendMessage(edge.getTargetId(), average); - } - aggContext.aggregate(edges.size()); - vertexValue.put(vertexId, vertex.getValue()); - } else { - aggContext.aggregate(0); - vertexValue.put(vertexId, 0); - } - } + private long responseId; + private int value; - @Override - public void finish(Integer vertexId, - MutableGraph mutableGraph) { - this.vertexCentricFuncContext.takeResponse(new TraversalResponse(vertexId, - Math.toIntExact(vertexCentricFuncContext.getIterationId()))); - } - }; + public TraversalResponse(long responseId, int value) { + this.responseId = responseId; + this.value = value; } @Override - public VertexCentricAggregateFunction, Tuple, Tuple, Integer> getAggregateFunction() { - return new VertexCentricAggregateFunction, - Tuple, Tuple, Integer>() { - @Override - public IPartialGraphAggFunction, Tuple> getPartialAggregation() { - return new IPartialGraphAggFunction, Tuple>() { - - private IPartialAggContext> partialAggContext; - - @Override - public Tuple create( - IPartialAggContext> partialAggContext) { - this.partialAggContext = partialAggContext; - return Tuple.of(0, 0); - } - - - @Override - public Tuple aggregate(Integer integer, Tuple result) { - result.f0 += 1; - result.f1 += integer; - return result; - } - - @Override - public void finish(Tuple result) { - partialAggContext.collect(result); - } - }; - } - - @Override - public IGraphAggregateFunction, Tuple, - Integer> getGlobalAggregation() { - return new IGraphAggregateFunction, Tuple, Integer>() { - - private IGlobalGraphAggContext globalGraphAggContext; - - @Override - public Tuple create( - IGlobalGraphAggContext globalGraphAggContext) { - this.globalGraphAggContext = globalGraphAggContext; - return Tuple.of(0, 0); - } - - @Override - public Integer aggregate(Tuple integerIntegerTuple2, - Tuple integerIntegerTuple22) { - return (integerIntegerTuple22.f1 + integerIntegerTuple2.f1) / 1005; - } - - @Override - public void finish(Integer value) { - long iterationId = this.globalGraphAggContext.getIteration(); - if (iterationId == 2) { - LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); - this.globalGraphAggContext.terminate(); - } else { - LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); - // Set a dummy value, without any global result - this.globalGraphAggContext.broadcast(0); - } - } - }; - } - }; + public long getResponseId() { + return responseId; } @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + public Integer getResponse() { + return value; } - static class TraversalResponse implements ITraversalResponse { - - private long responseId; - private int value; - - public TraversalResponse(long responseId, int value) { - this.responseId = responseId; - this.value = value; - } - - @Override - public long getResponseId() { - return responseId; - } - - @Override - public Integer getResponse() { - return value; - } - - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public ResponseType getType() { + return ResponseType.Vertex; } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAll.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAll.java index c0456947b..f3a92720e 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAll.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphAggTraversalAll.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; @@ -53,86 +54,105 @@ public class IncrGraphAggTraversalAll { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph_traversal_all_agg"; - public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal_all_agg"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } - - public static IPipelineResult submit(Environment environment) { - final Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - //build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + public static final String RESULT_FILE_PATH = + "./target/tmp/data/result/incr_graph_traversal_all_agg"; + public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal_all_agg"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } + + public static IPipelineResult submit(Environment environment) { + final Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + // build graph view + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(envConfig.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) .withBackend(BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> vertices = - // extract vertex from edge file - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IVertex vertex1 = new ValueVertex<>( - Integer.valueOf(fields[0]), 1); - IVertex vertex2 = new ValueVertex<>( - Integer.valueOf(fields[1]), 1); - return Arrays.asList(vertex1, vertex2); - }), SizeTumblingWindow.of(10000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> vertices = + // extract vertex from edge file + pipelineTaskCxt + .buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IVertex vertex1 = + new ValueVertex<>(Integer.valueOf(fields[0]), 1); + IVertex vertex2 = + new ValueVertex<>(Integer.valueOf(fields[1]), 1); + return Arrays.asList(vertex1, vertex2); + }), + SizeTumblingWindow.of(10000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt.buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(5000)); - - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); - - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); - - int mapParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); - int sinkParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - incGraphView.incrementalTraversal(new IncrGraphAggTraversalAlgorithm(5)) - .start() - .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) - .withParallelism(mapParallelism) - .sink(sink) - .withParallelism(sinkParallelism); - - } + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(5000)); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + int mapParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); + int sinkParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + incGraphView + .incrementalTraversal(new IncrGraphAggTraversalAlgorithm(5)) + .start() + .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) + .withParallelism(mapParallelism) + .sink(sink) + .withParallelism(sinkParallelism); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, - Comparator.comparingLong(IncrGraphAggTraversalAll::parseSumValue)); - } + public static void validateResult() throws IOException { + ResultValidator.validateMapResult( + REF_FILE_PATH, + RESULT_FILE_PATH, + Comparator.comparingLong(IncrGraphAggTraversalAll::parseSumValue)); + } - private static long parseSumValue(String result) { - String sumValue = result.split(",")[1]; - return Long.parseLong(sumValue); - } + private static long parseSumValue(String result) { + String sumValue = result.split(",")[1]; + return Long.parseLong(sumValue); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalAll.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalAll.java index a50b211c1..c967ce355 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalAll.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalAll.java @@ -25,6 +25,7 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -63,203 +64,212 @@ public class IncrGraphTraversalAll { - private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalAll.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalAll.class); - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph_traversal_all"; - public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal_all"; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph_traversal_all"; + public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal_all"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } - public static IPipelineResult submit(Environment environment) { - final Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static IPipelineResult submit(Environment environment) { + final Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - //build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + // build graph view + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(envConfig.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) .withBackend(BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> vertices = - // extract vertex from edge file - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IVertex vertex1 = new ValueVertex<>( - Integer.valueOf(fields[0]), 1); - IVertex vertex2 = new ValueVertex<>( - Integer.valueOf(fields[1]), 1); - return Arrays.asList(vertex1, vertex2); - }), SizeTumblingWindow.of(10000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> vertices = + // extract vertex from edge file + pipelineTaskCxt + .buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IVertex vertex1 = + new ValueVertex<>(Integer.valueOf(fields[0]), 1); + IVertex vertex2 = + new ValueVertex<>(Integer.valueOf(fields[1]), 1); + return Arrays.asList(vertex1, vertex2); + }), + SizeTumblingWindow.of(10000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt.buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(5000)); + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(5000)); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + int mapParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); + int sinkParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + incGraphView + .incrementalTraversal(new IncGraphTraversalAlgorithms(3)) + .start() + .map(res -> res.toString()) + .withParallelism(mapParallelism) + .sink(sink) + .withParallelism(sinkParallelism); + } + }); + return pipeline.execute(); + } - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); + public static void validateResult() throws IOException { + ResultValidator.validateMapResult( + REF_FILE_PATH, + RESULT_FILE_PATH, + Comparator.comparingLong(IncrGraphTraversalAll::parseSumValue)); + } - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); + private static long parseSumValue(String result) { + String sumValue = result.split(",")[1]; + return Long.parseLong(sumValue); + } - int mapParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); - int sinkParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - incGraphView.incrementalTraversal(new IncGraphTraversalAlgorithms(3)) - .start() - .map(res -> res.toString()) - .withParallelism(mapParallelism) - .sink(sink) - .withParallelism(sinkParallelism); + public static class IncGraphTraversalAlgorithms + extends IncVertexCentricTraversal { - } - }); - - return pipeline.execute(); + public IncGraphTraversalAlgorithms(long iterations) { + super(iterations); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, - Comparator.comparingLong(IncrGraphTraversalAll::parseSumValue)); - } + @Override + public IncVertexCentricTraversalFunction + getIncTraversalFunction() { + return new IncVertexCentricTraversalFunction() { - private static long parseSumValue(String result) { - String sumValue = result.split(",")[1]; - return Long.parseLong(sumValue); - } + private IncVertexCentricTraversalFuncContext + vertexCentricFuncContext; - public static class IncGraphTraversalAlgorithms extends IncVertexCentricTraversal { - - public IncGraphTraversalAlgorithms(long iterations) { - super(iterations); + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; } @Override - public IncVertexCentricTraversalFunction getIncTraversalFunction() { - return new IncVertexCentricTraversalFunction() { - - private IncVertexCentricTraversalFuncContext vertexCentricFuncContext; - - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - LOGGER.debug("evolve vId:{}", vertexId); - MutableGraph mutableGraph = this.vertexCentricFuncContext.getMutableGraph(); - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(0, edge); - } - } - } - - @Override - public void init(ITraversalRequest traversalRequest) { - int requestId = traversalRequest.getVId(); - List> edges = - this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); - int sum = 0; - if (edges != null) { - for (IEdge edge : edges) { - sum += edge.getValue(); - } - } - this.vertexCentricFuncContext.takeResponse(new TraversalResponse(requestId, - sum)); - } - - @Override - public void finish() { - - } - - @Override - public void close() { - - } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - } - - @Override - public void finish(Integer vertexId, - MutableGraph mutableGraph) { - } - }; + public void evolve( + Integer vertexId, TemporaryGraph temporaryGraph) { + LOGGER.debug("evolve vId:{}", vertexId); + MutableGraph mutableGraph = + this.vertexCentricFuncContext.getMutableGraph(); + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + } + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(0, edge); + } + } } @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + public void init(ITraversalRequest traversalRequest) { + int requestId = traversalRequest.getVId(); + List> edges = + this.vertexCentricFuncContext.getHistoricalGraph().getSnapShot(0).edges().getEdges(); + int sum = 0; + if (edges != null) { + for (IEdge edge : edges) { + sum += edge.getValue(); + } + } + this.vertexCentricFuncContext.takeResponse(new TraversalResponse(requestId, sum)); } + @Override + public void finish() {} + + @Override + public void close() {} + + @Override + public void compute(Integer vertexId, Iterator messageIterator) {} + + @Override + public void finish( + Integer vertexId, MutableGraph mutableGraph) {} + }; } - static class TraversalResponse implements ITraversalResponse { + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + } - private long responseId; - private int value; + static class TraversalResponse implements ITraversalResponse { - public TraversalResponse(long responseId, int value) { - this.responseId = responseId; - this.value = value; - } + private long responseId; + private int value; - @Override - public long getResponseId() { - return responseId; - } + public TraversalResponse(long responseId, int value) { + this.responseId = responseId; + this.value = value; + } - @Override - public Integer getResponse() { - return value; - } + @Override + public long getResponseId() { + return responseId; + } - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public Integer getResponse() { + return value; + } - @Override - public String toString() { - return responseId + "," + value; - } + @Override + public ResponseType getType() { + return ResponseType.Vertex; } + @Override + public String toString() { + return responseId + "," + value; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStartIds.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStartIds.java index 587e074a8..98db6cc02 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStartIds.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStartIds.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -62,190 +63,198 @@ public class IncrGraphTraversalByStartIds { - private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalByStartIds.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalByStartIds.class); - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph_traversal"; - public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal"; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/incr_graph_traversal"; + public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } - public static IPipelineResult submit(Environment environment) { - final Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static IPipelineResult submit(Environment environment) { + final Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - //build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + // build graph view + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(envConfig.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) .withBackend(BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> vertices = - // extract vertex from edge file - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> vertices = + // extract vertex from edge file + pipelineTaskCxt + .buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IVertex vertex1 = + new ValueVertex<>(Integer.valueOf(fields[0]), 1); + IVertex vertex2 = + new ValueVertex<>(Integer.valueOf(fields[1]), 1); + return Arrays.asList(vertex1, vertex2); + }), + SizeTumblingWindow.of(10000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt.buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - IVertex vertex1 = new ValueVertex<>( - Integer.valueOf(fields[0]), 1); - IVertex vertex2 = new ValueVertex<>( - Integer.valueOf(fields[1]), 1); - return Arrays.asList(vertex1, vertex2); - }), SizeTumblingWindow.of(10000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(5000)); - - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(5000)); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + int mapParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); + int sinkParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + incGraphView + .incrementalTraversal(new IncGraphTraversalAlgorithms(3)) + .start(3) + .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) + .withParallelism(mapParallelism) + .sink(sink) + .withParallelism(sinkParallelism); + } + }); - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); + return pipeline.execute(); + } - int mapParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); - int sinkParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - incGraphView.incrementalTraversal(new IncGraphTraversalAlgorithms(3)) - .start(3) - .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) - .withParallelism(mapParallelism) - .sink(sink) - .withParallelism(sinkParallelism); + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); + } - } - }); + public static class IncGraphTraversalAlgorithms + extends IncVertexCentricTraversal { - return pipeline.execute(); + public IncGraphTraversalAlgorithms(long iterations) { + super(iterations); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + @Override + public IncVertexCentricTraversalFunction + getIncTraversalFunction() { + return new IncVertexCentricTraversalFunction() { + private IncVertexCentricTraversalFuncContext + vertexCentricFuncContext; - public static class IncGraphTraversalAlgorithms extends IncVertexCentricTraversal { + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; + } - public IncGraphTraversalAlgorithms(long iterations) { - super(iterations); + @Override + public void evolve( + Integer vertexId, TemporaryGraph temporaryGraph) { + LOGGER.debug("evolve vId:{}", vertexId); + MutableGraph mutableGraph = + this.vertexCentricFuncContext.getMutableGraph(); + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + } + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(0, edge); + } + } } @Override - public IncVertexCentricTraversalFunction getIncTraversalFunction() { - return new IncVertexCentricTraversalFunction() { - - private IncVertexCentricTraversalFuncContext vertexCentricFuncContext; - - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - LOGGER.debug("evolve vId:{}", vertexId); - MutableGraph mutableGraph = - this.vertexCentricFuncContext.getMutableGraph(); - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(0, edge); - } - } - } - - @Override - public void init(ITraversalRequest traversalRequest) { - int requestId = traversalRequest.getVId(); - this.vertexCentricFuncContext.sendMessageToNeighbors(requestId); - } - - @Override - public void finish() { - - } - - @Override - public void close() { - - } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - int sum = 0; - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - this.vertexCentricFuncContext.takeResponse(new TraversalResponse(vertexId, sum)); - LOGGER.info("take response:{}-{}", vertexId, sum); - } - - @Override - public void finish(Integer vertexId, - MutableGraph mutableGraph) { - } - }; + public void init(ITraversalRequest traversalRequest) { + int requestId = traversalRequest.getVId(); + this.vertexCentricFuncContext.sendMessageToNeighbors(requestId); } @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + public void finish() {} + + @Override + public void close() {} + + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + int sum = 0; + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + this.vertexCentricFuncContext.takeResponse(new TraversalResponse(vertexId, sum)); + LOGGER.info("take response:{}-{}", vertexId, sum); } + @Override + public void finish( + Integer vertexId, MutableGraph mutableGraph) {} + }; } - static class TraversalResponse implements ITraversalResponse { + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + } - private long responseId; - private int value; + static class TraversalResponse implements ITraversalResponse { - public TraversalResponse(long responseId, int value) { - this.responseId = responseId; - this.value = value; - } + private long responseId; + private int value; - @Override - public long getResponseId() { - return responseId; - } + public TraversalResponse(long responseId, int value) { + this.responseId = responseId; + this.value = value; + } - @Override - public Integer getResponse() { - return value; - } + @Override + public long getResponseId() { + return responseId; + } - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public Integer getResponse() { + return value; } + @Override + public ResponseType getType() { + return ResponseType.Vertex; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStream.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStream.java index 80ef4ee42..38460237f 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStream.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/dynamic/traversal/IncrGraphTraversalByStream.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction; @@ -65,205 +66,211 @@ public class IncrGraphTraversalByStream { - private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalByStream.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphTraversalByStream.class); - public static final String RESULT_FILE_PATH = - "./target/tmp/data/result/incr_stream_graph_traversal"; - public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal"; + public static final String RESULT_FILE_PATH = + "./target/tmp/data/result/incr_stream_graph_traversal"; + public static final String REF_FILE_PATH = "data/reference/incr_graph_traversal"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } - public static IPipelineResult submit(Environment environment) { - final Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static IPipelineResult submit(Environment environment) { + final Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - //build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + // build graph view + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(envConfig.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) .withBackend(BackendType.RocksDB) - .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> vertices = - // extract vertex from edge file - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> vertices = + // extract vertex from edge file + pipelineTaskCxt + .buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IVertex vertex1 = + new ValueVertex<>(Integer.valueOf(fields[0]), 1); + IVertex vertex2 = + new ValueVertex<>(Integer.valueOf(fields[1]), 1); + return Arrays.asList(vertex1, vertex2); + }), + SizeTumblingWindow.of(10000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt.buildSource( + new RecoverableFileSource<>( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - IVertex vertex1 = new ValueVertex<>( - Integer.valueOf(fields[0]), 1); - IVertex vertex2 = new ValueVertex<>( - Integer.valueOf(fields[1]), 1); - return Arrays.asList(vertex1, vertex2); - }), SizeTumblingWindow.of(10000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new RecoverableFileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(5000)); - - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); - - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); - - PWindowSource> triggerSource = - pipelineTaskCxt.buildSource( - new CollectionSource<>(getTraversalRequests()), - SizeTumblingWindow.of(1) - ); - - int mapParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); - int sinkParallelism = pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - incGraphView.incrementalTraversal(new IncGraphTraversalAlgorithms(3)) - .start(triggerSource) - .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) - .withParallelism(mapParallelism) - .sink(sink) - .withParallelism(sinkParallelism); - - } + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(5000)); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + PWindowSource> triggerSource = + pipelineTaskCxt.buildSource( + new CollectionSource<>(getTraversalRequests()), SizeTumblingWindow.of(1)); + + int mapParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.MAP_PARALLELISM); + int sinkParallelism = + pipelineTaskCxt.getConfig().getInteger(ExampleConfigKeys.SINK_PARALLELISM); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + incGraphView + .incrementalTraversal(new IncGraphTraversalAlgorithms(3)) + .start(triggerSource) + .map(res -> String.format("%s,%s", res.getResponseId(), res.getResponse())) + .withParallelism(mapParallelism) + .sink(sink) + .withParallelism(sinkParallelism); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - private static List> getTraversalRequests() { - List> list = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - list.add(new VertexBeginTraversalRequest<>(3)); - } - return list; + private static List> getTraversalRequests() { + List> list = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + list.add(new VertexBeginTraversalRequest<>(3)); } + return list; + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); + } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); + public static class IncGraphTraversalAlgorithms + extends IncVertexCentricTraversal { + + public IncGraphTraversalAlgorithms(long iterations) { + super(iterations); } + @Override + public IncVertexCentricTraversalFunction + getIncTraversalFunction() { + return new IncVertexCentricTraversalFunction() { - public static class IncGraphTraversalAlgorithms extends IncVertexCentricTraversal { + private IncVertexCentricTraversalFuncContext + vertexCentricFuncContext; - public IncGraphTraversalAlgorithms(long iterations) { - super(iterations); + @Override + public void open( + IncVertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; } @Override - public IncVertexCentricTraversalFunction getIncTraversalFunction() { - return new IncVertexCentricTraversalFunction() { - - private IncVertexCentricTraversalFuncContext vertexCentricFuncContext; - - @Override - public void open( - IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - LOGGER.debug("evolve vId:{}", vertexId); - } - - @Override - public void init(ITraversalRequest traversalRequest) { - int requestId = traversalRequest.getVId(); - this.vertexCentricFuncContext.sendMessageToNeighbors(requestId); - } - - @Override - public void finish() { - - } - - @Override - public void close() { - - } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - int sum = 0; - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - this.vertexCentricFuncContext.takeResponse(new TraversalResponse(vertexId, sum)); - LOGGER.info("take response:{}-{}", vertexId, sum); - } - - @Override - public void finish(Integer vertexId, - MutableGraph mutableGraph) { - TemporaryGraph temporaryGraph = - this.vertexCentricFuncContext.getTemporaryGraph(); - IVertex vertex = temporaryGraph.getVertex(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - } - List> edges = temporaryGraph.getEdges(); - if (edges != null) { - for (IEdge edge : edges) { - mutableGraph.addEdge(0, edge); - } - } - } - }; + public void evolve( + Integer vertexId, TemporaryGraph temporaryGraph) { + LOGGER.debug("evolve vId:{}", vertexId); } @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + public void init(ITraversalRequest traversalRequest) { + int requestId = traversalRequest.getVId(); + this.vertexCentricFuncContext.sendMessageToNeighbors(requestId); } - } - - static class TraversalResponse implements ITraversalResponse { - - private long responseId; - private int value; - - public TraversalResponse(long responseId, int value) { - this.responseId = responseId; - this.value = value; - } + @Override + public void finish() {} @Override - public long getResponseId() { - return responseId; - } + public void close() {} @Override - public Integer getResponse() { - return value; + public void compute(Integer vertexId, Iterator messageIterator) { + int sum = 0; + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + this.vertexCentricFuncContext.takeResponse(new TraversalResponse(vertexId, sum)); + LOGGER.info("take response:{}-{}", vertexId, sum); } @Override - public ResponseType getType() { - return ResponseType.Vertex; + public void finish(Integer vertexId, MutableGraph mutableGraph) { + TemporaryGraph temporaryGraph = + this.vertexCentricFuncContext.getTemporaryGraph(); + IVertex vertex = temporaryGraph.getVertex(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + } + List> edges = temporaryGraph.getEdges(); + if (edges != null) { + for (IEdge edge : edges) { + mutableGraph.addEdge(0, edge); + } + } } + }; + } + + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + } + + static class TraversalResponse implements ITraversalResponse { + + private long responseId; + private int value; + + public TraversalResponse(long responseId, int value) { + this.responseId = responseId; + this.value = value; + } + + @Override + public long getResponseId() { + return responseId; } + @Override + public Integer getResponse() { + return value; + } + + @Override + public ResponseType getType() { + return ResponseType.Vertex; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/allshortestpath/AllShortestPath.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/allshortestpath/AllShortestPath.java index 73502c617..bf390ed55 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/allshortestpath/AllShortestPath.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/allshortestpath/AllShortestPath.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -62,168 +63,182 @@ public class AllShortestPath { - private static final Logger LOGGER = LoggerFactory.getLogger(AllShortestPath.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/allshortestpath"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/allshortestpath"; - - private static final int SOURCE_ID = 1525; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = AllShortestPath.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(AllShortestPath.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/allshortestpath"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/allshortestpath"; + + private static final int SOURCE_ID = 1525; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = AllShortestPath.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource>> vSource = + new FileSource<>( + GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserObjectMap); + PWindowStream>> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource>> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); + PWindowStream>> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream>> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new AllShortestPathAlgorithms(SOURCE_ID, 10)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class AllShortestPathAlgorithms + extends VertexCentricCompute< + Integer, Map, Map, Tuple>>> { + + private final int sourceId; + + public AllShortestPathAlgorithms(int sourceId, long iterations) { + super(iterations); + this.sourceId = sourceId; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource>> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserObjectMap); - PWindowStream>> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource>> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); - PWindowStream>> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream>> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new AllShortestPathAlgorithms(SOURCE_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())).sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction< + Integer, Map, Map, Tuple>>> + getComputeFunction() { + return new AllShortestPathVertexCentricComputeFunction(this.sourceId, EdgeDirection.OUT); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction>>> getCombineFunction() { + return null; } + } - public static class AllShortestPathAlgorithms extends VertexCentricCompute, Map, Tuple>>> { + public static class AllShortestPathVertexCentricComputeFunction + extends AbstractVcFunc< + Integer, Map, Map, Tuple>>> { - private final int sourceId; - - public AllShortestPathAlgorithms(int sourceId, long iterations) { - super(iterations); - this.sourceId = sourceId; - } - - @Override - public VertexCentricComputeFunction, Map, Tuple>>> getComputeFunction() { - return new AllShortestPathVertexCentricComputeFunction(this.sourceId, EdgeDirection.OUT); - } - - @Override - public VertexCentricCombineFunction>>> getCombineFunction() { - return null; - } + private static final String KEY_FIELD = "dis"; + private static final String KEY_ALL_PATHS = "allPaths"; + private final int sourceId; + private final EdgeDirection edgeType; + public AllShortestPathVertexCentricComputeFunction(int sourceId, EdgeDirection edgeType) { + this.sourceId = sourceId; + this.edgeType = edgeType; } - public static class AllShortestPathVertexCentricComputeFunction extends AbstractVcFunc, Map, Tuple>>> { - - private static final String KEY_FIELD = "dis"; - private static final String KEY_ALL_PATHS = "allPaths"; - private final int sourceId; - private final EdgeDirection edgeType; - - public AllShortestPathVertexCentricComputeFunction(int sourceId, EdgeDirection edgeType) { - this.sourceId = sourceId; - this.edgeType = edgeType; + @Override + public void compute( + Integer vertexId, Iterator>>> messageIterator) { + IVertex> vertex = this.context.vertex().get(); + Map property = vertex.getValue(); + int dis = Integer.MAX_VALUE; + Set> path = new HashSet<>(); + if (this.context.getIterationId() == 1) { + if (vertex.getId().equals(sourceId)) { + dis = 0; + path.add(new ArrayList<>(Collections.singletonList(sourceId))); + sendMessage(dis, path); } - - @Override - public void compute(Integer vertexId, - Iterator>>> messageIterator) { - IVertex> vertex = this.context.vertex().get(); - Map property = vertex.getValue(); - int dis = Integer.MAX_VALUE; - Set> path = new HashSet<>(); - if (this.context.getIterationId() == 1) { - if (vertex.getId().equals(sourceId)) { - dis = 0; - path.add(new ArrayList<>(Collections.singletonList(sourceId))); - sendMessage(dis, path); - } - property.put(KEY_ALL_PATHS, path); - property.put(KEY_FIELD, dis); - this.context.setNewVertexValue(property); - return; - } - - dis = (int) property.get(KEY_FIELD); - Tuple>> shortestDis = messageIterator.next(); - while (messageIterator.hasNext()) { - Tuple>> msg = messageIterator.next(); - if (shortestDis.getF0() > msg.getF0()) { - shortestDis = msg; - } else if (shortestDis.getF0().equals(msg.getF0())) { - for (List l : msg.getF1()) { - shortestDis.getF1().add(new ArrayList<>(l)); - } - } - } - - if (shortestDis.getF0() <= dis) { - if (shortestDis.getF0() == dis) { - path.addAll((Collection>) property.get(KEY_ALL_PATHS)); - } - dis = shortestDis.getF0(); - property.put(KEY_FIELD, dis); - for (List p : shortestDis.getF1()) { - List list = new ArrayList<>(p); - list.add(vertex.getId()); - path.add(list); - } - property.put(KEY_ALL_PATHS, path); - this.context.setNewVertexValue(property); - sendMessage(dis, path); - } + property.put(KEY_ALL_PATHS, path); + property.put(KEY_FIELD, dis); + this.context.setNewVertexValue(property); + return; + } + + dis = (int) property.get(KEY_FIELD); + Tuple>> shortestDis = messageIterator.next(); + while (messageIterator.hasNext()) { + Tuple>> msg = messageIterator.next(); + if (shortestDis.getF0() > msg.getF0()) { + shortestDis = msg; + } else if (shortestDis.getF0().equals(msg.getF0())) { + for (List l : msg.getF1()) { + shortestDis.getF1().add(new ArrayList<>(l)); + } } + } - private void sendMessage(int dis, Set> path) { - switch (this.edgeType) { - case OUT: - for (IEdge> edge : this.context.edges() - .getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), new Tuple<>(dis + edge.getValue().get(KEY_FIELD), path)); - } - break; - case IN: - for (IEdge> edge : this.context.edges() - .getInEdges()) { - this.context.sendMessage(edge.getTargetId(), new Tuple<>(dis + edge.getValue().get(KEY_FIELD), path)); - } - break; - default: - break; - } + if (shortestDis.getF0() <= dis) { + if (shortestDis.getF0() == dis) { + path.addAll((Collection>) property.get(KEY_ALL_PATHS)); } - + dis = shortestDis.getF0(); + property.put(KEY_FIELD, dis); + for (List p : shortestDis.getF1()) { + List list = new ArrayList<>(p); + list.add(vertex.getId()); + path.add(list); + } + property.put(KEY_ALL_PATHS, path); + this.context.setNewVertexValue(property); + sendMessage(dis, path); + } } + private void sendMessage(int dis, Set> path) { + switch (this.edgeType) { + case OUT: + for (IEdge> edge : this.context.edges().getOutEdges()) { + this.context.sendMessage( + edge.getTargetId(), new Tuple<>(dis + edge.getValue().get(KEY_FIELD), path)); + } + break; + case IN: + for (IEdge> edge : this.context.edges().getInEdges()) { + this.context.sendMessage( + edge.getTargetId(), new Tuple<>(dis + edge.getValue().get(KEY_FIELD), path)); + } + break; + default: + break; + } + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/averagedegree/AverageDegree.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/averagedegree/AverageDegree.java index c4ab6c281..87c931970 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/averagedegree/AverageDegree.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/averagedegree/AverageDegree.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.compute.VertexCentricAggCompute; @@ -56,199 +57,237 @@ public class AverageDegree { - private static final Logger LOGGER = LoggerFactory.getLogger(AverageDegree.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/average"; - public static final String REF_FILE_PATH = "data/reference/averagedegree"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = AverageDegree.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(AverageDegree.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/average"; + public static final String REF_FILE_PATH = "data/reference/averagedegree"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = AverageDegree.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(2) + .withBackend(BackendType.Memory) + .build(); + + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + SinkFunction> sink = + ExampleSinkFunctionFactory.getSinkFunction(conf); + graphWindow + .compute(new AverageDegreeAlgorithms(3)) + .compute(conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) + .getVertices() + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } + + public static class AverageDegreeAlgorithms + extends VertexCentricAggCompute< + Integer, + Double, + Integer, + Integer, + Integer, + Tuple, + Tuple, + Tuple, + Integer> { + + public AverageDegreeAlgorithms(long iterations) { + super(iterations); } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> prEdges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(2) - .withBackend(BackendType.Memory) - .build(); - - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - SinkFunction> sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - graphWindow.compute(new AverageDegreeAlgorithms(3)) - .compute(conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM)) - .getVertices() - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - }); - - return pipeline.execute(); - } + @Override + public VertexCentricAggComputeFunction + getComputeFunction() { + return new VertexCentricAggComputeFunction< + Integer, Double, Integer, Integer, Integer, Integer>() { - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + private VertexCentricComputeFuncContext + vertexCentricFuncContext; + private VertexCentricAggContext aggContext; - public static class AverageDegreeAlgorithms extends VertexCentricAggCompute, Tuple, - Tuple, Integer> { + @Override + public void init( + VertexCentricComputeFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; + } - public AverageDegreeAlgorithms(long iterations) { - super(iterations); + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.aggContext = aggContext; } @Override - public VertexCentricAggComputeFunction getComputeFunction() { - return new VertexCentricAggComputeFunction() { - - private VertexCentricComputeFuncContext vertexCentricFuncContext; - private VertexCentricAggContext aggContext; - - @Override - public void init(VertexCentricComputeFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.aggContext = aggContext; - } - - - @Override - public void compute(Integer vertex, Iterator messageIterator) { - if (vertexCentricFuncContext.getIterationId() == 1) { - int degreeSize = vertexCentricFuncContext.edges().getOutEdges().size(); - vertexCentricFuncContext.setNewVertexValue(Double.valueOf(degreeSize)); - aggContext.aggregate(degreeSize); - vertexCentricFuncContext.sendMessage(vertex, degreeSize); - } else { - int sum = 0; - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - aggContext.aggregate(sum); - vertexCentricFuncContext.setNewVertexValue(Double.valueOf(aggContext.getAggregateResult())); - } - } - - @Override - public void finish() { - - } - }; + public void compute(Integer vertex, Iterator messageIterator) { + if (vertexCentricFuncContext.getIterationId() == 1) { + int degreeSize = vertexCentricFuncContext.edges().getOutEdges().size(); + vertexCentricFuncContext.setNewVertexValue(Double.valueOf(degreeSize)); + aggContext.aggregate(degreeSize); + vertexCentricFuncContext.sendMessage(vertex, degreeSize); + } else { + int sum = 0; + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + aggContext.aggregate(sum); + vertexCentricFuncContext.setNewVertexValue( + Double.valueOf(aggContext.getAggregateResult())); + } } @Override - public VertexCentricAggregateFunction, Tuple, Tuple, Integer> getAggregateFunction() { - return new VertexCentricAggregateFunction, - Tuple, Tuple, Integer>() { - @Override - public IPartialGraphAggFunction, Tuple> getPartialAggregation() { - return new IPartialGraphAggFunction, Tuple>() { - - private IPartialAggContext> partialAggContext; - - @Override - public Tuple create( - IPartialAggContext> partialAggContext) { - this.partialAggContext = partialAggContext; - return Tuple.of(0, 0); - } - - - @Override - public Tuple aggregate(Integer integer, Tuple result) { - result.f0 += 1; - result.f1 += integer; - return result; - } - - @Override - public void finish(Tuple result) { - partialAggContext.collect(result); - } - }; - } - - @Override - public IGraphAggregateFunction, Tuple, - Integer> getGlobalAggregation() { - return new IGraphAggregateFunction, Tuple, Integer>() { - - private IGlobalGraphAggContext globalGraphAggContext; - - @Override - public Tuple create( - IGlobalGraphAggContext globalGraphAggContext) { - this.globalGraphAggContext = globalGraphAggContext; - return Tuple.of(0, 0); - } - - @Override - public Integer aggregate(Tuple integerIntegerTuple2, - Tuple integerIntegerTuple22) { - integerIntegerTuple22.f0 += integerIntegerTuple2.f0; - integerIntegerTuple22.f1 += integerIntegerTuple2.f1; - return (int) (integerIntegerTuple22.f1 / integerIntegerTuple22.f0); - } - - @Override - public void finish(Integer value) { - long iterationId = this.globalGraphAggContext.getIteration(); - if (value == 0) { - LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); - this.globalGraphAggContext.terminate(); - } else { - LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); - this.globalGraphAggContext.broadcast(value); - } - } - }; - } - }; + public void finish() {} + }; + } + + @Override + public VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer> + getAggregateFunction() { + return new VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer>() { + @Override + public IPartialGraphAggFunction, Tuple> + getPartialAggregation() { + return new IPartialGraphAggFunction< + Integer, Tuple, Tuple>() { + + private IPartialAggContext> partialAggContext; + + @Override + public Tuple create( + IPartialAggContext> partialAggContext) { + this.partialAggContext = partialAggContext; + return Tuple.of(0, 0); + } + + @Override + public Tuple aggregate( + Integer integer, Tuple result) { + result.f0 += 1; + result.f1 += integer; + return result; + } + + @Override + public void finish(Tuple result) { + partialAggContext.collect(result); + } + }; } @Override - public VertexCentricCombineFunction getCombineFunction() { - return new VertexCentricCombineFunction() { - @Override - public Integer combine(Integer oldMessage, Integer newMessage) { - return oldMessage + newMessage; - } - }; + public IGraphAggregateFunction, Tuple, Integer> + getGlobalAggregation() { + return new IGraphAggregateFunction< + Tuple, Tuple, Integer>() { + + private IGlobalGraphAggContext globalGraphAggContext; + + @Override + public Tuple create( + IGlobalGraphAggContext globalGraphAggContext) { + this.globalGraphAggContext = globalGraphAggContext; + return Tuple.of(0, 0); + } + + @Override + public Integer aggregate( + Tuple integerIntegerTuple2, + Tuple integerIntegerTuple22) { + integerIntegerTuple22.f0 += integerIntegerTuple2.f0; + integerIntegerTuple22.f1 += integerIntegerTuple2.f1; + return (int) (integerIntegerTuple22.f1 / integerIntegerTuple22.f0); + } + + @Override + public void finish(Integer value) { + long iterationId = this.globalGraphAggContext.getIteration(); + if (value == 0) { + LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); + this.globalGraphAggContext.terminate(); + } else { + LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); + this.globalGraphAggContext.broadcast(value); + } + } + }; + } + }; + } + + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new VertexCentricCombineFunction() { + @Override + public Integer combine(Integer oldMessage, Integer newMessage) { + return oldMessage + newMessage; } + }; } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/closenesscentrality/ClosenessCentrality.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/closenesscentrality/ClosenessCentrality.java index cb9c915c7..1534dfabb 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/closenesscentrality/ClosenessCentrality.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/closenesscentrality/ClosenessCentrality.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -54,127 +55,137 @@ public class ClosenessCentrality { - private static final Logger LOGGER = LoggerFactory.getLogger(ClosenessCentrality.class); - public static final int SOURCE_ID = 11342; - - public static final String REFERENCE_FILE_PATH = "data/reference/closeness"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/closeness"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = ClosenessCentrality.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(ClosenessCentrality.class); + public static final int SOURCE_ID = 11342; + + public static final String REFERENCE_FILE_PATH = "data/reference/closeness"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/closeness"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = ClosenessCentrality.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource>> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserTuple); + PWindowStream>> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream>> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new ClosenessCentralityAlgorithm(SOURCE_ID, 10)) + .compute(iterationParallelism) + .getVertices(); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class ClosenessCentralityAlgorithm + extends VertexCentricCompute, Integer, Integer> { + + private final int srcId; + + public ClosenessCentralityAlgorithm(int srcId, long iterations) { + super(iterations); + this.srcId = srcId; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource>> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserTuple); - PWindowStream>> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowStream> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream>> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new ClosenessCentralityAlgorithm(SOURCE_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction, Integer, Integer> + getComputeFunction() { + return new ClosenessCentralityVCCFunction(this.srcId); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - public static class ClosenessCentralityAlgorithm extends VertexCentricCompute, Integer, Integer> { + public static class ClosenessCentralityVCCFunction + extends AbstractVcFunc, Integer, Integer> { - private final int srcId; - - public ClosenessCentralityAlgorithm(int srcId, long iterations) { - super(iterations); - this.srcId = srcId; - } - - @Override - public VertexCentricComputeFunction, Integer, Integer> getComputeFunction() { - return new ClosenessCentralityVCCFunction(this.srcId); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final int srcId; + public ClosenessCentralityVCCFunction(int srcId) { + this.srcId = srcId; } - public static class ClosenessCentralityVCCFunction extends AbstractVcFunc, Integer, Integer> { - - private final int srcId; - - public ClosenessCentralityVCCFunction(int srcId) { - this.srcId = srcId; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex> vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(Tuple.of(0.0, 0)); + if (vertex.getId().equals(this.srcId)) { + this.context.setNewVertexValue(Tuple.of(1.0, 0)); + this.context.sendMessageToNeighbors(1); } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex> vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(Tuple.of(0.0, 0)); - if (vertex.getId().equals(this.srcId)) { - this.context.setNewVertexValue(Tuple.of(1.0, 0)); - this.context.sendMessageToNeighbors(1); - } - } else { - if (vertex.getId().equals(this.srcId)) { - int vertexNum = vertex.getValue().getF1(); - double sum = vertexNum / vertex.getValue().getF0(); - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - vertexNum++; - } - this.context.setNewVertexValue(Tuple.of(vertexNum / sum, vertexNum)); - } else { - if (vertex.getValue().getF1() < 1) { - Integer msg = messageIterator.next(); - this.context.sendMessage(this.srcId, msg); - for (IEdge edge : this.context.edges().getOutEdges()) { - if (!edge.getTargetId().equals(this.srcId)) { - this.context.sendMessage(edge.getTargetId(), msg + 1); - } - } - this.context.setNewVertexValue(Tuple.of(0.0, 1)); - } - } + } else { + if (vertex.getId().equals(this.srcId)) { + int vertexNum = vertex.getValue().getF1(); + double sum = vertexNum / vertex.getValue().getF0(); + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + vertexNum++; + } + this.context.setNewVertexValue(Tuple.of(vertexNum / sum, vertexNum)); + } else { + if (vertex.getValue().getF1() < 1) { + Integer msg = messageIterator.next(); + this.context.sendMessage(this.srcId, msg); + for (IEdge edge : this.context.edges().getOutEdges()) { + if (!edge.getTargetId().equals(this.srcId)) { + this.context.sendMessage(edge.getTargetId(), msg + 1); + } } + this.context.setNewVertexValue(Tuple.of(0.0, 1)); + } } - + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/clustercoefficient/ClusterCoefficient.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/clustercoefficient/ClusterCoefficient.java index 0cd3c8b5f..ce8255286 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/clustercoefficient/ClusterCoefficient.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/clustercoefficient/ClusterCoefficient.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -53,135 +54,145 @@ public class ClusterCoefficient { - private static final Logger LOGGER = LoggerFactory.getLogger(ClusterCoefficient.class); - - public static final int SOURCE_ID = 11342; - - public static final String REFERENCE_FILE_PATH = "data/reference/clustercoefficient"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/clustercoefficient"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = ClusterCoefficient.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(ClusterCoefficient.class); + + public static final int SOURCE_ID = 11342; + + public static final String REFERENCE_FILE_PATH = "data/reference/clustercoefficient"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/clustercoefficient"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = ClusterCoefficient.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserDouble); + PWindowStream> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new ClusterCoefficientAlgorithm(SOURCE_ID, 10)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class ClusterCoefficientAlgorithm + extends VertexCentricCompute { + + private final int srcId; + + public ClusterCoefficientAlgorithm(int srcId, long iterations) { + super(iterations); + this.srcId = srcId; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserDouble); - PWindowStream> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowStream> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new ClusterCoefficientAlgorithm(SOURCE_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new ClusterCoefficientVCCFunction(this.srcId); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - public static class ClusterCoefficientAlgorithm extends VertexCentricCompute { + public static class ClusterCoefficientVCCFunction + extends AbstractVcFunc { - private final int srcId; - - public ClusterCoefficientAlgorithm(int srcId, long iterations) { - super(iterations); - this.srcId = srcId; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new ClusterCoefficientVCCFunction(this.srcId); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final int srcId; + public ClusterCoefficientVCCFunction(int srcId) { + this.srcId = srcId; } - public static class ClusterCoefficientVCCFunction extends AbstractVcFunc { - - private final int srcId; - - public ClusterCoefficientVCCFunction(int srcId) { - this.srcId = srcId; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1L) { + if (vertex.getId().equals(this.srcId)) { + sendMessage(1); } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1L) { - if (vertex.getId().equals(this.srcId)) { - sendMessage(1); - } - } else if (this.context.getIterationId() == 2L) { - sendMessage(1); - } else if (this.context.getIterationId() == 3L) { - if (!vertex.getId().equals(this.srcId)) { - int meg = 0; - while (messageIterator.hasNext()) { - messageIterator.next(); - meg++; - } - sendMessage(meg); - } - } else if (this.context.getIterationId() == 4L) { - if (vertex.getId().equals(this.srcId)) { - int edgeNum = 0; - while (messageIterator.hasNext()) { - int value = messageIterator.next(); - edgeNum += value; - } - int degree = this.context.edges().getInEdges().size() + this.context.edges().getOutEdges().size(); - this.context.setNewVertexValue(((double) edgeNum) / (degree * (degree - 1))); - } - } + } else if (this.context.getIterationId() == 2L) { + sendMessage(1); + } else if (this.context.getIterationId() == 3L) { + if (!vertex.getId().equals(this.srcId)) { + int meg = 0; + while (messageIterator.hasNext()) { + messageIterator.next(); + meg++; + } + sendMessage(meg); } - - private void sendMessage(int meg) { - for (IEdge edge : this.context.edges().getInEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } + } else if (this.context.getIterationId() == 4L) { + if (vertex.getId().equals(this.srcId)) { + int edgeNum = 0; + while (messageIterator.hasNext()) { + int value = messageIterator.next(); + edgeNum += value; + } + int degree = + this.context.edges().getInEdges().size() + this.context.edges().getOutEdges().size(); + this.context.setNewVertexValue(((double) edgeNum) / (degree * (degree - 1))); } - + } } + private void sendMessage(int meg) { + for (IEdge edge : this.context.edges().getInEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/commonneighbors/CommonNeighbors.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/commonneighbors/CommonNeighbors.java index 1d5d89b91..9a2434e32 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/commonneighbors/CommonNeighbors.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/commonneighbors/CommonNeighbors.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -53,127 +54,136 @@ public class CommonNeighbors { - private static final Logger LOGGER = LoggerFactory.getLogger(CommonNeighbors.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/commonneighbor"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/commonneighbor"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = CommonNeighbors.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(CommonNeighbors.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/commonneighbor"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/commonneighbor"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = CommonNeighbors.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>( + GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); + PWindowStream> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new CommonNeighborsAlgorithm(398191, 722800, 50)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class CommonNeighborsAlgorithm + extends VertexCentricCompute { + + private final int id1; + private final int id2; + + public CommonNeighborsAlgorithm(int id1, int id2, long iterations) { + super(iterations); + this.id1 = id1; + this.id2 = id2; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); - PWindowStream> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowStream> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new CommonNeighborsAlgorithm(398191, 722800, 50)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new CommonNeighborsVCCFunction(this.id1, this.id2); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } + public static class CommonNeighborsVCCFunction + extends AbstractVcFunc { - public static class CommonNeighborsAlgorithm extends VertexCentricCompute { - - private final int id1; - private final int id2; - - public CommonNeighborsAlgorithm(int id1, int id2, long iterations) { - super(iterations); - this.id1 = id1; - this.id2 = id2; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new CommonNeighborsVCCFunction(this.id1, this.id2); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final int vertexId1; + private final int vertexId2; + public CommonNeighborsVCCFunction(int vertexId1, int vertexId2) { + this.vertexId1 = vertexId1; + this.vertexId2 = vertexId2; } - public static class CommonNeighborsVCCFunction extends AbstractVcFunc { - - private final int vertexId1; - private final int vertexId2; - - public CommonNeighborsVCCFunction(int vertexId1, int vertexId2) { - this.vertexId1 = vertexId1; - this.vertexId2 = vertexId2; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(0); + if (vertex.getId().equals(this.vertexId1) || vertex.getId().equals(this.vertexId2)) { + sendMessage(vertex.getId()); } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(0); - if (vertex.getId().equals(this.vertexId1) || vertex.getId().equals(this.vertexId2)) { - sendMessage(vertex.getId()); - } - } else { - int meg = messageIterator.next(); - while (messageIterator.hasNext()) { - if (meg != messageIterator.next()) { - this.context.setNewVertexValue(1); - break; - } - } - } + } else { + int meg = messageIterator.next(); + while (messageIterator.hasNext()) { + if (meg != messageIterator.next()) { + this.context.setNewVertexValue(1); + break; + } } - - private void sendMessage(int meg) { - for (IEdge edge : this.context.edges().getInEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } - } - + } } + private void sendMessage(int meg) { + for (IEdge edge : this.context.edges().getInEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/kcore/KCore.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/kcore/KCore.java index 2d0533fa7..d67262160 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/kcore/KCore.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/kcore/KCore.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -55,137 +56,145 @@ public class KCore { - private static final Logger LOGGER = LoggerFactory.getLogger(KCore.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/kcore"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/kcore"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = KCore.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(KCore.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/kcore"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/kcore"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = KCore.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>( + GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserBoolean); + PWindowStream> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserBoolean); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new KCoreAlgorithm(2, 50)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class KCoreAlgorithm + extends VertexCentricCompute { + + private final int core; + + public KCoreAlgorithm(int core, long iterations) { + super(iterations); + this.core = core; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserBoolean); - PWindowStream> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserBoolean); - PWindowStream> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new KCoreAlgorithm(2, 50)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new KCoreVCCFunction(core); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - public static class KCoreAlgorithm extends VertexCentricCompute { + public static class KCoreVCCFunction extends AbstractVcFunc { - private final int core; - - public KCoreAlgorithm(int core, long iterations) { - super(iterations); - this.core = core; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new KCoreVCCFunction(core); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final int core; + public KCoreVCCFunction(int core) { + this.core = core; } - public static class KCoreVCCFunction extends AbstractVcFunc { - - private final int core; - - public KCoreVCCFunction(int core) { - this.core = core; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(true); + } else { + if (!vertex.getValue()) { + return; } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(true); - } else { - if (!vertex.getValue()) { - return; - } - List removedVertexIds = new ArrayList<>(); - while (messageIterator.hasNext()) { - removedVertexIds.add(messageIterator.next()); - } - List> outEdges = this.context.edges().getOutEdges(); - for (IEdge outEdge : outEdges) { - if (removedVertexIds.contains(outEdge.getTargetId())) { - outEdge.withValue(false); - } - } - } - boolean vertexAvailable = vertexEdgeAvailableCheck(); - if (!vertexAvailable) { - this.context.sendMessageToNeighbors(vertex.getId()); - } + List removedVertexIds = new ArrayList<>(); + while (messageIterator.hasNext()) { + removedVertexIds.add(messageIterator.next()); } - - private boolean vertexEdgeAvailableCheck() { - List> outEdges = this.context.edges().getOutEdges(); - int availableEdgeCount; - if (outEdges != null) { - availableEdgeCount = (int) outEdges.stream().filter(IEdge::getValue).count(); - } else { - availableEdgeCount = 0; - } - if (availableEdgeCount < core) { - this.context.setNewVertexValue(false); - return false; - } else { - return true; - } + List> outEdges = this.context.edges().getOutEdges(); + for (IEdge outEdge : outEdges) { + if (removedVertexIds.contains(outEdge.getTargetId())) { + outEdge.withValue(false); + } } - + } + boolean vertexAvailable = vertexEdgeAvailableCheck(); + if (!vertexAvailable) { + this.context.sendMessageToNeighbors(vertex.getId()); + } } + private boolean vertexEdgeAvailableCheck() { + List> outEdges = this.context.edges().getOutEdges(); + int availableEdgeCount; + if (outEdges != null) { + availableEdgeCount = (int) outEdges.stream().filter(IEdge::getValue).count(); + } else { + availableEdgeCount = 0; + } + if (availableEdgeCount < core) { + this.context.setNewVertexValue(false); + return false; + } else { + return true; + } + } + } } - diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/khop/KHop.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/khop/KHop.java index 20b557e8f..8e39bbadf 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/khop/KHop.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/khop/KHop.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.Objects; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -55,114 +56,129 @@ public class KHop { - private static final Logger LOGGER = LoggerFactory.getLogger(KHop.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/KHop"; - public static final String REF_FILE_PATH = "data/reference/KHop"; - - private static int k = 2; - private static Object srcId = 990; - - public KHop(Object inputId, int inputK) { - srcId = inputId; - k = inputK; + private static final Logger LOGGER = LoggerFactory.getLogger(KHop.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/KHop"; + public static final String REF_FILE_PATH = "data/reference/KHop"; + + private static int k = 2; + private static Object srcId = 990; + + public KHop(Object inputId, int inputK) { + srcId = inputId; + k = inputK; + } + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + int sinkParallelism = conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + PWindowSource> vertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>(fields[0], Integer.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> edges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>(fields[0], fields[1], 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(2) + .withBackend(BackendType.Memory) + .build(); + + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new KHAlgorithms(k + 1)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + result + .filter(v -> v.getValue() < k + 1) + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } + + public static class KHAlgorithms extends VertexCentricCompute { + + public KHAlgorithms(long iterations) { + super(iterations); } - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new KHVertexCentricComputeFunction(); } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - int sinkParallelism = conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - PWindowSource> vertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - fields[0], Integer.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> edges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(fields[0], fields[1], 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(2) - .withBackend(BackendType.Memory) - .build(); - - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new KHAlgorithms(k + 1)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - result.filter(v -> v.getValue() < k + 1).map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } - - public static class KHAlgorithms extends VertexCentricCompute { - - public KHAlgorithms(long iterations) { - super(iterations); + } + + public static class KHVertexCentricComputeFunction + extends AbstractVcFunc { + + @Override + public void compute(Object vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1L) { + if (Objects.equals(vertex.getId(), srcId)) { + this.context.sendMessageToNeighbors(1); + this.context.setNewVertexValue(0); + } else { + this.context.setNewVertexValue(Integer.MAX_VALUE); } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new KHVertexCentricComputeFunction(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - - } - - public static class KHVertexCentricComputeFunction extends AbstractVcFunc { - - @Override - public void compute(Object vertexId, - Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1L) { - if (Objects.equals(vertex.getId(), srcId)) { - this.context.sendMessageToNeighbors(1); - this.context.setNewVertexValue(0); - } else { - this.context.setNewVertexValue(Integer.MAX_VALUE); - } - } else { - if (vertex.getValue() == Integer.MAX_VALUE && messageIterator.hasNext()) { - int value = messageIterator.next(); - this.context.sendMessageToNeighbors(value + 1); - this.context.setNewVertexValue(value); - } - } + } else { + if (vertex.getValue() == Integer.MAX_VALUE && messageIterator.hasNext()) { + int value = messageIterator.next(); + this.context.sendMessageToNeighbors(value + 1); + this.context.setNewVertexValue(value); } + } } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/labelpropagation/LabelPropagation.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/labelpropagation/LabelPropagation.java index 1c71bbc04..ce1fbfbca 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/labelpropagation/LabelPropagation.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/labelpropagation/LabelPropagation.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -53,112 +54,121 @@ public class LabelPropagation { - private static final Logger LOGGER = LoggerFactory.getLogger(LabelPropagation.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/labelpropagation"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/labelpropagation"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = LabelPropagation.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(LabelPropagation.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/labelpropagation"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/labelpropagation"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = LabelPropagation.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>( + GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); + PWindowStream> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new LabelPropagationAlgorithm(50)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class LabelPropagationAlgorithm + extends VertexCentricCompute { + + public LabelPropagationAlgorithm(long iterations) { + super(iterations); } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); - PWindowStream> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowStream> edges = pipelineTaskCxt.buildSource(eSource, - AllWindow.getInstance()) - .withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new LabelPropagationAlgorithm(50)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new LabelPropagationVCCFunction(); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - - public static class LabelPropagationAlgorithm extends VertexCentricCompute { - - public LabelPropagationAlgorithm(long iterations) { - super(iterations); - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new LabelPropagationVCCFunction(); + } + + public static class LabelPropagationVCCFunction + extends AbstractVcFunc { + + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(vertex.getId()); + sendMessage(vertex.getId()); + } else { + Integer newLabel = Integer.MIN_VALUE; + while (messageIterator.hasNext()) { + Integer next = messageIterator.next(); + newLabel = Math.max(newLabel, next); } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + if (vertex.getValue() < newLabel) { + this.context.setNewVertexValue(newLabel); + sendMessage(newLabel); } - + } } - public static class LabelPropagationVCCFunction extends AbstractVcFunc { - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(vertex.getId()); - sendMessage(vertex.getId()); - } else { - Integer newLabel = Integer.MIN_VALUE; - while (messageIterator.hasNext()) { - Integer next = messageIterator.next(); - newLabel = Math.max(newLabel, next); - } - - if (vertex.getValue() < newLabel) { - this.context.setNewVertexValue(newLabel); - sendMessage(newLabel); - } - } - } - - private void sendMessage(Integer label) { - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), label); - } - } - + private void sendMessage(Integer label) { + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), label); + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/linkprediction/LinkPrediction.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/linkprediction/LinkPrediction.java index 0e211ab7b..7628b73b3 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/linkprediction/LinkPrediction.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/linkprediction/LinkPrediction.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.HashSet; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -54,140 +55,149 @@ import org.slf4j.LoggerFactory; public class LinkPrediction { - private static final Logger LOGGER = LoggerFactory.getLogger(LinkPrediction.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/linkprediction"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/linkprediction"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = LinkPrediction.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(LinkPrediction.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/linkprediction"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/linkprediction"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = LinkPrediction.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserDouble); + PWindowSource> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowSource> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new LinkPredictionAlgorithm(398191, 722800, 50)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class LinkPredictionAlgorithm + extends VertexCentricCompute { + + private final int id1; + private final int id2; + + public LinkPredictionAlgorithm(int id1, int id2, long iterations) { + super(iterations); + this.id1 = id1; + this.id2 = id2; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserDouble); - PWindowSource> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowSource> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new LinkPredictionAlgorithm(398191, 722800, 50)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new LinkPredictionVCCFunction(this.id1, this.id2); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } + public static class LinkPredictionVCCFunction + extends AbstractVcFunc { - public static class LinkPredictionAlgorithm extends VertexCentricCompute { + private final int vertexId1; + private final int vertexId2; - private final int id1; - private final int id2; + public LinkPredictionVCCFunction(int vertexId1, int vertexId2) { + this.vertexId1 = vertexId1; + this.vertexId2 = vertexId2; + } - public LinkPredictionAlgorithm(int id1, int id2, long iterations) { - super(iterations); - this.id1 = id1; - this.id2 = id2; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(0.0); + if (vertex.getId().equals(this.vertexId1)) { + for (IEdge edge : this.context.edges().getInEdges()) { + this.context.sendMessage(this.vertexId2, edge.getTargetId()); + } + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(this.vertexId2, edge.getTargetId()); + } } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new LinkPredictionVCCFunction(this.id1, this.id2); + } else if (this.context.getIterationId() == 2) { + HashSet neighbors = new HashSet<>(); + while (messageIterator.hasNext()) { + neighbors.add(messageIterator.next()); } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; + for (IEdge edge : this.context.edges().getInEdges()) { + if (neighbors.contains(edge.getTargetId())) { + this.context.sendMessage(edge.getTargetId(), 1); + } } - - } - - public static class LinkPredictionVCCFunction extends AbstractVcFunc { - - private final int vertexId1; - private final int vertexId2; - - public LinkPredictionVCCFunction(int vertexId1, int vertexId2) { - this.vertexId1 = vertexId1; - this.vertexId2 = vertexId2; + for (IEdge edge : this.context.edges().getOutEdges()) { + if (neighbors.contains(edge.getTargetId())) { + this.context.sendMessage(edge.getTargetId(), 1); + } } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(0.0); - if (vertex.getId().equals(this.vertexId1)) { - for (IEdge edge : this.context.edges().getInEdges()) { - this.context.sendMessage(this.vertexId2, edge.getTargetId()); - } - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(this.vertexId2, edge.getTargetId()); - } - } - } else if (this.context.getIterationId() == 2) { - HashSet neighbors = new HashSet<>(); - while (messageIterator.hasNext()) { - neighbors.add(messageIterator.next()); - } - for (IEdge edge : this.context.edges().getInEdges()) { - if (neighbors.contains(edge.getTargetId())) { - this.context.sendMessage(edge.getTargetId(), 1); - } - } - for (IEdge edge : this.context.edges().getOutEdges()) { - if (neighbors.contains(edge.getTargetId())) { - this.context.sendMessage(edge.getTargetId(), 1); - } - } - } else if (this.context.getIterationId() == 3) { - Integer meg = this.context.edges().getInEdges().size() + this.context.edges().getOutEdges().size(); - this.context.sendMessage(this.vertexId1, meg); - this.context.sendMessage(this.vertexId2, meg); - } else if (this.context.getIterationId() == 4) { - double lv = 0.0; - while (messageIterator.hasNext()) { - lv += 1.0 / messageIterator.next(); - } - this.context.setNewVertexValue(lv); - } + } else if (this.context.getIterationId() == 3) { + Integer meg = + this.context.edges().getInEdges().size() + this.context.edges().getOutEdges().size(); + this.context.sendMessage(this.vertexId1, meg); + this.context.sendMessage(this.vertexId2, meg); + } else if (this.context.getIterationId() == 4) { + double lv = 0.0; + while (messageIterator.hasNext()) { + lv += 1.0 / messageIterator.next(); } - + this.context.setNewVertexValue(lv); + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/npaths/NPaths.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/npaths/NPaths.java index cfb94a453..0a5339459 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/npaths/NPaths.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/npaths/NPaths.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -59,167 +60,178 @@ public class NPaths { - private static final Logger LOGGER = LoggerFactory.getLogger(NPaths.class); - - private static final int SOURCE_ID = 11342; - - private static final int TARGET_ID = 748615; - - public static final String REFERENCE_FILE_PATH = "data/reference/npath"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/npath"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = NPaths.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(NPaths.class); + + private static final int SOURCE_ID = 11342; + + private static final int TARGET_ID = 748615; + + public static final String REFERENCE_FILE_PATH = "data/reference/npath"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/npath"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = NPaths.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource>> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMap); + PWindowStream>> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource>> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); + PWindowStream>> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream>> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new NPathsAlgorithms(SOURCE_ID, TARGET_ID, 10)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class NPathsAlgorithms + extends VertexCentricCompute< + Integer, Map, Map, List> { + + private final int sourceId; + + private final int targetId; + + public NPathsAlgorithms(int sourceId, int targetId, long iterations) { + super(iterations); + this.sourceId = sourceId; + this.targetId = targetId; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource>> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMap); - PWindowStream>> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource>> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); - PWindowStream>> edges = pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream>> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new NPathsAlgorithms(SOURCE_ID, TARGET_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction< + Integer, Map, Map, List> + getComputeFunction() { + return new NPathsVCFunction(sourceId, targetId, EdgeDirection.OUT); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction> getCombineFunction() { + return null; } + } - public static class NPathsAlgorithms extends VertexCentricCompute, Map, List> { + public static class NPathsVCFunction + extends AbstractVcFunc, Map, List> { - private final int sourceId; + private static final String KEY_FIELD = "dis"; - private final int targetId; + private final int sourceId; - public NPathsAlgorithms(int sourceId, int targetId, long iterations) { - super(iterations); - this.sourceId = sourceId; - this.targetId = targetId; - } + private final int targetId; - @Override - public VertexCentricComputeFunction, Map, List> getComputeFunction() { - return new NPathsVCFunction(sourceId, targetId, EdgeDirection.OUT); - } + private final EdgeDirection edgeType; - @Override - public VertexCentricCombineFunction> getCombineFunction() { - return null; - } + private final AtomicInteger pathNum = new AtomicInteger(10); + public NPathsVCFunction(int sourceId, int targetId, EdgeDirection edgeType) { + this.edgeType = edgeType; + this.sourceId = sourceId; + this.targetId = targetId; } - public static class NPathsVCFunction extends AbstractVcFunc, Map, List> { - - private static final String KEY_FIELD = "dis"; - - private final int sourceId; - - private final int targetId; - - private final EdgeDirection edgeType; - - private final AtomicInteger pathNum = new AtomicInteger(10); - - - public NPathsVCFunction(int sourceId, int targetId, EdgeDirection edgeType) { - this.edgeType = edgeType; - this.sourceId = sourceId; - this.targetId = targetId; + @Override + public void compute(Integer vertexId, Iterator> messageIterator) { + IVertex> vertex = this.context.vertex().get(); + Map property = vertex.getValue(); + if (this.context.getIterationId() == 1) { + property.put(KEY_FIELD, ""); + this.context.setNewVertexValue(property); + if (vertex.getId().equals(sourceId)) { + sendMessage(Collections.singletonList(String.valueOf(sourceId))); } - - @Override - public void compute(Integer vertexId, Iterator> messageIterator) { - IVertex> vertex = this.context.vertex().get(); - Map property = vertex.getValue(); - if (this.context.getIterationId() == 1) { - property.put(KEY_FIELD, ""); - this.context.setNewVertexValue(property); - if (vertex.getId().equals(sourceId)) { - sendMessage(Collections.singletonList(String.valueOf(sourceId))); - } - } else { - if (vertex.getId().equals(targetId)) { - StringBuilder builder = new StringBuilder(); - builder.append(property.get(KEY_FIELD)); - while (messageIterator.hasNext() && pathNum.get() > 0) { - List tmp = new LinkedList<>(messageIterator.next()); - tmp.add(String.valueOf(vertex.getId())); - String sep = ";"; - builder.append(tmp).append(sep); - pathNum.getAndDecrement(); - } - property.put(KEY_FIELD, builder.toString()); - this.context.setNewVertexValue(property); - } else { - while (messageIterator.hasNext()) { - List tmp = new LinkedList<>(messageIterator.next()); - tmp.add(String.valueOf(vertex.getId())); - sendMessage(tmp); - } - } - } - } - - private void sendMessage(List msg) { - switch (this.edgeType) { - case OUT: - for (IEdge> edge : this.context.edges() - .getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), msg); - } - break; - case IN: - for (IEdge> edge : this.context.edges() - .getInEdges()) { - this.context.sendMessage(edge.getTargetId(), msg); - } - break; - case BOTH: - for (IEdge> edge : this.context.edges().getEdges()) { - this.context.sendMessage(edge.getTargetId(), msg); - } - break; - default: - break; - } + } else { + if (vertex.getId().equals(targetId)) { + StringBuilder builder = new StringBuilder(); + builder.append(property.get(KEY_FIELD)); + while (messageIterator.hasNext() && pathNum.get() > 0) { + List tmp = new LinkedList<>(messageIterator.next()); + tmp.add(String.valueOf(vertex.getId())); + String sep = ";"; + builder.append(tmp).append(sep); + pathNum.getAndDecrement(); + } + property.put(KEY_FIELD, builder.toString()); + this.context.setNewVertexValue(property); + } else { + while (messageIterator.hasNext()) { + List tmp = new LinkedList<>(messageIterator.next()); + tmp.add(String.valueOf(vertex.getId())); + sendMessage(tmp); + } } + } } + private void sendMessage(List msg) { + switch (this.edgeType) { + case OUT: + for (IEdge> edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), msg); + } + break; + case IN: + for (IEdge> edge : this.context.edges().getInEdges()) { + this.context.sendMessage(edge.getTargetId(), msg); + } + break; + case BOTH: + for (IEdge> edge : this.context.edges().getEdges()) { + this.context.sendMessage(edge.getTargetId(), msg); + } + break; + default: + break; + } + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/pagerank/PageRank.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/pagerank/PageRank.java index 79bd744e7..f9a6dfddb 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/pagerank/PageRank.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/pagerank/PageRank.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; @@ -54,116 +55,132 @@ public class PageRank { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/pagerank"; - public static final String REF_FILE_PATH = "data/reference/pagerank"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = PageRank.submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/pagerank"; + public static final String REF_FILE_PATH = "data/reference/pagerank"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = PageRank.submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + SinkFunction> sink = + ExampleSinkFunctionFactory.getSinkFunction(conf); + graphWindow + .compute(new PRAlgorithms(10, 0.85)) + .compute(iterationParallelism) + .getVertices() + .map( + e -> { + double value = Double.parseDouble(String.format("%.2f", e.getValue())); + return e.withValue(value); + }) + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } + + public static class PRAlgorithms extends VertexCentricCompute { + + private double alpha; + + public PRAlgorithms(long iterations, double alpha) { + super(iterations); + this.alpha = alpha; } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> prVertices = pipelineTaskCxt.buildSource( - new FileSource<>("data/input/email_vertex", line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> prEdges = pipelineTaskCxt.buildSource( - new FileSource<>("data/input/email_edge", line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - SinkFunction> sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - graphWindow.compute(new PRAlgorithms(10, 0.85)) - .compute(iterationParallelism) - .getVertices() - .map(e -> { - double value = Double.parseDouble(String.format("%.2f", e.getValue())); - return e.withValue(value); - }) - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new PRVertexCentricComputeFunction(); } - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - public static class PRAlgorithms extends VertexCentricCompute { - - private double alpha; - - public PRAlgorithms(long iterations, double alpha) { - super(iterations); - this.alpha = alpha; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new PRVertexCentricComputeFunction(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - - private class PRVertexCentricComputeFunction extends AbstractVcFunc { - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - List> outEdges = context.edges().getOutEdges(); - if (this.context.getIterationId() == 1) { - if (!outEdges.isEmpty()) { - this.context.sendMessageToNeighbors(vertex.getValue() / outEdges.size()); - } - - } else { - double sum = 0; - while (messageIterator.hasNext()) { - double value = messageIterator.next(); - sum += value; - } - double pr = sum * alpha + (1 - alpha); - this.context.setNewVertexValue(pr); - - if (!outEdges.isEmpty()) { - this.context.sendMessageToNeighbors(pr / outEdges.size()); - } - } - } + private class PRVertexCentricComputeFunction + extends AbstractVcFunc { + + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + List> outEdges = context.edges().getOutEdges(); + if (this.context.getIterationId() == 1) { + if (!outEdges.isEmpty()) { + this.context.sendMessageToNeighbors(vertex.getValue() / outEdges.size()); + } + + } else { + double sum = 0; + while (messageIterator.hasNext()) { + double value = messageIterator.next(); + sum += value; + } + double pr = sum * alpha + (1 - alpha); + this.context.setNewVertexValue(pr); + + if (!outEdges.isEmpty()) { + this.context.sendMessageToNeighbors(pr / outEdges.size()); + } } + } } - - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/personalrank/PersonalRank.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/personalrank/PersonalRank.java index 00718429f..b3ad82ef8 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/personalrank/PersonalRank.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/personalrank/PersonalRank.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; @@ -56,154 +57,167 @@ public class PersonalRank { - private static final Logger LOGGER = LoggerFactory.getLogger(PersonalRank.class); - - private static final int ROOT_ID = 1; - - private static final int MAX_ITERATIONS = 10; - - private static final double ALPHA = 0.85; - - private static final double CONVERGENCE = 0.00001; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/personalrank"; - - public static final String REF_FILE_DIR = "data/reference/personalrank"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(PersonalRank.class); + + private static final int ROOT_ID = 1; + + private static final int MAX_ITERATIONS = 10; + + private static final double ALPHA = 0.85; + + private static final double CONVERGENCE = 0.00001; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/personalrank"; + + public static final String REF_FILE_DIR = "data/reference/personalrank"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + SinkFunction> sink = + ExampleSinkFunctionFactory.getSinkFunction(config); + graphWindow + .compute(new PersonalRankAlgorithms(ROOT_ID, MAX_ITERATIONS, ALPHA, CONVERGENCE)) + .compute(iterationParallelism) + .getVertices() + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_DIR, RESULT_FILE_DIR); + } + + public static class PersonalRankAlgorithms + extends VertexCentricCompute { + + private final int rootId; + + private final double alpha; + + private final double convergence; + + public PersonalRankAlgorithms(int rootId, long iterations, double alpha, double convergence) { + super(iterations); + this.rootId = rootId; + this.alpha = alpha; + this.convergence = convergence; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(sourceParallelism); - - PWindowSource> prEdges = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>( - Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - SinkFunction> sink = ExampleSinkFunctionFactory.getSinkFunction(config); - graphWindow - .compute(new PersonalRankAlgorithms(ROOT_ID, MAX_ITERATIONS, ALPHA, CONVERGENCE)) - .compute(iterationParallelism) - .getVertices() - .sink(sink) - .withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new PersonalRankVertexCentricComputeFunction(rootId, alpha, convergence); } - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_DIR, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - public static class PersonalRankAlgorithms extends VertexCentricCompute { + public static class PersonalRankVertexCentricComputeFunction + extends AbstractVcFunc { - private final int rootId; + private final int rootId; - private final double alpha; + private final double alpha; - private final double convergence; - - public PersonalRankAlgorithms(int rootId, long iterations, double alpha, double convergence) { - super(iterations); - this.rootId = rootId; - this.alpha = alpha; - this.convergence = convergence; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new PersonalRankVertexCentricComputeFunction(rootId, alpha, convergence); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final double convergence; + public PersonalRankVertexCentricComputeFunction(int rootId, double alpha, double convergence) { + this.rootId = rootId; + this.convergence = convergence; + this.alpha = alpha; } - public static class PersonalRankVertexCentricComputeFunction extends AbstractVcFunc { - - private final int rootId; - - private final double alpha; - - private final double convergence; - - public PersonalRankVertexCentricComputeFunction(int rootId, double alpha, double convergence) { - this.rootId = rootId; - this.convergence = convergence; - this.alpha = alpha; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + if (vertex.getId() == rootId) { + double score = 1.0; + calc(score); + } else { + this.context.setNewVertexValue(0.0); } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - if (vertex.getId() == rootId) { - double score = 1.0; - calc(score); - } else { - this.context.setNewVertexValue(0.0); - } - } else { - double score = vertex.getValue(); - while (messageIterator.hasNext()) { - score += messageIterator.next(); - } - calc(score); - } + } else { + double score = vertex.getValue(); + while (messageIterator.hasNext()) { + score += messageIterator.next(); } + calc(score); + } + } - private void calc(double score) { - if (score > convergence) { - this.context.setNewVertexValue(score * (1 - alpha)); - List> outEdges = this.context.edges().getOutEdges(); - if (outEdges.size() > 0) { - this.context.sendMessageToNeighbors((score * alpha) / outEdges.size()); - } - } + private void calc(double score) { + if (score > convergence) { + this.context.setNewVertexValue(score * (1 - alpha)); + List> outEdges = this.context.edges().getOutEdges(); + if (outEdges.size() > 0) { + this.context.sendMessageToNeighbors((score * alpha) / outEdges.size()); } + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpath/ShortestPath.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpath/ShortestPath.java index e11d0f0f3..5ab77b413 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpath/ShortestPath.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpath/ShortestPath.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -58,150 +59,162 @@ public class ShortestPath { - private static final Logger LOGGER = LoggerFactory.getLogger(ShortestPath.class); - - private static final int SOURCE_ID = 498021; - - public static final String REFERENCE_FILE_PATH = "data/reference/shortestpath"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/shortestpath"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(ShortestPath.class); + + private static final int SOURCE_ID = 498021; + + public static final String REFERENCE_FILE_PATH = "data/reference/shortestpath"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/shortestpath"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource>> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMap); + PWindowStream>> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource>> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); + PWindowStream>> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream>> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new ShortestPathAlgorithms(SOURCE_ID, 10)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class ShortestPathAlgorithms + extends VertexCentricCompute< + Integer, Map, Map, Tuple>> { + + private final int sourceId; + + public ShortestPathAlgorithms(int sourceId, long iterations) { + super(iterations); + this.sourceId = sourceId; } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource>> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMap); - PWindowStream>> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource>> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); - PWindowStream>> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream>> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new ShortestPathAlgorithms(SOURCE_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction< + Integer, Map, Map, Tuple>> + getComputeFunction() { + return new ShortestPathVertexCentricComputeFunction(sourceId, EdgeDirection.OUT); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction>> getCombineFunction() { + return null; } + } - public static class ShortestPathAlgorithms extends VertexCentricCompute, Map, Tuple>> { + public static class ShortestPathVertexCentricComputeFunction + extends AbstractVcFunc< + Integer, Map, Map, Tuple>> { - private final int sourceId; + private static final String KEY_FIELD = "dis"; - public ShortestPathAlgorithms(int sourceId, long iterations) { - super(iterations); - this.sourceId = sourceId; - } - - @Override - public VertexCentricComputeFunction, Map, Tuple>> getComputeFunction() { - return new ShortestPathVertexCentricComputeFunction(sourceId, EdgeDirection.OUT); - } + private static final String PATH = "paths"; - @Override - public VertexCentricCombineFunction>> getCombineFunction() { - return null; - } + private final int sourceId; + private final EdgeDirection edgeType; + public ShortestPathVertexCentricComputeFunction(int sourceId, EdgeDirection edgeType) { + this.sourceId = sourceId; + this.edgeType = edgeType; } - public static class ShortestPathVertexCentricComputeFunction extends AbstractVcFunc, Map, Tuple>> { - - private static final String KEY_FIELD = "dis"; - - private static final String PATH = "paths"; - - private final int sourceId; - private final EdgeDirection edgeType; - - public ShortestPathVertexCentricComputeFunction(int sourceId, EdgeDirection edgeType) { - this.sourceId = sourceId; - this.edgeType = edgeType; + @Override + public void compute(Integer vertexId, Iterator>> messageIterator) { + IVertex> vertex = this.context.vertex().get(); + Map vertexValue = vertex.getValue(); + int dis = vertex.getId().equals(sourceId) ? 0 : Integer.MAX_VALUE; + if (this.context.getIterationId() == 1) { + vertexValue.put(PATH, "[]"); + vertexValue.put(KEY_FIELD, String.valueOf(Integer.MAX_VALUE)); + this.context.setNewVertexValue(vertexValue); + return; + } + List path = new LinkedList<>(); + while (messageIterator.hasNext()) { + Tuple> msg = messageIterator.next(); + int tmp = msg.getF0(); + if (tmp < dis) { + path.clear(); + dis = tmp; + path.addAll(msg.getF1()); } - - @Override - public void compute(Integer vertexId, - Iterator>> messageIterator) { - IVertex> vertex = this.context.vertex().get(); - Map vertexValue = vertex.getValue(); - int dis = vertex.getId().equals(sourceId) ? 0 : Integer.MAX_VALUE; - if (this.context.getIterationId() == 1) { - vertexValue.put(PATH, "[]"); - vertexValue.put(KEY_FIELD, String.valueOf(Integer.MAX_VALUE)); - this.context.setNewVertexValue(vertexValue); - return; - } - List path = new LinkedList<>(); - while (messageIterator.hasNext()) { - Tuple> msg = messageIterator.next(); - int tmp = msg.getF0(); - if (tmp < dis) { - path.clear(); - dis = tmp; - path.addAll(msg.getF1()); - } + } + if (dis < Integer.parseInt(String.valueOf(vertexValue.get(KEY_FIELD)))) { + vertexValue.put(KEY_FIELD, String.valueOf(dis)); + path.add(String.valueOf(vertex.getId())); + vertexValue.put(PATH, path.toString()); + this.context.setNewVertexValue(vertexValue); + switch (edgeType) { + case IN: + for (IEdge> inEdge : this.context.edges().getInEdges()) { + this.context.sendMessage( + inEdge.getTargetId(), new Tuple<>(dis + inEdge.getValue().get(KEY_FIELD), path)); } - if (dis < Integer.parseInt(String.valueOf(vertexValue.get(KEY_FIELD)))) { - vertexValue.put(KEY_FIELD, String.valueOf(dis)); - path.add(String.valueOf(vertex.getId())); - vertexValue.put(PATH, path.toString()); - this.context.setNewVertexValue(vertexValue); - switch (edgeType) { - case IN: - for (IEdge> inEdge : this.context.edges() - .getInEdges()) { - this.context.sendMessage(inEdge.getTargetId(), - new Tuple<>(dis + inEdge.getValue().get(KEY_FIELD), path)); - } - break; - case OUT: - for (IEdge> outEdge : this.context.edges() - .getOutEdges()) { - this.context.sendMessage(outEdge.getTargetId(), - new Tuple<>(dis + outEdge.getValue().get(KEY_FIELD), path)); - } - break; - default: - break; - } + break; + case OUT: + for (IEdge> outEdge : + this.context.edges().getOutEdges()) { + this.context.sendMessage( + outEdge.getTargetId(), + new Tuple<>(dis + outEdge.getValue().get(KEY_FIELD), path)); } + break; + default: + break; } - + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpathofvertexsets/ShortestPathOfVertexSet.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpathofvertexsets/ShortestPathOfVertexSet.java index 440b985cb..66d3969f9 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpathofvertexsets/ShortestPathOfVertexSet.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/shortestpathofvertexsets/ShortestPathOfVertexSet.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -63,177 +64,197 @@ public class ShortestPathOfVertexSet { - private static final Logger LOGGER = LoggerFactory.getLogger(ShortestPathOfVertexSet.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ShortestPathOfVertexSet.class); - private static final Set SOURCE_ID = new HashSet<>(Arrays.asList(11342, 30957)); + private static final Set SOURCE_ID = new HashSet<>(Arrays.asList(11342, 30957)); - public static final String REFERENCE_FILE_PATH = "data/reference/shortestpathvertex"; + public static final String REFERENCE_FILE_PATH = "data/reference/shortestpathvertex"; - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/shortestpathvertex"; + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/shortestpathvertex"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource>>> vSource = - new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMapMap); - PWindowStream>>> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - FileSource>> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); - PWindowStream>> edges = - pipelineTaskCxt.buildSource(eSource, AllWindow.getInstance()).withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream>>> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new ShortestPathOfVertexSetAlgorithms(SOURCE_ID, 10)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); - } + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); - } + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - public static class ShortestPathOfVertexSetAlgorithms extends VertexCentricCompute>, Map, Triple>> { + FileSource>>> vSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserMapMap); + PWindowStream>>> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); - private final Set sourceId; + FileSource>> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserMap); + PWindowStream>> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); - public ShortestPathOfVertexSetAlgorithms(Set sourceId, long iterations) { - super(iterations); - this.sourceId = sourceId; - } + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream>>> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new ShortestPathOfVertexSetAlgorithms(SOURCE_ID, 10)) + .compute(iterationParallelism) + .getVertices(); - @Override - public VertexCentricComputeFunction>, Map, Triple>> getComputeFunction() { - return new ShortestPathOfVertexSetVCFunction(sourceId, EdgeDirection.OUT); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); - @Override - public VertexCentricCombineFunction>> getCombineFunction() { - return null; - } + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class ShortestPathOfVertexSetAlgorithms + extends VertexCentricCompute< + Integer, + Map>, + Map, + Triple>> { + + private final Set sourceId; + public ShortestPathOfVertexSetAlgorithms(Set sourceId, long iterations) { + super(iterations); + this.sourceId = sourceId; } - public static class ShortestPathOfVertexSetVCFunction extends AbstractVcFunc>, Map, Triple>> { + @Override + public VertexCentricComputeFunction< + Integer, + Map>, + Map, + Triple>> + getComputeFunction() { + return new ShortestPathOfVertexSetVCFunction(sourceId, EdgeDirection.OUT); + } - private static final String KEY_FIELD = "dis"; + @Override + public VertexCentricCombineFunction>> + getCombineFunction() { + return null; + } + } - private static final String PATH = "paths"; + public static class ShortestPathOfVertexSetVCFunction + extends AbstractVcFunc< + Integer, + Map>, + Map, + Triple>> { - private final Set sourceId; - private final EdgeDirection edgeType; + private static final String KEY_FIELD = "dis"; - public ShortestPathOfVertexSetVCFunction(Set sourceId, EdgeDirection edgeType) { - this.sourceId = sourceId; - this.edgeType = edgeType; - } + private static final String PATH = "paths"; - @Override - public void compute(Integer vertexId, - Iterator>> messageIterator) { - IVertex>> vertex = this.context.vertex().get(); - Map> property = vertex.getValue(); - if (this.context.getIterationId() == 1) { - Map dis = new HashMap<>(); //Map - Map path = new HashMap<>(); //Map> - for (Integer id : this.sourceId) { - if (vertex.getId().equals(id)) { - sendMessage(id, 0, new ArrayList<>( - Collections.singletonList(vertex.getId()))); - dis.put(id, 0); - path.put(id, new ArrayList<>(Collections.singletonList(id))); - } else { - dis.put(id, Integer.MAX_VALUE); - } - } - property.put(KEY_FIELD, dis); - property.put(PATH, path); - this.context.setNewVertexValue(property); - } else { - Map newDisMap = new HashMap<>(2); - Map> newPathMap = new HashMap<>(2); - while (messageIterator.hasNext()) { - Triple> meg = messageIterator.next(); - if (meg.getF1() < newDisMap.getOrDefault(meg.getF0(), Integer.MAX_VALUE)) { - newDisMap.put(meg.getF0(), meg.getF1()); - newPathMap.put(meg.getF0(), meg.getF2()); - } - } - if (!newDisMap.isEmpty()) { - Map curDisMap = property.get(KEY_FIELD); - Map curPathMap = property.get(PATH); - for (Map.Entry kv : newDisMap.entrySet()) { - if (kv.getValue() < (Integer) curDisMap.getOrDefault(kv.getKey(), Integer.MAX_VALUE)) { - curDisMap.put(kv.getKey(), kv.getValue()); - List tmp = new ArrayList<>(newPathMap.get(kv.getKey())); - tmp.add(vertex.getId()); - curPathMap.put(kv.getKey(), tmp); - sendMessage(kv.getKey(), kv.getValue(), tmp); - } - } - Map dis = new HashMap<>(curDisMap);//Map - Map path = new HashMap<>(curPathMap);//Map> - property.put(KEY_FIELD, dis); - property.put(PATH, path); - this.context.setNewVertexValue(property); - } - } - } + private final Set sourceId; + private final EdgeDirection edgeType; - private void sendMessage(Integer id, Integer distance, List path) { - switch (edgeType) { - case IN: - for (IEdge> edge : this.context.edges() - .getInEdges()) { - this.context.sendMessage(edge.getTargetId(), new Triple<>(id, distance + edge.getValue().get(KEY_FIELD), path)); - } - break; - case OUT: - for (IEdge> edge : this.context.edges() - .getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), new Triple<>(id, distance + edge.getValue().get(KEY_FIELD), path)); - } - break; - default: - break; + public ShortestPathOfVertexSetVCFunction(Set sourceId, EdgeDirection edgeType) { + this.sourceId = sourceId; + this.edgeType = edgeType; + } + + @Override + public void compute( + Integer vertexId, Iterator>> messageIterator) { + IVertex>> vertex = this.context.vertex().get(); + Map> property = vertex.getValue(); + if (this.context.getIterationId() == 1) { + Map dis = new HashMap<>(); // Map + Map path = new HashMap<>(); // Map> + for (Integer id : this.sourceId) { + if (vertex.getId().equals(id)) { + sendMessage(id, 0, new ArrayList<>(Collections.singletonList(vertex.getId()))); + dis.put(id, 0); + path.put(id, new ArrayList<>(Collections.singletonList(id))); + } else { + dis.put(id, Integer.MAX_VALUE); + } + } + property.put(KEY_FIELD, dis); + property.put(PATH, path); + this.context.setNewVertexValue(property); + } else { + Map newDisMap = new HashMap<>(2); + Map> newPathMap = new HashMap<>(2); + while (messageIterator.hasNext()) { + Triple> meg = messageIterator.next(); + if (meg.getF1() < newDisMap.getOrDefault(meg.getF0(), Integer.MAX_VALUE)) { + newDisMap.put(meg.getF0(), meg.getF1()); + newPathMap.put(meg.getF0(), meg.getF2()); + } + } + if (!newDisMap.isEmpty()) { + Map curDisMap = property.get(KEY_FIELD); + Map curPathMap = property.get(PATH); + for (Map.Entry kv : newDisMap.entrySet()) { + if (kv.getValue() < (Integer) curDisMap.getOrDefault(kv.getKey(), Integer.MAX_VALUE)) { + curDisMap.put(kv.getKey(), kv.getValue()); + List tmp = new ArrayList<>(newPathMap.get(kv.getKey())); + tmp.add(vertex.getId()); + curPathMap.put(kv.getKey(), tmp); + sendMessage(kv.getKey(), kv.getValue(), tmp); } + } + Map dis = new HashMap<>(curDisMap); // Map + Map path = new HashMap<>(curPathMap); // Map> + property.put(KEY_FIELD, dis); + property.put(PATH, path); + this.context.setNewVertexValue(property); } + } + } + private void sendMessage(Integer id, Integer distance, List path) { + switch (edgeType) { + case IN: + for (IEdge> edge : this.context.edges().getInEdges()) { + this.context.sendMessage( + edge.getTargetId(), + new Triple<>(id, distance + edge.getValue().get(KEY_FIELD), path)); + } + break; + case OUT: + for (IEdge> edge : this.context.edges().getOutEdges()) { + this.context.sendMessage( + edge.getTargetId(), + new Triple<>(id, distance + edge.getValue().get(KEY_FIELD), path)); + } + break; + default: + break; + } } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/sssp/SSSP.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/sssp/SSSP.java index 678f5d710..d8d5a7e79 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/sssp/SSSP.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/sssp/SSSP.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; @@ -55,115 +56,129 @@ public class SSSP { - private static final Logger LOGGER = LoggerFactory.getLogger(SSSP.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/sssp"; - public static final String REF_FILE_PATH = "data/reference/sssp"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(SSSP.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/sssp"; + public static final String REF_FILE_PATH = "data/reference/sssp"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + + int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + SinkFunction> sink = + ExampleSinkFunctionFactory.getSinkFunction(conf); + graphWindow + .compute(new SSSPAlgorithm(1, 10)) + .compute(iterationParallelism) + .getVertices() + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); + } + + public static class SSSPAlgorithm + extends VertexCentricCompute { + + private final int srcId; + + public SSSPAlgorithm(int srcId, long iterations) { + super(iterations); + this.srcId = srcId; } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - PWindowSource> prEdges = pipelineTaskCxt.buildSource(new FileSource<>( - "data/input/email_edge", line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - - int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - SinkFunction> sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - graphWindow.compute(new SSSPAlgorithm(1, 10)) - .compute(iterationParallelism) - .getVertices() - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new SSSP.SSSPVertexCentricComputeFunction(srcId); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } + } - public static class SSSPAlgorithm extends VertexCentricCompute { - - private final int srcId; + public static class SSSPVertexCentricComputeFunction + extends AbstractVcFunc { - public SSSPAlgorithm(int srcId, long iterations) { - super(iterations); - this.srcId = srcId; - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new SSSP.SSSPVertexCentricComputeFunction(srcId); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + private final int srcId; + public SSSPVertexCentricComputeFunction(int srcId) { + this.srcId = srcId; } - public static class SSSPVertexCentricComputeFunction extends AbstractVcFunc { - - private final int srcId; - - public SSSPVertexCentricComputeFunction(int srcId) { - this.srcId = srcId; + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + int minDistance = vertex.getId() == srcId ? 0 : Integer.MAX_VALUE; + if (messageIterator != null) { + while (messageIterator.hasNext()) { + Integer value = messageIterator.next(); + minDistance = Math.min(minDistance, value); } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - int minDistance = vertex.getId() == srcId ? 0 : Integer.MAX_VALUE; - if (messageIterator != null) { - while (messageIterator.hasNext()) { - Integer value = messageIterator.next(); - minDistance = Math.min(minDistance, value); - } - } - if (minDistance < vertex.getValue()) { - this.context.setNewVertexValue(minDistance); - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), minDistance + edge.getValue()); - } - } + } + if (minDistance < vertex.getValue()) { + this.context.setNewVertexValue(minDistance); + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), minDistance + edge.getValue()); } - + } } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/weakconnectedcomponents/WeakConnectedComponents.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/weakconnectedcomponents/WeakConnectedComponents.java index 68137543c..f5f21fbd8 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/weakconnectedcomponents/WeakConnectedComponents.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/compute/weakconnectedcomponents/WeakConnectedComponents.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Iterator; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.VertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -53,117 +54,125 @@ public class WeakConnectedComponents { - private static final Logger LOGGER = LoggerFactory.getLogger(WeakConnectedComponents.class); - - public static final String REFERENCE_FILE_PATH = "data/reference/wcc"; - - public static final String RESULT_FILE_DIR = "./target/tmp/data/result/wcc"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); + private static final Logger LOGGER = LoggerFactory.getLogger(WeakConnectedComponents.class); + + public static final String REFERENCE_FILE_PATH = "data/reference/wcc"; + + public static final String RESULT_FILE_DIR = "./target/tmp/data/result/wcc"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + ResultValidator.cleanResult(RESULT_FILE_DIR); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration config = pipelineTaskCxt.getConfig(); + int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); + int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); + LOGGER.info( + "with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); + + FileSource> vSource = + new FileSource<>( + GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); + PWindowStream> vertices = + pipelineTaskCxt + .buildSource(vSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + FileSource> eSource = + new FileSource<>(GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); + PWindowStream> edges = + pipelineTaskCxt + .buildSource(eSource, AllWindow.getInstance()) + .withParallelism(sourceParallelism); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(iterationParallelism) + .withBackend(BackendType.Memory) + .build(); + PWindowStream> result = + pipelineTaskCxt + .buildWindowStreamGraph(vertices, edges, graphViewDesc) + .compute(new WeakConnectedComponentsAlgorithm(50)) + .compute(iterationParallelism) + .getVertices(); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + result + .map(v -> String.format("%s,%s", v.getId(), v.getValue())) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + } + + public static class WeakConnectedComponentsAlgorithm + extends VertexCentricCompute { + + public WeakConnectedComponentsAlgorithm(long iterations) { + super(iterations); } - public static IPipelineResult submit(Environment environment) { - ResultValidator.cleanResult(RESULT_FILE_DIR); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_DIR); - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration config = pipelineTaskCxt.getConfig(); - int sourceParallelism = config.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - int iterationParallelism = config.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); - int sinkParallelism = config.getInteger(ExampleConfigKeys.SINK_PARALLELISM); - LOGGER.info("with {} {} {}", sourceParallelism, iterationParallelism, sinkParallelism); - - FileSource> vSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::vertexParserInteger); - PWindowStream> vertices = - pipelineTaskCxt.buildSource(vSource, AllWindow.getInstance()) - .withParallelism(sourceParallelism); - - FileSource> eSource = new FileSource<>( - GraphDataSet.DATASET_FILE, VertexEdgeParser::edgeParserInteger); - PWindowStream> edges = pipelineTaskCxt.buildSource(eSource, - AllWindow.getInstance()) - .withParallelism(sourceParallelism); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(iterationParallelism) - .withBackend(BackendType.Memory) - .build(); - PWindowStream> result = - pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) - .compute(new WeakConnectedComponentsAlgorithm(50)) - .compute(iterationParallelism) - .getVertices(); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - result.map(v -> String.format("%s,%s", v.getId(), v.getValue())) - .sink(sink).withParallelism(sinkParallelism); - }); - - return pipeline.execute(); + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new WeakConnectedComponentsVCCFunction(); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_DIR); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - - public static class WeakConnectedComponentsAlgorithm extends VertexCentricCompute { - - public WeakConnectedComponentsAlgorithm(long iterations) { - super(iterations); + } + + public static class WeakConnectedComponentsVCCFunction + extends AbstractVcFunc { + + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.setNewVertexValue(vertex.getId()); + sendMessage(vertex.getId()); + } else { + Integer newComponent = messageIterator.next(); + while (messageIterator.hasNext()) { + Integer tmp = messageIterator.next(); + if (newComponent > tmp) { + newComponent = tmp; + } } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new WeakConnectedComponentsVCCFunction(); + if (vertex.getValue() > newComponent) { + this.context.setNewVertexValue(newComponent); + sendMessage(newComponent); } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - + } } - public static class WeakConnectedComponentsVCCFunction extends AbstractVcFunc { - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.setNewVertexValue(vertex.getId()); - sendMessage(vertex.getId()); - } else { - Integer newComponent = messageIterator.next(); - while (messageIterator.hasNext()) { - Integer tmp = messageIterator.next(); - if (newComponent > tmp) { - newComponent = tmp; - } - } - if (vertex.getValue() > newComponent) { - this.context.setNewVertexValue(newComponent); - sendMessage(newComponent); - } - } - } - - private void sendMessage(Integer meg) { - for (IEdge edge : this.context.edges().getInEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } - for (IEdge edge : this.context.edges().getOutEdges()) { - this.context.sendMessage(edge.getTargetId(), meg); - } - } - + private void sendMessage(Integer meg) { + for (IEdge edge : this.context.edges().getInEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } + for (IEdge edge : this.context.edges().getOutEdges()) { + this.context.sendMessage(edge.getTargetId(), meg); + } } - + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAlgorithm.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAlgorithm.java index a70bae013..9063e0151 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAlgorithm.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAlgorithm.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.graph.statical.traversal; import java.util.Iterator; + import org.apache.geaflow.api.graph.function.vc.VertexCentricAggTraversalFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricAggregateFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -29,149 +30,172 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class StaticGraphAggTraversalAlgorithm extends VertexCentricAggTraversal, Tuple, - Tuple, Integer> { +public class StaticGraphAggTraversalAlgorithm + extends VertexCentricAggTraversal< + Integer, + Double, + Integer, + Integer, + Integer, + Integer, + Tuple, + Tuple, + Tuple, + Integer> { + + private static final Logger LOGGER = + LoggerFactory.getLogger(StaticGraphAggTraversalAlgorithm.class); + + public StaticGraphAggTraversalAlgorithm(long iterations) { + super(iterations); + } + + @Override + public VertexCentricAggTraversalFunction< + Integer, Double, Integer, Integer, Integer, Integer, Integer> + getTraversalFunction() { + return new StaticGraphAggTraversalAlgorithm.TraversalFunction(); + } + + @Override + public VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer> + getAggregateFunction() { + return new VertexCentricAggregateFunction< + Integer, + Tuple, + Tuple, + Tuple, + Integer>() { + @Override + public IPartialGraphAggFunction, Tuple> + getPartialAggregation() { + return new IPartialGraphAggFunction< + Integer, Tuple, Tuple>() { + + private IPartialAggContext> partialAggContext; + + @Override + public Tuple create( + IPartialAggContext> partialAggContext) { + this.partialAggContext = partialAggContext; + return Tuple.of(0, 0); + } + + @Override + public Tuple aggregate( + Integer integer, Tuple result) { + result.f0 += 1; + result.f1 += integer; + return result; + } + + @Override + public void finish(Tuple result) { + partialAggContext.collect(result); + } + }; + } - private static final Logger LOGGER = LoggerFactory.getLogger(StaticGraphAggTraversalAlgorithm.class); + @Override + public IGraphAggregateFunction, Tuple, Integer> + getGlobalAggregation() { + return new IGraphAggregateFunction< + Tuple, Tuple, Integer>() { + private IGlobalGraphAggContext globalGraphAggContext; + + @Override + public Tuple create( + IGlobalGraphAggContext globalGraphAggContext) { + this.globalGraphAggContext = globalGraphAggContext; + return Tuple.of(0, 0); + } + + @Override + public Integer aggregate( + Tuple integerIntegerTuple2, + Tuple integerIntegerTuple22) { + integerIntegerTuple22.f0 += integerIntegerTuple2.f0; + integerIntegerTuple22.f1 += integerIntegerTuple2.f1; + return (int) (integerIntegerTuple22.f1 / integerIntegerTuple22.f0); + } + + @Override + public void finish(Integer value) { + long iterationId = this.globalGraphAggContext.getIteration(); + if (value > 0) { + LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); + this.globalGraphAggContext.terminate(); + } else { + LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); + this.globalGraphAggContext.broadcast(value); + } + } + }; + } + }; + } + + @Override + public VertexCentricCombineFunction getCombineFunction() { + return new VertexCentricCombineFunction() { + @Override + public Integer combine(Integer oldMessage, Integer newMessage) { + return oldMessage + newMessage; + } + }; + } + + public class TraversalFunction + implements VertexCentricAggTraversalFunction< + Integer, Double, Integer, Integer, Integer, Integer, Integer> { + + private VertexCentricTraversalFuncContext + vertexCentricFuncContext; + private VertexCentricAggContext aggContext; - public StaticGraphAggTraversalAlgorithm(long iterations) { - super(iterations); + @Override + public void initContext(VertexCentricAggContext aggContext) { + this.aggContext = aggContext; } @Override - public VertexCentricAggTraversalFunction getTraversalFunction() { - return new StaticGraphAggTraversalAlgorithm.TraversalFunction(); + public void open( + VertexCentricTraversalFuncContext + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; } @Override - public VertexCentricAggregateFunction, Tuple, Tuple, Integer> getAggregateFunction() { - return new VertexCentricAggregateFunction, - Tuple, Tuple, Integer>() { - @Override - public IPartialGraphAggFunction, Tuple> getPartialAggregation() { - return new IPartialGraphAggFunction, Tuple>() { - - private IPartialAggContext> partialAggContext; - - @Override - public Tuple create( - IPartialAggContext> partialAggContext) { - this.partialAggContext = partialAggContext; - return Tuple.of(0, 0); - } - - - @Override - public Tuple aggregate(Integer integer, Tuple result) { - result.f0 += 1; - result.f1 += integer; - return result; - } - - @Override - public void finish(Tuple result) { - partialAggContext.collect(result); - } - }; - } + public void init(ITraversalRequest traversalRequest) { - @Override - public IGraphAggregateFunction, Tuple, - Integer> getGlobalAggregation() { - return new IGraphAggregateFunction, Tuple, Integer>() { - - private IGlobalGraphAggContext globalGraphAggContext; - - @Override - public Tuple create( - IGlobalGraphAggContext globalGraphAggContext) { - this.globalGraphAggContext = globalGraphAggContext; - return Tuple.of(0, 0); - } - - @Override - public Integer aggregate(Tuple integerIntegerTuple2, - Tuple integerIntegerTuple22) { - integerIntegerTuple22.f0 += integerIntegerTuple2.f0; - integerIntegerTuple22.f1 += integerIntegerTuple2.f1; - return (int) (integerIntegerTuple22.f1 / integerIntegerTuple22.f0); - } - - @Override - public void finish(Integer value) { - long iterationId = this.globalGraphAggContext.getIteration(); - if (value > 0) { - LOGGER.info("current iterationId:{} value is {}, do terminate", iterationId, value); - this.globalGraphAggContext.terminate(); - } else { - LOGGER.info("current iterationId:{} value is {}, do broadcast", iterationId, value); - this.globalGraphAggContext.broadcast(value); - } - } - }; - } - }; + int degreeSize = vertexCentricFuncContext.edges().getOutEdges().size(); + aggContext.aggregate(0); + vertexCentricFuncContext.sendMessageToNeighbors(degreeSize); } @Override - public VertexCentricCombineFunction getCombineFunction() { - return new VertexCentricCombineFunction() { - @Override - public Integer combine(Integer oldMessage, Integer newMessage) { - return oldMessage + newMessage; - } - }; + public void compute(Integer vertexId, Iterator messageIterator) { + + int sum = 0; + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + aggContext.aggregate(sum); + this.vertexCentricFuncContext.takeResponse(new TraversalResponseExample(vertexId, sum)); + LOGGER.info("vertexId {} aggregate {}", vertexId, sum); + vertexCentricFuncContext.sendMessageToNeighbors(sum); } - public class TraversalFunction implements - VertexCentricAggTraversalFunction { - - private VertexCentricTraversalFuncContext vertexCentricFuncContext; - private VertexCentricAggContext aggContext; - - @Override - public void initContext(VertexCentricAggContext aggContext) { - this.aggContext = aggContext; - } - - @Override - public void open(VertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void init(ITraversalRequest traversalRequest) { - - int degreeSize = vertexCentricFuncContext.edges().getOutEdges().size(); - aggContext.aggregate(0); - vertexCentricFuncContext.sendMessageToNeighbors(degreeSize); - } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - - int sum = 0; - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - aggContext.aggregate(sum); - this.vertexCentricFuncContext.takeResponse(new TraversalResponseExample(vertexId, sum)); - LOGGER.info("vertexId {} aggregate {}", vertexId, sum); - vertexCentricFuncContext.sendMessageToNeighbors(sum); - } - - @Override - public void finish() { - - } - - @Override - public void close() { + @Override + public void finish() {} - } - } -} \ No newline at end of file + @Override + public void close() {} + } +} diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllExample.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllExample.java index 18c8f0fa8..9bb50397d 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllExample.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllExample.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Collections; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -47,61 +48,76 @@ public class StaticGraphAggTraversalAllExample { - public static final String REFERENCE_FILE_PATH = "data/reference/static_traversal_all_agg"; - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/static_traversal_all_agg"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - StaticGraphAggTraversalAllExample.submit(environment); - } + public static final String REFERENCE_FILE_PATH = "data/reference/static_traversal_all_agg"; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/static_traversal_all_agg"; - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + StaticGraphAggTraversalAllExample.submit(environment); + } - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - PWindowSource> prEdges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(2) - .withBackend(BackendType.Memory) - .build(); + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(2) + .withBackend(BackendType.Memory) + .build(); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - graphWindow.traversal(new StaticGraphAggTraversalAlgorithm(3)) - .start().withParallelism(1) - .map(x -> String.format("%s,%s", x.getResponseId(), x.getResponse())) - .sink(sink); - }); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - return pipeline.execute(); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + graphWindow + .traversal(new StaticGraphAggTraversalAlgorithm(3)) + .start() + .withParallelism(1) + .map(x -> String.format("%s,%s", x.getResponseId(), x.getResponse())) + .sink(sink); + }); - public static void validateResult() throws IOException { - ResultValidator.validateResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH); - } + return pipeline.execute(); + } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdExample.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdExample.java index bc2b99a31..fd7a911c8 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdExample.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdExample.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Collections; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -47,60 +48,76 @@ public class StaticGraphAggTraversalByIdExample { - public static final String REFERENCE_FILE_PATH = "data/reference/static_traversal_id_agg"; - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/static_traversal_id_agg"; + public static final String REFERENCE_FILE_PATH = "data/reference/static_traversal_id_agg"; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/static_traversal_id_agg"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - StaticGraphAggTraversalByIdExample.submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + StaticGraphAggTraversalByIdExample.submit(environment); + } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - PWindowSource> prEdges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(2) - .withBackend(BackendType.Memory) - .build(); + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(2) + .withBackend(BackendType.Memory) + .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - graphWindow.traversal(new StaticGraphAggTraversalAlgorithm(3)) - .start(108).withParallelism(1) - .map(x -> String.format("%s,%s", x.getResponseId(), x.getResponse())) - .sink(sink); - }); + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + graphWindow + .traversal(new StaticGraphAggTraversalAlgorithm(3)) + .start(108) + .withParallelism(1) + .map(x -> String.format("%s,%s", x.getResponseId(), x.getResponse())) + .sink(sink); + }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH); - } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllExample.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllExample.java index d70d7f118..c78cd9f10 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllExample.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllExample.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; @@ -56,148 +57,160 @@ public class StaticGraphTraversalAllExample { - public static final String REFERENCE_FILE_PATH = "data/reference/static_graph_traversal_all"; - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result" - + "/static_graph_traversal_all"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); - } - - public static IPipelineResult submit(Environment environment) { - - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()).withParallelism(1); - - PWindowSource> prEdges = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()).withParallelism(1); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + public static final String REFERENCE_FILE_PATH = "data/reference/static_graph_traversal_all"; + + public static final String RESULT_FILE_PATH = + "./target/tmp/data/result" + "/static_graph_traversal_all"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(1); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(1); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) .withShardNum(1) .withBackend(BackendType.Memory) .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction( - pipelineTaskCxt.getConfig()); - - graphWindow.traversal(new VertexCentricTraversal(3) { - @Override - public VertexCentricTraversalFunction getTraversalFunction() { - return new VertexCentricTraversalFunction() { - - private VertexCentricTraversalFuncContext vertexCentricFuncContext; - - @Override - public void open( - VertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void init(ITraversalRequest traversalRequest) { - List> outEdges = - this.vertexCentricFuncContext.edges().getOutEdges(); - int sum = outEdges.size(); - this.vertexCentricFuncContext.takeResponse( - new TraversalResponse(traversalRequest.getRequestId(), sum)); - } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - - } - - @Override - public void finish() { - - } - - @Override - public void close() { - - } + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + SinkFunction sink = + ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); + + graphWindow + .traversal( + new VertexCentricTraversal(3) { + @Override + public VertexCentricTraversalFunction< + Integer, Integer, Integer, Integer, Integer> + getTraversalFunction() { + return new VertexCentricTraversalFunction< + Integer, Integer, Integer, Integer, Integer>() { + + private VertexCentricTraversalFuncContext< + Integer, Integer, Integer, Integer, Integer> + vertexCentricFuncContext; + + @Override + public void open( + VertexCentricTraversalFuncContext< + Integer, Integer, Integer, Integer, Integer> + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; + } + + @Override + public void init(ITraversalRequest traversalRequest) { + List> outEdges = + this.vertexCentricFuncContext.edges().getOutEdges(); + int sum = outEdges.size(); + this.vertexCentricFuncContext.takeResponse( + new TraversalResponse(traversalRequest.getRequestId(), sum)); + } + + @Override + public void compute( + Integer vertexId, Iterator messageIterator) {} + + @Override + public void finish() {} + + @Override + public void close() {} }; - } + } - @Override - public VertexCentricCombineFunction getCombineFunction() { + @Override + public VertexCentricCombineFunction getCombineFunction() { return null; - } - - }).start().withParallelism(1).map(x -> x.toString()).sink(sink); - - } + } + }) + .start() + .withParallelism(1) + .map(x -> x.toString()) + .sink(sink); + } }); - return pipeline.execute(); - } - - public static class TraversalResponse implements ITraversalResponse { + return pipeline.execute(); + } - private long responseId; - private int response; + public static class TraversalResponse implements ITraversalResponse { - public TraversalResponse(long responseId, int response) { - this.responseId = responseId; - this.response = response; - } + private long responseId; + private int response; - @Override - public long getResponseId() { - return responseId; - } + public TraversalResponse(long responseId, int response) { + this.responseId = responseId; + this.response = response; + } - @Override - public Integer getResponse() { - return response; - } + @Override + public long getResponseId() { + return responseId; + } - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public Integer getResponse() { + return response; + } - @Override - public String toString() { - return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response - + '}'; - } + @Override + public ResponseType getType() { + return ResponseType.Vertex; } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + @Override + public String toString() { + return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response + '}'; } + } + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REFERENCE_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalExample.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalExample.java index c7a2b63c9..5d323d8f0 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalExample.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalExample.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.graph.PGraphWindow; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction; @@ -53,137 +54,151 @@ public class StaticGraphTraversalExample { - private static final Logger LOGGER = LoggerFactory.getLogger(StaticGraphTraversalExample.class); - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - IPipelineResult result = submit(environment); - PipelineResultCollect.get(result); - environment.shutdown(); - } - - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - PWindowSource> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()).withParallelism(1); - - PWindowSource> prEdges = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()).withParallelism(1); - - int iterationParallelism = 2; - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + private static final Logger LOGGER = LoggerFactory.getLogger(StaticGraphTraversalExample.class); + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + IPipelineResult result = submit(environment); + PipelineResultCollect.get(result); + environment.shutdown(); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + PWindowSource> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(1); + + PWindowSource> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(1); + + int iterationParallelism = 2; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) .withShardNum(iterationParallelism) .withBackend(BackendType.Memory) .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - graphWindow.traversal(new VertexCentricTraversal(3) { - @Override - public VertexCentricTraversalFunction getTraversalFunction() { - return new VertexCentricTraversalFunction() { - - private VertexCentricTraversalFuncContext vertexCentricFuncContext; - - @Override - public void open( - VertexCentricTraversalFuncContext vertexCentricFuncContext) { - this.vertexCentricFuncContext = vertexCentricFuncContext; - } - - @Override - public void init(ITraversalRequest traversalRequest) { - List> outEdges = - this.vertexCentricFuncContext.edges().getOutEdges(); - for (IEdge edge : outEdges) { - LOGGER.info("out edge:{}", edge); - this.vertexCentricFuncContext.takeResponse(new TraversalResponse(traversalRequest.getRequestId(), edge.getTargetId())); - } + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + graphWindow + .traversal( + new VertexCentricTraversal(3) { + @Override + public VertexCentricTraversalFunction< + Integer, Integer, Integer, Integer, Integer> + getTraversalFunction() { + return new VertexCentricTraversalFunction< + Integer, Integer, Integer, Integer, Integer>() { + + private VertexCentricTraversalFuncContext< + Integer, Integer, Integer, Integer, Integer> + vertexCentricFuncContext; + + @Override + public void open( + VertexCentricTraversalFuncContext< + Integer, Integer, Integer, Integer, Integer> + vertexCentricFuncContext) { + this.vertexCentricFuncContext = vertexCentricFuncContext; + } + + @Override + public void init(ITraversalRequest traversalRequest) { + List> outEdges = + this.vertexCentricFuncContext.edges().getOutEdges(); + for (IEdge edge : outEdges) { + LOGGER.info("out edge:{}", edge); + this.vertexCentricFuncContext.takeResponse( + new TraversalResponse( + traversalRequest.getRequestId(), edge.getTargetId())); } + } - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - - } + @Override + public void compute( + Integer vertexId, Iterator messageIterator) {} - @Override - public void finish() { - - } + @Override + public void finish() {} - @Override - public void close() { - - } + @Override + public void close() {} }; - } + } - @Override - public VertexCentricCombineFunction getCombineFunction() { + @Override + public VertexCentricCombineFunction getCombineFunction() { return null; - } - - }).start(300).withParallelism(iterationParallelism).sink(x -> LOGGER.info("x:{}", x)); - - } + } + }) + .start(300) + .withParallelism(iterationParallelism) + .sink(x -> LOGGER.info("x:{}", x)); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static class TraversalResponse implements ITraversalResponse { + public static class TraversalResponse implements ITraversalResponse { - private long responseId; - private int response; + private long responseId; + private int response; - public TraversalResponse(long responseId, int response) { - this.responseId = responseId; - this.response = response; - } - - @Override - public long getResponseId() { - return responseId; - } + public TraversalResponse(long responseId, int response) { + this.responseId = responseId; + this.response = response; + } - @Override - public Integer getResponse() { - return response; - } + @Override + public long getResponseId() { + return responseId; + } - @Override - public ResponseType getType() { - return ResponseType.Vertex; - } + @Override + public Integer getResponse() { + return response; + } - @Override - public String toString() { - return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response - + '}'; - } + @Override + public ResponseType getType() { + return ResponseType.Vertex; } - public static void validateResult() throws IOException { + @Override + public String toString() { + return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response + '}'; } + } + public static void validateResult() throws IOException {} } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/TraversalResponseExample.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/TraversalResponseExample.java index dcf727e20..d88d1c1d7 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/TraversalResponseExample.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/graph/statical/traversal/TraversalResponseExample.java @@ -24,32 +24,31 @@ public class TraversalResponseExample implements ITraversalResponse { - private long responseId; - private T response; - - public TraversalResponseExample(long responseId, T response) { - this.responseId = responseId; - this.response = response; - } - - @Override - public long getResponseId() { - return responseId; - } - - @Override - public T getResponse() { - return response; - } - - @Override - public TraversalType.ResponseType getType() { - return TraversalType.ResponseType.Vertex; - } - - @Override - public String toString() { - return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response - + '}'; - } + private long responseId; + private T response; + + public TraversalResponseExample(long responseId, T response) { + this.responseId = responseId; + this.response = response; + } + + @Override + public long getResponseId() { + return responseId; + } + + @Override + public T getResponse() { + return response; + } + + @Override + public TraversalType.ResponseType getType() { + return TraversalType.ResponseType.Vertex; + } + + @Override + public String toString() { + return "TraversalResponse{" + "responseId=" + responseId + ", response=" + response + '}'; + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/k8s/UnBoundedStreamWordCount.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/k8s/UnBoundedStreamWordCount.java index 740177959..7d202958a 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/k8s/UnBoundedStreamWordCount.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/k8s/UnBoundedStreamWordCount.java @@ -27,12 +27,12 @@ public class UnBoundedStreamWordCount { - private static final Logger LOGGER = LoggerFactory.getLogger(UnBoundedStreamWordCount.class); + private static final Logger LOGGER = LoggerFactory.getLogger(UnBoundedStreamWordCount.class); - public static void main(String[] args) { - Environment environment = EnvironmentFactory.onK8SEnvironment(args); + public static void main(String[] args) { + Environment environment = EnvironmentFactory.onK8SEnvironment(args); - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - pipeline.submit(environment); - } + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + pipeline.submit(environment); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/QueryService.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/QueryService.java index 019c94a25..266ee3fed 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/QueryService.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/QueryService.java @@ -19,7 +19,6 @@ package org.apache.geaflow.example.service; -import com.google.common.base.Preconditions; import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.dsl.runtime.QueryClient; @@ -35,43 +34,51 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class QueryService { - private static final Logger LOGGER = LoggerFactory.getLogger(QueryService.class); + private static final Logger LOGGER = LoggerFactory.getLogger(QueryService.class); - private static final String WARM_UP_PATTERN = "USE GRAPH %s ; MATCH (a) RETURN a limit 0"; + private static final String WARM_UP_PATTERN = "USE GRAPH %s ; MATCH (a) RETURN a limit 0"; - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - Configuration configuration = environment.getEnvironmentContext().getConfig(); + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + Configuration configuration = environment.getEnvironmentContext().getConfig(); - String graphViewName = configuration.getConfigMap().get("geaflow.analytics.graph.view.name"); - Preconditions.checkNotNull(graphViewName, "graph view name is null"); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, String.format(WARM_UP_PATTERN, graphViewName)); - submit(environment); - LOGGER.info("query service start finish"); - synchronized (QueryService.class) { - try { - QueryService.class.wait(); - } catch (Throwable e) { - LOGGER.error("wait server failed"); - } - } + String graphViewName = configuration.getConfigMap().get("geaflow.analytics.graph.view.name"); + Preconditions.checkNotNull(graphViewName, "graph view name is null"); + configuration.put( + AnalyticsServiceConfigKeys.ANALYTICS_QUERY, String.format(WARM_UP_PATTERN, graphViewName)); + submit(environment); + LOGGER.info("query service start finish"); + synchronized (QueryService.class) { + try { + QueryService.class.wait(); + } catch (Throwable e) { + LOGGER.error("wait server failed"); + } } + } - public static IPipelineResult submit(Environment environment) { - Configuration configuration = environment.getEnvironmentContext().getConfig(); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.start((PipelineService) pipelineServiceContext -> { - QueryClient queryManager = new QueryClient(); - QueryEngine engineContext = new GeaFlowQueryEngine(pipelineServiceContext); - QueryContext queryContext = QueryContext.builder() - .setEngineContext(engineContext) - .setTraversalParallelism(configuration.getInteger(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM)) - .setCompile(false) - .build(); - queryManager.executeQuery((String) pipelineServiceContext.getRequest(), queryContext); - }); - return pipeline.execute(); - } -} \ No newline at end of file + public static IPipelineResult submit(Environment environment) { + Configuration configuration = environment.getEnvironmentContext().getConfig(); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.start( + (PipelineService) + pipelineServiceContext -> { + QueryClient queryManager = new QueryClient(); + QueryEngine engineContext = new GeaFlowQueryEngine(pipelineServiceContext); + QueryContext queryContext = + QueryContext.builder() + .setEngineContext(engineContext) + .setTraversalParallelism( + configuration.getInteger( + AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM)) + .setCompile(false) + .build(); + queryManager.executeQuery((String) pipelineServiceContext.getRequest(), queryContext); + }); + return pipeline.execute(); + } +} diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/WordLengthService.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/WordLengthService.java index c83ef9cbc..88fbc2cc4 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/WordLengthService.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/service/WordLengthService.java @@ -19,7 +19,6 @@ package org.apache.geaflow.example.service; -import com.google.common.collect.Lists; import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.pdata.PWindowCollect; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -32,22 +31,28 @@ import org.apache.geaflow.pipeline.service.IPipelineServiceContext; import org.apache.geaflow.pipeline.service.PipelineService; +import com.google.common.collect.Lists; + public class WordLengthService { - public IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.start(new PipelineService() { - @Override - public void execute(IPipelineServiceContext pipelineServiceContext) { - int sourceParallelism = pipelineServiceContext.getConfig().getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); - String word = (String) pipelineServiceContext.getRequest(); - PWindowSource windowSource = pipelineServiceContext - .buildSource(new CollectionSource<>(Lists.newArrayList(word)), AllWindow.getInstance()) + public IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.start( + new PipelineService() { + @Override + public void execute(IPipelineServiceContext pipelineServiceContext) { + int sourceParallelism = + pipelineServiceContext.getConfig().getInteger(ExampleConfigKeys.SOURCE_PARALLELISM); + String word = (String) pipelineServiceContext.getRequest(); + PWindowSource windowSource = + pipelineServiceContext + .buildSource( + new CollectionSource<>(Lists.newArrayList(word)), AllWindow.getInstance()) .withParallelism(sourceParallelism); - PWindowCollect collect = windowSource.map(String::length).collect(); - pipelineServiceContext.response(collect); - } + PWindowCollect collect = windowSource.map(String::length).collect(); + pipelineServiceContext.response(collect); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamAggPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamAggPipeline.java index fed995aca..b1b48bd74 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamAggPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamAggPipeline.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -42,76 +43,77 @@ public class StreamAggPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg"; - public static final String REF_FILE_PATH = "data/reference/agg"; - public static final String SPLIT = ","; - - public IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", - Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(Long.valueOf(record)); - } + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg"; + public static final String REF_FILE_PATH = "data/reference/agg"; + public static final String SPLIT = ","; + + public IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(Long.valueOf(record)); } + } }) - .keyBy(k -> 1) - .aggregate(new AggFunc()) - .map(v -> String.format("%s", v)) - .sink(sink); - } + .keyBy(k -> 1) + .aggregate(new AggFunc()) + .map(v -> String.format("%s", v)) + .sink(sink); + } }); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + return pipeline.execute(); + } - public static class MutableLong implements Serializable { + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } - long value; - } + public static class MutableLong implements Serializable { - public static class AggFunc implements AggregateFunction { + long value; + } - @Override - public MutableLong createAccumulator() { - return new MutableLong(); - } + public static class AggFunc implements AggregateFunction { - @Override - public void add(Long value, MutableLong accumulator) { - accumulator.value += value; - } + @Override + public MutableLong createAccumulator() { + return new MutableLong(); + } - @Override - public Long getResult(MutableLong accumulator) { - return accumulator.value; - } + @Override + public void add(Long value, MutableLong accumulator) { + accumulator.value += value; + } - @Override - public MutableLong merge(MutableLong a, MutableLong b) { - return null; - } + @Override + public Long getResult(MutableLong accumulator) { + return accumulator.value; } + @Override + public MutableLong merge(MutableLong a, MutableLong b) { + return null; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamKeyAggPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamKeyAggPipeline.java index c5e29b17f..a9496a7db 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamKeyAggPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamKeyAggPipeline.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -46,74 +47,77 @@ public class StreamKeyAggPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg2"; - public static final String REF_FILE_PATH = "data/reference/agg2"; - public static final String SPLIT = ","; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg2"; + public static final String REF_FILE_PATH = "data/reference/agg2"; + public static final String SPLIT = ","; - public IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); + public IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(Long.valueOf(record)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(Long.valueOf(record)); } + } }) - .map(p -> Tuple.of(p, p)) - .keyBy(p -> ((long) ((Tuple) p).f0) % 7) - .aggregate(new AggFunc()) - .withParallelism(conf.getInteger(AGG_PARALLELISM)) - .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) - .sink(sink).withParallelism(conf.getInteger(SINK_PARALLELISM)); - } + .map(p -> Tuple.of(p, p)) + .keyBy(p -> ((long) ((Tuple) p).f0) % 7) + .aggregate(new AggFunc()) + .withParallelism(conf.getInteger(AGG_PARALLELISM)) + .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(SINK_PARALLELISM)); + } }); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); - } + return pipeline.execute(); + } - public static class AggFunc implements - AggregateFunction, Tuple, Tuple> { + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + } - @Override - public Tuple createAccumulator() { - return Tuple.of(0L, 0L); - } + public static class AggFunc + implements AggregateFunction, Tuple, Tuple> { - @Override - public void add(Tuple value, Tuple accumulator) { - accumulator.setF0(value.f0); - accumulator.setF1(value.f1 + accumulator.f1); - } + @Override + public Tuple createAccumulator() { + return Tuple.of(0L, 0L); + } - @Override - public Tuple getResult(Tuple accumulator) { - return Tuple.of(accumulator.f0, accumulator.f1); - } + @Override + public void add(Tuple value, Tuple accumulator) { + accumulator.setF0(value.f0); + accumulator.setF1(value.f1 + accumulator.f1); + } - @Override - public Tuple merge(Tuple a, Tuple b) { - return null; - } + @Override + public Tuple getResult(Tuple accumulator) { + return Tuple.of(accumulator.f0, accumulator.f1); } + @Override + public Tuple merge(Tuple a, Tuple b) { + return null; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamUnionPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamUnionPipeline.java index eda479581..f408d8442 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamUnionPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamUnionPipeline.java @@ -26,6 +26,7 @@ import java.io.Serializable; import java.util.Collections; import java.util.Comparator; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -47,90 +48,94 @@ public class StreamUnionPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/union"; - public static final String REF_FILE_PATH = "data/reference/union"; - public static final String SPLIT = ","; - - public IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", - Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); - - PWindowSource streamSource2 = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", - Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .union(streamSource2) - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(Long.valueOf(record)); - } + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/union"; + public static final String REF_FILE_PATH = "data/reference/union"; + public static final String SPLIT = ","; + + public IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); + + PWindowSource streamSource2 = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .union(streamSource2) + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(Long.valueOf(record)); } + } }) - .map(p -> Tuple.of(p, p)) - .keyBy(p -> p) - .materialize() - .aggregate(new AggFunc()) - .withParallelism(conf.getInteger(AGG_PARALLELISM)) - .map(v -> String.format("%s", v)) - .sink(sink) - .withParallelism(conf.getInteger(SINK_PARALLELISM)); - } + .map(p -> Tuple.of(p, p)) + .keyBy(p -> p) + .materialize() + .aggregate(new AggFunc()) + .withParallelism(conf.getInteger(AGG_PARALLELISM)) + .map(v -> String.format("%s", v)) + .sink(sink) + .withParallelism(conf.getInteger(SINK_PARALLELISM)); + } }); - return pipeline.execute(); + return pipeline.execute(); + } + + public static void validateResult() throws IOException { + ResultValidator.validateMapResult( + REF_FILE_PATH, + RESULT_FILE_PATH, + Comparator.comparingLong(StreamUnionPipeline::parseSumValue)); + } + + private static long parseSumValue(String result) { + String sumValue = result.split(",")[1]; + sumValue = sumValue.substring(0, sumValue.length() - 1); + return Long.parseLong(sumValue); + } + + public static class AggFunc + implements AggregateFunction, Tuple, Tuple> { + + @Override + public Tuple createAccumulator() { + return Tuple.of(0L, 0L); } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, - Comparator.comparingLong(StreamUnionPipeline::parseSumValue)); + @Override + public void add(Tuple value, Tuple accumulator) { + accumulator.setF0(value.f0); + accumulator.setF1(value.f1 + accumulator.f1); } - private static long parseSumValue(String result) { - String sumValue = result.split(",")[1]; - sumValue = sumValue.substring(0, sumValue.length() - 1); - return Long.parseLong(sumValue); + @Override + public Tuple getResult(Tuple accumulator) { + return Tuple.of(accumulator.f0, accumulator.f1); } - public static class AggFunc implements - AggregateFunction, Tuple, Tuple> { - - @Override - public Tuple createAccumulator() { - return Tuple.of(0L, 0L); - } - - @Override - public void add(Tuple value, Tuple accumulator) { - accumulator.setF0(value.f0); - accumulator.setF1(value.f1 + accumulator.f1); - } - - @Override - public Tuple getResult(Tuple accumulator) { - return Tuple.of(accumulator.f0, accumulator.f1); - } - - @Override - public Tuple merge(Tuple a, Tuple b) { - return null; - } + @Override + public Tuple merge(Tuple a, Tuple b) { + return null; } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountCallBackPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountCallBackPipeline.java index 98c95b66f..1a63cfa04 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountCallBackPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountCallBackPipeline.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; @@ -49,34 +50,40 @@ public class StreamWordCountCallBackPipeline implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(StreamWordCountCallBackPipeline.class); + private static final Logger LOGGER = + LoggerFactory.getLogger(StreamWordCountCallBackPipeline.class); - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordcount"; - public static final String REF_FILE_PATH = "data/reference/wordcount"; - private static final List taskCallBackList = new ArrayList<>(); + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordcount"; + public static final String REF_FILE_PATH = "data/reference/wordcount"; + private static final List taskCallBackList = new ArrayList<>(); - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - TaskCallBack taskCallBack = pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + TaskCallBack taskCallBack = + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new RecoverableFileSource("data/input/email_edge", - line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, SizeTumblingWindow.of(5000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + PWindowSource streamSource = + pipelineTaskCxt + .buildSource( + new RecoverableFileSource( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + SizeTumblingWindow.of(5000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); streamSource @@ -87,28 +94,27 @@ public void execute(IPipelineTaskContext pipelineTaskCxt) { .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) .sink(sink) .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - } - }); - - taskCallBackList.clear(); - taskCallBack.addCallBack(new ICallbackFunction() { - @Override - public void window(long windowId) { - LOGGER.info("finish windowId:{}", windowId); - taskCallBackList.add(windowId); - } + } + }); - @Override - public void terminal() { + taskCallBackList.clear(); + taskCallBack.addCallBack( + new ICallbackFunction() { + @Override + public void window(long windowId) { + LOGGER.info("finish windowId:{}", windowId); + taskCallBackList.add(windowId); + } - } + @Override + public void terminal() {} }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - Assert.assertTrue("task call back should handle window", !taskCallBackList.isEmpty()); - } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + Assert.assertTrue("task call back should handle window", !taskCallBackList.isEmpty()); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountPipeline.java index 7f2b6816a..e9fd54d8c 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordCountPipeline.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.MapFunction; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -47,76 +48,80 @@ public class StreamWordCountPipeline implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(StreamWordCountPipeline.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordcount"; - public static final String REF_FILE_PATH = "data/reference/wordcount"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } - - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new RecoverableFileSource("data/input/email_edge", + private static final Logger LOGGER = LoggerFactory.getLogger(StreamWordCountPipeline.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordcount"; + public static final String REF_FILE_PATH = "data/reference/wordcount"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt + .buildSource( + new RecoverableFileSource( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, SizeTumblingWindow.of(5000)) + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + SizeTumblingWindow.of(5000)) .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .map(e -> Tuple.of(e, 1)) - .keyBy(new KeySelectorFunc()) - .reduce(new CountFunc()) - .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) - .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .map(e -> Tuple.of(e, 1)) + .keyBy(new KeySelectorFunc()) + .reduce(new CountFunc()) + .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) + .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + } }); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + return pipeline.execute(); + } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } - public static class MapFunc implements MapFunction> { + public static class MapFunc implements MapFunction> { - @Override - public Tuple map(String value) { - LOGGER.info("MapFunc process value: {}", value); - return Tuple.of(value, 1); - } + @Override + public Tuple map(String value) { + LOGGER.info("MapFunc process value: {}", value); + return Tuple.of(value, 1); } + } - public static class KeySelectorFunc implements KeySelector, Object> { + public static class KeySelectorFunc implements KeySelector, Object> { - @Override - public Object getKey(Tuple value) { - return value.f0; - } + @Override + public Object getKey(Tuple value) { + return value.f0; } + } - public static class CountFunc implements ReduceFunction> { + public static class CountFunc implements ReduceFunction> { - @Override - public Tuple reduce(Tuple oldValue, Tuple newValue) { - return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); - } + @Override + public Tuple reduce( + Tuple oldValue, Tuple newValue) { + return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordFlatMapPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordFlatMapPipeline.java index eb0f61427..87c63aff2 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordFlatMapPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordFlatMapPipeline.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.FlatMapFunction; import org.apache.geaflow.api.function.io.SinkFunction; @@ -43,45 +44,47 @@ public class StreamWordFlatMapPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/flatmap"; - public static final String REF_FILE_PATH = "data/reference/flatmap"; - public static final String SPLIT = ","; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/flatmap"; + public static final String REF_FILE_PATH = "data/reference/flatmap"; + public static final String SPLIT = ","; - public IPipelineResult submit(Environment environment) { - Map config = new HashMap<>(); - config.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - environment.getEnvironmentContext().withConfig(config); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration config = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", - Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); + public IPipelineResult submit(Environment environment) { + Map config = new HashMap<>(); + config.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + environment.getEnvironmentContext().withConfig(config); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration config = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - streamSource - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(record); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + streamSource + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(record); } - }).sink(sink); - } + } + }) + .sink(sink); + } }); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + return pipeline.execute(); + } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordPrintPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordPrintPipeline.java index 414842a1e..f17529d58 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordPrintPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/StreamWordPrintPipeline.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; @@ -41,38 +42,39 @@ public class StreamWordPrintPipeline implements Serializable { + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordprint"; + public static final String REF_FILE_PATH = "data/reference/wordprint"; - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/wordprint"; - public static final String REF_FILE_PATH = "data/reference/wordprint"; - - public IPipelineResult submit(Environment environment) { - Map config = new HashMap<>(); - config.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - environment.getEnvironmentContext().withConfig(config); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration config = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new FileSource("data/input" - + "/email_vertex", + public IPipelineResult submit(Environment environment) { + Map config = new HashMap<>(); + config.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + environment.getEnvironmentContext().withConfig(config); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration config = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_vertex", line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, SizeTumblingWindow.of(5000)); + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + SizeTumblingWindow.of(5000)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - streamSource.sink(sink); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + streamSource.sink(sink); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); - } + public static void validateResult() throws IOException { + ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamFoTest.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamFoTest.java index 3de5cf586..0a5272dff 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamFoTest.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamFoTest.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.GsonUtil; import org.apache.geaflow.common.utils.SleepUtils; @@ -33,24 +34,27 @@ public class UnBoundedStreamFoTest { - private static final Logger LOGGER = LoggerFactory.getLogger(UnBoundedStreamFoTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(UnBoundedStreamFoTest.class); - public static void main(String[] args) throws ExecutionException, InterruptedException, TimeoutException { - Environment environment = EnvironmentFactory.onRayEnvironment(args); - Configuration configuration = environment.getEnvironmentContext().getConfig(); + public static void main(String[] args) + throws ExecutionException, InterruptedException, TimeoutException { + Environment environment = EnvironmentFactory.onRayEnvironment(args); + Configuration configuration = environment.getEnvironmentContext().getConfig(); - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - Map hbaseConfig = new HashMap<>(); - hbaseConfig.put("hbase.zookeeper.quorum", "hbase-v51-nn-10001.eu95sqa.tbsite" + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + Map hbaseConfig = new HashMap<>(); + hbaseConfig.put( + "hbase.zookeeper.quorum", + "hbase-v51-nn-10001.eu95sqa.tbsite" + ".net,hbase-v51-nn-10002.eu95sqa.tbsite.net,hbase-v51-nn-10003.eu95sqa.tbsite.net"); - hbaseConfig.put("zookeeper.znode.parent", "/hbase-eu95sqa-perf-test-ssd"); - hbaseConfig.put("hbase.zookeeper.property.clientPort", "2181"); - Map config = new HashMap<>(); - config.put("geaflow.store.hbase.config.json", GsonUtil.toJson(hbaseConfig)); - - configuration.putAll(config); - pipeline.submit(environment); - SleepUtils.sleepSecond(3); - LOGGER.error("main finished"); - } + hbaseConfig.put("zookeeper.znode.parent", "/hbase-eu95sqa-perf-test-ssd"); + hbaseConfig.put("hbase.zookeeper.property.clientPort", "2181"); + Map config = new HashMap<>(); + config.put("geaflow.store.hbase.config.json", GsonUtil.toJson(hbaseConfig)); + + configuration.putAll(config); + pipeline.submit(environment); + SleepUtils.sleepSecond(3); + LOGGER.error("main finished"); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordCount.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordCount.java index 2eee15c3c..f272f1ccf 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordCount.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordCount.java @@ -24,10 +24,10 @@ public class UnBoundedStreamWordCount { - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - pipeline.submit(environment); - } + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + pipeline.submit(environment); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordPrint.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordPrint.java index 1f6edb207..f48f31d44 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordPrint.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/UnBoundedStreamWordPrint.java @@ -24,11 +24,10 @@ public class UnBoundedStreamWordPrint { - public static void main(String[] args) { - Environment environment = EnvironmentFactory.onRayEnvironment(); - - StreamWordPrintPipeline pipeline = new StreamWordPrintPipeline(); - pipeline.submit(environment); - } + public static void main(String[] args) { + Environment environment = EnvironmentFactory.onRayEnvironment(); + StreamWordPrintPipeline pipeline = new StreamWordPrintPipeline(); + pipeline.submit(environment); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/WindowStreamWordCount.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/WindowStreamWordCount.java index 710fb3d6e..95d99461f 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/WindowStreamWordCount.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/stream/WindowStreamWordCount.java @@ -19,8 +19,8 @@ package org.apache.geaflow.example.stream; -import com.google.common.collect.Lists; import java.util.List; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; @@ -35,26 +35,28 @@ import org.apache.geaflow.pipeline.task.IPipelineTaskContext; import org.apache.geaflow.pipeline.task.PipelineTask; +import com.google.common.collect.Lists; + public class WindowStreamWordCount { - public static void main(String[] args) { - Environment environment = EnvironmentFactory.onRayEnvironment(args); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration config = pipelineTaskCxt.getConfig(); - List words = Lists.newArrayList("hello", "world", "hello", "word"); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new CollectionSource(words) {}, - SizeTumblingWindow.of(100)); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); - streamSource.window(WindowFactory.createSizeTumblingWindow(2)) - .sink(sink); - } + public static void main(String[] args) { + Environment environment = EnvironmentFactory.onRayEnvironment(args); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration config = pipelineTaskCxt.getConfig(); + List words = Lists.newArrayList("hello", "world", "hello", "word"); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new CollectionSource(words) {}, SizeTumblingWindow.of(100)); + + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(config); + streamSource.window(WindowFactory.createSizeTumblingWindow(2)).sink(sink); + } }); - pipeline.execute(); - } + pipeline.execute(); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/EnvironmentUtil.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/EnvironmentUtil.java index 2923b412b..0a8cf958f 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/EnvironmentUtil.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/EnvironmentUtil.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.cluster.constants.ClusterConstants.CLUSTER_TYPE; import java.util.Locale; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.env.Environment; import org.apache.geaflow.env.EnvironmentFactory; @@ -30,27 +31,26 @@ import org.slf4j.LoggerFactory; public class EnvironmentUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentUtil.class); + private static final Logger LOGGER = LoggerFactory.getLogger(EnvironmentUtil.class); - public static Environment loadEnvironment(String[] args) { - EnvType clusterType = getClusterType(); - switch (clusterType) { - case K8S: - return EnvironmentFactory.onK8SEnvironment(args); - case RAY: - return EnvironmentFactory.onRayEnvironment(args); - default: - return EnvironmentFactory.onLocalEnvironment(args); - } + public static Environment loadEnvironment(String[] args) { + EnvType clusterType = getClusterType(); + switch (clusterType) { + case K8S: + return EnvironmentFactory.onK8SEnvironment(args); + case RAY: + return EnvironmentFactory.onRayEnvironment(args); + default: + return EnvironmentFactory.onLocalEnvironment(args); } + } - private static EnvType getClusterType() { - String clusterType = System.getProperty(CLUSTER_TYPE); - if (StringUtils.isBlank(clusterType)) { - LOGGER.warn("use local as default cluster"); - return EnvType.LOCAL; - } - return (EnvType.valueOf(clusterType.toUpperCase(Locale.ROOT))); + private static EnvType getClusterType() { + String clusterType = System.getProperty(CLUSTER_TYPE); + if (StringUtils.isBlank(clusterType)) { + LOGGER.warn("use local as default cluster"); + return EnvType.LOCAL; } - + return (EnvType.valueOf(clusterType.toUpperCase(Locale.ROOT))); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactory.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactory.java index f275efd00..bd11a60b9 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactory.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactory.java @@ -27,23 +27,19 @@ public class ExampleSinkFunctionFactory { - public enum SinkType { + public enum SinkType { - /** - * result sink to console. - */ - CONSOLE_SINK, - /** - * result sink to local file. - */ - FILE_SINK - } + /** result sink to console. */ + CONSOLE_SINK, + /** result sink to local file. */ + FILE_SINK + } - public static SinkFunction getSinkFunction(Configuration configuration) { - String sinkType = configuration.getString(ExampleConfigKeys.GEAFLOW_SINK_TYPE); - if (sinkType.equalsIgnoreCase(SinkType.CONSOLE_SINK.name())) { - return new ConsoleSink<>(); - } - return new FileSink<>(); + public static SinkFunction getSinkFunction(Configuration configuration) { + String sinkType = configuration.getString(ExampleConfigKeys.GEAFLOW_SINK_TYPE); + if (sinkType.equalsIgnoreCase(SinkType.CONSOLE_SINK.name())) { + return new ConsoleSink<>(); } + return new FileSink<>(); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/PipelineResultCollect.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/PipelineResultCollect.java index 64f078125..6b6170f55 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/PipelineResultCollect.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/PipelineResultCollect.java @@ -24,11 +24,10 @@ public class PipelineResultCollect { - public static T get(IPipelineResult result) { - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute pipeline failed"); - } - return result.get(); + public static T get(IPipelineResult result) { + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute pipeline failed"); } - + return result.get(); + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ResultValidator.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ResultValidator.java index 235d5c937..4047a6b4e 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ResultValidator.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/ResultValidator.java @@ -19,8 +19,6 @@ package org.apache.geaflow.example.util; -import com.google.common.io.Files; -import com.google.common.io.Resources; import java.io.File; import java.io.IOException; import java.nio.charset.Charset; @@ -30,87 +28,90 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.io.Files; +import com.google.common.io.Resources; + public class ResultValidator { - private static final Logger LOGGER = LoggerFactory.getLogger(ResultValidator.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ResultValidator.class); - public static void cleanResult(String resultPath) { - File dir = new File(resultPath); - FileUtils.deleteQuietly(dir); - } + public static void cleanResult(String resultPath) { + File dir = new File(resultPath); + FileUtils.deleteQuietly(dir); + } + + public static void validateResult(String refPath, String resultPath) throws IOException { + validateResult(refPath, resultPath, null); + } - public static void validateResult(String refPath, String resultPath) throws IOException { - validateResult(refPath, resultPath, null); + public static void validateResult( + String refPath, String resultPath, Comparator comparator) throws IOException { + List result = readFiles(resultPath); + List reference = readFiles(Resources.getResource(refPath).getFile()); + if (comparator == null) { + Collections.sort(result); + Collections.sort(reference); + } else { + Collections.sort(result, comparator); + Collections.sort(reference, comparator); } - public static void validateResult(String refPath, String resultPath, Comparator comparator) - throws IOException { + LOGGER.info("result size {}, reference size {}", result.size(), reference.size()); + Assert.assertEquals(reference.size(), result.size()); + Assert.assertEquals(reference, result); - List result = readFiles(resultPath); - List reference = readFiles(Resources.getResource(refPath).getFile()); - if (comparator == null) { - Collections.sort(result); - Collections.sort(reference); - } else { - Collections.sort(result, comparator); - Collections.sort(reference, comparator); - } + cleanResult(resultPath); + } - LOGGER.info("result size {}, reference size {}", result.size(), reference.size()); - Assert.assertEquals(reference.size(), result.size()); - Assert.assertEquals(reference, result); + public static void validateMapResult(String refPath, String resultPath) throws IOException { + validateMapResult(refPath, resultPath, null); + } - cleanResult(resultPath); + public static void validateMapResult( + String refPath, String resultPath, Comparator comparator) throws IOException { + List result = readFiles(resultPath); + if (comparator != null) { + Collections.sort(result, comparator); + } + Map resultMap = new HashMap<>(); + for (String temp : result) { + String[] values = temp.split(","); + resultMap.put(values[0].trim(), values[1].trim()); } - public static void validateMapResult(String refPath, String resultPath) throws IOException { - validateMapResult(refPath, resultPath, null); + List reference = readFiles(Resources.getResource(refPath).getFile()); + if (comparator != null) { + Collections.sort(reference, comparator); } - public static void validateMapResult(String refPath, String resultPath, Comparator comparator) - throws IOException { - List result = readFiles(resultPath); - if (comparator != null) { - Collections.sort(result, comparator); - } - Map resultMap = new HashMap<>(); - for (String temp : result) { - String[] values = temp.split(","); - resultMap.put(values[0].trim(), values[1].trim()); - } - - List reference = readFiles(Resources.getResource(refPath).getFile()); - if (comparator != null) { - Collections.sort(reference, comparator); - } - - Map referenceMap = new HashMap<>(); - for (String temp : reference) { - String[] values = temp.split(","); - referenceMap.put(values[0].trim(), values[1].trim()); - } - - LOGGER.info("result size {}, reference size {}", resultMap.size(), referenceMap.size()); - Assert.assertEquals(referenceMap.size(), resultMap.size()); - Assert.assertEquals(referenceMap, resultMap); - - cleanResult(resultPath); + Map referenceMap = new HashMap<>(); + for (String temp : reference) { + String[] values = temp.split(","); + referenceMap.put(values[0].trim(), values[1].trim()); } - private static List readFiles(String path) throws IOException { - File dir = new File(path); - List result = new ArrayList<>(); - File[] files = dir.listFiles(); - for (File file : files) { - List lines = Files.readLines(file, Charset.defaultCharset()); - result.addAll(lines); - } - return result; + LOGGER.info("result size {}, reference size {}", resultMap.size(), referenceMap.size()); + Assert.assertEquals(referenceMap.size(), resultMap.size()); + Assert.assertEquals(referenceMap, resultMap); + + cleanResult(resultPath); + } + + private static List readFiles(String path) throws IOException { + File dir = new File(path); + List result = new ArrayList<>(); + File[] files = dir.listFiles(); + for (File file : files) { + List lines = Files.readLines(file, Charset.defaultCharset()); + result.addAll(lines); } + return result; + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/VertexEdgeParser.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/VertexEdgeParser.java index faae85019..24342920b 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/VertexEdgeParser.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/util/VertexEdgeParser.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; @@ -32,116 +33,125 @@ public class VertexEdgeParser { - public static List>> vertexParserMap(String line) { - Map vertexValue = new HashMap<>(2); - String[] split = line.split("\\s+"); - IVertex> v1 = new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); - IVertex> v2 = new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); - List>> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } - - public static List>> edgeParserMap(String line) { - Map edgeValue = new HashMap<>(2); - String[] split = line.split("\\s+"); - int src = Integer.parseInt(split[0]); - int dst = Integer.parseInt(split[1]); - edgeValue.put("dis", 1); - IEdge> e1 = new ValueEdge>(src, dst, edgeValue, EdgeDirection.OUT); - IEdge> e2 = new ValueEdge>(dst, src, edgeValue, EdgeDirection.IN); - List>> edges = new ArrayList<>(); - edges.add(e1); - edges.add(e2); - return edges; - } - - public static List>>> vertexParserMapMap(String line) { - Map> vertexValue = new HashMap<>(2); - String[] split = line.split("\\s+"); - IVertex>> v1 = new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); - IVertex>> v2 = new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); - List>>> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List>> vertexParserMap(String line) { + Map vertexValue = new HashMap<>(2); + String[] split = line.split("\\s+"); + IVertex> v1 = + new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); + IVertex> v2 = + new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); + List>> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List>> vertexParserObjectMap(String line) { - Map vertexValue = new HashMap<>(2); - String[] split = line.split("\\s+"); - IVertex> v1 = new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); - IVertex> v2 = new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); - List>> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List>> edgeParserMap(String line) { + Map edgeValue = new HashMap<>(2); + String[] split = line.split("\\s+"); + int src = Integer.parseInt(split[0]); + int dst = Integer.parseInt(split[1]); + edgeValue.put("dis", 1); + IEdge> e1 = + new ValueEdge>(src, dst, edgeValue, EdgeDirection.OUT); + IEdge> e2 = + new ValueEdge>(dst, src, edgeValue, EdgeDirection.IN); + List>> edges = new ArrayList<>(); + edges.add(e1); + edges.add(e2); + return edges; + } + public static List>>> vertexParserMapMap( + String line) { + Map> vertexValue = new HashMap<>(2); + String[] split = line.split("\\s+"); + IVertex>> v1 = + new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); + IVertex>> v2 = + new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); + List>>> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List> vertexParserInteger(String line) { - String[] split = line.split("\\s+"); - IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), 0); - IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), 0); - List> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List>> vertexParserObjectMap(String line) { + Map vertexValue = new HashMap<>(2); + String[] split = line.split("\\s+"); + IVertex> v1 = + new ValueVertex<>(Integer.parseInt(split[0]), vertexValue); + IVertex> v2 = + new ValueVertex<>(Integer.parseInt(split[1]), vertexValue); + List>> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List> edgeParserInteger(String line) { - String[] split = line.split("\\s+"); - int src = Integer.parseInt(split[0]); - int dst = Integer.parseInt(split[1]); - IEdge e1 = new ValueEdge<>(src, dst, 0, EdgeDirection.OUT); - IEdge e2 = new ValueEdge<>(dst, src, 0, EdgeDirection.IN); - List> edges = new ArrayList<>(); - edges.add(e1); - edges.add(e2); - return edges; - } + public static List> vertexParserInteger(String line) { + String[] split = line.split("\\s+"); + IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), 0); + IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), 0); + List> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List>> vertexParserTuple(String line) { - String[] split = line.split("\\s+"); - IVertex> v1 = new ValueVertex<>(Integer.parseInt(split[0]), Tuple.of(0.0, 0)); - IVertex> v2 = new ValueVertex<>(Integer.parseInt(split[1]), Tuple.of(0.0, 0)); - List>> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List> edgeParserInteger(String line) { + String[] split = line.split("\\s+"); + int src = Integer.parseInt(split[0]); + int dst = Integer.parseInt(split[1]); + IEdge e1 = new ValueEdge<>(src, dst, 0, EdgeDirection.OUT); + IEdge e2 = new ValueEdge<>(dst, src, 0, EdgeDirection.IN); + List> edges = new ArrayList<>(); + edges.add(e1); + edges.add(e2); + return edges; + } - public static List> vertexParserDouble(String line) { - String[] split = line.split("\\s+"); - IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), 0.1); - IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), 0.1); - List> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List>> vertexParserTuple(String line) { + String[] split = line.split("\\s+"); + IVertex> v1 = + new ValueVertex<>(Integer.parseInt(split[0]), Tuple.of(0.0, 0)); + IVertex> v2 = + new ValueVertex<>(Integer.parseInt(split[1]), Tuple.of(0.0, 0)); + List>> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List> vertexParserBoolean(String line) { - String[] split = line.split("\\s+"); - IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), false); - IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), false); - List> vertices = new ArrayList<>(); - vertices.add(v1); - vertices.add(v2); - return vertices; - } + public static List> vertexParserDouble(String line) { + String[] split = line.split("\\s+"); + IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), 0.1); + IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), 0.1); + List> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } - public static List> edgeParserBoolean(String line) { - String[] split = line.split("\\s+"); - int src = Integer.parseInt(split[0]); - int dst = Integer.parseInt(split[1]); - IEdge e1 = new ValueEdge<>(src, dst, true, EdgeDirection.OUT); - IEdge e2 = new ValueEdge<>(dst, src, true, EdgeDirection.IN); - List> edges = new ArrayList<>(); - edges.add(e1); - edges.add(e2); - return edges; - } + public static List> vertexParserBoolean(String line) { + String[] split = line.split("\\s+"); + IVertex v1 = new ValueVertex<>(Integer.parseInt(split[0]), false); + IVertex v2 = new ValueVertex<>(Integer.parseInt(split[1]), false); + List> vertices = new ArrayList<>(); + vertices.add(v1); + vertices.add(v2); + return vertices; + } + public static List> edgeParserBoolean(String line) { + String[] split = line.split("\\s+"); + int src = Integer.parseInt(split[0]); + int dst = Integer.parseInt(split[1]); + IEdge e1 = new ValueEdge<>(src, dst, true, EdgeDirection.OUT); + IEdge e2 = new ValueEdge<>(dst, src, true, EdgeDirection.IN); + List> edges = new ArrayList<>(); + edges.add(e1); + edges.add(e2); + return edges; + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowCallBackPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowCallBackPipeline.java index b6aac921d..ffcf7a3ee 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowCallBackPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowCallBackPipeline.java @@ -23,6 +23,7 @@ import java.io.Serializable; import java.util.Collections; import java.util.Comparator; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.MapFunction; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -50,87 +51,90 @@ public class WindowCallBackPipeline implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(WindowCallBackPipeline.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count1"; - public static final String REF_FILE_PATH = "data/reference/count1"; - - public static void main(String[] args) { - Environment environment = EnvironmentUtil.loadEnvironment(args); - submit(environment); - } - - public static IPipelineResult submit(Environment environment) { - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - ResultValidator.cleanResult(RESULT_FILE_PATH); - TaskCallBack taskCallBack = pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { + private static final Logger LOGGER = LoggerFactory.getLogger(WindowCallBackPipeline.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count1"; + public static final String REF_FILE_PATH = "data/reference/count1"; + + public static void main(String[] args) { + Environment environment = EnvironmentUtil.loadEnvironment(args); + submit(environment); + } + + public static IPipelineResult submit(Environment environment) { + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + ResultValidator.cleanResult(RESULT_FILE_PATH); + TaskCallBack taskCallBack = + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new RecoverableFileSource("data/input/email_edge", - line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, SizeTumblingWindow.of(1000)) - .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); + PWindowSource streamSource = + pipelineTaskCxt + .buildSource( + new RecoverableFileSource( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + SizeTumblingWindow.of(1000)) + .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); streamSource .filter(e -> Integer.parseInt(e) > 100) .map(e -> Tuple.of(e, 1)) .keyBy(new KeySelectorFunc()) - .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, - ((Tuple) v).f1)) + .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, ((Tuple) v).f1)) .sink(sink) .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - } + } + }); + taskCallBack.addCallBack( + new ICallbackFunction() { + @Override + public void window(long windowId) { + LOGGER.info("finish windowId:{}", windowId); + } + + @Override + public void terminal() {} }); - taskCallBack.addCallBack(new ICallbackFunction() { - @Override - public void window(long windowId) { - LOGGER.info("finish windowId:{}", windowId); - } - - @Override - public void terminal() { - } - }); - - return pipeline.execute(); - } - - public static void validateResult(Comparator comparator) throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, comparator); - } + return pipeline.execute(); + } + public static void validateResult(Comparator comparator) throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, comparator); + } - public static class MapFunc implements MapFunction> { + public static class MapFunc implements MapFunction> { - @Override - public Tuple map(String value) { - LOGGER.info("MapFunc process value: {}", value); - return Tuple.of(value, 1); - } + @Override + public Tuple map(String value) { + LOGGER.info("MapFunc process value: {}", value); + return Tuple.of(value, 1); } + } - public static class KeySelectorFunc implements KeySelector, Object> { + public static class KeySelectorFunc implements KeySelector, Object> { - @Override - public Object getKey(Tuple value) { - return value.f0; - } + @Override + public Object getKey(Tuple value) { + return value.f0; } + } - public static class CountFunc implements ReduceFunction> { + public static class CountFunc implements ReduceFunction> { - @Override - public Tuple reduce(Tuple oldValue, Tuple newValue) { - return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); - } + @Override + public Tuple reduce( + Tuple oldValue, Tuple newValue) { + return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowKeyAggPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowKeyAggPipeline.java index 366363ac0..131c62af7 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowKeyAggPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowKeyAggPipeline.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -47,76 +48,81 @@ public class WindowKeyAggPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg4"; - public static final String REF_FILE_PATH = "data/reference/agg4"; - public static final String SPLIT = ","; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg4"; + public static final String REF_FILE_PATH = "data/reference/agg4"; + public static final String SPLIT = ","; - public IPipelineResult submit(Environment environment) { - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - envConfig.getConfigMap().put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); - ResultValidator.cleanResult(RESULT_FILE_PATH); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", Collections::singletonList) { - }, AllWindow.getInstance()); + public IPipelineResult submit(Environment environment) { + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + envConfig + .getConfigMap() + .put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); + ResultValidator.cleanResult(RESULT_FILE_PATH); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + AllWindow.getInstance()); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(Long.valueOf(record)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(Long.valueOf(record)); } + } }) - .map(p -> Tuple.of(p, p)) - .keyBy(p -> ((long) ((Tuple) p).f0) % 7) - .aggregate(new AggFunc()) - .withParallelism(conf.getInteger(AGG_PARALLELISM)) - .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) - .sink(sink).withParallelism(conf.getInteger(SINK_PARALLELISM)); - } + .map(p -> Tuple.of(p, p)) + .keyBy(p -> ((long) ((Tuple) p).f0) % 7) + .aggregate(new AggFunc()) + .withParallelism(conf.getInteger(AGG_PARALLELISM)) + .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(SINK_PARALLELISM)); + } }); - pipeline.shutdown(); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); - } + pipeline.shutdown(); + return pipeline.execute(); + } - public static class AggFunc implements - AggregateFunction, Tuple, Tuple> { + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + } - @Override - public Tuple createAccumulator() { - return Tuple.of(0L, 0L); - } + public static class AggFunc + implements AggregateFunction, Tuple, Tuple> { - @Override - public void add(Tuple value, Tuple accumulator) { - accumulator.setF0(value.f0); - accumulator.setF1(value.f1 + accumulator.f1); - } + @Override + public Tuple createAccumulator() { + return Tuple.of(0L, 0L); + } - @Override - public Tuple getResult(Tuple accumulator) { - return Tuple.of(accumulator.f0, accumulator.f1); - } + @Override + public void add(Tuple value, Tuple accumulator) { + accumulator.setF0(value.f0); + accumulator.setF1(value.f1 + accumulator.f1); + } - @Override - public Tuple merge(Tuple a, Tuple b) { - return null; - } + @Override + public Tuple getResult(Tuple accumulator) { + return Tuple.of(accumulator.f0, accumulator.f1); } + @Override + public Tuple merge(Tuple a, Tuple b) { + return null; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamKeyAggPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamKeyAggPipeline.java index c46bbf970..7825e8cee 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamKeyAggPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamKeyAggPipeline.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.collector.Collector; import org.apache.geaflow.api.function.base.AggregateFunction; import org.apache.geaflow.api.function.base.FlatMapFunction; @@ -47,75 +48,80 @@ public class WindowStreamKeyAggPipeline implements Serializable { - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg3"; - public static final String REF_FILE_PATH = "data/reference/agg3"; - public static final String SPLIT = ","; + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/agg3"; + public static final String REF_FILE_PATH = "data/reference/agg3"; + public static final String SPLIT = ","; - public IPipelineResult submit(Environment environment) { - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - envConfig.getConfigMap().put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); - ResultValidator.cleanResult(RESULT_FILE_PATH); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = - pipelineTaskCxt.buildSource(new FileSource("data/input" - + "/email_edge", Collections::singletonList) { - }, SizeTumblingWindow.of(5000)); + public IPipelineResult submit(Environment environment) { + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + envConfig + .getConfigMap() + .put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); + ResultValidator.cleanResult(RESULT_FILE_PATH); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt.buildSource( + new FileSource( + "data/input" + "/email_edge", Collections::singletonList) {}, + SizeTumblingWindow.of(5000)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .flatMap(new FlatMapFunction() { - @Override - public void flatMap(String value, Collector collector) { - String[] records = value.split(SPLIT); - for (String record : records) { - collector.partition(Long.valueOf(record)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap(String value, Collector collector) { + String[] records = value.split(SPLIT); + for (String record : records) { + collector.partition(Long.valueOf(record)); } + } }) - .map(p -> Tuple.of(p, p)) - .keyBy(p -> ((long) ((Tuple) p).f0) % 7) - .aggregate(new AggFunc()) - .withParallelism(conf.getInteger(AGG_PARALLELISM)) - .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) - .sink(sink).withParallelism(conf.getInteger(SINK_PARALLELISM)); - } + .map(p -> Tuple.of(p, p)) + .keyBy(p -> ((long) ((Tuple) p).f0) % 7) + .aggregate(new AggFunc()) + .withParallelism(conf.getInteger(AGG_PARALLELISM)) + .map(v -> String.format("%s,%s", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(SINK_PARALLELISM)); + } }); - return pipeline.execute(); - } - - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); - } + return pipeline.execute(); + } - public static class AggFunc implements - AggregateFunction, Tuple, Tuple> { + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + } - @Override - public Tuple createAccumulator() { - return Tuple.of(0L, 0L); - } + public static class AggFunc + implements AggregateFunction, Tuple, Tuple> { - @Override - public void add(Tuple value, Tuple accumulator) { - accumulator.setF0(value.f0); - accumulator.setF1(value.f1 + accumulator.f1); - } + @Override + public Tuple createAccumulator() { + return Tuple.of(0L, 0L); + } - @Override - public Tuple getResult(Tuple accumulator) { - return Tuple.of(accumulator.f0, accumulator.f1); - } + @Override + public void add(Tuple value, Tuple accumulator) { + accumulator.setF0(value.f0); + accumulator.setF1(value.f1 + accumulator.f1); + } - @Override - public Tuple merge(Tuple a, Tuple b) { - return null; - } + @Override + public Tuple getResult(Tuple accumulator) { + return Tuple.of(accumulator.f0, accumulator.f1); } + @Override + public Tuple merge(Tuple a, Tuple b) { + return null; + } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamWordCountPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamWordCountPipeline.java index 8d4c1b8ff..167407aae 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamWordCountPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowStreamWordCountPipeline.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.Collections; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.MapFunction; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -47,72 +48,78 @@ public class WindowStreamWordCountPipeline implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(WindowStreamWordCountPipeline.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count2"; - public static final String REF_FILE_PATH = "data/reference/count2"; - - public IPipelineResult submit(Environment environment) { - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - envConfig.getConfigMap().put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); - ResultValidator.cleanResult(RESULT_FILE_PATH); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new FileSource("data/input/email_edge", + private static final Logger LOGGER = LoggerFactory.getLogger(WindowStreamWordCountPipeline.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count2"; + public static final String REF_FILE_PATH = "data/reference/count2"; + + public IPipelineResult submit(Environment environment) { + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + envConfig + .getConfigMap() + .put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); + ResultValidator.cleanResult(RESULT_FILE_PATH); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt + .buildSource( + new FileSource( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, SizeTumblingWindow.of(5000)) + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + SizeTumblingWindow.of(5000)) .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .map(e -> Tuple.of(e, 1)) - .keyBy(new KeySelectorFunc()) - .reduce(new CountFunc()) - .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) - .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, - ((Tuple) v).f1)) - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .map(e -> Tuple.of(e, 1)) + .keyBy(new KeySelectorFunc()) + .reduce(new CountFunc()) + .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) + .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + } }); - return pipeline.execute(); - } + return pipeline.execute(); + } - public static void validateResult() throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); - } + public static void validateResult() throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, String::compareTo); + } - public static class MapFunc implements MapFunction> { + public static class MapFunc implements MapFunction> { - @Override - public Tuple map(String value) { - LOGGER.info("MapFunc process value: {}", value); - return Tuple.of(value, 1); - } + @Override + public Tuple map(String value) { + LOGGER.info("MapFunc process value: {}", value); + return Tuple.of(value, 1); } + } - public static class KeySelectorFunc implements KeySelector, Object> { + public static class KeySelectorFunc implements KeySelector, Object> { - @Override - public Object getKey(Tuple value) { - return value.f0; - } + @Override + public Object getKey(Tuple value) { + return value.f0; } + } - public static class CountFunc implements ReduceFunction> { + public static class CountFunc implements ReduceFunction> { - @Override - public Tuple reduce(Tuple oldValue, Tuple newValue) { - return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); - } + @Override + public Tuple reduce( + Tuple oldValue, Tuple newValue) { + return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); } + } } diff --git a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowWordCountPipeline.java b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowWordCountPipeline.java index 8d73a616a..75eb031bf 100644 --- a/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowWordCountPipeline.java +++ b/geaflow/geaflow-examples/src/main/java/org/apache/geaflow/example/window/WindowWordCountPipeline.java @@ -23,6 +23,7 @@ import java.io.Serializable; import java.util.Collections; import java.util.Comparator; + import org.apache.geaflow.api.function.base.KeySelector; import org.apache.geaflow.api.function.base.MapFunction; import org.apache.geaflow.api.function.base.ReduceFunction; @@ -48,73 +49,78 @@ public class WindowWordCountPipeline implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(WindowWordCountPipeline.class); - - public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count1"; - public static final String REF_FILE_PATH = "data/reference/count1"; - - public IPipelineResult submit(Environment environment) { - Configuration envConfig = environment.getEnvironmentContext().getConfig(); - envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); - envConfig.getConfigMap().put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); - ResultValidator.cleanResult(RESULT_FILE_PATH); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); - pipeline.submit(new PipelineTask() { - @Override - public void execute(IPipelineTaskContext pipelineTaskCxt) { - Configuration conf = pipelineTaskCxt.getConfig(); - PWindowSource streamSource = pipelineTaskCxt.buildSource( - new FileSource("data/input/email_edge", + private static final Logger LOGGER = LoggerFactory.getLogger(WindowWordCountPipeline.class); + + public static final String RESULT_FILE_PATH = "./target/tmp/data/result/count1"; + public static final String REF_FILE_PATH = "data/reference/count1"; + + public IPipelineResult submit(Environment environment) { + Configuration envConfig = environment.getEnvironmentContext().getConfig(); + envConfig.getConfigMap().put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); + envConfig + .getConfigMap() + .put(FrameworkConfigKeys.INC_STREAM_MATERIALIZE_DISABLE.getKey(), Boolean.TRUE.toString()); + ResultValidator.cleanResult(RESULT_FILE_PATH); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); + pipeline.submit( + new PipelineTask() { + @Override + public void execute(IPipelineTaskContext pipelineTaskCxt) { + Configuration conf = pipelineTaskCxt.getConfig(); + PWindowSource streamSource = + pipelineTaskCxt + .buildSource( + new FileSource( + "data/input/email_edge", line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }) { - }, AllWindow.getInstance()) + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }) {}, + AllWindow.getInstance()) .withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); - streamSource - .map(e -> Tuple.of(e, 1)) - .keyBy(new KeySelectorFunc()) - .reduce(new CountFunc()) - .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) - .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, - ((Tuple) v).f1)) - .sink(sink) - .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); - } + SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(conf); + streamSource + .map(e -> Tuple.of(e, 1)) + .keyBy(new KeySelectorFunc()) + .reduce(new CountFunc()) + .withParallelism(conf.getInteger(ExampleConfigKeys.REDUCE_PARALLELISM)) + .map(v -> String.format("(%s,%s)", ((Tuple) v).f0, ((Tuple) v).f1)) + .sink(sink) + .withParallelism(conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM)); + } }); - return pipeline.execute(); - } - - public static void validateResult(Comparator comparator) throws IOException { - ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, comparator); - } + return pipeline.execute(); + } + public static void validateResult(Comparator comparator) throws IOException { + ResultValidator.validateMapResult(REF_FILE_PATH, RESULT_FILE_PATH, comparator); + } - public static class MapFunc implements MapFunction> { + public static class MapFunc implements MapFunction> { - @Override - public Tuple map(String value) { - LOGGER.info("MapFunc process value: {}", value); - return Tuple.of(value, 1); - } + @Override + public Tuple map(String value) { + LOGGER.info("MapFunc process value: {}", value); + return Tuple.of(value, 1); } + } - public static class KeySelectorFunc implements KeySelector, Object> { + public static class KeySelectorFunc implements KeySelector, Object> { - @Override - public Object getKey(Tuple value) { - return value.f0; - } + @Override + public Object getKey(Tuple value) { + return value.f0; } + } - public static class CountFunc implements ReduceFunction> { + public static class CountFunc implements ReduceFunction> { - @Override - public Tuple reduce(Tuple oldValue, Tuple newValue) { - return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); - } + @Override + public Tuple reduce( + Tuple oldValue, Tuple newValue) { + return Tuple.of(oldValue.f0, oldValue.f1 + newValue.f1); } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseQueryTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseQueryTest.java index 9bd4c0c7a..4e77b3d6c 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseQueryTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseQueryTest.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; @@ -47,216 +48,214 @@ public class BaseQueryTest implements Serializable { - private int testTimeWaitSeconds = 0; - - public static final String INIT_DDL = "/query/modern_graph.sql"; - public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; - - private String queryPath; - - private boolean compareWithOrder = false; - - private boolean withOutPrefix = false; + private int testTimeWaitSeconds = 0; - private String graphDefinePath; + public static final String INIT_DDL = "/query/modern_graph.sql"; + public static final String DSL_STATE_REMOTE_PATH = "/tmp/dsl/"; - private final Map config = new HashMap<>(); + private String queryPath; - private BaseQueryTest() { - try { - initRemotePath(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } + private boolean compareWithOrder = false; - public static BaseQueryTest build() { - return new BaseQueryTest(); - } + private boolean withOutPrefix = false; + private String graphDefinePath; - public BaseQueryTest withQueryPath(String queryPath) { - this.queryPath = queryPath; - return this; - } + private final Map config = new HashMap<>(); - public BaseQueryTest withTestTimeWaitSeconds(int testTimeWaitSeconds) { - this.testTimeWaitSeconds = testTimeWaitSeconds; - return this; + private BaseQueryTest() { + try { + initRemotePath(); + } catch (IOException e) { + throw new RuntimeException(e); } - - public BaseQueryTest compareWithOrder() { - this.compareWithOrder = true; - return this; + } + + public static BaseQueryTest build() { + return new BaseQueryTest(); + } + + public BaseQueryTest withQueryPath(String queryPath) { + this.queryPath = queryPath; + return this; + } + + public BaseQueryTest withTestTimeWaitSeconds(int testTimeWaitSeconds) { + this.testTimeWaitSeconds = testTimeWaitSeconds; + return this; + } + + public BaseQueryTest compareWithOrder() { + this.compareWithOrder = true; + return this; + } + + public BaseQueryTest withoutPrefix() { + this.withOutPrefix = true; + return this; + } + + public BaseQueryTest withConfig(Map config) { + this.config.putAll(config); + return this; + } + + public BaseQueryTest withConfig(String key, Object value) { + this.config.put(key, String.valueOf(value)); + return this; + } + + public BaseQueryTest execute() throws Exception { + if (queryPath == null) { + throw new IllegalArgumentException("You should call withQueryPath() before execute()."); } - - public BaseQueryTest withoutPrefix() { - this.withOutPrefix = true; - return this; + Map config = new HashMap<>(); + config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); + config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); + config.put( + DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), + withOutPrefix ? queryPath : FileConstants.PREFIX_JAVA_RESOURCE + queryPath); + config.putAll(this.config); + initResultDirectory(); + + Environment environment = EnvironmentFactory.onLocalEnvironment(); + environment.getEnvironmentContext().withConfig(config); + + GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); + + String graphDefinePath = null; + if (this.graphDefinePath != null) { + graphDefinePath = this.graphDefinePath; } - - public BaseQueryTest withConfig(Map config) { - this.config.putAll(config); - return this; + gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); + try { + gqlPipeLine.execute(); + } finally { + environment.shutdown(); + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); } - - public BaseQueryTest withConfig(String key, Object value) { - this.config.put(key, String.valueOf(value)); - return this; + return this; + } + + private void initResultDirectory() throws Exception { + // delete target file path + String targetPath = getTargetPath(queryPath); + File targetFile = new File(targetPath); + if (targetFile.exists()) { + FileUtils.forceDelete(targetFile); } + } - public BaseQueryTest execute() throws Exception { - if (queryPath == null) { - throw new IllegalArgumentException("You should call withQueryPath() before execute()."); - } - Map config = new HashMap<>(); - config.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), String.valueOf(-1L)); - config.put(FileConfigKeys.ROOT.getKey(), DSL_STATE_REMOTE_PATH); - config.put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH.getKey(), - withOutPrefix ? queryPath : FileConstants.PREFIX_JAVA_RESOURCE + queryPath); - config.putAll(this.config); - initResultDirectory(); - - Environment environment = EnvironmentFactory.onLocalEnvironment(); - environment.getEnvironmentContext().withConfig(config); - - GQLPipeLine gqlPipeLine = new GQLPipeLine(environment, testTimeWaitSeconds); - - String graphDefinePath = null; - if (this.graphDefinePath != null) { - graphDefinePath = this.graphDefinePath; - } - gqlPipeLine.setPipelineHook(new TestGQLPipelineHook(graphDefinePath, queryPath)); - try { - gqlPipeLine.execute(); - } finally { - environment.shutdown(); - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); - } - return this; + private void initRemotePath() throws IOException { + // delete state remote path + File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); + if (stateRemoteFile.exists()) { + FileUtils.forceDelete(stateRemoteFile); } - - private void initResultDirectory() throws Exception { - // delete target file path - String targetPath = getTargetPath(queryPath); - File targetFile = new File(targetPath); - if (targetFile.exists()) { - FileUtils.forceDelete(targetFile); - } + } + + public void checkSinkResult(String... path) throws Exception { + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String exceptPath = "/expect/" + lastPath.split("\\.")[0] + ".txt"; + String targetPath; + if (path != null && path.length > 0) { + targetPath = path[0]; + } else { + targetPath = getTargetPath(queryPath); } - - private void initRemotePath() throws IOException { - // delete state remote path - File stateRemoteFile = new File(DSL_STATE_REMOTE_PATH); - if (stateRemoteFile.exists()) { - FileUtils.forceDelete(stateRemoteFile); - } + String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); + String actualResult = readFile(targetPath); + compareResult(actualResult, expectResult); + } + + private void compareResult(String actualResult, String expectResult) { + if (compareWithOrder) { + Assert.assertEquals(actualResult, expectResult); + } else { + String[] actualLines = actualResult.split("\n"); + Arrays.sort(actualLines); + String[] expectLines = expectResult.split("\n"); + Arrays.sort(expectLines); + + String actualSort = StringUtils.join(actualLines, "\n"); + String expectSort = StringUtils.join(expectLines, "\n"); + if (!Objects.equals(actualSort, expectSort)) { + Assert.assertEquals(actualResult, expectResult); + } } + } - public void checkSinkResult(String... path) throws Exception { - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String exceptPath = "/expect/" + lastPath.split("\\.")[0] + ".txt"; - String targetPath; - if (path != null && path.length > 0) { - targetPath = path[0]; - } else { - targetPath = getTargetPath(queryPath); - } - String expectResult = IOUtils.resourceToString(exceptPath, Charset.defaultCharset()).trim(); - String actualResult = readFile(targetPath); - compareResult(actualResult, expectResult); + private String readFile(String path) throws IOException { + File file = new File(path); + if (file.isHidden()) { + return ""; } - - private void compareResult(String actualResult, String expectResult) { - if (compareWithOrder) { - Assert.assertEquals(actualResult, expectResult); - } else { - String[] actualLines = actualResult.split("\n"); - Arrays.sort(actualLines); - String[] expectLines = expectResult.split("\n"); - Arrays.sort(expectLines); - - String actualSort = StringUtils.join(actualLines, "\n"); - String expectSort = StringUtils.join(expectLines, "\n"); - if (!Objects.equals(actualSort, expectSort)) { - Assert.assertEquals(actualResult, expectResult); - } - } + if (file.isFile()) { + return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); } - - private String readFile(String path) throws IOException { - File file = new File(path); - if (file.isHidden()) { - return ""; - } - if (file.isFile()) { - return IOUtils.toString(new File(path).toURI(), Charset.defaultCharset()).trim(); + File[] files = file.listFiles(); + StringBuilder content = new StringBuilder(); + if (files != null) { + for (File subFile : files) { + String readText = readFile(subFile.getAbsolutePath()); + if (StringUtils.isBlank(readText)) { + continue; } - File[] files = file.listFiles(); - StringBuilder content = new StringBuilder(); - if (files != null) { - for (File subFile : files) { - String readText = readFile(subFile.getAbsolutePath()); - if (StringUtils.isBlank(readText)) { - continue; - } - if (content.length() > 0) { - content.append("\n"); - } - content.append(readText); - } + if (content.length() > 0) { + content.append("\n"); } - return content.toString().trim(); - } - - private static String getTargetPath(String queryPath) { - assert queryPath != null; - String[] paths = queryPath.split("/"); - String lastPath = paths[paths.length - 1]; - String targetPath = "target/" + lastPath.split("\\.")[0]; - String currentPath = new File(".").getAbsolutePath(); - targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; - return targetPath; + content.append(readText); + } } + return content.toString().trim(); + } - public BaseQueryTest withGraphDefine(String graphDefinePath) { - this.graphDefinePath = Objects.requireNonNull(graphDefinePath); - return this; - } + private static String getTargetPath(String queryPath) { + assert queryPath != null; + String[] paths = queryPath.split("/"); + String lastPath = paths[paths.length - 1]; + String targetPath = "target/" + lastPath.split("\\.")[0]; + String currentPath = new File(".").getAbsolutePath(); + targetPath = currentPath.substring(0, currentPath.length() - 1) + targetPath; + return targetPath; + } - private static class TestGQLPipelineHook implements GQLPipelineHook { + public BaseQueryTest withGraphDefine(String graphDefinePath) { + this.graphDefinePath = Objects.requireNonNull(graphDefinePath); + return this; + } - private final String graphDefinePath; + private static class TestGQLPipelineHook implements GQLPipelineHook { - private final String queryPath; + private final String graphDefinePath; - public TestGQLPipelineHook(String graphDefinePath, String queryPath) { - this.graphDefinePath = graphDefinePath; - this.queryPath = queryPath; - } + private final String queryPath; - @Override - public String rewriteScript(String script, Configuration configuration) { - return script.replace("${target}", getTargetPath(queryPath)); - } - - @Override - public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { - if (graphDefinePath != null) { - try { - String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); - queryClient.executeQuery(ddl, queryContext); - } catch (IOException e) { - throw new GeaFlowDSLException(e); - } - } - } + public TestGQLPipelineHook(String graphDefinePath, String queryPath) { + this.graphDefinePath = graphDefinePath; + this.queryPath = queryPath; + } - @Override - public void afterExecute(QueryClient queryClient, QueryContext queryContext) { + @Override + public String rewriteScript(String script, Configuration configuration) { + return script.replace("${target}", getTargetPath(queryPath)); + } + @Override + public void beforeExecute(QueryClient queryClient, QueryContext queryContext) { + if (graphDefinePath != null) { + try { + String ddl = IOUtils.resourceToString(graphDefinePath, Charset.defaultCharset()); + queryClient.executeQuery(ddl, queryContext); + } catch (IOException e) { + throw new GeaFlowDSLException(e); } + } } + + @Override + public void afterExecute(QueryClient queryClient, QueryContext queryContext) {} + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseTest.java index 423368b38..b53878120 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/base/BaseTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.system.ClusterMetaStore; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.env.Environment; @@ -34,35 +35,35 @@ public class BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class); - protected Map config; - protected Environment environment; + protected Map config; + protected Environment environment; - @BeforeMethod - public void cleanMetaStore() { - LOGGER.info("clean cluster meta store"); - ClusterMetaStore.close(); - environment = null; - } + @BeforeMethod + public void cleanMetaStore() { + LOGGER.info("clean cluster meta store"); + ClusterMetaStore.close(); + environment = null; + } - @BeforeMethod - public void setup() { - config = new HashMap<>(); - config.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); - config.put(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE.getKey(), "false"); - } + @BeforeMethod + public void setup() { + config = new HashMap<>(); + config.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); + config.put(ExecutionConfigKeys.HTTP_REST_SERVICE_ENABLE.getKey(), "false"); + } - @AfterMethod - public void clean() { - if (environment != null) { - environment.shutdown(); - } - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); + @AfterMethod + public void clean() { + if (environment != null) { + environment.shutdown(); } + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); + } - public Map getConfig() { - return config; - } + public Map getConfig() { + return config; + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/dsl/DemoCaseTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/dsl/DemoCaseTest.java index 8c0b8c3cc..ea9cc0d44 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/dsl/DemoCaseTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/dsl/DemoCaseTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.keys.DSLConfigKeys; import org.apache.geaflow.example.base.BaseQueryTest; @@ -31,60 +32,58 @@ public class DemoCaseTest { - private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/join2Graph/test/graph"; + private final String TEST_GRAPH_PATH = "/tmp/geaflow/dsl/join2Graph/test/graph"; - private final Map testConfig = new HashMap() { + private final Map testConfig = + new HashMap() { { - put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); - put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); - put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); - put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file_path"); - // If the test is conducted using the console catalog, the appended config is required. - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE.getKey(), "console"); - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), ""); - // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME.getKey(), "test1"); - // put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), "http://127.0.0.1:8080"); + put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS"); + put(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH); + put(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}"); + put(DSLConfigKeys.GEAFLOW_DSL_QUERY_PATH_TYPE.getKey(), "file_path"); + // If the test is conducted using the console catalog, the appended config is required. + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TYPE.getKey(), "console"); + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_TOKEN_KEY.getKey(), ""); + // put(DSLConfigKeys.GEAFLOW_DSL_CATALOG_INSTANCE_NAME.getKey(), "test1"); + // put(ExecutionConfigKeys.GEAFLOW_GW_ENDPOINT.getKey(), "http://127.0.0.1:8080"); } - }; + }; - @AfterClass - public void tearDown() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } + @AfterClass + public void tearDown() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + } - @Test - public void testQuickStartSqlJoinDemo_001() throws Exception { - String resultPath = "/tmp/geaflow/sql_join_to_graph_demo_result"; - File file = new File(resultPath); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - BaseQueryTest - .build() - .withoutPrefix() - .withConfig(testConfig) - .withQueryPath(System.getProperty("user.dir") + "/gql/sql_join_to_graph_demo.sql") - .execute() - .checkSinkResult(resultPath); + @Test + public void testQuickStartSqlJoinDemo_001() throws Exception { + String resultPath = "/tmp/geaflow/sql_join_to_graph_demo_result"; + File file = new File(resultPath); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + BaseQueryTest.build() + .withoutPrefix() + .withConfig(testConfig) + .withQueryPath(System.getProperty("user.dir") + "/gql/sql_join_to_graph_demo.sql") + .execute() + .checkSinkResult(resultPath); + } - @Test - public void testQuickStartSqlJoinDemo_002() throws Exception { - String resultPath = "/tmp/geaflow/sql_join_to_graph_demo_02_result"; - File file = new File(resultPath); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - BaseQueryTest - .build() - .withoutPrefix() - .withConfig(testConfig) - .withQueryPath(System.getProperty("user.dir") + "/gql/sql_join_to_graph_demo_02.sql") - .execute() - .checkSinkResult(resultPath); + @Test + public void testQuickStartSqlJoinDemo_002() throws Exception { + String resultPath = "/tmp/geaflow/sql_join_to_graph_demo_02_result"; + File file = new File(resultPath); + if (file.exists()) { + FileUtils.deleteDirectory(file); } - + BaseQueryTest.build() + .withoutPrefix() + .withConfig(testConfig) + .withQueryPath(System.getProperty("user.dir") + "/gql/sql_join_to_graph_demo_02.sql") + .execute() + .checkSinkResult(resultPath); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/encoder/EncoderIntegrationTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/encoder/EncoderIntegrationTest.java index ec35c1822..caac5d06b 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/encoder/EncoderIntegrationTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/encoder/EncoderIntegrationTest.java @@ -19,13 +19,13 @@ package org.apache.geaflow.example.encoder; -import com.google.common.collect.Lists; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Collections; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.PGraphWindow; @@ -69,430 +69,499 @@ import org.apache.geaflow.view.graph.PIncGraphView; import org.testng.annotations.Test; -public class EncoderIntegrationTest extends BaseTest { - - public static final String REF_PATH_PREFIX = "data/reference/encoder/"; - public static final String RES_PATH_PREFIX = "./target/tmp/data/result/encoder/"; +import com.google.common.collect.Lists; - public static final String TAG_KEY_BY = "key_by"; - public static final String TAG_VC = "vc"; - public static final String TAG_INC_VC = "inc_vc"; - public static final String TAG_STREAM = "stream"; +public class EncoderIntegrationTest extends BaseTest { - private static String getRefPath(String tag) { - return REF_PATH_PREFIX + tag; + public static final String REF_PATH_PREFIX = "data/reference/encoder/"; + public static final String RES_PATH_PREFIX = "./target/tmp/data/result/encoder/"; + + public static final String TAG_KEY_BY = "key_by"; + public static final String TAG_VC = "vc"; + public static final String TAG_INC_VC = "inc_vc"; + public static final String TAG_STREAM = "stream"; + + private static String getRefPath(String tag) { + return REF_PATH_PREFIX + tag; + } + + private static String getResPath(String tag) { + return RES_PATH_PREFIX + tag; + } + + @Test + public void testKeyByWithEncoder() throws Exception { + String tag = TAG_KEY_BY; + String resPath = getResPath(tag); + + ResultValidator.cleanResult(resPath); + this.environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = this.environment.getEnvironmentContext().getConfig(); + config.putAll(this.config); + config.put(FileSink.OUTPUT_DIR, resPath); + + Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + List words = Lists.newArrayList("hello", "world", "hello", "word"); + + IEncoder> tEncoder = + Encoders.tuple(Encoders.STRING, Encoders.INTEGER); + PStreamSource streamSource = + pipelineTaskCxt + .buildSource(new CollectionSource<>(words), SizeTumblingWindow.of(100)) + .withEncoder(Encoders.STRING) + .window(WindowFactory.createSizeTumblingWindow(4)); + + SinkFunction sink = + ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); + streamSource + .map(s -> Tuple.of(s, 1)) + .withEncoder(tEncoder) + .keyBy(Tuple::getF0) + .withEncoder(tEncoder) + .map(String::valueOf) + .withEncoder(Encoders.STRING) + .sink(sink); + }); + + IPipelineResult result = pipeline.execute(); + result.get(); + ResultValidator.validateResult(getRefPath(tag), resPath); + } + + @Test + public void testVc() throws Exception { + String tag = TAG_VC; + String resPath = getResPath(tag); + + ResultValidator.cleanResult(resPath); + this.environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = this.environment.getEnvironmentContext().getConfig(); + config.putAll(this.config); + config.put(FileSink.OUTPUT_DIR, resPath); + + Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + int sourceParallelism = 3; + int iterationParallelism = 7; + int sinkParallelism = 5; + + PWindowStream> prVertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + (FileSource.FileLineParser>) + line -> { + String[] fields = line.split(","); + IVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Double.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + AllWindow.getInstance()) + .withParallelism(sourceParallelism) + .withEncoder(VEncoder.INSTANCE); + + PWindowStream> prEdges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + (FileSource.FileLineParser>) + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), + Integer.valueOf(fields[1]), + 1); + return Collections.singletonList(edge); + }), + AllWindow.getInstance()) + .withParallelism(sourceParallelism) + .withEncoder(EEncoder.INSTANCE); + + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) + .withShardNum(8) + .withBackend(IViewDesc.BackendType.Memory) + .build(); + PGraphWindow graphWindow = + pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); + + SinkFunction sink = + ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); + graphWindow + .compute(new VcFunc(3)) + .compute(iterationParallelism) + .getVertices() + .withEncoder(VEncoder.INSTANCE) + .map(v -> v.getId() + " " + v.getValue()) + .withEncoder(Encoders.STRING) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + IPipelineResult result = pipeline.execute(); + result.get(); + ResultValidator.validateResult(getRefPath(tag), resPath); + } + + private static class VcFunc extends VertexCentricCompute { + + public VcFunc(long iterations) { + super(iterations); } - private static String getResPath(String tag) { - return RES_PATH_PREFIX + tag; + @Override + public VertexCentricComputeFunction getComputeFunction() { + return new PRVertexCentricComputeFunction(); } - @Test - public void testKeyByWithEncoder() throws Exception { - String tag = TAG_KEY_BY; - String resPath = getResPath(tag); - - ResultValidator.cleanResult(resPath); - this.environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = this.environment.getEnvironmentContext().getConfig(); - config.putAll(this.config); - config.put(FileSink.OUTPUT_DIR, resPath); - - Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - List words = Lists.newArrayList("hello", "world", "hello", "word"); - - IEncoder> tEncoder = Encoders.tuple(Encoders.STRING, Encoders.INTEGER); - PStreamSource streamSource = - pipelineTaskCxt.buildSource(new CollectionSource<>(words), - SizeTumblingWindow.of(100)) - .withEncoder(Encoders.STRING) - .window(WindowFactory.createSizeTumblingWindow(4)); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); - streamSource.map(s -> Tuple.of(s, 1)).withEncoder(tEncoder) - .keyBy(Tuple::getF0).withEncoder(tEncoder) - .map(String::valueOf).withEncoder(Encoders.STRING) - .sink(sink); - }); - - IPipelineResult result = pipeline.execute(); - result.get(); - ResultValidator.validateResult(getRefPath(tag), resPath); + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; } - @Test - public void testVc() throws Exception { - String tag = TAG_VC; - String resPath = getResPath(tag); - - ResultValidator.cleanResult(resPath); - this.environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = this.environment.getEnvironmentContext().getConfig(); - config.putAll(this.config); - config.put(FileSink.OUTPUT_DIR, resPath); - - Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - int sourceParallelism = 3; - int iterationParallelism = 7; - int sinkParallelism = 5; - - PWindowStream> prVertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - (FileSource.FileLineParser>) line -> { - String[] fields = line.split(","); - IVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Double.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), AllWindow.getInstance()) - .withParallelism(sourceParallelism) - .withEncoder(VEncoder.INSTANCE); - - PWindowStream> prEdges = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - (FileSource.FileLineParser>) line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>( - Integer.valueOf(fields[0]), Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), AllWindow.getInstance()) - .withParallelism(sourceParallelism) - .withEncoder(EEncoder.INSTANCE); - - GraphViewDesc graphViewDesc = GraphViewBuilder - .createGraphView(GraphViewBuilder.DEFAULT_GRAPH) - .withShardNum(8) - .withBackend(IViewDesc.BackendType.Memory) - .build(); - PGraphWindow graphWindow = - pipelineTaskCxt.buildWindowStreamGraph(prVertices, prEdges, graphViewDesc); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); - graphWindow - .compute(new VcFunc(3)) - .compute(iterationParallelism) - .getVertices().withEncoder(VEncoder.INSTANCE) - .map(v -> v.getId() + " " + v.getValue()).withEncoder(Encoders.STRING) - .sink(sink) - .withParallelism(sinkParallelism); - }); - - IPipelineResult result = pipeline.execute(); - result.get(); - ResultValidator.validateResult(getRefPath(tag), resPath); + @Override + public IEncoder getKeyEncoder() { + return Encoders.INTEGER; } - private static class VcFunc extends VertexCentricCompute { - - public VcFunc(long iterations) { - super(iterations); - } - - @Override - public VertexCentricComputeFunction getComputeFunction() { - return new PRVertexCentricComputeFunction(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - - @Override - public IEncoder getKeyEncoder() { - return Encoders.INTEGER; - } - - @Override - public IEncoder getMessageEncoder() { - return Encoders.DOUBLE; - } - + @Override + public IEncoder getMessageEncoder() { + return Encoders.DOUBLE; } + } - private static class PRVertexCentricComputeFunction - implements VertexCentricComputeFunction { + private static class PRVertexCentricComputeFunction + implements VertexCentricComputeFunction { - private VertexCentricComputeFunction.VertexCentricComputeFuncContext context; + private VertexCentricComputeFunction.VertexCentricComputeFuncContext< + Integer, Double, Integer, Double> + context; - @Override - public void init(VertexCentricComputeFuncContext context) { - this.context = context; - } - - @Override - public void compute(Integer vertexId, - Iterator messageIterator) { - IVertex vertex = this.context.vertex().get(); - if (this.context.getIterationId() == 1) { - this.context.sendMessageToNeighbors(vertex.getValue()); - } else { - double sum = 0; - while (messageIterator.hasNext()) { - double value = messageIterator.next(); - sum += value; - } - this.context.setNewVertexValue(sum); - } - } + @Override + public void init(VertexCentricComputeFuncContext context) { + this.context = context; + } - @Override - public void finish() { + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + IVertex vertex = this.context.vertex().get(); + if (this.context.getIterationId() == 1) { + this.context.sendMessageToNeighbors(vertex.getValue()); + } else { + double sum = 0; + while (messageIterator.hasNext()) { + double value = messageIterator.next(); + sum += value; } - + this.context.setNewVertexValue(sum); + } } - @Test - public void testIncVC() throws Exception { - String tag = TAG_INC_VC; - String resPath = getResPath(tag); + @Override + public void finish() {} + } + + @Test + public void testIncVC() throws Exception { + String tag = TAG_INC_VC; + String resPath = getResPath(tag); - ResultValidator.cleanResult(resPath); - this.environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = this.environment.getEnvironmentContext().getConfig(); - config.putAll(this.config); - config.put(FileSink.OUTPUT_DIR, resPath); + ResultValidator.cleanResult(resPath); + this.environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = this.environment.getEnvironmentContext().getConfig(); + config.putAll(this.config); + config.put(FileSink.OUTPUT_DIR, resPath); - Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); + Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); - int sourceParallelism = 3; - int iterationParallelism = 4; - int sinkParallelism = 2; + int sourceParallelism = 3; + int iterationParallelism = 4; + int sinkParallelism = 2; - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(iterationParallelism) .withBackend(IViewDesc.BackendType.RocksDB) - .withSchema(new GraphMetaType<>(IntegerType.INSTANCE, ValueVertex.class, - Integer.class, ValueEdge.class, IntegerType.class)) + .withSchema( + new GraphMetaType<>( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + IntegerType.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - - PWindowSource> vertices = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", - (FileSource.FileLineParser>) line -> { - String[] fields = line.split(","); - ValueVertex vertex = new ValueVertex<>( - Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); - return Collections.singletonList(vertex); - }), SizeTumblingWindow.of(500)) - .withParallelism(sourceParallelism).withEncoder(IncVEncoder.INSTANCE); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", - (FileSource.FileLineParser>) line -> { - String[] fields = line.split(","); - IEdge edge = new ValueEdge<>(Integer.valueOf(fields[0]), - Integer.valueOf(fields[1]), 1); - return Collections.singletonList(edge); - }), SizeTumblingWindow.of(10000)).withParallelism(sourceParallelism).withEncoder(EEncoder.INSTANCE); - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); - - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); - incGraphView.incrementalCompute(new IncVcFunc(3)) - .getVertices().withEncoder(IncVEncoder.INSTANCE) - .map(v -> v.getId() + " " + v.getValue()).withEncoder(Encoders.STRING) - .sink(sink).withParallelism(sinkParallelism); - }); - - IPipelineResult result = pipeline.execute(); - result.get(); - ResultValidator.validateResult(getRefPath(tag), resPath); + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + PWindowSource> vertices = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_vertex", + (FileSource.FileLineParser>) + line -> { + String[] fields = line.split(","); + ValueVertex vertex = + new ValueVertex<>( + Integer.valueOf(fields[0]), Integer.valueOf(fields[1])); + return Collections.singletonList(vertex); + }), + SizeTumblingWindow.of(500)) + .withParallelism(sourceParallelism) + .withEncoder(IncVEncoder.INSTANCE); + + PWindowSource> edges = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + (FileSource.FileLineParser>) + line -> { + String[] fields = line.split(","); + IEdge edge = + new ValueEdge<>( + Integer.valueOf(fields[0]), + Integer.valueOf(fields[1]), + 1); + return Collections.singletonList(edge); + }), + SizeTumblingWindow.of(10000)) + .withParallelism(sourceParallelism) + .withEncoder(EEncoder.INSTANCE); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + SinkFunction sink = + ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); + incGraphView + .incrementalCompute(new IncVcFunc(3)) + .getVertices() + .withEncoder(IncVEncoder.INSTANCE) + .map(v -> v.getId() + " " + v.getValue()) + .withEncoder(Encoders.STRING) + .sink(sink) + .withParallelism(sinkParallelism); + }); + + IPipelineResult result = pipeline.execute(); + result.get(); + ResultValidator.validateResult(getRefPath(tag), resPath); + } + + @Test + public void testStream() throws Exception { + String tag = TAG_STREAM; + String resPath = getResPath(tag); + + ResultValidator.cleanResult(resPath); + this.environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = this.environment.getEnvironmentContext().getConfig(); + config.putAll(this.config); + config.put(FileSink.OUTPUT_DIR, resPath); + + Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + int parallelism = 1; + IEncoder> tEncoder = + Encoders.tuple(Encoders.STRING, Encoders.INTEGER); + PStreamSource streamSource = + pipelineTaskCxt + .buildSource( + new FileSource<>( + "data/input/email_edge", + line -> { + String[] fields = line.split(","); + return Collections.singletonList(fields[0]); + }), + SizeTumblingWindow.of(5000)) + .withEncoder(Encoders.STRING) + .withParallelism(parallelism); + + SinkFunction sink = + ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); + streamSource + .map(e -> Tuple.of(e, 1)) + .withEncoder(tEncoder) + .keyBy(Tuple::getF0) + .withEncoder(tEncoder) + .reduce(new StreamWordCountPipeline.CountFunc()) + .withEncoder(tEncoder) + .withParallelism(parallelism) + .map(String::valueOf) + .withEncoder(Encoders.STRING) + .withParallelism(parallelism) + .sink(sink) + .withParallelism(parallelism); + }); + + IPipelineResult result = pipeline.execute(); + result.get(); + ResultValidator.validateResult(getRefPath(tag), resPath); + } + + public static class IncVcFunc + extends IncVertexCentricCompute { + + public IncVcFunc(long iterations) { + super(iterations); } - @Test - public void testStream() throws Exception { - String tag = TAG_STREAM; - String resPath = getResPath(tag); - - ResultValidator.cleanResult(resPath); - this.environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = this.environment.getEnvironmentContext().getConfig(); - config.putAll(this.config); - config.put(FileSink.OUTPUT_DIR, resPath); - - Pipeline pipeline = PipelineFactory.buildPipeline(this.environment); - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - int parallelism = 1; - IEncoder> tEncoder = Encoders.tuple(Encoders.STRING, Encoders.INTEGER); - PStreamSource streamSource = pipelineTaskCxt.buildSource(new FileSource<>( - "data/input/email_edge", - line -> { - String[] fields = line.split(","); - return Collections.singletonList(fields[0]); - }), SizeTumblingWindow.of(5000)) - .withEncoder(Encoders.STRING) - .withParallelism(parallelism); - - SinkFunction sink = ExampleSinkFunctionFactory.getSinkFunction(pipelineTaskCxt.getConfig()); - streamSource - .map(e -> Tuple.of(e, 1)).withEncoder(tEncoder) - .keyBy(Tuple::getF0).withEncoder(tEncoder) - .reduce(new StreamWordCountPipeline.CountFunc()) - .withEncoder(tEncoder).withParallelism(parallelism) - .map(String::valueOf).withEncoder(Encoders.STRING).withParallelism(parallelism) - .sink(sink).withParallelism(parallelism); - }); - - IPipelineResult result = pipeline.execute(); - result.get(); - ResultValidator.validateResult(getRefPath(tag), resPath); + @Override + public IncVertexCentricComputeFunction + getIncComputeFunction() { + return new PRIncVcFunc(); } - public static class IncVcFunc extends IncVertexCentricCompute { - - public IncVcFunc(long iterations) { - super(iterations); - } - - @Override - public IncVertexCentricComputeFunction getIncComputeFunction() { - return new PRIncVcFunc(); - } - - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } - - @Override - public IEncoder getKeyEncoder() { - return Encoders.INTEGER; - } - - @Override - public IEncoder getMessageEncoder() { - return Encoders.INTEGER; - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + @Override + public IEncoder getKeyEncoder() { + return Encoders.INTEGER; } - public static class PRIncVcFunc implements IncVertexCentricComputeFunction { + @Override + public IEncoder getMessageEncoder() { + return Encoders.INTEGER; + } + } - private IncGraphComputeContext context; + public static class PRIncVcFunc + implements IncVertexCentricComputeFunction { - @Override - public void init(IncGraphComputeContext context) { - this.context = context; - } + private IncGraphComputeContext context; - @Override - public void evolve(Integer vertexId, - TemporaryGraph temporaryGraph) { - long lastVersionId = 0L; - IVertex vertex = temporaryGraph.getVertex(); - HistoricalGraph historicalGraph = this.context.getHistoricalGraph(); - if (vertex == null) { - vertex = historicalGraph.getSnapShot(lastVersionId).vertex().get(); - } - - if (vertex != null) { - List> newEs = temporaryGraph.getEdges(); - List> oldEs = - historicalGraph.getSnapShot(lastVersionId).edges().getOutEdges(); - if (newEs != null) { - for (IEdge edge : newEs) { - this.context.sendMessage(edge.getTargetId(), vertexId); - } - } - if (oldEs != null) { - for (IEdge edge : oldEs) { - this.context.sendMessage(edge.getTargetId(), vertexId); - } - } - } - } + @Override + public void init(IncGraphComputeContext context) { + this.context = context; + } - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - int max = 0; - while (messageIterator.hasNext()) { - int value = messageIterator.next(); - max = Math.max(max, value); - } - this.context.getTemporaryGraph().updateVertexValue(max); + @Override + public void evolve(Integer vertexId, TemporaryGraph temporaryGraph) { + long lastVersionId = 0L; + IVertex vertex = temporaryGraph.getVertex(); + HistoricalGraph historicalGraph = + this.context.getHistoricalGraph(); + if (vertex == null) { + vertex = historicalGraph.getSnapShot(lastVersionId).vertex().get(); + } + + if (vertex != null) { + List> newEs = temporaryGraph.getEdges(); + List> oldEs = + historicalGraph.getSnapShot(lastVersionId).edges().getOutEdges(); + if (newEs != null) { + for (IEdge edge : newEs) { + this.context.sendMessage(edge.getTargetId(), vertexId); + } } - - @Override - public void finish(Integer vertexId, MutableGraph mutableGraph) { - IVertex vertex = this.context.getTemporaryGraph().getVertex(); - List> edges = this.context.getTemporaryGraph().getEdges(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - this.context.collect(vertex); - } - if (edges != null) { - edges.forEach(edge -> mutableGraph.addEdge(0, edge)); - } + if (oldEs != null) { + for (IEdge edge : oldEs) { + this.context.sendMessage(edge.getTargetId(), vertexId); + } } - + } } - private static class VEncoder extends AbstractEncoder> { + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + int max = 0; + while (messageIterator.hasNext()) { + int value = messageIterator.next(); + max = Math.max(max, value); + } + this.context.getTemporaryGraph().updateVertexValue(max); + } - private static final VEncoder INSTANCE = new VEncoder(); + @Override + public void finish(Integer vertexId, MutableGraph mutableGraph) { + IVertex vertex = this.context.getTemporaryGraph().getVertex(); + List> edges = this.context.getTemporaryGraph().getEdges(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + this.context.collect(vertex); + } + if (edges != null) { + edges.forEach(edge -> mutableGraph.addEdge(0, edge)); + } + } + } - @Override - public void encode(IVertex data, OutputStream outputStream) throws IOException { - Encoders.INTEGER.encode(data.getId(), outputStream); - Encoders.DOUBLE.encode(data.getValue(), outputStream); - } + private static class VEncoder extends AbstractEncoder> { - @Override - public IVertex decode(InputStream inputStream) throws IOException { - Integer id = Encoders.INTEGER.decode(inputStream); - Double value = Encoders.DOUBLE.decode(inputStream); - return new ValueVertex<>(id, value); - } + private static final VEncoder INSTANCE = new VEncoder(); + @Override + public void encode(IVertex data, OutputStream outputStream) + throws IOException { + Encoders.INTEGER.encode(data.getId(), outputStream); + Encoders.DOUBLE.encode(data.getValue(), outputStream); } - private static class EEncoder extends AbstractEncoder> { + @Override + public IVertex decode(InputStream inputStream) throws IOException { + Integer id = Encoders.INTEGER.decode(inputStream); + Double value = Encoders.DOUBLE.decode(inputStream); + return new ValueVertex<>(id, value); + } + } - private static final EEncoder INSTANCE = new EEncoder(); + private static class EEncoder extends AbstractEncoder> { - @Override - public void encode(IEdge data, OutputStream outputStream) throws IOException { - Encoders.INTEGER.encode(data.getSrcId(), outputStream); - Encoders.INTEGER.encode(data.getTargetId(), outputStream); - Encoders.INTEGER.encode(data.getValue(), outputStream); - } - - @Override - public IEdge decode(InputStream inputStream) throws IOException { - Integer src = Encoders.INTEGER.decode(inputStream); - Integer dst = Encoders.INTEGER.decode(inputStream); - Integer value = Encoders.INTEGER.decode(inputStream); - return new ValueEdge<>(src, dst, value); - } + private static final EEncoder INSTANCE = new EEncoder(); + @Override + public void encode(IEdge data, OutputStream outputStream) throws IOException { + Encoders.INTEGER.encode(data.getSrcId(), outputStream); + Encoders.INTEGER.encode(data.getTargetId(), outputStream); + Encoders.INTEGER.encode(data.getValue(), outputStream); } - private static class IncVEncoder extends AbstractEncoder> { - - private static final IncVEncoder INSTANCE = new IncVEncoder(); + @Override + public IEdge decode(InputStream inputStream) throws IOException { + Integer src = Encoders.INTEGER.decode(inputStream); + Integer dst = Encoders.INTEGER.decode(inputStream); + Integer value = Encoders.INTEGER.decode(inputStream); + return new ValueEdge<>(src, dst, value); + } + } - @Override - public void encode(IVertex data, OutputStream outputStream) throws IOException { - Encoders.INTEGER.encode(data.getId(), outputStream); - Encoders.INTEGER.encode(data.getValue(), outputStream); - } + private static class IncVEncoder extends AbstractEncoder> { - @Override - public IVertex decode(InputStream inputStream) throws IOException { - Integer id = Encoders.INTEGER.decode(inputStream); - Integer value = Encoders.INTEGER.decode(inputStream); - return new ValueVertex<>(id, value); - } + private static final IncVEncoder INSTANCE = new IncVEncoder(); + @Override + public void encode(IVertex data, OutputStream outputStream) + throws IOException { + Encoders.INTEGER.encode(data.getId(), outputStream); + Encoders.INTEGER.encode(data.getValue(), outputStream); } + @Override + public IVertex decode(InputStream inputStream) throws IOException { + Integer id = Encoders.INTEGER.decode(inputStream); + Integer value = Encoders.INTEGER.decode(inputStream); + return new ValueVertex<>(id, value); + } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/BaseFoTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/BaseFoTest.java index 3b9216aff..70d130eb6 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/BaseFoTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/BaseFoTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.fo; import java.security.Permission; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.annotations.AfterClass; @@ -27,35 +28,32 @@ public class BaseFoTest { - private static final Logger LOGGER = LoggerFactory.getLogger(BaseFoTest.class); - private static SecurityManager securityManager; + private static final Logger LOGGER = LoggerFactory.getLogger(BaseFoTest.class); + private static SecurityManager securityManager; + @BeforeClass + public void before() { + securityManager = System.getSecurityManager(); + System.setSecurityManager(new SystemExitIgnoreSecurityManager()); + } - @BeforeClass - public void before() { - securityManager = System.getSecurityManager(); - System.setSecurityManager(new SystemExitIgnoreSecurityManager()); - } + @AfterClass + public void after() { + System.setSecurityManager(securityManager); + } - @AfterClass - public void after() { - System.setSecurityManager(securityManager); - } + public static class SystemExitIgnoreSecurityManager extends SecurityManager { + @Override + public void checkPermission(Permission perm) {} + + @Override + public void checkPermission(Permission perm, Object context) {} - public static class SystemExitIgnoreSecurityManager extends SecurityManager { - @Override - public void checkPermission(Permission perm) { - } - - @Override - public void checkPermission(Permission perm, Object context) { - } - - @Override - public void checkExit(int status) { - super.checkExit(status); - LOGGER.info("check exit {}", status); - throw new RuntimeException("throw exception instead of exit process"); - } + @Override + public void checkExit(int status) { + super.checkExit(status); + LOGGER.info("check exit {}", status); + throw new RuntimeException("throw exception instead of exit process"); } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/MockContext.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/MockContext.java index ebb40a3a1..b15cc5fb9 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/MockContext.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/fo/MockContext.java @@ -25,32 +25,31 @@ public class MockContext { - static class MockContainerContext extends ContainerContext { + static class MockContainerContext extends ContainerContext { - private boolean isRecover; + private boolean isRecover; - public MockContainerContext(int index, Configuration config, boolean isRecover) { - super(index, config); - this.isRecover = isRecover; - } - - public boolean isRecover() { - return isRecover; - } + public MockContainerContext(int index, Configuration config, boolean isRecover) { + super(index, config); + this.isRecover = isRecover; } - static class MockDriverContext extends DriverContext { + public boolean isRecover() { + return isRecover; + } + } - private boolean isRecover; + static class MockDriverContext extends DriverContext { - public MockDriverContext(int index, Configuration config, boolean isRecover) { - super(index, 0, config); - this.isRecover = isRecover; - } + private boolean isRecover; - public boolean isRecover() { - return isRecover; - } + public MockDriverContext(int index, Configuration config, boolean isRecover) { + super(index, 0, config); + this.isRecover = isRecover; + } + public boolean isRecover() { + return isRecover; } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/GraphAlgorithmsTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/GraphAlgorithmsTest.java index cdd1d1349..4c8af7a36 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/GraphAlgorithmsTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/GraphAlgorithmsTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.env.EnvironmentFactory; @@ -50,243 +51,239 @@ @Test(singleThreaded = true) public class GraphAlgorithmsTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphAlgorithmsTest.class); - - private static final Map TEST_CONFIG = new HashMap<>(); - - static { - TEST_CONFIG.put(ExampleConfigKeys.SOURCE_PARALLELISM.getKey(), String.valueOf(3)); - TEST_CONFIG.put(ExampleConfigKeys.ITERATOR_PARALLELISM.getKey(), String.valueOf(4)); - TEST_CONFIG.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), String.valueOf(5)); - TEST_CONFIG.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); - } + private static final Logger LOGGER = LoggerFactory.getLogger(GraphAlgorithmsTest.class); - public static class GraphAlgorithmTestFactory { - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new GraphAlgorithmsTest(true), - new GraphAlgorithmsTest(false), - }; - } - } + private static final Map TEST_CONFIG = new HashMap<>(); - private final boolean memoryPool; + static { + TEST_CONFIG.put(ExampleConfigKeys.SOURCE_PARALLELISM.getKey(), String.valueOf(3)); + TEST_CONFIG.put(ExampleConfigKeys.ITERATOR_PARALLELISM.getKey(), String.valueOf(4)); + TEST_CONFIG.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), String.valueOf(5)); + TEST_CONFIG.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); + } - public GraphAlgorithmsTest(boolean memoryPool) { - this.memoryPool = memoryPool; + public static class GraphAlgorithmTestFactory { + @Factory + public Object[] factoryMethod() { + return new Object[] { + new GraphAlgorithmsTest(true), new GraphAlgorithmsTest(false), + }; } + } - @Test - public void weakConnectedComponentsTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = WeakConnectedComponents.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - WeakConnectedComponents.validateResult(); - } + private final boolean memoryPool; + public GraphAlgorithmsTest(boolean memoryPool) { + this.memoryPool = memoryPool; + } - @Test - public void allShortestPathTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); + @Test + public void weakConnectedComponentsTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); - IPipelineResult result = AllShortestPath.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - AllShortestPath.validateResult(); + IPipelineResult result = WeakConnectedComponents.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void closenessCentralityTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = ClosenessCentrality.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ClosenessCentrality.validateResult(); + WeakConnectedComponents.validateResult(); + } + + @Test + public void allShortestPathTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = AllShortestPath.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void clusterCoefficientTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = ClusterCoefficient.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ClusterCoefficient.validateResult(); + AllShortestPath.validateResult(); + } + + @Test + public void closenessCentralityTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = ClosenessCentrality.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void commonNeighborsTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = CommonNeighbors.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - CommonNeighbors.validateResult(); + ClosenessCentrality.validateResult(); + } + + @Test + public void clusterCoefficientTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = ClusterCoefficient.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void kCoreTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = KCore.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - KCore.validateResult(); + ClusterCoefficient.validateResult(); + } + + @Test + public void commonNeighborsTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = CommonNeighbors.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void labelPropagationTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = LabelPropagation.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - LabelPropagation.validateResult(); + CommonNeighbors.validateResult(); + } + + @Test + public void kCoreTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = KCore.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - - @Test - public void linkPredictionTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = LinkPrediction.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - LinkPrediction.validateResult(); + KCore.validateResult(); + } + + @Test + public void labelPropagationTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = LabelPropagation.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void nPathsTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = NPaths.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - NPaths.validateResult(); + LabelPropagation.validateResult(); + } + + @Test + public void linkPredictionTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = LinkPrediction.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void pageRankTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = PageRank.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - PageRank.validateResult(); + LinkPrediction.validateResult(); + } + + @Test + public void nPathsTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = NPaths.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void personalRankTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = PersonalRank.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - PersonalRank.validateResult(); + NPaths.validateResult(); + } + + @Test + public void pageRankTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = PageRank.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void shortestPathTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = ShortestPath.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ShortestPath.validateResult(); + PageRank.validateResult(); + } + + @Test + public void personalRankTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = PersonalRank.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void shortestPathOfVertexSetTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = ShortestPathOfVertexSet.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ShortestPathOfVertexSet.validateResult(); + PersonalRank.validateResult(); + } + + @Test + public void shortestPathTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = ShortestPath.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void SSSPTest() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.putAll(TEST_CONFIG); - config.put(ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), - Boolean.toString(memoryPool)); - - IPipelineResult result = SSSP.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - SSSP.validateResult(); + ShortestPath.validateResult(); + } + + @Test + public void shortestPathOfVertexSetTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = ShortestPathOfVertexSet.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + ShortestPathOfVertexSet.validateResult(); + } + + @Test + public void SSSPTest() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.putAll(TEST_CONFIG); + config.put( + ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE.getKey(), Boolean.toString(memoryPool)); + + IPipelineResult result = SSSP.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); + } + SSSP.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/KHopTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/KHopTest.java index ad74312c3..efc15d489 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/KHopTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/KHopTest.java @@ -32,23 +32,24 @@ public class KHopTest extends BaseTest { - @Test - public void testMainInvoke() { - System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); - KHop.main(null); - } + @Test + public void testMainInvoke() { + System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); + KHop.main(null); + } - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = ((EnvironmentContext) environment.getEnvironmentContext()).getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = + ((EnvironmentContext) environment.getEnvironmentContext()).getConfig(); + configuration.putAll(config); - KHop pipeline = new KHop("0", 2); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + KHop pipeline = new KHop("0", 2); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/PageRankTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/PageRankTest.java index 5e623e0b6..95f567cf6 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/PageRankTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/PageRankTest.java @@ -36,151 +36,149 @@ public class PageRankTest extends BaseTest { - public static class TestPageRankFactory { - - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new PageRankTest(true), - new PageRankTest(false), - }; - } + public static class TestPageRankFactory { + @Factory + public Object[] factoryMethod() { + return new Object[] { + new PageRankTest(true), new PageRankTest(false), + }; } - - private final boolean prefetch; - - public PageRankTest(boolean prefetch) { - this.prefetch = prefetch; + } + + private final boolean prefetch; + + public PageRankTest(boolean prefetch) { + this.prefetch = prefetch; + } + + @Test + public void testMainInvoke() { + System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); + PageRank.main(null); + } + + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + configuration.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM, String.valueOf(1)); + + PageRank pipeline = new PageRank(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testMainInvoke() { - System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); - PageRank.main(null); + pipeline.validateResult(); + } + + @Test + public void test1() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - configuration.put(ExecutionConfigKeys.CONTAINER_WORKER_NUM, String.valueOf(1)); - - PageRank pipeline = new PageRank(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test2() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test1() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test3() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "1"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test2() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test4() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test3() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "1"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } - - @Test - public void test4() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test5() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test5() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } - - @Test - public void test6() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - PageRank pipeline = new PageRank(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "1"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test6() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + PageRank pipeline = new PageRank(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "1"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/SSSPTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/SSSPTest.java index 346bf65f7..c6ee32f24 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/SSSPTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/SSSPTest.java @@ -34,125 +34,125 @@ public class SSSPTest extends BaseTest { - @Test - public void testMainInvoke() { - System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); - SSSP.main(null); + @Test + public void testMainInvoke() { + System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); + SSSP.main(null); + } + + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + + SSSP pipeline = new SSSP(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - SSSP pipeline = new SSSP(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } - - @Test - public void test1() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test1() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test2() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test2() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test3() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "1"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test3() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "1"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test4() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test4() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test5() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test5() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test6() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - SSSP pipeline = new SSSP(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(ITERATOR_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "1"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test6() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + SSSP pipeline = new SSSP(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(ITERATOR_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "1"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphComputeTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphComputeTest.java index 5e7968cf7..375944e81 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphComputeTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphComputeTest.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -42,142 +43,144 @@ public class IncGraphComputeTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphComputeTest.class); - - private Map config; - - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - String path = config.get(FileConfigKeys.ROOT.getKey()) + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); - FileUtils.deleteQuietly(new File(path)); - } - - @Test - public void test1ShardWithSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphCompute pipeline = new IncrGraphCompute(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphComputeTest.class); + + private Map config; + + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + String path = + config.get(FileConfigKeys.ROOT.getKey()) + + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); + FileUtils.deleteQuietly(new File(path)); + } + + @Test + public void test1ShardWithSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphCompute pipeline = new IncrGraphCompute(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphOperatorTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphOperatorTest.java index d540bcfd8..8da1fefcf 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphOperatorTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/IncGraphOperatorTest.java @@ -19,11 +19,11 @@ package org.apache.geaflow.example.graph.inc; -import com.google.common.collect.Lists; import java.lang.reflect.Field; import java.util.Iterator; import java.util.List; import java.util.Set; + import org.apache.geaflow.api.function.internal.CollectionSource; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; @@ -61,163 +61,177 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class IncGraphOperatorTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphOperatorTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphOperatorTest.class); - @Test - public void testDynamicGraphVertexCentricComputeOp() { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration config = environment.getEnvironmentContext().getConfig(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + @Test + public void testDynamicGraphVertexCentricComputeOp() { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration config = environment.getEnvironmentContext().getConfig(); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - Pipeline pipeline = PipelineFactory.buildPipeline(environment); + Pipeline pipeline = PipelineFactory.buildPipeline(environment); - int sourceParallelism = 1; - int shardNum = 1; - int iterationParallelism = 1; - int sinkParallelism = 1; //build graph view + int sourceParallelism = 1; + int shardNum = 1; + int iterationParallelism = 1; + int sinkParallelism = 1; // build graph view - final String graphName = "graph_view_name"; - GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) + final String graphName = "graph_view_name"; + GraphViewDesc graphViewDesc = + GraphViewBuilder.createGraphView(graphName) .withShardNum(shardNum) .withBackend(IViewDesc.BackendType.RocksDB) - .withSchema(new GraphMetaType<>(IntegerType.INSTANCE, ValueVertex.class, - Object.class, ValueEdge.class, Object.class)) + .withSchema( + new GraphMetaType<>( + IntegerType.INSTANCE, + ValueVertex.class, + Object.class, + ValueEdge.class, + Object.class)) .build(); - pipeline.withView(graphName, graphViewDesc); - - pipeline.submit((PipelineTask) pipelineTaskCxt -> { - List> vertexSource = Lists.newArrayList( - new ValueVertex<>(1), new ValueVertex<>(2), - new ValueVertex<>(3), new ValueVertex<>(4), - new ValueVertex<>(5), new ValueVertex<>(6)); - - List> edgeSource = Lists.newArrayList( - new ValueEdge<>(1, 2), - new ValueEdge<>(3, 4), - new ValueEdge<>(5, 6)); - - PWindowSource> vertices = - pipelineTaskCxt.buildSource(new CollectionSource<>(vertexSource), - SizeTumblingWindow.of(2)) - .window(WindowFactory.createSizeTumblingWindow(3)) - .withParallelism(sourceParallelism); - - PWindowSource> edges = - pipelineTaskCxt.buildSource(new CollectionSource<>(edgeSource), - SizeTumblingWindow.of(1)) - .window(WindowFactory.createSizeTumblingWindow(3)) - .withParallelism(sourceParallelism); - - PGraphView fundGraphView = - pipelineTaskCxt.getGraphView(graphName); - - PIncGraphView incGraphView = - fundGraphView.appendGraph(vertices, edges); - - incGraphView.incrementalCompute(new IncGraphAlgorithms(2)) - // incremental compute operator with 2 parallelism. - .compute(iterationParallelism) - .getVertices() - .sink(v -> { - }) - .withParallelism(sinkParallelism); - - }); - IPipelineResult result = pipeline.execute(); - result.get(); + pipeline.withView(graphName, graphViewDesc); + + pipeline.submit( + (PipelineTask) + pipelineTaskCxt -> { + List> vertexSource = + Lists.newArrayList( + new ValueVertex<>(1), new ValueVertex<>(2), + new ValueVertex<>(3), new ValueVertex<>(4), + new ValueVertex<>(5), new ValueVertex<>(6)); + + List> edgeSource = + Lists.newArrayList( + new ValueEdge<>(1, 2), new ValueEdge<>(3, 4), new ValueEdge<>(5, 6)); + + PWindowSource> vertices = + pipelineTaskCxt + .buildSource(new CollectionSource<>(vertexSource), SizeTumblingWindow.of(2)) + .window(WindowFactory.createSizeTumblingWindow(3)) + .withParallelism(sourceParallelism); + + PWindowSource> edges = + pipelineTaskCxt + .buildSource(new CollectionSource<>(edgeSource), SizeTumblingWindow.of(1)) + .window(WindowFactory.createSizeTumblingWindow(3)) + .withParallelism(sourceParallelism); + + PGraphView fundGraphView = + pipelineTaskCxt.getGraphView(graphName); + + PIncGraphView incGraphView = + fundGraphView.appendGraph(vertices, edges); + + incGraphView + .incrementalCompute(new IncGraphAlgorithms(2)) + // incremental compute operator with 2 parallelism. + .compute(iterationParallelism) + .getVertices() + .sink(v -> {}) + .withParallelism(sinkParallelism); + }); + IPipelineResult result = pipeline.execute(); + result.get(); + } + + public static class IncGraphAlgorithms + extends IncVertexCentricCompute { + + public IncGraphAlgorithms(long iterations) { + super(iterations); } - public static class IncGraphAlgorithms extends IncVertexCentricCompute { - - public IncGraphAlgorithms(long iterations) { - super(iterations); - } + @Override + public IncVertexCentricComputeFunction + getIncComputeFunction() { + return new PRVertexCentricComputeFunction(); + } - @Override - public IncVertexCentricComputeFunction getIncComputeFunction() { - return new PRVertexCentricComputeFunction(); - } + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + } + + public static class PRVertexCentricComputeFunction + implements IncVertexCentricComputeFunction { + + private IncGraphComputeContext graphContext; + + @Override + public void init(IncGraphComputeContext graphContext) { + this.graphContext = graphContext; + TemporaryGraph temporaryGraph = + this.graphContext.getTemporaryGraph(); + TemporaryGraphCache temporaryGraphCache = + (TemporaryGraphCache) + getObjectField(temporaryGraph, IncTemporaryGraph.class, "temporaryGraphCache"); + Set allEvolveVId = temporaryGraphCache.getAllEvolveVId(); + Assert.assertEquals(allEvolveVId.size(), 0); + } - @Override - public VertexCentricCombineFunction getCombineFunction() { - return null; - } + @Override + public void evolve(Integer vertexId, TemporaryGraph temporaryGraph) { + this.graphContext.sendMessage(999, 1); + TemporaryGraphCache temporaryGraphCache = + (TemporaryGraphCache) + getObjectField(temporaryGraph, IncTemporaryGraph.class, "temporaryGraphCache"); + Set allEvolveVId = temporaryGraphCache.getAllEvolveVId(); + Assert.assertEquals(allEvolveVId.size(), 2); + + long nestedWindowId = this.graphContext.getRuntimeContext().getWindowId(); + HistoricalGraph historicalGraph = + this.graphContext.getHistoricalGraph(); + GraphState graphState = + (GraphState) getObjectField(historicalGraph, IncHistoricalGraph.class, "graphState"); + if (nestedWindowId == 2L) { + List allVersions = graphState.dynamicGraph().V().getAllVersions(1); + Assert.assertEquals(allVersions.size(), 1); + } + } + @Override + public void compute(Integer vertexId, Iterator messageIterator) { + int sum = 0; + Assert.assertEquals((int) vertexId, 999); + while (messageIterator.hasNext()) { + sum += messageIterator.next(); + } + Assert.assertEquals(sum, 2); } - public static class PRVertexCentricComputeFunction implements IncVertexCentricComputeFunction { - - private IncGraphComputeContext graphContext; - - @Override - public void init(IncGraphComputeContext graphContext) { - this.graphContext = graphContext; - TemporaryGraph temporaryGraph = - this.graphContext.getTemporaryGraph(); - TemporaryGraphCache temporaryGraphCache = (TemporaryGraphCache) getObjectField(temporaryGraph, - IncTemporaryGraph.class, "temporaryGraphCache"); - Set allEvolveVId = temporaryGraphCache.getAllEvolveVId(); - Assert.assertEquals(allEvolveVId.size(), 0); - } - - @Override - public void evolve(Integer vertexId, TemporaryGraph temporaryGraph) { - this.graphContext.sendMessage(999, 1); - TemporaryGraphCache temporaryGraphCache = (TemporaryGraphCache) getObjectField(temporaryGraph, - IncTemporaryGraph.class, "temporaryGraphCache"); - Set allEvolveVId = temporaryGraphCache.getAllEvolveVId(); - Assert.assertEquals(allEvolveVId.size(), 2); - - long nestedWindowId = this.graphContext.getRuntimeContext().getWindowId(); - HistoricalGraph historicalGraph = this.graphContext.getHistoricalGraph(); - GraphState graphState = (GraphState) getObjectField(historicalGraph, - IncHistoricalGraph.class, "graphState"); - if (nestedWindowId == 2L) { - List allVersions = graphState.dynamicGraph().V().getAllVersions(1); - Assert.assertEquals(allVersions.size(), 1); - } - } - - @Override - public void compute(Integer vertexId, Iterator messageIterator) { - int sum = 0; - Assert.assertEquals((int) vertexId, 999); - while (messageIterator.hasNext()) { - sum += messageIterator.next(); - } - Assert.assertEquals(sum, 2); - } - - @Override - public void finish(Integer vertexId, MutableGraph mutableGraph) { - IVertex vertex = graphContext.getTemporaryGraph().getVertex(); - List> edges = graphContext.getTemporaryGraph().getEdges(); - if (vertex != null) { - mutableGraph.addVertex(0, vertex); - graphContext.collect(vertex); - } else { - LOGGER.info("not found vertex {} in temporaryGraph ", vertexId); - } - if (edges != null) { - edges.stream().forEach(edge -> { - mutableGraph.addEdge(0, edge); + @Override + public void finish(Integer vertexId, MutableGraph mutableGraph) { + IVertex vertex = graphContext.getTemporaryGraph().getVertex(); + List> edges = graphContext.getTemporaryGraph().getEdges(); + if (vertex != null) { + mutableGraph.addVertex(0, vertex); + graphContext.collect(vertex); + } else { + LOGGER.info("not found vertex {} in temporaryGraph ", vertexId); + } + if (edges != null) { + edges.stream() + .forEach( + edge -> { + mutableGraph.addEdge(0, edge); }); - } - } - + } } - - private static Object getObjectField(Object instance, Class clazz, String fieldName) { - try { - Field field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - return field.get(instance); - } catch (IllegalAccessException | NoSuchFieldException e) { - throw new GeaflowRuntimeException(e); - } + } + + private static Object getObjectField(Object instance, Class clazz, String fieldName) { + try { + Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(instance); + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphAggTraversalAllTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphAggTraversalAllTest.java index 58959f670..2b2404164 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphAggTraversalAllTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphAggTraversalAllTest.java @@ -25,6 +25,7 @@ import static org.apache.geaflow.example.config.ExampleConfigKeys.SOURCE_PARALLELISM; import java.io.File; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -40,138 +41,142 @@ public class IncGraphAggTraversalAllTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphAggTraversalAllTest.class); - - @BeforeMethod - public void setUp() { - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName() + "-" + System.currentTimeMillis()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - String path = config.get(FileConfigKeys.ROOT.getKey()) + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); - FileUtils.deleteQuietly(new File(path)); - } - - @Test - public void test1ShardWithSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphAggTraversalAllTest.class); + + @BeforeMethod + public void setUp() { + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + getClass().getSimpleName() + "-" + System.currentTimeMillis()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + String path = + config.get(FileConfigKeys.ROOT.getKey()) + + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); + FileUtils.deleteQuietly(new File(path)); + } + + @Test + public void test1ShardWithSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphAggTraversalAll pipeline = new IncrGraphAggTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalAllTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalAllTest.java index 5728356e7..10abbe424 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalAllTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalAllTest.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -42,141 +43,145 @@ public class IncGraphTraversalAllTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphTraversalAllTest.class); - - private Map config; - - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName() + "-" + System.currentTimeMillis()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - String path = config.get(FileConfigKeys.ROOT.getKey()) + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); - FileUtils.deleteQuietly(new File(path)); - } - - @Test - public void test1ShardWithSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphTraversalAllTest.class); + + private Map config; + + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + getClass().getSimpleName() + "-" + System.currentTimeMillis()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + String path = + config.get(FileConfigKeys.ROOT.getKey()) + + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); + FileUtils.deleteQuietly(new File(path)); + } + + @Test + public void test1ShardWithSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalAll pipeline = new IncrGraphTraversalAll(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStartIdsTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStartIdsTest.java index fb3a43b3c..94a2ddb98 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStartIdsTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStartIdsTest.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -42,141 +43,144 @@ public class IncGraphTraversalByStartIdsTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphTraversalByStartIdsTest.class); - - private Map config; - - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - String path = config.get(FileConfigKeys.ROOT.getKey()) + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); - FileUtils.deleteQuietly(new File(path)); - } - - @Test - public void test1ShardWithSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } + private static final Logger LOGGER = + LoggerFactory.getLogger(IncGraphTraversalByStartIdsTest.class); + + private Map config; + + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + String path = + config.get(FileConfigKeys.ROOT.getKey()) + + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); + FileUtils.deleteQuietly(new File(path)); + } + + @Test + public void test1ShardWithSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStartIds pipeline = new IncrGraphTraversalByStartIds(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStreamTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStreamTest.java index 7706a0ca8..c44abcfce 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStreamTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/inc/traversal/IncGraphTraversalByStreamTest.java @@ -27,6 +27,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -42,142 +43,143 @@ public class IncGraphTraversalByStreamTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphTraversalByStreamTest.class); - - private Map config; - - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); - String path = config.get(FileConfigKeys.ROOT.getKey()) + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); - FileUtils.deleteQuietly(new File(path)); - } - - @Test - public void test1ShardWithSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - - @Test - public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } - - @Test - public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); - config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); - config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); - config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - result.get(); - pipeline.validateResult(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IncGraphTraversalByStreamTest.class); + + private Map config; + + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); + String path = + config.get(FileConfigKeys.ROOT.getKey()) + + config.get(ExecutionConfigKeys.JOB_APP_NAME.getKey()); + FileUtils.deleteQuietly(new File(path)); + } + + @Test + public void test1ShardWithSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCMapOneSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCFourMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(MAP_PARALLELISM.getKey(), String.valueOf(4)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test1ShardWithOneSourceVCMapFourSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(1)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(4)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithTwoSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(2)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCOneMapTwoSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(2)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } + + @Test + public void test2ShardWithOneSourceVCMapSinkConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + IncrGraphTraversalByStream pipeline = new IncrGraphTraversalByStream(); + config.put(SOURCE_PARALLELISM.getKey(), String.valueOf(1)); + config.put(ITERATOR_PARALLELISM.getKey(), String.valueOf(2)); + config.put(SINK_PARALLELISM.getKey(), String.valueOf(1)); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + result.get(); + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AllShortestPathTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AllShortestPathTest.java index d6c5453f4..5b2d35968 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AllShortestPathTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AllShortestPathTest.java @@ -28,16 +28,16 @@ public class AllShortestPathTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = AllShortestPath.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - AllShortestPath.validateResult(); + IPipelineResult result = AllShortestPath.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + AllShortestPath.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AverageDegreeTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AverageDegreeTest.java index ac76a37dd..9750294de 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AverageDegreeTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/AverageDegreeTest.java @@ -28,16 +28,16 @@ public class AverageDegreeTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = AverageDegree.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - AverageDegree.validateResult(); + IPipelineResult result = AverageDegree.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + AverageDegree.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClosenessCentralityTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClosenessCentralityTest.java index 595279569..b3ac3d2ee 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClosenessCentralityTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClosenessCentralityTest.java @@ -28,16 +28,16 @@ public class ClosenessCentralityTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = ClosenessCentrality.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ClosenessCentrality.validateResult(); + IPipelineResult result = ClosenessCentrality.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + ClosenessCentrality.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClusterCoefficientTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClusterCoefficientTest.java index c2fb9bf43..f6ea2bf4a 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClusterCoefficientTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ClusterCoefficientTest.java @@ -28,16 +28,16 @@ public class ClusterCoefficientTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = ClusterCoefficient.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ClusterCoefficient.validateResult(); + IPipelineResult result = ClusterCoefficient.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + ClusterCoefficient.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/CommonNeighborsTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/CommonNeighborsTest.java index 99c15f5dc..0acbd9b1a 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/CommonNeighborsTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/CommonNeighborsTest.java @@ -28,16 +28,16 @@ public class CommonNeighborsTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = CommonNeighbors.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - CommonNeighbors.validateResult(); + IPipelineResult result = CommonNeighbors.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + CommonNeighbors.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/KCoreTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/KCoreTest.java index da8dfece6..2275bea1d 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/KCoreTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/KCoreTest.java @@ -28,16 +28,16 @@ public class KCoreTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = KCore.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - KCore.validateResult(); + IPipelineResult result = KCore.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + KCore.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LabelPropagationTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LabelPropagationTest.java index 6014046ff..91104fa32 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LabelPropagationTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LabelPropagationTest.java @@ -28,16 +28,16 @@ public class LabelPropagationTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = LabelPropagation.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - LabelPropagation.validateResult(); + IPipelineResult result = LabelPropagation.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + LabelPropagation.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LinkPredictionTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LinkPredictionTest.java index 08930004c..9b3473121 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LinkPredictionTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/LinkPredictionTest.java @@ -28,16 +28,16 @@ public class LinkPredictionTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = LinkPrediction.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - LinkPrediction.validateResult(); + IPipelineResult result = LinkPrediction.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + LinkPrediction.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/NPathsTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/NPathsTest.java index d0b4bfc67..0760f77e7 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/NPathsTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/NPathsTest.java @@ -28,16 +28,16 @@ public class NPathsTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = NPaths.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - NPaths.validateResult(); + IPipelineResult result = NPaths.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + NPaths.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/PersonalRankRankTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/PersonalRankRankTest.java index f5d8fa29f..e1c67b2c3 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/PersonalRankRankTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/PersonalRankRankTest.java @@ -29,16 +29,16 @@ public class PersonalRankRankTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = PersonalRank.submit(environment); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute failed"); - } - PersonalRank.validateResult(); + IPipelineResult result = PersonalRank.submit(environment); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute failed"); } + PersonalRank.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathOfVertexSetTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathOfVertexSetTest.java index 8e302d26d..546b851ab 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathOfVertexSetTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathOfVertexSetTest.java @@ -28,16 +28,16 @@ public class ShortestPathOfVertexSetTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = ShortestPathOfVertexSet.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ShortestPathOfVertexSet.validateResult(); + IPipelineResult result = ShortestPathOfVertexSet.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + ShortestPathOfVertexSet.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathTest.java index bb1f9a6b1..ba4e4c816 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/ShortestPathTest.java @@ -28,16 +28,16 @@ public class ShortestPathTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = ShortestPath.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - ShortestPath.validateResult(); + IPipelineResult result = ShortestPath.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + ShortestPath.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/WeakConnectedComponentsTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/WeakConnectedComponentsTest.java index 87a8545e3..c24cc604f 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/WeakConnectedComponentsTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/WeakConnectedComponentsTest.java @@ -28,17 +28,16 @@ public class WeakConnectedComponentsTest extends BaseTest { - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = WeakConnectedComponents.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - WeakConnectedComponents.validateResult(); + IPipelineResult result = WeakConnectedComponents.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + WeakConnectedComponents.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllTest.java index d27cfa646..6dc11fb49 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalAllTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.graph.statical.traversal; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.env.EnvironmentFactory; @@ -29,16 +30,16 @@ public class StaticGraphAggTraversalAllTest extends BaseTest { - @Test - public void test() throws IOException { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws IOException { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = StaticGraphAggTraversalAllExample.submit(environment); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute failed"); - } - StaticGraphAggTraversalAllExample.validateResult(); + IPipelineResult result = StaticGraphAggTraversalAllExample.submit(environment); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute failed"); } + StaticGraphAggTraversalAllExample.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdTest.java index 57f75af1d..712340a7f 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphAggTraversalByIdTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.graph.statical.traversal; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.env.EnvironmentFactory; @@ -29,17 +30,16 @@ public class StaticGraphAggTraversalByIdTest extends BaseTest { - @Test - public void test() throws IOException { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws IOException { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = StaticGraphAggTraversalByIdExample.submit(environment); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute failed"); - } - StaticGraphAggTraversalByIdExample.validateResult(); + IPipelineResult result = StaticGraphAggTraversalByIdExample.submit(environment); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute failed"); } - + StaticGraphAggTraversalByIdExample.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllTest.java index 993a21514..9cfcf0ac2 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalAllTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.graph.statical.traversal; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.env.EnvironmentFactory; @@ -29,17 +30,16 @@ public class StaticGraphTraversalAllTest extends BaseTest { - @Test - public void test() throws IOException { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws IOException { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = StaticGraphTraversalAllExample.submit(environment); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute failed"); - } - StaticGraphTraversalAllExample.validateResult(); + IPipelineResult result = StaticGraphTraversalAllExample.submit(environment); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute failed"); } - + StaticGraphTraversalAllExample.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalTest.java index 99f497995..a461263be 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/graph/statical/traversal/StaticGraphTraversalTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.example.graph.statical.traversal; import java.io.IOException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.env.EnvironmentFactory; @@ -29,16 +30,16 @@ public class StaticGraphTraversalTest extends BaseTest { - @Test - public void test() throws IOException { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws IOException { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - IPipelineResult result = StaticGraphTraversalExample.submit(environment); - if (!result.isSuccess()) { - throw new GeaflowRuntimeException("execute failed"); - } - StaticGraphTraversalExample.validateResult(); + IPipelineResult result = StaticGraphTraversalExample.submit(environment); + if (!result.isSuccess()) { + throw new GeaflowRuntimeException("execute failed"); } + StaticGraphTraversalExample.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/KubernetesTestBase.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/KubernetesTestBase.java index 1d4375c27..c3a1802ed 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/KubernetesTestBase.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/KubernetesTestBase.java @@ -34,6 +34,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; import org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -41,47 +42,46 @@ public class KubernetesTestBase { - protected static final String APP_ID = "k8s-cluster-1234"; - - protected static final String CONTAINER_IMAGE = "geaflow-arm:0.1"; + protected static final String APP_ID = "k8s-cluster-1234"; - protected static final String MASTER_URL = "http://127.0.0.1:54448/"; + protected static final String CONTAINER_IMAGE = "geaflow-arm:0.1"; - protected static final String REDIS_HOST = "host.minikube.internal"; + protected static final String MASTER_URL = "http://127.0.0.1:54448/"; - protected Map config; + protected static final String REDIS_HOST = "host.minikube.internal"; - protected String localConfDir; + protected Map config; - public void setup() { - config = new HashMap<>(); - config.put(KubernetesConfig.CLIENT_MASTER_URL, MASTER_URL); - config.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), APP_ID); - config.put(ExecutionConfigKeys.JOB_UNIQUE_ID.getKey(), APP_ID); - config.put(KubernetesConfigKeys.CONTAINER_IMAGE.getKey(), CONTAINER_IMAGE); - config.put(MASTER_VCORES.getKey(), "1"); - config.put(CLIENT_VCORES.getKey(), "1"); - config.put(MASTER_MEMORY_MB.getKey(), "512"); - config.put(MASTER_JVM_OPTIONS.getKey(), - "-Xmx256m,-Xms256m,-Xmn64m,-XX:MaxDirectMemorySize=158m"); - config.put(CONTAINER_MEMORY_MB.getKey(), "512"); - config.put(CONTAINER_JVM_OPTION.getKey(), - "-Xmx128m,-Xms128m,-Xmn32m,-XX:MaxDirectMemorySize=128m"); - config.put(DRIVER_MEMORY_MB.getKey(), "512"); - config.put(DRIVER_JVM_OPTION.getKey(), - "-Xmx128m,-Xms128m,-Xmn32m,-XX:MaxDirectMemorySize=128m"); - config.put(CLIENT_MEMORY_MB.getKey(), "512"); - config.put(CLIENT_JVM_OPTIONS.getKey(), - "-Xmx256m,-Xms256m,-Xmn64m,-XX:MaxDirectMemorySize=158m"); - config.put(RedisConfigKeys.REDIS_HOST.getKey(), REDIS_HOST); - config.put(DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE.getKey(), "1Gi"); - config.put(ROOT.getKey(), "/tmp/geaflow/chk"); + protected String localConfDir; - localConfDir = this.getClass().getResource("/").getPath(); - } + public void setup() { + config = new HashMap<>(); + config.put(KubernetesConfig.CLIENT_MASTER_URL, MASTER_URL); + config.put(ExecutionConfigKeys.CLUSTER_ID.getKey(), APP_ID); + config.put(ExecutionConfigKeys.JOB_UNIQUE_ID.getKey(), APP_ID); + config.put(KubernetesConfigKeys.CONTAINER_IMAGE.getKey(), CONTAINER_IMAGE); + config.put(MASTER_VCORES.getKey(), "1"); + config.put(CLIENT_VCORES.getKey(), "1"); + config.put(MASTER_MEMORY_MB.getKey(), "512"); + config.put( + MASTER_JVM_OPTIONS.getKey(), "-Xmx256m,-Xms256m,-Xmn64m,-XX:MaxDirectMemorySize=158m"); + config.put(CONTAINER_MEMORY_MB.getKey(), "512"); + config.put( + CONTAINER_JVM_OPTION.getKey(), "-Xmx128m,-Xms128m,-Xmn32m,-XX:MaxDirectMemorySize=128m"); + config.put(DRIVER_MEMORY_MB.getKey(), "512"); + config.put( + DRIVER_JVM_OPTION.getKey(), "-Xmx128m,-Xms128m,-Xmn32m,-XX:MaxDirectMemorySize=128m"); + config.put(CLIENT_MEMORY_MB.getKey(), "512"); + config.put( + CLIENT_JVM_OPTIONS.getKey(), "-Xmx256m,-Xms256m,-Xmn64m,-XX:MaxDirectMemorySize=158m"); + config.put(RedisConfigKeys.REDIS_HOST.getKey(), REDIS_HOST); + config.put(DEFAULT_RESOURCE_EPHEMERAL_STORAGE_SIZE.getKey(), "1Gi"); + config.put(ROOT.getKey(), "/tmp/geaflow/chk"); - public Map getConfig() { - return config; - } + localConfDir = this.getClass().getResource("/").getPath(); + } + public Map getConfig() { + return config; + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/StreamWordCountTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/StreamWordCountTest.java index a40c52d7f..1c44c7e70 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/StreamWordCountTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/k8s/StreamWordCountTest.java @@ -24,9 +24,9 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CONTAINER_WORKER_NUM; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.REPORTER_LIST; -import com.alibaba.fastjson.JSON; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.cluster.k8s.client.KubernetesJobClient; import org.apache.geaflow.cluster.k8s.config.KubernetesConfig; import org.apache.geaflow.cluster.k8s.config.KubernetesConfigKeys; @@ -34,31 +34,34 @@ import org.apache.geaflow.example.util.ExampleSinkFunctionFactory.SinkType; import org.apache.geaflow.metrics.common.reporter.ReporterRegistry; +import com.alibaba.fastjson.JSON; + public class StreamWordCountTest extends KubernetesTestBase { - public StreamWordCountTest() { - super.setup(); - config.put(CONTAINER_NUM.getKey(), "1"); - config.put(CONTAINER_WORKER_NUM.getKey(), "2"); - config.put(REPORTER_LIST.getKey(), ReporterRegistry.SLF4J_REPORTER); - config.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); - config.put(CLIENT_EXIT_WAIT_SECONDS.getKey(), "120"); - } + public StreamWordCountTest() { + super.setup(); + config.put(CONTAINER_NUM.getKey(), "1"); + config.put(CONTAINER_WORKER_NUM.getKey(), "2"); + config.put(REPORTER_LIST.getKey(), ReporterRegistry.SLF4J_REPORTER); + config.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); + config.put(CLIENT_EXIT_WAIT_SECONDS.getKey(), "120"); + } - public void submit() { - config.remove(KubernetesConfig.CLIENT_MASTER_URL); - Map clientConfig = new HashMap<>(); - clientConfig.put("job", config); - String clientArgs = JSON.toJSONString(clientConfig); - config.put(KubernetesConfigKeys.USER_MAIN_CLASS.getKey(), - "org.apache.geaflow.example.k8s.UnBoundedStreamWordCount"); - config.put(KubernetesConfigKeys.USER_CLASS_ARGS.getKey(), clientArgs); - KubernetesJobClient jobClient = new KubernetesJobClient(config, MASTER_URL); - jobClient.submitJob(); - } + public void submit() { + config.remove(KubernetesConfig.CLIENT_MASTER_URL); + Map clientConfig = new HashMap<>(); + clientConfig.put("job", config); + String clientArgs = JSON.toJSONString(clientConfig); + config.put( + KubernetesConfigKeys.USER_MAIN_CLASS.getKey(), + "org.apache.geaflow.example.k8s.UnBoundedStreamWordCount"); + config.put(KubernetesConfigKeys.USER_CLASS_ARGS.getKey(), clientArgs); + KubernetesJobClient jobClient = new KubernetesJobClient(config, MASTER_URL); + jobClient.submitJob(); + } - public static void main(String[] args) { - StreamWordCountTest test = new StreamWordCountTest(); - test.submit(); - } + public static void main(String[] args) { + StreamWordCountTest test = new StreamWordCountTest(); + test.submit(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/AnalyticsClientTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/AnalyticsClientTest.java index 1ed329723..b3fc036ce 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/AnalyticsClientTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/AnalyticsClientTest.java @@ -18,14 +18,13 @@ */ package org.apache.geaflow.example.service; - import static org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys.ANALYTICS_COMPILE_SCHEMA_ENABLE; import static org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys.ANALYTICS_QUERY; import static org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE; -import com.github.fppt.jedismock.RedisServer; import java.io.IOException; import java.util.List; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.analytics.service.query.QueryResults; @@ -41,112 +40,115 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -public class AnalyticsClientTest extends BaseTest { - - private static final String QUERY = "hello world"; - private static final String HOST_NAME = "localhost"; - private static final int DEFAULT_WAITING_TIME = 5; +import com.github.fppt.jedismock.RedisServer; - @BeforeMethod - public void setUp() { - config.put(ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), String.valueOf(false)); - config.put(ANALYTICS_SERVICE_REGISTER_ENABLE.getKey(), Boolean.FALSE.toString()); - config.put(ANALYTICS_QUERY.getKey(), QUERY); - } +public class AnalyticsClientTest extends BaseTest { - @Test - public void testWordLengthWithAnalyticsClientByRpc() { - int testPort = 8091; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(testPort)); - configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_rpc"); - WordLengthService wordLengthService = new WordLengthService(); - wordLengthService.submit(environment); - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - - AnalyticsClient analyticsClient = AnalyticsClient.builder() + private static final String QUERY = "hello world"; + private static final String HOST_NAME = "localhost"; + private static final int DEFAULT_WAITING_TIME = 5; + + @BeforeMethod + public void setUp() { + config.put(ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), String.valueOf(false)); + config.put(ANALYTICS_SERVICE_REGISTER_ENABLE.getKey(), Boolean.FALSE.toString()); + config.put(ANALYTICS_QUERY.getKey(), QUERY); + } + + @Test + public void testWordLengthWithAnalyticsClientByRpc() { + int testPort = 8091; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(testPort)); + configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_rpc"); + WordLengthService wordLengthService = new WordLengthService(); + wordLengthService.submit(environment); + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + + AnalyticsClient analyticsClient = + AnalyticsClient.builder() .withHost(HOST_NAME) .withPort(testPort) .withConfiguration(configuration) .withRetryNum(3) .build(); - QueryResults queryResults = analyticsClient.executeQuery(QUERY); - Assert.assertNotNull(queryResults); - List> rawData = queryResults.getRawData(); - Assert.assertEquals(rawData.size(), 1); - Assert.assertEquals((int) rawData.get(0).get(0), 11); - analyticsClient.shutdown(); - } - - @Test - public void testWordLengthWithAnalyticsClientByHttp() { - int testPort = 8092; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(testPort)); - configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_http"); - environment.getEnvironmentContext().withConfig(configuration.getConfigMap()); - WordLengthService wordLengthService = new WordLengthService(); - wordLengthService.submit(environment); - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - - AnalyticsClient analyticsClient = AnalyticsClient.builder() + QueryResults queryResults = analyticsClient.executeQuery(QUERY); + Assert.assertNotNull(queryResults); + List> rawData = queryResults.getRawData(); + Assert.assertEquals(rawData.size(), 1); + Assert.assertEquals((int) rawData.get(0).get(0), 11); + analyticsClient.shutdown(); + } + + @Test + public void testWordLengthWithAnalyticsClientByHttp() { + int testPort = 8092; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(testPort)); + configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_http"); + environment.getEnvironmentContext().withConfig(configuration.getConfigMap()); + WordLengthService wordLengthService = new WordLengthService(); + wordLengthService.submit(environment); + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + + AnalyticsClient analyticsClient = + AnalyticsClient.builder() .withHost(HOST_NAME) .withPort(testPort) .withConfiguration(configuration) .withRetryNum(3) .build(); - QueryResults queryResults = analyticsClient.executeQuery(QUERY); - Assert.assertNotNull(queryResults); - List> rawData = queryResults.getRawData(); - Assert.assertEquals(rawData.size(), 1); - Assert.assertEquals((int) rawData.get(0).get(0), 11); + QueryResults queryResults = analyticsClient.executeQuery(QUERY); + Assert.assertNotNull(queryResults); + List> rawData = queryResults.getRawData(); + Assert.assertEquals(rawData.size(), 1); + Assert.assertEquals((int) rawData.get(0).get(0), 11); + analyticsClient.shutdown(); + } + + @Test + public void testWordLengthWithRedisMetaService() throws IOException { + RedisServer redisServer = null; + AnalyticsClient analyticsClient = null; + try { + redisServer = RedisServer.newRedisServer().start(); + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + configuration.put( + AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.TRUE.toString()); + configuration.put(ANALYTICS_COMPILE_SCHEMA_ENABLE, Boolean.FALSE.toString()); + configuration.put(ExecutionConfigKeys.JOB_MODE.getKey(), JobMode.OLAP_SERVICE.toString()); + configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "redis"); + configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_rpc"); + configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + + WordLengthService wordLengthService = new WordLengthService(); + wordLengthService.submit(environment); + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + + analyticsClient = + AnalyticsClient.builder().withRetryNum(3).withConfiguration(configuration).build(); + + QueryResults queryResults = analyticsClient.executeQuery(QUERY); + Assert.assertNotNull(queryResults); + List> list = queryResults.getRawData(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals((int) list.get(0).get(0), 11); + } finally { + if (redisServer != null) { + redisServer.stop(); + } + if (analyticsClient != null) { analyticsClient.shutdown(); + } } - - @Test - public void testWordLengthWithRedisMetaService() throws IOException { - RedisServer redisServer = null; - AnalyticsClient analyticsClient = null; - try { - redisServer = RedisServer.newRedisServer().start(); - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.TRUE.toString()); - configuration.put(ANALYTICS_COMPILE_SCHEMA_ENABLE, Boolean.FALSE.toString()); - configuration.put(ExecutionConfigKeys.JOB_MODE.getKey(), JobMode.OLAP_SERVICE.toString()); - configuration.put(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE, "redis"); - configuration.put(FrameworkConfigKeys.SERVICE_SERVER_TYPE, "analytics_rpc"); - configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - - WordLengthService wordLengthService = new WordLengthService(); - wordLengthService.submit(environment); - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - - analyticsClient = AnalyticsClient.builder() - .withRetryNum(3) - .withConfiguration(configuration) - .build(); - - QueryResults queryResults = analyticsClient.executeQuery(QUERY); - Assert.assertNotNull(queryResults); - List> list = queryResults.getRawData(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals((int) list.get(0).get(0), 11); - } finally { - if (redisServer != null) { - redisServer.stop(); - } - if (analyticsClient != null) { - analyticsClient.shutdown(); - } - } - } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/BaseServiceTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/BaseServiceTest.java index c67c1eee4..37fa6c060 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/BaseServiceTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/BaseServiceTest.java @@ -21,12 +21,9 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.CONTAINER_WORKER_NUM; -import com.github.fppt.jedismock.RedisServer; -import io.grpc.ManagedChannel; -import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.ChannelOption; import java.io.File; import java.io.IOException; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.cluster.system.ClusterMetaStore; @@ -44,251 +41,258 @@ import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeClass; -public class BaseServiceTest { +import com.github.fppt.jedismock.RedisServer; - protected static final String TEST_GRAPH_PATH = "/tmp/geaflow/analytics/test/graph"; - protected static final String HOST_NAME = "localhost"; - protected static final int DEFAULT_WAITING_TIME = 5; - protected RedisServer server; - protected MetaServer metaServer; - protected Configuration defaultConfig; - protected Environment environment; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.ChannelOption; - public final String graphView = - "CREATE GRAPH bi (\n" - + " --static\n" - + " --Place\n" - + " Vertex Country (\n" - + " id bigint ID,\n" - + " name varchar,\n" - + " url varchar\n" - + " ),\n" - + " Vertex City (\n" - + " id bigint ID,\n" - + " name varchar,\n" - + " url varchar\n" - + " ),\n" - + " Vertex Continent (\n" - + " id bigint ID,\n" - + " name varchar,\n" - + " url varchar\n" - + " ),\n" - + " --Organisation\n" - + " Vertex Company (\n" - + " id bigint ID,\n" - + " name varchar,\n" - + " url varchar\n" - + " ),\n" - + " Vertex University (\n" - + " id bigint ID,\n" - + " name varchar,\n" - + " url varchar\n" - + " ),\n" - + " --Tag\n" - + "\tVertex TagClass (\n" - + "\t id bigint ID,\n" - + "\t name varchar,\n" - + "\t url varchar\n" - + "\t),\n" - + "\tVertex Tag (\n" - + "\t id bigint ID,\n" - + "\t name varchar,\n" - + "\t url varchar\n" - + "\t),\n" - + "\n" - + " --dynamic\n" - + " Vertex Person (\n" - + " id bigint ID,\n" - + " creationDate bigint,\n" - + " firstName varchar,\n" - + " lastName varchar,\n" - + " gender varchar,\n" - + " --birthday Date,\n" - + " --email {varchar},\n" - + " --speaks {varchar},\n" - + " browserUsed varchar,\n" - + " locationIP varchar\n" - + " ),\n" - + " Vertex Forum (\n" - + " id bigint ID,\n" - + " creationDate bigint,\n" - + " title varchar\n" - + " ),\n" - + " --Message\n" - + " Vertex Post (\n" - + " id bigint ID,\n" - + " creationDate bigint,\n" - + " browserUsed varchar,\n" - + " locationIP varchar,\n" - + " content varchar,\n" - + " length bigint,\n" - + " lang varchar,\n" - + " imageFile varchar\n" - + " ),\n" - + " Vertex Comment (\n" - + " id bigint ID,\n" - + " creationDate bigint,\n" - + " browserUsed varchar,\n" - + " locationIP varchar,\n" - + " content varchar,\n" - + " length bigint\n" - + " ),\n" - + "\n" - + " --relations\n" - + " --static\n" - + "\tEdge isLocatedIn (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + "\tEdge isPartOf (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + " Edge isSubclassOf (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID\n" - + " ),\n" - + " Edge hasType (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID\n" - + " ),\n" - + "\n" - + " --dynamic\n" - + "\tEdge hasModerator (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + "\tEdge containerOf (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + "\tEdge replyOf (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + "\tEdge hasTag (\n" - + "\t srcId bigint SOURCE ID,\n" - + "\t targetId bigint DESTINATION ID\n" - + "\t),\n" - + " Edge hasInterest (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID\n" - + " ),\n" - + " Edge hasCreator (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID\n" - + " ),\n" - + " Edge workAt (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID,\n" - + " workForm bigint\n" - + " ),\n" - + " Edge studyAt (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID,\n" - + " classYear bigint\n" - + " ),\n" - + "\n" - + " --temporary\n" - + " Edge hasMember (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID,\n" - + " creationDate bigint\n" - + " ),\n" - + " Edge likes (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID,\n" - + " creationDate bigint\n" - + " ),\n" - + " Edge knows (\n" - + " srcId bigint SOURCE ID,\n" - + " targetId bigint DESTINATION ID,\n" - + " creationDate bigint\n" - + " )\n" - + ") WITH (\n" - + " \t\tstoreType='rocksdb',\n" - + " \tshardCount = 4\n" - + " );\n" - + "\n" - + "USE GRAPH bi;"; +public class BaseServiceTest { - public final String analyticsQuery = graphView + "MATCH (a) RETURN a limit 0"; - public final String executeQuery = - graphView + "MATCH (person:Person where id = 1100001)-[:isLocatedIn]->(city:City)\n" - + "RETURN person.id, person.firstName, person.lastName"; + protected static final String TEST_GRAPH_PATH = "/tmp/geaflow/analytics/test/graph"; + protected static final String HOST_NAME = "localhost"; + protected static final int DEFAULT_WAITING_TIME = 5; + protected RedisServer server; + protected MetaServer metaServer; + protected Configuration defaultConfig; + protected Environment environment; - @BeforeClass - public void beforeClass() throws Exception { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } + public final String graphView = + "CREATE GRAPH bi (\n" + + " --static\n" + + " --Place\n" + + " Vertex Country (\n" + + " id bigint ID,\n" + + " name varchar,\n" + + " url varchar\n" + + " ),\n" + + " Vertex City (\n" + + " id bigint ID,\n" + + " name varchar,\n" + + " url varchar\n" + + " ),\n" + + " Vertex Continent (\n" + + " id bigint ID,\n" + + " name varchar,\n" + + " url varchar\n" + + " ),\n" + + " --Organisation\n" + + " Vertex Company (\n" + + " id bigint ID,\n" + + " name varchar,\n" + + " url varchar\n" + + " ),\n" + + " Vertex University (\n" + + " id bigint ID,\n" + + " name varchar,\n" + + " url varchar\n" + + " ),\n" + + " --Tag\n" + + "\tVertex TagClass (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t url varchar\n" + + "\t),\n" + + "\tVertex Tag (\n" + + "\t id bigint ID,\n" + + "\t name varchar,\n" + + "\t url varchar\n" + + "\t),\n" + + "\n" + + " --dynamic\n" + + " Vertex Person (\n" + + " id bigint ID,\n" + + " creationDate bigint,\n" + + " firstName varchar,\n" + + " lastName varchar,\n" + + " gender varchar,\n" + + " --birthday Date,\n" + + " --email {varchar},\n" + + " --speaks {varchar},\n" + + " browserUsed varchar,\n" + + " locationIP varchar\n" + + " ),\n" + + " Vertex Forum (\n" + + " id bigint ID,\n" + + " creationDate bigint,\n" + + " title varchar\n" + + " ),\n" + + " --Message\n" + + " Vertex Post (\n" + + " id bigint ID,\n" + + " creationDate bigint,\n" + + " browserUsed varchar,\n" + + " locationIP varchar,\n" + + " content varchar,\n" + + " length bigint,\n" + + " lang varchar,\n" + + " imageFile varchar\n" + + " ),\n" + + " Vertex Comment (\n" + + " id bigint ID,\n" + + " creationDate bigint,\n" + + " browserUsed varchar,\n" + + " locationIP varchar,\n" + + " content varchar,\n" + + " length bigint\n" + + " ),\n" + + "\n" + + " --relations\n" + + " --static\n" + + "\tEdge isLocatedIn (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + "\tEdge isPartOf (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + " Edge isSubclassOf (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID\n" + + " ),\n" + + " Edge hasType (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID\n" + + " ),\n" + + "\n" + + " --dynamic\n" + + "\tEdge hasModerator (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + "\tEdge containerOf (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + "\tEdge replyOf (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + "\tEdge hasTag (\n" + + "\t srcId bigint SOURCE ID,\n" + + "\t targetId bigint DESTINATION ID\n" + + "\t),\n" + + " Edge hasInterest (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID\n" + + " ),\n" + + " Edge hasCreator (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID\n" + + " ),\n" + + " Edge workAt (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID,\n" + + " workForm bigint\n" + + " ),\n" + + " Edge studyAt (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID,\n" + + " classYear bigint\n" + + " ),\n" + + "\n" + + " --temporary\n" + + " Edge hasMember (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID,\n" + + " creationDate bigint\n" + + " ),\n" + + " Edge likes (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID,\n" + + " creationDate bigint\n" + + " ),\n" + + " Edge knows (\n" + + " srcId bigint SOURCE ID,\n" + + " targetId bigint DESTINATION ID,\n" + + " creationDate bigint\n" + + " )\n" + + ") WITH (\n" + + " \t\tstoreType='rocksdb',\n" + + " \tshardCount = 4\n" + + " );\n" + + "\n" + + "USE GRAPH bi;"; - BaseQueryTest - .build() - .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") - .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") - .withConfig(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)) - .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) - .withConfig(AnalyticsServiceConfigKeys.ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), - String.valueOf(false)) - .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") - .withQueryPath("/ldbc/bi_insert_01.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_02.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_03.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_04.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_05.sql") - .execute() - .withQueryPath("/ldbc/bi_insert_06.sql") - .execute(); - } + public final String analyticsQuery = graphView + "MATCH (a) RETURN a limit 0"; + public final String executeQuery = + graphView + + "MATCH (person:Person where id = 1100001)-[:isLocatedIn]->(city:City)\n" + + "RETURN person.id, person.firstName, person.lastName"; - public void before() throws Exception { - String jobName = "test_analytics_" + System.currentTimeMillis(); - server = RedisServer.newRedisServer().start(); - defaultConfig = new Configuration(); - defaultConfig.put(RedisConfigKeys.REDIS_HOST, server.getHost()); - defaultConfig.put(RedisConfigKeys.REDIS_PORT, String.valueOf(server.getBindPort())); - defaultConfig.put(ExecutionConfigKeys.JOB_APP_NAME, jobName); - metaServer = new MetaServer(); - metaServer.init(new MetaServerContext(defaultConfig)); + @BeforeClass + public void beforeClass() throws Exception { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } - @AfterClass - public void cleanGraphStore() throws IOException { - File file = new File(TEST_GRAPH_PATH); - if (file.exists()) { - FileUtils.deleteDirectory(file); - } - } + BaseQueryTest.build() + .withConfig(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE.getKey(), "1") + .withConfig(FileConfigKeys.PERSISTENT_TYPE.getKey(), "DFS") + .withConfig(CONTAINER_WORKER_NUM.getKey(), String.valueOf(20)) + .withConfig(FileConfigKeys.ROOT.getKey(), TEST_GRAPH_PATH) + .withConfig( + AnalyticsServiceConfigKeys.ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), + String.valueOf(false)) + .withConfig(FileConfigKeys.JSON_CONFIG.getKey(), "{\"fs.defaultFS\":\"local\"}") + .withQueryPath("/ldbc/bi_insert_01.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_02.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_03.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_04.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_05.sql") + .execute() + .withQueryPath("/ldbc/bi_insert_06.sql") + .execute(); + } - @AfterMethod - public void clean() throws IOException { - if (metaServer != null) { - metaServer.close(); - } - if (server != null) { - server.stop(); - } + public void before() throws Exception { + String jobName = "test_analytics_" + System.currentTimeMillis(); + server = RedisServer.newRedisServer().start(); + defaultConfig = new Configuration(); + defaultConfig.put(RedisConfigKeys.REDIS_HOST, server.getHost()); + defaultConfig.put(RedisConfigKeys.REDIS_PORT, String.valueOf(server.getBindPort())); + defaultConfig.put(ExecutionConfigKeys.JOB_APP_NAME, jobName); + metaServer = new MetaServer(); + metaServer.init(new MetaServerContext(defaultConfig)); + } - if (environment != null) { - environment.shutdown(); - environment = null; - } - ClusterMetaStore.close(); - ScheduledWorkerManagerFactory.clear(); + @AfterClass + public void cleanGraphStore() throws IOException { + File file = new File(TEST_GRAPH_PATH); + if (file.exists()) { + FileUtils.deleteDirectory(file); } + } - protected static ManagedChannel buildChannel(String host, int port, int timeoutMs) { - return NettyChannelBuilder.forAddress(host, port) - .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMs) - // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid - // needing certificates. - .usePlaintext() - .build(); + @AfterMethod + public void clean() throws IOException { + if (metaServer != null) { + metaServer.close(); } + if (server != null) { + server.stop(); + } + + if (environment != null) { + environment.shutdown(); + environment = null; + } + ClusterMetaStore.close(); + ScheduledWorkerManagerFactory.clear(); + } + + protected static ManagedChannel buildChannel(String host, int port, int timeoutMs) { + return NettyChannelBuilder.forAddress(host, port) + .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, timeoutMs) + // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid + // needing certificates. + .usePlaintext() + .build(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyQueryServiceTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyQueryServiceTest.java index cfff86a6e..1d87a6f7b 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyQueryServiceTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyQueryServiceTest.java @@ -21,12 +21,11 @@ import static org.apache.geaflow.file.FileConfigKeys.ROOT; -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; import java.util.LinkedList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; + import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.common.config.Configuration; @@ -45,69 +44,78 @@ import org.slf4j.LoggerFactory; import org.testng.annotations.Test; +import com.google.protobuf.ByteString; + +import io.grpc.ManagedChannel; + public class ConcurrencyQueryServiceTest extends BaseServiceTest { - private static final Logger LOGGER = LoggerFactory.getLogger(ConcurrencyQueryServiceTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ConcurrencyQueryServiceTest.class); - private static final Integer THREAD_NUM = 1; + private static final Integer THREAD_NUM = 1; - @Test - public void testQueryService() throws Exception { - int port = 8093; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.FALSE.toString()); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); - configuration.put(ROOT, TEST_GRAPH_PATH); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, String.valueOf(4)); - // Collection source must be set all window size in order to only execute one batch. - configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); + @Test + public void testQueryService() throws Exception { + int port = 8093; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); + configuration.put( + AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.FALSE.toString()); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); + configuration.put(ROOT, TEST_GRAPH_PATH); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, String.valueOf(4)); + // Collection source must be set all window size in order to only execute one batch. + configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); - configuration.put(AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER, String.valueOf(THREAD_NUM)); - configuration.put(ExecutionConfigKeys.RPC_IO_THREAD_NUM, "32"); - configuration.put(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM, "32"); + configuration.put( + AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER, String.valueOf(THREAD_NUM)); + configuration.put(ExecutionConfigKeys.RPC_IO_THREAD_NUM, "32"); + configuration.put(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM, "32"); - QueryService queryService = new QueryService(); - queryService.submit(environment); + QueryService queryService = new QueryService(); + queryService.submit(environment); - testQuery(port); - } + testQuery(port); + } - private void testQuery(int port) throws InterruptedException, ExecutionException { - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + private void testQuery(int port) throws InterruptedException, ExecutionException { + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - ManagedChannel channel = buildChannel("localhost", port, 3000); - ByteString.Output output = ByteString.newOutput(); - SerializerFactory.getKryoSerializer().serialize("request", output); - QueryRequest request = QueryRequest.newBuilder() + ManagedChannel channel = buildChannel("localhost", port, 3000); + ByteString.Output output = ByteString.newOutput(); + SerializerFactory.getKryoSerializer().serialize("request", output); + QueryRequest request = + QueryRequest.newBuilder() .setQuery(executeQuery) .setQueryConfig(output.toByteString()) .build(); - List> resultFutureList = new LinkedList<>(); - for (int i = 0; i < THREAD_NUM; i++) { - AnalyticsServiceBlockingStub stub = AnalyticsServiceGrpc.newBlockingStub(channel); - LOGGER.info("stub {}: {}", i, stub); - resultFutureList.add(CompletableFuture.supplyAsync(() -> stub.executeQuery(request))); - } - - LOGGER.info("resultFutureList: {}", resultFutureList); - - for (int i = 0; i < THREAD_NUM; i++) { - CompletableFuture future = resultFutureList.get(i); - QueryResult queryResult = future.get(); - Assert.assertNotNull(queryResult); - QueryResults result = RpcMessageEncoder.decode(queryResult.getQueryResult()); - Assert.assertNotNull(result); - Object defaultFormattedResult = result.getFormattedData(); - List> rawData = result.getRawData(); - Assert.assertEquals(1, rawData.size()); - Assert.assertEquals(3, rawData.get(0).size()); - Assert.assertEquals(rawData.get(0).get(0), 1100001L); - Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); - Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); - Assert.assertEquals(defaultFormattedResult.toString(), "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\",\"lastName\":\"王\",\"id\":\"1100001\"}]}"); - } + List> resultFutureList = new LinkedList<>(); + for (int i = 0; i < THREAD_NUM; i++) { + AnalyticsServiceBlockingStub stub = AnalyticsServiceGrpc.newBlockingStub(channel); + LOGGER.info("stub {}: {}", i, stub); + resultFutureList.add(CompletableFuture.supplyAsync(() -> stub.executeQuery(request))); + } + + LOGGER.info("resultFutureList: {}", resultFutureList); + + for (int i = 0; i < THREAD_NUM; i++) { + CompletableFuture future = resultFutureList.get(i); + QueryResult queryResult = future.get(); + Assert.assertNotNull(queryResult); + QueryResults result = RpcMessageEncoder.decode(queryResult.getQueryResult()); + Assert.assertNotNull(result); + Object defaultFormattedResult = result.getFormattedData(); + List> rawData = result.getRawData(); + Assert.assertEquals(1, rawData.size()); + Assert.assertEquals(3, rawData.get(0).size()); + Assert.assertEquals(rawData.get(0).get(0), 1100001L); + Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); + Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); + Assert.assertEquals( + defaultFormattedResult.toString(), + "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\",\"lastName\":\"王\",\"id\":\"1100001\"}]}"); } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyWordLengthServiceTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyWordLengthServiceTest.java index 4ad75c83a..9135c8026 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyWordLengthServiceTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/ConcurrencyWordLengthServiceTest.java @@ -26,6 +26,7 @@ import java.util.LinkedList; import java.util.List; import java.util.concurrent.CompletableFuture; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.analytics.service.query.QueryResults; @@ -43,62 +44,60 @@ public class ConcurrencyWordLengthServiceTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(ConcurrencyQueryServiceTest.class); - private static final String QUERY = "hello world"; - private static final Integer THREAD_NUM = 2; - private static final String TEST_HOST = "localhost"; + private static final Logger LOGGER = LoggerFactory.getLogger(ConcurrencyQueryServiceTest.class); + private static final String QUERY = "hello world"; + private static final Integer THREAD_NUM = 2; + private static final String TEST_HOST = "localhost"; - @BeforeMethod - public void setUp() { - config.put(ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), String.valueOf(false)); - config.put(ANALYTICS_SERVICE_REGISTER_ENABLE.getKey(), Boolean.FALSE.toString()); - config.put(ANALYTICS_QUERY.getKey(), QUERY); - ClusterMetaStore.close(); - } + @BeforeMethod + public void setUp() { + config.put(ANALYTICS_COMPILE_SCHEMA_ENABLE.getKey(), String.valueOf(false)); + config.put(ANALYTICS_SERVICE_REGISTER_ENABLE.getKey(), Boolean.FALSE.toString()); + config.put(ANALYTICS_QUERY.getKey(), QUERY); + ClusterMetaStore.close(); + } - @Test - public void testWordLengthWithHttpServer() throws Exception { - int port = 8091; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); - configuration.put(AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER, String.valueOf(THREAD_NUM)); - configuration.put(ExecutionConfigKeys.RPC_IO_THREAD_NUM, "32"); - configuration.put(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM, "32"); - configuration.putAll(config); - WordLengthService wordLengthService = new WordLengthService(); - wordLengthService.submit(environment); + @Test + public void testWordLengthWithHttpServer() throws Exception { + int port = 8091; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); + configuration.put( + AnalyticsServiceConfigKeys.MAX_REQUEST_PER_SERVER, String.valueOf(THREAD_NUM)); + configuration.put(ExecutionConfigKeys.RPC_IO_THREAD_NUM, "32"); + configuration.put(ExecutionConfigKeys.RPC_WORKER_THREAD_NUM, "32"); + configuration.putAll(config); + WordLengthService wordLengthService = new WordLengthService(); + wordLengthService.submit(environment); - testHttpServiceServer(port); - } + testHttpServiceServer(port); + } - private static void testHttpServiceServer(int port) throws Exception { - SleepUtils.sleepSecond(3); + private static void testHttpServiceServer(int port) throws Exception { + SleepUtils.sleepSecond(3); - List> resultFutureList = new LinkedList<>(); - List clientLinkedList = new LinkedList<>(); - for (int i = 0; i < THREAD_NUM; i++) { - AnalyticsClient client = AnalyticsClient.builder() - .withHost(TEST_HOST) - .withPort(port) - .withRetryNum(3) - .build(); + List> resultFutureList = new LinkedList<>(); + List clientLinkedList = new LinkedList<>(); + for (int i = 0; i < THREAD_NUM; i++) { + AnalyticsClient client = + AnalyticsClient.builder().withHost(TEST_HOST).withPort(port).withRetryNum(3).build(); - LOGGER.info("client {}: {}", i, client); - clientLinkedList.add(client); - resultFutureList.add(CompletableFuture.supplyAsync(() -> client.executeQuery(QUERY))); - } + LOGGER.info("client {}: {}", i, client); + clientLinkedList.add(client); + resultFutureList.add(CompletableFuture.supplyAsync(() -> client.executeQuery(QUERY))); + } - LOGGER.info("resultFutureList: {}", resultFutureList); + LOGGER.info("resultFutureList: {}", resultFutureList); - for (int i = 0; i < THREAD_NUM; i++) { - CompletableFuture future = resultFutureList.get(i); - QueryResults queryResults = future.get(); - Assert.assertNotNull(queryResults); - List> rawData = queryResults.getRawData(); - Assert.assertEquals(rawData.size(), 1); - Assert.assertEquals((int) rawData.get(0).get(0), 11); - } - clientLinkedList.forEach(AnalyticsClient::shutdown); + for (int i = 0; i < THREAD_NUM; i++) { + CompletableFuture future = resultFutureList.get(i); + QueryResults queryResults = future.get(); + Assert.assertNotNull(queryResults); + List> rawData = queryResults.getRawData(); + Assert.assertEquals(rawData.size(), 1); + Assert.assertEquals((int) rawData.get(0).get(0), 11); } + clientLinkedList.forEach(AnalyticsClient::shutdown); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/MultiQueryServiceTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/MultiQueryServiceTest.java index e234a99e3..a31d6a0e4 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/MultiQueryServiceTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/MultiQueryServiceTest.java @@ -21,9 +21,8 @@ import static org.apache.geaflow.file.FileConfigKeys.ROOT; -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; import java.util.List; + import org.apache.geaflow.analytics.service.config.AnalyticsServiceConfigKeys; import org.apache.geaflow.analytics.service.query.QueryResults; import org.apache.geaflow.common.config.Configuration; @@ -45,58 +44,65 @@ import org.testng.Assert; import org.testng.annotations.Test; +import com.google.protobuf.ByteString; + +import io.grpc.ManagedChannel; + public class MultiQueryServiceTest extends BaseServiceTest { - private static final Logger LOGGER = LoggerFactory.getLogger(MultiQueryServiceTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MultiQueryServiceTest.class); - @Test - public void testQueryService() throws Exception { - before(); - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.DRIVER_NUM, "2"); - configuration.put(FrameworkConfigKeys.SERVICE_SHARE_ENABLE, Boolean.TRUE.toString()); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, "2"); - configuration.put(ROOT, TEST_GRAPH_PATH); - // Collection source must be set all window size in order to only execute one batch. - configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); - configuration.putAll(defaultConfig.getConfigMap()); + @Test + public void testQueryService() throws Exception { + before(); + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.DRIVER_NUM, "2"); + configuration.put(FrameworkConfigKeys.SERVICE_SHARE_ENABLE, Boolean.TRUE.toString()); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, "2"); + configuration.put(ROOT, TEST_GRAPH_PATH); + // Collection source must be set all window size in order to only execute one batch. + configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); + configuration.putAll(defaultConfig.getConfigMap()); - QueryService queryService = new QueryService(); - queryService.submit(environment); - testQuery(); - } + QueryService queryService = new QueryService(); + queryService.submit(environment); + testQuery(); + } - private void testQuery() { - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + private void testQuery() { + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - MetaServerQueryClient queryClient = new MetaServerQueryClient(defaultConfig); - List serviceInfos = queryClient.queryAllServices(NamespaceType.DEFAULT); - for (HostAndPort hostAndPort : serviceInfos) { - LOGGER.info("host and port {}", hostAndPort); - ManagedChannel channel = buildChannel(hostAndPort.getHost(), hostAndPort.getPort(), 3000); - AnalyticsServiceGrpc.AnalyticsServiceBlockingStub stub = AnalyticsServiceGrpc.newBlockingStub(channel); + MetaServerQueryClient queryClient = new MetaServerQueryClient(defaultConfig); + List serviceInfos = queryClient.queryAllServices(NamespaceType.DEFAULT); + for (HostAndPort hostAndPort : serviceInfos) { + LOGGER.info("host and port {}", hostAndPort); + ManagedChannel channel = buildChannel(hostAndPort.getHost(), hostAndPort.getPort(), 3000); + AnalyticsServiceGrpc.AnalyticsServiceBlockingStub stub = + AnalyticsServiceGrpc.newBlockingStub(channel); - ByteString.Output output = ByteString.newOutput(); - SerializerFactory.getKryoSerializer().serialize("request", output); - QueryRequest request = QueryRequest.newBuilder() - .setQuery(executeQuery) - .setQueryConfig(output.toByteString()) - .build(); - QueryResult queryResult = stub.executeQuery(request); - Assert.assertNotNull(queryResult); - QueryResults result = RpcMessageEncoder.decode(queryResult.getQueryResult()); - Object defaultFormattedResult = result.getFormattedData(); - List> rawData = result.getRawData(); - Assert.assertEquals(1, rawData.size()); - Assert.assertEquals(3, rawData.get(0).size()); - Assert.assertEquals(rawData.get(0).get(0), 1100001L); - Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); - Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); - Assert.assertEquals(defaultFormattedResult.toString(), "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\",\"lastName\":\"王\",\"id\":\"1100001\"}]}"); - } - queryClient.close(); + ByteString.Output output = ByteString.newOutput(); + SerializerFactory.getKryoSerializer().serialize("request", output); + QueryRequest request = + QueryRequest.newBuilder() + .setQuery(executeQuery) + .setQueryConfig(output.toByteString()) + .build(); + QueryResult queryResult = stub.executeQuery(request); + Assert.assertNotNull(queryResult); + QueryResults result = RpcMessageEncoder.decode(queryResult.getQueryResult()); + Object defaultFormattedResult = result.getFormattedData(); + List> rawData = result.getRawData(); + Assert.assertEquals(1, rawData.size()); + Assert.assertEquals(3, rawData.get(0).size()); + Assert.assertEquals(rawData.get(0).get(0), 1100001L); + Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); + Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); + Assert.assertEquals( + defaultFormattedResult.toString(), + "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\",\"lastName\":\"王\",\"id\":\"1100001\"}]}"); } - + queryClient.close(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/QueryServiceTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/QueryServiceTest.java index 175a140a7..cfd016511 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/QueryServiceTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/service/QueryServiceTest.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; + import org.apache.geaflow.analytics.service.client.AnalyticsClient; import org.apache.geaflow.analytics.service.client.jdbc.AnalyticsDriver; import org.apache.geaflow.analytics.service.client.jdbc.AnalyticsResultSet; @@ -45,88 +46,89 @@ public class QueryServiceTest extends BaseServiceTest { - @Test - public void testQueryService() { - int port = 8093; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.FALSE.toString()); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); - configuration.put(ROOT, TEST_GRAPH_PATH); - configuration.put(CONTAINER_WORKER_NUM, String.valueOf(4)); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, - String.valueOf(4)); - // Collection source must be set all window size in order to only execute one batch. - configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); + @Test + public void testQueryService() { + int port = 8093; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); + configuration.put( + AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.FALSE.toString()); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); + configuration.put(ROOT, TEST_GRAPH_PATH); + configuration.put(CONTAINER_WORKER_NUM, String.valueOf(4)); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, String.valueOf(4)); + // Collection source must be set all window size in order to only execute one batch. + configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); - QueryService queryService = new QueryService(); - queryService.submit(environment); - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + QueryService queryService = new QueryService(); + queryService.submit(environment); + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - AnalyticsClient analyticsClient = AnalyticsClient.builder().withHost(HOST_NAME) - .withPort(port).withRetryNum(3).build(); + AnalyticsClient analyticsClient = + AnalyticsClient.builder().withHost(HOST_NAME).withPort(port).withRetryNum(3).build(); - QueryResults queryResults = analyticsClient.executeQuery(executeQuery); - Assert.assertNotNull(queryResults); - Object defaultFormattedResult = queryResults.getFormattedData(); - List> rawData = queryResults.getRawData(); - Assert.assertEquals(1, rawData.size()); - Assert.assertEquals(3, rawData.get(0).size()); - Assert.assertEquals(rawData.get(0).get(0), 1100001L); - Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); - Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); - Assert.assertEquals(defaultFormattedResult.toString(), - "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\"," - + "\"lastName\":\"王\",\"id\":\"1100001\"}]}"); - analyticsClient.shutdown(); - } + QueryResults queryResults = analyticsClient.executeQuery(executeQuery); + Assert.assertNotNull(queryResults); + Object defaultFormattedResult = queryResults.getFormattedData(); + List> rawData = queryResults.getRawData(); + Assert.assertEquals(1, rawData.size()); + Assert.assertEquals(3, rawData.get(0).size()); + Assert.assertEquals(rawData.get(0).get(0), 1100001L); + Assert.assertEquals(rawData.get(0).get(1).toString(), "一"); + Assert.assertEquals(rawData.get(0).get(2).toString(), "王"); + Assert.assertEquals( + defaultFormattedResult.toString(), + "{\"viewResult\":{\"nodes\":[],\"edges\":[]},\"jsonResult\":[{\"firstName\":\"一\"," + + "\"lastName\":\"王\",\"id\":\"1100001\"}]}"); + analyticsClient.shutdown(); + } - @Test - public void testJDBCResultSet() { - int port = 8094; - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, - Boolean.FALSE.toString()); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); - configuration.put(ROOT, TEST_GRAPH_PATH); - configuration.put(CONTAINER_WORKER_NUM, String.valueOf(4)); - configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, - String.valueOf(4)); - // Collection source must be set all window size in order to only execute one batch. - configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); + @Test + public void testJDBCResultSet() { + int port = 8094; + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_PORT, String.valueOf(port)); + configuration.put( + AnalyticsServiceConfigKeys.ANALYTICS_SERVICE_REGISTER_ENABLE, Boolean.FALSE.toString()); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY, analyticsQuery); + configuration.put(ROOT, TEST_GRAPH_PATH); + configuration.put(CONTAINER_WORKER_NUM, String.valueOf(4)); + configuration.put(AnalyticsServiceConfigKeys.ANALYTICS_QUERY_PARALLELISM, String.valueOf(4)); + // Collection source must be set all window size in order to only execute one batch. + configuration.put(DSLConfigKeys.GEAFLOW_DSL_WINDOW_SIZE, "-1"); - QueryService queryService = new QueryService(); - queryService.submit(environment); - SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); - String url = DRIVER_URL_START + HOST_NAME + ":" + port; - Properties properties = new Properties(); - properties.put("user", "analytics_test"); - String testQuery = graphView + "MATCH (person:Person where id = 1100001)-[:isLocatedIn]->(city:City)\n" + QueryService queryService = new QueryService(); + queryService.submit(environment); + SleepUtils.sleepSecond(DEFAULT_WAITING_TIME); + String url = DRIVER_URL_START + HOST_NAME + ":" + port; + Properties properties = new Properties(); + properties.put("user", "analytics_test"); + String testQuery = + graphView + + "MATCH (person:Person where id = 1100001)-[:isLocatedIn]->(city:City)\n" + "RETURN person, person.id, person.firstName, person.lastName"; - try { - Class.forName(AnalyticsDriver.class.getCanonicalName()); - Connection connection = DriverManager.getConnection(url, properties); - Statement statement = connection.createStatement(); - AnalyticsResultSet resultSet = (AnalyticsResultSet) statement.executeQuery(testQuery); - Assert.assertNotNull(resultSet); - Assert.assertTrue(resultSet.next()); - long personId = resultSet.getLong(2); - Assert.assertEquals(1100001L, personId); - Assert.assertEquals(resultSet.getLong("id"), personId); - String personFirstName = resultSet.getString(3); - Assert.assertEquals("一", personFirstName); - Assert.assertEquals(resultSet.getString("firstName"), personFirstName); - IVertex> personVertexByLabel = resultSet.getVertex("person"); - IVertex> personVertexByIndex = resultSet.getVertex(1); - Assert.assertNotNull(personVertexByLabel); - Assert.assertEquals(personVertexByLabel, personVertexByIndex); - Assert.assertFalse(resultSet.next()); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + try { + Class.forName(AnalyticsDriver.class.getCanonicalName()); + Connection connection = DriverManager.getConnection(url, properties); + Statement statement = connection.createStatement(); + AnalyticsResultSet resultSet = (AnalyticsResultSet) statement.executeQuery(testQuery); + Assert.assertNotNull(resultSet); + Assert.assertTrue(resultSet.next()); + long personId = resultSet.getLong(2); + Assert.assertEquals(1100001L, personId); + Assert.assertEquals(resultSet.getLong("id"), personId); + String personFirstName = resultSet.getString(3); + Assert.assertEquals("一", personFirstName); + Assert.assertEquals(resultSet.getString("firstName"), personFirstName); + IVertex> personVertexByLabel = resultSet.getVertex("person"); + IVertex> personVertexByIndex = resultSet.getVertex(1); + Assert.assertNotNull(personVertexByLabel); + Assert.assertEquals(personVertexByLabel, personVertexByIndex); + Assert.assertFalse(resultSet.next()); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/EnvironmentUtilTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/EnvironmentUtilTest.java index 7a0e1ce9f..336f45e13 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/EnvironmentUtilTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/EnvironmentUtilTest.java @@ -29,14 +29,13 @@ public class EnvironmentUtilTest { - @Test - public void test() { - Environment env = EnvironmentUtil.loadEnvironment(null); - Assert.assertTrue(env instanceof LocalEnvironment); - - System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); - env = EnvironmentUtil.loadEnvironment(null); - Assert.assertTrue(env instanceof LocalEnvironment); - } + @Test + public void test() { + Environment env = EnvironmentUtil.loadEnvironment(null); + Assert.assertTrue(env instanceof LocalEnvironment); + System.setProperty(CLUSTER_TYPE, LOCAL_CLUSTER); + env = EnvironmentUtil.loadEnvironment(null); + Assert.assertTrue(env instanceof LocalEnvironment); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactoryTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactoryTest.java index af41a294e..8a6e38681 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactoryTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/util/ExampleSinkFunctionFactoryTest.java @@ -30,21 +30,21 @@ public class ExampleSinkFunctionFactoryTest { - @Test - public void test() { - Configuration configuration = new Configuration(); - configuration.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.CONSOLE_SINK.name()); + @Test + public void test() { + Configuration configuration = new Configuration(); + configuration.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.CONSOLE_SINK.name()); - SinkFunction sinkFunction = ExampleSinkFunctionFactory.getSinkFunction(configuration); - Assert.assertTrue(sinkFunction instanceof ConsoleSink); - } + SinkFunction sinkFunction = ExampleSinkFunctionFactory.getSinkFunction(configuration); + Assert.assertTrue(sinkFunction instanceof ConsoleSink); + } - @Test - public void test1() { - Configuration configuration = new Configuration(); - configuration.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); + @Test + public void test1() { + Configuration configuration = new Configuration(); + configuration.put(ExampleConfigKeys.GEAFLOW_SINK_TYPE.getKey(), SinkType.FILE_SINK.name()); - SinkFunction sinkFunction = ExampleSinkFunctionFactory.getSinkFunction(configuration); - Assert.assertTrue(sinkFunction instanceof FileSink); - } + SinkFunction sinkFunction = ExampleSinkFunctionFactory.getSinkFunction(configuration); + Assert.assertTrue(sinkFunction instanceof FileSink); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/CallBackTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/CallBackTest.java index 78dab872c..036532938 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/CallBackTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/CallBackTest.java @@ -29,20 +29,18 @@ public class CallBackTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(CallBackTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(CallBackTest.class); - @Test - public void testWindowCallBack() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void testWindowCallBack() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - WindowCallBackPipeline pipeline = new WindowCallBackPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } + WindowCallBackPipeline pipeline = new WindowCallBackPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/FlatMapTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/FlatMapTest.java index b6b13c5f2..33cf950c6 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/FlatMapTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/FlatMapTest.java @@ -30,21 +30,19 @@ public class FlatMapTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(FlatMapTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(FlatMapTest.class); - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - StreamWordFlatMapPipeline pipeline = new StreamWordFlatMapPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + StreamWordFlatMapPipeline pipeline = new StreamWordFlatMapPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/KeyAggTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/KeyAggTest.java index 71bdaf22a..05a17ca3a 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/KeyAggTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/KeyAggTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.env.EnvironmentFactory; import org.apache.geaflow.example.base.BaseTest; @@ -33,91 +34,89 @@ public class KeyAggTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(KeyAggTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(KeyAggTest.class); - private Map config; + private Map config; - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - } + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + } - @Test - public void testAllSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } + @Test + public void testAllSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - @Test - public void testAggThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggAndSinkThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggThreeAndSinkTwoConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - - WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggAndSinkThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggThreeAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - - WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggThreeAndSinkTwoConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + + WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } + + @Test + public void testAggThreeAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + + WindowKeyAggPipeline pipeline = new WindowKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); + } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/UnionTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/UnionTest.java index 306c6a040..4cf076f04 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/UnionTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/UnionTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.env.EnvironmentFactory; import org.apache.geaflow.example.base.BaseTest; @@ -34,91 +35,89 @@ public class UnionTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(UnionTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(UnionTest.class); - private Map config; + private Map config; - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - } + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + } - @Test - public void testUnionWithAggSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - StreamUnionPipeline pipeline = new StreamUnionPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } + @Test + public void testUnionWithAggSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - @Test - public void testUnionWithAggThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - StreamUnionPipeline pipeline = new StreamUnionPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + StreamUnionPipeline pipeline = new StreamUnionPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testUnionWithAggThreeAndSinkThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - StreamUnionPipeline pipeline = new StreamUnionPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testUnionWithAggThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + StreamUnionPipeline pipeline = new StreamUnionPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testUnionWithAggThreeAndSinkTwoConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - - StreamUnionPipeline pipeline = new StreamUnionPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testUnionWithAggThreeAndSinkThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + StreamUnionPipeline pipeline = new StreamUnionPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testUnionWithAggThreeAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - - StreamUnionPipeline pipeline = new StreamUnionPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testUnionWithAggThreeAndSinkTwoConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + + StreamUnionPipeline pipeline = new StreamUnionPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } + + @Test + public void testUnionWithAggThreeAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + + StreamUnionPipeline pipeline = new StreamUnionPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); + } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowKeyAggTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowKeyAggTest.java index 2b22e5396..e92f3b99f 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowKeyAggTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowKeyAggTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.env.EnvironmentFactory; import org.apache.geaflow.example.base.BaseTest; @@ -33,91 +34,89 @@ public class WindowKeyAggTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(WindowKeyAggTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(WindowKeyAggTest.class); - private Map config; + private Map config; - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - } + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + } - @Test - public void testAllSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } + @Test + public void testAllSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - @Test - public void testAggThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggAndSinkThreeConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggThreeAndSinkTwoConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - - WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggAndSinkThreeConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testAggThreeAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - - WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void testAggThreeAndSinkTwoConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + + WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } + + @Test + public void testAggThreeAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + + WindowStreamKeyAggPipeline pipeline = new WindowStreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); + } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowStreamWordCountTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowStreamWordCountTest.java index e998586bc..223faa3e5 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowStreamWordCountTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowStreamWordCountTest.java @@ -31,91 +31,85 @@ public class WindowStreamWordCountTest extends BaseTest { - @Test - public void testSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + @Test + public void testSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void testReduceTwoAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceOneAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "1"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void testReduceOneAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "1"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkTwoConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void testReduceTwoAndSinkTwoConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkOneConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "1"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void testReduceTwoAndSinkOneConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + WindowStreamWordCountPipeline pipeline = new WindowStreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "1"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowWordCountTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowWordCountTest.java index 784d107c1..a40c3502d 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowWordCountTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WindowWordCountTest.java @@ -24,6 +24,7 @@ import static org.apache.geaflow.example.config.ExampleConfigKeys.SOURCE_PARALLELISM; import java.util.Comparator; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.env.EnvironmentFactory; @@ -34,144 +35,136 @@ public class WindowWordCountTest extends BaseTest { - public static class TestWindowWordCountFactory { - - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new WindowWordCountTest(true), - new WindowWordCountTest(false), - }; - } + public static class TestWindowWordCountFactory { + @Factory + public Object[] factoryMethod() { + return new Object[] { + new WindowWordCountTest(true), new WindowWordCountTest(false), + }; } + } - private final boolean prefetch; + private final boolean prefetch; - public WindowWordCountTest(boolean prefetch) { - this.prefetch = prefetch; - } + public WindowWordCountTest(boolean prefetch) { + this.prefetch = prefetch; + } - @Test - public void testSingleConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(new DataComparator()); + @Test + public void testSingleConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(new DataComparator()); - + pipeline.validateResult(new DataComparator()); + } + + @Test + public void testReduceTwoAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceOneAndSinkFourConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "1"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(new DataComparator()); - + pipeline.validateResult(new DataComparator()); + } + + @Test + public void testReduceOneAndSinkFourConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "1"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkTwoConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(new DataComparator()); - + pipeline.validateResult(new DataComparator()); + } + + @Test + public void testReduceTwoAndSinkTwoConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void testReduceTwoAndSinkOneConcurrency() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); - - WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "1"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(new DataComparator()); - + pipeline.validateResult(new DataComparator()); + } + + @Test + public void testReduceTwoAndSinkOneConcurrency() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.put(ExecutionConfigKeys.SHUFFLE_PREFETCH, String.valueOf(this.prefetch)); + + WindowWordCountPipeline pipeline = new WindowWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "1"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + pipeline.validateResult(new DataComparator()); + } - static class DataComparator implements Comparator { - - @Override - public int compare(String o1, String o2) { - if (o1 != null && o2 != null) { - String[] data1 = o1.replace(")", "").split(","); - String[] data2 = o2.replace(")", "").split(","); + static class DataComparator implements Comparator { - if (!data1[0].equals(data2[0])) { - return o1.compareTo(o2); - } else { - return Integer.valueOf(data1[1]).compareTo(Integer.valueOf(data2[1])); - } - } + @Override + public int compare(String o1, String o2) { + if (o1 != null && o2 != null) { + String[] data1 = o1.replace(")", "").split(","); + String[] data2 = o2.replace(")", "").split(","); - return 0; + if (!data1[0].equals(data2[0])) { + return o1.compareTo(o2); + } else { + return Integer.valueOf(data1[1]).compareTo(Integer.valueOf(data2[1])); } + } - @Override - public boolean equals(Object obj) { - if (obj == this) { - return true; - } - if (obj instanceof String) { - return obj.equals(this); - } - return false; - } + return 0; } + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof String) { + return obj.equals(this); + } + return false; + } + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WordPrintTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WordPrintTest.java index 26e6ecc35..582730fe4 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WordPrintTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/WordPrintTest.java @@ -30,21 +30,19 @@ public class WordPrintTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(WordPrintTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(WordPrintTest.class); - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - StreamWordPrintPipeline pipeline = new StreamWordPrintPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + StreamWordPrintPipeline pipeline = new StreamWordPrintPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrAggTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrAggTest.java index 4ce31ca69..03b79fabf 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrAggTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrAggTest.java @@ -30,20 +30,19 @@ public class IncrAggTest extends BaseTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IncrAggTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrAggTest.class); - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - StreamAggPipeline pipeline = new StreamAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + StreamAggPipeline pipeline = new StreamAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrKeyAggTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrKeyAggTest.java index 6217206ab..26206233e 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrKeyAggTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrKeyAggTest.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.env.EnvironmentFactory; import org.apache.geaflow.example.base.BaseTest; @@ -32,88 +33,87 @@ public class IncrKeyAggTest extends BaseTest { - private Map config; + private Map config; - @BeforeMethod - public void setUp() { - config = new HashMap<>(); - } + @BeforeMethod + public void setUp() { + config = new HashMap<>(); + } - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - } + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); - @Test - public void test1() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test2() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); - configuration.putAll(config); - - StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test1() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test3() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - - StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test2() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "3"); + configuration.putAll(config); + + StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test4() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); - config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - - StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + pipeline.validateResult(); + } + + @Test + public void test3() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + + StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } + + @Test + public void test4() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + config.put(ExampleConfigKeys.AGG_PARALLELISM.getKey(), "3"); + config.put(ExampleConfigKeys.SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + + StreamKeyAggPipeline pipeline = new StreamKeyAggPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); + } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountCallBackTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountCallBackTest.java index 973527464..50167e90b 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountCallBackTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountCallBackTest.java @@ -30,21 +30,19 @@ public class IncrWordCountCallBackTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(IncrWordCountCallBackTest.class); + private static final Logger LOGGER = LoggerFactory.getLogger(IncrWordCountCallBackTest.class); - @Test - public void testTaskCallBack() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - StreamWordCountCallBackPipeline pipeline = new StreamWordCountCallBackPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); + @Test + public void testTaskCallBack() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + StreamWordCountCallBackPipeline pipeline = new StreamWordCountCallBackPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountTest.java b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountTest.java index b84b320fe..8b60cbc3c 100644 --- a/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountTest.java +++ b/geaflow/geaflow-examples/src/test/java/org/apache/geaflow/example/window/incremental/IncrWordCountTest.java @@ -34,131 +34,121 @@ public class IncrWordCountTest extends BaseTest { - private static final Logger LOGGER = - LoggerFactory.getLogger(IncrWordCountTest.class); - - - @Test - public void test() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - configuration.putAll(config); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + private static final Logger LOGGER = LoggerFactory.getLogger(IncrWordCountTest.class); + + @Test + public void test() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + configuration.putAll(config); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test1() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test1() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test2() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test2() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test3() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "1"); - config.put(SINK_PARALLELISM.getKey(), "4"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test3() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "1"); + config.put(SINK_PARALLELISM.getKey(), "4"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test4() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "2"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test4() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "2"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test5() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "2"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test5() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "2"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - - @Test - public void test6() throws Exception { - environment = EnvironmentFactory.onLocalEnvironment(); - Configuration configuration = environment.getEnvironmentContext().getConfig(); - - StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); - config.put(SOURCE_PARALLELISM.getKey(), "1"); - config.put(REDUCE_PARALLELISM.getKey(), "2"); - config.put(SINK_PARALLELISM.getKey(), "1"); - configuration.putAll(config); - IPipelineResult result = pipeline.submit(environment); - if (!result.isSuccess()) { - throw new Exception("execute failed"); - } - pipeline.validateResult(); - + pipeline.validateResult(); + } + + @Test + public void test6() throws Exception { + environment = EnvironmentFactory.onLocalEnvironment(); + Configuration configuration = environment.getEnvironmentContext().getConfig(); + + StreamWordCountPipeline pipeline = new StreamWordCountPipeline(); + config.put(SOURCE_PARALLELISM.getKey(), "1"); + config.put(REDUCE_PARALLELISM.getKey(), "2"); + config.put(SINK_PARALLELISM.getKey(), "1"); + configuration.putAll(config); + IPipelineResult result = pipeline.submit(environment); + if (!result.isSuccess()) { + throw new Exception("execute failed"); } - + pipeline.validateResult(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java index 0289c1985..82e7d0060 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferContext.java @@ -20,9 +20,9 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.exchange.DataExchangeContext; @@ -30,74 +30,76 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class InferContext implements AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(InferContext.class); - private final DataExchangeContext shareMemoryContext; - private final String userDataTransformClass; - private final String sendQueueKey; + private static final Logger LOGGER = LoggerFactory.getLogger(InferContext.class); + private final DataExchangeContext shareMemoryContext; + private final String userDataTransformClass; + private final String sendQueueKey; - private final String receiveQueueKey; - private InferTaskRunImpl inferTaskRunner; - private InferDataBridgeImpl dataBridge; + private final String receiveQueueKey; + private InferTaskRunImpl inferTaskRunner; + private InferDataBridgeImpl dataBridge; - public InferContext(Configuration config) { - this.shareMemoryContext = new DataExchangeContext(config); - this.receiveQueueKey = shareMemoryContext.getReceiveQueueKey(); - this.sendQueueKey = shareMemoryContext.getSendQueueKey(); - this.userDataTransformClass = config.getString(INFER_ENV_USER_TRANSFORM_CLASSNAME); - Preconditions.checkNotNull(userDataTransformClass, - INFER_ENV_USER_TRANSFORM_CLASSNAME.getKey() + " param must be not null"); - this.dataBridge = new InferDataBridgeImpl<>(shareMemoryContext); - init(); - } + public InferContext(Configuration config) { + this.shareMemoryContext = new DataExchangeContext(config); + this.receiveQueueKey = shareMemoryContext.getReceiveQueueKey(); + this.sendQueueKey = shareMemoryContext.getSendQueueKey(); + this.userDataTransformClass = config.getString(INFER_ENV_USER_TRANSFORM_CLASSNAME); + Preconditions.checkNotNull( + userDataTransformClass, + INFER_ENV_USER_TRANSFORM_CLASSNAME.getKey() + " param must be not null"); + this.dataBridge = new InferDataBridgeImpl<>(shareMemoryContext); + init(); + } - private void init() { - try { - InferEnvironmentContext inferEnvironmentContext = getInferEnvironmentContext(); - runInferTask(inferEnvironmentContext); - } catch (Exception e) { - throw new GeaflowRuntimeException("infer context init failed", e); - } + private void init() { + try { + InferEnvironmentContext inferEnvironmentContext = getInferEnvironmentContext(); + runInferTask(inferEnvironmentContext); + } catch (Exception e) { + throw new GeaflowRuntimeException("infer context init failed", e); } + } - public OUT infer(Object... feature) throws Exception { - try { - dataBridge.write(feature); - return dataBridge.read(); - } catch (Exception e) { - inferTaskRunner.stop(); - LOGGER.error("model infer read result error, python process stopped", e); - throw new GeaflowRuntimeException("receive infer result exception", e); - } + public OUT infer(Object... feature) throws Exception { + try { + dataBridge.write(feature); + return dataBridge.read(); + } catch (Exception e) { + inferTaskRunner.stop(); + LOGGER.error("model infer read result error, python process stopped", e); + throw new GeaflowRuntimeException("receive infer result exception", e); } + } - - private InferEnvironmentContext getInferEnvironmentContext() { - boolean initFinished = InferEnvironmentManager.checkInferEnvironmentStatus(); - while (!initFinished) { - InferEnvironmentManager.checkError(); - initFinished = InferEnvironmentManager.checkInferEnvironmentStatus(); - } - return InferEnvironmentManager.getEnvironmentContext(); + private InferEnvironmentContext getInferEnvironmentContext() { + boolean initFinished = InferEnvironmentManager.checkInferEnvironmentStatus(); + while (!initFinished) { + InferEnvironmentManager.checkError(); + initFinished = InferEnvironmentManager.checkInferEnvironmentStatus(); } + return InferEnvironmentManager.getEnvironmentContext(); + } - private void runInferTask(InferEnvironmentContext inferEnvironmentContext) { - inferTaskRunner = new InferTaskRunImpl(inferEnvironmentContext); - List runCommands = new ArrayList<>(); - runCommands.add(inferEnvironmentContext.getPythonExec()); - runCommands.add(inferEnvironmentContext.getInferScript()); - runCommands.add(inferEnvironmentContext.getInferTFClassNameParam(this.userDataTransformClass)); - runCommands.add(inferEnvironmentContext.getInferShareMemoryInputParam(receiveQueueKey)); - runCommands.add(inferEnvironmentContext.getInferShareMemoryOutputParam(sendQueueKey)); - inferTaskRunner.run(runCommands); - } + private void runInferTask(InferEnvironmentContext inferEnvironmentContext) { + inferTaskRunner = new InferTaskRunImpl(inferEnvironmentContext); + List runCommands = new ArrayList<>(); + runCommands.add(inferEnvironmentContext.getPythonExec()); + runCommands.add(inferEnvironmentContext.getInferScript()); + runCommands.add(inferEnvironmentContext.getInferTFClassNameParam(this.userDataTransformClass)); + runCommands.add(inferEnvironmentContext.getInferShareMemoryInputParam(receiveQueueKey)); + runCommands.add(inferEnvironmentContext.getInferShareMemoryOutputParam(sendQueueKey)); + inferTaskRunner.run(runCommands); + } - @Override - public void close() { - if (inferTaskRunner != null) { - inferTaskRunner.stop(); - LOGGER.info("infer task stop after close"); - } + @Override + public void close() { + if (inferTaskRunner != null) { + inferTaskRunner.stop(); + LOGGER.info("infer task stop after close"); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferDependencyManager.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferDependencyManager.java index 3fee2c1cf..44f5e3d01 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferDependencyManager.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferDependencyManager.java @@ -25,6 +25,7 @@ import java.nio.file.Path; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.util.InferFileUtils; @@ -33,62 +34,69 @@ public class InferDependencyManager { - private static final Logger LOGGER = LoggerFactory.getLogger(InferDependencyManager.class); - - private static final String ENV_RUNNER_SH = "infer/env/install-infer-env.sh"; + private static final Logger LOGGER = LoggerFactory.getLogger(InferDependencyManager.class); - private static final String INFER_RUNTIME_PATH = "infer/inferRuntime"; + private static final String ENV_RUNNER_SH = "infer/env/install-infer-env.sh"; - private static final String FILE_IN_JAR_PREFIX = "/"; + private static final String INFER_RUNTIME_PATH = "infer/inferRuntime"; - private final InferEnvironmentContext environmentContext; + private static final String FILE_IN_JAR_PREFIX = "/"; - private final Configuration config; + private final InferEnvironmentContext environmentContext; - private String buildInferEnvShellPath; - private String inferEnvRequirementsPath; + private final Configuration config; - public InferDependencyManager(InferEnvironmentContext environmentContext) { - this.environmentContext = environmentContext; - this.config = environmentContext.getJobConfig(); - init(); - } + private String buildInferEnvShellPath; + private String inferEnvRequirementsPath; - private void init() { - List inferRuntimeFiles = buildInferRuntimeFiles(); - for (String inferRuntimeFile : inferRuntimeFiles) { - InferFileUtils.copyInferFileByURL(environmentContext.getInferFilesDirectory(), inferRuntimeFile); - } - String pythonFilesDirectory = environmentContext.getInferFilesDirectory(); - InferFileUtils.prepareInferFilesFromJars(pythonFilesDirectory); - this.inferEnvRequirementsPath = pythonFilesDirectory + File.separator + REQUIREMENTS_TXT; - this.buildInferEnvShellPath = InferFileUtils.copyInferFileByURL(environmentContext.getVirtualEnvDirectory(), ENV_RUNNER_SH); - } + public InferDependencyManager(InferEnvironmentContext environmentContext) { + this.environmentContext = environmentContext; + this.config = environmentContext.getJobConfig(); + init(); + } - public String getBuildInferEnvShellPath() { - return buildInferEnvShellPath; + private void init() { + List inferRuntimeFiles = buildInferRuntimeFiles(); + for (String inferRuntimeFile : inferRuntimeFiles) { + InferFileUtils.copyInferFileByURL( + environmentContext.getInferFilesDirectory(), inferRuntimeFile); } - - public String getInferEnvRequirementsPath() { - return inferEnvRequirementsPath; - } - - private List buildInferRuntimeFiles() { - List runtimeFiles; - try { - List filePaths = InferFileUtils.getPathsFromResourceJAR(INFER_RUNTIME_PATH); - runtimeFiles = filePaths.stream().map(path -> { - String filePath = path.toString(); - if (filePath.startsWith(FILE_IN_JAR_PREFIX)) { - filePath = filePath.substring(1); - } - LOGGER.info("infer runtime file name is {}", filePath); - return filePath; - }).collect(Collectors.toList()); - } catch (Exception e) { - LOGGER.error("get infer runtime files error", e); - throw new GeaflowRuntimeException("get infer runtime files failed", e); - } - return runtimeFiles; + String pythonFilesDirectory = environmentContext.getInferFilesDirectory(); + InferFileUtils.prepareInferFilesFromJars(pythonFilesDirectory); + this.inferEnvRequirementsPath = pythonFilesDirectory + File.separator + REQUIREMENTS_TXT; + this.buildInferEnvShellPath = + InferFileUtils.copyInferFileByURL( + environmentContext.getVirtualEnvDirectory(), ENV_RUNNER_SH); + } + + public String getBuildInferEnvShellPath() { + return buildInferEnvShellPath; + } + + public String getInferEnvRequirementsPath() { + return inferEnvRequirementsPath; + } + + private List buildInferRuntimeFiles() { + List runtimeFiles; + try { + List filePaths = InferFileUtils.getPathsFromResourceJAR(INFER_RUNTIME_PATH); + runtimeFiles = + filePaths.stream() + .map( + path -> { + String filePath = path.toString(); + if (filePath.startsWith(FILE_IN_JAR_PREFIX)) { + filePath = filePath.substring(1); + } + LOGGER.info("infer runtime file name is {}", filePath); + return filePath; + }) + .collect(Collectors.toList()); + } catch (Exception e) { + LOGGER.error("get infer runtime files error", e); + throw new GeaflowRuntimeException("get infer runtime files failed", e); } + return runtimeFiles; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java index 569b19ada..025e8e497 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentContext.java @@ -22,123 +22,123 @@ import java.lang.management.ManagementFactory; import java.lang.management.RuntimeMXBean; import java.net.InetAddress; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class InferEnvironmentContext { - private static final String PROCESS_ID_FLAG = "@"; + private static final String PROCESS_ID_FLAG = "@"; - private static final String HOST_SEPARATOR = ":"; + private static final String HOST_SEPARATOR = ":"; - private static final String LIB_PATH = "/conda/lib"; + private static final String LIB_PATH = "/conda/lib"; - private static final String INFER_SCRIPT_FILE = "/infer_server.py"; + private static final String INFER_SCRIPT_FILE = "/infer_server.py"; - private static final String PYTHON_EXEC = "/conda/bin/python3"; + private static final String PYTHON_EXEC = "/conda/bin/python3"; - // Start infer process parameter. - private static final String TF_CLASSNAME_KEY = "--tfClassName="; + // Start infer process parameter. + private static final String TF_CLASSNAME_KEY = "--tfClassName="; - private static final String SHARE_MEMORY_INPUT_KEY = "--input_queue_shm_id="; + private static final String SHARE_MEMORY_INPUT_KEY = "--input_queue_shm_id="; - private static final String SHARE_MEMORY_OUTPUT_KEY = "--output_queue_shm_id="; + private static final String SHARE_MEMORY_OUTPUT_KEY = "--output_queue_shm_id="; - private final String virtualEnvDirectory; + private final String virtualEnvDirectory; - private final String inferFilesDirectory; + private final String inferFilesDirectory; - private final String inferLibPath; + private final String inferLibPath; - private Boolean envFinished; + private Boolean envFinished; - private final String roleNameIndex; + private final String roleNameIndex; - private final Configuration configuration; + private final Configuration configuration; - private String inferScript; + private String inferScript; - private String pythonExec; + private String pythonExec; + public InferEnvironmentContext( + String virtualEnvDirectory, String pythonFilesDirectory, Configuration configuration) { + this.virtualEnvDirectory = virtualEnvDirectory; + this.inferFilesDirectory = pythonFilesDirectory; + this.inferLibPath = virtualEnvDirectory + LIB_PATH; + this.pythonExec = virtualEnvDirectory + PYTHON_EXEC; + this.inferScript = pythonFilesDirectory + INFER_SCRIPT_FILE; + this.roleNameIndex = queryRoleNameIndex(); + this.configuration = configuration; + this.envFinished = false; + } - public InferEnvironmentContext(String virtualEnvDirectory, String pythonFilesDirectory, - Configuration configuration) { - this.virtualEnvDirectory = virtualEnvDirectory; - this.inferFilesDirectory = pythonFilesDirectory; - this.inferLibPath = virtualEnvDirectory + LIB_PATH; - this.pythonExec = virtualEnvDirectory + PYTHON_EXEC; - this.inferScript = pythonFilesDirectory + INFER_SCRIPT_FILE; - this.roleNameIndex = queryRoleNameIndex(); - this.configuration = configuration; - this.envFinished = false; + private String queryRoleNameIndex() { + try { + InetAddress address = InetAddress.getLocalHost(); + String hostName = address.getHostName(); + RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); + String name = runtime.getName(); + int processId = Integer.parseInt(name.substring(0, name.indexOf(PROCESS_ID_FLAG))); + return hostName + HOST_SEPARATOR + processId; + } catch (Exception e) { + throw new GeaflowRuntimeException("query role name and index failed", e); } + } - private String queryRoleNameIndex() { - try { - InetAddress address = InetAddress.getLocalHost(); - String hostName = address.getHostName(); - RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); - String name = runtime.getName(); - int processId = Integer.parseInt(name.substring(0, name.indexOf(PROCESS_ID_FLAG))); - return hostName + HOST_SEPARATOR + processId; - } catch (Exception e) { - throw new GeaflowRuntimeException("query role name and index failed", e); - } - } + public String getVirtualEnvDirectory() { + return virtualEnvDirectory; + } - public String getVirtualEnvDirectory() { - return virtualEnvDirectory; - } + public String getInferFilesDirectory() { + return inferFilesDirectory; + } - public String getInferFilesDirectory() { - return inferFilesDirectory; - } + public Boolean enableFinished() { + return envFinished; + } - public Boolean enableFinished() { - return envFinished; - } + public void setPythonExec(String pythonExec) { + this.pythonExec = pythonExec; + } - public void setPythonExec(String pythonExec) { - this.pythonExec = pythonExec; - } + public void setInferScript(String inferScript) { + this.inferScript = inferScript; + } - public void setInferScript(String inferScript) { - this.inferScript = inferScript; - } + public void setFinished(Boolean envFinished) { + this.envFinished = envFinished; + } - public void setFinished(Boolean envFinished) { - this.envFinished = envFinished; - } + public String getRoleNameIndex() { + return roleNameIndex; + } - public String getRoleNameIndex() { - return roleNameIndex; - } + public String getInferLibPath() { + return inferLibPath; + } - public String getInferLibPath() { - return inferLibPath; - } - - public String getPythonExec() { - return pythonExec; - } + public String getPythonExec() { + return pythonExec; + } - public Configuration getJobConfig() { - return configuration; - } + public Configuration getJobConfig() { + return configuration; + } - public String getInferTFClassNameParam(String udfClassName) { - return TF_CLASSNAME_KEY + udfClassName; - } + public String getInferTFClassNameParam(String udfClassName) { + return TF_CLASSNAME_KEY + udfClassName; + } - public String getInferShareMemoryInputParam(String shareMemoryInputKey) { - return SHARE_MEMORY_INPUT_KEY + shareMemoryInputKey; - } + public String getInferShareMemoryInputParam(String shareMemoryInputKey) { + return SHARE_MEMORY_INPUT_KEY + shareMemoryInputKey; + } - public String getInferShareMemoryOutputParam(String shareMemoryOutputKey) { - return SHARE_MEMORY_OUTPUT_KEY + shareMemoryOutputKey; - } + public String getInferShareMemoryOutputParam(String shareMemoryOutputKey) { + return SHARE_MEMORY_OUTPUT_KEY + shareMemoryOutputKey; + } - public String getInferScript() { - return inferScript; - } + public String getInferScript() { + return inferScript; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentManager.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentManager.java index 46795beb4..da1b11327 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentManager.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferEnvironmentManager.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_CONDA_URL; import static org.apache.geaflow.infer.util.InferFileUtils.releaseLock; -import com.google.common.base.Joiner; import java.io.File; import java.nio.channels.FileLock; import java.time.Duration; @@ -33,6 +32,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -41,184 +41,196 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class InferEnvironmentManager implements AutoCloseable { - - private static final Logger LOGGER = LoggerFactory.getLogger(InferEnvironmentManager.class); +import com.google.common.base.Joiner; - private static final String LOCK_FILE = "_lock"; +public class InferEnvironmentManager implements AutoCloseable { - private static final String SHELL_START = "/bin/bash"; + private static final Logger LOGGER = LoggerFactory.getLogger(InferEnvironmentManager.class); - private static final long TIMEOUT_SECOND = 10; + private static final String LOCK_FILE = "_lock"; - private static final String SCRIPT_SEPARATOR = " "; + private static final String SHELL_START = "/bin/bash"; - private static final String CHMOD_CMD = "chmod"; + private static final long TIMEOUT_SECOND = 10; - private static final String CHMOD_PERMISSION = "755"; + private static final String SCRIPT_SEPARATOR = " "; - private static final String FINISH_FILE = "_finish"; + private static final String CHMOD_CMD = "chmod"; - private static final String FAILED_FILE = "_failed"; + private static final String CHMOD_PERMISSION = "755"; - private static final String EXEC_POOL_PREFIX = "create-infer-env-"; + private static final String FINISH_FILE = "_finish"; - private static final String VIRTUAL_ENV_DIR = "inferEnv"; + private static final String FAILED_FILE = "_failed"; - private static final String INFER_FILES_DIR = "inferFiles"; + private static final String EXEC_POOL_PREFIX = "create-infer-env-"; - private static final AtomicReference ERROR_CASE = new AtomicReference<>(); + private static final String VIRTUAL_ENV_DIR = "inferEnv"; - private static final AtomicInteger THREAD_IDX_GENERATOR = new AtomicInteger(); + private static final String INFER_FILES_DIR = "inferFiles"; - private static final AtomicBoolean INITIALIZED = new AtomicBoolean(false); + private static final AtomicReference ERROR_CASE = new AtomicReference<>(); - private static final AtomicBoolean SUCCESS_FLAG = new AtomicBoolean(false); + private static final AtomicInteger THREAD_IDX_GENERATOR = new AtomicInteger(); - private static InferEnvironmentManager INSTANCE; + private static final AtomicBoolean INITIALIZED = new AtomicBoolean(false); - private static InferEnvironmentContext environmentContext; + private static final AtomicBoolean SUCCESS_FLAG = new AtomicBoolean(false); - private final Configuration configuration; + private static InferEnvironmentManager INSTANCE; - private final transient ExecutorService executorService; + private static InferEnvironmentContext environmentContext; + private final Configuration configuration; - public static synchronized InferEnvironmentManager buildInferEnvironmentManager(Configuration config) { - if (INSTANCE == null) { - INSTANCE = new InferEnvironmentManager(config); - } - return INSTANCE; - } + private final transient ExecutorService executorService; - private InferEnvironmentManager(Configuration config) { - this.configuration = config; - this.executorService = Executors.newSingleThreadExecutor(r -> { - Thread t = new Thread(r); - String name = String.format(EXEC_POOL_PREFIX + THREAD_IDX_GENERATOR.getAndIncrement()); - t.setName(name); - t.setDaemon(true); - return t; - }); + public static synchronized InferEnvironmentManager buildInferEnvironmentManager( + Configuration config) { + if (INSTANCE == null) { + INSTANCE = new InferEnvironmentManager(config); } - - public void createEnvironment() { - if (INITIALIZED.compareAndSet(false, true)) { - executorService.execute(() -> { - try { - environmentContext = constructInferEnvironment(configuration); - if (environmentContext.enableFinished()) { - SUCCESS_FLAG.set(true); - LOGGER.info("{} create infer environment finished", - environmentContext.getRoleNameIndex()); - } - } catch (Throwable e) { - SUCCESS_FLAG.set(false); - ERROR_CASE.set(e); - LOGGER.error("execute install infer environment error", e); - } + return INSTANCE; + } + + private InferEnvironmentManager(Configuration config) { + this.configuration = config; + this.executorService = + Executors.newSingleThreadExecutor( + r -> { + Thread t = new Thread(r); + String name = + String.format(EXEC_POOL_PREFIX + THREAD_IDX_GENERATOR.getAndIncrement()); + t.setName(name); + t.setDaemon(true); + return t; }); - } - } - - private InferEnvironmentContext constructInferEnvironment(Configuration configuration) { - String inferEnvDirectory = InferFileUtils.createTargetDir(VIRTUAL_ENV_DIR, configuration); - String inferFilesDirectory = InferFileUtils.createTargetDir(INFER_FILES_DIR, configuration); - - InferEnvironmentContext environmentContext = - new InferEnvironmentContext(inferEnvDirectory, inferFilesDirectory, configuration); - File lockFile; - FileLock lock = null; - try { - lockFile = new File(inferEnvDirectory + File.separator + LOCK_FILE); - if (!lockFile.exists()) { - boolean createLock = lockFile.createNewFile(); - LOGGER.info("{} create lock file result {}", - environmentContext.getRoleNameIndex(), createLock); - } - - lock = InferFileUtils.addLock(lockFile); - File finishFile = new File(inferEnvDirectory + File.separator + FINISH_FILE); - File failedFile = new File(inferEnvDirectory + File.separator + FAILED_FILE); - if (failedFile.exists()) { - environmentContext.setFinished(false); - LOGGER.warn("{} create infer environment failed", environmentContext.getRoleNameIndex()); - return environmentContext; - } - if (finishFile.exists()) { - environmentContext.setFinished(true); - LOGGER.info("{} create infer environment finished", environmentContext.getRoleNameIndex()); - return environmentContext; + } + + public void createEnvironment() { + if (INITIALIZED.compareAndSet(false, true)) { + executorService.execute( + () -> { + try { + environmentContext = constructInferEnvironment(configuration); + if (environmentContext.enableFinished()) { + SUCCESS_FLAG.set(true); + LOGGER.info( + "{} create infer environment finished", environmentContext.getRoleNameIndex()); + } + } catch (Throwable e) { + SUCCESS_FLAG.set(false); + ERROR_CASE.set(e); + LOGGER.error("execute install infer environment error", e); } - InferDependencyManager inferDependencyManager = new InferDependencyManager(environmentContext); - boolean createFinished = createInferVirtualEnv(inferDependencyManager, environmentContext.getVirtualEnvDirectory()); - environmentContext.setFinished(createFinished); - if (createFinished) { - finishFile.createNewFile(); - } else { - failedFile.createNewFile(); - throw new GeaflowRuntimeException("execute virtual env shell failed"); - } - } catch (Throwable e) { - ERROR_CASE.set(e); - LOGGER.error("construct infer environment failed", e); - } finally { - if (lock != null) { - releaseLock(lock); - } - } - return environmentContext; + }); } - - private boolean createInferVirtualEnv(InferDependencyManager dependencyManager, String workingDir) { - String shellPath = dependencyManager.getBuildInferEnvShellPath(); - List execParams = new ArrayList<>(); - String requirementsPath = dependencyManager.getInferEnvRequirementsPath(); - execParams.add(workingDir); - execParams.add(requirementsPath); - String conda = configuration.getString(INFER_ENV_CONDA_URL); - execParams.add(conda); - List shellCommand = new ArrayList<>(Arrays.asList(SHELL_START, shellPath)); - shellCommand.addAll(execParams); - String cmd = Joiner.on(" ").join(shellCommand); - LOGGER.info("create infer virtual env {}", cmd); - - // Run "chmod 755 $shellPath" - List runCommands = new ArrayList<>(); - runCommands.add(CHMOD_CMD); - runCommands.add(CHMOD_PERMISSION); - runCommands.add(shellPath); - String chmodCmd = Joiner.on(SCRIPT_SEPARATOR).join(runCommands); - LOGGER.info("change {} permission run command is {}", shellPath, chmodCmd); - int installEnvTimeOut = configuration.getInteger(FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC); - if (!ShellExecUtils.run(chmodCmd, Duration.ofSeconds(installEnvTimeOut), LOGGER::info, LOGGER::error)) { - return false; - } - return ShellExecUtils.run(cmd, Duration.ofSeconds(installEnvTimeOut), LOGGER::info, LOGGER::error, workingDir); - } - - @Override - public void close() throws Exception { - if (executorService != null) { - executorService.shutdownNow(); - } + } + + private InferEnvironmentContext constructInferEnvironment(Configuration configuration) { + String inferEnvDirectory = InferFileUtils.createTargetDir(VIRTUAL_ENV_DIR, configuration); + String inferFilesDirectory = InferFileUtils.createTargetDir(INFER_FILES_DIR, configuration); + + InferEnvironmentContext environmentContext = + new InferEnvironmentContext(inferEnvDirectory, inferFilesDirectory, configuration); + File lockFile; + FileLock lock = null; + try { + lockFile = new File(inferEnvDirectory + File.separator + LOCK_FILE); + if (!lockFile.exists()) { + boolean createLock = lockFile.createNewFile(); + LOGGER.info( + "{} create lock file result {}", environmentContext.getRoleNameIndex(), createLock); + } + + lock = InferFileUtils.addLock(lockFile); + File finishFile = new File(inferEnvDirectory + File.separator + FINISH_FILE); + File failedFile = new File(inferEnvDirectory + File.separator + FAILED_FILE); + if (failedFile.exists()) { + environmentContext.setFinished(false); + LOGGER.warn("{} create infer environment failed", environmentContext.getRoleNameIndex()); + return environmentContext; + } + if (finishFile.exists()) { + environmentContext.setFinished(true); + LOGGER.info("{} create infer environment finished", environmentContext.getRoleNameIndex()); + return environmentContext; + } + InferDependencyManager inferDependencyManager = + new InferDependencyManager(environmentContext); + boolean createFinished = + createInferVirtualEnv( + inferDependencyManager, environmentContext.getVirtualEnvDirectory()); + environmentContext.setFinished(createFinished); + if (createFinished) { + finishFile.createNewFile(); + } else { + failedFile.createNewFile(); + throw new GeaflowRuntimeException("execute virtual env shell failed"); + } + } catch (Throwable e) { + ERROR_CASE.set(e); + LOGGER.error("construct infer environment failed", e); + } finally { + if (lock != null) { + releaseLock(lock); + } } - - public static Boolean checkInferEnvironmentStatus() { - return SUCCESS_FLAG.get(); + return environmentContext; + } + + private boolean createInferVirtualEnv( + InferDependencyManager dependencyManager, String workingDir) { + String shellPath = dependencyManager.getBuildInferEnvShellPath(); + List execParams = new ArrayList<>(); + String requirementsPath = dependencyManager.getInferEnvRequirementsPath(); + execParams.add(workingDir); + execParams.add(requirementsPath); + String conda = configuration.getString(INFER_ENV_CONDA_URL); + execParams.add(conda); + List shellCommand = new ArrayList<>(Arrays.asList(SHELL_START, shellPath)); + shellCommand.addAll(execParams); + String cmd = Joiner.on(" ").join(shellCommand); + LOGGER.info("create infer virtual env {}", cmd); + + // Run "chmod 755 $shellPath" + List runCommands = new ArrayList<>(); + runCommands.add(CHMOD_CMD); + runCommands.add(CHMOD_PERMISSION); + runCommands.add(shellPath); + String chmodCmd = Joiner.on(SCRIPT_SEPARATOR).join(runCommands); + LOGGER.info("change {} permission run command is {}", shellPath, chmodCmd); + int installEnvTimeOut = + configuration.getInteger(FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC); + if (!ShellExecUtils.run( + chmodCmd, Duration.ofSeconds(installEnvTimeOut), LOGGER::info, LOGGER::error)) { + return false; } - - public static InferEnvironmentContext getEnvironmentContext() { - return environmentContext; + return ShellExecUtils.run( + cmd, Duration.ofSeconds(installEnvTimeOut), LOGGER::info, LOGGER::error, workingDir); + } + + @Override + public void close() throws Exception { + if (executorService != null) { + executorService.shutdownNow(); } - - public static void checkError() { - final Throwable exception = ERROR_CASE.get(); - if (exception != null) { - String message = "create infer environment failed: " + exception.getMessage(); - LOGGER.error(message); - throw new GeaflowRuntimeException(message, exception); - } + } + + public static Boolean checkInferEnvironmentStatus() { + return SUCCESS_FLAG.get(); + } + + public static InferEnvironmentContext getEnvironmentContext() { + return environmentContext; + } + + public static void checkError() { + final Throwable exception = ERROR_CASE.get(); + if (exception != null) { + String message = "create infer environment failed: " + exception.getMessage(); + LOGGER.error(message); + throw new GeaflowRuntimeException(message, exception); } - + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRun.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRun.java index d7c4b8b75..de9c6a315 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRun.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRun.java @@ -23,13 +23,9 @@ public interface InferTaskRun { - /** - * Run infer task. - */ - void run(List script); + /** Run infer task. */ + void run(List script); - /** - * Stop infer task. - */ - void stop(); + /** Stop infer task. */ + void stop(); } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRunImpl.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRunImpl.java index bfd02c7a4..60a33b5a2 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRunImpl.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskRunImpl.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_USER_DEFINE_LIB_PATH; import static org.apache.geaflow.infer.InferTaskStatus.FAILED; -import com.google.common.base.Joiner; import java.io.File; import java.util.ArrayList; import java.util.Arrays; @@ -29,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.log.ProcessLoggerManager; @@ -36,114 +36,118 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Joiner; + public class InferTaskRunImpl implements InferTaskRun { - private static final Logger LOGGER = LoggerFactory.getLogger(InferTaskRunImpl.class); + private static final Logger LOGGER = LoggerFactory.getLogger(InferTaskRunImpl.class); - private static final File NULL_FILE = new File((System.getProperty("os.name").startsWith( - "Windows") ? "NUL" : "/dev/null")); + private static final File NULL_FILE = + new File((System.getProperty("os.name").startsWith("Windows") ? "NUL" : "/dev/null")); - private static final long TIMEOUT_SECOND = 10; - private static final String SCRIPT_SEPARATOR = " "; - private static final String LD_LIBRARY_PATH = "LD_LIBRARY_PATH"; - private static final String PATH = "PATH"; - private static final String PATH_REGEX = ":"; - private static final String PYTHON_PATH = "PYTHONPATH"; - private final InferEnvironmentContext inferEnvironmentContext; - private final Configuration jobConfig; - private final String virtualEnvPath; - private final String inferFilePath; - private final String executePath; - private Process inferTask; - private String inferScript; + private static final long TIMEOUT_SECOND = 10; + private static final String SCRIPT_SEPARATOR = " "; + private static final String LD_LIBRARY_PATH = "LD_LIBRARY_PATH"; + private static final String PATH = "PATH"; + private static final String PATH_REGEX = ":"; + private static final String PYTHON_PATH = "PYTHONPATH"; + private final InferEnvironmentContext inferEnvironmentContext; + private final Configuration jobConfig; + private final String virtualEnvPath; + private final String inferFilePath; + private final String executePath; + private Process inferTask; + private String inferScript; - private InferTaskStatus inferTaskStatus; + private InferTaskStatus inferTaskStatus; - public InferTaskRunImpl(InferEnvironmentContext inferEnvironmentContext) { - this.inferEnvironmentContext = inferEnvironmentContext; - this.jobConfig = inferEnvironmentContext.getJobConfig(); - this.inferFilePath = inferEnvironmentContext.getInferFilesDirectory(); - this.virtualEnvPath = inferEnvironmentContext.getVirtualEnvDirectory(); - this.executePath = this.virtualEnvPath + "/bin"; - } + public InferTaskRunImpl(InferEnvironmentContext inferEnvironmentContext) { + this.inferEnvironmentContext = inferEnvironmentContext; + this.jobConfig = inferEnvironmentContext.getJobConfig(); + this.inferFilePath = inferEnvironmentContext.getInferFilesDirectory(); + this.virtualEnvPath = inferEnvironmentContext.getVirtualEnvDirectory(); + this.executePath = this.virtualEnvPath + "/bin"; + } - @Override - public void run(List script) { - inferScript = Joiner.on(SCRIPT_SEPARATOR).join(script); - LOGGER.info("infer task run command is {}", inferScript); - ProcessBuilder inferTaskBuilder = new ProcessBuilder(script); - buildInferTaskBuilder(inferTaskBuilder); - try { - inferTask = inferTaskBuilder.start(); - this.inferTaskStatus = InferTaskStatus.RUNNING; - try (ProcessLoggerManager processLogger = new ProcessLoggerManager(inferTask, new Slf4JProcessOutputConsumer(this.getClass().getSimpleName()))) { - processLogger.startLogging(); - int exitValue = 0; - if (inferTask.waitFor(TIMEOUT_SECOND, TimeUnit.SECONDS)) { - exitValue = inferTask.exitValue(); - this.inferTaskStatus = FAILED; - } else { - this.inferTaskStatus = InferTaskStatus.RUNNING; - } - if (exitValue != 0) { - throw new GeaflowRuntimeException( - String.format("infer task [%s] run failed, exitCode is %d, message is " - + "%s", inferScript, exitValue, processLogger.getErrorOutputLogger().get())); - } - } - } catch (Exception e) { - throw new GeaflowRuntimeException("infer task run failed", e); - } finally { - if (inferTask != null && inferTaskStatus.equals(FAILED)) { - inferTask.destroyForcibly(); - } + @Override + public void run(List script) { + inferScript = Joiner.on(SCRIPT_SEPARATOR).join(script); + LOGGER.info("infer task run command is {}", inferScript); + ProcessBuilder inferTaskBuilder = new ProcessBuilder(script); + buildInferTaskBuilder(inferTaskBuilder); + try { + inferTask = inferTaskBuilder.start(); + this.inferTaskStatus = InferTaskStatus.RUNNING; + try (ProcessLoggerManager processLogger = + new ProcessLoggerManager( + inferTask, new Slf4JProcessOutputConsumer(this.getClass().getSimpleName()))) { + processLogger.startLogging(); + int exitValue = 0; + if (inferTask.waitFor(TIMEOUT_SECOND, TimeUnit.SECONDS)) { + exitValue = inferTask.exitValue(); + this.inferTaskStatus = FAILED; + } else { + this.inferTaskStatus = InferTaskStatus.RUNNING; } - } - - @Override - public void stop() { - if (inferTask != null) { - inferTask.destroyForcibly(); + if (exitValue != 0) { + throw new GeaflowRuntimeException( + String.format( + "infer task [%s] run failed, exitCode is %d, message is " + "%s", + inferScript, exitValue, processLogger.getErrorOutputLogger().get())); } + } + } catch (Exception e) { + throw new GeaflowRuntimeException("infer task run failed", e); + } finally { + if (inferTask != null && inferTaskStatus.equals(FAILED)) { + inferTask.destroyForcibly(); + } } + } - private void buildInferTaskBuilder(ProcessBuilder processBuilder) { - Map environment = processBuilder.environment(); - environment.put(PATH, executePath); - processBuilder.directory(new File(this.inferFilePath)); - processBuilder.redirectErrorStream(true); - setLibraryPath(processBuilder); - environment.computeIfAbsent(PYTHON_PATH, k -> virtualEnvPath); - processBuilder.redirectOutput(NULL_FILE); + @Override + public void stop() { + if (inferTask != null) { + inferTask.destroyForcibly(); } + } + private void buildInferTaskBuilder(ProcessBuilder processBuilder) { + Map environment = processBuilder.environment(); + environment.put(PATH, executePath); + processBuilder.directory(new File(this.inferFilePath)); + processBuilder.redirectErrorStream(true); + setLibraryPath(processBuilder); + environment.computeIfAbsent(PYTHON_PATH, k -> virtualEnvPath); + processBuilder.redirectOutput(NULL_FILE); + } - private void setLibraryPath(ProcessBuilder processBuilder) { - List userDefineLibPath = getUserDefineLibPath(); - StringBuilder libBuilder = new StringBuilder(); - libBuilder.append(this.inferEnvironmentContext.getInferLibPath()); - libBuilder.append(PATH_REGEX); - for (String ldLibraryPath : userDefineLibPath) { - libBuilder.append(ldLibraryPath); - libBuilder.append(PATH_REGEX); - } - String ldLibraryPathEnvVar = System.getenv(LD_LIBRARY_PATH); - libBuilder.append(ldLibraryPathEnvVar); - processBuilder.environment().put(LD_LIBRARY_PATH, libBuilder.toString()); + private void setLibraryPath(ProcessBuilder processBuilder) { + List userDefineLibPath = getUserDefineLibPath(); + StringBuilder libBuilder = new StringBuilder(); + libBuilder.append(this.inferEnvironmentContext.getInferLibPath()); + libBuilder.append(PATH_REGEX); + for (String ldLibraryPath : userDefineLibPath) { + libBuilder.append(ldLibraryPath); + libBuilder.append(PATH_REGEX); } + String ldLibraryPathEnvVar = System.getenv(LD_LIBRARY_PATH); + libBuilder.append(ldLibraryPathEnvVar); + processBuilder.environment().put(LD_LIBRARY_PATH, libBuilder.toString()); + } - private List getUserDefineLibPath() { - String userLibPath = jobConfig.getString(INFER_USER_DEFINE_LIB_PATH.getKey()); - List result = new ArrayList<>(); - if (userLibPath != null) { - String[] libs = userLibPath.split(","); - Iterator iterator = Arrays.stream(libs).iterator(); - while (iterator.hasNext()) { - String libPath = this.inferFilePath + File.separator + iterator.next().trim(); - LOGGER.info("define infer lib path is {}", libPath); - result.add(libPath); - } - } - return result; + private List getUserDefineLibPath() { + String userLibPath = jobConfig.getString(INFER_USER_DEFINE_LIB_PATH.getKey()); + List result = new ArrayList<>(); + if (userLibPath != null) { + String[] libs = userLibPath.split(","); + Iterator iterator = Arrays.stream(libs).iterator(); + while (iterator.hasNext()) { + String libPath = this.inferFilePath + File.separator + iterator.next().trim(); + LOGGER.info("define infer lib path is {}", libPath); + result.add(libPath); + } } + return result; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskStatus.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskStatus.java index 2445a7f90..e69e266ec 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskStatus.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/InferTaskStatus.java @@ -21,24 +21,15 @@ public enum InferTaskStatus { - /** - * Prepare infer task. - */ - RUNNING, + /** Prepare infer task. */ + RUNNING, - /** - * Infer task run succeed. - */ - SUCCEED, + /** Infer task run succeed. */ + SUCCEED, - /** - * Infer task run failed. - */ - FAILED, - - /** - * Stop infer task. - */ - KILLED + /** Infer task run failed. */ + FAILED, + /** Stop infer task. */ + KILLED } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java index 417e72703..d1a6a9985 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeContext.java @@ -24,6 +24,7 @@ import java.io.Closeable; import java.io.File; import java.io.IOException; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -31,82 +32,83 @@ public class DataExchangeContext implements Closeable { - private static final String FILE_KEY_PREFIX = "queue-"; - - private static final int INT_SIZE = 4; - - private static final String MMAP_KEY_PREFIX = "queue://"; - private static final String SHARE_MEMORY_DIR = "/infer_data"; - private static final String KEY_SEPARATOR = ":"; - - private static final String MMAP_INPUT_KEY_SUFFIX = ".input"; - private static final String MMAP_OUTPUT_KEY_SUFFIX = ".output"; - private final File localDirectory; - private final long queueEndpoint; - private final DataExchangeQueue receiveQueue; - private final DataExchangeQueue sendQueue; - - private final File receiveQueueFile; - private final File sendQueueFile; - private String receivePath; - private String sendPath; - - public DataExchangeContext(Configuration config) { - this.localDirectory = new File(InferFileUtils.getInferDirectory(config) + SHARE_MEMORY_DIR); - this.queueEndpoint = UnSafeUtils.UNSAFE.allocateMemory(INT_SIZE); - UnSafeUtils.UNSAFE.setMemory(queueEndpoint, INT_SIZE, (byte) 0); - this.receiveQueueFile = createTempFile(FILE_KEY_PREFIX, MMAP_INPUT_KEY_SUFFIX); - this.sendQueueFile = createTempFile(FILE_KEY_PREFIX, MMAP_OUTPUT_KEY_SUFFIX); - this.receivePath = receiveQueueFile.getAbsolutePath(); - this.sendPath = sendQueueFile.getAbsolutePath(); - int queueCapacity = config.getInteger(INFER_ENV_SHARE_MEMORY_QUEUE_SIZE); - this.receiveQueue = new DataExchangeQueue(receivePath, queueCapacity, true); - this.sendQueue = new DataExchangeQueue(sendPath, queueCapacity, true); - Runtime.getRuntime().addShutdownHook(new Thread(() -> UnSafeUtils.UNSAFE.freeMemory(queueEndpoint))); + private static final String FILE_KEY_PREFIX = "queue-"; + + private static final int INT_SIZE = 4; + + private static final String MMAP_KEY_PREFIX = "queue://"; + private static final String SHARE_MEMORY_DIR = "/infer_data"; + private static final String KEY_SEPARATOR = ":"; + + private static final String MMAP_INPUT_KEY_SUFFIX = ".input"; + private static final String MMAP_OUTPUT_KEY_SUFFIX = ".output"; + private final File localDirectory; + private final long queueEndpoint; + private final DataExchangeQueue receiveQueue; + private final DataExchangeQueue sendQueue; + + private final File receiveQueueFile; + private final File sendQueueFile; + private String receivePath; + private String sendPath; + + public DataExchangeContext(Configuration config) { + this.localDirectory = new File(InferFileUtils.getInferDirectory(config) + SHARE_MEMORY_DIR); + this.queueEndpoint = UnSafeUtils.UNSAFE.allocateMemory(INT_SIZE); + UnSafeUtils.UNSAFE.setMemory(queueEndpoint, INT_SIZE, (byte) 0); + this.receiveQueueFile = createTempFile(FILE_KEY_PREFIX, MMAP_INPUT_KEY_SUFFIX); + this.sendQueueFile = createTempFile(FILE_KEY_PREFIX, MMAP_OUTPUT_KEY_SUFFIX); + this.receivePath = receiveQueueFile.getAbsolutePath(); + this.sendPath = sendQueueFile.getAbsolutePath(); + int queueCapacity = config.getInteger(INFER_ENV_SHARE_MEMORY_QUEUE_SIZE); + this.receiveQueue = new DataExchangeQueue(receivePath, queueCapacity, true); + this.sendQueue = new DataExchangeQueue(sendPath, queueCapacity, true); + Runtime.getRuntime() + .addShutdownHook(new Thread(() -> UnSafeUtils.UNSAFE.freeMemory(queueEndpoint))); + } + + public String getReceiveQueueKey() { + return MMAP_KEY_PREFIX + receivePath + KEY_SEPARATOR + receiveQueue.getMemoryMapSize(); + } + + public String getSendQueueKey() { + return MMAP_KEY_PREFIX + sendPath + KEY_SEPARATOR + sendQueue.getMemoryMapSize(); + } + + @Override + public synchronized void close() throws IOException { + if (receiveQueue != null) { + receiveQueue.close(); } - - public String getReceiveQueueKey() { - return MMAP_KEY_PREFIX + receivePath + KEY_SEPARATOR + receiveQueue.getMemoryMapSize(); + if (sendQueue != null) { + sendQueue.close(); } - - public String getSendQueueKey() { - return MMAP_KEY_PREFIX + sendPath + KEY_SEPARATOR + sendQueue.getMemoryMapSize(); + if (receiveQueueFile != null) { + receiveQueueFile.delete(); } - - @Override - public synchronized void close() throws IOException { - if (receiveQueue != null) { - receiveQueue.close(); - } - if (sendQueue != null) { - sendQueue.close(); - } - if (receiveQueueFile != null) { - receiveQueueFile.delete(); - } - if (sendQueueFile != null) { - sendQueueFile.delete(); - } - UnSafeUtils.UNSAFE.freeMemory(this.queueEndpoint); - FileUtils.deleteQuietly(localDirectory); + if (sendQueueFile != null) { + sendQueueFile.delete(); } - - public DataExchangeQueue getReceiveQueue() { - return receiveQueue; - } - - public DataExchangeQueue getSendQueue() { - return sendQueue; - } - - private File createTempFile(String prefix, String suffix) { - try { - if (!localDirectory.exists()) { - InferFileUtils.forceMkdir(localDirectory); - } - return File.createTempFile(prefix, suffix, localDirectory); - } catch (IOException e) { - throw new GeaflowRuntimeException("create temp file on infer directory failed ", e); - } + UnSafeUtils.UNSAFE.freeMemory(this.queueEndpoint); + FileUtils.deleteQuietly(localDirectory); + } + + public DataExchangeQueue getReceiveQueue() { + return receiveQueue; + } + + public DataExchangeQueue getSendQueue() { + return sendQueue; + } + + private File createTempFile(String prefix, String suffix) { + try { + if (!localDirectory.exists()) { + InferFileUtils.forceMkdir(localDirectory); + } + return File.createTempFile(prefix, suffix, localDirectory); + } catch (IOException e) { + throw new GeaflowRuntimeException("create temp file on infer directory failed ", e); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java index 29057f60e..4eb926cd9 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataExchangeQueue.java @@ -21,148 +21,147 @@ import java.io.Closeable; import java.util.concurrent.atomic.AtomicBoolean; + import org.jctools.util.PortableJvmInfo; import org.jctools.util.Pow2; public final class DataExchangeQueue implements Closeable { - private static final AtomicBoolean CLOSED = new AtomicBoolean(false); - private final long outputNextAddress; - private final long capacityAddress; - private final long outputAddress; - private final long inputNextAddress; - private final long barrierAddress; - private final long currentBufferAddress; - private final long mapAddress; - private final int queueCapacity; - private final int bufferCapacity; - private final long initialRawAddress; - - private final long startPointAddress; - - private final long endPointAddress; - private final MemoryMapper memoryMapper; - - public DataExchangeQueue(String mapKey, int capacity, boolean reset) { - this.bufferCapacity = getBufferCapacity(capacity); - this.memoryMapper = new MemoryMapper(mapKey, - bufferCapacity + PortableJvmInfo.CACHE_LINE_SIZE); - this.mapAddress = memoryMapper.getMapAddress(); - this.queueCapacity = Pow2.roundToPowerOfTwo(capacity); - this.initialRawAddress = Pow2.align(mapAddress, PortableJvmInfo.CACHE_LINE_SIZE); - this.startPointAddress = initialRawAddress; - this.capacityAddress = startPointAddress + PortableJvmInfo.CACHE_LINE_SIZE; - this.outputAddress = startPointAddress + 2L * PortableJvmInfo.CACHE_LINE_SIZE; - this.inputNextAddress = outputAddress + 8; - this.outputNextAddress = startPointAddress + 8; - this.endPointAddress = outputAddress + PortableJvmInfo.CACHE_LINE_SIZE; - this.barrierAddress = endPointAddress + 8; - this.currentBufferAddress = barrierAddress + 8; - if (reset) { - reset(); - } - } - - - @Override - public synchronized void close() { - CLOSED.set(true); - if (memoryMapper != null) { - memoryMapper.close(); - } - UnSafeUtils.UNSAFE.freeMemory(mapAddress); - } - - public long getMemoryMapSize() { - if (memoryMapper == null) { - return 0; - } - return memoryMapper.getMapSize(); - } - - public long getInputPointer() { - return UnSafeUtils.UNSAFE.getLong(null, startPointAddress); - } - - public long getInputPointerByVolatile() { - return UnSafeUtils.UNSAFE.getLongVolatile(null, startPointAddress); - } - - public void setInputPointer(long value) { - UnSafeUtils.UNSAFE.putOrderedLong(null, startPointAddress, value); - } - - public long getOutputPointer() { - return UnSafeUtils.UNSAFE.getLong(null, outputAddress); - } - - public long getOutputPointerByVolatile() { - return UnSafeUtils.UNSAFE.getLongVolatile(null, outputAddress); - } - - public void setOutputPointer(long value) { - UnSafeUtils.UNSAFE.putOrderedLong(null, outputAddress, value); - } - - public long getInputNextPointer() { - return UnSafeUtils.UNSAFE.getLong(null, inputNextAddress); - } - - public void setInputNextPointer(final long value) { - UnSafeUtils.UNSAFE.putLong(inputNextAddress, value); - } - - public long getOutputNextPointer() { - return UnSafeUtils.UNSAFE.getLong(null, outputNextAddress); - } - - public void setOutputNextPointer(final long value) { - UnSafeUtils.UNSAFE.putLong(outputNextAddress, value); - } - - public long getBarrierAddress() { - return UnSafeUtils.UNSAFE.getLong(barrierAddress); - } - - public long getCurrentBufferAddress() { - return UnSafeUtils.UNSAFE.getLong(currentBufferAddress); - } - - public boolean enableFinished() { - return UnSafeUtils.UNSAFE.getLongVolatile(null, endPointAddress) != 0; - } - - public synchronized void markFinished() { - if (!CLOSED.get()) { - UnSafeUtils.UNSAFE.putOrderedLong(null, endPointAddress, -1); - } - } - - public long getInitialQueueAddress() { - return initialRawAddress + 4L * PortableJvmInfo.CACHE_LINE_SIZE; - } - - public int getQueueMask() { - return this.queueCapacity - 1; - } - - public int getQueueCapacity() { - return queueCapacity; - } - - public int getBufferCapacity(int capacity) { - return (Pow2.roundToPowerOfTwo(capacity)) + 4 * PortableJvmInfo.CACHE_LINE_SIZE; - } - - public void reset() { - UnSafeUtils.UNSAFE.setMemory(initialRawAddress, bufferCapacity, (byte) 0); - UnSafeUtils.UNSAFE.putLongVolatile(null, capacityAddress, queueCapacity); - } + private static final AtomicBoolean CLOSED = new AtomicBoolean(false); + private final long outputNextAddress; + private final long capacityAddress; + private final long outputAddress; + private final long inputNextAddress; + private final long barrierAddress; + private final long currentBufferAddress; + private final long mapAddress; + private final int queueCapacity; + private final int bufferCapacity; + private final long initialRawAddress; + + private final long startPointAddress; + + private final long endPointAddress; + private final MemoryMapper memoryMapper; + + public DataExchangeQueue(String mapKey, int capacity, boolean reset) { + this.bufferCapacity = getBufferCapacity(capacity); + this.memoryMapper = new MemoryMapper(mapKey, bufferCapacity + PortableJvmInfo.CACHE_LINE_SIZE); + this.mapAddress = memoryMapper.getMapAddress(); + this.queueCapacity = Pow2.roundToPowerOfTwo(capacity); + this.initialRawAddress = Pow2.align(mapAddress, PortableJvmInfo.CACHE_LINE_SIZE); + this.startPointAddress = initialRawAddress; + this.capacityAddress = startPointAddress + PortableJvmInfo.CACHE_LINE_SIZE; + this.outputAddress = startPointAddress + 2L * PortableJvmInfo.CACHE_LINE_SIZE; + this.inputNextAddress = outputAddress + 8; + this.outputNextAddress = startPointAddress + 8; + this.endPointAddress = outputAddress + PortableJvmInfo.CACHE_LINE_SIZE; + this.barrierAddress = endPointAddress + 8; + this.currentBufferAddress = barrierAddress + 8; + if (reset) { + reset(); + } + } + + @Override + public synchronized void close() { + CLOSED.set(true); + if (memoryMapper != null) { + memoryMapper.close(); + } + UnSafeUtils.UNSAFE.freeMemory(mapAddress); + } + + public long getMemoryMapSize() { + if (memoryMapper == null) { + return 0; + } + return memoryMapper.getMapSize(); + } + + public long getInputPointer() { + return UnSafeUtils.UNSAFE.getLong(null, startPointAddress); + } + + public long getInputPointerByVolatile() { + return UnSafeUtils.UNSAFE.getLongVolatile(null, startPointAddress); + } + + public void setInputPointer(long value) { + UnSafeUtils.UNSAFE.putOrderedLong(null, startPointAddress, value); + } + + public long getOutputPointer() { + return UnSafeUtils.UNSAFE.getLong(null, outputAddress); + } + + public long getOutputPointerByVolatile() { + return UnSafeUtils.UNSAFE.getLongVolatile(null, outputAddress); + } + + public void setOutputPointer(long value) { + UnSafeUtils.UNSAFE.putOrderedLong(null, outputAddress, value); + } + + public long getInputNextPointer() { + return UnSafeUtils.UNSAFE.getLong(null, inputNextAddress); + } + + public void setInputNextPointer(final long value) { + UnSafeUtils.UNSAFE.putLong(inputNextAddress, value); + } + + public long getOutputNextPointer() { + return UnSafeUtils.UNSAFE.getLong(null, outputNextAddress); + } + + public void setOutputNextPointer(final long value) { + UnSafeUtils.UNSAFE.putLong(outputNextAddress, value); + } + + public long getBarrierAddress() { + return UnSafeUtils.UNSAFE.getLong(barrierAddress); + } + + public long getCurrentBufferAddress() { + return UnSafeUtils.UNSAFE.getLong(currentBufferAddress); + } + + public boolean enableFinished() { + return UnSafeUtils.UNSAFE.getLongVolatile(null, endPointAddress) != 0; + } + + public synchronized void markFinished() { + if (!CLOSED.get()) { + UnSafeUtils.UNSAFE.putOrderedLong(null, endPointAddress, -1); + } + } + + public long getInitialQueueAddress() { + return initialRawAddress + 4L * PortableJvmInfo.CACHE_LINE_SIZE; + } + + public int getQueueMask() { + return this.queueCapacity - 1; + } + + public int getQueueCapacity() { + return queueCapacity; + } + + public int getBufferCapacity(int capacity) { + return (Pow2.roundToPowerOfTwo(capacity)) + 4 * PortableJvmInfo.CACHE_LINE_SIZE; + } + + public void reset() { + UnSafeUtils.UNSAFE.setMemory(initialRawAddress, bufferCapacity, (byte) 0); + UnSafeUtils.UNSAFE.putLongVolatile(null, capacityAddress, queueCapacity); + } - public static long getNextPointIndex(long v, int capacity) { - if ((v & (capacity - 1)) == 0) { - return v + capacity; - } - return Pow2.align(v, capacity); + public static long getNextPointIndex(long v, int capacity) { + if ((v & (capacity - 1)) == 0) { + return v + capacity; } -} \ No newline at end of file + return Pow2.align(v, capacity); + } +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueInputStream.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueInputStream.java index e64970038..fda636b4b 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueInputStream.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueInputStream.java @@ -26,145 +26,146 @@ import java.io.InterruptedIOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DataQueueInputStream extends InputStream { - private static final Logger LOGGER = LoggerFactory.getLogger(DataQueueInputStream.class); - private static final int BUFFER_SIZE = 10 * 1024; - private static final int INT_SIZE = 4; - private static final int SHORT_SIZE = 2; - private static final int LONG_SIZE = 8; - private final DataExchangeQueue dataExchangeQueue; - private final byte[] dataBufferArray; - private final ByteBuffer buffer; - private final int queueCapacity; - private final long initialAddress; - private final int queueMask; - - public DataQueueInputStream(DataExchangeQueue dataExchangeQueue) { - this.dataExchangeQueue = dataExchangeQueue; - this.queueCapacity = dataExchangeQueue.getQueueCapacity(); - this.initialAddress = dataExchangeQueue.getInitialQueueAddress(); - this.queueMask = dataExchangeQueue.getQueueMask(); - this.dataBufferArray = new byte[BUFFER_SIZE]; - this.buffer = ByteBuffer.wrap(dataBufferArray, 0, dataBufferArray.length); - buffer.order(ByteOrder.LITTLE_ENDIAN); + private static final Logger LOGGER = LoggerFactory.getLogger(DataQueueInputStream.class); + private static final int BUFFER_SIZE = 10 * 1024; + private static final int INT_SIZE = 4; + private static final int SHORT_SIZE = 2; + private static final int LONG_SIZE = 8; + private final DataExchangeQueue dataExchangeQueue; + private final byte[] dataBufferArray; + private final ByteBuffer buffer; + private final int queueCapacity; + private final long initialAddress; + private final int queueMask; + + public DataQueueInputStream(DataExchangeQueue dataExchangeQueue) { + this.dataExchangeQueue = dataExchangeQueue; + this.queueCapacity = dataExchangeQueue.getQueueCapacity(); + this.initialAddress = dataExchangeQueue.getInitialQueueAddress(); + this.queueMask = dataExchangeQueue.getQueueMask(); + this.dataBufferArray = new byte[BUFFER_SIZE]; + this.buffer = ByteBuffer.wrap(dataBufferArray, 0, dataBufferArray.length); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @Override + public int read() throws IOException { + int r = read(dataBufferArray, 0, 1); + if (r == 1) { + return dataBufferArray[0] & 0xFF; } - - @Override - public int read() throws IOException { - int r = read(dataBufferArray, 0, 1); - if (r == 1) { - return dataBufferArray[0] & 0xFF; - } - return -1; + return -1; + } + + @Override + public int read(byte[] buffer, int offset, int length) throws IOException { + int currentIndex = 0; + while (currentIndex < length) { + int currentLength; + try { + currentLength = readFully(buffer, currentIndex + offset, length - currentIndex); + } catch (InterruptedException e) { + InterruptedIOException interruptedIOException = new InterruptedIOException(e.getMessage()); + interruptedIOException.bytesTransferred = currentIndex; + LOGGER.error("read infer data failed", e); + throw interruptedIOException; + } + if (currentLength < 0) { + return currentIndex > 0 ? currentIndex : -1; + } + currentIndex += currentLength; } - - @Override - public int read(byte[] buffer, int offset, int length) throws IOException { - int currentIndex = 0; - while (currentIndex < length) { - int currentLength; - try { - currentLength = readFully(buffer, currentIndex + offset, length - currentIndex); - } catch (InterruptedException e) { - InterruptedIOException interruptedIOException = new InterruptedIOException(e.getMessage()); - interruptedIOException.bytesTransferred = currentIndex; - LOGGER.error("read infer data failed", e); - throw interruptedIOException; - } - if (currentLength < 0) { - return currentIndex > 0 ? currentIndex : -1; - } - currentIndex += currentLength; + return currentIndex; + } + + public int read(byte[] b, int size) throws IOException { + return read(b, 0, size); + } + + private int readFully(byte[] buffer, int offset, int length) throws InterruptedException { + long inputPointer = dataExchangeQueue.getInputPointer(); + long outputNextPointer = dataExchangeQueue.getOutputNextPointer(); + + while (inputPointer >= outputNextPointer) { + long outputPointer = dataExchangeQueue.getOutputPointerByVolatile(); + dataExchangeQueue.setOutputNextPointer(outputPointer); + outputNextPointer = dataExchangeQueue.getOutputNextPointer(); + if (inputPointer >= outputNextPointer) { + if (dataExchangeQueue.enableFinished()) { + long outputPointerByVolatile = dataExchangeQueue.getOutputPointerByVolatile(); + dataExchangeQueue.setOutputNextPointer(outputPointerByVolatile); + outputNextPointer = dataExchangeQueue.getOutputNextPointer(); + if (inputPointer >= outputNextPointer) { + return -1; + } + break; } - return currentIndex; + } } - - public int read(byte[] b, int size) throws IOException { - return read(b, 0, size); + long nextPointIndex = getNextPointIndex(inputPointer, queueCapacity); + int remainByteNum; + if (outputNextPointer > nextPointIndex) { + remainByteNum = (int) (nextPointIndex - inputPointer); + } else { + remainByteNum = (int) (outputNextPointer - inputPointer); } - - private int readFully(byte[] buffer, int offset, int length) throws InterruptedException { - long inputPointer = dataExchangeQueue.getInputPointer(); - long outputNextPointer = dataExchangeQueue.getOutputNextPointer(); - - while (inputPointer >= outputNextPointer) { - long outputPointer = dataExchangeQueue.getOutputPointerByVolatile(); - dataExchangeQueue.setOutputNextPointer(outputPointer); - outputNextPointer = dataExchangeQueue.getOutputNextPointer(); - if (inputPointer >= outputNextPointer) { - if (dataExchangeQueue.enableFinished()) { - long outputPointerByVolatile = dataExchangeQueue.getOutputPointerByVolatile(); - dataExchangeQueue.setOutputNextPointer(outputPointerByVolatile); - outputNextPointer = dataExchangeQueue.getOutputNextPointer(); - if (inputPointer >= outputNextPointer) { - return -1; - } - break; - } - } - } - long nextPointIndex = getNextPointIndex(inputPointer, queueCapacity); - int remainByteNum; - if (outputNextPointer > nextPointIndex) { - remainByteNum = (int) (nextPointIndex - inputPointer); - } else { - remainByteNum = (int) (outputNextPointer - inputPointer); - } - int readableNum = Math.min(remainByteNum, length); - long left = this.initialAddress + (inputPointer & this.queueMask); - int right = sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET + offset; - UnSafeUtils.UNSAFE.copyMemory(null, left, buffer, right, readableNum); - dataExchangeQueue.setInputPointer(inputPointer + readableNum); - return readableNum; - } - - @Override - public int available() { - final long currentRead = dataExchangeQueue.getInputPointer(); - long writeCache = dataExchangeQueue.getOutputNextPointer(); - if (currentRead >= writeCache) { - dataExchangeQueue.setOutputNextPointer(dataExchangeQueue.getOutputPointerByVolatile()); - writeCache = dataExchangeQueue.getOutputNextPointer(); - } - - int availRead = (int) (writeCache - currentRead); - if (availRead > 0) { - return availRead; - } - return 0; - } - - public int getInt() throws IOException { - read(dataBufferArray, INT_SIZE); - buffer.clear(); - return buffer.getInt(); - } - - public short getShort() throws IOException { - read(dataBufferArray, SHORT_SIZE); - buffer.clear(); - return buffer.getShort(); - } - - public long getLong() throws IOException { - read(dataBufferArray, LONG_SIZE); - buffer.clear(); - return buffer.getLong(); - } - - public double getDouble() throws IOException { - read(dataBufferArray, LONG_SIZE); - buffer.clear(); - return buffer.getDouble(); + int readableNum = Math.min(remainByteNum, length); + long left = this.initialAddress + (inputPointer & this.queueMask); + int right = sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET + offset; + UnSafeUtils.UNSAFE.copyMemory(null, left, buffer, right, readableNum); + dataExchangeQueue.setInputPointer(inputPointer + readableNum); + return readableNum; + } + + @Override + public int available() { + final long currentRead = dataExchangeQueue.getInputPointer(); + long writeCache = dataExchangeQueue.getOutputNextPointer(); + if (currentRead >= writeCache) { + dataExchangeQueue.setOutputNextPointer(dataExchangeQueue.getOutputPointerByVolatile()); + writeCache = dataExchangeQueue.getOutputNextPointer(); } - public float getFloat() throws IOException { - read(dataBufferArray, INT_SIZE); - buffer.clear(); - return buffer.getFloat(); + int availRead = (int) (writeCache - currentRead); + if (availRead > 0) { + return availRead; } -} \ No newline at end of file + return 0; + } + + public int getInt() throws IOException { + read(dataBufferArray, INT_SIZE); + buffer.clear(); + return buffer.getInt(); + } + + public short getShort() throws IOException { + read(dataBufferArray, SHORT_SIZE); + buffer.clear(); + return buffer.getShort(); + } + + public long getLong() throws IOException { + read(dataBufferArray, LONG_SIZE); + buffer.clear(); + return buffer.getLong(); + } + + public double getDouble() throws IOException { + read(dataBufferArray, LONG_SIZE); + buffer.clear(); + return buffer.getDouble(); + } + + public float getFloat() throws IOException { + read(dataBufferArray, INT_SIZE); + buffer.clear(); + return buffer.getFloat(); + } +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueOutputStream.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueOutputStream.java index b3114ab56..f5f71fd86 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueOutputStream.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/DataQueueOutputStream.java @@ -25,78 +25,78 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class DataQueueOutputStream extends OutputStream { - private static final int BUFFER_SIZE = 10 * 1024; - private final DataExchangeQueue dataExchangeQueue; - private final byte[] dataBufferArray; - private final ByteBuffer buffer; + private static final int BUFFER_SIZE = 10 * 1024; + private final DataExchangeQueue dataExchangeQueue; + private final byte[] dataBufferArray; + private final ByteBuffer buffer; - private final int queueCapacity; - private final int queueMask; + private final int queueCapacity; + private final int queueMask; - public DataQueueOutputStream(DataExchangeQueue dataExchangeQueue) { - this.dataExchangeQueue = dataExchangeQueue; - this.queueCapacity = dataExchangeQueue.getQueueCapacity(); - this.dataBufferArray = new byte[BUFFER_SIZE]; - this.queueMask = dataExchangeQueue.getQueueMask(); - this.buffer = ByteBuffer.wrap(dataBufferArray, 0, dataBufferArray.length); - buffer.order(ByteOrder.LITTLE_ENDIAN); - } + public DataQueueOutputStream(DataExchangeQueue dataExchangeQueue) { + this.dataExchangeQueue = dataExchangeQueue; + this.queueCapacity = dataExchangeQueue.getQueueCapacity(); + this.dataBufferArray = new byte[BUFFER_SIZE]; + this.queueMask = dataExchangeQueue.getQueueMask(); + this.buffer = ByteBuffer.wrap(dataBufferArray, 0, dataBufferArray.length); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } - @Override - public void write(int b) throws IOException { - dataBufferArray[0] = (byte) (b & 0xff); - write(dataBufferArray, 0, 1); - } + @Override + public void write(int b) throws IOException { + dataBufferArray[0] = (byte) (b & 0xff); + write(dataBufferArray, 0, 1); + } - @Override - public void write(byte[] buffer, int offset, int size) throws IOException { - long outputPointer = dataExchangeQueue.getOutputPointer(); - long currentInputIndex = outputPointer - (queueCapacity - size); - while (dataExchangeQueue.getInputNextPointer() <= currentInputIndex - || dataExchangeQueue.getBarrierAddress() > dataExchangeQueue.getCurrentBufferAddress()) { + @Override + public void write(byte[] buffer, int offset, int size) throws IOException { + long outputPointer = dataExchangeQueue.getOutputPointer(); + long currentInputIndex = outputPointer - (queueCapacity - size); + while (dataExchangeQueue.getInputNextPointer() <= currentInputIndex + || dataExchangeQueue.getBarrierAddress() > dataExchangeQueue.getCurrentBufferAddress()) { - dataExchangeQueue.setInputNextPointer(dataExchangeQueue.getInputPointerByVolatile()); - if (dataExchangeQueue.getInputNextPointer() <= currentInputIndex - || dataExchangeQueue.getBarrierAddress() > dataExchangeQueue.getCurrentBufferAddress()) { - if (dataExchangeQueue.enableFinished()) { - throw new GeaflowRuntimeException("output queue is marked finished"); - } - Thread.yield(); - } + dataExchangeQueue.setInputNextPointer(dataExchangeQueue.getInputPointerByVolatile()); + if (dataExchangeQueue.getInputNextPointer() <= currentInputIndex + || dataExchangeQueue.getBarrierAddress() > dataExchangeQueue.getCurrentBufferAddress()) { + if (dataExchangeQueue.enableFinished()) { + throw new GeaflowRuntimeException("output queue is marked finished"); } + Thread.yield(); + } + } - int currentOutputNum = 0; - while (currentOutputNum < size) { - long nextPointIndex = getNextPointIndex(outputPointer, queueCapacity); - int remainNum = (int) (nextPointIndex - outputPointer); - int bytesToWrite = Math.min(size - currentOutputNum, remainNum); - int left = sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET + offset + currentOutputNum; - long right = dataExchangeQueue.getInitialQueueAddress() + (outputPointer & queueMask); + int currentOutputNum = 0; + while (currentOutputNum < size) { + long nextPointIndex = getNextPointIndex(outputPointer, queueCapacity); + int remainNum = (int) (nextPointIndex - outputPointer); + int bytesToWrite = Math.min(size - currentOutputNum, remainNum); + int left = sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET + offset + currentOutputNum; + long right = dataExchangeQueue.getInitialQueueAddress() + (outputPointer & queueMask); - UnSafeUtils.UNSAFE.copyMemory(buffer, left, null, right, bytesToWrite); - dataExchangeQueue.setOutputPointer(outputPointer + bytesToWrite); - currentOutputNum += bytesToWrite; - outputPointer += bytesToWrite; - } - dataExchangeQueue.setOutputPointer(outputPointer); + UnSafeUtils.UNSAFE.copyMemory(buffer, left, null, right, bytesToWrite); + dataExchangeQueue.setOutputPointer(outputPointer + bytesToWrite); + currentOutputNum += bytesToWrite; + outputPointer += bytesToWrite; } + dataExchangeQueue.setOutputPointer(outputPointer); + } - public boolean tryReserveBeforeWrite(int len) { - long outputPointer = dataExchangeQueue.getOutputPointer(); - long currentInputIndex = outputPointer - (queueCapacity - len); - if (dataExchangeQueue.getInputNextPointer() <= currentInputIndex) { - dataExchangeQueue.setInputNextPointer(dataExchangeQueue.getInputPointerByVolatile()); - } - long inputNextPointer = dataExchangeQueue.getInputNextPointer(); - return inputNextPointer > currentInputIndex; + public boolean tryReserveBeforeWrite(int len) { + long outputPointer = dataExchangeQueue.getOutputPointer(); + long currentInputIndex = outputPointer - (queueCapacity - len); + if (dataExchangeQueue.getInputNextPointer() <= currentInputIndex) { + dataExchangeQueue.setInputNextPointer(dataExchangeQueue.getInputPointerByVolatile()); } + long inputNextPointer = dataExchangeQueue.getInputNextPointer(); + return inputNextPointer > currentInputIndex; + } - @Override - public void close() { - dataExchangeQueue.markFinished(); - } + @Override + public void close() { + dataExchangeQueue.markFinished(); + } } - diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDataBridge.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDataBridge.java index 64c909bf5..4a90dac84 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDataBridge.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDataBridge.java @@ -24,8 +24,7 @@ public interface IDataBridge extends Closeable { - boolean write(Object... obj) throws IOException; - - OUT read() throws IOException; + boolean write(Object... obj) throws IOException; + OUT read() throws IOException; } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDecoder.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDecoder.java index e6f9d9b5c..e1ec8fdea 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDecoder.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IDecoder.java @@ -23,5 +23,5 @@ public interface IDecoder extends Closeable { - T decode(byte[] bytes); + T decode(byte[] bytes); } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IEncoder.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IEncoder.java index e4b824fe1..0477f5740 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IEncoder.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/IEncoder.java @@ -23,5 +23,5 @@ public interface IEncoder extends Closeable { - byte[] encode(Object object); + byte[] encode(Object object); } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataReader.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataReader.java index f3a661950..8adaed1cb 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataReader.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataReader.java @@ -19,7 +19,6 @@ package org.apache.geaflow.infer.exchange; -import com.google.common.base.Preconditions; import java.io.Closeable; import java.io.DataInputStream; import java.io.EOFException; @@ -27,55 +26,59 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.concurrent.atomic.AtomicBoolean; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class InferDataReader implements Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger(InferDataReader.class); - private static final AtomicBoolean END = new AtomicBoolean(false); - private static final int HEADER_LENGTH = 4; - private final DataInputStream input; + private static final Logger LOGGER = LoggerFactory.getLogger(InferDataReader.class); + private static final AtomicBoolean END = new AtomicBoolean(false); + private static final int HEADER_LENGTH = 4; + private final DataInputStream input; - public InferDataReader(DataExchangeQueue dataExchangeQueue) { - DataQueueInputStream dataQueueInputStream = new DataQueueInputStream(dataExchangeQueue); - this.input = new DataInputStream(dataQueueInputStream); - } + public InferDataReader(DataExchangeQueue dataExchangeQueue) { + DataQueueInputStream dataQueueInputStream = new DataQueueInputStream(dataExchangeQueue); + this.input = new DataInputStream(dataQueueInputStream); + } - public byte[] read() throws IOException { - byte[] buffer = new byte[HEADER_LENGTH]; - int bytesNum; - try { - bytesNum = input.read(buffer); - } catch (EOFException e) { - LOGGER.error("read infer data fail", e); - END.set(true); - return null; - } - if (bytesNum < 0) { - LOGGER.warn("read infer data size is {}", bytesNum); - END.set(true); - return null; - } - if (bytesNum < buffer.length) { - input.readFully(buffer, bytesNum, buffer.length - bytesNum); - } - int len = fromInt32LE(buffer); - byte[] data = new byte[len]; - input.readFully(data); - return data; + public byte[] read() throws IOException { + byte[] buffer = new byte[HEADER_LENGTH]; + int bytesNum; + try { + bytesNum = input.read(buffer); + } catch (EOFException e) { + LOGGER.error("read infer data fail", e); + END.set(true); + return null; } - - private int fromInt32LE(byte[] data) { - Preconditions.checkState(data.length == HEADER_LENGTH, String.format("read data header " - + "size %d, must be %d", data.length, HEADER_LENGTH)); - ByteBuffer byteBuffer = ByteBuffer.wrap(data); - byteBuffer.order(ByteOrder.LITTLE_ENDIAN); - return byteBuffer.getInt(); + if (bytesNum < 0) { + LOGGER.warn("read infer data size is {}", bytesNum); + END.set(true); + return null; } - - @Override - public void close() throws IOException { - input.close(); + if (bytesNum < buffer.length) { + input.readFully(buffer, bytesNum, buffer.length - bytesNum); } + int len = fromInt32LE(buffer); + byte[] data = new byte[len]; + input.readFully(data); + return data; + } + + private int fromInt32LE(byte[] data) { + Preconditions.checkState( + data.length == HEADER_LENGTH, + String.format("read data header " + "size %d, must be %d", data.length, HEADER_LENGTH)); + ByteBuffer byteBuffer = ByteBuffer.wrap(data); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + return byteBuffer.getInt(); + } + + @Override + public void close() throws IOException { + input.close(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataWriter.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataWriter.java index 27a679101..ccf267d83 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataWriter.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/InferDataWriter.java @@ -26,41 +26,41 @@ public class InferDataWriter implements Closeable { - private static final int HEADER_LENGTH = 4; - private final DataQueueOutputStream outputStream; - private final byte[] dataHeaderBytes; - private final ByteBuffer headerByteBuffer; + private static final int HEADER_LENGTH = 4; + private final DataQueueOutputStream outputStream; + private final byte[] dataHeaderBytes; + private final ByteBuffer headerByteBuffer; - public InferDataWriter(DataExchangeQueue dataExchangeQueue) { - this.outputStream = new DataQueueOutputStream(dataExchangeQueue); - this.dataHeaderBytes = new byte[HEADER_LENGTH]; - this.headerByteBuffer = ByteBuffer.wrap(dataHeaderBytes); - this.headerByteBuffer.order(ByteOrder.LITTLE_ENDIAN); - } + public InferDataWriter(DataExchangeQueue dataExchangeQueue) { + this.outputStream = new DataQueueOutputStream(dataExchangeQueue); + this.dataHeaderBytes = new byte[HEADER_LENGTH]; + this.headerByteBuffer = ByteBuffer.wrap(dataHeaderBytes); + this.headerByteBuffer.order(ByteOrder.LITTLE_ENDIAN); + } - public boolean write(byte[] record, int offset, int length) throws IOException { - int outputSize = HEADER_LENGTH + (length - offset); - if (!outputStream.tryReserveBeforeWrite((outputSize))) { - return false; - } - byte[] headerData = extractHeaderData(length); - outputStream.write(headerData, 0, HEADER_LENGTH); - outputStream.write(record, offset, length); - return true; + public boolean write(byte[] record, int offset, int length) throws IOException { + int outputSize = HEADER_LENGTH + (length - offset); + if (!outputStream.tryReserveBeforeWrite((outputSize))) { + return false; } + byte[] headerData = extractHeaderData(length); + outputStream.write(headerData, 0, HEADER_LENGTH); + outputStream.write(record, offset, length); + return true; + } - public boolean write(byte[] record) throws IOException { - return write(record, 0, record.length); - } + public boolean write(byte[] record) throws IOException { + return write(record, 0, record.length); + } - private byte[] extractHeaderData(int data) { - headerByteBuffer.clear(); - headerByteBuffer.putInt(data); - return dataHeaderBytes; - } + private byte[] extractHeaderData(int data) { + headerByteBuffer.clear(); + headerByteBuffer.putInt(data); + return dataHeaderBytes; + } - @Override - public void close() throws IOException { - outputStream.close(); - } + @Override + public void close() throws IOException { + outputStream.close(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/MemoryMapper.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/MemoryMapper.java index dca358be9..2e0e70e12 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/MemoryMapper.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/MemoryMapper.java @@ -23,87 +23,86 @@ import java.io.RandomAccessFile; import java.lang.reflect.Method; import java.nio.channels.FileChannel; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class MemoryMapper implements Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger(MemoryMapper.class); - private static final String MAP_0 = "map0"; - private static final String UNMAP_0 = "unmap0"; - private static final Method MEMORY_MAP_METHOD; - private static final String SHARE_MEMORY_MODE = "rw"; - private static final Method MEMORY_UN_MAP_METHOD; - private long mapAddress; - private final long mapSize; - private final String mapKey; - - static { - try { - MEMORY_MAP_METHOD = getMethod(sun.nio.ch.FileChannelImpl.class, MAP_0, int.class, long.class, long.class); - MEMORY_UN_MAP_METHOD = getMethod(sun.nio.ch.FileChannelImpl.class, UNMAP_0, long.class, long.class); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryMapper.class); + private static final String MAP_0 = "map0"; + private static final String UNMAP_0 = "unmap0"; + private static final Method MEMORY_MAP_METHOD; + private static final String SHARE_MEMORY_MODE = "rw"; + private static final Method MEMORY_UN_MAP_METHOD; + private long mapAddress; + private final long mapSize; + private final String mapKey; + + static { + try { + MEMORY_MAP_METHOD = + getMethod(sun.nio.ch.FileChannelImpl.class, MAP_0, int.class, long.class, long.class); + MEMORY_UN_MAP_METHOD = + getMethod(sun.nio.ch.FileChannelImpl.class, UNMAP_0, long.class, long.class); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - public MemoryMapper(String mapKey, long size) { - this.mapKey = mapKey; - this.mapSize = roundTo4096(size); - map(mapKey, mapSize); - } - - - private void map(String mapKey, long mapSize) { - try { - RandomAccessFile backingFile = new RandomAccessFile(mapKey, SHARE_MEMORY_MODE); - backingFile.setLength(mapSize); - FileChannel ch = backingFile.getChannel(); - this.mapAddress = (long) MEMORY_MAP_METHOD.invoke(ch, 1, 0L, mapSize); - ch.close(); - backingFile.close(); - } catch (Throwable e) { - LOGGER.error("map memory key {} size {} failed", mapKey, mapSize); - throw new GeaflowRuntimeException("memory map failed", e); - } + } + + public MemoryMapper(String mapKey, long size) { + this.mapKey = mapKey; + this.mapSize = roundTo4096(size); + map(mapKey, mapSize); + } + + private void map(String mapKey, long mapSize) { + try { + RandomAccessFile backingFile = new RandomAccessFile(mapKey, SHARE_MEMORY_MODE); + backingFile.setLength(mapSize); + FileChannel ch = backingFile.getChannel(); + this.mapAddress = (long) MEMORY_MAP_METHOD.invoke(ch, 1, 0L, mapSize); + ch.close(); + backingFile.close(); + } catch (Throwable e) { + LOGGER.error("map memory key {} size {} failed", mapKey, mapSize); + throw new GeaflowRuntimeException("memory map failed", e); } - - - public void unMap() { - if (mapAddress != 0) { - try { - MEMORY_UN_MAP_METHOD.invoke(null, mapAddress, this.mapSize); - } catch (Throwable e) { - LOGGER.error("un map error"); - mapAddress = 0; - } - } + } + + public void unMap() { + if (mapAddress != 0) { + try { + MEMORY_UN_MAP_METHOD.invoke(null, mapAddress, this.mapSize); + } catch (Throwable e) { + LOGGER.error("un map error"); mapAddress = 0; + } } - - @Override - public void close() { - unMap(); - } - - public long getMapSize() { - return mapSize; - } - - public long getMapAddress() { - return mapAddress; - } - - - private static Method getMethod(Class cls, String name, Class... params) - throws Exception { - Method m = cls.getDeclaredMethod(name, params); - m.setAccessible(true); - return m; - } - - private long roundTo4096(long i) { - return (i + 0xfffL) & ~0xfffL; - } + mapAddress = 0; + } + + @Override + public void close() { + unMap(); + } + + public long getMapSize() { + return mapSize; + } + + public long getMapAddress() { + return mapAddress; + } + + private static Method getMethod(Class cls, String name, Class... params) throws Exception { + Method m = cls.getDeclaredMethod(name, params); + m.setAccessible(true); + return m; + } + + private long roundTo4096(long i) { + return (i + 0xfffL) & ~0xfffL; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/UnSafeUtils.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/UnSafeUtils.java index 07d72fa01..1edd86288 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/UnSafeUtils.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/UnSafeUtils.java @@ -21,37 +21,38 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Field; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class UnSafeUtils { - private static final Logger LOGGER = LoggerFactory.getLogger(UnSafeUtils.class); - - private static final String THE_UNSAFE = "theUnsafe"; - public static final sun.misc.Unsafe UNSAFE; - - static { - sun.misc.Unsafe instance; - try { - Field field = sun.misc.Unsafe.class.getDeclaredField(THE_UNSAFE); - field.setAccessible(true); - instance = (sun.misc.Unsafe) field.get(null); - } catch (Exception e) { - LOGGER.error("get unsafe field failed", e); - instance = initDeclaredConstructor(); - } - UNSAFE = instance; + private static final Logger LOGGER = LoggerFactory.getLogger(UnSafeUtils.class); + + private static final String THE_UNSAFE = "theUnsafe"; + public static final sun.misc.Unsafe UNSAFE; + + static { + sun.misc.Unsafe instance; + try { + Field field = sun.misc.Unsafe.class.getDeclaredField(THE_UNSAFE); + field.setAccessible(true); + instance = (sun.misc.Unsafe) field.get(null); + } catch (Exception e) { + LOGGER.error("get unsafe field failed", e); + instance = initDeclaredConstructor(); } + UNSAFE = instance; + } - private static sun.misc.Unsafe initDeclaredConstructor() { - try { - Constructor c = sun.misc.Unsafe.class.getDeclaredConstructor(); - c.setAccessible(true); - return c.newInstance(); - } catch (Exception e) { - throw new GeaflowRuntimeException("init unsafe declared constructor failed", e); - } + private static sun.misc.Unsafe initDeclaredConstructor() { + try { + Constructor c = sun.misc.Unsafe.class.getDeclaredConstructor(); + c.setAccessible(true); + return c.newInstance(); + } catch (Exception e) { + throw new GeaflowRuntimeException("init unsafe declared constructor failed", e); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeDeCoderImpl.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeDeCoderImpl.java index 0746ce1e0..e137d35e7 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeDeCoderImpl.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeDeCoderImpl.java @@ -20,30 +20,31 @@ package org.apache.geaflow.infer.exchange.impl; import java.io.IOException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.exchange.IDecoder; import org.apache.geaflow.infer.exchange.serialize.Unpickler; public class DataExchangeDeCoderImpl implements IDecoder { - private final transient Unpickler unpickler; + private final transient Unpickler unpickler; - public DataExchangeDeCoderImpl() { - this.unpickler = new Unpickler(); - } + public DataExchangeDeCoderImpl() { + this.unpickler = new Unpickler(); + } - @Override - public T decode(byte[] bytes) { - try { - Object res = unpickler.loads(bytes); - return (T) res; - } catch (Exception e) { - throw new GeaflowRuntimeException("unpick object error", e); - } + @Override + public T decode(byte[] bytes) { + try { + Object res = unpickler.loads(bytes); + return (T) res; + } catch (Exception e) { + throw new GeaflowRuntimeException("unpick object error", e); } + } - @Override - public void close() throws IOException { - unpickler.close(); - } + @Override + public void close() throws IOException { + unpickler.close(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeEnCoderImpl.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeEnCoderImpl.java index 27ec94437..6703ec564 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeEnCoderImpl.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/DataExchangeEnCoderImpl.java @@ -20,29 +20,30 @@ package org.apache.geaflow.infer.exchange.impl; import java.io.IOException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.exchange.IEncoder; import org.apache.geaflow.infer.exchange.serialize.Pickler; public class DataExchangeEnCoderImpl implements IEncoder { - private final transient Pickler pickler; + private final transient Pickler pickler; - public DataExchangeEnCoderImpl() { - this.pickler = new Pickler(); - } + public DataExchangeEnCoderImpl() { + this.pickler = new Pickler(); + } - @Override - public byte[] encode(Object object) { - try { - return pickler.dumps(object); - } catch (Exception e) { - throw new GeaflowRuntimeException("pick object error: " + object, e); - } + @Override + public byte[] encode(Object object) { + try { + return pickler.dumps(object); + } catch (Exception e) { + throw new GeaflowRuntimeException("pick object error: " + object, e); } + } - @Override - public void close() throws IOException { - pickler.close(); - } + @Override + public void close() throws IOException { + pickler.close(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/InferDataBridgeImpl.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/InferDataBridgeImpl.java index d5f173143..4b90193ac 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/InferDataBridgeImpl.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/impl/InferDataBridgeImpl.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; + import org.apache.geaflow.infer.exchange.DataExchangeContext; import org.apache.geaflow.infer.exchange.DataExchangeQueue; import org.apache.geaflow.infer.exchange.IDataBridge; @@ -33,69 +34,69 @@ public class InferDataBridgeImpl implements IDataBridge { - private static final int HEADER_LENGTH = 4; - private final byte[] bufferArray; - private final ByteBuffer byteBuffer; - private final InferDataWriter inferDataWriter; - private final InferDataReader inferDataReader; - private final IEncoder encoder; - private final IDecoder decoder; + private static final int HEADER_LENGTH = 4; + private final byte[] bufferArray; + private final ByteBuffer byteBuffer; + private final InferDataWriter inferDataWriter; + private final InferDataReader inferDataReader; + private final IEncoder encoder; + private final IDecoder decoder; - public InferDataBridgeImpl(DataExchangeContext shareMemoryContext) { - DataExchangeQueue receiveQueue = shareMemoryContext.getReceiveQueue(); - DataExchangeQueue sendQueue = shareMemoryContext.getSendQueue(); - this.inferDataReader = new InferDataReader(receiveQueue); - this.inferDataWriter = new InferDataWriter(sendQueue); - this.encoder = new DataExchangeEnCoderImpl(); - this.decoder = new DataExchangeDeCoderImpl(); - this.bufferArray = new byte[HEADER_LENGTH]; - this.byteBuffer = ByteBuffer.wrap(bufferArray); - this.byteBuffer.order(ByteOrder.LITTLE_ENDIAN); - } + public InferDataBridgeImpl(DataExchangeContext shareMemoryContext) { + DataExchangeQueue receiveQueue = shareMemoryContext.getReceiveQueue(); + DataExchangeQueue sendQueue = shareMemoryContext.getSendQueue(); + this.inferDataReader = new InferDataReader(receiveQueue); + this.inferDataWriter = new InferDataWriter(sendQueue); + this.encoder = new DataExchangeEnCoderImpl(); + this.decoder = new DataExchangeDeCoderImpl(); + this.bufferArray = new byte[HEADER_LENGTH]; + this.byteBuffer = ByteBuffer.wrap(bufferArray); + this.byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + } - @Override - public boolean write(Object... inputs) throws IOException { - int inputsSize = inputs.length; - ByteArrayOutputStream result = new ByteArrayOutputStream(); - byte[] dataSizeBytes = toInt32LE(inputsSize); - result.write(dataSizeBytes); - for (Object element : inputs) { - result.write(transformBytes(element)); - } - byte[] byteArray = result.toByteArray(); - return inferDataWriter.write(byteArray); + @Override + public boolean write(Object... inputs) throws IOException { + int inputsSize = inputs.length; + ByteArrayOutputStream result = new ByteArrayOutputStream(); + byte[] dataSizeBytes = toInt32LE(inputsSize); + result.write(dataSizeBytes); + for (Object element : inputs) { + result.write(transformBytes(element)); } + byte[] byteArray = result.toByteArray(); + return inferDataWriter.write(byteArray); + } - @Override - public OUT read() throws IOException { - byte[] result = inferDataReader.read(); - if (result != null) { - return this.decoder.decode(result); - } - return null; + @Override + public OUT read() throws IOException { + byte[] result = inferDataReader.read(); + if (result != null) { + return this.decoder.decode(result); } + return null; + } - private byte[] transformBytes(Object obj) { - byte[] dataBytes = this.encoder.encode(obj); - int dataLength = dataBytes.length; - byte[] lenBytes = toInt32LE(dataLength); - ByteBuffer buffer = ByteBuffer.allocate(HEADER_LENGTH + dataLength); - buffer.put(lenBytes); - buffer.put(dataBytes); - return buffer.array(); - } + private byte[] transformBytes(Object obj) { + byte[] dataBytes = this.encoder.encode(obj); + int dataLength = dataBytes.length; + byte[] lenBytes = toInt32LE(dataLength); + ByteBuffer buffer = ByteBuffer.allocate(HEADER_LENGTH + dataLength); + buffer.put(lenBytes); + buffer.put(dataBytes); + return buffer.array(); + } - private byte[] toInt32LE(int data) { - this.byteBuffer.clear(); - this.byteBuffer.putInt(data); - return bufferArray; - } + private byte[] toInt32LE(int data) { + this.byteBuffer.clear(); + this.byteBuffer.putInt(data); + return bufferArray; + } - @Override - public void close() throws IOException { - inferDataReader.close(); - inferDataWriter.close(); - encoder.close(); - decoder.close(); - } + @Override + public void close() throws IOException { + inferDataReader.close(); + inferDataWriter.close(); + encoder.close(); + decoder.close(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/AnyClassConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/AnyClassConstructor.java index ed5715caa..8bcdfd6b0 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/AnyClassConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/AnyClassConstructor.java @@ -23,26 +23,26 @@ public abstract class AnyClassConstructor implements IObjectConstructor { - protected final Class type; + protected final Class type; - public AnyClassConstructor(Class type) { - this.type = type; - } + public AnyClassConstructor(Class type) { + this.type = type; + } - @Override - public Object construct(Object[] args) { - try { - Class[] paramTypes = new Class[args.length]; - for (int i = 0; i < args.length; ++i) { - paramTypes[i] = args[i].getClass(); - } - Constructor cons = type.getConstructor(paramTypes); - initClassImpl(cons, args); - return cons.newInstance(args); - } catch (Exception e) { - throw new PickleException("problem construction object: " + e); - } + @Override + public Object construct(Object[] args) { + try { + Class[] paramTypes = new Class[args.length]; + for (int i = 0; i < args.length; ++i) { + paramTypes[i] = args[i].getClass(); + } + Constructor cons = type.getConstructor(paramTypes); + initClassImpl(cons, args); + return cons.newInstance(args); + } catch (Exception e) { + throw new PickleException("problem construction object: " + e); } + } - protected abstract Object initClassImpl(Constructor cons, Object[] args) throws Exception; + protected abstract Object initClassImpl(Constructor cons, Object[] args) throws Exception; } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ArrayConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ArrayConstructor.java index 89abdf1ef..00e642a75 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ArrayConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ArrayConstructor.java @@ -23,399 +23,411 @@ public class ArrayConstructor implements IObjectConstructor { - public Object construct(Object[] args) throws PickleException { - if (args.length == 4) { - ArrayConstructor constructor = (ArrayConstructor) args[0]; - char typeCode = ((String) args[1]).charAt(0); - int machineCodeType = (Integer) args[2]; - byte[] data = (byte[]) args[3]; - return constructor.construct(typeCode, machineCodeType, data); + public Object construct(Object[] args) throws PickleException { + if (args.length == 4) { + ArrayConstructor constructor = (ArrayConstructor) args[0]; + char typeCode = ((String) args[1]).charAt(0); + int machineCodeType = (Integer) args[2]; + byte[] data = (byte[]) args[3]; + return constructor.construct(typeCode, machineCodeType, data); + } + if (args.length != 2) { + throw new PickleException( + "invalid pickle data for array; expected 2 args, got " + args.length); + } + + String typeCode = (String) args[0]; + if (args[1] instanceof String) { + throw new PickleException("unsupported Python 2.6 array pickle format"); + } + ArrayList values = (ArrayList) args[1]; + + switch (typeCode.charAt(0)) { + case 'c': + case 'u': + { + char[] result = new char[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((String) c).charAt(0); + } + return result; } - if (args.length != 2) { - throw new PickleException( - "invalid pickle data for array; expected 2 args, got " + args.length); + case 'b': + { + byte[] result = new byte[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).byteValue(); + } + return result; } - - String typeCode = (String) args[0]; - if (args[1] instanceof String) { - throw new PickleException("unsupported Python 2.6 array pickle format"); + case 'B': + case 'h': + { + short[] result = new short[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).shortValue(); + } + return result; + } + // 列表元素为int类型 + case 'H': + case 'i': + case 'l': + { + int[] result = new int[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).intValue(); + } + return result; + } + case 'I': + case 'L': + { + long[] result = new long[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).longValue(); + } + return result; + } + case 'f': + { + float[] result = new float[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).floatValue(); + } + return result; } - ArrayList values = (ArrayList) args[1]; - - switch (typeCode.charAt(0)) { - case 'c': - case 'u': { - char[] result = new char[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((String) c).charAt(0); - } - return result; - } - case 'b': { - byte[] result = new byte[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).byteValue(); - } - return result; - } - case 'B': - case 'h': { - short[] result = new short[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).shortValue(); - } - return result; - } - // 列表元素为int类型 - case 'H': - case 'i': - case 'l': { - int[] result = new int[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).intValue(); - } - return result; - } - case 'I': - case 'L': { - long[] result = new long[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).longValue(); - } - return result; - } - case 'f': { - float[] result = new float[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).floatValue(); - } - return result; - } - case 'd': { - double[] result = new double[values.size()]; - int i = 0; - for (Object c : values) { - result[i++] = ((Number) c).doubleValue(); - } - return result; - } - default: - throw new PickleException("invalid array typecode: " + typeCode); + case 'd': + { + double[] result = new double[values.size()]; + int i = 0; + for (Object c : values) { + result[i++] = ((Number) c).doubleValue(); + } + return result; } + default: + throw new PickleException("invalid array typecode: " + typeCode); } + } - - public Object construct(char typeCode, int machineCode, byte[] data) throws PickleException { - if (machineCode < 0) { - throw new PickleException("unknown machine type format"); - } - switch (typeCode) { - case 'c': - case 'u': { - if (machineCode != 18 && machineCode != 19 && machineCode != 20 - && machineCode != 21) { - throw new PickleException("for c/u type must be 18/19/20/21"); - } - if (machineCode == 18 || machineCode == 19) { - if (data.length % 2 != 0) { - throw new PickleException("data size alignment error"); - } - return constructCharArrayUTF16(machineCode, data); - } else { - if (data.length % 4 != 0) { - throw new PickleException("data size alignment error"); - } - return constructCharArrayUTF32(machineCode, data); - } - } - case 'b': { - if (machineCode != 1) { - throw new PickleException("for b type must be 1"); - } - return data; - } - case 'B': { - if (machineCode != 0) { - throw new PickleException("for B type must be 0"); - } - return constructShortArrayFromUByte(data); - } - case 'h': { - if (machineCode != 4 && machineCode != 5) { - throw new PickleException("for h type must be 4/5"); - } - if (data.length % 2 != 0) { - throw new PickleException("data size alignment error"); - } - return constructShortArraySigned(machineCode, data); - } - case 'H': { - if (machineCode != 2 && machineCode != 3) { - throw new PickleException("for H type must be 2/3"); - } - if (data.length % 2 != 0) { - throw new PickleException("data size alignment error"); - } - return constructIntArrayFromUShort(machineCode, data); - } - case 'i': { - if (machineCode != 8 && machineCode != 9) { - throw new PickleException("for i type must be 8/9"); - } - if (data.length % 4 != 0) { - throw new PickleException("data size alignment error"); - } - return constructIntArrayFromInt32(machineCode, data); - } - case 'l': { - if (machineCode != 8 && machineCode != 9 && machineCode != 12 - && machineCode != 13) { - throw new PickleException("for l type must be 8/9/12/13"); - } - if ((machineCode == 8 || machineCode == 9) && (data.length % 4 != 0)) { - throw new PickleException("data size alignment error"); - } - if ((machineCode == 12 || machineCode == 13) && (data.length % 8 != 0)) { - throw new PickleException("data size alignment error"); - } - if (machineCode == 8 || machineCode == 9) { - return constructIntArrayFromInt32(machineCode, data); - } else { - return constructLongArrayFromInt64(machineCode, data); - } - } - case 'I': { - if (machineCode != 6 && machineCode != 7) { - throw new PickleException("for I type must be 6/7"); - } - if (data.length % 4 != 0) { - throw new PickleException("data size alignment error"); - } - return constructLongArrayFromUInt32(machineCode, data); - } - case 'L': { - if (machineCode != 6 && machineCode != 7 && machineCode != 10 - && machineCode != 11) { - throw new PickleException("for L type must be 6/7/10/11"); - } - if ((machineCode == 6 || machineCode == 7) && (data.length % 4 != 0)) { - throw new PickleException("data size alignment error"); - } - if ((machineCode == 10 || machineCode == 11) && (data.length % 8 != 0)) { - throw new PickleException("data size alignment error"); - } - if (machineCode == 6 || machineCode == 7) { - // 32 bits - return constructLongArrayFromUInt32(machineCode, data); - } else { - // 64 bits - return constructLongArrayFromUInt64(machineCode, data); - } - } - case 'f': { - if (machineCode != 14 && machineCode != 15) { - throw new PickleException("for f type must be 14/15"); - } - if (data.length % 4 != 0) { - throw new PickleException("data size alignment error"); - } - return constructFloatArray(machineCode, data); + public Object construct(char typeCode, int machineCode, byte[] data) throws PickleException { + if (machineCode < 0) { + throw new PickleException("unknown machine type format"); + } + switch (typeCode) { + case 'c': + case 'u': + { + if (machineCode != 18 && machineCode != 19 && machineCode != 20 && machineCode != 21) { + throw new PickleException("for c/u type must be 18/19/20/21"); + } + if (machineCode == 18 || machineCode == 19) { + if (data.length % 2 != 0) { + throw new PickleException("data size alignment error"); } - case 'd': { - if (machineCode != 16 && machineCode != 17) { - throw new PickleException("for d type must be 16/17"); - } - if (data.length % 8 != 0) { - throw new PickleException("data size alignment error"); - } - return constructDoubleArray(machineCode, data); + return constructCharArrayUTF16(machineCode, data); + } else { + if (data.length % 4 != 0) { + throw new PickleException("data size alignment error"); } - default: - throw new PickleException("invalid array typecode: " + typeCode); + return constructCharArrayUTF32(machineCode, data); + } } - } - - protected int[] constructIntArrayFromInt32(int machineCode, byte[] data) { - int[] result = new int[data.length / 4]; - byte[] bigEnd = new byte[4]; - for (int i = 0; i < data.length / 4; i++) { - if (machineCode == 8) { - result[i] = PickleUtils.bytes2Integer(data, i * 4, 4); - } else { - bigEnd[0] = data[3 + i * 4]; - bigEnd[1] = data[2 + i * 4]; - bigEnd[2] = data[1 + i * 4]; - bigEnd[3] = data[i * 4]; - result[i] = PickleUtils.bytes2Integer(bigEnd); - } + case 'b': + { + if (machineCode != 1) { + throw new PickleException("for b type must be 1"); + } + return data; + } + case 'B': + { + if (machineCode != 0) { + throw new PickleException("for B type must be 0"); + } + return constructShortArrayFromUByte(data); + } + case 'h': + { + if (machineCode != 4 && machineCode != 5) { + throw new PickleException("for h type must be 4/5"); + } + if (data.length % 2 != 0) { + throw new PickleException("data size alignment error"); + } + return constructShortArraySigned(machineCode, data); + } + case 'H': + { + if (machineCode != 2 && machineCode != 3) { + throw new PickleException("for H type must be 2/3"); + } + if (data.length % 2 != 0) { + throw new PickleException("data size alignment error"); + } + return constructIntArrayFromUShort(machineCode, data); + } + case 'i': + { + if (machineCode != 8 && machineCode != 9) { + throw new PickleException("for i type must be 8/9"); + } + if (data.length % 4 != 0) { + throw new PickleException("data size alignment error"); + } + return constructIntArrayFromInt32(machineCode, data); + } + case 'l': + { + if (machineCode != 8 && machineCode != 9 && machineCode != 12 && machineCode != 13) { + throw new PickleException("for l type must be 8/9/12/13"); + } + if ((machineCode == 8 || machineCode == 9) && (data.length % 4 != 0)) { + throw new PickleException("data size alignment error"); + } + if ((machineCode == 12 || machineCode == 13) && (data.length % 8 != 0)) { + throw new PickleException("data size alignment error"); + } + if (machineCode == 8 || machineCode == 9) { + return constructIntArrayFromInt32(machineCode, data); + } else { + return constructLongArrayFromInt64(machineCode, data); + } + } + case 'I': + { + if (machineCode != 6 && machineCode != 7) { + throw new PickleException("for I type must be 6/7"); + } + if (data.length % 4 != 0) { + throw new PickleException("data size alignment error"); + } + return constructLongArrayFromUInt32(machineCode, data); } - return result; + case 'L': + { + if (machineCode != 6 && machineCode != 7 && machineCode != 10 && machineCode != 11) { + throw new PickleException("for L type must be 6/7/10/11"); + } + if ((machineCode == 6 || machineCode == 7) && (data.length % 4 != 0)) { + throw new PickleException("data size alignment error"); + } + if ((machineCode == 10 || machineCode == 11) && (data.length % 8 != 0)) { + throw new PickleException("data size alignment error"); + } + if (machineCode == 6 || machineCode == 7) { + // 32 bits + return constructLongArrayFromUInt32(machineCode, data); + } else { + // 64 bits + return constructLongArrayFromUInt64(machineCode, data); + } + } + case 'f': + { + if (machineCode != 14 && machineCode != 15) { + throw new PickleException("for f type must be 14/15"); + } + if (data.length % 4 != 0) { + throw new PickleException("data size alignment error"); + } + return constructFloatArray(machineCode, data); + } + case 'd': + { + if (machineCode != 16 && machineCode != 17) { + throw new PickleException("for d type must be 16/17"); + } + if (data.length % 8 != 0) { + throw new PickleException("data size alignment error"); + } + return constructDoubleArray(machineCode, data); + } + default: + throw new PickleException("invalid array typecode: " + typeCode); } + } - protected long[] constructLongArrayFromUInt32(int machineCode, byte[] data) { - long[] result = new long[data.length / 4]; - byte[] bigEnd = new byte[4]; - for (int i = 0; i < data.length / 4; i++) { - if (machineCode == 6) { - result[i] = PickleUtils.bytes2Uint(data, i * 4); - } else { - bigEnd[0] = data[3 + i * 4]; - bigEnd[1] = data[2 + i * 4]; - bigEnd[2] = data[1 + i * 4]; - bigEnd[3] = data[i * 4]; - result[i] = PickleUtils.bytes2Uint(bigEnd, 0); - } - } - return result; + protected int[] constructIntArrayFromInt32(int machineCode, byte[] data) { + int[] result = new int[data.length / 4]; + byte[] bigEnd = new byte[4]; + for (int i = 0; i < data.length / 4; i++) { + if (machineCode == 8) { + result[i] = PickleUtils.bytes2Integer(data, i * 4, 4); + } else { + bigEnd[0] = data[3 + i * 4]; + bigEnd[1] = data[2 + i * 4]; + bigEnd[2] = data[1 + i * 4]; + bigEnd[3] = data[i * 4]; + result[i] = PickleUtils.bytes2Integer(bigEnd); + } } + return result; + } - protected long[] constructLongArrayFromUInt64(int machineCode, byte[] data) { - throw new PickleException("unsupported datatype: 64-bits unsigned long"); + protected long[] constructLongArrayFromUInt32(int machineCode, byte[] data) { + long[] result = new long[data.length / 4]; + byte[] bigEnd = new byte[4]; + for (int i = 0; i < data.length / 4; i++) { + if (machineCode == 6) { + result[i] = PickleUtils.bytes2Uint(data, i * 4); + } else { + bigEnd[0] = data[3 + i * 4]; + bigEnd[1] = data[2 + i * 4]; + bigEnd[2] = data[1 + i * 4]; + bigEnd[3] = data[i * 4]; + result[i] = PickleUtils.bytes2Uint(bigEnd, 0); + } } + return result; + } - protected long[] constructLongArrayFromInt64(int machineCode, byte[] data) { - long[] result = new long[data.length / 8]; - byte[] bigEnd = new byte[8]; - for (int i = 0; i < data.length / 8; i++) { - if (machineCode == 12) { - result[i] = PickleUtils.bytes2Long(data, i * 8); - } else { - bigEnd[0] = data[7 + i * 8]; - bigEnd[1] = data[6 + i * 8]; - bigEnd[2] = data[5 + i * 8]; - bigEnd[3] = data[4 + i * 8]; - bigEnd[4] = data[3 + i * 8]; - bigEnd[5] = data[2 + i * 8]; - bigEnd[6] = data[1 + i * 8]; - bigEnd[7] = data[i * 8]; - result[i] = PickleUtils.bytes2Long(bigEnd, 0); - } - } - return result; + protected long[] constructLongArrayFromUInt64(int machineCode, byte[] data) { + throw new PickleException("unsupported datatype: 64-bits unsigned long"); + } + + protected long[] constructLongArrayFromInt64(int machineCode, byte[] data) { + long[] result = new long[data.length / 8]; + byte[] bigEnd = new byte[8]; + for (int i = 0; i < data.length / 8; i++) { + if (machineCode == 12) { + result[i] = PickleUtils.bytes2Long(data, i * 8); + } else { + bigEnd[0] = data[7 + i * 8]; + bigEnd[1] = data[6 + i * 8]; + bigEnd[2] = data[5 + i * 8]; + bigEnd[3] = data[4 + i * 8]; + bigEnd[4] = data[3 + i * 8]; + bigEnd[5] = data[2 + i * 8]; + bigEnd[6] = data[1 + i * 8]; + bigEnd[7] = data[i * 8]; + result[i] = PickleUtils.bytes2Long(bigEnd, 0); + } } + return result; + } - protected double[] constructDoubleArray(int machineCode, byte[] data) { - double[] result = new double[data.length / 8]; - byte[] bigEnd = new byte[8]; - for (int i = 0; i < data.length / 8; ++i) { - if (machineCode == 17) { - result[i] = PickleUtils.bytes2Double(data, i * 8); - } else { - bigEnd[0] = data[7 + i * 8]; - bigEnd[1] = data[6 + i * 8]; - bigEnd[2] = data[5 + i * 8]; - bigEnd[3] = data[4 + i * 8]; - bigEnd[4] = data[3 + i * 8]; - bigEnd[5] = data[2 + i * 8]; - bigEnd[6] = data[1 + i * 8]; - bigEnd[7] = data[i * 8]; - result[i] = PickleUtils.bytes2Double(bigEnd, 0); - } - } - return result; + protected double[] constructDoubleArray(int machineCode, byte[] data) { + double[] result = new double[data.length / 8]; + byte[] bigEnd = new byte[8]; + for (int i = 0; i < data.length / 8; ++i) { + if (machineCode == 17) { + result[i] = PickleUtils.bytes2Double(data, i * 8); + } else { + bigEnd[0] = data[7 + i * 8]; + bigEnd[1] = data[6 + i * 8]; + bigEnd[2] = data[5 + i * 8]; + bigEnd[3] = data[4 + i * 8]; + bigEnd[4] = data[3 + i * 8]; + bigEnd[5] = data[2 + i * 8]; + bigEnd[6] = data[1 + i * 8]; + bigEnd[7] = data[i * 8]; + result[i] = PickleUtils.bytes2Double(bigEnd, 0); + } } + return result; + } - protected float[] constructFloatArray(int machineCode, byte[] data) { - float[] result = new float[data.length / 4]; - byte[] bigEnd = new byte[4]; - for (int i = 0; i < data.length / 4; ++i) { - if (machineCode == 15) { - result[i] = PickleUtils.bytes2Float(data, i * 4); - } else { - bigEnd[0] = data[3 + i * 4]; - bigEnd[1] = data[2 + i * 4]; - bigEnd[2] = data[1 + i * 4]; - bigEnd[3] = data[i * 4]; - result[i] = PickleUtils.bytes2Float(bigEnd, 0); - } - } - return result; + protected float[] constructFloatArray(int machineCode, byte[] data) { + float[] result = new float[data.length / 4]; + byte[] bigEnd = new byte[4]; + for (int i = 0; i < data.length / 4; ++i) { + if (machineCode == 15) { + result[i] = PickleUtils.bytes2Float(data, i * 4); + } else { + bigEnd[0] = data[3 + i * 4]; + bigEnd[1] = data[2 + i * 4]; + bigEnd[2] = data[1 + i * 4]; + bigEnd[3] = data[i * 4]; + result[i] = PickleUtils.bytes2Float(bigEnd, 0); + } } + return result; + } - protected int[] constructIntArrayFromUShort(int machineCode, byte[] data) { - int[] result = new int[data.length / 2]; - for (int i = 0; i < data.length / 2; ++i) { - int b1 = data[i * 2] & 0xff; - int b2 = data[1 + i * 2] & 0xff; - if (machineCode == 2) { - result[i] = (b2 << 8) | b1; - } else { - result[i] = (b1 << 8) | b2; - } - } - return result; + protected int[] constructIntArrayFromUShort(int machineCode, byte[] data) { + int[] result = new int[data.length / 2]; + for (int i = 0; i < data.length / 2; ++i) { + int b1 = data[i * 2] & 0xff; + int b2 = data[1 + i * 2] & 0xff; + if (machineCode == 2) { + result[i] = (b2 << 8) | b1; + } else { + result[i] = (b1 << 8) | b2; + } } + return result; + } - protected short[] constructShortArraySigned(int machineCode, byte[] data) { - short[] result = new short[data.length / 2]; - for (int i = 0; i < data.length / 2; ++i) { - byte b1 = data[i * 2]; - byte b2 = data[1 + i * 2]; - if (machineCode == 4) { - result[i] = (short) ((b2 << 8) | (b1 & 0xff)); - } else { - result[i] = (short) ((b1 << 8) | (b2 & 0xff)); - } - } - return result; + protected short[] constructShortArraySigned(int machineCode, byte[] data) { + short[] result = new short[data.length / 2]; + for (int i = 0; i < data.length / 2; ++i) { + byte b1 = data[i * 2]; + byte b2 = data[1 + i * 2]; + if (machineCode == 4) { + result[i] = (short) ((b2 << 8) | (b1 & 0xff)); + } else { + result[i] = (short) ((b1 << 8) | (b2 & 0xff)); + } } + return result; + } - protected short[] constructShortArrayFromUByte(byte[] data) { - short[] result = new short[data.length]; - for (int i = 0; i < data.length; ++i) { - result[i] = (short) (data[i] & 0xff); - } - return result; + protected short[] constructShortArrayFromUByte(byte[] data) { + short[] result = new short[data.length]; + for (int i = 0; i < data.length; ++i) { + result[i] = (short) (data[i] & 0xff); } + return result; + } - protected char[] constructCharArrayUTF32(int machineCode, byte[] data) { - char[] result = new char[data.length / 4]; - byte[] bigEndian = new byte[4]; - for (int index = 0; index < data.length / 4; ++index) { - if (machineCode == 20) { - int codepoint = PickleUtils.bytes2Integer(data, index * 4, 4); - char[] cc = Character.toChars(codepoint); - if (cc.length > 1) { - throw new PickleException( - "cannot process UTF-32 character codepoint " + codepoint); - } - result[index] = cc[0]; - } else { - bigEndian[0] = data[3 + index * 4]; - bigEndian[1] = data[2 + index * 4]; - bigEndian[2] = data[1 + index * 4]; - bigEndian[3] = data[index * 4]; - int codepoint = PickleUtils.bytes2Integer(bigEndian); - char[] cc = Character.toChars(codepoint); - if (cc.length > 1) { - throw new PickleException( - "cannot process UTF-32 character codepoint " + codepoint); - } - result[index] = cc[0]; - } + protected char[] constructCharArrayUTF32(int machineCode, byte[] data) { + char[] result = new char[data.length / 4]; + byte[] bigEndian = new byte[4]; + for (int index = 0; index < data.length / 4; ++index) { + if (machineCode == 20) { + int codepoint = PickleUtils.bytes2Integer(data, index * 4, 4); + char[] cc = Character.toChars(codepoint); + if (cc.length > 1) { + throw new PickleException("cannot process UTF-32 character codepoint " + codepoint); } - return result; + result[index] = cc[0]; + } else { + bigEndian[0] = data[3 + index * 4]; + bigEndian[1] = data[2 + index * 4]; + bigEndian[2] = data[1 + index * 4]; + bigEndian[3] = data[index * 4]; + int codepoint = PickleUtils.bytes2Integer(bigEndian); + char[] cc = Character.toChars(codepoint); + if (cc.length > 1) { + throw new PickleException("cannot process UTF-32 character codepoint " + codepoint); + } + result[index] = cc[0]; + } } + return result; + } - protected char[] constructCharArrayUTF16(int machineCode, byte[] data) { - char[] result = new char[data.length / 2]; - byte[] bigEndian = new byte[2]; - for (int index = 0; index < data.length / 2; ++index) { - if (machineCode == 18) { - result[index] = (char) PickleUtils.bytes2Integer(data, index * 2, 2); - } else { - bigEndian[0] = data[1 + index * 2]; - bigEndian[1] = data[index * 2]; - result[index] = (char) PickleUtils.bytes2Integer(bigEndian); - } - } - return result; + protected char[] constructCharArrayUTF16(int machineCode, byte[] data) { + char[] result = new char[data.length / 2]; + byte[] bigEndian = new byte[2]; + for (int index = 0; index < data.length / 2; ++index) { + if (machineCode == 18) { + result[index] = (char) PickleUtils.bytes2Integer(data, index * 2, 2); + } else { + bigEndian[0] = data[1 + index * 2]; + bigEndian[1] = data[index * 2]; + result[index] = (char) PickleUtils.bytes2Integer(bigEndian); + } } + return result; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ByteArrayConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ByteArrayConstructor.java index 28c0a1dc9..0e88d0319 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ByteArrayConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ByteArrayConstructor.java @@ -24,42 +24,42 @@ public class ByteArrayConstructor implements IObjectConstructor { - private static final String LATIN = "latin-"; + private static final String LATIN = "latin-"; - private static final String ISO_KEY = "ISO-8859-"; + private static final String ISO_KEY = "ISO-8859-"; - public Object construct(Object[] args) throws PickleException { - if (args.length > 2) { - throw new PickleException( - "invalid pickle data for bytearray; expected 0, 1 or 2 args, got " + args.length); - } + public Object construct(Object[] args) throws PickleException { + if (args.length > 2) { + throw new PickleException( + "invalid pickle data for bytearray; expected 0, 1 or 2 args, got " + args.length); + } - if (args.length == 0) { - return new byte[0]; - } + if (args.length == 0) { + return new byte[0]; + } - if (args.length == 1) { - if (args[0] instanceof byte[]) { - return args[0]; - } + if (args.length == 1) { + if (args[0] instanceof byte[]) { + return args[0]; + } - ArrayList values = (ArrayList) args[0]; - byte[] data = new byte[values.size()]; - for (int i = 0; i < data.length; ++i) { - data[i] = values.get(i).byteValue(); - } - return data; - } else { - String data = (String) args[0]; - String encoding = (String) args[1]; - if (encoding.startsWith(LATIN)) { - encoding = ISO_KEY + encoding.substring(6); - } - try { - return data.getBytes(encoding); - } catch (UnsupportedEncodingException e) { - throw new PickleException("error creating bytearray: " + e); - } - } + ArrayList values = (ArrayList) args[0]; + byte[] data = new byte[values.size()]; + for (int i = 0; i < data.length; ++i) { + data[i] = values.get(i).byteValue(); + } + return data; + } else { + String data = (String) args[0]; + String encoding = (String) args[1]; + if (encoding.startsWith(LATIN)) { + encoding = ISO_KEY + encoding.substring(6); + } + try { + return data.getBytes(encoding); + } catch (UnsupportedEncodingException e) { + throw new PickleException("error creating bytearray: " + e); + } } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumber.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumber.java index 3651790cd..a08fe709b 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumber.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumber.java @@ -23,82 +23,82 @@ public class ComplexNumber implements Serializable { - private static final long serialVersionUID = 4668080260997226513L; - private static final String ADD_FLAG = "+"; - private static final String I_FLAG = "i"; - - private final double r; - private final double i; - - public ComplexNumber(double rr, double ii) { - r = rr; - i = ii; - } - - public ComplexNumber(Double rr, Double ii) { - r = rr; - i = ii; - } - - public String toString() { - StringBuffer sb = new StringBuffer().append(r); - if (i >= 0) { - sb.append(ADD_FLAG); - } - return sb.append(i).append(I_FLAG).toString(); - } - - public double getReal() { - return r; - } - - public double getImaginary() { - return i; - } - - public double magnitude() { - return Math.sqrt(r * r + i * i); - } - - public ComplexNumber add(ComplexNumber other) { - return add(this, other); + private static final long serialVersionUID = 4668080260997226513L; + private static final String ADD_FLAG = "+"; + private static final String I_FLAG = "i"; + + private final double r; + private final double i; + + public ComplexNumber(double rr, double ii) { + r = rr; + i = ii; + } + + public ComplexNumber(Double rr, Double ii) { + r = rr; + i = ii; + } + + public String toString() { + StringBuffer sb = new StringBuffer().append(r); + if (i >= 0) { + sb.append(ADD_FLAG); } - - public static ComplexNumber add(ComplexNumber c1, ComplexNumber c2) { - return new ComplexNumber(c1.r + c2.r, c1.i + c2.i); - } - - public ComplexNumber subtract(ComplexNumber other) { - return subtract(this, other); - } - - public static ComplexNumber subtract(ComplexNumber c1, ComplexNumber c2) { - return new ComplexNumber(c1.r - c2.r, c1.i - c2.i); - } - - public ComplexNumber multiply(ComplexNumber other) { - return multiply(this, other); + return sb.append(i).append(I_FLAG).toString(); + } + + public double getReal() { + return r; + } + + public double getImaginary() { + return i; + } + + public double magnitude() { + return Math.sqrt(r * r + i * i); + } + + public ComplexNumber add(ComplexNumber other) { + return add(this, other); + } + + public static ComplexNumber add(ComplexNumber c1, ComplexNumber c2) { + return new ComplexNumber(c1.r + c2.r, c1.i + c2.i); + } + + public ComplexNumber subtract(ComplexNumber other) { + return subtract(this, other); + } + + public static ComplexNumber subtract(ComplexNumber c1, ComplexNumber c2) { + return new ComplexNumber(c1.r - c2.r, c1.i - c2.i); + } + + public ComplexNumber multiply(ComplexNumber other) { + return multiply(this, other); + } + + public static ComplexNumber multiply(ComplexNumber c1, ComplexNumber c2) { + return new ComplexNumber(c1.r * c2.r - c1.i * c2.i, c1.r * c2.i + c1.i * c2.r); + } + + public static ComplexNumber divide(ComplexNumber c1, ComplexNumber c2) { + double value = c2.r * c2.r + c2.i * c2.i; + return new ComplexNumber( + (c1.r * c2.r + c1.i * c2.i) / (value), (c1.i * c2.r - c1.r * c2.i) / (value)); + } + + public boolean equals(Object o) { + if (!(o instanceof ComplexNumber)) { + return false; } + ComplexNumber other = (ComplexNumber) o; + return r == other.r && i == other.i; + } - public static ComplexNumber multiply(ComplexNumber c1, ComplexNumber c2) { - return new ComplexNumber(c1.r * c2.r - c1.i * c2.i, c1.r * c2.i + c1.i * c2.r); - } - - public static ComplexNumber divide(ComplexNumber c1, ComplexNumber c2) { - double value = c2.r * c2.r + c2.i * c2.i; - return new ComplexNumber((c1.r * c2.r + c1.i * c2.i) / (value), - (c1.i * c2.r - c1.r * c2.i) / (value)); - } - - public boolean equals(Object o) { - if (!(o instanceof ComplexNumber)) { - return false; - } - ComplexNumber other = (ComplexNumber) o; - return r == other.r && i == other.i; - } - - public int hashCode() { - return (Double.valueOf(r).hashCode()) ^ (Double.valueOf(i).hashCode()); - } + public int hashCode() { + return (Double.valueOf(r).hashCode()) ^ (Double.valueOf(i).hashCode()); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumberConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumberConstructor.java index 38e6db64e..e3803d11d 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumberConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ComplexNumberConstructor.java @@ -23,19 +23,19 @@ import java.math.BigDecimal; public class ComplexNumberConstructor extends AnyClassConstructor { - private static final String NAN = "nan"; + private static final String NAN = "nan"; - public ComplexNumberConstructor() { - super(ComplexNumber.class); - } + public ComplexNumberConstructor() { + super(ComplexNumber.class); + } - protected Object initClassImpl(Constructor cons, Object[] args) throws Exception { - if (this.type == BigDecimal.class && args.length == 1) { - String nan = (String) args[0]; - if (nan.equalsIgnoreCase(NAN)) { - return Double.NaN; - } - } - return cons.newInstance(args); + protected Object initClassImpl(Constructor cons, Object[] args) throws Exception { + if (this.type == BigDecimal.class && args.length == 1) { + String nan = (String) args[0]; + if (nan.equalsIgnoreCase(NAN)) { + return Double.NaN; + } } + return cons.newInstance(args); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Dictionary.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Dictionary.java index ea7fb1879..e982fcb35 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Dictionary.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Dictionary.java @@ -23,31 +23,31 @@ public class Dictionary extends HashMap { - private static final long serialVersionUID = 6157715596627049511L; + private static final long serialVersionUID = 6157715596627049511L; - private static final String CLASS_KEY = "__class__"; + private static final String CLASS_KEY = "__class__"; - private static final String CLASS_NAME_REGEX = "."; + private static final String CLASS_NAME_REGEX = "."; - private final String className; + private final String className; - public Dictionary(String moduleName, String classname) { - if (moduleName == null) { - this.className = classname; - } else { - this.className = moduleName + CLASS_NAME_REGEX + classname; - } - - this.put(CLASS_KEY, this.className); + public Dictionary(String moduleName, String classname) { + if (moduleName == null) { + this.className = classname; + } else { + this.className = moduleName + CLASS_NAME_REGEX + classname; } - public void setState(HashMap values) { - this.clear(); - this.put(CLASS_KEY, this.className); - this.putAll(values); - } + this.put(CLASS_KEY, this.className); + } - public String getClassName() { - return this.className; - } + public void setState(HashMap values) { + this.clear(); + this.put(CLASS_KEY, this.className); + this.putAll(values); + } + + public String getClassName() { + return this.className; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/DictionaryConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/DictionaryConstructor.java index f6b841dc8..d02d9fecb 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/DictionaryConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/DictionaryConstructor.java @@ -20,21 +20,25 @@ package org.apache.geaflow.infer.exchange.serialize; public class DictionaryConstructor implements IObjectConstructor { - private final String module; - private final String name; + private final String module; + private final String name; - public DictionaryConstructor(String module, String name) { - this.module = module; - this.name = name; - } + public DictionaryConstructor(String module, String name) { + this.module = module; + this.name = name; + } - public Object construct(Object[] args) { - if (args.length > 0) { - throw new PickleException("expected zero arguments for construction of ClassDict (for " + module + "." + name - + "). This happens when an unsupported/unregistered class is being unpickled that" - + " requires construction arguments. Fix it by registering a custom IObjectConstructor for this class."); - } - return new Dictionary(module, name); + public Object construct(Object[] args) { + if (args.length > 0) { + throw new PickleException( + "expected zero arguments for construction of ClassDict (for " + + module + + "." + + name + + "). This happens when an unsupported/unregistered class is being unpickled that" + + " requires construction arguments. Fix it by registering a custom" + + " IObjectConstructor for this class."); } + return new Dictionary(module, name); + } } - diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ExceptionConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ExceptionConstructor.java index 5345a5eef..f9a0103d5 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ExceptionConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/ExceptionConstructor.java @@ -21,57 +21,58 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Field; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class ExceptionConstructor implements IObjectConstructor { - private static final String PYTHON_EXCEPTION_TYPE = "pythonExceptionType"; + private static final String PYTHON_EXCEPTION_TYPE = "pythonExceptionType"; - private static final String MODULE_SUFFIX = "."; + private static final String MODULE_SUFFIX = "."; - private static final String BRACKETS_LEFT = "["; + private static final String BRACKETS_LEFT = "["; - private static final String BRACKETS_RIGHT = "]"; + private static final String BRACKETS_RIGHT = "]"; - private final Class type; + private final Class type; - private final String pythonExceptionType; + private final String pythonExceptionType; - public ExceptionConstructor(Class type, String module, String name) { - if (module != null) { - pythonExceptionType = module + MODULE_SUFFIX + name; - } else { - pythonExceptionType = name; - } - this.type = type; + public ExceptionConstructor(Class type, String module, String name) { + if (module != null) { + pythonExceptionType = module + MODULE_SUFFIX + name; + } else { + pythonExceptionType = name; } + this.type = type; + } - public Object construct(Object[] args) { - try { - if (pythonExceptionType != null) { - if (args == null || args.length == 0) { - args = new String[]{BRACKETS_LEFT + pythonExceptionType + BRACKETS_RIGHT}; - } else { - String msg = BRACKETS_LEFT + pythonExceptionType + BRACKETS_RIGHT + args[0]; - args = new String[]{msg}; - } - } - Class[] paramTypes = new Class[args.length]; - for (int i = 0; i < args.length; ++i) { - paramTypes[i] = args[i].getClass(); - } - Constructor cons = type.getConstructor(paramTypes); - Object ex = cons.newInstance(args); - - try { - Field prop = ex.getClass().getField(PYTHON_EXCEPTION_TYPE); - prop.set(ex, pythonExceptionType); - } catch (NoSuchFieldException e) { - throw new GeaflowRuntimeException(e); - } - return ex; - } catch (Exception x) { - throw new PickleException("problem construction object: " + x); + public Object construct(Object[] args) { + try { + if (pythonExceptionType != null) { + if (args == null || args.length == 0) { + args = new String[] {BRACKETS_LEFT + pythonExceptionType + BRACKETS_RIGHT}; + } else { + String msg = BRACKETS_LEFT + pythonExceptionType + BRACKETS_RIGHT + args[0]; + args = new String[] {msg}; } + } + Class[] paramTypes = new Class[args.length]; + for (int i = 0; i < args.length; ++i) { + paramTypes[i] = args[i].getClass(); + } + Constructor cons = type.getConstructor(paramTypes); + Object ex = cons.newInstance(args); + + try { + Field prop = ex.getClass().getField(PYTHON_EXCEPTION_TYPE); + prop.set(ex, pythonExceptionType); + } catch (NoSuchFieldException e) { + throw new GeaflowRuntimeException(e); + } + return ex; + } catch (Exception x) { + throw new PickleException("problem construction object: " + x); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectConstructor.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectConstructor.java index c525154ad..4b761bfcb 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectConstructor.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectConstructor.java @@ -21,5 +21,5 @@ public interface IObjectConstructor { - Object construct(Object[] args) throws PickleException; + Object construct(Object[] args) throws PickleException; } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectPickler.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectPickler.java index a905d7861..6b9a8f759 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectPickler.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/IObjectPickler.java @@ -23,5 +23,6 @@ import java.io.OutputStream; public interface IObjectPickler { - void pickle(Object o, OutputStream out, Pickler currentPickler) throws PickleException, IOException; + void pickle(Object o, OutputStream out, Pickler currentPickler) + throws PickleException, IOException; } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/InvalidOpcodeException.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/InvalidOpcodeException.java index 270b4b927..0c1aeff56 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/InvalidOpcodeException.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/InvalidOpcodeException.java @@ -23,14 +23,13 @@ public class InvalidOpcodeException extends GeaflowRuntimeException { - private static final long serialVersionUID = -7691944009311968713L; + private static final long serialVersionUID = -7691944009311968713L; - public InvalidOpcodeException(String message, Throwable cause) { - super(message, cause); - } - - public InvalidOpcodeException(String message) { - super(message); - } + public InvalidOpcodeException(String message, Throwable cause) { + super(message, cause); + } + public InvalidOpcodeException(String message) { + super(message); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/OpCodeConstant.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/OpCodeConstant.java index b563e93f1..c28a20486 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/OpCodeConstant.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/OpCodeConstant.java @@ -21,75 +21,75 @@ public class OpCodeConstant { - public static final short MARK = '('; - public static final short STOP = '.'; - public static final short POP = '0'; - public static final short POP_MARK = '1'; - public static final short DUP = '2'; - public static final short FLOAT = 'F'; - public static final short INT = 'I'; - public static final short BININT = 'J'; - public static final short BININT1 = 'K'; - public static final short LONG = 'L'; - public static final short BININT2 = 'M'; - public static final short NONE = 'N'; - public static final short PERSID = 'P'; - public static final short BINPERSID = 'Q'; - public static final short REDUCE = 'R'; - public static final short STRING = 'S'; - public static final short BINSTRING = 'T'; - public static final short SHORT_BINSTRING = 'U'; - public static final short UNICODE = 'V'; - public static final short BINUNICODE = 'X'; - public static final short APPEND = 'a'; - public static final short BUILD = 'b'; - public static final short GLOBAL = 'c'; - public static final short DICT = 'd'; - public static final short EMPTY_DICT = '}'; - public static final short APPENDS = 'e'; - public static final short GET = 'g'; - public static final short BINGET = 'h'; - public static final short INST = 'i'; - public static final short LONG_BINGET = 'j'; - public static final short LIST = 'l'; - public static final short EMPTY_LIST = ']'; - public static final short OBJ = 'o'; - public static final short PUT = 'p'; - public static final short BINPUT = 'q'; - public static final short LONG_BINPUT = 'r'; - public static final short SETITEM = 's'; - public static final short TUPLE = 't'; - public static final short EMPTY_TUPLE = ')'; - public static final short SETITEMS = 'u'; - public static final short BINFLOAT = 'G'; - public static final String TRUE = "I01\n"; - public static final String FALSE = "I00\n"; - public static final short PROTO = 0x80; - public static final short NEWOBJ = 0x81; - public static final short EXT1 = 0x82; - public static final short EXT2 = 0x83; - public static final short EXT4 = 0x84; - public static final short TUPLE1 = 0x85; - public static final short TUPLE2 = 0x86; - public static final short TUPLE3 = 0x87; - public static final short NEWTRUE = 0x88; - public static final short NEWFALSE = 0x89; - public static final short LONG1 = 0x8a; - public static final short LONG4 = 0x8b; - public static final short BINBYTES = 'B'; - public static final short SHORT_BINBYTES = 'C'; - public static final short SHORT_BINUNICODE = 0x8c; - public static final short BINUNICODE8 = 0x8d; - public static final short BINBYTES8 = 0x8e; - public static final short EMPTY_SET = 0x8f; - public static final short ADDITEMS = 0x90; - public static final short FROZENSET = 0x91; - public static final short MEMOIZE = 0x94; - public static final short FRAME = 0x95; - public static final short NEWOBJ_EX = 0x92; - public static final short STACK_GLOBAL = 0x93; + public static final short MARK = '('; + public static final short STOP = '.'; + public static final short POP = '0'; + public static final short POP_MARK = '1'; + public static final short DUP = '2'; + public static final short FLOAT = 'F'; + public static final short INT = 'I'; + public static final short BININT = 'J'; + public static final short BININT1 = 'K'; + public static final short LONG = 'L'; + public static final short BININT2 = 'M'; + public static final short NONE = 'N'; + public static final short PERSID = 'P'; + public static final short BINPERSID = 'Q'; + public static final short REDUCE = 'R'; + public static final short STRING = 'S'; + public static final short BINSTRING = 'T'; + public static final short SHORT_BINSTRING = 'U'; + public static final short UNICODE = 'V'; + public static final short BINUNICODE = 'X'; + public static final short APPEND = 'a'; + public static final short BUILD = 'b'; + public static final short GLOBAL = 'c'; + public static final short DICT = 'd'; + public static final short EMPTY_DICT = '}'; + public static final short APPENDS = 'e'; + public static final short GET = 'g'; + public static final short BINGET = 'h'; + public static final short INST = 'i'; + public static final short LONG_BINGET = 'j'; + public static final short LIST = 'l'; + public static final short EMPTY_LIST = ']'; + public static final short OBJ = 'o'; + public static final short PUT = 'p'; + public static final short BINPUT = 'q'; + public static final short LONG_BINPUT = 'r'; + public static final short SETITEM = 's'; + public static final short TUPLE = 't'; + public static final short EMPTY_TUPLE = ')'; + public static final short SETITEMS = 'u'; + public static final short BINFLOAT = 'G'; + public static final String TRUE = "I01\n"; + public static final String FALSE = "I00\n"; + public static final short PROTO = 0x80; + public static final short NEWOBJ = 0x81; + public static final short EXT1 = 0x82; + public static final short EXT2 = 0x83; + public static final short EXT4 = 0x84; + public static final short TUPLE1 = 0x85; + public static final short TUPLE2 = 0x86; + public static final short TUPLE3 = 0x87; + public static final short NEWTRUE = 0x88; + public static final short NEWFALSE = 0x89; + public static final short LONG1 = 0x8a; + public static final short LONG4 = 0x8b; + public static final short BINBYTES = 'B'; + public static final short SHORT_BINBYTES = 'C'; + public static final short SHORT_BINUNICODE = 0x8c; + public static final short BINUNICODE8 = 0x8d; + public static final short BINBYTES8 = 0x8e; + public static final short EMPTY_SET = 0x8f; + public static final short ADDITEMS = 0x90; + public static final short FROZENSET = 0x91; + public static final short MEMOIZE = 0x94; + public static final short FRAME = 0x95; + public static final short NEWOBJ_EX = 0x92; + public static final short STACK_GLOBAL = 0x93; - public static final short BYTEARRAY8 = 0x96; - public static final short NEXT_BUFFER = 0x97; - public static final short READONLY_BUFFER = 0x98; + public static final short BYTEARRAY8 = 0x96; + public static final short NEXT_BUFFER = 0x97; + public static final short READONLY_BUFFER = 0x98; } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleException.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleException.java index 036574d36..0514a5087 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleException.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleException.java @@ -23,13 +23,13 @@ public class PickleException extends GeaflowRuntimeException { - private static final long serialVersionUID = -5870448664938735316L; + private static final long serialVersionUID = -5870448664938735316L; - public PickleException(String message, Throwable cause) { - super(message, cause); - } + public PickleException(String message, Throwable cause) { + super(message, cause); + } - public PickleException(String message) { - super(message); - } + public PickleException(String message) { + super(message); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleUtils.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleUtils.java index b00d1b880..bc7795f08 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleUtils.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PickleUtils.java @@ -26,364 +26,367 @@ public class PickleUtils { - public static String readLine(InputStream input) throws IOException { - return readLine(input, false); - } - - public static String readLine(InputStream input, boolean includeLF) throws IOException { - StringBuilder sb = new StringBuilder(); - while (true) { - int c = input.read(); - if (c == -1) { - if (sb.length() == 0) { - throw new IOException("premature end of file"); - } - break; - } - if (c != '\n' || includeLF) { - sb.append((char) c); - } - if (c == '\n') { - break; - } + public static String readLine(InputStream input) throws IOException { + return readLine(input, false); + } + + public static String readLine(InputStream input, boolean includeLF) throws IOException { + StringBuilder sb = new StringBuilder(); + while (true) { + int c = input.read(); + if (c == -1) { + if (sb.length() == 0) { + throw new IOException("premature end of file"); } - return sb.toString(); + break; + } + if (c != '\n' || includeLF) { + sb.append((char) c); + } + if (c == '\n') { + break; + } } - - public static short readByte(InputStream input) throws IOException { - int b = input.read(); - return (short) b; + return sb.toString(); + } + + public static short readByte(InputStream input) throws IOException { + int b = input.read(); + return (short) b; + } + + public static byte[] readBytes(InputStream input, int n) throws IOException { + byte[] buffer = new byte[n]; + readBytesInto(input, buffer, 0, n); + return buffer; + } + + public static byte[] readBytes(InputStream input, long n) throws IOException { + if (n > Integer.MAX_VALUE) { + throw new PickleException("pickle too large, can't read more than maxint"); } - - public static byte[] readBytes(InputStream input, int n) throws IOException { - byte[] buffer = new byte[n]; - readBytesInto(input, buffer, 0, n); - return buffer; + return readBytes(input, (int) n); + } + + public static void readBytesInto(InputStream input, byte[] buffer, int offset, int length) + throws IOException { + while (length > 0) { + int read = input.read(buffer, offset, length); + if (read == -1) { + throw new IOException("expected more bytes in input stream"); + } + offset += read; + length -= read; } - - public static byte[] readBytes(InputStream input, long n) throws IOException { - if (n > Integer.MAX_VALUE) { - throw new PickleException("pickle too large, can't read more than maxint"); - } - return readBytes(input, (int) n); + } + + public static int bytes2Integer(byte[] bytes) { + return bytes2Integer(bytes, 0, bytes.length); + } + + public static int bytes2Integer(byte[] bytes, int offset, int size) { + if (size == 2) { + int i = bytes[1 + offset] & 0xff; + i <<= 8; + i |= bytes[offset] & 0xff; + return i; + } else if (size == 4) { + int i = bytes[3 + offset]; + i <<= 8; + i |= bytes[2 + offset] & 0xff; + i <<= 8; + i |= bytes[1 + offset] & 0xff; + i <<= 8; + i |= bytes[offset] & 0xff; + return i; + } else { + throw new PickleException("invalid amount of bytes to convert to int: " + size); } + } - public static void readBytesInto(InputStream input, byte[] buffer, int offset, int length) throws IOException { - while (length > 0) { - int read = input.read(buffer, offset, length); - if (read == -1) { - throw new IOException("expected more bytes in input stream"); - } - offset += read; - length -= read; - } + public static long bytes2Long(byte[] bytes, int offset) { + if (bytes.length - offset < 8) { + throw new PickleException("too few bytes to convert to long"); } - - public static int bytes2Integer(byte[] bytes) { - return bytes2Integer(bytes, 0, bytes.length); + long i = bytes[7 + offset] & 0xff; + i <<= 8; + i |= bytes[6 + offset] & 0xff; + i <<= 8; + i |= bytes[5 + offset] & 0xff; + i <<= 8; + i |= bytes[4 + offset] & 0xff; + i <<= 8; + i |= bytes[3 + offset] & 0xff; + i <<= 8; + i |= bytes[2 + offset] & 0xff; + i <<= 8; + i |= bytes[1 + offset] & 0xff; + i <<= 8; + i |= bytes[offset] & 0xff; + return i; + } + + public static long bytes2Uint(byte[] bytes, int offset) { + if (bytes.length - offset < 4) { + throw new PickleException("too few bytes to convert to long"); } - - public static int bytes2Integer(byte[] bytes, int offset, int size) { - if (size == 2) { - int i = bytes[1 + offset] & 0xff; - i <<= 8; - i |= bytes[offset] & 0xff; - return i; - } else if (size == 4) { - int i = bytes[3 + offset]; - i <<= 8; - i |= bytes[2 + offset] & 0xff; - i <<= 8; - i |= bytes[1 + offset] & 0xff; - i <<= 8; - i |= bytes[offset] & 0xff; - return i; - } else { - throw new PickleException("invalid amount of bytes to convert to int: " + size); - } + long i = bytes[3 + offset] & 0xff; + i <<= 8; + i |= bytes[2 + offset] & 0xff; + i <<= 8; + i |= bytes[1 + offset] & 0xff; + i <<= 8; + i |= bytes[offset] & 0xff; + return i; + } + + public static byte[] integer2Bytes(int i) { + final byte[] b = new byte[4]; + b[0] = (byte) (i & 0xff); + i >>= 8; + b[1] = (byte) (i & 0xff); + i >>= 8; + b[2] = (byte) (i & 0xff); + i >>= 8; + b[3] = (byte) (i & 0xff); + return b; + } + + public static byte[] double2Bytes(double d) { + long bits = Double.doubleToRawLongBits(d); + final byte[] b = new byte[8]; + b[7] = (byte) (bits & 0xff); + bits >>= 8; + b[6] = (byte) (bits & 0xff); + bits >>= 8; + b[5] = (byte) (bits & 0xff); + bits >>= 8; + b[4] = (byte) (bits & 0xff); + bits >>= 8; + b[3] = (byte) (bits & 0xff); + bits >>= 8; + b[2] = (byte) (bits & 0xff); + bits >>= 8; + b[1] = (byte) (bits & 0xff); + bits >>= 8; + b[0] = (byte) (bits & 0xff); + return b; + } + + public static double bytes2Double(byte[] bytes, int offset) { + try { + long result = bytes[offset] & 0xff; + result <<= 8; + result |= bytes[1 + offset] & 0xff; + result <<= 8; + result |= bytes[2 + offset] & 0xff; + result <<= 8; + result |= bytes[3 + offset] & 0xff; + result <<= 8; + result |= bytes[4 + offset] & 0xff; + result <<= 8; + result |= bytes[5 + offset] & 0xff; + result <<= 8; + result |= bytes[6 + offset] & 0xff; + result <<= 8; + result |= bytes[7 + offset] & 0xff; + return Double.longBitsToDouble(result); + } catch (IndexOutOfBoundsException x) { + throw new PickleException("decoding double: too few bytes"); } - - - public static long bytes2Long(byte[] bytes, int offset) { - if (bytes.length - offset < 8) { - throw new PickleException("too few bytes to convert to long"); - } - long i = bytes[7 + offset] & 0xff; - i <<= 8; - i |= bytes[6 + offset] & 0xff; - i <<= 8; - i |= bytes[5 + offset] & 0xff; - i <<= 8; - i |= bytes[4 + offset] & 0xff; - i <<= 8; - i |= bytes[3 + offset] & 0xff; - i <<= 8; - i |= bytes[2 + offset] & 0xff; - i <<= 8; - i |= bytes[1 + offset] & 0xff; - i <<= 8; - i |= bytes[offset] & 0xff; - return i; + } + + public static float bytes2Float(byte[] bytes, int offset) { + try { + int result = bytes[offset] & 0xff; + result <<= 8; + result |= bytes[1 + offset] & 0xff; + result <<= 8; + result |= bytes[2 + offset] & 0xff; + result <<= 8; + result |= bytes[3 + offset] & 0xff; + return Float.intBitsToFloat(result); + } catch (IndexOutOfBoundsException x) { + throw new PickleException(String.format("decoding float: too few bytes, %s", x)); } + } - - public static long bytes2Uint(byte[] bytes, int offset) { - if (bytes.length - offset < 4) { - throw new PickleException("too few bytes to convert to long"); - } - long i = bytes[3 + offset] & 0xff; - i <<= 8; - i |= bytes[2 + offset] & 0xff; - i <<= 8; - i |= bytes[1 + offset] & 0xff; - i <<= 8; - i |= bytes[offset] & 0xff; - return i; + public static Number decodeLong(byte[] data) { + if (data.length == 0) { + return 0L; } - - - public static byte[] integer2Bytes(int i) { - final byte[] b = new byte[4]; - b[0] = (byte) (i & 0xff); - i >>= 8; - b[1] = (byte) (i & 0xff); - i >>= 8; - b[2] = (byte) (i & 0xff); - i >>= 8; - b[3] = (byte) (i & 0xff); - return b; + byte[] data2 = new byte[data.length]; + for (int i = 0; i < data.length; ++i) { + data2[data.length - i - 1] = data[i]; } - - public static byte[] double2Bytes(double d) { - long bits = Double.doubleToRawLongBits(d); - final byte[] b = new byte[8]; - b[7] = (byte) (bits & 0xff); - bits >>= 8; - b[6] = (byte) (bits & 0xff); - bits >>= 8; - b[5] = (byte) (bits & 0xff); - bits >>= 8; - b[4] = (byte) (bits & 0xff); - bits >>= 8; - b[3] = (byte) (bits & 0xff); - bits >>= 8; - b[2] = (byte) (bits & 0xff); - bits >>= 8; - b[1] = (byte) (bits & 0xff); - bits >>= 8; - b[0] = (byte) (bits & 0xff); - return b; + BigInteger bigint = new BigInteger(data2); + return optimizeBigint(bigint); + } + + public static byte[] encodeLong(BigInteger big) { + byte[] data = big.toByteArray(); + byte[] data2 = new byte[data.length]; + for (int i = 0; i < data.length; ++i) { + data2[data.length - i - 1] = data[i]; } - - - public static double bytes2Double(byte[] bytes, int offset) { - try { - long result = bytes[offset] & 0xff; - result <<= 8; - result |= bytes[1 + offset] & 0xff; - result <<= 8; - result |= bytes[2 + offset] & 0xff; - result <<= 8; - result |= bytes[3 + offset] & 0xff; - result <<= 8; - result |= bytes[4 + offset] & 0xff; - result <<= 8; - result |= bytes[5 + offset] & 0xff; - result <<= 8; - result |= bytes[6 + offset] & 0xff; - result <<= 8; - result |= bytes[7 + offset] & 0xff; - return Double.longBitsToDouble(result); - } catch (IndexOutOfBoundsException x) { - throw new PickleException("decoding double: too few bytes"); + return data2; + } + + public static Number optimizeBigint(BigInteger bigint) { + final BigInteger maxLong = BigInteger.valueOf(Long.MAX_VALUE); + final BigInteger minLong = BigInteger.valueOf(Long.MIN_VALUE); + switch (bigint.signum()) { + case 0: + return 0L; + case 1: + if (bigint.compareTo(maxLong) <= 0) { + return bigint.longValue(); } - } - - - public static float bytes2Float(byte[] bytes, int offset) { - try { - int result = bytes[offset] & 0xff; - result <<= 8; - result |= bytes[1 + offset] & 0xff; - result <<= 8; - result |= bytes[2 + offset] & 0xff; - result <<= 8; - result |= bytes[3 + offset] & 0xff; - return Float.intBitsToFloat(result); - } catch (IndexOutOfBoundsException x) { - throw new PickleException(String.format("decoding float: too few bytes, %s", x)); + break; + case -1: + if (bigint.compareTo(minLong) >= 0) { + return bigint.longValue(); } + break; + default: + break; } + return bigint; + } - - public static Number decodeLong(byte[] data) { - if (data.length == 0) { - return 0L; - } - byte[] data2 = new byte[data.length]; - for (int i = 0; i < data.length; ++i) { - data2[data.length - i - 1] = data[i]; - } - BigInteger bigint = new BigInteger(data2); - return optimizeBigint(bigint); + public static String rawStringFromBytes(byte[] data) { + StringBuilder str = new StringBuilder(data.length); + for (byte b : data) { + str.append((char) (b & 0xff)); } - - - public static byte[] encodeLong(BigInteger big) { - byte[] data = big.toByteArray(); - byte[] data2 = new byte[data.length]; - for (int i = 0; i < data.length; ++i) { - data2[data.length - i - 1] = data[i]; - } - return data2; + return str.toString(); + } + + public static byte[] str2bytes(String str) throws IOException { + byte[] b = new byte[str.length()]; + for (int i = 0; i < str.length(); ++i) { + char c = str.charAt(i); + if (c > 255) { + throw new UnsupportedEncodingException( + "string contained a char > 255," + " cannot convert to bytes"); + } + b[i] = (byte) c; } + return b; + } - public static Number optimizeBigint(BigInteger bigint) { - final BigInteger maxLong = BigInteger.valueOf(Long.MAX_VALUE); - final BigInteger minLong = BigInteger.valueOf(Long.MIN_VALUE); - switch (bigint.signum()) { - case 0: - return 0L; - case 1: - if (bigint.compareTo(maxLong) <= 0) { - return bigint.longValue(); - } - break; - case -1: - if (bigint.compareTo(minLong) >= 0) { - return bigint.longValue(); - } - break; - default: - break; - } - return bigint; + public static String decodeEscaped(String str) { + if (str.indexOf('\\') == -1) { + return str; } - - public static String rawStringFromBytes(byte[] data) { - StringBuilder str = new StringBuilder(data.length); - for (byte b : data) { - str.append((char) (b & 0xff)); - } - return str.toString(); - } - - public static byte[] str2bytes(String str) throws IOException { - byte[] b = new byte[str.length()]; - for (int i = 0; i < str.length(); ++i) { - char c = str.charAt(i); - if (c > 255) { - throw new UnsupportedEncodingException("string contained a char > 255," - + " cannot convert to bytes"); + StringBuilder sb = new StringBuilder(str.length()); + for (int i = 0; i < str.length(); ++i) { + char c = str.charAt(i); + if (c == '\\') { + char c2 = str.charAt(++i); + switch (c2) { + case '\\': + sb.append(c); + break; + case 'x': + char h1 = str.charAt(++i); + char h2 = str.charAt(++i); + c2 = (char) Integer.parseInt("" + h1 + h2, 16); + sb.append(c2); + break; + case 'n': + sb.append('\n'); + break; + case 'r': + sb.append('\r'); + break; + case 't': + sb.append('\t'); + break; + case '\'': + sb.append('\''); + break; + default: + if (str.length() > 80) { + str = str.substring(0, 80); } - b[i] = (byte) c; + throw new PickleException( + "invalid escape sequence char \'" + + (c2) + + "\' in string \"" + + str + + " [...]\" (possibly truncated)"); } - return b; + } else { + sb.append(str.charAt(i)); + } } + return sb.toString(); + } - - public static String decodeEscaped(String str) { - if (str.indexOf('\\') == -1) { - return str; - } - StringBuilder sb = new StringBuilder(str.length()); - for (int i = 0; i < str.length(); ++i) { - char c = str.charAt(i); - if (c == '\\') { - char c2 = str.charAt(++i); - switch (c2) { - case '\\': - sb.append(c); - break; - case 'x': - char h1 = str.charAt(++i); - char h2 = str.charAt(++i); - c2 = (char) Integer.parseInt("" + h1 + h2, 16); - sb.append(c2); - break; - case 'n': - sb.append('\n'); - break; - case 'r': - sb.append('\r'); - break; - case 't': - sb.append('\t'); - break; - case '\'': - sb.append('\''); - break; - default: - if (str.length() > 80) { - str = str.substring(0, 80); - } - throw new PickleException("invalid escape sequence char \'" - + (c2) + "\' in string \"" + str + " [...]\" (possibly truncated)"); - } - } else { - sb.append(str.charAt(i)); - } - } - return sb.toString(); + public static String decodeUnicodeEscaped(String str) { + if (str.indexOf('\\') == -1) { + return str; } - - - public static String decodeUnicodeEscaped(String str) { - if (str.indexOf('\\') == -1) { - return str; - } - StringBuilder sb = new StringBuilder(str.length()); - for (int i = 0; i < str.length(); ++i) { - char c = str.charAt(i); - if (c == '\\') { - char c2 = str.charAt(++i); - switch (c2) { - case '\\': - sb.append(c); - break; - case 'u': { - char h1 = str.charAt(++i); - char h2 = str.charAt(++i); - char h3 = str.charAt(++i); - char h4 = str.charAt(++i); - c2 = (char) Integer.parseInt("" + h1 + h2 + h3 + h4, 16); - sb.append(c2); - break; - } - case 'U': { - char h1 = str.charAt(++i); - char h2 = str.charAt(++i); - char h3 = str.charAt(++i); - char h4 = str.charAt(++i); - char h5 = str.charAt(++i); - char h6 = str.charAt(++i); - char h7 = str.charAt(++i); - char h8 = str.charAt(++i); - String encoded = "" + h1 + h2 + h3 + h4 + h5 + h6 + h7 + h8; - String s = new String(Character.toChars(Integer.parseInt(encoded, 16))); - sb.append(s); - break; - } - case 'n': - sb.append('\n'); - break; - case 'r': - sb.append('\r'); - break; - case 't': - sb.append('\t'); - break; - default: - if (str.length() > 80) { - str = str.substring(0, 80); - } - throw new PickleException("invalid escape sequence char " - + "\'" + (c2) + "\' in string \"" + str + " [...]\" (possibly truncated)"); - } - } else { - sb.append(str.charAt(i)); + StringBuilder sb = new StringBuilder(str.length()); + for (int i = 0; i < str.length(); ++i) { + char c = str.charAt(i); + if (c == '\\') { + char c2 = str.charAt(++i); + switch (c2) { + case '\\': + sb.append(c); + break; + case 'u': + { + char h1 = str.charAt(++i); + char h2 = str.charAt(++i); + char h3 = str.charAt(++i); + char h4 = str.charAt(++i); + c2 = (char) Integer.parseInt("" + h1 + h2 + h3 + h4, 16); + sb.append(c2); + break; + } + case 'U': + { + char h1 = str.charAt(++i); + char h2 = str.charAt(++i); + char h3 = str.charAt(++i); + char h4 = str.charAt(++i); + char h5 = str.charAt(++i); + char h6 = str.charAt(++i); + char h7 = str.charAt(++i); + char h8 = str.charAt(++i); + String encoded = "" + h1 + h2 + h3 + h4 + h5 + h6 + h7 + h8; + String s = new String(Character.toChars(Integer.parseInt(encoded, 16))); + sb.append(s); + break; + } + case 'n': + sb.append('\n'); + break; + case 'r': + sb.append('\r'); + break; + case 't': + sb.append('\t'); + break; + default: + if (str.length() > 80) { + str = str.substring(0, 80); } + throw new PickleException( + "invalid escape sequence char " + + "\'" + + (c2) + + "\' in string \"" + + str + + " [...]\" (possibly truncated)"); } - return sb.toString(); + } else { + sb.append(str.charAt(i)); + } } + return sb.toString(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Pickler.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Pickler.java index d7ccbfe05..5ef754300 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Pickler.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Pickler.java @@ -37,509 +37,507 @@ public class Pickler { - private static final String GET = "get"; + private static final String GET = "get"; - private static final String GET_CLASS = "getClass"; + private static final String GET_CLASS = "getClass"; - private static final String IS = "is"; + private static final String IS = "is"; - private static final String CLASS_KEY = "__class__"; + private static final String CLASS_KEY = "__class__"; - private static final int PROTOCOL = 2; + private static final int PROTOCOL = 2; - private static class Memo { + private static class Memo { - public Object obj; - public int index; + public Object obj; + public int index; - public Memo(Object obj, int index) { - this.obj = obj; - this.index = index; - } + public Memo(Object obj, int index) { + this.obj = obj; + this.index = index; } + } - private static final int MAX_RECURSE_DEPTH = 1000; + private static final int MAX_RECURSE_DEPTH = 1000; - private int recurse = 0; + private int recurse = 0; - private OutputStream out; + private OutputStream out; - private static final Map, IObjectPickler> CUSTOM_PICKLER = new HashMap<>(); + private static final Map, IObjectPickler> CUSTOM_PICKLER = new HashMap<>(); - private final boolean useMemo; + private final boolean useMemo; - private final boolean valueCompare; + private final boolean valueCompare; - protected HashMap memo; + protected HashMap memo; - public Pickler() { - this(true); - } + public Pickler() { + this(true); + } - public Pickler(boolean useMemo) { - this(useMemo, true); - } + public Pickler(boolean useMemo) { + this(useMemo, true); + } - public Pickler(boolean useMemo, boolean valueCompare) { - this.useMemo = useMemo; - this.valueCompare = valueCompare; - } + public Pickler(boolean useMemo, boolean valueCompare) { + this.useMemo = useMemo; + this.valueCompare = valueCompare; + } - public void close() throws IOException { - memo = null; - if (out != null) { - out.flush(); - out.close(); - } + public void close() throws IOException { + memo = null; + if (out != null) { + out.flush(); + out.close(); } - - public static synchronized void registerCustomPickler(Class clazz, IObjectPickler pickler) { - CUSTOM_PICKLER.put(clazz, pickler); + } + + public static synchronized void registerCustomPickler(Class clazz, IObjectPickler pickler) { + CUSTOM_PICKLER.put(clazz, pickler); + } + + public byte[] dumps(Object o) throws PickleException, IOException { + ByteArrayOutputStream bo = new ByteArrayOutputStream(); + dump(o, bo); + bo.flush(); + return bo.toByteArray(); + } + + public void dump(Object o, OutputStream stream) throws IOException, PickleException { + out = stream; + recurse = 0; + if (useMemo) { + memo = new HashMap<>(); } - - public byte[] dumps(Object o) throws PickleException, IOException { - ByteArrayOutputStream bo = new ByteArrayOutputStream(); - dump(o, bo); - bo.flush(); - return bo.toByteArray(); + out.write(OpCodeConstant.PROTO); + out.write(PROTOCOL); + save(o); + memo = null; + out.write(OpCodeConstant.STOP); + out.flush(); + if (recurse != 0) { + throw new PickleException("recursive structure error, please report this problem"); } + } - public void dump(Object o, OutputStream stream) throws IOException, PickleException { - out = stream; - recurse = 0; - if (useMemo) { - memo = new HashMap<>(); - } - out.write(OpCodeConstant.PROTO); - out.write(PROTOCOL); - save(o); - memo = null; - out.write(OpCodeConstant.STOP); - out.flush(); - if (recurse != 0) { - throw new PickleException("recursive structure error, please report this problem"); - } + public void save(Object o) throws PickleException, IOException { + recurse++; + if (recurse > MAX_RECURSE_DEPTH) { + throw new StackOverflowError( + "recursion too deep in Pickler.save (>" + MAX_RECURSE_DEPTH + ")"); } - - public void save(Object o) throws PickleException, IOException { - recurse++; - if (recurse > MAX_RECURSE_DEPTH) { - throw new StackOverflowError("recursion too deep in Pickler.save (>" + MAX_RECURSE_DEPTH + ")"); - } - - if (o == null) { - out.write(OpCodeConstant.NONE); - recurse--; - return; - } - - Class t = o.getClass(); - if (lookupMemo(t, o) || dispatch(t, o)) { - recurse--; - return; - } - - throw new PickleException("couldn't pickle object of type " + t); + if (o == null) { + out.write(OpCodeConstant.NONE); + recurse--; + return; } - protected void writeMemo(Object obj) throws IOException { - if (!this.useMemo) { - return; - } - int hash = valueCompare ? obj.hashCode() : System.identityHashCode(obj); - if (!memo.containsKey(hash)) { - int memoIndex = memo.size(); - memo.put(hash, new Memo(obj, memoIndex)); - if (memoIndex <= 0xFF) { - out.write(OpCodeConstant.BINPUT); - out.write((byte) memoIndex); - } else { - out.write(OpCodeConstant.LONG_BINPUT); - byte[] indexBytes = PickleUtils.integer2Bytes(memoIndex); - out.write(indexBytes, 0, 4); - } - } + Class t = o.getClass(); + if (lookupMemo(t, o) || dispatch(t, o)) { + recurse--; + return; } + throw new PickleException("couldn't pickle object of type " + t); + } - private boolean lookupMemo(Class objectType, Object obj) throws IOException { - if (!this.useMemo) { - return false; - } - if (!objectType.isPrimitive()) { - int hash = valueCompare ? obj.hashCode() : System.identityHashCode(obj); - if (memo.containsKey(hash) - && (valueCompare ? memo.get(hash).obj.equals(obj) : memo.get(hash).obj == obj)) { - int memoIndex = memo.get(hash).index; - if (memoIndex <= 0xff) { - out.write(OpCodeConstant.BINGET); - out.write((byte) memoIndex); - } else { - out.write(OpCodeConstant.LONG_BINGET); - byte[] indexBytes = PickleUtils.integer2Bytes(memoIndex); - out.write(indexBytes, 0, 4); - } - return true; - } - } - return false; + protected void writeMemo(Object obj) throws IOException { + if (!this.useMemo) { + return; } - - private boolean dispatch(Class t, Object o) throws IOException { - Class componentType = t.getComponentType(); - if (componentType != null) { - if (componentType.isPrimitive()) { - putArrayOfPrimitives(componentType, o); - } else { - putArrayOfObjects((Object[]) o); - } - return true; - } - - if (o instanceof Boolean || t.equals(Boolean.TYPE)) { - putBool((Boolean) o); - return true; - } - if (o instanceof Byte || t.equals(Byte.TYPE)) { - putLong(((Byte) o).longValue()); - return true; - } - if (o instanceof Short || t.equals(Short.TYPE)) { - putLong(((Short) o).longValue()); - return true; - } - if (o instanceof Integer || t.equals(Integer.TYPE)) { - putLong(((Integer) o).longValue()); - return true; - } - if (o instanceof Long || t.equals(Long.TYPE)) { - putLong(((Long) o).longValue()); - return true; - } - if (o instanceof Float || t.equals(Float.TYPE)) { - putFloat(((Float) o).doubleValue()); - return true; - } - if (o instanceof Double || t.equals(Double.TYPE)) { - putFloat(((Double) o).doubleValue()); - return true; - } - if (o instanceof Character || t.equals(Character.TYPE)) { - putString("" + o); - return true; - } - - IObjectPickler customPickler = getCustomPickler(t); - if (customPickler != null) { - customPickler.pickle(o, this.out, this); - writeMemo(o); - return true; - } - - if (o instanceof String) { - putString((String) o); - return true; - } - if (o instanceof BigInteger) { - putBigint((BigInteger) o); - return true; - } - if (o instanceof BigDecimal) { - putDecimal((BigDecimal) o); - return true; - } - if (o instanceof Enum) { - putString(o.toString()); - return true; - } - if (o instanceof Set) { - putSet((Set) o); - return true; - } - if (o instanceof Map) { - putMap((Map) o); - return true; - } - if (o instanceof List) { - putCollection((List) o); - return true; - } - if (o instanceof Collection) { - putCollection((Collection) o); - return true; - } - if (o instanceof java.io.Serializable) { - putJavabean(o); - return true; - } - return false; + int hash = valueCompare ? obj.hashCode() : System.identityHashCode(obj); + if (!memo.containsKey(hash)) { + int memoIndex = memo.size(); + memo.put(hash, new Memo(obj, memoIndex)); + if (memoIndex <= 0xFF) { + out.write(OpCodeConstant.BINPUT); + out.write((byte) memoIndex); + } else { + out.write(OpCodeConstant.LONG_BINPUT); + byte[] indexBytes = PickleUtils.integer2Bytes(memoIndex); + out.write(indexBytes, 0, 4); + } } + } - private synchronized IObjectPickler getCustomPickler(Class t) { - IObjectPickler pickler = CUSTOM_PICKLER.get(t); - if (pickler != null) { - return pickler; - } - for (Entry, IObjectPickler> x : CUSTOM_PICKLER.entrySet()) { - if (x.getKey().isAssignableFrom(t)) { - return x.getValue(); - } - } - return null; + private boolean lookupMemo(Class objectType, Object obj) throws IOException { + if (!this.useMemo) { + return false; } - - private void putCollection(Collection list) throws IOException { - out.write(OpCodeConstant.EMPTY_LIST); - writeMemo(list); - out.write(OpCodeConstant.MARK); - for (Object o : list) { - save(o); + if (!objectType.isPrimitive()) { + int hash = valueCompare ? obj.hashCode() : System.identityHashCode(obj); + if (memo.containsKey(hash) + && (valueCompare ? memo.get(hash).obj.equals(obj) : memo.get(hash).obj == obj)) { + int memoIndex = memo.get(hash).index; + if (memoIndex <= 0xff) { + out.write(OpCodeConstant.BINGET); + out.write((byte) memoIndex); + } else { + out.write(OpCodeConstant.LONG_BINGET); + byte[] indexBytes = PickleUtils.integer2Bytes(memoIndex); + out.write(indexBytes, 0, 4); } - out.write(OpCodeConstant.APPENDS); + return true; + } } - - private void putMap(Map o) throws IOException { - out.write(OpCodeConstant.EMPTY_DICT); - writeMemo(o); - out.write(OpCodeConstant.MARK); - for (Object k : o.keySet()) { - save(k); - save(o.get(k)); - } - out.write(OpCodeConstant.SETITEMS); + return false; + } + + private boolean dispatch(Class t, Object o) throws IOException { + Class componentType = t.getComponentType(); + if (componentType != null) { + if (componentType.isPrimitive()) { + putArrayOfPrimitives(componentType, o); + } else { + putArrayOfObjects((Object[]) o); + } + return true; } - private void putSet(Set o) throws IOException { - out.write(OpCodeConstant.GLOBAL); - out.write("__builtin__\nset\n".getBytes()); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (Object x : o) { - save(x); - } - out.write(OpCodeConstant.APPENDS); - out.write(OpCodeConstant.TUPLE1); - out.write(OpCodeConstant.REDUCE); - writeMemo(o); - } - - private void putArrayOfObjects(Object[] array) throws IOException { - if (array.length == 0) { - out.write(OpCodeConstant.EMPTY_TUPLE); - } else if (array.length == 1) { - if (array[0] == array) { - throw new PickleException("recursive array not supported, use list"); - } - save(array[0]); - out.write(OpCodeConstant.TUPLE1); - } else if (array.length == 2) { - if (array[0] == array || array[1] == array) { - throw new PickleException("recursive array not supported, use list"); - } - save(array[0]); - save(array[1]); - out.write(OpCodeConstant.TUPLE2); - } else if (array.length == 3) { - if (array[0] == array || array[1] == array || array[2] == array) { - throw new PickleException("recursive array not supported, use list"); - } - save(array[0]); - save(array[1]); - save(array[2]); - out.write(OpCodeConstant.TUPLE3); - } else { - out.write(OpCodeConstant.MARK); - for (Object o : array) { - if (o == array) { - throw new PickleException("recursive array not supported, use list"); - } - save(o); - } - out.write(OpCodeConstant.TUPLE); - } - writeMemo(array); + if (o instanceof Boolean || t.equals(Boolean.TYPE)) { + putBool((Boolean) o); + return true; } - - private void putArrayOfPrimitives(Class t, Object array) throws IOException { - if (t.equals(Boolean.TYPE)) { - boolean[] source = (boolean[]) array; - Boolean[] boolArray = new Boolean[source.length]; - for (int i = 0; i < source.length; ++i) { - boolArray[i] = source[i]; - } - putArrayOfObjects(boolArray); - return; - } - if (t.equals(Character.TYPE)) { - String s = new String((char[]) array); - putString(s); - return; - } - if (t.equals(Byte.TYPE)) { - out.write(OpCodeConstant.GLOBAL); - out.write("__builtin__\nbytearray\n".getBytes()); - String str = PickleUtils.rawStringFromBytes((byte[]) array); - putString(str); - putString("latin-1"); - out.write(OpCodeConstant.TUPLE2); - out.write(OpCodeConstant.REDUCE); - writeMemo(array); - return; - } - - out.write(OpCodeConstant.GLOBAL); - out.write("array\narray\n".getBytes()); - out.write(OpCodeConstant.SHORT_BINSTRING); - out.write(1); - - if (t.equals(Short.TYPE)) { - out.write('h'); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (short s : (short[]) array) { - save(s); - } - } else if (t.equals(Integer.TYPE)) { - out.write('i'); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (int i : (int[]) array) { - save(i); - } - } else if (t.equals(Long.TYPE)) { - out.write('l'); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (long v : (long[]) array) { - save(v); - } - } else if (t.equals(Float.TYPE)) { - out.write('f'); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (float f : (float[]) array) { - save(f); - } - } else if (t.equals(Double.TYPE)) { - out.write('d'); - out.write(OpCodeConstant.EMPTY_LIST); - out.write(OpCodeConstant.MARK); - for (double d : (double[]) array) { - save(d); - } - } - - out.write(OpCodeConstant.APPENDS); - out.write(OpCodeConstant.TUPLE2); - out.write(OpCodeConstant.REDUCE); - - writeMemo(array); + if (o instanceof Byte || t.equals(Byte.TYPE)) { + putLong(((Byte) o).longValue()); + return true; } - - private void putDecimal(BigDecimal d) throws IOException { - out.write(OpCodeConstant.GLOBAL); - out.write("decimal\nDecimal\n".getBytes()); - putString(d.toEngineeringString()); - out.write(OpCodeConstant.TUPLE1); - out.write(OpCodeConstant.REDUCE); - writeMemo(d); + if (o instanceof Short || t.equals(Short.TYPE)) { + putLong(((Short) o).longValue()); + return true; } - - - private void putBigint(BigInteger i) throws IOException { - byte[] b = PickleUtils.encodeLong(i); - if (b.length <= 0xff) { - out.write(OpCodeConstant.LONG1); - out.write(b.length); - out.write(b); - } else { - out.write(OpCodeConstant.LONG4); - out.write(PickleUtils.integer2Bytes(b.length)); - out.write(b); - } - writeMemo(i); + if (o instanceof Integer || t.equals(Integer.TYPE)) { + putLong(((Integer) o).longValue()); + return true; + } + if (o instanceof Long || t.equals(Long.TYPE)) { + putLong(((Long) o).longValue()); + return true; + } + if (o instanceof Float || t.equals(Float.TYPE)) { + putFloat(((Float) o).doubleValue()); + return true; + } + if (o instanceof Double || t.equals(Double.TYPE)) { + putFloat(((Double) o).doubleValue()); + return true; + } + if (o instanceof Character || t.equals(Character.TYPE)) { + putString("" + o); + return true; } - private void putString(String string) throws IOException { - byte[] encoded = string.getBytes(StandardCharsets.UTF_8); - out.write(OpCodeConstant.BINUNICODE); - out.write(PickleUtils.integer2Bytes(encoded.length)); - out.write(encoded); - writeMemo(string); + IObjectPickler customPickler = getCustomPickler(t); + if (customPickler != null) { + customPickler.pickle(o, this.out, this); + writeMemo(o); + return true; } - private void putFloat(double d) throws IOException { - out.write(OpCodeConstant.BINFLOAT); - out.write(PickleUtils.double2Bytes(d)); + if (o instanceof String) { + putString((String) o); + return true; } + if (o instanceof BigInteger) { + putBigint((BigInteger) o); + return true; + } + if (o instanceof BigDecimal) { + putDecimal((BigDecimal) o); + return true; + } + if (o instanceof Enum) { + putString(o.toString()); + return true; + } + if (o instanceof Set) { + putSet((Set) o); + return true; + } + if (o instanceof Map) { + putMap((Map) o); + return true; + } + if (o instanceof List) { + putCollection((List) o); + return true; + } + if (o instanceof Collection) { + putCollection((Collection) o); + return true; + } + if (o instanceof java.io.Serializable) { + putJavabean(o); + return true; + } + return false; + } - private void putLong(long v) throws IOException { - if (v >= 0) { - if (v <= 0xff) { - out.write(OpCodeConstant.BININT1); - out.write((int) v); - return; - } - if (v <= 0xffff) { - out.write(OpCodeConstant.BININT2); - out.write((int) v & 0xff); - out.write((int) v >> 8); - return; - } + private synchronized IObjectPickler getCustomPickler(Class t) { + IObjectPickler pickler = CUSTOM_PICKLER.get(t); + if (pickler != null) { + return pickler; + } + for (Entry, IObjectPickler> x : CUSTOM_PICKLER.entrySet()) { + if (x.getKey().isAssignableFrom(t)) { + return x.getValue(); + } + } + return null; + } + + private void putCollection(Collection list) throws IOException { + out.write(OpCodeConstant.EMPTY_LIST); + writeMemo(list); + out.write(OpCodeConstant.MARK); + for (Object o : list) { + save(o); + } + out.write(OpCodeConstant.APPENDS); + } + + private void putMap(Map o) throws IOException { + out.write(OpCodeConstant.EMPTY_DICT); + writeMemo(o); + out.write(OpCodeConstant.MARK); + for (Object k : o.keySet()) { + save(k); + save(o.get(k)); + } + out.write(OpCodeConstant.SETITEMS); + } + + private void putSet(Set o) throws IOException { + out.write(OpCodeConstant.GLOBAL); + out.write("__builtin__\nset\n".getBytes()); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (Object x : o) { + save(x); + } + out.write(OpCodeConstant.APPENDS); + out.write(OpCodeConstant.TUPLE1); + out.write(OpCodeConstant.REDUCE); + writeMemo(o); + } + + private void putArrayOfObjects(Object[] array) throws IOException { + if (array.length == 0) { + out.write(OpCodeConstant.EMPTY_TUPLE); + } else if (array.length == 1) { + if (array[0] == array) { + throw new PickleException("recursive array not supported, use list"); + } + save(array[0]); + out.write(OpCodeConstant.TUPLE1); + } else if (array.length == 2) { + if (array[0] == array || array[1] == array) { + throw new PickleException("recursive array not supported, use list"); + } + save(array[0]); + save(array[1]); + out.write(OpCodeConstant.TUPLE2); + } else if (array.length == 3) { + if (array[0] == array || array[1] == array || array[2] == array) { + throw new PickleException("recursive array not supported, use list"); + } + save(array[0]); + save(array[1]); + save(array[2]); + out.write(OpCodeConstant.TUPLE3); + } else { + out.write(OpCodeConstant.MARK); + for (Object o : array) { + if (o == array) { + throw new PickleException("recursive array not supported, use list"); } + save(o); + } + out.write(OpCodeConstant.TUPLE); + } + writeMemo(array); + } + + private void putArrayOfPrimitives(Class t, Object array) throws IOException { + if (t.equals(Boolean.TYPE)) { + boolean[] source = (boolean[]) array; + Boolean[] boolArray = new Boolean[source.length]; + for (int i = 0; i < source.length; ++i) { + boolArray[i] = source[i]; + } + putArrayOfObjects(boolArray); + return; + } + if (t.equals(Character.TYPE)) { + String s = new String((char[]) array); + putString(s); + return; + } + if (t.equals(Byte.TYPE)) { + out.write(OpCodeConstant.GLOBAL); + out.write("__builtin__\nbytearray\n".getBytes()); + String str = PickleUtils.rawStringFromBytes((byte[]) array); + putString(str); + putString("latin-1"); + out.write(OpCodeConstant.TUPLE2); + out.write(OpCodeConstant.REDUCE); + writeMemo(array); + return; + } - long highBits = v >> 31; - if (highBits == 0 || highBits == -1) { - out.write(OpCodeConstant.BININT); - out.write(PickleUtils.integer2Bytes((int) v)); - return; - } - putBigint(BigInteger.valueOf(v)); + out.write(OpCodeConstant.GLOBAL); + out.write("array\narray\n".getBytes()); + out.write(OpCodeConstant.SHORT_BINSTRING); + out.write(1); + + if (t.equals(Short.TYPE)) { + out.write('h'); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (short s : (short[]) array) { + save(s); + } + } else if (t.equals(Integer.TYPE)) { + out.write('i'); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (int i : (int[]) array) { + save(i); + } + } else if (t.equals(Long.TYPE)) { + out.write('l'); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (long v : (long[]) array) { + save(v); + } + } else if (t.equals(Float.TYPE)) { + out.write('f'); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (float f : (float[]) array) { + save(f); + } + } else if (t.equals(Double.TYPE)) { + out.write('d'); + out.write(OpCodeConstant.EMPTY_LIST); + out.write(OpCodeConstant.MARK); + for (double d : (double[]) array) { + save(d); + } } - private void putBool(boolean b) throws IOException { - if (b) { - out.write(OpCodeConstant.NEWTRUE); - } else { - out.write(OpCodeConstant.NEWFALSE); - } + out.write(OpCodeConstant.APPENDS); + out.write(OpCodeConstant.TUPLE2); + out.write(OpCodeConstant.REDUCE); + + writeMemo(array); + } + + private void putDecimal(BigDecimal d) throws IOException { + out.write(OpCodeConstant.GLOBAL); + out.write("decimal\nDecimal\n".getBytes()); + putString(d.toEngineeringString()); + out.write(OpCodeConstant.TUPLE1); + out.write(OpCodeConstant.REDUCE); + writeMemo(d); + } + + private void putBigint(BigInteger i) throws IOException { + byte[] b = PickleUtils.encodeLong(i); + if (b.length <= 0xff) { + out.write(OpCodeConstant.LONG1); + out.write(b.length); + out.write(b); + } else { + out.write(OpCodeConstant.LONG4); + out.write(PickleUtils.integer2Bytes(b.length)); + out.write(b); + } + writeMemo(i); + } + + private void putString(String string) throws IOException { + byte[] encoded = string.getBytes(StandardCharsets.UTF_8); + out.write(OpCodeConstant.BINUNICODE); + out.write(PickleUtils.integer2Bytes(encoded.length)); + out.write(encoded); + writeMemo(string); + } + + private void putFloat(double d) throws IOException { + out.write(OpCodeConstant.BINFLOAT); + out.write(PickleUtils.double2Bytes(d)); + } + + private void putLong(long v) throws IOException { + if (v >= 0) { + if (v <= 0xff) { + out.write(OpCodeConstant.BININT1); + out.write((int) v); + return; + } + if (v <= 0xffff) { + out.write(OpCodeConstant.BININT2); + out.write((int) v & 0xff); + out.write((int) v >> 8); + return; + } } - private void putJavabean(Object o) throws PickleException, IOException { - Map map = new HashMap<>(); - try { - for (Method m : o.getClass().getMethods()) { - int modifiers = m.getModifiers(); - if ((modifiers & Modifier.PUBLIC) != 0 && (modifiers & Modifier.STATIC) == 0) { - String methodName = m.getName(); - int prefixLen; - if (methodName.equals(GET_CLASS)) { - continue; - } - if (methodName.startsWith(GET)) { - prefixLen = 3; - } else if (methodName.startsWith(IS)) { - prefixLen = 2; - } else { - continue; - } - Object value = m.invoke(o); - String name = methodName.substring(prefixLen); - if (name.length() == 1) { - name = name.toLowerCase(); - } else { - if (!Character.isUpperCase(name.charAt(1))) { - name = Character.toLowerCase(name.charAt(0)) + name.substring(1); - } - } - map.put(name, value); - } + long highBits = v >> 31; + if (highBits == 0 || highBits == -1) { + out.write(OpCodeConstant.BININT); + out.write(PickleUtils.integer2Bytes((int) v)); + return; + } + putBigint(BigInteger.valueOf(v)); + } + + private void putBool(boolean b) throws IOException { + if (b) { + out.write(OpCodeConstant.NEWTRUE); + } else { + out.write(OpCodeConstant.NEWFALSE); + } + } + + private void putJavabean(Object o) throws PickleException, IOException { + Map map = new HashMap<>(); + try { + for (Method m : o.getClass().getMethods()) { + int modifiers = m.getModifiers(); + if ((modifiers & Modifier.PUBLIC) != 0 && (modifiers & Modifier.STATIC) == 0) { + String methodName = m.getName(); + int prefixLen; + if (methodName.equals(GET_CLASS)) { + continue; + } + if (methodName.startsWith(GET)) { + prefixLen = 3; + } else if (methodName.startsWith(IS)) { + prefixLen = 2; + } else { + continue; + } + Object value = m.invoke(o); + String name = methodName.substring(prefixLen); + if (name.length() == 1) { + name = name.toLowerCase(); + } else { + if (!Character.isUpperCase(name.charAt(1))) { + name = Character.toLowerCase(name.charAt(0)) + name.substring(1); } - map.put(CLASS_KEY, o.getClass().getName()); - save(map); - } catch (IllegalArgumentException | IllegalAccessException | InvocationTargetException e) { - throw new PickleException("couldn't introspect javabean: " + e); - } + } + map.put(name, value); + } + } + map.put(CLASS_KEY, o.getClass().getName()); + save(map); + } catch (IllegalArgumentException | IllegalAccessException | InvocationTargetException e) { + throw new PickleException("couldn't introspect javabean: " + e); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PythonException.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PythonException.java index 66d95b3ac..1e7e162a4 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PythonException.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/PythonException.java @@ -21,40 +21,41 @@ import java.util.HashMap; import java.util.List; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class PythonException extends GeaflowRuntimeException { - private static final long serialVersionUID = 4884843316742683086L; + private static final long serialVersionUID = 4884843316742683086L; - private static final String TRACEBACK = "_pyroTraceback"; + private static final String TRACEBACK = "_pyroTraceback"; - public String errorTraceback; + public String errorTraceback; - public String pythonExceptionType; + public String pythonExceptionType; - public PythonException(String message, Throwable cause) { - super(message, cause); - } + public PythonException(String message, Throwable cause) { + super(message, cause); + } - public PythonException(String message) { - super(message); - } + public PythonException(String message) { + super(message); + } - public PythonException(Throwable cause) { - super(cause); - } + public PythonException(Throwable cause) { + super(cause); + } - public void setState(HashMap args) { - Object tb = args.get(TRACEBACK); - if (tb instanceof List) { - StringBuilder sb = new StringBuilder(); - for (Object line : (List) tb) { - sb.append(line); - } - errorTraceback = sb.toString(); - } else { - errorTraceback = (String) tb; - } + public void setState(HashMap args) { + Object tb = args.get(TRACEBACK); + if (tb instanceof List) { + StringBuilder sb = new StringBuilder(); + for (Object line : (List) tb) { + sb.append(line); + } + errorTraceback = sb.toString(); + } else { + errorTraceback = (String) tb; } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/UnpickleStack.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/UnpickleStack.java index 4af0234e7..bf6a18348 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/UnpickleStack.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/UnpickleStack.java @@ -26,56 +26,56 @@ public class UnpickleStack implements Serializable { - private static final long serialVersionUID = 5032718425413805422L; - private final ArrayList stack; - protected Object marker; + private static final long serialVersionUID = 5032718425413805422L; + private final ArrayList stack; + protected Object marker; - public UnpickleStack() { - stack = new ArrayList<>(); - marker = new Object(); - } + public UnpickleStack() { + stack = new ArrayList<>(); + marker = new Object(); + } - public void add(Object o) { - this.stack.add(o); - } + public void add(Object o) { + this.stack.add(o); + } - public void addMark() { - this.stack.add(this.marker); - } + public void addMark() { + this.stack.add(this.marker); + } - public Object pop() { - int size = this.stack.size(); - Object result = this.stack.get(size - 1); - this.stack.remove(size - 1); - return result; - } + public Object pop() { + int size = this.stack.size(); + Object result = this.stack.get(size - 1); + this.stack.remove(size - 1); + return result; + } - public List popAllSinceMarker() { - ArrayList result = new ArrayList<>(); - Object o = pop(); - while (o != this.marker) { - result.add(o); - o = pop(); - } - result.trimToSize(); - Collections.reverse(result); - return result; + public List popAllSinceMarker() { + ArrayList result = new ArrayList<>(); + Object o = pop(); + while (o != this.marker) { + result.add(o); + o = pop(); } + result.trimToSize(); + Collections.reverse(result); + return result; + } - public Object peek() { - return this.stack.get(this.stack.size() - 1); - } + public Object peek() { + return this.stack.get(this.stack.size() - 1); + } - public void trim() { - this.stack.trimToSize(); - } + public void trim() { + this.stack.trimToSize(); + } - public int size() { - return this.stack.size(); - } + public int size() { + return this.stack.size(); + } - public void clear() { - this.stack.clear(); - this.stack.trimToSize(); - } + public void clear() { + this.stack.clear(); + this.stack.trimToSize(); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Unpickler.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Unpickler.java index 1c70b350a..7a87b2958 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Unpickler.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/exchange/serialize/Unpickler.java @@ -30,766 +30,772 @@ import java.util.HashSet; import java.util.List; import java.util.Map; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Unpickler { - private static final Logger LOGGER = LoggerFactory.getLogger(Unpickler.class); - - private static final String SET_STATE = "setState"; - - private static final String PYTHON_EXCEPTION = "python_exception"; - - private static final String MODULE_SUFFIX = "."; - - private static final String LONG_END_WITH = "L"; - - private static final String EXCEPTIONS = "exceptions"; - - private static final String ERROR = "Error"; - - private static final String EXCEPTION = "Exception"; - - private static final String WARNING = "Warning"; - - private static final String SYSTEM_EXIT = "SystemExit"; - - private static final String GENERATOR_EXIT = "GeneratorExit"; - - private static final String KEYBOARD_INTERRUPT = "KeyboardInterrupt"; - - private static final String STOP_ITERATION = "StopIteration"; - - private static final String BUILTINS = "builtins"; - - private static final String BUILTIN = "__builtin__"; - - protected static final Object NO_RETURN_VALUE = new Object(); - - protected static final int HIGHEST_PROTOCOL = 5; - - protected Map memo; - - protected UnpickleStack stack; - - protected InputStream input; - - protected static Map objectConstructors; - - static { - objectConstructors = new HashMap<>(); - objectConstructors.put("__builtin__.complex", new ComplexNumberConstructor()); - objectConstructors.put("builtins.complex", new ComplexNumberConstructor()); - objectConstructors.put("array.array", new ArrayConstructor()); - objectConstructors.put("array._array_reconstructor", new ArrayConstructor()); - objectConstructors.put("__builtin__.bytearray", new ByteArrayConstructor()); - objectConstructors.put("builtins.bytearray", new ByteArrayConstructor()); - objectConstructors.put("__builtin__.bytes", new ByteArrayConstructor()); - objectConstructors.put("_codecs.encode", new ByteArrayConstructor()); - } - - public Unpickler() { - memo = new HashMap<>(); - stack = new UnpickleStack(); - } - - public static void registerConstructor(String module, String classname, IObjectConstructor constructor) { - objectConstructors.put(module + MODULE_SUFFIX + classname, constructor); - } - - public Object load(InputStream stream) throws PickleException, IOException { - input = stream; - while (true) { - short key = PickleUtils.readByte(input); - if (key == -1) { - throw new IOException("premature end of file"); - } - Object value = dispatch(key); - if (value != NO_RETURN_VALUE) { - return value; - } - } - } - - - public Object loads(byte[] pickleData) throws Exception { - Object loadResult = load(new ByteArrayInputStream(pickleData)); - if (loadResult instanceof String) { - if (loadResult.toString().startsWith(PYTHON_EXCEPTION)) { - throw new RuntimeException(loadResult.toString()); - } - } - return loadResult; - } - - public void close() { - if (stack != null) { - stack.clear(); - } - if (memo != null) { - memo.clear(); - } - if (input != null) { - try { - input.close(); - } catch (IOException e) { - LOGGER.error("input closed fail"); - } - } - } - - protected Object nextBuffer() throws PickleException { - throw new PickleException("pickle stream refers to out-of-band data " - + "but no user-overridden nextBuffer() method is used"); - } - - protected Object dispatch(short key) throws PickleException, IOException { - switch (key) { - case OpCodeConstant.MARK: - loadMark(); - break; - case OpCodeConstant.STOP: - Object value = stack.pop(); - stack.clear(); - memo.clear(); - return value; - case OpCodeConstant.POP: - loadPop(); - break; - case OpCodeConstant.POP_MARK: - loadPopMark(); - break; - case OpCodeConstant.DUP: - loadDup(); - break; - case OpCodeConstant.FLOAT: - loadFloat(); - break; - case OpCodeConstant.INT: - loadInt(); - break; - case OpCodeConstant.BININT: - loadBinint(); - break; - case OpCodeConstant.BININT1: - loadBinint1(); - break; - case OpCodeConstant.LONG: - loadLong(); - break; - case OpCodeConstant.BININT2: - loadBinint2(); - break; - case OpCodeConstant.NONE: - loadNone(); - break; - case OpCodeConstant.PERSID: - loadPersid(); - break; - case OpCodeConstant.BINPERSID: - loadBinpersid(); - break; - case OpCodeConstant.REDUCE: - loadReduce(); - break; - case OpCodeConstant.STRING: - loadString(); - break; - case OpCodeConstant.BINSTRING: - loadBinstring(); - break; - case OpCodeConstant.SHORT_BINSTRING: - loadShortBinstring(); - break; - case OpCodeConstant.UNICODE: - loadUnicode(); - break; - case OpCodeConstant.BINUNICODE: - loadBinunicode(); - break; - case OpCodeConstant.APPEND: - loadAppend(); - break; - case OpCodeConstant.BUILD: - loadBuild(); - break; - case OpCodeConstant.GLOBAL: - loadGlobal(); - break; - case OpCodeConstant.DICT: - loadDict(); - break; - case OpCodeConstant.EMPTY_DICT: - loadEmptyDictionary(); - break; - case OpCodeConstant.APPENDS: - loadAppends(); - break; - case OpCodeConstant.GET: - loadGet(); - break; - case OpCodeConstant.BINGET: - loadBinget(); - break; - case OpCodeConstant.INST: - loadInst(); - break; - case OpCodeConstant.LONG_BINGET: - loadLongBinget(); - break; - case OpCodeConstant.LIST: - loadList(); - break; - case OpCodeConstant.EMPTY_LIST: - loadEmptyList(); - break; - case OpCodeConstant.OBJ: - loadObj(); - break; - case OpCodeConstant.PUT: - loadPut(); - break; - case OpCodeConstant.BINPUT: - loadBinput(); - break; - case OpCodeConstant.LONG_BINPUT: - loadLongBinput(); - break; - case OpCodeConstant.SETITEM: - loadSetitem(); - break; - case OpCodeConstant.TUPLE: - loadTuple(); - break; - case OpCodeConstant.EMPTY_TUPLE: - loadEmptyTuple(); - break; - case OpCodeConstant.SETITEMS: - loadSetitems(); - break; - case OpCodeConstant.BINFLOAT: - loadBinfloat(); - break; - - case OpCodeConstant.PROTO: - loadProto(); - break; - case OpCodeConstant.NEWOBJ: - loadNewBbj(); - break; - case OpCodeConstant.EXT1: - case OpCodeConstant.EXT2: - case OpCodeConstant.EXT4: - throw new PickleException("Unimplemented opcode EXT1/EXT2/EXT4 encountered."); - case OpCodeConstant.TUPLE1: - loadTuple1(); - break; - case OpCodeConstant.TUPLE2: - loadTuple2(); - break; - case OpCodeConstant.TUPLE3: - loadTuple3(); - break; - case OpCodeConstant.NEWTRUE: - loadTrue(); - break; - case OpCodeConstant.NEWFALSE: - loadFalse(); - break; - case OpCodeConstant.LONG1: - loadLong1(); - break; - case OpCodeConstant.LONG4: - loadLong4(); - break; - - case OpCodeConstant.BINBYTES: - loadBinBytes(); - break; - case OpCodeConstant.SHORT_BINBYTES: - loadShortBinBytes(); - break; - - case OpCodeConstant.BINUNICODE8: - loadBinunicode8(); - break; - case OpCodeConstant.SHORT_BINUNICODE: - loadShortBinunicode(); - break; - case OpCodeConstant.BINBYTES8: - loadBinBytes8(); - break; - case OpCodeConstant.EMPTY_SET: - loadEmptySet(); - break; - case OpCodeConstant.ADDITEMS: - loadAddItems(); - break; - case OpCodeConstant.FROZENSET: - loadFrozenset(); - break; - case OpCodeConstant.MEMOIZE: - loadMemoize(); - break; - case OpCodeConstant.FRAME: - loadFrame(); - break; - case OpCodeConstant.NEWOBJ_EX: - loadNewObjEx(); - break; - case OpCodeConstant.STACK_GLOBAL: - loadStackGlobal(); - break; - - case OpCodeConstant.BYTEARRAY8: - loadByteArray8(); - break; - case OpCodeConstant.READONLY_BUFFER: - loadReadonlyBuffer(); - break; - case OpCodeConstant.NEXT_BUFFER: - loadNextBuffer(); - break; - - default: - throw new InvalidOpcodeException("invalid pickle opcode: " + key); - } + private static final Logger LOGGER = LoggerFactory.getLogger(Unpickler.class); - return NO_RETURN_VALUE; - } + private static final String SET_STATE = "setState"; - private void loadReadonlyBuffer() { - } + private static final String PYTHON_EXCEPTION = "python_exception"; - private void loadNextBuffer() throws PickleException { - stack.add(nextBuffer()); - } + private static final String MODULE_SUFFIX = "."; - private void loadByteArray8() throws IOException { - long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); - stack.add(PickleUtils.readBytes(input, len)); - } + private static final String LONG_END_WITH = "L"; - private void loadBuild() { - Object args = stack.pop(); - Object target = stack.peek(); - try { - Method setStateMethod = target.getClass().getMethod(SET_STATE, args.getClass()); - setStateMethod.invoke(target, args); - } catch (Exception e) { - throw new PickleException("failed to setState()", e); - } - } - - private void loadProto() throws IOException { - short proto = PickleUtils.readByte(input); - if (proto < 0 || proto > HIGHEST_PROTOCOL) { - throw new PickleException("unsupported pickle protocol: " + proto); - } - } - - private void loadNone() { - stack.add(null); - } - - private void loadFalse() { - stack.add(false); - } - - private void loadTrue() { - stack.add(true); - } - - private void loadInt() throws IOException { - String data = PickleUtils.readLine(input, true); - Object val; - if (data.equals(OpCodeConstant.FALSE.substring(1))) { - val = false; - } else if (data.equals(OpCodeConstant.TRUE.substring(1))) { - val = true; - } else { - String number = data.substring(0, data.length() - 1); - try { - val = Integer.parseInt(number, 10); - } catch (NumberFormatException x) { - val = Long.parseLong(number, 10); - } - } - stack.add(val); - } + private static final String EXCEPTIONS = "exceptions"; - private void loadBinint() throws IOException { - int integer = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - stack.add(integer); - } + private static final String ERROR = "Error"; - private void loadBinint1() throws IOException { - stack.add((int) PickleUtils.readByte(input)); - } + private static final String EXCEPTION = "Exception"; - private void loadBinint2() throws IOException { - int integer = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 2)); - stack.add(integer); - } + private static final String WARNING = "Warning"; - private void loadLong() throws IOException { - String val = PickleUtils.readLine(input); - if (val.endsWith(LONG_END_WITH)) { - val = val.substring(0, val.length() - 1); - } - BigInteger bi = new BigInteger(val); - stack.add(PickleUtils.optimizeBigint(bi)); - } + private static final String SYSTEM_EXIT = "SystemExit"; - private void loadLong1() throws IOException { - short n = PickleUtils.readByte(input); - byte[] data = PickleUtils.readBytes(input, n); - stack.add(PickleUtils.decodeLong(data)); - } + private static final String GENERATOR_EXIT = "GeneratorExit"; - private void loadLong4() throws IOException { - int n = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - byte[] data = PickleUtils.readBytes(input, n); - stack.add(PickleUtils.decodeLong(data)); - } + private static final String KEYBOARD_INTERRUPT = "KeyboardInterrupt"; - private void loadFloat() throws IOException { - String val = PickleUtils.readLine(input, true); - stack.add(Double.parseDouble(val)); - } + private static final String STOP_ITERATION = "StopIteration"; - private void loadBinfloat() throws IOException { - double val = PickleUtils.bytes2Double(PickleUtils.readBytes(input, 8), 0); - stack.add(val); - } + private static final String BUILTINS = "builtins"; - private void loadString() throws IOException { - String rep = PickleUtils.readLine(input); - boolean quotesOk = false; - for (String q : new String[]{"\"", "'"}) { - if (rep.startsWith(q)) { - if (!rep.endsWith(q)) { - throw new PickleException("insecure string pickle"); - } - rep = rep.substring(1, rep.length() - 1); - quotesOk = true; - break; - } - } + private static final String BUILTIN = "__builtin__"; - if (!quotesOk) { - throw new PickleException("insecure string pickle"); - } + protected static final Object NO_RETURN_VALUE = new Object(); - stack.add(PickleUtils.decodeEscaped(rep)); - } + protected static final int HIGHEST_PROTOCOL = 5; - private void loadBinstring() throws IOException { - int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - byte[] data = PickleUtils.readBytes(input, len); - stack.add(PickleUtils.rawStringFromBytes(data)); - } + protected Map memo; - private void loadBinBytes() throws IOException { - int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - stack.add(PickleUtils.readBytes(input, len)); - } + protected UnpickleStack stack; - private void loadBinBytes8() throws IOException { - long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); - stack.add(PickleUtils.readBytes(input, len)); - } + protected InputStream input; - private void loadUnicode() throws IOException { - String str = PickleUtils.decodeUnicodeEscaped(PickleUtils.readLine(input)); - stack.add(str); - } + protected static Map objectConstructors; - private void loadBinunicode() throws IOException { - int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - byte[] data = PickleUtils.readBytes(input, len); - stack.add(new String(data, StandardCharsets.UTF_8)); - } + static { + objectConstructors = new HashMap<>(); + objectConstructors.put("__builtin__.complex", new ComplexNumberConstructor()); + objectConstructors.put("builtins.complex", new ComplexNumberConstructor()); + objectConstructors.put("array.array", new ArrayConstructor()); + objectConstructors.put("array._array_reconstructor", new ArrayConstructor()); + objectConstructors.put("__builtin__.bytearray", new ByteArrayConstructor()); + objectConstructors.put("builtins.bytearray", new ByteArrayConstructor()); + objectConstructors.put("__builtin__.bytes", new ByteArrayConstructor()); + objectConstructors.put("_codecs.encode", new ByteArrayConstructor()); + } - private void loadBinunicode8() throws IOException { - long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); - byte[] data = PickleUtils.readBytes(input, len); - stack.add(new String(data, StandardCharsets.UTF_8)); - } + public Unpickler() { + memo = new HashMap<>(); + stack = new UnpickleStack(); + } - private void loadShortBinunicode() throws IOException { - int len = PickleUtils.readByte(input); - byte[] data = PickleUtils.readBytes(input, len); - stack.add(new String(data, StandardCharsets.UTF_8)); - } + public static void registerConstructor( + String module, String classname, IObjectConstructor constructor) { + objectConstructors.put(module + MODULE_SUFFIX + classname, constructor); + } - private void loadShortBinstring() throws IOException { - short len = PickleUtils.readByte(input); - byte[] data = PickleUtils.readBytes(input, len); - stack.add(PickleUtils.rawStringFromBytes(data)); + public Object load(InputStream stream) throws PickleException, IOException { + input = stream; + while (true) { + short key = PickleUtils.readByte(input); + if (key == -1) { + throw new IOException("premature end of file"); + } + Object value = dispatch(key); + if (value != NO_RETURN_VALUE) { + return value; + } } + } - private void loadShortBinBytes() throws IOException { - short len = PickleUtils.readByte(input); - stack.add(PickleUtils.readBytes(input, len)); + public Object loads(byte[] pickleData) throws Exception { + Object loadResult = load(new ByteArrayInputStream(pickleData)); + if (loadResult instanceof String) { + if (loadResult.toString().startsWith(PYTHON_EXCEPTION)) { + throw new RuntimeException(loadResult.toString()); + } } + return loadResult; + } - private void loadTuple() { - List top = stack.popAllSinceMarker(); - stack.add(top.toArray()); + public void close() { + if (stack != null) { + stack.clear(); } - - private void loadEmptyTuple() { - stack.add(new Object[0]); + if (memo != null) { + memo.clear(); } - - private void loadTuple1() { - stack.add(new Object[]{stack.pop()}); + if (input != null) { + try { + input.close(); + } catch (IOException e) { + LOGGER.error("input closed fail"); + } } + } - private void loadTuple2() { - Object o2 = stack.pop(); - Object o1 = stack.pop(); - stack.add(new Object[]{o1, o2}); - } - - private void loadTuple3() { - Object o3 = stack.pop(); - Object o2 = stack.pop(); - Object o1 = stack.pop(); - stack.add(new Object[]{o1, o2, o3}); - } - - private void loadEmptyList() { - stack.add(new ArrayList<>(0)); - } - - private void loadEmptyDictionary() { - stack.add(new HashMap<>(0)); - } - - private void loadEmptySet() { - stack.add(new HashSet<>()); - } - - private void loadList() { - List top = stack.popAllSinceMarker(); - stack.add(top); - } - - private void loadDict() { - List top = stack.popAllSinceMarker(); - HashMap map = new HashMap<>(top.size()); - for (int i = 0; i < top.size(); i += 2) { - Object key = top.get(i); - Object value = top.get(i + 1); - map.put(key, value); - } - stack.add(map); - } - - private void loadFrozenset() { - List top = stack.popAllSinceMarker(); - HashSet set = new HashSet<>(); - set.addAll(top); - stack.add(set); - } - - private void loadAddItems() { - List top = stack.popAllSinceMarker(); - HashSet set = (HashSet) stack.pop(); - set.addAll(top); - stack.add(set); - } - - private void loadGlobal() throws IOException { - String module = PickleUtils.readLine(input); - String name = PickleUtils.readLine(input); - loadGlobalSub(module, name); - } - - private void loadStackGlobal() { - String name = (String) stack.pop(); - String module = (String) stack.pop(); - loadGlobalSub(module, name); - } - - private void loadGlobalSub(String module, String name) { - IObjectConstructor constructor = objectConstructors.get(module + MODULE_SUFFIX + name); - if (constructor == null) { - if (module.equals(EXCEPTIONS)) { - constructor = new ExceptionConstructor(PythonException.class, module, name); - } else if (module.equals(BUILTIN) || module.equals(BUILTINS)) { - if (name.endsWith(ERROR) || name.endsWith(WARNING) - || name.endsWith(EXCEPTION) || name.equals(KEYBOARD_INTERRUPT) - || name.equals(STOP_ITERATION) || name.equals(GENERATOR_EXIT) - || name.equals(SYSTEM_EXIT)) { - constructor = new ExceptionConstructor(PythonException.class, module, name); - } else { - constructor = new DictionaryConstructor(module, name); - } - } else { - constructor = new DictionaryConstructor(module, name); - } - } - stack.add(constructor); - } - - private void loadPop() { - stack.pop(); - } - - private void loadPopMark() { - Object o; - do { - o = stack.pop(); - } while (o != stack.marker); - stack.trim(); - } - - private void loadDup() { - stack.add(stack.peek()); - } - - private void loadGet() throws IOException { - int i = Integer.parseInt(PickleUtils.readLine(input), 10); - if (!memo.containsKey(i)) { - throw new PickleException("invalid memo key"); - } - stack.add(memo.get(i)); - } - - private void loadBinget() throws IOException { - int i = PickleUtils.readByte(input); - if (!memo.containsKey(i)) { - throw new PickleException("invalid memo key"); - } - stack.add(memo.get(i)); - } - - private void loadLongBinget() throws IOException { - int i = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - if (!memo.containsKey(i)) { - throw new PickleException("invalid memo key"); - } - stack.add(memo.get(i)); - } - - private void loadPut() throws IOException { - int i = Integer.parseInt(PickleUtils.readLine(input), 10); - memo.put(i, stack.peek()); - } - - private void loadBinput() throws IOException { - int i = PickleUtils.readByte(input); - memo.put(i, stack.peek()); - } - - private void loadLongBinput() throws IOException { - int i = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); - memo.put(i, stack.peek()); - } - - private void loadMemoize() { - memo.put(memo.size(), stack.peek()); - } - - private void loadAppend() { - Object value = stack.pop(); - ArrayList list = (ArrayList) stack.peek(); - list.add(value); - } - - private void loadAppends() { - List top = stack.popAllSinceMarker(); - ArrayList list = (ArrayList) stack.peek(); - list.addAll(top); - list.trimToSize(); - } - - private void loadSetitem() { - Object value = stack.pop(); - Object key = stack.pop(); - Map dict = (Map) stack.peek(); - dict.put(key, value); - } - - private void loadSetitems() { - HashMap newItems = new HashMap<>(); + protected Object nextBuffer() throws PickleException { + throw new PickleException( + "pickle stream refers to out-of-band data " + + "but no user-overridden nextBuffer() method is used"); + } + + protected Object dispatch(short key) throws PickleException, IOException { + switch (key) { + case OpCodeConstant.MARK: + loadMark(); + break; + case OpCodeConstant.STOP: Object value = stack.pop(); - while (value != stack.marker) { - Object key = stack.pop(); - newItems.put(key, value); - value = stack.pop(); - } - - Map dict = (Map) stack.peek(); - dict.putAll(newItems); - } - - private void loadMark() { - stack.addMark(); - } - - private void loadReduce() { - Object[] args = (Object[]) stack.pop(); - IObjectConstructor constructor = (IObjectConstructor) stack.pop(); - stack.add(constructor.construct(args)); - } - - private void loadNewBbj() { + stack.clear(); + memo.clear(); + return value; + case OpCodeConstant.POP: + loadPop(); + break; + case OpCodeConstant.POP_MARK: + loadPopMark(); + break; + case OpCodeConstant.DUP: + loadDup(); + break; + case OpCodeConstant.FLOAT: + loadFloat(); + break; + case OpCodeConstant.INT: + loadInt(); + break; + case OpCodeConstant.BININT: + loadBinint(); + break; + case OpCodeConstant.BININT1: + loadBinint1(); + break; + case OpCodeConstant.LONG: + loadLong(); + break; + case OpCodeConstant.BININT2: + loadBinint2(); + break; + case OpCodeConstant.NONE: + loadNone(); + break; + case OpCodeConstant.PERSID: + loadPersid(); + break; + case OpCodeConstant.BINPERSID: + loadBinpersid(); + break; + case OpCodeConstant.REDUCE: loadReduce(); - } - - private void loadNewObjEx() { - HashMap kwargs = (HashMap) stack.pop(); - Object[] args = (Object[]) stack.pop(); - IObjectConstructor constructor = (IObjectConstructor) stack.pop(); - if (kwargs.size() == 0) { - stack.add(constructor.construct(args)); - } else { - throw new PickleException("loadNewObjEx with keyword arguments not supported"); + break; + case OpCodeConstant.STRING: + loadString(); + break; + case OpCodeConstant.BINSTRING: + loadBinstring(); + break; + case OpCodeConstant.SHORT_BINSTRING: + loadShortBinstring(); + break; + case OpCodeConstant.UNICODE: + loadUnicode(); + break; + case OpCodeConstant.BINUNICODE: + loadBinunicode(); + break; + case OpCodeConstant.APPEND: + loadAppend(); + break; + case OpCodeConstant.BUILD: + loadBuild(); + break; + case OpCodeConstant.GLOBAL: + loadGlobal(); + break; + case OpCodeConstant.DICT: + loadDict(); + break; + case OpCodeConstant.EMPTY_DICT: + loadEmptyDictionary(); + break; + case OpCodeConstant.APPENDS: + loadAppends(); + break; + case OpCodeConstant.GET: + loadGet(); + break; + case OpCodeConstant.BINGET: + loadBinget(); + break; + case OpCodeConstant.INST: + loadInst(); + break; + case OpCodeConstant.LONG_BINGET: + loadLongBinget(); + break; + case OpCodeConstant.LIST: + loadList(); + break; + case OpCodeConstant.EMPTY_LIST: + loadEmptyList(); + break; + case OpCodeConstant.OBJ: + loadObj(); + break; + case OpCodeConstant.PUT: + loadPut(); + break; + case OpCodeConstant.BINPUT: + loadBinput(); + break; + case OpCodeConstant.LONG_BINPUT: + loadLongBinput(); + break; + case OpCodeConstant.SETITEM: + loadSetitem(); + break; + case OpCodeConstant.TUPLE: + loadTuple(); + break; + case OpCodeConstant.EMPTY_TUPLE: + loadEmptyTuple(); + break; + case OpCodeConstant.SETITEMS: + loadSetitems(); + break; + case OpCodeConstant.BINFLOAT: + loadBinfloat(); + break; + + case OpCodeConstant.PROTO: + loadProto(); + break; + case OpCodeConstant.NEWOBJ: + loadNewBbj(); + break; + case OpCodeConstant.EXT1: + case OpCodeConstant.EXT2: + case OpCodeConstant.EXT4: + throw new PickleException("Unimplemented opcode EXT1/EXT2/EXT4 encountered."); + case OpCodeConstant.TUPLE1: + loadTuple1(); + break; + case OpCodeConstant.TUPLE2: + loadTuple2(); + break; + case OpCodeConstant.TUPLE3: + loadTuple3(); + break; + case OpCodeConstant.NEWTRUE: + loadTrue(); + break; + case OpCodeConstant.NEWFALSE: + loadFalse(); + break; + case OpCodeConstant.LONG1: + loadLong1(); + break; + case OpCodeConstant.LONG4: + loadLong4(); + break; + + case OpCodeConstant.BINBYTES: + loadBinBytes(); + break; + case OpCodeConstant.SHORT_BINBYTES: + loadShortBinBytes(); + break; + + case OpCodeConstant.BINUNICODE8: + loadBinunicode8(); + break; + case OpCodeConstant.SHORT_BINUNICODE: + loadShortBinunicode(); + break; + case OpCodeConstant.BINBYTES8: + loadBinBytes8(); + break; + case OpCodeConstant.EMPTY_SET: + loadEmptySet(); + break; + case OpCodeConstant.ADDITEMS: + loadAddItems(); + break; + case OpCodeConstant.FROZENSET: + loadFrozenset(); + break; + case OpCodeConstant.MEMOIZE: + loadMemoize(); + break; + case OpCodeConstant.FRAME: + loadFrame(); + break; + case OpCodeConstant.NEWOBJ_EX: + loadNewObjEx(); + break; + case OpCodeConstant.STACK_GLOBAL: + loadStackGlobal(); + break; + + case OpCodeConstant.BYTEARRAY8: + loadByteArray8(); + break; + case OpCodeConstant.READONLY_BUFFER: + loadReadonlyBuffer(); + break; + case OpCodeConstant.NEXT_BUFFER: + loadNextBuffer(); + break; + + default: + throw new InvalidOpcodeException("invalid pickle opcode: " + key); + } + + return NO_RETURN_VALUE; + } + + private void loadReadonlyBuffer() {} + + private void loadNextBuffer() throws PickleException { + stack.add(nextBuffer()); + } + + private void loadByteArray8() throws IOException { + long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); + stack.add(PickleUtils.readBytes(input, len)); + } + + private void loadBuild() { + Object args = stack.pop(); + Object target = stack.peek(); + try { + Method setStateMethod = target.getClass().getMethod(SET_STATE, args.getClass()); + setStateMethod.invoke(target, args); + } catch (Exception e) { + throw new PickleException("failed to setState()", e); + } + } + + private void loadProto() throws IOException { + short proto = PickleUtils.readByte(input); + if (proto < 0 || proto > HIGHEST_PROTOCOL) { + throw new PickleException("unsupported pickle protocol: " + proto); + } + } + + private void loadNone() { + stack.add(null); + } + + private void loadFalse() { + stack.add(false); + } + + private void loadTrue() { + stack.add(true); + } + + private void loadInt() throws IOException { + String data = PickleUtils.readLine(input, true); + Object val; + if (data.equals(OpCodeConstant.FALSE.substring(1))) { + val = false; + } else if (data.equals(OpCodeConstant.TRUE.substring(1))) { + val = true; + } else { + String number = data.substring(0, data.length() - 1); + try { + val = Integer.parseInt(number, 10); + } catch (NumberFormatException x) { + val = Long.parseLong(number, 10); + } + } + stack.add(val); + } + + private void loadBinint() throws IOException { + int integer = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + stack.add(integer); + } + + private void loadBinint1() throws IOException { + stack.add((int) PickleUtils.readByte(input)); + } + + private void loadBinint2() throws IOException { + int integer = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 2)); + stack.add(integer); + } + + private void loadLong() throws IOException { + String val = PickleUtils.readLine(input); + if (val.endsWith(LONG_END_WITH)) { + val = val.substring(0, val.length() - 1); + } + BigInteger bi = new BigInteger(val); + stack.add(PickleUtils.optimizeBigint(bi)); + } + + private void loadLong1() throws IOException { + short n = PickleUtils.readByte(input); + byte[] data = PickleUtils.readBytes(input, n); + stack.add(PickleUtils.decodeLong(data)); + } + + private void loadLong4() throws IOException { + int n = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + byte[] data = PickleUtils.readBytes(input, n); + stack.add(PickleUtils.decodeLong(data)); + } + + private void loadFloat() throws IOException { + String val = PickleUtils.readLine(input, true); + stack.add(Double.parseDouble(val)); + } + + private void loadBinfloat() throws IOException { + double val = PickleUtils.bytes2Double(PickleUtils.readBytes(input, 8), 0); + stack.add(val); + } + + private void loadString() throws IOException { + String rep = PickleUtils.readLine(input); + boolean quotesOk = false; + for (String q : new String[] {"\"", "'"}) { + if (rep.startsWith(q)) { + if (!rep.endsWith(q)) { + throw new PickleException("insecure string pickle"); } - } - - private void loadFrame() throws IOException { - PickleUtils.readBytes(input, 8); - } - - private void loadPersid() throws IOException { - String pid = PickleUtils.readLine(input); - stack.add(persistentLoad(pid)); - } - - private void loadBinpersid() { - String pid = stack.pop().toString(); - stack.add(persistentLoad(pid)); - } - - private void loadObj() { - List args = stack.popAllSinceMarker(); - IObjectConstructor constructor = (IObjectConstructor) args.get(0); - args = args.subList(1, args.size()); - Object object = constructor.construct(args.toArray()); - stack.add(object); - } - - private void loadInst() throws IOException { - String module = PickleUtils.readLine(input); - String classname = PickleUtils.readLine(input); - List args = stack.popAllSinceMarker(); - IObjectConstructor constructor = objectConstructors.get(module + MODULE_SUFFIX + classname); - if (constructor == null) { - constructor = new DictionaryConstructor(module, classname); - args.clear(); + rep = rep.substring(1, rep.length() - 1); + quotesOk = true; + break; + } + } + + if (!quotesOk) { + throw new PickleException("insecure string pickle"); + } + + stack.add(PickleUtils.decodeEscaped(rep)); + } + + private void loadBinstring() throws IOException { + int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + byte[] data = PickleUtils.readBytes(input, len); + stack.add(PickleUtils.rawStringFromBytes(data)); + } + + private void loadBinBytes() throws IOException { + int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + stack.add(PickleUtils.readBytes(input, len)); + } + + private void loadBinBytes8() throws IOException { + long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); + stack.add(PickleUtils.readBytes(input, len)); + } + + private void loadUnicode() throws IOException { + String str = PickleUtils.decodeUnicodeEscaped(PickleUtils.readLine(input)); + stack.add(str); + } + + private void loadBinunicode() throws IOException { + int len = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + byte[] data = PickleUtils.readBytes(input, len); + stack.add(new String(data, StandardCharsets.UTF_8)); + } + + private void loadBinunicode8() throws IOException { + long len = PickleUtils.bytes2Long(PickleUtils.readBytes(input, 8), 0); + byte[] data = PickleUtils.readBytes(input, len); + stack.add(new String(data, StandardCharsets.UTF_8)); + } + + private void loadShortBinunicode() throws IOException { + int len = PickleUtils.readByte(input); + byte[] data = PickleUtils.readBytes(input, len); + stack.add(new String(data, StandardCharsets.UTF_8)); + } + + private void loadShortBinstring() throws IOException { + short len = PickleUtils.readByte(input); + byte[] data = PickleUtils.readBytes(input, len); + stack.add(PickleUtils.rawStringFromBytes(data)); + } + + private void loadShortBinBytes() throws IOException { + short len = PickleUtils.readByte(input); + stack.add(PickleUtils.readBytes(input, len)); + } + + private void loadTuple() { + List top = stack.popAllSinceMarker(); + stack.add(top.toArray()); + } + + private void loadEmptyTuple() { + stack.add(new Object[0]); + } + + private void loadTuple1() { + stack.add(new Object[] {stack.pop()}); + } + + private void loadTuple2() { + Object o2 = stack.pop(); + Object o1 = stack.pop(); + stack.add(new Object[] {o1, o2}); + } + + private void loadTuple3() { + Object o3 = stack.pop(); + Object o2 = stack.pop(); + Object o1 = stack.pop(); + stack.add(new Object[] {o1, o2, o3}); + } + + private void loadEmptyList() { + stack.add(new ArrayList<>(0)); + } + + private void loadEmptyDictionary() { + stack.add(new HashMap<>(0)); + } + + private void loadEmptySet() { + stack.add(new HashSet<>()); + } + + private void loadList() { + List top = stack.popAllSinceMarker(); + stack.add(top); + } + + private void loadDict() { + List top = stack.popAllSinceMarker(); + HashMap map = new HashMap<>(top.size()); + for (int i = 0; i < top.size(); i += 2) { + Object key = top.get(i); + Object value = top.get(i + 1); + map.put(key, value); + } + stack.add(map); + } + + private void loadFrozenset() { + List top = stack.popAllSinceMarker(); + HashSet set = new HashSet<>(); + set.addAll(top); + stack.add(set); + } + + private void loadAddItems() { + List top = stack.popAllSinceMarker(); + HashSet set = (HashSet) stack.pop(); + set.addAll(top); + stack.add(set); + } + + private void loadGlobal() throws IOException { + String module = PickleUtils.readLine(input); + String name = PickleUtils.readLine(input); + loadGlobalSub(module, name); + } + + private void loadStackGlobal() { + String name = (String) stack.pop(); + String module = (String) stack.pop(); + loadGlobalSub(module, name); + } + + private void loadGlobalSub(String module, String name) { + IObjectConstructor constructor = objectConstructors.get(module + MODULE_SUFFIX + name); + if (constructor == null) { + if (module.equals(EXCEPTIONS)) { + constructor = new ExceptionConstructor(PythonException.class, module, name); + } else if (module.equals(BUILTIN) || module.equals(BUILTINS)) { + if (name.endsWith(ERROR) + || name.endsWith(WARNING) + || name.endsWith(EXCEPTION) + || name.equals(KEYBOARD_INTERRUPT) + || name.equals(STOP_ITERATION) + || name.equals(GENERATOR_EXIT) + || name.equals(SYSTEM_EXIT)) { + constructor = new ExceptionConstructor(PythonException.class, module, name); + } else { + constructor = new DictionaryConstructor(module, name); } - Object object = constructor.construct(args.toArray()); - stack.add(object); - } - - protected Object persistentLoad(String pid) { - throw new PickleException("A load persistent id instruction was encountered, " - + "but no persistentLoad function was specified. pid: " + pid); - } + } else { + constructor = new DictionaryConstructor(module, name); + } + } + stack.add(constructor); + } + + private void loadPop() { + stack.pop(); + } + + private void loadPopMark() { + Object o; + do { + o = stack.pop(); + } while (o != stack.marker); + stack.trim(); + } + + private void loadDup() { + stack.add(stack.peek()); + } + + private void loadGet() throws IOException { + int i = Integer.parseInt(PickleUtils.readLine(input), 10); + if (!memo.containsKey(i)) { + throw new PickleException("invalid memo key"); + } + stack.add(memo.get(i)); + } + + private void loadBinget() throws IOException { + int i = PickleUtils.readByte(input); + if (!memo.containsKey(i)) { + throw new PickleException("invalid memo key"); + } + stack.add(memo.get(i)); + } + + private void loadLongBinget() throws IOException { + int i = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + if (!memo.containsKey(i)) { + throw new PickleException("invalid memo key"); + } + stack.add(memo.get(i)); + } + + private void loadPut() throws IOException { + int i = Integer.parseInt(PickleUtils.readLine(input), 10); + memo.put(i, stack.peek()); + } + + private void loadBinput() throws IOException { + int i = PickleUtils.readByte(input); + memo.put(i, stack.peek()); + } + + private void loadLongBinput() throws IOException { + int i = PickleUtils.bytes2Integer(PickleUtils.readBytes(input, 4)); + memo.put(i, stack.peek()); + } + + private void loadMemoize() { + memo.put(memo.size(), stack.peek()); + } + + private void loadAppend() { + Object value = stack.pop(); + ArrayList list = (ArrayList) stack.peek(); + list.add(value); + } + + private void loadAppends() { + List top = stack.popAllSinceMarker(); + ArrayList list = (ArrayList) stack.peek(); + list.addAll(top); + list.trimToSize(); + } + + private void loadSetitem() { + Object value = stack.pop(); + Object key = stack.pop(); + Map dict = (Map) stack.peek(); + dict.put(key, value); + } + + private void loadSetitems() { + HashMap newItems = new HashMap<>(); + Object value = stack.pop(); + while (value != stack.marker) { + Object key = stack.pop(); + newItems.put(key, value); + value = stack.pop(); + } + + Map dict = (Map) stack.peek(); + dict.putAll(newItems); + } + + private void loadMark() { + stack.addMark(); + } + + private void loadReduce() { + Object[] args = (Object[]) stack.pop(); + IObjectConstructor constructor = (IObjectConstructor) stack.pop(); + stack.add(constructor.construct(args)); + } + + private void loadNewBbj() { + loadReduce(); + } + + private void loadNewObjEx() { + HashMap kwargs = (HashMap) stack.pop(); + Object[] args = (Object[]) stack.pop(); + IObjectConstructor constructor = (IObjectConstructor) stack.pop(); + if (kwargs.size() == 0) { + stack.add(constructor.construct(args)); + } else { + throw new PickleException("loadNewObjEx with keyword arguments not supported"); + } + } + + private void loadFrame() throws IOException { + PickleUtils.readBytes(input, 8); + } + + private void loadPersid() throws IOException { + String pid = PickleUtils.readLine(input); + stack.add(persistentLoad(pid)); + } + + private void loadBinpersid() { + String pid = stack.pop().toString(); + stack.add(persistentLoad(pid)); + } + + private void loadObj() { + List args = stack.popAllSinceMarker(); + IObjectConstructor constructor = (IObjectConstructor) args.get(0); + args = args.subList(1, args.size()); + Object object = constructor.construct(args.toArray()); + stack.add(object); + } + + private void loadInst() throws IOException { + String module = PickleUtils.readLine(input); + String classname = PickleUtils.readLine(input); + List args = stack.popAllSinceMarker(); + IObjectConstructor constructor = objectConstructors.get(module + MODULE_SUFFIX + classname); + if (constructor == null) { + constructor = new DictionaryConstructor(module, classname); + args.clear(); + } + Object object = constructor.construct(args.toArray()); + stack.add(object); + } + + protected Object persistentLoad(String pid) { + throw new PickleException( + "A load persistent id instruction was encountered, " + + "but no persistentLoad function was specified. pid: " + + pid); + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessErrorOutputLogger.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessErrorOutputLogger.java index 5f5c731bf..85352da9c 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessErrorOutputLogger.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessErrorOutputLogger.java @@ -19,42 +19,43 @@ package org.apache.geaflow.infer.log; -import com.google.common.base.Throwables; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.util.function.Consumer; +import com.google.common.base.Throwables; + public class ProcessErrorOutputLogger implements Runnable { - private static final String LINE_SIGNAL = "\n"; + private static final String LINE_SIGNAL = "\n"; - private final StringBuilder buffer = new StringBuilder(); + private final StringBuilder buffer = new StringBuilder(); - private final InputStream inputStream; + private final InputStream inputStream; - private final Consumer consumer; + private final Consumer consumer; - public ProcessErrorOutputLogger(InputStream inputStream, Consumer consumer) { - this.inputStream = inputStream; - this.consumer = consumer; - } + public ProcessErrorOutputLogger(InputStream inputStream, Consumer consumer) { + this.inputStream = inputStream; + this.consumer = consumer; + } - @Override - public void run() { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - String line; - while ((line = reader.readLine()) != null) { - buffer.append(line).append(LINE_SIGNAL); - consumer.accept(line); - } - } catch (Exception e) { - consumer.accept(Throwables.getStackTraceAsString(e)); - } + @Override + public void run() { + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + String line; + while ((line = reader.readLine()) != null) { + buffer.append(line).append(LINE_SIGNAL); + consumer.accept(line); + } + } catch (Exception e) { + consumer.accept(Throwables.getStackTraceAsString(e)); } + } - public String get() { - return buffer.toString(); - } -} \ No newline at end of file + public String get() { + return buffer.toString(); + } +} diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessLoggerManager.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessLoggerManager.java index f6369fbe8..9140fccb3 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessLoggerManager.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessLoggerManager.java @@ -21,46 +21,50 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.geaflow.common.utils.ThreadUtil; public class ProcessLoggerManager implements AutoCloseable { - private static final String PROCESS_LOG_PREFIX = "infer-process-log"; + private static final String PROCESS_LOG_PREFIX = "infer-process-log"; - private static final int PROCESS_THREAD_NUM = 2; + private static final int PROCESS_THREAD_NUM = 2; - private final Process process; + private final Process process; - private final ProcessOutputConsumer processOutputConsumer; + private final ProcessOutputConsumer processOutputConsumer; - private ProcessErrorOutputLogger errorOutputLogger; + private ProcessErrorOutputLogger errorOutputLogger; - private final ExecutorService executor; + private final ExecutorService executor; - public ProcessLoggerManager(Process process, - ProcessOutputConsumer processOutputConsumer) { - this.process = process; - this.processOutputConsumer = processOutputConsumer; - this.executor = Executors.newFixedThreadPool(PROCESS_THREAD_NUM, ThreadUtil.namedThreadFactory(true, PROCESS_LOG_PREFIX)); - } + public ProcessLoggerManager(Process process, ProcessOutputConsumer processOutputConsumer) { + this.process = process; + this.processOutputConsumer = processOutputConsumer; + this.executor = + Executors.newFixedThreadPool( + PROCESS_THREAD_NUM, ThreadUtil.namedThreadFactory(true, PROCESS_LOG_PREFIX)); + } - public void startLogging() { - this.executor.execute(new ProcessStdOutputLogger(process.getInputStream(), - processOutputConsumer.getStdOutConsumer())); + public void startLogging() { + this.executor.execute( + new ProcessStdOutputLogger( + process.getInputStream(), processOutputConsumer.getStdOutConsumer())); - errorOutputLogger = new ProcessErrorOutputLogger(process.getErrorStream(), - processOutputConsumer.getStdErrConsumer()); - this.executor.execute(errorOutputLogger); - } + errorOutputLogger = + new ProcessErrorOutputLogger( + process.getErrorStream(), processOutputConsumer.getStdErrConsumer()); + this.executor.execute(errorOutputLogger); + } - public ProcessErrorOutputLogger getErrorOutputLogger() { - return errorOutputLogger; - } + public ProcessErrorOutputLogger getErrorOutputLogger() { + return errorOutputLogger; + } - @Override - public void close() { - if (executor != null) { - executor.shutdown(); - } + @Override + public void close() { + if (executor != null) { + executor.shutdown(); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessOutputConsumer.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessOutputConsumer.java index 689944d6f..d0de04077 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessOutputConsumer.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessOutputConsumer.java @@ -22,13 +22,9 @@ public interface ProcessOutputConsumer { - /** - * Process std output consumer. - */ - Consumer getStdOutConsumer(); + /** Process std output consumer. */ + Consumer getStdOutConsumer(); - /** - * Process std error output consumer. - */ - Consumer getStdErrConsumer(); + /** Process std error output consumer. */ + Consumer getStdErrConsumer(); } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessStdOutputLogger.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessStdOutputLogger.java index eebee82e2..b09c77dd5 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessStdOutputLogger.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/ProcessStdOutputLogger.java @@ -18,33 +18,34 @@ */ package org.apache.geaflow.infer.log; -import com.google.common.base.Throwables; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.util.function.Consumer; +import com.google.common.base.Throwables; + public class ProcessStdOutputLogger implements Runnable { - private final InputStream inputStream; + private final InputStream inputStream; - private final Consumer consumer; + private final Consumer consumer; - public ProcessStdOutputLogger(InputStream inputStream, Consumer consumer) { - this.inputStream = inputStream; - this.consumer = consumer; - } + public ProcessStdOutputLogger(InputStream inputStream, Consumer consumer) { + this.inputStream = inputStream; + this.consumer = consumer; + } - @Override - public void run() { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); - String line; - while ((line = reader.readLine()) != null) { - consumer.accept(line); - } - } catch (Exception e) { - consumer.accept(Throwables.getStackTraceAsString(e)); - } + @Override + public void run() { + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + String line; + while ((line = reader.readLine()) != null) { + consumer.accept(line); + } + } catch (Exception e) { + consumer.accept(Throwables.getStackTraceAsString(e)); } + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/Slf4JProcessOutputConsumer.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/Slf4JProcessOutputConsumer.java index 352c821b0..bdad0d195 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/Slf4JProcessOutputConsumer.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/log/Slf4JProcessOutputConsumer.java @@ -19,24 +19,25 @@ package org.apache.geaflow.infer.log; import java.util.function.Consumer; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Slf4JProcessOutputConsumer implements ProcessOutputConsumer { - private final Logger logger; + private final Logger logger; - public Slf4JProcessOutputConsumer(String loggerName) { - logger = LoggerFactory.getLogger(loggerName); - } + public Slf4JProcessOutputConsumer(String loggerName) { + logger = LoggerFactory.getLogger(loggerName); + } - @Override - public Consumer getStdOutConsumer() { - return logger::info; - } + @Override + public Consumer getStdOutConsumer() { + return logger::info; + } - @Override - public Consumer getStdErrConsumer() { - return logger::error; - } + @Override + public Consumer getStdErrConsumer() { + return logger::error; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java index a7a570cc2..1e020a41a 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/InferFileUtils.java @@ -22,8 +22,6 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.JOB_WORK_PATH; import static org.apache.geaflow.file.FileConfigKeys.USER_NAME; -import com.google.common.base.Preconditions; -import com.google.common.io.Resources; import java.io.File; import java.io.FileFilter; import java.io.FileOutputStream; @@ -46,6 +44,7 @@ import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -53,228 +52,230 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.common.io.Resources; + public class InferFileUtils { - private static final Logger LOGGER = LoggerFactory.getLogger(InferFileUtils.class); + private static final Logger LOGGER = LoggerFactory.getLogger(InferFileUtils.class); - public static final String INFER_DIRECTORY = "infer"; + public static final String INFER_DIRECTORY = "infer"; - private static final String FILE_SUFFIX = "/"; + private static final String FILE_SUFFIX = "/"; - private static final String JAR_FILE = "jar:file:"; + private static final String JAR_FILE = "jar:file:"; - private static final String UDF_RESOURCE_PATH = "."; + private static final String UDF_RESOURCE_PATH = "."; - public static final String PY_FILE_EXTENSION = ".py"; + public static final String PY_FILE_EXTENSION = ".py"; - public static final String JAR_FILE_EXTENSION = ".jar"; + public static final String JAR_FILE_EXTENSION = ".jar"; - private static final int DEFAULT_BUFFER_SIZE = 1024; + private static final int DEFAULT_BUFFER_SIZE = 1024; - public static final String REQUIREMENTS_TXT = "requirements.txt"; + public static final String REQUIREMENTS_TXT = "requirements.txt"; - public static void releaseLock(FileLock fileLock) { - try { - fileLock.release(); - fileLock.channel().close(); - } catch (IOException e) { - LOGGER.error("release file lock failed", e); - } + public static void releaseLock(FileLock fileLock) { + try { + fileLock.release(); + fileLock.channel().close(); + } catch (IOException e) { + LOGGER.error("release file lock failed", e); } + } - public static FileLock addLock(File file) { - try { - FileChannel channel = FileChannel.open(file.toPath(), StandardOpenOption.WRITE); - return channel.lock(); - } catch (Exception e) { - LOGGER.error("add lock file {} error", file.toPath(), e); - throw new GeaflowRuntimeException("get file lock failed", e); - } + public static FileLock addLock(File file) { + try { + FileChannel channel = FileChannel.open(file.toPath(), StandardOpenOption.WRITE); + return channel.lock(); + } catch (Exception e) { + LOGGER.error("add lock file {} error", file.toPath(), e); + throw new GeaflowRuntimeException("get file lock failed", e); } + } - public static void forceMkdir(File directory) throws IOException { - String message; - if (directory.exists()) { - if (!directory.isDirectory()) { - message = String.format("File %s exists and is not a directory. Unable to " - + "create directory.", directory); - throw new IOException(message); - } - } else if (!directory.mkdirs() && !directory.isDirectory()) { - message = "Unable to create directory " + directory; - throw new IOException(message); - } + public static void forceMkdir(File directory) throws IOException { + String message; + if (directory.exists()) { + if (!directory.isDirectory()) { + message = + String.format( + "File %s exists and is not a directory. Unable to " + "create directory.", + directory); + throw new IOException(message); + } + } else if (!directory.mkdirs() && !directory.isDirectory()) { + message = "Unable to create directory " + directory; + throw new IOException(message); } + } - public static String copyPythonFile(String parentDir, File resourceFile) { - parentDir = parentDir.endsWith(FILE_SUFFIX) ? parentDir : parentDir + FILE_SUFFIX; - File targetFile = new File(parentDir + resourceFile.getName()); - if (targetFile.exists()) { - targetFile.delete(); - LOGGER.info("{} file is existed, delete", targetFile.getAbsolutePath()); - } - try { - FileUtils.copyFile(resourceFile, targetFile); - } catch (Exception e) { - throw new GeaflowRuntimeException( - String.format("prepare python file [%s] failed, %s", resourceFile.getName(), - e.getMessage())); - } - LOGGER.info("prepare python file [{}] finish", resourceFile.getName()); - Preconditions.checkState(targetFile.exists()); - return targetFile.getAbsolutePath(); + public static String copyPythonFile(String parentDir, File resourceFile) { + parentDir = parentDir.endsWith(FILE_SUFFIX) ? parentDir : parentDir + FILE_SUFFIX; + File targetFile = new File(parentDir + resourceFile.getName()); + if (targetFile.exists()) { + targetFile.delete(); + LOGGER.info("{} file is existed, delete", targetFile.getAbsolutePath()); } - - - public static String copyInferFileByURL(String parentPath, String resourceFilePath) { - File sourceFile = new File(resourceFilePath.trim()); - parentPath = parentPath.endsWith(FILE_SUFFIX) ? parentPath : parentPath + FILE_SUFFIX; - File file = new File(parentPath + sourceFile.getName()); - if (file.exists()) { - file.delete(); - LOGGER.info("{} file is existed, delete", file.getAbsolutePath()); - } - try { - URL url = InferFileUtils.class.getClassLoader().getResource(resourceFilePath); - Preconditions.checkNotNull(url, - String.format("Cannot find resource file [%s] in " + "classpath", resourceFilePath)); - File tmpFile = new File(file.getParent() + File.separator + "tmp_" + file.getName()); - FileUtils.copyURLToFile(url, tmpFile); - tmpFile.renameTo(file); - } catch (Exception e) { - throw new GeaflowRuntimeException( - String.format("prepare python file [%s] failed by url, %s", resourceFilePath, - e.getMessage())); - } - LOGGER.info("prepare python file [{}] finish by url", resourceFilePath); - return file.getAbsolutePath(); + try { + FileUtils.copyFile(resourceFile, targetFile); + } catch (Exception e) { + throw new GeaflowRuntimeException( + String.format( + "prepare python file [%s] failed, %s", resourceFile.getName(), e.getMessage())); } + LOGGER.info("prepare python file [{}] finish", resourceFile.getName()); + Preconditions.checkState(targetFile.exists()); + return targetFile.getAbsolutePath(); + } - public static String copyPythonFile(String parentDir, File resourceFile, String reName) { - parentDir = parentDir.endsWith(FILE_SUFFIX) ? parentDir : parentDir + FILE_SUFFIX; - File targetFile = new File(parentDir + reName); - if (targetFile.exists()) { - targetFile.delete(); - LOGGER.info("{} file is existed, delete", targetFile.getAbsolutePath()); - } - try { - FileUtils.copyFile(resourceFile, targetFile); - } catch (Exception e) { - throw new GeaflowRuntimeException( - String.format("prepare python file [%s] and rename [%s] failed, %s", - resourceFile.getName(), reName, e.getMessage())); - } - LOGGER.info("prepare python file [{}] and rename [{}] finish", resourceFile.getName(), - reName); - Preconditions.checkState(targetFile.exists()); - return targetFile.getAbsolutePath(); + public static String copyInferFileByURL(String parentPath, String resourceFilePath) { + File sourceFile = new File(resourceFilePath.trim()); + parentPath = parentPath.endsWith(FILE_SUFFIX) ? parentPath : parentPath + FILE_SUFFIX; + File file = new File(parentPath + sourceFile.getName()); + if (file.exists()) { + file.delete(); + LOGGER.info("{} file is existed, delete", file.getAbsolutePath()); } + try { + URL url = InferFileUtils.class.getClassLoader().getResource(resourceFilePath); + Preconditions.checkNotNull( + url, String.format("Cannot find resource file [%s] in " + "classpath", resourceFilePath)); + File tmpFile = new File(file.getParent() + File.separator + "tmp_" + file.getName()); + FileUtils.copyURLToFile(url, tmpFile); + tmpFile.renameTo(file); + } catch (Exception e) { + throw new GeaflowRuntimeException( + String.format( + "prepare python file [%s] failed by url, %s", resourceFilePath, e.getMessage())); + } + LOGGER.info("prepare python file [{}] finish by url", resourceFilePath); + return file.getAbsolutePath(); + } - public static List getPythonFilesByCondition(FileFilter fileFilter) { - String jobPackagePath = Resources.getResource(UDF_RESOURCE_PATH).getPath(); - File folder = new File(jobPackagePath); - for (File file : folder.listFiles()) { - LOGGER.info("folder {} sub file {}", folder.getAbsolutePath(), file.getName()); - } - File[] subFiles = folder.listFiles(fileFilter); - if (subFiles == null) { - return Collections.emptyList(); - } - return Arrays.asList(subFiles); + public static String copyPythonFile(String parentDir, File resourceFile, String reName) { + parentDir = parentDir.endsWith(FILE_SUFFIX) ? parentDir : parentDir + FILE_SUFFIX; + File targetFile = new File(parentDir + reName); + if (targetFile.exists()) { + targetFile.delete(); + LOGGER.info("{} file is existed, delete", targetFile.getAbsolutePath()); + } + try { + FileUtils.copyFile(resourceFile, targetFile); + } catch (Exception e) { + throw new GeaflowRuntimeException( + String.format( + "prepare python file [%s] and rename [%s] failed, %s", + resourceFile.getName(), reName, e.getMessage())); } + LOGGER.info("prepare python file [{}] and rename [{}] finish", resourceFile.getName(), reName); + Preconditions.checkState(targetFile.exists()); + return targetFile.getAbsolutePath(); + } - public static File getUserJobJarFile() { - String jobPackagePath = Resources.getResource(UDF_RESOURCE_PATH).getPath(); - File folder = new File(jobPackagePath); - for (File file : folder.listFiles()) { - if (file.isFile() && file.getName().endsWith(JAR_FILE_EXTENSION)) { - LOGGER.info("folder {} user job jar is {}", folder.getAbsolutePath(), - file.getName()); - return file; - } - } - return null; + public static List getPythonFilesByCondition(FileFilter fileFilter) { + String jobPackagePath = Resources.getResource(UDF_RESOURCE_PATH).getPath(); + File folder = new File(jobPackagePath); + for (File file : folder.listFiles()) { + LOGGER.info("folder {} sub file {}", folder.getAbsolutePath(), file.getName()); } + File[] subFiles = folder.listFiles(fileFilter); + if (subFiles == null) { + return Collections.emptyList(); + } + return Arrays.asList(subFiles); + } + public static File getUserJobJarFile() { + String jobPackagePath = Resources.getResource(UDF_RESOURCE_PATH).getPath(); + File folder = new File(jobPackagePath); + for (File file : folder.listFiles()) { + if (file.isFile() && file.getName().endsWith(JAR_FILE_EXTENSION)) { + LOGGER.info("folder {} user job jar is {}", folder.getAbsolutePath(), file.getName()); + return file; + } + } + return null; + } - public static String getInferDirectory(Configuration configuration) { - String workPath = configuration.getString(JOB_WORK_PATH); - String userName = configuration.getString(USER_NAME); - String inferPath = workPath + File.separator + userName + File.separator + INFER_DIRECTORY; - String jobUniqueId = configuration.getString(ExecutionConfigKeys.JOB_UNIQUE_ID); - File inferFile = new File(inferPath + File.separator + jobUniqueId); - if (!inferFile.exists()) { - inferFile.mkdirs(); - } - return inferFile.getAbsolutePath(); + public static String getInferDirectory(Configuration configuration) { + String workPath = configuration.getString(JOB_WORK_PATH); + String userName = configuration.getString(USER_NAME); + String inferPath = workPath + File.separator + userName + File.separator + INFER_DIRECTORY; + String jobUniqueId = configuration.getString(ExecutionConfigKeys.JOB_UNIQUE_ID); + File inferFile = new File(inferPath + File.separator + jobUniqueId); + if (!inferFile.exists()) { + inferFile.mkdirs(); } + return inferFile.getAbsolutePath(); + } - public static String createTargetDir(String dirName, Configuration configuration) { - String inferDirectory = getInferDirectory(configuration); - File userFilesDirFile = new File(inferDirectory, dirName); - if (!userFilesDirFile.exists()) { - userFilesDirFile.mkdirs(); - } - String absolutePath = userFilesDirFile.getAbsolutePath(); - LOGGER.info("create infer directory is {}", absolutePath); - return absolutePath; + public static String createTargetDir(String dirName, Configuration configuration) { + String inferDirectory = getInferDirectory(configuration); + File userFilesDirFile = new File(inferDirectory, dirName); + if (!userFilesDirFile.exists()) { + userFilesDirFile.mkdirs(); } + String absolutePath = userFilesDirFile.getAbsolutePath(); + LOGGER.info("create infer directory is {}", absolutePath); + return absolutePath; + } - public static List getPathsFromResourceJAR(String folder) throws URISyntaxException, IOException { - List result; - String jarPath = InferFileUtils.class.getProtectionDomain() - .getCodeSource() - .getLocation() - .toURI() - .getPath(); - LOGGER.info("jar path {}", jarPath); - URI uri = URI.create(JAR_FILE + jarPath); - try (FileSystem fs = FileSystems.newFileSystem(uri, Collections.emptyMap())) { - result = Files.walk(fs.getPath(folder)) - .filter(Files::isRegularFile) - .collect(Collectors.toList()); - } - return result; + public static List getPathsFromResourceJAR(String folder) + throws URISyntaxException, IOException { + List result; + String jarPath = + InferFileUtils.class.getProtectionDomain().getCodeSource().getLocation().toURI().getPath(); + LOGGER.info("jar path {}", jarPath); + URI uri = URI.create(JAR_FILE + jarPath); + try (FileSystem fs = FileSystems.newFileSystem(uri, Collections.emptyMap())) { + result = + Files.walk(fs.getPath(folder)).filter(Files::isRegularFile).collect(Collectors.toList()); } + return result; + } - public static void prepareInferFilesFromJars(String targetDirectory) { - File userJobJarFile = getUserJobJarFile(); - Preconditions.checkNotNull(userJobJarFile); - try { - JarFile jarFile = new JarFile(userJobJarFile); - Enumeration entries = jarFile.entries(); - while (entries.hasMoreElements()) { - JarEntry entry = entries.nextElement(); - String entryName = entry.getName(); - if (!entry.isDirectory()) { - String inferFile = extractFile(targetDirectory, entryName, entry, jarFile); - LOGGER.info("cp infer file {} to {} from jar file {}", entryName, inferFile, userJobJarFile.getName()); - } else { - File entryDestination = new File(targetDirectory, entry.getName()); - if (!entryDestination.exists()) { - entryDestination.mkdirs(); - } - LOGGER.info("create infer directory is {}", entryDestination); - } - } - jarFile.close(); - } catch (IOException e) { - LOGGER.error("open jar file {} failed", userJobJarFile.getName()); + public static void prepareInferFilesFromJars(String targetDirectory) { + File userJobJarFile = getUserJobJarFile(); + Preconditions.checkNotNull(userJobJarFile); + try { + JarFile jarFile = new JarFile(userJobJarFile); + Enumeration entries = jarFile.entries(); + while (entries.hasMoreElements()) { + JarEntry entry = entries.nextElement(); + String entryName = entry.getName(); + if (!entry.isDirectory()) { + String inferFile = extractFile(targetDirectory, entryName, entry, jarFile); + LOGGER.info( + "cp infer file {} to {} from jar file {}", + entryName, + inferFile, + userJobJarFile.getName()); + } else { + File entryDestination = new File(targetDirectory, entry.getName()); + if (!entryDestination.exists()) { + entryDestination.mkdirs(); + } + LOGGER.info("create infer directory is {}", entryDestination); } + } + jarFile.close(); + } catch (IOException e) { + LOGGER.error("open jar file {} failed", userJobJarFile.getName()); } + } - - private static String extractFile(String targetDirectory, String fileName, JarEntry entry, - JarFile jarFile) throws IOException { - String targetFilePath = targetDirectory + File.separator + fileName; - try (InputStream inputStream = jarFile.getInputStream(entry); - FileOutputStream outputStream = new FileOutputStream(targetFilePath)) { - byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - outputStream.write(buffer, 0, bytesRead); - } - } - return targetFilePath; + private static String extractFile( + String targetDirectory, String fileName, JarEntry entry, JarFile jarFile) throws IOException { + String targetFilePath = targetDirectory + File.separator + fileName; + try (InputStream inputStream = jarFile.getInputStream(entry); + FileOutputStream outputStream = new FileOutputStream(targetFilePath)) { + byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } } + return targetFilePath; + } } diff --git a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/ShellExecUtils.java b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/ShellExecUtils.java index c747efb9a..7b53696ea 100644 --- a/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/ShellExecUtils.java +++ b/geaflow/geaflow-infer/src/main/java/org/apache/geaflow/infer/util/ShellExecUtils.java @@ -19,7 +19,6 @@ package org.apache.geaflow.infer.util; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.io.File; import java.time.Duration; import java.util.concurrent.ExecutorService; @@ -27,6 +26,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; + import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.infer.log.ProcessErrorOutputLogger; @@ -34,88 +34,93 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class ShellExecUtils { +import com.google.common.util.concurrent.ThreadFactoryBuilder; - private static final Logger LOGGER = LoggerFactory.getLogger(ShellExecUtils.class); +public final class ShellExecUtils { - private static final String SHELL_KEY = "sh"; + private static final Logger LOGGER = LoggerFactory.getLogger(ShellExecUtils.class); - private static final String SHELL_PARAM = "-c"; + private static final String SHELL_KEY = "sh"; - private static final Consumer DUMMY_CONSUMER = s -> { - }; + private static final String SHELL_PARAM = "-c"; - private static final ExecutorService LOGGER_POOL = - new ThreadPoolExecutor( - 2, - 20, - 10, - TimeUnit.SECONDS, - new LinkedBlockingQueue<>(1024), - new ThreadFactoryBuilder() - .setNameFormat("infer-task-log-%d") - .setDaemon(true) - .build()); + private static final Consumer DUMMY_CONSUMER = s -> {}; - private ShellExecUtils() { - } + private static final ExecutorService LOGGER_POOL = + new ThreadPoolExecutor( + 2, + 20, + 10, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(1024), + new ThreadFactoryBuilder().setNameFormat("infer-task-log-%d").setDaemon(true).build()); + private ShellExecUtils() {} - public static boolean run(String cmd, Consumer stdOutputConsumer, - Consumer errOutputConsumer, Duration timeout, - boolean allowFailure, String workingDir) { - ProcessBuilder builder = new ProcessBuilder(SHELL_KEY, SHELL_PARAM, cmd); - if (workingDir != null) { - builder.directory(new File(workingDir)); - } - Process process = null; - int exitCode = 0; - try { - process = builder.start(); - if (stdOutputConsumer == null) { - stdOutputConsumer = DUMMY_CONSUMER; - } - if (errOutputConsumer == null) { - errOutputConsumer = DUMMY_CONSUMER; - } - ProcessErrorOutputLogger processErrorOutputLogger = - new ProcessErrorOutputLogger(process.getErrorStream(), errOutputConsumer); - LOGGER_POOL.execute(new ProcessStdOutputLogger(process.getInputStream(), stdOutputConsumer)); - LOGGER_POOL.execute(processErrorOutputLogger); - boolean finished = process.waitFor(timeout.toMillis(), TimeUnit.MILLISECONDS); - exitCode = process.exitValue(); - boolean success = finished && exitCode == 0; - if (!success && !allowFailure) { - if (!finished) { - throw new GeaflowRuntimeException(String.format("Command %s didn't finish in " - + "time, please try increase the " - + "timeout %s", cmd, FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC)); - } else { - LOGGER.error("Command {} exec failed, error message {}", cmd, processErrorOutputLogger.get()); - } - } - return success; - } catch (Exception e) { - LOGGER.error("error running {}, exit code is {}", cmd, exitCode, e); - throw new GeaflowRuntimeException("running shell exception", e); - } finally { - if (process != null) { - process.destroyForcibly(); - } - } + public static boolean run( + String cmd, + Consumer stdOutputConsumer, + Consumer errOutputConsumer, + Duration timeout, + boolean allowFailure, + String workingDir) { + ProcessBuilder builder = new ProcessBuilder(SHELL_KEY, SHELL_PARAM, cmd); + if (workingDir != null) { + builder.directory(new File(workingDir)); } - - public static boolean run(String cmd, Duration timeout, Consumer stdOutputConsumer, - Consumer errOutputConsumer) { - return run(cmd, stdOutputConsumer, errOutputConsumer, - timeout, false, null); + Process process = null; + int exitCode = 0; + try { + process = builder.start(); + if (stdOutputConsumer == null) { + stdOutputConsumer = DUMMY_CONSUMER; + } + if (errOutputConsumer == null) { + errOutputConsumer = DUMMY_CONSUMER; + } + ProcessErrorOutputLogger processErrorOutputLogger = + new ProcessErrorOutputLogger(process.getErrorStream(), errOutputConsumer); + LOGGER_POOL.execute(new ProcessStdOutputLogger(process.getInputStream(), stdOutputConsumer)); + LOGGER_POOL.execute(processErrorOutputLogger); + boolean finished = process.waitFor(timeout.toMillis(), TimeUnit.MILLISECONDS); + exitCode = process.exitValue(); + boolean success = finished && exitCode == 0; + if (!success && !allowFailure) { + if (!finished) { + throw new GeaflowRuntimeException( + String.format( + "Command %s didn't finish in " + "time, please try increase the " + "timeout %s", + cmd, FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC)); + } else { + LOGGER.error( + "Command {} exec failed, error message {}", cmd, processErrorOutputLogger.get()); + } + } + return success; + } catch (Exception e) { + LOGGER.error("error running {}, exit code is {}", cmd, exitCode, e); + throw new GeaflowRuntimeException("running shell exception", e); + } finally { + if (process != null) { + process.destroyForcibly(); + } } + } - public static boolean run(String cmd, Duration timeout, - Consumer stdOutputConsumer, - Consumer errOutputConsumer, String workDir) { - return run(cmd, stdOutputConsumer, errOutputConsumer, - timeout, false, workDir); - } + public static boolean run( + String cmd, + Duration timeout, + Consumer stdOutputConsumer, + Consumer errOutputConsumer) { + return run(cmd, stdOutputConsumer, errOutputConsumer, timeout, false, null); + } + public static boolean run( + String cmd, + Duration timeout, + Consumer stdOutputConsumer, + Consumer errOutputConsumer, + String workDir) { + return run(cmd, stdOutputConsumer, errOutputConsumer, timeout, false, workDir); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/AbstractMemoryPool.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/AbstractMemoryPool.java index 11298e70c..ef6bca84d 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/AbstractMemoryPool.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/AbstractMemoryPool.java @@ -19,304 +19,307 @@ package org.apache.geaflow.memory; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; + import org.apache.geaflow.memory.metric.ChunkListMetric; import org.apache.geaflow.memory.metric.PoolMetric; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class AbstractMemoryPool implements PoolMetric { - - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractMemoryPool.class); - - final int maxOrder; - final int pageSize; - final int pageShifts; - final int chunkSize; - - protected final ChunkList q050; - protected final ChunkList q025; - protected final ChunkList q000; - protected final ChunkList qInit; - protected final ChunkList q075; - protected final ChunkList q100; - - protected final List chunkListMetrics = Lists.newArrayList(); - - private Map> subpages = Maps.newHashMap(); - - Map groupThreadCounter = Maps.newHashMap(); - - final MemoryManager memoryManager; - private AtomicLong usedMemory; - - protected int currentChunkNum = 0; - protected long allocateMemory; - - public AbstractMemoryPool(MemoryManager memoryManager, int pageSize, int maxOrder, int pageShifts, int chunkSize) { - this.pageSize = pageSize; - this.maxOrder = maxOrder; - this.pageShifts = pageShifts; - this.chunkSize = chunkSize; - this.memoryManager = memoryManager; - - q100 = new ChunkList(this, 100, Integer.MAX_VALUE); - q075 = new ChunkList(this, 75, 100); - q050 = new ChunkList(this, 50, 100); - q025 = new ChunkList(this, 25, 75); - q000 = new ChunkList(this, 1, 50); - qInit = new ChunkList(this, Integer.MIN_VALUE, 25); - - q100.setPreList(q075); - q100.setNextList(null); - q075.setPreList(q050); - q075.setNextList(q100); - q050.setPreList(q025); - q050.setNextList(q075); - q025.setPreList(q000); - q025.setNextList(q050); - q000.setPreList(null); - q000.setNextList(q025); - qInit.setPreList(qInit); - qInit.setNextList(q000); - - ESegmentSize slot = ESegmentSize.valueOf(pageSize); - for (int i = 0; i <= slot.index(); i++) { - subpages.put(ESegmentSize.upValues[i], newPageHead(pageSize)); - } +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; - chunkListMetrics.add(qInit); - chunkListMetrics.add(q000); - chunkListMetrics.add(q025); - chunkListMetrics.add(q050); - chunkListMetrics.add(q075); - chunkListMetrics.add(q100); +public abstract class AbstractMemoryPool implements PoolMetric { - usedMemory = new AtomicLong(0); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractMemoryPool.class); + + final int maxOrder; + final int pageSize; + final int pageShifts; + final int chunkSize; + + protected final ChunkList q050; + protected final ChunkList q025; + protected final ChunkList q000; + protected final ChunkList qInit; + protected final ChunkList q075; + protected final ChunkList q100; + + protected final List chunkListMetrics = Lists.newArrayList(); + + private Map> subpages = Maps.newHashMap(); + + Map groupThreadCounter = Maps.newHashMap(); + + final MemoryManager memoryManager; + private AtomicLong usedMemory; + + protected int currentChunkNum = 0; + protected long allocateMemory; + + public AbstractMemoryPool( + MemoryManager memoryManager, int pageSize, int maxOrder, int pageShifts, int chunkSize) { + this.pageSize = pageSize; + this.maxOrder = maxOrder; + this.pageShifts = pageShifts; + this.chunkSize = chunkSize; + this.memoryManager = memoryManager; + + q100 = new ChunkList(this, 100, Integer.MAX_VALUE); + q075 = new ChunkList(this, 75, 100); + q050 = new ChunkList(this, 50, 100); + q025 = new ChunkList(this, 25, 75); + q000 = new ChunkList(this, 1, 50); + qInit = new ChunkList(this, Integer.MIN_VALUE, 25); + + q100.setPreList(q075); + q100.setNextList(null); + q075.setPreList(q050); + q075.setNextList(q100); + q050.setPreList(q025); + q050.setNextList(q075); + q025.setPreList(q000); + q025.setNextList(q050); + q000.setPreList(null); + q000.setNextList(q025); + qInit.setPreList(qInit); + qInit.setNextList(q000); + + ESegmentSize slot = ESegmentSize.valueOf(pageSize); + for (int i = 0; i <= slot.index(); i++) { + subpages.put(ESegmentSize.upValues[i], newPageHead(pageSize)); } - private Page newPageHead(int pageSize) { - Page head = new Page<>(pageSize); - head.prev = head; - head.next = head; - return head; + chunkListMetrics.add(qInit); + chunkListMetrics.add(q000); + chunkListMetrics.add(q025); + chunkListMetrics.add(q050); + chunkListMetrics.add(q075); + chunkListMetrics.add(q100); + + usedMemory = new AtomicLong(0); + } + + private Page newPageHead(int pageSize) { + Page head = new Page<>(pageSize); + head.prev = head; + head.next = head; + return head; + } + + List> allocate(int reqCapacity, MemoryGroup group) { + int size = 1; + if (reqCapacity > chunkSize) { + size = (reqCapacity / chunkSize) + (reqCapacity % chunkSize > 0 ? 1 : 0); } - - List> allocate(int reqCapacity, MemoryGroup group) { - int size = 1; - if (reqCapacity > chunkSize) { - size = (reqCapacity / chunkSize) + (reqCapacity % chunkSize > 0 ? 1 : 0); - } - List> bufs = new ArrayList<>(size); - boolean failed = false; - for (int i = 0; i < size - 1; i++) { - ByteBuf buf = oneAllocate(chunkSize, group); - if (buf.allocateFailed()) { - failed = true; - break; - } - bufs.add(buf); - } - if (!failed) { - ByteBuf buf = oneAllocate(reqCapacity - (size - 1) * chunkSize, group); - if (!buf.allocateFailed()) { - bufs.add(buf); - } - } - return bufs; + List> bufs = new ArrayList<>(size); + boolean failed = false; + for (int i = 0; i < size - 1; i++) { + ByteBuf buf = oneAllocate(chunkSize, group); + if (buf.allocateFailed()) { + failed = true; + break; + } + bufs.add(buf); } - - ByteBuf oneAllocate(int reqCapacity, MemoryGroup group) { - final int normCapacity = normalizeCapacity(reqCapacity); - ByteBuf byteBuf = newByteBuf(normCapacity); - allocate(byteBuf, normCapacity, group); - usedMemory.getAndAdd(byteBuf.length); - return byteBuf; + if (!failed) { + ByteBuf buf = oneAllocate(reqCapacity - (size - 1) * chunkSize, group); + if (!buf.allocateFailed()) { + bufs.add(buf); + } } - - private void allocate(ByteBuf byteBuf, final int normCapacity, MemoryGroup group) { - - if (normCapacity < pageSize) { - ESegmentSize slot = normalize(normCapacity); - - final Page head = subpages.get(slot); - synchronized (head) { - final Page s = head.next; - if (s != head) { - assert s.doNotDestroy && s.elemSize == normCapacity; - long handle = s.allocate(); - assert handle >= 0; - if (group.allocate(normCapacity, getMemoryMode())) { - s.chunk.initBufWithSubpage(byteBuf, handle, normCapacity); - } - if (byteBuf.allocateFailed()) { - group.free(normCapacity, getMemoryMode()); - } - return; - } - } - - synchronized (this) { - allocateBytebuf(byteBuf, normCapacity, group); - } - return; + return bufs; + } + + ByteBuf oneAllocate(int reqCapacity, MemoryGroup group) { + final int normCapacity = normalizeCapacity(reqCapacity); + ByteBuf byteBuf = newByteBuf(normCapacity); + allocate(byteBuf, normCapacity, group); + usedMemory.getAndAdd(byteBuf.length); + return byteBuf; + } + + private void allocate(ByteBuf byteBuf, final int normCapacity, MemoryGroup group) { + + if (normCapacity < pageSize) { + ESegmentSize slot = normalize(normCapacity); + + final Page head = subpages.get(slot); + synchronized (head) { + final Page s = head.next; + if (s != head) { + assert s.doNotDestroy && s.elemSize == normCapacity; + long handle = s.allocate(); + assert handle >= 0; + if (group.allocate(normCapacity, getMemoryMode())) { + s.chunk.initBufWithSubpage(byteBuf, handle, normCapacity); + } + if (byteBuf.allocateFailed()) { + group.free(normCapacity, getMemoryMode()); + } + return; } + } - if (normCapacity <= chunkSize) { - synchronized (this) { - allocateBytebuf(byteBuf, normCapacity, group); - } - } else { - //todo allocate huge size - LOGGER.warn(String.format("not support huge size:%d!", normCapacity)); - } + synchronized (this) { + allocateBytebuf(byteBuf, normCapacity, group); + } + return; } - // Method must be called inside synchronized(this) { ... } block - protected abstract void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group); - - void free(Chunk chunk, long handle, long length) { - freeChunk(chunk, handle, length); + if (normCapacity <= chunkSize) { + synchronized (this) { + allocateBytebuf(byteBuf, normCapacity, group); + } + } else { + // todo allocate huge size + LOGGER.warn(String.format("not support huge size:%d!", normCapacity)); } + } - void freeChunk(Chunk chunk, long handle, long length) { - final boolean needDestroyChunk; - synchronized (this) { - needDestroyChunk = !chunk.parent.free(chunk, handle); - } - if (needDestroyChunk) { - // destroyChunk not need to be called while holding the synchronized lock. - destroyChunk(chunk); - } - usedMemory.getAndAdd(-1 * length); - } + // Method must be called inside synchronized(this) { ... } block + protected abstract void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group); - int normalizeCapacity(int reqCapacity) { - if (reqCapacity < 0) { - throw new IllegalArgumentException("capacity: " + reqCapacity + " (expected: 0+)"); - } - - if (reqCapacity >= chunkSize) { - return reqCapacity; - } + void free(Chunk chunk, long handle, long length) { + freeChunk(chunk, handle, length); + } - if (reqCapacity <= ESegmentSize.smallest().size()) { - return ESegmentSize.smallest().size(); - } - - int normalizedCapacity = reqCapacity; - normalizedCapacity--; - normalizedCapacity |= normalizedCapacity >>> 1; - normalizedCapacity |= normalizedCapacity >>> 2; - normalizedCapacity |= normalizedCapacity >>> 4; - normalizedCapacity |= normalizedCapacity >>> 8; - normalizedCapacity |= normalizedCapacity >>> 16; - normalizedCapacity++; - - if (normalizedCapacity < 0) { - normalizedCapacity >>>= 1; - } - - return normalizedCapacity; + void freeChunk(Chunk chunk, long handle, long length) { + final boolean needDestroyChunk; + synchronized (this) { + needDestroyChunk = !chunk.parent.free(chunk, handle); } - - static ESegmentSize normalize(int noramlCapacity) { - return ESegmentSize.valueOf(noramlCapacity); + if (needDestroyChunk) { + // destroyChunk not need to be called while holding the synchronized lock. + destroyChunk(chunk); } + usedMemory.getAndAdd(-1 * length); + } - abstract MemoryMode getMemoryMode(); - - Page getSuitablePageHead(int size) { - int normCapacity = normalizeCapacity(size); - return subpages.get(normalize(normCapacity)); + int normalizeCapacity(int reqCapacity) { + if (reqCapacity < 0) { + throw new IllegalArgumentException("capacity: " + reqCapacity + " (expected: 0+)"); } - protected void destroy() { - subpages.forEach((k, v) -> v.destroy()); - destroyChunkList(qInit, q000, q025, q050, q075, q100); + if (reqCapacity >= chunkSize) { + return reqCapacity; } - private static void destroyChunkList(ChunkList... lists) { - for (ChunkList chunkList : lists) { - chunkList.destroy(); - } + if (reqCapacity <= ESegmentSize.smallest().size()) { + return ESegmentSize.smallest().size(); } - abstract boolean canExpandCapacity(); + int normalizedCapacity = reqCapacity; + normalizedCapacity--; + normalizedCapacity |= normalizedCapacity >>> 1; + normalizedCapacity |= normalizedCapacity >>> 2; + normalizedCapacity |= normalizedCapacity >>> 4; + normalizedCapacity |= normalizedCapacity >>> 8; + normalizedCapacity |= normalizedCapacity >>> 16; + normalizedCapacity++; + + if (normalizedCapacity < 0) { + normalizedCapacity >>>= 1; + } - abstract void shrinkCapacity(); + return normalizedCapacity; + } - abstract void destroyChunk(Chunk chunk); + static ESegmentSize normalize(int noramlCapacity) { + return ESegmentSize.valueOf(noramlCapacity); + } - protected abstract ByteBuf newByteBuf(int size); + abstract MemoryMode getMemoryMode(); - protected abstract Chunk newChunk(int pageSize, int maxOrder, int pageShifts, int chunkSize); + Page getSuitablePageHead(int size) { + int normCapacity = normalizeCapacity(size); + return subpages.get(normalize(normCapacity)); + } - protected abstract String dump(); + protected void destroy() { + subpages.forEach((k, v) -> v.destroy()); + destroyChunkList(qInit, q000, q025, q050, q075, q100); + } - AtomicInteger groupThreads(MemoryGroup group) { - if (!groupThreadCounter.containsKey(group)) { - groupThreadCounter.put(group, new AtomicInteger(0)); - } - return groupThreadCounter.get(group); + private static void destroyChunkList(ChunkList... lists) { + for (ChunkList chunkList : lists) { + chunkList.destroy(); } + } - @Override - public int numThreadCaches() { - int num = 0; - synchronized (groupThreadCounter) { - for (Entry a : groupThreadCounter.entrySet()) { - num += ((AtomicInteger) a.getValue()).intValue(); - } - } - return num; - } + abstract boolean canExpandCapacity(); - @Override - public long numAllocations() { - return 0; - } + abstract void shrinkCapacity(); - @Override - public long allocateBytes() { - return this.allocateMemory; - } + abstract void destroyChunk(Chunk chunk); - @Override - public long numActiveBytes() { - return usedMemory.get(); - } + protected abstract ByteBuf newByteBuf(int size); - @Override - public long freeBytes() { - return allocateBytes() - numActiveBytes(); - } + protected abstract Chunk newChunk(int pageSize, int maxOrder, int pageShifts, int chunkSize); - protected void reloadMemoryStatics(MemoryMode memoryMode) { - if (memoryMode == MemoryMode.OFF_HEAP) { - long totalOffHeapMemory = memoryManager.totalAllocateOffHeapMemory(); - if (totalOffHeapMemory > 0) { - MemoryGroupManger.getInstance() - .resetMemory(totalOffHeapMemory, chunkSize, memoryMode); - } - } - } + protected abstract String dump(); - protected void updateAllocateMemory() { - this.allocateMemory = (long) this.currentChunkNum * chunkSize; + AtomicInteger groupThreads(MemoryGroup group) { + if (!groupThreadCounter.containsKey(group)) { + groupThreadCounter.put(group, new AtomicInteger(0)); } + return groupThreadCounter.get(group); + } + + @Override + public int numThreadCaches() { + int num = 0; + synchronized (groupThreadCounter) { + for (Entry a : groupThreadCounter.entrySet()) { + num += ((AtomicInteger) a.getValue()).intValue(); + } + } + return num; + } + + @Override + public long numAllocations() { + return 0; + } + + @Override + public long allocateBytes() { + return this.allocateMemory; + } + + @Override + public long numActiveBytes() { + return usedMemory.get(); + } + + @Override + public long freeBytes() { + return allocateBytes() - numActiveBytes(); + } + + protected void reloadMemoryStatics(MemoryMode memoryMode) { + if (memoryMode == MemoryMode.OFF_HEAP) { + long totalOffHeapMemory = memoryManager.totalAllocateOffHeapMemory(); + if (totalOffHeapMemory > 0) { + MemoryGroupManger.getInstance().resetMemory(totalOffHeapMemory, chunkSize, memoryMode); + } + } + } + + protected void updateAllocateMemory() { + this.allocateMemory = (long) this.currentChunkNum * chunkSize; + } - @Override - public synchronized String toString() { - String newLine = "\n"; - StringBuilder buf = new StringBuilder() + @Override + public synchronized String toString() { + String newLine = "\n"; + StringBuilder buf = + new StringBuilder() .append("Chunk(s) at 0~25%:") .append(qInit) .append(newLine) @@ -338,7 +341,6 @@ public synchronized String toString() { .append(dump()) .append(newLine); - return buf.toString(); - } - + return buf.toString(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ByteBuf.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ByteBuf.java index ce3ca6144..02b95f119 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ByteBuf.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ByteBuf.java @@ -23,102 +23,105 @@ public class ByteBuf { - protected Chunk chunk; - protected long handle; - protected int startOffset; - protected int length; - - private ByteBuffer bf; - - public ByteBuf() { - } - - public ByteBuf(Chunk chunk, long handle, int startOffset, int length) { - init(chunk, handle, startOffset, length); - } - - void init(Chunk chunk, long handle, int startOffset, int length) { - this.chunk = chunk; - this.handle = handle; - this.startOffset = startOffset; - this.length = length; - this.bf = internalBuffer(); - } - - public ByteBuffer getBf() { - return bf; - } - - protected ByteBuffer internalBuffer() { - - ByteBuffer tmp = this.bf; - if (tmp == null) { - tmp = newInternalBuffer(); - } - return tmp; - } - - protected ByteBuffer newInternalBuffer() { - ByteBuffer buffer; - if (chunk.pool.getMemoryMode() == MemoryMode.OFF_HEAP) { - buffer = (ByteBuffer) ((ByteBuffer) chunk.memory).duplicate().position(startOffset) - .limit(startOffset + length); - } else { - buffer = ByteBuffer.wrap((byte[]) chunk.memory, startOffset, length); - } - return buffer.slice(); - } - - public void free(MemoryGroup group) { - chunk.pool.free(chunk, handle, length); - group.free(length, chunk.pool.getMemoryMode()); - } - - public Chunk getChunk() { - return chunk; - } - - byte get(int pos) { - return bf.get(pos); - } - - int contentSize() { - return bf.position(); - } - - public int getRemain() { - return bf.remaining(); - } - - public long getHandle() { - return handle; - } - - public int getStartOffset() { - return startOffset; - } - - public int getLength() { - return length; - } - - public boolean allocateFailed() { - return chunk == null; - } - - public void reset() { - bf.clear(); - } - - public void position(int pos) { - bf.position(pos); - } - - public ByteBuffer duplicate() { - return (ByteBuffer) getBf().duplicate().position(0).limit(bf.position()); - } - - public boolean hasRemaining() { - return bf.hasRemaining(); - } + protected Chunk chunk; + protected long handle; + protected int startOffset; + protected int length; + + private ByteBuffer bf; + + public ByteBuf() {} + + public ByteBuf(Chunk chunk, long handle, int startOffset, int length) { + init(chunk, handle, startOffset, length); + } + + void init(Chunk chunk, long handle, int startOffset, int length) { + this.chunk = chunk; + this.handle = handle; + this.startOffset = startOffset; + this.length = length; + this.bf = internalBuffer(); + } + + public ByteBuffer getBf() { + return bf; + } + + protected ByteBuffer internalBuffer() { + + ByteBuffer tmp = this.bf; + if (tmp == null) { + tmp = newInternalBuffer(); + } + return tmp; + } + + protected ByteBuffer newInternalBuffer() { + ByteBuffer buffer; + if (chunk.pool.getMemoryMode() == MemoryMode.OFF_HEAP) { + buffer = + (ByteBuffer) + ((ByteBuffer) chunk.memory) + .duplicate() + .position(startOffset) + .limit(startOffset + length); + } else { + buffer = ByteBuffer.wrap((byte[]) chunk.memory, startOffset, length); + } + return buffer.slice(); + } + + public void free(MemoryGroup group) { + chunk.pool.free(chunk, handle, length); + group.free(length, chunk.pool.getMemoryMode()); + } + + public Chunk getChunk() { + return chunk; + } + + byte get(int pos) { + return bf.get(pos); + } + + int contentSize() { + return bf.position(); + } + + public int getRemain() { + return bf.remaining(); + } + + public long getHandle() { + return handle; + } + + public int getStartOffset() { + return startOffset; + } + + public int getLength() { + return length; + } + + public boolean allocateFailed() { + return chunk == null; + } + + public void reset() { + bf.clear(); + } + + public void position(int pos) { + bf.position(pos); + } + + public ByteBuffer duplicate() { + return (ByteBuffer) getBf().duplicate().position(0).limit(bf.position()); + } + + public boolean hasRemaining() { + return bf.hasRemaining(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Chunk.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Chunk.java index afe2032c1..378f3b3e8 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Chunk.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Chunk.java @@ -18,379 +18,384 @@ import org.apache.geaflow.memory.metric.ChunkMetric; -/** - * This class is an adaptation of Netty's io.netty.buffer.PoolChunk. - */ +/** This class is an adaptation of Netty's io.netty.buffer.PoolChunk. */ public class Chunk implements ChunkMetric { - private static final int INTEGER_SIZE_MINUS_ONE = Integer.SIZE - 1; - - final AbstractMemoryPool pool; - final T memory; - final int offset; - - private final byte[] memoryMap; - private final byte[] depthMap; - private final Page[] subpages; - /** - * Used to determine if the requested capacity is equal to or greater than pageSize. - */ - private final int subpageOverflowMask; - private final int pageSize; - private final int pageShifts; - private final int maxOrder; - private final int chunkSize; - private final int log2ChunkSize; - private final int maxSubpageAllocs; - /** - * Used to mark memory as unusable. - */ - private final byte unusable; - - private int freeBytes; - - ChunkList parent; - Chunk prev; - Chunk next; - - Chunk(AbstractMemoryPool pool, T memory, int pageSize, int maxOrder, int pageShifts, int chunkSize, int offset) { - this.pool = pool; - this.memory = memory; - this.pageSize = pageSize; - this.pageShifts = pageShifts; - this.maxOrder = maxOrder; - this.chunkSize = chunkSize; - this.offset = offset; - unusable = (byte) (maxOrder + 1); - log2ChunkSize = log2(chunkSize); - subpageOverflowMask = ~(pageSize - 1); - freeBytes = chunkSize; - - assert maxOrder < 30 : "maxOrder should be < 30, but is: " + maxOrder; - maxSubpageAllocs = 1 << maxOrder; - - // Generate the memory map. - memoryMap = new byte[maxSubpageAllocs << 1]; - depthMap = new byte[memoryMap.length]; - int memoryMapIndex = 1; - for (int d = 0; d <= maxOrder; ++d) { // move down the tree one level at a time - int depth = 1 << d; - for (int p = 0; p < depth; ++p) { - // in each level traverse left to right and set value to the depth of subtree - memoryMap[memoryMapIndex] = (byte) d; - depthMap[memoryMapIndex] = (byte) d; - memoryMapIndex++; - } - } - - subpages = newPageArray(maxSubpageAllocs); + private static final int INTEGER_SIZE_MINUS_ONE = Integer.SIZE - 1; + + final AbstractMemoryPool pool; + final T memory; + final int offset; + + private final byte[] memoryMap; + private final byte[] depthMap; + private final Page[] subpages; + + /** Used to determine if the requested capacity is equal to or greater than pageSize. */ + private final int subpageOverflowMask; + + private final int pageSize; + private final int pageShifts; + private final int maxOrder; + private final int chunkSize; + private final int log2ChunkSize; + private final int maxSubpageAllocs; + + /** Used to mark memory as unusable. */ + private final byte unusable; + + private int freeBytes; + + ChunkList parent; + Chunk prev; + Chunk next; + + Chunk( + AbstractMemoryPool pool, + T memory, + int pageSize, + int maxOrder, + int pageShifts, + int chunkSize, + int offset) { + this.pool = pool; + this.memory = memory; + this.pageSize = pageSize; + this.pageShifts = pageShifts; + this.maxOrder = maxOrder; + this.chunkSize = chunkSize; + this.offset = offset; + unusable = (byte) (maxOrder + 1); + log2ChunkSize = log2(chunkSize); + subpageOverflowMask = ~(pageSize - 1); + freeBytes = chunkSize; + + assert maxOrder < 30 : "maxOrder should be < 30, but is: " + maxOrder; + maxSubpageAllocs = 1 << maxOrder; + + // Generate the memory map. + memoryMap = new byte[maxSubpageAllocs << 1]; + depthMap = new byte[memoryMap.length]; + int memoryMapIndex = 1; + for (int d = 0; d <= maxOrder; ++d) { // move down the tree one level at a time + int depth = 1 << d; + for (int p = 0; p < depth; ++p) { + // in each level traverse left to right and set value to the depth of subtree + memoryMap[memoryMapIndex] = (byte) d; + depthMap[memoryMapIndex] = (byte) d; + memoryMapIndex++; + } } - @SuppressWarnings("unchecked") - private Page[] newPageArray(int size) { - return new Page[size]; - } + subpages = newPageArray(maxSubpageAllocs); + } - @Override - public int usage() { - final int freeBytes; - synchronized (pool) { - freeBytes = this.freeBytes; - } - return usage(freeBytes); - } + @SuppressWarnings("unchecked") + private Page[] newPageArray(int size) { + return new Page[size]; + } - boolean isFree() { - return this.freeBytes == chunkSize; + @Override + public int usage() { + final int freeBytes; + synchronized (pool) { + freeBytes = this.freeBytes; } + return usage(freeBytes); + } - private int usage(int freeBytes) { - if (freeBytes == 0) { - return 100; - } - - if (freeBytes == chunkSize) { - return 0; - } + boolean isFree() { + return this.freeBytes == chunkSize; + } - int freePercentage = (int) (freeBytes * 100L / chunkSize); - if (freePercentage == 0) { - return 99; - } - return 100 - freePercentage; + private int usage(int freeBytes) { + if (freeBytes == 0) { + return 100; } - long allocate(int normCapacity) { - if ((normCapacity & subpageOverflowMask) != 0) { // >= pageSize - return allocateRun(normCapacity); - } else { - return allocateSubpage(normCapacity); - } + if (freeBytes == chunkSize) { + return 0; } - /** - * Update method used by allocate. - * This is triggered only when a successor is allocated and all its predecessors - * need to update their state - * The minimal depth at which subtree rooted at id has some free space - * - * @param id id - */ - private void updateParentsAlloc(int id) { - while (id > 1) { - int parentId = id >>> 1; - byte val1 = value(id); - byte val2 = value(id ^ 1); - byte val = val1 < val2 ? val1 : val2; - setValue(parentId, val); - id = parentId; - } + int freePercentage = (int) (freeBytes * 100L / chunkSize); + if (freePercentage == 0) { + return 99; } - - /** - * Update method used by free. - * This needs to handle the special case when both children are completely free - * in which case parent be directly allocated on request of size = child-size * 2 - * - * @param id id - */ - private void updateParentsFree(int id) { - int logChild = depth(id) + 1; - while (id > 1) { - int parentId = id >>> 1; - byte val1 = value(id); - byte val2 = value(id ^ 1); - logChild -= 1; // in first iteration equals log, subsequently reduce 1 from logChild as we traverse up - - if (val1 == logChild && val2 == logChild) { - setValue(parentId, (byte) (logChild - 1)); - } else { - byte val = val1 < val2 ? val1 : val2; - setValue(parentId, val); - } - - id = parentId; - } + return 100 - freePercentage; + } + + long allocate(int normCapacity) { + if ((normCapacity & subpageOverflowMask) != 0) { // >= pageSize + return allocateRun(normCapacity); + } else { + return allocateSubpage(normCapacity); } - - /** - * Algorithm to allocate an index in memoryMap when we query for a free node. - * at depth d - * - * @param d depth - * @return index in memoryMap - */ - private int allocateNode(int d) { - int id = 1; - int initial = -(1 << d); // has last d bits = 0 and rest all = 1 - byte val = value(id); - if (val > d) { // unusable - return -1; - } - while (val < d || (id & initial) == 0) { // id & initial == 1 << d for all ids at depth d, for < d it is 0 - id <<= 1; - val = value(id); - if (val > d) { - id ^= 1; - val = value(id); - } - } - byte value = value(id); - assert value == d && (id & initial) == 1 << d : String.format("val = %d, id & initial = %d, d = %d", - value, id & initial, d); - setValue(id, unusable); // mark as unusable - updateParentsAlloc(id); - return id; + } + + /** + * Update method used by allocate. This is triggered only when a successor is allocated and all + * its predecessors need to update their state The minimal depth at which subtree rooted at id has + * some free space + * + * @param id id + */ + private void updateParentsAlloc(int id) { + while (id > 1) { + int parentId = id >>> 1; + byte val1 = value(id); + byte val2 = value(id ^ 1); + byte val = val1 < val2 ? val1 : val2; + setValue(parentId, val); + id = parentId; } - - /** - * Allocate a run of pages (>=1). - * - * @param normCapacity normalized capacity - * @return index in memoryMap - */ - private long allocateRun(int normCapacity) { - int d = maxOrder - (log2(normCapacity) - pageShifts); - int id = allocateNode(d); - if (id < 0) { - return id; - } - freeBytes -= runLength(id); - return id; + } + + /** + * Update method used by free. This needs to handle the special case when both children are + * completely free in which case parent be directly allocated on request of size = child-size * 2 + * + * @param id id + */ + private void updateParentsFree(int id) { + int logChild = depth(id) + 1; + while (id > 1) { + int parentId = id >>> 1; + byte val1 = value(id); + byte val2 = value(id ^ 1); + logChild -= + 1; // in first iteration equals log, subsequently reduce 1 from logChild as we traverse up + + if (val1 == logChild && val2 == logChild) { + setValue(parentId, (byte) (logChild - 1)); + } else { + byte val = val1 < val2 ? val1 : val2; + setValue(parentId, val); + } + + id = parentId; } - - /** - * Create/ initialize a new PoolSubpage of normCapacity. - * Any PoolSubpage created/ initialized here is added to subpage pool in the PoolArena that owns this PoolChunk - * - * @param normCapacity normalized capacity - * @return index in memoryMap - */ - private long allocateSubpage(int normCapacity) { - - Page head = pool.getSuitablePageHead(normCapacity); - - synchronized (head) { - int d = maxOrder; // subpages are only be allocated from pages i.e., leaves - int id = allocateNode(d); - if (id < 0) { - return id; - } - - final Page[] subpages = this.subpages; - final int pageSize = this.pageSize; - - freeBytes -= pageSize; - - int subpageIdx = subpageIdx(id); - Page subpage = subpages[subpageIdx]; - if (subpage == null) { - subpage = new Page<>(head, this, id, runOffset(id), pageSize, normCapacity); - subpages[subpageIdx] = subpage; - } else { - subpage.init(head, normCapacity); - } - return subpage.allocate(); - } + } + + /** + * Algorithm to allocate an index in memoryMap when we query for a free node. at depth d + * + * @param d depth + * @return index in memoryMap + */ + private int allocateNode(int d) { + int id = 1; + int initial = -(1 << d); // has last d bits = 0 and rest all = 1 + byte val = value(id); + if (val > d) { // unusable + return -1; } - - /** - * Free a subpage or a run of pages. - * When a subpage is freed from PoolSubpage, it might be added back to subpage pool of the owning PoolArena - * If the subpage pool in PoolArena has at least one other PoolSubpage of given elemSize, we can - * completely free the owning Page so it is available for subsequent allocations - */ - void free(long handle) { - int memoryMapIdx = memoryMapIdx(handle); - int bitmapIdx = bitmapIdx(handle); - - if (bitmapIdx != 0) { // free a subpage - Page subpage = subpages[subpageIdx(memoryMapIdx)]; - assert subpage != null && subpage.doNotDestroy; - - // Obtain the head of the PoolSubPage pool that is owned by the PoolArena and synchronize on it. - // This is need as we may add it back and so alter the linked-list structure. - Page head = pool.getSuitablePageHead(subpage.elemSize); - synchronized (head) { - if (subpage.free(head, bitmapIdx & 0x3FFFFFFF)) { - return; - } - } - } - freeBytes += runLength(memoryMapIdx); - setValue(memoryMapIdx, depth(memoryMapIdx)); - updateParentsFree(memoryMapIdx); + while (val < d + || (id & initial) == 0) { // id & initial == 1 << d for all ids at depth d, for < d it is 0 + id <<= 1; + val = value(id); + if (val > d) { + id ^= 1; + val = value(id); + } } - - void initBuf(ByteBuf buf, long handle, int reqCapacity) { - int memoryMapIdx = memoryMapIdx(handle); - int bitmapIdx = bitmapIdx(handle); - if (bitmapIdx == 0) { - byte val = value(memoryMapIdx); - assert val == unusable : String.valueOf(val); - buf.init(this, handle, runOffset(memoryMapIdx) + offset, reqCapacity); - } else { - initBufWithSubpage(buf, handle, bitmapIdx, reqCapacity); - } + byte value = value(id); + assert value == d && (id & initial) == 1 << d + : String.format("val = %d, id & initial = %d, d = %d", value, id & initial, d); + setValue(id, unusable); // mark as unusable + updateParentsAlloc(id); + return id; + } + + /** + * Allocate a run of pages (>=1). + * + * @param normCapacity normalized capacity + * @return index in memoryMap + */ + private long allocateRun(int normCapacity) { + int d = maxOrder - (log2(normCapacity) - pageShifts); + int id = allocateNode(d); + if (id < 0) { + return id; } - - void initBufWithSubpage(ByteBuf buf, long handle, int reqCapacity) { - initBufWithSubpage(buf, handle, bitmapIdx(handle), reqCapacity); - } - - private void initBufWithSubpage(ByteBuf buf, long handle, int bitmapIdx, int reqCapacity) { - assert bitmapIdx != 0; - - int memoryMapIdx = memoryMapIdx(handle); - - Page subpage = subpages[subpageIdx(memoryMapIdx)]; - assert subpage.doNotDestroy; - assert reqCapacity <= subpage.elemSize; - - buf.init( - this, handle, - runOffset(memoryMapIdx) + (bitmapIdx & 0x3FFFFFFF) * subpage.elemSize + offset, - reqCapacity); - } - - private byte value(int id) { - return memoryMap[id]; - } - - private void setValue(int id, byte val) { - memoryMap[id] = val; - } - - private byte depth(int id) { - return depthMap[id]; - } - - private static int log2(int val) { - // compute the (0-based, with lsb = 0) position of highest set bit i.e, log2 - return INTEGER_SIZE_MINUS_ONE - Integer.numberOfLeadingZeros(val); - } - - private int runLength(int id) { - // represents the size in #bytes supported by node 'id' in the tree - return 1 << log2ChunkSize - depth(id); - } - - private int runOffset(int id) { - // represents the 0-based offset in #bytes from start of the byte-array chunk - int shift = id ^ 1 << depth(id); - return shift * runLength(id); - } - - private int subpageIdx(int memoryMapIdx) { - return memoryMapIdx ^ maxSubpageAllocs; // remove highest set bit, to get offset - } - - private static int memoryMapIdx(long handle) { - return (int) handle; - } - - private static int bitmapIdx(long handle) { - return (int) (handle >>> Integer.SIZE); - } - - @Override - public int chunkSize() { - return chunkSize; + freeBytes -= runLength(id); + return id; + } + + /** + * Create/ initialize a new PoolSubpage of normCapacity. Any PoolSubpage created/ initialized here + * is added to subpage pool in the PoolArena that owns this PoolChunk + * + * @param normCapacity normalized capacity + * @return index in memoryMap + */ + private long allocateSubpage(int normCapacity) { + + Page head = pool.getSuitablePageHead(normCapacity); + + synchronized (head) { + int d = maxOrder; // subpages are only be allocated from pages i.e., leaves + int id = allocateNode(d); + if (id < 0) { + return id; + } + + final Page[] subpages = this.subpages; + final int pageSize = this.pageSize; + + freeBytes -= pageSize; + + int subpageIdx = subpageIdx(id); + Page subpage = subpages[subpageIdx]; + if (subpage == null) { + subpage = new Page<>(head, this, id, runOffset(id), pageSize, normCapacity); + subpages[subpageIdx] = subpage; + } else { + subpage.init(head, normCapacity); + } + return subpage.allocate(); } - - @Override - public int freeBytes() { - synchronized (pool) { - return freeBytes; + } + + /** + * Free a subpage or a run of pages. When a subpage is freed from PoolSubpage, it might be added + * back to subpage pool of the owning PoolArena If the subpage pool in PoolArena has at least one + * other PoolSubpage of given elemSize, we can completely free the owning Page so it is available + * for subsequent allocations + */ + void free(long handle) { + int memoryMapIdx = memoryMapIdx(handle); + int bitmapIdx = bitmapIdx(handle); + + if (bitmapIdx != 0) { // free a subpage + Page subpage = subpages[subpageIdx(memoryMapIdx)]; + assert subpage != null && subpage.doNotDestroy; + + // Obtain the head of the PoolSubPage pool that is owned by the PoolArena and synchronize on + // it. + // This is need as we may add it back and so alter the linked-list structure. + Page head = pool.getSuitablePageHead(subpage.elemSize); + synchronized (head) { + if (subpage.free(head, bitmapIdx & 0x3FFFFFFF)) { + return; } + } } - - @Override - public int activeBytes() { - return chunkSize() - freeBytes(); + freeBytes += runLength(memoryMapIdx); + setValue(memoryMapIdx, depth(memoryMapIdx)); + updateParentsFree(memoryMapIdx); + } + + void initBuf(ByteBuf buf, long handle, int reqCapacity) { + int memoryMapIdx = memoryMapIdx(handle); + int bitmapIdx = bitmapIdx(handle); + if (bitmapIdx == 0) { + byte val = value(memoryMapIdx); + assert val == unusable : String.valueOf(val); + buf.init(this, handle, runOffset(memoryMapIdx) + offset, reqCapacity); + } else { + initBufWithSubpage(buf, handle, bitmapIdx, reqCapacity); } - - @Override - public String toString() { - final int freeBytes; - synchronized (pool) { - freeBytes = this.freeBytes; - } - - return new StringBuilder() - .append("Chunk(") - .append(Integer.toHexString(System.identityHashCode(this))) - .append(": ") - .append(usage(freeBytes)) - .append("%, ") - .append(chunkSize - freeBytes) - .append('/') - .append(chunkSize) - .append(')') - .toString(); + } + + void initBufWithSubpage(ByteBuf buf, long handle, int reqCapacity) { + initBufWithSubpage(buf, handle, bitmapIdx(handle), reqCapacity); + } + + private void initBufWithSubpage(ByteBuf buf, long handle, int bitmapIdx, int reqCapacity) { + assert bitmapIdx != 0; + + int memoryMapIdx = memoryMapIdx(handle); + + Page subpage = subpages[subpageIdx(memoryMapIdx)]; + assert subpage.doNotDestroy; + assert reqCapacity <= subpage.elemSize; + + buf.init( + this, + handle, + runOffset(memoryMapIdx) + (bitmapIdx & 0x3FFFFFFF) * subpage.elemSize + offset, + reqCapacity); + } + + private byte value(int id) { + return memoryMap[id]; + } + + private void setValue(int id, byte val) { + memoryMap[id] = val; + } + + private byte depth(int id) { + return depthMap[id]; + } + + private static int log2(int val) { + // compute the (0-based, with lsb = 0) position of highest set bit i.e, log2 + return INTEGER_SIZE_MINUS_ONE - Integer.numberOfLeadingZeros(val); + } + + private int runLength(int id) { + // represents the size in #bytes supported by node 'id' in the tree + return 1 << log2ChunkSize - depth(id); + } + + private int runOffset(int id) { + // represents the 0-based offset in #bytes from start of the byte-array chunk + int shift = id ^ 1 << depth(id); + return shift * runLength(id); + } + + private int subpageIdx(int memoryMapIdx) { + return memoryMapIdx ^ maxSubpageAllocs; // remove highest set bit, to get offset + } + + private static int memoryMapIdx(long handle) { + return (int) handle; + } + + private static int bitmapIdx(long handle) { + return (int) (handle >>> Integer.SIZE); + } + + @Override + public int chunkSize() { + return chunkSize; + } + + @Override + public int freeBytes() { + synchronized (pool) { + return freeBytes; } - - void destroy() { - pool.destroyChunk(this); + } + + @Override + public int activeBytes() { + return chunkSize() - freeBytes(); + } + + @Override + public String toString() { + final int freeBytes; + synchronized (pool) { + freeBytes = this.freeBytes; } + + return new StringBuilder() + .append("Chunk(") + .append(Integer.toHexString(System.identityHashCode(this))) + .append(": ") + .append(usage(freeBytes)) + .append("%, ") + .append(chunkSize - freeBytes) + .append('/') + .append(chunkSize) + .append(')') + .toString(); + } + + void destroy() { + pool.destroyChunk(this); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ChunkList.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ChunkList.java index 4a2f08de2..97b3d3c51 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ChunkList.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ChunkList.java @@ -19,223 +19,226 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.geaflow.memory.metric.ChunkListMetric; import org.apache.geaflow.memory.metric.ChunkMetric; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ChunkList implements ChunkListMetric { - - private static final Logger LOGGER = LoggerFactory.getLogger(ChunkList.class); - - private static final Iterator EMPTY_METRICS = Collections.emptyList().iterator(); - - private final AbstractMemoryPool pool; +import com.google.common.base.Preconditions; - private ChunkList nextList; - private ChunkList preList; +public class ChunkList implements ChunkListMetric { - private Chunk head; + private static final Logger LOGGER = LoggerFactory.getLogger(ChunkList.class); - private final int minUsage; - private final int maxUsage; - AtomicInteger freeChunk; + private static final Iterator EMPTY_METRICS = + Collections.emptyList().iterator(); - public ChunkList(AbstractMemoryPool pool, int minUsage, int maxUsage) { - this.pool = pool; - this.minUsage = minUsage; - this.maxUsage = maxUsage; - freeChunk = new AtomicInteger(0); - } + private final AbstractMemoryPool pool; - public void setNextList(ChunkList nextList) { - this.nextList = nextList; - } + private ChunkList nextList; + private ChunkList preList; - public void setPreList(ChunkList preList) { - this.preList = preList; - } + private Chunk head; - boolean allocate(ByteBuf buf, int capacity) { - - try { - if (head == null) { - return false; - } - - for (Chunk chunk = head; ; ) { - if (chunk.isFree()) { - freeChunk.addAndGet(-1); - } - long handle = chunk.allocate(capacity); - if (handle < 0) { - chunk = chunk.next; - if (chunk == null) { - return false; - } - } else { - chunk.initBuf(buf, handle, capacity); - if (chunk.usage() >= maxUsage()) { - remove(chunk); - nextList.add(chunk); - } - return true; - } - } - } catch (Throwable t) { - LOGGER.error(String.format("max=%d,min=%d allocate failed!", maxUsage, minUsage), t); - throw t; - } - } + private final int minUsage; + private final int maxUsage; + AtomicInteger freeChunk; - boolean free(Chunk chunk, long handle) { - chunk.free(handle); - if (chunk.usage() < minUsage) { - remove(chunk); - return internalMove(chunk); - } - return true; - } + public ChunkList(AbstractMemoryPool pool, int minUsage, int maxUsage) { + this.pool = pool; + this.minUsage = minUsage; + this.maxUsage = maxUsage; + freeChunk = new AtomicInteger(0); + } - private boolean move(Chunk chunk) { - assert chunk.usage() < maxUsage; + public void setNextList(ChunkList nextList) { + this.nextList = nextList; + } - if (chunk.usage() < minUsage) { - return internalMove(chunk); - } + public void setPreList(ChunkList preList) { + this.preList = preList; + } - internalAdd(chunk); - return true; - } + boolean allocate(ByteBuf buf, int capacity) { - void add(Chunk chunk) { - if (chunk.usage() >= maxUsage) { - nextList.add(chunk); - return; - } - internalAdd(chunk); - } + try { + if (head == null) { + return false; + } - void internalAdd(Chunk chunk) { - chunk.parent = this; - if (head == null) { - head = chunk; - chunk.prev = null; - chunk.next = null; - } else { - chunk.prev = null; - chunk.next = head; - head.prev = chunk; - head = chunk; - } + for (Chunk chunk = head; ; ) { if (chunk.isFree()) { - freeChunk.addAndGet(1); - } - } - - private boolean internalMove(Chunk chunk) { - if (preList == null) { - // 堆内内存直接释放 - if (pool.getMemoryMode() == MemoryMode.ON_HEAP) { - Preconditions.checkArgument(chunk.usage() == 0); - return false; - } else { - // 堆外先放回,等待触发缩容 - internalAdd(chunk); - //缩容入口 - pool.shrinkCapacity(); - - return true; - } + freeChunk.addAndGet(-1); } - - return preList.move(chunk); - } - - void remove(Chunk cur) { - if (cur == head) { - head = cur.next; - if (head != null) { - head.prev = null; - } + long handle = chunk.allocate(capacity); + if (handle < 0) { + chunk = chunk.next; + if (chunk == null) { + return false; + } } else { - Chunk next = cur.next; - cur.prev.next = next; - if (next != null) { - next.prev = cur.prev; - } - } - } - - @Override - public int minUsage() { - return Math.max(1, this.minUsage); - } - - @Override - public int maxUsage() { - return Math.min(this.maxUsage, 100); - } - - @Override - public Iterator iterator() { - synchronized (pool) { - if (head == null) { - return EMPTY_METRICS; - } - List metrics = new ArrayList<>(); - for (Chunk cur = head; ; ) { - metrics.add(cur); - cur = cur.next; - if (cur == null) { - break; - } - } - return metrics.iterator(); + chunk.initBuf(buf, handle, capacity); + if (chunk.usage() >= maxUsage()) { + remove(chunk); + nextList.add(chunk); + } + return true; } - } + } + } catch (Throwable t) { + LOGGER.error(String.format("max=%d,min=%d allocate failed!", maxUsage, minUsage), t); + throw t; + } + } + + boolean free(Chunk chunk, long handle) { + chunk.free(handle); + if (chunk.usage() < minUsage) { + remove(chunk); + return internalMove(chunk); + } + return true; + } + + private boolean move(Chunk chunk) { + assert chunk.usage() < maxUsage; + + if (chunk.usage() < minUsage) { + return internalMove(chunk); + } + + internalAdd(chunk); + return true; + } + + void add(Chunk chunk) { + if (chunk.usage() >= maxUsage) { + nextList.add(chunk); + return; + } + internalAdd(chunk); + } + + void internalAdd(Chunk chunk) { + chunk.parent = this; + if (head == null) { + head = chunk; + chunk.prev = null; + chunk.next = null; + } else { + chunk.prev = null; + chunk.next = head; + head.prev = chunk; + head = chunk; + } + if (chunk.isFree()) { + freeChunk.addAndGet(1); + } + } + + private boolean internalMove(Chunk chunk) { + if (preList == null) { + // 堆内内存直接释放 + if (pool.getMemoryMode() == MemoryMode.ON_HEAP) { + Preconditions.checkArgument(chunk.usage() == 0); + return false; + } else { + // 堆外先放回,等待触发缩容 + internalAdd(chunk); + // 缩容入口 + pool.shrinkCapacity(); - void destroy() { - Chunk cur = head; - while (cur != null) { - pool.destroyChunk(cur); - cur = cur.next; + return true; + } + } + + return preList.move(chunk); + } + + void remove(Chunk cur) { + if (cur == head) { + head = cur.next; + if (head != null) { + head.prev = null; + } + } else { + Chunk next = cur.next; + cur.prev.next = next; + if (next != null) { + next.prev = cur.prev; + } + } + } + + @Override + public int minUsage() { + return Math.max(1, this.minUsage); + } + + @Override + public int maxUsage() { + return Math.min(this.maxUsage, 100); + } + + @Override + public Iterator iterator() { + synchronized (pool) { + if (head == null) { + return EMPTY_METRICS; + } + List metrics = new ArrayList<>(); + for (Chunk cur = head; ; ) { + metrics.add(cur); + cur = cur.next; + if (cur == null) { + break; } - head = null; - } - - int freeChunkNum() { - return freeChunk.get(); - } - - Chunk getHead() { - return head; - } - - @Override - public String toString() { - StringBuilder buf = new StringBuilder(); - synchronized (pool) { - if (head == null) { - return "[none]"; - } - buf.append("[free: ").append(freeChunkNum()); - int size = 0; - for (Chunk cur = head; ; ) { - size++; - cur = cur.next; - if (cur == null) { - break; - } - } - buf.append(", total: ").append(size).append("]"); + } + return metrics.iterator(); + } + } + + void destroy() { + Chunk cur = head; + while (cur != null) { + pool.destroyChunk(cur); + cur = cur.next; + } + head = null; + } + + int freeChunkNum() { + return freeChunk.get(); + } + + Chunk getHead() { + return head; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + synchronized (pool) { + if (head == null) { + return "[none]"; + } + buf.append("[free: ").append(freeChunkNum()); + int size = 0; + for (Chunk cur = head; ; ) { + size++; + cur = cur.next; + if (cur == null) { + break; } - return buf.toString(); + } + buf.append(", total: ").append(size).append("]"); } + return buf.toString(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemory.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemory.java index 300ce40c4..3d4f1e8a4 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemory.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemory.java @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.geaflow.common.utils.SystemArgsUtil; import org.apache.geaflow.memory.cleaner.Cleaner; import org.apache.geaflow.memory.cleaner.CleanerJava6; @@ -31,381 +32,373 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is an adaptation of Netty's io.netty.util.internal.DirectMemory. - */ +/** This class is an adaptation of Netty's io.netty.util.internal.DirectMemory. */ public final class DirectMemory { - private static final Logger LOGGER = LoggerFactory.getLogger(DirectMemory.class); - - private static final Pattern MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN = Pattern - .compile("\\s*-XX:MaxDirectMemorySize\\s*=\\s*([0-9]+)\\s*([kKmMgG]?)\\s*$"); - private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause0(); - private static final long MAX_DIRECT_MEMORY = maxDirectMemory0(); - private static final long BYTE_ARRAY_BASE_OFFSET = byteArrayBaseOffset0(); - private static final boolean USE_DIRECT_BUFFER_NO_CLEANER; - private static final AtomicLong DIRECT_MEMORY_COUNTER; - private static final long DIRECT_MEMORY_LIMIT; - private static final Cleaner CLEANER; - private static final Cleaner NOOP = buffer -> { + private static final Logger LOGGER = LoggerFactory.getLogger(DirectMemory.class); + + private static final Pattern MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN = + Pattern.compile("\\s*-XX:MaxDirectMemorySize\\s*=\\s*([0-9]+)\\s*([kKmMgG]?)\\s*$"); + private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause0(); + private static final long MAX_DIRECT_MEMORY = maxDirectMemory0(); + private static final long BYTE_ARRAY_BASE_OFFSET = byteArrayBaseOffset0(); + private static final boolean USE_DIRECT_BUFFER_NO_CLEANER; + private static final AtomicLong DIRECT_MEMORY_COUNTER; + private static final long DIRECT_MEMORY_LIMIT; + private static final Cleaner CLEANER; + private static final Cleaner NOOP = + buffer -> { // NOOP - }; - - static { - // Here is how the system property is used: - // - // * < 0 - Don't use cleaner, and inherit max direct memory from java. In this case the - // "practical max direct memory" would be 2 * max memory as defined by the JDK. - // * == 0 - Use cleaner, Netty will not enforce max memory, and instead will defer to JDK. - // * > 0 - Don't use cleaner. This will limit Netty's total direct memory - // (note: that JDK's direct memory limit is independent of this). - long maxDirectMemory = -1; - - if (!hasUnsafe() || !PlatformDependent - .hasDirectBufferNoCleanerConstructor()) { - USE_DIRECT_BUFFER_NO_CLEANER = false; - DIRECT_MEMORY_COUNTER = null; - } else { - USE_DIRECT_BUFFER_NO_CLEANER = true; - maxDirectMemory = MAX_DIRECT_MEMORY; - if (maxDirectMemory <= 0) { - DIRECT_MEMORY_COUNTER = null; - } else { - DIRECT_MEMORY_COUNTER = new AtomicLong(); - } - } - DIRECT_MEMORY_LIMIT = maxDirectMemory >= 1 ? maxDirectMemory : MAX_DIRECT_MEMORY; - - // only direct to method if we are not running on android. - // See https://github.com/netty/netty/issues/2604 - if (javaVersion() >= 9) { - CLEANER = CleanerJava9.isSupported() ? new CleanerJava9() : NOOP; - } else { - CLEANER = CleanerJava6.isSupported() ? new CleanerJava6() : NOOP; - } - - /* - * We do not want to log this message if unsafe is explicitly disabled. Do not remove the - * explicit no unsafe - * guard. - */ - if (CLEANER == NOOP) { - LOGGER.info( - "Your platform does not provide complete low-level API for accessing direct buffers reliably. " - , "Unless explicitly requested, heap buffer will always be preferred to avoid potential system instability."); - } - } - - private DirectMemory() { - // only static method supported - } - - /** - * Return the version of Java under which this library is used. - */ - public static int javaVersion() { - return PlatformDependent.javaVersion(); - } - - /** - * Return {@code true} if {@code sun.misc.Unsafe} was found on the classpath and can be used for - * accelerated. - * direct memory access. - */ - public static boolean hasUnsafe() { - return UNSAFE_UNAVAILABILITY_CAUSE == null; - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - if (hasUnsafe()) { - PlatformDependent.throwException(t); - } else { - DirectMemory.throwException0(t); - } - } - - @SuppressWarnings("unchecked") - private static void throwException0(Throwable t) throws E { - throw (E) t; - } - - /** - * Try to deallocate the specified direct {@link ByteBuffer}. Please note this method does - * nothing if - * the current platform does not support this operation or the specified buffer is not a direct - * buffer. - */ - public static void freeDirectBuffer(ByteBuffer buffer) { - CLEANER.freeDirectBuffer(buffer); - } - - public static long directBufferAddress(ByteBuffer buffer) { - return PlatformDependent.directBufferAddress(buffer); - } - - public static Object getObject(Object object, long fieldOffset) { - return PlatformDependent.getObject(object, fieldOffset); - } - - - public static byte getByte(long address) { - return PlatformDependent.getByte(address); - } - - public static short getShort(long address) { - return PlatformDependent.getShort(address); - } - - public static int getInt(long address) { - return PlatformDependent.getInt(address); - } - - public static long getLong(long address) { - return PlatformDependent.getLong(address); - } - - public static void putByte(long address, byte value) { - PlatformDependent.putByte(address, value); - } - - public static void putShort(long address, short value) { - PlatformDependent.putShort(address, value); - } - - public static void putInt(long address, int value) { - PlatformDependent.putInt(address, value); - } - - public static void putLong(long address, long value) { - PlatformDependent.putLong(address, value); - } - - public static long objectFieldOffset(Field field) { - return PlatformDependent.objectFieldOffset(field); - } - - public static void copyMemory(long srcAddr, long dstAddr, long length) { - PlatformDependent.copyMemory(srcAddr, dstAddr, length); - } - - public static void copyMemory(long srcAddr, byte[] dst, int dstIndex, long length) { - PlatformDependent.copyMemory(null, srcAddr, dst, BYTE_ARRAY_BASE_OFFSET + dstIndex, length); - } - - public static void copyMemory(byte[] src, int srcIndex, long dstAddr, long length) { - PlatformDependent.copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, null, dstAddr, length); - } - - public static void freeMemory(long address) { - PlatformDependent.freeMemory(address); - } - - public static void setMemory(long address, long bytes, byte value) { - PlatformDependent.setMemory(address, bytes, value); - } - - static ByteBuffer[] splitBuffer(ByteBuffer byteBuffer, int len) { - long address = directBufferAddress(byteBuffer); - int inSize = byteBuffer.capacity() / len; - ByteBuffer[] bfs = new ByteBuffer[inSize]; - int offset = 0; - - for (int i = 0; i < inSize; i++) { - bfs[i] = PlatformDependent.newDirectBuffer(address + offset, len); - offset += len; - } - return bfs; - } - - static ByteBuffer mergeBuffer(ByteBuffer bf1, ByteBuffer bf2) { - return PlatformDependent - .newDirectBuffer(directBufferAddress(bf1), bf1.capacity() + bf2.capacity()); - } - - /** - * Allocate a new {@link ByteBuffer} with the given {@code capacity}. {@link ByteBuffer}s - * allocated with - * this method MUST be deallocated via - * {@link #freeDirectNoCleaner(ByteBuffer)}. - */ - public static ByteBuffer allocateDirectNoCleaner(int capacity) { - assert USE_DIRECT_BUFFER_NO_CLEANER; - - incrementMemoryCounter(capacity); - try { - return PlatformDependent.allocateDirectNoCleaner(capacity); - } catch (Throwable e) { - decrementMemoryCounter(capacity); - throwException(e); - return null; - } - } - - /** - * This method MUST only be called for {@link ByteBuffer}s that were allocated - * via - * {@link #allocateDirectNoCleaner(int)}. - */ - public static void freeDirectNoCleaner(ByteBuffer buffer) { - assert USE_DIRECT_BUFFER_NO_CLEANER; - - int capacity = buffer.capacity(); - PlatformDependent.freeMemory(PlatformDependent.directBufferAddress(buffer)); - decrementMemoryCounter(capacity); - } - - private static void incrementMemoryCounter(int capacity) { - if (DIRECT_MEMORY_COUNTER != null) { - long newUsedMemory = DIRECT_MEMORY_COUNTER.addAndGet(capacity); - if (newUsedMemory > DIRECT_MEMORY_LIMIT) { - DIRECT_MEMORY_COUNTER.addAndGet(-capacity); - throw new GeaflowOutOfMemoryException( - "failed to allocate " + capacity + " byte(s) of direct memory (used: " + ( - newUsedMemory - capacity) + ", max: " + DIRECT_MEMORY_LIMIT + ')'); - } - } - } - - private static void decrementMemoryCounter(int capacity) { - if (DIRECT_MEMORY_COUNTER != null) { - long usedMemory = DIRECT_MEMORY_COUNTER.addAndGet(-capacity); - assert usedMemory >= 0; - } - } - - public static boolean useDirectBufferNoCleaner() { - return USE_DIRECT_BUFFER_NO_CLEANER; - } - - public static sun.misc.Unsafe unsafe() { - return PlatformDependent.UNSAFE; - } - - /** - * Return the system {@link ClassLoader}. + }; + + static { + // Here is how the system property is used: + // + // * < 0 - Don't use cleaner, and inherit max direct memory from java. In this case the + // "practical max direct memory" would be 2 * max memory as defined by the JDK. + // * == 0 - Use cleaner, Netty will not enforce max memory, and instead will defer to JDK. + // * > 0 - Don't use cleaner. This will limit Netty's total direct memory + // (note: that JDK's direct memory limit is independent of this). + long maxDirectMemory = -1; + + if (!hasUnsafe() || !PlatformDependent.hasDirectBufferNoCleanerConstructor()) { + USE_DIRECT_BUFFER_NO_CLEANER = false; + DIRECT_MEMORY_COUNTER = null; + } else { + USE_DIRECT_BUFFER_NO_CLEANER = true; + maxDirectMemory = MAX_DIRECT_MEMORY; + if (maxDirectMemory <= 0) { + DIRECT_MEMORY_COUNTER = null; + } else { + DIRECT_MEMORY_COUNTER = new AtomicLong(); + } + } + DIRECT_MEMORY_LIMIT = maxDirectMemory >= 1 ? maxDirectMemory : MAX_DIRECT_MEMORY; + + // only direct to method if we are not running on android. + // See https://github.com/netty/netty/issues/2604 + if (javaVersion() >= 9) { + CLEANER = CleanerJava9.isSupported() ? new CleanerJava9() : NOOP; + } else { + CLEANER = CleanerJava6.isSupported() ? new CleanerJava6() : NOOP; + } + + /* + * We do not want to log this message if unsafe is explicitly disabled. Do not remove the + * explicit no unsafe + * guard. */ - public static ClassLoader getSystemClassLoader() { - return PlatformDependent.getSystemClassLoader(); - } - - private static Throwable unsafeUnavailabilityCause0() { - Throwable cause = PlatformDependent.getUnsafeUnavailabilityCause(); - if (cause != null) { - return cause; + if (CLEANER == NOOP) { + LOGGER.info( + "Your platform does not provide complete low-level API for accessing direct buffers" + + " reliably. ", + "Unless explicitly requested, heap buffer will always be preferred to avoid potential" + + " system instability."); + } + } + + private DirectMemory() { + // only static method supported + } + + /** Return the version of Java under which this library is used. */ + public static int javaVersion() { + return PlatformDependent.javaVersion(); + } + + /** + * Return {@code true} if {@code sun.misc.Unsafe} was found on the classpath and can be used for + * accelerated. direct memory access. + */ + public static boolean hasUnsafe() { + return UNSAFE_UNAVAILABILITY_CAUSE == null; + } + + /** Raises an exception bypassing compiler checks for checked exceptions. */ + public static void throwException(Throwable t) { + if (hasUnsafe()) { + PlatformDependent.throwException(t); + } else { + DirectMemory.throwException0(t); + } + } + + @SuppressWarnings("unchecked") + private static void throwException0(Throwable t) throws E { + throw (E) t; + } + + /** + * Try to deallocate the specified direct {@link ByteBuffer}. Please note this method does nothing + * if the current platform does not support this operation or the specified buffer is not a direct + * buffer. + */ + public static void freeDirectBuffer(ByteBuffer buffer) { + CLEANER.freeDirectBuffer(buffer); + } + + public static long directBufferAddress(ByteBuffer buffer) { + return PlatformDependent.directBufferAddress(buffer); + } + + public static Object getObject(Object object, long fieldOffset) { + return PlatformDependent.getObject(object, fieldOffset); + } + + public static byte getByte(long address) { + return PlatformDependent.getByte(address); + } + + public static short getShort(long address) { + return PlatformDependent.getShort(address); + } + + public static int getInt(long address) { + return PlatformDependent.getInt(address); + } + + public static long getLong(long address) { + return PlatformDependent.getLong(address); + } + + public static void putByte(long address, byte value) { + PlatformDependent.putByte(address, value); + } + + public static void putShort(long address, short value) { + PlatformDependent.putShort(address, value); + } + + public static void putInt(long address, int value) { + PlatformDependent.putInt(address, value); + } + + public static void putLong(long address, long value) { + PlatformDependent.putLong(address, value); + } + + public static long objectFieldOffset(Field field) { + return PlatformDependent.objectFieldOffset(field); + } + + public static void copyMemory(long srcAddr, long dstAddr, long length) { + PlatformDependent.copyMemory(srcAddr, dstAddr, length); + } + + public static void copyMemory(long srcAddr, byte[] dst, int dstIndex, long length) { + PlatformDependent.copyMemory(null, srcAddr, dst, BYTE_ARRAY_BASE_OFFSET + dstIndex, length); + } + + public static void copyMemory(byte[] src, int srcIndex, long dstAddr, long length) { + PlatformDependent.copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, null, dstAddr, length); + } + + public static void freeMemory(long address) { + PlatformDependent.freeMemory(address); + } + + public static void setMemory(long address, long bytes, byte value) { + PlatformDependent.setMemory(address, bytes, value); + } + + static ByteBuffer[] splitBuffer(ByteBuffer byteBuffer, int len) { + long address = directBufferAddress(byteBuffer); + int inSize = byteBuffer.capacity() / len; + ByteBuffer[] bfs = new ByteBuffer[inSize]; + int offset = 0; + + for (int i = 0; i < inSize; i++) { + bfs[i] = PlatformDependent.newDirectBuffer(address + offset, len); + offset += len; + } + return bfs; + } + + static ByteBuffer mergeBuffer(ByteBuffer bf1, ByteBuffer bf2) { + return PlatformDependent.newDirectBuffer( + directBufferAddress(bf1), bf1.capacity() + bf2.capacity()); + } + + /** + * Allocate a new {@link ByteBuffer} with the given {@code capacity}. {@link ByteBuffer}s + * allocated with this method MUST be deallocated via {@link + * #freeDirectNoCleaner(ByteBuffer)}. + */ + public static ByteBuffer allocateDirectNoCleaner(int capacity) { + assert USE_DIRECT_BUFFER_NO_CLEANER; + + incrementMemoryCounter(capacity); + try { + return PlatformDependent.allocateDirectNoCleaner(capacity); + } catch (Throwable e) { + decrementMemoryCounter(capacity); + throwException(e); + return null; + } + } + + /** + * This method MUST only be called for {@link ByteBuffer}s that were allocated + * via {@link #allocateDirectNoCleaner(int)}. + */ + public static void freeDirectNoCleaner(ByteBuffer buffer) { + assert USE_DIRECT_BUFFER_NO_CLEANER; + + int capacity = buffer.capacity(); + PlatformDependent.freeMemory(PlatformDependent.directBufferAddress(buffer)); + decrementMemoryCounter(capacity); + } + + private static void incrementMemoryCounter(int capacity) { + if (DIRECT_MEMORY_COUNTER != null) { + long newUsedMemory = DIRECT_MEMORY_COUNTER.addAndGet(capacity); + if (newUsedMemory > DIRECT_MEMORY_LIMIT) { + DIRECT_MEMORY_COUNTER.addAndGet(-capacity); + throw new GeaflowOutOfMemoryException( + "failed to allocate " + + capacity + + " byte(s) of direct memory (used: " + + (newUsedMemory - capacity) + + ", max: " + + DIRECT_MEMORY_LIMIT + + ')'); + } + } + } + + private static void decrementMemoryCounter(int capacity) { + if (DIRECT_MEMORY_COUNTER != null) { + long usedMemory = DIRECT_MEMORY_COUNTER.addAndGet(-capacity); + assert usedMemory >= 0; + } + } + + public static boolean useDirectBufferNoCleaner() { + return USE_DIRECT_BUFFER_NO_CLEANER; + } + + public static sun.misc.Unsafe unsafe() { + return PlatformDependent.UNSAFE; + } + + /** Return the system {@link ClassLoader}. */ + public static ClassLoader getSystemClassLoader() { + return PlatformDependent.getSystemClassLoader(); + } + + private static Throwable unsafeUnavailabilityCause0() { + Throwable cause = PlatformDependent.getUnsafeUnavailabilityCause(); + if (cause != null) { + return cause; + } + + try { + boolean hasUnsafe = PlatformDependent.hasUnsafe(); + LOGGER.debug("sun.misc.Unsafe: {}", hasUnsafe ? "available" : "unavailable"); + return hasUnsafe ? null : PlatformDependent.getUnsafeUnavailabilityCause(); + } catch (Throwable t) { + LOGGER.trace("Could not determine if Unsafe is available", t); + // Probably failed to initialize PlatformDependent0. + return new UnsupportedOperationException("Could not determine if Unsafe is available", t); + } + } + + public static long maxDirectMemory0() { + long maxDirectMemory = 0; + + ClassLoader systemClassLoader = null; + try { + systemClassLoader = getSystemClassLoader(); + + // When using IBM J9 / Eclipse OpenJ9 we should not use VM.maxDirectMemory() as it + // not reflects the + // correct value. + // See: + // - https://github.com/netty/netty/issues/7654 + String vmName = SystemArgsUtil.get("java.vm.name", "").toLowerCase(); + if (!vmName.startsWith("ibm j9") + // https://github.com/eclipse/openj9/blob/openj9-0.8 + // .0/runtime/include/vendor_version.h#L53 + && !vmName.startsWith("eclipse openj9")) { + // Try to build from sun.misc.VM.maxDirectMemory() which should be most accurate. + Class vmClass = Class.forName("sun.misc.VM", true, systemClassLoader); + Method m = vmClass.getDeclaredMethod("maxDirectMemory"); + maxDirectMemory = ((Number) m.invoke(null)).longValue(); + } + } catch (Throwable ignored) { + LOGGER.warn("fail maxDirectMemory0", ignored); + } + + if (maxDirectMemory > 0) { + LOGGER.info("maxDirectMemory: {} bytes from sun.misc.VM", maxDirectMemory); + return maxDirectMemory; + } + + List vmArgs = null; + try { + // Now try to build the JVM option (-XX:MaxDirectMemorySize) and parse it. + // Note that we are using reflection because Android doesn't have these classes. + Class mgmtFactoryClass = + Class.forName("java.lang.management.ManagementFactory", true, systemClassLoader); + Class runtimeClass = + Class.forName("java.lang.management.RuntimeMXBean", true, systemClassLoader); + + Object runtime = mgmtFactoryClass.getDeclaredMethod("getRuntimeMXBean").invoke(null); + + vmArgs = (List) runtimeClass.getDeclaredMethod("getInputArguments").invoke(runtime); + } catch (Throwable ignored) { + LOGGER.warn("fail maxDirectMemory0", ignored); + } + + return maxDirectMemoryFromJVMOption(vmArgs); + } + + public static long maxDirectMemoryFromJVMOption(List vmArgs) { + long maxDirectMemory = 0; + try { + for (int i = vmArgs.size() - 1; i >= 0; i--) { + Matcher m = MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN.matcher(vmArgs.get(i)); + if (!m.matches()) { + continue; } - try { - boolean hasUnsafe = PlatformDependent.hasUnsafe(); - LOGGER.debug("sun.misc.Unsafe: {}", hasUnsafe ? "available" : "unavailable"); - return hasUnsafe ? null : PlatformDependent.getUnsafeUnavailabilityCause(); - } catch (Throwable t) { - LOGGER.trace("Could not determine if Unsafe is available", t); - // Probably failed to initialize PlatformDependent0. - return new UnsupportedOperationException("Could not determine if Unsafe is available", - t); + maxDirectMemory = Long.parseLong(m.group(1)); + switch (m.group(2).charAt(0)) { + case 'k': + case 'K': + maxDirectMemory *= 1024; + break; + case 'm': + case 'M': + maxDirectMemory *= 1024 * 1024; + break; + case 'g': + case 'G': + maxDirectMemory *= 1024 * 1024 * 1024; + break; + default: + throw new IllegalAccessException(); } + break; + } + } catch (Throwable ignored) { + LOGGER.warn("fail maxDirectMemory0", ignored); } - public static long maxDirectMemory0() { - long maxDirectMemory = 0; - - ClassLoader systemClassLoader = null; - try { - systemClassLoader = getSystemClassLoader(); - - // When using IBM J9 / Eclipse OpenJ9 we should not use VM.maxDirectMemory() as it - // not reflects the - // correct value. - // See: - // - https://github.com/netty/netty/issues/7654 - String vmName = SystemArgsUtil.get("java.vm.name", "").toLowerCase(); - if (!vmName.startsWith("ibm j9") - // https://github.com/eclipse/openj9/blob/openj9-0.8 - // .0/runtime/include/vendor_version.h#L53 - && !vmName.startsWith("eclipse openj9")) { - // Try to build from sun.misc.VM.maxDirectMemory() which should be most accurate. - Class vmClass = Class.forName("sun.misc.VM", true, systemClassLoader); - Method m = vmClass.getDeclaredMethod("maxDirectMemory"); - maxDirectMemory = ((Number) m.invoke(null)).longValue(); - } - } catch (Throwable ignored) { - LOGGER.warn("fail maxDirectMemory0", ignored); - } - - if (maxDirectMemory > 0) { - LOGGER.info("maxDirectMemory: {} bytes from sun.misc.VM", maxDirectMemory); - return maxDirectMemory; - } - - List vmArgs = null; - try { - // Now try to build the JVM option (-XX:MaxDirectMemorySize) and parse it. - // Note that we are using reflection because Android doesn't have these classes. - Class mgmtFactoryClass = Class - .forName("java.lang.management.ManagementFactory", true, systemClassLoader); - Class runtimeClass = Class - .forName("java.lang.management.RuntimeMXBean", true, systemClassLoader); - - Object runtime = mgmtFactoryClass.getDeclaredMethod("getRuntimeMXBean").invoke(null); - - vmArgs = (List) runtimeClass - .getDeclaredMethod("getInputArguments").invoke(runtime); - } catch (Throwable ignored) { - LOGGER.warn("fail maxDirectMemory0", ignored); - } - - return maxDirectMemoryFromJVMOption(vmArgs); + if (maxDirectMemory <= 0) { + maxDirectMemory = Runtime.getRuntime().maxMemory(); + LOGGER.info("maxDirectMemory: {} bytes (maybe) from jvm", maxDirectMemory); + } else { + LOGGER.info("maxDirectMemory: {} bytes from args", maxDirectMemory); } - public static long maxDirectMemoryFromJVMOption(List vmArgs) { - long maxDirectMemory = 0; - try { - for (int i = vmArgs.size() - 1; i >= 0; i--) { - Matcher m = MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN.matcher(vmArgs.get(i)); - if (!m.matches()) { - continue; - } - - maxDirectMemory = Long.parseLong(m.group(1)); - switch (m.group(2).charAt(0)) { - case 'k': - case 'K': - maxDirectMemory *= 1024; - break; - case 'm': - case 'M': - maxDirectMemory *= 1024 * 1024; - break; - case 'g': - case 'G': - maxDirectMemory *= 1024 * 1024 * 1024; - break; - default: - throw new IllegalAccessException(); - } - break; - } - } catch (Throwable ignored) { - LOGGER.warn("fail maxDirectMemory0", ignored); - } - - if (maxDirectMemory <= 0) { - maxDirectMemory = Runtime.getRuntime().maxMemory(); - LOGGER.info("maxDirectMemory: {} bytes (maybe) from jvm", maxDirectMemory); - } else { - LOGGER.info("maxDirectMemory: {} bytes from args", maxDirectMemory); - } + return maxDirectMemory; + } - return maxDirectMemory; - } - - private static long byteArrayBaseOffset0() { - if (!hasUnsafe()) { - return -1; - } - return PlatformDependent.byteArrayBaseOffset(); + private static long byteArrayBaseOffset0() { + if (!hasUnsafe()) { + return -1; } + return PlatformDependent.byteArrayBaseOffset(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemoryPool.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemoryPool.java index 4176130f5..d9c693be8 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemoryPool.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/DirectMemoryPool.java @@ -20,187 +20,205 @@ package org.apache.geaflow.memory; import java.nio.ByteBuffer; + import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DirectMemoryPool extends AbstractMemoryPool { - private static final Logger LOGGER = LoggerFactory.getLogger(DirectMemoryPool.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DirectMemoryPool.class); + + private final int maxChunkNum; + private final int initChunkNum; + private FreeChunkStatistics statistics; + private int count = 0; + private static final int MAX_COUNT = 10; + private static final double SHRINK_RATIO = 0.2d; + // enlarge the memory by 8 * chunkSize + private static final int EXPANSION_STEP_SIZE = 8; + + public DirectMemoryPool( + MemoryManager memoryManager, + int pageSize, + int maxOrder, + int pageShifts, + int chunkSize, + int initChunkNum, + int maxChunkNum) { + super(memoryManager, pageSize, maxOrder, pageShifts, chunkSize); + + for (int i = 0; i < initChunkNum; i++) { + qInit.add(newChunk(pageSize, maxOrder, pageShifts, chunkSize)); + } + this.currentChunkNum = initChunkNum; + this.maxChunkNum = maxChunkNum; + this.initChunkNum = initChunkNum; + statistics = + new FreeChunkStatistics( + memoryManager.config.getInteger(MemoryConfigKeys.MEMORY_TRIM_GAP_MINUTE)); + updateAllocateMemory(); + } - private final int maxChunkNum; - private final int initChunkNum; - private FreeChunkStatistics statistics; - private int count = 0; - private static final int MAX_COUNT = 10; - private static final double SHRINK_RATIO = 0.2d; - // enlarge the memory by 8 * chunkSize - private static final int EXPANSION_STEP_SIZE = 8; + @Override + MemoryMode getMemoryMode() { + return MemoryMode.OFF_HEAP; + } - public DirectMemoryPool(MemoryManager memoryManager, int pageSize, int maxOrder, int pageShifts, - int chunkSize, int initChunkNum, int maxChunkNum) { - super(memoryManager, pageSize, maxOrder, pageShifts, chunkSize); + @Override + void shrinkCapacity() { - for (int i = 0; i < initChunkNum; i++) { - qInit.add(newChunk(pageSize, maxOrder, pageShifts, chunkSize)); - } - this.currentChunkNum = initChunkNum; - this.maxChunkNum = maxChunkNum; - this.initChunkNum = initChunkNum; - statistics = new FreeChunkStatistics( - memoryManager.config.getInteger(MemoryConfigKeys.MEMORY_TRIM_GAP_MINUTE)); - updateAllocateMemory(); + if (this.currentChunkNum == this.initChunkNum || !statistics.isFull()) { + return; } - - @Override - MemoryMode getMemoryMode() { - return MemoryMode.OFF_HEAP; + int minFreeChunk = statistics.getMinFree(); + if (((1.0 * minFreeChunk) / this.currentChunkNum) < SHRINK_RATIO) { + return; } - - @Override - void shrinkCapacity() { - - if (this.currentChunkNum == this.initChunkNum || !statistics.isFull()) { - return; - } - int minFreeChunk = statistics.getMinFree(); - if (((1.0 * minFreeChunk) / this.currentChunkNum) < SHRINK_RATIO) { - return; - } - int needRemove = Math.min(minFreeChunk, this.currentChunkNum - this.initChunkNum); - int removed = 0; - synchronized (this) { - removed = removeChunk(qInit, needRemove); - if (removed < needRemove) { - removed += removeChunk(q000, needRemove - removed); - } - } - LOGGER.info("direct memory shrink capacity, need remove:{}, removed:{}, current " - + "chunkNum:{}, max chunkNum:{} ", needRemove, removed, currentChunkNum, maxChunkNum); - updateAllocateMemory(); - reloadMemoryStatics(getMemoryMode()); - statistics.clear(); + int needRemove = Math.min(minFreeChunk, this.currentChunkNum - this.initChunkNum); + int removed = 0; + synchronized (this) { + removed = removeChunk(qInit, needRemove); + if (removed < needRemove) { + removed += removeChunk(q000, needRemove - removed); + } } - - private int removeChunk(ChunkList chunkList, int needRemove) { - Chunk chunk = chunkList.getHead(); - if (chunk == null) { - return 0; - } - int removed = 0; - for (; ; ) { - if (chunk.isFree()) { - chunkList.remove(chunk); - chunkList.freeChunk.getAndAdd(-1); - chunk.destroy(); - currentChunkNum--; - removed++; - } - if (removed == needRemove) { - break; - } - chunk = chunk.next; - if (chunk == null) { - break; - } - } - return removed; + LOGGER.info( + "direct memory shrink capacity, need remove:{}, removed:{}, current " + + "chunkNum:{}, max chunkNum:{} ", + needRemove, + removed, + currentChunkNum, + maxChunkNum); + updateAllocateMemory(); + reloadMemoryStatics(getMemoryMode()); + statistics.clear(); + } + + private int removeChunk(ChunkList chunkList, int needRemove) { + Chunk chunk = chunkList.getHead(); + if (chunk == null) { + return 0; } - - @Override - protected void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group) { - - if (group.allocate(normCapacity, getMemoryMode())) { - if (q050.allocate(buf, normCapacity) || q025.allocate(buf, normCapacity) - || q000.allocate(buf, normCapacity) || qInit.allocate(buf, normCapacity) - || q075.allocate(buf, normCapacity)) { - if (++count % MAX_COUNT == 0) { - statistics.update(qInit.freeChunkNum() + q000.freeChunkNum()); - count = 0; - } - return; - } else { - group.free(normCapacity, getMemoryMode()); - } - } - - if (!canExpandCapacity()) { - return; - } - - int expansionSize = getExpansionSize(); - - for (int i = 0; i < expansionSize; i++) { - qInit.add(newChunk(pageSize, maxOrder, pageShifts, chunkSize)); - } - - this.currentChunkNum += expansionSize; - statistics.clear(); - updateAllocateMemory(); - reloadMemoryStatics(getMemoryMode()); - - boolean allocateSuccess = false; - - if (group.allocate(normCapacity, getMemoryMode())) { - allocateSuccess = qInit.allocate(buf, normCapacity); - } - if (!allocateSuccess) { - group.free(normCapacity, getMemoryMode()); - } - - LOGGER.info("direct memory expand capacity, expand chunkNum:{}, current chunkNum:{}, max " - + "chunkNum:{}", expansionSize, currentChunkNum, maxChunkNum); + int removed = 0; + for (; ; ) { + if (chunk.isFree()) { + chunkList.remove(chunk); + chunkList.freeChunk.getAndAdd(-1); + chunk.destroy(); + currentChunkNum--; + removed++; + } + if (removed == needRemove) { + break; + } + chunk = chunk.next; + if (chunk == null) { + break; + } } - - @Override - boolean canExpandCapacity() { - return this.maxChunkNum > this.currentChunkNum; + return removed; + } + + @Override + protected void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group) { + + if (group.allocate(normCapacity, getMemoryMode())) { + if (q050.allocate(buf, normCapacity) + || q025.allocate(buf, normCapacity) + || q000.allocate(buf, normCapacity) + || qInit.allocate(buf, normCapacity) + || q075.allocate(buf, normCapacity)) { + if (++count % MAX_COUNT == 0) { + statistics.update(qInit.freeChunkNum() + q000.freeChunkNum()); + count = 0; + } + return; + } else { + group.free(normCapacity, getMemoryMode()); + } } - private int getExpansionSize() { - int size = (maxChunkNum - initChunkNum) / EXPANSION_STEP_SIZE; - if (currentChunkNum + size > maxChunkNum) { - size = maxChunkNum - currentChunkNum; - } - return size > 0 ? size : 1; + if (!canExpandCapacity()) { + return; } - @Override - void destroyChunk(Chunk chunk) { + int expansionSize = getExpansionSize(); - if (DirectMemory.useDirectBufferNoCleaner()) { - DirectMemory.freeDirectNoCleaner(chunk.memory); - } else { - DirectMemory.freeDirectBuffer(chunk.memory); - } + for (int i = 0; i < expansionSize; i++) { + qInit.add(newChunk(pageSize, maxOrder, pageShifts, chunkSize)); } + this.currentChunkNum += expansionSize; + statistics.clear(); + updateAllocateMemory(); + reloadMemoryStatics(getMemoryMode()); - @Override - protected ByteBuf newByteBuf(int size) { - return new ByteBuf(); - } + boolean allocateSuccess = false; - @Override - protected Chunk newChunk(int pageSize, int maxOrder, int pageShifts, - int chunkSize) { - ByteBuffer buffer = allocateDirect(chunkSize); - return new Chunk<>(this, buffer, pageSize, maxOrder, pageShifts, chunkSize, 0); + if (group.allocate(normCapacity, getMemoryMode())) { + allocateSuccess = qInit.allocate(buf, normCapacity); } - - private static ByteBuffer allocateDirect(int capacity) { - return DirectMemory.useDirectBufferNoCleaner() ? DirectMemory.allocateDirectNoCleaner( - capacity) : ByteBuffer.allocateDirect(capacity); + if (!allocateSuccess) { + group.free(normCapacity, getMemoryMode()); } - void setStatistics(FreeChunkStatistics statistics) { - this.statistics = statistics; + LOGGER.info( + "direct memory expand capacity, expand chunkNum:{}, current chunkNum:{}, max " + + "chunkNum:{}", + expansionSize, + currentChunkNum, + maxChunkNum); + } + + @Override + boolean canExpandCapacity() { + return this.maxChunkNum > this.currentChunkNum; + } + + private int getExpansionSize() { + int size = (maxChunkNum - initChunkNum) / EXPANSION_STEP_SIZE; + if (currentChunkNum + size > maxChunkNum) { + size = maxChunkNum - currentChunkNum; } + return size > 0 ? size : 1; + } + @Override + void destroyChunk(Chunk chunk) { - protected String dump() { - return String.format("direct memory capacity statistics, full:%b, minFreeChunk:%d, " - + "currentChunkNum:%d, statistic:%s", statistics.isFull(), statistics.getMinFree(), - this.currentChunkNum, statistics.toString()); + if (DirectMemory.useDirectBufferNoCleaner()) { + DirectMemory.freeDirectNoCleaner(chunk.memory); + } else { + DirectMemory.freeDirectBuffer(chunk.memory); } + } + + @Override + protected ByteBuf newByteBuf(int size) { + return new ByteBuf(); + } + + @Override + protected Chunk newChunk(int pageSize, int maxOrder, int pageShifts, int chunkSize) { + ByteBuffer buffer = allocateDirect(chunkSize); + return new Chunk<>(this, buffer, pageSize, maxOrder, pageShifts, chunkSize, 0); + } + + private static ByteBuffer allocateDirect(int capacity) { + return DirectMemory.useDirectBufferNoCleaner() + ? DirectMemory.allocateDirectNoCleaner(capacity) + : ByteBuffer.allocateDirect(capacity); + } + + void setStatistics(FreeChunkStatistics statistics) { + this.statistics = statistics; + } + + protected String dump() { + return String.format( + "direct memory capacity statistics, full:%b, minFreeChunk:%d, " + + "currentChunkNum:%d, statistic:%s", + statistics.isFull(), statistics.getMinFree(), this.currentChunkNum, statistics.toString()); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ESegmentSize.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ESegmentSize.java index aa2046a6f..060781858 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ESegmentSize.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ESegmentSize.java @@ -20,144 +20,83 @@ package org.apache.geaflow.memory; public enum ESegmentSize { - /** - * 16 bytes. - */ - S16(16, 4), - /** - * 32 bytes. - */ - S32(32, 5), - /** - * 64 bytes. - */ - S64(64, 6), - /** - * 128 bytes. - */ - S128(128, 7), - /** - * 256 bytes. - */ - S256(256, 8), - /** - * 512 bytes. - */ - S512(512, 9), - /** - * 1 KB. - */ - S1024(1024, 10), - /** - * 2 KB. - */ - S2048(2048, 11), - /** - * 4 KB. - */ - S4096(4096, 12), - /** - * 8 KB. - */ - S8192(8192, 13), - /** - * 16 KB. - */ - S16384(16384, 14), - /** - * 32 KB. - */ - S32768(32768, 15), - /** - * 64 KB. - */ - S65536(65536, 16), - /** - * 128 KB. - */ - S131072(131072, 17), - /** - * 256 KB. - */ - S262144(262144, 18), - /** - * 512 KB. - */ - S524288(524288, 19), - /** - * 1 MB. - */ - S1048576(1048576, 20), - /** - * 2 MB. - */ - S2097152(2097152, 21), - /** - * 4 MB. - */ - S4194304(4194304, 22), - /** - * 8 MB. - */ - S8388608(8388608, 23), - /** - * 16 MB. - */ - S16777216(16777216, 24); + /** 16 bytes. */ + S16(16, 4), + /** 32 bytes. */ + S32(32, 5), + /** 64 bytes. */ + S64(64, 6), + /** 128 bytes. */ + S128(128, 7), + /** 256 bytes. */ + S256(256, 8), + /** 512 bytes. */ + S512(512, 9), + /** 1 KB. */ + S1024(1024, 10), + /** 2 KB. */ + S2048(2048, 11), + /** 4 KB. */ + S4096(4096, 12), + /** 8 KB. */ + S8192(8192, 13), + /** 16 KB. */ + S16384(16384, 14), + /** 32 KB. */ + S32768(32768, 15), + /** 64 KB. */ + S65536(65536, 16), + /** 128 KB. */ + S131072(131072, 17), + /** 256 KB. */ + S262144(262144, 18), + /** 512 KB. */ + S524288(524288, 19), + /** 1 MB. */ + S1048576(1048576, 20), + /** 2 MB. */ + S2097152(2097152, 21), + /** 4 MB. */ + S4194304(4194304, 22), + /** 8 MB. */ + S8388608(8388608, 23), + /** 16 MB. */ + S16777216(16777216, 24); - public static ESegmentSize[] upValues = new ESegmentSize[]{ - S16, - S32, - S64, - S128, - S256, - S512, - S1024, - S2048, - S4096, - S8192, - S16384, - S32768, - S65536, - S131072, - S262144, - S524288, - S1048576, - S2097152, - S4194304, - S8388608, - S16777216 - }; + public static ESegmentSize[] upValues = + new ESegmentSize[] { + S16, S32, S64, S128, S256, S512, S1024, S2048, S4096, S8192, S16384, S32768, S65536, + S131072, S262144, S524288, S1048576, S2097152, S4194304, S8388608, S16777216 + }; - private static int baseSigPos = upValues[0].mostSigPos; + private static int baseSigPos = upValues[0].mostSigPos; - private final int size; - private final int mostSigPos; + private final int size; + private final int mostSigPos; - ESegmentSize(int size, int mostSigPos) { - this.size = size; - this.mostSigPos = mostSigPos; - } + ESegmentSize(int size, int mostSigPos) { + this.size = size; + this.mostSigPos = mostSigPos; + } - public static ESegmentSize valueOf(int len) { - int n = (int) (Math.log(len) / Math.log(2)); - return upValues[n - smallest().mostSigPos]; - } + public static ESegmentSize valueOf(int len) { + int n = (int) (Math.log(len) / Math.log(2)); + return upValues[n - smallest().mostSigPos]; + } - public static ESegmentSize largest() { - return upValues[upValues.length - 1]; - } + public static ESegmentSize largest() { + return upValues[upValues.length - 1]; + } - public static ESegmentSize smallest() { - return upValues[0]; - } + public static ESegmentSize smallest() { + return upValues[0]; + } - public int size() { - return this.size; - } - - public int index() { - return mostSigPos - baseSigPos; - } + public int size() { + return this.size; + } + public int index() { + return mostSigPos - baseSigPos; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/FreeChunkStatistics.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/FreeChunkStatistics.java index 362dbed5f..565250ef2 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/FreeChunkStatistics.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/FreeChunkStatistics.java @@ -19,66 +19,77 @@ package org.apache.geaflow.memory; -import com.google.common.collect.EvictingQueue; import java.util.Queue; -public class FreeChunkStatistics { +import com.google.common.collect.EvictingQueue; - private Queue timeQueue; - private long lastTime; - private int free = Integer.MAX_VALUE; - private int queueSize; - private int diffTime = 60 * 1000; +public class FreeChunkStatistics { - public FreeChunkStatistics(int queueSize) { - timeQueue = EvictingQueue.create(queueSize); - this.queueSize = queueSize; - } + private Queue timeQueue; + private long lastTime; + private int free = Integer.MAX_VALUE; + private int queueSize; + private int diffTime = 60 * 1000; - public FreeChunkStatistics(int queueSize, int diffTime) { - this.queueSize = queueSize; - this.diffTime = diffTime; - timeQueue = EvictingQueue.create(queueSize); - } + public FreeChunkStatistics(int queueSize) { + timeQueue = EvictingQueue.create(queueSize); + this.queueSize = queueSize; + } - public void update(int free) { - long time = System.currentTimeMillis(); + public FreeChunkStatistics(int queueSize, int diffTime) { + this.queueSize = queueSize; + this.diffTime = diffTime; + timeQueue = EvictingQueue.create(queueSize); + } - if (lastTime == 0) { - lastTime = time; - } + public void update(int free) { + long time = System.currentTimeMillis(); - if (time - lastTime > diffTime) { - timeQueue.add(this.free); - lastTime = time; - this.free = Integer.MAX_VALUE; - } + if (lastTime == 0) { + lastTime = time; + } - this.free = Math.min(free, this.free); + if (time - lastTime > diffTime) { + timeQueue.add(this.free); + lastTime = time; + this.free = Integer.MAX_VALUE; } - public int getMinFree() { - int minFree = free; + this.free = Math.min(free, this.free); + } - for (int free : timeQueue) { - minFree = Math.min(minFree, free); - } - return minFree; - } + public int getMinFree() { + int minFree = free; - public boolean isFull() { - return timeQueue.size() + 1 >= queueSize; + for (int free : timeQueue) { + minFree = Math.min(minFree, free); } + return minFree; + } - public void clear() { - lastTime = 0; - free = Integer.MAX_VALUE; - timeQueue.clear(); - } + public boolean isFull() { + return timeQueue.size() + 1 >= queueSize; + } - @Override - public String toString() { - return "FreeChunkStatistics{" + "timeQueue=" + timeQueue + ", lastTime=" + lastTime - + ", free=" + free + ", queueSize=" + queueSize + ", diffTime=" + diffTime + '}'; - } + public void clear() { + lastTime = 0; + free = Integer.MAX_VALUE; + timeQueue.clear(); + } + + @Override + public String toString() { + return "FreeChunkStatistics{" + + "timeQueue=" + + timeQueue + + ", lastTime=" + + lastTime + + ", free=" + + free + + ", queueSize=" + + queueSize + + ", diffTime=" + + diffTime + + '}'; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/HeapMemoryPool.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/HeapMemoryPool.java index 4195eb0b3..ccea475e9 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/HeapMemoryPool.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/HeapMemoryPool.java @@ -21,67 +21,66 @@ public class HeapMemoryPool extends AbstractMemoryPool { - public HeapMemoryPool(MemoryManager memoryManager, int pageSize, int maxOrder, int pageShifts, - int chunkSize) { - super(memoryManager, pageSize, maxOrder, pageShifts, chunkSize); - } - - @Override - protected void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group) { + public HeapMemoryPool( + MemoryManager memoryManager, int pageSize, int maxOrder, int pageShifts, int chunkSize) { + super(memoryManager, pageSize, maxOrder, pageShifts, chunkSize); + } - group.allocate(normCapacity, getMemoryMode()); - if (q050.allocate(buf, normCapacity) || q025.allocate(buf, normCapacity) || q000 - .allocate(buf, normCapacity) || qInit.allocate(buf, normCapacity) || q075 - .allocate(buf, normCapacity)) { - return; - } + @Override + protected void allocateBytebuf(ByteBuf buf, int normCapacity, MemoryGroup group) { - Chunk c = newChunk(pageSize, maxOrder, pageShifts, chunkSize); - long handle = c.allocate(normCapacity); - assert handle > 0; - c.initBuf(buf, handle, normCapacity); - qInit.add(c); - currentChunkNum++; - updateAllocateMemory(); - reloadMemoryStatics(getMemoryMode()); + group.allocate(normCapacity, getMemoryMode()); + if (q050.allocate(buf, normCapacity) + || q025.allocate(buf, normCapacity) + || q000.allocate(buf, normCapacity) + || qInit.allocate(buf, normCapacity) + || q075.allocate(buf, normCapacity)) { + return; } - @Override - MemoryMode getMemoryMode() { - return MemoryMode.ON_HEAP; - } + Chunk c = newChunk(pageSize, maxOrder, pageShifts, chunkSize); + long handle = c.allocate(normCapacity); + assert handle > 0; + c.initBuf(buf, handle, normCapacity); + qInit.add(c); + currentChunkNum++; + updateAllocateMemory(); + reloadMemoryStatics(getMemoryMode()); + } - @Override - boolean canExpandCapacity() { - return true; - } + @Override + MemoryMode getMemoryMode() { + return MemoryMode.ON_HEAP; + } - @Override - void shrinkCapacity() { + @Override + boolean canExpandCapacity() { + return true; + } - } + @Override + void shrinkCapacity() {} - @Override - protected String dump() { - return null; - } + @Override + protected String dump() { + return null; + } - @Override - void destroyChunk(Chunk chunk) { - //wait gc - currentChunkNum--; - updateAllocateMemory(); - reloadMemoryStatics(getMemoryMode()); - } + @Override + void destroyChunk(Chunk chunk) { + // wait gc + currentChunkNum--; + updateAllocateMemory(); + reloadMemoryStatics(getMemoryMode()); + } - @Override - protected ByteBuf newByteBuf(int size) { - return new ByteBuf<>(); - } + @Override + protected ByteBuf newByteBuf(int size) { + return new ByteBuf<>(); + } - @Override - protected Chunk newChunk(int pageSize, int maxOrder, int pageShifts, int chunkSize) { - return new Chunk<>(this, new byte[chunkSize], pageSize, maxOrder, pageShifts, - chunkSize, 0); - } + @Override + protected Chunk newChunk(int pageSize, int maxOrder, int pageShifts, int chunkSize) { + return new Chunk<>(this, new byte[chunkSize], pageSize, maxOrder, pageShifts, chunkSize, 0); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroup.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroup.java index 60d3e2307..904b2a29c 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroup.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroup.java @@ -19,286 +19,294 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; -import com.google.common.collect.Maps; import java.io.Serializable; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; + import org.apache.geaflow.memory.metric.MemoryGroupMetric; +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; + public final class MemoryGroup implements MemoryGroupMetric, Serializable { - private final String name; - private final AtomicInteger threads = new AtomicInteger(0); - private final int spanSize; - private final AtomicLong byteBufCount = new AtomicLong(0); - private double ratio; - private final Map counterMap; - - MemoryGroup(String name, int spanSize) { - this.name = name; - this.spanSize = spanSize; - counterMap = Maps.newHashMap(); - counterMap.put(MemoryMode.OFF_HEAP, new OffHeapCounter()); - counterMap.put(MemoryMode.ON_HEAP, new OnHeapCounter()); + private final String name; + private final AtomicInteger threads = new AtomicInteger(0); + private final int spanSize; + private final AtomicLong byteBufCount = new AtomicLong(0); + private double ratio; + private final Map counterMap; + + MemoryGroup(String name, int spanSize) { + this.name = name; + this.spanSize = spanSize; + counterMap = Maps.newHashMap(); + counterMap.put(MemoryMode.OFF_HEAP, new OffHeapCounter()); + counterMap.put(MemoryMode.ON_HEAP, new OnHeapCounter()); + } + + public String getName() { + return name; + } + + public boolean allocate(long bytes, MemoryMode memoryMode) { + return counterMap.get(memoryMode).allocate(bytes); + } + + public boolean free(long bytes, MemoryMode memoryMode) { + return counterMap.get(memoryMode).free(bytes); + } + + public void updateByteBufCount(long count) { + byteBufCount.addAndGet(count); + } + + public void setSharedFreeBytes(AtomicLong sharedFreeBytes, MemoryMode memoryMode) { + counterMap.get(memoryMode).setSharedFreeBytes(sharedFreeBytes); + } + + @Override + public long usedBytes() { + return counterMap.values().stream().mapToLong(Counter::usedBytes).sum(); + } + + @Override + public long usedOnHeapBytes() { + return counterMap.get(MemoryMode.ON_HEAP).usedBytes(); + } + + @Override + public long usedOffHeapBytes() { + return counterMap.get(MemoryMode.OFF_HEAP).usedBytes(); + } + + @Override + public long baseBytes() { + return counterMap.values().stream().mapToLong(Counter::baseBytes).sum(); + } + + public void increaseThreads() { + threads.incrementAndGet(); + } + + public void decreaseThreads() { + threads.decrementAndGet(); + } + + public int getThreads() { + return threads.get(); + } + + public int getSpanSize() { + return spanSize; + } + + public void setRatio(String decimalRatio) { + int dr = "*".equals(decimalRatio) ? 0 : Integer.parseInt(decimalRatio); + Preconditions.checkArgument(dr <= 100, "group ratio expect[0-100]"); + this.ratio = 0.01 * dr; + } + + public double getRatio() { + return ratio; + } + + public void setBaseBytes(long totalMemory, int chunkSize, MemoryMode memoryMode) { + long baseBytes = normalize((long) (totalMemory * ratio), chunkSize); + counterMap.get(memoryMode).setBaseBytes(baseBytes); + } + + private long normalize(long size, int chunkSize) { + return (size & (long) (chunkSize - 1)) != 0 + ? ((size >>> (int) (Math.log(chunkSize) / Math.log(2.0))) + 1) * chunkSize + : size; + } + + @Override + public double usage() { + if (baseBytes() <= 0) { + return 0.0; } - - public String getName() { - return name; + return (double) usedBytes() / baseBytes(); + } + + @Override + public long byteBufNum() { + return byteBufCount.get(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean allocate(long bytes, MemoryMode memoryMode) { - return counterMap.get(memoryMode).allocate(bytes); + if (o == null || getClass() != o.getClass()) { + return false; } + MemoryGroup that = (MemoryGroup) o; + return this.name.equals(that.name); + } - public boolean free(long bytes, MemoryMode memoryMode) { - return counterMap.get(memoryMode).free(bytes); - } + @Override + public int hashCode() { + return Objects.hash(name); + } - public void updateByteBufCount(long count) { - byteBufCount.addAndGet(count); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); - public void setSharedFreeBytes(AtomicLong sharedFreeBytes, MemoryMode memoryMode) { - counterMap.get(memoryMode).setSharedFreeBytes(sharedFreeBytes); - } + sb.append("[") + .append(name) + .append(",") + .append(baseBytes()) + .append(",") + .append(usedBytes()) + .append(",") + .append(usage()) + .append(",") + .append(byteBufNum()) + .append(",") + .append(threads.get()) + .append("]"); - @Override - public long usedBytes() { - return counterMap.values().stream().mapToLong(Counter::usedBytes).sum(); - } + return sb.toString(); + } - @Override - public long usedOnHeapBytes() { - return counterMap.get(MemoryMode.ON_HEAP).usedBytes(); - } + public void reset() { + byteBufCount.set(0); + threads.set(0); + counterMap.forEach((k, v) -> v.reset()); + } - @Override - public long usedOffHeapBytes() { - return counterMap.get(MemoryMode.OFF_HEAP).usedBytes(); - } + interface Counter { - @Override - public long baseBytes() { - return counterMap.values().stream().mapToLong(Counter::baseBytes).sum(); - } + boolean allocate(long bytes); - public void increaseThreads() { - threads.incrementAndGet(); - } + boolean free(long bytes); - public void decreaseThreads() { - threads.decrementAndGet(); - } + long usedBytes(); - public int getThreads() { - return threads.get(); - } + void setSharedFreeBytes(AtomicLong sharedFreeBytes); - public int getSpanSize() { - return spanSize; - } + void setBaseBytes(long bytes); - public void setRatio(String decimalRatio) { - int dr = "*".equals(decimalRatio) ? 0 : Integer.parseInt(decimalRatio); - Preconditions.checkArgument(dr <= 100, "group ratio expect[0-100]"); - this.ratio = 0.01 * dr; - } + long baseBytes(); - public double getRatio() { - return ratio; - } + void reset(); + } - public void setBaseBytes(long totalMemory, int chunkSize, MemoryMode memoryMode) { - long baseBytes = normalize((long) (totalMemory * ratio), chunkSize); - counterMap.get(memoryMode).setBaseBytes(baseBytes); - } + class OffHeapCounter implements Counter { - private long normalize(long size, int chunkSize) { - return (size & (long) (chunkSize - 1)) != 0 - ? ((size >>> (int) (Math.log(chunkSize) / Math.log(2.0))) + 1) * chunkSize : size; - } + private AtomicLong usedSharedBytes = new AtomicLong(0); + private AtomicLong baseFreeBytes = new AtomicLong(0); + private AtomicLong sharedFreeBytes; + private long baseBytes; @Override - public double usage() { - if (baseBytes() <= 0) { - return 0.0; + public boolean allocate(long bytes) { + if (baseBytes > 0) { + if (baseFreeBytes.addAndGet(-1 * bytes) >= 0) { + return true; + } else { + baseFreeBytes.getAndAdd(bytes); } - return (double) usedBytes() / baseBytes(); + } + + if (sharedFreeBytes.get() < bytes) { + return false; + } + + if (sharedFreeBytes.addAndGet(-1 * bytes) < 0) { + sharedFreeBytes.getAndAdd(bytes); + return false; + } + usedSharedBytes.getAndAdd(bytes); + return true; } @Override - public long byteBufNum() { - return byteBufCount.get(); + public boolean free(long bytes) { + if (baseBytes > 0) { + if (baseFreeBytes.addAndGet(bytes) <= baseBytes) { + return true; + } else { + baseFreeBytes.getAndAdd(-1 * bytes); + } + } + + sharedFreeBytes.getAndAdd(bytes); + usedSharedBytes.getAndAdd(-1 * bytes); + return true; } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - MemoryGroup that = (MemoryGroup) o; - return this.name.equals(that.name); + public long usedBytes() { + return usedSharedBytes.get() + baseBytes - baseFreeBytes.get(); } @Override - public int hashCode() { - return Objects.hash(name); + public void setSharedFreeBytes(AtomicLong sharedFreeBytes) { + this.sharedFreeBytes = sharedFreeBytes; } @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - - sb.append("[").append(name).append(",").append(baseBytes()).append(",").append(usedBytes()) - .append(",").append(usage()).append(",").append(byteBufNum()).append(",") - .append(threads.get()).append("]"); + public void setBaseBytes(long bytes) { + long diff = bytes - this.baseBytes; + this.baseBytes = bytes; + this.baseFreeBytes.getAndAdd(diff); + } - return sb.toString(); + @Override + public long baseBytes() { + return baseBytes; } + @Override public void reset() { - byteBufCount.set(0); - threads.set(0); - counterMap.forEach((k, v) -> v.reset()); + usedSharedBytes.set(0); + baseFreeBytes.set(0); + baseBytes = 0; } + } - interface Counter { - - boolean allocate(long bytes); - - boolean free(long bytes); - - long usedBytes(); + static class OnHeapCounter implements Counter { - void setSharedFreeBytes(AtomicLong sharedFreeBytes); + private AtomicLong usedBytes = new AtomicLong(0); - void setBaseBytes(long bytes); - - long baseBytes(); - - void reset(); + @Override + public boolean allocate(long bytes) { + usedBytes.getAndAdd(bytes); + return true; } - class OffHeapCounter implements Counter { - - private AtomicLong usedSharedBytes = new AtomicLong(0); - private AtomicLong baseFreeBytes = new AtomicLong(0); - private AtomicLong sharedFreeBytes; - private long baseBytes; - - @Override - public boolean allocate(long bytes) { - if (baseBytes > 0) { - if (baseFreeBytes.addAndGet(-1 * bytes) >= 0) { - return true; - } else { - baseFreeBytes.getAndAdd(bytes); - } - } - - if (sharedFreeBytes.get() < bytes) { - return false; - } - - if (sharedFreeBytes.addAndGet(-1 * bytes) < 0) { - sharedFreeBytes.getAndAdd(bytes); - return false; - } - usedSharedBytes.getAndAdd(bytes); - return true; - } - - @Override - public boolean free(long bytes) { - if (baseBytes > 0) { - if (baseFreeBytes.addAndGet(bytes) <= baseBytes) { - return true; - } else { - baseFreeBytes.getAndAdd(-1 * bytes); - } - } - - sharedFreeBytes.getAndAdd(bytes); - usedSharedBytes.getAndAdd(-1 * bytes); - return true; - } - - @Override - public long usedBytes() { - return usedSharedBytes.get() + baseBytes - baseFreeBytes.get(); - } - - @Override - public void setSharedFreeBytes(AtomicLong sharedFreeBytes) { - this.sharedFreeBytes = sharedFreeBytes; - } - - @Override - public void setBaseBytes(long bytes) { - long diff = bytes - this.baseBytes; - this.baseBytes = bytes; - this.baseFreeBytes.getAndAdd(diff); - } - - @Override - public long baseBytes() { - return baseBytes; - } - - @Override - public void reset() { - usedSharedBytes.set(0); - baseFreeBytes.set(0); - baseBytes = 0; - } + @Override + public boolean free(long bytes) { + usedBytes.getAndAdd(-1 * bytes); + return true; } - static class OnHeapCounter implements Counter { - - private AtomicLong usedBytes = new AtomicLong(0); - - @Override - public boolean allocate(long bytes) { - usedBytes.getAndAdd(bytes); - return true; - } - - @Override - public boolean free(long bytes) { - usedBytes.getAndAdd(-1 * bytes); - return true; - } - - @Override - public long usedBytes() { - return usedBytes.get(); - } - - @Override - public void setSharedFreeBytes(AtomicLong sharedFreeBytes) { - - } - - @Override - public void setBaseBytes(long bytes) { + @Override + public long usedBytes() { + return usedBytes.get(); + } - } + @Override + public void setSharedFreeBytes(AtomicLong sharedFreeBytes) {} - @Override - public long baseBytes() { - return 0; - } + @Override + public void setBaseBytes(long bytes) {} - @Override - public void reset() { - usedBytes.set(0); - } + @Override + public long baseBytes() { + return 0; } + @Override + public void reset() { + usedBytes.set(0); + } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroupManger.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroupManger.java index e7e731ba2..1b6fb5f09 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroupManger.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryGroupManger.java @@ -19,105 +19,112 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.MemoryUtils; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class MemoryGroupManger { +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; - private static final Logger LOGGER = LoggerFactory.getLogger(MemoryGroupManger.class); +public final class MemoryGroupManger { - public static MemoryGroup DEFAULT = new MemoryGroup("default", -1); + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryGroupManger.class); - public static MemoryGroup SHUFFLE = new MemoryGroup("shuffle", (int) (16 * MemoryUtils.KB)); + public static MemoryGroup DEFAULT = new MemoryGroup("default", -1); - public static MemoryGroup STATE = new MemoryGroup("state", -1); + public static MemoryGroup SHUFFLE = new MemoryGroup("shuffle", (int) (16 * MemoryUtils.KB)); - private static final Map GROUPS = new LinkedHashMap<>(); - private static final AtomicLong SHARED_OFF_HEAP_FREE_BYTES = new AtomicLong(0); - private static long totalOffHeapSharedMemory; + public static MemoryGroup STATE = new MemoryGroup("state", -1); - private static MemoryGroupManger manger; + private static final Map GROUPS = new LinkedHashMap<>(); + private static final AtomicLong SHARED_OFF_HEAP_FREE_BYTES = new AtomicLong(0); + private static long totalOffHeapSharedMemory; - public static synchronized MemoryGroupManger getInstance() { - if (manger == null) { - manger = new MemoryGroupManger(); + private static MemoryGroupManger manger; - DEFAULT.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); - SHUFFLE.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); - STATE.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); + public static synchronized MemoryGroupManger getInstance() { + if (manger == null) { + manger = new MemoryGroupManger(); - // 顺序和ratio一致 - register(SHUFFLE); - register(STATE); - register(DEFAULT); - } - return manger; - } + DEFAULT.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); + SHUFFLE.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); + STATE.setSharedFreeBytes(SHARED_OFF_HEAP_FREE_BYTES, MemoryMode.OFF_HEAP); - void load(Configuration config) { - String groupRatios = config.getString(MemoryConfigKeys.MEMORY_GROUP_RATIO); - String[] ratios = groupRatios.split(":"); - Preconditions.checkArgument(ratios.length == GROUPS.size(), - MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey() + " group ratio is not equals with group size"); - List groups = memoryGroups(); - for (int i = 0; i < groups.size(); i++) { - groups.get(i).setRatio(ratios[i]); - } - LOGGER.info("MemoryGroup ratio : {}", groupRatios); + // 顺序和ratio一致 + register(SHUFFLE); + register(STATE); + register(DEFAULT); } - - synchronized void resetMemory(long memorySize, int chunkSize, MemoryMode memoryMode) { - if (memoryMode != MemoryMode.OFF_HEAP) { - return; - } - - List groupList = memoryGroups(); - groupList.forEach(e -> e.setBaseBytes(memorySize, chunkSize, memoryMode)); - - long oldShared = totalOffHeapSharedMemory; - totalOffHeapSharedMemory = - memorySize - groupList.stream().mapToLong(e -> e.baseBytes()).sum(); - SHARED_OFF_HEAP_FREE_BYTES.getAndAdd(totalOffHeapSharedMemory - oldShared); - - LOGGER.info("[default:{}, shuffle:{}, state:{}, totalShared:{}, sharedFreeBytes:{}]", - DEFAULT.baseBytes(), SHUFFLE.baseBytes(), STATE.baseBytes(), totalOffHeapSharedMemory, - SHARED_OFF_HEAP_FREE_BYTES.get()); + return manger; + } + + void load(Configuration config) { + String groupRatios = config.getString(MemoryConfigKeys.MEMORY_GROUP_RATIO); + String[] ratios = groupRatios.split(":"); + Preconditions.checkArgument( + ratios.length == GROUPS.size(), + MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey() + + " group ratio is not equals with group size"); + List groups = memoryGroups(); + for (int i = 0; i < groups.size(); i++) { + groups.get(i).setRatio(ratios[i]); } + LOGGER.info("MemoryGroup ratio : {}", groupRatios); + } - private static void register(MemoryGroup group) { - GROUPS.put(group.getName(), group); + synchronized void resetMemory(long memorySize, int chunkSize, MemoryMode memoryMode) { + if (memoryMode != MemoryMode.OFF_HEAP) { + return; } - private static String dump() { - StringBuilder sb = new StringBuilder(); - GROUPS.forEach((k, v) -> { - sb.append(v.toString()).append("\n"); + List groupList = memoryGroups(); + groupList.forEach(e -> e.setBaseBytes(memorySize, chunkSize, memoryMode)); + + long oldShared = totalOffHeapSharedMemory; + totalOffHeapSharedMemory = memorySize - groupList.stream().mapToLong(e -> e.baseBytes()).sum(); + SHARED_OFF_HEAP_FREE_BYTES.getAndAdd(totalOffHeapSharedMemory - oldShared); + + LOGGER.info( + "[default:{}, shuffle:{}, state:{}, totalShared:{}, sharedFreeBytes:{}]", + DEFAULT.baseBytes(), + SHUFFLE.baseBytes(), + STATE.baseBytes(), + totalOffHeapSharedMemory, + SHARED_OFF_HEAP_FREE_BYTES.get()); + } + + private static void register(MemoryGroup group) { + GROUPS.put(group.getName(), group); + } + + private static String dump() { + StringBuilder sb = new StringBuilder(); + GROUPS.forEach( + (k, v) -> { + sb.append(v.toString()).append("\n"); }); - return sb.toString(); - } - - void clear() { - GROUPS.forEach((k, v) -> v.reset()); - SHARED_OFF_HEAP_FREE_BYTES.set(0); - totalOffHeapSharedMemory = 0; - manger = null; - } - - public List memoryGroups() { - return Lists.newArrayList(GROUPS.values()); - } - - public long getCurrentSharedFreeBytes() { - return SHARED_OFF_HEAP_FREE_BYTES.get(); - } - + return sb.toString(); + } + + void clear() { + GROUPS.forEach((k, v) -> v.reset()); + SHARED_OFF_HEAP_FREE_BYTES.set(0); + totalOffHeapSharedMemory = 0; + manger = null; + } + + public List memoryGroups() { + return Lists.newArrayList(GROUPS.values()); + } + + public long getCurrentSharedFreeBytes() { + return SHARED_OFF_HEAP_FREE_BYTES.get(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryManager.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryManager.java index 4a85f0378..3c83ee2b1 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryManager.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryManager.java @@ -19,9 +19,9 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.MemoryUtils; import org.apache.geaflow.memory.config.MemoryConfigKeys; @@ -29,356 +29,381 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class MemoryManager implements Serializable { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(MemoryManager.class); +public final class MemoryManager implements Serializable { - public static final int MIN_PAGE_SIZE = 4 * 1024; - private static final int MAX_CHUNK_SIZE = (int) (((long) Integer.MAX_VALUE + 1) / 2); + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryManager.class); - private AbstractMemoryPool[] pools; + public static final int MIN_PAGE_SIZE = 4 * 1024; + private static final int MAX_CHUNK_SIZE = (int) (((long) Integer.MAX_VALUE + 1) / 2); - private boolean autoAdaptEnable; - private AbstractMemoryPool[] adaptivePools; + private AbstractMemoryPool[] pools; - private static MemoryManager memoryManager; - private final BaseMemoryGroupThreadLocal threadLocal; - private final int chunkSize; - private final long maxMemory; + private boolean autoAdaptEnable; + private AbstractMemoryPool[] adaptivePools; - protected Configuration config; + private static MemoryManager memoryManager; + private final BaseMemoryGroupThreadLocal threadLocal; + private final int chunkSize; + private final long maxMemory; - public static synchronized MemoryManager build(Configuration config) { - if (memoryManager == null) { - memoryManager = new MemoryManager(config); - } - return memoryManager; - } + protected Configuration config; - public static MemoryManager getInstance() { - return memoryManager; + public static synchronized MemoryManager build(Configuration config) { + if (memoryManager == null) { + memoryManager = new MemoryManager(config); } + return memoryManager; + } - private MemoryManager(Configuration config) { + public static MemoryManager getInstance() { + return memoryManager; + } - this.config = config; + private MemoryManager(Configuration config) { - this.autoAdaptEnable = config.getBoolean(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE); + this.config = config; - long offHeapSize = config.getLong(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB) * MemoryUtils.MB; - long heapSize = config.getLong(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB) * MemoryUtils.MB; + this.autoAdaptEnable = config.getBoolean(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE); - int pageSize = config.getInteger(MemoryConfigKeys.MEMORY_PAGE_SIZE); - int maxOrder = config.getInteger(MemoryConfigKeys.MEMORY_MAX_ORDER); - int pageShifts = validateAndCalculatePageShifts(pageSize); + long offHeapSize = config.getLong(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB) * MemoryUtils.MB; + long heapSize = config.getLong(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB) * MemoryUtils.MB; - if (offHeapSize == 0 && heapSize == 0) { - offHeapSize = (long) (MemoryConfigKeys.JVM_MAX_DIRECT_MEMORY * 0.3); - } - - chunkSize = validateAndCalculateChunkSize(pageSize, maxOrder); - int poolSize = config.getInteger(MemoryConfigKeys.MEMORY_POOL_SIZE); - int initChunkNum = 0; - int maxChunkNum = 0; - long maxOffHeapSize = config.getLong(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE); - - Preconditions.checkArgument(offHeapSize <= maxOffHeapSize, - "OffHeapSize:%s is greater than maxOffHeapSize:%s", offHeapSize, maxOffHeapSize); - - if (offHeapSize > 0) { - poolSize = poolSize > 0 ? poolSize : caculatePoolSize(offHeapSize, pageSize, maxOrder); - pools = new AbstractMemoryPool[poolSize]; - initChunkNum = caculateChunkNum(poolSize, chunkSize, offHeapSize); - maxChunkNum = caculateChunkNum(poolSize, chunkSize, maxOffHeapSize); - maxChunkNum = Math.max(maxChunkNum, initChunkNum); - - for (int i = 0; i < poolSize; i++) { - pools[i] = new DirectMemoryPool(this, pageSize, maxOrder, pageShifts, chunkSize, - initChunkNum, maxChunkNum); - } - - if (autoAdaptEnable) { - adaptivePools = new AbstractMemoryPool[poolSize]; - for (int i = 0; i < poolSize; i++) { - adaptivePools[i] = new HeapMemoryPool(this, pageSize, maxOrder, pageShifts, chunkSize); - } - } - - MemoryGroupManger.getInstance().load(config); - MemoryGroupManger.getInstance().resetMemory((long) poolSize * initChunkNum * chunkSize, - chunkSize, MemoryMode.OFF_HEAP); - } + int pageSize = config.getInteger(MemoryConfigKeys.MEMORY_PAGE_SIZE); + int maxOrder = config.getInteger(MemoryConfigKeys.MEMORY_MAX_ORDER); + int pageShifts = validateAndCalculatePageShifts(pageSize); - if (heapSize > 0) { - poolSize = poolSize > 0 ? poolSize : caculatePoolSize(heapSize, pageSize, maxOrder); - pools = new AbstractMemoryPool[poolSize]; - for (int i = 0; i < poolSize; i++) { - pools[i] = new HeapMemoryPool(this, pageSize, maxOrder, pageShifts, chunkSize); - } - } - - this.maxMemory = (long) chunkSize * poolSize * maxChunkNum; - - LOGGER.info("MemoryManager init, offHeapSize:{},maxOffHeapSize:{},heapSize:{}," - + "pageSize:{},maxOrder:{},pageShifts:{},chunkSize:{},poolSize:{},initChunkNum:{}," - + "maxChunkNum:{}", offHeapSize, this.maxMemory, heapSize, pageSize, maxOrder, - pageShifts, chunkSize, poolSize, initChunkNum, maxChunkNum); - - threadLocal = new BaseMemoryGroupThreadLocal() { - @Override - protected synchronized ThreadLocalCache initialValue(MemoryGroup group) { - int index = leastUsedPool(pools, group); - AbstractMemoryPool localPool = pools[index]; - // get one adaptive pool for the local pool - AbstractMemoryPool adaptivePool = - (autoAdaptEnable && adaptivePools != null) ? adaptivePools[index] : null; - return new ThreadLocalCache(pools, localPool, adaptivePool, group); - } - - @Override - protected void notifyRemove(ThreadLocalCache threadLocalCache) { - threadLocalCache.free(); - } - }; + if (offHeapSize == 0 && heapSize == 0) { + offHeapSize = (long) (MemoryConfigKeys.JVM_MAX_DIRECT_MEMORY * 0.3); } - private int leastUsedPool(AbstractMemoryPool[] pools, MemoryGroup group) { - if (pools == null || pools.length == 0) { - throw new IllegalArgumentException("pool size is empty"); + chunkSize = validateAndCalculateChunkSize(pageSize, maxOrder); + int poolSize = config.getInteger(MemoryConfigKeys.MEMORY_POOL_SIZE); + int initChunkNum = 0; + int maxChunkNum = 0; + long maxOffHeapSize = config.getLong(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE); + + Preconditions.checkArgument( + offHeapSize <= maxOffHeapSize, + "OffHeapSize:%s is greater than maxOffHeapSize:%s", + offHeapSize, + maxOffHeapSize); + + if (offHeapSize > 0) { + poolSize = poolSize > 0 ? poolSize : caculatePoolSize(offHeapSize, pageSize, maxOrder); + pools = new AbstractMemoryPool[poolSize]; + initChunkNum = caculateChunkNum(poolSize, chunkSize, offHeapSize); + maxChunkNum = caculateChunkNum(poolSize, chunkSize, maxOffHeapSize); + maxChunkNum = Math.max(maxChunkNum, initChunkNum); + + for (int i = 0; i < poolSize; i++) { + pools[i] = + new DirectMemoryPool( + this, pageSize, maxOrder, pageShifts, chunkSize, initChunkNum, maxChunkNum); + } + + if (autoAdaptEnable) { + adaptivePools = new AbstractMemoryPool[poolSize]; + for (int i = 0; i < poolSize; i++) { + adaptivePools[i] = new HeapMemoryPool(this, pageSize, maxOrder, pageShifts, chunkSize); } + } - int index = 0; - for (int i = 1; i < pools.length; i++) { - AbstractMemoryPool pool = pools[i]; - if (pool.groupThreads(group).intValue() < pools[index].groupThreads(group).intValue()) { - index = i; - } - } - return index; + MemoryGroupManger.getInstance().load(config); + MemoryGroupManger.getInstance() + .resetMemory((long) poolSize * initChunkNum * chunkSize, chunkSize, MemoryMode.OFF_HEAP); } - private static int validateAndCalculateChunkSize(int pageSize, int maxOrder) { - if (maxOrder > 14) { - throw new IllegalArgumentException("maxOrder: " + maxOrder + " (expected: 0-14)"); - } - - // Ensure the resulting chunkSize does not overflow. - int chunkSize = pageSize; - for (int i = maxOrder; i > 0; i--) { - if (chunkSize > MAX_CHUNK_SIZE / 2) { - throw new IllegalArgumentException(String.format( - "pageSize (%d) << maxOrder (%d) must not exceed %d", pageSize, maxOrder, MAX_CHUNK_SIZE)); - } - chunkSize <<= 1; - } - return chunkSize; + if (heapSize > 0) { + poolSize = poolSize > 0 ? poolSize : caculatePoolSize(heapSize, pageSize, maxOrder); + pools = new AbstractMemoryPool[poolSize]; + for (int i = 0; i < poolSize; i++) { + pools[i] = new HeapMemoryPool(this, pageSize, maxOrder, pageShifts, chunkSize); + } } - static int validateAndCalculatePageShifts(int pageSize) { - if (pageSize < MIN_PAGE_SIZE) { - throw new IllegalArgumentException("pageSize: " + pageSize + " (expected: " + MIN_PAGE_SIZE + ")"); - } - - if ((pageSize & pageSize - 1) != 0) { - throw new IllegalArgumentException("pageSize: " + pageSize + " (expected: power of 2)"); - } + this.maxMemory = (long) chunkSize * poolSize * maxChunkNum; + + LOGGER.info( + "MemoryManager init, offHeapSize:{},maxOffHeapSize:{},heapSize:{}," + + "pageSize:{},maxOrder:{},pageShifts:{},chunkSize:{},poolSize:{},initChunkNum:{}," + + "maxChunkNum:{}", + offHeapSize, + this.maxMemory, + heapSize, + pageSize, + maxOrder, + pageShifts, + chunkSize, + poolSize, + initChunkNum, + maxChunkNum); + + threadLocal = + new BaseMemoryGroupThreadLocal() { + @Override + protected synchronized ThreadLocalCache initialValue(MemoryGroup group) { + int index = leastUsedPool(pools, group); + AbstractMemoryPool localPool = pools[index]; + // get one adaptive pool for the local pool + AbstractMemoryPool adaptivePool = + (autoAdaptEnable && adaptivePools != null) ? adaptivePools[index] : null; + return new ThreadLocalCache(pools, localPool, adaptivePool, group); + } + + @Override + protected void notifyRemove(ThreadLocalCache threadLocalCache) { + threadLocalCache.free(); + } + }; + } - // Logarithm base 2. At this point we know that pageSize is a power of two. - return Integer.SIZE - 1 - Integer.numberOfLeadingZeros(pageSize); + private int leastUsedPool(AbstractMemoryPool[] pools, MemoryGroup group) { + if (pools == null || pools.length == 0) { + throw new IllegalArgumentException("pool size is empty"); } - private int caculatePoolSize(long heapSize, int pageSize, int maxOrder) { - int poolSize = (int) (heapSize / (pageSize << maxOrder) / 16); - return poolSize > 0 ? poolSize : 1; + int index = 0; + for (int i = 1; i < pools.length; i++) { + AbstractMemoryPool pool = pools[i]; + if (pool.groupThreads(group).intValue() < pools[index].groupThreads(group).intValue()) { + index = i; + } } + return index; + } - private int caculateChunkNum(int poolSize, int chunkSize, long initSize) { - int chunkNum = (int) (initSize / poolSize / chunkSize); - return chunkNum > 0 ? chunkNum : 1; + private static int validateAndCalculateChunkSize(int pageSize, int maxOrder) { + if (maxOrder > 14) { + throw new IllegalArgumentException("maxOrder: " + maxOrder + " (expected: 0-14)"); } - ThreadLocalCache localCache(MemoryGroup group) { - ThreadLocalCache localCache = threadLocal.get(group); - assert localCache != null; - return localCache; + // Ensure the resulting chunkSize does not overflow. + int chunkSize = pageSize; + for (int i = maxOrder; i > 0; i--) { + if (chunkSize > MAX_CHUNK_SIZE / 2) { + throw new IllegalArgumentException( + String.format( + "pageSize (%d) << maxOrder (%d) must not exceed %d", + pageSize, maxOrder, MAX_CHUNK_SIZE)); + } + chunkSize <<= 1; } + return chunkSize; + } - void removeCache() { - BaseMemoryGroupThreadLocal.removeAll(); + static int validateAndCalculatePageShifts(int pageSize) { + if (pageSize < MIN_PAGE_SIZE) { + throw new IllegalArgumentException( + "pageSize: " + pageSize + " (expected: " + MIN_PAGE_SIZE + ")"); } - public MemoryView requireMemory(int size, MemoryGroup group) { - return new MemoryView(requireBufs(size, group), group); + if ((pageSize & pageSize - 1) != 0) { + throw new IllegalArgumentException("pageSize: " + pageSize + " (expected: power of 2)"); } - ByteBuf requireBuf(int size, MemoryGroup group) { - ThreadLocalCache localCache = localCache(group); - ByteBuf byteBuf = localCache.requireBuf(size); - return byteBuf; + // Logarithm base 2. At this point we know that pageSize is a power of two. + return Integer.SIZE - 1 - Integer.numberOfLeadingZeros(pageSize); + } + + private int caculatePoolSize(long heapSize, int pageSize, int maxOrder) { + int poolSize = (int) (heapSize / (pageSize << maxOrder) / 16); + return poolSize > 0 ? poolSize : 1; + } + + private int caculateChunkNum(int poolSize, int chunkSize, long initSize) { + int chunkNum = (int) (initSize / poolSize / chunkSize); + return chunkNum > 0 ? chunkNum : 1; + } + + ThreadLocalCache localCache(MemoryGroup group) { + ThreadLocalCache localCache = threadLocal.get(group); + assert localCache != null; + return localCache; + } + + void removeCache() { + BaseMemoryGroupThreadLocal.removeAll(); + } + + public MemoryView requireMemory(int size, MemoryGroup group) { + return new MemoryView(requireBufs(size, group), group); + } + + ByteBuf requireBuf(int size, MemoryGroup group) { + ThreadLocalCache localCache = localCache(group); + ByteBuf byteBuf = localCache.requireBuf(size); + return byteBuf; + } + + List requireBufs(int size, MemoryGroup group) { + ThreadLocalCache localCache = localCache(group); + List bufs = localCache.requireBufs(size); + group.updateByteBufCount(bufs.size()); + return bufs; + } + + public void dispose() { + try { + memoryManager.removeCache(); + for (int i = 0; i < pools.length; i++) { + pools[i].destroy(); + pools[i] = null; + } + memoryManager = null; + MemoryGroupManger.getInstance().clear(); + System.gc(); + } catch (Throwable t) { + throw new RuntimeException(t); } - - List requireBufs(int size, MemoryGroup group) { - ThreadLocalCache localCache = localCache(group); - List bufs = localCache.requireBufs(size); - group.updateByteBufCount(bufs.size()); - return bufs; + } + + public String dumpStats() { + + StringBuilder buf = new StringBuilder(512); + buf.append(this.getClass().getSimpleName()) + .append("(usedMemory: ") + .append(usedMemory()) + .append("; allocatedMemory: ") + .append(totalAllocateMemory()) + .append("; numPools: ") + .append(poolSize()) + .append("; numThreadLocalCaches: ") + .append(numThreadLocalCaches()) + .append("; chunkSize: ") + .append(chunkSize()) + .append(')') + .append("\n"); + + int len = pools == null ? 0 : pools.length; + if (len > 0) { + boolean heap = pools[0].getMemoryMode() == MemoryMode.ON_HEAP; + buf.append(len); + if (heap) { + buf.append(" heap pool(s):").append("\n"); + } else { + buf.append(" direct pool(s):").append("\n"); + } + + for (AbstractMemoryPool a : pools) { + buf.append(a); + } + } else { + buf.append(" none pool!"); } - public void dispose() { - try { - memoryManager.removeCache(); - for (int i = 0; i < pools.length; i++) { - pools[i].destroy(); - pools[i] = null; - } - memoryManager = null; - MemoryGroupManger.getInstance().clear(); - System.gc(); - } catch (Throwable t) { - throw new RuntimeException(t); - } - } + return buf.toString(); + } - public String dumpStats() { - - StringBuilder buf = new StringBuilder(512); - buf.append(this.getClass().getSimpleName()) - .append("(usedMemory: ").append(usedMemory()) - .append("; allocatedMemory: ").append(totalAllocateMemory()) - .append("; numPools: ").append(poolSize()) - .append("; numThreadLocalCaches: ").append(numThreadLocalCaches()) - .append("; chunkSize: ").append(chunkSize()).append(')').append("\n"); - - int len = pools == null ? 0 : pools.length; - if (len > 0) { - boolean heap = pools[0].getMemoryMode() == MemoryMode.ON_HEAP; - buf.append(len); - if (heap) { - buf.append(" heap pool(s):") - .append("\n"); - } else { - buf.append(" direct pool(s):") - .append("\n"); - } - - for (AbstractMemoryPool a : pools) { - buf.append(a); - } - } else { - buf.append(" none pool!"); - } + public long totalAllocateMemory() { + return totalAllocateMemory(pools) + totalAllocateMemory(adaptivePools); + } - return buf.toString(); - } + public long usedMemory() { + return usedMemory(pools) + usedMemory(adaptivePools); + } - public long totalAllocateMemory() { - return totalAllocateMemory(pools) + totalAllocateMemory(adaptivePools); - } + public long maxMemory() { + return this.maxMemory; + } - public long usedMemory() { - return usedMemory(pools) + usedMemory(adaptivePools); + public long totalAllocateHeapMemory() { + if (pools == null) { + return 0; } - public long maxMemory() { - return this.maxMemory; + if (pools[0].getMemoryMode() == MemoryMode.ON_HEAP) { + return totalAllocateMemory(pools); } + return totalAllocateMemory(adaptivePools); + } - public long totalAllocateHeapMemory() { - if (pools == null) { - return 0; - } - - if (pools[0].getMemoryMode() == MemoryMode.ON_HEAP) { - return totalAllocateMemory(pools); - } - return totalAllocateMemory(adaptivePools); + public long totalAllocateOffHeapMemory() { + if (pools == null) { + return 0; } - public long totalAllocateOffHeapMemory() { - if (pools == null) { - return 0; - } - - if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.OFF_HEAP) { - return totalAllocateMemory(pools); - } - return 0; + if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.OFF_HEAP) { + return totalAllocateMemory(pools); } + return 0; + } - public long usedHeapMemory() { - if (pools == null) { - return 0; - } - - if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.ON_HEAP) { - return usedMemory(pools); - } - return usedMemory(adaptivePools); + public long usedHeapMemory() { + if (pools == null) { + return 0; } - public long usedOffHeapMemory() { - if (pools == null) { - return 0; - } - - if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.OFF_HEAP) { - return usedMemory(pools); - } - return 0; + if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.ON_HEAP) { + return usedMemory(pools); } + return usedMemory(adaptivePools); + } - private static long usedMemory(AbstractMemoryPool... pools) { - if (pools == null) { - return 0; - } - long used = 0; - for (AbstractMemoryPool arena : pools) { - used += arena.numActiveBytes(); - if (used < 0) { - return Long.MAX_VALUE; - } - } - return used; + public long usedOffHeapMemory() { + if (pools == null) { + return 0; } - private static long totalAllocateMemory(AbstractMemoryPool... pools) { - if (pools == null) { - return 0; - } - long used = 0; - for (AbstractMemoryPool arena : pools) { - used += arena.allocateBytes(); - if (used < 0) { - return Long.MAX_VALUE; - } - } - return used; + if (pools.length > 0 && pools[0].getMemoryMode() == MemoryMode.OFF_HEAP) { + return usedMemory(pools); } + return 0; + } - public int numThreadLocalCaches() { - AbstractMemoryPool[] pools = this.pools; - if (pools == null) { - return 0; - } - - int total = 0; - for (AbstractMemoryPool pool : pools) { - total += pool.numThreadCaches(); - } - - return total; + private static long usedMemory(AbstractMemoryPool... pools) { + if (pools == null) { + return 0; } + long used = 0; + for (AbstractMemoryPool arena : pools) { + used += arena.numActiveBytes(); + if (used < 0) { + return Long.MAX_VALUE; + } + } + return used; + } - public int poolSize() { - return pools == null ? 0 : pools.length; + private static long totalAllocateMemory(AbstractMemoryPool... pools) { + if (pools == null) { + return 0; + } + long used = 0; + for (AbstractMemoryPool arena : pools) { + used += arena.allocateBytes(); + if (used < 0) { + return Long.MAX_VALUE; + } } + return used; + } - public int chunkSize() { - return chunkSize; + public int numThreadLocalCaches() { + AbstractMemoryPool[] pools = this.pools; + if (pools == null) { + return 0; } - @Override - public String toString() { - return dumpStats(); + int total = 0; + for (AbstractMemoryPool pool : pools) { + total += pool.numThreadCaches(); } + + return total; + } + + public int poolSize() { + return pools == null ? 0 : pools.length; + } + + public int chunkSize() { + return chunkSize; + } + + @Override + public String toString() { + return dumpStats(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryMode.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryMode.java index a1b1a1759..7c2fd3195 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryMode.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryMode.java @@ -20,12 +20,8 @@ package org.apache.geaflow.memory; enum MemoryMode { - /** - * ON HEAP MEMORY. - */ - ON_HEAP, - /** - * OFF HEAP MEMORY. - */ - OFF_HEAP + /** ON HEAP MEMORY. */ + ON_HEAP, + /** OFF HEAP MEMORY. */ + OFF_HEAP } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryView.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryView.java index d38ce4093..c33c07f2c 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryView.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryView.java @@ -19,148 +19,150 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; import java.io.Closeable; import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import org.apache.geaflow.memory.channel.ByteBufferWritableChannel; -public class MemoryView implements Serializable, Closeable { - - final List bufList; - int contentSize; - int currIndex = 0; // -1 indicates released +import org.apache.geaflow.memory.channel.ByteBufferWritableChannel; - final MemoryGroup group; +import com.google.common.base.Preconditions; - private transient MemoryViewWriter writer; // only one writer +public class MemoryView implements Serializable, Closeable { - public MemoryView(List bufList, MemoryGroup memoryGroup) { - Preconditions.checkArgument(bufList != null && bufList.size() > 0); - this.bufList = new ArrayList<>(bufList); - this.group = memoryGroup; + final List bufList; + int contentSize; + int currIndex = 0; // -1 indicates released + + final MemoryGroup group; + + private transient MemoryViewWriter writer; // only one writer + + public MemoryView(List bufList, MemoryGroup memoryGroup) { + Preconditions.checkArgument(bufList != null && bufList.size() > 0); + this.bufList = new ArrayList<>(bufList); + this.group = memoryGroup; + } + + public List getBufList() { + return bufList; + } + + public byte[] toArray() { + checkAvail(); + ByteBufferWritableChannel channel = new ByteBufferWritableChannel(contentSize); + writeFully(channel); + channel.close(); + return channel.getData(); + } + + private void writeFully(ByteBufferWritableChannel channel) { + // https://www.evanjones.ca/java-bytebuffer-leak.html + // carefully set write number + int maxIndex = Math.min(bufList.size() - 1, currIndex); + for (int i = 0; i <= maxIndex; i++) { + synchronized (bufList.get(i)) { + channel.write(bufList.get(i).getBf(), 0); + } } + } - public List getBufList() { - return bufList; - } + void checkAvail() { + Preconditions.checkArgument(currIndex != -1, "view has been released"); + } - public byte[] toArray() { - checkAvail(); - ByteBufferWritableChannel channel = new ByteBufferWritableChannel(contentSize); - writeFully(channel); - channel.close(); - return channel.getData(); + public int remain() { + checkAvail(); + int remain = 0; + for (int i = currIndex; i < bufList.size(); i++) { + remain += bufList.get(i).getBf().remaining(); } - - private void writeFully(ByteBufferWritableChannel channel) { - // https://www.evanjones.ca/java-bytebuffer-leak.html - // carefully set write number - int maxIndex = Math.min(bufList.size() - 1, currIndex); - for (int i = 0; i <= maxIndex; i++) { - synchronized (bufList.get(i)) { - channel.write(bufList.get(i).getBf(), 0); - } - } + return remain; + } + + public int contentSize() { + return contentSize; + } + + @Override + public int hashCode() { + int result = 1; + int maxIndex = Math.min(bufList.size() - 1, currIndex); + for (int i = 0; i <= maxIndex; i++) { + ByteBuf buf = bufList.get(i); + for (int j = 0; j < buf.getBf().capacity(); j++) { + result = 31 * result + buf.getBf().get(j); + } } + return result; + } - void checkAvail() { - Preconditions.checkArgument(currIndex != -1, "view has been released"); + @Override + public boolean equals(Object other) { + if (!(other instanceof MemoryView)) { + return false; } - - public int remain() { - checkAvail(); - int remain = 0; - for (int i = currIndex; i < bufList.size(); i++) { - remain += bufList.get(i).getBf().remaining(); - } - return remain; + if (this == other) { + return true; } - public int contentSize() { - return contentSize; - } + MemoryView otherView = (MemoryView) other; - @Override - public int hashCode() { - int result = 1; - int maxIndex = Math.min(bufList.size() - 1, currIndex); - for (int i = 0; i <= maxIndex; i++) { - ByteBuf buf = bufList.get(i); - for (int j = 0; j < buf.getBf().capacity(); j++) { - result = 31 * result + buf.getBf().get(j); - } - } - return result; + if (group != otherView.group) { + return false; } - @Override - public boolean equals(Object other) { - if (!(other instanceof MemoryView)) { - return false; - } - if (this == other) { - return true; - } - - MemoryView otherView = (MemoryView) other; - - if (group != otherView.group) { - return false; - } - - if (contentSize != otherView.contentSize) { - return false; - } - - MemoryViewReader reader = new MemoryViewReader(this); - MemoryViewReader reader2 = new MemoryViewReader(otherView); + if (contentSize != otherView.contentSize) { + return false; + } - for (int i = 0; i < contentSize; i++) { - if (reader.read() != reader2.read()) { - return false; - } - } + MemoryViewReader reader = new MemoryViewReader(this); + MemoryViewReader reader2 = new MemoryViewReader(otherView); - return true; + for (int i = 0; i < contentSize; i++) { + if (reader.read() != reader2.read()) { + return false; + } } - @Override - public void close() { - if (currIndex >= 0) { - currIndex = -1; // must be after untegister - bufList.forEach(buf -> buf.free(group)); - group.updateByteBufCount(-1 * bufList.size()); - bufList.clear(); - } - } + return true; + } - public MemoryGroup getGroup() { - return this.group; + @Override + public void close() { + if (currIndex >= 0) { + currIndex = -1; // must be after untegister + bufList.forEach(buf -> buf.free(group)); + group.updateByteBufCount(-1 * bufList.size()); + bufList.clear(); } + } - public MemoryViewReader getReader() { - return new MemoryViewReader(this); - } + public MemoryGroup getGroup() { + return this.group; + } - public MemoryViewWriter getWriter() { - if (writer == null) { - writer = new MemoryViewWriter(this, group.getSpanSize()); - } - return writer; - } + public MemoryViewReader getReader() { + return new MemoryViewReader(this); + } - public MemoryViewWriter getWriter(int spanSize) { - if (writer == null) { - writer = new MemoryViewWriter(this, spanSize); - } - return writer; + public MemoryViewWriter getWriter() { + if (writer == null) { + writer = new MemoryViewWriter(this, group.getSpanSize()); } + return writer; + } - public void reset() { - currIndex = 0; - contentSize = 0; - bufList.forEach(ByteBuf::reset); + public MemoryViewWriter getWriter(int spanSize) { + if (writer == null) { + writer = new MemoryViewWriter(this, spanSize); } + return writer; + } + + public void reset() { + currIndex = 0; + contentSize = 0; + bufList.forEach(ByteBuf::reset); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReader.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReader.java index e4351fdf3..14436d67a 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReader.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReader.java @@ -25,132 +25,130 @@ public class MemoryViewReader { - private int readBufPos = 0; - private int readIndex = 0; - private int readPos = 0; - private final MemoryView view; - - public MemoryViewReader(MemoryView view) { - this.view = view; - view.checkAvail(); + private int readBufPos = 0; + private int readIndex = 0; + private int readPos = 0; + private final MemoryView view; + + public MemoryViewReader(MemoryView view) { + this.view = view; + view.checkAvail(); + } + + public boolean hasNext() { + return view.contentSize() > readPos; + } + + public byte read() { + if (readPos >= view.contentSize()) { + throw new BufferUnderflowException(); } + byte res = view.bufList.get(readIndex).get(readBufPos++); - public boolean hasNext() { - return view.contentSize() > readPos; + if (readBufPos == view.bufList.get(readIndex).getLength()) { + readIndex++; + readBufPos = 0; } - - public byte read() { - if (readPos >= view.contentSize()) { - throw new BufferUnderflowException(); - } - byte res = view.bufList.get(readIndex).get(readBufPos++); - - if (readBufPos == view.bufList.get(readIndex).getLength()) { - readIndex++; - readBufPos = 0; - } - readPos++; - return res; + readPos++; + return res; + } + + public int read(byte[] b) { + return read(b, 0, b.length); + } + + public int read(OutputStream outputStream) throws IOException { + int readLen = 0; + while (readIndex < view.bufList.size()) { + ByteBuf buf = view.bufList.get(readIndex); + int toRead = buf.contentSize() - readBufPos; + readLen += toRead; + + if (buf.chunk.pool.getMemoryMode() == MemoryMode.ON_HEAP) { + outputStream.write(buf.getBf().array(), buf.getStartOffset(), toRead); + } else { + byte[] b = new byte[toRead]; + buf.duplicate().get(b, 0, toRead); + outputStream.write(b, 0, toRead); + } + + readBufPos = 0; + readIndex++; } - - public int read(byte[] b) { - return read(b, 0, b.length); + readPos += readLen; + return readLen; + } + + public int read(byte[] b, int off, int len) { + if (b == null) { + throw new NullPointerException(); + } else if (off < 0 || len < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + return 0; } - public int read(OutputStream outputStream) throws IOException { - int readLen = 0; - while (readIndex < view.bufList.size()) { - ByteBuf buf = view.bufList.get(readIndex); - int toRead = buf.contentSize() - readBufPos; - readLen += toRead; - - if (buf.chunk.pool.getMemoryMode() == MemoryMode.ON_HEAP) { - outputStream.write(buf.getBf().array(), buf.getStartOffset(), toRead); - } else { - byte[] b = new byte[toRead]; - buf.duplicate().get(b, 0, toRead); - outputStream.write(b, 0, toRead); - } - - readBufPos = 0; - readIndex++; - } - readPos += readLen; - return readLen; + int offset = off; + int remain = len; + + while (remain > 0 && readIndex < view.bufList.size()) { + ByteBuf buf = view.bufList.get(readIndex); + int r = buf.contentSize() - readBufPos; + if (r == 0) { + break; + } + int toRead = Math.min(r, remain); + if (buf.getChunk().pool.getMemoryMode() == MemoryMode.ON_HEAP) { + System.arraycopy(buf.getBf().array(), buf.getStartOffset() + readBufPos, b, offset, toRead); + } else { + DirectMemory.copyMemory( + DirectMemory.directBufferAddress(buf.getBf()) + readBufPos, b, offset, toRead); + } + remain -= toRead; + offset += toRead; + readBufPos += toRead; + + if (toRead == r && !buf.hasRemaining()) { + readBufPos = 0; + readIndex++; + } } - - public int read(byte[] b, int off, int len) { - if (b == null) { - throw new NullPointerException(); - } else if (off < 0 || len < 0 || len > b.length - off) { - throw new IndexOutOfBoundsException(); - } else if (len == 0) { - return 0; - } - - int offset = off; - int remain = len; - - while (remain > 0 && readIndex < view.bufList.size()) { - ByteBuf buf = view.bufList.get(readIndex); - int r = buf.contentSize() - readBufPos; - if (r == 0) { - break; - } - int toRead = Math.min(r, remain); - if (buf.getChunk().pool.getMemoryMode() == MemoryMode.ON_HEAP) { - System.arraycopy(buf.getBf().array(), buf.getStartOffset() + readBufPos, b, offset, - toRead); - } else { - DirectMemory.copyMemory(DirectMemory.directBufferAddress(buf.getBf()) + readBufPos, - b, offset, toRead); - } - remain -= toRead; - offset += toRead; - readBufPos += toRead; - - if (toRead == r && !buf.hasRemaining()) { - readBufPos = 0; - readIndex++; - } - + int readLen = len - remain; + readPos += readLen; + return readLen == 0 ? -1 : readLen; + } + + public long skip(long n) { + int remain = (int) n; + while (remain > 0 && readIndex < view.bufList.size()) { + ByteBuf buf = view.bufList.get(readIndex); + int toRead = Math.min(buf.contentSize() - readBufPos, remain); + remain -= toRead; + readBufPos += toRead; + if (readBufPos == buf.contentSize()) { + if (buf.hasRemaining()) { + break; } - int readLen = len - remain; - readPos += readLen; - return readLen == 0 ? -1 : readLen; - } - - public long skip(long n) { - int remain = (int) n; - while (remain > 0 && readIndex < view.bufList.size()) { - ByteBuf buf = view.bufList.get(readIndex); - int toRead = Math.min(buf.contentSize() - readBufPos, remain); - remain -= toRead; - readBufPos += toRead; - if (readBufPos == buf.contentSize()) { - if (buf.hasRemaining()) { - break; - } - readBufPos = 0; - readIndex++; - } - } - long len = n - remain; - readPos += len; - return len; - } - - public void reset() { - this.readIndex = 0; - this.readBufPos = 0; - this.readPos = 0; - } - - public int readPos() { - return readPos; - } - - public int available() { - return view.contentSize - readPos; + readBufPos = 0; + readIndex++; + } } + long len = n - remain; + readPos += len; + return len; + } + + public void reset() { + this.readIndex = 0; + this.readBufPos = 0; + this.readPos = 0; + } + + public int readPos() { + return readPos; + } + + public int available() { + return view.contentSize - readPos; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReference.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReference.java index 23bf0d9b0..8b13b9122 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReference.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewReference.java @@ -24,44 +24,44 @@ public class MemoryViewReference implements Closeable { - private MemoryView memoryView; - private final AtomicInteger refCnt; + private MemoryView memoryView; + private final AtomicInteger refCnt; - public MemoryViewReference(MemoryView memoryView) { - this.memoryView = memoryView; - refCnt = new AtomicInteger(1); - } + public MemoryViewReference(MemoryView memoryView) { + this.memoryView = memoryView; + refCnt = new AtomicInteger(1); + } - public void incRef() { - refCnt.addAndGet(1); - } + public void incRef() { + refCnt.addAndGet(1); + } - public boolean decRef() { - if (refCnt.addAndGet(-1) <= 0) { - close(); - return true; - } - return false; + public boolean decRef() { + if (refCnt.addAndGet(-1) <= 0) { + close(); + return true; } + return false; + } - public MemoryView getMemoryView() { - return memoryView; - } + public MemoryView getMemoryView() { + return memoryView; + } - public int refCnt() { - return refCnt.get(); - } + public int refCnt() { + return refCnt.get(); + } - @Override - public void close() { - if (memoryView == null) { - return; - } - synchronized (this) { - if (memoryView != null) { - memoryView.close(); - memoryView = null; - } - } + @Override + public void close() { + if (memoryView == null) { + return; + } + synchronized (this) { + if (memoryView != null) { + memoryView.close(); + memoryView = null; + } } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewWriter.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewWriter.java index f4d0a7b77..fd069f6b6 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewWriter.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/MemoryViewWriter.java @@ -19,186 +19,197 @@ package org.apache.geaflow.memory; -import com.google.common.base.Preconditions; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class MemoryViewWriter { - private static final Logger logger = LoggerFactory.getLogger(MemoryViewWriter.class); + private static final Logger logger = LoggerFactory.getLogger(MemoryViewWriter.class); - private final int spanSize; - private final MemoryView view; + private final int spanSize; + private final MemoryView view; - public MemoryViewWriter(MemoryView view) { - this(view, -1); - } + public MemoryViewWriter(MemoryView view) { + this(view, -1); + } - public MemoryViewWriter(MemoryView view, int spanSize) { - this.view = view; - this.spanSize = spanSize; - } + public MemoryViewWriter(MemoryView view, int spanSize) { + this.view = view; + this.spanSize = spanSize; + } - private void checkContentBounds(int len) { - if (Integer.MAX_VALUE - view.contentSize < len) { - throw new IllegalArgumentException(String.format( - "current size %d plus write size %d should not be greater than Integer.MAX_VALUE %d", - view.contentSize, len, Integer.MAX_VALUE)); - } + private void checkContentBounds(int len) { + if (Integer.MAX_VALUE - view.contentSize < len) { + throw new IllegalArgumentException( + String.format( + "current size %d plus write size %d should not be greater than Integer.MAX_VALUE %d", + view.contentSize, len, Integer.MAX_VALUE)); } + } - public void write(int b) { - view.checkAvail(); - checkContentBounds(1); - if (view.currIndex == view.bufList.size()) { - span(1); - } - view.bufList.get(view.currIndex).getBf().put((byte) b); - - if (view.bufList.get(view.currIndex).contentSize() == view.bufList.get(view.currIndex) - .getLength()) { - view.currIndex++; - } - view.contentSize++; + public void write(int b) { + view.checkAvail(); + checkContentBounds(1); + if (view.currIndex == view.bufList.size()) { + span(1); } + view.bufList.get(view.currIndex).getBf().put((byte) b); - public void write(byte[] bytes) { - write(bytes, 0, bytes.length); + if (view.bufList.get(view.currIndex).contentSize() + == view.bufList.get(view.currIndex).getLength()) { + view.currIndex++; } + view.contentSize++; + } - public void write(byte[] bytes, int off, int len) { - view.checkAvail(); - if (len <= 0) { - return; - } - checkContentBounds(len); - int offset = off; - int remain = len; - while (remain > 0) { - if (view.currIndex == view.bufList.size()) { - span(remain); - } - int r = view.bufList.get(view.currIndex).getRemain(); - int toWrite = Math.min(r, remain); - view.bufList.get(view.currIndex).getBf().put(bytes, offset, toWrite); - view.contentSize += toWrite; - - remain -= toWrite; - offset += toWrite; - - if (toWrite == r) { - view.currIndex++; - } - } + public void write(byte[] bytes) { + write(bytes, 0, bytes.length); + } + public void write(byte[] bytes, int off, int len) { + view.checkAvail(); + if (len <= 0) { + return; } - - public void write(ByteBuffer buffer, int len) throws IOException { - view.checkAvail(); - checkContentBounds(len); - - ByteBuffer readBuf = buffer; - int wrote = readBuf.position(); - int limit = readBuf.position() + len; - Preconditions.checkArgument(limit <= readBuf.limit(), - "length should not be bigger than buffer limit!"); - - while (wrote < limit) { - int remain = limit - wrote; - if (view.currIndex == view.bufList.size()) { - span(remain); - } - try { - int r = view.bufList.get(view.currIndex).getRemain(); - int toWrite = Math.min(r, remain); - ByteBuffer newBuf = (ByteBuffer) readBuf.duplicate().limit(wrote + toWrite); - view.bufList.get(view.currIndex).getBf().put(newBuf); - view.contentSize += toWrite; - - wrote += toWrite; - readBuf.position(wrote); - if (toWrite == r) { - view.currIndex++; - } - } catch (Throwable t) { - logger.error(String.format("currIndex=%d,bufSize=%d,currBuf=%s", view.currIndex, - view.bufList.size(), view.bufList.get(view.currIndex).getBf() == null), t); - throw t; - } - } + checkContentBounds(len); + int offset = off; + int remain = len; + while (remain > 0) { + if (view.currIndex == view.bufList.size()) { + span(remain); + } + int r = view.bufList.get(view.currIndex).getRemain(); + int toWrite = Math.min(r, remain); + view.bufList.get(view.currIndex).getBf().put(bytes, offset, toWrite); + view.contentSize += toWrite; + + remain -= toWrite; + offset += toWrite; + + if (toWrite == r) { + view.currIndex++; + } } - - public void write(InputStream inputStream, int len) throws IOException { - view.checkAvail(); - if (len <= 0) { - return; - } - checkContentBounds(len); - - int remain = len; - byte[] buffer = new byte[1024]; - while (remain > 0) { - if (view.currIndex == view.bufList.size()) { - span(remain); - } - try { - int r = view.bufList.get(view.currIndex).getRemain(); - int toWrite = Math.min(Math.min(r, remain), 1024); - - inputStream.read(buffer, 0, toWrite); - view.bufList.get(view.currIndex).getBf().put(buffer, 0, toWrite); - view.contentSize += toWrite; - remain -= toWrite; - - if (toWrite == r) { - view.currIndex++; - } - } catch (Throwable t) { - logger.error(String.format("currIndex=%d,bufSize=%d,currBuf=%s", view.currIndex, - view.bufList.size(), view.bufList.get(view.currIndex).getBf() == null), t); - throw t; - } + } + + public void write(ByteBuffer buffer, int len) throws IOException { + view.checkAvail(); + checkContentBounds(len); + + ByteBuffer readBuf = buffer; + int wrote = readBuf.position(); + int limit = readBuf.position() + len; + Preconditions.checkArgument( + limit <= readBuf.limit(), "length should not be bigger than buffer limit!"); + + while (wrote < limit) { + int remain = limit - wrote; + if (view.currIndex == view.bufList.size()) { + span(remain); + } + try { + int r = view.bufList.get(view.currIndex).getRemain(); + int toWrite = Math.min(r, remain); + ByteBuffer newBuf = (ByteBuffer) readBuf.duplicate().limit(wrote + toWrite); + view.bufList.get(view.currIndex).getBf().put(newBuf); + view.contentSize += toWrite; + + wrote += toWrite; + readBuf.position(wrote); + if (toWrite == r) { + view.currIndex++; } + } catch (Throwable t) { + logger.error( + String.format( + "currIndex=%d,bufSize=%d,currBuf=%s", + view.currIndex, + view.bufList.size(), + view.bufList.get(view.currIndex).getBf() == null), + t); + throw t; + } } + } - private void span(int size) { - List bufs = MemoryManager.getInstance() - .requireBufs(spanSize > 0 ? spanSize : size, view.group); - view.bufList.addAll(bufs); + public void write(InputStream inputStream, int len) throws IOException { + view.checkAvail(); + if (len <= 0) { + return; } - - public void position(int position) { - if (position > view.contentSize || position < 0) { - throw new IllegalArgumentException( - String.format("position %s not in [0,%s]", position, view.contentSize)); - } - if (position == view.contentSize) { - return; - } - if (position == 0) { - view.reset(); - return; - } - int i = 0; - int totalSize = 0; - for (; i < view.bufList.size(); i++) { - int length = view.bufList.get(i).getLength(); - if (length + totalSize > position) { - break; - } - totalSize += length; - } - view.currIndex = i; - view.contentSize = position; - - view.bufList.get(i).position(position - totalSize); - for (int j = i + 1; j < view.bufList.size(); j++) { - view.bufList.get(j).reset(); + checkContentBounds(len); + + int remain = len; + byte[] buffer = new byte[1024]; + while (remain > 0) { + if (view.currIndex == view.bufList.size()) { + span(remain); + } + try { + int r = view.bufList.get(view.currIndex).getRemain(); + int toWrite = Math.min(Math.min(r, remain), 1024); + + inputStream.read(buffer, 0, toWrite); + view.bufList.get(view.currIndex).getBf().put(buffer, 0, toWrite); + view.contentSize += toWrite; + remain -= toWrite; + + if (toWrite == r) { + view.currIndex++; } + } catch (Throwable t) { + logger.error( + String.format( + "currIndex=%d,bufSize=%d,currBuf=%s", + view.currIndex, + view.bufList.size(), + view.bufList.get(view.currIndex).getBf() == null), + t); + throw t; + } + } + } + + private void span(int size) { + List bufs = + MemoryManager.getInstance().requireBufs(spanSize > 0 ? spanSize : size, view.group); + view.bufList.addAll(bufs); + } + + public void position(int position) { + if (position > view.contentSize || position < 0) { + throw new IllegalArgumentException( + String.format("position %s not in [0,%s]", position, view.contentSize)); + } + if (position == view.contentSize) { + return; } + if (position == 0) { + view.reset(); + return; + } + int i = 0; + int totalSize = 0; + for (; i < view.bufList.size(); i++) { + int length = view.bufList.get(i).getLength(); + if (length + totalSize > position) { + break; + } + totalSize += length; + } + view.currIndex = i; + view.contentSize = position; + view.bufList.get(i).position(position - totalSize); + for (int j = i + 1; j < view.bufList.size(); j++) { + view.bufList.get(j).reset(); + } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Page.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Page.java index 18a9cc954..b8b2727fd 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Page.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/Page.java @@ -18,211 +18,208 @@ import org.apache.geaflow.memory.metric.PageMetric; -/** - * This class is an adaptation of Netty's io.netty.buffer.PoolSubpage. - */ +/** This class is an adaptation of Netty's io.netty.buffer.PoolSubpage. */ public class Page implements PageMetric { - final Chunk chunk; - private final int memoryMapIdx; - private final int runOffset; - private final int pageSize; - private final long[] bitmap; - - Page prev; - Page next; - - boolean doNotDestroy; - int elemSize; - private int maxNumElems; - private int bitmapLength; - private int nextAvail; - private int numAvail; - - public Page(int pageSize) { - this.chunk = null; - this.memoryMapIdx = -1; - this.runOffset = -1; - this.pageSize = pageSize; - this.bitmap = null; - this.elemSize = -1; - } - - Page(Page head, Chunk chunk, int memoryMapIdx, int runOffset, int pageSize, int elemSize) { - this.chunk = chunk; - this.memoryMapIdx = memoryMapIdx; - this.runOffset = runOffset; - this.pageSize = pageSize; - bitmap = new long[pageSize >>> 10]; // pageSize / 16 / 64 - init(head, elemSize); - } - - void init(Page head, int elemSize) { - doNotDestroy = true; - this.elemSize = elemSize; - if (elemSize != 0) { - maxNumElems = numAvail = pageSize / elemSize; - nextAvail = 0; - bitmapLength = maxNumElems >>> 6; - if ((maxNumElems & 63) != 0) { - bitmapLength++; - } - - for (int i = 0; i < bitmapLength; i++) { - bitmap[i] = 0; - } - } - addToPool(head); - } - - /** - * Returns the bitmap index of the subpage allocation. - */ - long allocate() { - if (elemSize == 0) { - return toHandle(0); - } - - if (numAvail == 0 || !doNotDestroy) { - return -1; - } - - final int bitmapIdx = getNextAvail(); - int q = bitmapIdx >>> 6; - int r = bitmapIdx & 63; - assert (bitmap[q] >>> r & 1) == 0; - bitmap[q] |= 1L << r; - - if (--numAvail == 0) { - removeFromPool(); - } - - return toHandle(bitmapIdx); - } - - /** - * Free page. - * @return {@code true} if this subpage is in use. - * {@code false} if this subpage is not used by its chunk and thus it's OK to be released. - */ - boolean free(Page head, int bitmapIdx) { - if (elemSize == 0) { - return true; - } - int q = bitmapIdx >>> 6; - int r = bitmapIdx & 63; - assert (bitmap[q] >>> r & 1) != 0; - bitmap[q] ^= 1L << r; - - setNextAvail(bitmapIdx); - - if (numAvail++ == 0) { - addToPool(head); - return true; - } - - if (numAvail != maxNumElems) { - return true; + final Chunk chunk; + private final int memoryMapIdx; + private final int runOffset; + private final int pageSize; + private final long[] bitmap; + + Page prev; + Page next; + + boolean doNotDestroy; + int elemSize; + private int maxNumElems; + private int bitmapLength; + private int nextAvail; + private int numAvail; + + public Page(int pageSize) { + this.chunk = null; + this.memoryMapIdx = -1; + this.runOffset = -1; + this.pageSize = pageSize; + this.bitmap = null; + this.elemSize = -1; + } + + Page(Page head, Chunk chunk, int memoryMapIdx, int runOffset, int pageSize, int elemSize) { + this.chunk = chunk; + this.memoryMapIdx = memoryMapIdx; + this.runOffset = runOffset; + this.pageSize = pageSize; + bitmap = new long[pageSize >>> 10]; // pageSize / 16 / 64 + init(head, elemSize); + } + + void init(Page head, int elemSize) { + doNotDestroy = true; + this.elemSize = elemSize; + if (elemSize != 0) { + maxNumElems = numAvail = pageSize / elemSize; + nextAvail = 0; + bitmapLength = maxNumElems >>> 6; + if ((maxNumElems & 63) != 0) { + bitmapLength++; + } + + for (int i = 0; i < bitmapLength; i++) { + bitmap[i] = 0; + } + } + addToPool(head); + } + + /** Returns the bitmap index of the subpage allocation. */ + long allocate() { + if (elemSize == 0) { + return toHandle(0); + } + + if (numAvail == 0 || !doNotDestroy) { + return -1; + } + + final int bitmapIdx = getNextAvail(); + int q = bitmapIdx >>> 6; + int r = bitmapIdx & 63; + assert (bitmap[q] >>> r & 1) == 0; + bitmap[q] |= 1L << r; + + if (--numAvail == 0) { + removeFromPool(); + } + + return toHandle(bitmapIdx); + } + + /** + * Free page. + * + * @return {@code true} if this subpage is in use. {@code false} if this subpage is not used by + * its chunk and thus it's OK to be released. + */ + boolean free(Page head, int bitmapIdx) { + if (elemSize == 0) { + return true; + } + int q = bitmapIdx >>> 6; + int r = bitmapIdx & 63; + assert (bitmap[q] >>> r & 1) != 0; + bitmap[q] ^= 1L << r; + + setNextAvail(bitmapIdx); + + if (numAvail++ == 0) { + addToPool(head); + return true; + } + + if (numAvail != maxNumElems) { + return true; + } else { + // Subpage not in use (numAvail == maxNumElems) + if (prev == next) { + // Do not remove if this subpage is the only one left in the pool. + return true; + } + + // Remove this subpage from the pool if there are other subpages left in the pool. + doNotDestroy = false; + removeFromPool(); + return false; + } + } + + private void addToPool(Page head) { + assert prev == null && next == null; + prev = head; + next = head.next; + next.prev = this; + head.next = this; + } + + private void removeFromPool() { + assert prev != null && next != null; + prev.next = next; + next.prev = prev; + next = null; + prev = null; + } + + private void setNextAvail(int bitmapIdx) { + nextAvail = bitmapIdx; + } + + private int getNextAvail() { + int nextAvail = this.nextAvail; + if (nextAvail >= 0) { + this.nextAvail = -1; + return nextAvail; + } + return findNextAvail(); + } + + private int findNextAvail() { + final long[] bitmap = this.bitmap; + final int bitmapLength = this.bitmapLength; + for (int i = 0; i < bitmapLength; i++) { + long bits = bitmap[i]; + if (~bits != 0) { + return findNextAvail0(i, bits); + } + } + return -1; + } + + private int findNextAvail0(int i, long bits) { + final int maxNumElems = this.maxNumElems; + final int baseVal = i << 6; + + for (int j = 0; j < 64; j++) { + if ((bits & 1) == 0) { + int val = baseVal | j; + if (val < maxNumElems) { + return val; } else { - // Subpage not in use (numAvail == maxNumElems) - if (prev == next) { - // Do not remove if this subpage is the only one left in the pool. - return true; - } - - // Remove this subpage from the pool if there are other subpages left in the pool. - doNotDestroy = false; - removeFromPool(); - return false; + break; } + } + bits >>>= 1; } + return -1; + } - private void addToPool(Page head) { - assert prev == null && next == null; - prev = head; - next = head.next; - next.prev = this; - head.next = this; - } + private long toHandle(int bitmapIdx) { + return 0x4000000000000000L | (long) bitmapIdx << 32 | memoryMapIdx; + } - private void removeFromPool() { - assert prev != null && next != null; - prev.next = next; - next.prev = prev; - next = null; - prev = null; + void destroy() { + if (chunk != null) { + chunk.destroy(); } + } - private void setNextAvail(int bitmapIdx) { - nextAvail = bitmapIdx; - } + @Override + public int maxNumElements() { + return maxNumElems; + } - private int getNextAvail() { - int nextAvail = this.nextAvail; - if (nextAvail >= 0) { - this.nextAvail = -1; - return nextAvail; - } - return findNextAvail(); - } + @Override + public int numAvailable() { + return numAvail; + } - private int findNextAvail() { - final long[] bitmap = this.bitmap; - final int bitmapLength = this.bitmapLength; - for (int i = 0; i < bitmapLength; i++) { - long bits = bitmap[i]; - if (~bits != 0) { - return findNextAvail0(i, bits); - } - } - return -1; - } - - private int findNextAvail0(int i, long bits) { - final int maxNumElems = this.maxNumElems; - final int baseVal = i << 6; - - for (int j = 0; j < 64; j++) { - if ((bits & 1) == 0) { - int val = baseVal | j; - if (val < maxNumElems) { - return val; - } else { - break; - } - } - bits >>>= 1; - } - return -1; - } - - private long toHandle(int bitmapIdx) { - return 0x4000000000000000L | (long) bitmapIdx << 32 | memoryMapIdx; - } + @Override + public int elementSize() { + return elemSize; + } - void destroy() { - if (chunk != null) { - chunk.destroy(); - } - } - - @Override - public int maxNumElements() { - return maxNumElems; - } - - @Override - public int numAvailable() { - return numAvail; - } - - @Override - public int elementSize() { - return elemSize; - } - - @Override - public int pageSize() { - return pageSize; - } + @Override + public int pageSize() { + return pageSize; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/PlatformDependent.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/PlatformDependent.java index 3bc7f5340..703e31d59 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/PlatformDependent.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/PlatformDependent.java @@ -18,7 +18,6 @@ import static java.util.Objects.requireNonNull; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -27,351 +26,368 @@ import java.nio.ByteBuffer; import java.security.AccessController; import java.security.PrivilegedAction; + import org.apache.geaflow.common.utils.ReflectionUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class is an adaptation of Netty's io.netty.util.internal.PlatformDependent. - */ +import com.google.common.base.Preconditions; + +/** This class is an adaptation of Netty's io.netty.util.internal.PlatformDependent. */ public final class PlatformDependent implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PlatformDependent.class); - - static final sun.misc.Unsafe UNSAFE; - // constants borrowed from murmur3 - static final int HASH_CODE_ASCII_SEED = 0xc2b2ae35; - static final int HASH_CODE_C1 = 0xcc9e2d51; - static final int HASH_CODE_C2 = 0x1b873593; - private static final long ADDRESS_FIELD_OFFSET; - private static final long BYTE_ARRAY_BASE_OFFSET; - private static final Constructor DIRECT_BUFFER_CONSTRUCTOR; - private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE; - /** - * Limits the number of bytes to copy per {@link sun.misc.Unsafe#copyMemory(long, long, long)} to allow - * safepoint polling - * during a large copy. - */ - private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; - - static { - final ByteBuffer direct; - Field addressField = null; - Throwable unsafeUnavailabilityCause = null; - sun.misc.Unsafe unsafe; - - direct = ByteBuffer.allocateDirect(1); - - // attempt to access field Unsafe#theUnsafe - final Object maybeUnsafe = AccessController - .doPrivileged((PrivilegedAction) () -> { - try { + private static final Logger LOGGER = LoggerFactory.getLogger(PlatformDependent.class); + + static final sun.misc.Unsafe UNSAFE; + // constants borrowed from murmur3 + static final int HASH_CODE_ASCII_SEED = 0xc2b2ae35; + static final int HASH_CODE_C1 = 0xcc9e2d51; + static final int HASH_CODE_C2 = 0x1b873593; + private static final long ADDRESS_FIELD_OFFSET; + private static final long BYTE_ARRAY_BASE_OFFSET; + private static final Constructor DIRECT_BUFFER_CONSTRUCTOR; + private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE; + + /** + * Limits the number of bytes to copy per {@link sun.misc.Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + final ByteBuffer direct; + Field addressField = null; + Throwable unsafeUnavailabilityCause = null; + sun.misc.Unsafe unsafe; + + direct = ByteBuffer.allocateDirect(1); + + // attempt to access field Unsafe#theUnsafe + final Object maybeUnsafe = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { final Field unsafeField = sun.misc.Unsafe.class.getDeclaredField("theUnsafe"); // We always want to try using Unsafe as the access still works on java9 // as well and // we need it for out native-transports and many optimizations. Throwable cause = ReflectionUtil.trySetAccessible(unsafeField, false); if (cause != null) { - return cause; + return cause; } // the unsafe instance return unsafeField.get(null); - } catch (NoSuchFieldException | SecurityException | IllegalAccessException | NoClassDefFoundError e) { + } catch (NoSuchFieldException + | SecurityException + | IllegalAccessException + | NoClassDefFoundError e) { return e; - } // Also catch NoClassDefFoundError in case someone uses for example OSGI - // and it made - // Unsafe unloadable. - }); - - // the conditional check here can not be replaced with checking that maybeUnsafe - // is an instanceof Unsafe and reversing the if and else blocks; this is because an - // instanceof check against Unsafe will trigger a class load and we might not have - // the runtime permission accessClassInPackage.sun.misc - if (maybeUnsafe instanceof Throwable) { - unsafe = null; - unsafeUnavailabilityCause = (Throwable) maybeUnsafe; - } else { - unsafe = (sun.misc.Unsafe) maybeUnsafe; - } + } // Also catch NoClassDefFoundError in case someone uses for example OSGI + // and it made + // Unsafe unloadable. + }); - // ensure the unsafe supports all necessary methods to work around the mistake in the - // latest OpenJDK - // https://github.com/netty/netty/issues/1061 - // http://www.mail-archive.com/jdk6-dev@openjdk.java.net/msg00698.html - if (unsafe != null) { - final sun.misc.Unsafe finalUnsafe = unsafe; - final Object maybeException = AccessController - .doPrivileged((PrivilegedAction) () -> { + // the conditional check here can not be replaced with checking that maybeUnsafe + // is an instanceof Unsafe and reversing the if and else blocks; this is because an + // instanceof check against Unsafe will trigger a class load and we might not have + // the runtime permission accessClassInPackage.sun.misc + if (maybeUnsafe instanceof Throwable) { + unsafe = null; + unsafeUnavailabilityCause = (Throwable) maybeUnsafe; + } else { + unsafe = (sun.misc.Unsafe) maybeUnsafe; + } + + // ensure the unsafe supports all necessary methods to work around the mistake in the + // latest OpenJDK + // https://github.com/netty/netty/issues/1061 + // http://www.mail-archive.com/jdk6-dev@openjdk.java.net/msg00698.html + if (unsafe != null) { + final sun.misc.Unsafe finalUnsafe = unsafe; + final Object maybeException = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { try { - finalUnsafe.getClass() - .getDeclaredMethod("copyMemory", Object.class, long.class, - Object.class, long.class, long.class); - return null; + finalUnsafe + .getClass() + .getDeclaredMethod( + "copyMemory", + Object.class, + long.class, + Object.class, + long.class, + long.class); + return null; } catch (NoSuchMethodException | SecurityException e) { - return e; + return e; } - }); + }); - if (maybeException != null) { - unsafe = null; - unsafeUnavailabilityCause = (Throwable) maybeException; - } - } + if (maybeException != null) { + unsafe = null; + unsafeUnavailabilityCause = (Throwable) maybeException; + } + } - if (unsafe != null) { - final sun.misc.Unsafe finalUnsafe = unsafe; + if (unsafe != null) { + final sun.misc.Unsafe finalUnsafe = unsafe; - // attempt to access field Buffer#address - final Object maybeAddressField = AccessController - .doPrivileged((PrivilegedAction) () -> { + // attempt to access field Buffer#address + final Object maybeAddressField = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { try { - final Field field = Buffer.class.getDeclaredField("address"); - // Use Unsafe to read value of the address field. This way it will - // not fail on JDK9+ which - // will forbid changing the access level via reflection. - final long offset = finalUnsafe.objectFieldOffset(field); - final long address = finalUnsafe.getLong(direct, offset); - - // if direct really is a direct buffer, address will be non-zero - if (address == 0) { - return null; - } - return field; + final Field field = Buffer.class.getDeclaredField("address"); + // Use Unsafe to read value of the address field. This way it will + // not fail on JDK9+ which + // will forbid changing the access level via reflection. + final long offset = finalUnsafe.objectFieldOffset(field); + final long address = finalUnsafe.getLong(direct, offset); + + // if direct really is a direct buffer, address will be non-zero + if (address == 0) { + return null; + } + return field; } catch (NoSuchFieldException | SecurityException e) { - return e; + return e; } - }); - - if (maybeAddressField instanceof Field) { - addressField = (Field) maybeAddressField; - } else { - unsafeUnavailabilityCause = (Throwable) maybeAddressField; - - // If we cannot access the address of a direct buffer, there's no point of - // using unsafe. - // Let's just pretend unsafe is unavailable for overall simplicity. - unsafe = null; - } - } - - if (unsafe != null) { - // There are assumptions made where ever BYTE_ARRAY_BASE_OFFSET is used (equals, - // hashCodeAscii, and - // primitive accessors) that arrayIndexScale == 1, and results are undefined if - // this is not the case. - long byteArrayIndexScale = unsafe.arrayIndexScale(byte[].class); - if (byteArrayIndexScale != 1) { - unsafeUnavailabilityCause = new UnsupportedOperationException("Unexpected unsafe.arrayIndexScale"); - unsafe = null; - } - } - UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause; - UNSAFE = unsafe; + }); + + if (maybeAddressField instanceof Field) { + addressField = (Field) maybeAddressField; + } else { + unsafeUnavailabilityCause = (Throwable) maybeAddressField; + + // If we cannot access the address of a direct buffer, there's no point of + // using unsafe. + // Let's just pretend unsafe is unavailable for overall simplicity. + unsafe = null; + } + } - if (unsafe == null) { - ADDRESS_FIELD_OFFSET = -1; - BYTE_ARRAY_BASE_OFFSET = -1; - DIRECT_BUFFER_CONSTRUCTOR = null; - } else { - Constructor directBufferConstructor; - long address = -1; - try { - final Object maybeDirectBufferConstructor = AccessController - .doPrivileged((PrivilegedAction) () -> { - try { - final Constructor constructor = direct.getClass() - .getDeclaredConstructor(long.class, int.class); - Throwable cause = ReflectionUtil.trySetAccessible(constructor, false); - if (cause != null) { - return cause; - } - return constructor; - } catch (NoSuchMethodException | SecurityException e) { - return e; + if (unsafe != null) { + // There are assumptions made where ever BYTE_ARRAY_BASE_OFFSET is used (equals, + // hashCodeAscii, and + // primitive accessors) that arrayIndexScale == 1, and results are undefined if + // this is not the case. + long byteArrayIndexScale = unsafe.arrayIndexScale(byte[].class); + if (byteArrayIndexScale != 1) { + unsafeUnavailabilityCause = + new UnsupportedOperationException("Unexpected unsafe.arrayIndexScale"); + unsafe = null; + } + } + UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause; + UNSAFE = unsafe; + + if (unsafe == null) { + ADDRESS_FIELD_OFFSET = -1; + BYTE_ARRAY_BASE_OFFSET = -1; + DIRECT_BUFFER_CONSTRUCTOR = null; + } else { + Constructor directBufferConstructor; + long address = -1; + try { + final Object maybeDirectBufferConstructor = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + final Constructor constructor = + direct.getClass().getDeclaredConstructor(long.class, int.class); + Throwable cause = ReflectionUtil.trySetAccessible(constructor, false); + if (cause != null) { + return cause; } + return constructor; + } catch (NoSuchMethodException | SecurityException e) { + return e; + } }); - if (maybeDirectBufferConstructor instanceof Constructor) { - address = UNSAFE.allocateMemory(1); - // try to use the constructor now - try { - ((Constructor) maybeDirectBufferConstructor).newInstance(address, 1); - directBufferConstructor = (Constructor) maybeDirectBufferConstructor; - } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - directBufferConstructor = null; - } - } else { - directBufferConstructor = null; - } - } finally { - if (address != -1) { - UNSAFE.freeMemory(address); - } - } - DIRECT_BUFFER_CONSTRUCTOR = directBufferConstructor; - ADDRESS_FIELD_OFFSET = objectFieldOffset(addressField); - BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + if (maybeDirectBufferConstructor instanceof Constructor) { + address = UNSAFE.allocateMemory(1); + // try to use the constructor now + try { + ((Constructor) maybeDirectBufferConstructor).newInstance(address, 1); + directBufferConstructor = (Constructor) maybeDirectBufferConstructor; + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + directBufferConstructor = null; + } + } else { + directBufferConstructor = null; } - } - - private PlatformDependent() { - } - - static boolean hasUnsafe() { - return UNSAFE != null; - } - - static Throwable getUnsafeUnavailabilityCause() { - return UNSAFE_UNAVAILABILITY_CAUSE; - } - - public static void throwException(Throwable cause) { - // JVM has been observed to crash when passing a null argument. See https://github - // .com/netty/netty/issues/4131. - UNSAFE.throwException(requireNonNull(cause, "cause")); - } - - static boolean hasDirectBufferNoCleanerConstructor() { - return DIRECT_BUFFER_CONSTRUCTOR != null; - } - - static ByteBuffer allocateDirectNoCleaner(int capacity) { - // Calling malloc with capacity of 0 may return a null ptr or a memory address that can - // be used. - // Just use 1 to make it safe to use in all cases: - // See: http://pubs.opengroup.org/onlinepubs/009695399/functions/malloc.html - long addr = UNSAFE.allocateMemory(Math.max(1, capacity)); - return newDirectBuffer(addr, capacity); - } - - static ByteBuffer newDirectBuffer(long address, int capacity) { - Preconditions.checkArgument(capacity > 0, "capacity must > 0"); - - try { - return (ByteBuffer) DIRECT_BUFFER_CONSTRUCTOR.newInstance(address, capacity); - } catch (Throwable cause) { - // Not expected to ever throw! - if (cause instanceof Error) { - throw (Error) cause; - } - throw new Error(cause); + } finally { + if (address != -1) { + UNSAFE.freeMemory(address); } + } + DIRECT_BUFFER_CONSTRUCTOR = directBufferConstructor; + ADDRESS_FIELD_OFFSET = objectFieldOffset(addressField); + BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); } - - static long directBufferAddress(ByteBuffer buffer) { - return getLong(buffer, ADDRESS_FIELD_OFFSET); - } - - static long byteArrayBaseOffset() { - return BYTE_ARRAY_BASE_OFFSET; - } - - static Object getObject(Object object, long fieldOffset) { - return UNSAFE.getObject(object, fieldOffset); - } - - private static long getLong(Object object, long fieldOffset) { - return UNSAFE.getLong(object, fieldOffset); - } - - static long objectFieldOffset(Field field) { - return UNSAFE.objectFieldOffset(field); - } - - static byte getByte(long address) { - return UNSAFE.getByte(address); + } + + private PlatformDependent() {} + + static boolean hasUnsafe() { + return UNSAFE != null; + } + + static Throwable getUnsafeUnavailabilityCause() { + return UNSAFE_UNAVAILABILITY_CAUSE; + } + + public static void throwException(Throwable cause) { + // JVM has been observed to crash when passing a null argument. See https://github + // .com/netty/netty/issues/4131. + UNSAFE.throwException(requireNonNull(cause, "cause")); + } + + static boolean hasDirectBufferNoCleanerConstructor() { + return DIRECT_BUFFER_CONSTRUCTOR != null; + } + + static ByteBuffer allocateDirectNoCleaner(int capacity) { + // Calling malloc with capacity of 0 may return a null ptr or a memory address that can + // be used. + // Just use 1 to make it safe to use in all cases: + // See: http://pubs.opengroup.org/onlinepubs/009695399/functions/malloc.html + long addr = UNSAFE.allocateMemory(Math.max(1, capacity)); + return newDirectBuffer(addr, capacity); + } + + static ByteBuffer newDirectBuffer(long address, int capacity) { + Preconditions.checkArgument(capacity > 0, "capacity must > 0"); + + try { + return (ByteBuffer) DIRECT_BUFFER_CONSTRUCTOR.newInstance(address, capacity); + } catch (Throwable cause) { + // Not expected to ever throw! + if (cause instanceof Error) { + throw (Error) cause; + } + throw new Error(cause); } - - static short getShort(long address) { - return UNSAFE.getShort(address); + } + + static long directBufferAddress(ByteBuffer buffer) { + return getLong(buffer, ADDRESS_FIELD_OFFSET); + } + + static long byteArrayBaseOffset() { + return BYTE_ARRAY_BASE_OFFSET; + } + + static Object getObject(Object object, long fieldOffset) { + return UNSAFE.getObject(object, fieldOffset); + } + + private static long getLong(Object object, long fieldOffset) { + return UNSAFE.getLong(object, fieldOffset); + } + + static long objectFieldOffset(Field field) { + return UNSAFE.objectFieldOffset(field); + } + + static byte getByte(long address) { + return UNSAFE.getByte(address); + } + + static short getShort(long address) { + return UNSAFE.getShort(address); + } + + static int getInt(long address) { + return UNSAFE.getInt(address); + } + + static long getLong(long address) { + return UNSAFE.getLong(address); + } + + static void putByte(long address, byte value) { + UNSAFE.putByte(address, value); + } + + static void putShort(long address, short value) { + UNSAFE.putShort(address, value); + } + + static void putInt(long address, int value) { + UNSAFE.putInt(address, value); + } + + static void putLong(long address, long value) { + UNSAFE.putLong(address, value); + } + + static void copyMemory(long srcAddr, long dstAddr, long length) { + // Manual safe-point polling is only needed prior Java9: + // See https://bugs.openjdk.java.net/browse/JDK-8149596 + if (javaVersion() <= 8) { + copyMemoryWithSafePointPolling(srcAddr, dstAddr, length); + } else { + UNSAFE.copyMemory(srcAddr, dstAddr, length); } - - static int getInt(long address) { - return UNSAFE.getInt(address); + } + + private static void copyMemoryWithSafePointPolling(long srcAddr, long dstAddr, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(srcAddr, dstAddr, size); + length -= size; + srcAddr += size; + dstAddr += size; } - - static long getLong(long address) { - return UNSAFE.getLong(address); + } + + static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, long length) { + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, length); + } + + private static void copyMemoryWithSafePointPolling( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; } + } - static void putByte(long address, byte value) { - UNSAFE.putByte(address, value); - } + static void setMemory(long address, long bytes, byte value) { + UNSAFE.setMemory(address, bytes, value); + } - static void putShort(long address, short value) { - UNSAFE.putShort(address, value); + static ClassLoader getClassLoader(final Class clazz) { + if (System.getSecurityManager() == null) { + return clazz.getClassLoader(); + } else { + return AccessController.doPrivileged((PrivilegedAction) clazz::getClassLoader); } - - static void putInt(long address, int value) { - UNSAFE.putInt(address, value); + } + + static ClassLoader getSystemClassLoader() { + if (System.getSecurityManager() == null) { + return ClassLoader.getSystemClassLoader(); + } else { + return AccessController.doPrivileged( + (PrivilegedAction) ClassLoader::getSystemClassLoader); } + } - static void putLong(long address, long value) { - UNSAFE.putLong(address, value); - } - - static void copyMemory(long srcAddr, long dstAddr, long length) { - // Manual safe-point polling is only needed prior Java9: - // See https://bugs.openjdk.java.net/browse/JDK-8149596 - if (javaVersion() <= 8) { - copyMemoryWithSafePointPolling(srcAddr, dstAddr, length); - } else { - UNSAFE.copyMemory(srcAddr, dstAddr, length); - } - } + static void freeMemory(long address) { + UNSAFE.freeMemory(address); + } - private static void copyMemoryWithSafePointPolling(long srcAddr, long dstAddr, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - UNSAFE.copyMemory(srcAddr, dstAddr, size); - length -= size; - srcAddr += size; - dstAddr += size; - } - } - - static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, long length) { - UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, length); - } - - private static void copyMemoryWithSafePointPolling(Object src, long srcOffset, Object dst, - long dstOffset, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - static void setMemory(long address, long bytes, byte value) { - UNSAFE.setMemory(address, bytes, value); - } - - static ClassLoader getClassLoader(final Class clazz) { - if (System.getSecurityManager() == null) { - return clazz.getClassLoader(); - } else { - return AccessController - .doPrivileged((PrivilegedAction) clazz::getClassLoader); - } - } - - static ClassLoader getSystemClassLoader() { - if (System.getSecurityManager() == null) { - return ClassLoader.getSystemClassLoader(); - } else { - return AccessController - .doPrivileged((PrivilegedAction) ClassLoader::getSystemClassLoader); - } - } - - static void freeMemory(long address) { - UNSAFE.freeMemory(address); - } - - static int javaVersion() { - return ReflectionUtil.JAVA_VERSION; - } + static int javaVersion() { + return ReflectionUtil.JAVA_VERSION; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ThreadLocalCache.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ThreadLocalCache.java index 018ede527..05e439444 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ThreadLocalCache.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/ThreadLocalCache.java @@ -19,228 +19,238 @@ package org.apache.geaflow.memory; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.List; import java.util.Set; + import org.apache.geaflow.common.utils.MemoryUtils; import org.apache.geaflow.memory.exception.GeaflowOutOfMemoryException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class ThreadLocalCache { - private static final Logger LOGGER = LoggerFactory.getLogger(ThreadLocalCache.class); - - // local pool which binds to the ThreadLocalCache is the first priority to allocate memory. - private AbstractMemoryPool localPool; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; - // temp pool which route to a suitable pool when local pool is full. - private AbstractMemoryPool tempPool; +public final class ThreadLocalCache { + private static final Logger LOGGER = LoggerFactory.getLogger(ThreadLocalCache.class); - // try to allocate memory form adaptive pool when all pools are full. - private AbstractMemoryPool adaptivePool; + // local pool which binds to the ThreadLocalCache is the first priority to allocate memory. + private AbstractMemoryPool localPool; - private final AbstractMemoryPool[] pools; + // temp pool which route to a suitable pool when local pool is full. + private AbstractMemoryPool tempPool; - private MemoryGroup memoryGroup; + // try to allocate memory form adaptive pool when all pools are full. + private AbstractMemoryPool adaptivePool; - public ThreadLocalCache(AbstractMemoryPool[] pools, AbstractMemoryPool localPool, - AbstractMemoryPool adaptivePool, MemoryGroup memoryGroup) { + private final AbstractMemoryPool[] pools; - this.pools = pools; + private MemoryGroup memoryGroup; - this.localPool = localPool; - this.adaptivePool = adaptivePool; + public ThreadLocalCache( + AbstractMemoryPool[] pools, + AbstractMemoryPool localPool, + AbstractMemoryPool adaptivePool, + MemoryGroup memoryGroup) { - this.memoryGroup = memoryGroup; - this.memoryGroup.increaseThreads(); - } + this.pools = pools; - private int suitablePoolIndex(int size, Set visited) { - if (pools == null || pools.length == 0) { - return -1; - } + this.localPool = localPool; + this.adaptivePool = adaptivePool; - int index = -1; - AbstractMemoryPool minPool = null; - for (int i = 0; i < pools.length; i++) { - if (visited.contains(i)) { - continue; - } - AbstractMemoryPool pool = pools[i]; - if (pool.freeBytes() < size && !pool.canExpandCapacity()) { - visited.add(i); - continue; - } - if (minPool == null || pool.groupThreads(memoryGroup).intValue() < minPool.groupThreads(memoryGroup).intValue()) { - minPool = pool; - index = i; - } - } + this.memoryGroup = memoryGroup; + this.memoryGroup.increaseThreads(); + } - return index; + private int suitablePoolIndex(int size, Set visited) { + if (pools == null || pools.length == 0) { + return -1; } - ByteBuf requireBuf(int size) { - - if (this.localPool.getMemoryMode() == MemoryMode.ON_HEAP) { - return this.localPool.oneAllocate(size, memoryGroup); - } - ByteBuf byteBuf = null; - if (canAllocate(this.localPool, size)) { - byteBuf = this.localPool.oneAllocate(size, memoryGroup); - } - - if (oneAllocateFailed(byteBuf) && this.tempPool != null && canAllocate(this.tempPool, size)) { - byteBuf = this.tempPool.oneAllocate(size, memoryGroup); - } - - Set visited = Sets.newHashSet(); - while (byteBuf == null || byteBuf.allocateFailed()) { - - int i = suitablePoolIndex(size, visited); - if (i == -1) { - break; - } - byteBuf = pools[i].oneAllocate(size, memoryGroup); - visited.add(i); - - if (!byteBuf.allocateFailed()) { - updateTempPool(pools[i]); - break; - } + int index = -1; + AbstractMemoryPool minPool = null; + for (int i = 0; i < pools.length; i++) { + if (visited.contains(i)) { + continue; + } + AbstractMemoryPool pool = pools[i]; + if (pool.freeBytes() < size && !pool.canExpandCapacity()) { + visited.add(i); + continue; + } + if (minPool == null + || pool.groupThreads(memoryGroup).intValue() + < minPool.groupThreads(memoryGroup).intValue()) { + minPool = pool; + index = i; + } + } - if (visited.size() == pools.length) { - updateTempPool(null); - break; - } - } + return index; + } - byteBuf = tryOneAllocate(byteBuf, size); - if (oneAllocateFailed(byteBuf)) { - reportOutOfMemoryException(); - throw new GeaflowOutOfMemoryException("out of memory"); - } + ByteBuf requireBuf(int size) { - return byteBuf; + if (this.localPool.getMemoryMode() == MemoryMode.ON_HEAP) { + return this.localPool.oneAllocate(size, memoryGroup); } - - private boolean oneAllocateFailed(ByteBuf byteBuf) { - return (byteBuf == null || byteBuf.allocateFailed()); + ByteBuf byteBuf = null; + if (canAllocate(this.localPool, size)) { + byteBuf = this.localPool.oneAllocate(size, memoryGroup); } - private ByteBuf tryOneAllocate(ByteBuf buf, int size) { - if (adaptivePool != null && oneAllocateFailed(buf)) { - buf = adaptivePool.oneAllocate(size, memoryGroup); - } - return buf; + if (oneAllocateFailed(byteBuf) && this.tempPool != null && canAllocate(this.tempPool, size)) { + byteBuf = this.tempPool.oneAllocate(size, memoryGroup); } - private void updateTempPool(AbstractMemoryPool newPool) { - if (this.tempPool != null) { - this.tempPool.groupThreads(memoryGroup).decrementAndGet(); - } - - this.tempPool = newPool; - - if (this.tempPool != null) { - this.tempPool.groupThreads(memoryGroup).incrementAndGet(); - } + Set visited = Sets.newHashSet(); + while (byteBuf == null || byteBuf.allocateFailed()) { + + int i = suitablePoolIndex(size, visited); + if (i == -1) { + break; + } + byteBuf = pools[i].oneAllocate(size, memoryGroup); + visited.add(i); + + if (!byteBuf.allocateFailed()) { + updateTempPool(pools[i]); + break; + } + + if (visited.size() == pools.length) { + updateTempPool(null); + break; + } } - List requireBufs(int size) { + byteBuf = tryOneAllocate(byteBuf, size); + if (oneAllocateFailed(byteBuf)) { + reportOutOfMemoryException(); + throw new GeaflowOutOfMemoryException("out of memory"); + } - if (this.localPool.getMemoryMode() == MemoryMode.ON_HEAP) { - return this.localPool.allocate(size, memoryGroup); - } + return byteBuf; + } - List byteBufs = Lists.newArrayList(); - if (canAllocate(this.localPool, size)) { - byteBufs.addAll(this.localPool.allocate(size, memoryGroup)); - } + private boolean oneAllocateFailed(ByteBuf byteBuf) { + return (byteBuf == null || byteBuf.allocateFailed()); + } - int remain = allocateFailedSize(byteBufs, size); + private ByteBuf tryOneAllocate(ByteBuf buf, int size) { + if (adaptivePool != null && oneAllocateFailed(buf)) { + buf = adaptivePool.oneAllocate(size, memoryGroup); + } + return buf; + } - if (remain > 0 && this.tempPool != null && canAllocate(this.tempPool, remain)) { - byteBufs.addAll(this.tempPool.allocate(remain, memoryGroup)); - remain = allocateFailedSize(byteBufs, size); - } + private void updateTempPool(AbstractMemoryPool newPool) { + if (this.tempPool != null) { + this.tempPool.groupThreads(memoryGroup).decrementAndGet(); + } - Set visited = Sets.newHashSet(); - while (remain > 0) { - int i = suitablePoolIndex(remain > this.localPool.chunkSize ? this.localPool.chunkSize : remain, visited); - if (i == -1) { - break; - } - byteBufs.addAll(pools[i].allocate(remain, memoryGroup)); - visited.add(i); + this.tempPool = newPool; - remain = allocateFailedSize(byteBufs, size); + if (this.tempPool != null) { + this.tempPool.groupThreads(memoryGroup).incrementAndGet(); + } + } - if (remain <= 0) { - updateTempPool(pools[i]); - break; - } + List requireBufs(int size) { - if (visited.size() == pools.length) { - updateTempPool(null); - break; - } - } + if (this.localPool.getMemoryMode() == MemoryMode.ON_HEAP) { + return this.localPool.allocate(size, memoryGroup); + } - // 如果部分失败,则尝试从adaptivePool继续申请剩余部分内存 - if (remain > 0 && adaptivePool != null) { - byteBufs.addAll(adaptivePool.allocate(remain, memoryGroup)); - remain = allocateFailedSize(byteBufs, size); - } + List byteBufs = Lists.newArrayList(); + if (canAllocate(this.localPool, size)) { + byteBufs.addAll(this.localPool.allocate(size, memoryGroup)); + } - if (remain > 0) { - reportOutOfMemoryException(); - throw new GeaflowOutOfMemoryException("out of memory"); - } + int remain = allocateFailedSize(byteBufs, size); - return byteBufs; + if (remain > 0 && this.tempPool != null && canAllocate(this.tempPool, remain)) { + byteBufs.addAll(this.tempPool.allocate(remain, memoryGroup)); + remain = allocateFailedSize(byteBufs, size); } - private boolean canAllocate(AbstractMemoryPool pool, int size) { - return pool.freeBytes() >= size || pool.canExpandCapacity(); + Set visited = Sets.newHashSet(); + while (remain > 0) { + int i = + suitablePoolIndex( + remain > this.localPool.chunkSize ? this.localPool.chunkSize : remain, visited); + if (i == -1) { + break; + } + byteBufs.addAll(pools[i].allocate(remain, memoryGroup)); + visited.add(i); + + remain = allocateFailedSize(byteBufs, size); + + if (remain <= 0) { + updateTempPool(pools[i]); + break; + } + + if (visited.size() == pools.length) { + updateTempPool(null); + break; + } } - private int allocateFailedSize(List byteBufs, int totalSize) { - int size = 0; - for (ByteBuf buf : byteBufs) { - size += buf.getLength(); - } - return totalSize - size; + // 如果部分失败,则尝试从adaptivePool继续申请剩余部分内存 + if (remain > 0 && adaptivePool != null) { + byteBufs.addAll(adaptivePool.allocate(remain, memoryGroup)); + remain = allocateFailedSize(byteBufs, size); } - public AbstractMemoryPool getLocalPool() { - return this.localPool; + if (remain > 0) { + reportOutOfMemoryException(); + throw new GeaflowOutOfMemoryException("out of memory"); } - /** - * Should be called if the Thread that uses this cache is about to exist to - * release resources out of the cache. - */ - void free() { - if (localPool != null) { - localPool.groupThreads(memoryGroup).decrementAndGet(); - } - memoryGroup.decreaseThreads(); - } + return byteBufs; + } - protected void reportOutOfMemoryException() { - try { - MemoryManager memoryManager = MemoryManager.getInstance(); - String msg = String.format("Direct memory used %s, max direct memory %s. Please " - + "reduce -Xmx or set -XX:MaxDirectMemorySize to enlarge direct memory.", - MemoryUtils.humanReadableByteCount(memoryManager.usedMemory()), - MemoryUtils.humanReadableByteCount(memoryManager.maxMemory())); - LOGGER.warn("{}", msg); + private boolean canAllocate(AbstractMemoryPool pool, int size) { + return pool.freeBytes() >= size || pool.canExpandCapacity(); + } - } catch (Throwable t) { - LOGGER.error("report exception failed", t); - } + private int allocateFailedSize(List byteBufs, int totalSize) { + int size = 0; + for (ByteBuf buf : byteBufs) { + size += buf.getLength(); } - + return totalSize - size; + } + + public AbstractMemoryPool getLocalPool() { + return this.localPool; + } + + /** + * Should be called if the Thread that uses this cache is about to exist to release resources out + * of the cache. + */ + void free() { + if (localPool != null) { + localPool.groupThreads(memoryGroup).decrementAndGet(); + } + memoryGroup.decreaseThreads(); + } + + protected void reportOutOfMemoryException() { + try { + MemoryManager memoryManager = MemoryManager.getInstance(); + String msg = + String.format( + "Direct memory used %s, max direct memory %s. Please " + + "reduce -Xmx or set -XX:MaxDirectMemorySize to enlarge direct memory.", + MemoryUtils.humanReadableByteCount(memoryManager.usedMemory()), + MemoryUtils.humanReadableByteCount(memoryManager.maxMemory())); + LOGGER.warn("{}", msg); + + } catch (Throwable t) { + LOGGER.error("report exception failed", t); + } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArray.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArray.java index 08d189f3e..94da30fbd 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArray.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArray.java @@ -21,28 +21,18 @@ import java.nio.ByteBuffer; -/** - * This interface describes a unified view for bytes management. - */ +/** This interface describes a unified view for bytes management. */ public interface ByteArray { - /** - * get array size. - */ - int size(); + /** get array size. */ + int size(); - /** - * get byte array. - */ - byte[] array(); + /** get byte array. */ + byte[] array(); - /** - * convert to ByteBuffer. - */ - ByteBuffer toByteBuffer(); + /** convert to ByteBuffer. */ + ByteBuffer toByteBuffer(); - /** - * release byte array. - */ - boolean release(); + /** release byte array. */ + boolean release(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayBuilder.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayBuilder.java index 9853d2c37..46b087f55 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayBuilder.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayBuilder.java @@ -23,15 +23,15 @@ public class ByteArrayBuilder { - public static ByteArray of(MemoryViewReference viewRef, int offset, int len) { - return new ByteArrayWithMemoryView(viewRef, offset, len); - } + public static ByteArray of(MemoryViewReference viewRef, int offset, int len) { + return new ByteArrayWithMemoryView(viewRef, offset, len); + } - public static ByteArray of(byte[] bytes) { - return new DefaultByteArray(bytes); - } + public static ByteArray of(byte[] bytes) { + return new DefaultByteArray(bytes); + } - public static ByteArray of(byte[] bytes, int offset, int len) { - return new DefaultByteArray(bytes, offset, len); - } + public static ByteArray of(byte[] bytes, int offset, int len) { + return new DefaultByteArray(bytes, offset, len); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayRefUtil.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayRefUtil.java index a787cf77a..0e6727b37 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayRefUtil.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayRefUtil.java @@ -22,19 +22,17 @@ import java.util.HashSet; import java.util.Set; -/** - * This Util is a benchmark test util for memory leak detection. - */ +/** This Util is a benchmark test util for memory leak detection. */ public class ByteArrayRefUtil { - public static Set array = new HashSet<>(); - public static volatile boolean needCheck = false; + public static Set array = new HashSet<>(); + public static volatile boolean needCheck = false; - public static synchronized void add(ByteArrayWithMemoryView view) { - array.add(view); - } + public static synchronized void add(ByteArrayWithMemoryView view) { + array.add(view); + } - public static synchronized void remove(ByteArrayWithMemoryView view) { - array.remove(view); - } + public static synchronized void remove(ByteArrayWithMemoryView view) { + array.remove(view); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayWithMemoryView.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayWithMemoryView.java index 9104208ed..496968abf 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayWithMemoryView.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/ByteArrayWithMemoryView.java @@ -20,63 +20,64 @@ package org.apache.geaflow.memory.array; import java.nio.ByteBuffer; + import org.apache.geaflow.memory.MemoryViewReader; import org.apache.geaflow.memory.MemoryViewReference; public class ByteArrayWithMemoryView implements ByteArray { - private int offset; - private int len; - private MemoryViewReference viewRef; + private int offset; + private int len; + private MemoryViewReference viewRef; - public ByteArrayWithMemoryView(MemoryViewReference viewRef, int offset, int len) { - this.offset = offset; - this.len = len; - this.viewRef = viewRef; - this.viewRef.incRef(); - if (ByteArrayRefUtil.needCheck) { - ByteArrayRefUtil.add(this); - } + public ByteArrayWithMemoryView(MemoryViewReference viewRef, int offset, int len) { + this.offset = offset; + this.len = len; + this.viewRef = viewRef; + this.viewRef.incRef(); + if (ByteArrayRefUtil.needCheck) { + ByteArrayRefUtil.add(this); } + } - @Override - public int size() { - return this.len; - } + @Override + public int size() { + return this.len; + } - @Override - public byte[] array() { - byte[] bytes = new byte[len]; - getReader().read(bytes, 0, len); - return bytes; - } + @Override + public byte[] array() { + byte[] bytes = new byte[len]; + getReader().read(bytes, 0, len); + return bytes; + } - @Override - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(array(), 0, len); - } + @Override + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(array(), 0, len); + } - private MemoryViewReader getReader() { - MemoryViewReader reader = viewRef.getMemoryView().getReader(); - reader.skip(offset); - return reader; - } + private MemoryViewReader getReader() { + MemoryViewReader reader = viewRef.getMemoryView().getReader(); + reader.skip(offset); + return reader; + } - @Override - public boolean release() { - if (len < 0) { - return true; - } + @Override + public boolean release() { + if (len < 0) { + return true; + } - synchronized (viewRef) { - if (len > 0) { - if (ByteArrayRefUtil.needCheck) { - ByteArrayRefUtil.remove(this); - } - len = -1; - return viewRef.decRef(); - } + synchronized (viewRef) { + if (len > 0) { + if (ByteArrayRefUtil.needCheck) { + ByteArrayRefUtil.remove(this); } - return false; + len = -1; + return viewRef.decRef(); + } } + return false; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/DefaultByteArray.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/DefaultByteArray.java index a65d540de..c53f948da 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/DefaultByteArray.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/array/DefaultByteArray.java @@ -23,38 +23,38 @@ public class DefaultByteArray implements ByteArray { - private byte[] data; - private final int offset; - private final int length; - - public DefaultByteArray(byte[] data) { - this(data, 0, data.length); - } - - public DefaultByteArray(byte[] data, int offset, int length) { - this.data = data; - this.offset = offset; - this.length = length; - } - - @Override - public int size() { - return length; - } - - @Override - public byte[] array() { - return data; - } - - @Override - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(data, offset, length); - } - - @Override - public boolean release() { - data = null; - return true; - } + private byte[] data; + private final int offset; + private final int length; + + public DefaultByteArray(byte[] data) { + this(data, 0, data.length); + } + + public DefaultByteArray(byte[] data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int size() { + return length; + } + + @Override + public byte[] array() { + return data; + } + + @Override + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(data, offset, length); + } + + @Override + public boolean release() { + data = null; + return true; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayInputStream.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayInputStream.java index f06366ffa..59755d499 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayInputStream.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayInputStream.java @@ -21,59 +21,60 @@ import java.io.IOException; import java.io.InputStream; + import org.apache.geaflow.memory.MemoryView; import org.apache.geaflow.memory.MemoryViewReader; public class ByteArrayInputStream extends InputStream { - private final MemoryViewReader reader; + private final MemoryViewReader reader; - public ByteArrayInputStream(MemoryView view) { - reader = view.getReader(); - } + public ByteArrayInputStream(MemoryView view) { + reader = view.getReader(); + } - @Override - public int read() throws IOException { - return reader.hasNext() ? reader.read() & 0xff : -1; - } + @Override + public int read() throws IOException { + return reader.hasNext() ? reader.read() & 0xff : -1; + } - @Override - public int read(byte[] b) throws IOException { - return read(b, 0, b.length); - } + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } - @Override - public int read(byte[] b, int off, int len) throws IOException { - return reader.read(b, off, len); - } + @Override + public int read(byte[] b, int off, int len) throws IOException { + return reader.read(b, off, len); + } - @Override - public long skip(long n) throws IOException { - return reader.skip(n); - } + @Override + public long skip(long n) throws IOException { + return reader.skip(n); + } - @Override - public int available() throws IOException { - return reader.available(); - } + @Override + public int available() throws IOException { + return reader.available(); + } - @Override - public void close() throws IOException { - super.close(); - } + @Override + public void close() throws IOException { + super.close(); + } - @Override - public synchronized void mark(int readlimit) { - super.mark(readlimit); - } + @Override + public synchronized void mark(int readlimit) { + super.mark(readlimit); + } - @Override - public synchronized void reset() throws IOException { - super.reset(); - } + @Override + public synchronized void reset() throws IOException { + super.reset(); + } - @Override - public boolean markSupported() { - return super.markSupported(); - } + @Override + public boolean markSupported() { + return super.markSupported(); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayOutputStream.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayOutputStream.java index a654f9f6b..1ba054afc 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayOutputStream.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteArrayOutputStream.java @@ -21,42 +21,43 @@ import java.io.IOException; import java.io.OutputStream; + import org.apache.geaflow.memory.MemoryView; import org.apache.geaflow.memory.MemoryViewWriter; public class ByteArrayOutputStream extends OutputStream { - private MemoryView view = null; - private MemoryViewWriter writer = null; - - public ByteArrayOutputStream(MemoryView view) { - if (view != null) { - this.view = view; - this.writer = view.getWriter(); - } - } - - @Override - public void write(int b) throws IOException { - writer.write(b); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - writer.write(b, off, len); - } - - @Override - public void flush() throws IOException { - super.flush(); - } - - @Override - public void close() throws IOException { - super.close(); - } + private MemoryView view = null; + private MemoryViewWriter writer = null; - public MemoryView getView() { - return view; + public ByteArrayOutputStream(MemoryView view) { + if (view != null) { + this.view = view; + this.writer = view.getWriter(); } + } + + @Override + public void write(int b) throws IOException { + writer.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + writer.write(b, off, len); + } + + @Override + public void flush() throws IOException { + super.flush(); + } + + @Override + public void close() throws IOException { + super.close(); + } + + public MemoryView getView() { + return view; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteBufferWritableChannel.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteBufferWritableChannel.java index 38fe62adb..fd3db6166 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteBufferWritableChannel.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/channel/ByteBufferWritableChannel.java @@ -22,48 +22,41 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; -/** - * This class is a writable channel that stores the written data in a byte array in memory. - */ +/** This class is a writable channel that stores the written data in a byte array in memory. */ public class ByteBufferWritableChannel implements WritableByteChannel { - private final byte[] data; - private int offset; - - public ByteBufferWritableChannel(int size) { - this.data = new byte[size]; - } - - public byte[] getData() { - return data; - } - - /** - * Reads from the given buffer into the internal byte array. - */ - @Override - public int write(ByteBuffer src) { - return write(src, 0); - } - - public int write(ByteBuffer src, int startOffset) { - int position = src.position(); - int len = position - startOffset; - src.position(startOffset); - src.get(data, offset, len); - offset += len; - src.position(position); - return position; - } - - @Override - public boolean isOpen() { - return true; - } - - @Override - public void close() { - - } - + private final byte[] data; + private int offset; + + public ByteBufferWritableChannel(int size) { + this.data = new byte[size]; + } + + public byte[] getData() { + return data; + } + + /** Reads from the given buffer into the internal byte array. */ + @Override + public int write(ByteBuffer src) { + return write(src, 0); + } + + public int write(ByteBuffer src, int startOffset) { + int position = src.position(); + int len = position - startOffset; + src.position(startOffset); + src.get(data, offset, len); + offset += len; + src.position(position); + return position; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() {} } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/Cleaner.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/Cleaner.java index 452849a14..e83475a20 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/Cleaner.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/Cleaner.java @@ -19,13 +19,11 @@ import java.nio.ByteBuffer; /** - * This class is an adaptation of Netty's io.netty.util.internal.Cleaner. - * Allows to free direct {@link ByteBuffer}s. + * This class is an adaptation of Netty's io.netty.util.internal.Cleaner. Allows to free direct + * {@link ByteBuffer}s. */ public interface Cleaner { - /** - * Free a direct {@link ByteBuffer} if possible. - */ - void freeDirectBuffer(ByteBuffer buffer); + /** Free a direct {@link ByteBuffer} if possible. */ + void freeDirectBuffer(ByteBuffer buffer); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava6.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava6.java index a740c4eec..84f09b129 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava6.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava6.java @@ -21,121 +21,127 @@ import java.nio.ByteBuffer; import java.security.AccessController; import java.security.PrivilegedAction; + import org.apache.geaflow.memory.DirectMemory; import org.apache.geaflow.memory.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * This class is an adaptation of Netty's io.netty.util.internal.CleanerJava6. - * Allows to free direct {@link ByteBuffer} by using Cleaner for Java version less than 9. + * This class is an adaptation of Netty's io.netty.util.internal.CleanerJava6. Allows to free direct + * {@link ByteBuffer} by using Cleaner for Java version less than 9. */ public final class CleanerJava6 implements Cleaner { - private static final long CLEANER_FIELD_OFFSET; - private static final Method CLEAN_METHOD; - private static final Field CLEANER_FIELD; + private static final long CLEANER_FIELD_OFFSET; + private static final Method CLEAN_METHOD; + private static final Field CLEANER_FIELD; - private static final Logger logger = LoggerFactory.getLogger(CleanerJava6.class); + private static final Logger logger = LoggerFactory.getLogger(CleanerJava6.class); - static { - long fieldOffset; - Method clean; - Field cleanerField; - final ByteBuffer direct = ByteBuffer.allocateDirect(1); - try { - Object mayBeCleanerField = AccessController - .doPrivileged((PrivilegedAction) () -> { + static { + long fieldOffset; + Method clean; + Field cleanerField; + final ByteBuffer direct = ByteBuffer.allocateDirect(1); + try { + Object mayBeCleanerField = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { try { - Field cleanerField1 = direct.getClass().getDeclaredField("cleaner"); - if (!DirectMemory.hasUnsafe()) { - // We need to make it accessible if we do not use Unsafe as we will - // access it via - // reflection. - cleanerField1.setAccessible(true); - } - return cleanerField1; + Field cleanerField1 = direct.getClass().getDeclaredField("cleaner"); + if (!DirectMemory.hasUnsafe()) { + // We need to make it accessible if we do not use Unsafe as we will + // access it via + // reflection. + cleanerField1.setAccessible(true); + } + return cleanerField1; } catch (Throwable cause) { - return cause; + return cause; } - }); - if (mayBeCleanerField instanceof Throwable) { - throw (Throwable) mayBeCleanerField; - } + }); + if (mayBeCleanerField instanceof Throwable) { + throw (Throwable) mayBeCleanerField; + } - cleanerField = (Field) mayBeCleanerField; + cleanerField = (Field) mayBeCleanerField; - final Object cleaner; + final Object cleaner; - // If we have sun.misc.Unsafe we will use it as its faster then using reflection, - // otherwise let us try reflection as last resort. - if (DirectMemory.hasUnsafe()) { - fieldOffset = DirectMemory.objectFieldOffset(cleanerField); - cleaner = DirectMemory.getObject(direct, fieldOffset); - } else { - fieldOffset = -1; - cleaner = cleanerField.get(direct); - } - clean = cleaner.getClass().getDeclaredMethod("clean"); - clean.invoke(cleaner); - } catch (Throwable t) { - // We don't have ByteBuffer.cleaner(). - fieldOffset = -1; - clean = null; - cleanerField = null; - } - - CLEANER_FIELD = cleanerField; - CLEANER_FIELD_OFFSET = fieldOffset; - CLEAN_METHOD = clean; + // If we have sun.misc.Unsafe we will use it as its faster then using reflection, + // otherwise let us try reflection as last resort. + if (DirectMemory.hasUnsafe()) { + fieldOffset = DirectMemory.objectFieldOffset(cleanerField); + cleaner = DirectMemory.getObject(direct, fieldOffset); + } else { + fieldOffset = -1; + cleaner = cleanerField.get(direct); + } + clean = cleaner.getClass().getDeclaredMethod("clean"); + clean.invoke(cleaner); + } catch (Throwable t) { + // We don't have ByteBuffer.cleaner(). + fieldOffset = -1; + clean = null; + cleanerField = null; } - public static boolean isSupported() { - return CLEANER_FIELD_OFFSET != -1 || CLEANER_FIELD != null; - } + CLEANER_FIELD = cleanerField; + CLEANER_FIELD_OFFSET = fieldOffset; + CLEAN_METHOD = clean; + } - private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { - Throwable cause = AccessController.doPrivileged((PrivilegedAction) () -> { - try { - freeDirectBuffer0(buffer); - return null; - } catch (Throwable cause1) { - return cause1; - } - }); - if (cause != null) { - PlatformDependent.throwException(cause); - } + public static boolean isSupported() { + return CLEANER_FIELD_OFFSET != -1 || CLEANER_FIELD != null; + } + + private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { + Throwable cause = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + freeDirectBuffer0(buffer); + return null; + } catch (Throwable cause1) { + return cause1; + } + }); + if (cause != null) { + PlatformDependent.throwException(cause); } + } - private static void freeDirectBuffer0(ByteBuffer buffer) throws Exception { - final Object cleaner; - // If CLEANER_FIELD_OFFSET == -1 we need to use reflection to access the cleaner, - // otherwise we can use - // sun.misc.Unsafe. - if (CLEANER_FIELD_OFFSET == -1) { - cleaner = CLEANER_FIELD.get(buffer); - } else { - cleaner = DirectMemory.getObject(buffer, CLEANER_FIELD_OFFSET); - } - if (cleaner != null) { - CLEAN_METHOD.invoke(cleaner); - } + private static void freeDirectBuffer0(ByteBuffer buffer) throws Exception { + final Object cleaner; + // If CLEANER_FIELD_OFFSET == -1 we need to use reflection to access the cleaner, + // otherwise we can use + // sun.misc.Unsafe. + if (CLEANER_FIELD_OFFSET == -1) { + cleaner = CLEANER_FIELD.get(buffer); + } else { + cleaner = DirectMemory.getObject(buffer, CLEANER_FIELD_OFFSET); } + if (cleaner != null) { + CLEAN_METHOD.invoke(cleaner); + } + } - @Override - public void freeDirectBuffer(ByteBuffer buffer) { - if (!buffer.isDirect()) { - return; - } - if (System.getSecurityManager() == null) { - try { - freeDirectBuffer0(buffer); - } catch (Throwable cause) { - PlatformDependent.throwException(cause); - } - } else { - freeDirectBufferPrivileged(buffer); - } + @Override + public void freeDirectBuffer(ByteBuffer buffer) { + if (!buffer.isDirect()) { + return; + } + if (System.getSecurityManager() == null) { + try { + freeDirectBuffer0(buffer); + } catch (Throwable cause) { + PlatformDependent.throwException(cause); + } + } else { + freeDirectBufferPrivileged(buffer); } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava9.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava9.java index d1d4ed6ec..836f5462b 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava9.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/cleaner/CleanerJava9.java @@ -21,79 +21,89 @@ import java.nio.ByteBuffer; import java.security.AccessController; import java.security.PrivilegedAction; + import org.apache.geaflow.memory.DirectMemory; import org.apache.geaflow.memory.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * This class is an adaptation of Netty's io.netty.util.internal.CleanerJava9. - * Allows to free direct {@link ByteBuffer} by using Cleaner for Java version equal or greater than 9. + * This class is an adaptation of Netty's io.netty.util.internal.CleanerJava9. Allows to free direct + * {@link ByteBuffer} by using Cleaner for Java version equal or greater than 9. */ public final class CleanerJava9 implements Cleaner { - private static final Logger logger = LoggerFactory.getLogger(CleanerJava9.class); + private static final Logger logger = LoggerFactory.getLogger(CleanerJava9.class); - private static final Method INVOKE_CLEANER; + private static final Method INVOKE_CLEANER; - static { - final Method method; - if (DirectMemory.hasUnsafe()) { - final ByteBuffer buffer = ByteBuffer.allocateDirect(1); - Object maybeInvokeMethod = AccessController - .doPrivileged((PrivilegedAction) () -> { + static { + final Method method; + if (DirectMemory.hasUnsafe()) { + final ByteBuffer buffer = ByteBuffer.allocateDirect(1); + Object maybeInvokeMethod = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { try { - // See https://bugs.openjdk.java.net/browse/JDK-8171377 - Method m = DirectMemory.unsafe().getClass() - .getDeclaredMethod("invokeCleaner", ByteBuffer.class); - m.invoke(DirectMemory.unsafe(), buffer); - return m; - } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { - return e; + // See https://bugs.openjdk.java.net/browse/JDK-8171377 + Method m = + DirectMemory.unsafe() + .getClass() + .getDeclaredMethod("invokeCleaner", ByteBuffer.class); + m.invoke(DirectMemory.unsafe(), buffer); + return m; + } catch (NoSuchMethodException + | InvocationTargetException + | IllegalAccessException e) { + return e; } - }); + }); - if (maybeInvokeMethod instanceof Throwable) { - method = null; - } else { - method = (Method) maybeInvokeMethod; - } - } else { - method = null; - } - INVOKE_CLEANER = method; + if (maybeInvokeMethod instanceof Throwable) { + method = null; + } else { + method = (Method) maybeInvokeMethod; + } + } else { + method = null; } + INVOKE_CLEANER = method; + } - public static boolean isSupported() { - return INVOKE_CLEANER != null; - } + public static boolean isSupported() { + return INVOKE_CLEANER != null; + } - private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { - Exception error = AccessController.doPrivileged((PrivilegedAction) () -> { - try { - INVOKE_CLEANER.invoke(DirectMemory.unsafe(), buffer); - } catch (InvocationTargetException | IllegalAccessException e) { - return e; - } - return null; - }); - if (error != null) { - PlatformDependent.throwException(error); - } + private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { + Exception error = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + INVOKE_CLEANER.invoke(DirectMemory.unsafe(), buffer); + } catch (InvocationTargetException | IllegalAccessException e) { + return e; + } + return null; + }); + if (error != null) { + PlatformDependent.throwException(error); } + } - @Override - public void freeDirectBuffer(ByteBuffer buffer) { - // Try to minimize overhead when there is no SecurityManager present. - // See https://bugs.openjdk.java.net/browse/JDK-8191053. - if (System.getSecurityManager() == null) { - try { - INVOKE_CLEANER.invoke(DirectMemory.unsafe(), buffer); - } catch (Throwable cause) { - PlatformDependent.throwException(cause); - } - } else { - freeDirectBufferPrivileged(buffer); - } + @Override + public void freeDirectBuffer(ByteBuffer buffer) { + // Try to minimize overhead when there is no SecurityManager present. + // See https://bugs.openjdk.java.net/browse/JDK-8191053. + if (System.getSecurityManager() == null) { + try { + INVOKE_CLEANER.invoke(DirectMemory.unsafe(), buffer); + } catch (Throwable cause) { + PlatformDependent.throwException(cause); + } + } else { + freeDirectBufferPrivileged(buffer); } + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Lz4.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Lz4.java index ca958e958..b6163f214 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Lz4.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Lz4.java @@ -19,11 +19,6 @@ package org.apache.geaflow.memory.compress; -import net.jpountz.lz4.LZ4BlockInputStream; -import net.jpountz.lz4.LZ4BlockOutputStream; -import net.jpountz.lz4.LZ4Compressor; -import net.jpountz.lz4.LZ4Factory; -import net.jpountz.lz4.LZ4FastDecompressor; import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; import org.apache.geaflow.memory.MemoryView; @@ -31,46 +26,52 @@ import org.apache.geaflow.memory.channel.ByteArrayInputStream; import org.apache.geaflow.memory.channel.ByteArrayOutputStream; +import net.jpountz.lz4.LZ4BlockInputStream; +import net.jpountz.lz4.LZ4BlockOutputStream; +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; + public class Lz4 { - private static final int BLOCK_SIZE = 1024 * 8; - private static final LZ4Compressor COMPRESSOR = LZ4Factory.safeInstance().fastCompressor(); - private static final LZ4FastDecompressor DECOMPRESSOR = LZ4Factory.safeInstance().fastDecompressor(); + private static final int BLOCK_SIZE = 1024 * 8; + private static final LZ4Compressor COMPRESSOR = LZ4Factory.safeInstance().fastCompressor(); + private static final LZ4FastDecompressor DECOMPRESSOR = + LZ4Factory.safeInstance().fastDecompressor(); - public static MemoryView compress(MemoryView view) { - MemoryView v = - MemoryManager.getInstance().requireMemory(view.contentSize() / 2, MemoryGroupManger.STATE); - ByteArrayOutputStream baos = new ByteArrayOutputStream(v); + public static MemoryView compress(MemoryView view) { + MemoryView v = + MemoryManager.getInstance().requireMemory(view.contentSize() / 2, MemoryGroupManger.STATE); + ByteArrayOutputStream baos = new ByteArrayOutputStream(v); - LZ4BlockOutputStream compressedOutput = new LZ4BlockOutputStream( - baos, BLOCK_SIZE, COMPRESSOR); - try { - compressedOutput.write(view.toArray(), 0, view.contentSize()); - compressedOutput.close(); - } catch (Throwable throwable) { - throw new RuntimeException("compress fail", throwable); - } - return baos.getView(); + LZ4BlockOutputStream compressedOutput = new LZ4BlockOutputStream(baos, BLOCK_SIZE, COMPRESSOR); + try { + compressedOutput.write(view.toArray(), 0, view.contentSize()); + compressedOutput.close(); + } catch (Throwable throwable) { + throw new RuntimeException("compress fail", throwable); } + return baos.getView(); + } - public static MemoryView uncompress(MemoryView view) { - return uncompress(view, view.contentSize() * 2); - } + public static MemoryView uncompress(MemoryView view) { + return uncompress(view, view.contentSize() * 2); + } - public static MemoryView uncompress(MemoryView view, int initSize) { - MemoryView v = MemoryManager.getInstance().requireMemory(initSize, MemoryGroupManger.STATE); - MemoryViewWriter writer = v.getWriter(); - int count; - byte[] buffer = new byte[BLOCK_SIZE]; + public static MemoryView uncompress(MemoryView view, int initSize) { + MemoryView v = MemoryManager.getInstance().requireMemory(initSize, MemoryGroupManger.STATE); + MemoryViewWriter writer = v.getWriter(); + int count; + byte[] buffer = new byte[BLOCK_SIZE]; - try (LZ4BlockInputStream lzis = new LZ4BlockInputStream( - new ByteArrayInputStream(view), DECOMPRESSOR)) { - while ((count = lzis.read(buffer)) != -1) { - writer.write(buffer, 0, count); - } - } catch (Throwable throwable) { - throw new RuntimeException("uncompress fail", throwable); - } - return v; + try (LZ4BlockInputStream lzis = + new LZ4BlockInputStream(new ByteArrayInputStream(view), DECOMPRESSOR)) { + while ((count = lzis.read(buffer)) != -1) { + writer.write(buffer, 0, count); + } + } catch (Throwable throwable) { + throw new RuntimeException("uncompress fail", throwable); } + return v; + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Snappy.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Snappy.java index 2b7d40009..0d5b0febc 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Snappy.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/compress/Snappy.java @@ -20,6 +20,7 @@ package org.apache.geaflow.memory.compress; import java.io.IOException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; @@ -33,52 +34,50 @@ public class Snappy { - private static final int BUFFER_SIZE = 1024 * 8; + private static final int BUFFER_SIZE = 1024 * 8; - public static MemoryView compress(MemoryView view) throws IOException { - MemoryView v = - MemoryManager.getInstance().requireMemory(view.contentSize() / 2, MemoryGroupManger.STATE); - ByteArrayOutputStream baos = - new ByteArrayOutputStream(v); + public static MemoryView compress(MemoryView view) throws IOException { + MemoryView v = + MemoryManager.getInstance().requireMemory(view.contentSize() / 2, MemoryGroupManger.STATE); + ByteArrayOutputStream baos = new ByteArrayOutputStream(v); - try (SnappyFramedOutputStream sos = new SnappyFramedOutputStream(baos)) { - MemoryViewReader reader = view.getReader(); - byte[] buffer = new byte[BUFFER_SIZE]; - while (true) { - int count = reader.read(buffer); - if (count <= 0) { - break; - } - sos.write(buffer, 0, count); - } - sos.flush(); - return baos.getView(); - } catch (Exception ex) { - throw new GeaflowRuntimeException("uncompress fail", ex); + try (SnappyFramedOutputStream sos = new SnappyFramedOutputStream(baos)) { + MemoryViewReader reader = view.getReader(); + byte[] buffer = new byte[BUFFER_SIZE]; + while (true) { + int count = reader.read(buffer); + if (count <= 0) { + break; } + sos.write(buffer, 0, count); + } + sos.flush(); + return baos.getView(); + } catch (Exception ex) { + throw new GeaflowRuntimeException("uncompress fail", ex); } + } - public static MemoryView uncompress(MemoryView view) { - return uncompress(view, view.contentSize()); - } + public static MemoryView uncompress(MemoryView view) { + return uncompress(view, view.contentSize()); + } - public static MemoryView uncompress(MemoryView view, int initSize) { - byte[] buffer = new byte[BUFFER_SIZE]; - try (SnappyFramedInputStream sis = - new SnappyFramedInputStream(new ByteArrayInputStream(view))) { - MemoryView v = MemoryManager.getInstance().requireMemory(initSize, MemoryGroupManger.STATE); - MemoryViewWriter writer = v.getWriter(); - while (true) { - int count = sis.read(buffer); - if (count <= 0) { - break; - } - writer.write(buffer, 0, count); - } - return v; - } catch (Exception ex) { - throw new GeaflowRuntimeException("uncompress fail", ex); + public static MemoryView uncompress(MemoryView view, int initSize) { + byte[] buffer = new byte[BUFFER_SIZE]; + try (SnappyFramedInputStream sis = + new SnappyFramedInputStream(new ByteArrayInputStream(view))) { + MemoryView v = MemoryManager.getInstance().requireMemory(initSize, MemoryGroupManger.STATE); + MemoryViewWriter writer = v.getWriter(); + while (true) { + int count = sis.read(buffer); + if (count <= 0) { + break; } + writer.write(buffer, 0, count); + } + return v; + } catch (Exception ex) { + throw new GeaflowRuntimeException("uncompress fail", ex); } - + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/config/MemoryConfigKeys.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/config/MemoryConfigKeys.java index 8b8d94cf6..d9a35936b 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/config/MemoryConfigKeys.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/config/MemoryConfigKeys.java @@ -20,61 +20,62 @@ package org.apache.geaflow.memory.config; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; import org.apache.geaflow.memory.DirectMemory; public class MemoryConfigKeys implements Serializable { - public static final long JVM_MAX_DIRECT_MEMORY = DirectMemory.maxDirectMemory0(); + public static final long JVM_MAX_DIRECT_MEMORY = DirectMemory.maxDirectMemory0(); - public static final ConfigKey OFF_HEAP_MEMORY_SIZE_MB = ConfigKeys - .key("geaflow.memory.off.heap.mb") - .defaultValue(0L) - .description("off heap memory mb, default 0"); + public static final ConfigKey OFF_HEAP_MEMORY_SIZE_MB = + ConfigKeys.key("geaflow.memory.off.heap.mb") + .defaultValue(0L) + .description("off heap memory mb, default 0"); - public static final ConfigKey ON_HEAP_MEMORY_SIZE_MB = ConfigKeys - .key("geaflow.memory.on.heap.mb") - .defaultValue(0L) - .description("on heap memory mb, default 0"); + public static final ConfigKey ON_HEAP_MEMORY_SIZE_MB = + ConfigKeys.key("geaflow.memory.on.heap.mb") + .defaultValue(0L) + .description("on heap memory mb, default 0"); - public static final ConfigKey MAX_DIRECT_MEMORY_SIZE = ConfigKeys - .key("geaflow.memory.max.direct.size") - .defaultValue((long) (JVM_MAX_DIRECT_MEMORY * 0.8)) - .description("max direct memory size, default 0.8 * JVM_MAX_DIRECT_MEMORY"); + public static final ConfigKey MAX_DIRECT_MEMORY_SIZE = + ConfigKeys.key("geaflow.memory.max.direct.size") + .defaultValue((long) (JVM_MAX_DIRECT_MEMORY * 0.8)) + .description("max direct memory size, default 0.8 * JVM_MAX_DIRECT_MEMORY"); - public static final ConfigKey MEMORY_PAGE_SIZE = ConfigKeys - .key("geaflow.memory.page.size") - .defaultValue(8192) - .description("memory page size, default 8192 Byte"); + public static final ConfigKey MEMORY_PAGE_SIZE = + ConfigKeys.key("geaflow.memory.page.size") + .defaultValue(8192) + .description("memory page size, default 8192 Byte"); - public static final ConfigKey MEMORY_MAX_ORDER = ConfigKeys - .key("geaflow.memory.max.order") - .defaultValue(11) - .description("memory max order, default 11"); + public static final ConfigKey MEMORY_MAX_ORDER = + ConfigKeys.key("geaflow.memory.max.order") + .defaultValue(11) + .description("memory max order, default 11"); - public static final ConfigKey MEMORY_POOL_SIZE = ConfigKeys - .key("geaflow.memory.pool.size") - .defaultValue(0) - .description("inner memory pool size, default 0"); + public static final ConfigKey MEMORY_POOL_SIZE = + ConfigKeys.key("geaflow.memory.pool.size") + .defaultValue(0) + .description("inner memory pool size, default 0"); - public static final ConfigKey MEMORY_GROUP_RATIO = ConfigKeys - .key("geaflow.memory.group.ratio") - .defaultValue("10:*:*") - .description("format shuffle:state:default=10:*:*, shuffle=10%, *=shared memory"); + public static final ConfigKey MEMORY_GROUP_RATIO = + ConfigKeys.key("geaflow.memory.group.ratio") + .defaultValue("10:*:*") + .description("format shuffle:state:default=10:*:*, shuffle=10%, *=shared memory"); - public static final ConfigKey MEMORY_TRIM_GAP_MINUTE = ConfigKeys - .key("geaflow.memory.trim.gap.minute") - .defaultValue(30) - .description("auto check memory, and trim gap, default 30min"); + public static final ConfigKey MEMORY_TRIM_GAP_MINUTE = + ConfigKeys.key("geaflow.memory.trim.gap.minute") + .defaultValue(30) + .description("auto check memory, and trim gap, default 30min"); - public static final ConfigKey MEMORY_DEBUG_ENABLE = ConfigKeys - .key("geaflow.memory.debug.enable") - .defaultValue(false) - .description("memory manager mist print, default false"); + public static final ConfigKey MEMORY_DEBUG_ENABLE = + ConfigKeys.key("geaflow.memory.debug.enable") + .defaultValue(false) + .description("memory manager mist print, default false"); - public static final ConfigKey MEMORY_AUTO_ADAPT_ENABLE = ConfigKeys - .key("geaflow.memory.auto.adapt.enable") - .defaultValue(true) - .description("auto memory scale from direct to onHeap, default true"); + public static final ConfigKey MEMORY_AUTO_ADAPT_ENABLE = + ConfigKeys.key("geaflow.memory.auto.adapt.enable") + .defaultValue(true) + .description("auto memory scale from direct to onHeap, default true"); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/exception/GeaflowOutOfMemoryException.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/exception/GeaflowOutOfMemoryException.java index dfbffef53..cfe11990f 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/exception/GeaflowOutOfMemoryException.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/exception/GeaflowOutOfMemoryException.java @@ -21,9 +21,9 @@ public final class GeaflowOutOfMemoryException extends OutOfMemoryError { - private static final long serialVersionUID = 4228264016184011555L; + private static final long serialVersionUID = 4228264016184011555L; - public GeaflowOutOfMemoryException(String s) { - super(s); - } + public GeaflowOutOfMemoryException(String s) { + super(s); + } } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkListMetric.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkListMetric.java index bd4f5aa7a..b312b0ce0 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkListMetric.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkListMetric.java @@ -19,18 +19,17 @@ package org.apache.geaflow.memory.metric; -/** - * This interface is the metrics for a chunk list. - */ +/** This interface is the metrics for a chunk list. */ public interface ChunkListMetric extends Iterable { - /** - * Return the minimum usage of the chunk list before which chunks are promoted to the previous list. - */ - int minUsage(); + /** + * Return the minimum usage of the chunk list before which chunks are promoted to the previous + * list. + */ + int minUsage(); - /** - * Return the maximum usage of the chunk list after which chunks are promoted to the next list. - */ - int maxUsage(); + /** + * Return the maximum usage of the chunk list after which chunks are promoted to the next list. + */ + int maxUsage(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkMetric.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkMetric.java index 959d864ec..b730fef2f 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkMetric.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/ChunkMetric.java @@ -19,28 +19,21 @@ package org.apache.geaflow.memory.metric; -/** - * This interface is the metrics for a chunk. - */ +/** This interface is the metrics for a chunk. */ public interface ChunkMetric { - /** - * Return the percentage of the current usage of the chunk. - */ - int usage(); + /** Return the percentage of the current usage of the chunk. */ + int usage(); - /** - * Return the size of the chunk in bytes, this is the maximum of bytes that can be served out of the chunk. - */ - int chunkSize(); + /** + * Return the size of the chunk in bytes, this is the maximum of bytes that can be served out of + * the chunk. + */ + int chunkSize(); - /** - * Return the number of free bytes in the chunk. - */ - int freeBytes(); + /** Return the number of free bytes in the chunk. */ + int freeBytes(); - /** - * Return the number of active bytes in the chunk. - */ - int activeBytes(); + /** Return the number of active bytes in the chunk. */ + int activeBytes(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/MemoryGroupMetric.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/MemoryGroupMetric.java index c20176ba6..41176308c 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/MemoryGroupMetric.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/MemoryGroupMetric.java @@ -19,38 +19,24 @@ package org.apache.geaflow.memory.metric; -/** - * This interface is the metrics for a memory group. - */ +/** This interface is the metrics for a memory group. */ public interface MemoryGroupMetric { - /** - * return the number of bytes has been used by this group. - */ - long usedBytes(); - - /** - * return the on heap bytes used. - */ - long usedOnHeapBytes(); - - /** - * return the off heap bytes used. - */ - long usedOffHeapBytes(); - - /** - * return the baseline number of bytes can be allocated by this group. - */ - long baseBytes(); - - /** - * Return the percentage of the current usage of this group. - */ - double usage(); - - /** - * return the number of byteBuf allocated. - */ - long byteBufNum(); + /** return the number of bytes has been used by this group. */ + long usedBytes(); + + /** return the on heap bytes used. */ + long usedOnHeapBytes(); + + /** return the off heap bytes used. */ + long usedOffHeapBytes(); + + /** return the baseline number of bytes can be allocated by this group. */ + long baseBytes(); + + /** Return the percentage of the current usage of this group. */ + double usage(); + + /** return the number of byteBuf allocated. */ + long byteBufNum(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PageMetric.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PageMetric.java index 6c7d5ae89..54c1a6e18 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PageMetric.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PageMetric.java @@ -19,28 +19,18 @@ package org.apache.geaflow.memory.metric; -/** - * This interface is the metrics for a page. - */ +/** This interface is the metrics for a page. */ public interface PageMetric { - /** - * Return the number of maximal elements that can be allocated out of the sub-page. - */ - int maxNumElements(); + /** Return the number of maximal elements that can be allocated out of the sub-page. */ + int maxNumElements(); - /** - * Return the number of available elements to be allocated. - */ - int numAvailable(); + /** Return the number of available elements to be allocated. */ + int numAvailable(); - /** - * Return the size (in bytes) of the elements that will be allocated. - */ - int elementSize(); + /** Return the size (in bytes) of the elements that will be allocated. */ + int elementSize(); - /** - * Return the size (in bytes) of this page. - */ - int pageSize(); + /** Return the size (in bytes) of this page. */ + int pageSize(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PoolMetric.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PoolMetric.java index e5f9a13fc..2ea247d60 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PoolMetric.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/metric/PoolMetric.java @@ -19,33 +19,21 @@ package org.apache.geaflow.memory.metric; -/** - * This interface is the metrics for a pool. - */ +/** This interface is the metrics for a pool. */ public interface PoolMetric { - /** - * Returns the number of thread caches backed by this arena. - */ - int numThreadCaches(); + /** Returns the number of thread caches backed by this arena. */ + int numThreadCaches(); - /** - * Return the number of allocations done via the pool. This includes all sizes. - */ - long numAllocations(); + /** Return the number of allocations done via the pool. This includes all sizes. */ + long numAllocations(); - /** - * Returns the number of active bytes. - */ - long numActiveBytes(); + /** Returns the number of active bytes. */ + long numActiveBytes(); - /** - * Returns the number of allocated bytes. - */ - long allocateBytes(); + /** Returns the number of allocated bytes. */ + long allocateBytes(); - /** - * Returns the number of free bytes. - */ - long freeBytes(); + /** Returns the number of free bytes. */ + long freeBytes(); } diff --git a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/thread/BaseMemoryGroupThreadLocal.java b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/thread/BaseMemoryGroupThreadLocal.java index 4f6080e7a..a911d352b 100644 --- a/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/thread/BaseMemoryGroupThreadLocal.java +++ b/geaflow/geaflow-memory/src/main/java/org/apache/geaflow/memory/thread/BaseMemoryGroupThreadLocal.java @@ -19,99 +19,99 @@ package org.apache.geaflow.memory.thread; -import com.google.common.collect.Maps; -import io.netty.util.concurrent.FastThreadLocal; -import io.netty.util.concurrent.FastThreadLocalThread; -import io.netty.util.internal.InternalThreadLocalMap; import java.util.Collections; import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; + import org.apache.geaflow.memory.MemoryGroup; import org.apache.geaflow.memory.PlatformDependent; -/** - * This Class is a subClass of Netty FastThreadLocal, and adapted to MemoryGroup. - */ +import com.google.common.collect.Maps; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.concurrent.FastThreadLocalThread; +import io.netty.util.internal.InternalThreadLocalMap; + +/** This Class is a subClass of Netty FastThreadLocal, and adapted to MemoryGroup. */ public abstract class BaseMemoryGroupThreadLocal extends FastThreadLocal { - private static final int VARIABLES_TO_REMOVE_INDEX = InternalThreadLocalMap.nextVariableIndex(); - private final int index = InternalThreadLocalMap.nextVariableIndex(); - - /** - * Returns the current value for the current thread. - */ - @SuppressWarnings("unchecked") - public final V get(MemoryGroup group) { - InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get(); - Object v = threadLocalMap.indexedVariable(index); - if (v != InternalThreadLocalMap.UNSET) { - Map map = (Map) v; - if (map.containsKey(group)) { - return map.get(group); - } - V value = null; - try { - value = initialValue(group); - map.put(group, value); - } catch (Exception e) { - PlatformDependent.throwException(e); - } - return value; - } - - V value = initialize(threadLocalMap, group); - registerCleaner(threadLocalMap); - return value; + private static final int VARIABLES_TO_REMOVE_INDEX = InternalThreadLocalMap.nextVariableIndex(); + private final int index = InternalThreadLocalMap.nextVariableIndex(); + + /** Returns the current value for the current thread. */ + @SuppressWarnings("unchecked") + public final V get(MemoryGroup group) { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get(); + Object v = threadLocalMap.indexedVariable(index); + if (v != InternalThreadLocalMap.UNSET) { + Map map = (Map) v; + if (map.containsKey(group)) { + return map.get(group); + } + V value = null; + try { + value = initialValue(group); + map.put(group, value); + } catch (Exception e) { + PlatformDependent.throwException(e); + } + return value; } - private void registerCleaner(final InternalThreadLocalMap threadLocalMap) { - Thread current = Thread.currentThread(); - if (FastThreadLocalThread.willCleanupFastThreadLocals(current) || threadLocalMap.isCleanerFlagSet(index)) { - return; - } + V value = initialize(threadLocalMap, group); + registerCleaner(threadLocalMap); + return value; + } - threadLocalMap.setCleanerFlag(index); + private void registerCleaner(final InternalThreadLocalMap threadLocalMap) { + Thread current = Thread.currentThread(); + if (FastThreadLocalThread.willCleanupFastThreadLocals(current) + || threadLocalMap.isCleanerFlagSet(index)) { + return; } - private V initialize(InternalThreadLocalMap threadLocalMap, MemoryGroup group) { - V v = null; - try { - v = initialValue(group); - } catch (Exception e) { - PlatformDependent.throwException(e); - } - Map map = Maps.newHashMap(); - map.put(group, v); - - threadLocalMap.setIndexedVariable(index, map); - addToVariablesToRemove(threadLocalMap, this); - return v; + threadLocalMap.setCleanerFlag(index); + } + + private V initialize(InternalThreadLocalMap threadLocalMap, MemoryGroup group) { + V v = null; + try { + v = initialValue(group); + } catch (Exception e) { + PlatformDependent.throwException(e); } + Map map = Maps.newHashMap(); + map.put(group, v); - private static void addToVariablesToRemove(InternalThreadLocalMap threadLocalMap, FastThreadLocal variable) { - Object v = threadLocalMap.indexedVariable(VARIABLES_TO_REMOVE_INDEX); - Set> variablesToRemove; - if (v == InternalThreadLocalMap.UNSET || v == null) { - variablesToRemove = Collections.newSetFromMap(new IdentityHashMap, Boolean>()); - threadLocalMap.setIndexedVariable(VARIABLES_TO_REMOVE_INDEX, variablesToRemove); - } else { - variablesToRemove = (Set>) v; - } - - variablesToRemove.add(variable); + threadLocalMap.setIndexedVariable(index, map); + addToVariablesToRemove(threadLocalMap, this); + return v; + } + + private static void addToVariablesToRemove( + InternalThreadLocalMap threadLocalMap, FastThreadLocal variable) { + Object v = threadLocalMap.indexedVariable(VARIABLES_TO_REMOVE_INDEX); + Set> variablesToRemove; + if (v == InternalThreadLocalMap.UNSET || v == null) { + variablesToRemove = + Collections.newSetFromMap(new IdentityHashMap, Boolean>()); + threadLocalMap.setIndexedVariable(VARIABLES_TO_REMOVE_INDEX, variablesToRemove); + } else { + variablesToRemove = (Set>) v; } - /** - * Returns the initial value for this thread-local variable. - */ - protected abstract V initialValue(MemoryGroup group) throws Exception; + variablesToRemove.add(variable); + } + + /** Returns the initial value for this thread-local variable. */ + protected abstract V initialValue(MemoryGroup group) throws Exception; - protected abstract void notifyRemove(V v); + protected abstract void notifyRemove(V v); - protected void onRemoval(Object v) throws Exception { - if (v instanceof Map) { - ((Map) v).forEach((k, val) -> notifyRemove((V) val)); - } + protected void onRemoval(Object v) throws Exception { + if (v instanceof Map) { + ((Map) v).forEach((k, val) -> notifyRemove((V) val)); } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ByteArrayTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ByteArrayTest.java index cf86799ca..87737898e 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ByteArrayTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ByteArrayTest.java @@ -21,8 +21,8 @@ import static org.apache.geaflow.memory.MemoryGroupManger.DEFAULT; -import com.google.common.collect.Maps; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.array.ByteArray; import org.apache.geaflow.memory.array.ByteArrayBuilder; @@ -30,65 +30,65 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class ByteArrayTest extends MemoryReleaseTest { - - @Test - public void testMemoryView() { +import com.google.common.collect.Maps; - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "128"); +public class ByteArrayTest extends MemoryReleaseTest { - MemoryManager.build(new Configuration(conf)); + @Test + public void testMemoryView() { - MemoryView memoryView = MemoryManager.getInstance().requireMemory(1024, DEFAULT); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "128"); + MemoryManager.build(new Configuration(conf)); - String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; + MemoryView memoryView = MemoryManager.getInstance().requireMemory(1024, DEFAULT); - memoryView.getWriter().write(a.getBytes()); + String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; - MemoryViewReference reference = new MemoryViewReference(memoryView); + memoryView.getWriter().write(a.getBytes()); - ByteArray array1 = ByteArrayBuilder.of(reference, 0, memoryView.contentSize); + MemoryViewReference reference = new MemoryViewReference(memoryView); - int offset = memoryView.contentSize; + ByteArray array1 = ByteArrayBuilder.of(reference, 0, memoryView.contentSize); - String b = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP)!"; + int offset = memoryView.contentSize; - reference.getMemoryView().getWriter().write(b.getBytes()); + String b = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP)!"; - ByteArray array2 = ByteArrayBuilder.of(reference, offset, memoryView.contentSize - offset); + reference.getMemoryView().getWriter().write(b.getBytes()); - Assert.assertEquals(a.getBytes(), array1.array()); - Assert.assertEquals(a.getBytes().length, array1.size()); - Assert.assertEquals(a.getBytes().length, array1.toByteBuffer().array().length); + ByteArray array2 = ByteArrayBuilder.of(reference, offset, memoryView.contentSize - offset); - Assert.assertEquals(b.getBytes(), array2.array()); - Assert.assertEquals(b.getBytes().length, array2.size()); + Assert.assertEquals(a.getBytes(), array1.array()); + Assert.assertEquals(a.getBytes().length, array1.size()); + Assert.assertEquals(a.getBytes().length, array1.toByteBuffer().array().length); - array2.release(); - array1.release(); + Assert.assertEquals(b.getBytes(), array2.array()); + Assert.assertEquals(b.getBytes().length, array2.size()); - Assert.assertNotNull(reference.getMemoryView()); + array2.release(); + array1.release(); - reference.decRef(); - Assert.assertNull(reference.getMemoryView()); - } + Assert.assertNotNull(reference.getMemoryView()); + reference.decRef(); + Assert.assertNull(reference.getMemoryView()); + } - @Test - public void testDefaultArray() { + @Test + public void testDefaultArray() { - String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; + String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; - ByteArray array = ByteArrayBuilder.of(a.getBytes()); + ByteArray array = ByteArrayBuilder.of(a.getBytes()); - Assert.assertNotNull(array.array()); - Assert.assertEquals(a.getBytes(), array.array()); - Assert.assertEquals(a.getBytes().length, array.size()); - Assert.assertEquals(a.getBytes().length, array.toByteBuffer().array().length); + Assert.assertNotNull(array.array()); + Assert.assertEquals(a.getBytes(), array.array()); + Assert.assertEquals(a.getBytes().length, array.size()); + Assert.assertEquals(a.getBytes().length, array.toByteBuffer().array().length); - array.release(); - Assert.assertNull(array.array()); - } + array.release(); + Assert.assertNull(array.array()); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ChunkListTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ChunkListTest.java index 48a2f05e2..443a5537b 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ChunkListTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ChunkListTest.java @@ -23,114 +23,140 @@ import static org.mockito.Mockito.when; import java.nio.ByteBuffer; + import org.testng.Assert; import org.testng.annotations.Test; public class ChunkListTest extends MemoryReleaseTest { - @Test - public void test1() { + @Test + public void test1() { - AbstractMemoryPool pool = mock(AbstractMemoryPool.class); - when(pool.getMemoryMode()).thenReturn(MemoryMode.ON_HEAP); + AbstractMemoryPool pool = mock(AbstractMemoryPool.class); + when(pool.getMemoryMode()).thenReturn(MemoryMode.ON_HEAP); - ChunkList c10 = new ChunkList<>(pool, 1, 10); - ChunkList c20 = new ChunkList<>(pool, 10, 100); + ChunkList c10 = new ChunkList<>(pool, 1, 10); + ChunkList c20 = new ChunkList<>(pool, 10, 100); - c10.setPreList(null); - c10.setNextList(c20); - c20.setNextList(null); - c20.setPreList(c10); + c10.setPreList(null); + c10.setNextList(c20); + c20.setNextList(null); + c20.setPreList(c10); - int pageShift = MemoryManager.validateAndCalculatePageShifts(8192); + int pageShift = MemoryManager.validateAndCalculatePageShifts(8192); - Chunk chunk = new Chunk<>(pool, new byte[ESegmentSize.S16777216.size()], 8192, 11, pageShift, + Chunk chunk = + new Chunk<>( + pool, + new byte[ESegmentSize.S16777216.size()], + 8192, + 11, + pageShift, ESegmentSize.S16777216.size(), 0); - long handle1 = chunk.allocate((int) (0.05 * ESegmentSize.S16777216.size())); + long handle1 = chunk.allocate((int) (0.05 * ESegmentSize.S16777216.size())); - c10.add(chunk); + c10.add(chunk); - Assert.assertEquals(c10.getHead(), chunk); + Assert.assertEquals(c10.getHead(), chunk); - Chunk chunk2 = new Chunk<>(pool, new byte[ESegmentSize.S16777216.size()], 8192, 11, pageShift, + Chunk chunk2 = + new Chunk<>( + pool, + new byte[ESegmentSize.S16777216.size()], + 8192, + 11, + pageShift, ESegmentSize.S16777216.size(), 0); - long handle2 = chunk2.allocate(ESegmentSize.S8388608.size()); - - c10.add(chunk2); - Assert.assertEquals(c20.getHead(), chunk2); + long handle2 = chunk2.allocate(ESegmentSize.S8388608.size()); - c10.allocate(new ByteBuf<>(), (int) (0.1 * ESegmentSize.S16777216.size())); + c10.add(chunk2); + Assert.assertEquals(c20.getHead(), chunk2); - Assert.assertNull(c10.getHead()); - Assert.assertEquals(c20.getHead().next, chunk2); + c10.allocate(new ByteBuf<>(), (int) (0.1 * ESegmentSize.S16777216.size())); - boolean needDestroy = !chunk2.parent.free(chunk2, handle2); + Assert.assertNull(c10.getHead()); + Assert.assertEquals(c20.getHead().next, chunk2); - Assert.assertTrue(needDestroy); + boolean needDestroy = !chunk2.parent.free(chunk2, handle2); - } + Assert.assertTrue(needDestroy); + } + @Test + public void test2() { - @Test - public void test2() { + AbstractMemoryPool pool = mock(AbstractMemoryPool.class); + when(pool.getMemoryMode()).thenReturn(MemoryMode.OFF_HEAP); - AbstractMemoryPool pool = mock(AbstractMemoryPool.class); - when(pool.getMemoryMode()).thenReturn(MemoryMode.OFF_HEAP); + ChunkList c10 = new ChunkList<>(pool, 1, 10); + ChunkList c20 = new ChunkList<>(pool, 10, 100); - ChunkList c10 = new ChunkList<>(pool, 1, 10); - ChunkList c20 = new ChunkList<>(pool, 10, 100); + c10.setPreList(null); + c10.setNextList(c20); + c20.setNextList(null); + c20.setPreList(c10); - c10.setPreList(null); - c10.setNextList(c20); - c20.setNextList(null); - c20.setPreList(c10); + int pageShift = MemoryManager.validateAndCalculatePageShifts(8192); - int pageShift = MemoryManager.validateAndCalculatePageShifts(8192); - - Chunk chunk = new Chunk<>(pool, - ByteBuffer.allocateDirect(ESegmentSize.S16777216.size()), 8192, 11, pageShift, + Chunk chunk = + new Chunk<>( + pool, + ByteBuffer.allocateDirect(ESegmentSize.S16777216.size()), + 8192, + 11, + pageShift, ESegmentSize.S16777216.size(), 0); - long handle1 = chunk.allocate((int) (0.05 * ESegmentSize.S16777216.size())); + long handle1 = chunk.allocate((int) (0.05 * ESegmentSize.S16777216.size())); - c10.add(chunk); + c10.add(chunk); - Assert.assertEquals(c10.getHead(), chunk); + Assert.assertEquals(c10.getHead(), chunk); - Chunk chunk2 = new Chunk<>(pool, - ByteBuffer.allocateDirect(ESegmentSize.S16777216.size()), 8192, 11, pageShift, + Chunk chunk2 = + new Chunk<>( + pool, + ByteBuffer.allocateDirect(ESegmentSize.S16777216.size()), + 8192, + 11, + pageShift, ESegmentSize.S16777216.size(), 0); - long handle2 = chunk2.allocate(ESegmentSize.S8388608.size()); + long handle2 = chunk2.allocate(ESegmentSize.S8388608.size()); - c10.add(chunk2); - Assert.assertEquals(c20.getHead(), chunk2); + c10.add(chunk2); + Assert.assertEquals(c20.getHead(), chunk2); - c10.allocate(new ByteBuf<>(), (int) (0.1 * ESegmentSize.S16777216.size())); + c10.allocate(new ByteBuf<>(), (int) (0.1 * ESegmentSize.S16777216.size())); - Assert.assertNull(c10.getHead()); - Assert.assertEquals(c20.getHead().next, chunk2); + Assert.assertNull(c10.getHead()); + Assert.assertEquals(c20.getHead().next, chunk2); - boolean needDestroy = !chunk2.parent.free(chunk2, handle2); + boolean needDestroy = !chunk2.parent.free(chunk2, handle2); - Assert.assertFalse(needDestroy); + Assert.assertFalse(needDestroy); - Chunk chunk3 = new Chunk<>(pool, new byte[ESegmentSize.S16777216.size()], 8192, 11, pageShift, + Chunk chunk3 = + new Chunk<>( + pool, + new byte[ESegmentSize.S16777216.size()], + 8192, + 11, + pageShift, ESegmentSize.S16777216.size(), 0); - long handle3 = chunk3.allocate(ESegmentSize.S8388608.size()); - c10.add(chunk3); - Assert.assertEquals(c20.getHead(), chunk3); - Assert.assertEquals(chunk3.next, chunk); - - needDestroy = !chunk3.parent.free(chunk3, handle3); + long handle3 = chunk3.allocate(ESegmentSize.S8388608.size()); + c10.add(chunk3); + Assert.assertEquals(c20.getHead(), chunk3); + Assert.assertEquals(chunk3.next, chunk); - Assert.assertFalse(needDestroy); + needDestroy = !chunk3.parent.free(chunk3, handle3); - } + Assert.assertFalse(needDestroy); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/DirectMemoryTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/DirectMemoryTest.java index c729ef630..5453aa496 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/DirectMemoryTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/DirectMemoryTest.java @@ -22,51 +22,52 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.utils.MemoryUtils; import org.testng.Assert; import org.testng.annotations.Test; public class DirectMemoryTest { - @Test - public void test() { - List args = new ArrayList<>(); - args.add("-XX:MaxDirectMemorySize=2G"); - long s = DirectMemory.maxDirectMemoryFromJVMOption(args); - Assert.assertEquals(MemoryUtils.humanReadableByteCount(s), "2.00GB"); + @Test + public void test() { + List args = new ArrayList<>(); + args.add("-XX:MaxDirectMemorySize=2G"); + long s = DirectMemory.maxDirectMemoryFromJVMOption(args); + Assert.assertEquals(MemoryUtils.humanReadableByteCount(s), "2.00GB"); - ByteBuffer bf = DirectMemory.allocateDirectNoCleaner(20); - long addr = DirectMemory.directBufferAddress(bf); - DirectMemory.setMemory(addr, 20, (byte) 0); + ByteBuffer bf = DirectMemory.allocateDirectNoCleaner(20); + long addr = DirectMemory.directBufferAddress(bf); + DirectMemory.setMemory(addr, 20, (byte) 0); - DirectMemory.putInt(addr, 100); - DirectMemory.putLong(addr + 4, 1000L); - DirectMemory.putShort(addr + 12, (short) 10); - DirectMemory.putByte(addr + 14, (byte) 1); + DirectMemory.putInt(addr, 100); + DirectMemory.putLong(addr + 4, 1000L); + DirectMemory.putShort(addr + 12, (short) 10); + DirectMemory.putByte(addr + 14, (byte) 1); - Assert.assertEquals(DirectMemory.getInt(addr), 100); - Assert.assertEquals(DirectMemory.getLong(addr + 4), 1000L); - Assert.assertEquals(DirectMemory.getShort(addr + 12), 10); - Assert.assertEquals(DirectMemory.getByte(addr + 14), 1); - Assert.assertEquals(DirectMemory.getByte(addr + 15), 0); + Assert.assertEquals(DirectMemory.getInt(addr), 100); + Assert.assertEquals(DirectMemory.getLong(addr + 4), 1000L); + Assert.assertEquals(DirectMemory.getShort(addr + 12), 10); + Assert.assertEquals(DirectMemory.getByte(addr + 14), 1); + Assert.assertEquals(DirectMemory.getByte(addr + 15), 0); - ByteBuffer bf2 = DirectMemory.allocateDirectNoCleaner(20); - long addr2 = DirectMemory.directBufferAddress(bf2); - DirectMemory.copyMemory(addr, addr2, 20); + ByteBuffer bf2 = DirectMemory.allocateDirectNoCleaner(20); + long addr2 = DirectMemory.directBufferAddress(bf2); + DirectMemory.copyMemory(addr, addr2, 20); - Assert.assertEquals(DirectMemory.getInt(addr2), 100); - Assert.assertEquals(DirectMemory.getLong(addr2 + 4), 1000L); - Assert.assertEquals(DirectMemory.getShort(addr2 + 12), 10); - Assert.assertEquals(DirectMemory.getByte(addr2 + 14), 1); - Assert.assertEquals(DirectMemory.getByte(addr2 + 15), 0); + Assert.assertEquals(DirectMemory.getInt(addr2), 100); + Assert.assertEquals(DirectMemory.getLong(addr2 + 4), 1000L); + Assert.assertEquals(DirectMemory.getShort(addr2 + 12), 10); + Assert.assertEquals(DirectMemory.getByte(addr2 + 14), 1); + Assert.assertEquals(DirectMemory.getByte(addr2 + 15), 0); - ByteBuffer[] bfs = DirectMemory.splitBuffer(bf, 10); - Assert.assertEquals(bfs.length, 2); - ByteBuffer bf3 = DirectMemory.mergeBuffer(bfs[0], bfs[1]); - long addr3 = DirectMemory.directBufferAddress(bf3); - Assert.assertEquals(addr, addr3); + ByteBuffer[] bfs = DirectMemory.splitBuffer(bf, 10); + Assert.assertEquals(bfs.length, 2); + ByteBuffer bf3 = DirectMemory.mergeBuffer(bfs[0], bfs[1]); + long addr3 = DirectMemory.directBufferAddress(bf3); + Assert.assertEquals(addr, addr3); - DirectMemory.freeDirectBuffer(bf2); - DirectMemory.freeDirectBuffer(bf3); - } + DirectMemory.freeDirectBuffer(bf2); + DirectMemory.freeDirectBuffer(bf3); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/FreeChunkStatisticsTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/FreeChunkStatisticsTest.java index b525b67b3..d42c305a7 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/FreeChunkStatisticsTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/FreeChunkStatisticsTest.java @@ -24,33 +24,32 @@ public class FreeChunkStatisticsTest extends MemoryReleaseTest { - @Test - public void test() throws InterruptedException { + @Test + public void test() throws InterruptedException { - FreeChunkStatistics statistics = new FreeChunkStatistics(10, 1000); + FreeChunkStatistics statistics = new FreeChunkStatistics(10, 1000); - statistics.update(1); + statistics.update(1); - statistics.update(2); + statistics.update(2); - Assert.assertFalse(statistics.isFull()); + Assert.assertFalse(statistics.isFull()); - Assert.assertEquals(statistics.getMinFree(), 1); + Assert.assertEquals(statistics.getMinFree(), 1); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 10; i++) { - Thread.sleep(1001); + Thread.sleep(1001); - statistics.update(i); - - } + statistics.update(i); + } - Assert.assertEquals(statistics.getMinFree(), 0); + Assert.assertEquals(statistics.getMinFree(), 0); - Assert.assertTrue(statistics.isFull()); + Assert.assertTrue(statistics.isFull()); - statistics.clear(); + statistics.clear(); - Assert.assertFalse(statistics.isFull()); - } + Assert.assertFalse(statistics.isFull()); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryGroupTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryGroupTest.java index 19fda946d..e9c3aa136 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryGroupTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryGroupTest.java @@ -19,275 +19,281 @@ package org.apache.geaflow.memory; -import com.google.common.collect.Maps; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.MemoryUtils; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; -public class MemoryGroupTest extends MemoryReleaseTest { - - - @Test - public void test1() { - - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - MemoryManager.build(new Configuration(conf)); +import com.google.common.collect.Maps; - MemoryView view = MemoryManager.getInstance() - .requireMemory(1024, MemoryGroupManger.DEFAULT); +public class MemoryGroupTest extends MemoryReleaseTest { - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); - System.out.println(MemoryGroupManger.DEFAULT.toString()); + @Test + public void test1() { - Assert.assertEquals(MemoryGroupManger.DEFAULT, new MemoryGroup("default", 1024)); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + MemoryManager.build(new Configuration(conf)); - view.reset(); + MemoryView view = MemoryManager.getInstance().requireMemory(1024, MemoryGroupManger.DEFAULT); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); + System.out.println(MemoryGroupManger.DEFAULT.toString()); - view.close(); + Assert.assertEquals(MemoryGroupManger.DEFAULT, new MemoryGroup("default", 1024)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 0); + view.reset(); - MemoryManager.getInstance().dispose(); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); - Assert.assertEquals(MemoryGroupManger.DEFAULT.getThreads(), 0); - } + view.close(); - @Test - public void test2() { + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 0); - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - MemoryManager.build(new Configuration(conf)); + MemoryManager.getInstance().dispose(); - MemoryView view = MemoryManager.getInstance() - .requireMemory(1024, MemoryGroupManger.DEFAULT); + Assert.assertEquals(MemoryGroupManger.DEFAULT.getThreads(), 0); + } - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); + @Test + public void test2() { - MemoryViewWriter writer = view.getWriter(1024); - writer.write(new byte[2048]); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + MemoryManager.build(new Configuration(conf)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + MemoryView view = MemoryManager.getInstance().requireMemory(1024, MemoryGroupManger.DEFAULT); - writer.write(new byte[1]); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 1024); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 3072); + MemoryViewWriter writer = view.getWriter(1024); + writer.write(new byte[2048]); - view.close(); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 0); - } + writer.write(new byte[1]); - @Test - public void test3() { + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 3072); - long memory = 32 * MemoryUtils.MB; + view.close(); - MemoryGroupManger.getInstance().resetMemory(memory, (int) (1 * MemoryUtils.MB), - MemoryMode.OFF_HEAP); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 0); + } - long shuffleBase = 4 * MemoryUtils.MB; + @Test + public void test3() { - Assert.assertEquals(MemoryGroupManger.DEFAULT.baseBytes(), 0); - Assert.assertEquals(MemoryGroupManger.STATE.baseBytes(), 0); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.baseBytes(), shuffleBase); + long memory = 32 * MemoryUtils.MB; - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase); + MemoryGroupManger.getInstance() + .resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(2048, MemoryMode.OFF_HEAP), true); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 2048); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(shuffleBase, MemoryMode.OFF_HEAP), - true); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), shuffleBase + 2048); + long shuffleBase = 4 * MemoryUtils.MB; - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - 28 * MemoryUtils.MB - shuffleBase); + Assert.assertEquals(MemoryGroupManger.DEFAULT.baseBytes(), 0); + Assert.assertEquals(MemoryGroupManger.STATE.baseBytes(), 0); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.baseBytes(), shuffleBase); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(1024, MemoryMode.OFF_HEAP), true); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), memory - shuffleBase); - MemoryGroupManger.DEFAULT.allocate(28 * MemoryUtils.MB - shuffleBase, MemoryMode.OFF_HEAP); - MemoryGroupManger.SHUFFLE.allocate(shuffleBase - 3072, MemoryMode.OFF_HEAP); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(2048, MemoryMode.OFF_HEAP), true); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 2048); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(shuffleBase, MemoryMode.OFF_HEAP), true); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), shuffleBase + 2048); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), + 28 * MemoryUtils.MB - shuffleBase); - Assert.assertEquals(MemoryGroupManger.DEFAULT.allocate(1, MemoryMode.OFF_HEAP), false); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(1, MemoryMode.OFF_HEAP), false); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(1024, MemoryMode.OFF_HEAP), true); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 8 * MemoryUtils.MB); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 24 * MemoryUtils.MB); + MemoryGroupManger.DEFAULT.allocate(28 * MemoryUtils.MB - shuffleBase, MemoryMode.OFF_HEAP); + MemoryGroupManger.SHUFFLE.allocate(shuffleBase - 3072, MemoryMode.OFF_HEAP); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.free(shuffleBase, MemoryMode.OFF_HEAP), true); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.free(shuffleBase, MemoryMode.OFF_HEAP), true); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), shuffleBase); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); - MemoryGroupManger.getInstance().clear(); + Assert.assertEquals(MemoryGroupManger.DEFAULT.allocate(1, MemoryMode.OFF_HEAP), false); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.allocate(1, MemoryMode.OFF_HEAP), false); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 0); - } + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 8 * MemoryUtils.MB); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 24 * MemoryUtils.MB); - @Test - public void test4() { - long memory = 32 * MemoryUtils.MB; - MemoryGroupManger.getInstance().resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); - long shuffleBase = 4 * MemoryUtils.MB; + Assert.assertEquals(MemoryGroupManger.SHUFFLE.free(shuffleBase, MemoryMode.OFF_HEAP), true); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.free(shuffleBase, MemoryMode.OFF_HEAP), true); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), shuffleBase); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase); + MemoryGroupManger.getInstance().clear(); - Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 0); + } - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase - 4096); + @Test + public void test4() { + long memory = 32 * MemoryUtils.MB; + MemoryGroupManger.getInstance() + .resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); + long shuffleBase = 4 * MemoryUtils.MB; - Assert.assertTrue( - MemoryGroupManger.STATE.allocate(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - MemoryMode.OFF_HEAP)); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), memory - shuffleBase); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); + Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); - Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertFalse(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), memory - shuffleBase - 4096); - MemoryGroupManger.getInstance().clear(); - } + Assert.assertTrue( + MemoryGroupManger.STATE.allocate( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), MemoryMode.OFF_HEAP)); - @Test - public void test5() { - long memory = 32 * MemoryUtils.MB; - MemoryGroupManger.getInstance().resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.ON_HEAP); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); - Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.ON_HEAP)); - Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.ON_HEAP)); - Assert.assertTrue(MemoryGroupManger.SHUFFLE.allocate(2048, MemoryMode.ON_HEAP)); + Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertFalse(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 2048); - Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); + MemoryGroupManger.getInstance().clear(); + } - Assert.assertTrue(MemoryGroupManger.SHUFFLE.allocate(memory, MemoryMode.ON_HEAP)); + @Test + public void test5() { + long memory = 32 * MemoryUtils.MB; + MemoryGroupManger.getInstance() + .resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.ON_HEAP); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), memory + 2048); + Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.ON_HEAP)); + Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.ON_HEAP)); + Assert.assertTrue(MemoryGroupManger.SHUFFLE.allocate(2048, MemoryMode.ON_HEAP)); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), 2048); + Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); - Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(memory, MemoryMode.ON_HEAP)); + Assert.assertTrue(MemoryGroupManger.SHUFFLE.allocate(memory, MemoryMode.ON_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), memory + 2048); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.usedBytes(), memory + 2048); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedOffHeapBytes(), 0); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedOnHeapBytes(), memory + 2048); + Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(memory, MemoryMode.ON_HEAP)); - Assert.assertTrue(MemoryGroupManger.DEFAULT.free(memory, MemoryMode.ON_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - } + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), memory + 2048); - @Test - public void testRatio() { - MemoryGroupManger.getInstance().clear(); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedOffHeapBytes(), 0); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedOnHeapBytes(), memory + 2048); - MemoryGroupManger.getInstance().load(new Configuration()); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.getRatio(), 0.1); - Assert.assertEquals(MemoryGroupManger.STATE.getRatio(), 0.0); - Assert.assertEquals(MemoryGroupManger.DEFAULT.getRatio(), 0.0); + Assert.assertTrue(MemoryGroupManger.DEFAULT.free(memory, MemoryMode.ON_HEAP)); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + } - long memory = 32 * MemoryUtils.MB; - MemoryGroupManger.getInstance().resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); - long shuffleBase = 4 * MemoryUtils.MB; + @Test + public void testRatio() { + MemoryGroupManger.getInstance().clear(); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase); + MemoryGroupManger.getInstance().load(new Configuration()); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.getRatio(), 0.1); + Assert.assertEquals(MemoryGroupManger.STATE.getRatio(), 0.0); + Assert.assertEquals(MemoryGroupManger.DEFAULT.getRatio(), 0.0); - Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); + long memory = 32 * MemoryUtils.MB; + MemoryGroupManger.getInstance() + .resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); + long shuffleBase = 4 * MemoryUtils.MB; - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase - 4096); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), memory - shuffleBase); - Assert.assertTrue( - MemoryGroupManger.STATE.allocate(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - MemoryMode.OFF_HEAP)); + Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), memory - shuffleBase - 4096); - Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertFalse(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertTrue( + MemoryGroupManger.STATE.allocate( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), MemoryMode.OFF_HEAP)); - MemoryGroupManger.getInstance().clear(); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); - } + Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertFalse(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - @Test - public void testRatio2() { - MemoryGroupManger.getInstance().clear(); + MemoryGroupManger.getInstance().clear(); + } - Map config = Maps.newHashMap(); - config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:20:*"); + @Test + public void testRatio2() { + MemoryGroupManger.getInstance().clear(); - MemoryGroupManger.getInstance().load(new Configuration(config)); - Assert.assertEquals(MemoryGroupManger.SHUFFLE.getRatio(), 0.1); - Assert.assertEquals(MemoryGroupManger.STATE.getRatio(), 0.2); - Assert.assertEquals(MemoryGroupManger.DEFAULT.getRatio(), 0.0); + Map config = Maps.newHashMap(); + config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:20:*"); - long memory = 32 * MemoryUtils.MB; - MemoryGroupManger.getInstance().resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); - long shuffleBase = 4 * MemoryUtils.MB; - long stateBase = 7 * MemoryUtils.MB; + MemoryGroupManger.getInstance().load(new Configuration(config)); + Assert.assertEquals(MemoryGroupManger.SHUFFLE.getRatio(), 0.1); + Assert.assertEquals(MemoryGroupManger.STATE.getRatio(), 0.2); + Assert.assertEquals(MemoryGroupManger.DEFAULT.getRatio(), 0.0); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase - stateBase); + long memory = 32 * MemoryUtils.MB; + MemoryGroupManger.getInstance() + .resetMemory(memory, (int) (1 * MemoryUtils.MB), MemoryMode.OFF_HEAP); + long shuffleBase = 4 * MemoryUtils.MB; + long stateBase = 7 * MemoryUtils.MB; - Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); - Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), + memory - shuffleBase - stateBase); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - memory - shuffleBase - stateBase - 2048); + Assert.assertTrue(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.DEFAULT.usedBytes(), 2048); + Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 2048); - Assert.assertTrue( - MemoryGroupManger.DEFAULT.allocate(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), - MemoryMode.OFF_HEAP)); + Assert.assertEquals( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), + memory - shuffleBase - stateBase - 2048); - Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); + Assert.assertTrue( + MemoryGroupManger.DEFAULT.allocate( + MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), MemoryMode.OFF_HEAP)); - Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); - Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertEquals(MemoryGroupManger.getInstance().getCurrentSharedFreeBytes(), 0); - Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 4096); + Assert.assertFalse(MemoryGroupManger.DEFAULT.allocate(2048, MemoryMode.OFF_HEAP)); + Assert.assertTrue(MemoryGroupManger.STATE.allocate(2048, MemoryMode.OFF_HEAP)); - MemoryGroupManger.getInstance().clear(); + Assert.assertEquals(MemoryGroupManger.STATE.usedBytes(), 4096); - } + MemoryGroupManger.getInstance().clear(); + } - @Test - public void testRatio3() { - Assert.assertThrows(IllegalArgumentException.class, () -> { - Map config = Maps.newHashMap(); - config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:20"); + @Test + public void testRatio3() { + Assert.assertThrows( + IllegalArgumentException.class, + () -> { + Map config = Maps.newHashMap(); + config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:20"); - MemoryGroupManger.getInstance().load(new Configuration(config)); + MemoryGroupManger.getInstance().load(new Configuration(config)); }); - } + } - @Test - public void testRatio4() { - Assert.assertThrows(IllegalArgumentException.class, () -> { - Map config = Maps.newHashMap(); - config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:200:*"); + @Test + public void testRatio4() { + Assert.assertThrows( + IllegalArgumentException.class, + () -> { + Map config = Maps.newHashMap(); + config.put(MemoryConfigKeys.MEMORY_GROUP_RATIO.getKey(), "10:200:*"); - MemoryGroupManger.getInstance().load(new Configuration(config)); + MemoryGroupManger.getInstance().load(new Configuration(config)); }); - } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryReleaseTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryReleaseTest.java index 8604809a6..0f580502c 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryReleaseTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryReleaseTest.java @@ -19,15 +19,14 @@ package org.apache.geaflow.memory; - import org.testng.annotations.AfterMethod; public class MemoryReleaseTest { - @AfterMethod - public void tearUp() { - if (MemoryManager.getInstance() != null) { - MemoryManager.getInstance().dispose(); - } + @AfterMethod + public void tearUp() { + if (MemoryManager.getInstance() != null) { + MemoryManager.getInstance().dispose(); } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewReferenceTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewReferenceTest.java index 20f04cdb7..6963bba79 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewReferenceTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewReferenceTest.java @@ -21,39 +21,40 @@ import static org.apache.geaflow.memory.MemoryGroupManger.DEFAULT; -import com.google.common.collect.Maps; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; -public class MemoryViewReferenceTest extends MemoryReleaseTest { +import com.google.common.collect.Maps; +public class MemoryViewReferenceTest extends MemoryReleaseTest { - @Test - public void test() { + @Test + public void test() { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "128"); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "128"); - MemoryManager.build(new Configuration(conf)); + MemoryManager.build(new Configuration(conf)); - MemoryView memoryView = MemoryManager.getInstance().requireMemory(1023, DEFAULT); + MemoryView memoryView = MemoryManager.getInstance().requireMemory(1023, DEFAULT); - MemoryViewReference reference = new MemoryViewReference(memoryView); + MemoryViewReference reference = new MemoryViewReference(memoryView); - reference.incRef(); - reference.incRef(); - Assert.assertNotNull(reference.getMemoryView()); + reference.incRef(); + reference.incRef(); + Assert.assertNotNull(reference.getMemoryView()); - reference.decRef(); - reference.decRef(); + reference.decRef(); + reference.decRef(); - Assert.assertNotNull(reference.getMemoryView()); + Assert.assertNotNull(reference.getMemoryView()); - reference.decRef(); + reference.decRef(); - Assert.assertNull(reference.getMemoryView()); - } + Assert.assertNull(reference.getMemoryView()); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewTest.java index 84178ec13..ce4e71b36 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/MemoryViewTest.java @@ -19,8 +19,6 @@ package org.apache.geaflow.memory; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -33,697 +31,693 @@ import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.MemoryUtils; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; -public class MemoryViewTest extends MemoryReleaseTest { +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; - private static final int PAGE_SIZE = 8192; - private static final int PAGE_SHIFTS = 11; - - @Test - public void testView() { - - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); - MemoryManager.build(new Configuration(conf)); - System.out.println(MemoryManager.getInstance().toString()); - - MemoryView view = MemoryManager.getInstance() - .requireMemory(1200, MemoryGroupManger.SHUFFLE); - byte[] a = new byte[999]; - Arrays.fill(a, (byte) 1); - byte[] c = new byte[100]; - ByteBuffer bf = ByteBuffer.wrap(new byte[1100]); - bf.put(a); - bf.put(c); - bf.put((byte) 3); - - MemoryViewWriter writer = view.getWriter(); - writer.write(a); - writer.write(c); - writer.write(3); - Assert.assertEquals(view.remain(), 948); - byte[] b = view.toArray(); - Assert.assertEquals(b, bf.array()); - view.close(); - - view = MemoryManager.getInstance().requireMemory(1600, MemoryGroupManger.SHUFFLE); - writer = view.getWriter(); - writer.write(a); - writer.write(c); - writer.write(3); - Assert.assertEquals(view.remain(), 1024 * 2 - 1100); - b = view.toArray(); - Assert.assertEquals(b, bf.array()); - view.close(); - - MemoryManager.getInstance().dispose(); - - conf.remove(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey()); - conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - MemoryManager.build(new Configuration(conf)); - - view = MemoryManager.getInstance().requireMemory(1200, MemoryGroupManger.SHUFFLE); - writer = view.getWriter(); - writer.write(a); - writer.write(c); - writer.write(3); - Assert.assertEquals(view.remain(), 1024 * 2 - 1100); - b = view.toArray(); - Assert.assertEquals(b, bf.array()); - view.close(); - - view = MemoryManager.getInstance().requireMemory(1600, MemoryGroupManger.SHUFFLE); - writer = view.getWriter(); - writer.write(a); - writer.write(c); - writer.write(3); - Assert.assertEquals(view.remain(), 1024 * 2 - 1100); - b = view.toArray(); - Assert.assertEquals(b, bf.array()); - - MemoryView view2 = MemoryManager.getInstance() - .requireMemory(bf.array().length, MemoryGroupManger.SHUFFLE); - writer = view2.getWriter(); - writer.write(bf.array()); - - Assert.assertEquals(view.hashCode(), view2.hashCode()); - Assert.assertEquals(view, view2); - view.close(); - view2.close(); - - System.out.println(MemoryManager.getInstance().totalAllocateHeapMemory()); - System.out.println(MemoryManager.getInstance().totalAllocateOffHeapMemory()); - System.out.println(MemoryManager.getInstance().usedHeapMemory()); - System.out.println(MemoryManager.getInstance().usedOffHeapMemory()); - } +public class MemoryViewTest extends MemoryReleaseTest { - @Test(priority = 1) - public void release() throws Throwable { - Assert.assertThrows(IllegalArgumentException.class, () -> { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); - MemoryManager.build(new Configuration(conf)); - - byte[] a = new byte[999]; - for (int i = 0; i < a.length; i++) { - a[i] = 1; - } - byte[] c = new byte[100]; - for (int i = 0; i < c.length; i++) { - c[i] = 0; - } - ByteBuffer bf = ByteBuffer.wrap(new byte[10099]); - bf.put(c); - bf.put(a); - - MemoryView view = MemoryManager.getInstance() - .requireMemory(bf.array().length, MemoryGroupManger.SHUFFLE); - view.getWriter().write(bf.array()); - view.close(); - - view.remain(); + private static final int PAGE_SIZE = 8192; + private static final int PAGE_SHIFTS = 11; + + @Test + public void testView() { + + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + MemoryManager.build(new Configuration(conf)); + System.out.println(MemoryManager.getInstance().toString()); + + MemoryView view = MemoryManager.getInstance().requireMemory(1200, MemoryGroupManger.SHUFFLE); + byte[] a = new byte[999]; + Arrays.fill(a, (byte) 1); + byte[] c = new byte[100]; + ByteBuffer bf = ByteBuffer.wrap(new byte[1100]); + bf.put(a); + bf.put(c); + bf.put((byte) 3); + + MemoryViewWriter writer = view.getWriter(); + writer.write(a); + writer.write(c); + writer.write(3); + Assert.assertEquals(view.remain(), 948); + byte[] b = view.toArray(); + Assert.assertEquals(b, bf.array()); + view.close(); + + view = MemoryManager.getInstance().requireMemory(1600, MemoryGroupManger.SHUFFLE); + writer = view.getWriter(); + writer.write(a); + writer.write(c); + writer.write(3); + Assert.assertEquals(view.remain(), 1024 * 2 - 1100); + b = view.toArray(); + Assert.assertEquals(b, bf.array()); + view.close(); + + MemoryManager.getInstance().dispose(); + + conf.remove(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey()); + conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + MemoryManager.build(new Configuration(conf)); + + view = MemoryManager.getInstance().requireMemory(1200, MemoryGroupManger.SHUFFLE); + writer = view.getWriter(); + writer.write(a); + writer.write(c); + writer.write(3); + Assert.assertEquals(view.remain(), 1024 * 2 - 1100); + b = view.toArray(); + Assert.assertEquals(b, bf.array()); + view.close(); + + view = MemoryManager.getInstance().requireMemory(1600, MemoryGroupManger.SHUFFLE); + writer = view.getWriter(); + writer.write(a); + writer.write(c); + writer.write(3); + Assert.assertEquals(view.remain(), 1024 * 2 - 1100); + b = view.toArray(); + Assert.assertEquals(b, bf.array()); + + MemoryView view2 = + MemoryManager.getInstance().requireMemory(bf.array().length, MemoryGroupManger.SHUFFLE); + writer = view2.getWriter(); + writer.write(bf.array()); + + Assert.assertEquals(view.hashCode(), view2.hashCode()); + Assert.assertEquals(view, view2); + view.close(); + view2.close(); + + System.out.println(MemoryManager.getInstance().totalAllocateHeapMemory()); + System.out.println(MemoryManager.getInstance().totalAllocateOffHeapMemory()); + System.out.println(MemoryManager.getInstance().usedHeapMemory()); + System.out.println(MemoryManager.getInstance().usedOffHeapMemory()); + } + + @Test(priority = 1) + public void release() throws Throwable { + Assert.assertThrows( + IllegalArgumentException.class, + () -> { + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + MemoryManager.build(new Configuration(conf)); + + byte[] a = new byte[999]; + for (int i = 0; i < a.length; i++) { + a[i] = 1; + } + byte[] c = new byte[100]; + for (int i = 0; i < c.length; i++) { + c[i] = 0; + } + ByteBuffer bf = ByteBuffer.wrap(new byte[10099]); + bf.put(c); + bf.put(a); + + MemoryView view = + MemoryManager.getInstance() + .requireMemory(bf.array().length, MemoryGroupManger.SHUFFLE); + view.getWriter().write(bf.array()); + view.close(); + + view.remain(); }); - } + } - @Test(priority = 2) - public void testReadWrite() throws Throwable { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); - MemoryManager.build(new Configuration(conf)); - - MemoryView heapView = MemoryManager.getInstance() - .requireMemory(64, MemoryGroupManger.SHUFFLE); - String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; - - byte[] content = a.getBytes(); - MemoryViewWriter writer = heapView.getWriter(64); - writer.write(content); - Assert.assertEquals(new String(heapView.toArray()), a); - - writer.write(content, 0, 50); - writer.write(content, 50, content.length - 50); - - Assert.assertEquals(new String(heapView.toArray()), a + a); - - MemoryViewReader reader = heapView.getReader(); - byte[] readContent = new byte[content.length * 2]; - for (int i = 0; i < readContent.length; i++) { - readContent[i] = reader.read(); - } - Assert.assertEquals(new String(readContent), a + a); - reader.reset(); - - readContent = new byte[content.length]; - reader.read(readContent, 0, 50); - reader.read(readContent, 50, content.length - 50); - Assert.assertEquals(new String(readContent), a); - - heapView = MemoryManager.getInstance().requireMemory(1025, MemoryGroupManger.SHUFFLE); - writer = heapView.getWriter(); - writer.write(new byte[101]); - writer.write(content); - - readContent = new byte[content.length]; - reader = heapView.getReader(); - reader.skip(101); - reader.read(readContent); - Assert.assertEquals(new String(readContent), a); - - heapView = MemoryManager.getInstance().requireMemory(62, MemoryGroupManger.SHUFFLE); - writer = heapView.getWriter(); - for (int i = 0; i < 5; i++) { - writer.write(content); - } - Assert.assertEquals(new String(heapView.toArray()), a + a + a + a + a); - } + @Test(priority = 2) + public void testReadWrite() throws Throwable { + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + MemoryManager.build(new Configuration(conf)); - @Test(priority = 3) - public void testReadWrite2() throws Throwable { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); - MemoryManager.build(new Configuration(conf)); - - byte[] a = new byte[64 + 128]; - for (int i = 0; i < a.length; i++) { - a[i] = (byte) i; - } - MemoryView heapView = MemoryManager.getInstance() - .requireMemory(64, MemoryGroupManger.SHUFFLE); - MemoryViewWriter writer = heapView.getWriter(); - writer.write(a); - - MemoryViewReader reader = heapView.getReader(); - byte[] b = new byte[a.length]; - reader.read(b); - Assert.assertEquals(reader.read(b, 64, 0), 0); - - Assert.assertEquals(a, b); - - writer.write(0); - Assert.assertEquals(reader.read(), 0); - - writer.write(0); - writer.write(0); - writer.write(1); - Assert.assertEquals(heapView.contentSize, 64 + 128 + 4); - reader.skip(1); - reader.skip(1); - Assert.assertEquals(reader.read(), 1); - - Assert.assertEquals(reader.skip(1), 0); - - byte[] t = new byte[4]; - Assert.assertEquals(reader.read(t, 0, t.length), -1); - } - - @Test - public void testReadAndWriteStream() throws IOException { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 100000; i++) { - sb.append("hello world"); - } - - String sentence = sb.toString(); - byte[] bytes = sentence.getBytes(); - ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); - - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); + MemoryView heapView = MemoryManager.getInstance().requireMemory(64, MemoryGroupManger.SHUFFLE); + String a = "MemoryView heapView = memoryManager.requireMemory(1025, MemoryMode.ON_HEAP"; - MemoryView heapView = memoryManager.requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - MemoryViewWriter writer = heapView.getWriter(); - writer.write(inputStream, bytes.length); + byte[] content = a.getBytes(); + MemoryViewWriter writer = heapView.getWriter(64); + writer.write(content); + Assert.assertEquals(new String(heapView.toArray()), a); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(bytes.length); - MemoryViewReader reader = heapView.getReader(); - reader.read(outputStream); + writer.write(content, 0, 50); + writer.write(content, 50, content.length - 50); - String res = outputStream.toString(); - Assert.assertEquals(res, sentence); + Assert.assertEquals(new String(heapView.toArray()), a + a); + MemoryViewReader reader = heapView.getReader(); + byte[] readContent = new byte[content.length * 2]; + for (int i = 0; i < readContent.length; i++) { + readContent[i] = reader.read(); } - - - @Test - public void testMultiThread() { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - ExecutorService executors = Executors.newFixedThreadPool(2); - MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); - - executors.submit(() -> { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 100000; i++) { - sb.append("hello world"); - } - - String sentence = sb.toString(); - byte[] bytes = sentence.getBytes(); - ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); - - MemoryView heapView = memoryManager.requireMemory(bytes.length, - MemoryGroupManger.SHUFFLE); - MemoryViewWriter writer = heapView.getWriter(); - try { - writer.write(inputStream, bytes.length); - } catch (IOException e) { - e.printStackTrace(); - } - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(bytes.length); - MemoryViewReader reader = heapView.getReader(); - try { - reader.read(outputStream); - } catch (IOException e) { - e.printStackTrace(); - } - - String res = outputStream.toString(); - //outputStream.size() - Assert.assertEquals(res, sentence); - }); + Assert.assertEquals(new String(readContent), a + a); + reader.reset(); + + readContent = new byte[content.length]; + reader.read(readContent, 0, 50); + reader.read(readContent, 50, content.length - 50); + Assert.assertEquals(new String(readContent), a); + + heapView = MemoryManager.getInstance().requireMemory(1025, MemoryGroupManger.SHUFFLE); + writer = heapView.getWriter(); + writer.write(new byte[101]); + writer.write(content); + + readContent = new byte[content.length]; + reader = heapView.getReader(); + reader.skip(101); + reader.read(readContent); + Assert.assertEquals(new String(readContent), a); + + heapView = MemoryManager.getInstance().requireMemory(62, MemoryGroupManger.SHUFFLE); + writer = heapView.getWriter(); + for (int i = 0; i < 5; i++) { + writer.write(content); } - - @Test - public void testInitBufs() { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); - - MemoryView view = manager.requireMemory(28948888, MemoryGroupManger.SHUFFLE); - - Assert.assertTrue(view.getBufList().size() == 2); - Assert.assertTrue(view.getBufList().get(0).getLength() == ESegmentSize.S16777216.size()); - Assert.assertTrue(view.getBufList().get(1).getLength() == ESegmentSize.S16777216.size()); - - MemoryView view1 = manager.requireMemory( - ESegmentSize.S16777216.size() + ESegmentSize.S16.size() + 2, MemoryGroupManger.SHUFFLE); - - Assert.assertTrue(view1.getBufList().size() == 2); - Assert.assertTrue(view1.getBufList().get(0).getLength() == ESegmentSize.S16777216.size()); - Assert.assertTrue(view1.getBufList().get(1).getLength() == ESegmentSize.S32.size()); + Assert.assertEquals(new String(heapView.toArray()), a + a + a + a + a); + } + + @Test(priority = 3) + public void testReadWrite2() throws Throwable { + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + MemoryManager.build(new Configuration(conf)); + + byte[] a = new byte[64 + 128]; + for (int i = 0; i < a.length; i++) { + a[i] = (byte) i; + } + MemoryView heapView = MemoryManager.getInstance().requireMemory(64, MemoryGroupManger.SHUFFLE); + MemoryViewWriter writer = heapView.getWriter(); + writer.write(a); + + MemoryViewReader reader = heapView.getReader(); + byte[] b = new byte[a.length]; + reader.read(b); + Assert.assertEquals(reader.read(b, 64, 0), 0); + + Assert.assertEquals(a, b); + + writer.write(0); + Assert.assertEquals(reader.read(), 0); + + writer.write(0); + writer.write(0); + writer.write(1); + Assert.assertEquals(heapView.contentSize, 64 + 128 + 4); + reader.skip(1); + reader.skip(1); + Assert.assertEquals(reader.read(), 1); + + Assert.assertEquals(reader.skip(1), 0); + + byte[] t = new byte[4]; + Assert.assertEquals(reader.read(t, 0, t.length), -1); + } + + @Test + public void testReadAndWriteStream() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100000; i++) { + sb.append("hello world"); } - - private void testReuse() { - - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 10000; i++) { + String sentence = sb.toString(); + byte[] bytes = sentence.getBytes(); + ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); + + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); + + MemoryView heapView = memoryManager.requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); + MemoryViewWriter writer = heapView.getWriter(); + writer.write(inputStream, bytes.length); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(bytes.length); + MemoryViewReader reader = heapView.getReader(); + reader.read(outputStream); + + String res = outputStream.toString(); + Assert.assertEquals(res, sentence); + } + + @Test + public void testMultiThread() { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + ExecutorService executors = Executors.newFixedThreadPool(2); + MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); + + executors.submit( + () -> { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100000; i++) { sb.append("hello world"); - } + } + + String sentence = sb.toString(); + byte[] bytes = sentence.getBytes(); + ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); + + MemoryView heapView = + memoryManager.requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); + MemoryViewWriter writer = heapView.getWriter(); + try { + writer.write(inputStream, bytes.length); + } catch (IOException e) { + e.printStackTrace(); + } + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(bytes.length); + MemoryViewReader reader = heapView.getReader(); + try { + reader.read(outputStream); + } catch (IOException e) { + e.printStackTrace(); + } + + String res = outputStream.toString(); + // outputStream.size() + Assert.assertEquals(res, sentence); + }); + } - byte[] bytes = sb.toString().getBytes(); - MemoryView view = MemoryManager.getInstance() - .requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); + @Test + public void testInitBufs() { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); - view.getWriter().write(bytes); - view.getWriter().write(bytes); - view.getWriter().write(bytes); + MemoryView view = manager.requireMemory(28948888, MemoryGroupManger.SHUFFLE); - byte[] bytes1 = new byte[bytes.length]; - view.getReader().read(bytes1); - Assert.assertEquals(sb.toString(), new String(bytes1)); + Assert.assertTrue(view.getBufList().size() == 2); + Assert.assertTrue(view.getBufList().get(0).getLength() == ESegmentSize.S16777216.size()); + Assert.assertTrue(view.getBufList().get(1).getLength() == ESegmentSize.S16777216.size()); - view.reset(); - - Random random = new Random(2000); + MemoryView view1 = + manager.requireMemory( + ESegmentSize.S16777216.size() + ESegmentSize.S16.size() + 2, MemoryGroupManger.SHUFFLE); - sb = new StringBuilder(); - for (int i = 0; i < random.nextInt(2000); i++) { - sb.append("hello world"); - } + Assert.assertTrue(view1.getBufList().size() == 2); + Assert.assertTrue(view1.getBufList().get(0).getLength() == ESegmentSize.S16777216.size()); + Assert.assertTrue(view1.getBufList().get(1).getLength() == ESegmentSize.S32.size()); + } - bytes = sb.toString().getBytes(); + private void testReuse() { - view.getWriter().write(bytes); - view.getWriter().write(bytes); - view.getWriter().write(bytes); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 10000; i++) { + sb.append("hello world"); + } - bytes1 = new byte[bytes.length]; - view.getReader().read(bytes1); - Assert.assertEquals(sb.toString(), new String(bytes1)); + byte[] bytes = sb.toString().getBytes(); + MemoryView view = + MemoryManager.getInstance().requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - view.reset(); + view.getWriter().write(bytes); + view.getWriter().write(bytes); + view.getWriter().write(bytes); - sb = new StringBuilder(); - for (int i = 0; i < random.nextInt(2000); i++) { - sb.append("hello world"); - } + byte[] bytes1 = new byte[bytes.length]; + view.getReader().read(bytes1); + Assert.assertEquals(sb.toString(), new String(bytes1)); - bytes = sb.toString().getBytes(); + view.reset(); - view.getWriter().write(bytes); - view.getWriter().write(bytes); - view.getWriter().write(bytes); + Random random = new Random(2000); - bytes1 = new byte[bytes.length + 10]; + sb = new StringBuilder(); + for (int i = 0; i < random.nextInt(2000); i++) { + sb.append("hello world"); + } - MemoryViewReader reader = view.getReader(); - reader.skip(bytes.length * 2); - reader.read(bytes1); + bytes = sb.toString().getBytes(); - byte[] bytes2 = Arrays.copyOf(bytes1, bytes.length); - Assert.assertEquals(sb.toString(), new String(bytes2)); + view.getWriter().write(bytes); + view.getWriter().write(bytes); + view.getWriter().write(bytes); - byte[] bytes3 = Arrays.copyOfRange(bytes1, bytes.length, bytes.length + 10); + bytes1 = new byte[bytes.length]; + view.getReader().read(bytes1); + Assert.assertEquals(sb.toString(), new String(bytes1)); - Assert.assertEquals(bytes3, new byte[10]); + view.reset(); - view.close(); + sb = new StringBuilder(); + for (int i = 0; i < random.nextInt(2000); i++) { + sb.append("hello world"); } - private void testReNew() { + bytes = sb.toString().getBytes(); - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < 10000; i++) { - sb.append("hello world"); - } + view.getWriter().write(bytes); + view.getWriter().write(bytes); + view.getWriter().write(bytes); - byte[] bytes = sb.toString().getBytes(); - MemoryView view1 = MemoryManager.getInstance() - .requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); + bytes1 = new byte[bytes.length + 10]; - view1.getWriter().write(bytes); - view1.getWriter().write(bytes); - view1.getWriter().write(bytes); + MemoryViewReader reader = view.getReader(); + reader.skip(bytes.length * 2); + reader.read(bytes1); - byte[] bytes1 = new byte[bytes.length]; - view1.getReader().read(bytes1); - Assert.assertEquals(sb.toString(), new String(bytes1)); + byte[] bytes2 = Arrays.copyOf(bytes1, bytes.length); + Assert.assertEquals(sb.toString(), new String(bytes2)); - Random random = new Random(2000); + byte[] bytes3 = Arrays.copyOfRange(bytes1, bytes.length, bytes.length + 10); - sb = new StringBuilder(); - for (int i = 0; i < random.nextInt(2000); i++) { - sb.append("hello world"); - } + Assert.assertEquals(bytes3, new byte[10]); - bytes = sb.toString().getBytes(); - MemoryView view = MemoryManager.getInstance() - .requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); + view.close(); + } - view.getWriter().write(bytes); - view.getWriter().write(bytes); - view.getWriter().write(bytes); - - bytes1 = new byte[bytes.length]; - view.getReader().read(bytes1); - Assert.assertEquals(sb.toString(), new String(bytes1)); - - sb = new StringBuilder(); - for (int i = 0; i < random.nextInt(2000); i++) { - sb.append("hello world"); - } + private void testReNew() { - bytes = sb.toString().getBytes(); - MemoryView view2 = MemoryManager.getInstance() - .requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - - view2.getWriter().write(bytes); - view2.getWriter().write(bytes); - view2.getWriter().write(bytes); - - bytes1 = new byte[bytes.length + 10]; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 10000; i++) { + sb.append("hello world"); + } - MemoryViewReader reader = view2.getReader(); - reader.skip(bytes.length * 2); - reader.read(bytes1); + byte[] bytes = sb.toString().getBytes(); + MemoryView view1 = + MemoryManager.getInstance().requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - byte[] bytes2 = Arrays.copyOf(bytes1, bytes.length); - Assert.assertEquals(sb.toString(), new String(bytes2)); + view1.getWriter().write(bytes); + view1.getWriter().write(bytes); + view1.getWriter().write(bytes); - byte[] bytes3 = Arrays.copyOfRange(bytes1, bytes.length, bytes.length + 10); + byte[] bytes1 = new byte[bytes.length]; + view1.getReader().read(bytes1); + Assert.assertEquals(sb.toString(), new String(bytes1)); - Assert.assertEquals(bytes3, new byte[10]); + Random random = new Random(2000); - view1.close(); - view.close(); - view2.close(); + sb = new StringBuilder(); + for (int i = 0; i < random.nextInt(2000); i++) { + sb.append("hello world"); } - @Test - public void testReuseWithLoop() { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); + bytes = sb.toString().getBytes(); + MemoryView view = + MemoryManager.getInstance().requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - long sum = 0; - long start = System.nanoTime(); - for (int i = 0; i < 1000; i++) { - testReuse(); - } + view.getWriter().write(bytes); + view.getWriter().write(bytes); + view.getWriter().write(bytes); - sum += (System.nanoTime() - start); + bytes1 = new byte[bytes.length]; + view.getReader().read(bytes1); + Assert.assertEquals(sb.toString(), new String(bytes1)); - System.out.println("reuse:" + sum); + sb = new StringBuilder(); + for (int i = 0; i < random.nextInt(2000); i++) { + sb.append("hello world"); } - @Test - public void testReNewWithLoop() { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); - long sum = 0; - long start = System.nanoTime(); - for (int i = 0; i < 1000; i++) { - testReNew(); - } - sum += (System.nanoTime() - start); + bytes = sb.toString().getBytes(); + MemoryView view2 = + MemoryManager.getInstance().requireMemory(bytes.length, MemoryGroupManger.SHUFFLE); - System.out.println("renew:" + sum); - - } + view2.getWriter().write(bytes); + view2.getWriter().write(bytes); + view2.getWriter().write(bytes); - @Test - public void testWriteBuffer() throws IOException { + bytes1 = new byte[bytes.length + 10]; - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); - MemoryView view = manager.requireMemory(2049, MemoryGroupManger.SHUFFLE); + MemoryViewReader reader = view2.getReader(); + reader.skip(bytes.length * 2); + reader.read(bytes1); - ByteBuffer buffer = ByteBuffer.allocateDirect(1025); - byte[] a = new byte[999]; - for (int i = 0; i < a.length; i++) { - a[i] = 1; - } - buffer.put(a); - buffer.flip(); + byte[] bytes2 = Arrays.copyOf(bytes1, bytes.length); + Assert.assertEquals(sb.toString(), new String(bytes2)); - view.getWriter().write(buffer, a.length); + byte[] bytes3 = Arrays.copyOfRange(bytes1, bytes.length, bytes.length + 10); - byte[] read = new byte[a.length]; - view.getReader().read(read); - Assert.assertEquals(a, read); + Assert.assertEquals(bytes3, new byte[10]); - view.close(); + view1.close(); + view.close(); + view2.close(); + } - buffer.flip(); - view = manager.requireMemory(512, MemoryGroupManger.SHUFFLE); - view.getWriter(512).write(buffer, a.length); + @Test + public void testReuseWithLoop() { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); - read = new byte[a.length]; - view.getReader().read(read); - Assert.assertEquals(read, a); + long sum = 0; + long start = System.nanoTime(); + for (int i = 0; i < 1000; i++) { + testReuse(); } - @Test - public void testRollback() { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); - - MemoryView view1 = MemoryManager.getInstance() - .requireMemory(100, MemoryGroupManger.SHUFFLE); - - MemoryViewWriter viewWriter = view1.getWriter(100); - - byte[] bytes = new byte[100]; - Arrays.fill(bytes, (byte) 1); - viewWriter.write(bytes); - viewWriter.position(100); - - bytes = new byte[101]; - viewWriter.write(bytes); + sum += (System.nanoTime() - start); - viewWriter.position(101); + System.out.println("reuse:" + sum); + } - Assert.assertTrue(view1.contentSize() == 101); - view1.getReader().read(bytes); - Assert.assertEquals(bytes[100], 0); + @Test + public void testReNewWithLoop() { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); + long sum = 0; + long start = System.nanoTime(); + for (int i = 0; i < 1000; i++) { + testReNew(); + } + sum += (System.nanoTime() - start); - viewWriter.write(bytes); - Assert.assertTrue(view1.contentSize() == 202); - bytes = new byte[202]; - view1.getReader().read(bytes); - Assert.assertEquals(bytes[100], 0); - Assert.assertEquals(bytes[201], 0); + System.out.println("renew:" + sum); + } - viewWriter.position(200); - Assert.assertTrue(view1.contentSize() == 200); - bytes = new byte[201]; - view1.getReader().read(bytes); - Assert.assertEquals(bytes[199], 1); - Assert.assertEquals(bytes[200], 0); + @Test + public void testWriteBuffer() throws IOException { - viewWriter.position(0); - Assert.assertTrue(view1.contentSize() == 0); + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); + MemoryView view = manager.requireMemory(2049, MemoryGroupManger.SHUFFLE); + ByteBuffer buffer = ByteBuffer.allocateDirect(1025); + byte[] a = new byte[999]; + for (int i = 0; i < a.length; i++) { + a[i] = 1; } + buffer.put(a); + buffer.flip(); + + view.getWriter().write(buffer, a.length); + + byte[] read = new byte[a.length]; + view.getReader().read(read); + Assert.assertEquals(a, read); + + view.close(); + + buffer.flip(); + view = manager.requireMemory(512, MemoryGroupManger.SHUFFLE); + view.getWriter(512).write(buffer, a.length); + + read = new byte[a.length]; + view.getReader().read(read); + Assert.assertEquals(read, a); + } + + @Test + public void testRollback() { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "320"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); + + MemoryView view1 = MemoryManager.getInstance().requireMemory(100, MemoryGroupManger.SHUFFLE); + + MemoryViewWriter viewWriter = view1.getWriter(100); + + byte[] bytes = new byte[100]; + Arrays.fill(bytes, (byte) 1); + viewWriter.write(bytes); + viewWriter.position(100); + + bytes = new byte[101]; + viewWriter.write(bytes); + + viewWriter.position(101); + + Assert.assertTrue(view1.contentSize() == 101); + view1.getReader().read(bytes); + Assert.assertEquals(bytes[100], 0); + + viewWriter.write(bytes); + Assert.assertTrue(view1.contentSize() == 202); + bytes = new byte[202]; + view1.getReader().read(bytes); + Assert.assertEquals(bytes[100], 0); + Assert.assertEquals(bytes[201], 0); + + viewWriter.position(200); + Assert.assertTrue(view1.contentSize() == 200); + bytes = new byte[201]; + view1.getReader().read(bytes); + Assert.assertEquals(bytes[199], 1); + Assert.assertEquals(bytes[200], 0); + + viewWriter.position(0); + Assert.assertTrue(view1.contentSize() == 0); + } + + @Test + public void testAllocate() throws IOException { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "1"); + config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 1024 * 1024); + config.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); + MemoryView view = manager.requireMemory(1024 * 1024, MemoryGroupManger.SHUFFLE); + int pos = 0; + for (int i = 0; i < 1000000; i++) { + try { + int len = new Random(1000).nextInt(1000); + byte[] bytes = new byte[len]; + Arrays.fill(bytes, (byte) 1); + pos = view.contentSize; + view.getWriter().write(bytes); + } catch (Throwable t) { + Assert.assertTrue(view.contentSize() == 16 * 1024 * 1024); + view.getWriter().position(pos); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(16 * 1024 * 1024); + MemoryViewReader reader = view.getReader(); + int len = reader.read(outputStream); + Assert.assertEquals(len, view.contentSize); + System.out.println(len); + System.out.println(view.contentSize()); - @Test - public void testAllocate() throws IOException { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "1"); - config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 1024 * 1024); - config.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); - MemoryView view = manager.requireMemory(1024 * 1024, MemoryGroupManger.SHUFFLE); - int pos = 0; - for (int i = 0; i < 1000000; i++) { - try { - int len = new Random(1000).nextInt(1000); - byte[] bytes = new byte[len]; - Arrays.fill(bytes, (byte) 1); - pos = view.contentSize; - view.getWriter().write(bytes); - } catch (Throwable t) { - Assert.assertTrue(view.contentSize() == 16 * 1024 * 1024); - view.getWriter().position(pos); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(16 * 1024 * 1024); - MemoryViewReader reader = view.getReader(); - int len = reader.read(outputStream); - Assert.assertEquals(len, view.contentSize); - System.out.println(len); - System.out.println(view.contentSize()); - - Assert.assertEquals(outputStream.size(), len); - byte[] bytes = new byte[len]; - Arrays.fill(bytes, (byte) 1); - Assert.assertEquals(outputStream.toByteArray(), bytes); - break; - } - } - + Assert.assertEquals(outputStream.size(), len); + byte[] bytes = new byte[len]; + Arrays.fill(bytes, (byte) 1); + Assert.assertEquals(outputStream.toByteArray(), bytes); + break; + } + } + } + + @Test + public void testWriteOutOfBounds() throws NoSuchFieldException, IllegalAccessException { + + MemoryView view = new MemoryView(Lists.newArrayList(new ByteBuf()), MemoryGroupManger.SHUFFLE); + Field field = view.getClass().getDeclaredField("contentSize"); + field.setAccessible(true); + field.setInt(view, Integer.MAX_VALUE); + boolean hasExp = false; + try { + view.getWriter().write(1); + } catch (IllegalArgumentException e) { + hasExp = true; } - @Test - public void testWriteOutOfBounds() throws NoSuchFieldException, IllegalAccessException { - - MemoryView view = new MemoryView(Lists.newArrayList(new ByteBuf()), - MemoryGroupManger.SHUFFLE); - Field field = view.getClass().getDeclaredField("contentSize"); - field.setAccessible(true); - field.setInt(view, Integer.MAX_VALUE); - boolean hasExp = false; - try { - view.getWriter().write(1); - } catch (IllegalArgumentException e) { - hasExp = true; - } - - Assert.assertTrue(hasExp); - hasExp = false; - - try { - view.getWriter().write(new byte[10]); - } catch (IllegalArgumentException e) { - hasExp = true; - } - - Assert.assertTrue(hasExp); + Assert.assertTrue(hasExp); + hasExp = false; - hasExp = false; + try { + view.getWriter().write(new byte[10]); + } catch (IllegalArgumentException e) { + hasExp = true; + } - try { - view.getWriter().write(new byte[10], 0, 9); - } catch (IllegalArgumentException e) { - hasExp = true; - } + Assert.assertTrue(hasExp); - Assert.assertTrue(hasExp); + hasExp = false; - hasExp = false; + try { + view.getWriter().write(new byte[10], 0, 9); + } catch (IllegalArgumentException e) { + hasExp = true; + } - try { - view.getWriter().write(ByteBuffer.wrap(new byte[10]), 10); - } catch (IllegalArgumentException | IOException e) { - hasExp = true; - } + Assert.assertTrue(hasExp); - Assert.assertTrue(hasExp); + hasExp = false; - hasExp = false; + try { + view.getWriter().write(ByteBuffer.wrap(new byte[10]), 10); + } catch (IllegalArgumentException | IOException e) { + hasExp = true; + } - try { - ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[10]); - view.getWriter().write(inputStream, 10); - } catch (IllegalArgumentException | IOException e) { - hasExp = true; - } + Assert.assertTrue(hasExp); - Assert.assertTrue(hasExp); + hasExp = false; + try { + ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[10]); + view.getWriter().write(inputStream, 10); + } catch (IllegalArgumentException | IOException e) { + hasExp = true; } - @Test - public void testReader() throws IOException { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); - config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 32 * 1024 * 1024); - config.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); - MemoryManager manager = MemoryManager.build(new Configuration(config)); + Assert.assertTrue(hasExp); + } - MemoryView view = manager.requireMemory((int) MemoryUtils.KB, MemoryGroupManger.DEFAULT); + @Test + public void testReader() throws IOException { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "32"); + config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 32 * 1024 * 1024); + config.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); + MemoryManager manager = MemoryManager.build(new Configuration(config)); - byte[] bytes = new byte[1024]; - Arrays.fill(bytes, (byte) 1); - view.getWriter().write(bytes); - view.getWriter().write(bytes); - view.getWriter().write(bytes); - bytes = new byte[128]; - Arrays.fill(bytes, (byte) 0); - view.getWriter().write(bytes); + MemoryView view = manager.requireMemory((int) MemoryUtils.KB, MemoryGroupManger.DEFAULT); - MemoryViewReader reader = view.getReader(); - Assert.assertEquals(reader.readPos(), 0); + byte[] bytes = new byte[1024]; + Arrays.fill(bytes, (byte) 1); + view.getWriter().write(bytes); + view.getWriter().write(bytes); + view.getWriter().write(bytes); + bytes = new byte[128]; + Arrays.fill(bytes, (byte) 0); + view.getWriter().write(bytes); - reader.read(new byte[1025]); - Assert.assertEquals(reader.readPos(), 1025); + MemoryViewReader reader = view.getReader(); + Assert.assertEquals(reader.readPos(), 0); - reader.read(); - Assert.assertEquals(reader.readPos(), 1026); + reader.read(new byte[1025]); + Assert.assertEquals(reader.readPos(), 1025); - reader.skip(1024); - Assert.assertEquals(reader.readPos(), 1026 + 1024); + reader.read(); + Assert.assertEquals(reader.readPos(), 1026); - reader.read(new ByteArrayOutputStream()); + reader.skip(1024); + Assert.assertEquals(reader.readPos(), 1026 + 1024); - Assert.assertEquals(reader.readPos(), 3 * 1024 + 128); + reader.read(new ByteArrayOutputStream()); - Assert.assertFalse(reader.hasNext()); + Assert.assertEquals(reader.readPos(), 3 * 1024 + 128); - boolean hasExp = false; - try { - reader.read(); - } catch (BufferUnderflowException e) { - hasExp = true; - } - Assert.assertTrue(hasExp); + Assert.assertFalse(reader.hasNext()); + boolean hasExp = false; + try { + reader.read(); + } catch (BufferUnderflowException e) { + hasExp = true; } + Assert.assertTrue(hasExp); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadLocalCacheTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadLocalCacheTest.java index dac50d380..6eced20ce 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadLocalCacheTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadLocalCacheTest.java @@ -22,168 +22,173 @@ import static org.apache.geaflow.memory.MemoryGroupManger.DEFAULT; import static org.apache.geaflow.memory.MemoryGroupManger.SHUFFLE; -import com.google.common.collect.Maps; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Maps; + public class ThreadLocalCacheTest extends MemoryReleaseTest { - @Test - public void test1() { + @Test + public void test1() { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "96"); - conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 96 * 1024 * 1024); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); - conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "6"); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "96"); + conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 96 * 1024 * 1024); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); + conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "6"); - MemoryManager.build(new Configuration(conf)); + MemoryManager.build(new Configuration(conf)); - List byteBufs = MemoryManager.getInstance() + List byteBufs = + MemoryManager.getInstance() .requireBufs(ESegmentSize.largest().size(), MemoryGroupManger.DEFAULT); - Assert.assertEquals(byteBufs.size(), 1); + Assert.assertEquals(byteBufs.size(), 1); - AbstractMemoryPool pool = byteBufs.get(0).chunk.pool; + AbstractMemoryPool pool = byteBufs.get(0).chunk.pool; - byteBufs = MemoryManager.getInstance() + byteBufs = + MemoryManager.getInstance() .requireBufs(ESegmentSize.largest().size(), MemoryGroupManger.DEFAULT); - Assert.assertEquals(byteBufs.size(), 1); + Assert.assertEquals(byteBufs.size(), 1); - AbstractMemoryPool pool2 = byteBufs.get(0).chunk.pool; - Assert.assertNotSame(pool, pool2); + AbstractMemoryPool pool2 = byteBufs.get(0).chunk.pool; + Assert.assertNotSame(pool, pool2); - byteBufs = MemoryManager.getInstance() + byteBufs = + MemoryManager.getInstance() .requireBufs(ESegmentSize.smallest().size(), MemoryGroupManger.DEFAULT); - Assert.assertEquals(byteBufs.size(), 1); + Assert.assertEquals(byteBufs.size(), 1); - AbstractMemoryPool pool3 = byteBufs.get(0).chunk.pool; - Assert.assertNotSame(pool3, pool2); + AbstractMemoryPool pool3 = byteBufs.get(0).chunk.pool; + Assert.assertNotSame(pool3, pool2); - byteBufs = MemoryManager.getInstance() + byteBufs = + MemoryManager.getInstance() .requireBufs(ESegmentSize.smallest().size(), MemoryGroupManger.DEFAULT); - Assert.assertEquals(byteBufs.size(), 1); + Assert.assertEquals(byteBufs.size(), 1); - AbstractMemoryPool pool4 = byteBufs.get(0).chunk.pool; - Assert.assertEquals(pool3, pool4); + AbstractMemoryPool pool4 = byteBufs.get(0).chunk.pool; + Assert.assertEquals(pool3, pool4); - byteBufs = MemoryManager.getInstance() - .requireBufs(ESegmentSize.largest().size() + ESegmentSize.smallest().size(), + byteBufs = + MemoryManager.getInstance() + .requireBufs( + ESegmentSize.largest().size() + ESegmentSize.smallest().size(), MemoryGroupManger.DEFAULT); - Assert.assertEquals(byteBufs.size(), 2); - - Assert.assertNotSame(byteBufs.get(0).chunk.pool, byteBufs.get(1).chunk.pool); - - try { - byteBufs = MemoryManager.getInstance() - .requireBufs(ESegmentSize.largest().size() + ESegmentSize.smallest().size(), - MemoryGroupManger.DEFAULT); - } catch (Throwable t) { - Assert.assertTrue(t instanceof OutOfMemoryError); - } - + Assert.assertEquals(byteBufs.size(), 2); + + Assert.assertNotSame(byteBufs.get(0).chunk.pool, byteBufs.get(1).chunk.pool); + + try { + byteBufs = + MemoryManager.getInstance() + .requireBufs( + ESegmentSize.largest().size() + ESegmentSize.smallest().size(), + MemoryGroupManger.DEFAULT); + } catch (Throwable t) { + Assert.assertTrue(t instanceof OutOfMemoryError); } + } - @Test - public void test2() { - - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "48"); - conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 48 * 1024 * 1024); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); - conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "3"); + @Test + public void test2() { - MemoryManager.build(new Configuration(conf)); + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "48"); + conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 48 * 1024 * 1024); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); + conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "3"); - ByteBuf byteBuf = MemoryManager.getInstance() - .requireBuf(ESegmentSize.largest().size(), SHUFFLE); + MemoryManager.build(new Configuration(conf)); - AbstractMemoryPool pool = byteBuf.chunk.pool; + ByteBuf byteBuf = + MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); - byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); + AbstractMemoryPool pool = byteBuf.chunk.pool; - AbstractMemoryPool pool2 = byteBuf.chunk.pool; - Assert.assertNotSame(pool, pool2); + byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); - byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.smallest().size(), SHUFFLE); + AbstractMemoryPool pool2 = byteBuf.chunk.pool; + Assert.assertNotSame(pool, pool2); - AbstractMemoryPool pool3 = byteBuf.chunk.pool; - Assert.assertNotSame(pool3, pool2); + byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.smallest().size(), SHUFFLE); - byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.smallest().size(), SHUFFLE); + AbstractMemoryPool pool3 = byteBuf.chunk.pool; + Assert.assertNotSame(pool3, pool2); - AbstractMemoryPool pool4 = byteBuf.chunk.pool; - Assert.assertEquals(pool3, pool4); + byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.smallest().size(), SHUFFLE); - try { - byteBuf = MemoryManager.getInstance() - .requireBuf(ESegmentSize.largest().size(), SHUFFLE); - } catch (Throwable t) { - Assert.assertTrue(t instanceof OutOfMemoryError); - } + AbstractMemoryPool pool4 = byteBuf.chunk.pool; + Assert.assertEquals(pool3, pool4); + try { + byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); + } catch (Throwable t) { + Assert.assertTrue(t instanceof OutOfMemoryError); } + } - @Test - public void test3() { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "16"); - conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 16 * 1024 * 1024); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); - conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); - conf.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "true"); - - MemoryManager.build(new Configuration(conf)); + @Test + public void test3() { + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "16"); + conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 16 * 1024 * 1024); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); + conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); + conf.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "true"); - ByteBuf byteBuf = MemoryManager.getInstance() - .requireBuf(ESegmentSize.largest().size(), SHUFFLE); - Assert.assertEquals(byteBuf.chunk.pool.getMemoryMode(), MemoryMode.OFF_HEAP); + MemoryManager.build(new Configuration(conf)); - byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); + ByteBuf byteBuf = + MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); + Assert.assertEquals(byteBuf.chunk.pool.getMemoryMode(), MemoryMode.OFF_HEAP); - Assert.assertEquals(byteBuf.chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); + byteBuf = MemoryManager.getInstance().requireBuf(ESegmentSize.largest().size(), SHUFFLE); - byteBuf.free(SHUFFLE); + Assert.assertEquals(byteBuf.chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); - } - - @Test - public void test4() { - Map conf = Maps.newHashMap(); - conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "16"); - conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 16 * 1024 * 1024); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); - conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); - conf.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "true"); + byteBuf.free(SHUFFLE); + } - MemoryManager.build(new Configuration(conf)); + @Test + public void test4() { + Map conf = Maps.newHashMap(); + conf.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "16"); + conf.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 16 * 1024 * 1024); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + 8 * 1024); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "11"); + conf.put(MemoryConfigKeys.MEMORY_POOL_SIZE.getKey(), "1"); + conf.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "true"); - List byteBufs = MemoryManager.getInstance() - .requireBufs(ESegmentSize.largest().size(), DEFAULT); - Assert.assertEquals(byteBufs.size(), 1); + MemoryManager.build(new Configuration(conf)); - Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); - byteBufs.get(0).free(SHUFFLE); + List byteBufs = + MemoryManager.getInstance().requireBufs(ESegmentSize.largest().size(), DEFAULT); + Assert.assertEquals(byteBufs.size(), 1); - byteBufs = MemoryManager.getInstance().requireBufs(ESegmentSize.largest().size(), SHUFFLE); - Assert.assertEquals(byteBufs.size(), 1); + Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); + byteBufs.get(0).free(SHUFFLE); - Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.OFF_HEAP); + byteBufs = MemoryManager.getInstance().requireBufs(ESegmentSize.largest().size(), SHUFFLE); + Assert.assertEquals(byteBufs.size(), 1); - byteBufs = MemoryManager.getInstance().requireBufs(ESegmentSize.largest().size(), SHUFFLE); + Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.OFF_HEAP); - Assert.assertEquals(byteBufs.size(), 1); + byteBufs = MemoryManager.getInstance().requireBufs(ESegmentSize.largest().size(), SHUFFLE); - Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); + Assert.assertEquals(byteBufs.size(), 1); - byteBufs.get(0).free(SHUFFLE); + Assert.assertEquals(byteBufs.get(0).chunk.pool.getMemoryMode(), MemoryMode.ON_HEAP); - } + byteBufs.get(0).free(SHUFFLE); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadMemoryTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadMemoryTest.java index cf5f02be5..d1c8aa25a 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadMemoryTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/ThreadMemoryTest.java @@ -19,194 +19,204 @@ package org.apache.geaflow.memory; -import io.netty.util.concurrent.FastThreadLocalThread; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.config.MemoryConfigKeys; import org.apache.geaflow.memory.exception.GeaflowOutOfMemoryException; import org.testng.Assert; import org.testng.annotations.Test; -public class ThreadMemoryTest extends MemoryReleaseTest { - - private static final int PAGE_SIZE = 8192; - private static final int PAGE_SHIFTS = 11; - - @Test - public void testStream() throws Throwable { - Map conf = new HashMap<>(); - conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "1024"); - conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); +import io.netty.util.concurrent.FastThreadLocalThread; - MemoryManager memoryManager = MemoryManager.build(new Configuration(conf)); - LinkedBlockingQueue views = new LinkedBlockingQueue<>(10); +public class ThreadMemoryTest extends MemoryReleaseTest { - int threadNum = 8; - ExecutorService executor = Executors.newFixedThreadPool(threadNum); - executor.execute(new FastThreadLocalThread(() -> { - while (true) { + private static final int PAGE_SIZE = 8192; + private static final int PAGE_SHIFTS = 11; + + @Test + public void testStream() throws Throwable { + Map conf = new HashMap<>(); + conf.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "1024"); + conf.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + conf.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + + MemoryManager memoryManager = MemoryManager.build(new Configuration(conf)); + LinkedBlockingQueue views = new LinkedBlockingQueue<>(10); + + int threadNum = 8; + ExecutorService executor = Executors.newFixedThreadPool(threadNum); + executor.execute( + new FastThreadLocalThread( + () -> { + while (true) { MemoryView v = views.poll(); if (v != null) { - v.close(); + v.close(); } - } - })); + } + })); - final CountDownLatch latch = new CountDownLatch(threadNum); - long start = System.currentTimeMillis(); - for (int i = 0; i < threadNum; i++) { - executor.execute(new FastThreadLocalThread(() -> { + final CountDownLatch latch = new CountDownLatch(threadNum); + long start = System.currentTimeMillis(); + for (int i = 0; i < threadNum; i++) { + executor.execute( + new FastThreadLocalThread( + () -> { for (int j = 0; j < 10000; j++) { - try { - MemoryView view = memoryManager.requireMemory(ESegmentSize.S65536.size(), - MemoryGroupManger.SHUFFLE); - while (!views.offer(view)) { - ; - } - } catch (Exception ex) { - System.out.println(ex); + try { + MemoryView view = + memoryManager.requireMemory( + ESegmentSize.S65536.size(), MemoryGroupManger.SHUFFLE); + while (!views.offer(view)) { + ; } + } catch (Exception ex) { + System.out.println(ex); + } } latch.countDown(); - })); - } - latch.await(); - System.out.println(System.currentTimeMillis() - start); - - Thread.sleep(1000); - System.out.println(memoryManager); + })); } - - @Test - public void testBatch() throws Throwable { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "1024"); - config.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); - config.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); - - MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); - LinkedBlockingQueue views = new LinkedBlockingQueue<>(); - - int threadNum = 8; - ExecutorService executor = Executors.newFixedThreadPool(threadNum); - final CountDownLatch latch = new CountDownLatch(threadNum); - long start = System.currentTimeMillis(); - for (int i = 0; i < threadNum; i++) { - executor.execute(() -> { - for (int j = 0; j < 100; j++) { - try { - MemoryView view = memoryManager.requireMemory(32, - MemoryGroupManger.SHUFFLE); - while (!views.offer(view)) { - ; - } - } catch (Exception ex) { - System.out.println(ex); - } + latch.await(); + System.out.println(System.currentTimeMillis() - start); + + Thread.sleep(1000); + System.out.println(memoryManager); + } + + @Test + public void testBatch() throws Throwable { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.ON_HEAP_MEMORY_SIZE_MB.getKey(), "1024"); + config.put(MemoryConfigKeys.MEMORY_MAX_ORDER.getKey(), "" + PAGE_SHIFTS); + config.put(MemoryConfigKeys.MEMORY_PAGE_SIZE.getKey(), "" + PAGE_SIZE); + + MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); + LinkedBlockingQueue views = new LinkedBlockingQueue<>(); + + int threadNum = 8; + ExecutorService executor = Executors.newFixedThreadPool(threadNum); + final CountDownLatch latch = new CountDownLatch(threadNum); + long start = System.currentTimeMillis(); + for (int i = 0; i < threadNum; i++) { + executor.execute( + () -> { + for (int j = 0; j < 100; j++) { + try { + MemoryView view = memoryManager.requireMemory(32, MemoryGroupManger.SHUFFLE); + while (!views.offer(view)) { + ; } - latch.countDown(); - memoryManager.localCache(MemoryGroupManger.SHUFFLE).free(); - }); - } - latch.await(); - System.out.println(System.currentTimeMillis() - start); - Thread.sleep(1000); - - System.out.println(memoryManager); - while (true) { - MemoryView v = views.poll(); - if (v != null) { - v.close(); - } else { - break; + } catch (Exception ex) { + System.out.println(ex); + } } - } - - System.out.println(memoryManager); - - final CountDownLatch latch2 = new CountDownLatch(threadNum); - start = System.currentTimeMillis(); - for (int i = 0; i < threadNum; i++) { - executor.execute(() -> { - for (int j = 0; j < 100; j++) { - try { - MemoryView view = memoryManager.requireMemory(32768, - MemoryGroupManger.SHUFFLE); - while (!views.offer(view)) { - ; - } - } catch (Exception ex) { - System.out.println(ex); - } + latch.countDown(); + memoryManager.localCache(MemoryGroupManger.SHUFFLE).free(); + }); + } + latch.await(); + System.out.println(System.currentTimeMillis() - start); + Thread.sleep(1000); + + System.out.println(memoryManager); + while (true) { + MemoryView v = views.poll(); + if (v != null) { + v.close(); + } else { + break; + } + } + + System.out.println(memoryManager); + + final CountDownLatch latch2 = new CountDownLatch(threadNum); + start = System.currentTimeMillis(); + for (int i = 0; i < threadNum; i++) { + executor.execute( + () -> { + for (int j = 0; j < 100; j++) { + try { + MemoryView view = memoryManager.requireMemory(32768, MemoryGroupManger.SHUFFLE); + while (!views.offer(view)) { + ; } - latch2.countDown(); - }); - } - latch2.await(); - System.out.println(System.currentTimeMillis() - start); - Thread.sleep(1000); - - while (true) { - MemoryView v = views.poll(); - if (v != null) { - v.close(); - } else { - break; + } catch (Exception ex) { + System.out.println(ex); + } } - } - - final CountDownLatch latch3 = new CountDownLatch(threadNum); - start = System.currentTimeMillis(); - for (int i = 0; i < threadNum; i++) { - executor.execute(() -> { - for (int j = 0; j < 100; j++) { - try { - MemoryView view = memoryManager.requireMemory(32768, - MemoryGroupManger.SHUFFLE); - while (!views.offer(view)) { - ; - } - } catch (Exception ex) { - System.out.println(ex); - } + latch2.countDown(); + }); + } + latch2.await(); + System.out.println(System.currentTimeMillis() - start); + Thread.sleep(1000); + + while (true) { + MemoryView v = views.poll(); + if (v != null) { + v.close(); + } else { + break; + } + } + + final CountDownLatch latch3 = new CountDownLatch(threadNum); + start = System.currentTimeMillis(); + for (int i = 0; i < threadNum; i++) { + executor.execute( + () -> { + for (int j = 0; j < 100; j++) { + try { + MemoryView view = memoryManager.requireMemory(32768, MemoryGroupManger.SHUFFLE); + while (!views.offer(view)) { + ; } - latch3.countDown(); - }); - } - latch3.await(); - System.out.println(System.currentTimeMillis() - start); - Thread.sleep(1000); - - while (true) { - MemoryView v = views.poll(); - if (v != null) { - v.close(); - } else { - break; + } catch (Exception ex) { + System.out.println(ex); + } } - } - - System.out.println(memoryManager); - Thread.sleep(1000); - System.out.println(memoryManager); + latch3.countDown(); + }); + } + latch3.await(); + System.out.println(System.currentTimeMillis() - start); + Thread.sleep(1000); + + while (true) { + MemoryView v = views.poll(); + if (v != null) { + v.close(); + } else { + break; + } } - @Test - public void testOOM() { - Assert.assertThrows(GeaflowOutOfMemoryException.class, () -> { - Map config = new HashMap<>(); - config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "1"); - config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 1024 * 1024); - config.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "false"); - - MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); - MemoryView view = memoryManager.requireMemory(1024 * 1024 * 10, MemoryGroupManger.DEFAULT); + System.out.println(memoryManager); + Thread.sleep(1000); + System.out.println(memoryManager); + } + + @Test + public void testOOM() { + Assert.assertThrows( + GeaflowOutOfMemoryException.class, + () -> { + Map config = new HashMap<>(); + config.put(MemoryConfigKeys.OFF_HEAP_MEMORY_SIZE_MB.getKey(), "1"); + config.put(MemoryConfigKeys.MAX_DIRECT_MEMORY_SIZE.getKey(), "" + 1024 * 1024); + config.put(MemoryConfigKeys.MEMORY_AUTO_ADAPT_ENABLE.getKey(), "false"); + + MemoryManager memoryManager = MemoryManager.build(new Configuration(config)); + MemoryView view = + memoryManager.requireMemory(1024 * 1024 * 10, MemoryGroupManger.DEFAULT); }); - } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/cleaner/CleanerTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/cleaner/CleanerTest.java index 2a1179a20..83118d71a 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/cleaner/CleanerTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/cleaner/CleanerTest.java @@ -20,27 +20,28 @@ package org.apache.geaflow.memory.cleaner; import java.nio.ByteBuffer; + import org.apache.geaflow.memory.DirectMemory; import org.testng.Assert; import org.testng.annotations.Test; public class CleanerTest { - @Test - public void test() { - ByteBuffer bf = DirectMemory.allocateDirectNoCleaner(20); + @Test + public void test() { + ByteBuffer bf = DirectMemory.allocateDirectNoCleaner(20); - if (DirectMemory.javaVersion() < 9) { - Assert.assertTrue(CleanerJava6.isSupported()); - Assert.assertFalse(CleanerJava9.isSupported()); - CleanerJava6 cleanerJava = new CleanerJava6(); - cleanerJava.freeDirectBuffer(bf); - } else { - Assert.assertTrue(CleanerJava6.isSupported()); - Assert.assertTrue(CleanerJava9.isSupported()); + if (DirectMemory.javaVersion() < 9) { + Assert.assertTrue(CleanerJava6.isSupported()); + Assert.assertFalse(CleanerJava9.isSupported()); + CleanerJava6 cleanerJava = new CleanerJava6(); + cleanerJava.freeDirectBuffer(bf); + } else { + Assert.assertTrue(CleanerJava6.isSupported()); + Assert.assertTrue(CleanerJava9.isSupported()); - CleanerJava9 cleanerJava = new CleanerJava9(); - cleanerJava.freeDirectBuffer(bf); - } + CleanerJava9 cleanerJava = new CleanerJava9(); + cleanerJava.freeDirectBuffer(bf); } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/compress/CompressTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/compress/CompressTest.java index 74dcfe852..fe2bccb8a 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/compress/CompressTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/compress/CompressTest.java @@ -27,6 +27,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.memory.MemoryGroupManger; @@ -39,107 +40,117 @@ public class CompressTest { - @BeforeClass - public static void setUp() { - MemoryManager.build(new Configuration()); - } + @BeforeClass + public static void setUp() { + MemoryManager.build(new Configuration()); + } - @AfterClass - public static void after() { - MemoryManager.getInstance().dispose(); - } + @AfterClass + public static void after() { + MemoryManager.getInstance().dispose(); + } - @Test - public void testSnappy() throws IOException { + @Test + public void testSnappy() throws IOException { - Map map = new HashMap<>(); - for (int i = 10; i < 100; i++) { - map.put(Integer.toString(i), Integer.toString(i)); - } - byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); - MemoryView cc = MemoryManager.getInstance().requireMemory(bMap.length, MemoryGroupManger.DEFAULT); - cc.getWriter().write(bMap); - cc = Snappy.uncompress(Snappy.compress(cc)); - Assert.assertEquals(cc.toArray(), bMap); + Map map = new HashMap<>(); + for (int i = 10; i < 100; i++) { + map.put(Integer.toString(i), Integer.toString(i)); } + byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); + MemoryView cc = + MemoryManager.getInstance().requireMemory(bMap.length, MemoryGroupManger.DEFAULT); + cc.getWriter().write(bMap); + cc = Snappy.uncompress(Snappy.compress(cc)); + Assert.assertEquals(cc.toArray(), bMap); + } - @Test - public void testLz4() throws IOException { - Map map = new HashMap<>(); - for (int i = 10; i < 100; i++) { - map.put(Integer.toString(i), Integer.toString(i)); - } - byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); - MemoryView cc = MemoryManager.getInstance().requireMemory(bMap.length, MemoryGroupManger.DEFAULT); - cc.getWriter().write(bMap); - cc = Lz4.uncompress(Lz4.compress(cc)); - Assert.assertEquals(cc.toArray(), bMap); + @Test + public void testLz4() throws IOException { + Map map = new HashMap<>(); + for (int i = 10; i < 100; i++) { + map.put(Integer.toString(i), Integer.toString(i)); } + byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); + MemoryView cc = + MemoryManager.getInstance().requireMemory(bMap.length, MemoryGroupManger.DEFAULT); + cc.getWriter().write(bMap); + cc = Lz4.uncompress(Lz4.compress(cc)); + Assert.assertEquals(cc.toArray(), bMap); + } - @Test - public void multiThreadSnappy() { + @Test + public void multiThreadSnappy() { - Map map = new HashMap<>(); - for (int i = 0; i < 10000; i++) { - map.put(Integer.toString(i), Integer.toString(i)); - } - byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); + Map map = new HashMap<>(); + for (int i = 0; i < 10000; i++) { + map.put(Integer.toString(i), Integer.toString(i)); + } + byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); - int threadNum = 5; - ExecutorService executor = Executors.newFixedThreadPool(threadNum); - List list = new ArrayList<>(); - for (int i = 0; i < threadNum; i++) { - list.add(executor.submit(() -> { + int threadNum = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadNum); + List list = new ArrayList<>(); + for (int i = 0; i < threadNum; i++) { + list.add( + executor.submit( + () -> { try { - MemoryView view = MemoryManager.getInstance().requireMemory(bMap.length, - MemoryGroupManger.DEFAULT); - view.getWriter().write(bMap); - byte[] after = Snappy.uncompress(Snappy.compress(view)).toArray(); - Assert.assertEquals(after, bMap); + MemoryView view = + MemoryManager.getInstance() + .requireMemory(bMap.length, MemoryGroupManger.DEFAULT); + view.getWriter().write(bMap); + byte[] after = Snappy.uncompress(Snappy.compress(view)).toArray(); + Assert.assertEquals(after, bMap); } catch (IOException e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } - })); - } - list.forEach(c -> { - try { - c.get(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); + })); } + list.forEach( + c -> { + try { + c.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } - @Test - public void multiThreadLz4() { - Map map = new HashMap<>(); - for (int i = 0; i < 10000; i++) { - map.put(Integer.toString(i), Integer.toString(i)); - } - byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); + @Test + public void multiThreadLz4() { + Map map = new HashMap<>(); + for (int i = 0; i < 10000; i++) { + map.put(Integer.toString(i), Integer.toString(i)); + } + byte[] bMap = SerializerFactory.getKryoSerializer().serialize(map); - int threadNum = 5; - ExecutorService executor = Executors.newFixedThreadPool(threadNum); - List list = new ArrayList<>(); - for (int i = 0; i < threadNum; i++) { - list.add(executor.submit(() -> { + int threadNum = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadNum); + List list = new ArrayList<>(); + for (int i = 0; i < threadNum; i++) { + list.add( + executor.submit( + () -> { try { - MemoryView view = MemoryManager.getInstance().requireMemory(bMap.length, - MemoryGroupManger.DEFAULT); - view.getWriter().write(bMap); - byte[] after = Lz4.uncompress(Lz4.compress(view)).toArray(); - Assert.assertEquals(after, bMap); + MemoryView view = + MemoryManager.getInstance() + .requireMemory(bMap.length, MemoryGroupManger.DEFAULT); + view.getWriter().write(bMap); + byte[] after = Lz4.uncompress(Lz4.compress(view)).toArray(); + Assert.assertEquals(after, bMap); } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } - })); - } - list.forEach(c -> { - try { - c.get(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); + })); } + list.forEach( + c -> { + try { + c.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationBenchmark.java index 5ad8a4398..edc1cb4a9 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationBenchmark.java @@ -27,14 +27,15 @@ public class AllocationBenchmark { - public static void main(String[] args) throws RunnerException { + public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() + Options opt = + new OptionsBuilder() .include("Allocation*") .resultFormat(ResultFormatType.JSON) .result("allocation.json") .build(); - new Runner(opt).run(); - } + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationDruidBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationDruidBenchmark.java index 295f4ca8c..ec826264c 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationDruidBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationDruidBenchmark.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; + import org.apache.druid.collections.BlockingPool; import org.apache.druid.collections.DefaultBlockingPool; import org.apache.druid.collections.ReferenceCountingResourceHolder; @@ -42,24 +43,21 @@ @Fork(1) public class AllocationDruidBenchmark extends JMHParameter { - private BlockingPool pool; - - @Setup - public void setUp() { - pool = new DefaultBlockingPool(new OffheapBufferGenerator("1", PAGE_SIZE), 10240); - } - - @Benchmark - public void allocateAndFree() { - List referenceCountingResourceHolderList = - pool.takeBatch((allocateBytes + PAGE_SIZE - 1) / PAGE_SIZE); + private BlockingPool pool; - referenceCountingResourceHolderList.forEach(e -> e.close()); - } + @Setup + public void setUp() { + pool = new DefaultBlockingPool(new OffheapBufferGenerator("1", PAGE_SIZE), 10240); + } + @Benchmark + public void allocateAndFree() { + List referenceCountingResourceHolderList = + pool.takeBatch((allocateBytes + PAGE_SIZE - 1) / PAGE_SIZE); - @TearDown - public void finish() { - } + referenceCountingResourceHolderList.forEach(e -> e.close()); + } + @TearDown + public void finish() {} } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationFlinkBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationFlinkBenchmark.java index 61854fa4e..99cf11577 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationFlinkBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationFlinkBenchmark.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; + import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.memory.MemoryAllocationException; import org.apache.flink.runtime.memory.MemoryManager; @@ -41,25 +42,24 @@ @Fork(1) public class AllocationFlinkBenchmark extends JMHParameter { - private static final long MEMORY_SIZE = 512 * 1024 * 1024; - - private MemoryManager memoryManager; + private static final long MEMORY_SIZE = 512 * 1024 * 1024; - @Setup - public void setUp() { - memoryManager = MemoryManager.create(MEMORY_SIZE, PAGE_SIZE); - } + private MemoryManager memoryManager; - @Benchmark - public void allocateAndFree() throws MemoryAllocationException { - List segments = memoryManager.allocatePages(1, - (allocateBytes + PAGE_SIZE - 1) / PAGE_SIZE); - memoryManager.release(segments); - } + @Setup + public void setUp() { + memoryManager = MemoryManager.create(MEMORY_SIZE, PAGE_SIZE); + } - @TearDown - public void finish() { - memoryManager.shutdown(); - } + @Benchmark + public void allocateAndFree() throws MemoryAllocationException { + List segments = + memoryManager.allocatePages(1, (allocateBytes + PAGE_SIZE - 1) / PAGE_SIZE); + memoryManager.release(segments); + } + @TearDown + public void finish() { + memoryManager.shutdown(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationGeaflowBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationGeaflowBenchmark.java index 534d1e6fd..a545ce3bc 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationGeaflowBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationGeaflowBenchmark.java @@ -19,9 +19,9 @@ package org.apache.geaflow.memory.jmh; -import com.google.common.collect.Maps; import java.util.Map; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; @@ -36,6 +36,8 @@ import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.annotations.Warmup; +import com.google.common.collect.Maps; + @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @Warmup(iterations = 5, time = 1) @@ -43,23 +45,23 @@ @Fork(1) public class AllocationGeaflowBenchmark extends JMHParameter { - @Setup - public void setUp() { - Map config = Maps.newHashMap(); - config.put("max.direct.memory.size.mb", "512"); - config.put("off.heap.memory.chunkSize.MB", "16"); - MemoryManager.build(new Configuration(config)); - } + @Setup + public void setUp() { + Map config = Maps.newHashMap(); + config.put("max.direct.memory.size.mb", "512"); + config.put("off.heap.memory.chunkSize.MB", "16"); + MemoryManager.build(new Configuration(config)); + } - @Benchmark - public void allocateAndFree() { - MemoryView view = MemoryManager.getInstance() - .requireMemory(allocateBytes, MemoryGroupManger.DEFAULT); - view.close(); - } + @Benchmark + public void allocateAndFree() { + MemoryView view = + MemoryManager.getInstance().requireMemory(allocateBytes, MemoryGroupManger.DEFAULT); + view.close(); + } - @TearDown - public void finish() { - MemoryManager.getInstance().dispose(); - } + @TearDown + public void finish() { + MemoryManager.getInstance().dispose(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationSparkBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationSparkBenchmark.java index 681319539..38c1a07e1 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationSparkBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/AllocationSparkBenchmark.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.concurrent.TimeUnit; + import org.apache.spark.SparkConf; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; @@ -43,45 +44,43 @@ @Fork(1) public class AllocationSparkBenchmark extends JMHParameter { - private TaskMemoryManager memoryManager; - private MemoryConsumer consumer; - - @Setup - public void setUp() { - SparkConf conf = new SparkConf(); - conf.set("spark.memory.offHeap.enabled", "true"); - conf.set("spark.memory.offHeap.size", "1342177280"); - memoryManager = new TaskMemoryManager( - new UnifiedMemoryManager(conf, 134217728 * 20L, 1342177280L, 1), 1); - consumer = new LocalMemoryConsumer(memoryManager, MemoryMode.OFF_HEAP); - } - - @Benchmark - public void allocateAndFree() { - MemoryBlock block = memoryManager.allocatePage(allocateBytes, consumer); - memoryManager.freePage(block, consumer); - } + private TaskMemoryManager memoryManager; + private MemoryConsumer consumer; - public void finish() { - memoryManager.cleanUpAllAllocatedMemory(); - } + @Setup + public void setUp() { + SparkConf conf = new SparkConf(); + conf.set("spark.memory.offHeap.enabled", "true"); + conf.set("spark.memory.offHeap.size", "1342177280"); + memoryManager = + new TaskMemoryManager(new UnifiedMemoryManager(conf, 134217728 * 20L, 1342177280L, 1), 1); + consumer = new LocalMemoryConsumer(memoryManager, MemoryMode.OFF_HEAP); + } + @Benchmark + public void allocateAndFree() { + MemoryBlock block = memoryManager.allocatePage(allocateBytes, consumer); + memoryManager.freePage(block, consumer); + } - static class LocalMemoryConsumer extends MemoryConsumer { + public void finish() { + memoryManager.cleanUpAllAllocatedMemory(); + } - protected LocalMemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, - MemoryMode mode) { - super(taskMemoryManager, pageSize, mode); - } + static class LocalMemoryConsumer extends MemoryConsumer { - protected LocalMemoryConsumer(TaskMemoryManager taskMemoryManager, MemoryMode mode) { - super(taskMemoryManager, mode); - } + protected LocalMemoryConsumer( + TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) { + super(taskMemoryManager, pageSize, mode); + } - @Override - public long spill(long size, MemoryConsumer trigger) throws IOException { - return 0; - } + protected LocalMemoryConsumer(TaskMemoryManager taskMemoryManager, MemoryMode mode) { + super(taskMemoryManager, mode); } + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + return 0; + } + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowHeapBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowHeapBenchmark.java index 838ff00ef..716fbebdf 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowHeapBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowHeapBenchmark.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.memory.DirectMemory; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -41,21 +42,20 @@ @State(Scope.Benchmark) public class GeaflowHeapBenchmark { - @Param({"128", "512", "1024", "2048", "5120"}) - public int memBytes; + @Param({"128", "512", "1024", "2048", "5120"}) + public int memBytes; - @Benchmark - public void test() { - for (int i = 0; i < 100000; i++) { - byte[] bytes = new byte[memBytes]; - for (int j = 0; j < memBytes; j += 4) { - Arrays.fill(bytes, j, j + 4, (byte) 1); - } + @Benchmark + public void test() { + for (int i = 0; i < 100000; i++) { + byte[] bytes = new byte[memBytes]; + for (int j = 0; j < memBytes; j += 4) { + Arrays.fill(bytes, j, j + 4, (byte) 1); + } - long address = DirectMemory.unsafe().allocateMemory(memBytes); - DirectMemory.copyMemory(bytes, 0, address, memBytes); - DirectMemory.freeMemory(address); - } + long address = DirectMemory.unsafe().allocateMemory(memBytes); + DirectMemory.copyMemory(bytes, 0, address, memBytes); + DirectMemory.freeMemory(address); } - + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryBenchmark.java index 15e55cd7c..43b940499 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryBenchmark.java @@ -26,10 +26,9 @@ public class GeaflowMemoryBenchmark { - public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder().include("Geaflow*") - .build(); + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include("Geaflow*").build(); - new Runner(opt).run(); - } + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewReaderBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewReaderBenchmark.java index b9afb86de..0ce8bc54f 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewReaderBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewReaderBenchmark.java @@ -19,11 +19,10 @@ package org.apache.geaflow.memory.jmh; - -import com.google.common.collect.Maps; import java.util.Arrays; import java.util.HashMap; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; @@ -41,6 +40,8 @@ import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.annotations.Warmup; +import com.google.common.collect.Maps; + @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) @@ -49,35 +50,34 @@ @State(Scope.Benchmark) public class GeaflowMemoryViewReaderBenchmark { - @Param({"1", "32", "64", "128", "512", "1024", "2048", "4096", "8192", "10240"}) - public int memBytes; - - private MemoryView view; + @Param({"1", "32", "64", "128", "512", "1024", "2048", "4096", "8192", "10240"}) + public int memBytes; - @Setup - public void setup() { - HashMap config = Maps.newHashMap(); - config.put("max.direct.memory.size.mb", "512"); - config.put("off.heap.memory.chunkSize.MB", "16"); - MemoryManager.build(new Configuration(config)); - view = MemoryManager.getInstance().requireMemory(10240, MemoryGroupManger.DEFAULT); + private MemoryView view; - byte[] bytes = new byte[1024]; - Arrays.fill(bytes, (byte) 1); - view.getWriter().write(bytes); - } + @Setup + public void setup() { + HashMap config = Maps.newHashMap(); + config.put("max.direct.memory.size.mb", "512"); + config.put("off.heap.memory.chunkSize.MB", "16"); + MemoryManager.build(new Configuration(config)); + view = MemoryManager.getInstance().requireMemory(10240, MemoryGroupManger.DEFAULT); + byte[] bytes = new byte[1024]; + Arrays.fill(bytes, (byte) 1); + view.getWriter().write(bytes); + } - @Benchmark - public void read() { - for (int i = 0; i < memBytes; i++) { - view.getReader().read(new byte[1]); - } + @Benchmark + public void read() { + for (int i = 0; i < memBytes; i++) { + view.getReader().read(new byte[1]); } + } - @TearDown - public void finish() { - view.close(); - MemoryManager.getInstance().dispose(); - } + @TearDown + public void finish() { + view.close(); + MemoryManager.getInstance().dispose(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewWriterBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewWriterBenchmark.java index 4fe588fe0..fe9a7b0f0 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewWriterBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowMemoryViewWriterBenchmark.java @@ -19,9 +19,9 @@ package org.apache.geaflow.memory.jmh; -import com.google.common.collect.Maps; import java.util.HashMap; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.memory.MemoryGroupManger; import org.apache.geaflow.memory.MemoryManager; @@ -39,6 +39,8 @@ import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.annotations.Warmup; +import com.google.common.collect.Maps; + @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) @@ -47,33 +49,32 @@ @State(Scope.Benchmark) public class GeaflowMemoryViewWriterBenchmark { + @Param({"1", "32", "64", "128", "512", "1024", "2048", "4096", "8192", "10240"}) + public int memBytes; - @Param({"1", "32", "64", "128", "512", "1024", "2048", "4096", "8192", "10240"}) - public int memBytes; - - private MemoryView view; - private MemoryViewWriter writer; - private int spanSize; + private MemoryView view; + private MemoryViewWriter writer; + private int spanSize; - @Benchmark - public void write() { - HashMap config = Maps.newHashMap(); - config.put("max.direct.memory.size.mb", "512"); - config.put("off.heap.memory.chunkSize.MB", "16"); - MemoryManager.build(new Configuration(config)); - view = MemoryManager.getInstance().requireMemory(memBytes, MemoryGroupManger.DEFAULT); + @Benchmark + public void write() { + HashMap config = Maps.newHashMap(); + config.put("max.direct.memory.size.mb", "512"); + config.put("off.heap.memory.chunkSize.MB", "16"); + MemoryManager.build(new Configuration(config)); + view = MemoryManager.getInstance().requireMemory(memBytes, MemoryGroupManger.DEFAULT); - spanSize = memBytes / 2 == 0 ? 1 : memBytes / 2; - writer = view.getWriter(spanSize); + spanSize = memBytes / 2 == 0 ? 1 : memBytes / 2; + writer = view.getWriter(spanSize); - for (int i = 0; i < memBytes; i++) { - writer.write(new byte[spanSize]); - } - view.close(); + for (int i = 0; i < memBytes; i++) { + writer.write(new byte[spanSize]); } + view.close(); + } - @TearDown - public void finish() { - MemoryManager.getInstance().dispose(); - } + @TearDown + public void finish() { + MemoryManager.getInstance().dispose(); + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowOffHeapBenchmark.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowOffHeapBenchmark.java index 555f2eeeb..63417bb88 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowOffHeapBenchmark.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/GeaflowOffHeapBenchmark.java @@ -20,6 +20,7 @@ package org.apache.geaflow.memory.jmh; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.memory.DirectMemory; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -40,23 +41,19 @@ @State(Scope.Benchmark) public class GeaflowOffHeapBenchmark { + @Param({"128", "512", "1024", "2048", "5120"}) + public int memBytes; - @Param({"128", "512", "1024", "2048", "5120"}) - public int memBytes; - - - @Benchmark - public void test() { + @Benchmark + public void test() { - for (int i = 0; i < 100000; i++) { - long address = DirectMemory.unsafe().allocateMemory(memBytes); - for (int j = 0; j < memBytes; ) { - DirectMemory.setMemory(address + j, 4, (byte) 1); - j += 4; - } - DirectMemory.freeMemory(address); - } + for (int i = 0; i < 100000; i++) { + long address = DirectMemory.unsafe().allocateMemory(memBytes); + for (int j = 0; j < memBytes; ) { + DirectMemory.setMemory(address + j, 4, (byte) 1); + j += 4; + } + DirectMemory.freeMemory(address); } - - + } } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/JMHParameter.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/JMHParameter.java index c8bd3d073..356636057 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/JMHParameter.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/jmh/JMHParameter.java @@ -26,9 +26,8 @@ @State(Scope.Benchmark) public class JMHParameter { - protected static final int PAGE_SIZE = 4 * 1024; - - @Param({"100", "1500", "4096", "32768", "35000", "524288", "550000", "4096001", "10485760"}) - public int allocateBytes; + protected static final int PAGE_SIZE = 4 * 1024; + @Param({"100", "1500", "4096", "32768", "35000", "524288", "550000", "4096001", "10485760"}) + public int allocateBytes; } diff --git a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/util/MemoryUtilsTest.java b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/util/MemoryUtilsTest.java index 95416e92d..6495b0cb8 100644 --- a/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/util/MemoryUtilsTest.java +++ b/geaflow/geaflow-memory/src/test/java/org/apache/geaflow/memory/util/MemoryUtilsTest.java @@ -25,12 +25,12 @@ public class MemoryUtilsTest { - @Test - public void test() { - long memSize = 11 * MemoryUtils.MB; - Assert.assertEquals(MemoryUtils.humanReadableByteCount(memSize), "11.00MB"); + @Test + public void test() { + long memSize = 11 * MemoryUtils.MB; + Assert.assertEquals(MemoryUtils.humanReadableByteCount(memSize), "11.00MB"); - memSize = 1024 * MemoryUtils.TB; - Assert.assertEquals(MemoryUtils.humanReadableByteCount(memSize), "1.00PB"); - } + memSize = 1024 * MemoryUtils.TB; + Assert.assertEquals(MemoryUtils.humanReadableByteCount(memSize), "1.00PB"); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/AggType.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/AggType.java index 737436090..488150268 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/AggType.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/AggType.java @@ -21,31 +21,22 @@ public enum AggType { - /** - * Average. - */ - AVG("avg"), - /** - * Sum. - */ - SUM("sum"), - /** - * Max value. - */ - MAX("max"), - /** - * Min value. - */ - MIN("min"); + /** Average. */ + AVG("avg"), + /** Sum. */ + SUM("sum"), + /** Max value. */ + MAX("max"), + /** Min value. */ + MIN("min"); - private final String value; + private final String value; - AggType(String value) { - this.value = value; - } - - public String getValue() { - return this.value; - } + AggType(String value) { + this.value = value; + } + public String getValue() { + return this.value; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/DownSample.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/DownSample.java index ce2c9ffb0..f22a20870 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/DownSample.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/DownSample.java @@ -21,28 +21,23 @@ public enum DownSample { - /** - * Time serial down sample, average value in 60 seconds. - */ - AVG("60s-avg"), - /** - * Time serial down sample, sum value in 60 seconds. - */ - SUM("60s-sum"); + /** Time serial down sample, average value in 60 seconds. */ + AVG("60s-avg"), + /** Time serial down sample, sum value in 60 seconds. */ + SUM("60s-sum"); - private final String value; + private final String value; - DownSample(String value) { - this.value = value; - } + DownSample(String value) { + this.value = value; + } - public String getValue() { - return value; - } - - @Override - public String toString() { - return getValue(); - } + public String getValue() { + return value; + } + @Override + public String toString() { + return getValue(); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/HistAggType.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/HistAggType.java index b4542e433..a8543c29c 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/HistAggType.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/HistAggType.java @@ -21,49 +21,34 @@ public enum HistAggType { - /** - * Default histogram aggregation types. - */ - DEFAULT(new String[]{"max", "p999", "p99", "p95", "p50"}), + /** Default histogram aggregation types. */ + DEFAULT(new String[] {"max", "p999", "p99", "p95", "p50"}), - /** - * Default histogram aggregation MAX. - */ - MAX(new String[]{"max"}), + /** Default histogram aggregation MAX. */ + MAX(new String[] {"max"}), - /** - * Default histogram aggregation MIN. - */ - MIN(new String[]{"min"}), + /** Default histogram aggregation MIN. */ + MIN(new String[] {"min"}), - /** - * Default histogram aggregation p999. - */ - P999(new String[]{"p999"}), + /** Default histogram aggregation p999. */ + P999(new String[] {"p999"}), - /** - * Default histogram aggregation p99. - */ - P99(new String[]{"p99"}), + /** Default histogram aggregation p99. */ + P99(new String[] {"p99"}), - /** - * Default histogram aggregation p95. - */ - P95(new String[]{"p95"}), + /** Default histogram aggregation p95. */ + P95(new String[] {"p95"}), - /** - * Default histogram aggregation p50. - */ - P50(new String[]{"p50"}); + /** Default histogram aggregation p50. */ + P50(new String[] {"p50"}); - private String[] aggTypes; + private String[] aggTypes; - HistAggType(String[] aggTypes) { - this.aggTypes = aggTypes; - } - - public String[] getAggTypes() { - return aggTypes; - } + HistAggType(String[] aggTypes) { + this.aggTypes = aggTypes; + } + public String[] getAggTypes() { + return aggTypes; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConfig.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConfig.java index 311f6ff8d..978a98e69 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConfig.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConfig.java @@ -27,52 +27,53 @@ import java.io.Serializable; import java.util.Random; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; public class MetricConfig implements Serializable { - private static final int WAIT_SECONDS = 5; - private static final int RANDOM_SECONDS = 5; - private static final int PERIOD_SECONDS = 5; + private static final int WAIT_SECONDS = 5; + private static final int RANDOM_SECONDS = 5; + private static final int PERIOD_SECONDS = 5; - private static final String REPORTER_SCHEDULE_PERIOD = "geaflow.metric.%s.schedule.period.sec"; + private static final String REPORTER_SCHEDULE_PERIOD = "geaflow.metric.%s.schedule.period.sec"; - private final int schedulePeriod; - private final Configuration config; - private final Random random; + private final int schedulePeriod; + private final Configuration config; + private final Random random; - public MetricConfig(Configuration configuration) { - this.config = configuration; - this.random = new Random(); - this.schedulePeriod = configuration.getInteger(SCHEDULE_PERIOD); - } + public MetricConfig(Configuration configuration) { + this.config = configuration; + this.random = new Random(); + this.schedulePeriod = configuration.getInteger(SCHEDULE_PERIOD); + } - public String getReporterList() { - boolean isLocal = config.getBoolean(ExecutionConfigKeys.RUN_LOCAL_MODE); - if (isLocal) { - return config.getString(REPORTER_LIST, ""); - } else { - return config.getString(REPORTER_LIST); - } + public String getReporterList() { + boolean isLocal = config.getBoolean(ExecutionConfigKeys.RUN_LOCAL_MODE); + if (isLocal) { + return config.getString(REPORTER_LIST, ""); + } else { + return config.getString(REPORTER_LIST); } + } - public int getSchedulePeriodSec(String reporterName) { - String periodKey = String.format(REPORTER_SCHEDULE_PERIOD, reporterName); - return config.getInteger(periodKey, schedulePeriod); - } + public int getSchedulePeriodSec(String reporterName) { + String periodKey = String.format(REPORTER_SCHEDULE_PERIOD, reporterName); + return config.getInteger(periodKey, schedulePeriod); + } - public int getRandomDelaySec() { - int randomDelay = random.nextInt(RANDOM_SECONDS) + WAIT_SECONDS; - return config.getInteger(METRIC_META_REPORT_DELAY, randomDelay); - } + public int getRandomDelaySec() { + int randomDelay = random.nextInt(RANDOM_SECONDS) + WAIT_SECONDS; + return config.getInteger(METRIC_META_REPORT_DELAY, randomDelay); + } - public int getRandomPeriodSec() { - int randomPeriod = random.nextInt(PERIOD_SECONDS) + 1; - return config.getInteger(METRIC_META_REPORT_PERIOD, randomPeriod); - } + public int getRandomPeriodSec() { + int randomPeriod = random.nextInt(PERIOD_SECONDS) + 1; + return config.getInteger(METRIC_META_REPORT_PERIOD, randomPeriod); + } - public int getReportMaxRetries() { - return config.getInteger(METRIC_META_REPORT_RETRIES); - } + public int getReportMaxRetries() { + return config.getInteger(METRIC_META_REPORT_RETRIES); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConstants.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConstants.java index 5f2725eb4..bbb1e5193 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConstants.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricConstants.java @@ -21,67 +21,61 @@ public class MetricConstants { - public static final String GROUP_DELIMITER = "/"; + public static final String GROUP_DELIMITER = "/"; - /** - * Metric module. - */ - public static final String MODULE_DEFAULT = "default"; - public static final String MODULE_SYSTEM = "system"; - public static final String MODULE_DSL = "dsl"; - public static final String MODULE_FRAMEWORK = "framework"; + /** Metric module. */ + public static final String MODULE_DEFAULT = "default"; - /** - * System metric name. - */ - public static final String METRIC_TOTAL_HEAP = "totalUsedHeapMb"; - public static final String METRIC_TOTAL_MEMORY = "totalMemoryMb"; - public static final String METRIC_HEAP_USAGE_RATIO = "heapUsageRatio"; - public static final String METRIC_GC_TIME = "gcTime"; - public static final String METRIC_FGC_COUNT = "fgcCount"; - public static final String METRIC_FGC_TIME = "fgcTime"; + public static final String MODULE_SYSTEM = "system"; + public static final String MODULE_DSL = "dsl"; + public static final String MODULE_FRAMEWORK = "framework"; - /** - * Operator metric name. - */ - public static final String METRIC_INPUT_TPS = "inputTps"; - public static final String METRIC_OUTPUT_TPS = "outputTps"; - public static final String METRIC_VERTEX_TPS = "vertexTps"; - public static final String METRIC_EDGE_TPS = "edgeTps"; - public static final String METRIC_PROCESS_RT = "processRt"; - public static final String METRIC_ITERATION = "iteration"; - public static final String METRIC_ITERATION_MSG_TPS = "iterationMsgTps"; - public static final String METRIC_ITERATION_AGG_TPS = "iterationAggTps"; + /** System metric name. */ + public static final String METRIC_TOTAL_HEAP = "totalUsedHeapMb"; - /** - * Dsl metric name. - */ - public static final String METRIC_TABLE_INPUT_ROW = "tableInputRow"; - public static final String METRIC_STEP_INPUT_RECORD = "stepInputRecord"; - public static final String METRIC_STEP_OUTPUT_RECORD = "stepOutputRecord"; - public static final String METRIC_STEP_INPUT_EOD = "stepInputEod"; - public static final String METRIC_TABLE_OUTPUT_ROW_TPS = "tableOutputRowTps"; - public static final String METRIC_TABLE_INPUT_ROW_TPS = "tableInputRowTps"; - public static final String METRIC_TABLE_INPUT_BLOCK_TPS = "tableInputBlockTps"; - public static final String METRIC_STEP_INPUT_ROW_TPS = "stepInputRowTps"; - public static final String METRIC_STEP_OUTPUT_ROW_TPS = "stepOutputRowTps"; - public static final String METRIC_TABLE_WRITE_TIME_RT = "tableWriteTimeRt"; - public static final String METRIC_TABLE_FLUSH_TIME_RT = "tableFlushTimeRt"; - public static final String METRIC_TABLE_PARSER_TIME_RT = "tableParserTimeRt"; - public static final String METRIC_STEP_PROCESS_TIME_RT = "stepProcessTimeRt"; - public static final String METRIC_LOAD_EDGE_COUNT_RT = "loadEdgeCountRt"; - public static final String METRIC_LOAD_EDGE_TIME_RT = "loadEdgeTimeRt"; - public static final String METRIC_LOAD_VERTEX_TIME_RT = "loadVertexTimeRt"; + public static final String METRIC_TOTAL_MEMORY = "totalMemoryMb"; + public static final String METRIC_HEAP_USAGE_RATIO = "heapUsageRatio"; + public static final String METRIC_GC_TIME = "gcTime"; + public static final String METRIC_FGC_COUNT = "fgcCount"; + public static final String METRIC_FGC_TIME = "fgcTime"; - /** - * Metric unit. - */ - public static final String UNIT_N = "(N)"; - public static final String UNIT_S = "(s)"; - public static final String UNIT_MS = "(ms)"; - public static final String UNIT_US = "(us)"; - public static final String UNIT_NS = "(ns)"; - public static final String UNIT_ROW_PER_S = "(row/s)"; - public static final String UNIT_BLOCK_PER_S = "(block/s)"; + /** Operator metric name. */ + public static final String METRIC_INPUT_TPS = "inputTps"; + public static final String METRIC_OUTPUT_TPS = "outputTps"; + public static final String METRIC_VERTEX_TPS = "vertexTps"; + public static final String METRIC_EDGE_TPS = "edgeTps"; + public static final String METRIC_PROCESS_RT = "processRt"; + public static final String METRIC_ITERATION = "iteration"; + public static final String METRIC_ITERATION_MSG_TPS = "iterationMsgTps"; + public static final String METRIC_ITERATION_AGG_TPS = "iterationAggTps"; + + /** Dsl metric name. */ + public static final String METRIC_TABLE_INPUT_ROW = "tableInputRow"; + + public static final String METRIC_STEP_INPUT_RECORD = "stepInputRecord"; + public static final String METRIC_STEP_OUTPUT_RECORD = "stepOutputRecord"; + public static final String METRIC_STEP_INPUT_EOD = "stepInputEod"; + public static final String METRIC_TABLE_OUTPUT_ROW_TPS = "tableOutputRowTps"; + public static final String METRIC_TABLE_INPUT_ROW_TPS = "tableInputRowTps"; + public static final String METRIC_TABLE_INPUT_BLOCK_TPS = "tableInputBlockTps"; + public static final String METRIC_STEP_INPUT_ROW_TPS = "stepInputRowTps"; + public static final String METRIC_STEP_OUTPUT_ROW_TPS = "stepOutputRowTps"; + public static final String METRIC_TABLE_WRITE_TIME_RT = "tableWriteTimeRt"; + public static final String METRIC_TABLE_FLUSH_TIME_RT = "tableFlushTimeRt"; + public static final String METRIC_TABLE_PARSER_TIME_RT = "tableParserTimeRt"; + public static final String METRIC_STEP_PROCESS_TIME_RT = "stepProcessTimeRt"; + public static final String METRIC_LOAD_EDGE_COUNT_RT = "loadEdgeCountRt"; + public static final String METRIC_LOAD_EDGE_TIME_RT = "loadEdgeTimeRt"; + public static final String METRIC_LOAD_VERTEX_TIME_RT = "loadVertexTimeRt"; + + /** Metric unit. */ + public static final String UNIT_N = "(N)"; + + public static final String UNIT_S = "(s)"; + public static final String UNIT_MS = "(ms)"; + public static final String UNIT_US = "(us)"; + public static final String UNIT_NS = "(ns)"; + public static final String UNIT_ROW_PER_S = "(row/s)"; + public static final String UNIT_BLOCK_PER_S = "(block/s)"; } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricGroupRegistry.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricGroupRegistry.java index 4ecbd19fc..9c2f9a723 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricGroupRegistry.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricGroupRegistry.java @@ -19,8 +19,6 @@ package org.apache.geaflow.metrics.common; -import com.codahale.metrics.MetricRegistry; -import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -29,6 +27,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -41,101 +40,108 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class MetricGroupRegistry { - private static final Logger LOGGER = LoggerFactory.getLogger(MetricGroupRegistry.class); - - private static final String REPORTER_SEPARATOR = ","; - private static MetricGroupRegistry INSTANCE; - - private final MetricRegistry metricRegistry; - private final List reporterList; - private final ScheduledExecutorService executor; - private final Map metricGroupMap; - - private MetricGroupRegistry() { - this.metricRegistry = new MetricRegistry(); - this.reporterList = new ArrayList<>(); - this.executor = new ScheduledThreadPoolExecutor(1, - ThreadUtil.namedThreadFactory(true, "metricService")); - this.metricGroupMap = new ConcurrentHashMap<>(); - } - - public static synchronized MetricGroupRegistry getInstance(Configuration config) { - if (INSTANCE == null) { - INSTANCE = new MetricGroupRegistry(); - INSTANCE.open(config); - } - return INSTANCE; - } - - public static MetricGroupRegistry getInstance() { - return INSTANCE; - } - - public MetricGroup getMetricGroup() { - return metricGroupMap.computeIfAbsent(MetricConstants.MODULE_DEFAULT, - name -> new MetricGroupImpl(metricRegistry)); - } +import com.codahale.metrics.MetricRegistry; +import com.google.common.annotations.VisibleForTesting; - public MetricGroup getMetricGroup(String groupName) { - return metricGroupMap.computeIfAbsent(groupName, name -> new MetricGroupImpl(name, - metricRegistry)); +public class MetricGroupRegistry { + private static final Logger LOGGER = LoggerFactory.getLogger(MetricGroupRegistry.class); + + private static final String REPORTER_SEPARATOR = ","; + private static MetricGroupRegistry INSTANCE; + + private final MetricRegistry metricRegistry; + private final List reporterList; + private final ScheduledExecutorService executor; + private final Map metricGroupMap; + + private MetricGroupRegistry() { + this.metricRegistry = new MetricRegistry(); + this.reporterList = new ArrayList<>(); + this.executor = + new ScheduledThreadPoolExecutor(1, ThreadUtil.namedThreadFactory(true, "metricService")); + this.metricGroupMap = new ConcurrentHashMap<>(); + } + + public static synchronized MetricGroupRegistry getInstance(Configuration config) { + if (INSTANCE == null) { + INSTANCE = new MetricGroupRegistry(); + INSTANCE.open(config); } - - @VisibleForTesting - public List getReporterList() { - return reporterList; + return INSTANCE; + } + + public static MetricGroupRegistry getInstance() { + return INSTANCE; + } + + public MetricGroup getMetricGroup() { + return metricGroupMap.computeIfAbsent( + MetricConstants.MODULE_DEFAULT, name -> new MetricGroupImpl(metricRegistry)); + } + + public MetricGroup getMetricGroup(String groupName) { + return metricGroupMap.computeIfAbsent( + groupName, name -> new MetricGroupImpl(name, metricRegistry)); + } + + @VisibleForTesting + public List getReporterList() { + return reporterList; + } + + private void open(Configuration config) { + MetricConfig metricConfig = new MetricConfig(config); + + String reporterList = metricConfig.getReporterList(); + if (StringUtils.isEmpty(reporterList)) { + LOGGER.warn("report list is empty"); + return; } - - private void open(Configuration config) { - MetricConfig metricConfig = new MetricConfig(config); - - String reporterList = metricConfig.getReporterList(); - if (StringUtils.isEmpty(reporterList)) { - LOGGER.warn("report list is empty"); - return; - } - String[] reporters = reporterList.split(REPORTER_SEPARATOR); - try { - for (String reporter : reporters) { - MetricReporter metricReporter = MetricReporterFactory.getMetricReporter(reporter.toLowerCase()); - metricReporter.open(config, metricRegistry); - if (metricReporter instanceof ScheduledReporter) { - int period = metricConfig.getSchedulePeriodSec(reporter); - executor.scheduleWithFixedDelay(new ReporterTask( - (ScheduledReporter) metricReporter), period, period, TimeUnit.SECONDS); - LOGGER.info("schedule {} with duration {}s", reporter, period); - } - this.reporterList.add(metricReporter); - } - } catch (Exception e) { - LOGGER.error("failed to initialized metricReporters", e); - throw new GeaflowRuntimeException(e); + String[] reporters = reporterList.split(REPORTER_SEPARATOR); + try { + for (String reporter : reporters) { + MetricReporter metricReporter = + MetricReporterFactory.getMetricReporter(reporter.toLowerCase()); + metricReporter.open(config, metricRegistry); + if (metricReporter instanceof ScheduledReporter) { + int period = metricConfig.getSchedulePeriodSec(reporter); + executor.scheduleWithFixedDelay( + new ReporterTask((ScheduledReporter) metricReporter), + period, + period, + TimeUnit.SECONDS); + LOGGER.info("schedule {} with duration {}s", reporter, period); } + this.reporterList.add(metricReporter); + } + } catch (Exception e) { + LOGGER.error("failed to initialized metricReporters", e); + throw new GeaflowRuntimeException(e); } + } - public void close() { - LOGGER.info("close metric service"); - for (MetricReporter metricReporter : reporterList) { - metricReporter.close(); - } + public void close() { + LOGGER.info("close metric service"); + for (MetricReporter metricReporter : reporterList) { + metricReporter.close(); } + } - private static final class ReporterTask extends TimerTask { + private static final class ReporterTask extends TimerTask { - private final ScheduledReporter reporter; + private final ScheduledReporter reporter; - private ReporterTask(ScheduledReporter reporter) { - this.reporter = reporter; - } + private ReporterTask(ScheduledReporter reporter) { + this.reporter = reporter; + } - @Override - public void run() { - try { - reporter.report(); - } catch (Throwable t) { - LOGGER.warn("Error while reporting metrics", t); - } - } + @Override + public void run() { + try { + reporter.report(); + } catch (Throwable t) { + LOGGER.warn("Error while reporting metrics", t); + } } + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricMeta.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricMeta.java index be80bda65..94c5512be 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricMeta.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricMeta.java @@ -23,67 +23,67 @@ public class MetricMeta implements Serializable { - private String jobName; - private String metricName; - private String metricType; - private String queries; - private String metricGroup; - private String metaType; - private String metaGroup; - - public String getJobName() { - return this.jobName; - } - - public void setJobName(String jobName) { - this.jobName = jobName; - } - - public String getMetricName() { - return this.metricName; - } - - public void setMetricName(String metricName) { - this.metricName = metricName; - } - - public String getMetricType() { - return this.metricType; - } - - public void setMetricType(String metricType) { - this.metricType = metricType; - } - - public String getQueries() { - return this.queries; - } - - public void setQueries(String queries) { - this.queries = queries; - } - - public String getMetricGroup() { - return this.metricGroup; - } - - public void setMetricGroup(String metricGroup) { - this.metricGroup = metricGroup; - } - - public String getMetaType() { - return this.metaType; - } - - public void setMetaType(String metaType) { - this.metaType = metaType; - } - - public String getMetaGroup() { - return this.metaGroup; - } - - public void setMetaGroup(String metaGroup) { - this.metaGroup = metaGroup; - } + private String jobName; + private String metricName; + private String metricType; + private String queries; + private String metricGroup; + private String metaType; + private String metaGroup; + + public String getJobName() { + return this.jobName; + } + + public void setJobName(String jobName) { + this.jobName = jobName; + } + + public String getMetricName() { + return this.metricName; + } + + public void setMetricName(String metricName) { + this.metricName = metricName; + } + + public String getMetricType() { + return this.metricType; + } + + public void setMetricType(String metricType) { + this.metricType = metricType; + } + + public String getQueries() { + return this.queries; + } + + public void setQueries(String queries) { + this.queries = queries; + } + + public String getMetricGroup() { + return this.metricGroup; + } + + public void setMetricGroup(String metricGroup) { + this.metricGroup = metricGroup; + } + + public String getMetaType() { + return this.metaType; + } + + public void setMetaType(String metaType) { + this.metaType = metaType; + } + + public String getMetaGroup() { + return this.metaGroup; + } + + public void setMetaGroup(String metaGroup) { + this.metaGroup = metaGroup; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricNameFormatter.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricNameFormatter.java index 71852ddd5..81557c880 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricNameFormatter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricNameFormatter.java @@ -23,163 +23,171 @@ public class MetricNameFormatter { - ////////////////////////////// - // System - - /// /////////////////////////// - - public static String totalHeapMetricName() { - return MetricConstants.METRIC_TOTAL_HEAP + MetricConstants.UNIT_N; - } - - public static String totalMemoryMetricName() { - return MetricConstants.METRIC_TOTAL_MEMORY + MetricConstants.UNIT_N; - } - - public static String heapUsageRatioMetricName() { - return MetricConstants.METRIC_HEAP_USAGE_RATIO + MetricConstants.UNIT_N; - } - - public static String gcTimeMetricName() { - return MetricConstants.METRIC_GC_TIME + MetricConstants.UNIT_MS; - } - - public static String fgcCountMetricName() { - return MetricConstants.METRIC_FGC_COUNT + MetricConstants.UNIT_N; - } - - public static String fgcTimeMetricName() { - return MetricConstants.METRIC_FGC_TIME + MetricConstants.UNIT_MS; - } - - ////////////////////////////// - // Framework - - /// /////////////////////////// - - public static String inputTpsMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_INPUT_TPS, metricName); - } - - public static String outputTpsMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_OUTPUT_TPS, metricName); - } - - public static String rtMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_US); - return MetricRegistry.name(MetricConstants.METRIC_PROCESS_RT, metricName); - } - - public static String vertexTpsMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_VERTEX_TPS, metricName); - } - - public static String edgeTpsMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_EDGE_TPS, metricName); - } - - public static String iterationFinishMetricName(Class opClass, int opId, long iterationId) { - String metricName = String.format("%s[%d:%d]%s", opClass.getSimpleName(), opId, iterationId, MetricConstants.UNIT_MS); - return MetricRegistry.name(MetricConstants.METRIC_ITERATION, metricName); - } - - public static String iterationMsgMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_ITERATION_MSG_TPS, metricName); - } - - public static String iterationAggMetricName(Class opClass, int opId) { - String metricName = String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_ITERATION_AGG_TPS, metricName); - } - - ////////////////////////////// - // Dsl - - /// /////////////////////////// - - public static String tableInputRowName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_ROW, metricName); - } - - public static String stepInputRecordName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_RECORD, metricName); - } - - public static String stepOutputRecordName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_STEP_OUTPUT_RECORD, metricName); - } - - public static String stepInputEodName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_EOD, metricName); - } - - public static String tableOutputRowTpsName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_OUTPUT_ROW_TPS, metricName); - } - - public static String tableInputRowTpsName(String tableName) { - String metricName = String.format("%s%s", tableName, MetricConstants.UNIT_ROW_PER_S); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_ROW_TPS, metricName); - } - - public static String tableInputBlockTpsName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_BLOCK_PER_S); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_BLOCK_TPS, metricName); - } - - public static String stepInputRowTpsName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); - return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_ROW_TPS, metricName); - } - - public static String stepOutputRowTpsName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); - return MetricRegistry.name(MetricConstants.METRIC_STEP_OUTPUT_ROW_TPS, metricName); - } - - public static String tableWriteTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_WRITE_TIME_RT, metricName); - } - - public static String tableFlushTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_FLUSH_TIME_RT, metricName); - } - - public static String tableParserTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_US); - return MetricRegistry.name(MetricConstants.METRIC_TABLE_PARSER_TIME_RT, metricName); - } - - public static String stepProcessTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_US); - return MetricRegistry.name(MetricConstants.METRIC_STEP_PROCESS_TIME_RT, metricName); - } - - public static String loadEdgeCountRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); - return MetricRegistry.name(MetricConstants.METRIC_LOAD_EDGE_COUNT_RT, metricName); - } - - public static String loadEdgeTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); - return MetricRegistry.name(MetricConstants.METRIC_LOAD_EDGE_TIME_RT, metricName); - } - - public static String loadVertexTimeRtName(String name) { - String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); - return MetricRegistry.name(MetricConstants.METRIC_LOAD_VERTEX_TIME_RT, metricName); - } - + ////////////////////////////// + // System + + /// /////////////////////////// + + public static String totalHeapMetricName() { + return MetricConstants.METRIC_TOTAL_HEAP + MetricConstants.UNIT_N; + } + + public static String totalMemoryMetricName() { + return MetricConstants.METRIC_TOTAL_MEMORY + MetricConstants.UNIT_N; + } + + public static String heapUsageRatioMetricName() { + return MetricConstants.METRIC_HEAP_USAGE_RATIO + MetricConstants.UNIT_N; + } + + public static String gcTimeMetricName() { + return MetricConstants.METRIC_GC_TIME + MetricConstants.UNIT_MS; + } + + public static String fgcCountMetricName() { + return MetricConstants.METRIC_FGC_COUNT + MetricConstants.UNIT_N; + } + + public static String fgcTimeMetricName() { + return MetricConstants.METRIC_FGC_TIME + MetricConstants.UNIT_MS; + } + + ////////////////////////////// + // Framework + + /// /////////////////////////// + + public static String inputTpsMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_INPUT_TPS, metricName); + } + + public static String outputTpsMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_OUTPUT_TPS, metricName); + } + + public static String rtMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_US); + return MetricRegistry.name(MetricConstants.METRIC_PROCESS_RT, metricName); + } + + public static String vertexTpsMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_VERTEX_TPS, metricName); + } + + public static String edgeTpsMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_EDGE_TPS, metricName); + } + + public static String iterationFinishMetricName(Class opClass, int opId, long iterationId) { + String metricName = + String.format( + "%s[%d:%d]%s", opClass.getSimpleName(), opId, iterationId, MetricConstants.UNIT_MS); + return MetricRegistry.name(MetricConstants.METRIC_ITERATION, metricName); + } + + public static String iterationMsgMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_ITERATION_MSG_TPS, metricName); + } + + public static String iterationAggMetricName(Class opClass, int opId) { + String metricName = + String.format("%s[%d]%s", opClass.getSimpleName(), opId, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_ITERATION_AGG_TPS, metricName); + } + + ////////////////////////////// + // Dsl + + /// /////////////////////////// + + public static String tableInputRowName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_ROW, metricName); + } + + public static String stepInputRecordName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_RECORD, metricName); + } + + public static String stepOutputRecordName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_STEP_OUTPUT_RECORD, metricName); + } + + public static String stepInputEodName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_EOD, metricName); + } + + public static String tableOutputRowTpsName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_OUTPUT_ROW_TPS, metricName); + } + + public static String tableInputRowTpsName(String tableName) { + String metricName = String.format("%s%s", tableName, MetricConstants.UNIT_ROW_PER_S); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_ROW_TPS, metricName); + } + + public static String tableInputBlockTpsName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_BLOCK_PER_S); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_INPUT_BLOCK_TPS, metricName); + } + + public static String stepInputRowTpsName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); + return MetricRegistry.name(MetricConstants.METRIC_STEP_INPUT_ROW_TPS, metricName); + } + + public static String stepOutputRowTpsName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_ROW_PER_S); + return MetricRegistry.name(MetricConstants.METRIC_STEP_OUTPUT_ROW_TPS, metricName); + } + + public static String tableWriteTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_WRITE_TIME_RT, metricName); + } + + public static String tableFlushTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_FLUSH_TIME_RT, metricName); + } + + public static String tableParserTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_US); + return MetricRegistry.name(MetricConstants.METRIC_TABLE_PARSER_TIME_RT, metricName); + } + + public static String stepProcessTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_US); + return MetricRegistry.name(MetricConstants.METRIC_STEP_PROCESS_TIME_RT, metricName); + } + + public static String loadEdgeCountRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_N); + return MetricRegistry.name(MetricConstants.METRIC_LOAD_EDGE_COUNT_RT, metricName); + } + + public static String loadEdgeTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); + return MetricRegistry.name(MetricConstants.METRIC_LOAD_EDGE_TIME_RT, metricName); + } + + public static String loadVertexTimeRtName(String name) { + String metricName = String.format("%s%s", name, MetricConstants.UNIT_MS); + return MetricRegistry.name(MetricConstants.METRIC_LOAD_VERTEX_TIME_RT, metricName); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricType.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricType.java index 4aec2f6cd..2233db7e0 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricType.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/MetricType.java @@ -21,21 +21,12 @@ public enum MetricType { - /** - * A gauge is the simplest metric type. It just returns a value. - */ - GAUGE, - /** - * A counter is a simple incrementing and decrementing 64-bit integer. - */ - COUNTER, - /** - * A meter measures the rate at which a set of events occur. - */ - METER, - /** - * A histogram measures the distribution of values in a stream of data. - */ - HISTOGRAM - + /** A gauge is the simplest metric type. It just returns a value. */ + GAUGE, + /** A counter is a simple incrementing and decrementing 64-bit integer. */ + COUNTER, + /** A meter measures the rate at which a set of events occur. */ + METER, + /** A histogram measures the distribution of values in a stream of data. */ + HISTOGRAM } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/BlackHoleMetricGroup.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/BlackHoleMetricGroup.java index c92f29178..266e36ac0 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/BlackHoleMetricGroup.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/BlackHoleMetricGroup.java @@ -21,145 +21,129 @@ public class BlackHoleMetricGroup implements MetricGroup { - public static final BlackHoleMetricGroup INSTANCE = new BlackHoleMetricGroup(); + public static final BlackHoleMetricGroup INSTANCE = new BlackHoleMetricGroup(); - @Override - public void register(String name, Gauge gauge) { - } + @Override + public void register(String name, Gauge gauge) {} - @SuppressWarnings("unchecked") - @Override - public Gauge gauge(String name) { - return (Gauge) BlackHoleGauge.INSTANCE; - } + @SuppressWarnings("unchecked") + @Override + public Gauge gauge(String name) { + return (Gauge) BlackHoleGauge.INSTANCE; + } - @Override - public Counter counter(String name) { - return BlackHoleCounter.INSTANCE; - } + @Override + public Counter counter(String name) { + return BlackHoleCounter.INSTANCE; + } - @Override - public Meter meter(String name) { - return BlackHoleMeter.INSTANCE; - } + @Override + public Meter meter(String name) { + return BlackHoleMeter.INSTANCE; + } - @Override - public Histogram histogram(String name) { - return BlackHoleHistogram.INSTANCE; - } + @Override + public Histogram histogram(String name) { + return BlackHoleHistogram.INSTANCE; + } - @Override - public void remove(String name) { - } - - public static class BlackHoleGauge implements Gauge { - - public static final BlackHoleGauge INSTANCE = new BlackHoleGauge<>(); + @Override + public void remove(String name) {} - @Override - public T getValue() { - return null; - } + public static class BlackHoleGauge implements Gauge { - @Override - public void setValue(T value) { - } + public static final BlackHoleGauge INSTANCE = new BlackHoleGauge<>(); + @Override + public T getValue() { + return null; } - public static class BlackHoleCounter implements Counter { - - public static final BlackHoleCounter INSTANCE = new BlackHoleCounter(); + @Override + public void setValue(T value) {} + } - @Override - public void inc() { - } + public static class BlackHoleCounter implements Counter { - @Override - public void inc(long n) { - } + public static final BlackHoleCounter INSTANCE = new BlackHoleCounter(); - @Override - public void dec() { - } + @Override + public void inc() {} - @Override - public void dec(long n) { - } + @Override + public void inc(long n) {} - @Override - public long getCount() { - return 0; - } + @Override + public void dec() {} - @Override - public long getCountAndReset() { - return 0; - } + @Override + public void dec(long n) {} + @Override + public long getCount() { + return 0; } - public static class BlackHoleMeter implements Meter { - - public static final BlackHoleMeter INSTANCE = new BlackHoleMeter(); + @Override + public long getCountAndReset() { + return 0; + } + } - @Override - public void mark() { - } + public static class BlackHoleMeter implements Meter { - @Override - public void mark(long n) { - } + public static final BlackHoleMeter INSTANCE = new BlackHoleMeter(); - @Override - public double getFifteenMinuteRate() { - return 0; - } + @Override + public void mark() {} - @Override - public double getFiveMinuteRate() { - return 0; - } + @Override + public void mark(long n) {} - @Override - public double getMeanRate() { - return 0; - } + @Override + public double getFifteenMinuteRate() { + return 0; + } - @Override - public double getOneMinuteRate() { - return 0; - } + @Override + public double getFiveMinuteRate() { + return 0; + } - @Override - public long getCount() { - return 0; - } + @Override + public double getMeanRate() { + return 0; + } - @Override - public long getCountAndReset() { - return 0; - } + @Override + public double getOneMinuteRate() { + return 0; + } + @Override + public long getCount() { + return 0; } - public static class BlackHoleHistogram implements Histogram { + @Override + public long getCountAndReset() { + return 0; + } + } - public static final BlackHoleHistogram INSTANCE = new BlackHoleHistogram(); + public static class BlackHoleHistogram implements Histogram { - @Override - public void update(int value) { - } + public static final BlackHoleHistogram INSTANCE = new BlackHoleHistogram(); - @Override - public void update(long value) { - } + @Override + public void update(int value) {} - @Override - public long getCount() { - return 0; - } + @Override + public void update(long value) {} + @Override + public long getCount() { + return 0; } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Counter.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Counter.java index 77b1db59c..071e3a2af 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Counter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Counter.java @@ -21,15 +21,15 @@ public interface Counter { - void inc(); + void inc(); - void inc(long n); + void inc(long n); - void dec(); + void dec(); - void dec(long n); + void dec(long n); - long getCount(); + long getCount(); - long getCountAndReset(); + long getCountAndReset(); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/CounterImpl.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/CounterImpl.java index bfbc6a22d..a8e159eee 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/CounterImpl.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/CounterImpl.java @@ -19,39 +19,37 @@ package org.apache.geaflow.metrics.common.api; -import com.codahale.metrics.Counter; import java.util.concurrent.atomic.LongAdder; -public class CounterImpl extends Counter implements - org.apache.geaflow.metrics.common.api.Counter { +import com.codahale.metrics.Counter; - private final LongAdder count = new LongAdder(); +public class CounterImpl extends Counter implements org.apache.geaflow.metrics.common.api.Counter { - public CounterImpl() { - } + private final LongAdder count = new LongAdder(); - public void inc() { - this.inc(1L); - } + public CounterImpl() {} - public void inc(long n) { - this.count.add(n); - } + public void inc() { + this.inc(1L); + } - public void dec() { - this.dec(1L); - } + public void inc(long n) { + this.count.add(n); + } - public void dec(long n) { - this.count.add(-n); - } + public void dec() { + this.dec(1L); + } - public long getCount() { - return this.count.sum(); - } + public void dec(long n) { + this.count.add(-n); + } - public long getCountAndReset() { - return this.count.sumThenReset(); - } + public long getCount() { + return this.count.sum(); + } + public long getCountAndReset() { + return this.count.sumThenReset(); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Gauge.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Gauge.java index b47c5f6d5..5234647e8 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Gauge.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Gauge.java @@ -21,8 +21,7 @@ public interface Gauge extends com.codahale.metrics.Gauge { - T getValue(); - - void setValue(T value); + T getValue(); + void setValue(T value); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/GaugeImpl.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/GaugeImpl.java index bf69b1ee8..ec1c565d7 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/GaugeImpl.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/GaugeImpl.java @@ -21,24 +21,23 @@ public class GaugeImpl implements Gauge { - private volatile T value; + private volatile T value; - public GaugeImpl() { - this(null); - } + public GaugeImpl() { + this(null); + } - public GaugeImpl(T defaultValue) { - this.value = defaultValue; - } + public GaugeImpl(T defaultValue) { + this.value = defaultValue; + } - @Override - public void setValue(T value) { - this.value = value; - } - - @Override - public T getValue() { - return value; - } + @Override + public void setValue(T value) { + this.value = value; + } + @Override + public T getValue() { + return value; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Histogram.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Histogram.java index 1f7ae352c..c91531b07 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Histogram.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Histogram.java @@ -21,10 +21,9 @@ public interface Histogram { - void update(int value); + void update(int value); - void update(long value); - - long getCount(); + void update(long value); + long getCount(); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/HistogramImpl.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/HistogramImpl.java index 4d51180f9..dcbbd192c 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/HistogramImpl.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/HistogramImpl.java @@ -21,12 +21,9 @@ import com.codahale.metrics.Reservoir; -public class HistogramImpl - extends com.codahale.metrics.Histogram - implements Histogram { - - public HistogramImpl(Reservoir reservoir) { - super(reservoir); - } +public class HistogramImpl extends com.codahale.metrics.Histogram implements Histogram { + public HistogramImpl(Reservoir reservoir) { + super(reservoir); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Meter.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Meter.java index 06fab4d92..29c06232b 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Meter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/Meter.java @@ -21,19 +21,19 @@ public interface Meter { - void mark(); + void mark(); - void mark(long n); + void mark(long n); - double getFifteenMinuteRate(); + double getFifteenMinuteRate(); - double getFiveMinuteRate(); + double getFiveMinuteRate(); - double getMeanRate(); + double getMeanRate(); - double getOneMinuteRate(); + double getOneMinuteRate(); - long getCount(); + long getCount(); - long getCountAndReset(); + long getCountAndReset(); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MeterImpl.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MeterImpl.java index 309165a32..179c88152 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MeterImpl.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MeterImpl.java @@ -19,79 +19,79 @@ package org.apache.geaflow.metrics.common.api; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.LongAdder; + import com.codahale.metrics.Clock; import com.codahale.metrics.ExponentialMovingAverages; import com.codahale.metrics.Meter; import com.codahale.metrics.MovingAverages; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.LongAdder; public class MeterImpl extends Meter implements org.apache.geaflow.metrics.common.api.Meter { - private final MovingAverages movingAverages; - private final LongAdder count; - private final long startTime; - private final Clock clock; - - public MeterImpl(MovingAverages movingAverages) { - this(movingAverages, Clock.defaultClock()); - } - - public MeterImpl() { - this(Clock.defaultClock()); - } - - public MeterImpl(Clock clock) { - this(new ExponentialMovingAverages(clock), clock); - } - - public MeterImpl(MovingAverages movingAverages, Clock clock) { - this.count = new LongAdder(); - this.movingAverages = movingAverages; - this.clock = clock; - this.startTime = this.clock.getTick(); - } - - public void mark() { - this.mark(1L); - } - - public void mark(long n) { - this.movingAverages.tickIfNecessary(); - this.count.add(n); - this.movingAverages.update(n); + private final MovingAverages movingAverages; + private final LongAdder count; + private final long startTime; + private final Clock clock; + + public MeterImpl(MovingAverages movingAverages) { + this(movingAverages, Clock.defaultClock()); + } + + public MeterImpl() { + this(Clock.defaultClock()); + } + + public MeterImpl(Clock clock) { + this(new ExponentialMovingAverages(clock), clock); + } + + public MeterImpl(MovingAverages movingAverages, Clock clock) { + this.count = new LongAdder(); + this.movingAverages = movingAverages; + this.clock = clock; + this.startTime = this.clock.getTick(); + } + + public void mark() { + this.mark(1L); + } + + public void mark(long n) { + this.movingAverages.tickIfNecessary(); + this.count.add(n); + this.movingAverages.update(n); + } + + public double getFifteenMinuteRate() { + this.movingAverages.tickIfNecessary(); + return this.movingAverages.getM15Rate(); + } + + public double getFiveMinuteRate() { + this.movingAverages.tickIfNecessary(); + return this.movingAverages.getM5Rate(); + } + + public double getMeanRate() { + if (this.getCount() == 0L) { + return 0.0D; + } else { + double elapsed = (double) (this.clock.getTick() - this.startTime); + return (double) this.getCount() / elapsed * (double) TimeUnit.SECONDS.toNanos(1L); } + } - public double getFifteenMinuteRate() { - this.movingAverages.tickIfNecessary(); - return this.movingAverages.getM15Rate(); - } + public double getOneMinuteRate() { + this.movingAverages.tickIfNecessary(); + return this.movingAverages.getM1Rate(); + } - public double getFiveMinuteRate() { - this.movingAverages.tickIfNecessary(); - return this.movingAverages.getM5Rate(); - } - - public double getMeanRate() { - if (this.getCount() == 0L) { - return 0.0D; - } else { - double elapsed = (double) (this.clock.getTick() - this.startTime); - return (double) this.getCount() / elapsed * (double) TimeUnit.SECONDS.toNanos(1L); - } - } - - public double getOneMinuteRate() { - this.movingAverages.tickIfNecessary(); - return this.movingAverages.getM1Rate(); - } - - public long getCount() { - return this.count.sum(); - } - - public long getCountAndReset() { - return this.count.sumThenReset(); - } + public long getCount() { + return this.count.sum(); + } + public long getCountAndReset() { + return this.count.sumThenReset(); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroup.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroup.java index 0d353373b..2ebda341b 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroup.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroup.java @@ -21,56 +21,53 @@ import java.io.Serializable; -/** - * Interface to register or get metric. - */ +/** Interface to register or get metric. */ public interface MetricGroup extends Serializable { - /** - * Registers a new {@link Gauge}. - * - * @param name name of the gauge - * @param gauge gauge to register - */ - void register(String name, Gauge gauge); - - /** - * Register a {@link com.codahale.metrics.SettableGauge} or get a existing one. - * - * @param name gauge name - * @return gauge - */ - Gauge gauge(String name); + /** + * Registers a new {@link Gauge}. + * + * @param name name of the gauge + * @param gauge gauge to register + */ + void register(String name, Gauge gauge); - /** - * Register or get a {@link Counter}. - * - * @param name name of the counter - * @return the created counter - */ - Counter counter(String name); + /** + * Register a {@link com.codahale.metrics.SettableGauge} or get a existing one. + * + * @param name gauge name + * @return gauge + */ + Gauge gauge(String name); - /** - * Registers or get a {@link Meter}. - * - * @param name name of the meter - * @return the registered meter - */ - Meter meter(String name); + /** + * Register or get a {@link Counter}. + * + * @param name name of the counter + * @return the created counter + */ + Counter counter(String name); - /** - * Registers or get a {@link Histogram}. - * - * @param name name of the histogram - * @return the registered histogram - */ - Histogram histogram(String name); + /** + * Registers or get a {@link Meter}. + * + * @param name name of the meter + * @return the registered meter + */ + Meter meter(String name); - /** - * remove a metric by name. - * - * @param name metricName. - */ - void remove(String name); + /** + * Registers or get a {@link Histogram}. + * + * @param name name of the histogram + * @return the registered histogram + */ + Histogram histogram(String name); + /** + * remove a metric by name. + * + * @param name metricName. + */ + void remove(String name); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroupImpl.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroupImpl.java index d71492c37..0481d17d6 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroupImpl.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/api/MetricGroupImpl.java @@ -19,56 +19,58 @@ package org.apache.geaflow.metrics.common.api; +import org.apache.geaflow.metrics.common.MetricConstants; + import com.codahale.metrics.ExponentiallyDecayingReservoir; import com.codahale.metrics.MetricRegistry; -import org.apache.geaflow.metrics.common.MetricConstants; public class MetricGroupImpl implements MetricGroup { - private final String groupName; - private final MetricRegistry metricRegistry; - - public MetricGroupImpl(MetricRegistry metricRegistry) { - this(MetricConstants.MODULE_DEFAULT, metricRegistry); - } + private final String groupName; + private final MetricRegistry metricRegistry; - public MetricGroupImpl(String name, MetricRegistry metricRegistry) { - this.groupName = name == null ? MetricConstants.MODULE_DEFAULT : name; - this.metricRegistry = metricRegistry; - } + public MetricGroupImpl(MetricRegistry metricRegistry) { + this(MetricConstants.MODULE_DEFAULT, metricRegistry); + } - @Override - public void register(String name, Gauge gauge) { - metricRegistry.register(getMetricName(name), gauge); - } + public MetricGroupImpl(String name, MetricRegistry metricRegistry) { + this.groupName = name == null ? MetricConstants.MODULE_DEFAULT : name; + this.metricRegistry = metricRegistry; + } - @Override - public Gauge gauge(String name) { - return metricRegistry.gauge(getMetricName(name), GaugeImpl::new); - } + @Override + public void register(String name, Gauge gauge) { + metricRegistry.register(getMetricName(name), gauge); + } - @Override - public Counter counter(String name) { - return (Counter) metricRegistry.counter(getMetricName(name), CounterImpl::new); - } + @Override + public Gauge gauge(String name) { + return metricRegistry.gauge(getMetricName(name), GaugeImpl::new); + } - @Override - public Meter meter(String name) { - return (Meter) metricRegistry.meter(getMetricName(name), MeterImpl::new); - } + @Override + public Counter counter(String name) { + return (Counter) metricRegistry.counter(getMetricName(name), CounterImpl::new); + } - @Override - public Histogram histogram(String name) { - return (Histogram) metricRegistry.histogram(getMetricName(name), () -> new HistogramImpl(new ExponentiallyDecayingReservoir())); - } + @Override + public Meter meter(String name) { + return (Meter) metricRegistry.meter(getMetricName(name), MeterImpl::new); + } - @Override - public void remove(String name) { - metricRegistry.remove(getMetricName(name)); - } + @Override + public Histogram histogram(String name) { + return (Histogram) + metricRegistry.histogram( + getMetricName(name), () -> new HistogramImpl(new ExponentiallyDecayingReservoir())); + } - public String getMetricName(String name) { - return this.groupName + MetricConstants.GROUP_DELIMITER + name; - } + @Override + public void remove(String name) { + metricRegistry.remove(getMetricName(name)); + } + public String getMetricName(String name) { + return this.groupName + MetricConstants.GROUP_DELIMITER + name; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporter.java index c00b1e6dc..4d8c9081a 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporter.java @@ -19,15 +19,15 @@ package org.apache.geaflow.metrics.common.reporter; -import com.codahale.metrics.MetricRegistry; import org.apache.geaflow.common.config.Configuration; -public interface MetricReporter { +import com.codahale.metrics.MetricRegistry; - void open(Configuration config, MetricRegistry metricRegistry); +public interface MetricReporter { - void close(); + void open(Configuration config, MetricRegistry metricRegistry); - String getReporterType(); + void close(); + String getReporterType(); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporterFactory.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporterFactory.java index 36afb401b..e8588eda0 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporterFactory.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/MetricReporterFactory.java @@ -20,17 +20,17 @@ package org.apache.geaflow.metrics.common.reporter; import java.util.ServiceLoader; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class MetricReporterFactory { - public static MetricReporter getMetricReporter(String reporterType) { - for (MetricReporter reporter : ServiceLoader.load(MetricReporter.class)) { - if (reporter.getReporterType().equalsIgnoreCase(reporterType)) { - return reporter; - } - } - throw new GeaflowRuntimeException("no metric reporter implement found for " + reporterType); + public static MetricReporter getMetricReporter(String reporterType) { + for (MetricReporter reporter : ServiceLoader.load(MetricReporter.class)) { + if (reporter.getReporterType().equalsIgnoreCase(reporterType)) { + return reporter; + } } - + throw new GeaflowRuntimeException("no metric reporter implement found for " + reporterType); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ReporterRegistry.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ReporterRegistry.java index 21548db51..9421a2969 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ReporterRegistry.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ReporterRegistry.java @@ -24,23 +24,22 @@ public class ReporterRegistry { - public static final String SLF4J_REPORTER = "slf4j"; - public static final String TSDB_REPORTER = "tsdb"; - public static final String INFLUXDB_REPORTER = "influxdb"; - private static final Map NAME_CLASS_MAP = new HashMap<>(); - - static { - register(SLF4J_REPORTER, "org.apache.geaflow.metrics.slf4j.Slf4jReporter"); - register(TSDB_REPORTER, "org.apache.geaflow.metrics.tsdb.TsdbMetricReporter"); - register(INFLUXDB_REPORTER, "org.apache.geaflow.metrics.influxdb.InfluxdbReporter"); - } - - public static void register(String name, String className) { - NAME_CLASS_MAP.put(name, className); - } - - public static String getClassByName(String reporterName) { - return NAME_CLASS_MAP.get(reporterName); - } - + public static final String SLF4J_REPORTER = "slf4j"; + public static final String TSDB_REPORTER = "tsdb"; + public static final String INFLUXDB_REPORTER = "influxdb"; + private static final Map NAME_CLASS_MAP = new HashMap<>(); + + static { + register(SLF4J_REPORTER, "org.apache.geaflow.metrics.slf4j.Slf4jReporter"); + register(TSDB_REPORTER, "org.apache.geaflow.metrics.tsdb.TsdbMetricReporter"); + register(INFLUXDB_REPORTER, "org.apache.geaflow.metrics.influxdb.InfluxdbReporter"); + } + + public static void register(String name, String className) { + NAME_CLASS_MAP.put(name, className); + } + + public static String getClassByName(String reporterName) { + return NAME_CLASS_MAP.get(reporterName); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ScheduledReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ScheduledReporter.java index 7c913e16e..eece36930 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ScheduledReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/main/java/org/apache/geaflow/metrics/common/reporter/ScheduledReporter.java @@ -23,6 +23,5 @@ public interface ScheduledReporter extends Serializable { - void report(); - + void report(); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricGroupTest.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricGroupTest.java index 75fe3f5ab..6d8cfc6e4 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricGroupTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricGroupTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.metrics.common; -import com.codahale.metrics.MetricRegistry; import org.apache.geaflow.metrics.common.api.Counter; import org.apache.geaflow.metrics.common.api.Gauge; import org.apache.geaflow.metrics.common.api.Meter; @@ -27,51 +26,52 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class MetricGroupTest { +import com.codahale.metrics.MetricRegistry; - @Test - public void testGauge() { - MetricRegistry registry = new MetricRegistry(); - MetricGroupImpl metricGroup = new MetricGroupImpl(registry); +public class MetricGroupTest { - String gaugeName = "newGauge"; - Gauge gauge = metricGroup.gauge(metricGroup.getMetricName(gaugeName)); - gauge.setValue(1.0); - Assert.assertEquals(gauge.getValue(), 1.0); - } + @Test + public void testGauge() { + MetricRegistry registry = new MetricRegistry(); + MetricGroupImpl metricGroup = new MetricGroupImpl(registry); - @Test - public void testCounter() { - MetricRegistry registry = new MetricRegistry(); - MetricGroupImpl metricGroup = new MetricGroupImpl(registry); + String gaugeName = "newGauge"; + Gauge gauge = metricGroup.gauge(metricGroup.getMetricName(gaugeName)); + gauge.setValue(1.0); + Assert.assertEquals(gauge.getValue(), 1.0); + } - String counterName = "newCounter"; - Counter counter = metricGroup.counter(counterName); - counter.inc(); - counter.inc(1); - Assert.assertEquals(counter.getCount(), 2); - Assert.assertTrue(counter instanceof com.codahale.metrics.Counter); + @Test + public void testCounter() { + MetricRegistry registry = new MetricRegistry(); + MetricGroupImpl metricGroup = new MetricGroupImpl(registry); - Counter counter1 = (Counter) registry.getCounters().get(metricGroup.getMetricName(counterName)); - Assert.assertEquals(counter1.getCountAndReset(), 2); - Assert.assertEquals(counter1.getCount(), 0); - } + String counterName = "newCounter"; + Counter counter = metricGroup.counter(counterName); + counter.inc(); + counter.inc(1); + Assert.assertEquals(counter.getCount(), 2); + Assert.assertTrue(counter instanceof com.codahale.metrics.Counter); - @Test - public void testMeter() { - MetricRegistry registry = new MetricRegistry(); - MetricGroupImpl metricGroup = new MetricGroupImpl(registry); + Counter counter1 = (Counter) registry.getCounters().get(metricGroup.getMetricName(counterName)); + Assert.assertEquals(counter1.getCountAndReset(), 2); + Assert.assertEquals(counter1.getCount(), 0); + } - String meterName = "testMeter"; - Meter meter = metricGroup.meter(meterName); - meter.mark(); - meter.mark(1); - Assert.assertEquals(meter.getCount(), 2); - Assert.assertTrue(meter instanceof com.codahale.metrics.Meter); + @Test + public void testMeter() { + MetricRegistry registry = new MetricRegistry(); + MetricGroupImpl metricGroup = new MetricGroupImpl(registry); - Meter meter1 = (Meter) registry.getMeters().get(metricGroup.getMetricName(meterName)); - Assert.assertEquals(meter1.getCountAndReset(), 2); - Assert.assertEquals(meter1.getCount(), 0); - } + String meterName = "testMeter"; + Meter meter = metricGroup.meter(meterName); + meter.mark(); + meter.mark(1); + Assert.assertEquals(meter.getCount(), 2); + Assert.assertTrue(meter instanceof com.codahale.metrics.Meter); + Meter meter1 = (Meter) registry.getMeters().get(metricGroup.getMetricName(meterName)); + Assert.assertEquals(meter1.getCountAndReset(), 2); + Assert.assertEquals(meter1.getCount(), 0); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricNameFormatterTest.java b/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricNameFormatterTest.java index 64a092a5d..49c240f90 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricNameFormatterTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-common/src/test/java/org/apache/geaflow/metrics/common/MetricNameFormatterTest.java @@ -24,23 +24,24 @@ public class MetricNameFormatterTest { - @Test - public void testMetricNameFormatter() { - int opId = 1; - String opInputMetricName = MetricNameFormatter.inputTpsMetricName(this.getClass(), opId); - Assert.assertEquals(opInputMetricName, "inputTps.MetricNameFormatterTest[1](N)"); - String opOutputMetricName = MetricNameFormatter.outputTpsMetricName(this.getClass(), opId); - Assert.assertEquals(opOutputMetricName, "outputTps.MetricNameFormatterTest[1](N)"); - String opRtMetricName = MetricNameFormatter.rtMetricName(this.getClass(), opId); - Assert.assertEquals(opRtMetricName, "processRt.MetricNameFormatterTest[1](us)"); - String vertexTpsMetricName = MetricNameFormatter.vertexTpsMetricName(this.getClass(), opId); - Assert.assertEquals(vertexTpsMetricName, "vertexTps.MetricNameFormatterTest[1](N)"); - String edgeTpsMetricName = MetricNameFormatter.edgeTpsMetricName(this.getClass(), opId); - Assert.assertEquals(edgeTpsMetricName, "edgeTps.MetricNameFormatterTest[1](N)"); - String iterationFinishMetricName = MetricNameFormatter.iterationFinishMetricName(this.getClass(), opId, 0); - Assert.assertEquals(iterationFinishMetricName, "iteration.MetricNameFormatterTest[1:0](ms)"); - String iterationMsgMetricName = MetricNameFormatter.iterationMsgMetricName(this.getClass(), opId); - Assert.assertEquals(iterationMsgMetricName, "iterationMsgTps.MetricNameFormatterTest[1](N)"); - } - + @Test + public void testMetricNameFormatter() { + int opId = 1; + String opInputMetricName = MetricNameFormatter.inputTpsMetricName(this.getClass(), opId); + Assert.assertEquals(opInputMetricName, "inputTps.MetricNameFormatterTest[1](N)"); + String opOutputMetricName = MetricNameFormatter.outputTpsMetricName(this.getClass(), opId); + Assert.assertEquals(opOutputMetricName, "outputTps.MetricNameFormatterTest[1](N)"); + String opRtMetricName = MetricNameFormatter.rtMetricName(this.getClass(), opId); + Assert.assertEquals(opRtMetricName, "processRt.MetricNameFormatterTest[1](us)"); + String vertexTpsMetricName = MetricNameFormatter.vertexTpsMetricName(this.getClass(), opId); + Assert.assertEquals(vertexTpsMetricName, "vertexTps.MetricNameFormatterTest[1](N)"); + String edgeTpsMetricName = MetricNameFormatter.edgeTpsMetricName(this.getClass(), opId); + Assert.assertEquals(edgeTpsMetricName, "edgeTps.MetricNameFormatterTest[1](N)"); + String iterationFinishMetricName = + MetricNameFormatter.iterationFinishMetricName(this.getClass(), opId, 0); + Assert.assertEquals(iterationFinishMetricName, "iteration.MetricNameFormatterTest[1:0](ms)"); + String iterationMsgMetricName = + MetricNameFormatter.iterationMsgMetricName(this.getClass(), opId); + Assert.assertEquals(iterationMsgMetricName, "iterationMsgTps.MetricNameFormatterTest[1](N)"); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfig.java b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfig.java index a1f2ab56e..0c8b91e30 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfig.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfig.java @@ -23,56 +23,65 @@ public class InfluxdbConfig { - private final String url; - private final String token; - private final String org; - private final String bucket; - private final long connectTimeoutMs; - private final long writeTimeoutMs; + private final String url; + private final String token; + private final String org; + private final String bucket; + private final long connectTimeoutMs; + private final long writeTimeoutMs; - public InfluxdbConfig(Configuration config) { - this.url = config.getString(InfluxdbConfigKeys.URL); - this.token = config.getString(InfluxdbConfigKeys.TOKEN); - this.org = config.getString(InfluxdbConfigKeys.ORG); - this.bucket = config.getString(InfluxdbConfigKeys.BUCKET); - this.connectTimeoutMs = config.getLong(InfluxdbConfigKeys.CONNECT_TIMEOUT_MS); - this.writeTimeoutMs = config.getLong(InfluxdbConfigKeys.WRITE_TIMEOUT_MS); - } + public InfluxdbConfig(Configuration config) { + this.url = config.getString(InfluxdbConfigKeys.URL); + this.token = config.getString(InfluxdbConfigKeys.TOKEN); + this.org = config.getString(InfluxdbConfigKeys.ORG); + this.bucket = config.getString(InfluxdbConfigKeys.BUCKET); + this.connectTimeoutMs = config.getLong(InfluxdbConfigKeys.CONNECT_TIMEOUT_MS); + this.writeTimeoutMs = config.getLong(InfluxdbConfigKeys.WRITE_TIMEOUT_MS); + } - public String getUrl() { - return this.url; - } + public String getUrl() { + return this.url; + } - public String getToken() { - return this.token; - } + public String getToken() { + return this.token; + } - public String getOrg() { - return this.org; - } + public String getOrg() { + return this.org; + } - public String getBucket() { - return this.bucket; - } + public String getBucket() { + return this.bucket; + } - public long getConnectTimeoutMs() { - return this.connectTimeoutMs; - } + public long getConnectTimeoutMs() { + return this.connectTimeoutMs; + } - public long getWriteTimeoutMs() { - return this.writeTimeoutMs; - } - - @Override - public String toString() { - return "InfluxdbConfig{" - + "url='" + url + '\'' - + ", token='" + token + '\'' - + ", org='" + org + '\'' - + ", bucket='" + bucket + '\'' - + ", connectTimeoutMs=" + connectTimeoutMs - + ", writeTimeoutMs=" + writeTimeoutMs - + '}'; - } + public long getWriteTimeoutMs() { + return this.writeTimeoutMs; + } + @Override + public String toString() { + return "InfluxdbConfig{" + + "url='" + + url + + '\'' + + ", token='" + + token + + '\'' + + ", org='" + + org + + '\'' + + ", bucket='" + + bucket + + '\'' + + ", connectTimeoutMs=" + + connectTimeoutMs + + ", writeTimeoutMs=" + + writeTimeoutMs + + '}'; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfigKeys.java b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfigKeys.java index 8a962f90b..edb9dfd55 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfigKeys.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbConfigKeys.java @@ -24,34 +24,33 @@ public class InfluxdbConfigKeys { - public static final ConfigKey URL = ConfigKeys - .key("geaflow.metric.influxdb.url") - .noDefaultValue() - .description("influxdb url, e.g. http://localhost:8086"); - - public static final ConfigKey TOKEN = ConfigKeys - .key("geaflow.metric.influxdb.token") - .noDefaultValue() - .description("influxdb token, for authorization"); - - public static final ConfigKey ORG = ConfigKeys - .key("geaflow.metric.influxdb.org") - .noDefaultValue() - .description("influxdb organization of the bucket"); - - public static final ConfigKey BUCKET = ConfigKeys - .key("geaflow.metric.influxdb.bucket") - .noDefaultValue() - .description("influxdb bucket name"); - - public static final ConfigKey CONNECT_TIMEOUT_MS = ConfigKeys - .key("geaflow.metric.influxdb.connect.timeout.ms") - .defaultValue(30_000L) - .description("influxdb connect timeout millis"); - - public static final ConfigKey WRITE_TIMEOUT_MS = ConfigKeys - .key("geaflow.metric.influxdb.write.timeout.ms") - .defaultValue(30_000L) - .description("influxdb write timeout millis"); - + public static final ConfigKey URL = + ConfigKeys.key("geaflow.metric.influxdb.url") + .noDefaultValue() + .description("influxdb url, e.g. http://localhost:8086"); + + public static final ConfigKey TOKEN = + ConfigKeys.key("geaflow.metric.influxdb.token") + .noDefaultValue() + .description("influxdb token, for authorization"); + + public static final ConfigKey ORG = + ConfigKeys.key("geaflow.metric.influxdb.org") + .noDefaultValue() + .description("influxdb organization of the bucket"); + + public static final ConfigKey BUCKET = + ConfigKeys.key("geaflow.metric.influxdb.bucket") + .noDefaultValue() + .description("influxdb bucket name"); + + public static final ConfigKey CONNECT_TIMEOUT_MS = + ConfigKeys.key("geaflow.metric.influxdb.connect.timeout.ms") + .defaultValue(30_000L) + .description("influxdb connect timeout millis"); + + public static final ConfigKey WRITE_TIMEOUT_MS = + ConfigKeys.key("geaflow.metric.influxdb.write.timeout.ms") + .defaultValue(30_000L) + .description("influxdb write timeout millis"); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporter.java index 2ebe87de8..54c36f3d3 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/main/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporter.java @@ -19,22 +19,12 @@ package org.apache.geaflow.metrics.influxdb; -import com.codahale.metrics.Counter; -import com.codahale.metrics.Gauge; -import com.codahale.metrics.Histogram; -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import com.google.common.annotations.VisibleForTesting; -import com.influxdb.client.InfluxDBClientOptions; -import com.influxdb.client.domain.WritePrecision; -import com.influxdb.client.internal.InfluxDBClientImpl; -import com.influxdb.client.write.Point; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import okhttp3.OkHttpClient; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.api.CounterImpl; import org.apache.geaflow.metrics.common.api.MeterImpl; @@ -43,93 +33,107 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.google.common.annotations.VisibleForTesting; +import com.influxdb.client.InfluxDBClientOptions; +import com.influxdb.client.domain.WritePrecision; +import com.influxdb.client.internal.InfluxDBClientImpl; +import com.influxdb.client.write.Point; + +import okhttp3.OkHttpClient; + public class InfluxdbReporter extends AbstractReporter implements ScheduledReporter { - private static final Logger LOGGER = LoggerFactory.getLogger(InfluxdbReporter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(InfluxdbReporter.class); - private static final Map EMPTY_TAGS = Collections.emptyMap(); - private static final String TYPE_INFLUXDB = "influxdb"; - private static final String FIELD = "value"; - private InfluxDBClientImpl influxDB; + private static final Map EMPTY_TAGS = Collections.emptyMap(); + private static final String TYPE_INFLUXDB = "influxdb"; + private static final String FIELD = "value"; + private InfluxDBClientImpl influxDB; - @Override - public void open(Configuration config, MetricRegistry metricRegistry) { - super.open(config, metricRegistry); - InfluxdbConfig influxdbConfig = new InfluxdbConfig(config); - OkHttpClient.Builder httpClient = new OkHttpClient.Builder() + @Override + public void open(Configuration config, MetricRegistry metricRegistry) { + super.open(config, metricRegistry); + InfluxdbConfig influxdbConfig = new InfluxdbConfig(config); + OkHttpClient.Builder httpClient = + new OkHttpClient.Builder() .connectTimeout(influxdbConfig.getConnectTimeoutMs(), TimeUnit.MILLISECONDS) .writeTimeout(influxdbConfig.getWriteTimeoutMs(), TimeUnit.MILLISECONDS); - InfluxDBClientOptions options = InfluxDBClientOptions.builder() + InfluxDBClientOptions options = + InfluxDBClientOptions.builder() .okHttpClient(httpClient) .url(influxdbConfig.getUrl()) .org(influxdbConfig.getOrg()) .bucket(influxdbConfig.getBucket()) .authenticateToken(influxdbConfig.getToken().toCharArray()) .build(); - this.influxDB = new InfluxDBClientImpl(options); - this.addMetricRegisterListener(config); - LOGGER.info("load influxdb config: {}", influxdbConfig); - } + this.influxDB = new InfluxDBClientImpl(options); + this.addMetricRegisterListener(config); + LOGGER.info("load influxdb config: {}", influxdbConfig); + } - @Override - public void report() { - List points = new ArrayList<>(); - for (Map.Entry gauge : metricRegistry.getGauges().entrySet()) { - points.add(this.buildPoint(gauge.getKey(), gauge.getValue().getValue())); - } - for (Map.Entry meter : metricRegistry.getMeters().entrySet()) { - MeterImpl meterWrapper = (MeterImpl) meter.getValue(); - points.add(this.buildPoint(meter.getKey(), meterWrapper.getCountAndReset())); - } - for (Map.Entry counter : metricRegistry.getCounters().entrySet()) { - CounterImpl counterWrapper = (CounterImpl) counter.getValue(); - points.add(this.buildPoint(counter.getKey(), counterWrapper.getCountAndReset())); - } - for (Map.Entry histogram : metricRegistry.getHistograms().entrySet()) { - points.add(this.buildPoint(histogram.getKey(), histogram.getValue().getSnapshot().getMean())); - } - try { - this.writePoints(points); - } catch (Exception e) { - LOGGER.error("save metric to influxdb err", e); - } + @Override + public void report() { + List points = new ArrayList<>(); + for (Map.Entry gauge : metricRegistry.getGauges().entrySet()) { + points.add(this.buildPoint(gauge.getKey(), gauge.getValue().getValue())); } - - protected void writePoints(List points) { - this.influxDB.getWriteApiBlocking().writePoints(points); + for (Map.Entry meter : metricRegistry.getMeters().entrySet()) { + MeterImpl meterWrapper = (MeterImpl) meter.getValue(); + points.add(this.buildPoint(meter.getKey(), meterWrapper.getCountAndReset())); } - - private Point buildPoint(String name, Object value) { - if (value instanceof Number) { - return Point.measurement(name) - .addTags(this.globalTags) - .time(System.currentTimeMillis(), WritePrecision.MS) - .addField(FIELD, (Number) value); - } else { - return Point.measurement(name) - .addTags(this.globalTags) - .time(System.currentTimeMillis(), WritePrecision.MS) - .addField(FIELD, String.valueOf(value)); - } + for (Map.Entry counter : metricRegistry.getCounters().entrySet()) { + CounterImpl counterWrapper = (CounterImpl) counter.getValue(); + points.add(this.buildPoint(counter.getKey(), counterWrapper.getCountAndReset())); } - - @Override - public void close() { - super.close(); - if (this.influxDB != null) { - this.influxDB.close(); - LOGGER.info("close influxdb client"); - } + for (Map.Entry histogram : metricRegistry.getHistograms().entrySet()) { + points.add(this.buildPoint(histogram.getKey(), histogram.getValue().getSnapshot().getMean())); + } + try { + this.writePoints(points); + } catch (Exception e) { + LOGGER.error("save metric to influxdb err", e); } + } - @Override - public String getReporterType() { - return TYPE_INFLUXDB; + protected void writePoints(List points) { + this.influxDB.getWriteApiBlocking().writePoints(points); + } + + private Point buildPoint(String name, Object value) { + if (value instanceof Number) { + return Point.measurement(name) + .addTags(this.globalTags) + .time(System.currentTimeMillis(), WritePrecision.MS) + .addField(FIELD, (Number) value); + } else { + return Point.measurement(name) + .addTags(this.globalTags) + .time(System.currentTimeMillis(), WritePrecision.MS) + .addField(FIELD, String.valueOf(value)); } + } - @VisibleForTesting - protected InfluxDBClientImpl getInfluxDB() { - return this.influxDB; + @Override + public void close() { + super.close(); + if (this.influxDB != null) { + this.influxDB.close(); + LOGGER.info("close influxdb client"); } + } + + @Override + public String getReporterType() { + return TYPE_INFLUXDB; + } + @VisibleForTesting + protected InfluxDBClientImpl getInfluxDB() { + return this.influxDB; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/test/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporterTest.java b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/test/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporterTest.java index d31a79500..6f40ca62c 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/test/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporterTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-influxdb/src/test/java/org/apache/geaflow/metrics/influxdb/InfluxdbReporterTest.java @@ -19,9 +19,8 @@ package org.apache.geaflow.metrics.influxdb; -import com.codahale.metrics.MetricRegistry; -import com.influxdb.client.write.Point; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.api.Counter; import org.apache.geaflow.metrics.common.api.Gauge; @@ -31,68 +30,64 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class InfluxdbReporterTest { - - private static final String url = "http://localhost:8086"; - private static final String org = "geaflow"; - private static final String token = "test"; - private static final String bucket = "test-bucket"; - private static final Configuration config = new Configuration(); +import com.codahale.metrics.MetricRegistry; +import com.influxdb.client.write.Point; - static { - config.put(InfluxdbConfigKeys.URL, url); - config.put(InfluxdbConfigKeys.ORG, org); - config.put(InfluxdbConfigKeys.TOKEN, token); - config.put(InfluxdbConfigKeys.BUCKET, bucket); - } +public class InfluxdbReporterTest { - @Test - public void testInfluxdbConfig() { - InfluxdbConfig influxdbConfig = new InfluxdbConfig(config); - Assert.assertEquals(influxdbConfig.getUrl(), url); - Assert.assertEquals(influxdbConfig.getOrg(), org); - Assert.assertEquals(influxdbConfig.getToken(), token); - Assert.assertEquals(influxdbConfig.getBucket(), bucket); - Assert.assertEquals( - influxdbConfig.getConnectTimeoutMs(), - InfluxdbConfigKeys.CONNECT_TIMEOUT_MS.getDefaultValue() - ); - Assert.assertEquals( - influxdbConfig.getWriteTimeoutMs(), - InfluxdbConfigKeys.WRITE_TIMEOUT_MS.getDefaultValue() - ); - } + private static final String url = "http://localhost:8086"; + private static final String org = "geaflow"; + private static final String token = "test"; + private static final String bucket = "test-bucket"; + private static final Configuration config = new Configuration(); - @Test - public void testInfluxdbReporter() { - MetricRegistry metricRegistry = new MetricRegistry(); - MetricGroupImpl metricGroup = new MetricGroupImpl(metricRegistry); - Gauge gauge = metricGroup.gauge("test-gauge"); - gauge.setValue(1.0); - Counter counter = metricGroup.counter("test-counter"); - counter.inc(); - Meter meter = metricGroup.meter("test-meter"); - meter.mark(); - Histogram histogram = metricGroup.histogram("test-histogram"); - histogram.update(123); + static { + config.put(InfluxdbConfigKeys.URL, url); + config.put(InfluxdbConfigKeys.ORG, org); + config.put(InfluxdbConfigKeys.TOKEN, token); + config.put(InfluxdbConfigKeys.BUCKET, bucket); + } - InfluxdbReporter reporter = new MockInfluxdbReporter(); - reporter.open(config, metricRegistry); - reporter.report(); - reporter.close(); - Assert.assertNotNull(reporter.getInfluxDB()); - } + @Test + public void testInfluxdbConfig() { + InfluxdbConfig influxdbConfig = new InfluxdbConfig(config); + Assert.assertEquals(influxdbConfig.getUrl(), url); + Assert.assertEquals(influxdbConfig.getOrg(), org); + Assert.assertEquals(influxdbConfig.getToken(), token); + Assert.assertEquals(influxdbConfig.getBucket(), bucket); + Assert.assertEquals( + influxdbConfig.getConnectTimeoutMs(), + InfluxdbConfigKeys.CONNECT_TIMEOUT_MS.getDefaultValue()); + Assert.assertEquals( + influxdbConfig.getWriteTimeoutMs(), InfluxdbConfigKeys.WRITE_TIMEOUT_MS.getDefaultValue()); + } - private static class MockInfluxdbReporter extends InfluxdbReporter { + @Test + public void testInfluxdbReporter() { + MetricRegistry metricRegistry = new MetricRegistry(); + MetricGroupImpl metricGroup = new MetricGroupImpl(metricRegistry); + Gauge gauge = metricGroup.gauge("test-gauge"); + gauge.setValue(1.0); + Counter counter = metricGroup.counter("test-counter"); + counter.inc(); + Meter meter = metricGroup.meter("test-meter"); + meter.mark(); + Histogram histogram = metricGroup.histogram("test-histogram"); + histogram.update(123); - @Override - protected void writePoints(List points) { - } + InfluxdbReporter reporter = new MockInfluxdbReporter(); + reporter.open(config, metricRegistry); + reporter.report(); + reporter.close(); + Assert.assertNotNull(reporter.getInfluxDB()); + } - @Override - protected void addMetricRegisterListener(Configuration config) { - } + private static class MockInfluxdbReporter extends InfluxdbReporter { - } + @Override + protected void writePoints(List points) {} + @Override + protected void addMetricRegisterListener(Configuration config) {} + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusConfigKeys.java b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusConfigKeys.java index 2b2614dfa..a9c73164e 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusConfigKeys.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusConfigKeys.java @@ -24,24 +24,25 @@ public class PrometheusConfigKeys { - public static final ConfigKey GATEWAY_URL = ConfigKeys - .key("geaflow.metric.prometheus.gateway.url") - .noDefaultValue() - .description("prometheus gateway url"); + public static final ConfigKey GATEWAY_URL = + ConfigKeys.key("geaflow.metric.prometheus.gateway.url") + .noDefaultValue() + .description("prometheus gateway url"); - public static final ConfigKey USER = ConfigKeys - .key("geaflow.metric.prometheus.auth.user") - .noDefaultValue() - .description("prometheus auth user"); + public static final ConfigKey USER = + ConfigKeys.key("geaflow.metric.prometheus.auth.user") + .noDefaultValue() + .description("prometheus auth user"); - public static final ConfigKey PASSWORD = ConfigKeys - .key("geaflow.metric.prometheus.auth.password") - .noDefaultValue() - .description("prometheus auth password"); + public static final ConfigKey PASSWORD = + ConfigKeys.key("geaflow.metric.prometheus.auth.password") + .noDefaultValue() + .description("prometheus auth password"); - public static final ConfigKey JOB_NAME = ConfigKeys - .key("geaflow.metric.prometheus.job.name") - .defaultValue("DEFAULT") - .description("prometheus job name format. The origin metric name will replace the '%s' " - + "in the format."); + public static final ConfigKey JOB_NAME = + ConfigKeys.key("geaflow.metric.prometheus.job.name") + .defaultValue("DEFAULT") + .description( + "prometheus job name format. The origin metric name will replace the '%s' " + + "in the format."); } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusReporter.java index 0b64cd11a..a1d490ebc 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/main/java/org/apache/geaflow/metrics/prometheus/PrometheusReporter.java @@ -19,17 +19,11 @@ package org.apache.geaflow.metrics.prometheus; -import com.codahale.metrics.Counter; -import com.codahale.metrics.Gauge; -import com.codahale.metrics.Histogram; -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import io.prometheus.client.exporter.BasicAuthHttpConnectionFactory; -import io.prometheus.client.exporter.PushGateway; import java.io.IOException; import java.net.URL; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.ProcessUtil; @@ -40,112 +34,127 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class PrometheusReporter extends AbstractReporter implements ScheduledReporter { +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; - private static final Logger LOGGER = LoggerFactory.getLogger(PrometheusReporter.class); +import io.prometheus.client.exporter.BasicAuthHttpConnectionFactory; +import io.prometheus.client.exporter.PushGateway; - private static final String TYPE_PROMETHEUS = "prometheus"; +public class PrometheusReporter extends AbstractReporter implements ScheduledReporter { - private static final String INSTANCE_KEY = "instance"; + private static final Logger LOGGER = LoggerFactory.getLogger(PrometheusReporter.class); - private static final String INVALID_METRIC_NAME_PATTERN = "[^a-zA-Z0-9_:]"; + private static final String TYPE_PROMETHEUS = "prometheus"; - private static final String METRIC_NAME_SEPARATOR = "_"; + private static final String INSTANCE_KEY = "instance"; - private Configuration configuration; + private static final String INVALID_METRIC_NAME_PATTERN = "[^a-zA-Z0-9_:]"; - private PushGateway pushGateway; + private static final String METRIC_NAME_SEPARATOR = "_"; - private String[] prometheusLabelNames; + private Configuration configuration; - private String[] prometheusLabelValues; + private PushGateway pushGateway; - private final Map gaugeMap = new HashMap<>(); + private String[] prometheusLabelNames; - @Override - public void open(Configuration config, MetricRegistry metricRegistry) { - this.configuration = config; - super.open(config, metricRegistry); - this.initPushGateway(config); - this.initLabels(); - this.addMetricRegisterListener(config); - } + private String[] prometheusLabelValues; - @Override - public void report() { - for (Map.Entry counter : metricRegistry.getCounters().entrySet()) { - CounterImpl counterWrapper = (CounterImpl) counter.getValue(); - doReport(counter.getKey(), counterWrapper.getCountAndReset()); - } - for (Map.Entry gaugeEntry : metricRegistry.getGauges().entrySet()) { - doReport(gaugeEntry.getKey(), gaugeEntry.getValue().getValue()); - } - for (Map.Entry meter : metricRegistry.getMeters().entrySet()) { - MeterImpl meterWrapper = (MeterImpl) meter.getValue(); - doReport(meter.getKey(), meterWrapper.getCountAndReset()); - } - for (Map.Entry histogram : metricRegistry.getHistograms().entrySet()) { - doReport(histogram.getKey(), histogram.getValue().getSnapshot().getMean()); - } - } + private final Map gaugeMap = new HashMap<>(); - private void doReport(String name, Object value) { - String metricName = getMetricName4Prometheus(name); - if (value instanceof Number) { - io.prometheus.client.Gauge gauge = getOrRegisterGauge(metricName); - String jobName = configuration.getString(PrometheusConfigKeys.JOB_NAME); - gauge.labels(prometheusLabelValues).set(((Number) value).doubleValue()); - try { - pushGateway.pushAdd(gauge, jobName); - } catch (IOException e) { - LOGGER.error("push metric {} to gateway failed. {}", metricName, e.getMessage(), e); - } - } - } + @Override + public void open(Configuration config, MetricRegistry metricRegistry) { + this.configuration = config; + super.open(config, metricRegistry); + this.initPushGateway(config); + this.initLabels(); + this.addMetricRegisterListener(config); + } - private io.prometheus.client.Gauge getOrRegisterGauge(String metricName) { - return gaugeMap.computeIfAbsent(metricName, - name -> io.prometheus.client.Gauge.build().name(metricName).help(metricName) - .labelNames(prometheusLabelNames).register()); + @Override + public void report() { + for (Map.Entry counter : metricRegistry.getCounters().entrySet()) { + CounterImpl counterWrapper = (CounterImpl) counter.getValue(); + doReport(counter.getKey(), counterWrapper.getCountAndReset()); } - - @Override - public void close() { - super.close(); + for (Map.Entry gaugeEntry : metricRegistry.getGauges().entrySet()) { + doReport(gaugeEntry.getKey(), gaugeEntry.getValue().getValue()); } - - @Override - public String getReporterType() { - return TYPE_PROMETHEUS; + for (Map.Entry meter : metricRegistry.getMeters().entrySet()) { + MeterImpl meterWrapper = (MeterImpl) meter.getValue(); + doReport(meter.getKey(), meterWrapper.getCountAndReset()); } - - private void initPushGateway(Configuration configuration) { - String gatewayUrl = configuration.getString(PrometheusConfigKeys.GATEWAY_URL); - try { - pushGateway = new PushGateway(new URL(gatewayUrl)); - LOGGER.info("Load prometheus push gateway: {}", gatewayUrl); - if (configuration.contains(PrometheusConfigKeys.USER) && configuration.contains( - PrometheusConfigKeys.PASSWORD)) { - pushGateway.setConnectionFactory(new BasicAuthHttpConnectionFactory( - configuration.getString(PrometheusConfigKeys.USER), - configuration.getString(PrometheusConfigKeys.PASSWORD))); - } - } catch (Exception e) { - LOGGER.error("Load push gateway of url {} error. {}", gatewayUrl, e.getMessage(), e); - throw new GeaflowRuntimeException(e); - } + for (Map.Entry histogram : metricRegistry.getHistograms().entrySet()) { + doReport(histogram.getKey(), histogram.getValue().getSnapshot().getMean()); } - - private void initLabels() { - String instanceName = ProcessUtil.getHostAndPid(); - Map prometheusLabels = new HashMap<>(globalTags); - prometheusLabels.put(INSTANCE_KEY, instanceName); - this.prometheusLabelNames = prometheusLabels.keySet().toArray(new String[0]); - this.prometheusLabelValues = prometheusLabels.values().toArray(new String[0]); + } + + private void doReport(String name, Object value) { + String metricName = getMetricName4Prometheus(name); + if (value instanceof Number) { + io.prometheus.client.Gauge gauge = getOrRegisterGauge(metricName); + String jobName = configuration.getString(PrometheusConfigKeys.JOB_NAME); + gauge.labels(prometheusLabelValues).set(((Number) value).doubleValue()); + try { + pushGateway.pushAdd(gauge, jobName); + } catch (IOException e) { + LOGGER.error("push metric {} to gateway failed. {}", metricName, e.getMessage(), e); + } } - - private String getMetricName4Prometheus(String metricName) { - metricName = metricName.replaceAll(INVALID_METRIC_NAME_PATTERN, METRIC_NAME_SEPARATOR); - return metricName; + } + + private io.prometheus.client.Gauge getOrRegisterGauge(String metricName) { + return gaugeMap.computeIfAbsent( + metricName, + name -> + io.prometheus.client.Gauge.build() + .name(metricName) + .help(metricName) + .labelNames(prometheusLabelNames) + .register()); + } + + @Override + public void close() { + super.close(); + } + + @Override + public String getReporterType() { + return TYPE_PROMETHEUS; + } + + private void initPushGateway(Configuration configuration) { + String gatewayUrl = configuration.getString(PrometheusConfigKeys.GATEWAY_URL); + try { + pushGateway = new PushGateway(new URL(gatewayUrl)); + LOGGER.info("Load prometheus push gateway: {}", gatewayUrl); + if (configuration.contains(PrometheusConfigKeys.USER) + && configuration.contains(PrometheusConfigKeys.PASSWORD)) { + pushGateway.setConnectionFactory( + new BasicAuthHttpConnectionFactory( + configuration.getString(PrometheusConfigKeys.USER), + configuration.getString(PrometheusConfigKeys.PASSWORD))); + } + } catch (Exception e) { + LOGGER.error("Load push gateway of url {} error. {}", gatewayUrl, e.getMessage(), e); + throw new GeaflowRuntimeException(e); } + } + + private void initLabels() { + String instanceName = ProcessUtil.getHostAndPid(); + Map prometheusLabels = new HashMap<>(globalTags); + prometheusLabels.put(INSTANCE_KEY, instanceName); + this.prometheusLabelNames = prometheusLabels.keySet().toArray(new String[0]); + this.prometheusLabelValues = prometheusLabels.values().toArray(new String[0]); + } + + private String getMetricName4Prometheus(String metricName) { + metricName = metricName.replaceAll(INVALID_METRIC_NAME_PATTERN, METRIC_NAME_SEPARATOR); + return metricName; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/test/java/org/apache/geaflow/metrics/prometheus/PrometheusReporterTest.java b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/test/java/org/apache/geaflow/metrics/prometheus/PrometheusReporterTest.java index 7841345a0..826011b83 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/test/java/org/apache/geaflow/metrics/prometheus/PrometheusReporterTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-prometheus/src/test/java/org/apache/geaflow/metrics/prometheus/PrometheusReporterTest.java @@ -24,7 +24,6 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.METRIC_META_REPORT_DELAY; import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.METRIC_META_REPORT_PERIOD; -import com.codahale.metrics.MetricRegistry; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.metrics.common.api.Counter; @@ -35,34 +34,35 @@ import org.testng.Assert; import org.testng.annotations.Test; -public class PrometheusReporterTest { +import com.codahale.metrics.MetricRegistry; - @Test - public void test() { - MetricRegistry metricRegistry = new MetricRegistry(); - MetricGroup metricGroup = new MetricGroupImpl(metricRegistry); - Counter counter = metricGroup.counter("system/iteration.MetricNameFormatterTest[1:0](ms)"); - counter.inc(); +public class PrometheusReporterTest { - Gauge gauge = metricGroup.gauge("gaugeTest"); - gauge.setValue(10); + @Test + public void test() { + MetricRegistry metricRegistry = new MetricRegistry(); + MetricGroup metricGroup = new MetricGroupImpl(metricRegistry); + Counter counter = metricGroup.counter("system/iteration.MetricNameFormatterTest[1:0](ms)"); + counter.inc(); - Histogram histogram = metricGroup.histogram("histTest"); - histogram.update(1); + Gauge gauge = metricGroup.gauge("gaugeTest"); + gauge.setValue(10); - Configuration config = new Configuration(); - config.put(JOB_APP_NAME, "geaflow123"); - config.put(GEAFLOW_GW_ENDPOINT, "http://localhost:8888/"); - config.put(METRIC_META_REPORT_DELAY, "0"); - config.put(METRIC_META_REPORT_PERIOD, "1"); - config.put(PrometheusConfigKeys.GATEWAY_URL.getKey(), "http://localhost:9091"); - PrometheusReporter metricReporter = new PrometheusReporter(); - metricReporter.open(config, metricRegistry); - metricReporter.report(); - SleepUtils.sleepSecond(3); + Histogram histogram = metricGroup.histogram("histTest"); + histogram.update(1); - metricReporter.close(); - Assert.assertTrue(true); - } + Configuration config = new Configuration(); + config.put(JOB_APP_NAME, "geaflow123"); + config.put(GEAFLOW_GW_ENDPOINT, "http://localhost:8888/"); + config.put(METRIC_META_REPORT_DELAY, "0"); + config.put(METRIC_META_REPORT_PERIOD, "1"); + config.put(PrometheusConfigKeys.GATEWAY_URL.getKey(), "http://localhost:9091"); + PrometheusReporter metricReporter = new PrometheusReporter(); + metricReporter.open(config, metricRegistry); + metricReporter.report(); + SleepUtils.sleepSecond(3); + metricReporter.close(); + Assert.assertTrue(true); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/AbstractReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/AbstractReporter.java index 086e28467..042d2577a 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/AbstractReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/AbstractReporter.java @@ -19,18 +19,9 @@ package org.apache.geaflow.metrics.reporter; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.JSONArray; -import com.alibaba.fastjson.JSONObject; -import com.alibaba.fastjson.serializer.SerializerFeature; -import com.codahale.metrics.Counter; -import com.codahale.metrics.Gauge; -import com.codahale.metrics.Histogram; -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import com.codahale.metrics.MetricRegistryListener; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.utils.ProcessUtil; @@ -42,112 +33,124 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.serializer.SerializerFeature; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.MetricRegistryListener; + public abstract class AbstractReporter implements MetricReporter { - private static final Logger LOGGER = LoggerFactory.getLogger(AbstractReporter.class); - - public static final String TAG_JOB_NAME = "jobName"; - public static final String TAG_WORKER = "worker"; - public static final String TAG_ENGINE = "engine"; - - public static final String KEY_TAGS = "tags"; - public static final String KEY_METRIC = "metric"; - public static final String KEY_AGGREGATOR = "aggregator"; - public static final String KEY_DOWN_SAMPLE = "downsample"; - public static final String KEY_GEAFLOW = "Geaflow"; - - protected MetricRegistry metricRegistry; - private MetricMetaClient metricMetaClient; - protected Map globalTags; - protected String jobName; - - public void open(Configuration config, MetricRegistry metricRegistry) { - this.metricRegistry = metricRegistry; - this.jobName = config.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.globalTags = new HashMap<>(); - this.globalTags.put(TAG_JOB_NAME, jobName); - this.globalTags.put(TAG_WORKER, ProcessUtil.getHostAndPid()); - this.globalTags.put(TAG_ENGINE, KEY_GEAFLOW); + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractReporter.class); + + public static final String TAG_JOB_NAME = "jobName"; + public static final String TAG_WORKER = "worker"; + public static final String TAG_ENGINE = "engine"; + + public static final String KEY_TAGS = "tags"; + public static final String KEY_METRIC = "metric"; + public static final String KEY_AGGREGATOR = "aggregator"; + public static final String KEY_DOWN_SAMPLE = "downsample"; + public static final String KEY_GEAFLOW = "Geaflow"; + + protected MetricRegistry metricRegistry; + private MetricMetaClient metricMetaClient; + protected Map globalTags; + protected String jobName; + + public void open(Configuration config, MetricRegistry metricRegistry) { + this.metricRegistry = metricRegistry; + this.jobName = config.getString(ExecutionConfigKeys.JOB_APP_NAME); + this.globalTags = new HashMap<>(); + this.globalTags.put(TAG_JOB_NAME, jobName); + this.globalTags.put(TAG_WORKER, ProcessUtil.getHostAndPid()); + this.globalTags.put(TAG_ENGINE, KEY_GEAFLOW); + } + + protected void addMetricRegisterListener(Configuration config) { + this.metricMetaClient = MetricMetaClient.getInstance(config); + this.metricRegistry.addListener(new MetricRegisterListener(this.metricMetaClient)); + LOGGER.info("add metric register listener"); + } + + @Override + public void close() { + if (this.metricMetaClient != null) { + this.metricMetaClient.close(); + LOGGER.info("close metric meta client"); + } + } + + protected class MetricRegisterListener extends MetricRegistryListener.Base { + + private final MetricMetaClient metricMetaClient; + + public MetricRegisterListener(MetricMetaClient metricMetaClient) { + this.metricMetaClient = metricMetaClient; + } + + @Override + public void onGaugeAdded(String metricName, Gauge gauge) { + String query = wrapQuery(metricName, DownSample.AVG, AggType.AVG); + metricMetaClient.registerMetricMeta(metricName, MetricType.GAUGE, query); } - protected void addMetricRegisterListener(Configuration config) { - this.metricMetaClient = MetricMetaClient.getInstance(config); - this.metricRegistry.addListener(new MetricRegisterListener(this.metricMetaClient)); - LOGGER.info("add metric register listener"); + @Override + public void onCounterAdded(String metricName, Counter counter) { + String query = wrapQuery(metricName, DownSample.SUM, AggType.SUM); + metricMetaClient.registerMetricMeta(metricName, MetricType.COUNTER, query); } @Override - public void close() { - if (this.metricMetaClient != null) { - this.metricMetaClient.close(); - LOGGER.info("close metric meta client"); - } + public void onMeterAdded(String metricName, Meter meter) { + String query = wrapQuery(metricName, DownSample.SUM, AggType.SUM); + metricMetaClient.registerMetricMeta(metricName, MetricType.METER, query); } - protected class MetricRegisterListener extends MetricRegistryListener.Base { - - private final MetricMetaClient metricMetaClient; - - public MetricRegisterListener(MetricMetaClient metricMetaClient) { - this.metricMetaClient = metricMetaClient; - } - - @Override - public void onGaugeAdded(String metricName, Gauge gauge) { - String query = wrapQuery(metricName, DownSample.AVG, AggType.AVG); - metricMetaClient.registerMetricMeta(metricName, MetricType.GAUGE, query); - } - - @Override - public void onCounterAdded(String metricName, Counter counter) { - String query = wrapQuery(metricName, DownSample.SUM, AggType.SUM); - metricMetaClient.registerMetricMeta(metricName, MetricType.COUNTER, query); - } - - @Override - public void onMeterAdded(String metricName, Meter meter) { - String query = wrapQuery(metricName, DownSample.SUM, AggType.SUM); - metricMetaClient.registerMetricMeta(metricName, MetricType.METER, query); - } - - @Override - public void onHistogramAdded(String metricName, Histogram histogram) { - JSONObject queryTags = buildQueryTags(); - - HistAggType aggType = HistAggType.DEFAULT; - JSONArray histogramQueries = new JSONArray(); - for (String aggregator : aggType.getAggTypes()) { - JSONObject query = new JSONObject(); - query.put(KEY_TAGS, queryTags); - query.put(KEY_METRIC, metricName); - query.put(KEY_AGGREGATOR, aggregator); - query.put(KEY_DOWN_SAMPLE, DownSample.AVG.getValue()); - histogramQueries.add(query); - } - - metricMetaClient.registerMetricMeta(metricName, MetricType.HISTOGRAM, - JSON.toJSONString(histogramQueries, SerializerFeature.DisableCircularReferenceDetect)); - } - - private String wrapQuery(String metricName, DownSample downSample, AggType aggregator) { - JSONObject query = new JSONObject(); - query.put(KEY_METRIC, metricName); - query.put(KEY_AGGREGATOR, aggregator.getValue()); - query.put(KEY_DOWN_SAMPLE, downSample.getValue()); - - JSONObject tags = buildQueryTags(); - query.put(KEY_TAGS, tags); - - JSONArray meterQueries = new JSONArray(); - meterQueries.add(query); - return meterQueries.toJSONString(); - } + @Override + public void onHistogramAdded(String metricName, Histogram histogram) { + JSONObject queryTags = buildQueryTags(); + + HistAggType aggType = HistAggType.DEFAULT; + JSONArray histogramQueries = new JSONArray(); + for (String aggregator : aggType.getAggTypes()) { + JSONObject query = new JSONObject(); + query.put(KEY_TAGS, queryTags); + query.put(KEY_METRIC, metricName); + query.put(KEY_AGGREGATOR, aggregator); + query.put(KEY_DOWN_SAMPLE, DownSample.AVG.getValue()); + histogramQueries.add(query); + } + + metricMetaClient.registerMetricMeta( + metricName, + MetricType.HISTOGRAM, + JSON.toJSONString(histogramQueries, SerializerFeature.DisableCircularReferenceDetect)); } - private JSONObject buildQueryTags() { - JSONObject tags = new JSONObject(); - tags.put(TAG_JOB_NAME, this.jobName); - return tags; + private String wrapQuery(String metricName, DownSample downSample, AggType aggregator) { + JSONObject query = new JSONObject(); + query.put(KEY_METRIC, metricName); + query.put(KEY_AGGREGATOR, aggregator.getValue()); + query.put(KEY_DOWN_SAMPLE, downSample.getValue()); + + JSONObject tags = buildQueryTags(); + query.put(KEY_TAGS, tags); + + JSONArray meterQueries = new JSONArray(); + meterQueries.add(query); + return meterQueries.toJSONString(); } + } + private JSONObject buildQueryTags() { + JSONObject tags = new JSONObject(); + tags.put(TAG_JOB_NAME, this.jobName); + return tags; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/MetricMetaClient.java b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/MetricMetaClient.java index fec17f4ef..427fff85a 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/MetricMetaClient.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/main/java/org/apache/geaflow/metrics/reporter/MetricMetaClient.java @@ -26,6 +26,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.utils.SleepUtils; @@ -39,82 +40,89 @@ public class MetricMetaClient { - private static final Logger LOGGER = LoggerFactory.getLogger(MetricMetaClient.class); - private static MetricMetaClient reporterClient; + private static final Logger LOGGER = LoggerFactory.getLogger(MetricMetaClient.class); + private static MetricMetaClient reporterClient; - private int failNum = 0; - private final int maxRetries; - private final String jobName; - private final List metricList = new ArrayList<>(); - private final BlockingQueue metricMetaQueue = new LinkedBlockingQueue<>(); - private final MetricConfig metricConfig; - private final ScheduledExecutorService scheduledService; + private int failNum = 0; + private final int maxRetries; + private final String jobName; + private final List metricList = new ArrayList<>(); + private final BlockingQueue metricMetaQueue = new LinkedBlockingQueue<>(); + private final MetricConfig metricConfig; + private final ScheduledExecutorService scheduledService; - private MetricMetaClient(Configuration config) { - this.jobName = config.getString(ExecutionConfigKeys.JOB_APP_NAME); + private MetricMetaClient(Configuration config) { + this.jobName = config.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.metricConfig = new MetricConfig(config); - this.maxRetries = metricConfig.getReportMaxRetries(); - this.scheduledService = new ScheduledThreadPoolExecutor(1, - ThreadUtil.namedThreadFactory(true, "async-metric-meta")); - scheduledService - .scheduleAtFixedRate(new RegisterTask(), metricConfig.getRandomDelaySec(), - metricConfig.getRandomPeriodSec(), TimeUnit.SECONDS); - } + this.metricConfig = new MetricConfig(config); + this.maxRetries = metricConfig.getReportMaxRetries(); + this.scheduledService = + new ScheduledThreadPoolExecutor( + 1, ThreadUtil.namedThreadFactory(true, "async-metric-meta")); + scheduledService.scheduleAtFixedRate( + new RegisterTask(), + metricConfig.getRandomDelaySec(), + metricConfig.getRandomPeriodSec(), + TimeUnit.SECONDS); + } - public static synchronized MetricMetaClient getInstance(Configuration config) { - if (reporterClient == null) { - reporterClient = new MetricMetaClient(config); - } - return reporterClient; + public static synchronized MetricMetaClient getInstance(Configuration config) { + if (reporterClient == null) { + reporterClient = new MetricMetaClient(config); } + return reporterClient; + } - public void registerMetricMeta(String metricName, MetricType metricType, String queries) { - MetricMeta metricMeta = new MetricMeta(); - metricMeta.setJobName(jobName); - metricMeta.setMetricName(metricName); - metricMeta.setQueries(queries); - metricMeta.setMetricType(metricType.name()); - metricMetaQueue.add(metricMeta); - } + public void registerMetricMeta(String metricName, MetricType metricType, String queries) { + MetricMeta metricMeta = new MetricMeta(); + metricMeta.setJobName(jobName); + metricMeta.setMetricName(metricName); + metricMeta.setQueries(queries); + metricMeta.setMetricType(metricType.name()); + metricMetaQueue.add(metricMeta); + } - public void close() { - if (scheduledService != null) { - scheduledService.shutdown(); - } + public void close() { + if (scheduledService != null) { + scheduledService.shutdown(); } + } - private class RegisterTask implements Runnable { + private class RegisterTask implements Runnable { - @Override - public void run() { - try { - if (metricMetaQueue.size() > 0 || metricList.size() > 0) { - if (metricList.size() == 0) { - metricMetaQueue.drainTo(metricList); - } - SleepUtils.sleepSecond(metricConfig.getRandomPeriodSec()); - for (MetricMeta metricMeta : metricList) { - StatsCollectorFactory.getInstance().getMetricMetaCollector().reportMetricMeta(metricMeta); - LOGGER.info("register {} with query: {}", metricMeta.getMetricName(), - metricMeta.getQueries()); - } - metricList.clear(); - failNum = 0; - } - } catch (RuntimeException ex) { - failNum++; - if (failNum < maxRetries) { - LOGGER.warn("register fail #{}, and retry in next round", failNum, ex); - } else { - LOGGER.warn("#{} retry exceeds {} times, discard {} metrics meta", failNum, - maxRetries, metricList.size()); - metricList.clear(); - failNum = 0; - } - throw ex; - } + @Override + public void run() { + try { + if (metricMetaQueue.size() > 0 || metricList.size() > 0) { + if (metricList.size() == 0) { + metricMetaQueue.drainTo(metricList); + } + SleepUtils.sleepSecond(metricConfig.getRandomPeriodSec()); + for (MetricMeta metricMeta : metricList) { + StatsCollectorFactory.getInstance() + .getMetricMetaCollector() + .reportMetricMeta(metricMeta); + LOGGER.info( + "register {} with query: {}", metricMeta.getMetricName(), metricMeta.getQueries()); + } + metricList.clear(); + failNum = 0; + } + } catch (RuntimeException ex) { + failNum++; + if (failNum < maxRetries) { + LOGGER.warn("register fail #{}, and retry in next round", failNum, ex); + } else { + LOGGER.warn( + "#{} retry exceeds {} times, discard {} metrics meta", + failNum, + maxRetries, + metricList.size()); + metricList.clear(); + failNum = 0; } + throw ex; + } } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MetricGroupRegistryTest.java b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MetricGroupRegistryTest.java index 14137ec73..bdab71547 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MetricGroupRegistryTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MetricGroupRegistryTest.java @@ -30,37 +30,38 @@ public class MetricGroupRegistryTest { - @Test - public void test() { - Configuration config = new Configuration(); - config.put(REPORTER_LIST, "mock"); - MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(config); + @Test + public void test() { + Configuration config = new Configuration(); + config.put(REPORTER_LIST, "mock"); + MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(config); - MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(); - metricGroup.register("timestamp", new Gauge() { - @Override - public Object getValue() { - return System.currentTimeMillis(); - } + MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(); + metricGroup.register( + "timestamp", + new Gauge() { + @Override + public Object getValue() { + return System.currentTimeMillis(); + } - @Override - public void setValue(Object value) { - } + @Override + public void setValue(Object value) {} }); - Assert.assertNotNull(metricGroup.gauge("timestamp")); + Assert.assertNotNull(metricGroup.gauge("timestamp")); - MetricGroup namedMetricGroup = metricGroupRegistry.getMetricGroup("group"); - namedMetricGroup.register("timestamp", new Gauge() { - @Override - public Object getValue() { - return System.currentTimeMillis(); - } + MetricGroup namedMetricGroup = metricGroupRegistry.getMetricGroup("group"); + namedMetricGroup.register( + "timestamp", + new Gauge() { + @Override + public Object getValue() { + return System.currentTimeMillis(); + } - @Override - public void setValue(Object value) { - } + @Override + public void setValue(Object value) {} }); - Assert.assertNotNull(namedMetricGroup.gauge("group.timestamp")); - } - + Assert.assertNotNull(namedMetricGroup.gauge("group.timestamp")); + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MockMetricReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MockMetricReporter.java index a78aa9400..d5c416207 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MockMetricReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-reporter/src/test/java/org/apache/geaflow/metrics/reporter/MockMetricReporter.java @@ -23,13 +23,11 @@ public class MockMetricReporter extends AbstractReporter implements ScheduledReporter { - @Override - public String getReporterType() { - return "mock"; - } + @Override + public String getReporterType() { + return "mock"; + } - @Override - public void report() { - - } + @Override + public void report() {} } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/main/java/org/apache/geaflow/metrics/slf4j/Slf4jReporter.java b/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/main/java/org/apache/geaflow/metrics/slf4j/Slf4jReporter.java index fb6aea520..87b867ae6 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/main/java/org/apache/geaflow/metrics/slf4j/Slf4jReporter.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/main/java/org/apache/geaflow/metrics/slf4j/Slf4jReporter.java @@ -19,108 +19,127 @@ package org.apache.geaflow.metrics.slf4j; -import com.codahale.metrics.Counter; -import com.codahale.metrics.Gauge; -import com.codahale.metrics.Histogram; -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import com.codahale.metrics.Snapshot; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.reporter.ScheduledReporter; import org.apache.geaflow.metrics.reporter.AbstractReporter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Snapshot; + public class Slf4jReporter extends AbstractReporter implements ScheduledReporter { - private static final Logger LOGGER = LoggerFactory.getLogger(Slf4jReporter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(Slf4jReporter.class); - private static final String TYPE_SLF4J = "slf4j"; - private static final String LINE_SEPARATOR = System.lineSeparator(); + private static final String TYPE_SLF4J = "slf4j"; + private static final String LINE_SEPARATOR = System.lineSeparator(); - @Override - public void open(Configuration jobConfig, MetricRegistry metricRegistry) { - super.open(jobConfig, metricRegistry); + @Override + public void open(Configuration jobConfig, MetricRegistry metricRegistry) { + super.open(jobConfig, metricRegistry); + } + + @Override + public void report() { + StringBuilder builder = new StringBuilder(); + builder + .append(LINE_SEPARATOR) + .append("=========================== Starting metrics report ===========================") + .append(LINE_SEPARATOR); + + builder + .append(LINE_SEPARATOR) + .append("-- Counters -------------------------------------------------------------------") + .append(LINE_SEPARATOR); + for (Map.Entry metric : metricRegistry.getCounters().entrySet()) { + builder + .append(metric.getKey()) + .append(": ") + .append(metric.getValue().getCount()) + .append(LINE_SEPARATOR); } - @Override - public void report() { - StringBuilder builder = new StringBuilder(); - builder - .append(LINE_SEPARATOR) - .append("=========================== Starting metrics report ===========================") - .append(LINE_SEPARATOR); - - builder - .append(LINE_SEPARATOR) - .append("-- Counters -------------------------------------------------------------------") - .append(LINE_SEPARATOR); - for (Map.Entry metric : metricRegistry.getCounters().entrySet()) { - builder - .append(metric.getKey()).append(": ").append(metric.getValue().getCount()) - .append(LINE_SEPARATOR); - } - - builder - .append(LINE_SEPARATOR) - .append("-- Gauges ---------------------------------------------------------------------") - .append(LINE_SEPARATOR); - for (Map.Entry metric : metricRegistry.getGauges().entrySet()) { - builder - .append(metric.getKey()).append(": ").append(metric.getValue().getValue()) - .append(LINE_SEPARATOR); - } - - builder - .append(LINE_SEPARATOR) - .append("-- Meters ---------------------------------------------------------------------") - .append(LINE_SEPARATOR); - for (Map.Entry metric : metricRegistry.getMeters().entrySet()) { - Meter meter = metric.getValue(); - builder - .append(metric.getKey()).append(": ") - .append(meter.getMeanRate()) - .append(", 1mRate=").append(meter.getOneMinuteRate()) - .append(", 5mRate=").append(meter.getFiveMinuteRate()) - .append(", 15mRate=").append(meter.getFifteenMinuteRate()) - .append(LINE_SEPARATOR); - } - - builder - .append(LINE_SEPARATOR) - .append("-- Histograms -----------------------------------------------------------------") - .append(LINE_SEPARATOR); - for (Map.Entry metric : metricRegistry.getHistograms().entrySet()) { - Snapshot stats = metric.getValue().getSnapshot(); - builder - .append(metric.getValue()).append(": count=").append(stats.size()) - .append(", min=").append(stats.getMin()) - .append(", max=").append(stats.getMax()) - .append(", mean=").append(stats.getMean()) - .append(", stddev=").append(stats.getStdDev()) - .append(", p75=").append(stats.get75thPercentile()) - .append(", p95=").append(stats.get95thPercentile()) - .append(", p98=").append(stats.get98thPercentile()) - .append(", p99=").append(stats.get99thPercentile()) - .append(", p999=").append(stats.get999thPercentile()) - .append(LINE_SEPARATOR); - } - - builder - .append(LINE_SEPARATOR) - .append("=========================== Finished metrics report ===========================") - .append(LINE_SEPARATOR); - LOGGER.info(builder.toString()); + builder + .append(LINE_SEPARATOR) + .append("-- Gauges ---------------------------------------------------------------------") + .append(LINE_SEPARATOR); + for (Map.Entry metric : metricRegistry.getGauges().entrySet()) { + builder + .append(metric.getKey()) + .append(": ") + .append(metric.getValue().getValue()) + .append(LINE_SEPARATOR); } - @Override - public void close() { + builder + .append(LINE_SEPARATOR) + .append("-- Meters ---------------------------------------------------------------------") + .append(LINE_SEPARATOR); + for (Map.Entry metric : metricRegistry.getMeters().entrySet()) { + Meter meter = metric.getValue(); + builder + .append(metric.getKey()) + .append(": ") + .append(meter.getMeanRate()) + .append(", 1mRate=") + .append(meter.getOneMinuteRate()) + .append(", 5mRate=") + .append(meter.getFiveMinuteRate()) + .append(", 15mRate=") + .append(meter.getFifteenMinuteRate()) + .append(LINE_SEPARATOR); } - @Override - public String getReporterType() { - return TYPE_SLF4J; + builder + .append(LINE_SEPARATOR) + .append("-- Histograms -----------------------------------------------------------------") + .append(LINE_SEPARATOR); + for (Map.Entry metric : metricRegistry.getHistograms().entrySet()) { + Snapshot stats = metric.getValue().getSnapshot(); + builder + .append(metric.getValue()) + .append(": count=") + .append(stats.size()) + .append(", min=") + .append(stats.getMin()) + .append(", max=") + .append(stats.getMax()) + .append(", mean=") + .append(stats.getMean()) + .append(", stddev=") + .append(stats.getStdDev()) + .append(", p75=") + .append(stats.get75thPercentile()) + .append(", p95=") + .append(stats.get95thPercentile()) + .append(", p98=") + .append(stats.get98thPercentile()) + .append(", p99=") + .append(stats.get99thPercentile()) + .append(", p999=") + .append(stats.get999thPercentile()) + .append(LINE_SEPARATOR); } + builder + .append(LINE_SEPARATOR) + .append("=========================== Finished metrics report ===========================") + .append(LINE_SEPARATOR); + LOGGER.info(builder.toString()); + } + + @Override + public void close() {} + + @Override + public String getReporterType() { + return TYPE_SLF4J; + } } diff --git a/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/test/java/org/apache/geaflow/metrics/slf4j/Slf4jReporterTest.java b/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/test/java/org/apache/geaflow/metrics/slf4j/Slf4jReporterTest.java index e4215a03c..c5c731c71 100644 --- a/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/test/java/org/apache/geaflow/metrics/slf4j/Slf4jReporterTest.java +++ b/geaflow/geaflow-metrics/geaflow-metrics-slf4j/src/test/java/org/apache/geaflow/metrics/slf4j/Slf4jReporterTest.java @@ -22,6 +22,7 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.REPORTER_LIST; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.MetricGroupRegistry; import org.apache.geaflow.metrics.common.api.Gauge; @@ -36,36 +37,36 @@ public class Slf4jReporterTest { - @Test - public void test() { - Configuration config = new Configuration(); - config.put(REPORTER_LIST, ReporterRegistry.SLF4J_REPORTER); - MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(config); + @Test + public void test() { + Configuration config = new Configuration(); + config.put(REPORTER_LIST, ReporterRegistry.SLF4J_REPORTER); + MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(config); - MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(); - metricGroup.register("timestamp", new Gauge() { - @Override - public Object getValue() { - return System.currentTimeMillis(); - } + MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(); + metricGroup.register( + "timestamp", + new Gauge() { + @Override + public Object getValue() { + return System.currentTimeMillis(); + } - @Override - public void setValue(Object value) { - } + @Override + public void setValue(Object value) {} }); - Assert.assertNotNull(metricGroup.gauge("timestamp")); + Assert.assertNotNull(metricGroup.gauge("timestamp")); - Histogram histogram = metricGroup.histogram("histTest"); - histogram.update(1); - Meter meter = metricGroup.meter("meterTest"); - meter.mark(); + Histogram histogram = metricGroup.histogram("histTest"); + histogram.update(1); + Meter meter = metricGroup.meter("meterTest"); + meter.mark(); - List reporterList = metricGroupRegistry.getReporterList(); - Assert.assertNotNull(reporterList); - for (MetricReporter reporter : reporterList) { - ((ScheduledReporter) reporter).report(); - reporter.close(); - } + List reporterList = metricGroupRegistry.getReporterList(); + Assert.assertNotNull(reporterList); + for (MetricReporter reporter : reporterList) { + ((ScheduledReporter) reporter).report(); + reporter.close(); } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/BaseStatsCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/BaseStatsCollector.java index ccbf19a95..b579e31eb 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/BaseStatsCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/BaseStatsCollector.java @@ -21,25 +21,25 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.stats.sink.IStatsWriter; public class BaseStatsCollector { - protected String jobName; - private final List statsWriters; + protected String jobName; + private final List statsWriters; - public BaseStatsCollector(IStatsWriter statsWriter, Configuration configuration) { - this.statsWriters = new ArrayList<>(); - this.statsWriters.add(statsWriter); - this.jobName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); - } + public BaseStatsCollector(IStatsWriter statsWriter, Configuration configuration) { + this.statsWriters = new ArrayList<>(); + this.statsWriters.add(statsWriter); + this.jobName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); + } - public void addToWriterQueue(String key, Object value) { - for (int i = 0; i < statsWriters.size(); i++) { - statsWriters.get(i).addMetric(key, value); - } + public void addToWriterQueue(String key, Object value) { + for (int i = 0; i < statsWriters.size(); i++) { + statsWriters.get(i).addMetric(key, value); } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/EventCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/EventCollector.java index 101e7120b..4e94ea1ce 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/EventCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/EventCollector.java @@ -29,23 +29,29 @@ public class EventCollector extends BaseStatsCollector { - EventCollector(IStatsWriter statsWriter, Configuration configuration) { - super(statsWriter, configuration); - } - - public void reportEvent(ExceptionLevel severityLevel, EventLabel label, String message) { - reportEvent(severityLevel, label.name(), message); - } - - public void reportEvent(ExceptionLevel severityLevel, String label, String message) { - EventInfo eventInfo = new EventInfo(ProcessUtil.getHostname(), - ProcessUtil.getHostIp(), ProcessUtil.getProcessId(), - message, severityLevel.name(), label); - addToWriterQueue(genEventKey(), eventInfo); - } - - private String genEventKey() { - return jobName + StatsMetricType.Event.getValue() + (Long.MAX_VALUE - System.currentTimeMillis()); - } - + EventCollector(IStatsWriter statsWriter, Configuration configuration) { + super(statsWriter, configuration); + } + + public void reportEvent(ExceptionLevel severityLevel, EventLabel label, String message) { + reportEvent(severityLevel, label.name(), message); + } + + public void reportEvent(ExceptionLevel severityLevel, String label, String message) { + EventInfo eventInfo = + new EventInfo( + ProcessUtil.getHostname(), + ProcessUtil.getHostIp(), + ProcessUtil.getProcessId(), + message, + severityLevel.name(), + label); + addToWriterQueue(genEventKey(), eventInfo); + } + + private String genEventKey() { + return jobName + + StatsMetricType.Event.getValue() + + (Long.MAX_VALUE - System.currentTimeMillis()); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ExceptionCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ExceptionCollector.java index 1b3824e39..595233b1c 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ExceptionCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ExceptionCollector.java @@ -29,23 +29,28 @@ public class ExceptionCollector extends BaseStatsCollector { - ExceptionCollector(IStatsWriter statsWriter, Configuration configuration) { - super(statsWriter, configuration); - } - - public void reportException(Throwable e) { - reportException(ExceptionLevel.ERROR, e); - } - - public void reportException(ExceptionLevel severityLevel, Throwable e) { - ExceptionInfo log = new ExceptionInfo(ProcessUtil.getHostname(), - ProcessUtil.getHostIp(), ProcessUtil.getProcessId(), - ExceptionUtils.getStackTrace(e), severityLevel.name()); - addToWriterQueue(genExceptionKey(), log); - } - - private String genExceptionKey() { - return jobName + StatsMetricType.Exception.getValue() + (Long.MAX_VALUE - System.currentTimeMillis()); - } - + ExceptionCollector(IStatsWriter statsWriter, Configuration configuration) { + super(statsWriter, configuration); + } + + public void reportException(Throwable e) { + reportException(ExceptionLevel.ERROR, e); + } + + public void reportException(ExceptionLevel severityLevel, Throwable e) { + ExceptionInfo log = + new ExceptionInfo( + ProcessUtil.getHostname(), + ProcessUtil.getHostIp(), + ProcessUtil.getProcessId(), + ExceptionUtils.getStackTrace(e), + severityLevel.name()); + addToWriterQueue(genExceptionKey(), log); + } + + private String genExceptionKey() { + return jobName + + StatsMetricType.Exception.getValue() + + (Long.MAX_VALUE - System.currentTimeMillis()); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/HeartbeatCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/HeartbeatCollector.java index 999226e34..b0af2ae1e 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/HeartbeatCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/HeartbeatCollector.java @@ -26,16 +26,15 @@ public class HeartbeatCollector extends BaseStatsCollector { - HeartbeatCollector(IStatsWriter statsWriter, Configuration configuration) { - super(statsWriter, configuration); - } + HeartbeatCollector(IStatsWriter statsWriter, Configuration configuration) { + super(statsWriter, configuration); + } - public void reportHeartbeat(HeartbeatInfo heartbeatInfo) { - addToWriterQueue(getHeartbeatKey(), heartbeatInfo); - } - - private String getHeartbeatKey() { - return jobName + StatsMetricType.Heartbeat.getValue(); - } + public void reportHeartbeat(HeartbeatInfo heartbeatInfo) { + addToWriterQueue(getHeartbeatKey(), heartbeatInfo); + } + private String getHeartbeatKey() { + return jobName + StatsMetricType.Heartbeat.getValue(); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/MetricMetaCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/MetricMetaCollector.java index 96221ba23..8c9107a15 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/MetricMetaCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/MetricMetaCollector.java @@ -26,20 +26,19 @@ public class MetricMetaCollector extends BaseStatsCollector { - private static final String KEY_SPLIT = "_"; + private static final String KEY_SPLIT = "_"; - private static final String META_KEY = "META"; + private static final String META_KEY = "META"; - MetricMetaCollector(IStatsWriter statsWriter, Configuration configuration) { - super(statsWriter, configuration); - } + MetricMetaCollector(IStatsWriter statsWriter, Configuration configuration) { + super(statsWriter, configuration); + } - public void reportMetricMeta(MetricMeta metricMeta) { - addToWriterQueue(getMetricMetaKey(metricMeta.getMetricName()), metricMeta); - } - - private String getMetricMetaKey(String metricName) { - return jobName + StatsMetricType.Metrics.getValue() + META_KEY + KEY_SPLIT + metricName; - } + public void reportMetricMeta(MetricMeta metricMeta) { + addToWriterQueue(getMetricMetaKey(metricMeta.getMetricName()), metricMeta); + } + private String getMetricMetaKey(String metricName) { + return jobName + StatsMetricType.Metrics.getValue() + META_KEY + KEY_SPLIT + metricName; + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/PipelineStatsCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/PipelineStatsCollector.java index 74894ef3f..9e83d614c 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/PipelineStatsCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/PipelineStatsCollector.java @@ -28,32 +28,33 @@ public class PipelineStatsCollector extends BaseStatsCollector { - private static final String KEY_SPLIT = "_"; - private final MetricCache metricCache; - - PipelineStatsCollector(IStatsWriter statsWriter, Configuration configuration, - MetricCache metricCache) { - super(statsWriter, configuration); - this.metricCache = metricCache; - } - - public void reportPipelineMetrics(PipelineMetrics pipelineMetric) { - addToWriterQueue(genMetricKey(PipelineMetricsType.PIPELINE, pipelineMetric.getName()), pipelineMetric); - metricCache.addPipelineMetrics(pipelineMetric); - } - - public void reportCycleMetrics(CycleMetrics cycleMetrics) { - String name = cycleMetrics.getPipelineName() + KEY_SPLIT + cycleMetrics.getName(); - addToWriterQueue(genMetricKey(PipelineMetricsType.CYCLE, name), cycleMetrics); - metricCache.addCycleMetrics(cycleMetrics); - } - - private String genMetricKey(PipelineMetricsType type, String name) { - return jobName + StatsMetricType.Metrics.getValue() + type + KEY_SPLIT + name; - } - - private enum PipelineMetricsType { - PIPELINE, - CYCLE - } + private static final String KEY_SPLIT = "_"; + private final MetricCache metricCache; + + PipelineStatsCollector( + IStatsWriter statsWriter, Configuration configuration, MetricCache metricCache) { + super(statsWriter, configuration); + this.metricCache = metricCache; + } + + public void reportPipelineMetrics(PipelineMetrics pipelineMetric) { + addToWriterQueue( + genMetricKey(PipelineMetricsType.PIPELINE, pipelineMetric.getName()), pipelineMetric); + metricCache.addPipelineMetrics(pipelineMetric); + } + + public void reportCycleMetrics(CycleMetrics cycleMetrics) { + String name = cycleMetrics.getPipelineName() + KEY_SPLIT + cycleMetrics.getName(); + addToWriterQueue(genMetricKey(PipelineMetricsType.CYCLE, name), cycleMetrics); + metricCache.addCycleMetrics(cycleMetrics); + } + + private String genMetricKey(PipelineMetricsType type, String name) { + return jobName + StatsMetricType.Metrics.getValue() + type + KEY_SPLIT + name; + } + + private enum PipelineMetricsType { + PIPELINE, + CYCLE + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ProcessStatsCollector.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ProcessStatsCollector.java index 108520c83..860b01055 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ProcessStatsCollector.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/ProcessStatsCollector.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.metric.ProcessMetrics; @@ -43,138 +44,138 @@ public class ProcessStatsCollector { - private static final List OLD_GEN_COLLECTOR_NAMES = Arrays.asList( - // Oracle (Sun) HotSpot - // -XX:+UseSerialGC - "MarkSweepCompact", - // -XX:+UseParallelGC and (-XX:+UseParallelOldGC or -XX:+UseParallelOldGCCompacting) - "PS MarkSweep", - // -XX:+UseConcMarkSweepGC - "ConcurrentMarkSweep", - - // Oracle (BEA) JRockit - // -XgcPrio:pausetime - "Garbage collection optimized for short pausetimes Old Collector", - // -XgcPrio:throughput - "Garbage collection optimized for throughput Old Collector", - // -XgcPrio:deterministic - "Garbage collection optimized for deterministic pausetimes Old Collector", - //UseG1GC - "G1 Old Generation"); - - private long preTimeNano = System.nanoTime(); - private long preCpuTimeNano = -1; - private Map threadMap = new ConcurrentHashMap<>(); - - private final Counter totalUsedHeapMB; - private final Counter totalMemoryMB; - private final Histogram usedHeapRatio; - private final Histogram gcTimeHistogram; - private final Histogram fgcTimeHistogram; - private final Histogram fgcCountHistogram; - - ProcessStatsCollector(Configuration configuration) { - MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(configuration); - MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(MetricConstants.MODULE_SYSTEM); - totalUsedHeapMB = metricGroup.counter(MetricNameFormatter.totalHeapMetricName()); - totalMemoryMB = metricGroup.counter(MetricNameFormatter.totalMemoryMetricName()); - usedHeapRatio = metricGroup.histogram(MetricNameFormatter.heapUsageRatioMetricName()); - gcTimeHistogram = metricGroup.histogram(MetricNameFormatter.gcTimeMetricName()); - fgcTimeHistogram = metricGroup.histogram(MetricNameFormatter.fgcCountMetricName()); - fgcCountHistogram = metricGroup.histogram(MetricNameFormatter.fgcTimeMetricName()); + private static final List OLD_GEN_COLLECTOR_NAMES = + Arrays.asList( + // Oracle (Sun) HotSpot + // -XX:+UseSerialGC + "MarkSweepCompact", + // -XX:+UseParallelGC and (-XX:+UseParallelOldGC or -XX:+UseParallelOldGCCompacting) + "PS MarkSweep", + // -XX:+UseConcMarkSweepGC + "ConcurrentMarkSweep", + + // Oracle (BEA) JRockit + // -XgcPrio:pausetime + "Garbage collection optimized for short pausetimes Old Collector", + // -XgcPrio:throughput + "Garbage collection optimized for throughput Old Collector", + // -XgcPrio:deterministic + "Garbage collection optimized for deterministic pausetimes Old Collector", + // UseG1GC + "G1 Old Generation"); + + private long preTimeNano = System.nanoTime(); + private long preCpuTimeNano = -1; + private Map threadMap = new ConcurrentHashMap<>(); + + private final Counter totalUsedHeapMB; + private final Counter totalMemoryMB; + private final Histogram usedHeapRatio; + private final Histogram gcTimeHistogram; + private final Histogram fgcTimeHistogram; + private final Histogram fgcCountHistogram; + + ProcessStatsCollector(Configuration configuration) { + MetricGroupRegistry metricGroupRegistry = MetricGroupRegistry.getInstance(configuration); + MetricGroup metricGroup = metricGroupRegistry.getMetricGroup(MetricConstants.MODULE_SYSTEM); + totalUsedHeapMB = metricGroup.counter(MetricNameFormatter.totalHeapMetricName()); + totalMemoryMB = metricGroup.counter(MetricNameFormatter.totalMemoryMetricName()); + usedHeapRatio = metricGroup.histogram(MetricNameFormatter.heapUsageRatioMetricName()); + gcTimeHistogram = metricGroup.histogram(MetricNameFormatter.gcTimeMetricName()); + fgcTimeHistogram = metricGroup.histogram(MetricNameFormatter.fgcCountMetricName()); + fgcCountHistogram = metricGroup.histogram(MetricNameFormatter.fgcTimeMetricName()); + } + + public ProcessMetrics collect() { + + MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean(); + MemoryUsage heapMemory = memoryMXBean.getHeapMemoryUsage(); + long committedMB = heapMemory.getCommitted() / FileUtils.ONE_MB; + long usedMB = heapMemory.getUsed() / FileUtils.ONE_MB; + + ProcessMetrics workerMetrics = new ProcessMetrics(); + workerMetrics.setHeapCommittedMB(committedMB); + workerMetrics.setHeapUsedMB(usedMB); + workerMetrics.setTotalMemoryMB(ProcessUtil.getTotalMemory()); + + if (committedMB != 0) { + double percentage = Math.round(usedMB * 100.0 / committedMB); + workerMetrics.setHeapUsedRatio(percentage); } - public ProcessMetrics collect() { - - MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean(); - MemoryUsage heapMemory = memoryMXBean.getHeapMemoryUsage(); - long committedMB = heapMemory.getCommitted() / FileUtils.ONE_MB; - long usedMB = heapMemory.getUsed() / FileUtils.ONE_MB; - - ProcessMetrics workerMetrics = new ProcessMetrics(); - workerMetrics.setHeapCommittedMB(committedMB); - workerMetrics.setHeapUsedMB(usedMB); - workerMetrics.setTotalMemoryMB(ProcessUtil.getTotalMemory()); - - if (committedMB != 0) { - double percentage = Math.round(usedMB * 100.0 / committedMB); - workerMetrics.setHeapUsedRatio(percentage); - } - - long fgcCount = 0L; - long fgcTime = 0L; - long gcTime = 0L; - long gcCount = 0; - - List mxBeans = ManagementFactory.getGarbageCollectorMXBeans(); - for (GarbageCollectorMXBean gmx : mxBeans) { - gcCount += gmx.getCollectionCount(); - gcTime += gmx.getCollectionTime(); - if (OLD_GEN_COLLECTOR_NAMES.contains(gmx.getName())) { - fgcCount += gmx.getCollectionCount(); - fgcTime += gmx.getCollectionTime(); - } - } - - workerMetrics.setFgcCount(fgcCount); - workerMetrics.setFgcTime(fgcTime); - workerMetrics.setGcTime(gcTime); - workerMetrics.setGcCount(gcCount); - - OperatingSystemMXBean systemMXBean = ManagementFactory.getOperatingSystemMXBean(); - String avgLoad = String.format("%.2f", systemMXBean.getSystemLoadAverage()); - workerMetrics.setAvgLoad(Double.parseDouble(avgLoad)); - - int availProcessors = systemMXBean.getAvailableProcessors(); - workerMetrics.setAvailCores(availProcessors); - - double cpuUsage = getCpuUsage(); - double avgCpuUsage = cpuUsage / availProcessors; - workerMetrics.setProcessCpu(Double.parseDouble(String.format("%.2f", avgCpuUsage))); - workerMetrics.setUsedCores(Double.parseDouble(String.format("%.2f", cpuUsage / 100.0))); - workerMetrics.setActiveThreads(threadMap.size()); - uploadMetrics(workerMetrics); - - return workerMetrics; + long fgcCount = 0L; + long fgcTime = 0L; + long gcTime = 0L; + long gcCount = 0; + + List mxBeans = ManagementFactory.getGarbageCollectorMXBeans(); + for (GarbageCollectorMXBean gmx : mxBeans) { + gcCount += gmx.getCollectionCount(); + gcTime += gmx.getCollectionTime(); + if (OLD_GEN_COLLECTOR_NAMES.contains(gmx.getName())) { + fgcCount += gmx.getCollectionCount(); + fgcTime += gmx.getCollectionTime(); + } } - private synchronized double getCpuUsage() { - ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); - - // total CPU time for threads in nanoseconds. - long totalTimeNano = 0; - Map currentThreadMap = new HashMap<>(); - for (long id : threadMXBean.getAllThreadIds()) { - long threadCpuTime = threadMXBean.getThreadCpuTime(id); - if (threadCpuTime > 0) { - totalTimeNano += threadCpuTime; - currentThreadMap.put(id, threadCpuTime); - threadMap.remove(id); - } - } - for (Map.Entry entry : threadMap.entrySet()) { - totalTimeNano += entry.getValue(); - } - threadMap.clear(); - threadMap = currentThreadMap; - - long curTimeNano = System.nanoTime(); - long usedCpuTime = preCpuTimeNano == -1 ? 0 : totalTimeNano - preCpuTimeNano; - long totalPassedTime = curTimeNano - preTimeNano; - preTimeNano = curTimeNano; - preCpuTimeNano = totalTimeNano; - - return usedCpuTime * 100.0 / totalPassedTime; + workerMetrics.setFgcCount(fgcCount); + workerMetrics.setFgcTime(fgcTime); + workerMetrics.setGcTime(gcTime); + workerMetrics.setGcCount(gcCount); + + OperatingSystemMXBean systemMXBean = ManagementFactory.getOperatingSystemMXBean(); + String avgLoad = String.format("%.2f", systemMXBean.getSystemLoadAverage()); + workerMetrics.setAvgLoad(Double.parseDouble(avgLoad)); + + int availProcessors = systemMXBean.getAvailableProcessors(); + workerMetrics.setAvailCores(availProcessors); + + double cpuUsage = getCpuUsage(); + double avgCpuUsage = cpuUsage / availProcessors; + workerMetrics.setProcessCpu(Double.parseDouble(String.format("%.2f", avgCpuUsage))); + workerMetrics.setUsedCores(Double.parseDouble(String.format("%.2f", cpuUsage / 100.0))); + workerMetrics.setActiveThreads(threadMap.size()); + uploadMetrics(workerMetrics); + + return workerMetrics; + } + + private synchronized double getCpuUsage() { + ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); + + // total CPU time for threads in nanoseconds. + long totalTimeNano = 0; + Map currentThreadMap = new HashMap<>(); + for (long id : threadMXBean.getAllThreadIds()) { + long threadCpuTime = threadMXBean.getThreadCpuTime(id); + if (threadCpuTime > 0) { + totalTimeNano += threadCpuTime; + currentThreadMap.put(id, threadCpuTime); + threadMap.remove(id); + } } - - private void uploadMetrics(ProcessMetrics metrics) { - totalUsedHeapMB.inc(metrics.getHeapUsedMB()); - totalMemoryMB.inc(metrics.getTotalMemoryMB()); - - gcTimeHistogram.update(metrics.getGcTime()); - fgcTimeHistogram.update(metrics.getFgcTime()); - fgcCountHistogram.update(metrics.getFgcCount()); - usedHeapRatio.update((int) (metrics.getHeapUsedRatio())); + for (Map.Entry entry : threadMap.entrySet()) { + totalTimeNano += entry.getValue(); } - + threadMap.clear(); + threadMap = currentThreadMap; + + long curTimeNano = System.nanoTime(); + long usedCpuTime = preCpuTimeNano == -1 ? 0 : totalTimeNano - preCpuTimeNano; + long totalPassedTime = curTimeNano - preTimeNano; + preTimeNano = curTimeNano; + preCpuTimeNano = totalTimeNano; + + return usedCpuTime * 100.0 / totalPassedTime; + } + + private void uploadMetrics(ProcessMetrics metrics) { + totalUsedHeapMB.inc(metrics.getHeapUsedMB()); + totalMemoryMB.inc(metrics.getTotalMemoryMB()); + + gcTimeHistogram.update(metrics.getGcTime()); + fgcTimeHistogram.update(metrics.getFgcTime()); + fgcCountHistogram.update(metrics.getFgcCount()); + usedHeapRatio.update((int) (metrics.getHeapUsedRatio())); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/StatsCollectorFactory.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/StatsCollectorFactory.java index 32b1e2738..3157c5151 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/StatsCollectorFactory.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/collector/StatsCollectorFactory.java @@ -26,69 +26,69 @@ public class StatsCollectorFactory { - private final ExceptionCollector exceptionCollector; - private final EventCollector eventCollector; - private final PipelineStatsCollector pipelineStatsCollector; - private final ProcessStatsCollector processStatsCollector; - private final MetricMetaCollector metricMetaCollector; - private final HeartbeatCollector heartbeatCollector; - private final MetricCache metricCache; - private final IStatsWriter syncWriter; - private static StatsCollectorFactory INSTANCE; - - private StatsCollectorFactory(Configuration configuration) { - this.syncWriter = StatsWriterFactory.getStatsWriter(configuration, true); - this.exceptionCollector = new ExceptionCollector(syncWriter, configuration); - this.eventCollector = new EventCollector(syncWriter, configuration); - this.metricCache = new MetricCache(configuration); - IStatsWriter statsWriter = StatsWriterFactory.getStatsWriter(configuration); - this.pipelineStatsCollector = new PipelineStatsCollector(statsWriter, configuration, metricCache); - this.metricMetaCollector = new MetricMetaCollector(statsWriter, configuration); - this.processStatsCollector = new ProcessStatsCollector(configuration); - this.heartbeatCollector = new HeartbeatCollector(statsWriter, configuration); + private final ExceptionCollector exceptionCollector; + private final EventCollector eventCollector; + private final PipelineStatsCollector pipelineStatsCollector; + private final ProcessStatsCollector processStatsCollector; + private final MetricMetaCollector metricMetaCollector; + private final HeartbeatCollector heartbeatCollector; + private final MetricCache metricCache; + private final IStatsWriter syncWriter; + private static StatsCollectorFactory INSTANCE; + + private StatsCollectorFactory(Configuration configuration) { + this.syncWriter = StatsWriterFactory.getStatsWriter(configuration, true); + this.exceptionCollector = new ExceptionCollector(syncWriter, configuration); + this.eventCollector = new EventCollector(syncWriter, configuration); + this.metricCache = new MetricCache(configuration); + IStatsWriter statsWriter = StatsWriterFactory.getStatsWriter(configuration); + this.pipelineStatsCollector = + new PipelineStatsCollector(statsWriter, configuration, metricCache); + this.metricMetaCollector = new MetricMetaCollector(statsWriter, configuration); + this.processStatsCollector = new ProcessStatsCollector(configuration); + this.heartbeatCollector = new HeartbeatCollector(statsWriter, configuration); + } + + public static synchronized StatsCollectorFactory init(Configuration configuration) { + if (INSTANCE == null) { + INSTANCE = new StatsCollectorFactory(configuration); } + return INSTANCE; + } - public static synchronized StatsCollectorFactory init(Configuration configuration) { - if (INSTANCE == null) { - INSTANCE = new StatsCollectorFactory(configuration); - } - return INSTANCE; - } + public static StatsCollectorFactory getInstance() { + return INSTANCE; + } - public static StatsCollectorFactory getInstance() { - return INSTANCE; - } + public ExceptionCollector getExceptionCollector() { + return exceptionCollector; + } - public ExceptionCollector getExceptionCollector() { - return exceptionCollector; - } + public EventCollector getEventCollector() { + return eventCollector; + } - public EventCollector getEventCollector() { - return eventCollector; - } + public PipelineStatsCollector getPipelineStatsCollector() { + return pipelineStatsCollector; + } - public PipelineStatsCollector getPipelineStatsCollector() { - return pipelineStatsCollector; - } + public HeartbeatCollector getHeartbeatCollector() { + return heartbeatCollector; + } - public HeartbeatCollector getHeartbeatCollector() { - return heartbeatCollector; - } + public MetricMetaCollector getMetricMetaCollector() { + return metricMetaCollector; + } - public MetricMetaCollector getMetricMetaCollector() { - return metricMetaCollector; - } - - public ProcessStatsCollector getProcessStatsCollector() { - return processStatsCollector; - } + public ProcessStatsCollector getProcessStatsCollector() { + return processStatsCollector; + } - public MetricCache getMetricCache() { - return metricCache; - } - - public IStatsWriter getStatsWriter() { - return syncWriter; - } + public MetricCache getMetricCache() { + return metricCache; + } + public IStatsWriter getStatsWriter() { + return syncWriter; + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventInfo.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventInfo.java index df6605f5c..515f68d46 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventInfo.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventInfo.java @@ -21,24 +21,24 @@ public class EventInfo extends ExceptionInfo { - private String label; + private String label; - public EventInfo(String hostname, String ip, int processId, String message, String severity, - String label) { - super(hostname, ip, processId, message, severity); - this.label = label; - } + public EventInfo( + String hostname, String ip, int processId, String message, String severity, String label) { + super(hostname, ip, processId, message, severity); + this.label = label; + } - public String getLabel() { - return label; - } + public String getLabel() { + return label; + } - public void setLabel(String label) { - this.label = label; - } + public void setLabel(String label) { + this.label = label; + } - @Override - public String toString() { - return "EventInfo{" + "label='" + label + '\'' + ", super=" + super.toString() + '}'; - } + @Override + public String toString() { + return "EventInfo{" + "label='" + label + '\'' + ", super=" + super.toString() + '}'; + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventLabel.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventLabel.java index db6a1dad9..0a5d7f91d 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventLabel.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/EventLabel.java @@ -20,15 +20,13 @@ package org.apache.geaflow.stats.model; public enum EventLabel { + START_CLUSTER_SUCCESS, - START_CLUSTER_SUCCESS, + START_CLUSTER_FAILED, - START_CLUSTER_FAILED, + WORKER_PROCESS_EXITED, - WORKER_PROCESS_EXITED, - - FAILOVER_START, - - FAILOVER_FINISH + FAILOVER_START, + FAILOVER_FINISH } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionInfo.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionInfo.java index ba1faacf5..27697917b 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionInfo.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionInfo.java @@ -23,77 +23,89 @@ public class ExceptionInfo implements Serializable { - protected String hostname; - protected String ip; - protected int processId; - protected String severity; - protected String message; - protected long timestamp; - - public ExceptionInfo(String hostname, String ip, int processId, String message, - String severity) { - this.hostname = hostname; - this.ip = ip; - this.processId = processId; - this.message = message; - this.timestamp = System.currentTimeMillis(); - this.severity = severity; - } - - public String getIp() { - return ip; - } - - public void setIp(String ip) { - this.ip = ip; - } - - public String getHostname() { - return hostname; - } - - public void setHostname(String hostname) { - this.hostname = hostname; - } - - public int getProcessId() { - return processId; - } - - public void setProcessId(int processId) { - this.processId = processId; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } - - public long getTimestamp() { - return timestamp; - } - - public void setTimestamp(long timestamp) { - this.timestamp = timestamp; - } - - public String getSeverity() { - return severity; - } - - public void setSeverity(String severity) { - this.severity = severity; - } - - @Override - public String toString() { - return "ExceptionInfo{" + "hostname='" + hostname + '\'' + ", ip='" + ip + '\'' - + ", processId=" + processId + ", severity='" + severity + '\'' + ", message='" - + message + '\'' + ", timestamp=" + timestamp + '}'; - } + protected String hostname; + protected String ip; + protected int processId; + protected String severity; + protected String message; + protected long timestamp; + + public ExceptionInfo(String hostname, String ip, int processId, String message, String severity) { + this.hostname = hostname; + this.ip = ip; + this.processId = processId; + this.message = message; + this.timestamp = System.currentTimeMillis(); + this.severity = severity; + } + + public String getIp() { + return ip; + } + + public void setIp(String ip) { + this.ip = ip; + } + + public String getHostname() { + return hostname; + } + + public void setHostname(String hostname) { + this.hostname = hostname; + } + + public int getProcessId() { + return processId; + } + + public void setProcessId(int processId) { + this.processId = processId; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public long getTimestamp() { + return timestamp; + } + + public void setTimestamp(long timestamp) { + this.timestamp = timestamp; + } + + public String getSeverity() { + return severity; + } + + public void setSeverity(String severity) { + this.severity = severity; + } + + @Override + public String toString() { + return "ExceptionInfo{" + + "hostname='" + + hostname + + '\'' + + ", ip='" + + ip + + '\'' + + ", processId=" + + processId + + ", severity='" + + severity + + '\'' + + ", message='" + + message + + '\'' + + ", timestamp=" + + timestamp + + '}'; + } } - - diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionLevel.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionLevel.java index 0175452e5..77c0801c5 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionLevel.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/ExceptionLevel.java @@ -21,29 +21,18 @@ public enum ExceptionLevel { - /** - * fatal exception. For example, the process is exited. - */ - FATAL, + /** fatal exception. For example, the process is exited. */ + FATAL, - /** - * program in error state, the program is going to exit or not. - */ - ERROR, + /** program in error state, the program is going to exit or not. */ + ERROR, - /** - * warning exception. - */ - WARN, + /** warning exception. */ + WARN, - /** - * info message. - */ - INFO, + /** info message. */ + INFO, - /** - * unknown label. - */ - UNKNOWN + /** unknown label. */ + UNKNOWN } - diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/MetricCache.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/MetricCache.java index b6d52a123..c102c69af 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/MetricCache.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/MetricCache.java @@ -29,113 +29,115 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.metric.CycleMetrics; import org.apache.geaflow.common.metric.PipelineMetrics; public class MetricCache implements Serializable { - private static final int DEFAULT_MAX_JOBS = 50; - - private final int maxPipelines; - private final BoundedHashMap submittedPipelines; - private Map pipelineMetricCacheMap; - - public MetricCache() { - this(DEFAULT_MAX_JOBS); + private static final int DEFAULT_MAX_JOBS = 50; + + private final int maxPipelines; + private final BoundedHashMap submittedPipelines; + private Map pipelineMetricCacheMap; + + public MetricCache() { + this(DEFAULT_MAX_JOBS); + } + + public MetricCache(Configuration configuration) { + this(configuration.getInteger(METRIC_MAX_CACHED_PIPELINES)); + } + + public MetricCache(int maxSize) { + this.maxPipelines = maxSize; + this.submittedPipelines = new BoundedHashMap<>(maxSize); + this.pipelineMetricCacheMap = new ConcurrentHashMap<>(); + } + + public synchronized void addPipelineMetrics(PipelineMetrics pipelineMetrics) { + submittedPipelines.put(pipelineMetrics.getStartTime(), pipelineMetrics.getName()); + PipelineMetricCache cache = + pipelineMetricCacheMap.computeIfAbsent( + pipelineMetrics.getName(), key -> new PipelineMetricCache()); + cache.updatePipelineMetrics(pipelineMetrics); + if (pipelineMetricCacheMap.size() > maxPipelines) { + pipelineMetricCacheMap.keySet().retainAll(submittedPipelines.values()); } - - public MetricCache(Configuration configuration) { - this(configuration.getInteger(METRIC_MAX_CACHED_PIPELINES)); + } + + public synchronized void addCycleMetrics(CycleMetrics cycleMetrics) { + PipelineMetricCache cache = + pipelineMetricCacheMap.computeIfAbsent( + cycleMetrics.getPipelineName(), key -> new PipelineMetricCache()); + cache.addCycleMetrics(cycleMetrics); + } + + public Map getPipelineMetricCaches() { + return pipelineMetricCacheMap; + } + + private void readObject(ObjectInputStream inputStream) + throws ClassNotFoundException, IOException { + this.pipelineMetricCacheMap = (Map) inputStream.readObject(); + } + + private void writeObject(ObjectOutputStream outputStream) throws IOException { + outputStream.writeObject(pipelineMetricCacheMap); + } + + public void mergeMetricCache(MetricCache metricCache) { + this.pipelineMetricCacheMap.putAll(metricCache.pipelineMetricCacheMap); + } + + public void clearAll() { + this.pipelineMetricCacheMap.clear(); + } + + public static class PipelineMetricCache implements Serializable { + private PipelineMetrics pipelineMetrics; + private Map cycleMetricMap; + + public PipelineMetricCache() { + this.cycleMetricMap = new HashMap<>(); } - public MetricCache(int maxSize) { - this.maxPipelines = maxSize; - this.submittedPipelines = new BoundedHashMap<>(maxSize); - this.pipelineMetricCacheMap = new ConcurrentHashMap<>(); + public void updatePipelineMetrics(PipelineMetrics pipelineMetrics) { + this.pipelineMetrics = pipelineMetrics; } - public synchronized void addPipelineMetrics(PipelineMetrics pipelineMetrics) { - submittedPipelines.put(pipelineMetrics.getStartTime(), pipelineMetrics.getName()); - PipelineMetricCache cache = pipelineMetricCacheMap.computeIfAbsent(pipelineMetrics.getName(), - key -> new PipelineMetricCache()); - cache.updatePipelineMetrics(pipelineMetrics); - if (pipelineMetricCacheMap.size() > maxPipelines) { - pipelineMetricCacheMap.keySet().retainAll(submittedPipelines.values()); - } + public void addCycleMetrics(CycleMetrics cycleMetrics) { + this.cycleMetricMap.put(cycleMetrics.getName(), cycleMetrics); } - public synchronized void addCycleMetrics(CycleMetrics cycleMetrics) { - PipelineMetricCache cache = pipelineMetricCacheMap.computeIfAbsent(cycleMetrics.getPipelineName(), - key -> new PipelineMetricCache()); - cache.addCycleMetrics(cycleMetrics); + public PipelineMetrics getPipelineMetrics() { + return pipelineMetrics; } - public Map getPipelineMetricCaches() { - return pipelineMetricCacheMap; + public void setPipelineMetrics(PipelineMetrics pipelineMetrics) { + this.pipelineMetrics = pipelineMetrics; } - private void readObject(ObjectInputStream inputStream) throws ClassNotFoundException, - IOException { - this.pipelineMetricCacheMap = (Map) inputStream.readObject(); + public Map getCycleMetricList() { + return cycleMetricMap; } - private void writeObject(ObjectOutputStream outputStream) throws IOException { - outputStream.writeObject(pipelineMetricCacheMap); + public void setCycleMetricList(Map cycleMetricList) { + this.cycleMetricMap = cycleMetricList; } + } - public void mergeMetricCache(MetricCache metricCache) { - this.pipelineMetricCacheMap.putAll(metricCache.pipelineMetricCacheMap); - } + public static class BoundedHashMap extends LinkedHashMap { + private final int maxSize; - public void clearAll() { - this.pipelineMetricCacheMap.clear(); + public BoundedHashMap(int capacity) { + super(capacity, 0.75f, true); + this.maxSize = capacity; } - public static class PipelineMetricCache implements Serializable { - private PipelineMetrics pipelineMetrics; - private Map cycleMetricMap; - - public PipelineMetricCache() { - this.cycleMetricMap = new HashMap<>(); - } - - public void updatePipelineMetrics(PipelineMetrics pipelineMetrics) { - this.pipelineMetrics = pipelineMetrics; - } - - public void addCycleMetrics(CycleMetrics cycleMetrics) { - this.cycleMetricMap.put(cycleMetrics.getName(), cycleMetrics); - } - - public PipelineMetrics getPipelineMetrics() { - return pipelineMetrics; - } - - public void setPipelineMetrics(PipelineMetrics pipelineMetrics) { - this.pipelineMetrics = pipelineMetrics; - } - - public Map getCycleMetricList() { - return cycleMetricMap; - } - - public void setCycleMetricList(Map cycleMetricList) { - this.cycleMetricMap = cycleMetricList; - } - } - - public static class BoundedHashMap extends LinkedHashMap { - private final int maxSize; - - public BoundedHashMap(int capacity) { - super(capacity, 0.75f, true); - this.maxSize = capacity; - } - - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return this.size() > maxSize; - } + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return this.size() > maxSize; } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsMetricType.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsMetricType.java index 845f9b22c..aad655983 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsMetricType.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsMetricType.java @@ -21,33 +21,25 @@ public enum StatsMetricType { - /** - * exception log. - */ - Exception("_exception_"), - - /** - * event log. - */ - Event("_event_"), - - /** - * runtime metrics. - */ - Metrics("_metrics_"), - - /** - * runtime heartbeat map. - */ - Heartbeat("_heartbeat_"); - - private final String value; - - StatsMetricType(String value) { - this.value = value; - } - - public String getValue() { - return value; - } + /** exception log. */ + Exception("_exception_"), + + /** event log. */ + Event("_event_"), + + /** runtime metrics. */ + Metrics("_metrics_"), + + /** runtime heartbeat map. */ + Heartbeat("_heartbeat_"); + + private final String value; + + StatsMetricType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsStoreType.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsStoreType.java index 2f9dcc507..415307b56 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsStoreType.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/model/StatsStoreType.java @@ -21,19 +21,12 @@ public enum StatsStoreType { - /** - * store in memory. - */ - MEMORY, + /** store in memory. */ + MEMORY, - /** - * store in hbase kvStore. - */ - HBASE, - - /** - * store in rmdb, like mysql. - */ - JDBC + /** store in hbase kvStore. */ + HBASE, + /** store in rmdb, like mysql. */ + JDBC } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriter.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriter.java index 346f6ec2d..1875e319a 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriter.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriter.java @@ -21,13 +21,13 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SYSTEM_META_TABLE; -import com.alibaba.fastjson.JSON; import java.util.LinkedList; import java.util.Queue; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.tuple.Tuple; @@ -42,131 +42,136 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class AsyncKvStoreWriter implements IStatsWriter { - - private static final Logger LOGGER = LoggerFactory.getLogger(AsyncKvStoreWriter.class); - private static final int MAX_METRIC_QUEUE_SIZE = 1024; - private static final String DEFAULT_NAMESPACE = "default"; - - private final int batchFlushSize; - private final int flushIntervalMs; - private volatile boolean running; - - private final IKVStore kvStore; - private final Queue> metricQueue; - private final ExecutorService executorService; +import com.alibaba.fastjson.JSON; - public AsyncKvStoreWriter(Configuration configuration) { - this.batchFlushSize = configuration - .getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_BATCH_SIZE); - this.flushIntervalMs = - configuration.getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_INTERVAL_MS); - this.kvStore = createKvStore(configuration); +public class AsyncKvStoreWriter implements IStatsWriter { - this.metricQueue = new LinkedBlockingQueue<>(MAX_METRIC_QUEUE_SIZE); - int threadNum = configuration.getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_THREADS); - this.executorService = new ThreadPoolExecutor(threadNum, threadNum, 30, TimeUnit.SECONDS, + private static final Logger LOGGER = LoggerFactory.getLogger(AsyncKvStoreWriter.class); + private static final int MAX_METRIC_QUEUE_SIZE = 1024; + private static final String DEFAULT_NAMESPACE = "default"; + + private final int batchFlushSize; + private final int flushIntervalMs; + private volatile boolean running; + + private final IKVStore kvStore; + private final Queue> metricQueue; + private final ExecutorService executorService; + + public AsyncKvStoreWriter(Configuration configuration) { + this.batchFlushSize = + configuration.getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_BATCH_SIZE); + this.flushIntervalMs = + configuration.getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_INTERVAL_MS); + this.kvStore = createKvStore(configuration); + + this.metricQueue = new LinkedBlockingQueue<>(MAX_METRIC_QUEUE_SIZE); + int threadNum = configuration.getInteger(ExecutionConfigKeys.STATS_METRIC_FLUSH_THREADS); + this.executorService = + new ThreadPoolExecutor( + threadNum, + threadNum, + 30, + TimeUnit.SECONDS, new LinkedBlockingQueue<>(threadNum), ThreadUtil.namedThreadFactory(true, "stats-flusher")); - for (int i = 0; i < threadNum; i++) { - this.executorService.submit(new MetricFlushTask()); - } - this.running = true; + for (int i = 0; i < threadNum; i++) { + this.executorService.submit(new MetricFlushTask()); } + this.running = true; + } - private IKVStore createKvStore(Configuration configuration) { - String namespace = DEFAULT_NAMESPACE; - if (configuration.contains(SYSTEM_META_TABLE)) { - namespace = configuration.getString(SYSTEM_META_TABLE); - } - StoreContext storeContext = new StoreContext(namespace); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); - storeContext.withConfig(configuration); - - String storeType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); - IStoreBuilder builder = StoreBuilderFactory.build(storeType); - IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); - kvStore.init(storeContext); - LOGGER.info("create stats store with type:{} namespace:{}", storeType, namespace); - return kvStore; + private IKVStore createKvStore(Configuration configuration) { + String namespace = DEFAULT_NAMESPACE; + if (configuration.contains(SYSTEM_META_TABLE)) { + namespace = configuration.getString(SYSTEM_META_TABLE); } - - @Override - public void addMetric(String key, Object value) { - Tuple tuple = Tuple.of(key, value); - boolean result = metricQueue.offer(tuple); - while (!result) { - Tuple expired = metricQueue.poll(); - if (expired != null) { - LOGGER.warn("discard metric: {} due to capacity limit", expired.getF0()); - } - result = metricQueue.offer(tuple); - } + StoreContext storeContext = new StoreContext(namespace); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); + storeContext.withConfig(configuration); + + String storeType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); + IStoreBuilder builder = StoreBuilderFactory.build(storeType); + IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); + kvStore.init(storeContext); + LOGGER.info("create stats store with type:{} namespace:{}", storeType, namespace); + return kvStore; + } + + @Override + public void addMetric(String key, Object value) { + Tuple tuple = Tuple.of(key, value); + boolean result = metricQueue.offer(tuple); + while (!result) { + Tuple expired = metricQueue.poll(); + if (expired != null) { + LOGGER.warn("discard metric: {} due to capacity limit", expired.getF0()); + } + result = metricQueue.offer(tuple); } + } - @Override - public void close() { - if (!running) { - return; - } - if (executorService != null) { - executorService.shutdown(); - } - running = false; + @Override + public void close() { + if (!running) { + return; + } + if (executorService != null) { + executorService.shutdown(); } + running = false; + } - public class MetricFlushTask implements Runnable { + public class MetricFlushTask implements Runnable { - private int flushSize; - private final Queue> buffers; + private int flushSize; + private final Queue> buffers; - public MetricFlushTask() { - this.flushSize = 0; - this.buffers = new LinkedList<>(); - } + public MetricFlushTask() { + this.flushSize = 0; + this.buffers = new LinkedList<>(); + } - @Override - public void run() { - while (true) { - try { - fillBuffers(); - if (flushSize > 0) { - doFlush(); - } - SleepUtils.sleepMilliSecond(flushIntervalMs); - } catch (Throwable e) { - LOGGER.warn("flush stats metrics failed:{}", e.getMessage(), e); - } - - } + @Override + public void run() { + while (true) { + try { + fillBuffers(); + if (flushSize > 0) { + doFlush(); + } + SleepUtils.sleepMilliSecond(flushIntervalMs); + } catch (Throwable e) { + LOGGER.warn("flush stats metrics failed:{}", e.getMessage(), e); } + } + } - private void fillBuffers() { - int count = 0; - while (count < batchFlushSize) { - Tuple tuple = metricQueue.poll(); - if (tuple == null) { - break; - } - buffers.add(tuple); - count++; - } - flushSize = count; + private void fillBuffers() { + int count = 0; + while (count < batchFlushSize) { + Tuple tuple = metricQueue.poll(); + if (tuple == null) { + break; } + buffers.add(tuple); + count++; + } + flushSize = count; + } - private void doFlush() { - try { - while (!buffers.isEmpty()) { - Tuple tuple = buffers.poll(); - kvStore.put(tuple.f0, JSON.toJSONString(tuple.f1)); - } - kvStore.flush(); - } catch (Throwable e) { - LOGGER.warn("discard {} metrics due to: {}", flushSize, e.getMessage(), e); - } finally { - flushSize = 0; - } + private void doFlush() { + try { + while (!buffers.isEmpty()) { + Tuple tuple = buffers.poll(); + kvStore.put(tuple.f0, JSON.toJSONString(tuple.f1)); } + kvStore.flush(); + } catch (Throwable e) { + LOGGER.warn("discard {} metrics due to: {}", flushSize, e.getMessage(), e); + } finally { + flushSize = 0; + } } - + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/IStatsWriter.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/IStatsWriter.java index 8a42c521f..4fe1ef908 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/IStatsWriter.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/IStatsWriter.java @@ -21,8 +21,7 @@ public interface IStatsWriter { - void addMetric(String key, Object value); - - void close(); + void addMetric(String key, Object value); + void close(); } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/MemoryStatsWriter.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/MemoryStatsWriter.java index 4f372be28..1ed3f2611 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/MemoryStatsWriter.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/MemoryStatsWriter.java @@ -23,14 +23,13 @@ import org.slf4j.LoggerFactory; public class MemoryStatsWriter implements IStatsWriter { - private static final Logger LOGGER = LoggerFactory.getLogger(MemoryStatsWriter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(MemoryStatsWriter.class); - @Override - public void addMetric(String key, Object value) { - LOGGER.info("update metric: key={}, value={}", key, value); - } + @Override + public void addMetric(String key, Object value) { + LOGGER.info("update metric: key={}, value={}", key, value); + } - @Override - public void close() { - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/StatsWriterFactory.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/StatsWriterFactory.java index be816fe09..f7c583970 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/StatsWriterFactory.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/StatsWriterFactory.java @@ -25,19 +25,18 @@ public class StatsWriterFactory { - public static IStatsWriter getStatsWriter(Configuration configuration) { - return getStatsWriter(configuration, false); - } + public static IStatsWriter getStatsWriter(Configuration configuration) { + return getStatsWriter(configuration, false); + } - public static IStatsWriter getStatsWriter(Configuration configuration, boolean isSync) { - String statStoreType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); - if (statStoreType.equalsIgnoreCase(StatsStoreType.MEMORY.name())) { - return new MemoryStatsWriter(); - } else if (statStoreType.equalsIgnoreCase(StatsStoreType.HBASE.name()) || statStoreType - .equalsIgnoreCase(StatsStoreType.JDBC.name())) { - return isSync ? new SyncKvStoreWriter(configuration) : - new AsyncKvStoreWriter(configuration); - } - throw new UnsupportedOperationException("unknown stats store type" + statStoreType); + public static IStatsWriter getStatsWriter(Configuration configuration, boolean isSync) { + String statStoreType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); + if (statStoreType.equalsIgnoreCase(StatsStoreType.MEMORY.name())) { + return new MemoryStatsWriter(); + } else if (statStoreType.equalsIgnoreCase(StatsStoreType.HBASE.name()) + || statStoreType.equalsIgnoreCase(StatsStoreType.JDBC.name())) { + return isSync ? new SyncKvStoreWriter(configuration) : new AsyncKvStoreWriter(configuration); } + throw new UnsupportedOperationException("unknown stats store type" + statStoreType); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/SyncKvStoreWriter.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/SyncKvStoreWriter.java index 9b5900238..a444f351f 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/SyncKvStoreWriter.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/main/java/org/apache/geaflow/stats/sink/SyncKvStoreWriter.java @@ -21,7 +21,6 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SYSTEM_META_TABLE; -import com.alibaba.fastjson.JSON; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.state.DataModel; @@ -33,42 +32,42 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.alibaba.fastjson.JSON; + public class SyncKvStoreWriter implements IStatsWriter { - private static final Logger LOGGER = LoggerFactory.getLogger(SyncKvStoreWriter.class); - private static final String DEFAULT_NAMESPACE = "default"; + private static final Logger LOGGER = LoggerFactory.getLogger(SyncKvStoreWriter.class); + private static final String DEFAULT_NAMESPACE = "default"; - private final IKVStore kvStore; + private final IKVStore kvStore; - public SyncKvStoreWriter(Configuration configuration) { - this.kvStore = createKvStore(configuration); - } + public SyncKvStoreWriter(Configuration configuration) { + this.kvStore = createKvStore(configuration); + } - private IKVStore createKvStore(Configuration configuration) { - String namespace = DEFAULT_NAMESPACE; - if (configuration.contains(SYSTEM_META_TABLE)) { - namespace = configuration.getString(SYSTEM_META_TABLE); - } - StoreContext storeContext = new StoreContext(namespace); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); - storeContext.withConfig(configuration); - - String storeType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); - IStoreBuilder builder = StoreBuilderFactory.build(storeType); - IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); - kvStore.init(storeContext); - LOGGER.info("create stats store with type:{} namespace:{}", storeType, namespace); - return kvStore; + private IKVStore createKvStore(Configuration configuration) { + String namespace = DEFAULT_NAMESPACE; + if (configuration.contains(SYSTEM_META_TABLE)) { + namespace = configuration.getString(SYSTEM_META_TABLE); } + StoreContext storeContext = new StoreContext(namespace); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, String.class)); + storeContext.withConfig(configuration); - @Override - public void addMetric(String key, Object value) { - kvStore.put(key, JSON.toJSONString(value)); - kvStore.flush(); - } + String storeType = configuration.getString(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE); + IStoreBuilder builder = StoreBuilderFactory.build(storeType); + IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, configuration); + kvStore.init(storeContext); + LOGGER.info("create stats store with type:{} namespace:{}", storeType, namespace); + return kvStore; + } - @Override - public void close() { - } + @Override + public void addMetric(String key, Object value) { + kvStore.put(key, JSON.toJSONString(value)); + kvStore.flush(); + } + @Override + public void close() {} } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/collector/ExceptionCollectorTest.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/collector/ExceptionCollectorTest.java index 493908065..144a3138d 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/collector/ExceptionCollectorTest.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/collector/ExceptionCollectorTest.java @@ -28,16 +28,14 @@ public class ExceptionCollectorTest { - @Test - public void testException() { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, - StatsStoreType.MEMORY.name()); - configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "1"); - IStatsWriter writer = Mockito.mock(IStatsWriter.class); - ExceptionCollector collector = new ExceptionCollector(writer, configuration); - collector.reportException(new RuntimeException()); - Mockito.verify(writer, Mockito.times(1)); - } - + @Test + public void testException() { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); + configuration.put(ExecutionConfigKeys.JOB_UNIQUE_ID, "1"); + IStatsWriter writer = Mockito.mock(IStatsWriter.class); + ExceptionCollector collector = new ExceptionCollector(writer, configuration); + collector.reportException(new RuntimeException()); + Mockito.verify(writer, Mockito.times(1)); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/model/MetricCacheTest.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/model/MetricCacheTest.java index 7c2ae4192..3c568dd53 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/model/MetricCacheTest.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/model/MetricCacheTest.java @@ -26,33 +26,32 @@ public class MetricCacheTest { - @Test - public void testBoundedHashMap() { - BoundedHashMap hashMap = new BoundedHashMap(3); - hashMap.put(1, 1); - hashMap.put(2, 2); - hashMap.put(3, 3); - Assert.assertEquals(hashMap.size(), 3); - hashMap.put(4, 4); - Assert.assertEquals(hashMap.size(), 3); - Assert.assertFalse(hashMap.containsKey(1)); - } - - @Test - public void testMetricCache() { - MetricCache metricCache = new MetricCache(2); - PipelineMetrics metric1 = new PipelineMetrics("1"); - metric1.setStartTime(1); - PipelineMetrics metric2 = new PipelineMetrics("2"); - metric2.setStartTime(2); - metricCache.addPipelineMetrics(metric1); - metricCache.addPipelineMetrics(metric2); - Assert.assertEquals(metricCache.getPipelineMetricCaches().size(), 2); - PipelineMetrics metric3 = new PipelineMetrics("3"); - metric3.setStartTime(3); - metricCache.addPipelineMetrics(metric3); - Assert.assertEquals(metricCache.getPipelineMetricCaches().size(), 2); - Assert.assertFalse(metricCache.getPipelineMetricCaches().containsKey("1")); - } + @Test + public void testBoundedHashMap() { + BoundedHashMap hashMap = new BoundedHashMap(3); + hashMap.put(1, 1); + hashMap.put(2, 2); + hashMap.put(3, 3); + Assert.assertEquals(hashMap.size(), 3); + hashMap.put(4, 4); + Assert.assertEquals(hashMap.size(), 3); + Assert.assertFalse(hashMap.containsKey(1)); + } + @Test + public void testMetricCache() { + MetricCache metricCache = new MetricCache(2); + PipelineMetrics metric1 = new PipelineMetrics("1"); + metric1.setStartTime(1); + PipelineMetrics metric2 = new PipelineMetrics("2"); + metric2.setStartTime(2); + metricCache.addPipelineMetrics(metric1); + metricCache.addPipelineMetrics(metric2); + Assert.assertEquals(metricCache.getPipelineMetricCaches().size(), 2); + PipelineMetrics metric3 = new PipelineMetrics("3"); + metric3.setStartTime(3); + metricCache.addPipelineMetrics(metric3); + Assert.assertEquals(metricCache.getPipelineMetricCaches().size(), 2); + Assert.assertFalse(metricCache.getPipelineMetricCaches().containsKey("1")); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriterTest.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriterTest.java index b9d006ad6..aa7230ee8 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriterTest.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/AsyncKvStoreWriterTest.java @@ -27,16 +27,15 @@ public class AsyncKvStoreWriterTest { - @Test - public void test() { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); - AsyncKvStoreWriter writer = new AsyncKvStoreWriter(configuration); - for (int i = 0; i < 1500; i++) { - writer.addMetric(String.valueOf(i), i); - } - SleepUtils.sleepSecond(3); - writer.close(); + @Test + public void test() { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); + AsyncKvStoreWriter writer = new AsyncKvStoreWriter(configuration); + for (int i = 0; i < 1500; i++) { + writer.addMetric(String.valueOf(i), i); } - + SleepUtils.sleepSecond(3); + writer.close(); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/StatsWriterFactoryTest.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/StatsWriterFactoryTest.java index dc68d2f73..65fce0823 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/StatsWriterFactoryTest.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/StatsWriterFactoryTest.java @@ -27,13 +27,11 @@ public class StatsWriterFactoryTest { - @Test - public void test() { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, - StatsStoreType.MEMORY.name()); - IStatsWriter writer = StatsWriterFactory.getStatsWriter(configuration); - Assert.assertTrue(writer instanceof MemoryStatsWriter); - } - + @Test + public void test() { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); + IStatsWriter writer = StatsWriterFactory.getStatsWriter(configuration); + Assert.assertTrue(writer instanceof MemoryStatsWriter); + } } diff --git a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/SyncKvStoreWriterTest.java b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/SyncKvStoreWriterTest.java index 0a642a40f..92642446c 100644 --- a/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/SyncKvStoreWriterTest.java +++ b/geaflow/geaflow-metrics/geaflow-stats-metrics/src/test/java/org/apache/geaflow/stats/sink/SyncKvStoreWriterTest.java @@ -27,16 +27,15 @@ public class SyncKvStoreWriterTest { - @Test - public void testSyncStatsWriter() { - Configuration configuration = new Configuration(); - configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); - SyncKvStoreWriter writer = new SyncKvStoreWriter(configuration); - for (int i = 0; i < 1500; i++) { - writer.addMetric(String.valueOf(i), i); - } - SleepUtils.sleepSecond(3); - writer.close(); + @Test + public void testSyncStatsWriter() { + Configuration configuration = new Configuration(); + configuration.put(ExecutionConfigKeys.STATS_METRIC_STORE_TYPE, StatsStoreType.MEMORY.name()); + SyncKvStoreWriter writer = new SyncKvStoreWriter(configuration); + for (int i = 0; i < 1500; i++) { + writer.addMetric(String.valueOf(i), i); } - + SleepUtils.sleepSecond(3); + writer.close(); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/common/Null.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/common/Null.java index d7d045855..ef318a6e1 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/common/Null.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/common/Null.java @@ -23,13 +23,13 @@ public class Null implements Serializable { - @Override - public int hashCode() { - return 1; - } + @Override + public int hashCode() { + return 1; + } - @Override - public boolean equals(Object obj) { - return obj instanceof Null; - } + @Override + public boolean equals(Object obj) { + return obj instanceof Null; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/GraphRecord.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/GraphRecord.java index d7f6f6244..0b734ffb5 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/GraphRecord.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/GraphRecord.java @@ -20,42 +20,42 @@ package org.apache.geaflow.model.graph; import java.io.Serializable; + import org.apache.geaflow.model.graph.vertex.IVertex; public class GraphRecord implements Serializable { - private Object record; - - public GraphRecord() { + private Object record; - } + public GraphRecord() {} - public GraphRecord(Object e) { - this.record = e; - } + public GraphRecord(Object e) { + this.record = e; + } - public ViewType getViewType() { - if (record instanceof IVertex) { - return ViewType.vertex; - } else { - return ViewType.edge; - } + public ViewType getViewType() { + if (record instanceof IVertex) { + return ViewType.vertex; + } else { + return ViewType.edge; } + } - public V getVertex() { - return (V) record; - } + public V getVertex() { + return (V) record; + } - public E getEdge() { - return (E) record; - } + public E getEdge() { + return (E) record; + } - @Override - public String toString() { - return record.toString(); - } + @Override + public String toString() { + return record.toString(); + } - public enum ViewType { - vertex, edge - } + public enum ViewType { + vertex, + edge + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithLabelField.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithLabelField.java index 62251912a..33d245e38 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithLabelField.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithLabelField.java @@ -21,18 +21,17 @@ public interface IGraphElementWithLabelField { - /** - * Get label. - * - * @return label - */ - String getLabel(); - - /** - * Set label. - * - * @param label label - */ - void setLabel(String label); + /** + * Get label. + * + * @return label + */ + String getLabel(); + /** + * Set label. + * + * @param label label + */ + void setLabel(String label); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithTimeField.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithTimeField.java index d57e66050..2de61bc05 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithTimeField.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/IGraphElementWithTimeField.java @@ -21,18 +21,17 @@ public interface IGraphElementWithTimeField { - /** - * Get time. - * - * @return timestamp - */ - long getTime(); - - /** - * Set time. - * - * @param time timestamp - */ - void setTime(long time); + /** + * Get time. + * + * @return timestamp + */ + long getTime(); + /** + * Set time. + * + * @param time timestamp + */ + void setTime(long time); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/DefaultGraphAggMessage.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/DefaultGraphAggMessage.java index 650c04f8f..5e4f608f8 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/DefaultGraphAggMessage.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/DefaultGraphAggMessage.java @@ -21,16 +21,14 @@ public class DefaultGraphAggMessage implements IGraphAggMessage { - private MESSAGE message; + private MESSAGE message; - public DefaultGraphAggMessage(MESSAGE message) { - this.message = message; - } - - - @Override - public MESSAGE getMessage() { - return message; - } + public DefaultGraphAggMessage(MESSAGE message) { + this.message = message; + } + @Override + public MESSAGE getMessage() { + return message; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/IGraphAggMessage.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/IGraphAggMessage.java index 5559baf88..c6f3feae6 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/IGraphAggMessage.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/aggregate/IGraphAggMessage.java @@ -23,6 +23,5 @@ public interface IGraphAggMessage extends Serializable { - MESSAGE getMessage(); - + MESSAGE getMessage(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/EdgeDirection.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/EdgeDirection.java index b83ef562d..6a1c437c1 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/EdgeDirection.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/EdgeDirection.java @@ -20,32 +20,24 @@ package org.apache.geaflow.model.graph.edge; public enum EdgeDirection { - /** - * The in edge. - */ - IN, - /** - * The out edge. - */ - OUT, - /** - * Include In and Out Edges. - */ - BOTH, + /** The in edge. */ + IN, + /** The out edge. */ + OUT, + /** Include In and Out Edges. */ + BOTH, - /** - * None direction edge. - */ - NONE; + /** None direction edge. */ + NONE; - public EdgeDirection reverse() { - switch (this) { - case IN: - return OUT; - case OUT: - return IN; - default: - return this; - } + public EdgeDirection reverse() { + switch (this) { + case IN: + return OUT; + case OUT: + return IN; + default: + return this; } + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/IEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/IEdge.java index da4977a57..a786d0998 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/IEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/IEdge.java @@ -23,67 +23,66 @@ public interface IEdge extends Serializable { - /** - * Get the source id of edge. - * - * @return source id - */ - K getSrcId(); + /** + * Get the source id of edge. + * + * @return source id + */ + K getSrcId(); - /** - * Set the source id for the edge. - * - * @param srcId source id - */ - void setSrcId(K srcId); + /** + * Set the source id for the edge. + * + * @param srcId source id + */ + void setSrcId(K srcId); - /** - * Get the target id of edge. - * - * @return target id - */ - K getTargetId(); + /** + * Get the target id of edge. + * + * @return target id + */ + K getTargetId(); - /** - * Set the target id for the edge. - * - * @param targetId target id - */ - void setTargetId(K targetId); + /** + * Set the target id for the edge. + * + * @param targetId target id + */ + void setTargetId(K targetId); - /** - * Get the direction of edge. - * - * @return direction - */ - EdgeDirection getDirect(); + /** + * Get the direction of edge. + * + * @return direction + */ + EdgeDirection getDirect(); - /** - * Set the direction for the edge. - * - * @param direction direction - */ - void setDirect(EdgeDirection direction); + /** + * Set the direction for the edge. + * + * @param direction direction + */ + void setDirect(EdgeDirection direction); - /** - * Get the value of edge. - * - * @return value - */ - EV getValue(); + /** + * Get the value of edge. + * + * @return value + */ + EV getValue(); - /** - * Reset the value for the edge. - * - * @param value value - */ - IEdge withValue(EV value); - - /** - * Reverse the source id and target id, and return new edge. - * - * @return edge - */ - IEdge reverse(); + /** + * Reset the value for the edge. + * + * @param value value + */ + IEdge withValue(EV value); + /** + * Reverse the source id and target id, and return new edge. + * + * @return edge + */ + IEdge reverse(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDEdge.java index 301a155a5..1c3d715af 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDEdge.java @@ -20,6 +20,7 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -27,85 +28,83 @@ public class IDEdge implements IEdge { - private K srcId; - private K targetId; - private EdgeDirection direction; - - public IDEdge() { - } - - public IDEdge(K srcId, K targetId) { - this(srcId, targetId, EdgeDirection.OUT); - } - - public IDEdge(K srcId, K targetId, EdgeDirection edgeDirection) { - this.srcId = srcId; - this.targetId = targetId; - this.direction = edgeDirection; - } - - @Override - public K getSrcId() { - return this.srcId; - } - - @Override - public void setSrcId(K srcId) { - this.srcId = srcId; - } - - @Override - public K getTargetId() { - return this.targetId; - } - - @Override - public void setTargetId(K targetId) { - this.targetId = targetId; + private K srcId; + private K targetId; + private EdgeDirection direction; + + public IDEdge() {} + + public IDEdge(K srcId, K targetId) { + this(srcId, targetId, EdgeDirection.OUT); + } + + public IDEdge(K srcId, K targetId, EdgeDirection edgeDirection) { + this.srcId = srcId; + this.targetId = targetId; + this.direction = edgeDirection; + } + + @Override + public K getSrcId() { + return this.srcId; + } + + @Override + public void setSrcId(K srcId) { + this.srcId = srcId; + } + + @Override + public K getTargetId() { + return this.targetId; + } + + @Override + public void setTargetId(K targetId) { + this.targetId = targetId; + } + + @Override + public IEdge reverse() { + return new IDEdge<>(this.targetId, this.srcId); + } + + @Override + public Object getValue() { + return null; + } + + @Override + public IEdge withValue(Object value) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public EdgeDirection getDirect() { + return this.direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public IEdge reverse() { - return new IDEdge<>(this.targetId, this.srcId); - } - - @Override - public Object getValue() { - return null; - } - - @Override - public IEdge withValue(Object value) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public EdgeDirection getDirect() { - return this.direction; + if (o == null || this.getClass() != o.getClass()) { + return false; } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || this.getClass() != o.getClass()) { - return false; - } - IDEdge that = (IDEdge) o; - return Objects.equals(this.srcId, that.srcId) - && Objects.equals(this.targetId, that.targetId) - && this.direction == that.direction; - } - - @Override - public int hashCode() { - return Objects.hash(this.srcId, this.targetId, this.direction); - } - + IDEdge that = (IDEdge) o; + return Objects.equals(this.srcId, that.srcId) + && Objects.equals(this.targetId, that.targetId) + && this.direction == that.direction; + } + + @Override + public int hashCode() { + return Objects.hash(this.srcId, this.targetId, this.direction); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelEdge.java index b73fe4cdc..49a796675 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelEdge.java @@ -20,51 +20,50 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.edge.EdgeDirection; public class IDLabelEdge extends IDEdge implements IGraphElementWithLabelField { - private String label; + private String label; - public IDLabelEdge() { - } + public IDLabelEdge() {} - public IDLabelEdge(K src, K target) { - this(src, target, null); - } + public IDLabelEdge(K src, K target) { + this(src, target, null); + } - public IDLabelEdge(K src, K target, String label) { - this(src, target, EdgeDirection.OUT, label); - } + public IDLabelEdge(K src, K target, String label) { + this(src, target, EdgeDirection.OUT, label); + } - public IDLabelEdge(K srcId, K targetId, EdgeDirection edgeDirection, String label) { - super(srcId, targetId, edgeDirection); - this.label = label; - } + public IDLabelEdge(K srcId, K targetId, EdgeDirection edgeDirection, String label) { + super(srcId, targetId, edgeDirection); + this.label = label; + } - @Override - public String getLabel() { - return this.label; - } - - @Override - public void setLabel(String label) { - this.label = label; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDLabelEdge that = (IDLabelEdge) o; - return Objects.equals(this.label, that.label); - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + IDLabelEdge that = (IDLabelEdge) o; + return Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelTimeEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelTimeEdge.java index 65fc4a9b7..40d8a58b4 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelTimeEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDLabelTimeEdge.java @@ -20,6 +20,7 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -27,58 +28,57 @@ public class IDLabelTimeEdge extends IDEdge implements IGraphElementWithLabelField, IGraphElementWithTimeField { - private String label; - private long time; + private String label; + private long time; - public IDLabelTimeEdge() { - } + public IDLabelTimeEdge() {} - public IDLabelTimeEdge(K src, K target) { - this(src, target, null, 0); - } + public IDLabelTimeEdge(K src, K target) { + this(src, target, null, 0); + } - public IDLabelTimeEdge(K src, K target, String label, long time) { - this(src, target, EdgeDirection.OUT, label, time); - } + public IDLabelTimeEdge(K src, K target, String label, long time) { + this(src, target, EdgeDirection.OUT, label, time); + } - public IDLabelTimeEdge(K srcId, K targetId, EdgeDirection edgeDirection, String label, long time) { - super(srcId, targetId, edgeDirection); - this.label = label; - this.time = time; - } + public IDLabelTimeEdge( + K srcId, K targetId, EdgeDirection edgeDirection, String label, long time) { + super(srcId, targetId, edgeDirection); + this.label = label; + this.time = time; + } - @Override - public String getLabel() { - return this.label; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public void setLabel(String label) { - this.label = label; - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public long getTime() { - return this.time; - } + @Override + public long getTime() { + return this.time; + } - @Override - public void setTime(long time) { - this.time = time; - } - - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDLabelTimeEdge that = (IDLabelTimeEdge) o; - return this.time == that.time && Objects.equals(this.label, that.label); - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label, this.time); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + IDLabelTimeEdge that = (IDLabelTimeEdge) o; + return this.time == that.time && Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label, this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDTimeEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDTimeEdge.java index d9e08299e..3c902a3fd 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDTimeEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/IDTimeEdge.java @@ -20,51 +20,50 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; public class IDTimeEdge extends IDEdge implements IGraphElementWithTimeField { - private long time; + private long time; - public IDTimeEdge() { - } + public IDTimeEdge() {} - public IDTimeEdge(K src, K target) { - this(src, target, 0); - } + public IDTimeEdge(K src, K target) { + this(src, target, 0); + } - public IDTimeEdge(K src, K target, long time) { - this(src, target, EdgeDirection.OUT, time); - } + public IDTimeEdge(K src, K target, long time) { + this(src, target, EdgeDirection.OUT, time); + } - public IDTimeEdge(K src, K target, EdgeDirection edgeDirection, long time) { - super(src, target, edgeDirection); - this.time = time; - } + public IDTimeEdge(K src, K target, EdgeDirection edgeDirection, long time) { + super(src, target, edgeDirection); + this.time = time; + } - @Override - public long getTime() { - return this.time; - } - - @Override - public void setTime(long time) { - this.time = time; - } + @Override + public long getTime() { + return this.time; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDTimeEdge that = (IDTimeEdge) o; - return this.time == that.time; - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.time); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + IDTimeEdge that = (IDTimeEdge) o; + return this.time == that.time; + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueEdge.java index 425914c89..22df865b4 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueEdge.java @@ -20,102 +20,110 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; public class ValueEdge implements IEdge { - private K srcId; - private K targetId; - private EdgeDirection direction; - private EV value; - - public ValueEdge() { - } - - public ValueEdge(K srcId, K targetId) { - this(srcId, targetId, null); - } - - public ValueEdge(K srcId, K targetId, EV value) { - this(srcId, targetId, value, EdgeDirection.OUT); - } - - public ValueEdge(K srcId, K targetId, EV value, EdgeDirection edgeDirection) { - this.srcId = srcId; - this.targetId = targetId; - this.direction = edgeDirection; - this.value = value; - } - - @Override - public K getSrcId() { - return this.srcId; - } - - @Override - public void setSrcId(K srcId) { - this.srcId = srcId; - } - - @Override - public K getTargetId() { - return this.targetId; - } - - @Override - public void setTargetId(K targetId) { - this.targetId = targetId; - } - - @Override - public IEdge reverse() { - return new ValueEdge<>(this.targetId, this.srcId, this.value); - } - - @Override - public EV getValue() { - return this.value; - } - - @Override - public EdgeDirection getDirect() { - return this.direction; - } - - @Override - public IEdge withValue(EV value) { - this.value = value; - return this; - } - - @Override - public void setDirect(EdgeDirection direction) { - this.direction = direction; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || this.getClass() != o.getClass()) { - return false; - } - ValueEdge that = (ValueEdge) o; - return Objects.equals(this.srcId, that.srcId) - && Objects.equals(this.targetId, that.targetId) - && this.direction == that.direction; - } - - @Override - public int hashCode() { - return Objects.hash(this.srcId, this.targetId, this.direction); - } - - @Override - public String toString() { - return "ValueEdge{" + "srcId=" + srcId + ", targetId=" + targetId + ", direction=" - + direction + ", value=" + value + '}'; - } + private K srcId; + private K targetId; + private EdgeDirection direction; + private EV value; + + public ValueEdge() {} + + public ValueEdge(K srcId, K targetId) { + this(srcId, targetId, null); + } + + public ValueEdge(K srcId, K targetId, EV value) { + this(srcId, targetId, value, EdgeDirection.OUT); + } + + public ValueEdge(K srcId, K targetId, EV value, EdgeDirection edgeDirection) { + this.srcId = srcId; + this.targetId = targetId; + this.direction = edgeDirection; + this.value = value; + } + + @Override + public K getSrcId() { + return this.srcId; + } + + @Override + public void setSrcId(K srcId) { + this.srcId = srcId; + } + + @Override + public K getTargetId() { + return this.targetId; + } + + @Override + public void setTargetId(K targetId) { + this.targetId = targetId; + } + + @Override + public IEdge reverse() { + return new ValueEdge<>(this.targetId, this.srcId, this.value); + } + + @Override + public EV getValue() { + return this.value; + } + + @Override + public EdgeDirection getDirect() { + return this.direction; + } + + @Override + public IEdge withValue(EV value) { + this.value = value; + return this; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || this.getClass() != o.getClass()) { + return false; + } + ValueEdge that = (ValueEdge) o; + return Objects.equals(this.srcId, that.srcId) + && Objects.equals(this.targetId, that.targetId) + && this.direction == that.direction; + } + + @Override + public int hashCode() { + return Objects.hash(this.srcId, this.targetId, this.direction); + } + + @Override + public String toString() { + return "ValueEdge{" + + "srcId=" + + srcId + + ", targetId=" + + targetId + + ", direction=" + + direction + + ", value=" + + value + + '}'; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelEdge.java index e69cb1dd6..5d0548b74 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelEdge.java @@ -20,51 +20,50 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.edge.EdgeDirection; public class ValueLabelEdge extends ValueEdge implements IGraphElementWithLabelField { - private String label; + private String label; - public ValueLabelEdge() { - } + public ValueLabelEdge() {} - public ValueLabelEdge(K src, K target, EV value) { - this(src, target, value, null); - } + public ValueLabelEdge(K src, K target, EV value) { + this(src, target, value, null); + } - public ValueLabelEdge(K src, K target, EV value, String label) { - this(src, target, value, EdgeDirection.OUT, label); - } + public ValueLabelEdge(K src, K target, EV value, String label) { + this(src, target, value, EdgeDirection.OUT, label); + } - public ValueLabelEdge(K srcId, K targetId, EV value, EdgeDirection edgeDirection, String label) { - super(srcId, targetId, value, edgeDirection); - this.label = label; - } + public ValueLabelEdge(K srcId, K targetId, EV value, EdgeDirection edgeDirection, String label) { + super(srcId, targetId, value, edgeDirection); + this.label = label; + } - @Override - public String getLabel() { - return this.label; - } - - @Override - public void setLabel(String label) { - this.label = label; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueLabelEdge that = (ValueLabelEdge) o; - return Objects.equals(this.label, that.label); - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + ValueLabelEdge that = (ValueLabelEdge) o; + return Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelTimeEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelTimeEdge.java index 7c0ede5be..e96534d39 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelTimeEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueLabelTimeEdge.java @@ -20,6 +20,7 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -27,58 +28,57 @@ public class ValueLabelTimeEdge extends ValueEdge implements IGraphElementWithLabelField, IGraphElementWithTimeField { - private String label; - private long time; + private String label; + private long time; - public ValueLabelTimeEdge() { - } + public ValueLabelTimeEdge() {} - public ValueLabelTimeEdge(K src, K target, EV value) { - this(src, target, value, null, 0); - } + public ValueLabelTimeEdge(K src, K target, EV value) { + this(src, target, value, null, 0); + } - public ValueLabelTimeEdge(K src, K target, EV value, String label, long time) { - this(src, target, value, EdgeDirection.OUT, label, time); - } + public ValueLabelTimeEdge(K src, K target, EV value, String label, long time) { + this(src, target, value, EdgeDirection.OUT, label, time); + } - public ValueLabelTimeEdge(K src, K target, EV value, EdgeDirection edgeDirection, String label, long time) { - super(src, target, value, edgeDirection); - this.label = label; - this.time = time; - } + public ValueLabelTimeEdge( + K src, K target, EV value, EdgeDirection edgeDirection, String label, long time) { + super(src, target, value, edgeDirection); + this.label = label; + this.time = time; + } - @Override - public String getLabel() { - return this.label; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public void setLabel(String label) { - this.label = label; - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public long getTime() { - return this.time; - } + @Override + public long getTime() { + return this.time; + } - @Override - public void setTime(long time) { - this.time = time; - } - - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueLabelTimeEdge that = (ValueLabelTimeEdge) o; - return this.time == that.time && Objects.equals(this.label, that.label); - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label, this.time); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + ValueLabelTimeEdge that = (ValueLabelTimeEdge) o; + return this.time == that.time && Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label, this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueTimeEdge.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueTimeEdge.java index 9c3bba5a3..2f5ed7d9e 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueTimeEdge.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/edge/impl/ValueTimeEdge.java @@ -20,51 +20,50 @@ package org.apache.geaflow.model.graph.edge.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithTimeField; import org.apache.geaflow.model.graph.edge.EdgeDirection; public class ValueTimeEdge extends ValueEdge implements IGraphElementWithTimeField { - private long time; + private long time; - public ValueTimeEdge() { - } + public ValueTimeEdge() {} - public ValueTimeEdge(K src, K target, EV value) { - this(src, target, value, 0); - } + public ValueTimeEdge(K src, K target, EV value) { + this(src, target, value, 0); + } - public ValueTimeEdge(K src, K target, EV value, long time) { - this(src, target, value, EdgeDirection.OUT, time); - } + public ValueTimeEdge(K src, K target, EV value, long time) { + this(src, target, value, EdgeDirection.OUT, time); + } - public ValueTimeEdge(K src, K target, EV value, EdgeDirection edgeDirection, long time) { - super(src, target, value, edgeDirection); - this.time = time; - } + public ValueTimeEdge(K src, K target, EV value, EdgeDirection edgeDirection, long time) { + super(src, target, value, edgeDirection); + this.time = time; + } - @Override - public long getTime() { - return this.time; - } - - @Override - public void setTime(long time) { - this.time = time; - } + @Override + public long getTime() { + return this.time; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueTimeEdge that = (ValueTimeEdge) o; - return this.time == that.time; - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.time); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + ValueTimeEdge that = (ValueTimeEdge) o; + return this.time == that.time; + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/DefaultGraphMessage.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/DefaultGraphMessage.java index 58aa96e74..10129ed7e 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/DefaultGraphMessage.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/DefaultGraphMessage.java @@ -23,41 +23,40 @@ public class DefaultGraphMessage implements IGraphMessage { - private final K targetVId; - private MESSAGE message; - - public DefaultGraphMessage(K targetVId, MESSAGE message) { - this.targetVId = targetVId; - this.message = message; - } - - public MESSAGE getMessage() { - return this.message; - } - - @Override - public K getTargetVId() { - return this.targetVId; - } - - @Override - public boolean hasNext() { - return this.message != null; - } - - @Override - public MESSAGE next() { - if (this.message == null) { - throw new NoSuchElementException(); - } - MESSAGE msg = this.message; - this.message = null; - return msg; + private final K targetVId; + private MESSAGE message; + + public DefaultGraphMessage(K targetVId, MESSAGE message) { + this.targetVId = targetVId; + this.message = message; + } + + public MESSAGE getMessage() { + return this.message; + } + + @Override + public K getTargetVId() { + return this.targetVId; + } + + @Override + public boolean hasNext() { + return this.message != null; + } + + @Override + public MESSAGE next() { + if (this.message == null) { + throw new NoSuchElementException(); } - - @Override - public String toString() { - return "DefaultGraphMessage{" + "targetVId=" + targetVId + ", message=" + message + '}'; - } - + MESSAGE msg = this.message; + this.message = null; + return msg; + } + + @Override + public String toString() { + return "DefaultGraphMessage{" + "targetVId=" + targetVId + ", message=" + message + '}'; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/IGraphMessage.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/IGraphMessage.java index 9c4857570..f9e103c2d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/IGraphMessage.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/IGraphMessage.java @@ -24,6 +24,5 @@ public interface IGraphMessage extends Iterator, Serializable { - K getTargetVId(); - + K getTargetVId(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/ListGraphMessage.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/ListGraphMessage.java index 621ba5a6a..8f0435e7c 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/ListGraphMessage.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/ListGraphMessage.java @@ -23,42 +23,41 @@ public class ListGraphMessage implements IGraphMessage { - private final K targetVId; - private final List messages; - private final int size; - private int idx; - - public ListGraphMessage(K targetVId, List messages) { - this.targetVId = targetVId; - this.messages = messages; - this.size = this.messages.size(); - this.idx = 0; - } - - public List getMessages() { - return this.messages; - } - - @Override - public K getTargetVId() { - return this.targetVId; - } - - @Override - public boolean hasNext() { - return this.idx < this.size; - } - - @Override - public MESSAGE next() { - MESSAGE message = this.messages.get(this.idx); - this.idx++; - return message; - } - - @Override - public String toString() { - return "ListGraphMessage{" + "targetVId=" + targetVId + ", messages=" + messages + '}'; - } - + private final K targetVId; + private final List messages; + private final int size; + private int idx; + + public ListGraphMessage(K targetVId, List messages) { + this.targetVId = targetVId; + this.messages = messages; + this.size = this.messages.size(); + this.idx = 0; + } + + public List getMessages() { + return this.messages; + } + + @Override + public K getTargetVId() { + return this.targetVId; + } + + @Override + public boolean hasNext() { + return this.idx < this.size; + } + + @Override + public MESSAGE next() { + MESSAGE message = this.messages.get(this.idx); + this.idx++; + return message; + } + + @Override + public String toString() { + return "ListGraphMessage{" + "targetVId=" + targetVId + ", messages=" + messages + '}'; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/AbstractGraphMessageEncoder.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/AbstractGraphMessageEncoder.java index d0fa95d9c..7b64432fd 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/AbstractGraphMessageEncoder.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/AbstractGraphMessageEncoder.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.OutputStream; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.common.encoder.impl.AbstractEncoder; @@ -31,29 +32,28 @@ public abstract class AbstractGraphMessageEncoder> extends AbstractEncoder { - protected final IEncoder keyEncoder; - protected final IEncoder msgEncoder; - - public AbstractGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { - this.keyEncoder = keyEncoder; - this.msgEncoder = msgEncoder; - } - - @Override - public void init(Configuration config) { - this.keyEncoder.init(config); - this.msgEncoder.init(config); + protected final IEncoder keyEncoder; + protected final IEncoder msgEncoder; + + public AbstractGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { + this.keyEncoder = keyEncoder; + this.msgEncoder = msgEncoder; + } + + @Override + public void init(Configuration config) { + this.keyEncoder.init(config); + this.msgEncoder.init(config); + } + + @Override + public void encode(GRAPHMESSAGE data, OutputStream outputStream) throws IOException { + if (data == null) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.shuffleSerializeError("graph message can not be null")); } + this.doEncode(data, outputStream); + } - @Override - public void encode(GRAPHMESSAGE data, OutputStream outputStream) throws IOException { - if (data == null) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.shuffleSerializeError("graph message can not be null")); - } - this.doEncode(data, outputStream); - } - - public abstract void doEncode(GRAPHMESSAGE data, OutputStream outputStream) throws IOException; - + public abstract void doEncode(GRAPHMESSAGE data, OutputStream outputStream) throws IOException; } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/DefaultGraphMessageEncoder.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/DefaultGraphMessageEncoder.java index faeb34986..784292e35 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/DefaultGraphMessageEncoder.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/DefaultGraphMessageEncoder.java @@ -22,26 +22,28 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; + import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.model.graph.message.DefaultGraphMessage; -public class DefaultGraphMessageEncoder extends AbstractGraphMessageEncoder> { - - public DefaultGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { - super(keyEncoder, msgEncoder); - } +public class DefaultGraphMessageEncoder + extends AbstractGraphMessageEncoder> { - @Override - public void doEncode(DefaultGraphMessage data, OutputStream outputStream) throws IOException { - this.keyEncoder.encode(data.getTargetVId(), outputStream); - this.msgEncoder.encode(data.getMessage(), outputStream); - } + public DefaultGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { + super(keyEncoder, msgEncoder); + } - @Override - public DefaultGraphMessage decode(InputStream inputStream) throws IOException { - K key = this.keyEncoder.decode(inputStream); - M msg = this.msgEncoder.decode(inputStream); - return new DefaultGraphMessage<>(key, msg); - } + @Override + public void doEncode(DefaultGraphMessage data, OutputStream outputStream) + throws IOException { + this.keyEncoder.encode(data.getTargetVId(), outputStream); + this.msgEncoder.encode(data.getMessage(), outputStream); + } + @Override + public DefaultGraphMessage decode(InputStream inputStream) throws IOException { + K key = this.keyEncoder.decode(inputStream); + M msg = this.msgEncoder.decode(inputStream); + return new DefaultGraphMessage<>(key, msg); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/GraphMessageEncoders.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/GraphMessageEncoders.java index 52a2d4bd6..f7772f65f 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/GraphMessageEncoders.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/GraphMessageEncoders.java @@ -24,12 +24,11 @@ public class GraphMessageEncoders { - public static IEncoder> build(IEncoder keyEncoder, - IEncoder msgEncoder) { - if (keyEncoder == null || msgEncoder == null) { - return null; - } - return new DefaultGraphMessageEncoder<>(keyEncoder, msgEncoder); + public static IEncoder> build( + IEncoder keyEncoder, IEncoder msgEncoder) { + if (keyEncoder == null || msgEncoder == null) { + return null; } - + return new DefaultGraphMessageEncoder<>(keyEncoder, msgEncoder); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/ListGraphMessageEncoder.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/ListGraphMessageEncoder.java index 6e720084f..2fde5b03d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/ListGraphMessageEncoder.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/message/encoder/ListGraphMessageEncoder.java @@ -24,36 +24,37 @@ import java.io.OutputStream; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.encoder.IEncoder; import org.apache.geaflow.model.graph.message.ListGraphMessage; -public class ListGraphMessageEncoder extends AbstractGraphMessageEncoder> { +public class ListGraphMessageEncoder + extends AbstractGraphMessageEncoder> { - public ListGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { - super(keyEncoder, msgEncoder); - } + public ListGraphMessageEncoder(IEncoder keyEncoder, IEncoder msgEncoder) { + super(keyEncoder, msgEncoder); + } - @Override - public void doEncode(ListGraphMessage data, OutputStream outputStream) throws IOException { - this.keyEncoder.encode(data.getTargetVId(), outputStream); - List messages = data.getMessages(); - Encoders.INTEGER.encode(messages.size(), outputStream); - for (M msg : messages) { - this.msgEncoder.encode(msg, outputStream); - } + @Override + public void doEncode(ListGraphMessage data, OutputStream outputStream) throws IOException { + this.keyEncoder.encode(data.getTargetVId(), outputStream); + List messages = data.getMessages(); + Encoders.INTEGER.encode(messages.size(), outputStream); + for (M msg : messages) { + this.msgEncoder.encode(msg, outputStream); } + } - @Override - public ListGraphMessage decode(InputStream inputStream) throws IOException { - K vid = this.keyEncoder.decode(inputStream); - int size = Encoders.INTEGER.decode(inputStream); - List messages = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - M msg = this.msgEncoder.decode(inputStream); - messages.add(msg); - } - return new ListGraphMessage<>(vid, messages); + @Override + public ListGraphMessage decode(InputStream inputStream) throws IOException { + K vid = this.keyEncoder.decode(inputStream); + int size = Encoders.INTEGER.decode(inputStream); + List messages = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + M msg = this.msgEncoder.decode(inputStream); + messages.add(msg); } - + return new ListGraphMessage<>(vid, messages); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/AbstractGraphElementMeta.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/AbstractGraphElementMeta.java index 1b245fd30..1226d6933 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/AbstractGraphElementMeta.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/AbstractGraphElementMeta.java @@ -20,52 +20,53 @@ package org.apache.geaflow.model.graph.meta; import java.util.function.Supplier; + import org.apache.geaflow.common.schema.ISchema; import org.apache.geaflow.common.type.IType; public abstract class AbstractGraphElementMeta implements IGraphElementMeta { - private final byte graphElementId; - private final Class elementClass; - private final Supplier elementConstruct; - private final Class propertyClass; - private final ISchema primitiveSchema; - - public AbstractGraphElementMeta(byte graphElementId, - IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - this.graphElementId = graphElementId; - this.elementClass = elementClass; - this.elementConstruct = elementConstruct; - this.propertyClass = propertyClass; - this.primitiveSchema = GraphElementSchemaFactory.newSchema(keyType, elementClass); - } + private final byte graphElementId; + private final Class elementClass; + private final Supplier elementConstruct; + private final Class propertyClass; + private final ISchema primitiveSchema; - @Override - public byte getGraphElementId() { - return this.graphElementId; - } + public AbstractGraphElementMeta( + byte graphElementId, + IType keyType, + Class elementClass, + Supplier elementConstruct, + Class propertyClass) { + this.graphElementId = graphElementId; + this.elementClass = elementClass; + this.elementConstruct = elementConstruct; + this.propertyClass = propertyClass; + this.primitiveSchema = GraphElementSchemaFactory.newSchema(keyType, elementClass); + } - @Override - public Class getGraphElementClass() { - return this.elementClass; - } + @Override + public byte getGraphElementId() { + return this.graphElementId; + } - @Override - public Supplier getGraphElementConstruct() { - return elementConstruct; - } + @Override + public Class getGraphElementClass() { + return this.elementClass; + } - @Override - public ISchema getGraphMeta() { - return this.primitiveSchema; - } + @Override + public Supplier getGraphElementConstruct() { + return elementConstruct; + } - @Override - public Class getPropertyClass() { - return this.propertyClass; - } + @Override + public ISchema getGraphMeta() { + return this.primitiveSchema; + } + @Override + public Class getPropertyClass() { + return this.propertyClass; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementMetas.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementMetas.java index 0ee0422df..3372f7cab 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementMetas.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementMetas.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.type.IType; @@ -33,154 +34,159 @@ @SuppressWarnings("ALL") public class GraphElementMetas { - private static final ThreadLocal, IGraphElementMeta>> META_CACHE = - ThreadLocal.withInitial(HashMap::new); - - public static IGraphElementMeta getMeta(IType keyType, Class elementClass, Supplier elementConstruct, - Class propertyClass) { - Map, IGraphElementMeta> map = META_CACHE.get(); - return map.computeIfAbsent(elementClass, - k -> newGraphElementMeta(keyType, elementClass, elementConstruct, propertyClass)); - } - - public static void clearCache() { - META_CACHE.remove(); + private static final ThreadLocal, IGraphElementMeta>> META_CACHE = + ThreadLocal.withInitial(HashMap::new); + + public static IGraphElementMeta getMeta( + IType keyType, Class elementClass, Supplier elementConstruct, Class propertyClass) { + Map, IGraphElementMeta> map = META_CACHE.get(); + return map.computeIfAbsent( + elementClass, + k -> newGraphElementMeta(keyType, elementClass, elementConstruct, propertyClass)); + } + + public static void clearCache() { + META_CACHE.remove(); + } + + private static IGraphElementMeta newGraphElementMeta( + IType keyType, Class elementClass, Supplier elementConstruct, Class propertyClass) { + GraphElementFlag flag = GraphElementFlag.build(elementClass); + byte graphElementId = flag.toGraphElementId(); + if (flag.isLabeled()) { + return new LabelElementMeta( + graphElementId, keyType, elementClass, elementConstruct, propertyClass); + } else if (flag.isTimed()) { + return new TimeElementMeta( + graphElementId, keyType, elementClass, elementConstruct, propertyClass); + } else if (flag.isLabeledAndTimed()) { + return new LabelTimeElementMeta( + graphElementId, keyType, elementClass, elementConstruct, propertyClass); + } else { + return new IdElementMeta( + graphElementId, keyType, elementClass, elementConstruct, propertyClass); } + } - private static IGraphElementMeta newGraphElementMeta(IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - GraphElementFlag flag = GraphElementFlag.build(elementClass); - byte graphElementId = flag.toGraphElementId(); - if (flag.isLabeled()) { - return new LabelElementMeta(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } else if (flag.isTimed()) { - return new TimeElementMeta(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } else if (flag.isLabeledAndTimed()) { - return new LabelTimeElementMeta(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } else { - return new IdElementMeta(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } - } - - public static class GraphElementFlag { - - private static final byte MASK_LABEL = 1 << 0; - private static final byte MASK_TIME = 1 << 1; + public static class GraphElementFlag { - private byte flag = 0; + private static final byte MASK_LABEL = 1 << 0; + private static final byte MASK_TIME = 1 << 1; - public void markLabeled() { - this.flag |= MASK_LABEL; - } - - public void markTimed() { - this.flag |= MASK_TIME; - } - - public boolean isLabeled() { - return (this.flag & MASK_LABEL) > 0 && (this.flag & MASK_TIME) == 0; - } - - public boolean isTimed() { - return (this.flag & MASK_LABEL) == 0 && (this.flag & MASK_TIME) > 0; - } - - public boolean isLabeledAndTimed() { - return (this.flag & MASK_LABEL) > 0 && (this.flag & MASK_TIME) > 0; - } - - public byte toGraphElementId() { - return this.flag; - } - - public static GraphElementFlag build(Class elementClass) { - if (!IVertex.class.isAssignableFrom(elementClass) && !IEdge.class.isAssignableFrom(elementClass)) { - String msg = "unrecognized graph element class: " + elementClass.getCanonicalName(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } - GraphElementFlag tag = new GraphElementFlag(); - if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { - tag.markLabeled(); - } - if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { - tag.markTimed(); - } - return tag; - } + private byte flag = 0; + public void markLabeled() { + this.flag |= MASK_LABEL; } - public static class IdElementMeta extends AbstractGraphElementMeta { - - public IdElementMeta(byte graphElementId, - IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - super(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } - - @Override - public IGraphFieldSerializer getGraphFieldSerializer() { - return GraphFieldSerializers.IdFieldSerializer.INSTANCE; - } - + public void markTimed() { + this.flag |= MASK_TIME; } - public static class LabelElementMeta - extends AbstractGraphElementMeta { - - public LabelElementMeta(byte graphElementId, - IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - super(graphElementId, keyType, elementClass, elementConstruct, propertyClass); - } + public boolean isLabeled() { + return (this.flag & MASK_LABEL) > 0 && (this.flag & MASK_TIME) == 0; + } - @Override - public IGraphFieldSerializer getGraphFieldSerializer() { - return GraphFieldSerializers.LabelFieldSerializer.INSTANCE; - } + public boolean isTimed() { + return (this.flag & MASK_LABEL) == 0 && (this.flag & MASK_TIME) > 0; + } + public boolean isLabeledAndTimed() { + return (this.flag & MASK_LABEL) > 0 && (this.flag & MASK_TIME) > 0; } - public static class TimeElementMeta - extends AbstractGraphElementMeta { + public byte toGraphElementId() { + return this.flag; + } - public TimeElementMeta(byte graphElementId, - IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - super(graphElementId, keyType, (Class) elementClass, elementConstruct, propertyClass); - } + public static GraphElementFlag build(Class elementClass) { + if (!IVertex.class.isAssignableFrom(elementClass) + && !IEdge.class.isAssignableFrom(elementClass)) { + String msg = "unrecognized graph element class: " + elementClass.getCanonicalName(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } + GraphElementFlag tag = new GraphElementFlag(); + if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { + tag.markLabeled(); + } + if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { + tag.markTimed(); + } + return tag; + } + } - @Override - public IGraphFieldSerializer getGraphFieldSerializer() { - return GraphFieldSerializers.TimeFieldSerializer.INSTANCE; - } + public static class IdElementMeta extends AbstractGraphElementMeta { + public IdElementMeta( + byte graphElementId, + IType keyType, + Class elementClass, + Supplier elementConstruct, + Class propertyClass) { + super(graphElementId, keyType, elementClass, elementConstruct, propertyClass); } - public static class LabelTimeElementMeta - extends AbstractGraphElementMeta { - - public LabelTimeElementMeta(byte graphElementId, - IType keyType, - Class elementClass, - Supplier elementConstruct, - Class propertyClass) { - super(graphElementId, keyType, (Class) elementClass, elementConstruct, propertyClass); - } + @Override + public IGraphFieldSerializer getGraphFieldSerializer() { + return GraphFieldSerializers.IdFieldSerializer.INSTANCE; + } + } + + public static class LabelElementMeta + extends AbstractGraphElementMeta { + + public LabelElementMeta( + byte graphElementId, + IType keyType, + Class elementClass, + Supplier elementConstruct, + Class propertyClass) { + super(graphElementId, keyType, elementClass, elementConstruct, propertyClass); + } - @Override - public IGraphFieldSerializer getGraphFieldSerializer() { - return GraphFieldSerializers.LabelTimeFieldSerializer.INSTANCE; - } + @Override + public IGraphFieldSerializer getGraphFieldSerializer() { + return GraphFieldSerializers.LabelFieldSerializer.INSTANCE; + } + } + + public static class TimeElementMeta + extends AbstractGraphElementMeta { + + public TimeElementMeta( + byte graphElementId, + IType keyType, + Class elementClass, + Supplier elementConstruct, + Class propertyClass) { + super( + graphElementId, keyType, (Class) elementClass, elementConstruct, propertyClass); + } + @Override + public IGraphFieldSerializer getGraphFieldSerializer() { + return GraphFieldSerializers.TimeFieldSerializer.INSTANCE; + } + } + + public static class LabelTimeElementMeta< + ELEMENT extends IGraphElementWithLabelField & IGraphElementWithTimeField> + extends AbstractGraphElementMeta { + + public LabelTimeElementMeta( + byte graphElementId, + IType keyType, + Class elementClass, + Supplier elementConstruct, + Class propertyClass) { + super( + graphElementId, keyType, (Class) elementClass, elementConstruct, propertyClass); } + @Override + public IGraphFieldSerializer getGraphFieldSerializer() { + return GraphFieldSerializers.LabelTimeFieldSerializer.INSTANCE; + } + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementSchemaFactory.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementSchemaFactory.java index a4727e47b..6ea51f311 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementSchemaFactory.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphElementSchemaFactory.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.schema.Field; @@ -34,45 +35,44 @@ public class GraphElementSchemaFactory { - public static final Field FIELD_LABEL = Field.newStringField(GraphFiledName.LABEL.name()); - public static final Field FIELD_TIME = Field.newLongField(GraphFiledName.TIME.name()); - public static final Field FIELD_DIRECTION = Field.newByteField(GraphFiledName.DIRECTION.name()); + public static final Field FIELD_LABEL = Field.newStringField(GraphFiledName.LABEL.name()); + public static final Field FIELD_TIME = Field.newLongField(GraphFiledName.TIME.name()); + public static final Field FIELD_DIRECTION = Field.newByteField(GraphFiledName.DIRECTION.name()); - public static ISchema newSchema(IType keyType, Class elementClass) { - if (IVertex.class.isAssignableFrom(elementClass)) { - return newVertexSchema(keyType, elementClass); - } - if (IEdge.class.isAssignableFrom(elementClass)) { - return newEdgeSchema(keyType, elementClass); - } - String msg = "unrecognized graph element class: " + elementClass.getCanonicalName(); - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + public static ISchema newSchema(IType keyType, Class elementClass) { + if (IVertex.class.isAssignableFrom(elementClass)) { + return newVertexSchema(keyType, elementClass); } - - private static ISchema newVertexSchema(IType keyType, Class elementClass) { - List fields = new ArrayList<>(); - fields.add(new Field(GraphFiledName.ID.name(), keyType)); - if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { - fields.add(FIELD_LABEL); - } - if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { - fields.add(FIELD_TIME); - } - return new SchemaImpl(elementClass.getName(), fields); + if (IEdge.class.isAssignableFrom(elementClass)) { + return newEdgeSchema(keyType, elementClass); } + String msg = "unrecognized graph element class: " + elementClass.getCanonicalName(); + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } - private static ISchema newEdgeSchema(IType keyType, Class elementClass) { - List fields = new ArrayList<>(); - fields.add(new Field(GraphFiledName.SRC_ID.name(), keyType)); - fields.add(new Field(GraphFiledName.DST_ID.name(), keyType)); - fields.add(FIELD_DIRECTION); - if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { - fields.add(FIELD_LABEL); - } - if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { - fields.add(FIELD_TIME); - } - return new SchemaImpl(elementClass.getName(), fields); + private static ISchema newVertexSchema(IType keyType, Class elementClass) { + List fields = new ArrayList<>(); + fields.add(new Field(GraphFiledName.ID.name(), keyType)); + if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { + fields.add(FIELD_LABEL); + } + if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { + fields.add(FIELD_TIME); } + return new SchemaImpl(elementClass.getName(), fields); + } + private static ISchema newEdgeSchema(IType keyType, Class elementClass) { + List fields = new ArrayList<>(); + fields.add(new Field(GraphFiledName.SRC_ID.name(), keyType)); + fields.add(new Field(GraphFiledName.DST_ID.name(), keyType)); + fields.add(FIELD_DIRECTION); + if (IGraphElementWithLabelField.class.isAssignableFrom(elementClass)) { + fields.add(FIELD_LABEL); + } + if (IGraphElementWithTimeField.class.isAssignableFrom(elementClass)) { + fields.add(FIELD_TIME); + } + return new SchemaImpl(elementClass.getName(), fields); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFieldSerializers.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFieldSerializers.java index bb476d737..b502cae8d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFieldSerializers.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFieldSerializers.java @@ -27,113 +27,111 @@ @SuppressWarnings("rawtypes") public class GraphFieldSerializers { - public static class IdFieldSerializer implements IGraphFieldSerializer { + public static class IdFieldSerializer implements IGraphFieldSerializer { - public static final IdFieldSerializer INSTANCE = new IdFieldSerializer(); + public static final IdFieldSerializer INSTANCE = new IdFieldSerializer(); - @Override - public Object getValue(ELEMENT element, GraphFiledName field) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - - @Override - public void setValue(ELEMENT element, GraphFiledName field, Object value) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } + @Override + public Object getValue(ELEMENT element, GraphFiledName field) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } + @Override + public void setValue(ELEMENT element, GraphFiledName field, Object value) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); } + } - public static class LabelFieldSerializer implements IGraphFieldSerializer { - - public static final LabelFieldSerializer INSTANCE = new LabelFieldSerializer(); - - @Override - public Object getValue(ELEMENT element, GraphFiledName field) { - if (field == GraphFiledName.LABEL) { - return element.getLabel(); - } else { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - } - - @Override - public void setValue(ELEMENT element, GraphFiledName field, Object value) { - if (field == GraphFiledName.LABEL) { - element.setLabel((String) value); - } else { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - } + public static class LabelFieldSerializer + implements IGraphFieldSerializer { - } + public static final LabelFieldSerializer INSTANCE = new LabelFieldSerializer(); - public static class TimeFieldSerializer implements IGraphFieldSerializer { - - public static final TimeFieldSerializer INSTANCE = new TimeFieldSerializer(); - - @Override - public Object getValue(ELEMENT element, GraphFiledName field) { - if (field == GraphFiledName.TIME) { - return element.getTime(); - } else { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - } - - @Override - public void setValue(ELEMENT vertex, GraphFiledName field, Object value) { - if (field == GraphFiledName.TIME) { - vertex.setTime((long) value); - } else { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(vertex.getClass(), field.name()))); - } - } + @Override + public Object getValue(ELEMENT element, GraphFiledName field) { + if (field == GraphFiledName.LABEL) { + return element.getLabel(); + } else { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } + } + @Override + public void setValue(ELEMENT element, GraphFiledName field, Object value) { + if (field == GraphFiledName.LABEL) { + element.setLabel((String) value); + } else { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } } + } - public static class LabelTimeFieldSerializer - implements IGraphFieldSerializer { - - public static final LabelTimeFieldSerializer INSTANCE = new LabelTimeFieldSerializer(); - - @Override - public Object getValue(ELEMENT element, GraphFiledName field) { - switch (field) { - case LABEL: - return element.getLabel(); - case TIME: - return element.getTime(); - default: - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - } - - @Override - public void setValue(ELEMENT element, GraphFiledName field, Object value) { - switch (field) { - case LABEL: - element.setLabel((String) value); - return; - case TIME: - element.setTime((long) value); - return; - default: - throw new GeaflowRuntimeException( - RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); - } - } + public static class TimeFieldSerializer + implements IGraphFieldSerializer { + public static final TimeFieldSerializer INSTANCE = new TimeFieldSerializer(); + + @Override + public Object getValue(ELEMENT element, GraphFiledName field) { + if (field == GraphFiledName.TIME) { + return element.getTime(); + } else { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } + } + + @Override + public void setValue(ELEMENT vertex, GraphFiledName field, Object value) { + if (field == GraphFiledName.TIME) { + vertex.setTime((long) value); + } else { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(vertex.getClass(), field.name()))); + } + } + } + + public static class LabelTimeFieldSerializer< + ELEMENT extends IGraphElementWithLabelField & IGraphElementWithTimeField> + implements IGraphFieldSerializer { + + public static final LabelTimeFieldSerializer INSTANCE = new LabelTimeFieldSerializer(); + + @Override + public Object getValue(ELEMENT element, GraphFiledName field) { + switch (field) { + case LABEL: + return element.getLabel(); + case TIME: + return element.getTime(); + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } } - private static String keyTypeErrMsg(Class clazz, String key) { - return String.format("unrecognized key [%s] of [%s]", key, clazz.getCanonicalName()); + @Override + public void setValue(ELEMENT element, GraphFiledName field, Object value) { + switch (field) { + case LABEL: + element.setLabel((String) value); + return; + case TIME: + element.setTime((long) value); + return; + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.typeSysError(keyTypeErrMsg(element.getClass(), field.name()))); + } } + } + private static String keyTypeErrMsg(Class clazz, String key) { + return String.format("unrecognized key [%s] of [%s]", key, clazz.getCanonicalName()); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFiledName.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFiledName.java index 6fdfc8e5e..df228ce42 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFiledName.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphFiledName.java @@ -21,29 +21,16 @@ public enum GraphFiledName { - /** - * Vertex id. - */ - ID, - /** - * Edge src. - */ - SRC_ID, - /** - * Edge dst. - */ - DST_ID, - /** - * Edge direction. - */ - DIRECTION, - /** - * Label. - */ - LABEL, - /** - * Time. - */ - TIME - + /** Vertex id. */ + ID, + /** Edge src. */ + SRC_ID, + /** Edge dst. */ + DST_ID, + /** Edge direction. */ + DIRECTION, + /** Label. */ + LABEL, + /** Time. */ + TIME } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMeta.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMeta.java index 1bfae03d8..395e5ccdc 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMeta.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMeta.java @@ -23,30 +23,32 @@ public class GraphMeta { - private final IType keyType; - private final IGraphElementMeta vertexMeta; - private final IGraphElementMeta edgeMeta; - - public GraphMeta(GraphMetaType graphMetaType) { - this.keyType = graphMetaType.geKeyType(); - Class vClass = graphMetaType.getVertexClass(); - Class vvClass = graphMetaType.getVertexValueClass(); - Class eClass = graphMetaType.getEdgeClass(); - Class evClass = graphMetaType.getEdgeValueClass(); - this.vertexMeta = GraphElementMetas.getMeta(this.keyType, vClass, graphMetaType.getVertexConstruct(), vvClass); - this.edgeMeta = GraphElementMetas.getMeta(this.keyType, eClass, graphMetaType.getEdgeConstruct(), evClass); - } - - public IType getKeyType() { - return this.keyType; - } - - public IGraphElementMeta getVertexMeta() { - return this.vertexMeta; - } - - public IGraphElementMeta getEdgeMeta() { - return this.edgeMeta; - } - + private final IType keyType; + private final IGraphElementMeta vertexMeta; + private final IGraphElementMeta edgeMeta; + + public GraphMeta(GraphMetaType graphMetaType) { + this.keyType = graphMetaType.geKeyType(); + Class vClass = graphMetaType.getVertexClass(); + Class vvClass = graphMetaType.getVertexValueClass(); + Class eClass = graphMetaType.getEdgeClass(); + Class evClass = graphMetaType.getEdgeValueClass(); + this.vertexMeta = + GraphElementMetas.getMeta( + this.keyType, vClass, graphMetaType.getVertexConstruct(), vvClass); + this.edgeMeta = + GraphElementMetas.getMeta(this.keyType, eClass, graphMetaType.getEdgeConstruct(), evClass); + } + + public IType getKeyType() { + return this.keyType; + } + + public IGraphElementMeta getVertexMeta() { + return this.vertexMeta; + } + + public IGraphElementMeta getEdgeMeta() { + return this.edgeMeta; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMetaType.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMetaType.java index a8c403734..d0fe62275 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMetaType.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/GraphMetaType.java @@ -21,136 +21,149 @@ import java.io.Serializable; import java.util.function.Supplier; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; -public class GraphMetaType, E extends IEdge> implements Serializable { - - private IType keyType; - private Class vertexClass; - private Supplier vertexConstruct; - private Class vertexValueClass; - private Class edgeClass; - private Supplier edgeConstruct; - private Class edgeValueClass; - - public GraphMetaType() { - } - - public GraphMetaType(IType keyType, - Class vertexClass, - Supplier vertexConstruct, - Class vertexValueClass, - Class edgeClass, - Supplier edgeConstruct, - Class edgeValueClass) { - this.keyType = keyType; - this.vertexClass = vertexClass; - this.vertexConstruct = vertexConstruct; - this.vertexValueClass = vertexValueClass; - this.edgeClass = edgeClass; - this.edgeConstruct = edgeConstruct; - this.edgeValueClass = edgeValueClass; - } - - public GraphMetaType(IType keyType, - Class vertexClass, - Class vertexValueClass, - Class edgeClass, - Class edgeValueClass) { - this(keyType, vertexClass, - new DefaultObjectConstruct<>(vertexClass), - vertexValueClass, edgeClass, - new DefaultObjectConstruct<>(edgeClass), - edgeValueClass); - } - - public IType geKeyType() { - return this.keyType; - } - - public void setKeyType(IType keyType) { - this.keyType = keyType; - } - - public Class getVertexClass() { - return this.vertexClass; - } - - public void setVertexClass(Class vertexClass) { - this.vertexClass = vertexClass; - } - - public Class getVertexValueClass() { - return this.vertexValueClass; - } - - public void setVertexValueClass(Class vertexValueClass) { - this.vertexValueClass = vertexValueClass; - } - - public Class getEdgeClass() { - return this.edgeClass; - } - - public void setEdgeClass(Class edgeClass) { - this.edgeClass = edgeClass; - } - - public Class getEdgeValueClass() { - return this.edgeValueClass; - } - - public void setEdgeValueClass(Class edgeValueClass) { - this.edgeValueClass = edgeValueClass; - } - - public Supplier getVertexConstruct() { - return vertexConstruct; - } - - public void setVertexConstruct(Supplier vertexConstruct) { - this.vertexConstruct = vertexConstruct; - } - - public Supplier getEdgeConstruct() { - return edgeConstruct; - } - - public void setEdgeConstruct(Supplier edgeConstruct) { - this.edgeConstruct = edgeConstruct; +public class GraphMetaType, E extends IEdge> + implements Serializable { + + private IType keyType; + private Class vertexClass; + private Supplier vertexConstruct; + private Class vertexValueClass; + private Class edgeClass; + private Supplier edgeConstruct; + private Class edgeValueClass; + + public GraphMetaType() {} + + public GraphMetaType( + IType keyType, + Class vertexClass, + Supplier vertexConstruct, + Class vertexValueClass, + Class edgeClass, + Supplier edgeConstruct, + Class edgeValueClass) { + this.keyType = keyType; + this.vertexClass = vertexClass; + this.vertexConstruct = vertexConstruct; + this.vertexValueClass = vertexValueClass; + this.edgeClass = edgeClass; + this.edgeConstruct = edgeConstruct; + this.edgeValueClass = edgeValueClass; + } + + public GraphMetaType( + IType keyType, + Class vertexClass, + Class vertexValueClass, + Class edgeClass, + Class edgeValueClass) { + this( + keyType, + vertexClass, + new DefaultObjectConstruct<>(vertexClass), + vertexValueClass, + edgeClass, + new DefaultObjectConstruct<>(edgeClass), + edgeValueClass); + } + + public IType geKeyType() { + return this.keyType; + } + + public void setKeyType(IType keyType) { + this.keyType = keyType; + } + + public Class getVertexClass() { + return this.vertexClass; + } + + public void setVertexClass(Class vertexClass) { + this.vertexClass = vertexClass; + } + + public Class getVertexValueClass() { + return this.vertexValueClass; + } + + public void setVertexValueClass(Class vertexValueClass) { + this.vertexValueClass = vertexValueClass; + } + + public Class getEdgeClass() { + return this.edgeClass; + } + + public void setEdgeClass(Class edgeClass) { + this.edgeClass = edgeClass; + } + + public Class getEdgeValueClass() { + return this.edgeValueClass; + } + + public void setEdgeValueClass(Class edgeValueClass) { + this.edgeValueClass = edgeValueClass; + } + + public Supplier getVertexConstruct() { + return vertexConstruct; + } + + public void setVertexConstruct(Supplier vertexConstruct) { + this.vertexConstruct = vertexConstruct; + } + + public Supplier getEdgeConstruct() { + return edgeConstruct; + } + + public void setEdgeConstruct(Supplier edgeConstruct) { + this.edgeConstruct = edgeConstruct; + } + + @Override + public String toString() { + return "GraphMetaType{" + + "keyType=" + + keyType + + ", vertexClass=" + + vertexClass + + ", vertexConstruct=" + + vertexConstruct + + ", vertexValueClass=" + + vertexValueClass + + ", edgeClass=" + + edgeClass + + ", edgeConstruct=" + + edgeConstruct + + ", edgeValueClass=" + + edgeValueClass + + '}'; + } + + private static class DefaultObjectConstruct implements Supplier { + + private final Class clazz; + + public DefaultObjectConstruct(Class clazz) { + this.clazz = clazz; } @Override - public String toString() { - return "GraphMetaType{" - + "keyType=" + keyType - + ", vertexClass=" + vertexClass - + ", vertexConstruct=" + vertexConstruct - + ", vertexValueClass=" + vertexValueClass - + ", edgeClass=" + edgeClass - + ", edgeConstruct=" + edgeConstruct - + ", edgeValueClass=" + edgeValueClass - + '}'; - } - - private static class DefaultObjectConstruct implements Supplier { - - private final Class clazz; - - public DefaultObjectConstruct(Class clazz) { - this.clazz = clazz; - } - - @Override - public C get() { - try { - return clazz.newInstance(); - } catch (InstantiationException | IllegalAccessException e) { - throw new GeaflowRuntimeException("Error in create instance for class: " + clazz, e); - } - } - } + public C get() { + try { + return clazz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { + throw new GeaflowRuntimeException("Error in create instance for class: " + clazz, e); + } + } + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphElementMeta.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphElementMeta.java index 24d564539..4431aae6c 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphElementMeta.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphElementMeta.java @@ -21,44 +21,45 @@ import java.io.Serializable; import java.util.function.Supplier; + import org.apache.geaflow.common.schema.ISchema; public interface IGraphElementMeta extends Serializable { - /** - * Get graph element id. - * - * @return graph element id - */ - byte getGraphElementId(); + /** + * Get graph element id. + * + * @return graph element id + */ + byte getGraphElementId(); - /** - * Get graph element class. - * - * @return vertex or edge class - */ - Class getGraphElementClass(); + /** + * Get graph element class. + * + * @return vertex or edge class + */ + Class getGraphElementClass(); - /** - * Get meta of graph element, like id, label of vertex or src, dst, ts of edge. - * - * @return vertex or edge primitive schema - */ - ISchema getGraphMeta(); + /** + * Get meta of graph element, like id, label of vertex or src, dst, ts of edge. + * + * @return vertex or edge primitive schema + */ + ISchema getGraphMeta(); - /** - * Get graph field serializer. - * - * @return graph field serializer - */ - IGraphFieldSerializer getGraphFieldSerializer(); + /** + * Get graph field serializer. + * + * @return graph field serializer + */ + IGraphFieldSerializer getGraphFieldSerializer(); - /** - * Get property class of a vertex or edge. - * - * @return property class - */ - Class getPropertyClass(); + /** + * Get property class of a vertex or edge. + * + * @return property class + */ + Class getPropertyClass(); - Supplier getGraphElementConstruct(); + Supplier getGraphElementConstruct(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphFieldSerializer.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphFieldSerializer.java index e86b2dff7..00d3b3eee 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphFieldSerializer.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/meta/IGraphFieldSerializer.java @@ -23,22 +23,21 @@ public interface IGraphFieldSerializer extends Serializable { - /** - * Extract value of a key from a vertex or edge. - * - * @param element vertex or edge - * @param field graph field name - * @return value - */ - Object getValue(ELEMENT element, GraphFiledName field); - - /** - * Inject value of a key to a vertex or edge. - * - * @param element vertex or edge - * @param field graph field name - * @param value value - */ - void setValue(ELEMENT element, GraphFiledName field, Object value); + /** + * Extract value of a key from a vertex or edge. + * + * @param element vertex or edge + * @param field graph field name + * @return value + */ + Object getValue(ELEMENT element, GraphFiledName field); + /** + * Inject value of a key to a vertex or edge. + * + * @param element vertex or edge + * @param field graph field name + * @param value value + */ + void setValue(ELEMENT element, GraphFiledName field, Object value); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/EmptyProperty.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/EmptyProperty.java index fbc5b98ea..930784f78 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/EmptyProperty.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/EmptyProperty.java @@ -21,6 +21,4 @@ import java.io.Serializable; -public class EmptyProperty implements Serializable { - -} +public class EmptyProperty implements Serializable {} diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/IPropertySerializable.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/IPropertySerializable.java index 12db093dd..1e8b1ecfa 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/IPropertySerializable.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/property/IPropertySerializable.java @@ -23,9 +23,9 @@ public interface IPropertySerializable extends Serializable, Cloneable { - IPropertySerializable fromBinary(byte[] bytes); + IPropertySerializable fromBinary(byte[] bytes); - byte[] toBytes(); + byte[] toBytes(); - IPropertySerializable clone(); + IPropertySerializable clone(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/IVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/IVertex.java index 27fe0bae3..5f3fbd9ab 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/IVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/IVertex.java @@ -23,49 +23,48 @@ public interface IVertex extends Serializable, Comparable { - /** - * Get id of vertex. - * - * @return id - */ - K getId(); + /** + * Get id of vertex. + * + * @return id + */ + K getId(); - /** - * Set id for the vertex. - * - * @param id id - */ - void setId(K id); + /** + * Set id for the vertex. + * + * @param id id + */ + void setId(K id); - /** - * Get value of vertex. - * - * @return value - */ - VV getValue(); + /** + * Get value of vertex. + * + * @return value + */ + VV getValue(); - /** - * Reset value for the vertex. - * - * @param value value - * @return vertex - */ - IVertex withValue(VV value); + /** + * Reset value for the vertex. + * + * @param value value + * @return vertex + */ + IVertex withValue(VV value); - /** - * Reset label value for the vertex. - * - * @param label label - * @return vertex - */ - IVertex withLabel(String label); - - /** - * Reset time value for the vertex. - * - * @param time time - * @return vertex - */ - IVertex withTime(long time); + /** + * Reset label value for the vertex. + * + * @param label label + * @return vertex + */ + IVertex withLabel(String label); + /** + * Reset time value for the vertex. + * + * @param time time + * @return vertex + */ + IVertex withTime(long time); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelTimeVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelTimeVertex.java index b86d4d499..b2f1b6e9e 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelTimeVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelTimeVertex.java @@ -20,77 +20,76 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.IGraphElementWithTimeField; public class IDLabelTimeVertex extends IDVertex implements IGraphElementWithLabelField, IGraphElementWithTimeField { - private String label; - private long time; - - public IDLabelTimeVertex() { - } - - public IDLabelTimeVertex(K id) { - super(id); + private String label; + private long time; + + public IDLabelTimeVertex() {} + + public IDLabelTimeVertex(K id) { + super(id); + } + + public IDLabelTimeVertex(K id, String label, long time) { + super(id); + this.label = label; + this.time = time; + } + + @Override + public String getLabel() { + return this.label; + } + + @Override + public void setLabel(String label) { + this.label = label; + } + + @Override + public long getTime() { + return this.time; + } + + @Override + public void setTime(long time) { + this.time = time; + } + + @Override + public ValueVertex withValue(Object value) { + return new ValueLabelTimeVertex<>(this.getId(), value, this.label, this.time); + } + + @Override + public IDVertex withLabel(String label) { + this.label = label; + return this; + } + + @Override + public IDVertex withTime(long time) { + this.time = time; + return this; + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } - - public IDLabelTimeVertex(K id, String label, long time) { - super(id); - this.label = label; - this.time = time; - } - - @Override - public String getLabel() { - return this.label; - } - - @Override - public void setLabel(String label) { - this.label = label; - } - - @Override - public long getTime() { - return this.time; - } - - @Override - public void setTime(long time) { - this.time = time; - } - - @Override - public ValueVertex withValue(Object value) { - return new ValueLabelTimeVertex<>(this.getId(), value, this.label, this.time); - } - - @Override - public IDVertex withLabel(String label) { - this.label = label; - return this; - } - - @Override - public IDVertex withTime(long time) { - this.time = time; - return this; - } - - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDLabelTimeVertex that = (IDLabelTimeVertex) o; - return this.time == that.time && Objects.equals(this.label, that.label); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label, this.time); - } - + IDLabelTimeVertex that = (IDLabelTimeVertex) o; + return this.time == that.time && Objects.equals(this.label, that.label); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label, this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelVertex.java index 5d098d561..6167a35f3 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDLabelVertex.java @@ -20,62 +20,61 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; public class IDLabelVertex extends IDVertex implements IGraphElementWithLabelField { - private String label; + private String label; - public IDLabelVertex() { - } + public IDLabelVertex() {} - public IDLabelVertex(K id) { - super(id); - } + public IDLabelVertex(K id) { + super(id); + } - public IDLabelVertex(K id, String label) { - super(id); - this.label = label; - } + public IDLabelVertex(K id, String label) { + super(id); + this.label = label; + } - @Override - public String getLabel() { - return this.label; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public void setLabel(String label) { - this.label = label; - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public ValueVertex withValue(Object value) { - return new ValueLabelVertex<>(this.getId(), value, this.label); - } + @Override + public ValueVertex withValue(Object value) { + return new ValueLabelVertex<>(this.getId(), value, this.label); + } - @Override - public IDVertex withLabel(String label) { - this.label = label; - return this; - } + @Override + public IDVertex withLabel(String label) { + this.label = label; + return this; + } - @Override - public IDVertex withTime(long time) { - return new IDLabelTimeVertex<>(this.getId(), this.label, time); - } - - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDLabelVertex that = (IDLabelVertex) o; - return Objects.equals(this.label, that.label); - } + @Override + public IDVertex withTime(long time) { + return new IDLabelTimeVertex<>(this.getId(), this.label, time); + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + IDLabelVertex that = (IDLabelVertex) o; + return Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDTimeVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDTimeVertex.java index dd85a83ea..84ef7112d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDTimeVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDTimeVertex.java @@ -20,61 +20,61 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithTimeField; public class IDTimeVertex extends IDVertex implements IGraphElementWithTimeField { - private long time; + private long time; - public IDTimeVertex() { - } + public IDTimeVertex() {} - public IDTimeVertex(K id) { - super(id); - } + public IDTimeVertex(K id) { + super(id); + } - public IDTimeVertex(K id, long time) { - super(id); - this.time = time; - } + public IDTimeVertex(K id, long time) { + super(id); + this.time = time; + } - @Override - public long getTime() { - return time; - } + @Override + public long getTime() { + return time; + } - @Override - public void setTime(long time) { - this.time = time; - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public ValueVertex withValue(Object value) { - return new ValueTimeVertex<>(this.getId(), value, this.time); - } + @Override + public ValueVertex withValue(Object value) { + return new ValueTimeVertex<>(this.getId(), value, this.time); + } - @Override - public IDVertex withLabel(String label) { - return new IDLabelTimeVertex<>(this.getId(), label, this.time); - } + @Override + public IDVertex withLabel(String label) { + return new IDLabelTimeVertex<>(this.getId(), label, this.time); + } - @Override - public IDTimeVertex withTime(long time) { - this.time = time; - return this; - } + @Override + public IDTimeVertex withTime(long time) { + this.time = time; + return this; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - IDTimeVertex that = (IDTimeVertex) o; - return this.time == that.time; + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + IDTimeVertex that = (IDTimeVertex) o; + return this.time == that.time; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.time); - } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDVertex.java index ab3f17015..1dc89e198 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/IDVertex.java @@ -20,79 +20,78 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.vertex.IVertex; public class IDVertex implements IVertex { - private K id; + private K id; - public IDVertex() { - } + public IDVertex() {} - public IDVertex(K id) { - this.id = id; - } + public IDVertex(K id) { + this.id = id; + } - @Override - public K getId() { - return this.id; - } + @Override + public K getId() { + return this.id; + } - @Override - public void setId(K id) { - this.id = id; - } - - @Override - public Object getValue() { - return null; - } + @Override + public void setId(K id) { + this.id = id; + } - @Override - public ValueVertex withValue(Object value) { - return new ValueVertex<>(this.id, value); - } + @Override + public Object getValue() { + return null; + } - @Override - public IDVertex withLabel(String label) { - return new IDLabelVertex<>(this.id, label); - } + @Override + public ValueVertex withValue(Object value) { + return new ValueVertex<>(this.id, value); + } - @Override - public IDVertex withTime(long time) { - return new IDTimeVertex<>(this.id, time); - } + @Override + public IDVertex withLabel(String label) { + return new IDLabelVertex<>(this.id, label); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - IDVertex idVertex = (IDVertex) o; - return Objects.equals(this.id, idVertex.id); - } + @Override + public IDVertex withTime(long time) { + return new IDTimeVertex<>(this.id, time); + } - @Override - public int hashCode() { - return Objects.hash(this.id); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return String.format("IDVertex(vertexId:%s)", id); + if (o == null || getClass() != o.getClass()) { + return false; } - - @Override - public int compareTo(Object o) { - IDVertex vertex = (IDVertex) o; - if (id instanceof Comparable) { - return ((Comparable) id).compareTo(vertex.getId()); - } else { - return ((Integer) hashCode()).compareTo(vertex.hashCode()); - } + IDVertex idVertex = (IDVertex) o; + return Objects.equals(this.id, idVertex.id); + } + + @Override + public int hashCode() { + return Objects.hash(this.id); + } + + @Override + public String toString() { + return String.format("IDVertex(vertexId:%s)", id); + } + + @Override + public int compareTo(Object o) { + IDVertex vertex = (IDVertex) o; + if (id instanceof Comparable) { + return ((Comparable) id).compareTo(vertex.getId()); + } else { + return ((Integer) hashCode()).compareTo(vertex.hashCode()); } - + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelTimeVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelTimeVertex.java index 30f8d43c1..54e126d9a 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelTimeVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelTimeVertex.java @@ -20,72 +20,71 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.IGraphElementWithTimeField; public class ValueLabelTimeVertex extends ValueVertex implements IGraphElementWithLabelField, IGraphElementWithTimeField { - private String label; - private long time; - - public ValueLabelTimeVertex() { - } - - public ValueLabelTimeVertex(K id) { - super(id); - } - - public ValueLabelTimeVertex(K id, VV value, String label, long time) { - super(id, value); - this.label = label; - this.time = time; - } - - @Override - public String getLabel() { - return this.label; - } - - @Override - public void setLabel(String label) { - this.label = label; - } - - @Override - public long getTime() { - return this.time; + private String label; + private long time; + + public ValueLabelTimeVertex() {} + + public ValueLabelTimeVertex(K id) { + super(id); + } + + public ValueLabelTimeVertex(K id, VV value, String label, long time) { + super(id, value); + this.label = label; + this.time = time; + } + + @Override + public String getLabel() { + return this.label; + } + + @Override + public void setLabel(String label) { + this.label = label; + } + + @Override + public long getTime() { + return this.time; + } + + @Override + public void setTime(long time) { + this.time = time; + } + + @Override + public ValueVertex withLabel(String label) { + this.label = label; + return this; + } + + @Override + public ValueVertex withTime(long time) { + this.time = time; + return this; + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } - - @Override - public void setTime(long time) { - this.time = time; - } - - @Override - public ValueVertex withLabel(String label) { - this.label = label; - return this; - } - - @Override - public ValueVertex withTime(long time) { - this.time = time; - return this; - } - - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueLabelTimeVertex that = (ValueLabelTimeVertex) o; - return this.time == that.time && Objects.equals(this.label, that.label); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label, this.time); - } - + ValueLabelTimeVertex that = (ValueLabelTimeVertex) o; + return this.time == that.time && Objects.equals(this.label, that.label); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label, this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelVertex.java index 0d5b8880b..55b482a5d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueLabelVertex.java @@ -20,57 +20,57 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; -public class ValueLabelVertex extends ValueVertex implements IGraphElementWithLabelField { +public class ValueLabelVertex extends ValueVertex + implements IGraphElementWithLabelField { - private String label; + private String label; - public ValueLabelVertex() { - } + public ValueLabelVertex() {} - public ValueLabelVertex(K id) { - super(id); - } - - public ValueLabelVertex(K id, VV value, String label) { - super(id, value); - this.label = label; - } + public ValueLabelVertex(K id) { + super(id); + } - @Override - public void setLabel(String label) { - this.label = label; - } + public ValueLabelVertex(K id, VV value, String label) { + super(id, value); + this.label = label; + } - @Override - public String getLabel() { - return this.label; - } + @Override + public void setLabel(String label) { + this.label = label; + } - @Override - public ValueVertex withLabel(String label) { - this.label = label; - return this; - } + @Override + public String getLabel() { + return this.label; + } - @Override - public ValueVertex withTime(long time) { - return new ValueLabelTimeVertex<>(this.getId(), this.getValue(), this.label, time); - } + @Override + public ValueVertex withLabel(String label) { + this.label = label; + return this; + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueLabelVertex that = (ValueLabelVertex) o; - return Objects.equals(this.label, that.label); - } + @Override + public ValueVertex withTime(long time) { + return new ValueLabelTimeVertex<>(this.getId(), this.getValue(), this.label, time); + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.label); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + ValueLabelVertex that = (ValueLabelVertex) o; + return Objects.equals(this.label, that.label); + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.label); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueTimeVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueTimeVertex.java index 1afce4dcd..4f12ab21f 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueTimeVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueTimeVertex.java @@ -20,57 +20,57 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.IGraphElementWithTimeField; -public class ValueTimeVertex extends ValueVertex implements IGraphElementWithTimeField { +public class ValueTimeVertex extends ValueVertex + implements IGraphElementWithTimeField { - private long time; + private long time; - public ValueTimeVertex() { - } + public ValueTimeVertex() {} - public ValueTimeVertex(K id) { - super(id); - } - - public ValueTimeVertex(K id, VV value, long time) { - super(id, value); - this.time = time; - } + public ValueTimeVertex(K id) { + super(id); + } - @Override - public long getTime() { - return this.time; - } + public ValueTimeVertex(K id, VV value, long time) { + super(id, value); + this.time = time; + } - @Override - public void setTime(long time) { - this.time = time; - } + @Override + public long getTime() { + return this.time; + } - @Override - public ValueVertex withLabel(String label) { - return new ValueLabelTimeVertex<>(this.getId(), this.getValue(), label, this.time); - } + @Override + public void setTime(long time) { + this.time = time; + } - @Override - public ValueVertex withTime(long time) { - this.time = time; - return this; - } + @Override + public ValueVertex withLabel(String label) { + return new ValueLabelTimeVertex<>(this.getId(), this.getValue(), label, this.time); + } - @Override - public boolean equals(Object o) { - if (!super.equals(o)) { - return false; - } - ValueTimeVertex that = (ValueTimeVertex) o; - return this.time == that.time; - } + @Override + public ValueVertex withTime(long time) { + this.time = time; + return this; + } - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), this.time); + @Override + public boolean equals(Object o) { + if (!super.equals(o)) { + return false; } + ValueTimeVertex that = (ValueTimeVertex) o; + return this.time == that.time; + } + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.time); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueVertex.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueVertex.java index e56af27e5..8482b0932 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueVertex.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/graph/vertex/impl/ValueVertex.java @@ -20,87 +20,85 @@ package org.apache.geaflow.model.graph.vertex.impl; import java.util.Objects; + import org.apache.geaflow.model.graph.vertex.IVertex; public class ValueVertex implements IVertex { - private K id; - private VV value; - - public ValueVertex() { - } - - public ValueVertex(K id) { - this.id = id; - } - - public ValueVertex(K id, VV value) { - this.id = id; - this.value = value; - } - - @Override - public K getId() { - return this.id; - } - - @Override - public void setId(K id) { - this.id = id; - } - - @Override - public VV getValue() { - return this.value; - } - - @Override - public ValueVertex withValue(VV value) { - this.value = value; - return this; - } - - @Override - public ValueVertex withLabel(String label) { - return new ValueLabelVertex<>(this.id, this.value, label); + private K id; + private VV value; + + public ValueVertex() {} + + public ValueVertex(K id) { + this.id = id; + } + + public ValueVertex(K id, VV value) { + this.id = id; + this.value = value; + } + + @Override + public K getId() { + return this.id; + } + + @Override + public void setId(K id) { + this.id = id; + } + + @Override + public VV getValue() { + return this.value; + } + + @Override + public ValueVertex withValue(VV value) { + this.value = value; + return this; + } + + @Override + public ValueVertex withLabel(String label) { + return new ValueLabelVertex<>(this.id, this.value, label); + } + + @Override + public ValueVertex withTime(long time) { + return new ValueTimeVertex<>(this.id, this.value, time); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public ValueVertex withTime(long time) { - return new ValueTimeVertex<>(this.id, this.value, time); + if (o == null || getClass() != o.getClass()) { + return false; } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ValueVertex idVertex = (ValueVertex) o; - return Objects.equals(this.id, idVertex.id); - } - - @Override - public int hashCode() { - return Objects.hash(this.id); + ValueVertex idVertex = (ValueVertex) o; + return Objects.equals(this.id, idVertex.id); + } + + @Override + public int hashCode() { + return Objects.hash(this.id); + } + + @Override + public String toString() { + return String.format("ValueVertex(vertexId:%s, value:%s)", id, value); + } + + @Override + public int compareTo(Object o) { + ValueVertex vertex = (ValueVertex) o; + if (id instanceof Comparable) { + return ((Comparable) id).compareTo(vertex.getId()); + } else { + return ((Integer) hashCode()).compareTo(vertex.hashCode()); } - - @Override - public String toString() { - return String.format("ValueVertex(vertexId:%s, value:%s)", id, value); - } - - @Override - public int compareTo(Object o) { - ValueVertex vertex = (ValueVertex) o; - if (id instanceof Comparable) { - return ((Comparable) id).compareTo(vertex.getId()); - } else { - return ((Integer) hashCode()).compareTo(vertex.hashCode()); - } - } - - + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/BatchRecord.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/BatchRecord.java index 79a5f0f72..21569b94d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/BatchRecord.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/BatchRecord.java @@ -24,31 +24,31 @@ public class BatchRecord implements Serializable { - private final RecordArgs recordArgs; - private Iterator messageIterator; + private final RecordArgs recordArgs; + private Iterator messageIterator; - public BatchRecord(RecordArgs recordArgs) { - this.recordArgs = recordArgs; - } + public BatchRecord(RecordArgs recordArgs) { + this.recordArgs = recordArgs; + } - public BatchRecord(RecordArgs recordArgs, Iterator messageIterator) { - this.recordArgs = recordArgs; - this.messageIterator = messageIterator; - } + public BatchRecord(RecordArgs recordArgs, Iterator messageIterator) { + this.recordArgs = recordArgs; + this.messageIterator = messageIterator; + } - public Iterator getMessageIterator() { - return messageIterator; - } + public Iterator getMessageIterator() { + return messageIterator; + } - public RecordArgs getRecordArgs() { - return recordArgs; - } + public RecordArgs getRecordArgs() { + return recordArgs; + } - public String getStreamName() { - return recordArgs.getName(); - } + public String getStreamName() { + return recordArgs.getName(); + } - public long getBatchId() { - return recordArgs.getWindowId(); - } + public long getBatchId() { + return recordArgs.getWindowId(); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IKeyRecord.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IKeyRecord.java index ec34ef49f..42dedb0b0 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IKeyRecord.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IKeyRecord.java @@ -21,6 +21,5 @@ public interface IKeyRecord extends IRecord { - KEY getKey(); - + KEY getKey(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IRecord.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IRecord.java index 28003d6b7..8d28528ef 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IRecord.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/IRecord.java @@ -21,9 +21,7 @@ import java.io.Serializable; - public interface IRecord extends Serializable { - T getValue(); - + T getValue(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordArgs.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordArgs.java index e535b37aa..c9827a06f 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordArgs.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordArgs.java @@ -21,37 +21,41 @@ public class RecordArgs { - private long windowId; - private String name; - - public RecordArgs() { - this.windowId = 0; - this.name = ""; - } - - public RecordArgs(long windowId) { - this.windowId = windowId; - } - - public RecordArgs(long windowId, String name) { - this.windowId = windowId; - this.name = name; - } - - public long getWindowId() { - return windowId; - } - - public String getName() { - return name; - } - - @Override - public String toString() { - return "RecordArgs{" + "windowId=" + windowId + ", name='" + name + '\'' + '}'; - } - - public enum GraphRecordNames { - Vertex, Edge, Request, Message, Aggregate; - } + private long windowId; + private String name; + + public RecordArgs() { + this.windowId = 0; + this.name = ""; + } + + public RecordArgs(long windowId) { + this.windowId = windowId; + } + + public RecordArgs(long windowId, String name) { + this.windowId = windowId; + this.name = name; + } + + public long getWindowId() { + return windowId; + } + + public String getName() { + return name; + } + + @Override + public String toString() { + return "RecordArgs{" + "windowId=" + windowId + ", name='" + name + '\'' + '}'; + } + + public enum GraphRecordNames { + Vertex, + Edge, + Request, + Message, + Aggregate; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordFactory.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordFactory.java index 0fcfe8fd1..f438afc54 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordFactory.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/RecordFactory.java @@ -20,17 +20,17 @@ package org.apache.geaflow.model.record; import java.io.Serializable; + import org.apache.geaflow.model.record.impl.KeyRecord; import org.apache.geaflow.model.record.impl.Record; public class RecordFactory implements Serializable { - public static IRecord buildRecord(T value) { - return new Record<>(value); - } - - public static IKeyRecord buildRecord(KEY key, T value) { - return new KeyRecord<>(key, value); - } + public static IRecord buildRecord(T value) { + return new Record<>(value); + } + public static IKeyRecord buildRecord(KEY key, T value) { + return new KeyRecord<>(key, value); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/KeyRecord.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/KeyRecord.java index 5fd6eb2ef..6ab211424 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/KeyRecord.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/KeyRecord.java @@ -23,29 +23,29 @@ public class KeyRecord extends Record implements IKeyRecord { - protected KEY key; - - public KeyRecord(KEY key, Record record) { - super(record.getValue()); - this.key = key; - } - - public KeyRecord(KEY key, T value) { - super(value); - this.key = key; - } - - @Override - public KEY getKey() { - return key; - } - - public void setKey(KEY key) { - this.key = key; - } - - @Override - public String toString() { - return String.format("key:%s value:%s", key, super.toString()); - } + protected KEY key; + + public KeyRecord(KEY key, Record record) { + super(record.getValue()); + this.key = key; + } + + public KeyRecord(KEY key, T value) { + super(value); + this.key = key; + } + + @Override + public KEY getKey() { + return key; + } + + public void setKey(KEY key) { + this.key = key; + } + + @Override + public String toString() { + return String.format("key:%s value:%s", key, super.toString()); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/Record.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/Record.java index a8f86b5d3..5eb5db98f 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/Record.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/record/impl/Record.java @@ -23,38 +23,36 @@ public class Record implements IRecord { - protected T value; + protected T value; - public Record() { - } + public Record() {} - public Record(T value) { - this.value = value; - } + public Record(T value) { + this.value = value; + } - @Override - public T getValue() { - return value; - } + @Override + public T getValue() { + return value; + } - public void setValue(T value) { - this.value = value; - } + public void setValue(T value) { + this.value = value; + } - @Override - public boolean equals(Object object) { - if (object instanceof Record) { - return ((Record) object).getValue().equals(this.value); - } - return false; + @Override + public boolean equals(Object object) { + if (object instanceof Record) { + return ((Record) object).getValue().equals(this.value); } + return false; + } - @Override - public String toString() { - if (value != null) { - return value.toString(); - } - return super.toString(); + @Override + public String toString() { + if (value != null) { + return value.toString(); } - + return super.toString(); + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalRequest.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalRequest.java index 6b9bee275..9039f8860 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalRequest.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalRequest.java @@ -20,14 +20,14 @@ package org.apache.geaflow.model.traversal; import java.io.Serializable; + import org.apache.geaflow.model.traversal.TraversalType.RequestType; public interface ITraversalRequest extends Serializable { - long getRequestId(); - - K getVId(); + long getRequestId(); - RequestType getType(); + K getVId(); + RequestType getType(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalResponse.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalResponse.java index 9e561fe41..1ce680d5d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalResponse.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/ITraversalResponse.java @@ -20,14 +20,14 @@ package org.apache.geaflow.model.traversal; import java.io.Serializable; + import org.apache.geaflow.model.traversal.TraversalType.ResponseType; public interface ITraversalResponse extends Serializable { - long getResponseId(); - - R getResponse(); + long getResponseId(); - ResponseType getType(); + R getResponse(); + ResponseType getType(); } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/TraversalType.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/TraversalType.java index 2913aecc7..025ab1f87 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/TraversalType.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/TraversalType.java @@ -21,12 +21,13 @@ public class TraversalType { - public enum RequestType { - Vertex, Edge; - } - - public enum ResponseType { - Vertex, Edge; - } + public enum RequestType { + Vertex, + Edge; + } + public enum ResponseType { + Vertex, + Edge; + } } diff --git a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/impl/VertexBeginTraversalRequest.java b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/impl/VertexBeginTraversalRequest.java index b6a7801e4..4be94ce5d 100644 --- a/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/impl/VertexBeginTraversalRequest.java +++ b/geaflow/geaflow-model/src/main/java/org/apache/geaflow/model/traversal/impl/VertexBeginTraversalRequest.java @@ -24,25 +24,24 @@ public class VertexBeginTraversalRequest implements ITraversalRequest { - private K vId; - - public VertexBeginTraversalRequest(K vId) { - this.vId = vId; - } - - @Override - public long getRequestId() { - return this.vId.hashCode(); - } - - @Override - public K getVId() { - return vId; - } - - - @Override - public RequestType getType() { - return RequestType.Vertex; - } + private K vId; + + public VertexBeginTraversalRequest(K vId) { + this.vId = vId; + } + + @Override + public long getRequestId() { + return this.vId.hashCode(); + } + + @Override + public K getVId() { + return vId; + } + + @Override + public RequestType getType() { + return RequestType.Vertex; + } } diff --git a/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/common/NullTest.java b/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/common/NullTest.java index 5f046c66b..aa6ee1ce0 100644 --- a/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/common/NullTest.java +++ b/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/common/NullTest.java @@ -24,12 +24,12 @@ public class NullTest { - @Test - public void testNull() { - Null nullObject = new Null(); + @Test + public void testNull() { + Null nullObject = new Null(); - Assert.assertEquals(1, nullObject.hashCode()); - Assert.assertTrue(nullObject.equals(new Null())); - Assert.assertFalse(nullObject.equals(new NullTest())); - } + Assert.assertEquals(1, nullObject.hashCode()); + Assert.assertTrue(nullObject.equals(new Null())); + Assert.assertFalse(nullObject.equals(new NullTest())); + } } diff --git a/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/graph/message/GraphMessageTest.java b/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/graph/message/GraphMessageTest.java index 16f8a39a7..ecddbaa3c 100644 --- a/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/graph/message/GraphMessageTest.java +++ b/geaflow/geaflow-model/src/test/java/org/apache/geaflow/model/graph/message/GraphMessageTest.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.List; import java.util.NoSuchElementException; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.encoder.Encoders; import org.apache.geaflow.common.encoder.IEncoder; @@ -39,147 +40,149 @@ @SuppressWarnings("unchecked") public class GraphMessageTest { - @Test(expectedExceptions = {NoSuchElementException.class}) - public void testMessageException() { - DefaultGraphMessage msg = new DefaultGraphMessage<>(1, 2.0); + @Test(expectedExceptions = {NoSuchElementException.class}) + public void testMessageException() { + DefaultGraphMessage msg = new DefaultGraphMessage<>(1, 2.0); + Assert.assertTrue(msg.hasNext()); + Assert.assertEquals(msg.next(), 2.0); + Assert.assertFalse(msg.hasNext()); + msg.next(); + } + + @Test + public void testDirectEmitMsg() { + for (int i = 1; i <= 100; i++) { + DefaultGraphMessage msg = new DefaultGraphMessage<>(i, i * 2.0); + Assert.assertTrue(msg.hasNext()); + Assert.assertEquals(msg.next(), i * 2.0); + Assert.assertFalse(msg.hasNext()); + } + } + + @Test + public void testListMessage() { + for (int i = 1; i <= 100; i++) { + List msgList = new ArrayList<>(); + for (int j = 0; j < 10; j++) { + msgList.add(j * 2.0); + } + ListGraphMessage msg = new ListGraphMessage<>(i, msgList); + for (int j = 0; j < 10; j++) { Assert.assertTrue(msg.hasNext()); - Assert.assertEquals(msg.next(), 2.0); - Assert.assertFalse(msg.hasNext()); - msg.next(); + Assert.assertEquals(msg.next(), j * 2.0); + } + Assert.assertFalse(msg.hasNext()); } - - @Test - public void testDirectEmitMsg() { - for (int i = 1; i <= 100; i++) { - DefaultGraphMessage msg = new DefaultGraphMessage<>(i, i * 2.0); - Assert.assertTrue(msg.hasNext()); - Assert.assertEquals(msg.next(), i * 2.0); - Assert.assertFalse(msg.hasNext()); + } + + @Test + public void testDefaultMessageEncoder() { + IEncoder> msgEncoder = + (IEncoder>) + GraphMessageEncoders.build(Encoders.INTEGER, Encoders.DOUBLE); + msgEncoder.init(new Configuration()); + + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 1; i <= 100; i++) { + for (int j = 0; j < 10; j++) { + IGraphMessage msg = new DefaultGraphMessage<>(i, j * 2.0); + msgEncoder.encode(msg, bos); } - } - - @Test - public void testListMessage() { - for (int i = 1; i <= 100; i++) { - List msgList = new ArrayList<>(); - for (int j = 0; j < 10; j++) { - msgList.add(j * 2.0); - } - ListGraphMessage msg = new ListGraphMessage<>(i, msgList); - for (int j = 0; j < 10; j++) { - Assert.assertTrue(msg.hasNext()); - Assert.assertEquals(msg.next(), j * 2.0); - } - Assert.assertFalse(msg.hasNext()); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int i = 1; + int j = 0; + while (bis.available() > 0) { + IGraphMessage value = msgEncoder.decode(bis); + int vertexId = value.getTargetVId(); + Assert.assertEquals(vertexId, i); + while (value.hasNext()) { + Double next = value.next(); + Assert.assertEquals(next, j * 2.0); + j++; } - } - - @Test - public void testDefaultMessageEncoder() { - IEncoder> msgEncoder = - (IEncoder>) GraphMessageEncoders.build(Encoders.INTEGER, Encoders.DOUBLE); - msgEncoder.init(new Configuration()); - - try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 1; i <= 100; i++) { - for (int j = 0; j < 10; j++) { - IGraphMessage msg = new DefaultGraphMessage<>(i, j * 2.0); - msgEncoder.encode(msg, bos); - } - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int i = 1; - int j = 0; - while (bis.available() > 0) { - IGraphMessage value = msgEncoder.decode(bis); - int vertexId = value.getTargetVId(); - Assert.assertEquals(vertexId, i); - while (value.hasNext()) { - Double next = value.next(); - Assert.assertEquals(next, j * 2.0); - j++; - } - if (j == 10) { - i++; - j = 0; - } - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + if (j == 10) { + i++; + j = 0; } + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); } - - @Test - public void testListMessageEncoder() { - IEncoder> msgEncoder = - new ListGraphMessageEncoder<>(Encoders.INTEGER, Encoders.DOUBLE); - msgEncoder.init(new Configuration()); - - try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - for (int i = 1; i <= 100; i++) { - List msgList = new ArrayList<>(); - for (int j = 0; j < 10; j++) { - msgList.add(j * 2.0); - } - ListGraphMessage msg = new ListGraphMessage<>(i, msgList); - msgEncoder.encode(msg, bos); - } - - byte[] arr = bos.toByteArray(); - ByteArrayInputStream bis = new ByteArrayInputStream(arr); - - int i = 1; - while (bis.available() > 0) { - ListGraphMessage value = msgEncoder.decode(bis); - int vertexId = value.getTargetVId(); - Assert.assertEquals(vertexId, i); - List messages = value.getMessages(); - for (int j = 0; j < 10; j++) { - Assert.assertEquals(messages.get(j), j * 2.0); - } - i++; - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); + } + + @Test + public void testListMessageEncoder() { + IEncoder> msgEncoder = + new ListGraphMessageEncoder<>(Encoders.INTEGER, Encoders.DOUBLE); + msgEncoder.init(new Configuration()); + + try { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + for (int i = 1; i <= 100; i++) { + List msgList = new ArrayList<>(); + for (int j = 0; j < 10; j++) { + msgList.add(j * 2.0); } + ListGraphMessage msg = new ListGraphMessage<>(i, msgList); + msgEncoder.encode(msg, bos); + } + + byte[] arr = bos.toByteArray(); + ByteArrayInputStream bis = new ByteArrayInputStream(arr); + + int i = 1; + while (bis.available() > 0) { + ListGraphMessage value = msgEncoder.decode(bis); + int vertexId = value.getTargetVId(); + Assert.assertEquals(vertexId, i); + List messages = value.getMessages(); + for (int j = 0; j < 10; j++) { + Assert.assertEquals(messages.get(j), j * 2.0); + } + i++; + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(e.getMessage()), e); } - - @Test - public void testNullEncoder() { - IEncoder> encoder1 = - GraphMessageEncoders.build(null, Encoders.INTEGER); - Assert.assertNull(encoder1); - IEncoder> encoder2 = - GraphMessageEncoders.build(Encoders.INTEGER, null); - Assert.assertNull(encoder2); - IEncoder> encoder3 = - GraphMessageEncoders.build(null, null); - Assert.assertNull(encoder3); - } - - @Test(expectedExceptions = {GeaflowRuntimeException.class}) - public void testEncoderException() throws IOException { - IEncoder> msgEncoder = - (IEncoder>) GraphMessageEncoders.build(Encoders.INTEGER, Encoders.DOUBLE); - msgEncoder.init(new Configuration()); - msgEncoder.encode(null, null); - } - - @Test - public void testDefaultMessageToString() { - DefaultGraphMessage msg = new DefaultGraphMessage<>(1, 1); - Assert.assertEquals(msg.toString(), "DefaultGraphMessage{targetVId=1, message=1}"); - } - - @Test - public void testListMsgToString() { - ListGraphMessage msg = new ListGraphMessage<>(1, Collections.singletonList(1)); - Assert.assertEquals(msg.toString(), "ListGraphMessage{targetVId=1, messages=[1]}"); - } - + } + + @Test + public void testNullEncoder() { + IEncoder> encoder1 = + GraphMessageEncoders.build(null, Encoders.INTEGER); + Assert.assertNull(encoder1); + IEncoder> encoder2 = + GraphMessageEncoders.build(Encoders.INTEGER, null); + Assert.assertNull(encoder2); + IEncoder> encoder3 = + GraphMessageEncoders.build(null, null); + Assert.assertNull(encoder3); + } + + @Test(expectedExceptions = {GeaflowRuntimeException.class}) + public void testEncoderException() throws IOException { + IEncoder> msgEncoder = + (IEncoder>) + GraphMessageEncoders.build(Encoders.INTEGER, Encoders.DOUBLE); + msgEncoder.init(new Configuration()); + msgEncoder.encode(null, null); + } + + @Test + public void testDefaultMessageToString() { + DefaultGraphMessage msg = new DefaultGraphMessage<>(1, 1); + Assert.assertEquals(msg.toString(), "DefaultGraphMessage{targetVId=1, message=1}"); + } + + @Test + public void testListMsgToString() { + ListGraphMessage msg = + new ListGraphMessage<>(1, Collections.singletonList(1)); + Assert.assertEquals(msg.toString(), "ListGraphMessage{targetVId=1, messages=[1]}"); + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileConfigKeys.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileConfigKeys.java index 9ce421376..6a30adca8 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileConfigKeys.java @@ -24,89 +24,75 @@ public class FileConfigKeys { - protected static final int CORE_NUM = Runtime.getRuntime().availableProcessors(); - public static final ConfigKey PERSISTENT_TYPE = ConfigKeys - .key("geaflow.file.persistent.type") - .defaultValue("LOCAL") - .description("geaflow file persistent type"); - - public static final ConfigKey PERSISTENT_THREAD_SIZE = ConfigKeys - .key("geaflow.file.persistent.thread.size") - .defaultValue(CORE_NUM) - .description("geaflow file persistent thread size"); - - public static final ConfigKey USER_NAME = ConfigKeys - .key("geaflow.file.persistent.user.name") - .defaultValue("geaflow") - .description("geaflow file base url"); - - public static final ConfigKey JSON_CONFIG = ConfigKeys - .key("geaflow.file.persistent.config.json") - .defaultValue("") - .description("geaflow json config"); - - public static final ConfigKey ROOT = ConfigKeys - .key("geaflow.file.persistent.root") - .defaultValue("/geaflow/chk") - .description("geaflow file persistent root path"); - - /** - * oss config. - */ - public static final ConfigKey OSS_BUCKET_NAME = ConfigKeys - .key("geaflow.file.oss.bucket.name") - .defaultValue(null) - .description("oss bucket name"); - - public static final ConfigKey OSS_ENDPOINT = ConfigKeys - .key("geaflow.file.oss.endpoint") - .defaultValue(null) - .description("oss endpoint"); - - public static final ConfigKey OSS_ACCESS_ID = ConfigKeys - .key("geaflow.file.oss.access.id") - .defaultValue(null) - .description("oss access id"); - - public static final ConfigKey OSS_SECRET_KEY = ConfigKeys - .key("geaflow.file.oss.secret.key") - .defaultValue(null) - .description("oss secret key"); - - - public static final ConfigKey S3_BUCKET_NAME = ConfigKeys - .key("geaflow.file.s3.bucket.name") - .defaultValue(null) - .description("s3 bucket name"); - - public static final ConfigKey S3_ENDPOINT = ConfigKeys - .key("geaflow.file.s3.endpoint") - .defaultValue(null) - .description("s3 endpoint"); - - public static final ConfigKey S3_ACCESS_KEY_ID = ConfigKeys - .key("geaflow.file.s3.access.key.id") - .defaultValue(null) - .description("s3 access key id"); - - public static final ConfigKey S3_ACCESS_KEY = ConfigKeys - .key("geaflow.file.s3.access.key") - .defaultValue(null) - .description("s3 access key"); - - public static final ConfigKey S3_MIN_PART_SIZE = ConfigKeys - .key("geaflow.file.s3.min.part.size") - .defaultValue(5242880L) - .description("s3 input minimum part size in bytes"); - - public static final ConfigKey S3_INPUT_STREAM_CHUNK_SIZE = ConfigKeys - .key("geaflow.file.s3.input.stream.chunk.size") - .defaultValue(1048576) - .description("s3 input stream chunk size in bytes"); - - - public static final ConfigKey S3_REGION = ConfigKeys - .key("geaflow.file.s3.region") - .defaultValue("CN_NORTH_1") - .description("s3 region"); + protected static final int CORE_NUM = Runtime.getRuntime().availableProcessors(); + public static final ConfigKey PERSISTENT_TYPE = + ConfigKeys.key("geaflow.file.persistent.type") + .defaultValue("LOCAL") + .description("geaflow file persistent type"); + + public static final ConfigKey PERSISTENT_THREAD_SIZE = + ConfigKeys.key("geaflow.file.persistent.thread.size") + .defaultValue(CORE_NUM) + .description("geaflow file persistent thread size"); + + public static final ConfigKey USER_NAME = + ConfigKeys.key("geaflow.file.persistent.user.name") + .defaultValue("geaflow") + .description("geaflow file base url"); + + public static final ConfigKey JSON_CONFIG = + ConfigKeys.key("geaflow.file.persistent.config.json") + .defaultValue("") + .description("geaflow json config"); + + public static final ConfigKey ROOT = + ConfigKeys.key("geaflow.file.persistent.root") + .defaultValue("/geaflow/chk") + .description("geaflow file persistent root path"); + + /** oss config. */ + public static final ConfigKey OSS_BUCKET_NAME = + ConfigKeys.key("geaflow.file.oss.bucket.name") + .defaultValue(null) + .description("oss bucket name"); + + public static final ConfigKey OSS_ENDPOINT = + ConfigKeys.key("geaflow.file.oss.endpoint").defaultValue(null).description("oss endpoint"); + + public static final ConfigKey OSS_ACCESS_ID = + ConfigKeys.key("geaflow.file.oss.access.id").defaultValue(null).description("oss access id"); + + public static final ConfigKey OSS_SECRET_KEY = + ConfigKeys.key("geaflow.file.oss.secret.key") + .defaultValue(null) + .description("oss secret key"); + + public static final ConfigKey S3_BUCKET_NAME = + ConfigKeys.key("geaflow.file.s3.bucket.name") + .defaultValue(null) + .description("s3 bucket name"); + + public static final ConfigKey S3_ENDPOINT = + ConfigKeys.key("geaflow.file.s3.endpoint").defaultValue(null).description("s3 endpoint"); + + public static final ConfigKey S3_ACCESS_KEY_ID = + ConfigKeys.key("geaflow.file.s3.access.key.id") + .defaultValue(null) + .description("s3 access key id"); + + public static final ConfigKey S3_ACCESS_KEY = + ConfigKeys.key("geaflow.file.s3.access.key").defaultValue(null).description("s3 access key"); + + public static final ConfigKey S3_MIN_PART_SIZE = + ConfigKeys.key("geaflow.file.s3.min.part.size") + .defaultValue(5242880L) + .description("s3 input minimum part size in bytes"); + + public static final ConfigKey S3_INPUT_STREAM_CHUNK_SIZE = + ConfigKeys.key("geaflow.file.s3.input.stream.chunk.size") + .defaultValue(1048576) + .description("s3 input stream chunk size in bytes"); + + public static final ConfigKey S3_REGION = + ConfigKeys.key("geaflow.file.s3.region").defaultValue("CN_NORTH_1").description("s3 region"); } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileInfo.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileInfo.java index d06e39042..621c69d4a 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileInfo.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/FileInfo.java @@ -20,72 +20,72 @@ package org.apache.geaflow.file; import java.util.Objects; + import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.Path; public class FileInfo { - private Path path; - private long modifiedTime; - private long length; + private Path path; + private long modifiedTime; + private long length; - protected FileInfo() { - } + protected FileInfo() {} - public static FileInfo of(FileStatus fileStatus) { - return new FileInfo() - .withPath(fileStatus.getPath()) - .withLength(fileStatus.getLen()) - .withModifiedTime(fileStatus.getModificationTime()); - } + public static FileInfo of(FileStatus fileStatus) { + return new FileInfo() + .withPath(fileStatus.getPath()) + .withLength(fileStatus.getLen()) + .withModifiedTime(fileStatus.getModificationTime()); + } - public static FileInfo of() { - return new FileInfo(); - } + public static FileInfo of() { + return new FileInfo(); + } - public FileInfo withLength(long length) { - this.length = length; - return this; - } + public FileInfo withLength(long length) { + this.length = length; + return this; + } - public FileInfo withPath(Path path) { - this.path = path; - return this; - } + public FileInfo withPath(Path path) { + this.path = path; + return this; + } - public FileInfo withModifiedTime(long modifiedTime) { - this.modifiedTime = modifiedTime; - return this; - } + public FileInfo withModifiedTime(long modifiedTime) { + this.modifiedTime = modifiedTime; + return this; + } - public Path getPath() { - return path; - } + public Path getPath() { + return path; + } - public long getModificationTime() { - return modifiedTime; - } + public long getModificationTime() { + return modifiedTime; + } - public long getLength() { - return length; - } + public long getLength() { + return length; + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - FileInfo fileInfo = (FileInfo) o; - return length == fileInfo.getLength() - && modifiedTime == fileInfo.modifiedTime - && Objects.equals(path, fileInfo.path); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return Objects.hash(length, path, modifiedTime); + if (o == null || getClass() != o.getClass()) { + return false; } + FileInfo fileInfo = (FileInfo) o; + return length == fileInfo.getLength() + && modifiedTime == fileInfo.modifiedTime + && Objects.equals(path, fileInfo.path); + } + + @Override + public int hashCode() { + return Objects.hash(length, path, modifiedTime); + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/IPersistentIO.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/IPersistentIO.java index e31acd761..86ca02799 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/IPersistentIO.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/IPersistentIO.java @@ -22,41 +22,42 @@ import java.io.IOException; import java.io.InputStream; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.PathFilter; public interface IPersistentIO { - void init(Configuration userConfig); + void init(Configuration userConfig); - List listFileName(Path path) throws IOException; + List listFileName(Path path) throws IOException; - boolean exists(Path path) throws IOException; + boolean exists(Path path) throws IOException; - boolean delete(Path path, boolean recursive) throws IOException; + boolean delete(Path path, boolean recursive) throws IOException; - boolean renameFile(Path from, Path to) throws IOException; + boolean renameFile(Path from, Path to) throws IOException; - boolean createNewFile(Path path) throws IOException; + boolean createNewFile(Path path) throws IOException; - void copyFromLocalFile(Path local, Path remote) throws IOException; + void copyFromLocalFile(Path local, Path remote) throws IOException; - void copyToLocalFile(Path remote, Path local) throws IOException; + void copyToLocalFile(Path remote, Path local) throws IOException; - long getFileSize(Path path) throws IOException; + long getFileSize(Path path) throws IOException; - long getFileCount(Path path) throws IOException; + long getFileCount(Path path) throws IOException; - FileInfo getFileInfo(Path path) throws IOException; + FileInfo getFileInfo(Path path) throws IOException; - FileInfo[] listFileInfo(Path path) throws IOException; + FileInfo[] listFileInfo(Path path) throws IOException; - FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException; + FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException; - InputStream open(Path path) throws IOException; + InputStream open(Path path) throws IOException; - void close() throws IOException; + void close() throws IOException; - PersistentType getPersistentType(); + PersistentType getPersistentType(); } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentIOBuilder.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentIOBuilder.java index 086642c1f..a86dd5439 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentIOBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentIOBuilder.java @@ -20,24 +20,24 @@ package org.apache.geaflow.file; import java.util.ServiceLoader; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class PersistentIOBuilder { - public static IPersistentIO build(Configuration userConfig) { - PersistentType type = PersistentType.valueOf( - userConfig.getString(FileConfigKeys.PERSISTENT_TYPE).toUpperCase()); - ServiceLoader serviceLoader = ServiceLoader.load(IPersistentIO.class); - for (IPersistentIO persistentIO : serviceLoader) { - if (persistentIO.getPersistentType() == type) { - persistentIO.init(userConfig); - return persistentIO; - } - } - - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(type.toString())); + public static IPersistentIO build(Configuration userConfig) { + PersistentType type = + PersistentType.valueOf(userConfig.getString(FileConfigKeys.PERSISTENT_TYPE).toUpperCase()); + ServiceLoader serviceLoader = ServiceLoader.load(IPersistentIO.class); + for (IPersistentIO persistentIO : serviceLoader) { + if (persistentIO.getPersistentType() == type) { + persistentIO.init(userConfig); + return persistentIO; + } } + throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(type.toString())); + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentType.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentType.java index bab295770..6c4240d9c 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentType.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-common/src/main/java/org/apache/geaflow/file/PersistentType.java @@ -20,20 +20,12 @@ package org.apache.geaflow.file; public enum PersistentType { - /** - * oss file schema. - */ - OSS, - /** - * hdfs or pangu file schema. - */ - DFS, - /** - * s3 file schema. - */ - S3, - /** - * local file schema, for testing. - */ - LOCAL + /** oss file schema. */ + OSS, + /** hdfs or pangu file schema. */ + DFS, + /** s3 file schema. */ + S3, + /** local file schema, for testing. */ + LOCAL } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/DfsIO.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/DfsIO.java index 4649175a1..90777d8ce 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/DfsIO.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/DfsIO.java @@ -19,7 +19,6 @@ package org.apache.geaflow.file.dfs; -import com.google.common.base.Preconditions; import java.io.IOException; import java.io.InputStream; import java.net.URI; @@ -28,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; + import org.apache.geaflow.common.utils.GsonUtil; import org.apache.geaflow.file.FileConfigKeys; import org.apache.geaflow.file.FileInfo; @@ -43,140 +43,142 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DfsIO implements IPersistentIO { - - private static final Logger LOGGER = LoggerFactory.getLogger(DfsIO.class); - private static final String DFS_URI_KEY = "fs.defaultFS"; - protected static final String LOCAL_FILE_IMPL = "fs.file.impl"; - protected FileSystem fileSystem; - - public DfsIO() { - - } - - public void init(org.apache.geaflow.common.config.Configuration userConfig) { - String jsonConfig = Preconditions.checkNotNull(userConfig.getString(FileConfigKeys.JSON_CONFIG)); - Map persistConfig = GsonUtil.parse(jsonConfig); - Preconditions.checkArgument(persistConfig.containsKey(DFS_URI_KEY), DFS_URI_KEY + " must be set"); - Configuration conf = new Configuration(); - conf.set(LOCAL_FILE_IMPL, LocalFileSystem.class.getCanonicalName()); - for (Entry entry : persistConfig.entrySet()) { - conf.set(entry.getKey(), entry.getValue()); - } - try { - this.fileSystem = FileSystem.newInstance(getURIFromConf(conf), conf); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - protected URI getURIFromConf(Configuration configuration) throws URISyntaxException { - return new URI(configuration.get(DFS_URI_KEY)); - } - - @Override - public List listFileName(Path path) throws IOException { - List fileNames = new ArrayList<>(); - for (FileStatus status : fileSystem.listStatus(path)) { - fileNames.add(status.getPath().getName()); - } - return fileNames; - } - - @Override - public boolean exists(Path path) throws IOException { - return fileSystem.exists(path); - } - - @Override - public boolean delete(Path path, boolean recursive) throws IOException { - if (exists(path)) { - return this.fileSystem.delete(path, recursive); - } - - return false; - } - - @Override - public boolean renameFile(Path from, Path to) throws IOException { - if (exists(to)) { - this.fileSystem.delete(to, false); - } - return fileSystem.rename(from, to); - } - - @Override - public boolean createNewFile(Path path) throws IOException { - if (!fileSystem.exists(path.getParent())) { - fileSystem.mkdirs(path.getParent()); - } - return fileSystem.createNewFile(path); - } - - @Override - public void copyFromLocalFile(Path local, Path remote) throws IOException { - fileSystem.copyFromLocalFile(false, true, local, remote); - } - - @Override - public void copyToLocalFile(Path remote, Path local) throws IOException { - fileSystem.copyToLocalFile(false, remote, local, true); - } - - @Override - public long getFileSize(Path path) throws IOException { - FileStatus status = fileSystem.getFileStatus(path); - return status.getLen(); - } - - @Override - public long getFileCount(Path path) throws IOException { - ContentSummary summary = this.fileSystem.getContentSummary(path); - return summary.getFileCount(); - } - - @Override - public FileInfo getFileInfo(Path path) throws IOException { - FileStatus status = fileSystem.getFileStatus(path); - return FileInfo.of(status); - } - - @Override - public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { - FileStatus[] fileStatuses = fileSystem.listStatus(path, filter); - FileInfo[] fileInfos = new FileInfo[fileStatuses.length]; - for (int i = 0; i < fileStatuses.length; i++) { - fileInfos[i] = FileInfo.of(fileStatuses[i]); - } - return fileInfos; - } - - @Override - public FileInfo[] listFileInfo(Path path) throws IOException { - FileStatus[] fileStatuses = fileSystem.listStatus(path); - FileInfo[] fileInfos = new FileInfo[fileStatuses.length]; - for (int i = 0; i < fileStatuses.length; i++) { - fileInfos[i] = FileInfo.of(fileStatuses[i]); - } - return fileInfos; - } - - @Override - public InputStream open(Path path) throws IOException { - return fileSystem.open(path); - } - - @Override - public void close() throws IOException { - this.fileSystem.close(); - } +import com.google.common.base.Preconditions; - @Override - public PersistentType getPersistentType() { - return PersistentType.DFS; - } +public class DfsIO implements IPersistentIO { - public FileSystem getFileSystem() { - return fileSystem; - } + private static final Logger LOGGER = LoggerFactory.getLogger(DfsIO.class); + private static final String DFS_URI_KEY = "fs.defaultFS"; + protected static final String LOCAL_FILE_IMPL = "fs.file.impl"; + protected FileSystem fileSystem; + + public DfsIO() {} + + public void init(org.apache.geaflow.common.config.Configuration userConfig) { + String jsonConfig = + Preconditions.checkNotNull(userConfig.getString(FileConfigKeys.JSON_CONFIG)); + Map persistConfig = GsonUtil.parse(jsonConfig); + Preconditions.checkArgument( + persistConfig.containsKey(DFS_URI_KEY), DFS_URI_KEY + " must be set"); + Configuration conf = new Configuration(); + conf.set(LOCAL_FILE_IMPL, LocalFileSystem.class.getCanonicalName()); + for (Entry entry : persistConfig.entrySet()) { + conf.set(entry.getKey(), entry.getValue()); + } + try { + this.fileSystem = FileSystem.newInstance(getURIFromConf(conf), conf); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected URI getURIFromConf(Configuration configuration) throws URISyntaxException { + return new URI(configuration.get(DFS_URI_KEY)); + } + + @Override + public List listFileName(Path path) throws IOException { + List fileNames = new ArrayList<>(); + for (FileStatus status : fileSystem.listStatus(path)) { + fileNames.add(status.getPath().getName()); + } + return fileNames; + } + + @Override + public boolean exists(Path path) throws IOException { + return fileSystem.exists(path); + } + + @Override + public boolean delete(Path path, boolean recursive) throws IOException { + if (exists(path)) { + return this.fileSystem.delete(path, recursive); + } + + return false; + } + + @Override + public boolean renameFile(Path from, Path to) throws IOException { + if (exists(to)) { + this.fileSystem.delete(to, false); + } + return fileSystem.rename(from, to); + } + + @Override + public boolean createNewFile(Path path) throws IOException { + if (!fileSystem.exists(path.getParent())) { + fileSystem.mkdirs(path.getParent()); + } + return fileSystem.createNewFile(path); + } + + @Override + public void copyFromLocalFile(Path local, Path remote) throws IOException { + fileSystem.copyFromLocalFile(false, true, local, remote); + } + + @Override + public void copyToLocalFile(Path remote, Path local) throws IOException { + fileSystem.copyToLocalFile(false, remote, local, true); + } + + @Override + public long getFileSize(Path path) throws IOException { + FileStatus status = fileSystem.getFileStatus(path); + return status.getLen(); + } + + @Override + public long getFileCount(Path path) throws IOException { + ContentSummary summary = this.fileSystem.getContentSummary(path); + return summary.getFileCount(); + } + + @Override + public FileInfo getFileInfo(Path path) throws IOException { + FileStatus status = fileSystem.getFileStatus(path); + return FileInfo.of(status); + } + + @Override + public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { + FileStatus[] fileStatuses = fileSystem.listStatus(path, filter); + FileInfo[] fileInfos = new FileInfo[fileStatuses.length]; + for (int i = 0; i < fileStatuses.length; i++) { + fileInfos[i] = FileInfo.of(fileStatuses[i]); + } + return fileInfos; + } + + @Override + public FileInfo[] listFileInfo(Path path) throws IOException { + FileStatus[] fileStatuses = fileSystem.listStatus(path); + FileInfo[] fileInfos = new FileInfo[fileStatuses.length]; + for (int i = 0; i < fileStatuses.length; i++) { + fileInfos[i] = FileInfo.of(fileStatuses[i]); + } + return fileInfos; + } + + @Override + public InputStream open(Path path) throws IOException { + return fileSystem.open(path); + } + + @Override + public void close() throws IOException { + this.fileSystem.close(); + } + + @Override + public PersistentType getPersistentType() { + return PersistentType.DFS; + } + + public FileSystem getFileSystem() { + return fileSystem; + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/LocalIO.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/LocalIO.java index d88f440b4..2a70899d0 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/LocalIO.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/main/java/org/apache/geaflow/file/dfs/LocalIO.java @@ -22,6 +22,7 @@ import java.io.File; import java.io.IOException; import java.net.URI; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -35,37 +36,38 @@ public class LocalIO extends DfsIO { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalIO.class); - private static final String LOCAL = "file:///"; - - @Override - public void init(Configuration userConfig) { - String root; - if (userConfig.contains(FileConfigKeys.ROOT)) { - root = userConfig.getString(FileConfigKeys.ROOT); - } else { - root = userConfig.getString(ExecutionConfigKeys.JOB_WORK_PATH) - + userConfig.getString(FileConfigKeys.ROOT); - } + private static final Logger LOGGER = LoggerFactory.getLogger(LocalIO.class); + private static final String LOCAL = "file:///"; - LOGGER.info("use local chk path {}", root); - try { - FileUtils.forceMkdir(new File(root)); - } catch (IOException e) { - throw new GeaflowRuntimeException("mkdir fail " + root, e); - } + @Override + public void init(Configuration userConfig) { + String root; + if (userConfig.contains(FileConfigKeys.ROOT)) { + root = userConfig.getString(FileConfigKeys.ROOT); + } else { + root = + userConfig.getString(ExecutionConfigKeys.JOB_WORK_PATH) + + userConfig.getString(FileConfigKeys.ROOT); + } - org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); - conf.set(LOCAL_FILE_IMPL, LocalFileSystem.class.getCanonicalName()); - try { - this.fileSystem = FileSystem.newInstance(new URI(LOCAL), conf); - } catch (Exception e) { - throw new RuntimeException(e); - } + LOGGER.info("use local chk path {}", root); + try { + FileUtils.forceMkdir(new File(root)); + } catch (IOException e) { + throw new GeaflowRuntimeException("mkdir fail " + root, e); } - @Override - public PersistentType getPersistentType() { - return PersistentType.LOCAL; + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + conf.set(LOCAL_FILE_IMPL, LocalFileSystem.class.getCanonicalName()); + try { + this.fileSystem = FileSystem.newInstance(new URI(LOCAL), conf); + } catch (Exception e) { + throw new RuntimeException(e); } + } + + @Override + public PersistentType getPersistentType() { + return PersistentType.LOCAL; + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/DfsIOTest.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/DfsIOTest.java index c2a82fe69..fe01a759c 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/DfsIOTest.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/DfsIOTest.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.GsonUtil; @@ -41,80 +42,78 @@ public class DfsIOTest { - private MiniDFSCluster hdfsCluster; - private String hdfsURI; - - @BeforeClass - public void createHDFS() throws Exception { - org.apache.hadoop.conf.Configuration hdConf = new org.apache.hadoop.conf.Configuration(); - File baseDir = new File("./target/hdfs/hdfsTest").getAbsoluteFile(); - FileUtils.deleteQuietly(baseDir); - hdConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath()); - MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(hdConf); - hdfsCluster = builder.build(); - - hdfsURI = "hdfs://" - + hdfsCluster.getURI().getHost() - + ":" - + hdfsCluster.getNameNodePort() - + "/"; - - Path hdPath = new Path("/test"); - FileSystem hdfs = hdPath.getFileSystem(hdConf); - FSDataOutputStream stream = hdfs.create(hdPath); - for (int i = 0; i < 10; i++) { - stream.write("Hello HDFS\n".getBytes()); - } - stream.close(); + private MiniDFSCluster hdfsCluster; + private String hdfsURI; + + @BeforeClass + public void createHDFS() throws Exception { + org.apache.hadoop.conf.Configuration hdConf = new org.apache.hadoop.conf.Configuration(); + File baseDir = new File("./target/hdfs/hdfsTest").getAbsoluteFile(); + FileUtils.deleteQuietly(baseDir); + hdConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath()); + MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(hdConf); + hdfsCluster = builder.build(); + + hdfsURI = + "hdfs://" + hdfsCluster.getURI().getHost() + ":" + hdfsCluster.getNameNodePort() + "/"; + + Path hdPath = new Path("/test"); + FileSystem hdfs = hdPath.getFileSystem(hdConf); + FSDataOutputStream stream = hdfs.create(hdPath); + for (int i = 0; i < 10; i++) { + stream.write("Hello HDFS\n".getBytes()); } + stream.close(); + } - @AfterClass - public void tearUp() throws Exception { - this.hdfsCluster.shutdown(true); - } + @AfterClass + public void tearUp() throws Exception { + this.hdfsCluster.shutdown(true); + } - private void test(Configuration configuration) throws Exception { - FileUtils.touch(new File("/tmp/README")); - - IPersistentIO persistentIO = PersistentIOBuilder.build(configuration); - String myName = "myName" + System.currentTimeMillis(); - persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); - - for (int i = 0; i < 101; i++) { - persistentIO.copyFromLocalFile(new Path("/tmp/README"), - new Path("/geaflow/chk/" + myName + "/datas/README" + i)); - } - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path( - "/geaflow/chk/" + myName + "/0/README")); - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path( - "/geaflow/chk/" + myName + "/1/README")); - - persistentIO.renameFile(new Path("/geaflow/chk/" + myName + "/"), new Path( - "/geaflow/chk/" + myName + "2")); - List list = persistentIO.listFileName(new Path("/geaflow/chk/" + myName + "2")); - Assert.assertEquals(list.size(), 3); - - FileInfo[] res = persistentIO.listFileInfo(new Path("/geaflow/chk/" + myName + "2/datas")); - Assert.assertEquals(res.length, 101); - - persistentIO.renameFile(new Path("/geaflow/chk/" + myName + "2/datas/README46"), - new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46")); - Assert.assertTrue(persistentIO.exists(new Path( - "/geaflow/chk/" + myName + "2/datas/MYREADME46"))); - Assert.assertFalse(persistentIO.exists(new Path( - "/geaflow/chk/" + myName + "2/datas/README46"))); - - persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); - } + private void test(Configuration configuration) throws Exception { + FileUtils.touch(new File("/tmp/README")); - @Test - public void testHdfs() throws Exception { - Configuration configuration = new Configuration(); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "DFS"); + IPersistentIO persistentIO = PersistentIOBuilder.build(configuration); + String myName = "myName" + System.currentTimeMillis(); + persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); - Map config = new HashMap<>(); - config.put("fs.defaultFS", hdfsURI); - configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); - test(configuration); + for (int i = 0; i < 101; i++) { + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/datas/README" + i)); } + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/0/README")); + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/1/README")); + + persistentIO.renameFile( + new Path("/geaflow/chk/" + myName + "/"), new Path("/geaflow/chk/" + myName + "2")); + List list = persistentIO.listFileName(new Path("/geaflow/chk/" + myName + "2")); + Assert.assertEquals(list.size(), 3); + + FileInfo[] res = persistentIO.listFileInfo(new Path("/geaflow/chk/" + myName + "2/datas")); + Assert.assertEquals(res.length, 101); + + persistentIO.renameFile( + new Path("/geaflow/chk/" + myName + "2/datas/README46"), + new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46")); + Assert.assertTrue( + persistentIO.exists(new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46"))); + Assert.assertFalse( + persistentIO.exists(new Path("/geaflow/chk/" + myName + "2/datas/README46"))); + + persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); + } + + @Test + public void testHdfs() throws Exception { + Configuration configuration = new Configuration(); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "DFS"); + + Map config = new HashMap<>(); + config.put("fs.defaultFS", hdfsURI); + configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); + test(configuration); + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/LocalIOTest.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/LocalIOTest.java index 688d0c598..ffc6b89eb 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/LocalIOTest.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-dfs/src/test/java/org/apache/geaflow/file/dfs/LocalIOTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.util.List; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.file.FileConfigKeys; @@ -33,34 +34,38 @@ public class LocalIOTest { - @Test - public void test() throws Exception { - FileUtils.touch(new File("/tmp/README")); + @Test + public void test() throws Exception { + FileUtils.touch(new File("/tmp/README")); - Configuration configuration = new Configuration(); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); - IPersistentIO persistentIO = PersistentIOBuilder.build(configuration); - persistentIO.delete(new Path("/tmp/geaflow/chk/myName2"), true); + Configuration configuration = new Configuration(); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); + IPersistentIO persistentIO = PersistentIOBuilder.build(configuration); + persistentIO.delete(new Path("/tmp/geaflow/chk/myName2"), true); - for (int i = 0; i < 101; i++) { - persistentIO.copyFromLocalFile(new Path("/tmp/README"), - new Path("/tmp/geaflow/chk/myName/datas/README" + i)); - } - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path("/tmp/geaflow/chk/myName/0/README")); - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path("/tmp/geaflow/chk/myName/1/README")); + for (int i = 0; i < 101; i++) { + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/tmp/geaflow/chk/myName/datas/README" + i)); + } + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/tmp/geaflow/chk/myName/0/README")); + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/tmp/geaflow/chk/myName/1/README")); - persistentIO.renameFile(new Path("/tmp/geaflow/chk/myName/"), new Path("/tmp/geaflow/chk/myName2")); - List list = persistentIO.listFileName(new Path("/tmp/geaflow/chk/myName2")); - Assert.assertEquals(list.size(), 3); + persistentIO.renameFile( + new Path("/tmp/geaflow/chk/myName/"), new Path("/tmp/geaflow/chk/myName2")); + List list = persistentIO.listFileName(new Path("/tmp/geaflow/chk/myName2")); + Assert.assertEquals(list.size(), 3); - FileInfo[] res = persistentIO.listFileInfo(new Path("/tmp/geaflow/chk/myName2/datas")); - Assert.assertEquals(res.length, 101); + FileInfo[] res = persistentIO.listFileInfo(new Path("/tmp/geaflow/chk/myName2/datas")); + Assert.assertEquals(res.length, 101); - persistentIO.renameFile(new Path("/tmp/geaflow/chk/myName2/datas/README46"), - new Path("/tmp/geaflow/chk/myName2/datas/MYREADME46")); - Assert.assertTrue(persistentIO.exists(new Path("/tmp/geaflow/chk/myName2/datas/MYREADME46"))); - Assert.assertFalse(persistentIO.exists(new Path("/tmp/geaflow/chk/myName2/datas/README46"))); + persistentIO.renameFile( + new Path("/tmp/geaflow/chk/myName2/datas/README46"), + new Path("/tmp/geaflow/chk/myName2/datas/MYREADME46")); + Assert.assertTrue(persistentIO.exists(new Path("/tmp/geaflow/chk/myName2/datas/MYREADME46"))); + Assert.assertFalse(persistentIO.exists(new Path("/tmp/geaflow/chk/myName2/datas/README46"))); - persistentIO.delete(new Path("/tmp/geaflow/chk/myName2"), true); - } + persistentIO.delete(new Path("/tmp/geaflow/chk/myName2"), true); + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/main/java/org/apache/geaflow/file/oss/OssIO.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/main/java/org/apache/geaflow/file/oss/OssIO.java index aaf5d17ce..fc7a1205f 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/main/java/org/apache/geaflow/file/oss/OssIO.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/main/java/org/apache/geaflow/file/oss/OssIO.java @@ -19,14 +19,6 @@ package org.apache.geaflow.file.oss; -import com.aliyun.oss.OSSClient; -import com.aliyun.oss.model.DeleteObjectsRequest; -import com.aliyun.oss.model.ListObjectsRequest; -import com.aliyun.oss.model.OSSObject; -import com.aliyun.oss.model.OSSObjectSummary; -import com.aliyun.oss.model.ObjectListing; -import com.aliyun.oss.model.ObjectMetadata; -import com.google.common.base.Preconditions; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; @@ -38,6 +30,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.GsonUtil; @@ -48,232 +41,239 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.PathFilter; -public class OssIO implements IPersistentIO { +import com.aliyun.oss.OSSClient; +import com.aliyun.oss.model.DeleteObjectsRequest; +import com.aliyun.oss.model.ListObjectsRequest; +import com.aliyun.oss.model.OSSObject; +import com.aliyun.oss.model.OSSObjectSummary; +import com.aliyun.oss.model.ObjectListing; +import com.aliyun.oss.model.ObjectMetadata; +import com.google.common.base.Preconditions; - private OSSClient ossClient; - private String bucketName; +public class OssIO implements IPersistentIO { - public OssIO() { + private OSSClient ossClient; + private String bucketName; - } + public OssIO() {} - @Override - public void init(Configuration userConfig) { - String jsonConfig = Preconditions.checkNotNull(userConfig.getString(FileConfigKeys.JSON_CONFIG)); - Map persistConfig = GsonUtil.parse(jsonConfig); + @Override + public void init(Configuration userConfig) { + String jsonConfig = + Preconditions.checkNotNull(userConfig.getString(FileConfigKeys.JSON_CONFIG)); + Map persistConfig = GsonUtil.parse(jsonConfig); - this.bucketName = Configuration.getString(FileConfigKeys.OSS_BUCKET_NAME, persistConfig); - String endpoint = Configuration.getString(FileConfigKeys.OSS_ENDPOINT, persistConfig); - String accessKeyId = Configuration.getString(FileConfigKeys.OSS_ACCESS_ID, persistConfig); - String accessKeySecret = Configuration.getString(FileConfigKeys.OSS_SECRET_KEY, persistConfig); - this.ossClient = new OSSClient(endpoint, accessKeyId, accessKeySecret); - } + this.bucketName = Configuration.getString(FileConfigKeys.OSS_BUCKET_NAME, persistConfig); + String endpoint = Configuration.getString(FileConfigKeys.OSS_ENDPOINT, persistConfig); + String accessKeyId = Configuration.getString(FileConfigKeys.OSS_ACCESS_ID, persistConfig); + String accessKeySecret = Configuration.getString(FileConfigKeys.OSS_SECRET_KEY, persistConfig); + this.ossClient = new OSSClient(endpoint, accessKeyId, accessKeySecret); + } + @Override + public List listFileName(Path path) throws IOException { + FileInfo[] infos = listFileInfo(path); + return Arrays.stream(infos).map(c -> c.getPath().getName()).collect(Collectors.toList()); + } - @Override - public List listFileName(Path path) throws IOException { - FileInfo[] infos = listFileInfo(path); - return Arrays.stream(infos).map(c -> c.getPath().getName()).collect(Collectors.toList()); + @Override + public boolean exists(Path path) throws IOException { + boolean existFile = ossClient.doesObjectExist(bucketName, pathToKey(path)); + if (!existFile) { + ObjectListing objectListing = ossClient.listObjects(bucketName, keyToPrefix(pathToKey(path))); + return objectListing.getObjectSummaries().size() > 0; } + return true; + } - @Override - public boolean exists(Path path) throws IOException { - boolean existFile = ossClient.doesObjectExist(bucketName, pathToKey(path)); - if (!existFile) { - ObjectListing objectListing = ossClient.listObjects(bucketName, keyToPrefix(pathToKey(path))); - return objectListing.getObjectSummaries().size() > 0; - } - return true; - } + @Override + public boolean delete(Path path, boolean recursive) throws IOException { + String key = pathToKey(path); + boolean deleteFlag = false; - @Override - public boolean delete(Path path, boolean recursive) throws IOException { - String key = pathToKey(path); - boolean deleteFlag = false; - - if (recursive) { - String nextMarker = null; - ObjectListing objectListing; - ListObjectsRequest request = new ListObjectsRequest(bucketName); - request.setPrefix(keyToPrefix(key)); - do { - request.setMarker(nextMarker); - Preconditions.checkArgument(request.getPrefix() != null && request.getPrefix().length() > 0); - objectListing = ossClient.listObjects(request); - List sums = objectListing.getObjectSummaries(); - List files = new ArrayList<>(); - for (OSSObjectSummary s : sums) { - files.add(s.getKey()); - } - nextMarker = objectListing.getNextMarker(); - if (!files.isEmpty()) { - ossClient.deleteObjects(new DeleteObjectsRequest(bucketName).withKeys(files)); - deleteFlag = true; - } - } while (objectListing.isTruncated()); - } else { - ossClient.deleteObject(bucketName, key); + if (recursive) { + String nextMarker = null; + ObjectListing objectListing; + ListObjectsRequest request = new ListObjectsRequest(bucketName); + request.setPrefix(keyToPrefix(key)); + do { + request.setMarker(nextMarker); + Preconditions.checkArgument( + request.getPrefix() != null && request.getPrefix().length() > 0); + objectListing = ossClient.listObjects(request); + List sums = objectListing.getObjectSummaries(); + List files = new ArrayList<>(); + for (OSSObjectSummary s : sums) { + files.add(s.getKey()); } - - return deleteFlag; + nextMarker = objectListing.getNextMarker(); + if (!files.isEmpty()) { + ossClient.deleteObjects(new DeleteObjectsRequest(bucketName).withKeys(files)); + deleteFlag = true; + } + } while (objectListing.isTruncated()); + } else { + ossClient.deleteObject(bucketName, key); } - @Override - public boolean renameFile(Path from, Path to) throws IOException { - String fromKey = pathToKey(from); - String toKey = pathToKey(to); - String nextMarker = null; - ObjectListing objectListing; - ListObjectsRequest request = new ListObjectsRequest(bucketName); - request.setPrefix(keyToPrefix(fromKey)); - do { - request.setMarker(nextMarker); - Preconditions.checkArgument(request.getPrefix() != null && request.getPrefix().length() > 0); - objectListing = ossClient.listObjects(request); - List sums = objectListing.getObjectSummaries(); - for (OSSObjectSummary s : sums) { - String key = s.getKey(); - String newKey = key.replace(fromKey, toKey); - ossClient.copyObject(bucketName, key, bucketName, newKey); - ossClient.deleteObject(bucketName, key); - } - nextMarker = objectListing.getNextMarker(); - } while (objectListing.isTruncated()); + return deleteFlag; + } - fromKey = pathToKey(from); - toKey = pathToKey(to); - if (!from.toString().endsWith("/") && !to.toString().endsWith("/") - && ossClient.doesObjectExist(bucketName, fromKey)) { - ossClient.copyObject(bucketName, fromKey, bucketName, toKey); - ossClient.deleteObject(bucketName, fromKey); - } - return true; - } + @Override + public boolean renameFile(Path from, Path to) throws IOException { + String fromKey = pathToKey(from); + String toKey = pathToKey(to); + String nextMarker = null; + ObjectListing objectListing; + ListObjectsRequest request = new ListObjectsRequest(bucketName); + request.setPrefix(keyToPrefix(fromKey)); + do { + request.setMarker(nextMarker); + Preconditions.checkArgument(request.getPrefix() != null && request.getPrefix().length() > 0); + objectListing = ossClient.listObjects(request); + List sums = objectListing.getObjectSummaries(); + for (OSSObjectSummary s : sums) { + String key = s.getKey(); + String newKey = key.replace(fromKey, toKey); + ossClient.copyObject(bucketName, key, bucketName, newKey); + ossClient.deleteObject(bucketName, key); + } + nextMarker = objectListing.getNextMarker(); + } while (objectListing.isTruncated()); - @Override - public boolean createNewFile(Path path) throws IOException { - if (exists(path)) { - return false; - } - ossClient.putObject(bucketName, pathToKey(path), new ByteArrayInputStream(new byte[]{})); - return true; + fromKey = pathToKey(from); + toKey = pathToKey(to); + if (!from.toString().endsWith("/") + && !to.toString().endsWith("/") + && ossClient.doesObjectExist(bucketName, fromKey)) { + ossClient.copyObject(bucketName, fromKey, bucketName, toKey); + ossClient.deleteObject(bucketName, fromKey); } + return true; + } - @Override - public void copyFromLocalFile(Path local, Path remote) throws IOException { - ossClient.putObject(bucketName, pathToKey(remote), new File(local.toString())); + @Override + public boolean createNewFile(Path path) throws IOException { + if (exists(path)) { + return false; } + ossClient.putObject(bucketName, pathToKey(path), new ByteArrayInputStream(new byte[] {})); + return true; + } - @Override - public void copyToLocalFile(Path remote, Path local) throws IOException { - FileUtils.copyInputStreamToFile(open(remote), new File(local.toString())); - } + @Override + public void copyFromLocalFile(Path local, Path remote) throws IOException { + ossClient.putObject(bucketName, pathToKey(remote), new File(local.toString())); + } - @Override - public long getFileSize(Path path) throws IOException { - OSSObject ossObject = ossClient.getObject(bucketName, pathToKey(path)); - return ossObject.getObjectMetadata().getContentLength(); - } + @Override + public void copyToLocalFile(Path remote, Path local) throws IOException { + FileUtils.copyInputStreamToFile(open(remote), new File(local.toString())); + } - @Override - public long getFileCount(Path path) throws IOException { - long count = 0; - String nextMarker = null; - ObjectListing objectListing; - ListObjectsRequest request = new ListObjectsRequest(bucketName); - request.setPrefix(keyToPrefix(pathToKey(path))); - do { - request.setMarker(nextMarker); - objectListing = ossClient.listObjects(request); - count += objectListing.getObjectSummaries().size(); - nextMarker = objectListing.getNextMarker(); - } while (objectListing.isTruncated()); + @Override + public long getFileSize(Path path) throws IOException { + OSSObject ossObject = ossClient.getObject(bucketName, pathToKey(path)); + return ossObject.getObjectMetadata().getContentLength(); + } - return count; - } + @Override + public long getFileCount(Path path) throws IOException { + long count = 0; + String nextMarker = null; + ObjectListing objectListing; + ListObjectsRequest request = new ListObjectsRequest(bucketName); + request.setPrefix(keyToPrefix(pathToKey(path))); + do { + request.setMarker(nextMarker); + objectListing = ossClient.listObjects(request); + count += objectListing.getObjectSummaries().size(); + nextMarker = objectListing.getNextMarker(); + } while (objectListing.isTruncated()); - @Override - public FileInfo getFileInfo(Path path) throws IOException { - ObjectMetadata obj = ossClient.getObjectMetadata(bucketName, pathToKey(path)); - return FileInfo.of() - .withPath(path) - .withLength(obj.getContentLength()) - .withModifiedTime(obj.getLastModified().getTime()); - } + return count; + } - @Override - public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { - List res = Arrays.asList(listFileInfo(path)); - return res.stream().filter(c -> filter.accept(c.getPath())).toArray(FileInfo[]::new); - } + @Override + public FileInfo getFileInfo(Path path) throws IOException { + ObjectMetadata obj = ossClient.getObjectMetadata(bucketName, pathToKey(path)); + return FileInfo.of() + .withPath(path) + .withLength(obj.getContentLength()) + .withModifiedTime(obj.getLastModified().getTime()); + } - @Override - public FileInfo[] listFileInfo(Path path) throws IOException { - Set res = new HashSet<>(); - String nextMarker = null; - ObjectListing objectListing; - ListObjectsRequest request = new ListObjectsRequest(bucketName); - request.setPrefix(keyToPrefix(pathToKey(path))); - int prefixLen = request.getPrefix().length(); - do { - request.setMarker(nextMarker); - objectListing = ossClient.listObjects(request); - List sums = objectListing.getObjectSummaries(); - for (OSSObjectSummary s : sums) { - String str = s.getKey().substring(prefixLen); - int nextPos = str.indexOf('/'); - Path filePath; - long modifiedTime; - if (nextPos == -1) { - filePath = new Path(keyToPath(s.getKey())); - modifiedTime = s.getLastModified().getTime(); - } else { - filePath = new Path(keyToPath(request.getPrefix() + str.substring(0, nextPos))); - modifiedTime = 0; - } - FileInfo fileInfo = FileInfo.of() - .withPath(filePath) - .withLength(s.getSize()) - .withModifiedTime(modifiedTime); - res.add(fileInfo); - } - nextMarker = objectListing.getNextMarker(); - } while (objectListing.isTruncated()); - return res.toArray(new FileInfo[0]); - } + @Override + public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { + List res = Arrays.asList(listFileInfo(path)); + return res.stream().filter(c -> filter.accept(c.getPath())).toArray(FileInfo[]::new); + } - @Override - public InputStream open(Path path) throws IOException { - OSSObject ossObject = ossClient.getObject(bucketName, pathToKey(path)); - return ossObject.getObjectContent(); - } + @Override + public FileInfo[] listFileInfo(Path path) throws IOException { + Set res = new HashSet<>(); + String nextMarker = null; + ObjectListing objectListing; + ListObjectsRequest request = new ListObjectsRequest(bucketName); + request.setPrefix(keyToPrefix(pathToKey(path))); + int prefixLen = request.getPrefix().length(); + do { + request.setMarker(nextMarker); + objectListing = ossClient.listObjects(request); + List sums = objectListing.getObjectSummaries(); + for (OSSObjectSummary s : sums) { + String str = s.getKey().substring(prefixLen); + int nextPos = str.indexOf('/'); + Path filePath; + long modifiedTime; + if (nextPos == -1) { + filePath = new Path(keyToPath(s.getKey())); + modifiedTime = s.getLastModified().getTime(); + } else { + filePath = new Path(keyToPath(request.getPrefix() + str.substring(0, nextPos))); + modifiedTime = 0; + } + FileInfo fileInfo = + FileInfo.of().withPath(filePath).withLength(s.getSize()).withModifiedTime(modifiedTime); + res.add(fileInfo); + } + nextMarker = objectListing.getNextMarker(); + } while (objectListing.isTruncated()); + return res.toArray(new FileInfo[0]); + } - @Override - public void close() throws IOException { - this.ossClient.shutdown(); - } + @Override + public InputStream open(Path path) throws IOException { + OSSObject ossObject = ossClient.getObject(bucketName, pathToKey(path)); + return ossObject.getObjectContent(); + } - @Override - public PersistentType getPersistentType() { - return PersistentType.OSS; - } + @Override + public void close() throws IOException { + this.ossClient.shutdown(); + } - private String keyToPath(String key) { - return "/" + key; - } + @Override + public PersistentType getPersistentType() { + return PersistentType.OSS; + } - private String pathToKey(Path path) { - String strPath = path.toUri().getPath(); - if (strPath.charAt(0) == '/') { - return strPath.substring(1); - } - return strPath; + private String keyToPath(String key) { + return "/" + key; + } + + private String pathToKey(Path path) { + String strPath = path.toUri().getPath(); + if (strPath.charAt(0) == '/') { + return strPath.substring(1); } + return strPath; + } - private String keyToPrefix(String key) { - if (key.charAt(key.length() - 1) == '/') { - return key; - } - return key + "/"; + private String keyToPrefix(String key) { + if (key.charAt(key.length() - 1) == '/') { + return key; } + return key + "/"; + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/test/java/org/apache/geaflow/file/oss/OssIOTest.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/test/java/org/apache/geaflow/file/oss/OssIOTest.java index f157fd40b..f1fb60305 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/test/java/org/apache/geaflow/file/oss/OssIOTest.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-oss/src/test/java/org/apache/geaflow/file/oss/OssIOTest.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.GsonUtil; @@ -36,53 +37,53 @@ public class OssIOTest { - @Test(enabled = false) - public void test() throws Exception { - FileUtils.touch(new File("/tmp/README")); + @Test(enabled = false) + public void test() throws Exception { + FileUtils.touch(new File("/tmp/README")); - IPersistentIO persistentIO = PersistentIOBuilder.build(getOSSConfig()); - String myName = "myName" + System.currentTimeMillis(); - persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); + IPersistentIO persistentIO = PersistentIOBuilder.build(getOSSConfig()); + String myName = "myName" + System.currentTimeMillis(); + persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); - for (int i = 0; i < 101; i++) { - persistentIO.copyFromLocalFile(new Path("/tmp/README"), - new Path("/geaflow/chk/" + myName + "/datas/README" + i)); - } - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path( - "/geaflow/chk/" + myName + "/0/README")); - persistentIO.copyFromLocalFile(new Path("/tmp/README"), new Path( - "/geaflow/chk/" + myName + "/1/README")); + for (int i = 0; i < 101; i++) { + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/datas/README" + i)); + } + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/0/README")); + persistentIO.copyFromLocalFile( + new Path("/tmp/README"), new Path("/geaflow/chk/" + myName + "/1/README")); - persistentIO.renameFile(new Path("/geaflow/chk/" + myName + "/"), new Path( - "/geaflow/chk/" + myName + "2")); - List list = persistentIO.listFileName(new Path("/geaflow/chk/" + myName + "2")); - Assert.assertEquals(list.size(), 3); + persistentIO.renameFile( + new Path("/geaflow/chk/" + myName + "/"), new Path("/geaflow/chk/" + myName + "2")); + List list = persistentIO.listFileName(new Path("/geaflow/chk/" + myName + "2")); + Assert.assertEquals(list.size(), 3); - FileInfo[] res = persistentIO.listFileInfo(new Path("/geaflow/chk/" + myName + "2/datas")); - Assert.assertEquals(res.length, 101); + FileInfo[] res = persistentIO.listFileInfo(new Path("/geaflow/chk/" + myName + "2/datas")); + Assert.assertEquals(res.length, 101); - persistentIO.renameFile(new Path("/geaflow/chk/" + myName + "2/datas/README46"), - new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46")); - Assert.assertTrue(persistentIO.exists(new Path( - "/geaflow/chk/" + myName + "2/datas/MYREADME46"))); - Assert.assertFalse(persistentIO.exists(new Path( - "/geaflow/chk/" + myName + "2/datas/README46"))); + persistentIO.renameFile( + new Path("/geaflow/chk/" + myName + "2/datas/README46"), + new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46")); + Assert.assertTrue( + persistentIO.exists(new Path("/geaflow/chk/" + myName + "2/datas/MYREADME46"))); + Assert.assertFalse( + persistentIO.exists(new Path("/geaflow/chk/" + myName + "2/datas/README46"))); - persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); - } + persistentIO.delete(new Path("/geaflow/chk/" + myName + "2"), true); + } - private static Configuration getOSSConfig() throws Exception { - Map config = new HashMap<>(); - // set private account info. - config.put(FileConfigKeys.OSS_BUCKET_NAME.getKey(), ""); - config.put(FileConfigKeys.OSS_ENDPOINT.getKey(), ""); - - config.put(FileConfigKeys.OSS_ACCESS_ID.getKey(), ""); - config.put(FileConfigKeys.OSS_SECRET_KEY.getKey(), ""); - Configuration configuration = new Configuration(); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "OSS"); - configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); - return configuration; - } + private static Configuration getOSSConfig() throws Exception { + Map config = new HashMap<>(); + // set private account info. + config.put(FileConfigKeys.OSS_BUCKET_NAME.getKey(), ""); + config.put(FileConfigKeys.OSS_ENDPOINT.getKey(), ""); + config.put(FileConfigKeys.OSS_ACCESS_ID.getKey(), ""); + config.put(FileConfigKeys.OSS_SECRET_KEY.getKey(), ""); + Configuration configuration = new Configuration(); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "OSS"); + configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); + return configuration; + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3IO.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3IO.java index bb6757a19..c9b0b7223 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3IO.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3IO.java @@ -19,8 +19,6 @@ package org.apache.geaflow.file.s3; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -34,6 +32,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.GsonUtil; @@ -45,6 +44,10 @@ import org.apache.hadoop.fs.PathFilter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; + import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.async.AsyncRequestBody; @@ -73,477 +76,513 @@ public class S3IO implements IPersistentIO { - private static final Logger LOGGER = LoggerFactory.getLogger(S3IO.class); - - public static final int S3_MAX_RETURN_KEYS = 1000; - private static final String PATH_DELIMITER = "/"; - private static final String LOCAL_RENAME_DELIMITER = "-"; - private static final String LOCAL_RENAME_DIR = "/tmp"; - - private S3AsyncClient s3Client; - private S3TransferManager transferManager; - private String bucketName; - private int inputStreamChunkSize; - - - public static String filePathToKey(String bucketName, Path dirPath) { - String pathStr = dirPath.toString(); - pathStr = pathStr.startsWith(bucketName) ? pathStr.substring(bucketName.length()) : pathStr; - pathStr = pathStr.startsWith(PATH_DELIMITER) ? pathStr.substring(1) : pathStr; - - return pathStr; - } - - public static String dirPathToPrefix(String bucketName, Path dirPath) { - String pathStr = dirPath.toString(); - pathStr = pathStr.startsWith(bucketName) ? pathStr.substring(bucketName.length()) : pathStr; - pathStr = pathStr.startsWith(PATH_DELIMITER) ? pathStr.substring(1) : pathStr; - pathStr = pathStr.endsWith(PATH_DELIMITER) ? pathStr : pathStr + PATH_DELIMITER; - - return PATH_DELIMITER.equals(pathStr) ? "" : pathStr; - } - - public static Path keyToFilePath(String bucketName, String key) { - return new Path(bucketName + PATH_DELIMITER + key); - } - - @Override - public void init(Configuration userConfig) { - String jsonConfig = Preconditions.checkNotNull( - userConfig.getString(FileConfigKeys.JSON_CONFIG)); - Map persistConfig = GsonUtil.parse(jsonConfig); - - this.bucketName = Configuration.getString(FileConfigKeys.S3_BUCKET_NAME, persistConfig); - String endpoint = Configuration.getString(FileConfigKeys.S3_ENDPOINT, persistConfig); - String accessKeyId = Configuration.getString(FileConfigKeys.S3_ACCESS_KEY_ID, - persistConfig); - String accessKey = Configuration.getString(FileConfigKeys.S3_ACCESS_KEY, persistConfig); - long minimumPartSizeInBytes = Configuration.getLong(FileConfigKeys.S3_MIN_PART_SIZE, - persistConfig); - this.inputStreamChunkSize = Configuration.getInteger( - FileConfigKeys.S3_INPUT_STREAM_CHUNK_SIZE, persistConfig); - - AwsBasicCredentials awsCreds = AwsBasicCredentials.create(accessKeyId, accessKey); - - this.s3Client = S3AsyncClient.builder().endpointOverride(URI.create(endpoint)) + private static final Logger LOGGER = LoggerFactory.getLogger(S3IO.class); + + public static final int S3_MAX_RETURN_KEYS = 1000; + private static final String PATH_DELIMITER = "/"; + private static final String LOCAL_RENAME_DELIMITER = "-"; + private static final String LOCAL_RENAME_DIR = "/tmp"; + + private S3AsyncClient s3Client; + private S3TransferManager transferManager; + private String bucketName; + private int inputStreamChunkSize; + + public static String filePathToKey(String bucketName, Path dirPath) { + String pathStr = dirPath.toString(); + pathStr = pathStr.startsWith(bucketName) ? pathStr.substring(bucketName.length()) : pathStr; + pathStr = pathStr.startsWith(PATH_DELIMITER) ? pathStr.substring(1) : pathStr; + + return pathStr; + } + + public static String dirPathToPrefix(String bucketName, Path dirPath) { + String pathStr = dirPath.toString(); + pathStr = pathStr.startsWith(bucketName) ? pathStr.substring(bucketName.length()) : pathStr; + pathStr = pathStr.startsWith(PATH_DELIMITER) ? pathStr.substring(1) : pathStr; + pathStr = pathStr.endsWith(PATH_DELIMITER) ? pathStr : pathStr + PATH_DELIMITER; + + return PATH_DELIMITER.equals(pathStr) ? "" : pathStr; + } + + public static Path keyToFilePath(String bucketName, String key) { + return new Path(bucketName + PATH_DELIMITER + key); + } + + @Override + public void init(Configuration userConfig) { + String jsonConfig = + Preconditions.checkNotNull(userConfig.getString(FileConfigKeys.JSON_CONFIG)); + Map persistConfig = GsonUtil.parse(jsonConfig); + + this.bucketName = Configuration.getString(FileConfigKeys.S3_BUCKET_NAME, persistConfig); + String endpoint = Configuration.getString(FileConfigKeys.S3_ENDPOINT, persistConfig); + String accessKeyId = Configuration.getString(FileConfigKeys.S3_ACCESS_KEY_ID, persistConfig); + String accessKey = Configuration.getString(FileConfigKeys.S3_ACCESS_KEY, persistConfig); + long minimumPartSizeInBytes = + Configuration.getLong(FileConfigKeys.S3_MIN_PART_SIZE, persistConfig); + this.inputStreamChunkSize = + Configuration.getInteger(FileConfigKeys.S3_INPUT_STREAM_CHUNK_SIZE, persistConfig); + + AwsBasicCredentials awsCreds = AwsBasicCredentials.create(accessKeyId, accessKey); + + this.s3Client = + S3AsyncClient.builder() + .endpointOverride(URI.create(endpoint)) .region(Region.of(Configuration.getString(FileConfigKeys.S3_REGION, persistConfig))) - .credentialsProvider(StaticCredentialsProvider.create(awsCreds)).forcePathStyle(true) + .credentialsProvider(StaticCredentialsProvider.create(awsCreds)) + .forcePathStyle(true) .multipartConfiguration( - MultipartConfiguration.builder().minimumPartSizeInBytes(minimumPartSizeInBytes) - .apiCallBufferSizeInBytes(minimumPartSizeInBytes).build()) - .serviceConfiguration(builder -> builder.checksumValidationEnabled(false)).build(); - this.transferManager = S3TransferManager.builder().s3Client(s3Client).build(); - - checkAndCreateBucket(); - } - - public void checkAndCreateBucket() { - if (bucketNotExists()) { - CreateBucketRequest createBucketRequest = CreateBucketRequest.builder() - .bucket(bucketName).build(); - try { - s3Client.createBucket(createBucketRequest); - LOGGER.info("create new bucket success, bucket name: {}", bucketName); - } catch (Throwable e) { - throw new GeaflowRuntimeException("Failed to create bucket: " + bucketName, e); - } - } - } - - private boolean bucketNotExists() { - HeadBucketRequest headBucketRequest = HeadBucketRequest.builder().bucket(bucketName) + MultipartConfiguration.builder() + .minimumPartSizeInBytes(minimumPartSizeInBytes) + .apiCallBufferSizeInBytes(minimumPartSizeInBytes) + .build()) + .serviceConfiguration(builder -> builder.checksumValidationEnabled(false)) .build(); - CompletableFuture future = s3Client.headBucket(headBucketRequest) - .handle((response, error) -> { - if (error != null) { + this.transferManager = S3TransferManager.builder().s3Client(s3Client).build(); + + checkAndCreateBucket(); + } + + public void checkAndCreateBucket() { + if (bucketNotExists()) { + CreateBucketRequest createBucketRequest = + CreateBucketRequest.builder().bucket(bucketName).build(); + try { + s3Client.createBucket(createBucketRequest); + LOGGER.info("create new bucket success, bucket name: {}", bucketName); + } catch (Throwable e) { + throw new GeaflowRuntimeException("Failed to create bucket: " + bucketName, e); + } + } + } + + private boolean bucketNotExists() { + HeadBucketRequest headBucketRequest = HeadBucketRequest.builder().bucket(bucketName).build(); + CompletableFuture future = + s3Client + .headBucket(headBucketRequest) + .handle( + (response, error) -> { + if (error != null) { Throwable cause = error.getCause(); if (cause instanceof NoSuchBucketException) { - return true; + return true; } throw new GeaflowRuntimeException(error); - } - LOGGER.info("bucket already exists, bucket name: {}", bucketName); - return false; - }); - - try { - return future.get(); - } catch (InterruptedException | ExecutionException e) { - throw new GeaflowRuntimeException( - "Error checking if bucket exists, bucket name: " + bucketName, e); - } + } + LOGGER.info("bucket already exists, bucket name: {}", bucketName); + return false; + }); + + try { + return future.get(); + } catch (InterruptedException | ExecutionException e) { + throw new GeaflowRuntimeException( + "Error checking if bucket exists, bucket name: " + bucketName, e); } + } - @Override - public List listFileName(Path path) throws IOException { - FileInfo[] fileInfos = listFileInfo(path); - return Arrays.stream(fileInfos).map(c -> c.getPath().getName()) - .collect(Collectors.toList()); - } + @Override + public List listFileName(Path path) throws IOException { + FileInfo[] fileInfos = listFileInfo(path); + return Arrays.stream(fileInfos).map(c -> c.getPath().getName()).collect(Collectors.toList()); + } - @Override - public boolean exists(Path path) throws IOException { - return isFile(path) || isDirectory(path); - } + @Override + public boolean exists(Path path) throws IOException { + return isFile(path) || isDirectory(path); + } - public boolean isFile(Path path) throws IOException { - String key = filePathToKey(bucketName, path); + public boolean isFile(Path path) throws IOException { + String key = filePathToKey(bucketName, path); - if (key.isEmpty()) { - return false; - } + if (key.isEmpty()) { + return false; + } - HeadObjectRequest request = HeadObjectRequest.builder().bucket(bucketName).key(key).build(); + HeadObjectRequest request = HeadObjectRequest.builder().bucket(bucketName).key(key).build(); - CompletableFuture future = s3Client.headObject(request) - .handle((response, error) -> { - if (error != null) { + CompletableFuture future = + s3Client + .headObject(request) + .handle( + (response, error) -> { + if (error != null) { if (error.getCause() instanceof NoSuchKeyException) { - return false; + return false; } throw new GeaflowRuntimeException(error); - } - return true; - }); - - try { - return future.get(); - } catch (Throwable e) { - throw new IOException("Occur error in isDirectory, path: " + path, e); - } + } + return true; + }); + + try { + return future.get(); + } catch (Throwable e) { + throw new IOException("Occur error in isDirectory, path: " + path, e); } + } + + public boolean isDirectory(Path path) throws IOException { + String prefix = dirPathToPrefix(bucketName, path); + String key = filePathToKey(bucketName, path); + ListObjectsV2Request request = + ListObjectsV2Request.builder() + .bucket(bucketName) + .prefix(prefix) + .delimiter(PATH_DELIMITER) + .maxKeys(1) + .build(); - public boolean isDirectory(Path path) throws IOException { - String prefix = dirPathToPrefix(bucketName, path); - String key = filePathToKey(bucketName, path); - ListObjectsV2Request request = ListObjectsV2Request.builder().bucket(bucketName) - .prefix(prefix).delimiter(PATH_DELIMITER).maxKeys(1).build(); - - CompletableFuture future = s3Client.listObjectsV2(request); - ListObjectsV2Response response; + CompletableFuture future = s3Client.listObjectsV2(request); + ListObjectsV2Response response; - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException("Occur error in isDirectory, path: " + path, e); - } - - return !response.commonPrefixes().isEmpty() || (!response.contents().isEmpty() - && !response.contents().get(0).key().equals(key)); + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException("Occur error in isDirectory, path: " + path, e); } - @Override - public boolean delete(Path path, boolean recursive) throws IOException { - if (isFile(path)) { - deleteFileOrDirectory(path, true); - return true; - } - - if (isDirectory(path)) { - if (!recursive) { - throw new IOException("path is a directory, but recursive is false, path: " + path); - } - deleteFileOrDirectory(path, false); - return true; - } + return !response.commonPrefixes().isEmpty() + || (!response.contents().isEmpty() && !response.contents().get(0).key().equals(key)); + } - return false; + @Override + public boolean delete(Path path, boolean recursive) throws IOException { + if (isFile(path)) { + deleteFileOrDirectory(path, true); + return true; } - private void deleteFileOrDirectory(Path path, boolean isFile) throws IOException { - if (isFile) { - deleteFile(filePathToKey(bucketName, path)); - return; - } - - String prefix = dirPathToPrefix(bucketName, path); - String continuationToken = null; - - ListObjectsV2Request.Builder builder = ListObjectsV2Request.builder().bucket(bucketName) - .prefix(prefix).maxKeys(S3_MAX_RETURN_KEYS); - - do { - if (continuationToken != null) { - builder.continuationToken(continuationToken); - } - - CompletableFuture future = s3Client.listObjectsV2( - builder.build()); - ListObjectsV2Response response; - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException("Failed to list objects, dir path: " + path, e); - } - - for (S3Object s3Object : response.contents()) { - deleteFile(s3Object.key()); - } - continuationToken = response.nextContinuationToken(); - } while (continuationToken != null); + if (isDirectory(path)) { + if (!recursive) { + throw new IOException("path is a directory, but recursive is false, path: " + path); + } + deleteFileOrDirectory(path, false); + return true; } - public void deleteFile(String key) throws IOException { - CompletableFuture future = s3Client.deleteObject( - DeleteObjectRequest.builder().bucket(bucketName).key(key).build()); + return false; + } - try { - future.get(); - } catch (Throwable e) { - throw new IOException( - "Failed to delete file, file path: " + keyToFilePath(bucketName, key), e); - } + private void deleteFileOrDirectory(Path path, boolean isFile) throws IOException { + if (isFile) { + deleteFile(filePathToKey(bucketName, path)); + return; } - @Override - public boolean createNewFile(Path path) throws IOException { - String key = filePathToKey(bucketName, path); + String prefix = dirPathToPrefix(bucketName, path); + String continuationToken = null; + + ListObjectsV2Request.Builder builder = + ListObjectsV2Request.builder() + .bucket(bucketName) + .prefix(prefix) + .maxKeys(S3_MAX_RETURN_KEYS); + + do { + if (continuationToken != null) { + builder.continuationToken(continuationToken); + } + + CompletableFuture future = s3Client.listObjectsV2(builder.build()); + ListObjectsV2Response response; + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException("Failed to list objects, dir path: " + path, e); + } + + for (S3Object s3Object : response.contents()) { + deleteFile(s3Object.key()); + } + continuationToken = response.nextContinuationToken(); + } while (continuationToken != null); + } + + public void deleteFile(String key) throws IOException { + CompletableFuture future = + s3Client.deleteObject(DeleteObjectRequest.builder().bucket(bucketName).key(key).build()); + + try { + future.get(); + } catch (Throwable e) { + throw new IOException( + "Failed to delete file, file path: " + keyToFilePath(bucketName, key), e); + } + } - if (exists(path)) { - return false; - } + @Override + public boolean createNewFile(Path path) throws IOException { + String key = filePathToKey(bucketName, path); - PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); + if (exists(path)) { + return false; + } - CompletableFuture future = s3Client.putObject(request, - AsyncRequestBody.fromBytes(new byte[0])); + PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); - try { - future.get(); - } catch (Throwable e) { - throw new IOException("Fail to create a new file, file path: " + path, e); - } + CompletableFuture future = + s3Client.putObject(request, AsyncRequestBody.fromBytes(new byte[0])); - return true; + try { + future.get(); + } catch (Throwable e) { + throw new IOException("Fail to create a new file, file path: " + path, e); } - @Override - public long getFileSize(Path path) throws IOException { - return getFileInfo(path).getLength(); - } + return true; + } - @Override - public long getFileCount(Path path) throws IOException { - String key = filePathToKey(bucketName, path); - long fileCount = 0; - String continuationToken = null; + @Override + public long getFileSize(Path path) throws IOException { + return getFileInfo(path).getLength(); + } - ListObjectsV2Request.Builder builder = ListObjectsV2Request.builder().bucket(bucketName) - .prefix(key); + @Override + public long getFileCount(Path path) throws IOException { + String key = filePathToKey(bucketName, path); + long fileCount = 0; + String continuationToken = null; - do { - if (continuationToken != null) { - builder.continuationToken(continuationToken); - } + ListObjectsV2Request.Builder builder = + ListObjectsV2Request.builder().bucket(bucketName).prefix(key); - CompletableFuture future = s3Client.listObjectsV2( - builder.build()); - ListObjectsV2Response response; + do { + if (continuationToken != null) { + builder.continuationToken(continuationToken); + } - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException("Failed to get content summary, dir path: " + path, e); - } + CompletableFuture future = s3Client.listObjectsV2(builder.build()); + ListObjectsV2Response response; - fileCount += response.contents().size(); - continuationToken = response.nextContinuationToken(); - } while (continuationToken != null); + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException("Failed to get content summary, dir path: " + path, e); + } - return fileCount; - } + fileCount += response.contents().size(); + continuationToken = response.nextContinuationToken(); + } while (continuationToken != null); - @Override - public FileInfo getFileInfo(Path path) throws IOException { - String key = filePathToKey(bucketName, path); + return fileCount; + } - CompletableFuture future = s3Client.listObjectsV2( - ListObjectsV2Request.builder().bucket(bucketName).prefix(key).build()); - ListObjectsV2Response response; - - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException("Failed to get file status, file path: " + path, e); - } - - if (response.contents().isEmpty()) { - throw new IOException("File not found: " + path); - } - S3Object s3Object = response.contents().get(0); + @Override + public FileInfo getFileInfo(Path path) throws IOException { + String key = filePathToKey(bucketName, path); - if (!s3Object.key().equals(key)) { - throw new IOException("File not found: " + path); - } + CompletableFuture future = + s3Client.listObjectsV2( + ListObjectsV2Request.builder().bucket(bucketName).prefix(key).build()); + ListObjectsV2Response response; - return FileInfo.of().withPath(path).withLength(s3Object.size()) - .withModifiedTime(s3Object.lastModified().toEpochMilli()); + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException("Failed to get file status, file path: " + path, e); } - @Override - public FileInfo[] listFileInfo(Path path) throws IOException { - return listFileInfo(path, null); + if (response.contents().isEmpty()) { + throw new IOException("File not found: " + path); } + S3Object s3Object = response.contents().get(0); - @Override - public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { - String prefix = dirPathToPrefix(bucketName, path); - - List fileInfoList = new ArrayList<>(); - String continuationToken = null; - - ListObjectsV2Request.Builder builder = ListObjectsV2Request.builder().bucket(bucketName) - .prefix(prefix).delimiter(PATH_DELIMITER).maxKeys(S3_MAX_RETURN_KEYS); - - do { - if (continuationToken != null) { - builder.continuationToken(continuationToken); - } - - CompletableFuture future = s3Client.listObjectsV2( - builder.build()); - ListObjectsV2Response response; - - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException("Failed to list file status, dir path: " + path, e); - } - - for (S3Object s3Object : response.contents()) { - Path filePath = keyToFilePath(bucketName, s3Object.key()); - if (filter == null || filter.accept(filePath)) { - fileInfoList.add(FileInfo.of().withPath(filePath).withLength(s3Object.size()) - .withModifiedTime(s3Object.lastModified().toEpochMilli())); - } - } - - for (CommonPrefix commonPrefix : response.commonPrefixes()) { - Path dirPath = keyToFilePath(bucketName, commonPrefix.prefix()); - if (filter == null || filter.accept(dirPath)) { - fileInfoList.add( - FileInfo.of().withPath(dirPath).withLength(0).withModifiedTime(0)); - } - } - continuationToken = response.nextContinuationToken(); - } while (continuationToken != null); - - return fileInfoList.toArray(new FileInfo[0]); + if (!s3Object.key().equals(key)) { + throw new IOException("File not found: " + path); } - - @Override - public void copyFromLocalFile(Path local, Path remote) throws IOException { - String key = filePathToKey(bucketName, remote); - java.nio.file.Path fromPath = Paths.get(local.toString()); - - int count = 0; - int maxTries = 3; - File localFile = new File(local.toString()); - long localFileLen = localFile.length(); - while (true) { - long start = System.currentTimeMillis(); - - UploadFileRequest uploadFileRequest = UploadFileRequest.builder() - .putObjectRequest(req -> req.bucket(bucketName).key(key)).source(fromPath).build(); - FileUpload upload = transferManager.uploadFile(uploadFileRequest); - - try { - upload.completionFuture().get(); - LOGGER.info("upload to s3: size {} KB took {} ms {} -> {}", - String.format("%.1f", localFileLen / 1024.0), - System.currentTimeMillis() - start, local, remote); - return; - } catch (Throwable ex) { - if (++count == maxTries) { - LOGGER.error("failed to upload file, from: {}, to: {}" + local, remote, ex); - throw new RuntimeException("failed to upload file", ex); - } - } + return FileInfo.of() + .withPath(path) + .withLength(s3Object.size()) + .withModifiedTime(s3Object.lastModified().toEpochMilli()); + } + + @Override + public FileInfo[] listFileInfo(Path path) throws IOException { + return listFileInfo(path, null); + } + + @Override + public FileInfo[] listFileInfo(Path path, PathFilter filter) throws IOException { + String prefix = dirPathToPrefix(bucketName, path); + + List fileInfoList = new ArrayList<>(); + String continuationToken = null; + + ListObjectsV2Request.Builder builder = + ListObjectsV2Request.builder() + .bucket(bucketName) + .prefix(prefix) + .delimiter(PATH_DELIMITER) + .maxKeys(S3_MAX_RETURN_KEYS); + + do { + if (continuationToken != null) { + builder.continuationToken(continuationToken); + } + + CompletableFuture future = s3Client.listObjectsV2(builder.build()); + ListObjectsV2Response response; + + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException("Failed to list file status, dir path: " + path, e); + } + + for (S3Object s3Object : response.contents()) { + Path filePath = keyToFilePath(bucketName, s3Object.key()); + if (filter == null || filter.accept(filePath)) { + fileInfoList.add( + FileInfo.of() + .withPath(filePath) + .withLength(s3Object.size()) + .withModifiedTime(s3Object.lastModified().toEpochMilli())); } - } + } - @Override - public void copyToLocalFile(Path remote, Path local) throws IOException { - int count = 0; - int maxTries = 3; - FileInfo fileInfo = getFileInfo(remote); - - java.nio.file.Path toPath = Paths.get(local.toString()); - String key = filePathToKey(bucketName, remote); - - DownloadFileRequest downloadFileRequest = DownloadFileRequest.builder().getObjectRequest( - req -> req.bucket(bucketName).key(key) - .checksumMode(ChecksumMode.UNKNOWN_TO_SDK_VERSION)).destination(toPath).build(); - - while (true) { - FileDownload download = transferManager.downloadFile(downloadFileRequest); - - try { - download.completionFuture().get(); - File localFile = new File(local.toString()); - if (localFile.length() != fileInfo.getLength()) { - LOGGER.warn("download from s3: size not same {} -> {}", remote, local); - if (++count == maxTries) { - return; - } - } else { - LOGGER.info("download from dfs: {} -> {}", remote, local); - break; - } - } catch (Throwable ex) { - if (++count == maxTries) { - LOGGER.error("failed to download file, from: {}, to: {}", remote, local, ex); - throw new RuntimeException("failed to download file", ex); - } - } + for (CommonPrefix commonPrefix : response.commonPrefixes()) { + Path dirPath = keyToFilePath(bucketName, commonPrefix.prefix()); + if (filter == null || filter.accept(dirPath)) { + fileInfoList.add(FileInfo.of().withPath(dirPath).withLength(0).withModifiedTime(0)); } - } - - @Override - public boolean renameFile(Path from, Path to) throws IOException { - String srcKey = filePathToKey(bucketName, from); - String destKey = filePathToKey(bucketName, to); - String filePathStr = - LOCAL_RENAME_DIR + PATH_DELIMITER + keyToFilePath(bucketName, srcKey).toString() - .replace(PATH_DELIMITER, LOCAL_RENAME_DELIMITER); - Path localPath = new Path(filePathStr); - - copyToLocalFile(from, localPath); - try { - copyFromLocalFile(localPath, new Path(destKey)); - } finally { - Files.delete(Paths.get(localPath.toString())); + } + continuationToken = response.nextContinuationToken(); + } while (continuationToken != null); + + return fileInfoList.toArray(new FileInfo[0]); + } + + @Override + public void copyFromLocalFile(Path local, Path remote) throws IOException { + String key = filePathToKey(bucketName, remote); + java.nio.file.Path fromPath = Paths.get(local.toString()); + + int count = 0; + int maxTries = 3; + File localFile = new File(local.toString()); + long localFileLen = localFile.length(); + while (true) { + long start = System.currentTimeMillis(); + + UploadFileRequest uploadFileRequest = + UploadFileRequest.builder() + .putObjectRequest(req -> req.bucket(bucketName).key(key)) + .source(fromPath) + .build(); + FileUpload upload = transferManager.uploadFile(uploadFileRequest); + + try { + upload.completionFuture().get(); + LOGGER.info( + "upload to s3: size {} KB took {} ms {} -> {}", + String.format("%.1f", localFileLen / 1024.0), + System.currentTimeMillis() - start, + local, + remote); + return; + } catch (Throwable ex) { + if (++count == maxTries) { + LOGGER.error("failed to upload file, from: {}, to: {}" + local, remote, ex); + throw new RuntimeException("failed to upload file", ex); } - deleteFile(srcKey); - - return true; - } - - @Override - public InputStream open(Path path) throws IOException { - return new S3InputStream(s3Client, bucketName, filePathToKey(bucketName, path), - inputStreamChunkSize); + } } + } + + @Override + public void copyToLocalFile(Path remote, Path local) throws IOException { + int count = 0; + int maxTries = 3; + FileInfo fileInfo = getFileInfo(remote); + + java.nio.file.Path toPath = Paths.get(local.toString()); + String key = filePathToKey(bucketName, remote); + + DownloadFileRequest downloadFileRequest = + DownloadFileRequest.builder() + .getObjectRequest( + req -> + req.bucket(bucketName) + .key(key) + .checksumMode(ChecksumMode.UNKNOWN_TO_SDK_VERSION)) + .destination(toPath) + .build(); - @Override - public void close() throws IOException { - s3Client.close(); - } + while (true) { + FileDownload download = transferManager.downloadFile(downloadFileRequest); - @Override - public PersistentType getPersistentType() { - return PersistentType.S3; - } - - @VisibleForTesting - public S3AsyncClient getS3Client() { - return s3Client; + try { + download.completionFuture().get(); + File localFile = new File(local.toString()); + if (localFile.length() != fileInfo.getLength()) { + LOGGER.warn("download from s3: size not same {} -> {}", remote, local); + if (++count == maxTries) { + return; + } + } else { + LOGGER.info("download from dfs: {} -> {}", remote, local); + break; + } + } catch (Throwable ex) { + if (++count == maxTries) { + LOGGER.error("failed to download file, from: {}, to: {}", remote, local, ex); + throw new RuntimeException("failed to download file", ex); + } + } } + } + + @Override + public boolean renameFile(Path from, Path to) throws IOException { + String srcKey = filePathToKey(bucketName, from); + String destKey = filePathToKey(bucketName, to); + String filePathStr = + LOCAL_RENAME_DIR + + PATH_DELIMITER + + keyToFilePath(bucketName, srcKey) + .toString() + .replace(PATH_DELIMITER, LOCAL_RENAME_DELIMITER); + Path localPath = new Path(filePathStr); - @VisibleForTesting - public void setBucketName(String bucketName) { - this.bucketName = bucketName; + copyToLocalFile(from, localPath); + try { + copyFromLocalFile(localPath, new Path(destKey)); + } finally { + Files.delete(Paths.get(localPath.toString())); } + deleteFile(srcKey); + + return true; + } + + @Override + public InputStream open(Path path) throws IOException { + return new S3InputStream( + s3Client, bucketName, filePathToKey(bucketName, path), inputStreamChunkSize); + } + + @Override + public void close() throws IOException { + s3Client.close(); + } + + @Override + public PersistentType getPersistentType() { + return PersistentType.S3; + } + + @VisibleForTesting + public S3AsyncClient getS3Client() { + return s3Client; + } + + @VisibleForTesting + public void setBucketName(String bucketName) { + this.bucketName = bucketName; + } } diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3InputStream.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3InputStream.java index 7fd70b49f..8682e8832 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3InputStream.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/main/java/org/apache/geaflow/file/s3/S3InputStream.java @@ -22,7 +22,9 @@ import java.io.IOException; import java.io.InputStream; import java.util.concurrent.CompletableFuture; + import org.jetbrains.annotations.NotNull; + import software.amazon.awssdk.core.BytesWrapper; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -34,133 +36,143 @@ public class S3InputStream extends InputStream { - private final S3AsyncClient s3Client; - private final String bucketName; - private final String key; - private long curPos; - private byte[] buffer; - private int bufferPos; - private final int minChunkSize; - private final long fileLen; - private boolean closed; - - public S3InputStream(S3AsyncClient s3Client, String bucketName, String key, int minChunkSize) - throws IOException { - this.s3Client = s3Client; - this.bucketName = bucketName; - this.key = key; - this.curPos = 0; - this.buffer = new byte[0]; - this.bufferPos = 0; - this.minChunkSize = minChunkSize; - this.fileLen = checkFileExistAndGetFileLength(); - this.closed = false; - } - - @Override - public void close() { - closed = true; - } - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException( - "s3 input stream closed, bucket name: " + bucketName + ", file path" + key); - } - - if (curPos >= fileLen) { - return -1; - } - - if (bufferPos >= buffer.length) { - long newPos = curPos + this.buffer.length; - if (newPos < fileLen) { - fetchChunk(newPos, minChunkSize); - } else { - return -1; - } - } - int byteRead = buffer[bufferPos] & 0xFF; - bufferPos++; - return byteRead; + private final S3AsyncClient s3Client; + private final String bucketName; + private final String key; + private long curPos; + private byte[] buffer; + private int bufferPos; + private final int minChunkSize; + private final long fileLen; + private boolean closed; + + public S3InputStream(S3AsyncClient s3Client, String bucketName, String key, int minChunkSize) + throws IOException { + this.s3Client = s3Client; + this.bucketName = bucketName; + this.key = key; + this.curPos = 0; + this.buffer = new byte[0]; + this.bufferPos = 0; + this.minChunkSize = minChunkSize; + this.fileLen = checkFileExistAndGetFileLength(); + this.closed = false; + } + + @Override + public void close() { + closed = true; + } + + @Override + public int read() throws IOException { + if (closed) { + throw new IOException( + "s3 input stream closed, bucket name: " + bucketName + ", file path" + key); } - @Override - public int read(@NotNull byte[] buffer) throws IOException { - return readInner(buffer, 0, buffer.length); + if (curPos >= fileLen) { + return -1; } - @Override - public int read(@NotNull byte[] buffer, int offset, int length) throws IOException { - return readInner(buffer, offset, length); + if (bufferPos >= buffer.length) { + long newPos = curPos + this.buffer.length; + if (newPos < fileLen) { + fetchChunk(newPos, minChunkSize); + } else { + return -1; + } } - - private int readInner(@NotNull byte[] buffer, int offset, int length) throws IOException { - int bytesRead = 0; - while (bytesRead < length) { - if (bufferPos >= this.buffer.length) { - long newPos = curPos + this.buffer.length; - if (newPos < fileLen) { - fetchChunk(newPos, Integer.max(length - bytesRead, minChunkSize)); - } else { - return bytesRead == 0 ? -1 : bytesRead; - } - } - int bytesToCopy = Math.min(length - bytesRead, this.buffer.length - bufferPos); - System.arraycopy(this.buffer, bufferPos, buffer, offset + bytesRead, bytesToCopy); - bufferPos += bytesToCopy; - bytesRead += bytesToCopy; + int byteRead = buffer[bufferPos] & 0xFF; + bufferPos++; + return byteRead; + } + + @Override + public int read(@NotNull byte[] buffer) throws IOException { + return readInner(buffer, 0, buffer.length); + } + + @Override + public int read(@NotNull byte[] buffer, int offset, int length) throws IOException { + return readInner(buffer, offset, length); + } + + private int readInner(@NotNull byte[] buffer, int offset, int length) throws IOException { + int bytesRead = 0; + while (bytesRead < length) { + if (bufferPos >= this.buffer.length) { + long newPos = curPos + this.buffer.length; + if (newPos < fileLen) { + fetchChunk(newPos, Integer.max(length - bytesRead, minChunkSize)); + } else { + return bytesRead == 0 ? -1 : bytesRead; } - return bytesRead; + } + int bytesToCopy = Math.min(length - bytesRead, this.buffer.length - bufferPos); + System.arraycopy(this.buffer, bufferPos, buffer, offset + bytesRead, bytesToCopy); + bufferPos += bytesToCopy; + bytesRead += bytesToCopy; } - - private void fetchChunk(long position, int chunkSize) throws IOException { - long endPosition = position + chunkSize - 1; - endPosition = Math.min(endPosition, fileLen - 1); - - String range = "bytes=" + position + "-" + endPosition; - GetObjectRequest request = GetObjectRequest.builder().bucket(bucketName).key(key) - .checksumMode(ChecksumMode.UNKNOWN_TO_SDK_VERSION).range(range).build(); - - CompletableFuture future = s3Client.getObject(request, - AsyncResponseTransformer.toBytes()).thenApply(BytesWrapper::asByteArray); - - try { - buffer = future.get(); - bufferPos = 0; - curPos = position; - } catch (Throwable e) { - throw new IOException( - "Failed to fetch chunk, bucket name: " + bucketName + ", file path: " + key - + ", position: " + position, e); - } + return bytesRead; + } + + private void fetchChunk(long position, int chunkSize) throws IOException { + long endPosition = position + chunkSize - 1; + endPosition = Math.min(endPosition, fileLen - 1); + + String range = "bytes=" + position + "-" + endPosition; + GetObjectRequest request = + GetObjectRequest.builder() + .bucket(bucketName) + .key(key) + .checksumMode(ChecksumMode.UNKNOWN_TO_SDK_VERSION) + .range(range) + .build(); + + CompletableFuture future = + s3Client + .getObject(request, AsyncResponseTransformer.toBytes()) + .thenApply(BytesWrapper::asByteArray); + + try { + buffer = future.get(); + bufferPos = 0; + curPos = position; + } catch (Throwable e) { + throw new IOException( + "Failed to fetch chunk, bucket name: " + + bucketName + + ", file path: " + + key + + ", position: " + + position, + e); } + } - private long checkFileExistAndGetFileLength() throws IOException { - CompletableFuture future = s3Client.listObjectsV2( + private long checkFileExistAndGetFileLength() throws IOException { + CompletableFuture future = + s3Client.listObjectsV2( ListObjectsV2Request.builder().bucket(bucketName).prefix(key).build()); - ListObjectsV2Response response; + ListObjectsV2Response response; - try { - response = future.get(); - } catch (Throwable e) { - throw new IOException( - "Failed to get file status, bucket name: " + bucketName + ", file path: " + key, e); - } + try { + response = future.get(); + } catch (Throwable e) { + throw new IOException( + "Failed to get file status, bucket name: " + bucketName + ", file path: " + key, e); + } - if (response.contents().isEmpty()) { - throw new IOException( - "File not found, bucket name: " + bucketName + ", file path: " + key); - } + if (response.contents().isEmpty()) { + throw new IOException("File not found, bucket name: " + bucketName + ", file path: " + key); + } - S3Object s3Object = response.contents().get(0); + S3Object s3Object = response.contents().get(0); - if (!s3Object.key().equals(key)) { - throw new IOException( - "File not found, bucket name: " + bucketName + ", file path: " + key); - } - return s3Object.size(); + if (!s3Object.key().equals(key)) { + throw new IOException("File not found, bucket name: " + bucketName + ", file path: " + key); } + return s3Object.size(); + } } - diff --git a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/test/java/org/apache/geaflow/file/s3/S3IOTest.java b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/test/java/org/apache/geaflow/file/s3/S3IOTest.java index 57220bfeb..4c1598974 100644 --- a/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/test/java/org/apache/geaflow/file/s3/S3IOTest.java +++ b/geaflow/geaflow-plugins/geaflow-file/geaflow-file-s3/src/test/java/org/apache/geaflow/file/s3/S3IOTest.java @@ -32,6 +32,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.CompletableFuture; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.utils.GsonUtil; @@ -48,6 +49,7 @@ import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; + import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; import software.amazon.awssdk.services.s3.model.DeleteBucketResponse; @@ -59,432 +61,431 @@ public class S3IOTest { - private static final Logger LOGGER = LoggerFactory.getLogger(S3IOTest.class); - - private static S3IO s3IO; - private LocalFileSystem localFileSystem; - private String remoteRoot; - private String localRoot; - private String bucketName; - - @BeforeMethod - public void setUp() throws IOException { - String testName = "S3IOTest" + System.currentTimeMillis(); - bucketName = "geaflow-s3-io-test"; - remoteRoot = bucketName + "/" + testName; - - Map config = new HashMap<>(); - config.put(FileConfigKeys.S3_BUCKET_NAME.getKey(), bucketName); - - // Fill the following configuration for s3 if needing to run S3IOTest. - config.put(FileConfigKeys.S3_ENDPOINT.getKey(), ""); - config.put(FileConfigKeys.S3_ACCESS_KEY_ID.getKey(), ""); - config.put(FileConfigKeys.S3_ACCESS_KEY.getKey(), ""); - - Configuration configuration = new Configuration(); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "S3"); - configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); - s3IO = (S3IO) PersistentIOBuilder.build(configuration); - localFileSystem = FileSystem.getLocal(new org.apache.hadoop.conf.Configuration()); - localRoot = "/tmp/" + testName; - localFileSystem.mkdirs(new Path(localRoot)); - } - - @AfterMethod - public void tearDown() throws IOException { - localFileSystem.delete(new Path(localRoot), true); - s3IO.delete(new Path(remoteRoot), true); - s3IO.close(); + private static final Logger LOGGER = LoggerFactory.getLogger(S3IOTest.class); + + private static S3IO s3IO; + private LocalFileSystem localFileSystem; + private String remoteRoot; + private String localRoot; + private String bucketName; + + @BeforeMethod + public void setUp() throws IOException { + String testName = "S3IOTest" + System.currentTimeMillis(); + bucketName = "geaflow-s3-io-test"; + remoteRoot = bucketName + "/" + testName; + + Map config = new HashMap<>(); + config.put(FileConfigKeys.S3_BUCKET_NAME.getKey(), bucketName); + + // Fill the following configuration for s3 if needing to run S3IOTest. + config.put(FileConfigKeys.S3_ENDPOINT.getKey(), ""); + config.put(FileConfigKeys.S3_ACCESS_KEY_ID.getKey(), ""); + config.put(FileConfigKeys.S3_ACCESS_KEY.getKey(), ""); + + Configuration configuration = new Configuration(); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "S3"); + configuration.put(FileConfigKeys.JSON_CONFIG, GsonUtil.toJson(config)); + s3IO = (S3IO) PersistentIOBuilder.build(configuration); + localFileSystem = FileSystem.getLocal(new org.apache.hadoop.conf.Configuration()); + localRoot = "/tmp/" + testName; + localFileSystem.mkdirs(new Path(localRoot)); + } + + @AfterMethod + public void tearDown() throws IOException { + localFileSystem.delete(new Path(localRoot), true); + s3IO.delete(new Path(remoteRoot), true); + s3IO.close(); + } + + private void createNewFileWithFixedEmptyBytes(Path path, int bytesLen) throws IOException { + String key = filePathToKey(bucketName, path); + + if (s3IO.exists(path)) { + return; } - private void createNewFileWithFixedEmptyBytes(Path path, int bytesLen) throws IOException { - String key = filePathToKey(bucketName, path); - - if (s3IO.exists(path)) { - return; - } + PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); - PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); + CompletableFuture future = + s3IO.getS3Client().putObject(request, AsyncRequestBody.fromBytes(new byte[bytesLen])); - CompletableFuture future = s3IO.getS3Client() - .putObject(request, AsyncRequestBody.fromBytes(new byte[bytesLen])); - - try { - future.get(); - } catch (Throwable e) { - throw new IOException("Fail to create a new file, file path: " + path, e); - } + try { + future.get(); + } catch (Throwable e) { + throw new IOException("Fail to create a new file, file path: " + path, e); } + } - private void createNewFileWithFixedBytes(Path path, byte[] bytes) throws IOException { - String key = filePathToKey(bucketName, path); - - if (s3IO.exists(path)) { - return; - } - - PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); - - CompletableFuture future = s3IO.getS3Client() - .putObject(request, AsyncRequestBody.fromBytes(bytes)); + private void createNewFileWithFixedBytes(Path path, byte[] bytes) throws IOException { + String key = filePathToKey(bucketName, path); - try { - future.get(); - } catch (Throwable e) { - throw new IOException("Fail to create a new file, file path: " + path, e); - } + if (s3IO.exists(path)) { + return; } - private void deleteBucket(String bucketName) throws IOException { - DeleteBucketRequest request = DeleteBucketRequest.builder().bucket(bucketName).build(); + PutObjectRequest request = PutObjectRequest.builder().bucket(bucketName).key(key).build(); - CompletableFuture future = s3IO.getS3Client().deleteBucket(request); + CompletableFuture future = + s3IO.getS3Client().putObject(request, AsyncRequestBody.fromBytes(bytes)); - try { - future.get(); - } catch (Throwable e) { - throw new IOException("Fail to create a new file, file path: " + bucketName, e); - } + try { + future.get(); + } catch (Throwable e) { + throw new IOException("Fail to create a new file, file path: " + path, e); } + } - @Test(enabled = false) - void testCreateBucket() throws IOException { - String testName = getTestName(); - String newBucketName = - "geaflow-" + testName.toLowerCase() + "-" + System.currentTimeMillis(); - s3IO.setBucketName(newBucketName); - s3IO.checkAndCreateBucket(); - Path path = new Path(testName + System.currentTimeMillis() + ".txt"); - s3IO.createNewFile(path); - s3IO.deleteFile(path.toString()); - - deleteBucket(newBucketName); - s3IO.setBucketName(bucketName); - } + private void deleteBucket(String bucketName) throws IOException { + DeleteBucketRequest request = DeleteBucketRequest.builder().bucket(bucketName).build(); - @Test(enabled = false) - void testCreateAndDelete() throws IOException { - String testName = getTestName(); - String testFileName = testName + ".txt"; - String pathStr = remoteRoot + "/" + testName + "/s3/1/" + testFileName; - Path filePath = new Path(pathStr); - - Assert.assertTrue(s3IO.createNewFile(filePath)); - Assert.assertFalse(s3IO.createNewFile(filePath)); - Assert.assertTrue(s3IO.delete(filePath, false)); - s3IO.createNewFile(filePath); - Assert.assertTrue(s3IO.delete(filePath, true)); - Assert.assertFalse(s3IO.exists(filePath)); - - pathStr = remoteRoot + "//" + testName + "/s3///1//" + testFileName + "/"; - filePath = new Path(pathStr); - Assert.assertTrue(s3IO.createNewFile(filePath)); - Assert.assertTrue(s3IO.delete(filePath, true)); - Assert.assertFalse(s3IO.exists(filePath)); - } + CompletableFuture future = s3IO.getS3Client().deleteBucket(request); - @Test(enabled = false) - void testGetFileCount() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Path testFilePath1 = new Path(root + "/s3/1/" + testName + "1.txt"); - Path testFilePath2 = new Path(root + "/s3/1/" + testName + "2.txt"); - Path testFilePath3 = new Path(root + "/s3/2/" + testName + "1.txt"); - Path testFilePath4 = new Path(root + "/s3/3/" + testName + "1.txt"); - Path testFilePath5 = new Path(root + "/s3/4/" + testName + "3.txt"); - Path testFilePath6 = new Path(root + "/s3/" + testName + ".txt"); - - s3IO.createNewFile(testFilePath1); - s3IO.createNewFile(testFilePath2); - s3IO.createNewFile(testFilePath3); - s3IO.createNewFile(testFilePath4); - s3IO.createNewFile(testFilePath5); - s3IO.createNewFile(testFilePath6); - - Assert.assertEquals(s3IO.getFileCount(new Path(root)), 6); + try { + future.get(); + } catch (Throwable e) { + throw new IOException("Fail to create a new file, file path: " + bucketName, e); } - - @Test(enabled = false) - void testDeleteRecursive() throws IOException { - String testName = getTestName(); - Path testFilePath1 = new Path(remoteRoot + "/" + testName + "/s3/" + testName + "1.txt"); - Path testFilePath2 = new Path(remoteRoot + "/" + testName + "/s3/" + testName + "2.txt"); - Path testFilePath3 = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "1.txt"); - Path testFilePath4 = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "2.txt"); - Path testFilePath5 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "1.txt"); - Path testFilePath6 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "2.txt"); - Path testFilePath7 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "3.txt"); - Path testFilePath8 = new Path( - remoteRoot + "/" + testName + "/s3/2/1/" + testName + "1.txt"); - - s3IO.createNewFile(testFilePath1); - s3IO.createNewFile(testFilePath2); - s3IO.createNewFile(testFilePath3); - s3IO.createNewFile(testFilePath4); - s3IO.createNewFile(testFilePath5); - s3IO.createNewFile(testFilePath6); - s3IO.createNewFile(testFilePath7); - s3IO.createNewFile(testFilePath8); - - Path rootPath = new Path(remoteRoot); - Path dirPath1 = new Path(remoteRoot + "/" + testName + "/s3/1/"); - Path dirPath2 = new Path(remoteRoot + "/" + testName + "/s3"); - - Assert.assertEquals(s3IO.getFileCount(rootPath), 8); - Assert.assertFalse(s3IO.delete(new Path(testName + "/" + "s3_1"), true)); - Assert.assertFalse(s3IO.delete(new Path(testName + "/" + "s3_1"), false)); - Assert.assertFalse( - s3IO.delete(new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "3.txt"), - true)); - Assert.assertFalse( - s3IO.delete(new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "3.txt"), - false)); - - Assert.assertEquals(s3IO.getFileCount(rootPath), 8); - - try { - s3IO.delete(dirPath1, false); - } catch (IOException e) { - Assert.assertTrue(e.getMessage().contains("path is a directory, but " + "recursive")); - } - - Assert.assertTrue(s3IO.delete(testFilePath5, true)); - Assert.assertEquals(s3IO.getFileCount(rootPath), 7); - Assert.assertTrue(s3IO.delete(testFilePath6, false)); - Assert.assertEquals(s3IO.getFileCount(rootPath), 6); - - Assert.assertTrue(s3IO.delete(dirPath1, true)); - Assert.assertEquals(s3IO.getFileCount(rootPath), 4); - - Assert.assertTrue(s3IO.delete(dirPath2, true)); - Assert.assertEquals(s3IO.getFileCount(rootPath), 0); + } + + @Test(enabled = false) + void testCreateBucket() throws IOException { + String testName = getTestName(); + String newBucketName = "geaflow-" + testName.toLowerCase() + "-" + System.currentTimeMillis(); + s3IO.setBucketName(newBucketName); + s3IO.checkAndCreateBucket(); + Path path = new Path(testName + System.currentTimeMillis() + ".txt"); + s3IO.createNewFile(path); + s3IO.deleteFile(path.toString()); + + deleteBucket(newBucketName); + s3IO.setBucketName(bucketName); + } + + @Test(enabled = false) + void testCreateAndDelete() throws IOException { + String testName = getTestName(); + String testFileName = testName + ".txt"; + String pathStr = remoteRoot + "/" + testName + "/s3/1/" + testFileName; + Path filePath = new Path(pathStr); + + Assert.assertTrue(s3IO.createNewFile(filePath)); + Assert.assertFalse(s3IO.createNewFile(filePath)); + Assert.assertTrue(s3IO.delete(filePath, false)); + s3IO.createNewFile(filePath); + Assert.assertTrue(s3IO.delete(filePath, true)); + Assert.assertFalse(s3IO.exists(filePath)); + + pathStr = remoteRoot + "//" + testName + "/s3///1//" + testFileName + "/"; + filePath = new Path(pathStr); + Assert.assertTrue(s3IO.createNewFile(filePath)); + Assert.assertTrue(s3IO.delete(filePath, true)); + Assert.assertFalse(s3IO.exists(filePath)); + } + + @Test(enabled = false) + void testGetFileCount() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Path testFilePath1 = new Path(root + "/s3/1/" + testName + "1.txt"); + Path testFilePath2 = new Path(root + "/s3/1/" + testName + "2.txt"); + Path testFilePath3 = new Path(root + "/s3/2/" + testName + "1.txt"); + Path testFilePath4 = new Path(root + "/s3/3/" + testName + "1.txt"); + Path testFilePath5 = new Path(root + "/s3/4/" + testName + "3.txt"); + Path testFilePath6 = new Path(root + "/s3/" + testName + ".txt"); + + s3IO.createNewFile(testFilePath1); + s3IO.createNewFile(testFilePath2); + s3IO.createNewFile(testFilePath3); + s3IO.createNewFile(testFilePath4); + s3IO.createNewFile(testFilePath5); + s3IO.createNewFile(testFilePath6); + + Assert.assertEquals(s3IO.getFileCount(new Path(root)), 6); + } + + @Test(enabled = false) + void testDeleteRecursive() throws IOException { + String testName = getTestName(); + Path testFilePath1 = new Path(remoteRoot + "/" + testName + "/s3/" + testName + "1.txt"); + Path testFilePath2 = new Path(remoteRoot + "/" + testName + "/s3/" + testName + "2.txt"); + Path testFilePath3 = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "1.txt"); + Path testFilePath4 = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "2.txt"); + Path testFilePath5 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "1.txt"); + Path testFilePath6 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "2.txt"); + Path testFilePath7 = new Path(remoteRoot + "/" + testName + "/s3/2/" + testName + "3.txt"); + Path testFilePath8 = new Path(remoteRoot + "/" + testName + "/s3/2/1/" + testName + "1.txt"); + + s3IO.createNewFile(testFilePath1); + s3IO.createNewFile(testFilePath2); + s3IO.createNewFile(testFilePath3); + s3IO.createNewFile(testFilePath4); + s3IO.createNewFile(testFilePath5); + s3IO.createNewFile(testFilePath6); + s3IO.createNewFile(testFilePath7); + s3IO.createNewFile(testFilePath8); + + Path rootPath = new Path(remoteRoot); + Path dirPath1 = new Path(remoteRoot + "/" + testName + "/s3/1/"); + Path dirPath2 = new Path(remoteRoot + "/" + testName + "/s3"); + + Assert.assertEquals(s3IO.getFileCount(rootPath), 8); + Assert.assertFalse(s3IO.delete(new Path(testName + "/" + "s3_1"), true)); + Assert.assertFalse(s3IO.delete(new Path(testName + "/" + "s3_1"), false)); + Assert.assertFalse( + s3IO.delete(new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "3.txt"), true)); + Assert.assertFalse( + s3IO.delete(new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "3.txt"), false)); + + Assert.assertEquals(s3IO.getFileCount(rootPath), 8); + + try { + s3IO.delete(dirPath1, false); + } catch (IOException e) { + Assert.assertTrue(e.getMessage().contains("path is a directory, but " + "recursive")); } - @Test(enabled = false) - void testGetFileInfo() throws IOException { - String testName = getTestName(); - Path testFilePath = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "1.txt"); - createNewFileWithFixedEmptyBytes(testFilePath, 5); - FileInfo fileInfo = s3IO.getFileInfo(testFilePath); - Assert.assertEquals(fileInfo.getLength(), 5); + Assert.assertTrue(s3IO.delete(testFilePath5, true)); + Assert.assertEquals(s3IO.getFileCount(rootPath), 7); + Assert.assertTrue(s3IO.delete(testFilePath6, false)); + Assert.assertEquals(s3IO.getFileCount(rootPath), 6); + + Assert.assertTrue(s3IO.delete(dirPath1, true)); + Assert.assertEquals(s3IO.getFileCount(rootPath), 4); + + Assert.assertTrue(s3IO.delete(dirPath2, true)); + Assert.assertEquals(s3IO.getFileCount(rootPath), 0); + } + + @Test(enabled = false) + void testGetFileInfo() throws IOException { + String testName = getTestName(); + Path testFilePath = new Path(remoteRoot + "/" + testName + "/s3/1/" + testName + "1.txt"); + createNewFileWithFixedEmptyBytes(testFilePath, 5); + FileInfo fileInfo = s3IO.getFileInfo(testFilePath); + Assert.assertEquals(fileInfo.getLength(), 5); + Assert.assertTrue(fileInfo.getModificationTime() != 0); + Assert.assertEquals(fileInfo.getPath(), testFilePath); + + LOGGER.info("file {} last modify time {}", testFilePath, fileInfo.getModificationTime()); + } + + @Test(enabled = false) + void testListFileInfo() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Set filePathSet1 = new HashSet<>(); + Set filePathSet2 = new HashSet<>(); + Set dirPathSet = new HashSet<>(); + + Path testFilePath1 = new Path(root + "/s3/" + testName + "1.txt"); + Path testFilePath2 = new Path(root + "/s3/" + testName + "2.txt"); + filePathSet1.add(testFilePath1); + filePathSet1.add(testFilePath2); + + Path testFilePath3 = new Path(root + "/s3/1/" + testName + "1.txt"); + Path testFilePath4 = new Path(root + "/s3/2/" + testName + "1.txt"); + Path testFilePath5 = new Path(root + "/s3/3/" + testName + "1.txt"); + dirPathSet.add(new Path(root + "/s3/1/")); + dirPathSet.add(new Path(root + "/s3/2/")); + dirPathSet.add(new Path(root + "/s3/3/")); + filePathSet2.add(testFilePath3); + filePathSet2.add(testFilePath4); + filePathSet2.add(testFilePath5); + + createNewFileWithFixedEmptyBytes(testFilePath1, 3); + createNewFileWithFixedEmptyBytes(testFilePath2, 3); + createNewFileWithFixedEmptyBytes(testFilePath3, 3); + createNewFileWithFixedEmptyBytes(testFilePath4, 3); + createNewFileWithFixedEmptyBytes(testFilePath5, 3); + + Path dirPath1 = new Path(root + "/s3/1/"); + Path dirPath2 = new Path(root + "/s3"); + + FileInfo[] fileInfos = s3IO.listFileInfo(dirPath2); + int fileCount = 0; + int dirCount = 0; + for (FileInfo fileInfo : fileInfos) { + if (fileInfo.getLength() != 0) { + Assert.assertEquals(fileInfo.getLength(), 3); Assert.assertTrue(fileInfo.getModificationTime() != 0); - Assert.assertEquals(fileInfo.getPath(), testFilePath); - - LOGGER.info("file {} last modify time {}", testFilePath, fileInfo.getModificationTime()); + Assert.assertTrue(filePathSet1.contains(fileInfo.getPath())); + fileCount++; + } else { + Assert.assertEquals(fileInfo.getLength(), 0); + Assert.assertEquals(fileInfo.getModificationTime(), 0); + Assert.assertTrue(dirPathSet.contains(fileInfo.getPath())); + dirCount++; + } } - - @Test(enabled = false) - void testListFileInfo() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Set filePathSet1 = new HashSet<>(); - Set filePathSet2 = new HashSet<>(); - Set dirPathSet = new HashSet<>(); - - Path testFilePath1 = new Path(root + "/s3/" + testName + "1.txt"); - Path testFilePath2 = new Path(root + "/s3/" + testName + "2.txt"); - filePathSet1.add(testFilePath1); - filePathSet1.add(testFilePath2); - - Path testFilePath3 = new Path(root + "/s3/1/" + testName + "1.txt"); - Path testFilePath4 = new Path(root + "/s3/2/" + testName + "1.txt"); - Path testFilePath5 = new Path(root + "/s3/3/" + testName + "1.txt"); - dirPathSet.add(new Path(root + "/s3/1/")); - dirPathSet.add(new Path(root + "/s3/2/")); - dirPathSet.add(new Path(root + "/s3/3/")); - filePathSet2.add(testFilePath3); - filePathSet2.add(testFilePath4); - filePathSet2.add(testFilePath5); - - createNewFileWithFixedEmptyBytes(testFilePath1, 3); - createNewFileWithFixedEmptyBytes(testFilePath2, 3); - createNewFileWithFixedEmptyBytes(testFilePath3, 3); - createNewFileWithFixedEmptyBytes(testFilePath4, 3); - createNewFileWithFixedEmptyBytes(testFilePath5, 3); - - Path dirPath1 = new Path(root + "/s3/1/"); - Path dirPath2 = new Path(root + "/s3"); - - FileInfo[] fileInfos = s3IO.listFileInfo(dirPath2); - int fileCount = 0; - int dirCount = 0; - for (FileInfo fileInfo : fileInfos) { - if (fileInfo.getLength() != 0) { - Assert.assertEquals(fileInfo.getLength(), 3); - Assert.assertTrue(fileInfo.getModificationTime() != 0); - Assert.assertTrue(filePathSet1.contains(fileInfo.getPath())); - fileCount++; - } else { - Assert.assertEquals(fileInfo.getLength(), 0); - Assert.assertEquals(fileInfo.getModificationTime(), 0); - Assert.assertTrue(dirPathSet.contains(fileInfo.getPath())); - dirCount++; - } - } - Assert.assertEquals(fileCount, 2); - Assert.assertEquals(dirCount, 3); - List fileNames = s3IO.listFileName(dirPath2); - Assert.assertEquals(fileNames.size(), 5); - Assert.assertTrue(fileNames.contains("testListFileInfo1.txt")); - - fileInfos = s3IO.listFileInfo(dirPath1); - fileCount = 0; - dirCount = 0; - for (FileInfo fileInfo : fileInfos) { - if (fileInfo.getLength() != 0) { - Assert.assertEquals(fileInfo.getLength(), 3); - Assert.assertTrue(fileInfo.getModificationTime() != 0); - Assert.assertTrue(filePathSet2.contains(fileInfo.getPath())); - fileCount++; - } else { - dirCount++; - } - } - Assert.assertEquals(fileCount, 1); - Assert.assertEquals(dirCount, 0); - } - - @Test(enabled = false) - public void testIsFileAndIsDirectory() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Path testFilePath = new Path(root + "/s3/1/" + testName + "1.txt"); - createNewFileWithFixedEmptyBytes(testFilePath, 3); - - Path filePath1 = new Path(root + "/s3/1/" + testName + "2.txt"); - - Path dirPath1 = new Path(root + "/s3/1/"); - Path dirPath2 = new Path(root + "/s3/"); - Path dirPath3 = new Path(root); - Path dirPath4 = new Path(remoteRoot); - - Assert.assertTrue(s3IO.isFile(testFilePath)); - Assert.assertFalse(s3IO.isFile(filePath1)); - Assert.assertFalse(s3IO.isFile(dirPath1)); - - Assert.assertFalse(s3IO.isDirectory(testFilePath)); - Assert.assertFalse(s3IO.isDirectory(filePath1)); - Assert.assertTrue(s3IO.isDirectory(dirPath1)); - Assert.assertTrue(s3IO.isDirectory(dirPath2)); - Assert.assertTrue(s3IO.isDirectory(dirPath3)); - Assert.assertTrue(s3IO.isDirectory(dirPath4)); - - Assert.assertTrue(s3IO.exists(testFilePath)); - Assert.assertTrue(s3IO.exists(dirPath1)); - Assert.assertTrue(s3IO.exists(dirPath2)); - Assert.assertTrue(s3IO.exists(dirPath3)); - Assert.assertTrue(s3IO.exists(dirPath4)); - Assert.assertFalse(s3IO.exists(filePath1)); + Assert.assertEquals(fileCount, 2); + Assert.assertEquals(dirCount, 3); + List fileNames = s3IO.listFileName(dirPath2); + Assert.assertEquals(fileNames.size(), 5); + Assert.assertTrue(fileNames.contains("testListFileInfo1.txt")); + + fileInfos = s3IO.listFileInfo(dirPath1); + fileCount = 0; + dirCount = 0; + for (FileInfo fileInfo : fileInfos) { + if (fileInfo.getLength() != 0) { + Assert.assertEquals(fileInfo.getLength(), 3); + Assert.assertTrue(fileInfo.getModificationTime() != 0); + Assert.assertTrue(filePathSet2.contains(fileInfo.getPath())); + fileCount++; + } else { + dirCount++; + } } - - @Test(enabled = false) - public void testRename() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Path oldFilePath = new Path(root + "/s3/1/" + testName + "1.txt"); - Path newFilePath = new Path(root + "/s3/1/" + testName + "2.txt"); - createNewFileWithFixedEmptyBytes(oldFilePath, 3); - Assert.assertTrue(s3IO.exists(oldFilePath)); - Assert.assertFalse(s3IO.exists(newFilePath)); - - s3IO.renameFile(oldFilePath, newFilePath); - Assert.assertTrue(s3IO.exists(newFilePath)); - Assert.assertFalse(s3IO.exists(oldFilePath)); + Assert.assertEquals(fileCount, 1); + Assert.assertEquals(dirCount, 0); + } + + @Test(enabled = false) + public void testIsFileAndIsDirectory() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Path testFilePath = new Path(root + "/s3/1/" + testName + "1.txt"); + createNewFileWithFixedEmptyBytes(testFilePath, 3); + + Path filePath1 = new Path(root + "/s3/1/" + testName + "2.txt"); + + Path dirPath1 = new Path(root + "/s3/1/"); + Path dirPath2 = new Path(root + "/s3/"); + Path dirPath3 = new Path(root); + Path dirPath4 = new Path(remoteRoot); + + Assert.assertTrue(s3IO.isFile(testFilePath)); + Assert.assertFalse(s3IO.isFile(filePath1)); + Assert.assertFalse(s3IO.isFile(dirPath1)); + + Assert.assertFalse(s3IO.isDirectory(testFilePath)); + Assert.assertFalse(s3IO.isDirectory(filePath1)); + Assert.assertTrue(s3IO.isDirectory(dirPath1)); + Assert.assertTrue(s3IO.isDirectory(dirPath2)); + Assert.assertTrue(s3IO.isDirectory(dirPath3)); + Assert.assertTrue(s3IO.isDirectory(dirPath4)); + + Assert.assertTrue(s3IO.exists(testFilePath)); + Assert.assertTrue(s3IO.exists(dirPath1)); + Assert.assertTrue(s3IO.exists(dirPath2)); + Assert.assertTrue(s3IO.exists(dirPath3)); + Assert.assertTrue(s3IO.exists(dirPath4)); + Assert.assertFalse(s3IO.exists(filePath1)); + } + + @Test(enabled = false) + public void testRename() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Path oldFilePath = new Path(root + "/s3/1/" + testName + "1.txt"); + Path newFilePath = new Path(root + "/s3/1/" + testName + "2.txt"); + createNewFileWithFixedEmptyBytes(oldFilePath, 3); + Assert.assertTrue(s3IO.exists(oldFilePath)); + Assert.assertFalse(s3IO.exists(newFilePath)); + + s3IO.renameFile(oldFilePath, newFilePath); + Assert.assertTrue(s3IO.exists(newFilePath)); + Assert.assertFalse(s3IO.exists(oldFilePath)); + } + + @Test(enabled = false) + public void testCopyFromLocal() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Path localFilePath = new Path(localRoot + "/" + testName + "1.txt"); + RandomAccessFile file = new RandomAccessFile(localFilePath.toUri().getPath(), "rw"); + int fileLen = 1024 * 1024; + byte[] bytes = new byte[fileLen]; + Random random = new Random(); + random.nextBytes(bytes); + file.write(bytes); + file.close(); + + Path remoteFilePath = new Path(root + "/s3/1/" + testName + "1" + ".txt"); + + s3IO.copyFromLocalFile(localFilePath, remoteFilePath); + Assert.assertEquals(s3IO.getFileSize(remoteFilePath), fileLen); + InputStream in = s3IO.open(remoteFilePath); + byte[] readBytes = IOUtils.toByteArray(in); + Assert.assertEquals(readBytes, bytes); + in.close(); + in = s3IO.open(remoteFilePath); + + for (int i = 0; i < 100; i++) { + readBytes[i] = (byte) in.read(); } - - @Test(enabled = false) - public void testCopyFromLocal() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Path localFilePath = new Path(localRoot + "/" + testName + "1.txt"); - RandomAccessFile file = new RandomAccessFile(localFilePath.toUri().getPath(), "rw"); - int fileLen = 1024 * 1024; - byte[] bytes = new byte[fileLen]; - Random random = new Random(); - random.nextBytes(bytes); - file.write(bytes); - file.close(); - - Path remoteFilePath = new Path(root + "/s3/1/" + testName + "1" + ".txt"); - - s3IO.copyFromLocalFile(localFilePath, remoteFilePath); - Assert.assertEquals(s3IO.getFileSize(remoteFilePath), fileLen); - InputStream in = s3IO.open(remoteFilePath); - byte[] readBytes = IOUtils.toByteArray(in); - Assert.assertEquals(readBytes, bytes); - in.close(); - in = s3IO.open(remoteFilePath); - - for (int i = 0; i < 100; i++) { - readBytes[i] = (byte) in.read(); + int readFileLen = in.read(readBytes, 100, 2 * fileLen); + Assert.assertEquals(readBytes, bytes); + Assert.assertEquals(readFileLen, 1024 * 1024 - 100); + } + + @Test(enabled = false) + public void testCopyToLocal() throws IOException { + String testName = getTestName(); + String root = remoteRoot + "/" + testName; + Path localFilePath = new Path(localRoot + "/" + testName + "1.txt"); + Path remoteFilePath = new Path(root + "/s3/1/" + testName + "1" + ".txt"); + + int fileLen = 1024 * 1024; + byte[] bytes = new byte[fileLen]; + Random random = new Random(); + random.nextBytes(bytes); + + createNewFileWithFixedBytes(remoteFilePath, bytes); + + s3IO.copyToLocalFile(remoteFilePath, localFilePath); + Assert.assertEquals(s3IO.getFileSize(remoteFilePath), fileLen); + Assert.assertEquals(localFileSystem.getFileStatus(localFilePath).getLen(), fileLen); + + FSDataInputStream fsDataInputStream = localFileSystem.open(localFilePath); + byte[] readBytes = new byte[fileLen]; + int readLen = fsDataInputStream.read(readBytes); + Assert.assertEquals(readLen, fileLen); + Assert.assertEquals(readBytes, bytes); + } + + // This test is normally closed and is used to delete all data in test bucket. + @Test(enabled = false) + public void deleteAllObjects() { + String continuationToken = null; + + ListObjectsV2Request.Builder listRequestBuilder = + ListObjectsV2Request.builder().bucket(bucketName).maxKeys(S3_MAX_RETURN_KEYS); + + do { + if (continuationToken != null) { + listRequestBuilder.continuationToken(continuationToken); + } + + CompletableFuture future = + s3IO.getS3Client().listObjectsV2(listRequestBuilder.build()); + ListObjectsV2Response response; + try { + response = future.get(); + } catch (Throwable e) { + throw new RuntimeException("Failed to list objects in bucket: " + bucketName, e); + } + + for (S3Object s3Object : response.contents()) { + try { + s3IO.deleteFile(s3Object.key()); + } catch (IOException e) { + throw new RuntimeException( + "Failed to delete object in bucket: " + + bucketName + + ", object key: " + + s3Object.key(), + e); } - int readFileLen = in.read(readBytes, 100, 2 * fileLen); - Assert.assertEquals(readBytes, bytes); - Assert.assertEquals(readFileLen, 1024 * 1024 - 100); - } + } - @Test(enabled = false) - public void testCopyToLocal() throws IOException { - String testName = getTestName(); - String root = remoteRoot + "/" + testName; - Path localFilePath = new Path(localRoot + "/" + testName + "1.txt"); - Path remoteFilePath = new Path(root + "/s3/1/" + testName + "1" + ".txt"); - - int fileLen = 1024 * 1024; - byte[] bytes = new byte[fileLen]; - Random random = new Random(); - random.nextBytes(bytes); - - createNewFileWithFixedBytes(remoteFilePath, bytes); - - s3IO.copyToLocalFile(remoteFilePath, localFilePath); - Assert.assertEquals(s3IO.getFileSize(remoteFilePath), fileLen); - Assert.assertEquals(localFileSystem.getFileStatus(localFilePath).getLen(), fileLen); - - FSDataInputStream fsDataInputStream = localFileSystem.open(localFilePath); - byte[] readBytes = new byte[fileLen]; - int readLen = fsDataInputStream.read(readBytes); - Assert.assertEquals(readLen, fileLen); - Assert.assertEquals(readBytes, bytes); - } + continuationToken = response.nextContinuationToken(); + } while (continuationToken != null); + } - // This test is normally closed and is used to delete all data in test bucket. - @Test(enabled = false) - public void deleteAllObjects() { - String continuationToken = null; - - ListObjectsV2Request.Builder listRequestBuilder = ListObjectsV2Request.builder() - .bucket(bucketName).maxKeys(S3_MAX_RETURN_KEYS); - - do { - if (continuationToken != null) { - listRequestBuilder.continuationToken(continuationToken); - } - - CompletableFuture future = s3IO.getS3Client() - .listObjectsV2(listRequestBuilder.build()); - ListObjectsV2Response response; - try { - response = future.get(); - } catch (Throwable e) { - throw new RuntimeException("Failed to list objects in bucket: " + bucketName, e); - } - - for (S3Object s3Object : response.contents()) { - try { - s3IO.deleteFile(s3Object.key()); - } catch (IOException e) { - throw new RuntimeException( - "Failed to delete object in bucket: " + bucketName + ", object key: " - + s3Object.key(), e); - } - } - - continuationToken = response.nextContinuationToken(); - } while (continuationToken != null); - } - - private static String getTestName() { - return Thread.currentThread().getStackTrace()[2].getMethodName(); - } + private static String getTestName() { + return Thread.currentThread().getStackTrace()[2].getMethodName(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilder.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilder.java index 10b468508..aa25bf8db 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilder.java @@ -23,9 +23,9 @@ public interface ServiceBuilder { - ServiceConsumer buildConsumer(Configuration configuration); + ServiceConsumer buildConsumer(Configuration configuration); - ServiceProvider buildProvider(Configuration configuration); + ServiceProvider buildProvider(Configuration configuration); - String serviceType(); + String serviceType(); } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilderFactory.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilderFactory.java index 665c57204..18b34b5cc 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilderFactory.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceBuilderFactory.java @@ -22,29 +22,30 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; public class ServiceBuilderFactory { - private static final Map CONCURRENT_TYPE_MAP = new ConcurrentHashMap<>(); - - public static synchronized ServiceBuilder build(String serviceType) { - if (CONCURRENT_TYPE_MAP.containsKey(serviceType)) { - return CONCURRENT_TYPE_MAP.get(serviceType); - } - - ServiceLoader serviceLoader = ServiceLoader.load(ServiceBuilder.class); - for (ServiceBuilder storeBuilder : serviceLoader) { - if (storeBuilder.serviceType().equalsIgnoreCase(serviceType)) { - CONCURRENT_TYPE_MAP.put(serviceType, storeBuilder); - return storeBuilder; - } - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(serviceType)); + private static final Map CONCURRENT_TYPE_MAP = new ConcurrentHashMap<>(); + + public static synchronized ServiceBuilder build(String serviceType) { + if (CONCURRENT_TYPE_MAP.containsKey(serviceType)) { + return CONCURRENT_TYPE_MAP.get(serviceType); } - public static synchronized void clear() { - CONCURRENT_TYPE_MAP.clear(); + ServiceLoader serviceLoader = ServiceLoader.load(ServiceBuilder.class); + for (ServiceBuilder storeBuilder : serviceLoader) { + if (storeBuilder.serviceType().equalsIgnoreCase(serviceType)) { + CONCURRENT_TYPE_MAP.put(serviceType, storeBuilder); + return storeBuilder; + } } + throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(serviceType)); + } + + public static synchronized void clear() { + CONCURRENT_TYPE_MAP.clear(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceConsumer.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceConsumer.java index bcaa8114c..c31945611 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceConsumer.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceConsumer.java @@ -21,24 +21,15 @@ public interface ServiceConsumer { - /** - * Check if the specified path exists. - */ - boolean exists(String path); + /** Check if the specified path exists. */ + boolean exists(String path); - /** - * Get the data at the specified path and set a watch. - */ - byte[] getDataAndWatch(String path); + /** Get the data at the specified path and set a watch. */ + byte[] getDataAndWatch(String path); - /** - * Register the specified listener to receive updated events. - */ - default void register(ServiceListener listener) { - } + /** Register the specified listener to receive updated events. */ + default void register(ServiceListener listener) {} - /** - * close the consumer. - */ - void close(); + /** close the consumer. */ + void close(); } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceDiscoveryType.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceDiscoveryType.java index b4eea0401..48629429f 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceDiscoveryType.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceDiscoveryType.java @@ -23,22 +23,21 @@ import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; public enum ServiceDiscoveryType { + REDIS, - REDIS, + ZOOKEEPER; - ZOOKEEPER; - - public static ServiceDiscoveryType getEnum(String type) { - for (ServiceDiscoveryType serviceType : values()) { - if (serviceType.name().equalsIgnoreCase(type)) { - return serviceType; - } - } - return ZOOKEEPER; + public static ServiceDiscoveryType getEnum(String type) { + for (ServiceDiscoveryType serviceType : values()) { + if (serviceType.name().equalsIgnoreCase(type)) { + return serviceType; + } } + return ZOOKEEPER; + } - public static ServiceDiscoveryType getEnum(Configuration config) { - String type = config.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE); - return getEnum(type); - } + public static ServiceDiscoveryType getEnum(Configuration config) { + String type = config.getString(ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE); + return getEnum(type); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceListener.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceListener.java index fb809b2e9..b2ba05be0 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceListener.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceListener.java @@ -21,23 +21,15 @@ public interface ServiceListener { - /** - * Called when a new node has been created. - */ - void nodeCreated(String path); + /** Called when a new node has been created. */ + void nodeCreated(String path); - /** - * Called when a node has been deleted. - */ - void nodeDeleted(String path); + /** Called when a node has been deleted. */ + void nodeDeleted(String path); - /** - * Called when an existing node has changed data. - */ - void nodeDataChanged(String path); + /** Called when an existing node has changed data. */ + void nodeDataChanged(String path); - /** - * Called when an existing node has a child node added or removed. - */ - void nodeChildrenChanged(String path); + /** Called when an existing node has a child node added or removed. */ + void nodeChildrenChanged(String path); } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceProvider.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceProvider.java index aa43bfe37..df51a5ae1 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceProvider.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-api/src/main/java/org/apache/geaflow/service/discovery/ServiceProvider.java @@ -21,25 +21,18 @@ public interface ServiceProvider extends ServiceConsumer { - /** - * Watch the specified path for updated events. - * If the node already exists, the method returns true. If the node does not exist, the method returns false. - */ - boolean watchAndCheckExists(String path); + /** + * Watch the specified path for updated events. If the node already exists, the method returns + * true. If the node does not exist, the method returns false. + */ + boolean watchAndCheckExists(String path); - /** - * Delete the specified path. - */ - void delete(String path); + /** Delete the specified path. */ + void delete(String path); - /** - * Creates the specified node with the specified data and watches it. - */ - boolean createAndWatch(String path, byte[] data); - - /** - * Update the specified node with the specified data. - */ - boolean update(String path, byte[] data); + /** Creates the specified node with the specified data and watches it. */ + boolean createAndWatch(String path, byte[] data); + /** Update the specified node with the specified data. */ + boolean update(String path, byte[] data); } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RecoverableRedis.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RecoverableRedis.java index d354774ef..2fb368c0f 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RecoverableRedis.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RecoverableRedis.java @@ -24,55 +24,63 @@ import org.apache.geaflow.common.utils.RetryCommand; import org.apache.geaflow.store.context.StoreContext; import org.apache.geaflow.store.redis.BaseRedisStore; + import redis.clients.jedis.Jedis; public class RecoverableRedis extends BaseRedisStore { - private ISerializer serializer; - private String namespace; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.serializer = SerializerFactory.getKryoSerializer(); - this.namespace = storeContext.getName() + BaseRedisStore.REDIS_NAMESPACE_SPLITTER; - } + private ISerializer serializer; + private String namespace; - public void setData(String key, byte[] valueArray) { - byte[] keyArray = this.serializer.serialize(key); - byte[] redisKey = getRedisKey(keyArray); - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.set(redisKey, valueArray); - } - return null; - }, retryTimes, retryIntervalMs); - } + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.serializer = SerializerFactory.getKryoSerializer(); + this.namespace = storeContext.getName() + BaseRedisStore.REDIS_NAMESPACE_SPLITTER; + } - public byte[] getData(String key) { - byte[] keyArray = this.serializer.serialize(key); - byte[] redisKey = getRedisKey(keyArray); - return RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - return jedis.get(redisKey); - } - }, retryTimes, retryIntervalMs); - } + public void setData(String key, byte[] valueArray) { + byte[] keyArray = this.serializer.serialize(key); + byte[] redisKey = getRedisKey(keyArray); + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.set(redisKey, valueArray); + } + return null; + }, + retryTimes, + retryIntervalMs); + } - public void deleteData(String key) { - byte[] keyArray = this.serializer.serialize(key); - byte[] redisKey = getRedisKey(keyArray); + public byte[] getData(String key) { + byte[] keyArray = this.serializer.serialize(key); + byte[] redisKey = getRedisKey(keyArray); + return RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.get(redisKey); + } + }, + retryTimes, + retryIntervalMs); + } - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.del(redisKey); - } - return null; - }, retryTimes, retryIntervalMs); - } + public void deleteData(String key) { + byte[] keyArray = this.serializer.serialize(key); + byte[] redisKey = getRedisKey(keyArray); - public String getNamespace() { - return this.namespace; - } + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.del(redisKey); + } + return null; + }, + retryTimes, + retryIntervalMs); + } + public String getNamespace() { + return this.namespace; + } } - diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceBuilder.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceBuilder.java index 9f5ea7139..3809baa52 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceBuilder.java @@ -23,20 +23,20 @@ public class RedisServiceBuilder implements ServiceBuilder { - private static final String SERVICE_TYPE = "redis"; + private static final String SERVICE_TYPE = "redis"; - @Override - public ServiceConsumer buildConsumer(Configuration configuration) { - return new RedisServiceConsumer(configuration); - } + @Override + public ServiceConsumer buildConsumer(Configuration configuration) { + return new RedisServiceConsumer(configuration); + } - @Override - public ServiceProvider buildProvider(Configuration configuration) { - return new RedisServiceProvider(configuration); - } + @Override + public ServiceProvider buildProvider(Configuration configuration) { + return new RedisServiceProvider(configuration); + } - @Override - public String serviceType() { - return SERVICE_TYPE; - } + @Override + public String serviceType() { + return SERVICE_TYPE; + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceConsumer.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceConsumer.java index 4035e6f25..7ee43b5e8 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceConsumer.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceConsumer.java @@ -28,45 +28,44 @@ import org.slf4j.LoggerFactory; public class RedisServiceConsumer implements ServiceConsumer { - private static final Logger LOGGER = LoggerFactory.getLogger(RedisServiceConsumer.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RedisServiceConsumer.class); - private final RecoverableRedis recoverableRedis; + private final RecoverableRedis recoverableRedis; - private final String baseKey; + private final String baseKey; - private final String namespace; + private final String namespace; - public RedisServiceConsumer(Configuration configuration) { - this.recoverableRedis = new RecoverableRedis(); - String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.baseKey = appName.startsWith("/") ? appName : "/" + appName; - StoreContext storeContext = new StoreContext(baseKey); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); - storeContext.withConfig(configuration); - this.recoverableRedis.init(storeContext); - this.namespace = recoverableRedis.getNamespace(); - LOGGER.info("redis service consumer base key is {}, namespace is {}", this.baseKey, this.namespace); - } - - @Override - public boolean exists(String path) { - if (StringUtils.isBlank(path)) { - return this.recoverableRedis.getData(this.namespace) != null; - } - return this.recoverableRedis.getData(path) != null; - } + public RedisServiceConsumer(Configuration configuration) { + this.recoverableRedis = new RecoverableRedis(); + String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); + this.baseKey = appName.startsWith("/") ? appName : "/" + appName; + StoreContext storeContext = new StoreContext(baseKey); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); + storeContext.withConfig(configuration); + this.recoverableRedis.init(storeContext); + this.namespace = recoverableRedis.getNamespace(); + LOGGER.info( + "redis service consumer base key is {}, namespace is {}", this.baseKey, this.namespace); + } - @Override - public byte[] getDataAndWatch(String path) { - return this.recoverableRedis.getData(path); + @Override + public boolean exists(String path) { + if (StringUtils.isBlank(path)) { + return this.recoverableRedis.getData(this.namespace) != null; } + return this.recoverableRedis.getData(path) != null; + } + @Override + public byte[] getDataAndWatch(String path) { + return this.recoverableRedis.getData(path); + } - @Override - public void close() { - if (this.recoverableRedis != null) { - this.recoverableRedis.close(); - } + @Override + public void close() { + if (this.recoverableRedis != null) { + this.recoverableRedis.close(); } - + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceProvider.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceProvider.java index 282034b4a..f3062b2c6 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceProvider.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/main/java/org/apache/geaflow/service/discovery/RedisServiceProvider.java @@ -28,63 +28,62 @@ import org.slf4j.LoggerFactory; public class RedisServiceProvider implements ServiceProvider { - private static final Logger LOGGER = LoggerFactory.getLogger(RedisServiceProvider.class); - private final RecoverableRedis recoverableRedis; - private final String baseKey; - private final String namespace; + private static final Logger LOGGER = LoggerFactory.getLogger(RedisServiceProvider.class); + private final RecoverableRedis recoverableRedis; + private final String baseKey; + private final String namespace; - public RedisServiceProvider(Configuration configuration) { - this.recoverableRedis = new RecoverableRedis(); - String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.baseKey = appName.startsWith("/") ? appName : "/" + appName; - StoreContext storeContext = new StoreContext(baseKey); - storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); - storeContext.withConfig(configuration); - this.recoverableRedis.init(storeContext); - this.namespace = recoverableRedis.getNamespace(); - this.recoverableRedis.setData(namespace, new byte[0]); - } + public RedisServiceProvider(Configuration configuration) { + this.recoverableRedis = new RecoverableRedis(); + String appName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME); + this.baseKey = appName.startsWith("/") ? appName : "/" + appName; + StoreContext storeContext = new StoreContext(baseKey); + storeContext.withKeySerializer(new DefaultKVSerializer(String.class, null)); + storeContext.withConfig(configuration); + this.recoverableRedis.init(storeContext); + this.namespace = recoverableRedis.getNamespace(); + this.recoverableRedis.setData(namespace, new byte[0]); + } - @Override - public boolean exists(String path) { - if (StringUtils.isBlank(path)) { - return this.recoverableRedis.getData(this.namespace) != null; - } - return this.recoverableRedis.getData(path) != null; + @Override + public boolean exists(String path) { + if (StringUtils.isBlank(path)) { + return this.recoverableRedis.getData(this.namespace) != null; } + return this.recoverableRedis.getData(path) != null; + } - @Override - public byte[] getDataAndWatch(String path) { - return this.recoverableRedis.getData(path); - } + @Override + public byte[] getDataAndWatch(String path) { + return this.recoverableRedis.getData(path); + } - @Override - public boolean watchAndCheckExists(String path) { - return false; - } + @Override + public boolean watchAndCheckExists(String path) { + return false; + } - @Override - public void delete(String path) { - this.recoverableRedis.deleteData(path); - } + @Override + public void delete(String path) { + this.recoverableRedis.deleteData(path); + } - @Override - public boolean createAndWatch(String path, byte[] data) { - this.recoverableRedis.setData(path, data); - return true; - } + @Override + public boolean createAndWatch(String path, byte[] data) { + this.recoverableRedis.setData(path, data); + return true; + } - @Override - public boolean update(String path, byte[] data) { - this.recoverableRedis.setData(path, data); - return true; - } + @Override + public boolean update(String path, byte[] data) { + this.recoverableRedis.setData(path, data); + return true; + } - @Override - public void close() { - if (this.recoverableRedis != null) { - this.recoverableRedis.close(); - } + @Override + public void close() { + if (this.recoverableRedis != null) { + this.recoverableRedis.close(); } - + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/test/java/org/apache/geaflow/service/discovery/RedisTest.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/test/java/org/apache/geaflow/service/discovery/RedisTest.java index dc605d700..7b9f07dd3 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/test/java/org/apache/geaflow/service/discovery/RedisTest.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-redis/src/test/java/org/apache/geaflow/service/discovery/RedisTest.java @@ -21,11 +21,10 @@ import static org.apache.geaflow.common.config.keys.ExecutionConfigKeys.SERVICE_DISCOVERY_TYPE; -import com.github.fppt.jedismock.RedisServer; -import com.google.common.primitives.Longs; import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.store.redis.RedisConfigKeys; @@ -34,106 +33,108 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class RedisTest { - - private static final String MASTER = "master"; - private RedisServer redisServer; - private Configuration configuration; - private ServiceBuilder serviceBuilder; - private ServiceConsumer consumer; - private ServiceProvider provider; - private String serviceType = "redis"; - - @BeforeClass - public void prepare() throws IOException { - redisServer = RedisServer.newRedisServer().start(); - this.configuration = new Configuration(); - this.configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - this.configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - this.configuration.put(SERVICE_DISCOVERY_TYPE, "redis"); - this.configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "testJob123"); - } - - @AfterClass - public void tearUp() throws IOException { - if (consumer != null) { - consumer.close(); - } - if (provider != null) { - provider.close(); - } - redisServer.stop(); - } - - @Test - public void testRedisServiceBuilder() { - serviceBuilder = ServiceBuilderFactory.build( - configuration.getString(SERVICE_DISCOVERY_TYPE)); - Assert.assertTrue(serviceBuilder instanceof RedisServiceBuilder); - this.consumer = ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); - Assert.assertTrue(consumer instanceof RedisServiceConsumer); - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - Assert.assertTrue(provider instanceof RedisServiceProvider); - } - - @Test - public void testCreateBaseNode() { - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - Assert.assertTrue(provider.exists("")); - } - - @Test - public void testDelete() { - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - boolean res = provider.createAndWatch(MASTER, "123".getBytes()); - Assert.assertTrue(res); - Assert.assertTrue(provider.exists(MASTER)); - this.provider.delete(MASTER); - Assert.assertFalse(provider.exists(MASTER)); - } +import com.github.fppt.jedismock.RedisServer; +import com.google.common.primitives.Longs; - @Test - public void testConsumer() { - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - ServiceConsumer consumer = ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); - boolean res = provider.createAndWatch(MASTER, "123".getBytes()); - Assert.assertTrue(res); - byte[] datas = consumer.getDataAndWatch(MASTER); - Assert.assertEquals(datas, "123".getBytes()); - provider.delete(MASTER); - consumer.close(); - } +public class RedisTest { - @Test - public void testUpdate() { - String version = "version"; - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - ServiceConsumer consumer = ServiceBuilderFactory.build(serviceType) - .buildConsumer(configuration); - - Assert.assertFalse(consumer.exists(version)); - byte[] versionData = Longs.toByteArray(2); - - this.provider.update(version, versionData); - byte[] data = consumer.getDataAndWatch(version); - Assert.assertEquals(versionData, data); - - versionData = Longs.toByteArray(3); - this.provider.update(version, versionData); - data = consumer.getDataAndWatch(version); - Assert.assertEquals(versionData, data); - consumer.close(); + private static final String MASTER = "master"; + private RedisServer redisServer; + private Configuration configuration; + private ServiceBuilder serviceBuilder; + private ServiceConsumer consumer; + private ServiceProvider provider; + private String serviceType = "redis"; + + @BeforeClass + public void prepare() throws IOException { + redisServer = RedisServer.newRedisServer().start(); + this.configuration = new Configuration(); + this.configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + this.configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + this.configuration.put(SERVICE_DISCOVERY_TYPE, "redis"); + this.configuration.put(ExecutionConfigKeys.JOB_APP_NAME, "testJob123"); + } + + @AfterClass + public void tearUp() throws IOException { + if (consumer != null) { + consumer.close(); } - - @Test - public void testBaseKey() { - Map config = configuration.getConfigMap(); - Configuration newConfig = new Configuration(new HashMap<>(config)); - newConfig.put(ExecutionConfigKeys.JOB_APP_NAME, "234"); - this.consumer = ServiceBuilderFactory.build(serviceType).buildConsumer(newConfig); - this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(newConfig); - Assert.assertTrue(provider.exists(null)); - ServiceBuilderFactory.clear(); + if (provider != null) { + provider.close(); } - + redisServer.stop(); + } + + @Test + public void testRedisServiceBuilder() { + serviceBuilder = ServiceBuilderFactory.build(configuration.getString(SERVICE_DISCOVERY_TYPE)); + Assert.assertTrue(serviceBuilder instanceof RedisServiceBuilder); + this.consumer = ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); + Assert.assertTrue(consumer instanceof RedisServiceConsumer); + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + Assert.assertTrue(provider instanceof RedisServiceProvider); + } + + @Test + public void testCreateBaseNode() { + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + Assert.assertTrue(provider.exists("")); + } + + @Test + public void testDelete() { + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + boolean res = provider.createAndWatch(MASTER, "123".getBytes()); + Assert.assertTrue(res); + Assert.assertTrue(provider.exists(MASTER)); + this.provider.delete(MASTER); + Assert.assertFalse(provider.exists(MASTER)); + } + + @Test + public void testConsumer() { + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + ServiceConsumer consumer = + ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); + boolean res = provider.createAndWatch(MASTER, "123".getBytes()); + Assert.assertTrue(res); + byte[] datas = consumer.getDataAndWatch(MASTER); + Assert.assertEquals(datas, "123".getBytes()); + provider.delete(MASTER); + consumer.close(); + } + + @Test + public void testUpdate() { + String version = "version"; + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + ServiceConsumer consumer = + ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); + + Assert.assertFalse(consumer.exists(version)); + byte[] versionData = Longs.toByteArray(2); + + this.provider.update(version, versionData); + byte[] data = consumer.getDataAndWatch(version); + Assert.assertEquals(versionData, data); + + versionData = Longs.toByteArray(3); + this.provider.update(version, versionData); + data = consumer.getDataAndWatch(version); + Assert.assertEquals(versionData, data); + consumer.close(); + } + + @Test + public void testBaseKey() { + Map config = configuration.getConfigMap(); + Configuration newConfig = new Configuration(new HashMap<>(config)); + newConfig.put(ExecutionConfigKeys.JOB_APP_NAME, "234"); + this.consumer = ServiceBuilderFactory.build(serviceType).buildConsumer(newConfig); + this.provider = ServiceBuilderFactory.build(serviceType).buildProvider(newConfig); + Assert.assertTrue(provider.exists(null)); + ServiceBuilderFactory.clear(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/RecoverableZooKeeper.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/RecoverableZooKeeper.java index 7b1ac7d05..6b1cc1697 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/RecoverableZooKeeper.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/RecoverableZooKeeper.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; + import org.apache.geaflow.common.utils.ProcessUtil; import org.apache.geaflow.common.utils.RetryCommand; import org.apache.zookeeper.CreateMode; @@ -37,112 +38,118 @@ public class RecoverableZooKeeper { - private static final Logger LOGGER = LoggerFactory.getLogger(RecoverableZooKeeper.class); - private final ZooKeeper zk; - private final int retryIntervalMillis; - private final int maxRetries; - private final String identifier; - - public RecoverableZooKeeper(String quorumServers, int sessionTimeout, Watcher watcher, - int maxRetries, int retryIntervalMillis) throws IOException { - this.zk = new ZooKeeper(quorumServers, sessionTimeout, watcher); - this.retryIntervalMillis = retryIntervalMillis; - this.identifier = ProcessUtil.getHostAndPid(); - LOGGER.info("The identifier of this process is {}", identifier); - this.maxRetries = maxRetries; - } - - - public boolean delete(String path, int version) { - return Boolean.TRUE.equals(RetryCommand.run(() -> { - zk.delete(path, version); - return true; - }, maxRetries, retryIntervalMillis)); - } - - - public Stat exists(String path, Watcher watcher) { - return RetryCommand.run(() -> zk.exists(path, watcher), maxRetries, retryIntervalMillis); - } - - - public Stat exists(String path, boolean watch) { - return RetryCommand.run(() -> zk.exists(path, watch), maxRetries, retryIntervalMillis); - } - - public List getChildren(String path, Watcher watcher) { - return RetryCommand.run(() -> zk.getChildren(path, watcher), maxRetries, - retryIntervalMillis); - } - - public byte[] getData(String path, Watcher watcher, Stat stat) { - return RetryCommand.run(() -> zk.getData(path, watcher, stat), maxRetries, - retryIntervalMillis); + private static final Logger LOGGER = LoggerFactory.getLogger(RecoverableZooKeeper.class); + private final ZooKeeper zk; + private final int retryIntervalMillis; + private final int maxRetries; + private final String identifier; + + public RecoverableZooKeeper( + String quorumServers, + int sessionTimeout, + Watcher watcher, + int maxRetries, + int retryIntervalMillis) + throws IOException { + this.zk = new ZooKeeper(quorumServers, sessionTimeout, watcher); + this.retryIntervalMillis = retryIntervalMillis; + this.identifier = ProcessUtil.getHostAndPid(); + LOGGER.info("The identifier of this process is {}", identifier); + this.maxRetries = maxRetries; + } + + public boolean delete(String path, int version) { + return Boolean.TRUE.equals( + RetryCommand.run( + () -> { + zk.delete(path, version); + return true; + }, + maxRetries, + retryIntervalMillis)); + } + + public Stat exists(String path, Watcher watcher) { + return RetryCommand.run(() -> zk.exists(path, watcher), maxRetries, retryIntervalMillis); + } + + public Stat exists(String path, boolean watch) { + return RetryCommand.run(() -> zk.exists(path, watch), maxRetries, retryIntervalMillis); + } + + public List getChildren(String path, Watcher watcher) { + return RetryCommand.run(() -> zk.getChildren(path, watcher), maxRetries, retryIntervalMillis); + } + + public byte[] getData(String path, Watcher watcher, Stat stat) { + return RetryCommand.run(() -> zk.getData(path, watcher, stat), maxRetries, retryIntervalMillis); + } + + public String create(String path, byte[] data, List acl, CreateMode createMode) + throws KeeperException, InterruptedException { + + switch (createMode) { + case EPHEMERAL: + case PERSISTENT: + return createNonSequential(path, data, acl, createMode); + default: + throw new IllegalArgumentException("Unrecognized CreateMode: " + createMode); } - - public String create(String path, byte[] data, List acl, CreateMode createMode) - throws KeeperException, InterruptedException { - - switch (createMode) { - case EPHEMERAL: - case PERSISTENT: - return createNonSequential(path, data, acl, createMode); - default: - throw new IllegalArgumentException("Unrecognized CreateMode: " + createMode); - } + } + + private String createNonSequential(String path, byte[] data, List acl, CreateMode createMode) + throws KeeperException, InterruptedException { + if (exists(path, false) != null) { + byte[] currentData = zk.getData(path, false, null); + if (currentData != null && Arrays.equals(currentData, data)) { + // We successfully created a non-sequential node + return path; + } } - - private String createNonSequential(String path, byte[] data, List acl, - CreateMode createMode) - throws KeeperException, InterruptedException { + try { + return zk.create(path, data, acl, createMode); + } catch (KeeperException e) { + if (Objects.requireNonNull(e.code()) == Code.NODEEXISTS) { + // If the connection was lost, there is still a possibility that + // we have successfully created the node at our previous attempt, + // so we read the node and compare. if (exists(path, false) != null) { - byte[] currentData = zk.getData(path, false, null); - if (currentData != null && Arrays.equals(currentData, data)) { - // We successfully created a non-sequential node - return path; - } - } - try { - return zk.create(path, data, acl, createMode); - } catch (KeeperException e) { - if (Objects.requireNonNull(e.code()) == Code.NODEEXISTS) { - // If the connection was lost, there is still a possibility that - // we have successfully created the node at our previous attempt, - // so we read the node and compare. - if (exists(path, false) != null) { - byte[] currentData = zk.getData(path, false, null); - if (currentData != null && Arrays.equals(currentData, data)) { - // We successfully created a non-sequential node - return path; - } - LOGGER.error("Node {} already exists with {}, could not write {}", path, - Arrays.toString(currentData), Arrays.toString(data)); - } - } - throw e; + byte[] currentData = zk.getData(path, false, null); + if (currentData != null && Arrays.equals(currentData, data)) { + // We successfully created a non-sequential node + return path; + } + LOGGER.error( + "Node {} already exists with {}, could not write {}", + path, + Arrays.toString(currentData), + Arrays.toString(data)); } + } + throw e; } - - public boolean setData(String path, byte[] data) throws KeeperException, InterruptedException { - try { - if (exists(path, false) != null) { - byte[] currentData = zk.getData(path, false, null); - if (currentData != null && Arrays.equals(currentData, data)) { - return true; - } - } - zk.setData(path, data, -1); - } catch (KeeperException e) { - throw e; + } + + public boolean setData(String path, byte[] data) throws KeeperException, InterruptedException { + try { + if (exists(path, false) != null) { + byte[] currentData = zk.getData(path, false, null); + if (currentData != null && Arrays.equals(currentData, data)) { + return true; } - return true; + } + zk.setData(path, data, -1); + } catch (KeeperException e) { + throw e; } + return true; + } - public long getSessionId() { - return zk.getSessionId(); - } + public long getSessionId() { + return zk.getSessionId(); + } - public void close() throws InterruptedException { - zk.close(); - } + public void close() throws InterruptedException { + zk.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZKUtil.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZKUtil.java index 30537735b..a3413ede5 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZKUtil.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZKUtil.java @@ -19,9 +19,9 @@ package org.apache.geaflow.service.discovery.zookeeper; -import com.google.common.collect.Lists; import java.io.IOException; import java.util.List; + import org.apache.commons.lang3.StringUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -35,134 +35,128 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ZKUtil { - - private static final Logger LOG = LoggerFactory.getLogger(ZKUtil.class); - public static final char ZNODE_PATH_SEPARATOR = '/'; - public static final List CREATOR_ALL_AND_WORLD_READABLE = Lists.newArrayList( - new ACL(ZooDefs.Perms.ALL, Ids.ANYONE_ID_UNSAFE)); - - public static RecoverableZooKeeper connect(Configuration conf, Watcher watcher) - throws IOException { - String quorumServers = conf.getString(ZooKeeperConfigKeys.ZOOKEEPER_QUORUM_SERVERS); - if (quorumServers == null) { - throw new IllegalArgumentException("Not find zookeeper quorumServers"); - } - int timeout = conf.getInteger(ZooKeeperConfigKeys.ZOOKEEPER_SESSION_TIMEOUT); - LOG.info("opening connection to ZooKeeper with quorumServers {}", quorumServers); - int retry = conf.getInteger(ZooKeeperConfigKeys.ZOOKEEPER_RETRY); - int retryIntervalMillis = conf.getInteger( - ZooKeeperConfigKeys.ZOOKEEPER_RETRY_INTERVAL_MILL); - return new RecoverableZooKeeper(quorumServers, timeout, watcher, retry, - retryIntervalMillis); - } +import com.google.common.collect.Lists; +public class ZKUtil { - public static String joinZNode(String prefix, String suffix) { - if (StringUtils.isBlank(suffix)) { - return prefix; - } - return prefix + ZNODE_PATH_SEPARATOR + suffix; - } + private static final Logger LOG = LoggerFactory.getLogger(ZKUtil.class); + public static final char ZNODE_PATH_SEPARATOR = '/'; + public static final List CREATOR_ALL_AND_WORLD_READABLE = + Lists.newArrayList(new ACL(ZooDefs.Perms.ALL, Ids.ANYONE_ID_UNSAFE)); - /** - * Watch the specified znode for delete/create/change events. The watcher is - * set whether or not the node exists. If the node already exists, the method - * returns true. If the node does not exist, the method returns false. - */ - public static boolean watchAndCheckExists(ZooKeeperWatcher zkw, String znode) { - try { - Stat s = zkw.getRecoverableZooKeeper().exists(znode, zkw); - boolean exists = s != null; - if (exists) { - LOG.info("Set watcher on existing znode {}", znode); - } else { - LOG.info("{} does not exist. Watcher is set.", znode); - } - return exists; - } catch (Exception e) { - LOG.warn("Unable to set watcher on znode {}", znode, e); - return false; - } + public static RecoverableZooKeeper connect(Configuration conf, Watcher watcher) + throws IOException { + String quorumServers = conf.getString(ZooKeeperConfigKeys.ZOOKEEPER_QUORUM_SERVERS); + if (quorumServers == null) { + throw new IllegalArgumentException("Not find zookeeper quorumServers"); } - - /** - * Check if the specified node exists. Sets no watches. - */ - public static boolean exists(ZooKeeperWatcher zkw, String znode) { - try { - return zkw.getRecoverableZooKeeper().exists(znode, null) != null; - } catch (Exception e) { - LOG.warn("Unable to set watcher on znode ({})", znode, e); - return false; - } + int timeout = conf.getInteger(ZooKeeperConfigKeys.ZOOKEEPER_SESSION_TIMEOUT); + LOG.info("opening connection to ZooKeeper with quorumServers {}", quorumServers); + int retry = conf.getInteger(ZooKeeperConfigKeys.ZOOKEEPER_RETRY); + int retryIntervalMillis = conf.getInteger(ZooKeeperConfigKeys.ZOOKEEPER_RETRY_INTERVAL_MILL); + return new RecoverableZooKeeper(quorumServers, timeout, watcher, retry, retryIntervalMillis); + } + + public static String joinZNode(String prefix, String suffix) { + if (StringUtils.isBlank(suffix)) { + return prefix; } - - public static byte[] getDataAndWatch(ZooKeeperWatcher zkw, String znode) { - return getDataInternal(zkw, znode, null); + return prefix + ZNODE_PATH_SEPARATOR + suffix; + } + + /** + * Watch the specified znode for delete/create/change events. The watcher is set whether or not + * the node exists. If the node already exists, the method returns true. If the node does not + * exist, the method returns false. + */ + public static boolean watchAndCheckExists(ZooKeeperWatcher zkw, String znode) { + try { + Stat s = zkw.getRecoverableZooKeeper().exists(znode, zkw); + boolean exists = s != null; + if (exists) { + LOG.info("Set watcher on existing znode {}", znode); + } else { + LOG.info("{} does not exist. Watcher is set.", znode); + } + return exists; + } catch (Exception e) { + LOG.warn("Unable to set watcher on znode {}", znode, e); + return false; } - - private static byte[] getDataInternal(ZooKeeperWatcher zkw, String znode, Stat stat) { - try { - return zkw.getRecoverableZooKeeper().getData(znode, zkw, stat); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + } + + /** Check if the specified node exists. Sets no watches. */ + public static boolean exists(ZooKeeperWatcher zkw, String znode) { + try { + return zkw.getRecoverableZooKeeper().exists(znode, null) != null; + } catch (Exception e) { + LOG.warn("Unable to set watcher on znode ({})", znode, e); + return false; } + } + public static byte[] getDataAndWatch(ZooKeeperWatcher zkw, String znode) { + return getDataInternal(zkw, znode, null); + } - public static boolean createEphemeralNodeAndWatch(ZooKeeperWatcher zkw, String znode, - byte[] data) { - try { - zkw.getRecoverableZooKeeper() - .create(znode, data, CREATOR_ALL_AND_WORLD_READABLE, CreateMode.EPHEMERAL); - } catch (KeeperException.NodeExistsException nee) { - if (!watchAndCheckExists(zkw, znode)) { - // It did exist but now it doesn't, try again - return createEphemeralNodeAndWatch(zkw, znode, data); - } - return false; - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - return true; + private static byte[] getDataInternal(ZooKeeperWatcher zkw, String znode, Stat stat) { + try { + return zkw.getRecoverableZooKeeper().getData(znode, zkw, stat); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - public static void deleteNode(ZooKeeperWatcher zkw, String node) { - zkw.getRecoverableZooKeeper().delete(node, -1); + } + + public static boolean createEphemeralNodeAndWatch( + ZooKeeperWatcher zkw, String znode, byte[] data) { + try { + zkw.getRecoverableZooKeeper() + .create(znode, data, CREATOR_ALL_AND_WORLD_READABLE, CreateMode.EPHEMERAL); + } catch (KeeperException.NodeExistsException nee) { + if (!watchAndCheckExists(zkw, znode)) { + // It did exist but now it doesn't, try again + return createEphemeralNodeAndWatch(zkw, znode, data); + } + return false; + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - public static void createPersistentNode(ZooKeeperWatcher zkw, String znode) { - - RecoverableZooKeeper zk = zkw.getRecoverableZooKeeper(); - try { - Stat stat = zk.exists(znode, false); - if (stat == null) { - String path = zk.create(znode, new byte[0], CREATOR_ALL_AND_WORLD_READABLE, - CreateMode.PERSISTENT); - LOG.info("{} create {} success", path, CreateMode.PERSISTENT); - } else { - LOG.info("{} exits, skip create", znode); - } - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + return true; + } + + public static void deleteNode(ZooKeeperWatcher zkw, String node) { + zkw.getRecoverableZooKeeper().delete(node, -1); + } + + public static void createPersistentNode(ZooKeeperWatcher zkw, String znode) { + + RecoverableZooKeeper zk = zkw.getRecoverableZooKeeper(); + try { + Stat stat = zk.exists(znode, false); + if (stat == null) { + String path = + zk.create(znode, new byte[0], CREATOR_ALL_AND_WORLD_READABLE, CreateMode.PERSISTENT); + LOG.info("{} create {} success", path, CreateMode.PERSISTENT); + } else { + LOG.info("{} exits, skip create", znode); + } + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - - public static boolean updatePersistentNode(ZooKeeperWatcher zkw, String znode, byte[] data) { - RecoverableZooKeeper zk = zkw.getRecoverableZooKeeper(); - try { - if (zk.exists(znode, false) == null) { - String path = zk.create(znode, data, CREATOR_ALL_AND_WORLD_READABLE, - CreateMode.PERSISTENT); - LOG.info("{} create success", path); - return true; - } - zk.setData(znode, data); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } + } + + public static boolean updatePersistentNode(ZooKeeperWatcher zkw, String znode, byte[] data) { + RecoverableZooKeeper zk = zkw.getRecoverableZooKeeper(); + try { + if (zk.exists(znode, false) == null) { + String path = zk.create(znode, data, CREATOR_ALL_AND_WORLD_READABLE, CreateMode.PERSISTENT); + LOG.info("{} create success", path); return true; + } + zk.setData(znode, data); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); } - + return true; + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperConfigKeys.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperConfigKeys.java index f4b83d657..0e0fb7b33 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperConfigKeys.java @@ -20,34 +20,34 @@ package org.apache.geaflow.service.discovery.zookeeper; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class ZooKeeperConfigKeys implements Serializable { - public static final ConfigKey ZOOKEEPER_BASE_NODE = ConfigKeys - .key("geaflow.zookeeper.znode.parent") - .noDefaultValue() - .description("zookeeper base node"); - - public static final ConfigKey ZOOKEEPER_QUORUM_SERVERS = ConfigKeys - .key("geaflow.zookeeper.quorum.servers") - .noDefaultValue() - .description("zookeeper quorum servers"); - - public static final ConfigKey ZOOKEEPER_SESSION_TIMEOUT = ConfigKeys - .key("geaflow.zookeeper.session.timeout") - .defaultValue(30 * 1000) - .description("zookeeper session timeout"); - - public static final ConfigKey ZOOKEEPER_RETRY = ConfigKeys - .key("geaflow.zookeeper.retry.count") - .defaultValue(5) - .description("zookeeper retry count"); - - public static final ConfigKey ZOOKEEPER_RETRY_INTERVAL_MILL = ConfigKeys - .key("geaflow.zookeeper.retry.interval.mill") - .defaultValue(1000) - .description("zookeeper retry interval"); - + public static final ConfigKey ZOOKEEPER_BASE_NODE = + ConfigKeys.key("geaflow.zookeeper.znode.parent") + .noDefaultValue() + .description("zookeeper base node"); + + public static final ConfigKey ZOOKEEPER_QUORUM_SERVERS = + ConfigKeys.key("geaflow.zookeeper.quorum.servers") + .noDefaultValue() + .description("zookeeper quorum servers"); + + public static final ConfigKey ZOOKEEPER_SESSION_TIMEOUT = + ConfigKeys.key("geaflow.zookeeper.session.timeout") + .defaultValue(30 * 1000) + .description("zookeeper session timeout"); + + public static final ConfigKey ZOOKEEPER_RETRY = + ConfigKeys.key("geaflow.zookeeper.retry.count") + .defaultValue(5) + .description("zookeeper retry count"); + + public static final ConfigKey ZOOKEEPER_RETRY_INTERVAL_MILL = + ConfigKeys.key("geaflow.zookeeper.retry.interval.mill") + .defaultValue(1000) + .description("zookeeper retry interval"); } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceBuilder.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceBuilder.java index 357f3e20e..36fa8c821 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceBuilder.java @@ -26,20 +26,20 @@ public class ZooKeeperServiceBuilder implements ServiceBuilder { - private static final String SERVICE_TYPE = "zookeeper"; + private static final String SERVICE_TYPE = "zookeeper"; - @Override - public ServiceConsumer buildConsumer(Configuration configuration) { - return new ZooKeeperServiceConsumer(configuration); - } + @Override + public ServiceConsumer buildConsumer(Configuration configuration) { + return new ZooKeeperServiceConsumer(configuration); + } - @Override - public ServiceProvider buildProvider(Configuration configuration) { - return new ZooKeeperServiceProvider(configuration); - } + @Override + public ServiceProvider buildProvider(Configuration configuration) { + return new ZooKeeperServiceProvider(configuration); + } - @Override - public String serviceType() { - return SERVICE_TYPE; - } + @Override + public String serviceType() { + return SERVICE_TYPE; + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceConsumer.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceConsumer.java index 93aae1c51..95e5d1eb1 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceConsumer.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceConsumer.java @@ -25,31 +25,29 @@ public class ZooKeeperServiceConsumer implements ServiceConsumer { - private final ZooKeeperWatcher watcher; - - public ZooKeeperServiceConsumer(Configuration configuration) { - watcher = new ZooKeeperWatcher(configuration); - } - - @Override - public boolean exists(String path) { - return ZKUtil.exists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } - - @Override - public byte[] getDataAndWatch(String path) { - return ZKUtil.getDataAndWatch(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } - - @Override - public void register(ServiceListener listener) { - watcher.registerListener(listener); - } - - @Override - public void close() { - watcher.close(); - } - - + private final ZooKeeperWatcher watcher; + + public ZooKeeperServiceConsumer(Configuration configuration) { + watcher = new ZooKeeperWatcher(configuration); + } + + @Override + public boolean exists(String path) { + return ZKUtil.exists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } + + @Override + public byte[] getDataAndWatch(String path) { + return ZKUtil.getDataAndWatch(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } + + @Override + public void register(ServiceListener listener) { + watcher.registerListener(listener); + } + + @Override + public void close() { + watcher.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceProvider.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceProvider.java index b8d98bc19..7f09c7e88 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceProvider.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperServiceProvider.java @@ -25,51 +25,50 @@ public class ZooKeeperServiceProvider implements ServiceProvider { - private final ZooKeeperWatcher watcher; + private final ZooKeeperWatcher watcher; - public ZooKeeperServiceProvider(Configuration configuration) { - watcher = new ZooKeeperWatcher(configuration, true); - } + public ZooKeeperServiceProvider(Configuration configuration) { + watcher = new ZooKeeperWatcher(configuration, true); + } - @Override - public boolean exists(String path) { - return ZKUtil.exists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } + @Override + public boolean exists(String path) { + return ZKUtil.exists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } - @Override - public boolean watchAndCheckExists(String path) { - return ZKUtil.watchAndCheckExists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } + @Override + public boolean watchAndCheckExists(String path) { + return ZKUtil.watchAndCheckExists(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } - @Override - public void delete(String path) { - ZKUtil.deleteNode(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } + @Override + public void delete(String path) { + ZKUtil.deleteNode(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } - @Override - public byte[] getDataAndWatch(String path) { - return ZKUtil.getDataAndWatch(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); - } + @Override + public byte[] getDataAndWatch(String path) { + return ZKUtil.getDataAndWatch(watcher, ZKUtil.joinZNode(watcher.baseZNode, path)); + } - @Override - public boolean createAndWatch(String path, byte[] data) { - return ZKUtil.createEphemeralNodeAndWatch(watcher, - ZKUtil.joinZNode(watcher.baseZNode, path), data); - } + @Override + public boolean createAndWatch(String path, byte[] data) { + return ZKUtil.createEphemeralNodeAndWatch( + watcher, ZKUtil.joinZNode(watcher.baseZNode, path), data); + } - @Override - public boolean update(String path, byte[] data) { - return ZKUtil.updatePersistentNode(watcher, ZKUtil.joinZNode(watcher.baseZNode, path), - data); - } + @Override + public boolean update(String path, byte[] data) { + return ZKUtil.updatePersistentNode(watcher, ZKUtil.joinZNode(watcher.baseZNode, path), data); + } - @Override - public void register(ServiceListener listener) { - watcher.registerListener(listener); - } + @Override + public void register(ServiceListener listener) { + watcher.registerListener(listener); + } - @Override - public void close() { - watcher.close(); - } + @Override + public void close() { + watcher.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperWatcher.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperWatcher.java index 60e3b2b2e..4c1c3d17c 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperWatcher.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/main/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperWatcher.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,93 +33,93 @@ public class ZooKeeperWatcher implements Watcher { - private static final Logger LOGGER = LoggerFactory.getLogger(ZooKeeperWatcher.class); - - private final Configuration conf; - public String baseZNode; + private static final Logger LOGGER = LoggerFactory.getLogger(ZooKeeperWatcher.class); + + private final Configuration conf; + public String baseZNode; + + private final List listeners = new CopyOnWriteArrayList(); + private final RecoverableZooKeeper zooKeeper; + + public ZooKeeperWatcher(Configuration conf) { + this(conf, false); + } + + public ZooKeeperWatcher(Configuration conf, boolean canCreateBaseZNode) { + this.conf = conf; + + try { + this.zooKeeper = ZKUtil.connect(conf, this); + String jobName = conf.getString(ExecutionConfigKeys.JOB_APP_NAME); + baseZNode = conf.getString(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE, "/" + jobName); + LOGGER.info("zk node {}", baseZNode); + if (canCreateBaseZNode) { + createBaseZNodes(); + } + } catch (Exception t) { + LOGGER.error("watcher init failed", t); + close(); + throw new GeaflowRuntimeException(t); + } + } - private final List listeners = new CopyOnWriteArrayList(); - private final RecoverableZooKeeper zooKeeper; + @Override + public void process(WatchedEvent event) { - public ZooKeeperWatcher(Configuration conf) { - this(conf, false); - } + switch (event.getType()) { - public ZooKeeperWatcher(Configuration conf, boolean canCreateBaseZNode) { - this.conf = conf; - - try { - this.zooKeeper = ZKUtil.connect(conf, this); - String jobName = conf.getString(ExecutionConfigKeys.JOB_APP_NAME); - baseZNode = conf.getString(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE, "/" + jobName); - LOGGER.info("zk node {}", baseZNode); - if (canCreateBaseZNode) { - createBaseZNodes(); - } - } catch (Exception t) { - LOGGER.error("watcher init failed", t); - close(); - throw new GeaflowRuntimeException(t); + // Otherwise pass along to the listeners + case NodeCreated: + { + for (ServiceListener listener : listeners) { + listener.nodeCreated(event.getPath()); + } + break; } - } - @Override - public void process(WatchedEvent event) { - - switch (event.getType()) { - - // Otherwise pass along to the listeners - case NodeCreated: { - for (ServiceListener listener : listeners) { - listener.nodeCreated(event.getPath()); - } - break; - } - - case NodeDeleted: { - for (ServiceListener listener : listeners) { - listener.nodeDeleted(event.getPath()); - } - break; - } - case NodeDataChanged: { - for (ServiceListener listener : listeners) { - listener.nodeDataChanged(event.getPath()); - } - break; - } - default: - break; + case NodeDeleted: + { + for (ServiceListener listener : listeners) { + listener.nodeDeleted(event.getPath()); + } + break; } - - } - - protected void createBaseZNodes() { - try { - // Create all the necessary "directories" of znodes - ZKUtil.createPersistentNode(this, baseZNode); - } catch (Exception e) { - throw new GeaflowRuntimeException("Unexpected KeeperException creating base node", e); + case NodeDataChanged: + { + for (ServiceListener listener : listeners) { + listener.nodeDataChanged(event.getPath()); + } + break; } + default: + break; } - - public void close() { - try { - if (zooKeeper != null) { - zooKeeper.close(); - } - } catch (InterruptedException e) { - LOGGER.error("close exception", e); - } + } + + protected void createBaseZNodes() { + try { + // Create all the necessary "directories" of znodes + ZKUtil.createPersistentNode(this, baseZNode); + } catch (Exception e) { + throw new GeaflowRuntimeException("Unexpected KeeperException creating base node", e); } - - - public void registerListener(ServiceListener listener) { - listeners.add(listener); + } + + public void close() { + try { + if (zooKeeper != null) { + zooKeeper.close(); + } + } catch (InterruptedException e) { + LOGGER.error("close exception", e); } + } - public RecoverableZooKeeper getRecoverableZooKeeper() { - return zooKeeper; - } + public void registerListener(ServiceListener listener) { + listeners.add(listener); + } + public RecoverableZooKeeper getRecoverableZooKeeper() { + return zooKeeper; + } } diff --git a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/test/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperTest.java b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/test/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperTest.java index 8109b92ed..5bb3e6054 100644 --- a/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/test/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperTest.java +++ b/geaflow/geaflow-plugins/geaflow-service-discovery/geaflow-service-discovery-zookeeper/src/test/java/org/apache/geaflow/service/discovery/zookeeper/ZooKeeperTest.java @@ -19,11 +19,11 @@ package org.apache.geaflow.service.discovery.zookeeper; -import com.google.common.primitives.Longs; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.apache.curator.test.TestingServer; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -37,199 +37,204 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -public class ZooKeeperTest { +import com.google.common.primitives.Longs; - private static final String MASTER = "master"; +public class ZooKeeperTest { - private Configuration configuration; + private static final String MASTER = "master"; - private ZooKeeperServiceProvider serviceProvider; + private Configuration configuration; - private TestListener listener; + private ZooKeeperServiceProvider serviceProvider; - private String serviceType = "zookeeper"; + private TestListener listener; - private String baseNode; + private String serviceType = "zookeeper"; - private File testDir; + private String baseNode; - private TestingServer server; + private File testDir; - @BeforeMethod - public void before() throws Exception { - testDir = new File("/tmp/zk" + System.currentTimeMillis()); - if (!testDir.exists()) { - testDir.mkdir(); - } - int port = PortUtil.getPort(5000, 6000); - server = new TestingServer(port, testDir); - server.start(); + private TestingServer server; - baseNode = "/test_zk" + System.currentTimeMillis(); - configuration = new Configuration(); - configuration.put(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE, baseNode); - configuration.put(ZooKeeperConfigKeys.ZOOKEEPER_QUORUM_SERVERS, "localhost:" + port); + @BeforeMethod + public void before() throws Exception { + testDir = new File("/tmp/zk" + System.currentTimeMillis()); + if (!testDir.exists()) { + testDir.mkdir(); } + int port = PortUtil.getPort(5000, 6000); + server = new TestingServer(port, testDir); + server.start(); - @AfterMethod - public void tearDown() throws IOException { - server.stop(); - testDir.delete(); - } + baseNode = "/test_zk" + System.currentTimeMillis(); + configuration = new Configuration(); + configuration.put(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE, baseNode); + configuration.put(ZooKeeperConfigKeys.ZOOKEEPER_QUORUM_SERVERS, "localhost:" + port); + } - @Test - public void testCreateBaseNode() { - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(configuration); + @AfterMethod + public void tearDown() throws IOException { + server.stop(); + testDir.delete(); + } - Assert.assertTrue(serviceProvider.exists("")); - } + @Test + public void testCreateBaseNode() { + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - @Test - public void testCreateSequential() throws KeeperException { - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(configuration); - listener = new TestListener(); - serviceProvider.register(listener); + Assert.assertTrue(serviceProvider.exists("")); + } - serviceProvider.watchAndCheckExists(MASTER); + @Test + public void testCreateSequential() throws KeeperException { + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + listener = new TestListener(); + serviceProvider.register(listener); - String data = "123"; - boolean res = serviceProvider.createAndWatch(MASTER, data.getBytes()); - Assert.assertTrue(res); - listener.updatePath = null; + serviceProvider.watchAndCheckExists(MASTER); - Assert.assertTrue(serviceProvider.exists(MASTER)); + String data = "123"; + boolean res = serviceProvider.createAndWatch(MASTER, data.getBytes()); + Assert.assertTrue(res); + listener.updatePath = null; - byte[] datas = serviceProvider.getDataAndWatch(MASTER); - Assert.assertEquals(datas, data.getBytes()); + Assert.assertTrue(serviceProvider.exists(MASTER)); - res = serviceProvider.createAndWatch(MASTER, "data".getBytes()); + byte[] datas = serviceProvider.getDataAndWatch(MASTER); + Assert.assertEquals(datas, data.getBytes()); - Assert.assertFalse(res); + res = serviceProvider.createAndWatch(MASTER, "data".getBytes()); - serviceProvider.delete(MASTER); + Assert.assertFalse(res); - res = serviceProvider.createAndWatch(MASTER, "data".getBytes()); + serviceProvider.delete(MASTER); - Assert.assertTrue(res); + res = serviceProvider.createAndWatch(MASTER, "data".getBytes()); - serviceProvider.delete(MASTER); - } + Assert.assertTrue(res); - @Test - public void testDelete() throws KeeperException { + serviceProvider.delete(MASTER); + } - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(configuration); - listener = new TestListener(); - serviceProvider.register(listener); + @Test + public void testDelete() throws KeeperException { - boolean res = serviceProvider.createAndWatch(MASTER, "123".getBytes()); + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(configuration); + listener = new TestListener(); + serviceProvider.register(listener); - Assert.assertTrue(res); + boolean res = serviceProvider.createAndWatch(MASTER, "123".getBytes()); - Assert.assertTrue(serviceProvider.exists(MASTER)); + Assert.assertTrue(res); - serviceProvider.delete(MASTER); + Assert.assertTrue(serviceProvider.exists(MASTER)); - Assert.assertFalse(serviceProvider.exists(MASTER)); + serviceProvider.delete(MASTER); - serviceProvider.close(); - } + Assert.assertFalse(serviceProvider.exists(MASTER)); - @Test - public void testConsumer() { + serviceProvider.close(); + } - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(configuration); + @Test + public void testConsumer() { - ServiceConsumer consumer = ServiceBuilderFactory.build(serviceType) - .buildConsumer(configuration); + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - listener = new TestListener(); - consumer.register(listener); + ServiceConsumer consumer = + ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); - boolean res = serviceProvider.createAndWatch(MASTER, "123".getBytes()); - Assert.assertTrue(res); + listener = new TestListener(); + consumer.register(listener); - byte[] datas = consumer.getDataAndWatch(MASTER); + boolean res = serviceProvider.createAndWatch(MASTER, "123".getBytes()); + Assert.assertTrue(res); - Assert.assertEquals(datas, "123".getBytes()); + byte[] datas = consumer.getDataAndWatch(MASTER); - serviceProvider.delete(MASTER); - } - - @Test - public void testUpdate() { - String version = "version"; - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(configuration); + Assert.assertEquals(datas, "123".getBytes()); - ServiceConsumer consumer = ServiceBuilderFactory.build(serviceType) - .buildConsumer(configuration); - listener = new TestListener(); - consumer.register(listener); + serviceProvider.delete(MASTER); + } - Assert.assertFalse(consumer.exists(version)); - byte[] versionData = Longs.toByteArray(2); + @Test + public void testUpdate() { + String version = "version"; + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(configuration); - serviceProvider.update(version, versionData); + ServiceConsumer consumer = + ServiceBuilderFactory.build(serviceType).buildConsumer(configuration); + listener = new TestListener(); + consumer.register(listener); - byte[] data = consumer.getDataAndWatch(version); + Assert.assertFalse(consumer.exists(version)); + byte[] versionData = Longs.toByteArray(2); - Assert.assertEquals(versionData, data); + serviceProvider.update(version, versionData); - versionData = Longs.toByteArray(3); - serviceProvider.update(version, versionData); + byte[] data = consumer.getDataAndWatch(version); - data = consumer.getDataAndWatch(version); - Assert.assertEquals(versionData, data); - - consumer.close(); - } + Assert.assertEquals(versionData, data); - @Test - public void testBaseNode() { + versionData = Longs.toByteArray(3); + serviceProvider.update(version, versionData); - Map config = configuration.getConfigMap(); - Configuration newConfig = new Configuration(new HashMap<>(config)); - newConfig.getConfigMap().remove(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE.getKey()); - newConfig.put(ExecutionConfigKeys.JOB_APP_NAME, "234"); + data = consumer.getDataAndWatch(version); + Assert.assertEquals(versionData, data); - serviceProvider = (ZooKeeperServiceProvider) ServiceBuilderFactory.build(serviceType) - .buildProvider(newConfig); + consumer.close(); + } - Assert.assertTrue(serviceProvider.exists(null)); - serviceProvider.close(); - ServiceBuilderFactory.clear(); - } + @Test + public void testBaseNode() { + Map config = configuration.getConfigMap(); + Configuration newConfig = new Configuration(new HashMap<>(config)); + newConfig.getConfigMap().remove(ZooKeeperConfigKeys.ZOOKEEPER_BASE_NODE.getKey()); + newConfig.put(ExecutionConfigKeys.JOB_APP_NAME, "234"); - static class TestListener implements ServiceListener { + serviceProvider = + (ZooKeeperServiceProvider) + ServiceBuilderFactory.build(serviceType).buildProvider(newConfig); - private String updatePath; + Assert.assertTrue(serviceProvider.exists(null)); + serviceProvider.close(); + ServiceBuilderFactory.clear(); + } - @Override - public void nodeCreated(String path) { - updatePath = path; - } + static class TestListener implements ServiceListener { - @Override - public void nodeDeleted(String path) { - updatePath = path; - } + private String updatePath; - @Override - public void nodeDataChanged(String path) { - updatePath = path; - } + @Override + public void nodeCreated(String path) { + updatePath = path; + } - @Override - public void nodeChildrenChanged(String path) { - updatePath = path; - } + @Override + public void nodeDeleted(String path) { + updatePath = path; } + @Override + public void nodeDataChanged(String path) { + updatePath = path; + } + @Override + public void nodeChildrenChanged(String path) { + updatePath = path; + } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IBaseStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IBaseStore.java index 7cf02466e..34ad03e02 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IBaseStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IBaseStore.java @@ -21,23 +21,15 @@ import org.apache.geaflow.store.context.StoreContext; -/** - * Basic store interface for all types of store. - */ +/** Basic store interface for all types of store. */ public interface IBaseStore { - /** - * init the store context, including all setup properties. - */ - void init(StoreContext storeContext); + /** init the store context, including all setup properties. */ + void init(StoreContext storeContext); - /** - * flush memory data to disk. - */ - void flush(); + /** flush memory data to disk. */ + void flush(); - /** - * close the store handler and all other used resources. - */ - void close(); + /** close the store handler and all other used resources. */ + void close(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStatefulStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStatefulStore.java index d48dbd273..87685cd2a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStatefulStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStatefulStore.java @@ -19,33 +19,21 @@ package org.apache.geaflow.store; -/** - * IStateful store is stateful, which means it ensure data HA and can be recovered. - */ +/** IStateful store is stateful, which means it ensure data HA and can be recovered. */ public interface IStatefulStore extends IBaseStore { - /** - * make a snapshot and ensure data HA. - */ - void archive(long checkpointId); + /** make a snapshot and ensure data HA. */ + void archive(long checkpointId); - /** - * recover the store data. - */ - void recovery(long checkpointId); + /** recover the store data. */ + void recovery(long checkpointId); - /** - * recover the latest store data. - */ - long recoveryLatest(); + /** recover the latest store data. */ + long recoveryLatest(); - /** - * trigger manual store data compaction. - */ - void compact(); + /** trigger manual store data compaction. */ + void compact(); - /** - * delete the store data. - */ - void drop(); + /** delete the store data. */ + void drop(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStoreBuilder.java index 0acddbf39..188871014 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/IStoreBuilder.java @@ -20,26 +20,19 @@ package org.apache.geaflow.store; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.state.DataModel; -/** - * The store builder interface is built by SPI, and used to get the specific store. - */ +/** The store builder interface is built by SPI, and used to get the specific store. */ public interface IStoreBuilder { - /** - * Returns the specific store by {@link DataModel}. - */ - IBaseStore getStore(DataModel type, Configuration config); + /** Returns the specific store by {@link DataModel}. */ + IBaseStore getStore(DataModel type, Configuration config); - /** - * Returns the store descriptor. - */ - StoreDesc getStoreDesc(); + /** Returns the store descriptor. */ + StoreDesc getStoreDesc(); - /** - * Returns the data models supported by the store builder. - */ - List supportedDataModel(); + /** Returns the data models supported by the store builder. */ + List supportedDataModel(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/StoreDesc.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/StoreDesc.java index 957b58275..14ead7f9f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/StoreDesc.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/StoreDesc.java @@ -19,18 +19,12 @@ package org.apache.geaflow.store; -/** - * The store descriptor is used for telling the store attributes. - */ +/** The store descriptor is used for telling the store attributes. */ public interface StoreDesc { - /** - * Returns the store is local or not. - */ - boolean isLocalStore(); + /** Returns the store is local or not. */ + boolean isLocalStore(); - /** - * Returns the store type name. - */ - String name(); + /** Returns the store type name. */ + String name(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/StoreBuilderFactory.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/StoreBuilderFactory.java index f49b0d097..3e864bb9f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/StoreBuilderFactory.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/StoreBuilderFactory.java @@ -22,27 +22,27 @@ import java.util.Map; import java.util.ServiceLoader; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.store.IStoreBuilder; public class StoreBuilderFactory { - private static final Map CONCURRENT_TYPE_MAP = new ConcurrentHashMap<>(); - - public static synchronized IStoreBuilder build(String storeType) { - if (CONCURRENT_TYPE_MAP.containsKey(storeType)) { - return CONCURRENT_TYPE_MAP.get(storeType); - } + private static final Map CONCURRENT_TYPE_MAP = new ConcurrentHashMap<>(); - ServiceLoader serviceLoader = ServiceLoader.load(IStoreBuilder.class); - for (IStoreBuilder storeBuilder : serviceLoader) { - if (storeBuilder.getStoreDesc().name().equalsIgnoreCase(storeType)) { - CONCURRENT_TYPE_MAP.put(storeType, storeBuilder); - return storeBuilder; - } - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(storeType)); + public static synchronized IStoreBuilder build(String storeType) { + if (CONCURRENT_TYPE_MAP.containsKey(storeType)) { + return CONCURRENT_TYPE_MAP.get(storeType); } + ServiceLoader serviceLoader = ServiceLoader.load(IStoreBuilder.class); + for (IStoreBuilder storeBuilder : serviceLoader) { + if (storeBuilder.getStoreDesc().name().equalsIgnoreCase(storeType)) { + CONCURRENT_TYPE_MAP.put(storeType, storeBuilder); + return storeBuilder; + } + } + throw new GeaflowRuntimeException(RuntimeErrors.INST.spiNotFoundError(storeType)); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/BaseGraphStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/BaseGraphStore.java index 4b7d94d99..b5def4f87 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/BaseGraphStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/BaseGraphStore.java @@ -27,18 +27,19 @@ public abstract class BaseGraphStore implements IPushDownStore { - protected StoreContext storeContext; - protected IFilterConverter filterConverter; + protected StoreContext storeContext; + protected IFilterConverter filterConverter; - public void init(StoreContext storeContext) { - this.storeContext = storeContext; - this.filterConverter = - storeContext.getConfig().getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE) - ? new CodeGenFilterConverter() : new DirectFilterConverter(); - } + public void init(StoreContext storeContext) { + this.storeContext = storeContext; + this.filterConverter = + storeContext.getConfig().getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE) + ? new CodeGenFilterConverter() + : new DirectFilterConverter(); + } - @Override - public IFilterConverter getFilterConverter() { - return filterConverter; - } + @Override + public IFilterConverter getFilterConverter() { + return filterConverter; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IDynamicGraphStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IDynamicGraphStore.java index 18c12eaf5..065885ecb 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IDynamicGraphStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IDynamicGraphStore.java @@ -22,7 +22,5 @@ import org.apache.geaflow.state.graph.DynamicGraphTrait; import org.apache.geaflow.store.IStatefulStore; -public interface IDynamicGraphStore extends DynamicGraphTrait, IStatefulStore, - IPushDownStore { - -} +public interface IDynamicGraphStore + extends DynamicGraphTrait, IStatefulStore, IPushDownStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IPushDownStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IPushDownStore.java index f4021f47b..856b22428 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IPushDownStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IPushDownStore.java @@ -23,8 +23,6 @@ public interface IPushDownStore { - /** - * Filter can be code generated or converted. - */ - IFilterConverter getFilterConverter(); + /** Filter can be code generated or converted. */ + IFilterConverter getFilterConverter(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IStaticGraphStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IStaticGraphStore.java index a53b902d0..2f1204ede 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IStaticGraphStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/graph/IStaticGraphStore.java @@ -22,7 +22,5 @@ import org.apache.geaflow.state.graph.StaticGraphTrait; import org.apache.geaflow.store.IStatefulStore; -public interface IStaticGraphStore extends StaticGraphTrait, IStatefulStore, - IPushDownStore { - -} +public interface IStaticGraphStore + extends StaticGraphTrait, IStatefulStore, IPushDownStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKListStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKListStore.java index a9330f5ab..339a1b099 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKListStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKListStore.java @@ -22,6 +22,4 @@ import org.apache.geaflow.state.key.KeyListTrait; import org.apache.geaflow.store.IBaseStore; -public interface IKListStore extends KeyListTrait, IBaseStore { - -} +public interface IKListStore extends KeyListTrait, IBaseStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKMapStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKMapStore.java index 2927067cc..47c78eade 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKMapStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKMapStore.java @@ -22,6 +22,4 @@ import org.apache.geaflow.state.key.KeyMapTrait; import org.apache.geaflow.store.IBaseStore; -public interface IKMapStore extends KeyMapTrait, IBaseStore { - -} +public interface IKMapStore extends KeyMapTrait, IBaseStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStatefulStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStatefulStore.java index dd10e37d5..3c381aeb9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStatefulStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStatefulStore.java @@ -21,6 +21,4 @@ import org.apache.geaflow.store.IStatefulStore; -public interface IKVStatefulStore extends IKVStore, IStatefulStore { - -} +public interface IKVStatefulStore extends IKVStore, IStatefulStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStore.java index 2ee3161e7..9447b421a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/api/key/IKVStore.java @@ -22,6 +22,4 @@ import org.apache.geaflow.state.key.KeyValueTrait; import org.apache.geaflow.store.IBaseStore; -public interface IKVStore extends KeyValueTrait, IBaseStore { - -} +public interface IKVStore extends KeyValueTrait, IBaseStore {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/config/StoreConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/config/StoreConfigKeys.java index 259f54009..0ce2cad72 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/config/StoreConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/config/StoreConfigKeys.java @@ -20,13 +20,14 @@ package org.apache.geaflow.store.config; import java.io.Serializable; + import org.apache.geaflow.common.config.ConfigKey; import org.apache.geaflow.common.config.ConfigKeys; public class StoreConfigKeys implements Serializable { - public static final ConfigKey STORE_FILTER_CODEGEN_ENABLE = ConfigKeys - .key("geaflow.store.filter.codegen.enable") - .defaultValue(true) - .description("store enable filter codegen, default true"); + public static final ConfigKey STORE_FILTER_CODEGEN_ENABLE = + ConfigKeys.key("geaflow.store.filter.codegen.enable") + .defaultValue(true) + .description("store enable filter codegen, default true"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/context/StoreContext.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/context/StoreContext.java index b34965664..705fc6950 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/context/StoreContext.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/context/StoreContext.java @@ -19,87 +19,88 @@ package org.apache.geaflow.store.context; -import com.google.common.base.Preconditions; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.api.MetricGroup; import org.apache.geaflow.state.schema.GraphDataSchema; import org.apache.geaflow.state.serializer.IKeySerializer; +import com.google.common.base.Preconditions; + public class StoreContext { - private String name; - private Configuration config; - private MetricGroup metricGroup; - - private int shardId; - private long version; - private GraphDataSchema graphSchema; - private IKeySerializer keySerializer; - - public StoreContext(String name) { - this.name = name; - } - - public StoreContext withConfig(Configuration config) { - this.config = config; - return this; - } - - public StoreContext withShardId(int shardId) { - this.shardId = shardId; - return this; - } - - public StoreContext withMetricGroup(MetricGroup metricGroup) { - this.metricGroup = metricGroup; - return this; - } - - public StoreContext withVersion(long version) { - this.version = version; - return this; - } - - public StoreContext withName(String name) { - this.name = name; - return this; - } - - public StoreContext withDataSchema(GraphDataSchema dataSchema) { - this.graphSchema = dataSchema; - return this; - } - - public StoreContext withKeySerializer(IKeySerializer keySerializer) { - this.keySerializer = keySerializer; - return this; - } - - public Configuration getConfig() { - return config; - } - - public int getShardId() { - return shardId; - } - - public MetricGroup getMetricGroup() { - return metricGroup; - } - - public long getVersion() { - return version; - } - - public String getName() { - return name; - } - - public GraphDataSchema getGraphSchema() { - return Preconditions.checkNotNull(graphSchema, "GraphMeta must be set"); - } - - public IKeySerializer getKeySerializer() { - return keySerializer; - } + private String name; + private Configuration config; + private MetricGroup metricGroup; + + private int shardId; + private long version; + private GraphDataSchema graphSchema; + private IKeySerializer keySerializer; + + public StoreContext(String name) { + this.name = name; + } + + public StoreContext withConfig(Configuration config) { + this.config = config; + return this; + } + + public StoreContext withShardId(int shardId) { + this.shardId = shardId; + return this; + } + + public StoreContext withMetricGroup(MetricGroup metricGroup) { + this.metricGroup = metricGroup; + return this; + } + + public StoreContext withVersion(long version) { + this.version = version; + return this; + } + + public StoreContext withName(String name) { + this.name = name; + return this; + } + + public StoreContext withDataSchema(GraphDataSchema dataSchema) { + this.graphSchema = dataSchema; + return this; + } + + public StoreContext withKeySerializer(IKeySerializer keySerializer) { + this.keySerializer = keySerializer; + return this; + } + + public Configuration getConfig() { + return config; + } + + public int getShardId() { + return shardId; + } + + public MetricGroup getMetricGroup() { + return metricGroup; + } + + public long getVersion() { + return version; + } + + public String getName() { + return name; + } + + public GraphDataSchema getGraphSchema() { + return Preconditions.checkNotNull(graphSchema, "GraphMeta must be set"); + } + + public IKeySerializer getKeySerializer() { + return keySerializer; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushBuffer.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushBuffer.java index 8ee0fce92..198a952b9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushBuffer.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushBuffer.java @@ -26,6 +26,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; + import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; @@ -41,135 +42,143 @@ public class AsyncFlushBuffer { - private static final Logger LOGGER = LoggerFactory.getLogger(AsyncFlushBuffer.class); - private static final int SLEEP_MILLI_SECOND = 100; - - private final boolean deepCopy; - private final int bufferNum; - private final int bufferSize; - private final Consumer> flushFun; - private final ISerializer serializer; - private final ExceptionHandler flushError; - private List> buffers; - private int curWriteBufferIdx; - private ThreadPoolExecutor flushService; - - protected AtomicLong writeCounter = new AtomicLong(0); - - protected AtomicLong flushCounter = new AtomicLong(0); - - protected volatile Throwable exp; - - public AsyncFlushBuffer(Configuration config, - Consumer> flushFun, - ISerializer serializer) { - this.flushFun = flushFun; - this.serializer = serializer; - this.deepCopy = config.getBoolean(StateConfigKeys.STATE_WRITE_BUFFER_DEEP_COPY); - this.bufferNum = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_NUMBER); - this.bufferSize = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_SIZE); - this.flushError = exception -> { - exp = exception; - LOGGER.error("flush error", exception); + private static final Logger LOGGER = LoggerFactory.getLogger(AsyncFlushBuffer.class); + private static final int SLEEP_MILLI_SECOND = 100; + + private final boolean deepCopy; + private final int bufferNum; + private final int bufferSize; + private final Consumer> flushFun; + private final ISerializer serializer; + private final ExceptionHandler flushError; + private List> buffers; + private int curWriteBufferIdx; + private ThreadPoolExecutor flushService; + + protected AtomicLong writeCounter = new AtomicLong(0); + + protected AtomicLong flushCounter = new AtomicLong(0); + + protected volatile Throwable exp; + + public AsyncFlushBuffer( + Configuration config, + Consumer> flushFun, + ISerializer serializer) { + this.flushFun = flushFun; + this.serializer = serializer; + this.deepCopy = config.getBoolean(StateConfigKeys.STATE_WRITE_BUFFER_DEEP_COPY); + this.bufferNum = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_NUMBER); + this.bufferSize = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_SIZE); + this.flushError = + exception -> { + exp = exception; + LOGGER.error("flush error", exception); }; - initBuffer(); - } - - private void initBuffer() { - this.buffers = new ArrayList<>(bufferNum); - for (int i = 0; i < bufferNum; i++) { - this.buffers.add(new GraphWriteBuffer<>(bufferSize)); - } - this.curWriteBufferIdx = 0; - this.flushService = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, - new LinkedBlockingQueue<>(bufferNum + 1), new BasicThreadFactory.Builder().namingPattern("flush-%d").build()); - } - - public void flush() { - if (writeCounter.get() == 0) { - return; - } - for (int i = 0; i < bufferNum; i++) { - if (!buffers.get(i).isFlushing()) { - tryFlushBuffer(i, true); - } - } - - ExecutorUtil.spinLockMs(() -> flushCounter.get() == writeCounter.get(), - this::exceptionCheck, SLEEP_MILLI_SECOND); - flushCounter.set(0); - writeCounter.set(0); - exceptionCheck(); - } + initBuffer(); + } - private void exceptionCheck() { - if (exp != null) { - LOGGER.error("encounter exception"); - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(exp.getMessage()), exp); - } + private void initBuffer() { + this.buffers = new ArrayList<>(bufferNum); + for (int i = 0; i < bufferNum; i++) { + this.buffers.add(new GraphWriteBuffer<>(bufferSize)); } - - private void tryFlushBuffer(int idx, boolean force) { - GraphWriteBuffer buffer = buffers.get(idx); - boolean needFlush = buffer.needFlush() || (force && buffer.getSize() > 0); - if (!needFlush) { - return; - } - buffer.setFlushing(); - exceptionCheck(); - ExecutorUtil.execute(flushService, () -> flushWriteBuffer(buffer), flushError); - int toWriteBufferIdx = (idx + 1) % buffers.size(); - ExecutorUtil.spinLockMs(() -> !buffers.get(toWriteBufferIdx).needFlush(), - this::exceptionCheck, SLEEP_MILLI_SECOND); - - curWriteBufferIdx = toWriteBufferIdx; - } - - private void flushWriteBuffer(GraphWriteBuffer buffer) { - flushFun.accept(buffer); - flushCounter.addAndGet(buffer.getSize()); - buffer.clear(); + this.curWriteBufferIdx = 0; + this.flushService = + new ThreadPoolExecutor( + 1, + 1, + 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(bufferNum + 1), + new BasicThreadFactory.Builder().namingPattern("flush-%d").build()); + } + + public void flush() { + if (writeCounter.get() == 0) { + return; } - - public IVertex readBufferedVertex(K id) { - for (int i = 0; i < bufferNum; i++) { - int idx = (bufferNum + curWriteBufferIdx - i) % bufferNum; - GraphWriteBuffer buffer = buffers.get(idx); - IVertex vertex = buffer.getVertexId2Vertex().get(id); - if (vertex != null) { - return deepCopy ? serializer.copy(vertex) : vertex; - } - } - return null; + for (int i = 0; i < bufferNum; i++) { + if (!buffers.get(i).isFlushing()) { + tryFlushBuffer(i, true); + } } - public List> readBufferedEdges(K srcId) { - List> list = new ArrayList<>(); - for (int i = 0; i < bufferNum; i++) { - GraphWriteBuffer buffer = buffers.get(i); - List> edgeList = buffer.getVertexId2Edges().get(srcId); - if (edgeList != null) { - list.addAll(deepCopy ? serializer.copy(edgeList) : edgeList); - } - } - return list; + ExecutorUtil.spinLockMs( + () -> flushCounter.get() == writeCounter.get(), this::exceptionCheck, SLEEP_MILLI_SECOND); + flushCounter.set(0); + writeCounter.set(0); + exceptionCheck(); + } + + private void exceptionCheck() { + if (exp != null) { + LOGGER.error("encounter exception"); + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(exp.getMessage()), exp); } + } - public void addVertex(IVertex vertex) { - writeCounter.incrementAndGet(); - buffers.get(curWriteBufferIdx).addVertex(vertex); - tryFlushBuffer(curWriteBufferIdx, false); + private void tryFlushBuffer(int idx, boolean force) { + GraphWriteBuffer buffer = buffers.get(idx); + boolean needFlush = buffer.needFlush() || (force && buffer.getSize() > 0); + if (!needFlush) { + return; } - - public void addEdge(IEdge edge) { - writeCounter.incrementAndGet(); - buffers.get(curWriteBufferIdx).addEdge(edge); - tryFlushBuffer(curWriteBufferIdx, false); + buffer.setFlushing(); + exceptionCheck(); + ExecutorUtil.execute(flushService, () -> flushWriteBuffer(buffer), flushError); + int toWriteBufferIdx = (idx + 1) % buffers.size(); + ExecutorUtil.spinLockMs( + () -> !buffers.get(toWriteBufferIdx).needFlush(), this::exceptionCheck, SLEEP_MILLI_SECOND); + + curWriteBufferIdx = toWriteBufferIdx; + } + + private void flushWriteBuffer(GraphWriteBuffer buffer) { + flushFun.accept(buffer); + flushCounter.addAndGet(buffer.getSize()); + buffer.clear(); + } + + public IVertex readBufferedVertex(K id) { + for (int i = 0; i < bufferNum; i++) { + int idx = (bufferNum + curWriteBufferIdx - i) % bufferNum; + GraphWriteBuffer buffer = buffers.get(idx); + IVertex vertex = buffer.getVertexId2Vertex().get(id); + if (vertex != null) { + return deepCopy ? serializer.copy(vertex) : vertex; + } } - - public void close() { - flushService.shutdown(); - buffers.forEach(GraphWriteBuffer::clear); - buffers.clear(); + return null; + } + + public List> readBufferedEdges(K srcId) { + List> list = new ArrayList<>(); + for (int i = 0; i < bufferNum; i++) { + GraphWriteBuffer buffer = buffers.get(i); + List> edgeList = buffer.getVertexId2Edges().get(srcId); + if (edgeList != null) { + list.addAll(deepCopy ? serializer.copy(edgeList) : edgeList); + } } + return list; + } + + public void addVertex(IVertex vertex) { + writeCounter.incrementAndGet(); + buffers.get(curWriteBufferIdx).addVertex(vertex); + tryFlushBuffer(curWriteBufferIdx, false); + } + + public void addEdge(IEdge edge) { + writeCounter.incrementAndGet(); + buffers.get(curWriteBufferIdx).addEdge(edge); + tryFlushBuffer(curWriteBufferIdx, false); + } + + public void close() { + flushService.shutdown(); + buffers.forEach(GraphWriteBuffer::clear); + buffers.clear(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushMultiVersionedBuffer.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushMultiVersionedBuffer.java index e7511dda0..874b40291 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushMultiVersionedBuffer.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/AsyncFlushMultiVersionedBuffer.java @@ -26,6 +26,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; + import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; @@ -41,131 +42,139 @@ public class AsyncFlushMultiVersionedBuffer { - private static final Logger LOGGER = LoggerFactory.getLogger(AsyncFlushMultiVersionedBuffer.class); - private static final int SLEEP_MILLI_SECOND = 100; - - private final boolean deepCopy; - private final int bufferNum; - private final int bufferSize; - private final Consumer> flushFun; - private final ISerializer serializer; - private final ExceptionHandler flushError; - private List> buffers; - private int curWriteBufferIdx; - private ThreadPoolExecutor flushService; - - protected AtomicLong writeCounter = new AtomicLong(0); - protected AtomicLong flushCounter = new AtomicLong(0); - protected volatile Throwable exp; - - public AsyncFlushMultiVersionedBuffer(Configuration config, - Consumer> flushFun, - ISerializer serializer) { - this.flushFun = flushFun; - this.serializer = serializer; - this.deepCopy = config.getBoolean(StateConfigKeys.STATE_WRITE_BUFFER_DEEP_COPY); - this.bufferNum = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_NUMBER); - this.bufferSize = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_SIZE); - this.flushError = exception -> { - exp = exception; - LOGGER.error("flush error", exception); + private static final Logger LOGGER = + LoggerFactory.getLogger(AsyncFlushMultiVersionedBuffer.class); + private static final int SLEEP_MILLI_SECOND = 100; + + private final boolean deepCopy; + private final int bufferNum; + private final int bufferSize; + private final Consumer> flushFun; + private final ISerializer serializer; + private final ExceptionHandler flushError; + private List> buffers; + private int curWriteBufferIdx; + private ThreadPoolExecutor flushService; + + protected AtomicLong writeCounter = new AtomicLong(0); + protected AtomicLong flushCounter = new AtomicLong(0); + protected volatile Throwable exp; + + public AsyncFlushMultiVersionedBuffer( + Configuration config, + Consumer> flushFun, + ISerializer serializer) { + this.flushFun = flushFun; + this.serializer = serializer; + this.deepCopy = config.getBoolean(StateConfigKeys.STATE_WRITE_BUFFER_DEEP_COPY); + this.bufferNum = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_NUMBER); + this.bufferSize = config.getInteger(StateConfigKeys.STATE_WRITE_BUFFER_SIZE); + this.flushError = + exception -> { + exp = exception; + LOGGER.error("flush error", exception); }; - initBuffer(); - } + initBuffer(); + } - private void initBuffer() { - this.buffers = new ArrayList<>(bufferNum); - for (int i = 0; i < bufferNum; i++) { - this.buffers.add(new GraphWriteMultiVersionedBuffer<>(bufferSize)); - } - this.curWriteBufferIdx = 0; - this.flushService = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, + private void initBuffer() { + this.buffers = new ArrayList<>(bufferNum); + for (int i = 0; i < bufferNum; i++) { + this.buffers.add(new GraphWriteMultiVersionedBuffer<>(bufferSize)); + } + this.curWriteBufferIdx = 0; + this.flushService = + new ThreadPoolExecutor( + 1, + 1, + 0L, + TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(bufferNum * 2), new BasicThreadFactory.Builder().namingPattern("flush-%d").build()); - } - - public void flush() { - for (int i = 0; i < bufferNum; i++) { - if (!buffers.get(i).isFlushing()) { - tryFlushBuffer(i, true); - } - } - - ExecutorUtil.spinLockMs(() -> flushCounter.get() == writeCounter.get(), - this::exceptionCheck, SLEEP_MILLI_SECOND); - // LOGGER.info("flushCount {}", flushCounter.get()); - flushCounter.set(0); - writeCounter.set(0); - exceptionCheck(); - } - - private void exceptionCheck() { - if (exp != null) { - LOGGER.error("encounter exception"); - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(exp.getMessage()), exp); - } - } - - private void tryFlushBuffer(int idx, boolean force) { - GraphWriteMultiVersionedBuffer buffer = buffers.get(idx); - boolean needFlush = buffer.needFlush() || (force && buffer.getSize() > 0); - if (!needFlush) { - return; - } - buffer.setFlushing(); - exceptionCheck(); - ExecutorUtil.execute(flushService, () -> flushWriteBuffer(buffer), flushError); - int toWriteBufferIdx = (idx + 1) % buffers.size(); - ExecutorUtil.spinLockMs(() -> !buffers.get(toWriteBufferIdx).needFlush(), - this::exceptionCheck, SLEEP_MILLI_SECOND); - - curWriteBufferIdx = toWriteBufferIdx; - } + } - private void flushWriteBuffer(GraphWriteMultiVersionedBuffer buffer) { - flushFun.accept(buffer); - flushCounter.addAndGet(buffer.getSize()); - buffer.clear(); + public void flush() { + for (int i = 0; i < bufferNum; i++) { + if (!buffers.get(i).isFlushing()) { + tryFlushBuffer(i, true); + } } - public IVertex readBufferedVertex(long version, K id) { - for (int i = 0; i < bufferNum; i++) { - int idx = (bufferNum + curWriteBufferIdx - i) % bufferNum; - IVertex vertex = buffers.get(idx).getVertex(version, id); - if (vertex != null) { - return deepCopy ? serializer.copy(vertex) : vertex; - } - } - return null; + ExecutorUtil.spinLockMs( + () -> flushCounter.get() == writeCounter.get(), this::exceptionCheck, SLEEP_MILLI_SECOND); + // LOGGER.info("flushCount {}", flushCounter.get()); + flushCounter.set(0); + writeCounter.set(0); + exceptionCheck(); + } + + private void exceptionCheck() { + if (exp != null) { + LOGGER.error("encounter exception"); + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(exp.getMessage()), exp); } + } - public List> readBufferedEdges(long version, K srcId) { - List> list = new ArrayList<>(); - for (int i = 0; i < bufferNum; i++) { - List> edgeList = buffers.get(i).getEdges(version, srcId); - if (edgeList != null && edgeList.size() > 0) { - list.addAll(deepCopy ? serializer.copy(edgeList) : edgeList); - } - } - return list; + private void tryFlushBuffer(int idx, boolean force) { + GraphWriteMultiVersionedBuffer buffer = buffers.get(idx); + boolean needFlush = buffer.needFlush() || (force && buffer.getSize() > 0); + if (!needFlush) { + return; } - - public void addVertex(long version, IVertex vertex) { - writeCounter.incrementAndGet(); - buffers.get(curWriteBufferIdx).addVertex(version, vertex); - tryFlushBuffer(curWriteBufferIdx, false); + buffer.setFlushing(); + exceptionCheck(); + ExecutorUtil.execute(flushService, () -> flushWriteBuffer(buffer), flushError); + int toWriteBufferIdx = (idx + 1) % buffers.size(); + ExecutorUtil.spinLockMs( + () -> !buffers.get(toWriteBufferIdx).needFlush(), this::exceptionCheck, SLEEP_MILLI_SECOND); + + curWriteBufferIdx = toWriteBufferIdx; + } + + private void flushWriteBuffer(GraphWriteMultiVersionedBuffer buffer) { + flushFun.accept(buffer); + flushCounter.addAndGet(buffer.getSize()); + buffer.clear(); + } + + public IVertex readBufferedVertex(long version, K id) { + for (int i = 0; i < bufferNum; i++) { + int idx = (bufferNum + curWriteBufferIdx - i) % bufferNum; + IVertex vertex = buffers.get(idx).getVertex(version, id); + if (vertex != null) { + return deepCopy ? serializer.copy(vertex) : vertex; + } } - - public void addEdge(long version, IEdge edge) { - writeCounter.incrementAndGet(); - buffers.get(curWriteBufferIdx).addEdge(version, edge); - tryFlushBuffer(curWriteBufferIdx, false); - } - - public void close() { - flush(); - flushService.shutdown(); - buffers.forEach(GraphWriteMultiVersionedBuffer::clear); - buffers.clear(); + return null; + } + + public List> readBufferedEdges(long version, K srcId) { + List> list = new ArrayList<>(); + for (int i = 0; i < bufferNum; i++) { + List> edgeList = buffers.get(i).getEdges(version, srcId); + if (edgeList != null && edgeList.size() > 0) { + list.addAll(deepCopy ? serializer.copy(edgeList) : edgeList); + } } + return list; + } + + public void addVertex(long version, IVertex vertex) { + writeCounter.incrementAndGet(); + buffers.get(curWriteBufferIdx).addVertex(version, vertex); + tryFlushBuffer(curWriteBufferIdx, false); + } + + public void addEdge(long version, IEdge edge) { + writeCounter.incrementAndGet(); + buffers.get(curWriteBufferIdx).addEdge(version, edge); + tryFlushBuffer(curWriteBufferIdx, false); + } + + public void close() { + flush(); + flushService.shutdown(); + buffers.forEach(GraphWriteMultiVersionedBuffer::clear); + buffers.clear(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteBuffer.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteBuffer.java index 41ab723ea..c56b3df8f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteBuffer.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteBuffer.java @@ -24,75 +24,76 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; public class GraphWriteBuffer implements Serializable { - private final int thresholdSize; - private volatile int size; - private volatile boolean isFlushing; - private Map>> vertexId2Edges; - private Map> vertexId2Vertex; - - public GraphWriteBuffer(int thresholdSize) { - this.size = 0; - this.isFlushing = false; - this.thresholdSize = thresholdSize; - this.vertexId2Edges = new ConcurrentHashMap<>(); - this.vertexId2Vertex = new ConcurrentHashMap<>(); - } + private final int thresholdSize; + private volatile int size; + private volatile boolean isFlushing; + private Map>> vertexId2Edges; + private Map> vertexId2Vertex; - public void addVertex(IVertex vertex) { - vertexId2Vertex.put(vertex.getId(), vertex); - this.size++; - } + public GraphWriteBuffer(int thresholdSize) { + this.size = 0; + this.isFlushing = false; + this.thresholdSize = thresholdSize; + this.vertexId2Edges = new ConcurrentHashMap<>(); + this.vertexId2Vertex = new ConcurrentHashMap<>(); + } - public void addEdge(IEdge edge) { - List> list = vertexId2Edges.computeIfAbsent(edge.getSrcId(), - k -> new ArrayList<>()); - list.add(edge); - this.size++; - } + public void addVertex(IVertex vertex) { + vertexId2Vertex.put(vertex.getId(), vertex); + this.size++; + } - public void addEdges(List> edges) { - if (edges == null || edges.size() == 0) { - return; - } - List> list = - vertexId2Edges.computeIfAbsent(edges.get(0).getSrcId(), k -> new ArrayList<>()); - list.addAll(edges); - this.size += edges.size(); - } + public void addEdge(IEdge edge) { + List> list = + vertexId2Edges.computeIfAbsent(edge.getSrcId(), k -> new ArrayList<>()); + list.add(edge); + this.size++; + } - public void setFlushing() { - isFlushing = true; + public void addEdges(List> edges) { + if (edges == null || edges.size() == 0) { + return; } + List> list = + vertexId2Edges.computeIfAbsent(edges.get(0).getSrcId(), k -> new ArrayList<>()); + list.addAll(edges); + this.size += edges.size(); + } - public boolean isFlushing() { - return isFlushing; - } + public void setFlushing() { + isFlushing = true; + } - public boolean needFlush() { - return size > thresholdSize; - } + public boolean isFlushing() { + return isFlushing; + } - public int getSize() { - return size; - } + public boolean needFlush() { + return size > thresholdSize; + } - public Map>> getVertexId2Edges() { - return vertexId2Edges; - } + public int getSize() { + return size; + } - public Map> getVertexId2Vertex() { - return vertexId2Vertex; - } + public Map>> getVertexId2Edges() { + return vertexId2Edges; + } - public void clear() { - this.vertexId2Vertex.clear(); - this.vertexId2Edges.clear(); - this.size = 0; - this.isFlushing = false; - } + public Map> getVertexId2Vertex() { + return vertexId2Vertex; + } + + public void clear() { + this.vertexId2Vertex.clear(); + this.vertexId2Edges.clear(); + this.size = 0; + this.isFlushing = false; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteMultiVersionedBuffer.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteMultiVersionedBuffer.java index c62c33b11..2e176b9c6 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteMultiVersionedBuffer.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/data/GraphWriteMultiVersionedBuffer.java @@ -25,87 +25,87 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; public class GraphWriteMultiVersionedBuffer implements Serializable { - private final int thresholdSize; - private volatile int size; - private volatile boolean isFlushing; - private Map>>> vertexId2Edges; - private Map>> vertexId2Vertex; - - public GraphWriteMultiVersionedBuffer(int thresholdSize) { - this.size = 0; - this.isFlushing = false; - this.thresholdSize = thresholdSize; - this.vertexId2Edges = new ConcurrentHashMap<>(); - this.vertexId2Vertex = new ConcurrentHashMap<>(); - } - - public void addEdge(long version, IEdge edge) { - K srcId = edge.getSrcId(); - Map>> map = vertexId2Edges.computeIfAbsent(srcId, - k -> new HashMap<>()); - List> edges = map.computeIfAbsent(version, k -> new ArrayList<>()); - edges.add(edge); - this.size++; - } - - public void addVertex(long version, IVertex vertex) { - K id = vertex.getId(); - Map> map = vertexId2Vertex.computeIfAbsent(id, k -> new HashMap<>()); - map.put(version, vertex); - this.size++; - } - - public IVertex getVertex(long version, K sid) { - Map> map = vertexId2Vertex.get(sid); - IVertex vertex = null; - if (map != null) { - vertex = map.get(version); - } - return vertex; - } - - public List> getEdges(long version, K sid) { - Map>> map = vertexId2Edges.get(sid); - List> list = new ArrayList<>(); - if (map != null) { - list = map.getOrDefault(version, new ArrayList<>()); - } - return list; - } - - public Map>>> getVertexId2Edges() { - return vertexId2Edges; + private final int thresholdSize; + private volatile int size; + private volatile boolean isFlushing; + private Map>>> vertexId2Edges; + private Map>> vertexId2Vertex; + + public GraphWriteMultiVersionedBuffer(int thresholdSize) { + this.size = 0; + this.isFlushing = false; + this.thresholdSize = thresholdSize; + this.vertexId2Edges = new ConcurrentHashMap<>(); + this.vertexId2Vertex = new ConcurrentHashMap<>(); + } + + public void addEdge(long version, IEdge edge) { + K srcId = edge.getSrcId(); + Map>> map = vertexId2Edges.computeIfAbsent(srcId, k -> new HashMap<>()); + List> edges = map.computeIfAbsent(version, k -> new ArrayList<>()); + edges.add(edge); + this.size++; + } + + public void addVertex(long version, IVertex vertex) { + K id = vertex.getId(); + Map> map = vertexId2Vertex.computeIfAbsent(id, k -> new HashMap<>()); + map.put(version, vertex); + this.size++; + } + + public IVertex getVertex(long version, K sid) { + Map> map = vertexId2Vertex.get(sid); + IVertex vertex = null; + if (map != null) { + vertex = map.get(version); } - - public Map>> getVertexId2Vertex() { - return vertexId2Vertex; - } - - public void setFlushing() { - isFlushing = true; - } - - public boolean isFlushing() { - return isFlushing; - } - - public boolean needFlush() { - return size > thresholdSize; - } - - public int getSize() { - return size; - } - - public void clear() { - this.vertexId2Vertex.clear(); - this.vertexId2Edges.clear(); - this.size = 0; - this.isFlushing = false; + return vertex; + } + + public List> getEdges(long version, K sid) { + Map>> map = vertexId2Edges.get(sid); + List> list = new ArrayList<>(); + if (map != null) { + list = map.getOrDefault(version, new ArrayList<>()); } + return list; + } + + public Map>>> getVertexId2Edges() { + return vertexId2Edges; + } + + public Map>> getVertexId2Vertex() { + return vertexId2Vertex; + } + + public void setFlushing() { + isFlushing = true; + } + + public boolean isFlushing() { + return isFlushing; + } + + public boolean needFlush() { + return size > thresholdSize; + } + + public int getSize() { + return size; + } + + public void clear() { + this.vertexId2Vertex.clear(); + this.vertexId2Edges.clear(); + this.size = 0; + this.isFlushing = false; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/encoder/BaseEncoder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/encoder/BaseEncoder.java index a9d5e8782..989be8a30 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/encoder/BaseEncoder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/encoder/BaseEncoder.java @@ -20,44 +20,45 @@ package org.apache.geaflow.store.encoder; import java.util.function.Function; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.state.schema.GraphDataSchema; public abstract class BaseEncoder { - protected final GraphDataSchema dataSchema; - protected final IType keyType; - protected final Function valueSerializer; - protected final Function valueDeserializer; - protected final boolean emptyProperty; + protected final GraphDataSchema dataSchema; + protected final IType keyType; + protected final Function valueSerializer; + protected final Function valueDeserializer; + protected final boolean emptyProperty; - protected BaseEncoder(GraphDataSchema dataSchema) { - this.dataSchema = dataSchema; - this.keyType = dataSchema.getKeyType(); - this.valueSerializer = initValueSerializer(dataSchema); - this.valueDeserializer = initValueDeserializer(dataSchema); - this.emptyProperty = initEmptyProperty(dataSchema); - } + protected BaseEncoder(GraphDataSchema dataSchema) { + this.dataSchema = dataSchema; + this.keyType = dataSchema.getKeyType(); + this.valueSerializer = initValueSerializer(dataSchema); + this.valueDeserializer = initValueDeserializer(dataSchema); + this.emptyProperty = initEmptyProperty(dataSchema); + } - protected abstract Function initValueSerializer(GraphDataSchema dataSchema); + protected abstract Function initValueSerializer(GraphDataSchema dataSchema); - protected abstract Function initValueDeserializer(GraphDataSchema dataSchema); + protected abstract Function initValueDeserializer(GraphDataSchema dataSchema); - protected abstract boolean initEmptyProperty(GraphDataSchema dataSchema); + protected abstract boolean initEmptyProperty(GraphDataSchema dataSchema); - protected GraphDataSchema getDataSchema() { - return dataSchema; - } + protected GraphDataSchema getDataSchema() { + return dataSchema; + } - protected Function getValueSerializer() { - return valueSerializer; - } + protected Function getValueSerializer() { + return valueSerializer; + } - protected Function getValueDeserializer() { - return valueDeserializer; - } + protected Function getValueDeserializer() { + return valueDeserializer; + } - protected boolean isEmptyProperty() { - return emptyProperty; - } + protected boolean isEmptyProperty() { + return emptyProperty; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeListScanIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeListScanIterator.java index a6ad44215..479a70ee3 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeListScanIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeListScanIterator.java @@ -21,62 +21,63 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; public class EdgeListScanIterator implements CloseableIterator>> { - private final CloseableIterator> edgeIterator; - private IEdge residualEdge; - private List> nextValue; + private final CloseableIterator> edgeIterator; + private IEdge residualEdge; + private List> nextValue; - public EdgeListScanIterator(CloseableIterator> edgeIterator) { - this.edgeIterator = edgeIterator; - } + public EdgeListScanIterator(CloseableIterator> edgeIterator) { + this.edgeIterator = edgeIterator; + } - @Override - public boolean hasNext() { - nextValue = getEdgesFromIterator(); - if (nextValue.size() == 0) { - return false; - } else { - return true; - } + @Override + public boolean hasNext() { + nextValue = getEdgesFromIterator(); + if (nextValue.size() == 0) { + return false; + } else { + return true; } + } - @Override - public List> next() { - return nextValue; - } + @Override + public List> next() { + return nextValue; + } - private List> getEdgesFromIterator() { - List> list = new ArrayList<>(); - final IEdge lastResidualEdge = residualEdge; - K key = null; - if (residualEdge != null) { - list.add(residualEdge); - key = residualEdge.getSrcId(); - } - while (edgeIterator.hasNext()) { - IEdge edge = edgeIterator.next(); - if (key == null) { - key = edge.getSrcId(); - } - if (edge.getSrcId().equals(key)) { - list.add(edge); - } else { - residualEdge = edge; - break; - } - } - if (lastResidualEdge == residualEdge) { - residualEdge = null; - } - return list; + private List> getEdgesFromIterator() { + List> list = new ArrayList<>(); + final IEdge lastResidualEdge = residualEdge; + K key = null; + if (residualEdge != null) { + list.add(residualEdge); + key = residualEdge.getSrcId(); } - - @Override - public void close() { - this.edgeIterator.close(); + while (edgeIterator.hasNext()) { + IEdge edge = edgeIterator.next(); + if (key == null) { + key = edge.getSrcId(); + } + if (edge.getSrcId().equals(key)) { + list.add(edge); + } else { + residualEdge = edge; + break; + } + } + if (lastResidualEdge == residualEdge) { + residualEdge = null; } + return list; + } + + @Override + public void close() { + this.edgeIterator.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeScanIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeScanIterator.java index daae1063a..6bc94c748 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeScanIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/EdgeScanIterator.java @@ -22,6 +22,7 @@ import java.util.Iterator; import java.util.function.BiFunction; import java.util.function.Supplier; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -32,50 +33,48 @@ public class EdgeScanIterator implements CloseableIterator> { - private Supplier filterFun; - private final Iterator> iterator; - private final BiFunction> edgeDecoder; - private IEdge nextValue; - private K lastKey = null; - private IGraphFilter filter = null; + private Supplier filterFun; + private final Iterator> iterator; + private final BiFunction> edgeDecoder; + private IEdge nextValue; + private K lastKey = null; + private IGraphFilter filter = null; - public EdgeScanIterator( - Iterator> iterator, - IStatePushDown pushdown, - BiFunction> decoderFun) { + public EdgeScanIterator( + Iterator> iterator, + IStatePushDown pushdown, + BiFunction> decoderFun) { - IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); - IEdgeLimit limit = pushdown.getEdgeLimit(); - filterFun = limit == null ? () -> filter : () -> LimitFilterBuilder.build(filter, limit); - this.iterator = iterator; - this.edgeDecoder = decoderFun; - } + IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); + IEdgeLimit limit = pushdown.getEdgeLimit(); + filterFun = limit == null ? () -> filter : () -> LimitFilterBuilder.build(filter, limit); + this.iterator = iterator; + this.edgeDecoder = decoderFun; + } - @Override - public boolean hasNext() { - while (iterator.hasNext()) { - Tuple pair = iterator.next(); + @Override + public boolean hasNext() { + while (iterator.hasNext()) { + Tuple pair = iterator.next(); - nextValue = edgeDecoder.apply(pair.f0, pair.f1); - if (!nextValue.getSrcId().equals(lastKey)) { - filter = filterFun.get(); - lastKey = nextValue.getSrcId(); - } - if (!filter.filterEdge(nextValue)) { - continue; - } - return true; - } - return false; + nextValue = edgeDecoder.apply(pair.f0, pair.f1); + if (!nextValue.getSrcId().equals(lastKey)) { + filter = filterFun.get(); + lastKey = nextValue.getSrcId(); + } + if (!filter.filterEdge(nextValue)) { + continue; + } + return true; } + return false; + } - @Override - public IEdge next() { - return nextValue; - } + @Override + public IEdge next() { + return nextValue; + } - @Override - public void close() { - - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/KeysIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/KeysIterator.java index 8b816652f..bdaae6f49 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/KeysIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/KeysIterator.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.function.BiFunction; import java.util.function.Function; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.pushdown.IStatePushDown; import org.apache.geaflow.state.pushdown.StatePushDown; @@ -30,47 +31,45 @@ public class KeysIterator implements CloseableIterator { - private final Iterator iterator; - private final BiFunction fetchFun; - private Function pushdownFun; - private R nextValue; - - public KeysIterator(List keys, BiFunction fetchFun, - IStatePushDown pushdown) { - this.fetchFun = fetchFun; - this.iterator = keys.iterator(); - if (pushdown.getFilters() != null) { - StatePushDown simpleKeyPushDown = StatePushDown.of() - .withEdgeLimit(pushdown.getEdgeLimit()) - .withOrderFields(pushdown.getOrderFields()); - this.pushdownFun = k -> simpleKeyPushDown.withFilter( - (IFilter) pushdown.getFilters().get(k)); - } else { - this.pushdownFun = k -> pushdown; - } - } + private final Iterator iterator; + private final BiFunction fetchFun; + private Function pushdownFun; + private R nextValue; - @Override - public boolean hasNext() { - while (iterator.hasNext()) { - K key = iterator.next(); - IStatePushDown pushdown = pushdownFun.apply(key); - nextValue = fetchFun.apply(key, pushdown); - if (nextValue == null) { - continue; - } - return true; - } - return false; + public KeysIterator( + List keys, BiFunction fetchFun, IStatePushDown pushdown) { + this.fetchFun = fetchFun; + this.iterator = keys.iterator(); + if (pushdown.getFilters() != null) { + StatePushDown simpleKeyPushDown = + StatePushDown.of() + .withEdgeLimit(pushdown.getEdgeLimit()) + .withOrderFields(pushdown.getOrderFields()); + this.pushdownFun = k -> simpleKeyPushDown.withFilter((IFilter) pushdown.getFilters().get(k)); + } else { + this.pushdownFun = k -> pushdown; } + } - @Override - public R next() { - return nextValue; + @Override + public boolean hasNext() { + while (iterator.hasNext()) { + K key = iterator.next(); + IStatePushDown pushdown = pushdownFun.apply(key); + nextValue = fetchFun.apply(key, pushdown); + if (nextValue == null) { + continue; + } + return true; } + return false; + } - @Override - public void close() { + @Override + public R next() { + return nextValue; + } - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/OneDegreeGraphScanIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/OneDegreeGraphScanIterator.java index 0582558fc..f13413dbf 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/OneDegreeGraphScanIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/OneDegreeGraphScanIterator.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.model.graph.edge.IEdge; @@ -33,92 +34,99 @@ import org.apache.geaflow.state.pushdown.filter.inner.IGraphFilter; // TODO: Implement one degree graph scan iterator for graph proxy partitioned by label -public class OneDegreeGraphScanIterator implements - IOneDegreeGraphIterator { - - private final CloseableIterator> vertexIterator; - private final CloseableIterator>> edgeListIterator; - private final IType keyType; - private IGraphFilter filter; - private OneDegreeGraph nextValue; +public class OneDegreeGraphScanIterator implements IOneDegreeGraphIterator { - private List> candidateEdges; - private IVertex candidateVertex; + private final CloseableIterator> vertexIterator; + private final CloseableIterator>> edgeListIterator; + private final IType keyType; + private IGraphFilter filter; + private OneDegreeGraph nextValue; - public OneDegreeGraphScanIterator( - IType keyType, - CloseableIterator> vertexIterator, - CloseableIterator> edgeIterator, - IStatePushDown pushdown) { - this.keyType = keyType; - this.vertexIterator = vertexIterator; - this.edgeListIterator = new EdgeListScanIterator<>(edgeIterator); - this.filter = (IGraphFilter) pushdown.getFilter(); - } + private List> candidateEdges; + private IVertex candidateVertex; + public OneDegreeGraphScanIterator( + IType keyType, + CloseableIterator> vertexIterator, + CloseableIterator> edgeIterator, + IStatePushDown pushdown) { + this.keyType = keyType; + this.vertexIterator = vertexIterator; + this.edgeListIterator = new EdgeListScanIterator<>(edgeIterator); + this.filter = (IGraphFilter) pushdown.getFilter(); + } - private IVertex getVertexFromIterator() { - if (vertexIterator.hasNext()) { - return vertexIterator.next(); - } - return null; + private IVertex getVertexFromIterator() { + if (vertexIterator.hasNext()) { + return vertexIterator.next(); } + return null; + } - @Override - public boolean hasNext() { - do { - candidateVertex = candidateVertex == null ? getVertexFromIterator() : candidateVertex; - if (candidateEdges == null && edgeListIterator.hasNext()) { - candidateEdges = edgeListIterator.next(); - } else if (candidateEdges == null) { - candidateEdges = new ArrayList<>(); - } + @Override + public boolean hasNext() { + do { + candidateVertex = candidateVertex == null ? getVertexFromIterator() : candidateVertex; + if (candidateEdges == null && edgeListIterator.hasNext()) { + candidateEdges = edgeListIterator.next(); + } else if (candidateEdges == null) { + candidateEdges = new ArrayList<>(); + } - if (candidateEdges.size() > 0 && candidateVertex != null) { - K edgeKey = candidateEdges.get(0).getSrcId(); - K vertexKey = candidateVertex.getId(); - int res = keyType.compare(edgeKey, vertexKey); - if (res < 0) { - nextValue = new OneDegreeGraph<>(edgeKey, null, - IteratorWithClose.wrap(candidateEdges.iterator())); - candidateEdges = null; - } else if (res == 0) { - nextValue = new OneDegreeGraph<>(vertexKey, candidateVertex, - IteratorWithClose.wrap(candidateEdges.iterator())); - candidateVertex = null; - candidateEdges = null; - } else { - nextValue = new OneDegreeGraph<>(vertexKey, candidateVertex, - IteratorWithClose.wrap(Collections.emptyIterator())); - candidateVertex = null; - } - } else if (candidateEdges.size() > 0) { - nextValue = new OneDegreeGraph<>(candidateEdges.get(0).getSrcId(), null, - IteratorWithClose.wrap(candidateEdges.iterator())); - candidateEdges = null; - } else if (candidateVertex != null) { - nextValue = new OneDegreeGraph<>(candidateVertex.getId(), candidateVertex, - IteratorWithClose.wrap(Collections.emptyIterator())); - candidateVertex = null; - } else { - return false; - } + if (candidateEdges.size() > 0 && candidateVertex != null) { + K edgeKey = candidateEdges.get(0).getSrcId(); + K vertexKey = candidateVertex.getId(); + int res = keyType.compare(edgeKey, vertexKey); + if (res < 0) { + nextValue = + new OneDegreeGraph<>( + edgeKey, null, IteratorWithClose.wrap(candidateEdges.iterator())); + candidateEdges = null; + } else if (res == 0) { + nextValue = + new OneDegreeGraph<>( + vertexKey, candidateVertex, IteratorWithClose.wrap(candidateEdges.iterator())); + candidateVertex = null; + candidateEdges = null; + } else { + nextValue = + new OneDegreeGraph<>( + vertexKey, candidateVertex, IteratorWithClose.wrap(Collections.emptyIterator())); + candidateVertex = null; + } + } else if (candidateEdges.size() > 0) { + nextValue = + new OneDegreeGraph<>( + candidateEdges.get(0).getSrcId(), + null, + IteratorWithClose.wrap(candidateEdges.iterator())); + candidateEdges = null; + } else if (candidateVertex != null) { + nextValue = + new OneDegreeGraph<>( + candidateVertex.getId(), + candidateVertex, + IteratorWithClose.wrap(Collections.emptyIterator())); + candidateVertex = null; + } else { + return false; + } - if (!filter.filterOneDegreeGraph(nextValue)) { - continue; - } - return true; - } while (true); - } + if (!filter.filterOneDegreeGraph(nextValue)) { + continue; + } + return true; + } while (true); + } - @Override - public OneDegreeGraph next() { - return nextValue; - } + @Override + public OneDegreeGraph next() { + return nextValue; + } - @Override - public void close() { - this.vertexIterator.close(); - this.edgeListIterator.close(); - } + @Override + public void close() { + this.vertexIterator.close(); + this.edgeListIterator.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/VertexScanIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/VertexScanIterator.java index 120625a2a..a24a3be09 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/VertexScanIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-api/src/main/java/org/apache/geaflow/store/iterator/VertexScanIterator.java @@ -20,6 +20,7 @@ package org.apache.geaflow.store.iterator; import java.util.function.BiFunction; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -29,45 +30,46 @@ public class VertexScanIterator implements IVertexIterator { - private final CloseableIterator> iterator; - private final IGraphFilter filter; - private final BiFunction> vertexDecoder; - private boolean isClosed = false; - private IVertex nextValue; - - public VertexScanIterator(CloseableIterator> iterator, - IStatePushDown pushdown, - BiFunction> decoderFun) { - this.vertexDecoder = decoderFun; - this.iterator = iterator; - this.filter = (IGraphFilter) pushdown.getFilter(); - } - - @Override - public boolean hasNext() { - if (isClosed) { - return false; - } - while (iterator.hasNext()) { - Tuple pair = iterator.next(); - nextValue = vertexDecoder.apply(pair.f0, pair.f1); + private final CloseableIterator> iterator; + private final IGraphFilter filter; + private final BiFunction> vertexDecoder; + private boolean isClosed = false; + private IVertex nextValue; - if (!filter.filterVertex(nextValue)) { - continue; - } - return true; - } + public VertexScanIterator( + CloseableIterator> iterator, + IStatePushDown pushdown, + BiFunction> decoderFun) { + this.vertexDecoder = decoderFun; + this.iterator = iterator; + this.filter = (IGraphFilter) pushdown.getFilter(); + } - return false; + @Override + public boolean hasNext() { + if (isClosed) { + return false; } + while (iterator.hasNext()) { + Tuple pair = iterator.next(); + nextValue = vertexDecoder.apply(pair.f0, pair.f1); - @Override - public IVertex next() { - return nextValue; + if (!filter.filterVertex(nextValue)) { + continue; + } + return true; } - @Override - public void close() { - this.iterator.close(); - } + return false; + } + + @Override + public IVertex next() { + return nextValue; + } + + @Override + public void close() { + this.iterator.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/BaseJdbcStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/BaseJdbcStore.java index 963fa6bcc..73dead965 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/BaseJdbcStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/BaseJdbcStore.java @@ -19,9 +19,6 @@ package org.apache.geaflow.store.jdbc; -import com.google.common.base.Joiner; -import com.zaxxer.hikari.HikariConfig; -import com.zaxxer.hikari.HikariDataSource; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -29,6 +26,7 @@ import java.util.Collections; import java.util.Map; import java.util.Map.Entry; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.GsonUtil; @@ -38,120 +36,127 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Joiner; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; + public abstract class BaseJdbcStore implements IBaseStore { - private static final Logger LOGGER = LoggerFactory.getLogger(BaseJdbcStore.class); - private static final char SQL_SEPARATOR = ','; - - private static HikariDataSource ds; - - protected String tableName; - protected Configuration config; - protected int retries; - protected String pk; - private String updateFormat; - private String insertFormat; - private String deleteFormat; - private String queryFormat; - - public void init(StoreContext storeContext) { - this.tableName = storeContext.getName(); - this.config = storeContext.getConfig(); - this.retries = this.config.getInteger(JdbcConfigKeys.JDBC_MAX_RETRIES); - this.pk = this.config.getString(JdbcConfigKeys.JDBC_PK); - this.insertFormat = this.config.getString(JdbcConfigKeys.JDBC_INSERT_FORMAT); - this.updateFormat = this.config.getString(JdbcConfigKeys.JDBC_UPDATE_FORMAT); - this.deleteFormat = this.config.getString(JdbcConfigKeys.JDBC_DELETE_FORMAT); - this.queryFormat = this.config.getString(JdbcConfigKeys.JDBC_QUERY_FORMAT); - initConnectionPool(config); + private static final Logger LOGGER = LoggerFactory.getLogger(BaseJdbcStore.class); + private static final char SQL_SEPARATOR = ','; + + private static HikariDataSource ds; + + protected String tableName; + protected Configuration config; + protected int retries; + protected String pk; + private String updateFormat; + private String insertFormat; + private String deleteFormat; + private String queryFormat; + + public void init(StoreContext storeContext) { + this.tableName = storeContext.getName(); + this.config = storeContext.getConfig(); + this.retries = this.config.getInteger(JdbcConfigKeys.JDBC_MAX_RETRIES); + this.pk = this.config.getString(JdbcConfigKeys.JDBC_PK); + this.insertFormat = this.config.getString(JdbcConfigKeys.JDBC_INSERT_FORMAT); + this.updateFormat = this.config.getString(JdbcConfigKeys.JDBC_UPDATE_FORMAT); + this.deleteFormat = this.config.getString(JdbcConfigKeys.JDBC_DELETE_FORMAT); + this.queryFormat = this.config.getString(JdbcConfigKeys.JDBC_QUERY_FORMAT); + initConnectionPool(config); + } + + private synchronized void initConnectionPool(Configuration config) { + if (ds != null) { + return; } - - private synchronized void initConnectionPool(Configuration config) { - if (ds != null) { - return; - } - HikariConfig conf = new HikariConfig(); - conf.setJdbcUrl(config.getString(JdbcConfigKeys.JDBC_URL)); - conf.setUsername(config.getString(JdbcConfigKeys.JDBC_USER_NAME)); - conf.setPassword(config.getString(JdbcConfigKeys.JDBC_PASSWORD)); - conf.setDriverClassName(config.getString(JdbcConfigKeys.JDBC_DRIVER_CLASS)); - conf.setMaximumPoolSize(config.getInteger(JdbcConfigKeys.JDBC_CONNECTION_POOL_SIZE)); - String jsonConfig = this.config.getString(JdbcConfigKeys.JSON_CONFIG); - Map map = GsonUtil.parse(jsonConfig); - for (Entry entry : map.entrySet()) { - conf.addDataSourceProperty(entry.getKey(), entry.getValue()); - } - ds = new HikariDataSource(conf); + HikariConfig conf = new HikariConfig(); + conf.setJdbcUrl(config.getString(JdbcConfigKeys.JDBC_URL)); + conf.setUsername(config.getString(JdbcConfigKeys.JDBC_USER_NAME)); + conf.setPassword(config.getString(JdbcConfigKeys.JDBC_PASSWORD)); + conf.setDriverClassName(config.getString(JdbcConfigKeys.JDBC_DRIVER_CLASS)); + conf.setMaximumPoolSize(config.getInteger(JdbcConfigKeys.JDBC_CONNECTION_POOL_SIZE)); + String jsonConfig = this.config.getString(JdbcConfigKeys.JSON_CONFIG); + Map map = GsonUtil.parse(jsonConfig); + for (Entry entry : map.entrySet()) { + conf.addDataSourceProperty(entry.getKey(), entry.getValue()); } + ds = new HikariDataSource(conf); + } - protected boolean insert(String key, String[] columns, Object[] values) throws SQLException { - if (columns.length != values.length) { - throw new GeaflowRuntimeException("columns' size does not match values'"); - } + protected boolean insert(String key, String[] columns, Object[] values) throws SQLException { + if (columns.length != values.length) { + throw new GeaflowRuntimeException("columns' size does not match values'"); + } - String sql = String.format(insertFormat, this.tableName, this.pk, + String sql = + String.format( + insertFormat, + this.tableName, + this.pk, Joiner.on(SQL_SEPARATOR).join(columns), Joiner.on(SQL_SEPARATOR).join(Collections.nCopies(values.length, "?"))); - try (Connection conn = ds.getConnection(); - PreparedStatement ps = conn.prepareStatement(sql)) { - ps.setString(1, key); - for (int i = 0; i < values.length; i++) { - ps.setObject(i + 2, values[i]); - } - return 0 != RetryCommand.run(ps::executeUpdate, this.retries); - } - + try (Connection conn = ds.getConnection(); + PreparedStatement ps = conn.prepareStatement(sql)) { + ps.setString(1, key); + for (int i = 0; i < values.length; i++) { + ps.setObject(i + 2, values[i]); + } + return 0 != RetryCommand.run(ps::executeUpdate, this.retries); } + } - protected boolean update(String key, String[] columns, Object[] values) throws SQLException { - if (columns.length != values.length) { - throw new GeaflowRuntimeException("columns' size does not match values'"); - } + protected boolean update(String key, String[] columns, Object[] values) throws SQLException { + if (columns.length != values.length) { + throw new GeaflowRuntimeException("columns' size does not match values'"); + } - String sql = String.format(updateFormat, - this.tableName, Joiner.on("=?, ").join(columns), pk, key); + String sql = + String.format(updateFormat, this.tableName, Joiner.on("=?, ").join(columns), pk, key); - try (Connection conn = ds.getConnection(); - PreparedStatement ps = conn.prepareStatement(sql)) { - for (int i = 0; i < values.length; i++) { - ps.setObject(i + 1, values[i]); - } - return 0 != RetryCommand.run(ps::executeUpdate, this.retries); - } + try (Connection conn = ds.getConnection(); + PreparedStatement ps = conn.prepareStatement(sql)) { + for (int i = 0; i < values.length; i++) { + ps.setObject(i + 1, values[i]); + } + return 0 != RetryCommand.run(ps::executeUpdate, this.retries); } - - protected byte[] query(String key, String[] columns) throws SQLException { - String selectColumn = columns == null ? "*" : Joiner.on(SQL_SEPARATOR).join(columns); - String sql = String.format(queryFormat, selectColumn, this.tableName, pk, key); - try (Connection conn = ds.getConnection(); - PreparedStatement ps = conn.prepareStatement(sql)) { - try (ResultSet rs = RetryCommand.run(ps::executeQuery, this.retries)) { - if (rs.next()) { - return rs.getBytes(1); - } else { - return null; - } - } + } + + protected byte[] query(String key, String[] columns) throws SQLException { + String selectColumn = columns == null ? "*" : Joiner.on(SQL_SEPARATOR).join(columns); + String sql = String.format(queryFormat, selectColumn, this.tableName, pk, key); + try (Connection conn = ds.getConnection(); + PreparedStatement ps = conn.prepareStatement(sql)) { + try (ResultSet rs = RetryCommand.run(ps::executeQuery, this.retries)) { + if (rs.next()) { + return rs.getBytes(1); + } else { + return null; } + } } + } - protected boolean delete(String key) throws SQLException { - String sql = String.format(deleteFormat, this.tableName, pk, key); - try (Connection conn = ds.getConnection(); PreparedStatement ps = conn.prepareStatement( - sql)) { - return 0 != RetryCommand.run(ps::executeUpdate, this.retries); - } + protected boolean delete(String key) throws SQLException { + String sql = String.format(deleteFormat, this.tableName, pk, key); + try (Connection conn = ds.getConnection(); + PreparedStatement ps = conn.prepareStatement(sql)) { + return 0 != RetryCommand.run(ps::executeUpdate, this.retries); } + } - @Override - public void flush() { - LOGGER.info("flush"); - } + @Override + public void flush() { + LOGGER.info("flush"); + } - @Override - public synchronized void close() { - if (!ds.isClosed()) { - ds.close(); - } + @Override + public synchronized void close() { + if (!ds.isClosed()) { + ds.close(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcConfigKeys.java index ad356d0a1..82bd4f27c 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcConfigKeys.java @@ -24,63 +24,64 @@ public class JdbcConfigKeys { - public static final ConfigKey JSON_CONFIG = ConfigKeys - .key("geaflow.store.jdbc.connect.config.json") - .defaultValue("{}") - .description("geaflow jdbc json config"); + public static final ConfigKey JSON_CONFIG = + ConfigKeys.key("geaflow.store.jdbc.connect.config.json") + .defaultValue("{}") + .description("geaflow jdbc json config"); - public static final ConfigKey JDBC_USER_NAME = ConfigKeys - .key("geaflow.store.jdbc.user.name") - .defaultValue("") - .description("geaflow store jdbc user name"); + public static final ConfigKey JDBC_USER_NAME = + ConfigKeys.key("geaflow.store.jdbc.user.name") + .defaultValue("") + .description("geaflow store jdbc user name"); - public static final ConfigKey JDBC_DRIVER_CLASS = ConfigKeys - .key("geaflow.store.jdbc.driver.class") - .defaultValue("com.mysql.jdbc.Driver") - .description("geaflow store jdbc driver class name"); + public static final ConfigKey JDBC_DRIVER_CLASS = + ConfigKeys.key("geaflow.store.jdbc.driver.class") + .defaultValue("com.mysql.jdbc.Driver") + .description("geaflow store jdbc driver class name"); - public static final ConfigKey JDBC_URL = ConfigKeys - .key("geaflow.store.jdbc.url") - .noDefaultValue() - .description("geaflow store jdbc url"); + public static final ConfigKey JDBC_URL = + ConfigKeys.key("geaflow.store.jdbc.url") + .noDefaultValue() + .description("geaflow store jdbc url"); - public static final ConfigKey JDBC_PASSWORD = ConfigKeys - .key("geaflow.store.jdbc.password") - .defaultValue("") - .description("geaflow store jdbc password"); + public static final ConfigKey JDBC_PASSWORD = + ConfigKeys.key("geaflow.store.jdbc.password") + .defaultValue("") + .description("geaflow store jdbc password"); - public static final ConfigKey JDBC_MAX_RETRIES = ConfigKeys - .key("geaflow.store.jdbc.max.retries") - .defaultValue(3) - .description("geaflow store jdbc max retry"); + public static final ConfigKey JDBC_MAX_RETRIES = + ConfigKeys.key("geaflow.store.jdbc.max.retries") + .defaultValue(3) + .description("geaflow store jdbc max retry"); - public static final ConfigKey JDBC_CONNECTION_POOL_SIZE = ConfigKeys - .key("geaflow.store.jdbc.connection.pool.size") - .defaultValue(10) - .description("geaflow store jdbc connection pool size"); + public static final ConfigKey JDBC_CONNECTION_POOL_SIZE = + ConfigKeys.key("geaflow.store.jdbc.connection.pool.size") + .defaultValue(10) + .description("geaflow store jdbc connection pool size"); - public static final ConfigKey JDBC_PK = ConfigKeys - .key("geaflow.store.jdbc.pk") - .defaultValue("pk") - .description("geaflow store jdbc db pk"); + public static final ConfigKey JDBC_PK = + ConfigKeys.key("geaflow.store.jdbc.pk") + .defaultValue("pk") + .description("geaflow store jdbc db pk"); - public static final ConfigKey JDBC_INSERT_FORMAT = ConfigKeys - .key("geaflow.store.jdbc.insert.format") - .defaultValue("INSERT INTO %s(%s,%s, gmt_create, gmt_modified) VALUES (?, %s, now(), now())") - .description("geaflow store jdbc insert format"); + public static final ConfigKey JDBC_INSERT_FORMAT = + ConfigKeys.key("geaflow.store.jdbc.insert.format") + .defaultValue( + "INSERT INTO %s(%s,%s, gmt_create, gmt_modified) VALUES (?, %s, now(), now())") + .description("geaflow store jdbc insert format"); - public static final ConfigKey JDBC_UPDATE_FORMAT = ConfigKeys - .key("geaflow.store.jdbc.update.format") - .defaultValue("UPDATE %s SET %s=? , gmt_modified=now() WHERE %s='%s'") - .description("geaflow store jdbc update format"); + public static final ConfigKey JDBC_UPDATE_FORMAT = + ConfigKeys.key("geaflow.store.jdbc.update.format") + .defaultValue("UPDATE %s SET %s=? , gmt_modified=now() WHERE %s='%s'") + .description("geaflow store jdbc update format"); - public static final ConfigKey JDBC_DELETE_FORMAT = ConfigKeys - .key("geaflow.store.jdbc.delete.format") - .defaultValue("DELETE FROM %s WHERE %s='%s'") - .description("geaflow store jdbc delete format"); + public static final ConfigKey JDBC_DELETE_FORMAT = + ConfigKeys.key("geaflow.store.jdbc.delete.format") + .defaultValue("DELETE FROM %s WHERE %s='%s'") + .description("geaflow store jdbc delete format"); - public static final ConfigKey JDBC_QUERY_FORMAT = ConfigKeys - .key("geaflow.store.jdbc.query.format") - .defaultValue("SELECT %s FROM %s WHERE %s='%s'") - .description("geaflow store jdbc query format"); + public static final ConfigKey JDBC_QUERY_FORMAT = + ConfigKeys.key("geaflow.store.jdbc.query.format") + .defaultValue("SELECT %s FROM %s WHERE %s='%s'") + .description("geaflow store jdbc query format"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcKVStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcKVStore.java index de1907a25..605e1b7e9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcKVStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcKVStore.java @@ -19,8 +19,8 @@ package org.apache.geaflow.store.jdbc; -import com.google.common.base.Preconditions; import java.sql.SQLException; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.RetryCommand; import org.apache.geaflow.state.serializer.IKVSerializer; @@ -29,69 +29,75 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class JdbcKVStore extends BaseJdbcStore implements IKVStore { - private static final Logger LOGGER = LoggerFactory.getLogger(JdbcKVStore.class); - private String[] columns = {"value"}; + private static final Logger LOGGER = LoggerFactory.getLogger(JdbcKVStore.class); + private String[] columns = {"value"}; - private IKVSerializer serializer; + private IKVSerializer serializer; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.serializer = (IKVSerializer) Preconditions.checkNotNull(storeContext.getKeySerializer(), - "keySerializer must be set"); - } + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.serializer = + (IKVSerializer) + Preconditions.checkNotNull( + storeContext.getKeySerializer(), "keySerializer must be set"); + } - private java.lang.String getFromKey(K key) { - if (key.getClass() == byte[].class) { - return new String((byte[]) key); - } - return key.toString(); + private java.lang.String getFromKey(K key) { + if (key.getClass() == byte[].class) { + return new String((byte[]) key); } + return key.toString(); + } - @Override - public V get(K key) { - try { - byte[] res = query(getFromKey(key), columns); - if (res != null) { - return serializer.deserializeValue(res); - } else { - return null; - } - } catch (SQLException e) { - throw new GeaflowRuntimeException("get fail", e); - } + @Override + public V get(K key) { + try { + byte[] res = query(getFromKey(key), columns); + if (res != null) { + return serializer.deserializeValue(res); + } else { + return null; + } + } catch (SQLException e) { + throw new GeaflowRuntimeException("get fail", e); } + } - @Override - public void put(K key, V value) { - byte[] valueArray = serializer.serializeValue(value); - RetryCommand.run(() -> { - try { - String fromKey = getFromKey(key); - if (!update(fromKey, columns, new Object[]{valueArray})) { - LOGGER.info("key: {}, insert fail, try insert", key); - try { - insert(fromKey, columns, new Object[]{valueArray}); - } catch (Exception e) { - LOGGER.info("key: {}, insert fail", key); - throw new GeaflowRuntimeException("put fail"); - } - } - } catch (SQLException e) { - throw new GeaflowRuntimeException("put fail", e); + @Override + public void put(K key, V value) { + byte[] valueArray = serializer.serializeValue(value); + RetryCommand.run( + () -> { + try { + String fromKey = getFromKey(key); + if (!update(fromKey, columns, new Object[] {valueArray})) { + LOGGER.info("key: {}, insert fail, try insert", key); + try { + insert(fromKey, columns, new Object[] {valueArray}); + } catch (Exception e) { + LOGGER.info("key: {}, insert fail", key); + throw new GeaflowRuntimeException("put fail"); + } } - return true; - }, retries); - } + } catch (SQLException e) { + throw new GeaflowRuntimeException("put fail", e); + } + return true; + }, + retries); + } - @Override - public void remove(K key) { - try { - delete(getFromKey(key)); - } catch (SQLException e) { - throw new GeaflowRuntimeException("remove fail", e); - } + @Override + public void remove(K key) { + try { + delete(getFromKey(key)); + } catch (SQLException e) { + throw new GeaflowRuntimeException("remove fail", e); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilder.java index 7acc5f49d..29fe7a4f7 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/main/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilder.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,38 +33,38 @@ public class JdbcStoreBuilder implements IStoreBuilder { - private static final StoreDesc STORE_DESC = new JdbcStoreDesc(); + private static final StoreDesc STORE_DESC = new JdbcStoreDesc(); - @Override - public IBaseStore getStore(DataModel type, Configuration config) { - switch (type) { - case KV: - return new JdbcKVStore<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); - } + @Override + public IBaseStore getStore(DataModel type, Configuration config) { + switch (type) { + case KV: + return new JdbcKVStore<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); } + } - @Override - public StoreDesc getStoreDesc() { - return STORE_DESC; - } + @Override + public StoreDesc getStoreDesc() { + return STORE_DESC; + } - @Override - public List supportedDataModel() { - return Arrays.asList(DataModel.KV); - } + @Override + public List supportedDataModel() { + return Arrays.asList(DataModel.KV); + } - public static class JdbcStoreDesc implements StoreDesc { + public static class JdbcStoreDesc implements StoreDesc { - @Override - public boolean isLocalStore() { - return false; - } + @Override + public boolean isLocalStore() { + return false; + } - @Override - public String name() { - return StoreType.JDBC.name(); - } + @Override + public String name() { + return StoreType.JDBC.name(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/test/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilderTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/test/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilderTest.java index 064a02e96..b6cbd922a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/test/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilderTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-jdbc/src/test/java/org/apache/geaflow/store/jdbc/JdbcStoreBuilderTest.java @@ -23,6 +23,7 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.Statement; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.state.DataModel; @@ -37,59 +38,60 @@ public class JdbcStoreBuilderTest { - private void prepareSqlite() throws Exception { - String createSqliteTable = "CREATE TABLE IF NOT EXISTS `store_test`(\n" + private void prepareSqlite() throws Exception { + String createSqliteTable = + "CREATE TABLE IF NOT EXISTS `store_test`(\n" + " `id` bigint UNSIGNED AUTO_INCREMENT,\n" + " `pk` VARCHAR(256) NOT NULL,\n" + " `value` LONGBLOB,\n" + " PRIMARY KEY ( `id`)\n" + ")"; - Class.forName("org.sqlite.JDBC"); - Connection c = DriverManager.getConnection("jdbc:sqlite:/tmp/test.db"); - Statement stmt = c.createStatement(); - stmt.execute(createSqliteTable); - stmt.close(); - c.close(); - } + Class.forName("org.sqlite.JDBC"); + Connection c = DriverManager.getConnection("jdbc:sqlite:/tmp/test.db"); + Statement stmt = c.createStatement(); + stmt.execute(createSqliteTable); + stmt.close(); + c.close(); + } - @Test - public void testSqliteKV() throws Exception { - prepareSqlite(); - Configuration configuration = new Configuration(); - configuration.put(JdbcConfigKeys.JDBC_DRIVER_CLASS.getKey(), "org.sqlite.JDBC"); - configuration.put(JdbcConfigKeys.JDBC_URL, "jdbc:sqlite:/tmp/test.db"); - configuration.put(JdbcConfigKeys.JDBC_INSERT_FORMAT, "INSERT INTO %s(%s,%s) VALUES (?, %s)"); - configuration.put(JdbcConfigKeys.JDBC_UPDATE_FORMAT, "UPDATE %s SET %s=? WHERE %s='%s'"); + @Test + public void testSqliteKV() throws Exception { + prepareSqlite(); + Configuration configuration = new Configuration(); + configuration.put(JdbcConfigKeys.JDBC_DRIVER_CLASS.getKey(), "org.sqlite.JDBC"); + configuration.put(JdbcConfigKeys.JDBC_URL, "jdbc:sqlite:/tmp/test.db"); + configuration.put(JdbcConfigKeys.JDBC_INSERT_FORMAT, "INSERT INTO %s(%s,%s) VALUES (?, %s)"); + configuration.put(JdbcConfigKeys.JDBC_UPDATE_FORMAT, "UPDATE %s SET %s=? WHERE %s='%s'"); - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.JDBC.name()); - IKVStore kvStore = - (IKVStore) builder.getStore(DataModel.KV, configuration); - StoreContext storeContext = new StoreContext("store_test").withConfig(configuration); - storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); - Assert.assertEquals(kvStore.getClass().getSimpleName(), JdbcKVStore.class.getSimpleName()); + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.JDBC.name()); + IKVStore kvStore = + (IKVStore) builder.getStore(DataModel.KV, configuration); + StoreContext storeContext = new StoreContext("store_test").withConfig(configuration); + storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); + Assert.assertEquals(kvStore.getClass().getSimpleName(), JdbcKVStore.class.getSimpleName()); - kvStore.init(storeContext); + kvStore.init(storeContext); - innerTestKV(kvStore); - FileUtils.deleteQuietly(new File("/tmp/test.db")); - kvStore.close(); - } + innerTestKV(kvStore); + FileUtils.deleteQuietly(new File("/tmp/test.db")); + kvStore.close(); + } - private static void innerTestKV(IKVStore kvStore) { - String key = "key1"; - kvStore.put(key, "foo"); - Assert.assertEquals(kvStore.get(key), "foo"); + private static void innerTestKV(IKVStore kvStore) { + String key = "key1"; + kvStore.put(key, "foo"); + Assert.assertEquals(kvStore.get(key), "foo"); - Assert.assertEquals(kvStore.get(key), "foo"); + Assert.assertEquals(kvStore.get(key), "foo"); - kvStore.remove(key); - Assert.assertNull(kvStore.get(key)); + kvStore.remove(key); + Assert.assertNull(kvStore.get(key)); - for (int i = 0; i < 5; i++) { - key = "key" + i; - kvStore.put(key, "foo"); - Assert.assertEquals(kvStore.get(key), "foo"); - } + for (int i = 0; i < 5; i++) { + key = "key" + i; + kvStore.put(key, "foo"); + Assert.assertEquals(kvStore.get(key), "foo"); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/BaseStaticGraphMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/BaseStaticGraphMemoryStore.java index 98dbf467e..38bff2f68 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/BaseStaticGraphMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/BaseStaticGraphMemoryStore.java @@ -19,13 +19,13 @@ package org.apache.geaflow.store.memory; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.function.Function; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -47,153 +47,167 @@ import org.apache.geaflow.store.memory.iterator.MemoryEdgeScanPushDownIterator; import org.apache.geaflow.store.memory.iterator.MemoryVertexScanIterator; -public abstract class BaseStaticGraphMemoryStore extends BaseGraphStore implements - IStaticGraphStore { - - @Override - public void init(StoreContext context) { - super.init(context); - } - - protected abstract IVertex getVertex(K sid); - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - IVertex vertex = getVertex(sid); - return vertex != null && ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; - } - - protected abstract List> getEdges(K sid); - - protected List> pushdownEdges(List> list, IStatePushDown pushdown) { - if (pushdown.getOrderFields() != null) { - list.sort(EdgeAtom.getComparator(pushdown.getOrderFields())); - } - List> res = new ArrayList<>(list.size()); - Iterator> it = list.iterator(); - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - while (it.hasNext() && !filter.dropAllRemaining()) { - IEdge edge = it.next(); - if (filter.filterEdge(edge)) { - res.add(edge); - } - } - - return res; - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - return pushdownEdges(getEdges(sid), pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - IVertex vertex = getVertex(sid, pushdown); - List> edges = getEdges(sid, pushdown); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edges.iterator())); - if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } - return null; - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - boolean emptyFilter = pushdown.getFilter().getFilterType() == FilterType.EMPTY; - return emptyFilter ? getVertexIterator() - : new MemoryVertexScanIterator<>(getVertexIterator(), (IGraphFilter) pushdown.getFilter()); - } - - @Override - public CloseableIterator> getVertexIterator(List list, IStatePushDown pushdown) { - return new KeysIterator<>(list, this::getVertex, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - CloseableIterator>> it = new MemoryEdgeScanPushDownIterator<>(getEdgesIterator(), pushdown); - return new IteratorWithFlatFn<>(it, List::iterator); - } - - @Override - public CloseableIterator> getEdgeIterator(List list, IStatePushDown pushdown) { - CloseableIterator>> it = new KeysIterator<>(list, this::getEdges, pushdown); - return new IteratorWithFlatFn<>(it, List::iterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(IStatePushDown pushdown) { - return new KeysIterator<>(Lists.newArrayList(getKeyIterator()), this::getOneDegreeGraph, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, IStatePushDown pushdown) { - return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(IStatePushDown, R> pushdown) { - return new IteratorWithFn<>(getEdgeIterator(pushdown), - edge -> Tuple.of(edge.getSrcId(), pushdown.getProjector().project(edge))); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown) { - return new IteratorWithFn<>(getEdgeIterator(keys, pushdown), - edge -> Tuple.of(edge.getSrcId(), pushdown.getProjector().project(edge))); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - Map res = new HashMap<>(); - Iterator keyIt = getKeyIterator(); - while (keyIt.hasNext()) { - K key = keyIt.next(); - - List> list = getEdges(key, pushdown); - res.put(key, (long) list.size()); - } - return res; - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - Map res = new HashMap<>(keys.size()); - - Function pushdownFun; - if (pushdown.getFilters() == null) { - pushdownFun = key -> pushdown; - } else { - pushdownFun = key -> StatePushDown.of().withFilter((IGraphFilter) pushdown.getFilters().get(key)); - } - - for (K key : keys) { - List> list = getEdges(key, pushdownFun.apply(key)); - res.put(key, (long) list.size()); - } - return res; - } - - protected abstract CloseableIterator>> getEdgesIterator(); - - protected abstract CloseableIterator> getVertexIterator(); - - @Override - public CloseableIterator vertexIDIterator() { - return new IteratorWithFn<>(getVertexIterator(), IVertex::getId); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - if (pushDown.getFilter() == null) { - return vertexIDIterator(); - } else { - return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); - } - } +import com.google.common.collect.Lists; - protected abstract CloseableIterator getKeyIterator(); +public abstract class BaseStaticGraphMemoryStore extends BaseGraphStore + implements IStaticGraphStore { + + @Override + public void init(StoreContext context) { + super.init(context); + } + + protected abstract IVertex getVertex(K sid); + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(sid); + return vertex != null && ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) + ? vertex + : null; + } + + protected abstract List> getEdges(K sid); + + protected List> pushdownEdges(List> list, IStatePushDown pushdown) { + if (pushdown.getOrderFields() != null) { + list.sort(EdgeAtom.getComparator(pushdown.getOrderFields())); + } + List> res = new ArrayList<>(list.size()); + Iterator> it = list.iterator(); + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + while (it.hasNext() && !filter.dropAllRemaining()) { + IEdge edge = it.next(); + if (filter.filterEdge(edge)) { + res.add(edge); + } + } + + return res; + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + return pushdownEdges(getEdges(sid), pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(sid, pushdown); + List> edges = getEdges(sid, pushdown); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edges.iterator())); + if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } + return null; + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + boolean emptyFilter = pushdown.getFilter().getFilterType() == FilterType.EMPTY; + return emptyFilter + ? getVertexIterator() + : new MemoryVertexScanIterator<>(getVertexIterator(), (IGraphFilter) pushdown.getFilter()); + } + + @Override + public CloseableIterator> getVertexIterator( + List list, IStatePushDown pushdown) { + return new KeysIterator<>(list, this::getVertex, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + CloseableIterator>> it = + new MemoryEdgeScanPushDownIterator<>(getEdgesIterator(), pushdown); + return new IteratorWithFlatFn<>(it, List::iterator); + } + + @Override + public CloseableIterator> getEdgeIterator(List list, IStatePushDown pushdown) { + CloseableIterator>> it = new KeysIterator<>(list, this::getEdges, pushdown); + return new IteratorWithFlatFn<>(it, List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + return new KeysIterator<>( + Lists.newArrayList(getKeyIterator()), this::getOneDegreeGraph, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + return new IteratorWithFn<>( + getEdgeIterator(pushdown), + edge -> Tuple.of(edge.getSrcId(), pushdown.getProjector().project(edge))); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return new IteratorWithFn<>( + getEdgeIterator(keys, pushdown), + edge -> Tuple.of(edge.getSrcId(), pushdown.getProjector().project(edge))); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + Map res = new HashMap<>(); + Iterator keyIt = getKeyIterator(); + while (keyIt.hasNext()) { + K key = keyIt.next(); + + List> list = getEdges(key, pushdown); + res.put(key, (long) list.size()); + } + return res; + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + Map res = new HashMap<>(keys.size()); + + Function pushdownFun; + if (pushdown.getFilters() == null) { + pushdownFun = key -> pushdown; + } else { + pushdownFun = + key -> StatePushDown.of().withFilter((IGraphFilter) pushdown.getFilters().get(key)); + } + + for (K key : keys) { + List> list = getEdges(key, pushdownFun.apply(key)); + res.put(key, (long) list.size()); + } + return res; + } + + protected abstract CloseableIterator>> getEdgesIterator(); + + protected abstract CloseableIterator> getVertexIterator(); + + @Override + public CloseableIterator vertexIDIterator() { + return new IteratorWithFn<>(getVertexIterator(), IVertex::getId); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + if (pushDown.getFilter() == null) { + return vertexIDIterator(); + } else { + return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); + } + } + + protected abstract CloseableIterator getKeyIterator(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/DynamicGraphMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/DynamicGraphMemoryStore.java index 9e2b3bbf5..b79a194ab 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/DynamicGraphMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/DynamicGraphMemoryStore.java @@ -19,7 +19,6 @@ package org.apache.geaflow.store.memory; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -32,6 +31,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.stream.Collectors; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -52,260 +52,250 @@ import org.apache.geaflow.store.memory.iterator.MemoryEdgeScanPushDownIterator; import org.apache.geaflow.store.memory.iterator.MemoryVertexScanIterator; -public class DynamicGraphMemoryStore extends BaseGraphStore implements IDynamicGraphStore { - - private Map>>> vertexId2Edges = new ConcurrentHashMap<>(); - private Map>> vertexId2Vertex = new ConcurrentHashMap<>(); - - public DynamicGraphMemoryStore() { - } - - @Override - public void init(StoreContext context) { - super.init(context); - } +import com.google.common.collect.Lists; - @Override - public void archive(long checkpointId) { +public class DynamicGraphMemoryStore extends BaseGraphStore + implements IDynamicGraphStore { - } + private Map>>> vertexId2Edges = new ConcurrentHashMap<>(); + private Map>> vertexId2Vertex = new ConcurrentHashMap<>(); - @Override - public void recovery(long checkpointId) { - - } + public DynamicGraphMemoryStore() {} - @Override - public long recoveryLatest() { - return 0; - } + @Override + public void init(StoreContext context) { + super.init(context); + } - @Override - public void compact() { + @Override + public void archive(long checkpointId) {} - } + @Override + public void recovery(long checkpointId) {} - @Override - public void flush() { + @Override + public long recoveryLatest() { + return 0; + } - } + @Override + public void compact() {} - @Override - public void close() { - this.vertexId2Edges.clear(); - this.vertexId2Vertex.clear(); - } + @Override + public void flush() {} - @Override - public void drop() { - this.vertexId2Edges = null; - this.vertexId2Vertex = null; - } + @Override + public void close() { + this.vertexId2Edges.clear(); + this.vertexId2Vertex.clear(); + } - public Map>> getVertexId2Edges(long version) { - Map>> map = new HashMap<>(vertexId2Edges.size()); - for (Entry>>> entry : vertexId2Edges.entrySet()) { - if (entry.getValue().containsKey(version)) { - map.put(entry.getKey(), entry.getValue().get(version)); - } + @Override + public void drop() { + this.vertexId2Edges = null; + this.vertexId2Vertex = null; + } - } - return map; + public Map>> getVertexId2Edges(long version) { + Map>> map = new HashMap<>(vertexId2Edges.size()); + for (Entry>>> entry : vertexId2Edges.entrySet()) { + if (entry.getValue().containsKey(version)) { + map.put(entry.getKey(), entry.getValue().get(version)); + } } - - public Map> getVertexId2Vertex(long version) { - Map> map = new HashMap<>(vertexId2Vertex.size()); - for (Entry>> entry : vertexId2Vertex.entrySet()) { - if (entry.getValue().containsKey(version)) { - map.put(entry.getKey(), entry.getValue().get(version)); - } - } - return map; + return map; + } + + public Map> getVertexId2Vertex(long version) { + Map> map = new HashMap<>(vertexId2Vertex.size()); + for (Entry>> entry : vertexId2Vertex.entrySet()) { + if (entry.getValue().containsKey(version)) { + map.put(entry.getKey(), entry.getValue().get(version)); + } } - - private Iterator getKeyIterator() { - Set keys = new HashSet<>(vertexId2Vertex.keySet()); - keys.addAll(vertexId2Edges.keySet()); - return keys.iterator(); + return map; + } + + private Iterator getKeyIterator() { + Set keys = new HashSet<>(vertexId2Vertex.keySet()); + keys.addAll(vertexId2Edges.keySet()); + return keys.iterator(); + } + + @Override + public void addEdge(long version, IEdge edge) { + K srcId = edge.getSrcId(); + Map>> map = vertexId2Edges.computeIfAbsent(srcId, k -> new HashMap<>()); + List> edges = map.computeIfAbsent(version, k -> new ArrayList<>()); + edges.add(edge); + } + + @Override + public void addVertex(long version, IVertex vertex) { + K id = vertex.getId(); + Map> map = vertexId2Vertex.computeIfAbsent(id, k -> new HashMap<>()); + map.put(version, vertex); + } + + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + Map> map = vertexId2Vertex.get(sid); + IVertex vertex = null; + if (map != null) { + vertex = map.get(version); } - - @Override - public void addEdge(long version, IEdge edge) { - K srcId = edge.getSrcId(); - Map>> map = vertexId2Edges.computeIfAbsent(srcId, - k -> new HashMap<>()); - List> edges = map.computeIfAbsent(version, k -> new ArrayList<>()); - edges.add(edge); + if (vertex != null && ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { + return vertex; } - - @Override - public void addVertex(long version, IVertex vertex) { - K id = vertex.getId(); - Map> map = vertexId2Vertex.computeIfAbsent(id, k -> new HashMap<>()); - map.put(version, vertex); + return null; + } + + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + Map>> map = vertexId2Edges.get(sid); + List> list = new ArrayList<>(); + if (map != null) { + list = map.getOrDefault(version, new ArrayList<>()); } - - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - Map> map = vertexId2Vertex.get(sid); - IVertex vertex = null; - if (map != null) { - vertex = map.get(version); - } - if (vertex != null && ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { - return vertex; - } - return null; + IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); + return list.stream().filter(filter::filterEdge).collect(Collectors.toList()); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(version, sid, pushdown); + List> edgeList = getEdges(version, sid, pushdown); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edgeList.iterator())); + if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } else { + return null; } - - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - Map>> map = vertexId2Edges.get(sid); - List> list = new ArrayList<>(); - if (map != null) { - list = map.getOrDefault(version, new ArrayList<>()); - } - IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); - return list.stream().filter(filter::filterEdge).collect(Collectors.toList()); + } + + @Override + public CloseableIterator vertexIDIterator() { + return IteratorWithClose.wrap(vertexId2Vertex.keySet().iterator()); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + if (pushdown.getFilter() == null) { + return new IteratorWithFilterThenFn<>( + vertexId2Vertex.entrySet().iterator(), + entry -> entry.getValue().containsKey(version), + Entry::getKey); + } else { + return new IteratorWithFn<>(getVertexIterator(version, pushdown), IVertex::getId); } - - @Override - public OneDegreeGraph getOneDegreeGraph(long version, K sid, - IStatePushDown pushdown) { - IVertex vertex = getVertex(version, sid, pushdown); - List> edgeList = getEdges(version, sid, pushdown); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edgeList.iterator())); - if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } else { - return null; - } - } - - @Override - public CloseableIterator vertexIDIterator() { - return IteratorWithClose.wrap(vertexId2Vertex.keySet().iterator()); - } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - if (pushdown.getFilter() == null) { - return new IteratorWithFilterThenFn<>(vertexId2Vertex.entrySet().iterator(), - entry -> entry.getValue().containsKey(version), Entry::getKey); - } else { - return new IteratorWithFn<>(getVertexIterator(version, pushdown), IVertex::getId); - } - } - - @Override - public CloseableIterator> getVertexIterator(long version, - IStatePushDown pushdown) { - return new MemoryVertexScanIterator<>(getVertexId2Vertex(version).values().iterator(), - (IGraphFilter) pushdown.getFilter()); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - BiFunction> fetchFun = (k, p) -> getVertex(version, k, p); - return new KeysIterator<>(keys, fetchFun, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - Iterator>> it = new MemoryEdgeScanPushDownIterator<>( + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + return new MemoryVertexScanIterator<>( + getVertexId2Vertex(version).values().iterator(), (IGraphFilter) pushdown.getFilter()); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + BiFunction> fetchFun = (k, p) -> getVertex(version, k, p); + return new KeysIterator<>(keys, fetchFun, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + Iterator>> it = + new MemoryEdgeScanPushDownIterator<>( getVertexId2Edges(version).values().iterator(), pushdown); - return new IteratorWithFlatFn<>(it, List::iterator); + return new IteratorWithFlatFn<>(it, List::iterator); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + BiFunction>> fetchFun = (k, p) -> getEdges(version, k, p); + Iterator>> it = new KeysIterator<>(keys, fetchFun, pushdown); + return new IteratorWithFlatFn<>(it, List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + BiFunction> fetchFun = + (k, p) -> getOneDegreeGraph(version, k, p); + + return new KeysIterator<>(Lists.newArrayList(getKeyIterator()), fetchFun, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + BiFunction> fetchFun = + (k, p) -> getOneDegreeGraph(version, k, p); + return new KeysIterator<>(keys, fetchFun, pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + Map> map = vertexId2Vertex.get(id); + if (map != null) { + return new ArrayList<>(map.keySet()); + } else { + return new ArrayList<>(); + } } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - BiFunction>> fetchFun = (k, p) -> getEdges(version, k, - p); - Iterator>> it = new KeysIterator<>(keys, fetchFun, pushdown); - return new IteratorWithFlatFn<>(it, List::iterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - BiFunction> fetchFun = - (k, p) -> getOneDegreeGraph( - version, k, p); - - return new KeysIterator<>(Lists.newArrayList(getKeyIterator()), fetchFun, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - List keys, - IStatePushDown pushdown) { - BiFunction> fetchFun = - (k, p) -> getOneDegreeGraph( - version, k, p); - return new KeysIterator<>(keys, fetchFun, pushdown); + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + Map> map = vertexId2Vertex.get(id); + if (map != null) { + return map.keySet().stream().max(Long::compare).get(); + } else { + return -1; + } } - - @Override - public List getAllVersions(K id, DataType dataType) { - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - Map> map = vertexId2Vertex.get(id); - if (map != null) { - return new ArrayList<>(map.keySet()); - } else { - return new ArrayList<>(); - } - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + if (dataType != DataType.V) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } - - @Override - public long getLatestVersion(K id, DataType dataType) { - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - Map> map = vertexId2Vertex.get(id); - if (map != null) { - return map.keySet().stream().max(Long::compare).get(); - } else { - return -1; - } - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + Map> map = vertexId2Vertex.get(id); + Map> res = new HashMap<>(map.size()); + IGraphFilter graphFilter = (IGraphFilter) pushdown.getFilter(); + for (Entry> entry : map.entrySet()) { + if (graphFilter.filterVertex(entry.getValue())) { + res.put(entry.getKey(), entry.getValue()); + } } - - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - if (dataType != DataType.V) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - Map> map = vertexId2Vertex.get(id); - Map> res = new HashMap<>(map.size()); - IGraphFilter graphFilter = (IGraphFilter) pushdown.getFilter(); - for (Entry> entry : map.entrySet()) { - if (graphFilter.filterVertex(entry.getValue())) { - res.put(entry.getKey(), entry.getValue()); - } - } - return res; + return res; + } + + @Override + public Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType) { + if (dataType != DataType.V) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } - - @Override - public Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType) { - if (dataType != DataType.V) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - Map> map = vertexId2Vertex.get(id); - Map> res = new HashMap<>(versions.size()); - IGraphFilter graphFilter = (IGraphFilter) pushdown.getFilter(); - for (long version : versions) { - IVertex vertex = map.get(version); - if (vertex != null && graphFilter.filterVertex(vertex)) { - res.put(version, vertex); - } - } - return res; + Map> map = vertexId2Vertex.get(id); + Map> res = new HashMap<>(versions.size()); + IGraphFilter graphFilter = (IGraphFilter) pushdown.getFilter(); + for (long version : versions) { + IVertex vertex = map.get(version); + if (vertex != null && graphFilter.filterVertex(vertex)) { + res.put(version, vertex); + } } + return res; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KListMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KListMemoryStore.java index 31c999f1f..2ffa190f8 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KListMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KListMemoryStore.java @@ -24,68 +24,57 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.store.IStatefulStore; import org.apache.geaflow.store.api.key.IKListStore; import org.apache.geaflow.store.context.StoreContext; public class KListMemoryStore implements IStatefulStore, IKListStore { - private Map> memoryStore = new HashMap<>(); - - @Override - public void add(K key, V... value) { - List list = memoryStore.computeIfAbsent(key, k -> new ArrayList<>()); - list.addAll(Arrays.asList(value)); - this.memoryStore.put(key, list); - } - - @Override - public void remove(K key) { - this.memoryStore.remove(key); - } - - @Override - public List get(K key) { - return this.memoryStore.getOrDefault(key, new ArrayList<>()); - } - - @Override - public void init(StoreContext storeContext) { - - } - - @Override - public void flush() { + private Map> memoryStore = new HashMap<>(); - } + @Override + public void add(K key, V... value) { + List list = memoryStore.computeIfAbsent(key, k -> new ArrayList<>()); + list.addAll(Arrays.asList(value)); + this.memoryStore.put(key, list); + } - @Override - public void close() { - this.memoryStore.clear(); - } + @Override + public void remove(K key) { + this.memoryStore.remove(key); + } - @Override - public void archive(long checkpointId) { + @Override + public List get(K key) { + return this.memoryStore.getOrDefault(key, new ArrayList<>()); + } - } + @Override + public void init(StoreContext storeContext) {} - @Override - public void recovery(long checkpointId) { + @Override + public void flush() {} - } + @Override + public void close() { + this.memoryStore.clear(); + } - @Override - public long recoveryLatest() { - return 0; - } + @Override + public void archive(long checkpointId) {} - @Override - public void compact() { + @Override + public void recovery(long checkpointId) {} - } + @Override + public long recoveryLatest() { + return 0; + } - @Override - public void drop() { + @Override + public void compact() {} - } + @Override + public void drop() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KMapMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KMapMemoryStore.java index ca7a0f746..87a29462e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KMapMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KMapMemoryStore.java @@ -24,89 +24,78 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.store.IStatefulStore; import org.apache.geaflow.store.api.key.IKMapStore; import org.apache.geaflow.store.context.StoreContext; public class KMapMemoryStore implements IStatefulStore, IKMapStore { - private Map> memoryStore = new HashMap<>(); - - @Override - public Map get(K key) { - return memoryStore.getOrDefault(key, new HashMap<>()); - } - - @Override - public List get(K key, UK... subKeys) { - Map map = get(key); - List list = new ArrayList<>(subKeys.length); - Arrays.stream(subKeys).forEach(c -> list.add(map.get(c))); - return list; - } - - @Override - public void add(K key, UK subKey, UV value) { - Map map = memoryStore.computeIfAbsent(key, k -> new HashMap<>()); - map.put(subKey, value); - } - - @Override - public void add(K key, Map map) { - Map tmp = memoryStore.computeIfAbsent(key, k -> new HashMap<>()); - tmp.putAll(map); - } - - @Override - public void remove(K key) { - this.memoryStore.remove(key); - } - - @Override - public void remove(K key, UK... subKeys) { - if (memoryStore.containsKey(key)) { - Map map = get(key); - Arrays.stream(subKeys).forEach(map::remove); - } - } - - @Override - public void init(StoreContext storeContext) { - + private Map> memoryStore = new HashMap<>(); + + @Override + public Map get(K key) { + return memoryStore.getOrDefault(key, new HashMap<>()); + } + + @Override + public List get(K key, UK... subKeys) { + Map map = get(key); + List list = new ArrayList<>(subKeys.length); + Arrays.stream(subKeys).forEach(c -> list.add(map.get(c))); + return list; + } + + @Override + public void add(K key, UK subKey, UV value) { + Map map = memoryStore.computeIfAbsent(key, k -> new HashMap<>()); + map.put(subKey, value); + } + + @Override + public void add(K key, Map map) { + Map tmp = memoryStore.computeIfAbsent(key, k -> new HashMap<>()); + tmp.putAll(map); + } + + @Override + public void remove(K key) { + this.memoryStore.remove(key); + } + + @Override + public void remove(K key, UK... subKeys) { + if (memoryStore.containsKey(key)) { + Map map = get(key); + Arrays.stream(subKeys).forEach(map::remove); } + } - @Override - public void flush() { + @Override + public void init(StoreContext storeContext) {} - } + @Override + public void flush() {} - @Override - public void close() { - this.memoryStore.clear(); - } + @Override + public void close() { + this.memoryStore.clear(); + } - @Override - public void archive(long checkpointId) { + @Override + public void archive(long checkpointId) {} - } + @Override + public void recovery(long checkpointId) {} - @Override - public void recovery(long checkpointId) { + @Override + public long recoveryLatest() { + return 0; + } - } + @Override + public void compact() {} - @Override - public long recoveryLatest() { - return 0; - } - - @Override - public void compact() { - - } - - @Override - public void drop() { - - } + @Override + public void drop() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KVMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KVMemoryStore.java index fcd8500ec..8d212d5d6 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KVMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/KVMemoryStore.java @@ -21,66 +21,55 @@ import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.store.IStatefulStore; import org.apache.geaflow.store.api.key.IKVStore; import org.apache.geaflow.store.context.StoreContext; public class KVMemoryStore implements IStatefulStore, IKVStore { - private Map memoryStore = new HashMap<>(); - - @Override - public void put(K key, V value) { - this.memoryStore.put(key, value); - } - - @Override - public void remove(K key) { - this.memoryStore.remove(key); - } - - @Override - public V get(K key) { - return this.memoryStore.get(key); - } - - @Override - public void init(StoreContext storeContext) { - - } - - @Override - public void flush() { + private Map memoryStore = new HashMap<>(); - } + @Override + public void put(K key, V value) { + this.memoryStore.put(key, value); + } - @Override - public void close() { - this.memoryStore.clear(); - } + @Override + public void remove(K key) { + this.memoryStore.remove(key); + } - @Override - public void archive(long checkpointId) { + @Override + public V get(K key) { + return this.memoryStore.get(key); + } - } + @Override + public void init(StoreContext storeContext) {} - @Override - public void recovery(long checkpointId) { + @Override + public void flush() {} - } + @Override + public void close() { + this.memoryStore.clear(); + } - @Override - public long recoveryLatest() { - return 0; - } + @Override + public void archive(long checkpointId) {} - @Override - public void compact() { + @Override + public void recovery(long checkpointId) {} - } + @Override + public long recoveryLatest() { + return 0; + } - @Override - public void drop() { + @Override + public void compact() {} - } + @Override + public void drop() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryConfigKeys.java index 6ea688f3a..870217cc8 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryConfigKeys.java @@ -24,8 +24,8 @@ public class MemoryConfigKeys { - public static final ConfigKey CSR_MEMORY_ENABLE = ConfigKeys - .key("geaflow.store.memory.csr.enable") - .defaultValue(false) - .description("memory graph layout csr or not, default false"); + public static final ConfigKey CSR_MEMORY_ENABLE = + ConfigKeys.key("geaflow.store.memory.csr.enable") + .defaultValue(false) + .description("memory graph layout csr or not, default false"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryStoreBuilder.java index 3909a5418..6337f5fd0 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/MemoryStoreBuilder.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,46 +33,46 @@ public class MemoryStoreBuilder implements IStoreBuilder { - private static final StoreDesc STORE_DESC = new MemoryStoreDesc(); + private static final StoreDesc STORE_DESC = new MemoryStoreDesc(); - public IBaseStore getStore(DataModel type, Configuration config) { - switch (type) { - case DYNAMIC_GRAPH: - return new DynamicGraphMemoryStore<>(); - case STATIC_GRAPH: - boolean csrEnable = config.getBoolean(MemoryConfigKeys.CSR_MEMORY_ENABLE); - return csrEnable ? new StaticGraphMemoryCSRStore<>() : new StaticGraphMemoryStore<>(); - case KV: - return new KVMemoryStore<>(); - case KList: - return new KListMemoryStore<>(); - case KMap: - return new KMapMemoryStore<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); - } + public IBaseStore getStore(DataModel type, Configuration config) { + switch (type) { + case DYNAMIC_GRAPH: + return new DynamicGraphMemoryStore<>(); + case STATIC_GRAPH: + boolean csrEnable = config.getBoolean(MemoryConfigKeys.CSR_MEMORY_ENABLE); + return csrEnable ? new StaticGraphMemoryCSRStore<>() : new StaticGraphMemoryStore<>(); + case KV: + return new KVMemoryStore<>(); + case KList: + return new KListMemoryStore<>(); + case KMap: + return new KMapMemoryStore<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); } + } - @Override - public StoreDesc getStoreDesc() { - return STORE_DESC; - } + @Override + public StoreDesc getStoreDesc() { + return STORE_DESC; + } - @Override - public List supportedDataModel() { - return Arrays.asList(DataModel.values()); - } + @Override + public List supportedDataModel() { + return Arrays.asList(DataModel.values()); + } - public static class MemoryStoreDesc implements StoreDesc { + public static class MemoryStoreDesc implements StoreDesc { - @Override - public boolean isLocalStore() { - return false; - } + @Override + public boolean isLocalStore() { + return false; + } - @Override - public String name() { - return StoreType.MEMORY.name(); - } + @Override + public String name() { + return StoreType.MEMORY.name(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryCSRStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryCSRStore.java index 46b097646..747adf770 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryCSRStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryCSRStore.java @@ -19,7 +19,6 @@ package org.apache.geaflow.store.memory; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; @@ -28,6 +27,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.stream.IntStream; + import org.apache.geaflow.collection.array.PrimitiveArray; import org.apache.geaflow.collection.array.PrimitiveArrayFactory; import org.apache.geaflow.collection.map.MapFactory; @@ -49,235 +49,230 @@ import org.apache.geaflow.store.memory.csr.vertex.IVertexArray; import org.apache.geaflow.store.memory.csr.vertex.VertexArrayFactory; -public class StaticGraphMemoryCSRStore extends BaseStaticGraphMemoryStore { - - // inner csr store. - private CSRStore csrStore; - private boolean isBuilt; - private List> vertexList; - private List> edgesList; - - @Override - public void init(StoreContext context) { - super.init(context); - isBuilt = false; - vertexList = new ArrayList<>(); - edgesList = new ArrayList<>(); - csrStore = new CSRStore<>(context); - } - - @Override - public void addEdge(IEdge edge) { - Preconditions.checkArgument(!isBuilt, "cannot add vertex/edge after flush."); - edgesList.add(edge); - } - - @Override - public void addVertex(IVertex vertex) { - Preconditions.checkArgument(!isBuilt, "cannot add vertex/edge after flush."); - vertexList.add(vertex); - } - - @Override - protected IVertex getVertex(K sid) { - Preconditions.checkArgument(isBuilt, "flush first."); - return csrStore.getVertex(sid); - } - - @Override - protected List> getEdges(K sid) { - Preconditions.checkArgument(isBuilt, "flush first."); - return csrStore.getEdges(sid); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - int pos = csrStore.getDictId(sid); - return getOneDegreeGraph(sid, pos, pushdown); - } - - public OneDegreeGraph getOneDegreeGraph(K sid, int pos, IStatePushDown pushdown) { - OneDegreeGraph oneDegreeGraph; - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - if (pos == CSRStore.NON_EXIST) { - return new OneDegreeGraph<>(sid, null, - IteratorWithClose.wrap(Collections.emptyIterator())); - } else { - IVertex vertex = csrStore.getVertex(sid, pos); - if (vertex == null || !filter.filterVertex(vertex)) { - vertex = null; - } - List> stream = pushdownEdges(csrStore.getEdges(sid, pos), pushdown); - oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(stream.iterator())); - } - return filter.filterOneDegreeGraph(oneDegreeGraph) ? oneDegreeGraph : null; - } - - @Override - public void archive(long checkpointId) { - - } +import com.google.common.base.Preconditions; - @Override - public void recovery(long checkpointId) { +public class StaticGraphMemoryCSRStore extends BaseStaticGraphMemoryStore { + // inner csr store. + private CSRStore csrStore; + private boolean isBuilt; + private List> vertexList; + private List> edgesList; + + @Override + public void init(StoreContext context) { + super.init(context); + isBuilt = false; + vertexList = new ArrayList<>(); + edgesList = new ArrayList<>(); + csrStore = new CSRStore<>(context); + } + + @Override + public void addEdge(IEdge edge) { + Preconditions.checkArgument(!isBuilt, "cannot add vertex/edge after flush."); + edgesList.add(edge); + } + + @Override + public void addVertex(IVertex vertex) { + Preconditions.checkArgument(!isBuilt, "cannot add vertex/edge after flush."); + vertexList.add(vertex); + } + + @Override + protected IVertex getVertex(K sid) { + Preconditions.checkArgument(isBuilt, "flush first."); + return csrStore.getVertex(sid); + } + + @Override + protected List> getEdges(K sid) { + Preconditions.checkArgument(isBuilt, "flush first."); + return csrStore.getEdges(sid); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + int pos = csrStore.getDictId(sid); + return getOneDegreeGraph(sid, pos, pushdown); + } + + public OneDegreeGraph getOneDegreeGraph(K sid, int pos, IStatePushDown pushdown) { + OneDegreeGraph oneDegreeGraph; + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + if (pos == CSRStore.NON_EXIST) { + return new OneDegreeGraph<>(sid, null, IteratorWithClose.wrap(Collections.emptyIterator())); + } else { + IVertex vertex = csrStore.getVertex(sid, pos); + if (vertex == null || !filter.filterVertex(vertex)) { + vertex = null; + } + List> stream = pushdownEdges(csrStore.getEdges(sid, pos), pushdown); + oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(stream.iterator())); } - - @Override - public long recoveryLatest() { - return 0; + return filter.filterOneDegreeGraph(oneDegreeGraph) ? oneDegreeGraph : null; + } + + @Override + public void archive(long checkpointId) {} + + @Override + public void recovery(long checkpointId) {} + + @Override + public long recoveryLatest() { + return 0; + } + + @Override + public void compact() {} + + @Override + public void flush() { + this.csrStore.build(vertexList, edgesList); + this.vertexList = null; + this.edgesList = null; + this.isBuilt = true; + } + + @Override + public void close() {} + + @Override + protected CloseableIterator>> getEdgesIterator() { + Preconditions.checkArgument(isBuilt, "flush first."); + return new IteratorWithFn<>( + IntStream.range(0, csrStore.getDict().size()).iterator(), + p -> csrStore.getEdges(csrStore.reverse.get(p), p)); + } + + @Override + protected CloseableIterator> getVertexIterator() { + Preconditions.checkArgument(isBuilt, "flush first."); + + return new IteratorWithFnThenFilter<>( + IntStream.range(0, csrStore.getDict().size()).iterator(), + p -> csrStore.getVertex(csrStore.reverse.get(p), p), + Objects::nonNull); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + Preconditions.checkArgument(isBuilt, "flush first."); + + return new IteratorWithFnThenFilter<>( + IntStream.range(0, csrStore.getDict().size()).iterator(), + p -> { + K k = csrStore.reverse.get(p); + return getOneDegreeGraph(k, p, pushdown); + }, + Objects::nonNull); + } + + @Override + protected CloseableIterator getKeyIterator() { + return IteratorWithClose.wrap(csrStore.getDict().keySet().iterator()); + } + + @Override + public void drop() { + csrStore.drop(); + } + + public static class CSRStore { + + public static final int NON_EXIST = -1; + private final Class keyClazz; + + private Map kDict; + private IVertexArray vertexArray; + private IEdgeArray edgeArray; + private PrimitiveArray reverse; + + public CSRStore(StoreContext context) { + GraphDataSchema graphDataSchema = context.getGraphSchema(); + this.keyClazz = graphDataSchema.getKeyType().getTypeClass(); + kDict = MapFactory.buildMap(keyClazz, Integer.TYPE); + vertexArray = VertexArrayFactory.getVertexArray(graphDataSchema); + edgeArray = EdgeArrayFactory.getEdgeArray(graphDataSchema); } - @Override - public void compact() { - + public int getDictIdOrRegister(K id) { + int res = kDict.computeIfAbsent(id, k -> kDict.size()); + return res; } - @Override - public void flush() { - this.csrStore.build(vertexList, edgesList); - this.vertexList = null; - this.edgesList = null; - this.isBuilt = true; + public int getDictId(K id) { + return kDict.getOrDefault(id, NON_EXIST); } - @Override - public void close() { - + public Map getDict() { + return kDict; } - @Override - protected CloseableIterator>> getEdgesIterator() { - Preconditions.checkArgument(isBuilt, "flush first."); - return new IteratorWithFn<>(IntStream.range(0, csrStore.getDict().size()).iterator(), - p -> csrStore.getEdges(csrStore.reverse.get(p), p)); + public List> getEdges(K sid) { + int pos = getDictId(sid); + return pos == NON_EXIST ? Collections.EMPTY_LIST : getEdges(sid, pos); } - @Override - protected CloseableIterator> getVertexIterator() { - Preconditions.checkArgument(isBuilt, "flush first."); - - return new IteratorWithFnThenFilter<>( - IntStream.range(0, csrStore.getDict().size()).iterator(), - p -> csrStore.getVertex(csrStore.reverse.get(p), p), Objects::nonNull); + public List> getEdges(K sid, int pos) { + Tuple edgePosRange = vertexArray.getEdgePosRange(pos); + return edgeArray.getRangeEdges(sid, edgePosRange.f0, edgePosRange.f1); } - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - Preconditions.checkArgument(isBuilt, "flush first."); - - return new IteratorWithFnThenFilter<>( - IntStream.range(0, csrStore.getDict().size()).iterator(), p -> { - K k = csrStore.reverse.get(p); - return getOneDegreeGraph(k, p, pushdown); - }, Objects::nonNull); + public IVertex getVertex(K id) { + int pos = getDictId(id); + return pos == NON_EXIST ? null : getVertex(id, pos); } - @Override - protected CloseableIterator getKeyIterator() { - return IteratorWithClose.wrap(csrStore.getDict().keySet().iterator()); + private IVertex getVertex(K id, int pos) { + return vertexArray.getVertex(id, pos); } - @Override - public void drop() { - csrStore.drop(); - } - - public static class CSRStore { - - public static final int NON_EXIST = -1; - private final Class keyClazz; - - private Map kDict; - private IVertexArray vertexArray; - private IEdgeArray edgeArray; - private PrimitiveArray reverse; - - public CSRStore(StoreContext context) { - GraphDataSchema graphDataSchema = context.getGraphSchema(); - this.keyClazz = graphDataSchema.getKeyType().getTypeClass(); - kDict = MapFactory.buildMap(keyClazz, Integer.TYPE); - vertexArray = VertexArrayFactory.getVertexArray(graphDataSchema); - edgeArray = EdgeArrayFactory.getEdgeArray(graphDataSchema); - } - - public int getDictIdOrRegister(K id) { - int res = kDict.computeIfAbsent(id, k -> kDict.size()); - return res; - } - - public int getDictId(K id) { - return kDict.getOrDefault(id, NON_EXIST); - } - - public Map getDict() { - return kDict; - } - - public List> getEdges(K sid) { - int pos = getDictId(sid); - return pos == NON_EXIST ? Collections.EMPTY_LIST : getEdges(sid, pos); - } - - public List> getEdges(K sid, int pos) { - Tuple edgePosRange = vertexArray.getEdgePosRange(pos); - return edgeArray.getRangeEdges(sid, edgePosRange.f0, edgePosRange.f1); - } - - public IVertex getVertex(K id) { - int pos = getDictId(id); - return pos == NON_EXIST ? null : getVertex(id, pos); - } - - private IVertex getVertex(K id, int pos) { - return vertexArray.getVertex(id, pos); - } - - public void build(List> vertexList, List> edgesList) { - List>> edgesListTmp = new ArrayList<>(); - for (IVertex vertex : vertexList) { - int dictId = getDictIdOrRegister(vertex.getId()); - edgesListTmp.add(dictId, new LinkedList<>()); - } - edgesList.forEach(edge -> { - int oldSize = kDict.size(); - int dictId = getDictIdOrRegister(edge.getSrcId()); - if (dictId > oldSize - 1) { // new register. - vertexList.add(dictId, null); - edgesListTmp.add(dictId, new LinkedList<>()); - } - edgesListTmp.get(dictId).add(edge); - }); - reverse = PrimitiveArrayFactory.getCustomArray(this.keyClazz, kDict.size()); - for (Entry entry : kDict.entrySet()) { - reverse.set(entry.getValue(), entry.getKey()); + public void build(List> vertexList, List> edgesList) { + List>> edgesListTmp = new ArrayList<>(); + for (IVertex vertex : vertexList) { + int dictId = getDictIdOrRegister(vertex.getId()); + edgesListTmp.add(dictId, new LinkedList<>()); + } + edgesList.forEach( + edge -> { + int oldSize = kDict.size(); + int dictId = getDictIdOrRegister(edge.getSrcId()); + if (dictId > oldSize - 1) { // new register. + vertexList.add(dictId, null); + edgesListTmp.add(dictId, new LinkedList<>()); } - - vertexArray.init(vertexList.size()); - int edgesNum = edgesListTmp.stream().mapToInt(value -> value != null ? value.size() : 0) - .sum(); - edgeArray.init(keyClazz, edgesNum); - for (int i = 0; i < vertexList.size(); i++) { - vertexArray.set(i, vertexList.get(i)); - int nextPos; - if (edgesListTmp.get(i) != null && !edgesListTmp.get(i).isEmpty()) { - for (IEdge edge : edgesListTmp.get(i)) { - nextPos = vertexArray.getNextPos(i); - edgeArray.set(nextPos, edge); - vertexArray.updateVId2EPos(i); - } - } - } - } - - public void drop() { - kDict = null; - vertexArray = null; - edgeArray = null; + edgesListTmp.get(dictId).add(edge); + }); + reverse = PrimitiveArrayFactory.getCustomArray(this.keyClazz, kDict.size()); + for (Entry entry : kDict.entrySet()) { + reverse.set(entry.getValue(), entry.getKey()); + } + + vertexArray.init(vertexList.size()); + int edgesNum = + edgesListTmp.stream().mapToInt(value -> value != null ? value.size() : 0).sum(); + edgeArray.init(keyClazz, edgesNum); + for (int i = 0; i < vertexList.size(); i++) { + vertexArray.set(i, vertexList.get(i)); + int nextPos; + if (edgesListTmp.get(i) != null && !edgesListTmp.get(i).isEmpty()) { + for (IEdge edge : edgesListTmp.get(i)) { + nextPos = vertexArray.getNextPos(i); + edgeArray.set(nextPos, edge); + vertexArray.updateVId2EPos(i); + } } + } } - + public void drop() { + kDict = null; + vertexArray = null; + edgeArray = null; + } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryStore.java index 967e1c307..21a65048a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/StaticGraphMemoryStore.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -35,93 +36,79 @@ public class StaticGraphMemoryStore extends BaseStaticGraphMemoryStore { - protected Map, List>>> map; - - public StaticGraphMemoryStore() { - } - - @Override - public void init(StoreContext context) { - super.init(context); - map = new ConcurrentHashMap<>(); - } - - @Override - public void addEdge(IEdge edge) { - K srcId = edge.getSrcId(); - Tuple, List>> v = map.computeIfAbsent(srcId, - k -> Tuple.of(null, new ArrayList<>())); - v.f1.add(edge); - } - - @Override - public void addVertex(IVertex vertex) { - K srcId = vertex.getId(); - Tuple, List>> v = map.computeIfAbsent(srcId, - k -> Tuple.of(null, new ArrayList<>())); - v.f0 = vertex; - } - - @Override - protected IVertex getVertex(K sid) { - Tuple, List>> v = map.get(sid); - return v == null ? null : v.f0; - } - - @Override - protected List> getEdges(K sid) { - Tuple, List>> v = map.get(sid); - return v == null ? Collections.EMPTY_LIST : v.f1; - } - - @Override - protected CloseableIterator>> getEdgesIterator() { - return IteratorWithClose.wrap(map.values().stream().map(c -> c.f1).iterator()); - } - - @Override - protected CloseableIterator> getVertexIterator() { - return new IteratorWithFnThenFilter<>(map.values().iterator(), c -> c.f0, Objects::nonNull); - } - - @Override - protected CloseableIterator getKeyIterator() { - return IteratorWithClose.wrap(map.keySet().iterator()); - } - - @Override - public void close() { - - } - - @Override - public void archive(long checkpointId) { - - } - - @Override - public void recovery(long checkpointId) { - - } - - @Override - public long recoveryLatest() { - return 0; - } - - @Override - public void compact() { - - } - - @Override - public void flush() { - - } - - @Override - public void drop() { - - } - + protected Map, List>>> map; + + public StaticGraphMemoryStore() {} + + @Override + public void init(StoreContext context) { + super.init(context); + map = new ConcurrentHashMap<>(); + } + + @Override + public void addEdge(IEdge edge) { + K srcId = edge.getSrcId(); + Tuple, List>> v = + map.computeIfAbsent(srcId, k -> Tuple.of(null, new ArrayList<>())); + v.f1.add(edge); + } + + @Override + public void addVertex(IVertex vertex) { + K srcId = vertex.getId(); + Tuple, List>> v = + map.computeIfAbsent(srcId, k -> Tuple.of(null, new ArrayList<>())); + v.f0 = vertex; + } + + @Override + protected IVertex getVertex(K sid) { + Tuple, List>> v = map.get(sid); + return v == null ? null : v.f0; + } + + @Override + protected List> getEdges(K sid) { + Tuple, List>> v = map.get(sid); + return v == null ? Collections.EMPTY_LIST : v.f1; + } + + @Override + protected CloseableIterator>> getEdgesIterator() { + return IteratorWithClose.wrap(map.values().stream().map(c -> c.f1).iterator()); + } + + @Override + protected CloseableIterator> getVertexIterator() { + return new IteratorWithFnThenFilter<>(map.values().iterator(), c -> c.f0, Objects::nonNull); + } + + @Override + protected CloseableIterator getKeyIterator() { + return IteratorWithClose.wrap(map.keySet().iterator()); + } + + @Override + public void close() {} + + @Override + public void archive(long checkpointId) {} + + @Override + public void recovery(long checkpointId) {} + + @Override + public long recoveryLatest() { + return 0; + } + + @Override + public void compact() {} + + @Override + public void flush() {} + + @Override + public void drop() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/EdgeArrayFactory.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/EdgeArrayFactory.java index 7417ba688..f493c3709 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/EdgeArrayFactory.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/EdgeArrayFactory.java @@ -32,19 +32,19 @@ public class EdgeArrayFactory { - public static IEdgeArray getEdgeArray(GraphDataSchema dataSchema) { - boolean noProperty = dataSchema.isEmptyEdgeProperty(); - GraphElementFlag flag = GraphElementFlag.build(dataSchema.getEdgeMeta().getGraphElementClass()); - IEdgeArray edgeArray; - if (flag.isLabeledAndTimed()) { - edgeArray = noProperty ? new IDLabelTimeEdgeArray<>() : new ValueLabelTimeEdgeArray<>(); - } else if (flag.isLabeled()) { - edgeArray = noProperty ? new IDLabelEdgeArray<>() : new ValueLabelEdgeArray<>(); - } else if (flag.isTimed()) { - edgeArray = noProperty ? new IDTimeEdgeArray<>() : new ValueTimeEdgeArray<>(); - } else { - edgeArray = noProperty ? new IDEdgeArray<>() : new ValueEdgeArray<>(); - } - return edgeArray; + public static IEdgeArray getEdgeArray(GraphDataSchema dataSchema) { + boolean noProperty = dataSchema.isEmptyEdgeProperty(); + GraphElementFlag flag = GraphElementFlag.build(dataSchema.getEdgeMeta().getGraphElementClass()); + IEdgeArray edgeArray; + if (flag.isLabeledAndTimed()) { + edgeArray = noProperty ? new IDLabelTimeEdgeArray<>() : new ValueLabelTimeEdgeArray<>(); + } else if (flag.isLabeled()) { + edgeArray = noProperty ? new IDLabelEdgeArray<>() : new ValueLabelEdgeArray<>(); + } else if (flag.isTimed()) { + edgeArray = noProperty ? new IDTimeEdgeArray<>() : new ValueTimeEdgeArray<>(); + } else { + edgeArray = noProperty ? new IDEdgeArray<>() : new ValueEdgeArray<>(); } + return edgeArray; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/IEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/IEdgeArray.java index beca2d11f..37a170a88 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/IEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/IEdgeArray.java @@ -20,15 +20,16 @@ package org.apache.geaflow.store.memory.csr.edge; import java.util.List; + import org.apache.geaflow.model.graph.edge.IEdge; public interface IEdgeArray { - void init(Class keyType, int capacity); + void init(Class keyType, int capacity); - void set(int pos, IEdge edge); + void set(int pos, IEdge edge); - List> getRangeEdges(K sid, int start, int end); + List> getRangeEdges(K sid, int start, int end); - void drop(); + void drop(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDEdgeArray.java index 96fbc321f..a45466491 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDEdgeArray.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.collection.array.PrimitiveArray; import org.apache.geaflow.collection.array.PrimitiveArrayFactory; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -30,45 +31,44 @@ public class IDEdgeArray implements IEdgeArray { - private PrimitiveArray dstIds; - private PrimitiveArray directions; - - public void init(Class keyType, int capacity) { - dstIds = PrimitiveArrayFactory.getCustomArray(keyType, capacity); - directions = PrimitiveArrayFactory.getCustomArray(Byte.class, capacity); - } + private PrimitiveArray dstIds; + private PrimitiveArray directions; - @Override - public List> getRangeEdges(K sid, int start, int end) { - List> edges = new ArrayList<>(end - start); - for (int i = start; i < end; i++) { - edges.add(getEdge(sid, i)); - } - return edges; - } + public void init(Class keyType, int capacity) { + dstIds = PrimitiveArrayFactory.getCustomArray(keyType, capacity); + directions = PrimitiveArrayFactory.getCustomArray(Byte.class, capacity); + } - protected IEdge getEdge(K sid, int pos) { - return new IDEdge<>(sid, getDstId(pos), getDirection(pos)); + @Override + public List> getRangeEdges(K sid, int start, int end) { + List> edges = new ArrayList<>(end - start); + for (int i = start; i < end; i++) { + edges.add(getEdge(sid, i)); } + return edges; + } - protected K getDstId(int pos) { - return dstIds.get(pos); - } + protected IEdge getEdge(K sid, int pos) { + return new IDEdge<>(sid, getDstId(pos), getDirection(pos)); + } - protected EdgeDirection getDirection(int pos) { - return EdgeDirection.values()[directions.get(pos)]; - } + protected K getDstId(int pos) { + return dstIds.get(pos); + } - @Override - public void drop() { - dstIds = null; - directions = null; - } + protected EdgeDirection getDirection(int pos) { + return EdgeDirection.values()[directions.get(pos)]; + } - @Override - public void set(int pos, IEdge edge) { - dstIds.set(pos, edge.getTargetId()); - directions.set(pos, (byte) edge.getDirect().ordinal()); - } + @Override + public void drop() { + dstIds = null; + directions = null; + } + @Override + public void set(int pos, IEdge edge) { + dstIds.set(pos, edge.getTargetId()); + directions.set(pos, (byte) edge.getDirect().ordinal()); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelEdgeArray.java index aaad67748..aabf4068e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelEdgeArray.java @@ -25,28 +25,27 @@ public class IDLabelEdgeArray extends IDEdgeArray { - private String[] labels; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - labels = new String[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new IDLabelEdge<>(sid, getDstId(pos), getDirection(pos), labels[pos]); - } - - @Override - public void drop() { - super.drop(); - labels = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); - } - + private String[] labels; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + labels = new String[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new IDLabelEdge<>(sid, getDstId(pos), getDirection(pos), labels[pos]); + } + + @Override + public void drop() { + super.drop(); + labels = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelTimeEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelTimeEdgeArray.java index f4d5f07ee..45e4533bb 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelTimeEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDLabelTimeEdgeArray.java @@ -26,33 +26,31 @@ public class IDLabelTimeEdgeArray extends IDEdgeArray { - private String[] labels; - private long[] times; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - labels = new String[capacity]; - times = new long[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new IDLabelTimeEdge<>(sid, getDstId(pos), getDirection(pos), - labels[pos], times[pos]); - } - - @Override - public void drop() { - super.drop(); - labels = null; - times = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - labels[pos] = ((IGraphElementWithLabelField) edge).getLabel(); - times[pos] = ((IGraphElementWithTimeField) edge).getTime(); - } - + private String[] labels; + private long[] times; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + labels = new String[capacity]; + times = new long[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new IDLabelTimeEdge<>(sid, getDstId(pos), getDirection(pos), labels[pos], times[pos]); + } + + @Override + public void drop() { + super.drop(); + labels = null; + times = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + labels[pos] = ((IGraphElementWithLabelField) edge).getLabel(); + times[pos] = ((IGraphElementWithTimeField) edge).getTime(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDTimeEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDTimeEdgeArray.java index 9188ffa74..f58c6670b 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDTimeEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/IDTimeEdgeArray.java @@ -25,28 +25,27 @@ public class IDTimeEdgeArray extends IDEdgeArray { - private long[] times; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - times = new long[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new IDTimeEdge<>(sid, getDstId(pos), getDirection(pos), times[pos]); - } - - @Override - public void drop() { - super.drop(); - times = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - times[pos] = ((IGraphElementWithTimeField) edge).getTime(); - } - + private long[] times; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + times = new long[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new IDTimeEdge<>(sid, getDstId(pos), getDirection(pos), times[pos]); + } + + @Override + public void drop() { + super.drop(); + times = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + times[pos] = ((IGraphElementWithTimeField) edge).getTime(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueEdgeArray.java index 2143728a6..0b95bbd8c 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueEdgeArray.java @@ -24,31 +24,30 @@ public class ValueEdgeArray extends IDEdgeArray { - private Object[] values; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - values = new Object[capacity]; - } - - protected IEdge getEdge(K sid, int pos) { - return new ValueEdge<>(sid, getDstId(pos), getValue(pos), getDirection(pos)); - } - - protected Object getValue(int pos) { - return values[pos]; - } - - @Override - public void drop() { - super.drop(); - values = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - values[pos] = edge.getValue(); - } - + private Object[] values; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + values = new Object[capacity]; + } + + protected IEdge getEdge(K sid, int pos) { + return new ValueEdge<>(sid, getDstId(pos), getValue(pos), getDirection(pos)); + } + + protected Object getValue(int pos) { + return values[pos]; + } + + @Override + public void drop() { + super.drop(); + values = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + values[pos] = edge.getValue(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelEdgeArray.java index 5548ca0c5..1352bd2b9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelEdgeArray.java @@ -25,29 +25,27 @@ public class ValueLabelEdgeArray extends ValueEdgeArray { - private String[] labels; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - labels = new String[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new ValueLabelEdge<>(sid, getDstId(pos), getValue(pos), - getDirection(pos), labels[pos]); - } - - @Override - public void drop() { - super.drop(); - labels = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); - } - + private String[] labels; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + labels = new String[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new ValueLabelEdge<>(sid, getDstId(pos), getValue(pos), getDirection(pos), labels[pos]); + } + + @Override + public void drop() { + super.drop(); + labels = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelTimeEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelTimeEdgeArray.java index b8e1913ce..bf118618b 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelTimeEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueLabelTimeEdgeArray.java @@ -26,33 +26,32 @@ public class ValueLabelTimeEdgeArray extends ValueEdgeArray { - private String[] labels; - private long[] times; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - labels = new String[capacity]; - times = new long[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new ValueLabelTimeEdge<>(sid, getDstId(pos), getValue(pos), - getDirection(pos), labels[pos], times[pos]); - } - - @Override - public void drop() { - super.drop(); - labels = null; - times = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); - times[pos] = ((IGraphElementWithTimeField) edge).getTime(); - } - + private String[] labels; + private long[] times; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + labels = new String[capacity]; + times = new long[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new ValueLabelTimeEdge<>( + sid, getDstId(pos), getValue(pos), getDirection(pos), labels[pos], times[pos]); + } + + @Override + public void drop() { + super.drop(); + labels = null; + times = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + labels[pos] = ((IGraphElementWithLabelField) edge).getLabel().intern(); + times[pos] = ((IGraphElementWithTimeField) edge).getTime(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueTimeEdgeArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueTimeEdgeArray.java index 393c15d7f..169d82d48 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueTimeEdgeArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/edge/type/ValueTimeEdgeArray.java @@ -25,29 +25,27 @@ public class ValueTimeEdgeArray extends ValueEdgeArray { - private long[] times; - - public void init(Class keyType, int capacity) { - super.init(keyType, capacity); - times = new long[capacity]; - } - - @Override - protected IEdge getEdge(K sid, int pos) { - return new ValueTimeEdge<>(sid, getDstId(pos), getValue(pos), - getDirection(pos), times[pos]); - } - - @Override - public void drop() { - super.drop(); - times = null; - } - - @Override - public void set(int pos, IEdge edge) { - super.set(pos, edge); - times[pos] = ((IGraphElementWithTimeField) edge).getTime(); - } - + private long[] times; + + public void init(Class keyType, int capacity) { + super.init(keyType, capacity); + times = new long[capacity]; + } + + @Override + protected IEdge getEdge(K sid, int pos) { + return new ValueTimeEdge<>(sid, getDstId(pos), getValue(pos), getDirection(pos), times[pos]); + } + + @Override + public void drop() { + super.drop(); + times = null; + } + + @Override + public void set(int pos, IEdge edge) { + super.set(pos, edge); + times[pos] = ((IGraphElementWithTimeField) edge).getTime(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/IVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/IVertexArray.java index 63b498760..2a7930799 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/IVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/IVertexArray.java @@ -24,17 +24,17 @@ public interface IVertexArray { - void init(int capacity); + void init(int capacity); - void set(int pos, IVertex vertex); + void set(int pos, IVertex vertex); - void updateVId2EPos(int pos); + void updateVId2EPos(int pos); - int getNextPos(int pos); + int getNextPos(int pos); - Tuple getEdgePosRange(int pos); + Tuple getEdgePosRange(int pos); - IVertex getVertex(K key, int pos); + IVertex getVertex(K key, int pos); - void drop(); + void drop(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/VertexArrayFactory.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/VertexArrayFactory.java index b0a54d6dd..556514896 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/VertexArrayFactory.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/VertexArrayFactory.java @@ -32,20 +32,20 @@ public class VertexArrayFactory { - public static IVertexArray getVertexArray(GraphDataSchema dataSchema) { - boolean noProperty = dataSchema.isEmptyVertexProperty(); - GraphElementFlag flag = GraphElementFlag.build(dataSchema.getVertexMeta().getGraphElementClass()); - IVertexArray vertexArray; - if (flag.isLabeledAndTimed()) { - vertexArray = noProperty ? new IDLabelTimeVertexArray<>() : new ValueLabelTimeVertexArray<>(); - } else if (flag.isLabeled()) { - vertexArray = noProperty ? new IDLabelVertexArray<>() : new ValueLabelVertexArray<>(); - } else if (flag.isTimed()) { - vertexArray = noProperty ? new IDTimeVertexArray<>() : new ValueTimeVertexArray<>(); - } else { - vertexArray = noProperty ? new IDVertexArray<>() : new ValueVertexArray<>(); - } - return vertexArray; + public static IVertexArray getVertexArray(GraphDataSchema dataSchema) { + boolean noProperty = dataSchema.isEmptyVertexProperty(); + GraphElementFlag flag = + GraphElementFlag.build(dataSchema.getVertexMeta().getGraphElementClass()); + IVertexArray vertexArray; + if (flag.isLabeledAndTimed()) { + vertexArray = noProperty ? new IDLabelTimeVertexArray<>() : new ValueLabelTimeVertexArray<>(); + } else if (flag.isLabeled()) { + vertexArray = noProperty ? new IDLabelVertexArray<>() : new ValueLabelVertexArray<>(); + } else if (flag.isTimed()) { + vertexArray = noProperty ? new IDTimeVertexArray<>() : new ValueTimeVertexArray<>(); + } else { + vertexArray = noProperty ? new IDVertexArray<>() : new ValueVertexArray<>(); } - + return vertexArray; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelTimeVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelTimeVertexArray.java index 55341bcd3..6607a6e3a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelTimeVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelTimeVertexArray.java @@ -26,34 +26,33 @@ public class IDLabelTimeVertexArray extends IDVertexArray { - private String[] labels; - private long[] times; - - public void init(int capacity) { - super.init(capacity); - labels = new String[capacity]; - times = new long[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new IDLabelTimeVertex<>(key, labels[pos], times[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); - times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); - } - } - - @Override - public void drop() { - super.drop(); - labels = null; - times = null; + private String[] labels; + private long[] times; + + public void init(int capacity) { + super.init(capacity); + labels = new String[capacity]; + times = new long[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new IDLabelTimeVertex<>(key, labels[pos], times[pos]) : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); + times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); } - + } + + @Override + public void drop() { + super.drop(); + labels = null; + times = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelVertexArray.java index 35e5fc77b..9f23775d6 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDLabelVertexArray.java @@ -25,30 +25,29 @@ public class IDLabelVertexArray extends IDVertexArray { - private String[] labels; - - public void init(int capacity) { - super.init(capacity); - labels = new String[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new IDLabelVertex<>(key, labels[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); - } - } - - @Override - public void drop() { - super.drop(); - labels = null; + private String[] labels; + + public void init(int capacity) { + super.init(capacity); + labels = new String[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new IDLabelVertex<>(key, labels[pos]) : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); } + } + @Override + public void drop() { + super.drop(); + labels = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDTimeVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDTimeVertexArray.java index 5d733bda9..a2d2a0539 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDTimeVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDTimeVertexArray.java @@ -25,30 +25,29 @@ public class IDTimeVertexArray extends IDVertexArray { - private long[] times; - - public void init(int capacity) { - super.init(capacity); - times = new long[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new IDTimeVertex<>(key, times[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); - } - } - - @Override - public void drop() { - super.drop(); - times = null; + private long[] times; + + public void init(int capacity) { + super.init(capacity); + times = new long[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new IDTimeVertex<>(key, times[pos]) : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); } + } + @Override + public void drop() { + super.drop(); + times = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDVertexArray.java index 3c7d4acf8..68d4ecde9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/IDVertexArray.java @@ -20,6 +20,7 @@ package org.apache.geaflow.store.memory.csr.vertex.type; import java.util.BitSet; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.model.graph.vertex.impl.IDVertex; @@ -27,53 +28,52 @@ public class IDVertexArray implements IVertexArray { - private int[] vId2EPos; - protected BitSet nullVertexBitSet; - - @Override - public void init(int capacity) { - vId2EPos = new int[capacity + 1]; - nullVertexBitSet = new BitSet(); - } + private int[] vId2EPos; + protected BitSet nullVertexBitSet; - @Override - public Tuple getEdgePosRange(int pos) { - if (pos < vId2EPos.length - 1) { - return Tuple.of(vId2EPos[pos], vId2EPos[pos + 1]); - } - return Tuple.of(0, 0); - } + @Override + public void init(int capacity) { + vId2EPos = new int[capacity + 1]; + nullVertexBitSet = new BitSet(); + } - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new IDVertex<>(key) : null; + @Override + public Tuple getEdgePosRange(int pos) { + if (pos < vId2EPos.length - 1) { + return Tuple.of(vId2EPos[pos], vId2EPos[pos + 1]); } + return Tuple.of(0, 0); + } - protected boolean containsVertex(int pos) { - return !nullVertexBitSet.get(pos); - } + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new IDVertex<>(key) : null; + } - @Override - public void drop() { - vId2EPos = null; - } + protected boolean containsVertex(int pos) { + return !nullVertexBitSet.get(pos); + } - @Override - public void set(int pos, IVertex vertex) { - if (vertex == null) { - nullVertexBitSet.set(pos); - } - vId2EPos[pos + 1] = vId2EPos[pos]; - } + @Override + public void drop() { + vId2EPos = null; + } - @Override - public int getNextPos(int pos) { - return vId2EPos[pos + 1]; + @Override + public void set(int pos, IVertex vertex) { + if (vertex == null) { + nullVertexBitSet.set(pos); } + vId2EPos[pos + 1] = vId2EPos[pos]; + } - @Override - public void updateVId2EPos(int pos) { - vId2EPos[pos + 1] = vId2EPos[pos + 1] + 1; - } + @Override + public int getNextPos(int pos) { + return vId2EPos[pos + 1]; + } + @Override + public void updateVId2EPos(int pos) { + vId2EPos[pos + 1] = vId2EPos[pos + 1] + 1; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelTimeVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelTimeVertexArray.java index 38e06dc9e..e4fb24a72 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelTimeVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelTimeVertexArray.java @@ -26,35 +26,35 @@ public class ValueLabelTimeVertexArray extends ValueVertexArray { - private String[] labels; - private long[] times; - - public void init(int capacity) { - super.init(capacity); - labels = new String[capacity]; - times = new long[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) - ? new ValueLabelTimeVertex<>(key, getValue(pos), labels[pos], times[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); - times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); - } - } - - @Override - public void drop() { - super.drop(); - labels = null; - times = null; + private String[] labels; + private long[] times; + + public void init(int capacity) { + super.init(capacity); + labels = new String[capacity]; + times = new long[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) + ? new ValueLabelTimeVertex<>(key, getValue(pos), labels[pos], times[pos]) + : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); + times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); } - + } + + @Override + public void drop() { + super.drop(); + labels = null; + times = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelVertexArray.java index 13550cec7..6b11612c5 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueLabelVertexArray.java @@ -25,30 +25,29 @@ public class ValueLabelVertexArray extends ValueVertexArray { - private String[] labels; - - public void init(int capacity) { - super.init(capacity); - labels = new String[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new ValueLabelVertex<>(key, getValue(pos), labels[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); - } - } - - @Override - public void drop() { - super.drop(); - labels = null; + private String[] labels; + + public void init(int capacity) { + super.init(capacity); + labels = new String[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new ValueLabelVertex<>(key, getValue(pos), labels[pos]) : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + labels[pos] = ((IGraphElementWithLabelField) vertex).getLabel().intern(); } + } + @Override + public void drop() { + super.drop(); + labels = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueTimeVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueTimeVertexArray.java index 7ab753e03..3b864be02 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueTimeVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueTimeVertexArray.java @@ -25,30 +25,29 @@ public class ValueTimeVertexArray extends ValueVertexArray { - private long[] times; - - public void init(int capacity) { - super.init(capacity); - times = new long[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new ValueTimeVertex<>(key, getValue(pos), times[pos]) : null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); - } - } - - @Override - public void drop() { - super.drop(); - times = null; + private long[] times; + + public void init(int capacity) { + super.init(capacity); + times = new long[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new ValueTimeVertex<>(key, getValue(pos), times[pos]) : null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + times[pos] = ((IGraphElementWithTimeField) vertex).getTime(); } + } + @Override + public void drop() { + super.drop(); + times = null; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueVertexArray.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueVertexArray.java index b1de7e237..ab8317845 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueVertexArray.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/csr/vertex/type/ValueVertexArray.java @@ -24,33 +24,33 @@ public class ValueVertexArray extends IDVertexArray { - private Object[] values; - - public void init(int capacity) { - super.init(capacity); - values = new Object[capacity]; - } - - @Override - public IVertex getVertex(K key, int pos) { - return containsVertex(pos) ? new ValueVertex<>(key, getValue(pos)) : null; - } - - protected Object getValue(int pos) { - return values[pos]; - } - - @Override - public void drop() { - super.drop(); - values = null; - } - - @Override - public void set(int pos, IVertex vertex) { - super.set(pos, vertex); - if (vertex != null) { - values[pos] = vertex.getValue(); - } + private Object[] values; + + public void init(int capacity) { + super.init(capacity); + values = new Object[capacity]; + } + + @Override + public IVertex getVertex(K key, int pos) { + return containsVertex(pos) ? new ValueVertex<>(key, getValue(pos)) : null; + } + + protected Object getValue(int pos) { + return values[pos]; + } + + @Override + public void drop() { + super.drop(); + values = null; + } + + @Override + public void set(int pos, IVertex vertex) { + super.set(pos, vertex); + if (vertex != null) { + values[pos] = vertex.getValue(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryEdgeScanPushDownIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryEdgeScanPushDownIterator.java index 68b02cef3..38777b80e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryEdgeScanPushDownIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryEdgeScanPushDownIterator.java @@ -19,11 +19,11 @@ package org.apache.geaflow.store.memory.iterator; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Comparator; import java.util.Iterator; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.state.graph.encoder.EdgeAtom; @@ -32,54 +32,54 @@ import org.apache.geaflow.state.pushdown.filter.inner.IGraphFilter; import org.apache.geaflow.state.pushdown.limit.IEdgeLimit; -public class MemoryEdgeScanPushDownIterator implements CloseableIterator>> { +import com.google.common.collect.Lists; - private final Iterator>> iterator; - private final Comparator edgeComparator; - private final IEdgeLimit edgeLimit; - private final IGraphFilter filter; - private List> nextValue; +public class MemoryEdgeScanPushDownIterator + implements CloseableIterator>> { - public MemoryEdgeScanPushDownIterator( - Iterator>> iterator, - IStatePushDown pushdown) { - this.filter = (IGraphFilter) pushdown.getFilter(); - this.edgeLimit = pushdown.getEdgeLimit(); - List orderFields = pushdown.getOrderFields(); - this.edgeComparator = EdgeAtom.getComparator(orderFields); - this.iterator = iterator; - } + private final Iterator>> iterator; + private final Comparator edgeComparator; + private final IEdgeLimit edgeLimit; + private final IGraphFilter filter; + private List> nextValue; - @Override - public boolean hasNext() { - if (iterator.hasNext()) { - List> list = Lists.newArrayList(iterator.next()); - if (this.edgeComparator != null) { - list.sort(this.edgeComparator); - } - List> res = new ArrayList<>(list.size()); - Iterator> it = list.iterator(); - IGraphFilter filter = GraphFilter.of(this.filter, this.edgeLimit); - while (it.hasNext() && !filter.dropAllRemaining()) { - IEdge edge = it.next(); - if (filter.filterEdge(edge)) { - res.add(edge); - } - } + public MemoryEdgeScanPushDownIterator( + Iterator>> iterator, IStatePushDown pushdown) { + this.filter = (IGraphFilter) pushdown.getFilter(); + this.edgeLimit = pushdown.getEdgeLimit(); + List orderFields = pushdown.getOrderFields(); + this.edgeComparator = EdgeAtom.getComparator(orderFields); + this.iterator = iterator; + } - nextValue = res; - return true; + @Override + public boolean hasNext() { + if (iterator.hasNext()) { + List> list = Lists.newArrayList(iterator.next()); + if (this.edgeComparator != null) { + list.sort(this.edgeComparator); + } + List> res = new ArrayList<>(list.size()); + Iterator> it = list.iterator(); + IGraphFilter filter = GraphFilter.of(this.filter, this.edgeLimit); + while (it.hasNext() && !filter.dropAllRemaining()) { + IEdge edge = it.next(); + if (filter.filterEdge(edge)) { + res.add(edge); } - return false; - } + } - @Override - public List> next() { - return nextValue; + nextValue = res; + return true; } + return false; + } - @Override - public void close() { + @Override + public List> next() { + return nextValue; + } - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryVertexScanIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryVertexScanIterator.java index 554af3206..128476e84 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryVertexScanIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/main/java/org/apache/geaflow/store/memory/iterator/MemoryVertexScanIterator.java @@ -20,40 +20,39 @@ package org.apache.geaflow.store.memory.iterator; import java.util.Iterator; + import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.iterator.IVertexIterator; import org.apache.geaflow.state.pushdown.filter.inner.IGraphFilter; public class MemoryVertexScanIterator implements IVertexIterator { - private final Iterator> iterator; - private final IGraphFilter filter; - private IVertex nextValue; - - public MemoryVertexScanIterator(Iterator> iterator, IGraphFilter filter) { - this.iterator = iterator; - this.filter = filter; + private final Iterator> iterator; + private final IGraphFilter filter; + private IVertex nextValue; + + public MemoryVertexScanIterator(Iterator> iterator, IGraphFilter filter) { + this.iterator = iterator; + this.filter = filter; + } + + @Override + public boolean hasNext() { + while (iterator.hasNext()) { + nextValue = iterator.next(); + if (!filter.filterVertex(nextValue)) { + continue; + } + return true; } + return false; + } - @Override - public boolean hasNext() { - while (iterator.hasNext()) { - nextValue = iterator.next(); - if (!filter.filterVertex(nextValue)) { - continue; - } - return true; - } - return false; - } + @Override + public IVertex next() { + return nextValue; + } - @Override - public IVertex next() { - return nextValue; - } - - @Override - public void close() { - - } + @Override + public void close() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphJMHRunner.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphJMHRunner.java index eeae1eb82..5b01896e1 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphJMHRunner.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphJMHRunner.java @@ -26,16 +26,16 @@ public class GraphJMHRunner { - public static void main(String[] args) throws RunnerException { + public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() + Options opt = + new OptionsBuilder() // import test class. .include(StringMapGraphJMH.class.getSimpleName()) .include(IntMapGraphJMH.class.getSimpleName()) .include(StringCSRMapGraphJMH.class.getSimpleName()) .include(IntCSRMapGraphJMH.class.getSimpleName()) .build(); - new Runner(opt).run(); - } - + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphMemoryStoreTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphMemoryStoreTest.java index d33d7e9fc..347db3f69 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphMemoryStoreTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/GraphMemoryStoreTest.java @@ -19,9 +19,6 @@ package org.apache.geaflow.store.memory; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -29,6 +26,7 @@ import java.util.Map; import java.util.Random; import java.util.function.Supplier; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.type.primitive.StringType; @@ -69,212 +67,270 @@ import org.testng.annotations.Factory; import org.testng.annotations.Test; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; + public class GraphMemoryStoreTest { - private final Configuration config; - private final MemoryStoreBuilder builder; + private final Configuration config; + private final MemoryStoreBuilder builder; - public GraphMemoryStoreTest(Configuration config) { - this.config = config; - this.builder = new MemoryStoreBuilder(); - } + public GraphMemoryStoreTest(Configuration config) { + this.config = config; + this.builder = new MemoryStoreBuilder(); + } - public static class GraphMemoryStoreTestFactory { + public static class GraphMemoryStoreTestFactory { - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new GraphMemoryStoreTest(new Configuration()), - new GraphMemoryStoreTest(new Configuration(ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true"))), - }; - } + @Factory + public Object[] factoryMethod() { + return new Object[] { + new GraphMemoryStoreTest(new Configuration()), + new GraphMemoryStoreTest( + new Configuration( + ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true"))), + }; } - - @Test - public void test() { - IStaticGraphStore store = - (IStaticGraphStore) builder.getStore(DataModel.STATIC_GRAPH, config); - StoreContext storeContext = new StoreContext("test") + } + + @Test + public void test() { + IStaticGraphStore store = + (IStaticGraphStore) + builder.getStore(DataModel.STATIC_GRAPH, config); + StoreContext storeContext = + new StoreContext("test") .withConfig(new Configuration()) - .withDataSchema(new GraphDataSchema(new GraphMeta( - new GraphMetaType<>(StringType.INSTANCE, ValueVertex.class, byte[].class, ValueEdge.class, byte[].class)))); - store.init(storeContext); - - for (int i = 0; i < 1000000; i++) { - String value = System.currentTimeMillis() + Math.random() + ""; - IVertex vertex = new ValueVertex<>("hello" + i, value.getBytes()); - store.addVertex(vertex); - IEdge edge = new ValueEdge<>("hello" + i, "hello" + (i + 1), - ("hello" + (i + 1)).getBytes()); - store.addEdge(edge); - } - store.flush(); - - Assert.assertEquals(Iterators.size(store.getVertexIterator(StatePushDown.of())), 1000000); - List> list = store.getEdges("hello" + 0, StatePushDown.of()); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(new String(list.get(0).getValue()), "hello1"); - } - - @Test - public void testGetOneDegreeGraphList() { - IStaticGraphStore store = - (IStaticGraphStore) builder.getStore(DataModel.STATIC_GRAPH, config); - StoreContext storeContext = - new StoreContext("test") - .withConfig(new Configuration()) - .withDataSchema(new GraphDataSchema(new GraphMeta( - new GraphMetaType<>(IntegerType.INSTANCE, IDVertex.class, EmptyProperty.class, - IDEdge.class, EmptyProperty.class)))); - store.init(storeContext); - - IVertex vertex = new IDVertex<>(1); - store.addVertex(vertex); - - IEdge edge = new IDEdge<>(2, 1); - edge.setDirect(EdgeDirection.IN); - store.addEdge(edge); - - store.flush(); - - List> list = - Lists.newArrayList(store.getOneDegreeGraphIterator(StatePushDown.of())); - Assert.assertEquals(list.size(), 2); - - list = - Lists.newArrayList(store.getOneDegreeGraphIterator(StatePushDown.of().withFilter( - GraphFilter.of(VertexMustContainFilter.getInstance())))); - Assert.assertEquals(list.size(), 1); - - list = - Lists.newArrayList(store.getOneDegreeGraphIterator(Arrays.asList(1, -1), StatePushDown.of())); - Assert.assertEquals(list.size(), 2); - } - - @Test - public void testGetVertexList() { - IStaticGraphStore store = - (IStaticGraphStore) builder.getStore(DataModel.STATIC_GRAPH, config); - StoreContext storeContext = - new StoreContext("test") - .withConfig(new Configuration()) - .withDataSchema(new GraphDataSchema(new GraphMeta( - new GraphMetaType<>(IntegerType.INSTANCE, IDVertex.class, EmptyProperty.class, - IDEdge.class, EmptyProperty.class)))); - store.init(storeContext); - - IVertex vertex = new IDVertex<>(1); - store.addVertex(vertex); - store.flush(); - - List> vertexList = - Lists.newArrayList(store.getVertexIterator(StatePushDown.of())); - Assert.assertEquals(vertexList.size(), 1); - Assert.assertEquals(vertexList.get(0).getId(), vertex.getId()); - vertexList.clear(); - - vertexList = Lists.newArrayList(store.getVertexIterator(StatePushDown.of())); - Assert.assertEquals(vertexList.size(), 1); - - vertexList = - Lists.newArrayList(store.getVertexIterator(Arrays.asList(1, -1), StatePushDown.of())); - Assert.assertEquals(vertexList.size(), 1); - - Map keyFilters = new HashMap<>(2); - keyFilters.put(0, GraphFilter.of((IVertexFilter) value -> value.getId() != 2)); - keyFilters.put(1, GraphFilter.of((IVertexFilter) value -> value.getId() != 2)); - vertexList = - Lists.newArrayList(store.getVertexIterator(Arrays.asList(1, -1), StatePushDown.of().withFilters(keyFilters))); - Assert.assertEquals(vertexList.size(), 1); - - Iterator> it = store.getEdgeIterator(StatePushDown.of()); - Assert.assertFalse(it.hasNext()); + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType<>( + StringType.INSTANCE, + ValueVertex.class, + byte[].class, + ValueEdge.class, + byte[].class)))); + store.init(storeContext); + + for (int i = 0; i < 1000000; i++) { + String value = System.currentTimeMillis() + Math.random() + ""; + IVertex vertex = new ValueVertex<>("hello" + i, value.getBytes()); + store.addVertex(vertex); + IEdge edge = + new ValueEdge<>("hello" + i, "hello" + (i + 1), ("hello" + (i + 1)).getBytes()); + store.addEdge(edge); } - - @Test - public void testGetEdgeList() { - IStaticGraphStore store = - (IStaticGraphStore) builder.getStore(DataModel.STATIC_GRAPH, config); - StoreContext storeContext = - new StoreContext("test") - .withConfig(new Configuration()) - .withDataSchema(new GraphDataSchema(new GraphMeta( - new GraphMetaType<>(IntegerType.INSTANCE, ValueVertex.class, Integer.class, - ValueEdge.class, Integer.class)))); - store.init(storeContext); - - store.addEdge(new ValueEdge<>(1, 1, 1)); - store.flush(); - - List> edgeList = store.getEdges(1, StatePushDown.of()); - Assert.assertEquals(edgeList.size(), 1); - - edgeList = store.getEdges(0, StatePushDown.of()); - Assert.assertEquals(edgeList.size(), 0); - - Iterator> it = store.getEdgeIterator(Arrays.asList(1, -1), StatePushDown.of()); - edgeList = Lists.newArrayList(it); - Assert.assertEquals(edgeList.size(), 1); - - Iterator> vIt = store.getVertexIterator(StatePushDown.of()); - Assert.assertFalse(vIt.hasNext()); - } - - private IVertex getVertex(GraphMeta graphMeta) { - GraphElementFlag flag = GraphElementFlag.build(graphMeta.getVertexMeta().getGraphElementClass()); - boolean noProperty = graphMeta.getVertexMeta().getPropertyClass() == EmptyProperty.class; - Random random = new Random(); - IVertex vertex; - String label = Integer.toString(random.nextInt(10)); - long time = random.nextInt(); - int property = random.nextInt(); - int srcid = random.nextInt(); - - if (flag.isLabeledAndTimed()) { - vertex = noProperty ? new IDLabelTimeVertex<>(srcid, label, time) : - new ValueLabelTimeVertex<>(srcid, property, label, time); - } else if (flag.isLabeled()) { - vertex = noProperty ? new IDLabelVertex<>(srcid, label) : - new ValueLabelVertex<>(srcid, property, label); - } else if (flag.isTimed()) { - vertex = noProperty ? new IDTimeVertex<>(srcid, time) : new ValueTimeVertex<>(srcid, property, time); - } else { - vertex = noProperty ? new IDVertex<>(srcid) : new ValueVertex<>(srcid, property); - } - return vertex; + store.flush(); + + Assert.assertEquals(Iterators.size(store.getVertexIterator(StatePushDown.of())), 1000000); + List> list = store.getEdges("hello" + 0, StatePushDown.of()); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(new String(list.get(0).getValue()), "hello1"); + } + + @Test + public void testGetOneDegreeGraphList() { + IStaticGraphStore store = + (IStaticGraphStore) + builder.getStore(DataModel.STATIC_GRAPH, config); + StoreContext storeContext = + new StoreContext("test") + .withConfig(new Configuration()) + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType<>( + IntegerType.INSTANCE, + IDVertex.class, + EmptyProperty.class, + IDEdge.class, + EmptyProperty.class)))); + store.init(storeContext); + + IVertex vertex = new IDVertex<>(1); + store.addVertex(vertex); + + IEdge edge = new IDEdge<>(2, 1); + edge.setDirect(EdgeDirection.IN); + store.addEdge(edge); + + store.flush(); + + List> list = + Lists.newArrayList(store.getOneDegreeGraphIterator(StatePushDown.of())); + Assert.assertEquals(list.size(), 2); + + list = + Lists.newArrayList( + store.getOneDegreeGraphIterator( + StatePushDown.of() + .withFilter(GraphFilter.of(VertexMustContainFilter.getInstance())))); + Assert.assertEquals(list.size(), 1); + + list = + Lists.newArrayList( + store.getOneDegreeGraphIterator(Arrays.asList(1, -1), StatePushDown.of())); + Assert.assertEquals(list.size(), 2); + } + + @Test + public void testGetVertexList() { + IStaticGraphStore store = + (IStaticGraphStore) + builder.getStore(DataModel.STATIC_GRAPH, config); + StoreContext storeContext = + new StoreContext("test") + .withConfig(new Configuration()) + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType<>( + IntegerType.INSTANCE, + IDVertex.class, + EmptyProperty.class, + IDEdge.class, + EmptyProperty.class)))); + store.init(storeContext); + + IVertex vertex = new IDVertex<>(1); + store.addVertex(vertex); + store.flush(); + + List> vertexList = + Lists.newArrayList(store.getVertexIterator(StatePushDown.of())); + Assert.assertEquals(vertexList.size(), 1); + Assert.assertEquals(vertexList.get(0).getId(), vertex.getId()); + vertexList.clear(); + + vertexList = Lists.newArrayList(store.getVertexIterator(StatePushDown.of())); + Assert.assertEquals(vertexList.size(), 1); + + vertexList = + Lists.newArrayList(store.getVertexIterator(Arrays.asList(1, -1), StatePushDown.of())); + Assert.assertEquals(vertexList.size(), 1); + + Map keyFilters = new HashMap<>(2); + keyFilters.put( + 0, GraphFilter.of((IVertexFilter) value -> value.getId() != 2)); + keyFilters.put( + 1, GraphFilter.of((IVertexFilter) value -> value.getId() != 2)); + vertexList = + Lists.newArrayList( + store.getVertexIterator( + Arrays.asList(1, -1), StatePushDown.of().withFilters(keyFilters))); + Assert.assertEquals(vertexList.size(), 1); + + Iterator> it = store.getEdgeIterator(StatePushDown.of()); + Assert.assertFalse(it.hasNext()); + } + + @Test + public void testGetEdgeList() { + IStaticGraphStore store = + (IStaticGraphStore) + builder.getStore(DataModel.STATIC_GRAPH, config); + StoreContext storeContext = + new StoreContext("test") + .withConfig(new Configuration()) + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType<>( + IntegerType.INSTANCE, + ValueVertex.class, + Integer.class, + ValueEdge.class, + Integer.class)))); + store.init(storeContext); + + store.addEdge(new ValueEdge<>(1, 1, 1)); + store.flush(); + + List> edgeList = store.getEdges(1, StatePushDown.of()); + Assert.assertEquals(edgeList.size(), 1); + + edgeList = store.getEdges(0, StatePushDown.of()); + Assert.assertEquals(edgeList.size(), 0); + + Iterator> it = + store.getEdgeIterator(Arrays.asList(1, -1), StatePushDown.of()); + edgeList = Lists.newArrayList(it); + Assert.assertEquals(edgeList.size(), 1); + + Iterator> vIt = store.getVertexIterator(StatePushDown.of()); + Assert.assertFalse(vIt.hasNext()); + } + + private IVertex getVertex(GraphMeta graphMeta) { + GraphElementFlag flag = + GraphElementFlag.build(graphMeta.getVertexMeta().getGraphElementClass()); + boolean noProperty = graphMeta.getVertexMeta().getPropertyClass() == EmptyProperty.class; + Random random = new Random(); + IVertex vertex; + String label = Integer.toString(random.nextInt(10)); + long time = random.nextInt(); + int property = random.nextInt(); + int srcid = random.nextInt(); + + if (flag.isLabeledAndTimed()) { + vertex = + noProperty + ? new IDLabelTimeVertex<>(srcid, label, time) + : new ValueLabelTimeVertex<>(srcid, property, label, time); + } else if (flag.isLabeled()) { + vertex = + noProperty + ? new IDLabelVertex<>(srcid, label) + : new ValueLabelVertex<>(srcid, property, label); + } else if (flag.isTimed()) { + vertex = + noProperty + ? new IDTimeVertex<>(srcid, time) + : new ValueTimeVertex<>(srcid, property, time); + } else { + vertex = noProperty ? new IDVertex<>(srcid) : new ValueVertex<>(srcid, property); } - - private IEdge getEdge(GraphMeta graphMeta) { - GraphElementFlag flag = GraphElementFlag.build(graphMeta.getEdgeMeta().getGraphElementClass()); - boolean noProperty = graphMeta.getEdgeMeta().getPropertyClass() == EmptyProperty.class; - Random random = new Random(); - IEdge edge; - String label = Integer.toString(random.nextInt(10)); - long time = random.nextInt(); - int property = random.nextInt(); - int id = random.nextInt(); - - if (flag.isLabeledAndTimed()) { - edge = noProperty ? new IDLabelTimeEdge<>(id, id, label, time) : - new ValueLabelTimeEdge<>(id, id, property, label, time); - } else if (flag.isLabeled()) { - edge = noProperty ? new IDLabelEdge<>(id, id, label) : new ValueLabelEdge<>(id, id, property, label); - } else if (flag.isTimed()) { - edge = noProperty ? new IDTimeEdge<>(id, id, time) : new ValueTimeEdge<>(id, id, property, time); - } else { - edge = noProperty ? new IDEdge<>(id, id) : new ValueEdge<>(id, property); - } - return edge; + return vertex; + } + + private IEdge getEdge(GraphMeta graphMeta) { + GraphElementFlag flag = GraphElementFlag.build(graphMeta.getEdgeMeta().getGraphElementClass()); + boolean noProperty = graphMeta.getEdgeMeta().getPropertyClass() == EmptyProperty.class; + Random random = new Random(); + IEdge edge; + String label = Integer.toString(random.nextInt(10)); + long time = random.nextInt(); + int property = random.nextInt(); + int id = random.nextInt(); + + if (flag.isLabeledAndTimed()) { + edge = + noProperty + ? new IDLabelTimeEdge<>(id, id, label, time) + : new ValueLabelTimeEdge<>(id, id, property, label, time); + } else if (flag.isLabeled()) { + edge = + noProperty + ? new IDLabelEdge<>(id, id, label) + : new ValueLabelEdge<>(id, id, property, label); + } else if (flag.isTimed()) { + edge = + noProperty ? new IDTimeEdge<>(id, id, time) : new ValueTimeEdge<>(id, id, property, time); + } else { + edge = noProperty ? new IDEdge<>(id, id) : new ValueEdge<>(id, property); } + return edge; + } - @Test - public void testDifferentType() { - IStaticGraphStore store = - (IStaticGraphStore) builder.getStore(DataModel.STATIC_GRAPH, - config); - + @Test + public void testDifferentType() { + IStaticGraphStore store = + (IStaticGraphStore) + builder.getStore(DataModel.STATIC_GRAPH, config); - List vertexClass = Arrays.asList( + List vertexClass = + Arrays.asList( ValueLabelTimeVertex.class, IDLabelTimeVertex.class, ValueLabelVertex.class, @@ -283,7 +339,8 @@ public void testDifferentType() { IDTimeVertex.class, ValueVertex.class, IDVertex.class); - List> vertexConstructs = Arrays.asList( + List> vertexConstructs = + Arrays.asList( ValueLabelTimeVertex::new, IDLabelTimeVertex::new, ValueLabelVertex::new, @@ -292,7 +349,8 @@ public void testDifferentType() { IDTimeVertex::new, ValueVertex::new, IDVertex::new); - List edgeClass = Arrays.asList( + List edgeClass = + Arrays.asList( ValueLabelTimeEdge.class, IDLabelTimeEdge.class, ValueLabelEdge.class, @@ -301,7 +359,8 @@ public void testDifferentType() { IDTimeEdge.class, ValueEdge.class, IDEdge.class); - List> edgeConstructs = Arrays.asList( + List> edgeConstructs = + Arrays.asList( ValueLabelTimeEdge::new, IDLabelTimeEdge::new, ValueLabelEdge::new, @@ -310,32 +369,41 @@ public void testDifferentType() { IDTimeEdge::new, ValueEdge::new, IDEdge::new); - for (int i = 0; i < vertexClass.size(); i++) { - Class propertyClazz = i % 2 == 0 ? Integer.class : EmptyProperty.class; - GraphMeta graphMeta = new GraphMeta( - new GraphMetaType(IntegerType.INSTANCE, vertexClass.get(i), vertexConstructs.get(i), - propertyClazz, edgeClass.get(i), edgeConstructs.get(i), propertyClazz)); - StoreContext storeContext = - new StoreContext("test").withConfig(new Configuration()).withDataSchema(new GraphDataSchema(graphMeta)); - store.init(storeContext); - - IVertex vertex = getVertex(graphMeta); - IEdge edge = getEdge(graphMeta); - store.addVertex(vertex); - store.addEdge(edge); - - store.flush(); - Iterator> it = store.getOneDegreeGraphIterator( - StatePushDown.of()); - while (it.hasNext()) { - OneDegreeGraph next = it.next(); - if (next.getVertex() != null) { - Assert.assertEquals(next.getVertex(), vertex); - } - if (next.getEdgeIterator().hasNext()) { - Assert.assertEquals(next.getEdgeIterator().next(), edge); - } - } + for (int i = 0; i < vertexClass.size(); i++) { + Class propertyClazz = i % 2 == 0 ? Integer.class : EmptyProperty.class; + GraphMeta graphMeta = + new GraphMeta( + new GraphMetaType( + IntegerType.INSTANCE, + vertexClass.get(i), + vertexConstructs.get(i), + propertyClazz, + edgeClass.get(i), + edgeConstructs.get(i), + propertyClazz)); + StoreContext storeContext = + new StoreContext("test") + .withConfig(new Configuration()) + .withDataSchema(new GraphDataSchema(graphMeta)); + store.init(storeContext); + + IVertex vertex = getVertex(graphMeta); + IEdge edge = getEdge(graphMeta); + store.addVertex(vertex); + store.addEdge(edge); + + store.flush(); + Iterator> it = + store.getOneDegreeGraphIterator(StatePushDown.of()); + while (it.hasNext()) { + OneDegreeGraph next = it.next(); + if (next.getVertex() != null) { + Assert.assertEquals(next.getVertex(), vertex); + } + if (next.getEdgeIterator().hasNext()) { + Assert.assertEquals(next.getEdgeIterator().next(), edge); } + } } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntCSRMapGraphJMH.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntCSRMapGraphJMH.java index 7ba84cbec..04cc6beec 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntCSRMapGraphJMH.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntCSRMapGraphJMH.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.Properties; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -64,98 +65,105 @@ @State(Scope.Benchmark) public class IntCSRMapGraphJMH { - private static final Logger LOGGER = LoggerFactory.getLogger(IntCSRMapGraphJMH.class); - - IStatePushDown pushdown = StatePushDown.of(); - IStaticGraphStore store; - StoreContext storeContext = new StoreContext("test").withDataSchema( - new GraphDataSchema(new GraphMeta( - new GraphMetaType(IntegerType.INSTANCE, IDVertex.class, IDVertex::new, EmptyProperty.class, - IDEdge.class, IDEdge::new, EmptyProperty.class)))); - - @Setup - public void setUp() { - Properties prop = new Properties(); - prop.setProperty("log4j.rootLogger", "INFO, stdout"); - prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); - prop.setProperty("log4j.appender.stdout.Target", "System.out"); - prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); - prop.setProperty("log4j.appender.stdout.layout.ConversionPattern", - "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); - PropertyConfigurator.configure(prop); - - store = new StaticGraphMemoryCSRStore(); - composeGraph(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IntCSRMapGraphJMH.class); - @Benchmark - public void composeGraph() { - store.init(storeContext); - for (int i = 0; i < 1000000; i++) { - IDVertex vertex = new IDVertex<>(i); - store.addVertex(vertex); - IDEdge edge = new IDEdge<>(i, i + 1); - edge.setDirect(EdgeDirection.IN); - store.addEdge(edge); - } - store.flush(); - } + IStatePushDown pushdown = StatePushDown.of(); + IStaticGraphStore store; + StoreContext storeContext = + new StoreContext("test") + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType( + IntegerType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + IDEdge.class, + IDEdge::new, + EmptyProperty.class)))); - @Benchmark - public void getVertex() { - for (int i = 0; i < 100000; i++) { - store.getVertex(i * 10, pushdown); - } + @Setup + public void setUp() { + Properties prop = new Properties(); + prop.setProperty("log4j.rootLogger", "INFO, stdout"); + prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); + prop.setProperty("log4j.appender.stdout.Target", "System.out"); + prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); + prop.setProperty( + "log4j.appender.stdout.layout.ConversionPattern", + "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); + PropertyConfigurator.configure(prop); + + store = new StaticGraphMemoryCSRStore(); + composeGraph(); + } + + @Benchmark + public void composeGraph() { + store.init(storeContext); + for (int i = 0; i < 1000000; i++) { + IDVertex vertex = new IDVertex<>(i); + store.addVertex(vertex); + IDEdge edge = new IDEdge<>(i, i + 1); + edge.setDirect(EdgeDirection.IN); + store.addEdge(edge); } + store.flush(); + } - @Benchmark - public void getEdges() { - for (int i = 0; i < 100000; i++) { - store.getEdges(i * 10, pushdown); - } + @Benchmark + public void getVertex() { + for (int i = 0; i < 100000; i++) { + store.getVertex(i * 10, pushdown); } + } - @Benchmark - public void getOneGraph() { - for (int i = 0; i < 100000; i++) { - store.getOneDegreeGraph(i * 10, pushdown); - } + @Benchmark + public void getEdges() { + for (int i = 0; i < 100000; i++) { + store.getEdges(i * 10, pushdown); } + } - @Benchmark - public void getVertexIterator() { - Iterator> it = - store.getVertexIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getOneGraph() { + for (int i = 0; i < 100000; i++) { + store.getOneDegreeGraph(i * 10, pushdown); } + } - @Benchmark - public void getEdgeIterator() { - Iterator> it = - store.getEdgeIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getVertexIterator() { + Iterator> it = store.getVertexIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void getOneGraphIterator() { - Iterator> it = - store.getOneDegreeGraphIterator( - pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getEdgeIterator() { + Iterator> it = store.getEdgeIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void memoryUsage() { - storeContext = null; - System.gc(); - SleepUtils.sleepSecond(1); - MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); - LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + @Benchmark + public void getOneGraphIterator() { + Iterator> it = + store.getOneDegreeGraphIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } + + @Benchmark + public void memoryUsage() { + storeContext = null; + System.gc(); + SleepUtils.sleepSecond(1); + MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); + LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntMapGraphJMH.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntMapGraphJMH.java index 0e81b96e3..040188d07 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntMapGraphJMH.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/IntMapGraphJMH.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.Properties; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -64,98 +65,105 @@ @State(Scope.Benchmark) public class IntMapGraphJMH { - private static final Logger LOGGER = LoggerFactory.getLogger(IntMapGraphJMH.class); - - IStatePushDown pushdown = StatePushDown.of(); - IStaticGraphStore store; - StoreContext storeContext = new StoreContext("test").withDataSchema( - new GraphDataSchema(new GraphMeta( - new GraphMetaType(IntegerType.INSTANCE, IDVertex.class, IDVertex::new, EmptyProperty.class, - IDEdge.class, IDEdge::new, EmptyProperty.class)))); - - @Setup - public void setUp() { - Properties prop = new Properties(); - prop.setProperty("log4j.rootLogger", "INFO, stdout"); - prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); - prop.setProperty("log4j.appender.stdout.Target", "System.out"); - prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); - prop.setProperty("log4j.appender.stdout.layout.ConversionPattern", - "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); - PropertyConfigurator.configure(prop); - - store = new StaticGraphMemoryStore(); - composeGraph(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(IntMapGraphJMH.class); - @Benchmark - public void composeGraph() { - store.init(storeContext); - for (int i = 0; i < 1000000; i++) { - IDVertex vertex = new IDVertex<>(i); - store.addVertex(vertex); - IDEdge edge = new IDEdge<>(i, i + 1); - edge.setDirect(EdgeDirection.IN); - store.addEdge(edge); - } - store.flush(); - } + IStatePushDown pushdown = StatePushDown.of(); + IStaticGraphStore store; + StoreContext storeContext = + new StoreContext("test") + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType( + IntegerType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + IDEdge.class, + IDEdge::new, + EmptyProperty.class)))); - @Benchmark - public void getVertex() { - for (int i = 0; i < 100000; i++) { - store.getVertex(i * 10, pushdown); - } + @Setup + public void setUp() { + Properties prop = new Properties(); + prop.setProperty("log4j.rootLogger", "INFO, stdout"); + prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); + prop.setProperty("log4j.appender.stdout.Target", "System.out"); + prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); + prop.setProperty( + "log4j.appender.stdout.layout.ConversionPattern", + "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); + PropertyConfigurator.configure(prop); + + store = new StaticGraphMemoryStore(); + composeGraph(); + } + + @Benchmark + public void composeGraph() { + store.init(storeContext); + for (int i = 0; i < 1000000; i++) { + IDVertex vertex = new IDVertex<>(i); + store.addVertex(vertex); + IDEdge edge = new IDEdge<>(i, i + 1); + edge.setDirect(EdgeDirection.IN); + store.addEdge(edge); } + store.flush(); + } - @Benchmark - public void getEdges() { - for (int i = 0; i < 100000; i++) { - store.getEdges(i * 10, pushdown); - } + @Benchmark + public void getVertex() { + for (int i = 0; i < 100000; i++) { + store.getVertex(i * 10, pushdown); } + } - @Benchmark - public void getOneGraph() { - for (int i = 0; i < 100000; i++) { - store.getOneDegreeGraph(i * 10, pushdown); - } + @Benchmark + public void getEdges() { + for (int i = 0; i < 100000; i++) { + store.getEdges(i * 10, pushdown); } + } - @Benchmark - public void getVertexIterator() { - Iterator> it = - store.getVertexIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getOneGraph() { + for (int i = 0; i < 100000; i++) { + store.getOneDegreeGraph(i * 10, pushdown); } + } - @Benchmark - public void getEdgeIterator() { - Iterator> it = - store.getEdgeIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getVertexIterator() { + Iterator> it = store.getVertexIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void getOneGraphIterator() { - Iterator> it = - store.getOneDegreeGraphIterator( - pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getEdgeIterator() { + Iterator> it = store.getEdgeIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void memoryUsage() { - storeContext = null; - System.gc(); - SleepUtils.sleepSecond(1); - MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); - LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + @Benchmark + public void getOneGraphIterator() { + Iterator> it = + store.getOneDegreeGraphIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } + + @Benchmark + public void memoryUsage() { + storeContext = null; + System.gc(); + SleepUtils.sleepSecond(1); + MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); + LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/KeyMemoryStoreTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/KeyMemoryStoreTest.java index 7fdc5e16e..4434a1031 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/KeyMemoryStoreTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/KeyMemoryStoreTest.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.state.DataModel; import org.apache.geaflow.state.StoreType; @@ -36,72 +37,71 @@ public class KeyMemoryStoreTest { - @Test - public void testKV() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); - IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, - new Configuration()); - - Configuration configuration = new Configuration(); - StoreContext storeContext = new StoreContext("mem").withConfig(configuration); - kvStore.init(storeContext); - kvStore.put("hello", "world"); - kvStore.put("foo", "bar"); - - Assert.assertEquals(kvStore.get("hello"), "world"); - Assert.assertEquals(kvStore.get("foo"), "bar"); - - kvStore.remove("foo"); - Assert.assertNull(kvStore.get("foo")); - } - - @Test - public void testKMap() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); - IKMapStore kMapStore = - (IKMapStore) builder.getStore( - DataModel.KMap, new Configuration()); - - Configuration configuration = new Configuration(); - StoreContext storeContext = new StoreContext("mem").withConfig(configuration); - kMapStore.init(storeContext); - - Map map = new HashMap<>(); - map.put("hello", "world"); - map.put("hello1", "world1"); - - kMapStore.add("hw", map); - - map.clear(); - map.put("foo", "bar"); - kMapStore.add("hw", map); - kMapStore.add("hw", "bar", "foo"); - - Assert.assertEquals(kMapStore.get("hw").size(), 4); - Assert.assertEquals(kMapStore.get("hw", "foo", "bar"), Arrays.asList("bar", "foo")); - - kMapStore.remove("hw", "bar"); - Assert.assertEquals(kMapStore.get("hw").size(), 3); - - kMapStore.remove("hw"); - Assert.assertEquals(kMapStore.get("hw").size(), 0); - } - - @Test - public void testKList() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); - IKListStore kListStore = (IKListStore) builder.getStore( - DataModel.KList, new Configuration()); - - Configuration configuration = new Configuration(); - StoreContext storeContext = new StoreContext("mem").withConfig(configuration); - kListStore.init(storeContext); - - kListStore.add("hw", "foo", "bar"); - kListStore.add("hw", "hello"); - - Assert.assertEquals(kListStore.get("hw").size(), 3); - kListStore.remove("hw"); - Assert.assertEquals(kListStore.get("hw").size(), 0); - } + @Test + public void testKV() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); + IKVStore kvStore = + (IKVStore) builder.getStore(DataModel.KV, new Configuration()); + + Configuration configuration = new Configuration(); + StoreContext storeContext = new StoreContext("mem").withConfig(configuration); + kvStore.init(storeContext); + kvStore.put("hello", "world"); + kvStore.put("foo", "bar"); + + Assert.assertEquals(kvStore.get("hello"), "world"); + Assert.assertEquals(kvStore.get("foo"), "bar"); + + kvStore.remove("foo"); + Assert.assertNull(kvStore.get("foo")); + } + + @Test + public void testKMap() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); + IKMapStore kMapStore = + (IKMapStore) builder.getStore(DataModel.KMap, new Configuration()); + + Configuration configuration = new Configuration(); + StoreContext storeContext = new StoreContext("mem").withConfig(configuration); + kMapStore.init(storeContext); + + Map map = new HashMap<>(); + map.put("hello", "world"); + map.put("hello1", "world1"); + + kMapStore.add("hw", map); + + map.clear(); + map.put("foo", "bar"); + kMapStore.add("hw", map); + kMapStore.add("hw", "bar", "foo"); + + Assert.assertEquals(kMapStore.get("hw").size(), 4); + Assert.assertEquals(kMapStore.get("hw", "foo", "bar"), Arrays.asList("bar", "foo")); + + kMapStore.remove("hw", "bar"); + Assert.assertEquals(kMapStore.get("hw").size(), 3); + + kMapStore.remove("hw"); + Assert.assertEquals(kMapStore.get("hw").size(), 0); + } + + @Test + public void testKList() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.MEMORY.name()); + IKListStore kListStore = + (IKListStore) builder.getStore(DataModel.KList, new Configuration()); + + Configuration configuration = new Configuration(); + StoreContext storeContext = new StoreContext("mem").withConfig(configuration); + kListStore.init(storeContext); + + kListStore.add("hw", "foo", "bar"); + kListStore.add("hw", "hello"); + + Assert.assertEquals(kListStore.get("hw").size(), 3); + kListStore.remove("hw"); + Assert.assertEquals(kListStore.get("hw").size(), 0); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringCSRMapGraphJMH.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringCSRMapGraphJMH.java index 22714c23d..57a2b08d9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringCSRMapGraphJMH.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringCSRMapGraphJMH.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.Properties; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -64,97 +65,104 @@ @State(Scope.Benchmark) public class StringCSRMapGraphJMH { - private static final Logger LOGGER = LoggerFactory.getLogger(StringCSRMapGraphJMH.class); - - IStatePushDown pushdown = StatePushDown.of(); - IStaticGraphStore store; - StoreContext storeContext = new StoreContext("test").withDataSchema( - new GraphDataSchema(new GraphMeta( - new GraphMetaType(StringType.INSTANCE, IDVertex.class, IDVertex::new, EmptyProperty.class, - IDEdge.class, IDEdge::new, EmptyProperty.class)))); - - @Setup - public void setUp() { - Properties prop = new Properties(); - prop.setProperty("log4j.rootLogger", "INFO, stdout"); - prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); - prop.setProperty("log4j.appender.stdout.Target", "System.out"); - prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); - prop.setProperty("log4j.appender.stdout.layout.ConversionPattern", - "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); - PropertyConfigurator.configure(prop); - - store = new StaticGraphMemoryCSRStore(); - composeGraph(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(StringCSRMapGraphJMH.class); - @Benchmark - public void composeGraph() { - store.init(storeContext); - for (int i = 0; i < 1000000; i++) { - IDVertex vertex = new IDVertex<>(Integer.toString(i)); - store.addVertex(vertex); - IDEdge edge = new IDEdge<>(Integer.toString(i), Integer.toString(i + 1)); - edge.setDirect(EdgeDirection.IN); - store.addEdge(edge); - } - store.flush(); - } + IStatePushDown pushdown = StatePushDown.of(); + IStaticGraphStore store; + StoreContext storeContext = + new StoreContext("test") + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType( + StringType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + IDEdge.class, + IDEdge::new, + EmptyProperty.class)))); - @Benchmark - public void getVertex() { - for (int i = 0; i < 100000; i++) { - store.getVertex(Integer.toString(i * 10), pushdown); - } + @Setup + public void setUp() { + Properties prop = new Properties(); + prop.setProperty("log4j.rootLogger", "INFO, stdout"); + prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); + prop.setProperty("log4j.appender.stdout.Target", "System.out"); + prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); + prop.setProperty( + "log4j.appender.stdout.layout.ConversionPattern", + "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); + PropertyConfigurator.configure(prop); + + store = new StaticGraphMemoryCSRStore(); + composeGraph(); + } + + @Benchmark + public void composeGraph() { + store.init(storeContext); + for (int i = 0; i < 1000000; i++) { + IDVertex vertex = new IDVertex<>(Integer.toString(i)); + store.addVertex(vertex); + IDEdge edge = new IDEdge<>(Integer.toString(i), Integer.toString(i + 1)); + edge.setDirect(EdgeDirection.IN); + store.addEdge(edge); } + store.flush(); + } - @Benchmark - public void getEdges() { - for (int i = 0; i < 100000; i++) { - store.getEdges(Integer.toString(i * 10), pushdown); - } + @Benchmark + public void getVertex() { + for (int i = 0; i < 100000; i++) { + store.getVertex(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getOneGraph() { - for (int i = 0; i < 100000; i++) { - store.getOneDegreeGraph(Integer.toString(i * 10), pushdown); - } + @Benchmark + public void getEdges() { + for (int i = 0; i < 100000; i++) { + store.getEdges(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getVertexIterator() { - Iterator> it = - store.getVertexIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getOneGraph() { + for (int i = 0; i < 100000; i++) { + store.getOneDegreeGraph(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getEdgeIterator() { - Iterator> it = - store.getEdgeIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getVertexIterator() { + Iterator> it = store.getVertexIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void getOneGraphIterator() { - Iterator> it = - store.getOneDegreeGraphIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getEdgeIterator() { + Iterator> it = store.getEdgeIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void memoryUsage() { - storeContext = null; - System.gc(); - SleepUtils.sleepSecond(1); - MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); - LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + @Benchmark + public void getOneGraphIterator() { + Iterator> it = store.getOneDegreeGraphIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } + + @Benchmark + public void memoryUsage() { + storeContext = null; + System.gc(); + SleepUtils.sleepSecond(1); + MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); + LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringMapGraphJMH.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringMapGraphJMH.java index 0487d0e5d..024ed640f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringMapGraphJMH.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-memory/src/test/java/org/apache/geaflow/store/memory/StringMapGraphJMH.java @@ -24,6 +24,7 @@ import java.util.Iterator; import java.util.Properties; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.type.primitive.StringType; import org.apache.geaflow.common.utils.SleepUtils; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -64,97 +65,104 @@ @State(Scope.Benchmark) public class StringMapGraphJMH { - private static final Logger LOGGER = LoggerFactory.getLogger(StringMapGraphJMH.class); - - IStatePushDown pushdown = StatePushDown.of(); - IStaticGraphStore store; - StoreContext storeContext = new StoreContext("test").withDataSchema( - new GraphDataSchema(new GraphMeta( - new GraphMetaType(StringType.INSTANCE, IDVertex.class, IDVertex::new, EmptyProperty.class, - IDEdge.class, IDEdge::new, EmptyProperty.class)))); - - @Setup - public void setUp() { - Properties prop = new Properties(); - prop.setProperty("log4j.rootLogger", "INFO, stdout"); - prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); - prop.setProperty("log4j.appender.stdout.Target", "System.out"); - prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); - prop.setProperty("log4j.appender.stdout.layout.ConversionPattern", - "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); - PropertyConfigurator.configure(prop); - - store = new StaticGraphMemoryStore(); - composeGraph(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(StringMapGraphJMH.class); - @Benchmark - public void composeGraph() { - store.init(storeContext); - for (int i = 0; i < 1000000; i++) { - IDVertex vertex = new IDVertex<>(Integer.toString(i)); - store.addVertex(vertex); - IDEdge edge = new IDEdge<>(Integer.toString(i), Integer.toString(i + 1)); - edge.setDirect(EdgeDirection.IN); - store.addEdge(edge); - } - store.flush(); - } + IStatePushDown pushdown = StatePushDown.of(); + IStaticGraphStore store; + StoreContext storeContext = + new StoreContext("test") + .withDataSchema( + new GraphDataSchema( + new GraphMeta( + new GraphMetaType( + StringType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + IDEdge.class, + IDEdge::new, + EmptyProperty.class)))); - @Benchmark - public void getVertex() { - for (int i = 0; i < 100000; i++) { - store.getVertex(Integer.toString(i * 10), pushdown); - } + @Setup + public void setUp() { + Properties prop = new Properties(); + prop.setProperty("log4j.rootLogger", "INFO, stdout"); + prop.setProperty("log4j.appender.stdout", "org.apache.log4j.ConsoleAppender"); + prop.setProperty("log4j.appender.stdout.Target", "System.out"); + prop.setProperty("log4j.appender.stdout.layout", "org.apache.log4j.PatternLayout"); + prop.setProperty( + "log4j.appender.stdout.layout.ConversionPattern", + "%d{yyyy-MM-dd HH:mm:ss} [%t] %-5p %c{1}:%L - %m%n"); + PropertyConfigurator.configure(prop); + + store = new StaticGraphMemoryStore(); + composeGraph(); + } + + @Benchmark + public void composeGraph() { + store.init(storeContext); + for (int i = 0; i < 1000000; i++) { + IDVertex vertex = new IDVertex<>(Integer.toString(i)); + store.addVertex(vertex); + IDEdge edge = new IDEdge<>(Integer.toString(i), Integer.toString(i + 1)); + edge.setDirect(EdgeDirection.IN); + store.addEdge(edge); } + store.flush(); + } - @Benchmark - public void getEdges() { - for (int i = 0; i < 100000; i++) { - store.getEdges(Integer.toString(i * 10), pushdown); - } + @Benchmark + public void getVertex() { + for (int i = 0; i < 100000; i++) { + store.getVertex(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getOneGraph() { - for (int i = 0; i < 100000; i++) { - store.getOneDegreeGraph(Integer.toString(i * 10), pushdown); - } + @Benchmark + public void getEdges() { + for (int i = 0; i < 100000; i++) { + store.getEdges(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getVertexIterator() { - Iterator> it = store.getVertexIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getOneGraph() { + for (int i = 0; i < 100000; i++) { + store.getOneDegreeGraph(Integer.toString(i * 10), pushdown); } + } - @Benchmark - public void getEdgeIterator() { - Iterator> it = - store.getEdgeIterator(pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getVertexIterator() { + Iterator> it = store.getVertexIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void getOneGraphIterator() { - Iterator> it = - store.getOneDegreeGraphIterator( - pushdown); - while (it.hasNext()) { - it.next(); - } + @Benchmark + public void getEdgeIterator() { + Iterator> it = store.getEdgeIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } - @Benchmark - public void memoryUsage() { - storeContext = null; - System.gc(); - SleepUtils.sleepSecond(1); - MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); - LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + @Benchmark + public void getOneGraphIterator() { + Iterator> it = store.getOneDegreeGraphIterator(pushdown); + while (it.hasNext()) { + it.next(); } + } + + @Benchmark + public void memoryUsage() { + storeContext = null; + System.gc(); + SleepUtils.sleepSecond(1); + MemoryMXBean mm = ManagementFactory.getMemoryMXBean(); + LOGGER.info("map(MB): {}", mm.getHeapMemoryUsage().getUsed() / 1024 / 1024); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonGraphStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonGraphStore.java index 49db1c585..cdfe36dd9 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonGraphStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonGraphStore.java @@ -36,58 +36,57 @@ public abstract class BasePaimonGraphStore extends BasePaimonStore implements IPushDownStore { - protected IFilterConverter filterConverter; - protected String vertexTable; - protected String edgeTable; - protected String indexTable; + protected IFilterConverter filterConverter; + protected String vertexTable; + protected String edgeTable; + protected String indexTable; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - Configuration config = storeContext.getConfig(); - boolean codegenEnable = config.getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE); - filterConverter = - codegenEnable ? new CodeGenFilterConverter() : new DirectFilterConverter(); - // use distributed table when distributed mode or database is configured. - if (config.contains(PAIMON_STORE_DATABASE) || isDistributedMode) { - paimonStoreName = config.getString(PAIMON_STORE_DATABASE); - vertexTable = config.getString(PAIMON_STORE_VERTEX_TABLE); - edgeTable = config.getString(PAIMON_STORE_EDGE_TABLE); - indexTable = config.getString(PAIMON_STORE_INDEX_TABLE); - } + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + Configuration config = storeContext.getConfig(); + boolean codegenEnable = config.getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE); + filterConverter = codegenEnable ? new CodeGenFilterConverter() : new DirectFilterConverter(); + // use distributed table when distributed mode or database is configured. + if (config.contains(PAIMON_STORE_DATABASE) || isDistributedMode) { + paimonStoreName = config.getString(PAIMON_STORE_DATABASE); + vertexTable = config.getString(PAIMON_STORE_VERTEX_TABLE); + edgeTable = config.getString(PAIMON_STORE_EDGE_TABLE); + indexTable = config.getString(PAIMON_STORE_INDEX_TABLE); } + } - @Override - public IFilterConverter getFilterConverter() { - return filterConverter; - } + @Override + public IFilterConverter getFilterConverter() { + return filterConverter; + } - protected PaimonTableRWHandle createVertexTable(int shardId) { - String tableName = vertexTable; - if (StringUtils.isEmpty(tableName)) { - tableName = String.format("%s#%s", "vertex", shardId); - } - return createTable(tableName); + protected PaimonTableRWHandle createVertexTable(int shardId) { + String tableName = vertexTable; + if (StringUtils.isEmpty(tableName)) { + tableName = String.format("%s#%s", "vertex", shardId); } + return createTable(tableName); + } - protected PaimonTableRWHandle createEdgeTable(int shardId) { - String tableName = edgeTable; - if (StringUtils.isEmpty(edgeTable)) { - tableName = String.format("%s#%s", "edge", shardId); - } - return createTable(tableName); + protected PaimonTableRWHandle createEdgeTable(int shardId) { + String tableName = edgeTable; + if (StringUtils.isEmpty(edgeTable)) { + tableName = String.format("%s#%s", "edge", shardId); } + return createTable(tableName); + } - protected PaimonTableRWHandle createIndexTable(int shardId) { - String tableName = indexTable; - if (StringUtils.isEmpty(indexTable)) { - tableName = String.format("%s#%s", "vertex_index", shardId); - } - return createTable(tableName); + protected PaimonTableRWHandle createIndexTable(int shardId) { + String tableName = indexTable; + if (StringUtils.isEmpty(indexTable)) { + tableName = String.format("%s#%s", "vertex_index", shardId); } + return createTable(tableName); + } - private PaimonTableRWHandle createTable(String tableName) { - Identifier vertexIndexIdentifier = new Identifier(paimonStoreName, tableName); - return createKVTableHandle(vertexIndexIdentifier); - } + private PaimonTableRWHandle createTable(String tableName) { + Identifier vertexIndexIdentifier = new Identifier(paimonStoreName, tableName); + return createKVTableHandle(vertexIndexIdentifier); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonStore.java index 767f1b08d..f62f0b0df 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/BasePaimonStore.java @@ -35,96 +35,96 @@ public abstract class BasePaimonStore extends BaseGraphStore implements IStatefulStore { - protected static final String KEY_COLUMN_NAME = "key"; - protected static final String VALUE_COLUMN_NAME = "value"; - protected static final int KEY_COLUMN_INDEX = 0; - protected static final int VALUE_COLUMN_INDEX = 1; + protected static final String KEY_COLUMN_NAME = "key"; + protected static final String VALUE_COLUMN_NAME = "value"; + protected static final int KEY_COLUMN_INDEX = 0; + protected static final int VALUE_COLUMN_INDEX = 1; - // 新增的常量定义 - protected static final String TARGET_ID_COLUMN_NAME = "target_id"; - protected static final String SRC_ID_COLUMN_NAME = "src_id"; - protected static final String TS_COLUMN_NAME = "ts"; - protected static final String DIRECTION_COLUMN_NAME = "direction"; - protected static final String LABEL_COLUMN_NAME = "label"; + // 新增的常量定义 + protected static final String TARGET_ID_COLUMN_NAME = "target_id"; + protected static final String SRC_ID_COLUMN_NAME = "src_id"; + protected static final String TS_COLUMN_NAME = "ts"; + protected static final String DIRECTION_COLUMN_NAME = "direction"; + protected static final String LABEL_COLUMN_NAME = "label"; - protected PaimonCatalogClient client; - protected int shardId; - protected String jobName; - protected String paimonStoreName; - protected long lastCheckpointId; - protected boolean isDistributedMode; - protected boolean enableAutoCreate; + protected PaimonCatalogClient client; + protected int shardId; + protected String jobName; + protected String paimonStoreName; + protected long lastCheckpointId; + protected boolean isDistributedMode; + protected boolean enableAutoCreate; - @Override - public void init(StoreContext storeContext) { - this.shardId = storeContext.getShardId(); - this.jobName = storeContext.getConfig().getString(ExecutionConfigKeys.JOB_APP_NAME); - this.paimonStoreName = this.jobName + "#" + this.shardId; - this.client = PaimonCatalogManager.getCatalogClient(storeContext.getConfig()); - this.lastCheckpointId = Long.MAX_VALUE; - this.isDistributedMode = storeContext.getConfig() - .getBoolean(PAIMON_STORE_DISTRIBUTED_MODE_ENABLE); - this.enableAutoCreate = storeContext.getConfig() - .getBoolean(PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE); - } - - @Override - public void close() { - this.client.close(); - } + @Override + public void init(StoreContext storeContext) { + this.shardId = storeContext.getShardId(); + this.jobName = storeContext.getConfig().getString(ExecutionConfigKeys.JOB_APP_NAME); + this.paimonStoreName = this.jobName + "#" + this.shardId; + this.client = PaimonCatalogManager.getCatalogClient(storeContext.getConfig()); + this.lastCheckpointId = Long.MAX_VALUE; + this.isDistributedMode = + storeContext.getConfig().getBoolean(PAIMON_STORE_DISTRIBUTED_MODE_ENABLE); + this.enableAutoCreate = + storeContext.getConfig().getBoolean(PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE); + } - @Override - public void drop() { - this.client.dropDatabase(paimonStoreName); - } - - protected PaimonTableRWHandle createKVTableHandle(Identifier identifier) { - Table vertexTable; - try { - vertexTable = this.client.getTable(identifier); - } catch (TableNotExistException e) { - if (enableAutoCreate) { - Schema.Builder schemaBuilder = Schema.newBuilder(); - schemaBuilder.primaryKey(KEY_COLUMN_NAME); - schemaBuilder.column(KEY_COLUMN_NAME, DataTypes.BYTES()); - schemaBuilder.column(VALUE_COLUMN_NAME, DataTypes.BYTES()); - Schema schema = schemaBuilder.build(); - vertexTable = this.client.createTable(schema, identifier); - } else { - throw new GeaflowRuntimeException("Table " + identifier + " not exist."); - } - } + @Override + public void close() { + this.client.close(); + } - return new PaimonTableRWHandle(identifier, vertexTable, shardId, isDistributedMode); - } + @Override + public void drop() { + this.client.dropDatabase(paimonStoreName); + } - protected PaimonTableRWHandle createEdgeTableHandle(Identifier identifier) { + protected PaimonTableRWHandle createKVTableHandle(Identifier identifier) { + Table vertexTable; + try { + vertexTable = this.client.getTable(identifier); + } catch (TableNotExistException e) { + if (enableAutoCreate) { Schema.Builder schemaBuilder = Schema.newBuilder(); - schemaBuilder.primaryKey(SRC_ID_COLUMN_NAME); - schemaBuilder.primaryKey(TARGET_ID_COLUMN_NAME); - schemaBuilder.primaryKey(TS_COLUMN_NAME); - schemaBuilder.primaryKey(DIRECTION_COLUMN_NAME); - schemaBuilder.primaryKey(LABEL_COLUMN_NAME); - schemaBuilder.column(SRC_ID_COLUMN_NAME, DataTypes.BYTES()); - schemaBuilder.column(TARGET_ID_COLUMN_NAME, DataTypes.BYTES()); - schemaBuilder.column(TS_COLUMN_NAME, DataTypes.BIGINT()); - schemaBuilder.column(DIRECTION_COLUMN_NAME, DataTypes.SMALLINT()); - schemaBuilder.column(LABEL_COLUMN_NAME, DataTypes.BYTES()); + schemaBuilder.primaryKey(KEY_COLUMN_NAME); + schemaBuilder.column(KEY_COLUMN_NAME, DataTypes.BYTES()); schemaBuilder.column(VALUE_COLUMN_NAME, DataTypes.BYTES()); Schema schema = schemaBuilder.build(); - Table vertexTable = this.client.createTable(schema, identifier); - return new PaimonTableRWHandle(identifier, vertexTable, shardId); + vertexTable = this.client.createTable(schema, identifier); + } else { + throw new GeaflowRuntimeException("Table " + identifier + " not exist."); + } } - protected PaimonTableRWHandle createVertexTableHandle(Identifier identifier) { - Schema.Builder schemaBuilder = Schema.newBuilder(); - schemaBuilder.primaryKey(SRC_ID_COLUMN_NAME); - schemaBuilder.column(SRC_ID_COLUMN_NAME, DataTypes.BYTES()); - schemaBuilder.column(TS_COLUMN_NAME, DataTypes.BIGINT()); - schemaBuilder.column(LABEL_COLUMN_NAME, DataTypes.SMALLINT()); - schemaBuilder.column(VALUE_COLUMN_NAME, DataTypes.BYTES()); - Schema schema = schemaBuilder.build(); - Table vertexTable = this.client.createTable(schema, identifier); - return new PaimonTableRWHandle(identifier, vertexTable, shardId); - } + return new PaimonTableRWHandle(identifier, vertexTable, shardId, isDistributedMode); + } + + protected PaimonTableRWHandle createEdgeTableHandle(Identifier identifier) { + Schema.Builder schemaBuilder = Schema.newBuilder(); + schemaBuilder.primaryKey(SRC_ID_COLUMN_NAME); + schemaBuilder.primaryKey(TARGET_ID_COLUMN_NAME); + schemaBuilder.primaryKey(TS_COLUMN_NAME); + schemaBuilder.primaryKey(DIRECTION_COLUMN_NAME); + schemaBuilder.primaryKey(LABEL_COLUMN_NAME); + schemaBuilder.column(SRC_ID_COLUMN_NAME, DataTypes.BYTES()); + schemaBuilder.column(TARGET_ID_COLUMN_NAME, DataTypes.BYTES()); + schemaBuilder.column(TS_COLUMN_NAME, DataTypes.BIGINT()); + schemaBuilder.column(DIRECTION_COLUMN_NAME, DataTypes.SMALLINT()); + schemaBuilder.column(LABEL_COLUMN_NAME, DataTypes.BYTES()); + schemaBuilder.column(VALUE_COLUMN_NAME, DataTypes.BYTES()); + Schema schema = schemaBuilder.build(); + Table vertexTable = this.client.createTable(schema, identifier); + return new PaimonTableRWHandle(identifier, vertexTable, shardId); + } + + protected PaimonTableRWHandle createVertexTableHandle(Identifier identifier) { + Schema.Builder schemaBuilder = Schema.newBuilder(); + schemaBuilder.primaryKey(SRC_ID_COLUMN_NAME); + schemaBuilder.column(SRC_ID_COLUMN_NAME, DataTypes.BYTES()); + schemaBuilder.column(TS_COLUMN_NAME, DataTypes.BIGINT()); + schemaBuilder.column(LABEL_COLUMN_NAME, DataTypes.SMALLINT()); + schemaBuilder.column(VALUE_COLUMN_NAME, DataTypes.BYTES()); + Schema schema = schemaBuilder.build(); + Table vertexTable = this.client.createTable(schema, identifier); + return new PaimonTableRWHandle(identifier, vertexTable, shardId); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/DynamicGraphPaimonStoreBase.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/DynamicGraphPaimonStoreBase.java index 9c4d202d7..7fab9c132 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/DynamicGraphPaimonStoreBase.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/DynamicGraphPaimonStoreBase.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -35,150 +36,152 @@ import org.apache.geaflow.store.paimon.proxy.IGraphMultiVersionedPaimonProxy; import org.apache.geaflow.store.paimon.proxy.PaimonProxyBuilder; -public class DynamicGraphPaimonStoreBase extends BasePaimonGraphStore implements - IDynamicGraphStore { - - private IGraphMultiVersionedPaimonProxy proxy; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - int[] projection = new int[]{KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; - - // TODO: Use graph schema to create table instead of KV table. - PaimonTableRWHandle vertexHandle = createVertexTable(shardId); - PaimonTableRWHandle vertexIndexHandle = createIndexTable(shardId); - PaimonTableRWHandle edgeHandle = createEdgeTable(shardId); - - IGraphKVEncoder encoder = GraphKVEncoderFactory.build(storeContext.getConfig(), - storeContext.getGraphSchema()); - this.proxy = PaimonProxyBuilder.buildMultiVersioned(storeContext.getConfig(), vertexHandle, - vertexIndexHandle, edgeHandle, projection, encoder); - } - - @Override - public void archive(long checkpointId) { - this.proxy.archive(checkpointId); - } - - @Override - public void recovery(long checkpointId) { - // TODO: Not implemented yet. - this.proxy.recover(checkpointId); - } - - @Override - public long recoveryLatest() { - return this.proxy.recoverLatest(); - } - - @Override - public void compact() { - - } - - @Override - public void addEdge(long version, IEdge edge) { - this.proxy.addEdge(version, edge); - } - - @Override - public void addVertex(long version, IVertex vertex) { - this.proxy.addVertex(version, vertex); - } - - @Override - public IVertex getVertex(long sliceId, K sid, IStatePushDown pushdown) { - return this.proxy.getVertex(sliceId, sid, pushdown); - } - - @Override - public List> getEdges(long sliceId, K sid, IStatePushDown pushdown) { - return this.proxy.getEdges(sliceId, sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(long sliceId, K sid, - IStatePushDown pushdown) { - return this.proxy.getOneDegreeGraph(sliceId, sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return this.proxy.vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - return this.proxy.vertexIDIterator(version, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, - IStatePushDown pushdown) { - return proxy.getVertexIterator(version, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return proxy.getVertexIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - return proxy.getEdgeIterator(version, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return proxy.getEdgeIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - return proxy.getOneDegreeGraphIterator(version, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - List keys, - IStatePushDown pushdown) { - return proxy.getOneDegreeGraphIterator(version, keys, pushdown); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - return this.proxy.getAllVersions(id, dataType); - } - - @Override - public long getLatestVersion(K id, DataType dataType) { - return this.proxy.getLatestVersion(id, dataType); - } - - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - return this.proxy.getAllVersionData(id, pushdown, dataType); - } - - @Override - public Map> getVersionData(K id, Collection slices, - IStatePushDown pushdown, DataType dataType) { - return this.proxy.getVersionData(id, slices, pushdown, dataType); - } - - @Override - public void flush() { - proxy.flush(); - } - - @Override - public void close() { - proxy.close(); - super.close(); - } +public class DynamicGraphPaimonStoreBase extends BasePaimonGraphStore + implements IDynamicGraphStore { + + private IGraphMultiVersionedPaimonProxy proxy; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + int[] projection = new int[] {KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; + + // TODO: Use graph schema to create table instead of KV table. + PaimonTableRWHandle vertexHandle = createVertexTable(shardId); + PaimonTableRWHandle vertexIndexHandle = createIndexTable(shardId); + PaimonTableRWHandle edgeHandle = createEdgeTable(shardId); + + IGraphKVEncoder encoder = + GraphKVEncoderFactory.build(storeContext.getConfig(), storeContext.getGraphSchema()); + this.proxy = + PaimonProxyBuilder.buildMultiVersioned( + storeContext.getConfig(), + vertexHandle, + vertexIndexHandle, + edgeHandle, + projection, + encoder); + } + + @Override + public void archive(long checkpointId) { + this.proxy.archive(checkpointId); + } + + @Override + public void recovery(long checkpointId) { + // TODO: Not implemented yet. + this.proxy.recover(checkpointId); + } + + @Override + public long recoveryLatest() { + return this.proxy.recoverLatest(); + } + + @Override + public void compact() {} + + @Override + public void addEdge(long version, IEdge edge) { + this.proxy.addEdge(version, edge); + } + + @Override + public void addVertex(long version, IVertex vertex) { + this.proxy.addVertex(version, vertex); + } + + @Override + public IVertex getVertex(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getVertex(sliceId, sid, pushdown); + } + + @Override + public List> getEdges(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getEdges(sliceId, sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getOneDegreeGraph(sliceId, sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return this.proxy.vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + return this.proxy.vertexIDIterator(version, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + return proxy.getVertexIterator(version, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getVertexIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + return proxy.getEdgeIterator(version, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getEdgeIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + return proxy.getOneDegreeGraphIterator(version, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getOneDegreeGraphIterator(version, keys, pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + return this.proxy.getAllVersions(id, dataType); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + return this.proxy.getLatestVersion(id, dataType); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + return this.proxy.getAllVersionData(id, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection slices, IStatePushDown pushdown, DataType dataType) { + return this.proxy.getVersionData(id, slices, pushdown, dataType); + } + + @Override + public void flush() { + proxy.flush(); + } + + @Override + public void close() { + proxy.close(); + super.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/KVPaimonStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/KVPaimonStore.java index e2c028415..79ca12227 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/KVPaimonStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/KVPaimonStore.java @@ -21,7 +21,6 @@ import static java.util.Collections.singletonList; -import com.google.common.base.Preconditions; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.state.serializer.IKVSerializer; import org.apache.geaflow.store.api.key.IKVStatefulStore; @@ -37,95 +36,102 @@ import org.apache.paimon.types.RowKind; import org.apache.paimon.types.RowType; -public class KVPaimonStore extends BasePaimonStore implements IKVStatefulStore { - - private static final String TABLE_NAME_PREFIX = "KVTable"; - - private IKVSerializer kvSerializer; - - private PaimonTableRWHandle tableHandle; - - private int[] projection; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.kvSerializer = (IKVSerializer) Preconditions.checkNotNull( - storeContext.getKeySerializer(), "keySerializer must be set"); - String tableName = TABLE_NAME_PREFIX + "#" + shardId; - - this.projection = new int[]{KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; - Identifier identifier = new Identifier(paimonStoreName, tableName); - this.tableHandle = createKVTableHandle(identifier); - } - - @Override - public void archive(long checkpointId) { - this.lastCheckpointId = checkpointId; - this.tableHandle.commit(lastCheckpointId); - } - - @Override - public void recovery(long checkpointId) { - throw new UnsupportedOperationException(); - } - - @Override - public long recoveryLatest() { - throw new UnsupportedOperationException(); - } - - @Override - public void compact() { - - } - - @Override - public V get(K key) { - byte[] binaryKey = this.kvSerializer.serializeKey(key); - RowType rowType = this.tableHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(Equal.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(binaryKey)); - RecordReaderIterator iterator = this.tableHandle.getIterator(predicate, null, - projection); - try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { - if (paimonIterator.hasNext()) { - Tuple row = paimonIterator.next(); - return this.kvSerializer.deserializeValue(row.getF1()); - } - return null; - } - } - - @Override - public void put(K key, V value) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] valueArray = this.kvSerializer.serializeValue(value); - GenericRow record = GenericRow.of(keyArray, valueArray); - this.tableHandle.write(record); - } - - @Override - public void remove(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - GenericRow record = GenericRow.ofKind(RowKind.DELETE, keyArray, null); - this.tableHandle.write(record); - } - - @Override - public void drop() { - this.client.dropTable(tableHandle.getIdentifier()); - super.drop(); - } +import com.google.common.base.Preconditions; - @Override - public void flush() { - this.tableHandle.flush(lastCheckpointId); - } +public class KVPaimonStore extends BasePaimonStore implements IKVStatefulStore { - @Override - public void close() { - this.tableHandle.close(); - this.client.close(); + private static final String TABLE_NAME_PREFIX = "KVTable"; + + private IKVSerializer kvSerializer; + + private PaimonTableRWHandle tableHandle; + + private int[] projection; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.kvSerializer = + (IKVSerializer) + Preconditions.checkNotNull( + storeContext.getKeySerializer(), "keySerializer must be set"); + String tableName = TABLE_NAME_PREFIX + "#" + shardId; + + this.projection = new int[] {KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; + Identifier identifier = new Identifier(paimonStoreName, tableName); + this.tableHandle = createKVTableHandle(identifier); + } + + @Override + public void archive(long checkpointId) { + this.lastCheckpointId = checkpointId; + this.tableHandle.commit(lastCheckpointId); + } + + @Override + public void recovery(long checkpointId) { + throw new UnsupportedOperationException(); + } + + @Override + public long recoveryLatest() { + throw new UnsupportedOperationException(); + } + + @Override + public void compact() {} + + @Override + public V get(K key) { + byte[] binaryKey = this.kvSerializer.serializeKey(key); + RowType rowType = this.tableHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + Equal.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(binaryKey)); + RecordReaderIterator iterator = + this.tableHandle.getIterator(predicate, null, projection); + try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { + if (paimonIterator.hasNext()) { + Tuple row = paimonIterator.next(); + return this.kvSerializer.deserializeValue(row.getF1()); + } + return null; } + } + + @Override + public void put(K key, V value) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] valueArray = this.kvSerializer.serializeValue(value); + GenericRow record = GenericRow.of(keyArray, valueArray); + this.tableHandle.write(record); + } + + @Override + public void remove(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + GenericRow record = GenericRow.ofKind(RowKind.DELETE, keyArray, null); + this.tableHandle.write(record); + } + + @Override + public void drop() { + this.client.dropTable(tableHandle.getIdentifier()); + super.drop(); + } + + @Override + public void flush() { + this.tableHandle.flush(lastCheckpointId); + } + + @Override + public void close() { + this.tableHandle.close(); + this.client.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogClient.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogClient.java index f4d163ed7..f2abf258d 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogClient.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogClient.java @@ -31,59 +31,59 @@ public class PaimonCatalogClient { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonCatalogClient.class); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonCatalogClient.class); - private final Configuration config; + private final Configuration config; - private Catalog catalog; + private Catalog catalog; - public PaimonCatalogClient(Catalog catalog, Configuration config) { - this.config = config; - this.catalog = catalog; - } + public PaimonCatalogClient(Catalog catalog, Configuration config) { + this.config = config; + this.catalog = catalog; + } - public Table createTable(Schema schema, Identifier identifier) { - try { - LOGGER.info("create table {}", identifier.getFullName()); - this.catalog.createDatabase(identifier.getDatabaseName(), true); - this.catalog.createTable(identifier, schema, true); - return getTable(identifier); - } catch (Exception e) { - throw new GeaflowRuntimeException("Create database or table failed.", e); - } + public Table createTable(Schema schema, Identifier identifier) { + try { + LOGGER.info("create table {}", identifier.getFullName()); + this.catalog.createDatabase(identifier.getDatabaseName(), true); + this.catalog.createTable(identifier, schema, true); + return getTable(identifier); + } catch (Exception e) { + throw new GeaflowRuntimeException("Create database or table failed.", e); } + } - public Catalog getCatalog() { - return this.catalog; - } + public Catalog getCatalog() { + return this.catalog; + } - public Table getTable(Identifier identifier) throws TableNotExistException { - return this.catalog.getTable(identifier); - } + public Table getTable(Identifier identifier) throws TableNotExistException { + return this.catalog.getTable(identifier); + } - public void close() { - if (catalog != null) { - try { - catalog.close(); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to close catalog.", e); - } - } + public void close() { + if (catalog != null) { + try { + catalog.close(); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to close catalog.", e); + } } + } - public void dropDatabase(String dbName) { - try { - catalog.dropDatabase(dbName, true, true); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to drop database.", e); - } + public void dropDatabase(String dbName) { + try { + catalog.dropDatabase(dbName, true, true); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to drop database.", e); } + } - public void dropTable(Identifier identifier) { - try { - catalog.dropTable(identifier, true); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to drop table.", e); - } + public void dropTable(Identifier identifier) { + try { + catalog.dropTable(identifier, true); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to drop table.", e); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogManager.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogManager.java index b8294fc0a..9a3deaf46 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogManager.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonCatalogManager.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.store.paimon.config.PaimonStoreConfig; import org.apache.paimon.catalog.Catalog; @@ -36,36 +37,35 @@ public class PaimonCatalogManager { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonCatalogManager.class); - - private static final Map catalogMap = new ConcurrentHashMap<>(); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonCatalogManager.class); - public static synchronized PaimonCatalogClient getCatalogClient(Configuration config) { - String warehouse = config.getString(PAIMON_STORE_WAREHOUSE); - return catalogMap.computeIfAbsent(warehouse, k -> createCatalog(config)); - } + private static final Map catalogMap = new ConcurrentHashMap<>(); - private static PaimonCatalogClient createCatalog(Configuration config) { - String metastore = config.getString(PAIMON_STORE_META_STORE); - String warehouse = config.getString(PAIMON_STORE_WAREHOUSE); + public static synchronized PaimonCatalogClient getCatalogClient(Configuration config) { + String warehouse = config.getString(PAIMON_STORE_WAREHOUSE); + return catalogMap.computeIfAbsent(warehouse, k -> createCatalog(config)); + } - Options options = new Options(); - options.set(CatalogOptions.WAREHOUSE, warehouse.toLowerCase()); - options.set(CatalogOptions.METASTORE, metastore.toLowerCase()); - Map extraOptions = PaimonStoreConfig.getPaimonOptions(config); - if (extraOptions != null) { - for (Map.Entry entry : extraOptions.entrySet()) { - LOGGER.info("add option: {}={}", entry.getKey(), entry.getValue()); - options.set(entry.getKey(), entry.getValue()); - } - } - if (!metastore.equalsIgnoreCase("filesystem")) { - throw new UnsupportedOperationException("Not support meta store type " + metastore); - } + private static PaimonCatalogClient createCatalog(Configuration config) { + String metastore = config.getString(PAIMON_STORE_META_STORE); + String warehouse = config.getString(PAIMON_STORE_WAREHOUSE); - CatalogContext context = CatalogContext.create(options); - Catalog catalog = CatalogFactory.createCatalog(context); - return new PaimonCatalogClient(catalog, config); + Options options = new Options(); + options.set(CatalogOptions.WAREHOUSE, warehouse.toLowerCase()); + options.set(CatalogOptions.METASTORE, metastore.toLowerCase()); + Map extraOptions = PaimonStoreConfig.getPaimonOptions(config); + if (extraOptions != null) { + for (Map.Entry entry : extraOptions.entrySet()) { + LOGGER.info("add option: {}={}", entry.getKey(), entry.getValue()); + options.set(entry.getKey(), entry.getValue()); + } + } + if (!metastore.equalsIgnoreCase("filesystem")) { + throw new UnsupportedOperationException("Not support meta store type " + metastore); } + CatalogContext context = CatalogContext.create(options); + Catalog catalog = CatalogFactory.createCatalog(context); + return new PaimonCatalogClient(catalog, config); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonStoreBuilder.java index 1babc205d..0eb407c41 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonStoreBuilder.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,42 +33,42 @@ public class PaimonStoreBuilder implements IStoreBuilder { - private static final PaimonStoreDesc paimonStoreDesc = new PaimonStoreDesc(); + private static final PaimonStoreDesc paimonStoreDesc = new PaimonStoreDesc(); - @Override - public IBaseStore getStore(DataModel type, Configuration config) { - switch (type) { - case KV: - return new KVPaimonStore(); - case STATIC_GRAPH: - return new StaticGraphPaimonStoreBase(); - case DYNAMIC_GRAPH: - return new DynamicGraphPaimonStoreBase<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); - } + @Override + public IBaseStore getStore(DataModel type, Configuration config) { + switch (type) { + case KV: + return new KVPaimonStore(); + case STATIC_GRAPH: + return new StaticGraphPaimonStoreBase(); + case DYNAMIC_GRAPH: + return new DynamicGraphPaimonStoreBase<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); } + } - @Override - public StoreDesc getStoreDesc() { - return paimonStoreDesc; - } + @Override + public StoreDesc getStoreDesc() { + return paimonStoreDesc; + } - @Override - public List supportedDataModel() { - return Arrays.asList(DataModel.KV, DataModel.STATIC_GRAPH); - } + @Override + public List supportedDataModel() { + return Arrays.asList(DataModel.KV, DataModel.STATIC_GRAPH); + } - public static class PaimonStoreDesc implements StoreDesc { + public static class PaimonStoreDesc implements StoreDesc { - @Override - public boolean isLocalStore() { - return true; - } + @Override + public boolean isLocalStore() { + return true; + } - @Override - public String name() { - return StoreType.PAIMON.name(); - } + @Override + public String name() { + return StoreType.PAIMON.name(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonTableRWHandle.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonTableRWHandle.java index d88ecb1cc..6d7253178 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonTableRWHandle.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/PaimonTableRWHandle.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.OptionalLong; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.store.paimon.commit.PaimonCommitRegistry; import org.apache.paimon.catalog.Identifier; @@ -44,144 +45,152 @@ public class PaimonTableRWHandle { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonTableRWHandle.class); - private final int shardId; - private final Identifier identifier; - private final Table table; - private final boolean isDistributedMode; - private final PaimonCommitRegistry registry; - private StreamTableWrite streamTableWrite; - private List commitMessages = new ArrayList<>(); - - public PaimonTableRWHandle(Identifier identifier, Table table, int shardId) { - this(identifier, table, shardId, false); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonTableRWHandle.class); + private final int shardId; + private final Identifier identifier; + private final Table table; + private final boolean isDistributedMode; + private final PaimonCommitRegistry registry; + private StreamTableWrite streamTableWrite; + private List commitMessages = new ArrayList<>(); + + public PaimonTableRWHandle(Identifier identifier, Table table, int shardId) { + this(identifier, table, shardId, false); + } + + public PaimonTableRWHandle( + Identifier identifier, Table table, int shardId, boolean isDistributedMode) { + this.shardId = shardId; + this.identifier = identifier; + this.table = table; + this.streamTableWrite = table.newStreamWriteBuilder().newWrite(); + this.registry = PaimonCommitRegistry.initInstance(); + this.isDistributedMode = isDistributedMode; + } + + public void write(GenericRow row) { + write(row, shardId); + } + + public void write(GenericRow row, int bucket) { + try { + streamTableWrite.write(row, bucket); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to put data into Paimon.", e); } - - public PaimonTableRWHandle(Identifier identifier, Table table, int shardId, - boolean isDistributedMode) { - this.shardId = shardId; - this.identifier = identifier; - this.table = table; - this.streamTableWrite = table.newStreamWriteBuilder().newWrite(); - this.registry = PaimonCommitRegistry.initInstance(); - this.isDistributedMode = isDistributedMode; + } + + public void commit(long checkpointId) { + commit(checkpointId, false); + } + + public void commit(long checkpointId, boolean waitCompaction) { + flush(checkpointId, waitCompaction); + List messages = new ArrayList<>(); + for (CommitMessage commitMessage : commitMessages) { + if (commitMessage instanceof CommitMessageImpl + && ((CommitMessageImpl) commitMessage).isEmpty()) { + continue; + } + messages.add(commitMessage); } - public void write(GenericRow row) { - write(row, shardId); + if (isDistributedMode) { + LOGGER.info( + "{} pre commit chkId:{} messages:{} wait:{}", + this.identifier, + checkpointId, + messages.size(), + waitCompaction); + registry.addMessages(shardId, table.name(), messages); + } else { + LOGGER.info( + "{} commit chkId:{} messages:{} wait:{}", + this.identifier, + checkpointId, + messages.size(), + waitCompaction); + try (StreamTableCommit commit = table.newStreamWriteBuilder().newCommit()) { + commit.commit(checkpointId, messages); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to commit data into Paimon.", e); + } } + commitMessages.clear(); + } - public void write(GenericRow row, int bucket) { - try { - streamTableWrite.write(row, bucket); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to put data into Paimon.", e); - } - } + public void rollbackTo(long snapshotId) { + table.rollbackTo(snapshotId); + } - public void commit(long checkpointId) { - commit(checkpointId, false); + public long rollbackToLatest() { + long latestSnapshotId = getLatestSnapshotId(); + if (latestSnapshotId < 0) { + throw new GeaflowRuntimeException("Not found any valid snapshot version"); } - - public void commit(long checkpointId, boolean waitCompaction) { - flush(checkpointId, waitCompaction); - List messages = new ArrayList<>(); - for (CommitMessage commitMessage : commitMessages) { - if (commitMessage instanceof CommitMessageImpl - && ((CommitMessageImpl) commitMessage).isEmpty()) { - continue; - } - messages.add(commitMessage); - } - - if (isDistributedMode) { - LOGGER.info("{} pre commit chkId:{} messages:{} wait:{}", - this.identifier, checkpointId, messages.size(), waitCompaction); - registry.addMessages(shardId, table.name(), messages); - } else { - LOGGER.info("{} commit chkId:{} messages:{} wait:{}", - this.identifier, checkpointId, messages.size(), waitCompaction); - try (StreamTableCommit commit = table.newStreamWriteBuilder().newCommit()) { - commit.commit(checkpointId, messages); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to commit data into Paimon.", e); - } - } - commitMessages.clear(); + table.rollbackTo(latestSnapshotId); + return latestSnapshotId; + } + + public RecordReaderIterator getIterator( + Predicate predicate, Filter filter, int[] projection) { + try { + ReadBuilder readBuilder = table.newReadBuilder().withProjection(projection); + if (predicate != null) { + readBuilder.withFilter(predicate); + } + readBuilder.withBucketFilter(bucketId -> bucketId == shardId); + + List splits = readBuilder.newScan().plan().splits(); + TableRead tableRead = readBuilder.newRead(); + if (predicate != null) { + tableRead.executeFilter(); + } + RecordReader reader = tableRead.createReader(splits); + if (filter != null) { + reader = reader.filter(filter); + } + return new RecordReaderIterator<>(reader); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to get data from Paimon.", e); } - - public void rollbackTo(long snapshotId) { - table.rollbackTo(snapshotId); + } + + public void flush(long checkpointIdentifier) { + flush(checkpointIdentifier, false); + } + + public void flush(long checkpointIdentifier, boolean waitCompaction) { + try { + this.commitMessages.addAll( + streamTableWrite.prepareCommit(waitCompaction, checkpointIdentifier)); + } catch (Exception e) { + throw new GeaflowRuntimeException("Failed to flush data into Paimon.", e); } + } - public long rollbackToLatest() { - long latestSnapshotId = getLatestSnapshotId(); - if (latestSnapshotId < 0) { - throw new GeaflowRuntimeException("Not found any valid snapshot version"); - } - table.rollbackTo(latestSnapshotId); - return latestSnapshotId; + public void close() { + try { + this.streamTableWrite.close(); + } catch (Exception e) { + throw new GeaflowRuntimeException("Close stream table write failed.", e); } - - public RecordReaderIterator getIterator(Predicate predicate, Filter filter, - int[] projection) { - try { - ReadBuilder readBuilder = table.newReadBuilder().withProjection(projection); - if (predicate != null) { - readBuilder.withFilter(predicate); - } - readBuilder.withBucketFilter(bucketId -> bucketId == shardId); - - List splits = readBuilder.newScan().plan().splits(); - TableRead tableRead = readBuilder.newRead(); - if (predicate != null) { - tableRead.executeFilter(); - } - RecordReader reader = tableRead.createReader(splits); - if (filter != null) { - reader = reader.filter(filter); - } - return new RecordReaderIterator<>(reader); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to get data from Paimon.", e); - } - } - - public void flush(long checkpointIdentifier) { - flush(checkpointIdentifier, false); + } + + public Table getTable() { + return this.table; + } + + public long getLatestSnapshotId() { + OptionalLong latestCheckpoint = table.latestSnapshotId(); + if (latestCheckpoint.isPresent()) { + return latestCheckpoint.getAsLong(); + } else { + return -1L; } + } - public void flush(long checkpointIdentifier, boolean waitCompaction) { - try { - this.commitMessages.addAll( - streamTableWrite.prepareCommit(waitCompaction, checkpointIdentifier)); - } catch (Exception e) { - throw new GeaflowRuntimeException("Failed to flush data into Paimon.", e); - } - } - - public void close() { - try { - this.streamTableWrite.close(); - } catch (Exception e) { - throw new GeaflowRuntimeException("Close stream table write failed.", e); - } - } - - public Table getTable() { - return this.table; - } - - public long getLatestSnapshotId() { - OptionalLong latestCheckpoint = table.latestSnapshotId(); - if (latestCheckpoint.isPresent()) { - return latestCheckpoint.getAsLong(); - } else { - return -1L; - } - } - - public Identifier getIdentifier() { - return this.identifier; - } + public Identifier getIdentifier() { + return this.identifier; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/StaticGraphPaimonStoreBase.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/StaticGraphPaimonStoreBase.java index 4b1dfe66a..fd3c4834a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/StaticGraphPaimonStoreBase.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/StaticGraphPaimonStoreBase.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; @@ -36,164 +37,162 @@ import org.apache.geaflow.store.paimon.proxy.IGraphPaimonProxy; import org.apache.geaflow.store.paimon.proxy.PaimonProxyBuilder; -public class StaticGraphPaimonStoreBase extends BasePaimonGraphStore implements - IStaticGraphStore { - - private EdgeAtom sortAtom; - - private IGraphPaimonProxy proxy; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - int[] projection = new int[]{KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; - - PaimonTableRWHandle vertexHandle = createVertexTable(shardId); - PaimonTableRWHandle edgeHandle = createEdgeTable(shardId); - - this.sortAtom = storeContext.getGraphSchema().getEdgeAtoms().get(1); - IGraphKVEncoder encoder = GraphKVEncoderFactory.build(storeContext.getConfig(), - storeContext.getGraphSchema()); - this.proxy = PaimonProxyBuilder.build(storeContext.getConfig(), vertexHandle, edgeHandle, - projection, encoder); - } - - @Override - public void archive(long checkpointId) { - this.proxy.archive(checkpointId); - } - - @Override - public void recovery(long checkpointId) { - // TODO: Not implemented yet. - this.proxy.recover(checkpointId); - } - - @Override - public long recoveryLatest() { - return this.proxy.recoverLatest(); - } - - @Override - public void compact() { - - } - - @Override - public void addEdge(IEdge edge) { - this.proxy.addEdge(edge); - } - - @Override - public void addVertex(IVertex vertex) { - this.proxy.addVertex(vertex); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - return this.proxy.getVertex(sid, pushdown); - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return this.proxy.getEdges(sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return this.proxy.getOneDegreeGraph(sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return this.proxy.vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - return this.proxy.vertexIDIterator(pushDown); - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - return this.proxy.getVertexIterator(pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(List keys, - IStatePushDown pushdown) { - return this.proxy.getVertexIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - return this.proxy.getEdgeIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - return this.proxy.getEdgeIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - return this.proxy.getOneDegreeGraphIterator(pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, - IStatePushDown pushdown) { - return this.proxy.getOneDegreeGraphIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - return this.proxy.getEdgeProjectIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown) { - return this.proxy.getEdgeProjectIterator(keys, pushdown); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - return this.proxy.getAggResult(pushdown); - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - return this.proxy.getAggResult(keys, pushdown); - } - - @Override - public void flush() { - this.proxy.flush(); - } - - @Override - public void drop() { - super.drop(); - } - - @Override - public void close() { - proxy.close(); - super.close(); - } - - private void checkOrderField(List orderFields) { - boolean emptyFields = orderFields == null || orderFields.isEmpty(); - boolean checkOk = emptyFields || sortAtom == orderFields.get(0); - if (!checkOk) { - throw new GeaflowRuntimeException( - String.format("store is sort by %s but need %s", sortAtom, orderFields.get(0))); - } - } +public class StaticGraphPaimonStoreBase extends BasePaimonGraphStore + implements IStaticGraphStore { + + private EdgeAtom sortAtom; + + private IGraphPaimonProxy proxy; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + int[] projection = new int[] {KEY_COLUMN_INDEX, VALUE_COLUMN_INDEX}; + + PaimonTableRWHandle vertexHandle = createVertexTable(shardId); + PaimonTableRWHandle edgeHandle = createEdgeTable(shardId); + + this.sortAtom = storeContext.getGraphSchema().getEdgeAtoms().get(1); + IGraphKVEncoder encoder = + GraphKVEncoderFactory.build(storeContext.getConfig(), storeContext.getGraphSchema()); + this.proxy = + PaimonProxyBuilder.build( + storeContext.getConfig(), vertexHandle, edgeHandle, projection, encoder); + } + + @Override + public void archive(long checkpointId) { + this.proxy.archive(checkpointId); + } + + @Override + public void recovery(long checkpointId) { + // TODO: Not implemented yet. + this.proxy.recover(checkpointId); + } + + @Override + public long recoveryLatest() { + return this.proxy.recoverLatest(); + } + + @Override + public void compact() {} + + @Override + public void addEdge(IEdge edge) { + this.proxy.addEdge(edge); + } + + @Override + public void addVertex(IVertex vertex) { + this.proxy.addVertex(vertex); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + return this.proxy.getVertex(sid, pushdown); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return this.proxy.getEdges(sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return this.proxy.getOneDegreeGraph(sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return this.proxy.vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + return this.proxy.vertexIDIterator(pushDown); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + return this.proxy.getVertexIterator(pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return this.proxy.getVertexIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + return this.proxy.getEdgeIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + return this.proxy.getEdgeIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + return this.proxy.getOneDegreeGraphIterator(pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return this.proxy.getOneDegreeGraphIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + return this.proxy.getEdgeProjectIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return this.proxy.getEdgeProjectIterator(keys, pushdown); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + return this.proxy.getAggResult(pushdown); + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + return this.proxy.getAggResult(keys, pushdown); + } + + @Override + public void flush() { + this.proxy.flush(); + } + + @Override + public void drop() { + super.drop(); + } + + @Override + public void close() { + proxy.close(); + super.close(); + } + + private void checkOrderField(List orderFields) { + boolean emptyFields = orderFields == null || orderFields.isEmpty(); + boolean checkOk = emptyFields || sortAtom == orderFields.get(0); + if (!checkOk) { + throw new GeaflowRuntimeException( + String.format("store is sort by %s but need %s", sortAtom, orderFields.get(0))); + } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonCommitRegistry.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonCommitRegistry.java index 8468a410f..c31c07609 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonCommitRegistry.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonCommitRegistry.java @@ -23,58 +23,59 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; + import org.apache.paimon.table.sink.CommitMessage; public class PaimonCommitRegistry { - private static PaimonCommitRegistry instance; + private static PaimonCommitRegistry instance; - private final ConcurrentHashMap> index2CommitMessages = - new ConcurrentHashMap<>(); + private final ConcurrentHashMap> index2CommitMessages = + new ConcurrentHashMap<>(); - public static synchronized PaimonCommitRegistry initInstance() { - if (instance == null) { - instance = new PaimonCommitRegistry(); - } - return instance; + public static synchronized PaimonCommitRegistry initInstance() { + if (instance == null) { + instance = new PaimonCommitRegistry(); } + return instance; + } - public static synchronized PaimonCommitRegistry getInstance() { - return instance; - } + public static synchronized PaimonCommitRegistry getInstance() { + return instance; + } - public synchronized void addMessages(int index, String tableName, - List commitMessages) { - List message = index2CommitMessages.computeIfAbsent(index, - key -> new ArrayList<>()); - message.add(new TaskCommitMessage(tableName, commitMessages)); - } + public synchronized void addMessages( + int index, String tableName, List commitMessages) { + List message = + index2CommitMessages.computeIfAbsent(index, key -> new ArrayList<>()); + message.add(new TaskCommitMessage(tableName, commitMessages)); + } - public synchronized List pollMessages(int index) { - List messages = index2CommitMessages.get(index); - if (messages != null && !messages.isEmpty()) { - List result = new ArrayList<>(messages); - messages.clear(); - return result; - } - return messages; + public synchronized List pollMessages(int index) { + List messages = index2CommitMessages.get(index); + if (messages != null && !messages.isEmpty()) { + List result = new ArrayList<>(messages); + messages.clear(); + return result; } + return messages; + } - public static class TaskCommitMessage implements Serializable { - private final String tableName; - private final List messages; + public static class TaskCommitMessage implements Serializable { + private final String tableName; + private final List messages; - public TaskCommitMessage(String tableName, List messages) { - this.tableName = tableName; - this.messages = messages; - } + public TaskCommitMessage(String tableName, List messages) { + this.tableName = tableName; + this.messages = messages; + } - public String getTableName() { - return tableName; - } + public String getTableName() { + return tableName; + } - public List getMessages() { - return messages; - } + public List getMessages() { + return messages; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonMessage.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonMessage.java index 686ef8f4a..36ac0b44f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonMessage.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/commit/PaimonMessage.java @@ -23,6 +23,7 @@ import java.io.ByteArrayOutputStream; import java.io.Serializable; import java.util.List; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.paimon.io.DataInputViewStreamWrapper; import org.apache.paimon.io.DataOutputViewStreamWrapper; @@ -33,57 +34,57 @@ public class PaimonMessage implements Serializable { - private static final Logger LOGGER = LoggerFactory.getLogger(PaimonMessage.class); - private static final ThreadLocal CACHE = ThreadLocal.withInitial( - CommitMessageSerializer::new); + private static final Logger LOGGER = LoggerFactory.getLogger(PaimonMessage.class); + private static final ThreadLocal CACHE = + ThreadLocal.withInitial(CommitMessageSerializer::new); - private transient List messages; - private byte[] messageBytes; - private int serializerVersion; - private String tableName; - private long chkId; + private transient List messages; + private byte[] messageBytes; + private int serializerVersion; + private String tableName; + private long chkId; - public PaimonMessage(long chkId, String tableName, List messages) { - this.chkId = chkId; - this.tableName = tableName; + public PaimonMessage(long chkId, String tableName, List messages) { + this.chkId = chkId; + this.tableName = tableName; - CommitMessageSerializer serializer = CACHE.get(); - this.serializerVersion = serializer.getVersion(); + CommitMessageSerializer serializer = CACHE.get(); + this.serializerVersion = serializer.getVersion(); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - try { - serializer.serializeList(messages, new DataOutputViewStreamWrapper(out)); - this.messageBytes = out.toByteArray(); - LOGGER.info("ser bytes: {}", messageBytes.length); - } catch (Exception e) { - LOGGER.error("serialize message error", e); - throw new GeaflowRuntimeException(e); - } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + serializer.serializeList(messages, new DataOutputViewStreamWrapper(out)); + this.messageBytes = out.toByteArray(); + LOGGER.info("ser bytes: {}", messageBytes.length); + } catch (Exception e) { + LOGGER.error("serialize message error", e); + throw new GeaflowRuntimeException(e); } + } - public List getMessages() { - if (messages == null) { - CommitMessageSerializer serializer = CACHE.get(); - try { - if (messageBytes == null) { - LOGGER.warn("deserialize message error, null"); - } - ByteArrayInputStream in = new ByteArrayInputStream(messageBytes); - messages = serializer.deserializeList(serializerVersion, - new DataInputViewStreamWrapper(in)); - } catch (Exception e) { - LOGGER.error("deserialize message error", e); - throw new GeaflowRuntimeException(e); - } + public List getMessages() { + if (messages == null) { + CommitMessageSerializer serializer = CACHE.get(); + try { + if (messageBytes == null) { + LOGGER.warn("deserialize message error, null"); } - return messages; + ByteArrayInputStream in = new ByteArrayInputStream(messageBytes); + messages = + serializer.deserializeList(serializerVersion, new DataInputViewStreamWrapper(in)); + } catch (Exception e) { + LOGGER.error("deserialize message error", e); + throw new GeaflowRuntimeException(e); + } } + return messages; + } - public String getTableName() { - return tableName; - } + public String getTableName() { + return tableName; + } - public long getCheckpointId() { - return chkId; - } + public long getCheckpointId() { + return chkId; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonConfigKeys.java index 1e4c80b19..82ed1f401 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonConfigKeys.java @@ -24,51 +24,52 @@ public class PaimonConfigKeys { - public static final ConfigKey PAIMON_STORE_WAREHOUSE = ConfigKeys - .key("geaflow.store.paimon.warehouse") - .defaultValue("file:///tmp/paimon/") - .description("paimon warehouse, default LOCAL path, now support path prefix: " - + "[file://], Options for future: [hdfs://, oss://, s3://]"); + public static final ConfigKey PAIMON_STORE_WAREHOUSE = + ConfigKeys.key("geaflow.store.paimon.warehouse") + .defaultValue("file:///tmp/paimon/") + .description( + "paimon warehouse, default LOCAL path, now support path prefix: " + + "[file://], Options for future: [hdfs://, oss://, s3://]"); - public static final ConfigKey PAIMON_STORE_META_STORE = ConfigKeys - .key("geaflow.store.paimon.meta.store") - .defaultValue("FILESYSTEM") - .description("Metastore of paimon catalog, now support [FILESYSTEM]. Options for future: " - + "[HIVE, JDBC]."); + public static final ConfigKey PAIMON_STORE_META_STORE = + ConfigKeys.key("geaflow.store.paimon.meta.store") + .defaultValue("FILESYSTEM") + .description( + "Metastore of paimon catalog, now support [FILESYSTEM]. Options for future: " + + "[HIVE, JDBC]."); - public static final ConfigKey PAIMON_STORE_OPTIONS = ConfigKeys - .key("geaflow.store.paimon.options") - .defaultValue(128) - .description("paimon memtable size, default 256MB"); + public static final ConfigKey PAIMON_STORE_OPTIONS = + ConfigKeys.key("geaflow.store.paimon.options") + .defaultValue(128) + .description("paimon memtable size, default 256MB"); - public static final ConfigKey PAIMON_STORE_DATABASE = ConfigKeys - .key("geaflow.store.paimon.database") - .defaultValue("graph") - .description("paimon graph store database"); + public static final ConfigKey PAIMON_STORE_DATABASE = + ConfigKeys.key("geaflow.store.paimon.database") + .defaultValue("graph") + .description("paimon graph store database"); - public static final ConfigKey PAIMON_STORE_VERTEX_TABLE = ConfigKeys - .key("geaflow.store.paimon.vertex.table") - .defaultValue("vertex") - .description("paimon graph store vertex table name"); + public static final ConfigKey PAIMON_STORE_VERTEX_TABLE = + ConfigKeys.key("geaflow.store.paimon.vertex.table") + .defaultValue("vertex") + .description("paimon graph store vertex table name"); - public static final ConfigKey PAIMON_STORE_EDGE_TABLE = ConfigKeys - .key("geaflow.store.paimon.edge.table") - .defaultValue("edge") - .description("paimon graph store edge table name"); + public static final ConfigKey PAIMON_STORE_EDGE_TABLE = + ConfigKeys.key("geaflow.store.paimon.edge.table") + .defaultValue("edge") + .description("paimon graph store edge table name"); - public static final ConfigKey PAIMON_STORE_INDEX_TABLE = ConfigKeys - .key("geaflow.store.paimon.index.table") - .defaultValue("index") - .description("paimon graph store index table name"); + public static final ConfigKey PAIMON_STORE_INDEX_TABLE = + ConfigKeys.key("geaflow.store.paimon.index.table") + .defaultValue("index") + .description("paimon graph store index table name"); - public static final ConfigKey PAIMON_STORE_DISTRIBUTED_MODE_ENABLE = ConfigKeys - .key("geaflow.store.paimon.distributed.mode.enable") - .defaultValue(true) - .description("paimon graph store distributed mode"); - - public static final ConfigKey PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE = ConfigKeys - .key("geaflow.store.paimon.table.auto.create") - .defaultValue(false) - .description("paimon graph store table auto create"); + public static final ConfigKey PAIMON_STORE_DISTRIBUTED_MODE_ENABLE = + ConfigKeys.key("geaflow.store.paimon.distributed.mode.enable") + .defaultValue(true) + .description("paimon graph store distributed mode"); + public static final ConfigKey PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE = + ConfigKeys.key("geaflow.store.paimon.table.auto.create") + .defaultValue(false) + .description("paimon graph store table auto create"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonStoreConfig.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonStoreConfig.java index e82da589b..b129cca7e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonStoreConfig.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/config/PaimonStoreConfig.java @@ -21,26 +21,26 @@ import java.util.HashMap; import java.util.Map; + import org.apache.commons.lang.StringUtils; import org.apache.geaflow.common.config.Configuration; public class PaimonStoreConfig { - public static Map getPaimonOptions(Configuration config) { - String s = config.getString(PaimonConfigKeys.PAIMON_STORE_OPTIONS, ""); - if (StringUtils.isEmpty(s)) { - return null; - } - Map options = new HashMap<>(); - String[] pairs = s.trim().split(","); - for (String pair : pairs) { - String[] kv = pair.trim().split("="); - if (kv.length < 2) { - continue; - } - options.put(kv[0], kv[1]); - } - return options; + public static Map getPaimonOptions(Configuration config) { + String s = config.getString(PaimonConfigKeys.PAIMON_STORE_OPTIONS, ""); + if (StringUtils.isEmpty(s)) { + return null; } - + Map options = new HashMap<>(); + String[] pairs = s.trim().split(","); + for (String pair : pairs) { + String[] kv = pair.trim().split("="); + if (kv.length < 2) { + continue; + } + options.put(kv[0], kv[1]); + } + return options; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/iterator/PaimonIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/iterator/PaimonIterator.java index 01bcac4e7..564c659cd 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/iterator/PaimonIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/iterator/PaimonIterator.java @@ -27,42 +27,42 @@ public class PaimonIterator implements CloseableIterator> { - private final RecordReaderIterator paimonRowIter; - private Tuple next; - private boolean isClosed = false; + private final RecordReaderIterator paimonRowIter; + private Tuple next; + private boolean isClosed = false; - public PaimonIterator(RecordReaderIterator iterator) { - this.paimonRowIter = iterator; - } + public PaimonIterator(RecordReaderIterator iterator) { + this.paimonRowIter = iterator; + } - @Override - public boolean hasNext() { - next = null; - if (!isClosed && this.paimonRowIter.hasNext()) { - InternalRow nextRow = this.paimonRowIter.next(); - next = Tuple.of(nextRow.getBinary(0), nextRow.getBinary(1)); - } - if (next == null) { - close(); - return false; - } - return true; + @Override + public boolean hasNext() { + next = null; + if (!isClosed && this.paimonRowIter.hasNext()) { + InternalRow nextRow = this.paimonRowIter.next(); + next = Tuple.of(nextRow.getBinary(0), nextRow.getBinary(1)); } - - @Override - public Tuple next() { - return next; + if (next == null) { + close(); + return false; } + return true; + } + + @Override + public Tuple next() { + return next; + } - @Override - public void close() { - if (!isClosed) { - try { - this.paimonRowIter.close(); - } catch (Exception e) { - throw new GeaflowRuntimeException("Close paimon iterator failed.", e); - } - isClosed = true; - } + @Override + public void close() { + if (!isClosed) { + try { + this.paimonRowIter.close(); + } catch (Exception e) { + throw new GeaflowRuntimeException("Close paimon iterator failed.", e); + } + isClosed = true; } -} \ No newline at end of file + } +} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWith.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWith.java index 3fe904383..81844404a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWith.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWith.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Optional; + import org.apache.paimon.memory.MemorySegment; import org.apache.paimon.predicate.FieldRef; import org.apache.paimon.predicate.FunctionVisitor; @@ -30,63 +31,61 @@ public class BytesStartsWith extends NullFalseLeafBinaryFunction { - public static final BytesStartsWith INSTANCE = new BytesStartsWith(); + public static final BytesStartsWith INSTANCE = new BytesStartsWith(); - private BytesStartsWith() { - } + private BytesStartsWith() {} - @Override - public boolean test(DataType type, Object field, Object patternLiteral) { - return startsWith((byte[]) field, (byte[]) patternLiteral); - } + @Override + public boolean test(DataType type, Object field, Object patternLiteral) { + return startsWith((byte[]) field, (byte[]) patternLiteral); + } - @Override - public boolean test(DataType type, long rowCount, Object min, Object max, Long nullCount, - Object patternLiteral) { - byte[] minBytes = (byte[]) min; - byte[] maxBytes = (byte[]) max; - byte[] pattern = (byte[]) patternLiteral; - return (startsWith(minBytes, pattern) || compareTo(minBytes, pattern) <= 0) && ( - startsWith(maxBytes, pattern) || compareTo(maxBytes, pattern) >= 0); - } + @Override + public boolean test( + DataType type, long rowCount, Object min, Object max, Long nullCount, Object patternLiteral) { + byte[] minBytes = (byte[]) min; + byte[] maxBytes = (byte[]) max; + byte[] pattern = (byte[]) patternLiteral; + return (startsWith(minBytes, pattern) || compareTo(minBytes, pattern) <= 0) + && (startsWith(maxBytes, pattern) || compareTo(maxBytes, pattern) >= 0); + } - @Override - public Optional negate() { - return Optional.empty(); - } + @Override + public Optional negate() { + return Optional.empty(); + } - @Override - public T visit(FunctionVisitor visitor, FieldRef fieldRef, List literals) { - return visitor.visitStartsWith(fieldRef, literals.get(0)); - } - - /** - * Judge whether the field starts with the pattern. - * - * @param field the field to be compared - * @param pattern the pattern to be compared - * @return true if the field starts with the pattern, false otherwise - */ - private static boolean startsWith(byte[] field, byte[] pattern) { - if (field.length < pattern.length) { - return false; - } - MemorySegment s1 = MemorySegment.wrap(field); - MemorySegment s2 = MemorySegment.wrap(pattern); + @Override + public T visit(FunctionVisitor visitor, FieldRef fieldRef, List literals) { + return visitor.visitStartsWith(fieldRef, literals.get(0)); + } - return s1.equalTo(s2, 0, 0, pattern.length); + /** + * Judge whether the field starts with the pattern. + * + * @param field the field to be compared + * @param pattern the pattern to be compared + * @return true if the field starts with the pattern, false otherwise + */ + private static boolean startsWith(byte[] field, byte[] pattern) { + if (field.length < pattern.length) { + return false; } + MemorySegment s1 = MemorySegment.wrap(field); + MemorySegment s2 = MemorySegment.wrap(pattern); - private static int compareTo(byte[] lhs, byte[] rhs) { - int len = Math.min(lhs.length, rhs.length); + return s1.equalTo(s2, 0, 0, pattern.length); + } - for (int i = 0; i < len; i++) { - int res = Byte.compare(lhs[i], rhs[i]); - if (res != 0) { - return res; - } - } - return lhs.length - rhs.length; - } + private static int compareTo(byte[] lhs, byte[] rhs) { + int len = Math.min(lhs.length, rhs.length); + for (int i = 0; i < len; i++) { + int res = Byte.compare(lhs[i], rhs[i]); + if (res != 0) { + return res; + } + } + return lhs.length - rhs.length; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/AsyncPaimonGraphRWProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/AsyncPaimonGraphRWProxy.java index d7b990253..7fd64225d 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/AsyncPaimonGraphRWProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/AsyncPaimonGraphRWProxy.java @@ -19,11 +19,11 @@ package org.apache.geaflow.store.paimon.proxy; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashSet; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.model.graph.edge.IEdge; @@ -36,78 +36,83 @@ import org.apache.geaflow.store.data.GraphWriteBuffer; import org.apache.geaflow.store.paimon.PaimonTableRWHandle; -public class AsyncPaimonGraphRWProxy extends PaimonGraphRWProxy { - - private final AsyncFlushBuffer flushBuffer; - - public AsyncPaimonGraphRWProxy(PaimonTableRWHandle vertexHandle, PaimonTableRWHandle edgeHandle, - int[] projection, IGraphKVEncoder encoder, - Configuration config) { - super(vertexHandle, edgeHandle, projection, encoder); - this.flushBuffer = new AsyncFlushBuffer<>(config, this::flush, - SerializerFactory.getKryoSerializer()); - } - - private void flush(GraphWriteBuffer graphWriteBuffer) { - if (graphWriteBuffer.getSize() == 0) { - return; - } - - Collection> vertices = graphWriteBuffer.getVertexId2Vertex().values(); - for (IVertex vertex : vertices) { - super.addVertex(vertex); - } +import com.google.common.collect.Lists; - Collection>> edgesList = graphWriteBuffer.getVertexId2Edges().values(); - for (List> edges : edgesList) { - for (IEdge edge : edges) { - super.addEdge(edge); - } - } +public class AsyncPaimonGraphRWProxy extends PaimonGraphRWProxy { - super.flush(); + private final AsyncFlushBuffer flushBuffer; + + public AsyncPaimonGraphRWProxy( + PaimonTableRWHandle vertexHandle, + PaimonTableRWHandle edgeHandle, + int[] projection, + IGraphKVEncoder encoder, + Configuration config) { + super(vertexHandle, edgeHandle, projection, encoder); + this.flushBuffer = + new AsyncFlushBuffer<>(config, this::flush, SerializerFactory.getKryoSerializer()); + } + + private void flush(GraphWriteBuffer graphWriteBuffer) { + if (graphWriteBuffer.getSize() == 0) { + return; } - @Override - public void addVertex(IVertex vertex) { - this.flushBuffer.addVertex(vertex); + Collection> vertices = graphWriteBuffer.getVertexId2Vertex().values(); + for (IVertex vertex : vertices) { + super.addVertex(vertex); } - @Override - public void addEdge(IEdge edge) { - this.flushBuffer.addEdge(edge); + Collection>> edgesList = graphWriteBuffer.getVertexId2Edges().values(); + for (List> edges : edgesList) { + for (IEdge edge : edges) { + super.addEdge(edge); + } } - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - IVertex vertex = this.flushBuffer.readBufferedVertex(sid); - if (vertex != null) { - return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; - } - return super.getVertex(sid, pushdown); - } + super.flush(); + } - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - List> list = this.flushBuffer.readBufferedEdges(sid); - LinkedHashSet> set = new LinkedHashSet<>(); + @Override + public void addVertex(IVertex vertex) { + this.flushBuffer.addVertex(vertex); + } - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - Lists.reverse(list).stream().filter(filter::filterEdge).forEach(set::add); - if (!filter.dropAllRemaining()) { - set.addAll(super.getEdges(sid, pushdown)); - } + @Override + public void addEdge(IEdge edge) { + this.flushBuffer.addEdge(edge); + } - return new ArrayList<>(set); + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + IVertex vertex = this.flushBuffer.readBufferedVertex(sid); + if (vertex != null) { + return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; } - - @Override - public void flush() { - flushBuffer.flush(); + return super.getVertex(sid, pushdown); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + List> list = this.flushBuffer.readBufferedEdges(sid); + LinkedHashSet> set = new LinkedHashSet<>(); + + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + Lists.reverse(list).stream().filter(filter::filterEdge).forEach(set::add); + if (!filter.dropAllRemaining()) { + set.addAll(super.getEdges(sid, pushdown)); } - @Override - public void close() { - flushBuffer.close(); - } + return new ArrayList<>(set); + } + + @Override + public void flush() { + flushBuffer.flush(); + } + + @Override + public void close() { + flushBuffer.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphMultiVersionedPaimonProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphMultiVersionedPaimonProxy.java index 9d5da5da7..673ec19bf 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphMultiVersionedPaimonProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphMultiVersionedPaimonProxy.java @@ -21,7 +21,5 @@ import org.apache.geaflow.state.graph.DynamicGraphTrait; -public interface IGraphMultiVersionedPaimonProxy extends DynamicGraphTrait, - IPaimonProxy { - -} +public interface IGraphMultiVersionedPaimonProxy + extends DynamicGraphTrait, IPaimonProxy {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphPaimonProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphPaimonProxy.java index adff95b35..72ad2a0a1 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphPaimonProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IGraphPaimonProxy.java @@ -21,6 +21,4 @@ import org.apache.geaflow.state.graph.StaticGraphTrait; -public interface IGraphPaimonProxy extends StaticGraphTrait, IPaimonProxy { - -} +public interface IGraphPaimonProxy extends StaticGraphTrait, IPaimonProxy {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IPaimonProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IPaimonProxy.java index 2a5f6c728..2a2c8b83d 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IPaimonProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/IPaimonProxy.java @@ -21,34 +21,30 @@ public interface IPaimonProxy { - /** - * Archive data with a specific checkpoint id. - * - * @param checkpointId checkpoint id. - */ - void archive(long checkpointId); + /** + * Archive data with a specific checkpoint id. + * + * @param checkpointId checkpoint id. + */ + void archive(long checkpointId); - /** - * Recovery data with a specific checkpoint id. - * - * @param checkpointId checkpoint id. - */ - void recover(long checkpointId); + /** + * Recovery data with a specific checkpoint id. + * + * @param checkpointId checkpoint id. + */ + void recover(long checkpointId); - /** - * Recovery data to the latest checkpoint. - * - * @return checkpoint id. - */ - long recoverLatest(); + /** + * Recovery data to the latest checkpoint. + * + * @return checkpoint id. + */ + long recoverLatest(); - /** - * Flush data from memtable to disk. - */ - void flush(); + /** Flush data from memtable to disk. */ + void flush(); - /** - * Close proxy. - */ - void close(); + /** Close proxy. */ + void close(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonBaseGraphProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonBaseGraphProxy.java index 7fd237c34..1bba24442 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonBaseGraphProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonBaseGraphProxy.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -53,186 +54,189 @@ public abstract class PaimonBaseGraphProxy implements IGraphPaimonProxy { - protected static final int KEY_COLUMN_INDEX = 0; - - protected static final int VALUE_COLUMN_INDEX = 1; - - protected IGraphKVEncoder encoder; - - protected PaimonTableRWHandle vertexHandle; - - protected PaimonTableRWHandle edgeHandle; - - protected int[] projection; - - protected long lastCheckpointId; - - public PaimonBaseGraphProxy(PaimonTableRWHandle vertexHandle, PaimonTableRWHandle edgeHandle, - int[] projection, IGraphKVEncoder encoder) { - this.vertexHandle = vertexHandle; - this.edgeHandle = edgeHandle; - this.projection = projection; - this.encoder = encoder; - this.lastCheckpointId = 0; - } - - @Override - public void addEdge(IEdge edge) { - Tuple tuple = this.encoder.getEdgeEncoder().format(edge); - GenericRow record = GenericRow.of(tuple.f0, tuple.f1); - this.edgeHandle.write(record); - } - - @Override - public void addVertex(IVertex vertex) { - Tuple tuple = this.encoder.getVertexEncoder().format(vertex); - GenericRow record = GenericRow.of(tuple.f0, tuple.f1); - this.vertexHandle.write(record); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - IVertex vertex = getVertex(sid, pushdown); - List> edgeList = getEdges(sid, pushdown); - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edgeList.iterator())); - if (filter.filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } else { - return null; - } - } - - @Override - public CloseableIterator vertexIDIterator() { - flush(); - RecordReaderIterator iterator = this.vertexHandle.getIterator(null, null, - projection); - PaimonIterator it = new PaimonIterator(iterator); - return new IteratorWithFnThenFilter<>(it, - tuple2 -> encoder.getVertexEncoder().getVertexID(tuple2.f0), new Predicate() { - K last = null; - - @Override - public boolean test(K k) { - boolean res = k.equals(last); - last = k; - return !res; - } - } - ); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - if (pushDown.getFilter() == null) { - return vertexIDIterator(); - } else { - return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); - } - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - flush(); - RecordReaderIterator iterator = this.vertexHandle.getIterator(null, null, - projection); - PaimonIterator it = new PaimonIterator(iterator); - return new VertexScanIterator<>(it, pushdown, encoder.getVertexEncoder()::getVertex); - } - - @Override - public CloseableIterator> getVertexIterator(List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, this::getVertex, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - flush(); - RecordReaderIterator iterator = this.edgeHandle.getIterator(null, - null, projection); - PaimonIterator it = new PaimonIterator(iterator); - return new EdgeScanIterator<>(it, pushdown, encoder.getEdgeEncoder()::getEdge); - } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - return new IteratorWithFlatFn<>(new KeysIterator<>(keys, this::getEdges, pushdown), - List::iterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - flush(); - return new OneDegreeGraphScanIterator<>(encoder.getKeyType(), getVertexIterator(pushdown), - getEdgeIterator(pushdown), pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - flush(); - return new IteratorWithFn<>(getEdgeIterator(pushdown), - e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown) { - return new IteratorWithFn<>(getEdgeIterator(keys, pushdown), - e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - Map res = new HashMap<>(); - Iterator>> it = new EdgeListScanIterator<>(getEdgeIterator(pushdown)); - while (it.hasNext()) { - List> edges = it.next(); - K key = edges.get(0).getSrcId(); - res.put(key, (long) edges.size()); - } - return res; - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - Map res = new HashMap<>(keys.size()); - - Function pushdownFun; - if (pushdown.getFilters() == null) { - pushdownFun = key -> pushdown; - } else { - pushdownFun = key -> StatePushDown.of() - .withFilter((IFilter) pushdown.getFilters().get(key)); - } - - for (K key : keys) { - List> list = getEdges(key, pushdownFun.apply(key)); - res.put(key, (long) list.size()); - } - return res; - } - - @Override - public void flush() { - this.vertexHandle.flush(lastCheckpointId); - this.edgeHandle.flush(lastCheckpointId); - } - - @Override - public void close() { - this.vertexHandle.close(); - this.edgeHandle.close(); - } + protected static final int KEY_COLUMN_INDEX = 0; + + protected static final int VALUE_COLUMN_INDEX = 1; + + protected IGraphKVEncoder encoder; + + protected PaimonTableRWHandle vertexHandle; + + protected PaimonTableRWHandle edgeHandle; + + protected int[] projection; + + protected long lastCheckpointId; + + public PaimonBaseGraphProxy( + PaimonTableRWHandle vertexHandle, + PaimonTableRWHandle edgeHandle, + int[] projection, + IGraphKVEncoder encoder) { + this.vertexHandle = vertexHandle; + this.edgeHandle = edgeHandle; + this.projection = projection; + this.encoder = encoder; + this.lastCheckpointId = 0; + } + + @Override + public void addEdge(IEdge edge) { + Tuple tuple = this.encoder.getEdgeEncoder().format(edge); + GenericRow record = GenericRow.of(tuple.f0, tuple.f1); + this.edgeHandle.write(record); + } + + @Override + public void addVertex(IVertex vertex) { + Tuple tuple = this.encoder.getVertexEncoder().format(vertex); + GenericRow record = GenericRow.of(tuple.f0, tuple.f1); + this.vertexHandle.write(record); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(sid, pushdown); + List> edgeList = getEdges(sid, pushdown); + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edgeList.iterator())); + if (filter.filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } else { + return null; + } + } + + @Override + public CloseableIterator vertexIDIterator() { + flush(); + RecordReaderIterator iterator = + this.vertexHandle.getIterator(null, null, projection); + PaimonIterator it = new PaimonIterator(iterator); + return new IteratorWithFnThenFilter<>( + it, + tuple2 -> encoder.getVertexEncoder().getVertexID(tuple2.f0), + new Predicate() { + K last = null; + + @Override + public boolean test(K k) { + boolean res = k.equals(last); + last = k; + return !res; + } + }); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + if (pushDown.getFilter() == null) { + return vertexIDIterator(); + } else { + return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); + } + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + flush(); + RecordReaderIterator iterator = + this.vertexHandle.getIterator(null, null, projection); + PaimonIterator it = new PaimonIterator(iterator); + return new VertexScanIterator<>(it, pushdown, encoder.getVertexEncoder()::getVertex); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, this::getVertex, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + flush(); + RecordReaderIterator iterator = + this.edgeHandle.getIterator(null, null, projection); + PaimonIterator it = new PaimonIterator(iterator); + return new EdgeScanIterator<>(it, pushdown, encoder.getEdgeEncoder()::getEdge); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + return new IteratorWithFlatFn<>( + new KeysIterator<>(keys, this::getEdges, pushdown), List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + flush(); + return new OneDegreeGraphScanIterator<>( + encoder.getKeyType(), getVertexIterator(pushdown), getEdgeIterator(pushdown), pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + flush(); + return new IteratorWithFn<>( + getEdgeIterator(pushdown), e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return new IteratorWithFn<>( + getEdgeIterator(keys, pushdown), + e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + Map res = new HashMap<>(); + Iterator>> it = new EdgeListScanIterator<>(getEdgeIterator(pushdown)); + while (it.hasNext()) { + List> edges = it.next(); + K key = edges.get(0).getSrcId(); + res.put(key, (long) edges.size()); + } + return res; + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + Map res = new HashMap<>(keys.size()); + + Function pushdownFun; + if (pushdown.getFilters() == null) { + pushdownFun = key -> pushdown; + } else { + pushdownFun = key -> StatePushDown.of().withFilter((IFilter) pushdown.getFilters().get(key)); + } + + for (K key : keys) { + List> list = getEdges(key, pushdownFun.apply(key)); + res.put(key, (long) list.size()); + } + return res; + } + + @Override + public void flush() { + this.vertexHandle.flush(lastCheckpointId); + this.edgeHandle.flush(lastCheckpointId); + } + + @Override + public void close() { + this.vertexHandle.close(); + this.edgeHandle.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphMultiVersionedProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphMultiVersionedProxy.java index 0e77dc2b1..642d05c6e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphMultiVersionedProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphMultiVersionedProxy.java @@ -21,14 +21,13 @@ import static java.util.Collections.singletonList; -import com.google.common.primitives.Bytes; -import com.google.common.primitives.Longs; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -64,340 +63,376 @@ import org.apache.paimon.reader.RecordReaderIterator; import org.apache.paimon.types.RowType; -public class PaimonGraphMultiVersionedProxy implements - IGraphMultiVersionedPaimonProxy { - - private static final int VERSION_BYTES_SIZE = Long.BYTES; - private static final int VERTEX_INDEX_SUFFIX_SIZE = - VERSION_BYTES_SIZE + StateConfigKeys.DELIMITER.length; - protected static final byte[] EMPTY_BYTES = new byte[0]; - protected final Configuration config; - protected IGraphKVEncoder encoder; - protected IEdgeKVEncoder edgeEncoder; - protected IVertexKVEncoder vertexEncoder; - - protected PaimonTableRWHandle vertexHandle; - protected PaimonTableRWHandle vertexIndexHandle; - protected PaimonTableRWHandle edgeHandle; - - protected int[] projection; - protected long lastCheckpointId; - - public PaimonGraphMultiVersionedProxy(PaimonTableRWHandle vertexHandle, - PaimonTableRWHandle vertexIndexHandle, - PaimonTableRWHandle edgeHandle, int[] projection, - IGraphKVEncoder encoder, - Configuration config) { - this.vertexHandle = vertexHandle; - this.vertexIndexHandle = vertexIndexHandle; - this.edgeHandle = edgeHandle; - this.projection = projection; - this.encoder = encoder; - this.vertexEncoder = encoder.getVertexEncoder(); - this.edgeEncoder = encoder.getEdgeEncoder(); - this.config = config; - this.lastCheckpointId = 0; - } - - @Override - public void archive(long checkpointId) { - this.lastCheckpointId = checkpointId; - this.vertexHandle.commit(lastCheckpointId); - this.vertexIndexHandle.commit(lastCheckpointId); - this.edgeHandle.commit(lastCheckpointId); - } - - @Override - public void recover(long checkpointId) { - throw new UnsupportedOperationException(); - } - - @Override - public long recoverLatest() { - throw new UnsupportedOperationException(); - } - - @Override - public void flush() { - - } - - @Override - public void close() { - - } - - @Override - public void addEdge(long version, IEdge edge) { - Tuple tuple = edgeEncoder.format(edge); - byte[] bVersion = getBinaryVersion(version); - GenericRow record = GenericRow.of(concat(bVersion, tuple.f0), tuple.f1); - this.edgeHandle.write(record); - } - - @Override - public void addVertex(long version, IVertex vertex) { - Tuple tuple = vertexEncoder.format(vertex); - byte[] bVersion = getBinaryVersion(version); - GenericRow record = GenericRow.of(concat(bVersion, tuple.f0), tuple.f1); - GenericRow index = GenericRow.of(concat(tuple.f0, bVersion), EMPTY_BYTES); - this.vertexHandle.write(record); - this.vertexIndexHandle.write(index); - } +import com.google.common.primitives.Bytes; +import com.google.common.primitives.Longs; - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - byte[] bVersion = getBinaryVersion(version); - RowType rowType = vertexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(Equal.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(concat(bVersion, key))); - RecordReaderIterator iterator = this.vertexHandle.getIterator(predicate, null, - projection); - try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { - if (paimonIterator.hasNext()) { - Tuple row = paimonIterator.next(); - IVertex vertex = vertexEncoder.getVertex(key, row.getF1()); - if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex( - vertex)) { - return vertex; - } - } - return null; +public class PaimonGraphMultiVersionedProxy + implements IGraphMultiVersionedPaimonProxy { + + private static final int VERSION_BYTES_SIZE = Long.BYTES; + private static final int VERTEX_INDEX_SUFFIX_SIZE = + VERSION_BYTES_SIZE + StateConfigKeys.DELIMITER.length; + protected static final byte[] EMPTY_BYTES = new byte[0]; + protected final Configuration config; + protected IGraphKVEncoder encoder; + protected IEdgeKVEncoder edgeEncoder; + protected IVertexKVEncoder vertexEncoder; + + protected PaimonTableRWHandle vertexHandle; + protected PaimonTableRWHandle vertexIndexHandle; + protected PaimonTableRWHandle edgeHandle; + + protected int[] projection; + protected long lastCheckpointId; + + public PaimonGraphMultiVersionedProxy( + PaimonTableRWHandle vertexHandle, + PaimonTableRWHandle vertexIndexHandle, + PaimonTableRWHandle edgeHandle, + int[] projection, + IGraphKVEncoder encoder, + Configuration config) { + this.vertexHandle = vertexHandle; + this.vertexIndexHandle = vertexIndexHandle; + this.edgeHandle = edgeHandle; + this.projection = projection; + this.encoder = encoder; + this.vertexEncoder = encoder.getVertexEncoder(); + this.edgeEncoder = encoder.getEdgeEncoder(); + this.config = config; + this.lastCheckpointId = 0; + } + + @Override + public void archive(long checkpointId) { + this.lastCheckpointId = checkpointId; + this.vertexHandle.commit(lastCheckpointId); + this.vertexIndexHandle.commit(lastCheckpointId); + this.edgeHandle.commit(lastCheckpointId); + } + + @Override + public void recover(long checkpointId) { + throw new UnsupportedOperationException(); + } + + @Override + public long recoverLatest() { + throw new UnsupportedOperationException(); + } + + @Override + public void flush() {} + + @Override + public void close() {} + + @Override + public void addEdge(long version, IEdge edge) { + Tuple tuple = edgeEncoder.format(edge); + byte[] bVersion = getBinaryVersion(version); + GenericRow record = GenericRow.of(concat(bVersion, tuple.f0), tuple.f1); + this.edgeHandle.write(record); + } + + @Override + public void addVertex(long version, IVertex vertex) { + Tuple tuple = vertexEncoder.format(vertex); + byte[] bVersion = getBinaryVersion(version); + GenericRow record = GenericRow.of(concat(bVersion, tuple.f0), tuple.f1); + GenericRow index = GenericRow.of(concat(tuple.f0, bVersion), EMPTY_BYTES); + this.vertexHandle.write(record); + this.vertexIndexHandle.write(index); + } + + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + byte[] bVersion = getBinaryVersion(version); + RowType rowType = vertexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + Equal.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(concat(bVersion, key))); + RecordReaderIterator iterator = + this.vertexHandle.getIterator(predicate, null, projection); + try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { + if (paimonIterator.hasNext()) { + Tuple row = paimonIterator.next(); + IVertex vertex = vertexEncoder.getVertex(key, row.getF1()); + if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { + return vertex; } + } + return null; } - - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - byte[] bVersion = getBinaryVersion(version); - byte[] prefixBytes = concat(bVersion, edgeEncoder.getScanBytes(sid)); - RowType rowType = edgeHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(prefixBytes)); - RecordReaderIterator iterator = this.edgeHandle.getIterator(predicate, null, - projection); - List> edges = new ArrayList<>(); - try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { - IGraphFilter graphFilter = GraphFilter.of(pushdown.getFilter(), - pushdown.getEdgeLimit()); - while (paimonIterator.hasNext()) { - Tuple row = paimonIterator.next(); - IEdge edge = edgeEncoder.getEdge(getKeyFromVersionToKey(row.f0), row.f1); - if (graphFilter.filterEdge(edge)) { - edges.add(edge); - } - } - return edges; + } + + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + byte[] bVersion = getBinaryVersion(version); + byte[] prefixBytes = concat(bVersion, edgeEncoder.getScanBytes(sid)); + RowType rowType = edgeHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefixBytes)); + RecordReaderIterator iterator = + this.edgeHandle.getIterator(predicate, null, projection); + List> edges = new ArrayList<>(); + try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { + IGraphFilter graphFilter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + while (paimonIterator.hasNext()) { + Tuple row = paimonIterator.next(); + IEdge edge = edgeEncoder.getEdge(getKeyFromVersionToKey(row.f0), row.f1); + if (graphFilter.filterEdge(edge)) { + edges.add(edge); } + } + return edges; } - - @Override - public OneDegreeGraph getOneDegreeGraph(long version, K sid, - IStatePushDown pushdown) { - IVertex vertex = getVertex(version, sid, pushdown); - List> edgeList = getEdges(version, sid, pushdown); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edgeList.iterator())); - if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } else { - return null; - } + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(version, sid, pushdown); + List> edgeList = getEdges(version, sid, pushdown); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edgeList.iterator())); + if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } else { + return null; } - - @Override - public CloseableIterator vertexIDIterator() { - flush(); - - RecordReaderIterator iterator = this.vertexIndexHandle.getIterator(null, null, - projection); - - return new IteratorWithFnThenFilter<>(iterator, - tuple2 -> vertexEncoder.getVertexID(getKeyFromKeyToVersion(tuple2.getBinary(0))), - new DeDupPredicate<>()); + } + + @Override + public CloseableIterator vertexIDIterator() { + flush(); + + RecordReaderIterator iterator = + this.vertexIndexHandle.getIterator(null, null, projection); + + return new IteratorWithFnThenFilter<>( + iterator, + tuple2 -> vertexEncoder.getVertexID(getKeyFromKeyToVersion(tuple2.getBinary(0))), + new DeDupPredicate<>()); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + if (pushdown.getFilter() == null) { + flush(); + byte[] prefix = getVersionPrefix(version); + RowType rowType = vertexIndexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefix)); + RecordReaderIterator iterator = + this.vertexIndexHandle.getIterator(predicate, null, projection); + return new IteratorWithFnThenFilter<>( + iterator, + tuple2 -> vertexEncoder.getVertexID(getKeyFromVersionToKey(tuple2.getBinary(0))), + new DeDupPredicate<>()); + + } else { + return new IteratorWithFn<>(getVertexIterator(version, pushdown), IVertex::getId); } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - if (pushdown.getFilter() == null) { - flush(); - byte[] prefix = getVersionPrefix(version); - RowType rowType = vertexIndexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), - 0, rowType.getField(0).name(), singletonList(prefix)); - RecordReaderIterator iterator = this.vertexIndexHandle.getIterator( - predicate, null, projection); - return new IteratorWithFnThenFilter<>(iterator, - tuple2 -> vertexEncoder.getVertexID(getKeyFromVersionToKey(tuple2.getBinary(0))), - new DeDupPredicate<>()); - - } else { - return new IteratorWithFn<>(getVertexIterator(version, pushdown), IVertex::getId); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + flush(); + byte[] prefix = getVersionPrefix(version); + RowType rowType = vertexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefix)); + RecordReaderIterator iterator = + this.vertexHandle.getIterator(predicate, null, projection); + PaimonIterator it = new PaimonIterator(iterator); + return new VertexScanIterator<>( + it, pushdown, (key, value) -> vertexEncoder.getVertex(getKeyFromVersionToKey(key), value)); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, (k, f) -> getVertex(version, k, f), pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + flush(); + byte[] prefix = getVersionPrefix(version); + RowType rowType = edgeHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefix)); + RecordReaderIterator iterator = + this.edgeHandle.getIterator(predicate, null, projection); + PaimonIterator it = new PaimonIterator(iterator); + return new EdgeScanIterator<>( + it, pushdown, (key, value) -> edgeEncoder.getEdge(getKeyFromVersionToKey(key), value)); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return new IteratorWithFlatFn<>( + new KeysIterator<>(keys, (k, f) -> getEdges(version, k, f), pushdown), List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + flush(); + return new OneDegreeGraphScanIterator<>( + encoder.getKeyType(), + getVertexIterator(version, pushdown), + getEdgeIterator(version, pushdown), + pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, (k, f) -> getOneDegreeGraph(version, k, f), pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + flush(); + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + List list = new ArrayList<>(); + byte[] prefix = Bytes.concat(encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); + RowType rowType = vertexIndexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefix)); + RecordReaderIterator iterator = + this.vertexIndexHandle.getIterator(predicate, null, projection); + try (PaimonIterator it = new PaimonIterator(iterator)) { + while (it.hasNext()) { + Tuple pair = it.next(); + list.add(getVersionFromKeyToVersion(pair.f0)); } + } + return list; } - - @Override - public CloseableIterator> getVertexIterator(long version, - IStatePushDown pushdown) { - flush(); - byte[] prefix = getVersionPrefix(version); - RowType rowType = vertexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(prefix)); - RecordReaderIterator iterator = this.vertexHandle.getIterator(predicate, null, - projection); - PaimonIterator it = new PaimonIterator(iterator); - return new VertexScanIterator<>(it, pushdown, - (key, value) -> vertexEncoder.getVertex(getKeyFromVersionToKey(key), value)); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, (k, f) -> getVertex(version, k, f), pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - flush(); - byte[] prefix = getVersionPrefix(version); - RowType rowType = edgeHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(prefix)); - RecordReaderIterator iterator = this.edgeHandle.getIterator(predicate, null, - projection); - PaimonIterator it = new PaimonIterator(iterator); - return new EdgeScanIterator<>(it, pushdown, - (key, value) -> edgeEncoder.getEdge(getKeyFromVersionToKey(key), value)); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return new IteratorWithFlatFn<>( - new KeysIterator<>(keys, (k, f) -> getEdges(version, k, f), pushdown), List::iterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - flush(); - return new OneDegreeGraphScanIterator<>(encoder.getKeyType(), - getVertexIterator(version, pushdown), getEdgeIterator(version, pushdown), pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, (k, f) -> getOneDegreeGraph(version, k, f), pushdown); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - flush(); - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - List list = new ArrayList<>(); - byte[] prefix = Bytes.concat(encoder.getKeyType().serialize(id), - StateConfigKeys.DELIMITER); - RowType rowType = vertexIndexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), - 0, rowType.getField(0).name(), singletonList(prefix)); - RecordReaderIterator iterator = this.vertexIndexHandle.getIterator( - predicate, null, projection); - try (PaimonIterator it = new PaimonIterator(iterator)) { - while (it.hasNext()) { - Tuple pair = it.next(); - list.add(getVersionFromKeyToVersion(pair.f0)); - } - } - return list; + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + flush(); + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + byte[] prefix = getKeyPrefix(id); + RowType rowType = vertexIndexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefix)); + RecordReaderIterator iterator = + this.vertexIndexHandle.getIterator(predicate, null, projection); + try (PaimonIterator it = new PaimonIterator(iterator)) { + if (it.hasNext()) { + Tuple pair = it.next(); + return getVersionFromKeyToVersion(pair.f0); } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + return -1; } - - @Override - public long getLatestVersion(K id, DataType dataType) { - flush(); - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - byte[] prefix = getKeyPrefix(id); - RowType rowType = vertexIndexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), - 0, rowType.getField(0).name(), singletonList(prefix)); - RecordReaderIterator iterator = this.vertexIndexHandle.getIterator( - predicate, null, projection); - try (PaimonIterator it = new PaimonIterator(iterator)) { - if (it.hasNext()) { - Tuple pair = it.next(); - return getVersionFromKeyToVersion(pair.f0); - } - } - return -1; + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + List allVersions = getAllVersions(id, dataType); + return getVersionData(id, allVersions, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType) { + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + Map> map = new HashMap<>(); + for (long version : versions) { + IVertex vertex = getVertex(version, id, pushdown); + if (vertex != null) { + map.put(version, vertex); } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + return map; } + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - List allVersions = getAllVersions(id, dataType); - return getVersionData(id, allVersions, pushdown, dataType); - } + private long getVersionFromKeyToVersion(byte[] key) { + byte[] bVersion = Arrays.copyOfRange(key, key.length - 8, key.length); + return Long.MAX_VALUE - Longs.fromByteArray(bVersion); + } - @Override - public Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType) { - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - Map> map = new HashMap<>(); - for (long version : versions) { - IVertex vertex = getVertex(version, id, pushdown); - if (vertex != null) { - map.put(version, vertex); - } - } - return map; - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + protected byte[] getKeyFromKeyToVersion(byte[] key) { + return Arrays.copyOf(key, key.length - VERTEX_INDEX_SUFFIX_SIZE); + } - private long getVersionFromKeyToVersion(byte[] key) { - byte[] bVersion = Arrays.copyOfRange(key, key.length - 8, key.length); - return Long.MAX_VALUE - Longs.fromByteArray(bVersion); - } + protected byte[] getBinaryVersion(long version) { + return Longs.toByteArray(Long.MAX_VALUE - version); + } - protected byte[] getKeyFromKeyToVersion(byte[] key) { - return Arrays.copyOf(key, key.length - VERTEX_INDEX_SUFFIX_SIZE); - } + protected byte[] getKeyPrefix(K id) { + return Bytes.concat(this.encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); + } - protected byte[] getBinaryVersion(long version) { - return Longs.toByteArray(Long.MAX_VALUE - version); - } + protected byte[] getVersionPrefix(long version) { + return Bytes.concat(getBinaryVersion(version), StateConfigKeys.DELIMITER); + } - protected byte[] getKeyPrefix(K id) { - return Bytes.concat(this.encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); - } + protected byte[] getKeyFromVersionToKey(byte[] key) { + return Arrays.copyOfRange(key, 10, key.length); + } - protected byte[] getVersionPrefix(long version) { - return Bytes.concat(getBinaryVersion(version), StateConfigKeys.DELIMITER); - } + protected byte[] concat(byte[] a, byte[] b) { + return Bytes.concat(a, StateConfigKeys.DELIMITER, b); + } - protected byte[] getKeyFromVersionToKey(byte[] key) { - return Arrays.copyOfRange(key, 10, key.length); - } + protected static class DeDupPredicate implements java.util.function.Predicate { - protected byte[] concat(byte[] a, byte[] b) { - return Bytes.concat(a, StateConfigKeys.DELIMITER, b); - } + K last = null; - protected static class DeDupPredicate implements java.util.function.Predicate { - - K last = null; - - @Override - public boolean test(K k) { - boolean res = k.equals(last); - last = k; - return !res; - } + @Override + public boolean test(K k) { + boolean res = k.equals(last); + last = k; + return !res; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphRWProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphRWProxy.java index 9cd55f8b6..86fdc5802 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphRWProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonGraphRWProxy.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -42,73 +43,83 @@ public class PaimonGraphRWProxy extends PaimonBaseGraphProxy { - public PaimonGraphRWProxy(PaimonTableRWHandle vertexHandle, PaimonTableRWHandle edgeHandle, - int[] projection, IGraphKVEncoder encoder) { - super(vertexHandle, edgeHandle, projection, encoder); - } + public PaimonGraphRWProxy( + PaimonTableRWHandle vertexHandle, + PaimonTableRWHandle edgeHandle, + int[] projection, + IGraphKVEncoder encoder) { + super(vertexHandle, edgeHandle, projection, encoder); + } - @Override - public void archive(long checkpointId) { - this.lastCheckpointId = checkpointId; - this.vertexHandle.commit(lastCheckpointId); - this.edgeHandle.commit(lastCheckpointId); - } + @Override + public void archive(long checkpointId) { + this.lastCheckpointId = checkpointId; + this.vertexHandle.commit(lastCheckpointId); + this.edgeHandle.commit(lastCheckpointId); + } - @Override - public void recover(long checkpointId) { - throw new UnsupportedOperationException(); - } + @Override + public void recover(long checkpointId) { + throw new UnsupportedOperationException(); + } - @Override - public long recoverLatest() { - throw new UnsupportedOperationException(); - } + @Override + public long recoverLatest() { + throw new UnsupportedOperationException(); + } - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - RowType rowType = vertexHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(Equal.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(key)); - RecordReaderIterator iterator = this.vertexHandle.getIterator(predicate, null, - projection); - try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { - if (paimonIterator.hasNext()) { - Tuple row = paimonIterator.next(); - IVertex vertex = encoder.getVertexEncoder() - .getVertex(row.getF0(), row.getF1()); - if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex( - vertex)) { - return vertex; - } - } - return null; + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + RowType rowType = vertexHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + Equal.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(key)); + RecordReaderIterator iterator = + this.vertexHandle.getIterator(predicate, null, projection); + try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { + if (paimonIterator.hasNext()) { + Tuple row = paimonIterator.next(); + IVertex vertex = encoder.getVertexEncoder().getVertex(row.getF0(), row.getF1()); + if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { + return vertex; } + } + return null; } + } - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - byte[] prefixBytes = encoder.getEdgeEncoder().getScanBytes(sid); - RowType rowType = edgeHandle.getTable().rowType(); - Predicate predicate = new LeafPredicate(BytesStartsWith.INSTANCE, rowType.getTypeAt(0), 0, - rowType.getField(0).name(), singletonList(prefixBytes)); - RecordReaderIterator iterator = this.edgeHandle.getIterator(predicate, null, - projection); - List> edges = new ArrayList<>(); - try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { - IGraphFilter graphFilter = GraphFilter.of(pushdown.getFilter(), - pushdown.getEdgeLimit()); - while (paimonIterator.hasNext()) { - Tuple row = paimonIterator.next(); - IEdge edge = encoder.getEdgeEncoder().getEdge(row.getF0(), row.getF1()); - if (graphFilter.filterEdge(edge)) { - edges.add(edge); - } - if (graphFilter.dropAllRemaining()) { - break; - } - } - return edges; + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + byte[] prefixBytes = encoder.getEdgeEncoder().getScanBytes(sid); + RowType rowType = edgeHandle.getTable().rowType(); + Predicate predicate = + new LeafPredicate( + BytesStartsWith.INSTANCE, + rowType.getTypeAt(0), + 0, + rowType.getField(0).name(), + singletonList(prefixBytes)); + RecordReaderIterator iterator = + this.edgeHandle.getIterator(predicate, null, projection); + List> edges = new ArrayList<>(); + try (PaimonIterator paimonIterator = new PaimonIterator(iterator)) { + IGraphFilter graphFilter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + while (paimonIterator.hasNext()) { + Tuple row = paimonIterator.next(); + IEdge edge = encoder.getEdgeEncoder().getEdge(row.getF0(), row.getF1()); + if (graphFilter.filterEdge(edge)) { + edges.add(edge); + } + if (graphFilter.dropAllRemaining()) { + break; } + } + return edges; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonProxyBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonProxyBuilder.java index faf22cad0..a3ec5eae3 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonProxyBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/main/java/org/apache/geaflow/store/paimon/proxy/PaimonProxyBuilder.java @@ -26,26 +26,29 @@ public class PaimonProxyBuilder { - public static IGraphPaimonProxy build(Configuration config, - PaimonTableRWHandle vertexRWHandle, - PaimonTableRWHandle edgeRWHandle, - int[] projection, - IGraphKVEncoder encoder) { - // TODO: add readonly proxy. - if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { - return new AsyncPaimonGraphRWProxy<>(vertexRWHandle, edgeRWHandle, projection, encoder, - config); - } else { - return new PaimonGraphRWProxy<>(vertexRWHandle, edgeRWHandle, projection, encoder); - } - } - - public static IGraphMultiVersionedPaimonProxy buildMultiVersioned( - Configuration config, PaimonTableRWHandle vertexRWHandle, - PaimonTableRWHandle vertexIndexRWHandle, PaimonTableRWHandle edgeRWHandle, int[] projection, - IGraphKVEncoder encoder) { - return new PaimonGraphMultiVersionedProxy<>(vertexRWHandle, vertexIndexRWHandle, - edgeRWHandle, projection, encoder, config); + public static IGraphPaimonProxy build( + Configuration config, + PaimonTableRWHandle vertexRWHandle, + PaimonTableRWHandle edgeRWHandle, + int[] projection, + IGraphKVEncoder encoder) { + // TODO: add readonly proxy. + if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { + return new AsyncPaimonGraphRWProxy<>( + vertexRWHandle, edgeRWHandle, projection, encoder, config); + } else { + return new PaimonGraphRWProxy<>(vertexRWHandle, edgeRWHandle, projection, encoder); } + } + public static IGraphMultiVersionedPaimonProxy buildMultiVersioned( + Configuration config, + PaimonTableRWHandle vertexRWHandle, + PaimonTableRWHandle vertexIndexRWHandle, + PaimonTableRWHandle edgeRWHandle, + int[] projection, + IGraphKVEncoder encoder) { + return new PaimonGraphMultiVersionedProxy<>( + vertexRWHandle, vertexIndexRWHandle, edgeRWHandle, projection, encoder, config); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/PaimonRWHandleTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/PaimonRWHandleTest.java index a951c1834..cb1db4a48 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/PaimonRWHandleTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/PaimonRWHandleTest.java @@ -32,99 +32,96 @@ public class PaimonRWHandleTest { - static class GraphPaimonStoreTest extends BasePaimonStore { + static class GraphPaimonStoreTest extends BasePaimonStore { - public void init(StoreContext storeContext) { - super.init(storeContext); - } - - @Override - public void flush() { - - } - - @Override - public void archive(long checkpointId) { - - } + public void init(StoreContext storeContext) { + super.init(storeContext); + } - @Override - public void recovery(long checkpointId) { + @Override + public void flush() {} - } + @Override + public void archive(long checkpointId) {} - @Override - public long recoveryLatest() { - return 0; - } + @Override + public void recovery(long checkpointId) {} - @Override - public void compact() { + @Override + public long recoveryLatest() { + return 0; + } - } + @Override + public void compact() {} - @Override - public void drop() { - super.drop(); - } + @Override + public void drop() { + super.drop(); } - - @Test - public void testGraphStore() throws Exception { - // create edge table. - GraphPaimonStoreTest storeBase = new GraphPaimonStoreTest(); - Configuration config = new Configuration(); - config.put(ExecutionConfigKeys.JOB_APP_NAME, "test_paimon_app"); - storeBase.init(new StoreContext("test_paimon_store").withConfig(config)); - - PaimonTableRWHandle edgeHandle = storeBase.createEdgeTableHandle( - new Identifier(storeBase.paimonStoreName, "edge")); - - // 写入一条数据 - 根据 schema 定义构造完整的行数据 - String srcId = "src1"; - String targetId = "dst1"; - long timestamp = System.currentTimeMillis(); - short direction = 1; // 假设 1 表示出边 - String label = "knows"; - String value = "edge-value-123"; - - // 按照 schema 顺序构造 row: src_id, target_id, ts, direction, label, value - GenericRow row = GenericRow.of(srcId.getBytes(), // src_id (主键) - targetId.getBytes(), // target_id - timestamp, // ts - direction, // direction - label.getBytes(), // label - value.getBytes() // value - ); - - edgeHandle.write(row); - long checkpointId = 1L; - edgeHandle.commit(checkpointId); - - // 读取数据并断言 - 使用所有列的投影 - int[] projection = new int[]{0, 1, 2, 3, 4, 5}; // 所有列 - RecordReaderIterator iterator = edgeHandle.getIterator(null, null, projection); - boolean found = false; - while (iterator.hasNext()) { - InternalRow internalRow = iterator.next(); - String readSrcId = new String(internalRow.getBinary(0)); - String readTargetId = new String(internalRow.getBinary(1)); - long readTs = internalRow.getLong(2); - short readDirection = internalRow.getShort(3); - String readLabel = new String(internalRow.getBinary(4)); - String readValue = new String(internalRow.getBinary(5)); - - if (srcId.equals(readSrcId) && targetId.equals(readTargetId) && timestamp == readTs - && direction == readDirection && label.equals(readLabel) && value.equals( - readValue)) { - found = true; - break; - } - } - iterator.close(); - assertTrue(found); - - storeBase.drop(); - storeBase.close(); + } + + @Test + public void testGraphStore() throws Exception { + // create edge table. + GraphPaimonStoreTest storeBase = new GraphPaimonStoreTest(); + Configuration config = new Configuration(); + config.put(ExecutionConfigKeys.JOB_APP_NAME, "test_paimon_app"); + storeBase.init(new StoreContext("test_paimon_store").withConfig(config)); + + PaimonTableRWHandle edgeHandle = + storeBase.createEdgeTableHandle(new Identifier(storeBase.paimonStoreName, "edge")); + + // 写入一条数据 - 根据 schema 定义构造完整的行数据 + String srcId = "src1"; + String targetId = "dst1"; + long timestamp = System.currentTimeMillis(); + short direction = 1; // 假设 1 表示出边 + String label = "knows"; + String value = "edge-value-123"; + + // 按照 schema 顺序构造 row: src_id, target_id, ts, direction, label, value + GenericRow row = + GenericRow.of( + srcId.getBytes(), // src_id (主键) + targetId.getBytes(), // target_id + timestamp, // ts + direction, // direction + label.getBytes(), // label + value.getBytes() // value + ); + + edgeHandle.write(row); + long checkpointId = 1L; + edgeHandle.commit(checkpointId); + + // 读取数据并断言 - 使用所有列的投影 + int[] projection = new int[] {0, 1, 2, 3, 4, 5}; // 所有列 + RecordReaderIterator iterator = edgeHandle.getIterator(null, null, projection); + boolean found = false; + while (iterator.hasNext()) { + InternalRow internalRow = iterator.next(); + String readSrcId = new String(internalRow.getBinary(0)); + String readTargetId = new String(internalRow.getBinary(1)); + long readTs = internalRow.getLong(2); + short readDirection = internalRow.getShort(3); + String readLabel = new String(internalRow.getBinary(4)); + String readValue = new String(internalRow.getBinary(5)); + + if (srcId.equals(readSrcId) + && targetId.equals(readTargetId) + && timestamp == readTs + && direction == readDirection + && label.equals(readLabel) + && value.equals(readValue)) { + found = true; + break; + } } + iterator.close(); + assertTrue(found); + + storeBase.drop(); + storeBase.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWithTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWithTest.java index 802147533..899085099 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWithTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-paimon/src/test/java/org/apache/geaflow/store/paimon/predicate/BytesStartsWithTest.java @@ -24,15 +24,13 @@ public class BytesStartsWithTest { - @Test - public void testBytesStartsWith() { - byte[] prefixBytes = "pattern".getBytes(); - Assert.assertTrue(BytesStartsWith.INSTANCE.test(null, "pattern1".getBytes(), prefixBytes)); + @Test + public void testBytesStartsWith() { + byte[] prefixBytes = "pattern".getBytes(); + Assert.assertTrue(BytesStartsWith.INSTANCE.test(null, "pattern1".getBytes(), prefixBytes)); - Assert.assertFalse(BytesStartsWith.INSTANCE.test(null, "patter".getBytes(), prefixBytes)); - - Assert.assertFalse( - BytesStartsWith.INSTANCE.test(null, "a".getBytes(), "pattern".getBytes())); - } + Assert.assertFalse(BytesStartsWith.INSTANCE.test(null, "patter".getBytes(), prefixBytes)); + Assert.assertFalse(BytesStartsWith.INSTANCE.test(null, "a".getBytes(), "pattern".getBytes())); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/BaseRedisStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/BaseRedisStore.java index 7ff3b14ff..6d1fdc58c 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/BaseRedisStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/BaseRedisStore.java @@ -19,51 +19,51 @@ package org.apache.geaflow.store.redis; -import com.google.common.primitives.Bytes; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.store.IBaseStore; import org.apache.geaflow.store.context.StoreContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import com.google.common.primitives.Bytes; + import redis.clients.jedis.JedisPool; public abstract class BaseRedisStore implements IBaseStore { - protected static final Logger LOGGER = LoggerFactory.getLogger(KVRedisStore.class); - protected static final char REDIS_NAMESPACE_SPLITTER = ':'; - protected transient JedisPool jedisPool; - protected byte[] prefix; - protected int retryTimes; - protected int retryIntervalMs; - - public void init(StoreContext storeContext) { - Configuration config = storeContext.getConfig(); - this.retryTimes = config.getInteger(RedisConfigKeys.REDIS_RETRY_TIMES); - this.retryIntervalMs = config.getInteger(RedisConfigKeys.REDIS_RETRY_INTERVAL_MS); - String host = config.getString(RedisConfigKeys.REDIS_HOST); - int port = config.getInteger(RedisConfigKeys.REDIS_PORT); - LOGGER.info("redis connect {}:{}", host, port); - GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig(); - int connectTimeout = config.getInteger(RedisConfigKeys.REDIS_CONNECT_TIMEOUT); - String user = config.getString(RedisConfigKeys.REDIS_USER.getKey()); - String password = config.getString(RedisConfigKeys.REDIS_PASSWORD.getKey()); - this.jedisPool = new JedisPool(poolConfig, host, port, connectTimeout, user, password); - String prefixStr = storeContext.getName() + REDIS_NAMESPACE_SPLITTER; - this.prefix = prefixStr.getBytes(); - } + protected static final Logger LOGGER = LoggerFactory.getLogger(KVRedisStore.class); + protected static final char REDIS_NAMESPACE_SPLITTER = ':'; + protected transient JedisPool jedisPool; + protected byte[] prefix; + protected int retryTimes; + protected int retryIntervalMs; - @Override - public void flush() { + public void init(StoreContext storeContext) { + Configuration config = storeContext.getConfig(); + this.retryTimes = config.getInteger(RedisConfigKeys.REDIS_RETRY_TIMES); + this.retryIntervalMs = config.getInteger(RedisConfigKeys.REDIS_RETRY_INTERVAL_MS); + String host = config.getString(RedisConfigKeys.REDIS_HOST); + int port = config.getInteger(RedisConfigKeys.REDIS_PORT); + LOGGER.info("redis connect {}:{}", host, port); + GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig(); + int connectTimeout = config.getInteger(RedisConfigKeys.REDIS_CONNECT_TIMEOUT); + String user = config.getString(RedisConfigKeys.REDIS_USER.getKey()); + String password = config.getString(RedisConfigKeys.REDIS_PASSWORD.getKey()); + this.jedisPool = new JedisPool(poolConfig, host, port, connectTimeout, user, password); + String prefixStr = storeContext.getName() + REDIS_NAMESPACE_SPLITTER; + this.prefix = prefixStr.getBytes(); + } - } + @Override + public void flush() {} - protected byte[] getRedisKey(byte[] key) { - return Bytes.concat(prefix, key); - } + protected byte[] getRedisKey(byte[] key) { + return Bytes.concat(prefix, key); + } - @Override - public void close() { - this.jedisPool.close(); - } + @Override + public void close() { + this.jedisPool.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KListRedisStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KListRedisStore.java index d44316fdc..b61b2bf11 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KListRedisStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KListRedisStore.java @@ -19,68 +19,82 @@ package org.apache.geaflow.store.redis; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.utils.RetryCommand; import org.apache.geaflow.state.serializer.IKVSerializer; import org.apache.geaflow.store.api.key.IKListStore; import org.apache.geaflow.store.context.StoreContext; + +import com.google.common.base.Preconditions; + import redis.clients.jedis.Jedis; public class KListRedisStore extends BaseRedisStore implements IKListStore { - private IKVSerializer kvSerializer; + private IKVSerializer kvSerializer; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.kvSerializer = (IKVSerializer) Preconditions.checkNotNull( - storeContext.getKeySerializer(), "kvSerializer must be set"); - } + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.kvSerializer = + (IKVSerializer) + Preconditions.checkNotNull(storeContext.getKeySerializer(), "kvSerializer must be set"); + } - @Override - public void add(K key, V... values) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - byte[][] bValues = Arrays.stream(values).map(this.kvSerializer::serializeValue) - .toArray(byte[][]::new); + @Override + public void add(K key, V... values) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + byte[][] bValues = + Arrays.stream(values).map(this.kvSerializer::serializeValue).toArray(byte[][]::new); - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.lpush(redisKey, bValues); - } - return null; - }, retryTimes, retryIntervalMs); - } + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.lpush(redisKey, bValues); + } + return null; + }, + retryTimes, + retryIntervalMs); + } - @Override - public List get(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); + @Override + public List get(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); - List valueArray = RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { + List valueArray = + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { return jedis.lrange(redisKey, 0, -1); - } - }, retryTimes, retryIntervalMs); - - return valueArray.stream().map(this.kvSerializer::deserializeValue) - .collect(Collectors.toList()); - } + } + }, + retryTimes, + retryIntervalMs); + return valueArray.stream() + .map(this.kvSerializer::deserializeValue) + .collect(Collectors.toList()); + } - @Override - public void remove(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); + @Override + public void remove(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.del(redisKey); - } - return null; - }, retryTimes, retryIntervalMs); - } + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.del(redisKey); + } + return null; + }, + retryTimes, + retryIntervalMs); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KMapRedisStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KMapRedisStore.java index f42956191..6362ef554 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KMapRedisStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KMapRedisStore.java @@ -19,125 +19,150 @@ package org.apache.geaflow.store.redis; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.stream.Collectors; + import org.apache.geaflow.common.utils.RetryCommand; import org.apache.geaflow.state.serializer.IKMapSerializer; import org.apache.geaflow.store.api.key.IKMapStore; import org.apache.geaflow.store.context.StoreContext; -import redis.clients.jedis.Jedis; - -public class KMapRedisStore extends BaseRedisStore implements IKMapStore { - private IKMapSerializer kMapSerializer; +import com.google.common.base.Preconditions; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.kMapSerializer = (IKMapSerializer) Preconditions.checkNotNull( - storeContext.getKeySerializer(), "keySerializer must be set"); - } +import redis.clients.jedis.Jedis; - @Override - public void add(K key, Map value) { - Map newMap = new HashMap<>(value.size()); - for (Entry entry : value.entrySet()) { - byte[] ukArray = this.kMapSerializer.serializeUK(entry.getKey()); - byte[] uvArray = this.kMapSerializer.serializeUV(entry.getValue()); - newMap.put(ukArray, uvArray); - } - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.hset(redisKey, newMap); - } - return null; - }, retryTimes, retryIntervalMs); - } +public class KMapRedisStore extends BaseRedisStore implements IKMapStore { - @Override - public void add(K key, UK uk, UV value) { - byte[] ukArray = this.kMapSerializer.serializeUK(uk); - byte[] uvArray = this.kMapSerializer.serializeUV(value); - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.hset(redisKey, ukArray, uvArray); - } - return null; - }, retryTimes, retryIntervalMs); + private IKMapSerializer kMapSerializer; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.kMapSerializer = + (IKMapSerializer) + Preconditions.checkNotNull( + storeContext.getKeySerializer(), "keySerializer must be set"); + } + + @Override + public void add(K key, Map value) { + Map newMap = new HashMap<>(value.size()); + for (Entry entry : value.entrySet()) { + byte[] ukArray = this.kMapSerializer.serializeUK(entry.getKey()); + byte[] uvArray = this.kMapSerializer.serializeUV(entry.getValue()); + newMap.put(ukArray, uvArray); } - - @Override - public Map get(K key) { - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - - Map map = RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.hset(redisKey, newMap); + } + return null; + }, + retryTimes, + retryIntervalMs); + } + + @Override + public void add(K key, UK uk, UV value) { + byte[] ukArray = this.kMapSerializer.serializeUK(uk); + byte[] uvArray = this.kMapSerializer.serializeUV(value); + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.hset(redisKey, ukArray, uvArray); + } + return null; + }, + retryTimes, + retryIntervalMs); + } + + @Override + public Map get(K key) { + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + + Map map = + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { return jedis.hgetAll(redisKey); - } - }, retryTimes, retryIntervalMs); - - Map newMap = new HashMap<>(map.size()); - for (Entry entry : map.entrySet()) { - newMap.put(this.kMapSerializer.deserializeUK(entry.getKey()), - this.kMapSerializer.deserializeUV(entry.getValue())); - } - return newMap; + } + }, + retryTimes, + retryIntervalMs); + + Map newMap = new HashMap<>(map.size()); + for (Entry entry : map.entrySet()) { + newMap.put( + this.kMapSerializer.deserializeUK(entry.getKey()), + this.kMapSerializer.deserializeUV(entry.getValue())); } - - @Override - public List get(K key, UK... uks) { - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - byte[][] ukArray = Arrays.stream(uks).map(this.kMapSerializer::serializeUK) - .toArray(byte[][]::new); - - List uvArray = RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { + return newMap; + } + + @Override + public List get(K key, UK... uks) { + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + byte[][] ukArray = + Arrays.stream(uks).map(this.kMapSerializer::serializeUK).toArray(byte[][]::new); + + List uvArray = + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { return jedis.hmget(redisKey, ukArray); - } - }, retryTimes, retryIntervalMs); - - return uvArray.stream().map(this.kMapSerializer::deserializeUV) - .collect(Collectors.toList()); - } - - @Override - public void remove(K key) { - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.del(redisKey); - } - return null; - }, retryTimes, retryIntervalMs); - } - - @Override - public void remove(K key, UK... uks) { - byte[] keyArray = this.kMapSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - byte[][] ukArray = Arrays.stream(uks).map(this.kMapSerializer::serializeUK) - .toArray(byte[][]::new); - - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.hdel(redisKey, ukArray); - } - return null; - }, retryTimes, retryIntervalMs); - } + } + }, + retryTimes, + retryIntervalMs); + + return uvArray.stream().map(this.kMapSerializer::deserializeUV).collect(Collectors.toList()); + } + + @Override + public void remove(K key) { + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.del(redisKey); + } + return null; + }, + retryTimes, + retryIntervalMs); + } + + @Override + public void remove(K key, UK... uks) { + byte[] keyArray = this.kMapSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + byte[][] ukArray = + Arrays.stream(uks).map(this.kMapSerializer::serializeUK).toArray(byte[][]::new); + + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.hdel(redisKey, ukArray); + } + return null; + }, + retryTimes, + retryIntervalMs); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KVRedisStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KVRedisStore.java index ef15f90a1..31d58b37f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KVRedisStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/KVRedisStore.java @@ -19,64 +19,77 @@ package org.apache.geaflow.store.redis; -import com.google.common.base.Preconditions; import org.apache.geaflow.common.utils.RetryCommand; import org.apache.geaflow.state.serializer.IKVSerializer; import org.apache.geaflow.store.api.key.IKVStore; import org.apache.geaflow.store.context.StoreContext; + +import com.google.common.base.Preconditions; + import redis.clients.jedis.Jedis; public class KVRedisStore extends BaseRedisStore implements IKVStore { - private IKVSerializer kvSerializer; + private IKVSerializer kvSerializer; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.kvSerializer = (IKVSerializer) Preconditions.checkNotNull( - storeContext.getKeySerializer(), "keySerializer must be set"); - } - - @Override - public void put(K key, V value) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - byte[] valueArray = this.kvSerializer.serializeValue(value); - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.set(redisKey, valueArray); - } - return null; - }, retryTimes, retryIntervalMs); + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.kvSerializer = + (IKVSerializer) + Preconditions.checkNotNull( + storeContext.getKeySerializer(), "keySerializer must be set"); + } - } + @Override + public void put(K key, V value) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + byte[] valueArray = this.kvSerializer.serializeValue(value); + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.set(redisKey, valueArray); + } + return null; + }, + retryTimes, + retryIntervalMs); + } - @Override - public V get(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); - byte[] valueArray = RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { + @Override + public V get(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); + byte[] valueArray = + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { return jedis.get(redisKey); - } - }, retryTimes, retryIntervalMs); + } + }, + retryTimes, + retryIntervalMs); - if (valueArray == null) { - return null; - } - return this.kvSerializer.deserializeValue(valueArray); + if (valueArray == null) { + return null; } + return this.kvSerializer.deserializeValue(valueArray); + } - @Override - public void remove(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] redisKey = getRedisKey(keyArray); + @Override + public void remove(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] redisKey = getRedisKey(keyArray); - RetryCommand.run(() -> { - try (Jedis jedis = jedisPool.getResource()) { - jedis.del(redisKey); - } - return null; - }, retryTimes, retryIntervalMs); - } + RetryCommand.run( + () -> { + try (Jedis jedis = jedisPool.getResource()) { + jedis.del(redisKey); + } + return null; + }, + retryTimes, + retryIntervalMs); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisConfigKeys.java index b470553eb..b9dee590e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisConfigKeys.java @@ -24,39 +24,38 @@ public class RedisConfigKeys { - public static final ConfigKey REDIS_HOST = ConfigKeys - .key("geaflow.store.redis.host") - .defaultValue("127.0.0.1") - .description("geaflow store redis server host"); - - public static final ConfigKey REDIS_PORT = ConfigKeys - .key("geaflow.store.redis.port") - .defaultValue(6379) - .description("geaflow store redis server port"); - - public static final ConfigKey REDIS_RETRY_TIMES = ConfigKeys - .key("geaflow.store.redis.retry.times") - .defaultValue(10) - .description("geaflow store redis retry times"); - - public static final ConfigKey REDIS_RETRY_INTERVAL_MS = ConfigKeys - .key("geaflow.store.redis.retry.interval.ms") - .defaultValue(500) - .description("geaflow store redis retry interval ms"); - - public static final ConfigKey REDIS_USER = ConfigKeys - .key("geaflow.store.redis.user") - .defaultValue("") - .description("redis connect user name"); - - public static final ConfigKey REDIS_PASSWORD = ConfigKeys - .key("geaflow.store.redis.password") - .defaultValue("") - .description("redis connect password"); - - public static final ConfigKey REDIS_CONNECT_TIMEOUT = ConfigKeys - .key("geaflow.store.redis.connection.timeout") - .defaultValue(5000) - .description("redis connect timeout in ms"); - + public static final ConfigKey REDIS_HOST = + ConfigKeys.key("geaflow.store.redis.host") + .defaultValue("127.0.0.1") + .description("geaflow store redis server host"); + + public static final ConfigKey REDIS_PORT = + ConfigKeys.key("geaflow.store.redis.port") + .defaultValue(6379) + .description("geaflow store redis server port"); + + public static final ConfigKey REDIS_RETRY_TIMES = + ConfigKeys.key("geaflow.store.redis.retry.times") + .defaultValue(10) + .description("geaflow store redis retry times"); + + public static final ConfigKey REDIS_RETRY_INTERVAL_MS = + ConfigKeys.key("geaflow.store.redis.retry.interval.ms") + .defaultValue(500) + .description("geaflow store redis retry interval ms"); + + public static final ConfigKey REDIS_USER = + ConfigKeys.key("geaflow.store.redis.user") + .defaultValue("") + .description("redis connect user name"); + + public static final ConfigKey REDIS_PASSWORD = + ConfigKeys.key("geaflow.store.redis.password") + .defaultValue("") + .description("redis connect password"); + + public static final ConfigKey REDIS_CONNECT_TIMEOUT = + ConfigKeys.key("geaflow.store.redis.connection.timeout") + .defaultValue(5000) + .description("redis connect timeout in ms"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisStoreBuilder.java index 029d893d5..78c65c1a0 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/main/java/org/apache/geaflow/store/redis/RedisStoreBuilder.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,41 +33,41 @@ public class RedisStoreBuilder implements IStoreBuilder { - private static final StoreDesc STORE_DESC = new RedisStoreDesc(); + private static final StoreDesc STORE_DESC = new RedisStoreDesc(); - public IBaseStore getStore(DataModel type, Configuration config) { - switch (type) { - case KV: - return new KVRedisStore<>(); - case KList: - return new KListRedisStore<>(); - case KMap: - return new KMapRedisStore<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); - } + public IBaseStore getStore(DataModel type, Configuration config) { + switch (type) { + case KV: + return new KVRedisStore<>(); + case KList: + return new KListRedisStore<>(); + case KMap: + return new KMapRedisStore<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); } + } - @Override - public StoreDesc getStoreDesc() { - return STORE_DESC; - } + @Override + public StoreDesc getStoreDesc() { + return STORE_DESC; + } - @Override - public List supportedDataModel() { - return Arrays.asList(DataModel.KV, DataModel.KMap); - } + @Override + public List supportedDataModel() { + return Arrays.asList(DataModel.KV, DataModel.KMap); + } - public static class RedisStoreDesc implements StoreDesc { + public static class RedisStoreDesc implements StoreDesc { - @Override - public boolean isLocalStore() { - return false; - } + @Override + public boolean isLocalStore() { + return false; + } - @Override - public String name() { - return StoreType.REDIS.name(); - } + @Override + public String name() { + return StoreType.REDIS.name(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/test/java/org/apache/geaflow/store/redis/RedisStoreBuilderTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/test/java/org/apache/geaflow/store/redis/RedisStoreBuilderTest.java index c976d685a..b9feed915 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/test/java/org/apache/geaflow/store/redis/RedisStoreBuilderTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-redis/src/test/java/org/apache/geaflow/store/redis/RedisStoreBuilderTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.store.redis; -import com.github.fppt.jedismock.RedisServer; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -30,6 +29,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.state.DataModel; import org.apache.geaflow.state.StoreType; @@ -47,167 +47,172 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class RedisStoreBuilderTest { - - private RedisServer redisServer; - - @BeforeClass - public void prepare() throws IOException { - redisServer = RedisServer.newRedisServer().start(); - } +import com.github.fppt.jedismock.RedisServer; - @AfterClass - public void tearUp() throws IOException { - redisServer.stop(); - } +public class RedisStoreBuilderTest { - @Test - public void testMultiThread() throws ExecutionException, InterruptedException { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); - IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, - new Configuration()); - Configuration configuration = new Configuration(); - configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - StoreContext storeContext = new StoreContext("redis").withConfig(configuration); - storeContext.withKeySerializer(new IKVSerializer() { - @Override - public byte[] serializeKey(String key) { - return key.getBytes(); - } - - @Override - public String deserializeKey(byte[] array) { - return new String(array); - } - - @Override - public byte[] serializeValue(String value) { - return value.getBytes(); - } - - @Override - public String deserializeValue(byte[] valueArray) { - return new String(valueArray); - } + private RedisServer redisServer; + + @BeforeClass + public void prepare() throws IOException { + redisServer = RedisServer.newRedisServer().start(); + } + + @AfterClass + public void tearUp() throws IOException { + redisServer.stop(); + } + + @Test + public void testMultiThread() throws ExecutionException, InterruptedException { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); + IKVStore kvStore = + (IKVStore) builder.getStore(DataModel.KV, new Configuration()); + Configuration configuration = new Configuration(); + configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + StoreContext storeContext = new StoreContext("redis").withConfig(configuration); + storeContext.withKeySerializer( + new IKVSerializer() { + @Override + public byte[] serializeKey(String key) { + return key.getBytes(); + } + + @Override + public String deserializeKey(byte[] array) { + return new String(array); + } + + @Override + public byte[] serializeValue(String value) { + return value.getBytes(); + } + + @Override + public String deserializeValue(byte[] valueArray) { + return new String(valueArray); + } }); - kvStore.init(storeContext); + kvStore.init(storeContext); - ExecutorService executors = Executors.newFixedThreadPool(10); + ExecutorService executors = Executors.newFixedThreadPool(10); - List futureList = new ArrayList<>(); - for (int i = 0; i < 20; i++) { - final int index = i; - Future future = executors.submit(() -> { + List futureList = new ArrayList<>(); + for (int i = 0; i < 20; i++) { + final int index = i; + Future future = + executors.submit( + () -> { for (int j = 0; j < 1000; j++) { - kvStore.put(index + "hello" + j, index + "world" + j); - Assert.assertEquals(kvStore.get(index + "hello" + j), index + "world" + j); + kvStore.put(index + "hello" + j, index + "world" + j); + Assert.assertEquals(kvStore.get(index + "hello" + j), index + "world" + j); } - }); - futureList.add(future); - } - for (Future f : futureList) { - f.get(); - } + }); + futureList.add(future); } - - @Test - public void testKV() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); - IKVStore kvStore = (IKVStore) builder.getStore(DataModel.KV, - new Configuration()); - - Configuration configuration = new Configuration(); - configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - StoreContext storeContext = new StoreContext("redis").withConfig(configuration); - storeContext.withKeySerializer(new IKVSerializer() { - @Override - public byte[] serializeKey(String key) { - return key.getBytes(); - } - - @Override - public String deserializeKey(byte[] array) { - return new String(array); - } - - @Override - public byte[] serializeValue(String value) { - return value.getBytes(); - } - - @Override - public String deserializeValue(byte[] valueArray) { - return new String(valueArray); - } - }); - - kvStore.init(storeContext); - kvStore.put("hello", "world"); - kvStore.put("foo", "bar"); - - Assert.assertEquals(kvStore.get("hello"), "world"); - Assert.assertEquals(kvStore.get("foo"), "bar"); - - kvStore.remove("foo"); - Assert.assertNull(kvStore.get("foo")); - } - - @Test - public void testKMap() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); - IKMapStore kMapStore = - (IKMapStore) builder.getStore( - DataModel.KMap, new Configuration()); - - Configuration configuration = new Configuration(); - configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - StoreContext storeContext = new StoreContext("redis").withConfig(configuration); - storeContext.withKeySerializer( - new DefaultKMapSerializer<>(String.class, String.class, String.class)); - kMapStore.init(storeContext); - - Map map = new HashMap<>(); - map.put("hello", "world"); - map.put("hello1", "world1"); - - kMapStore.add("hw", map); - - map.clear(); - map.put("foo", "bar"); - kMapStore.add("hw", map); - kMapStore.add("hw", "bar", "foo"); - - Assert.assertEquals(kMapStore.get("hw").size(), 4); - Assert.assertEquals(kMapStore.get("hw", "foo", "bar"), Arrays.asList("bar", "foo")); - - kMapStore.remove("hw", "bar"); - Assert.assertEquals(kMapStore.get("hw").size(), 3); - - kMapStore.remove("hw"); - Assert.assertEquals(kMapStore.get("hw").size(), 0); + for (Future f : futureList) { + f.get(); } + } + + @Test + public void testKV() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); + IKVStore kvStore = + (IKVStore) builder.getStore(DataModel.KV, new Configuration()); + + Configuration configuration = new Configuration(); + configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + StoreContext storeContext = new StoreContext("redis").withConfig(configuration); + storeContext.withKeySerializer( + new IKVSerializer() { + @Override + public byte[] serializeKey(String key) { + return key.getBytes(); + } + + @Override + public String deserializeKey(byte[] array) { + return new String(array); + } + + @Override + public byte[] serializeValue(String value) { + return value.getBytes(); + } + + @Override + public String deserializeValue(byte[] valueArray) { + return new String(valueArray); + } + }); - @Test - public void testKList() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); - IKListStore kListStore = (IKListStore) builder.getStore( - DataModel.KList, new Configuration()); - - Configuration configuration = new Configuration(); - configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); - configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); - StoreContext storeContext = new StoreContext("redis").withConfig(configuration); - storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); - kListStore.init(storeContext); - - kListStore.add("hw", "foo", "bar"); - kListStore.add("hw", "hello"); - - Assert.assertEquals(kListStore.get("hw").size(), 3); - kListStore.remove("hw"); - Assert.assertEquals(kListStore.get("hw").size(), 0); - } + kvStore.init(storeContext); + kvStore.put("hello", "world"); + kvStore.put("foo", "bar"); + + Assert.assertEquals(kvStore.get("hello"), "world"); + Assert.assertEquals(kvStore.get("foo"), "bar"); + + kvStore.remove("foo"); + Assert.assertNull(kvStore.get("foo")); + } + + @Test + public void testKMap() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); + IKMapStore kMapStore = + (IKMapStore) builder.getStore(DataModel.KMap, new Configuration()); + + Configuration configuration = new Configuration(); + configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + StoreContext storeContext = new StoreContext("redis").withConfig(configuration); + storeContext.withKeySerializer( + new DefaultKMapSerializer<>(String.class, String.class, String.class)); + kMapStore.init(storeContext); + + Map map = new HashMap<>(); + map.put("hello", "world"); + map.put("hello1", "world1"); + + kMapStore.add("hw", map); + + map.clear(); + map.put("foo", "bar"); + kMapStore.add("hw", map); + kMapStore.add("hw", "bar", "foo"); + + Assert.assertEquals(kMapStore.get("hw").size(), 4); + Assert.assertEquals(kMapStore.get("hw", "foo", "bar"), Arrays.asList("bar", "foo")); + + kMapStore.remove("hw", "bar"); + Assert.assertEquals(kMapStore.get("hw").size(), 3); + + kMapStore.remove("hw"); + Assert.assertEquals(kMapStore.get("hw").size(), 0); + } + + @Test + public void testKList() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.REDIS.name()); + IKListStore kListStore = + (IKListStore) builder.getStore(DataModel.KList, new Configuration()); + + Configuration configuration = new Configuration(); + configuration.put(RedisConfigKeys.REDIS_HOST, redisServer.getHost()); + configuration.put(RedisConfigKeys.REDIS_PORT, String.valueOf(redisServer.getBindPort())); + StoreContext storeContext = new StoreContext("redis").withConfig(configuration); + storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); + kListStore.init(storeContext); + + kListStore.add("hw", "foo", "bar"); + kListStore.add("hw", "hello"); + + Assert.assertEquals(kListStore.get("hw").size(), 3); + kListStore.remove("hw"); + Assert.assertEquals(kListStore.get("hw").size(), 0); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbGraphStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbGraphStore.java index 2de72dbd0..16bdf5804 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbGraphStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbGraphStore.java @@ -21,6 +21,7 @@ import java.nio.file.Path; import java.nio.file.Paths; + import org.apache.geaflow.state.pushdown.inner.CodeGenFilterConverter; import org.apache.geaflow.state.pushdown.inner.DirectFilterConverter; import org.apache.geaflow.state.pushdown.inner.IFilterConverter; @@ -30,24 +31,23 @@ public abstract class BaseRocksdbGraphStore extends BaseRocksdbStore implements IPushDownStore { - protected IFilterConverter filterConverter; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - boolean codegenEnable = - storeContext.getConfig().getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE); - filterConverter = codegenEnable ? new CodeGenFilterConverter() : new DirectFilterConverter(); - } - - @Override - public IFilterConverter getFilterConverter() { - return filterConverter; - } - - @Override - protected Path getRemotePath() { - return Paths.get(root, storeContext.getName(), - Integer.toString(shardId)); - } + protected IFilterConverter filterConverter; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + boolean codegenEnable = + storeContext.getConfig().getBoolean(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE); + filterConverter = codegenEnable ? new CodeGenFilterConverter() : new DirectFilterConverter(); + } + + @Override + public IFilterConverter getFilterConverter() { + return filterConverter; + } + + @Override + protected Path getRemotePath() { + return Paths.get(root, storeContext.getName(), Integer.toString(shardId)); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbStore.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbStore.java index d4e21f3a3..fa3de1b35 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbStore.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/BaseRocksdbStore.java @@ -22,6 +22,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; @@ -37,121 +38,126 @@ public abstract class BaseRocksdbStore extends BaseGraphStore implements IStatefulStore { - private static final Logger LOGGER = LoggerFactory.getLogger(BaseRocksdbStore.class); + private static final Logger LOGGER = LoggerFactory.getLogger(BaseRocksdbStore.class); - protected Configuration config; - protected String rocksdbPath; - protected String remotePath; - protected RocksdbClient rocksdbClient; - protected RocksdbPersistClient persistClient; - protected long keepChkNum; + protected Configuration config; + protected String rocksdbPath; + protected String remotePath; + protected RocksdbClient rocksdbClient; + protected RocksdbPersistClient persistClient; + protected long keepChkNum; - protected String root; - protected String jobName; - protected int shardId; - protected long recoveryVersion = -1; + protected String root; + protected String jobName; + protected int shardId; + protected long recoveryVersion = -1; - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.config = storeContext.getConfig(); - this.shardId = storeContext.getShardId(); + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.config = storeContext.getConfig(); + this.shardId = storeContext.getShardId(); - String workerPath = this.config.getString(ExecutionConfigKeys.JOB_WORK_PATH); - this.jobName = this.config.getString(ExecutionConfigKeys.JOB_APP_NAME); + String workerPath = this.config.getString(ExecutionConfigKeys.JOB_WORK_PATH); + this.jobName = this.config.getString(ExecutionConfigKeys.JOB_APP_NAME); - this.rocksdbPath = Paths.get(workerPath, jobName, storeContext.getName(), - Integer.toString(shardId)).toString(); + this.rocksdbPath = + Paths.get(workerPath, jobName, storeContext.getName(), Integer.toString(shardId)) + .toString(); - this.root = this.config.getString(FileConfigKeys.ROOT); + this.root = this.config.getString(FileConfigKeys.ROOT); - this.remotePath = getRemotePath().toString(); - this.persistClient = new RocksdbPersistClient(this.config); - long chkRate = this.config.getLong(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT); - this.keepChkNum = Math.max( - this.config.getInteger(StateConfigKeys.STATE_ARCHIVED_VERSION_NUM), chkRate * 2); + this.remotePath = getRemotePath().toString(); + this.persistClient = new RocksdbPersistClient(this.config); + long chkRate = this.config.getLong(FrameworkConfigKeys.BATCH_NUMBER_PER_CHECKPOINT); + this.keepChkNum = + Math.max(this.config.getInteger(StateConfigKeys.STATE_ARCHIVED_VERSION_NUM), chkRate * 2); - boolean enableDynamicCreateColumnFamily = PartitionType.getEnum( + boolean enableDynamicCreateColumnFamily = + PartitionType.getEnum( this.config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)) .isPartition(); - this.rocksdbClient = new RocksdbClient(rocksdbPath, getCfList(), config, - enableDynamicCreateColumnFamily); - LOGGER.info("ThreadId {}, BaseRocksdbStore initDB", Thread.currentThread().getId()); - this.rocksdbClient.initDB(); - } - - protected abstract List getCfList(); - - @Override - public void archive(long version) { - flush(); - String chkPath = RocksdbConfigKeys.getChkPath(this.rocksdbPath, version); - rocksdbClient.checkpoint(chkPath); - // sync file - try { - persistClient.archive(version, chkPath, remotePath, keepChkNum); - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("archive fail"), e); - } - } - - @Override - public void recovery(long version) { - if (version <= recoveryVersion) { - LOGGER.info("shardId {} recovery version {} <= last recovery version {}, ignore", - shardId, version, recoveryVersion); - return; - } - drop(); - String chkPath = RocksdbConfigKeys.getChkPath(this.rocksdbPath, version); - String recoverPath = remotePath; - boolean isScale = shardId != storeContext.getShardId(); - if (isScale) { - recoverPath = getRemotePath().toString(); - } - try { - persistClient.recover(version, this.rocksdbPath, chkPath, recoverPath); - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("recover fail"), e); - } - if (isScale) { - persistClient.clearFileInfo(); - shardId = storeContext.getShardId(); - } - this.rocksdbClient.initDB(); - recoveryVersion = version; - } - - protected Path getRemotePath() { - return Paths.get(root, jobName, storeContext.getName(), Integer.toString(shardId)); + this.rocksdbClient = + new RocksdbClient(rocksdbPath, getCfList(), config, enableDynamicCreateColumnFamily); + LOGGER.info("ThreadId {}, BaseRocksdbStore initDB", Thread.currentThread().getId()); + this.rocksdbClient.initDB(); + } + + protected abstract List getCfList(); + + @Override + public void archive(long version) { + flush(); + String chkPath = RocksdbConfigKeys.getChkPath(this.rocksdbPath, version); + rocksdbClient.checkpoint(chkPath); + // sync file + try { + persistClient.archive(version, chkPath, remotePath, keepChkNum); + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("archive fail"), e); } - - @Override - public long recoveryLatest() { - long chkId = persistClient.getLatestCheckpointId(remotePath); - if (chkId > 0) { - recovery(chkId); - } - return chkId; + } + + @Override + public void recovery(long version) { + if (version <= recoveryVersion) { + LOGGER.info( + "shardId {} recovery version {} <= last recovery version {}, ignore", + shardId, + version, + recoveryVersion); + return; } - - @Override - public void compact() { - this.rocksdbClient.compact(); + drop(); + String chkPath = RocksdbConfigKeys.getChkPath(this.rocksdbPath, version); + String recoverPath = remotePath; + boolean isScale = shardId != storeContext.getShardId(); + if (isScale) { + recoverPath = getRemotePath().toString(); } - - @Override - public void flush() { - this.rocksdbClient.flush(); + try { + persistClient.recover(version, this.rocksdbPath, chkPath, recoverPath); + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("recover fail"), e); } - - @Override - public void close() { - this.rocksdbClient.close(); + if (isScale) { + persistClient.clearFileInfo(); + shardId = storeContext.getShardId(); } - - @Override - public void drop() { - rocksdbClient.drop(); + this.rocksdbClient.initDB(); + recoveryVersion = version; + } + + protected Path getRemotePath() { + return Paths.get(root, jobName, storeContext.getName(), Integer.toString(shardId)); + } + + @Override + public long recoveryLatest() { + long chkId = persistClient.getLatestCheckpointId(remotePath); + if (chkId > 0) { + recovery(chkId); } + return chkId; + } + + @Override + public void compact() { + this.rocksdbClient.compact(); + } + + @Override + public void flush() { + this.rocksdbClient.flush(); + } + + @Override + public void close() { + this.rocksdbClient.close(); + } + + @Override + public void drop() { + rocksdbClient.drop(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/DynamicGraphRocksdbStoreBase.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/DynamicGraphRocksdbStoreBase.java index 3bfc005a1..b540426cb 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/DynamicGraphRocksdbStoreBase.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/DynamicGraphRocksdbStoreBase.java @@ -27,6 +27,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -43,122 +44,122 @@ public class DynamicGraphRocksdbStoreBase extends BaseRocksdbGraphStore implements IDynamicGraphStore { - private IGraphMultiVersionedRocksdbProxy proxy; - - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - IGraphKVEncoder encoder = GraphKVEncoderFactory.build(config, - storeContext.getGraphSchema()); - this.proxy = ProxyBuilder.buildMultiVersioned(config, rocksdbClient, encoder); - } - - @Override - protected List getCfList() { - return Arrays.asList(VERTEX_CF, EDGE_CF, VERTEX_INDEX_CF); - } - - @Override - public void addEdge(long version, IEdge edge) { - this.proxy.addEdge(version, edge); - } - - @Override - public void addVertex(long version, IVertex vertex) { - this.proxy.addVertex(version, vertex); - } - - @Override - public IVertex getVertex(long sliceId, K sid, IStatePushDown pushdown) { - return this.proxy.getVertex(sliceId, sid, pushdown); - } - - @Override - public List> getEdges(long sliceId, K sid, IStatePushDown pushdown) { - return this.proxy.getEdges(sliceId, sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(long sliceId, K sid, - IStatePushDown pushdown) { - return this.proxy.getOneDegreeGraph(sliceId, sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return this.proxy.vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - return this.proxy.vertexIDIterator(version, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown) { - return proxy.getVertexIterator(version, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return proxy.getVertexIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - return proxy.getEdgeIterator(version, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return proxy.getEdgeIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - return proxy.getOneDegreeGraphIterator(version, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, List keys, - IStatePushDown pushdown) { - return proxy.getOneDegreeGraphIterator(version, keys, pushdown); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - return this.proxy.getAllVersions(id, dataType); - } - - @Override - public long getLatestVersion(K id, DataType dataType) { - return this.proxy.getLatestVersion(id, dataType); - } - - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - return this.proxy.getAllVersionData(id, pushdown, dataType); - } - - @Override - public Map> getVersionData(K id, Collection slices, - IStatePushDown pushdown, DataType dataType) { - return this.proxy.getVersionData(id, slices, pushdown, dataType); - } - - @Override - public void flush() { - proxy.flush(); - super.flush(); - } - - @Override - public void close() { - proxy.close(); - super.close(); - } + private IGraphMultiVersionedRocksdbProxy proxy; + + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + IGraphKVEncoder encoder = + GraphKVEncoderFactory.build(config, storeContext.getGraphSchema()); + this.proxy = ProxyBuilder.buildMultiVersioned(config, rocksdbClient, encoder); + } + + @Override + protected List getCfList() { + return Arrays.asList(VERTEX_CF, EDGE_CF, VERTEX_INDEX_CF); + } + + @Override + public void addEdge(long version, IEdge edge) { + this.proxy.addEdge(version, edge); + } + + @Override + public void addVertex(long version, IVertex vertex) { + this.proxy.addVertex(version, vertex); + } + + @Override + public IVertex getVertex(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getVertex(sliceId, sid, pushdown); + } + + @Override + public List> getEdges(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getEdges(sliceId, sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long sliceId, K sid, IStatePushDown pushdown) { + return this.proxy.getOneDegreeGraph(sliceId, sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return this.proxy.vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + return this.proxy.vertexIDIterator(version, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + return proxy.getVertexIterator(version, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getVertexIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + return proxy.getEdgeIterator(version, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getEdgeIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + return proxy.getOneDegreeGraphIterator(version, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return proxy.getOneDegreeGraphIterator(version, keys, pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + return this.proxy.getAllVersions(id, dataType); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + return this.proxy.getLatestVersion(id, dataType); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + return this.proxy.getAllVersionData(id, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection slices, IStatePushDown pushdown, DataType dataType) { + return this.proxy.getVersionData(id, slices, pushdown, dataType); + } + + @Override + public void flush() { + proxy.flush(); + super.flush(); + } + + @Override + public void close() { + proxy.close(); + super.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/KVRocksdbStoreBase.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/KVRocksdbStoreBase.java index 10773735a..6b0d26fe1 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/KVRocksdbStoreBase.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/KVRocksdbStoreBase.java @@ -21,50 +21,53 @@ import static org.apache.geaflow.store.rocksdb.RocksdbConfigKeys.DEFAULT_CF; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; + import org.apache.geaflow.state.serializer.IKVSerializer; import org.apache.geaflow.store.api.key.IKVStatefulStore; import org.apache.geaflow.store.context.StoreContext; -public class KVRocksdbStoreBase extends BaseRocksdbStore implements - IKVStatefulStore { +import com.google.common.base.Preconditions; - private IKVSerializer kvSerializer; +public class KVRocksdbStoreBase extends BaseRocksdbStore implements IKVStatefulStore { - @Override - public void init(StoreContext storeContext) { - super.init(storeContext); - this.kvSerializer = (IKVSerializer) Preconditions.checkNotNull( - storeContext.getKeySerializer(), "keySerializer must be set"); - } + private IKVSerializer kvSerializer; - @Override - protected List getCfList() { - return Arrays.asList(DEFAULT_CF); - } + @Override + public void init(StoreContext storeContext) { + super.init(storeContext); + this.kvSerializer = + (IKVSerializer) + Preconditions.checkNotNull( + storeContext.getKeySerializer(), "keySerializer must be set"); + } - @Override - public void put(K key, V value) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] valueArray = this.kvSerializer.serializeValue(value); - this.rocksdbClient.write(DEFAULT_CF, keyArray, valueArray); - } + @Override + protected List getCfList() { + return Arrays.asList(DEFAULT_CF); + } - @Override - public void remove(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - this.rocksdbClient.delete(DEFAULT_CF, keyArray); - } + @Override + public void put(K key, V value) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] valueArray = this.kvSerializer.serializeValue(value); + this.rocksdbClient.write(DEFAULT_CF, keyArray, valueArray); + } + + @Override + public void remove(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + this.rocksdbClient.delete(DEFAULT_CF, keyArray); + } - @Override - public V get(K key) { - byte[] keyArray = this.kvSerializer.serializeKey(key); - byte[] valueArray = this.rocksdbClient.get(DEFAULT_CF, keyArray); - if (valueArray == null) { - return null; - } - return this.kvSerializer.deserializeValue(valueArray); + @Override + public V get(K key) { + byte[] keyArray = this.kvSerializer.serializeKey(key); + byte[] valueArray = this.rocksdbClient.get(DEFAULT_CF, keyArray); + if (valueArray == null) { + return null; } + return this.kvSerializer.deserializeValue(valueArray); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/PartitionType.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/PartitionType.java index 5c71cee93..701e94628 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/PartitionType.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/PartitionType.java @@ -23,41 +23,41 @@ // Partition type for rocksdb graph store public enum PartitionType { - LABEL(false, true), - // TODO: Support dt partition - DT(true, false), - // TODO: Support label dt partition - DT_LABEL(true, true), - NONE(false, false); - - private final boolean dtPartition; - private final boolean labelPartition; - - private static final PartitionType[] VALUES = values(); - - public static PartitionType getEnum(String value) { - for (PartitionType v : VALUES) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - throw new GeaflowRuntimeException("Illegal partition type " + value); + LABEL(false, true), + // TODO: Support dt partition + DT(true, false), + // TODO: Support label dt partition + DT_LABEL(true, true), + NONE(false, false); + + private final boolean dtPartition; + private final boolean labelPartition; + + private static final PartitionType[] VALUES = values(); + + public static PartitionType getEnum(String value) { + for (PartitionType v : VALUES) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } } - - PartitionType(boolean dtPartition, boolean labelPartition) { - this.dtPartition = dtPartition; - this.labelPartition = labelPartition; - } - - public boolean isDtPartition() { - return dtPartition; - } - - public boolean isLabelPartition() { - return labelPartition; - } - - public boolean isPartition() { - return dtPartition || labelPartition; - } -} \ No newline at end of file + throw new GeaflowRuntimeException("Illegal partition type " + value); + } + + PartitionType(boolean dtPartition, boolean labelPartition) { + this.dtPartition = dtPartition; + this.labelPartition = labelPartition; + } + + public boolean isDtPartition() { + return dtPartition; + } + + public boolean isLabelPartition() { + return labelPartition; + } + + public boolean isPartition() { + return dtPartition || labelPartition; + } +} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbClient.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbClient.java index 42e75c0f2..0d2fdc5e8 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbClient.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbClient.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; @@ -47,272 +48,270 @@ public class RocksdbClient { - private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbClient.class); - - private final String filePath; - private final String optionClass; - private final Configuration config; - private final List cfList; - private RocksDB rocksdb; - private IRocksDBOptions rocksDBOptions; - // column family name -> vertex or edge column family handle - private final Map handleMap = new HashMap<>(); - private final Map vertexHandleMap = new ConcurrentHashMap<>(); - private final Map edgeHandleMap = new ConcurrentHashMap<>(); - - // column family name -> column family descriptor - private Map descriptorMap; - private boolean enableDynamicCreateColumnFamily; - - public RocksdbClient(String filePath, List cfList, Configuration config, - boolean enableDynamicCreateColumnFamily) { - this(filePath, cfList, config); - this.enableDynamicCreateColumnFamily = enableDynamicCreateColumnFamily; - - if (enableDynamicCreateColumnFamily) { - // Using concurrent hashmap in partition situation - descriptorMap = new ConcurrentHashMap<>(); - } else { - descriptorMap = new HashMap<>(); - } + private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbClient.class); + + private final String filePath; + private final String optionClass; + private final Configuration config; + private final List cfList; + private RocksDB rocksdb; + private IRocksDBOptions rocksDBOptions; + // column family name -> vertex or edge column family handle + private final Map handleMap = new HashMap<>(); + private final Map vertexHandleMap = new ConcurrentHashMap<>(); + private final Map edgeHandleMap = new ConcurrentHashMap<>(); + + // column family name -> column family descriptor + private Map descriptorMap; + private boolean enableDynamicCreateColumnFamily; + + public RocksdbClient( + String filePath, + List cfList, + Configuration config, + boolean enableDynamicCreateColumnFamily) { + this(filePath, cfList, config); + this.enableDynamicCreateColumnFamily = enableDynamicCreateColumnFamily; + + if (enableDynamicCreateColumnFamily) { + // Using concurrent hashmap in partition situation + descriptorMap = new ConcurrentHashMap<>(); + } else { + descriptorMap = new HashMap<>(); } - - public RocksdbClient(String filePath, List cfList, Configuration config) { - this.filePath = filePath; - this.cfList = cfList; - this.config = config; - this.optionClass = this.config.getString(RocksdbConfigKeys.ROCKSDB_OPTION_CLASS); + } + + public RocksdbClient(String filePath, List cfList, Configuration config) { + this.filePath = filePath; + this.cfList = cfList; + this.config = config; + this.optionClass = this.config.getString(RocksdbConfigKeys.ROCKSDB_OPTION_CLASS); + } + + private void initRocksDbOptions() { + if (this.rocksDBOptions == null || this.rocksDBOptions.isClosed()) { + LOGGER.info("rocksdb optionClass {}", optionClass); + try { + this.rocksDBOptions = (IRocksDBOptions) Class.forName(optionClass).newInstance(); + this.rocksDBOptions.init(config); + } catch (Throwable e) { + LOGGER.error("{} not found", optionClass); + throw new GeaflowRuntimeException( + RuntimeErrors.INST.runError(optionClass + "class not found"), e); + } + + if (this.config.getBoolean(RocksdbConfigKeys.ROCKSDB_STATISTICS_ENABLE)) { + this.rocksDBOptions.enableStatistics(); + } + if (this.config.getBoolean(StateConfigKeys.STATE_PARANOID_CHECK_ENABLE)) { + this.rocksDBOptions.enableParanoidCheck(); + } + } + } + + public void initDB() { + File dbFile = new File(filePath); + if (!dbFile.getParentFile().exists()) { + try { + FileUtils.forceMkdir(dbFile.getParentFile()); + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("create file error"), e); + } } - private void initRocksDbOptions() { - if (this.rocksDBOptions == null || this.rocksDBOptions.isClosed()) { - LOGGER.info("rocksdb optionClass {}", optionClass); - try { - this.rocksDBOptions = (IRocksDBOptions) Class.forName(optionClass).newInstance(); - this.rocksDBOptions.init(config); - } catch (Throwable e) { - LOGGER.error("{} not found", optionClass); - throw new GeaflowRuntimeException( - RuntimeErrors.INST.runError(optionClass + "class not found"), e); - } + if (rocksdb == null) { + initRocksDbOptions(); + LOGGER.info("ThreadId {}, buildDB {}", Thread.currentThread().getId(), filePath); + int ttl = this.config.getInteger(RocksdbConfigKeys.ROCKSDB_TTL_SECOND); + this.descriptorMap.clear(); + List handles = new ArrayList<>(); + List ttls = new ArrayList<>(); - if (this.config.getBoolean(RocksdbConfigKeys.ROCKSDB_STATISTICS_ENABLE)) { - this.rocksDBOptions.enableStatistics(); - } - if (this.config.getBoolean(StateConfigKeys.STATE_PARANOID_CHECK_ENABLE)) { - this.rocksDBOptions.enableParanoidCheck(); + List validCfList = cfList; + if (enableDynamicCreateColumnFamily) { + try { + List cfNames = RocksDB.listColumnFamilies(new Options(), filePath); + if (!cfNames.isEmpty()) { + validCfList = new ArrayList<>(); + for (byte[] cfName : cfNames) { + validCfList.add(new String(cfName)); } + } + } catch (RocksDBException e) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.runError("List column family error"), e); } - } - - public void initDB() { - File dbFile = new File(filePath); - if (!dbFile.getParentFile().exists()) { - try { - FileUtils.forceMkdir(dbFile.getParentFile()); - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("create file error"), - e); - } + } + + List descriptorList = new ArrayList<>(); + for (String name : validCfList) { + ColumnFamilyDescriptor descriptor = + new ColumnFamilyDescriptor(name.getBytes(), rocksDBOptions.buildFamilyOptions()); + descriptorList.add(descriptor); + descriptorMap.put(name, descriptor); + ttls.add(ttl); + } + + try { + rocksdb = + TtlDB.open( + rocksDBOptions.getDbOptions(), this.filePath, descriptorList, handles, ttls, false); + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("open rocksdb error"), e); + } + + if (enableDynamicCreateColumnFamily) { + for (int i = 0; i < validCfList.size(); i++) { + if (validCfList.get(i).contains(RocksdbConfigKeys.VERTEX_CF_PREFIX)) { + vertexHandleMap.put(validCfList.get(i), handles.get(i)); + } else if (validCfList.get(i).contains(RocksdbConfigKeys.EDGE_CF_PREFIX)) { + edgeHandleMap.put(validCfList.get(i), handles.get(i)); + } else { + handleMap.put(validCfList.get(i), handles.get(i)); + } } - - if (rocksdb == null) { - initRocksDbOptions(); - LOGGER.info("ThreadId {}, buildDB {}", Thread.currentThread().getId(), filePath); - int ttl = this.config.getInteger(RocksdbConfigKeys.ROCKSDB_TTL_SECOND); - this.descriptorMap.clear(); - List handles = new ArrayList<>(); - List ttls = new ArrayList<>(); - - List validCfList = cfList; - if (enableDynamicCreateColumnFamily) { - try { - List cfNames = RocksDB.listColumnFamilies(new Options(), filePath); - if (!cfNames.isEmpty()) { - validCfList = new ArrayList<>(); - for (byte[] cfName : cfNames) { - validCfList.add(new String(cfName)); - } - } - } catch (RocksDBException e) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.runError("List column family error"), e); - } - } - - List descriptorList = new ArrayList<>(); - for (String name : validCfList) { - ColumnFamilyDescriptor descriptor = new ColumnFamilyDescriptor(name.getBytes(), - rocksDBOptions.buildFamilyOptions()); - descriptorList.add(descriptor); - descriptorMap.put(name, descriptor); - ttls.add(ttl); - } - - try { - rocksdb = TtlDB.open(rocksDBOptions.getDbOptions(), this.filePath, descriptorList, - handles, ttls, false); - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("open rocksdb error"), - e); - } - - if (enableDynamicCreateColumnFamily) { - for (int i = 0; i < validCfList.size(); i++) { - if (validCfList.get(i).contains(RocksdbConfigKeys.VERTEX_CF_PREFIX)) { - vertexHandleMap.put(validCfList.get(i), handles.get(i)); - } else if (validCfList.get(i).contains(RocksdbConfigKeys.EDGE_CF_PREFIX)) { - edgeHandleMap.put(validCfList.get(i), handles.get(i)); - } else { - handleMap.put(validCfList.get(i), handles.get(i)); - } - } - } else { - for (int i = 0; i < validCfList.size(); i++) { - handleMap.put(validCfList.get(i), handles.get(i)); - } - } + } else { + for (int i = 0; i < validCfList.size(); i++) { + handleMap.put(validCfList.get(i), handles.get(i)); } + } } + } - public Map getColumnFamilyHandleMap() { - return handleMap; - } + public Map getColumnFamilyHandleMap() { + return handleMap; + } - public Map getVertexHandleMap() { - return vertexHandleMap; - } + public Map getVertexHandleMap() { + return vertexHandleMap; + } - public Map getEdgeHandleMap() { - return edgeHandleMap; - } + public Map getEdgeHandleMap() { + return edgeHandleMap; + } - public void flush() { - try { - this.rocksdb.flush(rocksDBOptions.getFlushOptions()); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb compact error"), - e); - } + public void flush() { + try { + this.rocksdb.flush(rocksDBOptions.getFlushOptions()); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb compact error"), e); } + } - - public void compact() { - try { - this.rocksdb.compactRange(); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb compact error"), - e); - } + public void compact() { + try { + this.rocksdb.compactRange(); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb compact error"), e); } - - public void checkpoint(String path) { - Checkpoint checkpoint = Checkpoint.create(rocksdb); - LOGGER.info("Delete path: {}", path); - FileUtils.deleteQuietly(new File(path)); - try { - checkpoint.createCheckpoint(path); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb chk error"), e); - } - checkpoint.close(); + } + + public void checkpoint(String path) { + Checkpoint checkpoint = Checkpoint.create(rocksdb); + LOGGER.info("Delete path: {}", path); + FileUtils.deleteQuietly(new File(path)); + try { + checkpoint.createCheckpoint(path); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb chk error"), e); } - - public void write(String cf, byte[] key, byte[] value) { - try { - this.rocksdb.put(handleMap.get(cf), key, value); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); - } + checkpoint.close(); + } + + public void write(String cf, byte[] key, byte[] value) { + try { + this.rocksdb.put(handleMap.get(cf), key, value); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); } + } - public void write(ColumnFamilyHandle handle, byte[] key, byte[] value) { - try { - this.rocksdb.put(handle, key, value); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); - } + public void write(ColumnFamilyHandle handle, byte[] key, byte[] value) { + try { + this.rocksdb.put(handle, key, value); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); } + } - public void write(WriteBatch writeBatch) { - try { - this.rocksdb.write(rocksDBOptions.getWriteOptions(), writeBatch); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); - } + public void write(WriteBatch writeBatch) { + try { + this.rocksdb.write(rocksDBOptions.getWriteOptions(), writeBatch); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); } - - public void write(String cf, List> list) { - try { - WriteBatch writeBatch = new WriteBatch(); - for (Tuple tuple : list) { - writeBatch.put(handleMap.get(cf), tuple.f0, tuple.f1); - } - this.rocksdb.write(rocksDBOptions.getWriteOptions(), writeBatch); - writeBatch.clear(); - writeBatch.close(); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); - } + } + + public void write(String cf, List> list) { + try { + WriteBatch writeBatch = new WriteBatch(); + for (Tuple tuple : list) { + writeBatch.put(handleMap.get(cf), tuple.f0, tuple.f1); + } + this.rocksdb.write(rocksDBOptions.getWriteOptions(), writeBatch); + writeBatch.clear(); + writeBatch.close(); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb put error"), e); } + } - public byte[] get(String cf, byte[] key) { - try { - return this.rocksdb.get(handleMap.get(cf), key); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb get error"), e); - } + public byte[] get(String cf, byte[] key) { + try { + return this.rocksdb.get(handleMap.get(cf), key); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb get error"), e); } + } - public byte[] get(ColumnFamilyHandle handle, byte[] key) { - try { - return this.rocksdb.get(handle, key); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb get error"), e); - } - } - - public void delete(String cf, byte[] key) { - try { - this.rocksdb.delete(handleMap.get(cf), key); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb delete error"), - e); - } + public byte[] get(ColumnFamilyHandle handle, byte[] key) { + try { + return this.rocksdb.get(handle, key); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb get error"), e); } + } - public RocksIterator getIterator(String cf) { - return this.rocksdb.newIterator(handleMap.get(cf)); + public void delete(String cf, byte[] key) { + try { + this.rocksdb.delete(handleMap.get(cf), key); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("rocksdb delete error"), e); } - - public RocksIterator getIterator(ColumnFamilyHandle handle) { - return this.rocksdb.newIterator(handle); - } - - public void close() { - if (rocksdb != null) { - this.rocksdb.close(); - this.rocksDBOptions.close(); - this.descriptorMap.forEach((k, d) -> d.getOptions().close()); - this.rocksDBOptions = null; - this.rocksdb = null; - } + } + + public RocksIterator getIterator(String cf) { + return this.rocksdb.newIterator(handleMap.get(cf)); + } + + public RocksIterator getIterator(ColumnFamilyHandle handle) { + return this.rocksdb.newIterator(handle); + } + + public void close() { + if (rocksdb != null) { + this.rocksdb.close(); + this.rocksDBOptions.close(); + this.descriptorMap.forEach((k, d) -> d.getOptions().close()); + this.rocksDBOptions = null; + this.rocksdb = null; } + } - public void drop() { - close(); - FileUtils.deleteQuietly(new File(this.filePath)); - } + public void drop() { + close(); + FileUtils.deleteQuietly(new File(this.filePath)); + } - public RocksDB getRocksdb() { - return rocksdb; - } + public RocksDB getRocksdb() { + return rocksdb; + } - public IRocksDBOptions getRocksDBOptions() { - return rocksDBOptions; - } + public IRocksDBOptions getRocksDBOptions() { + return rocksDBOptions; + } - public Map getDescriptorMap() { - return descriptorMap; - } + public Map getDescriptorMap() { + return descriptorMap; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeys.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeys.java index 6f9290034..13f034a2a 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeys.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeys.java @@ -26,91 +26,90 @@ public class RocksdbConfigKeys { - public static final String CHK_SUFFIX = "_chk"; - public static final String DEFAULT_CF = "default"; - public static final String VERTEX_CF = "default"; - public static final String EDGE_CF = "e"; - public static final String VERTEX_INDEX_CF = "v_index"; - public static final char FILE_DOT = '.'; - public static final String VERTEX_CF_PREFIX = "v_"; - public static final String EDGE_CF_PREFIX = "e_"; - - - public static String getChkPath(String path, long checkpointId) { - return path + CHK_SUFFIX + checkpointId; - } - - public static boolean isChkPath(String path) { - // tmp file may exist. - return path.contains(CHK_SUFFIX) && path.indexOf(FILE_DOT) == -1; - } - - public static String getChkPathPrefix(String path) { - int end = path.indexOf(CHK_SUFFIX) + CHK_SUFFIX.length(); - return path.substring(0, end); - } - - public static long getChkIdFromChkPath(String path) { - return Long.parseLong(path.substring(path.lastIndexOf("chk") + 3)); - } - - public static final ConfigKey ROCKSDB_OPTION_CLASS = ConfigKeys - .key("geaflow.store.rocksdb.option.class") - .defaultValue(DefaultGraphOptions.class.getCanonicalName()) - .description("rocksdb option class"); - - public static final ConfigKey ROCKSDB_OPTIONS_TABLE_BLOCK_SIZE = ConfigKeys - .key("geaflow.store.rocksdb.table.block.size") - .defaultValue(128 * SizeUnit.KB) - .description("rocksdb table block size, default 128KB"); - - public static final ConfigKey ROCKSDB_OPTIONS_TABLE_BLOCK_CACHE_SIZE = ConfigKeys - .key("geaflow.store.rocksdb.table.block.cache.size") - .defaultValue(1024 * SizeUnit.MB) - .description("rocksdb table block cache size, default 1G"); - - public static final ConfigKey ROCKSDB_OPTIONS_MAX_WRITER_BUFFER_NUM = ConfigKeys - .key("geaflow.store.rocksdb.max.write.buffer.number") - .defaultValue(2) - .description("rocksdb max write buffer number, default 2"); - - public static final ConfigKey ROCKSDB_OPTIONS_WRITER_BUFFER_SIZE = ConfigKeys - .key("geaflow.store.rocksdb.write.buffer.size") - .defaultValue(128 * SizeUnit.MB) - .description("rocksdb write buffer size, default 128MB"); - - public static final ConfigKey ROCKSDB_OPTIONS_TARGET_FILE_SIZE = ConfigKeys - .key("geaflow.store.rocksdb.target.file.size") - .defaultValue(1024 * SizeUnit.MB) - .description("rocksdb target file size, default 1GB"); - - public static final ConfigKey ROCKSDB_STATISTICS_ENABLE = ConfigKeys - .key("geaflow.store.rocksdb.statistics.enable") - .defaultValue(false) - .description("rocksdb statistics, default false"); - - public static final ConfigKey ROCKSDB_TTL_SECOND = ConfigKeys - .key("geaflow.store.rocksdb.ttl.second") - .defaultValue(10 * 365 * 24 * 3600) // 10 years. - .description("rocksdb default ttl, default never ttl"); - - public static final ConfigKey ROCKSDB_PERSISTENT_CLEAN_THREAD_SIZE = ConfigKeys - .key("geaflow.store.rocksdb.persistent.clean.thread.size") - .defaultValue(4) - .description("rocksdb persistent clean thread size, default 4"); - - public static final ConfigKey ROCKSDB_GRAPH_STORE_PARTITION_TYPE = ConfigKeys - .key("geaflow.store.rocksdb.graph.store.partition.type") - .defaultValue("none") // Default none partition - .description("rocksdb graph store partition type, default none"); - - public static final ConfigKey ROCKSDB_GRAPH_STORE_DT_START = ConfigKeys - .key("geaflow.store.rocksdb.graph.store.dt.start") - .defaultValue("1735660800") // Default start timestamp 2025-01-01 00:00:00 - .description("rocksdb graph store start timestamp for dt partition"); - - public static final ConfigKey ROCKSDB_GRAPH_STORE_DT_CYCLE = ConfigKeys - .key("geaflow.store.rocksdb.graph.store.dt.cycle") - .defaultValue("2592000") // Default timestamp cycle 30 days - .description("rocksdb graph store start timestamp for dt partition"); + public static final String CHK_SUFFIX = "_chk"; + public static final String DEFAULT_CF = "default"; + public static final String VERTEX_CF = "default"; + public static final String EDGE_CF = "e"; + public static final String VERTEX_INDEX_CF = "v_index"; + public static final char FILE_DOT = '.'; + public static final String VERTEX_CF_PREFIX = "v_"; + public static final String EDGE_CF_PREFIX = "e_"; + + public static String getChkPath(String path, long checkpointId) { + return path + CHK_SUFFIX + checkpointId; + } + + public static boolean isChkPath(String path) { + // tmp file may exist. + return path.contains(CHK_SUFFIX) && path.indexOf(FILE_DOT) == -1; + } + + public static String getChkPathPrefix(String path) { + int end = path.indexOf(CHK_SUFFIX) + CHK_SUFFIX.length(); + return path.substring(0, end); + } + + public static long getChkIdFromChkPath(String path) { + return Long.parseLong(path.substring(path.lastIndexOf("chk") + 3)); + } + + public static final ConfigKey ROCKSDB_OPTION_CLASS = + ConfigKeys.key("geaflow.store.rocksdb.option.class") + .defaultValue(DefaultGraphOptions.class.getCanonicalName()) + .description("rocksdb option class"); + + public static final ConfigKey ROCKSDB_OPTIONS_TABLE_BLOCK_SIZE = + ConfigKeys.key("geaflow.store.rocksdb.table.block.size") + .defaultValue(128 * SizeUnit.KB) + .description("rocksdb table block size, default 128KB"); + + public static final ConfigKey ROCKSDB_OPTIONS_TABLE_BLOCK_CACHE_SIZE = + ConfigKeys.key("geaflow.store.rocksdb.table.block.cache.size") + .defaultValue(1024 * SizeUnit.MB) + .description("rocksdb table block cache size, default 1G"); + + public static final ConfigKey ROCKSDB_OPTIONS_MAX_WRITER_BUFFER_NUM = + ConfigKeys.key("geaflow.store.rocksdb.max.write.buffer.number") + .defaultValue(2) + .description("rocksdb max write buffer number, default 2"); + + public static final ConfigKey ROCKSDB_OPTIONS_WRITER_BUFFER_SIZE = + ConfigKeys.key("geaflow.store.rocksdb.write.buffer.size") + .defaultValue(128 * SizeUnit.MB) + .description("rocksdb write buffer size, default 128MB"); + + public static final ConfigKey ROCKSDB_OPTIONS_TARGET_FILE_SIZE = + ConfigKeys.key("geaflow.store.rocksdb.target.file.size") + .defaultValue(1024 * SizeUnit.MB) + .description("rocksdb target file size, default 1GB"); + + public static final ConfigKey ROCKSDB_STATISTICS_ENABLE = + ConfigKeys.key("geaflow.store.rocksdb.statistics.enable") + .defaultValue(false) + .description("rocksdb statistics, default false"); + + public static final ConfigKey ROCKSDB_TTL_SECOND = + ConfigKeys.key("geaflow.store.rocksdb.ttl.second") + .defaultValue(10 * 365 * 24 * 3600) // 10 years. + .description("rocksdb default ttl, default never ttl"); + + public static final ConfigKey ROCKSDB_PERSISTENT_CLEAN_THREAD_SIZE = + ConfigKeys.key("geaflow.store.rocksdb.persistent.clean.thread.size") + .defaultValue(4) + .description("rocksdb persistent clean thread size, default 4"); + + public static final ConfigKey ROCKSDB_GRAPH_STORE_PARTITION_TYPE = + ConfigKeys.key("geaflow.store.rocksdb.graph.store.partition.type") + .defaultValue("none") // Default none partition + .description("rocksdb graph store partition type, default none"); + + public static final ConfigKey ROCKSDB_GRAPH_STORE_DT_START = + ConfigKeys.key("geaflow.store.rocksdb.graph.store.dt.start") + .defaultValue("1735660800") // Default start timestamp 2025-01-01 00:00:00 + .description("rocksdb graph store start timestamp for dt partition"); + + public static final ConfigKey ROCKSDB_GRAPH_STORE_DT_CYCLE = + ConfigKeys.key("geaflow.store.rocksdb.graph.store.dt.cycle") + .defaultValue("2592000") // Default timestamp cycle 30 days + .description("rocksdb graph store start timestamp for dt partition"); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbPersistClient.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbPersistClient.java index 6a3393fa6..4e3f6c18e 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbPersistClient.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbPersistClient.java @@ -19,11 +19,6 @@ package org.apache.geaflow.store.rocksdb; -import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; -import com.google.common.base.Splitter; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.io.File; import java.io.FilenameFilter; import java.io.IOException; @@ -52,6 +47,7 @@ import java.util.concurrent.ThreadPoolExecutor.DiscardOldestPolicy; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; + import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.concurrent.BasicThreadFactory; @@ -70,524 +66,598 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + public class RocksdbPersistClient { - private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbPersistClient.class); - - private static final String COMMIT_TAG_FILE = "_commit"; - private static final String FILES = "FILES"; - private static final int DELETE_CAPACITY = 64; - private static final String DATAS = "datas"; - private static final String META = "meta"; - private static final String FILE_SEPARATOR = ","; - private static final String SST_SUFFIX = "sst"; - - private final Long persistTimeout; - private final IPersistentIO persistIO; - private final NavigableMap checkPointFileInfo; - private final ExecutorService copyFileService; - private final ExecutorService deleteFileService; - private final ExecutorService backgroundDeleteService = new ThreadPoolExecutor( - 1, 1, 300L, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), - new BasicThreadFactory.Builder().namingPattern("asyncDeletes-%d").daemon(true).build(), new DiscardOldestPolicy()); - - public RocksdbPersistClient(Configuration configuration) { - this.persistIO = PersistentIOBuilder.build(configuration); - this.checkPointFileInfo = new ConcurrentSkipListMap<>(); - int persistThreadNum = configuration.getInteger(FileConfigKeys.PERSISTENT_THREAD_SIZE); - int persistCleanThreadNum = configuration.getInteger(RocksdbConfigKeys.ROCKSDB_PERSISTENT_CLEAN_THREAD_SIZE); - this.persistTimeout = (long) configuration.getInteger( - StateConfigKeys.STATE_ROCKSDB_PERSIST_TIMEOUT_SECONDS); - copyFileService = Executors.getExecutorService(1, persistThreadNum, "persist-%d"); - deleteFileService = Executors.getService(persistCleanThreadNum, DELETE_CAPACITY, 300L, TimeUnit.SECONDS); - ((ThreadPoolExecutor) deleteFileService).setRejectedExecutionHandler(new ThreadPoolExecutor.DiscardOldestPolicy()); - } - - public void clearFileInfo() { - checkPointFileInfo.clear(); - } - - public long getSstIndex(String filename) { - try { - return Long.parseLong(filename.substring(0, filename.indexOf(RocksdbConfigKeys.FILE_DOT))); - } catch (Throwable ignore) { - LOGGER.warn("filename {} is abnormal", filename); - return 0; + private static final Logger LOGGER = LoggerFactory.getLogger(RocksdbPersistClient.class); + + private static final String COMMIT_TAG_FILE = "_commit"; + private static final String FILES = "FILES"; + private static final int DELETE_CAPACITY = 64; + private static final String DATAS = "datas"; + private static final String META = "meta"; + private static final String FILE_SEPARATOR = ","; + private static final String SST_SUFFIX = "sst"; + + private final Long persistTimeout; + private final IPersistentIO persistIO; + private final NavigableMap checkPointFileInfo; + private final ExecutorService copyFileService; + private final ExecutorService deleteFileService; + private final ExecutorService backgroundDeleteService = + new ThreadPoolExecutor( + 1, + 1, + 300L, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(1), + new BasicThreadFactory.Builder().namingPattern("asyncDeletes-%d").daemon(true).build(), + new DiscardOldestPolicy()); + + public RocksdbPersistClient(Configuration configuration) { + this.persistIO = PersistentIOBuilder.build(configuration); + this.checkPointFileInfo = new ConcurrentSkipListMap<>(); + int persistThreadNum = configuration.getInteger(FileConfigKeys.PERSISTENT_THREAD_SIZE); + int persistCleanThreadNum = + configuration.getInteger(RocksdbConfigKeys.ROCKSDB_PERSISTENT_CLEAN_THREAD_SIZE); + this.persistTimeout = + (long) configuration.getInteger(StateConfigKeys.STATE_ROCKSDB_PERSIST_TIMEOUT_SECONDS); + copyFileService = Executors.getExecutorService(1, persistThreadNum, "persist-%d"); + deleteFileService = + Executors.getService(persistCleanThreadNum, DELETE_CAPACITY, 300L, TimeUnit.SECONDS); + ((ThreadPoolExecutor) deleteFileService) + .setRejectedExecutionHandler(new ThreadPoolExecutor.DiscardOldestPolicy()); + } + + public void clearFileInfo() { + checkPointFileInfo.clear(); + } + + public long getSstIndex(String filename) { + try { + return Long.parseLong(filename.substring(0, filename.indexOf(RocksdbConfigKeys.FILE_DOT))); + } catch (Throwable ignore) { + LOGGER.warn("filename {} is abnormal", filename); + return 0; + } + } + + private long getMetaFileId(String fileName) { + return Long.parseLong(fileName.substring(fileName.indexOf(RocksdbConfigKeys.FILE_DOT) + 1)); + } + + private static String getMetaFileName(long chkId) { + return META + RocksdbConfigKeys.FILE_DOT + chkId; + } + + public void archive(long chkId, String localChkPath, String remotePath, long keepCheckpointNum) + throws Exception { + Set lastFullFiles = getLastFullFiles(chkId, localChkPath, remotePath); + + CheckPointFileInfo currentFileInfo = new CheckPointFileInfo(chkId); + List> callers = new ArrayList<>(); + + File localChkFile = new File(localChkPath); + String[] sstFileNames = localChkFile.list((dir, name) -> name.endsWith(SST_SUFFIX)); + FileUtils.write( + FileUtils.getFile(localChkFile, FILES), + Joiner.on(FILE_SEPARATOR).join(sstFileNames), + Charset.defaultCharset()); + + // copy sst files. + long size = 0L; + String dataPath = Paths.get(remotePath, DATAS).toString(); + for (String subFileName : sstFileNames) { + currentFileInfo.addFullFile(subFileName); + if (!lastFullFiles.contains(subFileName)) { + currentFileInfo.addIncDataFile(subFileName); + File tmp = FileUtils.getFile(localChkFile, subFileName); + callers.add( + copyFromLocal( + new Path(tmp.getAbsolutePath()), new Path(dataPath, subFileName), tmp.length())); + size = size + tmp.length(); + } + } + String[] metaFileNames = localChkFile.list((dir, name) -> !name.endsWith(SST_SUFFIX)); + String metaPath = Paths.get(remotePath, getMetaFileName(chkId)).toString(); + for (String metaFileName : metaFileNames) { + File tmp = FileUtils.getFile(localChkFile, metaFileName); + callers.add( + copyFromLocal( + new Path(tmp.getAbsolutePath()), new Path(metaPath, metaFileName), tmp.length())); + size = size + tmp.length(); + } + LOGGER.info( + "checkpointId {}, full {}, lastFullFiles {}, currentIncre {}", + chkId, + Arrays.toString(sstFileNames), + lastFullFiles, + currentFileInfo.getIncDataFiles()); + + final long startTime = System.nanoTime(); + completeHandler(callers, copyFileService); + callers.clear(); + persistIO.createNewFile(new Path(metaPath, COMMIT_TAG_FILE)); + double costMs = (System.nanoTime() - startTime) / 1000000.0; + + LOGGER.info( + "RocksDB {} archive local:{} to {} (incre[{}]/full[{}]) took {}ms. incre data size {}KB," + + " speed {}KB/s {}", + persistIO.getPersistentType(), + localChkFile.getAbsolutePath(), + remotePath, + currentFileInfo.getIncDataFiles().size(), + currentFileInfo.getFullDataFiles().size(), + costMs, + size / 1024, + size * 1000 / (1024 * costMs), + currentFileInfo.getIncDataFiles().toString()); + + checkPointFileInfo.put(chkId, currentFileInfo); + + backgroundDeleteService.execute( + () -> cleanLocalAndRemoteFiles(chkId, remotePath, keepCheckpointNum, localChkFile)); + } + + public long getLatestCheckpointId(String remotePathStr) { + try { + if (!persistIO.exists(new Path(remotePathStr))) { + return -1; + } + List files = persistIO.listFileName(new Path(remotePathStr)); + List chkIds = + files.stream() + .filter(f -> f.startsWith(META)) + .map(this::getMetaFileId) + .filter(f -> f > 0) + .sorted(Collections.reverseOrder()) + .collect(Collectors.toList()); + LOGGER.info("find available chk {}", chkIds); + for (Long chkId : chkIds) { + String path = Paths.get(remotePathStr, getMetaFileName(chkId), COMMIT_TAG_FILE).toString(); + if (persistIO.exists(new Path(path))) { + return chkId; + } else { + LOGGER.info("chk {} has no path {}", chkId, path); } + } + } catch (IOException e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError("recover fail"), e); } - - private long getMetaFileId(String fileName) { - return Long.parseLong(fileName.substring(fileName.indexOf(RocksdbConfigKeys.FILE_DOT) + 1)); + return -1; + } + + public void recover(long chkId, String localRdbPath, String localChkPath, String remotePathStr) + throws Exception { + checkPointFileInfo.clear(); + File rocksDBChkFile = new File(localChkPath); + File rocksDBFile = new File(localRdbPath); + LOGGER.info("delete {} {}", localChkPath, localRdbPath); + FileUtils.deleteQuietly(rocksDBChkFile); + FileUtils.deleteQuietly(rocksDBFile); + + rocksDBChkFile.mkdirs(); + rocksDBFile.mkdirs(); + + final long startTime = System.currentTimeMillis(); + Path remotePath = new Path(remotePathStr); + if (!persistIO.exists(remotePath)) { + String msg = String.format("checkPoint: %s is not exist in remote", remotePath); + LOGGER.warn(msg); + throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError(msg)); } - private static String getMetaFileName(long chkId) { - return META + RocksdbConfigKeys.FILE_DOT + chkId; + // fetch manifests. + String remoteMeta = Paths.get(remotePathStr, getMetaFileName(chkId)).toString(); + InputStream in = persistIO.open(new Path(remoteMeta, FILES)); + String sstString = IOUtils.toString(in, Charset.defaultCharset()); + List list = Splitter.on(FILE_SEPARATOR).omitEmptyStrings().splitToList(sstString); + + CheckPointFileInfo commitedInfo = new CheckPointFileInfo(chkId); + recoveryData(remotePath, rocksDBChkFile, commitedInfo, list, remoteMeta); + LOGGER.info( + "recoveryFromRemote {} cost {}ms", remotePath, System.currentTimeMillis() - startTime); + checkPointFileInfo.put(chkId, commitedInfo); + + for (File file : rocksDBChkFile.listFiles()) { + Files.createLink( + FileSystems.getDefault().getPath(localRdbPath, file.getName()), file.toPath()); } - public void archive(long chkId, String localChkPath, String remotePath, - long keepCheckpointNum) throws Exception { - Set lastFullFiles = getLastFullFiles(chkId, localChkPath, remotePath); + backgroundDeleteService.execute(() -> cleanLocalChk(chkId, new File(localChkPath))); + } + + private static void cleanLocalChk(long chkId, File localChkFile) { + String chkPrefix = RocksdbConfigKeys.getChkPathPrefix(localChkFile.getName()); + FilenameFilter filter = + (dir, name) -> { + if (RocksdbConfigKeys.isChkPath(name) && name.startsWith(chkPrefix)) { + return chkId > RocksdbConfigKeys.getChkIdFromChkPath(name); + } else { + return false; + } + }; + File[] subFiles = localChkFile.getParentFile().listFiles(filter); + for (File path : subFiles) { + LOGGER.info("delete local chk {}", path.toURI()); + FileUtils.deleteQuietly(path); + } + } + + private Set getLastFullFiles(long chkId, String localChkPath, String remotePath) + throws IOException { + CheckPointFileInfo commitFileInfo = checkPointFileInfo.get(chkId); + if (commitFileInfo == null) { + Entry info = checkPointFileInfo.lowerEntry(chkId); + if (info != null) { + commitFileInfo = info.getValue(); + } else { + Path path = new Path(remotePath); + PathFilter filter = path1 -> path1.getName().startsWith(META); + if (persistIO.exists(path)) { + FileInfo[] metaFileStatuses = persistIO.listFileInfo(path, filter); + Path lastMetaPath = getLastMetaFile(chkId, metaFileStatuses); + if (lastMetaPath != null) { + commitFileInfo = new CheckPointFileInfo(chkId); + commitFileInfo.addFullFiles(getKeptFileName(lastMetaPath)); + } + } + } + } - CheckPointFileInfo currentFileInfo = new CheckPointFileInfo(chkId); - List> callers = new ArrayList<>(); + Set lastFullFiles; + if (commitFileInfo != null) { + lastFullFiles = new HashSet<>(commitFileInfo.getFullDataFiles()); + } else { + lastFullFiles = new HashSet<>(); + } + File file = new File(localChkPath); + + // current sst number must be larger than the last one. + String[] curNames = file.list(); + Preconditions.checkNotNull(curNames, localChkPath + " is null"); + + Optional chkLargestSst = + Arrays.stream(curNames) + .filter(c -> c.endsWith(SST_SUFFIX)) + .map(this::getSstIndex) + .max(Long::compareTo); + Optional lastLargestSst = + lastFullFiles.stream() + .filter(c -> c.endsWith(SST_SUFFIX)) + .map(this::getSstIndex) + .max(Long::compareTo); + if (chkLargestSst.isPresent() && lastLargestSst.isPresent()) { + Preconditions.checkArgument( + chkLargestSst.get().compareTo(lastLargestSst.get()) >= 0, + "%s < %s, chk path %s, check FO and recovery.", + chkLargestSst.get(), + lastLargestSst.get(), + localChkPath); + } + return lastFullFiles; + } + + private void cleanLocalAndRemoteFiles( + long chkId, String remotePath, long keepCheckpointNum, File localChkFile) { + try { + removeEarlyChk(remotePath, chkId - keepCheckpointNum); + } catch (IOException ignore) { + LOGGER.warn( + "remove Early chk fail and ignore {}, chkId {}, keepChkNum {}", + remotePath, + chkId, + keepCheckpointNum); + } + Long key; + while ((key = checkPointFileInfo.lowerKey(chkId)) != null) { + checkPointFileInfo.remove(key); + } - File localChkFile = new File(localChkPath); - String[] sstFileNames = localChkFile.list((dir, name) -> name.endsWith(SST_SUFFIX)); - FileUtils.write(FileUtils.getFile(localChkFile, FILES), Joiner.on(FILE_SEPARATOR).join(sstFileNames), - Charset.defaultCharset()); + cleanLocalChk(chkId, localChkFile); + } - // copy sst files. - long size = 0L; - String dataPath = Paths.get(remotePath, DATAS).toString(); - for (String subFileName : sstFileNames) { - currentFileInfo.addFullFile(subFileName); - if (!lastFullFiles.contains(subFileName)) { - currentFileInfo.addIncDataFile(subFileName); - File tmp = FileUtils.getFile(localChkFile, subFileName); - callers.add(copyFromLocal(new Path(tmp.getAbsolutePath()), - new Path(dataPath, subFileName), tmp.length())); - size = size + tmp.length(); - } - } - String[] metaFileNames = localChkFile.list((dir, name) -> !name.endsWith(SST_SUFFIX)); - String metaPath = Paths.get(remotePath, getMetaFileName(chkId)).toString(); - for (String metaFileName : metaFileNames) { - File tmp = FileUtils.getFile(localChkFile, metaFileName); - callers.add(copyFromLocal(new Path(tmp.getAbsolutePath()), - new Path(metaPath, metaFileName), tmp.length())); - size = size + tmp.length(); - } - LOGGER.info("checkpointId {}, full {}, lastFullFiles {}, currentIncre {}", chkId, - Arrays.toString(sstFileNames), lastFullFiles, currentFileInfo.getIncDataFiles()); + private void removeEarlyChk(String remotePath, long chkId) throws IOException { + final long start = System.currentTimeMillis(); + LOGGER.info("skip remove early chk {} {}", remotePath, chkId); - final long startTime = System.nanoTime(); - completeHandler(callers, copyFileService); - callers.clear(); - persistIO.createNewFile(new Path(metaPath, COMMIT_TAG_FILE)); - double costMs = (System.nanoTime() - startTime) / 1000000.0; + FileInfo[] sstFileStatuses = new FileInfo[] {}; + try { + // if there is no data, the directory will not exist. + sstFileStatuses = persistIO.listFileInfo(new Path(remotePath, DATAS)); + } catch (Exception e) { + LOGGER.warn("{} do not have data, just ignore", remotePath); + } + Path path = new Path(remotePath); + PathFilter filter = path1 -> path1.getName().startsWith(META); + FileInfo[] metaFileStatuses = persistIO.listFileInfo(path, filter); + Path delMetaPath = getLastMetaFile(chkId, metaFileStatuses); + if (delMetaPath == null) { + return; + } + Set toBeKepts = getKeptFileName(delMetaPath); + if (toBeKepts.size() == 0) { + return; + } + // commit tag is the latest file to upload. + long chkPointTime = + persistIO.getFileInfo(new Path(delMetaPath, COMMIT_TAG_FILE)).getModificationTime(); + LOGGER.info( + "remotePath {}, chkId: {}, chkPointTime {}, toBeKepts: {}", + remotePath, + chkId, + new Date(chkPointTime), + toBeKepts); + + List paths = + getDelPaths(chkId, chkPointTime, sstFileStatuses, metaFileStatuses, toBeKepts); + LOGGER.info( + "RocksDB({}) clean dfs checkpoint: ({}) took {}ms", + chkId, + paths.stream().map(Path::getName).collect(Collectors.joining(",")), + (System.currentTimeMillis() - start)); + asyncDeletes(paths); + } + + private List getDelPaths( + long chkId, + long chkPointTime, + FileInfo[] sstFileStatuses, + FileInfo[] metaFileStatuses, + Set toBeKepts) { + + Set toBeDels = new HashSet<>(); + List paths = Lists.newArrayList(); + for (FileInfo fileStatus : sstFileStatuses) { + if (fileStatus.getModificationTime() < chkPointTime + && !toBeKepts.contains(fileStatus.getPath().getName())) { + toBeDels.add(fileStatus.getPath().getName()); + paths.add(fileStatus.getPath()); LOGGER.info( - "RocksDB {} archive local:{} to {} (incre[{}]/full[{}]) took {}ms. incre data size {}KB, speed {}KB/s {}", - persistIO.getPersistentType(), localChkFile.getAbsolutePath(), remotePath, - currentFileInfo.getIncDataFiles().size(), currentFileInfo.getFullDataFiles().size(), - costMs, size / 1024, size * 1000 / (1024 * costMs), - currentFileInfo.getIncDataFiles().toString()); - - checkPointFileInfo.put(chkId, currentFileInfo); - - backgroundDeleteService.execute(() -> - cleanLocalAndRemoteFiles(chkId, remotePath, keepCheckpointNum, localChkFile)); + "delete file: {} time: {}", + fileStatus.getPath(), + new Date(fileStatus.getModificationTime())); + } } + LOGGER.info("kepts: {}, dels: {} ", toBeKepts, toBeDels); - public long getLatestCheckpointId(String remotePathStr) { - try { - if (!persistIO.exists(new Path(remotePathStr))) { - return -1; - } - List files = persistIO.listFileName(new Path(remotePathStr)); - List chkIds = files.stream().filter(f -> f.startsWith(META)).map(this::getMetaFileId) - .filter(f -> f > 0).sorted(Collections.reverseOrder()).collect(Collectors.toList()); - LOGGER.info("find available chk {}", chkIds); - for (Long chkId : chkIds) { - String path = Paths.get(remotePathStr, getMetaFileName(chkId), COMMIT_TAG_FILE).toString(); - if (persistIO.exists(new Path(path))) { - return chkId; - } else { - LOGGER.info("chk {} has no path {}", chkId, path); - } - } - } catch (IOException e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError("recover fail"), e); - } - return -1; + for (final FileInfo fileStatus : metaFileStatuses) { + long chkVersion = getChkVersion(fileStatus.getPath().getName()); + if (chkVersion < chkId) { + paths.add(fileStatus.getPath()); + } } - - public void recover(long chkId, String localRdbPath, String localChkPath, String remotePathStr) - throws Exception { - checkPointFileInfo.clear(); - File rocksDBChkFile = new File(localChkPath); - File rocksDBFile = new File(localRdbPath); - LOGGER.info("delete {} {}", localChkPath, localRdbPath); - FileUtils.deleteQuietly(rocksDBChkFile); - FileUtils.deleteQuietly(rocksDBFile); - - rocksDBChkFile.mkdirs(); - rocksDBFile.mkdirs(); - - final long startTime = System.currentTimeMillis(); - Path remotePath = new Path(remotePathStr); - if (!persistIO.exists(remotePath)) { - String msg = String.format("checkPoint: %s is not exist in remote", remotePath); - LOGGER.warn(msg); - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError(msg)); - } - - // fetch manifests. - String remoteMeta = Paths.get(remotePathStr, getMetaFileName(chkId)).toString(); - InputStream in = persistIO.open(new Path(remoteMeta, FILES)); - String sstString = IOUtils.toString(in, Charset.defaultCharset()); - List list = Splitter.on(FILE_SEPARATOR).omitEmptyStrings().splitToList(sstString); - - CheckPointFileInfo commitedInfo = new CheckPointFileInfo(chkId); - recoveryData(remotePath, rocksDBChkFile, commitedInfo, list, remoteMeta); - LOGGER.info("recoveryFromRemote {} cost {}ms", remotePath, - System.currentTimeMillis() - startTime); - checkPointFileInfo.put(chkId, commitedInfo); - - for (File file : rocksDBChkFile.listFiles()) { - Files.createLink(FileSystems.getDefault().getPath(localRdbPath, file.getName()), file.toPath()); - } - - backgroundDeleteService.execute(() -> cleanLocalChk(chkId, new File(localChkPath))); + return paths; + } + + private Path getLastMetaFile(long chkId, FileInfo[] metaFileStatuses) { + // find the last meta file that indicates the largest committed chkId. + int maxMetaVersion = 0; + FileInfo fileInfo = null; + for (FileInfo fileStatus : metaFileStatuses) { + int metaVersion = getChkVersion(fileStatus.getPath().getName()); + if (metaVersion < chkId && metaVersion > maxMetaVersion) { + maxMetaVersion = metaVersion; + fileInfo = fileStatus; + } } - - private static void cleanLocalChk(long chkId, File localChkFile) { - String chkPrefix = RocksdbConfigKeys.getChkPathPrefix(localChkFile.getName()); - FilenameFilter filter = (dir, name) -> { - if (RocksdbConfigKeys.isChkPath(name) && name.startsWith(chkPrefix)) { - return chkId > RocksdbConfigKeys.getChkIdFromChkPath(name); - } else { - return false; - } - }; - File[] subFiles = localChkFile.getParentFile().listFiles(filter); - for (File path : subFiles) { - LOGGER.info("delete local chk {}", path.toURI()); - FileUtils.deleteQuietly(path); - } + if (maxMetaVersion == 0) { + return null; } - - private Set getLastFullFiles(long chkId, String localChkPath, String remotePath) - throws IOException { - CheckPointFileInfo commitFileInfo = checkPointFileInfo.get(chkId); - if (commitFileInfo == null) { - Entry info = checkPointFileInfo.lowerEntry(chkId); - if (info != null) { - commitFileInfo = info.getValue(); - } else { - Path path = new Path(remotePath); - PathFilter filter = path1 -> path1.getName().startsWith(META); - if (persistIO.exists(path)) { - FileInfo[] metaFileStatuses = persistIO.listFileInfo(path, filter); - Path lastMetaPath = getLastMetaFile(chkId, metaFileStatuses); - if (lastMetaPath != null) { - commitFileInfo = new CheckPointFileInfo(chkId); - commitFileInfo.addFullFiles(getKeptFileName(lastMetaPath)); - } - } - } - } - - Set lastFullFiles; - if (commitFileInfo != null) { - lastFullFiles = new HashSet<>(commitFileInfo.getFullDataFiles()); - } else { - lastFullFiles = new HashSet<>(); - } - File file = new File(localChkPath); - - // current sst number must be larger than the last one. - String[] curNames = file.list(); - Preconditions.checkNotNull(curNames, localChkPath + " is null"); - - Optional chkLargestSst = Arrays.stream(curNames) - .filter(c -> c.endsWith(SST_SUFFIX)).map(this::getSstIndex).max(Long::compareTo); - Optional lastLargestSst = lastFullFiles.stream().filter(c -> c.endsWith(SST_SUFFIX)) - .map(this::getSstIndex).max(Long::compareTo); - if (chkLargestSst.isPresent() && lastLargestSst.isPresent()) { - Preconditions.checkArgument(chkLargestSst.get().compareTo(lastLargestSst.get()) >= 0, - "%s < %s, chk path %s, check FO and recovery.", - chkLargestSst.get(), lastLargestSst.get(), localChkPath); - } - return lastFullFiles; + return fileInfo.getPath(); + } + + private Set getKeptFileName(Path metaPath) throws IOException { + Path filesPath = new Path(metaPath, FILES); + InputStream in = persistIO.open(filesPath); + String sstString = IOUtils.toString(in, Charset.defaultCharset()); + return Sets.newHashSet(Splitter.on(",").split(sstString)); + } + + private int getChkVersion(String filename) { + return Integer.parseInt(filename.substring(filename.indexOf('.') + 1)); + } + + private List completeHandler(List> callers, ExecutorService executorService) { + List> futures = new ArrayList<>(); + List results = new ArrayList<>(); + for (final Callable entry : callers) { + futures.add(executorService.submit(entry)); } - private void cleanLocalAndRemoteFiles(long chkId, String remotePath, long keepCheckpointNum, File localChkFile) { - try { - removeEarlyChk(remotePath, chkId - keepCheckpointNum); - } catch (IOException ignore) { - LOGGER.warn("remove Early chk fail and ignore {}, chkId {}, keepChkNum {}", remotePath, - chkId, keepCheckpointNum); - } - Long key; - while ((key = checkPointFileInfo.lowerKey(chkId)) != null) { - checkPointFileInfo.remove(key); - } - - cleanLocalChk(chkId, localChkFile); + try { + for (Future future : futures) { + results.add(future.get(persistTimeout, TimeUnit.SECONDS)); + } + } catch (Exception e) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.stateRocksDbError("persist time out or other exceptions"), e); } - - private void removeEarlyChk(String remotePath, long chkId) - throws IOException { - final long start = System.currentTimeMillis(); - LOGGER.info("skip remove early chk {} {}", remotePath, chkId); - - FileInfo[] sstFileStatuses = new FileInfo[]{}; + return results; + } + + private Tuple checkSizeSame(final Path dfsPath, final Path localPath) + throws IOException { + long len = persistIO.getFileSize(dfsPath); + File localFile = new File(localPath.toString()); + return Tuple.of(len == localFile.length(), len); + } + + private Callable copyFromLocal(final Path from, final Path to, final long size) { + return () -> { + int count = 0; + int maxTries = 3; + Tuple checkRes; + while (true) { try { - //if there is no data, the directory will not exist. - sstFileStatuses = persistIO.listFileInfo(new Path(remotePath, DATAS)); - } catch (Exception e) { - LOGGER.warn("{} do not have data, just ignore", remotePath); - } - - Path path = new Path(remotePath); - PathFilter filter = path1 -> path1.getName().startsWith(META); - FileInfo[] metaFileStatuses = persistIO.listFileInfo(path, filter); - Path delMetaPath = getLastMetaFile(chkId, metaFileStatuses); - if (delMetaPath == null) { - return; - } - Set toBeKepts = getKeptFileName(delMetaPath); - if (toBeKepts.size() == 0) { - return; - } - // commit tag is the latest file to upload. - long chkPointTime = persistIO.getFileInfo(new Path(delMetaPath, COMMIT_TAG_FILE)).getModificationTime(); - LOGGER.info("remotePath {}, chkId: {}, chkPointTime {}, toBeKepts: {}", - remotePath, chkId, new Date(chkPointTime), toBeKepts); - - List paths = getDelPaths(chkId, chkPointTime, sstFileStatuses, metaFileStatuses, toBeKepts); - LOGGER.info("RocksDB({}) clean dfs checkpoint: ({}) took {}ms", chkId, - paths.stream().map(Path::getName).collect(Collectors.joining(",")), - (System.currentTimeMillis() - start)); - asyncDeletes(paths); - } - - private List getDelPaths(long chkId, long chkPointTime, FileInfo[] sstFileStatuses, - FileInfo[] metaFileStatuses, Set toBeKepts) { - - Set toBeDels = new HashSet<>(); - List paths = Lists.newArrayList(); - for (FileInfo fileStatus : sstFileStatuses) { - if (fileStatus.getModificationTime() < chkPointTime - && !toBeKepts.contains(fileStatus.getPath().getName())) { - toBeDels.add(fileStatus.getPath().getName()); - paths.add(fileStatus.getPath()); - LOGGER.info("delete file: {} time: {}", - fileStatus.getPath(), new Date(fileStatus.getModificationTime())); + long start = System.currentTimeMillis(); + persistIO.copyFromLocalFile(from, to); + checkRes = checkSizeSame(to, from); + if (!checkRes.f0) { + LOGGER.warn("upload to dfs size not same {} -> {}", from, to); + if (++count == maxTries) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.stateRocksDbError("upload to dfs size not same")); } - } - - LOGGER.info("kepts: {}, dels: {} ", toBeKepts, toBeDels); - - for (final FileInfo fileStatus : metaFileStatuses) { - long chkVersion = getChkVersion(fileStatus.getPath().getName()); - if (chkVersion < chkId) { - paths.add(fileStatus.getPath()); + } else { + LOGGER.info( + "upload to dfs size {}KB took {}ms {} -> {}", + size / 1024, + System.currentTimeMillis() - start, + from, + to); + break; + } + } catch (IOException ex) { + if (++count == maxTries) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.stateRocksDbError("upload to dfs exception"), ex); + } + } + } + return checkRes.f1; + }; + } + + private Callable copyToLocal(final Path from, final Path to) { + return () -> { + int count = 0; + int maxTries = 3; + Tuple checkRes; + while (true) { + try { + persistIO.copyToLocalFile(from, to); + checkRes = checkSizeSame(from, to); + if (!checkRes.f0) { + LOGGER.warn("download from dfs size not same {} -> {}", from, to); + if (++count == maxTries) { + String msg = "download from dfs size not same: " + from; + throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError(msg)); } - } - return paths; - } - - private Path getLastMetaFile(long chkId, FileInfo[] metaFileStatuses) { - // find the last meta file that indicates the largest committed chkId. - int maxMetaVersion = 0; - FileInfo fileInfo = null; - for (FileInfo fileStatus : metaFileStatuses) { - int metaVersion = getChkVersion(fileStatus.getPath().getName()); - if (metaVersion < chkId && metaVersion > maxMetaVersion) { - maxMetaVersion = metaVersion; - fileInfo = fileStatus; + } else { + LOGGER.info("download from dfs {} -> {}", from, to); + break; + } + } catch (IOException ex) { + if (++count == maxTries) { + throw new GeaflowRuntimeException( + RuntimeErrors.INST.stateRocksDbError("copy from dfs exception"), ex); + } + } + } + return checkRes.f1; + }; + } + + private void asyncDeletes(final List paths) { + deleteFileService.execute( + () -> { + long start = System.currentTimeMillis(); + for (Path path : paths) { + try { + long s = System.nanoTime(); + persistIO.delete(path, true); + LOGGER.info("async Delete path {} cost {}us", path, (System.nanoTime() - s) / 1000); + } catch (IOException e) { + LOGGER.warn("delete fail", e); } - } - if (maxMetaVersion == 0) { - return null; - } - return fileInfo.getPath(); + } + LOGGER.info("asyncDeletes path {} cost {}ms", paths, System.currentTimeMillis() - start); + }); + } + + private long recoveryData( + Path remotePath, + File localChkFile, + CheckPointFileInfo committedInfo, + List list, + String remoteMeta) + throws Exception { + // fetch data list. + LOGGER.info("recoveryData {} list {}", remotePath, list); + + List> callers = new ArrayList<>(); + for (String sstName : list) { + callers.add( + copyToLocal( + new Path(Paths.get(remotePath.toString(), DATAS, sstName).toString()), + new Path(localChkFile.getAbsolutePath(), sstName))); } - - private Set getKeptFileName(Path metaPath) - throws IOException { - Path filesPath = new Path(metaPath, FILES); - InputStream in = persistIO.open(filesPath); - String sstString = IOUtils.toString(in, Charset.defaultCharset()); - return Sets.newHashSet(Splitter.on(",").split(sstString)); + List metaList = persistIO.listFileName(new Path(remoteMeta)); + for (String metaName : metaList) { + callers.add( + copyToLocal( + new Path(remoteMeta, metaName), new Path(localChkFile.getAbsolutePath(), metaName))); } - - private int getChkVersion(String filename) { - return Integer.parseInt(filename.substring(filename.indexOf('.') + 1)); + long start = System.currentTimeMillis(); + List res = completeHandler(callers, copyFileService); + long size = res.stream().mapToLong(i -> i).sum() / 1024; + long speed = 1000 * size / (System.currentTimeMillis() - start + 1); + + LOGGER.info( + "RocksDB {} copy ({} to local:{}) lastCommitInfo:{}. size: {}KB, speed: {}KB/s", + persistIO.getPersistentType(), + remotePath, + localChkFile, + committedInfo, + size / 1024, + speed); + String[] localChkFiles = localChkFile.list((dir, name) -> name.endsWith(SST_SUFFIX)); + if (localChkFiles != null) { + for (String chkFile : localChkFiles) { + committedInfo.addFullFile(chkFile); + } + } else { + Preconditions.checkArgument(list.size() == 0, "sst is not fetched."); } + return size; + } - private List completeHandler(List> callers, - ExecutorService executorService) { - List> futures = new ArrayList<>(); - List results = new ArrayList<>(); - for (final Callable entry : callers) { - futures.add(executorService.submit(entry)); - } + public static class CheckPointFileInfo { + private long checkPointId; + private Set incDataFiles = new HashSet<>(); + private Set fullDataFiles = new HashSet<>(); - try { - for (Future future : futures) { - results.add(future.get(persistTimeout, TimeUnit.SECONDS)); - } - } catch (Exception e) { - throw new GeaflowRuntimeException( - RuntimeErrors.INST.stateRocksDbError("persist time out or other exceptions"), e); - } - return results; - } - - private Tuple checkSizeSame(final Path dfsPath, final Path localPath) - throws IOException { - long len = persistIO.getFileSize(dfsPath); - File localFile = new File(localPath.toString()); - return Tuple.of(len == localFile.length(), len); - } - - private Callable copyFromLocal(final Path from, final Path to, final long size) { - return () -> { - int count = 0; - int maxTries = 3; - Tuple checkRes; - while (true) { - try { - long start = System.currentTimeMillis(); - persistIO.copyFromLocalFile(from, to); - checkRes = checkSizeSame(to, from); - if (!checkRes.f0) { - LOGGER.warn("upload to dfs size not same {} -> {}", from, to); - if (++count == maxTries) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError("upload to dfs size not same")); - } - } else { - LOGGER.info("upload to dfs size {}KB took {}ms {} -> {}", size / 1024, - System.currentTimeMillis() - start, from, to); - break; - } - } catch (IOException ex) { - if (++count == maxTries) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError( - "upload to dfs exception"), ex); - } - } - } - return checkRes.f1; - }; + public CheckPointFileInfo(long checkPointId) { + this.checkPointId = checkPointId; } - private Callable copyToLocal(final Path from, final Path to) { - return () -> { - int count = 0; - int maxTries = 3; - Tuple checkRes; - while (true) { - try { - persistIO.copyToLocalFile(from, to); - checkRes = checkSizeSame(from, to); - if (!checkRes.f0) { - LOGGER.warn("download from dfs size not same {} -> {}", from, to); - if (++count == maxTries) { - String msg = "download from dfs size not same: " + from; - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError(msg)); - } - } else { - LOGGER.info("download from dfs {} -> {}", from, to); - break; - } - } catch (IOException ex) { - if (++count == maxTries) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.stateRocksDbError( - "copy from dfs exception"), ex); - } - } - } - return checkRes.f1; - }; + public long getCheckPointId() { + return checkPointId; } - private void asyncDeletes(final List paths) { - deleteFileService.execute(() -> { - long start = System.currentTimeMillis(); - for (Path path : paths) { - try { - long s = System.nanoTime(); - persistIO.delete(path, true); - LOGGER.info("async Delete path {} cost {}us", path, (System.nanoTime() - s) / 1000); - } catch (IOException e) { - LOGGER.warn("delete fail", e); - } - } - LOGGER.info("asyncDeletes path {} cost {}ms", paths, - System.currentTimeMillis() - start); - }); + public void addIncDataFile(String name) { + incDataFiles.add(name); } - - private long recoveryData(Path remotePath, File localChkFile, - CheckPointFileInfo committedInfo, List list, String remoteMeta) - throws Exception { - // fetch data list. - LOGGER.info("recoveryData {} list {}", remotePath, list); - - List> callers = new ArrayList<>(); - for (String sstName : list) { - callers.add( - copyToLocal(new Path(Paths.get(remotePath.toString(), DATAS, sstName).toString()), - new Path(localChkFile.getAbsolutePath(), sstName))); - } - List metaList = persistIO.listFileName(new Path(remoteMeta)); - for (String metaName : metaList) { - callers.add( - copyToLocal(new Path(remoteMeta, metaName), new Path(localChkFile.getAbsolutePath(), metaName))); - } - long start = System.currentTimeMillis(); - List res = completeHandler(callers, copyFileService); - long size = res.stream().mapToLong(i -> i).sum() / 1024; - long speed = 1000 * size / (System.currentTimeMillis() - start + 1); - - LOGGER.info( - "RocksDB {} copy ({} to local:{}) lastCommitInfo:{}. size: {}KB, speed: {}KB/s", - persistIO.getPersistentType(), remotePath, localChkFile, committedInfo, size / 1024, speed); - String[] localChkFiles = localChkFile.list((dir, name) -> name.endsWith(SST_SUFFIX)); - if (localChkFiles != null) { - for (String chkFile : localChkFiles) { - committedInfo.addFullFile(chkFile); - } - } else { - Preconditions.checkArgument(list.size() == 0, "sst is not fetched."); - } - return size; + public void addFullFile(String name) { + fullDataFiles.add(name); } - public static class CheckPointFileInfo { - private long checkPointId; - private Set incDataFiles = new HashSet<>(); - private Set fullDataFiles = new HashSet<>(); - - public CheckPointFileInfo(long checkPointId) { - this.checkPointId = checkPointId; - } - - public long getCheckPointId() { - return checkPointId; - } - - public void addIncDataFile(String name) { - incDataFiles.add(name); - } - - public void addFullFile(String name) { - fullDataFiles.add(name); - } - - public void addFullFiles(Collection name) { - fullDataFiles.addAll(name); - } + public void addFullFiles(Collection name) { + fullDataFiles.addAll(name); + } - @Override - public String toString() { - return String - .format("CheckPointFileInfo [checkPointId=%d, incDataFiles=%s, fullDataFiles=%s]", - this.checkPointId, this.incDataFiles, this.fullDataFiles); - } + @Override + public String toString() { + return String.format( + "CheckPointFileInfo [checkPointId=%d, incDataFiles=%s, fullDataFiles=%s]", + this.checkPointId, this.incDataFiles, this.fullDataFiles); + } - public Set getIncDataFiles() { - return this.incDataFiles; - } + public Set getIncDataFiles() { + return this.incDataFiles; + } - public Set getFullDataFiles() { - return this.fullDataFiles; - } + public Set getFullDataFiles() { + return this.fullDataFiles; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilder.java index 0073932ac..f9cc2b873 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilder.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -32,41 +33,41 @@ public class RocksdbStoreBuilder implements IStoreBuilder { - private static final StoreDesc STORE_DESC = new RocksdbStoreDesc(); + private static final StoreDesc STORE_DESC = new RocksdbStoreDesc(); - public IBaseStore getStore(DataModel type, Configuration config) { - switch (type) { - case KV: - return new KVRocksdbStoreBase(); - case STATIC_GRAPH: - return new StaticGraphRocksdbStoreBase(); - case DYNAMIC_GRAPH: - return new DynamicGraphRocksdbStoreBase(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); - } + public IBaseStore getStore(DataModel type, Configuration config) { + switch (type) { + case KV: + return new KVRocksdbStoreBase(); + case STATIC_GRAPH: + return new StaticGraphRocksdbStoreBase(); + case DYNAMIC_GRAPH: + return new DynamicGraphRocksdbStoreBase(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.typeSysError("not support " + type)); } + } - @Override - public StoreDesc getStoreDesc() { - return STORE_DESC; - } + @Override + public StoreDesc getStoreDesc() { + return STORE_DESC; + } - @Override - public List supportedDataModel() { - return Arrays.asList(DataModel.KV, DataModel.DYNAMIC_GRAPH, DataModel.STATIC_GRAPH); - } + @Override + public List supportedDataModel() { + return Arrays.asList(DataModel.KV, DataModel.DYNAMIC_GRAPH, DataModel.STATIC_GRAPH); + } - public static class RocksdbStoreDesc implements StoreDesc { + public static class RocksdbStoreDesc implements StoreDesc { - @Override - public boolean isLocalStore() { - return true; - } + @Override + public boolean isLocalStore() { + return true; + } - @Override - public String name() { - return StoreType.ROCKSDB.name(); - } + @Override + public String name() { + return StoreType.ROCKSDB.name(); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/StaticGraphRocksdbStoreBase.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/StaticGraphRocksdbStoreBase.java index 81bffca80..7a9a4f9db 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/StaticGraphRocksdbStoreBase.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/StaticGraphRocksdbStoreBase.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; @@ -42,145 +43,152 @@ import org.apache.geaflow.store.rocksdb.proxy.IGraphRocksdbProxy; import org.apache.geaflow.store.rocksdb.proxy.ProxyBuilder; -public class StaticGraphRocksdbStoreBase extends BaseRocksdbGraphStore implements - IStaticGraphStore { - - private IGraphRocksdbProxy proxy; - private EdgeAtom sortAtom; - private PartitionType partitionType; - - @Override - public void init(StoreContext storeContext) { - // Init partition type for rocksdb graph store - partitionType = PartitionType.getEnum(storeContext.getConfig() - .getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)); - - super.init(storeContext); - IGraphKVEncoder encoder = GraphKVEncoderFactory.build(config, - storeContext.getGraphSchema()); - sortAtom = storeContext.getGraphSchema().getEdgeAtoms().get(1); - this.proxy = ProxyBuilder.build(config, rocksdbClient, encoder); - } - - @Override - protected List getCfList() { - if (!partitionType.isPartition()) { - return Arrays.asList(VERTEX_CF, EDGE_CF); - } - - return Collections.singletonList(DEFAULT_CF); - } - - @Override - public void addEdge(IEdge edge) { - this.proxy.addEdge(edge); - } - - @Override - public void addVertex(IVertex vertex) { - this.proxy.addVertex(vertex); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - return this.proxy.getVertex(sid, pushdown); - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getEdges(sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getOneDegreeGraph(sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return this.proxy.vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - return proxy.vertexIDIterator(pushDown); - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - return proxy.getVertexIterator(pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown) { - return proxy.getVertexIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getEdgeIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getEdgeIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getOneDegreeGraphIterator(pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, IStatePushDown pushdown) { - checkOrderField(pushdown.getOrderFields()); - return proxy.getOneDegreeGraphIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - return proxy.getEdgeProjectIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, IStatePushDown, R> pushdown) { - return proxy.getEdgeProjectIterator(keys, pushdown); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - return proxy.getAggResult(pushdown); - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - return proxy.getAggResult(keys, pushdown); - } - - private void checkOrderField(List orderFields) { - boolean emptyFields = orderFields == null || orderFields.isEmpty(); - boolean checkOk = emptyFields || sortAtom == orderFields.get(0); - if (!checkOk) { - throw new GeaflowRuntimeException(String.format("store is sort by %s but need %s", sortAtom, orderFields.get(0))); - } - } - - @Override - public void flush() { - proxy.flush(); - super.flush(); - } - - @Override - public void close() { - proxy.close(); - super.close(); - } +public class StaticGraphRocksdbStoreBase extends BaseRocksdbGraphStore + implements IStaticGraphStore { + + private IGraphRocksdbProxy proxy; + private EdgeAtom sortAtom; + private PartitionType partitionType; + + @Override + public void init(StoreContext storeContext) { + // Init partition type for rocksdb graph store + partitionType = + PartitionType.getEnum( + storeContext + .getConfig() + .getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)); + + super.init(storeContext); + IGraphKVEncoder encoder = + GraphKVEncoderFactory.build(config, storeContext.getGraphSchema()); + sortAtom = storeContext.getGraphSchema().getEdgeAtoms().get(1); + this.proxy = ProxyBuilder.build(config, rocksdbClient, encoder); + } + + @Override + protected List getCfList() { + if (!partitionType.isPartition()) { + return Arrays.asList(VERTEX_CF, EDGE_CF); + } + + return Collections.singletonList(DEFAULT_CF); + } + + @Override + public void addEdge(IEdge edge) { + this.proxy.addEdge(edge); + } + + @Override + public void addVertex(IVertex vertex) { + this.proxy.addVertex(vertex); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + return this.proxy.getVertex(sid, pushdown); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getEdges(sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getOneDegreeGraph(sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return this.proxy.vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + return proxy.vertexIDIterator(pushDown); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + return proxy.getVertexIterator(pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return proxy.getVertexIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getEdgeIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getEdgeIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getOneDegreeGraphIterator(pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + checkOrderField(pushdown.getOrderFields()); + return proxy.getOneDegreeGraphIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + return proxy.getEdgeProjectIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return proxy.getEdgeProjectIterator(keys, pushdown); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + return proxy.getAggResult(pushdown); + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + return proxy.getAggResult(keys, pushdown); + } + + private void checkOrderField(List orderFields) { + boolean emptyFields = orderFields == null || orderFields.isEmpty(); + boolean checkOk = emptyFields || sortAtom == orderFields.get(0); + if (!checkOk) { + throw new GeaflowRuntimeException( + String.format("store is sort by %s but need %s", sortAtom, orderFields.get(0))); + } + } + + @Override + public void flush() { + proxy.flush(); + super.flush(); + } + + @Override + public void close() { + proxy.close(); + super.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/iterator/RocksdbIterator.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/iterator/RocksdbIterator.java index 879b2671a..e3273702f 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/iterator/RocksdbIterator.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/iterator/RocksdbIterator.java @@ -26,50 +26,50 @@ public class RocksdbIterator implements CloseableIterator> { - private final RocksIterator rocksIt; - private byte[] prefix; - private Tuple next; - private boolean isClosed = false; + private final RocksIterator rocksIt; + private byte[] prefix; + private Tuple next; + private boolean isClosed = false; - public RocksdbIterator(RocksIterator iterator) { - this.rocksIt = iterator; - this.rocksIt.seekToFirst(); - } + public RocksdbIterator(RocksIterator iterator) { + this.rocksIt = iterator; + this.rocksIt.seekToFirst(); + } - public RocksdbIterator(RocksIterator iterator, byte[] prefix) { - this.rocksIt = iterator; - this.prefix = prefix; - this.rocksIt.seek(prefix); - } + public RocksdbIterator(RocksIterator iterator, byte[] prefix) { + this.rocksIt = iterator; + this.prefix = prefix; + this.rocksIt.seek(prefix); + } - private boolean isValid(byte[] key) { - return prefix == null || ByteUtils.isStartsWith(key, prefix); - } + private boolean isValid(byte[] key) { + return prefix == null || ByteUtils.isStartsWith(key, prefix); + } - @Override - public boolean hasNext() { - next = null; - if (!isClosed && this.rocksIt.isValid()) { - next = Tuple.of(this.rocksIt.key(), this.rocksIt.value()); - } - if (next == null || !isValid(next.f0)) { - close(); - return false; - } - return true; + @Override + public boolean hasNext() { + next = null; + if (!isClosed && this.rocksIt.isValid()) { + next = Tuple.of(this.rocksIt.key(), this.rocksIt.value()); } - - @Override - public Tuple next() { - this.rocksIt.next(); - return next; + if (next == null || !isValid(next.f0)) { + close(); + return false; } + return true; + } + + @Override + public Tuple next() { + this.rocksIt.next(); + return next; + } - @Override - public void close() { - if (!isClosed) { - this.rocksIt.close(); - isClosed = true; - } + @Override + public void close() { + if (!isClosed) { + this.rocksIt.close(); + isClosed = true; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/DefaultGraphOptions.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/DefaultGraphOptions.java index 97361050b..197defebd 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/DefaultGraphOptions.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/DefaultGraphOptions.java @@ -36,151 +36,149 @@ public class DefaultGraphOptions implements IRocksDBOptions { - protected Statistics statistics; - - protected Options options = new Options(); - - protected DBOptions dbOptions = new DBOptions(); - - protected WriteOptions writeOptions = new WriteOptions(); - - protected ReadOptions readOptions = new ReadOptions(); - - protected FlushOptions flushOptions = new FlushOptions(); - - protected boolean closed; - private int maxWriteBufferNumber; - private long blockSize; - private long blockCacheSize; - private long writeBufferSize; - private long targetFileSize; - - public DefaultGraphOptions() { - - } - - protected void initOption() { - options.setUseDirectIoForFlushAndCompaction(true); - options.setCreateIfMissing(true); - options.setCreateMissingColumnFamilies(true); - options.setMergeOperator(new StringAppendOperator()); - options.setTableFormatConfig(buildBlockBasedTableConfig()); - options.setMaxWriteBufferNumber(maxWriteBufferNumber); - // Amount of data to build up in memory (backed by an unsorted log - // on disk) before converting to a sorted on-disk file. 64MB default - options.setWriteBufferSize(writeBufferSize); - // target_file_size_base is per-file size for level-1. - options.setTargetFileSizeBase(targetFileSize); - options.setMaxBackgroundFlushes(2); - options.setMaxBackgroundCompactions(2); - options.setLevelZeroFileNumCompactionTrigger(20); - // Soft limit on number of level-0 files. - options.setLevelZeroSlowdownWritesTrigger(30); - // Maximum number of level-0 files. We stop writes at this point. - options.setLevelZeroStopWritesTrigger(40); - options.setNumLevels(4); - options.setMaxManifestFileSize(50 * SizeUnit.KB); - - dbOptions.setCreateIfMissing(true); - dbOptions.setCreateMissingColumnFamilies(true); - dbOptions.setMaxBackgroundFlushes(2); - dbOptions.setMaxBackgroundCompactions(6); - dbOptions.setMaxManifestFileSize(50 * SizeUnit.KB); - - writeOptions.setDisableWAL(true); - flushOptions.setWaitForFlush(true); - } - - @Override - public void init(Configuration config) { - this.maxWriteBufferNumber = - config.getInteger(RocksdbConfigKeys.ROCKSDB_OPTIONS_MAX_WRITER_BUFFER_NUM); - this.writeBufferSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_WRITER_BUFFER_SIZE); - this.targetFileSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TARGET_FILE_SIZE); - this.blockSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TABLE_BLOCK_SIZE); - this.blockCacheSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TABLE_BLOCK_CACHE_SIZE); - - initOption(); - } - - public BlockBasedTableConfig buildBlockBasedTableConfig() { - BlockBasedTableConfig tableOptions = new BlockBasedTableConfig(); - tableOptions.setBlockSize(blockSize); - tableOptions.setBlockCacheSize(blockCacheSize); - tableOptions.setFilter(new BloomFilter(10, false)); - tableOptions.setCacheIndexAndFilterBlocks(true); - tableOptions.setPinL0FilterAndIndexBlocksInCache(true); - return tableOptions; - } - - public Options getOptions() { - return options; - } - - public WriteOptions getWriteOptions() { - return writeOptions; - } - - public ReadOptions getReadOptions() { - return readOptions; - } - - public FlushOptions getFlushOptions() { - return flushOptions; - } - - @Override - public DBOptions getDbOptions() { - return dbOptions; - } - - @Override - public ColumnFamilyOptions buildFamilyOptions() { - ColumnFamilyOptions columnFamilyOptions = new ColumnFamilyOptions(); - columnFamilyOptions.setWriteBufferSize(writeBufferSize); - // target_file_size_base is per-file size for level-1. - columnFamilyOptions.setTargetFileSizeBase(targetFileSize); - columnFamilyOptions.setLevelZeroFileNumCompactionTrigger(20); - // Soft limit on number of level-0 files. - columnFamilyOptions.setLevelZeroSlowdownWritesTrigger(30); - // Maximum number of level-0 files. We stop writes at this point. - columnFamilyOptions.setLevelZeroStopWritesTrigger(40); - - BlockBasedTableConfig tableConfig = buildBlockBasedTableConfig(); - tableConfig.setBlockSize(blockSize); - tableConfig.setBlockCacheSize(blockCacheSize); - columnFamilyOptions.setTableFormatConfig(tableConfig); - columnFamilyOptions.setMaxWriteBufferNumber(2); - - return columnFamilyOptions; + protected Statistics statistics; + + protected Options options = new Options(); + + protected DBOptions dbOptions = new DBOptions(); + + protected WriteOptions writeOptions = new WriteOptions(); + + protected ReadOptions readOptions = new ReadOptions(); + + protected FlushOptions flushOptions = new FlushOptions(); + + protected boolean closed; + private int maxWriteBufferNumber; + private long blockSize; + private long blockCacheSize; + private long writeBufferSize; + private long targetFileSize; + + public DefaultGraphOptions() {} + + protected void initOption() { + options.setUseDirectIoForFlushAndCompaction(true); + options.setCreateIfMissing(true); + options.setCreateMissingColumnFamilies(true); + options.setMergeOperator(new StringAppendOperator()); + options.setTableFormatConfig(buildBlockBasedTableConfig()); + options.setMaxWriteBufferNumber(maxWriteBufferNumber); + // Amount of data to build up in memory (backed by an unsorted log + // on disk) before converting to a sorted on-disk file. 64MB default + options.setWriteBufferSize(writeBufferSize); + // target_file_size_base is per-file size for level-1. + options.setTargetFileSizeBase(targetFileSize); + options.setMaxBackgroundFlushes(2); + options.setMaxBackgroundCompactions(2); + options.setLevelZeroFileNumCompactionTrigger(20); + // Soft limit on number of level-0 files. + options.setLevelZeroSlowdownWritesTrigger(30); + // Maximum number of level-0 files. We stop writes at this point. + options.setLevelZeroStopWritesTrigger(40); + options.setNumLevels(4); + options.setMaxManifestFileSize(50 * SizeUnit.KB); + + dbOptions.setCreateIfMissing(true); + dbOptions.setCreateMissingColumnFamilies(true); + dbOptions.setMaxBackgroundFlushes(2); + dbOptions.setMaxBackgroundCompactions(6); + dbOptions.setMaxManifestFileSize(50 * SizeUnit.KB); + + writeOptions.setDisableWAL(true); + flushOptions.setWaitForFlush(true); + } + + @Override + public void init(Configuration config) { + this.maxWriteBufferNumber = + config.getInteger(RocksdbConfigKeys.ROCKSDB_OPTIONS_MAX_WRITER_BUFFER_NUM); + this.writeBufferSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_WRITER_BUFFER_SIZE); + this.targetFileSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TARGET_FILE_SIZE); + this.blockSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TABLE_BLOCK_SIZE); + this.blockCacheSize = config.getLong(RocksdbConfigKeys.ROCKSDB_OPTIONS_TABLE_BLOCK_CACHE_SIZE); + + initOption(); + } + + public BlockBasedTableConfig buildBlockBasedTableConfig() { + BlockBasedTableConfig tableOptions = new BlockBasedTableConfig(); + tableOptions.setBlockSize(blockSize); + tableOptions.setBlockCacheSize(blockCacheSize); + tableOptions.setFilter(new BloomFilter(10, false)); + tableOptions.setCacheIndexAndFilterBlocks(true); + tableOptions.setPinL0FilterAndIndexBlocksInCache(true); + return tableOptions; + } + + public Options getOptions() { + return options; + } + + public WriteOptions getWriteOptions() { + return writeOptions; + } + + public ReadOptions getReadOptions() { + return readOptions; + } + + public FlushOptions getFlushOptions() { + return flushOptions; + } + + @Override + public DBOptions getDbOptions() { + return dbOptions; + } + + @Override + public ColumnFamilyOptions buildFamilyOptions() { + ColumnFamilyOptions columnFamilyOptions = new ColumnFamilyOptions(); + columnFamilyOptions.setWriteBufferSize(writeBufferSize); + // target_file_size_base is per-file size for level-1. + columnFamilyOptions.setTargetFileSizeBase(targetFileSize); + columnFamilyOptions.setLevelZeroFileNumCompactionTrigger(20); + // Soft limit on number of level-0 files. + columnFamilyOptions.setLevelZeroSlowdownWritesTrigger(30); + // Maximum number of level-0 files. We stop writes at this point. + columnFamilyOptions.setLevelZeroStopWritesTrigger(40); + + BlockBasedTableConfig tableConfig = buildBlockBasedTableConfig(); + tableConfig.setBlockSize(blockSize); + tableConfig.setBlockCacheSize(blockCacheSize); + columnFamilyOptions.setTableFormatConfig(tableConfig); + columnFamilyOptions.setMaxWriteBufferNumber(2); + + return columnFamilyOptions; + } + + @Override + public Statistics getStatistics() { + return statistics; + } + + @Override + public void enableParanoidCheck() { + options.setParanoidChecks(true); + } + + @Override + public void enableStatistics() { + statistics = new Statistics(); + statistics.setStatsLevel(StatsLevel.ALL); + options.setStatistics(this.statistics); + } + + public void close() { + this.options.close(); + if (statistics != null) { + statistics.close(); } + this.closed = true; + } - @Override - public Statistics getStatistics() { - return statistics; - } - - @Override - public void enableParanoidCheck() { - options.setParanoidChecks(true); - } - - @Override - public void enableStatistics() { - statistics = new Statistics(); - statistics.setStatsLevel(StatsLevel.ALL); - options.setStatistics(this.statistics); - } - - public void close() { - this.options.close(); - if (statistics != null) { - statistics.close(); - } - this.closed = true; - } - - public boolean isClosed() { - return closed; - } + public boolean isClosed() { + return closed; + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/IRocksDBOptions.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/IRocksDBOptions.java index 9c2d18cf6..642bd68b3 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/IRocksDBOptions.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/options/IRocksDBOptions.java @@ -30,28 +30,27 @@ public interface IRocksDBOptions { - void init(Configuration config); + void init(Configuration config); - Options getOptions(); + Options getOptions(); - WriteOptions getWriteOptions(); + WriteOptions getWriteOptions(); - ReadOptions getReadOptions(); + ReadOptions getReadOptions(); - FlushOptions getFlushOptions(); + FlushOptions getFlushOptions(); - DBOptions getDbOptions(); + DBOptions getDbOptions(); - ColumnFamilyOptions buildFamilyOptions(); + ColumnFamilyOptions buildFamilyOptions(); - Statistics getStatistics(); + Statistics getStatistics(); - void enableParanoidCheck(); + void enableParanoidCheck(); - void enableStatistics(); + void enableStatistics(); - void close(); - - boolean isClosed(); + void close(); + boolean isClosed(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphMultiVersionedProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphMultiVersionedProxy.java index 8089d5ba2..fe2066608 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphMultiVersionedProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphMultiVersionedProxy.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.serialize.SerializerFactory; @@ -43,89 +44,91 @@ import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.WriteBatch; -public class AsyncGraphMultiVersionedProxy extends SyncGraphMultiVersionedProxy { +public class AsyncGraphMultiVersionedProxy + extends SyncGraphMultiVersionedProxy { + + private final AsyncFlushMultiVersionedBuffer flushBuffer; - private final AsyncFlushMultiVersionedBuffer flushBuffer; + public AsyncGraphMultiVersionedProxy( + RocksdbClient rocksdbStore, IGraphKVEncoder encoder, Configuration config) { + super(rocksdbStore, encoder, config); + this.flushBuffer = + new AsyncFlushMultiVersionedBuffer<>( + config, this::flush, SerializerFactory.getKryoSerializer()); + } - public AsyncGraphMultiVersionedProxy(RocksdbClient rocksdbStore, - IGraphKVEncoder encoder, - Configuration config) { - super(rocksdbStore, encoder, config); - this.flushBuffer = new AsyncFlushMultiVersionedBuffer<>(config, this::flush, SerializerFactory.getKryoSerializer()); + private void flush(GraphWriteMultiVersionedBuffer graphWriteBuffer) { + if (graphWriteBuffer.getSize() == 0) { + return; } + ColumnFamilyHandle vertexCF = this.rocksdbClient.getColumnFamilyHandleMap().get(VERTEX_CF); + ColumnFamilyHandle indexCF = this.rocksdbClient.getColumnFamilyHandleMap().get(VERTEX_INDEX_CF); + ColumnFamilyHandle edgeCF = this.rocksdbClient.getColumnFamilyHandleMap().get(EDGE_CF); + WriteBatch writeBatch = new WriteBatch(); + try { + for (Entry>> entry : + graphWriteBuffer.getVertexId2Vertex().entrySet()) { - private void flush(GraphWriteMultiVersionedBuffer graphWriteBuffer) { - if (graphWriteBuffer.getSize() == 0) { - return; + for (Entry> innerEntry : entry.getValue().entrySet()) { + Tuple tuple = vertexEncoder.format(innerEntry.getValue()); + byte[] bVersion = getBinaryVersion(innerEntry.getKey()); + writeBatch.put(vertexCF, concat(bVersion, tuple.f0), tuple.f1); + writeBatch.put(indexCF, concat(tuple.f0, bVersion), EMPTY_BYTES); } - ColumnFamilyHandle vertexCF = this.rocksdbClient.getColumnFamilyHandleMap().get(VERTEX_CF); - ColumnFamilyHandle indexCF = this.rocksdbClient.getColumnFamilyHandleMap().get(VERTEX_INDEX_CF); - ColumnFamilyHandle edgeCF = this.rocksdbClient.getColumnFamilyHandleMap().get(EDGE_CF); - WriteBatch writeBatch = new WriteBatch(); - try { - for (Entry>> entry : - graphWriteBuffer.getVertexId2Vertex().entrySet()) { - - for (Entry> innerEntry : entry.getValue().entrySet()) { - Tuple tuple = vertexEncoder.format(innerEntry.getValue()); - byte[] bVersion = getBinaryVersion(innerEntry.getKey()); - writeBatch.put(vertexCF, concat(bVersion, tuple.f0), tuple.f1); - writeBatch.put(indexCF, concat(tuple.f0, bVersion), EMPTY_BYTES); - } - } - for (Entry>>> entry : - graphWriteBuffer.getVertexId2Edges().entrySet()) { - for (Entry>> innerEntry : entry.getValue().entrySet()) { - byte[] bVersion = getBinaryVersion(innerEntry.getKey()); - for (IEdge c : innerEntry.getValue()) { - Tuple tuple = edgeEncoder.format(c); - writeBatch.put(edgeCF, concat(bVersion, tuple.f0), tuple.f1); - } - } - } - } catch (Exception ex) { - throw new GeaflowRuntimeException(ex); + } + for (Entry>>> entry : + graphWriteBuffer.getVertexId2Edges().entrySet()) { + for (Entry>> innerEntry : entry.getValue().entrySet()) { + byte[] bVersion = getBinaryVersion(innerEntry.getKey()); + for (IEdge c : innerEntry.getValue()) { + Tuple tuple = edgeEncoder.format(c); + writeBatch.put(edgeCF, concat(bVersion, tuple.f0), tuple.f1); + } } - this.rocksdbClient.write(writeBatch); - writeBatch.clear(); - writeBatch.close(); + } + } catch (Exception ex) { + throw new GeaflowRuntimeException(ex); } + this.rocksdbClient.write(writeBatch); + writeBatch.clear(); + writeBatch.close(); + } - @Override - public void addVertex(long version, IVertex vertex) { - this.flushBuffer.addVertex(version, vertex); - } + @Override + public void addVertex(long version, IVertex vertex) { + this.flushBuffer.addVertex(version, vertex); + } - @Override - public void addEdge(long version, IEdge edge) { - this.flushBuffer.addEdge(version, edge); - } + @Override + public void addEdge(long version, IEdge edge) { + this.flushBuffer.addEdge(version, edge); + } - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - IVertex vertex = this.flushBuffer.readBufferedVertex(version, sid); - if (vertex != null) { - return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; - } - return super.getVertex(version, sid, pushdown); + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + IVertex vertex = this.flushBuffer.readBufferedVertex(version, sid); + if (vertex != null) { + return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; } + return super.getVertex(version, sid, pushdown); + } - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - List> list = this.flushBuffer.readBufferedEdges(version, sid); - LinkedHashSet> set = new LinkedHashSet<>(); - list.stream().filter(((IGraphFilter) pushdown.getFilter())::filterEdge).forEach(set::add); - set.addAll(super.getEdges(version, sid, pushdown)); - return new ArrayList<>(set); - } + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + List> list = this.flushBuffer.readBufferedEdges(version, sid); + LinkedHashSet> set = new LinkedHashSet<>(); + list.stream().filter(((IGraphFilter) pushdown.getFilter())::filterEdge).forEach(set::add); + set.addAll(super.getEdges(version, sid, pushdown)); + return new ArrayList<>(set); + } - @Override - public void flush() { - flushBuffer.flush(); - } + @Override + public void flush() { + flushBuffer.flush(); + } - @Override - public void close() { - flushBuffer.close(); - } + @Override + public void close() { + flushBuffer.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphRocksdbProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphRocksdbProxy.java index c7b327b2a..3275554a5 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphRocksdbProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/AsyncGraphRocksdbProxy.java @@ -22,11 +22,11 @@ import static org.apache.geaflow.store.rocksdb.RocksdbConfigKeys.EDGE_CF; import static org.apache.geaflow.store.rocksdb.RocksdbConfigKeys.VERTEX_CF; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.tuple.Tuple; @@ -40,73 +40,77 @@ import org.apache.geaflow.store.data.GraphWriteBuffer; import org.apache.geaflow.store.rocksdb.RocksdbClient; -public class AsyncGraphRocksdbProxy extends SyncGraphRocksdbProxy { - - private final AsyncFlushBuffer flushBuffer; +import com.google.common.collect.Lists; - public AsyncGraphRocksdbProxy(RocksdbClient rocksdbClient, - IGraphKVEncoder encoder, - Configuration config) { - super(rocksdbClient, encoder, config); - this.flushBuffer = new AsyncFlushBuffer<>(config, this::flush, SerializerFactory.getKryoSerializer()); - } +public class AsyncGraphRocksdbProxy extends SyncGraphRocksdbProxy { - private void flush(GraphWriteBuffer graphWriteBuffer) { - if (graphWriteBuffer.getSize() == 0) { - return; - } + private final AsyncFlushBuffer flushBuffer; - List> list = graphWriteBuffer.getVertexId2Vertex().values() - .stream().map(v -> vertexEncoder.format(v)).collect(Collectors.toList()); - rocksdbClient.write(VERTEX_CF, list); + public AsyncGraphRocksdbProxy( + RocksdbClient rocksdbClient, IGraphKVEncoder encoder, Configuration config) { + super(rocksdbClient, encoder, config); + this.flushBuffer = + new AsyncFlushBuffer<>(config, this::flush, SerializerFactory.getKryoSerializer()); + } - list.clear(); - for (List> edges : graphWriteBuffer.getVertexId2Edges().values()) { - edges.forEach(e -> list.add(edgeEncoder.format(e))); - } - rocksdbClient.write(EDGE_CF, list); + private void flush(GraphWriteBuffer graphWriteBuffer) { + if (graphWriteBuffer.getSize() == 0) { + return; } - @Override - public void addVertex(IVertex vertex) { - this.flushBuffer.addVertex(vertex); - } + List> list = + graphWriteBuffer.getVertexId2Vertex().values().stream() + .map(v -> vertexEncoder.format(v)) + .collect(Collectors.toList()); + rocksdbClient.write(VERTEX_CF, list); - @Override - public void addEdge(IEdge edge) { - this.flushBuffer.addEdge(edge); + list.clear(); + for (List> edges : graphWriteBuffer.getVertexId2Edges().values()) { + edges.forEach(e -> list.add(edgeEncoder.format(e))); } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - IVertex vertex = this.flushBuffer.readBufferedVertex(sid); - if (vertex != null) { - return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; - } - return super.getVertex(sid, pushdown); + rocksdbClient.write(EDGE_CF, list); + } + + @Override + public void addVertex(IVertex vertex) { + this.flushBuffer.addVertex(vertex); + } + + @Override + public void addEdge(IEdge edge) { + this.flushBuffer.addEdge(edge); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + IVertex vertex = this.flushBuffer.readBufferedVertex(sid); + if (vertex != null) { + return ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex) ? vertex : null; } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - List> list = this.flushBuffer.readBufferedEdges(sid); - LinkedHashSet> set = new LinkedHashSet<>(); - - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - Lists.reverse(list).stream().filter(filter::filterEdge).forEach(set::add); - if (!filter.dropAllRemaining()) { - set.addAll(super.getEdges(sid, filter)); - } - - return new ArrayList<>(set); + return super.getVertex(sid, pushdown); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + List> list = this.flushBuffer.readBufferedEdges(sid); + LinkedHashSet> set = new LinkedHashSet<>(); + + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + Lists.reverse(list).stream().filter(filter::filterEdge).forEach(set::add); + if (!filter.dropAllRemaining()) { + set.addAll(super.getEdges(sid, filter)); } - @Override - public void flush() { - flushBuffer.flush(); - } + return new ArrayList<>(set); + } - @Override - public void close() { - flushBuffer.close(); - } + @Override + public void flush() { + flushBuffer.flush(); + } + + @Override + public void close() { + flushBuffer.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphMultiVersionedRocksdbProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphMultiVersionedRocksdbProxy.java index 3221fdbf4..e7f2b2835 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphMultiVersionedRocksdbProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphMultiVersionedRocksdbProxy.java @@ -21,7 +21,5 @@ import org.apache.geaflow.state.graph.DynamicGraphTrait; -public interface IGraphMultiVersionedRocksdbProxy extends DynamicGraphTrait, - IRocksdbProxy { - -} +public interface IGraphMultiVersionedRocksdbProxy + extends DynamicGraphTrait, IRocksdbProxy {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphRocksdbProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphRocksdbProxy.java index 1d5ca23d6..01788f098 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphRocksdbProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IGraphRocksdbProxy.java @@ -21,6 +21,4 @@ import org.apache.geaflow.state.graph.StaticGraphTrait; -public interface IGraphRocksdbProxy extends StaticGraphTrait, IRocksdbProxy { - -} +public interface IGraphRocksdbProxy extends StaticGraphTrait, IRocksdbProxy {} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IRocksdbProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IRocksdbProxy.java index 885abc0dc..cc0d6c161 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IRocksdbProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/IRocksdbProxy.java @@ -23,9 +23,9 @@ public interface IRocksdbProxy { - RocksdbClient getClient(); + RocksdbClient getClient(); - void flush(); + void flush(); - void close(); + void close(); } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/ProxyBuilder.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/ProxyBuilder.java index 0f7e786ab..791298ffa 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/ProxyBuilder.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/ProxyBuilder.java @@ -29,36 +29,36 @@ public class ProxyBuilder { - public static IGraphRocksdbProxy build( - Configuration config, RocksdbClient rocksdbClient, - IGraphKVEncoder encoder) { - PartitionType partitionType = PartitionType.getEnum( + public static IGraphRocksdbProxy build( + Configuration config, RocksdbClient rocksdbClient, IGraphKVEncoder encoder) { + PartitionType partitionType = + PartitionType.getEnum( config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)); - if (partitionType.isPartition()) { - if (partitionType == PartitionType.LABEL) { - // TODO: Support async graph proxy partitioned by label - return new SyncGraphLabelPartitionProxy<>(rocksdbClient, encoder, config); - } else if (partitionType == PartitionType.DT) { - return new SyncGraphDtPartitionProxy<>(rocksdbClient, encoder, config); - } - throw new GeaflowRuntimeException("unexpected partition type: " + config.getString( - RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)); - } else { - if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { - return new AsyncGraphRocksdbProxy<>(rocksdbClient, encoder, config); - } else { - return new SyncGraphRocksdbProxy<>(rocksdbClient, encoder, config); - } - } + if (partitionType.isPartition()) { + if (partitionType == PartitionType.LABEL) { + // TODO: Support async graph proxy partitioned by label + return new SyncGraphLabelPartitionProxy<>(rocksdbClient, encoder, config); + } else if (partitionType == PartitionType.DT) { + return new SyncGraphDtPartitionProxy<>(rocksdbClient, encoder, config); + } + throw new GeaflowRuntimeException( + "unexpected partition type: " + + config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE)); + } else { + if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { + return new AsyncGraphRocksdbProxy<>(rocksdbClient, encoder, config); + } else { + return new SyncGraphRocksdbProxy<>(rocksdbClient, encoder, config); + } } + } - public static IGraphMultiVersionedRocksdbProxy buildMultiVersioned( - Configuration config, RocksdbClient rocksdbClient, - IGraphKVEncoder encoder) { - if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { - return new AsyncGraphMultiVersionedProxy<>(rocksdbClient, encoder, config); - } else { - return new SyncGraphMultiVersionedProxy<>(rocksdbClient, encoder, config); - } + public static IGraphMultiVersionedRocksdbProxy buildMultiVersioned( + Configuration config, RocksdbClient rocksdbClient, IGraphKVEncoder encoder) { + if (config.getBoolean(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE)) { + return new AsyncGraphMultiVersionedProxy<>(rocksdbClient, encoder, config); + } else { + return new SyncGraphMultiVersionedProxy<>(rocksdbClient, encoder, config); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphDtPartitionProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphDtPartitionProxy.java index 185c8459f..a0cfc2abf 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphDtPartitionProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphDtPartitionProxy.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.ChainedCloseableIterator; @@ -46,245 +47,242 @@ import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; -/** - * GraphProxy which supported being partitioned by timestamp. - */ +/** GraphProxy which supported being partitioned by timestamp. */ public class SyncGraphDtPartitionProxy extends SyncGraphRocksdbProxy { - // column family name -> column family handle - private final Map vertexHandleMap; - private final Map edgeHandleMap; - private final Map descriptorMap; - private final IRocksDBOptions rocksDBOptions; - - - // Dt config from RocksdbConfigKeys - private final long startTimestamp; - private final long cycle; - - public SyncGraphDtPartitionProxy(RocksdbClient rocksdbClient, - IGraphKVEncoder encoder, Configuration config) { - super(rocksdbClient, encoder, config); - this.vertexHandleMap = rocksdbClient.getVertexHandleMap(); - this.edgeHandleMap = rocksdbClient.getEdgeHandleMap(); - this.descriptorMap = rocksdbClient.getDescriptorMap(); - this.rocksDBOptions = rocksdbClient.getRocksDBOptions(); - - this.startTimestamp = Long.parseLong( - config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_START)); - this.cycle = Long.parseLong( - config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_CYCLE)); - + // column family name -> column family handle + private final Map vertexHandleMap; + private final Map edgeHandleMap; + private final Map descriptorMap; + private final IRocksDBOptions rocksDBOptions; + + // Dt config from RocksdbConfigKeys + private final long startTimestamp; + private final long cycle; + + public SyncGraphDtPartitionProxy( + RocksdbClient rocksdbClient, IGraphKVEncoder encoder, Configuration config) { + super(rocksdbClient, encoder, config); + this.vertexHandleMap = rocksdbClient.getVertexHandleMap(); + this.edgeHandleMap = rocksdbClient.getEdgeHandleMap(); + this.descriptorMap = rocksdbClient.getDescriptorMap(); + this.rocksDBOptions = rocksdbClient.getRocksDBOptions(); + + this.startTimestamp = + Long.parseLong(config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_START)); + this.cycle = Long.parseLong(config.getString(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_CYCLE)); + } + + public long getDt(long timestamp) { + if (timestamp < startTimestamp) { + throw new GeaflowRuntimeException( + "timestamp " + timestamp + " is less than start " + startTimestamp); } + long offset = (timestamp - startTimestamp) / cycle; + return startTimestamp + offset * cycle; + } - public long getDt(long timestamp) { - if (timestamp < startTimestamp) { - throw new GeaflowRuntimeException( - "timestamp " + timestamp + " is less than start " + startTimestamp); - } - long offset = (timestamp - startTimestamp) / cycle; - return startTimestamp + offset * cycle; - } - - private String getColumnFamilyNameByTimeStamp(long timestamp, boolean isVertex) { - return (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) - + getDt(timestamp); - } + private String getColumnFamilyNameByTimeStamp(long timestamp, boolean isVertex) { + return (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) + + getDt(timestamp); + } - private List getColumnFamilyNameListByTimeRange(TimeRange range, boolean isVertex) { - long start = range.getStart(); - long end = range.getEnd(); + private List getColumnFamilyNameListByTimeRange(TimeRange range, boolean isVertex) { + long start = range.getStart(); + long end = range.getEnd(); - if (end < start) { - throw new GeaflowRuntimeException( - "timestamp end " + end + " is less than start " + start); - } - if (end < startTimestamp) { - // 全部在start之前 - return new ArrayList<>(); - } - - long first = Math.max(start, startTimestamp); - long firstDt = getDt(first); - List result = new ArrayList<>(); - for (long ts = firstDt; ts <= end; ts += cycle) { - result.add( - (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) - + ts); - } - - return result; + if (end < start) { + throw new GeaflowRuntimeException("timestamp end " + end + " is less than start " + start); } - - private ColumnFamilyHandle tryToGetOrCreateColumnFamilyHandle( - Map handleMap, IGraphElementWithTimeField element, - boolean isVertex) { - String cfName = getColumnFamilyNameByTimeStamp(element.getTime(), isVertex); - - return handleMap.computeIfAbsent(cfName, key -> { - // Create ColumnFamilyDescriptor - ColumnFamilyDescriptor descriptor = new ColumnFamilyDescriptor(cfName.getBytes(), - rocksDBOptions.buildFamilyOptions()); - - descriptorMap.put(cfName, descriptor); - - // Create ColumnFamilyHandle - try { - return rocksdbClient.getRocksdb().createColumnFamily(descriptor); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException("Create column family " + cfName + " fail", e); - } - }); + if (end < startTimestamp) { + // 全部在start之前 + return new ArrayList<>(); } - @Override - public void addVertex(IVertex vertex) { - Tuple tuple = vertexEncoder.format(vertex); - // Get the timestamp of the vertex - // Create a new column family for a partitioned-dt never written - ColumnFamilyHandle handle = tryToGetOrCreateColumnFamilyHandle(vertexHandleMap, - (IGraphElementWithTimeField) vertex, true); - - this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + long first = Math.max(start, startTimestamp); + long firstDt = getDt(first); + List result = new ArrayList<>(); + for (long ts = firstDt; ts <= end; ts += cycle) { + result.add( + (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) + ts); } - @Override - public void addEdge(IEdge edge) { - Tuple tuple = edgeEncoder.format(edge); - // Get the timestamp of the edge - // Create a new column family for a partitioned-dt never written - ColumnFamilyHandle handle = tryToGetOrCreateColumnFamilyHandle(edgeHandleMap, - (IGraphElementWithTimeField) edge, false); - - this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + return result; + } + + private ColumnFamilyHandle tryToGetOrCreateColumnFamilyHandle( + Map handleMap, + IGraphElementWithTimeField element, + boolean isVertex) { + String cfName = getColumnFamilyNameByTimeStamp(element.getTime(), isVertex); + + return handleMap.computeIfAbsent( + cfName, + key -> { + // Create ColumnFamilyDescriptor + ColumnFamilyDescriptor descriptor = + new ColumnFamilyDescriptor(cfName.getBytes(), rocksDBOptions.buildFamilyOptions()); + + descriptorMap.put(cfName, descriptor); + + // Create ColumnFamilyHandle + try { + return rocksdbClient.getRocksdb().createColumnFamily(descriptor); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException("Create column family " + cfName + " fail", e); + } + }); + } + + @Override + public void addVertex(IVertex vertex) { + Tuple tuple = vertexEncoder.format(vertex); + // Get the timestamp of the vertex + // Create a new column family for a partitioned-dt never written + ColumnFamilyHandle handle = + tryToGetOrCreateColumnFamilyHandle( + vertexHandleMap, (IGraphElementWithTimeField) vertex, true); + + this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + } + + @Override + public void addEdge(IEdge edge) { + Tuple tuple = edgeEncoder.format(edge); + // Get the timestamp of the edge + // Create a new column family for a partitioned-dt never written + ColumnFamilyHandle handle = + tryToGetOrCreateColumnFamilyHandle(edgeHandleMap, (IGraphElementWithTimeField) edge, false); + + this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + IGraphFilter filter = null; + TimeRange range = null; + + if (pushdown != null) { + filter = (IGraphFilter) pushdown.getFilter(); + range = FilterHelper.parseDt(filter, true); } - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - IGraphFilter filter = null; - TimeRange range = null; - - if (pushdown != null) { - filter = (IGraphFilter) pushdown.getFilter(); - range = FilterHelper.parseDt(filter, true); + if (range == null) { + for (ColumnFamilyHandle handle : vertexHandleMap.values()) { + byte[] value = this.rocksdbClient.get(handle, key); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (filter == null || filter.filterVertex(vertex)) { + return vertex; + } } - - if (range == null) { - for (ColumnFamilyHandle handle : vertexHandleMap.values()) { - byte[] value = this.rocksdbClient.get(handle, key); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (filter == null || filter.filterVertex(vertex)) { - return vertex; - } - } - } - } else { - List cfNames = getColumnFamilyNameListByTimeRange(range, true); - for (String cfName : cfNames) { - if (vertexHandleMap.containsKey(cfName)) { - byte[] value = this.rocksdbClient.get(vertexHandleMap.get(cfName), key); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (filter.filterVertex(vertex)) { - return vertex; - } - } - } + } + } else { + List cfNames = getColumnFamilyNameListByTimeRange(range, true); + for (String cfName : cfNames) { + if (vertexHandleMap.containsKey(cfName)) { + byte[] value = this.rocksdbClient.get(vertexHandleMap.get(cfName), key); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (filter.filterVertex(vertex)) { + return vertex; } + } } - - return null; + } } - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - return getEdges(sid, filter); - } - - protected List> getEdges(K sid, IGraphFilter filter) { - List> list = new ArrayList<>(); - TimeRange range = FilterHelper.parseDt(filter, true); - byte[] prefix = edgeEncoder.getScanBytes(sid); - - if (range == null) { - for (ColumnFamilyHandle handle : edgeHandleMap.values()) { - getEdgesFromSingleColumnFamily(handle, prefix, filter, list); - } - } else { - List cfNames = getColumnFamilyNameListByTimeRange(range, false); - - for (String cfName : cfNames) { - if (edgeHandleMap.containsKey(cfName)) { - getEdgesFromSingleColumnFamily(edgeHandleMap.get(cfName), prefix, filter, list); - } - } + return null; + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + return getEdges(sid, filter); + } + + protected List> getEdges(K sid, IGraphFilter filter) { + List> list = new ArrayList<>(); + TimeRange range = FilterHelper.parseDt(filter, true); + byte[] prefix = edgeEncoder.getScanBytes(sid); + + if (range == null) { + for (ColumnFamilyHandle handle : edgeHandleMap.values()) { + getEdgesFromSingleColumnFamily(handle, prefix, filter, list); + } + } else { + List cfNames = getColumnFamilyNameListByTimeRange(range, false); + + for (String cfName : cfNames) { + if (edgeHandleMap.containsKey(cfName)) { + getEdgesFromSingleColumnFamily(edgeHandleMap.get(cfName), prefix, filter, list); } - - return list; + } } - protected void getEdgesFromSingleColumnFamily(ColumnFamilyHandle handle, byte[] prefix, - IGraphFilter filter, List> list) { - try (RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(handle), - prefix)) { - getEdgesFromRocksDBIterator(list, it, filter); - } - } + return list; + } - @Override - public CloseableIterator vertexIDIterator() { - flush(); - - List iterList = new ArrayList<>(); - for (ColumnFamilyHandle handle : vertexHandleMap.values()) { - iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); - } - - CloseableIterator> it = new ChainedCloseableIterator(iterList); - - return buildVertexIDIteratorFromRocksDBIter(it); + protected void getEdgesFromSingleColumnFamily( + ColumnFamilyHandle handle, byte[] prefix, IGraphFilter filter, List> list) { + try (RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(handle), prefix)) { + getEdgesFromRocksDBIterator(list, it, filter); } + } - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - flush(); - - return new VertexScanIterator<>(getVertexOrEdgeIterator(vertexHandleMap, pushdown, true), - pushdown, vertexEncoder::getVertex); - } + @Override + public CloseableIterator vertexIDIterator() { + flush(); - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - flush(); - return new EdgeScanIterator<>(getVertexOrEdgeIterator(edgeHandleMap, pushdown, false), - pushdown, edgeEncoder::getEdge); + List iterList = new ArrayList<>(); + for (ColumnFamilyHandle handle : vertexHandleMap.values()) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); } - private CloseableIterator> getVertexOrEdgeIterator( - Map handleMap, IStatePushDown pushdown, boolean isVertex) { - - IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); - TimeRange range = FilterHelper.parseDt(filter, isVertex); - List iterList = new ArrayList<>(); - - if (range == null) { - for (ColumnFamilyHandle handle : handleMap.values()) { - iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); - } - } else { - List cfNames = getColumnFamilyNameListByTimeRange(range, isVertex); - for (String cfName : cfNames) { - - if (handleMap.containsKey(cfName)) { - iterList.add( - new RocksdbIterator(this.rocksdbClient.getIterator(handleMap.get(cfName)))); - } - } + CloseableIterator> it = new ChainedCloseableIterator(iterList); + + return buildVertexIDIteratorFromRocksDBIter(it); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + flush(); + + return new VertexScanIterator<>( + getVertexOrEdgeIterator(vertexHandleMap, pushdown, true), + pushdown, + vertexEncoder::getVertex); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + flush(); + return new EdgeScanIterator<>( + getVertexOrEdgeIterator(edgeHandleMap, pushdown, false), pushdown, edgeEncoder::getEdge); + } + + private CloseableIterator> getVertexOrEdgeIterator( + Map handleMap, IStatePushDown pushdown, boolean isVertex) { + + IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); + TimeRange range = FilterHelper.parseDt(filter, isVertex); + List iterList = new ArrayList<>(); + + if (range == null) { + for (ColumnFamilyHandle handle : handleMap.values()) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); + } + } else { + List cfNames = getColumnFamilyNameListByTimeRange(range, isVertex); + for (String cfName : cfNames) { + + if (handleMap.containsKey(cfName)) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handleMap.get(cfName)))); } - - return new ChainedCloseableIterator(iterList); + } } + + return new ChainedCloseableIterator(iterList); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphLabelPartitionProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphLabelPartitionProxy.java index 60cf9f204..b45fa94df 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphLabelPartitionProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphLabelPartitionProxy.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.ChainedCloseableIterator; @@ -45,206 +46,207 @@ import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; -/** - * GraphProxy which supported being partitioned by label. - */ +/** GraphProxy which supported being partitioned by label. */ public class SyncGraphLabelPartitionProxy extends SyncGraphRocksdbProxy { - // column family name -> column family handle - private final Map vertexHandleMap; - private final Map edgeHandleMap; - private final Map descriptorMap; - private final IRocksDBOptions rocksDBOptions; - - public SyncGraphLabelPartitionProxy(RocksdbClient rocksdbClient, - IGraphKVEncoder encoder, Configuration config) { - super(rocksdbClient, encoder, config); - this.vertexHandleMap = rocksdbClient.getVertexHandleMap(); - this.edgeHandleMap = rocksdbClient.getEdgeHandleMap(); - - this.descriptorMap = rocksdbClient.getDescriptorMap(); - this.rocksDBOptions = rocksdbClient.getRocksDBOptions(); - } - - - private String getColumnFamilyName(String label, boolean isVertex) { - return (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) - + label; - } - - private ColumnFamilyHandle tryToGetOrCreateColumnFamilyHandle( - Map handleMap, IGraphElementWithLabelField element, - boolean isVertex) { - String cfName = getColumnFamilyName(element.getLabel(), isVertex); - - return handleMap.computeIfAbsent(cfName, key -> { - // Create ColumnFamilyDescriptor - ColumnFamilyDescriptor descriptor = new ColumnFamilyDescriptor(cfName.getBytes(), - rocksDBOptions.buildFamilyOptions()); - - descriptorMap.put(cfName, descriptor); - - // Create ColumnFamilyHandle - try { - return rocksdbClient.getRocksdb().createColumnFamily(descriptor); - } catch (RocksDBException e) { - throw new GeaflowRuntimeException("Create column family " + cfName + " fail", e); - } + // column family name -> column family handle + private final Map vertexHandleMap; + private final Map edgeHandleMap; + private final Map descriptorMap; + private final IRocksDBOptions rocksDBOptions; + + public SyncGraphLabelPartitionProxy( + RocksdbClient rocksdbClient, IGraphKVEncoder encoder, Configuration config) { + super(rocksdbClient, encoder, config); + this.vertexHandleMap = rocksdbClient.getVertexHandleMap(); + this.edgeHandleMap = rocksdbClient.getEdgeHandleMap(); + + this.descriptorMap = rocksdbClient.getDescriptorMap(); + this.rocksDBOptions = rocksdbClient.getRocksDBOptions(); + } + + private String getColumnFamilyName(String label, boolean isVertex) { + return (isVertex ? RocksdbConfigKeys.VERTEX_CF_PREFIX : RocksdbConfigKeys.EDGE_CF_PREFIX) + + label; + } + + private ColumnFamilyHandle tryToGetOrCreateColumnFamilyHandle( + Map handleMap, + IGraphElementWithLabelField element, + boolean isVertex) { + String cfName = getColumnFamilyName(element.getLabel(), isVertex); + + return handleMap.computeIfAbsent( + cfName, + key -> { + // Create ColumnFamilyDescriptor + ColumnFamilyDescriptor descriptor = + new ColumnFamilyDescriptor(cfName.getBytes(), rocksDBOptions.buildFamilyOptions()); + + descriptorMap.put(cfName, descriptor); + + // Create ColumnFamilyHandle + try { + return rocksdbClient.getRocksdb().createColumnFamily(descriptor); + } catch (RocksDBException e) { + throw new GeaflowRuntimeException("Create column family " + cfName + " fail", e); + } }); + } + + @Override + public void addVertex(IVertex vertex) { + // TODO: Supports partitioning vertices not by label but edges by label + Tuple tuple = vertexEncoder.format(vertex); + // Get the label of the vertex + // Create a new column family for a label never written + ColumnFamilyHandle handle = + tryToGetOrCreateColumnFamilyHandle( + vertexHandleMap, (IGraphElementWithLabelField) vertex, true); + + this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + } + + @Override + public void addEdge(IEdge edge) { + Tuple tuple = edgeEncoder.format(edge); + // Get the label of the edge + // Create a new column family for a label never written + ColumnFamilyHandle handle = + tryToGetOrCreateColumnFamilyHandle( + edgeHandleMap, (IGraphElementWithLabelField) edge, false); + + this.rocksdbClient.write(handle, tuple.f0, tuple.f1); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + List labels = new ArrayList<>(); + IGraphFilter filter = null; + + if (pushdown != null) { + filter = (IGraphFilter) pushdown.getFilter(); + labels = FilterHelper.parseLabel(filter, true); } - @Override - public void addVertex(IVertex vertex) { - // TODO: Supports partitioning vertices not by label but edges by label - Tuple tuple = vertexEncoder.format(vertex); - // Get the label of the vertex - // Create a new column family for a label never written - ColumnFamilyHandle handle = tryToGetOrCreateColumnFamilyHandle(vertexHandleMap, - (IGraphElementWithLabelField) vertex, true); - - this.rocksdbClient.write(handle, tuple.f0, tuple.f1); - } - - @Override - public void addEdge(IEdge edge) { - Tuple tuple = edgeEncoder.format(edge); - // Get the label of the edge - // Create a new column family for a label never written - ColumnFamilyHandle handle = tryToGetOrCreateColumnFamilyHandle(edgeHandleMap, - (IGraphElementWithLabelField) edge, false); - - this.rocksdbClient.write(handle, tuple.f0, tuple.f1); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - List labels = new ArrayList<>(); - IGraphFilter filter = null; - - if (pushdown != null) { - filter = (IGraphFilter) pushdown.getFilter(); - labels = FilterHelper.parseLabel(filter, true); + if (labels.isEmpty()) { + for (ColumnFamilyHandle handle : vertexHandleMap.values()) { + byte[] value = this.rocksdbClient.get(handle, key); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (filter == null || filter.filterVertex(vertex)) { + return vertex; + } } - - if (labels.isEmpty()) { - for (ColumnFamilyHandle handle : vertexHandleMap.values()) { - byte[] value = this.rocksdbClient.get(handle, key); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (filter == null || filter.filterVertex(vertex)) { - return vertex; - } - } - } - } else { - for (String label : labels) { - String cfName = getColumnFamilyName(label, true); - - if (vertexHandleMap.containsKey(cfName)) { - byte[] value = this.rocksdbClient.get(vertexHandleMap.get(cfName), key); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (filter.filterVertex(vertex)) { - return vertex; - } - } - } + } + } else { + for (String label : labels) { + String cfName = getColumnFamilyName(label, true); + + if (vertexHandleMap.containsKey(cfName)) { + byte[] value = this.rocksdbClient.get(vertexHandleMap.get(cfName), key); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (filter.filterVertex(vertex)) { + return vertex; } + } } - - return null; - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - return getEdges(sid, filter); + } } - protected List> getEdges(K sid, IGraphFilter filter) { - List> list = new ArrayList<>(); - List labels = FilterHelper.parseLabel(filter, false); - byte[] prefix = edgeEncoder.getScanBytes(sid); - - // TODO: Multi-thread get edges - if (labels.isEmpty()) { - for (ColumnFamilyHandle handle : edgeHandleMap.values()) { - getEdgesFromSingleColumnFamily(handle, prefix, filter, list); - } - } else { - for (String label : labels) { - String cfName = getColumnFamilyName(label, false); - - if (edgeHandleMap.containsKey(cfName)) { - getEdgesFromSingleColumnFamily(edgeHandleMap.get(cfName), prefix, filter, - list); - } - } + return null; + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + return getEdges(sid, filter); + } + + protected List> getEdges(K sid, IGraphFilter filter) { + List> list = new ArrayList<>(); + List labels = FilterHelper.parseLabel(filter, false); + byte[] prefix = edgeEncoder.getScanBytes(sid); + + // TODO: Multi-thread get edges + if (labels.isEmpty()) { + for (ColumnFamilyHandle handle : edgeHandleMap.values()) { + getEdgesFromSingleColumnFamily(handle, prefix, filter, list); + } + } else { + for (String label : labels) { + String cfName = getColumnFamilyName(label, false); + + if (edgeHandleMap.containsKey(cfName)) { + getEdgesFromSingleColumnFamily(edgeHandleMap.get(cfName), prefix, filter, list); } - - return list; + } } - protected void getEdgesFromSingleColumnFamily(ColumnFamilyHandle handle, byte[] prefix, - IGraphFilter filter, List> list) { - try (RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(handle), - prefix)) { - getEdgesFromRocksDBIterator(list, it, filter); - } - } + return list; + } - @Override - public CloseableIterator vertexIDIterator() { - flush(); - - List iterList = new ArrayList<>(); - for (ColumnFamilyHandle handle : vertexHandleMap.values()) { - iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); - } - - CloseableIterator> it = new ChainedCloseableIterator(iterList); - - return buildVertexIDIteratorFromRocksDBIter(it); + protected void getEdgesFromSingleColumnFamily( + ColumnFamilyHandle handle, byte[] prefix, IGraphFilter filter, List> list) { + try (RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(handle), prefix)) { + getEdgesFromRocksDBIterator(list, it, filter); } + } - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - flush(); + @Override + public CloseableIterator vertexIDIterator() { + flush(); - return new VertexScanIterator<>(getVertexOrEdgeIterator(vertexHandleMap, pushdown, true), - pushdown, vertexEncoder::getVertex); + List iterList = new ArrayList<>(); + for (ColumnFamilyHandle handle : vertexHandleMap.values()) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); } - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - flush(); - return new EdgeScanIterator<>(getVertexOrEdgeIterator(edgeHandleMap, pushdown, false), - pushdown, edgeEncoder::getEdge); - } - - private CloseableIterator> getVertexOrEdgeIterator( - Map handleMap, IStatePushDown pushdown, boolean isVertex) { - - IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); - List labels = FilterHelper.parseLabel(filter, isVertex); - List iterList = new ArrayList<>(); - - if (labels.isEmpty()) { - for (ColumnFamilyHandle handle : handleMap.values()) { - iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); - } - } else { - for (String label : labels) { - String cfName = getColumnFamilyName(label, isVertex); - - if (handleMap.containsKey(cfName)) { - iterList.add(new RocksdbIterator( - this.rocksdbClient.getIterator(handleMap.get(cfName)))); - } - } + CloseableIterator> it = new ChainedCloseableIterator(iterList); + + return buildVertexIDIteratorFromRocksDBIter(it); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + flush(); + + return new VertexScanIterator<>( + getVertexOrEdgeIterator(vertexHandleMap, pushdown, true), + pushdown, + vertexEncoder::getVertex); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + flush(); + return new EdgeScanIterator<>( + getVertexOrEdgeIterator(edgeHandleMap, pushdown, false), pushdown, edgeEncoder::getEdge); + } + + private CloseableIterator> getVertexOrEdgeIterator( + Map handleMap, IStatePushDown pushdown, boolean isVertex) { + + IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); + List labels = FilterHelper.parseLabel(filter, isVertex); + List iterList = new ArrayList<>(); + + if (labels.isEmpty()) { + for (ColumnFamilyHandle handle : handleMap.values()) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handle))); + } + } else { + for (String label : labels) { + String cfName = getColumnFamilyName(label, isVertex); + + if (handleMap.containsKey(cfName)) { + iterList.add(new RocksdbIterator(this.rocksdbClient.getIterator(handleMap.get(cfName)))); } - - return new ChainedCloseableIterator(iterList); + } } + + return new ChainedCloseableIterator(iterList); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphMultiVersionedProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphMultiVersionedProxy.java index 0fb2e9549..a1ca35701 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphMultiVersionedProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphMultiVersionedProxy.java @@ -23,8 +23,6 @@ import static org.apache.geaflow.store.rocksdb.RocksdbConfigKeys.VERTEX_CF; import static org.apache.geaflow.store.rocksdb.RocksdbConfigKeys.VERTEX_INDEX_CF; -import com.google.common.primitives.Bytes; -import com.google.common.primitives.Longs; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -32,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.function.Predicate; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; @@ -58,274 +57,276 @@ import org.apache.geaflow.store.rocksdb.RocksdbClient; import org.apache.geaflow.store.rocksdb.iterator.RocksdbIterator; -public class SyncGraphMultiVersionedProxy implements IGraphMultiVersionedRocksdbProxy { - - private static final int VERSION_BYTES_SIZE = Long.BYTES; - private static final int VERTEX_INDEX_SUFFIX_SIZE = - VERSION_BYTES_SIZE + StateConfigKeys.DELIMITER.length; - protected static final byte[] EMPTY_BYTES = new byte[0]; - protected final Configuration config; - protected RocksdbClient rocksdbClient; - protected IGraphKVEncoder encoder; - protected IEdgeKVEncoder edgeEncoder; - protected IVertexKVEncoder vertexEncoder; - - public SyncGraphMultiVersionedProxy(RocksdbClient rocksdbStore, - IGraphKVEncoder encoder, - Configuration config) { - this.encoder = encoder; - this.rocksdbClient = rocksdbStore; - this.vertexEncoder = encoder.getVertexEncoder(); - this.edgeEncoder = encoder.getEdgeEncoder(); - this.config = config; - } - - @Override - public void addVertex(long version, IVertex vertex) { - Tuple tuple = vertexEncoder.format(vertex); - byte[] bVersion = getBinaryVersion(version); - this.rocksdbClient.write(VERTEX_CF, concat(bVersion, tuple.f0), tuple.f1); - this.rocksdbClient.write(VERTEX_INDEX_CF, concat(tuple.f0, bVersion), EMPTY_BYTES); - } - - @Override - public void addEdge(long version, IEdge edge) { - byte[] bVersion = getBinaryVersion(version); - Tuple tuple = edgeEncoder.format(edge); - this.rocksdbClient.write(EDGE_CF, concat(bVersion, tuple.f0), tuple.f1); - } +import com.google.common.primitives.Bytes; +import com.google.common.primitives.Longs; - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - byte[] bVersion = getBinaryVersion(version); - byte[] value = this.rocksdbClient.get(VERTEX_CF, concat(bVersion, key)); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { - return vertex; - } - } - return null; +public class SyncGraphMultiVersionedProxy + implements IGraphMultiVersionedRocksdbProxy { + + private static final int VERSION_BYTES_SIZE = Long.BYTES; + private static final int VERTEX_INDEX_SUFFIX_SIZE = + VERSION_BYTES_SIZE + StateConfigKeys.DELIMITER.length; + protected static final byte[] EMPTY_BYTES = new byte[0]; + protected final Configuration config; + protected RocksdbClient rocksdbClient; + protected IGraphKVEncoder encoder; + protected IEdgeKVEncoder edgeEncoder; + protected IVertexKVEncoder vertexEncoder; + + public SyncGraphMultiVersionedProxy( + RocksdbClient rocksdbStore, IGraphKVEncoder encoder, Configuration config) { + this.encoder = encoder; + this.rocksdbClient = rocksdbStore; + this.vertexEncoder = encoder.getVertexEncoder(); + this.edgeEncoder = encoder.getEdgeEncoder(); + this.config = config; + } + + @Override + public void addVertex(long version, IVertex vertex) { + Tuple tuple = vertexEncoder.format(vertex); + byte[] bVersion = getBinaryVersion(version); + this.rocksdbClient.write(VERTEX_CF, concat(bVersion, tuple.f0), tuple.f1); + this.rocksdbClient.write(VERTEX_INDEX_CF, concat(tuple.f0, bVersion), EMPTY_BYTES); + } + + @Override + public void addEdge(long version, IEdge edge) { + byte[] bVersion = getBinaryVersion(version); + Tuple tuple = edgeEncoder.format(edge); + this.rocksdbClient.write(EDGE_CF, concat(bVersion, tuple.f0), tuple.f1); + } + + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + byte[] bVersion = getBinaryVersion(version); + byte[] value = this.rocksdbClient.get(VERTEX_CF, concat(bVersion, key)); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { + return vertex; + } } - - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - List> list = new ArrayList<>(); - byte[] bVersion = getBinaryVersion(version); - byte[] prefix = concat(bVersion, edgeEncoder.getScanBytes(sid)); - - IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); - try (RocksdbIterator it = new RocksdbIterator( - this.rocksdbClient.getIterator(EDGE_CF), prefix)) { - while (it.hasNext()) { - Tuple pair = it.next(); - IEdge edge = edgeEncoder.getEdge(getKeyFromVersionToKey(pair.f0), pair.f1); - if (filter.filterEdge(edge)) { - list.add(edge); - } - } + return null; + } + + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + List> list = new ArrayList<>(); + byte[] bVersion = getBinaryVersion(version); + byte[] prefix = concat(bVersion, edgeEncoder.getScanBytes(sid)); + + IGraphFilter filter = (IGraphFilter) pushdown.getFilter(); + try (RocksdbIterator it = + new RocksdbIterator(this.rocksdbClient.getIterator(EDGE_CF), prefix)) { + while (it.hasNext()) { + Tuple pair = it.next(); + IEdge edge = edgeEncoder.getEdge(getKeyFromVersionToKey(pair.f0), pair.f1); + if (filter.filterEdge(edge)) { + list.add(edge); } - return list; + } } - - @Override - public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { - IVertex vertex = getVertex(version, sid, pushdown); - List> edgeList = getEdges(version, sid, pushdown); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edgeList.iterator())); - if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } else { - return null; - } + return list; + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(version, sid, pushdown); + List> edgeList = getEdges(version, sid, pushdown); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edgeList.iterator())); + if (((IGraphFilter) pushdown.getFilter()).filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } else { + return null; } - - @Override - public CloseableIterator vertexIDIterator() { - flush(); - RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF)); - - return new IteratorWithFnThenFilter<>(it, - tuple2 -> vertexEncoder.getVertexID(getKeyFromKeyToVersion(tuple2.f0)), - new DudupPredicate<>()); + } + + @Override + public CloseableIterator vertexIDIterator() { + flush(); + RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF)); + + return new IteratorWithFnThenFilter<>( + it, + tuple2 -> vertexEncoder.getVertexID(getKeyFromKeyToVersion(tuple2.f0)), + new DudupPredicate<>()); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushDown) { + if (pushDown.getFilter() == null) { + flush(); + byte[] prefix = getVersionPrefix(version); + RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF), prefix); + return new IteratorWithFnThenFilter<>( + it, + tuple2 -> vertexEncoder.getVertexID(getKeyFromVersionToKey(tuple2.f0)), + new DudupPredicate<>()); + + } else { + return new IteratorWithFn<>(getVertexIterator(version, pushDown), IVertex::getId); } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushDown) { - if (pushDown.getFilter() == null) { - flush(); - byte[] prefix = getVersionPrefix(version); - RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF), prefix); - return new IteratorWithFnThenFilter<>(it, - tuple2 -> vertexEncoder.getVertexID(getKeyFromVersionToKey(tuple2.f0)), - new DudupPredicate<>()); - - } else { - return new IteratorWithFn<>(getVertexIterator(version, pushDown), IVertex::getId); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + flush(); + byte[] prefix = getVersionPrefix(version); + RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF), prefix); + return new VertexScanIterator<>( + it, pushdown, (key, value) -> vertexEncoder.getVertex(getKeyFromVersionToKey(key), value)); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, (k, f) -> getVertex(version, k, f), pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + flush(); + byte[] prefix = getVersionPrefix(version); + RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(EDGE_CF), prefix); + return new EdgeScanIterator<>( + it, pushdown, (key, value) -> edgeEncoder.getEdge(getKeyFromVersionToKey(key), value)); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return new IteratorWithFlatFn<>( + new KeysIterator<>(keys, (k, f) -> getEdges(version, k, f), pushdown), List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + flush(); + return new OneDegreeGraphScanIterator<>( + encoder.getKeyType(), + getVertexIterator(version, pushdown), + getEdgeIterator(version, pushdown), + pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, (k, f) -> getOneDegreeGraph(version, k, f), pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + flush(); + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + List list = new ArrayList<>(); + byte[] prefix = Bytes.concat(encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); + try (RocksdbIterator it = + new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF), prefix)) { + while (it.hasNext()) { + Tuple pair = it.next(); + list.add(getVersionFromKeyToVersion(pair.f0)); } + } + return list; } - - @Override - public CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown) { - flush(); - byte[] prefix = getVersionPrefix(version); - RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF), prefix); - return new VertexScanIterator<>(it, pushdown, - (key, value) -> vertexEncoder.getVertex(getKeyFromVersionToKey(key), value)); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, (k, f) -> getVertex(version, k, f), pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - flush(); - byte[] prefix = getVersionPrefix(version); - RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(EDGE_CF), prefix); - return new EdgeScanIterator<>(it, pushdown, - (key, value) -> edgeEncoder.getEdge(getKeyFromVersionToKey(key), value)); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return new IteratorWithFlatFn<>(new KeysIterator<>(keys, (k, f) -> getEdges(version, k, f), pushdown), List::iterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - flush(); - return new OneDegreeGraphScanIterator<>( - encoder.getKeyType(), - getVertexIterator(version, pushdown), - getEdgeIterator(version, pushdown), - pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, List keys, - IStatePushDown pushdown) { - return new KeysIterator<>(keys, (k, f) -> getOneDegreeGraph(version, k, f), pushdown); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - flush(); - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - List list = new ArrayList<>(); - byte[] prefix = Bytes.concat(encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); - try (RocksdbIterator it = - new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF), prefix)) { - while (it.hasNext()) { - Tuple pair = it.next(); - list.add(getVersionFromKeyToVersion(pair.f0)); - } - } - return list; + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + flush(); + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + byte[] prefix = getKeyPrefix(id); + try (RocksdbIterator it = + new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF), prefix)) { + if (it.hasNext()) { + Tuple pair = it.next(); + return getVersionFromKeyToVersion(pair.f0); } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + return -1; } - - @Override - public long getLatestVersion(K id, DataType dataType) { - flush(); - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - byte[] prefix = getKeyPrefix(id); - try (RocksdbIterator it = - new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_INDEX_CF), prefix)) { - if (it.hasNext()) { - Tuple pair = it.next(); - return getVersionFromKeyToVersion(pair.f0); - } - } - return -1; + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + List allVersions = getAllVersions(id, dataType); + return getVersionData(id, allVersions, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType) { + if (dataType == DataType.V || dataType == DataType.V_TOPO) { + Map> map = new HashMap<>(); + for (long version : versions) { + IVertex vertex = getVertex(version, id, pushdown); + if (vertex != null) { + map.put(version, vertex); } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + return map; } + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - List allVersions = getAllVersions(id, dataType); - return getVersionData(id, allVersions, pushdown, dataType); - } + @Override + public RocksdbClient getClient() { + return rocksdbClient; + } - @Override - public Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType) { - if (dataType == DataType.V || dataType == DataType.V_TOPO) { - Map> map = new HashMap<>(); - for (long version : versions) { - IVertex vertex = getVertex(version, id, pushdown); - if (vertex != null) { - map.put(version, vertex); - } - } - return map; - } - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void flush() {} + @Override + public void close() {} - @Override - public RocksdbClient getClient() { - return rocksdbClient; - } + private long getVersionFromKeyToVersion(byte[] key) { + byte[] bVersion = Arrays.copyOfRange(key, key.length - 8, key.length); + return Long.MAX_VALUE - Longs.fromByteArray(bVersion); + } - @Override - public void flush() { + protected byte[] getKeyFromKeyToVersion(byte[] key) { + return Arrays.copyOf(key, key.length - VERTEX_INDEX_SUFFIX_SIZE); + } - } + protected byte[] getBinaryVersion(long version) { + return Longs.toByteArray(Long.MAX_VALUE - version); + } - @Override - public void close() { + protected byte[] getKeyPrefix(K id) { + return Bytes.concat(this.encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); + } - } + protected byte[] getVersionPrefix(long version) { + return Bytes.concat(getBinaryVersion(version), StateConfigKeys.DELIMITER); + } - private long getVersionFromKeyToVersion(byte[] key) { - byte[] bVersion = Arrays.copyOfRange(key, key.length - 8, key.length); - return Long.MAX_VALUE - Longs.fromByteArray(bVersion); - } + protected byte[] getKeyFromVersionToKey(byte[] key) { + return Arrays.copyOfRange(key, 10, key.length); + } - protected byte[] getKeyFromKeyToVersion(byte[] key) { - return Arrays.copyOf(key, key.length - VERTEX_INDEX_SUFFIX_SIZE); - } + protected byte[] concat(byte[] a, byte[] b) { + return Bytes.concat(a, StateConfigKeys.DELIMITER, b); + } - protected byte[] getBinaryVersion(long version) { - return Longs.toByteArray(Long.MAX_VALUE - version); - } - - protected byte[] getKeyPrefix(K id) { - return Bytes.concat(this.encoder.getKeyType().serialize(id), StateConfigKeys.DELIMITER); - } + protected static class DudupPredicate implements Predicate { - protected byte[] getVersionPrefix(long version) { - return Bytes.concat(getBinaryVersion(version), StateConfigKeys.DELIMITER); - } + K last = null; - protected byte[] getKeyFromVersionToKey(byte[] key) { - return Arrays.copyOfRange(key, 10, key.length); - } - - protected byte[] concat(byte[] a, byte[] b) { - return Bytes.concat(a, StateConfigKeys.DELIMITER, b); - } - - protected static class DudupPredicate implements Predicate { - - K last = null; - - @Override - public boolean test(K k) { - boolean res = k.equals(last); - last = k; - return !res; - } + @Override + public boolean test(K k) { + boolean res = k.equals(last); + last = k; + return !res; } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphRocksdbProxy.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphRocksdbProxy.java index e8aa2eafa..83905cb1d 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphRocksdbProxy.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/main/java/org/apache/geaflow/store/rocksdb/proxy/SyncGraphRocksdbProxy.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; @@ -57,221 +58,220 @@ public class SyncGraphRocksdbProxy implements IGraphRocksdbProxy { - protected final Configuration config; - protected final IVertexKVEncoder vertexEncoder; - protected final IEdgeKVEncoder edgeEncoder; - protected IGraphKVEncoder encoder; - protected final RocksdbClient rocksdbClient; - - public SyncGraphRocksdbProxy(RocksdbClient rocksdbClient, IGraphKVEncoder encoder, - Configuration config) { - this.encoder = encoder; - this.vertexEncoder = this.encoder.getVertexEncoder(); - this.edgeEncoder = this.encoder.getEdgeEncoder(); - this.rocksdbClient = rocksdbClient; - this.config = config; - } - - @Override - public RocksdbClient getClient() { - return rocksdbClient; - } - - @Override - public void addVertex(IVertex vertex) { - Tuple tuple = vertexEncoder.format(vertex); - this.rocksdbClient.write(VERTEX_CF, tuple.f0, tuple.f1); - } - - @Override - public void addEdge(IEdge edge) { - Tuple tuple = edgeEncoder.format(edge); - this.rocksdbClient.write(EDGE_CF, tuple.f0, tuple.f1); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - byte[] key = encoder.getKeyType().serialize(sid); - byte[] value = this.rocksdbClient.get(VERTEX_CF, key); - if (value != null) { - IVertex vertex = vertexEncoder.getVertex(key, value); - if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { - return vertex; - } - } - return null; - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - return getEdges(sid, filter); - } - - protected List> getEdges(K sid, IGraphFilter filter) { - List> list = new ArrayList<>(); - byte[] prefix = edgeEncoder.getScanBytes(sid); - try (RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(EDGE_CF), - prefix)) { - getEdgesFromRocksDBIterator(list, it, filter); - } - - return list; - } - - protected void getEdgesFromRocksDBIterator(List> list, RocksdbIterator it, - IGraphFilter filter) { - while (it.hasNext()) { - Tuple pair = it.next(); - IEdge edge = edgeEncoder.getEdge(pair.f0, pair.f1); - if (filter.filterEdge(edge)) { - list.add(edge); - } - if (filter.dropAllRemaining()) { - break; - } - } - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - IVertex vertex = getVertex(sid, pushdown); - List> edgeList = getEdges(sid, pushdown); - IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); - OneDegreeGraph oneDegreeGraph = new OneDegreeGraph<>(sid, vertex, - IteratorWithClose.wrap(edgeList.iterator())); - if (filter.filterOneDegreeGraph(oneDegreeGraph)) { - return oneDegreeGraph; - } else { - return null; - } - } - - @Override - public CloseableIterator vertexIDIterator() { - flush(); - RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_CF)); - return buildVertexIDIteratorFromRocksDBIter(it); - } - - protected CloseableIterator buildVertexIDIteratorFromRocksDBIter( - CloseableIterator> it) { - return new IteratorWithFnThenFilter<>(it, tuple2 -> vertexEncoder.getVertexID(tuple2.f0), - predicate()); + protected final Configuration config; + protected final IVertexKVEncoder vertexEncoder; + protected final IEdgeKVEncoder edgeEncoder; + protected IGraphKVEncoder encoder; + protected final RocksdbClient rocksdbClient; + + public SyncGraphRocksdbProxy( + RocksdbClient rocksdbClient, IGraphKVEncoder encoder, Configuration config) { + this.encoder = encoder; + this.vertexEncoder = this.encoder.getVertexEncoder(); + this.edgeEncoder = this.encoder.getEdgeEncoder(); + this.rocksdbClient = rocksdbClient; + this.config = config; + } + + @Override + public RocksdbClient getClient() { + return rocksdbClient; + } + + @Override + public void addVertex(IVertex vertex) { + Tuple tuple = vertexEncoder.format(vertex); + this.rocksdbClient.write(VERTEX_CF, tuple.f0, tuple.f1); + } + + @Override + public void addEdge(IEdge edge) { + Tuple tuple = edgeEncoder.format(edge); + this.rocksdbClient.write(EDGE_CF, tuple.f0, tuple.f1); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + byte[] key = encoder.getKeyType().serialize(sid); + byte[] value = this.rocksdbClient.get(VERTEX_CF, key); + if (value != null) { + IVertex vertex = vertexEncoder.getVertex(key, value); + if (pushdown == null || ((IGraphFilter) pushdown.getFilter()).filterVertex(vertex)) { + return vertex; + } } - - private Predicate predicate() { - return new Predicate() { - K last = null; - - @Override - public boolean test(K k) { - boolean res = k.equals(last); - last = k; - return !res; - } - }; + return null; + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + return getEdges(sid, filter); + } + + protected List> getEdges(K sid, IGraphFilter filter) { + List> list = new ArrayList<>(); + byte[] prefix = edgeEncoder.getScanBytes(sid); + try (RocksdbIterator it = + new RocksdbIterator(this.rocksdbClient.getIterator(EDGE_CF), prefix)) { + getEdgesFromRocksDBIterator(list, it, filter); } - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - if (pushDown.getFilter() == null) { - return vertexIDIterator(); - } else { - return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); - } + return list; + } + + protected void getEdgesFromRocksDBIterator( + List> list, RocksdbIterator it, IGraphFilter filter) { + while (it.hasNext()) { + Tuple pair = it.next(); + IEdge edge = edgeEncoder.getEdge(pair.f0, pair.f1); + if (filter.filterEdge(edge)) { + list.add(edge); + } + if (filter.dropAllRemaining()) { + break; + } } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - flush(); - RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF)); - return new VertexScanIterator<>(it, pushdown, vertexEncoder::getVertex); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + IVertex vertex = getVertex(sid, pushdown); + List> edgeList = getEdges(sid, pushdown); + IGraphFilter filter = GraphFilter.of(pushdown.getFilter(), pushdown.getEdgeLimit()); + OneDegreeGraph oneDegreeGraph = + new OneDegreeGraph<>(sid, vertex, IteratorWithClose.wrap(edgeList.iterator())); + if (filter.filterOneDegreeGraph(oneDegreeGraph)) { + return oneDegreeGraph; + } else { + return null; } - - @Override - public CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown) { - return new KeysIterator<>(keys, this::getVertex, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + flush(); + RocksdbIterator it = new RocksdbIterator(this.rocksdbClient.getIterator(VERTEX_CF)); + return buildVertexIDIteratorFromRocksDBIter(it); + } + + protected CloseableIterator buildVertexIDIteratorFromRocksDBIter( + CloseableIterator> it) { + return new IteratorWithFnThenFilter<>( + it, tuple2 -> vertexEncoder.getVertexID(tuple2.f0), predicate()); + } + + private Predicate predicate() { + return new Predicate() { + K last = null; + + @Override + public boolean test(K k) { + boolean res = k.equals(last); + last = k; + return !res; + } + }; + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + if (pushDown.getFilter() == null) { + return vertexIDIterator(); + } else { + return new IteratorWithFn<>(getVertexIterator(pushDown), IVertex::getId); } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - flush(); - RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(EDGE_CF)); - return new EdgeScanIterator<>(it, pushdown, edgeEncoder::getEdge); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + flush(); + RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(VERTEX_CF)); + return new VertexScanIterator<>(it, pushdown, vertexEncoder::getVertex); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, this::getVertex, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + flush(); + RocksdbIterator it = new RocksdbIterator(rocksdbClient.getIterator(EDGE_CF)); + return new EdgeScanIterator<>(it, pushdown, edgeEncoder::getEdge); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + return new IteratorWithFlatFn<>( + new KeysIterator<>(keys, this::getEdges, pushdown), List::iterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + flush(); + return new OneDegreeGraphScanIterator<>( + encoder.getKeyType(), getVertexIterator(pushdown), getEdgeIterator(pushdown), pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + flush(); + return new IteratorWithFn<>( + getEdgeIterator(pushdown), e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return new IteratorWithFn<>( + getEdgeIterator(keys, pushdown), + e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + Map res = new HashMap<>(); + Iterator>> it = new EdgeListScanIterator<>(getEdgeIterator(pushdown)); + while (it.hasNext()) { + List> edges = it.next(); + K key = edges.get(0).getSrcId(); + res.put(key, (long) edges.size()); } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - return new IteratorWithFlatFn<>(new KeysIterator<>(keys, this::getEdges, pushdown), List::iterator); + return res; + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + Map res = new HashMap<>(keys.size()); + + Function pushdownFun; + if (pushdown.getFilters() == null) { + pushdownFun = key -> pushdown; + } else { + pushdownFun = key -> StatePushDown.of().withFilter((IFilter) pushdown.getFilters().get(key)); } - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - flush(); - return new OneDegreeGraphScanIterator<>(encoder.getKeyType(), - getVertexIterator(pushdown), getEdgeIterator(pushdown), pushdown); + for (K key : keys) { + List> list = getEdges(key, pushdownFun.apply(key)); + res.put(key, (long) list.size()); } + return res; + } - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, IStatePushDown pushdown) { - return new KeysIterator<>(keys, this::getOneDegreeGraph, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - flush(); - return new IteratorWithFn<>(getEdgeIterator(pushdown), e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown) { - return new IteratorWithFn<>(getEdgeIterator(keys, pushdown), e -> Tuple.of(e.getSrcId(), pushdown.getProjector().project(e))); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - Map res = new HashMap<>(); - Iterator>> it = - new EdgeListScanIterator<>(getEdgeIterator(pushdown)); - while (it.hasNext()) { - List> edges = it.next(); - K key = edges.get(0).getSrcId(); - res.put(key, (long) edges.size()); - } - return res; - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - Map res = new HashMap<>(keys.size()); - - Function pushdownFun; - if (pushdown.getFilters() == null) { - pushdownFun = key -> pushdown; - } else { - pushdownFun = - key -> StatePushDown.of().withFilter((IFilter) pushdown.getFilters().get(key)); - } - - for (K key : keys) { - List> list = getEdges(key, pushdownFun.apply(key)); - res.put(key, (long) list.size()); - } - return res; - } - - @Override - public void flush() { - - } - - @Override - public void close() { - - } + @Override + public void flush() {} + @Override + public void close() {} } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeysTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeysTest.java index 6da414187..e341d2e35 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeysTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbConfigKeysTest.java @@ -24,18 +24,18 @@ public class RocksdbConfigKeysTest { - @Test - public void testChkPath() { - String path1 = "0_chk4"; - String path2 = "51_chk4"; - String path3 = "111_chk4.tmp"; + @Test + public void testChkPath() { + String path1 = "0_chk4"; + String path2 = "51_chk4"; + String path3 = "111_chk4.tmp"; - Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path1), "0_chk"); - Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path2), "51_chk"); - Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path3), "111_chk"); + Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path1), "0_chk"); + Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path2), "51_chk"); + Assert.assertEquals(RocksdbConfigKeys.getChkPathPrefix(path3), "111_chk"); - Assert.assertTrue(RocksdbConfigKeys.isChkPath(path1)); - Assert.assertTrue(RocksdbConfigKeys.isChkPath(path2)); - Assert.assertFalse(RocksdbConfigKeys.isChkPath(path3)); - } + Assert.assertTrue(RocksdbConfigKeys.isChkPath(path1)); + Assert.assertTrue(RocksdbConfigKeys.isChkPath(path2)); + Assert.assertFalse(RocksdbConfigKeys.isChkPath(path3)); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilderTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilderTest.java index 9844e49ae..52fd421d7 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilderTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-rocksdb/src/test/java/org/apache/geaflow/store/rocksdb/RocksdbStoreBuilderTest.java @@ -24,6 +24,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -42,86 +43,85 @@ public class RocksdbStoreBuilderTest { - Map config = new HashMap<>(); + Map config = new HashMap<>(); - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbStoreBuilderTest"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/RocksdbStoreBuilderTest"); - config.put(JOB_MAX_PARALLEL.getKey(), "1"); - } + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbStoreBuilderTest"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/RocksdbStoreBuilderTest"); + config.put(JOB_MAX_PARALLEL.getKey(), "1"); + } - @Test - public void testKV() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); - Configuration configuration = new Configuration(config); - IKVStatefulStore kvStore = (IKVStatefulStore) builder.getStore( - DataModel.KV, configuration); - StoreContext storeContext = new StoreContext("rocksdb_kv").withConfig(configuration); - storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); + @Test + public void testKV() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); + Configuration configuration = new Configuration(config); + IKVStatefulStore kvStore = + (IKVStatefulStore) builder.getStore(DataModel.KV, configuration); + StoreContext storeContext = new StoreContext("rocksdb_kv").withConfig(configuration); + storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); - kvStore.init(storeContext); - kvStore.put("hello", "world"); - kvStore.put("foo", "bar"); - kvStore.flush(); + kvStore.init(storeContext); + kvStore.put("hello", "world"); + kvStore.put("foo", "bar"); + kvStore.flush(); - Assert.assertEquals(kvStore.get("hello"), "world"); - Assert.assertEquals(kvStore.get("foo"), "bar"); + Assert.assertEquals(kvStore.get("hello"), "world"); + Assert.assertEquals(kvStore.get("foo"), "bar"); - kvStore.archive(1); - kvStore.drop(); + kvStore.archive(1); + kvStore.drop(); - kvStore = (IKVStatefulStore) builder.getStore(DataModel.KV, configuration); - kvStore.init(storeContext); - kvStore.recovery(1); + kvStore = (IKVStatefulStore) builder.getStore(DataModel.KV, configuration); + kvStore.init(storeContext); + kvStore.recovery(1); - Assert.assertEquals(kvStore.get("hello"), "world"); - Assert.assertEquals(kvStore.get("foo"), "bar"); - } + Assert.assertEquals(kvStore.get("hello"), "world"); + Assert.assertEquals(kvStore.get("foo"), "bar"); + } - @Test - public void testFO() { - IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); - Configuration configuration = new Configuration(config); - KVRocksdbStoreBase kvStore = - (KVRocksdbStoreBase) builder.getStore( - DataModel.KV, configuration); - StoreContext storeContext = new StoreContext("rocksdb_kv").withConfig(configuration); - storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); - kvStore.init(storeContext); - Assert.assertEquals(kvStore.recoveryLatest(), -1); - for (int i = 1; i < 10; i++) { - kvStore.put("hello", "world" + i); - kvStore.put("foo", "bar" + i); - kvStore.flush(); - kvStore.archive(i); - } - kvStore.close(); - kvStore.drop(); - kvStore = (KVRocksdbStoreBase) builder.getStore(DataModel.KV, - configuration); - kvStore.init(storeContext); - kvStore.recoveryLatest(); - Assert.assertEquals(kvStore.get("hello"), "world" + 9); - Assert.assertEquals(kvStore.get("foo"), "bar" + 9); - kvStore.close(); - kvStore.drop(); - FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest/RocksdbStoreBuilderTest" - + "/rocksdb_kv/0/meta.9/_commit")); - kvStore = (KVRocksdbStoreBase) builder.getStore(DataModel.KV, - configuration); - kvStore.init(storeContext); - kvStore.recoveryLatest(); - Assert.assertEquals(kvStore.get("hello"), "world" + 8); - Assert.assertEquals(kvStore.get("foo"), "bar" + 8); - kvStore.close(); - kvStore.drop(); + @Test + public void testFO() { + IStoreBuilder builder = StoreBuilderFactory.build(StoreType.ROCKSDB.name()); + Configuration configuration = new Configuration(config); + KVRocksdbStoreBase kvStore = + (KVRocksdbStoreBase) builder.getStore(DataModel.KV, configuration); + StoreContext storeContext = new StoreContext("rocksdb_kv").withConfig(configuration); + storeContext.withKeySerializer(new DefaultKVSerializer<>(String.class, String.class)); + kvStore.init(storeContext); + Assert.assertEquals(kvStore.recoveryLatest(), -1); + for (int i = 1; i < 10; i++) { + kvStore.put("hello", "world" + i); + kvStore.put("foo", "bar" + i); + kvStore.flush(); + kvStore.archive(i); } + kvStore.close(); + kvStore.drop(); + kvStore = (KVRocksdbStoreBase) builder.getStore(DataModel.KV, configuration); + kvStore.init(storeContext); + kvStore.recoveryLatest(); + Assert.assertEquals(kvStore.get("hello"), "world" + 9); + Assert.assertEquals(kvStore.get("foo"), "bar" + 9); + kvStore.close(); + kvStore.drop(); + FileUtils.deleteQuietly( + new File( + "/tmp/RocksdbStoreBuilderTest/RocksdbStoreBuilderTest" + + "/rocksdb_kv/0/meta.9/_commit")); + kvStore = (KVRocksdbStoreBase) builder.getStore(DataModel.KV, configuration); + kvStore.init(storeContext); + kvStore.recoveryLatest(); + Assert.assertEquals(kvStore.get("hello"), "world" + 8); + Assert.assertEquals(kvStore.get("foo"), "bar" + 8); + kvStore.close(); + kvStore.drop(); + } - @AfterMethod - public void tearUp() { - FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); - } + @AfterMethod + public void tearUp() { + FileUtils.deleteQuietly(new File("/tmp/RocksdbStoreBuilderTest")); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/GraphVectorIndex.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/GraphVectorIndex.java index ed49e667c..873a1eb7d 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/GraphVectorIndex.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/GraphVectorIndex.java @@ -20,6 +20,7 @@ package org.apache.geaflow.store.lucene; import java.io.IOException; + import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -39,154 +40,150 @@ /** * Graph vector index implementation using Apache Lucene. * - *

Thread Safety Note: This class is not thread-safe. - * All public methods in this class are not designed for concurrent access. - * External synchronization is required if instances of this class may be accessed - * from multiple threads.

+ *

Thread Safety Note: This class is not thread-safe. All + * public methods in this class are not designed for concurrent access. External + * synchronization is required if instances of this class may be accessed from multiple threads. + * + *

Specifically: * - *

Specifically:

*
    - *
  • {@link #addVectorIndex} uses a non-thread-safe {@link IndexWriter}
  • - *
  • {@link #searchVectorIndex} may see inconsistent data during concurrent writes
  • - *
  • The underlying {@link ByteBuffersDirectory} is not designed for multi-threaded access
  • + *
  • {@link #addVectorIndex} uses a non-thread-safe {@link IndexWriter} + *
  • {@link #searchVectorIndex} may see inconsistent data during concurrent writes + *
  • The underlying {@link ByteBuffersDirectory} is not designed for multi-threaded access *
* * @param the type of keys used to identify vectors */ public class GraphVectorIndex implements IVectorIndex { - private final Directory directory; - private final IndexWriter writer; - private final Class keyClass; - - private static final String KEY_FIELD_NAME = "key_field"; - - public GraphVectorIndex(Class keyClas) { - try { - this.directory = new ByteBuffersDirectory(); - StandardAnalyzer analyzer = new StandardAnalyzer(); - IndexWriterConfig config = new IndexWriterConfig(analyzer); - this.writer = new IndexWriter(directory, config); - this.keyClass = keyClas; - } catch (IOException e) { - throw new RuntimeException("Failed to initialize GraphVectorIndex", e); - } + private final Directory directory; + private final IndexWriter writer; + private final Class keyClass; + + private static final String KEY_FIELD_NAME = "key_field"; + + public GraphVectorIndex(Class keyClas) { + try { + this.directory = new ByteBuffersDirectory(); + StandardAnalyzer analyzer = new StandardAnalyzer(); + IndexWriterConfig config = new IndexWriterConfig(analyzer); + this.writer = new IndexWriter(directory, config); + this.keyClass = keyClas; + } catch (IOException e) { + throw new RuntimeException("Failed to initialize GraphVectorIndex", e); } - - /** - * Adds a vector to the index with the given key. - * - *

Note: This method is not thread-safe. - * The underlying {@link IndexWriter} is not designed for concurrent access. - * External synchronization is required if this method may be called from - * multiple threads.

- * - * @param isVertex whether this is a vertex (unused in current implementation) - * @param key the key associated with the vector - * @param fieldName the name of the vector field in the index - * @param vector the vector data to be indexed - * @throws RuntimeException if an I/O error occurs during indexing - */ - @Override - public void addVectorIndex(boolean isVertex, K key, String fieldName, float[] vector) { - try { - // Create document - Document doc = new Document(); - - // Select different Field based on the type of K - if (key instanceof Float) { - doc.add(new StoredField(KEY_FIELD_NAME, (float) key)); - } else if (key instanceof Long) { - doc.add(new StoredField(KEY_FIELD_NAME, (long) key)); - } else if (key instanceof Integer) { - doc.add(new StoredField(KEY_FIELD_NAME, (int) key)); - } else if (key instanceof String) { - doc.add(new TextField(KEY_FIELD_NAME, (String) key, Field.Store.YES)); - } else { - throw new IllegalArgumentException("Unsupported key type: " + key.getClass().getName()); - } - - // Add vector field - doc.add(new KnnVectorField(fieldName, vector)); - - // Add document to index - writer.addDocument(doc); - - // Commit write operation - writer.commit(); - } catch (IOException e) { - throw new RuntimeException("Failed to add vector index", e); - } + } + + /** + * Adds a vector to the index with the given key. + * + *

Note: This method is not thread-safe. The underlying + * {@link IndexWriter} is not designed for concurrent access. External synchronization is required + * if this method may be called from multiple threads. + * + * @param isVertex whether this is a vertex (unused in current implementation) + * @param key the key associated with the vector + * @param fieldName the name of the vector field in the index + * @param vector the vector data to be indexed + * @throws RuntimeException if an I/O error occurs during indexing + */ + @Override + public void addVectorIndex(boolean isVertex, K key, String fieldName, float[] vector) { + try { + // Create document + Document doc = new Document(); + + // Select different Field based on the type of K + if (key instanceof Float) { + doc.add(new StoredField(KEY_FIELD_NAME, (float) key)); + } else if (key instanceof Long) { + doc.add(new StoredField(KEY_FIELD_NAME, (long) key)); + } else if (key instanceof Integer) { + doc.add(new StoredField(KEY_FIELD_NAME, (int) key)); + } else if (key instanceof String) { + doc.add(new TextField(KEY_FIELD_NAME, (String) key, Field.Store.YES)); + } else { + throw new IllegalArgumentException("Unsupported key type: " + key.getClass().getName()); + } + + // Add vector field + doc.add(new KnnVectorField(fieldName, vector)); + + // Add document to index + writer.addDocument(doc); + + // Commit write operation + writer.commit(); + } catch (IOException e) { + throw new RuntimeException("Failed to add vector index", e); } - - /** - * Searches for the closest vector in the index. - * - *

Note: This method is not thread-safe. - * The underlying index access is not synchronized with write operations. - * External synchronization is required if this method may be called concurrently - * with {@link #addVectorIndex}.

- * - * @param isVertex whether this is a vertex (unused in current implementation) - * @param fieldName the name of the vector field to search - * @param vector the query vector - * @param topK the number of top results to return - * @return the key associated with the closest vector - * @throws RuntimeException if an I/O error occurs during search - */ - @Override - public K searchVectorIndex(boolean isVertex, String fieldName, float[] vector, int topK) { - try { - // Open index reader - IndexReader reader = DirectoryReader.open(directory); - IndexSearcher searcher = new IndexSearcher(reader); - - // Create KNN vector query - KnnVectorQuery knnQuery = new KnnVectorQuery(fieldName, vector, topK); - - // Execute search - TopDocs topDocs = searcher.search(knnQuery, topK); - - Document firstDoc = searcher.doc(topDocs.scoreDocs[0].doc); - - K result; - if (keyClass == String.class) { - String value = firstDoc.get(KEY_FIELD_NAME); - result = (K) value; - } else if (keyClass == Long.class) { - Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); - result = (K) Long.valueOf(value.longValue()); - } else if (keyClass == Integer.class) { - Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); - result = (K) Integer.valueOf(value.intValue()); - } else if (keyClass == Float.class) { - Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); - result = (K) Float.valueOf(value.floatValue()); - } else { - throw new IllegalArgumentException("Unsupported key type: " + keyClass.getName()); - } - - reader.close(); - - return result; - } catch (IOException e) { - throw new RuntimeException("Failed to search vector index", e); - } + } + + /** + * Searches for the closest vector in the index. + * + *

Note: This method is not thread-safe. The underlying index + * access is not synchronized with write operations. External synchronization is required if this + * method may be called concurrently with {@link #addVectorIndex}. + * + * @param isVertex whether this is a vertex (unused in current implementation) + * @param fieldName the name of the vector field to search + * @param vector the query vector + * @param topK the number of top results to return + * @return the key associated with the closest vector + * @throws RuntimeException if an I/O error occurs during search + */ + @Override + public K searchVectorIndex(boolean isVertex, String fieldName, float[] vector, int topK) { + try { + // Open index reader + IndexReader reader = DirectoryReader.open(directory); + IndexSearcher searcher = new IndexSearcher(reader); + + // Create KNN vector query + KnnVectorQuery knnQuery = new KnnVectorQuery(fieldName, vector, topK); + + // Execute search + TopDocs topDocs = searcher.search(knnQuery, topK); + + Document firstDoc = searcher.doc(topDocs.scoreDocs[0].doc); + + K result; + if (keyClass == String.class) { + String value = firstDoc.get(KEY_FIELD_NAME); + result = (K) value; + } else if (keyClass == Long.class) { + Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); + result = (K) Long.valueOf(value.longValue()); + } else if (keyClass == Integer.class) { + Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); + result = (K) Integer.valueOf(value.intValue()); + } else if (keyClass == Float.class) { + Number value = firstDoc.getField(KEY_FIELD_NAME).numericValue(); + result = (K) Float.valueOf(value.floatValue()); + } else { + throw new IllegalArgumentException("Unsupported key type: " + keyClass.getName()); + } + + reader.close(); + + return result; + } catch (IOException e) { + throw new RuntimeException("Failed to search vector index", e); } - - /** - * Close index writer and directory resources. - */ - public void close() { - try { - if (writer != null) { - writer.close(); - } - if (directory != null) { - directory.close(); - } - } catch (IOException e) { - throw new RuntimeException("Failed to close resources", e); - } + } + + /** Close index writer and directory resources. */ + public void close() { + try { + if (writer != null) { + writer.close(); + } + if (directory != null) { + directory.close(); + } + } catch (IOException e) { + throw new RuntimeException("Failed to close resources", e); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/IVectorIndex.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/IVectorIndex.java index 333447de4..492a0c6ea 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/IVectorIndex.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/main/java/org/apache/geaflow/store/lucene/IVectorIndex.java @@ -22,9 +22,9 @@ // Vector index interface, defining basic operations of vector index public interface IVectorIndex { - // Add vector index - void addVectorIndex(boolean isVertex, K key, String fieldName, float[] vector); + // Add vector index + void addVectorIndex(boolean isVertex, K key, String fieldName, float[] vector); - // Search vector index - K searchVectorIndex(boolean isVertex, String fieldName, float[] vector, int topK); -} \ No newline at end of file + // Search vector index + K searchVectorIndex(boolean isVertex, String fieldName, float[] vector, int topK); +} diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/GraphVectorIndexTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/GraphVectorIndexTest.java index 3231ebe7c..7ea0fb0f7 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/GraphVectorIndexTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/GraphVectorIndexTest.java @@ -27,132 +27,132 @@ public class GraphVectorIndexTest { - private GraphVectorIndex stringIndex; - private GraphVectorIndex longIndex; - private GraphVectorIndex intIndex; - private GraphVectorIndex floatIndex; - - @BeforeMethod - public void setUp() { - stringIndex = new GraphVectorIndex<>(String.class); - longIndex = new GraphVectorIndex<>(Long.class); - intIndex = new GraphVectorIndex<>(Integer.class); - floatIndex = new GraphVectorIndex<>(Float.class); + private GraphVectorIndex stringIndex; + private GraphVectorIndex longIndex; + private GraphVectorIndex intIndex; + private GraphVectorIndex floatIndex; + + @BeforeMethod + public void setUp() { + stringIndex = new GraphVectorIndex<>(String.class); + longIndex = new GraphVectorIndex<>(Long.class); + intIndex = new GraphVectorIndex<>(Integer.class); + floatIndex = new GraphVectorIndex<>(Float.class); + } + + @AfterMethod + public void tearDown() { + stringIndex.close(); + longIndex.close(); + intIndex.close(); + floatIndex.close(); + } + + @Test + public void testAddAndSearchStringKey() { + String key = "vertex_1"; + String fieldName = "embedding"; + float[] vector = {0.1f, 0.2f, 0.3f, 0.4f}; + + // Add vector index + stringIndex.addVectorIndex(true, key, fieldName, vector); + + // Search vector index + String result = stringIndex.searchVectorIndex(true, fieldName, vector, 1); + + assertEquals(result, key); + } + + @Test + public void testAddAndSearchLongKey() { + Long key = 12345L; + String fieldName = "embedding"; + float[] vector = {0.5f, 0.6f, 0.7f, 0.8f}; + + // Add vector index + longIndex.addVectorIndex(false, key, fieldName, vector); + + // Search vector index + Long result = longIndex.searchVectorIndex(false, fieldName, vector, 1); + + assertEquals(result, key); + } + + @Test + public void testAddAndSearchIntegerKey() { + Integer key = 999; + String fieldName = "features"; + float[] vector = {0.9f, 0.8f, 0.7f, 0.6f}; + + // Add vector index + intIndex.addVectorIndex(true, key, fieldName, vector); + + // Search vector index + Integer result = intIndex.searchVectorIndex(true, fieldName, vector, 1); + + assertEquals(result, key); + } + + @Test + public void testAddAndSearchFloatKey() { + Float key = 3.14f; + String fieldName = "weights"; + float[] vector = {0.2f, 0.4f, 0.6f, 0.8f}; + + // Add vector index + floatIndex.addVectorIndex(false, key, fieldName, vector); + + // Search vector index + Float result = floatIndex.searchVectorIndex(false, fieldName, vector, 1); + + assertEquals(result, key); + } + + @Test + public void testVectorSimilaritySearch() { + String fieldName = "similarity_test"; + + // Add several vectors with different similarities + stringIndex.addVectorIndex(true, "doc1", fieldName, new float[] {1.0f, 0.0f, 0.0f, 0.0f}); + stringIndex.addVectorIndex(true, "doc2", fieldName, new float[] {0.8f, 0.2f, 0.0f, 0.0f}); + stringIndex.addVectorIndex(true, "doc3", fieldName, new float[] {0.0f, 0.0f, 1.0f, 0.0f}); + + // Query using a vector identical to doc1 + float[] queryVector = {1.0f, 0.0f, 0.0f, 0.0f}; + String result = stringIndex.searchVectorIndex(true, fieldName, queryVector, 1); + + // Should return the most similar document + assertEquals(result, "doc1"); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testUnsupportedKeyType() { + // Use unsupported key type + Double key = 1.23; + GraphVectorIndex doubleIndex = new GraphVectorIndex<>(Double.class); + + try { + doubleIndex.addVectorIndex(true, key, "test_field", new float[] {0.1f, 0.2f}); + } finally { + doubleIndex.close(); } + } - @AfterMethod - public void tearDown() { - stringIndex.close(); - longIndex.close(); - intIndex.close(); - floatIndex.close(); - } - - @Test - public void testAddAndSearchStringKey() { - String key = "vertex_1"; - String fieldName = "embedding"; - float[] vector = {0.1f, 0.2f, 0.3f, 0.4f}; - - // Add vector index - stringIndex.addVectorIndex(true, key, fieldName, vector); - - // Search vector index - String result = stringIndex.searchVectorIndex(true, fieldName, vector, 1); - - assertEquals(result, key); - } - - @Test - public void testAddAndSearchLongKey() { - Long key = 12345L; - String fieldName = "embedding"; - float[] vector = {0.5f, 0.6f, 0.7f, 0.8f}; - - // Add vector index - longIndex.addVectorIndex(false, key, fieldName, vector); - - // Search vector index - Long result = longIndex.searchVectorIndex(false, fieldName, vector, 1); - - assertEquals(result, key); - } - - @Test - public void testAddAndSearchIntegerKey() { - Integer key = 999; - String fieldName = "features"; - float[] vector = {0.9f, 0.8f, 0.7f, 0.6f}; - - // Add vector index - intIndex.addVectorIndex(true, key, fieldName, vector); - - // Search vector index - Integer result = intIndex.searchVectorIndex(true, fieldName, vector, 1); - - assertEquals(result, key); - } - - @Test - public void testAddAndSearchFloatKey() { - Float key = 3.14f; - String fieldName = "weights"; - float[] vector = {0.2f, 0.4f, 0.6f, 0.8f}; - - // Add vector index - floatIndex.addVectorIndex(false, key, fieldName, vector); - - // Search vector index - Float result = floatIndex.searchVectorIndex(false, fieldName, vector, 1); + @Test + public void testMultipleVectorsSameKeyDifferentFields() { + String key = "multi_field_vertex"; + float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; + float[] vector2 = {0.5f, 0.6f, 0.7f, 0.8f}; - assertEquals(result, key); - } - - @Test - public void testVectorSimilaritySearch() { - String fieldName = "similarity_test"; + // Add vectors for the same key in different fields + stringIndex.addVectorIndex(true, key, "field1", vector1); + stringIndex.addVectorIndex(true, key, "field2", vector2); - // Add several vectors with different similarities - stringIndex.addVectorIndex(true, "doc1", fieldName, new float[]{1.0f, 0.0f, 0.0f, 0.0f}); - stringIndex.addVectorIndex(true, "doc2", fieldName, new float[]{0.8f, 0.2f, 0.0f, 0.0f}); - stringIndex.addVectorIndex(true, "doc3", fieldName, new float[]{0.0f, 0.0f, 1.0f, 0.0f}); + // Search different fields separately + String result1 = stringIndex.searchVectorIndex(true, "field1", vector1, 1); + String result2 = stringIndex.searchVectorIndex(true, "field2", vector2, 1); - // Query using a vector identical to doc1 - float[] queryVector = {1.0f, 0.0f, 0.0f, 0.0f}; - String result = stringIndex.searchVectorIndex(true, fieldName, queryVector, 1); - - // Should return the most similar document - assertEquals(result, "doc1"); - } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testUnsupportedKeyType() { - // Use unsupported key type - Double key = 1.23; - GraphVectorIndex doubleIndex = new GraphVectorIndex<>(Double.class); - - try { - doubleIndex.addVectorIndex(true, key, "test_field", new float[]{0.1f, 0.2f}); - } finally { - doubleIndex.close(); - } - } - - @Test - public void testMultipleVectorsSameKeyDifferentFields() { - String key = "multi_field_vertex"; - float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; - float[] vector2 = {0.5f, 0.6f, 0.7f, 0.8f}; - - // Add vectors for the same key in different fields - stringIndex.addVectorIndex(true, key, "field1", vector1); - stringIndex.addVectorIndex(true, key, "field2", vector2); - - // Search different fields separately - String result1 = stringIndex.searchVectorIndex(true, "field1", vector1, 1); - String result2 = stringIndex.searchVectorIndex(true, "field2", vector2, 1); - - assertEquals(result1, key); - assertEquals(result2, key); - } + assertEquals(result1, key); + assertEquals(result2, key); + } } diff --git a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/LuceneTest.java b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/LuceneTest.java index 67c5304da..698fbf865 100644 --- a/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/LuceneTest.java +++ b/geaflow/geaflow-plugins/geaflow-store/geaflow-store-vector/src/test/java/org/apache/geaflow/store/lucene/LuceneTest.java @@ -24,6 +24,7 @@ import static org.testng.Assert.assertTrue; import java.io.IOException; + import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -43,202 +44,202 @@ public class LuceneTest { - @Test - public void test() throws IOException { - // Use ByteBuffersDirectory instead of RAMDirectory - Directory directory = new ByteBuffersDirectory(); - StandardAnalyzer analyzer = new StandardAnalyzer(); - - // Create index writer configuration - IndexWriterConfig config = new IndexWriterConfig(analyzer); - IndexWriter writer = new IndexWriter(directory, config); - - // Add documents (including vector fields) - addDocuments(writer); - - // Commit and close writer - writer.commit(); - writer.close(); - - // Search example - searchDocuments(directory, analyzer); - - // Close directory - directory.close(); - } - - - private static void addDocuments(IndexWriter writer) throws IOException { - // Document 1 - Document doc1 = new Document(); - doc1.add(new TextField("title", "机器学习基础", Field.Store.YES)); - doc1.add(new TextField("content", "机器学习是人工智能的重要分支", Field.Store.YES)); - // Add vector field - float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; - doc1.add(new KnnVectorField("vector_field", vector1)); - writer.addDocument(doc1); - - // Document 2 - Document doc2 = new Document(); - doc2.add(new TextField("title", "深度学习入门", Field.Store.YES)); - doc2.add(new TextField("content", "深度学习使用神经网络进行学习", Field.Store.YES)); - float[] vector2 = {0.2f, 0.3f, 0.4f, 0.5f}; - doc2.add(new KnnVectorField("vector_field", vector2)); - writer.addDocument(doc2); - - // Document 3 - Document doc3 = new Document(); - doc3.add(new TextField("title", "自然语言处理", Field.Store.YES)); - doc3.add(new TextField("content", "NLP是处理文本的技术", Field.Store.YES)); - float[] vector3 = {0.15f, 0.25f, 0.35f, 0.45f}; - doc3.add(new KnnVectorField("vector_field", vector3)); - writer.addDocument(doc3); + @Test + public void test() throws IOException { + // Use ByteBuffersDirectory instead of RAMDirectory + Directory directory = new ByteBuffersDirectory(); + StandardAnalyzer analyzer = new StandardAnalyzer(); + + // Create index writer configuration + IndexWriterConfig config = new IndexWriterConfig(analyzer); + IndexWriter writer = new IndexWriter(directory, config); + + // Add documents (including vector fields) + addDocuments(writer); + + // Commit and close writer + writer.commit(); + writer.close(); + + // Search example + searchDocuments(directory, analyzer); + + // Close directory + directory.close(); + } + + private static void addDocuments(IndexWriter writer) throws IOException { + // Document 1 + Document doc1 = new Document(); + doc1.add(new TextField("title", "机器学习基础", Field.Store.YES)); + doc1.add(new TextField("content", "机器学习是人工智能的重要分支", Field.Store.YES)); + // Add vector field + float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; + doc1.add(new KnnVectorField("vector_field", vector1)); + writer.addDocument(doc1); + + // Document 2 + Document doc2 = new Document(); + doc2.add(new TextField("title", "深度学习入门", Field.Store.YES)); + doc2.add(new TextField("content", "深度学习使用神经网络进行学习", Field.Store.YES)); + float[] vector2 = {0.2f, 0.3f, 0.4f, 0.5f}; + doc2.add(new KnnVectorField("vector_field", vector2)); + writer.addDocument(doc2); + + // Document 3 + Document doc3 = new Document(); + doc3.add(new TextField("title", "自然语言处理", Field.Store.YES)); + doc3.add(new TextField("content", "NLP是处理文本的技术", Field.Store.YES)); + float[] vector3 = {0.15f, 0.25f, 0.35f, 0.45f}; + doc3.add(new KnnVectorField("vector_field", vector3)); + writer.addDocument(doc3); + } + + private static void searchDocuments(Directory directory, StandardAnalyzer analyzer) + throws IOException { + IndexReader reader = DirectoryReader.open(directory); + IndexSearcher searcher = new IndexSearcher(reader); + + // Create query vector + float[] queryVector = {0.12f, 0.22f, 0.32f, 0.42f}; + + // Execute KNN search + KnnVectorQuery knnQuery = new KnnVectorQuery("vector_field", queryVector, 3); + + // Execute search + TopDocs topDocs = searcher.search(knnQuery, 10); + + // Verify search results + assertEquals(topDocs.scoreDocs.length, 3, "Should return 3 results"); + + // Verify results are sorted by relevance (scores from high to low) + for (int i = 0; i < topDocs.scoreDocs.length - 1; i++) { + assertTrue( + topDocs.scoreDocs[i].score >= topDocs.scoreDocs[i + 1].score, + "Results should be sorted by score from high to low"); } - private static void searchDocuments(Directory directory, StandardAnalyzer analyzer) throws IOException { - IndexReader reader = DirectoryReader.open(directory); - IndexSearcher searcher = new IndexSearcher(reader); - - // Create query vector - float[] queryVector = {0.12f, 0.22f, 0.32f, 0.42f}; - - // Execute KNN search - KnnVectorQuery knnQuery = new KnnVectorQuery("vector_field", queryVector, 3); - - // Execute search - TopDocs topDocs = searcher.search(knnQuery, 10); - - // Verify search results - assertEquals(topDocs.scoreDocs.length, 3, "Should return 3 results"); - - // Verify results are sorted by relevance (scores from high to low) - for (int i = 0; i < topDocs.scoreDocs.length - 1; i++) { - assertTrue(topDocs.scoreDocs[i].score >= topDocs.scoreDocs[i + 1].score, - "Results should be sorted by score from high to low"); - } - - // Verify each result has expected fields - for (ScoreDoc scoreDoc : topDocs.scoreDocs) { - Document doc = searcher.doc(scoreDoc.doc); - assertNotNull(doc.get("title"), "Title field should not be null"); - assertNotNull(doc.get("content"), "Content field should not be null"); - } - - reader.close(); - } - - // Add more test cases to test Lucene vector functionality - - @Test - public void testVectorSearchAccuracy() throws IOException { - Directory directory = new ByteBuffersDirectory(); - StandardAnalyzer analyzer = new StandardAnalyzer(); - IndexWriterConfig config = new IndexWriterConfig(analyzer); - IndexWriter writer = new IndexWriter(directory, config); - - // Create test data to ensure vector similarity has obvious differences - Document doc1 = new Document(); - doc1.add(new TextField("id", "1", Field.Store.YES)); - // Exactly the same vector, should get the highest score - float[] vector1 = {1.0f, 0.0f, 0.0f, 0.0f}; - doc1.add(new KnnVectorField("vector_field", vector1)); - writer.addDocument(doc1); - - Document doc2 = new Document(); - doc2.add(new TextField("id", "2", Field.Store.YES)); - // Partially similar vector - float[] vector2 = {0.8f, 0.2f, 0.0f, 0.0f}; - doc2.add(new KnnVectorField("vector_field", vector2)); - writer.addDocument(doc2); - - Document doc3 = new Document(); - doc3.add(new TextField("id", "3", Field.Store.YES)); - // Completely different vector, should get the lowest score - float[] vector3 = {0.0f, 0.0f, 1.0f, 0.0f}; - doc3.add(new KnnVectorField("vector_field", vector3)); - writer.addDocument(doc3); - - writer.commit(); - writer.close(); - - // Execute search - IndexReader reader = DirectoryReader.open(directory); - IndexSearcher searcher = new IndexSearcher(reader); - - // Query vector is exactly the same as the first document - float[] queryVector = {1.0f, 0.0f, 0.0f, 0.0f}; - KnnVectorQuery knnQuery = new KnnVectorQuery("vector_field", queryVector, 3); - TopDocs topDocs = searcher.search(knnQuery, 10); - - // Verify result count - assertEquals(topDocs.scoreDocs.length, 3); - - // Verify the first result should be the document with id 1 (highest score) - Document firstDoc = searcher.doc(topDocs.scoreDocs[0].doc); - assertEquals(firstDoc.get("id"), "1"); - - // Verify score order (first score should be highest) - assertTrue(topDocs.scoreDocs[0].score >= topDocs.scoreDocs[1].score); - assertTrue(topDocs.scoreDocs[1].score >= topDocs.scoreDocs[2].score); - - reader.close(); - directory.close(); - } - - @Test - public void testVectorFieldProperties() throws IOException { - // Test basic properties of vector fields - float[] vector = {0.1f, 0.2f, 0.3f, 0.4f}; - KnnVectorField vectorField = new KnnVectorField("test_vector", vector); - - // Verify field name - assertEquals(vectorField.name(), "test_vector"); - - // Verify field value - assertEquals(vectorField.vectorValue(), vector); - } - - @Test - public void testMultipleVectorFields() throws IOException { - Directory directory = new ByteBuffersDirectory(); - StandardAnalyzer analyzer = new StandardAnalyzer(); - IndexWriterConfig config = new IndexWriterConfig(analyzer); - IndexWriter writer = new IndexWriter(directory, config); - - // Create document with multiple vector fields - Document doc = new Document(); - doc.add(new TextField("title", "多向量测试", Field.Store.YES)); - - // Add first vector field - float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; - doc.add(new KnnVectorField("vector_field_1", vector1)); - - // Add second vector field - float[] vector2 = {0.5f, 0.6f, 0.7f, 0.8f}; - doc.add(new KnnVectorField("vector_field_2", vector2)); - - writer.addDocument(doc); - writer.commit(); - writer.close(); - - // Verify document in index - IndexReader reader = DirectoryReader.open(directory); - IndexSearcher searcher = new IndexSearcher(reader); - - assertEquals(reader.numDocs(), 1); - - // Test search on first vector field - KnnVectorQuery query1 = new KnnVectorQuery("vector_field_1", vector1, 1); - TopDocs results1 = searcher.search(query1, 10); - assertEquals(results1.scoreDocs.length, 1); - - // Test search on second vector field - KnnVectorQuery query2 = new KnnVectorQuery("vector_field_2", vector2, 1); - TopDocs results2 = searcher.search(query2, 10); - assertEquals(results2.scoreDocs.length, 1); - - reader.close(); - directory.close(); + // Verify each result has expected fields + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document doc = searcher.doc(scoreDoc.doc); + assertNotNull(doc.get("title"), "Title field should not be null"); + assertNotNull(doc.get("content"), "Content field should not be null"); } + reader.close(); + } + + // Add more test cases to test Lucene vector functionality + + @Test + public void testVectorSearchAccuracy() throws IOException { + Directory directory = new ByteBuffersDirectory(); + StandardAnalyzer analyzer = new StandardAnalyzer(); + IndexWriterConfig config = new IndexWriterConfig(analyzer); + IndexWriter writer = new IndexWriter(directory, config); + + // Create test data to ensure vector similarity has obvious differences + Document doc1 = new Document(); + doc1.add(new TextField("id", "1", Field.Store.YES)); + // Exactly the same vector, should get the highest score + float[] vector1 = {1.0f, 0.0f, 0.0f, 0.0f}; + doc1.add(new KnnVectorField("vector_field", vector1)); + writer.addDocument(doc1); + + Document doc2 = new Document(); + doc2.add(new TextField("id", "2", Field.Store.YES)); + // Partially similar vector + float[] vector2 = {0.8f, 0.2f, 0.0f, 0.0f}; + doc2.add(new KnnVectorField("vector_field", vector2)); + writer.addDocument(doc2); + + Document doc3 = new Document(); + doc3.add(new TextField("id", "3", Field.Store.YES)); + // Completely different vector, should get the lowest score + float[] vector3 = {0.0f, 0.0f, 1.0f, 0.0f}; + doc3.add(new KnnVectorField("vector_field", vector3)); + writer.addDocument(doc3); + + writer.commit(); + writer.close(); + + // Execute search + IndexReader reader = DirectoryReader.open(directory); + IndexSearcher searcher = new IndexSearcher(reader); + + // Query vector is exactly the same as the first document + float[] queryVector = {1.0f, 0.0f, 0.0f, 0.0f}; + KnnVectorQuery knnQuery = new KnnVectorQuery("vector_field", queryVector, 3); + TopDocs topDocs = searcher.search(knnQuery, 10); + + // Verify result count + assertEquals(topDocs.scoreDocs.length, 3); + + // Verify the first result should be the document with id 1 (highest score) + Document firstDoc = searcher.doc(topDocs.scoreDocs[0].doc); + assertEquals(firstDoc.get("id"), "1"); + + // Verify score order (first score should be highest) + assertTrue(topDocs.scoreDocs[0].score >= topDocs.scoreDocs[1].score); + assertTrue(topDocs.scoreDocs[1].score >= topDocs.scoreDocs[2].score); + + reader.close(); + directory.close(); + } + + @Test + public void testVectorFieldProperties() throws IOException { + // Test basic properties of vector fields + float[] vector = {0.1f, 0.2f, 0.3f, 0.4f}; + KnnVectorField vectorField = new KnnVectorField("test_vector", vector); + + // Verify field name + assertEquals(vectorField.name(), "test_vector"); + + // Verify field value + assertEquals(vectorField.vectorValue(), vector); + } + + @Test + public void testMultipleVectorFields() throws IOException { + Directory directory = new ByteBuffersDirectory(); + StandardAnalyzer analyzer = new StandardAnalyzer(); + IndexWriterConfig config = new IndexWriterConfig(analyzer); + IndexWriter writer = new IndexWriter(directory, config); + + // Create document with multiple vector fields + Document doc = new Document(); + doc.add(new TextField("title", "多向量测试", Field.Store.YES)); + + // Add first vector field + float[] vector1 = {0.1f, 0.2f, 0.3f, 0.4f}; + doc.add(new KnnVectorField("vector_field_1", vector1)); + + // Add second vector field + float[] vector2 = {0.5f, 0.6f, 0.7f, 0.8f}; + doc.add(new KnnVectorField("vector_field_2", vector2)); + + writer.addDocument(doc); + writer.commit(); + writer.close(); + + // Verify document in index + IndexReader reader = DirectoryReader.open(directory); + IndexSearcher searcher = new IndexSearcher(reader); + + assertEquals(reader.numDocs(), 1); + + // Test search on first vector field + KnnVectorQuery query1 = new KnnVectorQuery("vector_field_1", vector1, 1); + TopDocs results1 = searcher.search(query1, 10); + assertEquals(results1.scoreDocs.length, 1); + + // Test search on second vector field + KnnVectorQuery query2 = new KnnVectorQuery("vector_field_2", vector2, 1); + TopDocs results2 = searcher.search(query2, 10); + assertEquals(results2.scoreDocs.length, 1); + + reader.close(); + directory.close(); + } } diff --git a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMeta.java b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMeta.java index f825d7cf9..4ce56c287 100644 --- a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMeta.java +++ b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMeta.java @@ -19,96 +19,94 @@ package org.apache.geaflow.view.meta; -import com.google.common.base.Preconditions; -import com.google.protobuf.ByteString; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.file.FileInfo; import org.apache.geaflow.file.IPersistentIO; import org.apache.hadoop.fs.Path; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; + public class ViewMeta { - private static final Logger LOGGER = LoggerFactory.getLogger(ViewMeta.class); + private static final Logger LOGGER = LoggerFactory.getLogger(ViewMeta.class); - private static final String BACKUP_SUFFIX = ".bak"; - private static final String TMP_SUFFIX = ".tmp"; - private final Map innerKV = new HashMap<>(); - private final String filePath; - private final IPersistentIO persistIO; - private long fileModifyTS; + private static final String BACKUP_SUFFIX = ".bak"; + private static final String TMP_SUFFIX = ".tmp"; + private final Map innerKV = new HashMap<>(); + private final String filePath; + private final IPersistentIO persistIO; + private long fileModifyTS; - public ViewMeta(String filePath, IPersistentIO persistentIO) throws IOException { - this.filePath = filePath; - this.persistIO = persistentIO; - if (persistentIO.exists(new Path(filePath))) { - readFile(filePath); - } else if (persistIO.exists(new Path(filePath + BACKUP_SUFFIX))) { - readFile(filePath + BACKUP_SUFFIX); - } + public ViewMeta(String filePath, IPersistentIO persistentIO) throws IOException { + this.filePath = filePath; + this.persistIO = persistentIO; + if (persistentIO.exists(new Path(filePath))) { + readFile(filePath); + } else if (persistIO.exists(new Path(filePath + BACKUP_SUFFIX))) { + readFile(filePath + BACKUP_SUFFIX); } + } - private void readFile(String strFilePath) throws IOException { - FileInfo info = persistIO.getFileInfo(new Path(strFilePath)); - InputStream stream = persistIO.open(new Path(strFilePath)); - byte[] content = new byte[(int) info.getLength()]; - stream.read(content); - stream.close(); - ViewMetaPb.ViewMeta viewMeta = ViewMetaPb.ViewMeta.parseFrom(content); - this.innerKV.putAll(viewMeta.getKvInfoMap()); - this.fileModifyTS = info.getModificationTime(); - } + private void readFile(String strFilePath) throws IOException { + FileInfo info = persistIO.getFileInfo(new Path(strFilePath)); + InputStream stream = persistIO.open(new Path(strFilePath)); + byte[] content = new byte[(int) info.getLength()]; + stream.read(content); + stream.close(); + ViewMetaPb.ViewMeta viewMeta = ViewMetaPb.ViewMeta.parseFrom(content); + this.innerKV.putAll(viewMeta.getKvInfoMap()); + this.fileModifyTS = info.getModificationTime(); + } - public void tryRefresh() throws IOException { - if (!persistIO.exists(new Path(filePath))) { - return; - } - long currentModifyTime = this.persistIO.getFileInfo(new Path(filePath)).getModificationTime(); - if (currentModifyTime > fileModifyTS) { - LOGGER.info("refresh {}", filePath); - readFile(filePath); - } else { - LOGGER.info("last {} now {}", fileModifyTS, currentModifyTime); - } + public void tryRefresh() throws IOException { + if (!persistIO.exists(new Path(filePath))) { + return; } - - public Map getKVMap() { - return innerKV; + long currentModifyTime = this.persistIO.getFileInfo(new Path(filePath)).getModificationTime(); + if (currentModifyTime > fileModifyTS) { + LOGGER.info("refresh {}", filePath); + readFile(filePath); + } else { + LOGGER.info("last {} now {}", fileModifyTS, currentModifyTime); } + } - public byte[] toBinary() { - return ViewMetaPb.ViewMeta - .newBuilder() - .putAllKvInfo(innerKV) - .build() - .toByteArray(); - } + public Map getKVMap() { + return innerKV; + } - public void archive() throws IOException { - final long start = System.currentTimeMillis(); - // fo, keep history - if (persistIO.exists(new Path(filePath))) { - persistIO.delete(new Path(filePath + BACKUP_SUFFIX), false); - boolean res = persistIO.renameFile(new Path(filePath), new Path(filePath + BACKUP_SUFFIX)); - Preconditions.checkArgument(res, "renameFile fail " + filePath); - } - // fo, protect filePath. - final java.nio.file.Path path = Files.createTempFile("tmp", TMP_SUFFIX); - Files.write(path, toBinary()); - persistIO.copyFromLocalFile(new Path(path.toString()), new Path(filePath + TMP_SUFFIX)); - Files.deleteIfExists(path); + public byte[] toBinary() { + return ViewMetaPb.ViewMeta.newBuilder().putAllKvInfo(innerKV).build().toByteArray(); + } + + public void archive() throws IOException { + final long start = System.currentTimeMillis(); + // fo, keep history + if (persistIO.exists(new Path(filePath))) { + persistIO.delete(new Path(filePath + BACKUP_SUFFIX), false); + boolean res = persistIO.renameFile(new Path(filePath), new Path(filePath + BACKUP_SUFFIX)); + Preconditions.checkArgument(res, "renameFile fail " + filePath); + } + // fo, protect filePath. + final java.nio.file.Path path = Files.createTempFile("tmp", TMP_SUFFIX); + Files.write(path, toBinary()); + persistIO.copyFromLocalFile(new Path(path.toString()), new Path(filePath + TMP_SUFFIX)); + Files.deleteIfExists(path); - // clean. - boolean res = persistIO.renameFile(new Path(filePath + TMP_SUFFIX), new Path(filePath)); - Preconditions.checkArgument(res, "renameFile fail " + filePath + TMP_SUFFIX); - if (persistIO.exists(new Path(filePath + BACKUP_SUFFIX))) { - persistIO.delete(new Path(filePath + BACKUP_SUFFIX), false); - } - LOGGER.info("save {} cost {}ms", filePath, System.currentTimeMillis() - start); + // clean. + boolean res = persistIO.renameFile(new Path(filePath + TMP_SUFFIX), new Path(filePath)); + Preconditions.checkArgument(res, "renameFile fail " + filePath + TMP_SUFFIX); + if (persistIO.exists(new Path(filePath + BACKUP_SUFFIX))) { + persistIO.delete(new Path(filePath + BACKUP_SUFFIX), false); } + LOGGER.info("save {} cost {}ms", filePath, System.currentTimeMillis() - start); + } } diff --git a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaBookKeeper.java b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaBookKeeper.java index 3792b06c3..50900b531 100644 --- a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaBookKeeper.java +++ b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaBookKeeper.java @@ -19,12 +19,10 @@ package org.apache.geaflow.view.meta; -import com.google.common.base.Preconditions; -import com.google.common.primitives.Longs; -import com.google.protobuf.ByteString; import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.common.utils.FileUtil; @@ -34,113 +32,117 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ViewMetaBookKeeper { - - private static final Logger LOGGER = LoggerFactory.getLogger(ViewMetaBookKeeper.class); - private static final String VIEW_VERSION = "view_version"; - private final String viewName; - private final ViewMetaKeeper viewMetaKeeper; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Longs; +import com.google.protobuf.ByteString; - public ViewMetaBookKeeper(String myViewName, Configuration config) { - this.viewName = myViewName; - viewMetaKeeper = new ViewMetaKeeper(); - viewMetaKeeper.init(myViewName, config); - } +public class ViewMetaBookKeeper { - public long getLatestViewVersion(String name) throws IOException { - byte[] res = viewMetaKeeper.get(name, VIEW_VERSION); - return res == null ? -1L : Longs.fromByteArray(res); + private static final Logger LOGGER = LoggerFactory.getLogger(ViewMetaBookKeeper.class); + private static final String VIEW_VERSION = "view_version"; + private final String viewName; + private final ViewMetaKeeper viewMetaKeeper; + + public ViewMetaBookKeeper(String myViewName, Configuration config) { + this.viewName = myViewName; + viewMetaKeeper = new ViewMetaKeeper(); + viewMetaKeeper.init(myViewName, config); + } + + public long getLatestViewVersion(String name) throws IOException { + byte[] res = viewMetaKeeper.get(name, VIEW_VERSION); + return res == null ? -1L : Longs.fromByteArray(res); + } + + public void saveViewVersion(long viewVersion) throws IOException { + viewMetaKeeper.save(VIEW_VERSION, Longs.toByteArray(viewVersion)); + LOGGER.info("save view version {} {}", viewName, viewVersion); + } + + public void archive() { + try { + viewMetaKeeper.archive(); + } catch (IOException e) { + throw new RuntimeException(e); } - - public void saveViewVersion(long viewVersion) throws IOException { - viewMetaKeeper.save(VIEW_VERSION, Longs.toByteArray(viewVersion)); - LOGGER.info("save view version {} {}", viewName, viewVersion); + } + + public String getViewName() { + return viewName; + } + + public static class ViewMetaKeeper { + + private static final String INFO_FILE_NAME = "view.meta"; + // file modify time unit is second. + private static final long SHARED_STATE_INFO_REFRESH_MS_THRESHOLD = 1000; + private IPersistentIO persistIO; + private Map> sharedViewMeta = new HashMap<>(); + private ViewMeta myStateInfo; + private String myViewName; + private String namespace; + + public void init(String viewName, Configuration config) { + this.myViewName = viewName; + this.persistIO = PersistentIOBuilder.build(config); + this.namespace = config.getString(FileConfigKeys.ROOT); } - public void archive() { - try { - viewMetaKeeper.archive(); - } catch (IOException e) { - throw new RuntimeException(e); + private ViewMeta getOrInit(String viewName) throws IOException { + boolean isMyself = myViewName.equals(viewName); + if (!isMyself) { + Tuple tuple = sharedViewMeta.get(viewName); + ViewMeta viewMeta; + if (tuple == null) { + String file = FileUtil.concatPath(this.namespace, viewName) + "/" + INFO_FILE_NAME; + viewMeta = new ViewMeta(file, persistIO); + this.sharedViewMeta.put(viewName, Tuple.of(viewMeta, System.currentTimeMillis())); + } else if (System.currentTimeMillis() - tuple.f1 > SHARED_STATE_INFO_REFRESH_MS_THRESHOLD) { + tryRefresh(tuple); + viewMeta = tuple.f0; + } else { + viewMeta = tuple.f0; + } + return viewMeta; + } else { + if (myStateInfo == null) { + String file = FileUtil.concatPath(this.namespace, viewName) + "/" + INFO_FILE_NAME; + myStateInfo = new ViewMeta(file, persistIO); } + return myStateInfo; + } } - public String getViewName() { - return viewName; + private void tryRefresh(Tuple tuple) throws IOException { + tuple.f0.tryRefresh(); + tuple.f1 = System.currentTimeMillis(); } - public static class ViewMetaKeeper { - - private static final String INFO_FILE_NAME = "view.meta"; - // file modify time unit is second. - private static final long SHARED_STATE_INFO_REFRESH_MS_THRESHOLD = 1000; - private IPersistentIO persistIO; - private Map> sharedViewMeta = new HashMap<>(); - private ViewMeta myStateInfo; - private String myViewName; - private String namespace; - - public void init(String viewName, Configuration config) { - this.myViewName = viewName; - this.persistIO = PersistentIOBuilder.build(config); - this.namespace = config.getString(FileConfigKeys.ROOT); - } - - private ViewMeta getOrInit(String viewName) throws IOException { - boolean isMyself = myViewName.equals(viewName); - if (!isMyself) { - Tuple tuple = sharedViewMeta.get(viewName); - ViewMeta viewMeta; - if (tuple == null) { - String file = FileUtil.concatPath(this.namespace, viewName) + "/" + INFO_FILE_NAME; - viewMeta = new ViewMeta(file, persistIO); - this.sharedViewMeta.put(viewName, Tuple.of(viewMeta, System.currentTimeMillis())); - } else if (System.currentTimeMillis() - tuple.f1 > SHARED_STATE_INFO_REFRESH_MS_THRESHOLD) { - tryRefresh(tuple); - viewMeta = tuple.f0; - } else { - viewMeta = tuple.f0; - } - return viewMeta; - } else { - if (myStateInfo == null) { - String file = FileUtil.concatPath(this.namespace, viewName) + "/" + INFO_FILE_NAME; - myStateInfo = new ViewMeta(file, persistIO); - } - return myStateInfo; - } - } - - private void tryRefresh(Tuple tuple) throws IOException { - tuple.f0.tryRefresh(); - tuple.f1 = System.currentTimeMillis(); - } - - public void save(String k, byte[] v) throws IOException { - Preconditions.checkNotNull(k); - Preconditions.checkNotNull(v); - ViewMeta stateInfo = getOrInit(myViewName); - stateInfo.getKVMap().put(k, ByteString.copyFrom(v)); - } + public void save(String k, byte[] v) throws IOException { + Preconditions.checkNotNull(k); + Preconditions.checkNotNull(v); + ViewMeta stateInfo = getOrInit(myViewName); + stateInfo.getKVMap().put(k, ByteString.copyFrom(v)); + } - public byte[] get(String k) throws IOException { - return get(myViewName, k); - } + public byte[] get(String k) throws IOException { + return get(myViewName, k); + } - public byte[] get(String viewName, String k) throws IOException { - ViewMeta viewMeta = getOrInit(viewName); - Preconditions.checkNotNull(viewMeta); - ByteString value = viewMeta.getKVMap().get(k); - if (value == null) { - return null; - } - return value.toByteArray(); - } + public byte[] get(String viewName, String k) throws IOException { + ViewMeta viewMeta = getOrInit(viewName); + Preconditions.checkNotNull(viewMeta); + ByteString value = viewMeta.getKVMap().get(k); + if (value == null) { + return null; + } + return value.toByteArray(); + } - public void archive() throws IOException { - long t = System.currentTimeMillis(); - myStateInfo.archive(); - LOGGER.info("archive view meta cost {}ms", System.currentTimeMillis() - t); - } + public void archive() throws IOException { + long t = System.currentTimeMillis(); + myStateInfo.archive(); + LOGGER.info("archive view meta cost {}ms", System.currentTimeMillis() - t); } + } } diff --git a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaPb.java b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaPb.java index 54f44a734..ecc1d7423 100644 --- a/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaPb.java +++ b/geaflow/geaflow-plugins/geaflow-view-meta/src/main/java/org/apache/geaflow/view/meta/ViewMetaPb.java @@ -23,865 +23,749 @@ package org.apache.geaflow.view.meta; public final class ViewMetaPb { - private ViewMetaPb() { - } + private ViewMetaPb() {} - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistryLite registry) { - } + public static void registerAllExtensions(com.google.protobuf.ExtensionRegistryLite registry) {} - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistry registry) { - registerAllExtensions( - (com.google.protobuf.ExtensionRegistryLite) registry); - } + public static void registerAllExtensions(com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions((com.google.protobuf.ExtensionRegistryLite) registry); + } - public interface ViewMetaOrBuilder extends - // @@protoc_insertion_point(interface_extends:ViewMeta) - com.google.protobuf.MessageOrBuilder { + public interface ViewMetaOrBuilder + extends + // @@protoc_insertion_point(interface_extends:ViewMeta) + com.google.protobuf.MessageOrBuilder { - /** - * map<string, bytes> kvInfo = 1; - */ - int getKvInfoCount(); + /** map<string, bytes> kvInfo = 1; */ + int getKvInfoCount(); - /** - * map<string, bytes> kvInfo = 1; - */ - boolean containsKvInfo( - java.lang.String key); + /** map<string, bytes> kvInfo = 1; */ + boolean containsKvInfo(java.lang.String key); - /** - * Use {@link #getKvInfoMap()} instead. - */ - @java.lang.Deprecated - java.util.Map - getKvInfo(); + /** Use {@link #getKvInfoMap()} instead. */ + @java.lang.Deprecated + java.util.Map getKvInfo(); - /** - * map<string, bytes> kvInfo = 1; - */ - java.util.Map - getKvInfoMap(); + /** map<string, bytes> kvInfo = 1; */ + java.util.Map getKvInfoMap(); - /** - * map<string, bytes> kvInfo = 1; - */ + /** map<string, bytes> kvInfo = 1; */ + com.google.protobuf.ByteString getKvInfoOrDefault( + java.lang.String key, com.google.protobuf.ByteString defaultValue); - com.google.protobuf.ByteString getKvInfoOrDefault( - java.lang.String key, - com.google.protobuf.ByteString defaultValue); + /** map<string, bytes> kvInfo = 1; */ + com.google.protobuf.ByteString getKvInfoOrThrow(java.lang.String key); + } - /** - * map<string, bytes> kvInfo = 1; - */ + /** Protobuf type {@code ViewMeta} */ + public static final class ViewMeta extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:ViewMeta) + ViewMetaOrBuilder { + private static final long serialVersionUID = 0L; - com.google.protobuf.ByteString getKvInfoOrThrow( - java.lang.String key); + // Use ViewMeta.newBuilder() to construct. + private ViewMeta(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); } - /** - * Protobuf type {@code ViewMeta} - */ - public static final class ViewMeta extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:ViewMeta) - ViewMetaOrBuilder { - private static final long serialVersionUID = 0L; - - // Use ViewMeta.newBuilder() to construct. - private ViewMeta(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - - private ViewMeta() { - } + private ViewMeta() {} - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new ViewMeta(); - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new ViewMeta(); + } - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - private ViewMeta( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - kvInfo_ = com.google.protobuf.MapField.newMapField( - KvInfoDefaultEntryHolder.defaultEntry); - mutable_bitField0_ |= 0x00000001; - } - com.google.protobuf.MapEntry - kvInfo__ = input.readMessage( - KvInfoDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry); - kvInfo_.getMutableMap().put( - kvInfo__.getKey(), kvInfo__.getValue()); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } + private ViewMeta( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + kvInfo_ = + com.google.protobuf.MapField.newMapField( + KvInfoDefaultEntryHolder.defaultEntry); + mutable_bitField0_ |= 0x00000001; } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + com.google.protobuf.MapEntry + kvInfo__ = + input.readMessage( + KvInfoDefaultEntryHolder.defaultEntry.getParserForType(), + extensionRegistry); + kvInfo_.getMutableMap().put(kvInfo__.getKey(), kvInfo__.getValue()); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return ViewMetaPb.internal_static_ViewMeta_descriptor; - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return ViewMetaPb.internal_static_ViewMeta_descriptor; + } - @SuppressWarnings({"rawtypes"}) - @java.lang.Override - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 1: - return internalGetKvInfo(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } + @SuppressWarnings({"rawtypes"}) + @java.lang.Override + protected com.google.protobuf.MapField internalGetMapField(int number) { + switch (number) { + case 1: + return internalGetKvInfo(); + default: + throw new RuntimeException("Invalid map field number: " + number); + } + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return ViewMetaPb.internal_static_ViewMeta_fieldAccessorTable - .ensureFieldAccessorsInitialized( - ViewMetaPb.ViewMeta.class, ViewMetaPb.ViewMeta.Builder.class); - } + return ViewMetaPb.internal_static_ViewMeta_fieldAccessorTable.ensureFieldAccessorsInitialized( + ViewMetaPb.ViewMeta.class, ViewMetaPb.ViewMeta.Builder.class); + } - public static final int KVINFO_FIELD_NUMBER = 1; - - private static final class KvInfoDefaultEntryHolder { - static final com.google.protobuf.MapEntry< - java.lang.String, com.google.protobuf.ByteString> defaultEntry = - com.google.protobuf.MapEntry - .newDefaultInstance( - ViewMetaPb.internal_static_ViewMeta_KvInfoEntry_descriptor, - com.google.protobuf.WireFormat.FieldType.STRING, - "", - com.google.protobuf.WireFormat.FieldType.BYTES, - com.google.protobuf.ByteString.EMPTY); - } + public static final int KVINFO_FIELD_NUMBER = 1; + + private static final class KvInfoDefaultEntryHolder { + static final com.google.protobuf.MapEntry + defaultEntry = + com.google.protobuf.MapEntry + .newDefaultInstance( + ViewMetaPb.internal_static_ViewMeta_KvInfoEntry_descriptor, + com.google.protobuf.WireFormat.FieldType.STRING, + "", + com.google.protobuf.WireFormat.FieldType.BYTES, + com.google.protobuf.ByteString.EMPTY); + } - private com.google.protobuf.MapField< - java.lang.String, com.google.protobuf.ByteString> kvInfo_; + private com.google.protobuf.MapField kvInfo_; - private com.google.protobuf.MapField + private com.google.protobuf.MapField internalGetKvInfo() { - if (kvInfo_ == null) { - return com.google.protobuf.MapField.emptyMapField( - KvInfoDefaultEntryHolder.defaultEntry); - } - return kvInfo_; - } - - public int getKvInfoCount() { - return internalGetKvInfo().getMap().size(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public boolean containsKvInfo( - java.lang.String key) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - return internalGetKvInfo().getMap().containsKey(key); - } - - /** - * Use {@link #getKvInfoMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getKvInfo() { - return getKvInfoMap(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public java.util.Map getKvInfoMap() { - return internalGetKvInfo().getMap(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public com.google.protobuf.ByteString getKvInfoOrDefault( - java.lang.String key, - com.google.protobuf.ByteString defaultValue) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - java.util.Map map = - internalGetKvInfo().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public com.google.protobuf.ByteString getKvInfoOrThrow( - java.lang.String key) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - java.util.Map map = - internalGetKvInfo().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - private byte memoizedIsInitialized = -1; + if (kvInfo_ == null) { + return com.google.protobuf.MapField.emptyMapField(KvInfoDefaultEntryHolder.defaultEntry); + } + return kvInfo_; + } - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + public int getKvInfoCount() { + return internalGetKvInfo().getMap().size(); + } - memoizedIsInitialized = 1; - return true; - } + /** map<string, bytes> kvInfo = 1; */ + public boolean containsKvInfo(java.lang.String key) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + return internalGetKvInfo().getMap().containsKey(key); + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - com.google.protobuf.GeneratedMessageV3 - .serializeStringMapTo( - output, - internalGetKvInfo(), - KvInfoDefaultEntryHolder.defaultEntry, - 1); - unknownFields.writeTo(output); - } + /** Use {@link #getKvInfoMap()} instead. */ + @java.lang.Deprecated + public java.util.Map getKvInfo() { + return getKvInfoMap(); + } - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + /** map<string, bytes> kvInfo = 1; */ + public java.util.Map getKvInfoMap() { + return internalGetKvInfo().getMap(); + } - size = 0; - for (java.util.Map.Entry entry - : internalGetKvInfo().getMap().entrySet()) { - com.google.protobuf.MapEntry - kvInfo__ = KvInfoDefaultEntryHolder.defaultEntry.newBuilderForType() - .setKey(entry.getKey()) - .setValue(entry.getValue()) - .build(); - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(1, kvInfo__); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + /** map<string, bytes> kvInfo = 1; */ + public com.google.protobuf.ByteString getKvInfoOrDefault( + java.lang.String key, com.google.protobuf.ByteString defaultValue) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + java.util.Map map = + internalGetKvInfo().getMap(); + return map.containsKey(key) ? map.get(key) : defaultValue; + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof ViewMetaPb.ViewMeta)) { - return super.equals(obj); - } - ViewMetaPb.ViewMeta other = (ViewMetaPb.ViewMeta) obj; - - if (!internalGetKvInfo().equals( - other.internalGetKvInfo())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } + /** map<string, bytes> kvInfo = 1; */ + public com.google.protobuf.ByteString getKvInfoOrThrow(java.lang.String key) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + java.util.Map map = + internalGetKvInfo().getMap(); + if (!map.containsKey(key)) { + throw new java.lang.IllegalArgumentException(); + } + return map.get(key); + } - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (!internalGetKvInfo().getMap().isEmpty()) { - hash = (37 * hash) + KVINFO_FIELD_NUMBER; - hash = (53 * hash) + internalGetKvInfo().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + private byte memoizedIsInitialized = -1; - public static ViewMetaPb.ViewMeta parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - public static ViewMetaPb.ViewMeta parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + memoizedIsInitialized = 1; + return true; + } - public static ViewMetaPb.ViewMeta parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + com.google.protobuf.GeneratedMessageV3.serializeStringMapTo( + output, internalGetKvInfo(), KvInfoDefaultEntryHolder.defaultEntry, 1); + unknownFields.writeTo(output); + } - public static ViewMetaPb.ViewMeta parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + for (java.util.Map.Entry entry : + internalGetKvInfo().getMap().entrySet()) { + com.google.protobuf.MapEntry kvInfo__ = + KvInfoDefaultEntryHolder.defaultEntry + .newBuilderForType() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build(); + size += com.google.protobuf.CodedOutputStream.computeMessageSize(1, kvInfo__); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - public static ViewMetaPb.ViewMeta parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof ViewMetaPb.ViewMeta)) { + return super.equals(obj); + } + ViewMetaPb.ViewMeta other = (ViewMetaPb.ViewMeta) obj; + + if (!internalGetKvInfo().equals(other.internalGetKvInfo())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - public static ViewMetaPb.ViewMeta parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (!internalGetKvInfo().getMap().isEmpty()) { + hash = (37 * hash) + KVINFO_FIELD_NUMBER; + hash = (53 * hash) + internalGetKvInfo().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - public static ViewMetaPb.ViewMeta parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + public static ViewMetaPb.ViewMeta parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static ViewMetaPb.ViewMeta parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static ViewMetaPb.ViewMeta parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - public static ViewMetaPb.ViewMeta parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + public static ViewMetaPb.ViewMeta parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static ViewMetaPb.ViewMeta parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + public static ViewMetaPb.ViewMeta parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - public static ViewMetaPb.ViewMeta parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + public static ViewMetaPb.ViewMeta parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static ViewMetaPb.ViewMeta parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static ViewMetaPb.ViewMeta parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + public static ViewMetaPb.ViewMeta parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + public static ViewMetaPb.ViewMeta parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - public static Builder newBuilder(ViewMetaPb.ViewMeta prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + public static ViewMetaPb.ViewMeta parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + public static ViewMetaPb.ViewMeta parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + public static ViewMetaPb.ViewMeta parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - /** - * Protobuf type {@code ViewMeta} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:ViewMeta) - ViewMetaPb.ViewMetaOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return ViewMetaPb.internal_static_ViewMeta_descriptor; - } - - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMapField( - int number) { - switch (number) { - case 1: - return internalGetKvInfo(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - - @SuppressWarnings({"rawtypes"}) - protected com.google.protobuf.MapField internalGetMutableMapField( - int number) { - switch (number) { - case 1: - return internalGetMutableKvInfo(); - default: - throw new RuntimeException( - "Invalid map field number: " + number); - } - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return ViewMetaPb.internal_static_ViewMeta_fieldAccessorTable - .ensureFieldAccessorsInitialized( - ViewMetaPb.ViewMeta.class, ViewMetaPb.ViewMeta.Builder.class); - } - - // Construct using org.apache.geaflow.view.meta.ViewMetaPb.ViewMeta.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - - private void maybeForceBuilderInitialization() { - } - - @java.lang.Override - public Builder clear() { - super.clear(); - internalGetMutableKvInfo().clear(); - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return ViewMetaPb.internal_static_ViewMeta_descriptor; - } - - @java.lang.Override - public ViewMetaPb.ViewMeta getDefaultInstanceForType() { - return ViewMetaPb.ViewMeta.getDefaultInstance(); - } - - @java.lang.Override - public ViewMetaPb.ViewMeta build() { - ViewMetaPb.ViewMeta result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public ViewMetaPb.ViewMeta buildPartial() { - ViewMetaPb.ViewMeta result = new ViewMetaPb.ViewMeta(this); - result.kvInfo_ = internalGetKvInfo(); - result.kvInfo_.makeImmutable(); - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof ViewMetaPb.ViewMeta) { - return mergeFrom((ViewMetaPb.ViewMeta) other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(ViewMetaPb.ViewMeta other) { - if (other == ViewMetaPb.ViewMeta.getDefaultInstance()) return this; - internalGetMutableKvInfo().mergeFrom( - other.internalGetKvInfo()); - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - ViewMetaPb.ViewMeta parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (ViewMetaPb.ViewMeta) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } + public static ViewMetaPb.ViewMeta parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - private int bitField0_; + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - private com.google.protobuf.MapField< - java.lang.String, com.google.protobuf.ByteString> kvInfo_; + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - private com.google.protobuf.MapField - internalGetKvInfo() { - if (kvInfo_ == null) { - return com.google.protobuf.MapField.emptyMapField( - KvInfoDefaultEntryHolder.defaultEntry); - } - return kvInfo_; - } - - private com.google.protobuf.MapField - internalGetMutableKvInfo() { - onChanged(); - if (kvInfo_ == null) { - kvInfo_ = com.google.protobuf.MapField.newMapField( - KvInfoDefaultEntryHolder.defaultEntry); - } - if (!kvInfo_.isMutable()) { - kvInfo_ = kvInfo_.copy(); - } - return kvInfo_; - } + public static Builder newBuilder(ViewMetaPb.ViewMeta prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - public int getKvInfoCount() { - return internalGetKvInfo().getMap().size(); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - /** - * map<string, bytes> kvInfo = 1; - */ + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - public boolean containsKvInfo( - java.lang.String key) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - return internalGetKvInfo().getMap().containsKey(key); - } - - /** - * Use {@link #getKvInfoMap()} instead. - */ - @java.lang.Deprecated - public java.util.Map getKvInfo() { - return getKvInfoMap(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public java.util.Map getKvInfoMap() { - return internalGetKvInfo().getMap(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public com.google.protobuf.ByteString getKvInfoOrDefault( - java.lang.String key, - com.google.protobuf.ByteString defaultValue) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - java.util.Map map = - internalGetKvInfo().getMap(); - return map.containsKey(key) ? map.get(key) : defaultValue; - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public com.google.protobuf.ByteString getKvInfoOrThrow( - java.lang.String key) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - java.util.Map map = - internalGetKvInfo().getMap(); - if (!map.containsKey(key)) { - throw new java.lang.IllegalArgumentException(); - } - return map.get(key); - } - - public Builder clearKvInfo() { - internalGetMutableKvInfo().getMutableMap() - .clear(); - return this; - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public Builder removeKvInfo( - java.lang.String key) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - internalGetMutableKvInfo().getMutableMap() - .remove(key); - return this; - } - - /** - * Use alternate mutation accessors instead. - */ - @java.lang.Deprecated - public java.util.Map - getMutableKvInfo() { - return internalGetMutableKvInfo().getMutableMap(); - } - - /** - * map<string, bytes> kvInfo = 1; - */ - public Builder putKvInfo( - java.lang.String key, - com.google.protobuf.ByteString value) { - if (key == null) { - throw new java.lang.NullPointerException(); - } - if (value == null) { - throw new java.lang.NullPointerException(); - } - internalGetMutableKvInfo().getMutableMap() - .put(key, value); - return this; - } - - /** - * map<string, bytes> kvInfo = 1; - */ - - public Builder putAllKvInfo( - java.util.Map values) { - internalGetMutableKvInfo().getMutableMap() - .putAll(values); - return this; - } - - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:ViewMeta) - } + /** Protobuf type {@code ViewMeta} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:ViewMeta) + ViewMetaPb.ViewMetaOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return ViewMetaPb.internal_static_ViewMeta_descriptor; + } + + @SuppressWarnings({"rawtypes"}) + protected com.google.protobuf.MapField internalGetMapField(int number) { + switch (number) { + case 1: + return internalGetKvInfo(); + default: + throw new RuntimeException("Invalid map field number: " + number); + } + } + + @SuppressWarnings({"rawtypes"}) + protected com.google.protobuf.MapField internalGetMutableMapField(int number) { + switch (number) { + case 1: + return internalGetMutableKvInfo(); + default: + throw new RuntimeException("Invalid map field number: " + number); + } + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return ViewMetaPb.internal_static_ViewMeta_fieldAccessorTable + .ensureFieldAccessorsInitialized( + ViewMetaPb.ViewMeta.class, ViewMetaPb.ViewMeta.Builder.class); + } + + // Construct using org.apache.geaflow.view.meta.ViewMetaPb.ViewMeta.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + internalGetMutableKvInfo().clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return ViewMetaPb.internal_static_ViewMeta_descriptor; + } + + @java.lang.Override + public ViewMetaPb.ViewMeta getDefaultInstanceForType() { + return ViewMetaPb.ViewMeta.getDefaultInstance(); + } + + @java.lang.Override + public ViewMetaPb.ViewMeta build() { + ViewMetaPb.ViewMeta result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public ViewMetaPb.ViewMeta buildPartial() { + ViewMetaPb.ViewMeta result = new ViewMetaPb.ViewMeta(this); + result.kvInfo_ = internalGetKvInfo(); + result.kvInfo_.makeImmutable(); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof ViewMetaPb.ViewMeta) { + return mergeFrom((ViewMetaPb.ViewMeta) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(ViewMetaPb.ViewMeta other) { + if (other == ViewMetaPb.ViewMeta.getDefaultInstance()) return this; + internalGetMutableKvInfo().mergeFrom(other.internalGetKvInfo()); + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + ViewMetaPb.ViewMeta parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (ViewMetaPb.ViewMeta) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private com.google.protobuf.MapField + kvInfo_; + + private com.google.protobuf.MapField + internalGetKvInfo() { + if (kvInfo_ == null) { + return com.google.protobuf.MapField.emptyMapField(KvInfoDefaultEntryHolder.defaultEntry); + } + return kvInfo_; + } + + private com.google.protobuf.MapField + internalGetMutableKvInfo() { + onChanged(); + if (kvInfo_ == null) { + kvInfo_ = com.google.protobuf.MapField.newMapField(KvInfoDefaultEntryHolder.defaultEntry); + } + if (!kvInfo_.isMutable()) { + kvInfo_ = kvInfo_.copy(); + } + return kvInfo_; + } + + public int getKvInfoCount() { + return internalGetKvInfo().getMap().size(); + } + + /** map<string, bytes> kvInfo = 1; */ + public boolean containsKvInfo(java.lang.String key) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + return internalGetKvInfo().getMap().containsKey(key); + } + + /** Use {@link #getKvInfoMap()} instead. */ + @java.lang.Deprecated + public java.util.Map getKvInfo() { + return getKvInfoMap(); + } + + /** map<string, bytes> kvInfo = 1; */ + public java.util.Map getKvInfoMap() { + return internalGetKvInfo().getMap(); + } + + /** map<string, bytes> kvInfo = 1; */ + public com.google.protobuf.ByteString getKvInfoOrDefault( + java.lang.String key, com.google.protobuf.ByteString defaultValue) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + java.util.Map map = + internalGetKvInfo().getMap(); + return map.containsKey(key) ? map.get(key) : defaultValue; + } + + /** map<string, bytes> kvInfo = 1; */ + public com.google.protobuf.ByteString getKvInfoOrThrow(java.lang.String key) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + java.util.Map map = + internalGetKvInfo().getMap(); + if (!map.containsKey(key)) { + throw new java.lang.IllegalArgumentException(); + } + return map.get(key); + } + + public Builder clearKvInfo() { + internalGetMutableKvInfo().getMutableMap().clear(); + return this; + } + + /** map<string, bytes> kvInfo = 1; */ + public Builder removeKvInfo(java.lang.String key) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + internalGetMutableKvInfo().getMutableMap().remove(key); + return this; + } + + /** Use alternate mutation accessors instead. */ + @java.lang.Deprecated + public java.util.Map getMutableKvInfo() { + return internalGetMutableKvInfo().getMutableMap(); + } + + /** map<string, bytes> kvInfo = 1; */ + public Builder putKvInfo(java.lang.String key, com.google.protobuf.ByteString value) { + if (key == null) { + throw new java.lang.NullPointerException(); + } + if (value == null) { + throw new java.lang.NullPointerException(); + } + internalGetMutableKvInfo().getMutableMap().put(key, value); + return this; + } + + /** map<string, bytes> kvInfo = 1; */ + public Builder putAllKvInfo( + java.util.Map values) { + internalGetMutableKvInfo().getMutableMap().putAll(values); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:ViewMeta) + } - // @@protoc_insertion_point(class_scope:ViewMeta) - private static final ViewMetaPb.ViewMeta DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:ViewMeta) + private static final ViewMetaPb.ViewMeta DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new ViewMetaPb.ViewMeta(); - } + static { + DEFAULT_INSTANCE = new ViewMetaPb.ViewMeta(); + } - public static ViewMetaPb.ViewMeta getDefaultInstance() { - return DEFAULT_INSTANCE; - } + public static ViewMetaPb.ViewMeta getDefaultInstance() { + return DEFAULT_INSTANCE; + } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public ViewMeta parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new ViewMeta(input, extensionRegistry); - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ViewMeta parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new ViewMeta(input, extensionRegistry); + } }; - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public ViewMetaPb.ViewMeta getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - + public static com.google.protobuf.Parser parser() { + return PARSER; } - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_ViewMeta_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_ViewMeta_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_ViewMeta_KvInfoEntry_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_ViewMeta_KvInfoEntry_fieldAccessorTable; - - public static com.google.protobuf.Descriptors.FileDescriptor - getDescriptor() { - return descriptor; + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; } - private static com.google.protobuf.Descriptors.FileDescriptor - descriptor; - - static { - java.lang.String[] descriptorData = { - "\n\014view_meta.pb\"`\n\010ViewMeta\022%\n\006kvInfo\030\001 \003" + - "(\0132\025.ViewMeta.KvInfoEntry\032-\n\013KvInfoEntry" + - "\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\014:\0028\001B \n\036com." + - "antgroup.geaflow.view.metab\006proto3" - }; - descriptor = com.google.protobuf.Descriptors.FileDescriptor - .internalBuildGeneratedFileFrom(descriptorData, - new com.google.protobuf.Descriptors.FileDescriptor[]{ - }); - internal_static_ViewMeta_descriptor = - getDescriptor().getMessageTypes().get(0); - internal_static_ViewMeta_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + @java.lang.Override + public ViewMetaPb.ViewMeta getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_ViewMeta_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_ViewMeta_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_ViewMeta_KvInfoEntry_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_ViewMeta_KvInfoEntry_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor getDescriptor() { + return descriptor; + } + + private static com.google.protobuf.Descriptors.FileDescriptor descriptor; + + static { + java.lang.String[] descriptorData = { + "\n" + + "\014view_meta.pb\"`\n" + + "\010ViewMeta\022%\n" + + "\006kvInfo\030\001 \003(\0132\025.ViewMeta.KvInfoEntry\032-\n" + + "\013KvInfoEntry\022\013\n" + + "\003key\030\001 \001(\t\022\r\n" + + "\005value\030\002 \001(\014:\0028\001B \n" + + "\036com.antgroup.geaflow.view.metab\006proto3" + }; + descriptor = + com.google.protobuf.Descriptors.FileDescriptor.internalBuildGeneratedFileFrom( + descriptorData, new com.google.protobuf.Descriptors.FileDescriptor[] {}); + internal_static_ViewMeta_descriptor = getDescriptor().getMessageTypes().get(0); + internal_static_ViewMeta_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_ViewMeta_descriptor, - new java.lang.String[]{"KvInfo",}); - internal_static_ViewMeta_KvInfoEntry_descriptor = - internal_static_ViewMeta_descriptor.getNestedTypes().get(0); - internal_static_ViewMeta_KvInfoEntry_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "KvInfo", + }); + internal_static_ViewMeta_KvInfoEntry_descriptor = + internal_static_ViewMeta_descriptor.getNestedTypes().get(0); + internal_static_ViewMeta_KvInfoEntry_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_ViewMeta_KvInfoEntry_descriptor, - new java.lang.String[]{"Key", "Value",}); - } + new java.lang.String[] { + "Key", "Value", + }); + } - // @@protoc_insertion_point(outer_class_scope) + // @@protoc_insertion_point(outer_class_scope) } diff --git a/geaflow/geaflow-plugins/geaflow-view-meta/src/test/java/org/apache/geaflow/view/meta/ViewMetaKeeperTest.java b/geaflow/geaflow-plugins/geaflow-view-meta/src/test/java/org/apache/geaflow/view/meta/ViewMetaKeeperTest.java index 96a5f46fa..ad0eb3c9d 100644 --- a/geaflow/geaflow-plugins/geaflow-view-meta/src/test/java/org/apache/geaflow/view/meta/ViewMetaKeeperTest.java +++ b/geaflow/geaflow-plugins/geaflow-view-meta/src/test/java/org/apache/geaflow/view/meta/ViewMetaKeeperTest.java @@ -20,6 +20,7 @@ package org.apache.geaflow.view.meta; import java.io.File; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.file.FileConfigKeys; import org.apache.geaflow.view.meta.ViewMetaBookKeeper.ViewMetaKeeper; @@ -28,65 +29,65 @@ public class ViewMetaKeeperTest { - @Test - public void testFO() throws Exception { - Configuration configuration = new Configuration(); - configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); - configuration.put(FileConfigKeys.ROOT, "/tmp/"); - - testFO("panguInfo" + System.currentTimeMillis(), configuration); - } - - private void testFO(String jobName, Configuration config) throws Exception { - ViewMetaKeeper keeper = new ViewMetaKeeper(); - keeper.init(jobName, config); - - keeper.save("successBatchId", "10".getBytes()); - Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); - keeper.archive(); - - ViewMetaKeeper sharedKeeper = new ViewMetaKeeper(); - sharedKeeper.init("shared", config); - Thread.sleep(1000); - Assert.assertEquals(sharedKeeper.get(jobName, "successBatchId"), "10".getBytes()); - - keeper.save("done", "true".getBytes()); - keeper.archive(); - - Thread.sleep(1000); - Assert.assertEquals(sharedKeeper.get(jobName, "done"), "true".getBytes()); - - // normal fo - keeper = new ViewMetaKeeper(); - keeper.init(jobName, config); - Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); - Assert.assertEquals(keeper.get("done"), "true".getBytes()); - - String filePath = "/tmp/" + jobName + "/view.meta"; - Assert.assertTrue(new File(filePath).exists()); - - // fail fo, luckily, we have backup file - new File(filePath).renameTo(new File(filePath + ".bak")); - Thread.sleep(1000); - keeper = new ViewMetaKeeper(); - keeper.init(jobName, config); - Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); - Assert.assertEquals(keeper.get("done"), "true".getBytes()); - Assert.assertEquals(sharedKeeper.get(jobName, "done"), "true".getBytes()); - - keeper.save("finish", "true".getBytes()); - keeper.archive(); - - keeper = new ViewMetaKeeper(); - keeper.init(jobName, config); - Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); - Assert.assertEquals(keeper.get("done"), "true".getBytes()); - Assert.assertEquals(keeper.get("finish"), "true".getBytes()); - Thread.sleep(1000); - Assert.assertEquals(sharedKeeper.get(jobName, "finish"), "true".getBytes()); - keeper.archive(); - - Assert.assertTrue(new File(filePath).exists()); - Assert.assertFalse(new File(filePath + ".bak").exists()); - } + @Test + public void testFO() throws Exception { + Configuration configuration = new Configuration(); + configuration.put(FileConfigKeys.PERSISTENT_TYPE, "LOCAL"); + configuration.put(FileConfigKeys.ROOT, "/tmp/"); + + testFO("panguInfo" + System.currentTimeMillis(), configuration); + } + + private void testFO(String jobName, Configuration config) throws Exception { + ViewMetaKeeper keeper = new ViewMetaKeeper(); + keeper.init(jobName, config); + + keeper.save("successBatchId", "10".getBytes()); + Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); + keeper.archive(); + + ViewMetaKeeper sharedKeeper = new ViewMetaKeeper(); + sharedKeeper.init("shared", config); + Thread.sleep(1000); + Assert.assertEquals(sharedKeeper.get(jobName, "successBatchId"), "10".getBytes()); + + keeper.save("done", "true".getBytes()); + keeper.archive(); + + Thread.sleep(1000); + Assert.assertEquals(sharedKeeper.get(jobName, "done"), "true".getBytes()); + + // normal fo + keeper = new ViewMetaKeeper(); + keeper.init(jobName, config); + Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); + Assert.assertEquals(keeper.get("done"), "true".getBytes()); + + String filePath = "/tmp/" + jobName + "/view.meta"; + Assert.assertTrue(new File(filePath).exists()); + + // fail fo, luckily, we have backup file + new File(filePath).renameTo(new File(filePath + ".bak")); + Thread.sleep(1000); + keeper = new ViewMetaKeeper(); + keeper.init(jobName, config); + Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); + Assert.assertEquals(keeper.get("done"), "true".getBytes()); + Assert.assertEquals(sharedKeeper.get(jobName, "done"), "true".getBytes()); + + keeper.save("finish", "true".getBytes()); + keeper.archive(); + + keeper = new ViewMetaKeeper(); + keeper.init(jobName, config); + Assert.assertEquals(keeper.get("successBatchId"), "10".getBytes()); + Assert.assertEquals(keeper.get("done"), "true".getBytes()); + Assert.assertEquals(keeper.get("finish"), "true".getBytes()); + Thread.sleep(1000); + Assert.assertEquals(sharedKeeper.get(jobName, "finish"), "true".getBytes()); + keeper.archive(); + + Assert.assertTrue(new File(filePath).exists()); + Assert.assertFalse(new File(filePath + ".bak").exists()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DataModel.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DataModel.java index 5a9b6a0e0..9736228f1 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DataModel.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DataModel.java @@ -20,24 +20,14 @@ package org.apache.geaflow.state; public enum DataModel { - /** - * static graph model. - */ - STATIC_GRAPH, - /** - * dynamic graph model. - */ - DYNAMIC_GRAPH, - /** - * kv model. - */ - KV, - /** - * kMap model. - */ - KMap, - /** - * kList model. - */ - KList, + /** static graph model. */ + STATIC_GRAPH, + /** dynamic graph model. */ + DYNAMIC_GRAPH, + /** kv model. */ + KV, + /** kMap model. */ + KMap, + /** kList model. */ + KList, } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicEdgeState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicEdgeState.java index 99b651fae..e684f0df8 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicEdgeState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicEdgeState.java @@ -21,7 +21,6 @@ import org.apache.geaflow.model.graph.edge.IEdge; -public interface DynamicEdgeState extends DynamicQueryableState>, - MultiVersionedRevisableState> { - -} +public interface DynamicEdgeState + extends DynamicQueryableState>, + MultiVersionedRevisableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicGraphState.java index 4609dd55c..ab4c4c455 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicGraphState.java @@ -20,25 +20,18 @@ package org.apache.geaflow.state; /** - * The dynamic graph state is the interface controlling - * dynamic vertex/edge or one degree subgraph. - * Dynamic graph is composed by multi graph snapshots, - * and all the operators are made on the snapshots. + * The dynamic graph state is the interface controlling dynamic vertex/edge or one degree subgraph. + * Dynamic graph is composed by multi graph snapshots, and all the operators are made on the + * snapshots. */ public interface DynamicGraphState { - /** - * Returns the dynamic vertex handler. - */ - DynamicVertexState V(); + /** Returns the dynamic vertex handler. */ + DynamicVertexState V(); - /** - * Returns the dynamic edge handler. - */ - DynamicEdgeState E(); + /** Returns the dynamic edge handler. */ + DynamicEdgeState E(); - /** - * Returns the one degree handler. - */ - DynamicOneDegreeGraphState VE(); + /** Returns the one degree handler. */ + DynamicOneDegreeGraphState VE(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphState.java index a9b7c95f4..2ec24d8f8 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphState.java @@ -22,6 +22,4 @@ import org.apache.geaflow.state.data.OneDegreeGraph; public interface DynamicOneDegreeGraphState - extends DynamicQueryableState> { - -} + extends DynamicQueryableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicQueryableState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicQueryableState.java index fdce2d0a3..c15c6bbba 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicQueryableState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicQueryableState.java @@ -21,59 +21,40 @@ import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.query.QueryableAllGraphState; import org.apache.geaflow.state.query.QueryableKeysGraphState; import org.apache.geaflow.state.query.QueryableVersionGraphState; import org.apache.geaflow.utils.keygroup.KeyGroup; -/** - * The query interface for dynamic graph. - */ +/** The query interface for dynamic graph. */ public interface DynamicQueryableState { - /** - * Returns the version list for some specific vertex id. - */ - List getAllVersions(K id); + /** Returns the version list for some specific vertex id. */ + List getAllVersions(K id); - /** - * Returns the latest version for some specific vertex id. - */ - long getLatestVersion(K id); + /** Returns the latest version for some specific vertex id. */ + long getLatestVersion(K id); - /** - * Returns the versioned-query interface for some specific id. - */ - QueryableVersionGraphState query(K id); + /** Returns the versioned-query interface for some specific id. */ + QueryableVersionGraphState query(K id); - /** - * Returns the versioned-query interface for some specific id and versions. - */ - QueryableVersionGraphState query(K id, Collection versions); + /** Returns the versioned-query interface for some specific id and versions. */ + QueryableVersionGraphState query(K id, Collection versions); - /** - * Returns the full graph query interface for some specific version. - */ - QueryableAllGraphState query(long version); + /** Returns the full graph query interface for some specific version. */ + QueryableAllGraphState query(long version); - /** - * Returns the full graph query interface for some specific version. - */ - QueryableAllGraphState query(long version, KeyGroup keyGroup); + /** Returns the full graph query interface for some specific version. */ + QueryableAllGraphState query(long version, KeyGroup keyGroup); - /** - * Returns the point query interface for some specific version and ids. - */ - QueryableKeysGraphState query(long version, K... ids); + /** Returns the point query interface for some specific version and ids. */ + QueryableKeysGraphState query(long version, K... ids); - /** - * Returns the point query interface for some specific version and ids. - */ - QueryableKeysGraphState query(long version, List ids); + /** Returns the point query interface for some specific version and ids. */ + QueryableKeysGraphState query(long version, List ids); - /** - * Returns the graph id iterator. - */ - CloseableIterator idIterator(); + /** Returns the graph id iterator. */ + CloseableIterator idIterator(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicVertexState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicVertexState.java index 157b01f9a..f9c9511d0 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicVertexState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/DynamicVertexState.java @@ -21,8 +21,6 @@ import org.apache.geaflow.model.graph.vertex.IVertex; -public interface DynamicVertexState extends - DynamicQueryableState>, - MultiVersionedRevisableState> { - -} +public interface DynamicVertexState + extends DynamicQueryableState>, + MultiVersionedRevisableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/GraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/GraphState.java index cf9f9cc6d..3e9c4d9df 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/GraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/GraphState.java @@ -21,24 +21,15 @@ import org.apache.geaflow.state.manage.ManageableGraphState; -/** - * The Graph State Interface, including static graph, - * dynamic graph and management state. - */ +/** The Graph State Interface, including static graph, dynamic graph and management state. */ public interface GraphState extends IState { - /** - * Returns the static graph state handler. - */ - StaticGraphState staticGraph(); + /** Returns the static graph state handler. */ + StaticGraphState staticGraph(); - /** - * Returns the dynamic graph state handler. - */ - DynamicGraphState dynamicGraph(); + /** Returns the dynamic graph state handler. */ + DynamicGraphState dynamicGraph(); - /** - * Returns the graph management state handler. - */ - ManageableGraphState manage(); + /** Returns the graph management state handler. */ + ManageableGraphState manage(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/IState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/IState.java index 01a08c42d..8ef92cf70 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/IState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/IState.java @@ -20,15 +20,12 @@ package org.apache.geaflow.state; import java.io.Serializable; + import org.apache.geaflow.state.manage.ManageableState; -/** - * Basic State Interface, including management state. - */ +/** Basic State Interface, including management state. */ public interface IState extends Serializable { - /** - * Returns the graph management state handler. - */ - ManageableState manage(); + /** Returns the graph management state handler. */ + ManageableState manage(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyListState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyListState.java index c1ad0a475..5a0fc70a2 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyListState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyListState.java @@ -20,16 +20,12 @@ package org.apache.geaflow.state; import java.util.List; + import org.apache.geaflow.state.key.KeyListTrait; -/** - * The key to list state interface. - * The interface is inspired by the Flink's ListState. - */ +/** The key to list state interface. The interface is inspired by the Flink's ListState. */ public interface KeyListState extends KeyListTrait, IState { - /** - * Override the list state. - */ - void put(K key, List list); + /** Override the list state. */ + void put(K key, List list); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyMapState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyMapState.java index df0c5291d..983030af2 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyMapState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyMapState.java @@ -20,17 +20,12 @@ package org.apache.geaflow.state; import java.util.Map; + import org.apache.geaflow.state.key.KeyMapTrait; -/** - * The key to map state interface. - * The interface is inspired by the Flink's MapState. - */ +/** The key to map state interface. The interface is inspired by the Flink's MapState. */ public interface KeyMapState extends KeyMapTrait, IState { - /** - * Override the given map into the state. - */ - void put(K key, Map map); - + /** Override the given map into the state. */ + void put(K key, Map map); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyValueState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyValueState.java index 133f16a35..395379d49 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyValueState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/KeyValueState.java @@ -21,25 +21,15 @@ import org.apache.geaflow.state.key.KeyValueTrait; -/** - * The key to value state interface. - * The interface is inspired by the Flink's ValueState. - */ +/** The key to value state interface. The interface is inspired by the Flink's ValueState. */ public interface KeyValueState extends KeyValueTrait, IState { - /** - * Get the value. - */ - V get(K key); - - /** - * Update the value. - */ - void put(K key, V value); + /** Get the value. */ + V get(K key); - /** - * Remove key. - */ - void remove(K key); + /** Update the value. */ + void put(K key, V value); + /** Remove key. */ + void remove(K key); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/MultiVersionedRevisableState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/MultiVersionedRevisableState.java index bf18055ae..9c8b69963 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/MultiVersionedRevisableState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/MultiVersionedRevisableState.java @@ -21,34 +21,21 @@ import java.util.Collection; -/** - * The interface describe the state supporting multi versioned graph modification. - */ +/** The interface describe the state supporting multi versioned graph modification. */ public interface MultiVersionedRevisableState { - /** - * add state method with version and data. - */ - void add(long version, R r); - - /** - * update state method with version and data. - */ - void update(long version, R r); + /** add state method with version and data. */ + void add(long version, R r); - /** - * delete state method with version and data. - */ - void delete(long version, R r); + /** update state method with version and data. */ + void update(long version, R r); - /** - * delete state method with version and multi keys. - */ - void delete(long version, K... ids); + /** delete state method with version and data. */ + void delete(long version, R r); - /** - * delete state method with version and multi keys. - */ - void delete(long version, Collection ids); + /** delete state method with version and multi keys. */ + void delete(long version, K... ids); + /** delete state method with version and multi keys. */ + void delete(long version, Collection ids); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/RevisableState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/RevisableState.java index d763da657..4be664544 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/RevisableState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/RevisableState.java @@ -21,33 +21,21 @@ import java.util.Collection; -/** - * The interface describe the state supporting modification. - */ +/** The interface describe the state supporting modification. */ public interface RevisableState { - /** - * Add state method with data. - */ - void add(R r); - - /** - * Update state method with data. - */ - void update(R r); - - /** - * Delete state method with data. - */ - void delete(R r); - - /** - * Delete state method with multi keys. - */ - void delete(K... ids); - - /** - * Delete state method with multi keys. - */ - void delete(Collection ids); + /** Add state method with data. */ + void add(R r); + + /** Update state method with data. */ + void update(R r); + + /** Delete state method with data. */ + void delete(R r); + + /** Delete state method with multi keys. */ + void delete(K... ids); + + /** Delete state method with multi keys. */ + void delete(Collection ids); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticEdgeState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticEdgeState.java index 77f29dc97..74777c943 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticEdgeState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticEdgeState.java @@ -21,7 +21,5 @@ import org.apache.geaflow.model.graph.edge.IEdge; -public interface StaticEdgeState extends StaticQueryableState>, - RevisableState> { - -} +public interface StaticEdgeState + extends StaticQueryableState>, RevisableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticGraphState.java index 504100bd5..aea5f6a64 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticGraphState.java @@ -20,25 +20,18 @@ package org.apache.geaflow.state; /** - * The static graph state is the interface controlling - * static vertex/edge or one degree subgraph. - * Static graph is an intact graph with only one snapshot - * comparing to the {@link DynamicGraphState}. + * The static graph state is the interface controlling static vertex/edge or one degree subgraph. + * Static graph is an intact graph with only one snapshot comparing to the {@link + * DynamicGraphState}. */ public interface StaticGraphState { - /** - * Returns the static vertex handler. - */ - StaticVertexState V(); + /** Returns the static vertex handler. */ + StaticVertexState V(); - /** - * Returns the static edge handler. - */ - StaticEdgeState E(); + /** Returns the static edge handler. */ + StaticEdgeState E(); - /** - * Returns the one degree handler. - */ - StaticOneDegreeGraphState VE(); + /** Returns the one degree handler. */ + StaticOneDegreeGraphState VE(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphState.java index 4cd6100b8..3a2bbfd5b 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphState.java @@ -22,6 +22,4 @@ import org.apache.geaflow.state.data.OneDegreeGraph; public interface StaticOneDegreeGraphState - extends StaticQueryableState> { - -} + extends StaticQueryableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticQueryableState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticQueryableState.java index 3e1bda183..fe74ed84a 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticQueryableState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticQueryableState.java @@ -20,54 +20,36 @@ package org.apache.geaflow.state; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.query.QueryableAllGraphState; import org.apache.geaflow.state.query.QueryableKeysGraphState; import org.apache.geaflow.utils.keygroup.KeyGroup; -/** - * The query interface for static graph. - */ +/** The query interface for static graph. */ public interface StaticQueryableState { - /** - * Returns the all graph handler. - */ - QueryableAllGraphState query(); - - /** - * Returns the all graph handler by KeyGroup. - */ - QueryableAllGraphState query(KeyGroup keyGroup); + /** Returns the all graph handler. */ + QueryableAllGraphState query(); - /** - * Returns the point query graph handler. - */ - QueryableKeysGraphState query(K id); + /** Returns the all graph handler by KeyGroup. */ + QueryableAllGraphState query(KeyGroup keyGroup); - /** - * Returns the point query graph handler. - */ - QueryableKeysGraphState query(K... ids); + /** Returns the point query graph handler. */ + QueryableKeysGraphState query(K id); - /** - * Returns the point query graph handler. - */ - QueryableKeysGraphState query(List ids); + /** Returns the point query graph handler. */ + QueryableKeysGraphState query(K... ids); - /** - * Returns the graph id iterator. - */ - CloseableIterator idIterator(); + /** Returns the point query graph handler. */ + QueryableKeysGraphState query(List ids); - /** - * Returns the graph query result iterator. - */ - CloseableIterator iterator(); + /** Returns the graph id iterator. */ + CloseableIterator idIterator(); - /** - * Returns the graph query result list. - */ - List asList(); + /** Returns the graph query result iterator. */ + CloseableIterator iterator(); + /** Returns the graph query result list. */ + List asList(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticVertexState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticVertexState.java index d9f548a8e..1cf9c98d9 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticVertexState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/StaticVertexState.java @@ -21,7 +21,5 @@ import org.apache.geaflow.model.graph.vertex.IVertex; -public interface StaticVertexState extends StaticQueryableState>, - RevisableState> { - -} +public interface StaticVertexState + extends StaticQueryableState>, RevisableState> {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/context/StateContext.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/context/StateContext.java index f3bc864cb..1c0d2c6cd 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/context/StateContext.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/context/StateContext.java @@ -26,78 +26,74 @@ import org.apache.geaflow.state.graph.StateMode; import org.apache.geaflow.utils.keygroup.KeyGroup; -/** - * This class describe the external dependencies of state sub system. - */ +/** This class describe the external dependencies of state sub system. */ public class StateContext { - private BaseStateDescriptor descriptor; - private Configuration config; - private int shardId; - private boolean isLocalStore; - - public StateContext(BaseStateDescriptor descriptor, Configuration config) { - this.descriptor = descriptor; - this.config = config; - } - - public StateContext withShardId(int shardId) { - this.shardId = shardId; - return this; - } - - public StateContext withLocalStore(boolean localStore) { - this.isLocalStore = localStore; - return this; - } - - public String getName() { - return descriptor.getName(); - } - - public Configuration getConfig() { - return config; - } - - public MetricGroup getMetricGroup() { - return descriptor.getMetricGroup(); - } - - public KeyGroup getKeyGroup() { - return descriptor.getKeyGroup(); - } - - public String getStoreType() { - return descriptor.getStoreType(); - } - - public BaseStateDescriptor getDescriptor() { - return descriptor; - } - - public int getShardId() { - return shardId; - } - - public boolean isLocalStore() { - return isLocalStore; - } - - public int getTotalShardNum() { - return getDescriptor().getAssigner().getKeyGroupNumber(); - } - - public DataModel getDataModel() { - return this.descriptor.getDateModel(); - } - - public StateMode getStateMode() { - return this.descriptor.getStateMode(); - } - - public StateContext clone() { - return new StateContext(descriptor, config) - .withShardId(shardId) - .withLocalStore(isLocalStore); - } + private BaseStateDescriptor descriptor; + private Configuration config; + private int shardId; + private boolean isLocalStore; + + public StateContext(BaseStateDescriptor descriptor, Configuration config) { + this.descriptor = descriptor; + this.config = config; + } + + public StateContext withShardId(int shardId) { + this.shardId = shardId; + return this; + } + + public StateContext withLocalStore(boolean localStore) { + this.isLocalStore = localStore; + return this; + } + + public String getName() { + return descriptor.getName(); + } + + public Configuration getConfig() { + return config; + } + + public MetricGroup getMetricGroup() { + return descriptor.getMetricGroup(); + } + + public KeyGroup getKeyGroup() { + return descriptor.getKeyGroup(); + } + + public String getStoreType() { + return descriptor.getStoreType(); + } + + public BaseStateDescriptor getDescriptor() { + return descriptor; + } + + public int getShardId() { + return shardId; + } + + public boolean isLocalStore() { + return isLocalStore; + } + + public int getTotalShardNum() { + return getDescriptor().getAssigner().getKeyGroupNumber(); + } + + public DataModel getDataModel() { + return this.descriptor.getDateModel(); + } + + public StateMode getStateMode() { + return this.descriptor.getStateMode(); + } + + public StateContext clone() { + return new StateContext(descriptor, config).withShardId(shardId).withLocalStore(isLocalStore); + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseKeyDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseKeyDescriptor.java index 2a15d68f9..cc0b67c3b 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseKeyDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseKeyDescriptor.java @@ -23,18 +23,18 @@ public abstract class BaseKeyDescriptor extends BaseStateDescriptor { - protected Class keyClazz; - protected IKeySerializer keySerializer; + protected Class keyClazz; + protected IKeySerializer keySerializer; - protected BaseKeyDescriptor(String name, String storeType) { - super(name, storeType); - } + protected BaseKeyDescriptor(String name, String storeType) { + super(name, storeType); + } - public Class getKeyClazz() { - return keyClazz; - } + public Class getKeyClazz() { + return keyClazz; + } - public IKeySerializer getKeySerializer() { - return keySerializer; - } + public IKeySerializer getKeySerializer() { + return keySerializer; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseStateDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseStateDescriptor.java index 88b3937b1..9e0d50bb8 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseStateDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/BaseStateDescriptor.java @@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import java.io.Serializable; + import org.apache.geaflow.metrics.common.api.MetricGroup; import org.apache.geaflow.state.DataModel; import org.apache.geaflow.state.graph.StateMode; @@ -30,76 +31,76 @@ public abstract class BaseStateDescriptor implements Serializable { - protected String name; - protected String storeType; - protected KeyGroup keyGroup; - protected MetricGroup metricGroup; - protected IKeyGroupAssigner assigner; - protected StateMode stateMode = StateMode.RW; - protected DataModel dateModel; - - protected BaseStateDescriptor(String name, String storeType) { - this.name = name; - this.storeType = storeType; - } - - public BaseStateDescriptor withName(String name) { - this.name = name; - return this; - } - - public BaseStateDescriptor withMetricGroup(MetricGroup metricGroup) { - this.metricGroup = metricGroup; - return this; - } - - public BaseStateDescriptor withKeyGroup(KeyGroup keyGroup) { - this.keyGroup = keyGroup; - return this; - } - - public BaseStateDescriptor withDataModel(DataModel dataModel) { - this.dateModel = dataModel; - return this; - } - - public BaseStateDescriptor withStateMode(StateMode stateMode) { - this.stateMode = stateMode; - return this; - } - - public BaseStateDescriptor withKeyGroupAssigner(IKeyGroupAssigner assigner) { - this.assigner = assigner; - return this; - } - - public StateMode getStateMode() { - return stateMode; - } - - public DataModel getDateModel() { - return dateModel; - } - - public MetricGroup getMetricGroup() { - return metricGroup; - } - - public KeyGroup getKeyGroup() { - return checkNotNull(keyGroup, "keyGroup must be set"); - } - - public String getName() { - return checkNotNull(name, "descriptor name must be set"); - } - - public String getStoreType() { - return checkNotNull(storeType, "storeType must be set"); - } - - public IKeyGroupAssigner getAssigner() { - return assigner; - } - - public abstract DescriptorType getDescriptorType(); + protected String name; + protected String storeType; + protected KeyGroup keyGroup; + protected MetricGroup metricGroup; + protected IKeyGroupAssigner assigner; + protected StateMode stateMode = StateMode.RW; + protected DataModel dateModel; + + protected BaseStateDescriptor(String name, String storeType) { + this.name = name; + this.storeType = storeType; + } + + public BaseStateDescriptor withName(String name) { + this.name = name; + return this; + } + + public BaseStateDescriptor withMetricGroup(MetricGroup metricGroup) { + this.metricGroup = metricGroup; + return this; + } + + public BaseStateDescriptor withKeyGroup(KeyGroup keyGroup) { + this.keyGroup = keyGroup; + return this; + } + + public BaseStateDescriptor withDataModel(DataModel dataModel) { + this.dateModel = dataModel; + return this; + } + + public BaseStateDescriptor withStateMode(StateMode stateMode) { + this.stateMode = stateMode; + return this; + } + + public BaseStateDescriptor withKeyGroupAssigner(IKeyGroupAssigner assigner) { + this.assigner = assigner; + return this; + } + + public StateMode getStateMode() { + return stateMode; + } + + public DataModel getDateModel() { + return dateModel; + } + + public MetricGroup getMetricGroup() { + return metricGroup; + } + + public KeyGroup getKeyGroup() { + return checkNotNull(keyGroup, "keyGroup must be set"); + } + + public String getName() { + return checkNotNull(name, "descriptor name must be set"); + } + + public String getStoreType() { + return checkNotNull(storeType, "storeType must be set"); + } + + public IKeyGroupAssigner getAssigner() { + return assigner; + } + + public abstract DescriptorType getDescriptorType(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/DescriptorType.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/DescriptorType.java index 1d6f060be..c2b1c533b 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/DescriptorType.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/DescriptorType.java @@ -20,20 +20,12 @@ package org.apache.geaflow.state.descriptor; public enum DescriptorType { - /** - * graph descriptor. - */ - GRAPH, - /** - * key value descriptor. - */ - KEY_VALUE, - /** - * key list descriptor. - */ - KEY_LIST, - /** - * key map descriptor. - */ - KEY_MAP + /** graph descriptor. */ + GRAPH, + /** key value descriptor. */ + KEY_VALUE, + /** key list descriptor. */ + KEY_LIST, + /** key map descriptor. */ + KEY_MAP } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/GraphStateDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/GraphStateDescriptor.java index 2658afb00..91a9ed1db 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/GraphStateDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/GraphStateDescriptor.java @@ -24,37 +24,37 @@ public class GraphStateDescriptor extends BaseStateDescriptor { - private GraphDataSchema graphSchema; - private boolean singleton = false; - - private GraphStateDescriptor(String name, String storeType) { - super(name, storeType); - } - - @Override - public DescriptorType getDescriptorType() { - return DescriptorType.GRAPH; - } - - public static GraphStateDescriptor build(String name, String storeType) { - return new GraphStateDescriptor<>(name, storeType); - } - - public GraphStateDescriptor withGraphMeta(GraphMeta descriptor) { - this.graphSchema = new GraphDataSchema(descriptor); - return this; - } - - public GraphStateDescriptor withSingleton() { - this.singleton = true; - return this; - } - - public boolean isSingleton() { - return this.singleton; - } - - public GraphDataSchema getGraphSchema() { - return graphSchema; - } + private GraphDataSchema graphSchema; + private boolean singleton = false; + + private GraphStateDescriptor(String name, String storeType) { + super(name, storeType); + } + + @Override + public DescriptorType getDescriptorType() { + return DescriptorType.GRAPH; + } + + public static GraphStateDescriptor build(String name, String storeType) { + return new GraphStateDescriptor<>(name, storeType); + } + + public GraphStateDescriptor withGraphMeta(GraphMeta descriptor) { + this.graphSchema = new GraphDataSchema(descriptor); + return this; + } + + public GraphStateDescriptor withSingleton() { + this.singleton = true; + return this; + } + + public boolean isSingleton() { + return this.singleton; + } + + public GraphDataSchema getGraphSchema() { + return graphSchema; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyListStateDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyListStateDescriptor.java index c621cb27d..a03fffd14 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyListStateDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyListStateDescriptor.java @@ -23,33 +23,33 @@ public class KeyListStateDescriptor extends BaseKeyDescriptor { - private Class valueClazz; - - protected KeyListStateDescriptor(String name, String storeType) { - super(name, storeType); - } - - @Override - public DescriptorType getDescriptorType() { - return DescriptorType.KEY_LIST; - } - - public static KeyListStateDescriptor build(String name, String storeType) { - return new KeyListStateDescriptor<>(name, storeType); - } - - public KeyListStateDescriptor withTypeInfo(Class keyClazz, Class valueClazz) { - this.keyClazz = keyClazz; - this.valueClazz = valueClazz; - return this; - } - - public KeyListStateDescriptor withKeySerializer(IKeySerializer keySerializer) { - this.keySerializer = keySerializer; - return this; - } - - public Class getValueClazz() { - return valueClazz; - } + private Class valueClazz; + + protected KeyListStateDescriptor(String name, String storeType) { + super(name, storeType); + } + + @Override + public DescriptorType getDescriptorType() { + return DescriptorType.KEY_LIST; + } + + public static KeyListStateDescriptor build(String name, String storeType) { + return new KeyListStateDescriptor<>(name, storeType); + } + + public KeyListStateDescriptor withTypeInfo(Class keyClazz, Class valueClazz) { + this.keyClazz = keyClazz; + this.valueClazz = valueClazz; + return this; + } + + public KeyListStateDescriptor withKeySerializer(IKeySerializer keySerializer) { + this.keySerializer = keySerializer; + return this; + } + + public Class getValueClazz() { + return valueClazz; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyMapStateDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyMapStateDescriptor.java index 1867af66a..44b45c62a 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyMapStateDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyMapStateDescriptor.java @@ -23,40 +23,40 @@ public class KeyMapStateDescriptor extends BaseKeyDescriptor { - private Class valueClazz; - private Class subKeyClazz; - - protected KeyMapStateDescriptor(String name, String storeType) { - super(name, storeType); - } - - @Override - public DescriptorType getDescriptorType() { - return DescriptorType.KEY_MAP; - } - - public static KeyMapStateDescriptor build(String name, - String storeType) { - return new KeyMapStateDescriptor<>(name, storeType); - } - - public KeyMapStateDescriptor withTypeInfo(Class keyClazz, Class subKeyClazz, Class valueClazz) { - this.keyClazz = keyClazz; - this.subKeyClazz = subKeyClazz; - this.valueClazz = valueClazz; - return this; - } - - public KeyMapStateDescriptor withKeySerializer(IKeySerializer kvSerializer) { - this.keySerializer = kvSerializer; - return this; - } - - public Class getValueClazz() { - return valueClazz; - } - - public Class getSubKeyClazz() { - return subKeyClazz; - } + private Class valueClazz; + private Class subKeyClazz; + + protected KeyMapStateDescriptor(String name, String storeType) { + super(name, storeType); + } + + @Override + public DescriptorType getDescriptorType() { + return DescriptorType.KEY_MAP; + } + + public static KeyMapStateDescriptor build(String name, String storeType) { + return new KeyMapStateDescriptor<>(name, storeType); + } + + public KeyMapStateDescriptor withTypeInfo( + Class keyClazz, Class subKeyClazz, Class valueClazz) { + this.keyClazz = keyClazz; + this.subKeyClazz = subKeyClazz; + this.valueClazz = valueClazz; + return this; + } + + public KeyMapStateDescriptor withKeySerializer(IKeySerializer kvSerializer) { + this.keySerializer = kvSerializer; + return this; + } + + public Class getValueClazz() { + return valueClazz; + } + + public Class getSubKeyClazz() { + return subKeyClazz; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyValueStateDescriptor.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyValueStateDescriptor.java index f7c559884..9bffef98c 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyValueStateDescriptor.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/descriptor/KeyValueStateDescriptor.java @@ -20,51 +20,52 @@ package org.apache.geaflow.state.descriptor; import java.util.function.Supplier; + import org.apache.geaflow.state.serializer.IKVSerializer; public class KeyValueStateDescriptor extends BaseKeyDescriptor { - private Supplier defaultValueSupplier; - private Class valueClazz; + private Supplier defaultValueSupplier; + private Class valueClazz; - protected KeyValueStateDescriptor(String name, String storeType) { - super(name, storeType); - } + protected KeyValueStateDescriptor(String name, String storeType) { + super(name, storeType); + } - @Override - public DescriptorType getDescriptorType() { - return DescriptorType.KEY_VALUE; - } + @Override + public DescriptorType getDescriptorType() { + return DescriptorType.KEY_VALUE; + } - public static KeyValueStateDescriptor build(String name, String storeType) { - return new KeyValueStateDescriptor<>(name, storeType); - } + public static KeyValueStateDescriptor build(String name, String storeType) { + return new KeyValueStateDescriptor<>(name, storeType); + } - public KeyValueStateDescriptor withDefaultValue(Supplier valueSupplier) { - this.defaultValueSupplier = valueSupplier; - return this; - } + public KeyValueStateDescriptor withDefaultValue(Supplier valueSupplier) { + this.defaultValueSupplier = valueSupplier; + return this; + } - public KeyValueStateDescriptor withTypeInfo(Class keyClazz, Class valueClazz) { - this.keyClazz = keyClazz; - this.valueClazz = valueClazz; - return this; - } + public KeyValueStateDescriptor withTypeInfo(Class keyClazz, Class valueClazz) { + this.keyClazz = keyClazz; + this.valueClazz = valueClazz; + return this; + } - public KeyValueStateDescriptor withKVSerializer(IKVSerializer kvSerializer) { - this.keySerializer = kvSerializer; - return this; - } + public KeyValueStateDescriptor withKVSerializer(IKVSerializer kvSerializer) { + this.keySerializer = kvSerializer; + return this; + } - public V getDefaultValue() { - if (defaultValueSupplier == null) { - return null; - } - // It is better to get a default copy. - return defaultValueSupplier.get(); + public V getDefaultValue() { + if (defaultValueSupplier == null) { + return null; } + // It is better to get a default copy. + return defaultValueSupplier.get(); + } - public Class getValueClazz() { - return valueClazz; - } + public Class getValueClazz() { + return valueClazz; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/GraphStateSummary.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/GraphStateSummary.java index d833ba473..39b589f7c 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/GraphStateSummary.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/GraphStateSummary.java @@ -21,5 +21,5 @@ public interface GraphStateSummary { - SummaryResult collect(); + SummaryResult collect(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/LoadOption.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/LoadOption.java index dc14f0093..c7d43003d 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/LoadOption.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/LoadOption.java @@ -23,49 +23,47 @@ public class LoadOption { - private KeyGroup keyGroup; - private long checkPointId; - private LoadEnum loadEnum = LoadEnum.REMOTE_TO_DISK; + private KeyGroup keyGroup; + private long checkPointId; + private LoadEnum loadEnum = LoadEnum.REMOTE_TO_DISK; - private LoadOption() { + private LoadOption() {} - } + public static LoadOption of() { + return new LoadOption(); + } - public static LoadOption of() { - return new LoadOption(); - } + public LoadOption withKeyGroup(KeyGroup keyGroup) { + this.keyGroup = keyGroup; + return this; + } - public LoadOption withKeyGroup(KeyGroup keyGroup) { - this.keyGroup = keyGroup; - return this; - } + public LoadOption withLoadEnum(LoadEnum loadEnum) { + this.loadEnum = loadEnum; + return this; + } - public LoadOption withLoadEnum(LoadEnum loadEnum) { - this.loadEnum = loadEnum; - return this; - } + public LoadOption withCheckpointId(long checkpointId) { + this.checkPointId = checkpointId; + return this; + } - public LoadOption withCheckpointId(long checkpointId) { - this.checkPointId = checkpointId; - return this; - } + public KeyGroup getKeyGroup() { + return keyGroup; + } - public KeyGroup getKeyGroup() { - return keyGroup; - } + public LoadEnum getLoadEnum() { + return loadEnum; + } - public LoadEnum getLoadEnum() { - return loadEnum; - } + public long getCheckPointId() { + return checkPointId; + } - public long getCheckPointId() { - return checkPointId; - } - - public enum LoadEnum { - // Download remote files to local disk. - REMOTE_TO_DISK, - // Download remote files to local disk and load to memory. - REMOTE_TO_MEM - } + public enum LoadEnum { + // Download remote files to local disk. + REMOTE_TO_DISK, + // Download remote files to local disk and load to memory. + REMOTE_TO_MEM + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableGraphState.java index d4fcc1195..ed617eb53 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableGraphState.java @@ -21,7 +21,7 @@ public interface ManageableGraphState extends ManageableState { - GraphStateSummary summary(); + GraphStateSummary summary(); - StateMetric metric(); + StateMetric metric(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableState.java index dc37e8417..4712cbfe2 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/ManageableState.java @@ -21,6 +21,5 @@ public interface ManageableState { - StateOperator operate(); - + StateOperator operate(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/MetricResult.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/MetricResult.java index 16b29410e..dd9a0f6ab 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/MetricResult.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/MetricResult.java @@ -19,6 +19,4 @@ package org.apache.geaflow.state.manage; -public class MetricResult { - -} +public class MetricResult {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateMetric.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateMetric.java index 0bd24d1eb..09f3b9e81 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateMetric.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateMetric.java @@ -21,5 +21,5 @@ public interface StateMetric { - MetricResult collect(); + MetricResult collect(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateOperator.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateOperator.java index f51438910..1cc96c002 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateOperator.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/StateOperator.java @@ -19,48 +19,30 @@ package org.apache.geaflow.state.manage; -/** - * The state operator class. - */ +/** The state operator class. */ public interface StateOperator { - /** - * Load state. - */ - void load(LoadOption loadOption); + /** Load state. */ + void load(LoadOption loadOption); - /** - * Set checkpoint id to state. - */ - void setCheckpointId(long checkpointId); + /** Set checkpoint id to state. */ + void setCheckpointId(long checkpointId); - /** - * Flush state to disk. - */ - void finish(); + /** Flush state to disk. */ + void finish(); - /** - * Compact state data. - */ - void compact(); + /** Compact state data. */ + void compact(); - /** - * Persist data. - */ - void archive(); + /** Persist data. */ + void archive(); - /** - * Recover data from persistent storage. - */ - void recover(); + /** Recover data from persistent storage. */ + void recover(); - /** - * Close state and release used resource. - */ - void close(); + /** Close state and release used resource. */ + void close(); - /** - * Drop disk data and close. - */ - void drop(); + /** Drop disk data and close. */ + void drop(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/SummaryResult.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/SummaryResult.java index 24b2b18d3..ee0ae2d89 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/SummaryResult.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/manage/SummaryResult.java @@ -19,6 +19,4 @@ package org.apache.geaflow.state.manage; -public class SummaryResult { - -} +public class SummaryResult {} diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryType.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryType.java index 892bb1149..63bf2e580 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryType.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryType.java @@ -23,13 +23,13 @@ public class QueryType { - private final DataType type; + private final DataType type; - public QueryType(DataType type) { - this.type = type; - } + public QueryType(DataType type) { + this.type = type; + } - public DataType getType() { - return type; - } + public DataType getType() { + return type; + } } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphState.java index 192b44787..01688899d 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphState.java @@ -23,8 +23,6 @@ public interface QueryableAllGraphState extends QueryableGraphState { - /** - * Query by a filter, sharing by all the keys. - */ - QueryableGraphState by(IFilter filter); + /** Query by a filter, sharing by all the keys. */ + QueryableGraphState by(IFilter filter); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableGraphState.java index b1aa4c7bd..4b4412863 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableGraphState.java @@ -21,52 +21,35 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.pushdown.project.IProjector; -/** - * The base query interface for graph. - */ +/** The base query interface for graph. */ public interface QueryableGraphState { - /** - * Query by some projector. - */ - QueryableGraphState select(IProjector projector); + /** Query by some projector. */ + QueryableGraphState select(IProjector projector); - /** - * Query by edge limit. - */ - QueryableGraphState limit(long out, long in); + /** Query by edge limit. */ + QueryableGraphState limit(long out, long in); - /** - * Query by some order. - */ - QueryableGraphState orderBy(EdgeAtom atom); + /** Query by some order. */ + QueryableGraphState orderBy(EdgeAtom atom); - /** - * Query a aggregate result. - */ - Map aggregate(); + /** Query a aggregate result. */ + Map aggregate(); - /** - * Query result is a list. - */ - List asList(); + /** Query result is a list. */ + List asList(); - /** - * Get id Iterator. - */ - CloseableIterator idIterator(); + /** Get id Iterator. */ + CloseableIterator idIterator(); - /** - * Query result is a iterator. - */ - CloseableIterator iterator(); + /** Query result is a iterator. */ + CloseableIterator iterator(); - /** - * Get a simple result like a vertex. - */ - R get(); + /** Get a simple result like a vertex. */ + R get(); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphState.java index c9e55690e..ec580e5cc 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphState.java @@ -21,18 +21,12 @@ import org.apache.geaflow.state.pushdown.filter.IFilter; -/** - * The point query interface for graph. - */ +/** The point query interface for graph. */ public interface QueryableKeysGraphState extends QueryableGraphState { - /** - * Query by the filters, corresponding to the search keys. - */ - QueryableGraphState by(IFilter... filter); + /** Query by the filters, corresponding to the search keys. */ + QueryableGraphState by(IFilter... filter); - /** - * Query by a filter, sharing by the search keys. - */ - QueryableGraphState by(IFilter filter); + /** Query by a filter, sharing by the search keys. */ + QueryableGraphState by(IFilter filter); } diff --git a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphState.java b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphState.java index 2cfac7a1d..f1827ffdd 100644 --- a/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphState.java +++ b/geaflow/geaflow-state/geaflow-state-api/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphState.java @@ -20,20 +20,15 @@ package org.apache.geaflow.state.query; import java.util.Map; + import org.apache.geaflow.state.pushdown.filter.IFilter; -/** - * The dynamic graph query interface. - */ +/** The dynamic graph query interface. */ public interface QueryableVersionGraphState { - /** - * Query by a filter, sharing by the search keys. - */ - QueryableVersionGraphState by(IFilter filter); + /** Query by a filter, sharing by the search keys. */ + QueryableVersionGraphState by(IFilter filter); - /** - * Query result is a map, which key is the version. - */ - Map asMap(); + /** Query result is a map, which key is the version. */ + Map asMap(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/StoreType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/StoreType.java index 0e22e46af..fe8519bb3 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/StoreType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/StoreType.java @@ -21,37 +21,25 @@ public enum StoreType { - /** - * MEMORY. - */ - MEMORY, - /** - * ROCKSDB. - */ - ROCKSDB, - /** - * HBASE. - */ - HBASE, - /** - * REDIS. - */ - REDIS, - /** - * JDBC. - */ - JDBC, - /** - * PAIMON (Experimental). - */ - PAIMON; + /** MEMORY. */ + MEMORY, + /** ROCKSDB. */ + ROCKSDB, + /** HBASE. */ + HBASE, + /** REDIS. */ + REDIS, + /** JDBC. */ + JDBC, + /** PAIMON (Experimental). */ + PAIMON; - public static StoreType getEnum(String value) { - for (StoreType v : values()) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - return MEMORY; + public static StoreType getEnum(String value) { + for (StoreType v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } } + return MEMORY; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/DataType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/DataType.java index 247075141..01ef053a1 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/DataType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/DataType.java @@ -20,40 +20,22 @@ package org.apache.geaflow.state.data; public enum DataType { - /** - * VType. - */ - V, - /** - * EType. - */ - E, - /** - * VEType. - */ - VE, - /** - * VID Type. - */ - VID, - /** - * Vertex topology type. - */ - V_TOPO, - /** - * Edge topology type. - */ - E_TOPO, - /** - * Vertex and Edge topology type. - */ - VE_TOPO, - /** - * Project field. - */ - PROJECT_FIELD, - /** - * Unknown type. - */ - OTHER; + /** VType. */ + V, + /** EType. */ + E, + /** VEType. */ + VE, + /** VID Type. */ + VID, + /** Vertex topology type. */ + V_TOPO, + /** Edge topology type. */ + E_TOPO, + /** Vertex and Edge topology type. */ + VE_TOPO, + /** Project field. */ + PROJECT_FIELD, + /** Unknown type. */ + OTHER; } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/OneDegreeGraph.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/OneDegreeGraph.java index 43552c9ca..7150e4938 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/OneDegreeGraph.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/OneDegreeGraph.java @@ -20,32 +20,33 @@ package org.apache.geaflow.state.data; import java.io.Serializable; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; public class OneDegreeGraph implements Serializable { - private IVertex vertex; - protected CloseableIterator> edgeIterator; - protected K key; + private IVertex vertex; + protected CloseableIterator> edgeIterator; + protected K key; - public OneDegreeGraph(K key, IVertex vertex, CloseableIterator> edgeIterator) { - this.key = key; - this.vertex = vertex; - this.edgeIterator = edgeIterator; - } + public OneDegreeGraph( + K key, IVertex vertex, CloseableIterator> edgeIterator) { + this.key = key; + this.vertex = vertex; + this.edgeIterator = edgeIterator; + } - public K getKey() { - return key; - } + public K getKey() { + return key; + } - public IVertex getVertex() { - return vertex; - } + public IVertex getVertex() { + return vertex; + } - public CloseableIterator> getEdgeIterator() { - return edgeIterator; - } + public CloseableIterator> getEdgeIterator() { + return edgeIterator; + } } - diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/TimeRange.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/TimeRange.java index e17e43d05..d1ec98985 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/TimeRange.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/data/TimeRange.java @@ -20,66 +20,63 @@ package org.apache.geaflow.state.data; import java.io.Serializable; + import org.apache.geaflow.utils.math.MathUtil; public class TimeRange implements Comparable, Serializable { - private final long start; - private final long end; + private final long start; + private final long end; - private TimeRange(long start, long end) { - this.start = start; - this.end = end; - } + private TimeRange(long start, long end) { + this.start = start; + this.end = end; + } - /** - * Return a TimeRange from start(INCLUSIVE) to end(EXCLUSIVE). - */ - public static TimeRange of(long start, long end) { - return new TimeRange(start, end); - } + /** Return a TimeRange from start(INCLUSIVE) to end(EXCLUSIVE). */ + public static TimeRange of(long start, long end) { + return new TimeRange(start, end); + } - public long getStart() { - return start; - } + public long getStart() { + return start; + } - public long getEnd() { - return end; - } + public long getEnd() { + return end; + } - @Override - public int hashCode() { - return MathUtil.longToIntWithBitMixing(start + end); - } + @Override + public int hashCode() { + return MathUtil.longToIntWithBitMixing(start + end); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - TimeRange range = (TimeRange) o; + TimeRange range = (TimeRange) o; - return end == range.end && start == range.start; - } + return end == range.end && start == range.start; + } - @Override - public String toString() { - return String.format("TimeRange{start=%d, end=%d}", start, end); - } + @Override + public String toString() { + return String.format("TimeRange{start=%d, end=%d}", start, end); + } - /** - * Returns {@code true} if this range contain the given ts. - */ - public boolean contain(long ts) { - return ts >= start && ts < end; - } + /** Returns {@code true} if this range contain the given ts. */ + public boolean contain(long ts) { + return ts >= start && ts < end; + } - @Override - public int compareTo(TimeRange o) { - return Long.compare(end, o.getEnd()); - } + @Override + public int compareTo(TimeRange o) { + return Long.compare(end, o.getEnd()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/DynamicGraphTrait.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/DynamicGraphTrait.java index 08e330720..f4c6def77 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/DynamicGraphTrait.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/DynamicGraphTrait.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -29,94 +30,64 @@ import org.apache.geaflow.state.data.OneDegreeGraph; import org.apache.geaflow.state.pushdown.IStatePushDown; -/** - * This interface describes the dynamic graph traits. - */ +/** This interface describes the dynamic graph traits. */ public interface DynamicGraphTrait { - /** - * Add the edge and its corresponding version. - */ - void addEdge(long version, IEdge edge); - - /** - * Add the vertex and its corresponding version. - */ - void addVertex(long version, IVertex vertex); - - /** - * Fetch the vertex according to the version, id and pushdown condition. - */ - IVertex getVertex(long version, K sid, IStatePushDown pushdown); - - /** - * Fetch the edge list according to the version, id and pushdown condition. - */ - List> getEdges(long version, K sid, IStatePushDown pushdown); - - /** - * Fetch the one degree graph according to the version, id and pushdown condition. - */ - OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown); - - /** - * Fetch the iterator of the ids of all graph vertices. - */ - CloseableIterator vertexIDIterator(); - - /** - * Fetch the iterator of the ids of all graph vertices by pushdown condition. - */ - CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown); - - /** - * Fetch the iterator of the graph vertices according to the version and pushdown condition. - */ - CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown); - - /** - * Fetch the iterator of some vertices according to the version, ids and pushdown condition. - */ - CloseableIterator> getVertexIterator(long version, List keys, IStatePushDown pushdown); - - /** - * Fetch the iterator of graph edges according to the version and pushdown condition. - */ - CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown); - - /** - * Fetch the iterator of some edges according to the version, ids and pushdown condition. - */ - CloseableIterator> getEdgeIterator(long version, List keys, IStatePushDown pushdown); - - /** - * Fetch the iterator of one degree graph according to the version and pushdown condition. - */ - CloseableIterator> getOneDegreeGraphIterator(long version, IStatePushDown pushdown); - - /** - * Fetch the iterator of one degree graph according to the version, ids and pushdown condition. - */ - CloseableIterator> getOneDegreeGraphIterator(long version, List keys, IStatePushDown pushdown); - - /** - * Fetch the versions of some id. - */ - List getAllVersions(K id, DataType dataType); - - /** - * Fetch the latest version of some id. - */ - long getLatestVersion(K id, DataType dataType); - - /** - * Fetch all versioned data by id and pushdown condition. - */ - Map> getAllVersionData(K id, IStatePushDown pushdown, DataType dataType); - - /** - * Fetch some specific versioned data by id and pushdown condition. - */ - Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType); + /** Add the edge and its corresponding version. */ + void addEdge(long version, IEdge edge); + + /** Add the vertex and its corresponding version. */ + void addVertex(long version, IVertex vertex); + + /** Fetch the vertex according to the version, id and pushdown condition. */ + IVertex getVertex(long version, K sid, IStatePushDown pushdown); + + /** Fetch the edge list according to the version, id and pushdown condition. */ + List> getEdges(long version, K sid, IStatePushDown pushdown); + + /** Fetch the one degree graph according to the version, id and pushdown condition. */ + OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown); + + /** Fetch the iterator of the ids of all graph vertices. */ + CloseableIterator vertexIDIterator(); + + /** Fetch the iterator of the ids of all graph vertices by pushdown condition. */ + CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown); + + /** Fetch the iterator of the graph vertices according to the version and pushdown condition. */ + CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown); + + /** Fetch the iterator of some vertices according to the version, ids and pushdown condition. */ + CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown); + + /** Fetch the iterator of graph edges according to the version and pushdown condition. */ + CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown); + + /** Fetch the iterator of some edges according to the version, ids and pushdown condition. */ + CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown); + + /** Fetch the iterator of one degree graph according to the version and pushdown condition. */ + CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown); + + /** + * Fetch the iterator of one degree graph according to the version, ids and pushdown condition. + */ + CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown); + + /** Fetch the versions of some id. */ + List getAllVersions(K id, DataType dataType); + + /** Fetch the latest version of some id. */ + long getLatestVersion(K id, DataType dataType); + + /** Fetch all versioned data by id and pushdown condition. */ + Map> getAllVersionData(K id, IStatePushDown pushdown, DataType dataType); + + /** Fetch some specific versioned data by id and pushdown condition. */ + Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StateMode.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StateMode.java index 40e31a3a3..989e0169d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StateMode.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StateMode.java @@ -22,35 +22,29 @@ import org.apache.geaflow.common.exception.GeaflowRuntimeException; public enum StateMode { - /** - * normal state. - */ - RW((short) 0), - /** - * read only state for state sharing. - */ - RDONLY((short) 1), - /** - * copy on write state for cold start. - */ - COW((short) 2); + /** normal state. */ + RW((short) 0), + /** read only state for state sharing. */ + RDONLY((short) 1), + /** copy on write state for cold start. */ + COW((short) 2); - private final short flag; + private final short flag; - StateMode(short flag) { - this.flag = flag; - } + StateMode(short flag) { + this.flag = flag; + } - public short value() { - return flag; - } + public short value() { + return flag; + } - public static StateMode getEnum(String value) { - for (StateMode v : values()) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - throw new GeaflowRuntimeException("not support " + value); + public static StateMode getEnum(String value) { + for (StateMode v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } } + throw new GeaflowRuntimeException("not support " + value); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StaticGraphTrait.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StaticGraphTrait.java index 0a38bb643..72436088f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StaticGraphTrait.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/StaticGraphTrait.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -28,97 +29,60 @@ import org.apache.geaflow.state.data.OneDegreeGraph; import org.apache.geaflow.state.pushdown.IStatePushDown; -/** - * This interface describes the static graph traits. - */ +/** This interface describes the static graph traits. */ public interface StaticGraphTrait { - /** - * Add the edge to the static graph. - */ - void addEdge(IEdge edge); - - /** - * Add the vertex to the static graph. - */ - void addVertex(IVertex vertex); - - /** - * Fetch the vertex according to the id and pushdown condition. - */ - IVertex getVertex(K sid, IStatePushDown pushdown); - - /** - * Fetch the edges according to the id and pushdown condition. - */ - List> getEdges(K sid, IStatePushDown pushdown); - - /** - * Fetch the one degree graph according to the id and pushdown condition. - */ - OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown); - - /** - * Fetch the iterator of the ids of all graph vertices. - */ - CloseableIterator vertexIDIterator(); - - /** - * Fetch the iterator of the ids of all graph vertices by pushdown condition. - */ - CloseableIterator vertexIDIterator(IStatePushDown pushDown); - - /** - * Fetch the iterator of the graph vertices according to the pushdown condition. - */ - CloseableIterator> getVertexIterator(IStatePushDown pushdown); - - /** - * Fetch the iterator of some vertices according to the ids and pushdown condition. - */ - CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown); - - /** - * Fetch the iterator of graph edges according to the pushdown condition. - */ - CloseableIterator> getEdgeIterator(IStatePushDown pushdown); - - /** - * Fetch the iterator of the graph edges according to the ids and pushdown condition. - */ - CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown); - - /** - * Fetch the iterator of one degree graph according to the pushdown condition. - */ - CloseableIterator> getOneDegreeGraphIterator(IStatePushDown pushdown); - - /** - * Fetch the iterator of one degree graph according to the ids and pushdown condition. - */ - CloseableIterator> getOneDegreeGraphIterator(List keys, - IStatePushDown pushdown); - - - /** - * Fetch the project result of edges according to the pushdown condition. - */ - CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown); - - /** - * Fetch the project result of edges according to the ids and pushdown condition. - */ - CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown); - - /** - * Fetch the aggregated results according to the pushdown condition. - */ - Map getAggResult(IStatePushDown pushdown); - - /** - * Fetch the aggregated results according to the ids and pushdown condition. - */ - Map getAggResult(List keys, IStatePushDown pushdown); + /** Add the edge to the static graph. */ + void addEdge(IEdge edge); + + /** Add the vertex to the static graph. */ + void addVertex(IVertex vertex); + + /** Fetch the vertex according to the id and pushdown condition. */ + IVertex getVertex(K sid, IStatePushDown pushdown); + + /** Fetch the edges according to the id and pushdown condition. */ + List> getEdges(K sid, IStatePushDown pushdown); + + /** Fetch the one degree graph according to the id and pushdown condition. */ + OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown); + + /** Fetch the iterator of the ids of all graph vertices. */ + CloseableIterator vertexIDIterator(); + + /** Fetch the iterator of the ids of all graph vertices by pushdown condition. */ + CloseableIterator vertexIDIterator(IStatePushDown pushDown); + + /** Fetch the iterator of the graph vertices according to the pushdown condition. */ + CloseableIterator> getVertexIterator(IStatePushDown pushdown); + + /** Fetch the iterator of some vertices according to the ids and pushdown condition. */ + CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown); + + /** Fetch the iterator of graph edges according to the pushdown condition. */ + CloseableIterator> getEdgeIterator(IStatePushDown pushdown); + + /** Fetch the iterator of the graph edges according to the ids and pushdown condition. */ + CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown); + + /** Fetch the iterator of one degree graph according to the pushdown condition. */ + CloseableIterator> getOneDegreeGraphIterator(IStatePushDown pushdown); + + /** Fetch the iterator of one degree graph according to the ids and pushdown condition. */ + CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown); + + /** Fetch the project result of edges according to the pushdown condition. */ + CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown); + + /** Fetch the project result of edges according to the ids and pushdown condition. */ + CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown); + + /** Fetch the aggregated results according to the pushdown condition. */ + Map getAggResult(IStatePushDown pushdown); + + /** Fetch the aggregated results according to the ids and pushdown condition. */ + Map getAggResult(List keys, IStatePushDown pushdown); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BaseBytesEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BaseBytesEncoder.java index 37c1fed4e..e88128f15 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BaseBytesEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BaseBytesEncoder.java @@ -21,15 +21,15 @@ public abstract class BaseBytesEncoder implements IBytesEncoder { - protected static final int BYTE_SHIFT = 4; - private static final byte MAGIC_MASK = 0x0F; + protected static final int BYTE_SHIFT = 4; + private static final byte MAGIC_MASK = 0x0F; - protected byte combine(byte x, byte y) { - return (byte) ((x << BYTE_SHIFT) | y); - } + protected byte combine(byte x, byte y) { + return (byte) ((x << BYTE_SHIFT) | y); + } - @Override - public byte parseMagicNumber(byte b) { - return (byte) (b & MAGIC_MASK); - } + @Override + public byte parseMagicNumber(byte b) { + return (byte) (b & MAGIC_MASK); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BytesEncoderRepo.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BytesEncoderRepo.java index 15e9f4314..e402f17b7 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BytesEncoderRepo.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/BytesEncoderRepo.java @@ -19,24 +19,25 @@ package org.apache.geaflow.state.graph.encoder; -import com.google.common.base.Preconditions; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import com.google.common.base.Preconditions; + public class BytesEncoderRepo { - private static final Map ENCODER_MAP = new ConcurrentHashMap<>(); + private static final Map ENCODER_MAP = new ConcurrentHashMap<>(); - public static void register(IBytesEncoder encoder) { - ENCODER_MAP.put(encoder.getMyMagicNumber(), encoder); - } + public static void register(IBytesEncoder encoder) { + ENCODER_MAP.put(encoder.getMyMagicNumber(), encoder); + } - public static IBytesEncoder get(byte myMagicNumber) { - return Preconditions.checkNotNull(ENCODER_MAP.get(myMagicNumber), - "not found encoder " + myMagicNumber); - } + public static IBytesEncoder get(byte myMagicNumber) { + return Preconditions.checkNotNull( + ENCODER_MAP.get(myMagicNumber), "not found encoder " + myMagicNumber); + } - static { - register(new DefaultBytesEncoder()); - } + static { + register(new DefaultBytesEncoder()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/DefaultBytesEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/DefaultBytesEncoder.java index 0b5d40e89..ed452f534 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/DefaultBytesEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/DefaultBytesEncoder.java @@ -25,94 +25,95 @@ public class DefaultBytesEncoder extends BaseBytesEncoder implements IBytesEncoder { - private final byte magicNumber = 0x01; + private final byte magicNumber = 0x01; - @Override - public byte[] combine(List listBytes) { - int len = listBytes.stream().mapToInt(c -> c.length).sum(); - ByteBuffer bf = ByteBuffer.wrap(new byte[len + Short.BYTES * listBytes.size() + Byte.BYTES]); - for (byte[] bytes : listBytes) { - bf.put(bytes); - } - for (byte[] bytes : listBytes) { - bf.putShort((short) bytes.length); - } - bf.put(combine((byte) listBytes.size(), magicNumber)); - return bf.array(); + @Override + public byte[] combine(List listBytes) { + int len = listBytes.stream().mapToInt(c -> c.length).sum(); + ByteBuffer bf = ByteBuffer.wrap(new byte[len + Short.BYTES * listBytes.size() + Byte.BYTES]); + for (byte[] bytes : listBytes) { + bf.put(bytes); } + for (byte[] bytes : listBytes) { + bf.putShort((short) bytes.length); + } + bf.put(combine((byte) listBytes.size(), magicNumber)); + return bf.array(); + } - @Override - public List split(byte[] bytes) { - ByteBuffer bf = ByteBuffer.wrap(bytes); - bf.position(bytes.length - Byte.BYTES); - byte lastByte = bf.get(); - if (magicNumber != parseMagicNumber(lastByte)) { - return null; - } - int size = lastByte >> BYTE_SHIFT; - int pos = bytes.length - Short.BYTES * size - Byte.BYTES; - short[] lenArray = new short[size]; - bf.position(pos); - for (int i = 0; i < size; i++) { - lenArray[i] = bf.getShort(); - } + @Override + public List split(byte[] bytes) { + ByteBuffer bf = ByteBuffer.wrap(bytes); + bf.position(bytes.length - Byte.BYTES); + byte lastByte = bf.get(); + if (magicNumber != parseMagicNumber(lastByte)) { + return null; + } + int size = lastByte >> BYTE_SHIFT; + int pos = bytes.length - Short.BYTES * size - Byte.BYTES; + short[] lenArray = new short[size]; + bf.position(pos); + for (int i = 0; i < size; i++) { + lenArray[i] = bf.getShort(); + } - bf.position(0); - List list = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - byte[] tmp = new byte[lenArray[i]]; - bf.get(tmp); - list.add(tmp); - } - return list; + bf.position(0); + List list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + byte[] tmp = new byte[lenArray[i]]; + bf.get(tmp); + list.add(tmp); } + return list; + } - @Override - public byte[] combine(List listBytes, byte[] delimiter) { - int len = listBytes.stream().mapToInt(c -> c.length).sum(); - ByteBuffer bf = ByteBuffer.wrap( + @Override + public byte[] combine(List listBytes, byte[] delimiter) { + int len = listBytes.stream().mapToInt(c -> c.length).sum(); + ByteBuffer bf = + ByteBuffer.wrap( new byte[len + (delimiter.length + Short.BYTES) * listBytes.size() + Byte.BYTES]); - for (byte[] bytes : listBytes) { - bf.put(bytes); - bf.put(delimiter); - } - for (byte[] bytes : listBytes) { - bf.putShort((short) bytes.length); - } - bf.put(combine((byte) listBytes.size(), magicNumber)); - return bf.array(); + for (byte[] bytes : listBytes) { + bf.put(bytes); + bf.put(delimiter); } + for (byte[] bytes : listBytes) { + bf.putShort((short) bytes.length); + } + bf.put(combine((byte) listBytes.size(), magicNumber)); + return bf.array(); + } - @Override - public List split(byte[] bytes, byte[] delimiter) { - ByteBuffer bf = ByteBuffer.wrap(bytes); - bf.position(bytes.length - Byte.BYTES); - byte lastByte = bf.get(); - if (magicNumber != parseMagicNumber(lastByte)) { - return null; - } - int size = lastByte >> BYTE_SHIFT; - int pos = bytes.length - Short.BYTES * size - Byte.BYTES; - short[] lenArray = new short[size]; - bf.position(pos); - for (int i = 0; i < size; i++) { - lenArray[i] = bf.getShort(); - } - - byte[] empty = new byte[delimiter.length]; - bf.position(0); - List list = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - byte[] tmp = new byte[lenArray[i]]; - bf.get(tmp); - list.add(tmp); - bf.get(empty); - } - return list; + @Override + public List split(byte[] bytes, byte[] delimiter) { + ByteBuffer bf = ByteBuffer.wrap(bytes); + bf.position(bytes.length - Byte.BYTES); + byte lastByte = bf.get(); + if (magicNumber != parseMagicNumber(lastByte)) { + return null; + } + int size = lastByte >> BYTE_SHIFT; + int pos = bytes.length - Short.BYTES * size - Byte.BYTES; + short[] lenArray = new short[size]; + bf.position(pos); + for (int i = 0; i < size; i++) { + lenArray[i] = bf.getShort(); } - @Override - public byte getMyMagicNumber() { - return magicNumber; + byte[] empty = new byte[delimiter.length]; + bf.position(0); + List list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + byte[] tmp = new byte[lenArray[i]]; + bf.get(tmp); + list.add(tmp); + bf.get(empty); } + return list; + } + + @Override + public byte getMyMagicNumber() { + return magicNumber; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeAtom.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeAtom.java index 524f53dbc..9f036479f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeAtom.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeAtom.java @@ -19,12 +19,11 @@ package org.apache.geaflow.state.graph.encoder; -import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Longs; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; + import org.apache.commons.collections.comparators.ComparatorChain; import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.IGraphElementWithTimeField; @@ -34,262 +33,262 @@ import org.apache.geaflow.state.pushdown.inner.PushDownPb; import org.apache.geaflow.state.schema.GraphDataSchema; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Longs; + public enum EdgeAtom { - /** - * SRC EDGE ATOM. - */ - SRC_ID { - @Override - public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { - return edge.getSrcId(); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return graphDataSchema.getKeyType().serialize(edge.getSrcId()); - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - edge.setSrcId((K) value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - edge.setSrcId((K) graphDataSchema.getKeyType().deserialize(value)); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.SRC_ID; - } - }, - /** - * TARGET EDGE ATOM. - */ - DST_ID { - @Override - public Object getValue(IEdge edge, - GraphDataSchema graphDataSchema) { - return edge.getTargetId(); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return graphDataSchema.getKeyType().serialize(edge.getTargetId()); - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - edge.setTargetId((K) value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - edge.setTargetId((K) graphDataSchema.getKeyType().deserialize(value)); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.DST_ID; - } - }, - /** - * TIME EDGE ATOM. - */ - TIME { - @Override - public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { - return graphDataSchema.getEdgeMeta().getGraphFieldSerializer().getValue(edge, GraphFiledName.TIME); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return Longs.toByteArray((Long) getValue(edge, graphDataSchema)); - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getEdgeMeta().getGraphFieldSerializer().setValue(edge, GraphFiledName.TIME, value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(edge, Longs.fromByteArray(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.TIME; - } - }, - DESC_TIME { - @Override - public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { - return Long.MAX_VALUE - (Long) TIME.getValue(edge, graphDataSchema); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return Longs.toByteArray((Long) getValue(edge, graphDataSchema)); - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getEdgeMeta().getGraphFieldSerializer().setValue(edge, - GraphFiledName.TIME, Long.MAX_VALUE - (Long) value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(edge, Longs.fromByteArray(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.TIME; - } - }, - /** - * DIRECTION EDGE ATOM. - */ - DIRECTION { - @Override - public Object getValue(IEdge edge, - GraphDataSchema graphDataSchema) { - return edge.getDirect(); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return edge.getDirect() == EdgeDirection.IN ? new byte[]{0} : new byte[]{1}; - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - edge.setDirect((EdgeDirection) value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - if (value[0] == 0) { - edge.setDirect(EdgeDirection.IN); - } else { - edge.setDirect(EdgeDirection.OUT); - } - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.DIRECTION; - } - }, - /** - * LABEL EDGE ATOM. - */ - LABEL { - @Override - public Object getValue(IEdge edge, - GraphDataSchema graphDataSchema) { - return graphDataSchema.getEdgeMeta().getGraphFieldSerializer().getValue(edge, GraphFiledName.LABEL); - } - - @Override - public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { - return getValue(edge, graphDataSchema).toString().getBytes(); - } - - @Override - public void setValue(IEdge edge, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getEdgeMeta().getGraphFieldSerializer().setValue(edge, GraphFiledName.LABEL, value); - } - - @Override - public void setBinaryValue(IEdge edge, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(edge, new String(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.LABEL; - } - }; - - private PushDownPb.SortType sortType; - - EdgeAtom() { - sortType = PushDownPb.SortType.valueOf(this.name()); - } - - public abstract Object getValue(IEdge edge, GraphDataSchema graphDataSchema); - - public abstract byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema); - - public abstract void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema); - - public abstract void setBinaryValue(IEdge edge, byte[] value, GraphDataSchema graphDataSchema); - - public abstract GraphFiledName getGraphFieldName(); - - public PushDownPb.SortType toPbSortType() { - return sortType; - } - - public static final Map EDGE_ATOM_MAP = ImmutableMap.of( - GraphFiledName.SRC_ID, SRC_ID, - GraphFiledName.DST_ID, DST_ID, - GraphFiledName.TIME, TIME, - GraphFiledName.DIRECTION, DIRECTION, - GraphFiledName.LABEL, LABEL - ); - - - public static EdgeAtom getEnum(String value) { - for (EdgeAtom v : values()) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - return null; - } - - public Comparator getComparator() { - switch (this) { - case TIME: - return Comparator.comparingLong(o -> ((IGraphElementWithTimeField) o).getTime()); - case DESC_TIME: - return Collections.reverseOrder(Comparator.comparingLong(o -> ((IGraphElementWithTimeField) o).getTime())); - case DIRECTION: - return (o1, o2) -> Integer.compare(o2.getDirect().ordinal(), o1.getDirect().ordinal()); - case DST_ID: - return (o1, o2) -> ((Comparable) o1.getTargetId()).compareTo(o2.getTargetId()); - case SRC_ID: - return (o1, o2) -> ((Comparable) o1.getSrcId()).compareTo(o2.getTargetId()); - case LABEL: - return Comparator.comparing(o -> ((IGraphElementWithLabelField) o).getLabel()); - default: - throw new RuntimeException("no comparator"); - } - } - - public static Comparator getComparator(List fields) { - if (fields == null || fields.size() == 0) { - return null; - } - ComparatorChain chain = new ComparatorChain(); - fields.forEach(f -> chain.addComparator(f.getComparator())); - return chain; + /** SRC EDGE ATOM. */ + SRC_ID { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return edge.getSrcId(); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return graphDataSchema.getKeyType().serialize(edge.getSrcId()); + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + edge.setSrcId((K) value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + edge.setSrcId((K) graphDataSchema.getKeyType().deserialize(value)); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.SRC_ID; + } + }, + /** TARGET EDGE ATOM. */ + DST_ID { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return edge.getTargetId(); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return graphDataSchema.getKeyType().serialize(edge.getTargetId()); + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + edge.setTargetId((K) value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + edge.setTargetId((K) graphDataSchema.getKeyType().deserialize(value)); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.DST_ID; + } + }, + /** TIME EDGE ATOM. */ + TIME { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .getValue(edge, GraphFiledName.TIME); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return Longs.toByteArray((Long) getValue(edge, graphDataSchema)); + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .setValue(edge, GraphFiledName.TIME, value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + setValue(edge, Longs.fromByteArray(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.TIME; + } + }, + DESC_TIME { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return Long.MAX_VALUE - (Long) TIME.getValue(edge, graphDataSchema); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return Longs.toByteArray((Long) getValue(edge, graphDataSchema)); + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .setValue(edge, GraphFiledName.TIME, Long.MAX_VALUE - (Long) value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + setValue(edge, Longs.fromByteArray(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.TIME; + } + }, + /** DIRECTION EDGE ATOM. */ + DIRECTION { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return edge.getDirect(); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return edge.getDirect() == EdgeDirection.IN ? new byte[] {0} : new byte[] {1}; + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + edge.setDirect((EdgeDirection) value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + if (value[0] == 0) { + edge.setDirect(EdgeDirection.IN); + } else { + edge.setDirect(EdgeDirection.OUT); + } + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.DIRECTION; + } + }, + /** LABEL EDGE ATOM. */ + LABEL { + @Override + public Object getValue(IEdge edge, GraphDataSchema graphDataSchema) { + return graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .getValue(edge, GraphFiledName.LABEL); + } + + @Override + public byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema) { + return getValue(edge, graphDataSchema).toString().getBytes(); + } + + @Override + public void setValue(IEdge edge, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .setValue(edge, GraphFiledName.LABEL, value); + } + + @Override + public void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema) { + setValue(edge, new String(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.LABEL; + } + }; + + private PushDownPb.SortType sortType; + + EdgeAtom() { + sortType = PushDownPb.SortType.valueOf(this.name()); + } + + public abstract Object getValue(IEdge edge, GraphDataSchema graphDataSchema); + + public abstract byte[] getBinaryValue(IEdge edge, GraphDataSchema graphDataSchema); + + public abstract void setValue( + IEdge edge, Object value, GraphDataSchema graphDataSchema); + + public abstract void setBinaryValue( + IEdge edge, byte[] value, GraphDataSchema graphDataSchema); + + public abstract GraphFiledName getGraphFieldName(); + + public PushDownPb.SortType toPbSortType() { + return sortType; + } + + public static final Map EDGE_ATOM_MAP = + ImmutableMap.of( + GraphFiledName.SRC_ID, SRC_ID, + GraphFiledName.DST_ID, DST_ID, + GraphFiledName.TIME, TIME, + GraphFiledName.DIRECTION, DIRECTION, + GraphFiledName.LABEL, LABEL); + + public static EdgeAtom getEnum(String value) { + for (EdgeAtom v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } + } + return null; + } + + public Comparator getComparator() { + switch (this) { + case TIME: + return Comparator.comparingLong(o -> ((IGraphElementWithTimeField) o).getTime()); + case DESC_TIME: + return Collections.reverseOrder( + Comparator.comparingLong(o -> ((IGraphElementWithTimeField) o).getTime())); + case DIRECTION: + return (o1, o2) -> Integer.compare(o2.getDirect().ordinal(), o1.getDirect().ordinal()); + case DST_ID: + return (o1, o2) -> ((Comparable) o1.getTargetId()).compareTo(o2.getTargetId()); + case SRC_ID: + return (o1, o2) -> ((Comparable) o1.getSrcId()).compareTo(o2.getTargetId()); + case LABEL: + return Comparator.comparing(o -> ((IGraphElementWithLabelField) o).getLabel()); + default: + throw new RuntimeException("no comparator"); + } + } + + public static Comparator getComparator(List fields) { + if (fields == null || fields.size() == 0) { + return null; } + ComparatorChain chain = new ComparatorChain(); + fields.forEach(f -> chain.addComparator(f.getComparator())); + return chain; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoder.java index c37a68b42..29c08a1a9 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoder.java @@ -25,20 +25,20 @@ public class EdgeKVEncoder extends EdgeKVEncoderWithoutValue { - public EdgeKVEncoder(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { - super(graphDataSchema, bytesEncoder); - } + public EdgeKVEncoder(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { + super(graphDataSchema, bytesEncoder); + } - @Override - public Tuple format(IEdge edge) { - Tuple res = super.format(edge); - res.f1 = graphDataSchema.getEdgePropertySerFun().apply(edge.getValue()); - return res; - } + @Override + public Tuple format(IEdge edge) { + Tuple res = super.format(edge); + res.f1 = graphDataSchema.getEdgePropertySerFun().apply(edge.getValue()); + return res; + } - @Override - public IEdge getEdge(byte[] key, byte[] value) { - IEdge edge = super.getEdge(key, value); - return edge.withValue(graphDataSchema.getEdgePropertyDeFun().apply(value)); - } + @Override + public IEdge getEdge(byte[] key, byte[] value) { + IEdge edge = super.getEdge(key, value); + return edge.withValue(graphDataSchema.getEdgePropertyDeFun().apply(value)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoderWithoutValue.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoderWithoutValue.java index 2a6b6d196..d784b2060 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoderWithoutValue.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/EdgeKVEncoderWithoutValue.java @@ -19,62 +19,64 @@ package org.apache.geaflow.state.graph.encoder; -import com.google.common.primitives.Bytes; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.state.schema.GraphDataSchema; +import com.google.common.primitives.Bytes; + public class EdgeKVEncoderWithoutValue implements IEdgeKVEncoder { - protected static final byte[] EMPTY_BYTES = new byte[0]; - protected final GraphDataSchema graphDataSchema; - protected final List edgeSchema; - protected final IType keyType; - protected final IBytesEncoder bytesEncoder; + protected static final byte[] EMPTY_BYTES = new byte[0]; + protected final GraphDataSchema graphDataSchema; + protected final List edgeSchema; + protected final IType keyType; + protected final IBytesEncoder bytesEncoder; - public EdgeKVEncoderWithoutValue(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { - this.graphDataSchema = graphDataSchema; - this.edgeSchema = graphDataSchema.getEdgeAtoms(); - this.keyType = graphDataSchema.getKeyType(); - this.bytesEncoder = bytesEncoder; - } + public EdgeKVEncoderWithoutValue(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { + this.graphDataSchema = graphDataSchema; + this.edgeSchema = graphDataSchema.getEdgeAtoms(); + this.keyType = graphDataSchema.getKeyType(); + this.bytesEncoder = bytesEncoder; + } - @Override - public byte[] getScanBytes(K key) { - return Bytes.concat(keyType.serialize(key), StateConfigKeys.DELIMITER); - } + @Override + public byte[] getScanBytes(K key) { + return Bytes.concat(keyType.serialize(key), StateConfigKeys.DELIMITER); + } - @Override - public Tuple format(IEdge edge) { - List list = new ArrayList<>(edgeSchema.size()); - for (int i = 0; i < edgeSchema.size(); i++) { - list.add(edgeSchema.get(i).getBinaryValue(edge, graphDataSchema)); - } - byte[] a = bytesEncoder.combine(list, StateConfigKeys.DELIMITER); - return new Tuple<>(a, EMPTY_BYTES); + @Override + public Tuple format(IEdge edge) { + List list = new ArrayList<>(edgeSchema.size()); + for (int i = 0; i < edgeSchema.size(); i++) { + list.add(edgeSchema.get(i).getBinaryValue(edge, graphDataSchema)); } + byte[] a = bytesEncoder.combine(list, StateConfigKeys.DELIMITER); + return new Tuple<>(a, EMPTY_BYTES); + } - @Override - public IEdge getEdge(byte[] key, byte[] value) { - IEdge edge = this.graphDataSchema.getEdgeConsFun().get(); - List values = bytesEncoder.split(key, StateConfigKeys.DELIMITER); - if (values == null) { - IBytesEncoder encoder = BytesEncoderRepo.get( - bytesEncoder.parseMagicNumber(key[key.length - 1])); - values = encoder.split(key, StateConfigKeys.DELIMITER); - } - for (int i = 0; i < edgeSchema.size(); i++) { - edgeSchema.get(i).setBinaryValue(edge, values.get(i), graphDataSchema); - } - return edge; + @Override + public IEdge getEdge(byte[] key, byte[] value) { + IEdge edge = this.graphDataSchema.getEdgeConsFun().get(); + List values = bytesEncoder.split(key, StateConfigKeys.DELIMITER); + if (values == null) { + IBytesEncoder encoder = + BytesEncoderRepo.get(bytesEncoder.parseMagicNumber(key[key.length - 1])); + values = encoder.split(key, StateConfigKeys.DELIMITER); } - - @Override - public IBytesEncoder getBytesEncoder() { - return this.bytesEncoder; + for (int i = 0; i < edgeSchema.size(); i++) { + edgeSchema.get(i).setBinaryValue(edge, values.get(i), graphDataSchema); } + return edge; + } + + @Override + public IBytesEncoder getBytesEncoder() { + return this.bytesEncoder; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoder.java index e40e79598..6bd0c73f7 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoder.java @@ -24,40 +24,40 @@ public class GraphKVEncoder implements IGraphKVEncoder { - private GraphDataSchema graphDataSchema; - private IType keyType; - private IVertexKVEncoder vertexKVEncoder; - private IEdgeKVEncoder edgeKVEncoder; - - public GraphKVEncoder() { - - } - - @Override - public void init(GraphDataSchema graphDataSchema) { - this.graphDataSchema = graphDataSchema; - this.keyType = graphDataSchema.getKeyType(); - IBytesEncoder bytesEncoder = new DefaultBytesEncoder(); - this.vertexKVEncoder = this.graphDataSchema.isEmptyVertexProperty() + private GraphDataSchema graphDataSchema; + private IType keyType; + private IVertexKVEncoder vertexKVEncoder; + private IEdgeKVEncoder edgeKVEncoder; + + public GraphKVEncoder() {} + + @Override + public void init(GraphDataSchema graphDataSchema) { + this.graphDataSchema = graphDataSchema; + this.keyType = graphDataSchema.getKeyType(); + IBytesEncoder bytesEncoder = new DefaultBytesEncoder(); + this.vertexKVEncoder = + this.graphDataSchema.isEmptyVertexProperty() ? new VertexKVEncoderWithoutValue<>(graphDataSchema, bytesEncoder) : new VertexKVEncoder<>(graphDataSchema, bytesEncoder); - this.edgeKVEncoder = this.graphDataSchema.isEmptyEdgeProperty() + this.edgeKVEncoder = + this.graphDataSchema.isEmptyEdgeProperty() ? new EdgeKVEncoderWithoutValue<>(graphDataSchema, bytesEncoder) : new EdgeKVEncoder<>(graphDataSchema, bytesEncoder); - } - - @Override - public IType getKeyType() { - return keyType; - } - - @Override - public IVertexKVEncoder getVertexEncoder() { - return vertexKVEncoder; - } - - @Override - public IEdgeKVEncoder getEdgeEncoder() { - return edgeKVEncoder; - } + } + + @Override + public IType getKeyType() { + return keyType; + } + + @Override + public IVertexKVEncoder getVertexEncoder() { + return vertexKVEncoder; + } + + @Override + public IEdgeKVEncoder getEdgeEncoder() { + return edgeKVEncoder; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoderFactory.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoderFactory.java index bedd2cc12..6530d7c5e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoderFactory.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/GraphKVEncoderFactory.java @@ -19,34 +19,38 @@ package org.apache.geaflow.state.graph.encoder; -import com.google.common.base.Splitter; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.state.schema.GraphDataSchema; +import com.google.common.base.Splitter; + public class GraphKVEncoderFactory { - private static final char EDGE_ORDER_SPLITTER = ','; + private static final char EDGE_ORDER_SPLITTER = ','; - public static IGraphKVEncoder build(Configuration config, - GraphDataSchema schema) { - String clazz = config.getString(StateConfigKeys.STATE_KV_ENCODER_CLASS); - String edgeOrder = config.getString(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER); - if (edgeOrder != null && edgeOrder.length() > 0) { - List list = Splitter.on(EDGE_ORDER_SPLITTER).splitToList(edgeOrder).stream() - .map(c -> EdgeAtom.getEnum(c.trim())).collect(Collectors.toList()); - schema.setEdgeAtoms(list); - } - try { - IGraphKVEncoder encoder = (IGraphKVEncoder) Class.forName(clazz).newInstance(); - encoder.init(schema); - return encoder; - } catch (Exception e) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(e.getMessage()), e); - } + public static IGraphKVEncoder build( + Configuration config, GraphDataSchema schema) { + String clazz = config.getString(StateConfigKeys.STATE_KV_ENCODER_CLASS); + String edgeOrder = config.getString(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER); + if (edgeOrder != null && edgeOrder.length() > 0) { + List list = + Splitter.on(EDGE_ORDER_SPLITTER).splitToList(edgeOrder).stream() + .map(c -> EdgeAtom.getEnum(c.trim())) + .collect(Collectors.toList()); + schema.setEdgeAtoms(list); + } + try { + IGraphKVEncoder encoder = (IGraphKVEncoder) Class.forName(clazz).newInstance(); + encoder.init(schema); + return encoder; + } catch (Exception e) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError(e.getMessage()), e); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IBytesEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IBytesEncoder.java index 3095405cf..93c936228 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IBytesEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IBytesEncoder.java @@ -23,15 +23,15 @@ public interface IBytesEncoder { - byte[] combine(List listBytes); + byte[] combine(List listBytes); - List split(byte[] bytes); + List split(byte[] bytes); - byte[] combine(List listBytes, byte[] delimiter); + byte[] combine(List listBytes, byte[] delimiter); - List split(byte[] bytes, byte[] delimiter); + List split(byte[] bytes, byte[] delimiter); - byte getMyMagicNumber(); + byte getMyMagicNumber(); - byte parseMagicNumber(byte b); + byte parseMagicNumber(byte b); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IEdgeKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IEdgeKVEncoder.java index 9ae2b1a8d..9a0e476b3 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IEdgeKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IEdgeKVEncoder.java @@ -24,11 +24,11 @@ public interface IEdgeKVEncoder { - byte[] getScanBytes(K key); + byte[] getScanBytes(K key); - Tuple format(IEdge edge); + Tuple format(IEdge edge); - IEdge getEdge(byte[] key, byte[] value); + IEdge getEdge(byte[] key, byte[] value); - IBytesEncoder getBytesEncoder(); + IBytesEncoder getBytesEncoder(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IGraphKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IGraphKVEncoder.java index c9188a82b..d1da34ea5 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IGraphKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IGraphKVEncoder.java @@ -24,11 +24,11 @@ public interface IGraphKVEncoder { - void init(GraphDataSchema graphDataSchema); + void init(GraphDataSchema graphDataSchema); - IType getKeyType(); + IType getKeyType(); - IVertexKVEncoder getVertexEncoder(); + IVertexKVEncoder getVertexEncoder(); - IEdgeKVEncoder getEdgeEncoder(); + IEdgeKVEncoder getEdgeEncoder(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IVertexKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IVertexKVEncoder.java index 34483867e..e69a58cfc 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IVertexKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/IVertexKVEncoder.java @@ -24,11 +24,11 @@ public interface IVertexKVEncoder { - Tuple format(IVertex vertexData); + Tuple format(IVertex vertexData); - IVertex getVertex(byte[] key, byte[] value); + IVertex getVertex(byte[] key, byte[] value); - K getVertexID(byte[] key); + K getVertexID(byte[] key); - IBytesEncoder getBytesEncoder(); + IBytesEncoder getBytesEncoder(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexAtom.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexAtom.java index 455537b5c..05d0b7168 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexAtom.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexAtom.java @@ -19,171 +19,171 @@ package org.apache.geaflow.state.graph.encoder; -import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Longs; import java.util.Map; + import org.apache.geaflow.model.graph.meta.GraphFiledName; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.schema.GraphDataSchema; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Longs; + public enum VertexAtom { - /** - * ID VERTEX ATOM. - */ - ID { - @Override - public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { - return vertex.getId(); - } - - @Override - public byte[] getBinaryValue(IVertex vertex, - GraphDataSchema graphDataSchema) { - return graphDataSchema.getKeyType().serialize(vertex.getId()); - } - - @Override - public void setValue(IVertex vertex, Object value, - GraphDataSchema graphDataSchema) { - vertex.setId((K) value); - } - - @Override - public void setBinaryValue(IVertex vertex, byte[] value, - GraphDataSchema graphDataSchema) { - vertex.setId((K) graphDataSchema.getKeyType().deserialize(value)); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.ID; - } - }, - /** - * TIME VERTEX ATOM. - */ - TIME { - @Override - public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { - return graphDataSchema.getVertexMeta().getGraphFieldSerializer() - .getValue(vertex, GraphFiledName.TIME); - } - - @Override - public byte[] getBinaryValue(IVertex vertex, - GraphDataSchema graphDataSchema) { - return Longs.toByteArray((Long) getValue(vertex, graphDataSchema)); - } - - @Override - public void setValue(IVertex vertex, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getVertexMeta().getGraphFieldSerializer() - .setValue(vertex, GraphFiledName.TIME, value); - } - - @Override - public void setBinaryValue(IVertex vertex, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(vertex, Longs.fromByteArray(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.TIME; - } - }, - DESC_TIME { - @Override - public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { - return Long.MAX_VALUE - (Long) TIME.getValue(vertex, graphDataSchema); - } - - @Override - public byte[] getBinaryValue(IVertex vertex, - GraphDataSchema graphDataSchema) { - return Longs.toByteArray((Long) getValue(vertex, graphDataSchema)); - } - - @Override - public void setValue(IVertex vertex, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getEdgeMeta().getGraphFieldSerializer().setValue(vertex, - GraphFiledName.TIME, Long.MAX_VALUE - (Long) value); - } - - @Override - public void setBinaryValue(IVertex vertex, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(vertex, Longs.fromByteArray(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.TIME; - } - }, - /** - * LABEL VERTEX ATOM. - */ - LABEL { - public Object getValue(IVertex vertex, - GraphDataSchema graphDataSchema) { - return graphDataSchema.getVertexMeta().getGraphFieldSerializer() - .getValue(vertex, GraphFiledName.LABEL); - } - - @Override - public byte[] getBinaryValue(IVertex vertex, - GraphDataSchema graphDataSchema) { - return getValue(vertex, graphDataSchema).toString().getBytes(); - } - - @Override - public void setValue(IVertex vertex, Object value, - GraphDataSchema graphDataSchema) { - graphDataSchema.getVertexMeta().getGraphFieldSerializer() - .setValue(vertex, GraphFiledName.LABEL, value); - } - - @Override - public void setBinaryValue(IVertex vertex, byte[] value, - GraphDataSchema graphDataSchema) { - setValue(vertex, new String(value), graphDataSchema); - } - - @Override - public GraphFiledName getGraphFieldName() { - return GraphFiledName.TIME; - } - }; - - public abstract Object getValue(IVertex vertex, GraphDataSchema graphDataSchema); - - public abstract byte[] getBinaryValue(IVertex vertex, - GraphDataSchema graphDataSchema); - - public abstract void setValue(IVertex vertex, Object value, - GraphDataSchema graphDataSchema); - - public abstract void setBinaryValue(IVertex vertex, byte[] value, - GraphDataSchema graphDataSchema); - - public abstract GraphFiledName getGraphFieldName(); - - public static final Map VERTEX_ATOM_MAP = ImmutableMap.of( - GraphFiledName.ID, ID, - GraphFiledName.TIME, TIME, - GraphFiledName.LABEL, LABEL - ); - - public static VertexAtom getEnum(String value) { - for (VertexAtom v : values()) { - if (v.name().equalsIgnoreCase(value)) { - return v; - } - } - return null; + /** ID VERTEX ATOM. */ + ID { + @Override + public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return vertex.getId(); + } + + @Override + public byte[] getBinaryValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return graphDataSchema.getKeyType().serialize(vertex.getId()); + } + + @Override + public void setValue( + IVertex vertex, Object value, GraphDataSchema graphDataSchema) { + vertex.setId((K) value); + } + + @Override + public void setBinaryValue( + IVertex vertex, byte[] value, GraphDataSchema graphDataSchema) { + vertex.setId((K) graphDataSchema.getKeyType().deserialize(value)); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.ID; + } + }, + /** TIME VERTEX ATOM. */ + TIME { + @Override + public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return graphDataSchema + .getVertexMeta() + .getGraphFieldSerializer() + .getValue(vertex, GraphFiledName.TIME); + } + + @Override + public byte[] getBinaryValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return Longs.toByteArray((Long) getValue(vertex, graphDataSchema)); + } + + @Override + public void setValue( + IVertex vertex, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getVertexMeta() + .getGraphFieldSerializer() + .setValue(vertex, GraphFiledName.TIME, value); + } + + @Override + public void setBinaryValue( + IVertex vertex, byte[] value, GraphDataSchema graphDataSchema) { + setValue(vertex, Longs.fromByteArray(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.TIME; + } + }, + DESC_TIME { + @Override + public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return Long.MAX_VALUE - (Long) TIME.getValue(vertex, graphDataSchema); } + @Override + public byte[] getBinaryValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return Longs.toByteArray((Long) getValue(vertex, graphDataSchema)); + } + + @Override + public void setValue( + IVertex vertex, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getEdgeMeta() + .getGraphFieldSerializer() + .setValue(vertex, GraphFiledName.TIME, Long.MAX_VALUE - (Long) value); + } + + @Override + public void setBinaryValue( + IVertex vertex, byte[] value, GraphDataSchema graphDataSchema) { + setValue(vertex, Longs.fromByteArray(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.TIME; + } + }, + /** LABEL VERTEX ATOM. */ + LABEL { + public Object getValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return graphDataSchema + .getVertexMeta() + .getGraphFieldSerializer() + .getValue(vertex, GraphFiledName.LABEL); + } + + @Override + public byte[] getBinaryValue(IVertex vertex, GraphDataSchema graphDataSchema) { + return getValue(vertex, graphDataSchema).toString().getBytes(); + } + + @Override + public void setValue( + IVertex vertex, Object value, GraphDataSchema graphDataSchema) { + graphDataSchema + .getVertexMeta() + .getGraphFieldSerializer() + .setValue(vertex, GraphFiledName.LABEL, value); + } + + @Override + public void setBinaryValue( + IVertex vertex, byte[] value, GraphDataSchema graphDataSchema) { + setValue(vertex, new String(value), graphDataSchema); + } + + @Override + public GraphFiledName getGraphFieldName() { + return GraphFiledName.TIME; + } + }; + + public abstract Object getValue(IVertex vertex, GraphDataSchema graphDataSchema); + + public abstract byte[] getBinaryValue( + IVertex vertex, GraphDataSchema graphDataSchema); + + public abstract void setValue( + IVertex vertex, Object value, GraphDataSchema graphDataSchema); + + public abstract void setBinaryValue( + IVertex vertex, byte[] value, GraphDataSchema graphDataSchema); + + public abstract GraphFiledName getGraphFieldName(); + + public static final Map VERTEX_ATOM_MAP = + ImmutableMap.of( + GraphFiledName.ID, ID, + GraphFiledName.TIME, TIME, + GraphFiledName.LABEL, LABEL); + + public static VertexAtom getEnum(String value) { + for (VertexAtom v : values()) { + if (v.name().equalsIgnoreCase(value)) { + return v; + } + } + return null; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoder.java index 9a16ae57e..7c1e5e456 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoder.java @@ -21,42 +21,43 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.schema.GraphDataSchema; public class VertexKVEncoder extends VertexKVEncoderWithoutValue { - public VertexKVEncoder(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { - super(graphDataSchema, bytesEncoder); - } + public VertexKVEncoder(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { + super(graphDataSchema, bytesEncoder); + } - @Override - public Tuple format(IVertex vertex) { - byte[] a = keyType.serialize(vertex.getId()); + @Override + public Tuple format(IVertex vertex) { + byte[] a = keyType.serialize(vertex.getId()); - List bytes = new ArrayList<>(); - for (int i = 1; i < vertexSchema.size(); i++) { - bytes.add(vertexSchema.get(i).getBinaryValue(vertex, graphDataSchema)); - } - bytes.add(graphDataSchema.getVertexPropertySerFun().apply(vertex.getValue())); - return new Tuple<>(a, bytesEncoder.combine(bytes)); + List bytes = new ArrayList<>(); + for (int i = 1; i < vertexSchema.size(); i++) { + bytes.add(vertexSchema.get(i).getBinaryValue(vertex, graphDataSchema)); } + bytes.add(graphDataSchema.getVertexPropertySerFun().apply(vertex.getValue())); + return new Tuple<>(a, bytesEncoder.combine(bytes)); + } - @Override - public IVertex getVertex(byte[] key, byte[] value) { - IVertex vertex = this.graphDataSchema.getVertexConsFun().get(); - vertex.setId(keyType.deserialize(key)); - List values = bytesEncoder.split(value); - if (values == null) { - IBytesEncoder encoder = BytesEncoderRepo.get( - bytesEncoder.parseMagicNumber(key[value.length - 1])); - values = encoder.split(value); - } - int i = 1; - for (; i < vertexSchema.size(); i++) { - vertexSchema.get(i).setBinaryValue(vertex, values.get(i - 1), graphDataSchema); - } - return vertex.withValue(graphDataSchema.getVertexPropertyDeFun().apply(values.get(i - 1))); + @Override + public IVertex getVertex(byte[] key, byte[] value) { + IVertex vertex = this.graphDataSchema.getVertexConsFun().get(); + vertex.setId(keyType.deserialize(key)); + List values = bytesEncoder.split(value); + if (values == null) { + IBytesEncoder encoder = + BytesEncoderRepo.get(bytesEncoder.parseMagicNumber(key[value.length - 1])); + values = encoder.split(value); + } + int i = 1; + for (; i < vertexSchema.size(); i++) { + vertexSchema.get(i).setBinaryValue(vertex, values.get(i - 1), graphDataSchema); } + return vertex.withValue(graphDataSchema.getVertexPropertyDeFun().apply(values.get(i - 1))); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoderWithoutValue.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoderWithoutValue.java index 9094297ff..cd18d92a7 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoderWithoutValue.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/graph/encoder/VertexKVEncoderWithoutValue.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.common.type.IType; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -28,54 +29,54 @@ public class VertexKVEncoderWithoutValue implements IVertexKVEncoder { - protected final GraphDataSchema graphDataSchema; - protected final List vertexSchema; - protected final IType keyType; - protected final IBytesEncoder bytesEncoder; - - public VertexKVEncoderWithoutValue(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { - this.graphDataSchema = graphDataSchema; - this.vertexSchema = graphDataSchema.getVertexAtoms(); - this.keyType = graphDataSchema.getKeyType(); - this.bytesEncoder = bytesEncoder; - } + protected final GraphDataSchema graphDataSchema; + protected final List vertexSchema; + protected final IType keyType; + protected final IBytesEncoder bytesEncoder; - @Override - public Tuple format(IVertex vertex) { - byte[] a = keyType.serialize(vertex.getId()); + public VertexKVEncoderWithoutValue(GraphDataSchema graphDataSchema, IBytesEncoder bytesEncoder) { + this.graphDataSchema = graphDataSchema; + this.vertexSchema = graphDataSchema.getVertexAtoms(); + this.keyType = graphDataSchema.getKeyType(); + this.bytesEncoder = bytesEncoder; + } - List bytes = new ArrayList<>(); - for (int i = 1; i < vertexSchema.size(); i++) { - bytes.add(vertexSchema.get(i).getBinaryValue(vertex, graphDataSchema)); - } + @Override + public Tuple format(IVertex vertex) { + byte[] a = keyType.serialize(vertex.getId()); - return new Tuple<>(a, bytesEncoder.combine(bytes)); + List bytes = new ArrayList<>(); + for (int i = 1; i < vertexSchema.size(); i++) { + bytes.add(vertexSchema.get(i).getBinaryValue(vertex, graphDataSchema)); } - @Override - public IVertex getVertex(byte[] key, byte[] value) { - IVertex vertex = this.graphDataSchema.getVertexConsFun().get(); - vertex.setId(keyType.deserialize(key)); - List values = bytesEncoder.split(value); - if (values == null) { - IBytesEncoder encoder = BytesEncoderRepo.get( - bytesEncoder.parseMagicNumber(key[value.length - 1])); - values = encoder.split(value); - } - int i = 1; - for (; i < vertexSchema.size(); i++) { - vertexSchema.get(i).setBinaryValue(vertex, values.get(i - 1), graphDataSchema); - } - return vertex; - } + return new Tuple<>(a, bytesEncoder.combine(bytes)); + } - @Override - public K getVertexID(byte[] key) { - return (K) keyType.deserialize(key); + @Override + public IVertex getVertex(byte[] key, byte[] value) { + IVertex vertex = this.graphDataSchema.getVertexConsFun().get(); + vertex.setId(keyType.deserialize(key)); + List values = bytesEncoder.split(value); + if (values == null) { + IBytesEncoder encoder = + BytesEncoderRepo.get(bytesEncoder.parseMagicNumber(key[value.length - 1])); + values = encoder.split(value); } - - @Override - public IBytesEncoder getBytesEncoder() { - return this.bytesEncoder; + int i = 1; + for (; i < vertexSchema.size(); i++) { + vertexSchema.get(i).setBinaryValue(vertex, values.get(i - 1), graphDataSchema); } + return vertex; + } + + @Override + public K getVertexID(byte[] key) { + return (K) keyType.deserialize(key); + } + + @Override + public IBytesEncoder getBytesEncoder() { + return this.bytesEncoder; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/BaseCloseableIterator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/BaseCloseableIterator.java index a61fec158..349c29b6a 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/BaseCloseableIterator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/BaseCloseableIterator.java @@ -20,31 +20,32 @@ package org.apache.geaflow.state.iterator; import java.util.Iterator; + import org.apache.geaflow.common.iterator.CloseableIterator; public abstract class BaseCloseableIterator implements CloseableIterator { - protected final Iterator iterator; - protected final boolean closeable; + protected final Iterator iterator; + protected final boolean closeable; - public BaseCloseableIterator(CloseableIterator iterator) { - this.iterator = iterator; - this.closeable = true; - } + public BaseCloseableIterator(CloseableIterator iterator) { + this.iterator = iterator; + this.closeable = true; + } - public BaseCloseableIterator(Iterator iterator) { - this.iterator = iterator; - this.closeable = iterator instanceof AutoCloseable; - } + public BaseCloseableIterator(Iterator iterator) { + this.iterator = iterator; + this.closeable = iterator instanceof AutoCloseable; + } - @Override - public void close() { - if (this.closeable) { - try { - ((AutoCloseable) this.iterator).close(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } + @Override + public void close() { + if (this.closeable) { + try { + ((AutoCloseable) this.iterator).close(); + } catch (Exception e) { + throw new RuntimeException(e); + } } + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IOneDegreeGraphIterator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IOneDegreeGraphIterator.java index c6fec68ab..95944c400 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IOneDegreeGraphIterator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IOneDegreeGraphIterator.java @@ -20,10 +20,9 @@ package org.apache.geaflow.state.iterator; import java.io.Serializable; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.data.OneDegreeGraph; -public interface IOneDegreeGraphIterator extends - CloseableIterator>, Serializable { - -} +public interface IOneDegreeGraphIterator + extends CloseableIterator>, Serializable {} diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IVertexIterator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IVertexIterator.java index 0aaa3dc02..4cac79dbd 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IVertexIterator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IVertexIterator.java @@ -20,9 +20,8 @@ package org.apache.geaflow.state.iterator; import java.io.Serializable; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.vertex.IVertex; -public interface IVertexIterator extends CloseableIterator>, Serializable { - -} +public interface IVertexIterator extends CloseableIterator>, Serializable {} diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithClose.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithClose.java index a6059b28e..122026299 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithClose.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithClose.java @@ -23,21 +23,21 @@ public class IteratorWithClose extends BaseCloseableIterator { - private IteratorWithClose(Iterator iterator) { - super(iterator); - } + private IteratorWithClose(Iterator iterator) { + super(iterator); + } - public static IteratorWithClose wrap(Iterator iterator) { - return new IteratorWithClose<>(iterator); - } + public static IteratorWithClose wrap(Iterator iterator) { + return new IteratorWithClose<>(iterator); + } - @Override - public boolean hasNext() { - return iterator.hasNext(); - } + @Override + public boolean hasNext() { + return iterator.hasNext(); + } - @Override - public T next() { - return iterator.next(); - } + @Override + public T next() { + return iterator.next(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilter.java index 7d6a730fc..8095eca40 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilter.java @@ -19,43 +19,43 @@ package org.apache.geaflow.state.iterator; -import com.google.common.base.Preconditions; import java.util.Iterator; import java.util.function.Predicate; -import org.apache.geaflow.common.iterator.CloseableIterator; - -/** - * This class accepts a iterator and makes its value filtered. - */ -public class IteratorWithFilter extends BaseCloseableIterator { - private final Predicate predicate; - private T nextValue; +import org.apache.geaflow.common.iterator.CloseableIterator; - public IteratorWithFilter(CloseableIterator iterator, Predicate predicate) { - super(iterator); - this.predicate = Preconditions.checkNotNull(predicate); - } +import com.google.common.base.Preconditions; - public IteratorWithFilter(Iterator iterator, Predicate predicate) { - super(iterator); - this.predicate = Preconditions.checkNotNull(predicate); - } +/** This class accepts a iterator and makes its value filtered. */ +public class IteratorWithFilter extends BaseCloseableIterator { - @Override - public boolean hasNext() { - while (this.iterator.hasNext()) { - nextValue = this.iterator.next(); - if (!predicate.test(nextValue)) { - continue; - } - return true; - } - return false; + private final Predicate predicate; + private T nextValue; + + public IteratorWithFilter(CloseableIterator iterator, Predicate predicate) { + super(iterator); + this.predicate = Preconditions.checkNotNull(predicate); + } + + public IteratorWithFilter(Iterator iterator, Predicate predicate) { + super(iterator); + this.predicate = Preconditions.checkNotNull(predicate); + } + + @Override + public boolean hasNext() { + while (this.iterator.hasNext()) { + nextValue = this.iterator.next(); + if (!predicate.test(nextValue)) { + continue; + } + return true; } + return false; + } - @Override - public T next() { - return nextValue; - } + @Override + public T next() { + return nextValue; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFn.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFn.java index af9499202..ceef9c9e7 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFn.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFn.java @@ -19,47 +19,49 @@ package org.apache.geaflow.state.iterator; -import com.google.common.base.Preconditions; import java.util.Iterator; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.common.iterator.CloseableIterator; -/** - * This class accepts a iterator and makes its value filtered and transformed. - */ +import com.google.common.base.Preconditions; + +/** This class accepts a iterator and makes its value filtered and transformed. */ public class IteratorWithFilterThenFn extends BaseCloseableIterator { - private final Predicate predicate; - private final Function function; - private T nextValue; + private final Predicate predicate; + private final Function function; + private T nextValue; - public IteratorWithFilterThenFn(CloseableIterator iterator, Predicate predicate, Function function) { - super(iterator); - this.predicate = Preconditions.checkNotNull(predicate); - this.function = function; - } + public IteratorWithFilterThenFn( + CloseableIterator iterator, Predicate predicate, Function function) { + super(iterator); + this.predicate = Preconditions.checkNotNull(predicate); + this.function = function; + } - public IteratorWithFilterThenFn(Iterator iterator, Predicate predicate, Function function) { - super(iterator); - this.predicate = Preconditions.checkNotNull(predicate); - this.function = function; - } + public IteratorWithFilterThenFn( + Iterator iterator, Predicate predicate, Function function) { + super(iterator); + this.predicate = Preconditions.checkNotNull(predicate); + this.function = function; + } - @Override - public boolean hasNext() { - while (this.iterator.hasNext()) { - nextValue = this.iterator.next(); - if (!predicate.test(nextValue)) { - continue; - } - return true; - } - return false; + @Override + public boolean hasNext() { + while (this.iterator.hasNext()) { + nextValue = this.iterator.next(); + if (!predicate.test(nextValue)) { + continue; + } + return true; } + return false; + } - @Override - public R next() { - return function.apply(nextValue); - } + @Override + public R next() { + return function.apply(nextValue); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFlatFn.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFlatFn.java index 2b288bd86..04100ea4d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFlatFn.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFlatFn.java @@ -21,42 +21,41 @@ import java.util.Iterator; import java.util.function.Function; + import org.apache.geaflow.common.iterator.CloseableIterator; -/** - * This class accepts a iterator and makes its value fanned out. - */ +/** This class accepts a iterator and makes its value fanned out. */ public class IteratorWithFlatFn extends BaseCloseableIterator { - private final Function> fn; - private Iterator inIt; + private final Function> fn; + private Iterator inIt; - public IteratorWithFlatFn(CloseableIterator iterator, Function> function) { - super(iterator); - this.fn = function; - } + public IteratorWithFlatFn(CloseableIterator iterator, Function> function) { + super(iterator); + this.fn = function; + } - public IteratorWithFlatFn(Iterator iterator, Function> function) { - super(iterator); - this.fn = function; - } + public IteratorWithFlatFn(Iterator iterator, Function> function) { + super(iterator); + this.fn = function; + } - @Override - public boolean hasNext() { - if (inIt != null && inIt.hasNext()) { - return true; - } - while (iterator.hasNext()) { - inIt = fn.apply(iterator.next()); - if (inIt.hasNext()) { - return true; - } - } - return false; + @Override + public boolean hasNext() { + if (inIt != null && inIt.hasNext()) { + return true; } - - @Override - public R next() { - return inIt.next(); + while (iterator.hasNext()) { + inIt = fn.apply(iterator.next()); + if (inIt.hasNext()) { + return true; + } } + return false; + } + + @Override + public R next() { + return inIt.next(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFn.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFn.java index 9caf8ad8f..a32d5ec27 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFn.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFn.java @@ -21,32 +21,31 @@ import java.util.Iterator; import java.util.function.Function; + import org.apache.geaflow.common.iterator.CloseableIterator; -/** - * This class accepts a iterator and makes its value transformed. - */ +/** This class accepts a iterator and makes its value transformed. */ public class IteratorWithFn extends BaseCloseableIterator { - protected final Function fn; + protected final Function fn; - public IteratorWithFn(CloseableIterator iterator, Function function) { - super(iterator); - this.fn = function; - } + public IteratorWithFn(CloseableIterator iterator, Function function) { + super(iterator); + this.fn = function; + } - public IteratorWithFn(Iterator iterator, Function function) { - super(iterator); - this.fn = function; - } + public IteratorWithFn(Iterator iterator, Function function) { + super(iterator); + this.fn = function; + } - @Override - public boolean hasNext() { - return iterator != null && iterator.hasNext(); - } + @Override + public boolean hasNext() { + return iterator != null && iterator.hasNext(); + } - @Override - public R next() { - return fn.apply(iterator.next()); - } + @Override + public R next() { + return fn.apply(iterator.next()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFnThenFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFnThenFilter.java index 18ec58c0d..774b0728d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFnThenFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/IteratorWithFnThenFilter.java @@ -19,47 +19,49 @@ package org.apache.geaflow.state.iterator; -import com.google.common.base.Preconditions; import java.util.Iterator; import java.util.function.Function; import java.util.function.Predicate; + import org.apache.geaflow.common.iterator.CloseableIterator; -/** - * This class accepts a iterator and makes its value transformed and filtered. - */ +import com.google.common.base.Preconditions; + +/** This class accepts a iterator and makes its value transformed and filtered. */ public class IteratorWithFnThenFilter extends BaseCloseableIterator { - protected final Function fn; - private final Predicate predicate; - private R nextValue; + protected final Function fn; + private final Predicate predicate; + private R nextValue; - public IteratorWithFnThenFilter(CloseableIterator iterator, Function function, Predicate predicate) { - super(iterator); - this.fn = function; - this.predicate = Preconditions.checkNotNull(predicate); - } + public IteratorWithFnThenFilter( + CloseableIterator iterator, Function function, Predicate predicate) { + super(iterator); + this.fn = function; + this.predicate = Preconditions.checkNotNull(predicate); + } - public IteratorWithFnThenFilter(Iterator iterator, Function function, Predicate predicate) { - super(iterator); - this.fn = function; - this.predicate = Preconditions.checkNotNull(predicate); - } + public IteratorWithFnThenFilter( + Iterator iterator, Function function, Predicate predicate) { + super(iterator); + this.fn = function; + this.predicate = Preconditions.checkNotNull(predicate); + } - @Override - public boolean hasNext() { - while (iterator != null && this.iterator.hasNext()) { - nextValue = fn.apply(iterator.next()); - if (!predicate.test(nextValue)) { - continue; - } - return true; - } - return false; + @Override + public boolean hasNext() { + while (iterator != null && this.iterator.hasNext()) { + nextValue = fn.apply(iterator.next()); + if (!predicate.test(nextValue)) { + continue; + } + return true; } + return false; + } - @Override - public R next() { - return nextValue; - } + @Override + public R next() { + return nextValue; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/MultiIterator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/MultiIterator.java index 0eab1cc33..fbf4a17cb 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/MultiIterator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/MultiIterator.java @@ -18,64 +18,65 @@ import java.util.Arrays; import java.util.Iterator; + import org.apache.geaflow.common.iterator.CloseableIterator; /** - * This class is an adaptation of Guava's Iterators.concat - * by fixing the issue https://github.com/google/guava/issues/3178. + * This class is an adaptation of Guava's Iterators.concat by fixing the issue + * https://github.com/google/guava/issues/3178. */ public class MultiIterator implements CloseableIterator { - private final Iterator> iterators; - private CloseableIterator currIterator; - private T nextValue; + private final Iterator> iterators; + private CloseableIterator currIterator; + private T nextValue; - public MultiIterator(Iterator> iterators) { - this.iterators = iterators; - if (iterators.hasNext()) { - this.currIterator = iterators.next(); - } + public MultiIterator(Iterator> iterators) { + this.iterators = iterators; + if (iterators.hasNext()) { + this.currIterator = iterators.next(); } + } - public MultiIterator(CloseableIterator... iteratorCandidates) { - this.iterators = Arrays.asList(iteratorCandidates).iterator(); - if (iterators.hasNext()) { - this.currIterator = iterators.next(); - } + public MultiIterator(CloseableIterator... iteratorCandidates) { + this.iterators = Arrays.asList(iteratorCandidates).iterator(); + if (iterators.hasNext()) { + this.currIterator = iterators.next(); } + } - @Override - public boolean hasNext() { - if (currIterator == null) { - return false; + @Override + public boolean hasNext() { + if (currIterator == null) { + return false; + } + if (!currIterator.hasNext()) { + currIterator.close(); + do { + if (!this.iterators.hasNext()) { + return false; } - if (!currIterator.hasNext()) { - currIterator.close(); - do { - if (!this.iterators.hasNext()) { - return false; - } - currIterator = this.iterators.next(); - } while (!currIterator.hasNext()); - } - nextValue = currIterator.next(); - return true; + currIterator = this.iterators.next(); + } while (!currIterator.hasNext()); } + nextValue = currIterator.next(); + return true; + } - @Override - public T next() { - return nextValue; - } + @Override + public T next() { + return nextValue; + } - @Override - public void close() { - if (currIterator != null) { - currIterator.close(); - } - while (this.iterators.hasNext()) { - currIterator = this.iterators.next(); - currIterator.close(); - } + @Override + public void close() { + if (currIterator != null) { + currIterator.close(); + } + while (this.iterators.hasNext()) { + currIterator = this.iterators.next(); + currIterator.close(); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/ScannerType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/ScannerType.java index c0120025d..2cbad6b0c 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/ScannerType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/ScannerType.java @@ -20,36 +20,28 @@ package org.apache.geaflow.state.iterator; public enum ScannerType { - /** - * FULL SCANNER. - */ - FULL(false, true), - /** - * FULL SCANNER WITH FILTER. - */ - FULL_WITH_FILTER(false, false), - /** - * MULTI KEYS SCANNER. - */ - KEYS(true, true), - /** - * MULTI KEYS SCANNER WITH FILTERS. - */ - KEYS_WITH_FILTER(true, false); + /** FULL SCANNER. */ + FULL(false, true), + /** FULL SCANNER WITH FILTER. */ + FULL_WITH_FILTER(false, false), + /** MULTI KEYS SCANNER. */ + KEYS(true, true), + /** MULTI KEYS SCANNER WITH FILTERS. */ + KEYS_WITH_FILTER(true, false); - ScannerType(boolean isMultiKey, boolean emptyFilter) { - this.isMultiKey = isMultiKey; - this.emptyFilter = emptyFilter; - } + ScannerType(boolean isMultiKey, boolean emptyFilter) { + this.isMultiKey = isMultiKey; + this.emptyFilter = emptyFilter; + } - private final boolean isMultiKey; - private final boolean emptyFilter; + private final boolean isMultiKey; + private final boolean emptyFilter; - public boolean isMultiKey() { - return isMultiKey; - } + public boolean isMultiKey() { + return isMultiKey; + } - public boolean emptyFilter() { - return emptyFilter; - } + public boolean emptyFilter() { + return emptyFilter; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/StandardIterator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/StandardIterator.java index ff28f99f0..00c9f9a4e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/StandardIterator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/iterator/StandardIterator.java @@ -21,44 +21,42 @@ import org.apache.geaflow.common.iterator.CloseableIterator; -/** - * This class is a wrapper iterator, allowing multiple hasNext call but one next call. - */ +/** This class is a wrapper iterator, allowing multiple hasNext call but one next call. */ public class StandardIterator extends BaseCloseableIterator { - private boolean nextCalled; - private boolean hasNextValue; - private T nextValue; - - public StandardIterator(CloseableIterator iterator) { - super(iterator); - innerNext(); - } - - private void innerNext() { - this.hasNextValue = this.iterator.hasNext(); - this.nextValue = this.hasNextValue ? this.iterator.next() : null; - this.nextCalled = false; + private boolean nextCalled; + private boolean hasNextValue; + private T nextValue; + + public StandardIterator(CloseableIterator iterator) { + super(iterator); + innerNext(); + } + + private void innerNext() { + this.hasNextValue = this.iterator.hasNext(); + this.nextValue = this.hasNextValue ? this.iterator.next() : null; + this.nextCalled = false; + } + + @Override + public boolean hasNext() { + // only next has called, we trigger hasNext method. + if (nextCalled) { + innerNext(); } - - @Override - public boolean hasNext() { - // only next has called, we trigger hasNext method. - if (nextCalled) { - innerNext(); - } - return hasNextValue; - } - - @Override - public T next() { - if (nextValue != null) { - nextCalled = true; - T next = nextValue; - nextValue = null; - return next; - } - return null; - // throw new NoSuchElementException("hasNext not called or has no next data"); + return hasNextValue; + } + + @Override + public T next() { + if (nextValue != null) { + nextCalled = true; + T next = nextValue; + nextValue = null; + return next; } + return null; + // throw new NoSuchElementException("hasNext not called or has no next data"); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyListTrait.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyListTrait.java index c461d2772..cbea18f6e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyListTrait.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyListTrait.java @@ -23,18 +23,12 @@ public interface KeyListTrait { - /** - * get List from key. - */ - List get(K key); + /** get List from key. */ + List get(K key); - /** - * add the value to list. - */ - void add(K key, V... value); + /** add the value to list. */ + void add(K key, V... value); - /** - * remove key. - */ - void remove(K key); + /** remove key. */ + void remove(K key); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyMapTrait.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyMapTrait.java index a9b89bad5..cd5d69d97 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyMapTrait.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyMapTrait.java @@ -24,35 +24,21 @@ public interface KeyMapTrait { - /** - * Returns the current value associated with the given key. - */ - Map get(K key); + /** Returns the current value associated with the given key. */ + Map get(K key); - /** - * Returns the current value associated with the given key. - */ - List get(K key, UK... subKeys); + /** Returns the current value associated with the given key. */ + List get(K key, UK... subKeys); - /** - * Associates a new value with the given key. - */ - void add(K key, UK subKey, UV value); + /** Associates a new value with the given key. */ + void add(K key, UK subKey, UV value); - /** - * Resets the state value. - */ - void add(K key, Map map); - - /** - * Deletes the mapping of the given key. - */ - void remove(K key); - - /** - * Deletes the mapping of the given key. - */ - void remove(K key, UK... subKeys); + /** Resets the state value. */ + void add(K key, Map map); + /** Deletes the mapping of the given key. */ + void remove(K key); + /** Deletes the mapping of the given key. */ + void remove(K key, UK... subKeys); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyValueTrait.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyValueTrait.java index 55627ccc9..dee9bd97f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyValueTrait.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/key/KeyValueTrait.java @@ -21,18 +21,12 @@ public interface KeyValueTrait { - /** - * Returns the current value associated with the given key. - */ - V get(K key); + /** Returns the current value associated with the given key. */ + V get(K key); - /** - * Override the value. - */ - void put(K key, V value); + /** Override the value. */ + void put(K key, V value); - /** - * Delete the key. - */ - void remove(K key); + /** Delete the key. */ + void remove(K key); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/IStatePushDown.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/IStatePushDown.java index 25fb899f8..e0526a77e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/IStatePushDown.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/IStatePushDown.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.pushdown.filter.IFilter; import org.apache.geaflow.state.pushdown.limit.IEdgeLimit; @@ -28,17 +29,17 @@ public interface IStatePushDown { - IFilter getFilter(); + IFilter getFilter(); - Map getFilters(); + Map getFilters(); - IEdgeLimit getEdgeLimit(); + IEdgeLimit getEdgeLimit(); - List getOrderFields(); + List getOrderFields(); - IProjector getProjector(); + IProjector getProjector(); - PushDownType getType(); + PushDownType getType(); - boolean isEmpty(); + boolean isEmpty(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/KeyGroupStatePushDown.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/KeyGroupStatePushDown.java index 72c33a6f5..ec2ff5bcd 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/KeyGroupStatePushDown.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/KeyGroupStatePushDown.java @@ -23,36 +23,35 @@ public class KeyGroupStatePushDown extends StatePushDown { - private KeyGroup keyGroup; - - private KeyGroupStatePushDown() { - } - - public static KeyGroupStatePushDown of() { - return new KeyGroupStatePushDown(); - } - - public static KeyGroupStatePushDown of(KeyGroup keyGroup) { - return new KeyGroupStatePushDown().withKeyGroup(keyGroup); - } - - public static KeyGroupStatePushDown of(StatePushDown statePushDown) { - KeyGroupStatePushDown pushDown = new KeyGroupStatePushDown(); - pushDown.filter = statePushDown.filter; - pushDown.edgeLimit = statePushDown.edgeLimit; - pushDown.filters = statePushDown.filters; - pushDown.orderFields = statePushDown.orderFields; - pushDown.projector = statePushDown.projector; - pushDown.pushdownType = statePushDown.pushdownType; - return pushDown; - } - - public KeyGroup getKeyGroup() { - return keyGroup; - } - - public KeyGroupStatePushDown withKeyGroup(KeyGroup keyGroup) { - this.keyGroup = keyGroup; - return this; - } + private KeyGroup keyGroup; + + private KeyGroupStatePushDown() {} + + public static KeyGroupStatePushDown of() { + return new KeyGroupStatePushDown(); + } + + public static KeyGroupStatePushDown of(KeyGroup keyGroup) { + return new KeyGroupStatePushDown().withKeyGroup(keyGroup); + } + + public static KeyGroupStatePushDown of(StatePushDown statePushDown) { + KeyGroupStatePushDown pushDown = new KeyGroupStatePushDown(); + pushDown.filter = statePushDown.filter; + pushDown.edgeLimit = statePushDown.edgeLimit; + pushDown.filters = statePushDown.filters; + pushDown.orderFields = statePushDown.orderFields; + pushDown.projector = statePushDown.projector; + pushDown.pushdownType = statePushDown.pushdownType; + return pushDown; + } + + public KeyGroup getKeyGroup() { + return keyGroup; + } + + public KeyGroupStatePushDown withKeyGroup(KeyGroup keyGroup) { + this.keyGroup = keyGroup; + return this; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/PushDownType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/PushDownType.java index bb3249d34..f1b075d1e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/PushDownType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/PushDownType.java @@ -20,8 +20,8 @@ package org.apache.geaflow.state.pushdown; public enum PushDownType { - NORMAL, - PROJECT, - AGGREGATE, - PROJECT_THEN_AGGREGATE + NORMAL, + PROJECT, + AGGREGATE, + PROJECT_THEN_AGGREGATE } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/StatePushDown.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/StatePushDown.java index de1b5ff78..f27a895ac 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/StatePushDown.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/StatePushDown.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; + import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.pushdown.filter.FilterType; import org.apache.geaflow.state.pushdown.filter.IFilter; @@ -31,88 +32,87 @@ public class StatePushDown implements IStatePushDown { - protected IProjector projector; - protected IFilter filter = EmptyGraphFilter.of(); - protected Map filters; - protected IEdgeLimit edgeLimit; - protected List orderFields; - protected PushDownType pushdownType = PushDownType.NORMAL; - - protected StatePushDown() { - } - - public static StatePushDown of() { - return new StatePushDown(); - } - - public StatePushDown withFilters(Map filters) { - this.filters = filters; - return this; - } - - public StatePushDown withFilter(IFilter filter) { - this.filter = filter; - return this; - } - - public StatePushDown withEdgeLimit(IEdgeLimit edgeLimit) { - this.edgeLimit = edgeLimit; - return this; - } - - public StatePushDown withOrderField(EdgeAtom orderField) { - if (orderField != null) { - this.orderFields = Collections.singletonList(orderField); - } - return this; - } - - public StatePushDown withOrderFields(List orderFields) { - this.orderFields = orderFields; - return this; - } - - public StatePushDown withProjector(IProjector projector) { - this.projector = projector; - this.pushdownType = PushDownType.PROJECT; - return this; - } - - @Override - public IProjector getProjector() { - return projector; - } - - @Override - public IFilter getFilter() { - return filter; - } - - @Override - public Map getFilters() { - return filters; - } - - @Override - public IEdgeLimit getEdgeLimit() { - return edgeLimit; - } - - @Override - public List getOrderFields() { - return orderFields; - } - - @Override - public PushDownType getType() { - return pushdownType; - } - - @Override - public boolean isEmpty() { - boolean filterEmpty = filter == null || filter.getFilterType() == FilterType.EMPTY; - boolean filtersEmpty = filters == null || filters.isEmpty(); - boolean orderEmpty = orderFields == null || orderFields.isEmpty(); - return filterEmpty && filtersEmpty && orderEmpty && edgeLimit == null && projector == null; + protected IProjector projector; + protected IFilter filter = EmptyGraphFilter.of(); + protected Map filters; + protected IEdgeLimit edgeLimit; + protected List orderFields; + protected PushDownType pushdownType = PushDownType.NORMAL; + + protected StatePushDown() {} + + public static StatePushDown of() { + return new StatePushDown(); + } + + public StatePushDown withFilters(Map filters) { + this.filters = filters; + return this; + } + + public StatePushDown withFilter(IFilter filter) { + this.filter = filter; + return this; + } + + public StatePushDown withEdgeLimit(IEdgeLimit edgeLimit) { + this.edgeLimit = edgeLimit; + return this; + } + + public StatePushDown withOrderField(EdgeAtom orderField) { + if (orderField != null) { + this.orderFields = Collections.singletonList(orderField); } + return this; + } + + public StatePushDown withOrderFields(List orderFields) { + this.orderFields = orderFields; + return this; + } + + public StatePushDown withProjector(IProjector projector) { + this.projector = projector; + this.pushdownType = PushDownType.PROJECT; + return this; + } + + @Override + public IProjector getProjector() { + return projector; + } + + @Override + public IFilter getFilter() { + return filter; + } + + @Override + public Map getFilters() { + return filters; + } + + @Override + public IEdgeLimit getEdgeLimit() { + return edgeLimit; + } + + @Override + public List getOrderFields() { + return orderFields; + } + + @Override + public PushDownType getType() { + return pushdownType; + } + + @Override + public boolean isEmpty() { + boolean filterEmpty = filter == null || filter.getFilterType() == FilterType.EMPTY; + boolean filtersEmpty = filters == null || filters.isEmpty(); + boolean orderEmpty = orderFields == null || orderFields.isEmpty(); + return filterEmpty && filtersEmpty && orderEmpty && edgeLimit == null && projector == null; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/AndFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/AndFilter.java index 04d2207dd..ab45173c2 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/AndFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/AndFilter.java @@ -20,36 +20,37 @@ package org.apache.geaflow.state.pushdown.filter; import java.util.List; + import org.apache.geaflow.state.data.DataType; public class AndFilter extends BaseLogicFilter { - public AndFilter(List inputs) { - super(inputs); - } + public AndFilter(List inputs) { + super(inputs); + } - public AndFilter(IFilter... inputs) { - super(inputs); - } + public AndFilter(IFilter... inputs) { + super(inputs); + } - void sanityCheck(IFilter filter) { - DataType filterDataType = filter.dateType(); - if (this.dateType == null) { - this.dateType = filterDataType; - } else if (this.dateType != filterDataType) { - this.dateType = DataType.VE; - } - - FilterType filterType = filter.getFilterType(); - if (filterType == FilterType.AND || filterType == FilterType.OR) { - this.rootBitSet.or(((BaseLogicFilter) filter).rootBitSet); - } else if (filter.getFilterType().isRootFilter()) { - this.rootBitSet.set(filter.getFilterType().ordinal()); - } + void sanityCheck(IFilter filter) { + DataType filterDataType = filter.dateType(); + if (this.dateType == null) { + this.dateType = filterDataType; + } else if (this.dateType != filterDataType) { + this.dateType = DataType.VE; } - @Override - public FilterType getFilterType() { - return FilterType.AND; + FilterType filterType = filter.getFilterType(); + if (filterType == FilterType.AND || filterType == FilterType.OR) { + this.rootBitSet.or(((BaseLogicFilter) filter).rootBitSet); + } else if (filter.getFilterType().isRootFilter()) { + this.rootBitSet.set(filter.getFilterType().ordinal()); } + } + + @Override + public FilterType getFilterType() { + return FilterType.AND; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/BaseLogicFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/BaseLogicFilter.java index 3738b6e20..92c383a28 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/BaseLogicFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/BaseLogicFilter.java @@ -19,83 +19,84 @@ package org.apache.geaflow.state.pushdown.filter; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.BitSet; import java.util.List; -import org.apache.geaflow.state.data.DataType; -public abstract class BaseLogicFilter implements IFilter { - - protected List filters = new ArrayList<>(); - protected BitSet rootBitSet = new BitSet(); - protected DataType dateType; +import org.apache.geaflow.state.data.DataType; - public BaseLogicFilter(List inputs) { - Preconditions.checkArgument(inputs != null && inputs.size() > 0); - for (IFilter filter : inputs) { - handleFilter(filter); - } - } +import com.google.common.base.Preconditions; - public BaseLogicFilter(IFilter... inputs) { - Preconditions.checkNotNull(inputs); - for (IFilter filter : inputs) { - handleFilter(filter); - } - } +public abstract class BaseLogicFilter implements IFilter { - public boolean filter(Object value) { - throw new UnsupportedOperationException(); - } + protected List filters = new ArrayList<>(); + protected BitSet rootBitSet = new BitSet(); + protected DataType dateType; - public DataType dateType() { - return dateType; + public BaseLogicFilter(List inputs) { + Preconditions.checkArgument(inputs != null && inputs.size() > 0); + for (IFilter filter : inputs) { + handleFilter(filter); } + } - - public List getFilters() { - return filters; + public BaseLogicFilter(IFilter... inputs) { + Preconditions.checkNotNull(inputs); + for (IFilter filter : inputs) { + handleFilter(filter); } - - public BitSet getRootBitSet() { - return rootBitSet; + } + + public boolean filter(Object value) { + throw new UnsupportedOperationException(); + } + + public DataType dateType() { + return dateType; + } + + public List getFilters() { + return filters; + } + + public BitSet getRootBitSet() { + return rootBitSet; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{\"").append(getFilterType()).append("\":["); + if (filters != null && filters.size() > 0) { + int size = filters.size(); + int index = 0; + for (IFilter filter : filters) { + sb.append(filter.toString()); + if (index < size - 1) { + sb.append(","); + } + index++; + } } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("{\"").append(getFilterType()).append("\":["); - if (filters != null && filters.size() > 0) { - int size = filters.size(); - int index = 0; - for (IFilter filter : filters) { - sb.append(filter.toString()); - if (index < size - 1) { - sb.append(","); - } - index++; - } - } + sb.append("]"); + sb.append("}"); - sb.append("]"); - sb.append("}"); + return sb.toString(); + } - return sb.toString(); + private void handleFilter(IFilter filter) { + if (filter.getFilterType() == FilterType.EMPTY) { + return; } - private void handleFilter(IFilter filter) { - if (filter.getFilterType() == FilterType.EMPTY) { - return; - } - - sanityCheck(filter); - if (filter.getFilterType() == getFilterType()) { - filters.addAll(((BaseLogicFilter) filter).filters); - } else { - filters.add(filter); - } + sanityCheck(filter); + if (filter.getFilterType() == getFilterType()) { + filters.addAll(((BaseLogicFilter) filter).filters); + } else { + filters.add(filter); } + } - abstract void sanityCheck(IFilter filter); + abstract void sanityCheck(IFilter filter); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeLabelFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeLabelFilter.java index 43c022450..0fc6761da 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeLabelFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeLabelFilter.java @@ -19,46 +19,47 @@ package org.apache.geaflow.state.pushdown.filter; -import com.google.common.base.Joiner; -import com.google.common.collect.Sets; import java.util.Collection; import java.util.Set; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.edge.IEdge; +import com.google.common.base.Joiner; +import com.google.common.collect.Sets; + public class EdgeLabelFilter implements IEdgeFilter { - private Set labels; + private Set labels; - public EdgeLabelFilter(Collection list) { - this.labels = Sets.newHashSet(list); - } + public EdgeLabelFilter(Collection list) { + this.labels = Sets.newHashSet(list); + } - public EdgeLabelFilter(String... labels) { - this.labels = Sets.newHashSet(labels); - } + public EdgeLabelFilter(String... labels) { + this.labels = Sets.newHashSet(labels); + } - public static EdgeLabelFilter getInstance(String... labels) { - return new EdgeLabelFilter<>(labels); - } + public static EdgeLabelFilter getInstance(String... labels) { + return new EdgeLabelFilter<>(labels); + } - @Override - public boolean filter(IEdge value) { - return labels.contains(((IGraphElementWithLabelField) value).getLabel()); - } + @Override + public boolean filter(IEdge value) { + return labels.contains(((IGraphElementWithLabelField) value).getLabel()); + } - public Set getLabels() { - return labels; - } + public Set getLabels() { + return labels; + } - @Override - public FilterType getFilterType() { - return FilterType.EDGE_LABEL; - } + @Override + public FilterType getFilterType() { + return FilterType.EDGE_LABEL; + } - @Override - public String toString() { - return String.format("\"%s(%s)\"", getFilterType().name(), - Joiner.on(',').join(labels)); - } + @Override + public String toString() { + return String.format("\"%s(%s)\"", getFilterType().name(), Joiner.on(',').join(labels)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeTsFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeTsFilter.java index 2b5e54464..7cf4c2737 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeTsFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeTsFilter.java @@ -25,33 +25,33 @@ public class EdgeTsFilter implements IEdgeFilter { - private final TimeRange timeRange; - - public EdgeTsFilter(TimeRange timeRange) { - this.timeRange = timeRange; - } - - public static EdgeTsFilter getInstance(long start, long end) { - return new EdgeTsFilter<>(TimeRange.of(start, end)); - } - - @Override - public boolean filter(IEdge value) { - return timeRange.contain(((IGraphElementWithTimeField) value).getTime()); - } - - public TimeRange getTimeRange() { - return timeRange; - } - - @Override - public FilterType getFilterType() { - return FilterType.EDGE_TS; - } - - @Override - public String toString() { - return String.format("\"%s[%d,%d)\"", getFilterType().name(), - timeRange.getStart(), timeRange.getEnd()); - } + private final TimeRange timeRange; + + public EdgeTsFilter(TimeRange timeRange) { + this.timeRange = timeRange; + } + + public static EdgeTsFilter getInstance(long start, long end) { + return new EdgeTsFilter<>(TimeRange.of(start, end)); + } + + @Override + public boolean filter(IEdge value) { + return timeRange.contain(((IGraphElementWithTimeField) value).getTime()); + } + + public TimeRange getTimeRange() { + return timeRange; + } + + @Override + public FilterType getFilterType() { + return FilterType.EDGE_TS; + } + + @Override + public String toString() { + return String.format( + "\"%s[%d,%d)\"", getFilterType().name(), timeRange.getStart(), timeRange.getEnd()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeValueDropFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeValueDropFilter.java index 4c93042f2..c51571dbf 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeValueDropFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EdgeValueDropFilter.java @@ -23,25 +23,25 @@ public class EdgeValueDropFilter implements IEdgeFilter { - private static final EdgeValueDropFilter filter = new EdgeValueDropFilter(); - - public static EdgeValueDropFilter getInstance() { - return filter; - } - - @Override - public boolean filter(IEdge value) { - value.withValue(null); - return true; - } - - @Override - public FilterType getFilterType() { - return FilterType.EDGE_VALUE_DROP; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final EdgeValueDropFilter filter = new EdgeValueDropFilter(); + + public static EdgeValueDropFilter getInstance() { + return filter; + } + + @Override + public boolean filter(IEdge value) { + value.withValue(null); + return true; + } + + @Override + public FilterType getFilterType() { + return FilterType.EDGE_VALUE_DROP; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EmptyFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EmptyFilter.java index 8a81d4c00..b38dca495 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EmptyFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/EmptyFilter.java @@ -23,29 +23,29 @@ public class EmptyFilter implements IFilter { - private static final EmptyFilter FILTER = new EmptyFilter(); - - public static EmptyFilter getInstance() { - return FILTER; - } - - @Override - public boolean filter(Object value) { - return true; - } - - @Override - public DataType dateType() { - return DataType.OTHER; - } - - @Override - public FilterType getFilterType() { - return FilterType.EMPTY; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final EmptyFilter FILTER = new EmptyFilter(); + + public static EmptyFilter getInstance() { + return FILTER; + } + + @Override + public boolean filter(Object value) { + return true; + } + + @Override + public DataType dateType() { + return DataType.OTHER; + } + + @Override + public FilterType getFilterType() { + return FilterType.EMPTY; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/FilterType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/FilterType.java index 1a33d45dc..9d895bf62 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/FilterType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/FilterType.java @@ -22,85 +22,53 @@ import org.apache.geaflow.state.pushdown.inner.PushDownPb; public enum FilterType { - /** - * empty filter. - */ - EMPTY(false), - /** - * in edge filter. - */ - IN_EDGE(false), - /** - * out edge filter. - */ - OUT_EDGE(false), - /** - * edge ts filter. - */ - EDGE_TS(false), - /** - * ttl filter. - */ - TTL(false), - /** - * logic or filters. - */ - OR(false), - /** - * logic and filters. - */ - AND(false), - /** - * only fetch vertex. - */ - ONLY_VERTEX(true), - /** - * vertex ts filter. - */ - VERTEX_TS(false), - /** - * edge value drop. - */ - EDGE_VALUE_DROP(true), - /** - * vertex value drop. - */ - VERTEX_VALUE_DROP(true), - /** - * edge label filter. - */ - EDGE_LABEL(false), - /** - * vertex label filter. - */ - VERTEX_LABEL(false), - /** - * result must contain vertex. - */ - VERTEX_MUST_CONTAIN(true), + /** empty filter. */ + EMPTY(false), + /** in edge filter. */ + IN_EDGE(false), + /** out edge filter. */ + OUT_EDGE(false), + /** edge ts filter. */ + EDGE_TS(false), + /** ttl filter. */ + TTL(false), + /** logic or filters. */ + OR(false), + /** logic and filters. */ + AND(false), + /** only fetch vertex. */ + ONLY_VERTEX(true), + /** vertex ts filter. */ + VERTEX_TS(false), + /** edge value drop. */ + EDGE_VALUE_DROP(true), + /** vertex value drop. */ + VERTEX_VALUE_DROP(true), + /** edge label filter. */ + EDGE_LABEL(false), + /** vertex label filter. */ + VERTEX_LABEL(false), + /** result must contain vertex. */ + VERTEX_MUST_CONTAIN(true), - /** - * generated filter. - */ - GENERATED(false), - /** - * other filter type. - */ - OTHER(false); + /** generated filter. */ + GENERATED(false), + /** other filter type. */ + OTHER(false); - private final PushDownPb.FilterType pbFilterType; - private boolean isRootFilter; + private final PushDownPb.FilterType pbFilterType; + private boolean isRootFilter; - FilterType(boolean isRootFilter) { - this.isRootFilter = isRootFilter; - this.pbFilterType = PushDownPb.FilterType.valueOf(this.name()); - } + FilterType(boolean isRootFilter) { + this.isRootFilter = isRootFilter; + this.pbFilterType = PushDownPb.FilterType.valueOf(this.name()); + } - public boolean isRootFilter() { - return isRootFilter; - } + public boolean isRootFilter() { + return isRootFilter; + } - public PushDownPb.FilterType toPbFilterType() { - return pbFilterType; - } + public PushDownPb.FilterType toPbFilterType() { + return pbFilterType; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IEdgeFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IEdgeFilter.java index 304beb377..1242d0554 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IEdgeFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IEdgeFilter.java @@ -24,8 +24,8 @@ public interface IEdgeFilter extends IFilter> { - @Override - default DataType dateType() { - return DataType.E; - } + @Override + default DataType dateType() { + return DataType.E; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IFilter.java index 3050877f6..26df395f2 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IFilter.java @@ -20,41 +20,30 @@ package org.apache.geaflow.state.pushdown.filter; import java.io.Serializable; + import org.apache.geaflow.state.data.DataType; -/** - * The Filter interface is used for condition pushdown. - */ +/** The Filter interface is used for condition pushdown. */ public interface IFilter extends Serializable { - /** - * Filter the specific value, true means keep. - */ - boolean filter(T value); - - /** - * Returns the filter value type {@link DataType}. - */ - DataType dateType(); - - /** - * Returns the filter's type {@link FilterType}. - */ - default FilterType getFilterType() { - return FilterType.OTHER; - } - - /** - * Returns the logical and filter of two filter. - */ - default AndFilter and(IFilter filter) { - return new AndFilter(this, filter); - } - - /** - * Returns the logical or filter of two filter. - */ - default OrFilter or(IFilter filter) { - return new OrFilter(this, filter); - } + /** Filter the specific value, true means keep. */ + boolean filter(T value); + + /** Returns the filter value type {@link DataType}. */ + DataType dateType(); + + /** Returns the filter's type {@link FilterType}. */ + default FilterType getFilterType() { + return FilterType.OTHER; + } + + /** Returns the logical and filter of two filter. */ + default AndFilter and(IFilter filter) { + return new AndFilter(this, filter); + } + + /** Returns the logical or filter of two filter. */ + default OrFilter or(IFilter filter) { + return new OrFilter(this, filter); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IOneDegreeGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IOneDegreeGraphFilter.java index 6b86fa317..916af2a4f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IOneDegreeGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IOneDegreeGraphFilter.java @@ -24,8 +24,8 @@ public interface IOneDegreeGraphFilter extends IFilter> { - @Override - default DataType dateType() { - return DataType.VE; - } + @Override + default DataType dateType() { + return DataType.VE; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IVertexFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IVertexFilter.java index 11d9c2da9..3379b1aca 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IVertexFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/IVertexFilter.java @@ -24,8 +24,8 @@ public interface IVertexFilter extends IFilter> { - @Override - default DataType dateType() { - return DataType.V; - } + @Override + default DataType dateType() { + return DataType.V; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/InEdgeFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/InEdgeFilter.java index c9dcc2bc3..01fbe9511 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/InEdgeFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/InEdgeFilter.java @@ -24,24 +24,24 @@ public class InEdgeFilter implements IEdgeFilter { - private static final InEdgeFilter inEdgeFilter = new InEdgeFilter(); - - @Override - public boolean filter(IEdge value) { - return value.getDirect() == EdgeDirection.IN; - } - - @Override - public FilterType getFilterType() { - return FilterType.IN_EDGE; - } - - public static InEdgeFilter getInstance() { - return inEdgeFilter; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final InEdgeFilter inEdgeFilter = new InEdgeFilter(); + + @Override + public boolean filter(IEdge value) { + return value.getDirect() == EdgeDirection.IN; + } + + @Override + public FilterType getFilterType() { + return FilterType.IN_EDGE; + } + + public static InEdgeFilter getInstance() { + return inEdgeFilter; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OrFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OrFilter.java index f172686b3..11bd58ef5 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OrFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OrFilter.java @@ -19,61 +19,65 @@ package org.apache.geaflow.state.pushdown.filter; -import com.google.common.base.Preconditions; import java.util.BitSet; import java.util.List; -import org.apache.geaflow.state.data.DataType; -public class OrFilter extends BaseLogicFilter { +import org.apache.geaflow.state.data.DataType; - private static final BitSet EMPTY_BIT_SET = new BitSet(); - private boolean singleLimit; +import com.google.common.base.Preconditions; - public OrFilter(List inputs) { - super(inputs); - } +public class OrFilter extends BaseLogicFilter { - public OrFilter(IFilter... inputs) { - super(inputs); - } + private static final BitSet EMPTY_BIT_SET = new BitSet(); + private boolean singleLimit; - void sanityCheck(IFilter filter) { - FilterType filterType = filter.getFilterType(); - DataType filterDataType = filter.dateType(); + public OrFilter(List inputs) { + super(inputs); + } - Preconditions.checkArgument(!filterType.isRootFilter(), "filter illegal %s", filterType); - if (this.dateType == null) { - this.dateType = filterDataType; - } else { - Preconditions.checkArgument(this.dateType == filterDataType, - "mix vertex filter and edge filter in OR list, %s, %s", - this.dateType, filterDataType); - } + public OrFilter(IFilter... inputs) { + super(inputs); + } - // root filters in or must be the same. - BitSet filterRootBitSet = EMPTY_BIT_SET; - if (filterType == FilterType.OR || filter.getFilterType() == FilterType.AND) { - filterRootBitSet = ((BaseLogicFilter) filter).rootBitSet; - } + void sanityCheck(IFilter filter) { + FilterType filterType = filter.getFilterType(); + DataType filterDataType = filter.dateType(); - if (filters.size() == 0 && filterRootBitSet != EMPTY_BIT_SET) { - this.rootBitSet = filterRootBitSet; - } else if (filters.size() > 0) { - Preconditions.checkArgument(this.rootBitSet.equals(filterRootBitSet)); - } + Preconditions.checkArgument(!filterType.isRootFilter(), "filter illegal %s", filterType); + if (this.dateType == null) { + this.dateType = filterDataType; + } else { + Preconditions.checkArgument( + this.dateType == filterDataType, + "mix vertex filter and edge filter in OR list, %s, %s", + this.dateType, + filterDataType); } - public OrFilter singleLimit() { - this.singleLimit = true; - return this; + // root filters in or must be the same. + BitSet filterRootBitSet = EMPTY_BIT_SET; + if (filterType == FilterType.OR || filter.getFilterType() == FilterType.AND) { + filterRootBitSet = ((BaseLogicFilter) filter).rootBitSet; } - public boolean isSingleLimit() { - return singleLimit; + if (filters.size() == 0 && filterRootBitSet != EMPTY_BIT_SET) { + this.rootBitSet = filterRootBitSet; + } else if (filters.size() > 0) { + Preconditions.checkArgument(this.rootBitSet.equals(filterRootBitSet)); } + } - @Override - public FilterType getFilterType() { - return FilterType.OR; - } + public OrFilter singleLimit() { + this.singleLimit = true; + return this; + } + + public boolean isSingleLimit() { + return singleLimit; + } + + @Override + public FilterType getFilterType() { + return FilterType.OR; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OutEdgeFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OutEdgeFilter.java index e8240d29f..99baeadc5 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OutEdgeFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/OutEdgeFilter.java @@ -24,24 +24,24 @@ public class OutEdgeFilter implements IEdgeFilter { - private static final OutEdgeFilter outEdgeFilter = new OutEdgeFilter(); - - public static OutEdgeFilter getInstance() { - return outEdgeFilter; - } - - @Override - public boolean filter(IEdge value) { - return value.getDirect() == EdgeDirection.OUT; - } - - @Override - public FilterType getFilterType() { - return FilterType.OUT_EDGE; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final OutEdgeFilter outEdgeFilter = new OutEdgeFilter(); + + public static OutEdgeFilter getInstance() { + return outEdgeFilter; + } + + @Override + public boolean filter(IEdge value) { + return value.getDirect() == EdgeDirection.OUT; + } + + @Override + public FilterType getFilterType() { + return FilterType.OUT_EDGE; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexLabelFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexLabelFilter.java index c84c47529..60450cb2f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexLabelFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexLabelFilter.java @@ -19,46 +19,47 @@ package org.apache.geaflow.state.pushdown.filter; -import com.google.common.base.Joiner; -import com.google.common.collect.Sets; import java.util.Collection; import java.util.Set; + import org.apache.geaflow.model.graph.IGraphElementWithLabelField; import org.apache.geaflow.model.graph.vertex.IVertex; +import com.google.common.base.Joiner; +import com.google.common.collect.Sets; + public class VertexLabelFilter implements IVertexFilter { - private Set labels; + private Set labels; - public VertexLabelFilter(Collection list) { - this.labels = Sets.newHashSet(list); - } + public VertexLabelFilter(Collection list) { + this.labels = Sets.newHashSet(list); + } - public VertexLabelFilter(String... labels) { - this.labels = Sets.newHashSet(labels); - } + public VertexLabelFilter(String... labels) { + this.labels = Sets.newHashSet(labels); + } - public static VertexLabelFilter getInstance(String... labels) { - return new VertexLabelFilter<>(labels); - } + public static VertexLabelFilter getInstance(String... labels) { + return new VertexLabelFilter<>(labels); + } - @Override - public boolean filter(IVertex value) { - return labels.contains(((IGraphElementWithLabelField) value).getLabel()); - } + @Override + public boolean filter(IVertex value) { + return labels.contains(((IGraphElementWithLabelField) value).getLabel()); + } - public Set getLabels() { - return labels; - } + public Set getLabels() { + return labels; + } - @Override - public FilterType getFilterType() { - return FilterType.VERTEX_LABEL; - } + @Override + public FilterType getFilterType() { + return FilterType.VERTEX_LABEL; + } - @Override - public String toString() { - return String.format("\"%s(%s)\"", getFilterType().toString(), - Joiner.on(',').join(labels)); - } + @Override + public String toString() { + return String.format("\"%s(%s)\"", getFilterType().toString(), Joiner.on(',').join(labels)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexMustContainFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexMustContainFilter.java index 745d84cda..85bc9f6e3 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexMustContainFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexMustContainFilter.java @@ -23,24 +23,24 @@ public class VertexMustContainFilter implements IOneDegreeGraphFilter { - private static final VertexMustContainFilter filter = new VertexMustContainFilter(); - - @Override - public boolean filter(OneDegreeGraph value) { - return value.getVertex() != null; - } - - @Override - public FilterType getFilterType() { - return FilterType.VERTEX_MUST_CONTAIN; - } - - public static VertexMustContainFilter getInstance() { - return filter; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final VertexMustContainFilter filter = new VertexMustContainFilter(); + + @Override + public boolean filter(OneDegreeGraph value) { + return value.getVertex() != null; + } + + @Override + public FilterType getFilterType() { + return FilterType.VERTEX_MUST_CONTAIN; + } + + public static VertexMustContainFilter getInstance() { + return filter; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexTsFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexTsFilter.java index fb7b5723f..f773898b5 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexTsFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexTsFilter.java @@ -25,33 +25,33 @@ public class VertexTsFilter implements IVertexFilter { - private final TimeRange timeRange; - - public VertexTsFilter(TimeRange timeRange) { - this.timeRange = timeRange; - } - - public static VertexTsFilter getInstance(long start, long end) { - return new VertexTsFilter<>(TimeRange.of(start, end)); - } - - @Override - public boolean filter(IVertex value) { - return timeRange.contain(((IGraphElementWithTimeField) value).getTime()); - } - - public TimeRange getTimeRange() { - return timeRange; - } - - @Override - public FilterType getFilterType() { - return FilterType.VERTEX_TS; - } - - @Override - public String toString() { - return String.format("\"%s[%d,%d)\"", getFilterType().toString(), - timeRange.getStart(), timeRange.getEnd()); - } + private final TimeRange timeRange; + + public VertexTsFilter(TimeRange timeRange) { + this.timeRange = timeRange; + } + + public static VertexTsFilter getInstance(long start, long end) { + return new VertexTsFilter<>(TimeRange.of(start, end)); + } + + @Override + public boolean filter(IVertex value) { + return timeRange.contain(((IGraphElementWithTimeField) value).getTime()); + } + + public TimeRange getTimeRange() { + return timeRange; + } + + @Override + public FilterType getFilterType() { + return FilterType.VERTEX_TS; + } + + @Override + public String toString() { + return String.format( + "\"%s[%d,%d)\"", getFilterType().toString(), timeRange.getStart(), timeRange.getEnd()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexValueDropFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexValueDropFilter.java index 197b0a8fe..b4e6e40c1 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexValueDropFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/VertexValueDropFilter.java @@ -23,25 +23,25 @@ public class VertexValueDropFilter implements IVertexFilter { - private static final VertexValueDropFilter filter = new VertexValueDropFilter(); - - public static VertexValueDropFilter getInstance() { - return filter; - } - - @Override - public boolean filter(IVertex value) { - value.withValue(null); - return true; - } - - @Override - public FilterType getFilterType() { - return FilterType.VERTEX_VALUE_DROP; - } - - @Override - public String toString() { - return getFilterType().name(); - } + private static final VertexValueDropFilter filter = new VertexValueDropFilter(); + + public static VertexValueDropFilter getInstance() { + return filter; + } + + @Override + public boolean filter(IVertex value) { + value.withValue(null); + return true; + } + + @Override + public FilterType getFilterType() { + return FilterType.VERTEX_VALUE_DROP; + } + + @Override + public String toString() { + return getFilterType().name(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/AndGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/AndGraphFilter.java index 10859c3c8..1160f091e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/AndGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/AndGraphFilter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.pushdown.filter.inner; import java.util.List; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.data.OneDegreeGraph; @@ -27,105 +28,99 @@ public class AndGraphFilter extends BaseComposeGraphFilter { - public AndGraphFilter(List childrenFilters) { - super(childrenFilters); - } - - @Override - public IGraphFilter and(IGraphFilter other) { - switch (other.getFilterType()) { - case EMPTY: - return this; - case AND: - childrenFilters.addAll(((AndGraphFilter) (other)).getFilterList()); - return this; - default: - childrenFilters.add(other); - return this; - } - - } + public AndGraphFilter(List childrenFilters) { + super(childrenFilters); + } - /** - * An edge will be filtered if return is false. - * - * @param edge - */ - @Override - public boolean filterEdge(IEdge edge) { - for (IGraphFilter filter : childrenFilters) { - if (!filter.filterEdge(edge)) { - return false; - } - } - return true; + @Override + public IGraphFilter and(IGraphFilter other) { + switch (other.getFilterType()) { + case EMPTY: + return this; + case AND: + childrenFilters.addAll(((AndGraphFilter) (other)).getFilterList()); + return this; + default: + childrenFilters.add(other); + return this; } + } - /** - * A Vertex will be filtered if return is false. - * - * @param vertex - */ - @Override - public boolean filterVertex(IVertex vertex) { - for (IGraphFilter filter : childrenFilters) { - if (!filter.filterVertex(vertex)) { - return false; - } - } - return true; + /** + * An edge will be filtered if return is false. + * + * @param edge + */ + @Override + public boolean filterEdge(IEdge edge) { + for (IGraphFilter filter : childrenFilters) { + if (!filter.filterEdge(edge)) { + return false; + } } + return true; + } - /** - * A oneDegreeGraph will be filtered if return is false. - */ - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - for (IGraphFilter filter : childrenFilters) { - if (!filter.filterOneDegreeGraph(oneDegreeGraph)) { - return false; - } - } - return true; + /** + * A Vertex will be filtered if return is false. + * + * @param vertex + */ + @Override + public boolean filterVertex(IVertex vertex) { + for (IGraphFilter filter : childrenFilters) { + if (!filter.filterVertex(vertex)) { + return false; + } } + return true; + } - @Override - public AndGraphFilter clone() { - return new AndGraphFilter(cloneFilterList()); + /** A oneDegreeGraph will be filtered if return is false. */ + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + for (IGraphFilter filter : childrenFilters) { + if (!filter.filterOneDegreeGraph(oneDegreeGraph)) { + return false; + } } + return true; + } - @Override - public boolean contains(FilterType type) { - if (type == FilterType.AND) { - return true; - } - for (IGraphFilter filter : childrenFilters) { - if (filter.contains(type)) { - return true; - } - } - return false; + @Override + public AndGraphFilter clone() { + return new AndGraphFilter(cloneFilterList()); + } + @Override + public boolean contains(FilterType type) { + if (type == FilterType.AND) { + return true; } - - @Override - public IGraphFilter retrieve(FilterType type) { - if (type == FilterType.AND) { - return this; - } - for (IGraphFilter filter : childrenFilters) { - if (filter.contains(type)) { - return filter.retrieve(type); - } - } - return null; + for (IGraphFilter filter : childrenFilters) { + if (filter.contains(type)) { + return true; + } } + return false; + } - /** - * FilterTypes. - */ - @Override - public FilterType getFilterType() { - return FilterType.AND; + @Override + public IGraphFilter retrieve(FilterType type) { + if (type == FilterType.AND) { + return this; } + for (IGraphFilter filter : childrenFilters) { + if (filter.contains(type)) { + return filter.retrieve(type); + } + } + return null; + } + + /** FilterTypes. */ + @Override + public FilterType getFilterType() { + return FilterType.AND; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseComposeGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseComposeGraphFilter.java index 7cbc3cba3..879a7d586 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseComposeGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseComposeGraphFilter.java @@ -21,64 +21,63 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.state.data.DataType; import org.apache.geaflow.state.pushdown.filter.IFilter; public abstract class BaseComposeGraphFilter extends BaseGraphFilter { - protected List childrenFilters; + protected List childrenFilters; - public BaseComposeGraphFilter(List childrenFilters) { - this.childrenFilters = childrenFilters; - } + public BaseComposeGraphFilter(List childrenFilters) { + this.childrenFilters = childrenFilters; + } - @Override - public DataType dateType() { - return DataType.OTHER; - } + @Override + public DataType dateType() { + return DataType.OTHER; + } - /** - * If this returns true, the edge scan will terminate. - */ - @Override - public boolean dropAllRemaining() { - for (IGraphFilter filter : childrenFilters) { - if (!filter.dropAllRemaining()) { - return false; - } - } - return true; + /** If this returns true, the edge scan will terminate. */ + @Override + public boolean dropAllRemaining() { + for (IGraphFilter filter : childrenFilters) { + if (!filter.dropAllRemaining()) { + return false; + } } + return true; + } - public List getFilterList() { - return this.childrenFilters; - } + public List getFilterList() { + return this.childrenFilters; + } - public List cloneFilterList() { - List copyList = new ArrayList<>(childrenFilters.size()); - childrenFilters.forEach(c -> copyList.add(c.clone())); - return copyList; - } + public List cloneFilterList() { + List copyList = new ArrayList<>(childrenFilters.size()); + childrenFilters.forEach(c -> copyList.add(c.clone())); + return copyList; + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("{\"").append(getFilterType().name()).append("\":["); - if (childrenFilters != null && childrenFilters.size() > 0) { - int size = childrenFilters.size(); - int index = 0; - for (IFilter filter : childrenFilters) { - sb.append(filter.toString()); - if (index < size - 1) { - sb.append(","); - } - index++; - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{\"").append(getFilterType().name()).append("\":["); + if (childrenFilters != null && childrenFilters.size() > 0) { + int size = childrenFilters.size(); + int index = 0; + for (IFilter filter : childrenFilters) { + sb.append(filter.toString()); + if (index < size - 1) { + sb.append(","); } + index++; + } + } - sb.append("]"); - sb.append("}"); + sb.append("]"); + sb.append("}"); - return sb.toString(); - } + return sb.toString(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseGraphFilter.java index ec903c98a..8d6249ab3 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/BaseGraphFilter.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.graph.edge.IEdge; @@ -30,111 +31,101 @@ public abstract class BaseGraphFilter implements IGraphFilter { - @Override - public boolean filter(Object value) { - switch (dateType()) { - case V: - return filterVertex((IVertex) value); - case E: - return filterEdge((IEdge) value); - case VE: - return filterOneDegreeGraph((OneDegreeGraph) value); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("not support " + dateType())); - } + @Override + public boolean filter(Object value) { + switch (dateType()) { + case V: + return filterVertex((IVertex) value); + case E: + return filterEdge((IEdge) value); + case VE: + return filterOneDegreeGraph((OneDegreeGraph) value); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.runError("not support " + dateType())); } + } - /** - * An edge will be filtered if return is false. - * - * @param edge - */ - @Override - public boolean filterEdge(IEdge edge) { - return true; - } + /** + * An edge will be filtered if return is false. + * + * @param edge + */ + @Override + public boolean filterEdge(IEdge edge) { + return true; + } - /** - * A Vertex will be filtered if return is false. - * - * @param vertex - */ - @Override - public boolean filterVertex(IVertex vertex) { - return true; - } + /** + * A Vertex will be filtered if return is false. + * + * @param vertex + */ + @Override + public boolean filterVertex(IVertex vertex) { + return true; + } - /** - * A oneDegreeGraph will be filtered if return is false. - */ - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - return true; - } + /** A oneDegreeGraph will be filtered if return is false. */ + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + return true; + } - /** - * If this returns true, the edge scan will terminate. - */ - @Override - public boolean dropAllRemaining() { - return false; - } + /** If this returns true, the edge scan will terminate. */ + @Override + public boolean dropAllRemaining() { + return false; + } - /** - * FilterTypes. - */ - @Override - public FilterType getFilterType() { - return FilterType.OTHER; - } + /** FilterTypes. */ + @Override + public FilterType getFilterType() { + return FilterType.OTHER; + } - /** - * Union other filter with AND logic. - */ - @Override - public IGraphFilter and(IGraphFilter other) { - List list = new ArrayList<>(); - list.add(this); - if (other.getFilterType() == FilterType.AND) { - list.addAll(((AndGraphFilter) other).getFilterList()); - } else { - list.add(other); - } - return new AndGraphFilter(list); + /** Union other filter with AND logic. */ + @Override + public IGraphFilter and(IGraphFilter other) { + List list = new ArrayList<>(); + list.add(this); + if (other.getFilterType() == FilterType.AND) { + list.addAll(((AndGraphFilter) other).getFilterList()); + } else { + list.add(other); } + return new AndGraphFilter(list); + } - /** - * Union other filter with OR logic. - */ - @Override - public IGraphFilter or(IGraphFilter other) { - List list = new ArrayList<>(); - list.add(this); - if (other.getFilterType() == FilterType.OR) { - list.addAll(((OrGraphFilter) other).getFilterList()); - } else { - list.add(other); - } - return new OrGraphFilter(list); + /** Union other filter with OR logic. */ + @Override + public IGraphFilter or(IGraphFilter other) { + List list = new ArrayList<>(); + list.add(this); + if (other.getFilterType() == FilterType.OR) { + list.addAll(((OrGraphFilter) other).getFilterList()); + } else { + list.add(other); } + return new OrGraphFilter(list); + } - @Override - public boolean contains(FilterType type) { - return type == getFilterType(); - } + @Override + public boolean contains(FilterType type) { + return type == getFilterType(); + } - @Override - public IGraphFilter retrieve(FilterType type) { - return contains(type) ? this : null; - } + @Override + public IGraphFilter retrieve(FilterType type) { + return contains(type) ? this : null; + } - @Override - public String toString() { - return getFilterType().name(); - } + @Override + public String toString() { + return getFilterType().name(); + } - @Override - public IGraphFilter clone() { - return this; - } + @Override + public IGraphFilter clone() { + return this; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/EmptyGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/EmptyGraphFilter.java index cede42738..8e63fcb45 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/EmptyGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/EmptyGraphFilter.java @@ -27,44 +27,44 @@ public class EmptyGraphFilter extends BaseGraphFilter { - private static final EmptyGraphFilter filter = new EmptyGraphFilter(); + private static final EmptyGraphFilter filter = new EmptyGraphFilter(); - public static EmptyGraphFilter of() { - return filter; - } + public static EmptyGraphFilter of() { + return filter; + } - @Override - public IGraphFilter and(IGraphFilter other) { - return other; - } + @Override + public IGraphFilter and(IGraphFilter other) { + return other; + } - @Override - public IGraphFilter or(IGraphFilter other) { - return other; - } + @Override + public IGraphFilter or(IGraphFilter other) { + return other; + } - @Override - public boolean filterEdge(IEdge edge) { - return true; - } + @Override + public boolean filterEdge(IEdge edge) { + return true; + } - @Override - public boolean filterVertex(IVertex vertex) { - return true; - } + @Override + public boolean filterVertex(IVertex vertex) { + return true; + } - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - return true; - } + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + return true; + } - @Override - public DataType dateType() { - return DataType.OTHER; - } + @Override + public DataType dateType() { + return DataType.OTHER; + } - @Override - public FilterType getFilterType() { - return FilterType.EMPTY; - } + @Override + public FilterType getFilterType() { + return FilterType.EMPTY; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/FilterHelper.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/FilterHelper.java index 49163be10..96ed9b4fe 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/FilterHelper.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/FilterHelper.java @@ -19,9 +19,9 @@ package org.apache.geaflow.state.pushdown.filter.inner; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.state.data.TimeRange; import org.apache.geaflow.state.pushdown.filter.EdgeLabelFilter; import org.apache.geaflow.state.pushdown.filter.EdgeTsFilter; @@ -31,140 +31,141 @@ import org.apache.geaflow.state.pushdown.filter.VertexLabelFilter; import org.apache.geaflow.state.pushdown.filter.VertexTsFilter; +import com.google.common.base.Preconditions; + public class FilterHelper { - public static boolean isSingleLimit(IFilter[] filters) { - int orNumber = 0; - int singleOrNumber = 0; - for (IFilter filter : filters) { - if (filter.getFilterType() == FilterType.OR) { - orNumber++; - if (((OrFilter) filter).isSingleLimit()) { - singleOrNumber++; - } - } + public static boolean isSingleLimit(IFilter[] filters) { + int orNumber = 0; + int singleOrNumber = 0; + for (IFilter filter : filters) { + if (filter.getFilterType() == FilterType.OR) { + orNumber++; + if (((OrFilter) filter).isSingleLimit()) { + singleOrNumber++; } - Preconditions.checkArgument(singleOrNumber == 0 || orNumber == singleOrNumber, - "some or filter is not single"); - return singleOrNumber > 0; + } } - - /** - * Parse label from graph filter, only direct filter converter support. - * - * @param filter graph filter. - * @return list of labels to be filtered. - */ - public static List parseLabel(IGraphFilter filter, boolean isVertex) { - if (filter.getFilterType() == FilterType.OR) { - return parseLabelOr(filter, isVertex); - } else if (filter.getFilterType() == FilterType.AND) { - return parseLabelAnd(filter, isVertex); - } else { - return parseLabelNormal(filter, isVertex); - } + Preconditions.checkArgument( + singleOrNumber == 0 || orNumber == singleOrNumber, "some or filter is not single"); + return singleOrNumber > 0; + } + + /** + * Parse label from graph filter, only direct filter converter support. + * + * @param filter graph filter. + * @return list of labels to be filtered. + */ + public static List parseLabel(IGraphFilter filter, boolean isVertex) { + if (filter.getFilterType() == FilterType.OR) { + return parseLabelOr(filter, isVertex); + } else if (filter.getFilterType() == FilterType.AND) { + return parseLabelAnd(filter, isVertex); + } else { + return parseLabelNormal(filter, isVertex); } + } - private static List parseLabelOr(IGraphFilter filter, boolean isVertex) { - List labels = new ArrayList<>(); - OrGraphFilter orFilter = (OrGraphFilter) filter; + private static List parseLabelOr(IGraphFilter filter, boolean isVertex) { + List labels = new ArrayList<>(); + OrGraphFilter orFilter = (OrGraphFilter) filter; - if (orFilter.getFilterList().isEmpty()) { - return labels; - } - - for (IGraphFilter childFilter : orFilter.getFilterList()) { - List tmpLabels = parseLabel(childFilter, isVertex); - if (tmpLabels.isEmpty()) { - labels.clear(); - return labels; - } else { - labels.addAll(tmpLabels); - } - } + if (orFilter.getFilterList().isEmpty()) { + return labels; + } + for (IGraphFilter childFilter : orFilter.getFilterList()) { + List tmpLabels = parseLabel(childFilter, isVertex); + if (tmpLabels.isEmpty()) { + labels.clear(); return labels; + } else { + labels.addAll(tmpLabels); + } } - private static List parseLabelAnd(IGraphFilter filter, boolean isVertex) { - List labels = new ArrayList<>(); - AndGraphFilter andFilter = (AndGraphFilter) filter; + return labels; + } - if (andFilter.getFilterList().isEmpty()) { - return labels; - } + private static List parseLabelAnd(IGraphFilter filter, boolean isVertex) { + List labels = new ArrayList<>(); + AndGraphFilter andFilter = (AndGraphFilter) filter; - for (IGraphFilter childFilter : andFilter.getFilterList()) { - labels.addAll(parseLabel(childFilter, isVertex)); - if (labels.size() > 1) { - return new ArrayList<>(); - } - } + if (andFilter.getFilterList().isEmpty()) { + return labels; + } - return labels; + for (IGraphFilter childFilter : andFilter.getFilterList()) { + labels.addAll(parseLabel(childFilter, isVertex)); + if (labels.size() > 1) { + return new ArrayList<>(); + } } - private static List parseLabelNormal(IGraphFilter filter, boolean isVertex) { - List labels = new ArrayList<>(); - if (!isVertex && filter.getFilterType() == FilterType.EDGE_LABEL) { - EdgeLabelFilter labelFilter = (EdgeLabelFilter) filter.getFilter(); - labels.addAll(labelFilter.getLabels()); - } else if (isVertex && filter.getFilterType() == FilterType.VERTEX_LABEL) { - VertexLabelFilter labelFilter = (VertexLabelFilter) filter.getFilter(); - labels.addAll(labelFilter.getLabels()); - } + return labels; + } + + private static List parseLabelNormal(IGraphFilter filter, boolean isVertex) { + List labels = new ArrayList<>(); + if (!isVertex && filter.getFilterType() == FilterType.EDGE_LABEL) { + EdgeLabelFilter labelFilter = (EdgeLabelFilter) filter.getFilter(); + labels.addAll(labelFilter.getLabels()); + } else if (isVertex && filter.getFilterType() == FilterType.VERTEX_LABEL) { + VertexLabelFilter labelFilter = (VertexLabelFilter) filter.getFilter(); + labels.addAll(labelFilter.getLabels()); + } - return labels; + return labels; + } + + /** + * Parse dt from graph filter, only direct filter converter support. + * + * @param filter graph filter. + * @return time range + */ + public static TimeRange parseDt(IGraphFilter filter, boolean isVertex) { + if (filter.getFilterType() == FilterType.OR || filter.getFilterType() == FilterType.AND) { + return parseDtCompose(filter, isVertex); + } else { + return parseDtNormal(filter, isVertex); } + } - /** - * Parse dt from graph filter, only direct filter converter support. - * - * @param filter graph filter. - * @return time range - */ - public static TimeRange parseDt(IGraphFilter filter, boolean isVertex) { - if (filter.getFilterType() == FilterType.OR || filter.getFilterType() == FilterType.AND) { - return parseDtCompose(filter, isVertex); - } else { - return parseDtNormal(filter, isVertex); - } + private static TimeRange parseDtCompose(IGraphFilter filter, boolean isVertex) { + BaseComposeGraphFilter composeFilter = (BaseComposeGraphFilter) filter; + if (composeFilter.getFilterList().isEmpty()) { + return null; } - private static TimeRange parseDtCompose(IGraphFilter filter, boolean isVertex) { - BaseComposeGraphFilter composeFilter = (BaseComposeGraphFilter) filter; - if (composeFilter.getFilterList().isEmpty()) { - return null; - } + TimeRange range = null; + for (IGraphFilter childFilter : composeFilter.getFilterList()) { + range = parseDtNormal(childFilter, isVertex); + if (range != null) { + return range; + } + } - TimeRange range = null; - for (IGraphFilter childFilter : composeFilter.getFilterList()) { - range = parseDtNormal(childFilter, isVertex); - if (range != null) { - return range; - } - } + return null; + } - return null; - } + private static TimeRange parseDtNormal(IGraphFilter filter, boolean isVertex) { + if (isVertex) { + if (filter.contains(FilterType.VERTEX_TS)) { + VertexTsFilter tsFilter = + (VertexTsFilter) (filter.retrieve(FilterType.VERTEX_TS).getFilter()); - private static TimeRange parseDtNormal(IGraphFilter filter, boolean isVertex) { - if (isVertex) { - if (filter.contains(FilterType.VERTEX_TS)) { - VertexTsFilter tsFilter = (VertexTsFilter) (filter.retrieve(FilterType.VERTEX_TS) - .getFilter()); - - return tsFilter.getTimeRange(); - } - } else { - if (filter.contains(FilterType.EDGE_TS)) { - EdgeTsFilter tsFilter = (EdgeTsFilter) (filter.retrieve(FilterType.EDGE_TS) - .getFilter()); - - return tsFilter.getTimeRange(); - } - } + return tsFilter.getTimeRange(); + } + } else { + if (filter.contains(FilterType.EDGE_TS)) { + EdgeTsFilter tsFilter = (EdgeTsFilter) (filter.retrieve(FilterType.EDGE_TS).getFilter()); - return null; + return tsFilter.getTimeRange(); + } } + + return null; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GeneratedQueryFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GeneratedQueryFilter.java index 5e6d1fc98..f72917829 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GeneratedQueryFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GeneratedQueryFilter.java @@ -21,8 +21,6 @@ public interface GeneratedQueryFilter { - /** - * setting variables in code generated plan. - */ - void initVariables(Object[] variables); + /** setting variables in code generated plan. */ + void initVariables(Object[] variables); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GraphFilter.java index 14323af52..113c8286f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/GraphFilter.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.data.DataType; @@ -33,79 +34,81 @@ public class GraphFilter extends BaseGraphFilter { - private final DataType dataType; - private IFilter filter; + private final DataType dataType; + private IFilter filter; - private GraphFilter(IFilter filter) { - this.filter = filter; - this.dataType = filter.dateType(); - } + private GraphFilter(IFilter filter) { + this.filter = filter; + this.dataType = filter.dateType(); + } - public static IGraphFilter of(IFilter filter) { - if (filter instanceof IGraphFilter) { - return (IGraphFilter) filter; - } - switch (filter.getFilterType()) { - case EMPTY: - return EmptyGraphFilter.of(); - case AND: - List childrenFilters = - ((AndFilter) filter).getFilters().stream().map(GraphFilter::of).collect(Collectors.toList()); - return new AndGraphFilter(childrenFilters); - case OR: - childrenFilters = - ((OrFilter) filter).getFilters().stream().map(GraphFilter::of).collect(Collectors.toList()); - return new OrGraphFilter(childrenFilters); - default: - return new GraphFilter(filter); - } + public static IGraphFilter of(IFilter filter) { + if (filter instanceof IGraphFilter) { + return (IGraphFilter) filter; } - - public static IGraphFilter of(IGraphFilter graphFilter, IEdgeLimit limit) { - return limit == null ? graphFilter : LimitFilterBuilder.build(graphFilter, limit); + switch (filter.getFilterType()) { + case EMPTY: + return EmptyGraphFilter.of(); + case AND: + List childrenFilters = + ((AndFilter) filter) + .getFilters().stream().map(GraphFilter::of).collect(Collectors.toList()); + return new AndGraphFilter(childrenFilters); + case OR: + childrenFilters = + ((OrFilter) filter) + .getFilters().stream().map(GraphFilter::of).collect(Collectors.toList()); + return new OrGraphFilter(childrenFilters); + default: + return new GraphFilter(filter); } + } - public static IGraphFilter of(IFilter filter, IEdgeLimit limit) { - IGraphFilter graphFilter = of(filter); - return of(graphFilter, limit); - } + public static IGraphFilter of(IGraphFilter graphFilter, IEdgeLimit limit) { + return limit == null ? graphFilter : LimitFilterBuilder.build(graphFilter, limit); + } - @Override - public boolean filterEdge(IEdge edge) { - if (dataType == DataType.E) { - return filter.filter(edge); - } - return true; - } + public static IGraphFilter of(IFilter filter, IEdgeLimit limit) { + IGraphFilter graphFilter = of(filter); + return of(graphFilter, limit); + } - @Override - public boolean filterVertex(IVertex vertex) { - if (dataType == DataType.V) { - return filter.filter(vertex); - } - return true; + @Override + public boolean filterEdge(IEdge edge) { + if (dataType == DataType.E) { + return filter.filter(edge); } + return true; + } - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - if (dataType == DataType.VE) { - return filter.filter(oneDegreeGraph); - } - return true; + @Override + public boolean filterVertex(IVertex vertex) { + if (dataType == DataType.V) { + return filter.filter(vertex); } + return true; + } - @Override - public DataType dateType() { - return filter.dateType(); + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + if (dataType == DataType.VE) { + return filter.filter(oneDegreeGraph); } + return true; + } - @Override - public FilterType getFilterType() { - return filter.getFilterType(); - } + @Override + public DataType dateType() { + return filter.dateType(); + } - @Override - public IFilter getFilter() { - return filter; - } + @Override + public FilterType getFilterType() { + return filter.getFilterType(); + } + + @Override + public IFilter getFilter() { + return filter; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/IGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/IGraphFilter.java index 2dd63916d..2958713c6 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/IGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/IGraphFilter.java @@ -29,63 +29,38 @@ public interface IGraphFilter extends IFilter { - /** - * An edge will be filtered if return is false. - */ - boolean filterEdge(IEdge edge); + /** An edge will be filtered if return is false. */ + boolean filterEdge(IEdge edge); - /** - * A Vertex will be filtered if return is false. - */ - boolean filterVertex(IVertex vertex); + /** A Vertex will be filtered if return is false. */ + boolean filterVertex(IVertex vertex); - /** - * A oneDegreeGraph will be filtered if return is false. - */ - boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph); + /** A oneDegreeGraph will be filtered if return is false. */ + boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph); - /** - * If this returns true, the edge scan will terminate. - */ - boolean dropAllRemaining(); + /** If this returns true, the edge scan will terminate. */ + boolean dropAllRemaining(); - /** - * FilterTypes. - */ - FilterType getFilterType(); + /** FilterTypes. */ + FilterType getFilterType(); - /** - * Union other filter with AND logic. - */ - IGraphFilter and(IGraphFilter other); + /** Union other filter with AND logic. */ + IGraphFilter and(IGraphFilter other); - /** - * Union other filter with or logic. - */ - IGraphFilter or(IGraphFilter other); + /** Union other filter with or logic. */ + IGraphFilter or(IGraphFilter other); - /** - * Check filter whether contains some type. - * Check itself If the filter is simple. - */ - boolean contains(FilterType type); + /** Check filter whether contains some type. Check itself If the filter is simple. */ + boolean contains(FilterType type); - /** - * retrieve target filter. - * return itself If the filter is simple. - */ - IGraphFilter retrieve(FilterType type); + /** retrieve target filter. return itself If the filter is simple. */ + IGraphFilter retrieve(FilterType type); - /** - * clone the filter. - */ - IGraphFilter clone(); + /** clone the filter. */ + IGraphFilter clone(); - /** - * get the inner filter, only direct filter converter support. - * return inner filter. - */ - default IFilter getFilter() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + /** get the inner filter, only direct filter converter support. return inner filter. */ + default IFilter getFilter() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilter.java index fa25db33f..9f9de1594 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilter.java @@ -32,85 +32,85 @@ public class LimitFilter extends BaseGraphFilter { - protected long inCounter; - protected long outCounter; - protected IGraphFilter filter; - - LimitFilter(IGraphFilter filter, IEdgeLimit limit) { - this.filter = filter; - inCounter = limit.inEdgeLimit(); - outCounter = limit.outEdgeLimit(); + protected long inCounter; + protected long outCounter; + protected IGraphFilter filter; + + LimitFilter(IGraphFilter filter, IEdgeLimit limit) { + this.filter = filter; + inCounter = limit.inEdgeLimit(); + outCounter = limit.outEdgeLimit(); + } + + @Override + public DataType dateType() { + return DataType.OTHER; + } + + @Override + public boolean filterEdge(IEdge edge) { + if (!filter.filterEdge(edge)) { + return false; } - - @Override - public DataType dateType() { - return DataType.OTHER; - } - - @Override - public boolean filterEdge(IEdge edge) { - if (!filter.filterEdge(edge)) { - return false; - } - if (edge.getDirect() == EdgeDirection.OUT && outCounter-- > 0) { - return true; - } else { - return edge.getDirect() == EdgeDirection.IN && inCounter-- > 0; - } - } - - @Override - public boolean filterVertex(IVertex vertex) { - return filter.filterVertex(vertex); - } - - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - return filter.filterOneDegreeGraph(oneDegreeGraph); - } - - @Override - public boolean dropAllRemaining() { - return (outCounter <= 0 && inCounter <= 0) || filter.dropAllRemaining(); - } - - @Override - public String toString() { - return String.format("Limit(%d, %d)", inCounter, outCounter); - } - - @Override - public FilterType getFilterType() { - return filter.getFilterType(); - } - - @Override - public IGraphFilter and(IGraphFilter other) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public IGraphFilter or(IGraphFilter other) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public boolean contains(FilterType type) { - return filter.contains(type); - } - - @Override - public IGraphFilter retrieve(FilterType type) { - return filter.retrieve(type); - } - - @Override - public IGraphFilter clone() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public IFilter getFilter() { - return filter.getFilter(); + if (edge.getDirect() == EdgeDirection.OUT && outCounter-- > 0) { + return true; + } else { + return edge.getDirect() == EdgeDirection.IN && inCounter-- > 0; } + } + + @Override + public boolean filterVertex(IVertex vertex) { + return filter.filterVertex(vertex); + } + + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + return filter.filterOneDegreeGraph(oneDegreeGraph); + } + + @Override + public boolean dropAllRemaining() { + return (outCounter <= 0 && inCounter <= 0) || filter.dropAllRemaining(); + } + + @Override + public String toString() { + return String.format("Limit(%d, %d)", inCounter, outCounter); + } + + @Override + public FilterType getFilterType() { + return filter.getFilterType(); + } + + @Override + public IGraphFilter and(IGraphFilter other) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public IGraphFilter or(IGraphFilter other) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public boolean contains(FilterType type) { + return filter.contains(type); + } + + @Override + public IGraphFilter retrieve(FilterType type) { + return filter.retrieve(type); + } + + @Override + public IGraphFilter clone() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public IFilter getFilter() { + return filter.getFilter(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilterBuilder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilterBuilder.java index 988bfc81c..6eeafc95e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilterBuilder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/LimitFilterBuilder.java @@ -25,10 +25,10 @@ public class LimitFilterBuilder { - public static LimitFilter build(IGraphFilter filter, IEdgeLimit limit) { - if (limit.limitType() == LimitType.SINGLE && filter.getFilterType() == FilterType.OR) { - return new SingleLimitFilter(filter, limit); - } - return new LimitFilter(filter, limit); + public static LimitFilter build(IGraphFilter filter, IEdgeLimit limit) { + if (limit.limitType() == LimitType.SINGLE && filter.getFilterType() == FilterType.OR) { + return new SingleLimitFilter(filter, limit); } + return new LimitFilter(filter, limit); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/OrGraphFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/OrGraphFilter.java index 62776c8d7..0bb72e1f4 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/OrGraphFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/OrGraphFilter.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.pushdown.filter.inner; import java.util.List; + import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.state.data.OneDegreeGraph; @@ -27,91 +28,89 @@ public class OrGraphFilter extends BaseComposeGraphFilter { - public OrGraphFilter(List childrenFilters) { - super(childrenFilters); - } + public OrGraphFilter(List childrenFilters) { + super(childrenFilters); + } - @Override - public IGraphFilter or(IGraphFilter filter) { - if (filter.getFilterType() == FilterType.OR) { - childrenFilters.addAll(((OrGraphFilter) (filter)).getFilterList()); - return this; - } else if (filter.getFilterType() == FilterType.EMPTY) { - return this; - } - childrenFilters.add(filter); - return this; + @Override + public IGraphFilter or(IGraphFilter filter) { + if (filter.getFilterType() == FilterType.OR) { + childrenFilters.addAll(((OrGraphFilter) (filter)).getFilterList()); + return this; + } else if (filter.getFilterType() == FilterType.EMPTY) { + return this; } + childrenFilters.add(filter); + return this; + } - /** - * An edge will be filtered if return is false. - * - * @param edge - */ - @Override - public boolean filterEdge(IEdge edge) { - if (childrenFilters.isEmpty()) { - return true; - } - for (IGraphFilter filter : childrenFilters) { - if (filter.filterEdge(edge)) { - return true; - } - } - return false; + /** + * An edge will be filtered if return is false. + * + * @param edge + */ + @Override + public boolean filterEdge(IEdge edge) { + if (childrenFilters.isEmpty()) { + return true; } - - /** - * A Vertex will be filtered if return is false. - * - * @param vertex - */ - @Override - public boolean filterVertex(IVertex vertex) { - if (childrenFilters.isEmpty()) { - return true; - } - for (IGraphFilter filter : childrenFilters) { - if (filter.filterVertex(vertex)) { - return true; - } - } - return false; + for (IGraphFilter filter : childrenFilters) { + if (filter.filterEdge(edge)) { + return true; + } } + return false; + } - /** - * A oneDegreeGraph will be filtered if return is false. - */ - @Override - public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { - if (childrenFilters.isEmpty()) { - return true; - } - for (IGraphFilter filter : childrenFilters) { - if (filter.filterOneDegreeGraph(oneDegreeGraph)) { - return true; - } - } - return false; + /** + * A Vertex will be filtered if return is false. + * + * @param vertex + */ + @Override + public boolean filterVertex(IVertex vertex) { + if (childrenFilters.isEmpty()) { + return true; } - - @Override - public boolean contains(FilterType type) { - return false; + for (IGraphFilter filter : childrenFilters) { + if (filter.filterVertex(vertex)) { + return true; + } } + return false; + } - @Override - public IGraphFilter retrieve(FilterType type) { - return null; + /** A oneDegreeGraph will be filtered if return is false. */ + @Override + public boolean filterOneDegreeGraph(OneDegreeGraph oneDegreeGraph) { + if (childrenFilters.isEmpty()) { + return true; } - - @Override - public OrGraphFilter clone() { - return new OrGraphFilter(cloneFilterList()); + for (IGraphFilter filter : childrenFilters) { + if (filter.filterOneDegreeGraph(oneDegreeGraph)) { + return true; + } } + return false; + } - @Override - public FilterType getFilterType() { - return FilterType.OR; - } + @Override + public boolean contains(FilterType type) { + return false; + } + + @Override + public IGraphFilter retrieve(FilterType type) { + return null; + } + + @Override + public OrGraphFilter clone() { + return new OrGraphFilter(cloneFilterList()); + } + + @Override + public FilterType getFilterType() { + return FilterType.OR; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/SingleLimitFilter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/SingleLimitFilter.java index 8e9617fb5..2f6c3583e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/SingleLimitFilter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/filter/inner/SingleLimitFilter.java @@ -20,58 +20,59 @@ package org.apache.geaflow.state.pushdown.filter.inner; import java.util.List; + import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.state.pushdown.limit.IEdgeLimit; public class SingleLimitFilter extends LimitFilter { - private List filters; - private final long[] outCounters; - private final long[] inCounters; - private long needHitMaxVersion; - private long hitMaxVersion; + private List filters; + private final long[] outCounters; + private final long[] inCounters; + private long needHitMaxVersion; + private long hitMaxVersion; - SingleLimitFilter(IGraphFilter filter, IEdgeLimit limit) { - super(filter, limit); - this.filters = ((OrGraphFilter) filter).childrenFilters; - this.outCounters = new long[filters.size()]; - this.inCounters = new long[filters.size()]; - this.needHitMaxVersion = filters.size() * (inCounter + outCounter); - if (this.needHitMaxVersion < 0) { - this.needHitMaxVersion = Long.MAX_VALUE; - } + SingleLimitFilter(IGraphFilter filter, IEdgeLimit limit) { + super(filter, limit); + this.filters = ((OrGraphFilter) filter).childrenFilters; + this.outCounters = new long[filters.size()]; + this.inCounters = new long[filters.size()]; + this.needHitMaxVersion = filters.size() * (inCounter + outCounter); + if (this.needHitMaxVersion < 0) { + this.needHitMaxVersion = Long.MAX_VALUE; } + } - @Override - public boolean filterEdge(IEdge edge) { - boolean keep = false; - int i = 0; - for (IGraphFilter filter : filters) { - if (filter.filterEdge(edge)) { - if (edge.getDirect() == EdgeDirection.OUT && outCounters[i] < outCounter) { - outCounters[i]++; - hitMaxVersion++; - keep = true; - } - if (edge.getDirect() == EdgeDirection.IN && inCounters[i] < inCounter) { - inCounters[i]++; - hitMaxVersion++; - keep = true; - } - } - i++; + @Override + public boolean filterEdge(IEdge edge) { + boolean keep = false; + int i = 0; + for (IGraphFilter filter : filters) { + if (filter.filterEdge(edge)) { + if (edge.getDirect() == EdgeDirection.OUT && outCounters[i] < outCounter) { + outCounters[i]++; + hitMaxVersion++; + keep = true; + } + if (edge.getDirect() == EdgeDirection.IN && inCounters[i] < inCounter) { + inCounters[i]++; + hitMaxVersion++; + keep = true; } - return keep; + } + i++; } + return keep; + } - @Override - public boolean dropAllRemaining() { - return hitMaxVersion >= needHitMaxVersion; - } + @Override + public boolean dropAllRemaining() { + return hitMaxVersion >= needHitMaxVersion; + } - @Override - public String toString() { - return String.format("%s(%d, %d)", getClass().getSimpleName(), inCounter, outCounter); - } + @Override + public String toString() { + return String.format("%s(%d, %d)", getClass().getSimpleName(), inCounter, outCounter); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverter.java index 5447a5a44..f65d98c0f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverter.java @@ -19,10 +19,6 @@ package org.apache.geaflow.state.pushdown.inner; -import com.google.common.base.Preconditions; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.protobuf.ProtocolStringList; import java.io.IOException; import java.io.InputStream; import java.nio.charset.Charset; @@ -33,6 +29,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; + import org.apache.commons.io.IOUtils; import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; @@ -49,294 +46,315 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.protobuf.ProtocolStringList; + public class CodeGenFilterConverter implements IFilterConverter { - private static final Logger LOGGER = LoggerFactory.getLogger(CodeGenFilterConverter.class); - private static final String CODE_GEN_PACKAGE = "org.apache.geaflow.state.pushdown.filter.inner."; - private static final String TEMPLATE_FILE_NAME = "Filter.template"; - private static final String FILTER_CLASS_HEADER = "GraphFilter_"; - private static final String TEMPLATE; - private static final AtomicLong COUNTER = new AtomicLong(0); - private static final int CACHE_SIZE = 1024; - private static final Cache FILTER_CACHE = - CacheBuilder.newBuilder().initialCapacity(CACHE_SIZE).build(); - - static { - try (InputStream is = CodeGenFilterConverter.class.getClassLoader().getResourceAsStream(TEMPLATE_FILE_NAME)) { - TEMPLATE = IOUtils.toString(is, Charset.defaultCharset()); - } catch (IOException e) { - throw new RuntimeException(e); - } + private static final Logger LOGGER = LoggerFactory.getLogger(CodeGenFilterConverter.class); + private static final String CODE_GEN_PACKAGE = "org.apache.geaflow.state.pushdown.filter.inner."; + private static final String TEMPLATE_FILE_NAME = "Filter.template"; + private static final String FILTER_CLASS_HEADER = "GraphFilter_"; + private static final String TEMPLATE; + private static final AtomicLong COUNTER = new AtomicLong(0); + private static final int CACHE_SIZE = 1024; + private static final Cache FILTER_CACHE = + CacheBuilder.newBuilder().initialCapacity(CACHE_SIZE).build(); + + static { + try (InputStream is = + CodeGenFilterConverter.class.getClassLoader().getResourceAsStream(TEMPLATE_FILE_NAME)) { + TEMPLATE = IOUtils.toString(is, Charset.defaultCharset()); + } catch (IOException e) { + throw new RuntimeException(e); } + } - @Override - public IFilter convert(IFilter origin) { - if (origin.getFilterType() == FilterType.EMPTY) { - return EmptyGraphFilter.of(); - } - - if (origin.getFilterType() == FilterType.OR && ((OrFilter) origin).isSingleLimit()) { - List list = ((OrFilter) origin).getFilters().stream() - .map(f -> (IGraphFilter) innerConvert(f)).collect(Collectors.toList()); - return new OrGraphFilter(list); - } - return innerConvert(origin); + @Override + public IFilter convert(IFilter origin) { + if (origin.getFilterType() == FilterType.EMPTY) { + return EmptyGraphFilter.of(); } - private IFilter innerConvert(IFilter origin) { - try { - FilterPlanWithData planWithData = FilterGenerator.getFilterPlanWithData(origin); - IGraphFilter filter = FILTER_CACHE.getIfPresent(planWithData.plan); - if (filter == null) { - filter = (IGraphFilter) convert(planWithData.plan); - FILTER_CACHE.put(planWithData.plan, filter); - } - IGraphFilter genFilter = filter.clone(); - VariableContext varContext = new VariableContext(); - variableGen(planWithData.data, varContext); - ((GeneratedQueryFilter) genFilter).initVariables(varContext.variables.toArray(new Object[0])); - return genFilter; - } catch (Exception ex) { - LOGGER.warn("code gen fail {}, return origin", ex.getMessage()); - return GraphFilter.of(origin); - } + if (origin.getFilterType() == FilterType.OR && ((OrFilter) origin).isSingleLimit()) { + List list = + ((OrFilter) origin) + .getFilters().stream() + .map(f -> (IGraphFilter) innerConvert(f)) + .collect(Collectors.toList()); + return new OrGraphFilter(list); } - - @Override - public IFilter convert(FilterNode filterNode) { - String className = FILTER_CLASS_HEADER + COUNTER.getAndIncrement(); - String src = codeGen(className, filterNode); - try { - SimpleCompiler compiler = new SimpleCompiler(); - compiler.cook(src); - Class aClass = compiler.getClassLoader().loadClass(CODE_GEN_PACKAGE + className); - return (IGraphFilter) aClass.newInstance(); - } catch (Exception e) { - LOGGER.error("code gen compile fail\n{}", src); - throw new RuntimeException(e); - } + return innerConvert(origin); + } + + private IFilter innerConvert(IFilter origin) { + try { + FilterPlanWithData planWithData = FilterGenerator.getFilterPlanWithData(origin); + IGraphFilter filter = FILTER_CACHE.getIfPresent(planWithData.plan); + if (filter == null) { + filter = (IGraphFilter) convert(planWithData.plan); + FILTER_CACHE.put(planWithData.plan, filter); + } + IGraphFilter genFilter = filter.clone(); + VariableContext varContext = new VariableContext(); + variableGen(planWithData.data, varContext); + ((GeneratedQueryFilter) genFilter).initVariables(varContext.variables.toArray(new Object[0])); + return genFilter; + } catch (Exception ex) { + LOGGER.warn("code gen fail {}, return origin", ex.getMessage()); + return GraphFilter.of(origin); } - - private String codeGen(String className, FilterNode filterNode) { - CodeGenContext context = new CodeGenContext(new AtomicInteger(0)); - innerCodeGen(filterNode, context); - return context.getCode(className); + } + + @Override + public IFilter convert(FilterNode filterNode) { + String className = FILTER_CLASS_HEADER + COUNTER.getAndIncrement(); + String src = codeGen(className, filterNode); + try { + SimpleCompiler compiler = new SimpleCompiler(); + compiler.cook(src); + Class aClass = compiler.getClassLoader().loadClass(CODE_GEN_PACKAGE + className); + return (IGraphFilter) aClass.newInstance(); + } catch (Exception e) { + LOGGER.error("code gen compile fail\n{}", src); + throw new RuntimeException(e); } - - private static void variableGen(FilterNode filterNode, VariableContext context) { - if (filterNode.getFiltersCount() == 0) { - switch (filterNode.getContentCase()) { - case INT_CONTENT: - context.variables.addAll(filterNode.getIntContent().getIntList()); - break; - case BYTES_CONTENT: - context.variables.addAll(filterNode.getBytesContent().getBytesList()); - break; - case LONG_CONTENT: - context.variables.addAll(filterNode.getLongContent().getLongList()); - break; - case STR_CONTENT: - ProtocolStringList list = filterNode.getStrContent().getStrList(); - PushDownPb.FilterType type = filterNode.getFilterType(); - if (type == PushDownPb.FilterType.VERTEX_LABEL || type == PushDownPb.FilterType.EDGE_LABEL) { - context.variables.add(new HashSet<>(list)); - } else { - context.variables.add(list); - } - break; - default: - } - } else { - List filterNodes = filterNode.getFiltersList(); - for (FilterNode inNode : filterNodes) { - variableGen(inNode, context); - } - } + } + + private String codeGen(String className, FilterNode filterNode) { + CodeGenContext context = new CodeGenContext(new AtomicInteger(0)); + innerCodeGen(filterNode, context); + return context.getCode(className); + } + + private static void variableGen(FilterNode filterNode, VariableContext context) { + if (filterNode.getFiltersCount() == 0) { + switch (filterNode.getContentCase()) { + case INT_CONTENT: + context.variables.addAll(filterNode.getIntContent().getIntList()); + break; + case BYTES_CONTENT: + context.variables.addAll(filterNode.getBytesContent().getBytesList()); + break; + case LONG_CONTENT: + context.variables.addAll(filterNode.getLongContent().getLongList()); + break; + case STR_CONTENT: + ProtocolStringList list = filterNode.getStrContent().getStrList(); + PushDownPb.FilterType type = filterNode.getFilterType(); + if (type == PushDownPb.FilterType.VERTEX_LABEL + || type == PushDownPb.FilterType.EDGE_LABEL) { + context.variables.add(new HashSet<>(list)); + } else { + context.variables.add(list); + } + break; + default: + } + } else { + List filterNodes = filterNode.getFiltersList(); + for (FilterNode inNode : filterNodes) { + variableGen(inNode, context); + } } - - private static void innerCodeGen(FilterNode filterNode, CodeGenContext context) { - PushDownPb.FilterType type = filterNode.getFilterType(); - switch (type) { - case AND: - CodeGenContext inContext = new CodeGenContext(context.varIdx); - for (FilterNode node : filterNode.getFiltersList()) { - innerCodeGen(node, inContext); - } - inContext.doAnd(); - context.merge(inContext); - break; - case OR: - inContext = new CodeGenContext(context.varIdx); - for (FilterNode node : filterNode.getFiltersList()) { - innerCodeGen(node, inContext); - } - inContext.doOr(); - context.merge(inContext); - break; - case VERTEX_LABEL: - context.addVertexPreCompute(type); - context.addVertexFormula(String.format("((Set)var[%d]).contains(label)", - context.varIdx.getAndIncrement())); - break; - case EDGE_LABEL: - context.addEdgePreCompute(type); - context.addEdgeFormula(String.format("((Set)var[%d]).contains(label)", - context.varIdx.getAndIncrement())); - break; - case VERTEX_TS: - context.addVertexPreCompute(type); - int start = context.varIdx.getAndIncrement(); - int end = context.varIdx.getAndIncrement(); - context.addVertexFormula(String.format("ts >= (Long)var[%d] && ts < (Long)var[%d]", start, end)); - break; - case EDGE_TS: - context.addEdgePreCompute(type); - start = context.varIdx.getAndIncrement(); - end = context.varIdx.getAndIncrement(); - context.addEdgeFormula(String.format("ts >= (Long)var[%d] && ts < (Long)var[%d]", start, end)); - break; - case IN_EDGE: - context.addEdgeFormula("edge.getDirect() == EdgeDirection.IN"); - break; - case OUT_EDGE: - context.addEdgeFormula("edge.getDirect() == EdgeDirection.OUT"); - break; - case VERTEX_VALUE_DROP: - context.addVertexPreCompute(type); - break; - case EDGE_VALUE_DROP: - context.addEdgePreCompute(type); - break; - case VERTEX_MUST_CONTAIN: - context.addOneDegreeFormula("oneDegreeGraph.getVertex() != null"); - break; - default: - throw new GeaflowRuntimeException("not find type" + type); + } + + private static void innerCodeGen(FilterNode filterNode, CodeGenContext context) { + PushDownPb.FilterType type = filterNode.getFilterType(); + switch (type) { + case AND: + CodeGenContext inContext = new CodeGenContext(context.varIdx); + for (FilterNode node : filterNode.getFiltersList()) { + innerCodeGen(node, inContext); } + inContext.doAnd(); + context.merge(inContext); + break; + case OR: + inContext = new CodeGenContext(context.varIdx); + for (FilterNode node : filterNode.getFiltersList()) { + innerCodeGen(node, inContext); + } + inContext.doOr(); + context.merge(inContext); + break; + case VERTEX_LABEL: + context.addVertexPreCompute(type); + context.addVertexFormula( + String.format( + "((Set)var[%d]).contains(label)", context.varIdx.getAndIncrement())); + break; + case EDGE_LABEL: + context.addEdgePreCompute(type); + context.addEdgeFormula( + String.format( + "((Set)var[%d]).contains(label)", context.varIdx.getAndIncrement())); + break; + case VERTEX_TS: + context.addVertexPreCompute(type); + int start = context.varIdx.getAndIncrement(); + int end = context.varIdx.getAndIncrement(); + context.addVertexFormula( + String.format("ts >= (Long)var[%d] && ts < (Long)var[%d]", start, end)); + break; + case EDGE_TS: + context.addEdgePreCompute(type); + start = context.varIdx.getAndIncrement(); + end = context.varIdx.getAndIncrement(); + context.addEdgeFormula( + String.format("ts >= (Long)var[%d] && ts < (Long)var[%d]", start, end)); + break; + case IN_EDGE: + context.addEdgeFormula("edge.getDirect() == EdgeDirection.IN"); + break; + case OUT_EDGE: + context.addEdgeFormula("edge.getDirect() == EdgeDirection.OUT"); + break; + case VERTEX_VALUE_DROP: + context.addVertexPreCompute(type); + break; + case EDGE_VALUE_DROP: + context.addEdgePreCompute(type); + break; + case VERTEX_MUST_CONTAIN: + context.addOneDegreeFormula("oneDegreeGraph.getVertex() != null"); + break; + default: + throw new GeaflowRuntimeException("not find type" + type); } + } - public static class CodeGenContext { + public static class CodeGenContext { - private Set edgePreFields = new HashSet<>(); - private Set vertexPreFields = new HashSet<>(); - private List edgeFormulas = new ArrayList<>(); - private List vertexFormulas = new ArrayList<>(); - private List oneDegreeFormulas = new ArrayList<>(); - private AtomicInteger varIdx; + private Set edgePreFields = new HashSet<>(); + private Set vertexPreFields = new HashSet<>(); + private List edgeFormulas = new ArrayList<>(); + private List vertexFormulas = new ArrayList<>(); + private List oneDegreeFormulas = new ArrayList<>(); + private AtomicInteger varIdx; - public CodeGenContext(AtomicInteger varIdx) { - this.varIdx = varIdx; - } + public CodeGenContext(AtomicInteger varIdx) { + this.varIdx = varIdx; + } - private String getPreComputeCode(PushDownPb.FilterType type) { - switch (type) { - case VERTEX_TS: - return "long ts = ((IGraphElementWithTimeField)vertex).getTime();"; - case EDGE_TS: - return "long ts = ((IGraphElementWithTimeField)edge).getTime();"; - case VERTEX_LABEL: - return "String label = ((IGraphElementWithLabelField)vertex).getLabel();"; - case EDGE_LABEL: - return "String label = ((IGraphElementWithLabelField)edge).getLabel();"; - case VERTEX_VALUE_DROP: - return "vertex.withValue(null);"; - case EDGE_VALUE_DROP: - return "edge.withValue(null)"; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - } + private String getPreComputeCode(PushDownPb.FilterType type) { + switch (type) { + case VERTEX_TS: + return "long ts = ((IGraphElementWithTimeField)vertex).getTime();"; + case EDGE_TS: + return "long ts = ((IGraphElementWithTimeField)edge).getTime();"; + case VERTEX_LABEL: + return "String label = ((IGraphElementWithLabelField)vertex).getLabel();"; + case EDGE_LABEL: + return "String label = ((IGraphElementWithLabelField)edge).getLabel();"; + case VERTEX_VALUE_DROP: + return "vertex.withValue(null);"; + case EDGE_VALUE_DROP: + return "edge.withValue(null)"; + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + } - public void addEdgeFormula(String code) { - edgeFormulas.add(code); - } + public void addEdgeFormula(String code) { + edgeFormulas.add(code); + } - public void addVertexFormula(String code) { - vertexFormulas.add(code); - } + public void addVertexFormula(String code) { + vertexFormulas.add(code); + } - public void addEdgePreCompute(PushDownPb.FilterType type) { - edgePreFields.add(type); - } + public void addEdgePreCompute(PushDownPb.FilterType type) { + edgePreFields.add(type); + } - public void addVertexPreCompute(PushDownPb.FilterType type) { - vertexPreFields.add(type); - } + public void addVertexPreCompute(PushDownPb.FilterType type) { + vertexPreFields.add(type); + } - public void doAnd() { - doMerge(vertexFormulas, " && "); - doMerge(edgeFormulas, " && "); - doMerge(oneDegreeFormulas, " && "); - } + public void doAnd() { + doMerge(vertexFormulas, " && "); + doMerge(edgeFormulas, " && "); + doMerge(oneDegreeFormulas, " && "); + } - public void doOr() { - doMerge(vertexFormulas, " || "); - doMerge(edgeFormulas, " || "); - doMerge(oneDegreeFormulas, " || "); - } + public void doOr() { + doMerge(vertexFormulas, " || "); + doMerge(edgeFormulas, " || "); + doMerge(oneDegreeFormulas, " || "); + } - private void doMerge(List formulas, String logic) { - if (formulas.size() <= 1) { - return; - } - StringBuilder mergedFormula = new StringBuilder(); - for (String formula : formulas) { - mergedFormula.append("(").append(formula).append(")").append(logic); - } - formulas.clear(); - formulas.add(mergedFormula.substring(0, mergedFormula.length() - logic.length())); - } + private void doMerge(List formulas, String logic) { + if (formulas.size() <= 1) { + return; + } + StringBuilder mergedFormula = new StringBuilder(); + for (String formula : formulas) { + mergedFormula.append("(").append(formula).append(")").append(logic); + } + formulas.clear(); + formulas.add(mergedFormula.substring(0, mergedFormula.length() - logic.length())); + } - public String getCode(String className) { - Preconditions.checkArgument(vertexFormulas.size() <= 1); - Preconditions.checkArgument(edgeFormulas.size() <= 1); - - String vertexCode = Boolean.TRUE.toString(); - String vertexPreCompute = ""; - if (vertexFormulas.size() == 1) { - StringBuilder preCompute = new StringBuilder(); - for (PushDownPb.FilterType filterType : vertexPreFields) { - preCompute.append(getPreComputeCode(filterType)).append("\n"); - } - vertexPreCompute = preCompute.toString(); - vertexCode = vertexFormulas.get(0); - } - String edgeCode = Boolean.TRUE.toString(); - String edgePreCompute = ""; - if (edgeFormulas.size() == 1) { - StringBuilder preCompute = new StringBuilder(); - for (PushDownPb.FilterType filterType : edgePreFields) { - preCompute.append(getPreComputeCode(filterType)).append("\n"); - } - edgePreCompute = preCompute.toString(); - edgeCode = edgeFormulas.get(0); - } - - String oneDegreeCode = Boolean.TRUE.toString(); - if (oneDegreeFormulas.size() == 1) { - oneDegreeCode = edgeFormulas.get(0); - } - - return String.format(TEMPLATE, className, className, vertexPreCompute, vertexCode, edgePreCompute, edgeCode, oneDegreeCode); + public String getCode(String className) { + Preconditions.checkArgument(vertexFormulas.size() <= 1); + Preconditions.checkArgument(edgeFormulas.size() <= 1); + String vertexCode = Boolean.TRUE.toString(); + String vertexPreCompute = ""; + if (vertexFormulas.size() == 1) { + StringBuilder preCompute = new StringBuilder(); + for (PushDownPb.FilterType filterType : vertexPreFields) { + preCompute.append(getPreComputeCode(filterType)).append("\n"); } - - public void merge(CodeGenContext inContext) { - vertexPreFields.addAll(inContext.vertexPreFields); - edgePreFields.addAll(inContext.edgePreFields); - - Preconditions.checkArgument(inContext.vertexFormulas.size() <= 1); - Preconditions.checkArgument(inContext.edgeFormulas.size() <= 1); - vertexFormulas.addAll(inContext.vertexFormulas); - edgeFormulas.addAll(inContext.edgeFormulas); + vertexPreCompute = preCompute.toString(); + vertexCode = vertexFormulas.get(0); + } + String edgeCode = Boolean.TRUE.toString(); + String edgePreCompute = ""; + if (edgeFormulas.size() == 1) { + StringBuilder preCompute = new StringBuilder(); + for (PushDownPb.FilterType filterType : edgePreFields) { + preCompute.append(getPreComputeCode(filterType)).append("\n"); } + edgePreCompute = preCompute.toString(); + edgeCode = edgeFormulas.get(0); + } + + String oneDegreeCode = Boolean.TRUE.toString(); + if (oneDegreeFormulas.size() == 1) { + oneDegreeCode = edgeFormulas.get(0); + } + + return String.format( + TEMPLATE, + className, + className, + vertexPreCompute, + vertexCode, + edgePreCompute, + edgeCode, + oneDegreeCode); + } - public void addOneDegreeFormula(String formula) { - oneDegreeFormulas.add(formula); - } + public void merge(CodeGenContext inContext) { + vertexPreFields.addAll(inContext.vertexPreFields); + edgePreFields.addAll(inContext.edgePreFields); + + Preconditions.checkArgument(inContext.vertexFormulas.size() <= 1); + Preconditions.checkArgument(inContext.edgeFormulas.size() <= 1); + vertexFormulas.addAll(inContext.vertexFormulas); + edgeFormulas.addAll(inContext.edgeFormulas); } - public static class VariableContext { - private List variables = new ArrayList<>(); + public void addOneDegreeFormula(String formula) { + oneDegreeFormulas.add(formula); } + } + + public static class VariableContext { + private List variables = new ArrayList<>(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/DirectFilterConverter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/DirectFilterConverter.java index 9c50276ee..007176d71 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/DirectFilterConverter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/DirectFilterConverter.java @@ -26,13 +26,13 @@ public class DirectFilterConverter implements IFilterConverter { - @Override - public IFilter convert(IFilter origin) { - return GraphFilter.of(origin); - } + @Override + public IFilter convert(IFilter origin) { + return GraphFilter.of(origin); + } - @Override - public IFilter convert(FilterNode filterNode) { - throw new GeaflowRuntimeException("not support"); - } + @Override + public IFilter convert(FilterNode filterNode) { + throw new GeaflowRuntimeException("not support"); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterGenerator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterGenerator.java index 90123cd47..37d0c72d3 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterGenerator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterGenerator.java @@ -24,6 +24,7 @@ import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.state.data.TimeRange; import org.apache.geaflow.state.pushdown.filter.AndFilter; @@ -41,109 +42,111 @@ public class FilterGenerator { - public static FilterPlanWithData getFilterPlanWithData(IFilter filter) { - FilterNode plan = getFilterPlan(filter); - FilterNode data = getFilterData(filter); - return new FilterPlanWithData(plan, data); - } + public static FilterPlanWithData getFilterPlanWithData(IFilter filter) { + FilterNode plan = getFilterPlan(filter); + FilterNode data = getFilterData(filter); + return new FilterPlanWithData(plan, data); + } - private static FilterNode getFilterPlan(IFilter filter) { - FilterType filterType = filter.getFilterType(); - switch (filterType) { - case AND: - return getLogicalPlan(filterType, ((AndFilter) filter).getFilters()); - case OR: - return getLogicalPlan(filterType, ((OrFilter) filter).getFilters()); - case VERTEX_LABEL: - case EDGE_LABEL: - case VERTEX_TS: - case EDGE_TS: - case IN_EDGE: - case OUT_EDGE: - case VERTEX_VALUE_DROP: - case EDGE_VALUE_DROP: - case VERTEX_MUST_CONTAIN: - return getNormalPlan(filterType); - default: - throw new GeaflowRuntimeException("not support user defined filter " + filter.getFilterType()); - } + private static FilterNode getFilterPlan(IFilter filter) { + FilterType filterType = filter.getFilterType(); + switch (filterType) { + case AND: + return getLogicalPlan(filterType, ((AndFilter) filter).getFilters()); + case OR: + return getLogicalPlan(filterType, ((OrFilter) filter).getFilters()); + case VERTEX_LABEL: + case EDGE_LABEL: + case VERTEX_TS: + case EDGE_TS: + case IN_EDGE: + case OUT_EDGE: + case VERTEX_VALUE_DROP: + case EDGE_VALUE_DROP: + case VERTEX_MUST_CONTAIN: + return getNormalPlan(filterType); + default: + throw new GeaflowRuntimeException( + "not support user defined filter " + filter.getFilterType()); } + } - private static FilterNode getLogicalPlan(FilterType filterType, List filters) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .addAllFilters(filters.stream() + private static FilterNode getLogicalPlan(FilterType filterType, List filters) { + return FilterNode.newBuilder() + .setFilterType(filterType.toPbFilterType()) + .addAllFilters( + filters.stream() .sorted(Comparator.comparingInt(o -> o.getFilterType().ordinal())) - .map(FilterGenerator::getFilterPlan).collect(Collectors.toList())) - .build(); - } + .map(FilterGenerator::getFilterPlan) + .collect(Collectors.toList())) + .build(); + } - private static FilterNode getNormalPlan(FilterType filterType) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .build(); - } + private static FilterNode getNormalPlan(FilterType filterType) { + return FilterNode.newBuilder().setFilterType(filterType.toPbFilterType()).build(); + } - public static FilterNode getFilterData(IFilter filter) { - FilterType filterType = filter.getFilterType(); - switch (filterType) { - case AND: - return getLogicalFilterData(filterType, ((AndFilter) filter).getFilters()); - case OR: - return getLogicalFilterData(filterType, ((OrFilter) filter).getFilters()); - case VERTEX_LABEL: - return getStringFilterData(filterType, ((VertexLabelFilter) filter).getLabels()); - case EDGE_LABEL: - return getStringFilterData(filterType, ((EdgeLabelFilter) filter).getLabels()); - case VERTEX_TS: - TimeRange range = ((VertexTsFilter) filter).getTimeRange(); - Collection longs = Arrays.asList(range.getStart(), range.getEnd()); - return getLongFilterData(filterType, longs); - case EDGE_TS: - range = ((EdgeTsFilter) filter).getTimeRange(); - longs = Arrays.asList(range.getStart(), range.getEnd()); - return getLongFilterData(filterType, longs); - case IN_EDGE: - case OUT_EDGE: - case VERTEX_VALUE_DROP: - case EDGE_VALUE_DROP: - case VERTEX_MUST_CONTAIN: - case EMPTY: - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .build(); - default: - throw new GeaflowRuntimeException("not support user defined filter " + filter.getFilterType()); - } + public static FilterNode getFilterData(IFilter filter) { + FilterType filterType = filter.getFilterType(); + switch (filterType) { + case AND: + return getLogicalFilterData(filterType, ((AndFilter) filter).getFilters()); + case OR: + return getLogicalFilterData(filterType, ((OrFilter) filter).getFilters()); + case VERTEX_LABEL: + return getStringFilterData(filterType, ((VertexLabelFilter) filter).getLabels()); + case EDGE_LABEL: + return getStringFilterData(filterType, ((EdgeLabelFilter) filter).getLabels()); + case VERTEX_TS: + TimeRange range = ((VertexTsFilter) filter).getTimeRange(); + Collection longs = Arrays.asList(range.getStart(), range.getEnd()); + return getLongFilterData(filterType, longs); + case EDGE_TS: + range = ((EdgeTsFilter) filter).getTimeRange(); + longs = Arrays.asList(range.getStart(), range.getEnd()); + return getLongFilterData(filterType, longs); + case IN_EDGE: + case OUT_EDGE: + case VERTEX_VALUE_DROP: + case EDGE_VALUE_DROP: + case VERTEX_MUST_CONTAIN: + case EMPTY: + return FilterNode.newBuilder().setFilterType(filterType.toPbFilterType()).build(); + default: + throw new GeaflowRuntimeException( + "not support user defined filter " + filter.getFilterType()); } + } - private static FilterNode getLogicalFilterData(FilterType filterType, List filters) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .addAllFilters(filters.stream() + private static FilterNode getLogicalFilterData(FilterType filterType, List filters) { + return FilterNode.newBuilder() + .setFilterType(filterType.toPbFilterType()) + .addAllFilters( + filters.stream() .sorted(Comparator.comparingInt(o -> o.getFilterType().ordinal())) - .map(FilterGenerator::getFilterData).collect(Collectors.toList())) - .build(); - } + .map(FilterGenerator::getFilterData) + .collect(Collectors.toList())) + .build(); + } - private static FilterNode getStringFilterData(FilterType filterType, Collection strs) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .setStrContent(StringList.newBuilder().addAllStr(strs).build()) - .build(); - } + private static FilterNode getStringFilterData(FilterType filterType, Collection strs) { + return FilterNode.newBuilder() + .setFilterType(filterType.toPbFilterType()) + .setStrContent(StringList.newBuilder().addAllStr(strs).build()) + .build(); + } - private static FilterNode getLongFilterData(FilterType filterType, Collection longs) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .setLongContent(LongList.newBuilder().addAllLong(longs).build()) - .build(); - } + private static FilterNode getLongFilterData(FilterType filterType, Collection longs) { + return FilterNode.newBuilder() + .setFilterType(filterType.toPbFilterType()) + .setLongContent(LongList.newBuilder().addAllLong(longs).build()) + .build(); + } - private static FilterNode getIntFilterData(FilterType filterType, Collection ints) { - return FilterNode.newBuilder() - .setFilterType(filterType.toPbFilterType()) - .setIntContent(IntList.newBuilder().addAllInt(ints).build()) - .build(); - } + private static FilterNode getIntFilterData(FilterType filterType, Collection ints) { + return FilterNode.newBuilder() + .setFilterType(filterType.toPbFilterType()) + .setIntContent(IntList.newBuilder().addAllInt(ints).build()) + .build(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterPlanWithData.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterPlanWithData.java index 82faf2066..8a612a2c0 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterPlanWithData.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/FilterPlanWithData.java @@ -23,11 +23,11 @@ public class FilterPlanWithData { - public FilterNode plan; - public FilterNode data; + public FilterNode plan; + public FilterNode data; - public FilterPlanWithData(FilterNode plan, FilterNode data) { - this.plan = plan; - this.data = data; - } + public FilterPlanWithData(FilterNode plan, FilterNode data) { + this.plan = plan; + this.data = data; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/IFilterConverter.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/IFilterConverter.java index a3d8f1827..48e18365e 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/IFilterConverter.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/IFilterConverter.java @@ -23,18 +23,14 @@ import org.apache.geaflow.state.pushdown.inner.PushDownPb.FilterNode; /** - * The filter converter is used to convert the user filter - * to the filter recognized or optimized by the underground store. + * The filter converter is used to convert the user filter to the filter recognized or optimized by + * the underground store. */ public interface IFilterConverter { - /** - * Returns the converted filter from the original filter. - */ - IFilter convert(IFilter origin); + /** Returns the converted filter from the original filter. */ + IFilter convert(IFilter origin); - /** - * Returns the converted filter from the protobuf formatted filter. - */ - IFilter convert(FilterNode filterNode); + /** Returns the converted filter from the protobuf formatted filter. */ + IFilter convert(FilterNode filterNode); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPb.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPb.java index 14df7425c..4f991e489 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPb.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPb.java @@ -22,8204 +22,7303 @@ package org.apache.geaflow.state.pushdown.inner; public final class PushDownPb { - private PushDownPb() { + private PushDownPb() {} + + public static void registerAllExtensions(com.google.protobuf.ExtensionRegistryLite registry) {} + + public static void registerAllExtensions(com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions((com.google.protobuf.ExtensionRegistryLite) registry); + } + + /** Protobuf enum {@code FilterType} */ + public enum FilterType implements com.google.protobuf.ProtocolMessageEnum { + /** EMPTY = 0; */ + EMPTY(0), + /** ONLY_VERTEX = 1; */ + ONLY_VERTEX(1), + /** IN_EDGE = 2; */ + IN_EDGE(2), + /** OUT_EDGE = 3; */ + OUT_EDGE(3), + /** VERTEX_TS = 4; */ + VERTEX_TS(4), + /** EDGE_TS = 5; */ + EDGE_TS(5), + /** MULTI_EDGE_TS = 6; */ + MULTI_EDGE_TS(6), + /** VERTEX_LABEL = 7; */ + VERTEX_LABEL(7), + /** EDGE_LABEL = 8; */ + EDGE_LABEL(8), + /** VERTEX_VALUE_DROP = 9; */ + VERTEX_VALUE_DROP(9), + /** EDGE_VALUE_DROP = 10; */ + EDGE_VALUE_DROP(10), + /** TTL = 11; */ + TTL(11), + /** AND = 12; */ + AND(12), + /** OR = 13; */ + OR(13), + /** VERTEX_MUST_CONTAIN = 14; */ + VERTEX_MUST_CONTAIN(14), + /** GENERATED = 15; */ + GENERATED(15), + /** OTHER = 16; */ + OTHER(16), + UNRECOGNIZED(-1), + ; + + /** EMPTY = 0; */ + public static final int EMPTY_VALUE = 0; + + /** ONLY_VERTEX = 1; */ + public static final int ONLY_VERTEX_VALUE = 1; + + /** IN_EDGE = 2; */ + public static final int IN_EDGE_VALUE = 2; + + /** OUT_EDGE = 3; */ + public static final int OUT_EDGE_VALUE = 3; + + /** VERTEX_TS = 4; */ + public static final int VERTEX_TS_VALUE = 4; + + /** EDGE_TS = 5; */ + public static final int EDGE_TS_VALUE = 5; + + /** MULTI_EDGE_TS = 6; */ + public static final int MULTI_EDGE_TS_VALUE = 6; + + /** VERTEX_LABEL = 7; */ + public static final int VERTEX_LABEL_VALUE = 7; + + /** EDGE_LABEL = 8; */ + public static final int EDGE_LABEL_VALUE = 8; + + /** VERTEX_VALUE_DROP = 9; */ + public static final int VERTEX_VALUE_DROP_VALUE = 9; + + /** EDGE_VALUE_DROP = 10; */ + public static final int EDGE_VALUE_DROP_VALUE = 10; + + /** TTL = 11; */ + public static final int TTL_VALUE = 11; + + /** AND = 12; */ + public static final int AND_VALUE = 12; + + /** OR = 13; */ + public static final int OR_VALUE = 13; + + /** VERTEX_MUST_CONTAIN = 14; */ + public static final int VERTEX_MUST_CONTAIN_VALUE = 14; + + /** GENERATED = 15; */ + public static final int GENERATED_VALUE = 15; + + /** OTHER = 16; */ + public static final int OTHER_VALUE = 16; + + public final int getNumber() { + if (this == UNRECOGNIZED) { + throw new java.lang.IllegalArgumentException( + "Can't get the number of an unknown enum value."); + } + return value; } - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistryLite registry) { + /** + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static FilterType valueOf(int value) { + return forNumber(value); } - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistry registry) { - registerAllExtensions( - (com.google.protobuf.ExtensionRegistryLite) registry); + public static FilterType forNumber(int value) { + switch (value) { + case 0: + return EMPTY; + case 1: + return ONLY_VERTEX; + case 2: + return IN_EDGE; + case 3: + return OUT_EDGE; + case 4: + return VERTEX_TS; + case 5: + return EDGE_TS; + case 6: + return MULTI_EDGE_TS; + case 7: + return VERTEX_LABEL; + case 8: + return EDGE_LABEL; + case 9: + return VERTEX_VALUE_DROP; + case 10: + return EDGE_VALUE_DROP; + case 11: + return TTL; + case 12: + return AND; + case 13: + return OR; + case 14: + return VERTEX_MUST_CONTAIN; + case 15: + return GENERATED; + case 16: + return OTHER; + default: + return null; + } } - /** - * Protobuf enum {@code FilterType} - */ - public enum FilterType - implements com.google.protobuf.ProtocolMessageEnum { - /** - * EMPTY = 0; - */ - EMPTY(0), - /** - * ONLY_VERTEX = 1; - */ - ONLY_VERTEX(1), - /** - * IN_EDGE = 2; - */ - IN_EDGE(2), - /** - * OUT_EDGE = 3; - */ - OUT_EDGE(3), - /** - * VERTEX_TS = 4; - */ - VERTEX_TS(4), - /** - * EDGE_TS = 5; - */ - EDGE_TS(5), - /** - * MULTI_EDGE_TS = 6; - */ - MULTI_EDGE_TS(6), - /** - * VERTEX_LABEL = 7; - */ - VERTEX_LABEL(7), - /** - * EDGE_LABEL = 8; - */ - EDGE_LABEL(8), - /** - * VERTEX_VALUE_DROP = 9; - */ - VERTEX_VALUE_DROP(9), - /** - * EDGE_VALUE_DROP = 10; - */ - EDGE_VALUE_DROP(10), - /** - * TTL = 11; - */ - TTL(11), - /** - * AND = 12; - */ - AND(12), - /** - * OR = 13; - */ - OR(13), - /** - * VERTEX_MUST_CONTAIN = 14; - */ - VERTEX_MUST_CONTAIN(14), - /** - * GENERATED = 15; - */ - GENERATED(15), - /** - * OTHER = 16; - */ - OTHER(16), - UNRECOGNIZED(-1), - ; - - /** - * EMPTY = 0; - */ - public static final int EMPTY_VALUE = 0; - /** - * ONLY_VERTEX = 1; - */ - public static final int ONLY_VERTEX_VALUE = 1; - /** - * IN_EDGE = 2; - */ - public static final int IN_EDGE_VALUE = 2; - /** - * OUT_EDGE = 3; - */ - public static final int OUT_EDGE_VALUE = 3; - /** - * VERTEX_TS = 4; - */ - public static final int VERTEX_TS_VALUE = 4; - /** - * EDGE_TS = 5; - */ - public static final int EDGE_TS_VALUE = 5; - /** - * MULTI_EDGE_TS = 6; - */ - public static final int MULTI_EDGE_TS_VALUE = 6; - /** - * VERTEX_LABEL = 7; - */ - public static final int VERTEX_LABEL_VALUE = 7; - /** - * EDGE_LABEL = 8; - */ - public static final int EDGE_LABEL_VALUE = 8; - /** - * VERTEX_VALUE_DROP = 9; - */ - public static final int VERTEX_VALUE_DROP_VALUE = 9; - /** - * EDGE_VALUE_DROP = 10; - */ - public static final int EDGE_VALUE_DROP_VALUE = 10; - /** - * TTL = 11; - */ - public static final int TTL_VALUE = 11; - /** - * AND = 12; - */ - public static final int AND_VALUE = 12; - /** - * OR = 13; - */ - public static final int OR_VALUE = 13; - /** - * VERTEX_MUST_CONTAIN = 14; - */ - public static final int VERTEX_MUST_CONTAIN_VALUE = 14; - /** - * GENERATED = 15; - */ - public static final int GENERATED_VALUE = 15; - /** - * OTHER = 16; - */ - public static final int OTHER_VALUE = 16; - - - public final int getNumber() { - if (this == UNRECOGNIZED) { - throw new java.lang.IllegalArgumentException( - "Can't get the number of an unknown enum value."); - } - return value; - } - - /** - * @deprecated Use {@link #forNumber(int)} instead. - */ - @java.lang.Deprecated - public static FilterType valueOf(int value) { - return forNumber(value); - } - - public static FilterType forNumber(int value) { - switch (value) { - case 0: - return EMPTY; - case 1: - return ONLY_VERTEX; - case 2: - return IN_EDGE; - case 3: - return OUT_EDGE; - case 4: - return VERTEX_TS; - case 5: - return EDGE_TS; - case 6: - return MULTI_EDGE_TS; - case 7: - return VERTEX_LABEL; - case 8: - return EDGE_LABEL; - case 9: - return VERTEX_VALUE_DROP; - case 10: - return EDGE_VALUE_DROP; - case 11: - return TTL; - case 12: - return AND; - case 13: - return OR; - case 14: - return VERTEX_MUST_CONTAIN; - case 15: - return GENERATED; - case 16: - return OTHER; - default: - return null; - } - } - - public static com.google.protobuf.Internal.EnumLiteMap - internalGetValueMap() { - return internalValueMap; - } + public static com.google.protobuf.Internal.EnumLiteMap internalGetValueMap() { + return internalValueMap; + } - private static final com.google.protobuf.Internal.EnumLiteMap< - FilterType> internalValueMap = - new com.google.protobuf.Internal.EnumLiteMap() { - public FilterType findValueByNumber(int number) { - return FilterType.forNumber(number); - } - }; + private static final com.google.protobuf.Internal.EnumLiteMap internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public FilterType findValueByNumber(int number) { + return FilterType.forNumber(number); + } + }; - public final com.google.protobuf.Descriptors.EnumValueDescriptor - getValueDescriptor() { - return getDescriptor().getValues().get(ordinal()); - } + public final com.google.protobuf.Descriptors.EnumValueDescriptor getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } - public final com.google.protobuf.Descriptors.EnumDescriptor - getDescriptorForType() { - return getDescriptor(); - } + public final com.google.protobuf.Descriptors.EnumDescriptor getDescriptorForType() { + return getDescriptor(); + } - public static final com.google.protobuf.Descriptors.EnumDescriptor - getDescriptor() { - return PushDownPb.getDescriptor().getEnumTypes().get(0); - } + public static final com.google.protobuf.Descriptors.EnumDescriptor getDescriptor() { + return PushDownPb.getDescriptor().getEnumTypes().get(0); + } - private static final FilterType[] VALUES = values(); + private static final FilterType[] VALUES = values(); - public static FilterType valueOf( - com.google.protobuf.Descriptors.EnumValueDescriptor desc) { - if (desc.getType() != getDescriptor()) { - throw new java.lang.IllegalArgumentException( - "EnumValueDescriptor is not for this type."); - } - if (desc.getIndex() == -1) { - return UNRECOGNIZED; - } - return VALUES[desc.getIndex()]; - } + public static FilterType valueOf(com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException("EnumValueDescriptor is not for this type."); + } + if (desc.getIndex() == -1) { + return UNRECOGNIZED; + } + return VALUES[desc.getIndex()]; + } - private final int value; + private final int value; - private FilterType(int value) { - this.value = value; - } + private FilterType(int value) { + this.value = value; + } - // @@protoc_insertion_point(enum_scope:FilterType) + // @@protoc_insertion_point(enum_scope:FilterType) + } + + /** Protobuf enum {@code SortType} */ + public enum SortType implements com.google.protobuf.ProtocolMessageEnum { + /** SRC_ID = 0; */ + SRC_ID(0), + /** DIRECTION = 1; */ + DIRECTION(1), + /** DESC_TIME = 2; */ + DESC_TIME(2), + /** TIME = 3; */ + TIME(3), + /** LABEL = 4; */ + LABEL(4), + /** DST_ID = 5; */ + DST_ID(5), + UNRECOGNIZED(-1), + ; + + /** SRC_ID = 0; */ + public static final int SRC_ID_VALUE = 0; + + /** DIRECTION = 1; */ + public static final int DIRECTION_VALUE = 1; + + /** DESC_TIME = 2; */ + public static final int DESC_TIME_VALUE = 2; + + /** TIME = 3; */ + public static final int TIME_VALUE = 3; + + /** LABEL = 4; */ + public static final int LABEL_VALUE = 4; + + /** DST_ID = 5; */ + public static final int DST_ID_VALUE = 5; + + public final int getNumber() { + if (this == UNRECOGNIZED) { + throw new java.lang.IllegalArgumentException( + "Can't get the number of an unknown enum value."); + } + return value; } /** - * Protobuf enum {@code SortType} + * @deprecated Use {@link #forNumber(int)} instead. */ - public enum SortType - implements com.google.protobuf.ProtocolMessageEnum { - /** - * SRC_ID = 0; - */ - SRC_ID(0), - /** - * DIRECTION = 1; - */ - DIRECTION(1), - /** - * DESC_TIME = 2; - */ - DESC_TIME(2), - /** - * TIME = 3; - */ - TIME(3), - /** - * LABEL = 4; - */ - LABEL(4), - /** - * DST_ID = 5; - */ - DST_ID(5), - UNRECOGNIZED(-1), - ; - - /** - * SRC_ID = 0; - */ - public static final int SRC_ID_VALUE = 0; - /** - * DIRECTION = 1; - */ - public static final int DIRECTION_VALUE = 1; - /** - * DESC_TIME = 2; - */ - public static final int DESC_TIME_VALUE = 2; - /** - * TIME = 3; - */ - public static final int TIME_VALUE = 3; - /** - * LABEL = 4; - */ - public static final int LABEL_VALUE = 4; - /** - * DST_ID = 5; - */ - public static final int DST_ID_VALUE = 5; - - - public final int getNumber() { - if (this == UNRECOGNIZED) { - throw new java.lang.IllegalArgumentException( - "Can't get the number of an unknown enum value."); - } - return value; - } - - /** - * @deprecated Use {@link #forNumber(int)} instead. - */ - @java.lang.Deprecated - public static SortType valueOf(int value) { - return forNumber(value); - } - - public static SortType forNumber(int value) { - switch (value) { - case 0: - return SRC_ID; - case 1: - return DIRECTION; - case 2: - return DESC_TIME; - case 3: - return TIME; - case 4: - return LABEL; - case 5: - return DST_ID; - default: - return null; - } - } + @java.lang.Deprecated + public static SortType valueOf(int value) { + return forNumber(value); + } - public static com.google.protobuf.Internal.EnumLiteMap - internalGetValueMap() { - return internalValueMap; - } + public static SortType forNumber(int value) { + switch (value) { + case 0: + return SRC_ID; + case 1: + return DIRECTION; + case 2: + return DESC_TIME; + case 3: + return TIME; + case 4: + return LABEL; + case 5: + return DST_ID; + default: + return null; + } + } - private static final com.google.protobuf.Internal.EnumLiteMap< - SortType> internalValueMap = - new com.google.protobuf.Internal.EnumLiteMap() { - public SortType findValueByNumber(int number) { - return SortType.forNumber(number); - } - }; + public static com.google.protobuf.Internal.EnumLiteMap internalGetValueMap() { + return internalValueMap; + } - public final com.google.protobuf.Descriptors.EnumValueDescriptor - getValueDescriptor() { - return getDescriptor().getValues().get(ordinal()); - } + private static final com.google.protobuf.Internal.EnumLiteMap internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public SortType findValueByNumber(int number) { + return SortType.forNumber(number); + } + }; - public final com.google.protobuf.Descriptors.EnumDescriptor - getDescriptorForType() { - return getDescriptor(); - } + public final com.google.protobuf.Descriptors.EnumValueDescriptor getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } - public static final com.google.protobuf.Descriptors.EnumDescriptor - getDescriptor() { - return PushDownPb.getDescriptor().getEnumTypes().get(1); - } + public final com.google.protobuf.Descriptors.EnumDescriptor getDescriptorForType() { + return getDescriptor(); + } - private static final SortType[] VALUES = values(); + public static final com.google.protobuf.Descriptors.EnumDescriptor getDescriptor() { + return PushDownPb.getDescriptor().getEnumTypes().get(1); + } - public static SortType valueOf( - com.google.protobuf.Descriptors.EnumValueDescriptor desc) { - if (desc.getType() != getDescriptor()) { - throw new java.lang.IllegalArgumentException( - "EnumValueDescriptor is not for this type."); - } - if (desc.getIndex() == -1) { - return UNRECOGNIZED; - } - return VALUES[desc.getIndex()]; - } + private static final SortType[] VALUES = values(); - private final int value; + public static SortType valueOf(com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException("EnumValueDescriptor is not for this type."); + } + if (desc.getIndex() == -1) { + return UNRECOGNIZED; + } + return VALUES[desc.getIndex()]; + } - private SortType(int value) { - this.value = value; - } + private final int value; - // @@protoc_insertion_point(enum_scope:SortType) + private SortType(int value) { + this.value = value; } - public interface PushDownOrBuilder extends - // @@protoc_insertion_point(interface_extends:PushDown) - com.google.protobuf.MessageOrBuilder { + // @@protoc_insertion_point(enum_scope:SortType) + } - /** - * .FilterNode filter_node = 1; - */ - boolean hasFilterNode(); + public interface PushDownOrBuilder + extends + // @@protoc_insertion_point(interface_extends:PushDown) + com.google.protobuf.MessageOrBuilder { - /** - * .FilterNode filter_node = 1; - */ - PushDownPb.FilterNode getFilterNode(); + /** .FilterNode filter_node = 1; */ + boolean hasFilterNode(); - /** - * .FilterNode filter_node = 1; - */ - PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder(); + /** .FilterNode filter_node = 1; */ + PushDownPb.FilterNode getFilterNode(); - /** - * .FilterNodes filter_nodes = 2; - */ - boolean hasFilterNodes(); + /** .FilterNode filter_node = 1; */ + PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder(); - /** - * .FilterNodes filter_nodes = 2; - */ - PushDownPb.FilterNodes getFilterNodes(); + /** .FilterNodes filter_nodes = 2; */ + boolean hasFilterNodes(); - /** - * .FilterNodes filter_nodes = 2; - */ - PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder(); + /** .FilterNodes filter_nodes = 2; */ + PushDownPb.FilterNodes getFilterNodes(); - /** - * .EdgeLimit edge_limit = 3; - */ - boolean hasEdgeLimit(); + /** .FilterNodes filter_nodes = 2; */ + PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder(); - /** - * .EdgeLimit edge_limit = 3; - */ - PushDownPb.EdgeLimit getEdgeLimit(); + /** .EdgeLimit edge_limit = 3; */ + boolean hasEdgeLimit(); - /** - * .EdgeLimit edge_limit = 3; - */ - PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder(); + /** .EdgeLimit edge_limit = 3; */ + PushDownPb.EdgeLimit getEdgeLimit(); - /** - * repeated .SortType sort_type = 4; - */ - java.util.List getSortTypeList(); + /** .EdgeLimit edge_limit = 3; */ + PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder(); - /** - * repeated .SortType sort_type = 4; - */ - int getSortTypeCount(); + /** repeated .SortType sort_type = 4; */ + java.util.List getSortTypeList(); - /** - * repeated .SortType sort_type = 4; - */ - PushDownPb.SortType getSortType(int index); + /** repeated .SortType sort_type = 4; */ + int getSortTypeCount(); - /** - * repeated .SortType sort_type = 4; - */ - java.util.List - getSortTypeValueList(); + /** repeated .SortType sort_type = 4; */ + PushDownPb.SortType getSortType(int index); - /** - * repeated .SortType sort_type = 4; - */ - int getSortTypeValue(int index); + /** repeated .SortType sort_type = 4; */ + java.util.List getSortTypeValueList(); - public PushDownPb.PushDown.FilterCase getFilterCase(); - } + /** repeated .SortType sort_type = 4; */ + int getSortTypeValue(int index); - /** - * Protobuf type {@code PushDown} - */ - public static final class PushDown extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:PushDown) - PushDownOrBuilder { - private static final long serialVersionUID = 0L; + public PushDownPb.PushDown.FilterCase getFilterCase(); + } - // Use PushDown.newBuilder() to construct. - private PushDown(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } + /** Protobuf type {@code PushDown} */ + public static final class PushDown extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:PushDown) + PushDownOrBuilder { + private static final long serialVersionUID = 0L; - private PushDown() { - sortType_ = java.util.Collections.emptyList(); - } + // Use PushDown.newBuilder() to construct. + private PushDown(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new PushDown(); - } + private PushDown() { + sortType_ = java.util.Collections.emptyList(); + } - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new PushDown(); + } - private PushDown( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - PushDownPb.FilterNode.Builder subBuilder = null; - if (filterCase_ == 1) { - subBuilder = ((PushDownPb.FilterNode) filter_).toBuilder(); - } - filter_ = - input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.FilterNode) filter_); - filter_ = subBuilder.buildPartial(); - } - filterCase_ = 1; - break; - } - case 18: { - PushDownPb.FilterNodes.Builder subBuilder = null; - if (filterCase_ == 2) { - subBuilder = ((PushDownPb.FilterNodes) filter_).toBuilder(); - } - filter_ = - input.readMessage(PushDownPb.FilterNodes.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.FilterNodes) filter_); - filter_ = subBuilder.buildPartial(); - } - filterCase_ = 2; - break; - } - case 26: { - PushDownPb.EdgeLimit.Builder subBuilder = null; - if (edgeLimit_ != null) { - subBuilder = edgeLimit_.toBuilder(); - } - edgeLimit_ = input.readMessage(PushDownPb.EdgeLimit.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom(edgeLimit_); - edgeLimit_ = subBuilder.buildPartial(); - } - - break; - } - case 32: { - int rawValue = input.readEnum(); - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - sortType_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - sortType_.add(rawValue); - break; - } - case 34: { - int length = input.readRawVarint32(); - int oldLimit = input.pushLimit(length); - while (input.getBytesUntilLimit() > 0) { - int rawValue = input.readEnum(); - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - sortType_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - sortType_.add(rawValue); - } - input.popLimit(oldLimit); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } + + private PushDown( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: + { + PushDownPb.FilterNode.Builder subBuilder = null; + if (filterCase_ == 1) { + subBuilder = ((PushDownPb.FilterNode) filter_).toBuilder(); } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - sortType_ = java.util.Collections.unmodifiableList(sortType_); + filter_ = input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.FilterNode) filter_); + filter_ = subBuilder.buildPartial(); } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + filterCase_ = 1; + break; + } + case 18: + { + PushDownPb.FilterNodes.Builder subBuilder = null; + if (filterCase_ == 2) { + subBuilder = ((PushDownPb.FilterNodes) filter_).toBuilder(); + } + filter_ = input.readMessage(PushDownPb.FilterNodes.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.FilterNodes) filter_); + filter_ = subBuilder.buildPartial(); + } + filterCase_ = 2; + break; + } + case 26: + { + PushDownPb.EdgeLimit.Builder subBuilder = null; + if (edgeLimit_ != null) { + subBuilder = edgeLimit_.toBuilder(); + } + edgeLimit_ = input.readMessage(PushDownPb.EdgeLimit.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(edgeLimit_); + edgeLimit_ = subBuilder.buildPartial(); + } + + break; + } + case 32: + { + int rawValue = input.readEnum(); + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + sortType_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000001; + } + sortType_.add(rawValue); + break; + } + case 34: + { + int length = input.readRawVarint32(); + int oldLimit = input.pushLimit(length); + while (input.getBytesUntilLimit() > 0) { + int rawValue = input.readEnum(); + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + sortType_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000001; + } + sortType_.add(rawValue); + } + input.popLimit(oldLimit); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + sortType_ = java.util.Collections.unmodifiableList(sortType_); + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_PushDown_descriptor; - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_PushDown_descriptor; + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return PushDownPb.internal_static_PushDown_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.PushDown.class, PushDownPb.PushDown.Builder.class); - } + return PushDownPb.internal_static_PushDown_fieldAccessorTable.ensureFieldAccessorsInitialized( + PushDownPb.PushDown.class, PushDownPb.PushDown.Builder.class); + } - private int filterCase_ = 0; - private java.lang.Object filter_; + private int filterCase_ = 0; + private java.lang.Object filter_; + + public enum FilterCase implements com.google.protobuf.Internal.EnumLite { + FILTER_NODE(1), + FILTER_NODES(2), + FILTER_NOT_SET(0); + private final int value; + + private FilterCase(int value) { + this.value = value; + } + + /** + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static FilterCase valueOf(int value) { + return forNumber(value); + } + + public static FilterCase forNumber(int value) { + switch (value) { + case 1: + return FILTER_NODE; + case 2: + return FILTER_NODES; + case 0: + return FILTER_NOT_SET; + default: + return null; + } + } + + public int getNumber() { + return this.value; + } + }; + + public FilterCase getFilterCase() { + return FilterCase.forNumber(filterCase_); + } - public enum FilterCase - implements com.google.protobuf.Internal.EnumLite { - FILTER_NODE(1), - FILTER_NODES(2), - FILTER_NOT_SET(0); - private final int value; + public static final int FILTER_NODE_FIELD_NUMBER = 1; - private FilterCase(int value) { - this.value = value; - } + /** .FilterNode filter_node = 1; */ + public boolean hasFilterNode() { + return filterCase_ == 1; + } - /** - * @deprecated Use {@link #forNumber(int)} instead. - */ - @java.lang.Deprecated - public static FilterCase valueOf(int value) { - return forNumber(value); - } + /** .FilterNode filter_node = 1; */ + public PushDownPb.FilterNode getFilterNode() { + if (filterCase_ == 1) { + return (PushDownPb.FilterNode) filter_; + } + return PushDownPb.FilterNode.getDefaultInstance(); + } - public static FilterCase forNumber(int value) { - switch (value) { - case 1: - return FILTER_NODE; - case 2: - return FILTER_NODES; - case 0: - return FILTER_NOT_SET; - default: - return null; - } - } + /** .FilterNode filter_node = 1; */ + public PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder() { + if (filterCase_ == 1) { + return (PushDownPb.FilterNode) filter_; + } + return PushDownPb.FilterNode.getDefaultInstance(); + } - public int getNumber() { - return this.value; - } - } + public static final int FILTER_NODES_FIELD_NUMBER = 2; - ; + /** .FilterNodes filter_nodes = 2; */ + public boolean hasFilterNodes() { + return filterCase_ == 2; + } - public FilterCase - getFilterCase() { - return FilterCase.forNumber( - filterCase_); - } + /** .FilterNodes filter_nodes = 2; */ + public PushDownPb.FilterNodes getFilterNodes() { + if (filterCase_ == 2) { + return (PushDownPb.FilterNodes) filter_; + } + return PushDownPb.FilterNodes.getDefaultInstance(); + } - public static final int FILTER_NODE_FIELD_NUMBER = 1; + /** .FilterNodes filter_nodes = 2; */ + public PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder() { + if (filterCase_ == 2) { + return (PushDownPb.FilterNodes) filter_; + } + return PushDownPb.FilterNodes.getDefaultInstance(); + } - /** - * .FilterNode filter_node = 1; - */ - public boolean hasFilterNode() { - return filterCase_ == 1; - } + public static final int EDGE_LIMIT_FIELD_NUMBER = 3; + private PushDownPb.EdgeLimit edgeLimit_; - /** - * .FilterNode filter_node = 1; - */ - public PushDownPb.FilterNode getFilterNode() { - if (filterCase_ == 1) { - return (PushDownPb.FilterNode) filter_; - } - return PushDownPb.FilterNode.getDefaultInstance(); - } + /** .EdgeLimit edge_limit = 3; */ + public boolean hasEdgeLimit() { + return edgeLimit_ != null; + } - /** - * .FilterNode filter_node = 1; - */ - public PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder() { - if (filterCase_ == 1) { - return (PushDownPb.FilterNode) filter_; - } - return PushDownPb.FilterNode.getDefaultInstance(); - } + /** .EdgeLimit edge_limit = 3; */ + public PushDownPb.EdgeLimit getEdgeLimit() { + return edgeLimit_ == null ? PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; + } - public static final int FILTER_NODES_FIELD_NUMBER = 2; + /** .EdgeLimit edge_limit = 3; */ + public PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder() { + return getEdgeLimit(); + } - /** - * .FilterNodes filter_nodes = 2; - */ - public boolean hasFilterNodes() { - return filterCase_ == 2; - } + public static final int SORT_TYPE_FIELD_NUMBER = 4; + private java.util.List sortType_; + private static final com.google.protobuf.Internal.ListAdapter.Converter< + java.lang.Integer, PushDownPb.SortType> + sortType_converter_ = + new com.google.protobuf.Internal.ListAdapter.Converter< + java.lang.Integer, PushDownPb.SortType>() { + public PushDownPb.SortType convert(java.lang.Integer from) { + @SuppressWarnings("deprecation") + PushDownPb.SortType result = PushDownPb.SortType.valueOf(from); + return result == null ? PushDownPb.SortType.UNRECOGNIZED : result; + } + }; - /** - * .FilterNodes filter_nodes = 2; - */ - public PushDownPb.FilterNodes getFilterNodes() { - if (filterCase_ == 2) { - return (PushDownPb.FilterNodes) filter_; - } - return PushDownPb.FilterNodes.getDefaultInstance(); - } + /** repeated .SortType sort_type = 4; */ + public java.util.List getSortTypeList() { + return new com.google.protobuf.Internal.ListAdapter( + sortType_, sortType_converter_); + } - /** - * .FilterNodes filter_nodes = 2; - */ - public PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder() { - if (filterCase_ == 2) { - return (PushDownPb.FilterNodes) filter_; - } - return PushDownPb.FilterNodes.getDefaultInstance(); - } + /** repeated .SortType sort_type = 4; */ + public int getSortTypeCount() { + return sortType_.size(); + } - public static final int EDGE_LIMIT_FIELD_NUMBER = 3; - private PushDownPb.EdgeLimit edgeLimit_; + /** repeated .SortType sort_type = 4; */ + public PushDownPb.SortType getSortType(int index) { + return sortType_converter_.convert(sortType_.get(index)); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public boolean hasEdgeLimit() { - return edgeLimit_ != null; - } + /** repeated .SortType sort_type = 4; */ + public java.util.List getSortTypeValueList() { + return sortType_; + } - /** - * .EdgeLimit edge_limit = 3; - */ - public PushDownPb.EdgeLimit getEdgeLimit() { - return edgeLimit_ == null ? PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; - } + /** repeated .SortType sort_type = 4; */ + public int getSortTypeValue(int index) { + return sortType_.get(index); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder() { - return getEdgeLimit(); - } + private int sortTypeMemoizedSerializedSize; - public static final int SORT_TYPE_FIELD_NUMBER = 4; - private java.util.List sortType_; - private static final com.google.protobuf.Internal.ListAdapter.Converter< - java.lang.Integer, PushDownPb.SortType> sortType_converter_ = - new com.google.protobuf.Internal.ListAdapter.Converter< - java.lang.Integer, PushDownPb.SortType>() { - public PushDownPb.SortType convert(java.lang.Integer from) { - @SuppressWarnings("deprecation") - PushDownPb.SortType result = PushDownPb.SortType.valueOf(from); - return result == null ? PushDownPb.SortType.UNRECOGNIZED : result; - } - }; + private byte memoizedIsInitialized = -1; - /** - * repeated .SortType sort_type = 4; - */ - public java.util.List getSortTypeList() { - return new com.google.protobuf.Internal.ListAdapter< - java.lang.Integer, PushDownPb.SortType>(sortType_, sortType_converter_); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - /** - * repeated .SortType sort_type = 4; - */ - public int getSortTypeCount() { - return sortType_.size(); - } + memoizedIsInitialized = 1; + return true; + } - /** - * repeated .SortType sort_type = 4; - */ - public PushDownPb.SortType getSortType(int index) { - return sortType_converter_.convert(sortType_.get(index)); - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + getSerializedSize(); + if (filterCase_ == 1) { + output.writeMessage(1, (PushDownPb.FilterNode) filter_); + } + if (filterCase_ == 2) { + output.writeMessage(2, (PushDownPb.FilterNodes) filter_); + } + if (edgeLimit_ != null) { + output.writeMessage(3, getEdgeLimit()); + } + if (getSortTypeList().size() > 0) { + output.writeUInt32NoTag(34); + output.writeUInt32NoTag(sortTypeMemoizedSerializedSize); + } + for (int i = 0; i < sortType_.size(); i++) { + output.writeEnumNoTag(sortType_.get(i)); + } + unknownFields.writeTo(output); + } - /** - * repeated .SortType sort_type = 4; - */ - public java.util.List - getSortTypeValueList() { - return sortType_; - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (filterCase_ == 1) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 1, (PushDownPb.FilterNode) filter_); + } + if (filterCase_ == 2) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 2, (PushDownPb.FilterNodes) filter_); + } + if (edgeLimit_ != null) { + size += com.google.protobuf.CodedOutputStream.computeMessageSize(3, getEdgeLimit()); + } + { + int dataSize = 0; + for (int i = 0; i < sortType_.size(); i++) { + dataSize += com.google.protobuf.CodedOutputStream.computeEnumSizeNoTag(sortType_.get(i)); + } + size += dataSize; + if (!getSortTypeList().isEmpty()) { + size += 1; + size += com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(dataSize); + } + sortTypeMemoizedSerializedSize = dataSize; + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - /** - * repeated .SortType sort_type = 4; - */ - public int getSortTypeValue(int index) { - return sortType_.get(index); - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.PushDown)) { + return super.equals(obj); + } + PushDownPb.PushDown other = (PushDownPb.PushDown) obj; + + if (hasEdgeLimit() != other.hasEdgeLimit()) return false; + if (hasEdgeLimit()) { + if (!getEdgeLimit().equals(other.getEdgeLimit())) return false; + } + if (!sortType_.equals(other.sortType_)) return false; + if (!getFilterCase().equals(other.getFilterCase())) return false; + switch (filterCase_) { + case 1: + if (!getFilterNode().equals(other.getFilterNode())) return false; + break; + case 2: + if (!getFilterNodes().equals(other.getFilterNodes())) return false; + break; + case 0: + default: + } + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - private int sortTypeMemoizedSerializedSize; + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasEdgeLimit()) { + hash = (37 * hash) + EDGE_LIMIT_FIELD_NUMBER; + hash = (53 * hash) + getEdgeLimit().hashCode(); + } + if (getSortTypeCount() > 0) { + hash = (37 * hash) + SORT_TYPE_FIELD_NUMBER; + hash = (53 * hash) + sortType_.hashCode(); + } + switch (filterCase_) { + case 1: + hash = (37 * hash) + FILTER_NODE_FIELD_NUMBER; + hash = (53 * hash) + getFilterNode().hashCode(); + break; + case 2: + hash = (37 * hash) + FILTER_NODES_FIELD_NUMBER; + hash = (53 * hash) + getFilterNodes().hashCode(); + break; + case 0: + default: + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - private byte memoizedIsInitialized = -1; + public static PushDownPb.PushDown parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + public static PushDownPb.PushDown parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - memoizedIsInitialized = 1; - return true; - } + public static PushDownPb.PushDown parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - getSerializedSize(); - if (filterCase_ == 1) { - output.writeMessage(1, (PushDownPb.FilterNode) filter_); - } - if (filterCase_ == 2) { - output.writeMessage(2, (PushDownPb.FilterNodes) filter_); - } - if (edgeLimit_ != null) { - output.writeMessage(3, getEdgeLimit()); - } - if (getSortTypeList().size() > 0) { - output.writeUInt32NoTag(34); - output.writeUInt32NoTag(sortTypeMemoizedSerializedSize); - } - for (int i = 0; i < sortType_.size(); i++) { - output.writeEnumNoTag(sortType_.get(i)); - } - unknownFields.writeTo(output); - } + public static PushDownPb.PushDown parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + public static PushDownPb.PushDown parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - size = 0; - if (filterCase_ == 1) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(1, (PushDownPb.FilterNode) filter_); - } - if (filterCase_ == 2) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(2, (PushDownPb.FilterNodes) filter_); - } - if (edgeLimit_ != null) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(3, getEdgeLimit()); - } - { - int dataSize = 0; - for (int i = 0; i < sortType_.size(); i++) { - dataSize += com.google.protobuf.CodedOutputStream - .computeEnumSizeNoTag(sortType_.get(i)); - } - size += dataSize; - if (!getSortTypeList().isEmpty()) { - size += 1; - size += com.google.protobuf.CodedOutputStream - .computeUInt32SizeNoTag(dataSize); - } - sortTypeMemoizedSerializedSize = dataSize; - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + public static PushDownPb.PushDown parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.PushDown)) { - return super.equals(obj); - } - PushDownPb.PushDown other = (PushDownPb.PushDown) obj; + public static PushDownPb.PushDown parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - if (hasEdgeLimit() != other.hasEdgeLimit()) return false; - if (hasEdgeLimit()) { - if (!getEdgeLimit() - .equals(other.getEdgeLimit())) return false; - } - if (!sortType_.equals(other.sortType_)) return false; - if (!getFilterCase().equals(other.getFilterCase())) return false; - switch (filterCase_) { - case 1: - if (!getFilterNode() - .equals(other.getFilterNode())) return false; - break; - case 2: - if (!getFilterNodes() - .equals(other.getFilterNodes())) return false; - break; - case 0: - default: - } - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } + public static PushDownPb.PushDown parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (hasEdgeLimit()) { - hash = (37 * hash) + EDGE_LIMIT_FIELD_NUMBER; - hash = (53 * hash) + getEdgeLimit().hashCode(); - } - if (getSortTypeCount() > 0) { - hash = (37 * hash) + SORT_TYPE_FIELD_NUMBER; - hash = (53 * hash) + sortType_.hashCode(); - } - switch (filterCase_) { - case 1: - hash = (37 * hash) + FILTER_NODE_FIELD_NUMBER; - hash = (53 * hash) + getFilterNode().hashCode(); - break; - case 2: - hash = (37 * hash) + FILTER_NODES_FIELD_NUMBER; - hash = (53 * hash) + getFilterNodes().hashCode(); - break; - case 0: - default: - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + public static PushDownPb.PushDown parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - public static PushDownPb.PushDown parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static PushDownPb.PushDown parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - public static PushDownPb.PushDown parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + public static PushDownPb.PushDown parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - public static PushDownPb.PushDown parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static PushDownPb.PushDown parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - public static PushDownPb.PushDown parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - public static PushDownPb.PushDown parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - public static PushDownPb.PushDown parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + public static Builder newBuilder(PushDownPb.PushDown prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - public static PushDownPb.PushDown parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - public static PushDownPb.PushDown parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - public static PushDownPb.PushDown parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + /** Protobuf type {@code PushDown} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:PushDown) + PushDownPb.PushDownOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_PushDown_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_PushDown_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.PushDown.class, PushDownPb.PushDown.Builder.class); + } + + // Construct using PushDownPb.PushDown.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + if (edgeLimitBuilder_ == null) { + edgeLimit_ = null; + } else { + edgeLimit_ = null; + edgeLimitBuilder_ = null; + } + sortType_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + filterCase_ = 0; + filter_ = null; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_PushDown_descriptor; + } + + @java.lang.Override + public PushDownPb.PushDown getDefaultInstanceForType() { + return PushDownPb.PushDown.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.PushDown build() { + PushDownPb.PushDown result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.PushDown buildPartial() { + PushDownPb.PushDown result = new PushDownPb.PushDown(this); + if (filterCase_ == 1) { + if (filterNodeBuilder_ == null) { + result.filter_ = filter_; + } else { + result.filter_ = filterNodeBuilder_.build(); + } + } + if (filterCase_ == 2) { + if (filterNodesBuilder_ == null) { + result.filter_ = filter_; + } else { + result.filter_ = filterNodesBuilder_.build(); + } + } + if (edgeLimitBuilder_ == null) { + result.edgeLimit_ = edgeLimit_; + } else { + result.edgeLimit_ = edgeLimitBuilder_.build(); + } + if (((bitField0_ & 0x00000001) != 0)) { + sortType_ = java.util.Collections.unmodifiableList(sortType_); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.sortType_ = sortType_; + result.filterCase_ = filterCase_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.PushDown) { + return mergeFrom((PushDownPb.PushDown) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.PushDown other) { + if (other == PushDownPb.PushDown.getDefaultInstance()) return this; + if (other.hasEdgeLimit()) { + mergeEdgeLimit(other.getEdgeLimit()); + } + if (!other.sortType_.isEmpty()) { + if (sortType_.isEmpty()) { + sortType_ = other.sortType_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureSortTypeIsMutable(); + sortType_.addAll(other.sortType_); + } + onChanged(); + } + switch (other.getFilterCase()) { + case FILTER_NODE: + { + mergeFilterNode(other.getFilterNode()); + break; + } + case FILTER_NODES: + { + mergeFilterNodes(other.getFilterNodes()); + break; + } + case FILTER_NOT_SET: + { + break; + } + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.PushDown parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.PushDown) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int filterCase_ = 0; + private java.lang.Object filter_; + + public FilterCase getFilterCase() { + return FilterCase.forNumber(filterCase_); + } + + public Builder clearFilter() { + filterCase_ = 0; + filter_ = null; + onChanged(); + return this; + } + + private int bitField0_; + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + filterNodeBuilder_; + + /** .FilterNode filter_node = 1; */ + public boolean hasFilterNode() { + return filterCase_ == 1; + } + + /** .FilterNode filter_node = 1; */ + public PushDownPb.FilterNode getFilterNode() { + if (filterNodeBuilder_ == null) { + if (filterCase_ == 1) { + return (PushDownPb.FilterNode) filter_; + } + return PushDownPb.FilterNode.getDefaultInstance(); + } else { + if (filterCase_ == 1) { + return filterNodeBuilder_.getMessage(); + } + return PushDownPb.FilterNode.getDefaultInstance(); + } + } + + /** .FilterNode filter_node = 1; */ + public Builder setFilterNode(PushDownPb.FilterNode value) { + if (filterNodeBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + filter_ = value; + onChanged(); + } else { + filterNodeBuilder_.setMessage(value); + } + filterCase_ = 1; + return this; + } + + /** .FilterNode filter_node = 1; */ + public Builder setFilterNode(PushDownPb.FilterNode.Builder builderForValue) { + if (filterNodeBuilder_ == null) { + filter_ = builderForValue.build(); + onChanged(); + } else { + filterNodeBuilder_.setMessage(builderForValue.build()); + } + filterCase_ = 1; + return this; + } + + /** .FilterNode filter_node = 1; */ + public Builder mergeFilterNode(PushDownPb.FilterNode value) { + if (filterNodeBuilder_ == null) { + if (filterCase_ == 1 && filter_ != PushDownPb.FilterNode.getDefaultInstance()) { + filter_ = + PushDownPb.FilterNode.newBuilder((PushDownPb.FilterNode) filter_) + .mergeFrom(value) + .buildPartial(); + } else { + filter_ = value; + } + onChanged(); + } else { + if (filterCase_ == 1) { + filterNodeBuilder_.mergeFrom(value); + } + filterNodeBuilder_.setMessage(value); + } + filterCase_ = 1; + return this; + } + + /** .FilterNode filter_node = 1; */ + public Builder clearFilterNode() { + if (filterNodeBuilder_ == null) { + if (filterCase_ == 1) { + filterCase_ = 0; + filter_ = null; + onChanged(); + } + } else { + if (filterCase_ == 1) { + filterCase_ = 0; + filter_ = null; + } + filterNodeBuilder_.clear(); + } + return this; + } + + /** .FilterNode filter_node = 1; */ + public PushDownPb.FilterNode.Builder getFilterNodeBuilder() { + return getFilterNodeFieldBuilder().getBuilder(); + } + + /** .FilterNode filter_node = 1; */ + public PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder() { + if ((filterCase_ == 1) && (filterNodeBuilder_ != null)) { + return filterNodeBuilder_.getMessageOrBuilder(); + } else { + if (filterCase_ == 1) { + return (PushDownPb.FilterNode) filter_; + } + return PushDownPb.FilterNode.getDefaultInstance(); + } + } + + /** .FilterNode filter_node = 1; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + getFilterNodeFieldBuilder() { + if (filterNodeBuilder_ == null) { + if (!(filterCase_ == 1)) { + filter_ = PushDownPb.FilterNode.getDefaultInstance(); + } + filterNodeBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNode, + PushDownPb.FilterNode.Builder, + PushDownPb.FilterNodeOrBuilder>( + (PushDownPb.FilterNode) filter_, getParentForChildren(), isClean()); + filter_ = null; + } + filterCase_ = 1; + onChanged(); + return filterNodeBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNodes, + PushDownPb.FilterNodes.Builder, + PushDownPb.FilterNodesOrBuilder> + filterNodesBuilder_; + + /** .FilterNodes filter_nodes = 2; */ + public boolean hasFilterNodes() { + return filterCase_ == 2; + } + + /** .FilterNodes filter_nodes = 2; */ + public PushDownPb.FilterNodes getFilterNodes() { + if (filterNodesBuilder_ == null) { + if (filterCase_ == 2) { + return (PushDownPb.FilterNodes) filter_; + } + return PushDownPb.FilterNodes.getDefaultInstance(); + } else { + if (filterCase_ == 2) { + return filterNodesBuilder_.getMessage(); + } + return PushDownPb.FilterNodes.getDefaultInstance(); + } + } + + /** .FilterNodes filter_nodes = 2; */ + public Builder setFilterNodes(PushDownPb.FilterNodes value) { + if (filterNodesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + filter_ = value; + onChanged(); + } else { + filterNodesBuilder_.setMessage(value); + } + filterCase_ = 2; + return this; + } + + /** .FilterNodes filter_nodes = 2; */ + public Builder setFilterNodes(PushDownPb.FilterNodes.Builder builderForValue) { + if (filterNodesBuilder_ == null) { + filter_ = builderForValue.build(); + onChanged(); + } else { + filterNodesBuilder_.setMessage(builderForValue.build()); + } + filterCase_ = 2; + return this; + } + + /** .FilterNodes filter_nodes = 2; */ + public Builder mergeFilterNodes(PushDownPb.FilterNodes value) { + if (filterNodesBuilder_ == null) { + if (filterCase_ == 2 && filter_ != PushDownPb.FilterNodes.getDefaultInstance()) { + filter_ = + PushDownPb.FilterNodes.newBuilder((PushDownPb.FilterNodes) filter_) + .mergeFrom(value) + .buildPartial(); + } else { + filter_ = value; + } + onChanged(); + } else { + if (filterCase_ == 2) { + filterNodesBuilder_.mergeFrom(value); + } + filterNodesBuilder_.setMessage(value); + } + filterCase_ = 2; + return this; + } + + /** .FilterNodes filter_nodes = 2; */ + public Builder clearFilterNodes() { + if (filterNodesBuilder_ == null) { + if (filterCase_ == 2) { + filterCase_ = 0; + filter_ = null; + onChanged(); + } + } else { + if (filterCase_ == 2) { + filterCase_ = 0; + filter_ = null; + } + filterNodesBuilder_.clear(); + } + return this; + } + + /** .FilterNodes filter_nodes = 2; */ + public PushDownPb.FilterNodes.Builder getFilterNodesBuilder() { + return getFilterNodesFieldBuilder().getBuilder(); + } + + /** .FilterNodes filter_nodes = 2; */ + public PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder() { + if ((filterCase_ == 2) && (filterNodesBuilder_ != null)) { + return filterNodesBuilder_.getMessageOrBuilder(); + } else { + if (filterCase_ == 2) { + return (PushDownPb.FilterNodes) filter_; + } + return PushDownPb.FilterNodes.getDefaultInstance(); + } + } + + /** .FilterNodes filter_nodes = 2; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNodes, + PushDownPb.FilterNodes.Builder, + PushDownPb.FilterNodesOrBuilder> + getFilterNodesFieldBuilder() { + if (filterNodesBuilder_ == null) { + if (!(filterCase_ == 2)) { + filter_ = PushDownPb.FilterNodes.getDefaultInstance(); + } + filterNodesBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.FilterNodes, + PushDownPb.FilterNodes.Builder, + PushDownPb.FilterNodesOrBuilder>( + (PushDownPb.FilterNodes) filter_, getParentForChildren(), isClean()); + filter_ = null; + } + filterCase_ = 2; + onChanged(); + return filterNodesBuilder_; + } + + private PushDownPb.EdgeLimit edgeLimit_; + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.EdgeLimit, PushDownPb.EdgeLimit.Builder, PushDownPb.EdgeLimitOrBuilder> + edgeLimitBuilder_; + + /** .EdgeLimit edge_limit = 3; */ + public boolean hasEdgeLimit() { + return edgeLimitBuilder_ != null || edgeLimit_ != null; + } + + /** .EdgeLimit edge_limit = 3; */ + public PushDownPb.EdgeLimit getEdgeLimit() { + if (edgeLimitBuilder_ == null) { + return edgeLimit_ == null ? PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; + } else { + return edgeLimitBuilder_.getMessage(); + } + } + + /** .EdgeLimit edge_limit = 3; */ + public Builder setEdgeLimit(PushDownPb.EdgeLimit value) { + if (edgeLimitBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + edgeLimit_ = value; + onChanged(); + } else { + edgeLimitBuilder_.setMessage(value); + } + + return this; + } + + /** .EdgeLimit edge_limit = 3; */ + public Builder setEdgeLimit(PushDownPb.EdgeLimit.Builder builderForValue) { + if (edgeLimitBuilder_ == null) { + edgeLimit_ = builderForValue.build(); + onChanged(); + } else { + edgeLimitBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + + /** .EdgeLimit edge_limit = 3; */ + public Builder mergeEdgeLimit(PushDownPb.EdgeLimit value) { + if (edgeLimitBuilder_ == null) { + if (edgeLimit_ != null) { + edgeLimit_ = + PushDownPb.EdgeLimit.newBuilder(edgeLimit_).mergeFrom(value).buildPartial(); + } else { + edgeLimit_ = value; + } + onChanged(); + } else { + edgeLimitBuilder_.mergeFrom(value); + } + + return this; + } + + /** .EdgeLimit edge_limit = 3; */ + public Builder clearEdgeLimit() { + if (edgeLimitBuilder_ == null) { + edgeLimit_ = null; + onChanged(); + } else { + edgeLimit_ = null; + edgeLimitBuilder_ = null; + } + + return this; + } + + /** .EdgeLimit edge_limit = 3; */ + public PushDownPb.EdgeLimit.Builder getEdgeLimitBuilder() { + + onChanged(); + return getEdgeLimitFieldBuilder().getBuilder(); + } + + /** .EdgeLimit edge_limit = 3; */ + public PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder() { + if (edgeLimitBuilder_ != null) { + return edgeLimitBuilder_.getMessageOrBuilder(); + } else { + return edgeLimit_ == null ? PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; + } + } + + /** .EdgeLimit edge_limit = 3; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.EdgeLimit, PushDownPb.EdgeLimit.Builder, PushDownPb.EdgeLimitOrBuilder> + getEdgeLimitFieldBuilder() { + if (edgeLimitBuilder_ == null) { + edgeLimitBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.EdgeLimit, + PushDownPb.EdgeLimit.Builder, + PushDownPb.EdgeLimitOrBuilder>(getEdgeLimit(), getParentForChildren(), isClean()); + edgeLimit_ = null; + } + return edgeLimitBuilder_; + } + + private java.util.List sortType_ = java.util.Collections.emptyList(); + + private void ensureSortTypeIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + sortType_ = new java.util.ArrayList(sortType_); + bitField0_ |= 0x00000001; + } + } + + /** repeated .SortType sort_type = 4; */ + public java.util.List getSortTypeList() { + return new com.google.protobuf.Internal.ListAdapter( + sortType_, sortType_converter_); + } + + /** repeated .SortType sort_type = 4; */ + public int getSortTypeCount() { + return sortType_.size(); + } + + /** repeated .SortType sort_type = 4; */ + public PushDownPb.SortType getSortType(int index) { + return sortType_converter_.convert(sortType_.get(index)); + } + + /** repeated .SortType sort_type = 4; */ + public Builder setSortType(int index, PushDownPb.SortType value) { + if (value == null) { + throw new NullPointerException(); + } + ensureSortTypeIsMutable(); + sortType_.set(index, value.getNumber()); + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public Builder addSortType(PushDownPb.SortType value) { + if (value == null) { + throw new NullPointerException(); + } + ensureSortTypeIsMutable(); + sortType_.add(value.getNumber()); + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public Builder addAllSortType(java.lang.Iterable values) { + ensureSortTypeIsMutable(); + for (PushDownPb.SortType value : values) { + sortType_.add(value.getNumber()); + } + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public Builder clearSortType() { + sortType_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public java.util.List getSortTypeValueList() { + return java.util.Collections.unmodifiableList(sortType_); + } + + /** repeated .SortType sort_type = 4; */ + public int getSortTypeValue(int index) { + return sortType_.get(index); + } + + /** repeated .SortType sort_type = 4; */ + public Builder setSortTypeValue(int index, int value) { + ensureSortTypeIsMutable(); + sortType_.set(index, value); + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public Builder addSortTypeValue(int value) { + ensureSortTypeIsMutable(); + sortType_.add(value); + onChanged(); + return this; + } + + /** repeated .SortType sort_type = 4; */ + public Builder addAllSortTypeValue(java.lang.Iterable values) { + ensureSortTypeIsMutable(); + for (int value : values) { + sortType_.add(value); + } + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:PushDown) + } - public static PushDownPb.PushDown parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + // @@protoc_insertion_point(class_scope:PushDown) + private static final PushDownPb.PushDown DEFAULT_INSTANCE; - public static PushDownPb.PushDown parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + static { + DEFAULT_INSTANCE = new PushDownPb.PushDown(); + } - public static PushDownPb.PushDown parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static PushDownPb.PushDown getDefaultInstance() { + return DEFAULT_INSTANCE; + } - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public PushDown parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new PushDown(input, extensionRegistry); + } + }; - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - public static Builder newBuilder(PushDownPb.PushDown prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + @java.lang.Override + public PushDownPb.PushDown getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + public interface EdgeLimitOrBuilder + extends + // @@protoc_insertion_point(interface_extends:EdgeLimit) + com.google.protobuf.MessageOrBuilder { - /** - * Protobuf type {@code PushDown} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:PushDown) - PushDownPb.PushDownOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_PushDown_descriptor; - } + /** uint64 in = 1; */ + long getIn(); - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_PushDown_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.PushDown.class, PushDownPb.PushDown.Builder.class); - } + /** uint64 out = 2; */ + long getOut(); - // Construct using PushDownPb.PushDown.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } + /** bool is_single = 3; */ + boolean getIsSingle(); + } - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } + /** Protobuf type {@code EdgeLimit} */ + public static final class EdgeLimit extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:EdgeLimit) + EdgeLimitOrBuilder { + private static final long serialVersionUID = 0L; - private void maybeForceBuilderInitialization() { - } + // Use EdgeLimit.newBuilder() to construct. + private EdgeLimit(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - @java.lang.Override - public Builder clear() { - super.clear(); - if (edgeLimitBuilder_ == null) { - edgeLimit_ = null; - } else { - edgeLimit_ = null; - edgeLimitBuilder_ = null; - } - sortType_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - filterCase_ = 0; - filter_ = null; - return this; - } + private EdgeLimit() {} - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_PushDown_descriptor; - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new EdgeLimit(); + } - @java.lang.Override - public PushDownPb.PushDown getDefaultInstanceForType() { - return PushDownPb.PushDown.getDefaultInstance(); - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - @java.lang.Override - public PushDownPb.PushDown build() { - PushDownPb.PushDown result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } + private EdgeLimit( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: + { + in_ = input.readUInt64(); + break; + } + case 16: + { + out_ = input.readUInt64(); + break; + } + case 24: + { + isSingle_ = input.readBool(); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - @java.lang.Override - public PushDownPb.PushDown buildPartial() { - PushDownPb.PushDown result = new PushDownPb.PushDown(this); - if (filterCase_ == 1) { - if (filterNodeBuilder_ == null) { - result.filter_ = filter_; - } else { - result.filter_ = filterNodeBuilder_.build(); - } - } - if (filterCase_ == 2) { - if (filterNodesBuilder_ == null) { - result.filter_ = filter_; - } else { - result.filter_ = filterNodesBuilder_.build(); - } - } - if (edgeLimitBuilder_ == null) { - result.edgeLimit_ = edgeLimit_; - } else { - result.edgeLimit_ = edgeLimitBuilder_.build(); - } - if (((bitField0_ & 0x00000001) != 0)) { - sortType_ = java.util.Collections.unmodifiableList(sortType_); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.sortType_ = sortType_; - result.filterCase_ = filterCase_; - onBuilt(); - return result; - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_EdgeLimit_descriptor; + } - @java.lang.Override - public Builder clone() { - return super.clone(); - } + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_EdgeLimit_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.EdgeLimit.class, PushDownPb.EdgeLimit.Builder.class); + } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } + public static final int IN_FIELD_NUMBER = 1; + private long in_; - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } + /** uint64 in = 1; */ + public long getIn() { + return in_; + } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } + public static final int OUT_FIELD_NUMBER = 2; + private long out_; - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } + /** uint64 out = 2; */ + public long getOut() { + return out_; + } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } + public static final int IS_SINGLE_FIELD_NUMBER = 3; + private boolean isSingle_; - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.PushDown) { - return mergeFrom((PushDownPb.PushDown) other); - } else { - super.mergeFrom(other); - return this; - } - } + /** bool is_single = 3; */ + public boolean getIsSingle() { + return isSingle_; + } - public Builder mergeFrom(PushDownPb.PushDown other) { - if (other == PushDownPb.PushDown.getDefaultInstance()) return this; - if (other.hasEdgeLimit()) { - mergeEdgeLimit(other.getEdgeLimit()); - } - if (!other.sortType_.isEmpty()) { - if (sortType_.isEmpty()) { - sortType_ = other.sortType_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureSortTypeIsMutable(); - sortType_.addAll(other.sortType_); - } - onChanged(); - } - switch (other.getFilterCase()) { - case FILTER_NODE: { - mergeFilterNode(other.getFilterNode()); - break; - } - case FILTER_NODES: { - mergeFilterNodes(other.getFilterNodes()); - break; - } - case FILTER_NOT_SET: { - break; - } - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } + private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { - return true; - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.PushDown parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.PushDown) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } + memoizedIsInitialized = 1; + return true; + } - private int filterCase_ = 0; - private java.lang.Object filter_; + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (in_ != 0L) { + output.writeUInt64(1, in_); + } + if (out_ != 0L) { + output.writeUInt64(2, out_); + } + if (isSingle_ != false) { + output.writeBool(3, isSingle_); + } + unknownFields.writeTo(output); + } - public FilterCase - getFilterCase() { - return FilterCase.forNumber( - filterCase_); - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (in_ != 0L) { + size += com.google.protobuf.CodedOutputStream.computeUInt64Size(1, in_); + } + if (out_ != 0L) { + size += com.google.protobuf.CodedOutputStream.computeUInt64Size(2, out_); + } + if (isSingle_ != false) { + size += com.google.protobuf.CodedOutputStream.computeBoolSize(3, isSingle_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - public Builder clearFilter() { - filterCase_ = 0; - filter_ = null; - onChanged(); - return this; - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.EdgeLimit)) { + return super.equals(obj); + } + PushDownPb.EdgeLimit other = (PushDownPb.EdgeLimit) obj; + + if (getIn() != other.getIn()) return false; + if (getOut() != other.getOut()) return false; + if (getIsSingle() != other.getIsSingle()) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - private int bitField0_; + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + IN_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashLong(getIn()); + hash = (37 * hash) + OUT_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashLong(getOut()); + hash = (37 * hash) + IS_SINGLE_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean(getIsSingle()); + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> filterNodeBuilder_; + public static PushDownPb.EdgeLimit parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * .FilterNode filter_node = 1; - */ - public boolean hasFilterNode() { - return filterCase_ == 1; - } + public static PushDownPb.EdgeLimit parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * .FilterNode filter_node = 1; - */ - public PushDownPb.FilterNode getFilterNode() { - if (filterNodeBuilder_ == null) { - if (filterCase_ == 1) { - return (PushDownPb.FilterNode) filter_; - } - return PushDownPb.FilterNode.getDefaultInstance(); - } else { - if (filterCase_ == 1) { - return filterNodeBuilder_.getMessage(); - } - return PushDownPb.FilterNode.getDefaultInstance(); - } - } + public static PushDownPb.EdgeLimit parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * .FilterNode filter_node = 1; - */ - public Builder setFilterNode(PushDownPb.FilterNode value) { - if (filterNodeBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - filter_ = value; - onChanged(); - } else { - filterNodeBuilder_.setMessage(value); - } - filterCase_ = 1; - return this; - } - - /** - * .FilterNode filter_node = 1; - */ - public Builder setFilterNode( - PushDownPb.FilterNode.Builder builderForValue) { - if (filterNodeBuilder_ == null) { - filter_ = builderForValue.build(); - onChanged(); - } else { - filterNodeBuilder_.setMessage(builderForValue.build()); - } - filterCase_ = 1; - return this; - } - - /** - * .FilterNode filter_node = 1; - */ - public Builder mergeFilterNode(PushDownPb.FilterNode value) { - if (filterNodeBuilder_ == null) { - if (filterCase_ == 1 && - filter_ != PushDownPb.FilterNode.getDefaultInstance()) { - filter_ = PushDownPb.FilterNode.newBuilder((PushDownPb.FilterNode) filter_) - .mergeFrom(value).buildPartial(); - } else { - filter_ = value; - } - onChanged(); - } else { - if (filterCase_ == 1) { - filterNodeBuilder_.mergeFrom(value); - } - filterNodeBuilder_.setMessage(value); - } - filterCase_ = 1; - return this; - } - - /** - * .FilterNode filter_node = 1; - */ - public Builder clearFilterNode() { - if (filterNodeBuilder_ == null) { - if (filterCase_ == 1) { - filterCase_ = 0; - filter_ = null; - onChanged(); - } - } else { - if (filterCase_ == 1) { - filterCase_ = 0; - filter_ = null; - } - filterNodeBuilder_.clear(); - } - return this; - } - - /** - * .FilterNode filter_node = 1; - */ - public PushDownPb.FilterNode.Builder getFilterNodeBuilder() { - return getFilterNodeFieldBuilder().getBuilder(); - } - - /** - * .FilterNode filter_node = 1; - */ - public PushDownPb.FilterNodeOrBuilder getFilterNodeOrBuilder() { - if ((filterCase_ == 1) && (filterNodeBuilder_ != null)) { - return filterNodeBuilder_.getMessageOrBuilder(); - } else { - if (filterCase_ == 1) { - return (PushDownPb.FilterNode) filter_; - } - return PushDownPb.FilterNode.getDefaultInstance(); - } - } - - /** - * .FilterNode filter_node = 1; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> - getFilterNodeFieldBuilder() { - if (filterNodeBuilder_ == null) { - if (!(filterCase_ == 1)) { - filter_ = PushDownPb.FilterNode.getDefaultInstance(); - } - filterNodeBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder>( - (PushDownPb.FilterNode) filter_, - getParentForChildren(), - isClean()); - filter_ = null; - } - filterCase_ = 1; - onChanged(); - return filterNodeBuilder_; - } - - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNodes, PushDownPb.FilterNodes.Builder, PushDownPb.FilterNodesOrBuilder> filterNodesBuilder_; - - /** - * .FilterNodes filter_nodes = 2; - */ - public boolean hasFilterNodes() { - return filterCase_ == 2; - } - - /** - * .FilterNodes filter_nodes = 2; - */ - public PushDownPb.FilterNodes getFilterNodes() { - if (filterNodesBuilder_ == null) { - if (filterCase_ == 2) { - return (PushDownPb.FilterNodes) filter_; - } - return PushDownPb.FilterNodes.getDefaultInstance(); - } else { - if (filterCase_ == 2) { - return filterNodesBuilder_.getMessage(); - } - return PushDownPb.FilterNodes.getDefaultInstance(); - } - } - - /** - * .FilterNodes filter_nodes = 2; - */ - public Builder setFilterNodes(PushDownPb.FilterNodes value) { - if (filterNodesBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - filter_ = value; - onChanged(); - } else { - filterNodesBuilder_.setMessage(value); - } - filterCase_ = 2; - return this; - } - - /** - * .FilterNodes filter_nodes = 2; - */ - public Builder setFilterNodes( - PushDownPb.FilterNodes.Builder builderForValue) { - if (filterNodesBuilder_ == null) { - filter_ = builderForValue.build(); - onChanged(); - } else { - filterNodesBuilder_.setMessage(builderForValue.build()); - } - filterCase_ = 2; - return this; - } - - /** - * .FilterNodes filter_nodes = 2; - */ - public Builder mergeFilterNodes(PushDownPb.FilterNodes value) { - if (filterNodesBuilder_ == null) { - if (filterCase_ == 2 && - filter_ != PushDownPb.FilterNodes.getDefaultInstance()) { - filter_ = PushDownPb.FilterNodes.newBuilder((PushDownPb.FilterNodes) filter_) - .mergeFrom(value).buildPartial(); - } else { - filter_ = value; - } - onChanged(); - } else { - if (filterCase_ == 2) { - filterNodesBuilder_.mergeFrom(value); - } - filterNodesBuilder_.setMessage(value); - } - filterCase_ = 2; - return this; - } - - /** - * .FilterNodes filter_nodes = 2; - */ - public Builder clearFilterNodes() { - if (filterNodesBuilder_ == null) { - if (filterCase_ == 2) { - filterCase_ = 0; - filter_ = null; - onChanged(); - } - } else { - if (filterCase_ == 2) { - filterCase_ = 0; - filter_ = null; - } - filterNodesBuilder_.clear(); - } - return this; - } + public static PushDownPb.EdgeLimit parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * .FilterNodes filter_nodes = 2; - */ - public PushDownPb.FilterNodes.Builder getFilterNodesBuilder() { - return getFilterNodesFieldBuilder().getBuilder(); - } + public static PushDownPb.EdgeLimit parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * .FilterNodes filter_nodes = 2; - */ - public PushDownPb.FilterNodesOrBuilder getFilterNodesOrBuilder() { - if ((filterCase_ == 2) && (filterNodesBuilder_ != null)) { - return filterNodesBuilder_.getMessageOrBuilder(); - } else { - if (filterCase_ == 2) { - return (PushDownPb.FilterNodes) filter_; - } - return PushDownPb.FilterNodes.getDefaultInstance(); - } - } + public static PushDownPb.EdgeLimit parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * .FilterNodes filter_nodes = 2; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNodes, PushDownPb.FilterNodes.Builder, PushDownPb.FilterNodesOrBuilder> - getFilterNodesFieldBuilder() { - if (filterNodesBuilder_ == null) { - if (!(filterCase_ == 2)) { - filter_ = PushDownPb.FilterNodes.getDefaultInstance(); - } - filterNodesBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.FilterNodes, PushDownPb.FilterNodes.Builder, PushDownPb.FilterNodesOrBuilder>( - (PushDownPb.FilterNodes) filter_, - getParentForChildren(), - isClean()); - filter_ = null; - } - filterCase_ = 2; - onChanged(); - return filterNodesBuilder_; - } + public static PushDownPb.EdgeLimit parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - private PushDownPb.EdgeLimit edgeLimit_; - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.EdgeLimit, PushDownPb.EdgeLimit.Builder, PushDownPb.EdgeLimitOrBuilder> edgeLimitBuilder_; + public static PushDownPb.EdgeLimit parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public boolean hasEdgeLimit() { - return edgeLimitBuilder_ != null || edgeLimit_ != null; - } + public static PushDownPb.EdgeLimit parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public PushDownPb.EdgeLimit getEdgeLimit() { - if (edgeLimitBuilder_ == null) { - return edgeLimit_ == null ? PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; - } else { - return edgeLimitBuilder_.getMessage(); - } - } + public static PushDownPb.EdgeLimit parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public Builder setEdgeLimit(PushDownPb.EdgeLimit value) { - if (edgeLimitBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - edgeLimit_ = value; - onChanged(); - } else { - edgeLimitBuilder_.setMessage(value); - } + public static PushDownPb.EdgeLimit parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - return this; - } + public static PushDownPb.EdgeLimit parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public Builder setEdgeLimit( - PushDownPb.EdgeLimit.Builder builderForValue) { - if (edgeLimitBuilder_ == null) { - edgeLimit_ = builderForValue.build(); - onChanged(); - } else { - edgeLimitBuilder_.setMessage(builderForValue.build()); - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - return this; - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public Builder mergeEdgeLimit(PushDownPb.EdgeLimit value) { - if (edgeLimitBuilder_ == null) { - if (edgeLimit_ != null) { - edgeLimit_ = - PushDownPb.EdgeLimit.newBuilder(edgeLimit_).mergeFrom(value).buildPartial(); - } else { - edgeLimit_ = value; - } - onChanged(); - } else { - edgeLimitBuilder_.mergeFrom(value); - } + public static Builder newBuilder(PushDownPb.EdgeLimit prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - return this; - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public Builder clearEdgeLimit() { - if (edgeLimitBuilder_ == null) { - edgeLimit_ = null; - onChanged(); - } else { - edgeLimit_ = null; - edgeLimitBuilder_ = null; - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - return this; - } + /** Protobuf type {@code EdgeLimit} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:EdgeLimit) + PushDownPb.EdgeLimitOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_EdgeLimit_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_EdgeLimit_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.EdgeLimit.class, PushDownPb.EdgeLimit.Builder.class); + } + + // Construct using PushDownPb.EdgeLimit.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + in_ = 0L; + + out_ = 0L; + + isSingle_ = false; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_EdgeLimit_descriptor; + } + + @java.lang.Override + public PushDownPb.EdgeLimit getDefaultInstanceForType() { + return PushDownPb.EdgeLimit.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.EdgeLimit build() { + PushDownPb.EdgeLimit result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.EdgeLimit buildPartial() { + PushDownPb.EdgeLimit result = new PushDownPb.EdgeLimit(this); + result.in_ = in_; + result.out_ = out_; + result.isSingle_ = isSingle_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.EdgeLimit) { + return mergeFrom((PushDownPb.EdgeLimit) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.EdgeLimit other) { + if (other == PushDownPb.EdgeLimit.getDefaultInstance()) return this; + if (other.getIn() != 0L) { + setIn(other.getIn()); + } + if (other.getOut() != 0L) { + setOut(other.getOut()); + } + if (other.getIsSingle() != false) { + setIsSingle(other.getIsSingle()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.EdgeLimit parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.EdgeLimit) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private long in_; + + /** uint64 in = 1; */ + public long getIn() { + return in_; + } + + /** uint64 in = 1; */ + public Builder setIn(long value) { + + in_ = value; + onChanged(); + return this; + } + + /** uint64 in = 1; */ + public Builder clearIn() { + + in_ = 0L; + onChanged(); + return this; + } + + private long out_; + + /** uint64 out = 2; */ + public long getOut() { + return out_; + } + + /** uint64 out = 2; */ + public Builder setOut(long value) { + + out_ = value; + onChanged(); + return this; + } + + /** uint64 out = 2; */ + public Builder clearOut() { + + out_ = 0L; + onChanged(); + return this; + } + + private boolean isSingle_; + + /** bool is_single = 3; */ + public boolean getIsSingle() { + return isSingle_; + } + + /** bool is_single = 3; */ + public Builder setIsSingle(boolean value) { + + isSingle_ = value; + onChanged(); + return this; + } + + /** bool is_single = 3; */ + public Builder clearIsSingle() { + + isSingle_ = false; + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:EdgeLimit) + } - /** - * .EdgeLimit edge_limit = 3; - */ - public PushDownPb.EdgeLimit.Builder getEdgeLimitBuilder() { + // @@protoc_insertion_point(class_scope:EdgeLimit) + private static final PushDownPb.EdgeLimit DEFAULT_INSTANCE; - onChanged(); - return getEdgeLimitFieldBuilder().getBuilder(); - } + static { + DEFAULT_INSTANCE = new PushDownPb.EdgeLimit(); + } - /** - * .EdgeLimit edge_limit = 3; - */ - public PushDownPb.EdgeLimitOrBuilder getEdgeLimitOrBuilder() { - if (edgeLimitBuilder_ != null) { - return edgeLimitBuilder_.getMessageOrBuilder(); - } else { - return edgeLimit_ == null ? - PushDownPb.EdgeLimit.getDefaultInstance() : edgeLimit_; - } - } + public static PushDownPb.EdgeLimit getDefaultInstance() { + return DEFAULT_INSTANCE; + } - /** - * .EdgeLimit edge_limit = 3; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.EdgeLimit, PushDownPb.EdgeLimit.Builder, PushDownPb.EdgeLimitOrBuilder> - getEdgeLimitFieldBuilder() { - if (edgeLimitBuilder_ == null) { - edgeLimitBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.EdgeLimit, PushDownPb.EdgeLimit.Builder, PushDownPb.EdgeLimitOrBuilder>( - getEdgeLimit(), - getParentForChildren(), - isClean()); - edgeLimit_ = null; - } - return edgeLimitBuilder_; - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public EdgeLimit parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new EdgeLimit(input, extensionRegistry); + } + }; - private java.util.List sortType_ = - java.util.Collections.emptyList(); + public static com.google.protobuf.Parser parser() { + return PARSER; + } - private void ensureSortTypeIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - sortType_ = new java.util.ArrayList(sortType_); - bitField0_ |= 0x00000001; - } - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - /** - * repeated .SortType sort_type = 4; - */ - public java.util.List getSortTypeList() { - return new com.google.protobuf.Internal.ListAdapter< - java.lang.Integer, PushDownPb.SortType>(sortType_, sortType_converter_); - } + @java.lang.Override + public PushDownPb.EdgeLimit getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - /** - * repeated .SortType sort_type = 4; - */ - public int getSortTypeCount() { - return sortType_.size(); - } + public interface FilterNodesOrBuilder + extends + // @@protoc_insertion_point(interface_extends:FilterNodes) + com.google.protobuf.MessageOrBuilder { - /** - * repeated .SortType sort_type = 4; - */ - public PushDownPb.SortType getSortType(int index) { - return sortType_converter_.convert(sortType_.get(index)); - } + /** repeated bytes keys = 1; */ + java.util.List getKeysList(); - /** - * repeated .SortType sort_type = 4; - */ - public Builder setSortType( - int index, PushDownPb.SortType value) { - if (value == null) { - throw new NullPointerException(); - } - ensureSortTypeIsMutable(); - sortType_.set(index, value.getNumber()); - onChanged(); - return this; - } + /** repeated bytes keys = 1; */ + int getKeysCount(); - /** - * repeated .SortType sort_type = 4; - */ - public Builder addSortType(PushDownPb.SortType value) { - if (value == null) { - throw new NullPointerException(); - } - ensureSortTypeIsMutable(); - sortType_.add(value.getNumber()); - onChanged(); - return this; - } + /** repeated bytes keys = 1; */ + com.google.protobuf.ByteString getKeys(int index); - /** - * repeated .SortType sort_type = 4; - */ - public Builder addAllSortType( - java.lang.Iterable values) { - ensureSortTypeIsMutable(); - for (PushDownPb.SortType value : values) { - sortType_.add(value.getNumber()); - } - onChanged(); - return this; - } + /** repeated .FilterNode filter_nodes = 2; */ + java.util.List getFilterNodesList(); - /** - * repeated .SortType sort_type = 4; - */ - public Builder clearSortType() { - sortType_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } + /** repeated .FilterNode filter_nodes = 2; */ + PushDownPb.FilterNode getFilterNodes(int index); - /** - * repeated .SortType sort_type = 4; - */ - public java.util.List - getSortTypeValueList() { - return java.util.Collections.unmodifiableList(sortType_); - } + /** repeated .FilterNode filter_nodes = 2; */ + int getFilterNodesCount(); - /** - * repeated .SortType sort_type = 4; - */ - public int getSortTypeValue(int index) { - return sortType_.get(index); - } + /** repeated .FilterNode filter_nodes = 2; */ + java.util.List getFilterNodesOrBuilderList(); - /** - * repeated .SortType sort_type = 4; - */ - public Builder setSortTypeValue( - int index, int value) { - ensureSortTypeIsMutable(); - sortType_.set(index, value); - onChanged(); - return this; - } + /** repeated .FilterNode filter_nodes = 2; */ + PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder(int index); + } - /** - * repeated .SortType sort_type = 4; - */ - public Builder addSortTypeValue(int value) { - ensureSortTypeIsMutable(); - sortType_.add(value); - onChanged(); - return this; - } + /** Protobuf type {@code FilterNodes} */ + public static final class FilterNodes extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:FilterNodes) + FilterNodesOrBuilder { + private static final long serialVersionUID = 0L; - /** - * repeated .SortType sort_type = 4; - */ - public Builder addAllSortTypeValue( - java.lang.Iterable values) { - ensureSortTypeIsMutable(); - for (int value : values) { - sortType_.add(value); - } - onChanged(); - return this; - } + // Use FilterNodes.newBuilder() to construct. + private FilterNodes(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } + private FilterNodes() { + keys_ = java.util.Collections.emptyList(); + filterNodes_ = java.util.Collections.emptyList(); + } - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new FilterNodes(); + } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - // @@protoc_insertion_point(builder_scope:PushDown) - } + private FilterNodes( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + keys_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000001; + } + keys_.add(input.readBytes()); + break; + } + case 18: + { + if (!((mutable_bitField0_ & 0x00000002) != 0)) { + filterNodes_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000002; + } + filterNodes_.add( + input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry)); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + keys_ = java.util.Collections.unmodifiableList(keys_); // C + } + if (((mutable_bitField0_ & 0x00000002) != 0)) { + filterNodes_ = java.util.Collections.unmodifiableList(filterNodes_); + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - // @@protoc_insertion_point(class_scope:PushDown) - private static final PushDownPb.PushDown DEFAULT_INSTANCE; + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_FilterNodes_descriptor; + } - static { - DEFAULT_INSTANCE = new PushDownPb.PushDown(); - } + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_FilterNodes_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.FilterNodes.class, PushDownPb.FilterNodes.Builder.class); + } - public static PushDownPb.PushDown getDefaultInstance() { - return DEFAULT_INSTANCE; - } + public static final int KEYS_FIELD_NUMBER = 1; + private java.util.List keys_; - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public PushDown parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new PushDown(input, extensionRegistry); - } - }; + /** repeated bytes keys = 1; */ + public java.util.List getKeysList() { + return keys_; + } - public static com.google.protobuf.Parser parser() { - return PARSER; - } + /** repeated bytes keys = 1; */ + public int getKeysCount() { + return keys_.size(); + } - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } + /** repeated bytes keys = 1; */ + public com.google.protobuf.ByteString getKeys(int index) { + return keys_.get(index); + } - @java.lang.Override - public PushDownPb.PushDown getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } + public static final int FILTER_NODES_FIELD_NUMBER = 2; + private java.util.List filterNodes_; + /** repeated .FilterNode filter_nodes = 2; */ + public java.util.List getFilterNodesList() { + return filterNodes_; } - public interface EdgeLimitOrBuilder extends - // @@protoc_insertion_point(interface_extends:EdgeLimit) - com.google.protobuf.MessageOrBuilder { + /** repeated .FilterNode filter_nodes = 2; */ + public java.util.List getFilterNodesOrBuilderList() { + return filterNodes_; + } - /** - * uint64 in = 1; - */ - long getIn(); + /** repeated .FilterNode filter_nodes = 2; */ + public int getFilterNodesCount() { + return filterNodes_.size(); + } - /** - * uint64 out = 2; - */ - long getOut(); + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNode getFilterNodes(int index) { + return filterNodes_.get(index); + } - /** - * bool is_single = 3; - */ - boolean getIsSingle(); + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder(int index) { + return filterNodes_.get(index); } - /** - * Protobuf type {@code EdgeLimit} - */ - public static final class EdgeLimit extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:EdgeLimit) - EdgeLimitOrBuilder { - private static final long serialVersionUID = 0L; + private byte memoizedIsInitialized = -1; - // Use EdgeLimit.newBuilder() to construct. - private EdgeLimit(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - private EdgeLimit() { - } + memoizedIsInitialized = 1; + return true; + } - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new EdgeLimit(); - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + for (int i = 0; i < keys_.size(); i++) { + output.writeBytes(1, keys_.get(i)); + } + for (int i = 0; i < filterNodes_.size(); i++) { + output.writeMessage(2, filterNodes_.get(i)); + } + unknownFields.writeTo(output); + } - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + { + int dataSize = 0; + for (int i = 0; i < keys_.size(); i++) { + dataSize += com.google.protobuf.CodedOutputStream.computeBytesSizeNoTag(keys_.get(i)); + } + size += dataSize; + size += 1 * getKeysList().size(); + } + for (int i = 0; i < filterNodes_.size(); i++) { + size += com.google.protobuf.CodedOutputStream.computeMessageSize(2, filterNodes_.get(i)); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - private EdgeLimit( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 8: { - - in_ = input.readUInt64(); - break; - } - case 16: { - - out_ = input.readUInt64(); - break; - } - case 24: { - - isSingle_ = input.readBool(); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.FilterNodes)) { + return super.equals(obj); + } + PushDownPb.FilterNodes other = (PushDownPb.FilterNodes) obj; + + if (!getKeysList().equals(other.getKeysList())) return false; + if (!getFilterNodesList().equals(other.getFilterNodesList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_EdgeLimit_descriptor; - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (getKeysCount() > 0) { + hash = (37 * hash) + KEYS_FIELD_NUMBER; + hash = (53 * hash) + getKeysList().hashCode(); + } + if (getFilterNodesCount() > 0) { + hash = (37 * hash) + FILTER_NODES_FIELD_NUMBER; + hash = (53 * hash) + getFilterNodesList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_EdgeLimit_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.EdgeLimit.class, PushDownPb.EdgeLimit.Builder.class); - } + public static PushDownPb.FilterNodes parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static final int IN_FIELD_NUMBER = 1; - private long in_; + public static PushDownPb.FilterNodes parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * uint64 in = 1; - */ - public long getIn() { - return in_; - } + public static PushDownPb.FilterNodes parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static final int OUT_FIELD_NUMBER = 2; - private long out_; + public static PushDownPb.FilterNodes parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * uint64 out = 2; - */ - public long getOut() { - return out_; - } + public static PushDownPb.FilterNodes parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static final int IS_SINGLE_FIELD_NUMBER = 3; - private boolean isSingle_; + public static PushDownPb.FilterNodes parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * bool is_single = 3; - */ - public boolean getIsSingle() { - return isSingle_; - } + public static PushDownPb.FilterNodes parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - private byte memoizedIsInitialized = -1; + public static PushDownPb.FilterNodes parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + public static PushDownPb.FilterNodes parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - memoizedIsInitialized = 1; - return true; - } + public static PushDownPb.FilterNodes parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (in_ != 0L) { - output.writeUInt64(1, in_); - } - if (out_ != 0L) { - output.writeUInt64(2, out_); - } - if (isSingle_ != false) { - output.writeBool(3, isSingle_); - } - unknownFields.writeTo(output); - } + public static PushDownPb.FilterNodes parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + public static PushDownPb.FilterNodes parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - size = 0; - if (in_ != 0L) { - size += com.google.protobuf.CodedOutputStream - .computeUInt64Size(1, in_); - } - if (out_ != 0L) { - size += com.google.protobuf.CodedOutputStream - .computeUInt64Size(2, out_); - } - if (isSingle_ != false) { - size += com.google.protobuf.CodedOutputStream - .computeBoolSize(3, isSingle_); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.EdgeLimit)) { - return super.equals(obj); - } - PushDownPb.EdgeLimit other = (PushDownPb.EdgeLimit) obj; - - if (getIn() - != other.getIn()) return false; - if (getOut() - != other.getOut()) return false; - if (getIsSingle() - != other.getIsSingle()) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + IN_FIELD_NUMBER; - hash = (53 * hash) + com.google.protobuf.Internal.hashLong( - getIn()); - hash = (37 * hash) + OUT_FIELD_NUMBER; - hash = (53 * hash) + com.google.protobuf.Internal.hashLong( - getOut()); - hash = (37 * hash) + IS_SINGLE_FIELD_NUMBER; - hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( - getIsSingle()); - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - public static PushDownPb.EdgeLimit parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static Builder newBuilder(PushDownPb.FilterNodes prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - public static PushDownPb.EdgeLimit parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - public static PushDownPb.EdgeLimit parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - public static PushDownPb.EdgeLimit parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + /** Protobuf type {@code FilterNodes} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:FilterNodes) + PushDownPb.FilterNodesOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_FilterNodes_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_FilterNodes_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.FilterNodes.class, PushDownPb.FilterNodes.Builder.class); + } + + // Construct using PushDownPb.FilterNodes.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders) { + getFilterNodesFieldBuilder(); + } + } + + @java.lang.Override + public Builder clear() { + super.clear(); + keys_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + if (filterNodesBuilder_ == null) { + filterNodes_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000002); + } else { + filterNodesBuilder_.clear(); + } + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_FilterNodes_descriptor; + } + + @java.lang.Override + public PushDownPb.FilterNodes getDefaultInstanceForType() { + return PushDownPb.FilterNodes.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.FilterNodes build() { + PushDownPb.FilterNodes result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.FilterNodes buildPartial() { + PushDownPb.FilterNodes result = new PushDownPb.FilterNodes(this); + if (((bitField0_ & 0x00000001) != 0)) { + keys_ = java.util.Collections.unmodifiableList(keys_); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.keys_ = keys_; + if (filterNodesBuilder_ == null) { + if (((bitField0_ & 0x00000002) != 0)) { + filterNodes_ = java.util.Collections.unmodifiableList(filterNodes_); + bitField0_ = (bitField0_ & ~0x00000002); + } + result.filterNodes_ = filterNodes_; + } else { + result.filterNodes_ = filterNodesBuilder_.build(); + } + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.FilterNodes) { + return mergeFrom((PushDownPb.FilterNodes) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.FilterNodes other) { + if (other == PushDownPb.FilterNodes.getDefaultInstance()) return this; + if (!other.keys_.isEmpty()) { + if (keys_.isEmpty()) { + keys_ = other.keys_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureKeysIsMutable(); + keys_.addAll(other.keys_); + } + onChanged(); + } + if (filterNodesBuilder_ == null) { + if (!other.filterNodes_.isEmpty()) { + if (filterNodes_.isEmpty()) { + filterNodes_ = other.filterNodes_; + bitField0_ = (bitField0_ & ~0x00000002); + } else { + ensureFilterNodesIsMutable(); + filterNodes_.addAll(other.filterNodes_); + } + onChanged(); + } + } else { + if (!other.filterNodes_.isEmpty()) { + if (filterNodesBuilder_.isEmpty()) { + filterNodesBuilder_.dispose(); + filterNodesBuilder_ = null; + filterNodes_ = other.filterNodes_; + bitField0_ = (bitField0_ & ~0x00000002); + filterNodesBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders + ? getFilterNodesFieldBuilder() + : null; + } else { + filterNodesBuilder_.addAllMessages(other.filterNodes_); + } + } + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.FilterNodes parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.FilterNodes) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private java.util.List keys_ = + java.util.Collections.emptyList(); + + private void ensureKeysIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + keys_ = new java.util.ArrayList(keys_); + bitField0_ |= 0x00000001; + } + } + + /** repeated bytes keys = 1; */ + public java.util.List getKeysList() { + return ((bitField0_ & 0x00000001) != 0) + ? java.util.Collections.unmodifiableList(keys_) + : keys_; + } + + /** repeated bytes keys = 1; */ + public int getKeysCount() { + return keys_.size(); + } + + /** repeated bytes keys = 1; */ + public com.google.protobuf.ByteString getKeys(int index) { + return keys_.get(index); + } + + /** repeated bytes keys = 1; */ + public Builder setKeys(int index, com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + ensureKeysIsMutable(); + keys_.set(index, value); + onChanged(); + return this; + } + + /** repeated bytes keys = 1; */ + public Builder addKeys(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + ensureKeysIsMutable(); + keys_.add(value); + onChanged(); + return this; + } + + /** repeated bytes keys = 1; */ + public Builder addAllKeys( + java.lang.Iterable values) { + ensureKeysIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, keys_); + onChanged(); + return this; + } + + /** repeated bytes keys = 1; */ + public Builder clearKeys() { + keys_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + private java.util.List filterNodes_ = + java.util.Collections.emptyList(); + + private void ensureFilterNodesIsMutable() { + if (!((bitField0_ & 0x00000002) != 0)) { + filterNodes_ = new java.util.ArrayList(filterNodes_); + bitField0_ |= 0x00000002; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + filterNodesBuilder_; + + /** repeated .FilterNode filter_nodes = 2; */ + public java.util.List getFilterNodesList() { + if (filterNodesBuilder_ == null) { + return java.util.Collections.unmodifiableList(filterNodes_); + } else { + return filterNodesBuilder_.getMessageList(); + } + } + + /** repeated .FilterNode filter_nodes = 2; */ + public int getFilterNodesCount() { + if (filterNodesBuilder_ == null) { + return filterNodes_.size(); + } else { + return filterNodesBuilder_.getCount(); + } + } + + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNode getFilterNodes(int index) { + if (filterNodesBuilder_ == null) { + return filterNodes_.get(index); + } else { + return filterNodesBuilder_.getMessage(index); + } + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder setFilterNodes(int index, PushDownPb.FilterNode value) { + if (filterNodesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFilterNodesIsMutable(); + filterNodes_.set(index, value); + onChanged(); + } else { + filterNodesBuilder_.setMessage(index, value); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder setFilterNodes(int index, PushDownPb.FilterNode.Builder builderForValue) { + if (filterNodesBuilder_ == null) { + ensureFilterNodesIsMutable(); + filterNodes_.set(index, builderForValue.build()); + onChanged(); + } else { + filterNodesBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder addFilterNodes(PushDownPb.FilterNode value) { + if (filterNodesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFilterNodesIsMutable(); + filterNodes_.add(value); + onChanged(); + } else { + filterNodesBuilder_.addMessage(value); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder addFilterNodes(int index, PushDownPb.FilterNode value) { + if (filterNodesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFilterNodesIsMutable(); + filterNodes_.add(index, value); + onChanged(); + } else { + filterNodesBuilder_.addMessage(index, value); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder addFilterNodes(PushDownPb.FilterNode.Builder builderForValue) { + if (filterNodesBuilder_ == null) { + ensureFilterNodesIsMutable(); + filterNodes_.add(builderForValue.build()); + onChanged(); + } else { + filterNodesBuilder_.addMessage(builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder addFilterNodes(int index, PushDownPb.FilterNode.Builder builderForValue) { + if (filterNodesBuilder_ == null) { + ensureFilterNodesIsMutable(); + filterNodes_.add(index, builderForValue.build()); + onChanged(); + } else { + filterNodesBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder addAllFilterNodes(java.lang.Iterable values) { + if (filterNodesBuilder_ == null) { + ensureFilterNodesIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, filterNodes_); + onChanged(); + } else { + filterNodesBuilder_.addAllMessages(values); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder clearFilterNodes() { + if (filterNodesBuilder_ == null) { + filterNodes_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000002); + onChanged(); + } else { + filterNodesBuilder_.clear(); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public Builder removeFilterNodes(int index) { + if (filterNodesBuilder_ == null) { + ensureFilterNodesIsMutable(); + filterNodes_.remove(index); + onChanged(); + } else { + filterNodesBuilder_.remove(index); + } + return this; + } + + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNode.Builder getFilterNodesBuilder(int index) { + return getFilterNodesFieldBuilder().getBuilder(index); + } + + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder(int index) { + if (filterNodesBuilder_ == null) { + return filterNodes_.get(index); + } else { + return filterNodesBuilder_.getMessageOrBuilder(index); + } + } + + /** repeated .FilterNode filter_nodes = 2; */ + public java.util.List + getFilterNodesOrBuilderList() { + if (filterNodesBuilder_ != null) { + return filterNodesBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(filterNodes_); + } + } + + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNode.Builder addFilterNodesBuilder() { + return getFilterNodesFieldBuilder().addBuilder(PushDownPb.FilterNode.getDefaultInstance()); + } + + /** repeated .FilterNode filter_nodes = 2; */ + public PushDownPb.FilterNode.Builder addFilterNodesBuilder(int index) { + return getFilterNodesFieldBuilder() + .addBuilder(index, PushDownPb.FilterNode.getDefaultInstance()); + } + + /** repeated .FilterNode filter_nodes = 2; */ + public java.util.List getFilterNodesBuilderList() { + return getFilterNodesFieldBuilder().getBuilderList(); + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + getFilterNodesFieldBuilder() { + if (filterNodesBuilder_ == null) { + filterNodesBuilder_ = + new com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, + PushDownPb.FilterNode.Builder, + PushDownPb.FilterNodeOrBuilder>( + filterNodes_, + ((bitField0_ & 0x00000002) != 0), + getParentForChildren(), + isClean()); + filterNodes_ = null; + } + return filterNodesBuilder_; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:FilterNodes) + } - public static PushDownPb.EdgeLimit parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + // @@protoc_insertion_point(class_scope:FilterNodes) + private static final PushDownPb.FilterNodes DEFAULT_INSTANCE; - public static PushDownPb.EdgeLimit parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + static { + DEFAULT_INSTANCE = new PushDownPb.FilterNodes(); + } - public static PushDownPb.EdgeLimit parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + public static PushDownPb.FilterNodes getDefaultInstance() { + return DEFAULT_INSTANCE; + } - public static PushDownPb.EdgeLimit parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public FilterNodes parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new FilterNodes(input, extensionRegistry); + } + }; - public static PushDownPb.EdgeLimit parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - public static PushDownPb.EdgeLimit parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - public static PushDownPb.EdgeLimit parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + @java.lang.Override + public PushDownPb.FilterNodes getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - public static PushDownPb.EdgeLimit parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public interface FilterNodeOrBuilder + extends + // @@protoc_insertion_point(interface_extends:FilterNode) + com.google.protobuf.MessageOrBuilder { - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + /** .FilterType filter_type = 1; */ + int getFilterTypeValue(); - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + /** .FilterType filter_type = 1; */ + PushDownPb.FilterType getFilterType(); - public static Builder newBuilder(PushDownPb.EdgeLimit prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + /** repeated .FilterNode filters = 2; */ + java.util.List getFiltersList(); - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + /** repeated .FilterNode filters = 2; */ + PushDownPb.FilterNode getFilters(int index); - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + /** repeated .FilterNode filters = 2; */ + int getFiltersCount(); - /** - * Protobuf type {@code EdgeLimit} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:EdgeLimit) - PushDownPb.EdgeLimitOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_EdgeLimit_descriptor; - } + /** repeated .FilterNode filters = 2; */ + java.util.List getFiltersOrBuilderList(); - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_EdgeLimit_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.EdgeLimit.class, PushDownPb.EdgeLimit.Builder.class); - } + /** repeated .FilterNode filters = 2; */ + PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder(int index); - // Construct using PushDownPb.EdgeLimit.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } + /** .IntList int_content = 3; */ + boolean hasIntContent(); - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } + /** .IntList int_content = 3; */ + PushDownPb.IntList getIntContent(); - private void maybeForceBuilderInitialization() { - } + /** .IntList int_content = 3; */ + PushDownPb.IntListOrBuilder getIntContentOrBuilder(); - @java.lang.Override - public Builder clear() { - super.clear(); - in_ = 0L; + /** .LongList long_content = 4; */ + boolean hasLongContent(); - out_ = 0L; + /** .LongList long_content = 4; */ + PushDownPb.LongList getLongContent(); - isSingle_ = false; + /** .LongList long_content = 4; */ + PushDownPb.LongListOrBuilder getLongContentOrBuilder(); - return this; - } + /** .StringList str_content = 5; */ + boolean hasStrContent(); - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_EdgeLimit_descriptor; - } + /** .StringList str_content = 5; */ + PushDownPb.StringList getStrContent(); - @java.lang.Override - public PushDownPb.EdgeLimit getDefaultInstanceForType() { - return PushDownPb.EdgeLimit.getDefaultInstance(); - } + /** .StringList str_content = 5; */ + PushDownPb.StringListOrBuilder getStrContentOrBuilder(); - @java.lang.Override - public PushDownPb.EdgeLimit build() { - PushDownPb.EdgeLimit result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } + /** .BytesList bytes_content = 6; */ + boolean hasBytesContent(); - @java.lang.Override - public PushDownPb.EdgeLimit buildPartial() { - PushDownPb.EdgeLimit result = new PushDownPb.EdgeLimit(this); - result.in_ = in_; - result.out_ = out_; - result.isSingle_ = isSingle_; - onBuilt(); - return result; - } + /** .BytesList bytes_content = 6; */ + PushDownPb.BytesList getBytesContent(); - @java.lang.Override - public Builder clone() { - return super.clone(); - } + /** .BytesList bytes_content = 6; */ + PushDownPb.BytesListOrBuilder getBytesContentOrBuilder(); - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } + public PushDownPb.FilterNode.ContentCase getContentCase(); + } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } + /** Protobuf type {@code FilterNode} */ + public static final class FilterNode extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:FilterNode) + FilterNodeOrBuilder { + private static final long serialVersionUID = 0L; - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } + // Use FilterNode.newBuilder() to construct. + private FilterNode(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } + private FilterNode() { + filterType_ = 0; + filters_ = java.util.Collections.emptyList(); + } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new FilterNode(); + } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.EdgeLimit) { - return mergeFrom((PushDownPb.EdgeLimit) other); - } else { - super.mergeFrom(other); - return this; - } - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - public Builder mergeFrom(PushDownPb.EdgeLimit other) { - if (other == PushDownPb.EdgeLimit.getDefaultInstance()) return this; - if (other.getIn() != 0L) { - setIn(other.getIn()); - } - if (other.getOut() != 0L) { - setOut(other.getOut()); + private FilterNode( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: + { + int rawValue = input.readEnum(); + + filterType_ = rawValue; + break; + } + case 18: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + filters_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000001; + } + filters_.add(input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry)); + break; + } + case 26: + { + PushDownPb.IntList.Builder subBuilder = null; + if (contentCase_ == 3) { + subBuilder = ((PushDownPb.IntList) content_).toBuilder(); } - if (other.getIsSingle() != false) { - setIsSingle(other.getIsSingle()); + content_ = input.readMessage(PushDownPb.IntList.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.IntList) content_); + content_ = subBuilder.buildPartial(); } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.EdgeLimit parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.EdgeLimit) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } + contentCase_ = 3; + break; + } + case 34: + { + PushDownPb.LongList.Builder subBuilder = null; + if (contentCase_ == 4) { + subBuilder = ((PushDownPb.LongList) content_).toBuilder(); } - return this; - } - - private long in_; - - /** - * uint64 in = 1; - */ - public long getIn() { - return in_; - } - - /** - * uint64 in = 1; - */ - public Builder setIn(long value) { - - in_ = value; - onChanged(); - return this; - } - - /** - * uint64 in = 1; - */ - public Builder clearIn() { - - in_ = 0L; - onChanged(); - return this; - } - - private long out_; - - /** - * uint64 out = 2; - */ - public long getOut() { - return out_; - } - - /** - * uint64 out = 2; - */ - public Builder setOut(long value) { - - out_ = value; - onChanged(); - return this; - } - - /** - * uint64 out = 2; - */ - public Builder clearOut() { - - out_ = 0L; - onChanged(); - return this; - } - - private boolean isSingle_; - - /** - * bool is_single = 3; - */ - public boolean getIsSingle() { - return isSingle_; - } - - /** - * bool is_single = 3; - */ - public Builder setIsSingle(boolean value) { - - isSingle_ = value; - onChanged(); - return this; - } - - /** - * bool is_single = 3; - */ - public Builder clearIsSingle() { - - isSingle_ = false; - onChanged(); - return this; - } - - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:EdgeLimit) - } - - // @@protoc_insertion_point(class_scope:EdgeLimit) - private static final PushDownPb.EdgeLimit DEFAULT_INSTANCE; - - static { - DEFAULT_INSTANCE = new PushDownPb.EdgeLimit(); - } - - public static PushDownPb.EdgeLimit getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public EdgeLimit parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new EdgeLimit(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public PushDownPb.EdgeLimit getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - public interface FilterNodesOrBuilder extends - // @@protoc_insertion_point(interface_extends:FilterNodes) - com.google.protobuf.MessageOrBuilder { - - /** - * repeated bytes keys = 1; - */ - java.util.List getKeysList(); - - /** - * repeated bytes keys = 1; - */ - int getKeysCount(); - - /** - * repeated bytes keys = 1; - */ - com.google.protobuf.ByteString getKeys(int index); - - /** - * repeated .FilterNode filter_nodes = 2; - */ - java.util.List - getFilterNodesList(); - - /** - * repeated .FilterNode filter_nodes = 2; - */ - PushDownPb.FilterNode getFilterNodes(int index); - - /** - * repeated .FilterNode filter_nodes = 2; - */ - int getFilterNodesCount(); - - /** - * repeated .FilterNode filter_nodes = 2; - */ - java.util.List - getFilterNodesOrBuilderList(); - - /** - * repeated .FilterNode filter_nodes = 2; - */ - PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder( - int index); - } - - /** - * Protobuf type {@code FilterNodes} - */ - public static final class FilterNodes extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:FilterNodes) - FilterNodesOrBuilder { - private static final long serialVersionUID = 0L; - - // Use FilterNodes.newBuilder() to construct. - private FilterNodes(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - - private FilterNodes() { - keys_ = java.util.Collections.emptyList(); - filterNodes_ = java.util.Collections.emptyList(); - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new FilterNodes(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - - private FilterNodes( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - keys_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - keys_.add(input.readBytes()); - break; - } - case 18: { - if (!((mutable_bitField0_ & 0x00000002) != 0)) { - filterNodes_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000002; - } - filterNodes_.add( - input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry)); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - keys_ = java.util.Collections.unmodifiableList(keys_); // C - } - if (((mutable_bitField0_ & 0x00000002) != 0)) { - filterNodes_ = java.util.Collections.unmodifiableList(filterNodes_); - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_FilterNodes_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_FilterNodes_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.FilterNodes.class, PushDownPb.FilterNodes.Builder.class); - } - - public static final int KEYS_FIELD_NUMBER = 1; - private java.util.List keys_; - - /** - * repeated bytes keys = 1; - */ - public java.util.List - getKeysList() { - return keys_; - } - - /** - * repeated bytes keys = 1; - */ - public int getKeysCount() { - return keys_.size(); - } - - /** - * repeated bytes keys = 1; - */ - public com.google.protobuf.ByteString getKeys(int index) { - return keys_.get(index); - } - - public static final int FILTER_NODES_FIELD_NUMBER = 2; - private java.util.List filterNodes_; - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public java.util.List getFilterNodesList() { - return filterNodes_; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public java.util.List - getFilterNodesOrBuilderList() { - return filterNodes_; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public int getFilterNodesCount() { - return filterNodes_.size(); - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNode getFilterNodes(int index) { - return filterNodes_.get(index); - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder( - int index) { - return filterNodes_.get(index); - } - - private byte memoizedIsInitialized = -1; - - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - for (int i = 0; i < keys_.size(); i++) { - output.writeBytes(1, keys_.get(i)); - } - for (int i = 0; i < filterNodes_.size(); i++) { - output.writeMessage(2, filterNodes_.get(i)); - } - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - { - int dataSize = 0; - for (int i = 0; i < keys_.size(); i++) { - dataSize += com.google.protobuf.CodedOutputStream - .computeBytesSizeNoTag(keys_.get(i)); - } - size += dataSize; - size += 1 * getKeysList().size(); - } - for (int i = 0; i < filterNodes_.size(); i++) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(2, filterNodes_.get(i)); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.FilterNodes)) { - return super.equals(obj); - } - PushDownPb.FilterNodes other = (PushDownPb.FilterNodes) obj; - - if (!getKeysList() - .equals(other.getKeysList())) return false; - if (!getFilterNodesList() - .equals(other.getFilterNodesList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getKeysCount() > 0) { - hash = (37 * hash) + KEYS_FIELD_NUMBER; - hash = (53 * hash) + getKeysList().hashCode(); - } - if (getFilterNodesCount() > 0) { - hash = (37 * hash) + FILTER_NODES_FIELD_NUMBER; - hash = (53 * hash) + getFilterNodesList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static PushDownPb.FilterNodes parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNodes parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNodes parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNodes parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNodes parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNodes parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNodes parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNodes parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.FilterNodes parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNodes parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.FilterNodes parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNodes parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } - - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - - public static Builder newBuilder(PushDownPb.FilterNodes prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - - /** - * Protobuf type {@code FilterNodes} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:FilterNodes) - PushDownPb.FilterNodesOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_FilterNodes_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_FilterNodes_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.FilterNodes.class, PushDownPb.FilterNodes.Builder.class); - } - - // Construct using PushDownPb.FilterNodes.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - getFilterNodesFieldBuilder(); - } - } - - @java.lang.Override - public Builder clear() { - super.clear(); - keys_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - if (filterNodesBuilder_ == null) { - filterNodes_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000002); - } else { - filterNodesBuilder_.clear(); - } - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_FilterNodes_descriptor; - } - - @java.lang.Override - public PushDownPb.FilterNodes getDefaultInstanceForType() { - return PushDownPb.FilterNodes.getDefaultInstance(); - } - - @java.lang.Override - public PushDownPb.FilterNodes build() { - PushDownPb.FilterNodes result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public PushDownPb.FilterNodes buildPartial() { - PushDownPb.FilterNodes result = new PushDownPb.FilterNodes(this); - if (((bitField0_ & 0x00000001) != 0)) { - keys_ = java.util.Collections.unmodifiableList(keys_); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.keys_ = keys_; - if (filterNodesBuilder_ == null) { - if (((bitField0_ & 0x00000002) != 0)) { - filterNodes_ = java.util.Collections.unmodifiableList(filterNodes_); - bitField0_ = (bitField0_ & ~0x00000002); - } - result.filterNodes_ = filterNodes_; - } else { - result.filterNodes_ = filterNodesBuilder_.build(); - } - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.FilterNodes) { - return mergeFrom((PushDownPb.FilterNodes) other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(PushDownPb.FilterNodes other) { - if (other == PushDownPb.FilterNodes.getDefaultInstance()) return this; - if (!other.keys_.isEmpty()) { - if (keys_.isEmpty()) { - keys_ = other.keys_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureKeysIsMutable(); - keys_.addAll(other.keys_); - } - onChanged(); - } - if (filterNodesBuilder_ == null) { - if (!other.filterNodes_.isEmpty()) { - if (filterNodes_.isEmpty()) { - filterNodes_ = other.filterNodes_; - bitField0_ = (bitField0_ & ~0x00000002); - } else { - ensureFilterNodesIsMutable(); - filterNodes_.addAll(other.filterNodes_); - } - onChanged(); - } - } else { - if (!other.filterNodes_.isEmpty()) { - if (filterNodesBuilder_.isEmpty()) { - filterNodesBuilder_.dispose(); - filterNodesBuilder_ = null; - filterNodes_ = other.filterNodes_; - bitField0_ = (bitField0_ & ~0x00000002); - filterNodesBuilder_ = - com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? - getFilterNodesFieldBuilder() : null; - } else { - filterNodesBuilder_.addAllMessages(other.filterNodes_); - } - } - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.FilterNodes parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.FilterNodes) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - - private int bitField0_; - - private java.util.List keys_ = java.util.Collections.emptyList(); - - private void ensureKeysIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - keys_ = new java.util.ArrayList(keys_); - bitField0_ |= 0x00000001; - } - } - - /** - * repeated bytes keys = 1; - */ - public java.util.List - getKeysList() { - return ((bitField0_ & 0x00000001) != 0) ? - java.util.Collections.unmodifiableList(keys_) : keys_; - } - - /** - * repeated bytes keys = 1; - */ - public int getKeysCount() { - return keys_.size(); - } - - /** - * repeated bytes keys = 1; - */ - public com.google.protobuf.ByteString getKeys(int index) { - return keys_.get(index); - } - - /** - * repeated bytes keys = 1; - */ - public Builder setKeys( - int index, com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - ensureKeysIsMutable(); - keys_.set(index, value); - onChanged(); - return this; - } - - /** - * repeated bytes keys = 1; - */ - public Builder addKeys(com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - ensureKeysIsMutable(); - keys_.add(value); - onChanged(); - return this; - } - - /** - * repeated bytes keys = 1; - */ - public Builder addAllKeys( - java.lang.Iterable values) { - ensureKeysIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, keys_); - onChanged(); - return this; - } - - /** - * repeated bytes keys = 1; - */ - public Builder clearKeys() { - keys_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } - - private java.util.List filterNodes_ = - java.util.Collections.emptyList(); - - private void ensureFilterNodesIsMutable() { - if (!((bitField0_ & 0x00000002) != 0)) { - filterNodes_ = new java.util.ArrayList(filterNodes_); - bitField0_ |= 0x00000002; - } - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> filterNodesBuilder_; - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public java.util.List getFilterNodesList() { - if (filterNodesBuilder_ == null) { - return java.util.Collections.unmodifiableList(filterNodes_); - } else { - return filterNodesBuilder_.getMessageList(); - } - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public int getFilterNodesCount() { - if (filterNodesBuilder_ == null) { - return filterNodes_.size(); - } else { - return filterNodesBuilder_.getCount(); - } - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNode getFilterNodes(int index) { - if (filterNodesBuilder_ == null) { - return filterNodes_.get(index); - } else { - return filterNodesBuilder_.getMessage(index); - } - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder setFilterNodes( - int index, PushDownPb.FilterNode value) { - if (filterNodesBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFilterNodesIsMutable(); - filterNodes_.set(index, value); - onChanged(); - } else { - filterNodesBuilder_.setMessage(index, value); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder setFilterNodes( - int index, PushDownPb.FilterNode.Builder builderForValue) { - if (filterNodesBuilder_ == null) { - ensureFilterNodesIsMutable(); - filterNodes_.set(index, builderForValue.build()); - onChanged(); - } else { - filterNodesBuilder_.setMessage(index, builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder addFilterNodes(PushDownPb.FilterNode value) { - if (filterNodesBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFilterNodesIsMutable(); - filterNodes_.add(value); - onChanged(); - } else { - filterNodesBuilder_.addMessage(value); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder addFilterNodes( - int index, PushDownPb.FilterNode value) { - if (filterNodesBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFilterNodesIsMutable(); - filterNodes_.add(index, value); - onChanged(); - } else { - filterNodesBuilder_.addMessage(index, value); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder addFilterNodes( - PushDownPb.FilterNode.Builder builderForValue) { - if (filterNodesBuilder_ == null) { - ensureFilterNodesIsMutable(); - filterNodes_.add(builderForValue.build()); - onChanged(); - } else { - filterNodesBuilder_.addMessage(builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder addFilterNodes( - int index, PushDownPb.FilterNode.Builder builderForValue) { - if (filterNodesBuilder_ == null) { - ensureFilterNodesIsMutable(); - filterNodes_.add(index, builderForValue.build()); - onChanged(); - } else { - filterNodesBuilder_.addMessage(index, builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder addAllFilterNodes( - java.lang.Iterable values) { - if (filterNodesBuilder_ == null) { - ensureFilterNodesIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, filterNodes_); - onChanged(); - } else { - filterNodesBuilder_.addAllMessages(values); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder clearFilterNodes() { - if (filterNodesBuilder_ == null) { - filterNodes_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000002); - onChanged(); - } else { - filterNodesBuilder_.clear(); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public Builder removeFilterNodes(int index) { - if (filterNodesBuilder_ == null) { - ensureFilterNodesIsMutable(); - filterNodes_.remove(index); - onChanged(); - } else { - filterNodesBuilder_.remove(index); - } - return this; - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNode.Builder getFilterNodesBuilder( - int index) { - return getFilterNodesFieldBuilder().getBuilder(index); - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNodeOrBuilder getFilterNodesOrBuilder( - int index) { - if (filterNodesBuilder_ == null) { - return filterNodes_.get(index); - } else { - return filterNodesBuilder_.getMessageOrBuilder(index); - } - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public java.util.List - getFilterNodesOrBuilderList() { - if (filterNodesBuilder_ != null) { - return filterNodesBuilder_.getMessageOrBuilderList(); - } else { - return java.util.Collections.unmodifiableList(filterNodes_); - } - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNode.Builder addFilterNodesBuilder() { - return getFilterNodesFieldBuilder().addBuilder( - PushDownPb.FilterNode.getDefaultInstance()); - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public PushDownPb.FilterNode.Builder addFilterNodesBuilder( - int index) { - return getFilterNodesFieldBuilder().addBuilder( - index, PushDownPb.FilterNode.getDefaultInstance()); - } - - /** - * repeated .FilterNode filter_nodes = 2; - */ - public java.util.List - getFilterNodesBuilderList() { - return getFilterNodesFieldBuilder().getBuilderList(); - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> - getFilterNodesFieldBuilder() { - if (filterNodesBuilder_ == null) { - filterNodesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder>( - filterNodes_, - ((bitField0_ & 0x00000002) != 0), - getParentForChildren(), - isClean()); - filterNodes_ = null; - } - return filterNodesBuilder_; - } - - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:FilterNodes) - } - - // @@protoc_insertion_point(class_scope:FilterNodes) - private static final PushDownPb.FilterNodes DEFAULT_INSTANCE; - - static { - DEFAULT_INSTANCE = new PushDownPb.FilterNodes(); - } - - public static PushDownPb.FilterNodes getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public FilterNodes parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new FilterNodes(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public PushDownPb.FilterNodes getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - public interface FilterNodeOrBuilder extends - // @@protoc_insertion_point(interface_extends:FilterNode) - com.google.protobuf.MessageOrBuilder { - - /** - * .FilterType filter_type = 1; - */ - int getFilterTypeValue(); - - /** - * .FilterType filter_type = 1; - */ - PushDownPb.FilterType getFilterType(); - - /** - * repeated .FilterNode filters = 2; - */ - java.util.List - getFiltersList(); - - /** - * repeated .FilterNode filters = 2; - */ - PushDownPb.FilterNode getFilters(int index); - - /** - * repeated .FilterNode filters = 2; - */ - int getFiltersCount(); - - /** - * repeated .FilterNode filters = 2; - */ - java.util.List - getFiltersOrBuilderList(); - - /** - * repeated .FilterNode filters = 2; - */ - PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder( - int index); - - /** - * .IntList int_content = 3; - */ - boolean hasIntContent(); - - /** - * .IntList int_content = 3; - */ - PushDownPb.IntList getIntContent(); - - /** - * .IntList int_content = 3; - */ - PushDownPb.IntListOrBuilder getIntContentOrBuilder(); - - /** - * .LongList long_content = 4; - */ - boolean hasLongContent(); - - /** - * .LongList long_content = 4; - */ - PushDownPb.LongList getLongContent(); - - /** - * .LongList long_content = 4; - */ - PushDownPb.LongListOrBuilder getLongContentOrBuilder(); - - /** - * .StringList str_content = 5; - */ - boolean hasStrContent(); - - /** - * .StringList str_content = 5; - */ - PushDownPb.StringList getStrContent(); - - /** - * .StringList str_content = 5; - */ - PushDownPb.StringListOrBuilder getStrContentOrBuilder(); - - /** - * .BytesList bytes_content = 6; - */ - boolean hasBytesContent(); - - /** - * .BytesList bytes_content = 6; - */ - PushDownPb.BytesList getBytesContent(); - - /** - * .BytesList bytes_content = 6; - */ - PushDownPb.BytesListOrBuilder getBytesContentOrBuilder(); - - public PushDownPb.FilterNode.ContentCase getContentCase(); - } - - /** - * Protobuf type {@code FilterNode} - */ - public static final class FilterNode extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:FilterNode) - FilterNodeOrBuilder { - private static final long serialVersionUID = 0L; - - // Use FilterNode.newBuilder() to construct. - private FilterNode(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - - private FilterNode() { - filterType_ = 0; - filters_ = java.util.Collections.emptyList(); - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new FilterNode(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - - private FilterNode( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 8: { - int rawValue = input.readEnum(); - - filterType_ = rawValue; - break; - } - case 18: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - filters_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - filters_.add( - input.readMessage(PushDownPb.FilterNode.parser(), extensionRegistry)); - break; - } - case 26: { - PushDownPb.IntList.Builder subBuilder = null; - if (contentCase_ == 3) { - subBuilder = ((PushDownPb.IntList) content_).toBuilder(); - } - content_ = - input.readMessage(PushDownPb.IntList.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.IntList) content_); - content_ = subBuilder.buildPartial(); - } - contentCase_ = 3; - break; - } - case 34: { - PushDownPb.LongList.Builder subBuilder = null; - if (contentCase_ == 4) { - subBuilder = ((PushDownPb.LongList) content_).toBuilder(); - } - content_ = - input.readMessage(PushDownPb.LongList.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.LongList) content_); - content_ = subBuilder.buildPartial(); - } - contentCase_ = 4; - break; - } - case 42: { - PushDownPb.StringList.Builder subBuilder = null; - if (contentCase_ == 5) { - subBuilder = ((PushDownPb.StringList) content_).toBuilder(); - } - content_ = - input.readMessage(PushDownPb.StringList.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.StringList) content_); - content_ = subBuilder.buildPartial(); - } - contentCase_ = 5; - break; - } - case 50: { - PushDownPb.BytesList.Builder subBuilder = null; - if (contentCase_ == 6) { - subBuilder = ((PushDownPb.BytesList) content_).toBuilder(); - } - content_ = - input.readMessage(PushDownPb.BytesList.parser(), extensionRegistry); - if (subBuilder != null) { - subBuilder.mergeFrom((PushDownPb.BytesList) content_); - content_ = subBuilder.buildPartial(); - } - contentCase_ = 6; - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - filters_ = java.util.Collections.unmodifiableList(filters_); - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_FilterNode_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_FilterNode_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.FilterNode.class, PushDownPb.FilterNode.Builder.class); - } - - private int contentCase_ = 0; - private java.lang.Object content_; - - public enum ContentCase - implements com.google.protobuf.Internal.EnumLite { - INT_CONTENT(3), - LONG_CONTENT(4), - STR_CONTENT(5), - BYTES_CONTENT(6), - CONTENT_NOT_SET(0); - private final int value; - - private ContentCase(int value) { - this.value = value; - } - - /** - * @deprecated Use {@link #forNumber(int)} instead. - */ - @java.lang.Deprecated - public static ContentCase valueOf(int value) { - return forNumber(value); - } - - public static ContentCase forNumber(int value) { - switch (value) { - case 3: - return INT_CONTENT; - case 4: - return LONG_CONTENT; - case 5: - return STR_CONTENT; - case 6: - return BYTES_CONTENT; - case 0: - return CONTENT_NOT_SET; - default: - return null; - } - } - - public int getNumber() { - return this.value; - } - } - - ; - - public ContentCase - getContentCase() { - return ContentCase.forNumber( - contentCase_); - } - - public static final int FILTER_TYPE_FIELD_NUMBER = 1; - private int filterType_; - - /** - * .FilterType filter_type = 1; - */ - public int getFilterTypeValue() { - return filterType_; - } - - /** - * .FilterType filter_type = 1; - */ - public PushDownPb.FilterType getFilterType() { - @SuppressWarnings("deprecation") - PushDownPb.FilterType result = PushDownPb.FilterType.valueOf(filterType_); - return result == null ? PushDownPb.FilterType.UNRECOGNIZED : result; - } - - public static final int FILTERS_FIELD_NUMBER = 2; - private java.util.List filters_; - - /** - * repeated .FilterNode filters = 2; - */ - public java.util.List getFiltersList() { - return filters_; - } - - /** - * repeated .FilterNode filters = 2; - */ - public java.util.List - getFiltersOrBuilderList() { - return filters_; - } - - /** - * repeated .FilterNode filters = 2; - */ - public int getFiltersCount() { - return filters_.size(); - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNode getFilters(int index) { - return filters_.get(index); - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder( - int index) { - return filters_.get(index); - } - - public static final int INT_CONTENT_FIELD_NUMBER = 3; - - /** - * .IntList int_content = 3; - */ - public boolean hasIntContent() { - return contentCase_ == 3; - } - - /** - * .IntList int_content = 3; - */ - public PushDownPb.IntList getIntContent() { - if (contentCase_ == 3) { - return (PushDownPb.IntList) content_; - } - return PushDownPb.IntList.getDefaultInstance(); - } - - /** - * .IntList int_content = 3; - */ - public PushDownPb.IntListOrBuilder getIntContentOrBuilder() { - if (contentCase_ == 3) { - return (PushDownPb.IntList) content_; - } - return PushDownPb.IntList.getDefaultInstance(); - } - - public static final int LONG_CONTENT_FIELD_NUMBER = 4; - - /** - * .LongList long_content = 4; - */ - public boolean hasLongContent() { - return contentCase_ == 4; - } - - /** - * .LongList long_content = 4; - */ - public PushDownPb.LongList getLongContent() { - if (contentCase_ == 4) { - return (PushDownPb.LongList) content_; - } - return PushDownPb.LongList.getDefaultInstance(); - } - - /** - * .LongList long_content = 4; - */ - public PushDownPb.LongListOrBuilder getLongContentOrBuilder() { - if (contentCase_ == 4) { - return (PushDownPb.LongList) content_; - } - return PushDownPb.LongList.getDefaultInstance(); - } - - public static final int STR_CONTENT_FIELD_NUMBER = 5; - - /** - * .StringList str_content = 5; - */ - public boolean hasStrContent() { - return contentCase_ == 5; - } - - /** - * .StringList str_content = 5; - */ - public PushDownPb.StringList getStrContent() { - if (contentCase_ == 5) { - return (PushDownPb.StringList) content_; - } - return PushDownPb.StringList.getDefaultInstance(); - } - - /** - * .StringList str_content = 5; - */ - public PushDownPb.StringListOrBuilder getStrContentOrBuilder() { - if (contentCase_ == 5) { - return (PushDownPb.StringList) content_; - } - return PushDownPb.StringList.getDefaultInstance(); - } - - public static final int BYTES_CONTENT_FIELD_NUMBER = 6; - - /** - * .BytesList bytes_content = 6; - */ - public boolean hasBytesContent() { - return contentCase_ == 6; - } - - /** - * .BytesList bytes_content = 6; - */ - public PushDownPb.BytesList getBytesContent() { - if (contentCase_ == 6) { - return (PushDownPb.BytesList) content_; - } - return PushDownPb.BytesList.getDefaultInstance(); - } - - /** - * .BytesList bytes_content = 6; - */ - public PushDownPb.BytesListOrBuilder getBytesContentOrBuilder() { - if (contentCase_ == 6) { - return (PushDownPb.BytesList) content_; - } - return PushDownPb.BytesList.getDefaultInstance(); - } - - private byte memoizedIsInitialized = -1; - - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (filterType_ != PushDownPb.FilterType.EMPTY.getNumber()) { - output.writeEnum(1, filterType_); - } - for (int i = 0; i < filters_.size(); i++) { - output.writeMessage(2, filters_.get(i)); - } - if (contentCase_ == 3) { - output.writeMessage(3, (PushDownPb.IntList) content_); - } - if (contentCase_ == 4) { - output.writeMessage(4, (PushDownPb.LongList) content_); - } - if (contentCase_ == 5) { - output.writeMessage(5, (PushDownPb.StringList) content_); - } - if (contentCase_ == 6) { - output.writeMessage(6, (PushDownPb.BytesList) content_); - } - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - if (filterType_ != PushDownPb.FilterType.EMPTY.getNumber()) { - size += com.google.protobuf.CodedOutputStream - .computeEnumSize(1, filterType_); - } - for (int i = 0; i < filters_.size(); i++) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(2, filters_.get(i)); - } - if (contentCase_ == 3) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(3, (PushDownPb.IntList) content_); - } - if (contentCase_ == 4) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(4, (PushDownPb.LongList) content_); - } - if (contentCase_ == 5) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(5, (PushDownPb.StringList) content_); - } - if (contentCase_ == 6) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(6, (PushDownPb.BytesList) content_); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.FilterNode)) { - return super.equals(obj); - } - PushDownPb.FilterNode other = (PushDownPb.FilterNode) obj; - - if (filterType_ != other.filterType_) return false; - if (!getFiltersList() - .equals(other.getFiltersList())) return false; - if (!getContentCase().equals(other.getContentCase())) return false; - switch (contentCase_) { - case 3: - if (!getIntContent() - .equals(other.getIntContent())) return false; - break; - case 4: - if (!getLongContent() - .equals(other.getLongContent())) return false; - break; - case 5: - if (!getStrContent() - .equals(other.getStrContent())) return false; - break; - case 6: - if (!getBytesContent() - .equals(other.getBytesContent())) return false; - break; - case 0: - default: - } - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + FILTER_TYPE_FIELD_NUMBER; - hash = (53 * hash) + filterType_; - if (getFiltersCount() > 0) { - hash = (37 * hash) + FILTERS_FIELD_NUMBER; - hash = (53 * hash) + getFiltersList().hashCode(); - } - switch (contentCase_) { - case 3: - hash = (37 * hash) + INT_CONTENT_FIELD_NUMBER; - hash = (53 * hash) + getIntContent().hashCode(); - break; - case 4: - hash = (37 * hash) + LONG_CONTENT_FIELD_NUMBER; - hash = (53 * hash) + getLongContent().hashCode(); - break; - case 5: - hash = (37 * hash) + STR_CONTENT_FIELD_NUMBER; - hash = (53 * hash) + getStrContent().hashCode(); - break; - case 6: - hash = (37 * hash) + BYTES_CONTENT_FIELD_NUMBER; - hash = (53 * hash) + getBytesContent().hashCode(); - break; - case 0: - default: - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static PushDownPb.FilterNode parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNode parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNode parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNode parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNode parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.FilterNode parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.FilterNode parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNode parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.FilterNode parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNode parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.FilterNode parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.FilterNode parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } - - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - - public static Builder newBuilder(PushDownPb.FilterNode prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - - /** - * Protobuf type {@code FilterNode} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:FilterNode) - PushDownPb.FilterNodeOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_FilterNode_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_FilterNode_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.FilterNode.class, PushDownPb.FilterNode.Builder.class); - } - - // Construct using PushDownPb.FilterNode.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - getFiltersFieldBuilder(); - } - } - - @java.lang.Override - public Builder clear() { - super.clear(); - filterType_ = 0; - - if (filtersBuilder_ == null) { - filters_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - } else { - filtersBuilder_.clear(); - } - contentCase_ = 0; - content_ = null; - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_FilterNode_descriptor; - } - - @java.lang.Override - public PushDownPb.FilterNode getDefaultInstanceForType() { - return PushDownPb.FilterNode.getDefaultInstance(); - } - - @java.lang.Override - public PushDownPb.FilterNode build() { - PushDownPb.FilterNode result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public PushDownPb.FilterNode buildPartial() { - PushDownPb.FilterNode result = new PushDownPb.FilterNode(this); - result.filterType_ = filterType_; - if (filtersBuilder_ == null) { - if (((bitField0_ & 0x00000001) != 0)) { - filters_ = java.util.Collections.unmodifiableList(filters_); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.filters_ = filters_; - } else { - result.filters_ = filtersBuilder_.build(); - } - if (contentCase_ == 3) { - if (intContentBuilder_ == null) { - result.content_ = content_; - } else { - result.content_ = intContentBuilder_.build(); - } - } - if (contentCase_ == 4) { - if (longContentBuilder_ == null) { - result.content_ = content_; - } else { - result.content_ = longContentBuilder_.build(); - } - } - if (contentCase_ == 5) { - if (strContentBuilder_ == null) { - result.content_ = content_; - } else { - result.content_ = strContentBuilder_.build(); - } - } - if (contentCase_ == 6) { - if (bytesContentBuilder_ == null) { - result.content_ = content_; - } else { - result.content_ = bytesContentBuilder_.build(); - } - } - result.contentCase_ = contentCase_; - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.FilterNode) { - return mergeFrom((PushDownPb.FilterNode) other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(PushDownPb.FilterNode other) { - if (other == PushDownPb.FilterNode.getDefaultInstance()) return this; - if (other.filterType_ != 0) { - setFilterTypeValue(other.getFilterTypeValue()); - } - if (filtersBuilder_ == null) { - if (!other.filters_.isEmpty()) { - if (filters_.isEmpty()) { - filters_ = other.filters_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureFiltersIsMutable(); - filters_.addAll(other.filters_); - } - onChanged(); - } - } else { - if (!other.filters_.isEmpty()) { - if (filtersBuilder_.isEmpty()) { - filtersBuilder_.dispose(); - filtersBuilder_ = null; - filters_ = other.filters_; - bitField0_ = (bitField0_ & ~0x00000001); - filtersBuilder_ = - com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? - getFiltersFieldBuilder() : null; - } else { - filtersBuilder_.addAllMessages(other.filters_); - } - } - } - switch (other.getContentCase()) { - case INT_CONTENT: { - mergeIntContent(other.getIntContent()); - break; - } - case LONG_CONTENT: { - mergeLongContent(other.getLongContent()); - break; - } - case STR_CONTENT: { - mergeStrContent(other.getStrContent()); - break; - } - case BYTES_CONTENT: { - mergeBytesContent(other.getBytesContent()); - break; - } - case CONTENT_NOT_SET: { - break; - } - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.FilterNode parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.FilterNode) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - - private int contentCase_ = 0; - private java.lang.Object content_; - - public ContentCase - getContentCase() { - return ContentCase.forNumber( - contentCase_); - } - - public Builder clearContent() { - contentCase_ = 0; - content_ = null; - onChanged(); - return this; - } - - private int bitField0_; - - private int filterType_ = 0; - - /** - * .FilterType filter_type = 1; - */ - public int getFilterTypeValue() { - return filterType_; - } - - /** - * .FilterType filter_type = 1; - */ - public Builder setFilterTypeValue(int value) { - filterType_ = value; - onChanged(); - return this; - } - - /** - * .FilterType filter_type = 1; - */ - public PushDownPb.FilterType getFilterType() { - @SuppressWarnings("deprecation") - PushDownPb.FilterType result = PushDownPb.FilterType.valueOf(filterType_); - return result == null ? PushDownPb.FilterType.UNRECOGNIZED : result; - } - - /** - * .FilterType filter_type = 1; - */ - public Builder setFilterType(PushDownPb.FilterType value) { - if (value == null) { - throw new NullPointerException(); - } - - filterType_ = value.getNumber(); - onChanged(); - return this; - } - - /** - * .FilterType filter_type = 1; - */ - public Builder clearFilterType() { - - filterType_ = 0; - onChanged(); - return this; - } - - private java.util.List filters_ = - java.util.Collections.emptyList(); - - private void ensureFiltersIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - filters_ = new java.util.ArrayList(filters_); - bitField0_ |= 0x00000001; - } - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> filtersBuilder_; - - /** - * repeated .FilterNode filters = 2; - */ - public java.util.List getFiltersList() { - if (filtersBuilder_ == null) { - return java.util.Collections.unmodifiableList(filters_); - } else { - return filtersBuilder_.getMessageList(); - } - } - - /** - * repeated .FilterNode filters = 2; - */ - public int getFiltersCount() { - if (filtersBuilder_ == null) { - return filters_.size(); - } else { - return filtersBuilder_.getCount(); - } - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNode getFilters(int index) { - if (filtersBuilder_ == null) { - return filters_.get(index); - } else { - return filtersBuilder_.getMessage(index); - } - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder setFilters( - int index, PushDownPb.FilterNode value) { - if (filtersBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFiltersIsMutable(); - filters_.set(index, value); - onChanged(); - } else { - filtersBuilder_.setMessage(index, value); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder setFilters( - int index, PushDownPb.FilterNode.Builder builderForValue) { - if (filtersBuilder_ == null) { - ensureFiltersIsMutable(); - filters_.set(index, builderForValue.build()); - onChanged(); - } else { - filtersBuilder_.setMessage(index, builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder addFilters(PushDownPb.FilterNode value) { - if (filtersBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFiltersIsMutable(); - filters_.add(value); - onChanged(); - } else { - filtersBuilder_.addMessage(value); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder addFilters( - int index, PushDownPb.FilterNode value) { - if (filtersBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureFiltersIsMutable(); - filters_.add(index, value); - onChanged(); - } else { - filtersBuilder_.addMessage(index, value); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder addFilters( - PushDownPb.FilterNode.Builder builderForValue) { - if (filtersBuilder_ == null) { - ensureFiltersIsMutable(); - filters_.add(builderForValue.build()); - onChanged(); - } else { - filtersBuilder_.addMessage(builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder addFilters( - int index, PushDownPb.FilterNode.Builder builderForValue) { - if (filtersBuilder_ == null) { - ensureFiltersIsMutable(); - filters_.add(index, builderForValue.build()); - onChanged(); - } else { - filtersBuilder_.addMessage(index, builderForValue.build()); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder addAllFilters( - java.lang.Iterable values) { - if (filtersBuilder_ == null) { - ensureFiltersIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, filters_); - onChanged(); - } else { - filtersBuilder_.addAllMessages(values); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder clearFilters() { - if (filtersBuilder_ == null) { - filters_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - } else { - filtersBuilder_.clear(); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public Builder removeFilters(int index) { - if (filtersBuilder_ == null) { - ensureFiltersIsMutable(); - filters_.remove(index); - onChanged(); - } else { - filtersBuilder_.remove(index); - } - return this; - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNode.Builder getFiltersBuilder( - int index) { - return getFiltersFieldBuilder().getBuilder(index); - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder( - int index) { - if (filtersBuilder_ == null) { - return filters_.get(index); - } else { - return filtersBuilder_.getMessageOrBuilder(index); - } - } - - /** - * repeated .FilterNode filters = 2; - */ - public java.util.List - getFiltersOrBuilderList() { - if (filtersBuilder_ != null) { - return filtersBuilder_.getMessageOrBuilderList(); - } else { - return java.util.Collections.unmodifiableList(filters_); - } - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNode.Builder addFiltersBuilder() { - return getFiltersFieldBuilder().addBuilder( - PushDownPb.FilterNode.getDefaultInstance()); - } - - /** - * repeated .FilterNode filters = 2; - */ - public PushDownPb.FilterNode.Builder addFiltersBuilder( - int index) { - return getFiltersFieldBuilder().addBuilder( - index, PushDownPb.FilterNode.getDefaultInstance()); - } - - /** - * repeated .FilterNode filters = 2; - */ - public java.util.List - getFiltersBuilderList() { - return getFiltersFieldBuilder().getBuilderList(); - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> - getFiltersFieldBuilder() { - if (filtersBuilder_ == null) { - filtersBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< - PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder>( - filters_, - ((bitField0_ & 0x00000001) != 0), - getParentForChildren(), - isClean()); - filters_ = null; - } - return filtersBuilder_; - } - - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder> intContentBuilder_; - - /** - * .IntList int_content = 3; - */ - public boolean hasIntContent() { - return contentCase_ == 3; - } - - /** - * .IntList int_content = 3; - */ - public PushDownPb.IntList getIntContent() { - if (intContentBuilder_ == null) { - if (contentCase_ == 3) { - return (PushDownPb.IntList) content_; - } - return PushDownPb.IntList.getDefaultInstance(); - } else { - if (contentCase_ == 3) { - return intContentBuilder_.getMessage(); - } - return PushDownPb.IntList.getDefaultInstance(); - } - } - - /** - * .IntList int_content = 3; - */ - public Builder setIntContent(PushDownPb.IntList value) { - if (intContentBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - content_ = value; - onChanged(); - } else { - intContentBuilder_.setMessage(value); - } - contentCase_ = 3; - return this; - } - - /** - * .IntList int_content = 3; - */ - public Builder setIntContent( - PushDownPb.IntList.Builder builderForValue) { - if (intContentBuilder_ == null) { - content_ = builderForValue.build(); - onChanged(); - } else { - intContentBuilder_.setMessage(builderForValue.build()); - } - contentCase_ = 3; - return this; - } - - /** - * .IntList int_content = 3; - */ - public Builder mergeIntContent(PushDownPb.IntList value) { - if (intContentBuilder_ == null) { - if (contentCase_ == 3 && - content_ != PushDownPb.IntList.getDefaultInstance()) { - content_ = PushDownPb.IntList.newBuilder((PushDownPb.IntList) content_) - .mergeFrom(value).buildPartial(); - } else { - content_ = value; - } - onChanged(); - } else { - if (contentCase_ == 3) { - intContentBuilder_.mergeFrom(value); - } - intContentBuilder_.setMessage(value); - } - contentCase_ = 3; - return this; - } - - /** - * .IntList int_content = 3; - */ - public Builder clearIntContent() { - if (intContentBuilder_ == null) { - if (contentCase_ == 3) { - contentCase_ = 0; - content_ = null; - onChanged(); - } - } else { - if (contentCase_ == 3) { - contentCase_ = 0; - content_ = null; - } - intContentBuilder_.clear(); - } - return this; - } - - /** - * .IntList int_content = 3; - */ - public PushDownPb.IntList.Builder getIntContentBuilder() { - return getIntContentFieldBuilder().getBuilder(); - } - - /** - * .IntList int_content = 3; - */ - public PushDownPb.IntListOrBuilder getIntContentOrBuilder() { - if ((contentCase_ == 3) && (intContentBuilder_ != null)) { - return intContentBuilder_.getMessageOrBuilder(); - } else { - if (contentCase_ == 3) { - return (PushDownPb.IntList) content_; - } - return PushDownPb.IntList.getDefaultInstance(); - } - } - - /** - * .IntList int_content = 3; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder> - getIntContentFieldBuilder() { - if (intContentBuilder_ == null) { - if (!(contentCase_ == 3)) { - content_ = PushDownPb.IntList.getDefaultInstance(); - } - intContentBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder>( - (PushDownPb.IntList) content_, - getParentForChildren(), - isClean()); - content_ = null; - } - contentCase_ = 3; - onChanged(); - return intContentBuilder_; - } - - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder> longContentBuilder_; - - /** - * .LongList long_content = 4; - */ - public boolean hasLongContent() { - return contentCase_ == 4; - } - - /** - * .LongList long_content = 4; - */ - public PushDownPb.LongList getLongContent() { - if (longContentBuilder_ == null) { - if (contentCase_ == 4) { - return (PushDownPb.LongList) content_; - } - return PushDownPb.LongList.getDefaultInstance(); - } else { - if (contentCase_ == 4) { - return longContentBuilder_.getMessage(); - } - return PushDownPb.LongList.getDefaultInstance(); - } - } - - /** - * .LongList long_content = 4; - */ - public Builder setLongContent(PushDownPb.LongList value) { - if (longContentBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - content_ = value; - onChanged(); - } else { - longContentBuilder_.setMessage(value); - } - contentCase_ = 4; - return this; - } - - /** - * .LongList long_content = 4; - */ - public Builder setLongContent( - PushDownPb.LongList.Builder builderForValue) { - if (longContentBuilder_ == null) { - content_ = builderForValue.build(); - onChanged(); - } else { - longContentBuilder_.setMessage(builderForValue.build()); - } - contentCase_ = 4; - return this; - } - - /** - * .LongList long_content = 4; - */ - public Builder mergeLongContent(PushDownPb.LongList value) { - if (longContentBuilder_ == null) { - if (contentCase_ == 4 && - content_ != PushDownPb.LongList.getDefaultInstance()) { - content_ = PushDownPb.LongList.newBuilder((PushDownPb.LongList) content_) - .mergeFrom(value).buildPartial(); - } else { - content_ = value; - } - onChanged(); - } else { - if (contentCase_ == 4) { - longContentBuilder_.mergeFrom(value); - } - longContentBuilder_.setMessage(value); - } - contentCase_ = 4; - return this; - } - - /** - * .LongList long_content = 4; - */ - public Builder clearLongContent() { - if (longContentBuilder_ == null) { - if (contentCase_ == 4) { - contentCase_ = 0; - content_ = null; - onChanged(); - } - } else { - if (contentCase_ == 4) { - contentCase_ = 0; - content_ = null; - } - longContentBuilder_.clear(); - } - return this; - } - - /** - * .LongList long_content = 4; - */ - public PushDownPb.LongList.Builder getLongContentBuilder() { - return getLongContentFieldBuilder().getBuilder(); - } - - /** - * .LongList long_content = 4; - */ - public PushDownPb.LongListOrBuilder getLongContentOrBuilder() { - if ((contentCase_ == 4) && (longContentBuilder_ != null)) { - return longContentBuilder_.getMessageOrBuilder(); - } else { - if (contentCase_ == 4) { - return (PushDownPb.LongList) content_; - } - return PushDownPb.LongList.getDefaultInstance(); - } - } - - /** - * .LongList long_content = 4; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder> - getLongContentFieldBuilder() { - if (longContentBuilder_ == null) { - if (!(contentCase_ == 4)) { - content_ = PushDownPb.LongList.getDefaultInstance(); - } - longContentBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder>( - (PushDownPb.LongList) content_, - getParentForChildren(), - isClean()); - content_ = null; - } - contentCase_ = 4; - onChanged(); - return longContentBuilder_; - } - - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.StringList, PushDownPb.StringList.Builder, PushDownPb.StringListOrBuilder> strContentBuilder_; - - /** - * .StringList str_content = 5; - */ - public boolean hasStrContent() { - return contentCase_ == 5; - } - - /** - * .StringList str_content = 5; - */ - public PushDownPb.StringList getStrContent() { - if (strContentBuilder_ == null) { - if (contentCase_ == 5) { - return (PushDownPb.StringList) content_; - } - return PushDownPb.StringList.getDefaultInstance(); - } else { - if (contentCase_ == 5) { - return strContentBuilder_.getMessage(); - } - return PushDownPb.StringList.getDefaultInstance(); - } - } - - /** - * .StringList str_content = 5; - */ - public Builder setStrContent(PushDownPb.StringList value) { - if (strContentBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - content_ = value; - onChanged(); - } else { - strContentBuilder_.setMessage(value); - } - contentCase_ = 5; - return this; - } - - /** - * .StringList str_content = 5; - */ - public Builder setStrContent( - PushDownPb.StringList.Builder builderForValue) { - if (strContentBuilder_ == null) { - content_ = builderForValue.build(); - onChanged(); - } else { - strContentBuilder_.setMessage(builderForValue.build()); - } - contentCase_ = 5; - return this; - } - - /** - * .StringList str_content = 5; - */ - public Builder mergeStrContent(PushDownPb.StringList value) { - if (strContentBuilder_ == null) { - if (contentCase_ == 5 && - content_ != PushDownPb.StringList.getDefaultInstance()) { - content_ = PushDownPb.StringList.newBuilder((PushDownPb.StringList) content_) - .mergeFrom(value).buildPartial(); - } else { - content_ = value; - } - onChanged(); - } else { - if (contentCase_ == 5) { - strContentBuilder_.mergeFrom(value); - } - strContentBuilder_.setMessage(value); - } - contentCase_ = 5; - return this; - } - - /** - * .StringList str_content = 5; - */ - public Builder clearStrContent() { - if (strContentBuilder_ == null) { - if (contentCase_ == 5) { - contentCase_ = 0; - content_ = null; - onChanged(); - } - } else { - if (contentCase_ == 5) { - contentCase_ = 0; - content_ = null; - } - strContentBuilder_.clear(); - } - return this; - } - - /** - * .StringList str_content = 5; - */ - public PushDownPb.StringList.Builder getStrContentBuilder() { - return getStrContentFieldBuilder().getBuilder(); - } - - /** - * .StringList str_content = 5; - */ - public PushDownPb.StringListOrBuilder getStrContentOrBuilder() { - if ((contentCase_ == 5) && (strContentBuilder_ != null)) { - return strContentBuilder_.getMessageOrBuilder(); - } else { - if (contentCase_ == 5) { - return (PushDownPb.StringList) content_; - } - return PushDownPb.StringList.getDefaultInstance(); - } - } - - /** - * .StringList str_content = 5; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.StringList, PushDownPb.StringList.Builder, PushDownPb.StringListOrBuilder> - getStrContentFieldBuilder() { - if (strContentBuilder_ == null) { - if (!(contentCase_ == 5)) { - content_ = PushDownPb.StringList.getDefaultInstance(); - } - strContentBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.StringList, PushDownPb.StringList.Builder, PushDownPb.StringListOrBuilder>( - (PushDownPb.StringList) content_, - getParentForChildren(), - isClean()); - content_ = null; - } - contentCase_ = 5; - onChanged(); - return strContentBuilder_; - } - - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.BytesList, PushDownPb.BytesList.Builder, PushDownPb.BytesListOrBuilder> bytesContentBuilder_; - - /** - * .BytesList bytes_content = 6; - */ - public boolean hasBytesContent() { - return contentCase_ == 6; - } - - /** - * .BytesList bytes_content = 6; - */ - public PushDownPb.BytesList getBytesContent() { - if (bytesContentBuilder_ == null) { - if (contentCase_ == 6) { - return (PushDownPb.BytesList) content_; - } - return PushDownPb.BytesList.getDefaultInstance(); - } else { - if (contentCase_ == 6) { - return bytesContentBuilder_.getMessage(); - } - return PushDownPb.BytesList.getDefaultInstance(); - } - } - - /** - * .BytesList bytes_content = 6; - */ - public Builder setBytesContent(PushDownPb.BytesList value) { - if (bytesContentBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - content_ = value; - onChanged(); - } else { - bytesContentBuilder_.setMessage(value); - } - contentCase_ = 6; - return this; - } - - /** - * .BytesList bytes_content = 6; - */ - public Builder setBytesContent( - PushDownPb.BytesList.Builder builderForValue) { - if (bytesContentBuilder_ == null) { - content_ = builderForValue.build(); - onChanged(); - } else { - bytesContentBuilder_.setMessage(builderForValue.build()); - } - contentCase_ = 6; - return this; - } - - /** - * .BytesList bytes_content = 6; - */ - public Builder mergeBytesContent(PushDownPb.BytesList value) { - if (bytesContentBuilder_ == null) { - if (contentCase_ == 6 && - content_ != PushDownPb.BytesList.getDefaultInstance()) { - content_ = PushDownPb.BytesList.newBuilder((PushDownPb.BytesList) content_) - .mergeFrom(value).buildPartial(); - } else { - content_ = value; - } - onChanged(); - } else { - if (contentCase_ == 6) { - bytesContentBuilder_.mergeFrom(value); - } - bytesContentBuilder_.setMessage(value); - } - contentCase_ = 6; - return this; - } - - /** - * .BytesList bytes_content = 6; - */ - public Builder clearBytesContent() { - if (bytesContentBuilder_ == null) { - if (contentCase_ == 6) { - contentCase_ = 0; - content_ = null; - onChanged(); - } - } else { - if (contentCase_ == 6) { - contentCase_ = 0; - content_ = null; - } - bytesContentBuilder_.clear(); - } - return this; - } - - /** - * .BytesList bytes_content = 6; - */ - public PushDownPb.BytesList.Builder getBytesContentBuilder() { - return getBytesContentFieldBuilder().getBuilder(); - } - - /** - * .BytesList bytes_content = 6; - */ - public PushDownPb.BytesListOrBuilder getBytesContentOrBuilder() { - if ((contentCase_ == 6) && (bytesContentBuilder_ != null)) { - return bytesContentBuilder_.getMessageOrBuilder(); - } else { - if (contentCase_ == 6) { - return (PushDownPb.BytesList) content_; - } - return PushDownPb.BytesList.getDefaultInstance(); - } - } - - /** - * .BytesList bytes_content = 6; - */ - private com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.BytesList, PushDownPb.BytesList.Builder, PushDownPb.BytesListOrBuilder> - getBytesContentFieldBuilder() { - if (bytesContentBuilder_ == null) { - if (!(contentCase_ == 6)) { - content_ = PushDownPb.BytesList.getDefaultInstance(); - } - bytesContentBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< - PushDownPb.BytesList, PushDownPb.BytesList.Builder, PushDownPb.BytesListOrBuilder>( - (PushDownPb.BytesList) content_, - getParentForChildren(), - isClean()); - content_ = null; - } - contentCase_ = 6; - onChanged(); - return bytesContentBuilder_; - } - - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:FilterNode) - } - - // @@protoc_insertion_point(class_scope:FilterNode) - private static final PushDownPb.FilterNode DEFAULT_INSTANCE; - - static { - DEFAULT_INSTANCE = new PushDownPb.FilterNode(); - } - - public static PushDownPb.FilterNode getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public FilterNode parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new FilterNode(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public PushDownPb.FilterNode getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - public interface IntListOrBuilder extends - // @@protoc_insertion_point(interface_extends:IntList) - com.google.protobuf.MessageOrBuilder { - - /** - * repeated int32 int = 1; - */ - java.util.List getIntList(); - - /** - * repeated int32 int = 1; - */ - int getIntCount(); - - /** - * repeated int32 int = 1; - */ - int getInt(int index); - } - - /** - * Protobuf type {@code IntList} - */ - public static final class IntList extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:IntList) - IntListOrBuilder { - private static final long serialVersionUID = 0L; - - // Use IntList.newBuilder() to construct. - private IntList(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - - private IntList() { - int_ = emptyIntList(); - } - - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new IntList(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - - private IntList( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 8: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - int_ = newIntList(); - mutable_bitField0_ |= 0x00000001; - } - int_.addInt(input.readInt32()); - break; - } - case 10: { - int length = input.readRawVarint32(); - int limit = input.pushLimit(length); - if (!((mutable_bitField0_ & 0x00000001) != 0) && input.getBytesUntilLimit() > 0) { - int_ = newIntList(); - mutable_bitField0_ |= 0x00000001; - } - while (input.getBytesUntilLimit() > 0) { - int_.addInt(input.readInt32()); - } - input.popLimit(limit); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - int_.makeImmutable(); // C - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_IntList_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_IntList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.IntList.class, PushDownPb.IntList.Builder.class); - } - - public static final int INT_FIELD_NUMBER = 1; - private com.google.protobuf.Internal.IntList int_; - - /** - * repeated int32 int = 1; - */ - public java.util.List - getIntList() { - return int_; - } - - /** - * repeated int32 int = 1; - */ - public int getIntCount() { - return int_.size(); - } - - /** - * repeated int32 int = 1; - */ - public int getInt(int index) { - return int_.getInt(index); - } - - private int intMemoizedSerializedSize = -1; - - private byte memoizedIsInitialized = -1; - - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - getSerializedSize(); - if (getIntList().size() > 0) { - output.writeUInt32NoTag(10); - output.writeUInt32NoTag(intMemoizedSerializedSize); - } - for (int i = 0; i < int_.size(); i++) { - output.writeInt32NoTag(int_.getInt(i)); - } - unknownFields.writeTo(output); - } - - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - { - int dataSize = 0; - for (int i = 0; i < int_.size(); i++) { - dataSize += com.google.protobuf.CodedOutputStream - .computeInt32SizeNoTag(int_.getInt(i)); - } - size += dataSize; - if (!getIntList().isEmpty()) { - size += 1; - size += com.google.protobuf.CodedOutputStream - .computeInt32SizeNoTag(dataSize); - } - intMemoizedSerializedSize = dataSize; - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.IntList)) { - return super.equals(obj); - } - PushDownPb.IntList other = (PushDownPb.IntList) obj; - - if (!getIntList() - .equals(other.getIntList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getIntCount() > 0) { - hash = (37 * hash) + INT_FIELD_NUMBER; - hash = (53 * hash) + getIntList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static PushDownPb.IntList parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.IntList parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.IntList parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.IntList parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.IntList parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - - public static PushDownPb.IntList parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - - public static PushDownPb.IntList parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.IntList parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.IntList parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - - public static PushDownPb.IntList parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - - public static PushDownPb.IntList parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - - public static PushDownPb.IntList parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } - - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - - public static Builder newBuilder(PushDownPb.IntList prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - - /** - * Protobuf type {@code IntList} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:IntList) - PushDownPb.IntListOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_IntList_descriptor; - } - - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_IntList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.IntList.class, PushDownPb.IntList.Builder.class); - } - - // Construct using PushDownPb.IntList.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - - private void maybeForceBuilderInitialization() { - } - - @java.lang.Override - public Builder clear() { - super.clear(); - int_ = emptyIntList(); - bitField0_ = (bitField0_ & ~0x00000001); - return this; - } - - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_IntList_descriptor; - } - - @java.lang.Override - public PushDownPb.IntList getDefaultInstanceForType() { - return PushDownPb.IntList.getDefaultInstance(); - } - - @java.lang.Override - public PushDownPb.IntList build() { - PushDownPb.IntList result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - @java.lang.Override - public PushDownPb.IntList buildPartial() { - PushDownPb.IntList result = new PushDownPb.IntList(this); - if (((bitField0_ & 0x00000001) != 0)) { - int_.makeImmutable(); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.int_ = int_; - onBuilt(); - return result; - } - - @java.lang.Override - public Builder clone() { - return super.clone(); - } - - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } - - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } - - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } - - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } - - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } - - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.IntList) { - return mergeFrom((PushDownPb.IntList) other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(PushDownPb.IntList other) { - if (other == PushDownPb.IntList.getDefaultInstance()) return this; - if (!other.int_.isEmpty()) { - if (int_.isEmpty()) { - int_ = other.int_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureIntIsMutable(); - int_.addAll(other.int_); - } - onChanged(); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - @java.lang.Override - public final boolean isInitialized() { - return true; - } - - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.IntList parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.IntList) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - - private int bitField0_; - - private com.google.protobuf.Internal.IntList int_ = emptyIntList(); - - private void ensureIntIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - int_ = mutableCopy(int_); - bitField0_ |= 0x00000001; - } - } - - /** - * repeated int32 int = 1; - */ - public java.util.List - getIntList() { - return ((bitField0_ & 0x00000001) != 0) ? - java.util.Collections.unmodifiableList(int_) : int_; - } - - /** - * repeated int32 int = 1; - */ - public int getIntCount() { - return int_.size(); - } - - /** - * repeated int32 int = 1; - */ - public int getInt(int index) { - return int_.getInt(index); - } - - /** - * repeated int32 int = 1; - */ - public Builder setInt( - int index, int value) { - ensureIntIsMutable(); - int_.setInt(index, value); - onChanged(); - return this; - } - - /** - * repeated int32 int = 1; - */ - public Builder addInt(int value) { - ensureIntIsMutable(); - int_.addInt(value); - onChanged(); - return this; - } - - /** - * repeated int32 int = 1; - */ - public Builder addAllInt( - java.lang.Iterable values) { - ensureIntIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, int_); - onChanged(); - return this; - } - - /** - * repeated int32 int = 1; - */ - public Builder clearInt() { - int_ = emptyIntList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } - - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:IntList) - } - - // @@protoc_insertion_point(class_scope:IntList) - private static final PushDownPb.IntList DEFAULT_INSTANCE; - - static { - DEFAULT_INSTANCE = new PushDownPb.IntList(); - } - - public static PushDownPb.IntList getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public IntList parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new IntList(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public PushDownPb.IntList getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - + content_ = input.readMessage(PushDownPb.LongList.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.LongList) content_); + content_ = subBuilder.buildPartial(); + } + contentCase_ = 4; + break; + } + case 42: + { + PushDownPb.StringList.Builder subBuilder = null; + if (contentCase_ == 5) { + subBuilder = ((PushDownPb.StringList) content_).toBuilder(); + } + content_ = input.readMessage(PushDownPb.StringList.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.StringList) content_); + content_ = subBuilder.buildPartial(); + } + contentCase_ = 5; + break; + } + case 50: + { + PushDownPb.BytesList.Builder subBuilder = null; + if (contentCase_ == 6) { + subBuilder = ((PushDownPb.BytesList) content_).toBuilder(); + } + content_ = input.readMessage(PushDownPb.BytesList.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom((PushDownPb.BytesList) content_); + content_ = subBuilder.buildPartial(); + } + contentCase_ = 6; + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + filters_ = java.util.Collections.unmodifiableList(filters_); + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } } - public interface LongListOrBuilder extends - // @@protoc_insertion_point(interface_extends:LongList) - com.google.protobuf.MessageOrBuilder { - - /** - * repeated int64 long = 1; - */ - java.util.List getLongList(); + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_FilterNode_descriptor; + } - /** - * repeated int64 long = 1; - */ - int getLongCount(); + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_FilterNode_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.FilterNode.class, PushDownPb.FilterNode.Builder.class); + } - /** - * repeated int64 long = 1; - */ - long getLong(int index); + private int contentCase_ = 0; + private java.lang.Object content_; + + public enum ContentCase implements com.google.protobuf.Internal.EnumLite { + INT_CONTENT(3), + LONG_CONTENT(4), + STR_CONTENT(5), + BYTES_CONTENT(6), + CONTENT_NOT_SET(0); + private final int value; + + private ContentCase(int value) { + this.value = value; + } + + /** + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static ContentCase valueOf(int value) { + return forNumber(value); + } + + public static ContentCase forNumber(int value) { + switch (value) { + case 3: + return INT_CONTENT; + case 4: + return LONG_CONTENT; + case 5: + return STR_CONTENT; + case 6: + return BYTES_CONTENT; + case 0: + return CONTENT_NOT_SET; + default: + return null; + } + } + + public int getNumber() { + return this.value; + } + }; + + public ContentCase getContentCase() { + return ContentCase.forNumber(contentCase_); } - /** - * Protobuf type {@code LongList} - */ - public static final class LongList extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:LongList) - LongListOrBuilder { - private static final long serialVersionUID = 0L; + public static final int FILTER_TYPE_FIELD_NUMBER = 1; + private int filterType_; - // Use LongList.newBuilder() to construct. - private LongList(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } + /** .FilterType filter_type = 1; */ + public int getFilterTypeValue() { + return filterType_; + } - private LongList() { - long_ = emptyLongList(); - } + /** .FilterType filter_type = 1; */ + public PushDownPb.FilterType getFilterType() { + @SuppressWarnings("deprecation") + PushDownPb.FilterType result = PushDownPb.FilterType.valueOf(filterType_); + return result == null ? PushDownPb.FilterType.UNRECOGNIZED : result; + } - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new LongList(); - } + public static final int FILTERS_FIELD_NUMBER = 2; + private java.util.List filters_; - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + /** repeated .FilterNode filters = 2; */ + public java.util.List getFiltersList() { + return filters_; + } - private LongList( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 8: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - long_ = newLongList(); - mutable_bitField0_ |= 0x00000001; - } - long_.addLong(input.readInt64()); - break; - } - case 10: { - int length = input.readRawVarint32(); - int limit = input.pushLimit(length); - if (!((mutable_bitField0_ & 0x00000001) != 0) && input.getBytesUntilLimit() > 0) { - long_ = newLongList(); - mutable_bitField0_ |= 0x00000001; - } - while (input.getBytesUntilLimit() > 0) { - long_.addLong(input.readInt64()); - } - input.popLimit(limit); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - long_.makeImmutable(); // C - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + /** repeated .FilterNode filters = 2; */ + public java.util.List getFiltersOrBuilderList() { + return filters_; + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_LongList_descriptor; - } + /** repeated .FilterNode filters = 2; */ + public int getFiltersCount() { + return filters_.size(); + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_LongList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.LongList.class, PushDownPb.LongList.Builder.class); - } + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNode getFilters(int index) { + return filters_.get(index); + } - public static final int LONG_FIELD_NUMBER = 1; - private com.google.protobuf.Internal.LongList long_; + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder(int index) { + return filters_.get(index); + } - /** - * repeated int64 long = 1; - */ - public java.util.List - getLongList() { - return long_; - } + public static final int INT_CONTENT_FIELD_NUMBER = 3; - /** - * repeated int64 long = 1; - */ - public int getLongCount() { - return long_.size(); - } + /** .IntList int_content = 3; */ + public boolean hasIntContent() { + return contentCase_ == 3; + } - /** - * repeated int64 long = 1; - */ - public long getLong(int index) { - return long_.getLong(index); - } + /** .IntList int_content = 3; */ + public PushDownPb.IntList getIntContent() { + if (contentCase_ == 3) { + return (PushDownPb.IntList) content_; + } + return PushDownPb.IntList.getDefaultInstance(); + } - private int longMemoizedSerializedSize = -1; + /** .IntList int_content = 3; */ + public PushDownPb.IntListOrBuilder getIntContentOrBuilder() { + if (contentCase_ == 3) { + return (PushDownPb.IntList) content_; + } + return PushDownPb.IntList.getDefaultInstance(); + } - private byte memoizedIsInitialized = -1; + public static final int LONG_CONTENT_FIELD_NUMBER = 4; - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + /** .LongList long_content = 4; */ + public boolean hasLongContent() { + return contentCase_ == 4; + } - memoizedIsInitialized = 1; - return true; - } + /** .LongList long_content = 4; */ + public PushDownPb.LongList getLongContent() { + if (contentCase_ == 4) { + return (PushDownPb.LongList) content_; + } + return PushDownPb.LongList.getDefaultInstance(); + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - getSerializedSize(); - if (getLongList().size() > 0) { - output.writeUInt32NoTag(10); - output.writeUInt32NoTag(longMemoizedSerializedSize); - } - for (int i = 0; i < long_.size(); i++) { - output.writeInt64NoTag(long_.getLong(i)); - } - unknownFields.writeTo(output); - } + /** .LongList long_content = 4; */ + public PushDownPb.LongListOrBuilder getLongContentOrBuilder() { + if (contentCase_ == 4) { + return (PushDownPb.LongList) content_; + } + return PushDownPb.LongList.getDefaultInstance(); + } - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + public static final int STR_CONTENT_FIELD_NUMBER = 5; - size = 0; - { - int dataSize = 0; - for (int i = 0; i < long_.size(); i++) { - dataSize += com.google.protobuf.CodedOutputStream - .computeInt64SizeNoTag(long_.getLong(i)); - } - size += dataSize; - if (!getLongList().isEmpty()) { - size += 1; - size += com.google.protobuf.CodedOutputStream - .computeInt32SizeNoTag(dataSize); - } - longMemoizedSerializedSize = dataSize; - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + /** .StringList str_content = 5; */ + public boolean hasStrContent() { + return contentCase_ == 5; + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.LongList)) { - return super.equals(obj); - } - PushDownPb.LongList other = (PushDownPb.LongList) obj; + /** .StringList str_content = 5; */ + public PushDownPb.StringList getStrContent() { + if (contentCase_ == 5) { + return (PushDownPb.StringList) content_; + } + return PushDownPb.StringList.getDefaultInstance(); + } - if (!getLongList() - .equals(other.getLongList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } + /** .StringList str_content = 5; */ + public PushDownPb.StringListOrBuilder getStrContentOrBuilder() { + if (contentCase_ == 5) { + return (PushDownPb.StringList) content_; + } + return PushDownPb.StringList.getDefaultInstance(); + } - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getLongCount() > 0) { - hash = (37 * hash) + LONG_FIELD_NUMBER; - hash = (53 * hash) + getLongList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + public static final int BYTES_CONTENT_FIELD_NUMBER = 6; - public static PushDownPb.LongList parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + /** .BytesList bytes_content = 6; */ + public boolean hasBytesContent() { + return contentCase_ == 6; + } - public static PushDownPb.LongList parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + /** .BytesList bytes_content = 6; */ + public PushDownPb.BytesList getBytesContent() { + if (contentCase_ == 6) { + return (PushDownPb.BytesList) content_; + } + return PushDownPb.BytesList.getDefaultInstance(); + } - public static PushDownPb.LongList parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + /** .BytesList bytes_content = 6; */ + public PushDownPb.BytesListOrBuilder getBytesContentOrBuilder() { + if (contentCase_ == 6) { + return (PushDownPb.BytesList) content_; + } + return PushDownPb.BytesList.getDefaultInstance(); + } - public static PushDownPb.LongList parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + private byte memoizedIsInitialized = -1; - public static PushDownPb.LongList parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - public static PushDownPb.LongList parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + memoizedIsInitialized = 1; + return true; + } - public static PushDownPb.LongList parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (filterType_ != PushDownPb.FilterType.EMPTY.getNumber()) { + output.writeEnum(1, filterType_); + } + for (int i = 0; i < filters_.size(); i++) { + output.writeMessage(2, filters_.get(i)); + } + if (contentCase_ == 3) { + output.writeMessage(3, (PushDownPb.IntList) content_); + } + if (contentCase_ == 4) { + output.writeMessage(4, (PushDownPb.LongList) content_); + } + if (contentCase_ == 5) { + output.writeMessage(5, (PushDownPb.StringList) content_); + } + if (contentCase_ == 6) { + output.writeMessage(6, (PushDownPb.BytesList) content_); + } + unknownFields.writeTo(output); + } - public static PushDownPb.LongList parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (filterType_ != PushDownPb.FilterType.EMPTY.getNumber()) { + size += com.google.protobuf.CodedOutputStream.computeEnumSize(1, filterType_); + } + for (int i = 0; i < filters_.size(); i++) { + size += com.google.protobuf.CodedOutputStream.computeMessageSize(2, filters_.get(i)); + } + if (contentCase_ == 3) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 3, (PushDownPb.IntList) content_); + } + if (contentCase_ == 4) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 4, (PushDownPb.LongList) content_); + } + if (contentCase_ == 5) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 5, (PushDownPb.StringList) content_); + } + if (contentCase_ == 6) { + size += + com.google.protobuf.CodedOutputStream.computeMessageSize( + 6, (PushDownPb.BytesList) content_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - public static PushDownPb.LongList parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.FilterNode)) { + return super.equals(obj); + } + PushDownPb.FilterNode other = (PushDownPb.FilterNode) obj; + + if (filterType_ != other.filterType_) return false; + if (!getFiltersList().equals(other.getFiltersList())) return false; + if (!getContentCase().equals(other.getContentCase())) return false; + switch (contentCase_) { + case 3: + if (!getIntContent().equals(other.getIntContent())) return false; + break; + case 4: + if (!getLongContent().equals(other.getLongContent())) return false; + break; + case 5: + if (!getStrContent().equals(other.getStrContent())) return false; + break; + case 6: + if (!getBytesContent().equals(other.getBytesContent())) return false; + break; + case 0: + default: + } + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - public static PushDownPb.LongList parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + FILTER_TYPE_FIELD_NUMBER; + hash = (53 * hash) + filterType_; + if (getFiltersCount() > 0) { + hash = (37 * hash) + FILTERS_FIELD_NUMBER; + hash = (53 * hash) + getFiltersList().hashCode(); + } + switch (contentCase_) { + case 3: + hash = (37 * hash) + INT_CONTENT_FIELD_NUMBER; + hash = (53 * hash) + getIntContent().hashCode(); + break; + case 4: + hash = (37 * hash) + LONG_CONTENT_FIELD_NUMBER; + hash = (53 * hash) + getLongContent().hashCode(); + break; + case 5: + hash = (37 * hash) + STR_CONTENT_FIELD_NUMBER; + hash = (53 * hash) + getStrContent().hashCode(); + break; + case 6: + hash = (37 * hash) + BYTES_CONTENT_FIELD_NUMBER; + hash = (53 * hash) + getBytesContent().hashCode(); + break; + case 0: + default: + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - public static PushDownPb.LongList parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + public static PushDownPb.FilterNode parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static PushDownPb.LongList parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static PushDownPb.FilterNode parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + public static PushDownPb.FilterNode parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + public static PushDownPb.FilterNode parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - public static Builder newBuilder(PushDownPb.LongList prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + public static PushDownPb.FilterNode parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + public static PushDownPb.FilterNode parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + public static PushDownPb.FilterNode parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - /** - * Protobuf type {@code LongList} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:LongList) - PushDownPb.LongListOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_LongList_descriptor; - } + public static PushDownPb.FilterNode parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_LongList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.LongList.class, PushDownPb.LongList.Builder.class); - } + public static PushDownPb.FilterNode parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - // Construct using PushDownPb.LongList.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } + public static PushDownPb.FilterNode parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } + public static PushDownPb.FilterNode parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - private void maybeForceBuilderInitialization() { - } + public static PushDownPb.FilterNode parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public Builder clear() { - super.clear(); - long_ = emptyLongList(); - bitField0_ = (bitField0_ & ~0x00000001); - return this; - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_LongList_descriptor; - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - @java.lang.Override - public PushDownPb.LongList getDefaultInstanceForType() { - return PushDownPb.LongList.getDefaultInstance(); - } + public static Builder newBuilder(PushDownPb.FilterNode prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - @java.lang.Override - public PushDownPb.LongList build() { - PushDownPb.LongList result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - @java.lang.Override - public PushDownPb.LongList buildPartial() { - PushDownPb.LongList result = new PushDownPb.LongList(this); - if (((bitField0_ & 0x00000001) != 0)) { - long_.makeImmutable(); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.long_ = long_; - onBuilt(); - return result; - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - @java.lang.Override - public Builder clone() { - return super.clone(); + /** Protobuf type {@code FilterNode} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:FilterNode) + PushDownPb.FilterNodeOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_FilterNode_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_FilterNode_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.FilterNode.class, PushDownPb.FilterNode.Builder.class); + } + + // Construct using PushDownPb.FilterNode.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders) { + getFiltersFieldBuilder(); + } + } + + @java.lang.Override + public Builder clear() { + super.clear(); + filterType_ = 0; + + if (filtersBuilder_ == null) { + filters_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + } else { + filtersBuilder_.clear(); + } + contentCase_ = 0; + content_ = null; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_FilterNode_descriptor; + } + + @java.lang.Override + public PushDownPb.FilterNode getDefaultInstanceForType() { + return PushDownPb.FilterNode.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.FilterNode build() { + PushDownPb.FilterNode result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.FilterNode buildPartial() { + PushDownPb.FilterNode result = new PushDownPb.FilterNode(this); + result.filterType_ = filterType_; + if (filtersBuilder_ == null) { + if (((bitField0_ & 0x00000001) != 0)) { + filters_ = java.util.Collections.unmodifiableList(filters_); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.filters_ = filters_; + } else { + result.filters_ = filtersBuilder_.build(); + } + if (contentCase_ == 3) { + if (intContentBuilder_ == null) { + result.content_ = content_; + } else { + result.content_ = intContentBuilder_.build(); + } + } + if (contentCase_ == 4) { + if (longContentBuilder_ == null) { + result.content_ = content_; + } else { + result.content_ = longContentBuilder_.build(); + } + } + if (contentCase_ == 5) { + if (strContentBuilder_ == null) { + result.content_ = content_; + } else { + result.content_ = strContentBuilder_.build(); + } + } + if (contentCase_ == 6) { + if (bytesContentBuilder_ == null) { + result.content_ = content_; + } else { + result.content_ = bytesContentBuilder_.build(); + } + } + result.contentCase_ = contentCase_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.FilterNode) { + return mergeFrom((PushDownPb.FilterNode) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.FilterNode other) { + if (other == PushDownPb.FilterNode.getDefaultInstance()) return this; + if (other.filterType_ != 0) { + setFilterTypeValue(other.getFilterTypeValue()); + } + if (filtersBuilder_ == null) { + if (!other.filters_.isEmpty()) { + if (filters_.isEmpty()) { + filters_ = other.filters_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureFiltersIsMutable(); + filters_.addAll(other.filters_); + } + onChanged(); + } + } else { + if (!other.filters_.isEmpty()) { + if (filtersBuilder_.isEmpty()) { + filtersBuilder_.dispose(); + filtersBuilder_ = null; + filters_ = other.filters_; + bitField0_ = (bitField0_ & ~0x00000001); + filtersBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders + ? getFiltersFieldBuilder() + : null; + } else { + filtersBuilder_.addAllMessages(other.filters_); + } + } + } + switch (other.getContentCase()) { + case INT_CONTENT: + { + mergeIntContent(other.getIntContent()); + break; } - - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); + case LONG_CONTENT: + { + mergeLongContent(other.getLongContent()); + break; } - - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); + case STR_CONTENT: + { + mergeStrContent(other.getStrContent()); + break; } - - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); + case BYTES_CONTENT: + { + mergeBytesContent(other.getBytesContent()); + break; } + case CONTENT_NOT_SET: + { + break; + } + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.FilterNode parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.FilterNode) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int contentCase_ = 0; + private java.lang.Object content_; + + public ContentCase getContentCase() { + return ContentCase.forNumber(contentCase_); + } + + public Builder clearContent() { + contentCase_ = 0; + content_ = null; + onChanged(); + return this; + } + + private int bitField0_; + + private int filterType_ = 0; + + /** .FilterType filter_type = 1; */ + public int getFilterTypeValue() { + return filterType_; + } + + /** .FilterType filter_type = 1; */ + public Builder setFilterTypeValue(int value) { + filterType_ = value; + onChanged(); + return this; + } + + /** .FilterType filter_type = 1; */ + public PushDownPb.FilterType getFilterType() { + @SuppressWarnings("deprecation") + PushDownPb.FilterType result = PushDownPb.FilterType.valueOf(filterType_); + return result == null ? PushDownPb.FilterType.UNRECOGNIZED : result; + } + + /** .FilterType filter_type = 1; */ + public Builder setFilterType(PushDownPb.FilterType value) { + if (value == null) { + throw new NullPointerException(); + } + + filterType_ = value.getNumber(); + onChanged(); + return this; + } + + /** .FilterType filter_type = 1; */ + public Builder clearFilterType() { + + filterType_ = 0; + onChanged(); + return this; + } + + private java.util.List filters_ = java.util.Collections.emptyList(); + + private void ensureFiltersIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + filters_ = new java.util.ArrayList(filters_); + bitField0_ |= 0x00000001; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + filtersBuilder_; + + /** repeated .FilterNode filters = 2; */ + public java.util.List getFiltersList() { + if (filtersBuilder_ == null) { + return java.util.Collections.unmodifiableList(filters_); + } else { + return filtersBuilder_.getMessageList(); + } + } + + /** repeated .FilterNode filters = 2; */ + public int getFiltersCount() { + if (filtersBuilder_ == null) { + return filters_.size(); + } else { + return filtersBuilder_.getCount(); + } + } + + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNode getFilters(int index) { + if (filtersBuilder_ == null) { + return filters_.get(index); + } else { + return filtersBuilder_.getMessage(index); + } + } + + /** repeated .FilterNode filters = 2; */ + public Builder setFilters(int index, PushDownPb.FilterNode value) { + if (filtersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFiltersIsMutable(); + filters_.set(index, value); + onChanged(); + } else { + filtersBuilder_.setMessage(index, value); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder setFilters(int index, PushDownPb.FilterNode.Builder builderForValue) { + if (filtersBuilder_ == null) { + ensureFiltersIsMutable(); + filters_.set(index, builderForValue.build()); + onChanged(); + } else { + filtersBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder addFilters(PushDownPb.FilterNode value) { + if (filtersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFiltersIsMutable(); + filters_.add(value); + onChanged(); + } else { + filtersBuilder_.addMessage(value); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder addFilters(int index, PushDownPb.FilterNode value) { + if (filtersBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureFiltersIsMutable(); + filters_.add(index, value); + onChanged(); + } else { + filtersBuilder_.addMessage(index, value); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder addFilters(PushDownPb.FilterNode.Builder builderForValue) { + if (filtersBuilder_ == null) { + ensureFiltersIsMutable(); + filters_.add(builderForValue.build()); + onChanged(); + } else { + filtersBuilder_.addMessage(builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder addFilters(int index, PushDownPb.FilterNode.Builder builderForValue) { + if (filtersBuilder_ == null) { + ensureFiltersIsMutable(); + filters_.add(index, builderForValue.build()); + onChanged(); + } else { + filtersBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder addAllFilters(java.lang.Iterable values) { + if (filtersBuilder_ == null) { + ensureFiltersIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, filters_); + onChanged(); + } else { + filtersBuilder_.addAllMessages(values); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder clearFilters() { + if (filtersBuilder_ == null) { + filters_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + } else { + filtersBuilder_.clear(); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public Builder removeFilters(int index) { + if (filtersBuilder_ == null) { + ensureFiltersIsMutable(); + filters_.remove(index); + onChanged(); + } else { + filtersBuilder_.remove(index); + } + return this; + } + + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNode.Builder getFiltersBuilder(int index) { + return getFiltersFieldBuilder().getBuilder(index); + } + + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNodeOrBuilder getFiltersOrBuilder(int index) { + if (filtersBuilder_ == null) { + return filters_.get(index); + } else { + return filtersBuilder_.getMessageOrBuilder(index); + } + } + + /** repeated .FilterNode filters = 2; */ + public java.util.List getFiltersOrBuilderList() { + if (filtersBuilder_ != null) { + return filtersBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(filters_); + } + } + + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNode.Builder addFiltersBuilder() { + return getFiltersFieldBuilder().addBuilder(PushDownPb.FilterNode.getDefaultInstance()); + } + + /** repeated .FilterNode filters = 2; */ + public PushDownPb.FilterNode.Builder addFiltersBuilder(int index) { + return getFiltersFieldBuilder() + .addBuilder(index, PushDownPb.FilterNode.getDefaultInstance()); + } + + /** repeated .FilterNode filters = 2; */ + public java.util.List getFiltersBuilderList() { + return getFiltersFieldBuilder().getBuilderList(); + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, PushDownPb.FilterNode.Builder, PushDownPb.FilterNodeOrBuilder> + getFiltersFieldBuilder() { + if (filtersBuilder_ == null) { + filtersBuilder_ = + new com.google.protobuf.RepeatedFieldBuilderV3< + PushDownPb.FilterNode, + PushDownPb.FilterNode.Builder, + PushDownPb.FilterNodeOrBuilder>( + filters_, ((bitField0_ & 0x00000001) != 0), getParentForChildren(), isClean()); + filters_ = null; + } + return filtersBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder> + intContentBuilder_; + + /** .IntList int_content = 3; */ + public boolean hasIntContent() { + return contentCase_ == 3; + } + + /** .IntList int_content = 3; */ + public PushDownPb.IntList getIntContent() { + if (intContentBuilder_ == null) { + if (contentCase_ == 3) { + return (PushDownPb.IntList) content_; + } + return PushDownPb.IntList.getDefaultInstance(); + } else { + if (contentCase_ == 3) { + return intContentBuilder_.getMessage(); + } + return PushDownPb.IntList.getDefaultInstance(); + } + } + + /** .IntList int_content = 3; */ + public Builder setIntContent(PushDownPb.IntList value) { + if (intContentBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + content_ = value; + onChanged(); + } else { + intContentBuilder_.setMessage(value); + } + contentCase_ = 3; + return this; + } + + /** .IntList int_content = 3; */ + public Builder setIntContent(PushDownPb.IntList.Builder builderForValue) { + if (intContentBuilder_ == null) { + content_ = builderForValue.build(); + onChanged(); + } else { + intContentBuilder_.setMessage(builderForValue.build()); + } + contentCase_ = 3; + return this; + } + + /** .IntList int_content = 3; */ + public Builder mergeIntContent(PushDownPb.IntList value) { + if (intContentBuilder_ == null) { + if (contentCase_ == 3 && content_ != PushDownPb.IntList.getDefaultInstance()) { + content_ = + PushDownPb.IntList.newBuilder((PushDownPb.IntList) content_) + .mergeFrom(value) + .buildPartial(); + } else { + content_ = value; + } + onChanged(); + } else { + if (contentCase_ == 3) { + intContentBuilder_.mergeFrom(value); + } + intContentBuilder_.setMessage(value); + } + contentCase_ = 3; + return this; + } + + /** .IntList int_content = 3; */ + public Builder clearIntContent() { + if (intContentBuilder_ == null) { + if (contentCase_ == 3) { + contentCase_ = 0; + content_ = null; + onChanged(); + } + } else { + if (contentCase_ == 3) { + contentCase_ = 0; + content_ = null; + } + intContentBuilder_.clear(); + } + return this; + } + + /** .IntList int_content = 3; */ + public PushDownPb.IntList.Builder getIntContentBuilder() { + return getIntContentFieldBuilder().getBuilder(); + } + + /** .IntList int_content = 3; */ + public PushDownPb.IntListOrBuilder getIntContentOrBuilder() { + if ((contentCase_ == 3) && (intContentBuilder_ != null)) { + return intContentBuilder_.getMessageOrBuilder(); + } else { + if (contentCase_ == 3) { + return (PushDownPb.IntList) content_; + } + return PushDownPb.IntList.getDefaultInstance(); + } + } + + /** .IntList int_content = 3; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder> + getIntContentFieldBuilder() { + if (intContentBuilder_ == null) { + if (!(contentCase_ == 3)) { + content_ = PushDownPb.IntList.getDefaultInstance(); + } + intContentBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.IntList, PushDownPb.IntList.Builder, PushDownPb.IntListOrBuilder>( + (PushDownPb.IntList) content_, getParentForChildren(), isClean()); + content_ = null; + } + contentCase_ = 3; + onChanged(); + return intContentBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder> + longContentBuilder_; + + /** .LongList long_content = 4; */ + public boolean hasLongContent() { + return contentCase_ == 4; + } + + /** .LongList long_content = 4; */ + public PushDownPb.LongList getLongContent() { + if (longContentBuilder_ == null) { + if (contentCase_ == 4) { + return (PushDownPb.LongList) content_; + } + return PushDownPb.LongList.getDefaultInstance(); + } else { + if (contentCase_ == 4) { + return longContentBuilder_.getMessage(); + } + return PushDownPb.LongList.getDefaultInstance(); + } + } + + /** .LongList long_content = 4; */ + public Builder setLongContent(PushDownPb.LongList value) { + if (longContentBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + content_ = value; + onChanged(); + } else { + longContentBuilder_.setMessage(value); + } + contentCase_ = 4; + return this; + } + + /** .LongList long_content = 4; */ + public Builder setLongContent(PushDownPb.LongList.Builder builderForValue) { + if (longContentBuilder_ == null) { + content_ = builderForValue.build(); + onChanged(); + } else { + longContentBuilder_.setMessage(builderForValue.build()); + } + contentCase_ = 4; + return this; + } + + /** .LongList long_content = 4; */ + public Builder mergeLongContent(PushDownPb.LongList value) { + if (longContentBuilder_ == null) { + if (contentCase_ == 4 && content_ != PushDownPb.LongList.getDefaultInstance()) { + content_ = + PushDownPb.LongList.newBuilder((PushDownPb.LongList) content_) + .mergeFrom(value) + .buildPartial(); + } else { + content_ = value; + } + onChanged(); + } else { + if (contentCase_ == 4) { + longContentBuilder_.mergeFrom(value); + } + longContentBuilder_.setMessage(value); + } + contentCase_ = 4; + return this; + } + + /** .LongList long_content = 4; */ + public Builder clearLongContent() { + if (longContentBuilder_ == null) { + if (contentCase_ == 4) { + contentCase_ = 0; + content_ = null; + onChanged(); + } + } else { + if (contentCase_ == 4) { + contentCase_ = 0; + content_ = null; + } + longContentBuilder_.clear(); + } + return this; + } + + /** .LongList long_content = 4; */ + public PushDownPb.LongList.Builder getLongContentBuilder() { + return getLongContentFieldBuilder().getBuilder(); + } + + /** .LongList long_content = 4; */ + public PushDownPb.LongListOrBuilder getLongContentOrBuilder() { + if ((contentCase_ == 4) && (longContentBuilder_ != null)) { + return longContentBuilder_.getMessageOrBuilder(); + } else { + if (contentCase_ == 4) { + return (PushDownPb.LongList) content_; + } + return PushDownPb.LongList.getDefaultInstance(); + } + } + + /** .LongList long_content = 4; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder> + getLongContentFieldBuilder() { + if (longContentBuilder_ == null) { + if (!(contentCase_ == 4)) { + content_ = PushDownPb.LongList.getDefaultInstance(); + } + longContentBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.LongList, PushDownPb.LongList.Builder, PushDownPb.LongListOrBuilder>( + (PushDownPb.LongList) content_, getParentForChildren(), isClean()); + content_ = null; + } + contentCase_ = 4; + onChanged(); + return longContentBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.StringList, PushDownPb.StringList.Builder, PushDownPb.StringListOrBuilder> + strContentBuilder_; + + /** .StringList str_content = 5; */ + public boolean hasStrContent() { + return contentCase_ == 5; + } + + /** .StringList str_content = 5; */ + public PushDownPb.StringList getStrContent() { + if (strContentBuilder_ == null) { + if (contentCase_ == 5) { + return (PushDownPb.StringList) content_; + } + return PushDownPb.StringList.getDefaultInstance(); + } else { + if (contentCase_ == 5) { + return strContentBuilder_.getMessage(); + } + return PushDownPb.StringList.getDefaultInstance(); + } + } + + /** .StringList str_content = 5; */ + public Builder setStrContent(PushDownPb.StringList value) { + if (strContentBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + content_ = value; + onChanged(); + } else { + strContentBuilder_.setMessage(value); + } + contentCase_ = 5; + return this; + } + + /** .StringList str_content = 5; */ + public Builder setStrContent(PushDownPb.StringList.Builder builderForValue) { + if (strContentBuilder_ == null) { + content_ = builderForValue.build(); + onChanged(); + } else { + strContentBuilder_.setMessage(builderForValue.build()); + } + contentCase_ = 5; + return this; + } + + /** .StringList str_content = 5; */ + public Builder mergeStrContent(PushDownPb.StringList value) { + if (strContentBuilder_ == null) { + if (contentCase_ == 5 && content_ != PushDownPb.StringList.getDefaultInstance()) { + content_ = + PushDownPb.StringList.newBuilder((PushDownPb.StringList) content_) + .mergeFrom(value) + .buildPartial(); + } else { + content_ = value; + } + onChanged(); + } else { + if (contentCase_ == 5) { + strContentBuilder_.mergeFrom(value); + } + strContentBuilder_.setMessage(value); + } + contentCase_ = 5; + return this; + } + + /** .StringList str_content = 5; */ + public Builder clearStrContent() { + if (strContentBuilder_ == null) { + if (contentCase_ == 5) { + contentCase_ = 0; + content_ = null; + onChanged(); + } + } else { + if (contentCase_ == 5) { + contentCase_ = 0; + content_ = null; + } + strContentBuilder_.clear(); + } + return this; + } + + /** .StringList str_content = 5; */ + public PushDownPb.StringList.Builder getStrContentBuilder() { + return getStrContentFieldBuilder().getBuilder(); + } + + /** .StringList str_content = 5; */ + public PushDownPb.StringListOrBuilder getStrContentOrBuilder() { + if ((contentCase_ == 5) && (strContentBuilder_ != null)) { + return strContentBuilder_.getMessageOrBuilder(); + } else { + if (contentCase_ == 5) { + return (PushDownPb.StringList) content_; + } + return PushDownPb.StringList.getDefaultInstance(); + } + } + + /** .StringList str_content = 5; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.StringList, PushDownPb.StringList.Builder, PushDownPb.StringListOrBuilder> + getStrContentFieldBuilder() { + if (strContentBuilder_ == null) { + if (!(contentCase_ == 5)) { + content_ = PushDownPb.StringList.getDefaultInstance(); + } + strContentBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.StringList, + PushDownPb.StringList.Builder, + PushDownPb.StringListOrBuilder>( + (PushDownPb.StringList) content_, getParentForChildren(), isClean()); + content_ = null; + } + contentCase_ = 5; + onChanged(); + return strContentBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.BytesList, PushDownPb.BytesList.Builder, PushDownPb.BytesListOrBuilder> + bytesContentBuilder_; + + /** .BytesList bytes_content = 6; */ + public boolean hasBytesContent() { + return contentCase_ == 6; + } + + /** .BytesList bytes_content = 6; */ + public PushDownPb.BytesList getBytesContent() { + if (bytesContentBuilder_ == null) { + if (contentCase_ == 6) { + return (PushDownPb.BytesList) content_; + } + return PushDownPb.BytesList.getDefaultInstance(); + } else { + if (contentCase_ == 6) { + return bytesContentBuilder_.getMessage(); + } + return PushDownPb.BytesList.getDefaultInstance(); + } + } + + /** .BytesList bytes_content = 6; */ + public Builder setBytesContent(PushDownPb.BytesList value) { + if (bytesContentBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + content_ = value; + onChanged(); + } else { + bytesContentBuilder_.setMessage(value); + } + contentCase_ = 6; + return this; + } + + /** .BytesList bytes_content = 6; */ + public Builder setBytesContent(PushDownPb.BytesList.Builder builderForValue) { + if (bytesContentBuilder_ == null) { + content_ = builderForValue.build(); + onChanged(); + } else { + bytesContentBuilder_.setMessage(builderForValue.build()); + } + contentCase_ = 6; + return this; + } + + /** .BytesList bytes_content = 6; */ + public Builder mergeBytesContent(PushDownPb.BytesList value) { + if (bytesContentBuilder_ == null) { + if (contentCase_ == 6 && content_ != PushDownPb.BytesList.getDefaultInstance()) { + content_ = + PushDownPb.BytesList.newBuilder((PushDownPb.BytesList) content_) + .mergeFrom(value) + .buildPartial(); + } else { + content_ = value; + } + onChanged(); + } else { + if (contentCase_ == 6) { + bytesContentBuilder_.mergeFrom(value); + } + bytesContentBuilder_.setMessage(value); + } + contentCase_ = 6; + return this; + } + + /** .BytesList bytes_content = 6; */ + public Builder clearBytesContent() { + if (bytesContentBuilder_ == null) { + if (contentCase_ == 6) { + contentCase_ = 0; + content_ = null; + onChanged(); + } + } else { + if (contentCase_ == 6) { + contentCase_ = 0; + content_ = null; + } + bytesContentBuilder_.clear(); + } + return this; + } + + /** .BytesList bytes_content = 6; */ + public PushDownPb.BytesList.Builder getBytesContentBuilder() { + return getBytesContentFieldBuilder().getBuilder(); + } + + /** .BytesList bytes_content = 6; */ + public PushDownPb.BytesListOrBuilder getBytesContentOrBuilder() { + if ((contentCase_ == 6) && (bytesContentBuilder_ != null)) { + return bytesContentBuilder_.getMessageOrBuilder(); + } else { + if (contentCase_ == 6) { + return (PushDownPb.BytesList) content_; + } + return PushDownPb.BytesList.getDefaultInstance(); + } + } + + /** .BytesList bytes_content = 6; */ + private com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.BytesList, PushDownPb.BytesList.Builder, PushDownPb.BytesListOrBuilder> + getBytesContentFieldBuilder() { + if (bytesContentBuilder_ == null) { + if (!(contentCase_ == 6)) { + content_ = PushDownPb.BytesList.getDefaultInstance(); + } + bytesContentBuilder_ = + new com.google.protobuf.SingleFieldBuilderV3< + PushDownPb.BytesList, + PushDownPb.BytesList.Builder, + PushDownPb.BytesListOrBuilder>( + (PushDownPb.BytesList) content_, getParentForChildren(), isClean()); + content_ = null; + } + contentCase_ = 6; + onChanged(); + return bytesContentBuilder_; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:FilterNode) + } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } + // @@protoc_insertion_point(class_scope:FilterNode) + private static final PushDownPb.FilterNode DEFAULT_INSTANCE; - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } + static { + DEFAULT_INSTANCE = new PushDownPb.FilterNode(); + } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.LongList) { - return mergeFrom((PushDownPb.LongList) other); - } else { - super.mergeFrom(other); - return this; - } - } + public static PushDownPb.FilterNode getDefaultInstance() { + return DEFAULT_INSTANCE; + } - public Builder mergeFrom(PushDownPb.LongList other) { - if (other == PushDownPb.LongList.getDefaultInstance()) return this; - if (!other.long_.isEmpty()) { - if (long_.isEmpty()) { - long_ = other.long_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureLongIsMutable(); - long_.addAll(other.long_); - } - onChanged(); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public FilterNode parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new FilterNode(input, extensionRegistry); + } + }; - @java.lang.Override - public final boolean isInitialized() { - return true; - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.LongList parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.LongList) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - private int bitField0_; + @java.lang.Override + public PushDownPb.FilterNode getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - private com.google.protobuf.Internal.LongList long_ = emptyLongList(); + public interface IntListOrBuilder + extends + // @@protoc_insertion_point(interface_extends:IntList) + com.google.protobuf.MessageOrBuilder { - private void ensureLongIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - long_ = mutableCopy(long_); - bitField0_ |= 0x00000001; - } - } + /** repeated int32 int = 1; */ + java.util.List getIntList(); - /** - * repeated int64 long = 1; - */ - public java.util.List - getLongList() { - return ((bitField0_ & 0x00000001) != 0) ? - java.util.Collections.unmodifiableList(long_) : long_; - } + /** repeated int32 int = 1; */ + int getIntCount(); - /** - * repeated int64 long = 1; - */ - public int getLongCount() { - return long_.size(); - } + /** repeated int32 int = 1; */ + int getInt(int index); + } - /** - * repeated int64 long = 1; - */ - public long getLong(int index) { - return long_.getLong(index); - } + /** Protobuf type {@code IntList} */ + public static final class IntList extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:IntList) + IntListOrBuilder { + private static final long serialVersionUID = 0L; - /** - * repeated int64 long = 1; - */ - public Builder setLong( - int index, long value) { - ensureLongIsMutable(); - long_.setLong(index, value); - onChanged(); - return this; - } + // Use IntList.newBuilder() to construct. + private IntList(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - /** - * repeated int64 long = 1; - */ - public Builder addLong(long value) { - ensureLongIsMutable(); - long_.addLong(value); - onChanged(); - return this; - } + private IntList() { + int_ = emptyIntList(); + } - /** - * repeated int64 long = 1; - */ - public Builder addAllLong( - java.lang.Iterable values) { - ensureLongIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, long_); - onChanged(); - return this; - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new IntList(); + } - /** - * repeated int64 long = 1; - */ - public Builder clearLong() { - long_ = emptyLongList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } + private IntList( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + int_ = newIntList(); + mutable_bitField0_ |= 0x00000001; + } + int_.addInt(input.readInt32()); + break; + } + case 10: + { + int length = input.readRawVarint32(); + int limit = input.pushLimit(length); + if (!((mutable_bitField0_ & 0x00000001) != 0) && input.getBytesUntilLimit() > 0) { + int_ = newIntList(); + mutable_bitField0_ |= 0x00000001; + } + while (input.getBytesUntilLimit() > 0) { + int_.addInt(input.readInt32()); + } + input.popLimit(limit); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + int_.makeImmutable(); // C + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_IntList_descriptor; + } + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_IntList_fieldAccessorTable.ensureFieldAccessorsInitialized( + PushDownPb.IntList.class, PushDownPb.IntList.Builder.class); + } - // @@protoc_insertion_point(builder_scope:LongList) - } + public static final int INT_FIELD_NUMBER = 1; + private com.google.protobuf.Internal.IntList int_; - // @@protoc_insertion_point(class_scope:LongList) - private static final PushDownPb.LongList DEFAULT_INSTANCE; + /** repeated int32 int = 1; */ + public java.util.List getIntList() { + return int_; + } - static { - DEFAULT_INSTANCE = new PushDownPb.LongList(); - } + /** repeated int32 int = 1; */ + public int getIntCount() { + return int_.size(); + } - public static PushDownPb.LongList getDefaultInstance() { - return DEFAULT_INSTANCE; - } + /** repeated int32 int = 1; */ + public int getInt(int index) { + return int_.getInt(index); + } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public LongList parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new LongList(input, extensionRegistry); - } - }; + private int intMemoizedSerializedSize = -1; - public static com.google.protobuf.Parser parser() { - return PARSER; - } + private byte memoizedIsInitialized = -1; - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - @java.lang.Override - public PushDownPb.LongList getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } + memoizedIsInitialized = 1; + return true; + } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + getSerializedSize(); + if (getIntList().size() > 0) { + output.writeUInt32NoTag(10); + output.writeUInt32NoTag(intMemoizedSerializedSize); + } + for (int i = 0; i < int_.size(); i++) { + output.writeInt32NoTag(int_.getInt(i)); + } + unknownFields.writeTo(output); } - public interface StringListOrBuilder extends - // @@protoc_insertion_point(interface_extends:StringList) - com.google.protobuf.MessageOrBuilder { + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + { + int dataSize = 0; + for (int i = 0; i < int_.size(); i++) { + dataSize += com.google.protobuf.CodedOutputStream.computeInt32SizeNoTag(int_.getInt(i)); + } + size += dataSize; + if (!getIntList().isEmpty()) { + size += 1; + size += com.google.protobuf.CodedOutputStream.computeInt32SizeNoTag(dataSize); + } + intMemoizedSerializedSize = dataSize; + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - /** - * repeated string str = 1; - */ - java.util.List - getStrList(); + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.IntList)) { + return super.equals(obj); + } + PushDownPb.IntList other = (PushDownPb.IntList) obj; + + if (!getIntList().equals(other.getIntList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - /** - * repeated string str = 1; - */ - int getStrCount(); + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (getIntCount() > 0) { + hash = (37 * hash) + INT_FIELD_NUMBER; + hash = (53 * hash) + getIntList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - /** - * repeated string str = 1; - */ - java.lang.String getStr(int index); + public static PushDownPb.IntList parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * repeated string str = 1; - */ - com.google.protobuf.ByteString - getStrBytes(int index); + public static PushDownPb.IntList parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); } - /** - * Protobuf type {@code StringList} - */ - public static final class StringList extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:StringList) - StringListOrBuilder { - private static final long serialVersionUID = 0L; + public static PushDownPb.IntList parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - // Use StringList.newBuilder() to construct. - private StringList(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } + public static PushDownPb.IntList parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - private StringList() { - str_ = com.google.protobuf.LazyStringArrayList.EMPTY; - } + public static PushDownPb.IntList parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new StringList(); - } + public static PushDownPb.IntList parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + public static PushDownPb.IntList parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - private StringList( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - java.lang.String s = input.readStringRequireUtf8(); - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - str_ = new com.google.protobuf.LazyStringArrayList(); - mutable_bitField0_ |= 0x00000001; - } - str_.add(s); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - str_ = str_.getUnmodifiableView(); - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + public static PushDownPb.IntList parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_StringList_descriptor; - } + public static PushDownPb.IntList parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_StringList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.StringList.class, PushDownPb.StringList.Builder.class); - } + public static PushDownPb.IntList parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - public static final int STR_FIELD_NUMBER = 1; - private com.google.protobuf.LazyStringList str_; + public static PushDownPb.IntList parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - /** - * repeated string str = 1; - */ - public com.google.protobuf.ProtocolStringList - getStrList() { - return str_; - } + public static PushDownPb.IntList parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - /** - * repeated string str = 1; - */ - public int getStrCount() { - return str_.size(); - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - /** - * repeated string str = 1; - */ - public java.lang.String getStr(int index) { - return str_.get(index); - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - /** - * repeated string str = 1; - */ - public com.google.protobuf.ByteString - getStrBytes(int index) { - return str_.getByteString(index); - } + public static Builder newBuilder(PushDownPb.IntList prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - private byte memoizedIsInitialized = -1; + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - memoizedIsInitialized = 1; - return true; - } + /** Protobuf type {@code IntList} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:IntList) + PushDownPb.IntListOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_IntList_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_IntList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.IntList.class, PushDownPb.IntList.Builder.class); + } + + // Construct using PushDownPb.IntList.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + int_ = emptyIntList(); + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_IntList_descriptor; + } + + @java.lang.Override + public PushDownPb.IntList getDefaultInstanceForType() { + return PushDownPb.IntList.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.IntList build() { + PushDownPb.IntList result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.IntList buildPartial() { + PushDownPb.IntList result = new PushDownPb.IntList(this); + if (((bitField0_ & 0x00000001) != 0)) { + int_.makeImmutable(); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.int_ = int_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.IntList) { + return mergeFrom((PushDownPb.IntList) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.IntList other) { + if (other == PushDownPb.IntList.getDefaultInstance()) return this; + if (!other.int_.isEmpty()) { + if (int_.isEmpty()) { + int_ = other.int_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureIntIsMutable(); + int_.addAll(other.int_); + } + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.IntList parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.IntList) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private com.google.protobuf.Internal.IntList int_ = emptyIntList(); + + private void ensureIntIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + int_ = mutableCopy(int_); + bitField0_ |= 0x00000001; + } + } + + /** repeated int32 int = 1; */ + public java.util.List getIntList() { + return ((bitField0_ & 0x00000001) != 0) + ? java.util.Collections.unmodifiableList(int_) + : int_; + } + + /** repeated int32 int = 1; */ + public int getIntCount() { + return int_.size(); + } + + /** repeated int32 int = 1; */ + public int getInt(int index) { + return int_.getInt(index); + } + + /** repeated int32 int = 1; */ + public Builder setInt(int index, int value) { + ensureIntIsMutable(); + int_.setInt(index, value); + onChanged(); + return this; + } + + /** repeated int32 int = 1; */ + public Builder addInt(int value) { + ensureIntIsMutable(); + int_.addInt(value); + onChanged(); + return this; + } + + /** repeated int32 int = 1; */ + public Builder addAllInt(java.lang.Iterable values) { + ensureIntIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, int_); + onChanged(); + return this; + } + + /** repeated int32 int = 1; */ + public Builder clearInt() { + int_ = emptyIntList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:IntList) + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - for (int i = 0; i < str_.size(); i++) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 1, str_.getRaw(i)); - } - unknownFields.writeTo(output); - } + // @@protoc_insertion_point(class_scope:IntList) + private static final PushDownPb.IntList DEFAULT_INSTANCE; - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + static { + DEFAULT_INSTANCE = new PushDownPb.IntList(); + } - size = 0; - { - int dataSize = 0; - for (int i = 0; i < str_.size(); i++) { - dataSize += computeStringSizeNoTag(str_.getRaw(i)); - } - size += dataSize; - size += 1 * getStrList().size(); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + public static PushDownPb.IntList getDefaultInstance() { + return DEFAULT_INSTANCE; + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.StringList)) { - return super.equals(obj); - } - PushDownPb.StringList other = (PushDownPb.StringList) obj; + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public IntList parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new IntList(input, extensionRegistry); + } + }; - if (!getStrList() - .equals(other.getStrList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getStrCount() > 0) { - hash = (37 * hash) + STR_FIELD_NUMBER; - hash = (53 * hash) + getStrList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - public static PushDownPb.StringList parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public PushDownPb.IntList getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - public static PushDownPb.StringList parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + public interface LongListOrBuilder + extends + // @@protoc_insertion_point(interface_extends:LongList) + com.google.protobuf.MessageOrBuilder { - public static PushDownPb.StringList parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + /** repeated int64 long = 1; */ + java.util.List getLongList(); - public static PushDownPb.StringList parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + /** repeated int64 long = 1; */ + int getLongCount(); - public static PushDownPb.StringList parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + /** repeated int64 long = 1; */ + long getLong(int index); + } - public static PushDownPb.StringList parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + /** Protobuf type {@code LongList} */ + public static final class LongList extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:LongList) + LongListOrBuilder { + private static final long serialVersionUID = 0L; - public static PushDownPb.StringList parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + // Use LongList.newBuilder() to construct. + private LongList(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - public static PushDownPb.StringList parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + private LongList() { + long_ = emptyLongList(); + } - public static PushDownPb.StringList parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new LongList(); + } - public static PushDownPb.StringList parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - public static PushDownPb.StringList parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + private LongList( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 8: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + long_ = newLongList(); + mutable_bitField0_ |= 0x00000001; + } + long_.addLong(input.readInt64()); + break; + } + case 10: + { + int length = input.readRawVarint32(); + int limit = input.pushLimit(length); + if (!((mutable_bitField0_ & 0x00000001) != 0) && input.getBytesUntilLimit() > 0) { + long_ = newLongList(); + mutable_bitField0_ |= 0x00000001; + } + while (input.getBytesUntilLimit() > 0) { + long_.addLong(input.readInt64()); + } + input.popLimit(limit); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + long_.makeImmutable(); // C + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - public static PushDownPb.StringList parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_LongList_descriptor; + } - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_LongList_fieldAccessorTable.ensureFieldAccessorsInitialized( + PushDownPb.LongList.class, PushDownPb.LongList.Builder.class); + } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + public static final int LONG_FIELD_NUMBER = 1; + private com.google.protobuf.Internal.LongList long_; - public static Builder newBuilder(PushDownPb.StringList prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + /** repeated int64 long = 1; */ + public java.util.List getLongList() { + return long_; + } - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + /** repeated int64 long = 1; */ + public int getLongCount() { + return long_.size(); + } - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + /** repeated int64 long = 1; */ + public long getLong(int index) { + return long_.getLong(index); + } - /** - * Protobuf type {@code StringList} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:StringList) - PushDownPb.StringListOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_StringList_descriptor; - } + private int longMemoizedSerializedSize = -1; - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_StringList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.StringList.class, PushDownPb.StringList.Builder.class); - } + private byte memoizedIsInitialized = -1; - // Construct using PushDownPb.StringList.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } + memoizedIsInitialized = 1; + return true; + } - private void maybeForceBuilderInitialization() { - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + getSerializedSize(); + if (getLongList().size() > 0) { + output.writeUInt32NoTag(10); + output.writeUInt32NoTag(longMemoizedSerializedSize); + } + for (int i = 0; i < long_.size(); i++) { + output.writeInt64NoTag(long_.getLong(i)); + } + unknownFields.writeTo(output); + } - @java.lang.Override - public Builder clear() { - super.clear(); - str_ = com.google.protobuf.LazyStringArrayList.EMPTY; - bitField0_ = (bitField0_ & ~0x00000001); - return this; - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + { + int dataSize = 0; + for (int i = 0; i < long_.size(); i++) { + dataSize += com.google.protobuf.CodedOutputStream.computeInt64SizeNoTag(long_.getLong(i)); + } + size += dataSize; + if (!getLongList().isEmpty()) { + size += 1; + size += com.google.protobuf.CodedOutputStream.computeInt32SizeNoTag(dataSize); + } + longMemoizedSerializedSize = dataSize; + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_StringList_descriptor; - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.LongList)) { + return super.equals(obj); + } + PushDownPb.LongList other = (PushDownPb.LongList) obj; + + if (!getLongList().equals(other.getLongList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - @java.lang.Override - public PushDownPb.StringList getDefaultInstanceForType() { - return PushDownPb.StringList.getDefaultInstance(); - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (getLongCount() > 0) { + hash = (37 * hash) + LONG_FIELD_NUMBER; + hash = (53 * hash) + getLongList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - @java.lang.Override - public PushDownPb.StringList build() { - PushDownPb.StringList result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } + public static PushDownPb.LongList parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public PushDownPb.StringList buildPartial() { - PushDownPb.StringList result = new PushDownPb.StringList(this); - if (((bitField0_ & 0x00000001) != 0)) { - str_ = str_.getUnmodifiableView(); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.str_ = str_; - onBuilt(); - return result; - } + public static PushDownPb.LongList parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public Builder clone() { - return super.clone(); - } + public static PushDownPb.LongList parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } + public static PushDownPb.LongList parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } + public static PushDownPb.LongList parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } + public static PushDownPb.LongList parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } + public static PushDownPb.LongList parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } + public static PushDownPb.LongList parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.StringList) { - return mergeFrom((PushDownPb.StringList) other); - } else { - super.mergeFrom(other); - return this; - } - } + public static PushDownPb.LongList parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - public Builder mergeFrom(PushDownPb.StringList other) { - if (other == PushDownPb.StringList.getDefaultInstance()) return this; - if (!other.str_.isEmpty()) { - if (str_.isEmpty()) { - str_ = other.str_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureStrIsMutable(); - str_.addAll(other.str_); - } - onChanged(); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } + public static PushDownPb.LongList parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public final boolean isInitialized() { - return true; - } + public static PushDownPb.LongList parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.StringList parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.StringList) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } + public static PushDownPb.LongList parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - private int bitField0_; + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - private com.google.protobuf.LazyStringList str_ = com.google.protobuf.LazyStringArrayList.EMPTY; + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - private void ensureStrIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - str_ = new com.google.protobuf.LazyStringArrayList(str_); - bitField0_ |= 0x00000001; - } - } + public static Builder newBuilder(PushDownPb.LongList prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - /** - * repeated string str = 1; - */ - public com.google.protobuf.ProtocolStringList - getStrList() { - return str_.getUnmodifiableView(); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - /** - * repeated string str = 1; - */ - public int getStrCount() { - return str_.size(); - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - /** - * repeated string str = 1; - */ - public java.lang.String getStr(int index) { - return str_.get(index); - } + /** Protobuf type {@code LongList} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:LongList) + PushDownPb.LongListOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_LongList_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_LongList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.LongList.class, PushDownPb.LongList.Builder.class); + } + + // Construct using PushDownPb.LongList.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + long_ = emptyLongList(); + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_LongList_descriptor; + } + + @java.lang.Override + public PushDownPb.LongList getDefaultInstanceForType() { + return PushDownPb.LongList.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.LongList build() { + PushDownPb.LongList result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.LongList buildPartial() { + PushDownPb.LongList result = new PushDownPb.LongList(this); + if (((bitField0_ & 0x00000001) != 0)) { + long_.makeImmutable(); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.long_ = long_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.LongList) { + return mergeFrom((PushDownPb.LongList) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.LongList other) { + if (other == PushDownPb.LongList.getDefaultInstance()) return this; + if (!other.long_.isEmpty()) { + if (long_.isEmpty()) { + long_ = other.long_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureLongIsMutable(); + long_.addAll(other.long_); + } + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.LongList parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.LongList) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private com.google.protobuf.Internal.LongList long_ = emptyLongList(); + + private void ensureLongIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + long_ = mutableCopy(long_); + bitField0_ |= 0x00000001; + } + } + + /** repeated int64 long = 1; */ + public java.util.List getLongList() { + return ((bitField0_ & 0x00000001) != 0) + ? java.util.Collections.unmodifiableList(long_) + : long_; + } + + /** repeated int64 long = 1; */ + public int getLongCount() { + return long_.size(); + } + + /** repeated int64 long = 1; */ + public long getLong(int index) { + return long_.getLong(index); + } + + /** repeated int64 long = 1; */ + public Builder setLong(int index, long value) { + ensureLongIsMutable(); + long_.setLong(index, value); + onChanged(); + return this; + } + + /** repeated int64 long = 1; */ + public Builder addLong(long value) { + ensureLongIsMutable(); + long_.addLong(value); + onChanged(); + return this; + } + + /** repeated int64 long = 1; */ + public Builder addAllLong(java.lang.Iterable values) { + ensureLongIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, long_); + onChanged(); + return this; + } + + /** repeated int64 long = 1; */ + public Builder clearLong() { + long_ = emptyLongList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:LongList) + } - /** - * repeated string str = 1; - */ - public com.google.protobuf.ByteString - getStrBytes(int index) { - return str_.getByteString(index); - } + // @@protoc_insertion_point(class_scope:LongList) + private static final PushDownPb.LongList DEFAULT_INSTANCE; - /** - * repeated string str = 1; - */ - public Builder setStr( - int index, java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - ensureStrIsMutable(); - str_.set(index, value); - onChanged(); - return this; - } + static { + DEFAULT_INSTANCE = new PushDownPb.LongList(); + } - /** - * repeated string str = 1; - */ - public Builder addStr( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - ensureStrIsMutable(); - str_.add(value); - onChanged(); - return this; - } + public static PushDownPb.LongList getDefaultInstance() { + return DEFAULT_INSTANCE; + } - /** - * repeated string str = 1; - */ - public Builder addAllStr( - java.lang.Iterable values) { - ensureStrIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, str_); - onChanged(); - return this; - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public LongList parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new LongList(input, extensionRegistry); + } + }; - /** - * repeated string str = 1; - */ - public Builder clearStr() { - str_ = com.google.protobuf.LazyStringArrayList.EMPTY; - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - /** - * repeated string str = 1; - */ - public Builder addStrBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); - ensureStrIsMutable(); - str_.add(value); - onChanged(); - return this; - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } + @java.lang.Override + public PushDownPb.LongList getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } + public interface StringListOrBuilder + extends + // @@protoc_insertion_point(interface_extends:StringList) + com.google.protobuf.MessageOrBuilder { + /** repeated string str = 1; */ + java.util.List getStrList(); - // @@protoc_insertion_point(builder_scope:StringList) - } + /** repeated string str = 1; */ + int getStrCount(); - // @@protoc_insertion_point(class_scope:StringList) - private static final PushDownPb.StringList DEFAULT_INSTANCE; + /** repeated string str = 1; */ + java.lang.String getStr(int index); - static { - DEFAULT_INSTANCE = new PushDownPb.StringList(); - } + /** repeated string str = 1; */ + com.google.protobuf.ByteString getStrBytes(int index); + } - public static PushDownPb.StringList getDefaultInstance() { - return DEFAULT_INSTANCE; - } + /** Protobuf type {@code StringList} */ + public static final class StringList extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:StringList) + StringListOrBuilder { + private static final long serialVersionUID = 0L; - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public StringList parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new StringList(input, extensionRegistry); - } - }; + // Use StringList.newBuilder() to construct. + private StringList(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - public static com.google.protobuf.Parser parser() { - return PARSER; - } + private StringList() { + str_ = com.google.protobuf.LazyStringArrayList.EMPTY; + } - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new StringList(); + } - @java.lang.Override - public PushDownPb.StringList getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } + private StringList( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: + { + java.lang.String s = input.readStringRequireUtf8(); + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + str_ = new com.google.protobuf.LazyStringArrayList(); + mutable_bitField0_ |= 0x00000001; + } + str_.add(s); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + str_ = str_.getUnmodifiableView(); + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } } - public interface BytesListOrBuilder extends - // @@protoc_insertion_point(interface_extends:BytesList) - com.google.protobuf.MessageOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_StringList_descriptor; + } - /** - * repeated bytes bytes = 1; - */ - java.util.List getBytesList(); + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_StringList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.StringList.class, PushDownPb.StringList.Builder.class); + } - /** - * repeated bytes bytes = 1; - */ - int getBytesCount(); + public static final int STR_FIELD_NUMBER = 1; + private com.google.protobuf.LazyStringList str_; - /** - * repeated bytes bytes = 1; - */ - com.google.protobuf.ByteString getBytes(int index); + /** repeated string str = 1; */ + public com.google.protobuf.ProtocolStringList getStrList() { + return str_; } - /** - * Protobuf type {@code BytesList} - */ - public static final class BytesList extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:BytesList) - BytesListOrBuilder { - private static final long serialVersionUID = 0L; + /** repeated string str = 1; */ + public int getStrCount() { + return str_.size(); + } - // Use BytesList.newBuilder() to construct. - private BytesList(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } + /** repeated string str = 1; */ + public java.lang.String getStr(int index) { + return str_.get(index); + } - private BytesList() { - bytes_ = java.util.Collections.emptyList(); - } + /** repeated string str = 1; */ + public com.google.protobuf.ByteString getStrBytes(int index) { + return str_.getByteString(index); + } - @java.lang.Override - @SuppressWarnings({"unused"}) - protected java.lang.Object newInstance( - UnusedPrivateParameter unused) { - return new BytesList(); - } + private byte memoizedIsInitialized = -1; - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - private BytesList( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - case 10: { - if (!((mutable_bitField0_ & 0x00000001) != 0)) { - bytes_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - bytes_.add(input.readBytes()); - break; - } - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) != 0)) { - bytes_ = java.util.Collections.unmodifiableList(bytes_); // C - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } + memoizedIsInitialized = 1; + return true; + } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_BytesList_descriptor; - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + for (int i = 0; i < str_.size(); i++) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, str_.getRaw(i)); + } + unknownFields.writeTo(output); + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_BytesList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.BytesList.class, PushDownPb.BytesList.Builder.class); - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + { + int dataSize = 0; + for (int i = 0; i < str_.size(); i++) { + dataSize += computeStringSizeNoTag(str_.getRaw(i)); + } + size += dataSize; + size += 1 * getStrList().size(); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - public static final int BYTES_FIELD_NUMBER = 1; - private java.util.List bytes_; + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.StringList)) { + return super.equals(obj); + } + PushDownPb.StringList other = (PushDownPb.StringList) obj; + + if (!getStrList().equals(other.getStrList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - /** - * repeated bytes bytes = 1; - */ - public java.util.List - getBytesList() { - return bytes_; - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (getStrCount() > 0) { + hash = (37 * hash) + STR_FIELD_NUMBER; + hash = (53 * hash) + getStrList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - /** - * repeated bytes bytes = 1; - */ - public int getBytesCount() { - return bytes_.size(); - } + public static PushDownPb.StringList parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * repeated bytes bytes = 1; - */ - public com.google.protobuf.ByteString getBytes(int index) { - return bytes_.get(index); - } + public static PushDownPb.StringList parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - private byte memoizedIsInitialized = -1; + public static PushDownPb.StringList parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; + public static PushDownPb.StringList parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - memoizedIsInitialized = 1; - return true; - } + public static PushDownPb.StringList parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - @java.lang.Override - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - for (int i = 0; i < bytes_.size(); i++) { - output.writeBytes(1, bytes_.get(i)); - } - unknownFields.writeTo(output); - } + public static PushDownPb.StringList parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - @java.lang.Override - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; + public static PushDownPb.StringList parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - size = 0; - { - int dataSize = 0; - for (int i = 0; i < bytes_.size(); i++) { - dataSize += com.google.protobuf.CodedOutputStream - .computeBytesSizeNoTag(bytes_.get(i)); - } - size += dataSize; - size += 1 * getBytesList().size(); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } + public static PushDownPb.StringList parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof PushDownPb.BytesList)) { - return super.equals(obj); - } - PushDownPb.BytesList other = (PushDownPb.BytesList) obj; + public static PushDownPb.StringList parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - if (!getBytesList() - .equals(other.getBytesList())) return false; - if (!unknownFields.equals(other.unknownFields)) return false; - return true; - } + public static PushDownPb.StringList parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getBytesCount() > 0) { - hash = (37 * hash) + BYTES_FIELD_NUMBER; - hash = (53 * hash) + getBytesList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } + public static PushDownPb.StringList parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - public static PushDownPb.BytesList parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static PushDownPb.StringList parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - public static PushDownPb.BytesList parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - public static PushDownPb.BytesList parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - public static PushDownPb.BytesList parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + public static Builder newBuilder(PushDownPb.StringList prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - public static PushDownPb.BytesList parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - public static PushDownPb.BytesList parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - public static PushDownPb.BytesList parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + /** Protobuf type {@code StringList} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:StringList) + PushDownPb.StringListOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_StringList_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_StringList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.StringList.class, PushDownPb.StringList.Builder.class); + } + + // Construct using PushDownPb.StringList.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + str_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_StringList_descriptor; + } + + @java.lang.Override + public PushDownPb.StringList getDefaultInstanceForType() { + return PushDownPb.StringList.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.StringList build() { + PushDownPb.StringList result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.StringList buildPartial() { + PushDownPb.StringList result = new PushDownPb.StringList(this); + if (((bitField0_ & 0x00000001) != 0)) { + str_ = str_.getUnmodifiableView(); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.str_ = str_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.StringList) { + return mergeFrom((PushDownPb.StringList) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.StringList other) { + if (other == PushDownPb.StringList.getDefaultInstance()) return this; + if (!other.str_.isEmpty()) { + if (str_.isEmpty()) { + str_ = other.str_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureStrIsMutable(); + str_.addAll(other.str_); + } + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.StringList parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.StringList) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private com.google.protobuf.LazyStringList str_ = + com.google.protobuf.LazyStringArrayList.EMPTY; + + private void ensureStrIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + str_ = new com.google.protobuf.LazyStringArrayList(str_); + bitField0_ |= 0x00000001; + } + } + + /** repeated string str = 1; */ + public com.google.protobuf.ProtocolStringList getStrList() { + return str_.getUnmodifiableView(); + } + + /** repeated string str = 1; */ + public int getStrCount() { + return str_.size(); + } + + /** repeated string str = 1; */ + public java.lang.String getStr(int index) { + return str_.get(index); + } + + /** repeated string str = 1; */ + public com.google.protobuf.ByteString getStrBytes(int index) { + return str_.getByteString(index); + } + + /** repeated string str = 1; */ + public Builder setStr(int index, java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureStrIsMutable(); + str_.set(index, value); + onChanged(); + return this; + } + + /** repeated string str = 1; */ + public Builder addStr(java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureStrIsMutable(); + str_.add(value); + onChanged(); + return this; + } + + /** repeated string str = 1; */ + public Builder addAllStr(java.lang.Iterable values) { + ensureStrIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, str_); + onChanged(); + return this; + } + + /** repeated string str = 1; */ + public Builder clearStr() { + str_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + /** repeated string str = 1; */ + public Builder addStrBytes(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + ensureStrIsMutable(); + str_.add(value); + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:StringList) + } - public static PushDownPb.BytesList parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + // @@protoc_insertion_point(class_scope:StringList) + private static final PushDownPb.StringList DEFAULT_INSTANCE; - public static PushDownPb.BytesList parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } + static { + DEFAULT_INSTANCE = new PushDownPb.StringList(); + } - public static PushDownPb.BytesList parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } + public static PushDownPb.StringList getDefaultInstance() { + return DEFAULT_INSTANCE; + } - public static PushDownPb.BytesList parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public StringList parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new StringList(input, extensionRegistry); + } + }; - public static PushDownPb.BytesList parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } + public static com.google.protobuf.Parser parser() { + return PARSER; + } - @java.lang.Override - public Builder newBuilderForType() { - return newBuilder(); - } + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } + @java.lang.Override + public PushDownPb.StringList getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } - public static Builder newBuilder(PushDownPb.BytesList prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } + public interface BytesListOrBuilder + extends + // @@protoc_insertion_point(interface_extends:BytesList) + com.google.protobuf.MessageOrBuilder { - @java.lang.Override - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } + /** repeated bytes bytes = 1; */ + java.util.List getBytesList(); - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } + /** repeated bytes bytes = 1; */ + int getBytesCount(); - /** - * Protobuf type {@code BytesList} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:BytesList) - PushDownPb.BytesListOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return PushDownPb.internal_static_BytesList_descriptor; - } + /** repeated bytes bytes = 1; */ + com.google.protobuf.ByteString getBytes(int index); + } - @java.lang.Override - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return PushDownPb.internal_static_BytesList_fieldAccessorTable - .ensureFieldAccessorsInitialized( - PushDownPb.BytesList.class, PushDownPb.BytesList.Builder.class); - } + /** Protobuf type {@code BytesList} */ + public static final class BytesList extends com.google.protobuf.GeneratedMessageV3 + implements + // @@protoc_insertion_point(message_implements:BytesList) + BytesListOrBuilder { + private static final long serialVersionUID = 0L; - // Construct using PushDownPb.BytesList.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } + // Use BytesList.newBuilder() to construct. + private BytesList(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } + private BytesList() { + bytes_ = java.util.Collections.emptyList(); + } - private void maybeForceBuilderInitialization() { - } + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance(UnusedPrivateParameter unused) { + return new BytesList(); + } - @java.lang.Override - public Builder clear() { - super.clear(); - bytes_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - return this; - } + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet getUnknownFields() { + return this.unknownFields; + } - @java.lang.Override - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return PushDownPb.internal_static_BytesList_descriptor; - } + private BytesList( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: + { + if (!((mutable_bitField0_ & 0x00000001) != 0)) { + bytes_ = new java.util.ArrayList(); + mutable_bitField0_ |= 0x00000001; + } + bytes_.add(input.readBytes()); + break; + } + default: + { + if (!parseUnknownField(input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e).setUnfinishedMessage(this); + } finally { + if (((mutable_bitField0_ & 0x00000001) != 0)) { + bytes_ = java.util.Collections.unmodifiableList(bytes_); // C + } + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } - @java.lang.Override - public PushDownPb.BytesList getDefaultInstanceForType() { - return PushDownPb.BytesList.getDefaultInstance(); - } + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_BytesList_descriptor; + } - @java.lang.Override - public PushDownPb.BytesList build() { - PushDownPb.BytesList result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_BytesList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.BytesList.class, PushDownPb.BytesList.Builder.class); + } - @java.lang.Override - public PushDownPb.BytesList buildPartial() { - PushDownPb.BytesList result = new PushDownPb.BytesList(this); - if (((bitField0_ & 0x00000001) != 0)) { - bytes_ = java.util.Collections.unmodifiableList(bytes_); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.bytes_ = bytes_; - onBuilt(); - return result; - } + public static final int BYTES_FIELD_NUMBER = 1; + private java.util.List bytes_; - @java.lang.Override - public Builder clone() { - return super.clone(); - } + /** repeated bytes bytes = 1; */ + public java.util.List getBytesList() { + return bytes_; + } - @java.lang.Override - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.setField(field, value); - } + /** repeated bytes bytes = 1; */ + public int getBytesCount() { + return bytes_.size(); + } - @java.lang.Override - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return super.clearField(field); - } + /** repeated bytes bytes = 1; */ + public com.google.protobuf.ByteString getBytes(int index) { + return bytes_.get(index); + } - @java.lang.Override - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return super.clearOneof(oneof); - } + private byte memoizedIsInitialized = -1; - @java.lang.Override - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return super.setRepeatedField(field, index, value); - } + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; - @java.lang.Override - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return super.addRepeatedField(field, value); - } + memoizedIsInitialized = 1; + return true; + } - @java.lang.Override - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof PushDownPb.BytesList) { - return mergeFrom((PushDownPb.BytesList) other); - } else { - super.mergeFrom(other); - return this; - } - } + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + for (int i = 0; i < bytes_.size(); i++) { + output.writeBytes(1, bytes_.get(i)); + } + unknownFields.writeTo(output); + } - public Builder mergeFrom(PushDownPb.BytesList other) { - if (other == PushDownPb.BytesList.getDefaultInstance()) return this; - if (!other.bytes_.isEmpty()) { - if (bytes_.isEmpty()) { - bytes_ = other.bytes_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureBytesIsMutable(); - bytes_.addAll(other.bytes_); - } - onChanged(); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + { + int dataSize = 0; + for (int i = 0; i < bytes_.size(); i++) { + dataSize += com.google.protobuf.CodedOutputStream.computeBytesSizeNoTag(bytes_.get(i)); + } + size += dataSize; + size += 1 * getBytesList().size(); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } - @java.lang.Override - public final boolean isInitialized() { - return true; - } + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof PushDownPb.BytesList)) { + return super.equals(obj); + } + PushDownPb.BytesList other = (PushDownPb.BytesList) obj; + + if (!getBytesList().equals(other.getBytesList())) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } - @java.lang.Override - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - PushDownPb.BytesList parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (PushDownPb.BytesList) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (getBytesCount() > 0) { + hash = (37 * hash) + BYTES_FIELD_NUMBER; + hash = (53 * hash) + getBytesList().hashCode(); + } + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } - private int bitField0_; + public static PushDownPb.BytesList parseFrom(java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - private java.util.List bytes_ = java.util.Collections.emptyList(); + public static PushDownPb.BytesList parseFrom( + java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - private void ensureBytesIsMutable() { - if (!((bitField0_ & 0x00000001) != 0)) { - bytes_ = new java.util.ArrayList(bytes_); - bitField0_ |= 0x00000001; - } - } + public static PushDownPb.BytesList parseFrom(com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * repeated bytes bytes = 1; - */ - public java.util.List - getBytesList() { - return ((bitField0_ & 0x00000001) != 0) ? - java.util.Collections.unmodifiableList(bytes_) : bytes_; - } + public static PushDownPb.BytesList parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * repeated bytes bytes = 1; - */ - public int getBytesCount() { - return bytes_.size(); - } + public static PushDownPb.BytesList parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } - /** - * repeated bytes bytes = 1; - */ - public com.google.protobuf.ByteString getBytes(int index) { - return bytes_.get(index); - } + public static PushDownPb.BytesList parseFrom( + byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } - /** - * repeated bytes bytes = 1; - */ - public Builder setBytes( - int index, com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - ensureBytesIsMutable(); - bytes_.set(index, value); - onChanged(); - return this; - } + public static PushDownPb.BytesList parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - /** - * repeated bytes bytes = 1; - */ - public Builder addBytes(com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - ensureBytesIsMutable(); - bytes_.add(value); - onChanged(); - return this; - } + public static PushDownPb.BytesList parseFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } - /** - * repeated bytes bytes = 1; - */ - public Builder addAllBytes( - java.lang.Iterable values) { - ensureBytesIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, bytes_); - onChanged(); - return this; - } + public static PushDownPb.BytesList parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException(PARSER, input); + } - /** - * repeated bytes bytes = 1; - */ - public Builder clearBytes() { - bytes_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - return this; - } + public static PushDownPb.BytesList parseDelimitedFrom( + java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseDelimitedWithIOException( + PARSER, input, extensionRegistry); + } - @java.lang.Override - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } + public static PushDownPb.BytesList parseFrom(com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException(PARSER, input); + } - @java.lang.Override - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } + public static PushDownPb.BytesList parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3.parseWithIOException( + PARSER, input, extensionRegistry); + } + @java.lang.Override + public Builder newBuilderForType() { + return newBuilder(); + } - // @@protoc_insertion_point(builder_scope:BytesList) - } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } - // @@protoc_insertion_point(class_scope:BytesList) - private static final PushDownPb.BytesList DEFAULT_INSTANCE; + public static Builder newBuilder(PushDownPb.BytesList prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } - static { - DEFAULT_INSTANCE = new PushDownPb.BytesList(); - } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE ? new Builder() : new Builder().mergeFrom(this); + } - public static PushDownPb.BytesList getDefaultInstance() { - return DEFAULT_INSTANCE; - } + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - @java.lang.Override - public BytesList parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new BytesList(input, extensionRegistry); - } - }; + /** Protobuf type {@code BytesList} */ + public static final class Builder + extends com.google.protobuf.GeneratedMessageV3.Builder + implements + // @@protoc_insertion_point(builder_implements:BytesList) + PushDownPb.BytesListOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { + return PushDownPb.internal_static_BytesList_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return PushDownPb.internal_static_BytesList_fieldAccessorTable + .ensureFieldAccessorsInitialized( + PushDownPb.BytesList.class, PushDownPb.BytesList.Builder.class); + } + + // Construct using PushDownPb.BytesList.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder(com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + + private void maybeForceBuilderInitialization() {} + + @java.lang.Override + public Builder clear() { + super.clear(); + bytes_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { + return PushDownPb.internal_static_BytesList_descriptor; + } + + @java.lang.Override + public PushDownPb.BytesList getDefaultInstanceForType() { + return PushDownPb.BytesList.getDefaultInstance(); + } + + @java.lang.Override + public PushDownPb.BytesList build() { + PushDownPb.BytesList result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public PushDownPb.BytesList buildPartial() { + PushDownPb.BytesList result = new PushDownPb.BytesList(this); + if (((bitField0_ & 0x00000001) != 0)) { + bytes_ = java.util.Collections.unmodifiableList(bytes_); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.bytes_ = bytes_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.setField(field, value); + } + + @java.lang.Override + public Builder clearField(com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + + @java.lang.Override + public Builder clearOneof(com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, + java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, java.lang.Object value) { + return super.addRepeatedField(field, value); + } + + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof PushDownPb.BytesList) { + return mergeFrom((PushDownPb.BytesList) other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(PushDownPb.BytesList other) { + if (other == PushDownPb.BytesList.getDefaultInstance()) return this; + if (!other.bytes_.isEmpty()) { + if (bytes_.isEmpty()) { + bytes_ = other.bytes_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureBytesIsMutable(); + bytes_.addAll(other.bytes_); + } + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + PushDownPb.BytesList parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (PushDownPb.BytesList) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private int bitField0_; + + private java.util.List bytes_ = + java.util.Collections.emptyList(); + + private void ensureBytesIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + bytes_ = new java.util.ArrayList(bytes_); + bitField0_ |= 0x00000001; + } + } + + /** repeated bytes bytes = 1; */ + public java.util.List getBytesList() { + return ((bitField0_ & 0x00000001) != 0) + ? java.util.Collections.unmodifiableList(bytes_) + : bytes_; + } + + /** repeated bytes bytes = 1; */ + public int getBytesCount() { + return bytes_.size(); + } + + /** repeated bytes bytes = 1; */ + public com.google.protobuf.ByteString getBytes(int index) { + return bytes_.get(index); + } + + /** repeated bytes bytes = 1; */ + public Builder setBytes(int index, com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + ensureBytesIsMutable(); + bytes_.set(index, value); + onChanged(); + return this; + } + + /** repeated bytes bytes = 1; */ + public Builder addBytes(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + ensureBytesIsMutable(); + bytes_.add(value); + onChanged(); + return this; + } + + /** repeated bytes bytes = 1; */ + public Builder addAllBytes( + java.lang.Iterable values) { + ensureBytesIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll(values, bytes_); + onChanged(); + return this; + } + + /** repeated bytes bytes = 1; */ + public Builder clearBytes() { + bytes_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + // @@protoc_insertion_point(builder_scope:BytesList) + } - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - @java.lang.Override - public PushDownPb.BytesList getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_PushDown_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_PushDown_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_EdgeLimit_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_EdgeLimit_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_FilterNodes_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_FilterNodes_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_FilterNode_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_FilterNode_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_IntList_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_IntList_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_LongList_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_LongList_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_StringList_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_StringList_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_BytesList_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_BytesList_fieldAccessorTable; - - public static com.google.protobuf.Descriptors.FileDescriptor - getDescriptor() { - return descriptor; - } - - private static com.google.protobuf.Descriptors.FileDescriptor - descriptor; + // @@protoc_insertion_point(class_scope:BytesList) + private static final PushDownPb.BytesList DEFAULT_INSTANCE; static { - java.lang.String[] descriptorData = { - "\n\016pushdown.proto\"\234\001\n\010PushDown\022\"\n\013filter_" + - "node\030\001 \001(\0132\013.FilterNodeH\000\022$\n\014filter_node" + - "s\030\002 \001(\0132\014.FilterNodesH\000\022\036\n\nedge_limit\030\003 " + - "\001(\0132\n.EdgeLimit\022\034\n\tsort_type\030\004 \003(\0162\t.Sor" + - "tTypeB\010\n\006filter\"7\n\tEdgeLimit\022\n\n\002in\030\001 \001(\004" + - "\022\013\n\003out\030\002 \001(\004\022\021\n\tis_single\030\003 \001(\010\">\n\013Filt" + - "erNodes\022\014\n\004keys\030\001 \003(\014\022!\n\014filter_nodes\030\002 " + - "\003(\0132\013.FilterNode\"\344\001\n\nFilterNode\022 \n\013filte" + - "r_type\030\001 \001(\0162\013.FilterType\022\034\n\007filters\030\002 \003" + - "(\0132\013.FilterNode\022\037\n\013int_content\030\003 \001(\0132\010.I" + - "ntListH\000\022!\n\014long_content\030\004 \001(\0132\t.LongLis" + - "tH\000\022\"\n\013str_content\030\005 \001(\0132\013.StringListH\000\022" + - "#\n\rbytes_content\030\006 \001(\0132\n.BytesListH\000B\t\n\007" + - "content\"\026\n\007IntList\022\013\n\003int\030\001 \003(\005\"\030\n\010LongL" + - "ist\022\014\n\004long\030\001 \003(\003\"\031\n\nStringList\022\013\n\003str\030\001" + - " \003(\t\"\032\n\tBytesList\022\r\n\005bytes\030\001 \003(\014*\215\002\n\nFil" + - "terType\022\t\n\005EMPTY\020\000\022\017\n\013ONLY_VERTEX\020\001\022\013\n\007I" + - "N_EDGE\020\002\022\014\n\010OUT_EDGE\020\003\022\r\n\tVERTEX_TS\020\004\022\013\n" + - "\007EDGE_TS\020\005\022\021\n\rMULTI_EDGE_TS\020\006\022\020\n\014VERTEX_" + - "LABEL\020\007\022\016\n\nEDGE_LABEL\020\010\022\025\n\021VERTEX_VALUE_" + - "DROP\020\t\022\023\n\017EDGE_VALUE_DROP\020\n\022\007\n\003TTL\020\013\022\007\n\003" + - "AND\020\014\022\006\n\002OR\020\r\022\027\n\023VERTEX_MUST_CONTAIN\020\016\022\r" + - "\n\tGENERATED\020\017\022\t\n\005OTHER\020\020*U\n\010SortType\022\n\n\006" + - "SRC_ID\020\000\022\r\n\tDIRECTION\020\001\022\r\n\tDESC_TIME\020\002\022\010" + - "\n\004TIME\020\003\022\t\n\005LABEL\020\004\022\n\n\006DST_ID\020\005B\014B\nPushD" + - "ownPbb\006proto3" + DEFAULT_INSTANCE = new PushDownPb.BytesList(); + } + + public static PushDownPb.BytesList getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser PARSER = + new com.google.protobuf.AbstractParser() { + @java.lang.Override + public BytesList parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new BytesList(input, extensionRegistry); + } }; - descriptor = com.google.protobuf.Descriptors.FileDescriptor - .internalBuildGeneratedFileFrom(descriptorData, - new com.google.protobuf.Descriptors.FileDescriptor[]{ - }); - internal_static_PushDown_descriptor = - getDescriptor().getMessageTypes().get(0); - internal_static_PushDown_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public PushDownPb.BytesList getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_PushDown_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_PushDown_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_EdgeLimit_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_EdgeLimit_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_FilterNodes_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_FilterNodes_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_FilterNode_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_FilterNode_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_IntList_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_IntList_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_LongList_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_LongList_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_StringList_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_StringList_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_BytesList_descriptor; + private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_BytesList_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor getDescriptor() { + return descriptor; + } + + private static com.google.protobuf.Descriptors.FileDescriptor descriptor; + + static { + java.lang.String[] descriptorData = { + "\n" + + "\016pushdown.proto\"\234\001\n" + + "\010PushDown\022\"\n" + + "\013filter_node\030\001 \001(\0132\013.FilterNodeH\000\022$\n" + + "\014filter_nodes\030\002 \001(\0132\014.FilterNodesH\000\022\036\n\n" + + "edge_limit\030\003 \001(\0132\n" + + ".EdgeLimit\022\034\n" + + "\tsort_type\030\004 \003(\0162\t.SortTypeB\010\n" + + "\006filter\"7\n" + + "\tEdgeLimit\022\n\n" + + "\002in\030\001 \001(\004\022\013\n" + + "\003out\030\002 \001(\004\022\021\n" + + "\tis_single\030\003 \001(\010\">\n" + + "\013FilterNodes\022\014\n" + + "\004keys\030\001 \003(\014\022!\n" + + "\014filter_nodes\030\002 \003(\0132\013.FilterNode\"\344\001\n\n" + + "FilterNode\022 \n" + + "\013filter_type\030\001 \001(\0162\013.FilterType\022\034\n" + + "\007filters\030\002 \003(\0132\013.FilterNode\022\037\n" + + "\013int_content\030\003 \001(\0132\010.IntListH\000\022!\n" + + "\014long_content\030\004 \001(\0132\t.LongListH\000\022\"\n" + + "\013str_content\030\005 \001(\0132\013.StringListH\000\022#\n\r" + + "bytes_content\030\006 \001(\0132\n" + + ".BytesListH\000B\t\n" + + "\007content\"\026\n" + + "\007IntList\022\013\n" + + "\003int\030\001 \003(\005\"\030\n" + + "\010LongList\022\014\n" + + "\004long\030\001 \003(\003\"\031\n\n" + + "StringList\022\013\n" + + "\003str\030\001 \003(\t\"\032\n" + + "\tBytesList\022\r\n" + + "\005bytes\030\001 \003(\014*\215\002\n\n" + + "FilterType\022\t\n" + + "\005EMPTY\020\000\022\017\n" + + "\013ONLY_VERTEX\020\001\022\013\n" + + "\007IN_EDGE\020\002\022\014\n" + + "\010OUT_EDGE\020\003\022\r\n" + + "\tVERTEX_TS\020\004\022\013\n" + + "\007EDGE_TS\020\005\022\021\n\r" + + "MULTI_EDGE_TS\020\006\022\020\n" + + "\014VERTEX_LABEL\020\007\022\016\n\n" + + "EDGE_LABEL\020\010\022\025\n" + + "\021VERTEX_VALUE_DROP\020\t\022\023\n" + + "\017EDGE_VALUE_DROP\020\n" + + "\022\007\n" + + "\003TTL\020\013\022\007\n" + + "\003AND\020\014\022\006\n" + + "\002OR\020\r" + + "\022\027\n" + + "\023VERTEX_MUST_CONTAIN\020\016\022\r" + + "\n" + + "\tGENERATED\020\017\022\t\n" + + "\005OTHER\020\020*U\n" + + "\010SortType\022\n\n" + + "\006SRC_ID\020\000\022\r\n" + + "\tDIRECTION\020\001\022\r\n" + + "\tDESC_TIME\020\002\022\010\n" + + "\004TIME\020\003\022\t\n" + + "\005LABEL\020\004\022\n\n" + + "\006DST_ID\020\005B\014B\n" + + "PushDownPbb\006proto3" + }; + descriptor = + com.google.protobuf.Descriptors.FileDescriptor.internalBuildGeneratedFileFrom( + descriptorData, new com.google.protobuf.Descriptors.FileDescriptor[] {}); + internal_static_PushDown_descriptor = getDescriptor().getMessageTypes().get(0); + internal_static_PushDown_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_PushDown_descriptor, - new java.lang.String[]{"FilterNode", "FilterNodes", "EdgeLimit", "SortType", "Filter",}); - internal_static_EdgeLimit_descriptor = - getDescriptor().getMessageTypes().get(1); - internal_static_EdgeLimit_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "FilterNode", "FilterNodes", "EdgeLimit", "SortType", "Filter", + }); + internal_static_EdgeLimit_descriptor = getDescriptor().getMessageTypes().get(1); + internal_static_EdgeLimit_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_EdgeLimit_descriptor, - new java.lang.String[]{"In", "Out", "IsSingle",}); - internal_static_FilterNodes_descriptor = - getDescriptor().getMessageTypes().get(2); - internal_static_FilterNodes_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "In", "Out", "IsSingle", + }); + internal_static_FilterNodes_descriptor = getDescriptor().getMessageTypes().get(2); + internal_static_FilterNodes_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_FilterNodes_descriptor, - new java.lang.String[]{"Keys", "FilterNodes",}); - internal_static_FilterNode_descriptor = - getDescriptor().getMessageTypes().get(3); - internal_static_FilterNode_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "Keys", "FilterNodes", + }); + internal_static_FilterNode_descriptor = getDescriptor().getMessageTypes().get(3); + internal_static_FilterNode_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_FilterNode_descriptor, - new java.lang.String[]{"FilterType", "Filters", "IntContent", "LongContent", "StrContent", "BytesContent", "Content",}); - internal_static_IntList_descriptor = - getDescriptor().getMessageTypes().get(4); - internal_static_IntList_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "FilterType", + "Filters", + "IntContent", + "LongContent", + "StrContent", + "BytesContent", + "Content", + }); + internal_static_IntList_descriptor = getDescriptor().getMessageTypes().get(4); + internal_static_IntList_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_IntList_descriptor, - new java.lang.String[]{"Int",}); - internal_static_LongList_descriptor = - getDescriptor().getMessageTypes().get(5); - internal_static_LongList_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "Int", + }); + internal_static_LongList_descriptor = getDescriptor().getMessageTypes().get(5); + internal_static_LongList_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_LongList_descriptor, - new java.lang.String[]{"Long",}); - internal_static_StringList_descriptor = - getDescriptor().getMessageTypes().get(6); - internal_static_StringList_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "Long", + }); + internal_static_StringList_descriptor = getDescriptor().getMessageTypes().get(6); + internal_static_StringList_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_StringList_descriptor, - new java.lang.String[]{"Str",}); - internal_static_BytesList_descriptor = - getDescriptor().getMessageTypes().get(7); - internal_static_BytesList_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + new java.lang.String[] { + "Str", + }); + internal_static_BytesList_descriptor = getDescriptor().getMessageTypes().get(7); + internal_static_BytesList_fieldAccessorTable = + new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_BytesList_descriptor, - new java.lang.String[]{"Bytes",}); - } + new java.lang.String[] { + "Bytes", + }); + } - // @@protoc_insertion_point(outer_class_scope) + // @@protoc_insertion_point(outer_class_scope) } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPbGenerator.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPbGenerator.java index 97de85104..52cae0cc9 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPbGenerator.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/inner/PushDownPbGenerator.java @@ -19,11 +19,11 @@ package org.apache.geaflow.state.pushdown.inner; -import com.google.protobuf.ByteString; import java.util.ArrayList; import java.util.List; import java.util.Map.Entry; import java.util.stream.Collectors; + import org.apache.geaflow.common.type.IType; import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.pushdown.IStatePushDown; @@ -35,45 +35,49 @@ import org.apache.geaflow.state.pushdown.limit.IEdgeLimit; import org.apache.geaflow.state.pushdown.limit.LimitType; +import com.google.protobuf.ByteString; + public class PushDownPbGenerator { - private static final byte[] EMPTY = new byte[0]; + private static final byte[] EMPTY = new byte[0]; - public static PushDown getPushDownPb(IType type, IStatePushDown pushDown) { - PushDown.Builder builder = PushDown.newBuilder(); - if (pushDown.getFilters() == null) { - builder.setFilterNode(FilterGenerator.getFilterData(pushDown.getFilter())); - } else { - List keys = new ArrayList<>(pushDown.getFilters().size()); - List filterNodes = new ArrayList<>(pushDown.getFilters().size()); - for (Object obj : pushDown.getFilters().entrySet()) { - Entry entry = (Entry) obj; - keys.add(ByteString.copyFrom(type.serialize(entry.getKey()))); - filterNodes.add(FilterGenerator.getFilterData(entry.getValue())); - } - FilterNodes nodes = FilterNodes.newBuilder() - .addAllKeys(keys).addAllFilterNodes(filterNodes).build(); - builder.setFilterNodes(nodes); - } - IEdgeLimit limit = pushDown.getEdgeLimit(); - if (limit != null) { - builder.setEdgeLimit(EdgeLimit.newBuilder() - .setIn(limit.inEdgeLimit()) - .setOut(limit.outEdgeLimit()) - .setIsSingle(limit.limitType() == LimitType.SINGLE) - .build()); - } - if (pushDown.getOrderFields() != null) { - List edgeAtoms = pushDown.getOrderFields(); - builder.addAllSortType(edgeAtoms.stream().map(EdgeAtom::toPbSortType).collect(Collectors.toList())); - } - return builder.build(); + public static PushDown getPushDownPb(IType type, IStatePushDown pushDown) { + PushDown.Builder builder = PushDown.newBuilder(); + if (pushDown.getFilters() == null) { + builder.setFilterNode(FilterGenerator.getFilterData(pushDown.getFilter())); + } else { + List keys = new ArrayList<>(pushDown.getFilters().size()); + List filterNodes = new ArrayList<>(pushDown.getFilters().size()); + for (Object obj : pushDown.getFilters().entrySet()) { + Entry entry = (Entry) obj; + keys.add(ByteString.copyFrom(type.serialize(entry.getKey()))); + filterNodes.add(FilterGenerator.getFilterData(entry.getValue())); + } + FilterNodes nodes = + FilterNodes.newBuilder().addAllKeys(keys).addAllFilterNodes(filterNodes).build(); + builder.setFilterNodes(nodes); + } + IEdgeLimit limit = pushDown.getEdgeLimit(); + if (limit != null) { + builder.setEdgeLimit( + EdgeLimit.newBuilder() + .setIn(limit.inEdgeLimit()) + .setOut(limit.outEdgeLimit()) + .setIsSingle(limit.limitType() == LimitType.SINGLE) + .build()); + } + if (pushDown.getOrderFields() != null) { + List edgeAtoms = pushDown.getOrderFields(); + builder.addAllSortType( + edgeAtoms.stream().map(EdgeAtom::toPbSortType).collect(Collectors.toList())); } + return builder.build(); + } - public static byte[] getPushDownPbBytes(IType type, IStatePushDown pushDown) { - if (pushDown.isEmpty()) { - return EMPTY; - } - return getPushDownPb(type, pushDown).toByteArray(); + public static byte[] getPushDownPbBytes(IType type, IStatePushDown pushDown) { + if (pushDown.isEmpty()) { + return EMPTY; } + return getPushDownPb(type, pushDown).toByteArray(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/ComposedEdgeLimit.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/ComposedEdgeLimit.java index 52311c742..9637d9b36 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/ComposedEdgeLimit.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/ComposedEdgeLimit.java @@ -21,27 +21,26 @@ public class ComposedEdgeLimit implements IEdgeLimit { - private long outEdgeLimit; - private long inEdgeLimit; - - public ComposedEdgeLimit(long outEdgeLimit, long inEdgeLimit) { - this.outEdgeLimit = outEdgeLimit; - this.inEdgeLimit = inEdgeLimit; - } - - @Override - public long inEdgeLimit() { - return this.inEdgeLimit; - } - - @Override - public long outEdgeLimit() { - return this.outEdgeLimit; - } - - @Override - public LimitType limitType() { - return LimitType.COMPOSED; - } - + private long outEdgeLimit; + private long inEdgeLimit; + + public ComposedEdgeLimit(long outEdgeLimit, long inEdgeLimit) { + this.outEdgeLimit = outEdgeLimit; + this.inEdgeLimit = inEdgeLimit; + } + + @Override + public long inEdgeLimit() { + return this.inEdgeLimit; + } + + @Override + public long outEdgeLimit() { + return this.outEdgeLimit; + } + + @Override + public LimitType limitType() { + return LimitType.COMPOSED; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/IEdgeLimit.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/IEdgeLimit.java index 2024fcb8e..ef2d92f99 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/IEdgeLimit.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/IEdgeLimit.java @@ -21,23 +21,15 @@ import java.io.Serializable; -/** - * The edge limit interface describes the queried edge limit number. - */ +/** The edge limit interface describes the queried edge limit number. */ public interface IEdgeLimit extends Serializable { - /** - * Returns the in edge limit number. - */ - long inEdgeLimit(); + /** Returns the in edge limit number. */ + long inEdgeLimit(); - /** - * Returns the out edge limit number. - */ - long outEdgeLimit(); + /** Returns the out edge limit number. */ + long outEdgeLimit(); - /** - * Returns the limit type. - */ - LimitType limitType(); + /** Returns the limit type. */ + LimitType limitType(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/LimitType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/LimitType.java index 72c59ac43..17c2a20b4 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/LimitType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/LimitType.java @@ -21,13 +21,11 @@ public enum LimitType { - /** - * Composed limit type is to take the whole filter into consideration. - */ - COMPOSED, - /** - * Single limit type is used for or filter, which - * each inner filter is independence for the limit condition. - */ - SINGLE, + /** Composed limit type is to take the whole filter into consideration. */ + COMPOSED, + /** + * Single limit type is used for or filter, which each inner filter is independence for the limit + * condition. + */ + SINGLE, } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/SingleEdgeLimit.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/SingleEdgeLimit.java index eac74e689..60a41850f 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/SingleEdgeLimit.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/limit/SingleEdgeLimit.java @@ -21,12 +21,12 @@ public class SingleEdgeLimit extends ComposedEdgeLimit { - public SingleEdgeLimit(long outEdgeLimit, long inEdgeLimit) { - super(outEdgeLimit, inEdgeLimit); - } + public SingleEdgeLimit(long outEdgeLimit, long inEdgeLimit) { + super(outEdgeLimit, inEdgeLimit); + } - @Override - public LimitType limitType() { - return LimitType.SINGLE; - } + @Override + public LimitType limitType() { + return LimitType.SINGLE; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/EdgeOrder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/EdgeOrder.java index df6e63128..d5d6245a4 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/EdgeOrder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/EdgeOrder.java @@ -23,21 +23,21 @@ public class EdgeOrder implements IEdgeOrder { - private GraphFiledName filedName; - private boolean desc; + private GraphFiledName filedName; + private boolean desc; - public EdgeOrder(GraphFiledName filedName, boolean desc) { - this.filedName = filedName; - this.desc = desc; - } + public EdgeOrder(GraphFiledName filedName, boolean desc) { + this.filedName = filedName; + this.desc = desc; + } - @Override - public GraphFiledName getField() { - return this.filedName; - } + @Override + public GraphFiledName getField() { + return this.filedName; + } - @Override - public boolean desc() { - return this.desc; - } + @Override + public boolean desc() { + return this.desc; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/IEdgeOrder.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/IEdgeOrder.java index 6f01a4a8d..d0f1e81cf 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/IEdgeOrder.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/order/IEdgeOrder.java @@ -20,11 +20,12 @@ package org.apache.geaflow.state.pushdown.order; import java.io.Serializable; + import org.apache.geaflow.model.graph.meta.GraphFiledName; public interface IEdgeOrder extends Serializable { - GraphFiledName getField(); + GraphFiledName getField(); - boolean desc(); + boolean desc(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/DstIdProjector.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/DstIdProjector.java index 809a00e1c..57e5c5d3b 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/DstIdProjector.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/DstIdProjector.java @@ -23,13 +23,13 @@ public class DstIdProjector implements IProjector, K> { - @Override - public K project(IEdge value) { - return value.getTargetId(); - } + @Override + public K project(IEdge value) { + return value.getTargetId(); + } - @Override - public ProjectType projectType() { - return ProjectType.DST_ID; - } + @Override + public ProjectType projectType() { + return ProjectType.DST_ID; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/IProjector.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/IProjector.java index d96bb6a6a..ebb846088 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/IProjector.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/IProjector.java @@ -23,7 +23,7 @@ public interface IProjector extends Serializable { - R project(T value); + R project(T value); - ProjectType projectType(); + ProjectType projectType(); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/ProjectType.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/ProjectType.java index da624b62c..d767d3b4c 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/ProjectType.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/ProjectType.java @@ -21,16 +21,10 @@ public enum ProjectType { - /** - * dst id projector. - */ - DST_ID, - /** - * time projector. - */ - TIME, - /** - * property projector. - */ - PROPERTY, + /** dst id projector. */ + DST_ID, + /** time projector. */ + TIME, + /** property projector. */ + PROPERTY, } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/PropertyProjector.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/PropertyProjector.java index 64433d6a9..c2c1b0d2d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/PropertyProjector.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/PropertyProjector.java @@ -21,19 +21,19 @@ public class PropertyProjector implements IProjector { - private final String[] columns; + private final String[] columns; - public PropertyProjector(String[] columns) { - this.columns = columns; - } + public PropertyProjector(String[] columns) { + this.columns = columns; + } - @Override - public T project(T value) { - return value; - } + @Override + public T project(T value) { + return value; + } - @Override - public ProjectType projectType() { - return ProjectType.PROPERTY; - } + @Override + public ProjectType projectType() { + return ProjectType.PROPERTY; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/TimeProjector.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/TimeProjector.java index 52a1c9c53..17dd348a5 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/TimeProjector.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/pushdown/project/TimeProjector.java @@ -24,13 +24,13 @@ public class TimeProjector implements IProjector, Long> { - @Override - public Long project(IEdge value) { - return ((IGraphElementWithTimeField) value).getTime(); - } + @Override + public Long project(IEdge value) { + return ((IGraphElementWithTimeField) value).getTime(); + } - @Override - public ProjectType projectType() { - return ProjectType.TIME; - } + @Override + public ProjectType projectType() { + return ProjectType.TIME; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/schema/GraphDataSchema.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/schema/GraphDataSchema.java index 3638a3dd5..99ee4e9d4 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/schema/GraphDataSchema.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/schema/GraphDataSchema.java @@ -19,8 +19,6 @@ package org.apache.geaflow.state.schema; - -import com.google.common.base.Preconditions; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.HashMap; @@ -32,6 +30,7 @@ import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.schema.Field; import org.apache.geaflow.common.serialize.ISerializer; @@ -49,186 +48,194 @@ import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.graph.encoder.VertexAtom; -public class GraphDataSchema { - - private final IGraphElementMeta vertexMeta; - private final IGraphElementMeta edgeMeta; - private final Map metaIdMap = new HashMap<>(); - private final Map idMetaMap = new HashMap<>(); - private final boolean emptyVertexProperty; - private final boolean emptyEdgeProperty; - private Supplier vertexConsFun; - private Supplier edgeConsFun; - - private Function vertexPropertySerFun; - private Function edgePropertySerFun; - private Function vertexPropertyDeFun; - private Function edgePropertyDeFun; - private List edgeAtoms = new ArrayList<>(); - private List vertexAtoms = new ArrayList<>(); - private IType keyType; - - // Currently, only one schema is supported. Multiple schemas need to be considered when adding HLA. - @SuppressWarnings("unchecked") - public GraphDataSchema(GraphMeta meta) { - this.vertexMeta = meta.getVertexMeta(); - this.edgeMeta = meta.getEdgeMeta(); - this.keyType = meta.getKeyType(); - this.emptyVertexProperty = this.vertexMeta.getPropertyClass() == EmptyProperty.class; - this.emptyEdgeProperty = this.edgeMeta.getPropertyClass() == EmptyProperty.class; - vertexConsFun = Objects.requireNonNull( - (Supplier) meta.getVertexMeta().getGraphElementConstruct()); - - edgeConsFun = Objects.requireNonNull( - (Supplier) meta.getEdgeMeta().getGraphElementConstruct()); - - transform(meta.getVertexMeta()); - transform(meta.getEdgeMeta()); - } - - private LinkedHashMap transform(IGraphElementMeta elementMeta) { - metaIdMap.put(elementMeta.getGraphElementClass(), (int) elementMeta.getGraphElementId()); - idMetaMap.put((int) elementMeta.getGraphElementId(), elementMeta.getGraphElementClass()); - - boolean isEdge = IEdge.class.isAssignableFrom(elementMeta.getGraphElementClass()); - LinkedHashMap map = new LinkedHashMap<>(); - for (Field field : elementMeta.getGraphMeta().getFields()) { - map.put(field.getName(), field.getType()); - if (isEdge) { - edgeAtoms.add(Preconditions.checkNotNull( - EdgeAtom.EDGE_ATOM_MAP.get(GraphFiledName.valueOf(field.getName())))); - } else { - vertexAtoms.add(Preconditions.checkNotNull( - VertexAtom.VERTEX_ATOM_MAP.get(GraphFiledName.valueOf(field.getName())))); - } - } - Tuple, Function> tuple = - getPropertySerde(elementMeta.getPropertyClass()); - - if (isEdge) { - this.edgePropertySerFun = tuple.f0; - this.edgePropertyDeFun = tuple.f1; - } else { - this.vertexPropertySerFun = tuple.f0; - this.vertexPropertyDeFun = tuple.f1; - } - return map; - } - - private Tuple, Function> getPropertySerde(Class propertyClass) { - Function serFun; - Function deFun; - IType type = Types.getType(propertyClass); - if (type != null) { - serFun = type::serialize; - deFun = type::deserialize; - return Tuple.of(serFun, deFun); - } - - boolean cloneable = IPropertySerializable.class.isAssignableFrom(propertyClass); - if (cloneable && isFinalClass(propertyClass)) { - serFun = o -> ((IPropertySerializable) o).toBytes(); - IPropertySerializable cleanProperty; - try { - cleanProperty = (IPropertySerializable) propertyClass.newInstance(); - } catch (Exception e) { - throw new GeaflowRuntimeException(e); - } - deFun = bytes -> { - IPropertySerializable clone = cleanProperty.clone(); - clone.fromBinary(bytes); - return clone; - }; - } else { - ISerializer kryoSerializer = SerializerFactory.getKryoSerializer(); - serFun = kryoSerializer::serialize; - deFun = kryoSerializer::deserialize; - } - return Tuple.of(serFun, deFun); - } - - private boolean isFinalClass(Class clazz) { - return Modifier.isFinal(clazz.getModifiers()); - } - - public Map getMetaIdMap() { - return metaIdMap; - } - - public Map getIdMetaMap() { - return idMetaMap; - } - - public IGraphElementMeta getVertexMeta() { - return vertexMeta; - } - - public IGraphElementMeta getEdgeMeta() { - return edgeMeta; - } - - public Supplier getVertexConsFun() { - return vertexConsFun; - } - - public Supplier getEdgeConsFun() { - return edgeConsFun; - } - - public Function getVertexPropertySerFun() { - return vertexPropertySerFun; - } - - public Function getEdgePropertySerFun() { - return edgePropertySerFun; - } - - public Function getVertexPropertyDeFun() { - return vertexPropertyDeFun; - } - - public Function getEdgePropertyDeFun() { - return edgePropertyDeFun; - } - - public IType getKeyType() { - return keyType; - } - - public List getEdgeAtoms() { - return edgeAtoms; - } - - public void setEdgeAtoms(List list) { - Set edgeFieldSet = this.edgeAtoms.stream().map(EdgeAtom::getGraphFieldName) - .collect(Collectors.toSet()); - Set newFieldSet = list.stream().map(EdgeAtom::getGraphFieldName) - .collect(Collectors.toSet()); - Preconditions.checkArgument(edgeFieldSet.equals(newFieldSet), - "edge element not match %s, elements are %s", list, edgeFieldSet); - this.edgeAtoms = list; - } - - public List getVertexAtoms() { - return vertexAtoms; - } - - public void setVertexAtoms(List list) { - Set vertexFieldSet = - this.vertexAtoms.stream().map(VertexAtom::getGraphFieldName) - .collect(Collectors.toSet()); - Set newFieldSet = list.stream().map(VertexAtom::getGraphFieldName) - .collect(Collectors.toSet()); - Preconditions.checkArgument(vertexFieldSet.equals(newFieldSet), - "illegal vertex order " + list); - this.vertexAtoms = list; - } +import com.google.common.base.Preconditions; - public boolean isEmptyVertexProperty() { - return emptyVertexProperty; - } +public class GraphDataSchema { - public boolean isEmptyEdgeProperty() { - return emptyEdgeProperty; - } + private final IGraphElementMeta vertexMeta; + private final IGraphElementMeta edgeMeta; + private final Map metaIdMap = new HashMap<>(); + private final Map idMetaMap = new HashMap<>(); + private final boolean emptyVertexProperty; + private final boolean emptyEdgeProperty; + private Supplier vertexConsFun; + private Supplier edgeConsFun; + + private Function vertexPropertySerFun; + private Function edgePropertySerFun; + private Function vertexPropertyDeFun; + private Function edgePropertyDeFun; + private List edgeAtoms = new ArrayList<>(); + private List vertexAtoms = new ArrayList<>(); + private IType keyType; + + // Currently, only one schema is supported. Multiple schemas need to be considered when adding + // HLA. + @SuppressWarnings("unchecked") + public GraphDataSchema(GraphMeta meta) { + this.vertexMeta = meta.getVertexMeta(); + this.edgeMeta = meta.getEdgeMeta(); + this.keyType = meta.getKeyType(); + this.emptyVertexProperty = this.vertexMeta.getPropertyClass() == EmptyProperty.class; + this.emptyEdgeProperty = this.edgeMeta.getPropertyClass() == EmptyProperty.class; + vertexConsFun = + Objects.requireNonNull((Supplier) meta.getVertexMeta().getGraphElementConstruct()); + + edgeConsFun = + Objects.requireNonNull((Supplier) meta.getEdgeMeta().getGraphElementConstruct()); + + transform(meta.getVertexMeta()); + transform(meta.getEdgeMeta()); + } + + private LinkedHashMap transform(IGraphElementMeta elementMeta) { + metaIdMap.put(elementMeta.getGraphElementClass(), (int) elementMeta.getGraphElementId()); + idMetaMap.put((int) elementMeta.getGraphElementId(), elementMeta.getGraphElementClass()); + + boolean isEdge = IEdge.class.isAssignableFrom(elementMeta.getGraphElementClass()); + LinkedHashMap map = new LinkedHashMap<>(); + for (Field field : elementMeta.getGraphMeta().getFields()) { + map.put(field.getName(), field.getType()); + if (isEdge) { + edgeAtoms.add( + Preconditions.checkNotNull( + EdgeAtom.EDGE_ATOM_MAP.get(GraphFiledName.valueOf(field.getName())))); + } else { + vertexAtoms.add( + Preconditions.checkNotNull( + VertexAtom.VERTEX_ATOM_MAP.get(GraphFiledName.valueOf(field.getName())))); + } + } + Tuple, Function> tuple = + getPropertySerde(elementMeta.getPropertyClass()); + + if (isEdge) { + this.edgePropertySerFun = tuple.f0; + this.edgePropertyDeFun = tuple.f1; + } else { + this.vertexPropertySerFun = tuple.f0; + this.vertexPropertyDeFun = tuple.f1; + } + return map; + } + + private Tuple, Function> getPropertySerde( + Class propertyClass) { + Function serFun; + Function deFun; + IType type = Types.getType(propertyClass); + if (type != null) { + serFun = type::serialize; + deFun = type::deserialize; + return Tuple.of(serFun, deFun); + } + + boolean cloneable = IPropertySerializable.class.isAssignableFrom(propertyClass); + if (cloneable && isFinalClass(propertyClass)) { + serFun = o -> ((IPropertySerializable) o).toBytes(); + IPropertySerializable cleanProperty; + try { + cleanProperty = (IPropertySerializable) propertyClass.newInstance(); + } catch (Exception e) { + throw new GeaflowRuntimeException(e); + } + deFun = + bytes -> { + IPropertySerializable clone = cleanProperty.clone(); + clone.fromBinary(bytes); + return clone; + }; + } else { + ISerializer kryoSerializer = SerializerFactory.getKryoSerializer(); + serFun = kryoSerializer::serialize; + deFun = kryoSerializer::deserialize; + } + return Tuple.of(serFun, deFun); + } + + private boolean isFinalClass(Class clazz) { + return Modifier.isFinal(clazz.getModifiers()); + } + + public Map getMetaIdMap() { + return metaIdMap; + } + + public Map getIdMetaMap() { + return idMetaMap; + } + + public IGraphElementMeta getVertexMeta() { + return vertexMeta; + } + + public IGraphElementMeta getEdgeMeta() { + return edgeMeta; + } + + public Supplier getVertexConsFun() { + return vertexConsFun; + } + + public Supplier getEdgeConsFun() { + return edgeConsFun; + } + + public Function getVertexPropertySerFun() { + return vertexPropertySerFun; + } + + public Function getEdgePropertySerFun() { + return edgePropertySerFun; + } + + public Function getVertexPropertyDeFun() { + return vertexPropertyDeFun; + } + + public Function getEdgePropertyDeFun() { + return edgePropertyDeFun; + } + + public IType getKeyType() { + return keyType; + } + + public List getEdgeAtoms() { + return edgeAtoms; + } + + public void setEdgeAtoms(List list) { + Set edgeFieldSet = + this.edgeAtoms.stream().map(EdgeAtom::getGraphFieldName).collect(Collectors.toSet()); + Set newFieldSet = + list.stream().map(EdgeAtom::getGraphFieldName).collect(Collectors.toSet()); + Preconditions.checkArgument( + edgeFieldSet.equals(newFieldSet), + "edge element not match %s, elements are %s", + list, + edgeFieldSet); + this.edgeAtoms = list; + } + + public List getVertexAtoms() { + return vertexAtoms; + } + + public void setVertexAtoms(List list) { + Set vertexFieldSet = + this.vertexAtoms.stream().map(VertexAtom::getGraphFieldName).collect(Collectors.toSet()); + Set newFieldSet = + list.stream().map(VertexAtom::getGraphFieldName).collect(Collectors.toSet()); + Preconditions.checkArgument(vertexFieldSet.equals(newFieldSet), "illegal vertex order " + list); + this.vertexAtoms = list; + } + + public boolean isEmptyVertexProperty() { + return emptyVertexProperty; + } + + public boolean isEmptyEdgeProperty() { + return emptyEdgeProperty; + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKMapSerializer.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKMapSerializer.java index b34afbee5..ae82d2a27 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKMapSerializer.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKMapSerializer.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.serializer; import java.util.function.Function; + import org.apache.geaflow.common.serialize.ISerializer; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; @@ -27,54 +28,56 @@ public class DefaultKMapSerializer implements IKMapSerializer { - private Function keySerializer; - private Function keyDeserializer; - private Function subKeySerializer; - private Function subKeyDeserializer; - private Function valueSerializer; - private Function valueDeserializer; + private Function keySerializer; + private Function keyDeserializer; + private Function subKeySerializer; + private Function subKeyDeserializer; + private Function valueSerializer; + private Function valueDeserializer; - public DefaultKMapSerializer(Class keyClazz, Class subKeyClazz, Class valueClazz) { - final IType keyType = Types.getType(keyClazz); - final IType subKeyType = Types.getType(subKeyClazz); - final IType valueType = Types.getType(valueClazz); - ISerializer serializer = SerializerFactory.getKryoSerializer(); + public DefaultKMapSerializer(Class keyClazz, Class subKeyClazz, Class valueClazz) { + final IType keyType = Types.getType(keyClazz); + final IType subKeyType = Types.getType(subKeyClazz); + final IType valueType = Types.getType(valueClazz); + ISerializer serializer = SerializerFactory.getKryoSerializer(); - keySerializer = keyType == null ? serializer::serialize : keyType::serialize; - keyDeserializer = keyType == null ? c -> (K) serializer.deserialize(c) : keyType::deserialize; - subKeySerializer = subKeyType == null ? serializer::serialize : subKeyType::serialize; - subKeyDeserializer = subKeyType == null ? c -> (UK) serializer.deserialize(c) : subKeyType::deserialize; - valueSerializer = valueType == null ? serializer::serialize : valueType::serialize; - valueDeserializer = valueType == null ? c -> (UV) serializer.deserialize(c) : valueType::deserialize; - } + keySerializer = keyType == null ? serializer::serialize : keyType::serialize; + keyDeserializer = keyType == null ? c -> (K) serializer.deserialize(c) : keyType::deserialize; + subKeySerializer = subKeyType == null ? serializer::serialize : subKeyType::serialize; + subKeyDeserializer = + subKeyType == null ? c -> (UK) serializer.deserialize(c) : subKeyType::deserialize; + valueSerializer = valueType == null ? serializer::serialize : valueType::serialize; + valueDeserializer = + valueType == null ? c -> (UV) serializer.deserialize(c) : valueType::deserialize; + } - @Override - public byte[] serializeKey(K key) { - return keySerializer.apply(key); - } + @Override + public byte[] serializeKey(K key) { + return keySerializer.apply(key); + } - @Override - public K deserializeKey(byte[] array) { - return keyDeserializer.apply(array); - } + @Override + public K deserializeKey(byte[] array) { + return keyDeserializer.apply(array); + } - @Override - public byte[] serializeUK(UK key) { - return subKeySerializer.apply(key); - } + @Override + public byte[] serializeUK(UK key) { + return subKeySerializer.apply(key); + } - @Override - public UK deserializeUK(byte[] array) { - return subKeyDeserializer.apply(array); - } + @Override + public UK deserializeUK(byte[] array) { + return subKeyDeserializer.apply(array); + } - @Override - public byte[] serializeUV(UV value) { - return valueSerializer.apply(value); - } + @Override + public byte[] serializeUV(UV value) { + return valueSerializer.apply(value); + } - @Override - public UV deserializeUV(byte[] valueArray) { - return valueDeserializer.apply(valueArray); - } + @Override + public UV deserializeUV(byte[] valueArray) { + return valueDeserializer.apply(valueArray); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKVSerializer.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKVSerializer.java index e662958f8..4e365f891 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKVSerializer.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/DefaultKVSerializer.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.serializer; import java.util.function.Function; + import org.apache.geaflow.common.serialize.ISerializer; import org.apache.geaflow.common.serialize.SerializerFactory; import org.apache.geaflow.common.type.IType; @@ -27,39 +28,40 @@ public class DefaultKVSerializer implements IKVSerializer { - private Function keySerializer; - private Function valueSerializer; - private Function keyDeserializer; - private Function valueDeserializer; + private Function keySerializer; + private Function valueSerializer; + private Function keyDeserializer; + private Function valueDeserializer; - public DefaultKVSerializer(Class keyClazz, Class valueClazz) { - IType keyType = Types.getType(keyClazz); - IType valueType = Types.getType(valueClazz); - ISerializer serializer = SerializerFactory.getKryoSerializer(); + public DefaultKVSerializer(Class keyClazz, Class valueClazz) { + IType keyType = Types.getType(keyClazz); + IType valueType = Types.getType(valueClazz); + ISerializer serializer = SerializerFactory.getKryoSerializer(); - keySerializer = keyType == null ? serializer::serialize : keyType::serialize; - keyDeserializer = keyType == null ? c -> (K) serializer.deserialize(c) : keyType::deserialize; - valueSerializer = valueType == null ? serializer::serialize : valueType::serialize; - valueDeserializer = valueType == null ? c -> (V) serializer.deserialize(c) : valueType::deserialize; - } + keySerializer = keyType == null ? serializer::serialize : keyType::serialize; + keyDeserializer = keyType == null ? c -> (K) serializer.deserialize(c) : keyType::deserialize; + valueSerializer = valueType == null ? serializer::serialize : valueType::serialize; + valueDeserializer = + valueType == null ? c -> (V) serializer.deserialize(c) : valueType::deserialize; + } - @Override - public byte[] serializeKey(K key) { - return keySerializer.apply(key); - } + @Override + public byte[] serializeKey(K key) { + return keySerializer.apply(key); + } - @Override - public K deserializeKey(byte[] array) { - return keyDeserializer.apply(array); - } + @Override + public K deserializeKey(byte[] array) { + return keyDeserializer.apply(array); + } - @Override - public byte[] serializeValue(V value) { - return valueSerializer.apply(value); - } + @Override + public byte[] serializeValue(V value) { + return valueSerializer.apply(value); + } - @Override - public V deserializeValue(byte[] valueArray) { - return valueDeserializer.apply(valueArray); - } + @Override + public V deserializeValue(byte[] valueArray) { + return valueDeserializer.apply(valueArray); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKMapSerializer.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKMapSerializer.java index 2a37b963e..61b64cd49 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKMapSerializer.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKMapSerializer.java @@ -21,11 +21,11 @@ public interface IKMapSerializer extends IKeySerializer { - byte[] serializeUK(UK uk); + byte[] serializeUK(UK uk); - UK deserializeUK(byte[] array); + UK deserializeUK(byte[] array); - byte[] serializeUV(UV uv); + byte[] serializeUV(UV uv); - UV deserializeUV(byte[] array); + UV deserializeUV(byte[] array); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKVSerializer.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKVSerializer.java index 32828d595..29b806669 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKVSerializer.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKVSerializer.java @@ -21,7 +21,7 @@ public interface IKVSerializer extends IKeySerializer { - byte[] serializeValue(V value); + byte[] serializeValue(V value); - V deserializeValue(byte[] valueArray); + V deserializeValue(byte[] valueArray); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKeySerializer.java b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKeySerializer.java index 56772b2ba..db3aef342 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKeySerializer.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/main/java/org/apache/geaflow/state/serializer/IKeySerializer.java @@ -23,7 +23,7 @@ public interface IKeySerializer extends Serializable { - byte[] serializeKey(K key); + byte[] serializeKey(K key); - K deserializeKey(byte[] array); + K deserializeKey(byte[] array); } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/data/TimeRangeTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/data/TimeRangeTest.java index 9f4d63bc2..edd48271d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/data/TimeRangeTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/data/TimeRangeTest.java @@ -24,13 +24,13 @@ public class TimeRangeTest { - @Test - public void testContain() { - TimeRange range = TimeRange.of(1, 5); - Assert.assertTrue(range.contain(1L)); - Assert.assertTrue(range.contain(3L)); - Assert.assertFalse(range.contain(5L)); + @Test + public void testContain() { + TimeRange range = TimeRange.of(1, 5); + Assert.assertTrue(range.contain(1L)); + Assert.assertTrue(range.contain(3L)); + Assert.assertFalse(range.contain(5L)); - Assert.assertEquals(range, TimeRange.of(1, 5)); - } + Assert.assertEquals(range, TimeRange.of(1, 5)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/filter/StateFilterTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/filter/StateFilterTest.java index 1d77e09b8..ea1d23b52 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/filter/StateFilterTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/filter/StateFilterTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; + import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.IDEdge; @@ -52,240 +53,261 @@ public class StateFilterTest { - private static final Logger LOGGER = LoggerFactory.getLogger(StateFilterTest.class); - - @Test - public void testValidate() { - EdgeTsFilter tsFilter = new EdgeTsFilter(TimeRange.of(100, 200)); - Exception e = null; - try { - IFilter filter = tsFilter.or(new EdgeValueDropFilter()); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - e = null; - try { - IFilter filter = tsFilter.or(new VertexTsFilter(TimeRange.of(100, 200))); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - e = null; - try { - IFilter filter = tsFilter.or(new VertexTsFilter(TimeRange.of(100, 200)).and(tsFilter)); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - e = null; - try { - IFilter filter = tsFilter.or( - new EdgeTsFilter(TimeRange.of(100, 200)).and(new EdgeValueDropFilter())); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - e = null; - try { - IFilter filter = tsFilter.or(new EdgeValueDropFilter()); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - e = null; - try { - IFilter filter = tsFilter.and(new EdgeValueDropFilter()) - .or(new EdgeTsFilter(TimeRange.of(100, 200))); - } catch (Exception ex) { - e = ex; - } - Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - - IFilter filter = tsFilter.and(new EdgeValueDropFilter()) - .or(new EdgeTsFilter(TimeRange.of(100, 200)).and(new EdgeValueDropFilter())); - filter = new EdgeValueDropFilter().and(new VertexValueDropFilter()) - .or(new EdgeTsFilter(TimeRange.of(100, 200)).and(new EdgeValueDropFilter()) - .and(new VertexValueDropFilter())); - } + private static final Logger LOGGER = LoggerFactory.getLogger(StateFilterTest.class); - @Test - public void testEdgeTsFilter() { - EdgeTsFilter filter = new EdgeTsFilter(TimeRange.of(100, 200)); - Assert.assertEquals(filter.getFilterType(), FilterType.EDGE_TS); - IEdge edge = new IDTimeEdge<>("hello", "world", 99); - Assert.assertFalse(filter.filter(edge)); - - IGraphFilter complexFilter = GraphFilter.of(filter); - Assert.assertFalse(complexFilter.filterEdge(edge)); + @Test + public void testValidate() { + EdgeTsFilter tsFilter = new EdgeTsFilter(TimeRange.of(100, 200)); + Exception e = null; + try { + IFilter filter = tsFilter.or(new EdgeValueDropFilter()); + } catch (Exception ex) { + e = ex; } + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - @Test - public void testEdgeDirection() { - IEdge edge1 = new IDEdge<>("hello", "world"); - edge1.setDirect(EdgeDirection.OUT); - IEdge edge2 = new IDEdge<>("hello", "world"); - edge2.setDirect(EdgeDirection.IN); - - IGraphFilter filter = GraphFilter.of(new InEdgeFilter()); - Assert.assertFalse(filter.filterEdge(edge1)); - Assert.assertTrue(filter.filterEdge(edge2)); - - filter = GraphFilter.of(new OutEdgeFilter()); - Assert.assertTrue(filter.filterEdge(edge1)); - Assert.assertFalse(filter.filterEdge(edge2)); + e = null; + try { + IFilter filter = tsFilter.or(new VertexTsFilter(TimeRange.of(100, 200))); + } catch (Exception ex) { + e = ex; } + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - @Test - public void testAndFilter() { - IFilter filter = new EdgeTsFilter(TimeRange.of(1L, 2L)); - IFilter appendFilter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(1L, 2L))); - filter = filter.and(appendFilter); - Assert.assertEquals(appendFilter.getFilterType(), FilterType.AND); - Assert.assertEquals(filter.toString(), filter.and(EmptyGraphFilter.of()).toString()); - - Assert.assertEquals(((AndFilter) appendFilter).getFilters().size(), 2); - Assert.assertEquals(filter.getFilterType(), FilterType.AND); - Assert.assertEquals(((AndFilter) filter).getFilters().size(), 3); - - filter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(100, 200))); - IEdge edge = new IDTimeEdge<>("hello", "world", 100); - IGraphFilter stateFilter = GraphFilter.of(filter); - Assert.assertFalse(stateFilter.filterEdge(edge)); - - edge.setDirect(EdgeDirection.IN); - Assert.assertTrue(stateFilter.filterEdge(edge)); - Assert.assertFalse(stateFilter.dropAllRemaining()); - - List> list = Arrays.asList(new IDTimeEdge<>("hello", "world", 100), - new IDTimeEdge<>("hello", "world", 120), new IDTimeEdge<>("hello", "world", 140), - new IDTimeEdge<>("hello", "world", 160), new IDTimeEdge<>("hello", "world", 180)); - - list.get(0).setDirect(EdgeDirection.IN); - list.get(2).setDirect(EdgeDirection.IN); - list.get(4).setDirect(EdgeDirection.IN); - - filter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(80, 141))); - stateFilter = GraphFilter.of(filter); - LOGGER.info("filter {}", stateFilter.toString()); - IGraphFilter graphFilter = GraphFilter.of(filter); - Assert.assertTrue(graphFilter.contains(FilterType.EDGE_TS)); - Assert.assertEquals(graphFilter.retrieve(FilterType.EDGE_TS).getFilterType(), FilterType.EDGE_TS); - List> res = list.stream().filter(stateFilter::filterEdge) - .collect(Collectors.toList()); - Assert.assertEquals(res.size(), 2); - Assert.assertTrue(res.contains(list.get(0))); - Assert.assertTrue(res.contains(list.get(2))); - - graphFilter = GraphFilter.of(filter); - Assert.assertTrue(graphFilter.filterVertex(new IDVertex("1"))); - Assert.assertTrue(graphFilter.filterOneDegreeGraph(new OneDegreeGraph<>("1", null, null))); + e = null; + try { + IFilter filter = tsFilter.or(new VertexTsFilter(TimeRange.of(100, 200)).and(tsFilter)); + } catch (Exception ex) { + e = ex; } - - @Test - public void testOrFilter() { - IFilter filter = new EdgeTsFilter(TimeRange.of(1L, 2L)); - filter = filter.or(new EdgeTsFilter(TimeRange.of(3L, 4L))); - - Assert.assertEquals(filter.getFilterType(), FilterType.OR); - Assert.assertEquals(filter.or(EmptyGraphFilter.of()).toString(), filter.toString()); - - List> list = Arrays.asList(new IDTimeEdge<>("hello", "world", 1), - new IDTimeEdge<>("hello", "world", 3), new IDTimeEdge<>("hello", "world", 5)); - - IGraphFilter graphFilter = GraphFilter.of(filter); - List> res = new ArrayList<>(); - for (IEdge stringObjectIEdge : list) { - if (graphFilter.filterEdge(stringObjectIEdge)) { - res.add(stringObjectIEdge); - } - } - Assert.assertEquals(res.size(), 2); - - filter = new EdgeTsFilter(TimeRange.of(5L, 6L)).or(filter); - res.clear(); - graphFilter = GraphFilter.of(filter); - for (IEdge stringObjectIEdge : list) { - if (graphFilter.filterEdge(stringObjectIEdge)) { - res.add(stringObjectIEdge); - } - } - Assert.assertEquals(res.size(), 3); - - graphFilter = GraphFilter.of(filter); - Assert.assertTrue(graphFilter.filterVertex(new IDVertex("1"))); - Assert.assertTrue(graphFilter.filterOneDegreeGraph(new OneDegreeGraph<>("1", null, null))); + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); + + e = null; + try { + IFilter filter = + tsFilter.or(new EdgeTsFilter(TimeRange.of(100, 200)).and(new EdgeValueDropFilter())); + } catch (Exception ex) { + e = ex; } + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - @Test - public void testParseLabelFilter() { - String label1 = "person"; - String label2 = "trade"; - String label3 = "relation"; - - IGraphFilter filter = GraphFilter.of(new EdgeLabelFilter(label1)); - List labels = FilterHelper.parseLabel(GraphFilter.of(filter), false); - Assert.assertEquals(labels.size(), 1); - Assert.assertEquals(labels.get(0), label1); + e = null; + try { + IFilter filter = tsFilter.or(new EdgeValueDropFilter()); + } catch (Exception ex) { + e = ex; + } + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); + + e = null; + try { + IFilter filter = + tsFilter.and(new EdgeValueDropFilter()).or(new EdgeTsFilter(TimeRange.of(100, 200))); + } catch (Exception ex) { + e = ex; + } + Assert.assertEquals(e.getClass(), IllegalArgumentException.class); - filter = GraphFilter.of(new EdgeLabelFilter(label1)) + IFilter filter = + tsFilter + .and(new EdgeValueDropFilter()) + .or(new EdgeTsFilter(TimeRange.of(100, 200)).and(new EdgeValueDropFilter())); + filter = + new EdgeValueDropFilter() + .and(new VertexValueDropFilter()) + .or( + new EdgeTsFilter(TimeRange.of(100, 200)) + .and(new EdgeValueDropFilter()) + .and(new VertexValueDropFilter())); + } + + @Test + public void testEdgeTsFilter() { + EdgeTsFilter filter = new EdgeTsFilter(TimeRange.of(100, 200)); + Assert.assertEquals(filter.getFilterType(), FilterType.EDGE_TS); + IEdge edge = new IDTimeEdge<>("hello", "world", 99); + Assert.assertFalse(filter.filter(edge)); + + IGraphFilter complexFilter = GraphFilter.of(filter); + Assert.assertFalse(complexFilter.filterEdge(edge)); + } + + @Test + public void testEdgeDirection() { + IEdge edge1 = new IDEdge<>("hello", "world"); + edge1.setDirect(EdgeDirection.OUT); + IEdge edge2 = new IDEdge<>("hello", "world"); + edge2.setDirect(EdgeDirection.IN); + + IGraphFilter filter = GraphFilter.of(new InEdgeFilter()); + Assert.assertFalse(filter.filterEdge(edge1)); + Assert.assertTrue(filter.filterEdge(edge2)); + + filter = GraphFilter.of(new OutEdgeFilter()); + Assert.assertTrue(filter.filterEdge(edge1)); + Assert.assertFalse(filter.filterEdge(edge2)); + } + + @Test + public void testAndFilter() { + IFilter filter = new EdgeTsFilter(TimeRange.of(1L, 2L)); + IFilter appendFilter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(1L, 2L))); + filter = filter.and(appendFilter); + Assert.assertEquals(appendFilter.getFilterType(), FilterType.AND); + Assert.assertEquals(filter.toString(), filter.and(EmptyGraphFilter.of()).toString()); + + Assert.assertEquals(((AndFilter) appendFilter).getFilters().size(), 2); + Assert.assertEquals(filter.getFilterType(), FilterType.AND); + Assert.assertEquals(((AndFilter) filter).getFilters().size(), 3); + + filter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(100, 200))); + IEdge edge = new IDTimeEdge<>("hello", "world", 100); + IGraphFilter stateFilter = GraphFilter.of(filter); + Assert.assertFalse(stateFilter.filterEdge(edge)); + + edge.setDirect(EdgeDirection.IN); + Assert.assertTrue(stateFilter.filterEdge(edge)); + Assert.assertFalse(stateFilter.dropAllRemaining()); + + List> list = + Arrays.asList( + new IDTimeEdge<>("hello", "world", 100), + new IDTimeEdge<>("hello", "world", 120), + new IDTimeEdge<>("hello", "world", 140), + new IDTimeEdge<>("hello", "world", 160), + new IDTimeEdge<>("hello", "world", 180)); + + list.get(0).setDirect(EdgeDirection.IN); + list.get(2).setDirect(EdgeDirection.IN); + list.get(4).setDirect(EdgeDirection.IN); + + filter = new InEdgeFilter().and(new EdgeTsFilter(TimeRange.of(80, 141))); + stateFilter = GraphFilter.of(filter); + LOGGER.info("filter {}", stateFilter.toString()); + IGraphFilter graphFilter = GraphFilter.of(filter); + Assert.assertTrue(graphFilter.contains(FilterType.EDGE_TS)); + Assert.assertEquals( + graphFilter.retrieve(FilterType.EDGE_TS).getFilterType(), FilterType.EDGE_TS); + List> res = + list.stream().filter(stateFilter::filterEdge).collect(Collectors.toList()); + Assert.assertEquals(res.size(), 2); + Assert.assertTrue(res.contains(list.get(0))); + Assert.assertTrue(res.contains(list.get(2))); + + graphFilter = GraphFilter.of(filter); + Assert.assertTrue(graphFilter.filterVertex(new IDVertex("1"))); + Assert.assertTrue(graphFilter.filterOneDegreeGraph(new OneDegreeGraph<>("1", null, null))); + } + + @Test + public void testOrFilter() { + IFilter filter = new EdgeTsFilter(TimeRange.of(1L, 2L)); + filter = filter.or(new EdgeTsFilter(TimeRange.of(3L, 4L))); + + Assert.assertEquals(filter.getFilterType(), FilterType.OR); + Assert.assertEquals(filter.or(EmptyGraphFilter.of()).toString(), filter.toString()); + + List> list = + Arrays.asList( + new IDTimeEdge<>("hello", "world", 1), + new IDTimeEdge<>("hello", "world", 3), + new IDTimeEdge<>("hello", "world", 5)); + + IGraphFilter graphFilter = GraphFilter.of(filter); + List> res = new ArrayList<>(); + for (IEdge stringObjectIEdge : list) { + if (graphFilter.filterEdge(stringObjectIEdge)) { + res.add(stringObjectIEdge); + } + } + Assert.assertEquals(res.size(), 2); + + filter = new EdgeTsFilter(TimeRange.of(5L, 6L)).or(filter); + res.clear(); + graphFilter = GraphFilter.of(filter); + for (IEdge stringObjectIEdge : list) { + if (graphFilter.filterEdge(stringObjectIEdge)) { + res.add(stringObjectIEdge); + } + } + Assert.assertEquals(res.size(), 3); + + graphFilter = GraphFilter.of(filter); + Assert.assertTrue(graphFilter.filterVertex(new IDVertex("1"))); + Assert.assertTrue(graphFilter.filterOneDegreeGraph(new OneDegreeGraph<>("1", null, null))); + } + + @Test + public void testParseLabelFilter() { + String label1 = "person"; + String label2 = "trade"; + String label3 = "relation"; + + IGraphFilter filter = GraphFilter.of(new EdgeLabelFilter(label1)); + List labels = FilterHelper.parseLabel(GraphFilter.of(filter), false); + Assert.assertEquals(labels.size(), 1); + Assert.assertEquals(labels.get(0), label1); + + filter = + GraphFilter.of(new EdgeLabelFilter(label1)) .or(GraphFilter.of(new EdgeTsFilter(TimeRange.of(5L, 6L)))); - labels = FilterHelper.parseLabel(filter, false); - Assert.assertEquals(labels.size(), 0); - - filter = GraphFilter.of(new EdgeLabelFilter(label1)) - .or(GraphFilter.of(new EdgeLabelFilter(label2))); - labels = FilterHelper.parseLabel(filter, false); - Assert.assertEquals(labels.size(), 2); - Assert.assertEquals(labels.get(0), label1); - Assert.assertEquals(labels.get(1), label2); - - filter = GraphFilter.of(new EdgeLabelFilter(label1)) + labels = FilterHelper.parseLabel(filter, false); + Assert.assertEquals(labels.size(), 0); + + filter = + GraphFilter.of(new EdgeLabelFilter(label1)).or(GraphFilter.of(new EdgeLabelFilter(label2))); + labels = FilterHelper.parseLabel(filter, false); + Assert.assertEquals(labels.size(), 2); + Assert.assertEquals(labels.get(0), label1); + Assert.assertEquals(labels.get(1), label2); + + filter = + GraphFilter.of(new EdgeLabelFilter(label1)) .and(GraphFilter.of(new EdgeLabelFilter(label2))); - labels = FilterHelper.parseLabel(filter, false); - Assert.assertEquals(labels.size(), 0); + labels = FilterHelper.parseLabel(filter, false); + Assert.assertEquals(labels.size(), 0); - filter = GraphFilter.of(new VertexLabelFilter(label1)) + filter = + GraphFilter.of(new VertexLabelFilter(label1)) .or(GraphFilter.of(new VertexLabelFilter(label2))) .or(GraphFilter.of(new VertexLabelFilter(label3))); - labels = FilterHelper.parseLabel(filter, true); - Assert.assertEquals(labels.size(), 3); - Assert.assertEquals(labels.get(0), label1); - Assert.assertEquals(labels.get(1), label2); - Assert.assertEquals(labels.get(2), label3); - - labels = FilterHelper.parseLabel(filter, false); - Assert.assertEquals(labels.size(), 0); - } - - @Test - public void testParseDtFilter() { - TimeRange range = TimeRange.of(1735660800, 1740758400); - IGraphFilter filter = GraphFilter.of(new EdgeTsFilter(range)); - TimeRange parse_range = FilterHelper.parseDt(filter, false); - Assert.assertEquals(parse_range, range); - parse_range = FilterHelper.parseDt(filter, true); - Assert.assertNull(parse_range); - - filter = GraphFilter.of(new EdgeLabelFilter("person").or( - new EdgeLabelFilter("trade").or(new EdgeTsFilter(range))) - .or(new EdgeLabelFilter("foo"))); - parse_range = FilterHelper.parseDt(filter, false); - Assert.assertEquals(parse_range, range); - parse_range = FilterHelper.parseDt(filter, true); - Assert.assertNull(parse_range); - - filter = GraphFilter.of(new VertexLabelFilter("person").or( - new VertexLabelFilter("trade").or(new VertexTsFilter(range))) - .or(new VertexLabelFilter("foo"))); - parse_range = FilterHelper.parseDt(filter, true); - Assert.assertEquals(parse_range, range); - parse_range = FilterHelper.parseDt(filter, false); - Assert.assertNull(parse_range); - } + labels = FilterHelper.parseLabel(filter, true); + Assert.assertEquals(labels.size(), 3); + Assert.assertEquals(labels.get(0), label1); + Assert.assertEquals(labels.get(1), label2); + Assert.assertEquals(labels.get(2), label3); + + labels = FilterHelper.parseLabel(filter, false); + Assert.assertEquals(labels.size(), 0); + } + + @Test + public void testParseDtFilter() { + TimeRange range = TimeRange.of(1735660800, 1740758400); + IGraphFilter filter = GraphFilter.of(new EdgeTsFilter(range)); + TimeRange parse_range = FilterHelper.parseDt(filter, false); + Assert.assertEquals(parse_range, range); + parse_range = FilterHelper.parseDt(filter, true); + Assert.assertNull(parse_range); + + filter = + GraphFilter.of( + new EdgeLabelFilter("person") + .or(new EdgeLabelFilter("trade").or(new EdgeTsFilter(range))) + .or(new EdgeLabelFilter("foo"))); + parse_range = FilterHelper.parseDt(filter, false); + Assert.assertEquals(parse_range, range); + parse_range = FilterHelper.parseDt(filter, true); + Assert.assertNull(parse_range); + + filter = + GraphFilter.of( + new VertexLabelFilter("person") + .or(new VertexLabelFilter("trade").or(new VertexTsFilter(range))) + .or(new VertexLabelFilter("foo"))); + parse_range = FilterHelper.parseDt(filter, true); + Assert.assertEquals(parse_range, range); + parse_range = FilterHelper.parseDt(filter, false); + Assert.assertNull(parse_range); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/graph/encoder/BytesEncoderTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/graph/encoder/BytesEncoderTest.java index 84bdb6f6e..41f064c1a 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/graph/encoder/BytesEncoderTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/graph/encoder/BytesEncoderTest.java @@ -21,33 +21,34 @@ import java.util.ArrayList; import java.util.List; + import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.testng.Assert; import org.testng.annotations.Test; public class BytesEncoderTest { - @Test - public void test() { - byte a = 0x3; - IBytesEncoder encoder = new DefaultBytesEncoder(); - List list = new ArrayList<>(); - list.add("a".getBytes()); - list.add("b".getBytes()); - byte[] combined = encoder.combine(list); - byte magic = encoder.parseMagicNumber(combined[combined.length - 1]); - Assert.assertEquals(magic, encoder.getMyMagicNumber()); - - List res = encoder.split(combined); - Assert.assertEquals(res.get(0), list.get(0)); - Assert.assertEquals(res.get(1), list.get(1)); - - combined = encoder.combine(list, StateConfigKeys.DELIMITER); - magic = encoder.parseMagicNumber(combined[combined.length - 1]); - Assert.assertEquals(magic, encoder.getMyMagicNumber()); - - res = encoder.split(combined, StateConfigKeys.DELIMITER); - Assert.assertEquals(res.get(0), list.get(0)); - Assert.assertEquals(res.get(1), list.get(1)); - } + @Test + public void test() { + byte a = 0x3; + IBytesEncoder encoder = new DefaultBytesEncoder(); + List list = new ArrayList<>(); + list.add("a".getBytes()); + list.add("b".getBytes()); + byte[] combined = encoder.combine(list); + byte magic = encoder.parseMagicNumber(combined[combined.length - 1]); + Assert.assertEquals(magic, encoder.getMyMagicNumber()); + + List res = encoder.split(combined); + Assert.assertEquals(res.get(0), list.get(0)); + Assert.assertEquals(res.get(1), list.get(1)); + + combined = encoder.combine(list, StateConfigKeys.DELIMITER); + magic = encoder.parseMagicNumber(combined[combined.length - 1]); + Assert.assertEquals(magic, encoder.getMyMagicNumber()); + + res = encoder.split(combined, StateConfigKeys.DELIMITER); + Assert.assertEquals(res.get(0), list.get(0)); + Assert.assertEquals(res.get(1), list.get(1)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterTest.java index fdfa64921..18073f97d 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterTest.java @@ -19,28 +19,28 @@ package org.apache.geaflow.state.iterator; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; + import org.testng.Assert; import org.testng.annotations.Test; -public class IteratorWithFilterTest { +import com.google.common.collect.Lists; - @Test - public void test() { - IteratorWithFilter it = new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o > 3); +public class IteratorWithFilterTest { - ArrayList list = Lists.newArrayList(it); - Assert.assertEquals(list.size(), 2); - Assert.assertTrue(list.contains(4)); - Assert.assertTrue(list.contains(5)); + @Test + public void test() { + IteratorWithFilter it = + new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o > 3); - it = new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o < 3); - list = Lists.newArrayList(it); - Assert.assertEquals(list.size(), 2); - } + ArrayList list = Lists.newArrayList(it); + Assert.assertEquals(list.size(), 2); + Assert.assertTrue(list.contains(4)); + Assert.assertTrue(list.contains(5)); + it = new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o < 3); + list = Lists.newArrayList(it); + Assert.assertEquals(list.size(), 2); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFnTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFnTest.java index 5a7c2268a..1f64e13c0 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFnTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/IteratorWithFilterThenFnTest.java @@ -19,29 +19,32 @@ package org.apache.geaflow.state.iterator; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; import java.util.Objects; + import org.testng.Assert; import org.testng.annotations.Test; +import com.google.common.collect.Lists; + public class IteratorWithFilterThenFnTest { - @Test - public void testFunction() { - IteratorWithFilterThenFn it = - new IteratorWithFilterThenFn<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o > 3, Object::toString); - - ArrayList list = Lists.newArrayList(it); - Assert.assertEquals(list.size(), 2); - Assert.assertTrue(list.contains("4")); - Assert.assertTrue(list.contains("5")); - - it = new IteratorWithFilterThenFn<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o <= 3, Objects::toString); - list = Lists.newArrayList(it); - Assert.assertEquals(list.size(), 3); - } -} \ No newline at end of file + @Test + public void testFunction() { + IteratorWithFilterThenFn it = + new IteratorWithFilterThenFn<>( + Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o > 3, Object::toString); + + ArrayList list = Lists.newArrayList(it); + Assert.assertEquals(list.size(), 2); + Assert.assertTrue(list.contains("4")); + Assert.assertTrue(list.contains("5")); + + it = + new IteratorWithFilterThenFn<>( + Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o <= 3, Objects::toString); + list = Lists.newArrayList(it); + Assert.assertEquals(list.size(), 3); + } +} diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/MultiIteratorTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/MultiIteratorTest.java index 319a03a04..4344b6b12 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/MultiIteratorTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/iterator/MultiIteratorTest.java @@ -19,29 +19,27 @@ package org.apache.geaflow.state.iterator; -import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.testng.Assert; import org.testng.annotations.Test; -public class MultiIteratorTest { +import com.google.common.collect.Lists; - @Test - public void test() { - List> list = new ArrayList<>(); - list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o > 3)); - list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o > 6)); - list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), - o -> o > 2)); +public class MultiIteratorTest { - MultiIterator it = new MultiIterator<>(list.iterator()); - List res = Lists.newArrayList(it); - Assert.assertEquals(res, Arrays.asList(4, 5, 3, 4, 5)); - } + @Test + public void test() { + List> list = new ArrayList<>(); + list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o > 3)); + list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o > 6)); + list.add(new IteratorWithFilter<>(Arrays.asList(1, 2, 3, 4, 5).iterator(), o -> o > 2)); + MultiIterator it = new MultiIterator<>(list.iterator()); + List res = Lists.newArrayList(it); + Assert.assertEquals(res, Arrays.asList(4, 5, 3, 4, 5)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverterTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverterTest.java index ff2a3940f..7d5eb7bd1 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverterTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/CodeGenFilterConverterTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.ValueLabelTimeEdge; @@ -47,88 +48,101 @@ public class CodeGenFilterConverterTest { - private IFilter simpleFilter; - private IFilter middleFilter; - private IFilter complexFilter; - private IGraphFilter simpleGraphFilter; - private IGraphFilter middleGraphFilter; - private IGraphFilter complexGraphFilter; - private List> edges = new ArrayList<>(); - private List> vertices = new ArrayList<>(); - private IFilterConverter converter; - - @BeforeClass - public void setUp() { - simpleFilter = new EdgeLabelFilter(Arrays.asList("label1", "你好", "label2")) + private IFilter simpleFilter; + private IFilter middleFilter; + private IFilter complexFilter; + private IGraphFilter simpleGraphFilter; + private IGraphFilter middleGraphFilter; + private IGraphFilter complexGraphFilter; + private List> edges = new ArrayList<>(); + private List> vertices = new ArrayList<>(); + private IFilterConverter converter; + + @BeforeClass + public void setUp() { + simpleFilter = + new EdgeLabelFilter(Arrays.asList("label1", "你好", "label2")) .and(new VertexTsFilter(TimeRange.of(100, 1000))); - middleFilter = new VertexTsFilter(TimeRange.of(100, 1000)) - .or(new VertexTsFilter(TimeRange.of(10, 100)).and(new VertexLabelFilter(Arrays.asList("label2")))) + middleFilter = + new VertexTsFilter(TimeRange.of(100, 1000)) + .or( + new VertexTsFilter(TimeRange.of(10, 100)) + .and(new VertexLabelFilter(Arrays.asList("label2")))) .or(new VertexLabelFilter(Arrays.asList("label3"))); - List list = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - list.add(new AndFilter(i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance(), - new EdgeLabelFilter("label" + i)).and(new EdgeTsFilter(TimeRange.of(10, 100)))); - } - complexFilter = new OrFilter(list); - - simpleGraphFilter = - GraphFilter.of(new EdgeLabelFilter(Arrays.asList("label1", "label2"))).and(GraphFilter.of(new VertexTsFilter( - TimeRange.of(100, 1000)))); - middleGraphFilter = - GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000))) - .or(GraphFilter.of(new VertexTsFilter(TimeRange.of(10, 100))).and(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label2"))))) - .or(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label3")))); - - List graphFilters = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - graphFilters.add(new AndGraphFilter(Arrays.asList(GraphFilter.of(i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance()), - GraphFilter.of(new EdgeLabelFilter("label" + i)), - GraphFilter.of(new EdgeTsFilter(TimeRange.of(10, 100)))))); - } - complexGraphFilter = new OrGraphFilter(graphFilters); - for (int i = 0; i < 10000000; i++) { - this.edges.add(new ValueLabelTimeEdge<>(0, i, 0, - i / 3 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, - "label" + i % 6, i)); - } - for (int i = 0; i < 10000000; i++) { - this.vertices.add(new ValueLabelTimeVertex<>(i, 0, - "label" + i % 6, i)); - } - - converter = new CodeGenFilterConverter(); + List list = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + list.add( + new AndFilter( + i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance(), + new EdgeLabelFilter("label" + i)) + .and(new EdgeTsFilter(TimeRange.of(10, 100)))); + } + complexFilter = new OrFilter(list); + + simpleGraphFilter = + GraphFilter.of(new EdgeLabelFilter(Arrays.asList("label1", "label2"))) + .and(GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000)))); + middleGraphFilter = + GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000))) + .or( + GraphFilter.of(new VertexTsFilter(TimeRange.of(10, 100))) + .and(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label2"))))) + .or(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label3")))); + + List graphFilters = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + graphFilters.add( + new AndGraphFilter( + Arrays.asList( + GraphFilter.of( + i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance()), + GraphFilter.of(new EdgeLabelFilter("label" + i)), + GraphFilter.of(new EdgeTsFilter(TimeRange.of(10, 100)))))); + } + complexGraphFilter = new OrGraphFilter(graphFilters); + for (int i = 0; i < 10000000; i++) { + this.edges.add( + new ValueLabelTimeEdge<>( + 0, i, 0, i / 3 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, "label" + i % 6, i)); + } + for (int i = 0; i < 10000000; i++) { + this.vertices.add(new ValueLabelTimeVertex<>(i, 0, "label" + i % 6, i)); } - @Test - public void testSimple() throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(simpleFilter); + converter = new CodeGenFilterConverter(); + } - long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); - long eNumber = edges.stream().filter(genFilter::filterEdge).count(); + @Test + public void testSimple() throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(simpleFilter); - Assert.assertEquals(vNumber, vertices.stream().filter(simpleGraphFilter::filterVertex).count()); - Assert.assertEquals(eNumber, edges.stream().filter(simpleGraphFilter::filterEdge).count()); - } + long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); + long eNumber = edges.stream().filter(genFilter::filterEdge).count(); - @Test - public void testMiddle() throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(middleFilter); + Assert.assertEquals(vNumber, vertices.stream().filter(simpleGraphFilter::filterVertex).count()); + Assert.assertEquals(eNumber, edges.stream().filter(simpleGraphFilter::filterEdge).count()); + } - long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); - long eNumber = edges.stream().filter(genFilter::filterEdge).count(); + @Test + public void testMiddle() throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(middleFilter); - Assert.assertEquals(vNumber, vertices.stream().filter(middleGraphFilter::filterVertex).count()); - Assert.assertEquals(eNumber, edges.stream().filter(middleGraphFilter::filterEdge).count()); - } + long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); + long eNumber = edges.stream().filter(genFilter::filterEdge).count(); - @Test - public void testComplex() throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(complexFilter); + Assert.assertEquals(vNumber, vertices.stream().filter(middleGraphFilter::filterVertex).count()); + Assert.assertEquals(eNumber, edges.stream().filter(middleGraphFilter::filterEdge).count()); + } - long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); - long eNumber = edges.stream().filter(genFilter::filterEdge).count(); + @Test + public void testComplex() throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(complexFilter); - Assert.assertEquals(vNumber, vertices.stream().filter(complexGraphFilter::filterVertex).count()); - Assert.assertEquals(eNumber, edges.stream().filter(complexGraphFilter::filterEdge).count()); - } + long vNumber = vertices.stream().filter(genFilter::filterVertex).count(); + long eNumber = edges.stream().filter(genFilter::filterEdge).count(); + + Assert.assertEquals( + vNumber, vertices.stream().filter(complexGraphFilter::filterVertex).count()); + Assert.assertEquals(eNumber, edges.stream().filter(complexGraphFilter::filterEdge).count()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/PushDownCodeGenJMH.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/PushDownCodeGenJMH.java index a446e1552..8bde16389 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/PushDownCodeGenJMH.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/pushdown/inner/PushDownCodeGenJMH.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; + import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.ValueLabelTimeEdge; @@ -71,128 +72,143 @@ @State(Scope.Benchmark) public class PushDownCodeGenJMH { - private IFilter simpleFilter; - private IFilter middleFilter; - private IFilter complexFilter; - private IGraphFilter simpleGraphFilter; - private IGraphFilter middleGraphFilter; - private IGraphFilter complexGraphFilter; + private IFilter simpleFilter; + private IFilter middleFilter; + private IFilter complexFilter; + private IGraphFilter simpleGraphFilter; + private IGraphFilter middleGraphFilter; + private IGraphFilter complexGraphFilter; - @Param({"500000", "1000000", "3000000"}) - private int vertexAndEdgeNum; + @Param({"500000", "1000000", "3000000"}) + private int vertexAndEdgeNum; - private List> edges = new ArrayList<>(); - private List> vertices = new ArrayList<>(); - private IFilterConverter converter; + private List> edges = new ArrayList<>(); + private List> vertices = new ArrayList<>(); + private IFilterConverter converter; - @Setup - public void setUp() { - simpleFilter = new EdgeLabelFilter(Arrays.asList("label1", "label2")) + @Setup + public void setUp() { + simpleFilter = + new EdgeLabelFilter(Arrays.asList("label1", "label2")) .and(new VertexTsFilter(TimeRange.of(100, 1000))); - middleFilter = new VertexTsFilter(TimeRange.of(100, 1000)) - .or(new VertexTsFilter(TimeRange.of(10, 100)).and(new VertexLabelFilter(Arrays.asList("label2")))) + middleFilter = + new VertexTsFilter(TimeRange.of(100, 1000)) + .or( + new VertexTsFilter(TimeRange.of(10, 100)) + .and(new VertexLabelFilter(Arrays.asList("label2")))) .or(new VertexLabelFilter(Arrays.asList("label3"))); - List list = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - list.add(new AndFilter(i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance(), - new EdgeLabelFilter("label" + i % 5)).and(new EdgeTsFilter(TimeRange.of(10, 100)))); - } - complexFilter = new OrFilter(list); - - simpleGraphFilter = - GraphFilter.of(new EdgeLabelFilter(Arrays.asList("label1", "label2"))).and(GraphFilter.of(new VertexTsFilter( - TimeRange.of(100, 1000)))); - middleGraphFilter = - GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000))) - .or(GraphFilter.of(new VertexTsFilter(TimeRange.of(10, 100))).and(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label2"))))) - .or(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label3")))); - - List graphFilters = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - graphFilters.add(new AndGraphFilter(Arrays.asList(GraphFilter.of(i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance()), - GraphFilter.of(new EdgeLabelFilter("label" + i)), - GraphFilter.of(new EdgeTsFilter(TimeRange.of(10, 100)))))); - } - complexGraphFilter = new OrGraphFilter(graphFilters); - - for (int i = 0; i < vertexAndEdgeNum; i++) { - this.edges.add(new ValueLabelTimeEdge<>(0, i, 0, - i / 3 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, - "label" + i % 6, i)); - } - for (int i = 0; i < vertexAndEdgeNum; i++) { - this.vertices.add(new ValueLabelTimeVertex<>(i, 0, - "label" + i % 6, i)); - } - converter = new CodeGenFilterConverter(); + List list = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + list.add( + new AndFilter( + i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance(), + new EdgeLabelFilter("label" + i % 5)) + .and(new EdgeTsFilter(TimeRange.of(10, 100)))); } - - @Benchmark - public void simpleCodeGenFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(simpleFilter); - for (IEdge edge : edges) { - genFilter.filterEdge(edge); - } - for (IVertex vertex : vertices) { - genFilter.filterVertex(vertex); - } + complexFilter = new OrFilter(list); + + simpleGraphFilter = + GraphFilter.of(new EdgeLabelFilter(Arrays.asList("label1", "label2"))) + .and(GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000)))); + middleGraphFilter = + GraphFilter.of(new VertexTsFilter(TimeRange.of(100, 1000))) + .or( + GraphFilter.of(new VertexTsFilter(TimeRange.of(10, 100))) + .and(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label2"))))) + .or(GraphFilter.of(new VertexLabelFilter(Arrays.asList("label3")))); + + List graphFilters = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + graphFilters.add( + new AndGraphFilter( + Arrays.asList( + GraphFilter.of( + i / 2 == 0 ? InEdgeFilter.getInstance() : OutEdgeFilter.getInstance()), + GraphFilter.of(new EdgeLabelFilter("label" + i)), + GraphFilter.of(new EdgeTsFilter(TimeRange.of(10, 100)))))); } + complexGraphFilter = new OrGraphFilter(graphFilters); - @Benchmark - public void simpleGraphFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = simpleGraphFilter; - for (IEdge edge : edges) { - genFilter.filterEdge(edge); - } - for (IVertex vertex : vertices) { - genFilter.filterVertex(vertex); - } + for (int i = 0; i < vertexAndEdgeNum; i++) { + this.edges.add( + new ValueLabelTimeEdge<>( + 0, i, 0, i / 3 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, "label" + i % 6, i)); } - - @Benchmark - public void middleCodeGenFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(middleFilter); - for (IEdge edge : edges) { - genFilter.filterEdge(edge); - } - for (IVertex vertex : vertices) { - genFilter.filterVertex(vertex); - } + for (int i = 0; i < vertexAndEdgeNum; i++) { + this.vertices.add(new ValueLabelTimeVertex<>(i, 0, "label" + i % 6, i)); } - - @Benchmark - public void middleGraphFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = middleGraphFilter; - for (IEdge edge : edges) { - genFilter.filterEdge(edge); - } - for (IVertex vertex : vertices) { - genFilter.filterVertex(vertex); - } + converter = new CodeGenFilterConverter(); + } + + @Benchmark + public void simpleCodeGenFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(simpleFilter); + for (IEdge edge : edges) { + genFilter.filterEdge(edge); } - - @Benchmark - public void complexCodeGenFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = (IGraphFilter) converter.convert(complexFilter); - blackhole.consume(edges.stream().filter(genFilter::filterEdge).collect(Collectors.toList())); - blackhole.consume(vertices.stream().filter(genFilter::filterVertex).collect(Collectors.toList())); + for (IVertex vertex : vertices) { + genFilter.filterVertex(vertex); } + } - @Benchmark - public void complexGraphFilter(Blackhole blackhole) throws Exception { - IGraphFilter genFilter = complexGraphFilter; - blackhole.consume(edges.stream().filter(genFilter::filterEdge).collect(Collectors.toList())); - blackhole.consume(vertices.stream().filter(genFilter::filterVertex).collect(Collectors.toList())); + @Benchmark + public void simpleGraphFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = simpleGraphFilter; + for (IEdge edge : edges) { + genFilter.filterEdge(edge); + } + for (IVertex vertex : vertices) { + genFilter.filterVertex(vertex); } + } - public static void main(String[] args) throws RunnerException { + @Benchmark + public void middleCodeGenFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(middleFilter); + for (IEdge edge : edges) { + genFilter.filterEdge(edge); + } + for (IVertex vertex : vertices) { + genFilter.filterVertex(vertex); + } + } - Options opt = new OptionsBuilder() + @Benchmark + public void middleGraphFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = middleGraphFilter; + for (IEdge edge : edges) { + genFilter.filterEdge(edge); + } + for (IVertex vertex : vertices) { + genFilter.filterVertex(vertex); + } + } + + @Benchmark + public void complexCodeGenFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = (IGraphFilter) converter.convert(complexFilter); + blackhole.consume(edges.stream().filter(genFilter::filterEdge).collect(Collectors.toList())); + blackhole.consume( + vertices.stream().filter(genFilter::filterVertex).collect(Collectors.toList())); + } + + @Benchmark + public void complexGraphFilter(Blackhole blackhole) throws Exception { + IGraphFilter genFilter = complexGraphFilter; + blackhole.consume(edges.stream().filter(genFilter::filterEdge).collect(Collectors.toList())); + blackhole.consume( + vertices.stream().filter(genFilter::filterVertex).collect(Collectors.toList())); + } + + public static void main(String[] args) throws RunnerException { + + Options opt = + new OptionsBuilder() // import test class. .include(PushDownCodeGenJMH.class.getSimpleName()) .resultFormat(ResultFormatType.JSON) .result("allocation.json") .build(); - new Runner(opt).run(); - } + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/schema/GraphDataSchemaTest.java b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/schema/GraphDataSchemaTest.java index 6064adbec..9b0261ca1 100644 --- a/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/schema/GraphDataSchemaTest.java +++ b/geaflow/geaflow-state/geaflow-state-common/src/test/java/org/apache/geaflow/state/schema/GraphDataSchemaTest.java @@ -36,63 +36,79 @@ public class GraphDataSchemaTest { - @Test - public void test() throws Exception { - GraphMeta graphMeta = new GraphMeta(new GraphMetaType(Types.STRING, - ValueTimeVertex.class, ValueTimeVertex::new, String.class, - ValueLabelEdge.class, ValueLabelEdge::new, ClassProperty.class)); - - GraphDataSchema graphDataSchema = new GraphDataSchema(graphMeta); - Assert.assertEquals(graphDataSchema.getKeyType(), Types.STRING); - Assert.assertFalse(graphDataSchema.isEmptyEdgeProperty()); - Assert.assertFalse(graphDataSchema.isEmptyVertexProperty()); - - IVertex vertex = graphDataSchema.getVertexConsFun().get(); - Assert.assertEquals(vertex.getClass(), graphMeta.getVertexMeta().getGraphElementClass()); - IEdge edge = graphDataSchema.getEdgeConsFun().get(); - Assert.assertEquals(edge.getClass(), graphMeta.getEdgeMeta().getGraphElementClass()); - - Assert.assertTrue(graphDataSchema.getVertexAtoms().contains(VertexAtom.TIME)); - Assert.assertTrue(graphDataSchema.getEdgeAtoms().contains(EdgeAtom.LABEL)); - - byte[] bytes = StringType.INSTANCE.serialize("foobar"); - Assert.assertEquals(graphDataSchema.getVertexPropertySerFun().apply("foobar"), bytes); - Assert.assertEquals(graphDataSchema.getVertexPropertyDeFun().apply(bytes), "foobar"); - - Assert.assertEquals(graphDataSchema.getEdgePropertySerFun().apply(new ClassProperty()), new byte[0]); - Assert.assertEquals(graphDataSchema.getEdgePropertyDeFun().apply(new byte[0]), property); - - GraphElementMetas.clearCache(); - graphMeta = new GraphMeta(new GraphMetaType(Types.STRING, - ValueTimeVertex.class, ValueTimeVertex::new, ClassProperty.class, - ValueLabelEdge.class, ValueLabelEdge::new, String.class)); - graphDataSchema = new GraphDataSchema(graphMeta); - - bytes = StringType.INSTANCE.serialize("foobar"); - Assert.assertEquals(graphDataSchema.getEdgePropertySerFun().apply("foobar"), bytes); - Assert.assertEquals(graphDataSchema.getEdgePropertyDeFun().apply(bytes), "foobar"); - - Assert.assertEquals(graphDataSchema.getVertexPropertySerFun().apply(new ClassProperty()), new byte[0]); - Assert.assertEquals(graphDataSchema.getVertexPropertyDeFun().apply(new byte[0]), property); + @Test + public void test() throws Exception { + GraphMeta graphMeta = + new GraphMeta( + new GraphMetaType( + Types.STRING, + ValueTimeVertex.class, + ValueTimeVertex::new, + String.class, + ValueLabelEdge.class, + ValueLabelEdge::new, + ClassProperty.class)); + + GraphDataSchema graphDataSchema = new GraphDataSchema(graphMeta); + Assert.assertEquals(graphDataSchema.getKeyType(), Types.STRING); + Assert.assertFalse(graphDataSchema.isEmptyEdgeProperty()); + Assert.assertFalse(graphDataSchema.isEmptyVertexProperty()); + + IVertex vertex = graphDataSchema.getVertexConsFun().get(); + Assert.assertEquals(vertex.getClass(), graphMeta.getVertexMeta().getGraphElementClass()); + IEdge edge = graphDataSchema.getEdgeConsFun().get(); + Assert.assertEquals(edge.getClass(), graphMeta.getEdgeMeta().getGraphElementClass()); + + Assert.assertTrue(graphDataSchema.getVertexAtoms().contains(VertexAtom.TIME)); + Assert.assertTrue(graphDataSchema.getEdgeAtoms().contains(EdgeAtom.LABEL)); + + byte[] bytes = StringType.INSTANCE.serialize("foobar"); + Assert.assertEquals(graphDataSchema.getVertexPropertySerFun().apply("foobar"), bytes); + Assert.assertEquals(graphDataSchema.getVertexPropertyDeFun().apply(bytes), "foobar"); + + Assert.assertEquals( + graphDataSchema.getEdgePropertySerFun().apply(new ClassProperty()), new byte[0]); + Assert.assertEquals(graphDataSchema.getEdgePropertyDeFun().apply(new byte[0]), property); + + GraphElementMetas.clearCache(); + graphMeta = + new GraphMeta( + new GraphMetaType( + Types.STRING, + ValueTimeVertex.class, + ValueTimeVertex::new, + ClassProperty.class, + ValueLabelEdge.class, + ValueLabelEdge::new, + String.class)); + graphDataSchema = new GraphDataSchema(graphMeta); + + bytes = StringType.INSTANCE.serialize("foobar"); + Assert.assertEquals(graphDataSchema.getEdgePropertySerFun().apply("foobar"), bytes); + Assert.assertEquals(graphDataSchema.getEdgePropertyDeFun().apply(bytes), "foobar"); + + Assert.assertEquals( + graphDataSchema.getVertexPropertySerFun().apply(new ClassProperty()), new byte[0]); + Assert.assertEquals(graphDataSchema.getVertexPropertyDeFun().apply(new byte[0]), property); + } + + private static final ClassProperty property = new ClassProperty(); + + public static final class ClassProperty implements IPropertySerializable { + + @Override + public IPropertySerializable fromBinary(byte[] bytes) { + return property; } - private static final ClassProperty property = new ClassProperty(); - - public static final class ClassProperty implements IPropertySerializable { - - @Override - public IPropertySerializable fromBinary(byte[] bytes) { - return property; - } - - @Override - public byte[] toBytes() { - return new byte[0]; - } + @Override + public byte[] toBytes() { + return new byte[0]; + } - @Override - public IPropertySerializable clone() { - return property; - } + @Override + public IPropertySerializable clone() { + return property; } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseDynamicQueryState.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseDynamicQueryState.java index eb592d785..e217f1c43 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseDynamicQueryState.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseDynamicQueryState.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.query.QueryCondition; import org.apache.geaflow.state.query.QueryType; @@ -34,64 +35,64 @@ import org.apache.geaflow.state.strategy.manager.IGraphManager; import org.apache.geaflow.utils.keygroup.KeyGroup; -public abstract class BaseDynamicQueryState implements - DynamicQueryableState { +public abstract class BaseDynamicQueryState + implements DynamicQueryableState { - protected final QueryType queryType; - protected final IGraphManager graphManager; + protected final QueryType queryType; + protected final IGraphManager graphManager; - public BaseDynamicQueryState(QueryType queryType, IGraphManager graphManager) { - this.queryType = queryType; - this.graphManager = graphManager; - } + public BaseDynamicQueryState(QueryType queryType, IGraphManager graphManager) { + this.queryType = queryType; + this.graphManager = graphManager; + } - @Override - public QueryableAllGraphState query(long version) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = null; - queryCondition.isFullScan = true; - return new QueryableAllGraphStateImpl<>(version, queryType, graphManager, queryCondition); - } + @Override + public QueryableAllGraphState query(long version) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = null; + queryCondition.isFullScan = true; + return new QueryableAllGraphStateImpl<>(version, queryType, graphManager, queryCondition); + } - @Override - public QueryableAllGraphState query(long version, KeyGroup keyGroup) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.keyGroup = keyGroup; - queryCondition.queryIds = null; - queryCondition.isFullScan = true; - return new QueryableAllGraphStateImpl<>(version, queryType, graphManager, queryCondition); - } + @Override + public QueryableAllGraphState query(long version, KeyGroup keyGroup) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.keyGroup = keyGroup; + queryCondition.queryIds = null; + queryCondition.isFullScan = true; + return new QueryableAllGraphStateImpl<>(version, queryType, graphManager, queryCondition); + } - @Override - public QueryableKeysGraphState query(long version, K... ids) { - return query(version, Arrays.asList(ids)); - } + @Override + public QueryableKeysGraphState query(long version, K... ids) { + return query(version, Arrays.asList(ids)); + } - @Override - public QueryableKeysGraphState query(long version, List ids) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = ids; - queryCondition.isFullScan = false; - return new QueryableKeysGraphStateImpl<>(version, queryType, graphManager, queryCondition); - } + @Override + public QueryableKeysGraphState query(long version, List ids) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = ids; + queryCondition.isFullScan = false; + return new QueryableKeysGraphStateImpl<>(version, queryType, graphManager, queryCondition); + } - @Override - public QueryableVersionGraphState query(K id) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = Arrays.asList(id); - return new QueryableVersionGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableVersionGraphState query(K id) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = Arrays.asList(id); + return new QueryableVersionGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public QueryableVersionGraphState query(K id, Collection versions) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = Arrays.asList(id); - queryCondition.versions = versions; - return new QueryableVersionGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableVersionGraphState query(K id, Collection versions) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = Arrays.asList(id); + queryCondition.versions = versions; + return new QueryableVersionGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public CloseableIterator idIterator() { - return this.graphManager.getDynamicGraphTrait().vertexIDIterator(); - } + @Override + public CloseableIterator idIterator() { + return this.graphManager.getDynamicGraphTrait().vertexIDIterator(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseKeyStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseKeyStateImpl.java index ec3353dae..6d30ae6e6 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseKeyStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseKeyStateImpl.java @@ -27,16 +27,16 @@ public abstract class BaseKeyStateImpl { - protected final IKeyStateManager keyStateManager; - protected final ManageableState manageableState; + protected final IKeyStateManager keyStateManager; + protected final ManageableState manageableState; - public BaseKeyStateImpl(StateContext context) { - this.keyStateManager = new KeyStateManager<>(); - this.keyStateManager.init(context); - this.manageableState = new ManageableStateImpl(this.keyStateManager); - } + public BaseKeyStateImpl(StateContext context) { + this.keyStateManager = new KeyStateManager<>(); + this.keyStateManager.init(context); + this.manageableState = new ManageableStateImpl(this.keyStateManager); + } - public ManageableState manage() { - return manageableState; - } + public ManageableState manage() { + return manageableState; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseQueryState.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseQueryState.java index 33a49b109..80485a87b 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseQueryState.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/BaseQueryState.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.state.query.QueryCondition; import org.apache.geaflow.state.query.QueryType; @@ -34,64 +35,64 @@ public abstract class BaseQueryState implements StaticQueryableState { - protected final QueryType queryType; - protected final IGraphManager graphManager; + protected final QueryType queryType; + protected final IGraphManager graphManager; - public BaseQueryState(QueryType queryType, IGraphManager graphManager) { - this.queryType = queryType; - this.graphManager = graphManager; - } + public BaseQueryState(QueryType queryType, IGraphManager graphManager) { + this.queryType = queryType; + this.graphManager = graphManager; + } - @Override - public QueryableAllGraphState query() { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = null; - queryCondition.isFullScan = true; - return new QueryableAllGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableAllGraphState query() { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = null; + queryCondition.isFullScan = true; + return new QueryableAllGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public QueryableAllGraphState query(KeyGroup keyGroup) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.keyGroup = keyGroup; - queryCondition.queryIds = null; - queryCondition.isFullScan = true; - return new QueryableAllGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableAllGraphState query(KeyGroup keyGroup) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.keyGroup = keyGroup; + queryCondition.queryIds = null; + queryCondition.isFullScan = true; + return new QueryableAllGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public QueryableKeysGraphState query(K id) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryId = id; - queryCondition.isFullScan = false; - return new QueryableOneKeyGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableKeysGraphState query(K id) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryId = id; + queryCondition.isFullScan = false; + return new QueryableOneKeyGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public QueryableKeysGraphState query(K... ids) { - return query(Arrays.asList(ids)); - } + @Override + public QueryableKeysGraphState query(K... ids) { + return query(Arrays.asList(ids)); + } - @Override - public QueryableKeysGraphState query(List ids) { - QueryCondition queryCondition = new QueryCondition<>(); - queryCondition.queryIds = ids; - queryCondition.isFullScan = false; - return new QueryableKeysGraphStateImpl<>(queryType, graphManager, queryCondition); - } + @Override + public QueryableKeysGraphState query(List ids) { + QueryCondition queryCondition = new QueryCondition<>(); + queryCondition.queryIds = ids; + queryCondition.isFullScan = false; + return new QueryableKeysGraphStateImpl<>(queryType, graphManager, queryCondition); + } - @Override - public CloseableIterator idIterator() { - return this.graphManager.getStaticGraphTrait().vertexIDIterator(); - } + @Override + public CloseableIterator idIterator() { + return this.graphManager.getStaticGraphTrait().vertexIDIterator(); + } - @Override - public CloseableIterator iterator() { - return query().iterator(); - } + @Override + public CloseableIterator iterator() { + return query().iterator(); + } - @Override - public List asList() { - return query().asList(); - } + @Override + public List asList() { + return query().asList(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicEdgeStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicEdgeStateImpl.java index 58e0dc900..f59b581fb 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicEdgeStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicEdgeStateImpl.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -29,51 +30,50 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class DynamicEdgeStateImpl extends - BaseDynamicQueryState> implements - DynamicEdgeState { +public class DynamicEdgeStateImpl extends BaseDynamicQueryState> + implements DynamicEdgeState { - public DynamicEdgeStateImpl(IGraphManager manager) { - super(new QueryType<>(DataType.E), manager); - } + public DynamicEdgeStateImpl(IGraphManager manager) { + super(new QueryType<>(DataType.E), manager); + } - @Override - public List getAllVersions(K id) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public List getAllVersions(K id) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public long getLatestVersion(K id) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public long getLatestVersion(K id) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public CloseableIterator idIterator() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public CloseableIterator idIterator() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void add(long version, IEdge edge) { - this.graphManager.getDynamicGraphTrait().addEdge(version, edge); - } + @Override + public void add(long version, IEdge edge) { + this.graphManager.getDynamicGraphTrait().addEdge(version, edge); + } - @Override - public void update(long version, IEdge edge) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void update(long version, IEdge edge) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, IEdge edge) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, IEdge edge) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, K... ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, K... ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, Collection ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, Collection ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicGraphStateImpl.java index e4836d058..4e3329f35 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicGraphStateImpl.java @@ -23,36 +23,36 @@ public class DynamicGraphStateImpl implements DynamicGraphState { - private final IGraphManager graphManager; - private DynamicVertexState vertexState; - private DynamicEdgeState edgeState; - private DynamicOneDegreeGraphState oneDegreeGraphState; - - public DynamicGraphStateImpl(IGraphManager graphManager) { - this.graphManager = graphManager; - } - - @Override - public DynamicVertexState V() { - if (vertexState == null) { - vertexState = new DynamicVertexStateImpl<>(this.graphManager); - } - return vertexState; + private final IGraphManager graphManager; + private DynamicVertexState vertexState; + private DynamicEdgeState edgeState; + private DynamicOneDegreeGraphState oneDegreeGraphState; + + public DynamicGraphStateImpl(IGraphManager graphManager) { + this.graphManager = graphManager; + } + + @Override + public DynamicVertexState V() { + if (vertexState == null) { + vertexState = new DynamicVertexStateImpl<>(this.graphManager); } + return vertexState; + } - @Override - public DynamicEdgeState E() { - if (edgeState == null) { - edgeState = new DynamicEdgeStateImpl<>(this.graphManager); - } - return edgeState; + @Override + public DynamicEdgeState E() { + if (edgeState == null) { + edgeState = new DynamicEdgeStateImpl<>(this.graphManager); } + return edgeState; + } - @Override - public DynamicOneDegreeGraphState VE() { - if (oneDegreeGraphState == null) { - oneDegreeGraphState = new DynamicOneDegreeGraphStateImpl<>(this.graphManager); - } - return oneDegreeGraphState; + @Override + public DynamicOneDegreeGraphState VE() { + if (oneDegreeGraphState == null) { + oneDegreeGraphState = new DynamicOneDegreeGraphStateImpl<>(this.graphManager); } + return oneDegreeGraphState; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphStateImpl.java index 81eb0ce88..87d74a7dd 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicOneDegreeGraphStateImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.state.data.DataType; @@ -27,20 +28,21 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class DynamicOneDegreeGraphStateImpl extends BaseDynamicQueryState> +public class DynamicOneDegreeGraphStateImpl + extends BaseDynamicQueryState> implements DynamicOneDegreeGraphState { - public DynamicOneDegreeGraphStateImpl(IGraphManager manager) { - super(new QueryType<>(DataType.VE), manager); - } + public DynamicOneDegreeGraphStateImpl(IGraphManager manager) { + super(new QueryType<>(DataType.VE), manager); + } - @Override - public List getAllVersions(K id) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public List getAllVersions(K id) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public long getLatestVersion(K id) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public long getLatestVersion(K id) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicVertexStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicVertexStateImpl.java index 6d041bc6f..59685d608 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicVertexStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/DynamicVertexStateImpl.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -28,46 +29,46 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class DynamicVertexStateImpl extends - BaseDynamicQueryState> implements - DynamicVertexState { +public class DynamicVertexStateImpl + extends BaseDynamicQueryState> + implements DynamicVertexState { - public DynamicVertexStateImpl(IGraphManager graphManager) { - super(new QueryType<>(DataType.V), graphManager); - } + public DynamicVertexStateImpl(IGraphManager graphManager) { + super(new QueryType<>(DataType.V), graphManager); + } - @Override - public List getAllVersions(K id) { - return this.graphManager.getDynamicGraphTrait().getAllVersions(id, DataType.V); - } + @Override + public List getAllVersions(K id) { + return this.graphManager.getDynamicGraphTrait().getAllVersions(id, DataType.V); + } - @Override - public long getLatestVersion(K id) { - return this.graphManager.getDynamicGraphTrait().getLatestVersion(id, DataType.V); - } + @Override + public long getLatestVersion(K id) { + return this.graphManager.getDynamicGraphTrait().getLatestVersion(id, DataType.V); + } - @Override - public void add(long version, IVertex vertex) { - this.graphManager.getDynamicGraphTrait().addVertex(version, vertex); - } + @Override + public void add(long version, IVertex vertex) { + this.graphManager.getDynamicGraphTrait().addVertex(version, vertex); + } - @Override - public void update(long version, IVertex vertex) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void update(long version, IVertex vertex) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, IVertex vertex) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, IVertex vertex) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, K... ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, K... ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(long version, Collection ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(long version, Collection ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/GraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/GraphStateImpl.java index ab2d03a3a..bd7a1cff0 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/GraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/GraphStateImpl.java @@ -28,35 +28,34 @@ import org.slf4j.LoggerFactory; public class GraphStateImpl implements GraphState { - private static final Logger LOGGER = LoggerFactory.getLogger(GraphStateImpl.class); - private final IGraphManager graphManager; - private final ManageableGraphState manageableGraphState; - - private StaticGraphState staticGraphState; - private DynamicGraphState dynamicGraphState; - - public GraphStateImpl(StateContext context) { - this.graphManager = new GraphManagerImpl<>(); - LOGGER.info("ThreadId {}, GraphStateImpl initDB", Thread.currentThread().getId()); - this.graphManager.init(context); - this.manageableGraphState = new ManageableGraphStateImpl(this.graphManager); - this.staticGraphState = new StaticGraphStateImpl<>(this.graphManager); - this.dynamicGraphState = new DynamicGraphStateImpl<>(this.graphManager); - } - - @Override - public StaticGraphState staticGraph() { - return staticGraphState; - } - - @Override - public DynamicGraphState dynamicGraph() { - return dynamicGraphState; - } - - @Override - public ManageableGraphState manage() { - return manageableGraphState; - } - + private static final Logger LOGGER = LoggerFactory.getLogger(GraphStateImpl.class); + private final IGraphManager graphManager; + private final ManageableGraphState manageableGraphState; + + private StaticGraphState staticGraphState; + private DynamicGraphState dynamicGraphState; + + public GraphStateImpl(StateContext context) { + this.graphManager = new GraphManagerImpl<>(); + LOGGER.info("ThreadId {}, GraphStateImpl initDB", Thread.currentThread().getId()); + this.graphManager.init(context); + this.manageableGraphState = new ManageableGraphStateImpl(this.graphManager); + this.staticGraphState = new StaticGraphStateImpl<>(this.graphManager); + this.dynamicGraphState = new DynamicGraphStateImpl<>(this.graphManager); + } + + @Override + public StaticGraphState staticGraph() { + return staticGraphState; + } + + @Override + public DynamicGraphState dynamicGraph() { + return dynamicGraphState; + } + + @Override + public ManageableGraphState manage() { + return manageableGraphState; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyListStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyListStateImpl.java index 66c8e2086..166f23338 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyListStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyListStateImpl.java @@ -20,37 +20,37 @@ package org.apache.geaflow.state; import java.util.List; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyListTrait; public class KeyListStateImpl extends BaseKeyStateImpl implements KeyListState { - private final KeyListTrait trait; - - public KeyListStateImpl(StateContext context, Class valueClazz) { - super(context); - this.trait = this.keyStateManager.getKeyListTrait(valueClazz); - } - - @Override - public List get(K key) { - return this.trait.get(key); - } - - @Override - public void add(K key, V... value) { - this.trait.add(key, value); - } - - @Override - public void remove(K key) { - this.trait.remove(key); - } - - @Override - public void put(K key, List list) { - this.trait.remove(key); - this.trait.add(key, (V[]) list.toArray(new Object[0])); - } - + private final KeyListTrait trait; + + public KeyListStateImpl(StateContext context, Class valueClazz) { + super(context); + this.trait = this.keyStateManager.getKeyListTrait(valueClazz); + } + + @Override + public List get(K key) { + return this.trait.get(key); + } + + @Override + public void add(K key, V... value) { + this.trait.add(key, value); + } + + @Override + public void remove(K key) { + this.trait.remove(key); + } + + @Override + public void put(K key, List list) { + this.trait.remove(key); + this.trait.add(key, (V[]) list.toArray(new Object[0])); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyMapStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyMapStateImpl.java index 13256c80d..a2f49641a 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyMapStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyMapStateImpl.java @@ -21,51 +21,53 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyMapTrait; -public class KeyMapStateImpl extends BaseKeyStateImpl implements KeyMapState { +public class KeyMapStateImpl extends BaseKeyStateImpl + implements KeyMapState { - private final KeyMapTrait trait; + private final KeyMapTrait trait; - public KeyMapStateImpl(StateContext context, Class subKeyClazz, Class valueClazz) { - super(context); - this.trait = this.keyStateManager.getKeyMapTrait(subKeyClazz, valueClazz); - } + public KeyMapStateImpl(StateContext context, Class subKeyClazz, Class valueClazz) { + super(context); + this.trait = this.keyStateManager.getKeyMapTrait(subKeyClazz, valueClazz); + } - @Override - public Map get(K key) { - return this.trait.get(key); - } + @Override + public Map get(K key) { + return this.trait.get(key); + } - @Override - public List get(K key, UK... subKeys) { - return this.trait.get(key, subKeys); - } + @Override + public List get(K key, UK... subKeys) { + return this.trait.get(key, subKeys); + } - @Override - public void add(K key, UK subKey, UV value) { - this.trait.add(key, subKey, value); - } + @Override + public void add(K key, UK subKey, UV value) { + this.trait.add(key, subKey, value); + } - @Override - public void add(K key, Map map) { - this.trait.add(key, map); - } + @Override + public void add(K key, Map map) { + this.trait.add(key, map); + } - @Override - public void put(K key, Map map) { - remove(key); - this.trait.add(key, map); - } + @Override + public void put(K key, Map map) { + remove(key); + this.trait.add(key, map); + } - @Override - public void remove(K key) { - this.trait.remove(key); - } + @Override + public void remove(K key) { + this.trait.remove(key); + } - @Override - public void remove(K key, UK... subKeys) { - this.trait.remove(key, subKeys); - } + @Override + public void remove(K key, UK... subKeys) { + this.trait.remove(key, subKeys); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueListStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueListStateImpl.java index 4cfcb744c..3539beba7 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueListStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueListStateImpl.java @@ -22,41 +22,42 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyValueTrait; public class KeyValueListStateImpl extends BaseKeyStateImpl implements KeyListState { - private final KeyValueTrait kListTrait; - - public KeyValueListStateImpl(StateContext context) { - super(context); - this.kListTrait = this.keyStateManager.getKeyValueTrait(List.class); - } + private final KeyValueTrait kListTrait; - @Override - public List get(K key) { - List list = this.kListTrait.get(key); - return list == null ? new ArrayList<>() : list; - } + public KeyValueListStateImpl(StateContext context) { + super(context); + this.kListTrait = this.keyStateManager.getKeyValueTrait(List.class); + } - @Override - public void add(K key, V... value) { - List list = get(key); - if (list == null) { - list = new ArrayList<>(); - } - list.addAll(Arrays.asList(value)); - put(key, list); - } - - @Override - public void remove(K key) { - this.kListTrait.remove(key); - } + @Override + public List get(K key) { + List list = this.kListTrait.get(key); + return list == null ? new ArrayList<>() : list; + } - @Override - public void put(K key, List list) { - this.kListTrait.put(key, list); + @Override + public void add(K key, V... value) { + List list = get(key); + if (list == null) { + list = new ArrayList<>(); } + list.addAll(Arrays.asList(value)); + put(key, list); + } + + @Override + public void remove(K key) { + this.kListTrait.remove(key); + } + + @Override + public void put(K key, List list) { + this.kListTrait.put(key, list); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueMapStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueMapStateImpl.java index 830738846..35f0f82e7 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueMapStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueMapStateImpl.java @@ -24,64 +24,66 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyValueTrait; -public class KeyValueMapStateImpl extends BaseKeyStateImpl implements KeyMapState { +public class KeyValueMapStateImpl extends BaseKeyStateImpl + implements KeyMapState { - private final KeyValueTrait kvTrait; + private final KeyValueTrait kvTrait; - public KeyValueMapStateImpl(StateContext context) { - super(context); - this.kvTrait = this.keyStateManager.getKeyValueTrait(Map.class); - } + public KeyValueMapStateImpl(StateContext context) { + super(context); + this.kvTrait = this.keyStateManager.getKeyValueTrait(Map.class); + } - @Override - public Map get(K key) { - Map map = this.kvTrait.get(key); - return map == null ? new HashMap<>() : map; - } + @Override + public Map get(K key) { + Map map = this.kvTrait.get(key); + return map == null ? new HashMap<>() : map; + } - @Override - public List get(K key, UK... subKeys) { - Map map = get(key); - return Arrays.stream(subKeys).map(map::get).collect(Collectors.toList()); - } + @Override + public List get(K key, UK... subKeys) { + Map map = get(key); + return Arrays.stream(subKeys).map(map::get).collect(Collectors.toList()); + } - @Override - public void add(K key, UK subKey, UV value) { - Map map = get(key); - if (map == null) { - map = new HashMap<>(); - } - map.put(subKey, value); - add(key, map); + @Override + public void add(K key, UK subKey, UV value) { + Map map = get(key); + if (map == null) { + map = new HashMap<>(); } + map.put(subKey, value); + add(key, map); + } - @Override - public void add(K key, Map map) { - Map tmp = get(key); - if (tmp == null) { - tmp = new HashMap<>(); - } - tmp.putAll(map); - put(key, tmp); + @Override + public void add(K key, Map map) { + Map tmp = get(key); + if (tmp == null) { + tmp = new HashMap<>(); } + tmp.putAll(map); + put(key, tmp); + } - @Override - public void remove(K key) { - this.kvTrait.remove(key); - } + @Override + public void remove(K key) { + this.kvTrait.remove(key); + } - @Override - public void put(K key, Map map) { - this.kvTrait.put(key, map); - } + @Override + public void put(K key, Map map) { + this.kvTrait.put(key, map); + } - @Override - public void remove(K key, UK... subKeys) { - Map map = get(key); - Arrays.stream(subKeys).forEach(map::remove); - add(key, map); - } + @Override + public void remove(K key, UK... subKeys) { + Map map = get(key); + Arrays.stream(subKeys).forEach(map::remove); + add(key, map); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueStateImpl.java index de12c1a16..f2384816e 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/KeyValueStateImpl.java @@ -25,31 +25,31 @@ public class KeyValueStateImpl extends BaseKeyStateImpl implements KeyValueState { - private final KeyValueStateDescriptor desc; - private final KeyValueTrait trait; - - public KeyValueStateImpl(StateContext context) { - super(context); - this.desc = (KeyValueStateDescriptor) context.getDescriptor(); - trait = this.keyStateManager.getKeyValueTrait(desc.getValueClazz()); - } - - @Override - public V get(K k) { - V res = this.trait.get(k); - if (res == null) { - res = desc.getDefaultValue(); - } - return res; - } - - @Override - public void put(K k, V value) { - this.trait.put(k, value); - } - - @Override - public void remove(K key) { - this.trait.remove(key); + private final KeyValueStateDescriptor desc; + private final KeyValueTrait trait; + + public KeyValueStateImpl(StateContext context) { + super(context); + this.desc = (KeyValueStateDescriptor) context.getDescriptor(); + trait = this.keyStateManager.getKeyValueTrait(desc.getValueClazz()); + } + + @Override + public V get(K k) { + V res = this.trait.get(k); + if (res == null) { + res = desc.getDefaultValue(); } + return res; + } + + @Override + public void put(K k, V value) { + this.trait.put(k, value); + } + + @Override + public void remove(K key) { + this.trait.remove(key); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StateFactory.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StateFactory.java index 2406b6e65..eff82b328 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StateFactory.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StateFactory.java @@ -19,12 +19,12 @@ package org.apache.geaflow.state; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.descriptor.GraphStateDescriptor; @@ -36,93 +36,101 @@ import org.apache.geaflow.store.IStoreBuilder; import org.apache.geaflow.store.api.StoreBuilderFactory; +import com.google.common.base.Preconditions; + public class StateFactory implements Serializable { - private static final long serialVersionUID = 6070809556097701233L; + private static final long serialVersionUID = 6070809556097701233L; - private static final List SUPPORTED_KEY_STORE_TYPES = - Arrays.asList(StoreType.ROCKSDB, StoreType.MEMORY, StoreType.PAIMON); + private static final List SUPPORTED_KEY_STORE_TYPES = + Arrays.asList(StoreType.ROCKSDB, StoreType.MEMORY, StoreType.PAIMON); - private static final Map GRAPH_STATE_MAP = new ConcurrentHashMap<>(); + private static final Map GRAPH_STATE_MAP = new ConcurrentHashMap<>(); - public static GraphState buildGraphState( - GraphStateDescriptor descriptor, Configuration configuration) { - if (descriptor.isSingleton()) { - return GRAPH_STATE_MAP.computeIfAbsent(descriptor.getName(), - k -> getGraphState(descriptor, configuration)); - } - return getGraphState(descriptor, configuration); + public static GraphState buildGraphState( + GraphStateDescriptor descriptor, Configuration configuration) { + if (descriptor.isSingleton()) { + return GRAPH_STATE_MAP.computeIfAbsent( + descriptor.getName(), k -> getGraphState(descriptor, configuration)); } + return getGraphState(descriptor, configuration); + } - private static GraphState getGraphState( - GraphStateDescriptor descriptor, Configuration configuration) { - if (descriptor.getDateModel() == null) { - descriptor.withDataModel(DataModel.STATIC_GRAPH); - } - return new GraphStateImpl<>(new StateContext(descriptor, configuration)); + private static GraphState getGraphState( + GraphStateDescriptor descriptor, Configuration configuration) { + if (descriptor.getDateModel() == null) { + descriptor.withDataModel(DataModel.STATIC_GRAPH); } + return new GraphStateImpl<>(new StateContext(descriptor, configuration)); + } + public static KeyValueState buildKeyValueState( + KeyValueStateDescriptor descriptor, Configuration configuration) { - public static KeyValueState buildKeyValueState( - KeyValueStateDescriptor descriptor, Configuration configuration) { - - Preconditions.checkArgument(SUPPORTED_KEY_STORE_TYPES.contains( - StoreType.getEnum(descriptor.getStoreType())), - "only support %s", SUPPORTED_KEY_STORE_TYPES); - if (descriptor.getKeySerializer() == null) { - descriptor.withKVSerializer(new DefaultKVSerializer<>( - descriptor.getKeyClazz(), descriptor.getValueClazz())); - } - descriptor.withDataModel(DataModel.KV); - return new KeyValueStateImpl<>(new StateContext(descriptor, configuration)); + Preconditions.checkArgument( + SUPPORTED_KEY_STORE_TYPES.contains(StoreType.getEnum(descriptor.getStoreType())), + "only support %s", + SUPPORTED_KEY_STORE_TYPES); + if (descriptor.getKeySerializer() == null) { + descriptor.withKVSerializer( + new DefaultKVSerializer<>(descriptor.getKeyClazz(), descriptor.getValueClazz())); } + descriptor.withDataModel(DataModel.KV); + return new KeyValueStateImpl<>(new StateContext(descriptor, configuration)); + } - public static KeyListState buildKeyListState( - KeyListStateDescriptor descriptor, Configuration configuration) { - Preconditions.checkArgument(SUPPORTED_KEY_STORE_TYPES.contains( - StoreType.getEnum(descriptor.getStoreType())), - "only support %s", SUPPORTED_KEY_STORE_TYPES); - String storeType = descriptor.getStoreType(); - IStoreBuilder builder = StoreBuilderFactory.build(storeType); - if (builder.supportedDataModel().contains(DataModel.KList)) { - if (descriptor.getKeySerializer() == null) { - descriptor.withKeySerializer(new DefaultKVSerializer<>(descriptor.getKeyClazz(), - descriptor.getValueClazz())); - } - descriptor.withDataModel(DataModel.KList); - return new KeyListStateImpl<>(new StateContext(descriptor, configuration), descriptor.getValueClazz()); - } else { - if (descriptor.getKeySerializer() == null) { - descriptor.withKeySerializer(new DefaultKVSerializer<>(descriptor.getKeyClazz(), List.class)); - } - descriptor.withDataModel(DataModel.KV); - return new KeyValueListStateImpl<>(new StateContext(descriptor, configuration)); - } - + public static KeyListState buildKeyListState( + KeyListStateDescriptor descriptor, Configuration configuration) { + Preconditions.checkArgument( + SUPPORTED_KEY_STORE_TYPES.contains(StoreType.getEnum(descriptor.getStoreType())), + "only support %s", + SUPPORTED_KEY_STORE_TYPES); + String storeType = descriptor.getStoreType(); + IStoreBuilder builder = StoreBuilderFactory.build(storeType); + if (builder.supportedDataModel().contains(DataModel.KList)) { + if (descriptor.getKeySerializer() == null) { + descriptor.withKeySerializer( + new DefaultKVSerializer<>(descriptor.getKeyClazz(), descriptor.getValueClazz())); + } + descriptor.withDataModel(DataModel.KList); + return new KeyListStateImpl<>( + new StateContext(descriptor, configuration), descriptor.getValueClazz()); + } else { + if (descriptor.getKeySerializer() == null) { + descriptor.withKeySerializer( + new DefaultKVSerializer<>(descriptor.getKeyClazz(), List.class)); + } + descriptor.withDataModel(DataModel.KV); + return new KeyValueListStateImpl<>(new StateContext(descriptor, configuration)); } + } - public static KeyMapState buildKeyMapState( - KeyMapStateDescriptor descriptor, Configuration configuration) { - Preconditions.checkArgument(SUPPORTED_KEY_STORE_TYPES.contains( - StoreType.getEnum(descriptor.getStoreType())), - "only support %s", SUPPORTED_KEY_STORE_TYPES); - String storeType = descriptor.getStoreType(); - IStoreBuilder builder = StoreBuilderFactory.build(storeType); - if (builder.supportedDataModel().contains(DataModel.KMap)) { - if (descriptor.getKeySerializer() == null) { - descriptor.withKeySerializer(new DefaultKMapSerializer<>(descriptor.getKeyClazz(), - descriptor.getSubKeyClazz(), descriptor.getValueClazz())); - } - descriptor.withDataModel(DataModel.KMap); - return new KeyMapStateImpl<>(new StateContext(descriptor, configuration), - descriptor.getSubKeyClazz(), descriptor.getValueClazz()); - } else { - if (descriptor.getKeySerializer() == null) { - descriptor.withKeySerializer( - new DefaultKVSerializer<>(descriptor.getKeyClazz(), Map.class)); - } - descriptor.withDataModel(DataModel.KV); - return new KeyValueMapStateImpl<>(new StateContext(descriptor, configuration)); - } + public static KeyMapState buildKeyMapState( + KeyMapStateDescriptor descriptor, Configuration configuration) { + Preconditions.checkArgument( + SUPPORTED_KEY_STORE_TYPES.contains(StoreType.getEnum(descriptor.getStoreType())), + "only support %s", + SUPPORTED_KEY_STORE_TYPES); + String storeType = descriptor.getStoreType(); + IStoreBuilder builder = StoreBuilderFactory.build(storeType); + if (builder.supportedDataModel().contains(DataModel.KMap)) { + if (descriptor.getKeySerializer() == null) { + descriptor.withKeySerializer( + new DefaultKMapSerializer<>( + descriptor.getKeyClazz(), descriptor.getSubKeyClazz(), descriptor.getValueClazz())); + } + descriptor.withDataModel(DataModel.KMap); + return new KeyMapStateImpl<>( + new StateContext(descriptor, configuration), + descriptor.getSubKeyClazz(), + descriptor.getValueClazz()); + } else { + if (descriptor.getKeySerializer() == null) { + descriptor.withKeySerializer( + new DefaultKVSerializer<>(descriptor.getKeyClazz(), Map.class)); + } + descriptor.withDataModel(DataModel.KV); + return new KeyValueMapStateImpl<>(new StateContext(descriptor, configuration)); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticEdgeStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticEdgeStateImpl.java index d6a5fef15..3caec55e3 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticEdgeStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticEdgeStateImpl.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Collection; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -29,40 +30,40 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class StaticEdgeStateImpl extends BaseQueryState> implements - StaticEdgeState { +public class StaticEdgeStateImpl extends BaseQueryState> + implements StaticEdgeState { - public StaticEdgeStateImpl(IGraphManager manager) { - super(new QueryType<>(DataType.E), manager); - } + public StaticEdgeStateImpl(IGraphManager manager) { + super(new QueryType<>(DataType.E), manager); + } - @Override - public void add(IEdge edge) { - this.graphManager.getStaticGraphTrait().addEdge(edge); - } + @Override + public void add(IEdge edge) { + this.graphManager.getStaticGraphTrait().addEdge(edge); + } - @Override - public void update(IEdge edge) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void update(IEdge edge) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(IEdge edge) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(IEdge edge) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public void delete(K... ids) { - delete(Arrays.asList(ids)); - } + @Override + public void delete(K... ids) { + delete(Arrays.asList(ids)); + } - @Override - public void delete(Collection ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public void delete(Collection ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public CloseableIterator idIterator() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public CloseableIterator idIterator() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticGraphStateImpl.java index 6ce5df01b..1593728df 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticGraphStateImpl.java @@ -23,30 +23,30 @@ public class StaticGraphStateImpl implements StaticGraphState { - private final IGraphManager graphManager; - private StaticVertexState vertexState; - private StaticEdgeState edgeState; - private StaticOneDegreeGraphState oneDegreeGraphState; - - public StaticGraphStateImpl(IGraphManager graphManager) { - this.graphManager = graphManager; - this.vertexState = new StaticVertexStateImpl<>(this.graphManager); - this.edgeState = new StaticEdgeStateImpl<>(this.graphManager); - this.oneDegreeGraphState = new StaticOneDegreeGraphStateImpl<>(this.graphManager); - } - - @Override - public StaticVertexState V() { - return vertexState; - } - - @Override - public StaticEdgeState E() { - return edgeState; - } - - @Override - public StaticOneDegreeGraphState VE() { - return oneDegreeGraphState; - } + private final IGraphManager graphManager; + private StaticVertexState vertexState; + private StaticEdgeState edgeState; + private StaticOneDegreeGraphState oneDegreeGraphState; + + public StaticGraphStateImpl(IGraphManager graphManager) { + this.graphManager = graphManager; + this.vertexState = new StaticVertexStateImpl<>(this.graphManager); + this.edgeState = new StaticEdgeStateImpl<>(this.graphManager); + this.oneDegreeGraphState = new StaticOneDegreeGraphStateImpl<>(this.graphManager); + } + + @Override + public StaticVertexState V() { + return vertexState; + } + + @Override + public StaticEdgeState E() { + return edgeState; + } + + @Override + public StaticOneDegreeGraphState VE() { + return oneDegreeGraphState; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphStateImpl.java index 0020daa3d..aec679e0e 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticOneDegreeGraphStateImpl.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state; import java.util.Collection; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.state.data.DataType; @@ -27,36 +28,36 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class StaticOneDegreeGraphStateImpl extends BaseQueryState> - implements StaticOneDegreeGraphState, - RevisableState> { - - public StaticOneDegreeGraphStateImpl(IGraphManager manager) { - super(new QueryType<>(DataType.VE), manager); - } - - @Override - public void add(OneDegreeGraph oneDegreeGraph) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void update(OneDegreeGraph oneDegreeGraph) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void delete(OneDegreeGraph oneDegreeGraph) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void delete(K... ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void delete(Collection ids) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } +public class StaticOneDegreeGraphStateImpl + extends BaseQueryState> + implements StaticOneDegreeGraphState, RevisableState> { + + public StaticOneDegreeGraphStateImpl(IGraphManager manager) { + super(new QueryType<>(DataType.VE), manager); + } + + @Override + public void add(OneDegreeGraph oneDegreeGraph) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void update(OneDegreeGraph oneDegreeGraph) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void delete(OneDegreeGraph oneDegreeGraph) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void delete(K... ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void delete(Collection ids) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticVertexStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticVertexStateImpl.java index 16e00a6f9..c4578f9e8 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticVertexStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/StaticVertexStateImpl.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.Collection; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -28,35 +29,35 @@ import org.apache.geaflow.state.query.QueryType; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class StaticVertexStateImpl extends BaseQueryState> implements - StaticVertexState { - - public StaticVertexStateImpl(IGraphManager graphManager) { - super(new QueryType<>(DataType.V), graphManager); - } - - @Override - public void add(IVertex vertex) { - this.graphManager.getStaticGraphTrait().addVertex(vertex); - } - - @Override - public void update(IVertex vertex) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void delete(IVertex vertex) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - - @Override - public void delete(K... ids) { - delete(Arrays.asList(ids)); - } - - @Override - public void delete(Collection id) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } +public class StaticVertexStateImpl extends BaseQueryState> + implements StaticVertexState { + + public StaticVertexStateImpl(IGraphManager graphManager) { + super(new QueryType<>(DataType.V), graphManager); + } + + @Override + public void add(IVertex vertex) { + this.graphManager.getStaticGraphTrait().addVertex(vertex); + } + + @Override + public void update(IVertex vertex) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void delete(IVertex vertex) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } + + @Override + public void delete(K... ids) { + delete(Arrays.asList(ids)); + } + + @Override + public void delete(Collection id) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableGraphStateImpl.java index 6f345f314..7f533ac8d 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableGraphStateImpl.java @@ -25,17 +25,17 @@ public class ManageableGraphStateImpl extends ManageableStateImpl implements ManageableGraphState { - public ManageableGraphStateImpl(IGraphManager graphManager) { - super(graphManager); - } + public ManageableGraphStateImpl(IGraphManager graphManager) { + super(graphManager); + } - @Override - public GraphStateSummary summary() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public GraphStateSummary summary() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public StateMetric metric() { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public StateMetric metric() { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableStateImpl.java index 9bc57ff07..c7f5d073c 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/ManageableStateImpl.java @@ -23,16 +23,15 @@ public class ManageableStateImpl implements ManageableState { - private IStateManager stateManager; - private StateOperator operator; + private IStateManager stateManager; + private StateOperator operator; - public ManageableStateImpl(IStateManager stateManager) { - this.stateManager = stateManager; - this.operator = new StateOperatorImpl(this.stateManager); - } - - public StateOperator operate() { - return operator; - } + public ManageableStateImpl(IStateManager stateManager) { + this.stateManager = stateManager; + this.operator = new StateOperatorImpl(this.stateManager); + } + public StateOperator operate() { + return operator; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/StateOperatorImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/StateOperatorImpl.java index 794201a64..c31ec4e67 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/StateOperatorImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/manage/StateOperatorImpl.java @@ -25,53 +25,53 @@ public class StateOperatorImpl implements StateOperator { - private final IStateManager stateManager; - private long checkpointId; + private final IStateManager stateManager; + private long checkpointId; - public StateOperatorImpl(IStateManager stateManager) { - this.stateManager = stateManager; - } + public StateOperatorImpl(IStateManager stateManager) { + this.stateManager = stateManager; + } - @Override - public void load(LoadOption loadOption) { - if (loadOption.getCheckPointId() == 0) { - loadOption.withCheckpointId(this.checkpointId); - } - this.stateManager.doStoreAction(ActionType.LOAD, new ActionRequest(loadOption)); + @Override + public void load(LoadOption loadOption) { + if (loadOption.getCheckPointId() == 0) { + loadOption.withCheckpointId(this.checkpointId); } + this.stateManager.doStoreAction(ActionType.LOAD, new ActionRequest(loadOption)); + } - @Override - public void setCheckpointId(long checkpointId) { - this.checkpointId = checkpointId; - } + @Override + public void setCheckpointId(long checkpointId) { + this.checkpointId = checkpointId; + } - @Override - public void finish() { - this.stateManager.doStoreAction(ActionType.FINISH, new ActionRequest()); - } + @Override + public void finish() { + this.stateManager.doStoreAction(ActionType.FINISH, new ActionRequest()); + } - @Override - public void compact() { - this.stateManager.doStoreAction(ActionType.COMPACT, new ActionRequest()); - } + @Override + public void compact() { + this.stateManager.doStoreAction(ActionType.COMPACT, new ActionRequest()); + } - @Override - public void archive() { - this.stateManager.doStoreAction(ActionType.ARCHIVE, new ActionRequest<>(checkpointId)); - } + @Override + public void archive() { + this.stateManager.doStoreAction(ActionType.ARCHIVE, new ActionRequest<>(checkpointId)); + } - @Override - public void recover() { - this.stateManager.doStoreAction(ActionType.RECOVER, new ActionRequest<>(checkpointId)); - } + @Override + public void recover() { + this.stateManager.doStoreAction(ActionType.RECOVER, new ActionRequest<>(checkpointId)); + } - @Override - public void close() { - this.stateManager.doStoreAction(ActionType.CLOSE, new ActionRequest()); - } + @Override + public void close() { + this.stateManager.doStoreAction(ActionType.CLOSE, new ActionRequest()); + } - @Override - public void drop() { - this.stateManager.doStoreAction(ActionType.DROP, new ActionRequest()); - } + @Override + public void drop() { + this.stateManager.doStoreAction(ActionType.DROP, new ActionRequest()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryCondition.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryCondition.java index 1a93afbd0..831b95271 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryCondition.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryCondition.java @@ -21,6 +21,7 @@ import java.util.Collection; import java.util.List; + import org.apache.geaflow.state.graph.encoder.EdgeAtom; import org.apache.geaflow.state.pushdown.filter.EmptyFilter; import org.apache.geaflow.state.pushdown.filter.IFilter; @@ -30,13 +31,13 @@ public class QueryCondition { - public Collection versions; - public K queryId; - public List queryIds; - public boolean isFullScan; - public IFilter[] stateFilters = new IFilter[] {EmptyFilter.getInstance()}; - public IProjector projector; - public IEdgeLimit limit; - public EdgeAtom order; - public KeyGroup keyGroup; + public Collection versions; + public K queryId; + public List queryIds; + public boolean isFullScan; + public IFilter[] stateFilters = new IFilter[] {EmptyFilter.getInstance()}; + public IProjector projector; + public IEdgeLimit limit; + public EdgeAtom order; + public KeyGroup keyGroup; } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphStateImpl.java index a1029c300..ba7c1f793 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableAllGraphStateImpl.java @@ -25,23 +25,24 @@ public class QueryableAllGraphStateImpl extends QueryableGraphStateImpl implements QueryableAllGraphState { - public QueryableAllGraphStateImpl(Long version, QueryType type, - IGraphManager graphManager, - QueryCondition queryCondition) { - super(version, type, graphManager); - this.queryCondition = queryCondition; - } + public QueryableAllGraphStateImpl( + Long version, + QueryType type, + IGraphManager graphManager, + QueryCondition queryCondition) { + super(version, type, graphManager); + this.queryCondition = queryCondition; + } - public QueryableAllGraphStateImpl(QueryType type, - IGraphManager graphManager, - QueryCondition queryCondition) { - super(type, graphManager); - this.queryCondition = queryCondition; - } + public QueryableAllGraphStateImpl( + QueryType type, IGraphManager graphManager, QueryCondition queryCondition) { + super(type, graphManager); + this.queryCondition = queryCondition; + } - @Override - public QueryableGraphState by(IFilter filter) { - queryCondition.stateFilters[0] = filter; - return this; - } + @Override + public QueryableGraphState by(IFilter filter) { + queryCondition.stateFilters[0] = filter; + return this; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableGraphStateImpl.java index e9a0a7b36..87f94fdc5 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableGraphStateImpl.java @@ -19,12 +19,11 @@ package org.apache.geaflow.state.query; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -46,183 +45,229 @@ import org.apache.geaflow.state.pushdown.project.ProjectType; import org.apache.geaflow.state.strategy.manager.IGraphManager; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + public class QueryableGraphStateImpl implements QueryableGraphState { - protected final QueryType type; - protected final IGraphManager graphManager; - protected IFilterConverter filterConverter; - protected long version = -1; - protected QueryCondition queryCondition; + protected final QueryType type; + protected final IGraphManager graphManager; + protected IFilterConverter filterConverter; + protected long version = -1; + protected QueryCondition queryCondition; - public QueryableGraphStateImpl(QueryType type, IGraphManager graphManager) { - this.type = type; - this.graphManager = graphManager; - this.filterConverter = this.graphManager.getFilterConverter(); - } + public QueryableGraphStateImpl(QueryType type, IGraphManager graphManager) { + this.type = type; + this.graphManager = graphManager; + this.filterConverter = this.graphManager.getFilterConverter(); + } - public QueryableGraphStateImpl(QueryType type, IGraphManager graphManager, - QueryCondition queryCondition) { - this(type, graphManager); - this.queryCondition = queryCondition; - } + public QueryableGraphStateImpl( + QueryType type, IGraphManager graphManager, QueryCondition queryCondition) { + this(type, graphManager); + this.queryCondition = queryCondition; + } - public QueryableGraphStateImpl(Long version, QueryType type, - IGraphManager graphManager) { - this(type, graphManager); - this.version = version; - } + public QueryableGraphStateImpl( + Long version, QueryType type, IGraphManager graphManager) { + this(type, graphManager); + this.version = version; + } - @Override - public QueryableGraphState select(IProjector projector) { - queryCondition.projector = projector; - if (projector.projectType() == ProjectType.DST_ID || projector.projectType() == ProjectType.TIME) { - QueryType queryType = new QueryType<>(DataType.PROJECT_FIELD); - return new QueryableGraphStateImpl<>(queryType, graphManager, queryCondition); - } else { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public QueryableGraphState select(IProjector projector) { + queryCondition.projector = projector; + if (projector.projectType() == ProjectType.DST_ID + || projector.projectType() == ProjectType.TIME) { + QueryType queryType = new QueryType<>(DataType.PROJECT_FIELD); + return new QueryableGraphStateImpl<>(queryType, graphManager, queryCondition); + } else { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - @Override - public QueryableGraphState limit(long out, long in) { - boolean isSingleLimit = FilterHelper.isSingleLimit(queryCondition.stateFilters); - queryCondition.limit = isSingleLimit ? new SingleEdgeLimit(out, in) : new ComposedEdgeLimit(out, in); - return this; - } + @Override + public QueryableGraphState limit(long out, long in) { + boolean isSingleLimit = FilterHelper.isSingleLimit(queryCondition.stateFilters); + queryCondition.limit = + isSingleLimit ? new SingleEdgeLimit(out, in) : new ComposedEdgeLimit(out, in); + return this; + } - @Override - public QueryableGraphState orderBy(EdgeAtom atom) { - queryCondition.order = atom; - return this; - } + @Override + public QueryableGraphState orderBy(EdgeAtom atom) { + queryCondition.order = atom; + return this; + } - @Override - public List asList() { - return Lists.newArrayList(iterator()); - } + @Override + public List asList() { + return Lists.newArrayList(iterator()); + } - @Override - public CloseableIterator idIterator() { - return version < 0 - ? this.graphManager.getStaticGraphTrait().vertexIDIterator(getPushDown()) - : this.graphManager.getDynamicGraphTrait().vertexIDIterator(version, getPushDown()); - } + @Override + public CloseableIterator idIterator() { + return version < 0 + ? this.graphManager.getStaticGraphTrait().vertexIDIterator(getPushDown()) + : this.graphManager.getDynamicGraphTrait().vertexIDIterator(version, getPushDown()); + } - private Map buildMapFilter() { - Map mapFilters = new HashMap<>(queryCondition.stateFilters.length); - Preconditions.checkArgument(queryCondition.stateFilters.length == queryCondition.queryIds.size()); - for (int i = 0; i < queryCondition.stateFilters.length; i++) { - mapFilters.put(queryCondition.queryIds.get(i), - filterConverter.convert(queryCondition.stateFilters[i])); - } - return mapFilters; + private Map buildMapFilter() { + Map mapFilters = new HashMap<>(queryCondition.stateFilters.length); + Preconditions.checkArgument( + queryCondition.stateFilters.length == queryCondition.queryIds.size()); + for (int i = 0; i < queryCondition.stateFilters.length; i++) { + mapFilters.put( + queryCondition.queryIds.get(i), filterConverter.convert(queryCondition.stateFilters[i])); } + return mapFilters; + } - protected StatePushDown getPushDown() { - StatePushDown pushDown = queryCondition.keyGroup == null ? StatePushDown.of() : - KeyGroupStatePushDown.of(queryCondition.keyGroup); + protected StatePushDown getPushDown() { + StatePushDown pushDown = + queryCondition.keyGroup == null + ? StatePushDown.of() + : KeyGroupStatePushDown.of(queryCondition.keyGroup); - pushDown.withEdgeLimit(queryCondition.limit).withOrderField(queryCondition.order); + pushDown.withEdgeLimit(queryCondition.limit).withOrderField(queryCondition.order); - if (queryCondition.stateFilters.length > 1) { - pushDown.withFilters(buildMapFilter()); - } else { - pushDown.withFilter(filterConverter.convert(queryCondition.stateFilters[0])); - } - return pushDown; + if (queryCondition.stateFilters.length > 1) { + pushDown.withFilters(buildMapFilter()); + } else { + pushDown.withFilter(filterConverter.convert(queryCondition.stateFilters[0])); } + return pushDown; + } + + private CloseableIterator staticIterator() { + StatePushDown condition = getPushDown(); + CloseableIterator it; - private CloseableIterator staticIterator() { - StatePushDown condition = getPushDown(); - CloseableIterator it; - - switch (this.type.getType()) { - case V: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getStaticGraphTrait().getVertexIterator(condition) - : (CloseableIterator) this.graphManager.getStaticGraphTrait().getVertexIterator(queryCondition.queryIds, condition); - break; - case E: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getStaticGraphTrait().getEdgeIterator(condition) - : (CloseableIterator) this.graphManager.getStaticGraphTrait().getEdgeIterator(queryCondition.queryIds, condition); - break; - case VE: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getStaticGraphTrait().getOneDegreeGraphIterator(condition) - : (CloseableIterator) this.graphManager.getStaticGraphTrait().getOneDegreeGraphIterator(queryCondition.queryIds, condition); - break; - case PROJECT_FIELD: - IStatePushDown, R> projectCondition = condition.withProjector(queryCondition.projector); - Iterator> res = queryCondition.isFullScan - ? this.graphManager.getStaticGraphTrait().getEdgeProjectIterator(projectCondition) - : this.graphManager.getStaticGraphTrait().getEdgeProjectIterator(queryCondition.queryIds, projectCondition); - it = new IteratorWithFn<>(res, Tuple::getF1); - break; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - return it; + switch (this.type.getType()) { + case V: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager.getStaticGraphTrait().getVertexIterator(condition) + : (CloseableIterator) + this.graphManager + .getStaticGraphTrait() + .getVertexIterator(queryCondition.queryIds, condition); + break; + case E: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager.getStaticGraphTrait().getEdgeIterator(condition) + : (CloseableIterator) + this.graphManager + .getStaticGraphTrait() + .getEdgeIterator(queryCondition.queryIds, condition); + break; + case VE: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager.getStaticGraphTrait().getOneDegreeGraphIterator(condition) + : (CloseableIterator) + this.graphManager + .getStaticGraphTrait() + .getOneDegreeGraphIterator(queryCondition.queryIds, condition); + break; + case PROJECT_FIELD: + IStatePushDown, R> projectCondition = + condition.withProjector(queryCondition.projector); + Iterator> res = + queryCondition.isFullScan + ? this.graphManager.getStaticGraphTrait().getEdgeProjectIterator(projectCondition) + : this.graphManager + .getStaticGraphTrait() + .getEdgeProjectIterator(queryCondition.queryIds, projectCondition); + it = new IteratorWithFn<>(res, Tuple::getF1); + break; + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + return it; + } + + private CloseableIterator dynamicIterator() { + StatePushDown condition = getPushDown(); + CloseableIterator it; - private CloseableIterator dynamicIterator() { - StatePushDown condition = getPushDown(); - CloseableIterator it; - - switch (this.type.getType()) { - case V: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getDynamicGraphTrait().getVertexIterator(version, condition) - : (CloseableIterator) this.graphManager.getDynamicGraphTrait().getVertexIterator(version, queryCondition.queryIds, condition); - break; - case E: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getDynamicGraphTrait().getEdgeIterator(version, condition) - : (CloseableIterator) this.graphManager.getDynamicGraphTrait().getEdgeIterator(version, queryCondition.queryIds, condition); - break; - case VE: - it = queryCondition.isFullScan - ? (CloseableIterator) this.graphManager.getDynamicGraphTrait().getOneDegreeGraphIterator(version, condition) - : (CloseableIterator) this.graphManager.getDynamicGraphTrait().getOneDegreeGraphIterator(version, queryCondition.queryIds, condition); - break; - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - return it; + switch (this.type.getType()) { + case V: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager.getDynamicGraphTrait().getVertexIterator(version, condition) + : (CloseableIterator) + this.graphManager + .getDynamicGraphTrait() + .getVertexIterator(version, queryCondition.queryIds, condition); + break; + case E: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager.getDynamicGraphTrait().getEdgeIterator(version, condition) + : (CloseableIterator) + this.graphManager + .getDynamicGraphTrait() + .getEdgeIterator(version, queryCondition.queryIds, condition); + break; + case VE: + it = + queryCondition.isFullScan + ? (CloseableIterator) + this.graphManager + .getDynamicGraphTrait() + .getOneDegreeGraphIterator(version, condition) + : (CloseableIterator) + this.graphManager + .getDynamicGraphTrait() + .getOneDegreeGraphIterator(version, queryCondition.queryIds, condition); + break; + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + return it; + } - @Override - public CloseableIterator iterator() { - if (queryCondition.isFullScan) { - Preconditions.checkArgument(queryCondition.stateFilters.length <= 1, - "full scan only support single or none filter now."); - } - CloseableIterator it = version < 0 ? staticIterator() : dynamicIterator(); - return new StandardIterator<>(it); + @Override + public CloseableIterator iterator() { + if (queryCondition.isFullScan) { + Preconditions.checkArgument( + queryCondition.stateFilters.length <= 1, + "full scan only support single or none filter now."); } + CloseableIterator it = version < 0 ? staticIterator() : dynamicIterator(); + return new StandardIterator<>(it); + } - @Override - public R get() { - Iterator it = iterator(); - if (it.hasNext()) { - return it.next(); - } - return null; + @Override + public R get() { + Iterator it = iterator(); + if (it.hasNext()) { + return it.next(); } + return null; + } - @Override - public Map aggregate() { - Preconditions.checkArgument(type.getType() == DataType.E, "only edge agg is supported now."); - Preconditions.checkArgument(version < 0, "only static graph is supported now."); - Preconditions.checkArgument(queryCondition.limit == null, "limit not supported now."); - if (queryCondition.queryIds != null) { - Preconditions.checkArgument(queryCondition.stateFilters.length == 1 - || queryCondition.stateFilters.length == queryCondition.queryIds.size(), - "filter number must be 1 or equal to key number."); - } - StatePushDown condition = getPushDown(); - return queryCondition.isFullScan - ? this.graphManager.getStaticGraphTrait().getAggResult(condition) - : this.graphManager.getStaticGraphTrait().getAggResult(queryCondition.queryIds, condition); + @Override + public Map aggregate() { + Preconditions.checkArgument(type.getType() == DataType.E, "only edge agg is supported now."); + Preconditions.checkArgument(version < 0, "only static graph is supported now."); + Preconditions.checkArgument(queryCondition.limit == null, "limit not supported now."); + if (queryCondition.queryIds != null) { + Preconditions.checkArgument( + queryCondition.stateFilters.length == 1 + || queryCondition.stateFilters.length == queryCondition.queryIds.size(), + "filter number must be 1 or equal to key number."); } + StatePushDown condition = getPushDown(); + return queryCondition.isFullScan + ? this.graphManager.getStaticGraphTrait().getAggResult(condition) + : this.graphManager.getStaticGraphTrait().getAggResult(queryCondition.queryIds, condition); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphStateImpl.java index de1fee36f..6c7764a3f 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableKeysGraphStateImpl.java @@ -25,29 +25,30 @@ public class QueryableKeysGraphStateImpl extends QueryableGraphStateImpl implements QueryableKeysGraphState { - public QueryableKeysGraphStateImpl(Long version, QueryType type, - IGraphManager graphManager, - QueryCondition queryCondition) { - super(version, type, graphManager); - this.queryCondition = queryCondition; - } + public QueryableKeysGraphStateImpl( + Long version, + QueryType type, + IGraphManager graphManager, + QueryCondition queryCondition) { + super(version, type, graphManager); + this.queryCondition = queryCondition; + } - public QueryableKeysGraphStateImpl(QueryType type, - IGraphManager graphManager, - QueryCondition queryCondition) { - super(type, graphManager); - this.queryCondition = queryCondition; - } + public QueryableKeysGraphStateImpl( + QueryType type, IGraphManager graphManager, QueryCondition queryCondition) { + super(type, graphManager); + this.queryCondition = queryCondition; + } - @Override - public QueryableGraphState by(IFilter... filters) { - queryCondition.stateFilters = filters; - return this; - } + @Override + public QueryableGraphState by(IFilter... filters) { + queryCondition.stateFilters = filters; + return this; + } - @Override - public QueryableGraphState by(IFilter filter) { - queryCondition.stateFilters[0] = filter; - return this; - } + @Override + public QueryableGraphState by(IFilter filter) { + queryCondition.stateFilters[0] = filter; + return this; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableOneKeyGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableOneKeyGraphStateImpl.java index 365fdee5e..0b79077d7 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableOneKeyGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableOneKeyGraphStateImpl.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; @@ -34,62 +35,65 @@ public class QueryableOneKeyGraphStateImpl extends QueryableKeysGraphStateImpl { - public QueryableOneKeyGraphStateImpl(QueryType type, IGraphManager graphManager, - QueryCondition queryCondition) { - super(type, graphManager, queryCondition); - } + public QueryableOneKeyGraphStateImpl( + QueryType type, IGraphManager graphManager, QueryCondition queryCondition) { + super(type, graphManager, queryCondition); + } - @Override - public QueryableGraphState by(IFilter[] filters) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + @Override + public QueryableGraphState by(IFilter[] filters) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); + } - @Override - public QueryableGraphState by(IFilter filter) { - this.queryCondition.stateFilters[0] = filter; - return this; - } + @Override + public QueryableGraphState by(IFilter filter) { + this.queryCondition.stateFilters[0] = filter; + return this; + } - @Override - public QueryableGraphState select(IProjector projector) { - throw new UnsupportedOperationException(); - } + @Override + public QueryableGraphState select(IProjector projector) { + throw new UnsupportedOperationException(); + } - @Override - public CloseableIterator iterator() { - return IteratorWithClose.wrap(asList().iterator()); - } + @Override + public CloseableIterator iterator() { + return IteratorWithClose.wrap(asList().iterator()); + } - protected StatePushDown getPushDown() { - return StatePushDown.of() - .withFilter(filterConverter.convert(queryCondition.stateFilters[0])) - .withEdgeLimit(queryCondition.limit) - .withOrderField(queryCondition.order); - } + protected StatePushDown getPushDown() { + return StatePushDown.of() + .withFilter(filterConverter.convert(queryCondition.stateFilters[0])) + .withEdgeLimit(queryCondition.limit) + .withOrderField(queryCondition.order); + } - @Override - public List asList() { - if (DataType.E == this.type.getType()) { - return (List) this.graphManager.getStaticGraphTrait().getEdges( - queryCondition.queryId, getPushDown()); - } else { - return Collections.singletonList(get()); - } + @Override + public List asList() { + if (DataType.E == this.type.getType()) { + return (List) + this.graphManager.getStaticGraphTrait().getEdges(queryCondition.queryId, getPushDown()); + } else { + return Collections.singletonList(get()); } + } - @Override - public R get() { - switch (this.type.getType()) { - case V: - return (R) this.graphManager.getStaticGraphTrait().getVertex( - queryCondition.queryId, getPushDown()); - case VE: - return (R) this.graphManager.getStaticGraphTrait().getOneDegreeGraph( - queryCondition.queryId, getPushDown()); - default: - throw new GeaflowRuntimeException( - RuntimeErrors.INST.runError("not supported " + this.type.getType())); - } - + @Override + public R get() { + switch (this.type.getType()) { + case V: + return (R) + this.graphManager + .getStaticGraphTrait() + .getVertex(queryCondition.queryId, getPushDown()); + case VE: + return (R) + this.graphManager + .getStaticGraphTrait() + .getOneDegreeGraph(queryCondition.queryId, getPushDown()); + default: + throw new GeaflowRuntimeException( + RuntimeErrors.INST.runError("not supported " + this.type.getType())); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphStateImpl.java b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphStateImpl.java index bd54f845d..eb3b36114 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphStateImpl.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/main/java/org/apache/geaflow/state/query/QueryableVersionGraphStateImpl.java @@ -19,8 +19,8 @@ package org.apache.geaflow.state.query; -import com.google.common.base.Preconditions; import java.util.Map; + import org.apache.geaflow.common.errorcode.RuntimeErrors; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -30,47 +30,59 @@ import org.apache.geaflow.state.pushdown.inner.IFilterConverter; import org.apache.geaflow.state.strategy.manager.IGraphManager; -public class QueryableVersionGraphStateImpl implements QueryableVersionGraphState { +import com.google.common.base.Preconditions; + +public class QueryableVersionGraphStateImpl + implements QueryableVersionGraphState { - protected final QueryType type; - protected final IGraphManager graphManager; - protected QueryCondition queryCondition; - protected IFilterConverter filterConverter; + protected final QueryType type; + protected final IGraphManager graphManager; + protected QueryCondition queryCondition; + protected IFilterConverter filterConverter; - public QueryableVersionGraphStateImpl(QueryType type, - IGraphManager graphManager, - QueryCondition queryCondition) { - this.type = type; - this.graphManager = graphManager; - this.queryCondition = queryCondition; - this.filterConverter = this.graphManager.getFilterConverter(); - } + public QueryableVersionGraphStateImpl( + QueryType type, IGraphManager graphManager, QueryCondition queryCondition) { + this.type = type; + this.graphManager = graphManager; + this.queryCondition = queryCondition; + this.filterConverter = this.graphManager.getFilterConverter(); + } - @Override - public QueryableVersionGraphState by(IFilter filter) { - this.queryCondition.stateFilters[0] = filter; - return this; - } + @Override + public QueryableVersionGraphState by(IFilter filter) { + this.queryCondition.stateFilters[0] = filter; + return this; + } - @Override - public Map asMap() { - Preconditions.checkArgument(queryCondition.queryIds.size() == 1); - K id = queryCondition.queryIds.iterator().next(); - Map> res; - if (this.type.getType() == DataType.V) { - if (queryCondition.versions != null) { - res = this.graphManager.getDynamicGraphTrait().getVersionData(id, + @Override + public Map asMap() { + Preconditions.checkArgument(queryCondition.queryIds.size() == 1); + K id = queryCondition.queryIds.iterator().next(); + Map> res; + if (this.type.getType() == DataType.V) { + if (queryCondition.versions != null) { + res = + this.graphManager + .getDynamicGraphTrait() + .getVersionData( + id, queryCondition.versions, - StatePushDown.of().withFilter(filterConverter.convert(queryCondition.stateFilters[0])), + StatePushDown.of() + .withFilter(filterConverter.convert(queryCondition.stateFilters[0])), DataType.V); - } else { - res = this.graphManager.getDynamicGraphTrait().getAllVersionData(id, - StatePushDown.of().withFilter(filterConverter.convert(queryCondition.stateFilters[0])), + } else { + res = + this.graphManager + .getDynamicGraphTrait() + .getAllVersionData( + id, + StatePushDown.of() + .withFilter(filterConverter.convert(queryCondition.stateFilters[0])), DataType.V); - } - } else { - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } - return (Map) res; + } + } else { + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + return (Map) res; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/GraphReadOnlyStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/GraphReadOnlyStateTest.java index 46a48c1d4..0d8f1f86c 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/GraphReadOnlyStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/GraphReadOnlyStateTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; import java.io.File; import java.io.IOException; import java.util.Arrays; @@ -27,6 +26,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -52,219 +52,257 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -public class GraphReadOnlyStateTest { - - private Map config; - - @BeforeMethod - public void setUp() throws IOException { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); - Map persistConfig = new HashMap<>(); - config = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - } - - @Test - public void testStaticReadOnlyState() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - - desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - desc.withStateMode(StateMode.RDONLY); - - graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().recover(); - - list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - graphState.manage().operate().drop(); - } - - @Test - public void testDynamicReadOnlyState() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test2", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - desc.withDataModel(DataModel.DYNAMIC_GRAPH); - - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.dynamicGraph().E().add(1, new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); - graphState.dynamicGraph().E().add(2, new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); - graphState.dynamicGraph().E().add(1, new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); - graphState.dynamicGraph().E().add(2, new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); - graphState.dynamicGraph().V().add(1, new ValueVertex<>("1", "3")); - graphState.dynamicGraph().V().add(1, new ValueVertex<>("2", "4")); - graphState.dynamicGraph().V().add(2, new ValueVertex<>("1", "5")); - graphState.dynamicGraph().V().add(2, new ValueVertex<>("2", "6")); - graphState.manage().operate().finish(); - - List> list = - graphState.dynamicGraph().E().query(1L, Arrays.asList("1", "2")).asList(); - Assert.assertEquals(list.size(), 2); - - long version = graphState.dynamicGraph().V().getLatestVersion("2"); - Assert.assertEquals(version, 2); - - IVertex vertex = graphState.dynamicGraph().V().query(2, "1").get(); - Assert.assertEquals(vertex.getValue(), "5"); - - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - - desc = GraphStateDescriptor.build("test2", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - desc.withDataModel(DataModel.DYNAMIC_GRAPH).withStateMode(StateMode.RDONLY); - - graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().recover(); - - list = graphState.dynamicGraph().E().query(1L, Arrays.asList("1", "2")).asList(); - Assert.assertEquals(list.size(), 2); - - version = graphState.dynamicGraph().V().getLatestVersion("2"); - Assert.assertEquals(version, 2); - - vertex = graphState.dynamicGraph().V().query(2, "1").get(); - Assert.assertEquals(vertex.getValue(), "5"); - - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } - - @Test(invocationCount = 5) - public void testReadOnlyStateWarmup() throws IOException { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - - Configuration configuration = new Configuration(config); - ViewMetaBookKeeper viewMetaBookKeeper = new ViewMetaBookKeeper("test1", configuration); - GraphState graphState = StateFactory.buildGraphState(desc, configuration); - - graphState.manage().operate().setCheckpointId(1); - - graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - viewMetaBookKeeper.saveViewVersion(1); - viewMetaBookKeeper.archive(); - - - desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - desc.withStateMode(StateMode.RDONLY); - - config.put(StateConfigKeys.STATE_BACKGROUND_SYNC_ENABLE.getKey(), "true"); - config.put(StateConfigKeys.STATE_RECOVER_LATEST_VERSION_ENABLE.getKey(), "true"); - - graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - graphState.manage().operate().load(LoadOption.of()); - - list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); +import com.google.common.collect.Iterators; - vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); +public class GraphReadOnlyStateTest { - graphState.manage().operate().drop(); - viewMetaBookKeeper.archive(); - } -} \ No newline at end of file + private Map config; + + @BeforeMethod + public void setUp() throws IOException { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); + Map persistConfig = new HashMap<>(); + config = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + } + + @Test + public void testStaticReadOnlyState() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + graphState.manage().operate().archive(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + desc.withStateMode(StateMode.RDONLY); + + graphState = StateFactory.buildGraphState(desc, new Configuration(config)); + graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().recover(); + + list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + graphState.manage().operate().drop(); + } + + @Test + public void testDynamicReadOnlyState() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test2", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + desc.withDataModel(DataModel.DYNAMIC_GRAPH); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.dynamicGraph().E().add(1, new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); + graphState.dynamicGraph().E().add(2, new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); + graphState.dynamicGraph().E().add(1, new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); + graphState.dynamicGraph().E().add(2, new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); + graphState.dynamicGraph().V().add(1, new ValueVertex<>("1", "3")); + graphState.dynamicGraph().V().add(1, new ValueVertex<>("2", "4")); + graphState.dynamicGraph().V().add(2, new ValueVertex<>("1", "5")); + graphState.dynamicGraph().V().add(2, new ValueVertex<>("2", "6")); + graphState.manage().operate().finish(); + + List> list = + graphState.dynamicGraph().E().query(1L, Arrays.asList("1", "2")).asList(); + Assert.assertEquals(list.size(), 2); + + long version = graphState.dynamicGraph().V().getLatestVersion("2"); + Assert.assertEquals(version, 2); + + IVertex vertex = graphState.dynamicGraph().V().query(2, "1").get(); + Assert.assertEquals(vertex.getValue(), "5"); + + graphState.manage().operate().archive(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + desc = GraphStateDescriptor.build("test2", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + desc.withDataModel(DataModel.DYNAMIC_GRAPH).withStateMode(StateMode.RDONLY); + + graphState = StateFactory.buildGraphState(desc, new Configuration(config)); + graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().recover(); + + list = graphState.dynamicGraph().E().query(1L, Arrays.asList("1", "2")).asList(); + Assert.assertEquals(list.size(), 2); + + version = graphState.dynamicGraph().V().getLatestVersion("2"); + Assert.assertEquals(version, 2); + + vertex = graphState.dynamicGraph().V().query(2, "1").get(); + Assert.assertEquals(vertex.getValue(), "5"); + + graphState.manage().operate().archive(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test(invocationCount = 5) + public void testReadOnlyStateWarmup() throws IOException { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + + Configuration configuration = new Configuration(config); + ViewMetaBookKeeper viewMetaBookKeeper = new ViewMetaBookKeeper("test1", configuration); + GraphState graphState = + StateFactory.buildGraphState(desc, configuration); + + graphState.manage().operate().setCheckpointId(1); + + graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + graphState.manage().operate().archive(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + viewMetaBookKeeper.saveViewVersion(1); + viewMetaBookKeeper.archive(); + + desc = GraphStateDescriptor.build("test1", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + desc.withStateMode(StateMode.RDONLY); + + config.put(StateConfigKeys.STATE_BACKGROUND_SYNC_ENABLE.getKey(), "true"); + config.put(StateConfigKeys.STATE_RECOVER_LATEST_VERSION_ENABLE.getKey(), "true"); + + graphState = StateFactory.buildGraphState(desc, new Configuration(config)); + graphState.manage().operate().load(LoadOption.of()); + + list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + graphState.manage().operate().drop(); + viewMetaBookKeeper.archive(); + } +} diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryDynamicGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryDynamicGraphStateTest.java index 0f8a42dfc..77010a1de 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryDynamicGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryDynamicGraphStateTest.java @@ -19,13 +19,12 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.ValueEdge; @@ -41,93 +40,111 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; + public class MemoryDynamicGraphStateTest { - private long ts; - private String testName; - private GraphState graphState; - - @BeforeMethod - public void setUp() { - this.ts = System.currentTimeMillis(); - this.testName = "graph-state-test-" + ts; - graphState = prepareData(); - } - - @Test - public void test() { - List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.dynamicGraph().E().query(1L, "1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); - Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 3); - - Map> map = - graphState.dynamicGraph().V().query("1").asMap(); - Assert.assertEquals(map.size(), 3); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 1); - - map = graphState.dynamicGraph().V().query("1").by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L)).asMap(); - Assert.assertEquals(map.size(), 2); - - List> res = - graphState.dynamicGraph().VE().query(2L, "2").asList(); - Assert.assertEquals(res.size(), 1); - - res = graphState.dynamicGraph().VE().query(3L, "1").asList(); - Assert.assertEquals(res.size(), 1); - - Iterator idIterator = graphState.dynamicGraph().V().idIterator(); - List idList = Lists.newArrayList(idIterator); - Assert.assertEquals(idList.size(), 2); - - idIterator = graphState.dynamicGraph().V().query(2L, new KeyGroup(1, 1)).idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 1); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } - - private GraphState prepareData() { - Map config = new HashMap<>(); - - GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.MEMORY.name()); - desc.withDataModel(DataModel.DYNAMIC_GRAPH).withKeyGroup(new KeyGroup(0, 1)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); - graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); - graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); - return graphState; - } + private long ts; + private String testName; + private GraphState graphState; + + @BeforeMethod + public void setUp() { + this.ts = System.currentTimeMillis(); + this.testName = "graph-state-test-" + ts; + graphState = prepareData(); + } + + @Test + public void test() { + List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .dynamicGraph() + .E() + .query(1L, "1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); + Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 3); + + Map> map = graphState.dynamicGraph().V().query("1").asMap(); + Assert.assertEquals(map.size(), 3); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); + Assert.assertEquals(map.size(), 2); + + map = + graphState + .dynamicGraph() + .V() + .query("1", Arrays.asList(2L, 3L)) + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 1); + + map = + graphState + .dynamicGraph() + .V() + .query("1") + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 2); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L)).asMap(); + Assert.assertEquals(map.size(), 2); + + List> res = + graphState.dynamicGraph().VE().query(2L, "2").asList(); + Assert.assertEquals(res.size(), 1); + + res = graphState.dynamicGraph().VE().query(3L, "1").asList(); + Assert.assertEquals(res.size(), 1); + + Iterator idIterator = graphState.dynamicGraph().V().idIterator(); + List idList = Lists.newArrayList(idIterator); + Assert.assertEquals(idList.size(), 2); + + idIterator = graphState.dynamicGraph().V().query(2L, new KeyGroup(1, 1)).idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 1); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + private GraphState prepareData() { + Map config = new HashMap<>(); + + GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.MEMORY.name()); + desc.withDataModel(DataModel.DYNAMIC_GRAPH) + .withKeyGroup(new KeyGroup(0, 1)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); + graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); + graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); + return graphState; + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryGraphStateTest.java index f987ec528..fb724801c 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/MemoryGraphStateTest.java @@ -19,15 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.type.Types; @@ -54,159 +52,200 @@ import org.testng.annotations.Factory; import org.testng.annotations.Test; -public class MemoryGraphStateTest { - - private final Map additionalConfig; - private long ts; - private String testName; - - @BeforeMethod - public void setUp() { - this.ts = System.currentTimeMillis(); - this.testName = "graph-state-test-" + ts; - } - - @AfterMethod - public void tearDown() throws IOException { - FileUtils.deleteQuietly(new File("/tmp/" + testName)); - } - - public MemoryGraphStateTest(Map config) { - this.additionalConfig = config; - } - - public static class GraphMemoryStoreTestFactory { - - @Factory - public Object[] factoryMethod() { - return new Object[]{ - new MemoryGraphStateTest(new HashMap<>()), - new MemoryGraphStateTest( - ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true")), - }; - } - } - - @Test - public void test() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - - Iterator idIterator = graphState.staticGraph().V().idIterator(); - List idList = Lists.newArrayList(idIterator); - Assert.assertEquals(idList.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } - - @Test - public void testFilter() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueLabelTimeEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "2", "hello", "foo", 1000)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "3", "hello", "bar", 100)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "2", "world", "foo", 1000)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "3", "world", "bar", 100)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1", "2").by( - new EdgeTsFilter<>(TimeRange.of(0, 500))).asList(); - Assert.assertEquals(list.size(), 2); +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; - list = graphState.staticGraph().E().query("1", "2").by( - new EdgeTsFilter<>(TimeRange.of(0, 500)).or(new EdgeTsFilter<>(TimeRange.of(800, - 1100)))).asList(); - Assert.assertEquals(list.size(), 4); +public class MemoryGraphStateTest { - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + private final Map additionalConfig; + private long ts; + private String testName; + + @BeforeMethod + public void setUp() { + this.ts = System.currentTimeMillis(); + this.testName = "graph-state-test-" + ts; + } + + @AfterMethod + public void tearDown() throws IOException { + FileUtils.deleteQuietly(new File("/tmp/" + testName)); + } + + public MemoryGraphStateTest(Map config) { + this.additionalConfig = config; + } + + public static class GraphMemoryStoreTestFactory { + + @Factory + public Object[] factoryMethod() { + return new Object[] { + new MemoryGraphStateTest(new HashMap<>()), + new MemoryGraphStateTest( + ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true")), + }; } - - @Test - public void testLimit() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(new GraphMetaType<>(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - for (int i = 0; i < 10; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 10; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add(new ValueEdge<>(src, dst, "hello" + src + dst, - EdgeDirection.values()[j % 2])); - } - graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); - } - graphState.manage().operate().finish(); - - List> list = - graphState.staticGraph().E().query("1", "2", "3") - .limit(1L, 1L).asList(); - System.out.println(list); - Assert.assertEquals(list.size(), 6); - - list = - graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 10); - - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .limit(1L, 2L).asList(); - Assert.assertEquals(list.size(), 20); - - List targetIds = - graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); - - Assert.assertEquals(targetIds.size(), 20); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + } + + @Test + public void test() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + Map config = new HashMap<>(additionalConfig); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + Iterator idIterator = graphState.staticGraph().V().idIterator(); + List idList = Lists.newArrayList(idIterator); + Assert.assertEquals(idList.size(), 2); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFilter() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, + ValueVertex.class, + String.class, + ValueLabelTimeEdge.class, + String.class))); + Map config = new HashMap<>(additionalConfig); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "2", "hello", "foo", 1000)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "3", "hello", "bar", 100)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "2", "world", "foo", 1000)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "3", "world", "bar", 100)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); + + graphState.manage().operate().finish(); + + List> list = + graphState + .staticGraph() + .E() + .query("1", "2") + .by(new EdgeTsFilter<>(TimeRange.of(0, 500))) + .asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1", "2") + .by( + new EdgeTsFilter<>(TimeRange.of(0, 500)) + .or(new EdgeTsFilter<>(TimeRange.of(800, 1100)))) + .asList(); + Assert.assertEquals(list.size(), 4); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testLimit() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", StoreType.MEMORY.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + Map config = new HashMap<>(additionalConfig); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + for (int i = 0; i < 10; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 10; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add(new ValueEdge<>(src, dst, "hello" + src + dst, EdgeDirection.values()[j % 2])); + } + graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); } + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().E().query("1", "2", "3").limit(1L, 1L).asList(); + System.out.println(list); + Assert.assertEquals(list.size(), 6); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 10); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L).asList(); + Assert.assertEquals(list.size(), 20); + + List targetIds = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance()) + .select(new DstIdProjector<>()) + .limit(1L, 2L) + .asList(); + + Assert.assertEquals(targetIds.size(), 20); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonDynamicGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonDynamicGraphStateTest.java index 2922d7e58..4f6c2159e 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonDynamicGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonDynamicGraphStateTest.java @@ -19,14 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; import java.io.File; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -53,188 +52,217 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class PaimonDynamicGraphStateTest { - - static Map config = new HashMap<>(); - - @BeforeClass - public static void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/PaimonDynamicGraphStateTest")); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "PaimonDynamicGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - config.put(PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), - "file:///tmp/PaimonDynamicGraphStateTest/"); - config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); - } +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; - @AfterClass - public static void tearUp() { - FileUtils.deleteQuietly(new File("/tmp/PaimonDynamicGraphStateTest")); - } +public class PaimonDynamicGraphStateTest { - private GraphState getGraphState(IType type, String name, - Map conf) { - return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + static Map config = new HashMap<>(); + + @BeforeClass + public static void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/PaimonDynamicGraphStateTest")); + Map persistConfig = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "PaimonDynamicGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + config.put( + PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), + "file:///tmp/PaimonDynamicGraphStateTest/"); + config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); + } + + @AfterClass + public static void tearUp() { + FileUtils.deleteQuietly(new File("/tmp/PaimonDynamicGraphStateTest")); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + } + + private GraphState getGraphState( + IType type, String name, Map conf, KeyGroup group, int maxPara) { + GraphMetaType tag = + new GraphMetaType( + type, ValueVertex.class, type.getTypeClass(), ValueEdge.class, type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.PAIMON.name()); + desc.withKeyGroup(group) + .withDataModel(DataModel.DYNAMIC_GRAPH) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + @Test + public void testWriteRead() { + Map conf = new HashMap<>(config); + conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), "false"); + conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); + testApi(conf); + } + + @Test + public void testWriteRead2() { + Map conf = new HashMap<>(config); + conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), "false"); + conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); + conf.put(PaimonConfigKeys.PAIMON_STORE_DATABASE.getKey(), "graph"); + testApi(conf); + } + + private void testApi(Map conf) { + conf.put(StateConfigKeys.STATE_WRITE_BUFFER_SIZE.getKey(), "100"); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testApi", conf); + + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); + graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); + graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); } - private GraphState getGraphState(IType type, String name, - Map conf, KeyGroup group, - int maxPara) { - GraphMetaType tag = new GraphMetaType(type, ValueVertex.class, type.getTypeClass(), - ValueEdge.class, type.getTypeClass()); - - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.PAIMON.name()); - desc.withKeyGroup(group).withDataModel(DataModel.DYNAMIC_GRAPH) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - return graphState; + graphState.dynamicGraph().V().add(4L, new ValueVertex<>("1", "6")); + graphState.dynamicGraph().V().add(4L, new ValueVertex<>("3", "6")); + graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "1", "6")); + graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "2", "6")); + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + + List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .dynamicGraph() + .E() + .query(1L, "1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); + Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 4); + Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("1"), 4); + + Map> map = graphState.dynamicGraph().V().query("1").asMap(); + Assert.assertEquals(map.size(), 4); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); + Assert.assertEquals(map.size(), 2); + + map = + graphState + .dynamicGraph() + .V() + .query("1", Arrays.asList(2L, 3L, 4L)) + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 2); + + map = + graphState + .dynamicGraph() + .V() + .query("1") + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 3); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L, 5L)).asMap(); + Assert.assertEquals(map.size(), 3); + + List> res = + graphState.dynamicGraph().VE().query(2L, "2").asList(); + Assert.assertEquals(res.size(), 1); + + res = graphState.dynamicGraph().VE().query(3L, "1").asList(); + Assert.assertEquals(res.size(), 1); + + Iterator idIterator = graphState.dynamicGraph().V().idIterator(); + List idList = Lists.newArrayList(idIterator); + Assert.assertEquals(idList.size(), 3); + + res = graphState.dynamicGraph().VE().query(4L, "1").asList(); + Assert.assertEquals(res.size(), 1); + Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 2); + + res = + graphState + .dynamicGraph() + .VE() + .query(4L, "1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("1")) + .asList(); + Assert.assertEquals(res.size(), 1); + Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 1); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testKeyGroup() { + Map conf = new HashMap<>(config); + conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); + conf.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testKeyGroup", conf); + + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10; i++) { + graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "2", "hello" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "3", "hello" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "2", "world" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "3", "world" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "3" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("2", "4" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "5" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "6" + i)); } - @Test - public void testWriteRead() { - Map conf = new HashMap<>(config); - conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), "false"); - conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); - testApi(conf); - } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); - @Test - public void testWriteRead2() { - Map conf = new HashMap<>(config); - conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), "false"); - conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); - conf.put(PaimonConfigKeys.PAIMON_STORE_DATABASE.getKey(), "graph"); - testApi(conf); - } + Iterator idIterator = graphState.dynamicGraph().V().idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 2); - private void testApi(Map conf) { - conf.put(StateConfigKeys.STATE_WRITE_BUFFER_SIZE.getKey(), "100"); - GraphState graphState = getGraphState(StringType.INSTANCE, "testApi", conf); - - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); - graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); - graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); - } - - graphState.dynamicGraph().V().add(4L, new ValueVertex<>("1", "6")); - graphState.dynamicGraph().V().add(4L, new ValueVertex<>("3", "6")); - graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "1", "6")); - graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "2", "6")); - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - - List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.dynamicGraph().E().query(1L, "1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); - Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 4); - Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("1"), 4); - - Map> map = graphState.dynamicGraph().V().query("1").asMap(); - Assert.assertEquals(map.size(), 4); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L)).by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1").by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 3); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L, 5L)).asMap(); - Assert.assertEquals(map.size(), 3); - - List> res = - graphState.dynamicGraph().VE().query(2L, "2").asList(); - Assert.assertEquals(res.size(), 1); - - res = graphState.dynamicGraph().VE().query(3L, "1").asList(); - Assert.assertEquals(res.size(), 1); - - Iterator idIterator = graphState.dynamicGraph().V().idIterator(); - List idList = Lists.newArrayList(idIterator); - Assert.assertEquals(idList.size(), 3); - - res = - graphState.dynamicGraph().VE().query(4L, "1").asList(); - Assert.assertEquals(res.size(), 1); - Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 2); - - res = - graphState.dynamicGraph().VE().query(4L, "1") - .by((IEdgeFilter) value -> !value.getTargetId().equals("1")).asList(); - Assert.assertEquals(res.size(), 1); - Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 1); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + idIterator = graphState.dynamicGraph().V().query(1L, new KeyGroup(0, 0)).idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 1); - @Test - public void testKeyGroup() { - Map conf = new HashMap<>(config); - conf.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); - conf.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); - GraphState graphState = getGraphState(StringType.INSTANCE, "testKeyGroup", conf); - - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10; i++) { - graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "2", "hello" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "3", "hello" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "2", "world" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "3", "world" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "3" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("2", "4" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "5" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "6" + i)); - } - - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - - Iterator idIterator = graphState.dynamicGraph().V().idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 2); - - idIterator = graphState.dynamicGraph().V().query(1L, new KeyGroup(0, 0)).idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 1); - - List> list = graphState.dynamicGraph().E().query(1L, - new KeyGroup(0, 0)).by((IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + List> list = + graphState + .dynamicGraph() + .E() + .query(1L, new KeyGroup(0, 0)) + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonGraphStateTest.java index ed1305dc7..e2d421db8 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonGraphStateTest.java @@ -26,6 +26,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -49,130 +50,134 @@ public class PaimonGraphStateTest { - static Map config = new HashMap<>(); - - @BeforeClass - public static void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/PaimonGraphStateTest")); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "PaimonGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - config.put(PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), - "file:///tmp/PaimonGraphStateTest/"); - config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); - config.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); - - } - - @AfterClass - public static void tearUp() { - FileUtils.deleteQuietly(new File("/tmp/PaimonGraphStateTest")); - } - - private GraphState getGraphState(IType type, String name, - Map conf) { - return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); - } - - private GraphState getGraphState(IType type, String name, - Map conf, KeyGroup keyGroup, - int maxPara) { - GraphElementMetas.clearCache(); - GraphMetaType tag = new GraphMetaType(type, ValueVertex.class, ValueVertex::new, - type.getTypeClass(), ValueEdge.class, ValueEdge::new, type.getTypeClass()); - - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.PAIMON.name()); - desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - return graphState; + static Map config = new HashMap<>(); + + @BeforeClass + public static void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/PaimonGraphStateTest")); + Map persistConfig = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "PaimonGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + config.put( + PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), "file:///tmp/PaimonGraphStateTest/"); + config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); + config.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); + } + + @AfterClass + public static void tearUp() { + FileUtils.deleteQuietly(new File("/tmp/PaimonGraphStateTest")); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + } + + private GraphState getGraphState( + IType type, String name, Map conf, KeyGroup keyGroup, int maxPara) { + GraphElementMetas.clearCache(); + GraphMetaType tag = + new GraphMetaType( + type, + ValueVertex.class, + ValueVertex::new, + type.getTypeClass(), + ValueEdge.class, + ValueEdge::new, + type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.PAIMON.name()); + desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + public void testWriteRead(Map conf) { + GraphState graphState = + getGraphState(StringType.INSTANCE, "write_read", conf); + + // set chk = 1 + graphState.manage().operate().setCheckpointId(1L); + // write 1 vertex and 100 edges. + graphState.staticGraph().V().add(new ValueVertex<>("1", "vertex_hello")); + for (int i = 0; i < 100; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("1", id, "edge_hello")); } - - public void testWriteRead(Map conf) { - GraphState graphState = getGraphState(StringType.INSTANCE, - "write_read", conf); - - // set chk = 1 - graphState.manage().operate().setCheckpointId(1L); - // write 1 vertex and 100 edges. - graphState.staticGraph().V().add(new ValueVertex<>("1", "vertex_hello")); - for (int i = 0; i < 100; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("1", id, "edge_hello")); - } - // read nothing since not committed - boolean async = Boolean.parseBoolean(conf.get(STATE_WRITE_ASYNC_ENABLE.getKey())); - if (!async) { - Assert.assertNull(graphState.staticGraph().V().query("1").get()); - Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 0); - } - // commit chk = 1, now be able to read data - graphState.manage().operate().archive(); - Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); - Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); - - // set chk = 2 - graphState.manage().operate().setCheckpointId(2L); - // write 1 new vertex and 390 new edges. - graphState.staticGraph().V().add(new ValueVertex<>("2", "vertex_hello")); - for (int i = 0; i < 200; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("2", id, "edge_hello")); - } - // be not able to read data with chk = 2 since not committed. - Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); - Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); - if (!async) { - Assert.assertNull(graphState.staticGraph().V().query("2").get()); - Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 0); - } - // commit chk = 2, now be able to read data - graphState.manage().operate().archive(); - Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); - Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); - Assert.assertNotNull(graphState.staticGraph().V().query("2").get()); - Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 200); - // Read data which not exists. - Assert.assertEquals(graphState.staticGraph().E().query("3").asList().size(), 0); - Assert.assertEquals(graphState.staticGraph().E().query("").asList().size(), 0); - - // TODO. Rollback to chk = 1, then be not able to read data with chk = 2. - // graphState.manage().operate().setCheckpointId(1); - // graphState.manage().operate().recover(); - // Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); - // Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 390); - // Assert.assertNull(graphState.staticGraph().V().query("2").get()); - // Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 0); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + // read nothing since not committed + boolean async = Boolean.parseBoolean(conf.get(STATE_WRITE_ASYNC_ENABLE.getKey())); + if (!async) { + Assert.assertNull(graphState.staticGraph().V().query("1").get()); + Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 0); } - - @Test - public void testBothWriteMode() { - Map conf = new HashMap<>(config); - conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); - testWriteRead(conf); - - conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); - testWriteRead(conf); + // commit chk = 1, now be able to read data + graphState.manage().operate().archive(); + Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); + Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); + + // set chk = 2 + graphState.manage().operate().setCheckpointId(2L); + // write 1 new vertex and 390 new edges. + graphState.staticGraph().V().add(new ValueVertex<>("2", "vertex_hello")); + for (int i = 0; i < 200; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("2", id, "edge_hello")); } - - @Test - public void testBothWriteMode2() { - Map conf = new HashMap<>(config); - conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); - conf.put(PAIMON_STORE_WAREHOUSE.getKey(), "/tmp/testBothWriteMode2"); - conf.put(PAIMON_STORE_DATABASE.getKey(), "graph"); - testWriteRead(conf); - - conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); - testWriteRead(conf); + // be not able to read data with chk = 2 since not committed. + Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); + Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); + if (!async) { + Assert.assertNull(graphState.staticGraph().V().query("2").get()); + Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 0); } - + // commit chk = 2, now be able to read data + graphState.manage().operate().archive(); + Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); + Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 100); + Assert.assertNotNull(graphState.staticGraph().V().query("2").get()); + Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 200); + // Read data which not exists. + Assert.assertEquals(graphState.staticGraph().E().query("3").asList().size(), 0); + Assert.assertEquals(graphState.staticGraph().E().query("").asList().size(), 0); + + // TODO. Rollback to chk = 1, then be not able to read data with chk = 2. + // graphState.manage().operate().setCheckpointId(1); + // graphState.manage().operate().recover(); + // Assert.assertNotNull(graphState.staticGraph().V().query("1").get()); + // Assert.assertEquals(graphState.staticGraph().E().query("1").asList().size(), 390); + // Assert.assertNull(graphState.staticGraph().V().query("2").get()); + // Assert.assertEquals(graphState.staticGraph().E().query("2").asList().size(), 0); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testBothWriteMode() { + Map conf = new HashMap<>(config); + conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); + testWriteRead(conf); + + conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); + testWriteRead(conf); + } + + @Test + public void testBothWriteMode2() { + Map conf = new HashMap<>(config); + conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); + conf.put(PAIMON_STORE_WAREHOUSE.getKey(), "/tmp/testBothWriteMode2"); + conf.put(PAIMON_STORE_DATABASE.getKey(), "graph"); + testWriteRead(conf); + + conf.put(STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.TRUE.toString()); + testWriteRead(conf); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonKeyStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonKeyStateTest.java index 4d640b90a..5c28ce006 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonKeyStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/PaimonKeyStateTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -37,78 +38,75 @@ public class PaimonKeyStateTest { - static Map config = new HashMap<>(); - - @BeforeClass - public static void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/PaimonKeyStateTest/")); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "PaimonKeyStateTest"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), "file:///tmp" - + "/PaimonKeyStateTest/"); - config.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); - config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); - } - - @AfterClass - public static void tearUp() { - FileUtils.deleteQuietly(new File("/tmp/PaimonKeyStateTest")); - } - - @Test - public void testKMap() { - KeyMapStateDescriptor desc = - KeyMapStateDescriptor.build("testKV", StoreType.PAIMON.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyMapState mapState = StateFactory.buildKeyMapState(desc, - new Configuration(config)); - - // set chk = 1 - mapState.manage().operate().setCheckpointId(1L); - // write data - Map conf = new HashMap<>(config); - mapState.put("hello", conf); - // read nothing since not committed. - Assert.assertEquals(mapState.get("hello").size(), 0); - // commit chk = 1, now be able to read data. - mapState.manage().operate().archive(); - Assert.assertEquals(mapState.get("hello").size(), 6); - - // set chk = 2 - mapState.manage().operate().setCheckpointId(2L); - - Map conf2 = new HashMap<>(config); - conf2.put("conf2", "test"); - mapState.put("hello2", conf2); - // cannot read data with chk = 2 since chk2 not committed. - Assert.assertEquals(mapState.get("hello").size(), 6); - Assert.assertEquals(mapState.get("hello2").size(), 0); - - // commit chk = 2 - mapState.manage().operate().finish(); - mapState.manage().operate().archive(); - - // now be able to read data - Assert.assertEquals(mapState.get("hello").size(), 6); - Assert.assertEquals(mapState.get("hello2").size(), 7); - - // read data which not exists - Assert.assertEquals(mapState.get("hello3").size(), 0); - - // TODO. recover to chk = 1, then be not able to read data with chk = 2. - // mapState = StateFactory.buildKeyMapState(desc, - // new Configuration(config)); - // mapState.manage().operate().setCheckpointId(1L); - // mapState.manage().operate().recover(); - // Assert.assertEquals(mapState.get("hello").size(), 4); - // Assert.assertEquals(mapState.get("hell2").size(), 0); - - - mapState.manage().operate().close(); - mapState.manage().operate().drop(); - } - + static Map config = new HashMap<>(); + + @BeforeClass + public static void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/PaimonKeyStateTest/")); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "PaimonKeyStateTest"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put( + PaimonConfigKeys.PAIMON_STORE_WAREHOUSE.getKey(), "file:///tmp" + "/PaimonKeyStateTest/"); + config.put(PaimonConfigKeys.PAIMON_STORE_DISTRIBUTED_MODE_ENABLE.getKey(), "false"); + config.put(PaimonConfigKeys.PAIMON_STORE_TABLE_AUTO_CREATE_ENABLE.getKey(), "true"); + } + + @AfterClass + public static void tearUp() { + FileUtils.deleteQuietly(new File("/tmp/PaimonKeyStateTest")); + } + + @Test + public void testKMap() { + KeyMapStateDescriptor desc = + KeyMapStateDescriptor.build("testKV", StoreType.PAIMON.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyMapState mapState = + StateFactory.buildKeyMapState(desc, new Configuration(config)); + + // set chk = 1 + mapState.manage().operate().setCheckpointId(1L); + // write data + Map conf = new HashMap<>(config); + mapState.put("hello", conf); + // read nothing since not committed. + Assert.assertEquals(mapState.get("hello").size(), 0); + // commit chk = 1, now be able to read data. + mapState.manage().operate().archive(); + Assert.assertEquals(mapState.get("hello").size(), 6); + + // set chk = 2 + mapState.manage().operate().setCheckpointId(2L); + + Map conf2 = new HashMap<>(config); + conf2.put("conf2", "test"); + mapState.put("hello2", conf2); + // cannot read data with chk = 2 since chk2 not committed. + Assert.assertEquals(mapState.get("hello").size(), 6); + Assert.assertEquals(mapState.get("hello2").size(), 0); + + // commit chk = 2 + mapState.manage().operate().finish(); + mapState.manage().operate().archive(); + + // now be able to read data + Assert.assertEquals(mapState.get("hello").size(), 6); + Assert.assertEquals(mapState.get("hello2").size(), 7); + + // read data which not exists + Assert.assertEquals(mapState.get("hello3").size(), 0); + + // TODO. recover to chk = 1, then be not able to read data with chk = 2. + // mapState = StateFactory.buildKeyMapState(desc, + // new Configuration(config)); + // mapState.manage().operate().setCheckpointId(1L); + // mapState.manage().operate().recover(); + // Assert.assertEquals(mapState.get("hello").size(), 4); + // Assert.assertEquals(mapState.get("hell2").size(), 0); + + mapState.manage().operate().close(); + mapState.manage().operate().drop(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBDtPartitionGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBDtPartitionGraphStateTest.java index 552ee6562..75980982b 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBDtPartitionGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBDtPartitionGraphStateTest.java @@ -19,15 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -68,480 +66,637 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + public class RocksDBDtPartitionGraphStateTest { - Map config = new HashMap<>(); - - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksDBDtPartitionGraphStateTest")); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBDtPartitionGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE.getKey(), "dt"); - // 2025-01-01 00:00:00 - config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_START.getKey(), "1735660800"); - // 7 days - config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_CYCLE.getKey(), "604800"); - config.put(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE.getKey(), "false"); + Map config = new HashMap<>(); + + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksDBDtPartitionGraphStateTest")); + Map persistConfig = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBDtPartitionGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE.getKey(), "dt"); + // 2025-01-01 00:00:00 + config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_START.getKey(), "1735660800"); + // 7 days + config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_DT_CYCLE.getKey(), "604800"); + config.put(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE.getKey(), "false"); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + } + + private GraphState getGraphState( + IType type, String name, Map conf, KeyGroup keyGroup, int maxPara) { + GraphElementMetas.clearCache(); + GraphMetaType tag = + new GraphMetaType( + type, + ValueTimeVertex.class, + type.getTypeClass(), + ValueTimeEdge.class, + type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); + desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + @Test(invocationCount = 10) + public void testWrite() { + Map conf = Maps.newHashMap(config); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testWrite", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueTimeEdge<>("0", id, "hello", 1735660800 + i)); } - private GraphState getGraphState(IType type, String name, - Map conf) { - return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + for (int i = 0; i < 360; i++) { + String id = Integer.toString(i); + graphState + .staticGraph() + .E() + .add(new ValueTimeEdge<>("0", id, "world", 1736265600 /*1735660800 + 604800*/ + i)); } - private GraphState getGraphState(IType type, String name, - Map conf, KeyGroup keyGroup, - int maxPara) { - GraphElementMetas.clearCache(); - GraphMetaType tag = new GraphMetaType(type, ValueTimeVertex.class, type.getTypeClass(), - ValueTimeEdge.class, type.getTypeClass()); - - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); - desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - return graphState; + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueTimeEdge<>("1", id, "val", 1735660800 + i)); } - @Test(invocationCount = 10) - public void testWrite() { - Map conf = Maps.newHashMap(config); - GraphState graphState = getGraphState(StringType.INSTANCE, - "testWrite", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueTimeEdge<>("0", id, "hello", 1735660800 + i)); - } - - for (int i = 0; i < 360; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E() - .add(new ValueTimeEdge<>("0", id, "world", 1736265600/*1735660800 + 604800*/ + i)); - } - - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueTimeEdge<>("1", id, "val", 1735660800 + i)); - } - - graphState.manage().operate().finish(); - List> edgeList1 = graphState.staticGraph().E().query("0") + graphState.manage().operate().finish(); + List> edgeList1 = + graphState + .staticGraph() + .E() + .query("0") .by(OutEdgeFilter.getInstance().and(EdgeTsFilter.getInstance(1735660800, 1735662000))) .asList(); - Assert.assertEquals(edgeList1.size(), 390); - graphState.manage().operate().close(); - - graphState = getGraphState(StringType.INSTANCE, "testWrite", conf); - graphState.manage().operate().setCheckpointId(1); - - List> edgeList = graphState.staticGraph().E().query("0").asList(); - Assert.assertEquals(edgeList.size(), 750); - - edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 750); - edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 0); - edgeList = graphState.staticGraph().E().query() + Assert.assertEquals(edgeList1.size(), 390); + graphState.manage().operate().close(); + + graphState = getGraphState(StringType.INSTANCE, "testWrite", conf); + graphState.manage().operate().setCheckpointId(1); + + List> edgeList = graphState.staticGraph().E().query("0").asList(); + Assert.assertEquals(edgeList.size(), 750); + + edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 750); + edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 0); + edgeList = + graphState + .staticGraph() + .E() + .query() .by(OutEdgeFilter.getInstance().and(EdgeTsFilter.getInstance(1735660800, 1735662000))) .asList(); - Assert.assertEquals(edgeList.size(), 780); - edgeList = graphState.staticGraph().E().query() + Assert.assertEquals(edgeList.size(), 780); + edgeList = + graphState + .staticGraph() + .E() + .query() .by(OutEdgeFilter.getInstance().and(EdgeTsFilter.getInstance(1736265600, 1736265699))) .asList(); - Assert.assertEquals(edgeList.size(), 99); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + Assert.assertEquals(edgeList.size(), 99); - @Test - public void testRead() { - Map conf = new HashMap<>(config); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } - GraphState graphState = getGraphState(StringType.INSTANCE, - "testRead", conf); - graphState.manage().operate().setCheckpointId(1); + @Test + public void testRead() { + Map conf = new HashMap<>(config); - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueTimeEdge<>("2", id, "hello", 1735660800 + i)); - graphState.staticGraph().V().add(new ValueTimeVertex<>(id, "hello", 1735660800 + i)); - } - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E() - .add(new ValueTimeEdge<>("2", id, "world", 1736265600/*1735660800 + 604800*/ + i)); - graphState.staticGraph().V().add(new ValueTimeVertex<>(id, "world", 1736870400 + i - /*1735660800 + 2*604800*/)); - } + GraphState graphState = + getGraphState(StringType.INSTANCE, "testRead", conf); + graphState.manage().operate().setCheckpointId(1); - List> edges1 = graphState.staticGraph().E().query("2") - .by(EdgeTsFilter.getInstance(1735660800, 1735662800)).asList(); - List> edges2 = graphState.staticGraph().E().query("2") - .by(EdgeTsFilter.getInstance(1736265600, 1736365600)).asList(); - List> edges3 = graphState.staticGraph().E().query("2") - .by(EdgeTsFilter.getInstance(1835660800, 1835760800)).asList(); - IVertex vertex1 = graphState.staticGraph().V().query("9999") - .by(VertexTsFilter.getInstance(1736870400, 1736970400)).get(); - IVertex vertex2 = graphState.staticGraph().V().query("9999") - .by(VertexTsFilter.getInstance(1735660800, 1735760800)).get(); - IVertex vertex3 = graphState.staticGraph().V().query("9999") - .by(VertexTsFilter.getInstance(1835660800, 1835760800)).get(); - - Assert.assertEquals(edges1.size(), 2000); - Assert.assertEquals(edges2.size(), 10000); - Assert.assertEquals(edges3.size(), 0); - Assert.assertEquals(edges1.get(1).getValue(), "hello"); - Assert.assertEquals(edges2.get(1).getValue(), "world"); - Assert.assertEquals(vertex1, new ValueTimeVertex<>("9999", "world", 1736880399 - /*1736870400 + 9999*/)); - Assert.assertEquals(vertex2, new ValueTimeVertex<>("9999", "hello", 1735670799 - /*1735660800 + 9999*/)); - Assert.assertNull(vertex3); - - graphState.manage().operate().finish(); - graphState.manage().operate().drop(); + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueTimeEdge<>("2", id, "hello", 1735660800 + i)); + graphState.staticGraph().V().add(new ValueTimeVertex<>(id, "hello", 1735660800 + i)); } - - @Test - public void testIterator() { - Map conf = new HashMap<>(config); - - GraphState graphState = getGraphState(StringType.INSTANCE, - "testIterator", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 200; i++) { - String id = Integer.toString(i); - if (i % 4 <= 1) { - if (i % 4 == 0) { - graphState.staticGraph().E() - .add(new ValueTimeEdge<>(id, id, "hello", 1735660800 + i)); - } else { - graphState.staticGraph().E().add(new ValueTimeEdge<>(id, id, "world", - 1736265600/*1735660800 + 604800*/ + i)); - } - } - - if (i % 4 >= 1) { - if (i % 4 == 1) { - graphState.staticGraph().V() - .add(new ValueTimeVertex<>(id, "tom", 1735660800 + i)); - } else { - graphState.staticGraph().V() - .add(new ValueTimeVertex<>(id, "company", 1736870400 + i - /*1735660800 + 2*604800*/)); - } - } - } - graphState.manage().operate().finish(); - - Iterator> it = graphState.staticGraph().V().query() - .by(VertexTsFilter.getInstance(1736870400, 1736970400)).iterator(); - List> vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 100); - - it = graphState.staticGraph().V().query() - .by(VertexTsFilter.getInstance(1836870400, 1836970400)).iterator(); - vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 0); - - it = graphState.staticGraph().V().query("122", "151").iterator(); - vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 2); - - it = graphState.staticGraph().V().query().iterator(); - Assert.assertEquals(Iterators.size(it), 150); - Iterator idIt = graphState.staticGraph().V().query().idIterator(); - Assert.assertEquals(Iterators.size(idIt), 150); - - Iterator> it2 = graphState.staticGraph().VE().query() - .by(VertexTsFilter.getInstance(1735660800, 1736265600) - .and(EdgeTsFilter.getInstance(1735660800, 1736265600))).iterator(); - List res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 100); - - it2 = graphState.staticGraph().VE().query() - .by(VertexTsFilter.getInstance(1836870400, 1836970400) - .and(EdgeTsFilter.getInstance(1836870400, 1836970400))).iterator(); - res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 0); - - it2 = graphState.staticGraph().VE().query("109", "115").iterator(); - res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState + .staticGraph() + .E() + .add(new ValueTimeEdge<>("2", id, "world", 1736265600 /*1735660800 + 604800*/ + i)); + graphState + .staticGraph() + .V() + .add(new ValueTimeVertex<>(id, "world", 1736870400 + i /*1735660800 + 2*604800*/)); } - @Test - public void testFilter() { - Map conf = new HashMap<>(config); - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, ValueLabelTimeVertex.class, - ValueLabelTimeVertex::new, EmptyProperty.class, IDLabelTimeEdge.class, - IDLabelTimeEdge::new, EmptyProperty.class); - - GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", 1735660800 + i); - edge1.setDirect(EdgeDirection.OUT); - graphState.staticGraph().E().add(edge1); - - IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", 1735660800 + i); - edge2.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge2); - - IEdge edge3 = new IDLabelTimeEdge<>("2", id, "world", 1735660800 + i); - edge3.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge3); + List> edges1 = + graphState + .staticGraph() + .E() + .query("2") + .by(EdgeTsFilter.getInstance(1735660800, 1735662800)) + .asList(); + List> edges2 = + graphState + .staticGraph() + .E() + .query("2") + .by(EdgeTsFilter.getInstance(1736265600, 1736365600)) + .asList(); + List> edges3 = + graphState + .staticGraph() + .E() + .query("2") + .by(EdgeTsFilter.getInstance(1835660800, 1835760800)) + .asList(); + IVertex vertex1 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexTsFilter.getInstance(1736870400, 1736970400)) + .get(); + IVertex vertex2 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexTsFilter.getInstance(1735660800, 1735760800)) + .get(); + IVertex vertex3 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexTsFilter.getInstance(1835660800, 1835760800)) + .get(); + + Assert.assertEquals(edges1.size(), 2000); + Assert.assertEquals(edges2.size(), 10000); + Assert.assertEquals(edges3.size(), 0); + Assert.assertEquals(edges1.get(1).getValue(), "hello"); + Assert.assertEquals(edges2.get(1).getValue(), "world"); + Assert.assertEquals( + vertex1, new ValueTimeVertex<>("9999", "world", 1736880399 /*1736870400 + 9999*/)); + Assert.assertEquals( + vertex2, new ValueTimeVertex<>("9999", "hello", 1735670799 /*1735660800 + 9999*/)); + Assert.assertNull(vertex3); + + graphState.manage().operate().finish(); + graphState.manage().operate().drop(); + } + + @Test + public void testIterator() { + Map conf = new HashMap<>(config); + + GraphState graphState = + getGraphState(StringType.INSTANCE, "testIterator", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 200; i++) { + String id = Integer.toString(i); + if (i % 4 <= 1) { + if (i % 4 == 0) { + graphState.staticGraph().E().add(new ValueTimeEdge<>(id, id, "hello", 1735660800 + i)); + } else { + graphState + .staticGraph() + .E() + .add(new ValueTimeEdge<>(id, id, "world", 1736265600 /*1735660800 + 604800*/ + i)); } - - graphState.manage().operate().finish(); - List> edges = graphState.staticGraph().E().query("2") - .by(new EdgeTsFilter(TimeRange.of(1735660800, 1735665800))).asList(); - - Assert.assertEquals(edges.size(), 15000); - long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max() - .getAsLong(); - Assert.assertEquals(maxTime, 1735665799); - - long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 5000); - - edges = graphState.staticGraph().E().query("2").by(OutEdgeFilter.getInstance() - .and(new EdgeTsFilter(TimeRange.of(1735660800, 1735661800)))).asList(); - Assert.assertEquals(edges.size(), 1000); - - maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); - Assert.assertEquals(maxTime, 1735661799); - - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 1000); - - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); - Assert.assertEquals(num, 0); - - edges = graphState.staticGraph().E().query("2") - .by(new EdgeTsFilter(TimeRange.of(1735667800, 1735670800)).and( - EdgeLabelFilter.getInstance("world"))).asList(); - Assert.assertEquals(edges.size(), 3000); - long minTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).min() - .getAsLong(); - Assert.assertEquals(minTime, 1735667800); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + } + + if (i % 4 >= 1) { + if (i % 4 == 1) { + graphState.staticGraph().V().add(new ValueTimeVertex<>(id, "tom", 1735660800 + i)); + } else { + graphState + .staticGraph() + .V() + .add(new ValueTimeVertex<>(id, "company", 1736870400 + i /*1735660800 + 2*604800*/)); + } + } } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testEdgeOrderError() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); - getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); + graphState.manage().operate().finish(); + + Iterator> it = + graphState + .staticGraph() + .V() + .query() + .by(VertexTsFilter.getInstance(1736870400, 1736970400)) + .iterator(); + List> vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 100); + + it = + graphState + .staticGraph() + .V() + .query() + .by(VertexTsFilter.getInstance(1836870400, 1836970400)) + .iterator(); + vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 0); + + it = graphState.staticGraph().V().query("122", "151").iterator(); + vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 2); + + it = graphState.staticGraph().V().query().iterator(); + Assert.assertEquals(Iterators.size(it), 150); + Iterator idIt = graphState.staticGraph().V().query().idIterator(); + Assert.assertEquals(Iterators.size(idIt), 150); + + Iterator> it2 = + graphState + .staticGraph() + .VE() + .query() + .by( + VertexTsFilter.getInstance(1735660800, 1736265600) + .and(EdgeTsFilter.getInstance(1735660800, 1736265600))) + .iterator(); + List res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 100); + + it2 = + graphState + .staticGraph() + .VE() + .query() + .by( + VertexTsFilter.getInstance(1836870400, 1836970400) + .and(EdgeTsFilter.getInstance(1836870400, 1836970400))) + .iterator(); + res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 0); + + it2 = graphState.staticGraph().VE().query("109", "115").iterator(); + res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 2); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFilter() { + Map conf = new HashMap<>(config); + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + ValueLabelTimeVertex.class, + ValueLabelTimeVertex::new, + EmptyProperty.class, + IDLabelTimeEdge.class, + IDLabelTimeEdge::new, + EmptyProperty.class); + + GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", 1735660800 + i); + edge1.setDirect(EdgeDirection.OUT); + graphState.staticGraph().E().add(edge1); + + IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", 1735660800 + i); + edge2.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge2); + + IEdge edge3 = new IDLabelTimeEdge<>("2", id, "world", 1735660800 + i); + edge3.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge3); } - @Test - public void testEdgeSort() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); - - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, ValueLabelTimeVertex.class, - ValueLabelTimeVertex::new, Object.class, ValueLabelTimeEdge.class, - ValueLabelTimeEdge::new, Object.class); - - GraphStateDescriptor desc = GraphStateDescriptor.build("testEdgeSort", - StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", - 1735660800 + i); - graphState.staticGraph().E().add(edge.withValue("world")); - } - graphState.manage().operate().finish(); + graphState.manage().operate().finish(); + List> edges = + graphState + .staticGraph() + .E() + .query("2") + .by(new EdgeTsFilter(TimeRange.of(1735660800, 1735665800))) + .asList(); - List> list = graphState.staticGraph().E().asList(); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 1735670799); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 1735660800); - } + Assert.assertEquals(edges.size(), 15000); + long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 1735665799); + + long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 5000); + + edges = + graphState + .staticGraph() + .E() + .query("2") + .by( + OutEdgeFilter.getInstance() + .and(new EdgeTsFilter(TimeRange.of(1735660800, 1735661800)))) + .asList(); + Assert.assertEquals(edges.size(), 1000); - @Test - public void testLimit() { - Map conf = new HashMap<>(config); - String name = "testLimit"; - conf.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBDtPartitionGraphStateTest" + System.currentTimeMillis()); - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 10; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add(new ValueTimeEdge<>(src, dst, "hello" + src + dst, - EdgeDirection.values()[j % 2], 1735660800 + i * j)); - } - graphState.staticGraph().V() - .add(new ValueTimeVertex<>(src, "world" + src, 1735660800 + i)); - } - graphState.manage().operate().finish(); + maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 1735661799); - List> list = graphState.staticGraph().E().query("1", "2", "3") - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 6); + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 1000); - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L) - .asList(); - Assert.assertEquals(list.size(), 10); + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); + Assert.assertEquals(num, 0); - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L) + edges = + graphState + .staticGraph() + .E() + .query("2") + .by( + new EdgeTsFilter(TimeRange.of(1735667800, 1735670800)) + .and(EdgeLabelFilter.getInstance("world"))) .asList(); - Assert.assertEquals(list.size(), 20); - - List targetIds = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); - - Assert.assertEquals(targetIds.size(), 20); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + Assert.assertEquals(edges.size(), 3000); + long minTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).min().getAsLong(); + Assert.assertEquals(minTime, 1735667800); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testEdgeOrderError() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); + } + + @Test + public void testEdgeSort() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + ValueLabelTimeVertex.class, + ValueLabelTimeVertex::new, + Object.class, + ValueLabelTimeEdge.class, + ValueLabelTimeEdge::new, + Object.class); + + GraphStateDescriptor desc = + GraphStateDescriptor.build("testEdgeSort", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", 1735660800 + i); + graphState.staticGraph().E().add(edge.withValue("world")); } + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().asList(); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 1735670799); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 1735660800); + } + + @Test + public void testLimit() { + Map conf = new HashMap<>(config); + String name = "testLimit"; + conf.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBDtPartitionGraphStateTest" + System.currentTimeMillis()); + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 10; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add( + new ValueTimeEdge<>( + src, + dst, + "hello" + src + dst, + EdgeDirection.values()[j % 2], + 1735660800 + i * j)); + } + graphState.staticGraph().V().add(new ValueTimeVertex<>(src, "world" + src, 1735660800 + i)); + } + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().E().query("1", "2", "3").limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 6); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 10); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L).asList(); + Assert.assertEquals(list.size(), 20); + + List targetIds = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance()) + .select(new DstIdProjector<>()) + .limit(1L, 2L) + .asList(); - @Test - public void testFO() throws IOException { - Map conf = new HashMap<>(config); - String name = "testFO"; - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); + Assert.assertEquals(targetIds.size(), 20); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } - graphState.manage().operate().drop(); - graphState = getGraphState(StringType.INSTANCE, name, conf); + @Test + public void testFO() throws IOException { + Map conf = new HashMap<>(config); + String name = "testFO"; + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(2); + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); - for (int i = 0; i < 100; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueTimeEdge<>("1", id, "hello", 1735660800 + i)); - graphState.staticGraph().V().add(new ValueTimeVertex<>("1", "hello", 1735660800 + i)); - } + graphState.manage().operate().drop(); + graphState = getGraphState(StringType.INSTANCE, name, conf); - List> edges = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(edges.size(), 100); - List> vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 1); + graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(2); - graphState.manage().operate().archive(); - graphState.manage().operate().finish(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + for (int i = 0; i < 100; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueTimeEdge<>("1", id, "hello", 1735660800 + i)); + graphState.staticGraph().V().add(new ValueTimeVertex<>("1", "hello", 1735660800 + i)); + } - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(2); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(3); - - edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 100); - vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 1); - - for (int i = 0; i < 80; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E() - .add(new ValueTimeEdge<>("2", id, "hello", 1736265600/*1735660800 + 604800*/ + i)); - graphState.staticGraph().V() - .add(new ValueTimeVertex<>("2", "hello", 1736265600/*1735660800 + 604800*/ + i)); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 180); - edges = graphState.staticGraph().E().query() - .by(EdgeTsFilter.getInstance(1736265600, 1737265600)).asList(); - Assert.assertEquals(edges.size(), 80); - edges = graphState.staticGraph().E().query() - .by(EdgeTsFilter.getInstance(1735660800, 1735661800)).asList(); - Assert.assertEquals(edges.size(), 100); - vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 2); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(3); + List> edges = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(edges.size(), 100); + List> vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 1); + + graphState.manage().operate().archive(); + graphState.manage().operate().finish(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(2); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(3); + + edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 100); + vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 1); + + for (int i = 0; i < 80; i++) { + String id = Integer.toString(i); + graphState + .staticGraph() + .E() + .add(new ValueTimeEdge<>("2", id, "hello", 1736265600 /*1735660800 + 604800*/ + i)); + graphState + .staticGraph() + .V() + .add(new ValueTimeVertex<>("2", "hello", 1736265600 /*1735660800 + 604800*/ + i)); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 180); + edges = + graphState + .staticGraph() + .E() + .query() + .by(EdgeTsFilter.getInstance(1736265600, 1737265600)) + .asList(); + Assert.assertEquals(edges.size(), 80); + edges = + graphState + .staticGraph() + .E() + .query() + .by(EdgeTsFilter.getInstance(1735660800, 1735661800)) + .asList(); + Assert.assertEquals(edges.size(), 100); + vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 2); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(3); + graphState.manage().operate().recover(); + graphState.staticGraph().V().add(new ValueTimeVertex<>("2", "world", 1737265600)); + Assert.assertEquals( + graphState + .staticGraph() + .V() + .query("2") + .by(VertexTsFilter.getInstance(1737265600, 1737265601)) + .get() + .getValue(), + "world"); + graphState.manage().operate().finish(); + Assert.assertEquals( + graphState + .staticGraph() + .V() + .query("2") + .by(VertexTsFilter.getInstance(1737265600, 1737265601)) + .get() + .getValue(), + "world"); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } + + @Test + public void testArchive() throws IOException { + Map conf = new HashMap<>(config); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + + GraphState graphState = null; + + for (int v = 1; v < 10; v++) { + graphState = getGraphState(StringType.INSTANCE, "archive", conf); + if (v > 1) { + graphState.manage().operate().setCheckpointId(v - 1); graphState.manage().operate().recover(); - graphState.staticGraph().V().add(new ValueTimeVertex<>("2", "world", 1737265600)); - Assert.assertEquals(graphState.staticGraph().V().query("2") - .by(VertexTsFilter.getInstance(1737265600, 1737265601)).get().getValue(), "world"); - graphState.manage().operate().finish(); - Assert.assertEquals(graphState.staticGraph().V().query("2") - .by(VertexTsFilter.getInstance(1737265600, 1737265601)).get().getValue(), "world"); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); + } + graphState.manage().operate().setCheckpointId(v); + for (int i = 0; i < 10; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueTimeEdge<>(id, id, id, 1735660800 + i)); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + graphState.manage().operate().close(); } - - @Test - public void testArchive() throws IOException { - Map conf = new HashMap<>(config); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - - GraphState graphState = null; - - for (int v = 1; v < 10; v++) { - graphState = getGraphState(StringType.INSTANCE, "archive", conf); - if (v > 1) { - graphState.manage().operate().setCheckpointId(v - 1); - graphState.manage().operate().recover(); - } - graphState.manage().operate().setCheckpointId(v); - for (int i = 0; i < 10; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueTimeEdge<>(id, id, id, 1735660800 + i)); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - } - - graphState.manage().operate().drop(); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - } + graphState.manage().operate().drop(); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBGraphStateTest.java index 5c3327936..22237b59c 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBGraphStateTest.java @@ -19,15 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -67,480 +65,543 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + public class RocksDBGraphStateTest { - Map config = new HashMap<>(); - - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + Map config = new HashMap<>(); + + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); + Map persistConfig = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + } + + private GraphState getGraphState( + IType type, String name, Map conf, KeyGroup keyGroup, int maxPara) { + GraphElementMetas.clearCache(); + GraphMetaType tag = + new GraphMetaType( + type, + ValueVertex.class, + ValueVertex::new, + type.getTypeClass(), + ValueEdge.class, + ValueEdge::new, + type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); + desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + public void testWrite(boolean async) { + Map conf = Maps.newHashMap(config); + conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.toString(async)); + + GraphState graphState = + getGraphState(StringType.INSTANCE, "writeTest", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("0", id, "hello")); } - private GraphState getGraphState(IType type, String name, - Map conf) { - return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("1", id, "hello")); + } + graphState.manage().operate().finish(); + graphState.manage().operate().close(); + + graphState = getGraphState(StringType.INSTANCE, "writeTest", conf); + graphState.manage().operate().setCheckpointId(1); + + List> edgeList = graphState.staticGraph().E().query("0").asList(); + Assert.assertEquals(edgeList.size(), 390); + + edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 390); + edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 0); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test(invocationCount = 10) + public void testBothWriteMode() { + testWrite(true); + testWrite(false); + } + + @Test + public void testAsyncRead() { + Map conf = new HashMap<>(config); + + GraphState graphState = + getGraphState(StringType.INSTANCE, "async", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("2", id, "hello")); + graphState.staticGraph().V().add(new ValueVertex<>(id, "hello")); + } + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("3", id, "world")); + graphState.staticGraph().V().add(new ValueVertex<>(id, "world")); } - private GraphState getGraphState(IType type, String name, - Map conf, KeyGroup keyGroup, - int maxPara) { - GraphElementMetas.clearCache(); - GraphMetaType tag = new GraphMetaType(type, ValueVertex.class, ValueVertex::new, - type.getTypeClass(), ValueEdge.class, ValueEdge::new, type.getTypeClass()); - - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); - desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - return graphState; + List> edges1 = graphState.staticGraph().E().query("2").asList(); + List> edges2 = graphState.staticGraph().E().query("3").asList(); + IVertex vertex = graphState.staticGraph().V().query("9999").get(); + + Assert.assertEquals(edges1.size(), 10000); + Assert.assertEquals(edges2.size(), 10000); + Assert.assertEquals(edges1.get(1).getValue(), "hello"); + Assert.assertEquals(edges2.get(1).getValue(), "world"); + Assert.assertEquals(vertex, new ValueVertex<>("9999", "world")); + + graphState.manage().operate().finish(); + graphState.manage().operate().drop(); + } + + @Test + public void testIterator() { + Map conf = new HashMap<>(config); + + GraphState graphState = + getGraphState(StringType.INSTANCE, "iterator", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 200; i++) { + String id = Integer.toString(i); + if (i % 3 <= 1) { + graphState.staticGraph().E().add(new ValueEdge<>(id, id, "hello")); + } + if (i % 3 >= 1) { + graphState.staticGraph().V().add(new ValueVertex<>(id, "world")); + } } + graphState.manage().operate().finish(); - public void testWrite(boolean async) { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), Boolean.toString(async)); + Iterator> it = graphState.staticGraph().V().iterator(); - GraphState graphState = getGraphState(StringType.INSTANCE, - "writeTest", conf); - graphState.manage().operate().setCheckpointId(1); + List> vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 133); - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("0", id, "hello")); - } + it = graphState.staticGraph().V().query("122", "151").iterator(); + vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 2); - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("1", id, "hello")); - } - graphState.manage().operate().finish(); - graphState.manage().operate().close(); + it = graphState.staticGraph().V().query(new KeyGroup(1, 1)).iterator(); + Assert.assertEquals(Iterators.size(it), 73); - graphState = getGraphState(StringType.INSTANCE, "writeTest", conf); - graphState.manage().operate().setCheckpointId(1); + Iterator idIt = graphState.staticGraph().V().query(new KeyGroup(1, 1)).idIterator(); + Assert.assertEquals(Iterators.size(idIt), 73); - List> edgeList = graphState.staticGraph().E().query("0").asList(); - Assert.assertEquals(edgeList.size(), 390); + Iterator> it2 = + graphState.staticGraph().VE().query().iterator(); - edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 390); - edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 0); + List res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 200); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + it2 = graphState.staticGraph().VE().query("111", "115").iterator(); + res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 2); - @Test(invocationCount = 10) - public void testBothWriteMode() { - testWrite(true); - testWrite(false); - } + graphState.manage().operate().close(); + graphState.manage().operate().drop(); - @Test - public void testAsyncRead() { - Map conf = new HashMap<>(config); - - GraphState graphState = getGraphState(StringType.INSTANCE, "async", - conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("2", id, "hello")); - graphState.staticGraph().V().add(new ValueVertex<>(id, "hello")); - } - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("3", id, "world")); - graphState.staticGraph().V().add(new ValueVertex<>(id, "world")); - } - - List> edges1 = graphState.staticGraph().E().query("2").asList(); - List> edges2 = graphState.staticGraph().E().query("3").asList(); - IVertex vertex = graphState.staticGraph().V().query("9999").get(); - - Assert.assertEquals(edges1.size(), 10000); - Assert.assertEquals(edges2.size(), 10000); - Assert.assertEquals(edges1.get(1).getValue(), "hello"); - Assert.assertEquals(edges2.get(1).getValue(), "world"); - Assert.assertEquals(vertex, new ValueVertex<>("9999", "world")); - - graphState.manage().operate().finish(); - graphState.manage().operate().drop(); + GraphState graphState2 = + getGraphState(IntegerType.INSTANCE, "iterator", conf); + graphState2.manage().operate().setCheckpointId(1); + for (int i = 0; i < 200; i++) { + if (i % 3 <= 1) { + graphState2.staticGraph().E().add(new ValueEdge<>(i, i, i)); + } + if (i % 3 >= 1) { + graphState2.staticGraph().V().add(new ValueVertex<>(i, i)); + } } + graphState2.manage().operate().finish(); + Iterator> it3 = + graphState2.staticGraph().VE().query().iterator(); + + res = Lists.newArrayList(it3); + Assert.assertEquals(res.size(), 200); + + Iterator idIterator = graphState2.staticGraph().V().idIterator(); + List idList = Lists.newArrayList(idIterator); + Assert.assertEquals(idList.size(), 133); + + graphState2.manage().operate().close(); + graphState2.manage().operate().drop(); + } + + @Test + public void testOtherVE() { + Map conf = new HashMap<>(config); + + GraphMetaType tag = + new GraphMetaType( + IntegerType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + ValueLabelTimeEdge.class, + ValueLabelTimeEdge::new, + Object.class); - @Test - public void testIterator() { - Map conf = new HashMap<>(config); - - GraphState graphState = getGraphState(StringType.INSTANCE, - "iterator", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 200; i++) { - String id = Integer.toString(i); - if (i % 3 <= 1) { - graphState.staticGraph().E().add(new ValueEdge<>(id, id, "hello")); - } - if (i % 3 >= 1) { - graphState.staticGraph().V().add(new ValueVertex<>(id, "world")); - } - } - graphState.manage().operate().finish(); - - Iterator> it = graphState.staticGraph().V().iterator(); - - List> vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 133); - - it = graphState.staticGraph().V().query("122", "151").iterator(); - vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 2); - - it = graphState.staticGraph().V().query(new KeyGroup(1, 1)).iterator(); - Assert.assertEquals(Iterators.size(it), 73); - - Iterator idIt = graphState.staticGraph().V().query(new KeyGroup(1, 1)).idIterator(); - Assert.assertEquals(Iterators.size(idIt), 73); - - Iterator> it2 = graphState.staticGraph().VE().query() - .iterator(); - - List res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 200); - - it2 = graphState.staticGraph().VE().query("111", "115").iterator(); - res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - GraphState graphState2 = getGraphState(IntegerType.INSTANCE, - "iterator", conf); - graphState2.manage().operate().setCheckpointId(1); - for (int i = 0; i < 200; i++) { - if (i % 3 <= 1) { - graphState2.staticGraph().E().add(new ValueEdge<>(i, i, i)); - } - if (i % 3 >= 1) { - graphState2.staticGraph().V().add(new ValueVertex<>(i, i)); - } - } - graphState2.manage().operate().finish(); - Iterator> it3 = graphState2.staticGraph().VE() - .query().iterator(); - - res = Lists.newArrayList(it3); - Assert.assertEquals(res.size(), 200); - - Iterator idIterator = graphState2.staticGraph().V().idIterator(); - List idList = Lists.newArrayList(idIterator); - Assert.assertEquals(idList.size(), 133); - - graphState2.manage().operate().close(); - graphState2.manage().operate().drop(); + GraphStateDescriptor desc = GraphStateDescriptor.build("OtherVE", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + graphState.staticGraph().V().add(new IDVertex<>(i)); + IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, null, "foo", i); + edge.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge.withValue("bar")); } - - @Test - public void testOtherVE() { - Map conf = new HashMap<>(config); - - GraphMetaType tag = new GraphMetaType(IntegerType.INSTANCE, IDVertex.class, IDVertex::new, - EmptyProperty.class, ValueLabelTimeEdge.class, ValueLabelTimeEdge::new, Object.class); - - GraphStateDescriptor desc = GraphStateDescriptor.build("OtherVE", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - graphState.staticGraph().V().add(new IDVertex<>(i)); - IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, null, "foo", i); - edge.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge.withValue("bar")); - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().VE().query() - .asList(); - Assert.assertEquals(list.size(), 1000); - int key = list.get(0).getKey(); - Assert.assertEquals(list.get(0).getVertex(), new IDVertex<>(key)); - IEdge edge = new ValueLabelTimeEdge<>(key, key + 1, null, "foo", key); - edge.setDirect(EdgeDirection.IN); - Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge.withValue("bar")); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - tag = new GraphMetaType(IntegerType.INSTANCE, ValueLabelTimeVertex.class, Object.class, - IDEdge.class, EmptyProperty.class); - desc.withGraphMeta(new GraphMeta(tag)); - graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - graphState.staticGraph().V().add(new ValueLabelTimeVertex<>(i, "bar", "foo", i)); - IEdge idEdge = new IDEdge<>(i, i + 1); - idEdge.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(idEdge); - } - graphState.manage().operate().finish(); - list = graphState.staticGraph().VE().query().asList(); - Assert.assertEquals(list.size(), 1000); - key = list.get(0).getKey(); - Assert.assertEquals(list.get(0).getVertex(), - new ValueLabelTimeVertex<>(key, "bar", "foo", key)); - - IEdge idEdge = new IDEdge<>(key, key + 1); - idEdge.setDirect(EdgeDirection.IN); - Assert.assertEquals(list.get(0).getEdgeIterator().next(), idEdge); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().VE().query().asList(); + Assert.assertEquals(list.size(), 1000); + int key = list.get(0).getKey(); + Assert.assertEquals(list.get(0).getVertex(), new IDVertex<>(key)); + IEdge edge = new ValueLabelTimeEdge<>(key, key + 1, null, "foo", key); + edge.setDirect(EdgeDirection.IN); + Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge.withValue("bar")); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + tag = + new GraphMetaType( + IntegerType.INSTANCE, + ValueLabelTimeVertex.class, + Object.class, + IDEdge.class, + EmptyProperty.class); + desc.withGraphMeta(new GraphMeta(tag)); + graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + graphState.staticGraph().V().add(new ValueLabelTimeVertex<>(i, "bar", "foo", i)); + IEdge idEdge = new IDEdge<>(i, i + 1); + idEdge.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(idEdge); } - - @Test - public void testFilter() { - Map conf = new HashMap<>(config); - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, IDVertex.class, IDVertex::new, - EmptyProperty.class, IDLabelTimeEdge.class, IDLabelTimeEdge::new, EmptyProperty.class); - - GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", i); - edge1.setDirect(EdgeDirection.OUT); - graphState.staticGraph().E().add(edge1); - - IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", i); - edge2.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge2); - } - - graphState.manage().operate().finish(); - List> edges = graphState.staticGraph().E().query("2") - .by(new EdgeTsFilter(TimeRange.of(0, 5000))).asList(); - - Assert.assertEquals(edges.size(), 10000); - long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max() - .getAsLong(); - Assert.assertEquals(maxTime, 4999); - - long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 5000); - - edges = graphState.staticGraph().E().query("2") - .by(OutEdgeFilter.getInstance().and(new EdgeTsFilter(TimeRange.of(0, 1000)))).asList(); - Assert.assertEquals(edges.size(), 1000); - - maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); - Assert.assertEquals(maxTime, 999); - - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 1000); - - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); - Assert.assertEquals(num, 0); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - // TODO: MAX VERSION FILTER SUPPORT + graphState.manage().operate().finish(); + list = graphState.staticGraph().VE().query().asList(); + Assert.assertEquals(list.size(), 1000); + key = list.get(0).getKey(); + Assert.assertEquals( + list.get(0).getVertex(), new ValueLabelTimeVertex<>(key, "bar", "foo", key)); + + IEdge idEdge = new IDEdge<>(key, key + 1); + idEdge.setDirect(EdgeDirection.IN); + Assert.assertEquals(list.get(0).getEdgeIterator().next(), idEdge); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFilter() { + Map conf = new HashMap<>(config); + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + IDVertex.class, + IDVertex::new, + EmptyProperty.class, + IDLabelTimeEdge.class, + IDLabelTimeEdge::new, + EmptyProperty.class); + + GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", i); + edge1.setDirect(EdgeDirection.OUT); + graphState.staticGraph().E().add(edge1); + + IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", i); + edge2.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge2); } - @Test(expectedExceptions = IllegalArgumentException.class) - public void testEdgeOrderError() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); - getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); - } + graphState.manage().operate().finish(); + List> edges = + graphState + .staticGraph() + .E() + .query("2") + .by(new EdgeTsFilter(TimeRange.of(0, 5000))) + .asList(); - @Test - public void testEdgeSort() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + Assert.assertEquals(edges.size(), 10000); + long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 4999); - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, ValueVertex.class, - ValueVertex::new, Object.class, ValueLabelTimeEdge.class, ValueLabelTimeEdge::new, - Object.class); + long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 5000); - GraphStateDescriptor desc = GraphStateDescriptor.build("testEdgeSort", - StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", i); - graphState.staticGraph().E().add(edge.withValue("world")); - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().asList(); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 9999); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 0); - } - - @Test - public void testLimit() { - Map conf = new HashMap<>(config); - String name = "testLimit"; - conf.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBGraphStateTest" + System.currentTimeMillis()); - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 10; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add( - new ValueEdge<>(src, dst, "hello" + src + dst, EdgeDirection.values()[j % 2])); - } - graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1", "2", "3") - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 6); - - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L) + edges = + graphState + .staticGraph() + .E() + .query("2") + .by(OutEdgeFilter.getInstance().and(new EdgeTsFilter(TimeRange.of(0, 1000)))) .asList(); - Assert.assertEquals(list.size(), 10); + Assert.assertEquals(edges.size(), 1000); + + maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 999); + + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 1000); + + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); + Assert.assertEquals(num, 0); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + // TODO: MAX VERSION FILTER SUPPORT + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testEdgeOrderError() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); + } + + @Test + public void testEdgeSort() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + ValueVertex.class, + ValueVertex::new, + Object.class, + ValueLabelTimeEdge.class, + ValueLabelTimeEdge::new, + Object.class); - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L) + GraphStateDescriptor desc = + GraphStateDescriptor.build("testEdgeSort", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", i); + graphState.staticGraph().E().add(edge.withValue("world")); + } + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().asList(); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 9999); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 0); + } + + @Test + public void testLimit() { + Map conf = new HashMap<>(config); + String name = "testLimit"; + conf.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBGraphStateTest" + System.currentTimeMillis()); + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 10; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add(new ValueEdge<>(src, dst, "hello" + src + dst, EdgeDirection.values()[j % 2])); + } + graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); + } + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().E().query("1", "2", "3").limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 6); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 10); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L).asList(); + Assert.assertEquals(list.size(), 20); + + List targetIds = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance()) + .select(new DstIdProjector<>()) + .limit(1L, 2L) .asList(); - Assert.assertEquals(list.size(), 20); - List targetIds = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); - - Assert.assertEquals(targetIds.size(), 20); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + Assert.assertEquals(targetIds.size(), 20); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFO() throws IOException { + Map conf = new HashMap<>(config); + String name = "fo"; + conf.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBGraphStateTest" + System.currentTimeMillis()); + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); + + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + + graphState.manage().operate().drop(); + graphState = getGraphState(StringType.INSTANCE, name, conf); + + graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(2); + + for (int i = 0; i < 100; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("1", id, "hello")); + graphState.staticGraph().V().add(new ValueVertex<>("1", "hello")); } - - @Test - public void testFO() throws IOException { - Map conf = new HashMap<>(config); - String name = "fo"; - conf.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBGraphStateTest" + System.currentTimeMillis()); - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - - graphState.manage().operate().drop(); - graphState = getGraphState(StringType.INSTANCE, name, conf); - - graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(2); - - for (int i = 0; i < 100; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("1", id, "hello")); - graphState.staticGraph().V().add(new ValueVertex<>("1", "hello")); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(2); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(3); - - List> edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 100); - List> vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 1); - - for (int i = 0; i < 100; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>("2", id, "hello")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "hello")); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 200); - vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 2); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(3); + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(2); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(3); + + List> edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 100); + List> vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 1); + + for (int i = 0; i < 100; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>("2", id, "hello")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "hello")); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 200); + vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 2); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(3); + graphState.manage().operate().recover(); + graphState.staticGraph().V().add(new ValueVertex<>("2", "world")); + Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); + graphState.manage().operate().finish(); + Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } + + @Test + public void testArchive() throws IOException { + Map conf = new HashMap<>(config); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + + GraphState graphState = null; + + for (int v = 1; v < 10; v++) { + graphState = getGraphState(StringType.INSTANCE, "archive", conf); + if (v > 1) { + graphState.manage().operate().setCheckpointId(v - 1); graphState.manage().operate().recover(); - graphState.staticGraph().V().add(new ValueVertex<>("2", "world")); - Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); - graphState.manage().operate().finish(); - Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); + } + graphState.manage().operate().setCheckpointId(v); + for (int i = 0; i < 10; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueEdge<>(id, id, id)); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + graphState.manage().operate().close(); } - - @Test - public void testArchive() throws IOException { - Map conf = new HashMap<>(config); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - - GraphState graphState = null; - - for (int v = 1; v < 10; v++) { - graphState = getGraphState(StringType.INSTANCE, "archive", conf); - if (v > 1) { - graphState.manage().operate().setCheckpointId(v - 1); - graphState.manage().operate().recover(); - } - graphState.manage().operate().setCheckpointId(v); - for (int i = 0; i < 10; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueEdge<>(id, id, id)); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - } - - graphState.manage().operate().drop(); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - } + graphState.manage().operate().drop(); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBLabelPartitionGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBLabelPartitionGraphStateTest.java index 37e437307..0334a5440 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBLabelPartitionGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksDBLabelPartitionGraphStateTest.java @@ -19,15 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -69,525 +67,624 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + public class RocksDBLabelPartitionGraphStateTest { - Map config = new HashMap<>(); - - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksDBLabelPartitionGraphStateTest")); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBLabelPartitionGraphStateTest" + System.currentTimeMillis()); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE.getKey(), "label"); - config.put(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE.getKey(), "false"); + Map config = new HashMap<>(); + + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksDBLabelPartitionGraphStateTest")); + Map persistConfig = new HashMap<>(); + config.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBLabelPartitionGraphStateTest" + System.currentTimeMillis()); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + config.put(RocksdbConfigKeys.ROCKSDB_GRAPH_STORE_PARTITION_TYPE.getKey(), "label"); + config.put(StoreConfigKeys.STORE_FILTER_CODEGEN_ENABLE.getKey(), "false"); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + } + + private GraphState getGraphState( + IType type, String name, Map conf, KeyGroup keyGroup, int maxPara) { + GraphElementMetas.clearCache(); + GraphMetaType tag = + new GraphMetaType( + type, + ValueLabelVertex.class, + type.getTypeClass(), + ValueLabelEdge.class, + type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); + desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + @Test(invocationCount = 10) + public void testWrite() { + Map conf = Maps.newHashMap(config); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testWrite", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("0", id, "hello", "person")); } - private GraphState getGraphState(IType type, String name, - Map conf) { - return getGraphState(type, name, conf, new KeyGroup(0, 1), 2); + for (int i = 0; i < 360; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("0", id, "world", "trade")); } - private GraphState getGraphState(IType type, String name, - Map conf, KeyGroup keyGroup, - int maxPara) { - GraphElementMetas.clearCache(); - GraphMetaType tag = new GraphMetaType(type, ValueLabelVertex.class, type.getTypeClass(), - ValueLabelEdge.class, type.getTypeClass()); - - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); - desc.withKeyGroup(keyGroup).withKeyGroupAssigner(new DefaultKeyGroupAssigner(maxPara)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - return graphState; + for (int i = 0; i < 390; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("1", id, "val", "person")); } - @Test(invocationCount = 10) - public void testWrite() { - Map conf = Maps.newHashMap(config); - GraphState graphState = getGraphState(StringType.INSTANCE, - "testWrite", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("0", id, "hello", "person")); - } - - for (int i = 0; i < 360; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("0", id, "world", "trade")); - } - - for (int i = 0; i < 390; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("1", id, "val", "person")); - } - - graphState.manage().operate().finish(); - List> edgeList1 = graphState.staticGraph().E().query("0") - .by(OutEdgeFilter.getInstance().and(EdgeLabelFilter.getInstance("person"))).asList(); - Assert.assertEquals(edgeList1.size(), 390); - graphState.manage().operate().close(); + graphState.manage().operate().finish(); + List> edgeList1 = + graphState + .staticGraph() + .E() + .query("0") + .by(OutEdgeFilter.getInstance().and(EdgeLabelFilter.getInstance("person"))) + .asList(); + Assert.assertEquals(edgeList1.size(), 390); + graphState.manage().operate().close(); + + graphState = getGraphState(StringType.INSTANCE, "testWrite", conf); + graphState.manage().operate().setCheckpointId(1); + + List> edgeList = graphState.staticGraph().E().query("0").asList(); + Assert.assertEquals(edgeList.size(), 750); + + edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 750); + edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); + Assert.assertEquals(edgeList.size(), 0); + edgeList = + graphState + .staticGraph() + .E() + .query("0") + .by(OutEdgeFilter.getInstance().and(EdgeLabelFilter.getInstance("person"))) + .asList(); + Assert.assertEquals(edgeList.size(), 390); - graphState = getGraphState(StringType.INSTANCE, "testWrite", conf); - graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } - List> edgeList = graphState.staticGraph().E().query("0").asList(); - Assert.assertEquals(edgeList.size(), 750); + @Test + public void testRead() { + Map conf = new HashMap<>(config); - edgeList = graphState.staticGraph().E().query("0").by(OutEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 750); - edgeList = graphState.staticGraph().E().query("0").by(InEdgeFilter.getInstance()).asList(); - Assert.assertEquals(edgeList.size(), 0); - edgeList = graphState.staticGraph().E().query("0") - .by(OutEdgeFilter.getInstance().and(EdgeLabelFilter.getInstance("person"))).asList(); - Assert.assertEquals(edgeList.size(), 390); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testRead", conf); + graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "hello", "person")); + graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "hello", "person")); } - - @Test - public void testRead() { - Map conf = new HashMap<>(config); - - GraphState graphState = getGraphState(StringType.INSTANCE, - "testRead", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "hello", "person")); - graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "hello", "person")); - } - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "world", "trade")); - graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "world", "relation")); - } - - List> edges1 = graphState.staticGraph().E().query("2") - .by(EdgeLabelFilter.getInstance("person")).asList(); - List> edges2 = graphState.staticGraph().E().query("2") - .by(EdgeLabelFilter.getInstance("trade")).asList(); - List> edges3 = graphState.staticGraph().E().query("2") - .by(EdgeLabelFilter.getInstance("illegal")).asList(); - IVertex vertex1 = graphState.staticGraph().V().query("9999") - .by(VertexLabelFilter.getInstance("relation")).get(); - IVertex vertex2 = graphState.staticGraph().V().query("9999") - .by(VertexLabelFilter.getInstance("person")).get(); - IVertex vertex3 = graphState.staticGraph().V().query("9999") - .by(VertexLabelFilter.getInstance("illegal")).get(); - - Assert.assertEquals(edges1.size(), 10000); - Assert.assertEquals(edges2.size(), 10000); - Assert.assertEquals(edges3.size(), 0); - Assert.assertEquals(edges1.get(1).getValue(), "hello"); - Assert.assertEquals(edges2.get(1).getValue(), "world"); - Assert.assertEquals(vertex1, new ValueLabelVertex<>("9999", "world", "relation")); - Assert.assertEquals(vertex2, new ValueLabelVertex<>("9999", "hello", "person")); - Assert.assertNull(vertex3); - - graphState.manage().operate().finish(); - graphState.manage().operate().drop(); + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "world", "trade")); + graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "world", "relation")); } - @Test - public void testIterator() { - Map conf = new HashMap<>(config); - - GraphState graphState = getGraphState(StringType.INSTANCE, - "testIterator", conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 200; i++) { - String id = Integer.toString(i); - if (i % 4 <= 1) { - if (i % 4 == 0) { - graphState.staticGraph().E() - .add(new ValueLabelEdge<>(id, id, "hello", "person")); - } else { - graphState.staticGraph().E() - .add(new ValueLabelEdge<>(id, id, "world", "trade")); - } - } - - if (i % 4 >= 1) { - if (i % 4 == 1) { - graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "tom", "person")); - } else { - graphState.staticGraph().V() - .add(new ValueLabelVertex<>(id, "company", "entity")); - } - } + List> edges1 = + graphState.staticGraph().E().query("2").by(EdgeLabelFilter.getInstance("person")).asList(); + List> edges2 = + graphState.staticGraph().E().query("2").by(EdgeLabelFilter.getInstance("trade")).asList(); + List> edges3 = + graphState.staticGraph().E().query("2").by(EdgeLabelFilter.getInstance("illegal")).asList(); + IVertex vertex1 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexLabelFilter.getInstance("relation")) + .get(); + IVertex vertex2 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexLabelFilter.getInstance("person")) + .get(); + IVertex vertex3 = + graphState + .staticGraph() + .V() + .query("9999") + .by(VertexLabelFilter.getInstance("illegal")) + .get(); + + Assert.assertEquals(edges1.size(), 10000); + Assert.assertEquals(edges2.size(), 10000); + Assert.assertEquals(edges3.size(), 0); + Assert.assertEquals(edges1.get(1).getValue(), "hello"); + Assert.assertEquals(edges2.get(1).getValue(), "world"); + Assert.assertEquals(vertex1, new ValueLabelVertex<>("9999", "world", "relation")); + Assert.assertEquals(vertex2, new ValueLabelVertex<>("9999", "hello", "person")); + Assert.assertNull(vertex3); + + graphState.manage().operate().finish(); + graphState.manage().operate().drop(); + } + + @Test + public void testIterator() { + Map conf = new HashMap<>(config); + + GraphState graphState = + getGraphState(StringType.INSTANCE, "testIterator", conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 200; i++) { + String id = Integer.toString(i); + if (i % 4 <= 1) { + if (i % 4 == 0) { + graphState.staticGraph().E().add(new ValueLabelEdge<>(id, id, "hello", "person")); + } else { + graphState.staticGraph().E().add(new ValueLabelEdge<>(id, id, "world", "trade")); } - graphState.manage().operate().finish(); + } - Iterator> it = graphState.staticGraph().V().query() - .by(VertexLabelFilter.getInstance("entity")).iterator(); - List> vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 100); - - it = graphState.staticGraph().V().query().by(VertexLabelFilter.getInstance("illegal")) + if (i % 4 >= 1) { + if (i % 4 == 1) { + graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "tom", "person")); + } else { + graphState.staticGraph().V().add(new ValueLabelVertex<>(id, "company", "entity")); + } + } + } + graphState.manage().operate().finish(); + + Iterator> it = + graphState.staticGraph().V().query().by(VertexLabelFilter.getInstance("entity")).iterator(); + List> vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 100); + + it = + graphState + .staticGraph() + .V() + .query() + .by(VertexLabelFilter.getInstance("illegal")) .iterator(); - vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 0); - - it = graphState.staticGraph().V().query("122", "151").iterator(); - vertices = Lists.newArrayList(it); - Assert.assertEquals(vertices.size(), 2); - - it = graphState.staticGraph().V().query().iterator(); - Assert.assertEquals(Iterators.size(it), 150); - Iterator idIt = graphState.staticGraph().V().query().idIterator(); - Assert.assertEquals(Iterators.size(idIt), 150); - - Iterator> it2 = graphState.staticGraph().VE().query() - .by(VertexLabelFilter.getInstance("person")).iterator(); - List res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 100); - - it2 = graphState.staticGraph().VE().query().by(VertexLabelFilter.getInstance("entity")) + vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 0); + + it = graphState.staticGraph().V().query("122", "151").iterator(); + vertices = Lists.newArrayList(it); + Assert.assertEquals(vertices.size(), 2); + + it = graphState.staticGraph().V().query().iterator(); + Assert.assertEquals(Iterators.size(it), 150); + Iterator idIt = graphState.staticGraph().V().query().idIterator(); + Assert.assertEquals(Iterators.size(idIt), 150); + + Iterator> it2 = + graphState + .staticGraph() + .VE() + .query() + .by(VertexLabelFilter.getInstance("person")) .iterator(); - res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 200); - - it2 = graphState.staticGraph().VE().query("109", "115").iterator(); - res = Lists.newArrayList(it2); - Assert.assertEquals(res.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } - - @Test - public void testOtherVE() { - Map conf = new HashMap<>(config); - - GraphMetaType tag = new GraphMetaType(IntegerType.INSTANCE, ValueLabelVertex.class, - StringType.class, ValueLabelTimeEdge.class, StringType.class); - GraphStateDescriptor desc = GraphStateDescriptor.build("OtherVE", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - if (i % 3 == 0) { - // 334 - graphState.staticGraph().V().add(new ValueLabelVertex<>(i, "hello", "default")); - IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, "hello", "foo", i); - edge.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge.withValue("bar")); - } else { - // 666 - graphState.staticGraph().V().add(new ValueLabelVertex<>(i, "hello", "default")); - IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, "hello", "person", i); - edge.setDirect(EdgeDirection.OUT); - graphState.staticGraph().E().add(edge.withValue("male")); - } - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().VE().query() - .by(EdgeLabelFilter.getInstance("foo")).asList(); - Assert.assertEquals(list.size(), 1000); - int key = list.get(0).getKey(); - Assert.assertEquals(list.get(0).getVertex(), - new ValueLabelVertex<>(key, "hello", "default")); - IEdge edge = new ValueLabelTimeEdge<>(key, key + 1, "hello", "foo", key); - edge.setDirect(EdgeDirection.IN); - Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge.withValue("bar")); - Assert.assertFalse(list.get(1).getEdgeIterator().hasNext()); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - tag = new GraphMetaType(IntegerType.INSTANCE, ValueLabelTimeVertex.class, Object.class, - ValueLabelEdge.class, Object.class); - desc.withGraphMeta(new GraphMeta(tag)); - graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - graphState.staticGraph().V().add(new ValueLabelTimeVertex<>(i, "bar", "foo", i)); - IEdge idEdge = new ValueLabelEdge<>(i, i + 1, "3.0", "default"); - idEdge.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(idEdge); - } - graphState.manage().operate().finish(); - list = graphState.staticGraph().VE().query().asList(); - Assert.assertEquals(list.size(), 1000); - key = list.get(0).getKey(); - Assert.assertEquals(list.get(0).getVertex(), - new ValueLabelTimeVertex<>(key, "bar", "foo", key)); - - edge = new ValueLabelEdge<>(key, key + 1, "3.0", "default"); + List res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 100); + + it2 = + graphState + .staticGraph() + .VE() + .query() + .by(VertexLabelFilter.getInstance("entity")) + .iterator(); + res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 200); + + it2 = graphState.staticGraph().VE().query("109", "115").iterator(); + res = Lists.newArrayList(it2); + Assert.assertEquals(res.size(), 2); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testOtherVE() { + Map conf = new HashMap<>(config); + + GraphMetaType tag = + new GraphMetaType( + IntegerType.INSTANCE, + ValueLabelVertex.class, + StringType.class, + ValueLabelTimeEdge.class, + StringType.class); + GraphStateDescriptor desc = GraphStateDescriptor.build("OtherVE", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + if (i % 3 == 0) { + // 334 + graphState.staticGraph().V().add(new ValueLabelVertex<>(i, "hello", "default")); + IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, "hello", "foo", i); edge.setDirect(EdgeDirection.IN); - Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + graphState.staticGraph().E().add(edge.withValue("bar")); + } else { + // 666 + graphState.staticGraph().V().add(new ValueLabelVertex<>(i, "hello", "default")); + IEdge edge = new ValueLabelTimeEdge<>(i, i + 1, "hello", "person", i); + edge.setDirect(EdgeDirection.OUT); + graphState.staticGraph().E().add(edge.withValue("male")); + } } - - @Test - public void testFilter() { - Map conf = new HashMap<>(config); - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, ValueLabelVertex.class, - ValueLabelVertex::new, EmptyProperty.class, IDLabelTimeEdge.class, IDLabelTimeEdge::new, + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().VE().query().by(EdgeLabelFilter.getInstance("foo")).asList(); + Assert.assertEquals(list.size(), 1000); + int key = list.get(0).getKey(); + Assert.assertEquals(list.get(0).getVertex(), new ValueLabelVertex<>(key, "hello", "default")); + IEdge edge = new ValueLabelTimeEdge<>(key, key + 1, "hello", "foo", key); + edge.setDirect(EdgeDirection.IN); + Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge.withValue("bar")); + Assert.assertFalse(list.get(1).getEdgeIterator().hasNext()); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + tag = + new GraphMetaType( + IntegerType.INSTANCE, + ValueLabelTimeVertex.class, + Object.class, + ValueLabelEdge.class, + Object.class); + desc.withGraphMeta(new GraphMeta(tag)); + graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + graphState.staticGraph().V().add(new ValueLabelTimeVertex<>(i, "bar", "foo", i)); + IEdge idEdge = new ValueLabelEdge<>(i, i + 1, "3.0", "default"); + idEdge.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(idEdge); + } + graphState.manage().operate().finish(); + list = graphState.staticGraph().VE().query().asList(); + Assert.assertEquals(list.size(), 1000); + key = list.get(0).getKey(); + Assert.assertEquals( + list.get(0).getVertex(), new ValueLabelTimeVertex<>(key, "bar", "foo", key)); + + edge = new ValueLabelEdge<>(key, key + 1, "3.0", "default"); + edge.setDirect(EdgeDirection.IN); + Assert.assertEquals(list.get(0).getEdgeIterator().next(), edge); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFilter() { + Map conf = new HashMap<>(config); + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + ValueLabelVertex.class, + ValueLabelVertex::new, + EmptyProperty.class, + IDLabelTimeEdge.class, + IDLabelTimeEdge::new, EmptyProperty.class); - GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", i); - edge1.setDirect(EdgeDirection.OUT); - graphState.staticGraph().E().add(edge1); - - IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", i); - edge2.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge2); - - IEdge edge3 = new IDLabelTimeEdge<>("2", id, "world", i); - edge3.setDirect(EdgeDirection.IN); - graphState.staticGraph().E().add(edge3); - } + GraphStateDescriptor desc = GraphStateDescriptor.build("filter", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge1 = new IDLabelTimeEdge<>("2", id, "hello", i); + edge1.setDirect(EdgeDirection.OUT); + graphState.staticGraph().E().add(edge1); + + IEdge edge2 = new IDLabelTimeEdge<>("2", id, "hello", i); + edge2.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge2); + + IEdge edge3 = new IDLabelTimeEdge<>("2", id, "world", i); + edge3.setDirect(EdgeDirection.IN); + graphState.staticGraph().E().add(edge3); + } - graphState.manage().operate().finish(); - List> edges = graphState.staticGraph().E().query("2") - .by(new EdgeTsFilter(TimeRange.of(0, 5000))).asList(); + graphState.manage().operate().finish(); + List> edges = + graphState + .staticGraph() + .E() + .query("2") + .by(new EdgeTsFilter(TimeRange.of(0, 5000))) + .asList(); - Assert.assertEquals(edges.size(), 15000); - long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max() - .getAsLong(); - Assert.assertEquals(maxTime, 4999); + Assert.assertEquals(edges.size(), 15000); + long maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 4999); - long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 5000); + long num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 5000); - edges = graphState.staticGraph().E().query("2") - .by(OutEdgeFilter.getInstance().and(new EdgeTsFilter(TimeRange.of(0, 1000)))).asList(); - Assert.assertEquals(edges.size(), 1000); + edges = + graphState + .staticGraph() + .E() + .query("2") + .by(OutEdgeFilter.getInstance().and(new EdgeTsFilter(TimeRange.of(0, 1000)))) + .asList(); + Assert.assertEquals(edges.size(), 1000); - maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); - Assert.assertEquals(maxTime, 999); + maxTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).max().getAsLong(); + Assert.assertEquals(maxTime, 999); - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); - Assert.assertEquals(num, 1000); + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.OUT).count(); + Assert.assertEquals(num, 1000); - num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); - Assert.assertEquals(num, 0); + num = edges.stream().filter(e -> e.getDirect() == EdgeDirection.IN).count(); + Assert.assertEquals(num, 0); - edges = graphState.staticGraph().E().query("2") - .by(new EdgeTsFilter(TimeRange.of(7000, 10000)).and(EdgeLabelFilter.getInstance("world"))) + edges = + graphState + .staticGraph() + .E() + .query("2") + .by( + new EdgeTsFilter(TimeRange.of(7000, 10000)) + .and(EdgeLabelFilter.getInstance("world"))) .asList(); - Assert.assertEquals(edges.size(), 3000); - long minTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).min() - .getAsLong(); - Assert.assertEquals(minTime, 7000); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - // TODO: MAX VERSION FILTER SUPPORT - } - - @Test(expectedExceptions = IllegalArgumentException.class) - public void testEdgeOrderError() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); - getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); - } - - @Test - public void testEdgeSort() { - Map conf = Maps.newHashMap(config); - conf.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); - - GraphMetaType tag = new GraphMetaType(StringType.INSTANCE, ValueLabelVertex.class, - ValueLabelVertex::new, Object.class, ValueLabelTimeEdge.class, ValueLabelTimeEdge::new, + Assert.assertEquals(edges.size(), 3000); + long minTime = edges.stream().mapToLong(e -> ((IDLabelTimeEdge) e).getTime()).min().getAsLong(); + Assert.assertEquals(minTime, 7000); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + // TODO: MAX VERSION FILTER SUPPORT + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testEdgeOrderError() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + getGraphState(StringType.INSTANCE, "testEdgeOrderError", conf); + } + + @Test + public void testEdgeSort() { + Map conf = Maps.newHashMap(config); + conf.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, LABEL, DIRECTION, DST_ID"); + + GraphMetaType tag = + new GraphMetaType( + StringType.INSTANCE, + ValueLabelVertex.class, + ValueLabelVertex::new, + Object.class, + ValueLabelTimeEdge.class, + ValueLabelTimeEdge::new, Object.class); - GraphStateDescriptor desc = GraphStateDescriptor.build("testEdgeSort", - StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(conf)); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10000; i++) { - String id = Integer.toString(i); - IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", i); - graphState.staticGraph().E().add(edge.withValue("world")); - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().asList(); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 9999); - Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 0); + GraphStateDescriptor desc = + GraphStateDescriptor.build("testEdgeSort", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(conf)); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10000; i++) { + String id = Integer.toString(i); + IEdge edge = new ValueLabelTimeEdge<>("2", id, null, "hello", i); + graphState.staticGraph().E().add(edge.withValue("world")); } - - @Test - public void testLimit() { - Map conf = new HashMap<>(config); - String name = "testLimit"; - conf.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "RocksDBLabelPartitionGraphStateTest" + System.currentTimeMillis()); - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 10; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add(new ValueLabelEdge<>(src, dst, "hello" + src + dst, - EdgeDirection.values()[j % 2], "foo")); - } - graphState.staticGraph().V().add(new ValueLabelVertex<>(src, "world" + src, "foo")); - } - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1", "2", "3") - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 6); - - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L) - .asList(); - Assert.assertEquals(list.size(), 10); - - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L) - .asList(); - Assert.assertEquals(list.size(), 20); - - List targetIds = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); - - Assert.assertEquals(targetIds.size(), 20); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().asList(); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(0)).getTime(), 9999); + Assert.assertEquals(((ValueLabelTimeEdge) list.get(9999)).getTime(), 0); + } + + @Test + public void testLimit() { + Map conf = new HashMap<>(config); + String name = "testLimit"; + conf.put( + ExecutionConfigKeys.JOB_APP_NAME.getKey(), + "RocksDBLabelPartitionGraphStateTest" + System.currentTimeMillis()); + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 10; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add( + new ValueLabelEdge<>( + src, dst, "hello" + src + dst, EdgeDirection.values()[j % 2], "foo")); + } + graphState.staticGraph().V().add(new ValueLabelVertex<>(src, "world" + src, "foo")); } + graphState.manage().operate().finish(); + + List> list = + graphState.staticGraph().E().query("1", "2", "3").limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 6); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 10); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L).asList(); + Assert.assertEquals(list.size(), 20); + + List targetIds = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance()) + .select(new DstIdProjector<>()) + .limit(1L, 2L) + .asList(); - @Test - public void testFO() throws IOException { - Map conf = new HashMap<>(config); - String name = "testFO"; - GraphState graphState = getGraphState(StringType.INSTANCE, name, - conf); - graphState.manage().operate().setCheckpointId(1); - - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); + Assert.assertEquals(targetIds.size(), 20); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } - graphState.manage().operate().drop(); - graphState = getGraphState(StringType.INSTANCE, name, conf); + @Test + public void testFO() throws IOException { + Map conf = new HashMap<>(config); + String name = "testFO"; + GraphState graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().setCheckpointId(1); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(2); + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); - for (int i = 0; i < 100; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("1", id, "hello", "foo")); - graphState.staticGraph().V().add(new ValueLabelVertex<>("1", "hello", "foo")); - } + graphState.manage().operate().drop(); + graphState = getGraphState(StringType.INSTANCE, name, conf); - List> edges = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(edges.size(), 100); - List> vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 1); + graphState.manage().operate().setCheckpointId(1); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(2); - graphState.manage().operate().archive(); - graphState.manage().operate().finish(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(2); - graphState.manage().operate().recover(); - graphState.manage().operate().setCheckpointId(3); - - edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 100); - vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 1); + for (int i = 0; i < 100; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("1", id, "hello", "foo")); + graphState.staticGraph().V().add(new ValueLabelVertex<>("1", "hello", "foo")); + } - for (int i = 0; i < 80; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "hello", "person")); - graphState.staticGraph().V().add(new ValueLabelVertex<>("2", "hello", "person")); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - edges = graphState.staticGraph().E().asList(); - Assert.assertEquals(edges.size(), 180); - edges = graphState.staticGraph().E().query().by(EdgeLabelFilter.getInstance("person")) - .asList(); - Assert.assertEquals(edges.size(), 80); - edges = graphState.staticGraph().E().query().by(EdgeLabelFilter.getInstance("foo")).asList(); - Assert.assertEquals(edges.size(), 100); - vertices = graphState.staticGraph().V().asList(); - Assert.assertEquals(vertices.size(), 2); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - - graphState = getGraphState(StringType.INSTANCE, name, conf); - graphState.manage().operate().setCheckpointId(3); + List> edges = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(edges.size(), 100); + List> vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 1); + + graphState.manage().operate().archive(); + graphState.manage().operate().finish(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(2); + graphState.manage().operate().recover(); + graphState.manage().operate().setCheckpointId(3); + + edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 100); + vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 1); + + for (int i = 0; i < 80; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>("2", id, "hello", "person")); + graphState.staticGraph().V().add(new ValueLabelVertex<>("2", "hello", "person")); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + edges = graphState.staticGraph().E().asList(); + Assert.assertEquals(edges.size(), 180); + edges = graphState.staticGraph().E().query().by(EdgeLabelFilter.getInstance("person")).asList(); + Assert.assertEquals(edges.size(), 80); + edges = graphState.staticGraph().E().query().by(EdgeLabelFilter.getInstance("foo")).asList(); + Assert.assertEquals(edges.size(), 100); + vertices = graphState.staticGraph().V().asList(); + Assert.assertEquals(vertices.size(), 2); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + + graphState = getGraphState(StringType.INSTANCE, name, conf); + graphState.manage().operate().setCheckpointId(3); + graphState.manage().operate().recover(); + graphState.staticGraph().V().add(new ValueLabelVertex<>("2", "world", "person")); + Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); + graphState.manage().operate().finish(); + Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } + + @Test + public void testArchive() throws IOException { + Map conf = new HashMap<>(config); + IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + + GraphState graphState = null; + + for (int v = 1; v < 10; v++) { + graphState = getGraphState(StringType.INSTANCE, "archive", conf); + if (v > 1) { + graphState.manage().operate().setCheckpointId(v - 1); graphState.manage().operate().recover(); - graphState.staticGraph().V().add(new ValueLabelVertex<>("2", "world", "person")); - Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); - graphState.manage().operate().finish(); - Assert.assertEquals(graphState.staticGraph().V().query("2").get().getValue(), "world"); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); + } + graphState.manage().operate().setCheckpointId(v); + for (int i = 0; i < 10; i++) { + String id = Integer.toString(i); + graphState.staticGraph().E().add(new ValueLabelEdge<>(id, id, id, "foo")); + } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + graphState.manage().operate().close(); } - - @Test - public void testArchive() throws IOException { - Map conf = new HashMap<>(config); - IPersistentIO persistentIO = PersistentIOBuilder.build(new Configuration(conf)); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - - GraphState graphState = null; - - for (int v = 1; v < 10; v++) { - graphState = getGraphState(StringType.INSTANCE, "archive", conf); - if (v > 1) { - graphState.manage().operate().setCheckpointId(v - 1); - graphState.manage().operate().recover(); - } - graphState.manage().operate().setCheckpointId(v); - for (int i = 0; i < 10; i++) { - String id = Integer.toString(i); - graphState.staticGraph().E().add(new ValueLabelEdge<>(id, id, id, "foo")); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - graphState.manage().operate().close(); - } - - graphState.manage().operate().drop(); - persistentIO.delete(new Path(Configuration.getString(FileConfigKeys.ROOT, conf), - Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), true); - } + graphState.manage().operate().drop(); + persistentIO.delete( + new Path( + Configuration.getString(FileConfigKeys.ROOT, conf), + Configuration.getString(ExecutionConfigKeys.JOB_APP_NAME, conf)), + true); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbDynamicGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbDynamicGraphStateTest.java index 8e017527d..7cea7e312 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbDynamicGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbDynamicGraphStateTest.java @@ -19,14 +19,13 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; import java.io.File; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -51,157 +50,186 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -public class RocksdbDynamicGraphStateTest { - - Map config = new HashMap<>(); - - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); - Map persistConfig = new HashMap<>(); - - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksDBGraphStateTest"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - } - - private GraphState getGraphState(IType type, String name, Map conf) { - GraphMetaType tag = new GraphMetaType(type, ValueVertex.class, - type.getTypeClass(), ValueEdge.class, type.getTypeClass()); +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; - GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withDataModel(DataModel.DYNAMIC_GRAPH) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta(tag)); - GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); - return graphState; - } +public class RocksdbDynamicGraphStateTest { - @Test - public void testBothWriteMode() { - testApi(true); - testApi(false); + Map config = new HashMap<>(); + + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksDBGraphStateTest")); + Map persistConfig = new HashMap<>(); + + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksDBGraphStateTest"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + } + + private GraphState getGraphState( + IType type, String name, Map conf) { + GraphMetaType tag = + new GraphMetaType( + type, ValueVertex.class, type.getTypeClass(), ValueEdge.class, type.getTypeClass()); + + GraphStateDescriptor desc = GraphStateDescriptor.build(name, StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 1)) + .withDataModel(DataModel.DYNAMIC_GRAPH) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta(new GraphMeta(tag)); + GraphState graphState = StateFactory.buildGraphState(desc, new Configuration(conf)); + return graphState; + } + + @Test + public void testBothWriteMode() { + testApi(true); + testApi(false); + } + + private void testApi(boolean async) { + Map conf = config; + conf.put(StateConfigKeys.STATE_WRITE_BUFFER_SIZE.getKey(), "100"); + conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), String.valueOf(async)); + GraphState graphState = + getGraphState(StringType.INSTANCE, "testApi", conf); + + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 1000; i++) { + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); + graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); + graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); + graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); + graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); + graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); } - private void testApi(boolean async) { - Map conf = config; - conf.put(StateConfigKeys.STATE_WRITE_BUFFER_SIZE.getKey(), "100"); - conf.put(StateConfigKeys.STATE_WRITE_ASYNC_ENABLE.getKey(), String.valueOf(async)); - GraphState graphState = getGraphState(StringType.INSTANCE, "testApi", conf); - - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 1000; i++) { - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "2", "hello")); - graphState.dynamicGraph().E().add(1L, new ValueEdge<>("1", "3", "hello")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "2", "world")); - graphState.dynamicGraph().E().add(2L, new ValueEdge<>("2", "3", "world")); - graphState.dynamicGraph().V().add(1L, new ValueVertex<>("1", "3")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("2", "4")); - graphState.dynamicGraph().V().add(2L, new ValueVertex<>("1", "5")); - graphState.dynamicGraph().V().add(3L, new ValueVertex<>("1", "6")); - } - - graphState.manage().operate().finish(); - graphState.dynamicGraph().V().add(4L, new ValueVertex<>("1", "6")); - graphState.dynamicGraph().V().add(4L, new ValueVertex<>("3", "6")); - graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "1", "6")); - graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "2", "6")); - - List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.dynamicGraph().E().query(1L, "1").by( - (IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); - - IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); - Assert.assertEquals(vertex.getValue(), "3"); - - Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); - Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 4); - Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("1"), 4); - - Map> map = graphState.dynamicGraph().V().query("1").asMap(); - Assert.assertEquals(map.size(), 4); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L)).by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 2); - - map = graphState.dynamicGraph().V().query("1").by( - (IVertexFilter) value -> !value.getValue().equals("5")).asMap(); - Assert.assertEquals(map.size(), 3); - - map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L, 5L)).asMap(); - Assert.assertEquals(map.size(), 3); - - List> res = - graphState.dynamicGraph().VE().query(2L, "2").asList(); - Assert.assertEquals(res.size(), 1); - - res = graphState.dynamicGraph().VE().query(3L, "1").asList(); - Assert.assertEquals(res.size(), 1); - - Iterator idIterator = graphState.dynamicGraph().V().idIterator(); - List idList = Lists.newArrayList(idIterator); - Assert.assertEquals(idList.size(), 3); - - res = - graphState.dynamicGraph().VE().query(4L, "1").asList(); - Assert.assertEquals(res.size(), 1); - Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 2); - - res = - graphState.dynamicGraph().VE().query(4L, "1") - .by((IEdgeFilter) value -> !value.getTargetId().equals("1")).asList(); - Assert.assertEquals(res.size(), 1); - Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 1); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + graphState.manage().operate().finish(); + graphState.dynamicGraph().V().add(4L, new ValueVertex<>("1", "6")); + graphState.dynamicGraph().V().add(4L, new ValueVertex<>("3", "6")); + graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "1", "6")); + graphState.dynamicGraph().E().add(4L, new ValueEdge<>("1", "2", "6")); + + List> list = graphState.dynamicGraph().E().query(1L, "1").asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .dynamicGraph() + .E() + .query(1L, "1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + + Iterator> iterator = graphState.dynamicGraph().V().query(2L).iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.dynamicGraph().V().query(1L, "1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("2"), 2L); + Assert.assertEquals(graphState.dynamicGraph().V().getAllVersions("1").size(), 4); + Assert.assertEquals(graphState.dynamicGraph().V().getLatestVersion("1"), 4); + + Map> map = graphState.dynamicGraph().V().query("1").asMap(); + Assert.assertEquals(map.size(), 4); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L)).asMap(); + Assert.assertEquals(map.size(), 2); + + map = + graphState + .dynamicGraph() + .V() + .query("1", Arrays.asList(2L, 3L, 4L)) + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 2); + + map = + graphState + .dynamicGraph() + .V() + .query("1") + .by((IVertexFilter) value -> !value.getValue().equals("5")) + .asMap(); + Assert.assertEquals(map.size(), 3); + + map = graphState.dynamicGraph().V().query("1", Arrays.asList(2L, 3L, 4L, 5L)).asMap(); + Assert.assertEquals(map.size(), 3); + + List> res = + graphState.dynamicGraph().VE().query(2L, "2").asList(); + Assert.assertEquals(res.size(), 1); + + res = graphState.dynamicGraph().VE().query(3L, "1").asList(); + Assert.assertEquals(res.size(), 1); + + Iterator idIterator = graphState.dynamicGraph().V().idIterator(); + List idList = Lists.newArrayList(idIterator); + Assert.assertEquals(idList.size(), 3); + + res = graphState.dynamicGraph().VE().query(4L, "1").asList(); + Assert.assertEquals(res.size(), 1); + Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 2); + + res = + graphState + .dynamicGraph() + .VE() + .query(4L, "1") + .by((IEdgeFilter) value -> !value.getTargetId().equals("1")) + .asList(); + Assert.assertEquals(res.size(), 1); + Assert.assertEquals(Iterators.size(res.get(0).getEdgeIterator()), 1); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testKeyGroup() { + Map conf = config; + GraphState graphState = + getGraphState(StringType.INSTANCE, "testKeyGroup", conf); + + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 10; i++) { + graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "2", "hello" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "3", "hello" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "2", "world" + i)); + graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "3", "world" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "3" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("2", "4" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "5" + i)); + graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "6" + i)); } - @Test - public void testKeyGroup() { - Map conf = config; - GraphState graphState = getGraphState(StringType.INSTANCE, "testKeyGroup", conf); - - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 10; i++) { - graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "2", "hello" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("1", "3", "hello" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "2", "world" + i)); - graphState.dynamicGraph().E().add(i, new ValueEdge<>("2", "3", "world" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "3" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("2", "4" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "5" + i)); - graphState.dynamicGraph().V().add(i, new ValueVertex<>("1", "6" + i)); - } + graphState.manage().operate().finish(); - graphState.manage().operate().finish(); + Iterator idIterator = graphState.dynamicGraph().V().idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 2); - Iterator idIterator = graphState.dynamicGraph().V().idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 2); - - idIterator = graphState.dynamicGraph().V().query(1L, new KeyGroup(0, 0)).idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 1); - - List> list = graphState.dynamicGraph().E().query(1L, - new KeyGroup(0, 0)).by((IEdgeFilter) value -> !value.getTargetId().equals("2")).asList(); - Assert.assertEquals(list.size(), 1); - Assert.assertEquals(list.get(0).getTargetId(), "3"); - - } + idIterator = graphState.dynamicGraph().V().query(1L, new KeyGroup(0, 0)).idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 1); + List> list = + graphState + .dynamicGraph() + .E() + .query(1L, new KeyGroup(0, 0)) + .by((IEdgeFilter) value -> !value.getTargetId().equals("2")) + .asList(); + Assert.assertEquals(list.size(), 1); + Assert.assertEquals(list.get(0).getTargetId(), "3"); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbKeyStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbKeyStateTest.java index daa57a3cd..2c09541db 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbKeyStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/RocksdbKeyStateTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -39,189 +40,182 @@ public class RocksdbKeyStateTest { - Map config = new HashMap<>(); - - @BeforeClass - public void setUp() { - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); - FileUtils.deleteQuietly(new File("/tmp/RocksdbKeyStateTest")); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbKeyStateTest"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - } - - @AfterClass - public void tearUp() { - FileUtils.deleteQuietly(new File("/tmp/RocksdbKeyStateTest")); - } - - @Test - public void testKMap() { - KeyMapStateDescriptor desc = - KeyMapStateDescriptor.build("testKV", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyMapState mapState = StateFactory.buildKeyMapState(desc, - new Configuration(config)); - mapState.manage().operate().setCheckpointId(1L); - - Map conf = new HashMap<>(config); - mapState.put("hello", conf); - mapState.add("foo", "bar1", "bar2"); - Assert.assertEquals(mapState.get("hello").size(), conf.size()); - Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); - - mapState.manage().operate().finish(); - mapState.manage().operate().archive(); - - mapState.manage().operate().close(); - mapState.manage().operate().drop(); - - mapState = StateFactory.buildKeyMapState(desc, - new Configuration(config)); - mapState.manage().operate().setCheckpointId(1L); - mapState.manage().operate().recover(); - - mapState.manage().operate().setCheckpointId(2L); - Assert.assertEquals(mapState.get("hello").size(), conf.size()); - Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); - - mapState.add("foo", "bar2", "bar3"); - mapState.manage().operate().finish(); - mapState.manage().operate().archive(); - - mapState.manage().operate().close(); - mapState.manage().operate().drop(); - - mapState = StateFactory.buildKeyMapState(desc, - new Configuration(config)); - mapState.manage().operate().setCheckpointId(2L); - mapState.manage().operate().recover(); - Assert.assertEquals(mapState.get("hello").size(), conf.size()); - Assert.assertEquals(mapState.get("foo").get("bar2"), "bar3"); - - mapState.manage().operate().close(); - mapState.manage().operate().drop(); - } - - @Test - public void testKList() { - KeyListStateDescriptor desc = - KeyListStateDescriptor.build("testKList", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyListState listState = StateFactory.buildKeyListState(desc, - new Configuration(config)); - listState.manage().operate().setCheckpointId(1L); - - listState.add("hello", "world"); - listState.put("foo", Arrays.asList("bar1", "bar2")); - Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); - Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); - - listState.manage().operate().finish(); - listState.manage().operate().archive(); - - listState.manage().operate().close(); - listState.manage().operate().drop(); - - listState = StateFactory.buildKeyListState(desc, new Configuration(config)); - listState.manage().operate().setCheckpointId(1L); - listState.manage().operate().recover(); - - listState.manage().operate().setCheckpointId(2L); - Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); - Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); - - listState.manage().operate().close(); - listState.manage().operate().drop(); - } - - @Test - public void testKV() { - KeyValueStateDescriptor desc = - KeyValueStateDescriptor.build("testKV", StoreType.ROCKSDB.name()); - desc.withDefaultValue(() -> "foobar").withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyValueState valueState = StateFactory.buildKeyValueState(desc, - new Configuration(config)); - - valueState.manage().operate().setCheckpointId(1L); - - valueState.put("hello", "world"); - Assert.assertEquals(valueState.get("hello"), "world"); - Assert.assertEquals(valueState.get("foo"), "foobar"); - - valueState.manage().operate().finish(); - valueState.manage().operate().archive(); - - valueState.manage().operate().close(); - valueState.manage().operate().drop(); - - valueState = StateFactory.buildKeyValueState(desc, new Configuration(config)); - valueState.manage().operate().setCheckpointId(1L); - valueState.manage().operate().recover(); - - valueState.manage().operate().setCheckpointId(2L); - Assert.assertEquals(valueState.get("hello"), "world"); - Assert.assertEquals(valueState.get("foo"), "foobar"); - valueState.manage().operate().close(); - valueState.manage().operate().drop(); - - - desc.withTypeInfo(String.class, String.class); - valueState = StateFactory.buildKeyValueState(desc, - new Configuration(config)); - - valueState.manage().operate().setCheckpointId(1L); - - valueState.put("hello", "world"); - Assert.assertEquals(valueState.get("hello"), "world"); - Assert.assertEquals(valueState.get("foo"), "foobar"); - - valueState.manage().operate().finish(); - valueState.manage().operate().archive(); - + Map config = new HashMap<>(); + + @BeforeClass + public void setUp() { + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk/")); + FileUtils.deleteQuietly(new File("/tmp/RocksdbKeyStateTest")); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "RocksdbKeyStateTest"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + } + + @AfterClass + public void tearUp() { + FileUtils.deleteQuietly(new File("/tmp/RocksdbKeyStateTest")); + } + + @Test + public void testKMap() { + KeyMapStateDescriptor desc = + KeyMapStateDescriptor.build("testKV", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyMapState mapState = + StateFactory.buildKeyMapState(desc, new Configuration(config)); + mapState.manage().operate().setCheckpointId(1L); + + Map conf = new HashMap<>(config); + mapState.put("hello", conf); + mapState.add("foo", "bar1", "bar2"); + Assert.assertEquals(mapState.get("hello").size(), conf.size()); + Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); + + mapState.manage().operate().finish(); + mapState.manage().operate().archive(); + + mapState.manage().operate().close(); + mapState.manage().operate().drop(); + + mapState = StateFactory.buildKeyMapState(desc, new Configuration(config)); + mapState.manage().operate().setCheckpointId(1L); + mapState.manage().operate().recover(); + + mapState.manage().operate().setCheckpointId(2L); + Assert.assertEquals(mapState.get("hello").size(), conf.size()); + Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); + + mapState.add("foo", "bar2", "bar3"); + mapState.manage().operate().finish(); + mapState.manage().operate().archive(); + + mapState.manage().operate().close(); + mapState.manage().operate().drop(); + + mapState = StateFactory.buildKeyMapState(desc, new Configuration(config)); + mapState.manage().operate().setCheckpointId(2L); + mapState.manage().operate().recover(); + Assert.assertEquals(mapState.get("hello").size(), conf.size()); + Assert.assertEquals(mapState.get("foo").get("bar2"), "bar3"); + + mapState.manage().operate().close(); + mapState.manage().operate().drop(); + } + + @Test + public void testKList() { + KeyListStateDescriptor desc = + KeyListStateDescriptor.build("testKList", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyListState listState = + StateFactory.buildKeyListState(desc, new Configuration(config)); + listState.manage().operate().setCheckpointId(1L); + + listState.add("hello", "world"); + listState.put("foo", Arrays.asList("bar1", "bar2")); + Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); + Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); + + listState.manage().operate().finish(); + listState.manage().operate().archive(); + + listState.manage().operate().close(); + listState.manage().operate().drop(); + + listState = StateFactory.buildKeyListState(desc, new Configuration(config)); + listState.manage().operate().setCheckpointId(1L); + listState.manage().operate().recover(); + + listState.manage().operate().setCheckpointId(2L); + Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); + Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); + + listState.manage().operate().close(); + listState.manage().operate().drop(); + } + + @Test + public void testKV() { + KeyValueStateDescriptor desc = + KeyValueStateDescriptor.build("testKV", StoreType.ROCKSDB.name()); + desc.withDefaultValue(() -> "foobar") + .withKeyGroup(new KeyGroup(0, 0)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyValueState valueState = + StateFactory.buildKeyValueState(desc, new Configuration(config)); + + valueState.manage().operate().setCheckpointId(1L); + + valueState.put("hello", "world"); + Assert.assertEquals(valueState.get("hello"), "world"); + Assert.assertEquals(valueState.get("foo"), "foobar"); + + valueState.manage().operate().finish(); + valueState.manage().operate().archive(); + + valueState.manage().operate().close(); + valueState.manage().operate().drop(); + + valueState = StateFactory.buildKeyValueState(desc, new Configuration(config)); + valueState.manage().operate().setCheckpointId(1L); + valueState.manage().operate().recover(); + + valueState.manage().operate().setCheckpointId(2L); + Assert.assertEquals(valueState.get("hello"), "world"); + Assert.assertEquals(valueState.get("foo"), "foobar"); + valueState.manage().operate().close(); + valueState.manage().operate().drop(); + + desc.withTypeInfo(String.class, String.class); + valueState = StateFactory.buildKeyValueState(desc, new Configuration(config)); + + valueState.manage().operate().setCheckpointId(1L); + + valueState.put("hello", "world"); + Assert.assertEquals(valueState.get("hello"), "world"); + Assert.assertEquals(valueState.get("foo"), "foobar"); + + valueState.manage().operate().finish(); + valueState.manage().operate().archive(); + + valueState.manage().operate().close(); + valueState.manage().operate().drop(); + + valueState = StateFactory.buildKeyValueState(desc, new Configuration(config)); + valueState.manage().operate().setCheckpointId(1L); + valueState.manage().operate().recover(); + + valueState.manage().operate().setCheckpointId(2L); + Assert.assertEquals(valueState.get("hello"), "world"); + Assert.assertEquals(valueState.get("foo"), "foobar"); + valueState.manage().operate().close(); + valueState.manage().operate().drop(); + } + + @Test + public void testKVFO() { + KeyValueStateDescriptor desc = + KeyValueStateDescriptor.build("testKVFO", StoreType.ROCKSDB.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyValueState valueState = + StateFactory.buildKeyValueState(desc, new Configuration(config)); + + for (int i = 5; i < 200; i += 5) { + valueState.manage().operate().setCheckpointId(i); + if (i > 100) { + for (int j = 0; j < 100; j++) { + valueState.put("hello", "world" + j); + } + } + valueState.manage().operate().finish(); + valueState.manage().operate().archive(); + if (i % 50 == 0) { valueState.manage().operate().close(); valueState.manage().operate().drop(); - valueState = StateFactory.buildKeyValueState(desc, new Configuration(config)); - valueState.manage().operate().setCheckpointId(1L); + valueState.manage().operate().setCheckpointId(i); valueState.manage().operate().recover(); - - valueState.manage().operate().setCheckpointId(2L); - Assert.assertEquals(valueState.get("hello"), "world"); - Assert.assertEquals(valueState.get("foo"), "foobar"); - valueState.manage().operate().close(); - valueState.manage().operate().drop(); - } - - @Test - public void testKVFO() { - KeyValueStateDescriptor desc = - KeyValueStateDescriptor.build("testKVFO", StoreType.ROCKSDB.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyValueState valueState = StateFactory.buildKeyValueState(desc, - new Configuration(config)); - - for (int i = 5; i < 200; i += 5) { - valueState.manage().operate().setCheckpointId(i); - if (i > 100) { - for (int j = 0; j < 100; j++) { - valueState.put("hello", "world" + j); - } - } - valueState.manage().operate().finish(); - valueState.manage().operate().archive(); - if (i % 50 == 0) { - valueState.manage().operate().close(); - valueState.manage().operate().drop(); - valueState = StateFactory.buildKeyValueState(desc, - new Configuration(config)); - valueState.manage().operate().setCheckpointId(i); - valueState.manage().operate().recover(); - } - } + } } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StateFactoryTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StateFactoryTest.java index 9e415e096..83456b38e 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StateFactoryTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StateFactoryTest.java @@ -19,7 +19,6 @@ package org.apache.geaflow.state; -import com.google.common.collect.Iterators; import java.io.File; import java.io.IOException; import java.util.ArrayList; @@ -32,6 +31,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -56,193 +56,211 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.google.common.collect.Iterators; + public class StateFactoryTest { - private long ts; - private String testName; + private long ts; + private String testName; + + @BeforeMethod + public void setUp() { + this.ts = System.currentTimeMillis(); + this.testName = "factory-test-" + ts; + } + + @AfterMethod + public void tearDown() throws IOException { + FileUtils.deleteQuietly(new File("/tmp/" + testName)); + } - @BeforeMethod - public void setUp() { - this.ts = System.currentTimeMillis(); - this.testName = "factory-test-" + ts; + @Test + public void testSingleton() throws ExecutionException, InterruptedException { + Map config = new HashMap<>(); + Map persistConfig = new HashMap<>(); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), testName); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); + + GraphMetaType tag = + new GraphMetaType( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class); + GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.ROCKSDB.name()); + desc.withGraphMeta(new GraphMeta(tag)) + .withKeyGroup(new KeyGroup(0, 1)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withSingleton(); + ExecutorService executors = Executors.getExecutorService(8, "testSingleton"); + List> list = new ArrayList<>(); + for (int i = 0; i < 8; i++) { + Future future = + executors.submit(() -> StateFactory.buildGraphState(desc, new Configuration(config))); + list.add(future); + } + GraphState graphState = list.get(0).get(); + for (int i = 1; i < 8; i++) { + Assert.assertTrue(graphState == list.get(i).get()); } - @AfterMethod - public void tearDown() throws IOException { - FileUtils.deleteQuietly(new File("/tmp/" + testName)); + graphState.manage().operate().setCheckpointId(1); + for (int i = 0; i < 10000; i++) { + graphState.staticGraph().V().add(new ValueVertex<>("hello" + i, "hello")); + graphState.staticGraph().E().add(new ValueEdge<>("hello" + i, "world" + i, "hello" + i)); } + graphState.manage().operate().finish(); + graphState.manage().operate().archive(); + graphState.manage().operate().drop(); - @Test - public void testSingleton() throws ExecutionException, InterruptedException { - Map config = new HashMap<>(); - Map persistConfig = new HashMap<>(); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), testName); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - config.put(FileConfigKeys.JSON_CONFIG.getKey(), GsonUtil.toJson(persistConfig)); - - GraphMetaType tag = new GraphMetaType(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class); - GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.ROCKSDB.name()); - desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 1)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withSingleton(); - ExecutorService executors = Executors.getExecutorService(8, "testSingleton"); - List> list = new ArrayList<>(); - for (int i = 0; i < 8; i++) { - Future future = executors.submit( - () -> StateFactory.buildGraphState(desc, new Configuration(config))); - list.add(future); - } - GraphState graphState = list.get(0).get(); - for (int i = 1; i < 8; i++) { - Assert.assertTrue(graphState == list.get(i).get()); - } - - graphState.manage().operate().setCheckpointId(1); - for (int i = 0; i < 10000; i++) { - graphState.staticGraph().V().add(new ValueVertex<>("hello" + i, "hello")); - graphState.staticGraph().E().add(new ValueEdge<>("hello" + i, "world" + i, "hello" + i)); - } - graphState.manage().operate().finish(); - graphState.manage().operate().archive(); - graphState.manage().operate().drop(); - - final GraphState graphState2 = StateFactory.buildGraphState(desc, new Configuration(config)); - List list2 = new ArrayList<>(); - AtomicInteger count = new AtomicInteger(0); - for (int i = 0; i < 8; i++) { - final int keyGroupId = i % 2; - list2.add(executors.submit(() -> { - graphState2.manage().operate().load(LoadOption.of().withKeyGroup(new KeyGroup(keyGroupId, keyGroupId)).withCheckpointId(1L)); - int size = Iterators.size(graphState2.staticGraph().V().query(new KeyGroup(keyGroupId, keyGroupId)).idIterator()); + final GraphState graphState2 = StateFactory.buildGraphState(desc, new Configuration(config)); + List list2 = new ArrayList<>(); + AtomicInteger count = new AtomicInteger(0); + for (int i = 0; i < 8; i++) { + final int keyGroupId = i % 2; + list2.add( + executors.submit( + () -> { + graphState2 + .manage() + .operate() + .load( + LoadOption.of() + .withKeyGroup(new KeyGroup(keyGroupId, keyGroupId)) + .withCheckpointId(1L)); + int size = + Iterators.size( + graphState2 + .staticGraph() + .V() + .query(new KeyGroup(keyGroupId, keyGroupId)) + .idIterator()); count.addAndGet(size); - })); - } - for (Future f : list2) { - f.get(); - } - Assert.assertEquals(count.get(), 40000); - graphState2.manage().operate().drop(); - executors.shutdown(); - - FileUtils.deleteQuietly(new File("/tmp/geaflow/chk")); + })); } - - @Test - public void testGraphState() throws Exception { - Map config = new HashMap<>(); - - GraphMetaType tag = new GraphMetaType(Types.STRING, ValueVertex.class, - String.class, ValueEdge.class, String.class); - GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.MEMORY.name()); - desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - for (int i = 0; i < 100; i++) { - graphState.staticGraph().V().add(new ValueVertex<>("hello" + ts, "hello")); - graphState.staticGraph().E().add(new ValueEdge<>("hello" + ts, "world" + i, - "hello" + i)); - } - graphState.staticGraph().V().add(new ValueVertex<>("world" + ts, "world")); - - Iterator> it = graphState.staticGraph().VE().query( - "hello" + ts, "world" + ts).iterator(); - Map> res = new HashMap<>(); - while (it.hasNext()) { - OneDegreeGraph oneDegreeGraph = it.next(); - res.put(oneDegreeGraph.getKey(), oneDegreeGraph); - } - - Assert.assertEquals(res.size(), 2); - OneDegreeGraph oneDegreeGraph = res.get("hello" + ts); - Assert.assertEquals(oneDegreeGraph.getVertex().getValue(), "hello"); - Assert.assertEquals(Iterators.size(oneDegreeGraph.getEdgeIterator()), 100); - oneDegreeGraph = res.get("world" + ts); - Assert.assertEquals(oneDegreeGraph.getVertex().getValue(), "world"); - Assert.assertEquals(Iterators.size(oneDegreeGraph.getEdgeIterator()), 0); + for (Future f : list2) { + f.get(); } + Assert.assertEquals(count.get(), 40000); + graphState2.manage().operate().drop(); + executors.shutdown(); - @Test - public void testKeyValueState() { - Map config = new HashMap<>(); - KeyValueStateDescriptor desc = - KeyValueStateDescriptor.build(testName, StoreType.MEMORY.name()); - desc.withDefaultValue(() -> "foobar").withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - KeyValueState valueState = StateFactory.buildKeyValueState(desc, - new Configuration(config)); - - valueState.put("hello", "world"); - Assert.assertEquals(valueState.get("hello"), "world"); - Assert.assertEquals(valueState.get("foo"), "foobar"); - } + FileUtils.deleteQuietly(new File("/tmp/geaflow/chk")); + } - @Test - public void testKeyListState() { - Map config = new HashMap<>(); - KeyListStateDescriptor desc = - KeyListStateDescriptor.build(testName, StoreType.MEMORY.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - ; - KeyListState listState = StateFactory.buildKeyListState(desc, - new Configuration(config)); - - listState.add("hello", "world"); - listState.put("foo", Arrays.asList("bar1", "bar2")); - Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); - Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); - listState.remove("foo"); - Assert.assertEquals(listState.get("foo").size(), 0); - } + @Test + public void testGraphState() throws Exception { + Map config = new HashMap<>(); - @Test - public void testKeyMapState() { - Map config = new HashMap<>(); - KeyMapStateDescriptor desc = - KeyMapStateDescriptor.build(testName, StoreType.MEMORY.name()); - desc.withKeyGroup(new KeyGroup(0, 0)) - .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - ; - KeyMapState mapState = StateFactory.buildKeyMapState(desc, - new Configuration(config)); - - mapState.put("hello", config); - mapState.add("foo", "bar1", "bar2"); - Assert.assertEquals(mapState.get("hello").size(), config.size()); - Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); - mapState.remove("hello"); - Assert.assertEquals(mapState.get("hello").size(), 0); - } + GraphMetaType tag = + new GraphMetaType( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class); + GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.MEMORY.name()); + desc.withGraphMeta(new GraphMeta(tag)) + .withKeyGroup(new KeyGroup(0, 0)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); - @Test(expectedExceptions = IllegalArgumentException.class) - public void testUnsupportedStore_0() { - KeyValueStateDescriptor desc = - KeyValueStateDescriptor.build(testName, StoreType.REDIS.name()); - StateFactory.buildKeyValueState(desc, new Configuration()); + graphState.manage().operate().setCheckpointId(1); + for (int i = 0; i < 100; i++) { + graphState.staticGraph().V().add(new ValueVertex<>("hello" + ts, "hello")); + graphState.staticGraph().E().add(new ValueEdge<>("hello" + ts, "world" + i, "hello" + i)); } + graphState.staticGraph().V().add(new ValueVertex<>("world" + ts, "world")); - @Test(expectedExceptions = IllegalArgumentException.class) - public void testUnsupportedStore_1() { - KeyListStateDescriptor desc = - KeyListStateDescriptor.build(testName, StoreType.REDIS.name()); - StateFactory.buildKeyListState(desc, new Configuration()); + Iterator> it = + graphState.staticGraph().VE().query("hello" + ts, "world" + ts).iterator(); + Map> res = new HashMap<>(); + while (it.hasNext()) { + OneDegreeGraph oneDegreeGraph = it.next(); + res.put(oneDegreeGraph.getKey(), oneDegreeGraph); } - @Test(expectedExceptions = IllegalArgumentException.class) - public void testUnsupportedStore_2() { - KeyMapStateDescriptor desc = - KeyMapStateDescriptor.build(testName, StoreType.REDIS.name()); - StateFactory.buildKeyMapState(desc, new Configuration()); - } + Assert.assertEquals(res.size(), 2); + OneDegreeGraph oneDegreeGraph = res.get("hello" + ts); + Assert.assertEquals(oneDegreeGraph.getVertex().getValue(), "hello"); + Assert.assertEquals(Iterators.size(oneDegreeGraph.getEdgeIterator()), 100); + oneDegreeGraph = res.get("world" + ts); + Assert.assertEquals(oneDegreeGraph.getVertex().getValue(), "world"); + Assert.assertEquals(Iterators.size(oneDegreeGraph.getEdgeIterator()), 0); + } - @Test(expectedExceptions = Exception.class) - public void testUnsupportedStore_3() { - GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.REDIS.name()); - StateFactory.buildGraphState(desc, new Configuration()); - } + @Test + public void testKeyValueState() { + Map config = new HashMap<>(); + KeyValueStateDescriptor desc = + KeyValueStateDescriptor.build(testName, StoreType.MEMORY.name()); + desc.withDefaultValue(() -> "foobar") + .withKeyGroup(new KeyGroup(0, 0)) + .withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + KeyValueState valueState = + StateFactory.buildKeyValueState(desc, new Configuration(config)); + + valueState.put("hello", "world"); + Assert.assertEquals(valueState.get("hello"), "world"); + Assert.assertEquals(valueState.get("foo"), "foobar"); + } + + @Test + public void testKeyListState() { + Map config = new HashMap<>(); + KeyListStateDescriptor desc = + KeyListStateDescriptor.build(testName, StoreType.MEMORY.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + ; + KeyListState listState = + StateFactory.buildKeyListState(desc, new Configuration(config)); + + listState.add("hello", "world"); + listState.put("foo", Arrays.asList("bar1", "bar2")); + Assert.assertEquals(listState.get("hello"), Arrays.asList("world")); + Assert.assertEquals(listState.get("foo"), Arrays.asList("bar1", "bar2")); + listState.remove("foo"); + Assert.assertEquals(listState.get("foo").size(), 0); + } + + @Test + public void testKeyMapState() { + Map config = new HashMap<>(); + KeyMapStateDescriptor desc = + KeyMapStateDescriptor.build(testName, StoreType.MEMORY.name()); + desc.withKeyGroup(new KeyGroup(0, 0)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + ; + KeyMapState mapState = + StateFactory.buildKeyMapState(desc, new Configuration(config)); + + mapState.put("hello", config); + mapState.add("foo", "bar1", "bar2"); + Assert.assertEquals(mapState.get("hello").size(), config.size()); + Assert.assertEquals(mapState.get("foo").get("bar1"), "bar2"); + mapState.remove("hello"); + Assert.assertEquals(mapState.get("hello").size(), 0); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testUnsupportedStore_0() { + KeyValueStateDescriptor desc = + KeyValueStateDescriptor.build(testName, StoreType.REDIS.name()); + StateFactory.buildKeyValueState(desc, new Configuration()); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testUnsupportedStore_1() { + KeyListStateDescriptor desc = + KeyListStateDescriptor.build(testName, StoreType.REDIS.name()); + StateFactory.buildKeyListState(desc, new Configuration()); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testUnsupportedStore_2() { + KeyMapStateDescriptor desc = + KeyMapStateDescriptor.build(testName, StoreType.REDIS.name()); + StateFactory.buildKeyMapState(desc, new Configuration()); + } + + @Test(expectedExceptions = Exception.class) + public void testUnsupportedStore_3() { + GraphStateDescriptor desc = GraphStateDescriptor.build(testName, StoreType.REDIS.name()); + StateFactory.buildGraphState(desc, new Configuration()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StaticGraphStateTest.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StaticGraphStateTest.java index 316381dcf..909ae8e3c 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StaticGraphStateTest.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/StaticGraphStateTest.java @@ -19,14 +19,12 @@ package org.apache.geaflow.state; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterators; -import com.google.common.primitives.Longs; import java.io.File; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -60,275 +58,400 @@ import org.testng.annotations.Factory; import org.testng.annotations.Test; -public class StaticGraphStateTest { - - private final Map additionalConfig; - private final StoreType storeType; - - @AfterMethod - public void tearDown() { - FileUtils.deleteQuietly(new File("/tmp/StaticGraphStateTest")); - } +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.primitives.Longs; - public StaticGraphStateTest(StoreType storeType, Map config) { - this.storeType = storeType; - this.additionalConfig = config; - } +public class StaticGraphStateTest { - public static class GraphStateTestFactory { - - @Factory - public Object[] factoryMethod() { - return new Object[]{new StaticGraphStateTest(StoreType.MEMORY, new HashMap<>()), - new StaticGraphStateTest(StoreType.MEMORY, - ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true")), - new StaticGraphStateTest(StoreType.ROCKSDB, - ImmutableMap.of(ExecutionConfigKeys.JOB_APP_NAME.getKey(), - "StaticGraphStateTest"))}; - } + private final Map additionalConfig; + private final StoreType storeType; + + @AfterMethod + public void tearDown() { + FileUtils.deleteQuietly(new File("/tmp/StaticGraphStateTest")); + } + + public StaticGraphStateTest(StoreType storeType, Map config) { + this.storeType = storeType; + this.additionalConfig = config; + } + + public static class GraphStateTestFactory { + + @Factory + public Object[] factoryMethod() { + return new Object[] { + new StaticGraphStateTest(StoreType.MEMORY, new HashMap<>()), + new StaticGraphStateTest( + StoreType.MEMORY, ImmutableMap.of(MemoryConfigKeys.CSR_MEMORY_ENABLE.getKey(), "true")), + new StaticGraphStateTest( + StoreType.ROCKSDB, + ImmutableMap.of(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "StaticGraphStateTest")) + }; } - - @Test - public void testNormalGet() { - GraphStateDescriptor desc = GraphStateDescriptor.build("test1", - storeType.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta( - new GraphMetaType<>(Types.STRING, ValueVertex.class, String.class, ValueEdge.class, + } + + @Test + public void testNormalGet() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("test1", storeType.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueEdge.class, String.class))); + Map config = new HashMap<>(additionalConfig); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); + graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); + graphState.manage().operate().finish(); + + List> list = graphState.staticGraph().E().query("1").asList(); + Assert.assertEquals(list.size(), 2); + + Iterator> iterator = graphState.staticGraph().V().iterator(); + Assert.assertEquals(Iterators.size(iterator), 2); + + IVertex vertex = graphState.staticGraph().V().query("1").get(); + Assert.assertEquals(vertex.getValue(), "3"); + + // keyGroup get + iterator = graphState.staticGraph().V().query(new KeyGroup(0, 0)).iterator(); + Assert.assertEquals(Iterators.size(iterator), 1); + + Iterator idIterator = + graphState.staticGraph().V().query(new KeyGroup(0, 0)).idIterator(); + Assert.assertEquals(Iterators.size(idIterator), 1); + + list = graphState.staticGraph().E().query(new KeyGroup(0, 0)).asList(); + Assert.assertEquals(list.size(), 2); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testFilter() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("testFilter", storeType.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, + ValueVertex.class, + String.class, + ValueLabelTimeEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); + Map config = new HashMap<>(additionalConfig); - graphState.staticGraph().E().add(new ValueEdge<>("1", "2", "hello", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("1", "3", "hello", EdgeDirection.OUT)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "2", "world", EdgeDirection.IN)); - graphState.staticGraph().E().add(new ValueEdge<>("2", "3", "world", EdgeDirection.OUT)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - graphState.manage().operate().finish(); + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); - List> list = graphState.staticGraph().E().query("1").asList(); - Assert.assertEquals(list.size(), 2); + graphState.manage().operate().setCheckpointId(1); - Iterator> iterator = graphState.staticGraph().V().iterator(); - Assert.assertEquals(Iterators.size(iterator), 2); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "2", "hello", "foo", 1000)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "3", "hello", "bar", 100)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "2", "world", "foo", 1000)); + graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "3", "world", "bar", 100)); + graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); + graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - IVertex vertex = graphState.staticGraph().V().query("1").get(); - Assert.assertEquals(vertex.getValue(), "3"); + graphState.manage().operate().finish(); - // keyGroup get - iterator = graphState.staticGraph().V().query(new KeyGroup(0, 0)).iterator(); - Assert.assertEquals(Iterators.size(iterator), 1); - - Iterator idIterator = graphState.staticGraph().V().query(new KeyGroup(0, 0)) - .idIterator(); - Assert.assertEquals(Iterators.size(idIterator), 1); - - list = graphState.staticGraph().E().query(new KeyGroup(0, 0)).asList(); - Assert.assertEquals(list.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); + List> list = + graphState + .staticGraph() + .E() + .query("1", "2") + .by(new EdgeTsFilter<>(TimeRange.of(0, 500))) + .asList(); + Assert.assertEquals(list.size(), 2); + + list = + graphState + .staticGraph() + .E() + .query("1", "2") + .by( + new EdgeTsFilter<>(TimeRange.of(0, 500)) + .or(new EdgeTsFilter<>(TimeRange.of(800, 1100)))) + .asList(); + Assert.assertEquals(list.size(), 4); + + list = + graphState + .staticGraph() + .E() + .query("1", "2") + .by( + new IFilter[] { + new EdgeTsFilter<>(TimeRange.of(0, 500)), + new EdgeTsFilter<>(TimeRange.of(800, 1100)) + }) + .asList(); + Assert.assertEquals(list.size(), 2); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testLimitAndSort() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("testLimitAndSort", storeType.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, ValueVertex.class, String.class, ValueTimeEdge.class, String.class))); + Map config = new HashMap<>(additionalConfig); + config.put( + StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), + "SRC_ID, DESC_TIME, DIRECTION, DST_ID"); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + + for (int i = 0; i < 100; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 100; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add( + new ValueTimeEdge<>( + src, + dst, + "hello" + src + dst, + EdgeDirection.values()[j % 2], + i >= 10 ? j : (i + 1) * 100 + j)); + } + graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); } + graphState.manage().operate().finish(); + + // key limit + List> list = + graphState.staticGraph().E().query("1", "2", "3").limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 6); + + list = + graphState + .staticGraph() + .E() + .query("1", "2", "3") + .by(InEdgeFilter.getInstance()) + .limit(1L, 1L) + .asList(); + Assert.assertEquals(list.size(), 3); - @Test - public void testFilter() { - GraphStateDescriptor desc = GraphStateDescriptor.build("testFilter", - storeType.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta( - new GraphMetaType<>(Types.STRING, ValueVertex.class, String.class, - ValueLabelTimeEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "2", "hello", "foo", 1000)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("1", "3", "hello", "bar", 100)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "2", "world", "foo", 1000)); - graphState.staticGraph().E().add(new ValueLabelTimeEdge<>("2", "3", "world", "bar", 100)); - graphState.staticGraph().V().add(new ValueVertex<>("1", "3")); - graphState.staticGraph().V().add(new ValueVertex<>("2", "4")); - - graphState.manage().operate().finish(); - - List> list = graphState.staticGraph().E().query("1", "2") - .by(new EdgeTsFilter<>(TimeRange.of(0, 500))).asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("1", "2") - .by(new EdgeTsFilter<>(TimeRange.of(0, 500)).or( - new EdgeTsFilter<>(TimeRange.of(800, 1100)))).asList(); - Assert.assertEquals(list.size(), 4); - - list = graphState.staticGraph().E().query("1", "2") - .by(new IFilter[]{new EdgeTsFilter<>(TimeRange.of(0, 500)), - new EdgeTsFilter<>(TimeRange.of(800, 1100))}).asList(); - Assert.assertEquals(list.size(), 2); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + list = graphState.staticGraph().E().query("1").limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 2); - @Test - public void testLimitAndSort() { - GraphStateDescriptor desc = GraphStateDescriptor.build( - "testLimitAndSort", storeType.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta( - new GraphMetaType<>(Types.STRING, ValueVertex.class, String.class, ValueTimeEdge.class, - String.class))); - Map config = new HashMap<>(additionalConfig); - config.put(StateConfigKeys.STATE_KV_ENCODER_EDGE_ORDER.getKey(), - "SRC_ID, DESC_TIME, DIRECTION, DST_ID"); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - - for (int i = 0; i < 100; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 100; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add(new ValueTimeEdge<>(src, dst, "hello" + src + dst, - EdgeDirection.values()[j % 2], i >= 10 ? j : (i + 1) * 100 + j)); - } - graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); - } - graphState.manage().operate().finish(); - - // key limit - List> list = graphState.staticGraph().E().query("1", "2", "3") - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 6); - - list = graphState.staticGraph().E().query("1", "2", "3").by(InEdgeFilter.getInstance()) - .limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 3); - - list = graphState.staticGraph().E().query("1").limit(1L, 1L).asList(); - Assert.assertEquals(list.size(), 2); - - list = graphState.staticGraph().E().query("11", "12", "13") + list = + graphState + .staticGraph() + .E() + .query("11", "12", "13") .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60)).singleLimit()) - .limit(2L, 1L).asList(); - Assert.assertEquals(list.size(), 18); - - list = graphState.staticGraph().E().query("11", "12", "13") - .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60))).limit(2L, 1L) + .limit(2L, 1L) .asList(); - Assert.assertEquals(list.size(), 9); - - // full limit - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L) + Assert.assertEquals(list.size(), 18); + + list = + graphState + .staticGraph() + .E() + .query("11", "12", "13") + .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60))) + .limit(2L, 1L) .asList(); - Assert.assertEquals(list.size(), 100); - - list = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L) - .asList(); - Assert.assertEquals(list.size(), 200); - - list = graphState.staticGraph().E().query() + Assert.assertEquals(list.size(), 9); + + // full limit + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 1L).asList(); + Assert.assertEquals(list.size(), 100); + + list = + graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()).limit(1L, 2L).asList(); + Assert.assertEquals(list.size(), 200); + + list = + graphState + .staticGraph() + .E() + .query() .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60)).singleLimit()) - .limit(2L, 1L).asList(); - Assert.assertEquals(list.size(), 540); - - list = graphState.staticGraph().E().query() - .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60))).limit(2L, 1L) + .limit(2L, 1L) .asList(); - Assert.assertEquals(list.size(), 270); - - // sort keys - long[] times = Longs.toArray(graphState.staticGraph().E().query("1", "2", "3").limit(3L, 3L) - .orderBy(EdgeAtom.DESC_TIME).select(new TimeProjector<>()).asList()); - Assert.assertEquals(times.length, 18); - Assert.assertEquals(Longs.max(times), 4 * 100 + 99L); - Assert.assertEquals(Longs.min(times), 294L); - - // sort all - times = Longs.toArray( - graphState.staticGraph().E().query().limit(0L, 1L).orderBy(EdgeAtom.DESC_TIME) - .select(new TimeProjector<>()).asList()); - Assert.assertEquals(times.length, 100); - Assert.assertEquals(Longs.max(times), 10 * 100 + 98L); - Assert.assertEquals(Longs.min(times), 98L); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } - - @Test - public void testProjectAndAgg() { - GraphStateDescriptor desc = GraphStateDescriptor.build( - "testProjectAndAgg", storeType.name()); - desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); - desc.withGraphMeta(new GraphMeta( - new GraphMetaType<>(Types.STRING, ValueVertex.class, String.class, ValueLabelEdge.class, + Assert.assertEquals(list.size(), 540); + + list = + graphState + .staticGraph() + .E() + .query() + .by(EdgeTsFilter.getInstance(10, 20).or(EdgeTsFilter.getInstance(50, 60))) + .limit(2L, 1L) + .asList(); + Assert.assertEquals(list.size(), 270); + + // sort keys + long[] times = + Longs.toArray( + graphState + .staticGraph() + .E() + .query("1", "2", "3") + .limit(3L, 3L) + .orderBy(EdgeAtom.DESC_TIME) + .select(new TimeProjector<>()) + .asList()); + Assert.assertEquals(times.length, 18); + Assert.assertEquals(Longs.max(times), 4 * 100 + 99L); + Assert.assertEquals(Longs.min(times), 294L); + + // sort all + times = + Longs.toArray( + graphState + .staticGraph() + .E() + .query() + .limit(0L, 1L) + .orderBy(EdgeAtom.DESC_TIME) + .select(new TimeProjector<>()) + .asList()); + Assert.assertEquals(times.length, 100); + Assert.assertEquals(Longs.max(times), 10 * 100 + 98L); + Assert.assertEquals(Longs.min(times), 98L); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } + + @Test + public void testProjectAndAgg() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("testProjectAndAgg", storeType.name()); + desc.withKeyGroup(new KeyGroup(0, 1)).withKeyGroupAssigner(new DefaultKeyGroupAssigner(2)); + desc.withGraphMeta( + new GraphMeta( + new GraphMetaType<>( + Types.STRING, + ValueVertex.class, + String.class, + ValueLabelEdge.class, String.class))); - Map config = new HashMap<>(additionalConfig); - - GraphState graphState = StateFactory.buildGraphState(desc, - new Configuration(config)); - - graphState.manage().operate().setCheckpointId(1); - String[] labels = new String[]{"teacher", "student", "president"}; - - for (int i = 0; i < 10; i++) { - String src = Integer.toString(i); - for (int j = 1; j < 10; j++) { - String dst = Integer.toString(j); - graphState.staticGraph().E().add(new ValueLabelEdge<>(src, dst, "hello" + src + dst, - EdgeDirection.values()[j % 2], labels[j % 3])); - } - graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); - } - graphState.manage().operate().finish(); - - // project test - List targetIds = graphState.staticGraph().E().query().by(InEdgeFilter.getInstance()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); + Map config = new HashMap<>(additionalConfig); + + GraphState graphState = + StateFactory.buildGraphState(desc, new Configuration(config)); + + graphState.manage().operate().setCheckpointId(1); + String[] labels = new String[] {"teacher", "student", "president"}; + + for (int i = 0; i < 10; i++) { + String src = Integer.toString(i); + for (int j = 1; j < 10; j++) { + String dst = Integer.toString(j); + graphState + .staticGraph() + .E() + .add( + new ValueLabelEdge<>( + src, dst, "hello" + src + dst, EdgeDirection.values()[j % 2], labels[j % 3])); + } + graphState.staticGraph().V().add(new ValueVertex<>(src, "world" + src)); + } + graphState.manage().operate().finish(); + + // project test + List targetIds = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance()) + .select(new DstIdProjector<>()) + .limit(1L, 2L) + .asList(); - Assert.assertEquals(targetIds.size(), 20); + Assert.assertEquals(targetIds.size(), 20); - targetIds = graphState.staticGraph().E().query() + targetIds = + graphState + .staticGraph() + .E() + .query() .by(InEdgeFilter.getInstance().or(OutEdgeFilter.getInstance()).singleLimit()) - .select(new DstIdProjector<>()).limit(1L, 2L).asList(); - Assert.assertEquals(targetIds.size(), 30); - - // full agg test - Map res = graphState.staticGraph().E().query() - .by(InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher"))).aggregate(); - Assert.assertEquals(res.size(), 10); - Assert.assertTrue(res.get("2") == 1L); - - res = graphState.staticGraph().E().query() - .by(OutEdgeFilter.getInstance().and(new EdgeLabelFilter("student"))).aggregate(); - Assert.assertEquals(res.size(), 10); - Assert.assertTrue(res.get("2") == 2L); - - // key agg test - res = graphState.staticGraph().E().query("2", "5") - .by(InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher"))).aggregate(); - - Assert.assertEquals(res.size(), 2); - Assert.assertTrue(res.get("2") == 1L); - - res = graphState.staticGraph().E().query("2", "5") - .by(InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher")), - OutEdgeFilter.getInstance().and(new EdgeLabelFilter("student"))).aggregate(); - Assert.assertEquals(res.size(), 2); - Assert.assertTrue(res.get("2") == 1L); - Assert.assertTrue(res.get("5") == 2L); - - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + .select(new DstIdProjector<>()) + .limit(1L, 2L) + .asList(); + Assert.assertEquals(targetIds.size(), 30); + + // full agg test + Map res = + graphState + .staticGraph() + .E() + .query() + .by(InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher"))) + .aggregate(); + Assert.assertEquals(res.size(), 10); + Assert.assertTrue(res.get("2") == 1L); + + res = + graphState + .staticGraph() + .E() + .query() + .by(OutEdgeFilter.getInstance().and(new EdgeLabelFilter("student"))) + .aggregate(); + Assert.assertEquals(res.size(), 10); + Assert.assertTrue(res.get("2") == 2L); + + // key agg test + res = + graphState + .staticGraph() + .E() + .query("2", "5") + .by(InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher"))) + .aggregate(); + + Assert.assertEquals(res.size(), 2); + Assert.assertTrue(res.get("2") == 1L); + + res = + graphState + .staticGraph() + .E() + .query("2", "5") + .by( + InEdgeFilter.getInstance().and(new EdgeLabelFilter("teacher")), + OutEdgeFilter.getInstance().and(new EdgeLabelFilter("student"))) + .aggregate(); + Assert.assertEquals(res.size(), 2); + Assert.assertTrue(res.get("2") == 1L); + Assert.assertTrue(res.get("5") == 2L); + + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/DirectStoreReadJMH.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/DirectStoreReadJMH.java index 436bae578..ddf243610 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/DirectStoreReadJMH.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/DirectStoreReadJMH.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -62,79 +63,89 @@ @State(Scope.Benchmark) public class DirectStoreReadJMH extends JMHParameter { - IStaticGraphStore store; - IStatePushDown pushdown = StatePushDown.of(); + IStaticGraphStore store; + IStatePushDown pushdown = StatePushDown.of(); - @Setup(Level.Trial) - public void setUp() { - composeGraph(); - } + @Setup(Level.Trial) + public void setUp() { + composeGraph(); + } - @TearDown(Level.Trial) - public void tearDown() { - store.close(); - store.drop(); - } + @TearDown(Level.Trial) + public void tearDown() { + store.close(); + store.drop(); + } - public void composeGraph() { - GraphMetaType tag = new GraphMetaType(Types.INTEGER, ValueVertex.class, - Integer.class, ValueLabelTimeEdge.class, Integer.class); - StoreContext storeContext = new StoreContext("test"); - storeContext.withDataSchema(new GraphDataSchema(new GraphMeta(tag))); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - storeContext.withConfig(configuration); + public void composeGraph() { + GraphMetaType tag = + new GraphMetaType( + Types.INTEGER, + ValueVertex.class, + Integer.class, + ValueLabelTimeEdge.class, + Integer.class); + StoreContext storeContext = new StoreContext("test"); + storeContext.withDataSchema(new GraphDataSchema(new GraphMeta(tag))); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + storeContext.withConfig(configuration); - store = - (IStaticGraphStore) StoreBuilderFactory.build( - storeType).getStore(DataModel.STATIC_GRAPH, configuration); - store.init(storeContext); - for (int i = 0; i < vNum; i++) { - IVertex vertex = new ValueVertex<>(i, i); - store.addVertex(vertex); - for (int j = 1; j < outE; j++) { - IEdge edge = new ValueLabelTimeEdge<>(i, j, i, - i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, - Integer.toString(i % 10), i + 10000000); - store.addEdge(edge); - } - } + store = + (IStaticGraphStore) + StoreBuilderFactory.build(storeType).getStore(DataModel.STATIC_GRAPH, configuration); + store.init(storeContext); + for (int i = 0; i < vNum; i++) { + IVertex vertex = new ValueVertex<>(i, i); + store.addVertex(vertex); + for (int j = 1; j < outE; j++) { + IEdge edge = + new ValueLabelTimeEdge<>( + i, + j, + i, + i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, + Integer.toString(i % 10), + i + 10000000); + store.addEdge(edge); + } } + } - @Benchmark - public void getVertex(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(store.getVertex(i, pushdown)); - } + @Benchmark + public void getVertex(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(store.getVertex(i, pushdown)); } + } - @Benchmark - public void getEdges(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(store.getEdges(i, pushdown)); - } + @Benchmark + public void getEdges(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(store.getEdges(i, pushdown)); } + } - @Benchmark - public void getOneGraph(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(store.getOneDegreeGraph(i, pushdown)); - } + @Benchmark + public void getOneGraph(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(store.getOneDegreeGraph(i, pushdown)); } + } - @Benchmark - public void getVertexIterator(Blackhole blackhole) { - Iterator> it = store.getVertexIterator(pushdown); - while (it.hasNext()) { - blackhole.consume(it.next()); - } + @Benchmark + public void getVertexIterator(Blackhole blackhole) { + Iterator> it = store.getVertexIterator(pushdown); + while (it.hasNext()) { + blackhole.consume(it.next()); } + } - @Benchmark - public void getOneDegreeGraphIterator(Blackhole blackhole) { - Iterator> it = - store.getOneDegreeGraphIterator(pushdown); - while (it.hasNext()) { - blackhole.consume(it.next()); - } + @Benchmark + public void getOneDegreeGraphIterator(Blackhole blackhole) { + Iterator> it = + store.getOneDegreeGraphIterator(pushdown); + while (it.hasNext()) { + blackhole.consume(it.next()); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/JMHParameter.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/JMHParameter.java index dc3cd4660..be16e9b91 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/JMHParameter.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/JMHParameter.java @@ -21,8 +21,8 @@ import static org.apache.geaflow.common.config.keys.FrameworkConfigKeys.JOB_MAX_PARALLEL; -import com.google.common.collect.ImmutableMap; import java.util.HashMap; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.file.FileConfigKeys; @@ -30,23 +30,26 @@ import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.State; +import com.google.common.collect.ImmutableMap; + @State(Scope.Benchmark) public class JMHParameter { - @Param({"rocksdb"}) - public String storeType; - - @Param({"1", "50"}) - public int outE; + @Param({"rocksdb"}) + public String storeType; - @Param({"10000"}) - public int vNum; + @Param({"1", "50"}) + public int outE; - public Configuration configuration = new Configuration(new HashMap<>(ImmutableMap.of( - ExecutionConfigKeys.JOB_WORK_PATH.getKey(), "/tmp", - FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL", - FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/", - JOB_MAX_PARALLEL.getKey(), "1" - ))); + @Param({"10000"}) + public int vNum; + public Configuration configuration = + new Configuration( + new HashMap<>( + ImmutableMap.of( + ExecutionConfigKeys.JOB_WORK_PATH.getKey(), "/tmp", + FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL", + FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/", + JOB_MAX_PARALLEL.getKey(), "1"))); } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphJMHRunner.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphJMHRunner.java index a2b5c3e64..2ae412341 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphJMHRunner.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphJMHRunner.java @@ -26,15 +26,16 @@ public class StaticGraphJMHRunner { - public static void main(String[] args) throws RunnerException { + public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() + Options opt = + new OptionsBuilder() // 导入要测试的类 .include(DirectStoreReadJMH.class.getSimpleName()) - //.include(StaticGraphStateReadJMH10.class.getSimpleName()) - //.include(StaticGraphStateWriteJMH10.class.getSimpleName()) - //.include(DirectStoreReadJMH.class.getSimpleName()) + // .include(StaticGraphStateReadJMH10.class.getSimpleName()) + // .include(StaticGraphStateWriteJMH10.class.getSimpleName()) + // .include(DirectStoreReadJMH.class.getSimpleName()) .build(); - new Runner(opt).run(); - } + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateReadJMH10.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateReadJMH10.java index 3f229af69..a4d1420ef 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateReadJMH10.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateReadJMH10.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.concurrent.TimeUnit; + import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.type.Types; import org.apache.geaflow.model.graph.edge.EdgeDirection; @@ -59,86 +60,96 @@ @State(Scope.Benchmark) public class StaticGraphStateReadJMH10 extends JMHParameter { - GraphState graphState; + GraphState graphState; - @Setup(Level.Trial) - public void setUp() { - composeGraph(); - } + @Setup(Level.Trial) + public void setUp() { + composeGraph(); + } - @TearDown(Level.Trial) - public void tearDown() { - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - } + @TearDown(Level.Trial) + public void tearDown() { + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + } - public void composeGraph() { - GraphStateDescriptor desc = GraphStateDescriptor.build( - "StaticGraphStateJMH", storeType); - GraphMetaType tag = new GraphMetaType(Types.INTEGER, ValueVertex.class, - Integer.class, ValueLabelTimeEdge.class, Integer.class); - desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 0)); - configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); - graphState = StateFactory.buildGraphState(desc, configuration); + public void composeGraph() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("StaticGraphStateJMH", storeType); + GraphMetaType tag = + new GraphMetaType( + Types.INTEGER, + ValueVertex.class, + Integer.class, + ValueLabelTimeEdge.class, + Integer.class); + desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 0)); + configuration.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), getClass().getSimpleName()); + graphState = StateFactory.buildGraphState(desc, configuration); - graphState.manage().operate().setCheckpointId(1); - for (int i = 0; i < vNum; i++) { - IVertex vertex = new ValueVertex<>(i, i); - graphState.staticGraph().V().add(vertex); - for (int j = 1; j < outE; j++) { - IEdge edge = new ValueLabelTimeEdge<>(i, j, i, - i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, - Integer.toString(i % 10), i + 10000000); - graphState.staticGraph().E().add(edge); - } - } - graphState.manage().operate().finish(); + graphState.manage().operate().setCheckpointId(1); + for (int i = 0; i < vNum; i++) { + IVertex vertex = new ValueVertex<>(i, i); + graphState.staticGraph().V().add(vertex); + for (int j = 1; j < outE; j++) { + IEdge edge = + new ValueLabelTimeEdge<>( + i, + j, + i, + i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, + Integer.toString(i % 10), + i + 10000000); + graphState.staticGraph().E().add(edge); + } } + graphState.manage().operate().finish(); + } - @Benchmark - public void getVertex(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(graphState.staticGraph().V().query(i).get()); - } + @Benchmark + public void getVertex(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(graphState.staticGraph().V().query(i).get()); } + } - @Benchmark - public void getEdges(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(graphState.staticGraph().E().query(i).asList()); - } + @Benchmark + public void getEdges(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(graphState.staticGraph().E().query(i).asList()); } + } - @Benchmark - public void getOneGraph(Blackhole blackhole) { - for (int i = 0; i < vNum; i++) { - blackhole.consume(graphState.staticGraph().VE().query(i).get()); - } + @Benchmark + public void getOneGraph(Blackhole blackhole) { + for (int i = 0; i < vNum; i++) { + blackhole.consume(graphState.staticGraph().VE().query(i).get()); } + } - @Benchmark - public void getVertexIterator(Blackhole blackhole) { - Iterator> it = graphState.staticGraph().V().iterator(); - while (it.hasNext()) { - blackhole.consume(it.next()); - } + @Benchmark + public void getVertexIterator(Blackhole blackhole) { + Iterator> it = graphState.staticGraph().V().iterator(); + while (it.hasNext()) { + blackhole.consume(it.next()); } + } - //@Benchmark - public void getEdgeIterator(Blackhole blackhole) { - Iterator> it = graphState.staticGraph().E().iterator(); - while (it.hasNext()) { - blackhole.consume(it.next()); - } + // @Benchmark + public void getEdgeIterator(Blackhole blackhole) { + Iterator> it = graphState.staticGraph().E().iterator(); + while (it.hasNext()) { + blackhole.consume(it.next()); } + } - @Benchmark - public void getOneGraphIterator(Blackhole blackhole) { - Iterator> it = - graphState.staticGraph().VE().iterator(); - while (it.hasNext()) { - OneDegreeGraph next = it.next(); - blackhole.consume(next); - } + @Benchmark + public void getOneGraphIterator(Blackhole blackhole) { + Iterator> it = + graphState.staticGraph().VE().iterator(); + while (it.hasNext()) { + OneDegreeGraph next = it.next(); + blackhole.consume(next); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateWriteJMH10.java b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateWriteJMH10.java index e855db5b1..96b6b4286 100644 --- a/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateWriteJMH10.java +++ b/geaflow/geaflow-state/geaflow-state-impl/src/test/java/org/apache/geaflow/state/jmh/StaticGraphStateWriteJMH10.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; + import org.apache.commons.io.FileUtils; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; @@ -62,39 +63,48 @@ @State(Scope.Benchmark) public class StaticGraphStateWriteJMH10 extends JMHParameter { - GraphState graphState; + GraphState graphState; - @Benchmark - public void composeGraph() { - GraphStateDescriptor desc = GraphStateDescriptor.build( - "StaticGraphStateJMH", storeType); - GraphMetaType tag = new GraphMetaType(Types.INTEGER, ValueVertex.class, - Integer.class, ValueLabelTimeEdge.class, Integer.class); - desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 0)); - desc.withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); - Map config = new HashMap<>(); - config.put(JOB_MAX_PARALLEL.getKey(), "1"); - config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "StaticGraphStateJMH"); - config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); - config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); + @Benchmark + public void composeGraph() { + GraphStateDescriptor desc = + GraphStateDescriptor.build("StaticGraphStateJMH", storeType); + GraphMetaType tag = + new GraphMetaType( + Types.INTEGER, + ValueVertex.class, + Integer.class, + ValueLabelTimeEdge.class, + Integer.class); + desc.withGraphMeta(new GraphMeta(tag)).withKeyGroup(new KeyGroup(0, 0)); + desc.withKeyGroupAssigner(new DefaultKeyGroupAssigner(1)); + Map config = new HashMap<>(); + config.put(JOB_MAX_PARALLEL.getKey(), "1"); + config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), "StaticGraphStateJMH"); + config.put(FileConfigKeys.PERSISTENT_TYPE.getKey(), "LOCAL"); + config.put(FileConfigKeys.ROOT.getKey(), "/tmp/geaflow/chk/"); - graphState = StateFactory.buildGraphState(desc, new Configuration(config)); + graphState = StateFactory.buildGraphState(desc, new Configuration(config)); - graphState.manage().operate().setCheckpointId(1); - for (int i = 0; i < vNum; i++) { - IVertex vertex = new ValueVertex<>(i, i); - graphState.staticGraph().V().add(vertex); - for (int j = 1; j < outE; j++) { - IEdge edge = new ValueLabelTimeEdge<>(i, j, i, - i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, - Integer.toString(i % 10), i + 10000000); - graphState.staticGraph().E().add(edge); - } - } - graphState.manage().operate().finish(); - graphState.manage().operate().close(); - graphState.manage().operate().drop(); - FileUtils.deleteQuietly(new File("/tmp/geaflow_store_local")); + graphState.manage().operate().setCheckpointId(1); + for (int i = 0; i < vNum; i++) { + IVertex vertex = new ValueVertex<>(i, i); + graphState.staticGraph().V().add(vertex); + for (int j = 1; j < outE; j++) { + IEdge edge = + new ValueLabelTimeEdge<>( + i, + j, + i, + i % 2 == 0 ? EdgeDirection.IN : EdgeDirection.OUT, + Integer.toString(i % 10), + i + 10000000); + graphState.staticGraph().E().add(edge); + } } - + graphState.manage().operate().finish(); + graphState.manage().operate().close(); + graphState.manage().operate().drop(); + FileUtils.deleteQuietly(new File("/tmp/geaflow_store_local")); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionRequest.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionRequest.java index 643b8a37c..ade9aed21 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionRequest.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionRequest.java @@ -21,25 +21,24 @@ public class ActionRequest { - private T request; - private int shard; + private T request; + private int shard; - public ActionRequest() { - } + public ActionRequest() {} - public ActionRequest(T request) { - this.request = request; - } + public ActionRequest(T request) { + this.request = request; + } - public T getRequest() { - return this.request; - } + public T getRequest() { + return this.request; + } - public int getShard() { - return shard; - } + public int getShard() { + return shard; + } - public void setShard(int shard) { - this.shard = shard; - } + public void setShard(int shard) { + this.shard = shard; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionType.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionType.java index 8c5ef0d26..92afa2f98 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionType.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/ActionType.java @@ -20,11 +20,11 @@ package org.apache.geaflow.state.action; public enum ActionType { - LOAD, - ARCHIVE, - RECOVER, - FINISH, - CLOSE, - DROP, - COMPACT + LOAD, + ARCHIVE, + RECOVER, + FINISH, + CLOSE, + DROP, + COMPACT } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/BaseAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/BaseAction.java index daf1f1436..d69500b1c 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/BaseAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/BaseAction.java @@ -21,10 +21,10 @@ public abstract class BaseAction implements IAction { - protected StateActionContext context; + protected StateActionContext context; - @Override - public void init(StateActionContext context) { - this.context = context; - } + @Override + public void init(StateActionContext context) { + this.context = context; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/EmptyAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/EmptyAction.java index 908b1b1b8..6b1dcd8de 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/EmptyAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/EmptyAction.java @@ -21,19 +21,17 @@ public class EmptyAction extends BaseAction { - private ActionType actionType; + private ActionType actionType; - public EmptyAction(ActionType actionType) { - this.actionType = actionType; - } + public EmptyAction(ActionType actionType) { + this.actionType = actionType; + } - @Override - public void apply(ActionRequest request) { + @Override + public void apply(ActionRequest request) {} - } - - @Override - public ActionType getActionType() { - return actionType; - } + @Override + public ActionType getActionType() { + return actionType; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/IAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/IAction.java index 8e3bdb958..d67646f8c 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/IAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/IAction.java @@ -21,9 +21,9 @@ public interface IAction { - void init(StateActionContext context); + void init(StateActionContext context); - void apply(ActionRequest request); + void apply(ActionRequest request); - ActionType getActionType(); + ActionType getActionType(); } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/StateActionContext.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/StateActionContext.java index 15380b227..f291de161 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/StateActionContext.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/StateActionContext.java @@ -24,19 +24,19 @@ public class StateActionContext { - private IStatefulStore baseStore; - private Configuration config; + private IStatefulStore baseStore; + private Configuration config; - public StateActionContext(IStatefulStore baseStore, Configuration config) { - this.baseStore = baseStore; - this.config = config; - } + public StateActionContext(IStatefulStore baseStore, Configuration config) { + this.baseStore = baseStore; + this.config = config; + } - public IStatefulStore getBaseStore() { - return baseStore; - } + public IStatefulStore getBaseStore() { + return baseStore; + } - public Configuration getConfig() { - return config; - } + public Configuration getConfig() { + return config; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/archive/ArchiveAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/archive/ArchiveAction.java index 242dd21f2..6289cd641 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/archive/ArchiveAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/archive/ArchiveAction.java @@ -25,13 +25,13 @@ public class ArchiveAction extends BaseAction { - @Override - public void apply(ActionRequest request) { - context.getBaseStore().archive((long) (request.getRequest())); - } + @Override + public void apply(ActionRequest request) { + context.getBaseStore().archive((long) (request.getRequest())); + } - @Override - public ActionType getActionType() { - return ActionType.ARCHIVE; - } + @Override + public ActionType getActionType() { + return ActionType.ARCHIVE; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/close/CloseAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/close/CloseAction.java index 1c8457490..804713512 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/close/CloseAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/close/CloseAction.java @@ -25,13 +25,13 @@ public class CloseAction extends BaseAction { - @Override - public void apply(ActionRequest request) { - context.getBaseStore().close(); - } + @Override + public void apply(ActionRequest request) { + context.getBaseStore().close(); + } - @Override - public ActionType getActionType() { - return ActionType.CLOSE; - } + @Override + public ActionType getActionType() { + return ActionType.CLOSE; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/compact/CompactAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/compact/CompactAction.java index 244ae976f..776f29f77 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/compact/CompactAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/compact/CompactAction.java @@ -25,13 +25,13 @@ public class CompactAction extends BaseAction { - @Override - public void apply(ActionRequest request) { - context.getBaseStore().compact(); - } + @Override + public void apply(ActionRequest request) { + context.getBaseStore().compact(); + } - @Override - public ActionType getActionType() { - return ActionType.COMPACT; - } + @Override + public ActionType getActionType() { + return ActionType.COMPACT; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/drop/DropAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/drop/DropAction.java index 2c6cd2db2..3ebdd0e58 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/drop/DropAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/drop/DropAction.java @@ -27,16 +27,16 @@ public class DropAction extends BaseAction { - private static final Logger LOGGER = LoggerFactory.getLogger(DropAction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(DropAction.class); - @Override - public void apply(ActionRequest request) { - LOGGER.info("shard {} drop action trigger", request.getShard()); - context.getBaseStore().drop(); - } + @Override + public void apply(ActionRequest request) { + LOGGER.info("shard {} drop action trigger", request.getShard()); + context.getBaseStore().drop(); + } - @Override - public ActionType getActionType() { - return ActionType.DROP; - } + @Override + public ActionType getActionType() { + return ActionType.DROP; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/finish/FinishAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/finish/FinishAction.java index 848401bc1..d5d3b0f87 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/finish/FinishAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/finish/FinishAction.java @@ -25,13 +25,13 @@ public class FinishAction extends BaseAction { - @Override - public void apply(ActionRequest request) { - context.getBaseStore().flush(); - } + @Override + public void apply(ActionRequest request) { + context.getBaseStore().flush(); + } - @Override - public ActionType getActionType() { - return ActionType.FINISH; - } + @Override + public ActionType getActionType() { + return ActionType.FINISH; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/load/LoadAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/load/LoadAction.java index 71bddebd9..cf19e043c 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/load/LoadAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/load/LoadAction.java @@ -28,22 +28,24 @@ public class LoadAction extends BaseAction { - private static final Logger LOGGER = LoggerFactory.getLogger(LoadAction.class); + private static final Logger LOGGER = LoggerFactory.getLogger(LoadAction.class); - @Override - public void apply(ActionRequest request) { - LoadOption option = (LoadOption) request.getRequest(); - if (option.getKeyGroup() == null || option.getKeyGroup().contains(request.getShard())) { - LOGGER.info("base store recover version {}", option.getCheckPointId()); - context.getBaseStore().recovery(option.getCheckPointId()); - } else { - LOGGER.warn("key group is null or key group {} not contain shard {}, ignore", - option.getKeyGroup(), request.getShard()); - } + @Override + public void apply(ActionRequest request) { + LoadOption option = (LoadOption) request.getRequest(); + if (option.getKeyGroup() == null || option.getKeyGroup().contains(request.getShard())) { + LOGGER.info("base store recover version {}", option.getCheckPointId()); + context.getBaseStore().recovery(option.getCheckPointId()); + } else { + LOGGER.warn( + "key group is null or key group {} not contain shard {}, ignore", + option.getKeyGroup(), + request.getShard()); } + } - @Override - public ActionType getActionType() { - return ActionType.LOAD; - } + @Override + public ActionType getActionType() { + return ActionType.LOAD; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/recovery/RecoveryAction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/recovery/RecoveryAction.java index 40224a052..2dc98926b 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/recovery/RecoveryAction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/action/recovery/RecoveryAction.java @@ -25,13 +25,13 @@ public class RecoveryAction extends BaseAction { - @Override - public void apply(ActionRequest request) { - context.getBaseStore().recovery((long) (request.getRequest())); - } + @Override + public void apply(ActionRequest request) { + context.getBaseStore().recovery((long) (request.getRequest())); + } - @Override - public ActionType getActionType() { - return ActionType.RECOVER; - } + @Override + public ActionType getActionType() { + return ActionType.RECOVER; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/AccessorBuilder.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/AccessorBuilder.java index 41b5592a9..ff097ddfa 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/AccessorBuilder.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/AccessorBuilder.java @@ -28,73 +28,73 @@ public class AccessorBuilder { - public static final Logger LOGGER = LoggerFactory.getLogger(AccessorBuilder.class); + public static final Logger LOGGER = LoggerFactory.getLogger(AccessorBuilder.class); - public static IAccessor getAccessor(DataModel dataModel, StateMode stateMode) { - switch (dataModel) { - case STATIC_GRAPH: - return getStaticGraphAccessor(stateMode); - case DYNAMIC_GRAPH: - return getDynamicGraphAccessor(stateMode); - case KV: - return getKVAccessor(stateMode); - case KList: - return getKListAccessor(stateMode); - case KMap: - return getKMapAccessor(stateMode); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + public static IAccessor getAccessor(DataModel dataModel, StateMode stateMode) { + switch (dataModel) { + case STATIC_GRAPH: + return getStaticGraphAccessor(stateMode); + case DYNAMIC_GRAPH: + return getDynamicGraphAccessor(stateMode); + case KV: + return getKVAccessor(stateMode); + case KList: + return getKListAccessor(stateMode); + case KMap: + return getKMapAccessor(stateMode); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - private static IAccessor getStaticGraphAccessor(StateMode stateMode) { - switch (stateMode) { - case RW: - return new RWStaticGraphAccessor<>(); - case RDONLY: - return new ReadOnlyStaticGraphAccessor<>(); - case COW: - return new COWGraphAccessor<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + private static IAccessor getStaticGraphAccessor(StateMode stateMode) { + switch (stateMode) { + case RW: + return new RWStaticGraphAccessor<>(); + case RDONLY: + return new ReadOnlyStaticGraphAccessor<>(); + case COW: + return new COWGraphAccessor<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - private static IAccessor getDynamicGraphAccessor(StateMode stateMode) { - switch (stateMode) { - case RW: - return new RWDynamicGraphAccessor<>(); - case RDONLY: - return new ReadOnlyDynamicGraphAccessor<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + private static IAccessor getDynamicGraphAccessor(StateMode stateMode) { + switch (stateMode) { + case RW: + return new RWDynamicGraphAccessor<>(); + case RDONLY: + return new ReadOnlyDynamicGraphAccessor<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - private static IAccessor getKVAccessor(StateMode stateMode) { - switch (stateMode) { - case RW: - return new RWKeyValueAccessor<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + private static IAccessor getKVAccessor(StateMode stateMode) { + switch (stateMode) { + case RW: + return new RWKeyValueAccessor<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - private static IAccessor getKListAccessor(StateMode stateMode) { - switch (stateMode) { - case RW: - return new RWKeyListAccessor<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + private static IAccessor getKListAccessor(StateMode stateMode) { + switch (stateMode) { + case RW: + return new RWKeyListAccessor<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } - private static IAccessor getKMapAccessor(StateMode stateMode) { - switch (stateMode) { - case RW: - return new RWKeyMapAccessor<>(); - default: - throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); - } + private static IAccessor getKMapAccessor(StateMode stateMode) { + switch (stateMode) { + case RW: + return new RWKeyMapAccessor<>(); + default: + throw new GeaflowRuntimeException(RuntimeErrors.INST.unsupportedError()); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ActionBuilder.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ActionBuilder.java index 22c22ffd0..cb51b4e42 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ActionBuilder.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ActionBuilder.java @@ -32,24 +32,24 @@ public class ActionBuilder { - public static IAction build(ActionType actionType) { - switch (actionType) { - case ARCHIVE: - return new ArchiveAction(); - case RECOVER: - return new RecoveryAction(); - case FINISH: - return new FinishAction(); - case CLOSE: - return new CloseAction(); - case DROP: - return new DropAction(); - case COMPACT: - return new CompactAction(); - case LOAD: - return new LoadAction(); - default: - return new EmptyAction(actionType); - } + public static IAction build(ActionType actionType) { + switch (actionType) { + case ARCHIVE: + return new ArchiveAction(); + case RECOVER: + return new RecoveryAction(); + case FINISH: + return new FinishAction(); + case CLOSE: + return new CloseAction(); + case DROP: + return new DropAction(); + case COMPACT: + return new CompactAction(); + case LOAD: + return new LoadAction(); + default: + return new EmptyAction(actionType); } + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/BaseActionAccess.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/BaseActionAccess.java index d83e72468..16bd0c764 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/BaseActionAccess.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/BaseActionAccess.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; + import org.apache.geaflow.state.action.ActionRequest; import org.apache.geaflow.state.action.ActionType; import org.apache.geaflow.state.action.EmptyAction; @@ -34,34 +35,34 @@ public abstract class BaseActionAccess { - private Lock lock = new ReentrantLock(); - protected Map registeredAction = new HashMap<>(); + private Lock lock = new ReentrantLock(); + protected Map registeredAction = new HashMap<>(); - protected abstract List allowActionTypes(); + protected abstract List allowActionTypes(); - protected void initAction(IStatefulStore baseStore, StateContext stateContext) { - List allowActionTypes = allowActionTypes(); - for (ActionType actionType : allowActionTypes()) { - if (allowActionTypes.contains(actionType)) { - IAction action = ActionBuilder.build(actionType); - action.init(new StateActionContext(baseStore, stateContext.getConfig())); - this.registerAction(action); - } else { - this.registerAction(new EmptyAction(actionType)); - } - } + protected void initAction(IStatefulStore baseStore, StateContext stateContext) { + List allowActionTypes = allowActionTypes(); + for (ActionType actionType : allowActionTypes()) { + if (allowActionTypes.contains(actionType)) { + IAction action = ActionBuilder.build(actionType); + action.init(new StateActionContext(baseStore, stateContext.getConfig())); + this.registerAction(action); + } else { + this.registerAction(new EmptyAction(actionType)); + } } + } - public void registerAction(IAction action) { - if (action != null) { - this.registeredAction.put(action.getActionType(), action); - } + public void registerAction(IAction action) { + if (action != null) { + this.registeredAction.put(action.getActionType(), action); } + } - public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { - request.setShard(shard); - lock.lock(); - this.registeredAction.get(actionType).apply(request); - lock.unlock(); - } + public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { + request.setShard(shard); + lock.lock(); + this.registeredAction.get(actionType).apply(request); + lock.unlock(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/COWGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/COWGraphAccessor.java index 1b6daf92e..1c01072c0 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/COWGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/COWGraphAccessor.java @@ -19,16 +19,17 @@ package org.apache.geaflow.state.strategy.accessor; -import com.google.common.base.Preconditions; import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.graph.StateMode; import org.apache.geaflow.store.IStoreBuilder; +import com.google.common.base.Preconditions; + public class COWGraphAccessor extends RWStaticGraphAccessor { - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - Preconditions.checkArgument(context.getStateMode() == StateMode.COW); - super.init(context, storeBuilder); - } + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + Preconditions.checkArgument(context.getStateMode() == StateMode.COW); + super.init(context, storeBuilder); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/DynamicGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/DynamicGraphAccessor.java index b3ae36a91..531b04be9 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/DynamicGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/DynamicGraphAccessor.java @@ -21,6 +21,4 @@ import org.apache.geaflow.state.graph.DynamicGraphTrait; -public interface DynamicGraphAccessor extends DynamicGraphTrait, IAccessor { - -} +public interface DynamicGraphAccessor extends DynamicGraphTrait, IAccessor {} diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IAccessor.java index 3c07e2a1a..3d799b887 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IAccessor.java @@ -28,12 +28,12 @@ public interface IAccessor { - void init(StateContext stateContext, IStoreBuilder storeBuilder); + void init(StateContext stateContext, IStoreBuilder storeBuilder); - IBaseStore getStore(); + IBaseStore getStore(); - // action - void registerAction(IAction action); + // action + void registerAction(IAction action); - void doStoreAction(int shardId, ActionType actionType, ActionRequest request); + void doStoreAction(int shardId, ActionType actionType, ActionRequest request); } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IStaticGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IStaticGraphAccessor.java index 75e88d13a..287177708 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IStaticGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/IStaticGraphAccessor.java @@ -21,6 +21,4 @@ import org.apache.geaflow.state.graph.StaticGraphTrait; -public interface IStaticGraphAccessor extends StaticGraphTrait, IAccessor { - -} +public interface IStaticGraphAccessor extends StaticGraphTrait, IAccessor {} diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWDynamicGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWDynamicGraphAccessor.java index f3ea304cd..7166e68f3 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWDynamicGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWDynamicGraphAccessor.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -40,126 +41,130 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RWDynamicGraphAccessor extends BaseActionAccess implements DynamicGraphAccessor { +public class RWDynamicGraphAccessor extends BaseActionAccess + implements DynamicGraphAccessor { - private static final Logger LOGGER = LoggerFactory.getLogger(RWDynamicGraphAccessor.class); - private IDynamicGraphStore graphStore; + private static final Logger LOGGER = LoggerFactory.getLogger(RWDynamicGraphAccessor.class); + private IDynamicGraphStore graphStore; - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - this.graphStore = (IDynamicGraphStore) storeBuilder.getStore( - DataModel.DYNAMIC_GRAPH, context.getConfig()); + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + this.graphStore = + (IDynamicGraphStore) + storeBuilder.getStore(DataModel.DYNAMIC_GRAPH, context.getConfig()); - GraphStateDescriptor desc = (GraphStateDescriptor) context.getDescriptor(); - StoreContext storeContext = new StoreContext(context.getName()) + GraphStateDescriptor desc = + (GraphStateDescriptor) context.getDescriptor(); + StoreContext storeContext = + new StoreContext(context.getName()) .withConfig(context.getConfig()) .withMetricGroup(context.getMetricGroup()) .withDataSchema(desc.getGraphSchema()) .withName(context.getName()) .withShardId(context.getShardId()); - this.graphStore.init(storeContext); - - initAction(this.graphStore, context); - } - - @Override - public IDynamicGraphStore getStore() { - return graphStore; - } - - protected List allowActionTypes() { - return Stream.of(ActionType.values()).collect(Collectors.toList()); - } - - @Override - public void addEdge(long version, IEdge edge) { - getStore().addEdge(version, edge); - } - - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - return getStore().getEdges(version, sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(long version, K sid, - IStatePushDown pushdown) { - return getStore().getOneDegreeGraph(version, sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return getStore().vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - return getStore().vertexIDIterator(version, pushdown); - } - - @Override - public void addVertex(long version, IVertex vertex) { - getStore().addVertex(version, vertex); - } - - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - return getStore().getVertex(version, sid, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown) { - return getStore().getVertexIterator(version, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return getStore().getVertexIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - return getStore().getEdgeIterator(version, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return getStore().getEdgeIterator(version, keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - return getStore().getOneDegreeGraphIterator(version, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, List keys, - IStatePushDown pushdown) { - return getStore().getOneDegreeGraphIterator(version, keys, pushdown); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - return getStore().getAllVersions(id, dataType); - } - - @Override - public long getLatestVersion(K id, DataType dataType) { - return getStore().getLatestVersion(id, dataType); - } - - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - return getStore().getAllVersionData(id, pushdown, dataType); - } - - @Override - public Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType) { - return getStore().getVersionData(id, versions, pushdown, dataType); - } + this.graphStore.init(storeContext); + + initAction(this.graphStore, context); + } + + @Override + public IDynamicGraphStore getStore() { + return graphStore; + } + + protected List allowActionTypes() { + return Stream.of(ActionType.values()).collect(Collectors.toList()); + } + + @Override + public void addEdge(long version, IEdge edge) { + getStore().addEdge(version, edge); + } + + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + return getStore().getEdges(version, sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { + return getStore().getOneDegreeGraph(version, sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return getStore().vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + return getStore().vertexIDIterator(version, pushdown); + } + + @Override + public void addVertex(long version, IVertex vertex) { + getStore().addVertex(version, vertex); + } + + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + return getStore().getVertex(version, sid, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + return getStore().getVertexIterator(version, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return getStore().getVertexIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + return getStore().getEdgeIterator(version, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return getStore().getEdgeIterator(version, keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + return getStore().getOneDegreeGraphIterator(version, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return getStore().getOneDegreeGraphIterator(version, keys, pushdown); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + return getStore().getAllVersions(id, dataType); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + return getStore().getLatestVersion(id, dataType); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + return getStore().getAllVersionData(id, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType) { + return getStore().getVersionData(id, versions, pushdown, dataType); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyAccessor.java index 3cd43d937..b565c07e7 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyAccessor.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.geaflow.state.action.ActionType; import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.descriptor.BaseKeyDescriptor; @@ -32,29 +33,32 @@ public class RWKeyAccessor extends BaseActionAccess implements IAccessor { - protected IStatefulStore store; + protected IStatefulStore store; - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - this.store = (IStatefulStore) storeBuilder.getStore(context.getDataModel(), - context.getConfig()); + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + this.store = + (IStatefulStore) storeBuilder.getStore(context.getDataModel(), context.getConfig()); - BaseKeyDescriptor desc = (BaseKeyDescriptor) context.getDescriptor(); + BaseKeyDescriptor desc = (BaseKeyDescriptor) context.getDescriptor(); - StoreContext storeContext = new StoreContext(context.getName()).withConfig( - context.getConfig()).withMetricGroup(context.getMetricGroup()) - .withShardId(context.getShardId()).withKeySerializer(desc.getKeySerializer()); + StoreContext storeContext = + new StoreContext(context.getName()) + .withConfig(context.getConfig()) + .withMetricGroup(context.getMetricGroup()) + .withShardId(context.getShardId()) + .withKeySerializer(desc.getKeySerializer()); - this.store.init(storeContext); - initAction(this.store, context); - } + this.store.init(storeContext); + initAction(this.store, context); + } - @Override - public IBaseStore getStore() { - return this.store; - } + @Override + public IBaseStore getStore() { + return this.store; + } - protected List allowActionTypes() { - return Stream.of(ActionType.values()).collect(Collectors.toList()); - } + protected List allowActionTypes() { + return Stream.of(ActionType.values()).collect(Collectors.toList()); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyListAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyListAccessor.java index 3e6a79bf4..f8096c0f5 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyListAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyListAccessor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.strategy.accessor; import java.util.List; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyListTrait; import org.apache.geaflow.store.IStoreBuilder; @@ -27,26 +28,26 @@ public class RWKeyListAccessor extends RWKeyAccessor implements KeyListTrait { - private IKListStore kListStore; + private IKListStore kListStore; - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - super.init(context, storeBuilder); - this.kListStore = (IKListStore) store; - } + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + super.init(context, storeBuilder); + this.kListStore = (IKListStore) store; + } - @Override - public List get(K key) { - return this.kListStore.get(key); - } + @Override + public List get(K key) { + return this.kListStore.get(key); + } - @Override - public void add(K key, V... value) { - this.kListStore.add(key, value); - } + @Override + public void add(K key, V... value) { + this.kListStore.add(key, value); + } - @Override - public void remove(K key) { - this.kListStore.remove(key); - } + @Override + public void remove(K key) { + this.kListStore.remove(key); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyMapAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyMapAccessor.java index 9fc4dee20..0a96e58c9 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyMapAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyMapAccessor.java @@ -21,48 +21,50 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyMapTrait; import org.apache.geaflow.store.IStoreBuilder; import org.apache.geaflow.store.api.key.IKMapStore; -public class RWKeyMapAccessor extends RWKeyAccessor implements KeyMapTrait { +public class RWKeyMapAccessor extends RWKeyAccessor + implements KeyMapTrait { - private IKMapStore kMapStore; + private IKMapStore kMapStore; - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - super.init(context, storeBuilder); - this.kMapStore = (IKMapStore) store; - } + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + super.init(context, storeBuilder); + this.kMapStore = (IKMapStore) store; + } - @Override - public void add(K key, Map value) { - this.kMapStore.add(key, value); - } + @Override + public void add(K key, Map value) { + this.kMapStore.add(key, value); + } - @Override - public void remove(K key) { - this.kMapStore.remove(key); - } + @Override + public void remove(K key) { + this.kMapStore.remove(key); + } - @Override - public void remove(K key, UK... subKeys) { - this.kMapStore.remove(key, subKeys); - } + @Override + public void remove(K key, UK... subKeys) { + this.kMapStore.remove(key, subKeys); + } - @Override - public void add(K key, UK uk, UV value) { - this.kMapStore.add(key, uk, value); - } + @Override + public void add(K key, UK uk, UV value) { + this.kMapStore.add(key, uk, value); + } - @Override - public Map get(K key) { - return this.kMapStore.get(key); - } + @Override + public Map get(K key) { + return this.kMapStore.get(key); + } - @Override - public List get(K key, UK... uk) { - return this.kMapStore.get(key, uk); - } + @Override + public List get(K key, UK... uk) { + return this.kMapStore.get(key, uk); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyValueAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyValueAccessor.java index cd2ff8e61..8bc827973 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyValueAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWKeyValueAccessor.java @@ -26,26 +26,26 @@ public class RWKeyValueAccessor extends RWKeyAccessor implements KeyValueTrait { - private IKVStore kvStore; - - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - super.init(context, storeBuilder); - this.kvStore = (IKVStore) store; - } - - @Override - public V get(K key) { - return this.kvStore.get(key); - } - - @Override - public void put(K key, V value) { - this.kvStore.put(key, value); - } - - @Override - public void remove(K key) { - this.kvStore.remove(key); - } + private IKVStore kvStore; + + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + super.init(context, storeBuilder); + this.kvStore = (IKVStore) store; + } + + @Override + public V get(K key) { + return this.kvStore.get(key); + } + + @Override + public void put(K key, V value) { + this.kvStore.put(key, value); + } + + @Override + public void remove(K key) { + this.kvStore.remove(key); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWStaticGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWStaticGraphAccessor.java index f12c07b06..1312345dd 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWStaticGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/RWStaticGraphAccessor.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -39,120 +40,128 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class RWStaticGraphAccessor extends BaseActionAccess implements IStaticGraphAccessor { +public class RWStaticGraphAccessor extends BaseActionAccess + implements IStaticGraphAccessor { - private static final Logger LOGGER = LoggerFactory.getLogger(RWStaticGraphAccessor.class); - private IStaticGraphStore graphStore; + private static final Logger LOGGER = LoggerFactory.getLogger(RWStaticGraphAccessor.class); + private IStaticGraphStore graphStore; - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - this.graphStore = (IStaticGraphStore) storeBuilder.getStore(DataModel.STATIC_GRAPH, context.getConfig()); + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + this.graphStore = + (IStaticGraphStore) + storeBuilder.getStore(DataModel.STATIC_GRAPH, context.getConfig()); - GraphStateDescriptor desc = (GraphStateDescriptor) context.getDescriptor(); + GraphStateDescriptor desc = + (GraphStateDescriptor) context.getDescriptor(); - StoreContext storeContext = new StoreContext(context.getName()) + StoreContext storeContext = + new StoreContext(context.getName()) .withConfig(context.getConfig()) .withMetricGroup(context.getMetricGroup()) .withDataSchema(desc.getGraphSchema()) .withName(context.getName()) .withShardId(context.getShardId()); - this.graphStore.init(storeContext); - initAction(this.graphStore, context); - } - - @Override - public IStaticGraphStore getStore() { - return this.graphStore; - } - - protected List allowActionTypes() { - return Stream.of(ActionType.values()).collect(Collectors.toList()); - } - - @Override - public void addEdge(IEdge edge) { - getStore().addEdge(edge); - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - return getStore().getEdges(sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - return getStore().getOneDegreeGraph(sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - return getStore().vertexIDIterator(); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { - return getStore().vertexIDIterator(pushDown); - } - - @Override - public void addVertex(IVertex vertex) { - getStore().addVertex(vertex); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - return getStore().getVertex(sid, pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - return getStore().getVertexIterator(pushdown); - } - - @Override - public CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown) { - return getStore().getVertexIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - return getStore().getEdgeIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - return getStore().getEdgeIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - return getStore().getOneDegreeGraphIterator(pushdown); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, IStatePushDown pushdown) { - return getStore().getOneDegreeGraphIterator(keys, pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - return getStore().getEdgeProjectIterator(pushdown); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, IStatePushDown, R> pushdown) { - return getStore().getEdgeProjectIterator(keys, pushdown); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - return getStore().getAggResult(pushdown); - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - return getStore().getAggResult(keys, pushdown); - } + this.graphStore.init(storeContext); + initAction(this.graphStore, context); + } + + @Override + public IStaticGraphStore getStore() { + return this.graphStore; + } + + protected List allowActionTypes() { + return Stream.of(ActionType.values()).collect(Collectors.toList()); + } + + @Override + public void addEdge(IEdge edge) { + getStore().addEdge(edge); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + return getStore().getEdges(sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + return getStore().getOneDegreeGraph(sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + return getStore().vertexIDIterator(); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushDown) { + return getStore().vertexIDIterator(pushDown); + } + + @Override + public void addVertex(IVertex vertex) { + getStore().addVertex(vertex); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + return getStore().getVertex(sid, pushdown); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + return getStore().getVertexIterator(pushdown); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return getStore().getVertexIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + return getStore().getEdgeIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + return getStore().getEdgeIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + return getStore().getOneDegreeGraphIterator(pushdown); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return getStore().getOneDegreeGraphIterator(keys, pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + return getStore().getEdgeProjectIterator(pushdown); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return getStore().getEdgeProjectIterator(keys, pushdown); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + return getStore().getAggResult(pushdown); + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + return getStore().getAggResult(keys, pushdown); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyDynamicGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyDynamicGraphAccessor.java index 6bb063b75..bd5bc8368 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyDynamicGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyDynamicGraphAccessor.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; + import org.apache.geaflow.state.action.ActionRequest; import org.apache.geaflow.state.action.ActionType; import org.apache.geaflow.state.context.StateContext; @@ -30,29 +31,29 @@ public class ReadOnlyDynamicGraphAccessor extends RWDynamicGraphAccessor { - private final ReadOnlyGraph readOnlyGraph = new ReadOnlyGraph<>(); - private Lock lock = new ReentrantLock(); - - @Override - public void init(StateContext context, IStoreBuilder storeBuilder) { - this.readOnlyGraph.init(context, storeBuilder); - } - - @Override - protected List allowActionTypes() { - return this.readOnlyGraph.allowActionTypes(); - } - - @Override - public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { - request.setShard(shard); - lock.lock(); - this.readOnlyGraph.doStoreAction(shard, actionType, request); - lock.unlock(); - } - - @Override - public IDynamicGraphStore getStore() { - return (IDynamicGraphStore) this.readOnlyGraph.getStore(); - } + private final ReadOnlyGraph readOnlyGraph = new ReadOnlyGraph<>(); + private Lock lock = new ReentrantLock(); + + @Override + public void init(StateContext context, IStoreBuilder storeBuilder) { + this.readOnlyGraph.init(context, storeBuilder); + } + + @Override + protected List allowActionTypes() { + return this.readOnlyGraph.allowActionTypes(); + } + + @Override + public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { + request.setShard(shard); + lock.lock(); + this.readOnlyGraph.doStoreAction(shard, actionType, request); + lock.unlock(); + } + + @Override + public IDynamicGraphStore getStore() { + return (IDynamicGraphStore) this.readOnlyGraph.getStore(); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyGraph.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyGraph.java index 17162a8f9..0fef5178d 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyGraph.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyGraph.java @@ -19,13 +19,13 @@ package org.apache.geaflow.state.strategy.accessor; -import com.google.common.base.Preconditions; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; + import org.apache.geaflow.common.config.keys.StateConfigKeys; import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.utils.SleepUtils; @@ -48,175 +48,192 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + public class ReadOnlyGraph { - private static final Logger LOGGER = LoggerFactory.getLogger(ReadOnlyGraph.class); - - protected StateContext context; - protected IStoreBuilder storeBuilder; - protected ViewMetaBookKeeper viewMetaBookKeeper; - - protected ScheduledExecutorService syncExecutor; - protected long currentVersion; - protected Throwable warmupException; - protected AtomicBoolean initialized; - // InUseGraphStore is the store in using. - protected IStatefulStore inUseGraphStore; - // LatestGraphStore is the store referring to the latest version. - protected IStatefulStore latestGraphStore; - // LazyCloseGraphStore is the store in querying while current store is switching. - protected IStatefulStore lazyCloseGraphStore; - protected boolean enableRecoverLatestVersion; - protected boolean enableStateBackgroundSync; - protected int syncGapMs; - - public void init(StateContext context, IStoreBuilder storeBuilder) { - Preconditions.checkArgument(context.getStateMode() == StateMode.RDONLY); - this.context = context; - this.storeBuilder = storeBuilder; - this.viewMetaBookKeeper = new ViewMetaBookKeeper(context.getName(), context.getConfig()); - this.enableRecoverLatestVersion = context.getConfig() - .getBoolean(StateConfigKeys.STATE_RECOVER_LATEST_VERSION_ENABLE); - this.enableStateBackgroundSync = context.getConfig() - .getBoolean(StateConfigKeys.STATE_BACKGROUND_SYNC_ENABLE); - this.syncGapMs = context.getConfig().getInteger(StateConfigKeys.STATE_SYNC_GAP_MS); - this.initialized = new AtomicBoolean(false); - if (this.enableStateBackgroundSync) { - this.enableRecoverLatestVersion = true; - LOGGER.info("initialize background sync service"); - this.syncExecutor = Executors.newSingleThreadScheduledExecutor( - ThreadUtil.namedThreadFactory(false, - Thread.currentThread().getName() + "read-only-background-sync-" - + context.getShardId())); - this.startStateSyncService(); - } - } + private static final Logger LOGGER = LoggerFactory.getLogger(ReadOnlyGraph.class); - protected List allowActionTypes() { - return Arrays.asList(ActionType.RECOVER, ActionType.LOAD, ActionType.DROP, - ActionType.CLOSE); - } + protected StateContext context; + protected IStoreBuilder storeBuilder; + protected ViewMetaBookKeeper viewMetaBookKeeper; - public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { - if (actionType == ActionType.DROP || actionType == ActionType.CLOSE) { - IAction action = actionType == ActionType.DROP ? new DropAction() : new CloseAction(); - StateActionContext stateActionContext = new StateActionContext(latestGraphStore, - context.getConfig()); - action.init(stateActionContext); - action.apply(request); - if (enableStateBackgroundSync) { - this.syncExecutor.shutdown(); - } - } - if (actionType == ActionType.RECOVER) { - long version = (long) request.getRequest(); - if (!enableStateBackgroundSync) { - recover(version); - } - } else if (actionType == ActionType.LOAD) { - LOGGER.info("wait async background sync to be finished"); - LoadOption option = (LoadOption) request.getRequest(); - if (option.getKeyGroup() != null && !option.getKeyGroup().contains(shard)) { - return; - } - if (enableStateBackgroundSync) { - while (!this.initialized.get()) { - if (warmupException != null) { - throw new GeaflowRuntimeException("warmup error", this.warmupException); - } - SleepUtils.sleepMilliSecond(1000); - } - } else { - recover(option.getCheckPointId()); - } - } + protected ScheduledExecutorService syncExecutor; + protected long currentVersion; + protected Throwable warmupException; + protected AtomicBoolean initialized; + // InUseGraphStore is the store in using. + protected IStatefulStore inUseGraphStore; + // LatestGraphStore is the store referring to the latest version. + protected IStatefulStore latestGraphStore; + // LazyCloseGraphStore is the store in querying while current store is switching. + protected IStatefulStore lazyCloseGraphStore; + protected boolean enableRecoverLatestVersion; + protected boolean enableStateBackgroundSync; + protected int syncGapMs; + + public void init(StateContext context, IStoreBuilder storeBuilder) { + Preconditions.checkArgument(context.getStateMode() == StateMode.RDONLY); + this.context = context; + this.storeBuilder = storeBuilder; + this.viewMetaBookKeeper = new ViewMetaBookKeeper(context.getName(), context.getConfig()); + this.enableRecoverLatestVersion = + context.getConfig().getBoolean(StateConfigKeys.STATE_RECOVER_LATEST_VERSION_ENABLE); + this.enableStateBackgroundSync = + context.getConfig().getBoolean(StateConfigKeys.STATE_BACKGROUND_SYNC_ENABLE); + this.syncGapMs = context.getConfig().getInteger(StateConfigKeys.STATE_SYNC_GAP_MS); + this.initialized = new AtomicBoolean(false); + if (this.enableStateBackgroundSync) { + this.enableRecoverLatestVersion = true; + LOGGER.info("initialize background sync service"); + this.syncExecutor = + Executors.newSingleThreadScheduledExecutor( + ThreadUtil.namedThreadFactory( + false, + Thread.currentThread().getName() + + "read-only-background-sync-" + + context.getShardId())); + this.startStateSyncService(); } + } - protected void recover(long version) { - if (enableRecoverLatestVersion) { - try { - version = viewMetaBookKeeper.getLatestViewVersion(context.getName()); - } catch (Throwable t) { - throw new GeaflowRuntimeException("failed to get latest version", t); - } - } - if (latestGraphStore == null) { - createReadOnlyState(version); - } else { - updateVersion(version); + protected List allowActionTypes() { + return Arrays.asList(ActionType.RECOVER, ActionType.LOAD, ActionType.DROP, ActionType.CLOSE); + } + + public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { + if (actionType == ActionType.DROP || actionType == ActionType.CLOSE) { + IAction action = actionType == ActionType.DROP ? new DropAction() : new CloseAction(); + StateActionContext stateActionContext = + new StateActionContext(latestGraphStore, context.getConfig()); + action.init(stateActionContext); + action.apply(request); + if (enableStateBackgroundSync) { + this.syncExecutor.shutdown(); + } + } + if (actionType == ActionType.RECOVER) { + long version = (long) request.getRequest(); + if (!enableStateBackgroundSync) { + recover(version); + } + } else if (actionType == ActionType.LOAD) { + LOGGER.info("wait async background sync to be finished"); + LoadOption option = (LoadOption) request.getRequest(); + if (option.getKeyGroup() != null && !option.getKeyGroup().contains(shard)) { + return; + } + if (enableStateBackgroundSync) { + while (!this.initialized.get()) { + if (warmupException != null) { + throw new GeaflowRuntimeException("warmup error", this.warmupException); + } + SleepUtils.sleepMilliSecond(1000); } + } else { + recover(option.getCheckPointId()); + } } + } - protected void startStateSyncService() { - this.syncExecutor.scheduleAtFixedRate(() -> { - try { - final long start = System.currentTimeMillis(); - long latestVersion = viewMetaBookKeeper.getLatestViewVersion(context.getName()); - Preconditions.checkArgument(latestVersion > 0); - if (latestVersion != currentVersion) { - createReadOnlyState(latestVersion); - currentVersion = latestVersion; - } else { - LOGGER.info("don't need recover, current version {} latest version {}", - currentVersion, latestVersion); - } - // Try to update in-use connection. - getStore(); - LOGGER.info("background sync finished cost {}", System.currentTimeMillis() - start); - this.initialized.set(true); - } catch (Throwable t) { - if (!initialized.get()) { - this.warmupException = t; - } - LOGGER.error("background sync error", t); - } - }, 0, syncGapMs, TimeUnit.MILLISECONDS); + protected void recover(long version) { + if (enableRecoverLatestVersion) { + try { + version = viewMetaBookKeeper.getLatestViewVersion(context.getName()); + } catch (Throwable t) { + throw new GeaflowRuntimeException("failed to get latest version", t); + } + } + if (latestGraphStore == null) { + createReadOnlyState(version); + } else { + updateVersion(version); } + } - public IBaseStore getStore() { - if (enableStateBackgroundSync) { - if (latestGraphStore != inUseGraphStore) { - synchronized (ReadOnlyStaticGraphAccessor.class) { - if (latestGraphStore == inUseGraphStore) { - return inUseGraphStore; - } - if (lazyCloseGraphStore != null) { - lazyCloseGraphStore.close(); - } - lazyCloseGraphStore = inUseGraphStore; - inUseGraphStore = latestGraphStore; - } + protected void startStateSyncService() { + this.syncExecutor.scheduleAtFixedRate( + () -> { + try { + final long start = System.currentTimeMillis(); + long latestVersion = viewMetaBookKeeper.getLatestViewVersion(context.getName()); + Preconditions.checkArgument(latestVersion > 0); + if (latestVersion != currentVersion) { + createReadOnlyState(latestVersion); + currentVersion = latestVersion; + } else { + LOGGER.info( + "don't need recover, current version {} latest version {}", + currentVersion, + latestVersion); } - return inUseGraphStore; - } else { - if (latestGraphStore == null) { - LOGGER.warn("create graph store is null, shardId {}, keyGroup {}", - context.getShardId(), context.getKeyGroup()); + // Try to update in-use connection. + getStore(); + LOGGER.info("background sync finished cost {}", System.currentTimeMillis() - start); + this.initialized.set(true); + } catch (Throwable t) { + if (!initialized.get()) { + this.warmupException = t; } - return latestGraphStore; + LOGGER.error("background sync error", t); + } + }, + 0, + syncGapMs, + TimeUnit.MILLISECONDS); + } + + public IBaseStore getStore() { + if (enableStateBackgroundSync) { + if (latestGraphStore != inUseGraphStore) { + synchronized (ReadOnlyStaticGraphAccessor.class) { + if (latestGraphStore == inUseGraphStore) { + return inUseGraphStore; + } + if (lazyCloseGraphStore != null) { + lazyCloseGraphStore.close(); + } + lazyCloseGraphStore = inUseGraphStore; + inUseGraphStore = latestGraphStore; } + } + return inUseGraphStore; + } else { + if (latestGraphStore == null) { + LOGGER.warn( + "create graph store is null, shardId {}, keyGroup {}", + context.getShardId(), + context.getKeyGroup()); + } + return latestGraphStore; } + } - protected void createReadOnlyState(long version) { - LOGGER.info("create new read only state, state index {} version {} backend type {}", - context.getShardId(), version, context.getStoreType()); - IStatefulStore graphStoreTmp = (IStatefulStore) storeBuilder.getStore(this.context.getDataModel(), - context.getConfig()); - GraphStateDescriptor desc = - (GraphStateDescriptor) context.getDescriptor(); - - StoreContext storeContext = new StoreContext(context.getName()).withConfig( - context.getConfig()).withMetricGroup(context.getMetricGroup()) - .withDataSchema(desc.getGraphSchema()).withName(context.getName()) + protected void createReadOnlyState(long version) { + LOGGER.info( + "create new read only state, state index {} version {} backend type {}", + context.getShardId(), + version, + context.getStoreType()); + IStatefulStore graphStoreTmp = + (IStatefulStore) storeBuilder.getStore(this.context.getDataModel(), context.getConfig()); + GraphStateDescriptor desc = + (GraphStateDescriptor) context.getDescriptor(); + + StoreContext storeContext = + new StoreContext(context.getName()) + .withConfig(context.getConfig()) + .withMetricGroup(context.getMetricGroup()) + .withDataSchema(desc.getGraphSchema()) + .withName(context.getName()) .withShardId(context.getShardId()); - graphStoreTmp.init(storeContext); - graphStoreTmp.recovery(version); - latestGraphStore = graphStoreTmp; - } + graphStoreTmp.init(storeContext); + graphStoreTmp.recovery(version); + latestGraphStore = graphStoreTmp; + } - protected void updateVersion(long version) { - LOGGER.info("update read only state, state index {} version {}", context.getShardId(), - version); - latestGraphStore.recovery(version); - } + protected void updateVersion(long version) { + LOGGER.info("update read only state, state index {} version {}", context.getShardId(), version); + latestGraphStore.recovery(version); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyStaticGraphAccessor.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyStaticGraphAccessor.java index 0a801251f..2e4f93474 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyStaticGraphAccessor.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/accessor/ReadOnlyStaticGraphAccessor.java @@ -20,6 +20,7 @@ package org.apache.geaflow.state.strategy.accessor; import java.util.List; + import org.apache.geaflow.state.action.ActionRequest; import org.apache.geaflow.state.action.ActionType; import org.apache.geaflow.state.context.StateContext; @@ -28,21 +29,21 @@ public class ReadOnlyStaticGraphAccessor extends RWStaticGraphAccessor { - private final ReadOnlyGraph readOnlyGraph = new ReadOnlyGraph<>(); + private final ReadOnlyGraph readOnlyGraph = new ReadOnlyGraph<>(); - public void init(StateContext context, IStoreBuilder storeBuilder) { - this.readOnlyGraph.init(context, storeBuilder); - } + public void init(StateContext context, IStoreBuilder storeBuilder) { + this.readOnlyGraph.init(context, storeBuilder); + } - protected List allowActionTypes() { - return this.readOnlyGraph.allowActionTypes(); - } + protected List allowActionTypes() { + return this.readOnlyGraph.allowActionTypes(); + } - public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { - this.readOnlyGraph.doStoreAction(shard, actionType, request); - } + public void doStoreAction(int shard, ActionType actionType, ActionRequest request) { + this.readOnlyGraph.doStoreAction(shard, actionType, request); + } - public IStaticGraphStore getStore() { - return (IStaticGraphStore) this.readOnlyGraph.getStore(); - } -} \ No newline at end of file + public IStaticGraphStore getStore() { + return (IStaticGraphStore) this.readOnlyGraph.getStore(); + } +} diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseShardManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseShardManager.java index 525d8ce25..706f29a8c 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseShardManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseShardManager.java @@ -19,10 +19,10 @@ package org.apache.geaflow.state.strategy.manager; -import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.pushdown.IStatePushDown; @@ -33,50 +33,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class BaseShardManager { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(BaseShardManager.class); +public abstract class BaseShardManager { - protected final int totalShardNum; - protected Map traitMap; - protected KeyGroup shardGroup; - protected IKeyGroupAssigner assigner; - protected boolean mayScale; + private static final Logger LOGGER = LoggerFactory.getLogger(BaseShardManager.class); - public BaseShardManager(StateContext context, Map accessorMap) { - this.shardGroup = context.getKeyGroup(); - this.assigner = context.getDescriptor().getAssigner(); - Preconditions.checkArgument(this.assigner != null, "The assigner must be not null"); - LOGGER.info("key group {}, key group num {}", this.shardGroup, this.assigner.getKeyGroupNumber()); - this.mayScale = context.isLocalStore(); - this.totalShardNum = this.assigner.getKeyGroupNumber(); - this.traitMap = new HashMap<>(accessorMap.size()); - for (Entry entry : accessorMap.entrySet()) { - this.traitMap.put(entry.getKey(), (T) entry.getValue()); - } - } + protected final int totalShardNum; + protected Map traitMap; + protected KeyGroup shardGroup; + protected IKeyGroupAssigner assigner; + protected boolean mayScale; - protected T getTraitByKey(K key) { - return getTraitById(assigner.assign(key)); + public BaseShardManager(StateContext context, Map accessorMap) { + this.shardGroup = context.getKeyGroup(); + this.assigner = context.getDescriptor().getAssigner(); + Preconditions.checkArgument(this.assigner != null, "The assigner must be not null"); + LOGGER.info( + "key group {}, key group num {}", this.shardGroup, this.assigner.getKeyGroupNumber()); + this.mayScale = context.isLocalStore(); + this.totalShardNum = this.assigner.getKeyGroupNumber(); + this.traitMap = new HashMap<>(accessorMap.size()); + for (Entry entry : accessorMap.entrySet()) { + this.traitMap.put(entry.getKey(), (T) entry.getValue()); } + } - protected T getTraitById(int keyGroupId) { - T trait = traitMap.get(keyGroupId); - if (trait == null) { - throw new GeaflowRuntimeException( - "we have " + traitMap.keySet() + " need keyGroupId " + keyGroupId); + protected T getTraitByKey(K key) { + return getTraitById(assigner.assign(key)); + } - } - return trait; + protected T getTraitById(int keyGroupId) { + T trait = traitMap.get(keyGroupId); + if (trait == null) { + throw new GeaflowRuntimeException( + "we have " + traitMap.keySet() + " need keyGroupId " + keyGroupId); } + return trait; + } - protected KeyGroup getShardGroup(IStatePushDown pushdown) { - KeyGroup queryShardGroup = this.shardGroup; - if (pushdown instanceof KeyGroupStatePushDown) { - queryShardGroup = ((KeyGroupStatePushDown) pushdown).getKeyGroup(); - Preconditions.checkArgument(this.shardGroup.contains(queryShardGroup), - "state keyGroup %s, query keyGroup %s", this.shardGroup, queryShardGroup); - } - return queryShardGroup; + protected KeyGroup getShardGroup(IStatePushDown pushdown) { + KeyGroup queryShardGroup = this.shardGroup; + if (pushdown instanceof KeyGroupStatePushDown) { + queryShardGroup = ((KeyGroupStatePushDown) pushdown).getKeyGroup(); + Preconditions.checkArgument( + this.shardGroup.contains(queryShardGroup), + "state keyGroup %s, query keyGroup %s", + this.shardGroup, + queryShardGroup); } + return queryShardGroup; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseStateManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseStateManager.java index 7baa08f25..10ffdbff9 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseStateManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/BaseStateManager.java @@ -19,9 +19,9 @@ package org.apache.geaflow.state.strategy.manager; -import com.google.common.base.Preconditions; import java.util.HashMap; import java.util.Map; + import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.metrics.common.api.MetricGroup; import org.apache.geaflow.state.action.ActionRequest; @@ -34,40 +34,41 @@ import org.apache.geaflow.store.api.StoreBuilderFactory; import org.apache.geaflow.utils.keygroup.KeyGroup; +import com.google.common.base.Preconditions; + public class BaseStateManager { - protected StateContext context; - protected Configuration config; - protected MetricGroup metricGroup; - protected KeyGroup keyGroup; - protected Map accessorMap = new HashMap<>(); + protected StateContext context; + protected Configuration config; + protected MetricGroup metricGroup; + protected KeyGroup keyGroup; + protected Map accessorMap = new HashMap<>(); - public void init(StateContext context) { - this.context = context; - this.config = context.getConfig(); - this.metricGroup = context.getMetricGroup(); - this.keyGroup = context.getKeyGroup(); - IStoreBuilder storeBuilder = StoreBuilderFactory.build(context.getStoreType()); - this.context.withLocalStore(storeBuilder.getStoreDesc().isLocalStore()); + public void init(StateContext context) { + this.context = context; + this.config = context.getConfig(); + this.metricGroup = context.getMetricGroup(); + this.keyGroup = context.getKeyGroup(); + IStoreBuilder storeBuilder = StoreBuilderFactory.build(context.getStoreType()); + this.context.withLocalStore(storeBuilder.getStoreDesc().isLocalStore()); - for (int shardId = keyGroup.getStartKeyGroup(); shardId <= keyGroup.getEndKeyGroup(); - shardId++) { - IAccessor accessor = AccessorBuilder.getAccessor(context.getDataModel(), - context.getStateMode()); - StateContext newContext = context.clone().withShardId(shardId); - accessor.init(newContext, storeBuilder); + for (int shardId = keyGroup.getStartKeyGroup(); + shardId <= keyGroup.getEndKeyGroup(); + shardId++) { + IAccessor accessor = + AccessorBuilder.getAccessor(context.getDataModel(), context.getStateMode()); + StateContext newContext = context.clone().withShardId(shardId); + accessor.init(newContext, storeBuilder); - this.accessorMap.put(shardId, accessor); - } + this.accessorMap.put(shardId, accessor); } + } - public void doStoreAction(ActionType actionType, ActionRequest request) { - if (actionType == ActionType.LOAD) { - KeyGroup loadKeyGroup = ((LoadOption) request.getRequest()).getKeyGroup(); - Preconditions.checkArgument( - loadKeyGroup == null || this.keyGroup.contains(loadKeyGroup)); - } - this.accessorMap.forEach((key, value) -> value.doStoreAction(key, actionType, request)); - + public void doStoreAction(ActionType actionType, ActionRequest request) { + if (actionType == ActionType.LOAD) { + KeyGroup loadKeyGroup = ((LoadOption) request.getRequest()).getKeyGroup(); + Preconditions.checkArgument(loadKeyGroup == null || this.keyGroup.contains(loadKeyGroup)); } + this.accessorMap.forEach((key, value) -> value.doStoreAction(key, actionType, request)); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/DynamicGraphManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/DynamicGraphManagerImpl.java index 3984c3bcb..88b52739a 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/DynamicGraphManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/DynamicGraphManagerImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.state.strategy.manager; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -28,6 +27,7 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -43,166 +43,184 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class DynamicGraphManagerImpl extends BaseShardManager> implements DynamicGraphTrait { - - private static final Logger LOGGER = LoggerFactory.getLogger(DynamicGraphManagerImpl.class); - - public DynamicGraphManagerImpl(StateContext context, Map accessorMap) { - super(context, accessorMap); - } - - @Override - public void addEdge(long version, IEdge edge) { - getTraitByKey(edge.getSrcId()).addEdge(version, edge); - } - - @Override - public List> getEdges(long version, K sid, IStatePushDown pushdown) { - return getTraitByKey(sid).getEdges(version, sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(long version, K sid, - IStatePushDown pushdown) { - return getTraitByKey(sid).getOneDegreeGraph(version, sid, pushdown); - } - - @Override - public void addVertex(long version, IVertex vertex) { - getTraitByKey(vertex.getId()).addVertex(version, vertex); - } - - @Override - public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { - return getTraitByKey(sid).getVertex(version, sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - List> iterators = new ArrayList<>(); - for (Entry> entry : traitMap.entrySet()) { - CloseableIterator iterator = entry.getValue().vertexIDIterator(); - iterators.add(this.mayScale ? shardFilter(iterator, entry.getKey(), k -> k) : iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - @Override - public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { - List> iterators = new ArrayList<>(); - KeyGroup shardGroup = getShardGroup(pushdown); - - for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { - DynamicGraphTrait trait = traitMap.get(shard); - CloseableIterator iterator = trait.vertexIDIterator(version, pushdown); - iterators.add(this.mayScale ? shardFilter(iterator, shard, k -> k) : iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - @Override - public CloseableIterator> getVertexIterator(long version, IStatePushDown pushdown) { - return getIterator(IVertex::getId, pushdown, (trait, pushdown1) -> trait.getVertexIterator(version, pushdown1)); - } - - @Override - public CloseableIterator> getVertexIterator(long version, List keys, - IStatePushDown pushdown) { - return getIterator(keys, pushdown, - (trait, keys1, pushdown1) -> trait.getVertexIterator(version, keys1, pushdown1)); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { - return getIterator(IEdge::getSrcId, pushdown, (trait, pushdown1) -> trait.getEdgeIterator(version, pushdown1)); - } - - @Override - public CloseableIterator> getEdgeIterator(long version, List keys, - IStatePushDown pushdown) { - return getIterator(keys, pushdown, - (trait, keys1, pushdown1) -> trait.getEdgeIterator(version, keys1, pushdown1)); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, - IStatePushDown pushdown) { - return getIterator(OneDegreeGraph::getKey, pushdown, - (trait, pushdown1) -> trait.getOneDegreeGraphIterator(version, pushdown1)); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(long version, List keys, - IStatePushDown pushdown) { - return getIterator(keys, pushdown, - (trait, keys1, pushdown1) -> trait.getOneDegreeGraphIterator(version, keys1, pushdown1)); - } - - private CloseableIterator getIterator(List keys, - IStatePushDown pushdown, - TriFunction, List, IStatePushDown, CloseableIterator> function) { - List> iterators = new ArrayList<>(); - Map> keyGroupMap = getKeyGroupMap(keys); - for (Entry> entry : keyGroupMap.entrySet()) { - Preconditions.checkArgument( - entry.getKey() >= this.shardGroup.getStartKeyGroup() - && entry.getKey() <= this.shardGroup.getEndKeyGroup()); - - CloseableIterator iterator = function.apply(getTraitById(entry.getKey()), entry.getValue(), pushdown); - iterators.add(iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - - private CloseableIterator getIterator( - Function keyExtractor, - IStatePushDown pushdown, - BiFunction, IStatePushDown, CloseableIterator> function) { - List> iterators = new ArrayList<>(); - - KeyGroup shardGroup = getShardGroup(pushdown); - int startShard = shardGroup.getStartKeyGroup(); - int endShard = shardGroup.getEndKeyGroup(); - - for (int shard = startShard; shard <= endShard; shard++) { - DynamicGraphTrait trait = traitMap.get(shard); - CloseableIterator iterator = function.apply(trait, pushdown); - iterators.add(this.mayScale ? shardFilter(iterator, shard, keyExtractor) : iterator); - } - - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - @Override - public List getAllVersions(K id, DataType dataType) { - return getTraitByKey(id).getAllVersions(id, dataType); - } - - @Override - public long getLatestVersion(K id, DataType dataType) { - return getTraitByKey(id).getLatestVersion(id, dataType); - } - - @Override - public Map> getAllVersionData(K id, IStatePushDown pushdown, - DataType dataType) { - return getTraitByKey(id).getAllVersionData(id, pushdown, dataType); - } - - @Override - public Map> getVersionData(K id, Collection versions, - IStatePushDown pushdown, DataType dataType) { - return getTraitByKey(id).getVersionData(id, versions, pushdown, dataType); - } - - private CloseableIterator shardFilter(CloseableIterator iterator, int keyGroupId, - Function keyExtractor) { - return new IteratorWithFilter<>(iterator, t -> assigner.assign(keyExtractor.apply(t)) == keyGroupId); - } +import com.google.common.base.Preconditions; - private Map> getKeyGroupMap(Collection keySet) { - return keySet.stream().collect(Collectors.groupingBy(c -> assigner.assign(c))); - } +public class DynamicGraphManagerImpl + extends BaseShardManager> + implements DynamicGraphTrait { + + private static final Logger LOGGER = LoggerFactory.getLogger(DynamicGraphManagerImpl.class); + + public DynamicGraphManagerImpl(StateContext context, Map accessorMap) { + super(context, accessorMap); + } + + @Override + public void addEdge(long version, IEdge edge) { + getTraitByKey(edge.getSrcId()).addEdge(version, edge); + } + + @Override + public List> getEdges(long version, K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getEdges(version, sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(long version, K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getOneDegreeGraph(version, sid, pushdown); + } + + @Override + public void addVertex(long version, IVertex vertex) { + getTraitByKey(vertex.getId()).addVertex(version, vertex); + } + + @Override + public IVertex getVertex(long version, K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getVertex(version, sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + List> iterators = new ArrayList<>(); + for (Entry> entry : traitMap.entrySet()) { + CloseableIterator iterator = entry.getValue().vertexIDIterator(); + iterators.add(this.mayScale ? shardFilter(iterator, entry.getKey(), k -> k) : iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + @Override + public CloseableIterator vertexIDIterator(long version, IStatePushDown pushdown) { + List> iterators = new ArrayList<>(); + KeyGroup shardGroup = getShardGroup(pushdown); + + for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { + DynamicGraphTrait trait = traitMap.get(shard); + CloseableIterator iterator = trait.vertexIDIterator(version, pushdown); + iterators.add(this.mayScale ? shardFilter(iterator, shard, k -> k) : iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, IStatePushDown pushdown) { + return getIterator( + IVertex::getId, + pushdown, + (trait, pushdown1) -> trait.getVertexIterator(version, pushdown1)); + } + + @Override + public CloseableIterator> getVertexIterator( + long version, List keys, IStatePushDown pushdown) { + return getIterator( + keys, + pushdown, + (trait, keys1, pushdown1) -> trait.getVertexIterator(version, keys1, pushdown1)); + } + + @Override + public CloseableIterator> getEdgeIterator(long version, IStatePushDown pushdown) { + return getIterator( + IEdge::getSrcId, pushdown, (trait, pushdown1) -> trait.getEdgeIterator(version, pushdown1)); + } + + @Override + public CloseableIterator> getEdgeIterator( + long version, List keys, IStatePushDown pushdown) { + return getIterator( + keys, + pushdown, + (trait, keys1, pushdown1) -> trait.getEdgeIterator(version, keys1, pushdown1)); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, IStatePushDown pushdown) { + return getIterator( + OneDegreeGraph::getKey, + pushdown, + (trait, pushdown1) -> trait.getOneDegreeGraphIterator(version, pushdown1)); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + long version, List keys, IStatePushDown pushdown) { + return getIterator( + keys, + pushdown, + (trait, keys1, pushdown1) -> trait.getOneDegreeGraphIterator(version, keys1, pushdown1)); + } + + private CloseableIterator getIterator( + List keys, + IStatePushDown pushdown, + TriFunction, List, IStatePushDown, CloseableIterator> + function) { + List> iterators = new ArrayList<>(); + Map> keyGroupMap = getKeyGroupMap(keys); + for (Entry> entry : keyGroupMap.entrySet()) { + Preconditions.checkArgument( + entry.getKey() >= this.shardGroup.getStartKeyGroup() + && entry.getKey() <= this.shardGroup.getEndKeyGroup()); + + CloseableIterator iterator = + function.apply(getTraitById(entry.getKey()), entry.getValue(), pushdown); + iterators.add(iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + private CloseableIterator getIterator( + Function keyExtractor, + IStatePushDown pushdown, + BiFunction, IStatePushDown, CloseableIterator> function) { + List> iterators = new ArrayList<>(); + + KeyGroup shardGroup = getShardGroup(pushdown); + int startShard = shardGroup.getStartKeyGroup(); + int endShard = shardGroup.getEndKeyGroup(); + + for (int shard = startShard; shard <= endShard; shard++) { + DynamicGraphTrait trait = traitMap.get(shard); + CloseableIterator iterator = function.apply(trait, pushdown); + iterators.add(this.mayScale ? shardFilter(iterator, shard, keyExtractor) : iterator); + } + + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + @Override + public List getAllVersions(K id, DataType dataType) { + return getTraitByKey(id).getAllVersions(id, dataType); + } + + @Override + public long getLatestVersion(K id, DataType dataType) { + return getTraitByKey(id).getLatestVersion(id, dataType); + } + + @Override + public Map> getAllVersionData( + K id, IStatePushDown pushdown, DataType dataType) { + return getTraitByKey(id).getAllVersionData(id, pushdown, dataType); + } + + @Override + public Map> getVersionData( + K id, Collection versions, IStatePushDown pushdown, DataType dataType) { + return getTraitByKey(id).getVersionData(id, versions, pushdown, dataType); + } + + private CloseableIterator shardFilter( + CloseableIterator iterator, int keyGroupId, Function keyExtractor) { + return new IteratorWithFilter<>( + iterator, t -> assigner.assign(keyExtractor.apply(t)) == keyGroupId); + } + + private Map> getKeyGroupMap(Collection keySet) { + return keySet.stream().collect(Collectors.groupingBy(c -> assigner.assign(c))); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/GraphManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/GraphManagerImpl.java index 83f3b7280..84cb4bb05 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/GraphManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/GraphManagerImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.state.strategy.manager; -import com.google.common.base.Preconditions; import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.graph.DynamicGraphTrait; import org.apache.geaflow.state.graph.StaticGraphTrait; @@ -30,54 +29,55 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class GraphManagerImpl extends BaseStateManager implements IGraphManager { +import com.google.common.base.Preconditions; - private static final Logger LOGGER = LoggerFactory.getLogger(GraphManagerImpl.class); +public class GraphManagerImpl extends BaseStateManager + implements IGraphManager { - private StaticGraphTrait staticGraphTrait; - private DynamicGraphTrait dynamicGraphTrait; - private IFilterConverter filterConverter; + private static final Logger LOGGER = LoggerFactory.getLogger(GraphManagerImpl.class); - public GraphManagerImpl() { + private StaticGraphTrait staticGraphTrait; + private DynamicGraphTrait dynamicGraphTrait; + private IFilterConverter filterConverter; - } + public GraphManagerImpl() {} - @Override - public void init(StateContext context) { - super.init(context); - } + @Override + public void init(StateContext context) { + super.init(context); + } - @Override - public StaticGraphTrait getStaticGraphTrait() { - if (staticGraphTrait == null) { - staticGraphTrait = new StaticGraphManagerImpl<>(this.context, this.accessorMap); - } - return staticGraphTrait; + @Override + public StaticGraphTrait getStaticGraphTrait() { + if (staticGraphTrait == null) { + staticGraphTrait = new StaticGraphManagerImpl<>(this.context, this.accessorMap); } + return staticGraphTrait; + } - @Override - public DynamicGraphTrait getDynamicGraphTrait() { - if (dynamicGraphTrait == null) { - dynamicGraphTrait = new DynamicGraphManagerImpl<>(this.context, this.accessorMap); - } - return dynamicGraphTrait; + @Override + public DynamicGraphTrait getDynamicGraphTrait() { + if (dynamicGraphTrait == null) { + dynamicGraphTrait = new DynamicGraphManagerImpl<>(this.context, this.accessorMap); } + return dynamicGraphTrait; + } - @Override - public IFilterConverter getFilterConverter() { - if (filterConverter == null) { - if (this.accessorMap.values().size() > 0) { - for (IAccessor value : this.accessorMap.values()) { - IBaseStore store = value.getStore(); - if (store == null) { - continue; - } - if (store instanceof IPushDownStore) { - filterConverter = ((IPushDownStore) store).getFilterConverter(); - } - } - } + @Override + public IFilterConverter getFilterConverter() { + if (filterConverter == null) { + if (this.accessorMap.values().size() > 0) { + for (IAccessor value : this.accessorMap.values()) { + IBaseStore store = value.getStore(); + if (store == null) { + continue; + } + if (store instanceof IPushDownStore) { + filterConverter = ((IPushDownStore) store).getFilterConverter(); + } } - return Preconditions.checkNotNull(filterConverter); + } } + return Preconditions.checkNotNull(filterConverter); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IGraphManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IGraphManager.java index 9115ca50e..f89129fbd 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IGraphManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IGraphManager.java @@ -25,9 +25,9 @@ public interface IGraphManager extends IStateManager { - StaticGraphTrait getStaticGraphTrait(); + StaticGraphTrait getStaticGraphTrait(); - DynamicGraphTrait getDynamicGraphTrait(); + DynamicGraphTrait getDynamicGraphTrait(); - IFilterConverter getFilterConverter(); + IFilterConverter getFilterConverter(); } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IKeyStateManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IKeyStateManager.java index 83ea0c570..260827f88 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IKeyStateManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IKeyStateManager.java @@ -25,9 +25,9 @@ public interface IKeyStateManager extends IStateManager { - KeyValueTrait getKeyValueTrait(Class valueClazz); + KeyValueTrait getKeyValueTrait(Class valueClazz); - KeyListTrait getKeyListTrait(Class valueClazz); + KeyListTrait getKeyListTrait(Class valueClazz); - KeyMapTrait getKeyMapTrait(Class subKeyClazz, Class valueClazz); + KeyMapTrait getKeyMapTrait(Class subKeyClazz, Class valueClazz); } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IStateManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IStateManager.java index 7d66445dc..7d6bfcfef 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IStateManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/IStateManager.java @@ -25,7 +25,7 @@ public interface IStateManager { - void init(StateContext context); + void init(StateContext context); - void doStoreAction(ActionType actionType, ActionRequest request); + void doStoreAction(ActionType actionType, ActionRequest request); } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyListManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyListManagerImpl.java index 6f564f402..92b00147a 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyListManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyListManagerImpl.java @@ -21,28 +21,30 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyListTrait; import org.apache.geaflow.state.strategy.accessor.IAccessor; -public class KeyListManagerImpl extends BaseShardManager> implements KeyListTrait { +public class KeyListManagerImpl extends BaseShardManager> + implements KeyListTrait { - public KeyListManagerImpl(StateContext context, Map accessorMap) { - super(context, accessorMap); - } + public KeyListManagerImpl(StateContext context, Map accessorMap) { + super(context, accessorMap); + } - @Override - public List get(K key) { - return getTraitByKey(key).get(key); - } + @Override + public List get(K key) { + return getTraitByKey(key).get(key); + } - @Override - public void add(K key, V... value) { - getTraitByKey(key).add(key, value); - } + @Override + public void add(K key, V... value) { + getTraitByKey(key).add(key, value); + } - @Override - public void remove(K key) { - getTraitByKey(key).remove(key); - } + @Override + public void remove(K key) { + getTraitByKey(key).remove(key); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyMapManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyMapManagerImpl.java index 54951f816..3bacb64a5 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyMapManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyMapManagerImpl.java @@ -21,44 +21,45 @@ import java.util.List; import java.util.Map; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyMapTrait; import org.apache.geaflow.state.strategy.accessor.IAccessor; -public class KeyMapManagerImpl extends BaseShardManager> implements KeyMapTrait { +public class KeyMapManagerImpl extends BaseShardManager> + implements KeyMapTrait { - public KeyMapManagerImpl(StateContext context, Map accessorMap) { - super(context, accessorMap); - } + public KeyMapManagerImpl(StateContext context, Map accessorMap) { + super(context, accessorMap); + } - @Override - public void add(K key, Map value) { - getTraitByKey(key).add(key, value); - } + @Override + public void add(K key, Map value) { + getTraitByKey(key).add(key, value); + } - @Override - public void remove(K key) { - getTraitByKey(key).remove(key); - } + @Override + public void remove(K key) { + getTraitByKey(key).remove(key); + } - @Override - public void remove(K key, UK... subKeys) { - getTraitByKey(key).remove(key, subKeys); - } + @Override + public void remove(K key, UK... subKeys) { + getTraitByKey(key).remove(key, subKeys); + } - @Override - public void add(K key, UK uk, UV value) { - getTraitByKey(key).add(key, uk, value); - } + @Override + public void add(K key, UK uk, UV value) { + getTraitByKey(key).add(key, uk, value); + } - @Override - public Map get(K key) { - return getTraitByKey(key).get(key); - } + @Override + public Map get(K key) { + return getTraitByKey(key).get(key); + } - @Override - public List get(K key, UK... subKeys) { - return getTraitByKey(key).get(key, subKeys); - } + @Override + public List get(K key, UK... subKeys) { + return getTraitByKey(key).get(key, subKeys); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyStateManager.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyStateManager.java index d0653c513..c6167f505 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyStateManager.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyStateManager.java @@ -25,32 +25,32 @@ public class KeyStateManager extends BaseStateManager implements IKeyStateManager { - private KeyValueTrait keyValueTrait; - private KeyListTrait keyListTrait; - private KeyMapTrait keyMapTrait; + private KeyValueTrait keyValueTrait; + private KeyListTrait keyListTrait; + private KeyMapTrait keyMapTrait; - @Override - public KeyValueTrait getKeyValueTrait(Class valueClazz) { - if (this.keyValueTrait == null) { - this.keyValueTrait = new KeyValueManagerImpl<>(context, accessorMap); - } - return this.keyValueTrait; + @Override + public KeyValueTrait getKeyValueTrait(Class valueClazz) { + if (this.keyValueTrait == null) { + this.keyValueTrait = new KeyValueManagerImpl<>(context, accessorMap); } + return this.keyValueTrait; + } - @Override - public KeyListTrait getKeyListTrait(Class valueClazz) { - if (this.keyListTrait == null) { - this.keyListTrait = new KeyListManagerImpl<>(context, accessorMap); - } - return this.keyListTrait; + @Override + public KeyListTrait getKeyListTrait(Class valueClazz) { + if (this.keyListTrait == null) { + this.keyListTrait = new KeyListManagerImpl<>(context, accessorMap); } + return this.keyListTrait; + } - @Override - public KeyMapTrait getKeyMapTrait(Class subKeyClazz, - Class valueClazz) { - if (this.keyMapTrait == null) { - this.keyMapTrait = new KeyMapManagerImpl<>(context, accessorMap); - } - return this.keyMapTrait; + @Override + public KeyMapTrait getKeyMapTrait( + Class subKeyClazz, Class valueClazz) { + if (this.keyMapTrait == null) { + this.keyMapTrait = new KeyMapManagerImpl<>(context, accessorMap); } + return this.keyMapTrait; + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyValueManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyValueManagerImpl.java index 161ae5678..2aee0fa53 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyValueManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/KeyValueManagerImpl.java @@ -20,28 +20,30 @@ package org.apache.geaflow.state.strategy.manager; import java.util.Map; + import org.apache.geaflow.state.context.StateContext; import org.apache.geaflow.state.key.KeyValueTrait; import org.apache.geaflow.state.strategy.accessor.IAccessor; -public class KeyValueManagerImpl extends BaseShardManager> implements KeyValueTrait { +public class KeyValueManagerImpl extends BaseShardManager> + implements KeyValueTrait { - public KeyValueManagerImpl(StateContext context, Map accessorMap) { - super(context, accessorMap); - } + public KeyValueManagerImpl(StateContext context, Map accessorMap) { + super(context, accessorMap); + } - @Override - public V get(K key) { - return getTraitByKey(key).get(key); - } + @Override + public V get(K key) { + return getTraitByKey(key).get(key); + } - @Override - public void put(K key, V value) { - getTraitByKey(key).put(key, value); - } + @Override + public void put(K key, V value) { + getTraitByKey(key).put(key, value); + } - @Override - public void remove(K key) { - getTraitByKey(key).remove(key); - } + @Override + public void remove(K key) { + getTraitByKey(key).remove(key); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/StaticGraphManagerImpl.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/StaticGraphManagerImpl.java index 2cdafe697..0cb08a23b 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/StaticGraphManagerImpl.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/StaticGraphManagerImpl.java @@ -19,7 +19,6 @@ package org.apache.geaflow.state.strategy.manager; -import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -29,6 +28,7 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; + import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.common.tuple.Tuple; import org.apache.geaflow.model.graph.edge.IEdge; @@ -42,183 +42,198 @@ import org.apache.geaflow.state.strategy.accessor.IAccessor; import org.apache.geaflow.utils.keygroup.KeyGroup; -public class StaticGraphManagerImpl extends BaseShardManager> implements StaticGraphTrait { - - - public StaticGraphManagerImpl(StateContext context, Map accessorMap) { - super(context, accessorMap); - } - - @Override - public void addEdge(IEdge edge) { - getTraitByKey(edge.getSrcId()).addEdge(edge); - } - - @Override - public List> getEdges(K sid, IStatePushDown pushdown) { - return getTraitByKey(sid).getEdges(sid, pushdown); - } - - @Override - public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { - return getTraitByKey(sid).getOneDegreeGraph(sid, pushdown); - } - - @Override - public void addVertex(IVertex vertex) { - getTraitByKey(vertex.getId()).addVertex(vertex); - } - - @Override - public IVertex getVertex(K sid, IStatePushDown pushdown) { - return getTraitByKey(sid).getVertex(sid, pushdown); - } - - @Override - public CloseableIterator vertexIDIterator() { - List> iterators = new ArrayList<>(); - for (Entry> entry : traitMap.entrySet()) { - CloseableIterator iterator = entry.getValue().vertexIDIterator(); - iterators.add(this.mayScale ? shardFilter(iterator, entry.getKey(), k -> k) : iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - @Override - public CloseableIterator vertexIDIterator(IStatePushDown pushdown) { - List> iterators = new ArrayList<>(); - KeyGroup shardGroup = getShardGroup(pushdown); - - for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { - StaticGraphTrait trait = traitMap.get(shard); - CloseableIterator iterator = trait.vertexIDIterator(pushdown); - iterators.add(this.mayScale ? shardFilter(iterator, shard, k -> k) : iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - @Override - public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { - return getIterator(IVertex::getId, pushdown, StaticGraphTrait::getVertexIterator); - } - - @Override - public CloseableIterator> getVertexIterator(List keys, IStatePushDown pushdown) { - return getIterator(keys, pushdown, StaticGraphTrait::getVertexIterator); - } - - @Override - public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { - return getIterator(IEdge::getSrcId, pushdown, StaticGraphTrait::getEdgeIterator); - } - - @Override - public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { - return getIterator(keys, pushdown, StaticGraphTrait::getEdgeIterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator( - IStatePushDown pushdown) { - return getIterator(OneDegreeGraph::getKey, pushdown, StaticGraphTrait::getOneDegreeGraphIterator); - } - - @Override - public CloseableIterator> getOneDegreeGraphIterator(List keys, IStatePushDown pushdown) { - return getIterator(keys, pushdown, StaticGraphTrait::getOneDegreeGraphIterator); - } - - @Override - public CloseableIterator> getEdgeProjectIterator( - IStatePushDown, R> pushdown) { - return getIterator(Tuple::getF0, pushdown, - (trait, pd) -> (CloseableIterator>) trait.getEdgeProjectIterator(pd)); - } - - @Override - public CloseableIterator> getEdgeProjectIterator(List keys, - IStatePushDown, R> pushdown) { - return getIterator(keys, pushdown, - (trait, keys1, pushdown1) -> trait.getEdgeProjectIterator(keys1, pushdown1)); - } - - @Override - public Map getAggResult(IStatePushDown pushdown) { - Map> map = new HashMap<>(); - KeyGroup shardGroup = getShardGroup(pushdown); - - for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { - map.put(shard, traitMap.get(shard).getAggResult(pushdown)); - } - - Map res = new HashMap<>(); - for (Entry> partRes : map.entrySet()) { - int keyGroupId = partRes.getKey(); - for (Entry entry : partRes.getValue().entrySet()) { - if (keyGroupId != assigner.assign(entry.getKey())) { - continue; - } - res.put(entry.getKey(), entry.getValue()); - } - } - return res; - } - - @Override - public Map getAggResult(List keys, IStatePushDown pushdown) { - Map res = new HashMap<>(); - Map> keyGroupMap = getKeyGroupMap(keys); - for (Entry> entry : keyGroupMap.entrySet()) { - Preconditions.checkArgument(entry.getKey() >= this.shardGroup.getStartKeyGroup() - && entry.getKey() <= this.shardGroup.getEndKeyGroup()); - - res.putAll(getTraitById(entry.getKey()).getAggResult(entry.getValue(), pushdown)); - } - return res; - } - - private CloseableIterator getIterator(List keys, - IStatePushDown pushdown, - TriFunction, List, IStatePushDown, CloseableIterator> function) { - List> iterators = new ArrayList<>(); - Map> keyGroupMap = getKeyGroupMap(keys); - for (Entry> entry : keyGroupMap.entrySet()) { - Preconditions.checkArgument(entry.getKey() >= this.shardGroup.getStartKeyGroup() - && entry.getKey() <= this.shardGroup.getEndKeyGroup()); - - CloseableIterator iterator = function.apply(getTraitById(entry.getKey()), entry.getValue(), pushdown); - iterators.add(iterator); - } - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - private CloseableIterator getIterator( - Function keyExtractor, - IStatePushDown pushdown, - BiFunction, IStatePushDown, CloseableIterator> function) { - List> iterators = new ArrayList<>(); - - KeyGroup shardGroup = getShardGroup(pushdown); - - int startShard = shardGroup.getStartKeyGroup(); - int endShard = shardGroup.getEndKeyGroup(); +import com.google.common.base.Preconditions; - for (int shard = startShard; shard <= endShard; shard++) { - StaticGraphTrait trait = traitMap.get(shard); - CloseableIterator iterator = function.apply(trait, pushdown); - iterators.add(this.mayScale ? shardFilter(iterator, shard, keyExtractor) : iterator); +public class StaticGraphManagerImpl + extends BaseShardManager> + implements StaticGraphTrait { + + public StaticGraphManagerImpl(StateContext context, Map accessorMap) { + super(context, accessorMap); + } + + @Override + public void addEdge(IEdge edge) { + getTraitByKey(edge.getSrcId()).addEdge(edge); + } + + @Override + public List> getEdges(K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getEdges(sid, pushdown); + } + + @Override + public OneDegreeGraph getOneDegreeGraph(K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getOneDegreeGraph(sid, pushdown); + } + + @Override + public void addVertex(IVertex vertex) { + getTraitByKey(vertex.getId()).addVertex(vertex); + } + + @Override + public IVertex getVertex(K sid, IStatePushDown pushdown) { + return getTraitByKey(sid).getVertex(sid, pushdown); + } + + @Override + public CloseableIterator vertexIDIterator() { + List> iterators = new ArrayList<>(); + for (Entry> entry : traitMap.entrySet()) { + CloseableIterator iterator = entry.getValue().vertexIDIterator(); + iterators.add(this.mayScale ? shardFilter(iterator, entry.getKey(), k -> k) : iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + @Override + public CloseableIterator vertexIDIterator(IStatePushDown pushdown) { + List> iterators = new ArrayList<>(); + KeyGroup shardGroup = getShardGroup(pushdown); + + for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { + StaticGraphTrait trait = traitMap.get(shard); + CloseableIterator iterator = trait.vertexIDIterator(pushdown); + iterators.add(this.mayScale ? shardFilter(iterator, shard, k -> k) : iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + @Override + public CloseableIterator> getVertexIterator(IStatePushDown pushdown) { + return getIterator(IVertex::getId, pushdown, StaticGraphTrait::getVertexIterator); + } + + @Override + public CloseableIterator> getVertexIterator( + List keys, IStatePushDown pushdown) { + return getIterator(keys, pushdown, StaticGraphTrait::getVertexIterator); + } + + @Override + public CloseableIterator> getEdgeIterator(IStatePushDown pushdown) { + return getIterator(IEdge::getSrcId, pushdown, StaticGraphTrait::getEdgeIterator); + } + + @Override + public CloseableIterator> getEdgeIterator(List keys, IStatePushDown pushdown) { + return getIterator(keys, pushdown, StaticGraphTrait::getEdgeIterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + IStatePushDown pushdown) { + return getIterator( + OneDegreeGraph::getKey, pushdown, StaticGraphTrait::getOneDegreeGraphIterator); + } + + @Override + public CloseableIterator> getOneDegreeGraphIterator( + List keys, IStatePushDown pushdown) { + return getIterator(keys, pushdown, StaticGraphTrait::getOneDegreeGraphIterator); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + IStatePushDown, R> pushdown) { + return getIterator( + Tuple::getF0, + pushdown, + (trait, pd) -> (CloseableIterator>) trait.getEdgeProjectIterator(pd)); + } + + @Override + public CloseableIterator> getEdgeProjectIterator( + List keys, IStatePushDown, R> pushdown) { + return getIterator( + keys, + pushdown, + (trait, keys1, pushdown1) -> trait.getEdgeProjectIterator(keys1, pushdown1)); + } + + @Override + public Map getAggResult(IStatePushDown pushdown) { + Map> map = new HashMap<>(); + KeyGroup shardGroup = getShardGroup(pushdown); + + for (int shard = shardGroup.getStartKeyGroup(); shard <= shardGroup.getEndKeyGroup(); shard++) { + map.put(shard, traitMap.get(shard).getAggResult(pushdown)); + } + + Map res = new HashMap<>(); + for (Entry> partRes : map.entrySet()) { + int keyGroupId = partRes.getKey(); + for (Entry entry : partRes.getValue().entrySet()) { + if (keyGroupId != assigner.assign(entry.getKey())) { + continue; } - - return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); - } - - private CloseableIterator shardFilter(CloseableIterator iterator, int keyGroupId, - Function keyExtractor) { - return new IteratorWithFilter<>(iterator, t -> assigner.assign(keyExtractor.apply(t)) == keyGroupId); - } - - private Map> getKeyGroupMap(Collection keySet) { - return keySet.stream().collect(Collectors.groupingBy(c -> assigner.assign(c))); - } + res.put(entry.getKey(), entry.getValue()); + } + } + return res; + } + + @Override + public Map getAggResult(List keys, IStatePushDown pushdown) { + Map res = new HashMap<>(); + Map> keyGroupMap = getKeyGroupMap(keys); + for (Entry> entry : keyGroupMap.entrySet()) { + Preconditions.checkArgument( + entry.getKey() >= this.shardGroup.getStartKeyGroup() + && entry.getKey() <= this.shardGroup.getEndKeyGroup()); + + res.putAll(getTraitById(entry.getKey()).getAggResult(entry.getValue(), pushdown)); + } + return res; + } + + private CloseableIterator getIterator( + List keys, + IStatePushDown pushdown, + TriFunction, List, IStatePushDown, CloseableIterator> + function) { + List> iterators = new ArrayList<>(); + Map> keyGroupMap = getKeyGroupMap(keys); + for (Entry> entry : keyGroupMap.entrySet()) { + Preconditions.checkArgument( + entry.getKey() >= this.shardGroup.getStartKeyGroup() + && entry.getKey() <= this.shardGroup.getEndKeyGroup()); + + CloseableIterator iterator = + function.apply(getTraitById(entry.getKey()), entry.getValue(), pushdown); + iterators.add(iterator); + } + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + private CloseableIterator getIterator( + Function keyExtractor, + IStatePushDown pushdown, + BiFunction, IStatePushDown, CloseableIterator> function) { + List> iterators = new ArrayList<>(); + + KeyGroup shardGroup = getShardGroup(pushdown); + + int startShard = shardGroup.getStartKeyGroup(); + int endShard = shardGroup.getEndKeyGroup(); + + for (int shard = startShard; shard <= endShard; shard++) { + StaticGraphTrait trait = traitMap.get(shard); + CloseableIterator iterator = function.apply(trait, pushdown); + iterators.add(this.mayScale ? shardFilter(iterator, shard, keyExtractor) : iterator); + } + + return iterators.size() == 1 ? iterators.get(0) : new MultiIterator<>(iterators.iterator()); + } + + private CloseableIterator shardFilter( + CloseableIterator iterator, int keyGroupId, Function keyExtractor) { + return new IteratorWithFilter<>( + iterator, t -> assigner.assign(keyExtractor.apply(t)) == keyGroupId); + } + + private Map> getKeyGroupMap(Collection keySet) { + return keySet.stream().collect(Collectors.groupingBy(c -> assigner.assign(c))); + } } diff --git a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/TriFunction.java b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/TriFunction.java index 717392709..4acea89c9 100644 --- a/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/TriFunction.java +++ b/geaflow/geaflow-state/geaflow-state-strategy/src/main/java/org/apache/geaflow/state/strategy/manager/TriFunction.java @@ -22,6 +22,5 @@ @FunctionalInterface public interface TriFunction { - R apply(T t, U u, V v); - + R apply(T t, U u, V v); } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/ByteUtils.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/ByteUtils.java index 3ec7a7700..e3271766a 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/ByteUtils.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/ByteUtils.java @@ -21,16 +21,16 @@ public class ByteUtils { - public static boolean isStartsWith(byte[] bytes, byte[] prefix) { - if (bytes == null || prefix == null || prefix.length > bytes.length) { - return false; - } else { - for (int i = 0, j = 0; i < prefix.length; i++, j++) { - if (bytes[i] != prefix[j]) { - return false; - } - } + public static boolean isStartsWith(byte[] bytes, byte[] prefix) { + if (bytes == null || prefix == null || prefix.length > bytes.length) { + return false; + } else { + for (int i = 0, j = 0; i < prefix.length; i++, j++) { + if (bytes[i] != prefix[j]) { + return false; } - return true; + } } + return true; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/HttpUtil.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/HttpUtil.java index 051fb549e..579f26614 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/HttpUtil.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/HttpUtil.java @@ -19,11 +19,20 @@ package org.apache.geaflow.utils; -import com.google.gson.Gson; import java.io.IOException; import java.lang.reflect.Type; import java.util.Map; import java.util.concurrent.Callable; + +import org.apache.geaflow.common.errorcode.RuntimeErrors; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import org.apache.geaflow.common.utils.RetryCommand; +import org.apache.geaflow.utils.client.HttpResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.gson.Gson; + import okhttp3.Headers; import okhttp3.MediaType; import okhttp3.OkHttpClient; @@ -32,138 +41,138 @@ import okhttp3.RequestBody; import okhttp3.Response; import okhttp3.ResponseBody; -import org.apache.geaflow.common.errorcode.RuntimeErrors; -import org.apache.geaflow.common.exception.GeaflowRuntimeException; -import org.apache.geaflow.common.utils.RetryCommand; -import org.apache.geaflow.utils.client.HttpResponse; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class HttpUtil { - private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class); - private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); - private static final Gson GSON = new Gson(); - private static final int DEFAULT_RETRY_TIMES = 3; - - public static Object post(String url, String json) { - return post(url, json, Object.class); - } - - public static Object post(String url, String json, Map headers) { - return post(url, json, headers, Object.class); - } - - public static T post(String url, String json, Class resultClass) { - return post(url, json, null, resultClass); - } - - public static T post(String url, String body, Map headers, - Class resultClass) { - LOGGER.info("post url: {} body: {}", url, body); - RequestBody requestBody = RequestBody.create(MEDIA_TYPE, body); - Builder builder = getRequestBuilder(url, headers); - Request request = builder.post(requestBody).build(); - - long t = System.currentTimeMillis(); - OkHttpClient client = new OkHttpClient(); - return RetryCommand.run(new Callable() { - @Override - public T call() throws Exception { - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - String msg = (responseBody != null) ? responseBody.string() : "{}"; - if (!response.isSuccessful()) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } - HttpResponse httpResponse = GSON.fromJson(msg, HttpResponse.class); - if (!httpResponse.isSuccess()) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } - T result = GSON.fromJson(httpResponse.getData(), resultClass); - LOGGER.info("post {} response cost {}ms: {}", url, System.currentTimeMillis() - t, msg); - return result; - } catch (IOException e) { - LOGGER.info("execute post failed: {}", e.getCause(), e); - throw new GeaflowRuntimeException(e); - } + private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class); + private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); + private static final Gson GSON = new Gson(); + private static final int DEFAULT_RETRY_TIMES = 3; + + public static Object post(String url, String json) { + return post(url, json, Object.class); + } + + public static Object post(String url, String json, Map headers) { + return post(url, json, headers, Object.class); + } + + public static T post(String url, String json, Class resultClass) { + return post(url, json, null, resultClass); + } + + public static T post( + String url, String body, Map headers, Class resultClass) { + LOGGER.info("post url: {} body: {}", url, body); + RequestBody requestBody = RequestBody.create(MEDIA_TYPE, body); + Builder builder = getRequestBuilder(url, headers); + Request request = builder.post(requestBody).build(); + + long t = System.currentTimeMillis(); + OkHttpClient client = new OkHttpClient(); + return RetryCommand.run( + new Callable() { + @Override + public T call() throws Exception { + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + String msg = (responseBody != null) ? responseBody.string() : "{}"; + if (!response.isSuccessful()) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } + HttpResponse httpResponse = GSON.fromJson(msg, HttpResponse.class); + if (!httpResponse.isSuccess()) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } + T result = GSON.fromJson(httpResponse.getData(), resultClass); + LOGGER.info( + "post {} response cost {}ms: {}", url, System.currentTimeMillis() - t, msg); + return result; + } catch (IOException e) { + LOGGER.info("execute post failed: {}", e.getCause(), e); + throw new GeaflowRuntimeException(e); } - }, DEFAULT_RETRY_TIMES); - } - - public static T get(String url, Map headers, Class resultClass) { - return get(url, headers, (Type) resultClass); - } - - public static T get(String url, Map headers, Type typeOfT) { - LOGGER.info("get url: {}", url); - Builder builder = getRequestBuilder(url, headers); - Request request = builder.get().build(); - - long t = System.currentTimeMillis(); - OkHttpClient client = new OkHttpClient(); - return RetryCommand.run(new Callable() { - @Override - public T call() throws Exception { - try (Response response = client.newCall(request).execute()) { - ResponseBody responseBody = response.body(); - String msg = (responseBody != null) ? responseBody.string() : "{}"; - if (!response.isSuccessful()) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } - HttpResponse httpResponse = GSON.fromJson(msg, HttpResponse.class); - if (!httpResponse.isSuccess()) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } - T result = GSON.fromJson(httpResponse.getData(), typeOfT); - LOGGER.info("get {} response cost {}ms: {}", url, System.currentTimeMillis() - t, - msg); - return result; - } catch (IOException e) { - LOGGER.info("execute get failed: {}", e.getCause(), e); - throw new GeaflowRuntimeException(e); - } + } + }, + DEFAULT_RETRY_TIMES); + } + + public static T get(String url, Map headers, Class resultClass) { + return get(url, headers, (Type) resultClass); + } + + public static T get(String url, Map headers, Type typeOfT) { + LOGGER.info("get url: {}", url); + Builder builder = getRequestBuilder(url, headers); + Request request = builder.get().build(); + + long t = System.currentTimeMillis(); + OkHttpClient client = new OkHttpClient(); + return RetryCommand.run( + new Callable() { + @Override + public T call() throws Exception { + try (Response response = client.newCall(request).execute()) { + ResponseBody responseBody = response.body(); + String msg = (responseBody != null) ? responseBody.string() : "{}"; + if (!response.isSuccessful()) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } + HttpResponse httpResponse = GSON.fromJson(msg, HttpResponse.class); + if (!httpResponse.isSuccess()) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } + T result = GSON.fromJson(httpResponse.getData(), typeOfT); + LOGGER.info( + "get {} response cost {}ms: {}", url, System.currentTimeMillis() - t, msg); + return result; + } catch (IOException e) { + LOGGER.info("execute get failed: {}", e.getCause(), e); + throw new GeaflowRuntimeException(e); } - }, DEFAULT_RETRY_TIMES); - } - - public static boolean delete(String url) { - return delete(url, null); - } - - public static boolean delete(String url, Map headers) { - LOGGER.info("delete url: {}", url); - OkHttpClient client = new OkHttpClient(); - Builder requestBuilder = getRequestBuilder(url, headers); - Request request = requestBuilder.delete().build(); - long t = System.currentTimeMillis(); - return RetryCommand.run(new Callable() { - @Override - public Boolean call() throws Exception { - try (Response response = client.newCall(request).execute()) { - ResponseBody body = response.body(); - String msg = (body != null) ? body.string() : "{}"; - if (!response.isSuccessful()) { - throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); - } else { - LOGGER.info("delete {} cost {}ms", url, System.currentTimeMillis() - t); - return true; - } - } catch (IOException e) { - LOGGER.info("execute delete failed: {}", e.getCause(), e); - throw new GeaflowRuntimeException(e); - } + } + }, + DEFAULT_RETRY_TIMES); + } + + public static boolean delete(String url) { + return delete(url, null); + } + + public static boolean delete(String url, Map headers) { + LOGGER.info("delete url: {}", url); + OkHttpClient client = new OkHttpClient(); + Builder requestBuilder = getRequestBuilder(url, headers); + Request request = requestBuilder.delete().build(); + long t = System.currentTimeMillis(); + return RetryCommand.run( + new Callable() { + @Override + public Boolean call() throws Exception { + try (Response response = client.newCall(request).execute()) { + ResponseBody body = response.body(); + String msg = (body != null) ? body.string() : "{}"; + if (!response.isSuccessful()) { + throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(msg)); + } else { + LOGGER.info("delete {} cost {}ms", url, System.currentTimeMillis() - t); + return true; + } + } catch (IOException e) { + LOGGER.info("execute delete failed: {}", e.getCause(), e); + throw new GeaflowRuntimeException(e); } - }, DEFAULT_RETRY_TIMES); - } - - private static Builder getRequestBuilder(String url, Map headers) { - Builder requestBuilder = new Request.Builder().url(url); - if (headers != null) { - Headers requestHeaders = Headers.of(headers); - requestBuilder.headers(requestHeaders); - } - return requestBuilder; + } + }, + DEFAULT_RETRY_TIMES); + } + + private static Builder getRequestBuilder(String url, Map headers) { + Builder requestBuilder = new Request.Builder().url(url); + if (headers != null) { + Headers requestHeaders = Headers.of(headers); + requestBuilder.headers(requestHeaders); } - + return requestBuilder; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/JsonUtils.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/JsonUtils.java index ba68e700c..62f615fb9 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/JsonUtils.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/JsonUtils.java @@ -19,46 +19,47 @@ package org.apache.geaflow.utils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ContainerNode; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.Map; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ContainerNode; + public class JsonUtils { - public static final ObjectMapper MAPPER = new ObjectMapper(); + public static final ObjectMapper MAPPER = new ObjectMapper(); - public static Map parseJson2map(String str) { - try { - JsonNode jsonNode = MAPPER.readTree(str); - Map map = new HashMap<>(); - Iterator fieldNames = jsonNode.fieldNames(); - while (fieldNames.hasNext()) { - String key = fieldNames.next(); - JsonNode value = jsonNode.get(key); - if (value instanceof ContainerNode) { - map.put(key, value.toString()); - } else { - map.put(key, value.asText()); - } - } - return map; - } catch (IOException e) { - throw new GeaflowRuntimeException(e); + public static Map parseJson2map(String str) { + try { + JsonNode jsonNode = MAPPER.readTree(str); + Map map = new HashMap<>(); + Iterator fieldNames = jsonNode.fieldNames(); + while (fieldNames.hasNext()) { + String key = fieldNames.next(); + JsonNode value = jsonNode.get(key); + if (value instanceof ContainerNode) { + map.put(key, value.toString()); + } else { + map.put(key, value.asText()); } + } + return map; + } catch (IOException e) { + throw new GeaflowRuntimeException(e); } + } - public static String toJsonString(Object object) { - try { - return MAPPER.writeValueAsString(object); - } catch (JsonProcessingException e) { - throw new GeaflowRuntimeException(e); - } + public static String toJsonString(Object object) { + try { + return MAPPER.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new GeaflowRuntimeException(e); } - + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/NetworkUtil.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/NetworkUtil.java index cc0a9d908..b2fa0e7c3 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/NetworkUtil.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/NetworkUtil.java @@ -25,12 +25,11 @@ public class NetworkUtil { - public static void checkServiceAvailable(String hostName, int port, int connectTimeout) - throws IOException { - try (Socket socket = new Socket()) { - InetSocketAddress socketAddress = new InetSocketAddress(hostName, port); - socket.connect(socketAddress, connectTimeout); - } + public static void checkServiceAvailable(String hostName, int port, int connectTimeout) + throws IOException { + try (Socket socket = new Socket()) { + InetSocketAddress socketAddress = new InetSocketAddress(hostName, port); + socket.connect(socketAddress, connectTimeout); } - + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/TicToc.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/TicToc.java index ce871d636..f81fe2c60 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/TicToc.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/TicToc.java @@ -23,31 +23,30 @@ public class TicToc implements Serializable { - private long start = 0L; - - public void tic() { - this.start = System.currentTimeMillis(); - } - - public long toc() { - long end = System.currentTimeMillis(); - long duration = end - this.start; - this.start = end; - return duration; - } - - public void ticNano() { - start = System.nanoTime(); - } - - public long tocNano() { - long end = System.nanoTime(); - long duration = end - this.start; - if (duration < 0) { - return -1; - } - this.start = end; - return duration; + private long start = 0L; + + public void tic() { + this.start = System.currentTimeMillis(); + } + + public long toc() { + long end = System.currentTimeMillis(); + long duration = end - this.start; + this.start = end; + return duration; + } + + public void ticNano() { + start = System.nanoTime(); + } + + public long tocNano() { + long end = System.nanoTime(); + long duration = end - this.start; + if (duration < 0) { + return -1; } - + this.start = end; + return duration; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/client/HttpResponse.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/client/HttpResponse.java index 747514335..da9735225 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/client/HttpResponse.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/client/HttpResponse.java @@ -19,54 +19,55 @@ package org.apache.geaflow.utils.client; -import com.google.gson.JsonElement; import java.io.Serializable; +import com.google.gson.JsonElement; + public class HttpResponse implements Serializable { - private String code; - private String message; - private String host; - private boolean success; - private JsonElement data; + private String code; + private String message; + private String host; + private boolean success; + private JsonElement data; - public String getCode() { - return code; - } + public String getCode() { + return code; + } - public void setCode(String code) { - this.code = code; - } + public void setCode(String code) { + this.code = code; + } - public String getMessage() { - return message; - } + public String getMessage() { + return message; + } - public void setMessage(String message) { - this.message = message; - } + public void setMessage(String message) { + this.message = message; + } - public String getHost() { - return host; - } + public String getHost() { + return host; + } - public void setHost(String host) { - this.host = host; - } + public void setHost(String host) { + this.host = host; + } - public JsonElement getData() { - return data; - } + public JsonElement getData() { + return data; + } - public void setData(JsonElement data) { - this.data = data; - } + public void setData(JsonElement data) { + this.data = data; + } - public boolean isSuccess() { - return success; - } + public boolean isSuccess() { + return success; + } - public void setSuccess(boolean success) { - this.success = success; - } + public void setSuccess(boolean success) { + this.success = success; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/DefaultKeyGroupAssigner.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/DefaultKeyGroupAssigner.java index 5f5c2fdad..176c61b3d 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/DefaultKeyGroupAssigner.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/DefaultKeyGroupAssigner.java @@ -21,22 +21,22 @@ public class DefaultKeyGroupAssigner implements IKeyGroupAssigner { - private int maxPara; + private int maxPara; - public DefaultKeyGroupAssigner(int maxPara) { - this.maxPara = maxPara; - } + public DefaultKeyGroupAssigner(int maxPara) { + this.maxPara = maxPara; + } - @Override - public int getKeyGroupNumber() { - return this.maxPara; - } + @Override + public int getKeyGroupNumber() { + return this.maxPara; + } - @Override - public int assign(Object key) { - if (key == null) { - return -1; - } - return KeyGroupAssignment.assignToKeyGroup(key, maxPara); + @Override + public int assign(Object key) { + if (key == null) { + return -1; } + return KeyGroupAssignment.assignToKeyGroup(key, maxPara); + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/IKeyGroupAssigner.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/IKeyGroupAssigner.java index c02b7e813..ca3d735e0 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/IKeyGroupAssigner.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/IKeyGroupAssigner.java @@ -21,7 +21,7 @@ public interface IKeyGroupAssigner { - int getKeyGroupNumber(); + int getKeyGroupNumber(); - int assign(Object key); + int assign(Object key); } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroup.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroup.java index e7da0c648..e66a3578f 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroup.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroup.java @@ -19,70 +19,68 @@ package org.apache.geaflow.utils.keygroup; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.Objects; +import com.google.common.base.Preconditions; + public class KeyGroup implements Serializable { - private final int startKeyGroup; - private final int endKeyGroup; + private final int startKeyGroup; + private final int endKeyGroup; - /** - * Defines the range [startKeyGroup, endKeyGroup]. - * - * @param startKeyGroup start of the range (inclusive) - * @param endKeyGroup end of the range (inclusive) - */ - public KeyGroup(int startKeyGroup, int endKeyGroup) { - Preconditions.checkArgument(startKeyGroup >= 0); - Preconditions.checkArgument(startKeyGroup <= endKeyGroup); - this.startKeyGroup = startKeyGroup; - this.endKeyGroup = endKeyGroup; - Preconditions.checkArgument(getNumberOfKeyGroups() >= 0, "Potential overflow detected."); - } + /** + * Defines the range [startKeyGroup, endKeyGroup]. + * + * @param startKeyGroup start of the range (inclusive) + * @param endKeyGroup end of the range (inclusive) + */ + public KeyGroup(int startKeyGroup, int endKeyGroup) { + Preconditions.checkArgument(startKeyGroup >= 0); + Preconditions.checkArgument(startKeyGroup <= endKeyGroup); + this.startKeyGroup = startKeyGroup; + this.endKeyGroup = endKeyGroup; + Preconditions.checkArgument(getNumberOfKeyGroups() >= 0, "Potential overflow detected."); + } - /** - * Get the number of key-groups in the range. - */ - public int getNumberOfKeyGroups() { - return 1 + endKeyGroup - startKeyGroup; - } + /** Get the number of key-groups in the range. */ + public int getNumberOfKeyGroups() { + return 1 + endKeyGroup - startKeyGroup; + } - public int getStartKeyGroup() { - return startKeyGroup; - } + public int getStartKeyGroup() { + return startKeyGroup; + } - public int getEndKeyGroup() { - return endKeyGroup; - } + public int getEndKeyGroup() { + return endKeyGroup; + } - public boolean contains(KeyGroup other) { - return this.startKeyGroup <= other.startKeyGroup && this.endKeyGroup >= other.endKeyGroup; - } + public boolean contains(KeyGroup other) { + return this.startKeyGroup <= other.startKeyGroup && this.endKeyGroup >= other.endKeyGroup; + } - public boolean contains(int keyGroupId) { - return this.startKeyGroup <= keyGroupId && this.endKeyGroup >= keyGroupId; - } + public boolean contains(int keyGroupId) { + return this.startKeyGroup <= keyGroupId && this.endKeyGroup >= keyGroupId; + } - @Override - public int hashCode() { - return Objects.hash(startKeyGroup, endKeyGroup); - } + @Override + public int hashCode() { + return Objects.hash(startKeyGroup, endKeyGroup); + } - @Override - public boolean equals(Object obj) { - if (!(obj instanceof KeyGroup)) { - return false; - } - KeyGroup keyGroup = (KeyGroup) obj; - return this.startKeyGroup == keyGroup.getStartKeyGroup() && this.endKeyGroup == keyGroup - .getEndKeyGroup(); + @Override + public boolean equals(Object obj) { + if (!(obj instanceof KeyGroup)) { + return false; } + KeyGroup keyGroup = (KeyGroup) obj; + return this.startKeyGroup == keyGroup.getStartKeyGroup() + && this.endKeyGroup == keyGroup.getEndKeyGroup(); + } - @Override - public String toString() { - return "KeyGroup{" + "startKeyGroup=" + startKeyGroup + ", endKeyGroup=" + endKeyGroup - + '}'; - } + @Override + public String toString() { + return "KeyGroup{" + "startKeyGroup=" + startKeyGroup + ", endKeyGroup=" + endKeyGroup + '}'; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignerFactory.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignerFactory.java index b919595c4..19140d822 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignerFactory.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignerFactory.java @@ -21,12 +21,12 @@ public class KeyGroupAssignerFactory { - public static IKeyGroupAssigner createKeyGroupAssigner(KeyGroup keyGroup, int taskIndex, - int maxPara) { - if (keyGroup.getNumberOfKeyGroups() == 1 && keyGroup.getStartKeyGroup() == taskIndex) { - return new SingleKeyGroupAssigner(keyGroup.getStartKeyGroup()); - } else { - return new DefaultKeyGroupAssigner(maxPara); - } + public static IKeyGroupAssigner createKeyGroupAssigner( + KeyGroup keyGroup, int taskIndex, int maxPara) { + if (keyGroup.getNumberOfKeyGroups() == 1 && keyGroup.getStartKeyGroup() == taskIndex) { + return new SingleKeyGroupAssigner(keyGroup.getStartKeyGroup()); + } else { + return new DefaultKeyGroupAssigner(maxPara); } + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignment.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignment.java index 6d59ec0b2..c8f2d6e40 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignment.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/KeyGroupAssignment.java @@ -19,14 +19,16 @@ package org.apache.geaflow.utils.keygroup; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; + import org.apache.geaflow.utils.math.MathUtil; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + /* This file is based on source code from the Flink Project (http://flink.apache.org/), licensed by the Apache * Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. */ @@ -36,96 +38,106 @@ */ public final class KeyGroupAssignment { - /** - * Computes the range of key-groups that are assigned to a given operator under the given - * parallelism and maximum parallelism. - * @param maxParallelism Maximal parallelism that the job was initially created with. - * @param parallelism The current parallelism under which the job runs. Must be <= - * maxParallelism. - * @param index Id of a key-group. 0 <= keyGroupID < maxParallelism. - */ - public static KeyGroup computeKeyGroupRangeForOperatorIndex(int maxParallelism, int parallelism, - int index) { - Preconditions.checkArgument(maxParallelism > 0, "maxParallelism should be > 0"); - if (parallelism > maxParallelism) { - throw new IllegalArgumentException("Maximum parallelism " + maxParallelism + " must " - + "not be smaller than parallelism " + parallelism); - } - int start = index == 0 ? 0 : ((index * maxParallelism - 1) / parallelism) + 1; - int end = ((index + 1) * maxParallelism - 1) / parallelism; - return new KeyGroup(start, end); + /** + * Computes the range of key-groups that are assigned to a given operator under the given + * parallelism and maximum parallelism. + * + * @param maxParallelism Maximal parallelism that the job was initially created with. + * @param parallelism The current parallelism under which the job runs. Must be <= maxParallelism. + * @param index Id of a key-group. 0 <= keyGroupID < maxParallelism. + */ + public static KeyGroup computeKeyGroupRangeForOperatorIndex( + int maxParallelism, int parallelism, int index) { + Preconditions.checkArgument(maxParallelism > 0, "maxParallelism should be > 0"); + if (parallelism > maxParallelism) { + throw new IllegalArgumentException( + "Maximum parallelism " + + maxParallelism + + " must " + + "not be smaller than parallelism " + + parallelism); } + int start = index == 0 ? 0 : ((index * maxParallelism - 1) / parallelism) + 1; + int end = ((index + 1) * maxParallelism - 1) / parallelism; + return new KeyGroup(start, end); + } - /** - * Assigns the given key to a parallel operator index. - * @param key the key to assign - * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. - * @param parallelism the current parallelism of the operator - * @return the index of the parallel operator to which the given key should be routed. - */ - public static int assignKeyToParallelTask(Object key, int maxParallelism, int parallelism) { - return computeTaskIndexForKeyGroup(maxParallelism, parallelism, - assignToKeyGroup(key, maxParallelism)); - } + /** + * Assigns the given key to a parallel operator index. + * + * @param key the key to assign + * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. + * @param parallelism the current parallelism of the operator + * @return the index of the parallel operator to which the given key should be routed. + */ + public static int assignKeyToParallelTask(Object key, int maxParallelism, int parallelism) { + return computeTaskIndexForKeyGroup( + maxParallelism, parallelism, assignToKeyGroup(key, maxParallelism)); + } - /** - * Computes the index of the operator to which a key-group belongs under the given parallelism - * and maximum parallelism. - * IMPORTANT: maxParallelism must be <= Short.MAX_VALUE to avoid rounding problems in this - * method. If we ever want - * to go beyond this boundary, this method must perform arithmetic on long values. - * @param maxParallelism Maximal parallelism that the job was initially created with. - * 0 < parallelism <= maxParallelism <= Short.MAX_VALUE must hold. - * @param parallelism The current parallelism under which the job runs. Must be <= - * maxParallelism. - * @param keyGroupId Id of a key-group. 0 <= keyGroupID < maxParallelism. - * @return The index of the operator to which elements from the given key-group should be routed - * under the given parallelism and maxParallelism. - */ - public static int computeTaskIndexForKeyGroup(int maxParallelism, int parallelism, - int keyGroupId) { - Preconditions.checkArgument(maxParallelism > 0, "maxParallelism should be > 0"); - if (parallelism > maxParallelism) { - throw new IllegalArgumentException("Maximum parallelism " + maxParallelism + " must " - + "not be smaller than parallelism " + parallelism); - } - return keyGroupId * parallelism / maxParallelism; + /** + * Computes the index of the operator to which a key-group belongs under the given parallelism and + * maximum parallelism. IMPORTANT: maxParallelism must be <= Short.MAX_VALUE to avoid rounding + * problems in this method. If we ever want to go beyond this boundary, this method must perform + * arithmetic on long values. + * + * @param maxParallelism Maximal parallelism that the job was initially created with. 0 < + * parallelism <= maxParallelism <= Short.MAX_VALUE must hold. + * @param parallelism The current parallelism under which the job runs. Must be <= maxParallelism. + * @param keyGroupId Id of a key-group. 0 <= keyGroupID < maxParallelism. + * @return The index of the operator to which elements from the given key-group should be routed + * under the given parallelism and maxParallelism. + */ + public static int computeTaskIndexForKeyGroup( + int maxParallelism, int parallelism, int keyGroupId) { + Preconditions.checkArgument(maxParallelism > 0, "maxParallelism should be > 0"); + if (parallelism > maxParallelism) { + throw new IllegalArgumentException( + "Maximum parallelism " + + maxParallelism + + " must " + + "not be smaller than parallelism " + + parallelism); } + return keyGroupId * parallelism / maxParallelism; + } - /** - * Assigns the given key to a key-group index. - * @param key the key to assign - * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. - * @return the key-group to which the given key is assigned - */ - public static int assignToKeyGroup(Object key, int maxParallelism) { - return computeKeyGroupForKeyHash(key.hashCode(), maxParallelism); - } + /** + * Assigns the given key to a key-group index. + * + * @param key the key to assign + * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. + * @return the key-group to which the given key is assigned + */ + public static int assignToKeyGroup(Object key, int maxParallelism) { + return computeKeyGroupForKeyHash(key.hashCode(), maxParallelism); + } - /** - * Assigns the given key to a key-group index. - * @param keyHash the hash of the key to assign - * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. - * @return the key-group to which the given key is assigned - */ - public static int computeKeyGroupForKeyHash(int keyHash, int maxParallelism) { - // we can rehash keyHash - return MathUtil.murmurHash(keyHash) % maxParallelism; - } + /** + * Assigns the given key to a key-group index. + * + * @param keyHash the hash of the key to assign + * @param maxParallelism the maximum supported parallelism, aka the number of key-groups. + * @return the key-group to which the given key is assigned + */ + public static int computeKeyGroupForKeyHash(int keyHash, int maxParallelism) { + // we can rehash keyHash + return MathUtil.murmurHash(keyHash) % maxParallelism; + } - @VisibleForTesting - public static Map> computeKeyGroupToTask(int maxParallelism, - List targetTasks) { - Map> keyGroupToTask = new ConcurrentHashMap<>(); - for (int index = 0; index < targetTasks.size(); index++) { - KeyGroup taskKeyGroup = computeKeyGroupRangeForOperatorIndex(maxParallelism, - targetTasks.size(), index); - for (int groupId = taskKeyGroup.getStartKeyGroup(); - groupId <= taskKeyGroup.getEndKeyGroup(); groupId++) { - keyGroupToTask.put(groupId, ImmutableList.of(targetTasks.get(index))); - } - } - return keyGroupToTask; + @VisibleForTesting + public static Map> computeKeyGroupToTask( + int maxParallelism, List targetTasks) { + Map> keyGroupToTask = new ConcurrentHashMap<>(); + for (int index = 0; index < targetTasks.size(); index++) { + KeyGroup taskKeyGroup = + computeKeyGroupRangeForOperatorIndex(maxParallelism, targetTasks.size(), index); + for (int groupId = taskKeyGroup.getStartKeyGroup(); + groupId <= taskKeyGroup.getEndKeyGroup(); + groupId++) { + keyGroupToTask.put(groupId, ImmutableList.of(targetTasks.get(index))); + } } - + return keyGroupToTask; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/SingleKeyGroupAssigner.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/SingleKeyGroupAssigner.java index 52a590bd4..d2cc8c291 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/SingleKeyGroupAssigner.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/keygroup/SingleKeyGroupAssigner.java @@ -21,20 +21,19 @@ public class SingleKeyGroupAssigner implements IKeyGroupAssigner { - private final int keyGroupId; + private final int keyGroupId; - public SingleKeyGroupAssigner(int keyGroupId) { - this.keyGroupId = keyGroupId; - } + public SingleKeyGroupAssigner(int keyGroupId) { + this.keyGroupId = keyGroupId; + } - @Override - public int getKeyGroupNumber() { - return 1; - } - - @Override - public int assign(Object key) { - return keyGroupId; - } + @Override + public int getKeyGroupNumber() { + return 1; + } + @Override + public int assign(Object key) { + return keyGroupId; + } } diff --git a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/math/MathUtil.java b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/math/MathUtil.java index aef6f3d13..3ecbe319e 100644 --- a/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/math/MathUtil.java +++ b/geaflow/geaflow-utils/src/main/java/org/apache/geaflow/utils/math/MathUtil.java @@ -25,318 +25,320 @@ public class MathUtil { - public static final long SECOND_IN_MS = 1000L; - public static final long MINUTE_IN_MS = 60L * SECOND_IN_MS; - public static final long HOUR_IN_MS = 60L * MINUTE_IN_MS; + public static final long SECOND_IN_MS = 1000L; + public static final long MINUTE_IN_MS = 60L * SECOND_IN_MS; + public static final long HOUR_IN_MS = 60L * MINUTE_IN_MS; + + public static final long MINUTE = 60L; + public static final long HOUR = 60L * MINUTE; + + public static final int MAXIMUM_CAPACITY = 1 << 30; + + private MathUtil() {} + + /** + * Check if the array has deviating elements. + * + *

Deviating elements are found by comparing each individual value against the average. + * + * @param values the array of values to check + * @param buffer the amount to ignore as a buffer for smaller valued lists + * @param factor the amount of allowed deviation is calculated from average * factor + * @return the index of the deviating value, or -1 if + */ + public static int[] deviates(long[] values, long buffer, double factor) { + if (values == null || values.length == 0) { + return new int[0]; + } - public static final long MINUTE = 60L; - public static final long HOUR = 60L * MINUTE; + long avg = average(values); - public static final int MAXIMUM_CAPACITY = 1 << 30; + // Find deviated elements + long minimumDiff = Math.max(buffer, (long) (avg * factor)); + List deviatedElements = new ArrayList(); - private MathUtil() { + for (int i = 0; i < values.length; i++) { + long diff = values[i] - avg; + if (diff > minimumDiff) { + deviatedElements.add(i); + } } - /** - * Check if the array has deviating elements. - *

- * Deviating elements are found by comparing each individual value against the average. - * @param values the array of values to check - * @param buffer the amount to ignore as a buffer for smaller valued lists - * @param factor the amount of allowed deviation is calculated from average * factor - * @return the index of the deviating value, or -1 if - */ - public static int[] deviates(long[] values, long buffer, double factor) { - if (values == null || values.length == 0) { - return new int[0]; - } - - long avg = average(values); - - // Find deviated elements - long minimumDiff = Math.max(buffer, (long) (avg * factor)); - List deviatedElements = new ArrayList(); - - for (int i = 0; i < values.length; i++) { - long diff = values[i] - avg; - if (diff > minimumDiff) { - deviatedElements.add(i); - } - } - - int[] result = new int[deviatedElements.size()]; - for (int i = 0; i < result.length; i++) { - result[i] = deviatedElements.get(i); - } - - return result; + int[] result = new int[deviatedElements.size()]; + for (int i = 0; i < result.length; i++) { + result[i] = deviatedElements.get(i); } - /** - * The percentile method returns the least value from the given list which has at least given - * percentile. - * @param values The list of values to find the percentile from - * @param percentile The percentile - * @return The least value from the list with at least the given percentile - */ - public static long percentile(List values, int percentile) { - - if (values.size() == 0) { - throw new IllegalArgumentException("Percentile of empty list is not defined."); - } - - if (percentile > 100 || percentile < 0) { - throw new IllegalArgumentException("Percentile has to be between 0-100"); - } - - if (percentile == 0) { - return 0; - } - - Collections.sort(values); - - // Use Nearest Rank method. - // https://en.wikipedia.org/wiki/Percentile#The_Nearest_Rank_method - int position = (int) Math.ceil(values.size() * percentile / 100.0); - - // should never happen. - if (position == 0) { - return values.get(position); - } - - // position is always one greater than index. Return value at the proper index - return values.get(position - 1); + return result; + } + + /** + * The percentile method returns the least value from the given list which has at least given + * percentile. + * + * @param values The list of values to find the percentile from + * @param percentile The percentile + * @return The least value from the list with at least the given percentile + */ + public static long percentile(List values, int percentile) { + + if (values.size() == 0) { + throw new IllegalArgumentException("Percentile of empty list is not defined."); } - /** - * This function hashes an integer value. - * - *

It is crucial to use different hash functions to partition data across machines and the - * internal partitioning of data structures. This hash function is intended for partitioning - * across machines. - * - * @param code The integer to be hashed. - * @return The non-negative hash code for the integer. - */ - public static int murmurHash(int code) { - code *= 0xcc9e2d51; - code = Integer.rotateLeft(code, 15); - code *= 0x1b873593; - - code = Integer.rotateLeft(code, 13); - code = code * 5 + 0xe6546b64; - - code ^= 4; - code = bitMix(code); - - if (code >= 0) { - return code; - } else if (code != Integer.MIN_VALUE) { - return -code; - } else { - return 0; - } + if (percentile > 100 || percentile < 0) { + throw new IllegalArgumentException("Percentile has to be between 0-100"); } - - public static long[][] findTwoGroups(long[] values) { - return findTwoGroupsRecursive(values, average(values), 2); + if (percentile == 0) { + return 0; } - public static long[][] findTwoGroupsRecursive(long[] values, long middle, int levels) { - if (levels > 0) { - long[][] result = twoMeans(values, middle); - long newMiddle = average(result[1]) - average(result[0]); - return findTwoGroupsRecursive(values, newMiddle, levels - 1); - } - return twoMeans(values, middle); - } + Collections.sort(values); - private static long[][] twoMeans(long[] values, long middle) { - List smaller = new ArrayList(); - List larger = new ArrayList(); - for (int i = 0; i < values.length; i++) { - if (values[i] < middle) { - smaller.add(values[i]); - } else { - larger.add(values[i]); - } - } - - long[][] result = new long[2][]; - result[0] = toArray(smaller); - result[1] = toArray(larger); - - return result; - } + // Use Nearest Rank method. + // https://en.wikipedia.org/wiki/Percentile#The_Nearest_Rank_method + int position = (int) Math.ceil(values.size() * percentile / 100.0); - private static long[] toArray(List input) { - long[] result = new long[input.size()]; - for (int i = 0; i < result.length; i++) { - result[i] = input.get(i); - } - return result; + // should never happen. + if (position == 0) { + return values.get(position); } - /** - * Compute average for the given array of long. - * @param values the values - * @return The average(values) - */ - public static long average(long[] values) { - //Find average - double sum = 0d; - for (long value : values) { - sum += value; - } - return (long) (sum / (double) values.length); + // position is always one greater than index. Return value at the proper index + return values.get(position - 1); + } + + /** + * This function hashes an integer value. + * + *

It is crucial to use different hash functions to partition data across machines and the + * internal partitioning of data structures. This hash function is intended for partitioning + * across machines. + * + * @param code The integer to be hashed. + * @return The non-negative hash code for the integer. + */ + public static int murmurHash(int code) { + code *= 0xcc9e2d51; + code = Integer.rotateLeft(code, 15); + code *= 0x1b873593; + + code = Integer.rotateLeft(code, 13); + code = code * 5 + 0xe6546b64; + + code ^= 4; + code = bitMix(code); + + if (code >= 0) { + return code; + } else if (code != Integer.MIN_VALUE) { + return -code; + } else { + return 0; } + } - /** - * Compute average for a List of long values. - * @param values the values - * @return The average(values) - */ - public static long average(List values) { - //Find average - double sum = 0d; - for (long value : values) { - sum += value; - } - return (long) (sum / (double) values.size()); - } + public static long[][] findTwoGroups(long[] values) { + return findTwoGroupsRecursive(values, average(values), 2); + } - /** - * Find the median of the given list. - * @param values The values - * @return The median(values) - */ - public static long median(List values) { - if (values.size() == 0) { - throw new IllegalArgumentException("Median of an empty list is not defined."); - } - Collections.sort(values); - int middle = values.size() / 2; - if (values.size() % 2 == 0) { - return (values.get(middle - 1) + values.get(middle)) / 2; - } else { - return values.get(middle); - } + public static long[][] findTwoGroupsRecursive(long[] values, long middle, int levels) { + if (levels > 0) { + long[][] result = twoMeans(values, middle); + long newMiddle = average(result[1]) - average(result[0]); + return findTwoGroupsRecursive(values, newMiddle, levels - 1); } - - /** - * Fast method of finding the next power of 2 greater than or equal to the supplied value. - * - *

- * If the value is {@code <= 0} then 1 will be returned. - * This method is not suitable for {@link Integer#MIN_VALUE} or numbers greater than 2^30. - * - * @param value from which to search for next power of 2 - * @return The next power of 2 or the value itself if it is a power of 2 - */ - public static int findNextPositivePowerOfTwo(final int value) { - assert value > Integer.MIN_VALUE && value < 0x40000000; - return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + return twoMeans(values, middle); + } + + private static long[][] twoMeans(long[] values, long middle) { + List smaller = new ArrayList(); + List larger = new ArrayList(); + for (int i = 0; i < values.length; i++) { + if (values[i] < middle) { + smaller.add(values[i]); + } else { + larger.add(values[i]); + } } - /** - * Fast method of finding the next power of 2 greater than or equal to the supplied value. - *

- * This method will do runtime bounds checking and call {@link #findNextPositivePowerOfTwo(int)} if within a - * valid range. - * @param value from which to search for next power of 2 - * @return The next power of 2 or the value itself if it is a power of 2. - * Special cases for return values are as follows: - *

    - *
  • {@code <= 0} -> 1
  • - *
  • {@code >= 2^30} -> 2^30
  • - *
- */ - public static int safeFindNextPositivePowerOfTwo(final int value) { - return value <= 0 ? 1 : value >= 0x40000000 ? 0x40000000 : findNextPositivePowerOfTwo(value); - } + long[][] result = new long[2][]; + result[0] = toArray(smaller); + result[1] = toArray(larger); - public static boolean isPrime(int n) { - for (int i = 2; i * i <= n; ++i) { - if (n % i == 0) { - return false; - } - } - return true; - } + return result; + } - public static int nextPrime(int n) { - for (int num = n; num < n * 2; num++) { - if (isPrime(num)) { - return num; - } - } - return n; + private static long[] toArray(List input) { + long[] result = new long[input.size()]; + for (int i = 0; i < result.length; i++) { + result[i] = input.get(i); } - - /** - * Pseudo-randomly maps a long (64-bit) to an integer (32-bit) using some bit-mixing for better - * distribution. - * - * @param in the long (64-bit)input. - * @return the bit-mixed int (32-bit) output - */ - public static int longToIntWithBitMixing(long in) { - in = (in ^ (in >>> 30)) * 0xbf58476d1ce4e5b9L; - in = (in ^ (in >>> 27)) * 0x94d049bb133111ebL; - in = in ^ (in >>> 31); - return (int) in; + return result; + } + + /** + * Compute average for the given array of long. + * + * @param values the values + * @return The average(values) + */ + public static long average(long[] values) { + // Find average + double sum = 0d; + for (long value : values) { + sum += value; } - - // ============================================================================================ - - /** - * Bit-mixing for pseudo-randomization of integers (e.g., to guard against bad hash functions). - * Implementation is from Murmur's 32 bit finalizer. - * - * @param in the input value - * @return the bit-mixed output value - */ - public static int bitMix(int in) { - in ^= in >>> 16; - in *= 0x85ebca6b; - in ^= in >>> 13; - in *= 0xc2b2ae35; - in ^= in >>> 16; - return in; + return (long) (sum / (double) values.length); + } + + /** + * Compute average for a List of long values. + * + * @param values the values + * @return The average(values) + */ + public static long average(List values) { + // Find average + double sum = 0d; + for (long value : values) { + sum += value; } - - public static int multiplesOf50(int input) { - return input / 50 * 50; + return (long) (sum / (double) values.size()); + } + + /** + * Find the median of the given list. + * + * @param values The values + * @return The median(values) + */ + public static long median(List values) { + if (values.size() == 0) { + throw new IllegalArgumentException("Median of an empty list is not defined."); } - - /** - * Check whether input is power of two. - * - * @param input - * @return - */ - public static boolean isPowerOf2(int input) { - if (input > 0) { - return input == 1 || (input & (-input)) == input; - } + Collections.sort(values); + int middle = values.size() / 2; + if (values.size() % 2 == 0) { + return (values.get(middle - 1) + values.get(middle)) / 2; + } else { + return values.get(middle); + } + } + + /** + * Fast method of finding the next power of 2 greater than or equal to the supplied value. + * + *

If the value is {@code <= 0} then 1 will be returned. This method is not suitable for {@link + * Integer#MIN_VALUE} or numbers greater than 2^30. + * + * @param value from which to search for next power of 2 + * @return The next power of 2 or the value itself if it is a power of 2 + */ + public static int findNextPositivePowerOfTwo(final int value) { + assert value > Integer.MIN_VALUE && value < 0x40000000; + return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + } + + /** + * Fast method of finding the next power of 2 greater than or equal to the supplied value. + * + *

This method will do runtime bounds checking and call {@link + * #findNextPositivePowerOfTwo(int)} if within a valid range. + * + * @param value from which to search for next power of 2 + * @return The next power of 2 or the value itself if it is a power of 2. Special cases for return + * values are as follows: + *

    + *
  • {@code <= 0} -> 1 + *
  • {@code >= 2^30} -> 2^30 + *
+ */ + public static int safeFindNextPositivePowerOfTwo(final int value) { + return value <= 0 ? 1 : value >= 0x40000000 ? 0x40000000 : findNextPositivePowerOfTwo(value); + } + + public static boolean isPrime(int n) { + for (int i = 2; i * i <= n; ++i) { + if (n % i == 0) { return false; + } } - - /** - * Compute the min num n power of two. - * - * @param input - * @return The min n power of two for input. - */ - public static int minPowerOf2(int input) { - int n = input - 1; - n |= n >>> 1; - n |= n >>> 2; - n |= n >>> 4; - n |= n >>> 8; - n |= n >>> 16; - return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1; + return true; + } + + public static int nextPrime(int n) { + for (int num = n; num < n * 2; num++) { + if (isPrime(num)) { + return num; + } } - + return n; + } + + /** + * Pseudo-randomly maps a long (64-bit) to an integer (32-bit) using some bit-mixing for better + * distribution. + * + * @param in the long (64-bit)input. + * @return the bit-mixed int (32-bit) output + */ + public static int longToIntWithBitMixing(long in) { + in = (in ^ (in >>> 30)) * 0xbf58476d1ce4e5b9L; + in = (in ^ (in >>> 27)) * 0x94d049bb133111ebL; + in = in ^ (in >>> 31); + return (int) in; + } + + // ============================================================================================ + + /** + * Bit-mixing for pseudo-randomization of integers (e.g., to guard against bad hash functions). + * Implementation is from Murmur's 32 bit finalizer. + * + * @param in the input value + * @return the bit-mixed output value + */ + public static int bitMix(int in) { + in ^= in >>> 16; + in *= 0x85ebca6b; + in ^= in >>> 13; + in *= 0xc2b2ae35; + in ^= in >>> 16; + return in; + } + + public static int multiplesOf50(int input) { + return input / 50 * 50; + } + + /** + * Check whether input is power of two. + * + * @param input + * @return + */ + public static boolean isPowerOf2(int input) { + if (input > 0) { + return input == 1 || (input & (-input)) == input; + } + return false; + } + + /** + * Compute the min num n power of two. + * + * @param input + * @return The min n power of two for input. + */ + public static int minPowerOf2(int input) { + int n = input - 1; + n |= n >>> 1; + n |= n >>> 2; + n |= n >>> 4; + n |= n >>> 8; + n |= n >>> 16; + return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1; + } } diff --git a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/BlockingListBenchmark.java b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/BlockingListBenchmark.java index 41f8180a8..473d83c37 100644 --- a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/BlockingListBenchmark.java +++ b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/BlockingListBenchmark.java @@ -22,6 +22,7 @@ import java.util.LinkedList; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; + import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -44,32 +45,28 @@ @State(Scope.Benchmark) public class BlockingListBenchmark { - private static final String TEST_STR = "test"; - - @Benchmark - public void testQueue() { - LinkedList queue = new LinkedList<>(); - for (int i = 0; i < 10_000_000; i++) { - queue.offer(TEST_STR); - queue.poll(); - } - } - + private static final String TEST_STR = "test"; - @Benchmark - public void testBlockingQueue() throws Exception { - LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); - for (int i = 0; i < 10_000_000; i++) { - queue.offer(TEST_STR); - queue.poll(); - } + @Benchmark + public void testQueue() { + LinkedList queue = new LinkedList<>(); + for (int i = 0; i < 10_000_000; i++) { + queue.offer(TEST_STR); + queue.poll(); } + } - public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() - .include(BlockingListBenchmark.class.getSimpleName()) - .build(); - new Runner(opt).run(); + @Benchmark + public void testBlockingQueue() throws Exception { + LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); + for (int i = 0; i < 10_000_000; i++) { + queue.offer(TEST_STR); + queue.poll(); } + } + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(BlockingListBenchmark.class.getSimpleName()).build(); + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/HttpUtilTest.java b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/HttpUtilTest.java index 10108c7ec..227b4aa82 100644 --- a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/HttpUtilTest.java +++ b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/HttpUtilTest.java @@ -23,72 +23,74 @@ import java.net.URI; import java.util.HashMap; import java.util.Map; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; + import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.testng.Assert; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + public class HttpUtilTest { - MockWebServer server; - String baseUrl; + MockWebServer server; + String baseUrl; - @BeforeTest - public void prepare() throws IOException { - // Create a MockWebServer. - server = new MockWebServer(); - // Start the server. - server.start(); - baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); - } + @BeforeTest + public void prepare() throws IOException { + // Create a MockWebServer. + server = new MockWebServer(); + // Start the server. + server.start(); + baseUrl = "http://" + server.getHostName() + ":" + server.getPort(); + } - @AfterTest - public void tearUp() throws IOException { - // Shut down the server. Instances cannot be reused. - server.shutdown(); - } + @AfterTest + public void tearUp() throws IOException { + // Shut down the server. Instances cannot be reused. + server.shutdown(); + } - @Test - public void test() throws InterruptedException { - // Schedule some responses. - server.enqueue(new MockResponse().setBody("{key:value,success:true}")); - server.enqueue(new MockResponse().setBody("{delete:true,success:true}")); + @Test + public void test() throws InterruptedException { + // Schedule some responses. + server.enqueue(new MockResponse().setBody("{key:value,success:true}")); + server.enqueue(new MockResponse().setBody("{delete:true,success:true}")); - // Ask the server for its URL. You'll need this to make HTTP requests. - String url = URI.create(baseUrl).resolve("/v1/cluster").toString(); - Object result = HttpUtil.post(url, "{}"); - Assert.assertNull(result); + // Ask the server for its URL. You'll need this to make HTTP requests. + String url = URI.create(baseUrl).resolve("/v1/cluster").toString(); + Object result = HttpUtil.post(url, "{}"); + Assert.assertNull(result); - // confirm that your app made the HTTP requests you were expecting. - RecordedRequest request1 = server.takeRequest(); - Assert.assertEquals("/v1/cluster", request1.getPath()); + // confirm that your app made the HTTP requests you were expecting. + RecordedRequest request1 = server.takeRequest(); + Assert.assertEquals("/v1/cluster", request1.getPath()); - Map headers = new HashMap<>(); - headers.put("header1", "value"); - Object result2 = HttpUtil.delete(url, headers); - Assert.assertNotNull(result2); + Map headers = new HashMap<>(); + headers.put("header1", "value"); + Object result2 = HttpUtil.delete(url, headers); + Assert.assertNotNull(result2); - // confirm that your app made the HTTP requests you were expecting. - RecordedRequest request2 = server.takeRequest(); - Assert.assertEquals("/v1/cluster", request2.getPath()); - Assert.assertEquals("value", request2.getHeaders().get("header1")); - } + // confirm that your app made the HTTP requests you were expecting. + RecordedRequest request2 = server.takeRequest(); + Assert.assertEquals("/v1/cluster", request2.getPath()); + Assert.assertEquals("value", request2.getHeaders().get("header1")); + } - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testPostInvalid() { - server.enqueue(new MockResponse().setResponseCode(500)); - String invalidUrl = URI.create(baseUrl).resolve("/invalid").toString(); - HttpUtil.post(invalidUrl, "{}"); - } + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testPostInvalid() { + server.enqueue(new MockResponse().setResponseCode(500)); + String invalidUrl = URI.create(baseUrl).resolve("/invalid").toString(); + HttpUtil.post(invalidUrl, "{}"); + } - @Test(expectedExceptions = GeaflowRuntimeException.class) - public void testDeleteInvalid() { - server.enqueue(new MockResponse().setResponseCode(500)); - String invalidUrl = URI.create(baseUrl).resolve("/invalid").toString(); - HttpUtil.delete(invalidUrl); - } + @Test(expectedExceptions = GeaflowRuntimeException.class) + public void testDeleteInvalid() { + server.enqueue(new MockResponse().setResponseCode(500)); + String invalidUrl = URI.create(baseUrl).resolve("/invalid").toString(); + HttpUtil.delete(invalidUrl); + } } diff --git a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/TicTocBenchmark.java b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/TicTocBenchmark.java index 9e05d982a..4b22622e9 100644 --- a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/TicTocBenchmark.java +++ b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/TicTocBenchmark.java @@ -20,6 +20,7 @@ package org.apache.geaflow.utils; import java.util.concurrent.TimeUnit; + import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -42,29 +43,26 @@ @State(Scope.Benchmark) public class TicTocBenchmark { - private final TicToc ticToc = new TicToc(); - - @Benchmark - public void testMs() { - for (int i = 0; i < 10_000_000; i++) { - this.ticToc.tic(); - this.ticToc.toc(); - } - } + private final TicToc ticToc = new TicToc(); - @Benchmark - public void testNs() { - for (int i = 0; i < 10_000_000; i++) { - this.ticToc.ticNano(); - this.ticToc.tocNano(); - } + @Benchmark + public void testMs() { + for (int i = 0; i < 10_000_000; i++) { + this.ticToc.tic(); + this.ticToc.toc(); } + } - public static void main(String[] args) throws RunnerException { - Options opt = new OptionsBuilder() - .include(TicTocBenchmark.class.getSimpleName()) - .build(); - new Runner(opt).run(); + @Benchmark + public void testNs() { + for (int i = 0; i < 10_000_000; i++) { + this.ticToc.ticNano(); + this.ticToc.tocNano(); } + } + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder().include(TicTocBenchmark.class.getSimpleName()).build(); + new Runner(opt).run(); + } } diff --git a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/keygroup/KeyGroupTest.java b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/keygroup/KeyGroupTest.java index 7685cdeec..7ce57e9e3 100644 --- a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/keygroup/KeyGroupTest.java +++ b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/keygroup/KeyGroupTest.java @@ -24,17 +24,16 @@ public class KeyGroupTest { - @Test - public void testKeyGroupContains() { - KeyGroup a = new KeyGroup(0, 100); - KeyGroup b = new KeyGroup(0, 1); - Assert.assertTrue(a.contains(b)); + @Test + public void testKeyGroupContains() { + KeyGroup a = new KeyGroup(0, 100); + KeyGroup b = new KeyGroup(0, 1); + Assert.assertTrue(a.contains(b)); - KeyGroup c = new KeyGroup(0, 100); - Assert.assertTrue(a.contains(c)); + KeyGroup c = new KeyGroup(0, 100); + Assert.assertTrue(a.contains(c)); - KeyGroup d = new KeyGroup(9, 101); - Assert.assertFalse(a.contains(d)); - } - -} \ No newline at end of file + KeyGroup d = new KeyGroup(9, 101); + Assert.assertFalse(a.contains(d)); + } +} diff --git a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/math/MathUtilTest.java b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/math/MathUtilTest.java index 4177d2a62..9871b27ae 100644 --- a/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/math/MathUtilTest.java +++ b/geaflow/geaflow-utils/src/test/java/org/apache/geaflow/utils/math/MathUtilTest.java @@ -21,161 +21,160 @@ import java.util.Arrays; import java.util.List; + import org.testng.Assert; import org.testng.annotations.Test; public class MathUtilTest { - @Test - public void testDeviates() { - } - - @Test - public void testPercentile() { - - List input = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); - - long result = MathUtil.percentile(input, 0); - Assert.assertEquals(0, result); - - result = MathUtil.percentile(input, 50); - Assert.assertEquals(7, result); - - result = MathUtil.percentile(input, 90); - Assert.assertEquals(12, result); - } - - @Test - public void testFindTwoGroups() { - long[] input = new long[]{1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L}; - long[][] result = MathUtil.findTwoGroups(input); - Assert.assertEquals(6, result[0].length); - Arrays.sort(result[0]); - Assert.assertEquals(new long[]{1L, 2L, 3L, 4L, 5L, 6L}, result[0]); - Arrays.sort(result[1]); - Assert.assertEquals(7, result[1].length); - Assert.assertEquals(new long[]{7L, 8L, 9L, 10L, 11L, 12L, 13L}, result[1]); - } - - @Test - public void testAverage() { - - long[] input = new long[]{1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L}; - long result = MathUtil.average(input); - Assert.assertEquals(7, result); - - List array = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); - result = MathUtil.average(array); - Assert.assertEquals(7, result); - } - - @Test - public void testMedian() { - List array = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); - long result = MathUtil.median(array); - Assert.assertEquals(7, result); - } - - @Test - public void testFindNextPositivePowerOfTwo() { - int result = MathUtil.findNextPositivePowerOfTwo(10); - Assert.assertEquals(16, result); - } - - @Test - public void testSafeFindNextPositivePowerOfTwo() { - - int result = MathUtil.safeFindNextPositivePowerOfTwo(Integer.MAX_VALUE); - Assert.assertEquals(0x40000000, result); - } - - @Test - public void testIsPrime() { - boolean result = MathUtil.isPrime(7); - Assert.assertEquals(true, result); - - result = MathUtil.isPrime(10); - Assert.assertEquals(false, result); - } - - @Test - public void testNextPrime() { - int result = MathUtil.nextPrime(7); - Assert.assertEquals(7, result); - - result = MathUtil.nextPrime(10); - Assert.assertEquals(11, result); - } - - @Test - public void testLongToIntWithBitMixing() { - int result = MathUtil.longToIntWithBitMixing(10L); - Assert.assertEquals(-1456339591, result); - - } - - @Test - public void testBitMix() { - int result = MathUtil.bitMix(10); - Assert.assertEquals(-383449968, result); - } - - @Test - public void testMurmurHash() { - int code = 1; - Assert.assertEquals(68075478, MathUtil.murmurHash(code)); - code = -1; - Assert.assertEquals(1982413648, MathUtil.murmurHash(code)); - - Object objectCode = 1; - Assert.assertEquals(68075478, MathUtil.murmurHash(objectCode.hashCode())); - objectCode = -1; - Assert.assertEquals(1982413648, MathUtil.murmurHash(objectCode.hashCode())); - - int minCode = Integer.MIN_VALUE; - Assert.assertEquals(1718298732, MathUtil.murmurHash(minCode)); - - Object minHashCode = Integer.MIN_VALUE; - Assert.assertEquals(1718298732, MathUtil.murmurHash(minHashCode.hashCode())); - - int maxCode = Integer.MAX_VALUE; - Assert.assertEquals(1653689534, MathUtil.murmurHash(maxCode)); - - Object maxHashCode = Integer.MAX_VALUE; - Assert.assertEquals(1653689534, MathUtil.murmurHash(maxHashCode.hashCode())); - - Object StringCode = "hello"; - Assert.assertEquals(1715862179, MathUtil.murmurHash(StringCode.hashCode())); - } - - @Test - public void testIsPowerOf2() { - Assert.assertFalse(MathUtil.isPowerOf2(0)); - Assert.assertFalse(MathUtil.isPowerOf2(3)); - Assert.assertFalse(MathUtil.isPowerOf2(-3)); - Assert.assertTrue(MathUtil.isPowerOf2(1)); - Assert.assertTrue(MathUtil.isPowerOf2(4)); - Assert.assertFalse(MathUtil.isPowerOf2(-4)); - } - - @Test - public void testMinPowerOf2() { - int input = 1; - Assert.assertEquals(1, MathUtil.minPowerOf2(input)); - - input = 2; - Assert.assertEquals(2, MathUtil.minPowerOf2(input)); - - input = 3; - Assert.assertEquals(4, MathUtil.minPowerOf2(input)); - - input = 5; - Assert.assertEquals(8, MathUtil.minPowerOf2(input)); - input = 6; - Assert.assertEquals(8, MathUtil.minPowerOf2(input)); - input = 7; - Assert.assertEquals(8, MathUtil.minPowerOf2(input)); - input = 8; - Assert.assertEquals(8, MathUtil.minPowerOf2(input)); - } + @Test + public void testDeviates() {} + + @Test + public void testPercentile() { + + List input = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); + + long result = MathUtil.percentile(input, 0); + Assert.assertEquals(0, result); + + result = MathUtil.percentile(input, 50); + Assert.assertEquals(7, result); + + result = MathUtil.percentile(input, 90); + Assert.assertEquals(12, result); + } + + @Test + public void testFindTwoGroups() { + long[] input = new long[] {1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L}; + long[][] result = MathUtil.findTwoGroups(input); + Assert.assertEquals(6, result[0].length); + Arrays.sort(result[0]); + Assert.assertEquals(new long[] {1L, 2L, 3L, 4L, 5L, 6L}, result[0]); + Arrays.sort(result[1]); + Assert.assertEquals(7, result[1].length); + Assert.assertEquals(new long[] {7L, 8L, 9L, 10L, 11L, 12L, 13L}, result[1]); + } + + @Test + public void testAverage() { + + long[] input = new long[] {1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L}; + long result = MathUtil.average(input); + Assert.assertEquals(7, result); + + List array = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); + result = MathUtil.average(array); + Assert.assertEquals(7, result); + } + + @Test + public void testMedian() { + List array = Arrays.asList(1L, 2L, 8L, 9L, 13L, 3L, 12L, 4L, 7L, 6L, 11L, 10L, 5L); + long result = MathUtil.median(array); + Assert.assertEquals(7, result); + } + + @Test + public void testFindNextPositivePowerOfTwo() { + int result = MathUtil.findNextPositivePowerOfTwo(10); + Assert.assertEquals(16, result); + } + + @Test + public void testSafeFindNextPositivePowerOfTwo() { + + int result = MathUtil.safeFindNextPositivePowerOfTwo(Integer.MAX_VALUE); + Assert.assertEquals(0x40000000, result); + } + + @Test + public void testIsPrime() { + boolean result = MathUtil.isPrime(7); + Assert.assertEquals(true, result); + + result = MathUtil.isPrime(10); + Assert.assertEquals(false, result); + } + + @Test + public void testNextPrime() { + int result = MathUtil.nextPrime(7); + Assert.assertEquals(7, result); + + result = MathUtil.nextPrime(10); + Assert.assertEquals(11, result); + } + + @Test + public void testLongToIntWithBitMixing() { + int result = MathUtil.longToIntWithBitMixing(10L); + Assert.assertEquals(-1456339591, result); + } + + @Test + public void testBitMix() { + int result = MathUtil.bitMix(10); + Assert.assertEquals(-383449968, result); + } + + @Test + public void testMurmurHash() { + int code = 1; + Assert.assertEquals(68075478, MathUtil.murmurHash(code)); + code = -1; + Assert.assertEquals(1982413648, MathUtil.murmurHash(code)); + + Object objectCode = 1; + Assert.assertEquals(68075478, MathUtil.murmurHash(objectCode.hashCode())); + objectCode = -1; + Assert.assertEquals(1982413648, MathUtil.murmurHash(objectCode.hashCode())); + + int minCode = Integer.MIN_VALUE; + Assert.assertEquals(1718298732, MathUtil.murmurHash(minCode)); + + Object minHashCode = Integer.MIN_VALUE; + Assert.assertEquals(1718298732, MathUtil.murmurHash(minHashCode.hashCode())); + + int maxCode = Integer.MAX_VALUE; + Assert.assertEquals(1653689534, MathUtil.murmurHash(maxCode)); + + Object maxHashCode = Integer.MAX_VALUE; + Assert.assertEquals(1653689534, MathUtil.murmurHash(maxHashCode.hashCode())); + + Object StringCode = "hello"; + Assert.assertEquals(1715862179, MathUtil.murmurHash(StringCode.hashCode())); + } + + @Test + public void testIsPowerOf2() { + Assert.assertFalse(MathUtil.isPowerOf2(0)); + Assert.assertFalse(MathUtil.isPowerOf2(3)); + Assert.assertFalse(MathUtil.isPowerOf2(-3)); + Assert.assertTrue(MathUtil.isPowerOf2(1)); + Assert.assertTrue(MathUtil.isPowerOf2(4)); + Assert.assertFalse(MathUtil.isPowerOf2(-4)); + } + + @Test + public void testMinPowerOf2() { + int input = 1; + Assert.assertEquals(1, MathUtil.minPowerOf2(input)); + + input = 2; + Assert.assertEquals(2, MathUtil.minPowerOf2(input)); + + input = 3; + Assert.assertEquals(4, MathUtil.minPowerOf2(input)); + + input = 5; + Assert.assertEquals(8, MathUtil.minPowerOf2(input)); + input = 6; + Assert.assertEquals(8, MathUtil.minPowerOf2(input)); + input = 7; + Assert.assertEquals(8, MathUtil.minPowerOf2(input)); + input = 8; + Assert.assertEquals(8, MathUtil.minPowerOf2(input)); + } } diff --git a/pom.xml b/pom.xml index 66d19baee..76fba71e8 100644 --- a/pom.xml +++ b/pom.xml @@ -264,6 +264,43 @@ versions-maven-plugin 2.8.1 + + + com.diffplug.spotless + spotless-maven-plugin + 2.43.0 + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + **/generated/**/* + **/proto/**/* + + + 1.17.0 + + true + true + + + java,javax,org,com, + + + + + + + + + check + + validate + + + org.apache.maven.plugins